From e0c090f227e9b64e595b47d4d1f96f8a2fff5bf7 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 2 Jan 2018 09:19:18 +0800 Subject: [PATCH 0001/1161] [SPARK-22932][SQL] Refactor AnalysisContext ## What changes were proposed in this pull request? Add a `reset` function to ensure the state in `AnalysisContext ` is per-query. ## How was this patch tested? The existing test cases Author: gatorsmile Closes #20127 from gatorsmile/refactorAnalysisContext. --- .../sql/catalyst/analysis/Analyzer.scala | 25 +++++++++++++++---- 1 file changed, 20 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 6d294d48c0ee7..35b35110e491f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -52,6 +52,7 @@ object SimpleAnalyzer extends Analyzer( /** * Provides a way to keep state during the analysis, this enables us to decouple the concerns * of analysis environment from the catalog. + * The state that is kept here is per-query. * * Note this is thread local. * @@ -70,6 +71,8 @@ object AnalysisContext { } def get: AnalysisContext = value.get() + def reset(): Unit = value.remove() + private def set(context: AnalysisContext): Unit = value.set(context) def withAnalysisContext[A](database: Option[String])(f: => A): A = { @@ -95,6 +98,17 @@ class Analyzer( this(catalog, conf, conf.optimizerMaxIterations) } + override def execute(plan: LogicalPlan): LogicalPlan = { + AnalysisContext.reset() + try { + executeSameContext(plan) + } finally { + AnalysisContext.reset() + } + } + + private def executeSameContext(plan: LogicalPlan): LogicalPlan = super.execute(plan) + def resolver: Resolver = conf.resolver protected val fixedPoint = FixedPoint(maxIterations) @@ -176,7 +190,7 @@ class Analyzer( case With(child, relations) => substituteCTE(child, relations.foldLeft(Seq.empty[(String, LogicalPlan)]) { case (resolved, (name, relation)) => - resolved :+ name -> execute(substituteCTE(relation, resolved)) + resolved :+ name -> executeSameContext(substituteCTE(relation, resolved)) }) case other => other } @@ -600,7 +614,7 @@ class Analyzer( "avoid errors. Increase the value of spark.sql.view.maxNestedViewDepth to work " + "aroud this.") } - execute(child) + executeSameContext(child) } view.copy(child = newChild) case p @ SubqueryAlias(_, view: View) => @@ -1269,7 +1283,7 @@ class Analyzer( do { // Try to resolve the subquery plan using the regular analyzer. previous = current - current = execute(current) + current = executeSameContext(current) // Use the outer references to resolve the subquery plan if it isn't resolved yet. val i = plans.iterator @@ -1392,7 +1406,7 @@ class Analyzer( grouping, Alias(cond, "havingCondition")() :: Nil, child) - val resolvedOperator = execute(aggregatedCondition) + val resolvedOperator = executeSameContext(aggregatedCondition) def resolvedAggregateFilter = resolvedOperator .asInstanceOf[Aggregate] @@ -1450,7 +1464,8 @@ class Analyzer( val aliasedOrdering = unresolvedSortOrders.map(o => Alias(o.child, "aggOrder")()) val aggregatedOrdering = aggregate.copy(aggregateExpressions = aliasedOrdering) - val resolvedAggregate: Aggregate = execute(aggregatedOrdering).asInstanceOf[Aggregate] + val resolvedAggregate: Aggregate = + executeSameContext(aggregatedOrdering).asInstanceOf[Aggregate] val resolvedAliasedOrdering: Seq[Alias] = resolvedAggregate.aggregateExpressions.asInstanceOf[Seq[Alias]] From a6fc300e91273230e7134ac6db95ccb4436c6f8f Mon Sep 17 00:00:00 2001 From: Xianjin YE Date: Tue, 2 Jan 2018 23:30:38 +0800 Subject: [PATCH 0002/1161] [SPARK-22897][CORE] Expose stageAttemptId in TaskContext ## What changes were proposed in this pull request? stageAttemptId added in TaskContext and corresponding construction modification ## How was this patch tested? Added a new test in TaskContextSuite, two cases are tested: 1. Normal case without failure 2. Exception case with resubmitted stages Link to [SPARK-22897](https://issues.apache.org/jira/browse/SPARK-22897) Author: Xianjin YE Closes #20082 from advancedxy/SPARK-22897. --- .../scala/org/apache/spark/TaskContext.scala | 9 +++++- .../org/apache/spark/TaskContextImpl.scala | 5 ++-- .../org/apache/spark/scheduler/Task.scala | 1 + .../spark/JavaTaskContextCompileCheck.java | 2 ++ .../scala/org/apache/spark/ShuffleSuite.scala | 6 ++-- .../spark/memory/MemoryTestingUtils.scala | 1 + .../spark/scheduler/TaskContextSuite.scala | 29 +++++++++++++++++-- .../spark/storage/BlockInfoManagerSuite.scala | 2 +- project/MimaExcludes.scala | 3 ++ .../UnsafeFixedWidthAggregationMapSuite.scala | 1 + .../UnsafeKVExternalSorterSuite.scala | 1 + .../execution/UnsafeRowSerializerSuite.scala | 2 +- .../SortBasedAggregationStoreSuite.scala | 3 +- 13 files changed, 54 insertions(+), 11 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala index 0b87cd503d4fa..69739745aa6cf 100644 --- a/core/src/main/scala/org/apache/spark/TaskContext.scala +++ b/core/src/main/scala/org/apache/spark/TaskContext.scala @@ -66,7 +66,7 @@ object TaskContext { * An empty task context that does not represent an actual task. This is only used in tests. */ private[spark] def empty(): TaskContextImpl = { - new TaskContextImpl(0, 0, 0, 0, null, new Properties, null) + new TaskContextImpl(0, 0, 0, 0, 0, null, new Properties, null) } } @@ -150,6 +150,13 @@ abstract class TaskContext extends Serializable { */ def stageId(): Int + /** + * How many times the stage that this task belongs to has been attempted. The first stage attempt + * will be assigned stageAttemptNumber = 0, and subsequent attempts will have increasing attempt + * numbers. + */ + def stageAttemptNumber(): Int + /** * The ID of the RDD partition that is computed by this task. */ diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala index 01d8973e1bb06..cccd3ea457ba4 100644 --- a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala +++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala @@ -41,8 +41,9 @@ import org.apache.spark.util._ * `TaskMetrics` & `MetricsSystem` objects are not thread safe. */ private[spark] class TaskContextImpl( - val stageId: Int, - val partitionId: Int, + override val stageId: Int, + override val stageAttemptNumber: Int, + override val partitionId: Int, override val taskAttemptId: Long, override val attemptNumber: Int, override val taskMemoryManager: TaskMemoryManager, diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index 7767ef1803a06..f536fc2a5f0a1 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -79,6 +79,7 @@ private[spark] abstract class Task[T]( SparkEnv.get.blockManager.registerTask(taskAttemptId) context = new TaskContextImpl( stageId, + stageAttemptId, // stageAttemptId and stageAttemptNumber are semantically equal partitionId, taskAttemptId, attemptNumber, diff --git a/core/src/test/java/test/org/apache/spark/JavaTaskContextCompileCheck.java b/core/src/test/java/test/org/apache/spark/JavaTaskContextCompileCheck.java index 94f5805853e1e..f8e233a05a447 100644 --- a/core/src/test/java/test/org/apache/spark/JavaTaskContextCompileCheck.java +++ b/core/src/test/java/test/org/apache/spark/JavaTaskContextCompileCheck.java @@ -38,6 +38,7 @@ public static void test() { tc.attemptNumber(); tc.partitionId(); tc.stageId(); + tc.stageAttemptNumber(); tc.taskAttemptId(); } @@ -51,6 +52,7 @@ public void onTaskCompletion(TaskContext context) { context.isCompleted(); context.isInterrupted(); context.stageId(); + context.stageAttemptNumber(); context.partitionId(); context.addTaskCompletionListener(this); } diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala index 3931d53b4ae0a..ced5a06516f75 100644 --- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala @@ -363,14 +363,14 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC // first attempt -- its successful val writer1 = manager.getWriter[Int, Int](shuffleHandle, 0, - new TaskContextImpl(0, 0, 0L, 0, taskMemoryManager, new Properties, metricsSystem)) + new TaskContextImpl(0, 0, 0, 0L, 0, taskMemoryManager, new Properties, metricsSystem)) val data1 = (1 to 10).map { x => x -> x} // second attempt -- also successful. We'll write out different data, // just to simulate the fact that the records may get written differently // depending on what gets spilled, what gets combined, etc. val writer2 = manager.getWriter[Int, Int](shuffleHandle, 0, - new TaskContextImpl(0, 0, 1L, 0, taskMemoryManager, new Properties, metricsSystem)) + new TaskContextImpl(0, 0, 0, 1L, 0, taskMemoryManager, new Properties, metricsSystem)) val data2 = (11 to 20).map { x => x -> x} // interleave writes of both attempts -- we want to test that both attempts can occur @@ -398,7 +398,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC } val reader = manager.getReader[Int, Int](shuffleHandle, 0, 1, - new TaskContextImpl(1, 0, 2L, 0, taskMemoryManager, new Properties, metricsSystem)) + new TaskContextImpl(1, 0, 0, 2L, 0, taskMemoryManager, new Properties, metricsSystem)) val readData = reader.read().toIndexedSeq assert(readData === data1.toIndexedSeq || readData === data2.toIndexedSeq) diff --git a/core/src/test/scala/org/apache/spark/memory/MemoryTestingUtils.scala b/core/src/test/scala/org/apache/spark/memory/MemoryTestingUtils.scala index 362cd861cc248..dcf89e4f75acf 100644 --- a/core/src/test/scala/org/apache/spark/memory/MemoryTestingUtils.scala +++ b/core/src/test/scala/org/apache/spark/memory/MemoryTestingUtils.scala @@ -29,6 +29,7 @@ object MemoryTestingUtils { val taskMemoryManager = new TaskMemoryManager(env.memoryManager, 0) new TaskContextImpl( stageId = 0, + stageAttemptNumber = 0, partitionId = 0, taskAttemptId = 0, attemptNumber = 0, diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala index a1d9085fa085d..aa9c36c0aaacb 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala @@ -29,6 +29,7 @@ import org.apache.spark.memory.TaskMemoryManager import org.apache.spark.metrics.source.JvmSource import org.apache.spark.network.util.JavaUtils import org.apache.spark.rdd.RDD +import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.util._ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSparkContext { @@ -158,6 +159,30 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark assert(attemptIdsWithFailedTask.toSet === Set(0, 1)) } + test("TaskContext.stageAttemptNumber getter") { + sc = new SparkContext("local[1,2]", "test") + + // Check stageAttemptNumbers are 0 for initial stage + val stageAttemptNumbers = sc.parallelize(Seq(1, 2), 2).mapPartitions { _ => + Seq(TaskContext.get().stageAttemptNumber()).iterator + }.collect() + assert(stageAttemptNumbers.toSet === Set(0)) + + // Check stageAttemptNumbers that are resubmitted when tasks have FetchFailedException + val stageAttemptNumbersWithFailedStage = + sc.parallelize(Seq(1, 2, 3, 4), 4).repartition(1).mapPartitions { _ => + val stageAttemptNumber = TaskContext.get().stageAttemptNumber() + if (stageAttemptNumber < 2) { + // Throw FetchFailedException to explicitly trigger stage resubmission. A normal exception + // will only trigger task resubmission in the same stage. + throw new FetchFailedException(null, 0, 0, 0, "Fake") + } + Seq(stageAttemptNumber).iterator + }.collect() + + assert(stageAttemptNumbersWithFailedStage.toSet === Set(2)) + } + test("accumulators are updated on exception failures") { // This means use 1 core and 4 max task failures sc = new SparkContext("local[1,4]", "test") @@ -190,7 +215,7 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark // accumulator updates from it. val taskMetrics = TaskMetrics.empty val task = new Task[Int](0, 0, 0) { - context = new TaskContextImpl(0, 0, 0L, 0, + context = new TaskContextImpl(0, 0, 0, 0L, 0, new TaskMemoryManager(SparkEnv.get.memoryManager, 0L), new Properties, SparkEnv.get.metricsSystem, @@ -213,7 +238,7 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark // accumulator updates from it. val taskMetrics = TaskMetrics.registered val task = new Task[Int](0, 0, 0) { - context = new TaskContextImpl(0, 0, 0L, 0, + context = new TaskContextImpl(0, 0, 0, 0L, 0, new TaskMemoryManager(SparkEnv.get.memoryManager, 0L), new Properties, SparkEnv.get.metricsSystem, diff --git a/core/src/test/scala/org/apache/spark/storage/BlockInfoManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockInfoManagerSuite.scala index 917db766f7f11..9c0699bc981f8 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockInfoManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockInfoManagerSuite.scala @@ -62,7 +62,7 @@ class BlockInfoManagerSuite extends SparkFunSuite with BeforeAndAfterEach { private def withTaskId[T](taskAttemptId: Long)(block: => T): T = { try { TaskContext.setTaskContext( - new TaskContextImpl(0, 0, taskAttemptId, 0, null, new Properties, null)) + new TaskContextImpl(0, 0, 0, taskAttemptId, 0, null, new Properties, null)) block } finally { TaskContext.unset() diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 81584af6813ea..3b452f35c5ec1 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -36,6 +36,9 @@ object MimaExcludes { // Exclude rules for 2.3.x lazy val v23excludes = v22excludes ++ Seq( + // [SPARK-22897] Expose stageAttemptId in TaskContext + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.TaskContext.stageAttemptNumber"), + // SPARK-22789: Map-only continuous processing execution ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.streaming.StreamingQueryManager.startQuery$default$8"), ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.streaming.StreamingQueryManager.startQuery$default$6"), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala index 232c1beae7998..3e31d22e15c0e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala @@ -70,6 +70,7 @@ class UnsafeFixedWidthAggregationMapSuite TaskContext.setTaskContext(new TaskContextImpl( stageId = 0, + stageAttemptNumber = 0, partitionId = 0, taskAttemptId = Random.nextInt(10000), attemptNumber = 0, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala index 604502f2a57d0..6af9f8b77f8d3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala @@ -116,6 +116,7 @@ class UnsafeKVExternalSorterSuite extends SparkFunSuite with SharedSQLContext { val taskMemMgr = new TaskMemoryManager(memoryManager, 0) TaskContext.setTaskContext(new TaskContextImpl( stageId = 0, + stageAttemptNumber = 0, partitionId = 0, taskAttemptId = 98456, attemptNumber = 0, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala index dff88ce7f1b9a..a3ae93810aa3c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala @@ -114,7 +114,7 @@ class UnsafeRowSerializerSuite extends SparkFunSuite with LocalSparkContext { (i, converter(Row(i))) } val taskMemoryManager = new TaskMemoryManager(sc.env.memoryManager, 0) - val taskContext = new TaskContextImpl(0, 0, 0, 0, taskMemoryManager, new Properties, null) + val taskContext = new TaskContextImpl(0, 0, 0, 0, 0, taskMemoryManager, new Properties, null) val sorter = new ExternalSorter[Int, UnsafeRow, UnsafeRow]( taskContext, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationStoreSuite.scala index 10f1ee279bedf..3fad7dfddadcc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationStoreSuite.scala @@ -35,7 +35,8 @@ class SortBasedAggregationStoreSuite extends SparkFunSuite with LocalSparkConte val conf = new SparkConf() sc = new SparkContext("local[2, 4]", "test", conf) val taskManager = new TaskMemoryManager(new TestMemoryManager(conf), 0) - TaskContext.setTaskContext(new TaskContextImpl(0, 0, 0, 0, taskManager, new Properties, null)) + TaskContext.setTaskContext( + new TaskContextImpl(0, 0, 0, 0, 0, taskManager, new Properties, null)) } override def afterAll(): Unit = TaskContext.unset() From 247a08939d58405aef39b2a4e7773aa45474ad12 Mon Sep 17 00:00:00 2001 From: Juliusz Sompolski Date: Wed, 3 Jan 2018 21:40:51 +0800 Subject: [PATCH 0003/1161] [SPARK-22938] Assert that SQLConf.get is accessed only on the driver. ## What changes were proposed in this pull request? Assert if code tries to access SQLConf.get on executor. This can lead to hard to detect bugs, where the executor will read fallbackConf, falling back to default config values, ignoring potentially changed non-default configs. If a config is to be passed to executor code, it needs to be read on the driver, and passed explicitly. ## How was this patch tested? Check in existing tests. Author: Juliusz Sompolski Closes #20136 from juliuszsompolski/SPARK-22938. --- .../scala/org/apache/spark/sql/internal/SQLConf.scala | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 4f77c54a7af57..80cdc61484c0f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -27,11 +27,13 @@ import scala.util.matching.Regex import org.apache.hadoop.fs.Path +import org.apache.spark.{SparkContext, SparkEnv} import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ import org.apache.spark.network.util.ByteUnit import org.apache.spark.sql.catalyst.analysis.Resolver import org.apache.spark.sql.catalyst.expressions.codegen.CodeGenerator +import org.apache.spark.util.Utils //////////////////////////////////////////////////////////////////////////////////////////////////// // This file defines the configuration options for Spark SQL. @@ -70,7 +72,7 @@ object SQLConf { * Default config. Only used when there is no active SparkSession for the thread. * See [[get]] for more information. */ - private val fallbackConf = new ThreadLocal[SQLConf] { + private lazy val fallbackConf = new ThreadLocal[SQLConf] { override def initialValue: SQLConf = new SQLConf } @@ -1087,6 +1089,12 @@ object SQLConf { class SQLConf extends Serializable with Logging { import SQLConf._ + if (Utils.isTesting && SparkEnv.get != null) { + // assert that we're only accessing it on the driver. + assert(SparkEnv.get.executorId == SparkContext.DRIVER_IDENTIFIER, + "SQLConf should only be created and accessed on the driver.") + } + /** Only low degree of contention is expected for conf, thus NOT using ConcurrentHashMap. */ @transient protected[spark] val settings = java.util.Collections.synchronizedMap( new java.util.HashMap[String, String]()) From 1a87a1609c4d2c9027a2cf669ea3337b89f61fb6 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Wed, 3 Jan 2018 22:09:30 +0800 Subject: [PATCH 0004/1161] [SPARK-22934][SQL] Make optional clauses order insensitive for CREATE TABLE SQL statement ## What changes were proposed in this pull request? Currently, our CREATE TABLE syntax require the EXACT order of clauses. It is pretty hard to remember the exact order. Thus, this PR is to make optional clauses order insensitive for `CREATE TABLE` SQL statement. ``` CREATE [TEMPORARY] TABLE [IF NOT EXISTS] [db_name.]table_name [(col_name1 col_type1 [COMMENT col_comment1], ...)] USING datasource [OPTIONS (key1=val1, key2=val2, ...)] [PARTITIONED BY (col_name1, col_name2, ...)] [CLUSTERED BY (col_name3, col_name4, ...) INTO num_buckets BUCKETS] [LOCATION path] [COMMENT table_comment] [TBLPROPERTIES (key1=val1, key2=val2, ...)] [AS select_statement] ``` The proposal is to make the following clauses order insensitive. ``` [OPTIONS (key1=val1, key2=val2, ...)] [PARTITIONED BY (col_name1, col_name2, ...)] [CLUSTERED BY (col_name3, col_name4, ...) INTO num_buckets BUCKETS] [LOCATION path] [COMMENT table_comment] [TBLPROPERTIES (key1=val1, key2=val2, ...)] ``` The same idea is also applicable to Create Hive Table. ``` CREATE [EXTERNAL] TABLE [IF NOT EXISTS] [db_name.]table_name [(col_name1[:] col_type1 [COMMENT col_comment1], ...)] [COMMENT table_comment] [PARTITIONED BY (col_name2[:] col_type2 [COMMENT col_comment2], ...)] [ROW FORMAT row_format] [STORED AS file_format] [LOCATION path] [TBLPROPERTIES (key1=val1, key2=val2, ...)] [AS select_statement] ``` The proposal is to make the following clauses order insensitive. ``` [COMMENT table_comment] [PARTITIONED BY (col_name2[:] col_type2 [COMMENT col_comment2], ...)] [ROW FORMAT row_format] [STORED AS file_format] [LOCATION path] [TBLPROPERTIES (key1=val1, key2=val2, ...)] ``` ## How was this patch tested? Added test cases Author: gatorsmile Closes #20133 from gatorsmile/createDataSourceTableDDL. --- .../spark/sql/catalyst/parser/SqlBase.g4 | 24 +- .../sql/catalyst/parser/ParserUtils.scala | 9 + .../spark/sql/execution/SparkSqlParser.scala | 81 +++++-- .../execution/command/DDLParserSuite.scala | 220 ++++++++++++++---- .../sql/execution/command/DDLSuite.scala | 2 +- .../sql/hive/execution/HiveDDLSuite.scala | 13 +- .../sql/hive/execution/SQLQuerySuite.scala | 124 +++++----- 7 files changed, 335 insertions(+), 138 deletions(-) diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 6fe995f650d55..6daf01d98426c 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -73,18 +73,22 @@ statement | ALTER DATABASE identifier SET DBPROPERTIES tablePropertyList #setDatabaseProperties | DROP DATABASE (IF EXISTS)? identifier (RESTRICT | CASCADE)? #dropDatabase | createTableHeader ('(' colTypeList ')')? tableProvider - (OPTIONS options=tablePropertyList)? - (PARTITIONED BY partitionColumnNames=identifierList)? - bucketSpec? locationSpec? - (COMMENT comment=STRING)? - (TBLPROPERTIES tableProps=tablePropertyList)? + ((OPTIONS options=tablePropertyList) | + (PARTITIONED BY partitionColumnNames=identifierList) | + bucketSpec | + locationSpec | + (COMMENT comment=STRING) | + (TBLPROPERTIES tableProps=tablePropertyList))* (AS? query)? #createTable | createTableHeader ('(' columns=colTypeList ')')? - (COMMENT comment=STRING)? - (PARTITIONED BY '(' partitionColumns=colTypeList ')')? - bucketSpec? skewSpec? - rowFormat? createFileFormat? locationSpec? - (TBLPROPERTIES tablePropertyList)? + ((COMMENT comment=STRING) | + (PARTITIONED BY '(' partitionColumns=colTypeList ')') | + bucketSpec | + skewSpec | + rowFormat | + createFileFormat | + locationSpec | + (TBLPROPERTIES tableProps=tablePropertyList))* (AS? query)? #createHiveTable | CREATE TABLE (IF NOT EXISTS)? target=tableIdentifier LIKE source=tableIdentifier locationSpec? #createTableLike diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala index 9b127f91648e6..89347f4b1f7bf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala @@ -16,6 +16,8 @@ */ package org.apache.spark.sql.catalyst.parser +import java.util + import scala.collection.mutable.StringBuilder import org.antlr.v4.runtime.{ParserRuleContext, Token} @@ -39,6 +41,13 @@ object ParserUtils { throw new ParseException(s"Operation not allowed: $message", ctx) } + def checkDuplicateClauses[T]( + nodes: util.List[T], clauseName: String, ctx: ParserRuleContext): Unit = { + if (nodes.size() > 1) { + throw new ParseException(s"Found duplicate clauses: $clauseName", ctx) + } + } + /** Check if duplicate keys exist in a set of key-value pairs. */ def checkDuplicateKeys[T](keyPairs: Seq[(String, T)], ctx: ParserRuleContext): Unit = { keyPairs.groupBy(_._1).filter(_._2.size > 1).foreach { case (key, _) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index 29b584b55972c..d3cfd2a1ffbf2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -383,16 +383,19 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { * {{{ * CREATE [TEMPORARY] TABLE [IF NOT EXISTS] [db_name.]table_name * USING table_provider - * [OPTIONS table_property_list] - * [PARTITIONED BY (col_name, col_name, ...)] - * [CLUSTERED BY (col_name, col_name, ...) - * [SORTED BY (col_name [ASC|DESC], ...)] - * INTO num_buckets BUCKETS - * ] - * [LOCATION path] - * [COMMENT table_comment] - * [TBLPROPERTIES (property_name=property_value, ...)] + * create_table_clauses * [[AS] select_statement]; + * + * create_table_clauses (order insensitive): + * [OPTIONS table_property_list] + * [PARTITIONED BY (col_name, col_name, ...)] + * [CLUSTERED BY (col_name, col_name, ...) + * [SORTED BY (col_name [ASC|DESC], ...)] + * INTO num_buckets BUCKETS + * ] + * [LOCATION path] + * [COMMENT table_comment] + * [TBLPROPERTIES (property_name=property_value, ...)] * }}} */ override def visitCreateTable(ctx: CreateTableContext): LogicalPlan = withOrigin(ctx) { @@ -400,6 +403,14 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { if (external) { operationNotAllowed("CREATE EXTERNAL TABLE ... USING", ctx) } + + checkDuplicateClauses(ctx.TBLPROPERTIES, "TBLPROPERTIES", ctx) + checkDuplicateClauses(ctx.OPTIONS, "OPTIONS", ctx) + checkDuplicateClauses(ctx.PARTITIONED, "PARTITIONED BY", ctx) + checkDuplicateClauses(ctx.COMMENT, "COMMENT", ctx) + checkDuplicateClauses(ctx.bucketSpec(), "CLUSTERED BY", ctx) + checkDuplicateClauses(ctx.locationSpec, "LOCATION", ctx) + val options = Option(ctx.options).map(visitPropertyKeyValues).getOrElse(Map.empty) val provider = ctx.tableProvider.qualifiedName.getText val schema = Option(ctx.colTypeList()).map(createSchema) @@ -408,9 +419,9 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { .map(visitIdentifierList(_).toArray) .getOrElse(Array.empty[String]) val properties = Option(ctx.tableProps).map(visitPropertyKeyValues).getOrElse(Map.empty) - val bucketSpec = Option(ctx.bucketSpec()).map(visitBucketSpec) + val bucketSpec = ctx.bucketSpec().asScala.headOption.map(visitBucketSpec) - val location = Option(ctx.locationSpec).map(visitLocationSpec) + val location = ctx.locationSpec.asScala.headOption.map(visitLocationSpec) val storage = DataSource.buildStorageFormatFromOptions(options) if (location.isDefined && storage.locationUri.isDefined) { @@ -1087,13 +1098,16 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { * {{{ * CREATE [EXTERNAL] TABLE [IF NOT EXISTS] [db_name.]table_name * [(col1[:] data_type [COMMENT col_comment], ...)] - * [COMMENT table_comment] - * [PARTITIONED BY (col2[:] data_type [COMMENT col_comment], ...)] - * [ROW FORMAT row_format] - * [STORED AS file_format] - * [LOCATION path] - * [TBLPROPERTIES (property_name=property_value, ...)] + * create_table_clauses * [AS select_statement]; + * + * create_table_clauses (order insensitive): + * [COMMENT table_comment] + * [PARTITIONED BY (col2[:] data_type [COMMENT col_comment], ...)] + * [ROW FORMAT row_format] + * [STORED AS file_format] + * [LOCATION path] + * [TBLPROPERTIES (property_name=property_value, ...)] * }}} */ override def visitCreateHiveTable(ctx: CreateHiveTableContext): LogicalPlan = withOrigin(ctx) { @@ -1104,15 +1118,23 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { "CREATE TEMPORARY TABLE is not supported yet. " + "Please use CREATE TEMPORARY VIEW as an alternative.", ctx) } - if (ctx.skewSpec != null) { + if (ctx.skewSpec.size > 0) { operationNotAllowed("CREATE TABLE ... SKEWED BY", ctx) } + checkDuplicateClauses(ctx.TBLPROPERTIES, "TBLPROPERTIES", ctx) + checkDuplicateClauses(ctx.PARTITIONED, "PARTITIONED BY", ctx) + checkDuplicateClauses(ctx.COMMENT, "COMMENT", ctx) + checkDuplicateClauses(ctx.bucketSpec(), "CLUSTERED BY", ctx) + checkDuplicateClauses(ctx.createFileFormat, "STORED AS/BY", ctx) + checkDuplicateClauses(ctx.rowFormat, "ROW FORMAT", ctx) + checkDuplicateClauses(ctx.locationSpec, "LOCATION", ctx) + val dataCols = Option(ctx.columns).map(visitColTypeList).getOrElse(Nil) val partitionCols = Option(ctx.partitionColumns).map(visitColTypeList).getOrElse(Nil) - val properties = Option(ctx.tablePropertyList).map(visitPropertyKeyValues).getOrElse(Map.empty) + val properties = Option(ctx.tableProps).map(visitPropertyKeyValues).getOrElse(Map.empty) val selectQuery = Option(ctx.query).map(plan) - val bucketSpec = Option(ctx.bucketSpec()).map(visitBucketSpec) + val bucketSpec = ctx.bucketSpec().asScala.headOption.map(visitBucketSpec) // Note: Hive requires partition columns to be distinct from the schema, so we need // to include the partition columns here explicitly @@ -1120,12 +1142,12 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { // Storage format val defaultStorage = HiveSerDe.getDefaultStorage(conf) - validateRowFormatFileFormat(ctx.rowFormat, ctx.createFileFormat, ctx) - val fileStorage = Option(ctx.createFileFormat).map(visitCreateFileFormat) + validateRowFormatFileFormat(ctx.rowFormat.asScala, ctx.createFileFormat.asScala, ctx) + val fileStorage = ctx.createFileFormat.asScala.headOption.map(visitCreateFileFormat) .getOrElse(CatalogStorageFormat.empty) - val rowStorage = Option(ctx.rowFormat).map(visitRowFormat) + val rowStorage = ctx.rowFormat.asScala.headOption.map(visitRowFormat) .getOrElse(CatalogStorageFormat.empty) - val location = Option(ctx.locationSpec).map(visitLocationSpec) + val location = ctx.locationSpec.asScala.headOption.map(visitLocationSpec) // If we are creating an EXTERNAL table, then the LOCATION field is required if (external && location.isEmpty) { operationNotAllowed("CREATE EXTERNAL TABLE must be accompanied by LOCATION", ctx) @@ -1180,7 +1202,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { ctx) } - val hasStorageProperties = (ctx.createFileFormat != null) || (ctx.rowFormat != null) + val hasStorageProperties = (ctx.createFileFormat.size != 0) || (ctx.rowFormat.size != 0) if (conf.convertCTAS && !hasStorageProperties) { // At here, both rowStorage.serdeProperties and fileStorage.serdeProperties // are empty Maps. @@ -1366,6 +1388,15 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { } } + private def validateRowFormatFileFormat( + rowFormatCtx: Seq[RowFormatContext], + createFileFormatCtx: Seq[CreateFileFormatContext], + parentCtx: ParserRuleContext): Unit = { + if (rowFormatCtx.size == 1 && createFileFormatCtx.size == 1) { + validateRowFormatFileFormat(rowFormatCtx.head, createFileFormatCtx.head, parentCtx) + } + } + /** * Create or replace a view. This creates a [[CreateViewCommand]] command. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala index eb7c33590b602..2b1aea08b1223 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala @@ -54,6 +54,13 @@ class DDLParserSuite extends PlanTest with SharedSQLContext { } } + private def intercept(sqlCommand: String, messages: String*): Unit = { + val e = intercept[ParseException](parser.parsePlan(sqlCommand)).getMessage + messages.foreach { message => + assert(e.contains(message)) + } + } + private def parseAs[T: ClassTag](query: String): T = { parser.parsePlan(query) match { case t: T => t @@ -494,6 +501,37 @@ class DDLParserSuite extends PlanTest with SharedSQLContext { } } + test("Duplicate clauses - create table") { + def createTableHeader(duplicateClause: String, isNative: Boolean): String = { + val fileFormat = if (isNative) "USING parquet" else "STORED AS parquet" + s"CREATE TABLE my_tab(a INT, b STRING) $fileFormat $duplicateClause $duplicateClause" + } + + Seq(true, false).foreach { isNative => + intercept(createTableHeader("TBLPROPERTIES('test' = 'test2')", isNative), + "Found duplicate clauses: TBLPROPERTIES") + intercept(createTableHeader("LOCATION '/tmp/file'", isNative), + "Found duplicate clauses: LOCATION") + intercept(createTableHeader("COMMENT 'a table'", isNative), + "Found duplicate clauses: COMMENT") + intercept(createTableHeader("CLUSTERED BY(b) INTO 256 BUCKETS", isNative), + "Found duplicate clauses: CLUSTERED BY") + } + + // Only for native data source tables + intercept(createTableHeader("PARTITIONED BY (b)", isNative = true), + "Found duplicate clauses: PARTITIONED BY") + + // Only for Hive serde tables + intercept(createTableHeader("PARTITIONED BY (k int)", isNative = false), + "Found duplicate clauses: PARTITIONED BY") + intercept(createTableHeader("STORED AS parquet", isNative = false), + "Found duplicate clauses: STORED AS/BY") + intercept( + createTableHeader("ROW FORMAT SERDE 'parquet.hive.serde.ParquetHiveSerDe'", isNative = false), + "Found duplicate clauses: ROW FORMAT") + } + test("create table - with location") { val v1 = "CREATE TABLE my_tab(a INT, b STRING) USING parquet LOCATION '/tmp/file'" @@ -1153,38 +1191,119 @@ class DDLParserSuite extends PlanTest with SharedSQLContext { } } + test("Test CTAS against data source tables") { + val s1 = + """ + |CREATE TABLE IF NOT EXISTS mydb.page_view + |USING parquet + |COMMENT 'This is the staging page view table' + |LOCATION '/user/external/page_view' + |TBLPROPERTIES ('p1'='v1', 'p2'='v2') + |AS SELECT * FROM src + """.stripMargin + + val s2 = + """ + |CREATE TABLE IF NOT EXISTS mydb.page_view + |USING parquet + |LOCATION '/user/external/page_view' + |COMMENT 'This is the staging page view table' + |TBLPROPERTIES ('p1'='v1', 'p2'='v2') + |AS SELECT * FROM src + """.stripMargin + + val s3 = + """ + |CREATE TABLE IF NOT EXISTS mydb.page_view + |USING parquet + |COMMENT 'This is the staging page view table' + |LOCATION '/user/external/page_view' + |TBLPROPERTIES ('p1'='v1', 'p2'='v2') + |AS SELECT * FROM src + """.stripMargin + + checkParsing(s1) + checkParsing(s2) + checkParsing(s3) + + def checkParsing(sql: String): Unit = { + val (desc, exists) = extractTableDesc(sql) + assert(exists) + assert(desc.identifier.database == Some("mydb")) + assert(desc.identifier.table == "page_view") + assert(desc.storage.locationUri == Some(new URI("/user/external/page_view"))) + assert(desc.schema.isEmpty) // will be populated later when the table is actually created + assert(desc.comment == Some("This is the staging page view table")) + assert(desc.viewText.isEmpty) + assert(desc.viewDefaultDatabase.isEmpty) + assert(desc.viewQueryColumnNames.isEmpty) + assert(desc.partitionColumnNames.isEmpty) + assert(desc.provider == Some("parquet")) + assert(desc.properties == Map("p1" -> "v1", "p2" -> "v2")) + } + } + test("Test CTAS #1") { val s1 = - """CREATE EXTERNAL TABLE IF NOT EXISTS mydb.page_view + """ + |CREATE EXTERNAL TABLE IF NOT EXISTS mydb.page_view |COMMENT 'This is the staging page view table' |STORED AS RCFILE |LOCATION '/user/external/page_view' |TBLPROPERTIES ('p1'='v1', 'p2'='v2') - |AS SELECT * FROM src""".stripMargin + |AS SELECT * FROM src + """.stripMargin - val (desc, exists) = extractTableDesc(s1) - assert(exists) - assert(desc.identifier.database == Some("mydb")) - assert(desc.identifier.table == "page_view") - assert(desc.tableType == CatalogTableType.EXTERNAL) - assert(desc.storage.locationUri == Some(new URI("/user/external/page_view"))) - assert(desc.schema.isEmpty) // will be populated later when the table is actually created - assert(desc.comment == Some("This is the staging page view table")) - // TODO will be SQLText - assert(desc.viewText.isEmpty) - assert(desc.viewDefaultDatabase.isEmpty) - assert(desc.viewQueryColumnNames.isEmpty) - assert(desc.partitionColumnNames.isEmpty) - assert(desc.storage.inputFormat == Some("org.apache.hadoop.hive.ql.io.RCFileInputFormat")) - assert(desc.storage.outputFormat == Some("org.apache.hadoop.hive.ql.io.RCFileOutputFormat")) - assert(desc.storage.serde == - Some("org.apache.hadoop.hive.serde2.columnar.LazyBinaryColumnarSerDe")) - assert(desc.properties == Map("p1" -> "v1", "p2" -> "v2")) + val s2 = + """ + |CREATE EXTERNAL TABLE IF NOT EXISTS mydb.page_view + |STORED AS RCFILE + |COMMENT 'This is the staging page view table' + |TBLPROPERTIES ('p1'='v1', 'p2'='v2') + |LOCATION '/user/external/page_view' + |AS SELECT * FROM src + """.stripMargin + + val s3 = + """ + |CREATE EXTERNAL TABLE IF NOT EXISTS mydb.page_view + |TBLPROPERTIES ('p1'='v1', 'p2'='v2') + |LOCATION '/user/external/page_view' + |STORED AS RCFILE + |COMMENT 'This is the staging page view table' + |AS SELECT * FROM src + """.stripMargin + + checkParsing(s1) + checkParsing(s2) + checkParsing(s3) + + def checkParsing(sql: String): Unit = { + val (desc, exists) = extractTableDesc(sql) + assert(exists) + assert(desc.identifier.database == Some("mydb")) + assert(desc.identifier.table == "page_view") + assert(desc.tableType == CatalogTableType.EXTERNAL) + assert(desc.storage.locationUri == Some(new URI("/user/external/page_view"))) + assert(desc.schema.isEmpty) // will be populated later when the table is actually created + assert(desc.comment == Some("This is the staging page view table")) + // TODO will be SQLText + assert(desc.viewText.isEmpty) + assert(desc.viewDefaultDatabase.isEmpty) + assert(desc.viewQueryColumnNames.isEmpty) + assert(desc.partitionColumnNames.isEmpty) + assert(desc.storage.inputFormat == Some("org.apache.hadoop.hive.ql.io.RCFileInputFormat")) + assert(desc.storage.outputFormat == Some("org.apache.hadoop.hive.ql.io.RCFileOutputFormat")) + assert(desc.storage.serde == + Some("org.apache.hadoop.hive.serde2.columnar.LazyBinaryColumnarSerDe")) + assert(desc.properties == Map("p1" -> "v1", "p2" -> "v2")) + } } test("Test CTAS #2") { - val s2 = - """CREATE EXTERNAL TABLE IF NOT EXISTS mydb.page_view + val s1 = + """ + |CREATE EXTERNAL TABLE IF NOT EXISTS mydb.page_view |COMMENT 'This is the staging page view table' |ROW FORMAT SERDE 'parquet.hive.serde.ParquetHiveSerDe' | STORED AS @@ -1192,26 +1311,45 @@ class DDLParserSuite extends PlanTest with SharedSQLContext { | OUTPUTFORMAT 'parquet.hive.DeprecatedParquetOutputFormat' |LOCATION '/user/external/page_view' |TBLPROPERTIES ('p1'='v1', 'p2'='v2') - |AS SELECT * FROM src""".stripMargin + |AS SELECT * FROM src + """.stripMargin - val (desc, exists) = extractTableDesc(s2) - assert(exists) - assert(desc.identifier.database == Some("mydb")) - assert(desc.identifier.table == "page_view") - assert(desc.tableType == CatalogTableType.EXTERNAL) - assert(desc.storage.locationUri == Some(new URI("/user/external/page_view"))) - assert(desc.schema.isEmpty) // will be populated later when the table is actually created - // TODO will be SQLText - assert(desc.comment == Some("This is the staging page view table")) - assert(desc.viewText.isEmpty) - assert(desc.viewDefaultDatabase.isEmpty) - assert(desc.viewQueryColumnNames.isEmpty) - assert(desc.partitionColumnNames.isEmpty) - assert(desc.storage.properties == Map()) - assert(desc.storage.inputFormat == Some("parquet.hive.DeprecatedParquetInputFormat")) - assert(desc.storage.outputFormat == Some("parquet.hive.DeprecatedParquetOutputFormat")) - assert(desc.storage.serde == Some("parquet.hive.serde.ParquetHiveSerDe")) - assert(desc.properties == Map("p1" -> "v1", "p2" -> "v2")) + val s2 = + """ + |CREATE EXTERNAL TABLE IF NOT EXISTS mydb.page_view + |LOCATION '/user/external/page_view' + |TBLPROPERTIES ('p1'='v1', 'p2'='v2') + |ROW FORMAT SERDE 'parquet.hive.serde.ParquetHiveSerDe' + | STORED AS + | INPUTFORMAT 'parquet.hive.DeprecatedParquetInputFormat' + | OUTPUTFORMAT 'parquet.hive.DeprecatedParquetOutputFormat' + |COMMENT 'This is the staging page view table' + |AS SELECT * FROM src + """.stripMargin + + checkParsing(s1) + checkParsing(s2) + + def checkParsing(sql: String): Unit = { + val (desc, exists) = extractTableDesc(sql) + assert(exists) + assert(desc.identifier.database == Some("mydb")) + assert(desc.identifier.table == "page_view") + assert(desc.tableType == CatalogTableType.EXTERNAL) + assert(desc.storage.locationUri == Some(new URI("/user/external/page_view"))) + assert(desc.schema.isEmpty) // will be populated later when the table is actually created + // TODO will be SQLText + assert(desc.comment == Some("This is the staging page view table")) + assert(desc.viewText.isEmpty) + assert(desc.viewDefaultDatabase.isEmpty) + assert(desc.viewQueryColumnNames.isEmpty) + assert(desc.partitionColumnNames.isEmpty) + assert(desc.storage.properties == Map()) + assert(desc.storage.inputFormat == Some("parquet.hive.DeprecatedParquetInputFormat")) + assert(desc.storage.outputFormat == Some("parquet.hive.DeprecatedParquetOutputFormat")) + assert(desc.storage.serde == Some("parquet.hive.serde.ParquetHiveSerDe")) + assert(desc.properties == Map("p1" -> "v1", "p2" -> "v2")) + } } test("Test CTAS #3") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index fdb9b2f51f9cb..591510c1d8283 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -1971,8 +1971,8 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { s""" |CREATE TABLE t(a int, b int, c int, d int) |USING parquet - |PARTITIONED BY(a, b) |LOCATION "${dir.toURI}" + |PARTITIONED BY(a, b) """.stripMargin) spark.sql("INSERT INTO TABLE t PARTITION(a=1, b=2) SELECT 3, 4") checkAnswer(spark.table("t"), Row(3, 4, 1, 2) :: Nil) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala index f2e0c695ca38b..65be244418670 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -875,12 +875,13 @@ class HiveDDLSuite test("desc table for Hive table - bucketed + sorted table") { withTable("tbl") { - sql(s""" - CREATE TABLE tbl (id int, name string) - PARTITIONED BY (ds string) - CLUSTERED BY(id) - SORTED BY(id, name) INTO 1024 BUCKETS - """) + sql( + s""" + |CREATE TABLE tbl (id int, name string) + |CLUSTERED BY(id) + |SORTED BY(id, name) INTO 1024 BUCKETS + |PARTITIONED BY (ds string) + """.stripMargin) val x = sql("DESC FORMATTED tbl").collect() assert(x.containsSlice( diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 07ae3ae945848..47adc77a52d51 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -461,51 +461,55 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } test("CTAS without serde without location") { - val originalConf = sessionState.conf.convertCTAS - - setConf(SQLConf.CONVERT_CTAS, true) - - val defaultDataSource = sessionState.conf.defaultDataSourceName - try { - sql("CREATE TABLE ctas1 AS SELECT key k, value FROM src ORDER BY k, value") - sql("CREATE TABLE IF NOT EXISTS ctas1 AS SELECT key k, value FROM src ORDER BY k, value") - val message = intercept[AnalysisException] { + withSQLConf(SQLConf.CONVERT_CTAS.key -> "true") { + val defaultDataSource = sessionState.conf.defaultDataSourceName + withTable("ctas1") { sql("CREATE TABLE ctas1 AS SELECT key k, value FROM src ORDER BY k, value") - }.getMessage - assert(message.contains("already exists")) - checkRelation("ctas1", true, defaultDataSource) - sql("DROP TABLE ctas1") + sql("CREATE TABLE IF NOT EXISTS ctas1 AS SELECT key k, value FROM src ORDER BY k, value") + val message = intercept[AnalysisException] { + sql("CREATE TABLE ctas1 AS SELECT key k, value FROM src ORDER BY k, value") + }.getMessage + assert(message.contains("already exists")) + checkRelation("ctas1", isDataSourceTable = true, defaultDataSource) + } // Specifying database name for query can be converted to data source write path // is not allowed right now. - sql("CREATE TABLE default.ctas1 AS SELECT key k, value FROM src ORDER BY k, value") - checkRelation("ctas1", true, defaultDataSource) - sql("DROP TABLE ctas1") + withTable("ctas1") { + sql("CREATE TABLE default.ctas1 AS SELECT key k, value FROM src ORDER BY k, value") + checkRelation("ctas1", isDataSourceTable = true, defaultDataSource) + } - sql("CREATE TABLE ctas1 stored as textfile" + + withTable("ctas1") { + sql("CREATE TABLE ctas1 stored as textfile" + " AS SELECT key k, value FROM src ORDER BY k, value") - checkRelation("ctas1", false, "text") - sql("DROP TABLE ctas1") + checkRelation("ctas1", isDataSourceTable = false, "text") + } - sql("CREATE TABLE ctas1 stored as sequencefile" + - " AS SELECT key k, value FROM src ORDER BY k, value") - checkRelation("ctas1", false, "sequence") - sql("DROP TABLE ctas1") + withTable("ctas1") { + sql("CREATE TABLE ctas1 stored as sequencefile" + + " AS SELECT key k, value FROM src ORDER BY k, value") + checkRelation("ctas1", isDataSourceTable = false, "sequence") + } - sql("CREATE TABLE ctas1 stored as rcfile AS SELECT key k, value FROM src ORDER BY k, value") - checkRelation("ctas1", false, "rcfile") - sql("DROP TABLE ctas1") + withTable("ctas1") { + sql("CREATE TABLE ctas1 stored as rcfile AS SELECT key k, value FROM src ORDER BY k, value") + checkRelation("ctas1", isDataSourceTable = false, "rcfile") + } - sql("CREATE TABLE ctas1 stored as orc AS SELECT key k, value FROM src ORDER BY k, value") - checkRelation("ctas1", false, "orc") - sql("DROP TABLE ctas1") + withTable("ctas1") { + sql("CREATE TABLE ctas1 stored as orc AS SELECT key k, value FROM src ORDER BY k, value") + checkRelation("ctas1", isDataSourceTable = false, "orc") + } - sql("CREATE TABLE ctas1 stored as parquet AS SELECT key k, value FROM src ORDER BY k, value") - checkRelation("ctas1", false, "parquet") - sql("DROP TABLE ctas1") - } finally { - setConf(SQLConf.CONVERT_CTAS, originalConf) - sql("DROP TABLE IF EXISTS ctas1") + withTable("ctas1") { + sql( + """ + |CREATE TABLE ctas1 stored as parquet + |AS SELECT key k, value FROM src ORDER BY k, value + """.stripMargin) + checkRelation("ctas1", isDataSourceTable = false, "parquet") + } } } @@ -539,30 +543,40 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { val defaultDataSource = sessionState.conf.defaultDataSourceName val tempLocation = dir.toURI.getPath.stripSuffix("/") - sql(s"CREATE TABLE ctas1 LOCATION 'file:$tempLocation/c1'" + - " AS SELECT key k, value FROM src ORDER BY k, value") - checkRelation("ctas1", true, defaultDataSource, Some(s"file:$tempLocation/c1")) - sql("DROP TABLE ctas1") + withTable("ctas1") { + sql(s"CREATE TABLE ctas1 LOCATION 'file:$tempLocation/c1'" + + " AS SELECT key k, value FROM src ORDER BY k, value") + checkRelation( + "ctas1", isDataSourceTable = true, defaultDataSource, Some(s"file:$tempLocation/c1")) + } - sql(s"CREATE TABLE ctas1 LOCATION 'file:$tempLocation/c2'" + - " AS SELECT key k, value FROM src ORDER BY k, value") - checkRelation("ctas1", true, defaultDataSource, Some(s"file:$tempLocation/c2")) - sql("DROP TABLE ctas1") + withTable("ctas1") { + sql(s"CREATE TABLE ctas1 LOCATION 'file:$tempLocation/c2'" + + " AS SELECT key k, value FROM src ORDER BY k, value") + checkRelation( + "ctas1", isDataSourceTable = true, defaultDataSource, Some(s"file:$tempLocation/c2")) + } - sql(s"CREATE TABLE ctas1 stored as textfile LOCATION 'file:$tempLocation/c3'" + - " AS SELECT key k, value FROM src ORDER BY k, value") - checkRelation("ctas1", false, "text", Some(s"file:$tempLocation/c3")) - sql("DROP TABLE ctas1") + withTable("ctas1") { + sql(s"CREATE TABLE ctas1 stored as textfile LOCATION 'file:$tempLocation/c3'" + + " AS SELECT key k, value FROM src ORDER BY k, value") + checkRelation( + "ctas1", isDataSourceTable = false, "text", Some(s"file:$tempLocation/c3")) + } - sql(s"CREATE TABLE ctas1 stored as sequenceFile LOCATION 'file:$tempLocation/c4'" + - " AS SELECT key k, value FROM src ORDER BY k, value") - checkRelation("ctas1", false, "sequence", Some(s"file:$tempLocation/c4")) - sql("DROP TABLE ctas1") + withTable("ctas1") { + sql(s"CREATE TABLE ctas1 stored as sequenceFile LOCATION 'file:$tempLocation/c4'" + + " AS SELECT key k, value FROM src ORDER BY k, value") + checkRelation( + "ctas1", isDataSourceTable = false, "sequence", Some(s"file:$tempLocation/c4")) + } - sql(s"CREATE TABLE ctas1 stored as rcfile LOCATION 'file:$tempLocation/c5'" + - " AS SELECT key k, value FROM src ORDER BY k, value") - checkRelation("ctas1", false, "rcfile", Some(s"file:$tempLocation/c5")) - sql("DROP TABLE ctas1") + withTable("ctas1") { + sql(s"CREATE TABLE ctas1 stored as rcfile LOCATION 'file:$tempLocation/c5'" + + " AS SELECT key k, value FROM src ORDER BY k, value") + checkRelation( + "ctas1", isDataSourceTable = false, "rcfile", Some(s"file:$tempLocation/c5")) + } } } } From a66fe36cee9363b01ee70e469f1c968f633c5713 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 3 Jan 2018 22:18:13 +0800 Subject: [PATCH 0005/1161] [SPARK-20236][SQL] dynamic partition overwrite ## What changes were proposed in this pull request? When overwriting a partitioned table with dynamic partition columns, the behavior is different between data source and hive tables. data source table: delete all partition directories that match the static partition values provided in the insert statement. hive table: only delete partition directories which have data written into it This PR adds a new config to make users be able to choose hive's behavior. ## How was this patch tested? new tests Author: Wenchen Fan Closes #18714 from cloud-fan/overwrite-partition. --- .../internal/io/FileCommitProtocol.scala | 25 ++++-- .../io/HadoopMapReduceCommitProtocol.scala | 75 ++++++++++++++---- .../apache/spark/sql/internal/SQLConf.scala | 21 +++++ .../InsertIntoHadoopFsRelationCommand.scala | 20 ++++- .../SQLHadoopMapReduceCommitProtocol.scala | 10 ++- .../spark/sql/sources/InsertSuite.scala | 78 +++++++++++++++++++ 6 files changed, 200 insertions(+), 29 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala b/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala index 50f51e1af4530..6d0059b6a0272 100644 --- a/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala +++ b/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala @@ -28,8 +28,9 @@ import org.apache.spark.util.Utils * * 1. Implementations must be serializable, as the committer instance instantiated on the driver * will be used for tasks on executors. - * 2. Implementations should have a constructor with 2 arguments: - * (jobId: String, path: String) + * 2. Implementations should have a constructor with 2 or 3 arguments: + * (jobId: String, path: String) or + * (jobId: String, path: String, dynamicPartitionOverwrite: Boolean) * 3. A committer should not be reused across multiple Spark jobs. * * The proper call sequence is: @@ -139,10 +140,22 @@ object FileCommitProtocol { /** * Instantiates a FileCommitProtocol using the given className. */ - def instantiate(className: String, jobId: String, outputPath: String) - : FileCommitProtocol = { + def instantiate( + className: String, + jobId: String, + outputPath: String, + dynamicPartitionOverwrite: Boolean = false): FileCommitProtocol = { val clazz = Utils.classForName(className).asInstanceOf[Class[FileCommitProtocol]] - val ctor = clazz.getDeclaredConstructor(classOf[String], classOf[String]) - ctor.newInstance(jobId, outputPath) + // First try the constructor with arguments (jobId: String, outputPath: String, + // dynamicPartitionOverwrite: Boolean). + // If that doesn't exist, try the one with (jobId: string, outputPath: String). + try { + val ctor = clazz.getDeclaredConstructor(classOf[String], classOf[String], classOf[Boolean]) + ctor.newInstance(jobId, outputPath, dynamicPartitionOverwrite.asInstanceOf[java.lang.Boolean]) + } catch { + case _: NoSuchMethodException => + val ctor = clazz.getDeclaredConstructor(classOf[String], classOf[String]) + ctor.newInstance(jobId, outputPath) + } } } diff --git a/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala index 95c99d29c3a9c..6d20ef1f98a3c 100644 --- a/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala +++ b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala @@ -39,8 +39,19 @@ import org.apache.spark.mapred.SparkHadoopMapRedUtil * * @param jobId the job's or stage's id * @param path the job's output path, or null if committer acts as a noop + * @param dynamicPartitionOverwrite If true, Spark will overwrite partition directories at runtime + * dynamically, i.e., we first write files under a staging + * directory with partition path, e.g. + * /path/to/staging/a=1/b=1/xxx.parquet. When committing the job, + * we first clean up the corresponding partition directories at + * destination path, e.g. /path/to/destination/a=1/b=1, and move + * files from staging directory to the corresponding partition + * directories under destination path. */ -class HadoopMapReduceCommitProtocol(jobId: String, path: String) +class HadoopMapReduceCommitProtocol( + jobId: String, + path: String, + dynamicPartitionOverwrite: Boolean = false) extends FileCommitProtocol with Serializable with Logging { import FileCommitProtocol._ @@ -67,9 +78,17 @@ class HadoopMapReduceCommitProtocol(jobId: String, path: String) @transient private var addedAbsPathFiles: mutable.Map[String, String] = null /** - * The staging directory for all files committed with absolute output paths. + * Tracks partitions with default path that have new files written into them by this task, + * e.g. a=1/b=2. Files under these partitions will be saved into staging directory and moved to + * destination directory at the end, if `dynamicPartitionOverwrite` is true. */ - private def absPathStagingDir: Path = new Path(path, "_temporary-" + jobId) + @transient private var partitionPaths: mutable.Set[String] = null + + /** + * The staging directory of this write job. Spark uses it to deal with files with absolute output + * path, or writing data into partitioned directory with dynamicPartitionOverwrite=true. + */ + private def stagingDir = new Path(path, ".spark-staging-" + jobId) protected def setupCommitter(context: TaskAttemptContext): OutputCommitter = { val format = context.getOutputFormatClass.newInstance() @@ -85,11 +104,16 @@ class HadoopMapReduceCommitProtocol(jobId: String, path: String) taskContext: TaskAttemptContext, dir: Option[String], ext: String): String = { val filename = getFilename(taskContext, ext) - val stagingDir: String = committer match { + val stagingDir: Path = committer match { + case _ if dynamicPartitionOverwrite => + assert(dir.isDefined, + "The dataset to be written must be partitioned when dynamicPartitionOverwrite is true.") + partitionPaths += dir.get + this.stagingDir // For FileOutputCommitter it has its own staging path called "work path". case f: FileOutputCommitter => - Option(f.getWorkPath).map(_.toString).getOrElse(path) - case _ => path + new Path(Option(f.getWorkPath).map(_.toString).getOrElse(path)) + case _ => new Path(path) } dir.map { d => @@ -106,8 +130,7 @@ class HadoopMapReduceCommitProtocol(jobId: String, path: String) // Include a UUID here to prevent file collisions for one task writing to different dirs. // In principle we could include hash(absoluteDir) instead but this is simpler. - val tmpOutputPath = new Path( - absPathStagingDir, UUID.randomUUID().toString() + "-" + filename).toString + val tmpOutputPath = new Path(stagingDir, UUID.randomUUID().toString() + "-" + filename).toString addedAbsPathFiles(tmpOutputPath) = absOutputPath tmpOutputPath @@ -141,23 +164,42 @@ class HadoopMapReduceCommitProtocol(jobId: String, path: String) override def commitJob(jobContext: JobContext, taskCommits: Seq[TaskCommitMessage]): Unit = { committer.commitJob(jobContext) - val filesToMove = taskCommits.map(_.obj.asInstanceOf[Map[String, String]]) - .foldLeft(Map[String, String]())(_ ++ _) - logDebug(s"Committing files staged for absolute locations $filesToMove") + if (hasValidPath) { - val fs = absPathStagingDir.getFileSystem(jobContext.getConfiguration) + val (allAbsPathFiles, allPartitionPaths) = + taskCommits.map(_.obj.asInstanceOf[(Map[String, String], Set[String])]).unzip + val fs = stagingDir.getFileSystem(jobContext.getConfiguration) + + val filesToMove = allAbsPathFiles.foldLeft(Map[String, String]())(_ ++ _) + logDebug(s"Committing files staged for absolute locations $filesToMove") + if (dynamicPartitionOverwrite) { + val absPartitionPaths = filesToMove.values.map(new Path(_).getParent).toSet + logDebug(s"Clean up absolute partition directories for overwriting: $absPartitionPaths") + absPartitionPaths.foreach(fs.delete(_, true)) + } for ((src, dst) <- filesToMove) { fs.rename(new Path(src), new Path(dst)) } - fs.delete(absPathStagingDir, true) + + if (dynamicPartitionOverwrite) { + val partitionPaths = allPartitionPaths.foldLeft(Set[String]())(_ ++ _) + logDebug(s"Clean up default partition directories for overwriting: $partitionPaths") + for (part <- partitionPaths) { + val finalPartPath = new Path(path, part) + fs.delete(finalPartPath, true) + fs.rename(new Path(stagingDir, part), finalPartPath) + } + } + + fs.delete(stagingDir, true) } } override def abortJob(jobContext: JobContext): Unit = { committer.abortJob(jobContext, JobStatus.State.FAILED) if (hasValidPath) { - val fs = absPathStagingDir.getFileSystem(jobContext.getConfiguration) - fs.delete(absPathStagingDir, true) + val fs = stagingDir.getFileSystem(jobContext.getConfiguration) + fs.delete(stagingDir, true) } } @@ -165,13 +207,14 @@ class HadoopMapReduceCommitProtocol(jobId: String, path: String) committer = setupCommitter(taskContext) committer.setupTask(taskContext) addedAbsPathFiles = mutable.Map[String, String]() + partitionPaths = mutable.Set[String]() } override def commitTask(taskContext: TaskAttemptContext): TaskCommitMessage = { val attemptId = taskContext.getTaskAttemptID SparkHadoopMapRedUtil.commitTask( committer, taskContext, attemptId.getJobID.getId, attemptId.getTaskID.getId) - new TaskCommitMessage(addedAbsPathFiles.toMap) + new TaskCommitMessage(addedAbsPathFiles.toMap -> partitionPaths.toSet) } override def abortTask(taskContext: TaskAttemptContext): Unit = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 80cdc61484c0f..5d6edf6b8abec 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1068,6 +1068,24 @@ object SQLConf { .timeConf(TimeUnit.MILLISECONDS) .createWithDefault(100) + object PartitionOverwriteMode extends Enumeration { + val STATIC, DYNAMIC = Value + } + + val PARTITION_OVERWRITE_MODE = + buildConf("spark.sql.sources.partitionOverwriteMode") + .doc("When INSERT OVERWRITE a partitioned data source table, we currently support 2 modes: " + + "static and dynamic. In static mode, Spark deletes all the partitions that match the " + + "partition specification(e.g. PARTITION(a=1,b)) in the INSERT statement, before " + + "overwriting. In dynamic mode, Spark doesn't delete partitions ahead, and only overwrite " + + "those partitions that have data written into it at runtime. By default we use static " + + "mode to keep the same behavior of Spark prior to 2.3. Note that this config doesn't " + + "affect Hive serde tables, as they are always overwritten with dynamic mode.") + .stringConf + .transform(_.toUpperCase(Locale.ROOT)) + .checkValues(PartitionOverwriteMode.values.map(_.toString)) + .createWithDefault(PartitionOverwriteMode.STATIC.toString) + object Deprecated { val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" } @@ -1394,6 +1412,9 @@ class SQLConf extends Serializable with Logging { def concatBinaryAsString: Boolean = getConf(CONCAT_BINARY_AS_STRING) + def partitionOverwriteMode: PartitionOverwriteMode.Value = + PartitionOverwriteMode.withName(getConf(PARTITION_OVERWRITE_MODE)) + /** ********************** SQLConf functionality methods ************ */ /** Set Spark SQL configuration properties. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala index ad24e280d942a..dd7ef0d15c140 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.command._ +import org.apache.spark.sql.internal.SQLConf.PartitionOverwriteMode import org.apache.spark.sql.util.SchemaUtils /** @@ -89,13 +90,19 @@ case class InsertIntoHadoopFsRelationCommand( } val pathExists = fs.exists(qualifiedOutputPath) - // If we are appending data to an existing dir. - val isAppend = pathExists && (mode == SaveMode.Append) + + val enableDynamicOverwrite = + sparkSession.sessionState.conf.partitionOverwriteMode == PartitionOverwriteMode.DYNAMIC + // This config only makes sense when we are overwriting a partitioned dataset with dynamic + // partition columns. + val dynamicPartitionOverwrite = enableDynamicOverwrite && mode == SaveMode.Overwrite && + staticPartitions.size < partitionColumns.length val committer = FileCommitProtocol.instantiate( sparkSession.sessionState.conf.fileCommitProtocolClass, jobId = java.util.UUID.randomUUID().toString, - outputPath = outputPath.toString) + outputPath = outputPath.toString, + dynamicPartitionOverwrite = dynamicPartitionOverwrite) val doInsertion = (mode, pathExists) match { case (SaveMode.ErrorIfExists, true) => @@ -103,6 +110,9 @@ case class InsertIntoHadoopFsRelationCommand( case (SaveMode.Overwrite, true) => if (ifPartitionNotExists && matchingPartitions.nonEmpty) { false + } else if (dynamicPartitionOverwrite) { + // For dynamic partition overwrite, do not delete partition directories ahead. + true } else { deleteMatchingPartitions(fs, qualifiedOutputPath, customPartitionLocations, committer) true @@ -126,7 +136,9 @@ case class InsertIntoHadoopFsRelationCommand( catalogTable.get.identifier, newPartitions.toSeq.map(p => (p, None)), ifNotExists = true).run(sparkSession) } - if (mode == SaveMode.Overwrite) { + // For dynamic partition overwrite, we never remove partitions but only update existing + // ones. + if (mode == SaveMode.Overwrite && !dynamicPartitionOverwrite) { val deletedPartitions = initialMatchingPartitions.toSet -- updatedPartitions if (deletedPartitions.nonEmpty) { AlterTableDropPartitionCommand( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SQLHadoopMapReduceCommitProtocol.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SQLHadoopMapReduceCommitProtocol.scala index 40825a1f724b1..39c594a9bc618 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SQLHadoopMapReduceCommitProtocol.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SQLHadoopMapReduceCommitProtocol.scala @@ -29,11 +29,15 @@ import org.apache.spark.sql.internal.SQLConf * A variant of [[HadoopMapReduceCommitProtocol]] that allows specifying the actual * Hadoop output committer using an option specified in SQLConf. */ -class SQLHadoopMapReduceCommitProtocol(jobId: String, path: String) - extends HadoopMapReduceCommitProtocol(jobId, path) with Serializable with Logging { +class SQLHadoopMapReduceCommitProtocol( + jobId: String, + path: String, + dynamicPartitionOverwrite: Boolean = false) + extends HadoopMapReduceCommitProtocol(jobId, path, dynamicPartitionOverwrite) + with Serializable with Logging { override protected def setupCommitter(context: TaskAttemptContext): OutputCommitter = { - var committer = context.getOutputFormatClass.newInstance().getOutputCommitter(context) + var committer = super.setupCommitter(context) val configuration = context.getConfiguration val clazz = diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala index 8b7e2e5f45946..fef01c860db6e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala @@ -21,6 +21,8 @@ import java.io.File import org.apache.spark.SparkException import org.apache.spark.sql.{AnalysisException, Row} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.SQLConf.PartitionOverwriteMode import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.util.Utils @@ -442,4 +444,80 @@ class InsertSuite extends DataSourceTest with SharedSQLContext { assert(e.contains("Only Data Sources providing FileFormat are supported")) } } + + test("SPARK-20236: dynamic partition overwrite without catalog table") { + withSQLConf(SQLConf.PARTITION_OVERWRITE_MODE.key -> PartitionOverwriteMode.DYNAMIC.toString) { + withTempPath { path => + Seq((1, 1, 1)).toDF("i", "part1", "part2") + .write.partitionBy("part1", "part2").parquet(path.getAbsolutePath) + checkAnswer(spark.read.parquet(path.getAbsolutePath), Row(1, 1, 1)) + + Seq((2, 1, 1)).toDF("i", "part1", "part2") + .write.partitionBy("part1", "part2").mode("overwrite").parquet(path.getAbsolutePath) + checkAnswer(spark.read.parquet(path.getAbsolutePath), Row(2, 1, 1)) + + Seq((2, 2, 2)).toDF("i", "part1", "part2") + .write.partitionBy("part1", "part2").mode("overwrite").parquet(path.getAbsolutePath) + checkAnswer(spark.read.parquet(path.getAbsolutePath), Row(2, 1, 1) :: Row(2, 2, 2) :: Nil) + } + } + } + + test("SPARK-20236: dynamic partition overwrite") { + withSQLConf(SQLConf.PARTITION_OVERWRITE_MODE.key -> PartitionOverwriteMode.DYNAMIC.toString) { + withTable("t") { + sql( + """ + |create table t(i int, part1 int, part2 int) using parquet + |partitioned by (part1, part2) + """.stripMargin) + + sql("insert into t partition(part1=1, part2=1) select 1") + checkAnswer(spark.table("t"), Row(1, 1, 1)) + + sql("insert overwrite table t partition(part1=1, part2=1) select 2") + checkAnswer(spark.table("t"), Row(2, 1, 1)) + + sql("insert overwrite table t partition(part1=2, part2) select 2, 2") + checkAnswer(spark.table("t"), Row(2, 1, 1) :: Row(2, 2, 2) :: Nil) + + sql("insert overwrite table t partition(part1=1, part2=2) select 3") + checkAnswer(spark.table("t"), Row(2, 1, 1) :: Row(2, 2, 2) :: Row(3, 1, 2) :: Nil) + + sql("insert overwrite table t partition(part1=1, part2) select 4, 1") + checkAnswer(spark.table("t"), Row(4, 1, 1) :: Row(2, 2, 2) :: Row(3, 1, 2) :: Nil) + } + } + } + + test("SPARK-20236: dynamic partition overwrite with customer partition path") { + withSQLConf(SQLConf.PARTITION_OVERWRITE_MODE.key -> PartitionOverwriteMode.DYNAMIC.toString) { + withTable("t") { + sql( + """ + |create table t(i int, part1 int, part2 int) using parquet + |partitioned by (part1, part2) + """.stripMargin) + + val path1 = Utils.createTempDir() + sql(s"alter table t add partition(part1=1, part2=1) location '$path1'") + sql(s"insert into t partition(part1=1, part2=1) select 1") + checkAnswer(spark.table("t"), Row(1, 1, 1)) + + sql("insert overwrite table t partition(part1=1, part2=1) select 2") + checkAnswer(spark.table("t"), Row(2, 1, 1)) + + sql("insert overwrite table t partition(part1=2, part2) select 2, 2") + checkAnswer(spark.table("t"), Row(2, 1, 1) :: Row(2, 2, 2) :: Nil) + + val path2 = Utils.createTempDir() + sql(s"alter table t add partition(part1=1, part2=2) location '$path2'") + sql("insert overwrite table t partition(part1=1, part2=2) select 3") + checkAnswer(spark.table("t"), Row(2, 1, 1) :: Row(2, 2, 2) :: Row(3, 1, 2) :: Nil) + + sql("insert overwrite table t partition(part1=1, part2) select 4, 1") + checkAnswer(spark.table("t"), Row(4, 1, 1) :: Row(2, 2, 2) :: Row(3, 1, 2) :: Nil) + } + } + } } From 9a2b65a3c0c36316aae0a53aa0f61c5044c2ceff Mon Sep 17 00:00:00 2001 From: chetkhatri Date: Wed, 3 Jan 2018 11:31:32 -0600 Subject: [PATCH 0006/1161] [SPARK-22896] Improvement in String interpolation ## What changes were proposed in this pull request? * String interpolation in ml pipeline example has been corrected as per scala standard. ## How was this patch tested? * manually tested. Author: chetkhatri Closes #20070 from chetkhatri/mllib-chetan-contrib. --- .../spark/examples/ml/JavaQuantileDiscretizerExample.java | 2 +- .../apache/spark/examples/SimpleSkewedGroupByTest.scala | 4 ---- .../org/apache/spark/examples/graphx/Analytics.scala | 6 ++++-- .../org/apache/spark/examples/graphx/SynthBenchmark.scala | 6 +++--- .../apache/spark/examples/ml/ChiSquareTestExample.scala | 6 +++--- .../org/apache/spark/examples/ml/CorrelationExample.scala | 4 ++-- .../org/apache/spark/examples/ml/DataFrameExample.scala | 4 ++-- .../examples/ml/DecisionTreeClassificationExample.scala | 4 ++-- .../spark/examples/ml/DecisionTreeRegressionExample.scala | 4 ++-- .../apache/spark/examples/ml/DeveloperApiExample.scala | 6 +++--- .../examples/ml/EstimatorTransformerParamExample.scala | 6 +++--- .../ml/GradientBoostedTreeClassifierExample.scala | 4 ++-- .../examples/ml/GradientBoostedTreeRegressorExample.scala | 4 ++-- ...ulticlassLogisticRegressionWithElasticNetExample.scala | 2 +- .../ml/MultilayerPerceptronClassifierExample.scala | 2 +- .../org/apache/spark/examples/ml/NaiveBayesExample.scala | 2 +- .../spark/examples/ml/QuantileDiscretizerExample.scala | 4 ++-- .../spark/examples/ml/RandomForestClassifierExample.scala | 4 ++-- .../spark/examples/ml/RandomForestRegressorExample.scala | 4 ++-- .../apache/spark/examples/ml/VectorIndexerExample.scala | 4 ++-- .../spark/examples/mllib/AssociationRulesExample.scala | 6 +++--- .../mllib/BinaryClassificationMetricsExample.scala | 4 ++-- .../mllib/DecisionTreeClassificationExample.scala | 4 ++-- .../examples/mllib/DecisionTreeRegressionExample.scala | 4 ++-- .../org/apache/spark/examples/mllib/FPGrowthExample.scala | 2 +- .../mllib/GradientBoostingClassificationExample.scala | 4 ++-- .../mllib/GradientBoostingRegressionExample.scala | 4 ++-- .../spark/examples/mllib/HypothesisTestingExample.scala | 2 +- .../spark/examples/mllib/IsotonicRegressionExample.scala | 2 +- .../org/apache/spark/examples/mllib/KMeansExample.scala | 2 +- .../org/apache/spark/examples/mllib/LBFGSExample.scala | 2 +- .../examples/mllib/LatentDirichletAllocationExample.scala | 8 +++++--- .../examples/mllib/LinearRegressionWithSGDExample.scala | 2 +- .../org/apache/spark/examples/mllib/PCAExample.scala | 4 ++-- .../spark/examples/mllib/PMMLModelExportExample.scala | 2 +- .../apache/spark/examples/mllib/PrefixSpanExample.scala | 4 ++-- .../mllib/RandomForestClassificationExample.scala | 4 ++-- .../examples/mllib/RandomForestRegressionExample.scala | 4 ++-- .../spark/examples/mllib/RecommendationExample.scala | 2 +- .../apache/spark/examples/mllib/SVMWithSGDExample.scala | 2 +- .../org/apache/spark/examples/mllib/SimpleFPGrowth.scala | 8 +++----- .../spark/examples/mllib/StratifiedSamplingExample.scala | 4 ++-- .../org/apache/spark/examples/mllib/TallSkinnyPCA.scala | 2 +- .../org/apache/spark/examples/mllib/TallSkinnySVD.scala | 2 +- .../apache/spark/examples/streaming/CustomReceiver.scala | 6 +++--- .../apache/spark/examples/streaming/RawNetworkGrep.scala | 2 +- .../examples/streaming/RecoverableNetworkWordCount.scala | 8 ++++---- .../streaming/clickstream/PageViewGenerator.scala | 4 ++-- .../examples/streaming/clickstream/PageViewStream.scala | 4 ++-- 49 files changed, 94 insertions(+), 96 deletions(-) diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaQuantileDiscretizerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaQuantileDiscretizerExample.java index dd20cac621102..43cc30c1a899b 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaQuantileDiscretizerExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaQuantileDiscretizerExample.java @@ -66,7 +66,7 @@ public static void main(String[] args) { .setNumBuckets(3); Dataset result = discretizer.fit(df).transform(df); - result.show(); + result.show(false); // $example off$ spark.stop(); } diff --git a/examples/src/main/scala/org/apache/spark/examples/SimpleSkewedGroupByTest.scala b/examples/src/main/scala/org/apache/spark/examples/SimpleSkewedGroupByTest.scala index e64dcbd182d94..2332a661f26a0 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SimpleSkewedGroupByTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SimpleSkewedGroupByTest.scala @@ -60,10 +60,6 @@ object SimpleSkewedGroupByTest { pairs1.count println(s"RESULT: ${pairs1.groupByKey(numReducers).count}") - // Print how many keys each reducer got (for debugging) - // println("RESULT: " + pairs1.groupByKey(numReducers) - // .map{case (k,v) => (k, v.size)} - // .collectAsMap) spark.stop() } diff --git a/examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala b/examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala index 92936bd30dbc0..815404d1218b7 100644 --- a/examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala +++ b/examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala @@ -145,9 +145,11 @@ object Analytics extends Logging { // TriangleCount requires the graph to be partitioned .partitionBy(partitionStrategy.getOrElse(RandomVertexCut)).cache() val triangles = TriangleCount.run(graph) - println("Triangles: " + triangles.vertices.map { + val triangleTypes = triangles.vertices.map { case (vid, data) => data.toLong - }.reduce(_ + _) / 3) + }.reduce(_ + _) / 3 + + println(s"Triangles: ${triangleTypes}") sc.stop() case _ => diff --git a/examples/src/main/scala/org/apache/spark/examples/graphx/SynthBenchmark.scala b/examples/src/main/scala/org/apache/spark/examples/graphx/SynthBenchmark.scala index 6d2228c8742aa..57b2edf992208 100644 --- a/examples/src/main/scala/org/apache/spark/examples/graphx/SynthBenchmark.scala +++ b/examples/src/main/scala/org/apache/spark/examples/graphx/SynthBenchmark.scala @@ -52,7 +52,7 @@ object SynthBenchmark { arg => arg.dropWhile(_ == '-').split('=') match { case Array(opt, v) => (opt -> v) - case _ => throw new IllegalArgumentException("Invalid argument: " + arg) + case _ => throw new IllegalArgumentException(s"Invalid argument: $arg") } } @@ -76,7 +76,7 @@ object SynthBenchmark { case ("sigma", v) => sigma = v.toDouble case ("degFile", v) => degFile = v case ("seed", v) => seed = v.toInt - case (opt, _) => throw new IllegalArgumentException("Invalid option: " + opt) + case (opt, _) => throw new IllegalArgumentException(s"Invalid option: $opt") } val conf = new SparkConf() @@ -86,7 +86,7 @@ object SynthBenchmark { val sc = new SparkContext(conf) // Create the graph - println(s"Creating graph...") + println("Creating graph...") val unpartitionedGraph = GraphGenerators.logNormalGraph(sc, numVertices, numEPart.getOrElse(sc.defaultParallelism), mu, sigma, seed) // Repartition the graph diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/ChiSquareTestExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/ChiSquareTestExample.scala index dcee1e427ce58..5146fd0316467 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/ChiSquareTestExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/ChiSquareTestExample.scala @@ -52,9 +52,9 @@ object ChiSquareTestExample { val df = data.toDF("label", "features") val chi = ChiSquareTest.test(df, "features", "label").head - println("pValues = " + chi.getAs[Vector](0)) - println("degreesOfFreedom = " + chi.getSeq[Int](1).mkString("[", ",", "]")) - println("statistics = " + chi.getAs[Vector](2)) + println(s"pValues = ${chi.getAs[Vector](0)}") + println(s"degreesOfFreedom ${chi.getSeq[Int](1).mkString("[", ",", "]")}") + println(s"statistics ${chi.getAs[Vector](2)}") // $example off$ spark.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/CorrelationExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/CorrelationExample.scala index 3f57dc342eb00..d7f1fc8ed74d7 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/CorrelationExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/CorrelationExample.scala @@ -51,10 +51,10 @@ object CorrelationExample { val df = data.map(Tuple1.apply).toDF("features") val Row(coeff1: Matrix) = Correlation.corr(df, "features").head - println("Pearson correlation matrix:\n" + coeff1.toString) + println(s"Pearson correlation matrix:\n $coeff1") val Row(coeff2: Matrix) = Correlation.corr(df, "features", "spearman").head - println("Spearman correlation matrix:\n" + coeff2.toString) + println(s"Spearman correlation matrix:\n $coeff2") // $example off$ spark.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DataFrameExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DataFrameExample.scala index 0658bddf16961..ee4469faab3a0 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/DataFrameExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DataFrameExample.scala @@ -47,7 +47,7 @@ object DataFrameExample { val parser = new OptionParser[Params]("DataFrameExample") { head("DataFrameExample: an example app using DataFrame for ML.") opt[String]("input") - .text(s"input path to dataframe") + .text("input path to dataframe") .action((x, c) => c.copy(input = x)) checkConfig { params => success @@ -93,7 +93,7 @@ object DataFrameExample { // Load the records back. println(s"Loading Parquet file with UDT from $outputDir.") val newDF = spark.read.parquet(outputDir) - println(s"Schema from Parquet:") + println("Schema from Parquet:") newDF.printSchema() spark.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeClassificationExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeClassificationExample.scala index bc6d3275933ea..276cedab11abc 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeClassificationExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeClassificationExample.scala @@ -83,10 +83,10 @@ object DecisionTreeClassificationExample { .setPredictionCol("prediction") .setMetricName("accuracy") val accuracy = evaluator.evaluate(predictions) - println("Test Error = " + (1.0 - accuracy)) + println(s"Test Error = ${(1.0 - accuracy)}") val treeModel = model.stages(2).asInstanceOf[DecisionTreeClassificationModel] - println("Learned classification tree model:\n" + treeModel.toDebugString) + println(s"Learned classification tree model:\n ${treeModel.toDebugString}") // $example off$ spark.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeRegressionExample.scala index ee61200ad1d0c..aaaecaea47081 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeRegressionExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeRegressionExample.scala @@ -73,10 +73,10 @@ object DecisionTreeRegressionExample { .setPredictionCol("prediction") .setMetricName("rmse") val rmse = evaluator.evaluate(predictions) - println("Root Mean Squared Error (RMSE) on test data = " + rmse) + println(s"Root Mean Squared Error (RMSE) on test data = $rmse") val treeModel = model.stages(1).asInstanceOf[DecisionTreeRegressionModel] - println("Learned regression tree model:\n" + treeModel.toDebugString) + println(s"Learned regression tree model:\n ${treeModel.toDebugString}") // $example off$ spark.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala index d94d837d10e96..2dc11b07d88ef 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala @@ -53,7 +53,7 @@ object DeveloperApiExample { // Create a LogisticRegression instance. This instance is an Estimator. val lr = new MyLogisticRegression() // Print out the parameters, documentation, and any default values. - println("MyLogisticRegression parameters:\n" + lr.explainParams() + "\n") + println(s"MyLogisticRegression parameters:\n ${lr.explainParams()}") // We may set parameters using setter methods. lr.setMaxIter(10) @@ -169,10 +169,10 @@ private class MyLogisticRegressionModel( Vectors.dense(-margin, margin) } - /** Number of classes the label can take. 2 indicates binary classification. */ + // Number of classes the label can take. 2 indicates binary classification. override val numClasses: Int = 2 - /** Number of features the model was trained on. */ + // Number of features the model was trained on. override val numFeatures: Int = coefficients.size /** diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/EstimatorTransformerParamExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/EstimatorTransformerParamExample.scala index f18d86e1a6921..e5d91f132a3f2 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/EstimatorTransformerParamExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/EstimatorTransformerParamExample.scala @@ -46,7 +46,7 @@ object EstimatorTransformerParamExample { // Create a LogisticRegression instance. This instance is an Estimator. val lr = new LogisticRegression() // Print out the parameters, documentation, and any default values. - println("LogisticRegression parameters:\n" + lr.explainParams() + "\n") + println(s"LogisticRegression parameters:\n ${lr.explainParams()}\n") // We may set parameters using setter methods. lr.setMaxIter(10) @@ -58,7 +58,7 @@ object EstimatorTransformerParamExample { // we can view the parameters it used during fit(). // This prints the parameter (name: value) pairs, where names are unique IDs for this // LogisticRegression instance. - println("Model 1 was fit using parameters: " + model1.parent.extractParamMap) + println(s"Model 1 was fit using parameters: ${model1.parent.extractParamMap}") // We may alternatively specify parameters using a ParamMap, // which supports several methods for specifying parameters. @@ -73,7 +73,7 @@ object EstimatorTransformerParamExample { // Now learn a new model using the paramMapCombined parameters. // paramMapCombined overrides all parameters set earlier via lr.set* methods. val model2 = lr.fit(training, paramMapCombined) - println("Model 2 was fit using parameters: " + model2.parent.extractParamMap) + println(s"Model 2 was fit using parameters: ${model2.parent.extractParamMap}") // Prepare test data. val test = spark.createDataFrame(Seq( diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/GradientBoostedTreeClassifierExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/GradientBoostedTreeClassifierExample.scala index 3656773c8b817..ef78c0a1145ef 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/GradientBoostedTreeClassifierExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/GradientBoostedTreeClassifierExample.scala @@ -86,10 +86,10 @@ object GradientBoostedTreeClassifierExample { .setPredictionCol("prediction") .setMetricName("accuracy") val accuracy = evaluator.evaluate(predictions) - println("Test Error = " + (1.0 - accuracy)) + println(s"Test Error = ${1.0 - accuracy}") val gbtModel = model.stages(2).asInstanceOf[GBTClassificationModel] - println("Learned classification GBT model:\n" + gbtModel.toDebugString) + println(s"Learned classification GBT model:\n ${gbtModel.toDebugString}") // $example off$ spark.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/GradientBoostedTreeRegressorExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/GradientBoostedTreeRegressorExample.scala index e53aab7f326d3..3feb2343f6a85 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/GradientBoostedTreeRegressorExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/GradientBoostedTreeRegressorExample.scala @@ -73,10 +73,10 @@ object GradientBoostedTreeRegressorExample { .setPredictionCol("prediction") .setMetricName("rmse") val rmse = evaluator.evaluate(predictions) - println("Root Mean Squared Error (RMSE) on test data = " + rmse) + println(s"Root Mean Squared Error (RMSE) on test data = $rmse") val gbtModel = model.stages(1).asInstanceOf[GBTRegressionModel] - println("Learned regression GBT model:\n" + gbtModel.toDebugString) + println(s"Learned regression GBT model:\n ${gbtModel.toDebugString}") // $example off$ spark.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/MulticlassLogisticRegressionWithElasticNetExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/MulticlassLogisticRegressionWithElasticNetExample.scala index 42f0ace7a353d..3e61dbe628c20 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/MulticlassLogisticRegressionWithElasticNetExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/MulticlassLogisticRegressionWithElasticNetExample.scala @@ -48,7 +48,7 @@ object MulticlassLogisticRegressionWithElasticNetExample { // Print the coefficients and intercept for multinomial logistic regression println(s"Coefficients: \n${lrModel.coefficientMatrix}") - println(s"Intercepts: ${lrModel.interceptVector}") + println(s"Intercepts: \n${lrModel.interceptVector}") // $example off$ spark.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/MultilayerPerceptronClassifierExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/MultilayerPerceptronClassifierExample.scala index 6fce82d294f8d..646f46a925062 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/MultilayerPerceptronClassifierExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/MultilayerPerceptronClassifierExample.scala @@ -66,7 +66,7 @@ object MultilayerPerceptronClassifierExample { val evaluator = new MulticlassClassificationEvaluator() .setMetricName("accuracy") - println("Test set accuracy = " + evaluator.evaluate(predictionAndLabels)) + println(s"Test set accuracy = ${evaluator.evaluate(predictionAndLabels)}") // $example off$ spark.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/NaiveBayesExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/NaiveBayesExample.scala index bd9fcc420a66c..50c70c626b128 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/NaiveBayesExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/NaiveBayesExample.scala @@ -52,7 +52,7 @@ object NaiveBayesExample { .setPredictionCol("prediction") .setMetricName("accuracy") val accuracy = evaluator.evaluate(predictions) - println("Test set accuracy = " + accuracy) + println(s"Test set accuracy = $accuracy") // $example off$ spark.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/QuantileDiscretizerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/QuantileDiscretizerExample.scala index aedb9e7d3bb70..0fe16fb6dfa9f 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/QuantileDiscretizerExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/QuantileDiscretizerExample.scala @@ -36,7 +36,7 @@ object QuantileDiscretizerExample { // Output of QuantileDiscretizer for such small datasets can depend on the number of // partitions. Here we force a single partition to ensure consistent results. // Note this is not necessary for normal use cases - .repartition(1) + .repartition(1) // $example on$ val discretizer = new QuantileDiscretizer() @@ -45,7 +45,7 @@ object QuantileDiscretizerExample { .setNumBuckets(3) val result = discretizer.fit(df).transform(df) - result.show() + result.show(false) // $example off$ spark.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestClassifierExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestClassifierExample.scala index 5eafda8ce4285..6265f83902528 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestClassifierExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestClassifierExample.scala @@ -85,10 +85,10 @@ object RandomForestClassifierExample { .setPredictionCol("prediction") .setMetricName("accuracy") val accuracy = evaluator.evaluate(predictions) - println("Test Error = " + (1.0 - accuracy)) + println(s"Test Error = ${(1.0 - accuracy)}") val rfModel = model.stages(2).asInstanceOf[RandomForestClassificationModel] - println("Learned classification forest model:\n" + rfModel.toDebugString) + println(s"Learned classification forest model:\n ${rfModel.toDebugString}") // $example off$ spark.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestRegressorExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestRegressorExample.scala index 9a0a001c26ef5..2679fcb353a8a 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestRegressorExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestRegressorExample.scala @@ -72,10 +72,10 @@ object RandomForestRegressorExample { .setPredictionCol("prediction") .setMetricName("rmse") val rmse = evaluator.evaluate(predictions) - println("Root Mean Squared Error (RMSE) on test data = " + rmse) + println(s"Root Mean Squared Error (RMSE) on test data = $rmse") val rfModel = model.stages(1).asInstanceOf[RandomForestRegressionModel] - println("Learned regression forest model:\n" + rfModel.toDebugString) + println(s"Learned regression forest model:\n ${rfModel.toDebugString}") // $example off$ spark.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/VectorIndexerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/VectorIndexerExample.scala index afa761aee0b98..96bb8ea2338af 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/VectorIndexerExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/VectorIndexerExample.scala @@ -41,8 +41,8 @@ object VectorIndexerExample { val indexerModel = indexer.fit(data) val categoricalFeatures: Set[Int] = indexerModel.categoryMaps.keys.toSet - println(s"Chose ${categoricalFeatures.size} categorical features: " + - categoricalFeatures.mkString(", ")) + println(s"Chose ${categoricalFeatures.size} " + + s"categorical features: ${categoricalFeatures.mkString(", ")}") // Create new column "indexed" with categorical values transformed to indices val indexedData = indexerModel.transform(data) diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/AssociationRulesExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/AssociationRulesExample.scala index ff44de56839e5..a07535bb5a38d 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/AssociationRulesExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/AssociationRulesExample.scala @@ -42,9 +42,8 @@ object AssociationRulesExample { val results = ar.run(freqItemsets) results.collect().foreach { rule => - println("[" + rule.antecedent.mkString(",") - + "=>" - + rule.consequent.mkString(",") + "]," + rule.confidence) + println(s"[${rule.antecedent.mkString(",")}=>${rule.consequent.mkString(",")} ]" + + s" ${rule.confidence}") } // $example off$ @@ -53,3 +52,4 @@ object AssociationRulesExample { } // scalastyle:on println + diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassificationMetricsExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassificationMetricsExample.scala index b9263ac6fcff6..c6312d71cc912 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassificationMetricsExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassificationMetricsExample.scala @@ -86,7 +86,7 @@ object BinaryClassificationMetricsExample { // AUPRC val auPRC = metrics.areaUnderPR - println("Area under precision-recall curve = " + auPRC) + println(s"Area under precision-recall curve = $auPRC") // Compute thresholds used in ROC and PR curves val thresholds = precision.map(_._1) @@ -96,7 +96,7 @@ object BinaryClassificationMetricsExample { // AUROC val auROC = metrics.areaUnderROC - println("Area under ROC = " + auROC) + println(s"Area under ROC = $auROC") // $example off$ sc.stop() } diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeClassificationExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeClassificationExample.scala index b50b4592777ce..c2f89b72c9a2e 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeClassificationExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeClassificationExample.scala @@ -55,8 +55,8 @@ object DecisionTreeClassificationExample { (point.label, prediction) } val testErr = labelAndPreds.filter(r => r._1 != r._2).count().toDouble / testData.count() - println("Test Error = " + testErr) - println("Learned classification tree model:\n" + model.toDebugString) + println(s"Test Error = $testErr") + println(s"Learned classification tree model:\n ${model.toDebugString}") // Save and load model model.save(sc, "target/tmp/myDecisionTreeClassificationModel") diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRegressionExample.scala index 2af45afae3d5b..1ecf6426e1f95 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRegressionExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRegressionExample.scala @@ -54,8 +54,8 @@ object DecisionTreeRegressionExample { (point.label, prediction) } val testMSE = labelsAndPredictions.map{ case (v, p) => math.pow(v - p, 2) }.mean() - println("Test Mean Squared Error = " + testMSE) - println("Learned regression tree model:\n" + model.toDebugString) + println(s"Test Mean Squared Error = $testMSE") + println(s"Learned regression tree model:\n ${model.toDebugString}") // Save and load model model.save(sc, "target/tmp/myDecisionTreeRegressionModel") diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/FPGrowthExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/FPGrowthExample.scala index 6435abc127752..f724ee1030f04 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/FPGrowthExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/FPGrowthExample.scala @@ -74,7 +74,7 @@ object FPGrowthExample { println(s"Number of frequent itemsets: ${model.freqItemsets.count()}") model.freqItemsets.collect().foreach { itemset => - println(itemset.items.mkString("[", ",", "]") + ", " + itemset.freq) + println(s"${itemset.items.mkString("[", ",", "]")}, ${itemset.freq}") } sc.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostingClassificationExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostingClassificationExample.scala index 00bb3348d2a36..3c56e1941aeca 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostingClassificationExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostingClassificationExample.scala @@ -54,8 +54,8 @@ object GradientBoostingClassificationExample { (point.label, prediction) } val testErr = labelAndPreds.filter(r => r._1 != r._2).count.toDouble / testData.count() - println("Test Error = " + testErr) - println("Learned classification GBT model:\n" + model.toDebugString) + println(s"Test Error = $testErr") + println(s"Learned classification GBT model:\n ${model.toDebugString}") // Save and load model model.save(sc, "target/tmp/myGradientBoostingClassificationModel") diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostingRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostingRegressionExample.scala index d8c263460839b..c288bf29bf255 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostingRegressionExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostingRegressionExample.scala @@ -53,8 +53,8 @@ object GradientBoostingRegressionExample { (point.label, prediction) } val testMSE = labelsAndPredictions.map{ case(v, p) => math.pow((v - p), 2)}.mean() - println("Test Mean Squared Error = " + testMSE) - println("Learned regression GBT model:\n" + model.toDebugString) + println(s"Test Mean Squared Error = $testMSE") + println(s"Learned regression GBT model:\n ${model.toDebugString}") // Save and load model model.save(sc, "target/tmp/myGradientBoostingRegressionModel") diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/HypothesisTestingExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/HypothesisTestingExample.scala index 0d391a3637c07..add1719739539 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/HypothesisTestingExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/HypothesisTestingExample.scala @@ -68,7 +68,7 @@ object HypothesisTestingExample { // against the label. val featureTestResults: Array[ChiSqTestResult] = Statistics.chiSqTest(obs) featureTestResults.zipWithIndex.foreach { case (k, v) => - println("Column " + (v + 1).toString + ":") + println(s"Column ${(v + 1)} :") println(k) } // summary of the test // $example off$ diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/IsotonicRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/IsotonicRegressionExample.scala index 4aee951f5b04c..a10d6f0dda880 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/IsotonicRegressionExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/IsotonicRegressionExample.scala @@ -56,7 +56,7 @@ object IsotonicRegressionExample { // Calculate mean squared error between predicted and real labels. val meanSquaredError = predictionAndLabel.map { case (p, l) => math.pow((p - l), 2) }.mean() - println("Mean Squared Error = " + meanSquaredError) + println(s"Mean Squared Error = $meanSquaredError") // Save and load model model.save(sc, "target/tmp/myIsotonicRegressionModel") diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/KMeansExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/KMeansExample.scala index c4d71d862f375..b0a6f1671a898 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/KMeansExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/KMeansExample.scala @@ -43,7 +43,7 @@ object KMeansExample { // Evaluate clustering by computing Within Set Sum of Squared Errors val WSSSE = clusters.computeCost(parsedData) - println("Within Set Sum of Squared Errors = " + WSSSE) + println(s"Within Set Sum of Squared Errors = $WSSSE") // Save and load model clusters.save(sc, "target/org/apache/spark/KMeansExample/KMeansModel") diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/LBFGSExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/LBFGSExample.scala index fedcefa098381..123782fa6b9cf 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/LBFGSExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/LBFGSExample.scala @@ -82,7 +82,7 @@ object LBFGSExample { println("Loss of each step in training process") loss.foreach(println) - println("Area under ROC = " + auROC) + println(s"Area under ROC = $auROC") // $example off$ sc.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/LatentDirichletAllocationExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/LatentDirichletAllocationExample.scala index f2c8ec01439f1..d25962c5500ed 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/LatentDirichletAllocationExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/LatentDirichletAllocationExample.scala @@ -42,11 +42,13 @@ object LatentDirichletAllocationExample { val ldaModel = new LDA().setK(3).run(corpus) // Output topics. Each is a distribution over words (matching word count vectors) - println("Learned topics (as distributions over vocab of " + ldaModel.vocabSize + " words):") + println(s"Learned topics (as distributions over vocab of ${ldaModel.vocabSize} words):") val topics = ldaModel.topicsMatrix for (topic <- Range(0, 3)) { - print("Topic " + topic + ":") - for (word <- Range(0, ldaModel.vocabSize)) { print(" " + topics(word, topic)); } + print(s"Topic $topic :") + for (word <- Range(0, ldaModel.vocabSize)) { + print(s"${topics(word, topic)}") + } println() } diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegressionWithSGDExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegressionWithSGDExample.scala index d399618094487..449b725d1d173 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegressionWithSGDExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegressionWithSGDExample.scala @@ -52,7 +52,7 @@ object LinearRegressionWithSGDExample { (point.label, prediction) } val MSE = valuesAndPreds.map{ case(v, p) => math.pow((v - p), 2) }.mean() - println("training Mean Squared Error = " + MSE) + println(s"training Mean Squared Error $MSE") // Save and load model model.save(sc, "target/tmp/scalaLinearRegressionWithSGDModel") diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/PCAExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/PCAExample.scala index eb36697d94ba1..eff2393cc3abe 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/PCAExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/PCAExample.scala @@ -65,8 +65,8 @@ object PCAExample { val MSE = valuesAndPreds.map { case (v, p) => math.pow((v - p), 2) }.mean() val MSE_pca = valuesAndPreds_pca.map { case (v, p) => math.pow((v - p), 2) }.mean() - println("Mean Squared Error = " + MSE) - println("PCA Mean Squared Error = " + MSE_pca) + println(s"Mean Squared Error = $MSE") + println(s"PCA Mean Squared Error = $MSE_pca") // $example off$ sc.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/PMMLModelExportExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/PMMLModelExportExample.scala index d74d74a37fb11..96deafd469bc7 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/PMMLModelExportExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/PMMLModelExportExample.scala @@ -41,7 +41,7 @@ object PMMLModelExportExample { val clusters = KMeans.train(parsedData, numClusters, numIterations) // Export to PMML to a String in PMML format - println("PMML Model:\n" + clusters.toPMML) + println(s"PMML Model:\n ${clusters.toPMML}") // Export the model to a local file in PMML format clusters.toPMML("/tmp/kmeans.xml") diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/PrefixSpanExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/PrefixSpanExample.scala index 69c72c4336576..8b789277774af 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/PrefixSpanExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/PrefixSpanExample.scala @@ -42,8 +42,8 @@ object PrefixSpanExample { val model = prefixSpan.run(sequences) model.freqSequences.collect().foreach { freqSequence => println( - freqSequence.sequence.map(_.mkString("[", ", ", "]")).mkString("[", ", ", "]") + - ", " + freqSequence.freq) + s"${freqSequence.sequence.map(_.mkString("[", ", ", "]")).mkString("[", ", ", "]")}," + + s" ${freqSequence.freq}") } // $example off$ diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/RandomForestClassificationExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/RandomForestClassificationExample.scala index f1ebdf1a733ed..246e71de25615 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/RandomForestClassificationExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/RandomForestClassificationExample.scala @@ -55,8 +55,8 @@ object RandomForestClassificationExample { (point.label, prediction) } val testErr = labelAndPreds.filter(r => r._1 != r._2).count.toDouble / testData.count() - println("Test Error = " + testErr) - println("Learned classification forest model:\n" + model.toDebugString) + println(s"Test Error = $testErr") + println(s"Learned classification forest model:\n ${model.toDebugString}") // Save and load model model.save(sc, "target/tmp/myRandomForestClassificationModel") diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/RandomForestRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/RandomForestRegressionExample.scala index 11d612e651b4b..770e30276bc30 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/RandomForestRegressionExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/RandomForestRegressionExample.scala @@ -55,8 +55,8 @@ object RandomForestRegressionExample { (point.label, prediction) } val testMSE = labelsAndPredictions.map{ case(v, p) => math.pow((v - p), 2)}.mean() - println("Test Mean Squared Error = " + testMSE) - println("Learned regression forest model:\n" + model.toDebugString) + println(s"Test Mean Squared Error = $testMSE") + println(s"Learned regression forest model:\n ${model.toDebugString}") // Save and load model model.save(sc, "target/tmp/myRandomForestRegressionModel") diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/RecommendationExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/RecommendationExample.scala index 6df742d737e70..0bb2b8c8c2b43 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/RecommendationExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/RecommendationExample.scala @@ -56,7 +56,7 @@ object RecommendationExample { val err = (r1 - r2) err * err }.mean() - println("Mean Squared Error = " + MSE) + println(s"Mean Squared Error = $MSE") // Save and load model model.save(sc, "target/tmp/myCollaborativeFilter") diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/SVMWithSGDExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/SVMWithSGDExample.scala index b73fe9b2b3faa..285e2ce512639 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/SVMWithSGDExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/SVMWithSGDExample.scala @@ -57,7 +57,7 @@ object SVMWithSGDExample { val metrics = new BinaryClassificationMetrics(scoreAndLabels) val auROC = metrics.areaUnderROC() - println("Area under ROC = " + auROC) + println(s"Area under ROC = $auROC") // Save and load model model.save(sc, "target/tmp/scalaSVMWithSGDModel") diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/SimpleFPGrowth.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/SimpleFPGrowth.scala index b5c3033bcba09..694c3bb18b045 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/SimpleFPGrowth.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/SimpleFPGrowth.scala @@ -42,15 +42,13 @@ object SimpleFPGrowth { val model = fpg.run(transactions) model.freqItemsets.collect().foreach { itemset => - println(itemset.items.mkString("[", ",", "]") + ", " + itemset.freq) + println(s"${itemset.items.mkString("[", ",", "]")},${itemset.freq}") } val minConfidence = 0.8 model.generateAssociationRules(minConfidence).collect().foreach { rule => - println( - rule.antecedent.mkString("[", ",", "]") - + " => " + rule.consequent .mkString("[", ",", "]") - + ", " + rule.confidence) + println(s"${rule.antecedent.mkString("[", ",", "]")}=> " + + s"${rule.consequent .mkString("[", ",", "]")},${rule.confidence}") } // $example off$ diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/StratifiedSamplingExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/StratifiedSamplingExample.scala index 16b074ef60699..3d41bef0af88c 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/StratifiedSamplingExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/StratifiedSamplingExample.scala @@ -41,10 +41,10 @@ object StratifiedSamplingExample { val exactSample = data.sampleByKeyExact(withReplacement = false, fractions = fractions) // $example off$ - println("approxSample size is " + approxSample.collect().size.toString) + println(s"approxSample size is ${approxSample.collect().size}") approxSample.collect().foreach(println) - println("exactSample its size is " + exactSample.collect().size.toString) + println(s"exactSample its size is ${exactSample.collect().size}") exactSample.collect().foreach(println) sc.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnyPCA.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnyPCA.scala index 03bc675299c5a..071d341b81614 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnyPCA.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnyPCA.scala @@ -54,7 +54,7 @@ object TallSkinnyPCA { // Compute principal components. val pc = mat.computePrincipalComponents(mat.numCols().toInt) - println("Principal components are:\n" + pc) + println(s"Principal components are:\n $pc") sc.stop() } diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnySVD.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnySVD.scala index 067e49b9599e7..8ae6de16d80e7 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnySVD.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnySVD.scala @@ -54,7 +54,7 @@ object TallSkinnySVD { // Compute SVD. val svd = mat.computeSVD(mat.numCols().toInt) - println("Singular values are " + svd.s) + println(s"Singular values are ${svd.s}") sc.stop() } diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/CustomReceiver.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/CustomReceiver.scala index 43044d01b1204..25c7bf2871972 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/CustomReceiver.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/CustomReceiver.scala @@ -82,9 +82,9 @@ class CustomReceiver(host: String, port: Int) var socket: Socket = null var userInput: String = null try { - logInfo("Connecting to " + host + ":" + port) + logInfo(s"Connecting to $host : $port") socket = new Socket(host, port) - logInfo("Connected to " + host + ":" + port) + logInfo(s"Connected to $host : $port") val reader = new BufferedReader( new InputStreamReader(socket.getInputStream(), StandardCharsets.UTF_8)) userInput = reader.readLine() @@ -98,7 +98,7 @@ class CustomReceiver(host: String, port: Int) restart("Trying to connect again") } catch { case e: java.net.ConnectException => - restart("Error connecting to " + host + ":" + port, e) + restart(s"Error connecting to $host : $port", e) case t: Throwable => restart("Error receiving data", t) } diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/RawNetworkGrep.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/RawNetworkGrep.scala index 5322929d177b4..437ccf0898d7c 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/RawNetworkGrep.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/RawNetworkGrep.scala @@ -54,7 +54,7 @@ object RawNetworkGrep { ssc.rawSocketStream[String](host, port, StorageLevel.MEMORY_ONLY_SER_2)).toArray val union = ssc.union(rawStreams) union.filter(_.contains("the")).count().foreachRDD(r => - println("Grep count: " + r.collect().mkString)) + println(s"Grep count: ${r.collect().mkString}")) ssc.start() ssc.awaitTermination() } diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala index 49c0427321133..f018f3a26d2e9 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala @@ -130,10 +130,10 @@ object RecoverableNetworkWordCount { true } }.collect().mkString("[", ", ", "]") - val output = "Counts at time " + time + " " + counts + val output = s"Counts at time $time $counts" println(output) - println("Dropped " + droppedWordsCounter.value + " word(s) totally") - println("Appending to " + outputFile.getAbsolutePath) + println(s"Dropped ${droppedWordsCounter.value} word(s) totally") + println(s"Appending to ${outputFile.getAbsolutePath}") Files.append(output + "\n", outputFile, Charset.defaultCharset()) } ssc @@ -141,7 +141,7 @@ object RecoverableNetworkWordCount { def main(args: Array[String]) { if (args.length != 4) { - System.err.println("Your arguments were " + args.mkString("[", ", ", "]")) + System.err.println(s"Your arguments were ${args.mkString("[", ", ", "]")}") System.err.println( """ |Usage: RecoverableNetworkWordCount diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala index 0ddd065f0db2b..2108bc63edea2 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala @@ -90,13 +90,13 @@ object PageViewGenerator { val viewsPerSecond = args(1).toFloat val sleepDelayMs = (1000.0 / viewsPerSecond).toInt val listener = new ServerSocket(port) - println("Listening on port: " + port) + println(s"Listening on port: $port") while (true) { val socket = listener.accept() new Thread() { override def run(): Unit = { - println("Got client connected from: " + socket.getInetAddress) + println(s"Got client connected from: ${socket.getInetAddress}") val out = new PrintWriter(socket.getOutputStream(), true) while (true) { diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewStream.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewStream.scala index 1ba093f57b32c..b8e7c7e9e9152 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewStream.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewStream.scala @@ -104,8 +104,8 @@ object PageViewStream { .foreachRDD((rdd, time) => rdd.join(userList) .map(_._2._2) .take(10) - .foreach(u => println("Saw user %s at time %s".format(u, time)))) - case _ => println("Invalid metric entered: " + metric) + .foreach(u => println(s"Saw user $u at time $time"))) + case _ => println(s"Invalid metric entered: $metric") } ssc.start() From b297029130735316e1ac1144dee44761a12bfba7 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 4 Jan 2018 07:28:53 +0800 Subject: [PATCH 0007/1161] [SPARK-20960][SQL] make ColumnVector public ## What changes were proposed in this pull request? move `ColumnVector` and related classes to `org.apache.spark.sql.vectorized`, and improve the document. ## How was this patch tested? existing tests. Author: Wenchen Fan Closes #20116 from cloud-fan/column-vector. --- .../VectorizedParquetRecordReader.java | 7 ++- .../vectorized/ColumnVectorUtils.java | 2 + .../vectorized/MutableColumnarRow.java | 4 ++ .../vectorized/WritableColumnVector.java | 7 ++- .../vectorized/ArrowColumnVector.java | 62 +------------------ .../vectorized/ColumnVector.java | 31 ++++++---- .../vectorized/ColumnarArray.java | 7 +-- .../vectorized/ColumnarBatch.java | 34 +++------- .../vectorized/ColumnarRow.java | 7 +-- .../sql/execution/ColumnarBatchScan.scala | 4 +- .../aggregate/HashAggregateExec.scala | 2 +- .../VectorizedHashMapGenerator.scala | 3 +- .../sql/execution/arrow/ArrowConverters.scala | 2 +- .../columnar/InMemoryTableScanExec.scala | 1 + .../execution/datasources/FileScanRDD.scala | 2 +- .../execution/python/ArrowPythonRunner.scala | 2 +- .../execution/arrow/ArrowWriterSuite.scala | 2 +- .../vectorized/ArrowColumnVectorSuite.scala | 1 + .../vectorized/ColumnVectorSuite.scala | 2 +- .../vectorized/ColumnarBatchSuite.scala | 6 +- 20 files changed, 63 insertions(+), 125 deletions(-) rename sql/core/src/main/java/org/apache/spark/sql/{execution => }/vectorized/ArrowColumnVector.java (94%) rename sql/core/src/main/java/org/apache/spark/sql/{execution => }/vectorized/ColumnVector.java (79%) rename sql/core/src/main/java/org/apache/spark/sql/{execution => }/vectorized/ColumnarArray.java (95%) rename sql/core/src/main/java/org/apache/spark/sql/{execution => }/vectorized/ColumnarBatch.java (73%) rename sql/core/src/main/java/org/apache/spark/sql/{execution => }/vectorized/ColumnarRow.java (96%) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java index 6c157e85d411f..cd745b1f0e4e3 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java @@ -31,10 +31,10 @@ import org.apache.spark.memory.MemoryMode; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.execution.vectorized.ColumnVectorUtils; -import org.apache.spark.sql.execution.vectorized.ColumnarBatch; import org.apache.spark.sql.execution.vectorized.WritableColumnVector; import org.apache.spark.sql.execution.vectorized.OffHeapColumnVector; import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector; +import org.apache.spark.sql.vectorized.ColumnarBatch; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; @@ -248,7 +248,10 @@ public void enableReturningBatches() { * Advances to the next batch of rows. Returns false if there are no more. */ public boolean nextBatch() throws IOException { - columnarBatch.reset(); + for (WritableColumnVector vector : columnVectors) { + vector.reset(); + } + columnarBatch.setNumRows(0); if (rowsReturned >= totalRowCount) return false; checkEndOfRowGroup(); diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java index bc62bc43484e5..b5cbe8e2839ba 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java @@ -28,6 +28,8 @@ import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.util.DateTimeUtils; import org.apache.spark.sql.types.*; +import org.apache.spark.sql.vectorized.ColumnarArray; +import org.apache.spark.sql.vectorized.ColumnarBatch; import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java index 06602c147dfe9..70057a9def6c0 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java @@ -23,6 +23,10 @@ import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; import org.apache.spark.sql.catalyst.util.MapData; import org.apache.spark.sql.types.*; +import org.apache.spark.sql.vectorized.ColumnarArray; +import org.apache.spark.sql.vectorized.ColumnarBatch; +import org.apache.spark.sql.vectorized.ColumnarRow; +import org.apache.spark.sql.vectorized.ColumnVector; import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java index 5f6f125976e12..d2ae32b06f83b 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java @@ -23,6 +23,7 @@ import org.apache.spark.sql.internal.SQLConf; import org.apache.spark.sql.types.*; +import org.apache.spark.sql.vectorized.ColumnVector; import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.types.UTF8String; @@ -585,11 +586,11 @@ public final int appendArray(int length) { public final int appendStruct(boolean isNull) { if (isNull) { appendNull(); - for (ColumnVector c: childColumns) { + for (WritableColumnVector c: childColumns) { if (c.type instanceof StructType) { - ((WritableColumnVector) c).appendStruct(true); + c.appendStruct(true); } else { - ((WritableColumnVector) c).appendNull(); + c.appendNull(); } } } else { diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java similarity index 94% rename from sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java rename to sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java index af5673e26a501..708333213f3f1 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.execution.vectorized; +package org.apache.spark.sql.vectorized; import org.apache.arrow.vector.*; import org.apache.arrow.vector.complex.*; @@ -34,11 +34,7 @@ public final class ArrowColumnVector extends ColumnVector { private ArrowColumnVector[] childColumns; private void ensureAccessible(int index) { - int valueCount = accessor.getValueCount(); - if (index < 0 || index >= valueCount) { - throw new IndexOutOfBoundsException( - String.format("index: %d, valueCount: %d", index, valueCount)); - } + ensureAccessible(index, 1); } private void ensureAccessible(int index, int count) { @@ -64,20 +60,12 @@ public void close() { accessor.close(); } - // - // APIs dealing with nulls - // - @Override public boolean isNullAt(int rowId) { ensureAccessible(rowId); return accessor.isNullAt(rowId); } - // - // APIs dealing with Booleans - // - @Override public boolean getBoolean(int rowId) { ensureAccessible(rowId); @@ -94,10 +82,6 @@ public boolean[] getBooleans(int rowId, int count) { return array; } - // - // APIs dealing with Bytes - // - @Override public byte getByte(int rowId) { ensureAccessible(rowId); @@ -114,10 +98,6 @@ public byte[] getBytes(int rowId, int count) { return array; } - // - // APIs dealing with Shorts - // - @Override public short getShort(int rowId) { ensureAccessible(rowId); @@ -134,10 +114,6 @@ public short[] getShorts(int rowId, int count) { return array; } - // - // APIs dealing with Ints - // - @Override public int getInt(int rowId) { ensureAccessible(rowId); @@ -154,10 +130,6 @@ public int[] getInts(int rowId, int count) { return array; } - // - // APIs dealing with Longs - // - @Override public long getLong(int rowId) { ensureAccessible(rowId); @@ -174,10 +146,6 @@ public long[] getLongs(int rowId, int count) { return array; } - // - // APIs dealing with floats - // - @Override public float getFloat(int rowId) { ensureAccessible(rowId); @@ -194,10 +162,6 @@ public float[] getFloats(int rowId, int count) { return array; } - // - // APIs dealing with doubles - // - @Override public double getDouble(int rowId) { ensureAccessible(rowId); @@ -214,10 +178,6 @@ public double[] getDoubles(int rowId, int count) { return array; } - // - // APIs dealing with Arrays - // - @Override public int getArrayLength(int rowId) { ensureAccessible(rowId); @@ -230,45 +190,27 @@ public int getArrayOffset(int rowId) { return accessor.getArrayOffset(rowId); } - // - // APIs dealing with Decimals - // - @Override public Decimal getDecimal(int rowId, int precision, int scale) { ensureAccessible(rowId); return accessor.getDecimal(rowId, precision, scale); } - // - // APIs dealing with UTF8Strings - // - @Override public UTF8String getUTF8String(int rowId) { ensureAccessible(rowId); return accessor.getUTF8String(rowId); } - // - // APIs dealing with Binaries - // - @Override public byte[] getBinary(int rowId) { ensureAccessible(rowId); return accessor.getBinary(rowId); } - /** - * Returns the data for the underlying array. - */ @Override public ArrowColumnVector arrayData() { return childColumns[0]; } - /** - * Returns the ordinal's child data column. - */ @Override public ArrowColumnVector getChildColumn(int ordinal) { return childColumns[ordinal]; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java similarity index 79% rename from sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java rename to sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java index dc7c1269bedd9..d1196e1299fee 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql.execution.vectorized; +package org.apache.spark.sql.vectorized; import org.apache.spark.sql.catalyst.util.MapData; import org.apache.spark.sql.types.DataType; @@ -22,24 +22,31 @@ import org.apache.spark.unsafe.types.UTF8String; /** - * This class represents in-memory values of a column and provides the main APIs to access the data. - * It supports all the types and contains get APIs as well as their batched versions. The batched - * versions are considered to be faster and preferable whenever possible. + * An interface representing in-memory columnar data in Spark. This interface defines the main APIs + * to access the data, as well as their batched versions. The batched versions are considered to be + * faster and preferable whenever possible. * - * To handle nested schemas, ColumnVector has two types: Arrays and Structs. In both cases these - * columns have child columns. All of the data are stored in the child columns and the parent column - * only contains nullability. In the case of Arrays, the lengths and offsets are saved in the child - * column and are encoded identically to INTs. + * Most of the APIs take the rowId as a parameter. This is the batch local 0-based row id for values + * in this ColumnVector. * - * Maps are just a special case of a two field struct. + * ColumnVector supports all the data types including nested types. To handle nested types, + * ColumnVector can have children and is a tree structure. For struct type, it stores the actual + * data of each field in the corresponding child ColumnVector, and only stores null information in + * the parent ColumnVector. For array type, it stores the actual array elements in the child + * ColumnVector, and stores null information, array offsets and lengths in the parent ColumnVector. * - * Most of the APIs take the rowId as a parameter. This is the batch local 0-based row id for values - * in the current batch. + * ColumnVector is expected to be reused during the entire data loading process, to avoid allocating + * memory again and again. + * + * ColumnVector is meant to maximize CPU efficiency but not to minimize storage footprint. + * Implementations should prefer computing efficiency over storage efficiency when design the + * format. Since it is expected to reuse the ColumnVector instance while loading data, the storage + * footprint is negligible. */ public abstract class ColumnVector implements AutoCloseable { /** - * Returns the data type of this column. + * Returns the data type of this column vector. */ public final DataType dataType() { return type; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarArray.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java similarity index 95% rename from sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarArray.java rename to sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java index cbc39d1d0aec2..0d89a52e7a4fe 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarArray.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql.execution.vectorized; +package org.apache.spark.sql.vectorized; import org.apache.spark.sql.catalyst.util.ArrayData; import org.apache.spark.sql.catalyst.util.MapData; @@ -23,8 +23,7 @@ import org.apache.spark.unsafe.types.UTF8String; /** - * Array abstraction in {@link ColumnVector}. The instance of this class is intended - * to be reused, callers should copy the data out if it needs to be stored. + * Array abstraction in {@link ColumnVector}. */ public final class ColumnarArray extends ArrayData { // The data for this array. This array contains elements from @@ -33,7 +32,7 @@ public final class ColumnarArray extends ArrayData { private final int offset; private final int length; - ColumnarArray(ColumnVector data, int offset, int length) { + public ColumnarArray(ColumnVector data, int offset, int length) { this.data = data; this.offset = offset; this.length = length; diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatch.java similarity index 73% rename from sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java rename to sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatch.java index a9d09aa679726..9ae1c6d9993f0 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatch.java @@ -14,26 +14,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql.execution.vectorized; +package org.apache.spark.sql.vectorized; import java.util.*; import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.execution.vectorized.MutableColumnarRow; import org.apache.spark.sql.types.StructType; /** - * This class is the in memory representation of rows as they are streamed through operators. It - * is designed to maximize CPU efficiency and not storage footprint. Since it is expected that - * each operator allocates one of these objects, the storage footprint on the task is negligible. - * - * The layout is a columnar with values encoded in their native format. Each RowBatch contains - * a horizontal partitioning of the data, split into columns. - * - * The ColumnarBatch supports either on heap or offheap modes with (mostly) the identical API. - * - * TODO: - * - There are many TODOs for the existing APIs. They should throw a not implemented exception. - * - Compaction: The batch and columns should be able to compact based on a selection vector. + * This class wraps multiple ColumnVectors as a row-wise table. It provides a row view of this + * batch so that Spark can access the data row by row. Instance of it is meant to be reused during + * the entire data loading process. */ public final class ColumnarBatch { public static final int DEFAULT_BATCH_SIZE = 4 * 1024; @@ -57,7 +49,7 @@ public void close() { } /** - * Returns an iterator over the rows in this batch. This skips rows that are filtered out. + * Returns an iterator over the rows in this batch. */ public Iterator rowIterator() { final int maxRows = numRows; @@ -87,19 +79,7 @@ public void remove() { } /** - * Resets the batch for writing. - */ - public void reset() { - for (int i = 0; i < numCols(); ++i) { - if (columns[i] instanceof WritableColumnVector) { - ((WritableColumnVector) columns[i]).reset(); - } - } - this.numRows = 0; - } - - /** - * Sets the number of rows that are valid. + * Sets the number of rows in this batch. */ public void setNumRows(int numRows) { assert(numRows <= this.capacity); diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarRow.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java similarity index 96% rename from sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarRow.java rename to sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java index 8bb33ed5b78c0..3c6656dec77cd 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarRow.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql.execution.vectorized; +package org.apache.spark.sql.vectorized; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; @@ -24,8 +24,7 @@ import org.apache.spark.unsafe.types.UTF8String; /** - * Row abstraction in {@link ColumnVector}. The instance of this class is intended - * to be reused, callers should copy the data out if it needs to be stored. + * Row abstraction in {@link ColumnVector}. */ public final class ColumnarRow extends InternalRow { // The data for this row. @@ -34,7 +33,7 @@ public final class ColumnarRow extends InternalRow { private final int rowId; private final int numFields; - ColumnarRow(ColumnVector data, int rowId) { + public ColumnarRow(ColumnVector data, int rowId) { assert (data.dataType() instanceof StructType); this.data = data; this.rowId = rowId; diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala index 782cec5e292ba..5617046e1396e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala @@ -20,13 +20,13 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.execution.metric.SQLMetrics -import org.apache.spark.sql.execution.vectorized.{ColumnarBatch, ColumnVector} import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector} /** * Helper trait for abstracting scan functionality using - * [[org.apache.spark.sql.execution.vectorized.ColumnarBatch]]es. + * [[ColumnarBatch]]es. */ private[sql] trait ColumnarBatchScan extends CodegenSupport { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 9a6f1c6dfa6a9..ce3c68810f3b6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} -import org.apache.spark.sql.execution.vectorized.{ColumnarRow, MutableColumnarRow} +import org.apache.spark.sql.execution.vectorized.MutableColumnarRow import org.apache.spark.sql.types.{DecimalType, StringType, StructType} import org.apache.spark.unsafe.KVIterator import org.apache.spark.util.Utils diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala index 0380ee8b09d63..0cf9b53ce1d5d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala @@ -20,8 +20,9 @@ package org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext -import org.apache.spark.sql.execution.vectorized.{ColumnarBatch, MutableColumnarRow, OnHeapColumnVector} +import org.apache.spark.sql.execution.vectorized.{MutableColumnarRow, OnHeapColumnVector} import org.apache.spark.sql.types._ +import org.apache.spark.sql.vectorized.ColumnarBatch /** * This is a helper class to generate an append-only vectorized hash map that can act as a 'cache' diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala index bcfc412430263..bcd1aa0890ba3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala @@ -32,8 +32,8 @@ import org.apache.spark.TaskContext import org.apache.spark.api.java.JavaRDD import org.apache.spark.sql.{DataFrame, SQLContext} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.execution.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector} import org.apache.spark.sql.types._ +import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector} import org.apache.spark.util.Utils diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala index 3e73393b12850..933b9753faa61 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partition import org.apache.spark.sql.execution.{ColumnarBatchScan, LeafExecNode, WholeStageCodegenExec} import org.apache.spark.sql.execution.vectorized._ import org.apache.spark.sql.types._ +import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector} case class InMemoryTableScanExec( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala index 8731ee88f87f2..835ce98462477 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala @@ -26,7 +26,7 @@ import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.rdd.{InputFileBlockHolder, RDD} import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.execution.vectorized.ColumnarBatch +import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.util.NextIterator /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala index 5cc8ed3535654..dc5ba96e69aec 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala @@ -30,8 +30,8 @@ import org.apache.spark._ import org.apache.spark.api.python._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.arrow.{ArrowUtils, ArrowWriter} -import org.apache.spark.sql.execution.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector} import org.apache.spark.sql.types._ +import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector} import org.apache.spark.util.Utils /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala index 508c116aae92e..c42bc60a59d67 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala @@ -20,8 +20,8 @@ package org.apache.spark.sql.execution.arrow import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.ArrayData -import org.apache.spark.sql.execution.vectorized.ArrowColumnVector import org.apache.spark.sql.types._ +import org.apache.spark.sql.vectorized.ArrowColumnVector import org.apache.spark.unsafe.types.UTF8String class ArrowWriterSuite extends SparkFunSuite { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala index 03490ad15a655..7304803a092c0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala @@ -23,6 +23,7 @@ import org.apache.arrow.vector.complex._ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.execution.arrow.ArrowUtils import org.apache.spark.sql.types._ +import org.apache.spark.sql.vectorized.ArrowColumnVector import org.apache.spark.unsafe.types.UTF8String class ArrowColumnVectorSuite extends SparkFunSuite { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala index 54b31cee031f6..944240f3bade5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala @@ -21,10 +21,10 @@ import org.scalatest.BeforeAndAfterEach import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.SpecificInternalRow -import org.apache.spark.sql.catalyst.util.ArrayData import org.apache.spark.sql.execution.columnar.ColumnAccessor import org.apache.spark.sql.execution.columnar.compression.ColumnBuilderHelper import org.apache.spark.sql.types._ +import org.apache.spark.sql.vectorized.ColumnarArray import org.apache.spark.unsafe.types.UTF8String class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala index 7848ebdcab6d0..675f06b31b970 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala @@ -33,6 +33,7 @@ import org.apache.spark.sql.{RandomDataGenerator, Row} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.arrow.ArrowUtils import org.apache.spark.sql.types._ +import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector} import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.types.CalendarInterval @@ -918,10 +919,7 @@ class ColumnarBatchSuite extends SparkFunSuite { assert(it.hasNext == false) // Reset and add 3 rows - batch.reset() - assert(batch.numRows() == 0) - assert(batch.rowIterator().hasNext == false) - + columns.foreach(_.reset()) // Add rows [NULL, 2.2, 2, "abc"], [3, NULL, 3, ""], [4, 4.4, 4, "world] columns(0).putNull(0) columns(1).putDouble(0, 2.2) From 7d045c5f00e2c7c67011830e2169a4e130c3ace8 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 4 Jan 2018 13:14:52 +0800 Subject: [PATCH 0008/1161] [SPARK-22944][SQL] improve FoldablePropagation ## What changes were proposed in this pull request? `FoldablePropagation` is a little tricky as it needs to handle attributes that are miss-derived from children, e.g. outer join outputs. This rule does a kind of stop-able tree transform, to skip to apply this rule when hit a node which may have miss-derived attributes. Logically we should be able to apply this rule above the unsupported nodes, by just treating the unsupported nodes as leaf nodes. This PR improves this rule to not stop the tree transformation, but reduce the foldable expressions that we want to propagate. ## How was this patch tested? existing tests Author: Wenchen Fan Closes #20139 from cloud-fan/foldable. --- .../sql/catalyst/optimizer/expressions.scala | 65 +++++++++++-------- .../optimizer/FoldablePropagationSuite.scala | 23 ++++++- 2 files changed, 58 insertions(+), 30 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 7d830bbb7dc32..1c0b7bd806801 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -506,18 +506,21 @@ object NullPropagation extends Rule[LogicalPlan] { /** - * Propagate foldable expressions: * Replace attributes with aliases of the original foldable expressions if possible. - * Other optimizations will take advantage of the propagated foldable expressions. - * + * Other optimizations will take advantage of the propagated foldable expressions. For example, + * this rule can optimize * {{{ * SELECT 1.0 x, 'abc' y, Now() z ORDER BY x, y, 3 - * ==> SELECT 1.0 x, 'abc' y, Now() z ORDER BY 1.0, 'abc', Now() * }}} + * to + * {{{ + * SELECT 1.0 x, 'abc' y, Now() z ORDER BY 1.0, 'abc', Now() + * }}} + * and other rules can further optimize it and remove the ORDER BY operator. */ object FoldablePropagation extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = { - val foldableMap = AttributeMap(plan.flatMap { + var foldableMap = AttributeMap(plan.flatMap { case Project(projectList, _) => projectList.collect { case a: Alias if a.child.foldable => (a.toAttribute, a) } @@ -530,38 +533,44 @@ object FoldablePropagation extends Rule[LogicalPlan] { if (foldableMap.isEmpty) { plan } else { - var stop = false CleanupAliases(plan.transformUp { - // A leaf node should not stop the folding process (note that we are traversing up the - // tree, starting at the leaf nodes); so we are allowing it. - case l: LeafNode => - l - // We can only propagate foldables for a subset of unary nodes. - case u: UnaryNode if !stop && canPropagateFoldables(u) => + case u: UnaryNode if foldableMap.nonEmpty && canPropagateFoldables(u) => u.transformExpressions(replaceFoldable) - // Allow inner joins. We do not allow outer join, although its output attributes are - // derived from its children, they are actually different attributes: the output of outer - // join is not always picked from its children, but can also be null. + // Join derives the output attributes from its child while they are actually not the + // same attributes. For example, the output of outer join is not always picked from its + // children, but can also be null. We should exclude these miss-derived attributes when + // propagating the foldable expressions. // TODO(cloud-fan): It seems more reasonable to use new attributes as the output attributes // of outer join. - case j @ Join(_, _, Inner, _) if !stop => - j.transformExpressions(replaceFoldable) - - // We can fold the projections an expand holds. However expand changes the output columns - // and often reuses the underlying attributes; so we cannot assume that a column is still - // foldable after the expand has been applied. - // TODO(hvanhovell): Expand should use new attributes as the output attributes. - case expand: Expand if !stop => - val newExpand = expand.copy(projections = expand.projections.map { projection => + case j @ Join(left, right, joinType, _) if foldableMap.nonEmpty => + val newJoin = j.transformExpressions(replaceFoldable) + val missDerivedAttrsSet: AttributeSet = AttributeSet(joinType match { + case _: InnerLike | LeftExistence(_) => Nil + case LeftOuter => right.output + case RightOuter => left.output + case FullOuter => left.output ++ right.output + }) + foldableMap = AttributeMap(foldableMap.baseMap.values.filterNot { + case (attr, _) => missDerivedAttrsSet.contains(attr) + }.toSeq) + newJoin + + // We can not replace the attributes in `Expand.output`. If there are other non-leaf + // operators that have the `output` field, we should put them here too. + case expand: Expand if foldableMap.nonEmpty => + expand.copy(projections = expand.projections.map { projection => projection.map(_.transform(replaceFoldable)) }) - stop = true - newExpand - case other => - stop = true + // For other plans, they are not safe to apply foldable propagation, and they should not + // propagate foldable expressions from children. + case other if foldableMap.nonEmpty => + val childrenOutputSet = AttributeSet(other.children.flatMap(_.output)) + foldableMap = AttributeMap(foldableMap.baseMap.values.filterNot { + case (attr, _) => childrenOutputSet.contains(attr) + }.toSeq) other }) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FoldablePropagationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FoldablePropagationSuite.scala index dccb32f0379a8..c28844642aed0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FoldablePropagationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FoldablePropagationSuite.scala @@ -147,8 +147,8 @@ class FoldablePropagationSuite extends PlanTest { test("Propagate in expand") { val c1 = Literal(1).as('a) val c2 = Literal(2).as('b) - val a1 = c1.toAttribute.withNullability(true) - val a2 = c2.toAttribute.withNullability(true) + val a1 = c1.toAttribute.newInstance().withNullability(true) + val a2 = c2.toAttribute.newInstance().withNullability(true) val expand = Expand( Seq(Seq(Literal(null), 'b), Seq('a, Literal(null))), Seq(a1, a2), @@ -161,4 +161,23 @@ class FoldablePropagationSuite extends PlanTest { val correctAnswer = correctExpand.where(a1.isNotNull).select(a1, a2).analyze comparePlans(optimized, correctAnswer) } + + test("Propagate above outer join") { + val left = LocalRelation('a.int).select('a, Literal(1).as('b)) + val right = LocalRelation('c.int).select('c, Literal(1).as('d)) + + val join = left.join( + right, + joinType = LeftOuter, + condition = Some('a === 'c && 'b === 'd)) + val query = join.select(('b + 3).as('res)).analyze + val optimized = Optimize.execute(query) + + val correctAnswer = left.join( + right, + joinType = LeftOuter, + condition = Some('a === 'c && Literal(1) === Literal(1))) + .select((Literal(1) + 3).as('res)).analyze + comparePlans(optimized, correctAnswer) + } } From df95a908baf78800556636a76d58bba9b3dd943f Mon Sep 17 00:00:00 2001 From: Felix Cheung Date: Wed, 3 Jan 2018 21:43:14 -0800 Subject: [PATCH 0009/1161] [SPARK-22933][SPARKR] R Structured Streaming API for withWatermark, trigger, partitionBy ## What changes were proposed in this pull request? R Structured Streaming API for withWatermark, trigger, partitionBy ## How was this patch tested? manual, unit tests Author: Felix Cheung Closes #20129 from felixcheung/rwater. --- R/pkg/NAMESPACE | 1 + R/pkg/R/DataFrame.R | 96 +++++++++++++++- R/pkg/R/SQLContext.R | 4 +- R/pkg/R/generics.R | 6 + R/pkg/tests/fulltests/test_streaming.R | 107 ++++++++++++++++++ python/pyspark/sql/streaming.py | 4 + .../sql/execution/streaming/Triggers.scala | 2 +- 7 files changed, 214 insertions(+), 6 deletions(-) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 3219c6f0cc47b..c51eb0f39c4b1 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -179,6 +179,7 @@ exportMethods("arrange", "with", "withColumn", "withColumnRenamed", + "withWatermark", "write.df", "write.jdbc", "write.json", diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index fe238f6dd4eb0..9956f7eda91e6 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -3661,7 +3661,8 @@ setMethod("getNumPartitions", #' isStreaming #' #' Returns TRUE if this SparkDataFrame contains one or more sources that continuously return data -#' as it arrives. +#' as it arrives. A dataset that reads data from a streaming source must be executed as a +#' \code{StreamingQuery} using \code{write.stream}. #' #' @param x A SparkDataFrame #' @return TRUE if this SparkDataFrame is from a streaming source @@ -3707,7 +3708,17 @@ setMethod("isStreaming", #' @param df a streaming SparkDataFrame. #' @param source a name for external data source. #' @param outputMode one of 'append', 'complete', 'update'. -#' @param ... additional argument(s) passed to the method. +#' @param partitionBy a name or a list of names of columns to partition the output by on the file +#' system. If specified, the output is laid out on the file system similar to Hive's +#' partitioning scheme. +#' @param trigger.processingTime a processing time interval as a string, e.g. '5 seconds', +#' '1 minute'. This is a trigger that runs a query periodically based on the processing +#' time. If value is '0 seconds', the query will run as fast as possible, this is the +#' default. Only one trigger can be set. +#' @param trigger.once a logical, must be set to \code{TRUE}. This is a trigger that processes only +#' one batch of data in a streaming query then terminates the query. Only one trigger can be +#' set. +#' @param ... additional external data source specific named options. #' #' @family SparkDataFrame functions #' @seealso \link{read.stream} @@ -3725,7 +3736,8 @@ setMethod("isStreaming", #' # console #' q <- write.stream(wordCounts, "console", outputMode = "complete") #' # text stream -#' q <- write.stream(df, "text", path = "/home/user/out", checkpointLocation = "/home/user/cp") +#' q <- write.stream(df, "text", path = "/home/user/out", checkpointLocation = "/home/user/cp" +#' partitionBy = c("year", "month"), trigger.processingTime = "30 seconds") #' # memory stream #' q <- write.stream(wordCounts, "memory", queryName = "outs", outputMode = "complete") #' head(sql("SELECT * from outs")) @@ -3737,7 +3749,8 @@ setMethod("isStreaming", #' @note experimental setMethod("write.stream", signature(df = "SparkDataFrame"), - function(df, source = NULL, outputMode = NULL, ...) { + function(df, source = NULL, outputMode = NULL, partitionBy = NULL, + trigger.processingTime = NULL, trigger.once = NULL, ...) { if (!is.null(source) && !is.character(source)) { stop("source should be character, NULL or omitted. It is the data source specified ", "in 'spark.sql.sources.default' configuration by default.") @@ -3748,12 +3761,43 @@ setMethod("write.stream", if (is.null(source)) { source <- getDefaultSqlSource() } + cols <- NULL + if (!is.null(partitionBy)) { + if (!all(sapply(partitionBy, function(c) { is.character(c) }))) { + stop("All partitionBy column names should be characters.") + } + cols <- as.list(partitionBy) + } + jtrigger <- NULL + if (!is.null(trigger.processingTime) && !is.na(trigger.processingTime)) { + if (!is.null(trigger.once)) { + stop("Multiple triggers not allowed.") + } + interval <- as.character(trigger.processingTime) + if (nchar(interval) == 0) { + stop("Value for trigger.processingTime must be a non-empty string.") + } + jtrigger <- handledCallJStatic("org.apache.spark.sql.streaming.Trigger", + "ProcessingTime", + interval) + } else if (!is.null(trigger.once) && !is.na(trigger.once)) { + if (!is.logical(trigger.once) || !trigger.once) { + stop("Value for trigger.once must be TRUE.") + } + jtrigger <- callJStatic("org.apache.spark.sql.streaming.Trigger", "Once") + } options <- varargsToStrEnv(...) write <- handledCallJMethod(df@sdf, "writeStream") write <- callJMethod(write, "format", source) if (!is.null(outputMode)) { write <- callJMethod(write, "outputMode", outputMode) } + if (!is.null(cols)) { + write <- callJMethod(write, "partitionBy", cols) + } + if (!is.null(jtrigger)) { + write <- callJMethod(write, "trigger", jtrigger) + } write <- callJMethod(write, "options", options) ssq <- handledCallJMethod(write, "start") streamingQuery(ssq) @@ -3967,3 +4011,47 @@ setMethod("broadcast", sdf <- callJStatic("org.apache.spark.sql.functions", "broadcast", x@sdf) dataFrame(sdf) }) + +#' withWatermark +#' +#' Defines an event time watermark for this streaming SparkDataFrame. A watermark tracks a point in +#' time before which we assume no more late data is going to arrive. +#' +#' Spark will use this watermark for several purposes: +#' \itemize{ +#' \item{-} To know when a given time window aggregation can be finalized and thus can be emitted +#' when using output modes that do not allow updates. +#' \item{-} To minimize the amount of state that we need to keep for on-going aggregations. +#' } +#' The current watermark is computed by looking at the \code{MAX(eventTime)} seen across +#' all of the partitions in the query minus a user specified \code{delayThreshold}. Due to the cost +#' of coordinating this value across partitions, the actual watermark used is only guaranteed +#' to be at least \code{delayThreshold} behind the actual event time. In some cases we may still +#' process records that arrive more than \code{delayThreshold} late. +#' +#' @param x a streaming SparkDataFrame +#' @param eventTime a string specifying the name of the Column that contains the event time of the +#' row. +#' @param delayThreshold a string specifying the minimum delay to wait to data to arrive late, +#' relative to the latest record that has been processed in the form of an +#' interval (e.g. "1 minute" or "5 hours"). NOTE: This should not be negative. +#' @return a SparkDataFrame. +#' @aliases withWatermark,SparkDataFrame,character,character-method +#' @family SparkDataFrame functions +#' @rdname withWatermark +#' @name withWatermark +#' @export +#' @examples +#' \dontrun{ +#' sparkR.session() +#' schema <- structType(structField("time", "timestamp"), structField("value", "double")) +#' df <- read.stream("json", path = jsonDir, schema = schema, maxFilesPerTrigger = 1) +#' df <- withWatermark(df, "time", "10 minutes") +#' } +#' @note withWatermark since 2.3.0 +setMethod("withWatermark", + signature(x = "SparkDataFrame", eventTime = "character", delayThreshold = "character"), + function(x, eventTime, delayThreshold) { + sdf <- callJMethod(x@sdf, "withWatermark", eventTime, delayThreshold) + dataFrame(sdf) + }) diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R index 3b7f71bbbffb8..9d0a2d5e074e4 100644 --- a/R/pkg/R/SQLContext.R +++ b/R/pkg/R/SQLContext.R @@ -727,7 +727,9 @@ read.jdbc <- function(url, tableName, #' @param schema The data schema defined in structType or a DDL-formatted string, this is #' required for file-based streaming data source #' @param ... additional external data source specific named options, for instance \code{path} for -#' file-based streaming data source +#' file-based streaming data source. \code{timeZone} to indicate a timezone to be used to +#' parse timestamps in the JSON/CSV data sources or partition values; If it isn't set, it +#' uses the default value, session local timezone. #' @return SparkDataFrame #' @rdname read.stream #' @name read.stream diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 5369c32544e5e..e0dde3339fabc 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -799,6 +799,12 @@ setGeneric("withColumn", function(x, colName, col) { standardGeneric("withColumn setGeneric("withColumnRenamed", function(x, existingCol, newCol) { standardGeneric("withColumnRenamed") }) +#' @rdname withWatermark +#' @export +setGeneric("withWatermark", function(x, eventTime, delayThreshold) { + standardGeneric("withWatermark") +}) + #' @rdname write.df #' @export setGeneric("write.df", function(df, path = NULL, ...) { standardGeneric("write.df") }) diff --git a/R/pkg/tests/fulltests/test_streaming.R b/R/pkg/tests/fulltests/test_streaming.R index 54f40bbd5f517..a354d50c6b54e 100644 --- a/R/pkg/tests/fulltests/test_streaming.R +++ b/R/pkg/tests/fulltests/test_streaming.R @@ -172,6 +172,113 @@ test_that("Terminated by error", { stopQuery(q) }) +test_that("PartitionBy", { + parquetPath <- tempfile(pattern = "sparkr-test", fileext = ".parquet") + checkpointPath <- tempfile(pattern = "sparkr-test", fileext = ".checkpoint") + textPath <- tempfile(pattern = "sparkr-test", fileext = ".text") + df <- read.df(jsonPath, "json", stringSchema) + write.df(df, parquetPath, "parquet", "overwrite") + + df <- read.stream(path = parquetPath, schema = stringSchema) + + expect_error(write.stream(df, "json", path = textPath, checkpointLocation = "append", + partitionBy = c(1, 2)), + "All partitionBy column names should be characters") + + q <- write.stream(df, "json", path = textPath, checkpointLocation = "append", + partitionBy = "name") + awaitTermination(q, 5 * 1000) + callJMethod(q@ssq, "processAllAvailable") + + dirs <- list.files(textPath) + expect_equal(length(dirs[substring(dirs, 1, nchar("name=")) == "name="]), 3) + + unlink(checkpointPath) + unlink(textPath) + unlink(parquetPath) +}) + +test_that("Watermark", { + parquetPath <- tempfile(pattern = "sparkr-test", fileext = ".parquet") + schema <- structType(structField("value", "string")) + t <- Sys.time() + df <- as.DataFrame(lapply(list(t), as.character), schema) + write.df(df, parquetPath, "parquet", "append") + df <- read.stream(path = parquetPath, schema = "value STRING") + df <- withColumn(df, "eventTime", cast(df$value, "timestamp")) + df <- withWatermark(df, "eventTime", "10 seconds") + counts <- count(group_by(df, "eventTime")) + q <- write.stream(counts, "memory", queryName = "times", outputMode = "append") + + # first events + df <- as.DataFrame(lapply(list(t + 1, t, t + 2), as.character), schema) + write.df(df, parquetPath, "parquet", "append") + awaitTermination(q, 5 * 1000) + callJMethod(q@ssq, "processAllAvailable") + + # advance watermark to 15 + df <- as.DataFrame(lapply(list(t + 25), as.character), schema) + write.df(df, parquetPath, "parquet", "append") + awaitTermination(q, 5 * 1000) + callJMethod(q@ssq, "processAllAvailable") + + # old events, should be dropped + df <- as.DataFrame(lapply(list(t), as.character), schema) + write.df(df, parquetPath, "parquet", "append") + awaitTermination(q, 5 * 1000) + callJMethod(q@ssq, "processAllAvailable") + + # evict events less than previous watermark + df <- as.DataFrame(lapply(list(t + 25), as.character), schema) + write.df(df, parquetPath, "parquet", "append") + awaitTermination(q, 5 * 1000) + callJMethod(q@ssq, "processAllAvailable") + + times <- collect(sql("SELECT * FROM times")) + # looks like write timing can affect the first bucket; but it should be t + expect_equal(times[order(times$eventTime),][1, 2], 2) + + stopQuery(q) + unlink(parquetPath) +}) + +test_that("Trigger", { + parquetPath <- tempfile(pattern = "sparkr-test", fileext = ".parquet") + schema <- structType(structField("value", "string")) + df <- as.DataFrame(lapply(list(Sys.time()), as.character), schema) + write.df(df, parquetPath, "parquet", "append") + df <- read.stream(path = parquetPath, schema = "value STRING") + + expect_error(write.stream(df, "memory", queryName = "times", outputMode = "append", + trigger.processingTime = "", trigger.once = ""), "Multiple triggers not allowed.") + + expect_error(write.stream(df, "memory", queryName = "times", outputMode = "append", + trigger.processingTime = ""), + "Value for trigger.processingTime must be a non-empty string.") + + expect_error(write.stream(df, "memory", queryName = "times", outputMode = "append", + trigger.processingTime = "invalid"), "illegal argument") + + expect_error(write.stream(df, "memory", queryName = "times", outputMode = "append", + trigger.once = ""), "Value for trigger.once must be TRUE.") + + expect_error(write.stream(df, "memory", queryName = "times", outputMode = "append", + trigger.once = FALSE), "Value for trigger.once must be TRUE.") + + q <- write.stream(df, "memory", queryName = "times", outputMode = "append", trigger.once = TRUE) + awaitTermination(q, 5 * 1000) + callJMethod(q@ssq, "processAllAvailable") + df <- as.DataFrame(lapply(list(Sys.time()), as.character), schema) + write.df(df, parquetPath, "parquet", "append") + awaitTermination(q, 5 * 1000) + callJMethod(q@ssq, "processAllAvailable") + + expect_equal(nrow(collect(sql("SELECT * FROM times"))), 1) + + stopQuery(q) + unlink(parquetPath) +}) + unlink(jsonPath) unlink(jsonPathNa) diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index fb228f99ba7ab..24ae3776a217b 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -793,6 +793,10 @@ def trigger(self, processingTime=None, once=None): .. note:: Evolving. :param processingTime: a processing time interval as a string, e.g. '5 seconds', '1 minute'. + Set a trigger that runs a query periodically based on the processing + time. Only one trigger can be set. + :param once: if set to True, set a trigger that processes only one batch of data in a + streaming query then terminates the query. Only one trigger can be set. >>> # trigger the query for execution every 5 seconds >>> writer = sdf.writeStream.trigger(processingTime='5 seconds') diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Triggers.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Triggers.scala index 271bc4da99c08..19e3e55cb2829 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Triggers.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Triggers.scala @@ -21,7 +21,7 @@ import org.apache.spark.annotation.{Experimental, InterfaceStability} import org.apache.spark.sql.streaming.Trigger /** - * A [[Trigger]] that process only one batch of data in a streaming query then terminates + * A [[Trigger]] that processes only one batch of data in a streaming query then terminates * the query. */ @Experimental From 9fa703e89318922393bae03c0db4575f4f4b4c56 Mon Sep 17 00:00:00 2001 From: Kent Yao Date: Thu, 4 Jan 2018 19:10:10 +0800 Subject: [PATCH 0010/1161] [SPARK-22950][SQL] Handle ChildFirstURLClassLoader's parent ## What changes were proposed in this pull request? ChildFirstClassLoader's parent is set to null, so we can't get jars from its parent. This will cause ClassNotFoundException during HiveClient initialization with builtin hive jars, where we may should use spark context loader instead. ## How was this patch tested? add new ut cc cloud-fan gatorsmile Author: Kent Yao Closes #20145 from yaooqinn/SPARK-22950. --- .../org/apache/spark/sql/hive/HiveUtils.scala | 4 +++- .../spark/sql/hive/HiveUtilsSuite.scala | 20 +++++++++++++++++++ 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala index c489690af8cd1..c7717d70c996f 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala @@ -47,7 +47,7 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf._ import org.apache.spark.sql.internal.StaticSQLConf.{CATALOG_IMPLEMENTATION, WAREHOUSE_PATH} import org.apache.spark.sql.types._ -import org.apache.spark.util.Utils +import org.apache.spark.util.{ChildFirstURLClassLoader, Utils} private[spark] object HiveUtils extends Logging { @@ -312,6 +312,8 @@ private[spark] object HiveUtils extends Logging { // starting from the given classLoader. def allJars(classLoader: ClassLoader): Array[URL] = classLoader match { case null => Array.empty[URL] + case childFirst: ChildFirstURLClassLoader => + childFirst.getURLs() ++ allJars(Utils.getSparkClassLoader) case urlClassLoader: URLClassLoader => urlClassLoader.getURLs ++ allJars(urlClassLoader.getParent) case other => allJars(other.getParent) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveUtilsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveUtilsSuite.scala index fdbfcf1a68440..8697d47e89e89 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveUtilsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveUtilsSuite.scala @@ -17,11 +17,16 @@ package org.apache.spark.sql.hive +import java.net.URL + import org.apache.hadoop.hive.conf.HiveConf.ConfVars +import org.apache.spark.SparkConf +import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.sql.QueryTest import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.util.{ChildFirstURLClassLoader, MutableURLClassLoader} class HiveUtilsSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { @@ -42,4 +47,19 @@ class HiveUtilsSuite extends QueryTest with SQLTestUtils with TestHiveSingleton assert(hiveConf("foo") === "bar") } } + + test("ChildFirstURLClassLoader's parent is null, get spark classloader instead") { + val conf = new SparkConf + val contextClassLoader = Thread.currentThread().getContextClassLoader + val loader = new ChildFirstURLClassLoader(Array(), contextClassLoader) + try { + Thread.currentThread().setContextClassLoader(loader) + HiveUtils.newClientForMetadata( + conf, + SparkHadoopUtil.newConfiguration(conf), + HiveUtils.newTemporaryConfiguration(useInMemoryDerby = true)) + } finally { + Thread.currentThread().setContextClassLoader(contextClassLoader) + } + } } From d5861aba9d80ca15ad3f22793b79822e470d6913 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 4 Jan 2018 19:17:22 +0800 Subject: [PATCH 0011/1161] [SPARK-22945][SQL] add java UDF APIs in the functions object ## What changes were proposed in this pull request? Currently Scala users can use UDF like ``` val foo = udf((i: Int) => Math.random() + i).asNondeterministic df.select(foo('a)) ``` Python users can also do it with similar APIs. However Java users can't do it, we should add Java UDF APIs in the functions object. ## How was this patch tested? new tests Author: Wenchen Fan Closes #20141 from cloud-fan/udf. --- .../apache/spark/sql/UDFRegistration.scala | 90 ++--- .../sql/expressions/UserDefinedFunction.scala | 1 + .../org/apache/spark/sql/functions.scala | 313 ++++++++++++++---- .../apache/spark/sql/JavaDataFrameSuite.java | 11 + .../scala/org/apache/spark/sql/UDFSuite.scala | 12 +- 5 files changed, 315 insertions(+), 112 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala index dc2468a721e41..f94baef39dfad 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql -import java.lang.reflect.{ParameterizedType, Type} +import java.lang.reflect.ParameterizedType import scala.reflect.runtime.universe.TypeTag import scala.util.Try @@ -110,29 +110,29 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends /* register 0-22 were generated by this script - (0 to 22).map { x => + (0 to 22).foreach { x => val types = (1 to x).foldRight("RT")((i, s) => {s"A$i, $s"}) - val typeTags = (1 to x).map(i => s"A${i}: TypeTag").foldLeft("RT: TypeTag")(_ + ", " + _) + val typeTags = (1 to x).map(i => s"A$i: TypeTag").foldLeft("RT: TypeTag")(_ + ", " + _) val inputTypes = (1 to x).foldRight("Nil")((i, s) => {s"ScalaReflection.schemaFor[A$i].dataType :: $s"}) println(s""" - /** - * Registers a deterministic Scala closure of ${x} arguments as user-defined function (UDF). - * @tparam RT return type of UDF. - * @since 1.3.0 - */ - def register[$typeTags](name: String, func: Function$x[$types]): UserDefinedFunction = { - val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes = Try($inputTypes).toOption - def builder(e: Seq[Expression]) = if (e.length == $x) { - ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable, udfDeterministic = true) - } else { - throw new AnalysisException("Invalid number of arguments for function " + name + - ". Expected: $x; Found: " + e.length) - } - functionRegistry.createOrReplaceTempFunction(name, builder) - val udf = UserDefinedFunction(func, dataType, inputTypes).withName(name) - if (nullable) udf else udf.asNonNullable() - }""") + |/** + | * Registers a deterministic Scala closure of $x arguments as user-defined function (UDF). + | * @tparam RT return type of UDF. + | * @since 1.3.0 + | */ + |def register[$typeTags](name: String, func: Function$x[$types]): UserDefinedFunction = { + | val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] + | val inputTypes = Try($inputTypes).toOption + | def builder(e: Seq[Expression]) = if (e.length == $x) { + | ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable, udfDeterministic = true) + | } else { + | throw new AnalysisException("Invalid number of arguments for function " + name + + | ". Expected: $x; Found: " + e.length) + | } + | functionRegistry.createOrReplaceTempFunction(name, builder) + | val udf = UserDefinedFunction(func, dataType, inputTypes).withName(name) + | if (nullable) udf else udf.asNonNullable() + |}""".stripMargin) } (0 to 22).foreach { i => @@ -144,7 +144,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val funcCall = if (i == 0) "() => func" else "func" println(s""" |/** - | * Register a user-defined function with ${i} arguments. + | * Register a deterministic Java UDF$i instance as user-defined function (UDF). | * @since $version | */ |def register(name: String, f: UDF$i[$extTypeArgs], returnType: DataType): Unit = { @@ -689,7 +689,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 0 arguments. + * Register a deterministic Java UDF0 instance as user-defined function (UDF). * @since 2.3.0 */ def register(name: String, f: UDF0[_], returnType: DataType): Unit = { @@ -704,7 +704,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 1 arguments. + * Register a deterministic Java UDF1 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF1[_, _], returnType: DataType): Unit = { @@ -719,7 +719,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 2 arguments. + * Register a deterministic Java UDF2 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF2[_, _, _], returnType: DataType): Unit = { @@ -734,7 +734,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 3 arguments. + * Register a deterministic Java UDF3 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF3[_, _, _, _], returnType: DataType): Unit = { @@ -749,7 +749,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 4 arguments. + * Register a deterministic Java UDF4 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF4[_, _, _, _, _], returnType: DataType): Unit = { @@ -764,7 +764,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 5 arguments. + * Register a deterministic Java UDF5 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF5[_, _, _, _, _, _], returnType: DataType): Unit = { @@ -779,7 +779,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 6 arguments. + * Register a deterministic Java UDF6 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF6[_, _, _, _, _, _, _], returnType: DataType): Unit = { @@ -794,7 +794,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 7 arguments. + * Register a deterministic Java UDF7 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF7[_, _, _, _, _, _, _, _], returnType: DataType): Unit = { @@ -809,7 +809,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 8 arguments. + * Register a deterministic Java UDF8 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF8[_, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { @@ -824,7 +824,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 9 arguments. + * Register a deterministic Java UDF9 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF9[_, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { @@ -839,7 +839,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 10 arguments. + * Register a deterministic Java UDF10 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF10[_, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { @@ -854,7 +854,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 11 arguments. + * Register a deterministic Java UDF11 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF11[_, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { @@ -869,7 +869,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 12 arguments. + * Register a deterministic Java UDF12 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF12[_, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { @@ -884,7 +884,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 13 arguments. + * Register a deterministic Java UDF13 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF13[_, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { @@ -899,7 +899,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 14 arguments. + * Register a deterministic Java UDF14 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF14[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { @@ -914,7 +914,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 15 arguments. + * Register a deterministic Java UDF15 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF15[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { @@ -929,7 +929,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 16 arguments. + * Register a deterministic Java UDF16 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF16[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { @@ -944,7 +944,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 17 arguments. + * Register a deterministic Java UDF17 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF17[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { @@ -959,7 +959,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 18 arguments. + * Register a deterministic Java UDF18 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF18[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { @@ -974,7 +974,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 19 arguments. + * Register a deterministic Java UDF19 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF19[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { @@ -989,7 +989,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 20 arguments. + * Register a deterministic Java UDF20 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF20[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { @@ -1004,7 +1004,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 21 arguments. + * Register a deterministic Java UDF21 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF21[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { @@ -1019,7 +1019,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 22 arguments. + * Register a deterministic Java UDF22 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF22[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala index 03b654f830520..40a058d2cadd2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala @@ -66,6 +66,7 @@ case class UserDefinedFunction protected[sql] ( * * @since 1.3.0 */ + @scala.annotation.varargs def apply(exprs: Column*): Column = { Column(ScalaUDF( f, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 530a525a01dec..0d11682d80a3c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -24,6 +24,7 @@ import scala.util.Try import scala.util.control.NonFatal import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.sql.api.java._ import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.analysis.{Star, UnresolvedFunction} import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder @@ -32,7 +33,6 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical.{HintInfo, ResolvedHint} import org.apache.spark.sql.execution.SparkSqlParser import org.apache.spark.sql.expressions.UserDefinedFunction -import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -3254,42 +3254,66 @@ object functions { */ def map_values(e: Column): Column = withExpr { MapValues(e.expr) } - ////////////////////////////////////////////////////////////////////////////////////////////// - ////////////////////////////////////////////////////////////////////////////////////////////// - // scalastyle:off line.size.limit // scalastyle:off parameter.number /* Use the following code to generate: - (0 to 10).map { x => + + (0 to 10).foreach { x => val types = (1 to x).foldRight("RT")((i, s) => {s"A$i, $s"}) val typeTags = (1 to x).map(i => s"A$i: TypeTag").foldLeft("RT: TypeTag")(_ + ", " + _) val inputTypes = (1 to x).foldRight("Nil")((i, s) => {s"ScalaReflection.schemaFor(typeTag[A$i]).dataType :: $s"}) println(s""" - /** - * Defines a deterministic user-defined function of ${x} arguments as user-defined - * function (UDF). The data types are automatically inferred based on the function's - * signature. To change a UDF to nondeterministic, call the API - * `UserDefinedFunction.asNondeterministic()`. - * - * @group udf_funcs - * @since 1.3.0 - */ - def udf[$typeTags](f: Function$x[$types]): UserDefinedFunction = { - val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes = Try($inputTypes).toOption - val udf = UserDefinedFunction(f, dataType, inputTypes) - if (nullable) udf else udf.asNonNullable() - }""") + |/** + | * Defines a Scala closure of $x arguments as user-defined function (UDF). + | * The data types are automatically inferred based on the Scala closure's + | * signature. By default the returned UDF is deterministic. To change it to + | * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. + | * + | * @group udf_funcs + | * @since 1.3.0 + | */ + |def udf[$typeTags](f: Function$x[$types]): UserDefinedFunction = { + | val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] + | val inputTypes = Try($inputTypes).toOption + | val udf = UserDefinedFunction(f, dataType, inputTypes) + | if (nullable) udf else udf.asNonNullable() + |}""".stripMargin) + } + + (0 to 10).foreach { i => + val extTypeArgs = (0 to i).map(_ => "_").mkString(", ") + val anyTypeArgs = (0 to i).map(_ => "Any").mkString(", ") + val anyCast = s".asInstanceOf[UDF$i[$anyTypeArgs]]" + val anyParams = (1 to i).map(_ => "_: Any").mkString(", ") + val funcCall = if (i == 0) "() => func" else "func" + println(s""" + |/** + | * Defines a Java UDF$i instance as user-defined function (UDF). + | * The caller must specify the output data type, and there is no automatic input type coercion. + | * By default the returned UDF is deterministic. To change it to nondeterministic, call the + | * API `UserDefinedFunction.asNondeterministic()`. + | * + | * @group udf_funcs + | * @since 2.3.0 + | */ + |def udf(f: UDF$i[$extTypeArgs], returnType: DataType): UserDefinedFunction = { + | val func = f$anyCast.call($anyParams) + | UserDefinedFunction($funcCall, returnType, inputTypes = None) + |}""".stripMargin) } */ + ////////////////////////////////////////////////////////////////////////////////////////////// + // Scala UDF functions + ////////////////////////////////////////////////////////////////////////////////////////////// + /** - * Defines a deterministic user-defined function of 0 arguments as user-defined - * function (UDF). The data types are automatically inferred based on the function's - * signature. To change a UDF to nondeterministic, call the API - * `UserDefinedFunction.asNondeterministic()`. + * Defines a Scala closure of 0 arguments as user-defined function (UDF). + * The data types are automatically inferred based on the Scala closure's + * signature. By default the returned UDF is deterministic. To change it to + * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. * * @group udf_funcs * @since 1.3.0 @@ -3302,10 +3326,10 @@ object functions { } /** - * Defines a deterministic user-defined function of 1 arguments as user-defined - * function (UDF). The data types are automatically inferred based on the function's - * signature. To change a UDF to nondeterministic, call the API - * `UserDefinedFunction.asNondeterministic()`. + * Defines a Scala closure of 1 arguments as user-defined function (UDF). + * The data types are automatically inferred based on the Scala closure's + * signature. By default the returned UDF is deterministic. To change it to + * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. * * @group udf_funcs * @since 1.3.0 @@ -3318,10 +3342,10 @@ object functions { } /** - * Defines a deterministic user-defined function of 2 arguments as user-defined - * function (UDF). The data types are automatically inferred based on the function's - * signature. To change a UDF to nondeterministic, call the API - * `UserDefinedFunction.asNondeterministic()`. + * Defines a Scala closure of 2 arguments as user-defined function (UDF). + * The data types are automatically inferred based on the Scala closure's + * signature. By default the returned UDF is deterministic. To change it to + * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. * * @group udf_funcs * @since 1.3.0 @@ -3334,10 +3358,10 @@ object functions { } /** - * Defines a deterministic user-defined function of 3 arguments as user-defined - * function (UDF). The data types are automatically inferred based on the function's - * signature. To change a UDF to nondeterministic, call the API - * `UserDefinedFunction.asNondeterministic()`. + * Defines a Scala closure of 3 arguments as user-defined function (UDF). + * The data types are automatically inferred based on the Scala closure's + * signature. By default the returned UDF is deterministic. To change it to + * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. * * @group udf_funcs * @since 1.3.0 @@ -3350,10 +3374,10 @@ object functions { } /** - * Defines a deterministic user-defined function of 4 arguments as user-defined - * function (UDF). The data types are automatically inferred based on the function's - * signature. To change a UDF to nondeterministic, call the API - * `UserDefinedFunction.asNondeterministic()`. + * Defines a Scala closure of 4 arguments as user-defined function (UDF). + * The data types are automatically inferred based on the Scala closure's + * signature. By default the returned UDF is deterministic. To change it to + * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. * * @group udf_funcs * @since 1.3.0 @@ -3366,10 +3390,10 @@ object functions { } /** - * Defines a deterministic user-defined function of 5 arguments as user-defined - * function (UDF). The data types are automatically inferred based on the function's - * signature. To change a UDF to nondeterministic, call the API - * `UserDefinedFunction.asNondeterministic()`. + * Defines a Scala closure of 5 arguments as user-defined function (UDF). + * The data types are automatically inferred based on the Scala closure's + * signature. By default the returned UDF is deterministic. To change it to + * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. * * @group udf_funcs * @since 1.3.0 @@ -3382,10 +3406,10 @@ object functions { } /** - * Defines a deterministic user-defined function of 6 arguments as user-defined - * function (UDF). The data types are automatically inferred based on the function's - * signature. To change a UDF to nondeterministic, call the API - * `UserDefinedFunction.asNondeterministic()`. + * Defines a Scala closure of 6 arguments as user-defined function (UDF). + * The data types are automatically inferred based on the Scala closure's + * signature. By default the returned UDF is deterministic. To change it to + * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. * * @group udf_funcs * @since 1.3.0 @@ -3398,10 +3422,10 @@ object functions { } /** - * Defines a deterministic user-defined function of 7 arguments as user-defined - * function (UDF). The data types are automatically inferred based on the function's - * signature. To change a UDF to nondeterministic, call the API - * `UserDefinedFunction.asNondeterministic()`. + * Defines a Scala closure of 7 arguments as user-defined function (UDF). + * The data types are automatically inferred based on the Scala closure's + * signature. By default the returned UDF is deterministic. To change it to + * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. * * @group udf_funcs * @since 1.3.0 @@ -3414,10 +3438,10 @@ object functions { } /** - * Defines a deterministic user-defined function of 8 arguments as user-defined - * function (UDF). The data types are automatically inferred based on the function's - * signature. To change a UDF to nondeterministic, call the API - * `UserDefinedFunction.asNondeterministic()`. + * Defines a Scala closure of 8 arguments as user-defined function (UDF). + * The data types are automatically inferred based on the Scala closure's + * signature. By default the returned UDF is deterministic. To change it to + * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. * * @group udf_funcs * @since 1.3.0 @@ -3430,10 +3454,10 @@ object functions { } /** - * Defines a deterministic user-defined function of 9 arguments as user-defined - * function (UDF). The data types are automatically inferred based on the function's - * signature. To change a UDF to nondeterministic, call the API - * `UserDefinedFunction.asNondeterministic()`. + * Defines a Scala closure of 9 arguments as user-defined function (UDF). + * The data types are automatically inferred based on the Scala closure's + * signature. By default the returned UDF is deterministic. To change it to + * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. * * @group udf_funcs * @since 1.3.0 @@ -3446,10 +3470,10 @@ object functions { } /** - * Defines a deterministic user-defined function of 10 arguments as user-defined - * function (UDF). The data types are automatically inferred based on the function's - * signature. To change a UDF to nondeterministic, call the API - * `UserDefinedFunction.asNondeterministic()`. + * Defines a Scala closure of 10 arguments as user-defined function (UDF). + * The data types are automatically inferred based on the Scala closure's + * signature. By default the returned UDF is deterministic. To change it to + * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. * * @group udf_funcs * @since 1.3.0 @@ -3461,13 +3485,172 @@ object functions { if (nullable) udf else udf.asNonNullable() } + ////////////////////////////////////////////////////////////////////////////////////////////// + // Java UDF functions + ////////////////////////////////////////////////////////////////////////////////////////////// + + /** + * Defines a Java UDF0 instance as user-defined function (UDF). + * The caller must specify the output data type, and there is no automatic input type coercion. + * By default the returned UDF is deterministic. To change it to nondeterministic, call the + * API `UserDefinedFunction.asNondeterministic()`. + * + * @group udf_funcs + * @since 2.3.0 + */ + def udf(f: UDF0[_], returnType: DataType): UserDefinedFunction = { + val func = f.asInstanceOf[UDF0[Any]].call() + UserDefinedFunction(() => func, returnType, inputTypes = None) + } + + /** + * Defines a Java UDF1 instance as user-defined function (UDF). + * The caller must specify the output data type, and there is no automatic input type coercion. + * By default the returned UDF is deterministic. To change it to nondeterministic, call the + * API `UserDefinedFunction.asNondeterministic()`. + * + * @group udf_funcs + * @since 2.3.0 + */ + def udf(f: UDF1[_, _], returnType: DataType): UserDefinedFunction = { + val func = f.asInstanceOf[UDF1[Any, Any]].call(_: Any) + UserDefinedFunction(func, returnType, inputTypes = None) + } + + /** + * Defines a Java UDF2 instance as user-defined function (UDF). + * The caller must specify the output data type, and there is no automatic input type coercion. + * By default the returned UDF is deterministic. To change it to nondeterministic, call the + * API `UserDefinedFunction.asNondeterministic()`. + * + * @group udf_funcs + * @since 2.3.0 + */ + def udf(f: UDF2[_, _, _], returnType: DataType): UserDefinedFunction = { + val func = f.asInstanceOf[UDF2[Any, Any, Any]].call(_: Any, _: Any) + UserDefinedFunction(func, returnType, inputTypes = None) + } + + /** + * Defines a Java UDF3 instance as user-defined function (UDF). + * The caller must specify the output data type, and there is no automatic input type coercion. + * By default the returned UDF is deterministic. To change it to nondeterministic, call the + * API `UserDefinedFunction.asNondeterministic()`. + * + * @group udf_funcs + * @since 2.3.0 + */ + def udf(f: UDF3[_, _, _, _], returnType: DataType): UserDefinedFunction = { + val func = f.asInstanceOf[UDF3[Any, Any, Any, Any]].call(_: Any, _: Any, _: Any) + UserDefinedFunction(func, returnType, inputTypes = None) + } + + /** + * Defines a Java UDF4 instance as user-defined function (UDF). + * The caller must specify the output data type, and there is no automatic input type coercion. + * By default the returned UDF is deterministic. To change it to nondeterministic, call the + * API `UserDefinedFunction.asNondeterministic()`. + * + * @group udf_funcs + * @since 2.3.0 + */ + def udf(f: UDF4[_, _, _, _, _], returnType: DataType): UserDefinedFunction = { + val func = f.asInstanceOf[UDF4[Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any) + UserDefinedFunction(func, returnType, inputTypes = None) + } + + /** + * Defines a Java UDF5 instance as user-defined function (UDF). + * The caller must specify the output data type, and there is no automatic input type coercion. + * By default the returned UDF is deterministic. To change it to nondeterministic, call the + * API `UserDefinedFunction.asNondeterministic()`. + * + * @group udf_funcs + * @since 2.3.0 + */ + def udf(f: UDF5[_, _, _, _, _, _], returnType: DataType): UserDefinedFunction = { + val func = f.asInstanceOf[UDF5[Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any) + UserDefinedFunction(func, returnType, inputTypes = None) + } + + /** + * Defines a Java UDF6 instance as user-defined function (UDF). + * The caller must specify the output data type, and there is no automatic input type coercion. + * By default the returned UDF is deterministic. To change it to nondeterministic, call the + * API `UserDefinedFunction.asNondeterministic()`. + * + * @group udf_funcs + * @since 2.3.0 + */ + def udf(f: UDF6[_, _, _, _, _, _, _], returnType: DataType): UserDefinedFunction = { + val func = f.asInstanceOf[UDF6[Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any) + UserDefinedFunction(func, returnType, inputTypes = None) + } + + /** + * Defines a Java UDF7 instance as user-defined function (UDF). + * The caller must specify the output data type, and there is no automatic input type coercion. + * By default the returned UDF is deterministic. To change it to nondeterministic, call the + * API `UserDefinedFunction.asNondeterministic()`. + * + * @group udf_funcs + * @since 2.3.0 + */ + def udf(f: UDF7[_, _, _, _, _, _, _, _], returnType: DataType): UserDefinedFunction = { + val func = f.asInstanceOf[UDF7[Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) + UserDefinedFunction(func, returnType, inputTypes = None) + } + + /** + * Defines a Java UDF8 instance as user-defined function (UDF). + * The caller must specify the output data type, and there is no automatic input type coercion. + * By default the returned UDF is deterministic. To change it to nondeterministic, call the + * API `UserDefinedFunction.asNondeterministic()`. + * + * @group udf_funcs + * @since 2.3.0 + */ + def udf(f: UDF8[_, _, _, _, _, _, _, _, _], returnType: DataType): UserDefinedFunction = { + val func = f.asInstanceOf[UDF8[Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) + UserDefinedFunction(func, returnType, inputTypes = None) + } + + /** + * Defines a Java UDF9 instance as user-defined function (UDF). + * The caller must specify the output data type, and there is no automatic input type coercion. + * By default the returned UDF is deterministic. To change it to nondeterministic, call the + * API `UserDefinedFunction.asNondeterministic()`. + * + * @group udf_funcs + * @since 2.3.0 + */ + def udf(f: UDF9[_, _, _, _, _, _, _, _, _, _], returnType: DataType): UserDefinedFunction = { + val func = f.asInstanceOf[UDF9[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) + UserDefinedFunction(func, returnType, inputTypes = None) + } + + /** + * Defines a Java UDF10 instance as user-defined function (UDF). + * The caller must specify the output data type, and there is no automatic input type coercion. + * By default the returned UDF is deterministic. To change it to nondeterministic, call the + * API `UserDefinedFunction.asNondeterministic()`. + * + * @group udf_funcs + * @since 2.3.0 + */ + def udf(f: UDF10[_, _, _, _, _, _, _, _, _, _, _], returnType: DataType): UserDefinedFunction = { + val func = f.asInstanceOf[UDF10[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) + UserDefinedFunction(func, returnType, inputTypes = None) + } + // scalastyle:on parameter.number // scalastyle:on line.size.limit /** * Defines a deterministic user-defined function (UDF) using a Scala closure. For this variant, * the caller must specify the output data type, and there is no automatic input type coercion. - * To change a UDF to nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. + * By default the returned UDF is deterministic. To change it to nondeterministic, call the + * API `UserDefinedFunction.asNondeterministic()`. * * @param f A closure in Scala * @param dataType The output data type of the UDF diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index b007093dad84b..4f8a31f185724 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -36,6 +36,7 @@ import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.expressions.UserDefinedFunction; import org.apache.spark.sql.test.TestSparkSession; import org.apache.spark.sql.types.*; import org.apache.spark.util.sketch.BloomFilter; @@ -455,4 +456,14 @@ public void testCircularReferenceBean() { CircularReference1Bean bean = new CircularReference1Bean(); spark.createDataFrame(Arrays.asList(bean), CircularReference1Bean.class); } + + @Test + public void testUDF() { + UserDefinedFunction foo = udf((Integer i, String s) -> i.toString() + s, DataTypes.StringType); + Dataset df = spark.table("testData").select(foo.apply(col("key"), col("value"))); + String[] result = df.collectAsList().stream().map(row -> row.getString(0)).toArray(String[]::new); + String[] expected = spark.table("testData").collectAsList().stream() + .map(row -> row.get(0).toString() + row.getString(1)).toArray(String[]::new); + Assert.assertArrayEquals(expected, result); + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index 7f1c009ca6e7a..db37be68e42e6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -17,12 +17,13 @@ package org.apache.spark.sql +import org.apache.spark.sql.api.java._ import org.apache.spark.sql.catalyst.plans.logical.Project import org.apache.spark.sql.execution.command.ExplainCommand -import org.apache.spark.sql.functions.{col, udf} +import org.apache.spark.sql.functions.udf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.test.SQLTestData._ -import org.apache.spark.sql.types.DataTypes +import org.apache.spark.sql.types.{DataTypes, DoubleType} private case class FunctionResult(f1: String, f2: String) @@ -128,6 +129,13 @@ class UDFSuite extends QueryTest with SharedSQLContext { val df2 = testData.select(bar()) assert(df2.logicalPlan.asInstanceOf[Project].projectList.forall(!_.deterministic)) assert(df2.head().getDouble(0) >= 0.0) + + val javaUdf = udf(new UDF0[Double] { + override def call(): Double = Math.random() + }, DoubleType).asNondeterministic() + val df3 = testData.select(javaUdf()) + assert(df3.logicalPlan.asInstanceOf[Project].projectList.forall(!_.deterministic)) + assert(df3.head().getDouble(0) >= 0.0) } test("TwoArgument UDF") { From 5aadbc929cb194e06dbd3bab054a161569289af5 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Thu, 4 Jan 2018 21:07:31 +0800 Subject: [PATCH 0012/1161] [SPARK-22939][PYSPARK] Support Spark UDF in registerFunction ## What changes were proposed in this pull request? ```Python import random from pyspark.sql.functions import udf from pyspark.sql.types import IntegerType, StringType random_udf = udf(lambda: int(random.random() * 100), IntegerType()).asNondeterministic() spark.catalog.registerFunction("random_udf", random_udf, StringType()) spark.sql("SELECT random_udf()").collect() ``` We will get the following error. ``` Py4JError: An error occurred while calling o29.__getnewargs__. Trace: py4j.Py4JException: Method __getnewargs__([]) does not exist at py4j.reflection.ReflectionEngine.getMethod(ReflectionEngine.java:318) at py4j.reflection.ReflectionEngine.getMethod(ReflectionEngine.java:326) at py4j.Gateway.invoke(Gateway.java:274) at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132) at py4j.commands.CallCommand.execute(CallCommand.java:79) at py4j.GatewayConnection.run(GatewayConnection.java:214) at java.lang.Thread.run(Thread.java:745) ``` This PR is to support it. ## How was this patch tested? WIP Author: gatorsmile Closes #20137 from gatorsmile/registerFunction. --- python/pyspark/sql/catalog.py | 27 +++++++++++++++---- python/pyspark/sql/context.py | 16 +++++++++--- python/pyspark/sql/tests.py | 49 +++++++++++++++++++++++++---------- python/pyspark/sql/udf.py | 21 ++++++++++----- 4 files changed, 84 insertions(+), 29 deletions(-) diff --git a/python/pyspark/sql/catalog.py b/python/pyspark/sql/catalog.py index 659bc65701a0c..156603128d063 100644 --- a/python/pyspark/sql/catalog.py +++ b/python/pyspark/sql/catalog.py @@ -227,15 +227,15 @@ def dropGlobalTempView(self, viewName): @ignore_unicode_prefix @since(2.0) def registerFunction(self, name, f, returnType=StringType()): - """Registers a python function (including lambda function) as a UDF - so it can be used in SQL statements. + """Registers a Python function (including lambda function) or a :class:`UserDefinedFunction` + as a UDF. The registered UDF can be used in SQL statement. In addition to a name and the function itself, the return type can be optionally specified. When the return type is not given it default to a string and conversion will automatically be done. For any other return type, the produced object must match the specified type. :param name: name of the UDF - :param f: python function + :param f: a Python function, or a wrapped/native UserDefinedFunction :param returnType: a :class:`pyspark.sql.types.DataType` object :return: a wrapped :class:`UserDefinedFunction` @@ -255,9 +255,26 @@ def registerFunction(self, name, f, returnType=StringType()): >>> _ = spark.udf.register("stringLengthInt", len, IntegerType()) >>> spark.sql("SELECT stringLengthInt('test')").collect() [Row(stringLengthInt(test)=4)] + + >>> import random + >>> from pyspark.sql.functions import udf + >>> from pyspark.sql.types import IntegerType, StringType + >>> random_udf = udf(lambda: random.randint(0, 100), IntegerType()).asNondeterministic() + >>> newRandom_udf = spark.catalog.registerFunction("random_udf", random_udf, StringType()) + >>> spark.sql("SELECT random_udf()").collect() # doctest: +SKIP + [Row(random_udf()=u'82')] + >>> spark.range(1).select(newRandom_udf()).collect() # doctest: +SKIP + [Row(random_udf()=u'62')] """ - udf = UserDefinedFunction(f, returnType=returnType, name=name, - evalType=PythonEvalType.SQL_BATCHED_UDF) + + # This is to check whether the input function is a wrapped/native UserDefinedFunction + if hasattr(f, 'asNondeterministic'): + udf = UserDefinedFunction(f.func, returnType=returnType, name=name, + evalType=PythonEvalType.SQL_BATCHED_UDF, + deterministic=f.deterministic) + else: + udf = UserDefinedFunction(f, returnType=returnType, name=name, + evalType=PythonEvalType.SQL_BATCHED_UDF) self._jsparkSession.udf().registerPython(name, udf._judf) return udf._wrapped() diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index b1e723cdecef3..b8d86cc098e94 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -175,15 +175,15 @@ def range(self, start, end=None, step=1, numPartitions=None): @ignore_unicode_prefix @since(1.2) def registerFunction(self, name, f, returnType=StringType()): - """Registers a python function (including lambda function) as a UDF - so it can be used in SQL statements. + """Registers a Python function (including lambda function) or a :class:`UserDefinedFunction` + as a UDF. The registered UDF can be used in SQL statement. In addition to a name and the function itself, the return type can be optionally specified. When the return type is not given it default to a string and conversion will automatically be done. For any other return type, the produced object must match the specified type. :param name: name of the UDF - :param f: python function + :param f: a Python function, or a wrapped/native UserDefinedFunction :param returnType: a :class:`pyspark.sql.types.DataType` object :return: a wrapped :class:`UserDefinedFunction` @@ -203,6 +203,16 @@ def registerFunction(self, name, f, returnType=StringType()): >>> _ = sqlContext.udf.register("stringLengthInt", lambda x: len(x), IntegerType()) >>> sqlContext.sql("SELECT stringLengthInt('test')").collect() [Row(stringLengthInt(test)=4)] + + >>> import random + >>> from pyspark.sql.functions import udf + >>> from pyspark.sql.types import IntegerType, StringType + >>> random_udf = udf(lambda: random.randint(0, 100), IntegerType()).asNondeterministic() + >>> newRandom_udf = sqlContext.registerFunction("random_udf", random_udf, StringType()) + >>> sqlContext.sql("SELECT random_udf()").collect() # doctest: +SKIP + [Row(random_udf()=u'82')] + >>> sqlContext.range(1).select(newRandom_udf()).collect() # doctest: +SKIP + [Row(random_udf()=u'62')] """ return self.sparkSession.catalog.registerFunction(name, f, returnType) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 67bdb3d72d93b..6dc767f9ec46e 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -378,6 +378,41 @@ def test_udf2(self): [res] = self.spark.sql("SELECT strlen(a) FROM test WHERE strlen(a) > 1").collect() self.assertEqual(4, res[0]) + def test_udf3(self): + twoargs = self.spark.catalog.registerFunction( + "twoArgs", UserDefinedFunction(lambda x, y: len(x) + y), IntegerType()) + self.assertEqual(twoargs.deterministic, True) + [row] = self.spark.sql("SELECT twoArgs('test', 1)").collect() + self.assertEqual(row[0], 5) + + def test_nondeterministic_udf(self): + from pyspark.sql.functions import udf + import random + udf_random_col = udf(lambda: int(100 * random.random()), IntegerType()).asNondeterministic() + self.assertEqual(udf_random_col.deterministic, False) + df = self.spark.createDataFrame([Row(1)]).select(udf_random_col().alias('RAND')) + udf_add_ten = udf(lambda rand: rand + 10, IntegerType()) + [row] = df.withColumn('RAND_PLUS_TEN', udf_add_ten('RAND')).collect() + self.assertEqual(row[0] + 10, row[1]) + + def test_nondeterministic_udf2(self): + import random + from pyspark.sql.functions import udf + random_udf = udf(lambda: random.randint(6, 6), IntegerType()).asNondeterministic() + self.assertEqual(random_udf.deterministic, False) + random_udf1 = self.spark.catalog.registerFunction("randInt", random_udf, StringType()) + self.assertEqual(random_udf1.deterministic, False) + [row] = self.spark.sql("SELECT randInt()").collect() + self.assertEqual(row[0], "6") + [row] = self.spark.range(1).select(random_udf1()).collect() + self.assertEqual(row[0], "6") + [row] = self.spark.range(1).select(random_udf()).collect() + self.assertEqual(row[0], 6) + # render_doc() reproduces the help() exception without printing output + pydoc.render_doc(udf(lambda: random.randint(6, 6), IntegerType())) + pydoc.render_doc(random_udf) + pydoc.render_doc(random_udf1) + def test_chained_udf(self): self.spark.catalog.registerFunction("double", lambda x: x + x, IntegerType()) [row] = self.spark.sql("SELECT double(1)").collect() @@ -435,15 +470,6 @@ def test_udf_with_array_type(self): self.assertEqual(list(range(3)), l1) self.assertEqual(1, l2) - def test_nondeterministic_udf(self): - from pyspark.sql.functions import udf - import random - udf_random_col = udf(lambda: int(100 * random.random()), IntegerType()).asNondeterministic() - df = self.spark.createDataFrame([Row(1)]).select(udf_random_col().alias('RAND')) - udf_add_ten = udf(lambda rand: rand + 10, IntegerType()) - [row] = df.withColumn('RAND_PLUS_TEN', udf_add_ten('RAND')).collect() - self.assertEqual(row[0] + 10, row[1]) - def test_broadcast_in_udf(self): bar = {"a": "aa", "b": "bb", "c": "abc"} foo = self.sc.broadcast(bar) @@ -567,7 +593,6 @@ def test_read_multiple_orc_file(self): def test_udf_with_input_file_name(self): from pyspark.sql.functions import udf, input_file_name - from pyspark.sql.types import StringType sourceFile = udf(lambda path: path, StringType()) filePath = "python/test_support/sql/people1.json" row = self.spark.read.json(filePath).select(sourceFile(input_file_name())).first() @@ -575,7 +600,6 @@ def test_udf_with_input_file_name(self): def test_udf_with_input_file_name_for_hadooprdd(self): from pyspark.sql.functions import udf, input_file_name - from pyspark.sql.types import StringType def filename(path): return path @@ -635,7 +659,6 @@ def test_udf_with_string_return_type(self): def test_udf_shouldnt_accept_noncallable_object(self): from pyspark.sql.functions import UserDefinedFunction - from pyspark.sql.types import StringType non_callable = None self.assertRaises(TypeError, UserDefinedFunction, non_callable, StringType()) @@ -1299,7 +1322,6 @@ def test_between_function(self): df.filter(df.a.between(df.b, df.c)).collect()) def test_struct_type(self): - from pyspark.sql.types import StructType, StringType, StructField struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None) struct2 = StructType([StructField("f1", StringType(), True), StructField("f2", StringType(), True, None)]) @@ -1368,7 +1390,6 @@ def test_parse_datatype_string(self): _parse_datatype_string("a INT, c DOUBLE")) def test_metadata_null(self): - from pyspark.sql.types import StructType, StringType, StructField schema = StructType([StructField("f1", StringType(), True, None), StructField("f2", StringType(), True, {'a': None})]) rdd = self.sc.parallelize([["a", "b"], ["c", "d"]]) diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index 54b5a8656e1c8..5e75eb6545333 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -56,7 +56,8 @@ def _create_udf(f, returnType, evalType): ) # Set the name of the UserDefinedFunction object to be the name of function f - udf_obj = UserDefinedFunction(f, returnType=returnType, name=None, evalType=evalType) + udf_obj = UserDefinedFunction( + f, returnType=returnType, name=None, evalType=evalType, deterministic=True) return udf_obj._wrapped() @@ -67,8 +68,10 @@ class UserDefinedFunction(object): .. versionadded:: 1.3 """ def __init__(self, func, - returnType=StringType(), name=None, - evalType=PythonEvalType.SQL_BATCHED_UDF): + returnType=StringType(), + name=None, + evalType=PythonEvalType.SQL_BATCHED_UDF, + deterministic=True): if not callable(func): raise TypeError( "Invalid function: not a function or callable (__call__ is not defined): " @@ -92,7 +95,7 @@ def __init__(self, func, func.__name__ if hasattr(func, '__name__') else func.__class__.__name__) self.evalType = evalType - self._deterministic = True + self.deterministic = deterministic @property def returnType(self): @@ -130,7 +133,7 @@ def _create_judf(self): wrapped_func = _wrap_function(sc, self.func, self.returnType) jdt = spark._jsparkSession.parseDataType(self.returnType.json()) judf = sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonFunction( - self._name, wrapped_func, jdt, self.evalType, self._deterministic) + self._name, wrapped_func, jdt, self.evalType, self.deterministic) return judf def __call__(self, *cols): @@ -138,6 +141,9 @@ def __call__(self, *cols): sc = SparkContext._active_spark_context return Column(judf.apply(_to_seq(sc, cols, _to_java_column))) + # This function is for improving the online help system in the interactive interpreter. + # For example, the built-in help / pydoc.help. It wraps the UDF with the docstring and + # argument annotation. (See: SPARK-19161) def _wrapped(self): """ Wrap this udf with a function and attach docstring from func @@ -162,7 +168,8 @@ def wrapper(*args): wrapper.func = self.func wrapper.returnType = self.returnType wrapper.evalType = self.evalType - wrapper.asNondeterministic = self.asNondeterministic + wrapper.deterministic = self.deterministic + wrapper.asNondeterministic = lambda: self.asNondeterministic()._wrapped() return wrapper @@ -172,5 +179,5 @@ def asNondeterministic(self): .. versionadded:: 2.3 """ - self._deterministic = False + self.deterministic = False return self From 6f68316e98fad72b171df422566e1fc9a7bbfcde Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Thu, 4 Jan 2018 21:15:10 +0800 Subject: [PATCH 0013/1161] [SPARK-22771][SQL] Add a missing return statement in Concat.checkInputDataTypes ## What changes were proposed in this pull request? This pr is a follow-up to fix a bug left in #19977. ## How was this patch tested? Added tests in `StringExpressionsSuite`. Author: Takeshi Yamamuro Closes #20149 from maropu/SPARK-22771-FOLLOWUP. --- .../sql/catalyst/expressions/stringExpressions.scala | 2 +- .../expressions/StringExpressionsSuite.scala | 12 ++++++++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index b0da55a4a961b..41dc762154a4c 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -58,7 +58,7 @@ case class Concat(children: Seq[Expression]) extends Expression { } else { val childTypes = children.map(_.dataType) if (childTypes.exists(tpe => !Seq(StringType, BinaryType).contains(tpe))) { - TypeCheckResult.TypeCheckFailure( + return TypeCheckResult.TypeCheckFailure( s"input to function $prettyName should have StringType or BinaryType, but it's " + childTypes.map(_.simpleString).mkString("[", ", ", "]")) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index 54cde77176e27..97ddbeba2c5ca 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -51,6 +51,18 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Concat(strs.map(Literal.create(_, StringType))), strs.mkString, EmptyRow) } + test("SPARK-22771 Check Concat.checkInputDataTypes results") { + assert(Concat(Seq.empty[Expression]).checkInputDataTypes().isSuccess) + assert(Concat(Literal.create("a") :: Literal.create("b") :: Nil) + .checkInputDataTypes().isSuccess) + assert(Concat(Literal.create("a".getBytes) :: Literal.create("b".getBytes) :: Nil) + .checkInputDataTypes().isSuccess) + assert(Concat(Literal.create(1) :: Literal.create(2) :: Nil) + .checkInputDataTypes().isFailure) + assert(Concat(Literal.create("a") :: Literal.create("b".getBytes) :: Nil) + .checkInputDataTypes().isFailure) + } + test("concat_ws") { def testConcatWs(expected: String, sep: String, inputs: Any*): Unit = { val inputExprs = inputs.map { From 93f92c0ed7442a4382e97254307309977ff676f8 Mon Sep 17 00:00:00 2001 From: jerryshao Date: Thu, 4 Jan 2018 11:39:42 -0800 Subject: [PATCH 0014/1161] [SPARK-21475][CORE][2ND ATTEMPT] Change to use NIO's Files API for external shuffle service ## What changes were proposed in this pull request? This PR is the second attempt of #18684 , NIO's Files API doesn't override `skip` method for `InputStream`, so it will bring in performance issue (mentioned in #20119). But using `FileInputStream`/`FileOutputStream` will also bring in memory issue (https://dzone.com/articles/fileinputstream-fileoutputstream-considered-harmful), which is severe for long running external shuffle service. So here in this proposal, only fixing the external shuffle service related code. ## How was this patch tested? Existing tests. Author: jerryshao Closes #20144 from jerryshao/SPARK-21475-v2. --- .../apache/spark/network/buffer/FileSegmentManagedBuffer.java | 3 ++- .../apache/spark/network/shuffle/ShuffleIndexInformation.java | 4 ++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/common/network-common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java b/common/network-common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java index c20fab83c3460..8b8f9892847c3 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java +++ b/common/network-common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java @@ -24,6 +24,7 @@ import java.io.RandomAccessFile; import java.nio.ByteBuffer; import java.nio.channels.FileChannel; +import java.nio.file.StandardOpenOption; import com.google.common.base.Objects; import com.google.common.io.ByteStreams; @@ -132,7 +133,7 @@ public Object convertToNetty() throws IOException { if (conf.lazyFileDescriptor()) { return new DefaultFileRegion(file, offset, length); } else { - FileChannel fileChannel = new FileInputStream(file).getChannel(); + FileChannel fileChannel = FileChannel.open(file.toPath(), StandardOpenOption.READ); return new DefaultFileRegion(fileChannel, offset, length); } } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleIndexInformation.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleIndexInformation.java index eacf485344b76..386738ece51a6 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleIndexInformation.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleIndexInformation.java @@ -19,10 +19,10 @@ import java.io.DataInputStream; import java.io.File; -import java.io.FileInputStream; import java.io.IOException; import java.nio.ByteBuffer; import java.nio.LongBuffer; +import java.nio.file.Files; /** * Keeps the index information for a particular map output @@ -39,7 +39,7 @@ public ShuffleIndexInformation(File indexFile) throws IOException { offsets = buffer.asLongBuffer(); DataInputStream dis = null; try { - dis = new DataInputStream(new FileInputStream(indexFile)); + dis = new DataInputStream(Files.newInputStream(indexFile.toPath())); dis.readFully(buffer.array()); } finally { if (dis != null) { From d2cddc88eac32f26b18ec26bb59e85c6f09a8c88 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Thu, 4 Jan 2018 16:19:00 -0600 Subject: [PATCH 0015/1161] [SPARK-22850][CORE] Ensure queued events are delivered to all event queues. The code in LiveListenerBus was queueing events before start in the queues themselves; so in situations like the following: bus.post(someEvent) bus.addToEventLogQueue(listener) bus.start() "someEvent" would not be delivered to "listener" if that was the first listener in the queue, because the queue wouldn't exist when the event was posted. This change buffers the events before starting the bus in the bus itself, so that they can be delivered to all registered queues when the bus is started. Also tweaked the unit tests to cover the behavior above. Author: Marcelo Vanzin Closes #20039 from vanzin/SPARK-22850. --- .../spark/scheduler/LiveListenerBus.scala | 45 ++++++++++++++++--- .../spark/scheduler/SparkListenerSuite.scala | 21 +++++---- 2 files changed, 52 insertions(+), 14 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala index 23121402b1025..ba6387a8f08ad 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala @@ -62,6 +62,9 @@ private[spark] class LiveListenerBus(conf: SparkConf) { private val queues = new CopyOnWriteArrayList[AsyncEventQueue]() + // Visible for testing. + @volatile private[scheduler] var queuedEvents = new mutable.ListBuffer[SparkListenerEvent]() + /** Add a listener to queue shared by all non-internal listeners. */ def addToSharedQueue(listener: SparkListenerInterface): Unit = { addToQueue(listener, SHARED_QUEUE) @@ -125,13 +128,39 @@ private[spark] class LiveListenerBus(conf: SparkConf) { /** Post an event to all queues. */ def post(event: SparkListenerEvent): Unit = { - if (!stopped.get()) { - metrics.numEventsPosted.inc() - val it = queues.iterator() - while (it.hasNext()) { - it.next().post(event) + if (stopped.get()) { + return + } + + metrics.numEventsPosted.inc() + + // If the event buffer is null, it means the bus has been started and we can avoid + // synchronization and post events directly to the queues. This should be the most + // common case during the life of the bus. + if (queuedEvents == null) { + postToQueues(event) + return + } + + // Otherwise, need to synchronize to check whether the bus is started, to make sure the thread + // calling start() picks up the new event. + synchronized { + if (!started.get()) { + queuedEvents += event + return } } + + // If the bus was already started when the check above was made, just post directly to the + // queues. + postToQueues(event) + } + + private def postToQueues(event: SparkListenerEvent): Unit = { + val it = queues.iterator() + while (it.hasNext()) { + it.next().post(event) + } } /** @@ -149,7 +178,11 @@ private[spark] class LiveListenerBus(conf: SparkConf) { } this.sparkContext = sc - queues.asScala.foreach(_.start(sc)) + queues.asScala.foreach { q => + q.start(sc) + queuedEvents.foreach(q.post) + } + queuedEvents = null metricsSystem.registerSource(metrics) } diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala index 1beb36afa95f0..da6ecb82c7e42 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala @@ -48,7 +48,7 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match bus.metrics.metricRegistry.counter(s"queue.$SHARED_QUEUE.numDroppedEvents").getCount } - private def queueSize(bus: LiveListenerBus): Int = { + private def sharedQueueSize(bus: LiveListenerBus): Int = { bus.metrics.metricRegistry.getGauges().get(s"queue.$SHARED_QUEUE.size").getValue() .asInstanceOf[Int] } @@ -73,12 +73,11 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match val conf = new SparkConf() val counter = new BasicJobCounter val bus = new LiveListenerBus(conf) - bus.addToSharedQueue(counter) // Metrics are initially empty. assert(bus.metrics.numEventsPosted.getCount === 0) assert(numDroppedEvents(bus) === 0) - assert(queueSize(bus) === 0) + assert(bus.queuedEvents.size === 0) assert(eventProcessingTimeCount(bus) === 0) // Post five events: @@ -87,7 +86,10 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match // Five messages should be marked as received and queued, but no messages should be posted to // listeners yet because the the listener bus hasn't been started. assert(bus.metrics.numEventsPosted.getCount === 5) - assert(queueSize(bus) === 5) + assert(bus.queuedEvents.size === 5) + + // Add the counter to the bus after messages have been queued for later delivery. + bus.addToSharedQueue(counter) assert(counter.count === 0) // Starting listener bus should flush all buffered events @@ -95,9 +97,12 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match Mockito.verify(mockMetricsSystem).registerSource(bus.metrics) bus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) assert(counter.count === 5) - assert(queueSize(bus) === 0) + assert(sharedQueueSize(bus) === 0) assert(eventProcessingTimeCount(bus) === 5) + // After the bus is started, there should be no more queued events. + assert(bus.queuedEvents === null) + // After listener bus has stopped, posting events should not increment counter bus.stop() (1 to 5).foreach { _ => bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded)) } @@ -188,18 +193,18 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match // Post a message to the listener bus and wait for processing to begin: bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded)) listenerStarted.acquire() - assert(queueSize(bus) === 0) + assert(sharedQueueSize(bus) === 0) assert(numDroppedEvents(bus) === 0) // If we post an additional message then it should remain in the queue because the listener is // busy processing the first event: bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded)) - assert(queueSize(bus) === 1) + assert(sharedQueueSize(bus) === 1) assert(numDroppedEvents(bus) === 0) // The queue is now full, so any additional events posted to the listener will be dropped: bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded)) - assert(queueSize(bus) === 1) + assert(sharedQueueSize(bus) === 1) assert(numDroppedEvents(bus) === 1) // Allow the the remaining events to be processed so we can stop the listener bus: From 95f9659abe8845f9f3f42fd7ababd79e55c52489 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Thu, 4 Jan 2018 15:00:09 -0800 Subject: [PATCH 0016/1161] [SPARK-22948][K8S] Move SparkPodInitContainer to correct package. Author: Marcelo Vanzin Closes #20156 from vanzin/SPARK-22948. --- dev/sparktestsupport/modules.py | 2 +- .../spark/deploy/{rest => }/k8s/SparkPodInitContainer.scala | 2 +- .../deploy/{rest => }/k8s/SparkPodInitContainerSuite.scala | 2 +- .../docker/src/main/dockerfiles/init-container/Dockerfile | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) rename resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/{rest => }/k8s/SparkPodInitContainer.scala (99%) rename resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/{rest => }/k8s/SparkPodInitContainerSuite.scala (98%) diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index f834563da9dda..7164180a6a7b0 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -539,7 +539,7 @@ def __hash__(self): kubernetes = Module( name="kubernetes", dependencies=[], - source_file_regexes=["resource-managers/kubernetes/core"], + source_file_regexes=["resource-managers/kubernetes"], build_profile_flags=["-Pkubernetes"], sbt_test_goals=["kubernetes/test"] ) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/rest/k8s/SparkPodInitContainer.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/SparkPodInitContainer.scala similarity index 99% rename from resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/rest/k8s/SparkPodInitContainer.scala rename to resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/SparkPodInitContainer.scala index 4a4b628aedbbf..c0f08786b76a1 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/rest/k8s/SparkPodInitContainer.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/SparkPodInitContainer.scala @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.deploy.rest.k8s +package org.apache.spark.deploy.k8s import java.io.File import java.util.concurrent.TimeUnit diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/rest/k8s/SparkPodInitContainerSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/SparkPodInitContainerSuite.scala similarity index 98% rename from resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/rest/k8s/SparkPodInitContainerSuite.scala rename to resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/SparkPodInitContainerSuite.scala index 6c557ec4a7c9a..e0f29ecd0fb53 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/rest/k8s/SparkPodInitContainerSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/SparkPodInitContainerSuite.scala @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.deploy.rest.k8s +package org.apache.spark.deploy.k8s import java.io.File import java.util.UUID diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/init-container/Dockerfile b/resource-managers/kubernetes/docker/src/main/dockerfiles/init-container/Dockerfile index 055493188fcb7..047056ab2633b 100644 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/init-container/Dockerfile +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/init-container/Dockerfile @@ -21,4 +21,4 @@ FROM spark-base # command should be invoked from the top level directory of the Spark distribution. E.g.: # docker build -t spark-init:latest -f kubernetes/dockerfiles/init-container/Dockerfile . -ENTRYPOINT [ "/opt/entrypoint.sh", "/opt/spark/bin/spark-class", "org.apache.spark.deploy.rest.k8s.SparkPodInitContainer" ] +ENTRYPOINT [ "/opt/entrypoint.sh", "/opt/spark/bin/spark-class", "org.apache.spark.deploy.k8s.SparkPodInitContainer" ] From e288fc87a027ec1e1a21401d1f151df20dbfecf3 Mon Sep 17 00:00:00 2001 From: Yinan Li Date: Thu, 4 Jan 2018 15:35:20 -0800 Subject: [PATCH 0017/1161] [SPARK-22953][K8S] Avoids adding duplicated secret volumes when init-container is used ## What changes were proposed in this pull request? User-specified secrets are mounted into both the main container and init-container (when it is used) in a Spark driver/executor pod, using the `MountSecretsBootstrap`. Because `MountSecretsBootstrap` always adds new secret volumes for the secrets to the pod, the same secret volumes get added twice, one when mounting the secrets to the main container, and the other when mounting the secrets to the init-container. This PR fixes the issue by separating `MountSecretsBootstrap.mountSecrets` out into two methods: `addSecretVolumes` for adding secret volumes to a pod and `mountSecrets` for mounting secret volumes to a container, respectively. `addSecretVolumes` is only called once for each pod, whereas `mountSecrets` is called individually for the main container and the init-container (if it is used). Ref: https://github.com/apache-spark-on-k8s/spark/issues/594. ## How was this patch tested? Unit tested and manually tested. vanzin This replaces https://github.com/apache/spark/pull/20148. hex108 foxish kimoonkim Author: Yinan Li Closes #20159 from liyinan926/master. --- .../deploy/k8s/MountSecretsBootstrap.scala | 30 ++++++++++++------- .../k8s/submit/DriverConfigOrchestrator.scala | 16 +++++----- .../steps/BasicDriverConfigurationStep.scala | 2 +- .../submit/steps/DriverMountSecretsStep.scala | 4 +-- .../InitContainerMountSecretsStep.scala | 11 +++---- .../cluster/k8s/ExecutorPodFactory.scala | 6 ++-- .../k8s/{submit => }/SecretVolumeUtils.scala | 18 +++++------ .../BasicDriverConfigurationStepSuite.scala | 4 +-- .../steps/DriverMountSecretsStepSuite.scala | 4 +-- .../InitContainerMountSecretsStepSuite.scala | 7 +---- .../cluster/k8s/ExecutorPodFactorySuite.scala | 14 +++++---- 11 files changed, 61 insertions(+), 55 deletions(-) rename resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/{submit => }/SecretVolumeUtils.scala (71%) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/MountSecretsBootstrap.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/MountSecretsBootstrap.scala index 8286546ce0641..c35e7db51d407 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/MountSecretsBootstrap.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/MountSecretsBootstrap.scala @@ -24,26 +24,36 @@ import io.fabric8.kubernetes.api.model.{Container, ContainerBuilder, Pod, PodBui private[spark] class MountSecretsBootstrap(secretNamesToMountPaths: Map[String, String]) { /** - * Mounts Kubernetes secrets as secret volumes into the given container in the given pod. + * Add new secret volumes for the secrets specified in secretNamesToMountPaths into the given pod. * * @param pod the pod into which the secret volumes are being added. - * @param container the container into which the secret volumes are being mounted. - * @return the updated pod and container with the secrets mounted. + * @return the updated pod with the secret volumes added. */ - def mountSecrets(pod: Pod, container: Container): (Pod, Container) = { + def addSecretVolumes(pod: Pod): Pod = { var podBuilder = new PodBuilder(pod) secretNamesToMountPaths.keys.foreach { name => podBuilder = podBuilder .editOrNewSpec() .addNewVolume() - .withName(secretVolumeName(name)) - .withNewSecret() - .withSecretName(name) - .endSecret() - .endVolume() + .withName(secretVolumeName(name)) + .withNewSecret() + .withSecretName(name) + .endSecret() + .endVolume() .endSpec() } + podBuilder.build() + } + + /** + * Mounts Kubernetes secret volumes of the secrets specified in secretNamesToMountPaths into the + * given container. + * + * @param container the container into which the secret volumes are being mounted. + * @return the updated container with the secrets mounted. + */ + def mountSecrets(container: Container): Container = { var containerBuilder = new ContainerBuilder(container) secretNamesToMountPaths.foreach { case (name, path) => containerBuilder = containerBuilder @@ -53,7 +63,7 @@ private[spark] class MountSecretsBootstrap(secretNamesToMountPaths: Map[String, .endVolumeMount() } - (podBuilder.build(), containerBuilder.build()) + containerBuilder.build() } private def secretVolumeName(secretName: String): String = { diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestrator.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestrator.scala index 00c9c4ee49177..c9cc300d65569 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestrator.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestrator.scala @@ -127,6 +127,12 @@ private[spark] class DriverConfigOrchestrator( Nil } + val mountSecretsStep = if (secretNamesToMountPaths.nonEmpty) { + Seq(new DriverMountSecretsStep(new MountSecretsBootstrap(secretNamesToMountPaths))) + } else { + Nil + } + val initContainerBootstrapStep = if (existNonContainerLocalFiles(sparkJars ++ sparkFiles)) { val orchestrator = new InitContainerConfigOrchestrator( sparkJars, @@ -147,19 +153,13 @@ private[spark] class DriverConfigOrchestrator( Nil } - val mountSecretsStep = if (secretNamesToMountPaths.nonEmpty) { - Seq(new DriverMountSecretsStep(new MountSecretsBootstrap(secretNamesToMountPaths))) - } else { - Nil - } - Seq( initialSubmissionStep, serviceBootstrapStep, kubernetesCredentialsStep) ++ dependencyResolutionStep ++ - initContainerBootstrapStep ++ - mountSecretsStep + mountSecretsStep ++ + initContainerBootstrapStep } private def existNonContainerLocalFiles(files: Seq[String]): Boolean = { diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStep.scala index b7a69a7dfd472..eca46b84c6066 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStep.scala @@ -119,7 +119,7 @@ private[spark] class BasicDriverConfigurationStep( .endEnv() .addNewEnv() .withName(ENV_DRIVER_ARGS) - .withValue(appArgs.map(arg => "\"" + arg + "\"").mkString(" ")) + .withValue(appArgs.mkString(" ")) .endEnv() .addNewEnv() .withName(ENV_DRIVER_BIND_ADDRESS) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverMountSecretsStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverMountSecretsStep.scala index f872e0f4b65d1..91e9a9f211335 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverMountSecretsStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverMountSecretsStep.scala @@ -28,8 +28,8 @@ private[spark] class DriverMountSecretsStep( bootstrap: MountSecretsBootstrap) extends DriverConfigurationStep { override def configureDriver(driverSpec: KubernetesDriverSpec): KubernetesDriverSpec = { - val (pod, container) = bootstrap.mountSecrets( - driverSpec.driverPod, driverSpec.driverContainer) + val pod = bootstrap.addSecretVolumes(driverSpec.driverPod) + val container = bootstrap.mountSecrets(driverSpec.driverContainer) driverSpec.copy( driverPod = pod, driverContainer = container diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerMountSecretsStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerMountSecretsStep.scala index c0e7bb20cce8c..0daa7b95e8aae 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerMountSecretsStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerMountSecretsStep.scala @@ -28,12 +28,9 @@ private[spark] class InitContainerMountSecretsStep( bootstrap: MountSecretsBootstrap) extends InitContainerConfigurationStep { override def configureInitContainer(spec: InitContainerSpec) : InitContainerSpec = { - val (driverPod, initContainer) = bootstrap.mountSecrets( - spec.driverPod, - spec.initContainer) - spec.copy( - driverPod = driverPod, - initContainer = initContainer - ) + // Mount the secret volumes given that the volumes have already been added to the driver pod + // when mounting the secrets into the main driver container. + val initContainer = bootstrap.mountSecrets(spec.initContainer) + spec.copy(initContainer = initContainer) } } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala index ba5d891f4c77e..066d7e9f70ca5 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala @@ -214,7 +214,7 @@ private[spark] class ExecutorPodFactory( val (maybeSecretsMountedPod, maybeSecretsMountedContainer) = mountSecretsBootstrap.map { bootstrap => - bootstrap.mountSecrets(executorPod, containerWithLimitCores) + (bootstrap.addSecretVolumes(executorPod), bootstrap.mountSecrets(containerWithLimitCores)) }.getOrElse((executorPod, containerWithLimitCores)) val (bootstrappedPod, bootstrappedContainer) = @@ -227,7 +227,9 @@ private[spark] class ExecutorPodFactory( val (pod, mayBeSecretsMountedInitContainer) = initContainerMountSecretsBootstrap.map { bootstrap => - bootstrap.mountSecrets(podWithInitContainer.pod, podWithInitContainer.initContainer) + // Mount the secret volumes given that the volumes have already been added to the + // executor pod when mounting the secrets into the main executor container. + (podWithInitContainer.pod, bootstrap.mountSecrets(podWithInitContainer.initContainer)) }.getOrElse((podWithInitContainer.pod, podWithInitContainer.initContainer)) val bootstrappedPod = KubernetesUtils.appendInitContainer( diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/SecretVolumeUtils.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/SecretVolumeUtils.scala similarity index 71% rename from resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/SecretVolumeUtils.scala rename to resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/SecretVolumeUtils.scala index 8388c16ded268..16780584a674a 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/SecretVolumeUtils.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/SecretVolumeUtils.scala @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.deploy.k8s.submit +package org.apache.spark.deploy.k8s import scala.collection.JavaConverters._ @@ -22,15 +22,15 @@ import io.fabric8.kubernetes.api.model.{Container, Pod} private[spark] object SecretVolumeUtils { - def podHasVolume(driverPod: Pod, volumeName: String): Boolean = { - driverPod.getSpec.getVolumes.asScala.exists(volume => volume.getName == volumeName) + def podHasVolume(pod: Pod, volumeName: String): Boolean = { + pod.getSpec.getVolumes.asScala.exists { volume => + volume.getName == volumeName + } } - def containerHasVolume( - driverContainer: Container, - volumeName: String, - mountPath: String): Boolean = { - driverContainer.getVolumeMounts.asScala.exists(volumeMount => - volumeMount.getName == volumeName && volumeMount.getMountPath == mountPath) + def containerHasVolume(container: Container, volumeName: String, mountPath: String): Boolean = { + container.getVolumeMounts.asScala.exists { volumeMount => + volumeMount.getName == volumeName && volumeMount.getMountPath == mountPath + } } } diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStepSuite.scala index e864c6a16eeb1..8ee629ac8ddc1 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStepSuite.scala @@ -33,7 +33,7 @@ class BasicDriverConfigurationStepSuite extends SparkFunSuite { private val CONTAINER_IMAGE_PULL_POLICY = "IfNotPresent" private val APP_NAME = "spark-test" private val MAIN_CLASS = "org.apache.spark.examples.SparkPi" - private val APP_ARGS = Array("arg1", "arg2", "arg 3") + private val APP_ARGS = Array("arg1", "arg2", "\"arg 3\"") private val CUSTOM_ANNOTATION_KEY = "customAnnotation" private val CUSTOM_ANNOTATION_VALUE = "customAnnotationValue" private val DRIVER_CUSTOM_ENV_KEY1 = "customDriverEnv1" @@ -82,7 +82,7 @@ class BasicDriverConfigurationStepSuite extends SparkFunSuite { assert(envs(ENV_SUBMIT_EXTRA_CLASSPATH) === "/opt/spark/spark-examples.jar") assert(envs(ENV_DRIVER_MEMORY) === "256M") assert(envs(ENV_DRIVER_MAIN_CLASS) === MAIN_CLASS) - assert(envs(ENV_DRIVER_ARGS) === "\"arg1\" \"arg2\" \"arg 3\"") + assert(envs(ENV_DRIVER_ARGS) === "arg1 arg2 \"arg 3\"") assert(envs(DRIVER_CUSTOM_ENV_KEY1) === "customDriverEnv1") assert(envs(DRIVER_CUSTOM_ENV_KEY2) === "customDriverEnv2") diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverMountSecretsStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverMountSecretsStepSuite.scala index 9ec0cb55de5aa..960d0bda1d011 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverMountSecretsStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverMountSecretsStepSuite.scala @@ -17,8 +17,8 @@ package org.apache.spark.deploy.k8s.submit.steps import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.deploy.k8s.MountSecretsBootstrap -import org.apache.spark.deploy.k8s.submit.{KubernetesDriverSpec, SecretVolumeUtils} +import org.apache.spark.deploy.k8s.{MountSecretsBootstrap, SecretVolumeUtils} +import org.apache.spark.deploy.k8s.submit.KubernetesDriverSpec class DriverMountSecretsStepSuite extends SparkFunSuite { diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerMountSecretsStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerMountSecretsStepSuite.scala index eab4e17659456..7ac0bde80dfe6 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerMountSecretsStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerMountSecretsStepSuite.scala @@ -19,8 +19,7 @@ package org.apache.spark.deploy.k8s.submit.steps.initcontainer import io.fabric8.kubernetes.api.model.{ContainerBuilder, PodBuilder} import org.apache.spark.SparkFunSuite -import org.apache.spark.deploy.k8s.MountSecretsBootstrap -import org.apache.spark.deploy.k8s.submit.SecretVolumeUtils +import org.apache.spark.deploy.k8s.{MountSecretsBootstrap, SecretVolumeUtils} class InitContainerMountSecretsStepSuite extends SparkFunSuite { @@ -44,12 +43,8 @@ class InitContainerMountSecretsStepSuite extends SparkFunSuite { val initContainerMountSecretsStep = new InitContainerMountSecretsStep(mountSecretsBootstrap) val configuredInitContainerSpec = initContainerMountSecretsStep.configureInitContainer( baseInitContainerSpec) - - val podWithSecretsMounted = configuredInitContainerSpec.driverPod val initContainerWithSecretsMounted = configuredInitContainerSpec.initContainer - Seq(s"$SECRET_FOO-volume", s"$SECRET_BAR-volume").foreach(volumeName => - assert(SecretVolumeUtils.podHasVolume(podWithSecretsMounted, volumeName))) Seq(s"$SECRET_FOO-volume", s"$SECRET_BAR-volume").foreach(volumeName => assert(SecretVolumeUtils.containerHasVolume( initContainerWithSecretsMounted, volumeName, SECRET_MOUNT_PATH))) diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala index 7121a802c69c1..884da8aabd880 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala @@ -25,7 +25,7 @@ import org.mockito.Mockito._ import org.scalatest.{BeforeAndAfter, BeforeAndAfterEach} import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.deploy.k8s.{InitContainerBootstrap, MountSecretsBootstrap, PodWithDetachedInitContainer} +import org.apache.spark.deploy.k8s.{InitContainerBootstrap, MountSecretsBootstrap, PodWithDetachedInitContainer, SecretVolumeUtils} import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ @@ -165,17 +165,19 @@ class ExecutorPodFactorySuite extends SparkFunSuite with BeforeAndAfter with Bef val factory = new ExecutorPodFactory( conf, - None, + Some(secretsBootstrap), Some(initContainerBootstrap), Some(secretsBootstrap)) val executor = factory.createExecutorPod( "1", "dummy", "dummy", Seq[(String, String)](), driverPod, Map[String, Int]()) + assert(executor.getSpec.getVolumes.size() === 1) + assert(SecretVolumeUtils.podHasVolume(executor, "secret1-volume")) + assert(SecretVolumeUtils.containerHasVolume( + executor.getSpec.getContainers.get(0), "secret1-volume", "/var/secret1")) assert(executor.getSpec.getInitContainers.size() === 1) - assert(executor.getSpec.getInitContainers.get(0).getVolumeMounts.get(0).getName - === "secret1-volume") - assert(executor.getSpec.getInitContainers.get(0).getVolumeMounts.get(0) - .getMountPath === "/var/secret1") + assert(SecretVolumeUtils.containerHasVolume( + executor.getSpec.getInitContainers.get(0), "secret1-volume", "/var/secret1")) checkOwnerReferences(executor, driverPodUid) } From 0428368c2c5e135f99f62be20877bbbda43be310 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Thu, 4 Jan 2018 16:34:56 -0800 Subject: [PATCH 0018/1161] [SPARK-22960][K8S] Make build-push-docker-images.sh more dev-friendly. - Make it possible to build images from a git clone. - Make it easy to use minikube to test things. Also fixed what seemed like a bug: the base image wasn't getting the tag provided in the command line. Adding the tag allows users to use multiple Spark builds in the same kubernetes cluster. Tested by deploying images on minikube and running spark-submit from a dev environment; also by building the images with different tags and verifying "docker images" in minikube. Author: Marcelo Vanzin Closes #20154 from vanzin/SPARK-22960. --- docs/running-on-kubernetes.md | 9 +- .../src/main/dockerfiles/driver/Dockerfile | 3 +- .../src/main/dockerfiles/executor/Dockerfile | 3 +- .../dockerfiles/init-container/Dockerfile | 3 +- .../main/dockerfiles/spark-base/Dockerfile | 7 +- sbin/build-push-docker-images.sh | 120 +++++++++++++++--- 6 files changed, 117 insertions(+), 28 deletions(-) diff --git a/docs/running-on-kubernetes.md b/docs/running-on-kubernetes.md index e491329136a3c..2d69f636472ae 100644 --- a/docs/running-on-kubernetes.md +++ b/docs/running-on-kubernetes.md @@ -16,6 +16,9 @@ Kubernetes scheduler that has been added to Spark. you may setup a test cluster on your local machine using [minikube](https://kubernetes.io/docs/getting-started-guides/minikube/). * We recommend using the latest release of minikube with the DNS addon enabled. + * Be aware that the default minikube configuration is not enough for running Spark applications. + We recommend 3 CPUs and 4g of memory to be able to start a simple Spark application with a single + executor. * You must have appropriate permissions to list, create, edit and delete [pods](https://kubernetes.io/docs/user-guide/pods/) in your cluster. You can verify that you can list these resources by running `kubectl auth can-i pods`. @@ -197,7 +200,7 @@ kubectl port-forward 4040:4040 Then, the Spark driver UI can be accessed on `http://localhost:4040`. -### Debugging +### Debugging There may be several kinds of failures. If the Kubernetes API server rejects the request made from spark-submit, or the connection is refused for a different reason, the submission logic should indicate the error encountered. However, if there @@ -215,8 +218,8 @@ If the pod has encountered a runtime error, the status can be probed further usi kubectl logs ``` -Status and logs of failed executor pods can be checked in similar ways. Finally, deleting the driver pod will clean up the entire spark -application, includling all executors, associated service, etc. The driver pod can be thought of as the Kubernetes representation of +Status and logs of failed executor pods can be checked in similar ways. Finally, deleting the driver pod will clean up the entire spark +application, including all executors, associated service, etc. The driver pod can be thought of as the Kubernetes representation of the Spark application. ## Kubernetes Features diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/driver/Dockerfile b/resource-managers/kubernetes/docker/src/main/dockerfiles/driver/Dockerfile index 45fbcd9cd0deb..ff5289e10c21e 100644 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/driver/Dockerfile +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/driver/Dockerfile @@ -15,7 +15,8 @@ # limitations under the License. # -FROM spark-base +ARG base_image +FROM ${base_image} # Before building the docker image, first build and make a Spark distribution following # the instructions in http://spark.apache.org/docs/latest/building-spark.html. diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/executor/Dockerfile b/resource-managers/kubernetes/docker/src/main/dockerfiles/executor/Dockerfile index 0f806cf7e148e..3eabb42d4d852 100644 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/executor/Dockerfile +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/executor/Dockerfile @@ -15,7 +15,8 @@ # limitations under the License. # -FROM spark-base +ARG base_image +FROM ${base_image} # Before building the docker image, first build and make a Spark distribution following # the instructions in http://spark.apache.org/docs/latest/building-spark.html. diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/init-container/Dockerfile b/resource-managers/kubernetes/docker/src/main/dockerfiles/init-container/Dockerfile index 047056ab2633b..e0a249e0ac71f 100644 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/init-container/Dockerfile +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/init-container/Dockerfile @@ -15,7 +15,8 @@ # limitations under the License. # -FROM spark-base +ARG base_image +FROM ${base_image} # If this docker file is being used in the context of building your images from a Spark distribution, the docker build # command should be invoked from the top level directory of the Spark distribution. E.g.: diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark-base/Dockerfile b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark-base/Dockerfile index 222e777db3a82..da1d6b9e161cc 100644 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark-base/Dockerfile +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark-base/Dockerfile @@ -17,6 +17,9 @@ FROM openjdk:8-alpine +ARG spark_jars +ARG img_path + # Before building the docker image, first build and make a Spark distribution following # the instructions in http://spark.apache.org/docs/latest/building-spark.html. # If this docker file is being used in the context of building your images from a Spark @@ -34,11 +37,11 @@ RUN set -ex && \ ln -sv /bin/bash /bin/sh && \ chgrp root /etc/passwd && chmod ug+rw /etc/passwd -COPY jars /opt/spark/jars +COPY ${spark_jars} /opt/spark/jars COPY bin /opt/spark/bin COPY sbin /opt/spark/sbin COPY conf /opt/spark/conf -COPY kubernetes/dockerfiles/spark-base/entrypoint.sh /opt/ +COPY ${img_path}/spark-base/entrypoint.sh /opt/ ENV SPARK_HOME /opt/spark diff --git a/sbin/build-push-docker-images.sh b/sbin/build-push-docker-images.sh index b3137598692d8..bb8806dd33f37 100755 --- a/sbin/build-push-docker-images.sh +++ b/sbin/build-push-docker-images.sh @@ -19,29 +19,94 @@ # This script builds and pushes docker images when run from a release of Spark # with Kubernetes support. -declare -A path=( [spark-driver]=kubernetes/dockerfiles/driver/Dockerfile \ - [spark-executor]=kubernetes/dockerfiles/executor/Dockerfile \ - [spark-init]=kubernetes/dockerfiles/init-container/Dockerfile ) +function error { + echo "$@" 1>&2 + exit 1 +} + +# Detect whether this is a git clone or a Spark distribution and adjust paths +# accordingly. +if [ -z "${SPARK_HOME}" ]; then + SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" +fi +. "${SPARK_HOME}/bin/load-spark-env.sh" + +if [ -f "$SPARK_HOME/RELEASE" ]; then + IMG_PATH="kubernetes/dockerfiles" + SPARK_JARS="jars" +else + IMG_PATH="resource-managers/kubernetes/docker/src/main/dockerfiles" + SPARK_JARS="assembly/target/scala-$SPARK_SCALA_VERSION/jars" +fi + +if [ ! -d "$IMG_PATH" ]; then + error "Cannot find docker images. This script must be run from a runnable distribution of Apache Spark." +fi + +declare -A path=( [spark-driver]="$IMG_PATH/driver/Dockerfile" \ + [spark-executor]="$IMG_PATH/executor/Dockerfile" \ + [spark-init]="$IMG_PATH/init-container/Dockerfile" ) + +function image_ref { + local image="$1" + local add_repo="${2:-1}" + if [ $add_repo = 1 ] && [ -n "$REPO" ]; then + image="$REPO/$image" + fi + if [ -n "$TAG" ]; then + image="$image:$TAG" + fi + echo "$image" +} function build { - docker build -t spark-base -f kubernetes/dockerfiles/spark-base/Dockerfile . + local base_image="$(image_ref spark-base 0)" + docker build --build-arg "spark_jars=$SPARK_JARS" \ + --build-arg "img_path=$IMG_PATH" \ + -t "$base_image" \ + -f "$IMG_PATH/spark-base/Dockerfile" . for image in "${!path[@]}"; do - docker build -t ${REPO}/$image:${TAG} -f ${path[$image]} . + docker build --build-arg "base_image=$base_image" -t "$(image_ref $image)" -f ${path[$image]} . done } - function push { for image in "${!path[@]}"; do - docker push ${REPO}/$image:${TAG} + docker push "$(image_ref $image)" done } function usage { - echo "This script must be run from a runnable distribution of Apache Spark." - echo "Usage: ./sbin/build-push-docker-images.sh -r -t build" - echo " ./sbin/build-push-docker-images.sh -r -t push" - echo "for example: ./sbin/build-push-docker-images.sh -r docker.io/myrepo -t v2.3.0 push" + cat </dev/null; then + error "Cannot find minikube." + fi + eval $(minikube docker-env) + ;; esac done -if [ -z "$REPO" ] || [ -z "$TAG" ]; then +case "${@: -1}" in + build) + build + ;; + push) + if [ -z "$REPO" ]; then + usage + exit 1 + fi + push + ;; + *) usage -else - case "${@: -1}" in - build) build;; - push) push;; - *) usage;; - esac -fi + exit 1 + ;; +esac From df7fc3ef3899cadd252d2837092bebe3442d6523 Mon Sep 17 00:00:00 2001 From: Juliusz Sompolski Date: Fri, 5 Jan 2018 10:16:34 +0800 Subject: [PATCH 0019/1161] [SPARK-22957] ApproxQuantile breaks if the number of rows exceeds MaxInt ## What changes were proposed in this pull request? 32bit Int was used for row rank. That overflowed in a dataframe with more than 2B rows. ## How was this patch tested? Added test, but ignored, as it takes 4 minutes. Author: Juliusz Sompolski Closes #20152 from juliuszsompolski/SPARK-22957. --- .../aggregate/ApproximatePercentile.scala | 12 ++++++------ .../spark/sql/catalyst/util/QuantileSummaries.scala | 8 ++++---- .../org/apache/spark/sql/DataFrameStatSuite.scala | 8 ++++++++ 3 files changed, 18 insertions(+), 10 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala index 149ac265e6ed5..a45854a3b5146 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala @@ -296,8 +296,8 @@ object ApproximatePercentile { Ints.BYTES + Doubles.BYTES + Longs.BYTES + // length of summary.sampled Ints.BYTES + - // summary.sampled, Array[Stat(value: Double, g: Int, delta: Int)] - summaries.sampled.length * (Doubles.BYTES + Ints.BYTES + Ints.BYTES) + // summary.sampled, Array[Stat(value: Double, g: Long, delta: Long)] + summaries.sampled.length * (Doubles.BYTES + Longs.BYTES + Longs.BYTES) } final def serialize(obj: PercentileDigest): Array[Byte] = { @@ -312,8 +312,8 @@ object ApproximatePercentile { while (i < summary.sampled.length) { val stat = summary.sampled(i) buffer.putDouble(stat.value) - buffer.putInt(stat.g) - buffer.putInt(stat.delta) + buffer.putLong(stat.g) + buffer.putLong(stat.delta) i += 1 } buffer.array() @@ -330,8 +330,8 @@ object ApproximatePercentile { var i = 0 while (i < sampledLength) { val value = buffer.getDouble() - val g = buffer.getInt() - val delta = buffer.getInt() + val g = buffer.getLong() + val delta = buffer.getLong() sampled(i) = Stats(value, g, delta) i += 1 } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala index eb7941cf9e6af..b013add9c9778 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala @@ -105,7 +105,7 @@ class QuantileSummaries( if (newSamples.isEmpty || (sampleIdx == sampled.length && opsIdx == sorted.length - 1)) { 0 } else { - math.floor(2 * relativeError * currentCount).toInt + math.floor(2 * relativeError * currentCount).toLong } val tuple = Stats(currentSample, 1, delta) @@ -192,10 +192,10 @@ class QuantileSummaries( } // Target rank - val rank = math.ceil(quantile * count).toInt + val rank = math.ceil(quantile * count).toLong val targetError = relativeError * count // Minimum rank at current sample - var minRank = 0 + var minRank = 0L var i = 0 while (i < sampled.length - 1) { val curSample = sampled(i) @@ -235,7 +235,7 @@ object QuantileSummaries { * @param g the minimum rank jump from the previous value's minimum rank * @param delta the maximum span of the rank. */ - case class Stats(value: Double, g: Int, delta: Int) + case class Stats(value: Double, g: Long, delta: Long) private def compressImmut( currentSamples: IndexedSeq[Stats], diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index 46b21c3b64a2e..5169d2b5fc6b2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -260,6 +260,14 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext { assert(res2(1).isEmpty) } + // SPARK-22957: check for 32bit overflow when computing rank. + // ignored - takes 4 minutes to run. + ignore("approx quantile 4: test for Int overflow") { + val res = spark.range(3000000000L).stat.approxQuantile("id", Array(0.8, 0.9), 0.05) + assert(res(0) > 2200000000.0) + assert(res(1) > 2200000000.0) + } + test("crosstab") { withSQLConf(SQLConf.SUPPORT_QUOTED_REGEX_COLUMN_NAME.key -> "false") { val rng = new Random() From 52fc5c17d9d784b846149771b398e741621c0b5c Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Fri, 5 Jan 2018 14:02:21 +0800 Subject: [PATCH 0020/1161] [SPARK-22825][SQL] Fix incorrect results of Casting Array to String ## What changes were proposed in this pull request? This pr fixed the issue when casting arrays into strings; ``` scala> val df = spark.range(10).select('id.cast("integer")).agg(collect_list('id).as('ids)) scala> df.write.saveAsTable("t") scala> sql("SELECT cast(ids as String) FROM t").show(false) +------------------------------------------------------------------+ |ids | +------------------------------------------------------------------+ |org.apache.spark.sql.catalyst.expressions.UnsafeArrayData8bc285df| +------------------------------------------------------------------+ ``` This pr modified the result into; ``` +------------------------------+ |ids | +------------------------------+ |[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]| +------------------------------+ ``` ## How was this patch tested? Added tests in `CastSuite` and `SQLQuerySuite`. Author: Takeshi Yamamuro Closes #20024 from maropu/SPARK-22825. --- .../codegen/UTF8StringBuilder.java | 78 +++++++++++++++++++ .../spark/sql/catalyst/expressions/Cast.scala | 68 ++++++++++++++++ .../sql/catalyst/expressions/CastSuite.scala | 25 ++++++ .../org/apache/spark/sql/SQLQuerySuite.scala | 2 - 4 files changed, 171 insertions(+), 2 deletions(-) create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UTF8StringBuilder.java diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UTF8StringBuilder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UTF8StringBuilder.java new file mode 100644 index 0000000000000..f0f66bae245fd --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UTF8StringBuilder.java @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.codegen; + +import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.array.ByteArrayMethods; +import org.apache.spark.unsafe.types.UTF8String; + +/** + * A helper class to write {@link UTF8String}s to an internal buffer and build the concatenated + * {@link UTF8String} at the end. + */ +public class UTF8StringBuilder { + + private static final int ARRAY_MAX = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH; + + private byte[] buffer; + private int cursor = Platform.BYTE_ARRAY_OFFSET; + + public UTF8StringBuilder() { + // Since initial buffer size is 16 in `StringBuilder`, we set the same size here + this.buffer = new byte[16]; + } + + // Grows the buffer by at least `neededSize` + private void grow(int neededSize) { + if (neededSize > ARRAY_MAX - totalSize()) { + throw new UnsupportedOperationException( + "Cannot grow internal buffer by size " + neededSize + " because the size after growing " + + "exceeds size limitation " + ARRAY_MAX); + } + final int length = totalSize() + neededSize; + if (buffer.length < length) { + int newLength = length < ARRAY_MAX / 2 ? length * 2 : ARRAY_MAX; + final byte[] tmp = new byte[newLength]; + Platform.copyMemory( + buffer, + Platform.BYTE_ARRAY_OFFSET, + tmp, + Platform.BYTE_ARRAY_OFFSET, + totalSize()); + buffer = tmp; + } + } + + private int totalSize() { + return cursor - Platform.BYTE_ARRAY_OFFSET; + } + + public void append(UTF8String value) { + grow(value.numBytes()); + value.writeToMemory(buffer, cursor); + cursor += value.numBytes(); + } + + public void append(String value) { + append(UTF8String.fromString(value)); + } + + public UTF8String build() { + return UTF8String.fromBytes(buffer, 0, totalSize()); + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 274d8813f16db..d4fc5e0f168a7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -206,6 +206,28 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String case DateType => buildCast[Int](_, d => UTF8String.fromString(DateTimeUtils.dateToString(d))) case TimestampType => buildCast[Long](_, t => UTF8String.fromString(DateTimeUtils.timestampToString(t, timeZone))) + case ArrayType(et, _) => + buildCast[ArrayData](_, array => { + val builder = new UTF8StringBuilder + builder.append("[") + if (array.numElements > 0) { + val toUTF8String = castToString(et) + if (!array.isNullAt(0)) { + builder.append(toUTF8String(array.get(0, et)).asInstanceOf[UTF8String]) + } + var i = 1 + while (i < array.numElements) { + builder.append(",") + if (!array.isNullAt(i)) { + builder.append(" ") + builder.append(toUTF8String(array.get(i, et)).asInstanceOf[UTF8String]) + } + i += 1 + } + } + builder.append("]") + builder.build() + }) case _ => buildCast[Any](_, o => UTF8String.fromString(o.toString)) } @@ -597,6 +619,41 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String """ } + private def writeArrayToStringBuilder( + et: DataType, + array: String, + buffer: String, + ctx: CodegenContext): String = { + val elementToStringCode = castToStringCode(et, ctx) + val funcName = ctx.freshName("elementToString") + val elementToStringFunc = ctx.addNewFunction(funcName, + s""" + |private UTF8String $funcName(${ctx.javaType(et)} element) { + | UTF8String elementStr = null; + | ${elementToStringCode("element", "elementStr", null /* resultIsNull won't be used */)} + | return elementStr; + |} + """.stripMargin) + + val loopIndex = ctx.freshName("loopIndex") + s""" + |$buffer.append("["); + |if ($array.numElements() > 0) { + | if (!$array.isNullAt(0)) { + | $buffer.append($elementToStringFunc(${ctx.getValue(array, et, "0")})); + | } + | for (int $loopIndex = 1; $loopIndex < $array.numElements(); $loopIndex++) { + | $buffer.append(","); + | if (!$array.isNullAt($loopIndex)) { + | $buffer.append(" "); + | $buffer.append($elementToStringFunc(${ctx.getValue(array, et, loopIndex)})); + | } + | } + |} + |$buffer.append("]"); + """.stripMargin + } + private[this] def castToStringCode(from: DataType, ctx: CodegenContext): CastFunction = { from match { case BinaryType => @@ -608,6 +665,17 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String val tz = ctx.addReferenceObj("timeZone", timeZone) (c, evPrim, evNull) => s"""$evPrim = UTF8String.fromString( org.apache.spark.sql.catalyst.util.DateTimeUtils.timestampToString($c, $tz));""" + case ArrayType(et, _) => + (c, evPrim, evNull) => { + val buffer = ctx.freshName("buffer") + val bufferClass = classOf[UTF8StringBuilder].getName + val writeArrayElemCode = writeArrayToStringBuilder(et, c, buffer, ctx) + s""" + |$bufferClass $buffer = new $bufferClass(); + |$writeArrayElemCode; + |$evPrim = $buffer.build(); + """.stripMargin + } case _ => (c, evPrim, evNull) => s"$evPrim = UTF8String.fromString(String.valueOf($c));" } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala index 1dd040e4696a1..e3ed7171defd8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -853,4 +853,29 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { cast("2", LongType).genCode(ctx) assert(ctx.inlinedMutableStates.length == 0) } + + test("SPARK-22825 Cast array to string") { + val ret1 = cast(Literal.create(Array(1, 2, 3, 4, 5)), StringType) + checkEvaluation(ret1, "[1, 2, 3, 4, 5]") + val ret2 = cast(Literal.create(Array("ab", "cde", "f")), StringType) + checkEvaluation(ret2, "[ab, cde, f]") + val ret3 = cast(Literal.create(Array("ab", null, "c")), StringType) + checkEvaluation(ret3, "[ab,, c]") + val ret4 = cast(Literal.create(Array("ab".getBytes, "cde".getBytes, "f".getBytes)), StringType) + checkEvaluation(ret4, "[ab, cde, f]") + val ret5 = cast( + Literal.create(Array("2014-12-03", "2014-12-04", "2014-12-06").map(Date.valueOf)), + StringType) + checkEvaluation(ret5, "[2014-12-03, 2014-12-04, 2014-12-06]") + val ret6 = cast( + Literal.create(Array("2014-12-03 13:01:00", "2014-12-04 15:05:00").map(Timestamp.valueOf)), + StringType) + checkEvaluation(ret6, "[2014-12-03 13:01:00, 2014-12-04 15:05:00]") + val ret7 = cast(Literal.create(Array(Array(1, 2, 3), Array(4, 5))), StringType) + checkEvaluation(ret7, "[[1, 2, 3], [4, 5]]") + val ret8 = cast( + Literal.create(Array(Array(Array("a"), Array("b", "c")), Array(Array("d")))), + StringType) + checkEvaluation(ret8, "[[[a], [b, c]], [[d]]]") + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 5e077285ade55..96bf65fce9c4a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -28,8 +28,6 @@ import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} import org.apache.spark.sql.catalyst.util.StringUtils import org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, SortAggregateExec} -import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation} -import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, CartesianProductExec, SortMergeJoinExec} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf From cf0aa65576acbe0209c67f04c029058fd73555c1 Mon Sep 17 00:00:00 2001 From: Bago Amirbekian Date: Thu, 4 Jan 2018 22:45:15 -0800 Subject: [PATCH 0021/1161] [SPARK-22949][ML] Apply CrossValidator approach to Driver/Distributed memory tradeoff for TrainValidationSplit ## What changes were proposed in this pull request? Avoid holding all models in memory for `TrainValidationSplit`. ## How was this patch tested? Existing tests. Author: Bago Amirbekian Closes #20143 from MrBago/trainValidMemoryFix. --- .../spark/ml/tuning/CrossValidator.scala | 4 +++- .../spark/ml/tuning/TrainValidationSplit.scala | 18 ++++-------------- 2 files changed, 7 insertions(+), 15 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index 095b54c0fe83f..a0b507d2e718c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -160,8 +160,10 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) } (executionContext) } - // Wait for metrics to be calculated before unpersisting validation dataset + // Wait for metrics to be calculated val foldMetrics = foldMetricFutures.map(ThreadUtils.awaitResult(_, Duration.Inf)) + + // Unpersist training & validation set once all metrics have been produced trainingDataset.unpersist() validationDataset.unpersist() foldMetrics diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala index c73bd18475475..8826ef3271bc1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala @@ -143,24 +143,13 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St // Fit models in a Future for training in parallel logDebug(s"Train split with multiple sets of parameters.") - val modelFutures = epm.zipWithIndex.map { case (paramMap, paramIndex) => - Future[Model[_]] { + val metricFutures = epm.zipWithIndex.map { case (paramMap, paramIndex) => + Future[Double] { val model = est.fit(trainingDataset, paramMap).asInstanceOf[Model[_]] if (collectSubModelsParam) { subModels.get(paramIndex) = model } - model - } (executionContext) - } - - // Unpersist training data only when all models have trained - Future.sequence[Model[_], Iterable](modelFutures)(implicitly, executionContext) - .onComplete { _ => trainingDataset.unpersist() } (executionContext) - - // Evaluate models in a Future that will calulate a metric and allow model to be cleaned up - val metricFutures = modelFutures.zip(epm).map { case (modelFuture, paramMap) => - modelFuture.map { model => // TODO: duplicate evaluator to take extra params from input val metric = eval.evaluate(model.transform(validationDataset, paramMap)) logDebug(s"Got metric $metric for model trained with $paramMap.") @@ -171,7 +160,8 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St // Wait for all metrics to be calculated val metrics = metricFutures.map(ThreadUtils.awaitResult(_, Duration.Inf)) - // Unpersist validation set once all metrics have been produced + // Unpersist training & validation set once all metrics have been produced + trainingDataset.unpersist() validationDataset.unpersist() logInfo(s"Train validation split metrics: ${metrics.toSeq}") From 6cff7d19f6a905fe425bd6892fe7ca014c0e696b Mon Sep 17 00:00:00 2001 From: Yinan Li Date: Thu, 4 Jan 2018 23:23:41 -0800 Subject: [PATCH 0022/1161] [SPARK-22757][K8S] Enable spark.jars and spark.files in KUBERNETES mode ## What changes were proposed in this pull request? We missed enabling `spark.files` and `spark.jars` in https://github.com/apache/spark/pull/19954. The result is that remote dependencies specified through `spark.files` or `spark.jars` are not included in the list of remote dependencies to be downloaded by the init-container. This PR fixes it. ## How was this patch tested? Manual tests. vanzin This replaces https://github.com/apache/spark/pull/20157. foxish Author: Yinan Li Closes #20160 from liyinan926/SPARK-22757. --- .../src/main/scala/org/apache/spark/deploy/SparkSubmit.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index cbe1f2c3e08a1..1e381965c52ba 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -584,10 +584,11 @@ object SparkSubmit extends CommandLineUtils with Logging { confKey = "spark.executor.memory"), OptionAssigner(args.totalExecutorCores, STANDALONE | MESOS | KUBERNETES, ALL_DEPLOY_MODES, confKey = "spark.cores.max"), - OptionAssigner(args.files, LOCAL | STANDALONE | MESOS, ALL_DEPLOY_MODES, + OptionAssigner(args.files, LOCAL | STANDALONE | MESOS | KUBERNETES, ALL_DEPLOY_MODES, confKey = "spark.files"), OptionAssigner(args.jars, LOCAL, CLIENT, confKey = "spark.jars"), - OptionAssigner(args.jars, STANDALONE | MESOS, ALL_DEPLOY_MODES, confKey = "spark.jars"), + OptionAssigner(args.jars, STANDALONE | MESOS | KUBERNETES, ALL_DEPLOY_MODES, + confKey = "spark.jars"), OptionAssigner(args.driverMemory, STANDALONE | MESOS | YARN | KUBERNETES, CLUSTER, confKey = "spark.driver.memory"), OptionAssigner(args.driverCores, STANDALONE | MESOS | YARN | KUBERNETES, CLUSTER, From 51c33bd0d402af9e0284c6cbc0111f926446bfba Mon Sep 17 00:00:00 2001 From: Adrian Ionescu Date: Fri, 5 Jan 2018 21:32:39 +0800 Subject: [PATCH 0023/1161] [SPARK-22961][REGRESSION] Constant columns should generate QueryPlanConstraints ## What changes were proposed in this pull request? #19201 introduced the following regression: given something like `df.withColumn("c", lit(2))`, we're no longer picking up `c === 2` as a constraint and infer filters from it when joins are involved, which may lead to noticeable performance degradation. This patch re-enables this optimization by picking up Aliases of Literals in Projection lists as constraints and making sure they're not treated as aliased columns. ## How was this patch tested? Unit test was added. Author: Adrian Ionescu Closes #20155 from adrian-ionescu/constant_constraints. --- .../sql/catalyst/plans/logical/LogicalPlan.scala | 2 ++ .../plans/logical/QueryPlanConstraints.scala | 2 +- .../InferFiltersFromConstraintsSuite.scala | 13 +++++++++++++ 3 files changed, 16 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index a38458add7b5e..ff2a0ec588567 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -247,6 +247,8 @@ abstract class UnaryNode extends LogicalPlan { protected def getAliasedConstraints(projectList: Seq[NamedExpression]): Set[Expression] = { var allConstraints = child.constraints.asInstanceOf[Set[Expression]] projectList.foreach { + case a @ Alias(l: Literal, _) => + allConstraints += EqualTo(a.toAttribute, l) case a @ Alias(e, _) => // For every alias in `projectList`, replace the reference in constraints by its attribute. allConstraints ++= allConstraints.map(_ transform { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala index b0f611fd38dea..9c0a30a47f839 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala @@ -98,7 +98,7 @@ trait QueryPlanConstraints { self: LogicalPlan => // we may avoid producing recursive constraints. private lazy val aliasMap: AttributeMap[Expression] = AttributeMap( expressions.collect { - case a: Alias => (a.toAttribute, a.child) + case a: Alias if !a.child.isInstanceOf[Literal] => (a.toAttribute, a.child) } ++ children.flatMap(_.asInstanceOf[QueryPlanConstraints].aliasMap)) // Note: the explicit cast is necessary, since Scala compiler fails to infer the type. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala index 5580f8604ec72..a0708bf7eee9a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala @@ -236,4 +236,17 @@ class InferFiltersFromConstraintsSuite extends PlanTest { comparePlans(optimized, originalQuery) } } + + test("constraints should be inferred from aliased literals") { + val originalLeft = testRelation.subquery('left).as("left") + val optimizedLeft = testRelation.subquery('left).where(IsNotNull('a) && 'a === 2).as("left") + + val right = Project(Seq(Literal(2).as("two")), testRelation.subquery('right)).as("right") + val condition = Some("left.a".attr === "right.two".attr) + + val original = originalLeft.join(right, Inner, condition) + val correct = optimizedLeft.join(right, Inner, condition) + + comparePlans(Optimize.execute(original.analyze), correct.analyze) + } } From c0b7424ecacb56d3e7a18acc11ba3d5e7be57c43 Mon Sep 17 00:00:00 2001 From: Bruce Robbins Date: Fri, 5 Jan 2018 09:58:28 -0800 Subject: [PATCH 0024/1161] [SPARK-22940][SQL] HiveExternalCatalogVersionsSuite should succeed on platforms that don't have wget ## What changes were proposed in this pull request? Modified HiveExternalCatalogVersionsSuite.scala to use Utils.doFetchFile to download different versions of Spark binaries rather than launching wget as an external process. On platforms that don't have wget installed, this suite fails with an error. cloud-fan : would you like to check this change? ## How was this patch tested? 1) test-only of HiveExternalCatalogVersionsSuite on several platforms. Tested bad mirror, read timeout, and redirects. 2) ./dev/run-tests Author: Bruce Robbins Closes #20147 from bersprockets/SPARK-22940-alt. --- .../HiveExternalCatalogVersionsSuite.scala | 48 ++++++++++++++++--- 1 file changed, 42 insertions(+), 6 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala index a3d5b941a6761..ae4aeb7b4ce4a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala @@ -18,11 +18,14 @@ package org.apache.spark.sql.hive import java.io.File -import java.nio.file.Files +import java.nio.charset.StandardCharsets +import java.nio.file.{Files, Paths} import scala.sys.process._ -import org.apache.spark.TestUtils +import org.apache.hadoop.conf.Configuration + +import org.apache.spark.{SecurityManager, SparkConf, TestUtils} import org.apache.spark.sql.{QueryTest, Row, SparkSession} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog.CatalogTableType @@ -55,14 +58,19 @@ class HiveExternalCatalogVersionsSuite extends SparkSubmitTestUtils { private def tryDownloadSpark(version: String, path: String): Unit = { // Try mirrors a few times until one succeeds for (i <- 0 until 3) { + // we don't retry on a failure to get mirror url. If we can't get a mirror url, + // the test fails (getStringFromUrl will throw an exception) val preferredMirror = - Seq("wget", "https://www.apache.org/dyn/closer.lua?preferred=true", "-q", "-O", "-").!!.trim - val url = s"$preferredMirror/spark/spark-$version/spark-$version-bin-hadoop2.7.tgz" + getStringFromUrl("https://www.apache.org/dyn/closer.lua?preferred=true") + val filename = s"spark-$version-bin-hadoop2.7.tgz" + val url = s"$preferredMirror/spark/spark-$version/$filename" logInfo(s"Downloading Spark $version from $url") - if (Seq("wget", url, "-q", "-P", path).! == 0) { + try { + getFileFromUrl(url, path, filename) return + } catch { + case ex: Exception => logWarning(s"Failed to download Spark $version from $url", ex) } - logWarning(s"Failed to download Spark $version from $url") } fail(s"Unable to download Spark $version") } @@ -85,6 +93,34 @@ class HiveExternalCatalogVersionsSuite extends SparkSubmitTestUtils { new File(tmpDataDir, name).getCanonicalPath } + private def getFileFromUrl(urlString: String, targetDir: String, filename: String): Unit = { + val conf = new SparkConf + // if the caller passes the name of an existing file, we want doFetchFile to write over it with + // the contents from the specified url. + conf.set("spark.files.overwrite", "true") + val securityManager = new SecurityManager(conf) + val hadoopConf = new Configuration + + val outDir = new File(targetDir) + if (!outDir.exists()) { + outDir.mkdirs() + } + + // propagate exceptions up to the caller of getFileFromUrl + Utils.doFetchFile(urlString, outDir, filename, conf, securityManager, hadoopConf) + } + + private def getStringFromUrl(urlString: String): String = { + val contentFile = File.createTempFile("string-", ".txt") + contentFile.deleteOnExit() + + // exceptions will propagate to the caller of getStringFromUrl + getFileFromUrl(urlString, contentFile.getParent, contentFile.getName) + + val contentPath = Paths.get(contentFile.toURI) + new String(Files.readAllBytes(contentPath), StandardCharsets.UTF_8) + } + override def beforeAll(): Unit = { super.beforeAll() From 930b90a84871e2504b57ed50efa7b8bb52d3ba44 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Fri, 5 Jan 2018 11:51:25 -0800 Subject: [PATCH 0025/1161] [SPARK-13030][ML] Follow-up cleanups for OneHotEncoderEstimator ## What changes were proposed in this pull request? Follow-up cleanups for the OneHotEncoderEstimator PR. See some discussion in the original PR: https://github.com/apache/spark/pull/19527 or read below for what this PR includes: * configedCategorySize: I reverted this to return an Array. I realized the original setup (which I had recommended in the original PR) caused the whole model to be serialized in the UDF. * encoder: I reorganized the logic to show what I meant in the comment in the previous PR. I think it's simpler but am open to suggestions. I also made some small style cleanups based on IntelliJ warnings. ## How was this patch tested? Existing unit tests Author: Joseph K. Bradley Closes #20132 from jkbradley/viirya-SPARK-13030. --- .../ml/feature/OneHotEncoderEstimator.scala | 92 ++++++++++--------- 1 file changed, 49 insertions(+), 43 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala index 074622d41e28d..bd1e3426c8780 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala @@ -30,24 +30,27 @@ import org.apache.spark.ml.util._ import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.expressions.UserDefinedFunction import org.apache.spark.sql.functions.{col, lit, udf} -import org.apache.spark.sql.types.{DoubleType, NumericType, StructField, StructType} +import org.apache.spark.sql.types.{DoubleType, StructField, StructType} /** Private trait for params and common methods for OneHotEncoderEstimator and OneHotEncoderModel */ private[ml] trait OneHotEncoderBase extends Params with HasHandleInvalid with HasInputCols with HasOutputCols { /** - * Param for how to handle invalid data. + * Param for how to handle invalid data during transform(). * Options are 'keep' (invalid data presented as an extra categorical feature) or * 'error' (throw an error). + * Note that this Param is only used during transform; during fitting, invalid data + * will result in an error. * Default: "error" * @group param */ @Since("2.3.0") override val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", - "How to handle invalid data " + + "How to handle invalid data during transform(). " + "Options are 'keep' (invalid data presented as an extra categorical feature) " + - "or error (throw an error).", + "or error (throw an error). Note that this Param is only used during transform; " + + "during fitting, invalid data will result in an error.", ParamValidators.inArray(OneHotEncoderEstimator.supportedHandleInvalids)) setDefault(handleInvalid, OneHotEncoderEstimator.ERROR_INVALID) @@ -66,10 +69,11 @@ private[ml] trait OneHotEncoderBase extends Params with HasHandleInvalid def getDropLast: Boolean = $(dropLast) protected def validateAndTransformSchema( - schema: StructType, dropLast: Boolean, keepInvalid: Boolean): StructType = { + schema: StructType, + dropLast: Boolean, + keepInvalid: Boolean): StructType = { val inputColNames = $(inputCols) val outputColNames = $(outputCols) - val existingFields = schema.fields require(inputColNames.length == outputColNames.length, s"The number of input columns ${inputColNames.length} must be the same as the number of " + @@ -197,6 +201,10 @@ object OneHotEncoderEstimator extends DefaultParamsReadable[OneHotEncoderEstimat override def load(path: String): OneHotEncoderEstimator = super.load(path) } +/** + * @param categorySizes Original number of categories for each feature being encoded. + * The array contains one value for each input column, in order. + */ @Since("2.3.0") class OneHotEncoderModel private[ml] ( @Since("2.3.0") override val uid: String, @@ -205,60 +213,58 @@ class OneHotEncoderModel private[ml] ( import OneHotEncoderModel._ - // Returns the category size for a given index with `dropLast` and `handleInvalid` + // Returns the category size for each index with `dropLast` and `handleInvalid` // taken into account. - private def configedCategorySize(orgCategorySize: Int, idx: Int): Int = { + private def getConfigedCategorySizes: Array[Int] = { val dropLast = getDropLast val keepInvalid = getHandleInvalid == OneHotEncoderEstimator.KEEP_INVALID if (!dropLast && keepInvalid) { // When `handleInvalid` is "keep", an extra category is added as last category // for invalid data. - orgCategorySize + 1 + categorySizes.map(_ + 1) } else if (dropLast && !keepInvalid) { // When `dropLast` is true, the last category is removed. - orgCategorySize - 1 + categorySizes.map(_ - 1) } else { // When `dropLast` is true and `handleInvalid` is "keep", the extra category for invalid // data is removed. Thus, it is the same as the plain number of categories. - orgCategorySize + categorySizes } } private def encoder: UserDefinedFunction = { - val oneValue = Array(1.0) - val emptyValues = Array.empty[Double] - val emptyIndices = Array.empty[Int] - val dropLast = getDropLast - val handleInvalid = getHandleInvalid - val keepInvalid = handleInvalid == OneHotEncoderEstimator.KEEP_INVALID + val keepInvalid = getHandleInvalid == OneHotEncoderEstimator.KEEP_INVALID + val configedSizes = getConfigedCategorySizes + val localCategorySizes = categorySizes // The udf performed on input data. The first parameter is the input value. The second - // parameter is the index of input. - udf { (label: Double, idx: Int) => - val plainNumCategories = categorySizes(idx) - val size = configedCategorySize(plainNumCategories, idx) - - if (label < 0) { - throw new SparkException(s"Negative value: $label. Input can't be negative.") - } else if (label == size && dropLast && !keepInvalid) { - // When `dropLast` is true and `handleInvalid` is not "keep", - // the last category is removed. - Vectors.sparse(size, emptyIndices, emptyValues) - } else if (label >= plainNumCategories && keepInvalid) { - // When `handleInvalid` is "keep", encodes invalid data to last category (and removed - // if `dropLast` is true) - if (dropLast) { - Vectors.sparse(size, emptyIndices, emptyValues) + // parameter is the index in inputCols of the column being encoded. + udf { (label: Double, colIdx: Int) => + val origCategorySize = localCategorySizes(colIdx) + // idx: index in vector of the single 1-valued element + val idx = if (label >= 0 && label < origCategorySize) { + label + } else { + if (keepInvalid) { + origCategorySize } else { - Vectors.sparse(size, Array(size - 1), oneValue) + if (label < 0) { + throw new SparkException(s"Negative value: $label. Input can't be negative. " + + s"To handle invalid values, set Param handleInvalid to " + + s"${OneHotEncoderEstimator.KEEP_INVALID}") + } else { + throw new SparkException(s"Unseen value: $label. To handle unseen values, " + + s"set Param handleInvalid to ${OneHotEncoderEstimator.KEEP_INVALID}.") + } } - } else if (label < plainNumCategories) { - Vectors.sparse(size, Array(label.toInt), oneValue) + } + + val size = configedSizes(colIdx) + if (idx < size) { + Vectors.sparse(size, Array(idx.toInt), Array(1.0)) } else { - assert(handleInvalid == OneHotEncoderEstimator.ERROR_INVALID) - throw new SparkException(s"Unseen value: $label. To handle unseen values, " + - s"set Param handleInvalid to ${OneHotEncoderEstimator.KEEP_INVALID}.") + Vectors.sparse(size, Array.empty[Int], Array.empty[Double]) } } } @@ -282,7 +288,6 @@ class OneHotEncoderModel private[ml] ( @Since("2.3.0") override def transformSchema(schema: StructType): StructType = { val inputColNames = $(inputCols) - val outputColNames = $(outputCols) require(inputColNames.length == categorySizes.length, s"The number of input columns ${inputColNames.length} must be the same as the number of " + @@ -300,6 +305,7 @@ class OneHotEncoderModel private[ml] ( * account. Mismatched numbers will cause exception. */ private def verifyNumOfValues(schema: StructType): StructType = { + val configedSizes = getConfigedCategorySizes $(outputCols).zipWithIndex.foreach { case (outputColName, idx) => val inputColName = $(inputCols)(idx) val attrGroup = AttributeGroup.fromStructField(schema(outputColName)) @@ -308,9 +314,9 @@ class OneHotEncoderModel private[ml] ( // comparing with expected category number with `handleInvalid` and // `dropLast` taken into account. if (attrGroup.attributes.nonEmpty) { - val numCategories = configedCategorySize(categorySizes(idx), idx) + val numCategories = configedSizes(idx) require(attrGroup.size == numCategories, "OneHotEncoderModel expected " + - s"$numCategories categorical values for input column ${inputColName}, " + + s"$numCategories categorical values for input column $inputColName, " + s"but the input column had metadata specifying ${attrGroup.size} values.") } } @@ -322,7 +328,7 @@ class OneHotEncoderModel private[ml] ( val transformedSchema = transformSchema(dataset.schema, logging = true) val keepInvalid = $(handleInvalid) == OneHotEncoderEstimator.KEEP_INVALID - val encodedColumns = (0 until $(inputCols).length).map { idx => + val encodedColumns = $(inputCols).indices.map { idx => val inputColName = $(inputCols)(idx) val outputColName = $(outputCols)(idx) From ea956833017fcbd8ed2288368bfa2e417a2251c5 Mon Sep 17 00:00:00 2001 From: Gera Shegalov Date: Fri, 5 Jan 2018 17:25:28 -0800 Subject: [PATCH 0026/1161] [SPARK-22914][DEPLOY] Register history.ui.port ## What changes were proposed in this pull request? Register spark.history.ui.port as a known spark conf to be used in substitution expressions even if it's not set explicitly. ## How was this patch tested? Added unit test to demonstrate the issue Author: Gera Shegalov Author: Gera Shegalov Closes #20098 from gerashegalov/gera/register-SHS-port-conf. --- .../spark/deploy/history/HistoryServer.scala | 3 +- .../apache/spark/deploy/history/config.scala | 5 +++ .../spark/deploy/yarn/ApplicationMaster.scala | 17 +++++--- .../deploy/yarn/ApplicationMasterSuite.scala | 43 +++++++++++++++++++ 4 files changed, 62 insertions(+), 6 deletions(-) create mode 100644 resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ApplicationMasterSuite.scala diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala index 75484f5c9f30f..0ec4afad0308c 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala @@ -28,6 +28,7 @@ import org.eclipse.jetty.servlet.{ServletContextHandler, ServletHolder} import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.deploy.history.config.HISTORY_SERVER_UI_PORT import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ import org.apache.spark.status.api.v1.{ApiRootResource, ApplicationInfo, UIRoot} @@ -276,7 +277,7 @@ object HistoryServer extends Logging { .newInstance(conf) .asInstanceOf[ApplicationHistoryProvider] - val port = conf.getInt("spark.history.ui.port", 18080) + val port = conf.get(HISTORY_SERVER_UI_PORT) val server = new HistoryServer(conf, provider, securityManager, port) server.bind() diff --git a/core/src/main/scala/org/apache/spark/deploy/history/config.scala b/core/src/main/scala/org/apache/spark/deploy/history/config.scala index 22b6d49d8e2a4..efdbf672bb52f 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/config.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/config.scala @@ -44,4 +44,9 @@ private[spark] object config { .bytesConf(ByteUnit.BYTE) .createWithDefaultString("10g") + val HISTORY_SERVER_UI_PORT = ConfigBuilder("spark.history.ui.port") + .doc("Web UI port to bind Spark History Server") + .intConf + .createWithDefault(18080) + } diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index b2576b0d72633..4d5e3bb043671 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -427,11 +427,8 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends uiAddress: Option[String]) = { val appId = client.getAttemptId().getApplicationId().toString() val attemptId = client.getAttemptId().getAttemptId().toString() - val historyAddress = - _sparkConf.get(HISTORY_SERVER_ADDRESS) - .map { text => SparkHadoopUtil.get.substituteHadoopVariables(text, yarnConf) } - .map { address => s"${address}${HistoryServer.UI_PATH_PREFIX}/${appId}/${attemptId}" } - .getOrElse("") + val historyAddress = ApplicationMaster + .getHistoryServerAddress(_sparkConf, yarnConf, appId, attemptId) val driverUrl = RpcEndpointAddress( _sparkConf.get("spark.driver.host"), @@ -834,6 +831,16 @@ object ApplicationMaster extends Logging { master.getAttemptId } + private[spark] def getHistoryServerAddress( + sparkConf: SparkConf, + yarnConf: YarnConfiguration, + appId: String, + attemptId: String): String = { + sparkConf.get(HISTORY_SERVER_ADDRESS) + .map { text => SparkHadoopUtil.get.substituteHadoopVariables(text, yarnConf) } + .map { address => s"${address}${HistoryServer.UI_PATH_PREFIX}/${appId}/${attemptId}" } + .getOrElse("") + } } /** diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ApplicationMasterSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ApplicationMasterSuite.scala new file mode 100644 index 0000000000000..695a82f3583e6 --- /dev/null +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ApplicationMasterSuite.scala @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.yarn + +import org.apache.hadoop.yarn.conf.YarnConfiguration + +import org.apache.spark.{SparkConf, SparkFunSuite} + +class ApplicationMasterSuite extends SparkFunSuite { + + test("history url with hadoop and spark substitutions") { + val host = "rm.host.com" + val port = 18080 + val sparkConf = new SparkConf() + + sparkConf.set("spark.yarn.historyServer.address", + "http://${hadoopconf-yarn.resourcemanager.hostname}:${spark.history.ui.port}") + val yarnConf = new YarnConfiguration() + yarnConf.set("yarn.resourcemanager.hostname", host) + val appId = "application_123_1" + val attemptId = appId + "_1" + + val shsAddr = ApplicationMaster + .getHistoryServerAddress(sparkConf, yarnConf, appId, attemptId) + + assert(shsAddr === s"http://${host}:${port}/history/${appId}/${attemptId}") + } +} From e8af7e8aeca15a6107248f358d9514521ffdc6d3 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Sat, 6 Jan 2018 09:26:03 +0800 Subject: [PATCH 0027/1161] [SPARK-22937][SQL] SQL elt output binary for binary inputs ## What changes were proposed in this pull request? This pr modified `elt` to output binary for binary inputs. `elt` in the current master always output data as a string. But, in some databases (e.g., MySQL), if all inputs are binary, `elt` also outputs binary (Also, this might be a small surprise). This pr is related to #19977. ## How was this patch tested? Added tests in `SQLQueryTestSuite` and `TypeCoercionSuite`. Author: Takeshi Yamamuro Closes #20135 from maropu/SPARK-22937. --- docs/sql-programming-guide.md | 2 + .../sql/catalyst/analysis/TypeCoercion.scala | 29 +++++ .../expressions/stringExpressions.scala | 46 ++++--- .../apache/spark/sql/internal/SQLConf.scala | 8 ++ .../catalyst/analysis/TypeCoercionSuite.scala | 54 ++++++++ .../inputs/typeCoercion/native/elt.sql | 44 +++++++ .../results/typeCoercion/native/elt.sql.out | 115 ++++++++++++++++++ 7 files changed, 281 insertions(+), 17 deletions(-) create mode 100644 sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/elt.sql create mode 100644 sql/core/src/test/resources/sql-tests/results/typeCoercion/native/elt.sql.out diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index dc3e384008d27..b50f9360b866c 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1783,6 +1783,8 @@ options. - Since Spark 2.3, when all inputs are binary, `functions.concat()` returns an output as binary. Otherwise, it returns as a string. Until Spark 2.3, it always returns as a string despite of input types. To keep the old behavior, set `spark.sql.function.concatBinaryAsString` to `true`. + - Since Spark 2.3, when all inputs are binary, SQL `elt()` returns an output as binary. Otherwise, it returns as a string. Until Spark 2.3, it always returns as a string despite of input types. To keep the old behavior, set `spark.sql.function.eltOutputAsString` to `true`. + ## Upgrading From Spark SQL 2.1 to 2.2 - Spark 2.1.1 introduced a new configuration key: `spark.sql.hive.caseSensitiveInferenceMode`. It had a default setting of `NEVER_INFER`, which kept behavior identical to 2.1.0. However, Spark 2.2.0 changes this setting's default value to `INFER_AND_SAVE` to restore compatibility with reading Hive metastore tables whose underlying file schema have mixed-case column names. With the `INFER_AND_SAVE` configuration value, on first access Spark will perform schema inference on any Hive metastore table for which it has not already saved an inferred schema. Note that schema inference can be a very time consuming operation for tables with thousands of partitions. If compatibility with mixed-case column names is not a concern, you can safely set `spark.sql.hive.caseSensitiveInferenceMode` to `NEVER_INFER` to avoid the initial overhead of schema inference. Note that with the new default `INFER_AND_SAVE` setting, the results of the schema inference are saved as a metastore key for future use. Therefore, the initial schema inference occurs only at a table's first access. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index e9436367c7e2e..e8669c4637d06 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -54,6 +54,7 @@ object TypeCoercion { BooleanEquality :: FunctionArgumentConversion :: ConcatCoercion(conf) :: + EltCoercion(conf) :: CaseWhenCoercion :: IfCoercion :: StackCoercion :: @@ -684,6 +685,34 @@ object TypeCoercion { } } + /** + * Coerces the types of [[Elt]] children to expected ones. + * + * If `spark.sql.function.eltOutputAsString` is false and all children types are binary, + * the expected types are binary. Otherwise, the expected ones are strings. + */ + case class EltCoercion(conf: SQLConf) extends TypeCoercionRule { + + override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = plan transform { case p => + p transformExpressionsUp { + // Skip nodes if unresolved or not enough children + case c @ Elt(children) if !c.childrenResolved || children.size < 2 => c + case c @ Elt(children) => + val index = children.head + val newIndex = ImplicitTypeCasts.implicitCast(index, IntegerType).getOrElse(index) + val newInputs = if (conf.eltOutputAsString || + !children.tail.map(_.dataType).forall(_ == BinaryType)) { + children.tail.map { e => + ImplicitTypeCasts.implicitCast(e, StringType).getOrElse(e) + } + } else { + children.tail + } + c.copy(children = newIndex +: newInputs) + } + } + } + /** * Turns Add/Subtract of DateType/TimestampType/StringType and CalendarIntervalType * to TimeAdd/TimeSub diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 41dc762154a4c..e004bfc6af473 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -271,33 +271,45 @@ case class ConcatWs(children: Seq[Expression]) } } +/** + * An expression that returns the `n`-th input in given inputs. + * If all inputs are binary, `elt` returns an output as binary. Otherwise, it returns as string. + * If any input is null, `elt` returns null. + */ // scalastyle:off line.size.limit @ExpressionDescription( - usage = "_FUNC_(n, str1, str2, ...) - Returns the `n`-th string, e.g., returns `str2` when `n` is 2.", + usage = "_FUNC_(n, input1, input2, ...) - Returns the `n`-th input, e.g., returns `input2` when `n` is 2.", examples = """ Examples: > SELECT _FUNC_(1, 'scala', 'java'); scala """) // scalastyle:on line.size.limit -case class Elt(children: Seq[Expression]) - extends Expression with ImplicitCastInputTypes { +case class Elt(children: Seq[Expression]) extends Expression { private lazy val indexExpr = children.head - private lazy val stringExprs = children.tail.toArray + private lazy val inputExprs = children.tail.toArray /** This expression is always nullable because it returns null if index is out of range. */ override def nullable: Boolean = true - override def dataType: DataType = StringType - - override def inputTypes: Seq[DataType] = IntegerType +: Seq.fill(children.size - 1)(StringType) + override def dataType: DataType = inputExprs.map(_.dataType).headOption.getOrElse(StringType) override def checkInputDataTypes(): TypeCheckResult = { if (children.size < 2) { TypeCheckResult.TypeCheckFailure("elt function requires at least two arguments") } else { - super[ImplicitCastInputTypes].checkInputDataTypes() + val (indexType, inputTypes) = (indexExpr.dataType, inputExprs.map(_.dataType)) + if (indexType != IntegerType) { + return TypeCheckResult.TypeCheckFailure(s"first input to function $prettyName should " + + s"have IntegerType, but it's $indexType") + } + if (inputTypes.exists(tpe => !Seq(StringType, BinaryType).contains(tpe))) { + return TypeCheckResult.TypeCheckFailure( + s"input to function $prettyName should have StringType or BinaryType, but it's " + + inputTypes.map(_.simpleString).mkString("[", ", ", "]")) + } + TypeUtils.checkForSameTypeInputExpr(inputTypes, s"function $prettyName") } } @@ -307,27 +319,27 @@ case class Elt(children: Seq[Expression]) null } else { val index = indexObj.asInstanceOf[Int] - if (index <= 0 || index > stringExprs.length) { + if (index <= 0 || index > inputExprs.length) { null } else { - stringExprs(index - 1).eval(input) + inputExprs(index - 1).eval(input) } } } override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val index = indexExpr.genCode(ctx) - val strings = stringExprs.map(_.genCode(ctx)) + val inputs = inputExprs.map(_.genCode(ctx)) val indexVal = ctx.freshName("index") val indexMatched = ctx.freshName("eltIndexMatched") - val stringVal = ctx.addMutableState(ctx.javaType(dataType), "stringVal") + val inputVal = ctx.addMutableState(ctx.javaType(dataType), "inputVal") - val assignStringValue = strings.zipWithIndex.map { case (eval, index) => + val assignInputValue = inputs.zipWithIndex.map { case (eval, index) => s""" |if ($indexVal == ${index + 1}) { | ${eval.code} - | $stringVal = ${eval.isNull} ? null : ${eval.value}; + | $inputVal = ${eval.isNull} ? null : ${eval.value}; | $indexMatched = true; | continue; |} @@ -335,7 +347,7 @@ case class Elt(children: Seq[Expression]) } val codes = ctx.splitExpressionsWithCurrentInputs( - expressions = assignStringValue, + expressions = assignInputValue, funcName = "eltFunc", extraArguments = ("int", indexVal) :: Nil, returnType = ctx.JAVA_BOOLEAN, @@ -361,11 +373,11 @@ case class Elt(children: Seq[Expression]) |${index.code} |final int $indexVal = ${index.value}; |${ctx.JAVA_BOOLEAN} $indexMatched = false; - |$stringVal = null; + |$inputVal = null; |do { | $codes |} while (false); - |final UTF8String ${ev.value} = $stringVal; + |final ${ctx.javaType(dataType)} ${ev.value} = $inputVal; |final boolean ${ev.isNull} = ${ev.value} == null; """.stripMargin) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 5d6edf6b8abec..80b8965e084a2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1052,6 +1052,12 @@ object SQLConf { .booleanConf .createWithDefault(false) + val ELT_OUTPUT_AS_STRING = buildConf("spark.sql.function.eltOutputAsString") + .doc("When this option is set to false and all inputs are binary, `elt` returns " + + "an output as binary. Otherwise, it returns as a string. ") + .booleanConf + .createWithDefault(false) + val CONTINUOUS_STREAMING_EXECUTOR_QUEUE_SIZE = buildConf("spark.sql.streaming.continuous.executorQueueSize") .internal() @@ -1412,6 +1418,8 @@ class SQLConf extends Serializable with Logging { def concatBinaryAsString: Boolean = getConf(CONCAT_BINARY_AS_STRING) + def eltOutputAsString: Boolean = getConf(ELT_OUTPUT_AS_STRING) + def partitionOverwriteMode: PartitionOverwriteMode.Value = PartitionOverwriteMode.withName(getConf(PARTITION_OVERWRITE_MODE)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala index 3661530cd622b..52a7ebdafd7c7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala @@ -923,6 +923,60 @@ class TypeCoercionSuite extends AnalysisTest { } } + test("type coercion for Elt") { + val rule = TypeCoercion.EltCoercion(conf) + + ruleTest(rule, + Elt(Seq(Literal(1), Literal("ab"), Literal("cde"))), + Elt(Seq(Literal(1), Literal("ab"), Literal("cde")))) + ruleTest(rule, + Elt(Seq(Literal(1.toShort), Literal("ab"), Literal("cde"))), + Elt(Seq(Cast(Literal(1.toShort), IntegerType), Literal("ab"), Literal("cde")))) + ruleTest(rule, + Elt(Seq(Literal(2), Literal(null), Literal("abc"))), + Elt(Seq(Literal(2), Cast(Literal(null), StringType), Literal("abc")))) + ruleTest(rule, + Elt(Seq(Literal(2), Literal(1), Literal("234"))), + Elt(Seq(Literal(2), Cast(Literal(1), StringType), Literal("234")))) + ruleTest(rule, + Elt(Seq(Literal(3), Literal(1L), Literal(2.toByte), Literal(0.1))), + Elt(Seq(Literal(3), Cast(Literal(1L), StringType), Cast(Literal(2.toByte), StringType), + Cast(Literal(0.1), StringType)))) + ruleTest(rule, + Elt(Seq(Literal(2), Literal(true), Literal(0.1f), Literal(3.toShort))), + Elt(Seq(Literal(2), Cast(Literal(true), StringType), Cast(Literal(0.1f), StringType), + Cast(Literal(3.toShort), StringType)))) + ruleTest(rule, + Elt(Seq(Literal(1), Literal(1L), Literal(0.1))), + Elt(Seq(Literal(1), Cast(Literal(1L), StringType), Cast(Literal(0.1), StringType)))) + ruleTest(rule, + Elt(Seq(Literal(1), Literal(Decimal(10)))), + Elt(Seq(Literal(1), Cast(Literal(Decimal(10)), StringType)))) + ruleTest(rule, + Elt(Seq(Literal(1), Literal(BigDecimal.valueOf(10)))), + Elt(Seq(Literal(1), Cast(Literal(BigDecimal.valueOf(10)), StringType)))) + ruleTest(rule, + Elt(Seq(Literal(1), Literal(java.math.BigDecimal.valueOf(10)))), + Elt(Seq(Literal(1), Cast(Literal(java.math.BigDecimal.valueOf(10)), StringType)))) + ruleTest(rule, + Elt(Seq(Literal(2), Literal(new java.sql.Date(0)), Literal(new Timestamp(0)))), + Elt(Seq(Literal(2), Cast(Literal(new java.sql.Date(0)), StringType), + Cast(Literal(new Timestamp(0)), StringType)))) + + withSQLConf("spark.sql.function.eltOutputAsString" -> "true") { + ruleTest(rule, + Elt(Seq(Literal(1), Literal("123".getBytes), Literal("456".getBytes))), + Elt(Seq(Literal(1), Cast(Literal("123".getBytes), StringType), + Cast(Literal("456".getBytes), StringType)))) + } + + withSQLConf("spark.sql.function.eltOutputAsString" -> "false") { + ruleTest(rule, + Elt(Seq(Literal(1), Literal("123".getBytes), Literal("456".getBytes))), + Elt(Seq(Literal(1), Literal("123".getBytes), Literal("456".getBytes)))) + } + } + test("BooleanEquality type cast") { val be = TypeCoercion.BooleanEquality // Use something more than a literal to avoid triggering the simplification rules. diff --git a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/elt.sql b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/elt.sql new file mode 100644 index 0000000000000..717616f91db05 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/elt.sql @@ -0,0 +1,44 @@ +-- Mixed inputs (output type is string) +SELECT elt(2, col1, col2, col3, col4, col5) col +FROM ( + SELECT + 'prefix_' col1, + id col2, + string(id + 1) col3, + encode(string(id + 2), 'utf-8') col4, + CAST(id AS DOUBLE) col5 + FROM range(10) +); + +SELECT elt(3, col1, col2, col3, col4) col +FROM ( + SELECT + string(id) col1, + string(id + 1) col2, + encode(string(id + 2), 'utf-8') col3, + encode(string(id + 3), 'utf-8') col4 + FROM range(10) +); + +-- turn on eltOutputAsString +set spark.sql.function.eltOutputAsString=true; + +SELECT elt(1, col1, col2) col +FROM ( + SELECT + encode(string(id), 'utf-8') col1, + encode(string(id + 1), 'utf-8') col2 + FROM range(10) +); + +-- turn off eltOutputAsString +set spark.sql.function.eltOutputAsString=false; + +-- Elt binary inputs (output type is binary) +SELECT elt(2, col1, col2) col +FROM ( + SELECT + encode(string(id), 'utf-8') col1, + encode(string(id + 1), 'utf-8') col2 + FROM range(10) +); diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/elt.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/elt.sql.out new file mode 100644 index 0000000000000..b62e1b6826045 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/elt.sql.out @@ -0,0 +1,115 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 6 + + +-- !query 0 +SELECT elt(2, col1, col2, col3, col4, col5) col +FROM ( + SELECT + 'prefix_' col1, + id col2, + string(id + 1) col3, + encode(string(id + 2), 'utf-8') col4, + CAST(id AS DOUBLE) col5 + FROM range(10) +) +-- !query 0 schema +struct +-- !query 0 output +0 +1 +2 +3 +4 +5 +6 +7 +8 +9 + + +-- !query 1 +SELECT elt(3, col1, col2, col3, col4) col +FROM ( + SELECT + string(id) col1, + string(id + 1) col2, + encode(string(id + 2), 'utf-8') col3, + encode(string(id + 3), 'utf-8') col4 + FROM range(10) +) +-- !query 1 schema +struct +-- !query 1 output +10 +11 +2 +3 +4 +5 +6 +7 +8 +9 + + +-- !query 2 +set spark.sql.function.eltOutputAsString=true +-- !query 2 schema +struct +-- !query 2 output +spark.sql.function.eltOutputAsString true + + +-- !query 3 +SELECT elt(1, col1, col2) col +FROM ( + SELECT + encode(string(id), 'utf-8') col1, + encode(string(id + 1), 'utf-8') col2 + FROM range(10) +) +-- !query 3 schema +struct +-- !query 3 output +0 +1 +2 +3 +4 +5 +6 +7 +8 +9 + + +-- !query 4 +set spark.sql.function.eltOutputAsString=false +-- !query 4 schema +struct +-- !query 4 output +spark.sql.function.eltOutputAsString false + + +-- !query 5 +SELECT elt(2, col1, col2) col +FROM ( + SELECT + encode(string(id), 'utf-8') col1, + encode(string(id + 1), 'utf-8') col2 + FROM range(10) +) +-- !query 5 schema +struct +-- !query 5 output +1 +10 +2 +3 +4 +5 +6 +7 +8 +9 From bf65cd3cda46d5480bfcd13110975c46ca631972 Mon Sep 17 00:00:00 2001 From: Yinan Li Date: Fri, 5 Jan 2018 17:29:27 -0800 Subject: [PATCH 0028/1161] [SPARK-22960][K8S] Revert use of ARG base_image in images ## What changes were proposed in this pull request? This PR reverts the `ARG base_image` before `FROM` in the images of driver, executor, and init-container, introduced in https://github.com/apache/spark/pull/20154. The reason is Docker versions before 17.06 do not support this use (`ARG` before `FROM`). ## How was this patch tested? Tested manually. vanzin foxish kimoonkim Author: Yinan Li Closes #20170 from liyinan926/master. --- .../docker/src/main/dockerfiles/driver/Dockerfile | 3 +-- .../docker/src/main/dockerfiles/executor/Dockerfile | 3 +-- .../docker/src/main/dockerfiles/init-container/Dockerfile | 3 +-- sbin/build-push-docker-images.sh | 8 ++++---- 4 files changed, 7 insertions(+), 10 deletions(-) diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/driver/Dockerfile b/resource-managers/kubernetes/docker/src/main/dockerfiles/driver/Dockerfile index ff5289e10c21e..45fbcd9cd0deb 100644 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/driver/Dockerfile +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/driver/Dockerfile @@ -15,8 +15,7 @@ # limitations under the License. # -ARG base_image -FROM ${base_image} +FROM spark-base # Before building the docker image, first build and make a Spark distribution following # the instructions in http://spark.apache.org/docs/latest/building-spark.html. diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/executor/Dockerfile b/resource-managers/kubernetes/docker/src/main/dockerfiles/executor/Dockerfile index 3eabb42d4d852..0f806cf7e148e 100644 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/executor/Dockerfile +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/executor/Dockerfile @@ -15,8 +15,7 @@ # limitations under the License. # -ARG base_image -FROM ${base_image} +FROM spark-base # Before building the docker image, first build and make a Spark distribution following # the instructions in http://spark.apache.org/docs/latest/building-spark.html. diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/init-container/Dockerfile b/resource-managers/kubernetes/docker/src/main/dockerfiles/init-container/Dockerfile index e0a249e0ac71f..047056ab2633b 100644 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/init-container/Dockerfile +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/init-container/Dockerfile @@ -15,8 +15,7 @@ # limitations under the License. # -ARG base_image -FROM ${base_image} +FROM spark-base # If this docker file is being used in the context of building your images from a Spark distribution, the docker build # command should be invoked from the top level directory of the Spark distribution. E.g.: diff --git a/sbin/build-push-docker-images.sh b/sbin/build-push-docker-images.sh index bb8806dd33f37..b9532597419a5 100755 --- a/sbin/build-push-docker-images.sh +++ b/sbin/build-push-docker-images.sh @@ -60,13 +60,13 @@ function image_ref { } function build { - local base_image="$(image_ref spark-base 0)" - docker build --build-arg "spark_jars=$SPARK_JARS" \ + docker build \ + --build-arg "spark_jars=$SPARK_JARS" \ --build-arg "img_path=$IMG_PATH" \ - -t "$base_image" \ + -t spark-base \ -f "$IMG_PATH/spark-base/Dockerfile" . for image in "${!path[@]}"; do - docker build --build-arg "base_image=$base_image" -t "$(image_ref $image)" -f ${path[$image]} . + docker build -t "$(image_ref $image)" -f ${path[$image]} . done } From f2dd8b923759e8771b0e5f59bfa7ae4ad7e6a339 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Sat, 6 Jan 2018 16:11:20 +0800 Subject: [PATCH 0029/1161] [SPARK-22930][PYTHON][SQL] Improve the description of Vectorized UDFs for non-deterministic cases ## What changes were proposed in this pull request? Add tests for using non deterministic UDFs in aggregate. Update pandas_udf docstring w.r.t to determinism. ## How was this patch tested? test_nondeterministic_udf_in_aggregate Author: Li Jin Closes #20142 from icexelloss/SPARK-22930-pandas-udf-deterministic. --- python/pyspark/sql/functions.py | 12 +++++++- python/pyspark/sql/tests.py | 52 +++++++++++++++++++++++++++++++++ 2 files changed, 63 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index a4ed562ad48b4..733e32bd825b0 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2214,7 +2214,17 @@ def pandas_udf(f=None, returnType=None, functionType=None): .. seealso:: :meth:`pyspark.sql.GroupedData.apply` - .. note:: The user-defined function must be deterministic. + .. note:: The user-defined functions are considered deterministic by default. Due to + optimization, duplicate invocations may be eliminated or the function may even be invoked + more times than it is present in the query. If your function is not deterministic, call + `asNondeterministic` on the user defined function. E.g.: + + >>> @pandas_udf('double', PandasUDFType.SCALAR) # doctest: +SKIP + ... def random(v): + ... import numpy as np + ... import pandas as pd + ... return pd.Series(np.random.randn(len(v)) + >>> random = random.asNondeterministic() # doctest: +SKIP .. note:: The user-defined functions do not support conditional expressions or short curcuiting in boolean expressions and it ends up with being executed all internally. If the functions diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 6dc767f9ec46e..689736d8e6456 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -386,6 +386,7 @@ def test_udf3(self): self.assertEqual(row[0], 5) def test_nondeterministic_udf(self): + # Test that nondeterministic UDFs are evaluated only once in chained UDF evaluations from pyspark.sql.functions import udf import random udf_random_col = udf(lambda: int(100 * random.random()), IntegerType()).asNondeterministic() @@ -413,6 +414,18 @@ def test_nondeterministic_udf2(self): pydoc.render_doc(random_udf) pydoc.render_doc(random_udf1) + def test_nondeterministic_udf_in_aggregate(self): + from pyspark.sql.functions import udf, sum + import random + udf_random_col = udf(lambda: int(100 * random.random()), 'int').asNondeterministic() + df = self.spark.range(10) + + with QuietTest(self.sc): + with self.assertRaisesRegexp(AnalysisException, "nondeterministic"): + df.groupby('id').agg(sum(udf_random_col())).collect() + with self.assertRaisesRegexp(AnalysisException, "nondeterministic"): + df.agg(sum(udf_random_col())).collect() + def test_chained_udf(self): self.spark.catalog.registerFunction("double", lambda x: x + x, IntegerType()) [row] = self.spark.sql("SELECT double(1)").collect() @@ -3567,6 +3580,18 @@ def tearDownClass(cls): time.tzset() ReusedSQLTestCase.tearDownClass() + @property + def random_udf(self): + from pyspark.sql.functions import pandas_udf + + @pandas_udf('double') + def random_udf(v): + import pandas as pd + import numpy as np + return pd.Series(np.random.random(len(v))) + random_udf = random_udf.asNondeterministic() + return random_udf + def test_vectorized_udf_basic(self): from pyspark.sql.functions import pandas_udf, col df = self.spark.range(10).select( @@ -3950,6 +3975,33 @@ def test_vectorized_udf_timestamps_respect_session_timezone(self): finally: self.spark.conf.set("spark.sql.session.timeZone", orig_tz) + def test_nondeterministic_udf(self): + # Test that nondeterministic UDFs are evaluated only once in chained UDF evaluations + from pyspark.sql.functions import udf, pandas_udf, col + + @pandas_udf('double') + def plus_ten(v): + return v + 10 + random_udf = self.random_udf + + df = self.spark.range(10).withColumn('rand', random_udf(col('id'))) + result1 = df.withColumn('plus_ten(rand)', plus_ten(df['rand'])).toPandas() + + self.assertEqual(random_udf.deterministic, False) + self.assertTrue(result1['plus_ten(rand)'].equals(result1['rand'] + 10)) + + def test_nondeterministic_udf_in_aggregate(self): + from pyspark.sql.functions import pandas_udf, sum + + df = self.spark.range(10) + random_udf = self.random_udf + + with QuietTest(self.sc): + with self.assertRaisesRegexp(AnalysisException, 'nondeterministic'): + df.groupby(df.id).agg(sum(random_udf(df.id))).collect() + with self.assertRaisesRegexp(AnalysisException, 'nondeterministic'): + df.agg(sum(random_udf(df.id))).collect() + @unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed") class GroupbyApplyTests(ReusedSQLTestCase): From be9a804f2ef77a5044d3da7d9374976daf59fc16 Mon Sep 17 00:00:00 2001 From: zuotingbing Date: Sat, 6 Jan 2018 18:07:45 +0800 Subject: [PATCH 0030/1161] [SPARK-22793][SQL] Memory leak in Spark Thrift Server # What changes were proposed in this pull request? 1. Start HiveThriftServer2. 2. Connect to thriftserver through beeline. 3. Close the beeline. 4. repeat step2 and step 3 for many times. we found there are many directories never be dropped under the path `hive.exec.local.scratchdir` and `hive.exec.scratchdir`, as we know the scratchdir has been added to deleteOnExit when it be created. So it means that the cache size of FileSystem `deleteOnExit` will keep increasing until JVM terminated. In addition, we use `jmap -histo:live [PID]` to printout the size of objects in HiveThriftServer2 Process, we can find the object `org.apache.spark.sql.hive.client.HiveClientImpl` and `org.apache.hadoop.hive.ql.session.SessionState` keep increasing even though we closed all the beeline connections, which may caused the leak of Memory. # How was this patch tested? manual tests This PR follw-up the https://github.com/apache/spark/pull/19989 Author: zuotingbing Closes #20029 from zuotingbing/SPARK-22793. --- .../org/apache/spark/sql/hive/HiveSessionStateBuilder.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala index 92cb4ef11c9e3..dc92ad3b0c1ac 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala @@ -42,7 +42,7 @@ class HiveSessionStateBuilder(session: SparkSession, parentState: Option[Session * Create a Hive aware resource loader. */ override protected lazy val resourceLoader: HiveSessionResourceLoader = { - val client: HiveClient = externalCatalog.client.newSession() + val client: HiveClient = externalCatalog.client new HiveSessionResourceLoader(session, client) } From 7b78041423b6ee330def2336dfd1ff9ae8469c59 Mon Sep 17 00:00:00 2001 From: fjh100456 Date: Sat, 6 Jan 2018 18:19:57 +0800 Subject: [PATCH 0031/1161] [SPARK-21786][SQL] When acquiring 'compressionCodecClassName' in 'ParquetOptions', `parquet.compression` needs to be considered. [SPARK-21786][SQL] When acquiring 'compressionCodecClassName' in 'ParquetOptions', `parquet.compression` needs to be considered. ## What changes were proposed in this pull request? Since Hive 1.1, Hive allows users to set parquet compression codec via table-level properties parquet.compression. See the JIRA: https://issues.apache.org/jira/browse/HIVE-7858 . We do support orc.compression for ORC. Thus, for external users, it is more straightforward to support both. See the stackflow question: https://stackoverflow.com/questions/36941122/spark-sql-ignores-parquet-compression-propertie-specified-in-tblproperties In Spark side, our table-level compression conf compression was added by #11464 since Spark 2.0. We need to support both table-level conf. Users might also use session-level conf spark.sql.parquet.compression.codec. The priority rule will be like If other compression codec configuration was found through hive or parquet, the precedence would be compression, parquet.compression, spark.sql.parquet.compression.codec. Acceptable values include: none, uncompressed, snappy, gzip, lzo. The rule for Parquet is consistent with the ORC after the change. Changes: 1.Increased acquiring 'compressionCodecClassName' from `parquet.compression`,and the precedence order is `compression`,`parquet.compression`,`spark.sql.parquet.compression.codec`, just like what we do in `OrcOptions`. 2.Change `spark.sql.parquet.compression.codec` to support "none".Actually in `ParquetOptions`,we do support "none" as equivalent to "uncompressed", but it does not allowed to configured to "none". 3.Change `compressionCode` to `compressionCodecClassName`. ## How was this patch tested? Add test. Author: fjh100456 Closes #20076 from fjh100456/ParquetOptionIssue. --- docs/sql-programming-guide.md | 6 +- .../apache/spark/sql/internal/SQLConf.scala | 14 +- .../datasources/parquet/ParquetOptions.scala | 12 +- ...rquetCompressionCodecPrecedenceSuite.scala | 122 ++++++++++++++++++ 4 files changed, 145 insertions(+), 9 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompressionCodecPrecedenceSuite.scala diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index b50f9360b866c..3ccaaf4d5b1fa 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -953,8 +953,10 @@ Configuration of Parquet can be done using the `setConf` method on `SparkSession spark.sql.parquet.compression.codec snappy - Sets the compression codec use when writing Parquet files. Acceptable values include: - uncompressed, snappy, gzip, lzo. + Sets the compression codec used when writing Parquet files. If either `compression` or + `parquet.compression` is specified in the table-specific options/properties, the precedence would be + `compression`, `parquet.compression`, `spark.sql.parquet.compression.codec`. Acceptable values include: + none, uncompressed, snappy, gzip, lzo. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 80b8965e084a2..7d1217de254a2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -325,11 +325,13 @@ object SQLConf { .createWithDefault(false) val PARQUET_COMPRESSION = buildConf("spark.sql.parquet.compression.codec") - .doc("Sets the compression codec use when writing Parquet files. Acceptable values include: " + - "uncompressed, snappy, gzip, lzo.") + .doc("Sets the compression codec used when writing Parquet files. If either `compression` or" + + "`parquet.compression` is specified in the table-specific options/properties, the precedence" + + "would be `compression`, `parquet.compression`, `spark.sql.parquet.compression.codec`." + + "Acceptable values include: none, uncompressed, snappy, gzip, lzo.") .stringConf .transform(_.toLowerCase(Locale.ROOT)) - .checkValues(Set("uncompressed", "snappy", "gzip", "lzo")) + .checkValues(Set("none", "uncompressed", "snappy", "gzip", "lzo")) .createWithDefault("snappy") val PARQUET_FILTER_PUSHDOWN_ENABLED = buildConf("spark.sql.parquet.filterPushdown") @@ -366,8 +368,10 @@ object SQLConf { .createWithDefault(true) val ORC_COMPRESSION = buildConf("spark.sql.orc.compression.codec") - .doc("Sets the compression codec use when writing ORC files. Acceptable values include: " + - "none, uncompressed, snappy, zlib, lzo.") + .doc("Sets the compression codec used when writing ORC files. If either `compression` or" + + "`orc.compress` is specified in the table-specific options/properties, the precedence" + + "would be `compression`, `orc.compress`, `spark.sql.orc.compression.codec`." + + "Acceptable values include: none, uncompressed, snappy, zlib, lzo.") .stringConf .transform(_.toLowerCase(Locale.ROOT)) .checkValues(Set("none", "uncompressed", "snappy", "zlib", "lzo")) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala index 772d4565de548..ef67ea7d17cea 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.datasources.parquet import java.util.Locale +import org.apache.parquet.hadoop.ParquetOutputFormat import org.apache.parquet.hadoop.metadata.CompressionCodecName import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap @@ -42,8 +43,15 @@ private[parquet] class ParquetOptions( * Acceptable values are defined in [[shortParquetCompressionCodecNames]]. */ val compressionCodecClassName: String = { - val codecName = parameters.getOrElse("compression", - sqlConf.parquetCompressionCodec).toLowerCase(Locale.ROOT) + // `compression`, `parquet.compression`(i.e., ParquetOutputFormat.COMPRESSION), and + // `spark.sql.parquet.compression.codec` + // are in order of precedence from highest to lowest. + val parquetCompressionConf = parameters.get(ParquetOutputFormat.COMPRESSION) + val codecName = parameters + .get("compression") + .orElse(parquetCompressionConf) + .getOrElse(sqlConf.parquetCompressionCodec) + .toLowerCase(Locale.ROOT) if (!shortParquetCompressionCodecNames.contains(codecName)) { val availableCodecs = shortParquetCompressionCodecNames.keys.map(_.toLowerCase(Locale.ROOT)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompressionCodecPrecedenceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompressionCodecPrecedenceSuite.scala new file mode 100644 index 0000000000000..ed8fd2b453456 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompressionCodecPrecedenceSuite.scala @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.parquet + +import java.io.File + +import scala.collection.JavaConverters._ + +import org.apache.hadoop.fs.Path +import org.apache.parquet.hadoop.ParquetOutputFormat + +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSQLContext + +class ParquetCompressionCodecPrecedenceSuite extends ParquetTest with SharedSQLContext { + test("Test `spark.sql.parquet.compression.codec` config") { + Seq("NONE", "UNCOMPRESSED", "SNAPPY", "GZIP", "LZO").foreach { c => + withSQLConf(SQLConf.PARQUET_COMPRESSION.key -> c) { + val expected = if (c == "NONE") "UNCOMPRESSED" else c + val option = new ParquetOptions(Map.empty[String, String], spark.sessionState.conf) + assert(option.compressionCodecClassName == expected) + } + } + } + + test("[SPARK-21786] Test Acquiring 'compressionCodecClassName' for parquet in right order.") { + // When "compression" is configured, it should be the first choice. + withSQLConf(SQLConf.PARQUET_COMPRESSION.key -> "snappy") { + val props = Map("compression" -> "uncompressed", ParquetOutputFormat.COMPRESSION -> "gzip") + val option = new ParquetOptions(props, spark.sessionState.conf) + assert(option.compressionCodecClassName == "UNCOMPRESSED") + } + + // When "compression" is not configured, "parquet.compression" should be the preferred choice. + withSQLConf(SQLConf.PARQUET_COMPRESSION.key -> "snappy") { + val props = Map(ParquetOutputFormat.COMPRESSION -> "gzip") + val option = new ParquetOptions(props, spark.sessionState.conf) + assert(option.compressionCodecClassName == "GZIP") + } + + // When both "compression" and "parquet.compression" are not configured, + // spark.sql.parquet.compression.codec should be the right choice. + withSQLConf(SQLConf.PARQUET_COMPRESSION.key -> "snappy") { + val props = Map.empty[String, String] + val option = new ParquetOptions(props, spark.sessionState.conf) + assert(option.compressionCodecClassName == "SNAPPY") + } + } + + private def getTableCompressionCodec(path: String): Seq[String] = { + val hadoopConf = spark.sessionState.newHadoopConf() + val codecs = for { + footer <- readAllFootersWithoutSummaryFiles(new Path(path), hadoopConf) + block <- footer.getParquetMetadata.getBlocks.asScala + column <- block.getColumns.asScala + } yield column.getCodec.name() + codecs.distinct + } + + private def createTableWithCompression( + tableName: String, + isPartitioned: Boolean, + compressionCodec: String, + rootDir: File): Unit = { + val options = + s""" + |OPTIONS('path'='${rootDir.toURI.toString.stripSuffix("/")}/$tableName', + |'parquet.compression'='$compressionCodec') + """.stripMargin + val partitionCreate = if (isPartitioned) "PARTITIONED BY (p)" else "" + sql( + s""" + |CREATE TABLE $tableName USING Parquet $options $partitionCreate + |AS SELECT 1 AS col1, 2 AS p + """.stripMargin) + } + + private def checkCompressionCodec(compressionCodec: String, isPartitioned: Boolean): Unit = { + withTempDir { tmpDir => + val tempTableName = "TempParquetTable" + withTable(tempTableName) { + createTableWithCompression(tempTableName, isPartitioned, compressionCodec, tmpDir) + val partitionPath = if (isPartitioned) "p=2" else "" + val path = s"${tmpDir.getPath.stripSuffix("/")}/$tempTableName/$partitionPath" + val realCompressionCodecs = getTableCompressionCodec(path) + assert(realCompressionCodecs.forall(_ == compressionCodec)) + } + } + } + + test("Create parquet table with compression") { + Seq(true, false).foreach { isPartitioned => + Seq("UNCOMPRESSED", "SNAPPY", "GZIP").foreach { compressionCodec => + checkCompressionCodec(compressionCodec, isPartitioned) + } + } + } + + test("Create table with unknown compression") { + Seq(true, false).foreach { isPartitioned => + val exception = intercept[IllegalArgumentException] { + checkCompressionCodec("aa", isPartitioned) + } + assert(exception.getMessage.contains("Codec [aa] is not available")) + } + } +} From 993f21567a1dd33e43ef9a626e0ddfbe46f83f93 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Sat, 6 Jan 2018 23:08:26 +0800 Subject: [PATCH 0032/1161] [SPARK-22901][PYTHON][FOLLOWUP] Adds the doc for asNondeterministic for wrapped UDF function ## What changes were proposed in this pull request? This PR wraps the `asNondeterministic` attribute in the wrapped UDF function to set the docstring properly. ```python from pyspark.sql.functions import udf help(udf(lambda x: x).asNondeterministic) ``` Before: ``` Help on function in module pyspark.sql.udf: lambda (END ``` After: ``` Help on function asNondeterministic in module pyspark.sql.udf: asNondeterministic() Updates UserDefinedFunction to nondeterministic. .. versionadded:: 2.3 (END) ``` ## How was this patch tested? Manually tested and a simple test was added. Author: hyukjinkwon Closes #20173 from HyukjinKwon/SPARK-22901-followup. --- python/pyspark/sql/tests.py | 1 + python/pyspark/sql/udf.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 689736d8e6456..122a65b83aef9 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -413,6 +413,7 @@ def test_nondeterministic_udf2(self): pydoc.render_doc(udf(lambda: random.randint(6, 6), IntegerType())) pydoc.render_doc(random_udf) pydoc.render_doc(random_udf1) + pydoc.render_doc(udf(lambda x: x).asNondeterministic) def test_nondeterministic_udf_in_aggregate(self): from pyspark.sql.functions import udf, sum diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index 5e75eb6545333..5e80ab9165867 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -169,8 +169,8 @@ def wrapper(*args): wrapper.returnType = self.returnType wrapper.evalType = self.evalType wrapper.deterministic = self.deterministic - wrapper.asNondeterministic = lambda: self.asNondeterministic()._wrapped() - + wrapper.asNondeterministic = functools.wraps( + self.asNondeterministic)(lambda: self.asNondeterministic()._wrapped()) return wrapper def asNondeterministic(self): From 9a7048b2889bd0fd66e68a0ce3e07e466315a051 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Sun, 7 Jan 2018 00:19:21 +0800 Subject: [PATCH 0033/1161] [HOTFIX] Fix style checking failure ## What changes were proposed in this pull request? This PR is to fix the style checking failure. ## How was this patch tested? N/A Author: gatorsmile Closes #20175 from gatorsmile/stylefix. --- .../org/apache/spark/sql/internal/SQLConf.scala | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 7d1217de254a2..5c61f10bb71ad 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -325,10 +325,11 @@ object SQLConf { .createWithDefault(false) val PARQUET_COMPRESSION = buildConf("spark.sql.parquet.compression.codec") - .doc("Sets the compression codec used when writing Parquet files. If either `compression` or" + - "`parquet.compression` is specified in the table-specific options/properties, the precedence" + - "would be `compression`, `parquet.compression`, `spark.sql.parquet.compression.codec`." + - "Acceptable values include: none, uncompressed, snappy, gzip, lzo.") + .doc("Sets the compression codec used when writing Parquet files. If either `compression` or " + + "`parquet.compression` is specified in the table-specific options/properties, the " + + "precedence would be `compression`, `parquet.compression`, " + + "`spark.sql.parquet.compression.codec`. Acceptable values include: none, uncompressed, " + + "snappy, gzip, lzo.") .stringConf .transform(_.toLowerCase(Locale.ROOT)) .checkValues(Set("none", "uncompressed", "snappy", "gzip", "lzo")) @@ -368,8 +369,8 @@ object SQLConf { .createWithDefault(true) val ORC_COMPRESSION = buildConf("spark.sql.orc.compression.codec") - .doc("Sets the compression codec used when writing ORC files. If either `compression` or" + - "`orc.compress` is specified in the table-specific options/properties, the precedence" + + .doc("Sets the compression codec used when writing ORC files. If either `compression` or " + + "`orc.compress` is specified in the table-specific options/properties, the precedence " + "would be `compression`, `orc.compress`, `spark.sql.orc.compression.codec`." + "Acceptable values include: none, uncompressed, snappy, zlib, lzo.") .stringConf From 18e94149992618a2b4e6f0fd3b3f4594e1745224 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Sun, 7 Jan 2018 13:42:01 +0800 Subject: [PATCH 0034/1161] [SPARK-22973][SQL] Fix incorrect results of Casting Map to String ## What changes were proposed in this pull request? This pr fixed the issue when casting maps into strings; ``` scala> Seq(Map(1 -> "a", 2 -> "b")).toDF("a").write.saveAsTable("t") scala> sql("SELECT cast(a as String) FROM t").show(false) +----------------------------------------------------------------+ |a | +----------------------------------------------------------------+ |org.apache.spark.sql.catalyst.expressions.UnsafeMapData38bdd75d| +----------------------------------------------------------------+ ``` This pr modified the result into; ``` +----------------+ |a | +----------------+ |[1 -> a, 2 -> b]| +----------------+ ``` ## How was this patch tested? Added tests in `CastSuite`. Author: Takeshi Yamamuro Closes #20166 from maropu/SPARK-22973. --- .../spark/sql/catalyst/expressions/Cast.scala | 89 +++++++++++++++++++ .../sql/catalyst/expressions/CastSuite.scala | 28 ++++++ 2 files changed, 117 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index d4fc5e0f168a7..f2de4c8e30bec 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -228,6 +228,37 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String builder.append("]") builder.build() }) + case MapType(kt, vt, _) => + buildCast[MapData](_, map => { + val builder = new UTF8StringBuilder + builder.append("[") + if (map.numElements > 0) { + val keyArray = map.keyArray() + val valueArray = map.valueArray() + val keyToUTF8String = castToString(kt) + val valueToUTF8String = castToString(vt) + builder.append(keyToUTF8String(keyArray.get(0, kt)).asInstanceOf[UTF8String]) + builder.append(" ->") + if (!valueArray.isNullAt(0)) { + builder.append(" ") + builder.append(valueToUTF8String(valueArray.get(0, vt)).asInstanceOf[UTF8String]) + } + var i = 1 + while (i < map.numElements) { + builder.append(", ") + builder.append(keyToUTF8String(keyArray.get(i, kt)).asInstanceOf[UTF8String]) + builder.append(" ->") + if (!valueArray.isNullAt(i)) { + builder.append(" ") + builder.append(valueToUTF8String(valueArray.get(i, vt)) + .asInstanceOf[UTF8String]) + } + i += 1 + } + } + builder.append("]") + builder.build() + }) case _ => buildCast[Any](_, o => UTF8String.fromString(o.toString)) } @@ -654,6 +685,53 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String """.stripMargin } + private def writeMapToStringBuilder( + kt: DataType, + vt: DataType, + map: String, + buffer: String, + ctx: CodegenContext): String = { + + def dataToStringFunc(func: String, dataType: DataType) = { + val funcName = ctx.freshName(func) + val dataToStringCode = castToStringCode(dataType, ctx) + ctx.addNewFunction(funcName, + s""" + |private UTF8String $funcName(${ctx.javaType(dataType)} data) { + | UTF8String dataStr = null; + | ${dataToStringCode("data", "dataStr", null /* resultIsNull won't be used */)} + | return dataStr; + |} + """.stripMargin) + } + + val keyToStringFunc = dataToStringFunc("keyToString", kt) + val valueToStringFunc = dataToStringFunc("valueToString", vt) + val loopIndex = ctx.freshName("loopIndex") + s""" + |$buffer.append("["); + |if ($map.numElements() > 0) { + | $buffer.append($keyToStringFunc(${ctx.getValue(s"$map.keyArray()", kt, "0")})); + | $buffer.append(" ->"); + | if (!$map.valueArray().isNullAt(0)) { + | $buffer.append(" "); + | $buffer.append($valueToStringFunc(${ctx.getValue(s"$map.valueArray()", vt, "0")})); + | } + | for (int $loopIndex = 1; $loopIndex < $map.numElements(); $loopIndex++) { + | $buffer.append(", "); + | $buffer.append($keyToStringFunc(${ctx.getValue(s"$map.keyArray()", kt, loopIndex)})); + | $buffer.append(" ->"); + | if (!$map.valueArray().isNullAt($loopIndex)) { + | $buffer.append(" "); + | $buffer.append($valueToStringFunc( + | ${ctx.getValue(s"$map.valueArray()", vt, loopIndex)})); + | } + | } + |} + |$buffer.append("]"); + """.stripMargin + } + private[this] def castToStringCode(from: DataType, ctx: CodegenContext): CastFunction = { from match { case BinaryType => @@ -676,6 +754,17 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String |$evPrim = $buffer.build(); """.stripMargin } + case MapType(kt, vt, _) => + (c, evPrim, evNull) => { + val buffer = ctx.freshName("buffer") + val bufferClass = classOf[UTF8StringBuilder].getName + val writeMapElemCode = writeMapToStringBuilder(kt, vt, c, buffer, ctx) + s""" + |$bufferClass $buffer = new $bufferClass(); + |$writeMapElemCode; + |$evPrim = $buffer.build(); + """.stripMargin + } case _ => (c, evPrim, evNull) => s"$evPrim = UTF8String.fromString(String.valueOf($c));" } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala index e3ed7171defd8..1445bb8a97d40 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -878,4 +878,32 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { StringType) checkEvaluation(ret8, "[[[a], [b, c]], [[d]]]") } + + test("SPARK-22973 Cast map to string") { + val ret1 = cast(Literal.create(Map(1 -> "a", 2 -> "b", 3 -> "c")), StringType) + checkEvaluation(ret1, "[1 -> a, 2 -> b, 3 -> c]") + val ret2 = cast( + Literal.create(Map("1" -> "a".getBytes, "2" -> null, "3" -> "c".getBytes)), + StringType) + checkEvaluation(ret2, "[1 -> a, 2 ->, 3 -> c]") + val ret3 = cast( + Literal.create(Map( + 1 -> Date.valueOf("2014-12-03"), + 2 -> Date.valueOf("2014-12-04"), + 3 -> Date.valueOf("2014-12-05"))), + StringType) + checkEvaluation(ret3, "[1 -> 2014-12-03, 2 -> 2014-12-04, 3 -> 2014-12-05]") + val ret4 = cast( + Literal.create(Map( + 1 -> Timestamp.valueOf("2014-12-03 13:01:00"), + 2 -> Timestamp.valueOf("2014-12-04 15:05:00"))), + StringType) + checkEvaluation(ret4, "[1 -> 2014-12-03 13:01:00, 2 -> 2014-12-04 15:05:00]") + val ret5 = cast( + Literal.create(Map( + 1 -> Array(1, 2, 3), + 2 -> Array(4, 5, 6))), + StringType) + checkEvaluation(ret5, "[1 -> [1, 2, 3], 2 -> [4, 5, 6]]") + } } From 71d65a32158a55285be197bec4e41fedc9225b94 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 8 Jan 2018 11:39:45 +0800 Subject: [PATCH 0035/1161] [SPARK-22985] Fix argument escaping bug in from_utc_timestamp / to_utc_timestamp codegen ## What changes were proposed in this pull request? This patch adds additional escaping in `from_utc_timestamp` / `to_utc_timestamp` expression codegen in order to a bug where invalid timezones which contain special characters could cause generated code to fail to compile. ## How was this patch tested? New regression tests in `DateExpressionsSuite`. Author: Josh Rosen Closes #20182 from JoshRosen/SPARK-22985-fix-utc-timezone-function-escaping-bugs. --- .../catalyst/expressions/datetimeExpressions.scala | 12 ++++++++---- .../catalyst/expressions/DateExpressionsSuite.scala | 6 ++++++ 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index 7a674ea7f4d76..424871f2047e9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -23,6 +23,8 @@ import java.util.{Calendar, TimeZone} import scala.util.control.NonFatal +import org.apache.commons.lang3.StringEscapeUtils + import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodegenFallback, ExprCode} import org.apache.spark.sql.catalyst.util.DateTimeUtils @@ -1008,7 +1010,7 @@ case class FromUTCTimestamp(left: Expression, right: Expression) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") if (right.foldable) { - val tz = right.eval() + val tz = right.eval().asInstanceOf[UTF8String] if (tz == null) { ev.copy(code = s""" |boolean ${ev.isNull} = true; @@ -1017,8 +1019,9 @@ case class FromUTCTimestamp(left: Expression, right: Expression) } else { val tzClass = classOf[TimeZone].getName val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") + val escapedTz = StringEscapeUtils.escapeJava(tz.toString) val tzTerm = ctx.addMutableState(tzClass, "tz", - v => s"""$v = $dtu.getTimeZone("$tz");""") + v => s"""$v = $dtu.getTimeZone("$escapedTz");""") val utcTerm = "tzUTC" ctx.addImmutableStateIfNotExists(tzClass, utcTerm, v => s"""$v = $dtu.getTimeZone("UTC");""") @@ -1185,7 +1188,7 @@ case class ToUTCTimestamp(left: Expression, right: Expression) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") if (right.foldable) { - val tz = right.eval() + val tz = right.eval().asInstanceOf[UTF8String] if (tz == null) { ev.copy(code = s""" |boolean ${ev.isNull} = true; @@ -1194,8 +1197,9 @@ case class ToUTCTimestamp(left: Expression, right: Expression) } else { val tzClass = classOf[TimeZone].getName val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") + val escapedTz = StringEscapeUtils.escapeJava(tz.toString) val tzTerm = ctx.addMutableState(tzClass, "tz", - v => s"""$v = $dtu.getTimeZone("$tz");""") + v => s"""$v = $dtu.getTimeZone("$escapedTz");""") val utcTerm = "tzUTC" ctx.addImmutableStateIfNotExists(tzClass, utcTerm, v => s"""$v = $dtu.getTimeZone("UTC");""") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala index 63f6ceeb21b96..786266a2c13c0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala @@ -22,6 +22,7 @@ import java.text.SimpleDateFormat import java.util.{Calendar, Locale, TimeZone} import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection import org.apache.spark.sql.catalyst.util.DateTimeTestUtils._ import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.catalyst.util.DateTimeUtils.TimeZoneGMT @@ -791,6 +792,9 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { test(null, "UTC", null) test("2015-07-24 00:00:00", null, null) test(null, null, null) + // Test escaping of timezone + GenerateUnsafeProjection.generate( + ToUTCTimestamp(Literal(Timestamp.valueOf("2015-07-24 00:00:00")), Literal("\"quote")) :: Nil) } test("from_utc_timestamp") { @@ -811,5 +815,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { test(null, "UTC", null) test("2015-07-24 00:00:00", null, null) test(null, null, null) + // Test escaping of timezone + GenerateUnsafeProjection.generate(FromUTCTimestamp(Literal(0), Literal("\"quote")) :: Nil) } } From 3e40eb3f1ffac3d2f49459a801e3ce171ed34091 Mon Sep 17 00:00:00 2001 From: Guilherme Berger Date: Mon, 8 Jan 2018 14:32:05 +0900 Subject: [PATCH 0036/1161] [SPARK-22566][PYTHON] Better error message for `_merge_type` in Pandas to Spark DF conversion MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? It provides a better error message when doing `spark_session.createDataFrame(pandas_df)` with no schema and an error occurs in the schema inference due to incompatible types. The Pandas column names are propagated down and the error message mentions which column had the merging error. https://issues.apache.org/jira/browse/SPARK-22566 ## How was this patch tested? Manually in the `./bin/pyspark` console, and with new tests: `./python/run-tests` screen shot 2017-11-21 at 13 29 49 I state that the contribution is my original work and that I license the work to the Apache Spark project under the project’s open source license. Author: Guilherme Berger Closes #19792 from gberger/master. --- python/pyspark/sql/session.py | 17 +++--- python/pyspark/sql/tests.py | 100 ++++++++++++++++++++++++++++++++++ python/pyspark/sql/types.py | 28 +++++++--- 3 files changed, 129 insertions(+), 16 deletions(-) diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index 6e5eec48e8aca..6052fa9e84096 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -325,11 +325,12 @@ def range(self, start, end=None, step=1, numPartitions=None): return DataFrame(jdf, self._wrapped) - def _inferSchemaFromList(self, data): + def _inferSchemaFromList(self, data, names=None): """ Infer schema from list of Row or tuple. :param data: list of Row or tuple + :param names: list of column names :return: :class:`pyspark.sql.types.StructType` """ if not data: @@ -338,12 +339,12 @@ def _inferSchemaFromList(self, data): if type(first) is dict: warnings.warn("inferring schema from dict is deprecated," "please use pyspark.sql.Row instead") - schema = reduce(_merge_type, map(_infer_schema, data)) + schema = reduce(_merge_type, (_infer_schema(row, names) for row in data)) if _has_nulltype(schema): raise ValueError("Some of types cannot be determined after inferring") return schema - def _inferSchema(self, rdd, samplingRatio=None): + def _inferSchema(self, rdd, samplingRatio=None, names=None): """ Infer schema from an RDD of Row or tuple. @@ -360,10 +361,10 @@ def _inferSchema(self, rdd, samplingRatio=None): "Use pyspark.sql.Row instead") if samplingRatio is None: - schema = _infer_schema(first) + schema = _infer_schema(first, names=names) if _has_nulltype(schema): for row in rdd.take(100)[1:]: - schema = _merge_type(schema, _infer_schema(row)) + schema = _merge_type(schema, _infer_schema(row, names=names)) if not _has_nulltype(schema): break else: @@ -372,7 +373,7 @@ def _inferSchema(self, rdd, samplingRatio=None): else: if samplingRatio < 0.99: rdd = rdd.sample(False, float(samplingRatio)) - schema = rdd.map(_infer_schema).reduce(_merge_type) + schema = rdd.map(lambda row: _infer_schema(row, names)).reduce(_merge_type) return schema def _createFromRDD(self, rdd, schema, samplingRatio): @@ -380,7 +381,7 @@ def _createFromRDD(self, rdd, schema, samplingRatio): Create an RDD for DataFrame from an existing RDD, returns the RDD and schema. """ if schema is None or isinstance(schema, (list, tuple)): - struct = self._inferSchema(rdd, samplingRatio) + struct = self._inferSchema(rdd, samplingRatio, names=schema) converter = _create_converter(struct) rdd = rdd.map(converter) if isinstance(schema, (list, tuple)): @@ -406,7 +407,7 @@ def _createFromLocal(self, data, schema): data = list(data) if schema is None or isinstance(schema, (list, tuple)): - struct = self._inferSchemaFromList(data) + struct = self._inferSchemaFromList(data, names=schema) converter = _create_converter(struct) data = map(converter, data) if isinstance(schema, (list, tuple)): diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 122a65b83aef9..13576ff57001b 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -68,6 +68,7 @@ from pyspark.sql.types import UserDefinedType, _infer_type, _make_type_verifier from pyspark.sql.types import _array_signed_int_typecode_ctype_mappings, _array_type_mappings from pyspark.sql.types import _array_unsigned_int_typecode_ctype_mappings +from pyspark.sql.types import _merge_type from pyspark.tests import QuietTest, ReusedPySparkTestCase, SparkSubmitTests from pyspark.sql.functions import UserDefinedFunction, sha2, lit from pyspark.sql.window import Window @@ -898,6 +899,15 @@ def test_infer_schema(self): result = self.spark.sql("SELECT l[0].a from test2 where d['key'].d = '2'") self.assertEqual(1, result.head()[0]) + def test_infer_schema_not_enough_names(self): + df = self.spark.createDataFrame([["a", "b"]], ["col1"]) + self.assertEqual(df.columns, ['col1', '_2']) + + def test_infer_schema_fails(self): + with self.assertRaisesRegexp(TypeError, 'field a'): + self.spark.createDataFrame(self.spark.sparkContext.parallelize([[1, 1], ["x", 1]]), + schema=["a", "b"], samplingRatio=0.99) + def test_infer_nested_schema(self): NestedRow = Row("f1", "f2") nestedRdd1 = self.sc.parallelize([NestedRow([1, 2], {"row1": 1.0}), @@ -918,6 +928,10 @@ def test_infer_nested_schema(self): df = self.spark.createDataFrame(rdd) self.assertEqual(Row(field1=1, field2=u'row1'), df.first()) + def test_create_dataframe_from_dict_respects_schema(self): + df = self.spark.createDataFrame([{'a': 1}], ["b"]) + self.assertEqual(df.columns, ['b']) + def test_create_dataframe_from_objects(self): data = [MyObject(1, "1"), MyObject(2, "2")] df = self.spark.createDataFrame(data) @@ -1772,6 +1786,92 @@ def test_infer_long_type(self): self.assertEqual(_infer_type(2**61), LongType()) self.assertEqual(_infer_type(2**71), LongType()) + def test_merge_type(self): + self.assertEqual(_merge_type(LongType(), NullType()), LongType()) + self.assertEqual(_merge_type(NullType(), LongType()), LongType()) + + self.assertEqual(_merge_type(LongType(), LongType()), LongType()) + + self.assertEqual(_merge_type( + ArrayType(LongType()), + ArrayType(LongType()) + ), ArrayType(LongType())) + with self.assertRaisesRegexp(TypeError, 'element in array'): + _merge_type(ArrayType(LongType()), ArrayType(DoubleType())) + + self.assertEqual(_merge_type( + MapType(StringType(), LongType()), + MapType(StringType(), LongType()) + ), MapType(StringType(), LongType())) + with self.assertRaisesRegexp(TypeError, 'key of map'): + _merge_type( + MapType(StringType(), LongType()), + MapType(DoubleType(), LongType())) + with self.assertRaisesRegexp(TypeError, 'value of map'): + _merge_type( + MapType(StringType(), LongType()), + MapType(StringType(), DoubleType())) + + self.assertEqual(_merge_type( + StructType([StructField("f1", LongType()), StructField("f2", StringType())]), + StructType([StructField("f1", LongType()), StructField("f2", StringType())]) + ), StructType([StructField("f1", LongType()), StructField("f2", StringType())])) + with self.assertRaisesRegexp(TypeError, 'field f1'): + _merge_type( + StructType([StructField("f1", LongType()), StructField("f2", StringType())]), + StructType([StructField("f1", DoubleType()), StructField("f2", StringType())])) + + self.assertEqual(_merge_type( + StructType([StructField("f1", StructType([StructField("f2", LongType())]))]), + StructType([StructField("f1", StructType([StructField("f2", LongType())]))]) + ), StructType([StructField("f1", StructType([StructField("f2", LongType())]))])) + with self.assertRaisesRegexp(TypeError, 'field f2 in field f1'): + _merge_type( + StructType([StructField("f1", StructType([StructField("f2", LongType())]))]), + StructType([StructField("f1", StructType([StructField("f2", StringType())]))])) + + self.assertEqual(_merge_type( + StructType([StructField("f1", ArrayType(LongType())), StructField("f2", StringType())]), + StructType([StructField("f1", ArrayType(LongType())), StructField("f2", StringType())]) + ), StructType([StructField("f1", ArrayType(LongType())), StructField("f2", StringType())])) + with self.assertRaisesRegexp(TypeError, 'element in array field f1'): + _merge_type( + StructType([ + StructField("f1", ArrayType(LongType())), + StructField("f2", StringType())]), + StructType([ + StructField("f1", ArrayType(DoubleType())), + StructField("f2", StringType())])) + + self.assertEqual(_merge_type( + StructType([ + StructField("f1", MapType(StringType(), LongType())), + StructField("f2", StringType())]), + StructType([ + StructField("f1", MapType(StringType(), LongType())), + StructField("f2", StringType())]) + ), StructType([ + StructField("f1", MapType(StringType(), LongType())), + StructField("f2", StringType())])) + with self.assertRaisesRegexp(TypeError, 'value of map field f1'): + _merge_type( + StructType([ + StructField("f1", MapType(StringType(), LongType())), + StructField("f2", StringType())]), + StructType([ + StructField("f1", MapType(StringType(), DoubleType())), + StructField("f2", StringType())])) + + self.assertEqual(_merge_type( + StructType([StructField("f1", ArrayType(MapType(StringType(), LongType())))]), + StructType([StructField("f1", ArrayType(MapType(StringType(), LongType())))]) + ), StructType([StructField("f1", ArrayType(MapType(StringType(), LongType())))])) + with self.assertRaisesRegexp(TypeError, 'key of map element in array field f1'): + _merge_type( + StructType([StructField("f1", ArrayType(MapType(StringType(), LongType())))]), + StructType([StructField("f1", ArrayType(MapType(DoubleType(), LongType())))]) + ) + def test_filter_with_datetime(self): time = datetime.datetime(2015, 4, 17, 23, 1, 2, 3000) date = time.date() diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 146e673ae9756..0dc5823f72a3c 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1073,7 +1073,7 @@ def _infer_type(obj): raise TypeError("not supported type: %s" % type(obj)) -def _infer_schema(row): +def _infer_schema(row, names=None): """Infer the schema from dict/namedtuple/object""" if isinstance(row, dict): items = sorted(row.items()) @@ -1084,7 +1084,10 @@ def _infer_schema(row): elif hasattr(row, "_fields"): # namedtuple items = zip(row._fields, tuple(row)) else: - names = ['_%d' % i for i in range(1, len(row) + 1)] + if names is None: + names = ['_%d' % i for i in range(1, len(row) + 1)] + elif len(names) < len(row): + names.extend('_%d' % i for i in range(len(names) + 1, len(row) + 1)) items = zip(names, row) elif hasattr(row, "__dict__"): # object @@ -1109,19 +1112,27 @@ def _has_nulltype(dt): return isinstance(dt, NullType) -def _merge_type(a, b): +def _merge_type(a, b, name=None): + if name is None: + new_msg = lambda msg: msg + new_name = lambda n: "field %s" % n + else: + new_msg = lambda msg: "%s: %s" % (name, msg) + new_name = lambda n: "field %s in %s" % (n, name) + if isinstance(a, NullType): return b elif isinstance(b, NullType): return a elif type(a) is not type(b): # TODO: type cast (such as int -> long) - raise TypeError("Can not merge type %s and %s" % (type(a), type(b))) + raise TypeError(new_msg("Can not merge type %s and %s" % (type(a), type(b)))) # same type if isinstance(a, StructType): nfs = dict((f.name, f.dataType) for f in b.fields) - fields = [StructField(f.name, _merge_type(f.dataType, nfs.get(f.name, NullType()))) + fields = [StructField(f.name, _merge_type(f.dataType, nfs.get(f.name, NullType()), + name=new_name(f.name))) for f in a.fields] names = set([f.name for f in fields]) for n in nfs: @@ -1130,11 +1141,12 @@ def _merge_type(a, b): return StructType(fields) elif isinstance(a, ArrayType): - return ArrayType(_merge_type(a.elementType, b.elementType), True) + return ArrayType(_merge_type(a.elementType, b.elementType, + name='element in array %s' % name), True) elif isinstance(a, MapType): - return MapType(_merge_type(a.keyType, b.keyType), - _merge_type(a.valueType, b.valueType), + return MapType(_merge_type(a.keyType, b.keyType, name='key of map %s' % name), + _merge_type(a.valueType, b.valueType, name='value of map %s' % name), True) else: return a From 8fdeb4b9946bd9be045abb919da2e531708b3bd4 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Mon, 8 Jan 2018 13:59:08 +0800 Subject: [PATCH 0037/1161] [SPARK-22979][PYTHON][SQL] Avoid per-record type dispatch in Python data conversion (EvaluatePython.fromJava) ## What changes were proposed in this pull request? Seems we can avoid type dispatch for each value when Java objection (from Pyrolite) -> Spark's internal data format because we know the schema ahead. I manually performed the benchmark as below: ```scala test("EvaluatePython.fromJava / EvaluatePython.makeFromJava") { val numRows = 1000 * 1000 val numFields = 30 val random = new Random(System.nanoTime()) val types = Array( BooleanType, ByteType, FloatType, DoubleType, IntegerType, LongType, ShortType, DecimalType.ShortDecimal, DecimalType.IntDecimal, DecimalType.ByteDecimal, DecimalType.FloatDecimal, DecimalType.LongDecimal, new DecimalType(5, 2), new DecimalType(12, 2), new DecimalType(30, 10), CalendarIntervalType) val schema = RandomDataGenerator.randomSchema(random, numFields, types) val rows = mutable.ArrayBuffer.empty[Array[Any]] var i = 0 while (i < numRows) { val row = RandomDataGenerator.randomRow(random, schema) rows += row.toSeq.toArray i += 1 } val benchmark = new Benchmark("EvaluatePython.fromJava / EvaluatePython.makeFromJava", numRows) benchmark.addCase("Before - EvaluatePython.fromJava", 3) { _ => var i = 0 while (i < numRows) { EvaluatePython.fromJava(rows(i), schema) i += 1 } } benchmark.addCase("After - EvaluatePython.makeFromJava", 3) { _ => val fromJava = EvaluatePython.makeFromJava(schema) var i = 0 while (i < numRows) { fromJava(rows(i)) i += 1 } } benchmark.run() } ``` ``` EvaluatePython.fromJava / EvaluatePython.makeFromJava: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ Before - EvaluatePython.fromJava 1265 / 1346 0.8 1264.8 1.0X After - EvaluatePython.makeFromJava 571 / 649 1.8 570.8 2.2X ``` If the structure is nested, I think the advantage should be larger than this. ## How was this patch tested? Existing tests should cover this. Also, I manually checked if the values from before / after are actually same via `assert` when performing the benchmarks. Author: hyukjinkwon Closes #20172 from HyukjinKwon/type-dispatch-python-eval. --- .../org/apache/spark/sql/SparkSession.scala | 5 +- .../python/BatchEvalPythonExec.scala | 7 +- .../sql/execution/python/EvaluatePython.scala | 166 ++++++++++++------ 3 files changed, 118 insertions(+), 60 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index 272eb844226d4..734573ba31f71 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -742,7 +742,10 @@ class SparkSession private( private[sql] def applySchemaToPythonRDD( rdd: RDD[Array[Any]], schema: StructType): DataFrame = { - val rowRdd = rdd.map(r => python.EvaluatePython.fromJava(r, schema).asInstanceOf[InternalRow]) + val rowRdd = rdd.mapPartitions { iter => + val fromJava = python.EvaluatePython.makeFromJava(schema) + iter.map(r => fromJava(r).asInstanceOf[InternalRow]) + } internalCreateDataFrame(rowRdd, schema) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala index 26ee25f633ea4..f4d83e8dc7c2b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala @@ -79,16 +79,19 @@ case class BatchEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chi } else { StructType(udfs.map(u => StructField("", u.dataType, u.nullable))) } + + val fromJava = EvaluatePython.makeFromJava(resultType) + outputIterator.flatMap { pickedResult => val unpickledBatch = unpickle.loads(pickedResult) unpickledBatch.asInstanceOf[java.util.ArrayList[Any]].asScala }.map { result => if (udfs.length == 1) { // fast path for single UDF - mutableRow(0) = EvaluatePython.fromJava(result, resultType) + mutableRow(0) = fromJava(result) mutableRow } else { - EvaluatePython.fromJava(result, resultType).asInstanceOf[InternalRow] + fromJava(result).asInstanceOf[InternalRow] } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala index 9bbfa6018ba77..520afad287648 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala @@ -83,82 +83,134 @@ object EvaluatePython { } /** - * Converts `obj` to the type specified by the data type, or returns null if the type of obj is - * unexpected. Because Python doesn't enforce the type. + * Make a converter that converts `obj` to the type specified by the data type, or returns + * null if the type of obj is unexpected. Because Python doesn't enforce the type. */ - def fromJava(obj: Any, dataType: DataType): Any = (obj, dataType) match { - case (null, _) => null - - case (c: Boolean, BooleanType) => c + def makeFromJava(dataType: DataType): Any => Any = dataType match { + case BooleanType => (obj: Any) => nullSafeConvert(obj) { + case b: Boolean => b + } - case (c: Byte, ByteType) => c - case (c: Short, ByteType) => c.toByte - case (c: Int, ByteType) => c.toByte - case (c: Long, ByteType) => c.toByte + case ByteType => (obj: Any) => nullSafeConvert(obj) { + case c: Byte => c + case c: Short => c.toByte + case c: Int => c.toByte + case c: Long => c.toByte + } - case (c: Byte, ShortType) => c.toShort - case (c: Short, ShortType) => c - case (c: Int, ShortType) => c.toShort - case (c: Long, ShortType) => c.toShort + case ShortType => (obj: Any) => nullSafeConvert(obj) { + case c: Byte => c.toShort + case c: Short => c + case c: Int => c.toShort + case c: Long => c.toShort + } - case (c: Byte, IntegerType) => c.toInt - case (c: Short, IntegerType) => c.toInt - case (c: Int, IntegerType) => c - case (c: Long, IntegerType) => c.toInt + case IntegerType => (obj: Any) => nullSafeConvert(obj) { + case c: Byte => c.toInt + case c: Short => c.toInt + case c: Int => c + case c: Long => c.toInt + } - case (c: Byte, LongType) => c.toLong - case (c: Short, LongType) => c.toLong - case (c: Int, LongType) => c.toLong - case (c: Long, LongType) => c + case LongType => (obj: Any) => nullSafeConvert(obj) { + case c: Byte => c.toLong + case c: Short => c.toLong + case c: Int => c.toLong + case c: Long => c + } - case (c: Float, FloatType) => c - case (c: Double, FloatType) => c.toFloat + case FloatType => (obj: Any) => nullSafeConvert(obj) { + case c: Float => c + case c: Double => c.toFloat + } - case (c: Float, DoubleType) => c.toDouble - case (c: Double, DoubleType) => c + case DoubleType => (obj: Any) => nullSafeConvert(obj) { + case c: Float => c.toDouble + case c: Double => c + } - case (c: java.math.BigDecimal, dt: DecimalType) => Decimal(c, dt.precision, dt.scale) + case dt: DecimalType => (obj: Any) => nullSafeConvert(obj) { + case c: java.math.BigDecimal => Decimal(c, dt.precision, dt.scale) + } - case (c: Int, DateType) => c + case DateType => (obj: Any) => nullSafeConvert(obj) { + case c: Int => c + } - case (c: Long, TimestampType) => c - // Py4J serializes values between MIN_INT and MAX_INT as Ints, not Longs - case (c: Int, TimestampType) => c.toLong + case TimestampType => (obj: Any) => nullSafeConvert(obj) { + case c: Long => c + // Py4J serializes values between MIN_INT and MAX_INT as Ints, not Longs + case c: Int => c.toLong + } - case (c, StringType) => UTF8String.fromString(c.toString) + case StringType => (obj: Any) => nullSafeConvert(obj) { + case _ => UTF8String.fromString(obj.toString) + } - case (c: String, BinaryType) => c.getBytes(StandardCharsets.UTF_8) - case (c, BinaryType) if c.getClass.isArray && c.getClass.getComponentType.getName == "byte" => c + case BinaryType => (obj: Any) => nullSafeConvert(obj) { + case c: String => c.getBytes(StandardCharsets.UTF_8) + case c if c.getClass.isArray && c.getClass.getComponentType.getName == "byte" => c + } - case (c: java.util.List[_], ArrayType(elementType, _)) => - new GenericArrayData(c.asScala.map { e => fromJava(e, elementType)}.toArray) + case ArrayType(elementType, _) => + val elementFromJava = makeFromJava(elementType) - case (c, ArrayType(elementType, _)) if c.getClass.isArray => - new GenericArrayData(c.asInstanceOf[Array[_]].map(e => fromJava(e, elementType))) + (obj: Any) => nullSafeConvert(obj) { + case c: java.util.List[_] => + new GenericArrayData(c.asScala.map { e => elementFromJava(e) }.toArray) + case c if c.getClass.isArray => + new GenericArrayData(c.asInstanceOf[Array[_]].map(e => elementFromJava(e))) + } - case (javaMap: java.util.Map[_, _], MapType(keyType, valueType, _)) => - ArrayBasedMapData( - javaMap, - (key: Any) => fromJava(key, keyType), - (value: Any) => fromJava(value, valueType)) + case MapType(keyType, valueType, _) => + val keyFromJava = makeFromJava(keyType) + val valueFromJava = makeFromJava(valueType) + + (obj: Any) => nullSafeConvert(obj) { + case javaMap: java.util.Map[_, _] => + ArrayBasedMapData( + javaMap, + (key: Any) => keyFromJava(key), + (value: Any) => valueFromJava(value)) + } - case (c, StructType(fields)) if c.getClass.isArray => - val array = c.asInstanceOf[Array[_]] - if (array.length != fields.length) { - throw new IllegalStateException( - s"Input row doesn't have expected number of values required by the schema. " + - s"${fields.length} fields are required while ${array.length} values are provided." - ) + case StructType(fields) => + val fieldsFromJava = fields.map(f => makeFromJava(f.dataType)).toArray + + (obj: Any) => nullSafeConvert(obj) { + case c if c.getClass.isArray => + val array = c.asInstanceOf[Array[_]] + if (array.length != fields.length) { + throw new IllegalStateException( + s"Input row doesn't have expected number of values required by the schema. " + + s"${fields.length} fields are required while ${array.length} values are provided." + ) + } + + val row = new GenericInternalRow(fields.length) + var i = 0 + while (i < fields.length) { + row(i) = fieldsFromJava(i)(array(i)) + i += 1 + } + row } - new GenericInternalRow(array.zip(fields).map { - case (e, f) => fromJava(e, f.dataType) - }) - case (_, udt: UserDefinedType[_]) => fromJava(obj, udt.sqlType) + case udt: UserDefinedType[_] => makeFromJava(udt.sqlType) + + case other => (obj: Any) => nullSafeConvert(other)(PartialFunction.empty) + } - // all other unexpected type should be null, or we will have runtime exception - // TODO(davies): we could improve this by try to cast the object to expected type - case (c, _) => null + private def nullSafeConvert(input: Any)(f: PartialFunction[Any, Any]): Any = { + if (input == null) { + null + } else { + f.applyOrElse(input, { + // all other unexpected type should be null, or we will have runtime exception + // TODO(davies): we could improve this by try to cast the object to expected type + _: Any => null + }) + } } private val module = "pyspark.sql.types" From 2c73d2a948bdde798aaf0f87c18846281deb05fd Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 8 Jan 2018 16:04:03 +0800 Subject: [PATCH 0038/1161] [SPARK-22983] Don't push filters beneath aggregates with empty grouping expressions ## What changes were proposed in this pull request? The following SQL query should return zero rows, but in Spark it actually returns one row: ``` SELECT 1 from ( SELECT 1 AS z, MIN(a.x) FROM (select 1 as x) a WHERE false ) b where b.z != b.z ``` The problem stems from the `PushDownPredicate` rule: when this rule encounters a filter on top of an Aggregate operator, e.g. `Filter(Agg(...))`, it removes the original filter and adds a new filter onto Aggregate's child, e.g. `Agg(Filter(...))`. This is sometimes okay, but the case above is a counterexample: because there is no explicit `GROUP BY`, we are implicitly computing a global aggregate over the entire table so the original filter was not acting like a `HAVING` clause filtering the number of groups: if we push this filter then it fails to actually reduce the cardinality of the Aggregate output, leading to the wrong answer. In 2016 I fixed a similar problem involving invalid pushdowns of data-independent filters (filters which reference no columns of the filtered relation). There was additional discussion after my fix was merged which pointed out that my patch was an incomplete fix (see #15289), but it looks I must have either misunderstood the comment or forgot to follow up on the additional points raised there. This patch fixes the problem by choosing to never push down filters in cases where there are no grouping expressions. Since there are no grouping keys, the only columns are aggregate columns and we can't push filters defined over aggregate results, so this change won't cause us to miss out on any legitimate pushdown opportunities. ## How was this patch tested? New regression tests in `SQLQueryTestSuite` and `FilterPushdownSuite`. Author: Josh Rosen Closes #20180 from JoshRosen/SPARK-22983-dont-push-filters-beneath-aggs-with-empty-grouping-expressions. --- .../spark/sql/catalyst/optimizer/Optimizer.scala | 3 ++- .../catalyst/optimizer/FilterPushdownSuite.scala | 13 +++++++++++++ .../test/resources/sql-tests/inputs/group-by.sql | 9 +++++++++ .../resources/sql-tests/results/group-by.sql.out | 16 +++++++++++++++- 4 files changed, 39 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 0d4b02c6e7d8a..df0af8264a329 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -795,7 +795,8 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper { project.copy(child = Filter(replaceAlias(condition, aliasMap), grandChild)) case filter @ Filter(condition, aggregate: Aggregate) - if aggregate.aggregateExpressions.forall(_.deterministic) => + if aggregate.aggregateExpressions.forall(_.deterministic) + && aggregate.groupingExpressions.nonEmpty => // Find all the aliased expressions in the aggregate list that don't include any actual // AggregateExpression, and create a map from the alias to the expression val aliasMap = AttributeMap(aggregate.aggregateExpressions.collect { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index 85a5e979f6021..82a10254d846d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -809,6 +809,19 @@ class FilterPushdownSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + test("aggregate: don't push filters if the aggregate has no grouping expressions") { + val originalQuery = LocalRelation.apply(testRelation.output, Seq.empty) + .select('a, 'b) + .groupBy()(count(1)) + .where(false) + + val optimized = Optimize.execute(originalQuery.analyze) + + val correctAnswer = originalQuery.analyze + + comparePlans(optimized, correctAnswer) + } + test("broadcast hint") { val originalQuery = ResolvedHint(testRelation) .where('a === 2L && 'b + Rand(10).as("rnd") === 3) diff --git a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql index 1e1384549a410..c5070b734d521 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql @@ -60,3 +60,12 @@ SELECT a, COUNT(1) FROM testData WHERE false GROUP BY a; -- Aggregate with empty input and empty GroupBy expressions. SELECT COUNT(1) FROM testData WHERE false; SELECT 1 FROM (SELECT COUNT(1) FROM testData WHERE false) t; + +-- Aggregate with empty GroupBy expressions and filter on top +SELECT 1 from ( + SELECT 1 AS z, + MIN(a.x) + FROM (select 1 as x) a + WHERE false +) b +where b.z != b.z diff --git a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out index 986bb01c13fe4..c1abc6dff754b 100644 --- a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 25 +-- Number of queries: 26 -- !query 0 @@ -227,3 +227,17 @@ SELECT 1 FROM (SELECT COUNT(1) FROM testData WHERE false) t struct<1:int> -- !query 24 output 1 + + +-- !query 25 +SELECT 1 from ( + SELECT 1 AS z, + MIN(a.x) + FROM (select 1 as x) a + WHERE false +) b +where b.z != b.z +-- !query 25 schema +struct<1:int> +-- !query 25 output + From eb45b52e826ea9cea48629760db35ef87f91fea0 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 8 Jan 2018 19:41:41 +0800 Subject: [PATCH 0039/1161] [SPARK-21865][SQL] simplify the distribution semantic of Spark SQL ## What changes were proposed in this pull request? **The current shuffle planning logic** 1. Each operator specifies the distribution requirements for its children, via the `Distribution` interface. 2. Each operator specifies its output partitioning, via the `Partitioning` interface. 3. `Partitioning.satisfy` determines whether a `Partitioning` can satisfy a `Distribution`. 4. For each operator, check each child of it, add a shuffle node above the child if the child partitioning can not satisfy the required distribution. 5. For each operator, check if its children's output partitionings are compatible with each other, via the `Partitioning.compatibleWith`. 6. If the check in 5 failed, add a shuffle above each child. 7. try to eliminate the shuffles added in 6, via `Partitioning.guarantees`. This design has a major problem with the definition of "compatible". `Partitioning.compatibleWith` is not well defined, ideally a `Partitioning` can't know if it's compatible with other `Partitioning`, without more information from the operator. For example, `t1 join t2 on t1.a = t2.b`, `HashPartitioning(a, 10)` should be compatible with `HashPartitioning(b, 10)` under this case, but the partitioning itself doesn't know it. As a result, currently `Partitioning.compatibleWith` always return false except for literals, which make it almost useless. This also means, if an operator has distribution requirements for multiple children, Spark always add shuffle nodes to all the children(although some of them can be eliminated). However, there is no guarantee that the children's output partitionings are compatible with each other after adding these shuffles, we just assume that the operator will only specify `ClusteredDistribution` for multiple children. I think it's very hard to guarantee children co-partition for all kinds of operators, and we can not even give a clear definition about co-partition between distributions like `ClusteredDistribution(a,b)` and `ClusteredDistribution(c)`. I think we should drop the "compatible" concept in the distribution model, and let the operator achieve the co-partition requirement by special distribution requirements. **Proposed shuffle planning logic after this PR** (The first 4 are same as before) 1. Each operator specifies the distribution requirements for its children, via the `Distribution` interface. 2. Each operator specifies its output partitioning, via the `Partitioning` interface. 3. `Partitioning.satisfy` determines whether a `Partitioning` can satisfy a `Distribution`. 4. For each operator, check each child of it, add a shuffle node above the child if the child partitioning can not satisfy the required distribution. 5. For each operator, check if its children's output partitionings have the same number of partitions. 6. If the check in 5 failed, pick the max number of partitions from children's output partitionings, and add shuffle to child whose number of partitions doesn't equal to the max one. The new distribution model is very simple, we only have one kind of relationship, which is `Partitioning.satisfy`. For multiple children, Spark only guarantees they have the same number of partitions, and it's the operator's responsibility to leverage this guarantee to achieve more complicated requirements. For example, non-broadcast joins can use the newly added `HashPartitionedDistribution` to achieve co-partition. ## How was this patch tested? existing tests. Author: Wenchen Fan Closes #19080 from cloud-fan/exchange. --- .../plans/physical/partitioning.scala | 286 +++++++----------- .../sql/catalyst/PartitioningSuite.scala | 55 ---- .../spark/sql/execution/SparkPlan.scala | 16 +- .../exchange/EnsureRequirements.scala | 120 +++----- .../joins/ShuffledHashJoinExec.scala | 2 +- .../execution/joins/SortMergeJoinExec.scala | 2 +- .../apache/spark/sql/execution/objects.scala | 2 +- .../spark/sql/execution/PlannerSuite.scala | 81 ++--- 8 files changed, 194 insertions(+), 370 deletions(-) delete mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/PartitioningSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index e57c842ce2a36..0189bd73c56bf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -30,18 +30,43 @@ import org.apache.spark.sql.types.{DataType, IntegerType} * - Intra-partition ordering of data: In this case the distribution describes guarantees made * about how tuples are distributed within a single partition. */ -sealed trait Distribution +sealed trait Distribution { + /** + * The required number of partitions for this distribution. If it's None, then any number of + * partitions is allowed for this distribution. + */ + def requiredNumPartitions: Option[Int] + + /** + * Creates a default partitioning for this distribution, which can satisfy this distribution while + * matching the given number of partitions. + */ + def createPartitioning(numPartitions: Int): Partitioning +} /** * Represents a distribution where no promises are made about co-location of data. */ -case object UnspecifiedDistribution extends Distribution +case object UnspecifiedDistribution extends Distribution { + override def requiredNumPartitions: Option[Int] = None + + override def createPartitioning(numPartitions: Int): Partitioning = { + throw new IllegalStateException("UnspecifiedDistribution does not have default partitioning.") + } +} /** * Represents a distribution that only has a single partition and all tuples of the dataset * are co-located. */ -case object AllTuples extends Distribution +case object AllTuples extends Distribution { + override def requiredNumPartitions: Option[Int] = Some(1) + + override def createPartitioning(numPartitions: Int): Partitioning = { + assert(numPartitions == 1, "The default partitioning of AllTuples can only have 1 partition.") + SinglePartition + } +} /** * Represents data where tuples that share the same values for the `clustering` @@ -51,12 +76,41 @@ case object AllTuples extends Distribution */ case class ClusteredDistribution( clustering: Seq[Expression], - numPartitions: Option[Int] = None) extends Distribution { + requiredNumPartitions: Option[Int] = None) extends Distribution { require( clustering != Nil, "The clustering expressions of a ClusteredDistribution should not be Nil. " + "An AllTuples should be used to represent a distribution that only has " + "a single partition.") + + override def createPartitioning(numPartitions: Int): Partitioning = { + assert(requiredNumPartitions.isEmpty || requiredNumPartitions.get == numPartitions, + s"This ClusteredDistribution requires ${requiredNumPartitions.get} partitions, but " + + s"the actual number of partitions is $numPartitions.") + HashPartitioning(clustering, numPartitions) + } +} + +/** + * Represents data where tuples have been clustered according to the hash of the given + * `expressions`. The hash function is defined as `HashPartitioning.partitionIdExpression`, so only + * [[HashPartitioning]] can satisfy this distribution. + * + * This is a strictly stronger guarantee than [[ClusteredDistribution]]. Given a tuple and the + * number of partitions, this distribution strictly requires which partition the tuple should be in. + */ +case class HashClusteredDistribution(expressions: Seq[Expression]) extends Distribution { + require( + expressions != Nil, + "The expressions for hash of a HashPartitionedDistribution should not be Nil. " + + "An AllTuples should be used to represent a distribution that only has " + + "a single partition.") + + override def requiredNumPartitions: Option[Int] = None + + override def createPartitioning(numPartitions: Int): Partitioning = { + HashPartitioning(expressions, numPartitions) + } } /** @@ -73,46 +127,31 @@ case class OrderedDistribution(ordering: Seq[SortOrder]) extends Distribution { "An AllTuples should be used to represent a distribution that only has " + "a single partition.") - // TODO: This is not really valid... - def clustering: Set[Expression] = ordering.map(_.child).toSet + override def requiredNumPartitions: Option[Int] = None + + override def createPartitioning(numPartitions: Int): Partitioning = { + RangePartitioning(ordering, numPartitions) + } } /** * Represents data where tuples are broadcasted to every node. It is quite common that the * entire set of tuples is transformed into different data structure. */ -case class BroadcastDistribution(mode: BroadcastMode) extends Distribution +case class BroadcastDistribution(mode: BroadcastMode) extends Distribution { + override def requiredNumPartitions: Option[Int] = Some(1) + + override def createPartitioning(numPartitions: Int): Partitioning = { + assert(numPartitions == 1, + "The default partitioning of BroadcastDistribution can only have 1 partition.") + BroadcastPartitioning(mode) + } +} /** - * Describes how an operator's output is split across partitions. The `compatibleWith`, - * `guarantees`, and `satisfies` methods describe relationships between child partitionings, - * target partitionings, and [[Distribution]]s. These relations are described more precisely in - * their individual method docs, but at a high level: - * - * - `satisfies` is a relationship between partitionings and distributions. - * - `compatibleWith` is relationships between an operator's child output partitionings. - * - `guarantees` is a relationship between a child's existing output partitioning and a target - * output partitioning. - * - * Diagrammatically: - * - * +--------------+ - * | Distribution | - * +--------------+ - * ^ - * | - * satisfies - * | - * +--------------+ +--------------+ - * | Child | | Target | - * +----| Partitioning |----guarantees--->| Partitioning | - * | +--------------+ +--------------+ - * | ^ - * | | - * | compatibleWith - * | | - * +------------+ - * + * Describes how an operator's output is split across partitions. It has 2 major properties: + * 1. number of partitions. + * 2. if it can satisfy a given distribution. */ sealed trait Partitioning { /** Returns the number of partitions that the data is split across */ @@ -123,113 +162,35 @@ sealed trait Partitioning { * to satisfy the partitioning scheme mandated by the `required` [[Distribution]], * i.e. the current dataset does not need to be re-partitioned for the `required` * Distribution (it is possible that tuples within a partition need to be reorganized). - */ - def satisfies(required: Distribution): Boolean - - /** - * Returns true iff we can say that the partitioning scheme of this [[Partitioning]] - * guarantees the same partitioning scheme described by `other`. - * - * Compatibility of partitionings is only checked for operators that have multiple children - * and that require a specific child output [[Distribution]], such as joins. - * - * Intuitively, partitionings are compatible if they route the same partitioning key to the same - * partition. For instance, two hash partitionings are only compatible if they produce the same - * number of output partitionings and hash records according to the same hash function and - * same partitioning key schema. - * - * Put another way, two partitionings are compatible with each other if they satisfy all of the - * same distribution guarantees. - */ - def compatibleWith(other: Partitioning): Boolean - - /** - * Returns true iff we can say that the partitioning scheme of this [[Partitioning]] guarantees - * the same partitioning scheme described by `other`. If a `A.guarantees(B)`, then repartitioning - * the child's output according to `B` will be unnecessary. `guarantees` is used as a performance - * optimization to allow the exchange planner to avoid redundant repartitionings. By default, - * a partitioning only guarantees partitionings that are equal to itself (i.e. the same number - * of partitions, same strategy (range or hash), etc). - * - * In order to enable more aggressive optimization, this strict equality check can be relaxed. - * For example, say that the planner needs to repartition all of an operator's children so that - * they satisfy the [[AllTuples]] distribution. One way to do this is to repartition all children - * to have the [[SinglePartition]] partitioning. If one of the operator's children already happens - * to be hash-partitioned with a single partition then we do not need to re-shuffle this child; - * this repartitioning can be avoided if a single-partition [[HashPartitioning]] `guarantees` - * [[SinglePartition]]. - * - * The SinglePartition example given above is not particularly interesting; guarantees' real - * value occurs for more advanced partitioning strategies. SPARK-7871 will introduce a notion - * of null-safe partitionings, under which partitionings can specify whether rows whose - * partitioning keys contain null values will be grouped into the same partition or whether they - * will have an unknown / random distribution. If a partitioning does not require nulls to be - * clustered then a partitioning which _does_ cluster nulls will guarantee the null clustered - * partitioning. The converse is not true, however: a partitioning which clusters nulls cannot - * be guaranteed by one which does not cluster them. Thus, in general `guarantees` is not a - * symmetric relation. * - * Another way to think about `guarantees`: if `A.guarantees(B)`, then any partitioning of rows - * produced by `A` could have also been produced by `B`. + * By default a [[Partitioning]] can satisfy [[UnspecifiedDistribution]], and [[AllTuples]] if + * the [[Partitioning]] only have one partition. Implementations can overwrite this method with + * special logic. */ - def guarantees(other: Partitioning): Boolean = this == other -} - -object Partitioning { - def allCompatible(partitionings: Seq[Partitioning]): Boolean = { - // Note: this assumes transitivity - partitionings.sliding(2).map { - case Seq(a) => true - case Seq(a, b) => - if (a.numPartitions != b.numPartitions) { - assert(!a.compatibleWith(b) && !b.compatibleWith(a)) - false - } else { - a.compatibleWith(b) && b.compatibleWith(a) - } - }.forall(_ == true) - } -} - -case class UnknownPartitioning(numPartitions: Int) extends Partitioning { - override def satisfies(required: Distribution): Boolean = required match { + def satisfies(required: Distribution): Boolean = required match { case UnspecifiedDistribution => true + case AllTuples => numPartitions == 1 case _ => false } - - override def compatibleWith(other: Partitioning): Boolean = false - - override def guarantees(other: Partitioning): Boolean = false } +case class UnknownPartitioning(numPartitions: Int) extends Partitioning + /** * Represents a partitioning where rows are distributed evenly across output partitions * by starting from a random target partition number and distributing rows in a round-robin * fashion. This partitioning is used when implementing the DataFrame.repartition() operator. */ -case class RoundRobinPartitioning(numPartitions: Int) extends Partitioning { - override def satisfies(required: Distribution): Boolean = required match { - case UnspecifiedDistribution => true - case _ => false - } - - override def compatibleWith(other: Partitioning): Boolean = false - - override def guarantees(other: Partitioning): Boolean = false -} +case class RoundRobinPartitioning(numPartitions: Int) extends Partitioning case object SinglePartition extends Partitioning { val numPartitions = 1 override def satisfies(required: Distribution): Boolean = required match { case _: BroadcastDistribution => false - case ClusteredDistribution(_, desiredPartitions) => desiredPartitions.forall(_ == 1) + case ClusteredDistribution(_, Some(requiredNumPartitions)) => requiredNumPartitions == 1 case _ => true } - - override def compatibleWith(other: Partitioning): Boolean = other.numPartitions == 1 - - override def guarantees(other: Partitioning): Boolean = other.numPartitions == 1 } /** @@ -244,22 +205,19 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) override def nullable: Boolean = false override def dataType: DataType = IntegerType - override def satisfies(required: Distribution): Boolean = required match { - case UnspecifiedDistribution => true - case ClusteredDistribution(requiredClustering, desiredPartitions) => - expressions.forall(x => requiredClustering.exists(_.semanticEquals(x))) && - desiredPartitions.forall(_ == numPartitions) // if desiredPartitions = None, returns true - case _ => false - } - - override def compatibleWith(other: Partitioning): Boolean = other match { - case o: HashPartitioning => this.semanticEquals(o) - case _ => false - } - - override def guarantees(other: Partitioning): Boolean = other match { - case o: HashPartitioning => this.semanticEquals(o) - case _ => false + override def satisfies(required: Distribution): Boolean = { + super.satisfies(required) || { + required match { + case h: HashClusteredDistribution => + expressions.length == h.expressions.length && expressions.zip(h.expressions).forall { + case (l, r) => l.semanticEquals(r) + } + case ClusteredDistribution(requiredClustering, requiredNumPartitions) => + expressions.forall(x => requiredClustering.exists(_.semanticEquals(x))) && + (requiredNumPartitions.isEmpty || requiredNumPartitions.get == numPartitions) + case _ => false + } + } } /** @@ -288,25 +246,18 @@ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int) override def nullable: Boolean = false override def dataType: DataType = IntegerType - override def satisfies(required: Distribution): Boolean = required match { - case UnspecifiedDistribution => true - case OrderedDistribution(requiredOrdering) => - val minSize = Seq(requiredOrdering.size, ordering.size).min - requiredOrdering.take(minSize) == ordering.take(minSize) - case ClusteredDistribution(requiredClustering, desiredPartitions) => - ordering.map(_.child).forall(x => requiredClustering.exists(_.semanticEquals(x))) && - desiredPartitions.forall(_ == numPartitions) // if desiredPartitions = None, returns true - case _ => false - } - - override def compatibleWith(other: Partitioning): Boolean = other match { - case o: RangePartitioning => this.semanticEquals(o) - case _ => false - } - - override def guarantees(other: Partitioning): Boolean = other match { - case o: RangePartitioning => this.semanticEquals(o) - case _ => false + override def satisfies(required: Distribution): Boolean = { + super.satisfies(required) || { + required match { + case OrderedDistribution(requiredOrdering) => + val minSize = Seq(requiredOrdering.size, ordering.size).min + requiredOrdering.take(minSize) == ordering.take(minSize) + case ClusteredDistribution(requiredClustering, requiredNumPartitions) => + ordering.map(_.child).forall(x => requiredClustering.exists(_.semanticEquals(x))) && + (requiredNumPartitions.isEmpty || requiredNumPartitions.get == numPartitions) + case _ => false + } + } } } @@ -347,20 +298,6 @@ case class PartitioningCollection(partitionings: Seq[Partitioning]) override def satisfies(required: Distribution): Boolean = partitionings.exists(_.satisfies(required)) - /** - * Returns true if any `partitioning` of this collection is compatible with - * the given [[Partitioning]]. - */ - override def compatibleWith(other: Partitioning): Boolean = - partitionings.exists(_.compatibleWith(other)) - - /** - * Returns true if any `partitioning` of this collection guarantees - * the given [[Partitioning]]. - */ - override def guarantees(other: Partitioning): Boolean = - partitionings.exists(_.guarantees(other)) - override def toString: String = { partitionings.map(_.toString).mkString("(", " or ", ")") } @@ -377,9 +314,4 @@ case class BroadcastPartitioning(mode: BroadcastMode) extends Partitioning { case BroadcastDistribution(m) if m == mode => true case _ => false } - - override def compatibleWith(other: Partitioning): Boolean = other match { - case BroadcastPartitioning(m) if m == mode => true - case _ => false - } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/PartitioningSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/PartitioningSuite.scala deleted file mode 100644 index 5b802ccc637dd..0000000000000 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/PartitioningSuite.scala +++ /dev/null @@ -1,55 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst - -import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.expressions.{InterpretedMutableProjection, Literal} -import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, HashPartitioning} - -class PartitioningSuite extends SparkFunSuite { - test("HashPartitioning compatibility should be sensitive to expression ordering (SPARK-9785)") { - val expressions = Seq(Literal(2), Literal(3)) - // Consider two HashPartitionings that have the same _set_ of hash expressions but which are - // created with different orderings of those expressions: - val partitioningA = HashPartitioning(expressions, 100) - val partitioningB = HashPartitioning(expressions.reverse, 100) - // These partitionings are not considered equal: - assert(partitioningA != partitioningB) - // However, they both satisfy the same clustered distribution: - val distribution = ClusteredDistribution(expressions) - assert(partitioningA.satisfies(distribution)) - assert(partitioningB.satisfies(distribution)) - // These partitionings compute different hashcodes for the same input row: - def computeHashCode(partitioning: HashPartitioning): Int = { - val hashExprProj = new InterpretedMutableProjection(partitioning.expressions, Seq.empty) - hashExprProj.apply(InternalRow.empty).hashCode() - } - assert(computeHashCode(partitioningA) != computeHashCode(partitioningB)) - // Thus, these partitionings are incompatible: - assert(!partitioningA.compatibleWith(partitioningB)) - assert(!partitioningB.compatibleWith(partitioningA)) - assert(!partitioningA.guarantees(partitioningB)) - assert(!partitioningB.guarantees(partitioningA)) - - // Just to be sure that we haven't cheated by having these methods always return false, - // check that identical partitionings are still compatible with and guarantee each other: - assert(partitioningA === partitioningA) - assert(partitioningA.guarantees(partitioningA)) - assert(partitioningA.compatibleWith(partitioningA)) - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 787c1cfbfb3d8..82300efc01632 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -94,7 +94,21 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ /** Specifies how data is partitioned across different nodes in the cluster. */ def outputPartitioning: Partitioning = UnknownPartitioning(0) // TODO: WRONG WIDTH! - /** Specifies any partition requirements on the input data for this operator. */ + /** + * Specifies the data distribution requirements of all the children for this operator. By default + * it's [[UnspecifiedDistribution]] for each child, which means each child can have any + * distribution. + * + * If an operator overwrites this method, and specifies distribution requirements(excluding + * [[UnspecifiedDistribution]] and [[BroadcastDistribution]]) for more than one child, Spark + * guarantees that the outputs of these children will have same number of partitions, so that the + * operator can safely zip partitions of these children's result RDDs. Some operators can leverage + * this guarantee to satisfy some interesting requirement, e.g., non-broadcast joins can specify + * HashClusteredDistribution(a,b) for its left child, and specify HashClusteredDistribution(c,d) + * for its right child, then it's guaranteed that left and right child are co-partitioned by + * a,b/c,d, which means tuples of same value are in the partitions of same index, e.g., + * (a=1,b=2) and (c=1,d=2) are both in the second partition of left and right child. + */ def requiredChildDistribution: Seq[Distribution] = Seq.fill(children.size)(UnspecifiedDistribution) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index c8e236be28b42..e3d28388c5470 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -46,23 +46,6 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { if (minNumPostShufflePartitions > 0) Some(minNumPostShufflePartitions) else None } - /** - * Given a required distribution, returns a partitioning that satisfies that distribution. - * @param requiredDistribution The distribution that is required by the operator - * @param numPartitions Used when the distribution doesn't require a specific number of partitions - */ - private def createPartitioning( - requiredDistribution: Distribution, - numPartitions: Int): Partitioning = { - requiredDistribution match { - case AllTuples => SinglePartition - case ClusteredDistribution(clustering, desiredPartitions) => - HashPartitioning(clustering, desiredPartitions.getOrElse(numPartitions)) - case OrderedDistribution(ordering) => RangePartitioning(ordering, numPartitions) - case dist => sys.error(s"Do not know how to satisfy distribution $dist") - } - } - /** * Adds [[ExchangeCoordinator]] to [[ShuffleExchangeExec]]s if adaptive query execution is enabled * and partitioning schemes of these [[ShuffleExchangeExec]]s support [[ExchangeCoordinator]]. @@ -88,8 +71,9 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { // shuffle data when we have more than one children because data generated by // these children may not be partitioned in the same way. // Please see the comment in withCoordinator for more details. - val supportsDistribution = - requiredChildDistributions.forall(_.isInstanceOf[ClusteredDistribution]) + val supportsDistribution = requiredChildDistributions.forall { dist => + dist.isInstanceOf[ClusteredDistribution] || dist.isInstanceOf[HashClusteredDistribution] + } children.length > 1 && supportsDistribution } @@ -142,8 +126,7 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { // // It will be great to introduce a new Partitioning to represent the post-shuffle // partitions when one post-shuffle partition includes multiple pre-shuffle partitions. - val targetPartitioning = - createPartitioning(distribution, defaultNumPreShufflePartitions) + val targetPartitioning = distribution.createPartitioning(defaultNumPreShufflePartitions) assert(targetPartitioning.isInstanceOf[HashPartitioning]) ShuffleExchangeExec(targetPartitioning, child, Some(coordinator)) } @@ -162,71 +145,56 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { assert(requiredChildDistributions.length == children.length) assert(requiredChildOrderings.length == children.length) - // Ensure that the operator's children satisfy their output distribution requirements: + // Ensure that the operator's children satisfy their output distribution requirements. children = children.zip(requiredChildDistributions).map { case (child, distribution) if child.outputPartitioning.satisfies(distribution) => child case (child, BroadcastDistribution(mode)) => BroadcastExchangeExec(mode, child) case (child, distribution) => - ShuffleExchangeExec(createPartitioning(distribution, defaultNumPreShufflePartitions), child) + val numPartitions = distribution.requiredNumPartitions + .getOrElse(defaultNumPreShufflePartitions) + ShuffleExchangeExec(distribution.createPartitioning(numPartitions), child) } - // If the operator has multiple children and specifies child output distributions (e.g. join), - // then the children's output partitionings must be compatible: - def requireCompatiblePartitioning(distribution: Distribution): Boolean = distribution match { - case UnspecifiedDistribution => false - case BroadcastDistribution(_) => false + // Get the indexes of children which have specified distribution requirements and need to have + // same number of partitions. + val childrenIndexes = requiredChildDistributions.zipWithIndex.filter { + case (UnspecifiedDistribution, _) => false + case (_: BroadcastDistribution, _) => false case _ => true - } - if (children.length > 1 - && requiredChildDistributions.exists(requireCompatiblePartitioning) - && !Partitioning.allCompatible(children.map(_.outputPartitioning))) { + }.map(_._2) - // First check if the existing partitions of the children all match. This means they are - // partitioned by the same partitioning into the same number of partitions. In that case, - // don't try to make them match `defaultPartitions`, just use the existing partitioning. - val maxChildrenNumPartitions = children.map(_.outputPartitioning.numPartitions).max - val useExistingPartitioning = children.zip(requiredChildDistributions).forall { - case (child, distribution) => - child.outputPartitioning.guarantees( - createPartitioning(distribution, maxChildrenNumPartitions)) + val childrenNumPartitions = + childrenIndexes.map(children(_).outputPartitioning.numPartitions).toSet + + if (childrenNumPartitions.size > 1) { + // Get the number of partitions which is explicitly required by the distributions. + val requiredNumPartitions = { + val numPartitionsSet = childrenIndexes.flatMap { + index => requiredChildDistributions(index).requiredNumPartitions + }.toSet + assert(numPartitionsSet.size <= 1, + s"$operator have incompatible requirements of the number of partitions for its children") + numPartitionsSet.headOption } - children = if (useExistingPartitioning) { - // We do not need to shuffle any child's output. - children - } else { - // We need to shuffle at least one child's output. - // Now, we will determine the number of partitions that will be used by created - // partitioning schemes. - val numPartitions = { - // Let's see if we need to shuffle all child's outputs when we use - // maxChildrenNumPartitions. - val shufflesAllChildren = children.zip(requiredChildDistributions).forall { - case (child, distribution) => - !child.outputPartitioning.guarantees( - createPartitioning(distribution, maxChildrenNumPartitions)) - } - // If we need to shuffle all children, we use defaultNumPreShufflePartitions as the - // number of partitions. Otherwise, we use maxChildrenNumPartitions. - if (shufflesAllChildren) defaultNumPreShufflePartitions else maxChildrenNumPartitions - } + val targetNumPartitions = requiredNumPartitions.getOrElse(childrenNumPartitions.max) - children.zip(requiredChildDistributions).map { - case (child, distribution) => - val targetPartitioning = createPartitioning(distribution, numPartitions) - if (child.outputPartitioning.guarantees(targetPartitioning)) { - child - } else { - child match { - // If child is an exchange, we replace it with - // a new one having targetPartitioning. - case ShuffleExchangeExec(_, c, _) => ShuffleExchangeExec(targetPartitioning, c) - case _ => ShuffleExchangeExec(targetPartitioning, child) - } + children = children.zip(requiredChildDistributions).zipWithIndex.map { + case ((child, distribution), index) if childrenIndexes.contains(index) => + if (child.outputPartitioning.numPartitions == targetNumPartitions) { + child + } else { + val defaultPartitioning = distribution.createPartitioning(targetNumPartitions) + child match { + // If child is an exchange, we replace it with a new one having defaultPartitioning. + case ShuffleExchangeExec(_, c, _) => ShuffleExchangeExec(defaultPartitioning, c) + case _ => ShuffleExchangeExec(defaultPartitioning, child) + } } - } + + case ((child, _), _) => child } } @@ -324,10 +292,10 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { } def apply(plan: SparkPlan): SparkPlan = plan.transformUp { - case operator @ ShuffleExchangeExec(partitioning, child, _) => - child.children match { - case ShuffleExchangeExec(childPartitioning, baseChild, _)::Nil => - if (childPartitioning.guarantees(partitioning)) child else operator + // TODO: remove this after we create a physical operator for `RepartitionByExpression`. + case operator @ ShuffleExchangeExec(upper: HashPartitioning, child, _) => + child.outputPartitioning match { + case lower: HashPartitioning if upper.semanticEquals(lower) => child case _ => operator } case operator: SparkPlan => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala index 66e8031bb5191..897a4dae39f32 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala @@ -46,7 +46,7 @@ case class ShuffledHashJoinExec( "avgHashProbe" -> SQLMetrics.createAverageMetric(sparkContext, "avg hash probe")) override def requiredChildDistribution: Seq[Distribution] = - ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil + HashClusteredDistribution(leftKeys) :: HashClusteredDistribution(rightKeys) :: Nil private def buildHashedRelation(iter: Iterator[InternalRow]): HashedRelation = { val buildDataSize = longMetric("buildDataSize") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index 94405410cce90..2de2f30eb05d3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -78,7 +78,7 @@ case class SortMergeJoinExec( } override def requiredChildDistribution: Seq[Distribution] = - ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil + HashClusteredDistribution(leftKeys) :: HashClusteredDistribution(rightKeys) :: Nil override def outputOrdering: Seq[SortOrder] = joinType match { // For inner join, orders of both sides keys should be kept. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala index d1bd8a7076863..03d1bbf2ab882 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala @@ -456,7 +456,7 @@ case class CoGroupExec( right: SparkPlan) extends BinaryExecNode with ObjectProducerExec { override def requiredChildDistribution: Seq[Distribution] = - ClusteredDistribution(leftGroup) :: ClusteredDistribution(rightGroup) :: Nil + HashClusteredDistribution(leftGroup) :: HashClusteredDistribution(rightGroup) :: Nil override def requiredChildOrdering: Seq[Seq[SortOrder]] = leftGroup.map(SortOrder(_, Ascending)) :: rightGroup.map(SortOrder(_, Ascending)) :: Nil diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index b50642d275ba8..f8b26f5b28cc7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -260,11 +260,16 @@ class PlannerSuite extends SharedSQLContext { // do they satisfy the distribution requirements? As a result, we need at least four test cases. private def assertDistributionRequirementsAreSatisfied(outputPlan: SparkPlan): Unit = { - if (outputPlan.children.length > 1 - && outputPlan.requiredChildDistribution.toSet != Set(UnspecifiedDistribution)) { - val childPartitionings = outputPlan.children.map(_.outputPartitioning) - if (!Partitioning.allCompatible(childPartitionings)) { - fail(s"Partitionings are not compatible: $childPartitionings") + if (outputPlan.children.length > 1) { + val childPartitionings = outputPlan.children.zip(outputPlan.requiredChildDistribution) + .filter { + case (_, UnspecifiedDistribution) => false + case (_, _: BroadcastDistribution) => false + case _ => true + }.map(_._1.outputPartitioning) + + if (childPartitionings.map(_.numPartitions).toSet.size > 1) { + fail(s"Partitionings doesn't have same number of partitions: $childPartitionings") } } outputPlan.children.zip(outputPlan.requiredChildDistribution).foreach { @@ -274,40 +279,7 @@ class PlannerSuite extends SharedSQLContext { } } - test("EnsureRequirements with incompatible child partitionings which satisfy distribution") { - // Consider an operator that requires inputs that are clustered by two expressions (e.g. - // sort merge join where there are multiple columns in the equi-join condition) - val clusteringA = Literal(1) :: Nil - val clusteringB = Literal(2) :: Nil - val distribution = ClusteredDistribution(clusteringA ++ clusteringB) - // Say that the left and right inputs are each partitioned by _one_ of the two join columns: - val leftPartitioning = HashPartitioning(clusteringA, 1) - val rightPartitioning = HashPartitioning(clusteringB, 1) - // Individually, each input's partitioning satisfies the clustering distribution: - assert(leftPartitioning.satisfies(distribution)) - assert(rightPartitioning.satisfies(distribution)) - // However, these partitionings are not compatible with each other, so we still need to - // repartition both inputs prior to performing the join: - assert(!leftPartitioning.compatibleWith(rightPartitioning)) - assert(!rightPartitioning.compatibleWith(leftPartitioning)) - val inputPlan = DummySparkPlan( - children = Seq( - DummySparkPlan(outputPartitioning = leftPartitioning), - DummySparkPlan(outputPartitioning = rightPartitioning) - ), - requiredChildDistribution = Seq(distribution, distribution), - requiredChildOrdering = Seq(Seq.empty, Seq.empty) - ) - val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(inputPlan) - assertDistributionRequirementsAreSatisfied(outputPlan) - if (outputPlan.collect { case e: ShuffleExchangeExec => true }.isEmpty) { - fail(s"Exchange should have been added:\n$outputPlan") - } - } - test("EnsureRequirements with child partitionings with different numbers of output partitions") { - // This is similar to the previous test, except it checks that partitionings are not compatible - // unless they produce the same number of partitions. val clustering = Literal(1) :: Nil val distribution = ClusteredDistribution(clustering) val inputPlan = DummySparkPlan( @@ -386,18 +358,15 @@ class PlannerSuite extends SharedSQLContext { } } - test("EnsureRequirements eliminates Exchange if child has Exchange with same partitioning") { + test("EnsureRequirements eliminates Exchange if child has same partitioning") { val distribution = ClusteredDistribution(Literal(1) :: Nil) - val finalPartitioning = HashPartitioning(Literal(1) :: Nil, 5) - val childPartitioning = HashPartitioning(Literal(2) :: Nil, 5) - assert(!childPartitioning.satisfies(distribution)) - val inputPlan = ShuffleExchangeExec(finalPartitioning, - DummySparkPlan( - children = DummySparkPlan(outputPartitioning = childPartitioning) :: Nil, - requiredChildDistribution = Seq(distribution), - requiredChildOrdering = Seq(Seq.empty)), - None) + val partitioning = HashPartitioning(Literal(1) :: Nil, 5) + assert(partitioning.satisfies(distribution)) + val inputPlan = ShuffleExchangeExec( + partitioning, + DummySparkPlan(outputPartitioning = partitioning), + None) val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) if (outputPlan.collect { case e: ShuffleExchangeExec => true }.size == 2) { @@ -407,17 +376,13 @@ class PlannerSuite extends SharedSQLContext { test("EnsureRequirements does not eliminate Exchange with different partitioning") { val distribution = ClusteredDistribution(Literal(1) :: Nil) - // Number of partitions differ - val finalPartitioning = HashPartitioning(Literal(1) :: Nil, 8) - val childPartitioning = HashPartitioning(Literal(2) :: Nil, 5) - assert(!childPartitioning.satisfies(distribution)) - val inputPlan = ShuffleExchangeExec(finalPartitioning, - DummySparkPlan( - children = DummySparkPlan(outputPartitioning = childPartitioning) :: Nil, - requiredChildDistribution = Seq(distribution), - requiredChildOrdering = Seq(Seq.empty)), - None) + val partitioning = HashPartitioning(Literal(2) :: Nil, 5) + assert(!partitioning.satisfies(distribution)) + val inputPlan = ShuffleExchangeExec( + partitioning, + DummySparkPlan(outputPartitioning = partitioning), + None) val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) if (outputPlan.collect { case e: ShuffleExchangeExec => true }.size == 1) { From 40b983c3b44b6771f07302ce87987fa4716b5ebf Mon Sep 17 00:00:00 2001 From: Xianjin YE Date: Mon, 8 Jan 2018 23:49:07 +0800 Subject: [PATCH 0040/1161] [SPARK-22952][CORE] Deprecate stageAttemptId in favour of stageAttemptNumber ## What changes were proposed in this pull request? 1. Deprecate attemptId in StageInfo and add `def attemptNumber() = attemptId` 2. Replace usage of stageAttemptId with stageAttemptNumber ## How was this patch tested? I manually checked the compiler warning info Author: Xianjin YE Closes #20178 from advancedxy/SPARK-22952. --- .../apache/spark/scheduler/DAGScheduler.scala | 15 +++--- .../apache/spark/scheduler/StageInfo.scala | 4 +- .../spark/scheduler/StatsReportListener.scala | 2 +- .../spark/status/AppStatusListener.scala | 7 +-- .../org/apache/spark/status/LiveEntity.scala | 4 +- .../spark/ui/scope/RDDOperationGraph.scala | 2 +- .../org/apache/spark/util/JsonProtocol.scala | 2 +- .../spark/status/AppStatusListenerSuite.scala | 54 ++++++++++--------- .../execution/ui/SQLAppStatusListener.scala | 2 +- 9 files changed, 51 insertions(+), 41 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index c2498d4808e91..199937b8c27af 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -815,7 +815,8 @@ class DAGScheduler( private[scheduler] def handleBeginEvent(task: Task[_], taskInfo: TaskInfo) { // Note that there is a chance that this task is launched after the stage is cancelled. // In that case, we wouldn't have the stage anymore in stageIdToStage. - val stageAttemptId = stageIdToStage.get(task.stageId).map(_.latestInfo.attemptId).getOrElse(-1) + val stageAttemptId = + stageIdToStage.get(task.stageId).map(_.latestInfo.attemptNumber).getOrElse(-1) listenerBus.post(SparkListenerTaskStart(task.stageId, stageAttemptId, taskInfo)) } @@ -1050,7 +1051,7 @@ class DAGScheduler( val locs = taskIdToLocations(id) val part = stage.rdd.partitions(id) stage.pendingPartitions += id - new ShuffleMapTask(stage.id, stage.latestInfo.attemptId, + new ShuffleMapTask(stage.id, stage.latestInfo.attemptNumber, taskBinary, part, locs, properties, serializedTaskMetrics, Option(jobId), Option(sc.applicationId), sc.applicationAttemptId) } @@ -1060,7 +1061,7 @@ class DAGScheduler( val p: Int = stage.partitions(id) val part = stage.rdd.partitions(p) val locs = taskIdToLocations(id) - new ResultTask(stage.id, stage.latestInfo.attemptId, + new ResultTask(stage.id, stage.latestInfo.attemptNumber, taskBinary, part, locs, id, properties, serializedTaskMetrics, Option(jobId), Option(sc.applicationId), sc.applicationAttemptId) } @@ -1076,7 +1077,7 @@ class DAGScheduler( logInfo(s"Submitting ${tasks.size} missing tasks from $stage (${stage.rdd}) (first 15 " + s"tasks are for partitions ${tasks.take(15).map(_.partitionId)})") taskScheduler.submitTasks(new TaskSet( - tasks.toArray, stage.id, stage.latestInfo.attemptId, jobId, properties)) + tasks.toArray, stage.id, stage.latestInfo.attemptNumber, jobId, properties)) } else { // Because we posted SparkListenerStageSubmitted earlier, we should mark // the stage as completed here in case there are no tasks to run @@ -1245,7 +1246,7 @@ class DAGScheduler( val status = event.result.asInstanceOf[MapStatus] val execId = status.location.executorId logDebug("ShuffleMapTask finished on " + execId) - if (stageIdToStage(task.stageId).latestInfo.attemptId == task.stageAttemptId) { + if (stageIdToStage(task.stageId).latestInfo.attemptNumber == task.stageAttemptId) { // This task was for the currently running attempt of the stage. Since the task // completed successfully from the perspective of the TaskSetManager, mark it as // no longer pending (the TaskSetManager may consider the task complete even @@ -1324,10 +1325,10 @@ class DAGScheduler( val failedStage = stageIdToStage(task.stageId) val mapStage = shuffleIdToMapStage(shuffleId) - if (failedStage.latestInfo.attemptId != task.stageAttemptId) { + if (failedStage.latestInfo.attemptNumber != task.stageAttemptId) { logInfo(s"Ignoring fetch failure from $task as it's from $failedStage attempt" + s" ${task.stageAttemptId} and there is a more recent attempt for that stage " + - s"(attempt ID ${failedStage.latestInfo.attemptId}) running") + s"(attempt ${failedStage.latestInfo.attemptNumber}) running") } else { // It is likely that we receive multiple FetchFailed for a single stage (because we have // multiple tasks running concurrently on different executors). In that case, it is diff --git a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala index c513ed36d1680..903e25b7986f2 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala @@ -30,7 +30,7 @@ import org.apache.spark.storage.RDDInfo @DeveloperApi class StageInfo( val stageId: Int, - val attemptId: Int, + @deprecated("Use attemptNumber instead", "2.3.0") val attemptId: Int, val name: String, val numTasks: Int, val rddInfos: Seq[RDDInfo], @@ -56,6 +56,8 @@ class StageInfo( completionTime = Some(System.currentTimeMillis) } + def attemptNumber(): Int = attemptId + private[spark] def getStatusString: String = { if (completionTime.isDefined) { if (failureReason.isDefined) { diff --git a/core/src/main/scala/org/apache/spark/scheduler/StatsReportListener.scala b/core/src/main/scala/org/apache/spark/scheduler/StatsReportListener.scala index 3c8cab7504c17..3c7af4f6146fa 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/StatsReportListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/StatsReportListener.scala @@ -79,7 +79,7 @@ class StatsReportListener extends SparkListener with Logging { x => info.completionTime.getOrElse(System.currentTimeMillis()) - x ).getOrElse("-") - s"Stage(${info.stageId}, ${info.attemptId}); Name: '${info.name}'; " + + s"Stage(${info.stageId}, ${info.attemptNumber}); Name: '${info.name}'; " + s"Status: ${info.getStatusString}$failureReason; numTasks: ${info.numTasks}; " + s"Took: $timeTaken msec" } diff --git a/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala index 487a782e865e8..88b75ddd5993a 100644 --- a/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala +++ b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala @@ -529,7 +529,8 @@ private[spark] class AppStatusListener( } override def onStageCompleted(event: SparkListenerStageCompleted): Unit = { - val maybeStage = Option(liveStages.remove((event.stageInfo.stageId, event.stageInfo.attemptId))) + val maybeStage = + Option(liveStages.remove((event.stageInfo.stageId, event.stageInfo.attemptNumber))) maybeStage.foreach { stage => val now = System.nanoTime() stage.info = event.stageInfo @@ -785,7 +786,7 @@ private[spark] class AppStatusListener( } private def getOrCreateStage(info: StageInfo): LiveStage = { - val stage = liveStages.computeIfAbsent((info.stageId, info.attemptId), + val stage = liveStages.computeIfAbsent((info.stageId, info.attemptNumber), new Function[(Int, Int), LiveStage]() { override def apply(key: (Int, Int)): LiveStage = new LiveStage() }) @@ -912,7 +913,7 @@ private[spark] class AppStatusListener( private def cleanupTasks(stage: LiveStage): Unit = { val countToDelete = calculateNumberToRemove(stage.savedTasks.get(), maxTasksPerStage).toInt if (countToDelete > 0) { - val stageKey = Array(stage.info.stageId, stage.info.attemptId) + val stageKey = Array(stage.info.stageId, stage.info.attemptNumber) val view = kvstore.view(classOf[TaskDataWrapper]).index("stage").first(stageKey) .last(stageKey) diff --git a/core/src/main/scala/org/apache/spark/status/LiveEntity.scala b/core/src/main/scala/org/apache/spark/status/LiveEntity.scala index 52e83f250d34e..305c2fafa6aac 100644 --- a/core/src/main/scala/org/apache/spark/status/LiveEntity.scala +++ b/core/src/main/scala/org/apache/spark/status/LiveEntity.scala @@ -412,14 +412,14 @@ private class LiveStage extends LiveEntity { def executorSummary(executorId: String): LiveExecutorStageSummary = { executorSummaries.getOrElseUpdate(executorId, - new LiveExecutorStageSummary(info.stageId, info.attemptId, executorId)) + new LiveExecutorStageSummary(info.stageId, info.attemptNumber, executorId)) } def toApi(): v1.StageData = { new v1.StageData( status, info.stageId, - info.attemptId, + info.attemptNumber, info.numTasks, activeTasks, diff --git a/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraph.scala b/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraph.scala index 827a8637b9bd2..948858224d724 100644 --- a/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraph.scala +++ b/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraph.scala @@ -116,7 +116,7 @@ private[spark] object RDDOperationGraph extends Logging { // Use a special prefix here to differentiate this cluster from other operation clusters val stageClusterId = STAGE_CLUSTER_PREFIX + stage.stageId val stageClusterName = s"Stage ${stage.stageId}" + - { if (stage.attemptId == 0) "" else s" (attempt ${stage.attemptId})" } + { if (stage.attemptNumber == 0) "" else s" (attempt ${stage.attemptNumber})" } val rootCluster = new RDDOperationCluster(stageClusterId, stageClusterName) var rootNodeCount = 0 diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index 5e60218c5740b..ff83301d631c4 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -263,7 +263,7 @@ private[spark] object JsonProtocol { val completionTime = stageInfo.completionTime.map(JInt(_)).getOrElse(JNothing) val failureReason = stageInfo.failureReason.map(JString(_)).getOrElse(JNothing) ("Stage ID" -> stageInfo.stageId) ~ - ("Stage Attempt ID" -> stageInfo.attemptId) ~ + ("Stage Attempt ID" -> stageInfo.attemptNumber) ~ ("Stage Name" -> stageInfo.name) ~ ("Number of Tasks" -> stageInfo.numTasks) ~ ("RDD Info" -> rddInfo) ~ diff --git a/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala index 997c7de8dd02b..b8c84e24c2c3f 100644 --- a/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala @@ -195,7 +195,9 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { val s1Tasks = createTasks(4, execIds) s1Tasks.foreach { task => - listener.onTaskStart(SparkListenerTaskStart(stages.head.stageId, stages.head.attemptId, task)) + listener.onTaskStart(SparkListenerTaskStart(stages.head.stageId, + stages.head.attemptNumber, + task)) } assert(store.count(classOf[TaskDataWrapper]) === s1Tasks.size) @@ -213,10 +215,11 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { check[TaskDataWrapper](task.taskId) { wrapper => assert(wrapper.info.taskId === task.taskId) assert(wrapper.stageId === stages.head.stageId) - assert(wrapper.stageAttemptId === stages.head.attemptId) - assert(Arrays.equals(wrapper.stage, Array(stages.head.stageId, stages.head.attemptId))) + assert(wrapper.stageAttemptId === stages.head.attemptNumber) + assert(Arrays.equals(wrapper.stage, Array(stages.head.stageId, stages.head.attemptNumber))) - val runtime = Array[AnyRef](stages.head.stageId: JInteger, stages.head.attemptId: JInteger, + val runtime = Array[AnyRef](stages.head.stageId: JInteger, + stages.head.attemptNumber: JInteger, -1L: JLong) assert(Arrays.equals(wrapper.runtime, runtime)) @@ -237,7 +240,7 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { Some(1L), None, true, false, None) listener.onExecutorMetricsUpdate(SparkListenerExecutorMetricsUpdate( task.executorId, - Seq((task.taskId, stages.head.stageId, stages.head.attemptId, Seq(accum))))) + Seq((task.taskId, stages.head.stageId, stages.head.attemptNumber, Seq(accum))))) } check[StageDataWrapper](key(stages.head)) { stage => @@ -254,12 +257,12 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { // Fail one of the tasks, re-start it. time += 1 s1Tasks.head.markFinished(TaskState.FAILED, time) - listener.onTaskEnd(SparkListenerTaskEnd(stages.head.stageId, stages.head.attemptId, + listener.onTaskEnd(SparkListenerTaskEnd(stages.head.stageId, stages.head.attemptNumber, "taskType", TaskResultLost, s1Tasks.head, null)) time += 1 val reattempt = newAttempt(s1Tasks.head, nextTaskId()) - listener.onTaskStart(SparkListenerTaskStart(stages.head.stageId, stages.head.attemptId, + listener.onTaskStart(SparkListenerTaskStart(stages.head.stageId, stages.head.attemptNumber, reattempt)) assert(store.count(classOf[TaskDataWrapper]) === s1Tasks.size + 1) @@ -289,7 +292,7 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { val killed = s1Tasks.drop(1).head killed.finishTime = time killed.failed = true - listener.onTaskEnd(SparkListenerTaskEnd(stages.head.stageId, stages.head.attemptId, + listener.onTaskEnd(SparkListenerTaskEnd(stages.head.stageId, stages.head.attemptNumber, "taskType", TaskKilled("killed"), killed, null)) check[JobDataWrapper](1) { job => @@ -311,13 +314,13 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { time += 1 val denied = newAttempt(killed, nextTaskId()) val denyReason = TaskCommitDenied(1, 1, 1) - listener.onTaskStart(SparkListenerTaskStart(stages.head.stageId, stages.head.attemptId, + listener.onTaskStart(SparkListenerTaskStart(stages.head.stageId, stages.head.attemptNumber, denied)) time += 1 denied.finishTime = time denied.failed = true - listener.onTaskEnd(SparkListenerTaskEnd(stages.head.stageId, stages.head.attemptId, + listener.onTaskEnd(SparkListenerTaskEnd(stages.head.stageId, stages.head.attemptNumber, "taskType", denyReason, denied, null)) check[JobDataWrapper](1) { job => @@ -337,7 +340,7 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { // Start a new attempt. val reattempt2 = newAttempt(denied, nextTaskId()) - listener.onTaskStart(SparkListenerTaskStart(stages.head.stageId, stages.head.attemptId, + listener.onTaskStart(SparkListenerTaskStart(stages.head.stageId, stages.head.attemptNumber, reattempt2)) // Succeed all tasks in stage 1. @@ -350,7 +353,7 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { time += 1 pending.foreach { task => task.markFinished(TaskState.FINISHED, time) - listener.onTaskEnd(SparkListenerTaskEnd(stages.head.stageId, stages.head.attemptId, + listener.onTaskEnd(SparkListenerTaskEnd(stages.head.stageId, stages.head.attemptNumber, "taskType", Success, task, s1Metrics)) } @@ -414,13 +417,15 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { time += 1 val s2Tasks = createTasks(4, execIds) s2Tasks.foreach { task => - listener.onTaskStart(SparkListenerTaskStart(stages.last.stageId, stages.last.attemptId, task)) + listener.onTaskStart(SparkListenerTaskStart(stages.last.stageId, + stages.last.attemptNumber, + task)) } time += 1 s2Tasks.foreach { task => task.markFinished(TaskState.FAILED, time) - listener.onTaskEnd(SparkListenerTaskEnd(stages.last.stageId, stages.last.attemptId, + listener.onTaskEnd(SparkListenerTaskEnd(stages.last.stageId, stages.last.attemptNumber, "taskType", TaskResultLost, task, null)) } @@ -455,7 +460,7 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { // - Re-submit stage 2, all tasks, and succeed them and the stage. val oldS2 = stages.last - val newS2 = new StageInfo(oldS2.stageId, oldS2.attemptId + 1, oldS2.name, oldS2.numTasks, + val newS2 = new StageInfo(oldS2.stageId, oldS2.attemptNumber + 1, oldS2.name, oldS2.numTasks, oldS2.rddInfos, oldS2.parentIds, oldS2.details, oldS2.taskMetrics) time += 1 @@ -466,14 +471,14 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { val newS2Tasks = createTasks(4, execIds) newS2Tasks.foreach { task => - listener.onTaskStart(SparkListenerTaskStart(newS2.stageId, newS2.attemptId, task)) + listener.onTaskStart(SparkListenerTaskStart(newS2.stageId, newS2.attemptNumber, task)) } time += 1 newS2Tasks.foreach { task => task.markFinished(TaskState.FINISHED, time) - listener.onTaskEnd(SparkListenerTaskEnd(newS2.stageId, newS2.attemptId, "taskType", Success, - task, null)) + listener.onTaskEnd(SparkListenerTaskEnd(newS2.stageId, newS2.attemptNumber, "taskType", + Success, task, null)) } time += 1 @@ -522,14 +527,15 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { val j2s2Tasks = createTasks(4, execIds) j2s2Tasks.foreach { task => - listener.onTaskStart(SparkListenerTaskStart(j2Stages.last.stageId, j2Stages.last.attemptId, + listener.onTaskStart(SparkListenerTaskStart(j2Stages.last.stageId, + j2Stages.last.attemptNumber, task)) } time += 1 j2s2Tasks.foreach { task => task.markFinished(TaskState.FINISHED, time) - listener.onTaskEnd(SparkListenerTaskEnd(j2Stages.last.stageId, j2Stages.last.attemptId, + listener.onTaskEnd(SparkListenerTaskEnd(j2Stages.last.stageId, j2Stages.last.attemptNumber, "taskType", Success, task, null)) } @@ -919,13 +925,13 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { time += 1 val tasks = createTasks(2, Array("1")) tasks.foreach { task => - listener.onTaskStart(SparkListenerTaskStart(attempt2.stageId, attempt2.attemptId, task)) + listener.onTaskStart(SparkListenerTaskStart(attempt2.stageId, attempt2.attemptNumber, task)) } assert(store.count(classOf[TaskDataWrapper]) === 2) // Start a 3rd task. The finished tasks should be deleted. createTasks(1, Array("1")).foreach { task => - listener.onTaskStart(SparkListenerTaskStart(attempt2.stageId, attempt2.attemptId, task)) + listener.onTaskStart(SparkListenerTaskStart(attempt2.stageId, attempt2.attemptNumber, task)) } assert(store.count(classOf[TaskDataWrapper]) === 2) intercept[NoSuchElementException] { @@ -934,7 +940,7 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { // Start a 4th task. The first task should be deleted, even if it's still running. createTasks(1, Array("1")).foreach { task => - listener.onTaskStart(SparkListenerTaskStart(attempt2.stageId, attempt2.attemptId, task)) + listener.onTaskStart(SparkListenerTaskStart(attempt2.stageId, attempt2.attemptNumber, task)) } assert(store.count(classOf[TaskDataWrapper]) === 2) intercept[NoSuchElementException] { @@ -960,7 +966,7 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { } } - private def key(stage: StageInfo): Array[Int] = Array(stage.stageId, stage.attemptId) + private def key(stage: StageInfo): Array[Int] = Array(stage.stageId, stage.attemptNumber) private def check[T: ClassTag](key: Any)(fn: T => Unit): Unit = { val value = store.read(classTag[T].runtimeClass, key).asInstanceOf[T] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala index d8adbe7bee13e..73a105266e1c1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala @@ -99,7 +99,7 @@ class SQLAppStatusListener( // Reset the metrics tracking object for the new attempt. Option(stageMetrics.get(event.stageInfo.stageId)).foreach { metrics => metrics.taskMetrics.clear() - metrics.attemptId = event.stageInfo.attemptId + metrics.attemptId = event.stageInfo.attemptNumber } } From eed82a0b211352215316ec70dc48aefc013ad0b2 Mon Sep 17 00:00:00 2001 From: foxish Date: Mon, 8 Jan 2018 13:01:45 -0800 Subject: [PATCH 0041/1161] [SPARK-22992][K8S] Remove assumption of the DNS domain ## What changes were proposed in this pull request? Remove the use of FQDN to access the driver because it assumes that it's set up in a DNS zone - `cluster.local` which is common but not ubiquitous Note that we already access the in-cluster API server through `kubernetes.default.svc`, so, by extension, this should work as well. The alternative is to introduce DNS zones for both of those addresses. ## How was this patch tested? Unit tests cc vanzin liyinan926 mridulm mccheah Author: foxish Closes #20187 from foxish/cluster.local. --- .../deploy/k8s/submit/steps/DriverServiceBootstrapStep.scala | 2 +- .../k8s/submit/steps/DriverServiceBootstrapStepSuite.scala | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverServiceBootstrapStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverServiceBootstrapStep.scala index eb594e4f16ec0..34af7cde6c1a9 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverServiceBootstrapStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverServiceBootstrapStep.scala @@ -83,7 +83,7 @@ private[spark] class DriverServiceBootstrapStep( .build() val namespace = sparkConf.get(KUBERNETES_NAMESPACE) - val driverHostname = s"${driverService.getMetadata.getName}.$namespace.svc.cluster.local" + val driverHostname = s"${driverService.getMetadata.getName}.$namespace.svc" val resolvedSparkConf = driverSpec.driverSparkConf.clone() .set(DRIVER_HOST_KEY, driverHostname) .set("spark.driver.port", driverPort.toString) diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverServiceBootstrapStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverServiceBootstrapStepSuite.scala index 006ce2668f8a0..78c8c3ba1afbd 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverServiceBootstrapStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverServiceBootstrapStepSuite.scala @@ -85,7 +85,7 @@ class DriverServiceBootstrapStepSuite extends SparkFunSuite with BeforeAndAfter val resolvedDriverSpec = configurationStep.configureDriver(baseDriverSpec) val expectedServiceName = SHORT_RESOURCE_NAME_PREFIX + DriverServiceBootstrapStep.DRIVER_SVC_POSTFIX - val expectedHostName = s"$expectedServiceName.my-namespace.svc.cluster.local" + val expectedHostName = s"$expectedServiceName.my-namespace.svc" verifySparkConfHostNames(resolvedDriverSpec.driverSparkConf, expectedHostName) } @@ -120,7 +120,7 @@ class DriverServiceBootstrapStepSuite extends SparkFunSuite with BeforeAndAfter val driverService = resolvedDriverSpec.otherKubernetesResources.head.asInstanceOf[Service] val expectedServiceName = s"spark-10000${DriverServiceBootstrapStep.DRIVER_SVC_POSTFIX}" assert(driverService.getMetadata.getName === expectedServiceName) - val expectedHostName = s"$expectedServiceName.my-namespace.svc.cluster.local" + val expectedHostName = s"$expectedServiceName.my-namespace.svc" verifySparkConfHostNames(resolvedDriverSpec.driverSparkConf, expectedHostName) } From 4f7e75883436069c2d9028c4cd5daa78e8d59560 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Mon, 8 Jan 2018 13:24:08 -0800 Subject: [PATCH 0042/1161] [SPARK-22912] v2 data source support in MicroBatchExecution ## What changes were proposed in this pull request? Support for v2 data sources in microbatch streaming. ## How was this patch tested? A very basic new unit test on the toy v2 implementation of rate source. Once we have a v1 source fully migrated to v2, we'll need to do more detailed compatibility testing. Author: Jose Torres Closes #20097 from jose-torres/v2-impl. --- ...pache.spark.sql.sources.DataSourceRegister | 1 + .../datasources/v2/DataSourceV2Relation.scala | 10 ++ .../streaming/MicroBatchExecution.scala | 112 ++++++++++++++---- .../streaming/ProgressReporter.scala | 6 +- .../streaming/RateSourceProvider.scala | 10 +- .../execution/streaming/StreamExecution.scala | 4 +- .../streaming/StreamingRelation.scala | 4 +- .../continuous/ContinuousExecution.scala | 4 +- .../ContinuousRateStreamSource.scala | 17 +-- .../sources/RateStreamSourceV2.scala | 31 ++++- .../sql/streaming/DataStreamReader.scala | 25 +++- .../sql/streaming/StreamingQueryManager.scala | 24 ++-- .../streaming/RateSourceV2Suite.scala | 68 +++++++++-- .../spark/sql/streaming/StreamTest.scala | 2 +- .../continuous/ContinuousSuite.scala | 2 +- 15 files changed, 241 insertions(+), 79 deletions(-) diff --git a/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister index 6cdfe2fae5642..0259c774bbf4a 100644 --- a/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister +++ b/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -7,3 +7,4 @@ org.apache.spark.sql.execution.datasources.text.TextFileFormat org.apache.spark.sql.execution.streaming.ConsoleSinkProvider org.apache.spark.sql.execution.streaming.TextSocketSourceProvider org.apache.spark.sql.execution.streaming.RateSourceProvider +org.apache.spark.sql.execution.streaming.sources.RateSourceProviderV2 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala index 7eb99a645001a..cba20dd902007 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala @@ -35,6 +35,16 @@ case class DataSourceV2Relation( } } +/** + * A specialization of DataSourceV2Relation with the streaming bit set to true. Otherwise identical + * to the non-streaming relation. + */ +class StreamingDataSourceV2Relation( + fullOutput: Seq[AttributeReference], + reader: DataSourceV2Reader) extends DataSourceV2Relation(fullOutput, reader) { + override def isStreaming: Boolean = true +} + object DataSourceV2Relation { def apply(reader: DataSourceV2Reader): DataSourceV2Relation = { new DataSourceV2Relation(reader.readSchema().toAttributes, reader) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index 9a7a13fcc5806..42240eeb58d4b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -17,6 +17,9 @@ package org.apache.spark.sql.execution.streaming +import java.util.Optional + +import scala.collection.JavaConverters._ import scala.collection.mutable.{ArrayBuffer, Map => MutableMap} import org.apache.spark.sql.{Dataset, SparkSession} @@ -24,7 +27,10 @@ import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, CurrentBatchTimestamp, CurrentDate, CurrentTimestamp} import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.execution.SQLExecution -import org.apache.spark.sql.sources.v2.streaming.MicroBatchReadSupport +import org.apache.spark.sql.execution.datasources.v2.{StreamingDataSourceV2Relation, WriteToDataSourceV2} +import org.apache.spark.sql.sources.v2.DataSourceV2Options +import org.apache.spark.sql.sources.v2.streaming.{MicroBatchReadSupport, MicroBatchWriteSupport} +import org.apache.spark.sql.sources.v2.streaming.reader.{MicroBatchReader, Offset => OffsetV2} import org.apache.spark.sql.streaming.{OutputMode, ProcessingTime, Trigger} import org.apache.spark.util.{Clock, Utils} @@ -33,10 +39,11 @@ class MicroBatchExecution( name: String, checkpointRoot: String, analyzedPlan: LogicalPlan, - sink: Sink, + sink: BaseStreamingSink, trigger: Trigger, triggerClock: Clock, outputMode: OutputMode, + extraOptions: Map[String, String], deleteCheckpointOnStop: Boolean) extends StreamExecution( sparkSession, name, checkpointRoot, analyzedPlan, sink, @@ -57,6 +64,13 @@ class MicroBatchExecution( var nextSourceId = 0L val toExecutionRelationMap = MutableMap[StreamingRelation, StreamingExecutionRelation]() val v2ToExecutionRelationMap = MutableMap[StreamingRelationV2, StreamingExecutionRelation]() + // We transform each distinct streaming relation into a StreamingExecutionRelation, keeping a + // map as we go to ensure each identical relation gets the same StreamingExecutionRelation + // object. For each microbatch, the StreamingExecutionRelation will be replaced with a logical + // plan for the data within that batch. + // Note that we have to use the previous `output` as attributes in StreamingExecutionRelation, + // since the existing logical plan has already used those attributes. The per-microbatch + // transformation is responsible for replacing attributes with their final values. val _logicalPlan = analyzedPlan.transform { case streamingRelation@StreamingRelation(dataSource, _, output) => toExecutionRelationMap.getOrElseUpdate(streamingRelation, { @@ -64,19 +78,26 @@ class MicroBatchExecution( val metadataPath = s"$resolvedCheckpointRoot/sources/$nextSourceId" val source = dataSource.createSource(metadataPath) nextSourceId += 1 - // We still need to use the previous `output` instead of `source.schema` as attributes in - // "df.logicalPlan" has already used attributes of the previous `output`. StreamingExecutionRelation(source, output)(sparkSession) }) - case s @ StreamingRelationV2(v2DataSource, _, _, output, v1DataSource) - if !v2DataSource.isInstanceOf[MicroBatchReadSupport] => + case s @ StreamingRelationV2(source: MicroBatchReadSupport, _, options, output, _) => + v2ToExecutionRelationMap.getOrElseUpdate(s, { + // Materialize source to avoid creating it in every batch + val metadataPath = s"$resolvedCheckpointRoot/sources/$nextSourceId" + val reader = source.createMicroBatchReader( + Optional.empty(), // user specified schema + metadataPath, + new DataSourceV2Options(options.asJava)) + nextSourceId += 1 + StreamingExecutionRelation(reader, output)(sparkSession) + }) + case s @ StreamingRelationV2(_, _, _, output, v1Relation) => v2ToExecutionRelationMap.getOrElseUpdate(s, { // Materialize source to avoid creating it in every batch val metadataPath = s"$resolvedCheckpointRoot/sources/$nextSourceId" - val source = v1DataSource.createSource(metadataPath) + assert(v1Relation.isDefined, "v2 execution didn't match but v1 was unavailable") + val source = v1Relation.get.dataSource.createSource(metadataPath) nextSourceId += 1 - // We still need to use the previous `output` instead of `source.schema` as attributes in - // "df.logicalPlan" has already used attributes of the previous `output`. StreamingExecutionRelation(source, output)(sparkSession) }) } @@ -192,7 +213,8 @@ class MicroBatchExecution( source.getBatch(start, end) } case nonV1Tuple => - throw new IllegalStateException(s"Unexpected V2 source in $nonV1Tuple") + // The V2 API does not have the same edge case requiring getBatch to be called + // here, so we do nothing here. } currentBatchId = latestCommittedBatchId + 1 committedOffsets ++= availableOffsets @@ -236,14 +258,27 @@ class MicroBatchExecution( val hasNewData = { awaitProgressLock.lock() try { - val latestOffsets: Map[Source, Option[Offset]] = uniqueSources.map { + // Generate a map from each unique source to the next available offset. + val latestOffsets: Map[BaseStreamingSource, Option[Offset]] = uniqueSources.map { case s: Source => updateStatusMessage(s"Getting offsets from $s") reportTimeTaken("getOffset") { (s, s.getOffset) } + case s: MicroBatchReader => + updateStatusMessage(s"Getting offsets from $s") + reportTimeTaken("getOffset") { + // Once v1 streaming source execution is gone, we can refactor this away. + // For now, we set the range here to get the source to infer the available end offset, + // get that offset, and then set the range again when we later execute. + s.setOffsetRange( + toJava(availableOffsets.get(s).map(off => s.deserializeOffset(off.json))), + Optional.empty()) + + (s, Some(s.getEndOffset)) + } }.toMap - availableOffsets ++= latestOffsets.filter { case (s, o) => o.nonEmpty }.mapValues(_.get) + availableOffsets ++= latestOffsets.filter { case (_, o) => o.nonEmpty }.mapValues(_.get) if (dataAvailable) { true @@ -317,6 +352,8 @@ class MicroBatchExecution( if (prevBatchOff.isDefined) { prevBatchOff.get.toStreamProgress(sources).foreach { case (src: Source, off) => src.commit(off) + case (reader: MicroBatchReader, off) => + reader.commit(reader.deserializeOffset(off.json)) } } else { throw new IllegalStateException(s"batch $currentBatchId doesn't exist") @@ -357,7 +394,16 @@ class MicroBatchExecution( s"DataFrame returned by getBatch from $source did not have isStreaming=true\n" + s"${batch.queryExecution.logical}") logDebug(s"Retrieving data from $source: $current -> $available") - Some(source -> batch) + Some(source -> batch.logicalPlan) + case (reader: MicroBatchReader, available) + if committedOffsets.get(reader).map(_ != available).getOrElse(true) => + val current = committedOffsets.get(reader).map(off => reader.deserializeOffset(off.json)) + reader.setOffsetRange( + toJava(current), + Optional.of(available.asInstanceOf[OffsetV2])) + logDebug(s"Retrieving data from $reader: $current -> $available") + Some(reader -> + new StreamingDataSourceV2Relation(reader.readSchema().toAttributes, reader)) case _ => None } } @@ -365,15 +411,14 @@ class MicroBatchExecution( // A list of attributes that will need to be updated. val replacements = new ArrayBuffer[(Attribute, Attribute)] // Replace sources in the logical plan with data that has arrived since the last batch. - val withNewSources = logicalPlan transform { + val newBatchesPlan = logicalPlan transform { case StreamingExecutionRelation(source, output) => - newData.get(source).map { data => - val newPlan = data.logicalPlan - assert(output.size == newPlan.output.size, + newData.get(source).map { dataPlan => + assert(output.size == dataPlan.output.size, s"Invalid batch: ${Utils.truncatedString(output, ",")} != " + - s"${Utils.truncatedString(newPlan.output, ",")}") - replacements ++= output.zip(newPlan.output) - newPlan + s"${Utils.truncatedString(dataPlan.output, ",")}") + replacements ++= output.zip(dataPlan.output) + dataPlan }.getOrElse { LocalRelation(output, isStreaming = true) } @@ -381,7 +426,7 @@ class MicroBatchExecution( // Rewire the plan to use the new attributes that were returned by the source. val replacementMap = AttributeMap(replacements) - val triggerLogicalPlan = withNewSources transformAllExpressions { + val newAttributePlan = newBatchesPlan transformAllExpressions { case a: Attribute if replacementMap.contains(a) => replacementMap(a).withMetadata(a.metadata) case ct: CurrentTimestamp => @@ -392,6 +437,20 @@ class MicroBatchExecution( cd.dataType, cd.timeZoneId) } + val triggerLogicalPlan = sink match { + case _: Sink => newAttributePlan + case s: MicroBatchWriteSupport => + val writer = s.createMicroBatchWriter( + s"$runId", + currentBatchId, + newAttributePlan.schema, + outputMode, + new DataSourceV2Options(extraOptions.asJava)) + assert(writer.isPresent, "microbatch writer must always be present") + WriteToDataSourceV2(writer.get, newAttributePlan) + case _ => throw new IllegalArgumentException(s"unknown sink type for $sink") + } + reportTimeTaken("queryPlanning") { lastExecution = new IncrementalExecution( sparkSessionToRunBatch, @@ -409,7 +468,12 @@ class MicroBatchExecution( reportTimeTaken("addBatch") { SQLExecution.withNewExecutionId(sparkSessionToRunBatch, lastExecution) { - sink.addBatch(currentBatchId, nextBatch) + sink match { + case s: Sink => s.addBatch(currentBatchId, nextBatch) + case s: MicroBatchWriteSupport => + // This doesn't accumulate any data - it just forces execution of the microbatch writer. + nextBatch.collect() + } } } @@ -421,4 +485,8 @@ class MicroBatchExecution( awaitProgressLock.unlock() } } + + private def toJava(scalaOption: Option[OffsetV2]): Optional[OffsetV2] = { + Optional.ofNullable(scalaOption.orNull) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala index 1c9043613cb69..d1e5be9c12762 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala @@ -53,7 +53,7 @@ trait ProgressReporter extends Logging { protected def triggerClock: Clock protected def logicalPlan: LogicalPlan protected def lastExecution: QueryExecution - protected def newData: Map[BaseStreamingSource, DataFrame] + protected def newData: Map[BaseStreamingSource, LogicalPlan] protected def availableOffsets: StreamProgress protected def committedOffsets: StreamProgress protected def sources: Seq[BaseStreamingSource] @@ -225,8 +225,8 @@ trait ProgressReporter extends Logging { // // 3. For each source, we sum the metrics of the associated execution plan leaves. // - val logicalPlanLeafToSource = newData.flatMap { case (source, df) => - df.logicalPlan.collectLeaves().map { leaf => leaf -> source } + val logicalPlanLeafToSource = newData.flatMap { case (source, logicalPlan) => + logicalPlan.collectLeaves().map { leaf => leaf -> source } } val allLogicalPlanLeaves = lastExecution.logical.collectLeaves() // includes non-streaming val allExecPlanLeaves = lastExecution.executedPlan.collectLeaves() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala index d02cf882b61ac..66eb0169ac1ec 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala @@ -29,12 +29,12 @@ import org.apache.spark.network.util.JavaUtils import org.apache.spark.sql.{AnalysisException, DataFrame, SQLContext} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} -import org.apache.spark.sql.execution.streaming.continuous.ContinuousRateStreamReader -import org.apache.spark.sql.execution.streaming.sources.RateStreamV2Reader +import org.apache.spark.sql.execution.streaming.continuous.RateStreamContinuousReader +import org.apache.spark.sql.execution.streaming.sources.RateStreamMicroBatchReader import org.apache.spark.sql.sources.{DataSourceRegister, StreamSourceProvider} import org.apache.spark.sql.sources.v2._ -import org.apache.spark.sql.sources.v2.streaming.ContinuousReadSupport -import org.apache.spark.sql.sources.v2.streaming.reader.ContinuousReader +import org.apache.spark.sql.sources.v2.streaming.{ContinuousReadSupport, MicroBatchReadSupport} +import org.apache.spark.sql.sources.v2.streaming.reader.{ContinuousReader, MicroBatchReader} import org.apache.spark.sql.types._ import org.apache.spark.util.{ManualClock, SystemClock} @@ -112,7 +112,7 @@ class RateSourceProvider extends StreamSourceProvider with DataSourceRegister schema: Optional[StructType], checkpointLocation: String, options: DataSourceV2Options): ContinuousReader = { - new ContinuousRateStreamReader(options) + new RateStreamContinuousReader(options) } override def shortName(): String = "rate" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index 3e76bf7b7ca8f..24a8b000df0c1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -163,7 +163,7 @@ abstract class StreamExecution( var lastExecution: IncrementalExecution = _ /** Holds the most recent input data for each source. */ - protected var newData: Map[BaseStreamingSource, DataFrame] = _ + protected var newData: Map[BaseStreamingSource, LogicalPlan] = _ @volatile protected var streamDeathCause: StreamingQueryException = null @@ -418,7 +418,7 @@ abstract class StreamExecution( * Blocks the current thread until processing for data from the given `source` has reached at * least the given `Offset`. This method is intended for use primarily when writing tests. */ - private[sql] def awaitOffset(source: Source, newOffset: Offset): Unit = { + private[sql] def awaitOffset(source: BaseStreamingSource, newOffset: Offset): Unit = { assertAwaitThread() def notDone = { val localCommittedOffsets = committedOffsets diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala index a9d50e3a112e7..a0ee683a895d8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala @@ -61,7 +61,7 @@ case class StreamingRelation(dataSource: DataSource, sourceName: String, output: * [[org.apache.spark.sql.catalyst.plans.logical.LogicalPlan]]. */ case class StreamingExecutionRelation( - source: Source, + source: BaseStreamingSource, output: Seq[Attribute])(session: SparkSession) extends LeafNode { @@ -92,7 +92,7 @@ case class StreamingRelationV2( sourceName: String, extraOptions: Map[String, String], output: Seq[Attribute], - v1DataSource: DataSource)(session: SparkSession) + v1Relation: Option[StreamingRelation])(session: SparkSession) extends LeafNode { override def isStreaming: Boolean = true override def toString: String = sourceName diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index 2843ab13bde2b..9657b5e26d770 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.{AnalysisException, SparkSession} import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, CurrentBatchTimestamp, CurrentDate, CurrentTimestamp} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.SQLExecution -import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, WriteToDataSourceV2} +import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, StreamingDataSourceV2Relation, WriteToDataSourceV2} import org.apache.spark.sql.execution.streaming.{ContinuousExecutionRelation, StreamingRelationV2, _} import org.apache.spark.sql.sources.v2.DataSourceV2Options import org.apache.spark.sql.sources.v2.streaming.{ContinuousReadSupport, ContinuousWriteSupport} @@ -174,7 +174,7 @@ class ContinuousExecution( val loggedOffset = offsets.offsets(0) val realOffset = loggedOffset.map(off => reader.deserializeOffset(off.json)) reader.setOffset(java.util.Optional.ofNullable(realOffset.orNull)) - DataSourceV2Relation(newOutput, reader) + new StreamingDataSourceV2Relation(newOutput, reader) } // Rewire the plan to use the new attributes that were returned by the source. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala index c9aa78a5a2e28..b4b21e7d2052f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala @@ -32,10 +32,10 @@ import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.sources.v2.streaming.reader.{ContinuousDataReader, ContinuousReader, Offset, PartitionOffset} import org.apache.spark.sql.types.{LongType, StructField, StructType, TimestampType} -case class ContinuousRateStreamPartitionOffset( +case class RateStreamPartitionOffset( partition: Int, currentValue: Long, currentTimeMs: Long) extends PartitionOffset -class ContinuousRateStreamReader(options: DataSourceV2Options) +class RateStreamContinuousReader(options: DataSourceV2Options) extends ContinuousReader { implicit val defaultFormats: DefaultFormats = DefaultFormats @@ -48,7 +48,7 @@ class ContinuousRateStreamReader(options: DataSourceV2Options) override def mergeOffsets(offsets: Array[PartitionOffset]): Offset = { assert(offsets.length == numPartitions) val tuples = offsets.map { - case ContinuousRateStreamPartitionOffset(i, currVal, nextRead) => + case RateStreamPartitionOffset(i, currVal, nextRead) => (i, ValueRunTimeMsPair(currVal, nextRead)) } RateStreamOffset(Map(tuples: _*)) @@ -86,7 +86,7 @@ class ContinuousRateStreamReader(options: DataSourceV2Options) val start = partitionStartMap(i) // Have each partition advance by numPartitions each row, with starting points staggered // by their partition index. - RateStreamReadTask( + RateStreamContinuousReadTask( start.value, start.runTimeMs, i, @@ -101,7 +101,7 @@ class ContinuousRateStreamReader(options: DataSourceV2Options) } -case class RateStreamReadTask( +case class RateStreamContinuousReadTask( startValue: Long, startTimeMs: Long, partitionIndex: Int, @@ -109,10 +109,11 @@ case class RateStreamReadTask( rowsPerSecond: Double) extends ReadTask[Row] { override def createDataReader(): DataReader[Row] = - new RateStreamDataReader(startValue, startTimeMs, partitionIndex, increment, rowsPerSecond) + new RateStreamContinuousDataReader( + startValue, startTimeMs, partitionIndex, increment, rowsPerSecond) } -class RateStreamDataReader( +class RateStreamContinuousDataReader( startValue: Long, startTimeMs: Long, partitionIndex: Int, @@ -151,5 +152,5 @@ class RateStreamDataReader( override def close(): Unit = {} override def getOffset(): PartitionOffset = - ContinuousRateStreamPartitionOffset(partitionIndex, currentValue, nextReadTime) + RateStreamPartitionOffset(partitionIndex, currentValue, nextReadTime) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala index 97bada08bcd2b..c0ed12cec25ef 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala @@ -28,17 +28,38 @@ import org.json4s.jackson.Serialization import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.streaming.{RateStreamOffset, ValueRunTimeMsPair} -import org.apache.spark.sql.sources.v2.DataSourceV2Options +import org.apache.spark.sql.sources.DataSourceRegister +import org.apache.spark.sql.sources.v2.{DataSourceV2, DataSourceV2Options} import org.apache.spark.sql.sources.v2.reader._ +import org.apache.spark.sql.sources.v2.streaming.MicroBatchReadSupport import org.apache.spark.sql.sources.v2.streaming.reader.{MicroBatchReader, Offset} import org.apache.spark.sql.types.{LongType, StructField, StructType, TimestampType} -import org.apache.spark.util.SystemClock +import org.apache.spark.util.{ManualClock, SystemClock} -class RateStreamV2Reader(options: DataSourceV2Options) +/** + * This is a temporary register as we build out v2 migration. Microbatch read support should + * be implemented in the same register as v1. + */ +class RateSourceProviderV2 extends DataSourceV2 with MicroBatchReadSupport with DataSourceRegister { + override def createMicroBatchReader( + schema: Optional[StructType], + checkpointLocation: String, + options: DataSourceV2Options): MicroBatchReader = { + new RateStreamMicroBatchReader(options) + } + + override def shortName(): String = "ratev2" +} + +class RateStreamMicroBatchReader(options: DataSourceV2Options) extends MicroBatchReader { implicit val defaultFormats: DefaultFormats = DefaultFormats - val clock = new SystemClock + val clock = { + // The option to use a manual clock is provided only for unit testing purposes. + if (options.get("useManualClock").orElse("false").toBoolean) new ManualClock + else new SystemClock + } private val numPartitions = options.get(RateStreamSourceV2.NUM_PARTITIONS).orElse("5").toInt @@ -111,7 +132,7 @@ class RateStreamV2Reader(options: DataSourceV2Options) val packedRows = mutable.ListBuffer[(Long, Long)]() var outVal = startVal + numPartitions - var outTimeMs = startTimeMs + msPerPartitionBetweenRows + var outTimeMs = startTimeMs while (outVal <= endVal) { packedRows.append((outTimeMs, outVal)) outVal += numPartitions diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index 2e92beecf2c17..52f2e2639cd86 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.streaming -import java.util.Locale +import java.util.{Locale, Optional} import scala.collection.JavaConverters._ @@ -27,8 +27,9 @@ import org.apache.spark.sql.{AnalysisException, DataFrame, Dataset, SparkSession import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming.{StreamingRelation, StreamingRelationV2} +import org.apache.spark.sql.sources.StreamSourceProvider import org.apache.spark.sql.sources.v2.DataSourceV2Options -import org.apache.spark.sql.sources.v2.streaming.ContinuousReadSupport +import org.apache.spark.sql.sources.v2.streaming.{ContinuousReadSupport, MicroBatchReadSupport} import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils @@ -166,19 +167,31 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo userSpecifiedSchema = userSpecifiedSchema, className = source, options = extraOptions.toMap) + val v1Relation = ds match { + case _: StreamSourceProvider => Some(StreamingRelation(v1DataSource)) + case _ => None + } ds match { + case s: MicroBatchReadSupport => + val tempReader = s.createMicroBatchReader( + Optional.ofNullable(userSpecifiedSchema.orNull), + Utils.createTempDir(namePrefix = s"temporaryReader").getCanonicalPath, + options) + Dataset.ofRows( + sparkSession, + StreamingRelationV2( + s, source, extraOptions.toMap, + tempReader.readSchema().toAttributes, v1Relation)(sparkSession)) case s: ContinuousReadSupport => val tempReader = s.createContinuousReader( - java.util.Optional.ofNullable(userSpecifiedSchema.orNull), + Optional.ofNullable(userSpecifiedSchema.orNull), Utils.createTempDir(namePrefix = s"temporaryReader").getCanonicalPath, options) - // Generate the V1 node to catch errors thrown within generation. - StreamingRelation(v1DataSource) Dataset.ofRows( sparkSession, StreamingRelationV2( s, source, extraOptions.toMap, - tempReader.readSchema().toAttributes, v1DataSource)(sparkSession)) + tempReader.readSchema().toAttributes, v1Relation)(sparkSession)) case _ => // Code path for data source v1. Dataset.ofRows(sparkSession, StreamingRelation(v1DataSource)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala index b508f4406138f..4b27e0d4ef47b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala @@ -29,10 +29,10 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.{AnalysisException, DataFrame, SparkSession} import org.apache.spark.sql.catalyst.analysis.UnsupportedOperationChecker import org.apache.spark.sql.execution.streaming._ -import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution +import org.apache.spark.sql.execution.streaming.continuous.{ContinuousExecution, ContinuousTrigger} import org.apache.spark.sql.execution.streaming.state.StateStoreCoordinatorRef import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.sources.v2.streaming.ContinuousWriteSupport +import org.apache.spark.sql.sources.v2.streaming.{ContinuousWriteSupport, MicroBatchWriteSupport} import org.apache.spark.util.{Clock, SystemClock, Utils} /** @@ -240,31 +240,35 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Lo "is not supported in streaming DataFrames/Datasets and will be disabled.") } - sink match { - case v1Sink: Sink => - new StreamingQueryWrapper(new MicroBatchExecution( + (sink, trigger) match { + case (v2Sink: ContinuousWriteSupport, trigger: ContinuousTrigger) => + UnsupportedOperationChecker.checkForContinuous(analyzedPlan, outputMode) + new StreamingQueryWrapper(new ContinuousExecution( sparkSession, userSpecifiedName.orNull, checkpointLocation, analyzedPlan, - v1Sink, + v2Sink, trigger, triggerClock, outputMode, + extraOptions, deleteCheckpointOnStop)) - case v2Sink: ContinuousWriteSupport => - UnsupportedOperationChecker.checkForContinuous(analyzedPlan, outputMode) - new StreamingQueryWrapper(new ContinuousExecution( + case (_: MicroBatchWriteSupport, _) | (_: Sink, _) => + new StreamingQueryWrapper(new MicroBatchExecution( sparkSession, userSpecifiedName.orNull, checkpointLocation, analyzedPlan, - v2Sink, + sink, trigger, triggerClock, outputMode, extraOptions, deleteCheckpointOnStop)) + case (_: ContinuousWriteSupport, t) if !t.isInstanceOf[ContinuousTrigger] => + throw new AnalysisException( + "Sink only supports continuous writes, but a continuous trigger was not specified.") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala index e11705a227f48..85085d43061bd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala @@ -18,20 +18,64 @@ package org.apache.spark.sql.execution.streaming import java.util.Optional +import java.util.concurrent.TimeUnit import scala.collection.JavaConverters._ import org.apache.spark.sql.Row import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming.continuous._ -import org.apache.spark.sql.execution.streaming.sources.{RateStreamBatchTask, RateStreamSourceV2, RateStreamV2Reader} +import org.apache.spark.sql.execution.streaming.sources.{RateStreamBatchTask, RateStreamMicroBatchReader, RateStreamSourceV2} import org.apache.spark.sql.sources.v2.DataSourceV2Options -import org.apache.spark.sql.sources.v2.streaming.ContinuousReadSupport +import org.apache.spark.sql.sources.v2.streaming.{ContinuousReadSupport, MicroBatchReadSupport} import org.apache.spark.sql.streaming.StreamTest +import org.apache.spark.util.ManualClock class RateSourceV2Suite extends StreamTest { + import testImplicits._ + + case class AdvanceRateManualClock(seconds: Long) extends AddData { + override def addData(query: Option[StreamExecution]): (BaseStreamingSource, Offset) = { + assert(query.nonEmpty) + val rateSource = query.get.logicalPlan.collect { + case StreamingExecutionRelation(source: RateStreamMicroBatchReader, _) => source + }.head + rateSource.clock.asInstanceOf[ManualClock].advance(TimeUnit.SECONDS.toMillis(seconds)) + rateSource.setOffsetRange(Optional.empty(), Optional.empty()) + (rateSource, rateSource.getEndOffset()) + } + } + + test("microbatch in registry") { + DataSource.lookupDataSource("ratev2", spark.sqlContext.conf).newInstance() match { + case ds: MicroBatchReadSupport => + val reader = ds.createMicroBatchReader(Optional.empty(), "", DataSourceV2Options.empty()) + assert(reader.isInstanceOf[RateStreamMicroBatchReader]) + case _ => + throw new IllegalStateException("Could not find v2 read support for rate") + } + } + + test("basic microbatch execution") { + val input = spark.readStream + .format("rateV2") + .option("numPartitions", "1") + .option("rowsPerSecond", "10") + .option("useManualClock", "true") + .load() + testStream(input, useV2Sink = true)( + AdvanceRateManualClock(seconds = 1), + CheckLastBatch((0 until 10).map(v => new java.sql.Timestamp(v * 100L) -> v): _*), + StopStream, + StartStream(), + // Advance 2 seconds because creating a new RateSource will also create a new ManualClock + AdvanceRateManualClock(seconds = 2), + CheckLastBatch((10 until 20).map(v => new java.sql.Timestamp(v * 100L) -> v): _*) + ) + } + test("microbatch - numPartitions propagated") { - val reader = new RateStreamV2Reader( + val reader = new RateStreamMicroBatchReader( new DataSourceV2Options(Map("numPartitions" -> "11", "rowsPerSecond" -> "33").asJava)) reader.setOffsetRange(Optional.empty(), Optional.empty()) val tasks = reader.createReadTasks() @@ -39,7 +83,7 @@ class RateSourceV2Suite extends StreamTest { } test("microbatch - set offset") { - val reader = new RateStreamV2Reader(DataSourceV2Options.empty()) + val reader = new RateStreamMicroBatchReader(DataSourceV2Options.empty()) val startOffset = RateStreamOffset(Map((0, ValueRunTimeMsPair(0, 1000)))) val endOffset = RateStreamOffset(Map((0, ValueRunTimeMsPair(0, 2000)))) reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) @@ -48,7 +92,7 @@ class RateSourceV2Suite extends StreamTest { } test("microbatch - infer offsets") { - val reader = new RateStreamV2Reader( + val reader = new RateStreamMicroBatchReader( new DataSourceV2Options(Map("numPartitions" -> "1", "rowsPerSecond" -> "100").asJava)) reader.clock.waitTillTime(reader.clock.getTimeMillis() + 100) reader.setOffsetRange(Optional.empty(), Optional.empty()) @@ -69,7 +113,7 @@ class RateSourceV2Suite extends StreamTest { } test("microbatch - predetermined batch size") { - val reader = new RateStreamV2Reader( + val reader = new RateStreamMicroBatchReader( new DataSourceV2Options(Map("numPartitions" -> "1", "rowsPerSecond" -> "20").asJava)) val startOffset = RateStreamOffset(Map((0, ValueRunTimeMsPair(0, 1000)))) val endOffset = RateStreamOffset(Map((0, ValueRunTimeMsPair(20, 2000)))) @@ -80,7 +124,7 @@ class RateSourceV2Suite extends StreamTest { } test("microbatch - data read") { - val reader = new RateStreamV2Reader( + val reader = new RateStreamMicroBatchReader( new DataSourceV2Options(Map("numPartitions" -> "11", "rowsPerSecond" -> "33").asJava)) val startOffset = RateStreamSourceV2.createInitialOffset(11, reader.creationTimeMs) val endOffset = RateStreamOffset(startOffset.partitionToValueAndRunTimeMs.toSeq.map { @@ -107,14 +151,14 @@ class RateSourceV2Suite extends StreamTest { DataSource.lookupDataSource("rate", spark.sqlContext.conf).newInstance() match { case ds: ContinuousReadSupport => val reader = ds.createContinuousReader(Optional.empty(), "", DataSourceV2Options.empty()) - assert(reader.isInstanceOf[ContinuousRateStreamReader]) + assert(reader.isInstanceOf[RateStreamContinuousReader]) case _ => throw new IllegalStateException("Could not find v2 read support for rate") } } test("continuous data") { - val reader = new ContinuousRateStreamReader( + val reader = new RateStreamContinuousReader( new DataSourceV2Options(Map("numPartitions" -> "2", "rowsPerSecond" -> "20").asJava)) reader.setOffset(Optional.empty()) val tasks = reader.createReadTasks() @@ -122,17 +166,17 @@ class RateSourceV2Suite extends StreamTest { val data = scala.collection.mutable.ListBuffer[Row]() tasks.asScala.foreach { - case t: RateStreamReadTask => + case t: RateStreamContinuousReadTask => val startTimeMs = reader.getStartOffset() .asInstanceOf[RateStreamOffset] .partitionToValueAndRunTimeMs(t.partitionIndex) .runTimeMs - val r = t.createDataReader().asInstanceOf[RateStreamDataReader] + val r = t.createDataReader().asInstanceOf[RateStreamContinuousDataReader] for (rowIndex <- 0 to 9) { r.next() data.append(r.get()) assert(r.getOffset() == - ContinuousRateStreamPartitionOffset( + RateStreamPartitionOffset( t.partitionIndex, t.partitionIndex + rowIndex * 2, startTimeMs + (rowIndex + 1) * 100)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index 4b7f0fbe97d4e..d46461fa9bf6d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -105,7 +105,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be * the active query, and then return the source object the data was added, as well as the * offset of added data. */ - def addData(query: Option[StreamExecution]): (Source, Offset) + def addData(query: Option[StreamExecution]): (BaseStreamingSource, Offset) } /** A trait that can be extended when testing a source. */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala index eda0d8ad48313..9562c10feafe9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala @@ -61,7 +61,7 @@ class ContinuousSuiteBase extends StreamTest { case s: ContinuousExecution => assert(numTriggers >= 2, "must wait for at least 2 triggers to ensure query is initialized") val reader = s.lastExecution.executedPlan.collectFirst { - case DataSourceV2ScanExec(_, r: ContinuousRateStreamReader) => r + case DataSourceV2ScanExec(_, r: RateStreamContinuousReader) => r }.get val deltaMs = numTriggers * 1000 + 300 From 68ce792b5857f0291154f524ac651036db868bb9 Mon Sep 17 00:00:00 2001 From: xubo245 <601450868@qq.com> Date: Tue, 9 Jan 2018 10:15:01 +0800 Subject: [PATCH 0043/1161] [SPARK-22972] Couldn't find corresponding Hive SerDe for data source provider org.apache.spark.sql.hive.orc ## What changes were proposed in this pull request? Fix the warning: Couldn't find corresponding Hive SerDe for data source provider org.apache.spark.sql.hive.orc. ## How was this patch tested? test("SPARK-22972: hive orc source") assert(HiveSerDe.sourceToSerDe("org.apache.spark.sql.hive.orc") .equals(HiveSerDe.sourceToSerDe("orc"))) Author: xubo245 <601450868@qq.com> Closes #20165 from xubo245/HiveSerDe. --- .../apache/spark/sql/internal/HiveSerDe.scala | 1 + .../sql/hive/orc/HiveOrcSourceSuite.scala | 29 +++++++++++++++++++ 2 files changed, 30 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/HiveSerDe.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/HiveSerDe.scala index b9515ec7bca2a..dac463641cfab 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/HiveSerDe.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/HiveSerDe.scala @@ -73,6 +73,7 @@ object HiveSerDe { val key = source.toLowerCase(Locale.ROOT) match { case s if s.startsWith("org.apache.spark.sql.parquet") => "parquet" case s if s.startsWith("org.apache.spark.sql.orc") => "orc" + case s if s.startsWith("org.apache.spark.sql.hive.orc") => "orc" case s if s.equals("orcfile") => "orc" case s if s.equals("parquetfile") => "parquet" case s if s.equals("avrofile") => "avro" diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcSourceSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcSourceSuite.scala index 17b7d8cfe127e..d556a030e2186 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcSourceSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcSourceSuite.scala @@ -20,8 +20,10 @@ package org.apache.spark.sql.hive.orc import java.io.File import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.execution.datasources.orc.OrcSuite import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.internal.HiveSerDe import org.apache.spark.util.Utils class HiveOrcSourceSuite extends OrcSuite with TestHiveSingleton { @@ -62,6 +64,33 @@ class HiveOrcSourceSuite extends OrcSuite with TestHiveSingleton { """.stripMargin) } + test("SPARK-22972: hive orc source") { + val tableName = "normal_orc_as_source_hive" + withTable(tableName) { + sql( + s""" + |CREATE TABLE $tableName + |USING org.apache.spark.sql.hive.orc + |OPTIONS ( + | PATH '${new File(orcTableAsDir.getAbsolutePath).toURI}' + |) + """.stripMargin) + + val tableMetadata = spark.sessionState.catalog.getTableMetadata( + TableIdentifier(tableName)) + assert(tableMetadata.storage.inputFormat == + Option("org.apache.hadoop.hive.ql.io.orc.OrcInputFormat")) + assert(tableMetadata.storage.outputFormat == + Option("org.apache.hadoop.hive.ql.io.orc.OrcOutputFormat")) + assert(tableMetadata.storage.serde == + Option("org.apache.hadoop.hive.ql.io.orc.OrcSerde")) + assert(HiveSerDe.sourceToSerDe("org.apache.spark.sql.hive.orc") + .equals(HiveSerDe.sourceToSerDe("orc"))) + assert(HiveSerDe.sourceToSerDe("org.apache.spark.sql.orc") + .equals(HiveSerDe.sourceToSerDe("orc"))) + } + } + test("SPARK-19459/SPARK-18220: read char/varchar column written by Hive") { val location = Utils.createTempDir() val uri = location.toURI From 849043ce1d28a976659278d29368da0799329db8 Mon Sep 17 00:00:00 2001 From: Wang Gengliang Date: Tue, 9 Jan 2018 10:44:21 +0800 Subject: [PATCH 0044/1161] [SPARK-22990][CORE] Fix method isFairScheduler in JobsTab and StagesTab ## What changes were proposed in this pull request? In current implementation, the function `isFairScheduler` is always false, since it is comparing String with `SchedulingMode` Author: Wang Gengliang Closes #20186 from gengliangwang/isFairScheduler. --- .../src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala | 8 ++++---- .../main/scala/org/apache/spark/ui/jobs/StagesTab.scala | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala index 99eab1b2a27d8..ff1b75e5c5065 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala @@ -34,10 +34,10 @@ private[ui] class JobsTab(parent: SparkUI, store: AppStatusStore) val killEnabled = parent.killEnabled def isFairScheduler: Boolean = { - store.environmentInfo().sparkProperties.toMap - .get("spark.scheduler.mode") - .map { mode => mode == SchedulingMode.FAIR } - .getOrElse(false) + store + .environmentInfo() + .sparkProperties + .contains(("spark.scheduler.mode", SchedulingMode.FAIR.toString)) } def getSparkUser: String = parent.getSparkUser diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala index be05a963f0e68..10b032084ce4f 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala @@ -37,10 +37,10 @@ private[ui] class StagesTab(val parent: SparkUI, val store: AppStatusStore) attachPage(new PoolPage(this)) def isFairScheduler: Boolean = { - store.environmentInfo().sparkProperties.toMap - .get("spark.scheduler.mode") - .map { mode => mode == SchedulingMode.FAIR } - .getOrElse(false) + store + .environmentInfo() + .sparkProperties + .contains(("spark.scheduler.mode", SchedulingMode.FAIR.toString)) } def handleKillRequest(request: HttpServletRequest): Unit = { From f20131dd35939734fe16b0005a086aa72400893b Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 9 Jan 2018 11:49:10 +0800 Subject: [PATCH 0045/1161] [SPARK-22984] Fix incorrect bitmap copying and offset adjustment in GenerateUnsafeRowJoiner ## What changes were proposed in this pull request? This PR fixes a longstanding correctness bug in `GenerateUnsafeRowJoiner`. This class was introduced in https://github.com/apache/spark/pull/7821 (July 2015 / Spark 1.5.0+) and is used to combine pairs of UnsafeRows in TungstenAggregationIterator, CartesianProductExec, and AppendColumns. ### Bugs fixed by this patch 1. **Incorrect combining of null-tracking bitmaps**: when concatenating two UnsafeRows, the implementation "Concatenate the two bitsets together into a single one, taking padding into account". If one row has no columns then it has a bitset size of 0, but the code was incorrectly assuming that if the left row had a non-zero number of fields then the right row would also have at least one field, so it was copying invalid bytes and and treating them as part of the bitset. I'm not sure whether this bug was also present in the original implementation or whether it was introduced in https://github.com/apache/spark/pull/7892 (which fixed another bug in this code). 2. **Incorrect updating of data offsets for null variable-length fields**: after updating the bitsets and copying fixed-length and variable-length data, we need to perform adjustments to the offsets pointing the start of variable length fields's data. The existing code was _conditionally_ adding a fixed offset to correct for the new length of the combined row, but it is unsafe to do this if the variable-length field has a null value: we always represent nulls by storing `0` in the fixed-length slot, but this code was incorrectly incrementing those values. This bug was present since the original version of `GenerateUnsafeRowJoiner`. ### Why this bug remained latent for so long The PR which introduced `GenerateUnsafeRowJoiner` features several randomized tests, including tests of the cases where one side of the join has no fields and where string-valued fields are null. However, the existing assertions were too weak to uncover this bug: - If a null field has a non-zero value in its fixed-length data slot then this will not cause problems for field accesses because the null-tracking bitmap should still be correct and we will not try to use the incorrect offset for anything. - If the null tracking bitmap is corrupted by joining against a row with no fields then the corruption occurs in field numbers past the actual field numbers contained in the row. Thus valid `isNullAt()` calls will not read the incorrectly-set bits. The existing `GenerateUnsafeRowJoinerSuite` tests only exercised `.get()` and `isNullAt()`, but didn't actually check the UnsafeRows for bit-for-bit equality, preventing these bugs from failing assertions. It turns out that there was even a [GenerateUnsafeRowJoinerBitsetSuite](https://github.com/apache/spark/blob/03377d2522776267a07b7d6ae9bddf79a4e0f516/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerBitsetSuite.scala) but it looks like it also didn't catch this problem because it only tested the bitsets in an end-to-end fashion by accessing them through the `UnsafeRow` interface instead of actually comparing the bitsets' bytes. ### Impact of these bugs - This bug will cause `equals()` and `hashCode()` to be incorrect for these rows, which will be problematic in case`GenerateUnsafeRowJoiner`'s results are used as join or grouping keys. - Chained / repeated invocations of `GenerateUnsafeRowJoiner` may result in reads from invalid null bitmap positions causing fields to incorrectly become NULL (see the end-to-end example below). - It looks like this generally only happens in `CartesianProductExec`, which our query optimizer often avoids executing (usually we try to plan a `BroadcastNestedLoopJoin` instead). ### End-to-end test case demonstrating the problem The following query demonstrates how this bug may result in incorrect query results: ```sql set spark.sql.autoBroadcastJoinThreshold=-1; -- Needed to trigger CartesianProductExec create table a as select * from values 1; create table b as select * from values 2; SELECT t3.col1, t1.col1 FROM a t1 CROSS JOIN b t2 CROSS JOIN b t3 ``` This should return `(2, 1)` but instead was returning `(null, 1)`. Column pruning ends up trimming off all columns from `t2`, so when `t2` joins with another table this triggers the bitmap-copying bug. This incorrect bitmap is subsequently copied again when performing the final join, causing the final output to have an incorrectly-set null bit for the first field. ## How was this patch tested? Strengthened the assertions in existing tests in GenerateUnsafeRowJoinerSuite. Also verified that the end-to-end test case which uncovered this now passes. Author: Josh Rosen Closes #20181 from JoshRosen/SPARK-22984-fix-generate-unsaferow-joiner-bitmap-bugs. --- .../codegen/GenerateUnsafeRowJoiner.scala | 52 +++++++++- .../GenerateUnsafeRowJoinerSuite.scala | 95 ++++++++++++++++++- 2 files changed, 138 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala index be5f5a73b5d47..febf7b0c96c2a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala @@ -70,7 +70,7 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U // --------------------- copy bitset from row 1 and row 2 --------------------------- // val copyBitset = Seq.tabulate(outputBitsetWords) { i => - val bits = if (bitset1Remainder > 0) { + val bits = if (bitset1Remainder > 0 && bitset2Words != 0) { if (i < bitset1Words - 1) { s"$getLong(obj1, offset1 + ${i * 8})" } else if (i == bitset1Words - 1) { @@ -152,7 +152,9 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U } else { // Number of bytes to increase for the offset. Note that since in UnsafeRow we store the // offset in the upper 32 bit of the words, we can just shift the offset to the left by - // 32 and increment that amount in place. + // 32 and increment that amount in place. However, we need to handle the important special + // case of a null field, in which case the offset should be zero and should not have a + // shift added to it. val shift = if (i < schema1.size) { s"${(outputBitsetWords - bitset1Words + schema2.size) * 8}L" @@ -160,14 +162,55 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U s"(${(outputBitsetWords - bitset2Words + schema1.size) * 8}L + numBytesVariableRow1)" } val cursor = offset + outputBitsetWords * 8 + i * 8 - s"$putLong(buf, $cursor, $getLong(buf, $cursor) + ($shift << 32));\n" + // UnsafeRow is a little underspecified, so in what follows we'll treat UnsafeRowWriter's + // output as a de-facto specification for the internal layout of data. + // + // Null-valued fields will always have a data offset of 0 because + // UnsafeRowWriter.setNullAt(ordinal) sets the null bit and stores 0 to in field's + // position in the fixed-length section of the row. As a result, we must NOT add + // `shift` to the offset for null fields. + // + // We could perform a null-check here by inspecting the null-tracking bitmap, but doing + // so could be expensive and will add significant bloat to the generated code. Instead, + // we'll rely on the invariant "stored offset == 0 for variable-length data type implies + // that the field's value is null." + // + // To establish that this invariant holds, we'll prove that a non-null field can never + // have a stored offset of 0. There are two cases to consider: + // + // 1. The non-null field's data is of non-zero length: reading this field's value + // must read data from the variable-length section of the row, so the stored offset + // will actually be used in address calculation and must be correct. The offsets + // count bytes from the start of the UnsafeRow so these offsets will always be + // non-zero because the storage of the offsets themselves takes up space at the + // start of the row. + // 2. The non-null field's data is of zero length (i.e. its data is empty). In this + // case, we have to worry about the possibility that an arbitrary offset value was + // stored because we never actually read any bytes using this offset and therefore + // would not crash if it was incorrect. The variable-sized data writing paths in + // UnsafeRowWriter unconditionally calls setOffsetAndSize(ordinal, numBytes) with + // no special handling for the case where `numBytes == 0`. Internally, + // setOffsetAndSize computes the offset without taking the size into account. Thus + // the stored offset is the same non-zero offset that would be used if the field's + // dataSize was non-zero (and in (1) above we've shown that case behaves as we + // expect). + // + // Thus it is safe to perform `existingOffset != 0` checks here in the place of + // more expensive null-bit checks. + s""" + |existingOffset = $getLong(buf, $cursor); + |if (existingOffset != 0) { + | $putLong(buf, $cursor, existingOffset + ($shift << 32)); + |} + """.stripMargin } } val updateOffsets = ctx.splitExpressions( expressions = updateOffset, funcName = "copyBitsetFunc", - arguments = ("long", "numBytesVariableRow1") :: Nil) + arguments = ("long", "numBytesVariableRow1") :: Nil, + makeSplitFunction = (s: String) => "long existingOffset;\n" + s) // ------------------------ Finally, put everything together --------------------------- // val codeBody = s""" @@ -200,6 +243,7 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U | $copyFixedLengthRow2 | $copyVariableLengthRow1 | $copyVariableLengthRow2 + | long existingOffset; | $updateOffsets | | out.pointTo(buf, sizeInBytes); diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerSuite.scala index f203f25ad10d4..75c6beeb32150 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerSuite.scala @@ -22,8 +22,10 @@ import scala.util.Random import org.apache.spark.SparkFunSuite import org.apache.spark.sql.RandomDataGenerator import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} -import org.apache.spark.sql.catalyst.expressions.UnsafeProjection +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.catalyst.expressions.{JoinedRow, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String /** * Test suite for [[GenerateUnsafeRowJoiner]]. @@ -45,6 +47,32 @@ class GenerateUnsafeRowJoinerSuite extends SparkFunSuite { testConcat(64, 64, fixed) } + test("rows with all empty strings") { + val schema = StructType(Seq( + StructField("f1", StringType), StructField("f2", StringType))) + val row: UnsafeRow = UnsafeProjection.create(schema).apply( + InternalRow(UTF8String.EMPTY_UTF8, UTF8String.EMPTY_UTF8)) + testConcat(schema, row, schema, row) + } + + test("rows with all empty int arrays") { + val schema = StructType(Seq( + StructField("f1", ArrayType(IntegerType)), StructField("f2", ArrayType(IntegerType)))) + val emptyIntArray = + ExpressionEncoder[Array[Int]]().resolveAndBind().toRow(Array.emptyIntArray).getArray(0) + val row: UnsafeRow = UnsafeProjection.create(schema).apply( + InternalRow(emptyIntArray, emptyIntArray)) + testConcat(schema, row, schema, row) + } + + test("alternating empty and non-empty strings") { + val schema = StructType(Seq( + StructField("f1", StringType), StructField("f2", StringType))) + val row: UnsafeRow = UnsafeProjection.create(schema).apply( + InternalRow(UTF8String.EMPTY_UTF8, UTF8String.fromString("foo"))) + testConcat(schema, row, schema, row) + } + test("randomized fix width types") { for (i <- 0 until 20) { testConcatOnce(Random.nextInt(100), Random.nextInt(100), fixed) @@ -94,27 +122,84 @@ class GenerateUnsafeRowJoinerSuite extends SparkFunSuite { val extRow2 = RandomDataGenerator.forType(schema2, nullable = false).get.apply() val row1 = converter1.apply(internalConverter1.apply(extRow1).asInstanceOf[InternalRow]) val row2 = converter2.apply(internalConverter2.apply(extRow2).asInstanceOf[InternalRow]) + testConcat(schema1, row1, schema2, row2) + } + + private def testConcat( + schema1: StructType, + row1: UnsafeRow, + schema2: StructType, + row2: UnsafeRow) { // Run the joiner. val mergedSchema = StructType(schema1 ++ schema2) val concater = GenerateUnsafeRowJoiner.create(schema1, schema2) - val output = concater.join(row1, row2) + val output: UnsafeRow = concater.join(row1, row2) + + // We'll also compare to an UnsafeRow produced with JoinedRow + UnsafeProjection. This ensures + // that unused space in the row (e.g. leftover bits in the null-tracking bitmap) is written + // correctly. + val expectedOutput: UnsafeRow = { + val joinedRowProjection = UnsafeProjection.create(mergedSchema) + val joined = new JoinedRow() + joinedRowProjection.apply(joined.apply(row1, row2)) + } // Test everything equals ... for (i <- mergedSchema.indices) { + val dataType = mergedSchema(i).dataType if (i < schema1.size) { assert(output.isNullAt(i) === row1.isNullAt(i)) if (!output.isNullAt(i)) { - assert(output.get(i, mergedSchema(i).dataType) === row1.get(i, mergedSchema(i).dataType)) + assert(output.get(i, dataType) === row1.get(i, dataType)) + assert(output.get(i, dataType) === expectedOutput.get(i, dataType)) } } else { assert(output.isNullAt(i) === row2.isNullAt(i - schema1.size)) if (!output.isNullAt(i)) { - assert(output.get(i, mergedSchema(i).dataType) === - row2.get(i - schema1.size, mergedSchema(i).dataType)) + assert(output.get(i, dataType) === row2.get(i - schema1.size, dataType)) + assert(output.get(i, dataType) === expectedOutput.get(i, dataType)) } } } + + + assert( + expectedOutput.getSizeInBytes == output.getSizeInBytes, + "output isn't same size in bytes as slow path") + + // Compare the UnsafeRows byte-by-byte so that we can print more useful debug information in + // case this assertion fails: + val actualBytes = output.getBaseObject.asInstanceOf[Array[Byte]] + .take(output.getSizeInBytes) + val expectedBytes = expectedOutput.getBaseObject.asInstanceOf[Array[Byte]] + .take(expectedOutput.getSizeInBytes) + + val bitsetWidth = UnsafeRow.calculateBitSetWidthInBytes(expectedOutput.numFields()) + val actualBitset = actualBytes.take(bitsetWidth) + val expectedBitset = expectedBytes.take(bitsetWidth) + assert(actualBitset === expectedBitset, "bitsets were not equal") + + val fixedLengthSize = expectedOutput.numFields() * 8 + val actualFixedLength = actualBytes.slice(bitsetWidth, bitsetWidth + fixedLengthSize) + val expectedFixedLength = expectedBytes.slice(bitsetWidth, bitsetWidth + fixedLengthSize) + if (actualFixedLength !== expectedFixedLength) { + actualFixedLength.grouped(8) + .zip(expectedFixedLength.grouped(8)) + .zip(mergedSchema.fields.toIterator) + .foreach { + case ((actual, expected), field) => + assert(actual === expected, s"Fixed length sections are not equal for field $field") + } + fail("Fixed length sections were not equal") + } + + val variableLengthStart = bitsetWidth + fixedLengthSize + val actualVariableLength = actualBytes.drop(variableLengthStart) + val expectedVariableLength = expectedBytes.drop(variableLengthStart) + assert(actualVariableLength === expectedVariableLength, "fixed length sections were not equal") + + assert(output.hashCode() == expectedOutput.hashCode(), "hash codes were not equal") } } From 8486ad419d8f1779e277ec71c39e1516673a83ab Mon Sep 17 00:00:00 2001 From: Felix Cheung Date: Mon, 8 Jan 2018 21:58:26 -0800 Subject: [PATCH 0046/1161] [SPARK-21292][DOCS] refreshtable example ## What changes were proposed in this pull request? doc update Author: Felix Cheung Closes #20198 from felixcheung/rrefreshdoc. --- docs/sql-programming-guide.md | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 3ccaaf4d5b1fa..72f79d6909ecc 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -915,6 +915,14 @@ spark.catalog.refreshTable("my_table") +
+ +{% highlight r %} +refreshTable("my_table") +{% endhighlight %} + +
+
{% highlight sql %} @@ -1498,10 +1506,10 @@ that these options will be deprecated in future release as more optimizations ar ## Broadcast Hint for SQL Queries The `BROADCAST` hint guides Spark to broadcast each specified table when joining them with another table or view. -When Spark deciding the join methods, the broadcast hash join (i.e., BHJ) is preferred, +When Spark deciding the join methods, the broadcast hash join (i.e., BHJ) is preferred, even if the statistics is above the configuration `spark.sql.autoBroadcastJoinThreshold`. When both sides of a join are specified, Spark broadcasts the one having the lower statistics. -Note Spark does not guarantee BHJ is always chosen, since not all cases (e.g. full outer join) +Note Spark does not guarantee BHJ is always chosen, since not all cases (e.g. full outer join) support BHJ. When the broadcast nested loop join is selected, we still respect the hint.
@@ -1780,7 +1788,7 @@ options. Note that, for DecimalType(38,0)*, the table above intentionally does not cover all other combinations of scales and precisions because currently we only infer decimal type like `BigInteger`/`BigInt`. For example, 1.1 is inferred as double type. - In PySpark, now we need Pandas 0.19.2 or upper if you want to use Pandas related functionalities, such as `toPandas`, `createDataFrame` from Pandas DataFrame, etc. - In PySpark, the behavior of timestamp values for Pandas related functionalities was changed to respect session timezone. If you want to use the old behavior, you need to set a configuration `spark.sql.execution.pandas.respectSessionTimeZone` to `False`. See [SPARK-22395](https://issues.apache.org/jira/browse/SPARK-22395) for details. - + - Since Spark 2.3, when either broadcast hash join or broadcast nested loop join is applicable, we prefer to broadcasting the table that is explicitly specified in a broadcast hint. For details, see the section [Broadcast Hint](#broadcast-hint-for-sql-queries) and [SPARK-22489](https://issues.apache.org/jira/browse/SPARK-22489). - Since Spark 2.3, when all inputs are binary, `functions.concat()` returns an output as binary. Otherwise, it returns as a string. Until Spark 2.3, it always returns as a string despite of input types. To keep the old behavior, set `spark.sql.function.concatBinaryAsString` to `true`. @@ -2167,7 +2175,7 @@ Not all the APIs of the Hive UDF/UDTF/UDAF are supported by Spark SQL. Below are Spark SQL currently does not support the reuse of aggregation. * `getWindowingEvaluator` (`GenericUDAFEvaluator`) is a function to optimize aggregation by evaluating an aggregate over a fixed window. - + ### Incompatible Hive UDF Below are the scenarios in which Hive and Spark generate different results: From 02214b094390e913f52e71d55c9bb8a81c9e7ef9 Mon Sep 17 00:00:00 2001 From: Felix Cheung Date: Mon, 8 Jan 2018 22:08:19 -0800 Subject: [PATCH 0047/1161] [SPARK-21293][SPARKR][DOCS] structured streaming doc update ## What changes were proposed in this pull request? doc update Author: Felix Cheung Closes #20197 from felixcheung/rwadoc. --- R/pkg/vignettes/sparkr-vignettes.Rmd | 2 +- docs/sparkr.md | 2 +- .../structured-streaming-programming-guide.md | 32 +++++++++++++++++-- 3 files changed, 32 insertions(+), 4 deletions(-) diff --git a/R/pkg/vignettes/sparkr-vignettes.Rmd b/R/pkg/vignettes/sparkr-vignettes.Rmd index 2e662424b25f2..feca617c2554c 100644 --- a/R/pkg/vignettes/sparkr-vignettes.Rmd +++ b/R/pkg/vignettes/sparkr-vignettes.Rmd @@ -1042,7 +1042,7 @@ unlink(modelPath) ## Structured Streaming -SparkR supports the Structured Streaming API (experimental). +SparkR supports the Structured Streaming API. You can check the Structured Streaming Programming Guide for [an introduction](https://spark.apache.org/docs/latest/structured-streaming-programming-guide.html#programming-model) to its programming model and basic concepts. diff --git a/docs/sparkr.md b/docs/sparkr.md index 997ea60fb6cf0..6685b585a393a 100644 --- a/docs/sparkr.md +++ b/docs/sparkr.md @@ -596,7 +596,7 @@ The following example shows how to save/load a MLlib model by SparkR. # Structured Streaming -SparkR supports the Structured Streaming API (experimental). Structured Streaming is a scalable and fault-tolerant stream processing engine built on the Spark SQL engine. For more information see the R API on the [Structured Streaming Programming Guide](structured-streaming-programming-guide.html) +SparkR supports the Structured Streaming API. Structured Streaming is a scalable and fault-tolerant stream processing engine built on the Spark SQL engine. For more information see the R API on the [Structured Streaming Programming Guide](structured-streaming-programming-guide.html) # R Function Name Conflicts diff --git a/docs/structured-streaming-programming-guide.md b/docs/structured-streaming-programming-guide.md index 31fcfabb9cacc..de13e281916db 100644 --- a/docs/structured-streaming-programming-guide.md +++ b/docs/structured-streaming-programming-guide.md @@ -827,8 +827,8 @@ df.isStreaming() {% endhighlight %}
-{% highlight bash %} -Not available. +{% highlight r %} +isStreaming(df) {% endhighlight %}
@@ -885,6 +885,19 @@ windowedCounts = words.groupBy( ).count() {% endhighlight %} + +
+{% highlight r %} +words <- ... # streaming DataFrame of schema { timestamp: Timestamp, word: String } + +# Group the data by window and word and compute the count of each group +windowedCounts <- count( + groupBy( + words, + window(words$timestamp, "10 minutes", "5 minutes"), + words$word)) +{% endhighlight %} +
@@ -959,6 +972,21 @@ windowedCounts = words \ .count() {% endhighlight %} + +
+{% highlight r %} +words <- ... # streaming DataFrame of schema { timestamp: Timestamp, word: String } + +# Group the data by window and word and compute the count of each group + +words <- withWatermark(words, "timestamp", "10 minutes") +windowedCounts <- count( + groupBy( + words, + window(words$timestamp, "10 minutes", "5 minutes"), + words$word)) +{% endhighlight %} +
From 0959aa581a399279be3f94214bcdffc6a1b6d60a Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 9 Jan 2018 16:31:20 +0800 Subject: [PATCH 0048/1161] [SPARK-23000] Fix Flaky test suite DataSourceWithHiveMetastoreCatalogSuite in Spark 2.3 ## What changes were proposed in this pull request? https://amplab.cs.berkeley.edu/jenkins/job/spark-branch-2.3-test-sbt-hadoop-2.6/ The test suite DataSourceWithHiveMetastoreCatalogSuite of Branch 2.3 always failed in hadoop 2.6 The table `t` exists in `default`, but `runSQLHive` reported the table does not exist. Obviously, Hive client's default database is different. The fix is to clean the environment and use `DEFAULT` as the database. ``` org.apache.spark.sql.execution.QueryExecutionException: FAILED: SemanticException [Error 10001]: Line 1:14 Table not found 't' Stacktrace sbt.ForkMain$ForkError: org.apache.spark.sql.execution.QueryExecutionException: FAILED: SemanticException [Error 10001]: Line 1:14 Table not found 't' at org.apache.spark.sql.hive.client.HiveClientImpl$$anonfun$runHive$1.apply(HiveClientImpl.scala:699) at org.apache.spark.sql.hive.client.HiveClientImpl$$anonfun$runHive$1.apply(HiveClientImpl.scala:683) at org.apache.spark.sql.hive.client.HiveClientImpl$$anonfun$withHiveState$1.apply(HiveClientImpl.scala:272) at org.apache.spark.sql.hive.client.HiveClientImpl.liftedTree1$1(HiveClientImpl.scala:210) at org.apache.spark.sql.hive.client.HiveClientImpl.retryLocked(HiveClientImpl.scala:209) at org.apache.spark.sql.hive.client.HiveClientImpl.withHiveState(HiveClientImpl.scala:255) at org.apache.spark.sql.hive.client.HiveClientImpl.runHive(HiveClientImpl.scala:683) at org.apache.spark.sql.hive.client.HiveClientImpl.runSqlHive(HiveClientImpl.scala:673) ``` ## How was this patch tested? N/A Author: gatorsmile Closes #20196 from gatorsmile/testFix. --- .../org/apache/spark/sql/hive/client/HiveClientImpl.scala | 6 +++++- .../apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala | 5 +++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala index 7b7f4e0f10210..102f40bacc985 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala @@ -823,7 +823,8 @@ private[hive] class HiveClientImpl( } def reset(): Unit = withHiveState { - client.getAllTables("default").asScala.foreach { t => + try { + client.getAllTables("default").asScala.foreach { t => logDebug(s"Deleting table $t") val table = client.getTable("default", t) client.getIndexes("default", t, 255).asScala.foreach { index => @@ -837,6 +838,9 @@ private[hive] class HiveClientImpl( logDebug(s"Dropping Database: $db") client.dropDatabase(db, true, false, true) } + } finally { + runSqlHive("USE default") + } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala index 18137e7ea1d63..cf4ce83124d88 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala @@ -146,6 +146,11 @@ class DataSourceWithHiveMetastoreCatalogSuite 'id cast StringType as 'd2 ).coalesce(1) + override def beforeAll(): Unit = { + super.beforeAll() + sparkSession.metadataHive.reset() + } + Seq( "parquet" -> (( "org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat", From 6a4206ff04746481d7c8e307dfd0d31ff1402555 Mon Sep 17 00:00:00 2001 From: Yinan Li Date: Tue, 9 Jan 2018 01:32:48 -0800 Subject: [PATCH 0049/1161] [SPARK-22998][K8S] Set missing value for SPARK_MOUNTED_CLASSPATH in the executors ## What changes were proposed in this pull request? The environment variable `SPARK_MOUNTED_CLASSPATH` is referenced in the executor's Dockerfile, where its value is added to the classpath of the executor. However, the scheduler backend code missed setting it when creating the executor pods. This PR fixes it. ## How was this patch tested? Unit tested. vanzin Can you help take a look? Thanks! foxish Author: Yinan Li Closes #20193 from liyinan926/master. --- .../spark/scheduler/cluster/k8s/ExecutorPodFactory.scala | 5 ++++- .../scheduler/cluster/k8s/ExecutorPodFactorySuite.scala | 3 ++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala index 066d7e9f70ca5..bcacb3934d36a 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala @@ -94,6 +94,8 @@ private[spark] class ExecutorPodFactory( private val executorCores = sparkConf.getDouble("spark.executor.cores", 1) private val executorLimitCores = sparkConf.get(KUBERNETES_EXECUTOR_LIMIT_CORES) + private val executorJarsDownloadDir = sparkConf.get(JARS_DOWNLOAD_LOCATION) + /** * Configure and construct an executor pod with the given parameters. */ @@ -145,7 +147,8 @@ private[spark] class ExecutorPodFactory( (ENV_EXECUTOR_CORES, math.ceil(executorCores).toInt.toString), (ENV_EXECUTOR_MEMORY, executorMemoryString), (ENV_APPLICATION_ID, applicationId), - (ENV_EXECUTOR_ID, executorId)) ++ executorEnvs) + (ENV_EXECUTOR_ID, executorId), + (ENV_MOUNTED_CLASSPATH, s"$executorJarsDownloadDir/*")) ++ executorEnvs) .map(env => new EnvVarBuilder() .withName(env._1) .withValue(env._2) diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala index 884da8aabd880..7cfbe54c95390 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala @@ -197,7 +197,8 @@ class ExecutorPodFactorySuite extends SparkFunSuite with BeforeAndAfter with Bef ENV_EXECUTOR_CORES -> "1", ENV_EXECUTOR_MEMORY -> "1g", ENV_APPLICATION_ID -> "dummy", - ENV_EXECUTOR_POD_IP -> null) ++ additionalEnvVars + ENV_EXECUTOR_POD_IP -> null, + ENV_MOUNTED_CLASSPATH -> "/var/spark-data/spark-jars/*") ++ additionalEnvVars assert(executor.getSpec.getContainers.size() === 1) assert(executor.getSpec.getContainers.get(0).getEnv.size() === defaultEnvs.size) From f44ba910f58083458e1133502e193a9d6f2bf766 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Tue, 9 Jan 2018 21:48:14 +0800 Subject: [PATCH 0050/1161] [SPARK-16060][SQL] Support Vectorized ORC Reader ## What changes were proposed in this pull request? This PR adds an ORC columnar-batch reader to native `OrcFileFormat`. Since both Spark `ColumnarBatch` and ORC `RowBatch` are used together, it is faster than the current Spark implementation. This replaces the prior PR, #17924. Also, this PR adds `OrcReadBenchmark` to show the performance improvement. ## How was this patch tested? Pass the existing test cases. Author: Dongjoon Hyun Closes #19943 from dongjoon-hyun/SPARK-16060. --- .../apache/spark/sql/internal/SQLConf.scala | 7 + .../orc/OrcColumnarBatchReader.java | 523 ++++++++++++++++++ .../datasources/orc/OrcFileFormat.scala | 75 ++- .../execution/datasources/orc/OrcUtils.scala | 7 +- .../spark/sql/hive/orc/OrcReadBenchmark.scala | 435 +++++++++++++++ 5 files changed, 1022 insertions(+), 25 deletions(-) create mode 100644 sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java create mode 100644 sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcReadBenchmark.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 5c61f10bb71ad..74949db883f7a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -386,6 +386,11 @@ object SQLConf { .checkValues(Set("hive", "native")) .createWithDefault("native") + val ORC_VECTORIZED_READER_ENABLED = buildConf("spark.sql.orc.enableVectorizedReader") + .doc("Enables vectorized orc decoding.") + .booleanConf + .createWithDefault(true) + val ORC_FILTER_PUSHDOWN_ENABLED = buildConf("spark.sql.orc.filterPushdown") .doc("When true, enable filter pushdown for ORC files.") .booleanConf @@ -1183,6 +1188,8 @@ class SQLConf extends Serializable with Logging { def orcCompressionCodec: String = getConf(ORC_COMPRESSION) + def orcVectorizedReaderEnabled: Boolean = getConf(ORC_VECTORIZED_READER_ENABLED) + def parquetCompressionCodec: String = getConf(PARQUET_COMPRESSION) def parquetVectorizedReaderEnabled: Boolean = getConf(PARQUET_VECTORIZED_READER_ENABLED) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java new file mode 100644 index 0000000000000..5c28d0e6e507a --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java @@ -0,0 +1,523 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.orc; + +import java.io.IOException; +import java.util.stream.IntStream; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.mapreduce.InputSplit; +import org.apache.hadoop.mapreduce.RecordReader; +import org.apache.hadoop.mapreduce.TaskAttemptContext; +import org.apache.hadoop.mapreduce.lib.input.FileSplit; +import org.apache.orc.OrcConf; +import org.apache.orc.OrcFile; +import org.apache.orc.Reader; +import org.apache.orc.TypeDescription; +import org.apache.orc.mapred.OrcInputFormat; +import org.apache.orc.storage.common.type.HiveDecimal; +import org.apache.orc.storage.ql.exec.vector.*; +import org.apache.orc.storage.serde2.io.HiveDecimalWritable; + +import org.apache.spark.memory.MemoryMode; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.execution.vectorized.ColumnVectorUtils; +import org.apache.spark.sql.execution.vectorized.OffHeapColumnVector; +import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector; +import org.apache.spark.sql.execution.vectorized.WritableColumnVector; +import org.apache.spark.sql.types.*; +import org.apache.spark.sql.vectorized.ColumnarBatch; + + +/** + * To support vectorization in WholeStageCodeGen, this reader returns ColumnarBatch. + * After creating, `initialize` and `initBatch` should be called sequentially. + */ +public class OrcColumnarBatchReader extends RecordReader { + + /** + * The default size of batch. We use this value for both ORC and Spark consistently + * because they have different default values like the following. + * + * - ORC's VectorizedRowBatch.DEFAULT_SIZE = 1024 + * - Spark's ColumnarBatch.DEFAULT_BATCH_SIZE = 4 * 1024 + */ + public static final int DEFAULT_SIZE = 4 * 1024; + + // ORC File Reader + private Reader reader; + + // Vectorized ORC Row Batch + private VectorizedRowBatch batch; + + /** + * The column IDs of the physical ORC file schema which are required by this reader. + * -1 means this required column doesn't exist in the ORC file. + */ + private int[] requestedColIds; + + // Record reader from ORC row batch. + private org.apache.orc.RecordReader recordReader; + + private StructField[] requiredFields; + + // The result columnar batch for vectorized execution by whole-stage codegen. + private ColumnarBatch columnarBatch; + + // Writable column vectors of the result columnar batch. + private WritableColumnVector[] columnVectors; + + /** + * The memory mode of the columnarBatch + */ + private final MemoryMode MEMORY_MODE; + + public OrcColumnarBatchReader(boolean useOffHeap) { + MEMORY_MODE = useOffHeap ? MemoryMode.OFF_HEAP : MemoryMode.ON_HEAP; + } + + + @Override + public Void getCurrentKey() throws IOException, InterruptedException { + return null; + } + + @Override + public ColumnarBatch getCurrentValue() throws IOException, InterruptedException { + return columnarBatch; + } + + @Override + public float getProgress() throws IOException, InterruptedException { + return recordReader.getProgress(); + } + + @Override + public boolean nextKeyValue() throws IOException, InterruptedException { + return nextBatch(); + } + + @Override + public void close() throws IOException { + if (columnarBatch != null) { + columnarBatch.close(); + columnarBatch = null; + } + if (recordReader != null) { + recordReader.close(); + recordReader = null; + } + } + + /** + * Initialize ORC file reader and batch record reader. + * Please note that `initBatch` is needed to be called after this. + */ + @Override + public void initialize(InputSplit inputSplit, TaskAttemptContext taskAttemptContext) + throws IOException, InterruptedException { + FileSplit fileSplit = (FileSplit)inputSplit; + Configuration conf = taskAttemptContext.getConfiguration(); + reader = OrcFile.createReader( + fileSplit.getPath(), + OrcFile.readerOptions(conf) + .maxLength(OrcConf.MAX_FILE_LENGTH.getLong(conf)) + .filesystem(fileSplit.getPath().getFileSystem(conf))); + + Reader.Options options = + OrcInputFormat.buildOptions(conf, reader, fileSplit.getStart(), fileSplit.getLength()); + recordReader = reader.rows(options); + } + + /** + * Initialize columnar batch by setting required schema and partition information. + * With this information, this creates ColumnarBatch with the full schema. + */ + public void initBatch( + TypeDescription orcSchema, + int[] requestedColIds, + StructField[] requiredFields, + StructType partitionSchema, + InternalRow partitionValues) { + batch = orcSchema.createRowBatch(DEFAULT_SIZE); + assert(!batch.selectedInUse); // `selectedInUse` should be initialized with `false`. + + this.requiredFields = requiredFields; + this.requestedColIds = requestedColIds; + assert(requiredFields.length == requestedColIds.length); + + StructType resultSchema = new StructType(requiredFields); + for (StructField f : partitionSchema.fields()) { + resultSchema = resultSchema.add(f); + } + + int capacity = DEFAULT_SIZE; + if (MEMORY_MODE == MemoryMode.OFF_HEAP) { + columnVectors = OffHeapColumnVector.allocateColumns(capacity, resultSchema); + } else { + columnVectors = OnHeapColumnVector.allocateColumns(capacity, resultSchema); + } + columnarBatch = new ColumnarBatch(resultSchema, columnVectors, capacity); + + if (partitionValues.numFields() > 0) { + int partitionIdx = requiredFields.length; + for (int i = 0; i < partitionValues.numFields(); i++) { + ColumnVectorUtils.populate(columnVectors[i + partitionIdx], partitionValues, i); + columnVectors[i + partitionIdx].setIsConstant(); + } + } + + // Initialize the missing columns once. + for (int i = 0; i < requiredFields.length; i++) { + if (requestedColIds[i] == -1) { + columnVectors[i].putNulls(0, columnarBatch.capacity()); + columnVectors[i].setIsConstant(); + } + } + } + + /** + * Return true if there exists more data in the next batch. If exists, prepare the next batch + * by copying from ORC VectorizedRowBatch columns to Spark ColumnarBatch columns. + */ + private boolean nextBatch() throws IOException { + for (WritableColumnVector vector : columnVectors) { + vector.reset(); + } + columnarBatch.setNumRows(0); + + recordReader.nextBatch(batch); + int batchSize = batch.size; + if (batchSize == 0) { + return false; + } + columnarBatch.setNumRows(batchSize); + for (int i = 0; i < requiredFields.length; i++) { + StructField field = requiredFields[i]; + WritableColumnVector toColumn = columnVectors[i]; + + if (requestedColIds[i] >= 0) { + ColumnVector fromColumn = batch.cols[requestedColIds[i]]; + + if (fromColumn.isRepeating) { + putRepeatingValues(batchSize, field, fromColumn, toColumn); + } else if (fromColumn.noNulls) { + putNonNullValues(batchSize, field, fromColumn, toColumn); + } else { + putValues(batchSize, field, fromColumn, toColumn); + } + } + } + return true; + } + + private void putRepeatingValues( + int batchSize, + StructField field, + ColumnVector fromColumn, + WritableColumnVector toColumn) { + if (fromColumn.isNull[0]) { + toColumn.putNulls(0, batchSize); + } else { + DataType type = field.dataType(); + if (type instanceof BooleanType) { + toColumn.putBooleans(0, batchSize, ((LongColumnVector)fromColumn).vector[0] == 1); + } else if (type instanceof ByteType) { + toColumn.putBytes(0, batchSize, (byte)((LongColumnVector)fromColumn).vector[0]); + } else if (type instanceof ShortType) { + toColumn.putShorts(0, batchSize, (short)((LongColumnVector)fromColumn).vector[0]); + } else if (type instanceof IntegerType || type instanceof DateType) { + toColumn.putInts(0, batchSize, (int)((LongColumnVector)fromColumn).vector[0]); + } else if (type instanceof LongType) { + toColumn.putLongs(0, batchSize, ((LongColumnVector)fromColumn).vector[0]); + } else if (type instanceof TimestampType) { + toColumn.putLongs(0, batchSize, + fromTimestampColumnVector((TimestampColumnVector)fromColumn, 0)); + } else if (type instanceof FloatType) { + toColumn.putFloats(0, batchSize, (float)((DoubleColumnVector)fromColumn).vector[0]); + } else if (type instanceof DoubleType) { + toColumn.putDoubles(0, batchSize, ((DoubleColumnVector)fromColumn).vector[0]); + } else if (type instanceof StringType || type instanceof BinaryType) { + BytesColumnVector data = (BytesColumnVector)fromColumn; + WritableColumnVector arrayData = toColumn.getChildColumn(0); + int size = data.vector[0].length; + arrayData.reserve(size); + arrayData.putBytes(0, size, data.vector[0], 0); + for (int index = 0; index < batchSize; index++) { + toColumn.putArray(index, 0, size); + } + } else if (type instanceof DecimalType) { + DecimalType decimalType = (DecimalType)type; + putDecimalWritables( + toColumn, + batchSize, + decimalType.precision(), + decimalType.scale(), + ((DecimalColumnVector)fromColumn).vector[0]); + } else { + throw new UnsupportedOperationException("Unsupported Data Type: " + type); + } + } + } + + private void putNonNullValues( + int batchSize, + StructField field, + ColumnVector fromColumn, + WritableColumnVector toColumn) { + DataType type = field.dataType(); + if (type instanceof BooleanType) { + long[] data = ((LongColumnVector)fromColumn).vector; + for (int index = 0; index < batchSize; index++) { + toColumn.putBoolean(index, data[index] == 1); + } + } else if (type instanceof ByteType) { + long[] data = ((LongColumnVector)fromColumn).vector; + for (int index = 0; index < batchSize; index++) { + toColumn.putByte(index, (byte)data[index]); + } + } else if (type instanceof ShortType) { + long[] data = ((LongColumnVector)fromColumn).vector; + for (int index = 0; index < batchSize; index++) { + toColumn.putShort(index, (short)data[index]); + } + } else if (type instanceof IntegerType || type instanceof DateType) { + long[] data = ((LongColumnVector)fromColumn).vector; + for (int index = 0; index < batchSize; index++) { + toColumn.putInt(index, (int)data[index]); + } + } else if (type instanceof LongType) { + toColumn.putLongs(0, batchSize, ((LongColumnVector)fromColumn).vector, 0); + } else if (type instanceof TimestampType) { + TimestampColumnVector data = ((TimestampColumnVector)fromColumn); + for (int index = 0; index < batchSize; index++) { + toColumn.putLong(index, fromTimestampColumnVector(data, index)); + } + } else if (type instanceof FloatType) { + double[] data = ((DoubleColumnVector)fromColumn).vector; + for (int index = 0; index < batchSize; index++) { + toColumn.putFloat(index, (float)data[index]); + } + } else if (type instanceof DoubleType) { + toColumn.putDoubles(0, batchSize, ((DoubleColumnVector)fromColumn).vector, 0); + } else if (type instanceof StringType || type instanceof BinaryType) { + BytesColumnVector data = ((BytesColumnVector)fromColumn); + WritableColumnVector arrayData = toColumn.getChildColumn(0); + int totalNumBytes = IntStream.of(data.length).sum(); + arrayData.reserve(totalNumBytes); + for (int index = 0, pos = 0; index < batchSize; pos += data.length[index], index++) { + arrayData.putBytes(pos, data.length[index], data.vector[index], data.start[index]); + toColumn.putArray(index, pos, data.length[index]); + } + } else if (type instanceof DecimalType) { + DecimalType decimalType = (DecimalType)type; + DecimalColumnVector data = ((DecimalColumnVector)fromColumn); + if (decimalType.precision() > Decimal.MAX_LONG_DIGITS()) { + WritableColumnVector arrayData = toColumn.getChildColumn(0); + arrayData.reserve(batchSize * 16); + } + for (int index = 0; index < batchSize; index++) { + putDecimalWritable( + toColumn, + index, + decimalType.precision(), + decimalType.scale(), + data.vector[index]); + } + } else { + throw new UnsupportedOperationException("Unsupported Data Type: " + type); + } + } + + private void putValues( + int batchSize, + StructField field, + ColumnVector fromColumn, + WritableColumnVector toColumn) { + DataType type = field.dataType(); + if (type instanceof BooleanType) { + long[] vector = ((LongColumnVector)fromColumn).vector; + for (int index = 0; index < batchSize; index++) { + if (fromColumn.isNull[index]) { + toColumn.putNull(index); + } else { + toColumn.putBoolean(index, vector[index] == 1); + } + } + } else if (type instanceof ByteType) { + long[] vector = ((LongColumnVector)fromColumn).vector; + for (int index = 0; index < batchSize; index++) { + if (fromColumn.isNull[index]) { + toColumn.putNull(index); + } else { + toColumn.putByte(index, (byte)vector[index]); + } + } + } else if (type instanceof ShortType) { + long[] vector = ((LongColumnVector)fromColumn).vector; + for (int index = 0; index < batchSize; index++) { + if (fromColumn.isNull[index]) { + toColumn.putNull(index); + } else { + toColumn.putShort(index, (short)vector[index]); + } + } + } else if (type instanceof IntegerType || type instanceof DateType) { + long[] vector = ((LongColumnVector)fromColumn).vector; + for (int index = 0; index < batchSize; index++) { + if (fromColumn.isNull[index]) { + toColumn.putNull(index); + } else { + toColumn.putInt(index, (int)vector[index]); + } + } + } else if (type instanceof LongType) { + long[] vector = ((LongColumnVector)fromColumn).vector; + for (int index = 0; index < batchSize; index++) { + if (fromColumn.isNull[index]) { + toColumn.putNull(index); + } else { + toColumn.putLong(index, vector[index]); + } + } + } else if (type instanceof TimestampType) { + TimestampColumnVector vector = ((TimestampColumnVector)fromColumn); + for (int index = 0; index < batchSize; index++) { + if (fromColumn.isNull[index]) { + toColumn.putNull(index); + } else { + toColumn.putLong(index, fromTimestampColumnVector(vector, index)); + } + } + } else if (type instanceof FloatType) { + double[] vector = ((DoubleColumnVector)fromColumn).vector; + for (int index = 0; index < batchSize; index++) { + if (fromColumn.isNull[index]) { + toColumn.putNull(index); + } else { + toColumn.putFloat(index, (float)vector[index]); + } + } + } else if (type instanceof DoubleType) { + double[] vector = ((DoubleColumnVector)fromColumn).vector; + for (int index = 0; index < batchSize; index++) { + if (fromColumn.isNull[index]) { + toColumn.putNull(index); + } else { + toColumn.putDouble(index, vector[index]); + } + } + } else if (type instanceof StringType || type instanceof BinaryType) { + BytesColumnVector vector = (BytesColumnVector)fromColumn; + WritableColumnVector arrayData = toColumn.getChildColumn(0); + int totalNumBytes = IntStream.of(vector.length).sum(); + arrayData.reserve(totalNumBytes); + for (int index = 0, pos = 0; index < batchSize; pos += vector.length[index], index++) { + if (fromColumn.isNull[index]) { + toColumn.putNull(index); + } else { + arrayData.putBytes(pos, vector.length[index], vector.vector[index], vector.start[index]); + toColumn.putArray(index, pos, vector.length[index]); + } + } + } else if (type instanceof DecimalType) { + DecimalType decimalType = (DecimalType)type; + HiveDecimalWritable[] vector = ((DecimalColumnVector)fromColumn).vector; + if (decimalType.precision() > Decimal.MAX_LONG_DIGITS()) { + WritableColumnVector arrayData = toColumn.getChildColumn(0); + arrayData.reserve(batchSize * 16); + } + for (int index = 0; index < batchSize; index++) { + if (fromColumn.isNull[index]) { + toColumn.putNull(index); + } else { + putDecimalWritable( + toColumn, + index, + decimalType.precision(), + decimalType.scale(), + vector[index]); + } + } + } else { + throw new UnsupportedOperationException("Unsupported Data Type: " + type); + } + } + + /** + * Returns the number of micros since epoch from an element of TimestampColumnVector. + */ + private static long fromTimestampColumnVector(TimestampColumnVector vector, int index) { + return vector.time[index] * 1000L + vector.nanos[index] / 1000L; + } + + /** + * Put a `HiveDecimalWritable` to a `WritableColumnVector`. + */ + private static void putDecimalWritable( + WritableColumnVector toColumn, + int index, + int precision, + int scale, + HiveDecimalWritable decimalWritable) { + HiveDecimal decimal = decimalWritable.getHiveDecimal(); + Decimal value = + Decimal.apply(decimal.bigDecimalValue(), decimal.precision(), decimal.scale()); + value.changePrecision(precision, scale); + + if (precision <= Decimal.MAX_INT_DIGITS()) { + toColumn.putInt(index, (int) value.toUnscaledLong()); + } else if (precision <= Decimal.MAX_LONG_DIGITS()) { + toColumn.putLong(index, value.toUnscaledLong()); + } else { + byte[] bytes = value.toJavaBigDecimal().unscaledValue().toByteArray(); + WritableColumnVector arrayData = toColumn.getChildColumn(0); + arrayData.putBytes(index * 16, bytes.length, bytes, 0); + toColumn.putArray(index, index * 16, bytes.length); + } + } + + /** + * Put `HiveDecimalWritable`s to a `WritableColumnVector`. + */ + private static void putDecimalWritables( + WritableColumnVector toColumn, + int size, + int precision, + int scale, + HiveDecimalWritable decimalWritable) { + HiveDecimal decimal = decimalWritable.getHiveDecimal(); + Decimal value = + Decimal.apply(decimal.bigDecimalValue(), decimal.precision(), decimal.scale()); + value.changePrecision(precision, scale); + + if (precision <= Decimal.MAX_INT_DIGITS()) { + toColumn.putInts(0, size, (int) value.toUnscaledLong()); + } else if (precision <= Decimal.MAX_LONG_DIGITS()) { + toColumn.putLongs(0, size, value.toUnscaledLong()); + } else { + byte[] bytes = value.toJavaBigDecimal().unscaledValue().toByteArray(); + WritableColumnVector arrayData = toColumn.getChildColumn(0); + arrayData.reserve(bytes.length); + arrayData.putBytes(0, bytes.length, bytes, 0); + for (int index = 0; index < size; index++) { + toColumn.putArray(index, 0, bytes.length); + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala index f7471cd7debce..b8bacfa1838ae 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala @@ -118,6 +118,13 @@ class OrcFileFormat } } + override def supportBatch(sparkSession: SparkSession, schema: StructType): Boolean = { + val conf = sparkSession.sessionState.conf + conf.orcVectorizedReaderEnabled && conf.wholeStageEnabled && + schema.length <= conf.wholeStageMaxNumFields && + schema.forall(_.dataType.isInstanceOf[AtomicType]) + } + override def isSplitable( sparkSession: SparkSession, options: Map[String, String], @@ -139,6 +146,11 @@ class OrcFileFormat } } + val resultSchema = StructType(requiredSchema.fields ++ partitionSchema.fields) + val sqlConf = sparkSession.sessionState.conf + val enableOffHeapColumnVector = sqlConf.offHeapColumnVectorEnabled + val enableVectorizedReader = supportBatch(sparkSession, resultSchema) + val broadcastedConf = sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) val isCaseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis @@ -146,8 +158,14 @@ class OrcFileFormat (file: PartitionedFile) => { val conf = broadcastedConf.value.value + val filePath = new Path(new URI(file.filePath)) + + val fs = filePath.getFileSystem(conf) + val readerOptions = OrcFile.readerOptions(conf).filesystem(fs) + val reader = OrcFile.createReader(filePath, readerOptions) + val requestedColIdsOrEmptyFile = OrcUtils.requestedColumnIds( - isCaseSensitive, dataSchema, requiredSchema, new Path(new URI(file.filePath)), conf) + isCaseSensitive, dataSchema, requiredSchema, reader, conf) if (requestedColIdsOrEmptyFile.isEmpty) { Iterator.empty @@ -155,29 +173,46 @@ class OrcFileFormat val requestedColIds = requestedColIdsOrEmptyFile.get assert(requestedColIds.length == requiredSchema.length, "[BUG] requested column IDs do not match required schema") - conf.set(OrcConf.INCLUDE_COLUMNS.getAttribute, + val taskConf = new Configuration(conf) + taskConf.set(OrcConf.INCLUDE_COLUMNS.getAttribute, requestedColIds.filter(_ != -1).sorted.mkString(",")) - val fileSplit = - new FileSplit(new Path(new URI(file.filePath)), file.start, file.length, Array.empty) + val fileSplit = new FileSplit(filePath, file.start, file.length, Array.empty) val attemptId = new TaskAttemptID(new TaskID(new JobID(), TaskType.MAP, 0), 0) - val taskAttemptContext = new TaskAttemptContextImpl(conf, attemptId) - - val orcRecordReader = new OrcInputFormat[OrcStruct] - .createRecordReader(fileSplit, taskAttemptContext) - val iter = new RecordReaderIterator[OrcStruct](orcRecordReader) - Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => iter.close())) - - val fullSchema = requiredSchema.toAttributes ++ partitionSchema.toAttributes - val unsafeProjection = GenerateUnsafeProjection.generate(fullSchema, fullSchema) - val deserializer = new OrcDeserializer(dataSchema, requiredSchema, requestedColIds) - - if (partitionSchema.length == 0) { - iter.map(value => unsafeProjection(deserializer.deserialize(value))) + val taskAttemptContext = new TaskAttemptContextImpl(taskConf, attemptId) + + val taskContext = Option(TaskContext.get()) + if (enableVectorizedReader) { + val batchReader = + new OrcColumnarBatchReader(enableOffHeapColumnVector && taskContext.isDefined) + batchReader.initialize(fileSplit, taskAttemptContext) + batchReader.initBatch( + reader.getSchema, + requestedColIds, + requiredSchema.fields, + partitionSchema, + file.partitionValues) + + val iter = new RecordReaderIterator(batchReader) + Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => iter.close())) + iter.asInstanceOf[Iterator[InternalRow]] } else { - val joinedRow = new JoinedRow() - iter.map(value => - unsafeProjection(joinedRow(deserializer.deserialize(value), file.partitionValues))) + val orcRecordReader = new OrcInputFormat[OrcStruct] + .createRecordReader(fileSplit, taskAttemptContext) + val iter = new RecordReaderIterator[OrcStruct](orcRecordReader) + Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => iter.close())) + + val fullSchema = requiredSchema.toAttributes ++ partitionSchema.toAttributes + val unsafeProjection = GenerateUnsafeProjection.generate(fullSchema, fullSchema) + val deserializer = new OrcDeserializer(dataSchema, requiredSchema, requestedColIds) + + if (partitionSchema.length == 0) { + iter.map(value => unsafeProjection(deserializer.deserialize(value))) + } else { + val joinedRow = new JoinedRow() + iter.map(value => + unsafeProjection(joinedRow(deserializer.deserialize(value), file.partitionValues))) + } } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala index b03ee06d04a16..13a23996f4ade 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala @@ -21,7 +21,7 @@ import scala.collection.JavaConverters._ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path} -import org.apache.orc.{OrcFile, TypeDescription} +import org.apache.orc.{OrcFile, Reader, TypeDescription} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging @@ -80,11 +80,8 @@ object OrcUtils extends Logging { isCaseSensitive: Boolean, dataSchema: StructType, requiredSchema: StructType, - file: Path, + reader: Reader, conf: Configuration): Option[Array[Int]] = { - val fs = file.getFileSystem(conf) - val readerOptions = OrcFile.readerOptions(conf).filesystem(fs) - val reader = OrcFile.createReader(file, readerOptions) val orcFieldNames = reader.getSchema.getFieldNames.asScala if (orcFieldNames.isEmpty) { // SPARK-8501: Some old empty ORC files always have an empty schema stored in their footer. diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcReadBenchmark.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcReadBenchmark.scala new file mode 100644 index 0000000000000..37ed846acd1eb --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcReadBenchmark.scala @@ -0,0 +1,435 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.orc + +import java.io.File + +import scala.util.{Random, Try} + +import org.apache.spark.SparkConf +import org.apache.spark.sql.{DataFrame, SparkSession} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types._ +import org.apache.spark.util.{Benchmark, Utils} + + +/** + * Benchmark to measure ORC read performance. + * + * This is in `sql/hive` module in order to compare `sql/core` and `sql/hive` ORC data sources. + */ +// scalastyle:off line.size.limit +object OrcReadBenchmark { + val conf = new SparkConf() + conf.set("orc.compression", "snappy") + + private val spark = SparkSession.builder() + .master("local[1]") + .appName("OrcReadBenchmark") + .config(conf) + .getOrCreate() + + // Set default configs. Individual cases will change them if necessary. + spark.conf.set(SQLConf.ORC_FILTER_PUSHDOWN_ENABLED.key, "true") + + def withTempPath(f: File => Unit): Unit = { + val path = Utils.createTempDir() + path.delete() + try f(path) finally Utils.deleteRecursively(path) + } + + def withTempTable(tableNames: String*)(f: => Unit): Unit = { + try f finally tableNames.foreach(spark.catalog.dropTempView) + } + + def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = { + val (keys, values) = pairs.unzip + val currentValues = keys.map(key => Try(spark.conf.get(key)).toOption) + (keys, values).zipped.foreach(spark.conf.set) + try f finally { + keys.zip(currentValues).foreach { + case (key, Some(value)) => spark.conf.set(key, value) + case (key, None) => spark.conf.unset(key) + } + } + } + + private val NATIVE_ORC_FORMAT = classOf[org.apache.spark.sql.execution.datasources.orc.OrcFileFormat].getCanonicalName + private val HIVE_ORC_FORMAT = classOf[org.apache.spark.sql.hive.orc.OrcFileFormat].getCanonicalName + + private def prepareTable(dir: File, df: DataFrame, partition: Option[String] = None): Unit = { + val dirORC = dir.getCanonicalPath + + if (partition.isDefined) { + df.write.partitionBy(partition.get).orc(dirORC) + } else { + df.write.orc(dirORC) + } + + spark.read.format(NATIVE_ORC_FORMAT).load(dirORC).createOrReplaceTempView("nativeOrcTable") + spark.read.format(HIVE_ORC_FORMAT).load(dirORC).createOrReplaceTempView("hiveOrcTable") + } + + def numericScanBenchmark(values: Int, dataType: DataType): Unit = { + val sqlBenchmark = new Benchmark(s"SQL Single ${dataType.sql} Column Scan", values) + + withTempPath { dir => + withTempTable("t1", "nativeOrcTable", "hiveOrcTable") { + import spark.implicits._ + spark.range(values).map(_ => Random.nextLong).createOrReplaceTempView("t1") + + prepareTable(dir, spark.sql(s"SELECT CAST(value as ${dataType.sql}) id FROM t1")) + + sqlBenchmark.addCase("Native ORC MR") { _ => + withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> "false") { + spark.sql("SELECT sum(id) FROM nativeOrcTable").collect() + } + } + + sqlBenchmark.addCase("Native ORC Vectorized") { _ => + spark.sql("SELECT sum(id) FROM nativeOrcTable").collect() + } + + sqlBenchmark.addCase("Hive built-in ORC") { _ => + spark.sql("SELECT sum(id) FROM hiveOrcTable").collect() + } + + /* + Java HotSpot(TM) 64-Bit Server VM 1.8.0_152-b16 on Mac OS X 10.13.2 + Intel(R) Core(TM) i7-4770HQ CPU @ 2.20GHz + + SQL Single TINYINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Native ORC MR 1192 / 1221 13.2 75.8 1.0X + Native ORC Vectorized 161 / 170 97.5 10.3 7.4X + Hive built-in ORC 1399 / 1413 11.2 89.0 0.9X + + SQL Single SMALLINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Native ORC MR 1287 / 1333 12.2 81.8 1.0X + Native ORC Vectorized 164 / 172 95.6 10.5 7.8X + Hive built-in ORC 1629 / 1650 9.7 103.6 0.8X + + SQL Single INT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Native ORC MR 1304 / 1388 12.1 82.9 1.0X + Native ORC Vectorized 227 / 240 69.3 14.4 5.7X + Hive built-in ORC 1866 / 1867 8.4 118.6 0.7X + + SQL Single BIGINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Native ORC MR 1331 / 1357 11.8 84.6 1.0X + Native ORC Vectorized 289 / 297 54.4 18.4 4.6X + Hive built-in ORC 1922 / 1929 8.2 122.2 0.7X + + SQL Single FLOAT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Native ORC MR 1410 / 1428 11.2 89.7 1.0X + Native ORC Vectorized 328 / 335 48.0 20.8 4.3X + Hive built-in ORC 1929 / 2012 8.2 122.6 0.7X + + SQL Single DOUBLE Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Native ORC MR 1467 / 1485 10.7 93.3 1.0X + Native ORC Vectorized 402 / 411 39.1 25.6 3.6X + Hive built-in ORC 2023 / 2042 7.8 128.6 0.7X + */ + sqlBenchmark.run() + } + } + } + + def intStringScanBenchmark(values: Int): Unit = { + val benchmark = new Benchmark("Int and String Scan", values) + + withTempPath { dir => + withTempTable("t1", "nativeOrcTable", "hiveOrcTable") { + import spark.implicits._ + spark.range(values).map(_ => Random.nextLong).createOrReplaceTempView("t1") + + prepareTable( + dir, + spark.sql("SELECT CAST(value AS INT) AS c1, CAST(value as STRING) AS c2 FROM t1")) + + benchmark.addCase("Native ORC MR") { _ => + withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> "false") { + spark.sql("SELECT sum(c1), sum(length(c2)) FROM nativeOrcTable").collect() + } + } + + benchmark.addCase("Native ORC Vectorized") { _ => + spark.sql("SELECT sum(c1), sum(length(c2)) FROM nativeOrcTable").collect() + } + + benchmark.addCase("Hive built-in ORC") { _ => + spark.sql("SELECT sum(c1), sum(length(c2)) FROM hiveOrcTable").collect() + } + + /* + Java HotSpot(TM) 64-Bit Server VM 1.8.0_152-b16 on Mac OS X 10.13.2 + Intel(R) Core(TM) i7-4770HQ CPU @ 2.20GHz + + Int and String Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Native ORC MR 2729 / 2744 3.8 260.2 1.0X + Native ORC Vectorized 1318 / 1344 8.0 125.7 2.1X + Hive built-in ORC 3731 / 3782 2.8 355.8 0.7X + */ + benchmark.run() + } + } + } + + def partitionTableScanBenchmark(values: Int): Unit = { + val benchmark = new Benchmark("Partitioned Table", values) + + withTempPath { dir => + withTempTable("t1", "nativeOrcTable", "hiveOrcTable") { + import spark.implicits._ + spark.range(values).map(_ => Random.nextLong).createOrReplaceTempView("t1") + + prepareTable(dir, spark.sql("SELECT value % 2 AS p, value AS id FROM t1"), Some("p")) + + benchmark.addCase("Read data column - Native ORC MR") { _ => + withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> "false") { + spark.sql("SELECT sum(id) FROM nativeOrcTable").collect() + } + } + + benchmark.addCase("Read data column - Native ORC Vectorized") { _ => + spark.sql("SELECT sum(id) FROM nativeOrcTable").collect() + } + + benchmark.addCase("Read data column - Hive built-in ORC") { _ => + spark.sql("SELECT sum(id) FROM hiveOrcTable").collect() + } + + benchmark.addCase("Read partition column - Native ORC MR") { _ => + withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> "false") { + spark.sql("SELECT sum(p) FROM nativeOrcTable").collect() + } + } + + benchmark.addCase("Read partition column - Native ORC Vectorized") { _ => + spark.sql("SELECT sum(p) FROM nativeOrcTable").collect() + } + + benchmark.addCase("Read partition column - Hive built-in ORC") { _ => + spark.sql("SELECT sum(p) FROM hiveOrcTable").collect() + } + + benchmark.addCase("Read both columns - Native ORC MR") { _ => + withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> "false") { + spark.sql("SELECT sum(p), sum(id) FROM nativeOrcTable").collect() + } + } + + benchmark.addCase("Read both columns - Native ORC Vectorized") { _ => + spark.sql("SELECT sum(p), sum(id) FROM nativeOrcTable").collect() + } + + benchmark.addCase("Read both columns - Hive built-in ORC") { _ => + spark.sql("SELECT sum(p), sum(id) FROM hiveOrcTable").collect() + } + + /* + Java HotSpot(TM) 64-Bit Server VM 1.8.0_152-b16 on Mac OS X 10.13.2 + Intel(R) Core(TM) i7-4770HQ CPU @ 2.20GHz + + Partitioned Table: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Read data column - Native ORC MR 1531 / 1536 10.3 97.4 1.0X + Read data column - Native ORC Vectorized 295 / 298 53.3 18.8 5.2X + Read data column - Hive built-in ORC 2125 / 2126 7.4 135.1 0.7X + Read partition column - Native ORC MR 1049 / 1062 15.0 66.7 1.5X + Read partition column - Native ORC Vectorized 54 / 57 290.1 3.4 28.2X + Read partition column - Hive built-in ORC 1282 / 1291 12.3 81.5 1.2X + Read both columns - Native ORC MR 1594 / 1598 9.9 101.3 1.0X + Read both columns - Native ORC Vectorized 332 / 336 47.4 21.1 4.6X + Read both columns - Hive built-in ORC 2145 / 2187 7.3 136.4 0.7X + */ + benchmark.run() + } + } + } + + def repeatedStringScanBenchmark(values: Int): Unit = { + val benchmark = new Benchmark("Repeated String", values) + + withTempPath { dir => + withTempTable("t1", "nativeOrcTable", "hiveOrcTable") { + spark.range(values).createOrReplaceTempView("t1") + + prepareTable(dir, spark.sql("SELECT CAST((id % 200) + 10000 as STRING) AS c1 FROM t1")) + + benchmark.addCase("Native ORC MR") { _ => + withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> "false") { + spark.sql("SELECT sum(length(c1)) FROM nativeOrcTable").collect() + } + } + + benchmark.addCase("Native ORC Vectorized") { _ => + spark.sql("SELECT sum(length(c1)) FROM nativeOrcTable").collect() + } + + benchmark.addCase("Hive built-in ORC") { _ => + spark.sql("SELECT sum(length(c1)) FROM hiveOrcTable").collect() + } + + /* + Java HotSpot(TM) 64-Bit Server VM 1.8.0_152-b16 on Mac OS X 10.13.2 + Intel(R) Core(TM) i7-4770HQ CPU @ 2.20GHz + + Repeated String: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Native ORC MR 1325 / 1328 7.9 126.4 1.0X + Native ORC Vectorized 320 / 330 32.8 30.5 4.1X + Hive built-in ORC 1971 / 1972 5.3 188.0 0.7X + */ + benchmark.run() + } + } + } + + def stringWithNullsScanBenchmark(values: Int, fractionOfNulls: Double): Unit = { + withTempPath { dir => + withTempTable("t1", "nativeOrcTable", "hiveOrcTable") { + spark.range(values).createOrReplaceTempView("t1") + + prepareTable( + dir, + spark.sql( + s"SELECT IF(RAND(1) < $fractionOfNulls, NULL, CAST(id as STRING)) AS c1, " + + s"IF(RAND(2) < $fractionOfNulls, NULL, CAST(id as STRING)) AS c2 FROM t1")) + + val benchmark = new Benchmark(s"String with Nulls Scan ($fractionOfNulls%)", values) + + benchmark.addCase("Native ORC MR") { _ => + withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> "false") { + spark.sql("SELECT SUM(LENGTH(c2)) FROM nativeOrcTable " + + "WHERE c1 IS NOT NULL AND c2 IS NOT NULL").collect() + } + } + + benchmark.addCase("Native ORC Vectorized") { _ => + spark.sql("SELECT SUM(LENGTH(c2)) FROM nativeOrcTable " + + "WHERE c1 IS NOT NULL AND c2 IS NOT NULL").collect() + } + + benchmark.addCase("Hive built-in ORC") { _ => + spark.sql("SELECT SUM(LENGTH(c2)) FROM hiveOrcTable " + + "WHERE c1 IS NOT NULL AND c2 IS NOT NULL").collect() + } + + /* + Java HotSpot(TM) 64-Bit Server VM 1.8.0_152-b16 on Mac OS X 10.13.2 + Intel(R) Core(TM) i7-4770HQ CPU @ 2.20GHz + + String with Nulls Scan (0.0%): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Native ORC MR 2553 / 2554 4.1 243.4 1.0X + Native ORC Vectorized 953 / 954 11.0 90.9 2.7X + Hive built-in ORC 3875 / 3898 2.7 369.6 0.7X + + String with Nulls Scan (0.5%): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Native ORC MR 2389 / 2408 4.4 227.8 1.0X + Native ORC Vectorized 1208 / 1209 8.7 115.2 2.0X + Hive built-in ORC 2940 / 2952 3.6 280.4 0.8X + + String with Nulls Scan (0.95%): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Native ORC MR 1295 / 1311 8.1 123.5 1.0X + Native ORC Vectorized 449 / 457 23.4 42.8 2.9X + Hive built-in ORC 1649 / 1660 6.4 157.3 0.8X + */ + benchmark.run() + } + } + } + + def columnsBenchmark(values: Int, width: Int): Unit = { + val sqlBenchmark = new Benchmark(s"SQL Single Column Scan from $width columns", values) + + withTempPath { dir => + withTempTable("t1", "nativeOrcTable", "hiveOrcTable") { + import spark.implicits._ + val middle = width / 2 + val selectExpr = (1 to width).map(i => s"value as c$i") + spark.range(values).map(_ => Random.nextLong).toDF() + .selectExpr(selectExpr: _*).createOrReplaceTempView("t1") + + prepareTable(dir, spark.sql("SELECT * FROM t1")) + + sqlBenchmark.addCase("Native ORC MR") { _ => + withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> "false") { + spark.sql(s"SELECT sum(c$middle) FROM nativeOrcTable").collect() + } + } + + sqlBenchmark.addCase("Native ORC Vectorized") { _ => + spark.sql(s"SELECT sum(c$middle) FROM nativeOrcTable").collect() + } + + sqlBenchmark.addCase("Hive built-in ORC") { _ => + spark.sql(s"SELECT sum(c$middle) FROM hiveOrcTable").collect() + } + + /* + Java HotSpot(TM) 64-Bit Server VM 1.8.0_152-b16 on Mac OS X 10.13.2 + Intel(R) Core(TM) i7-4770HQ CPU @ 2.20GHz + + SQL Single Column Scan from 100 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Native ORC MR 1103 / 1124 1.0 1052.0 1.0X + Native ORC Vectorized 92 / 100 11.4 87.9 12.0X + Hive built-in ORC 383 / 390 2.7 365.4 2.9X + + SQL Single Column Scan from 200 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Native ORC MR 2245 / 2250 0.5 2141.0 1.0X + Native ORC Vectorized 157 / 165 6.7 150.2 14.3X + Hive built-in ORC 587 / 593 1.8 559.4 3.8X + + SQL Single Column Scan from 300 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Native ORC MR 3343 / 3350 0.3 3188.3 1.0X + Native ORC Vectorized 265 / 280 3.9 253.2 12.6X + Hive built-in ORC 828 / 842 1.3 789.8 4.0X + */ + sqlBenchmark.run() + } + } + } + + def main(args: Array[String]): Unit = { + Seq(ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType).foreach { dataType => + numericScanBenchmark(1024 * 1024 * 15, dataType) + } + intStringScanBenchmark(1024 * 1024 * 10) + partitionTableScanBenchmark(1024 * 1024 * 15) + repeatedStringScanBenchmark(1024 * 1024 * 10) + for (fractionOfNulls <- List(0.0, 0.50, 0.95)) { + stringWithNullsScanBenchmark(1024 * 1024 * 10, fractionOfNulls) + } + columnsBenchmark(1024 * 1024 * 1, 100) + columnsBenchmark(1024 * 1024 * 1, 200) + columnsBenchmark(1024 * 1024 * 1, 300) + } +} +// scalastyle:on line.size.limit From 2250cb75b99d257e698fe5418a51d8cddb4d5104 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Tue, 9 Jan 2018 21:58:55 +0800 Subject: [PATCH 0051/1161] [SPARK-22981][SQL] Fix incorrect results of Casting Struct to String ## What changes were proposed in this pull request? This pr fixed the issue when casting structs into strings; ``` scala> val df = Seq(((1, "a"), 0), ((2, "b"), 0)).toDF("a", "b") scala> df.write.saveAsTable("t") scala> sql("SELECT CAST(a AS STRING) FROM t").show +-------------------+ | a| +-------------------+ |[0,1,1800000001,61]| |[0,2,1800000001,62]| +-------------------+ ``` This pr modified the result into; ``` +------+ | a| +------+ |[1, a]| |[2, b]| +------+ ``` ## How was this patch tested? Added tests in `CastSuite`. Author: Takeshi Yamamuro Closes #20176 from maropu/SPARK-22981. --- .../spark/sql/catalyst/expressions/Cast.scala | 71 +++++++++++++++++++ .../sql/catalyst/expressions/CastSuite.scala | 16 +++++ 2 files changed, 87 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index f2de4c8e30bec..f21aa1e9e3135 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -259,6 +259,29 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String builder.append("]") builder.build() }) + case StructType(fields) => + buildCast[InternalRow](_, row => { + val builder = new UTF8StringBuilder + builder.append("[") + if (row.numFields > 0) { + val st = fields.map(_.dataType) + val toUTF8StringFuncs = st.map(castToString) + if (!row.isNullAt(0)) { + builder.append(toUTF8StringFuncs(0)(row.get(0, st(0))).asInstanceOf[UTF8String]) + } + var i = 1 + while (i < row.numFields) { + builder.append(",") + if (!row.isNullAt(i)) { + builder.append(" ") + builder.append(toUTF8StringFuncs(i)(row.get(i, st(i))).asInstanceOf[UTF8String]) + } + i += 1 + } + } + builder.append("]") + builder.build() + }) case _ => buildCast[Any](_, o => UTF8String.fromString(o.toString)) } @@ -732,6 +755,41 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String """.stripMargin } + private def writeStructToStringBuilder( + st: Seq[DataType], + row: String, + buffer: String, + ctx: CodegenContext): String = { + val structToStringCode = st.zipWithIndex.map { case (ft, i) => + val fieldToStringCode = castToStringCode(ft, ctx) + val field = ctx.freshName("field") + val fieldStr = ctx.freshName("fieldStr") + s""" + |${if (i != 0) s"""$buffer.append(",");""" else ""} + |if (!$row.isNullAt($i)) { + | ${if (i != 0) s"""$buffer.append(" ");""" else ""} + | + | // Append $i field into the string buffer + | ${ctx.javaType(ft)} $field = ${ctx.getValue(row, ft, s"$i")}; + | UTF8String $fieldStr = null; + | ${fieldToStringCode(field, fieldStr, null /* resultIsNull won't be used */)} + | $buffer.append($fieldStr); + |} + """.stripMargin + } + + val writeStructCode = ctx.splitExpressions( + expressions = structToStringCode, + funcName = "fieldToString", + arguments = ("InternalRow", row) :: (classOf[UTF8StringBuilder].getName, buffer) :: Nil) + + s""" + |$buffer.append("["); + |$writeStructCode + |$buffer.append("]"); + """.stripMargin + } + private[this] def castToStringCode(from: DataType, ctx: CodegenContext): CastFunction = { from match { case BinaryType => @@ -765,6 +823,19 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String |$evPrim = $buffer.build(); """.stripMargin } + case StructType(fields) => + (c, evPrim, evNull) => { + val row = ctx.freshName("row") + val buffer = ctx.freshName("buffer") + val bufferClass = classOf[UTF8StringBuilder].getName + val writeStructCode = writeStructToStringBuilder(fields.map(_.dataType), row, buffer, ctx) + s""" + |InternalRow $row = $c; + |$bufferClass $buffer = new $bufferClass(); + |$writeStructCode + |$evPrim = $buffer.build(); + """.stripMargin + } case _ => (c, evPrim, evNull) => s"$evPrim = UTF8String.fromString(String.valueOf($c));" } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala index 1445bb8a97d40..5b25bdf907c3a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -906,4 +906,20 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { StringType) checkEvaluation(ret5, "[1 -> [1, 2, 3], 2 -> [4, 5, 6]]") } + + test("SPARK-22981 Cast struct to string") { + val ret1 = cast(Literal.create((1, "a", 0.1)), StringType) + checkEvaluation(ret1, "[1, a, 0.1]") + val ret2 = cast(Literal.create(Tuple3[Int, String, String](1, null, "a")), StringType) + checkEvaluation(ret2, "[1,, a]") + val ret3 = cast(Literal.create( + (Date.valueOf("2014-12-03"), Timestamp.valueOf("2014-12-03 15:05:00"))), StringType) + checkEvaluation(ret3, "[2014-12-03, 2014-12-03 15:05:00]") + val ret4 = cast(Literal.create(((1, "a"), 5, 0.1)), StringType) + checkEvaluation(ret4, "[[1, a], 5, 0.1]") + val ret5 = cast(Literal.create((Seq(1, 2, 3), "a", 0.1)), StringType) + checkEvaluation(ret5, "[[1, 2, 3], a, 0.1]") + val ret6 = cast(Literal.create((1, Map(1 -> "a", 2 -> "b", 3 -> "c"))), StringType) + checkEvaluation(ret6, "[1, [1 -> a, 2 -> b, 3 -> c]]") + } } From 96ba217a06fbe1dad703447d7058cb7841653861 Mon Sep 17 00:00:00 2001 From: Wang Gengliang Date: Wed, 10 Jan 2018 10:15:27 +0800 Subject: [PATCH 0052/1161] [SPARK-23005][CORE] Improve RDD.take on small number of partitions ## What changes were proposed in this pull request? In current implementation of RDD.take, we overestimate the number of partitions we need to try by 50%: `(1.5 * num * partsScanned / buf.size).toInt` However, when the number is small, the result of `.toInt` is not what we want. E.g, 2.9 will become 2, which should be 3. Use Math.ceil to fix the problem. Also clean up the code in RDD.scala. ## How was this patch tested? Unit test Author: Wang Gengliang Closes #20200 from gengliangwang/Take. --- .../main/scala/org/apache/spark/rdd/RDD.scala | 27 +++++++++---------- .../spark/sql/execution/SparkPlan.scala | 5 ++-- 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 8798dfc925362..7859781e98223 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -150,7 +150,7 @@ abstract class RDD[T: ClassTag]( val id: Int = sc.newRddId() /** A friendly name for this RDD */ - @transient var name: String = null + @transient var name: String = _ /** Assign a name to this RDD */ def setName(_name: String): this.type = { @@ -224,8 +224,8 @@ abstract class RDD[T: ClassTag]( // Our dependencies and partitions will be gotten by calling subclass's methods below, and will // be overwritten when we're checkpointed - private var dependencies_ : Seq[Dependency[_]] = null - @transient private var partitions_ : Array[Partition] = null + private var dependencies_ : Seq[Dependency[_]] = _ + @transient private var partitions_ : Array[Partition] = _ /** An Option holding our checkpoint RDD, if we are checkpointed */ private def checkpointRDD: Option[CheckpointRDD[T]] = checkpointData.flatMap(_.checkpointRDD) @@ -297,7 +297,7 @@ abstract class RDD[T: ClassTag]( private[spark] def getNarrowAncestors: Seq[RDD[_]] = { val ancestors = new mutable.HashSet[RDD[_]] - def visit(rdd: RDD[_]) { + def visit(rdd: RDD[_]): Unit = { val narrowDependencies = rdd.dependencies.filter(_.isInstanceOf[NarrowDependency[_]]) val narrowParents = narrowDependencies.map(_.rdd) val narrowParentsNotVisited = narrowParents.filterNot(ancestors.contains) @@ -449,7 +449,7 @@ abstract class RDD[T: ClassTag]( if (shuffle) { /** Distributes elements evenly across output partitions, starting from a random partition. */ val distributePartition = (index: Int, items: Iterator[T]) => { - var position = (new Random(hashing.byteswap32(index))).nextInt(numPartitions) + var position = new Random(hashing.byteswap32(index)).nextInt(numPartitions) items.map { t => // Note that the hash code of the key will just be the key itself. The HashPartitioner // will mod it with the number of total partitions. @@ -951,7 +951,7 @@ abstract class RDD[T: ClassTag]( def collectPartition(p: Int): Array[T] = { sc.runJob(this, (iter: Iterator[T]) => iter.toArray, Seq(p)).head } - (0 until partitions.length).iterator.flatMap(i => collectPartition(i)) + partitions.indices.iterator.flatMap(i => collectPartition(i)) } /** @@ -1338,6 +1338,7 @@ abstract class RDD[T: ClassTag]( // The number of partitions to try in this iteration. It is ok for this number to be // greater than totalParts because we actually cap it at totalParts in runJob. var numPartsToTry = 1L + val left = num - buf.size if (partsScanned > 0) { // If we didn't find any rows after the previous iteration, quadruple and retry. // Otherwise, interpolate the number of partitions we need to try, but overestimate @@ -1345,13 +1346,12 @@ abstract class RDD[T: ClassTag]( if (buf.isEmpty) { numPartsToTry = partsScanned * scaleUpFactor } else { - // the left side of max is >=1 whenever partsScanned >= 2 - numPartsToTry = Math.max((1.5 * num * partsScanned / buf.size).toInt - partsScanned, 1) + // As left > 0, numPartsToTry is always >= 1 + numPartsToTry = Math.ceil(1.5 * left * partsScanned / buf.size).toInt numPartsToTry = Math.min(numPartsToTry, partsScanned * scaleUpFactor) } } - val left = num - buf.size val p = partsScanned.until(math.min(partsScanned + numPartsToTry, totalParts).toInt) val res = sc.runJob(this, (it: Iterator[T]) => it.take(left).toArray, p) @@ -1677,8 +1677,7 @@ abstract class RDD[T: ClassTag]( // an RDD and its parent in every batch, in which case the parent may never be checkpointed // and its lineage never truncated, leading to OOMs in the long run (SPARK-6847). private val checkpointAllMarkedAncestors = - Option(sc.getLocalProperty(RDD.CHECKPOINT_ALL_MARKED_ANCESTORS)) - .map(_.toBoolean).getOrElse(false) + Option(sc.getLocalProperty(RDD.CHECKPOINT_ALL_MARKED_ANCESTORS)).exists(_.toBoolean) /** Returns the first parent RDD */ protected[spark] def firstParent[U: ClassTag]: RDD[U] = { @@ -1686,7 +1685,7 @@ abstract class RDD[T: ClassTag]( } /** Returns the jth parent RDD: e.g. rdd.parent[T](0) is equivalent to rdd.firstParent[T] */ - protected[spark] def parent[U: ClassTag](j: Int) = { + protected[spark] def parent[U: ClassTag](j: Int): RDD[U] = { dependencies(j).rdd.asInstanceOf[RDD[U]] } @@ -1754,7 +1753,7 @@ abstract class RDD[T: ClassTag]( * collected. Subclasses of RDD may override this method for implementing their own cleaning * logic. See [[org.apache.spark.rdd.UnionRDD]] for an example. */ - protected def clearDependencies() { + protected def clearDependencies(): Unit = { dependencies_ = null } @@ -1790,7 +1789,7 @@ abstract class RDD[T: ClassTag]( val lastDepStrings = debugString(lastDep.rdd, prefix, lastDep.isInstanceOf[ShuffleDependency[_, _, _]], true) - (frontDepStrings ++ lastDepStrings) + frontDepStrings ++ lastDepStrings } } // The first RDD in the dependency stack has no parents, so no need for a +- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 82300efc01632..398758a3331b4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -351,8 +351,9 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ if (buf.isEmpty) { numPartsToTry = partsScanned * limitScaleUpFactor } else { - // the left side of max is >=1 whenever partsScanned >= 2 - numPartsToTry = Math.max((1.5 * n * partsScanned / buf.size).toInt - partsScanned, 1) + val left = n - buf.size + // As left > 0, numPartsToTry is always >= 1 + numPartsToTry = Math.ceil(1.5 * left * partsScanned / buf.size).toInt numPartsToTry = Math.min(numPartsToTry, partsScanned * limitScaleUpFactor) } } From 6f169ca9e1444fe8fd1ab6f3fbf0a8be1670f1b5 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 10 Jan 2018 10:20:34 +0800 Subject: [PATCH 0053/1161] [MINOR] fix a typo in BroadcastJoinSuite ## What changes were proposed in this pull request? `BroadcastNestedLoopJoinExec` should be `BroadcastHashJoinExec` ## How was this patch tested? N/A Author: Wenchen Fan Closes #20202 from cloud-fan/typo. --- .../apache/spark/sql/execution/joins/BroadcastJoinSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala index 6da46ea3480b3..0bcd54e1fceab 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala @@ -318,7 +318,7 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils { case b: BroadcastNestedLoopJoinExec => assert(b.getClass.getSimpleName === joinMethod) assert(b.buildSide === buildSide) - case b: BroadcastNestedLoopJoinExec => + case b: BroadcastHashJoinExec => assert(b.getClass.getSimpleName === joinMethod) assert(b.buildSide === buildSide) case w: WholeStageCodegenExec => From 7bcc2666810cefc85dfa0d6679ac7a0de9e23154 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Wed, 10 Jan 2018 14:00:07 +0900 Subject: [PATCH 0054/1161] [SPARK-23018][PYTHON] Fix createDataFrame from Pandas timestamp series assignment ## What changes were proposed in this pull request? This fixes createDataFrame from Pandas to only assign modified timestamp series back to a copied version of the Pandas DataFrame. Previously, if the Pandas DataFrame was only a reference (e.g. a slice of another) each series will still get assigned back to the reference even if it is not a modified timestamp column. This caused the following warning "SettingWithCopyWarning: A value is trying to be set on a copy of a slice from a DataFrame." ## How was this patch tested? existing tests Author: Bryan Cutler Closes #20213 from BryanCutler/pyspark-createDataFrame-copy-slice-warn-SPARK-23018. --- python/pyspark/sql/session.py | 28 +++++++++++++++------------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index 6052fa9e84096..3e4574729a631 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -459,21 +459,23 @@ def _convert_from_pandas(self, pdf, schema, timezone): # TODO: handle nested timestamps, such as ArrayType(TimestampType())? if isinstance(field.dataType, TimestampType): s = _check_series_convert_timestamps_tz_local(pdf[field.name], timezone) - if not copied and s is not pdf[field.name]: - # Copy once if the series is modified to prevent the original Pandas - # DataFrame from being updated - pdf = pdf.copy() - copied = True - pdf[field.name] = s + if s is not pdf[field.name]: + if not copied: + # Copy once if the series is modified to prevent the original + # Pandas DataFrame from being updated + pdf = pdf.copy() + copied = True + pdf[field.name] = s else: for column, series in pdf.iteritems(): - s = _check_series_convert_timestamps_tz_local(pdf[column], timezone) - if not copied and s is not pdf[column]: - # Copy once if the series is modified to prevent the original Pandas - # DataFrame from being updated - pdf = pdf.copy() - copied = True - pdf[column] = s + s = _check_series_convert_timestamps_tz_local(series, timezone) + if s is not series: + if not copied: + # Copy once if the series is modified to prevent the original + # Pandas DataFrame from being updated + pdf = pdf.copy() + copied = True + pdf[column] = s # Convert pandas.DataFrame to list of numpy records np_records = pdf.to_records(index=False) From e5998372487af20114e160264a594957344ff433 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Wed, 10 Jan 2018 14:55:24 +0900 Subject: [PATCH 0055/1161] [SPARK-23009][PYTHON] Fix for non-str col names to createDataFrame from Pandas ## What changes were proposed in this pull request? This the case when calling `SparkSession.createDataFrame` using a Pandas DataFrame that has non-str column labels. The column name conversion logic to handle non-string or unicode in python2 is: ``` if column is not any type of string: name = str(column) else if column is unicode in Python 2: name = column.encode('utf-8') ``` ## How was this patch tested? Added a new test with a Pandas DataFrame that has int column labels Author: Bryan Cutler Closes #20210 from BryanCutler/python-createDataFrame-int-col-error-SPARK-23009. --- python/pyspark/sql/session.py | 4 +++- python/pyspark/sql/tests.py | 9 +++++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index 3e4574729a631..604021c1f45cc 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -648,7 +648,9 @@ def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=Tr # If no schema supplied by user then get the names of columns only if schema is None: - schema = [x.encode('utf-8') if not isinstance(x, str) else x for x in data.columns] + schema = [str(x) if not isinstance(x, basestring) else + (x.encode('utf-8') if not isinstance(x, str) else x) + for x in data.columns] if self.conf.get("spark.sql.execution.arrow.enabled", "false").lower() == "true" \ and len(data) > 0: diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 13576ff57001b..80a94a91a87b3 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3532,6 +3532,15 @@ def test_toPandas_with_array_type(self): self.assertTrue(expected[r][e] == result_arrow[r][e] and result[r][e] == result_arrow[r][e]) + def test_createDataFrame_with_int_col_names(self): + import numpy as np + import pandas as pd + pdf = pd.DataFrame(np.random.rand(4, 2)) + df, df_arrow = self._createDataFrame_toggle(pdf) + pdf_col_names = [str(c) for c in pdf.columns] + self.assertEqual(pdf_col_names, df.columns) + self.assertEqual(pdf_col_names, df_arrow.columns) + @unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed") class PandasUDFTests(ReusedSQLTestCase): From edf0a48c2ec696b92ed6a96dcee6eeb1a046b20b Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 10 Jan 2018 15:01:11 +0800 Subject: [PATCH 0056/1161] [SPARK-22982] Remove unsafe asynchronous close() call from FileDownloadChannel ## What changes were proposed in this pull request? This patch fixes a severe asynchronous IO bug in Spark's Netty-based file transfer code. At a high-level, the problem is that an unsafe asynchronous `close()` of a pipe's source channel creates a race condition where file transfer code closes a file descriptor then attempts to read from it. If the closed file descriptor's number has been reused by an `open()` call then this invalid read may cause unrelated file operations to return incorrect results. **One manifestation of this problem is incorrect query results.** For a high-level overview of how file download works, take a look at the control flow in `NettyRpcEnv.openChannel()`: this code creates a pipe to buffer results, then submits an asynchronous stream request to a lower-level TransportClient. The callback passes received data to the sink end of the pipe. The source end of the pipe is passed back to the caller of `openChannel()`. Thus `openChannel()` returns immediately and callers interact with the returned pipe source channel. Because the underlying stream request is asynchronous, errors may occur after `openChannel()` has returned and after that method's caller has started to `read()` from the returned channel. For example, if a client requests an invalid stream from a remote server then the "stream does not exist" error may not be received from the remote server until after `openChannel()` has returned. In order to be able to propagate the "stream does not exist" error to the file-fetching application thread, this code wraps the pipe's source channel in a special `FileDownloadChannel` which adds an `setError(t: Throwable)` method, then calls this `setError()` method in the FileDownloadCallback's `onFailure` method. It is possible for `FileDownloadChannel`'s `read()` and `setError()` methods to be called concurrently from different threads: the `setError()` method is called from within the Netty RPC system's stream callback handlers, while the `read()` methods are called from higher-level application code performing remote stream reads. The problem lies in `setError()`: the existing code closed the wrapped pipe source channel. Because `read()` and `setError()` occur in different threads, this means it is possible for one thread to be calling `source.read()` while another asynchronously calls `source.close()`. Java's IO libraries do not guarantee that this will be safe and, in fact, it's possible for these operations to interleave in such a way that a lower-level `read()` system call occurs right after a `close()` call. In the best-case, this fails as a read of a closed file descriptor; in the worst-case, the file descriptor number has been re-used by an intervening `open()` operation and the read corrupts the result of an unrelated file IO operation being performed by a different thread. The solution here is to remove the `stream.close()` call in `onError()`: the thread that is performing the `read()` calls is responsible for closing the stream in a `finally` block, so there's no need to close it here. If that thread is blocked in a `read()` then it will become unblocked when the sink end of the pipe is closed in `FileDownloadCallback.onFailure()`. After making this change, we also need to refine the `read()` method to always check for a `setError()` result, even if the underlying channel `read()` call has succeeded. This patch also makes a slight cleanup to a dodgy-looking `catch e: Exception` block to use a safer `try-finally` error handling idiom. This bug was introduced in SPARK-11956 / #9941 and is present in Spark 1.6.0+. ## How was this patch tested? This fix was tested manually against a workload which non-deterministically hit this bug. Author: Josh Rosen Closes #20179 from JoshRosen/SPARK-22982-fix-unsafe-async-io-in-file-download-channel. --- .../apache/spark/rpc/netty/NettyRpcEnv.scala | 37 +++++++++++-------- .../shuffle/IndexShuffleBlockResolver.scala | 21 +++++++++-- 2 files changed, 39 insertions(+), 19 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala index f951591e02a5c..a2936d6ad539c 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala @@ -332,16 +332,14 @@ private[netty] class NettyRpcEnv( val pipe = Pipe.open() val source = new FileDownloadChannel(pipe.source()) - try { + Utils.tryWithSafeFinallyAndFailureCallbacks(block = { val client = downloadClient(parsedUri.getHost(), parsedUri.getPort()) val callback = new FileDownloadCallback(pipe.sink(), source, client) client.stream(parsedUri.getPath(), callback) - } catch { - case e: Exception => - pipe.sink().close() - source.close() - throw e - } + })(catchBlock = { + pipe.sink().close() + source.close() + }) source } @@ -370,24 +368,33 @@ private[netty] class NettyRpcEnv( fileDownloadFactory.createClient(host, port) } - private class FileDownloadChannel(source: ReadableByteChannel) extends ReadableByteChannel { + private class FileDownloadChannel(source: Pipe.SourceChannel) extends ReadableByteChannel { @volatile private var error: Throwable = _ def setError(e: Throwable): Unit = { + // This setError callback is invoked by internal RPC threads in order to propagate remote + // exceptions to application-level threads which are reading from this channel. When an + // RPC error occurs, the RPC system will call setError() and then will close the + // Pipe.SinkChannel corresponding to the other end of the `source` pipe. Closing of the pipe + // sink will cause `source.read()` operations to return EOF, unblocking the application-level + // reading thread. Thus there is no need to actually call `source.close()` here in the + // onError() callback and, in fact, calling it here would be dangerous because the close() + // would be asynchronous with respect to the read() call and could trigger race-conditions + // that lead to data corruption. See the PR for SPARK-22982 for more details on this topic. error = e - source.close() } override def read(dst: ByteBuffer): Int = { Try(source.read(dst)) match { + // See the documentation above in setError(): if an RPC error has occurred then setError() + // will be called to propagate the RPC error and then `source`'s corresponding + // Pipe.SinkChannel will be closed, unblocking this read. In that case, we want to propagate + // the remote RPC exception (and not any exceptions triggered by the pipe close, such as + // ChannelClosedException), hence this `error != null` check: + case _ if error != null => throw error case Success(bytesRead) => bytesRead - case Failure(readErr) => - if (error != null) { - throw error - } else { - throw readErr - } + case Failure(readErr) => throw readErr } } diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala index 15540485170d0..266ee42e39cca 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala @@ -18,8 +18,8 @@ package org.apache.spark.shuffle import java.io._ - -import com.google.common.io.ByteStreams +import java.nio.channels.Channels +import java.nio.file.Files import org.apache.spark.{SparkConf, SparkEnv} import org.apache.spark.internal.Logging @@ -196,11 +196,24 @@ private[spark] class IndexShuffleBlockResolver( // find out the consolidated file, then the offset within that from our index val indexFile = getIndexFile(blockId.shuffleId, blockId.mapId) - val in = new DataInputStream(new FileInputStream(indexFile)) + // SPARK-22982: if this FileInputStream's position is seeked forward by another piece of code + // which is incorrectly using our file descriptor then this code will fetch the wrong offsets + // (which may cause a reducer to be sent a different reducer's data). The explicit position + // checks added here were a useful debugging aid during SPARK-22982 and may help prevent this + // class of issue from re-occurring in the future which is why they are left here even though + // SPARK-22982 is fixed. + val channel = Files.newByteChannel(indexFile.toPath) + channel.position(blockId.reduceId * 8) + val in = new DataInputStream(Channels.newInputStream(channel)) try { - ByteStreams.skipFully(in, blockId.reduceId * 8) val offset = in.readLong() val nextOffset = in.readLong() + val actualPosition = channel.position() + val expectedPosition = blockId.reduceId * 8 + 16 + if (actualPosition != expectedPosition) { + throw new Exception(s"SPARK-22982: Incorrect channel position after index file reads: " + + s"expected $expectedPosition but actual position was $actualPosition.") + } new FileSegmentManagedBuffer( transportConf, getDataFile(blockId.shuffleId, blockId.mapId), From eaac60a1e20e29084b7151ffca964cfaa5ba99d1 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 10 Jan 2018 15:16:27 +0800 Subject: [PATCH 0057/1161] [SPARK-16060][SQL][FOLLOW-UP] add a wrapper solution for vectorized orc reader ## What changes were proposed in this pull request? This is mostly from https://github.com/apache/spark/pull/13775 The wrapper solution is pretty good for string/binary type, as the ORC column vector doesn't keep bytes in a continuous memory region, and has a significant overhead when copying the data to Spark columnar batch. For other cases, the wrapper solution is almost same with the current solution. I think we can treat the wrapper solution as a baseline and keep improving the writing to Spark solution. ## How was this patch tested? existing tests. Author: Wenchen Fan Closes #20205 from cloud-fan/orc. --- .../apache/spark/sql/internal/SQLConf.scala | 7 + .../datasources/orc/OrcColumnVector.java | 251 ++++++++++++++++++ .../orc/OrcColumnarBatchReader.java | 106 ++++++-- .../datasources/orc/OrcFileFormat.scala | 6 +- .../spark/sql/hive/orc/OrcReadBenchmark.scala | 236 ++++++++++------ 5 files changed, 490 insertions(+), 116 deletions(-) create mode 100644 sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 74949db883f7a..36e802a9faa6f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -391,6 +391,13 @@ object SQLConf { .booleanConf .createWithDefault(true) + val ORC_COPY_BATCH_TO_SPARK = buildConf("spark.sql.orc.copyBatchToSpark") + .doc("Whether or not to copy the ORC columnar batch to Spark columnar batch in the " + + "vectorized ORC reader.") + .internal() + .booleanConf + .createWithDefault(false) + val ORC_FILTER_PUSHDOWN_ENABLED = buildConf("spark.sql.orc.filterPushdown") .doc("When true, enable filter pushdown for ORC files.") .booleanConf diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java new file mode 100644 index 0000000000000..f94c55d860304 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java @@ -0,0 +1,251 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.orc; + +import java.math.BigDecimal; + +import org.apache.orc.storage.ql.exec.vector.*; + +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.sql.types.TimestampType; +import org.apache.spark.unsafe.types.UTF8String; + +/** + * A column vector class wrapping Hive's ColumnVector. Because Spark ColumnarBatch only accepts + * Spark's vectorized.ColumnVector, this column vector is used to adapt Hive ColumnVector with + * Spark ColumnarVector. + */ +public class OrcColumnVector extends org.apache.spark.sql.vectorized.ColumnVector { + private ColumnVector baseData; + private LongColumnVector longData; + private DoubleColumnVector doubleData; + private BytesColumnVector bytesData; + private DecimalColumnVector decimalData; + private TimestampColumnVector timestampData; + final private boolean isTimestamp; + + private int batchSize; + + OrcColumnVector(DataType type, ColumnVector vector) { + super(type); + + if (type instanceof TimestampType) { + isTimestamp = true; + } else { + isTimestamp = false; + } + + baseData = vector; + if (vector instanceof LongColumnVector) { + longData = (LongColumnVector) vector; + } else if (vector instanceof DoubleColumnVector) { + doubleData = (DoubleColumnVector) vector; + } else if (vector instanceof BytesColumnVector) { + bytesData = (BytesColumnVector) vector; + } else if (vector instanceof DecimalColumnVector) { + decimalData = (DecimalColumnVector) vector; + } else if (vector instanceof TimestampColumnVector) { + timestampData = (TimestampColumnVector) vector; + } else { + throw new UnsupportedOperationException(); + } + } + + public void setBatchSize(int batchSize) { + this.batchSize = batchSize; + } + + @Override + public void close() { + + } + + @Override + public int numNulls() { + if (baseData.isRepeating) { + if (baseData.isNull[0]) { + return batchSize; + } else { + return 0; + } + } else if (baseData.noNulls) { + return 0; + } else { + int count = 0; + for (int i = 0; i < batchSize; i++) { + if (baseData.isNull[i]) count++; + } + return count; + } + } + + /* A helper method to get the row index in a column. */ + private int getRowIndex(int rowId) { + return baseData.isRepeating ? 0 : rowId; + } + + @Override + public boolean isNullAt(int rowId) { + return baseData.isNull[getRowIndex(rowId)]; + } + + @Override + public boolean getBoolean(int rowId) { + return longData.vector[getRowIndex(rowId)] == 1; + } + + @Override + public boolean[] getBooleans(int rowId, int count) { + boolean[] res = new boolean[count]; + for (int i = 0; i < count; i++) { + res[i] = getBoolean(rowId + i); + } + return res; + } + + @Override + public byte getByte(int rowId) { + return (byte) longData.vector[getRowIndex(rowId)]; + } + + @Override + public byte[] getBytes(int rowId, int count) { + byte[] res = new byte[count]; + for (int i = 0; i < count; i++) { + res[i] = getByte(rowId + i); + } + return res; + } + + @Override + public short getShort(int rowId) { + return (short) longData.vector[getRowIndex(rowId)]; + } + + @Override + public short[] getShorts(int rowId, int count) { + short[] res = new short[count]; + for (int i = 0; i < count; i++) { + res[i] = getShort(rowId + i); + } + return res; + } + + @Override + public int getInt(int rowId) { + return (int) longData.vector[getRowIndex(rowId)]; + } + + @Override + public int[] getInts(int rowId, int count) { + int[] res = new int[count]; + for (int i = 0; i < count; i++) { + res[i] = getInt(rowId + i); + } + return res; + } + + @Override + public long getLong(int rowId) { + int index = getRowIndex(rowId); + if (isTimestamp) { + return timestampData.time[index] * 1000 + timestampData.nanos[index] / 1000; + } else { + return longData.vector[index]; + } + } + + @Override + public long[] getLongs(int rowId, int count) { + long[] res = new long[count]; + for (int i = 0; i < count; i++) { + res[i] = getLong(rowId + i); + } + return res; + } + + @Override + public float getFloat(int rowId) { + return (float) doubleData.vector[getRowIndex(rowId)]; + } + + @Override + public float[] getFloats(int rowId, int count) { + float[] res = new float[count]; + for (int i = 0; i < count; i++) { + res[i] = getFloat(rowId + i); + } + return res; + } + + @Override + public double getDouble(int rowId) { + return doubleData.vector[getRowIndex(rowId)]; + } + + @Override + public double[] getDoubles(int rowId, int count) { + double[] res = new double[count]; + for (int i = 0; i < count; i++) { + res[i] = getDouble(rowId + i); + } + return res; + } + + @Override + public int getArrayLength(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public int getArrayOffset(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public Decimal getDecimal(int rowId, int precision, int scale) { + BigDecimal data = decimalData.vector[getRowIndex(rowId)].getHiveDecimal().bigDecimalValue(); + return Decimal.apply(data, precision, scale); + } + + @Override + public UTF8String getUTF8String(int rowId) { + int index = getRowIndex(rowId); + BytesColumnVector col = bytesData; + return UTF8String.fromBytes(col.vector[index], col.start[index], col.length[index]); + } + + @Override + public byte[] getBinary(int rowId) { + int index = getRowIndex(rowId); + byte[] binary = new byte[bytesData.length[index]]; + System.arraycopy(bytesData.vector[index], bytesData.start[index], binary, 0, binary.length); + return binary; + } + + @Override + public org.apache.spark.sql.vectorized.ColumnVector arrayData() { + throw new UnsupportedOperationException(); + } + + @Override + public org.apache.spark.sql.vectorized.ColumnVector getChildColumn(int ordinal) { + throw new UnsupportedOperationException(); + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java index 5c28d0e6e507a..36fdf2bdf84d2 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java @@ -51,13 +51,13 @@ public class OrcColumnarBatchReader extends RecordReader { /** - * The default size of batch. We use this value for both ORC and Spark consistently - * because they have different default values like the following. + * The default size of batch. We use this value for ORC reader to make it consistent with Spark's + * columnar batch, because their default batch sizes are different like the following: * * - ORC's VectorizedRowBatch.DEFAULT_SIZE = 1024 * - Spark's ColumnarBatch.DEFAULT_BATCH_SIZE = 4 * 1024 */ - public static final int DEFAULT_SIZE = 4 * 1024; + private static final int DEFAULT_SIZE = 4 * 1024; // ORC File Reader private Reader reader; @@ -82,13 +82,18 @@ public class OrcColumnarBatchReader extends RecordReader { // Writable column vectors of the result columnar batch. private WritableColumnVector[] columnVectors; - /** - * The memory mode of the columnarBatch - */ + // The wrapped ORC column vectors. It should be null if `copyToSpark` is true. + private org.apache.spark.sql.vectorized.ColumnVector[] orcVectorWrappers; + + // The memory mode of the columnarBatch private final MemoryMode MEMORY_MODE; - public OrcColumnarBatchReader(boolean useOffHeap) { + // Whether or not to copy the ORC columnar batch to Spark columnar batch. + private final boolean copyToSpark; + + public OrcColumnarBatchReader(boolean useOffHeap, boolean copyToSpark) { MEMORY_MODE = useOffHeap ? MemoryMode.OFF_HEAP : MemoryMode.ON_HEAP; + this.copyToSpark = copyToSpark; } @@ -167,27 +172,61 @@ public void initBatch( } int capacity = DEFAULT_SIZE; - if (MEMORY_MODE == MemoryMode.OFF_HEAP) { - columnVectors = OffHeapColumnVector.allocateColumns(capacity, resultSchema); - } else { - columnVectors = OnHeapColumnVector.allocateColumns(capacity, resultSchema); - } - columnarBatch = new ColumnarBatch(resultSchema, columnVectors, capacity); - if (partitionValues.numFields() > 0) { - int partitionIdx = requiredFields.length; - for (int i = 0; i < partitionValues.numFields(); i++) { - ColumnVectorUtils.populate(columnVectors[i + partitionIdx], partitionValues, i); - columnVectors[i + partitionIdx].setIsConstant(); + if (copyToSpark) { + if (MEMORY_MODE == MemoryMode.OFF_HEAP) { + columnVectors = OffHeapColumnVector.allocateColumns(capacity, resultSchema); + } else { + columnVectors = OnHeapColumnVector.allocateColumns(capacity, resultSchema); } - } - // Initialize the missing columns once. - for (int i = 0; i < requiredFields.length; i++) { - if (requestedColIds[i] == -1) { - columnVectors[i].putNulls(0, columnarBatch.capacity()); - columnVectors[i].setIsConstant(); + // Initialize the missing columns once. + for (int i = 0; i < requiredFields.length; i++) { + if (requestedColIds[i] == -1) { + columnVectors[i].putNulls(0, capacity); + columnVectors[i].setIsConstant(); + } + } + + if (partitionValues.numFields() > 0) { + int partitionIdx = requiredFields.length; + for (int i = 0; i < partitionValues.numFields(); i++) { + ColumnVectorUtils.populate(columnVectors[i + partitionIdx], partitionValues, i); + columnVectors[i + partitionIdx].setIsConstant(); + } + } + + columnarBatch = new ColumnarBatch(resultSchema, columnVectors, capacity); + } else { + // Just wrap the ORC column vector instead of copying it to Spark column vector. + orcVectorWrappers = new org.apache.spark.sql.vectorized.ColumnVector[resultSchema.length()]; + + for (int i = 0; i < requiredFields.length; i++) { + DataType dt = requiredFields[i].dataType(); + int colId = requestedColIds[i]; + // Initialize the missing columns once. + if (colId == -1) { + OnHeapColumnVector missingCol = new OnHeapColumnVector(capacity, dt); + missingCol.putNulls(0, capacity); + missingCol.setIsConstant(); + orcVectorWrappers[i] = missingCol; + } else { + orcVectorWrappers[i] = new OrcColumnVector(dt, batch.cols[colId]); + } } + + if (partitionValues.numFields() > 0) { + int partitionIdx = requiredFields.length; + for (int i = 0; i < partitionValues.numFields(); i++) { + DataType dt = partitionSchema.fields()[i].dataType(); + OnHeapColumnVector partitionCol = new OnHeapColumnVector(capacity, dt); + ColumnVectorUtils.populate(partitionCol, partitionValues, i); + partitionCol.setIsConstant(); + orcVectorWrappers[partitionIdx + i] = partitionCol; + } + } + + columnarBatch = new ColumnarBatch(resultSchema, orcVectorWrappers, capacity); } } @@ -196,17 +235,26 @@ public void initBatch( * by copying from ORC VectorizedRowBatch columns to Spark ColumnarBatch columns. */ private boolean nextBatch() throws IOException { - for (WritableColumnVector vector : columnVectors) { - vector.reset(); - } - columnarBatch.setNumRows(0); - recordReader.nextBatch(batch); int batchSize = batch.size; if (batchSize == 0) { return false; } columnarBatch.setNumRows(batchSize); + + if (!copyToSpark) { + for (int i = 0; i < requiredFields.length; i++) { + if (requestedColIds[i] != -1) { + ((OrcColumnVector) orcVectorWrappers[i]).setBatchSize(batchSize); + } + } + return true; + } + + for (WritableColumnVector vector : columnVectors) { + vector.reset(); + } + for (int i = 0; i < requiredFields.length; i++) { StructField field = requiredFields[i]; WritableColumnVector toColumn = columnVectors[i]; diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala index b8bacfa1838ae..2dd314d165348 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala @@ -38,6 +38,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection import org.apache.spark.sql.execution.datasources._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ import org.apache.spark.util.SerializableConfiguration @@ -150,6 +151,7 @@ class OrcFileFormat val sqlConf = sparkSession.sessionState.conf val enableOffHeapColumnVector = sqlConf.offHeapColumnVectorEnabled val enableVectorizedReader = supportBatch(sparkSession, resultSchema) + val copyToSpark = sparkSession.sessionState.conf.getConf(SQLConf.ORC_COPY_BATCH_TO_SPARK) val broadcastedConf = sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) @@ -183,8 +185,8 @@ class OrcFileFormat val taskContext = Option(TaskContext.get()) if (enableVectorizedReader) { - val batchReader = - new OrcColumnarBatchReader(enableOffHeapColumnVector && taskContext.isDefined) + val batchReader = new OrcColumnarBatchReader( + enableOffHeapColumnVector && taskContext.isDefined, copyToSpark) batchReader.initialize(fileSplit, taskAttemptContext) batchReader.initBatch( reader.getSchema, diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcReadBenchmark.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcReadBenchmark.scala index 37ed846acd1eb..bf6efa7c4c08c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcReadBenchmark.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcReadBenchmark.scala @@ -86,7 +86,7 @@ object OrcReadBenchmark { } def numericScanBenchmark(values: Int, dataType: DataType): Unit = { - val sqlBenchmark = new Benchmark(s"SQL Single ${dataType.sql} Column Scan", values) + val benchmark = new Benchmark(s"SQL Single ${dataType.sql} Column Scan", values) withTempPath { dir => withTempTable("t1", "nativeOrcTable", "hiveOrcTable") { @@ -95,61 +95,73 @@ object OrcReadBenchmark { prepareTable(dir, spark.sql(s"SELECT CAST(value as ${dataType.sql}) id FROM t1")) - sqlBenchmark.addCase("Native ORC MR") { _ => + benchmark.addCase("Native ORC MR") { _ => withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> "false") { spark.sql("SELECT sum(id) FROM nativeOrcTable").collect() } } - sqlBenchmark.addCase("Native ORC Vectorized") { _ => + benchmark.addCase("Native ORC Vectorized") { _ => spark.sql("SELECT sum(id) FROM nativeOrcTable").collect() } - sqlBenchmark.addCase("Hive built-in ORC") { _ => + benchmark.addCase("Native ORC Vectorized with copy") { _ => + withSQLConf(SQLConf.ORC_COPY_BATCH_TO_SPARK.key -> "true") { + spark.sql("SELECT sum(id) FROM nativeOrcTable").collect() + } + } + + benchmark.addCase("Hive built-in ORC") { _ => spark.sql("SELECT sum(id) FROM hiveOrcTable").collect() } /* - Java HotSpot(TM) 64-Bit Server VM 1.8.0_152-b16 on Mac OS X 10.13.2 - Intel(R) Core(TM) i7-4770HQ CPU @ 2.20GHz + Java HotSpot(TM) 64-Bit Server VM 1.8.0_60-b27 on Mac OS X 10.13.1 + Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz SQL Single TINYINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ - Native ORC MR 1192 / 1221 13.2 75.8 1.0X - Native ORC Vectorized 161 / 170 97.5 10.3 7.4X - Hive built-in ORC 1399 / 1413 11.2 89.0 0.9X + Native ORC MR 1135 / 1171 13.9 72.2 1.0X + Native ORC Vectorized 152 / 163 103.4 9.7 7.5X + Native ORC Vectorized with copy 149 / 162 105.4 9.5 7.6X + Hive built-in ORC 1380 / 1384 11.4 87.7 0.8X SQL Single SMALLINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ - Native ORC MR 1287 / 1333 12.2 81.8 1.0X - Native ORC Vectorized 164 / 172 95.6 10.5 7.8X - Hive built-in ORC 1629 / 1650 9.7 103.6 0.8X + Native ORC MR 1182 / 1244 13.3 75.2 1.0X + Native ORC Vectorized 145 / 156 108.7 9.2 8.2X + Native ORC Vectorized with copy 148 / 158 106.4 9.4 8.0X + Hive built-in ORC 1591 / 1636 9.9 101.2 0.7X SQL Single INT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ - Native ORC MR 1304 / 1388 12.1 82.9 1.0X - Native ORC Vectorized 227 / 240 69.3 14.4 5.7X - Hive built-in ORC 1866 / 1867 8.4 118.6 0.7X + Native ORC MR 1271 / 1271 12.4 80.8 1.0X + Native ORC Vectorized 206 / 212 76.3 13.1 6.2X + Native ORC Vectorized with copy 200 / 213 78.8 12.7 6.4X + Hive built-in ORC 1776 / 1787 8.9 112.9 0.7X SQL Single BIGINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ - Native ORC MR 1331 / 1357 11.8 84.6 1.0X - Native ORC Vectorized 289 / 297 54.4 18.4 4.6X - Hive built-in ORC 1922 / 1929 8.2 122.2 0.7X + Native ORC MR 1344 / 1355 11.7 85.4 1.0X + Native ORC Vectorized 258 / 268 61.0 16.4 5.2X + Native ORC Vectorized with copy 252 / 257 62.4 16.0 5.3X + Hive built-in ORC 1818 / 1823 8.7 115.6 0.7X SQL Single FLOAT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ - Native ORC MR 1410 / 1428 11.2 89.7 1.0X - Native ORC Vectorized 328 / 335 48.0 20.8 4.3X - Hive built-in ORC 1929 / 2012 8.2 122.6 0.7X + Native ORC MR 1333 / 1352 11.8 84.8 1.0X + Native ORC Vectorized 310 / 324 50.7 19.7 4.3X + Native ORC Vectorized with copy 312 / 320 50.4 19.9 4.3X + Hive built-in ORC 1904 / 1918 8.3 121.0 0.7X SQL Single DOUBLE Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ - Native ORC MR 1467 / 1485 10.7 93.3 1.0X - Native ORC Vectorized 402 / 411 39.1 25.6 3.6X - Hive built-in ORC 2023 / 2042 7.8 128.6 0.7X + Native ORC MR 1408 / 1585 11.2 89.5 1.0X + Native ORC Vectorized 359 / 368 43.8 22.8 3.9X + Native ORC Vectorized with copy 364 / 371 43.2 23.2 3.9X + Hive built-in ORC 1881 / 1954 8.4 119.6 0.7X */ - sqlBenchmark.run() + benchmark.run() } } } @@ -176,19 +188,26 @@ object OrcReadBenchmark { spark.sql("SELECT sum(c1), sum(length(c2)) FROM nativeOrcTable").collect() } + benchmark.addCase("Native ORC Vectorized with copy") { _ => + withSQLConf(SQLConf.ORC_COPY_BATCH_TO_SPARK.key -> "true") { + spark.sql("SELECT sum(c1), sum(length(c2)) FROM nativeOrcTable").collect() + } + } + benchmark.addCase("Hive built-in ORC") { _ => spark.sql("SELECT sum(c1), sum(length(c2)) FROM hiveOrcTable").collect() } /* - Java HotSpot(TM) 64-Bit Server VM 1.8.0_152-b16 on Mac OS X 10.13.2 - Intel(R) Core(TM) i7-4770HQ CPU @ 2.20GHz + Java HotSpot(TM) 64-Bit Server VM 1.8.0_60-b27 on Mac OS X 10.13.1 + Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz Int and String Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ - Native ORC MR 2729 / 2744 3.8 260.2 1.0X - Native ORC Vectorized 1318 / 1344 8.0 125.7 2.1X - Hive built-in ORC 3731 / 3782 2.8 355.8 0.7X + Native ORC MR 2566 / 2592 4.1 244.7 1.0X + Native ORC Vectorized 1098 / 1113 9.6 104.7 2.3X + Native ORC Vectorized with copy 1527 / 1593 6.9 145.6 1.7X + Hive built-in ORC 3561 / 3705 2.9 339.6 0.7X */ benchmark.run() } @@ -205,63 +224,84 @@ object OrcReadBenchmark { prepareTable(dir, spark.sql("SELECT value % 2 AS p, value AS id FROM t1"), Some("p")) - benchmark.addCase("Read data column - Native ORC MR") { _ => + benchmark.addCase("Data column - Native ORC MR") { _ => withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> "false") { spark.sql("SELECT sum(id) FROM nativeOrcTable").collect() } } - benchmark.addCase("Read data column - Native ORC Vectorized") { _ => + benchmark.addCase("Data column - Native ORC Vectorized") { _ => spark.sql("SELECT sum(id) FROM nativeOrcTable").collect() } - benchmark.addCase("Read data column - Hive built-in ORC") { _ => + benchmark.addCase("Data column - Native ORC Vectorized with copy") { _ => + withSQLConf(SQLConf.ORC_COPY_BATCH_TO_SPARK.key -> "true") { + spark.sql("SELECT sum(id) FROM nativeOrcTable").collect() + } + } + + benchmark.addCase("Data column - Hive built-in ORC") { _ => spark.sql("SELECT sum(id) FROM hiveOrcTable").collect() } - benchmark.addCase("Read partition column - Native ORC MR") { _ => + benchmark.addCase("Partition column - Native ORC MR") { _ => withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> "false") { spark.sql("SELECT sum(p) FROM nativeOrcTable").collect() } } - benchmark.addCase("Read partition column - Native ORC Vectorized") { _ => + benchmark.addCase("Partition column - Native ORC Vectorized") { _ => spark.sql("SELECT sum(p) FROM nativeOrcTable").collect() } - benchmark.addCase("Read partition column - Hive built-in ORC") { _ => + benchmark.addCase("Partition column - Native ORC Vectorized with copy") { _ => + withSQLConf(SQLConf.ORC_COPY_BATCH_TO_SPARK.key -> "true") { + spark.sql("SELECT sum(p) FROM nativeOrcTable").collect() + } + } + + benchmark.addCase("Partition column - Hive built-in ORC") { _ => spark.sql("SELECT sum(p) FROM hiveOrcTable").collect() } - benchmark.addCase("Read both columns - Native ORC MR") { _ => + benchmark.addCase("Both columns - Native ORC MR") { _ => withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> "false") { spark.sql("SELECT sum(p), sum(id) FROM nativeOrcTable").collect() } } - benchmark.addCase("Read both columns - Native ORC Vectorized") { _ => + benchmark.addCase("Both columns - Native ORC Vectorized") { _ => spark.sql("SELECT sum(p), sum(id) FROM nativeOrcTable").collect() } - benchmark.addCase("Read both columns - Hive built-in ORC") { _ => + benchmark.addCase("Both column - Native ORC Vectorized with copy") { _ => + withSQLConf(SQLConf.ORC_COPY_BATCH_TO_SPARK.key -> "true") { + spark.sql("SELECT sum(p), sum(id) FROM nativeOrcTable").collect() + } + } + + benchmark.addCase("Both columns - Hive built-in ORC") { _ => spark.sql("SELECT sum(p), sum(id) FROM hiveOrcTable").collect() } /* - Java HotSpot(TM) 64-Bit Server VM 1.8.0_152-b16 on Mac OS X 10.13.2 - Intel(R) Core(TM) i7-4770HQ CPU @ 2.20GHz + Java HotSpot(TM) 64-Bit Server VM 1.8.0_60-b27 on Mac OS X 10.13.1 + Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz Partitioned Table: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ - Read data column - Native ORC MR 1531 / 1536 10.3 97.4 1.0X - Read data column - Native ORC Vectorized 295 / 298 53.3 18.8 5.2X - Read data column - Hive built-in ORC 2125 / 2126 7.4 135.1 0.7X - Read partition column - Native ORC MR 1049 / 1062 15.0 66.7 1.5X - Read partition column - Native ORC Vectorized 54 / 57 290.1 3.4 28.2X - Read partition column - Hive built-in ORC 1282 / 1291 12.3 81.5 1.2X - Read both columns - Native ORC MR 1594 / 1598 9.9 101.3 1.0X - Read both columns - Native ORC Vectorized 332 / 336 47.4 21.1 4.6X - Read both columns - Hive built-in ORC 2145 / 2187 7.3 136.4 0.7X + Data only - Native ORC MR 1447 / 1457 10.9 92.0 1.0X + Data only - Native ORC Vectorized 256 / 266 61.4 16.3 5.6X + Data only - Native ORC Vectorized with copy 263 / 273 59.8 16.7 5.5X + Data only - Hive built-in ORC 1960 / 1988 8.0 124.6 0.7X + Partition only - Native ORC MR 1039 / 1043 15.1 66.0 1.4X + Partition only - Native ORC Vectorized 48 / 53 326.6 3.1 30.1X + Partition only - Native ORC Vectorized with copy 48 / 53 328.4 3.0 30.2X + Partition only - Hive built-in ORC 1234 / 1242 12.7 78.4 1.2X + Both columns - Native ORC MR 1465 / 1475 10.7 93.1 1.0X + Both columns - Native ORC Vectorized 292 / 301 53.9 18.6 5.0X + Both column - Native ORC Vectorized with copy 348 / 354 45.1 22.2 4.2X + Both columns - Hive built-in ORC 2051 / 2060 7.7 130.4 0.7X */ benchmark.run() } @@ -287,19 +327,26 @@ object OrcReadBenchmark { spark.sql("SELECT sum(length(c1)) FROM nativeOrcTable").collect() } + benchmark.addCase("Native ORC Vectorized with copy") { _ => + withSQLConf(SQLConf.ORC_COPY_BATCH_TO_SPARK.key -> "true") { + spark.sql("SELECT sum(length(c1)) FROM nativeOrcTable").collect() + } + } + benchmark.addCase("Hive built-in ORC") { _ => spark.sql("SELECT sum(length(c1)) FROM hiveOrcTable").collect() } /* - Java HotSpot(TM) 64-Bit Server VM 1.8.0_152-b16 on Mac OS X 10.13.2 - Intel(R) Core(TM) i7-4770HQ CPU @ 2.20GHz + Java HotSpot(TM) 64-Bit Server VM 1.8.0_60-b27 on Mac OS X 10.13.1 + Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz Repeated String: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ - Native ORC MR 1325 / 1328 7.9 126.4 1.0X - Native ORC Vectorized 320 / 330 32.8 30.5 4.1X - Hive built-in ORC 1971 / 1972 5.3 188.0 0.7X + Native ORC MR 1271 / 1278 8.3 121.2 1.0X + Native ORC Vectorized 200 / 212 52.4 19.1 6.4X + Native ORC Vectorized with copy 342 / 347 30.7 32.6 3.7X + Hive built-in ORC 1874 / 2105 5.6 178.7 0.7X */ benchmark.run() } @@ -331,32 +378,42 @@ object OrcReadBenchmark { "WHERE c1 IS NOT NULL AND c2 IS NOT NULL").collect() } + benchmark.addCase("Native ORC Vectorized with copy") { _ => + withSQLConf(SQLConf.ORC_COPY_BATCH_TO_SPARK.key -> "true") { + spark.sql("SELECT SUM(LENGTH(c2)) FROM nativeOrcTable " + + "WHERE c1 IS NOT NULL AND c2 IS NOT NULL").collect() + } + } + benchmark.addCase("Hive built-in ORC") { _ => spark.sql("SELECT SUM(LENGTH(c2)) FROM hiveOrcTable " + "WHERE c1 IS NOT NULL AND c2 IS NOT NULL").collect() } /* - Java HotSpot(TM) 64-Bit Server VM 1.8.0_152-b16 on Mac OS X 10.13.2 - Intel(R) Core(TM) i7-4770HQ CPU @ 2.20GHz + Java HotSpot(TM) 64-Bit Server VM 1.8.0_60-b27 on Mac OS X 10.13.1 + Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz String with Nulls Scan (0.0%): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ - Native ORC MR 2553 / 2554 4.1 243.4 1.0X - Native ORC Vectorized 953 / 954 11.0 90.9 2.7X - Hive built-in ORC 3875 / 3898 2.7 369.6 0.7X + Native ORC MR 2394 / 2886 4.4 228.3 1.0X + Native ORC Vectorized 699 / 729 15.0 66.7 3.4X + Native ORC Vectorized with copy 959 / 1025 10.9 91.5 2.5X + Hive built-in ORC 3899 / 3901 2.7 371.9 0.6X String with Nulls Scan (0.5%): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ - Native ORC MR 2389 / 2408 4.4 227.8 1.0X - Native ORC Vectorized 1208 / 1209 8.7 115.2 2.0X - Hive built-in ORC 2940 / 2952 3.6 280.4 0.8X + Native ORC MR 2234 / 2255 4.7 213.1 1.0X + Native ORC Vectorized 854 / 869 12.3 81.4 2.6X + Native ORC Vectorized with copy 1099 / 1128 9.5 104.8 2.0X + Hive built-in ORC 2767 / 2793 3.8 263.9 0.8X String with Nulls Scan (0.95%): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ - Native ORC MR 1295 / 1311 8.1 123.5 1.0X - Native ORC Vectorized 449 / 457 23.4 42.8 2.9X - Hive built-in ORC 1649 / 1660 6.4 157.3 0.8X + Native ORC MR 1166 / 1202 9.0 111.2 1.0X + Native ORC Vectorized 338 / 345 31.1 32.2 3.5X + Native ORC Vectorized with copy 418 / 428 25.1 39.9 2.8X + Hive built-in ORC 1730 / 1761 6.1 164.9 0.7X */ benchmark.run() } @@ -364,7 +421,7 @@ object OrcReadBenchmark { } def columnsBenchmark(values: Int, width: Int): Unit = { - val sqlBenchmark = new Benchmark(s"SQL Single Column Scan from $width columns", values) + val benchmark = new Benchmark(s"Single Column Scan from $width columns", values) withTempPath { dir => withTempTable("t1", "nativeOrcTable", "hiveOrcTable") { @@ -376,43 +433,52 @@ object OrcReadBenchmark { prepareTable(dir, spark.sql("SELECT * FROM t1")) - sqlBenchmark.addCase("Native ORC MR") { _ => + benchmark.addCase("Native ORC MR") { _ => withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> "false") { spark.sql(s"SELECT sum(c$middle) FROM nativeOrcTable").collect() } } - sqlBenchmark.addCase("Native ORC Vectorized") { _ => + benchmark.addCase("Native ORC Vectorized") { _ => spark.sql(s"SELECT sum(c$middle) FROM nativeOrcTable").collect() } - sqlBenchmark.addCase("Hive built-in ORC") { _ => + benchmark.addCase("Native ORC Vectorized with copy") { _ => + withSQLConf(SQLConf.ORC_COPY_BATCH_TO_SPARK.key -> "true") { + spark.sql(s"SELECT sum(c$middle) FROM nativeOrcTable").collect() + } + } + + benchmark.addCase("Hive built-in ORC") { _ => spark.sql(s"SELECT sum(c$middle) FROM hiveOrcTable").collect() } /* - Java HotSpot(TM) 64-Bit Server VM 1.8.0_152-b16 on Mac OS X 10.13.2 - Intel(R) Core(TM) i7-4770HQ CPU @ 2.20GHz + Java HotSpot(TM) 64-Bit Server VM 1.8.0_60-b27 on Mac OS X 10.13.1 + Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz - SQL Single Column Scan from 100 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + Single Column Scan from 100 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ - Native ORC MR 1103 / 1124 1.0 1052.0 1.0X - Native ORC Vectorized 92 / 100 11.4 87.9 12.0X - Hive built-in ORC 383 / 390 2.7 365.4 2.9X + Native ORC MR 1050 / 1053 1.0 1001.1 1.0X + Native ORC Vectorized 95 / 101 11.0 90.9 11.0X + Native ORC Vectorized with copy 95 / 102 11.0 90.9 11.0X + Hive built-in ORC 348 / 358 3.0 331.8 3.0X - SQL Single Column Scan from 200 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + Single Column Scan from 200 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ - Native ORC MR 2245 / 2250 0.5 2141.0 1.0X - Native ORC Vectorized 157 / 165 6.7 150.2 14.3X - Hive built-in ORC 587 / 593 1.8 559.4 3.8X + Native ORC MR 2099 / 2108 0.5 2002.1 1.0X + Native ORC Vectorized 179 / 187 5.8 171.1 11.7X + Native ORC Vectorized with copy 176 / 188 6.0 167.6 11.9X + Hive built-in ORC 562 / 581 1.9 535.9 3.7X - SQL Single Column Scan from 300 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + Single Column Scan from 300 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ - Native ORC MR 3343 / 3350 0.3 3188.3 1.0X - Native ORC Vectorized 265 / 280 3.9 253.2 12.6X - Hive built-in ORC 828 / 842 1.3 789.8 4.0X + Native ORC MR 3221 / 3246 0.3 3071.4 1.0X + Native ORC Vectorized 312 / 322 3.4 298.0 10.3X + Native ORC Vectorized with copy 306 / 320 3.4 291.6 10.5X + Hive built-in ORC 815 / 824 1.3 777.3 4.0X */ - sqlBenchmark.run() + benchmark.run() } } } From 70bcc9d5ae33d6669bb5c97db29087ccead770fb Mon Sep 17 00:00:00 2001 From: sethah Date: Tue, 9 Jan 2018 23:32:47 -0800 Subject: [PATCH 0058/1161] [SPARK-22993][ML] Clarify HasCheckpointInterval param doc ## What changes were proposed in this pull request? Add a note to the `HasCheckpointInterval` parameter doc that clarifies that this setting is ignored when no checkpoint directory has been set on the spark context. ## How was this patch tested? No tests necessary, just a doc update. Author: sethah Closes #20188 from sethah/als_checkpoint_doc. --- R/pkg/R/mllib_recommendation.R | 2 ++ R/pkg/R/mllib_tree.R | 6 ++++++ .../apache/spark/ml/param/shared/SharedParamsCodeGen.scala | 4 +++- .../org/apache/spark/ml/param/shared/sharedParams.scala | 4 ++-- python/pyspark/ml/param/_shared_params_code_gen.py | 5 +++-- python/pyspark/ml/param/shared.py | 4 ++-- 6 files changed, 18 insertions(+), 7 deletions(-) diff --git a/R/pkg/R/mllib_recommendation.R b/R/pkg/R/mllib_recommendation.R index fa794249085d7..5441c4a4022a9 100644 --- a/R/pkg/R/mllib_recommendation.R +++ b/R/pkg/R/mllib_recommendation.R @@ -48,6 +48,8 @@ setClass("ALSModel", representation(jobj = "jobj")) #' @param numUserBlocks number of user blocks used to parallelize computation (> 0). #' @param numItemBlocks number of item blocks used to parallelize computation (> 0). #' @param checkpointInterval number of checkpoint intervals (>= 1) or disable checkpoint (-1). +#' Note: this setting will be ignored if the checkpoint directory is not +#' set. #' @param ... additional argument(s) passed to the method. #' @return \code{spark.als} returns a fitted ALS model. #' @rdname spark.als diff --git a/R/pkg/R/mllib_tree.R b/R/pkg/R/mllib_tree.R index 89a58bf0aadae..4e5ddf22ee16d 100644 --- a/R/pkg/R/mllib_tree.R +++ b/R/pkg/R/mllib_tree.R @@ -161,6 +161,8 @@ print.summary.decisionTree <- function(x) { #' >= 1. #' @param minInfoGain Minimum information gain for a split to be considered at a tree node. #' @param checkpointInterval Param for set checkpoint interval (>= 1) or disable checkpoint (-1). +#' Note: this setting will be ignored if the checkpoint directory is not +#' set. #' @param maxMemoryInMB Maximum memory in MB allocated to histogram aggregation. #' @param cacheNodeIds If FALSE, the algorithm will pass trees to executors to match instances with #' nodes. If TRUE, the algorithm will cache node IDs for each instance. Caching @@ -382,6 +384,8 @@ setMethod("write.ml", signature(object = "GBTClassificationModel", path = "chara #' @param minInstancesPerNode Minimum number of instances each child must have after split. #' @param minInfoGain Minimum information gain for a split to be considered at a tree node. #' @param checkpointInterval Param for set checkpoint interval (>= 1) or disable checkpoint (-1). +#' Note: this setting will be ignored if the checkpoint directory is not +#' set. #' @param maxMemoryInMB Maximum memory in MB allocated to histogram aggregation. #' @param cacheNodeIds If FALSE, the algorithm will pass trees to executors to match instances with #' nodes. If TRUE, the algorithm will cache node IDs for each instance. Caching @@ -595,6 +599,8 @@ setMethod("write.ml", signature(object = "RandomForestClassificationModel", path #' @param minInstancesPerNode Minimum number of instances each child must have after split. #' @param minInfoGain Minimum information gain for a split to be considered at a tree node. #' @param checkpointInterval Param for set checkpoint interval (>= 1) or disable checkpoint (-1). +#' Note: this setting will be ignored if the checkpoint directory is not +#' set. #' @param maxMemoryInMB Maximum memory in MB allocated to histogram aggregation. #' @param cacheNodeIds If FALSE, the algorithm will pass trees to executors to match instances with #' nodes. If TRUE, the algorithm will cache node IDs for each instance. Caching diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala index a5d57a15317e6..6ad44af9ef7eb 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala @@ -63,7 +63,9 @@ private[shared] object SharedParamsCodeGen { ParamDesc[Array[String]]("outputCols", "output column names"), ParamDesc[Int]("checkpointInterval", "set checkpoint interval (>= 1) or " + "disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed " + - "every 10 iterations", isValid = "(interval: Int) => interval == -1 || interval >= 1"), + "every 10 iterations. Note: this setting will be ignored if the checkpoint directory " + + "is not set in the SparkContext", + isValid = "(interval: Int) => interval == -1 || interval >= 1"), ParamDesc[Boolean]("fitIntercept", "whether to fit an intercept term", Some("true")), ParamDesc[String]("handleInvalid", "how to handle invalid entries. Options are skip (which " + "will filter out rows with bad values), or error (which will throw an error). More " + diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala index 13425dacc9f18..be8b2f273164b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala @@ -282,10 +282,10 @@ trait HasOutputCols extends Params { trait HasCheckpointInterval extends Params { /** - * Param for set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations. + * Param for set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations. Note: this setting will be ignored if the checkpoint directory is not set in the SparkContext. * @group param */ - final val checkpointInterval: IntParam = new IntParam(this, "checkpointInterval", "set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations", (interval: Int) => interval == -1 || interval >= 1) + final val checkpointInterval: IntParam = new IntParam(this, "checkpointInterval", "set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations. Note: this setting will be ignored if the checkpoint directory is not set in the SparkContext", (interval: Int) => interval == -1 || interval >= 1) /** @group getParam */ final def getCheckpointInterval: Int = $(checkpointInterval) diff --git a/python/pyspark/ml/param/_shared_params_code_gen.py b/python/pyspark/ml/param/_shared_params_code_gen.py index d55d209d09398..1d0f60acc6983 100644 --- a/python/pyspark/ml/param/_shared_params_code_gen.py +++ b/python/pyspark/ml/param/_shared_params_code_gen.py @@ -121,8 +121,9 @@ def get$Name(self): ("outputCol", "output column name.", "self.uid + '__output'", "TypeConverters.toString"), ("numFeatures", "number of features.", None, "TypeConverters.toInt"), ("checkpointInterval", "set checkpoint interval (>= 1) or disable checkpoint (-1). " + - "E.g. 10 means that the cache will get checkpointed every 10 iterations.", None, - "TypeConverters.toInt"), + "E.g. 10 means that the cache will get checkpointed every 10 iterations. Note: " + + "this setting will be ignored if the checkpoint directory is not set in the SparkContext.", + None, "TypeConverters.toInt"), ("seed", "random seed.", "hash(type(self).__name__)", "TypeConverters.toInt"), ("tol", "the convergence tolerance for iterative algorithms (>= 0).", None, "TypeConverters.toFloat"), diff --git a/python/pyspark/ml/param/shared.py b/python/pyspark/ml/param/shared.py index e5c5ddfba6c1f..813f7a59f3fd1 100644 --- a/python/pyspark/ml/param/shared.py +++ b/python/pyspark/ml/param/shared.py @@ -281,10 +281,10 @@ def getNumFeatures(self): class HasCheckpointInterval(Params): """ - Mixin for param checkpointInterval: set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations. + Mixin for param checkpointInterval: set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations. Note: this setting will be ignored if the checkpoint directory is not set in the SparkContext. """ - checkpointInterval = Param(Params._dummy(), "checkpointInterval", "set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations.", typeConverter=TypeConverters.toInt) + checkpointInterval = Param(Params._dummy(), "checkpointInterval", "set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations. Note: this setting will be ignored if the checkpoint directory is not set in the SparkContext.", typeConverter=TypeConverters.toInt) def __init__(self): super(HasCheckpointInterval, self).__init__() From f340b6b3066033d40b7e163fd5fb68e9820adfb1 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 10 Jan 2018 00:45:47 -0800 Subject: [PATCH 0059/1161] [SPARK-22997] Add additional defenses against use of freed MemoryBlocks ## What changes were proposed in this pull request? This patch modifies Spark's `MemoryAllocator` implementations so that `free(MemoryBlock)` mutates the passed block to clear pointers (in the off-heap case) or null out references to backing `long[]` arrays (in the on-heap case). The goal of this change is to add an extra layer of defense against use-after-free bugs because currently it's hard to detect corruption caused by blind writes to freed memory blocks. ## How was this patch tested? New unit tests in `PlatformSuite`, including new tests for existing functionality because we did not have sufficient mutation coverage of the on-heap memory allocator's pooling logic. Author: Josh Rosen Closes #20191 from JoshRosen/SPARK-22997-add-defenses-against-use-after-free-bugs-in-memory-allocator. --- .../unsafe/memory/HeapMemoryAllocator.java | 35 +++++++++---- .../spark/unsafe/memory/MemoryBlock.java | 21 +++++++- .../unsafe/memory/UnsafeMemoryAllocator.java | 11 ++++ .../spark/unsafe/PlatformUtilSuite.java | 50 ++++++++++++++++++- .../spark/memory/TaskMemoryManager.java | 13 ++++- .../spark/memory/TaskMemoryManagerSuite.java | 29 +++++++++++ 6 files changed, 146 insertions(+), 13 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java index cc9cc429643ad..3acfe3696cb1e 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java @@ -31,8 +31,7 @@ public class HeapMemoryAllocator implements MemoryAllocator { @GuardedBy("this") - private final Map>> bufferPoolsBySize = - new HashMap<>(); + private final Map>> bufferPoolsBySize = new HashMap<>(); private static final int POOLING_THRESHOLD_BYTES = 1024 * 1024; @@ -49,13 +48,14 @@ private boolean shouldPool(long size) { public MemoryBlock allocate(long size) throws OutOfMemoryError { if (shouldPool(size)) { synchronized (this) { - final LinkedList> pool = bufferPoolsBySize.get(size); + final LinkedList> pool = bufferPoolsBySize.get(size); if (pool != null) { while (!pool.isEmpty()) { - final WeakReference blockReference = pool.pop(); - final MemoryBlock memory = blockReference.get(); - if (memory != null) { - assert (memory.size() == size); + final WeakReference arrayReference = pool.pop(); + final long[] array = arrayReference.get(); + if (array != null) { + assert (array.length * 8L >= size); + MemoryBlock memory = new MemoryBlock(array, Platform.LONG_ARRAY_OFFSET, size); if (MemoryAllocator.MEMORY_DEBUG_FILL_ENABLED) { memory.fill(MemoryAllocator.MEMORY_DEBUG_FILL_CLEAN_VALUE); } @@ -76,18 +76,35 @@ public MemoryBlock allocate(long size) throws OutOfMemoryError { @Override public void free(MemoryBlock memory) { + assert (memory.obj != null) : + "baseObject was null; are you trying to use the on-heap allocator to free off-heap memory?"; + assert (memory.pageNumber != MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER) : + "page has already been freed"; + assert ((memory.pageNumber == MemoryBlock.NO_PAGE_NUMBER) + || (memory.pageNumber == MemoryBlock.FREED_IN_TMM_PAGE_NUMBER)) : + "TMM-allocated pages must first be freed via TMM.freePage(), not directly in allocator free()"; + final long size = memory.size(); if (MemoryAllocator.MEMORY_DEBUG_FILL_ENABLED) { memory.fill(MemoryAllocator.MEMORY_DEBUG_FILL_FREED_VALUE); } + + // Mark the page as freed (so we can detect double-frees). + memory.pageNumber = MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER; + + // As an additional layer of defense against use-after-free bugs, we mutate the + // MemoryBlock to null out its reference to the long[] array. + long[] array = (long[]) memory.obj; + memory.setObjAndOffset(null, 0); + if (shouldPool(size)) { synchronized (this) { - LinkedList> pool = bufferPoolsBySize.get(size); + LinkedList> pool = bufferPoolsBySize.get(size); if (pool == null) { pool = new LinkedList<>(); bufferPoolsBySize.put(size, pool); } - pool.add(new WeakReference<>(memory)); + pool.add(new WeakReference<>(array)); } } else { // Do nothing diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java index cd1d378bc1470..c333857358d30 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java @@ -26,6 +26,25 @@ */ public class MemoryBlock extends MemoryLocation { + /** Special `pageNumber` value for pages which were not allocated by TaskMemoryManagers */ + public static final int NO_PAGE_NUMBER = -1; + + /** + * Special `pageNumber` value for marking pages that have been freed in the TaskMemoryManager. + * We set `pageNumber` to this value in TaskMemoryManager.freePage() so that MemoryAllocator + * can detect if pages which were allocated by TaskMemoryManager have been freed in the TMM + * before being passed to MemoryAllocator.free() (it is an error to allocate a page in + * TaskMemoryManager and then directly free it in a MemoryAllocator without going through + * the TMM freePage() call). + */ + public static final int FREED_IN_TMM_PAGE_NUMBER = -2; + + /** + * Special `pageNumber` value for pages that have been freed by the MemoryAllocator. This allows + * us to detect double-frees. + */ + public static final int FREED_IN_ALLOCATOR_PAGE_NUMBER = -3; + private final long length; /** @@ -33,7 +52,7 @@ public class MemoryBlock extends MemoryLocation { * TaskMemoryManager. This field is public so that it can be modified by the TaskMemoryManager, * which lives in a different package. */ - public int pageNumber = -1; + public int pageNumber = NO_PAGE_NUMBER; public MemoryBlock(@Nullable Object obj, long offset, long length) { super(obj, offset); diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java index 55bcdf1ed7b06..4368fb615ba1e 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java @@ -38,9 +38,20 @@ public MemoryBlock allocate(long size) throws OutOfMemoryError { public void free(MemoryBlock memory) { assert (memory.obj == null) : "baseObject not null; are you trying to use the off-heap allocator to free on-heap memory?"; + assert (memory.pageNumber != MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER) : + "page has already been freed"; + assert ((memory.pageNumber == MemoryBlock.NO_PAGE_NUMBER) + || (memory.pageNumber == MemoryBlock.FREED_IN_TMM_PAGE_NUMBER)) : + "TMM-allocated pages must be freed via TMM.freePage(), not directly in allocator free()"; + if (MemoryAllocator.MEMORY_DEBUG_FILL_ENABLED) { memory.fill(MemoryAllocator.MEMORY_DEBUG_FILL_FREED_VALUE); } Platform.freeMemory(memory.offset); + // As an additional layer of defense against use-after-free bugs, we mutate the + // MemoryBlock to reset its pointer. + memory.offset = 0; + // Mark the page as freed (so we can detect double-frees). + memory.pageNumber = MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER; } } diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java index 4b141339ec816..62854837b05ed 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java @@ -62,6 +62,52 @@ public void overlappingCopyMemory() { } } + @Test + public void onHeapMemoryAllocatorPoolingReUsesLongArrays() { + MemoryBlock block1 = MemoryAllocator.HEAP.allocate(1024 * 1024); + Object baseObject1 = block1.getBaseObject(); + MemoryAllocator.HEAP.free(block1); + MemoryBlock block2 = MemoryAllocator.HEAP.allocate(1024 * 1024); + Object baseObject2 = block2.getBaseObject(); + Assert.assertSame(baseObject1, baseObject2); + MemoryAllocator.HEAP.free(block2); + } + + @Test + public void freeingOnHeapMemoryBlockResetsBaseObjectAndOffset() { + MemoryBlock block = MemoryAllocator.HEAP.allocate(1024); + Assert.assertNotNull(block.getBaseObject()); + MemoryAllocator.HEAP.free(block); + Assert.assertNull(block.getBaseObject()); + Assert.assertEquals(0, block.getBaseOffset()); + Assert.assertEquals(MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER, block.pageNumber); + } + + @Test + public void freeingOffHeapMemoryBlockResetsOffset() { + MemoryBlock block = MemoryAllocator.UNSAFE.allocate(1024); + Assert.assertNull(block.getBaseObject()); + Assert.assertNotEquals(0, block.getBaseOffset()); + MemoryAllocator.UNSAFE.free(block); + Assert.assertNull(block.getBaseObject()); + Assert.assertEquals(0, block.getBaseOffset()); + Assert.assertEquals(MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER, block.pageNumber); + } + + @Test(expected = AssertionError.class) + public void onHeapMemoryAllocatorThrowsAssertionErrorOnDoubleFree() { + MemoryBlock block = MemoryAllocator.HEAP.allocate(1024); + MemoryAllocator.HEAP.free(block); + MemoryAllocator.HEAP.free(block); + } + + @Test(expected = AssertionError.class) + public void offHeapMemoryAllocatorThrowsAssertionErrorOnDoubleFree() { + MemoryBlock block = MemoryAllocator.UNSAFE.allocate(1024); + MemoryAllocator.UNSAFE.free(block); + MemoryAllocator.UNSAFE.free(block); + } + @Test public void memoryDebugFillEnabledInTest() { Assert.assertTrue(MemoryAllocator.MEMORY_DEBUG_FILL_ENABLED); @@ -71,9 +117,11 @@ public void memoryDebugFillEnabledInTest() { MemoryAllocator.MEMORY_DEBUG_FILL_CLEAN_VALUE); MemoryBlock onheap1 = MemoryAllocator.HEAP.allocate(1024 * 1024); + Object onheap1BaseObject = onheap1.getBaseObject(); + long onheap1BaseOffset = onheap1.getBaseOffset(); MemoryAllocator.HEAP.free(onheap1); Assert.assertEquals( - Platform.getByte(onheap1.getBaseObject(), onheap1.getBaseOffset()), + Platform.getByte(onheap1BaseObject, onheap1BaseOffset), MemoryAllocator.MEMORY_DEBUG_FILL_FREED_VALUE); MemoryBlock onheap2 = MemoryAllocator.HEAP.allocate(1024 * 1024); Assert.assertEquals( diff --git a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java index e8d3730daa7a4..632d718062212 100644 --- a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java +++ b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java @@ -321,8 +321,12 @@ public MemoryBlock allocatePage(long size, MemoryConsumer consumer) { * Free a block of memory allocated via {@link TaskMemoryManager#allocatePage}. */ public void freePage(MemoryBlock page, MemoryConsumer consumer) { - assert (page.pageNumber != -1) : + assert (page.pageNumber != MemoryBlock.NO_PAGE_NUMBER) : "Called freePage() on memory that wasn't allocated with allocatePage()"; + assert (page.pageNumber != MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER) : + "Called freePage() on a memory block that has already been freed"; + assert (page.pageNumber != MemoryBlock.FREED_IN_TMM_PAGE_NUMBER) : + "Called freePage() on a memory block that has already been freed"; assert(allocatedPages.get(page.pageNumber)); pageTable[page.pageNumber] = null; synchronized (this) { @@ -332,6 +336,10 @@ public void freePage(MemoryBlock page, MemoryConsumer consumer) { logger.trace("Freed page number {} ({} bytes)", page.pageNumber, page.size()); } long pageSize = page.size(); + // Clear the page number before passing the block to the MemoryAllocator's free(). + // Doing this allows the MemoryAllocator to detect when a TaskMemoryManager-managed + // page has been inappropriately directly freed without calling TMM.freePage(). + page.pageNumber = MemoryBlock.FREED_IN_TMM_PAGE_NUMBER; memoryManager.tungstenMemoryAllocator().free(page); releaseExecutionMemory(pageSize, consumer); } @@ -358,7 +366,7 @@ public long encodePageNumberAndOffset(MemoryBlock page, long offsetInPage) { @VisibleForTesting public static long encodePageNumberAndOffset(int pageNumber, long offsetInPage) { - assert (pageNumber != -1) : "encodePageNumberAndOffset called with invalid page"; + assert (pageNumber >= 0) : "encodePageNumberAndOffset called with invalid page"; return (((long) pageNumber) << OFFSET_BITS) | (offsetInPage & MASK_LONG_LOWER_51_BITS); } @@ -424,6 +432,7 @@ public long cleanUpAllAllocatedMemory() { for (MemoryBlock page : pageTable) { if (page != null) { logger.debug("unreleased page: " + page + " in task " + taskAttemptId); + page.pageNumber = MemoryBlock.FREED_IN_TMM_PAGE_NUMBER; memoryManager.tungstenMemoryAllocator().free(page); } } diff --git a/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java b/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java index 46b0516e36141..a0664b30d6cc2 100644 --- a/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java +++ b/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java @@ -21,6 +21,7 @@ import org.junit.Test; import org.apache.spark.SparkConf; +import org.apache.spark.unsafe.memory.MemoryAllocator; import org.apache.spark.unsafe.memory.MemoryBlock; public class TaskMemoryManagerSuite { @@ -68,6 +69,34 @@ public void encodePageNumberAndOffsetOnHeap() { Assert.assertEquals(64, manager.getOffsetInPage(encodedAddress)); } + @Test + public void freeingPageSetsPageNumberToSpecialConstant() { + final TaskMemoryManager manager = new TaskMemoryManager( + new TestMemoryManager(new SparkConf().set("spark.memory.offHeap.enabled", "false")), 0); + final MemoryConsumer c = new TestMemoryConsumer(manager, MemoryMode.ON_HEAP); + final MemoryBlock dataPage = manager.allocatePage(256, c); + c.freePage(dataPage); + Assert.assertEquals(MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER, dataPage.pageNumber); + } + + @Test(expected = AssertionError.class) + public void freeingPageDirectlyInAllocatorTriggersAssertionError() { + final TaskMemoryManager manager = new TaskMemoryManager( + new TestMemoryManager(new SparkConf().set("spark.memory.offHeap.enabled", "false")), 0); + final MemoryConsumer c = new TestMemoryConsumer(manager, MemoryMode.ON_HEAP); + final MemoryBlock dataPage = manager.allocatePage(256, c); + MemoryAllocator.HEAP.free(dataPage); + } + + @Test(expected = AssertionError.class) + public void callingFreePageOnDirectlyAllocatedPageTriggersAssertionError() { + final TaskMemoryManager manager = new TaskMemoryManager( + new TestMemoryManager(new SparkConf().set("spark.memory.offHeap.enabled", "false")), 0); + final MemoryConsumer c = new TestMemoryConsumer(manager, MemoryMode.ON_HEAP); + final MemoryBlock dataPage = MemoryAllocator.HEAP.allocate(256); + manager.freePage(dataPage, c); + } + @Test public void cooperativeSpilling() { final TestMemoryManager memoryManager = new TestMemoryManager(new SparkConf()); From 344e3aab87178e45957333479a07e07f202ca1fd Mon Sep 17 00:00:00 2001 From: Wang Gengliang Date: Wed, 10 Jan 2018 09:44:30 -0800 Subject: [PATCH 0060/1161] [SPARK-23019][CORE] Wait until SparkContext.stop() finished in SparkLauncherSuite ## What changes were proposed in this pull request? In current code ,the function `waitFor` call https://github.com/apache/spark/blob/cfcd746689c2b84824745fa6d327ffb584c7a17d/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java#L155 only wait until DAGScheduler is stopped, while SparkContext.clearActiveContext may not be called yet. https://github.com/apache/spark/blob/1c9f95cb771ac78775a77edd1abfeb2d8ae2a124/core/src/main/scala/org/apache/spark/SparkContext.scala#L1924 Thus, in the Jenkins test https://amplab.cs.berkeley.edu/jenkins/job/spark-branch-2.3-test-maven-hadoop-2.6/ , `JdbcRDDSuite` failed because the previous test `SparkLauncherSuite` exit before SparkContext.stop() is finished. To repo: ``` $ build/sbt > project core > testOnly *SparkLauncherSuite *JavaJdbcRDDSuite ``` To Fix: Wait for a reasonable amount of time to avoid creating two active SparkContext in JVM in SparkLauncherSuite. Can' come up with any better solution for now. ## How was this patch tested? Unit test Author: Wang Gengliang Closes #20221 from gengliangwang/SPARK-23019. --- .../java/org/apache/spark/launcher/SparkLauncherSuite.java | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java b/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java index c2261c204cd45..9d2f563b2e367 100644 --- a/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java +++ b/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java @@ -23,6 +23,7 @@ import java.util.List; import java.util.Map; import java.util.Properties; +import java.util.concurrent.TimeUnit; import org.junit.Test; import static org.junit.Assert.*; @@ -133,6 +134,10 @@ public void testInProcessLauncher() throws Exception { p.put(e.getKey(), e.getValue()); } System.setProperties(p); + // Here DAGScheduler is stopped, while SparkContext.clearActiveContext may not be called yet. + // Wait for a reasonable amount of time to avoid creating two active SparkContext in JVM. + // See SPARK-23019 and SparkContext.stop() for details. + TimeUnit.MILLISECONDS.sleep(500); } } From 9b33dfc408de986f4203bb0ac0c3f5c56effd69d Mon Sep 17 00:00:00 2001 From: Feng Liu Date: Wed, 10 Jan 2018 14:25:04 -0800 Subject: [PATCH 0061/1161] [SPARK-22951][SQL] fix aggregation after dropDuplicates on empty data frames ## What changes were proposed in this pull request? (courtesy of liancheng) Spark SQL supports both global aggregation and grouping aggregation. Global aggregation always return a single row with the initial aggregation state as the output, even there are zero input rows. Spark implements this by simply checking the number of grouping keys and treats an aggregation as a global aggregation if it has zero grouping keys. However, this simple principle drops the ball in the following case: ```scala spark.emptyDataFrame.dropDuplicates().agg(count($"*") as "c").show() // +---+ // | c | // +---+ // | 1 | // +---+ ``` The reason is that: 1. `df.dropDuplicates()` is roughly translated into something equivalent to: ```scala val allColumns = df.columns.map { col } df.groupBy(allColumns: _*).agg(allColumns.head, allColumns.tail: _*) ``` This translation is implemented in the rule `ReplaceDeduplicateWithAggregate`. 2. `spark.emptyDataFrame` contains zero columns and zero rows. Therefore, rule `ReplaceDeduplicateWithAggregate` makes a confusing transformation roughly equivalent to the following one: ```scala spark.emptyDataFrame.dropDuplicates() => spark.emptyDataFrame.groupBy().agg(Map.empty[String, String]) ``` The above transformation is confusing because the resulting aggregate operator contains no grouping keys (because `emptyDataFrame` contains no columns), and gets recognized as a global aggregation. As a result, Spark SQL allocates a single row filled by the initial aggregation state and uses it as the output, and returns a wrong result. To fix this issue, this PR tweaks `ReplaceDeduplicateWithAggregate` by appending a literal `1` to the grouping key list of the resulting `Aggregate` operator when the input plan contains zero output columns. In this way, `spark.emptyDataFrame.dropDuplicates()` is now translated into a grouping aggregation, roughly depicted as: ```scala spark.emptyDataFrame.dropDuplicates() => spark.emptyDataFrame.groupBy(lit(1)).agg(Map.empty[String, String]) ``` Which is now properly treated as a grouping aggregation and returns the correct answer. ## How was this patch tested? New unit tests added Author: Feng Liu Closes #20174 from liufengdb/fix-duplicate. --- .../sql/catalyst/optimizer/Optimizer.scala | 8 ++++++- .../optimizer/ReplaceOperatorSuite.scala | 10 +++++++- .../spark/sql/DataFrameAggregateSuite.scala | 24 +++++++++++++++++-- 3 files changed, 38 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index df0af8264a329..c794ba8619322 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -1222,7 +1222,13 @@ object ReplaceDeduplicateWithAggregate extends Rule[LogicalPlan] { Alias(new First(attr).toAggregateExpression(), attr.name)(attr.exprId) } } - Aggregate(keys, aggCols, child) + // SPARK-22951: Physical aggregate operators distinguishes global aggregation and grouping + // aggregations by checking the number of grouping keys. The key difference here is that a + // global aggregation always returns at least one row even if there are no input rows. Here + // we append a literal when the grouping key list is empty so that the result aggregate + // operator is properly treated as a grouping aggregation. + val nonemptyKeys = if (keys.isEmpty) Literal(1) :: Nil else keys + Aggregate(nonemptyKeys, aggCols, child) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala index 0fa1aaeb9e164..e9701ffd2c54b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.expressions.{Alias, Not} +import org.apache.spark.sql.catalyst.expressions.{Alias, Literal, Not} import org.apache.spark.sql.catalyst.expressions.aggregate.First import org.apache.spark.sql.catalyst.plans.{LeftAnti, LeftSemi, PlanTest} import org.apache.spark.sql.catalyst.plans.logical._ @@ -198,6 +198,14 @@ class ReplaceOperatorSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + test("add one grouping key if necessary when replace Deduplicate with Aggregate") { + val input = LocalRelation() + val query = Deduplicate(Seq.empty, input) // dropDuplicates() + val optimized = Optimize.execute(query.analyze) + val correctAnswer = Aggregate(Seq(Literal(1)), input.output, input) + comparePlans(optimized, correctAnswer) + } + test("don't replace streaming Deduplicate") { val input = LocalRelation(Seq('a.int, 'b.int), isStreaming = true) val attrA = input.output(0) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 06848e4d2b297..e7776e36702ad 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql import scala.util.Random +import org.apache.spark.sql.catalyst.expressions.{Alias, Literal} +import org.apache.spark.sql.catalyst.expressions.aggregate.Count import org.apache.spark.sql.execution.WholeStageCodegenExec import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec} import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec @@ -27,7 +29,7 @@ import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.test.SQLTestData.DecimalData -import org.apache.spark.sql.types.{Decimal, DecimalType} +import org.apache.spark.sql.types.DecimalType case class Fact(date: Int, hour: Int, minute: Int, room_name: String, temp: Double) @@ -456,7 +458,6 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { test("null moments") { val emptyTableData = Seq.empty[(Int, Int)].toDF("a", "b") - checkAnswer( emptyTableData.agg(variance('a), var_samp('a), var_pop('a), skewness('a), kurtosis('a)), Row(null, null, null, null, null)) @@ -666,4 +667,23 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { assert(exchangePlans.length == 1) } } + + Seq(true, false).foreach { codegen => + test("SPARK-22951: dropDuplicates on empty dataFrames should produce correct aggregate " + + s"results when codegen is enabled: $codegen") { + withSQLConf((SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, codegen.toString)) { + // explicit global aggregations + val emptyAgg = Map.empty[String, String] + checkAnswer(spark.emptyDataFrame.agg(emptyAgg), Seq(Row())) + checkAnswer(spark.emptyDataFrame.groupBy().agg(emptyAgg), Seq(Row())) + checkAnswer(spark.emptyDataFrame.groupBy().agg(count("*")), Seq(Row(0))) + checkAnswer(spark.emptyDataFrame.dropDuplicates().agg(emptyAgg), Seq(Row())) + checkAnswer(spark.emptyDataFrame.dropDuplicates().groupBy().agg(emptyAgg), Seq(Row())) + checkAnswer(spark.emptyDataFrame.dropDuplicates().groupBy().agg(count("*")), Seq(Row(0))) + + // global aggregation is converted to grouping aggregation: + assert(spark.emptyDataFrame.dropDuplicates().count() == 0) + } + } + } } From a6647ffbf7a312a3e119a9beef90880cc915aa60 Mon Sep 17 00:00:00 2001 From: Mingjie Tang Date: Thu, 11 Jan 2018 11:51:03 +0800 Subject: [PATCH 0062/1161] [SPARK-22587] Spark job fails if fs.defaultFS and application jar are different url ## What changes were proposed in this pull request? Two filesystems comparing does not consider the authority of URI. This is specific for WASB file storage system, where userInfo is honored to differentiate filesystems. For example: wasbs://user1xyz.net, wasbs://user2xyz.net would consider as two filesystem. Therefore, we have to add the authority to compare two filesystem, and two filesystem with different authority can not be the same FS. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Mingjie Tang Closes #19885 from merlintang/EAR-7377. --- .../org/apache/spark/deploy/yarn/Client.scala | 24 +++++++++++--- .../spark/deploy/yarn/ClientSuite.scala | 33 +++++++++++++++++++ 2 files changed, 53 insertions(+), 4 deletions(-) diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 15328d08b3b5c..8cd3cd9746a3a 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -1421,15 +1421,20 @@ private object Client extends Logging { } /** - * Return whether the two file systems are the same. + * Return whether two URI represent file system are the same */ - private def compareFs(srcFs: FileSystem, destFs: FileSystem): Boolean = { - val srcUri = srcFs.getUri() - val dstUri = destFs.getUri() + private[spark] def compareUri(srcUri: URI, dstUri: URI): Boolean = { + if (srcUri.getScheme() == null || srcUri.getScheme() != dstUri.getScheme()) { return false } + val srcAuthority = srcUri.getAuthority() + val dstAuthority = dstUri.getAuthority() + if (srcAuthority != null && !srcAuthority.equalsIgnoreCase(dstAuthority)) { + return false + } + var srcHost = srcUri.getHost() var dstHost = dstUri.getHost() @@ -1447,6 +1452,17 @@ private object Client extends Logging { } Objects.equal(srcHost, dstHost) && srcUri.getPort() == dstUri.getPort() + + } + + /** + * Return whether the two file systems are the same. + */ + protected def compareFs(srcFs: FileSystem, destFs: FileSystem): Boolean = { + val srcUri = srcFs.getUri() + val dstUri = destFs.getUri() + + compareUri(srcUri, dstUri) } /** diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala index 9d5f5eb621118..7fa597167f3f0 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala @@ -357,6 +357,39 @@ class ClientSuite extends SparkFunSuite with Matchers { sparkConf.get(SECONDARY_JARS) should be (Some(Seq(new File(jar2.toURI).getName))) } + private val matching = Seq( + ("files URI match test1", "file:///file1", "file:///file2"), + ("files URI match test2", "file:///c:file1", "file://c:file2"), + ("files URI match test3", "file://host/file1", "file://host/file2"), + ("wasb URI match test", "wasb://bucket1@user", "wasb://bucket1@user/"), + ("hdfs URI match test", "hdfs:/path1", "hdfs:/path1") + ) + + matching.foreach { t => + test(t._1) { + assert(Client.compareUri(new URI(t._2), new URI(t._3)), + s"No match between ${t._2} and ${t._3}") + } + } + + private val unmatching = Seq( + ("files URI unmatch test1", "file:///file1", "file://host/file2"), + ("files URI unmatch test2", "file://host/file1", "file:///file2"), + ("files URI unmatch test3", "file://host/file1", "file://host2/file2"), + ("wasb URI unmatch test1", "wasb://bucket1@user", "wasb://bucket2@user/"), + ("wasb URI unmatch test2", "wasb://bucket1@user", "wasb://bucket1@user2/"), + ("s3 URI unmatch test", "s3a://user@pass:bucket1/", "s3a://user2@pass2:bucket1/"), + ("hdfs URI unmatch test1", "hdfs://namenode1/path1", "hdfs://namenode1:8080/path2"), + ("hdfs URI unmatch test2", "hdfs://namenode1:8020/path1", "hdfs://namenode1:8080/path2") + ) + + unmatching.foreach { t => + test(t._1) { + assert(!Client.compareUri(new URI(t._2), new URI(t._3)), + s"match between ${t._2} and ${t._3}") + } + } + object Fixtures { val knownDefYarnAppCP: Seq[String] = From 87c98de8b23f0e978958fc83677fdc4c339b7e6a Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Thu, 11 Jan 2018 18:17:34 +0800 Subject: [PATCH 0063/1161] [SPARK-23001][SQL] Fix NullPointerException when DESC a database with NULL description ## What changes were proposed in this pull request? When users' DB description is NULL, users might hit `NullPointerException`. This PR is to fix the issue. ## How was this patch tested? Added test cases Author: gatorsmile Closes #20215 from gatorsmile/SPARK-23001. --- .../apache/spark/sql/hive/client/HiveClientImpl.scala | 2 +- .../apache/spark/sql/hive/HiveExternalCatalogSuite.scala | 6 ++++++ .../org/apache/spark/sql/hive/client/VersionsSuite.scala | 9 +++++++++ 3 files changed, 16 insertions(+), 1 deletion(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala index 102f40bacc985..4b923f5235a90 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala @@ -330,7 +330,7 @@ private[hive] class HiveClientImpl( Option(client.getDatabase(dbName)).map { d => CatalogDatabase( name = d.getName, - description = d.getDescription, + description = Option(d.getDescription).getOrElse(""), locationUri = CatalogUtils.stringToURI(d.getLocationUri), properties = Option(d.getParameters).map(_.asScala.toMap).orNull) }.getOrElse(throw new NoSuchDatabaseException(dbName)) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala index 2e35fdeba464d..0a522b6a11c80 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala @@ -107,4 +107,10 @@ class HiveExternalCatalogSuite extends ExternalCatalogSuite { .filter(_.contains("Num Buckets")).head assert(bucketString.contains("10")) } + + test("SPARK-23001: NullPointerException when running desc database") { + val catalog = newBasicCatalog() + catalog.createDatabase(newDb("dbWithNullDesc").copy(description = null), ignoreIfExists = false) + assert(catalog.getDatabase("dbWithNullDesc").description == "") + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala index 94473a08dd317..ff90e9dda5f7c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala @@ -163,6 +163,15 @@ class VersionsSuite extends SparkFunSuite with Logging { client.createDatabase(tempDB, ignoreIfExists = true) } + test(s"$version: createDatabase with null description") { + withTempDir { tmpDir => + val dbWithNullDesc = + CatalogDatabase("dbWithNullDesc", description = null, tmpDir.toURI, Map()) + client.createDatabase(dbWithNullDesc, ignoreIfExists = true) + assert(client.getDatabase("dbWithNullDesc").description == "") + } + } + test(s"$version: setCurrentDatabase") { client.setCurrentDatabase("default") } From 1c70da3bfbb4016e394de2c73eb0db7cdd9a6968 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Thu, 11 Jan 2018 19:41:48 +0800 Subject: [PATCH 0064/1161] [SPARK-20657][CORE] Speed up rendering of the stages page. There are two main changes to speed up rendering of the tasks list when rendering the stage page. The first one makes the code only load the tasks being shown in the current page of the tasks table, and information related to only those tasks. One side-effect of this change is that the graph that shows task-related events now only shows events for the tasks in the current page, instead of the previously hardcoded limit of "events for the first 1000 tasks". That ends up helping with readability, though. To make sorting efficient when using a disk store, the task wrapper was extended to include many new indices, one for each of the sortable columns in the UI, and metrics for which quantiles are calculated. The second changes the way metric quantiles are calculated for stages. Instead of using the "Distribution" class to process data for all task metrics, which requires scanning all tasks of a stage, the code now uses the KVStore "skip()" functionality to only read tasks that contain interesting information for the quantiles that are desired. This is still not cheap; because there are many metrics that the UI and API track, the code needs to scan the index for each metric to gather the information. Savings come mainly from skipping deserialization when using the disk store, but the in-memory code also seems to be faster than before (most probably because of other changes in this patch). To make subsequent calls faster, some quantiles are cached in the status store. This makes UIs much faster after the first time a stage has been loaded. With the above changes, a lot of code in the UI layer could be simplified. Author: Marcelo Vanzin Closes #20013 from vanzin/SPARK-20657. --- .../apache/spark/util/kvstore/LevelDB.java | 1 + .../spark/status/AppStatusListener.scala | 57 +- .../apache/spark/status/AppStatusStore.scala | 389 +++++--- .../apache/spark/status/AppStatusUtils.scala | 68 ++ .../org/apache/spark/status/LiveEntity.scala | 344 ++++--- .../spark/status/api/v1/StagesResource.scala | 3 +- .../org/apache/spark/status/api/v1/api.scala | 3 + .../org/apache/spark/status/storeTypes.scala | 327 ++++++- .../apache/spark/ui/jobs/ExecutorTable.scala | 4 +- .../org/apache/spark/ui/jobs/JobPage.scala | 2 +- .../org/apache/spark/ui/jobs/StagePage.scala | 919 ++++++------------ ...mmary_w__custom_quantiles_expectation.json | 3 + ...sk_summary_w_shuffle_read_expectation.json | 3 + ...k_summary_w_shuffle_write_expectation.json | 3 + .../spark/status/AppStatusListenerSuite.scala | 105 +- .../spark/status/AppStatusStoreSuite.scala | 104 ++ .../org/apache/spark/ui/StagePageSuite.scala | 10 +- scalastyle-config.xml | 2 +- 18 files changed, 1361 insertions(+), 986 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/status/AppStatusUtils.scala create mode 100644 core/src/test/scala/org/apache/spark/status/AppStatusStoreSuite.scala diff --git a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDB.java b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDB.java index 4f9e10ca20066..0e491efac9181 100644 --- a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDB.java +++ b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDB.java @@ -83,6 +83,7 @@ public LevelDB(File path, KVStoreSerializer serializer) throws Exception { if (versionData != null) { long version = serializer.deserializeLong(versionData); if (version != STORE_VERSION) { + close(); throw new UnsupportedStoreVersionException(); } } else { diff --git a/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala index 88b75ddd5993a..b4edcf23abc09 100644 --- a/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala +++ b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala @@ -377,6 +377,10 @@ private[spark] class AppStatusListener( Option(liveStages.get((event.stageId, event.stageAttemptId))).foreach { stage => stage.activeTasks += 1 stage.firstLaunchTime = math.min(stage.firstLaunchTime, event.taskInfo.launchTime) + + val locality = event.taskInfo.taskLocality.toString() + val count = stage.localitySummary.getOrElse(locality, 0L) + 1L + stage.localitySummary = stage.localitySummary ++ Map(locality -> count) maybeUpdate(stage, now) stage.jobs.foreach { job => @@ -433,7 +437,7 @@ private[spark] class AppStatusListener( } task.errorMessage = errorMessage val delta = task.updateMetrics(event.taskMetrics) - update(task, now) + update(task, now, last = true) delta }.orNull @@ -450,7 +454,7 @@ private[spark] class AppStatusListener( Option(liveStages.get((event.stageId, event.stageAttemptId))).foreach { stage => if (metricsDelta != null) { - stage.metrics.update(metricsDelta) + stage.metrics = LiveEntityHelpers.addMetrics(stage.metrics, metricsDelta) } stage.activeTasks -= 1 stage.completedTasks += completedDelta @@ -486,7 +490,7 @@ private[spark] class AppStatusListener( esummary.failedTasks += failedDelta esummary.killedTasks += killedDelta if (metricsDelta != null) { - esummary.metrics.update(metricsDelta) + esummary.metrics = LiveEntityHelpers.addMetrics(esummary.metrics, metricsDelta) } maybeUpdate(esummary, now) @@ -604,11 +608,11 @@ private[spark] class AppStatusListener( maybeUpdate(task, now) Option(liveStages.get((sid, sAttempt))).foreach { stage => - stage.metrics.update(delta) + stage.metrics = LiveEntityHelpers.addMetrics(stage.metrics, delta) maybeUpdate(stage, now) val esummary = stage.executorSummary(event.execId) - esummary.metrics.update(delta) + esummary.metrics = LiveEntityHelpers.addMetrics(esummary.metrics, delta) maybeUpdate(esummary, now) } } @@ -690,7 +694,7 @@ private[spark] class AppStatusListener( // can update the executor information too. liveRDDs.get(block.rddId).foreach { rdd => if (updatedStorageLevel.isDefined) { - rdd.storageLevel = updatedStorageLevel.get + rdd.setStorageLevel(updatedStorageLevel.get) } val partition = rdd.partition(block.name) @@ -814,7 +818,7 @@ private[spark] class AppStatusListener( /** Update a live entity only if it hasn't been updated in the last configured period. */ private def maybeUpdate(entity: LiveEntity, now: Long): Unit = { - if (liveUpdatePeriodNs >= 0 && now - entity.lastWriteTime > liveUpdatePeriodNs) { + if (live && liveUpdatePeriodNs >= 0 && now - entity.lastWriteTime > liveUpdatePeriodNs) { update(entity, now) } } @@ -865,7 +869,7 @@ private[spark] class AppStatusListener( } stages.foreach { s => - val key = s.id + val key = Array(s.info.stageId, s.info.attemptId) kvstore.delete(s.getClass(), key) val execSummaries = kvstore.view(classOf[ExecutorStageSummaryWrapper]) @@ -885,15 +889,15 @@ private[spark] class AppStatusListener( .asScala tasks.foreach { t => - kvstore.delete(t.getClass(), t.info.taskId) + kvstore.delete(t.getClass(), t.taskId) } // Check whether there are remaining attempts for the same stage. If there aren't, then // also delete the RDD graph data. val remainingAttempts = kvstore.view(classOf[StageDataWrapper]) .index("stageId") - .first(s.stageId) - .last(s.stageId) + .first(s.info.stageId) + .last(s.info.stageId) .closeableIterator() val hasMoreAttempts = try { @@ -905,8 +909,10 @@ private[spark] class AppStatusListener( } if (!hasMoreAttempts) { - kvstore.delete(classOf[RDDOperationGraphWrapper], s.stageId) + kvstore.delete(classOf[RDDOperationGraphWrapper], s.info.stageId) } + + cleanupCachedQuantiles(key) } } @@ -919,9 +925,9 @@ private[spark] class AppStatusListener( // Try to delete finished tasks only. val toDelete = KVUtils.viewToSeq(view, countToDelete) { t => - !live || t.info.status != TaskState.RUNNING.toString() + !live || t.status != TaskState.RUNNING.toString() } - toDelete.foreach { t => kvstore.delete(t.getClass(), t.info.taskId) } + toDelete.foreach { t => kvstore.delete(t.getClass(), t.taskId) } stage.savedTasks.addAndGet(-toDelete.size) // If there are more running tasks than the configured limit, delete running tasks. This @@ -930,13 +936,34 @@ private[spark] class AppStatusListener( val remaining = countToDelete - toDelete.size if (remaining > 0) { val runningTasksToDelete = view.max(remaining).iterator().asScala.toList - runningTasksToDelete.foreach { t => kvstore.delete(t.getClass(), t.info.taskId) } + runningTasksToDelete.foreach { t => kvstore.delete(t.getClass(), t.taskId) } stage.savedTasks.addAndGet(-remaining) } + + // On live applications, cleanup any cached quantiles for the stage. This makes sure that + // quantiles will be recalculated after tasks are replaced with newer ones. + // + // This is not needed in the SHS since caching only happens after the event logs are + // completely processed. + if (live) { + cleanupCachedQuantiles(stageKey) + } } stage.cleaning = false } + private def cleanupCachedQuantiles(stageKey: Array[Int]): Unit = { + val cachedQuantiles = kvstore.view(classOf[CachedQuantile]) + .index("stage") + .first(stageKey) + .last(stageKey) + .asScala + .toList + cachedQuantiles.foreach { q => + kvstore.delete(q.getClass(), q.id) + } + } + /** * Remove at least (retainedSize / 10) items to reduce friction. Because tracking may be done * asynchronously, this method may return 0 in case enough items have been deleted already. diff --git a/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala b/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala index 5a942f5284018..efc28538a33db 100644 --- a/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala +++ b/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala @@ -24,7 +24,7 @@ import scala.collection.JavaConverters._ import org.apache.spark.{JobExecutionStatus, SparkConf} import org.apache.spark.status.api.v1 import org.apache.spark.ui.scope._ -import org.apache.spark.util.Distribution +import org.apache.spark.util.{Distribution, Utils} import org.apache.spark.util.kvstore.{InMemoryStore, KVStore} /** @@ -98,7 +98,11 @@ private[spark] class AppStatusStore( val it = store.view(classOf[StageDataWrapper]).index("stageId").reverse().first(stageId) .closeableIterator() try { - it.next().info + if (it.hasNext()) { + it.next().info + } else { + throw new NoSuchElementException(s"No stage with id $stageId") + } } finally { it.close() } @@ -110,107 +114,238 @@ private[spark] class AppStatusStore( if (details) stageWithDetails(stage) else stage } + def taskCount(stageId: Int, stageAttemptId: Int): Long = { + store.count(classOf[TaskDataWrapper], "stage", Array(stageId, stageAttemptId)) + } + + def localitySummary(stageId: Int, stageAttemptId: Int): Map[String, Long] = { + store.read(classOf[StageDataWrapper], Array(stageId, stageAttemptId)).locality + } + + /** + * Calculates a summary of the task metrics for the given stage attempt, returning the + * requested quantiles for the recorded metrics. + * + * This method can be expensive if the requested quantiles are not cached; the method + * will only cache certain quantiles (every 0.05 step), so it's recommended to stick to + * those to avoid expensive scans of all task data. + */ def taskSummary( stageId: Int, stageAttemptId: Int, - quantiles: Array[Double]): v1.TaskMetricDistributions = { - - val stage = Array(stageId, stageAttemptId) - - val rawMetrics = store.view(classOf[TaskDataWrapper]) - .index("stage") - .first(stage) - .last(stage) - .asScala - .flatMap(_.info.taskMetrics) - .toList - .view - - def metricQuantiles(f: v1.TaskMetrics => Double): IndexedSeq[Double] = - Distribution(rawMetrics.map { d => f(d) }).get.getQuantiles(quantiles) - - // We need to do a lot of similar munging to nested metrics here. For each one, - // we want (a) extract the values for nested metrics (b) make a distribution for each metric - // (c) shove the distribution into the right field in our return type and (d) only return - // a result if the option is defined for any of the tasks. MetricHelper is a little util - // to make it a little easier to deal w/ all of the nested options. Mostly it lets us just - // implement one "build" method, which just builds the quantiles for each field. - - val inputMetrics = - new MetricHelper[v1.InputMetrics, v1.InputMetricDistributions](rawMetrics, quantiles) { - def getSubmetrics(raw: v1.TaskMetrics): v1.InputMetrics = raw.inputMetrics - - def build: v1.InputMetricDistributions = new v1.InputMetricDistributions( - bytesRead = submetricQuantiles(_.bytesRead), - recordsRead = submetricQuantiles(_.recordsRead) - ) - }.build - - val outputMetrics = - new MetricHelper[v1.OutputMetrics, v1.OutputMetricDistributions](rawMetrics, quantiles) { - def getSubmetrics(raw: v1.TaskMetrics): v1.OutputMetrics = raw.outputMetrics - - def build: v1.OutputMetricDistributions = new v1.OutputMetricDistributions( - bytesWritten = submetricQuantiles(_.bytesWritten), - recordsWritten = submetricQuantiles(_.recordsWritten) - ) - }.build - - val shuffleReadMetrics = - new MetricHelper[v1.ShuffleReadMetrics, v1.ShuffleReadMetricDistributions](rawMetrics, - quantiles) { - def getSubmetrics(raw: v1.TaskMetrics): v1.ShuffleReadMetrics = - raw.shuffleReadMetrics - - def build: v1.ShuffleReadMetricDistributions = new v1.ShuffleReadMetricDistributions( - readBytes = submetricQuantiles { s => s.localBytesRead + s.remoteBytesRead }, - readRecords = submetricQuantiles(_.recordsRead), - remoteBytesRead = submetricQuantiles(_.remoteBytesRead), - remoteBytesReadToDisk = submetricQuantiles(_.remoteBytesReadToDisk), - remoteBlocksFetched = submetricQuantiles(_.remoteBlocksFetched), - localBlocksFetched = submetricQuantiles(_.localBlocksFetched), - totalBlocksFetched = submetricQuantiles { s => - s.localBlocksFetched + s.remoteBlocksFetched - }, - fetchWaitTime = submetricQuantiles(_.fetchWaitTime) - ) - }.build - - val shuffleWriteMetrics = - new MetricHelper[v1.ShuffleWriteMetrics, v1.ShuffleWriteMetricDistributions](rawMetrics, - quantiles) { - def getSubmetrics(raw: v1.TaskMetrics): v1.ShuffleWriteMetrics = - raw.shuffleWriteMetrics - - def build: v1.ShuffleWriteMetricDistributions = new v1.ShuffleWriteMetricDistributions( - writeBytes = submetricQuantiles(_.bytesWritten), - writeRecords = submetricQuantiles(_.recordsWritten), - writeTime = submetricQuantiles(_.writeTime) - ) - }.build - - new v1.TaskMetricDistributions( + unsortedQuantiles: Array[Double]): Option[v1.TaskMetricDistributions] = { + val stageKey = Array(stageId, stageAttemptId) + val quantiles = unsortedQuantiles.sorted + + // We don't know how many tasks remain in the store that actually have metrics. So scan one + // metric and count how many valid tasks there are. Use skip() instead of next() since it's + // cheaper for disk stores (avoids deserialization). + val count = { + Utils.tryWithResource( + store.view(classOf[TaskDataWrapper]) + .parent(stageKey) + .index(TaskIndexNames.EXEC_RUN_TIME) + .first(0L) + .closeableIterator() + ) { it => + var _count = 0L + while (it.hasNext()) { + _count += 1 + it.skip(1) + } + _count + } + } + + if (count <= 0) { + return None + } + + // Find out which quantiles are already cached. The data in the store must match the expected + // task count to be considered, otherwise it will be re-scanned and overwritten. + val cachedQuantiles = quantiles.filter(shouldCacheQuantile).flatMap { q => + val qkey = Array(stageId, stageAttemptId, quantileToString(q)) + asOption(store.read(classOf[CachedQuantile], qkey)).filter(_.taskCount == count) + } + + // If there are no missing quantiles, return the data. Otherwise, just compute everything + // to make the code simpler. + if (cachedQuantiles.size == quantiles.size) { + def toValues(fn: CachedQuantile => Double): IndexedSeq[Double] = cachedQuantiles.map(fn) + + val distributions = new v1.TaskMetricDistributions( + quantiles = quantiles, + executorDeserializeTime = toValues(_.executorDeserializeTime), + executorDeserializeCpuTime = toValues(_.executorDeserializeCpuTime), + executorRunTime = toValues(_.executorRunTime), + executorCpuTime = toValues(_.executorCpuTime), + resultSize = toValues(_.resultSize), + jvmGcTime = toValues(_.jvmGcTime), + resultSerializationTime = toValues(_.resultSerializationTime), + gettingResultTime = toValues(_.gettingResultTime), + schedulerDelay = toValues(_.schedulerDelay), + peakExecutionMemory = toValues(_.peakExecutionMemory), + memoryBytesSpilled = toValues(_.memoryBytesSpilled), + diskBytesSpilled = toValues(_.diskBytesSpilled), + inputMetrics = new v1.InputMetricDistributions( + toValues(_.bytesRead), + toValues(_.recordsRead)), + outputMetrics = new v1.OutputMetricDistributions( + toValues(_.bytesWritten), + toValues(_.recordsWritten)), + shuffleReadMetrics = new v1.ShuffleReadMetricDistributions( + toValues(_.shuffleReadBytes), + toValues(_.shuffleRecordsRead), + toValues(_.shuffleRemoteBlocksFetched), + toValues(_.shuffleLocalBlocksFetched), + toValues(_.shuffleFetchWaitTime), + toValues(_.shuffleRemoteBytesRead), + toValues(_.shuffleRemoteBytesReadToDisk), + toValues(_.shuffleTotalBlocksFetched)), + shuffleWriteMetrics = new v1.ShuffleWriteMetricDistributions( + toValues(_.shuffleWriteBytes), + toValues(_.shuffleWriteRecords), + toValues(_.shuffleWriteTime))) + + return Some(distributions) + } + + // Compute quantiles by scanning the tasks in the store. This is not really stable for live + // stages (e.g. the number of recorded tasks may change while this code is running), but should + // stabilize once the stage finishes. It's also slow, especially with disk stores. + val indices = quantiles.map { q => math.min((q * count).toLong, count - 1) } + + def scanTasks(index: String)(fn: TaskDataWrapper => Long): IndexedSeq[Double] = { + Utils.tryWithResource( + store.view(classOf[TaskDataWrapper]) + .parent(stageKey) + .index(index) + .first(0L) + .closeableIterator() + ) { it => + var last = Double.NaN + var currentIdx = -1L + indices.map { idx => + if (idx == currentIdx) { + last + } else { + val diff = idx - currentIdx + currentIdx = idx + if (it.skip(diff - 1)) { + last = fn(it.next()).toDouble + last + } else { + Double.NaN + } + } + }.toIndexedSeq + } + } + + val computedQuantiles = new v1.TaskMetricDistributions( quantiles = quantiles, - executorDeserializeTime = metricQuantiles(_.executorDeserializeTime), - executorDeserializeCpuTime = metricQuantiles(_.executorDeserializeCpuTime), - executorRunTime = metricQuantiles(_.executorRunTime), - executorCpuTime = metricQuantiles(_.executorCpuTime), - resultSize = metricQuantiles(_.resultSize), - jvmGcTime = metricQuantiles(_.jvmGcTime), - resultSerializationTime = metricQuantiles(_.resultSerializationTime), - memoryBytesSpilled = metricQuantiles(_.memoryBytesSpilled), - diskBytesSpilled = metricQuantiles(_.diskBytesSpilled), - inputMetrics = inputMetrics, - outputMetrics = outputMetrics, - shuffleReadMetrics = shuffleReadMetrics, - shuffleWriteMetrics = shuffleWriteMetrics - ) + executorDeserializeTime = scanTasks(TaskIndexNames.DESER_TIME) { t => + t.executorDeserializeTime + }, + executorDeserializeCpuTime = scanTasks(TaskIndexNames.DESER_CPU_TIME) { t => + t.executorDeserializeCpuTime + }, + executorRunTime = scanTasks(TaskIndexNames.EXEC_RUN_TIME) { t => t.executorRunTime }, + executorCpuTime = scanTasks(TaskIndexNames.EXEC_CPU_TIME) { t => t.executorCpuTime }, + resultSize = scanTasks(TaskIndexNames.RESULT_SIZE) { t => t.resultSize }, + jvmGcTime = scanTasks(TaskIndexNames.GC_TIME) { t => t.jvmGcTime }, + resultSerializationTime = scanTasks(TaskIndexNames.SER_TIME) { t => + t.resultSerializationTime + }, + gettingResultTime = scanTasks(TaskIndexNames.GETTING_RESULT_TIME) { t => + t.gettingResultTime + }, + schedulerDelay = scanTasks(TaskIndexNames.SCHEDULER_DELAY) { t => t.schedulerDelay }, + peakExecutionMemory = scanTasks(TaskIndexNames.PEAK_MEM) { t => t.peakExecutionMemory }, + memoryBytesSpilled = scanTasks(TaskIndexNames.MEM_SPILL) { t => t.memoryBytesSpilled }, + diskBytesSpilled = scanTasks(TaskIndexNames.DISK_SPILL) { t => t.diskBytesSpilled }, + inputMetrics = new v1.InputMetricDistributions( + scanTasks(TaskIndexNames.INPUT_SIZE) { t => t.inputBytesRead }, + scanTasks(TaskIndexNames.INPUT_RECORDS) { t => t.inputRecordsRead }), + outputMetrics = new v1.OutputMetricDistributions( + scanTasks(TaskIndexNames.OUTPUT_SIZE) { t => t.outputBytesWritten }, + scanTasks(TaskIndexNames.OUTPUT_RECORDS) { t => t.outputRecordsWritten }), + shuffleReadMetrics = new v1.ShuffleReadMetricDistributions( + scanTasks(TaskIndexNames.SHUFFLE_TOTAL_READS) { m => + m.shuffleLocalBytesRead + m.shuffleRemoteBytesRead + }, + scanTasks(TaskIndexNames.SHUFFLE_READ_RECORDS) { t => t.shuffleRecordsRead }, + scanTasks(TaskIndexNames.SHUFFLE_REMOTE_BLOCKS) { t => t.shuffleRemoteBlocksFetched }, + scanTasks(TaskIndexNames.SHUFFLE_LOCAL_BLOCKS) { t => t.shuffleLocalBlocksFetched }, + scanTasks(TaskIndexNames.SHUFFLE_READ_TIME) { t => t.shuffleFetchWaitTime }, + scanTasks(TaskIndexNames.SHUFFLE_REMOTE_READS) { t => t.shuffleRemoteBytesRead }, + scanTasks(TaskIndexNames.SHUFFLE_REMOTE_READS_TO_DISK) { t => + t.shuffleRemoteBytesReadToDisk + }, + scanTasks(TaskIndexNames.SHUFFLE_TOTAL_BLOCKS) { m => + m.shuffleLocalBlocksFetched + m.shuffleRemoteBlocksFetched + }), + shuffleWriteMetrics = new v1.ShuffleWriteMetricDistributions( + scanTasks(TaskIndexNames.SHUFFLE_WRITE_SIZE) { t => t.shuffleBytesWritten }, + scanTasks(TaskIndexNames.SHUFFLE_WRITE_RECORDS) { t => t.shuffleRecordsWritten }, + scanTasks(TaskIndexNames.SHUFFLE_WRITE_TIME) { t => t.shuffleWriteTime })) + + // Go through the computed quantiles and cache the values that match the caching criteria. + computedQuantiles.quantiles.zipWithIndex + .filter { case (q, _) => quantiles.contains(q) && shouldCacheQuantile(q) } + .foreach { case (q, idx) => + val cached = new CachedQuantile(stageId, stageAttemptId, quantileToString(q), count, + executorDeserializeTime = computedQuantiles.executorDeserializeTime(idx), + executorDeserializeCpuTime = computedQuantiles.executorDeserializeCpuTime(idx), + executorRunTime = computedQuantiles.executorRunTime(idx), + executorCpuTime = computedQuantiles.executorCpuTime(idx), + resultSize = computedQuantiles.resultSize(idx), + jvmGcTime = computedQuantiles.jvmGcTime(idx), + resultSerializationTime = computedQuantiles.resultSerializationTime(idx), + gettingResultTime = computedQuantiles.gettingResultTime(idx), + schedulerDelay = computedQuantiles.schedulerDelay(idx), + peakExecutionMemory = computedQuantiles.peakExecutionMemory(idx), + memoryBytesSpilled = computedQuantiles.memoryBytesSpilled(idx), + diskBytesSpilled = computedQuantiles.diskBytesSpilled(idx), + + bytesRead = computedQuantiles.inputMetrics.bytesRead(idx), + recordsRead = computedQuantiles.inputMetrics.recordsRead(idx), + + bytesWritten = computedQuantiles.outputMetrics.bytesWritten(idx), + recordsWritten = computedQuantiles.outputMetrics.recordsWritten(idx), + + shuffleReadBytes = computedQuantiles.shuffleReadMetrics.readBytes(idx), + shuffleRecordsRead = computedQuantiles.shuffleReadMetrics.readRecords(idx), + shuffleRemoteBlocksFetched = + computedQuantiles.shuffleReadMetrics.remoteBlocksFetched(idx), + shuffleLocalBlocksFetched = computedQuantiles.shuffleReadMetrics.localBlocksFetched(idx), + shuffleFetchWaitTime = computedQuantiles.shuffleReadMetrics.fetchWaitTime(idx), + shuffleRemoteBytesRead = computedQuantiles.shuffleReadMetrics.remoteBytesRead(idx), + shuffleRemoteBytesReadToDisk = + computedQuantiles.shuffleReadMetrics.remoteBytesReadToDisk(idx), + shuffleTotalBlocksFetched = computedQuantiles.shuffleReadMetrics.totalBlocksFetched(idx), + + shuffleWriteBytes = computedQuantiles.shuffleWriteMetrics.writeBytes(idx), + shuffleWriteRecords = computedQuantiles.shuffleWriteMetrics.writeRecords(idx), + shuffleWriteTime = computedQuantiles.shuffleWriteMetrics.writeTime(idx)) + store.write(cached) + } + + Some(computedQuantiles) } + /** + * Whether to cache information about a specific metric quantile. We cache quantiles at every 0.05 + * step, which covers the default values used both in the API and in the stages page. + */ + private def shouldCacheQuantile(q: Double): Boolean = (math.round(q * 100) % 5) == 0 + + private def quantileToString(q: Double): String = math.round(q * 100).toString + def taskList(stageId: Int, stageAttemptId: Int, maxTasks: Int): Seq[v1.TaskData] = { val stageKey = Array(stageId, stageAttemptId) store.view(classOf[TaskDataWrapper]).index("stage").first(stageKey).last(stageKey).reverse() - .max(maxTasks).asScala.map(_.info).toSeq.reverse + .max(maxTasks).asScala.map(_.toApi).toSeq.reverse } def taskList( @@ -219,18 +354,43 @@ private[spark] class AppStatusStore( offset: Int, length: Int, sortBy: v1.TaskSorting): Seq[v1.TaskData] = { + val (indexName, ascending) = sortBy match { + case v1.TaskSorting.ID => + (None, true) + case v1.TaskSorting.INCREASING_RUNTIME => + (Some(TaskIndexNames.EXEC_RUN_TIME), true) + case v1.TaskSorting.DECREASING_RUNTIME => + (Some(TaskIndexNames.EXEC_RUN_TIME), false) + } + taskList(stageId, stageAttemptId, offset, length, indexName, ascending) + } + + def taskList( + stageId: Int, + stageAttemptId: Int, + offset: Int, + length: Int, + sortBy: Option[String], + ascending: Boolean): Seq[v1.TaskData] = { val stageKey = Array(stageId, stageAttemptId) val base = store.view(classOf[TaskDataWrapper]) val indexed = sortBy match { - case v1.TaskSorting.ID => + case Some(index) => + base.index(index).parent(stageKey) + + case _ => + // Sort by ID, which is the "stage" index. base.index("stage").first(stageKey).last(stageKey) - case v1.TaskSorting.INCREASING_RUNTIME => - base.index("runtime").first(stageKey ++ Array(-1L)).last(stageKey ++ Array(Long.MaxValue)) - case v1.TaskSorting.DECREASING_RUNTIME => - base.index("runtime").first(stageKey ++ Array(Long.MaxValue)).last(stageKey ++ Array(-1L)) - .reverse() } - indexed.skip(offset).max(length).asScala.map(_.info).toSeq + + val ordered = if (ascending) indexed else indexed.reverse() + ordered.skip(offset).max(length).asScala.map(_.toApi).toSeq + } + + def executorSummary(stageId: Int, attemptId: Int): Map[String, v1.ExecutorStageSummary] = { + val stageKey = Array(stageId, attemptId) + store.view(classOf[ExecutorStageSummaryWrapper]).index("stage").first(stageKey).last(stageKey) + .asScala.map { exec => (exec.executorId -> exec.info) }.toMap } def rddList(cachedOnly: Boolean = true): Seq[v1.RDDStorageInfo] = { @@ -256,12 +416,6 @@ private[spark] class AppStatusStore( .map { t => (t.taskId, t) } .toMap - val stageKey = Array(stage.stageId, stage.attemptId) - val execs = store.view(classOf[ExecutorStageSummaryWrapper]).index("stage").first(stageKey) - .last(stageKey).closeableIterator().asScala - .map { exec => (exec.executorId -> exec.info) } - .toMap - new v1.StageData( stage.status, stage.stageId, @@ -295,7 +449,7 @@ private[spark] class AppStatusStore( stage.rddIds, stage.accumulatorUpdates, Some(tasks), - Some(execs), + Some(executorSummary(stage.stageId, stage.attemptId)), stage.killedTasksSummary) } @@ -352,22 +506,3 @@ private[spark] object AppStatusStore { } } - -/** - * Helper for getting distributions from nested metric types. - */ -private abstract class MetricHelper[I, O]( - rawMetrics: Seq[v1.TaskMetrics], - quantiles: Array[Double]) { - - def getSubmetrics(raw: v1.TaskMetrics): I - - def build: O - - val data: Seq[I] = rawMetrics.map(getSubmetrics) - - /** applies the given function to all input metrics, and returns the quantiles */ - def submetricQuantiles(f: I => Double): IndexedSeq[Double] = { - Distribution(data.map { d => f(d) }).get.getQuantiles(quantiles) - } -} diff --git a/core/src/main/scala/org/apache/spark/status/AppStatusUtils.scala b/core/src/main/scala/org/apache/spark/status/AppStatusUtils.scala new file mode 100644 index 0000000000000..341bd4e0cd016 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/status/AppStatusUtils.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.status + +import org.apache.spark.status.api.v1.{TaskData, TaskMetrics} + +private[spark] object AppStatusUtils { + + def schedulerDelay(task: TaskData): Long = { + if (task.taskMetrics.isDefined && task.duration.isDefined) { + val m = task.taskMetrics.get + schedulerDelay(task.launchTime.getTime(), fetchStart(task), task.duration.get, + m.executorDeserializeTime, m.resultSerializationTime, m.executorRunTime) + } else { + 0L + } + } + + def gettingResultTime(task: TaskData): Long = { + gettingResultTime(task.launchTime.getTime(), fetchStart(task), task.duration.getOrElse(-1L)) + } + + def schedulerDelay( + launchTime: Long, + fetchStart: Long, + duration: Long, + deserializeTime: Long, + serializeTime: Long, + runTime: Long): Long = { + math.max(0, duration - runTime - deserializeTime - serializeTime - + gettingResultTime(launchTime, fetchStart, duration)) + } + + def gettingResultTime(launchTime: Long, fetchStart: Long, duration: Long): Long = { + if (fetchStart > 0) { + if (duration > 0) { + launchTime + duration - fetchStart + } else { + System.currentTimeMillis() - fetchStart + } + } else { + 0L + } + } + + private def fetchStart(task: TaskData): Long = { + if (task.resultFetchStart.isDefined) { + task.resultFetchStart.get.getTime() + } else { + -1 + } + } +} diff --git a/core/src/main/scala/org/apache/spark/status/LiveEntity.scala b/core/src/main/scala/org/apache/spark/status/LiveEntity.scala index 305c2fafa6aac..4295e664e131c 100644 --- a/core/src/main/scala/org/apache/spark/status/LiveEntity.scala +++ b/core/src/main/scala/org/apache/spark/status/LiveEntity.scala @@ -22,6 +22,8 @@ import java.util.concurrent.atomic.AtomicInteger import scala.collection.mutable.HashMap +import com.google.common.collect.Interners + import org.apache.spark.JobExecutionStatus import org.apache.spark.executor.TaskMetrics import org.apache.spark.scheduler.{AccumulableInfo, StageInfo, TaskInfo} @@ -119,7 +121,9 @@ private class LiveTask( import LiveEntityHelpers._ - private var recordedMetrics: v1.TaskMetrics = null + // The task metrics use a special value when no metrics have been reported. The special value is + // checked when calculating indexed values when writing to the store (see [[TaskDataWrapper]]). + private var metrics: v1.TaskMetrics = createMetrics(default = -1L) var errorMessage: Option[String] = None @@ -129,8 +133,8 @@ private class LiveTask( */ def updateMetrics(metrics: TaskMetrics): v1.TaskMetrics = { if (metrics != null) { - val old = recordedMetrics - recordedMetrics = new v1.TaskMetrics( + val old = this.metrics + val newMetrics = createMetrics( metrics.executorDeserializeTime, metrics.executorDeserializeCpuTime, metrics.executorRunTime, @@ -141,73 +145,35 @@ private class LiveTask( metrics.memoryBytesSpilled, metrics.diskBytesSpilled, metrics.peakExecutionMemory, - new v1.InputMetrics( - metrics.inputMetrics.bytesRead, - metrics.inputMetrics.recordsRead), - new v1.OutputMetrics( - metrics.outputMetrics.bytesWritten, - metrics.outputMetrics.recordsWritten), - new v1.ShuffleReadMetrics( - metrics.shuffleReadMetrics.remoteBlocksFetched, - metrics.shuffleReadMetrics.localBlocksFetched, - metrics.shuffleReadMetrics.fetchWaitTime, - metrics.shuffleReadMetrics.remoteBytesRead, - metrics.shuffleReadMetrics.remoteBytesReadToDisk, - metrics.shuffleReadMetrics.localBytesRead, - metrics.shuffleReadMetrics.recordsRead), - new v1.ShuffleWriteMetrics( - metrics.shuffleWriteMetrics.bytesWritten, - metrics.shuffleWriteMetrics.writeTime, - metrics.shuffleWriteMetrics.recordsWritten)) - if (old != null) calculateMetricsDelta(recordedMetrics, old) else recordedMetrics + metrics.inputMetrics.bytesRead, + metrics.inputMetrics.recordsRead, + metrics.outputMetrics.bytesWritten, + metrics.outputMetrics.recordsWritten, + metrics.shuffleReadMetrics.remoteBlocksFetched, + metrics.shuffleReadMetrics.localBlocksFetched, + metrics.shuffleReadMetrics.fetchWaitTime, + metrics.shuffleReadMetrics.remoteBytesRead, + metrics.shuffleReadMetrics.remoteBytesReadToDisk, + metrics.shuffleReadMetrics.localBytesRead, + metrics.shuffleReadMetrics.recordsRead, + metrics.shuffleWriteMetrics.bytesWritten, + metrics.shuffleWriteMetrics.writeTime, + metrics.shuffleWriteMetrics.recordsWritten) + + this.metrics = newMetrics + + // Only calculate the delta if the old metrics contain valid information, otherwise + // the new metrics are the delta. + if (old.executorDeserializeTime >= 0L) { + subtractMetrics(newMetrics, old) + } else { + newMetrics + } } else { null } } - /** - * Return a new TaskMetrics object containing the delta of the various fields of the given - * metrics objects. This is currently targeted at updating stage data, so it does not - * necessarily calculate deltas for all the fields. - */ - private def calculateMetricsDelta( - metrics: v1.TaskMetrics, - old: v1.TaskMetrics): v1.TaskMetrics = { - val shuffleWriteDelta = new v1.ShuffleWriteMetrics( - metrics.shuffleWriteMetrics.bytesWritten - old.shuffleWriteMetrics.bytesWritten, - 0L, - metrics.shuffleWriteMetrics.recordsWritten - old.shuffleWriteMetrics.recordsWritten) - - val shuffleReadDelta = new v1.ShuffleReadMetrics( - 0L, 0L, 0L, - metrics.shuffleReadMetrics.remoteBytesRead - old.shuffleReadMetrics.remoteBytesRead, - metrics.shuffleReadMetrics.remoteBytesReadToDisk - - old.shuffleReadMetrics.remoteBytesReadToDisk, - metrics.shuffleReadMetrics.localBytesRead - old.shuffleReadMetrics.localBytesRead, - metrics.shuffleReadMetrics.recordsRead - old.shuffleReadMetrics.recordsRead) - - val inputDelta = new v1.InputMetrics( - metrics.inputMetrics.bytesRead - old.inputMetrics.bytesRead, - metrics.inputMetrics.recordsRead - old.inputMetrics.recordsRead) - - val outputDelta = new v1.OutputMetrics( - metrics.outputMetrics.bytesWritten - old.outputMetrics.bytesWritten, - metrics.outputMetrics.recordsWritten - old.outputMetrics.recordsWritten) - - new v1.TaskMetrics( - 0L, 0L, - metrics.executorRunTime - old.executorRunTime, - metrics.executorCpuTime - old.executorCpuTime, - 0L, 0L, 0L, - metrics.memoryBytesSpilled - old.memoryBytesSpilled, - metrics.diskBytesSpilled - old.diskBytesSpilled, - 0L, - inputDelta, - outputDelta, - shuffleReadDelta, - shuffleWriteDelta) - } - override protected def doUpdate(): Any = { val duration = if (info.finished) { info.duration @@ -215,22 +181,48 @@ private class LiveTask( info.timeRunning(lastUpdateTime.getOrElse(System.currentTimeMillis())) } - val task = new v1.TaskData( + new TaskDataWrapper( info.taskId, info.index, info.attemptNumber, - new Date(info.launchTime), - if (info.gettingResult) Some(new Date(info.gettingResultTime)) else None, - Some(duration), - info.executorId, - info.host, - info.status, - info.taskLocality.toString(), + info.launchTime, + if (info.gettingResult) info.gettingResultTime else -1L, + duration, + weakIntern(info.executorId), + weakIntern(info.host), + weakIntern(info.status), + weakIntern(info.taskLocality.toString()), info.speculative, newAccumulatorInfos(info.accumulables), errorMessage, - Option(recordedMetrics)) - new TaskDataWrapper(task, stageId, stageAttemptId) + + metrics.executorDeserializeTime, + metrics.executorDeserializeCpuTime, + metrics.executorRunTime, + metrics.executorCpuTime, + metrics.resultSize, + metrics.jvmGcTime, + metrics.resultSerializationTime, + metrics.memoryBytesSpilled, + metrics.diskBytesSpilled, + metrics.peakExecutionMemory, + metrics.inputMetrics.bytesRead, + metrics.inputMetrics.recordsRead, + metrics.outputMetrics.bytesWritten, + metrics.outputMetrics.recordsWritten, + metrics.shuffleReadMetrics.remoteBlocksFetched, + metrics.shuffleReadMetrics.localBlocksFetched, + metrics.shuffleReadMetrics.fetchWaitTime, + metrics.shuffleReadMetrics.remoteBytesRead, + metrics.shuffleReadMetrics.remoteBytesReadToDisk, + metrics.shuffleReadMetrics.localBytesRead, + metrics.shuffleReadMetrics.recordsRead, + metrics.shuffleWriteMetrics.bytesWritten, + metrics.shuffleWriteMetrics.writeTime, + metrics.shuffleWriteMetrics.recordsWritten, + + stageId, + stageAttemptId) } } @@ -313,50 +305,19 @@ private class LiveExecutor(val executorId: String, _addTime: Long) extends LiveE } -/** Metrics tracked per stage (both total and per executor). */ -private class MetricsTracker { - var executorRunTime = 0L - var executorCpuTime = 0L - var inputBytes = 0L - var inputRecords = 0L - var outputBytes = 0L - var outputRecords = 0L - var shuffleReadBytes = 0L - var shuffleReadRecords = 0L - var shuffleWriteBytes = 0L - var shuffleWriteRecords = 0L - var memoryBytesSpilled = 0L - var diskBytesSpilled = 0L - - def update(delta: v1.TaskMetrics): Unit = { - executorRunTime += delta.executorRunTime - executorCpuTime += delta.executorCpuTime - inputBytes += delta.inputMetrics.bytesRead - inputRecords += delta.inputMetrics.recordsRead - outputBytes += delta.outputMetrics.bytesWritten - outputRecords += delta.outputMetrics.recordsWritten - shuffleReadBytes += delta.shuffleReadMetrics.localBytesRead + - delta.shuffleReadMetrics.remoteBytesRead - shuffleReadRecords += delta.shuffleReadMetrics.recordsRead - shuffleWriteBytes += delta.shuffleWriteMetrics.bytesWritten - shuffleWriteRecords += delta.shuffleWriteMetrics.recordsWritten - memoryBytesSpilled += delta.memoryBytesSpilled - diskBytesSpilled += delta.diskBytesSpilled - } - -} - private class LiveExecutorStageSummary( stageId: Int, attemptId: Int, executorId: String) extends LiveEntity { + import LiveEntityHelpers._ + var taskTime = 0L var succeededTasks = 0 var failedTasks = 0 var killedTasks = 0 - val metrics = new MetricsTracker() + var metrics = createMetrics(default = 0L) override protected def doUpdate(): Any = { val info = new v1.ExecutorStageSummary( @@ -364,14 +325,14 @@ private class LiveExecutorStageSummary( failedTasks, succeededTasks, killedTasks, - metrics.inputBytes, - metrics.inputRecords, - metrics.outputBytes, - metrics.outputRecords, - metrics.shuffleReadBytes, - metrics.shuffleReadRecords, - metrics.shuffleWriteBytes, - metrics.shuffleWriteRecords, + metrics.inputMetrics.bytesRead, + metrics.inputMetrics.recordsRead, + metrics.outputMetrics.bytesWritten, + metrics.outputMetrics.recordsWritten, + metrics.shuffleReadMetrics.remoteBytesRead + metrics.shuffleReadMetrics.localBytesRead, + metrics.shuffleReadMetrics.recordsRead, + metrics.shuffleWriteMetrics.bytesWritten, + metrics.shuffleWriteMetrics.recordsWritten, metrics.memoryBytesSpilled, metrics.diskBytesSpilled) new ExecutorStageSummaryWrapper(stageId, attemptId, executorId, info) @@ -402,7 +363,9 @@ private class LiveStage extends LiveEntity { var firstLaunchTime = Long.MaxValue - val metrics = new MetricsTracker() + var localitySummary: Map[String, Long] = Map() + + var metrics = createMetrics(default = 0L) val executorSummaries = new HashMap[String, LiveExecutorStageSummary]() @@ -435,14 +398,14 @@ private class LiveStage extends LiveEntity { info.completionTime.map(new Date(_)), info.failureReason, - metrics.inputBytes, - metrics.inputRecords, - metrics.outputBytes, - metrics.outputRecords, - metrics.shuffleReadBytes, - metrics.shuffleReadRecords, - metrics.shuffleWriteBytes, - metrics.shuffleWriteRecords, + metrics.inputMetrics.bytesRead, + metrics.inputMetrics.recordsRead, + metrics.outputMetrics.bytesWritten, + metrics.outputMetrics.recordsWritten, + metrics.shuffleReadMetrics.localBytesRead + metrics.shuffleReadMetrics.remoteBytesRead, + metrics.shuffleReadMetrics.recordsRead, + metrics.shuffleWriteMetrics.bytesWritten, + metrics.shuffleWriteMetrics.recordsWritten, metrics.memoryBytesSpilled, metrics.diskBytesSpilled, @@ -459,13 +422,15 @@ private class LiveStage extends LiveEntity { } override protected def doUpdate(): Any = { - new StageDataWrapper(toApi(), jobIds) + new StageDataWrapper(toApi(), jobIds, localitySummary) } } private class LiveRDDPartition(val blockName: String) { + import LiveEntityHelpers._ + // Pointers used by RDDPartitionSeq. @volatile var prev: LiveRDDPartition = null @volatile var next: LiveRDDPartition = null @@ -485,7 +450,7 @@ private class LiveRDDPartition(val blockName: String) { diskUsed: Long): Unit = { value = new v1.RDDPartitionInfo( blockName, - storageLevel, + weakIntern(storageLevel), memoryUsed, diskUsed, executors) @@ -495,6 +460,8 @@ private class LiveRDDPartition(val blockName: String) { private class LiveRDDDistribution(exec: LiveExecutor) { + import LiveEntityHelpers._ + val executorId = exec.executorId var memoryUsed = 0L var diskUsed = 0L @@ -508,7 +475,7 @@ private class LiveRDDDistribution(exec: LiveExecutor) { def toApi(): v1.RDDDataDistribution = { if (lastUpdate == null) { lastUpdate = new v1.RDDDataDistribution( - exec.hostPort, + weakIntern(exec.hostPort), memoryUsed, exec.maxMemory - exec.memoryUsed, diskUsed, @@ -524,7 +491,9 @@ private class LiveRDDDistribution(exec: LiveExecutor) { private class LiveRDD(val info: RDDInfo) extends LiveEntity { - var storageLevel: String = info.storageLevel.description + import LiveEntityHelpers._ + + var storageLevel: String = weakIntern(info.storageLevel.description) var memoryUsed = 0L var diskUsed = 0L @@ -533,6 +502,10 @@ private class LiveRDD(val info: RDDInfo) extends LiveEntity { private val distributions = new HashMap[String, LiveRDDDistribution]() + def setStorageLevel(level: String): Unit = { + this.storageLevel = weakIntern(level) + } + def partition(blockName: String): LiveRDDPartition = { partitions.getOrElseUpdate(blockName, { val part = new LiveRDDPartition(blockName) @@ -593,6 +566,9 @@ private class SchedulerPool(name: String) extends LiveEntity { private object LiveEntityHelpers { + private val stringInterner = Interners.newWeakInterner[String]() + + def newAccumulatorInfos(accums: Iterable[AccumulableInfo]): Seq[v1.AccumulableInfo] = { accums .filter { acc => @@ -604,13 +580,119 @@ private object LiveEntityHelpers { .map { acc => new v1.AccumulableInfo( acc.id, - acc.name.orNull, + acc.name.map(weakIntern).orNull, acc.update.map(_.toString()), acc.value.map(_.toString()).orNull) } .toSeq } + /** String interning to reduce the memory usage. */ + def weakIntern(s: String): String = { + stringInterner.intern(s) + } + + // scalastyle:off argcount + def createMetrics( + executorDeserializeTime: Long, + executorDeserializeCpuTime: Long, + executorRunTime: Long, + executorCpuTime: Long, + resultSize: Long, + jvmGcTime: Long, + resultSerializationTime: Long, + memoryBytesSpilled: Long, + diskBytesSpilled: Long, + peakExecutionMemory: Long, + inputBytesRead: Long, + inputRecordsRead: Long, + outputBytesWritten: Long, + outputRecordsWritten: Long, + shuffleRemoteBlocksFetched: Long, + shuffleLocalBlocksFetched: Long, + shuffleFetchWaitTime: Long, + shuffleRemoteBytesRead: Long, + shuffleRemoteBytesReadToDisk: Long, + shuffleLocalBytesRead: Long, + shuffleRecordsRead: Long, + shuffleBytesWritten: Long, + shuffleWriteTime: Long, + shuffleRecordsWritten: Long): v1.TaskMetrics = { + new v1.TaskMetrics( + executorDeserializeTime, + executorDeserializeCpuTime, + executorRunTime, + executorCpuTime, + resultSize, + jvmGcTime, + resultSerializationTime, + memoryBytesSpilled, + diskBytesSpilled, + peakExecutionMemory, + new v1.InputMetrics( + inputBytesRead, + inputRecordsRead), + new v1.OutputMetrics( + outputBytesWritten, + outputRecordsWritten), + new v1.ShuffleReadMetrics( + shuffleRemoteBlocksFetched, + shuffleLocalBlocksFetched, + shuffleFetchWaitTime, + shuffleRemoteBytesRead, + shuffleRemoteBytesReadToDisk, + shuffleLocalBytesRead, + shuffleRecordsRead), + new v1.ShuffleWriteMetrics( + shuffleBytesWritten, + shuffleWriteTime, + shuffleRecordsWritten)) + } + // scalastyle:on argcount + + def createMetrics(default: Long): v1.TaskMetrics = { + createMetrics(default, default, default, default, default, default, default, default, + default, default, default, default, default, default, default, default, + default, default, default, default, default, default, default, default) + } + + /** Add m2 values to m1. */ + def addMetrics(m1: v1.TaskMetrics, m2: v1.TaskMetrics): v1.TaskMetrics = addMetrics(m1, m2, 1) + + /** Subtract m2 values from m1. */ + def subtractMetrics(m1: v1.TaskMetrics, m2: v1.TaskMetrics): v1.TaskMetrics = { + addMetrics(m1, m2, -1) + } + + private def addMetrics(m1: v1.TaskMetrics, m2: v1.TaskMetrics, mult: Int): v1.TaskMetrics = { + createMetrics( + m1.executorDeserializeTime + m2.executorDeserializeTime * mult, + m1.executorDeserializeCpuTime + m2.executorDeserializeCpuTime * mult, + m1.executorRunTime + m2.executorRunTime * mult, + m1.executorCpuTime + m2.executorCpuTime * mult, + m1.resultSize + m2.resultSize * mult, + m1.jvmGcTime + m2.jvmGcTime * mult, + m1.resultSerializationTime + m2.resultSerializationTime * mult, + m1.memoryBytesSpilled + m2.memoryBytesSpilled * mult, + m1.diskBytesSpilled + m2.diskBytesSpilled * mult, + m1.peakExecutionMemory + m2.peakExecutionMemory * mult, + m1.inputMetrics.bytesRead + m2.inputMetrics.bytesRead * mult, + m1.inputMetrics.recordsRead + m2.inputMetrics.recordsRead * mult, + m1.outputMetrics.bytesWritten + m2.outputMetrics.bytesWritten * mult, + m1.outputMetrics.recordsWritten + m2.outputMetrics.recordsWritten * mult, + m1.shuffleReadMetrics.remoteBlocksFetched + m2.shuffleReadMetrics.remoteBlocksFetched * mult, + m1.shuffleReadMetrics.localBlocksFetched + m2.shuffleReadMetrics.localBlocksFetched * mult, + m1.shuffleReadMetrics.fetchWaitTime + m2.shuffleReadMetrics.fetchWaitTime * mult, + m1.shuffleReadMetrics.remoteBytesRead + m2.shuffleReadMetrics.remoteBytesRead * mult, + m1.shuffleReadMetrics.remoteBytesReadToDisk + + m2.shuffleReadMetrics.remoteBytesReadToDisk * mult, + m1.shuffleReadMetrics.localBytesRead + m2.shuffleReadMetrics.localBytesRead * mult, + m1.shuffleReadMetrics.recordsRead + m2.shuffleReadMetrics.recordsRead * mult, + m1.shuffleWriteMetrics.bytesWritten + m2.shuffleWriteMetrics.bytesWritten * mult, + m1.shuffleWriteMetrics.writeTime + m2.shuffleWriteMetrics.writeTime * mult, + m1.shuffleWriteMetrics.recordsWritten + m2.shuffleWriteMetrics.recordsWritten * mult) + } + } /** diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/StagesResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/StagesResource.scala index 3b879545b3d2e..96249e4bfd5fa 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/StagesResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/StagesResource.scala @@ -87,7 +87,8 @@ private[v1] class StagesResource extends BaseAppResource { } } - ui.store.taskSummary(stageId, stageAttemptId, quantiles) + ui.store.taskSummary(stageId, stageAttemptId, quantiles).getOrElse( + throw new NotFoundException(s"No tasks reported metrics for $stageId / $stageAttemptId yet.")) } @GET diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala index 45eaf935fb083..7d8e4de3c8efb 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala @@ -261,6 +261,9 @@ class TaskMetricDistributions private[spark]( val resultSize: IndexedSeq[Double], val jvmGcTime: IndexedSeq[Double], val resultSerializationTime: IndexedSeq[Double], + val gettingResultTime: IndexedSeq[Double], + val schedulerDelay: IndexedSeq[Double], + val peakExecutionMemory: IndexedSeq[Double], val memoryBytesSpilled: IndexedSeq[Double], val diskBytesSpilled: IndexedSeq[Double], diff --git a/core/src/main/scala/org/apache/spark/status/storeTypes.scala b/core/src/main/scala/org/apache/spark/status/storeTypes.scala index 1cfd30df49091..c9cb996a55fcc 100644 --- a/core/src/main/scala/org/apache/spark/status/storeTypes.scala +++ b/core/src/main/scala/org/apache/spark/status/storeTypes.scala @@ -17,9 +17,11 @@ package org.apache.spark.status -import java.lang.{Integer => JInteger, Long => JLong} +import java.lang.{Long => JLong} +import java.util.Date import com.fasterxml.jackson.annotation.JsonIgnore +import com.fasterxml.jackson.databind.annotation.JsonDeserialize import org.apache.spark.status.KVUtils._ import org.apache.spark.status.api.v1._ @@ -49,10 +51,10 @@ private[spark] class ApplicationEnvironmentInfoWrapper(val info: ApplicationEnvi private[spark] class ExecutorSummaryWrapper(val info: ExecutorSummary) { @JsonIgnore @KVIndex - private[this] val id: String = info.id + private def id: String = info.id @JsonIgnore @KVIndex("active") - private[this] val active: Boolean = info.isActive + private def active: Boolean = info.isActive @JsonIgnore @KVIndex("host") val host: String = info.hostPort.split(":")(0) @@ -69,51 +71,271 @@ private[spark] class JobDataWrapper( val skippedStages: Set[Int]) { @JsonIgnore @KVIndex - private[this] val id: Int = info.jobId + private def id: Int = info.jobId } private[spark] class StageDataWrapper( val info: StageData, - val jobIds: Set[Int]) { + val jobIds: Set[Int], + @JsonDeserialize(contentAs = classOf[JLong]) + val locality: Map[String, Long]) { @JsonIgnore @KVIndex - def id: Array[Int] = Array(info.stageId, info.attemptId) + private[this] val id: Array[Int] = Array(info.stageId, info.attemptId) @JsonIgnore @KVIndex("stageId") - def stageId: Int = info.stageId + private def stageId: Int = info.stageId + @JsonIgnore @KVIndex("active") + private def active: Boolean = info.status == StageStatus.ACTIVE + +} + +/** + * Tasks have a lot of indices that are used in a few different places. This object keeps logical + * names for these indices, mapped to short strings to save space when using a disk store. + */ +private[spark] object TaskIndexNames { + final val ACCUMULATORS = "acc" + final val ATTEMPT = "att" + final val DESER_CPU_TIME = "dct" + final val DESER_TIME = "des" + final val DISK_SPILL = "dbs" + final val DURATION = "dur" + final val ERROR = "err" + final val EXECUTOR = "exe" + final val EXEC_CPU_TIME = "ect" + final val EXEC_RUN_TIME = "ert" + final val GC_TIME = "gc" + final val GETTING_RESULT_TIME = "grt" + final val INPUT_RECORDS = "ir" + final val INPUT_SIZE = "is" + final val LAUNCH_TIME = "lt" + final val LOCALITY = "loc" + final val MEM_SPILL = "mbs" + final val OUTPUT_RECORDS = "or" + final val OUTPUT_SIZE = "os" + final val PEAK_MEM = "pem" + final val RESULT_SIZE = "rs" + final val SCHEDULER_DELAY = "dly" + final val SER_TIME = "rst" + final val SHUFFLE_LOCAL_BLOCKS = "slbl" + final val SHUFFLE_READ_RECORDS = "srr" + final val SHUFFLE_READ_TIME = "srt" + final val SHUFFLE_REMOTE_BLOCKS = "srbl" + final val SHUFFLE_REMOTE_READS = "srby" + final val SHUFFLE_REMOTE_READS_TO_DISK = "srbd" + final val SHUFFLE_TOTAL_READS = "stby" + final val SHUFFLE_TOTAL_BLOCKS = "stbl" + final val SHUFFLE_WRITE_RECORDS = "swr" + final val SHUFFLE_WRITE_SIZE = "sws" + final val SHUFFLE_WRITE_TIME = "swt" + final val STAGE = "stage" + final val STATUS = "sta" + final val TASK_INDEX = "idx" } /** - * The task information is always indexed with the stage ID, since that is how the UI and API - * consume it. That means every indexed value has the stage ID and attempt ID included, aside - * from the actual data being indexed. + * Unlike other data types, the task data wrapper does not keep a reference to the API's TaskData. + * That is to save memory, since for large applications there can be a large number of these + * elements (by default up to 100,000 per stage), and every bit of wasted memory adds up. + * + * It also contains many secondary indices, which are used to sort data efficiently in the UI at the + * expense of storage space (and slower write times). */ private[spark] class TaskDataWrapper( - val info: TaskData, + // Storing this as an object actually saves memory; it's also used as the key in the in-memory + // store, so in that case you'd save the extra copy of the value here. + @KVIndexParam + val taskId: JLong, + @KVIndexParam(value = TaskIndexNames.TASK_INDEX, parent = TaskIndexNames.STAGE) + val index: Int, + @KVIndexParam(value = TaskIndexNames.ATTEMPT, parent = TaskIndexNames.STAGE) + val attempt: Int, + @KVIndexParam(value = TaskIndexNames.LAUNCH_TIME, parent = TaskIndexNames.STAGE) + val launchTime: Long, + val resultFetchStart: Long, + @KVIndexParam(value = TaskIndexNames.DURATION, parent = TaskIndexNames.STAGE) + val duration: Long, + @KVIndexParam(value = TaskIndexNames.EXECUTOR, parent = TaskIndexNames.STAGE) + val executorId: String, + val host: String, + @KVIndexParam(value = TaskIndexNames.STATUS, parent = TaskIndexNames.STAGE) + val status: String, + @KVIndexParam(value = TaskIndexNames.LOCALITY, parent = TaskIndexNames.STAGE) + val taskLocality: String, + val speculative: Boolean, + val accumulatorUpdates: Seq[AccumulableInfo], + val errorMessage: Option[String], + + // The following is an exploded view of a TaskMetrics API object. This saves 5 objects + // (= 80 bytes of Java object overhead) per instance of this wrapper. If the first value + // (executorDeserializeTime) is -1L, it means the metrics for this task have not been + // recorded. + @KVIndexParam(value = TaskIndexNames.DESER_TIME, parent = TaskIndexNames.STAGE) + val executorDeserializeTime: Long, + @KVIndexParam(value = TaskIndexNames.DESER_CPU_TIME, parent = TaskIndexNames.STAGE) + val executorDeserializeCpuTime: Long, + @KVIndexParam(value = TaskIndexNames.EXEC_RUN_TIME, parent = TaskIndexNames.STAGE) + val executorRunTime: Long, + @KVIndexParam(value = TaskIndexNames.EXEC_CPU_TIME, parent = TaskIndexNames.STAGE) + val executorCpuTime: Long, + @KVIndexParam(value = TaskIndexNames.RESULT_SIZE, parent = TaskIndexNames.STAGE) + val resultSize: Long, + @KVIndexParam(value = TaskIndexNames.GC_TIME, parent = TaskIndexNames.STAGE) + val jvmGcTime: Long, + @KVIndexParam(value = TaskIndexNames.SER_TIME, parent = TaskIndexNames.STAGE) + val resultSerializationTime: Long, + @KVIndexParam(value = TaskIndexNames.MEM_SPILL, parent = TaskIndexNames.STAGE) + val memoryBytesSpilled: Long, + @KVIndexParam(value = TaskIndexNames.DISK_SPILL, parent = TaskIndexNames.STAGE) + val diskBytesSpilled: Long, + @KVIndexParam(value = TaskIndexNames.PEAK_MEM, parent = TaskIndexNames.STAGE) + val peakExecutionMemory: Long, + @KVIndexParam(value = TaskIndexNames.INPUT_SIZE, parent = TaskIndexNames.STAGE) + val inputBytesRead: Long, + @KVIndexParam(value = TaskIndexNames.INPUT_RECORDS, parent = TaskIndexNames.STAGE) + val inputRecordsRead: Long, + @KVIndexParam(value = TaskIndexNames.OUTPUT_SIZE, parent = TaskIndexNames.STAGE) + val outputBytesWritten: Long, + @KVIndexParam(value = TaskIndexNames.OUTPUT_RECORDS, parent = TaskIndexNames.STAGE) + val outputRecordsWritten: Long, + @KVIndexParam(value = TaskIndexNames.SHUFFLE_REMOTE_BLOCKS, parent = TaskIndexNames.STAGE) + val shuffleRemoteBlocksFetched: Long, + @KVIndexParam(value = TaskIndexNames.SHUFFLE_LOCAL_BLOCKS, parent = TaskIndexNames.STAGE) + val shuffleLocalBlocksFetched: Long, + @KVIndexParam(value = TaskIndexNames.SHUFFLE_READ_TIME, parent = TaskIndexNames.STAGE) + val shuffleFetchWaitTime: Long, + @KVIndexParam(value = TaskIndexNames.SHUFFLE_REMOTE_READS, parent = TaskIndexNames.STAGE) + val shuffleRemoteBytesRead: Long, + @KVIndexParam(value = TaskIndexNames.SHUFFLE_REMOTE_READS_TO_DISK, + parent = TaskIndexNames.STAGE) + val shuffleRemoteBytesReadToDisk: Long, + val shuffleLocalBytesRead: Long, + @KVIndexParam(value = TaskIndexNames.SHUFFLE_READ_RECORDS, parent = TaskIndexNames.STAGE) + val shuffleRecordsRead: Long, + @KVIndexParam(value = TaskIndexNames.SHUFFLE_WRITE_SIZE, parent = TaskIndexNames.STAGE) + val shuffleBytesWritten: Long, + @KVIndexParam(value = TaskIndexNames.SHUFFLE_WRITE_TIME, parent = TaskIndexNames.STAGE) + val shuffleWriteTime: Long, + @KVIndexParam(value = TaskIndexNames.SHUFFLE_WRITE_RECORDS, parent = TaskIndexNames.STAGE) + val shuffleRecordsWritten: Long, + val stageId: Int, val stageAttemptId: Int) { - @JsonIgnore @KVIndex - def id: Long = info.taskId + def hasMetrics: Boolean = executorDeserializeTime >= 0 + + def toApi: TaskData = { + val metrics = if (hasMetrics) { + Some(new TaskMetrics( + executorDeserializeTime, + executorDeserializeCpuTime, + executorRunTime, + executorCpuTime, + resultSize, + jvmGcTime, + resultSerializationTime, + memoryBytesSpilled, + diskBytesSpilled, + peakExecutionMemory, + new InputMetrics( + inputBytesRead, + inputRecordsRead), + new OutputMetrics( + outputBytesWritten, + outputRecordsWritten), + new ShuffleReadMetrics( + shuffleRemoteBlocksFetched, + shuffleLocalBlocksFetched, + shuffleFetchWaitTime, + shuffleRemoteBytesRead, + shuffleRemoteBytesReadToDisk, + shuffleLocalBytesRead, + shuffleRecordsRead), + new ShuffleWriteMetrics( + shuffleBytesWritten, + shuffleWriteTime, + shuffleRecordsWritten))) + } else { + None + } - @JsonIgnore @KVIndex("stage") - def stage: Array[Int] = Array(stageId, stageAttemptId) + new TaskData( + taskId, + index, + attempt, + new Date(launchTime), + if (resultFetchStart > 0L) Some(new Date(resultFetchStart)) else None, + if (duration > 0L) Some(duration) else None, + executorId, + host, + status, + taskLocality, + speculative, + accumulatorUpdates, + errorMessage, + metrics) + } + + @JsonIgnore @KVIndex(TaskIndexNames.STAGE) + private def stage: Array[Int] = Array(stageId, stageAttemptId) - @JsonIgnore @KVIndex("runtime") - def runtime: Array[AnyRef] = { - val _runtime = info.taskMetrics.map(_.executorRunTime).getOrElse(-1L) - Array(stageId: JInteger, stageAttemptId: JInteger, _runtime: JLong) + @JsonIgnore @KVIndex(value = TaskIndexNames.SCHEDULER_DELAY, parent = TaskIndexNames.STAGE) + def schedulerDelay: Long = { + if (hasMetrics) { + AppStatusUtils.schedulerDelay(launchTime, resultFetchStart, duration, executorDeserializeTime, + resultSerializationTime, executorRunTime) + } else { + -1L + } } - @JsonIgnore @KVIndex("startTime") - def startTime: Array[AnyRef] = { - Array(stageId: JInteger, stageAttemptId: JInteger, info.launchTime.getTime(): JLong) + @JsonIgnore @KVIndex(value = TaskIndexNames.GETTING_RESULT_TIME, parent = TaskIndexNames.STAGE) + def gettingResultTime: Long = { + if (hasMetrics) { + AppStatusUtils.gettingResultTime(launchTime, resultFetchStart, duration) + } else { + -1L + } } - @JsonIgnore @KVIndex("active") - def active: Boolean = info.duration.isEmpty + /** + * Sorting by accumulators is a little weird, and the previous behavior would generate + * insanely long keys in the index. So this implementation just considers the first + * accumulator and its String representation. + */ + @JsonIgnore @KVIndex(value = TaskIndexNames.ACCUMULATORS, parent = TaskIndexNames.STAGE) + private def accumulators: String = { + if (accumulatorUpdates.nonEmpty) { + val acc = accumulatorUpdates.head + s"${acc.name}:${acc.value}" + } else { + "" + } + } + + @JsonIgnore @KVIndex(value = TaskIndexNames.SHUFFLE_TOTAL_READS, parent = TaskIndexNames.STAGE) + private def shuffleTotalReads: Long = { + if (hasMetrics) { + shuffleLocalBytesRead + shuffleRemoteBytesRead + } else { + -1L + } + } + + @JsonIgnore @KVIndex(value = TaskIndexNames.SHUFFLE_TOTAL_BLOCKS, parent = TaskIndexNames.STAGE) + private def shuffleTotalBlocks: Long = { + if (hasMetrics) { + shuffleLocalBlocksFetched + shuffleRemoteBlocksFetched + } else { + -1L + } + } + + @JsonIgnore @KVIndex(value = TaskIndexNames.ERROR, parent = TaskIndexNames.STAGE) + private def error: String = if (errorMessage.isDefined) errorMessage.get else "" } @@ -134,10 +356,13 @@ private[spark] class ExecutorStageSummaryWrapper( val info: ExecutorStageSummary) { @JsonIgnore @KVIndex - val id: Array[Any] = Array(stageId, stageAttemptId, executorId) + private val _id: Array[Any] = Array(stageId, stageAttemptId, executorId) @JsonIgnore @KVIndex("stage") - private[this] val stage: Array[Int] = Array(stageId, stageAttemptId) + private def stage: Array[Int] = Array(stageId, stageAttemptId) + + @JsonIgnore + def id: Array[Any] = _id } @@ -203,3 +428,53 @@ private[spark] class AppSummary( def id: String = classOf[AppSummary].getName() } + +/** + * A cached view of a specific quantile for one stage attempt's metrics. + */ +private[spark] class CachedQuantile( + val stageId: Int, + val stageAttemptId: Int, + val quantile: String, + val taskCount: Long, + + // The following fields are an exploded view of a single entry for TaskMetricDistributions. + val executorDeserializeTime: Double, + val executorDeserializeCpuTime: Double, + val executorRunTime: Double, + val executorCpuTime: Double, + val resultSize: Double, + val jvmGcTime: Double, + val resultSerializationTime: Double, + val gettingResultTime: Double, + val schedulerDelay: Double, + val peakExecutionMemory: Double, + val memoryBytesSpilled: Double, + val diskBytesSpilled: Double, + + val bytesRead: Double, + val recordsRead: Double, + + val bytesWritten: Double, + val recordsWritten: Double, + + val shuffleReadBytes: Double, + val shuffleRecordsRead: Double, + val shuffleRemoteBlocksFetched: Double, + val shuffleLocalBlocksFetched: Double, + val shuffleFetchWaitTime: Double, + val shuffleRemoteBytesRead: Double, + val shuffleRemoteBytesReadToDisk: Double, + val shuffleTotalBlocksFetched: Double, + + val shuffleWriteBytes: Double, + val shuffleWriteRecords: Double, + val shuffleWriteTime: Double) { + + @KVIndex @JsonIgnore + def id: Array[Any] = Array(stageId, stageAttemptId, quantile) + + @KVIndex("stage") @JsonIgnore + def stage: Array[Int] = Array(stageId, stageAttemptId) + +} diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala index 41d42b52430a5..95c12b1e73653 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala @@ -87,7 +87,9 @@ private[ui] class ExecutorTable(stage: StageData, store: AppStatusStore) { } private def createExecutorTable(stage: StageData) : Seq[Node] = { - stage.executorSummary.getOrElse(Map.empty).toSeq.sortBy(_._1).map { case (k, v) => + val executorSummary = store.executorSummary(stage.stageId, stage.attemptId) + + executorSummary.toSeq.sortBy(_._1).map { case (k, v) => val executor = store.asOption(store.executorSummary(k)) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala index 740f12e7d13d4..bf59152c8c0cd 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala @@ -201,7 +201,7 @@ private[ui] class JobPage(parent: JobsTab, store: AppStatusStore) extends WebUIP val stages = jobData.stageIds.map { stageId => // This could be empty if the listener hasn't received information about the // stage or if the stage information has been garbage collected - store.stageData(stageId).lastOption.getOrElse { + store.asOption(store.lastStageAttempt(stageId)).getOrElse { new v1.StageData( v1.StageStatus.PENDING, stageId, diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index 11a6a34344976..7c6e06cf183ba 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -19,6 +19,7 @@ package org.apache.spark.ui.jobs import java.net.URLEncoder import java.util.Date +import java.util.concurrent.TimeUnit import javax.servlet.http.HttpServletRequest import scala.collection.mutable.{HashMap, HashSet} @@ -29,15 +30,14 @@ import org.apache.commons.lang3.StringEscapeUtils import org.apache.spark.SparkConf import org.apache.spark.internal.config._ import org.apache.spark.scheduler.TaskLocality -import org.apache.spark.status.AppStatusStore +import org.apache.spark.status._ import org.apache.spark.status.api.v1._ import org.apache.spark.ui._ -import org.apache.spark.util.{Distribution, Utils} +import org.apache.spark.util.Utils /** Page showing statistics and task list for a given stage */ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends WebUIPage("stage") { import ApiHelper._ - import StagePage._ private val TIMELINE_LEGEND = {
@@ -67,17 +67,17 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We // if we find that it's okay. private val MAX_TIMELINE_TASKS = parent.conf.getInt("spark.ui.timeline.tasks.maximum", 1000) - private def getLocalitySummaryString(stageData: StageData, taskList: Seq[TaskData]): String = { - val localities = taskList.map(_.taskLocality) - val localityCounts = localities.groupBy(identity).mapValues(_.size) + private def getLocalitySummaryString(localitySummary: Map[String, Long]): String = { val names = Map( TaskLocality.PROCESS_LOCAL.toString() -> "Process local", TaskLocality.NODE_LOCAL.toString() -> "Node local", TaskLocality.RACK_LOCAL.toString() -> "Rack local", TaskLocality.ANY.toString() -> "Any") - val localityNamesAndCounts = localityCounts.toSeq.map { case (locality, count) => - s"${names(locality)}: $count" - } + val localityNamesAndCounts = names.flatMap { case (key, name) => + localitySummary.get(key).map { count => + s"$name: $count" + } + }.toSeq localityNamesAndCounts.sorted.mkString("; ") } @@ -108,7 +108,7 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We val stageHeader = s"Details for Stage $stageId (Attempt $stageAttemptId)" val stageData = parent.store - .asOption(parent.store.stageAttempt(stageId, stageAttemptId, details = true)) + .asOption(parent.store.stageAttempt(stageId, stageAttemptId, details = false)) .getOrElse { val content =
@@ -117,8 +117,11 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We return UIUtils.headerSparkPage(stageHeader, content, parent) } - val tasks = stageData.tasks.getOrElse(Map.empty).values.toSeq - if (tasks.isEmpty) { + val localitySummary = store.localitySummary(stageData.stageId, stageData.attemptId) + + val totalTasks = stageData.numActiveTasks + stageData.numCompleteTasks + + stageData.numFailedTasks + stageData.numKilledTasks + if (totalTasks == 0) { val content =

Summary Metrics

No tasks have started yet @@ -127,18 +130,14 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We return UIUtils.headerSparkPage(stageHeader, content, parent) } + val storedTasks = store.taskCount(stageData.stageId, stageData.attemptId) val numCompleted = stageData.numCompleteTasks - val totalTasks = stageData.numActiveTasks + stageData.numCompleteTasks + - stageData.numFailedTasks + stageData.numKilledTasks - val totalTasksNumStr = if (totalTasks == tasks.size) { + val totalTasksNumStr = if (totalTasks == storedTasks) { s"$totalTasks" } else { - s"$totalTasks, showing ${tasks.size}" + s"$totalTasks, showing ${storedTasks}" } - val externalAccumulables = stageData.accumulatorUpdates - val hasAccumulators = externalAccumulables.size > 0 - val summary =
    @@ -148,7 +147,7 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We
  • Locality Level Summary: - {getLocalitySummaryString(stageData, tasks)} + {getLocalitySummaryString(localitySummary)}
  • {if (hasInput(stageData)) {
  • @@ -266,7 +265,7 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We val accumulableTable = UIUtils.listingTable( accumulableHeaders, accumulableRow, - externalAccumulables.toSeq) + stageData.accumulatorUpdates.toSeq) val page: Int = { // If the user has changed to a larger page size, then go to page 1 in order to avoid @@ -280,16 +279,9 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We val currentTime = System.currentTimeMillis() val (taskTable, taskTableHTML) = try { val _taskTable = new TaskPagedTable( - parent.conf, + stageData, UIUtils.prependBaseUri(parent.basePath) + s"/stages/stage?id=${stageId}&attempt=${stageAttemptId}", - tasks, - hasAccumulators, - hasInput(stageData), - hasOutput(stageData), - hasShuffleRead(stageData), - hasShuffleWrite(stageData), - hasBytesSpilled(stageData), currentTime, pageSize = taskPageSize, sortColumn = taskSortColumn, @@ -320,217 +312,155 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We | } |}); """.stripMargin - } + } } - val taskIdsInPage = if (taskTable == null) Set.empty[Long] - else taskTable.dataSource.slicedTaskIds + val metricsSummary = store.taskSummary(stageData.stageId, stageData.attemptId, + Array(0, 0.25, 0.5, 0.75, 1.0)) - // Excludes tasks which failed and have incomplete metrics - val validTasks = tasks.filter(t => t.status == "SUCCESS" && t.taskMetrics.isDefined) - - val summaryTable: Option[Seq[Node]] = - if (validTasks.size == 0) { - None - } else { - def getDistributionQuantiles(data: Seq[Double]): IndexedSeq[Double] = { - Distribution(data).get.getQuantiles() - } - def getFormattedTimeQuantiles(times: Seq[Double]): Seq[Node] = { - getDistributionQuantiles(times).map { millis => - {UIUtils.formatDuration(millis.toLong)} - } - } - def getFormattedSizeQuantiles(data: Seq[Double]): Seq[Elem] = { - getDistributionQuantiles(data).map(d => {Utils.bytesToString(d.toLong)}) + val summaryTable = metricsSummary.map { metrics => + def timeQuantiles(data: IndexedSeq[Double]): Seq[Node] = { + data.map { millis => + {UIUtils.formatDuration(millis.toLong)} } + } - val deserializationTimes = validTasks.map { task => - task.taskMetrics.get.executorDeserializeTime.toDouble - } - val deserializationQuantiles = - - - Task Deserialization Time - - +: getFormattedTimeQuantiles(deserializationTimes) - - val serviceTimes = validTasks.map(_.taskMetrics.get.executorRunTime.toDouble) - val serviceQuantiles = Duration +: getFormattedTimeQuantiles(serviceTimes) - - val gcTimes = validTasks.map(_.taskMetrics.get.jvmGcTime.toDouble) - val gcQuantiles = - - GC Time - - +: getFormattedTimeQuantiles(gcTimes) - - val serializationTimes = validTasks.map(_.taskMetrics.get.resultSerializationTime.toDouble) - val serializationQuantiles = - - - Result Serialization Time - - +: getFormattedTimeQuantiles(serializationTimes) - - val gettingResultTimes = validTasks.map(getGettingResultTime(_, currentTime).toDouble) - val gettingResultQuantiles = - - - Getting Result Time - - +: - getFormattedTimeQuantiles(gettingResultTimes) - - val peakExecutionMemory = validTasks.map(_.taskMetrics.get.peakExecutionMemory.toDouble) - val peakExecutionMemoryQuantiles = { - - - Peak Execution Memory - - +: getFormattedSizeQuantiles(peakExecutionMemory) + def sizeQuantiles(data: IndexedSeq[Double]): Seq[Node] = { + data.map { size => + {Utils.bytesToString(size.toLong)} } + } - // The scheduler delay includes the network delay to send the task to the worker - // machine and to send back the result (but not the time to fetch the task result, - // if it needed to be fetched from the block manager on the worker). - val schedulerDelays = validTasks.map { task => - getSchedulerDelay(task, task.taskMetrics.get, currentTime).toDouble - } - val schedulerDelayTitle = Scheduler Delay - val schedulerDelayQuantiles = schedulerDelayTitle +: - getFormattedTimeQuantiles(schedulerDelays) - def getFormattedSizeQuantilesWithRecords(data: Seq[Double], records: Seq[Double]) - : Seq[Elem] = { - val recordDist = getDistributionQuantiles(records).iterator - getDistributionQuantiles(data).map(d => - {s"${Utils.bytesToString(d.toLong)} / ${recordDist.next().toLong}"} - ) + def sizeQuantilesWithRecords( + data: IndexedSeq[Double], + records: IndexedSeq[Double]) : Seq[Node] = { + data.zip(records).map { case (d, r) => + {s"${Utils.bytesToString(d.toLong)} / ${r.toLong}"} } + } - val inputSizes = validTasks.map(_.taskMetrics.get.inputMetrics.bytesRead.toDouble) - val inputRecords = validTasks.map(_.taskMetrics.get.inputMetrics.recordsRead.toDouble) - val inputQuantiles = Input Size / Records +: - getFormattedSizeQuantilesWithRecords(inputSizes, inputRecords) + def titleCell(title: String, tooltip: String): Seq[Node] = { + + + {title} + + + } - val outputSizes = validTasks.map(_.taskMetrics.get.outputMetrics.bytesWritten.toDouble) - val outputRecords = validTasks.map(_.taskMetrics.get.outputMetrics.recordsWritten.toDouble) - val outputQuantiles = Output Size / Records +: - getFormattedSizeQuantilesWithRecords(outputSizes, outputRecords) + def simpleTitleCell(title: String): Seq[Node] = {title} - val shuffleReadBlockedTimes = validTasks.map { task => - task.taskMetrics.get.shuffleReadMetrics.fetchWaitTime.toDouble - } - val shuffleReadBlockedQuantiles = - - - Shuffle Read Blocked Time - - +: - getFormattedTimeQuantiles(shuffleReadBlockedTimes) - - val shuffleReadTotalSizes = validTasks.map { task => - totalBytesRead(task.taskMetrics.get.shuffleReadMetrics).toDouble - } - val shuffleReadTotalRecords = validTasks.map { task => - task.taskMetrics.get.shuffleReadMetrics.recordsRead.toDouble - } - val shuffleReadTotalQuantiles = - - - Shuffle Read Size / Records - - +: - getFormattedSizeQuantilesWithRecords(shuffleReadTotalSizes, shuffleReadTotalRecords) - - val shuffleReadRemoteSizes = validTasks.map { task => - task.taskMetrics.get.shuffleReadMetrics.remoteBytesRead.toDouble - } - val shuffleReadRemoteQuantiles = - - - Shuffle Remote Reads - - +: - getFormattedSizeQuantiles(shuffleReadRemoteSizes) - - val shuffleWriteSizes = validTasks.map { task => - task.taskMetrics.get.shuffleWriteMetrics.bytesWritten.toDouble - } + val deserializationQuantiles = titleCell("Task Deserialization Time", + ToolTips.TASK_DESERIALIZATION_TIME) ++ timeQuantiles(metrics.executorDeserializeTime) - val shuffleWriteRecords = validTasks.map { task => - task.taskMetrics.get.shuffleWriteMetrics.recordsWritten.toDouble - } + val serviceQuantiles = simpleTitleCell("Duration") ++ timeQuantiles(metrics.executorRunTime) - val shuffleWriteQuantiles = Shuffle Write Size / Records +: - getFormattedSizeQuantilesWithRecords(shuffleWriteSizes, shuffleWriteRecords) + val gcQuantiles = titleCell("GC Time", ToolTips.GC_TIME) ++ timeQuantiles(metrics.jvmGcTime) - val memoryBytesSpilledSizes = validTasks.map(_.taskMetrics.get.memoryBytesSpilled.toDouble) - val memoryBytesSpilledQuantiles = Shuffle spill (memory) +: - getFormattedSizeQuantiles(memoryBytesSpilledSizes) + val serializationQuantiles = titleCell("Result Serialization Time", + ToolTips.RESULT_SERIALIZATION_TIME) ++ timeQuantiles(metrics.resultSerializationTime) - val diskBytesSpilledSizes = validTasks.map(_.taskMetrics.get.diskBytesSpilled.toDouble) - val diskBytesSpilledQuantiles = Shuffle spill (disk) +: - getFormattedSizeQuantiles(diskBytesSpilledSizes) + val gettingResultQuantiles = titleCell("Getting Result Time", ToolTips.GETTING_RESULT_TIME) ++ + timeQuantiles(metrics.gettingResultTime) - val listings: Seq[Seq[Node]] = Seq( - {serviceQuantiles}, - {schedulerDelayQuantiles}, - - {deserializationQuantiles} - - {gcQuantiles}, - - {serializationQuantiles} - , - {gettingResultQuantiles}, - - {peakExecutionMemoryQuantiles} - , - if (hasInput(stageData)) {inputQuantiles} else Nil, - if (hasOutput(stageData)) {outputQuantiles} else Nil, - if (hasShuffleRead(stageData)) { - - {shuffleReadBlockedQuantiles} - - {shuffleReadTotalQuantiles} - - {shuffleReadRemoteQuantiles} - - } else { - Nil - }, - if (hasShuffleWrite(stageData)) {shuffleWriteQuantiles} else Nil, - if (hasBytesSpilled(stageData)) {memoryBytesSpilledQuantiles} else Nil, - if (hasBytesSpilled(stageData)) {diskBytesSpilledQuantiles} else Nil) - - val quantileHeaders = Seq("Metric", "Min", "25th percentile", - "Median", "75th percentile", "Max") - // The summary table does not use CSS to stripe rows, which doesn't work with hidden - // rows (instead, JavaScript in table.js is used to stripe the non-hidden rows). - Some(UIUtils.listingTable( - quantileHeaders, - identity[Seq[Node]], - listings, - fixedWidth = true, - id = Some("task-summary-table"), - stripeRowsWithCss = false)) + val peakExecutionMemoryQuantiles = titleCell("Peak Execution Memory", + ToolTips.PEAK_EXECUTION_MEMORY) ++ sizeQuantiles(metrics.peakExecutionMemory) + + // The scheduler delay includes the network delay to send the task to the worker + // machine and to send back the result (but not the time to fetch the task result, + // if it needed to be fetched from the block manager on the worker). + val schedulerDelayQuantiles = titleCell("Scheduler Delay", ToolTips.SCHEDULER_DELAY) ++ + timeQuantiles(metrics.schedulerDelay) + + def inputQuantiles: Seq[Node] = { + simpleTitleCell("Input Size / Records") ++ + sizeQuantilesWithRecords(metrics.inputMetrics.bytesRead, metrics.inputMetrics.recordsRead) + } + + def outputQuantiles: Seq[Node] = { + simpleTitleCell("Output Size / Records") ++ + sizeQuantilesWithRecords(metrics.outputMetrics.bytesWritten, + metrics.outputMetrics.recordsWritten) } + def shuffleReadBlockedQuantiles: Seq[Node] = { + titleCell("Shuffle Read Blocked Time", ToolTips.SHUFFLE_READ_BLOCKED_TIME) ++ + timeQuantiles(metrics.shuffleReadMetrics.fetchWaitTime) + } + + def shuffleReadTotalQuantiles: Seq[Node] = { + titleCell("Shuffle Read Size / Records", ToolTips.SHUFFLE_READ) ++ + sizeQuantilesWithRecords(metrics.shuffleReadMetrics.readBytes, + metrics.shuffleReadMetrics.readRecords) + } + + def shuffleReadRemoteQuantiles: Seq[Node] = { + titleCell("Shuffle Remote Reads", ToolTips.SHUFFLE_READ_REMOTE_SIZE) ++ + sizeQuantiles(metrics.shuffleReadMetrics.remoteBytesRead) + } + + def shuffleWriteQuantiles: Seq[Node] = { + simpleTitleCell("Shuffle Write Size / Records") ++ + sizeQuantilesWithRecords(metrics.shuffleWriteMetrics.writeBytes, + metrics.shuffleWriteMetrics.writeRecords) + } + + def memoryBytesSpilledQuantiles: Seq[Node] = { + simpleTitleCell("Shuffle spill (memory)") ++ sizeQuantiles(metrics.memoryBytesSpilled) + } + + def diskBytesSpilledQuantiles: Seq[Node] = { + simpleTitleCell("Shuffle spill (disk)") ++ sizeQuantiles(metrics.diskBytesSpilled) + } + + val listings: Seq[Seq[Node]] = Seq( + {serviceQuantiles}, + {schedulerDelayQuantiles}, + + {deserializationQuantiles} + + {gcQuantiles}, + + {serializationQuantiles} + , + {gettingResultQuantiles}, + + {peakExecutionMemoryQuantiles} + , + if (hasInput(stageData)) {inputQuantiles} else Nil, + if (hasOutput(stageData)) {outputQuantiles} else Nil, + if (hasShuffleRead(stageData)) { + + {shuffleReadBlockedQuantiles} + + {shuffleReadTotalQuantiles} + + {shuffleReadRemoteQuantiles} + + } else { + Nil + }, + if (hasShuffleWrite(stageData)) {shuffleWriteQuantiles} else Nil, + if (hasBytesSpilled(stageData)) {memoryBytesSpilledQuantiles} else Nil, + if (hasBytesSpilled(stageData)) {diskBytesSpilledQuantiles} else Nil) + + val quantileHeaders = Seq("Metric", "Min", "25th percentile", "Median", "75th percentile", + "Max") + // The summary table does not use CSS to stripe rows, which doesn't work with hidden + // rows (instead, JavaScript in table.js is used to stripe the non-hidden rows). + UIUtils.listingTable( + quantileHeaders, + identity[Seq[Node]], + listings, + fixedWidth = true, + id = Some("task-summary-table"), + stripeRowsWithCss = false) + } + val executorTable = new ExecutorTable(stageData, parent.store) val maybeAccumulableTable: Seq[Node] = - if (hasAccumulators) {

    Accumulators

    ++ accumulableTable } else Seq() + if (hasAccumulators(stageData)) {

    Accumulators

    ++ accumulableTable } else Seq() val aggMetrics = taskIdsInPage.contains(t.taskId) }, + Option(taskTable).map(_.dataSource.tasks).getOrElse(Nil), currentTime) ++

    Summary Metrics for {numCompleted} Completed Tasks

    ++
    {summaryTable.getOrElse("No tasks have reported metrics yet.")}
    ++ @@ -593,10 +523,9 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We val serializationTimeProportion = toProportion(serializationTime) val deserializationTime = metricsOpt.map(_.executorDeserializeTime).getOrElse(0L) val deserializationTimeProportion = toProportion(deserializationTime) - val gettingResultTime = getGettingResultTime(taskInfo, currentTime) + val gettingResultTime = AppStatusUtils.gettingResultTime(taskInfo) val gettingResultTimeProportion = toProportion(gettingResultTime) - val schedulerDelay = - metricsOpt.map(getSchedulerDelay(taskInfo, _, currentTime)).getOrElse(0L) + val schedulerDelay = AppStatusUtils.schedulerDelay(taskInfo) val schedulerDelayProportion = toProportion(schedulerDelay) val executorOverhead = serializationTime + deserializationTime @@ -708,7 +637,7 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We { if (MAX_TIMELINE_TASKS < tasks.size) { - This stage has more than the maximum number of tasks that can be shown in the + This page has more than the maximum number of tasks that can be shown in the visualization! Only the most recent {MAX_TIMELINE_TASKS} tasks (of {tasks.size} total) are shown. @@ -733,402 +662,49 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We } -private[ui] object StagePage { - private[ui] def getGettingResultTime(info: TaskData, currentTime: Long): Long = { - info.resultFetchStart match { - case Some(start) => - info.duration match { - case Some(duration) => - info.launchTime.getTime() + duration - start.getTime() - - case _ => - currentTime - start.getTime() - } - - case _ => - 0L - } - } - - private[ui] def getSchedulerDelay( - info: TaskData, - metrics: TaskMetrics, - currentTime: Long): Long = { - info.duration match { - case Some(duration) => - val executorOverhead = metrics.executorDeserializeTime + metrics.resultSerializationTime - math.max( - 0, - duration - metrics.executorRunTime - executorOverhead - - getGettingResultTime(info, currentTime)) - - case _ => - // The task is still running and the metrics like executorRunTime are not available. - 0L - } - } - -} - -private[ui] case class TaskTableRowInputData(inputSortable: Long, inputReadable: String) - -private[ui] case class TaskTableRowOutputData(outputSortable: Long, outputReadable: String) - -private[ui] case class TaskTableRowShuffleReadData( - shuffleReadBlockedTimeSortable: Long, - shuffleReadBlockedTimeReadable: String, - shuffleReadSortable: Long, - shuffleReadReadable: String, - shuffleReadRemoteSortable: Long, - shuffleReadRemoteReadable: String) - -private[ui] case class TaskTableRowShuffleWriteData( - writeTimeSortable: Long, - writeTimeReadable: String, - shuffleWriteSortable: Long, - shuffleWriteReadable: String) - -private[ui] case class TaskTableRowBytesSpilledData( - memoryBytesSpilledSortable: Long, - memoryBytesSpilledReadable: String, - diskBytesSpilledSortable: Long, - diskBytesSpilledReadable: String) - -/** - * Contains all data that needs for sorting and generating HTML. Using this one rather than - * TaskData to avoid creating duplicate contents during sorting the data. - */ -private[ui] class TaskTableRowData( - val index: Int, - val taskId: Long, - val attempt: Int, - val speculative: Boolean, - val status: String, - val taskLocality: String, - val executorId: String, - val host: String, - val launchTime: Long, - val duration: Long, - val formatDuration: String, - val schedulerDelay: Long, - val taskDeserializationTime: Long, - val gcTime: Long, - val serializationTime: Long, - val gettingResultTime: Long, - val peakExecutionMemoryUsed: Long, - val accumulators: Option[String], // HTML - val input: Option[TaskTableRowInputData], - val output: Option[TaskTableRowOutputData], - val shuffleRead: Option[TaskTableRowShuffleReadData], - val shuffleWrite: Option[TaskTableRowShuffleWriteData], - val bytesSpilled: Option[TaskTableRowBytesSpilledData], - val error: String, - val logs: Map[String, String]) - private[ui] class TaskDataSource( - tasks: Seq[TaskData], - hasAccumulators: Boolean, - hasInput: Boolean, - hasOutput: Boolean, - hasShuffleRead: Boolean, - hasShuffleWrite: Boolean, - hasBytesSpilled: Boolean, + stage: StageData, currentTime: Long, pageSize: Int, sortColumn: String, desc: Boolean, - store: AppStatusStore) extends PagedDataSource[TaskTableRowData](pageSize) { - import StagePage._ + store: AppStatusStore) extends PagedDataSource[TaskData](pageSize) { + import ApiHelper._ // Keep an internal cache of executor log maps so that long task lists render faster. private val executorIdToLogs = new HashMap[String, Map[String, String]]() - // Convert TaskData to TaskTableRowData which contains the final contents to show in the table - // so that we can avoid creating duplicate contents during sorting the data - private val data = tasks.map(taskRow).sorted(ordering(sortColumn, desc)) - - private var _slicedTaskIds: Set[Long] = _ + private var _tasksToShow: Seq[TaskData] = null - override def dataSize: Int = data.size + override def dataSize: Int = stage.numCompleteTasks + stage.numFailedTasks + stage.numKilledTasks - override def sliceData(from: Int, to: Int): Seq[TaskTableRowData] = { - val r = data.slice(from, to) - _slicedTaskIds = r.map(_.taskId).toSet - r - } - - def slicedTaskIds: Set[Long] = _slicedTaskIds - - private def taskRow(info: TaskData): TaskTableRowData = { - val metrics = info.taskMetrics - val duration = info.duration.getOrElse(1L) - val formatDuration = info.duration.map(d => UIUtils.formatDuration(d)).getOrElse("") - val schedulerDelay = metrics.map(getSchedulerDelay(info, _, currentTime)).getOrElse(0L) - val gcTime = metrics.map(_.jvmGcTime).getOrElse(0L) - val taskDeserializationTime = metrics.map(_.executorDeserializeTime).getOrElse(0L) - val serializationTime = metrics.map(_.resultSerializationTime).getOrElse(0L) - val gettingResultTime = getGettingResultTime(info, currentTime) - - val externalAccumulableReadable = info.accumulatorUpdates.map { acc => - StringEscapeUtils.escapeHtml4(s"${acc.name}: ${acc.update}") + override def sliceData(from: Int, to: Int): Seq[TaskData] = { + if (_tasksToShow == null) { + _tasksToShow = store.taskList(stage.stageId, stage.attemptId, from, to - from, + indexName(sortColumn), !desc) } - val peakExecutionMemoryUsed = metrics.map(_.peakExecutionMemory).getOrElse(0L) - - val maybeInput = metrics.map(_.inputMetrics) - val inputSortable = maybeInput.map(_.bytesRead).getOrElse(0L) - val inputReadable = maybeInput - .map(m => s"${Utils.bytesToString(m.bytesRead)}") - .getOrElse("") - val inputRecords = maybeInput.map(_.recordsRead.toString).getOrElse("") - - val maybeOutput = metrics.map(_.outputMetrics) - val outputSortable = maybeOutput.map(_.bytesWritten).getOrElse(0L) - val outputReadable = maybeOutput - .map(m => s"${Utils.bytesToString(m.bytesWritten)}") - .getOrElse("") - val outputRecords = maybeOutput.map(_.recordsWritten.toString).getOrElse("") - - val maybeShuffleRead = metrics.map(_.shuffleReadMetrics) - val shuffleReadBlockedTimeSortable = maybeShuffleRead.map(_.fetchWaitTime).getOrElse(0L) - val shuffleReadBlockedTimeReadable = - maybeShuffleRead.map(ms => UIUtils.formatDuration(ms.fetchWaitTime)).getOrElse("") - - val totalShuffleBytes = maybeShuffleRead.map(ApiHelper.totalBytesRead) - val shuffleReadSortable = totalShuffleBytes.getOrElse(0L) - val shuffleReadReadable = totalShuffleBytes.map(Utils.bytesToString).getOrElse("") - val shuffleReadRecords = maybeShuffleRead.map(_.recordsRead.toString).getOrElse("") - - val remoteShuffleBytes = maybeShuffleRead.map(_.remoteBytesRead) - val shuffleReadRemoteSortable = remoteShuffleBytes.getOrElse(0L) - val shuffleReadRemoteReadable = remoteShuffleBytes.map(Utils.bytesToString).getOrElse("") - - val maybeShuffleWrite = metrics.map(_.shuffleWriteMetrics) - val shuffleWriteSortable = maybeShuffleWrite.map(_.bytesWritten).getOrElse(0L) - val shuffleWriteReadable = maybeShuffleWrite - .map(m => s"${Utils.bytesToString(m.bytesWritten)}").getOrElse("") - val shuffleWriteRecords = maybeShuffleWrite - .map(_.recordsWritten.toString).getOrElse("") - - val maybeWriteTime = metrics.map(_.shuffleWriteMetrics.writeTime) - val writeTimeSortable = maybeWriteTime.getOrElse(0L) - val writeTimeReadable = maybeWriteTime.map(t => t / (1000 * 1000)).map { ms => - if (ms == 0) "" else UIUtils.formatDuration(ms) - }.getOrElse("") - - val maybeMemoryBytesSpilled = metrics.map(_.memoryBytesSpilled) - val memoryBytesSpilledSortable = maybeMemoryBytesSpilled.getOrElse(0L) - val memoryBytesSpilledReadable = - maybeMemoryBytesSpilled.map(Utils.bytesToString).getOrElse("") - - val maybeDiskBytesSpilled = metrics.map(_.diskBytesSpilled) - val diskBytesSpilledSortable = maybeDiskBytesSpilled.getOrElse(0L) - val diskBytesSpilledReadable = maybeDiskBytesSpilled.map(Utils.bytesToString).getOrElse("") - - val input = - if (hasInput) { - Some(TaskTableRowInputData(inputSortable, s"$inputReadable / $inputRecords")) - } else { - None - } - - val output = - if (hasOutput) { - Some(TaskTableRowOutputData(outputSortable, s"$outputReadable / $outputRecords")) - } else { - None - } - - val shuffleRead = - if (hasShuffleRead) { - Some(TaskTableRowShuffleReadData( - shuffleReadBlockedTimeSortable, - shuffleReadBlockedTimeReadable, - shuffleReadSortable, - s"$shuffleReadReadable / $shuffleReadRecords", - shuffleReadRemoteSortable, - shuffleReadRemoteReadable - )) - } else { - None - } - - val shuffleWrite = - if (hasShuffleWrite) { - Some(TaskTableRowShuffleWriteData( - writeTimeSortable, - writeTimeReadable, - shuffleWriteSortable, - s"$shuffleWriteReadable / $shuffleWriteRecords" - )) - } else { - None - } - - val bytesSpilled = - if (hasBytesSpilled) { - Some(TaskTableRowBytesSpilledData( - memoryBytesSpilledSortable, - memoryBytesSpilledReadable, - diskBytesSpilledSortable, - diskBytesSpilledReadable - )) - } else { - None - } - - new TaskTableRowData( - info.index, - info.taskId, - info.attempt, - info.speculative, - info.status, - info.taskLocality.toString, - info.executorId, - info.host, - info.launchTime.getTime(), - duration, - formatDuration, - schedulerDelay, - taskDeserializationTime, - gcTime, - serializationTime, - gettingResultTime, - peakExecutionMemoryUsed, - if (hasAccumulators) Some(externalAccumulableReadable.mkString("
    ")) else None, - input, - output, - shuffleRead, - shuffleWrite, - bytesSpilled, - info.errorMessage.getOrElse(""), - executorLogs(info.executorId)) + _tasksToShow } - private def executorLogs(id: String): Map[String, String] = { + def tasks: Seq[TaskData] = _tasksToShow + + def executorLogs(id: String): Map[String, String] = { executorIdToLogs.getOrElseUpdate(id, store.asOption(store.executorSummary(id)).map(_.executorLogs).getOrElse(Map.empty)) } - /** - * Return Ordering according to sortColumn and desc - */ - private def ordering(sortColumn: String, desc: Boolean): Ordering[TaskTableRowData] = { - val ordering: Ordering[TaskTableRowData] = sortColumn match { - case "Index" => Ordering.by(_.index) - case "ID" => Ordering.by(_.taskId) - case "Attempt" => Ordering.by(_.attempt) - case "Status" => Ordering.by(_.status) - case "Locality Level" => Ordering.by(_.taskLocality) - case "Executor ID" => Ordering.by(_.executorId) - case "Host" => Ordering.by(_.host) - case "Launch Time" => Ordering.by(_.launchTime) - case "Duration" => Ordering.by(_.duration) - case "Scheduler Delay" => Ordering.by(_.schedulerDelay) - case "Task Deserialization Time" => Ordering.by(_.taskDeserializationTime) - case "GC Time" => Ordering.by(_.gcTime) - case "Result Serialization Time" => Ordering.by(_.serializationTime) - case "Getting Result Time" => Ordering.by(_.gettingResultTime) - case "Peak Execution Memory" => Ordering.by(_.peakExecutionMemoryUsed) - case "Accumulators" => - if (hasAccumulators) { - Ordering.by(_.accumulators.get) - } else { - throw new IllegalArgumentException( - "Cannot sort by Accumulators because of no accumulators") - } - case "Input Size / Records" => - if (hasInput) { - Ordering.by(_.input.get.inputSortable) - } else { - throw new IllegalArgumentException( - "Cannot sort by Input Size / Records because of no inputs") - } - case "Output Size / Records" => - if (hasOutput) { - Ordering.by(_.output.get.outputSortable) - } else { - throw new IllegalArgumentException( - "Cannot sort by Output Size / Records because of no outputs") - } - // ShuffleRead - case "Shuffle Read Blocked Time" => - if (hasShuffleRead) { - Ordering.by(_.shuffleRead.get.shuffleReadBlockedTimeSortable) - } else { - throw new IllegalArgumentException( - "Cannot sort by Shuffle Read Blocked Time because of no shuffle reads") - } - case "Shuffle Read Size / Records" => - if (hasShuffleRead) { - Ordering.by(_.shuffleRead.get.shuffleReadSortable) - } else { - throw new IllegalArgumentException( - "Cannot sort by Shuffle Read Size / Records because of no shuffle reads") - } - case "Shuffle Remote Reads" => - if (hasShuffleRead) { - Ordering.by(_.shuffleRead.get.shuffleReadRemoteSortable) - } else { - throw new IllegalArgumentException( - "Cannot sort by Shuffle Remote Reads because of no shuffle reads") - } - // ShuffleWrite - case "Write Time" => - if (hasShuffleWrite) { - Ordering.by(_.shuffleWrite.get.writeTimeSortable) - } else { - throw new IllegalArgumentException( - "Cannot sort by Write Time because of no shuffle writes") - } - case "Shuffle Write Size / Records" => - if (hasShuffleWrite) { - Ordering.by(_.shuffleWrite.get.shuffleWriteSortable) - } else { - throw new IllegalArgumentException( - "Cannot sort by Shuffle Write Size / Records because of no shuffle writes") - } - // BytesSpilled - case "Shuffle Spill (Memory)" => - if (hasBytesSpilled) { - Ordering.by(_.bytesSpilled.get.memoryBytesSpilledSortable) - } else { - throw new IllegalArgumentException( - "Cannot sort by Shuffle Spill (Memory) because of no spills") - } - case "Shuffle Spill (Disk)" => - if (hasBytesSpilled) { - Ordering.by(_.bytesSpilled.get.diskBytesSpilledSortable) - } else { - throw new IllegalArgumentException( - "Cannot sort by Shuffle Spill (Disk) because of no spills") - } - case "Errors" => Ordering.by(_.error) - case unknownColumn => throw new IllegalArgumentException(s"Unknown column: $unknownColumn") - } - if (desc) { - ordering.reverse - } else { - ordering - } - } - } private[ui] class TaskPagedTable( - conf: SparkConf, + stage: StageData, basePath: String, - data: Seq[TaskData], - hasAccumulators: Boolean, - hasInput: Boolean, - hasOutput: Boolean, - hasShuffleRead: Boolean, - hasShuffleWrite: Boolean, - hasBytesSpilled: Boolean, currentTime: Long, pageSize: Int, sortColumn: String, desc: Boolean, - store: AppStatusStore) extends PagedTable[TaskTableRowData] { + store: AppStatusStore) extends PagedTable[TaskData] { + + import ApiHelper._ override def tableId: String = "task-table" @@ -1142,13 +718,7 @@ private[ui] class TaskPagedTable( override def pageNumberFormField: String = "task.page" override val dataSource: TaskDataSource = new TaskDataSource( - data, - hasAccumulators, - hasInput, - hasOutput, - hasShuffleRead, - hasShuffleWrite, - hasBytesSpilled, + stage, currentTime, pageSize, sortColumn, @@ -1180,22 +750,22 @@ private[ui] class TaskPagedTable( ("Result Serialization Time", TaskDetailsClassNames.RESULT_SERIALIZATION_TIME), ("Getting Result Time", TaskDetailsClassNames.GETTING_RESULT_TIME), ("Peak Execution Memory", TaskDetailsClassNames.PEAK_EXECUTION_MEMORY)) ++ - {if (hasAccumulators) Seq(("Accumulators", "")) else Nil} ++ - {if (hasInput) Seq(("Input Size / Records", "")) else Nil} ++ - {if (hasOutput) Seq(("Output Size / Records", "")) else Nil} ++ - {if (hasShuffleRead) { + {if (hasAccumulators(stage)) Seq(("Accumulators", "")) else Nil} ++ + {if (hasInput(stage)) Seq(("Input Size / Records", "")) else Nil} ++ + {if (hasOutput(stage)) Seq(("Output Size / Records", "")) else Nil} ++ + {if (hasShuffleRead(stage)) { Seq(("Shuffle Read Blocked Time", TaskDetailsClassNames.SHUFFLE_READ_BLOCKED_TIME), ("Shuffle Read Size / Records", ""), ("Shuffle Remote Reads", TaskDetailsClassNames.SHUFFLE_READ_REMOTE_SIZE)) } else { Nil }} ++ - {if (hasShuffleWrite) { + {if (hasShuffleWrite(stage)) { Seq(("Write Time", ""), ("Shuffle Write Size / Records", "")) } else { Nil }} ++ - {if (hasBytesSpilled) { + {if (hasBytesSpilled(stage)) { Seq(("Shuffle Spill (Memory)", ""), ("Shuffle Spill (Disk)", "")) } else { Nil @@ -1237,7 +807,17 @@ private[ui] class TaskPagedTable( {headerRow} } - def row(task: TaskTableRowData): Seq[Node] = { + def row(task: TaskData): Seq[Node] = { + def formatDuration(value: Option[Long], hideZero: Boolean = false): String = { + value.map { v => + if (v > 0 || !hideZero) UIUtils.formatDuration(v) else "" + }.getOrElse("") + } + + def formatBytes(value: Option[Long]): String = { + Utils.bytesToString(value.getOrElse(0L)) + } + {task.index} {task.taskId} @@ -1249,62 +829,98 @@ private[ui] class TaskPagedTable(
    {task.host}
    { - task.logs.map { + dataSource.executorLogs(task.executorId).map { case (logName, logUrl) => } }
    - {UIUtils.formatDate(new Date(task.launchTime))} - {task.formatDuration} + {UIUtils.formatDate(task.launchTime)} + {formatDuration(task.duration)} - {UIUtils.formatDuration(task.schedulerDelay)} + {UIUtils.formatDuration(AppStatusUtils.schedulerDelay(task))} - {UIUtils.formatDuration(task.taskDeserializationTime)} + {formatDuration(task.taskMetrics.map(_.executorDeserializeTime))} - {if (task.gcTime > 0) UIUtils.formatDuration(task.gcTime) else ""} + {formatDuration(task.taskMetrics.map(_.jvmGcTime), hideZero = true)} - {UIUtils.formatDuration(task.serializationTime)} + {formatDuration(task.taskMetrics.map(_.resultSerializationTime))} - {UIUtils.formatDuration(task.gettingResultTime)} + {UIUtils.formatDuration(AppStatusUtils.gettingResultTime(task))} - {Utils.bytesToString(task.peakExecutionMemoryUsed)} + {formatBytes(task.taskMetrics.map(_.peakExecutionMemory))} - {if (task.accumulators.nonEmpty) { - {Unparsed(task.accumulators.get)} + {if (hasAccumulators(stage)) { + accumulatorsInfo(task) }} - {if (task.input.nonEmpty) { - {task.input.get.inputReadable} + {if (hasInput(stage)) { + metricInfo(task) { m => + val bytesRead = Utils.bytesToString(m.inputMetrics.bytesRead) + val records = m.inputMetrics.recordsRead + {bytesRead} / {records} + } }} - {if (task.output.nonEmpty) { - {task.output.get.outputReadable} + {if (hasOutput(stage)) { + metricInfo(task) { m => + val bytesWritten = Utils.bytesToString(m.outputMetrics.bytesWritten) + val records = m.outputMetrics.recordsWritten + {bytesWritten} / {records} + } }} - {if (task.shuffleRead.nonEmpty) { + {if (hasShuffleRead(stage)) { - {task.shuffleRead.get.shuffleReadBlockedTimeReadable} + {formatDuration(task.taskMetrics.map(_.shuffleReadMetrics.fetchWaitTime))} - {task.shuffleRead.get.shuffleReadReadable} + { + metricInfo(task) { m => + val bytesRead = Utils.bytesToString(totalBytesRead(m.shuffleReadMetrics)) + val records = m.shuffleReadMetrics.recordsRead + Unparsed(s"$bytesRead / $records") + } + } - {task.shuffleRead.get.shuffleReadRemoteReadable} + {formatBytes(task.taskMetrics.map(_.shuffleReadMetrics.remoteBytesRead))} }} - {if (task.shuffleWrite.nonEmpty) { - {task.shuffleWrite.get.writeTimeReadable} - {task.shuffleWrite.get.shuffleWriteReadable} + {if (hasShuffleWrite(stage)) { + { + formatDuration( + task.taskMetrics.map { m => + TimeUnit.NANOSECONDS.toMillis(m.shuffleWriteMetrics.writeTime) + }, + hideZero = true) + } + { + metricInfo(task) { m => + val bytesWritten = Utils.bytesToString(m.shuffleWriteMetrics.bytesWritten) + val records = m.shuffleWriteMetrics.recordsWritten + Unparsed(s"$bytesWritten / $records") + } + } }} - {if (task.bytesSpilled.nonEmpty) { - {task.bytesSpilled.get.memoryBytesSpilledReadable} - {task.bytesSpilled.get.diskBytesSpilledReadable} + {if (hasBytesSpilled(stage)) { + {formatBytes(task.taskMetrics.map(_.memoryBytesSpilled))} + {formatBytes(task.taskMetrics.map(_.diskBytesSpilled))} }} - {errorMessageCell(task.error)} + {errorMessageCell(task.errorMessage.getOrElse(""))} } + private def accumulatorsInfo(task: TaskData): Seq[Node] = { + task.accumulatorUpdates.map { acc => + Unparsed(StringEscapeUtils.escapeHtml4(s"${acc.name}: ${acc.update}")) + } + } + + private def metricInfo(task: TaskData)(fn: TaskMetrics => Seq[Node]): Seq[Node] = { + task.taskMetrics.map(fn).getOrElse(Nil) + } + private def errorMessageCell(error: String): Seq[Node] = { val isMultiline = error.indexOf('\n') >= 0 // Display the first line by default @@ -1333,6 +949,36 @@ private[ui] class TaskPagedTable( private object ApiHelper { + + private val COLUMN_TO_INDEX = Map( + "ID" -> null.asInstanceOf[String], + "Index" -> TaskIndexNames.TASK_INDEX, + "Attempt" -> TaskIndexNames.ATTEMPT, + "Status" -> TaskIndexNames.STATUS, + "Locality Level" -> TaskIndexNames.LOCALITY, + "Executor ID / Host" -> TaskIndexNames.EXECUTOR, + "Launch Time" -> TaskIndexNames.LAUNCH_TIME, + "Duration" -> TaskIndexNames.DURATION, + "Scheduler Delay" -> TaskIndexNames.SCHEDULER_DELAY, + "Task Deserialization Time" -> TaskIndexNames.DESER_TIME, + "GC Time" -> TaskIndexNames.GC_TIME, + "Result Serialization Time" -> TaskIndexNames.SER_TIME, + "Getting Result Time" -> TaskIndexNames.GETTING_RESULT_TIME, + "Peak Execution Memory" -> TaskIndexNames.PEAK_MEM, + "Accumulators" -> TaskIndexNames.ACCUMULATORS, + "Input Size / Records" -> TaskIndexNames.INPUT_SIZE, + "Output Size / Records" -> TaskIndexNames.OUTPUT_SIZE, + "Shuffle Read Blocked Time" -> TaskIndexNames.SHUFFLE_READ_TIME, + "Shuffle Read Size / Records" -> TaskIndexNames.SHUFFLE_TOTAL_READS, + "Shuffle Remote Reads" -> TaskIndexNames.SHUFFLE_REMOTE_READS, + "Write Time" -> TaskIndexNames.SHUFFLE_WRITE_TIME, + "Shuffle Write Size / Records" -> TaskIndexNames.SHUFFLE_WRITE_SIZE, + "Shuffle Spill (Memory)" -> TaskIndexNames.MEM_SPILL, + "Shuffle Spill (Disk)" -> TaskIndexNames.DISK_SPILL, + "Errors" -> TaskIndexNames.ERROR) + + def hasAccumulators(stageData: StageData): Boolean = stageData.accumulatorUpdates.size > 0 + def hasInput(stageData: StageData): Boolean = stageData.inputBytes > 0 def hasOutput(stageData: StageData): Boolean = stageData.outputBytes > 0 @@ -1349,4 +995,11 @@ private object ApiHelper { metrics.localBytesRead + metrics.remoteBytesRead } + def indexName(sortColumn: String): Option[String] = { + COLUMN_TO_INDEX.get(sortColumn) match { + case Some(v) => Option(v) + case _ => throw new IllegalArgumentException(s"Invalid sort column: $sortColumn") + } + } + } diff --git a/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w__custom_quantiles_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w__custom_quantiles_expectation.json index f8e27703c0def..5c42ac1d87f4c 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w__custom_quantiles_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w__custom_quantiles_expectation.json @@ -7,6 +7,9 @@ "resultSize" : [ 2010.0, 2065.0, 2065.0 ], "jvmGcTime" : [ 0.0, 0.0, 7.0 ], "resultSerializationTime" : [ 0.0, 0.0, 2.0 ], + "gettingResultTime" : [ 0.0, 0.0, 0.0 ], + "schedulerDelay" : [ 2.0, 6.0, 53.0 ], + "peakExecutionMemory" : [ 0.0, 0.0, 0.0 ], "memoryBytesSpilled" : [ 0.0, 0.0, 0.0 ], "diskBytesSpilled" : [ 0.0, 0.0, 0.0 ], "inputMetrics" : { diff --git a/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w_shuffle_read_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w_shuffle_read_expectation.json index a28bda16a956e..e6b705989cc97 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w_shuffle_read_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w_shuffle_read_expectation.json @@ -7,6 +7,9 @@ "resultSize" : [ 1034.0, 1034.0, 1034.0, 1034.0, 1034.0 ], "jvmGcTime" : [ 0.0, 0.0, 0.0, 0.0, 0.0 ], "resultSerializationTime" : [ 0.0, 0.0, 0.0, 0.0, 0.0 ], + "gettingResultTime" : [ 0.0, 0.0, 0.0, 0.0, 0.0 ], + "schedulerDelay" : [ 4.0, 4.0, 6.0, 7.0, 9.0 ], + "peakExecutionMemory" : [ 0.0, 0.0, 0.0, 0.0, 0.0 ], "memoryBytesSpilled" : [ 0.0, 0.0, 0.0, 0.0, 0.0 ], "diskBytesSpilled" : [ 0.0, 0.0, 0.0, 0.0, 0.0 ], "inputMetrics" : { diff --git a/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w_shuffle_write_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w_shuffle_write_expectation.json index ede3eaed1d1d2..788f28cf7b365 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w_shuffle_write_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w_shuffle_write_expectation.json @@ -7,6 +7,9 @@ "resultSize" : [ 2010.0, 2065.0, 2065.0, 2065.0, 2065.0 ], "jvmGcTime" : [ 0.0, 0.0, 0.0, 5.0, 7.0 ], "resultSerializationTime" : [ 0.0, 0.0, 0.0, 0.0, 1.0 ], + "gettingResultTime" : [ 0.0, 0.0, 0.0, 0.0, 0.0 ], + "schedulerDelay" : [ 2.0, 4.0, 6.0, 13.0, 40.0 ], + "peakExecutionMemory" : [ 0.0, 0.0, 0.0, 0.0, 0.0 ], "memoryBytesSpilled" : [ 0.0, 0.0, 0.0, 0.0, 0.0 ], "diskBytesSpilled" : [ 0.0, 0.0, 0.0, 0.0, 0.0 ], "inputMetrics" : { diff --git a/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala index b8c84e24c2c3f..ca66b6b9db890 100644 --- a/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala @@ -213,45 +213,42 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { s1Tasks.foreach { task => check[TaskDataWrapper](task.taskId) { wrapper => - assert(wrapper.info.taskId === task.taskId) + assert(wrapper.taskId === task.taskId) assert(wrapper.stageId === stages.head.stageId) - assert(wrapper.stageAttemptId === stages.head.attemptNumber) - assert(Arrays.equals(wrapper.stage, Array(stages.head.stageId, stages.head.attemptNumber))) - - val runtime = Array[AnyRef](stages.head.stageId: JInteger, - stages.head.attemptNumber: JInteger, - -1L: JLong) - assert(Arrays.equals(wrapper.runtime, runtime)) - - assert(wrapper.info.index === task.index) - assert(wrapper.info.attempt === task.attemptNumber) - assert(wrapper.info.launchTime === new Date(task.launchTime)) - assert(wrapper.info.executorId === task.executorId) - assert(wrapper.info.host === task.host) - assert(wrapper.info.status === task.status) - assert(wrapper.info.taskLocality === task.taskLocality.toString()) - assert(wrapper.info.speculative === task.speculative) + assert(wrapper.stageAttemptId === stages.head.attemptId) + assert(wrapper.index === task.index) + assert(wrapper.attempt === task.attemptNumber) + assert(wrapper.launchTime === task.launchTime) + assert(wrapper.executorId === task.executorId) + assert(wrapper.host === task.host) + assert(wrapper.status === task.status) + assert(wrapper.taskLocality === task.taskLocality.toString()) + assert(wrapper.speculative === task.speculative) } } - // Send executor metrics update. Only update one metric to avoid a lot of boilerplate code. - s1Tasks.foreach { task => - val accum = new AccumulableInfo(1L, Some(InternalAccumulator.MEMORY_BYTES_SPILLED), - Some(1L), None, true, false, None) - listener.onExecutorMetricsUpdate(SparkListenerExecutorMetricsUpdate( - task.executorId, - Seq((task.taskId, stages.head.stageId, stages.head.attemptNumber, Seq(accum))))) - } + // Send two executor metrics update. Only update one metric to avoid a lot of boilerplate code. + // The tasks are distributed among the two executors, so the executor-level metrics should + // hold half of the cummulative value of the metric being updated. + Seq(1L, 2L).foreach { value => + s1Tasks.foreach { task => + val accum = new AccumulableInfo(1L, Some(InternalAccumulator.MEMORY_BYTES_SPILLED), + Some(value), None, true, false, None) + listener.onExecutorMetricsUpdate(SparkListenerExecutorMetricsUpdate( + task.executorId, + Seq((task.taskId, stages.head.stageId, stages.head.attemptNumber, Seq(accum))))) + } - check[StageDataWrapper](key(stages.head)) { stage => - assert(stage.info.memoryBytesSpilled === s1Tasks.size) - } + check[StageDataWrapper](key(stages.head)) { stage => + assert(stage.info.memoryBytesSpilled === s1Tasks.size * value) + } - val execs = store.view(classOf[ExecutorStageSummaryWrapper]).index("stage") - .first(key(stages.head)).last(key(stages.head)).asScala.toSeq - assert(execs.size > 0) - execs.foreach { exec => - assert(exec.info.memoryBytesSpilled === s1Tasks.size / 2) + val execs = store.view(classOf[ExecutorStageSummaryWrapper]).index("stage") + .first(key(stages.head)).last(key(stages.head)).asScala.toSeq + assert(execs.size > 0) + execs.foreach { exec => + assert(exec.info.memoryBytesSpilled === s1Tasks.size * value / 2) + } } // Fail one of the tasks, re-start it. @@ -278,13 +275,13 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { } check[TaskDataWrapper](s1Tasks.head.taskId) { task => - assert(task.info.status === s1Tasks.head.status) - assert(task.info.errorMessage == Some(TaskResultLost.toErrorString)) + assert(task.status === s1Tasks.head.status) + assert(task.errorMessage == Some(TaskResultLost.toErrorString)) } check[TaskDataWrapper](reattempt.taskId) { task => - assert(task.info.index === s1Tasks.head.index) - assert(task.info.attempt === reattempt.attemptNumber) + assert(task.index === s1Tasks.head.index) + assert(task.attempt === reattempt.attemptNumber) } // Kill one task, restart it. @@ -306,8 +303,8 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { } check[TaskDataWrapper](killed.taskId) { task => - assert(task.info.index === killed.index) - assert(task.info.errorMessage === Some("killed")) + assert(task.index === killed.index) + assert(task.errorMessage === Some("killed")) } // Start a new attempt and finish it with TaskCommitDenied, make sure it's handled like a kill. @@ -334,8 +331,8 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { } check[TaskDataWrapper](denied.taskId) { task => - assert(task.info.index === killed.index) - assert(task.info.errorMessage === Some(denyReason.toErrorString)) + assert(task.index === killed.index) + assert(task.errorMessage === Some(denyReason.toErrorString)) } // Start a new attempt. @@ -373,10 +370,10 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { pending.foreach { task => check[TaskDataWrapper](task.taskId) { wrapper => - assert(wrapper.info.errorMessage === None) - assert(wrapper.info.taskMetrics.get.executorCpuTime === 2L) - assert(wrapper.info.taskMetrics.get.executorRunTime === 4L) - assert(wrapper.info.duration === Some(task.duration)) + assert(wrapper.errorMessage === None) + assert(wrapper.executorCpuTime === 2L) + assert(wrapper.executorRunTime === 4L) + assert(wrapper.duration === task.duration) } } @@ -894,6 +891,23 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { assert(store.count(classOf[StageDataWrapper]) === 3) assert(store.count(classOf[RDDOperationGraphWrapper]) === 3) + val dropped = stages.drop(1).head + + // Cache some quantiles by calling AppStatusStore.taskSummary(). For quantiles to be + // calculcated, we need at least one finished task. + time += 1 + val task = createTasks(1, Array("1")).head + listener.onTaskStart(SparkListenerTaskStart(dropped.stageId, dropped.attemptId, task)) + + time += 1 + task.markFinished(TaskState.FINISHED, time) + listener.onTaskEnd(SparkListenerTaskEnd(dropped.stageId, dropped.attemptId, + "taskType", Success, task, null)) + + new AppStatusStore(store) + .taskSummary(dropped.stageId, dropped.attemptId, Array(0.25d, 0.50d, 0.75d)) + assert(store.count(classOf[CachedQuantile], "stage", key(dropped)) === 3) + stages.drop(1).foreach { s => time += 1 s.completionTime = Some(time) @@ -905,6 +919,7 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { intercept[NoSuchElementException] { store.read(classOf[StageDataWrapper], Array(2, 0)) } + assert(store.count(classOf[CachedQuantile], "stage", key(dropped)) === 0) val attempt2 = new StageInfo(3, 1, "stage3", 4, Nil, Nil, "details3") time += 1 diff --git a/core/src/test/scala/org/apache/spark/status/AppStatusStoreSuite.scala b/core/src/test/scala/org/apache/spark/status/AppStatusStoreSuite.scala new file mode 100644 index 0000000000000..92f90f3d96ddf --- /dev/null +++ b/core/src/test/scala/org/apache/spark/status/AppStatusStoreSuite.scala @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.status + +import org.apache.spark.SparkFunSuite +import org.apache.spark.status.api.v1.TaskMetricDistributions +import org.apache.spark.util.Distribution +import org.apache.spark.util.kvstore._ + +class AppStatusStoreSuite extends SparkFunSuite { + + private val uiQuantiles = Array(0.0, 0.25, 0.5, 0.75, 1.0) + private val stageId = 1 + private val attemptId = 1 + + test("quantile calculation: 1 task") { + compareQuantiles(1, uiQuantiles) + } + + test("quantile calculation: few tasks") { + compareQuantiles(4, uiQuantiles) + } + + test("quantile calculation: more tasks") { + compareQuantiles(100, uiQuantiles) + } + + test("quantile calculation: lots of tasks") { + compareQuantiles(4096, uiQuantiles) + } + + test("quantile calculation: custom quantiles") { + compareQuantiles(4096, Array(0.01, 0.33, 0.5, 0.42, 0.69, 0.99)) + } + + test("quantile cache") { + val store = new InMemoryStore() + (0 until 4096).foreach { i => store.write(newTaskData(i)) } + + val appStore = new AppStatusStore(store) + + appStore.taskSummary(stageId, attemptId, Array(0.13d)) + intercept[NoSuchElementException] { + store.read(classOf[CachedQuantile], Array(stageId, attemptId, "13")) + } + + appStore.taskSummary(stageId, attemptId, Array(0.25d)) + val d1 = store.read(classOf[CachedQuantile], Array(stageId, attemptId, "25")) + + // Add a new task to force the cached quantile to be evicted, and make sure it's updated. + store.write(newTaskData(4096)) + appStore.taskSummary(stageId, attemptId, Array(0.25d, 0.50d, 0.73d)) + + val d2 = store.read(classOf[CachedQuantile], Array(stageId, attemptId, "25")) + assert(d1.taskCount != d2.taskCount) + + store.read(classOf[CachedQuantile], Array(stageId, attemptId, "50")) + intercept[NoSuchElementException] { + store.read(classOf[CachedQuantile], Array(stageId, attemptId, "73")) + } + + assert(store.count(classOf[CachedQuantile]) === 2) + } + + private def compareQuantiles(count: Int, quantiles: Array[Double]): Unit = { + val store = new InMemoryStore() + val values = (0 until count).map { i => + val task = newTaskData(i) + store.write(task) + i.toDouble + }.toArray + + val summary = new AppStatusStore(store).taskSummary(stageId, attemptId, quantiles).get + val dist = new Distribution(values, 0, values.length).getQuantiles(quantiles.sorted) + + dist.zip(summary.executorRunTime).foreach { case (expected, actual) => + assert(expected === actual) + } + } + + private def newTaskData(i: Int): TaskDataWrapper = { + new TaskDataWrapper( + i, i, i, i, i, i, i.toString, i.toString, i.toString, i.toString, false, Nil, None, + i, i, i, i, i, i, i, i, i, i, + i, i, i, i, i, i, i, i, i, i, + i, i, i, i, stageId, attemptId) + } + +} diff --git a/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala b/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala index 661d0d48d2f37..0aeddf730cd35 100644 --- a/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala @@ -28,6 +28,7 @@ import org.apache.spark._ import org.apache.spark.executor.TaskMetrics import org.apache.spark.scheduler._ import org.apache.spark.status.AppStatusStore +import org.apache.spark.status.config._ import org.apache.spark.ui.jobs.{StagePage, StagesTab} class StagePageSuite extends SparkFunSuite with LocalSparkContext { @@ -35,15 +36,13 @@ class StagePageSuite extends SparkFunSuite with LocalSparkContext { private val peakExecutionMemory = 10 test("peak execution memory should displayed") { - val conf = new SparkConf(false) - val html = renderStagePage(conf).toString().toLowerCase(Locale.ROOT) + val html = renderStagePage().toString().toLowerCase(Locale.ROOT) val targetString = "peak execution memory" assert(html.contains(targetString)) } test("SPARK-10543: peak execution memory should be per-task rather than cumulative") { - val conf = new SparkConf(false) - val html = renderStagePage(conf).toString().toLowerCase(Locale.ROOT) + val html = renderStagePage().toString().toLowerCase(Locale.ROOT) // verify min/25/50/75/max show task value not cumulative values assert(html.contains(s"$peakExecutionMemory.0 b" * 5)) } @@ -52,7 +51,8 @@ class StagePageSuite extends SparkFunSuite with LocalSparkContext { * Render a stage page started with the given conf and return the HTML. * This also runs a dummy stage to populate the page with useful content. */ - private def renderStagePage(conf: SparkConf): Seq[Node] = { + private def renderStagePage(): Seq[Node] = { + val conf = new SparkConf(false).set(LIVE_ENTITY_UPDATE_PERIOD, 0L) val statusStore = AppStatusStore.createLiveStore(conf) val listener = statusStore.listener.get diff --git a/scalastyle-config.xml b/scalastyle-config.xml index 7bdd3fac773a3..e2fa5754afaee 100644 --- a/scalastyle-config.xml +++ b/scalastyle-config.xml @@ -93,7 +93,7 @@ This file is divided into 3 sections: - + From 0552c36e02434c60dad82024334d291f6008b822 Mon Sep 17 00:00:00 2001 From: wuyi5 Date: Thu, 11 Jan 2018 22:17:15 +0900 Subject: [PATCH 0065/1161] [SPARK-22967][TESTS] Fix VersionSuite's unit tests by change Windows path into URI path ## What changes were proposed in this pull request? Two unit test will fail due to Windows format path: 1.test(s"$version: read avro file containing decimal") ``` org.apache.hadoop.hive.ql.metadata.HiveException: MetaException(message:java.lang.IllegalArgumentException: Can not create a Path from an empty string); ``` 2.test(s"$version: SPARK-17920: Insert into/overwrite avro table") ``` Unable to infer the schema. The schema specification is required to create the table `default`.`tab2`.; org.apache.spark.sql.AnalysisException: Unable to infer the schema. The schema specification is required to create the table `default`.`tab2`.; ``` This pr fix these two unit test by change Windows path into URI path. ## How was this patch tested? Existed. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: wuyi5 Closes #20199 from Ngone51/SPARK-22967. --- .../org/apache/spark/sql/hive/client/VersionsSuite.scala | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala index ff90e9dda5f7c..e64389e56b5a1 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala @@ -811,7 +811,7 @@ class VersionsSuite extends SparkFunSuite with Logging { test(s"$version: read avro file containing decimal") { val url = Thread.currentThread().getContextClassLoader.getResource("avroDecimal") - val location = new File(url.getFile) + val location = new File(url.getFile).toURI.toString val tableName = "tab1" val avroSchema = @@ -851,6 +851,8 @@ class VersionsSuite extends SparkFunSuite with Logging { } test(s"$version: SPARK-17920: Insert into/overwrite avro table") { + // skipped because it's failed in the condition on Windows + assume(!(Utils.isWindows && version == "0.12")) withTempDir { dir => val avroSchema = """ @@ -875,10 +877,10 @@ class VersionsSuite extends SparkFunSuite with Logging { val writer = new PrintWriter(schemaFile) writer.write(avroSchema) writer.close() - val schemaPath = schemaFile.getCanonicalPath + val schemaPath = schemaFile.toURI.toString val url = Thread.currentThread().getContextClassLoader.getResource("avroDecimal") - val srcLocation = new File(url.getFile).getCanonicalPath + val srcLocation = new File(url.getFile).toURI.toString val destTableName = "tab1" val srcTableName = "tab2" From 76892bcf2c08efd7e9c5b16d377e623d82fe695e Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Thu, 11 Jan 2018 21:32:36 +0800 Subject: [PATCH 0066/1161] [SPARK-23000][TEST-HADOOP2.6] Fix Flaky test suite DataSourceWithHiveMetastoreCatalogSuite ## What changes were proposed in this pull request? The Spark 2.3 branch still failed due to the flaky test suite `DataSourceWithHiveMetastoreCatalogSuite `. https://amplab.cs.berkeley.edu/jenkins/job/spark-branch-2.3-test-sbt-hadoop-2.6/ Although https://github.com/apache/spark/pull/20207 is unable to reproduce it in Spark 2.3, it sounds like the current DB of Spark's Catalog is changed based on the following stacktrace. Thus, we just need to reset it. ``` [info] DataSourceWithHiveMetastoreCatalogSuite: 02:40:39.486 ERROR org.apache.hadoop.hive.ql.parse.CalcitePlanner: org.apache.hadoop.hive.ql.parse.SemanticException: Line 1:14 Table not found 't' at org.apache.hadoop.hive.ql.parse.SemanticAnalyzer.getMetaData(SemanticAnalyzer.java:1594) at org.apache.hadoop.hive.ql.parse.SemanticAnalyzer.getMetaData(SemanticAnalyzer.java:1545) at org.apache.hadoop.hive.ql.parse.SemanticAnalyzer.genResolvedParseTree(SemanticAnalyzer.java:10077) at org.apache.hadoop.hive.ql.parse.SemanticAnalyzer.analyzeInternal(SemanticAnalyzer.java:10128) at org.apache.hadoop.hive.ql.parse.CalcitePlanner.analyzeInternal(CalcitePlanner.java:209) at org.apache.hadoop.hive.ql.parse.BaseSemanticAnalyzer.analyze(BaseSemanticAnalyzer.java:227) at org.apache.hadoop.hive.ql.Driver.compile(Driver.java:424) at org.apache.hadoop.hive.ql.Driver.compile(Driver.java:308) at org.apache.hadoop.hive.ql.Driver.compileInternal(Driver.java:1122) at org.apache.hadoop.hive.ql.Driver.runInternal(Driver.java:1170) at org.apache.hadoop.hive.ql.Driver.run(Driver.java:1059) at org.apache.hadoop.hive.ql.Driver.run(Driver.java:1049) at org.apache.spark.sql.hive.client.HiveClientImpl$$anonfun$runHive$1.apply(HiveClientImpl.scala:694) at org.apache.spark.sql.hive.client.HiveClientImpl$$anonfun$runHive$1.apply(HiveClientImpl.scala:683) at org.apache.spark.sql.hive.client.HiveClientImpl$$anonfun$withHiveState$1.apply(HiveClientImpl.scala:272) at org.apache.spark.sql.hive.client.HiveClientImpl.liftedTree1$1(HiveClientImpl.scala:210) at org.apache.spark.sql.hive.client.HiveClientImpl.retryLocked(HiveClientImpl.scala:209) at org.apache.spark.sql.hive.client.HiveClientImpl.withHiveState(HiveClientImpl.scala:255) at org.apache.spark.sql.hive.client.HiveClientImpl.runHive(HiveClientImpl.scala:683) at org.apache.spark.sql.hive.client.HiveClientImpl.runSqlHive(HiveClientImpl.scala:673) at org.apache.spark.sql.hive.DataSourceWithHiveMetastoreCatalogSuite$$anonfun$9$$anonfun$apply$1$$anonfun$apply$mcV$sp$3.apply$mcV$sp(HiveMetastoreCatalogSuite.scala:185) at org.apache.spark.sql.test.SQLTestUtilsBase$class.withTable(SQLTestUtils.scala:273) at org.apache.spark.sql.hive.DataSourceWithHiveMetastoreCatalogSuite.withTable(HiveMetastoreCatalogSuite.scala:139) at org.apache.spark.sql.hive.DataSourceWithHiveMetastoreCatalogSuite$$anonfun$9$$anonfun$apply$1.apply$mcV$sp(HiveMetastoreCatalogSuite.scala:163) at org.apache.spark.sql.hive.DataSourceWithHiveMetastoreCatalogSuite$$anonfun$9$$anonfun$apply$1.apply(HiveMetastoreCatalogSuite.scala:163) at org.apache.spark.sql.hive.DataSourceWithHiveMetastoreCatalogSuite$$anonfun$9$$anonfun$apply$1.apply(HiveMetastoreCatalogSuite.scala:163) at org.scalatest.OutcomeOf$class.outcomeOf(OutcomeOf.scala:85) at org.scalatest.OutcomeOf$.outcomeOf(OutcomeOf.scala:104) at org.scalatest.Transformer.apply(Transformer.scala:22) at org.scalatest.Transformer.apply(Transformer.scala:20) at org.scalatest.FunSuiteLike$$anon$1.apply(FunSuiteLike.scala:186) at org.apache.spark.SparkFunSuite.withFixture(SparkFunSuite.scala:68) at org.scalatest.FunSuiteLike$class.invokeWithFixture$1(FunSuiteLike.scala:183) at org.scalatest.FunSuiteLike$$anonfun$runTest$1.apply(FunSuiteLike.scala:196) at org.scalatest.FunSuiteLike$$anonfun$runTest$1.apply(FunSuiteLike.scala:196) at org.scalatest.SuperEngine.runTestImpl(Engine.scala:289) at org.scalatest.FunSuiteLike$class.runTest(FunSuiteLike.scala:196) at org.scalatest.FunSuite.runTest(FunSuite.scala:1560) at org.scalatest.FunSuiteLike$$anonfun$runTests$1.apply(FunSuiteLike.scala:229) at org.scalatest.FunSuiteLike$$anonfun$runTests$1.apply(FunSuiteLike.scala:229) at org.scalatest.SuperEngine$$anonfun$traverseSubNodes$1$1.apply(Engine.scala:396) at org.scalatest.SuperEngine$$anonfun$traverseSubNodes$1$1.apply(Engine.scala:384) at scala.collection.immutable.List.foreach(List.scala:381) at org.scalatest.SuperEngine.traverseSubNodes$1(Engine.scala:384) at org.scalatest.SuperEngine.org$scalatest$SuperEngine$$runTestsInBranch(Engine.scala:379) at org.scalatest.SuperEngine.runTestsImpl(Engine.scala:461) at org.scalatest.FunSuiteLike$class.runTests(FunSuiteLike.scala:229) at org.scalatest.FunSuite.runTests(FunSuite.scala:1560) at org.scalatest.Suite$class.run(Suite.scala:1147) at org.scalatest.FunSuite.org$scalatest$FunSuiteLike$$super$run(FunSuite.scala:1560) at org.scalatest.FunSuiteLike$$anonfun$run$1.apply(FunSuiteLike.scala:233) at org.scalatest.FunSuiteLike$$anonfun$run$1.apply(FunSuiteLike.scala:233) at org.scalatest.SuperEngine.runImpl(Engine.scala:521) at org.scalatest.FunSuiteLike$class.run(FunSuiteLike.scala:233) at org.apache.spark.SparkFunSuite.org$scalatest$BeforeAndAfterAll$$super$run(SparkFunSuite.scala:31) at org.scalatest.BeforeAndAfterAll$class.liftedTree1$1(BeforeAndAfterAll.scala:213) at org.scalatest.BeforeAndAfterAll$class.run(BeforeAndAfterAll.scala:210) at org.apache.spark.SparkFunSuite.run(SparkFunSuite.scala:31) at org.scalatest.tools.Framework.org$scalatest$tools$Framework$$runSuite(Framework.scala:314) at org.scalatest.tools.Framework$ScalaTestTask.execute(Framework.scala:480) at sbt.ForkMain$Run$2.call(ForkMain.java:296) at sbt.ForkMain$Run$2.call(ForkMain.java:286) at java.util.concurrent.FutureTask.run(FutureTask.java:266) at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142) at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617) at java.lang.Thread.run(Thread.java:745) ``` ## How was this patch tested? N/A Author: gatorsmile Closes #20218 from gatorsmile/testFixAgain. --- .../org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala index cf4ce83124d88..ba9b944e4a055 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala @@ -148,6 +148,7 @@ class DataSourceWithHiveMetastoreCatalogSuite override def beforeAll(): Unit = { super.beforeAll() + sparkSession.sessionState.catalog.reset() sparkSession.metadataHive.reset() } From b46e58b74c82dac37b7b92284ea3714919c5a886 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Thu, 11 Jan 2018 22:33:42 +0900 Subject: [PATCH 0067/1161] [SPARK-19732][FOLLOW-UP] Document behavior changes made in na.fill and fillna ## What changes were proposed in this pull request? https://github.com/apache/spark/pull/18164 introduces the behavior changes. We need to document it. ## How was this patch tested? N/A Author: gatorsmile Closes #20234 from gatorsmile/docBehaviorChange. --- docs/sql-programming-guide.md | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 72f79d6909ecc..258c769ff593b 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1788,12 +1788,10 @@ options. Note that, for DecimalType(38,0)*, the table above intentionally does not cover all other combinations of scales and precisions because currently we only infer decimal type like `BigInteger`/`BigInt`. For example, 1.1 is inferred as double type. - In PySpark, now we need Pandas 0.19.2 or upper if you want to use Pandas related functionalities, such as `toPandas`, `createDataFrame` from Pandas DataFrame, etc. - In PySpark, the behavior of timestamp values for Pandas related functionalities was changed to respect session timezone. If you want to use the old behavior, you need to set a configuration `spark.sql.execution.pandas.respectSessionTimeZone` to `False`. See [SPARK-22395](https://issues.apache.org/jira/browse/SPARK-22395) for details. - - - Since Spark 2.3, when either broadcast hash join or broadcast nested loop join is applicable, we prefer to broadcasting the table that is explicitly specified in a broadcast hint. For details, see the section [Broadcast Hint](#broadcast-hint-for-sql-queries) and [SPARK-22489](https://issues.apache.org/jira/browse/SPARK-22489). - - - Since Spark 2.3, when all inputs are binary, `functions.concat()` returns an output as binary. Otherwise, it returns as a string. Until Spark 2.3, it always returns as a string despite of input types. To keep the old behavior, set `spark.sql.function.concatBinaryAsString` to `true`. - - - Since Spark 2.3, when all inputs are binary, SQL `elt()` returns an output as binary. Otherwise, it returns as a string. Until Spark 2.3, it always returns as a string despite of input types. To keep the old behavior, set `spark.sql.function.eltOutputAsString` to `true`. + - In PySpark, `na.fill()` or `fillna` also accepts boolean and replaces nulls with booleans. In prior Spark versions, PySpark just ignores it and returns the original Dataset/DataFrame. + - Since Spark 2.3, when either broadcast hash join or broadcast nested loop join is applicable, we prefer to broadcasting the table that is explicitly specified in a broadcast hint. For details, see the section [Broadcast Hint](#broadcast-hint-for-sql-queries) and [SPARK-22489](https://issues.apache.org/jira/browse/SPARK-22489). + - Since Spark 2.3, when all inputs are binary, `functions.concat()` returns an output as binary. Otherwise, it returns as a string. Until Spark 2.3, it always returns as a string despite of input types. To keep the old behavior, set `spark.sql.function.concatBinaryAsString` to `true`. + - Since Spark 2.3, when all inputs are binary, SQL `elt()` returns an output as binary. Otherwise, it returns as a string. Until Spark 2.3, it always returns as a string despite of input types. To keep the old behavior, set `spark.sql.function.eltOutputAsString` to `true`. ## Upgrading From Spark SQL 2.1 to 2.2 From 6d230dccf65300651f989392159d84bfaf08f18f Mon Sep 17 00:00:00 2001 From: FanDonglai Date: Thu, 11 Jan 2018 09:06:40 -0600 Subject: [PATCH 0068/1161] Update PageRank.scala ## What changes were proposed in this pull request? Hi, acording to code below, "if (id == src) (0.0, Double.NegativeInfinity) else (0.0, 0.0)" I think the comment can be wrong ## How was this patch tested? Please review http://spark.apache.org/contributing.html before opening a pull request. Author: FanDonglai Closes #20220 from ddna1021/master. --- .../src/main/scala/org/apache/spark/graphx/lib/PageRank.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala index fd7b7f7c1c487..ebd65e8320e5c 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala @@ -303,7 +303,7 @@ object PageRank extends Logging { val src: VertexId = srcId.getOrElse(-1L) // Initialize the pagerankGraph with each edge attribute - // having weight 1/outDegree and each vertex with attribute 1.0. + // having weight 1/outDegree and each vertex with attribute 0. val pagerankGraph: Graph[(Double, Double), Double] = graph // Associate the degree with each vertex .outerJoinVertices(graph.outDegrees) { From 0b2eefb674151a0af64806728b38d9410da552ec Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Thu, 11 Jan 2018 10:37:35 -0800 Subject: [PATCH 0069/1161] [SPARK-22994][K8S] Use a single image for all Spark containers. This change allows a user to submit a Spark application on kubernetes having to provide a single image, instead of one image for each type of container. The image's entry point now takes an extra argument that identifies the process that is being started. The configuration still allows the user to provide different images for each container type if they so desire. On top of that, the entry point was simplified a bit to share more code; mainly, the same env variable is used to propagate the user-defined classpath to the different containers. Aside from being modified to match the new behavior, the 'build-push-docker-images.sh' script was renamed to 'docker-image-tool.sh' to more closely match its purpose; the old name was a little awkward and now also not entirely correct, since there is a single image. It was also moved to 'bin' since it's not necessarily an admin tool. Docs have been updated to match the new behavior. Tested locally with minikube. Author: Marcelo Vanzin Closes #20192 from vanzin/SPARK-22994. --- .../docker-image-tool.sh | 68 ++++++------- docs/running-on-kubernetes.md | 58 +++++------ .../org/apache/spark/deploy/k8s/Config.scala | 17 ++-- .../apache/spark/deploy/k8s/Constants.scala | 3 +- .../deploy/k8s/InitContainerBootstrap.scala | 1 + .../steps/BasicDriverConfigurationStep.scala | 3 +- .../cluster/k8s/ExecutorPodFactory.scala | 3 +- .../DriverConfigOrchestratorSuite.scala | 12 +-- .../BasicDriverConfigurationStepSuite.scala | 4 +- ...InitContainerConfigOrchestratorSuite.scala | 4 +- .../cluster/k8s/ExecutorPodFactorySuite.scala | 4 +- .../src/main/dockerfiles/driver/Dockerfile | 35 ------- .../src/main/dockerfiles/executor/Dockerfile | 35 ------- .../dockerfiles/init-container/Dockerfile | 24 ----- .../main/dockerfiles/spark-base/entrypoint.sh | 37 ------- .../{spark-base => spark}/Dockerfile | 10 +- .../src/main/dockerfiles/spark/entrypoint.sh | 97 +++++++++++++++++++ 17 files changed, 189 insertions(+), 226 deletions(-) rename sbin/build-push-docker-images.sh => bin/docker-image-tool.sh (63%) delete mode 100644 resource-managers/kubernetes/docker/src/main/dockerfiles/driver/Dockerfile delete mode 100644 resource-managers/kubernetes/docker/src/main/dockerfiles/executor/Dockerfile delete mode 100644 resource-managers/kubernetes/docker/src/main/dockerfiles/init-container/Dockerfile delete mode 100755 resource-managers/kubernetes/docker/src/main/dockerfiles/spark-base/entrypoint.sh rename resource-managers/kubernetes/docker/src/main/dockerfiles/{spark-base => spark}/Dockerfile (87%) create mode 100755 resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh diff --git a/sbin/build-push-docker-images.sh b/bin/docker-image-tool.sh similarity index 63% rename from sbin/build-push-docker-images.sh rename to bin/docker-image-tool.sh index b9532597419a5..071406336d1b1 100755 --- a/sbin/build-push-docker-images.sh +++ b/bin/docker-image-tool.sh @@ -24,29 +24,11 @@ function error { exit 1 } -# Detect whether this is a git clone or a Spark distribution and adjust paths -# accordingly. if [ -z "${SPARK_HOME}" ]; then SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" fi . "${SPARK_HOME}/bin/load-spark-env.sh" -if [ -f "$SPARK_HOME/RELEASE" ]; then - IMG_PATH="kubernetes/dockerfiles" - SPARK_JARS="jars" -else - IMG_PATH="resource-managers/kubernetes/docker/src/main/dockerfiles" - SPARK_JARS="assembly/target/scala-$SPARK_SCALA_VERSION/jars" -fi - -if [ ! -d "$IMG_PATH" ]; then - error "Cannot find docker images. This script must be run from a runnable distribution of Apache Spark." -fi - -declare -A path=( [spark-driver]="$IMG_PATH/driver/Dockerfile" \ - [spark-executor]="$IMG_PATH/executor/Dockerfile" \ - [spark-init]="$IMG_PATH/init-container/Dockerfile" ) - function image_ref { local image="$1" local add_repo="${2:-1}" @@ -60,35 +42,49 @@ function image_ref { } function build { - docker build \ - --build-arg "spark_jars=$SPARK_JARS" \ - --build-arg "img_path=$IMG_PATH" \ - -t spark-base \ - -f "$IMG_PATH/spark-base/Dockerfile" . - for image in "${!path[@]}"; do - docker build -t "$(image_ref $image)" -f ${path[$image]} . - done + local BUILD_ARGS + local IMG_PATH + + if [ ! -f "$SPARK_HOME/RELEASE" ]; then + # Set image build arguments accordingly if this is a source repo and not a distribution archive. + IMG_PATH=resource-managers/kubernetes/docker/src/main/dockerfiles + BUILD_ARGS=( + --build-arg + img_path=$IMG_PATH + --build-arg + spark_jars=assembly/target/scala-$SPARK_SCALA_VERSION/jars + ) + else + # Not passed as an argument to docker, but used to validate the Spark directory. + IMG_PATH="kubernetes/dockerfiles" + fi + + if [ ! -d "$IMG_PATH" ]; then + error "Cannot find docker image. This script must be run from a runnable distribution of Apache Spark." + fi + + docker build "${BUILD_ARGS[@]}" \ + -t $(image_ref spark) \ + -f "$IMG_PATH/spark/Dockerfile" . } function push { - for image in "${!path[@]}"; do - docker push "$(image_ref $image)" - done + docker push "$(image_ref spark)" } function usage { cat < -t my-tag build - ./sbin/build-push-docker-images.sh -r -t my-tag push - -Docker files are under the `kubernetes/dockerfiles/` directory and can be customized further before -building using the supplied script, or manually. + ./bin/docker-image-tool.sh -r -t my-tag build + ./bin/docker-image-tool.sh -r -t my-tag push ## Cluster Mode @@ -79,8 +76,7 @@ $ bin/spark-submit \ --name spark-pi \ --class org.apache.spark.examples.SparkPi \ --conf spark.executor.instances=5 \ - --conf spark.kubernetes.driver.container.image= \ - --conf spark.kubernetes.executor.container.image= \ + --conf spark.kubernetes.container.image= \ local:///path/to/examples.jar ``` @@ -126,13 +122,7 @@ Those dependencies can be added to the classpath by referencing them with `local ### Using Remote Dependencies When there are application dependencies hosted in remote locations like HDFS or HTTP servers, the driver and executor pods need a Kubernetes [init-container](https://kubernetes.io/docs/concepts/workloads/pods/init-containers/) for downloading -the dependencies so the driver and executor containers can use them locally. This requires users to specify the container -image for the init-container using the configuration property `spark.kubernetes.initContainer.image`. For example, users -simply add the following option to the `spark-submit` command to specify the init-container image: - -``` ---conf spark.kubernetes.initContainer.image= -``` +the dependencies so the driver and executor containers can use them locally. The init-container handles remote dependencies specified in `spark.jars` (or the `--jars` option of `spark-submit`) and `spark.files` (or the `--files` option of `spark-submit`). It also handles remotely hosted main application resources, e.g., @@ -147,9 +137,7 @@ $ bin/spark-submit \ --jars https://path/to/dependency1.jar,https://path/to/dependency2.jar --files hdfs://host:port/path/to/file1,hdfs://host:port/path/to/file2 --conf spark.executor.instances=5 \ - --conf spark.kubernetes.driver.container.image= \ - --conf spark.kubernetes.executor.container.image= \ - --conf spark.kubernetes.initContainer.image= + --conf spark.kubernetes.container.image= \ https://path/to/examples.jar ``` @@ -322,21 +310,27 @@ specific to Spark on Kubernetes. - spark.kubernetes.driver.container.image + spark.kubernetes.container.image (none) - Container image to use for the driver. - This is usually of the form example.com/repo/spark-driver:v1.0.0. - This configuration is required and must be provided by the user. + Container image to use for the Spark application. + This is usually of the form example.com/repo/spark:v1.0.0. + This configuration is required and must be provided by the user, unless explicit + images are provided for each different container type. + + + + spark.kubernetes.driver.container.image + (value of spark.kubernetes.container.image) + + Custom container image to use for the driver. spark.kubernetes.executor.container.image - (none) + (value of spark.kubernetes.container.image) - Container image to use for the executors. - This is usually of the form example.com/repo/spark-executor:v1.0.0. - This configuration is required and must be provided by the user. + Custom container image to use for executors. @@ -643,9 +637,9 @@ specific to Spark on Kubernetes. spark.kubernetes.initContainer.image - (none) + (value of spark.kubernetes.container.image) - Container image for the init-container of the driver and executors for downloading dependencies. This is usually of the form example.com/repo/spark-init:v1.0.0. This configuration is optional and must be provided by the user if any non-container local dependency is used and must be downloaded remotely. + Custom container image for the init container of both driver and executors. diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala index e5d79d9a9d9da..471196ac0e3f6 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala @@ -29,17 +29,23 @@ private[spark] object Config extends Logging { .stringConf .createWithDefault("default") + val CONTAINER_IMAGE = + ConfigBuilder("spark.kubernetes.container.image") + .doc("Container image to use for Spark containers. Individual container types " + + "(e.g. driver or executor) can also be configured to use different images if desired, " + + "by setting the container type-specific image name.") + .stringConf + .createOptional + val DRIVER_CONTAINER_IMAGE = ConfigBuilder("spark.kubernetes.driver.container.image") .doc("Container image to use for the driver.") - .stringConf - .createOptional + .fallbackConf(CONTAINER_IMAGE) val EXECUTOR_CONTAINER_IMAGE = ConfigBuilder("spark.kubernetes.executor.container.image") .doc("Container image to use for the executors.") - .stringConf - .createOptional + .fallbackConf(CONTAINER_IMAGE) val CONTAINER_IMAGE_PULL_POLICY = ConfigBuilder("spark.kubernetes.container.image.pullPolicy") @@ -148,8 +154,7 @@ private[spark] object Config extends Logging { val INIT_CONTAINER_IMAGE = ConfigBuilder("spark.kubernetes.initContainer.image") .doc("Image for the driver and executor's init-container for downloading dependencies.") - .stringConf - .createOptional + .fallbackConf(CONTAINER_IMAGE) val INIT_CONTAINER_MOUNT_TIMEOUT = ConfigBuilder("spark.kubernetes.mountDependencies.timeout") diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Constants.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Constants.scala index 111cb2a3b75e5..9411956996843 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Constants.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Constants.scala @@ -60,10 +60,9 @@ private[spark] object Constants { val ENV_APPLICATION_ID = "SPARK_APPLICATION_ID" val ENV_EXECUTOR_ID = "SPARK_EXECUTOR_ID" val ENV_EXECUTOR_POD_IP = "SPARK_EXECUTOR_POD_IP" - val ENV_EXECUTOR_EXTRA_CLASSPATH = "SPARK_EXECUTOR_EXTRA_CLASSPATH" val ENV_MOUNTED_CLASSPATH = "SPARK_MOUNTED_CLASSPATH" val ENV_JAVA_OPT_PREFIX = "SPARK_JAVA_OPT_" - val ENV_SUBMIT_EXTRA_CLASSPATH = "SPARK_SUBMIT_EXTRA_CLASSPATH" + val ENV_CLASSPATH = "SPARK_CLASSPATH" val ENV_DRIVER_MAIN_CLASS = "SPARK_DRIVER_CLASS" val ENV_DRIVER_ARGS = "SPARK_DRIVER_ARGS" val ENV_DRIVER_JAVA_OPTS = "SPARK_DRIVER_JAVA_OPTS" diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/InitContainerBootstrap.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/InitContainerBootstrap.scala index dfeccf9e2bd1c..f6a57dfe00171 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/InitContainerBootstrap.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/InitContainerBootstrap.scala @@ -77,6 +77,7 @@ private[spark] class InitContainerBootstrap( .withMountPath(INIT_CONTAINER_PROPERTIES_FILE_DIR) .endVolumeMount() .addToVolumeMounts(sharedVolumeMounts: _*) + .addToArgs("init") .addToArgs(INIT_CONTAINER_PROPERTIES_FILE_PATH) .build() diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStep.scala index eca46b84c6066..164e2e5594778 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStep.scala @@ -66,7 +66,7 @@ private[spark] class BasicDriverConfigurationStep( override def configureDriver(driverSpec: KubernetesDriverSpec): KubernetesDriverSpec = { val driverExtraClasspathEnv = driverExtraClasspath.map { classPath => new EnvVarBuilder() - .withName(ENV_SUBMIT_EXTRA_CLASSPATH) + .withName(ENV_CLASSPATH) .withValue(classPath) .build() } @@ -133,6 +133,7 @@ private[spark] class BasicDriverConfigurationStep( .addToLimits("memory", driverMemoryLimitQuantity) .addToLimits(maybeCpuLimitQuantity.toMap.asJava) .endResources() + .addToArgs("driver") .build() val baseDriverPod = new PodBuilder(driverSpec.driverPod) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala index bcacb3934d36a..141bd2827e7c5 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala @@ -128,7 +128,7 @@ private[spark] class ExecutorPodFactory( .build() val executorExtraClasspathEnv = executorExtraClasspath.map { cp => new EnvVarBuilder() - .withName(ENV_EXECUTOR_EXTRA_CLASSPATH) + .withName(ENV_CLASSPATH) .withValue(cp) .build() } @@ -181,6 +181,7 @@ private[spark] class ExecutorPodFactory( .endResources() .addAllToEnv(executorEnv.asJava) .withPorts(requiredPorts.asJava) + .addToArgs("executor") .build() val executorPod = new PodBuilder() diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestratorSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestratorSuite.scala index f193b1f4d3664..65274c6f50e01 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestratorSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestratorSuite.scala @@ -34,8 +34,7 @@ class DriverConfigOrchestratorSuite extends SparkFunSuite { private val SECRET_MOUNT_PATH = "/etc/secrets/driver" test("Base submission steps with a main app resource.") { - val sparkConf = new SparkConf(false) - .set(DRIVER_CONTAINER_IMAGE, DRIVER_IMAGE) + val sparkConf = new SparkConf(false).set(CONTAINER_IMAGE, DRIVER_IMAGE) val mainAppResource = JavaMainAppResource("local:///var/apps/jars/main.jar") val orchestrator = new DriverConfigOrchestrator( APP_ID, @@ -55,8 +54,7 @@ class DriverConfigOrchestratorSuite extends SparkFunSuite { } test("Base submission steps without a main app resource.") { - val sparkConf = new SparkConf(false) - .set(DRIVER_CONTAINER_IMAGE, DRIVER_IMAGE) + val sparkConf = new SparkConf(false).set(CONTAINER_IMAGE, DRIVER_IMAGE) val orchestrator = new DriverConfigOrchestrator( APP_ID, LAUNCH_TIME, @@ -75,8 +73,8 @@ class DriverConfigOrchestratorSuite extends SparkFunSuite { test("Submission steps with an init-container.") { val sparkConf = new SparkConf(false) - .set(DRIVER_CONTAINER_IMAGE, DRIVER_IMAGE) - .set(INIT_CONTAINER_IMAGE, IC_IMAGE) + .set(CONTAINER_IMAGE, DRIVER_IMAGE) + .set(INIT_CONTAINER_IMAGE.key, IC_IMAGE) .set("spark.jars", "hdfs://localhost:9000/var/apps/jars/jar1.jar") val mainAppResource = JavaMainAppResource("local:///var/apps/jars/main.jar") val orchestrator = new DriverConfigOrchestrator( @@ -98,7 +96,7 @@ class DriverConfigOrchestratorSuite extends SparkFunSuite { test("Submission steps with driver secrets to mount") { val sparkConf = new SparkConf(false) - .set(DRIVER_CONTAINER_IMAGE, DRIVER_IMAGE) + .set(CONTAINER_IMAGE, DRIVER_IMAGE) .set(s"$KUBERNETES_DRIVER_SECRETS_PREFIX$SECRET_FOO", SECRET_MOUNT_PATH) .set(s"$KUBERNETES_DRIVER_SECRETS_PREFIX$SECRET_BAR", SECRET_MOUNT_PATH) val mainAppResource = JavaMainAppResource("local:///var/apps/jars/main.jar") diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStepSuite.scala index 8ee629ac8ddc1..b136f2c02ffba 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStepSuite.scala @@ -47,7 +47,7 @@ class BasicDriverConfigurationStepSuite extends SparkFunSuite { .set(KUBERNETES_DRIVER_LIMIT_CORES, "4") .set(org.apache.spark.internal.config.DRIVER_MEMORY.key, "256M") .set(org.apache.spark.internal.config.DRIVER_MEMORY_OVERHEAD, 200L) - .set(DRIVER_CONTAINER_IMAGE, "spark-driver:latest") + .set(CONTAINER_IMAGE, "spark-driver:latest") .set(s"$KUBERNETES_DRIVER_ANNOTATION_PREFIX$CUSTOM_ANNOTATION_KEY", CUSTOM_ANNOTATION_VALUE) .set(s"$KUBERNETES_DRIVER_ENV_KEY$DRIVER_CUSTOM_ENV_KEY1", "customDriverEnv1") .set(s"$KUBERNETES_DRIVER_ENV_KEY$DRIVER_CUSTOM_ENV_KEY2", "customDriverEnv2") @@ -79,7 +79,7 @@ class BasicDriverConfigurationStepSuite extends SparkFunSuite { .asScala .map(env => (env.getName, env.getValue)) .toMap - assert(envs(ENV_SUBMIT_EXTRA_CLASSPATH) === "/opt/spark/spark-examples.jar") + assert(envs(ENV_CLASSPATH) === "/opt/spark/spark-examples.jar") assert(envs(ENV_DRIVER_MEMORY) === "256M") assert(envs(ENV_DRIVER_MAIN_CLASS) === MAIN_CLASS) assert(envs(ENV_DRIVER_ARGS) === "arg1 arg2 \"arg 3\"") diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerConfigOrchestratorSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerConfigOrchestratorSuite.scala index 20f2e5bc15df3..09b42e4484d86 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerConfigOrchestratorSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerConfigOrchestratorSuite.scala @@ -40,7 +40,7 @@ class InitContainerConfigOrchestratorSuite extends SparkFunSuite { test("including basic configuration step") { val sparkConf = new SparkConf(true) - .set(INIT_CONTAINER_IMAGE, DOCKER_IMAGE) + .set(CONTAINER_IMAGE, DOCKER_IMAGE) .set(s"$KUBERNETES_DRIVER_LABEL_PREFIX$CUSTOM_LABEL_KEY", CUSTOM_LABEL_VALUE) val orchestrator = new InitContainerConfigOrchestrator( @@ -59,7 +59,7 @@ class InitContainerConfigOrchestratorSuite extends SparkFunSuite { test("including step to mount user-specified secrets") { val sparkConf = new SparkConf(false) - .set(INIT_CONTAINER_IMAGE, DOCKER_IMAGE) + .set(CONTAINER_IMAGE, DOCKER_IMAGE) .set(s"$KUBERNETES_DRIVER_SECRETS_PREFIX$SECRET_FOO", SECRET_MOUNT_PATH) .set(s"$KUBERNETES_DRIVER_SECRETS_PREFIX$SECRET_BAR", SECRET_MOUNT_PATH) diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala index 7cfbe54c95390..a3c615be031d2 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala @@ -54,7 +54,7 @@ class ExecutorPodFactorySuite extends SparkFunSuite with BeforeAndAfter with Bef baseConf = new SparkConf() .set(KUBERNETES_DRIVER_POD_NAME, driverPodName) .set(KUBERNETES_EXECUTOR_POD_NAME_PREFIX, executorPrefix) - .set(EXECUTOR_CONTAINER_IMAGE, executorImage) + .set(CONTAINER_IMAGE, executorImage) } test("basic executor pod has reasonable defaults") { @@ -107,7 +107,7 @@ class ExecutorPodFactorySuite extends SparkFunSuite with BeforeAndAfter with Bef checkEnv(executor, Map("SPARK_JAVA_OPT_0" -> "foo=bar", - "SPARK_EXECUTOR_EXTRA_CLASSPATH" -> "bar=baz", + ENV_CLASSPATH -> "bar=baz", "qux" -> "quux")) checkOwnerReferences(executor, driverPodUid) } diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/driver/Dockerfile b/resource-managers/kubernetes/docker/src/main/dockerfiles/driver/Dockerfile deleted file mode 100644 index 45fbcd9cd0deb..0000000000000 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/driver/Dockerfile +++ /dev/null @@ -1,35 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -FROM spark-base - -# Before building the docker image, first build and make a Spark distribution following -# the instructions in http://spark.apache.org/docs/latest/building-spark.html. -# If this docker file is being used in the context of building your images from a Spark -# distribution, the docker build command should be invoked from the top level directory -# of the Spark distribution. E.g.: -# docker build -t spark-driver:latest -f kubernetes/dockerfiles/driver/Dockerfile . - -COPY examples /opt/spark/examples - -CMD SPARK_CLASSPATH="${SPARK_HOME}/jars/*" && \ - env | grep SPARK_JAVA_OPT_ | sed 's/[^=]*=\(.*\)/\1/g' > /tmp/java_opts.txt && \ - readarray -t SPARK_DRIVER_JAVA_OPTS < /tmp/java_opts.txt && \ - if ! [ -z ${SPARK_MOUNTED_CLASSPATH+x} ]; then SPARK_CLASSPATH="$SPARK_MOUNTED_CLASSPATH:$SPARK_CLASSPATH"; fi && \ - if ! [ -z ${SPARK_SUBMIT_EXTRA_CLASSPATH+x} ]; then SPARK_CLASSPATH="$SPARK_SUBMIT_EXTRA_CLASSPATH:$SPARK_CLASSPATH"; fi && \ - if ! [ -z ${SPARK_MOUNTED_FILES_DIR+x} ]; then cp -R "$SPARK_MOUNTED_FILES_DIR/." .; fi && \ - ${JAVA_HOME}/bin/java "${SPARK_DRIVER_JAVA_OPTS[@]}" -cp "$SPARK_CLASSPATH" -Xms$SPARK_DRIVER_MEMORY -Xmx$SPARK_DRIVER_MEMORY -Dspark.driver.bindAddress=$SPARK_DRIVER_BIND_ADDRESS $SPARK_DRIVER_CLASS $SPARK_DRIVER_ARGS diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/executor/Dockerfile b/resource-managers/kubernetes/docker/src/main/dockerfiles/executor/Dockerfile deleted file mode 100644 index 0f806cf7e148e..0000000000000 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/executor/Dockerfile +++ /dev/null @@ -1,35 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -FROM spark-base - -# Before building the docker image, first build and make a Spark distribution following -# the instructions in http://spark.apache.org/docs/latest/building-spark.html. -# If this docker file is being used in the context of building your images from a Spark -# distribution, the docker build command should be invoked from the top level directory -# of the Spark distribution. E.g.: -# docker build -t spark-executor:latest -f kubernetes/dockerfiles/executor/Dockerfile . - -COPY examples /opt/spark/examples - -CMD SPARK_CLASSPATH="${SPARK_HOME}/jars/*" && \ - env | grep SPARK_JAVA_OPT_ | sed 's/[^=]*=\(.*\)/\1/g' > /tmp/java_opts.txt && \ - readarray -t SPARK_EXECUTOR_JAVA_OPTS < /tmp/java_opts.txt && \ - if ! [ -z ${SPARK_MOUNTED_CLASSPATH}+x} ]; then SPARK_CLASSPATH="$SPARK_MOUNTED_CLASSPATH:$SPARK_CLASSPATH"; fi && \ - if ! [ -z ${SPARK_EXECUTOR_EXTRA_CLASSPATH+x} ]; then SPARK_CLASSPATH="$SPARK_EXECUTOR_EXTRA_CLASSPATH:$SPARK_CLASSPATH"; fi && \ - if ! [ -z ${SPARK_MOUNTED_FILES_DIR+x} ]; then cp -R "$SPARK_MOUNTED_FILES_DIR/." .; fi && \ - ${JAVA_HOME}/bin/java "${SPARK_EXECUTOR_JAVA_OPTS[@]}" -Xms$SPARK_EXECUTOR_MEMORY -Xmx$SPARK_EXECUTOR_MEMORY -cp "$SPARK_CLASSPATH" org.apache.spark.executor.CoarseGrainedExecutorBackend --driver-url $SPARK_DRIVER_URL --executor-id $SPARK_EXECUTOR_ID --cores $SPARK_EXECUTOR_CORES --app-id $SPARK_APPLICATION_ID --hostname $SPARK_EXECUTOR_POD_IP diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/init-container/Dockerfile b/resource-managers/kubernetes/docker/src/main/dockerfiles/init-container/Dockerfile deleted file mode 100644 index 047056ab2633b..0000000000000 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/init-container/Dockerfile +++ /dev/null @@ -1,24 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -FROM spark-base - -# If this docker file is being used in the context of building your images from a Spark distribution, the docker build -# command should be invoked from the top level directory of the Spark distribution. E.g.: -# docker build -t spark-init:latest -f kubernetes/dockerfiles/init-container/Dockerfile . - -ENTRYPOINT [ "/opt/entrypoint.sh", "/opt/spark/bin/spark-class", "org.apache.spark.deploy.k8s.SparkPodInitContainer" ] diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark-base/entrypoint.sh b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark-base/entrypoint.sh deleted file mode 100755 index 82559889f4beb..0000000000000 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark-base/entrypoint.sh +++ /dev/null @@ -1,37 +0,0 @@ -#!/bin/bash -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -# echo commands to the terminal output -set -ex - -# Check whether there is a passwd entry for the container UID -myuid=$(id -u) -mygid=$(id -g) -uidentry=$(getent passwd $myuid) - -# If there is no passwd entry for the container UID, attempt to create one -if [ -z "$uidentry" ] ; then - if [ -w /etc/passwd ] ; then - echo "$myuid:x:$myuid:$mygid:anonymous uid:$SPARK_HOME:/bin/false" >> /etc/passwd - else - echo "Container ENTRYPOINT failed to add passwd entry for anonymous UID" - fi -fi - -# Execute the container CMD under tini for better hygiene -/sbin/tini -s -- "$@" diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark-base/Dockerfile b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/Dockerfile similarity index 87% rename from resource-managers/kubernetes/docker/src/main/dockerfiles/spark-base/Dockerfile rename to resource-managers/kubernetes/docker/src/main/dockerfiles/spark/Dockerfile index da1d6b9e161cc..491b7cf692478 100644 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark-base/Dockerfile +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/Dockerfile @@ -17,15 +17,15 @@ FROM openjdk:8-alpine -ARG spark_jars -ARG img_path +ARG spark_jars=jars +ARG img_path=kubernetes/dockerfiles # Before building the docker image, first build and make a Spark distribution following # the instructions in http://spark.apache.org/docs/latest/building-spark.html. # If this docker file is being used in the context of building your images from a Spark # distribution, the docker build command should be invoked from the top level directory # of the Spark distribution. E.g.: -# docker build -t spark-base:latest -f kubernetes/dockerfiles/spark-base/Dockerfile . +# docker build -t spark:latest -f kubernetes/dockerfiles/spark/Dockerfile . RUN set -ex && \ apk upgrade --no-cache && \ @@ -41,7 +41,9 @@ COPY ${spark_jars} /opt/spark/jars COPY bin /opt/spark/bin COPY sbin /opt/spark/sbin COPY conf /opt/spark/conf -COPY ${img_path}/spark-base/entrypoint.sh /opt/ +COPY ${img_path}/spark/entrypoint.sh /opt/ +COPY examples /opt/spark/examples +COPY data /opt/spark/data ENV SPARK_HOME /opt/spark diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh new file mode 100755 index 0000000000000..0c28c75857871 --- /dev/null +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh @@ -0,0 +1,97 @@ +#!/bin/bash +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# echo commands to the terminal output +set -ex + +# Check whether there is a passwd entry for the container UID +myuid=$(id -u) +mygid=$(id -g) +uidentry=$(getent passwd $myuid) + +# If there is no passwd entry for the container UID, attempt to create one +if [ -z "$uidentry" ] ; then + if [ -w /etc/passwd ] ; then + echo "$myuid:x:$myuid:$mygid:anonymous uid:$SPARK_HOME:/bin/false" >> /etc/passwd + else + echo "Container ENTRYPOINT failed to add passwd entry for anonymous UID" + fi +fi + +SPARK_K8S_CMD="$1" +if [ -z "$SPARK_K8S_CMD" ]; then + echo "No command to execute has been provided." 1>&2 + exit 1 +fi +shift 1 + +SPARK_CLASSPATH="$SPARK_CLASSPATH:${SPARK_HOME}/jars/*" +env | grep SPARK_JAVA_OPT_ | sed 's/[^=]*=\(.*\)/\1/g' > /tmp/java_opts.txt +readarray -t SPARK_DRIVER_JAVA_OPTS < /tmp/java_opts.txt +if [ -n "$SPARK_MOUNTED_CLASSPATH" ]; then + SPARK_CLASSPATH="$SPARK_CLASSPATH:$SPARK_MOUNTED_CLASSPATH" +fi +if [ -n "$SPARK_MOUNTED_FILES_DIR" ]; then + cp -R "$SPARK_MOUNTED_FILES_DIR/." . +fi + +case "$SPARK_K8S_CMD" in + driver) + CMD=( + ${JAVA_HOME}/bin/java + "${SPARK_DRIVER_JAVA_OPTS[@]}" + -cp "$SPARK_CLASSPATH" + -Xms$SPARK_DRIVER_MEMORY + -Xmx$SPARK_DRIVER_MEMORY + -Dspark.driver.bindAddress=$SPARK_DRIVER_BIND_ADDRESS + $SPARK_DRIVER_CLASS + $SPARK_DRIVER_ARGS + ) + ;; + + executor) + CMD=( + ${JAVA_HOME}/bin/java + "${SPARK_EXECUTOR_JAVA_OPTS[@]}" + -Xms$SPARK_EXECUTOR_MEMORY + -Xmx$SPARK_EXECUTOR_MEMORY + -cp "$SPARK_CLASSPATH" + org.apache.spark.executor.CoarseGrainedExecutorBackend + --driver-url $SPARK_DRIVER_URL + --executor-id $SPARK_EXECUTOR_ID + --cores $SPARK_EXECUTOR_CORES + --app-id $SPARK_APPLICATION_ID + --hostname $SPARK_EXECUTOR_POD_IP + ) + ;; + + init) + CMD=( + "$SPARK_HOME/bin/spark-class" + "org.apache.spark.deploy.k8s.SparkPodInitContainer" + "$@" + ) + ;; + + *) + echo "Unknown command: $SPARK_K8S_CMD" 1>&2 + exit 1 +esac + +# Execute the container CMD under tini for better hygiene +exec /sbin/tini -s -- "${CMD[@]}" From 6f7aaed805070d29dcba32e04ca7a1f581fa54b9 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Thu, 11 Jan 2018 10:52:12 -0800 Subject: [PATCH 0070/1161] [SPARK-22908] Add kafka source and sink for continuous processing. ## What changes were proposed in this pull request? Add kafka source and sink for continuous processing. This involves two small changes to the execution engine: * Bring data reader close() into the normal data reader thread to avoid thread safety issues. * Fix up the semantics of the RECONFIGURING StreamExecution state. State updates are now atomic, and we don't have to deal with swallowing an exception. ## How was this patch tested? new unit tests Author: Jose Torres Closes #20096 from jose-torres/continuous-kafka. --- .../sql/kafka010/KafkaContinuousReader.scala | 232 +++++++++ .../sql/kafka010/KafkaContinuousWriter.scala | 119 +++++ .../sql/kafka010/KafkaOffsetReader.scala | 21 +- .../spark/sql/kafka010/KafkaSource.scala | 17 +- .../sql/kafka010/KafkaSourceOffset.scala | 7 +- .../sql/kafka010/KafkaSourceProvider.scala | 105 +++- .../spark/sql/kafka010/KafkaWriteTask.scala | 71 ++- .../spark/sql/kafka010/KafkaWriter.scala | 5 +- .../kafka010/KafkaContinuousSinkSuite.scala | 474 ++++++++++++++++++ .../kafka010/KafkaContinuousSourceSuite.scala | 96 ++++ .../sql/kafka010/KafkaContinuousTest.scala | 64 +++ .../spark/sql/kafka010/KafkaSourceSuite.scala | 470 +++++++++-------- .../apache/spark/sql/DataFrameReader.scala | 32 +- .../apache/spark/sql/DataFrameWriter.scala | 25 +- .../datasources/v2/WriteToDataSourceV2.scala | 8 +- .../execution/streaming/StreamExecution.scala | 15 +- .../ContinuousDataSourceRDDIter.scala | 3 +- .../continuous/ContinuousExecution.scala | 67 ++- .../continuous/EpochCoordinator.scala | 21 +- .../sql/streaming/DataStreamWriter.scala | 26 +- .../spark/sql/streaming/StreamTest.scala | 36 +- 21 files changed, 1531 insertions(+), 383 deletions(-) create mode 100644 external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala create mode 100644 external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousWriter.scala create mode 100644 external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala create mode 100644 external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala create mode 100644 external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala new file mode 100644 index 0000000000000..928379544758c --- /dev/null +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala @@ -0,0 +1,232 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import java.{util => ju} + +import org.apache.kafka.clients.consumer.ConsumerRecord +import org.apache.kafka.common.TopicPartition + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.catalyst.expressions.codegen.{BufferHolder, UnsafeRowWriter} +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.kafka010.KafkaSource.{INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE, INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE} +import org.apache.spark.sql.sources.v2.reader._ +import org.apache.spark.sql.sources.v2.streaming.reader.{ContinuousDataReader, ContinuousReader, Offset, PartitionOffset} +import org.apache.spark.sql.types.StructType +import org.apache.spark.unsafe.types.UTF8String + +/** + * A [[ContinuousReader]] for data from kafka. + * + * @param offsetReader a reader used to get kafka offsets. Note that the actual data will be + * read by per-task consumers generated later. + * @param kafkaParams String params for per-task Kafka consumers. + * @param sourceOptions The [[org.apache.spark.sql.sources.v2.DataSourceV2Options]] params which + * are not Kafka consumer params. + * @param metadataPath Path to a directory this reader can use for writing metadata. + * @param initialOffsets The Kafka offsets to start reading data at. + * @param failOnDataLoss Flag indicating whether reading should fail in data loss + * scenarios, where some offsets after the specified initial ones can't be + * properly read. + */ +class KafkaContinuousReader( + offsetReader: KafkaOffsetReader, + kafkaParams: ju.Map[String, Object], + sourceOptions: Map[String, String], + metadataPath: String, + initialOffsets: KafkaOffsetRangeLimit, + failOnDataLoss: Boolean) + extends ContinuousReader with SupportsScanUnsafeRow with Logging { + + private lazy val session = SparkSession.getActiveSession.get + private lazy val sc = session.sparkContext + + // Initialized when creating read tasks. If this diverges from the partitions at the latest + // offsets, we need to reconfigure. + // Exposed outside this object only for unit tests. + private[sql] var knownPartitions: Set[TopicPartition] = _ + + override def readSchema: StructType = KafkaOffsetReader.kafkaSchema + + private var offset: Offset = _ + override def setOffset(start: ju.Optional[Offset]): Unit = { + offset = start.orElse { + val offsets = initialOffsets match { + case EarliestOffsetRangeLimit => KafkaSourceOffset(offsetReader.fetchEarliestOffsets()) + case LatestOffsetRangeLimit => KafkaSourceOffset(offsetReader.fetchLatestOffsets()) + case SpecificOffsetRangeLimit(p) => offsetReader.fetchSpecificOffsets(p, reportDataLoss) + } + logInfo(s"Initial offsets: $offsets") + offsets + } + } + + override def getStartOffset(): Offset = offset + + override def deserializeOffset(json: String): Offset = { + KafkaSourceOffset(JsonUtils.partitionOffsets(json)) + } + + override def createUnsafeRowReadTasks(): ju.List[ReadTask[UnsafeRow]] = { + import scala.collection.JavaConverters._ + + val oldStartPartitionOffsets = KafkaSourceOffset.getPartitionOffsets(offset) + + val currentPartitionSet = offsetReader.fetchEarliestOffsets().keySet + val newPartitions = currentPartitionSet.diff(oldStartPartitionOffsets.keySet) + val newPartitionOffsets = offsetReader.fetchEarliestOffsets(newPartitions.toSeq) + + val deletedPartitions = oldStartPartitionOffsets.keySet.diff(currentPartitionSet) + if (deletedPartitions.nonEmpty) { + reportDataLoss(s"Some partitions were deleted: $deletedPartitions") + } + + val startOffsets = newPartitionOffsets ++ + oldStartPartitionOffsets.filterKeys(!deletedPartitions.contains(_)) + knownPartitions = startOffsets.keySet + + startOffsets.toSeq.map { + case (topicPartition, start) => + KafkaContinuousReadTask( + topicPartition, start, kafkaParams, failOnDataLoss) + .asInstanceOf[ReadTask[UnsafeRow]] + }.asJava + } + + /** Stop this source and free any resources it has allocated. */ + def stop(): Unit = synchronized { + offsetReader.close() + } + + override def commit(end: Offset): Unit = {} + + override def mergeOffsets(offsets: Array[PartitionOffset]): Offset = { + val mergedMap = offsets.map { + case KafkaSourcePartitionOffset(p, o) => Map(p -> o) + }.reduce(_ ++ _) + KafkaSourceOffset(mergedMap) + } + + override def needsReconfiguration(): Boolean = { + knownPartitions != null && offsetReader.fetchLatestOffsets().keySet != knownPartitions + } + + override def toString(): String = s"KafkaSource[$offsetReader]" + + /** + * If `failOnDataLoss` is true, this method will throw an `IllegalStateException`. + * Otherwise, just log a warning. + */ + private def reportDataLoss(message: String): Unit = { + if (failOnDataLoss) { + throw new IllegalStateException(message + s". $INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE") + } else { + logWarning(message + s". $INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE") + } + } +} + +/** + * A read task for continuous Kafka processing. This will be serialized and transformed into a + * full reader on executors. + * + * @param topicPartition The (topic, partition) pair this task is responsible for. + * @param startOffset The offset to start reading from within the partition. + * @param kafkaParams Kafka consumer params to use. + * @param failOnDataLoss Flag indicating whether data reader should fail if some offsets + * are skipped. + */ +case class KafkaContinuousReadTask( + topicPartition: TopicPartition, + startOffset: Long, + kafkaParams: ju.Map[String, Object], + failOnDataLoss: Boolean) extends ReadTask[UnsafeRow] { + override def createDataReader(): KafkaContinuousDataReader = { + new KafkaContinuousDataReader(topicPartition, startOffset, kafkaParams, failOnDataLoss) + } +} + +/** + * A per-task data reader for continuous Kafka processing. + * + * @param topicPartition The (topic, partition) pair this data reader is responsible for. + * @param startOffset The offset to start reading from within the partition. + * @param kafkaParams Kafka consumer params to use. + * @param failOnDataLoss Flag indicating whether data reader should fail if some offsets + * are skipped. + */ +class KafkaContinuousDataReader( + topicPartition: TopicPartition, + startOffset: Long, + kafkaParams: ju.Map[String, Object], + failOnDataLoss: Boolean) extends ContinuousDataReader[UnsafeRow] { + private val topic = topicPartition.topic + private val kafkaPartition = topicPartition.partition + private val consumer = CachedKafkaConsumer.createUncached(topic, kafkaPartition, kafkaParams) + + private val sharedRow = new UnsafeRow(7) + private val bufferHolder = new BufferHolder(sharedRow) + private val rowWriter = new UnsafeRowWriter(bufferHolder, 7) + + private var nextKafkaOffset = startOffset + private var currentRecord: ConsumerRecord[Array[Byte], Array[Byte]] = _ + + override def next(): Boolean = { + var r: ConsumerRecord[Array[Byte], Array[Byte]] = null + while (r == null) { + r = consumer.get( + nextKafkaOffset, + untilOffset = Long.MaxValue, + pollTimeoutMs = Long.MaxValue, + failOnDataLoss) + } + nextKafkaOffset = r.offset + 1 + currentRecord = r + true + } + + override def get(): UnsafeRow = { + bufferHolder.reset() + + if (currentRecord.key == null) { + rowWriter.setNullAt(0) + } else { + rowWriter.write(0, currentRecord.key) + } + rowWriter.write(1, currentRecord.value) + rowWriter.write(2, UTF8String.fromString(currentRecord.topic)) + rowWriter.write(3, currentRecord.partition) + rowWriter.write(4, currentRecord.offset) + rowWriter.write(5, + DateTimeUtils.fromJavaTimestamp(new java.sql.Timestamp(currentRecord.timestamp))) + rowWriter.write(6, currentRecord.timestampType.id) + sharedRow.setTotalSize(bufferHolder.totalSize) + sharedRow + } + + override def getOffset(): KafkaSourcePartitionOffset = { + KafkaSourcePartitionOffset(topicPartition, nextKafkaOffset) + } + + override def close(): Unit = { + consumer.close() + } +} diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousWriter.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousWriter.scala new file mode 100644 index 0000000000000..9843f469c5b25 --- /dev/null +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousWriter.scala @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import org.apache.kafka.clients.producer.{Callback, ProducerRecord, RecordMetadata} +import scala.collection.JavaConverters._ + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.{Row, SparkSession} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Literal, UnsafeProjection} +import org.apache.spark.sql.kafka010.KafkaSourceProvider.{kafkaParamsForProducer, TOPIC_OPTION_KEY} +import org.apache.spark.sql.kafka010.KafkaWriter.validateQuery +import org.apache.spark.sql.sources.v2.streaming.writer.ContinuousWriter +import org.apache.spark.sql.sources.v2.writer._ +import org.apache.spark.sql.streaming.OutputMode +import org.apache.spark.sql.types.{BinaryType, StringType, StructType} + +/** + * Dummy commit message. The DataSourceV2 framework requires a commit message implementation but we + * don't need to really send one. + */ +case object KafkaWriterCommitMessage extends WriterCommitMessage + +/** + * A [[ContinuousWriter]] for Kafka writing. Responsible for generating the writer factory. + * @param topic The topic this writer is responsible for. If None, topic will be inferred from + * a `topic` field in the incoming data. + * @param producerParams Parameters for Kafka producers in each task. + * @param schema The schema of the input data. + */ +class KafkaContinuousWriter( + topic: Option[String], producerParams: Map[String, String], schema: StructType) + extends ContinuousWriter with SupportsWriteInternalRow { + + validateQuery(schema.toAttributes, producerParams.toMap[String, Object].asJava, topic) + + override def createInternalRowWriterFactory(): KafkaContinuousWriterFactory = + KafkaContinuousWriterFactory(topic, producerParams, schema) + + override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {} + override def abort(messages: Array[WriterCommitMessage]): Unit = {} +} + +/** + * A [[DataWriterFactory]] for Kafka writing. Will be serialized and sent to executors to generate + * the per-task data writers. + * @param topic The topic that should be written to. If None, topic will be inferred from + * a `topic` field in the incoming data. + * @param producerParams Parameters for Kafka producers in each task. + * @param schema The schema of the input data. + */ +case class KafkaContinuousWriterFactory( + topic: Option[String], producerParams: Map[String, String], schema: StructType) + extends DataWriterFactory[InternalRow] { + + override def createDataWriter(partitionId: Int, attemptNumber: Int): DataWriter[InternalRow] = { + new KafkaContinuousDataWriter(topic, producerParams, schema.toAttributes) + } +} + +/** + * A [[DataWriter]] for Kafka writing. One data writer will be created in each partition to + * process incoming rows. + * + * @param targetTopic The topic that this data writer is targeting. If None, topic will be inferred + * from a `topic` field in the incoming data. + * @param producerParams Parameters to use for the Kafka producer. + * @param inputSchema The attributes in the input data. + */ +class KafkaContinuousDataWriter( + targetTopic: Option[String], producerParams: Map[String, String], inputSchema: Seq[Attribute]) + extends KafkaRowWriter(inputSchema, targetTopic) with DataWriter[InternalRow] { + import scala.collection.JavaConverters._ + + private lazy val producer = CachedKafkaProducer.getOrCreate( + new java.util.HashMap[String, Object](producerParams.asJava)) + + def write(row: InternalRow): Unit = { + checkForErrors() + sendRow(row, producer) + } + + def commit(): WriterCommitMessage = { + // Send is asynchronous, but we can't commit until all rows are actually in Kafka. + // This requires flushing and then checking that no callbacks produced errors. + // We also check for errors before to fail as soon as possible - the check is cheap. + checkForErrors() + producer.flush() + checkForErrors() + KafkaWriterCommitMessage + } + + def abort(): Unit = {} + + def close(): Unit = { + checkForErrors() + if (producer != null) { + producer.flush() + checkForErrors() + CachedKafkaProducer.close(new java.util.HashMap[String, Object](producerParams.asJava)) + } + } +} diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala index 3e65949a6fd1b..551641cfdbca8 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala @@ -117,10 +117,14 @@ private[kafka010] class KafkaOffsetReader( * Resolves the specific offsets based on Kafka seek positions. * This method resolves offset value -1 to the latest and -2 to the * earliest Kafka seek position. + * + * @param partitionOffsets the specific offsets to resolve + * @param reportDataLoss callback to either report or log data loss depending on setting */ def fetchSpecificOffsets( - partitionOffsets: Map[TopicPartition, Long]): Map[TopicPartition, Long] = - runUninterruptibly { + partitionOffsets: Map[TopicPartition, Long], + reportDataLoss: String => Unit): KafkaSourceOffset = { + val fetched = runUninterruptibly { withRetriesWithoutInterrupt { // Poll to get the latest assigned partitions consumer.poll(0) @@ -145,6 +149,19 @@ private[kafka010] class KafkaOffsetReader( } } + partitionOffsets.foreach { + case (tp, off) if off != KafkaOffsetRangeLimit.LATEST && + off != KafkaOffsetRangeLimit.EARLIEST => + if (fetched(tp) != off) { + reportDataLoss( + s"startingOffsets for $tp was $off but consumer reset to ${fetched(tp)}") + } + case _ => + // no real way to check that beginning or end is reasonable + } + KafkaSourceOffset(fetched) + } + /** * Fetch the earliest offsets for the topic partitions that are indicated * in the [[ConsumerStrategy]]. diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala index e9cff04ba5f2e..27da76068a66f 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala @@ -130,7 +130,7 @@ private[kafka010] class KafkaSource( val offsets = startingOffsets match { case EarliestOffsetRangeLimit => KafkaSourceOffset(kafkaReader.fetchEarliestOffsets()) case LatestOffsetRangeLimit => KafkaSourceOffset(kafkaReader.fetchLatestOffsets()) - case SpecificOffsetRangeLimit(p) => fetchAndVerify(p) + case SpecificOffsetRangeLimit(p) => kafkaReader.fetchSpecificOffsets(p, reportDataLoss) } metadataLog.add(0, offsets) logInfo(s"Initial offsets: $offsets") @@ -138,21 +138,6 @@ private[kafka010] class KafkaSource( }.partitionToOffsets } - private def fetchAndVerify(specificOffsets: Map[TopicPartition, Long]) = { - val result = kafkaReader.fetchSpecificOffsets(specificOffsets) - specificOffsets.foreach { - case (tp, off) if off != KafkaOffsetRangeLimit.LATEST && - off != KafkaOffsetRangeLimit.EARLIEST => - if (result(tp) != off) { - reportDataLoss( - s"startingOffsets for $tp was $off but consumer reset to ${result(tp)}") - } - case _ => - // no real way to check that beginning or end is reasonable - } - KafkaSourceOffset(result) - } - private var currentPartitionOffsets: Option[Map[TopicPartition, Long]] = None override def schema: StructType = KafkaOffsetReader.kafkaSchema diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceOffset.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceOffset.scala index b5da415b3097e..c82154cfbad7f 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceOffset.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceOffset.scala @@ -20,17 +20,22 @@ package org.apache.spark.sql.kafka010 import org.apache.kafka.common.TopicPartition import org.apache.spark.sql.execution.streaming.{Offset, SerializedOffset} +import org.apache.spark.sql.sources.v2.streaming.reader.{Offset => OffsetV2, PartitionOffset} /** * An [[Offset]] for the [[KafkaSource]]. This one tracks all partitions of subscribed topics and * their offsets. */ private[kafka010] -case class KafkaSourceOffset(partitionToOffsets: Map[TopicPartition, Long]) extends Offset { +case class KafkaSourceOffset(partitionToOffsets: Map[TopicPartition, Long]) extends OffsetV2 { override val json = JsonUtils.partitionOffsets(partitionToOffsets) } +private[kafka010] +case class KafkaSourcePartitionOffset(topicPartition: TopicPartition, partitionOffset: Long) + extends PartitionOffset + /** Companion object of the [[KafkaSourceOffset]] */ private[kafka010] object KafkaSourceOffset { diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala index 3cb4d8cad12cc..3914370a96595 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.kafka010 import java.{util => ju} -import java.util.{Locale, UUID} +import java.util.{Locale, Optional, UUID} import scala.collection.JavaConverters._ @@ -27,9 +27,12 @@ import org.apache.kafka.clients.producer.ProducerConfig import org.apache.kafka.common.serialization.{ByteArrayDeserializer, ByteArraySerializer} import org.apache.spark.internal.Logging -import org.apache.spark.sql.{AnalysisException, DataFrame, SaveMode, SQLContext} -import org.apache.spark.sql.execution.streaming.{Sink, Source} +import org.apache.spark.sql.{AnalysisException, DataFrame, SaveMode, SparkSession, SQLContext} +import org.apache.spark.sql.execution.streaming.{Offset, Sink, Source} import org.apache.spark.sql.sources._ +import org.apache.spark.sql.sources.v2.{DataSourceV2, DataSourceV2Options} +import org.apache.spark.sql.sources.v2.streaming.{ContinuousReadSupport, ContinuousWriteSupport} +import org.apache.spark.sql.sources.v2.streaming.writer.ContinuousWriter import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType @@ -43,6 +46,8 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister with StreamSinkProvider with RelationProvider with CreatableRelationProvider + with ContinuousWriteSupport + with ContinuousReadSupport with Logging { import KafkaSourceProvider._ @@ -101,6 +106,43 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister failOnDataLoss(caseInsensitiveParams)) } + override def createContinuousReader( + schema: Optional[StructType], + metadataPath: String, + options: DataSourceV2Options): KafkaContinuousReader = { + val parameters = options.asMap().asScala.toMap + validateStreamOptions(parameters) + // Each running query should use its own group id. Otherwise, the query may be only assigned + // partial data since Kafka will assign partitions to multiple consumers having the same group + // id. Hence, we should generate a unique id for each query. + val uniqueGroupId = s"spark-kafka-source-${UUID.randomUUID}-${metadataPath.hashCode}" + + val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) } + val specifiedKafkaParams = + parameters + .keySet + .filter(_.toLowerCase(Locale.ROOT).startsWith("kafka.")) + .map { k => k.drop(6).toString -> parameters(k) } + .toMap + + val startingStreamOffsets = KafkaSourceProvider.getKafkaOffsetRangeLimit(caseInsensitiveParams, + STARTING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit) + + val kafkaOffsetReader = new KafkaOffsetReader( + strategy(caseInsensitiveParams), + kafkaParamsForDriver(specifiedKafkaParams), + parameters, + driverGroupIdPrefix = s"$uniqueGroupId-driver") + + new KafkaContinuousReader( + kafkaOffsetReader, + kafkaParamsForExecutors(specifiedKafkaParams, uniqueGroupId), + parameters, + metadataPath, + startingStreamOffsets, + failOnDataLoss(caseInsensitiveParams)) + } + /** * Returns a new base relation with the given parameters. * @@ -181,26 +223,22 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister } } - private def kafkaParamsForProducer(parameters: Map[String, String]): Map[String, String] = { - val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) } - if (caseInsensitiveParams.contains(s"kafka.${ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG}")) { - throw new IllegalArgumentException( - s"Kafka option '${ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG}' is not supported as keys " - + "are serialized with ByteArraySerializer.") - } + override def createContinuousWriter( + queryId: String, + schema: StructType, + mode: OutputMode, + options: DataSourceV2Options): Optional[ContinuousWriter] = { + import scala.collection.JavaConverters._ - if (caseInsensitiveParams.contains(s"kafka.${ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG}")) - { - throw new IllegalArgumentException( - s"Kafka option '${ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG}' is not supported as " - + "value are serialized with ByteArraySerializer.") - } - parameters - .keySet - .filter(_.toLowerCase(Locale.ROOT).startsWith("kafka.")) - .map { k => k.drop(6).toString -> parameters(k) } - .toMap + (ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG -> classOf[ByteArraySerializer].getName, - ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG -> classOf[ByteArraySerializer].getName) + val spark = SparkSession.getActiveSession.get + val topic = Option(options.get(TOPIC_OPTION_KEY).orElse(null)).map(_.trim) + // We convert the options argument from V2 -> Java map -> scala mutable -> scala immutable. + val producerParams = kafkaParamsForProducer(options.asMap.asScala.toMap) + + KafkaWriter.validateQuery( + schema.toAttributes, new java.util.HashMap[String, Object](producerParams.asJava), topic) + + Optional.of(new KafkaContinuousWriter(topic, producerParams, schema)) } private def strategy(caseInsensitiveParams: Map[String, String]) = @@ -450,4 +488,27 @@ private[kafka010] object KafkaSourceProvider extends Logging { def build(): ju.Map[String, Object] = map } + + private[kafka010] def kafkaParamsForProducer( + parameters: Map[String, String]): Map[String, String] = { + val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) } + if (caseInsensitiveParams.contains(s"kafka.${ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG}")) { + throw new IllegalArgumentException( + s"Kafka option '${ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG}' is not supported as keys " + + "are serialized with ByteArraySerializer.") + } + + if (caseInsensitiveParams.contains(s"kafka.${ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG}")) + { + throw new IllegalArgumentException( + s"Kafka option '${ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG}' is not supported as " + + "value are serialized with ByteArraySerializer.") + } + parameters + .keySet + .filter(_.toLowerCase(Locale.ROOT).startsWith("kafka.")) + .map { k => k.drop(6).toString -> parameters(k) } + .toMap + (ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG -> classOf[ByteArraySerializer].getName, + ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG -> classOf[ByteArraySerializer].getName) + } } diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala index 6fd333e2f43ba..baa60febf661d 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala @@ -33,10 +33,8 @@ import org.apache.spark.sql.types.{BinaryType, StringType} private[kafka010] class KafkaWriteTask( producerConfiguration: ju.Map[String, Object], inputSchema: Seq[Attribute], - topic: Option[String]) { + topic: Option[String]) extends KafkaRowWriter(inputSchema, topic) { // used to synchronize with Kafka callbacks - @volatile private var failedWrite: Exception = null - private val projection = createProjection private var producer: KafkaProducer[Array[Byte], Array[Byte]] = _ /** @@ -46,23 +44,7 @@ private[kafka010] class KafkaWriteTask( producer = CachedKafkaProducer.getOrCreate(producerConfiguration) while (iterator.hasNext && failedWrite == null) { val currentRow = iterator.next() - val projectedRow = projection(currentRow) - val topic = projectedRow.getUTF8String(0) - val key = projectedRow.getBinary(1) - val value = projectedRow.getBinary(2) - if (topic == null) { - throw new NullPointerException(s"null topic present in the data. Use the " + - s"${KafkaSourceProvider.TOPIC_OPTION_KEY} option for setting a default topic.") - } - val record = new ProducerRecord[Array[Byte], Array[Byte]](topic.toString, key, value) - val callback = new Callback() { - override def onCompletion(recordMetadata: RecordMetadata, e: Exception): Unit = { - if (failedWrite == null && e != null) { - failedWrite = e - } - } - } - producer.send(record, callback) + sendRow(currentRow, producer) } } @@ -74,8 +56,49 @@ private[kafka010] class KafkaWriteTask( producer = null } } +} + +private[kafka010] abstract class KafkaRowWriter( + inputSchema: Seq[Attribute], topic: Option[String]) { + + // used to synchronize with Kafka callbacks + @volatile protected var failedWrite: Exception = _ + protected val projection = createProjection + + private val callback = new Callback() { + override def onCompletion(recordMetadata: RecordMetadata, e: Exception): Unit = { + if (failedWrite == null && e != null) { + failedWrite = e + } + } + } - private def createProjection: UnsafeProjection = { + /** + * Send the specified row to the producer, with a callback that will save any exception + * to failedWrite. Note that send is asynchronous; subclasses must flush() their producer before + * assuming the row is in Kafka. + */ + protected def sendRow( + row: InternalRow, producer: KafkaProducer[Array[Byte], Array[Byte]]): Unit = { + val projectedRow = projection(row) + val topic = projectedRow.getUTF8String(0) + val key = projectedRow.getBinary(1) + val value = projectedRow.getBinary(2) + if (topic == null) { + throw new NullPointerException(s"null topic present in the data. Use the " + + s"${KafkaSourceProvider.TOPIC_OPTION_KEY} option for setting a default topic.") + } + val record = new ProducerRecord[Array[Byte], Array[Byte]](topic.toString, key, value) + producer.send(record, callback) + } + + protected def checkForErrors(): Unit = { + if (failedWrite != null) { + throw failedWrite + } + } + + private def createProjection = { val topicExpression = topic.map(Literal(_)).orElse { inputSchema.find(_.name == KafkaWriter.TOPIC_ATTRIBUTE_NAME) }.getOrElse { @@ -112,11 +135,5 @@ private[kafka010] class KafkaWriteTask( Seq(topicExpression, Cast(keyExpression, BinaryType), Cast(valueExpression, BinaryType)), inputSchema) } - - private def checkForErrors(): Unit = { - if (failedWrite != null) { - throw failedWrite - } - } } diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala index 5e9ae35b3f008..15cd44812cb0c 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala @@ -43,10 +43,9 @@ private[kafka010] object KafkaWriter extends Logging { override def toString: String = "KafkaWriter" def validateQuery( - queryExecution: QueryExecution, + schema: Seq[Attribute], kafkaParameters: ju.Map[String, Object], topic: Option[String] = None): Unit = { - val schema = queryExecution.analyzed.output schema.find(_.name == TOPIC_ATTRIBUTE_NAME).getOrElse( if (topic.isEmpty) { throw new AnalysisException(s"topic option required when no " + @@ -84,7 +83,7 @@ private[kafka010] object KafkaWriter extends Logging { kafkaParameters: ju.Map[String, Object], topic: Option[String] = None): Unit = { val schema = queryExecution.analyzed.output - validateQuery(queryExecution, kafkaParameters, topic) + validateQuery(schema, kafkaParameters, topic) queryExecution.toRdd.foreachPartition { iter => val writeTask = new KafkaWriteTask(kafkaParameters, schema, topic) Utils.tryWithSafeFinally(block = writeTask.execute(iter))( diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala new file mode 100644 index 0000000000000..dfc97b1c38bb5 --- /dev/null +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala @@ -0,0 +1,474 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import java.util.Locale +import java.util.concurrent.atomic.AtomicInteger + +import org.apache.kafka.clients.producer.ProducerConfig +import org.apache.kafka.common.serialization.ByteArraySerializer +import org.scalatest.time.SpanSugar._ +import scala.collection.JavaConverters._ + +import org.apache.spark.sql.{AnalysisException, DataFrame, Row, SaveMode} +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, SpecificInternalRow, UnsafeProjection} +import org.apache.spark.sql.execution.streaming.MemoryStream +import org.apache.spark.sql.streaming._ +import org.apache.spark.sql.types.{BinaryType, DataType} +import org.apache.spark.util.Utils + +/** + * This is a temporary port of KafkaSinkSuite, since we do not yet have a V2 memory stream. + * Once we have one, this will be changed to a specialization of KafkaSinkSuite and we won't have + * to duplicate all the code. + */ +class KafkaContinuousSinkSuite extends KafkaContinuousTest { + import testImplicits._ + + override val streamingTimeout = 30.seconds + + override def beforeAll(): Unit = { + super.beforeAll() + testUtils = new KafkaTestUtils( + withBrokerProps = Map("auto.create.topics.enable" -> "false")) + testUtils.setup() + } + + override def afterAll(): Unit = { + if (testUtils != null) { + testUtils.teardown() + testUtils = null + } + super.afterAll() + } + + test("streaming - write to kafka with topic field") { + val inputTopic = newTopic() + testUtils.createTopic(inputTopic, partitions = 1) + + val input = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("subscribe", inputTopic) + .option("startingOffsets", "earliest") + .load() + + val topic = newTopic() + testUtils.createTopic(topic) + + val writer = createKafkaWriter( + input.toDF(), + withTopic = None, + withOutputMode = Some(OutputMode.Append))( + withSelectExpr = s"'$topic' as topic", "value") + + val reader = createKafkaReader(topic) + .selectExpr("CAST(key as STRING) key", "CAST(value as STRING) value") + .selectExpr("CAST(key as INT) key", "CAST(value as INT) value") + .as[(Int, Int)] + .map(_._2) + + try { + testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) + failAfter(streamingTimeout) { + writer.processAllAvailable() + } + checkDatasetUnorderly(reader, 1, 2, 3, 4, 5) + testUtils.sendMessages(inputTopic, Array("6", "7", "8", "9", "10")) + failAfter(streamingTimeout) { + writer.processAllAvailable() + } + checkDatasetUnorderly(reader, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10) + } finally { + writer.stop() + } + } + + test("streaming - write w/o topic field, with topic option") { + val inputTopic = newTopic() + testUtils.createTopic(inputTopic, partitions = 1) + + val input = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("subscribe", inputTopic) + .option("startingOffsets", "earliest") + .load() + + val topic = newTopic() + testUtils.createTopic(topic) + + val writer = createKafkaWriter( + input.toDF(), + withTopic = Some(topic), + withOutputMode = Some(OutputMode.Append()))() + + val reader = createKafkaReader(topic) + .selectExpr("CAST(key as STRING) key", "CAST(value as STRING) value") + .selectExpr("CAST(key as INT) key", "CAST(value as INT) value") + .as[(Int, Int)] + .map(_._2) + + try { + testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) + failAfter(streamingTimeout) { + writer.processAllAvailable() + } + checkDatasetUnorderly(reader, 1, 2, 3, 4, 5) + testUtils.sendMessages(inputTopic, Array("6", "7", "8", "9", "10")) + failAfter(streamingTimeout) { + writer.processAllAvailable() + } + checkDatasetUnorderly(reader, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10) + } finally { + writer.stop() + } + } + + test("streaming - topic field and topic option") { + /* The purpose of this test is to ensure that the topic option + * overrides the topic field. We begin by writing some data that + * includes a topic field and value (e.g., 'foo') along with a topic + * option. Then when we read from the topic specified in the option + * we should see the data i.e., the data was written to the topic + * option, and not to the topic in the data e.g., foo + */ + val inputTopic = newTopic() + testUtils.createTopic(inputTopic, partitions = 1) + + val input = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("subscribe", inputTopic) + .option("startingOffsets", "earliest") + .load() + + val topic = newTopic() + testUtils.createTopic(topic) + + val writer = createKafkaWriter( + input.toDF(), + withTopic = Some(topic), + withOutputMode = Some(OutputMode.Append()))( + withSelectExpr = "'foo' as topic", "CAST(value as STRING) value") + + val reader = createKafkaReader(topic) + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .selectExpr("CAST(key AS INT)", "CAST(value AS INT)") + .as[(Int, Int)] + .map(_._2) + + try { + testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) + failAfter(streamingTimeout) { + writer.processAllAvailable() + } + checkDatasetUnorderly(reader, 1, 2, 3, 4, 5) + testUtils.sendMessages(inputTopic, Array("6", "7", "8", "9", "10")) + failAfter(streamingTimeout) { + writer.processAllAvailable() + } + checkDatasetUnorderly(reader, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10) + } finally { + writer.stop() + } + } + + test("null topic attribute") { + val inputTopic = newTopic() + testUtils.createTopic(inputTopic, partitions = 1) + + val input = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("subscribe", inputTopic) + .option("startingOffsets", "earliest") + .load() + val topic = newTopic() + testUtils.createTopic(topic) + + /* No topic field or topic option */ + var writer: StreamingQuery = null + var ex: Exception = null + try { + ex = intercept[StreamingQueryException] { + writer = createKafkaWriter(input.toDF())( + withSelectExpr = "CAST(null as STRING) as topic", "value" + ) + testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) + writer.processAllAvailable() + } + } finally { + writer.stop() + } + assert(ex.getCause.getCause.getMessage + .toLowerCase(Locale.ROOT) + .contains("null topic present in the data.")) + } + + test("streaming - write data with bad schema") { + val inputTopic = newTopic() + testUtils.createTopic(inputTopic, partitions = 1) + + val input = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("subscribe", inputTopic) + .option("startingOffsets", "earliest") + .load() + val topic = newTopic() + testUtils.createTopic(topic) + + /* No topic field or topic option */ + var writer: StreamingQuery = null + var ex: Exception = null + try { + ex = intercept[StreamingQueryException] { + writer = createKafkaWriter(input.toDF())( + withSelectExpr = "value as key", "value" + ) + testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) + writer.processAllAvailable() + } + } finally { + writer.stop() + } + assert(ex.getMessage + .toLowerCase(Locale.ROOT) + .contains("topic option required when no 'topic' attribute is present")) + + try { + /* No value field */ + ex = intercept[StreamingQueryException] { + writer = createKafkaWriter(input.toDF())( + withSelectExpr = s"'$topic' as topic", "value as key" + ) + testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) + writer.processAllAvailable() + } + } finally { + writer.stop() + } + assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( + "required attribute 'value' not found")) + } + + test("streaming - write data with valid schema but wrong types") { + val inputTopic = newTopic() + testUtils.createTopic(inputTopic, partitions = 1) + + val input = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("subscribe", inputTopic) + .option("startingOffsets", "earliest") + .load() + .selectExpr("CAST(value as STRING) value") + val topic = newTopic() + testUtils.createTopic(topic) + + var writer: StreamingQuery = null + var ex: Exception = null + try { + /* topic field wrong type */ + ex = intercept[StreamingQueryException] { + writer = createKafkaWriter(input.toDF())( + withSelectExpr = s"CAST('1' as INT) as topic", "value" + ) + testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) + writer.processAllAvailable() + } + } finally { + writer.stop() + } + assert(ex.getMessage.toLowerCase(Locale.ROOT).contains("topic type must be a string")) + + try { + /* value field wrong type */ + ex = intercept[StreamingQueryException] { + writer = createKafkaWriter(input.toDF())( + withSelectExpr = s"'$topic' as topic", "CAST(value as INT) as value" + ) + testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) + writer.processAllAvailable() + } + } finally { + writer.stop() + } + assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( + "value attribute type must be a string or binarytype")) + + try { + ex = intercept[StreamingQueryException] { + /* key field wrong type */ + writer = createKafkaWriter(input.toDF())( + withSelectExpr = s"'$topic' as topic", "CAST(value as INT) as key", "value" + ) + testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) + writer.processAllAvailable() + } + } finally { + writer.stop() + } + assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( + "key attribute type must be a string or binarytype")) + } + + test("streaming - write to non-existing topic") { + val inputTopic = newTopic() + testUtils.createTopic(inputTopic, partitions = 1) + + val input = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("subscribe", inputTopic) + .option("startingOffsets", "earliest") + .load() + val topic = newTopic() + + var writer: StreamingQuery = null + var ex: Exception = null + try { + ex = intercept[StreamingQueryException] { + writer = createKafkaWriter(input.toDF(), withTopic = Some(topic))() + testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) + eventually(timeout(streamingTimeout)) { + assert(writer.exception.isDefined) + } + throw writer.exception.get + } + } finally { + writer.stop() + } + assert(ex.getMessage.toLowerCase(Locale.ROOT).contains("job aborted")) + } + + test("streaming - exception on config serializer") { + val inputTopic = newTopic() + testUtils.createTopic(inputTopic, partitions = 1) + testUtils.sendMessages(inputTopic, Array("0")) + + val input = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("subscribe", inputTopic) + .load() + var writer: StreamingQuery = null + var ex: Exception = null + try { + ex = intercept[StreamingQueryException] { + writer = createKafkaWriter( + input.toDF(), + withOptions = Map("kafka.key.serializer" -> "foo"))() + writer.processAllAvailable() + } + assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( + "kafka option 'key.serializer' is not supported")) + } finally { + writer.stop() + } + + try { + ex = intercept[StreamingQueryException] { + writer = createKafkaWriter( + input.toDF(), + withOptions = Map("kafka.value.serializer" -> "foo"))() + writer.processAllAvailable() + } + assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( + "kafka option 'value.serializer' is not supported")) + } finally { + writer.stop() + } + } + + test("generic - write big data with small producer buffer") { + /* This test ensures that we understand the semantics of Kafka when + * is comes to blocking on a call to send when the send buffer is full. + * This test will configure the smallest possible producer buffer and + * indicate that we should block when it is full. Thus, no exception should + * be thrown in the case of a full buffer. + */ + val topic = newTopic() + testUtils.createTopic(topic, 1) + val options = new java.util.HashMap[String, String] + options.put("bootstrap.servers", testUtils.brokerAddress) + options.put("buffer.memory", "16384") // min buffer size + options.put("block.on.buffer.full", "true") + options.put(ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG, classOf[ByteArraySerializer].getName) + options.put(ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG, classOf[ByteArraySerializer].getName) + val inputSchema = Seq(AttributeReference("value", BinaryType)()) + val data = new Array[Byte](15000) // large value + val writeTask = new KafkaContinuousDataWriter(Some(topic), options.asScala.toMap, inputSchema) + try { + val fieldTypes: Array[DataType] = Array(BinaryType) + val converter = UnsafeProjection.create(fieldTypes) + val row = new SpecificInternalRow(fieldTypes) + row.update(0, data) + val iter = Seq.fill(1000)(converter.apply(row)).iterator + iter.foreach(writeTask.write(_)) + writeTask.commit() + } finally { + writeTask.close() + } + } + + private def createKafkaReader(topic: String): DataFrame = { + spark.read + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("startingOffsets", "earliest") + .option("endingOffsets", "latest") + .option("subscribe", topic) + .load() + } + + private def createKafkaWriter( + input: DataFrame, + withTopic: Option[String] = None, + withOutputMode: Option[OutputMode] = None, + withOptions: Map[String, String] = Map[String, String]()) + (withSelectExpr: String*): StreamingQuery = { + var stream: DataStreamWriter[Row] = null + val checkpointDir = Utils.createTempDir() + var df = input.toDF() + if (withSelectExpr.length > 0) { + df = df.selectExpr(withSelectExpr: _*) + } + stream = df.writeStream + .format("kafka") + .option("checkpointLocation", checkpointDir.getCanonicalPath) + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + // We need to reduce blocking time to efficiently test non-existent partition behavior. + .option("kafka.max.block.ms", "1000") + .trigger(Trigger.Continuous(1000)) + .queryName("kafkaStream") + withTopic.foreach(stream.option("topic", _)) + withOutputMode.foreach(stream.outputMode(_)) + withOptions.foreach(opt => stream.option(opt._1, opt._2)) + stream.start() + } +} diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala new file mode 100644 index 0000000000000..b3dade414f625 --- /dev/null +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import java.util.Properties +import java.util.concurrent.atomic.AtomicInteger + +import org.scalatest.time.SpanSugar._ +import scala.collection.mutable +import scala.util.Random + +import org.apache.spark.SparkContext +import org.apache.spark.sql.{DataFrame, Dataset, ForeachWriter, Row} +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation +import org.apache.spark.sql.execution.streaming.StreamExecution +import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution +import org.apache.spark.sql.streaming.{StreamTest, Trigger} +import org.apache.spark.sql.test.{SharedSQLContext, TestSparkSession} + +// Run tests in KafkaSourceSuiteBase in continuous execution mode. +class KafkaContinuousSourceSuite extends KafkaSourceSuiteBase with KafkaContinuousTest + +class KafkaContinuousSourceTopicDeletionSuite extends KafkaContinuousTest { + import testImplicits._ + + override val brokerProps = Map("auto.create.topics.enable" -> "false") + + test("subscribing topic by pattern with topic deletions") { + val topicPrefix = newTopic() + val topic = topicPrefix + "-seems" + val topic2 = topicPrefix + "-bad" + testUtils.createTopic(topic, partitions = 5) + testUtils.sendMessages(topic, Array("-1")) + require(testUtils.getLatestOffsets(Set(topic)).size === 5) + + val reader = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("kafka.metadata.max.age.ms", "1") + .option("subscribePattern", s"$topicPrefix-.*") + .option("failOnDataLoss", "false") + + val kafka = reader.load() + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .as[(String, String)] + val mapped = kafka.map(kv => kv._2.toInt + 1) + + testStream(mapped)( + makeSureGetOffsetCalled, + AddKafkaData(Set(topic), 1, 2, 3), + CheckAnswer(2, 3, 4), + Execute { query => + testUtils.deleteTopic(topic) + testUtils.createTopic(topic2, partitions = 5) + eventually(timeout(streamingTimeout)) { + assert( + query.lastExecution.logical.collectFirst { + case DataSourceV2Relation(_, r: KafkaContinuousReader) => r + }.exists { r => + // Ensure the new topic is present and the old topic is gone. + r.knownPartitions.exists(_.topic == topic2) + }, + s"query never reconfigured to new topic $topic2") + } + }, + AddKafkaData(Set(topic2), 4, 5, 6), + CheckAnswer(2, 3, 4, 5, 6, 7) + ) + } +} + +class KafkaContinuousSourceStressForDontFailOnDataLossSuite + extends KafkaSourceStressForDontFailOnDataLossSuite { + override protected def startStream(ds: Dataset[Int]) = { + ds.writeStream + .format("memory") + .queryName("memory") + .start() + } +} diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala new file mode 100644 index 0000000000000..e713e6695d2bd --- /dev/null +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import org.apache.spark.SparkContext +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation +import org.apache.spark.sql.execution.streaming.StreamExecution +import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution +import org.apache.spark.sql.streaming.Trigger +import org.apache.spark.sql.test.TestSparkSession + +// Trait to configure StreamTest for kafka continuous execution tests. +trait KafkaContinuousTest extends KafkaSourceTest { + override val defaultTrigger = Trigger.Continuous(1000) + override val defaultUseV2Sink = true + + // We need more than the default local[2] to be able to schedule all partitions simultaneously. + override protected def createSparkSession = new TestSparkSession( + new SparkContext( + "local[10]", + "continuous-stream-test-sql-context", + sparkConf.set("spark.sql.testkey", "true"))) + + // In addition to setting the partitions in Kafka, we have to wait until the query has + // reconfigured to the new count so the test framework can hook in properly. + override protected def setTopicPartitions( + topic: String, newCount: Int, query: StreamExecution) = { + testUtils.addPartitions(topic, newCount) + eventually(timeout(streamingTimeout)) { + assert( + query.lastExecution.logical.collectFirst { + case DataSourceV2Relation(_, r: KafkaContinuousReader) => r + }.exists(_.knownPartitions.size == newCount), + s"query never reconfigured to $newCount partitions") + } + } + + test("ensure continuous stream is being used") { + val query = spark.readStream + .format("rate") + .option("numPartitions", "1") + .option("rowsPerSecond", "1") + .load() + + testStream(query)( + Execute(q => assert(q.isInstanceOf[ContinuousExecution])) + ) + } +} diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala index 2034b9be07f24..d66908f86ccc7 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala @@ -34,11 +34,14 @@ import org.scalatest.concurrent.PatienceConfiguration.Timeout import org.scalatest.time.SpanSugar._ import org.apache.spark.SparkContext -import org.apache.spark.sql.ForeachWriter +import org.apache.spark.sql.{DataFrame, Dataset, ForeachWriter, Row} +import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, WriteToDataSourceV2Exec} import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution +import org.apache.spark.sql.execution.streaming.sources.ContinuousMemoryWriter import org.apache.spark.sql.functions.{count, window} import org.apache.spark.sql.kafka010.KafkaSourceProvider._ -import org.apache.spark.sql.streaming.{ProcessingTime, StreamTest} +import org.apache.spark.sql.streaming.{ProcessingTime, StreamTest, Trigger} import org.apache.spark.sql.streaming.util.StreamManualClock import org.apache.spark.sql.test.{SharedSQLContext, TestSparkSession} import org.apache.spark.util.Utils @@ -49,9 +52,11 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext { override val streamingTimeout = 30.seconds + protected val brokerProps = Map[String, Object]() + override def beforeAll(): Unit = { super.beforeAll() - testUtils = new KafkaTestUtils + testUtils = new KafkaTestUtils(brokerProps) testUtils.setup() } @@ -59,18 +64,25 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext { if (testUtils != null) { testUtils.teardown() testUtils = null - super.afterAll() } + super.afterAll() } protected def makeSureGetOffsetCalled = AssertOnQuery { q => // Because KafkaSource's initialPartitionOffsets is set lazily, we need to make sure - // its "getOffset" is called before pushing any data. Otherwise, because of the race contion, + // its "getOffset" is called before pushing any data. Otherwise, because of the race condition, // we don't know which data should be fetched when `startingOffsets` is latest. - q.processAllAvailable() + q match { + case c: ContinuousExecution => c.awaitEpoch(0) + case m: MicroBatchExecution => m.processAllAvailable() + } true } + protected def setTopicPartitions(topic: String, newCount: Int, query: StreamExecution) : Unit = { + testUtils.addPartitions(topic, newCount) + } + /** * Add data to Kafka. * @@ -82,7 +94,7 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext { message: String = "", topicAction: (String, Option[Int]) => Unit = (_, _) => {}) extends AddData { - override def addData(query: Option[StreamExecution]): (Source, Offset) = { + override def addData(query: Option[StreamExecution]): (BaseStreamingSource, Offset) = { if (query.get.isActive) { // Make sure no Spark job is running when deleting a topic query.get.processAllAvailable() @@ -97,16 +109,18 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext { topicAction(existingTopicPartitions._1, Some(existingTopicPartitions._2)) } - // Read all topics again in case some topics are delete. - val allTopics = testUtils.getAllTopicsAndPartitionSize().toMap.keys require( query.nonEmpty, "Cannot add data when there is no query for finding the active kafka source") val sources = query.get.logicalPlan.collect { - case StreamingExecutionRelation(source, _) if source.isInstanceOf[KafkaSource] => - source.asInstanceOf[KafkaSource] - } + case StreamingExecutionRelation(source: KafkaSource, _) => source + } ++ (query.get.lastExecution match { + case null => Seq() + case e => e.logical.collect { + case DataSourceV2Relation(_, reader: KafkaContinuousReader) => reader + } + }) if (sources.isEmpty) { throw new Exception( "Could not find Kafka source in the StreamExecution logical plan to add data to") @@ -137,14 +151,158 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext { override def toString: String = s"AddKafkaData(topics = $topics, data = $data, message = $message)" } -} + private val topicId = new AtomicInteger(0) + protected def newTopic(): String = s"topic-${topicId.getAndIncrement()}" +} -class KafkaSourceSuite extends KafkaSourceTest { +class KafkaMicroBatchSourceSuite extends KafkaSourceSuiteBase { import testImplicits._ - private val topicId = new AtomicInteger(0) + test("(de)serialization of initial offsets") { + val topic = newTopic() + testUtils.createTopic(topic, partitions = 5) + + val reader = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("subscribe", topic) + + testStream(reader.load)( + makeSureGetOffsetCalled, + StopStream, + StartStream(), + StopStream) + } + + test("maxOffsetsPerTrigger") { + val topic = newTopic() + testUtils.createTopic(topic, partitions = 3) + testUtils.sendMessages(topic, (100 to 200).map(_.toString).toArray, Some(0)) + testUtils.sendMessages(topic, (10 to 20).map(_.toString).toArray, Some(1)) + testUtils.sendMessages(topic, Array("1"), Some(2)) + + val reader = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("kafka.metadata.max.age.ms", "1") + .option("maxOffsetsPerTrigger", 10) + .option("subscribe", topic) + .option("startingOffsets", "earliest") + val kafka = reader.load() + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .as[(String, String)] + val mapped: org.apache.spark.sql.Dataset[_] = kafka.map(kv => kv._2.toInt) + + val clock = new StreamManualClock + + val waitUntilBatchProcessed = AssertOnQuery { q => + eventually(Timeout(streamingTimeout)) { + if (!q.exception.isDefined) { + assert(clock.isStreamWaitingAt(clock.getTimeMillis())) + } + } + if (q.exception.isDefined) { + throw q.exception.get + } + true + } + + testStream(mapped)( + StartStream(ProcessingTime(100), clock), + waitUntilBatchProcessed, + // 1 from smallest, 1 from middle, 8 from biggest + CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107), + AdvanceManualClock(100), + waitUntilBatchProcessed, + // smallest now empty, 1 more from middle, 9 more from biggest + CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107, + 11, 108, 109, 110, 111, 112, 113, 114, 115, 116 + ), + StopStream, + StartStream(ProcessingTime(100), clock), + waitUntilBatchProcessed, + // smallest now empty, 1 more from middle, 9 more from biggest + CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107, + 11, 108, 109, 110, 111, 112, 113, 114, 115, 116, + 12, 117, 118, 119, 120, 121, 122, 123, 124, 125 + ), + AdvanceManualClock(100), + waitUntilBatchProcessed, + // smallest now empty, 1 more from middle, 9 more from biggest + CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107, + 11, 108, 109, 110, 111, 112, 113, 114, 115, 116, + 12, 117, 118, 119, 120, 121, 122, 123, 124, 125, + 13, 126, 127, 128, 129, 130, 131, 132, 133, 134 + ) + ) + } + + test("input row metrics") { + val topic = newTopic() + testUtils.createTopic(topic, partitions = 5) + testUtils.sendMessages(topic, Array("-1")) + require(testUtils.getLatestOffsets(Set(topic)).size === 5) + + val kafka = spark + .readStream + .format("kafka") + .option("subscribe", topic) + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .load() + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .as[(String, String)] + + val mapped = kafka.map(kv => kv._2.toInt + 1) + testStream(mapped)( + StartStream(trigger = ProcessingTime(1)), + makeSureGetOffsetCalled, + AddKafkaData(Set(topic), 1, 2, 3), + CheckAnswer(2, 3, 4), + AssertOnQuery { query => + val recordsRead = query.recentProgress.map(_.numInputRows).sum + recordsRead == 3 + } + ) + } + + test("subscribing topic by pattern with topic deletions") { + val topicPrefix = newTopic() + val topic = topicPrefix + "-seems" + val topic2 = topicPrefix + "-bad" + testUtils.createTopic(topic, partitions = 5) + testUtils.sendMessages(topic, Array("-1")) + require(testUtils.getLatestOffsets(Set(topic)).size === 5) + + val reader = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("kafka.metadata.max.age.ms", "1") + .option("subscribePattern", s"$topicPrefix-.*") + .option("failOnDataLoss", "false") + + val kafka = reader.load() + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .as[(String, String)] + val mapped = kafka.map(kv => kv._2.toInt + 1) + + testStream(mapped)( + makeSureGetOffsetCalled, + AddKafkaData(Set(topic), 1, 2, 3), + CheckAnswer(2, 3, 4), + Assert { + testUtils.deleteTopic(topic) + testUtils.createTopic(topic2, partitions = 5) + true + }, + AddKafkaData(Set(topic2), 4, 5, 6), + CheckAnswer(2, 3, 4, 5, 6, 7) + ) + } testWithUninterruptibleThread( "deserialization of initial offset with Spark 2.1.0") { @@ -237,86 +395,51 @@ class KafkaSourceSuite extends KafkaSourceTest { } } - test("(de)serialization of initial offsets") { - val topic = newTopic() - testUtils.createTopic(topic, partitions = 64) - - val reader = spark - .readStream - .format("kafka") - .option("kafka.bootstrap.servers", testUtils.brokerAddress) - .option("subscribe", topic) - - testStream(reader.load)( - makeSureGetOffsetCalled, - StopStream, - StartStream(), - StopStream) - } - - test("maxOffsetsPerTrigger") { + test("KafkaSource with watermark") { + val now = System.currentTimeMillis() val topic = newTopic() - testUtils.createTopic(topic, partitions = 3) - testUtils.sendMessages(topic, (100 to 200).map(_.toString).toArray, Some(0)) - testUtils.sendMessages(topic, (10 to 20).map(_.toString).toArray, Some(1)) - testUtils.sendMessages(topic, Array("1"), Some(2)) + testUtils.createTopic(newTopic(), partitions = 1) + testUtils.sendMessages(topic, Array(1).map(_.toString)) - val reader = spark + val kafka = spark .readStream .format("kafka") .option("kafka.bootstrap.servers", testUtils.brokerAddress) .option("kafka.metadata.max.age.ms", "1") - .option("maxOffsetsPerTrigger", 10) + .option("startingOffsets", s"earliest") .option("subscribe", topic) - .option("startingOffsets", "earliest") - val kafka = reader.load() - .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") - .as[(String, String)] - val mapped: org.apache.spark.sql.Dataset[_] = kafka.map(kv => kv._2.toInt) - - val clock = new StreamManualClock + .load() - val waitUntilBatchProcessed = AssertOnQuery { q => - eventually(Timeout(streamingTimeout)) { - if (!q.exception.isDefined) { - assert(clock.isStreamWaitingAt(clock.getTimeMillis())) - } - } - if (q.exception.isDefined) { - throw q.exception.get - } - true - } + val windowedAggregation = kafka + .withWatermark("timestamp", "10 seconds") + .groupBy(window($"timestamp", "5 seconds") as 'window) + .agg(count("*") as 'count) + .select($"window".getField("start") as 'window, $"count") - testStream(mapped)( - StartStream(ProcessingTime(100), clock), - waitUntilBatchProcessed, - // 1 from smallest, 1 from middle, 8 from biggest - CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107), - AdvanceManualClock(100), - waitUntilBatchProcessed, - // smallest now empty, 1 more from middle, 9 more from biggest - CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107, - 11, 108, 109, 110, 111, 112, 113, 114, 115, 116 - ), - StopStream, - StartStream(ProcessingTime(100), clock), - waitUntilBatchProcessed, - // smallest now empty, 1 more from middle, 9 more from biggest - CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107, - 11, 108, 109, 110, 111, 112, 113, 114, 115, 116, - 12, 117, 118, 119, 120, 121, 122, 123, 124, 125 - ), - AdvanceManualClock(100), - waitUntilBatchProcessed, - // smallest now empty, 1 more from middle, 9 more from biggest - CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107, - 11, 108, 109, 110, 111, 112, 113, 114, 115, 116, - 12, 117, 118, 119, 120, 121, 122, 123, 124, 125, - 13, 126, 127, 128, 129, 130, 131, 132, 133, 134 - ) - ) + val query = windowedAggregation + .writeStream + .format("memory") + .outputMode("complete") + .queryName("kafkaWatermark") + .start() + query.processAllAvailable() + val rows = spark.table("kafkaWatermark").collect() + assert(rows.length === 1, s"Unexpected results: ${rows.toList}") + val row = rows(0) + // We cannot check the exact window start time as it depands on the time that messages were + // inserted by the producer. So here we just use a low bound to make sure the internal + // conversion works. + assert( + row.getAs[java.sql.Timestamp]("window").getTime >= now - 5 * 1000, + s"Unexpected results: $row") + assert(row.getAs[Int]("count") === 1, s"Unexpected results: $row") + query.stop() } +} + +class KafkaSourceSuiteBase extends KafkaSourceTest { + + import testImplicits._ test("cannot stop Kafka stream") { val topic = newTopic() @@ -328,7 +451,7 @@ class KafkaSourceSuite extends KafkaSourceTest { .format("kafka") .option("kafka.bootstrap.servers", testUtils.brokerAddress) .option("kafka.metadata.max.age.ms", "1") - .option("subscribePattern", s"topic-.*") + .option("subscribePattern", s"$topic.*") val kafka = reader.load() .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") @@ -422,65 +545,6 @@ class KafkaSourceSuite extends KafkaSourceTest { } } - test("subscribing topic by pattern with topic deletions") { - val topicPrefix = newTopic() - val topic = topicPrefix + "-seems" - val topic2 = topicPrefix + "-bad" - testUtils.createTopic(topic, partitions = 5) - testUtils.sendMessages(topic, Array("-1")) - require(testUtils.getLatestOffsets(Set(topic)).size === 5) - - val reader = spark - .readStream - .format("kafka") - .option("kafka.bootstrap.servers", testUtils.brokerAddress) - .option("kafka.metadata.max.age.ms", "1") - .option("subscribePattern", s"$topicPrefix-.*") - .option("failOnDataLoss", "false") - - val kafka = reader.load() - .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") - .as[(String, String)] - val mapped = kafka.map(kv => kv._2.toInt + 1) - - testStream(mapped)( - makeSureGetOffsetCalled, - AddKafkaData(Set(topic), 1, 2, 3), - CheckAnswer(2, 3, 4), - Assert { - testUtils.deleteTopic(topic) - testUtils.createTopic(topic2, partitions = 5) - true - }, - AddKafkaData(Set(topic2), 4, 5, 6), - CheckAnswer(2, 3, 4, 5, 6, 7) - ) - } - - test("starting offset is latest by default") { - val topic = newTopic() - testUtils.createTopic(topic, partitions = 5) - testUtils.sendMessages(topic, Array("0")) - require(testUtils.getLatestOffsets(Set(topic)).size === 5) - - val reader = spark - .readStream - .format("kafka") - .option("kafka.bootstrap.servers", testUtils.brokerAddress) - .option("subscribe", topic) - - val kafka = reader.load() - .selectExpr("CAST(value AS STRING)") - .as[String] - val mapped = kafka.map(_.toInt) - - testStream(mapped)( - makeSureGetOffsetCalled, - AddKafkaData(Set(topic), 1, 2, 3), - CheckAnswer(1, 2, 3) // should not have 0 - ) - } - test("bad source options") { def testBadOptions(options: (String, String)*)(expectedMsgs: String*): Unit = { val ex = intercept[IllegalArgumentException] { @@ -540,34 +604,6 @@ class KafkaSourceSuite extends KafkaSourceTest { testUnsupportedConfig("kafka.auto.offset.reset", "latest") } - test("input row metrics") { - val topic = newTopic() - testUtils.createTopic(topic, partitions = 5) - testUtils.sendMessages(topic, Array("-1")) - require(testUtils.getLatestOffsets(Set(topic)).size === 5) - - val kafka = spark - .readStream - .format("kafka") - .option("subscribe", topic) - .option("kafka.bootstrap.servers", testUtils.brokerAddress) - .load() - .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") - .as[(String, String)] - - val mapped = kafka.map(kv => kv._2.toInt + 1) - testStream(mapped)( - StartStream(trigger = ProcessingTime(1)), - makeSureGetOffsetCalled, - AddKafkaData(Set(topic), 1, 2, 3), - CheckAnswer(2, 3, 4), - AssertOnQuery { query => - val recordsRead = query.recentProgress.map(_.numInputRows).sum - recordsRead == 3 - } - ) - } - test("delete a topic when a Spark job is running") { KafkaSourceSuite.collectedData.clear() @@ -629,8 +665,6 @@ class KafkaSourceSuite extends KafkaSourceTest { } } - private def newTopic(): String = s"topic-${topicId.getAndIncrement()}" - private def assignString(topic: String, partitions: Iterable[Int]): String = { JsonUtils.partitions(partitions.map(p => new TopicPartition(topic, p))) } @@ -676,6 +710,10 @@ class KafkaSourceSuite extends KafkaSourceTest { testStream(mapped)( makeSureGetOffsetCalled, + Execute { q => + // wait to reach the last offset in every partition + q.awaitOffset(0, KafkaSourceOffset(partitionOffsets.mapValues(_ => 3L))) + }, CheckAnswer(-20, -21, -22, 0, 1, 2, 11, 12, 22), StopStream, StartStream(), @@ -706,6 +744,7 @@ class KafkaSourceSuite extends KafkaSourceTest { .format("memory") .outputMode("append") .queryName("kafkaColumnTypes") + .trigger(defaultTrigger) .start() query.processAllAvailable() val rows = spark.table("kafkaColumnTypes").collect() @@ -723,47 +762,6 @@ class KafkaSourceSuite extends KafkaSourceTest { query.stop() } - test("KafkaSource with watermark") { - val now = System.currentTimeMillis() - val topic = newTopic() - testUtils.createTopic(newTopic(), partitions = 1) - testUtils.sendMessages(topic, Array(1).map(_.toString)) - - val kafka = spark - .readStream - .format("kafka") - .option("kafka.bootstrap.servers", testUtils.brokerAddress) - .option("kafka.metadata.max.age.ms", "1") - .option("startingOffsets", s"earliest") - .option("subscribe", topic) - .load() - - val windowedAggregation = kafka - .withWatermark("timestamp", "10 seconds") - .groupBy(window($"timestamp", "5 seconds") as 'window) - .agg(count("*") as 'count) - .select($"window".getField("start") as 'window, $"count") - - val query = windowedAggregation - .writeStream - .format("memory") - .outputMode("complete") - .queryName("kafkaWatermark") - .start() - query.processAllAvailable() - val rows = spark.table("kafkaWatermark").collect() - assert(rows.length === 1, s"Unexpected results: ${rows.toList}") - val row = rows(0) - // We cannot check the exact window start time as it depands on the time that messages were - // inserted by the producer. So here we just use a low bound to make sure the internal - // conversion works. - assert( - row.getAs[java.sql.Timestamp]("window").getTime >= now - 5 * 1000, - s"Unexpected results: $row") - assert(row.getAs[Int]("count") === 1, s"Unexpected results: $row") - query.stop() - } - private def testFromLatestOffsets( topic: String, addPartitions: Boolean, @@ -800,9 +798,7 @@ class KafkaSourceSuite extends KafkaSourceTest { AddKafkaData(Set(topic), 7, 8), CheckAnswer(2, 3, 4, 5, 6, 7, 8, 9), AssertOnQuery("Add partitions") { query: StreamExecution => - if (addPartitions) { - testUtils.addPartitions(topic, 10) - } + if (addPartitions) setTopicPartitions(topic, 10, query) true }, AddKafkaData(Set(topic), 9, 10, 11, 12, 13, 14, 15, 16), @@ -843,9 +839,7 @@ class KafkaSourceSuite extends KafkaSourceTest { StartStream(), CheckAnswer(2, 3, 4, 5, 6, 7, 8, 9), AssertOnQuery("Add partitions") { query: StreamExecution => - if (addPartitions) { - testUtils.addPartitions(topic, 10) - } + if (addPartitions) setTopicPartitions(topic, 10, query) true }, AddKafkaData(Set(topic), 9, 10, 11, 12, 13, 14, 15, 16), @@ -977,20 +971,8 @@ class KafkaSourceStressForDontFailOnDataLossSuite extends StreamTest with Shared } } - test("stress test for failOnDataLoss=false") { - val reader = spark - .readStream - .format("kafka") - .option("kafka.bootstrap.servers", testUtils.brokerAddress) - .option("kafka.metadata.max.age.ms", "1") - .option("subscribePattern", "failOnDataLoss.*") - .option("startingOffsets", "earliest") - .option("failOnDataLoss", "false") - .option("fetchOffset.retryIntervalMs", "3000") - val kafka = reader.load() - .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") - .as[(String, String)] - val query = kafka.map(kv => kv._2.toInt).writeStream.foreach(new ForeachWriter[Int] { + protected def startStream(ds: Dataset[Int]) = { + ds.writeStream.foreach(new ForeachWriter[Int] { override def open(partitionId: Long, version: Long): Boolean = { true @@ -1004,6 +986,22 @@ class KafkaSourceStressForDontFailOnDataLossSuite extends StreamTest with Shared override def close(errorOrNull: Throwable): Unit = { } }).start() + } + + test("stress test for failOnDataLoss=false") { + val reader = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("kafka.metadata.max.age.ms", "1") + .option("subscribePattern", "failOnDataLoss.*") + .option("startingOffsets", "earliest") + .option("failOnDataLoss", "false") + .option("fetchOffset.retryIntervalMs", "3000") + val kafka = reader.load() + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .as[(String, String)] + val query = startStream(kafka.map(kv => kv._2.toInt)) val testTime = 1.minutes val startTime = System.currentTimeMillis() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index e8d683a578f35..b714a46b5f786 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -191,6 +191,9 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { ds = ds.asInstanceOf[DataSourceV2], conf = sparkSession.sessionState.conf)).asJava) + // Streaming also uses the data source V2 API. So it may be that the data source implements + // v2, but has no v2 implementation for batch reads. In that case, we fall back to loading + // the dataframe as a v1 source. val reader = (ds, userSpecifiedSchema) match { case (ds: ReadSupportWithSchema, Some(schema)) => ds.createReader(schema, options) @@ -208,23 +211,30 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { } reader - case _ => - throw new AnalysisException(s"$cls does not support data reading.") + case _ => null // fall back to v1 } - Dataset.ofRows(sparkSession, DataSourceV2Relation(reader)) + if (reader == null) { + loadV1Source(paths: _*) + } else { + Dataset.ofRows(sparkSession, DataSourceV2Relation(reader)) + } } else { - // Code path for data source v1. - sparkSession.baseRelationToDataFrame( - DataSource.apply( - sparkSession, - paths = paths, - userSpecifiedSchema = userSpecifiedSchema, - className = source, - options = extraOptions.toMap).resolveRelation()) + loadV1Source(paths: _*) } } + private def loadV1Source(paths: String*) = { + // Code path for data source v1. + sparkSession.baseRelationToDataFrame( + DataSource.apply( + sparkSession, + paths = paths, + userSpecifiedSchema = userSpecifiedSchema, + className = source, + options = extraOptions.toMap).resolveRelation()) + } + /** * Construct a `DataFrame` representing the database table accessible via JDBC URL * url named table and connection properties. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 3304f368e1050..97f12ff625c42 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -255,17 +255,24 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { } } - case _ => throw new AnalysisException(s"$cls does not support data writing.") + // Streaming also uses the data source V2 API. So it may be that the data source implements + // v2, but has no v2 implementation for batch writes. In that case, we fall back to saving + // as though it's a V1 source. + case _ => saveToV1Source() } } else { - // Code path for data source v1. - runCommand(df.sparkSession, "save") { - DataSource( - sparkSession = df.sparkSession, - className = source, - partitionColumns = partitioningColumns.getOrElse(Nil), - options = extraOptions.toMap).planForWriting(mode, AnalysisBarrier(df.logicalPlan)) - } + saveToV1Source() + } + } + + private def saveToV1Source(): Unit = { + // Code path for data source v1. + runCommand(df.sparkSession, "save") { + DataSource( + sparkSession = df.sparkSession, + className = source, + partitionColumns = partitioningColumns.getOrElse(Nil), + options = extraOptions.toMap).planForWriting(mode, AnalysisBarrier(df.logicalPlan)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala index f0bdf84bb7a84..a4a857f2d4d9b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala @@ -81,9 +81,11 @@ case class WriteToDataSourceV2Exec(writer: DataSourceV2Writer, query: SparkPlan) (index, message: WriterCommitMessage) => messages(index) = message ) - logInfo(s"Data source writer $writer is committing.") - writer.commit(messages) - logInfo(s"Data source writer $writer committed.") + if (!writer.isInstanceOf[ContinuousWriter]) { + logInfo(s"Data source writer $writer is committing.") + writer.commit(messages) + logInfo(s"Data source writer $writer committed.") + } } catch { case _: InterruptedException if writer.isInstanceOf[ContinuousWriter] => // Interruption is how continuous queries are ended, so accept and ignore the exception. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index 24a8b000df0c1..cf27e1a70650a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -142,7 +142,8 @@ abstract class StreamExecution( override val id: UUID = UUID.fromString(streamMetadata.id) - override val runId: UUID = UUID.randomUUID + override def runId: UUID = currentRunId + protected var currentRunId = UUID.randomUUID /** * Pretty identified string of printing in logs. Format is @@ -418,11 +419,17 @@ abstract class StreamExecution( * Blocks the current thread until processing for data from the given `source` has reached at * least the given `Offset`. This method is intended for use primarily when writing tests. */ - private[sql] def awaitOffset(source: BaseStreamingSource, newOffset: Offset): Unit = { + private[sql] def awaitOffset(sourceIndex: Int, newOffset: Offset): Unit = { assertAwaitThread() def notDone = { val localCommittedOffsets = committedOffsets - !localCommittedOffsets.contains(source) || localCommittedOffsets(source) != newOffset + if (sources == null) { + // sources might not be initialized yet + false + } else { + val source = sources(sourceIndex) + !localCommittedOffsets.contains(source) || localCommittedOffsets(source) != newOffset + } } while (notDone) { @@ -436,7 +443,7 @@ abstract class StreamExecution( awaitProgressLock.unlock() } } - logDebug(s"Unblocked at $newOffset for $source") + logDebug(s"Unblocked at $newOffset for ${sources(sourceIndex)}") } /** A flag to indicate that a batch has completed with no new data available. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala index d79e4bd65f563..e700aa4f9aea7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala @@ -77,7 +77,6 @@ class ContinuousDataSourceRDD( dataReaderThread.start() context.addTaskCompletionListener(_ => { - reader.close() dataReaderThread.interrupt() epochPollExecutor.shutdown() }) @@ -201,6 +200,8 @@ class DataReaderThread( failedFlag.set(true) // Don't rethrow the exception in this thread. It's not needed, and the default Spark // exception handler will kill the executor. + } finally { + reader.close() } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index 9657b5e26d770..667410ef9f1c6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -17,7 +17,9 @@ package org.apache.spark.sql.execution.streaming.continuous +import java.util.UUID import java.util.concurrent.TimeUnit +import java.util.function.UnaryOperator import scala.collection.JavaConverters._ import scala.collection.mutable.{ArrayBuffer, Map => MutableMap} @@ -52,7 +54,7 @@ class ContinuousExecution( sparkSession, name, checkpointRoot, analyzedPlan, sink, trigger, triggerClock, outputMode, deleteCheckpointOnStop) { - @volatile protected var continuousSources: Seq[ContinuousReader] = Seq.empty + @volatile protected var continuousSources: Seq[ContinuousReader] = _ override protected def sources: Seq[BaseStreamingSource] = continuousSources override lazy val logicalPlan: LogicalPlan = { @@ -78,15 +80,17 @@ class ContinuousExecution( } override protected def runActivatedStream(sparkSessionForStream: SparkSession): Unit = { - do { - try { - runContinuous(sparkSessionForStream) - } catch { - case _: InterruptedException if state.get().equals(RECONFIGURING) => - // swallow exception and run again - state.set(ACTIVE) + val stateUpdate = new UnaryOperator[State] { + override def apply(s: State) = s match { + // If we ended the query to reconfigure, reset the state to active. + case RECONFIGURING => ACTIVE + case _ => s } - } while (state.get() == ACTIVE) + } + + do { + runContinuous(sparkSessionForStream) + } while (state.updateAndGet(stateUpdate) == ACTIVE) } /** @@ -120,12 +124,16 @@ class ContinuousExecution( } committedOffsets = nextOffsets.toStreamProgress(sources) - // Forcibly align commit and offset logs by slicing off any spurious offset logs from - // a previous run. We can't allow commits to an epoch that a previous run reached but - // this run has not. - offsetLog.purgeAfter(latestEpochId) + // Get to an epoch ID that has definitely never been sent to a sink before. Since sink + // commit happens between offset log write and commit log write, this means an epoch ID + // which is not in the offset log. + val (latestOffsetEpoch, _) = offsetLog.getLatest().getOrElse { + throw new IllegalStateException( + s"Offset log had no latest element. This shouldn't be possible because nextOffsets is" + + s"an element.") + } + currentBatchId = latestOffsetEpoch + 1 - currentBatchId = latestEpochId + 1 logDebug(s"Resuming at epoch $currentBatchId with committed offsets $committedOffsets") nextOffsets case None => @@ -141,6 +149,7 @@ class ContinuousExecution( * @param sparkSessionForQuery Isolated [[SparkSession]] to run the continuous query with. */ private def runContinuous(sparkSessionForQuery: SparkSession): Unit = { + currentRunId = UUID.randomUUID // A list of attributes that will need to be updated. val replacements = new ArrayBuffer[(Attribute, Attribute)] // Translate from continuous relation to the underlying data source. @@ -225,13 +234,11 @@ class ContinuousExecution( triggerExecutor.execute(() => { startTrigger() - if (reader.needsReconfiguration()) { - state.set(RECONFIGURING) + if (reader.needsReconfiguration() && state.compareAndSet(ACTIVE, RECONFIGURING)) { stopSources() if (queryExecutionThread.isAlive) { sparkSession.sparkContext.cancelJobGroup(runId.toString) queryExecutionThread.interrupt() - // No need to join - this thread is about to end anyway. } false } else if (isActive) { @@ -259,6 +266,7 @@ class ContinuousExecution( sparkSessionForQuery, lastExecution)(lastExecution.toRdd) } } finally { + epochEndpoint.askSync[Unit](StopContinuousExecutionWrites) SparkEnv.get.rpcEnv.stop(epochEndpoint) epochUpdateThread.interrupt() @@ -273,17 +281,22 @@ class ContinuousExecution( epoch: Long, reader: ContinuousReader, partitionOffsets: Seq[PartitionOffset]): Unit = { assert(continuousSources.length == 1, "only one continuous source supported currently") - if (partitionOffsets.contains(null)) { - // If any offset is null, that means the corresponding partition hasn't seen any data yet, so - // there's nothing meaningful to add to the offset log. - } val globalOffset = reader.mergeOffsets(partitionOffsets.toArray) - synchronized { - if (queryExecutionThread.isAlive) { - offsetLog.add(epoch, OffsetSeq.fill(globalOffset)) - } else { - return - } + val oldOffset = synchronized { + offsetLog.add(epoch, OffsetSeq.fill(globalOffset)) + offsetLog.get(epoch - 1) + } + + // If offset hasn't changed since last epoch, there's been no new data. + if (oldOffset.contains(OffsetSeq.fill(globalOffset))) { + noNewData = true + } + + awaitProgressLock.lock() + try { + awaitProgressLockCondition.signalAll() + } finally { + awaitProgressLock.unlock() } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala index 98017c3ac6a33..40dcbecade814 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala @@ -39,6 +39,15 @@ private[continuous] sealed trait EpochCoordinatorMessage extends Serializable */ private[sql] case object IncrementAndGetEpoch extends EpochCoordinatorMessage +/** + * The RpcEndpoint stop() will wait to clear out the message queue before terminating the + * object. This can lead to a race condition where the query restarts at epoch n, a new + * EpochCoordinator starts at epoch n, and then the old epoch coordinator commits epoch n + 1. + * The framework doesn't provide a handle to wait on the message queue, so we use a synchronous + * message to stop any writes to the ContinuousExecution object. + */ +private[sql] case object StopContinuousExecutionWrites extends EpochCoordinatorMessage + // Init messages /** * Set the reader and writer partition counts. Tasks may not be started until the coordinator @@ -116,6 +125,8 @@ private[continuous] class EpochCoordinator( override val rpcEnv: RpcEnv) extends ThreadSafeRpcEndpoint with Logging { + private var queryWritesStopped: Boolean = false + private var numReaderPartitions: Int = _ private var numWriterPartitions: Int = _ @@ -147,12 +158,16 @@ private[continuous] class EpochCoordinator( partitionCommits.remove(k) } for (k <- partitionOffsets.keys.filter { case (e, _) => e < epoch }) { - partitionCommits.remove(k) + partitionOffsets.remove(k) } } } override def receive: PartialFunction[Any, Unit] = { + // If we just drop these messages, we won't do any writes to the query. The lame duck tasks + // won't shed errors or anything. + case _ if queryWritesStopped => () + case CommitPartitionEpoch(partitionId, epoch, message) => logDebug(s"Got commit from partition $partitionId at epoch $epoch: $message") if (!partitionCommits.isDefinedAt((epoch, partitionId))) { @@ -188,5 +203,9 @@ private[continuous] class EpochCoordinator( case SetWriterPartitions(numPartitions) => numWriterPartitions = numPartitions context.reply(()) + + case StopContinuousExecutionWrites => + queryWritesStopped = true + context.reply(()) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala index db588ae282f38..b5b4a05ab4973 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger import org.apache.spark.sql.execution.streaming.sources.{MemoryPlanV2, MemorySinkV2} +import org.apache.spark.sql.sources.v2.streaming.ContinuousWriteSupport /** * Interface used to write a streaming `Dataset` to external storage systems (e.g. file systems, @@ -279,18 +280,29 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { useTempCheckpointLocation = true, trigger = trigger) } else { - val dataSource = - DataSource( - df.sparkSession, - className = source, - options = extraOptions.toMap, - partitionColumns = normalizedParCols.getOrElse(Nil)) + val sink = trigger match { + case _: ContinuousTrigger => + val ds = DataSource.lookupDataSource(source, df.sparkSession.sessionState.conf) + ds.newInstance() match { + case w: ContinuousWriteSupport => w + case _ => throw new AnalysisException( + s"Data source $source does not support continuous writing") + } + case _ => + val ds = DataSource( + df.sparkSession, + className = source, + options = extraOptions.toMap, + partitionColumns = normalizedParCols.getOrElse(Nil)) + ds.createSink(outputMode) + } + df.sparkSession.sessionState.streamingQueryManager.startQuery( extraOptions.get("queryName"), extraOptions.get("checkpointLocation"), df, extraOptions.toMap, - dataSource.createSink(outputMode), + sink, outputMode, useTempCheckpointLocation = source == "console", recoverFromCheckpointLocation = true, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index d46461fa9bf6d..0762895fdc620 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -38,8 +38,9 @@ import org.apache.spark.sql.{Dataset, Encoder, QueryTest, Row} import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder, RowEncoder} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation import org.apache.spark.sql.execution.streaming._ -import org.apache.spark.sql.execution.streaming.continuous.{ContinuousExecution, EpochCoordinatorRef, IncrementAndGetEpoch} +import org.apache.spark.sql.execution.streaming.continuous.{ContinuousExecution, ContinuousTrigger, EpochCoordinatorRef, IncrementAndGetEpoch} import org.apache.spark.sql.execution.streaming.sources.MemorySinkV2 import org.apache.spark.sql.execution.streaming.state.StateStore import org.apache.spark.sql.streaming.StreamingQueryListener._ @@ -80,6 +81,9 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be StateStore.stop() // stop the state store maintenance thread and unload store providers } + protected val defaultTrigger = Trigger.ProcessingTime(0) + protected val defaultUseV2Sink = false + /** How long to wait for an active stream to catch up when checking a result. */ val streamingTimeout = 10.seconds @@ -189,7 +193,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be /** Starts the stream, resuming if data has already been processed. It must not be running. */ case class StartStream( - trigger: Trigger = Trigger.ProcessingTime(0), + trigger: Trigger = defaultTrigger, triggerClock: Clock = new SystemClock, additionalConfs: Map[String, String] = Map.empty, checkpointLocation: String = null) @@ -276,7 +280,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be def testStream( _stream: Dataset[_], outputMode: OutputMode = OutputMode.Append, - useV2Sink: Boolean = false)(actions: StreamAction*): Unit = synchronized { + useV2Sink: Boolean = defaultUseV2Sink)(actions: StreamAction*): Unit = synchronized { import org.apache.spark.sql.streaming.util.StreamManualClock // `synchronized` is added to prevent the user from calling multiple `testStream`s concurrently @@ -403,18 +407,11 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be def fetchStreamAnswer(currentStream: StreamExecution, lastOnly: Boolean) = { verify(currentStream != null, "stream not running") - // Get the map of source index to the current source objects - val indexToSource = currentStream - .logicalPlan - .collect { case StreamingExecutionRelation(s, _) => s } - .zipWithIndex - .map(_.swap) - .toMap // Block until all data added has been processed for all the source awaiting.foreach { case (sourceIndex, offset) => failAfter(streamingTimeout) { - currentStream.awaitOffset(indexToSource(sourceIndex), offset) + currentStream.awaitOffset(sourceIndex, offset) } } @@ -473,6 +470,12 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be // after starting the query. try { currentStream.awaitInitialization(streamingTimeout.toMillis) + currentStream match { + case s: ContinuousExecution => eventually("IncrementalExecution was not created") { + s.lastExecution.executedPlan // will fail if lastExecution is null + } + case _ => + } } catch { case _: StreamingQueryException => // Ignore the exception. `StopStream` or `ExpectFailure` will catch it as well. @@ -600,7 +603,10 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be def findSourceIndex(plan: LogicalPlan): Option[Int] = { plan - .collect { case StreamingExecutionRelation(s, _) => s } + .collect { + case StreamingExecutionRelation(s, _) => s + case DataSourceV2Relation(_, r) => r + } .zipWithIndex .find(_._1 == source) .map(_._2) @@ -613,9 +619,13 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be findSourceIndex(query.logicalPlan) }.orElse { findSourceIndex(stream.logicalPlan) + }.orElse { + queryToUse.flatMap { q => + findSourceIndex(q.lastExecution.logical) + } }.getOrElse { throw new IllegalArgumentException( - "Could find index of the source to which data was added") + "Could not find index of the source to which data was added") } // Store the expected offset of added data to wait for it later From 186bf8fb2e9ff8a80f3f6bcb5f2a0327fa79a1c9 Mon Sep 17 00:00:00 2001 From: Bago Amirbekian Date: Thu, 11 Jan 2018 13:57:15 -0800 Subject: [PATCH 0071/1161] [SPARK-23046][ML][SPARKR] Have RFormula include VectorSizeHint in pipeline ## What changes were proposed in this pull request? Including VectorSizeHint in RFormula piplelines will allow them to be applied to streaming dataframes. ## How was this patch tested? Unit tests. Author: Bago Amirbekian Closes #20238 from MrBago/rFormulaVectorSize. --- R/pkg/R/mllib_utils.R | 1 + .../apache/spark/ml/feature/RFormula.scala | 18 +++++++-- .../spark/ml/feature/RFormulaSuite.scala | 37 ++++++++++++++++--- 3 files changed, 48 insertions(+), 8 deletions(-) diff --git a/R/pkg/R/mllib_utils.R b/R/pkg/R/mllib_utils.R index a53c92c2c4815..23dda42c325be 100644 --- a/R/pkg/R/mllib_utils.R +++ b/R/pkg/R/mllib_utils.R @@ -130,3 +130,4 @@ read.ml <- function(path) { stop("Unsupported model: ", jobj) } } + diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala index 7da3339f8b487..f384ffbf578bc 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala @@ -25,7 +25,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.{Estimator, Model, Pipeline, PipelineModel, PipelineStage, Transformer} import org.apache.spark.ml.attribute.AttributeGroup -import org.apache.spark.ml.linalg.VectorUDT +import org.apache.spark.ml.linalg.{Vector, VectorUDT} import org.apache.spark.ml.param.{BooleanParam, Param, ParamMap, ParamValidators} import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasHandleInvalid, HasLabelCol} import org.apache.spark.ml.util._ @@ -210,8 +210,8 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override val uid: String) // First we index each string column referenced by the input terms. val indexed: Map[String, String] = resolvedFormula.terms.flatten.distinct.map { term => - dataset.schema(term) match { - case column if column.dataType == StringType => + dataset.schema(term).dataType match { + case _: StringType => val indexCol = tmpColumn("stridx") encoderStages += new StringIndexer() .setInputCol(term) @@ -220,6 +220,18 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override val uid: String) .setHandleInvalid($(handleInvalid)) prefixesToRewrite(indexCol + "_") = term + "_" (term, indexCol) + case _: VectorUDT => + val group = AttributeGroup.fromStructField(dataset.schema(term)) + val size = if (group.size < 0) { + dataset.select(term).first().getAs[Vector](0).size + } else { + group.size + } + encoderStages += new VectorSizeHint(uid) + .setHandleInvalid("optimistic") + .setInputCol(term) + .setSize(size) + (term, term) case _ => (term, term) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala index 5d09c90ec6dfa..f3f4b5a3d0233 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala @@ -17,15 +17,15 @@ package org.apache.spark.ml.feature -import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.SparkException import org.apache.spark.ml.attribute._ -import org.apache.spark.ml.linalg.Vectors +import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} -import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} +import org.apache.spark.sql.{DataFrame, Encoder, Row} import org.apache.spark.sql.types.DoubleType -class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class RFormulaSuite extends MLTest with DefaultReadWriteTest { import testImplicits._ @@ -548,4 +548,31 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul assert(result3.collect() === expected3.collect()) assert(result4.collect() === expected4.collect()) } + + test("Use Vectors as inputs to formula.") { + val original = Seq( + (1, 4, Vectors.dense(0.0, 0.0, 4.0)), + (2, 4, Vectors.dense(1.0, 0.0, 4.0)), + (3, 5, Vectors.dense(1.0, 0.0, 5.0)), + (4, 5, Vectors.dense(0.0, 1.0, 5.0)) + ).toDF("id", "a", "b") + val formula = new RFormula().setFormula("id ~ a + b") + val (first +: rest) = Seq("id", "a", "b", "features", "label") + testTransformer[(Int, Int, Vector)](original, formula.fit(original), first, rest: _*) { + case Row(id: Int, a: Int, b: Vector, features: Vector, label: Double) => + assert(label === id) + assert(features.toArray === a +: b.toArray) + } + + val group = new AttributeGroup("b", 3) + val vectorColWithMetadata = original("b").as("b", group.toMetadata()) + val dfWithMetadata = original.withColumn("b", vectorColWithMetadata) + val model = formula.fit(dfWithMetadata) + // model should work even when applied to dataframe without metadata. + testTransformer[(Int, Int, Vector)](original, model, first, rest: _*) { + case Row(id: Int, a: Int, b: Vector, features: Vector, label: Double) => + assert(label === id) + assert(features.toArray === a +: b.toArray) + } + } } From b5042d75c2faa5f15bc1e160d75f06dfdd6eea37 Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Thu, 11 Jan 2018 16:20:30 -0800 Subject: [PATCH 0072/1161] [SPARK-23008][ML] OnehotEncoderEstimator python API ## What changes were proposed in this pull request? OnehotEncoderEstimator python API. ## How was this patch tested? doctest Author: WeichenXu Closes #20209 from WeichenXu123/ohe_py. --- python/pyspark/ml/feature.py | 113 ++++++++++++++++++ .../ml/param/_shared_params_code_gen.py | 1 + python/pyspark/ml/param/shared.py | 23 ++++ 3 files changed, 137 insertions(+) diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 13bf95cce40be..b963e45dd7cff 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -45,6 +45,7 @@ 'NGram', 'Normalizer', 'OneHotEncoder', + 'OneHotEncoderEstimator', 'OneHotEncoderModel', 'PCA', 'PCAModel', 'PolynomialExpansion', 'QuantileDiscretizer', @@ -1641,6 +1642,118 @@ def getDropLast(self): return self.getOrDefault(self.dropLast) +@inherit_doc +class OneHotEncoderEstimator(JavaEstimator, HasInputCols, HasOutputCols, HasHandleInvalid, + JavaMLReadable, JavaMLWritable): + """ + A one-hot encoder that maps a column of category indices to a column of binary vectors, with + at most a single one-value per row that indicates the input category index. + For example with 5 categories, an input value of 2.0 would map to an output vector of + `[0.0, 0.0, 1.0, 0.0]`. + The last category is not included by default (configurable via `dropLast`), + because it makes the vector entries sum up to one, and hence linearly dependent. + So an input value of 4.0 maps to `[0.0, 0.0, 0.0, 0.0]`. + + Note: This is different from scikit-learn's OneHotEncoder, which keeps all categories. + The output vectors are sparse. + + When `handleInvalid` is configured to 'keep', an extra "category" indicating invalid values is + added as last category. So when `dropLast` is true, invalid values are encoded as all-zeros + vector. + + Note: When encoding multi-column by using `inputCols` and `outputCols` params, input/output + cols come in pairs, specified by the order in the arrays, and each pair is treated + independently. + + See `StringIndexer` for converting categorical values into category indices + + >>> from pyspark.ml.linalg import Vectors + >>> df = spark.createDataFrame([(0.0,), (1.0,), (2.0,)], ["input"]) + >>> ohe = OneHotEncoderEstimator(inputCols=["input"], outputCols=["output"]) + >>> model = ohe.fit(df) + >>> model.transform(df).head().output + SparseVector(2, {0: 1.0}) + >>> ohePath = temp_path + "/oheEstimator" + >>> ohe.save(ohePath) + >>> loadedOHE = OneHotEncoderEstimator.load(ohePath) + >>> loadedOHE.getInputCols() == ohe.getInputCols() + True + >>> modelPath = temp_path + "/ohe-model" + >>> model.save(modelPath) + >>> loadedModel = OneHotEncoderModel.load(modelPath) + >>> loadedModel.categorySizes == model.categorySizes + True + + .. versionadded:: 2.3.0 + """ + + handleInvalid = Param(Params._dummy(), "handleInvalid", "How to handle invalid data during " + + "transform(). Options are 'keep' (invalid data presented as an extra " + + "categorical feature) or error (throw an error). Note that this Param " + + "is only used during transform; during fitting, invalid data will " + + "result in an error.", + typeConverter=TypeConverters.toString) + + dropLast = Param(Params._dummy(), "dropLast", "whether to drop the last category", + typeConverter=TypeConverters.toBoolean) + + @keyword_only + def __init__(self, inputCols=None, outputCols=None, handleInvalid="error", dropLast=True): + """ + __init__(self, inputCols=None, outputCols=None, handleInvalid="error", dropLast=True) + """ + super(OneHotEncoderEstimator, self).__init__() + self._java_obj = self._new_java_obj( + "org.apache.spark.ml.feature.OneHotEncoderEstimator", self.uid) + self._setDefault(handleInvalid="error", dropLast=True) + kwargs = self._input_kwargs + self.setParams(**kwargs) + + @keyword_only + @since("2.3.0") + def setParams(self, inputCols=None, outputCols=None, handleInvalid="error", dropLast=True): + """ + setParams(self, inputCols=None, outputCols=None, handleInvalid="error", dropLast=True) + Sets params for this OneHotEncoderEstimator. + """ + kwargs = self._input_kwargs + return self._set(**kwargs) + + @since("2.3.0") + def setDropLast(self, value): + """ + Sets the value of :py:attr:`dropLast`. + """ + return self._set(dropLast=value) + + @since("2.3.0") + def getDropLast(self): + """ + Gets the value of dropLast or its default value. + """ + return self.getOrDefault(self.dropLast) + + def _create_model(self, java_model): + return OneHotEncoderModel(java_model) + + +class OneHotEncoderModel(JavaModel, JavaMLReadable, JavaMLWritable): + """ + Model fitted by :py:class:`OneHotEncoderEstimator`. + + .. versionadded:: 2.3.0 + """ + + @property + @since("2.3.0") + def categorySizes(self): + """ + Original number of categories for each feature being encoded. + The array contains one value for each input column, in order. + """ + return self._call_java("categorySizes") + + @inherit_doc class PolynomialExpansion(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): diff --git a/python/pyspark/ml/param/_shared_params_code_gen.py b/python/pyspark/ml/param/_shared_params_code_gen.py index 1d0f60acc6983..db951d81de1e7 100644 --- a/python/pyspark/ml/param/_shared_params_code_gen.py +++ b/python/pyspark/ml/param/_shared_params_code_gen.py @@ -119,6 +119,7 @@ def get$Name(self): ("inputCol", "input column name.", None, "TypeConverters.toString"), ("inputCols", "input column names.", None, "TypeConverters.toListString"), ("outputCol", "output column name.", "self.uid + '__output'", "TypeConverters.toString"), + ("outputCols", "output column names.", None, "TypeConverters.toListString"), ("numFeatures", "number of features.", None, "TypeConverters.toInt"), ("checkpointInterval", "set checkpoint interval (>= 1) or disable checkpoint (-1). " + "E.g. 10 means that the cache will get checkpointed every 10 iterations. Note: " + diff --git a/python/pyspark/ml/param/shared.py b/python/pyspark/ml/param/shared.py index 813f7a59f3fd1..474c38764e5a1 100644 --- a/python/pyspark/ml/param/shared.py +++ b/python/pyspark/ml/param/shared.py @@ -256,6 +256,29 @@ def getOutputCol(self): return self.getOrDefault(self.outputCol) +class HasOutputCols(Params): + """ + Mixin for param outputCols: output column names. + """ + + outputCols = Param(Params._dummy(), "outputCols", "output column names.", typeConverter=TypeConverters.toListString) + + def __init__(self): + super(HasOutputCols, self).__init__() + + def setOutputCols(self, value): + """ + Sets the value of :py:attr:`outputCols`. + """ + return self._set(outputCols=value) + + def getOutputCols(self): + """ + Gets the value of outputCols or its default value. + """ + return self.getOrDefault(self.outputCols) + + class HasNumFeatures(Params): """ Mixin for param numFeatures: number of features. From cbe7c6fbf9dc2fc422b93b3644c40d449a869eea Mon Sep 17 00:00:00 2001 From: ho3rexqj Date: Fri, 12 Jan 2018 15:27:00 +0800 Subject: [PATCH 0073/1161] [SPARK-22986][CORE] Use a cache to avoid instantiating multiple instances of broadcast variable values When resources happen to be constrained on an executor the first time a broadcast variable is instantiated it is persisted to disk by the BlockManager. Consequently, every subsequent call to TorrentBroadcast::readBroadcastBlock from other instances of that broadcast variable spawns another instance of the underlying value. That is, broadcast variables are spawned once per executor **unless** memory is constrained, in which case every instance of a broadcast variable is provided with a unique copy of the underlying value. This patch fixes the above by explicitly caching the underlying values using weak references in a ReferenceMap. Author: ho3rexqj Closes #20183 from ho3rexqj/fix/cache-broadcast-values. --- .../spark/broadcast/BroadcastManager.scala | 6 ++ .../spark/broadcast/TorrentBroadcast.scala | 72 +++++++++++-------- .../spark/broadcast/BroadcastSuite.scala | 34 +++++++++ 3 files changed, 83 insertions(+), 29 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala b/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala index e88988fe03b2e..8d7a4a353a792 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala @@ -21,6 +21,8 @@ import java.util.concurrent.atomic.AtomicLong import scala.reflect.ClassTag +import org.apache.commons.collections.map.{AbstractReferenceMap, ReferenceMap} + import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.internal.Logging @@ -52,6 +54,10 @@ private[spark] class BroadcastManager( private val nextBroadcastId = new AtomicLong(0) + private[broadcast] val cachedValues = { + new ReferenceMap(AbstractReferenceMap.HARD, AbstractReferenceMap.WEAK) + } + def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean): Broadcast[T] = { broadcastFactory.newBroadcast[T](value_, isLocal, nextBroadcastId.getAndIncrement()) } diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala index 7aecd3c9668ea..e125095cf4777 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala @@ -206,36 +206,50 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long) private def readBroadcastBlock(): T = Utils.tryOrIOException { TorrentBroadcast.synchronized { - setConf(SparkEnv.get.conf) - val blockManager = SparkEnv.get.blockManager - blockManager.getLocalValues(broadcastId) match { - case Some(blockResult) => - if (blockResult.data.hasNext) { - val x = blockResult.data.next().asInstanceOf[T] - releaseLock(broadcastId) - x - } else { - throw new SparkException(s"Failed to get locally stored broadcast data: $broadcastId") - } - case None => - logInfo("Started reading broadcast variable " + id) - val startTimeMs = System.currentTimeMillis() - val blocks = readBlocks() - logInfo("Reading broadcast variable " + id + " took" + Utils.getUsedTimeMs(startTimeMs)) - - try { - val obj = TorrentBroadcast.unBlockifyObject[T]( - blocks.map(_.toInputStream()), SparkEnv.get.serializer, compressionCodec) - // Store the merged copy in BlockManager so other tasks on this executor don't - // need to re-fetch it. - val storageLevel = StorageLevel.MEMORY_AND_DISK - if (!blockManager.putSingle(broadcastId, obj, storageLevel, tellMaster = false)) { - throw new SparkException(s"Failed to store $broadcastId in BlockManager") + val broadcastCache = SparkEnv.get.broadcastManager.cachedValues + + Option(broadcastCache.get(broadcastId)).map(_.asInstanceOf[T]).getOrElse { + setConf(SparkEnv.get.conf) + val blockManager = SparkEnv.get.blockManager + blockManager.getLocalValues(broadcastId) match { + case Some(blockResult) => + if (blockResult.data.hasNext) { + val x = blockResult.data.next().asInstanceOf[T] + releaseLock(broadcastId) + + if (x != null) { + broadcastCache.put(broadcastId, x) + } + + x + } else { + throw new SparkException(s"Failed to get locally stored broadcast data: $broadcastId") } - obj - } finally { - blocks.foreach(_.dispose()) - } + case None => + logInfo("Started reading broadcast variable " + id) + val startTimeMs = System.currentTimeMillis() + val blocks = readBlocks() + logInfo("Reading broadcast variable " + id + " took" + Utils.getUsedTimeMs(startTimeMs)) + + try { + val obj = TorrentBroadcast.unBlockifyObject[T]( + blocks.map(_.toInputStream()), SparkEnv.get.serializer, compressionCodec) + // Store the merged copy in BlockManager so other tasks on this executor don't + // need to re-fetch it. + val storageLevel = StorageLevel.MEMORY_AND_DISK + if (!blockManager.putSingle(broadcastId, obj, storageLevel, tellMaster = false)) { + throw new SparkException(s"Failed to store $broadcastId in BlockManager") + } + + if (obj != null) { + broadcastCache.put(broadcastId, obj) + } + + obj + } finally { + blocks.foreach(_.dispose()) + } + } } } } diff --git a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala index 159629825c677..9ad2e9a5e74ac 100644 --- a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala +++ b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala @@ -153,6 +153,40 @@ class BroadcastSuite extends SparkFunSuite with LocalSparkContext with Encryptio assert(broadcast.value.sum === 10) } + test("One broadcast value instance per executor") { + val conf = new SparkConf() + .setMaster("local[4]") + .setAppName("test") + + sc = new SparkContext(conf) + val list = List[Int](1, 2, 3, 4) + val broadcast = sc.broadcast(list) + val instances = sc.parallelize(1 to 10) + .map(x => System.identityHashCode(broadcast.value)) + .collect() + .toSet + + assert(instances.size === 1) + } + + test("One broadcast value instance per executor when memory is constrained") { + val conf = new SparkConf() + .setMaster("local[4]") + .setAppName("test") + .set("spark.memory.useLegacyMode", "true") + .set("spark.storage.memoryFraction", "0.0") + + sc = new SparkContext(conf) + val list = List[Int](1, 2, 3, 4) + val broadcast = sc.broadcast(list) + val instances = sc.parallelize(1 to 10) + .map(x => System.identityHashCode(broadcast.value)) + .collect() + .toSet + + assert(instances.size === 1) + } + /** * Verify the persistence of state associated with a TorrentBroadcast in a local-cluster. * From a7d98d53ceaf69cabaecc6c9113f17438c4e61f6 Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Fri, 12 Jan 2018 11:27:02 +0200 Subject: [PATCH 0074/1161] [SPARK-23008][ML][FOLLOW-UP] mark OneHotEncoder python API deprecated ## What changes were proposed in this pull request? mark OneHotEncoder python API deprecated ## How was this patch tested? N/A Author: WeichenXu Closes #20241 from WeichenXu123/mark_ohe_deprecated. --- python/pyspark/ml/feature.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index b963e45dd7cff..eb79b193103e2 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -1578,6 +1578,9 @@ class OneHotEncoder(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, .. note:: This is different from scikit-learn's OneHotEncoder, which keeps all categories. The output vectors are sparse. + .. note:: Deprecated in 2.3.0. :py:class:`OneHotEncoderEstimator` will be renamed to + :py:class:`OneHotEncoder` and this :py:class:`OneHotEncoder` will be removed in 3.0.0. + .. seealso:: :py:class:`StringIndexer` for converting categorical values into From 505086806997b4331d4a8c2fc5e08345d869a23c Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Fri, 12 Jan 2018 18:04:44 +0800 Subject: [PATCH 0075/1161] [SPARK-23025][SQL] Support Null type in scala reflection ## What changes were proposed in this pull request? Add support for `Null` type in the `schemaFor` method for Scala reflection. ## How was this patch tested? Added UT Author: Marco Gaido Closes #20219 from mgaido91/SPARK-23025. --- .../org/apache/spark/sql/catalyst/ScalaReflection.scala | 4 ++++ .../apache/spark/sql/catalyst/ScalaReflectionSuite.scala | 9 +++++++++ .../test/scala/org/apache/spark/sql/DatasetSuite.scala | 5 +++++ 3 files changed, 18 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 65040f1af4b04..9a4bf0075a178 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -63,6 +63,7 @@ object ScalaReflection extends ScalaReflection { private def dataTypeFor(tpe: `Type`): DataType = cleanUpReflectionObjects { tpe.dealias match { + case t if t <:< definitions.NullTpe => NullType case t if t <:< definitions.IntTpe => IntegerType case t if t <:< definitions.LongTpe => LongType case t if t <:< definitions.DoubleTpe => DoubleType @@ -712,6 +713,9 @@ object ScalaReflection extends ScalaReflection { /** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */ def schemaFor(tpe: `Type`): Schema = cleanUpReflectionObjects { tpe.dealias match { + // this must be the first case, since all objects in scala are instances of Null, therefore + // Null type would wrongly match the first of them, which is Option as of now + case t if t <:< definitions.NullTpe => Schema(NullType, nullable = true) case t if t.typeSymbol.annotations.exists(_.tree.tpe =:= typeOf[SQLUserDefinedType]) => val udt = getClassFromType(t).getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance() Schema(udt, nullable = true) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala index 23e866cdf4917..8c3db48a01f12 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -356,4 +356,13 @@ class ScalaReflectionSuite extends SparkFunSuite { assert(deserializerFor[Int].isInstanceOf[AssertNotNull]) assert(!deserializerFor[String].isInstanceOf[AssertNotNull]) } + + test("SPARK-23025: schemaFor should support Null type") { + val schema = schemaFor[(Int, Null)] + assert(schema === Schema( + StructType(Seq( + StructField("_1", IntegerType, nullable = false), + StructField("_2", NullType, nullable = true))), + nullable = true)) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index d535896723bd5..54893c184642b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -1441,6 +1441,11 @@ class DatasetSuite extends QueryTest with SharedSQLContext { assert(e.getCause.isInstanceOf[NullPointerException]) } } + + test("SPARK-23025: Add support for null type in scala reflection") { + val data = Seq(("a", null)) + checkDataset(data.toDS(), data: _*) + } } case class SingleData(id: Int) From f5300fbbe370af3741560f67bfb5ae6f0b0f7bb5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Matthias=20Beaup=C3=A8re?= Date: Fri, 12 Jan 2018 08:29:46 -0600 Subject: [PATCH 0076/1161] Update rdd-programming-guide.md MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? Small typing correction - double word ## How was this patch tested? Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Matthias Beaupère Closes #20212 from matthiasbe/patch-1. --- docs/rdd-programming-guide.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/rdd-programming-guide.md b/docs/rdd-programming-guide.md index 29af159510e46..2e29aef7f21a2 100644 --- a/docs/rdd-programming-guide.md +++ b/docs/rdd-programming-guide.md @@ -91,7 +91,7 @@ so C libraries like NumPy can be used. It also works with PyPy 2.3+. Python 2.6 support was removed in Spark 2.2.0. -Spark applications in Python can either be run with the `bin/spark-submit` script which includes Spark at runtime, or by including including it in your setup.py as: +Spark applications in Python can either be run with the `bin/spark-submit` script which includes Spark at runtime, or by including it in your setup.py as: {% highlight python %} install_requires=[ From 651f76153f5e9b185aaf593161d40cabe7994fea Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Sat, 13 Jan 2018 00:37:59 +0800 Subject: [PATCH 0077/1161] [SPARK-23028] Bump master branch version to 2.4.0-SNAPSHOT ## What changes were proposed in this pull request? This patch bumps the master branch version to `2.4.0-SNAPSHOT`. ## How was this patch tested? N/A Author: gatorsmile Closes #20222 from gatorsmile/bump24. --- R/pkg/DESCRIPTION | 2 +- assembly/pom.xml | 2 +- common/kvstore/pom.xml | 2 +- common/network-common/pom.xml | 2 +- common/network-shuffle/pom.xml | 2 +- common/network-yarn/pom.xml | 2 +- common/sketch/pom.xml | 2 +- common/tags/pom.xml | 2 +- common/unsafe/pom.xml | 2 +- core/pom.xml | 2 +- dev/run-tests-jenkins.py | 4 ++-- docs/_config.yml | 4 ++-- examples/pom.xml | 2 +- external/docker-integration-tests/pom.xml | 2 +- external/flume-assembly/pom.xml | 2 +- external/flume-sink/pom.xml | 2 +- external/flume/pom.xml | 2 +- external/kafka-0-10-assembly/pom.xml | 2 +- external/kafka-0-10-sql/pom.xml | 2 +- external/kafka-0-10/pom.xml | 2 +- external/kafka-0-8-assembly/pom.xml | 2 +- external/kafka-0-8/pom.xml | 2 +- external/kinesis-asl-assembly/pom.xml | 2 +- external/kinesis-asl/pom.xml | 2 +- external/spark-ganglia-lgpl/pom.xml | 2 +- graphx/pom.xml | 2 +- hadoop-cloud/pom.xml | 2 +- launcher/pom.xml | 2 +- mllib-local/pom.xml | 2 +- mllib/pom.xml | 2 +- pom.xml | 2 +- project/MimaExcludes.scala | 5 +++++ python/pyspark/version.py | 2 +- repl/pom.xml | 2 +- resource-managers/kubernetes/core/pom.xml | 2 +- resource-managers/mesos/pom.xml | 2 +- resource-managers/yarn/pom.xml | 2 +- sql/catalyst/pom.xml | 2 +- sql/core/pom.xml | 2 +- sql/hive-thriftserver/pom.xml | 2 +- sql/hive/pom.xml | 2 +- streaming/pom.xml | 2 +- tools/pom.xml | 2 +- 43 files changed, 49 insertions(+), 44 deletions(-) diff --git a/R/pkg/DESCRIPTION b/R/pkg/DESCRIPTION index 6d46c31906260..855eb5bf77f16 100644 --- a/R/pkg/DESCRIPTION +++ b/R/pkg/DESCRIPTION @@ -1,6 +1,6 @@ Package: SparkR Type: Package -Version: 2.3.0 +Version: 2.4.0 Title: R Frontend for Apache Spark Description: Provides an R Frontend for Apache Spark. Authors@R: c(person("Shivaram", "Venkataraman", role = c("aut", "cre"), diff --git a/assembly/pom.xml b/assembly/pom.xml index b3b4239771bc3..a207dae5a74ff 100644 --- a/assembly/pom.xml +++ b/assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../pom.xml diff --git a/common/kvstore/pom.xml b/common/kvstore/pom.xml index cf93d41cd77cf..8c148359c3029 100644 --- a/common/kvstore/pom.xml +++ b/common/kvstore/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../../pom.xml diff --git a/common/network-common/pom.xml b/common/network-common/pom.xml index 18cbdadd224ab..8ca7733507f1b 100644 --- a/common/network-common/pom.xml +++ b/common/network-common/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../../pom.xml diff --git a/common/network-shuffle/pom.xml b/common/network-shuffle/pom.xml index 9968480ab7658..05335df61a664 100644 --- a/common/network-shuffle/pom.xml +++ b/common/network-shuffle/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../../pom.xml diff --git a/common/network-yarn/pom.xml b/common/network-yarn/pom.xml index ec2db6e5bb88c..564e6583c909e 100644 --- a/common/network-yarn/pom.xml +++ b/common/network-yarn/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../../pom.xml diff --git a/common/sketch/pom.xml b/common/sketch/pom.xml index 2d59c71cc3757..2f04abe8c7e88 100644 --- a/common/sketch/pom.xml +++ b/common/sketch/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../../pom.xml diff --git a/common/tags/pom.xml b/common/tags/pom.xml index f7e586ee777e1..ba127408e1c59 100644 --- a/common/tags/pom.xml +++ b/common/tags/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../../pom.xml diff --git a/common/unsafe/pom.xml b/common/unsafe/pom.xml index a3772a2620088..1527854730394 100644 --- a/common/unsafe/pom.xml +++ b/common/unsafe/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../../pom.xml diff --git a/core/pom.xml b/core/pom.xml index 0a5bd958fc9c5..9258a856028a0 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../pom.xml diff --git a/dev/run-tests-jenkins.py b/dev/run-tests-jenkins.py index 914eb93622d51..3960a0de62530 100755 --- a/dev/run-tests-jenkins.py +++ b/dev/run-tests-jenkins.py @@ -181,8 +181,8 @@ def main(): short_commit_hash = ghprb_actual_commit[0:7] # format: http://linux.die.net/man/1/timeout - # must be less than the timeout configured on Jenkins (currently 300m) - tests_timeout = "250m" + # must be less than the timeout configured on Jenkins (currently 350m) + tests_timeout = "300m" # Array to capture all test names to run on the pull request. These tests are represented # by their file equivalents in the dev/tests/ directory. diff --git a/docs/_config.yml b/docs/_config.yml index dcc211204d766..095fadb93fe5d 100644 --- a/docs/_config.yml +++ b/docs/_config.yml @@ -14,8 +14,8 @@ include: # These allow the documentation to be updated with newer releases # of Spark, Scala, and Mesos. -SPARK_VERSION: 2.3.0-SNAPSHOT -SPARK_VERSION_SHORT: 2.3.0 +SPARK_VERSION: 2.4.0-SNAPSHOT +SPARK_VERSION_SHORT: 2.4.0 SCALA_BINARY_VERSION: "2.11" SCALA_VERSION: "2.11.8" MESOS_VERSION: 1.0.0 diff --git a/examples/pom.xml b/examples/pom.xml index 1791dbaad775e..868110b8e35ef 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../pom.xml diff --git a/external/docker-integration-tests/pom.xml b/external/docker-integration-tests/pom.xml index 485b562dce990..431339d412194 100644 --- a/external/docker-integration-tests/pom.xml +++ b/external/docker-integration-tests/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../../pom.xml diff --git a/external/flume-assembly/pom.xml b/external/flume-assembly/pom.xml index 71016bc645ca7..7cd1ec4c9c09a 100644 --- a/external/flume-assembly/pom.xml +++ b/external/flume-assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../../pom.xml diff --git a/external/flume-sink/pom.xml b/external/flume-sink/pom.xml index 12630840e79dc..f810aa80e8780 100644 --- a/external/flume-sink/pom.xml +++ b/external/flume-sink/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../../pom.xml diff --git a/external/flume/pom.xml b/external/flume/pom.xml index 87a09642405a7..498e88f665eb5 100644 --- a/external/flume/pom.xml +++ b/external/flume/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../../pom.xml diff --git a/external/kafka-0-10-assembly/pom.xml b/external/kafka-0-10-assembly/pom.xml index d6f97316b326a..a742b8d6dbddb 100644 --- a/external/kafka-0-10-assembly/pom.xml +++ b/external/kafka-0-10-assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../../pom.xml diff --git a/external/kafka-0-10-sql/pom.xml b/external/kafka-0-10-sql/pom.xml index 0c9f0aa765a39..16bbc6db641ca 100644 --- a/external/kafka-0-10-sql/pom.xml +++ b/external/kafka-0-10-sql/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../../pom.xml diff --git a/external/kafka-0-10/pom.xml b/external/kafka-0-10/pom.xml index 6eb7ba5f0092d..3b124b2a69d50 100644 --- a/external/kafka-0-10/pom.xml +++ b/external/kafka-0-10/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../../pom.xml diff --git a/external/kafka-0-8-assembly/pom.xml b/external/kafka-0-8-assembly/pom.xml index 786349474389b..41bc8b3e3ee1f 100644 --- a/external/kafka-0-8-assembly/pom.xml +++ b/external/kafka-0-8-assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../../pom.xml diff --git a/external/kafka-0-8/pom.xml b/external/kafka-0-8/pom.xml index 849c8b465f99e..6d1c4789f382d 100644 --- a/external/kafka-0-8/pom.xml +++ b/external/kafka-0-8/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../../pom.xml diff --git a/external/kinesis-asl-assembly/pom.xml b/external/kinesis-asl-assembly/pom.xml index 48783d65826aa..37c7d1e604ec5 100644 --- a/external/kinesis-asl-assembly/pom.xml +++ b/external/kinesis-asl-assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../../pom.xml diff --git a/external/kinesis-asl/pom.xml b/external/kinesis-asl/pom.xml index 40a751a652fa9..4915893965595 100644 --- a/external/kinesis-asl/pom.xml +++ b/external/kinesis-asl/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../../pom.xml diff --git a/external/spark-ganglia-lgpl/pom.xml b/external/spark-ganglia-lgpl/pom.xml index 36d555066b181..027157e53d511 100644 --- a/external/spark-ganglia-lgpl/pom.xml +++ b/external/spark-ganglia-lgpl/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../../pom.xml diff --git a/graphx/pom.xml b/graphx/pom.xml index cb30e4a4af4bc..fbe77fcb958d5 100644 --- a/graphx/pom.xml +++ b/graphx/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../pom.xml diff --git a/hadoop-cloud/pom.xml b/hadoop-cloud/pom.xml index aa36dd4774d86..8e424b1c50236 100644 --- a/hadoop-cloud/pom.xml +++ b/hadoop-cloud/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../pom.xml diff --git a/launcher/pom.xml b/launcher/pom.xml index e9b46c4cf0ffa..912eb6b6d2a08 100644 --- a/launcher/pom.xml +++ b/launcher/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../pom.xml diff --git a/mllib-local/pom.xml b/mllib-local/pom.xml index 043d13609fd26..53286fe93478d 100644 --- a/mllib-local/pom.xml +++ b/mllib-local/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../pom.xml diff --git a/mllib/pom.xml b/mllib/pom.xml index a906c9e02cd4c..f07d7f24fd312 100644 --- a/mllib/pom.xml +++ b/mllib/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../pom.xml diff --git a/pom.xml b/pom.xml index 1b37164376460..d14594aa4ccb0 100644 --- a/pom.xml +++ b/pom.xml @@ -26,7 +26,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT pom Spark Project Parent POM http://spark.apache.org/ diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 3b452f35c5ec1..32eb31f495979 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -34,6 +34,10 @@ import com.typesafe.tools.mima.core.ProblemFilters._ */ object MimaExcludes { + // Exclude rules for 2.4.x + lazy val v24excludes = v23excludes ++ Seq( + ) + // Exclude rules for 2.3.x lazy val v23excludes = v22excludes ++ Seq( // [SPARK-22897] Expose stageAttemptId in TaskContext @@ -1082,6 +1086,7 @@ object MimaExcludes { } def excludes(version: String) = version match { + case v if v.startsWith("2.4") => v24excludes case v if v.startsWith("2.3") => v23excludes case v if v.startsWith("2.2") => v22excludes case v if v.startsWith("2.1") => v21excludes diff --git a/python/pyspark/version.py b/python/pyspark/version.py index 12dd53b9d2902..b9c2c4ced71d5 100644 --- a/python/pyspark/version.py +++ b/python/pyspark/version.py @@ -16,4 +16,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "2.3.0.dev0" +__version__ = "2.4.0.dev0" diff --git a/repl/pom.xml b/repl/pom.xml index 1cb0098d0eca3..6f4a863c48bc7 100644 --- a/repl/pom.xml +++ b/repl/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../pom.xml diff --git a/resource-managers/kubernetes/core/pom.xml b/resource-managers/kubernetes/core/pom.xml index 7d35aea8a4142..a62f271273465 100644 --- a/resource-managers/kubernetes/core/pom.xml +++ b/resource-managers/kubernetes/core/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../../../pom.xml diff --git a/resource-managers/mesos/pom.xml b/resource-managers/mesos/pom.xml index 70d0c1750b14e..3995d0afeb5f4 100644 --- a/resource-managers/mesos/pom.xml +++ b/resource-managers/mesos/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../../pom.xml diff --git a/resource-managers/yarn/pom.xml b/resource-managers/yarn/pom.xml index 43a7ce95bd3de..37e25ceecb883 100644 --- a/resource-managers/yarn/pom.xml +++ b/resource-managers/yarn/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../../pom.xml diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml index 9e2ced30407d4..839b929abd3cb 100644 --- a/sql/catalyst/pom.xml +++ b/sql/catalyst/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../../pom.xml diff --git a/sql/core/pom.xml b/sql/core/pom.xml index 93010c606cf45..744daa6079779 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../../pom.xml diff --git a/sql/hive-thriftserver/pom.xml b/sql/hive-thriftserver/pom.xml index 3135a8a275dae..9f247f9224c75 100644 --- a/sql/hive-thriftserver/pom.xml +++ b/sql/hive-thriftserver/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../../pom.xml diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml index 66fad85ea0263..c55ba32fa458c 100644 --- a/sql/hive/pom.xml +++ b/sql/hive/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../../pom.xml diff --git a/streaming/pom.xml b/streaming/pom.xml index fea882ad11230..4497e53b65984 100644 --- a/streaming/pom.xml +++ b/streaming/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../pom.xml diff --git a/tools/pom.xml b/tools/pom.xml index 37427e8da62d8..242219e29f50f 100644 --- a/tools/pom.xml +++ b/tools/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.11 - 2.3.0-SNAPSHOT + 2.4.0-SNAPSHOT ../pom.xml From 7bd14cfd40500a0b6462cda647bdbb686a430328 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Fri, 12 Jan 2018 10:18:42 -0800 Subject: [PATCH 0078/1161] [MINOR][BUILD] Fix Java linter errors ## What changes were proposed in this pull request? This PR cleans up the java-lint errors (for v2.3.0-rc1 tag). Hopefully, this will be the final one. ``` $ dev/lint-java Using `mvn` from path: /usr/local/bin/mvn Checkstyle checks failed at following occurrences: [ERROR] src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java:[85] (sizes) LineLength: Line is longer than 100 characters (found 101). [ERROR] src/main/java/org/apache/spark/launcher/InProcessAppHandle.java:[20,8] (imports) UnusedImports: Unused import - java.io.IOException. [ERROR] src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java:[41,9] (modifier) ModifierOrder: 'private' modifier out of order with the JLS suggestions. [ERROR] src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java:[464] (sizes) LineLength: Line is longer than 100 characters (found 102). ``` ## How was this patch tested? Manual. ``` $ dev/lint-java Using `mvn` from path: /usr/local/bin/mvn Checkstyle checks passed. ``` Author: Dongjoon Hyun Closes #20242 from dongjoon-hyun/fix_lint_java_2.3_rc1. --- .../org/apache/spark/unsafe/memory/HeapMemoryAllocator.java | 3 ++- .../java/org/apache/spark/launcher/InProcessAppHandle.java | 1 - .../spark/sql/execution/datasources/orc/OrcColumnVector.java | 2 +- .../java/test/org/apache/spark/sql/JavaDataFrameSuite.java | 3 ++- 4 files changed, 5 insertions(+), 4 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java index 3acfe3696cb1e..a9603c1aba051 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java @@ -82,7 +82,8 @@ public void free(MemoryBlock memory) { "page has already been freed"; assert ((memory.pageNumber == MemoryBlock.NO_PAGE_NUMBER) || (memory.pageNumber == MemoryBlock.FREED_IN_TMM_PAGE_NUMBER)) : - "TMM-allocated pages must first be freed via TMM.freePage(), not directly in allocator free()"; + "TMM-allocated pages must first be freed via TMM.freePage(), not directly in allocator " + + "free()"; final long size = memory.size(); if (MemoryAllocator.MEMORY_DEBUG_FILL_ENABLED) { diff --git a/launcher/src/main/java/org/apache/spark/launcher/InProcessAppHandle.java b/launcher/src/main/java/org/apache/spark/launcher/InProcessAppHandle.java index 0d6a73a3da3ed..acd64c962604f 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/InProcessAppHandle.java +++ b/launcher/src/main/java/org/apache/spark/launcher/InProcessAppHandle.java @@ -17,7 +17,6 @@ package org.apache.spark.launcher; -import java.io.IOException; import java.lang.reflect.Method; import java.util.concurrent.atomic.AtomicLong; import java.util.logging.Level; diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java index f94c55d860304..b6e792274da11 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java @@ -38,7 +38,7 @@ public class OrcColumnVector extends org.apache.spark.sql.vectorized.ColumnVecto private BytesColumnVector bytesData; private DecimalColumnVector decimalData; private TimestampColumnVector timestampData; - final private boolean isTimestamp; + private final boolean isTimestamp; private int batchSize; diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index 4f8a31f185724..69a2904f5f3fe 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -461,7 +461,8 @@ public void testCircularReferenceBean() { public void testUDF() { UserDefinedFunction foo = udf((Integer i, String s) -> i.toString() + s, DataTypes.StringType); Dataset df = spark.table("testData").select(foo.apply(col("key"), col("value"))); - String[] result = df.collectAsList().stream().map(row -> row.getString(0)).toArray(String[]::new); + String[] result = df.collectAsList().stream().map(row -> row.getString(0)) + .toArray(String[]::new); String[] expected = spark.table("testData").collectAsList().stream() .map(row -> row.get(0).toString() + row.getString(1)).toArray(String[]::new); Assert.assertArrayEquals(expected, result); From 54277398afbde92a38ba2802f4a7a3e5910533de Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Fri, 12 Jan 2018 11:25:37 -0800 Subject: [PATCH 0079/1161] [SPARK-22975][SS] MetricsReporter should not throw exception when there was no progress reported ## What changes were proposed in this pull request? `MetricsReporter ` assumes that there has been some progress for the query, ie. `lastProgress` is not null. If this is not true, as it might happen in particular conditions, a `NullPointerException` can be thrown. The PR checks whether there is a `lastProgress` and if this is not true, it returns a default value for the metrics. ## How was this patch tested? added UT Author: Marco Gaido Closes #20189 from mgaido91/SPARK-22975. --- .../execution/streaming/MetricsReporter.scala | 21 ++++++++--------- .../sql/streaming/StreamingQuerySuite.scala | 23 +++++++++++++++++++ 2 files changed, 33 insertions(+), 11 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetricsReporter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetricsReporter.scala index b84e6ce64c611..66b11ecddf233 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetricsReporter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetricsReporter.scala @@ -17,15 +17,11 @@ package org.apache.spark.sql.execution.streaming -import java.{util => ju} - -import scala.collection.mutable - import com.codahale.metrics.{Gauge, MetricRegistry} import org.apache.spark.internal.Logging import org.apache.spark.metrics.source.{Source => CodahaleSource} -import org.apache.spark.util.Clock +import org.apache.spark.sql.streaming.StreamingQueryProgress /** * Serves metrics from a [[org.apache.spark.sql.streaming.StreamingQuery]] to @@ -39,14 +35,17 @@ class MetricsReporter( // Metric names should not have . in them, so that all the metrics of a query are identified // together in Ganglia as a single metric group - registerGauge("inputRate-total", () => stream.lastProgress.inputRowsPerSecond) - registerGauge("processingRate-total", () => stream.lastProgress.processedRowsPerSecond) - registerGauge("latency", () => stream.lastProgress.durationMs.get("triggerExecution").longValue()) - - private def registerGauge[T](name: String, f: () => T)(implicit num: Numeric[T]): Unit = { + registerGauge("inputRate-total", _.inputRowsPerSecond, 0.0) + registerGauge("processingRate-total", _.processedRowsPerSecond, 0.0) + registerGauge("latency", _.durationMs.get("triggerExecution").longValue(), 0L) + + private def registerGauge[T]( + name: String, + f: StreamingQueryProgress => T, + default: T): Unit = { synchronized { metricRegistry.register(name, new Gauge[T] { - override def getValue: T = f() + override def getValue: T = Option(stream.lastProgress).map(f).getOrElse(default) }) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala index 2fa4595dab376..76201c63a2701 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala @@ -424,6 +424,29 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi } } + test("SPARK-22975: MetricsReporter defaults when there was no progress reported") { + withSQLConf("spark.sql.streaming.metricsEnabled" -> "true") { + BlockingSource.latch = new CountDownLatch(1) + withTempDir { tempDir => + val sq = spark.readStream + .format("org.apache.spark.sql.streaming.util.BlockingSource") + .load() + .writeStream + .format("org.apache.spark.sql.streaming.util.BlockingSource") + .option("checkpointLocation", tempDir.toString) + .start() + .asInstanceOf[StreamingQueryWrapper] + .streamingQuery + + val gauges = sq.streamMetrics.metricRegistry.getGauges + assert(gauges.get("latency").getValue.asInstanceOf[Long] == 0) + assert(gauges.get("processingRate-total").getValue.asInstanceOf[Double] == 0.0) + assert(gauges.get("inputRate-total").getValue.asInstanceOf[Double] == 0.0) + sq.stop() + } + } + } + test("input row calculation with mixed batch and streaming sources") { val streamingTriggerDF = spark.createDataset(1 to 10).toDF val streamingInputDF = createSingleTriggerStreamingDF(streamingTriggerDF).toDF("value") From 55dbfbca37ce4c05f83180777ba3d4fe2d96a02e Mon Sep 17 00:00:00 2001 From: Sameer Agarwal Date: Fri, 12 Jan 2018 15:00:00 -0800 Subject: [PATCH 0080/1161] Revert "[SPARK-22908] Add kafka source and sink for continuous processing." This reverts commit 6f7aaed805070d29dcba32e04ca7a1f581fa54b9. --- .../sql/kafka010/KafkaContinuousReader.scala | 232 --------- .../sql/kafka010/KafkaContinuousWriter.scala | 119 ----- .../sql/kafka010/KafkaOffsetReader.scala | 21 +- .../spark/sql/kafka010/KafkaSource.scala | 17 +- .../sql/kafka010/KafkaSourceOffset.scala | 7 +- .../sql/kafka010/KafkaSourceProvider.scala | 105 +--- .../spark/sql/kafka010/KafkaWriteTask.scala | 71 +-- .../spark/sql/kafka010/KafkaWriter.scala | 5 +- .../kafka010/KafkaContinuousSinkSuite.scala | 474 ------------------ .../kafka010/KafkaContinuousSourceSuite.scala | 96 ---- .../sql/kafka010/KafkaContinuousTest.scala | 64 --- .../spark/sql/kafka010/KafkaSourceSuite.scala | 470 ++++++++--------- .../apache/spark/sql/DataFrameReader.scala | 32 +- .../apache/spark/sql/DataFrameWriter.scala | 25 +- .../datasources/v2/WriteToDataSourceV2.scala | 8 +- .../execution/streaming/StreamExecution.scala | 15 +- .../ContinuousDataSourceRDDIter.scala | 3 +- .../continuous/ContinuousExecution.scala | 67 +-- .../continuous/EpochCoordinator.scala | 21 +- .../sql/streaming/DataStreamWriter.scala | 26 +- .../spark/sql/streaming/StreamTest.scala | 36 +- 21 files changed, 383 insertions(+), 1531 deletions(-) delete mode 100644 external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala delete mode 100644 external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousWriter.scala delete mode 100644 external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala delete mode 100644 external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala delete mode 100644 external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala deleted file mode 100644 index 928379544758c..0000000000000 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala +++ /dev/null @@ -1,232 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.kafka010 - -import java.{util => ju} - -import org.apache.kafka.clients.consumer.ConsumerRecord -import org.apache.kafka.common.TopicPartition - -import org.apache.spark.internal.Logging -import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.catalyst.expressions.UnsafeRow -import org.apache.spark.sql.catalyst.expressions.codegen.{BufferHolder, UnsafeRowWriter} -import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.kafka010.KafkaSource.{INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE, INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE} -import org.apache.spark.sql.sources.v2.reader._ -import org.apache.spark.sql.sources.v2.streaming.reader.{ContinuousDataReader, ContinuousReader, Offset, PartitionOffset} -import org.apache.spark.sql.types.StructType -import org.apache.spark.unsafe.types.UTF8String - -/** - * A [[ContinuousReader]] for data from kafka. - * - * @param offsetReader a reader used to get kafka offsets. Note that the actual data will be - * read by per-task consumers generated later. - * @param kafkaParams String params for per-task Kafka consumers. - * @param sourceOptions The [[org.apache.spark.sql.sources.v2.DataSourceV2Options]] params which - * are not Kafka consumer params. - * @param metadataPath Path to a directory this reader can use for writing metadata. - * @param initialOffsets The Kafka offsets to start reading data at. - * @param failOnDataLoss Flag indicating whether reading should fail in data loss - * scenarios, where some offsets after the specified initial ones can't be - * properly read. - */ -class KafkaContinuousReader( - offsetReader: KafkaOffsetReader, - kafkaParams: ju.Map[String, Object], - sourceOptions: Map[String, String], - metadataPath: String, - initialOffsets: KafkaOffsetRangeLimit, - failOnDataLoss: Boolean) - extends ContinuousReader with SupportsScanUnsafeRow with Logging { - - private lazy val session = SparkSession.getActiveSession.get - private lazy val sc = session.sparkContext - - // Initialized when creating read tasks. If this diverges from the partitions at the latest - // offsets, we need to reconfigure. - // Exposed outside this object only for unit tests. - private[sql] var knownPartitions: Set[TopicPartition] = _ - - override def readSchema: StructType = KafkaOffsetReader.kafkaSchema - - private var offset: Offset = _ - override def setOffset(start: ju.Optional[Offset]): Unit = { - offset = start.orElse { - val offsets = initialOffsets match { - case EarliestOffsetRangeLimit => KafkaSourceOffset(offsetReader.fetchEarliestOffsets()) - case LatestOffsetRangeLimit => KafkaSourceOffset(offsetReader.fetchLatestOffsets()) - case SpecificOffsetRangeLimit(p) => offsetReader.fetchSpecificOffsets(p, reportDataLoss) - } - logInfo(s"Initial offsets: $offsets") - offsets - } - } - - override def getStartOffset(): Offset = offset - - override def deserializeOffset(json: String): Offset = { - KafkaSourceOffset(JsonUtils.partitionOffsets(json)) - } - - override def createUnsafeRowReadTasks(): ju.List[ReadTask[UnsafeRow]] = { - import scala.collection.JavaConverters._ - - val oldStartPartitionOffsets = KafkaSourceOffset.getPartitionOffsets(offset) - - val currentPartitionSet = offsetReader.fetchEarliestOffsets().keySet - val newPartitions = currentPartitionSet.diff(oldStartPartitionOffsets.keySet) - val newPartitionOffsets = offsetReader.fetchEarliestOffsets(newPartitions.toSeq) - - val deletedPartitions = oldStartPartitionOffsets.keySet.diff(currentPartitionSet) - if (deletedPartitions.nonEmpty) { - reportDataLoss(s"Some partitions were deleted: $deletedPartitions") - } - - val startOffsets = newPartitionOffsets ++ - oldStartPartitionOffsets.filterKeys(!deletedPartitions.contains(_)) - knownPartitions = startOffsets.keySet - - startOffsets.toSeq.map { - case (topicPartition, start) => - KafkaContinuousReadTask( - topicPartition, start, kafkaParams, failOnDataLoss) - .asInstanceOf[ReadTask[UnsafeRow]] - }.asJava - } - - /** Stop this source and free any resources it has allocated. */ - def stop(): Unit = synchronized { - offsetReader.close() - } - - override def commit(end: Offset): Unit = {} - - override def mergeOffsets(offsets: Array[PartitionOffset]): Offset = { - val mergedMap = offsets.map { - case KafkaSourcePartitionOffset(p, o) => Map(p -> o) - }.reduce(_ ++ _) - KafkaSourceOffset(mergedMap) - } - - override def needsReconfiguration(): Boolean = { - knownPartitions != null && offsetReader.fetchLatestOffsets().keySet != knownPartitions - } - - override def toString(): String = s"KafkaSource[$offsetReader]" - - /** - * If `failOnDataLoss` is true, this method will throw an `IllegalStateException`. - * Otherwise, just log a warning. - */ - private def reportDataLoss(message: String): Unit = { - if (failOnDataLoss) { - throw new IllegalStateException(message + s". $INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE") - } else { - logWarning(message + s". $INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE") - } - } -} - -/** - * A read task for continuous Kafka processing. This will be serialized and transformed into a - * full reader on executors. - * - * @param topicPartition The (topic, partition) pair this task is responsible for. - * @param startOffset The offset to start reading from within the partition. - * @param kafkaParams Kafka consumer params to use. - * @param failOnDataLoss Flag indicating whether data reader should fail if some offsets - * are skipped. - */ -case class KafkaContinuousReadTask( - topicPartition: TopicPartition, - startOffset: Long, - kafkaParams: ju.Map[String, Object], - failOnDataLoss: Boolean) extends ReadTask[UnsafeRow] { - override def createDataReader(): KafkaContinuousDataReader = { - new KafkaContinuousDataReader(topicPartition, startOffset, kafkaParams, failOnDataLoss) - } -} - -/** - * A per-task data reader for continuous Kafka processing. - * - * @param topicPartition The (topic, partition) pair this data reader is responsible for. - * @param startOffset The offset to start reading from within the partition. - * @param kafkaParams Kafka consumer params to use. - * @param failOnDataLoss Flag indicating whether data reader should fail if some offsets - * are skipped. - */ -class KafkaContinuousDataReader( - topicPartition: TopicPartition, - startOffset: Long, - kafkaParams: ju.Map[String, Object], - failOnDataLoss: Boolean) extends ContinuousDataReader[UnsafeRow] { - private val topic = topicPartition.topic - private val kafkaPartition = topicPartition.partition - private val consumer = CachedKafkaConsumer.createUncached(topic, kafkaPartition, kafkaParams) - - private val sharedRow = new UnsafeRow(7) - private val bufferHolder = new BufferHolder(sharedRow) - private val rowWriter = new UnsafeRowWriter(bufferHolder, 7) - - private var nextKafkaOffset = startOffset - private var currentRecord: ConsumerRecord[Array[Byte], Array[Byte]] = _ - - override def next(): Boolean = { - var r: ConsumerRecord[Array[Byte], Array[Byte]] = null - while (r == null) { - r = consumer.get( - nextKafkaOffset, - untilOffset = Long.MaxValue, - pollTimeoutMs = Long.MaxValue, - failOnDataLoss) - } - nextKafkaOffset = r.offset + 1 - currentRecord = r - true - } - - override def get(): UnsafeRow = { - bufferHolder.reset() - - if (currentRecord.key == null) { - rowWriter.setNullAt(0) - } else { - rowWriter.write(0, currentRecord.key) - } - rowWriter.write(1, currentRecord.value) - rowWriter.write(2, UTF8String.fromString(currentRecord.topic)) - rowWriter.write(3, currentRecord.partition) - rowWriter.write(4, currentRecord.offset) - rowWriter.write(5, - DateTimeUtils.fromJavaTimestamp(new java.sql.Timestamp(currentRecord.timestamp))) - rowWriter.write(6, currentRecord.timestampType.id) - sharedRow.setTotalSize(bufferHolder.totalSize) - sharedRow - } - - override def getOffset(): KafkaSourcePartitionOffset = { - KafkaSourcePartitionOffset(topicPartition, nextKafkaOffset) - } - - override def close(): Unit = { - consumer.close() - } -} diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousWriter.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousWriter.scala deleted file mode 100644 index 9843f469c5b25..0000000000000 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousWriter.scala +++ /dev/null @@ -1,119 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.kafka010 - -import org.apache.kafka.clients.producer.{Callback, ProducerRecord, RecordMetadata} -import scala.collection.JavaConverters._ - -import org.apache.spark.internal.Logging -import org.apache.spark.sql.{Row, SparkSession} -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Literal, UnsafeProjection} -import org.apache.spark.sql.kafka010.KafkaSourceProvider.{kafkaParamsForProducer, TOPIC_OPTION_KEY} -import org.apache.spark.sql.kafka010.KafkaWriter.validateQuery -import org.apache.spark.sql.sources.v2.streaming.writer.ContinuousWriter -import org.apache.spark.sql.sources.v2.writer._ -import org.apache.spark.sql.streaming.OutputMode -import org.apache.spark.sql.types.{BinaryType, StringType, StructType} - -/** - * Dummy commit message. The DataSourceV2 framework requires a commit message implementation but we - * don't need to really send one. - */ -case object KafkaWriterCommitMessage extends WriterCommitMessage - -/** - * A [[ContinuousWriter]] for Kafka writing. Responsible for generating the writer factory. - * @param topic The topic this writer is responsible for. If None, topic will be inferred from - * a `topic` field in the incoming data. - * @param producerParams Parameters for Kafka producers in each task. - * @param schema The schema of the input data. - */ -class KafkaContinuousWriter( - topic: Option[String], producerParams: Map[String, String], schema: StructType) - extends ContinuousWriter with SupportsWriteInternalRow { - - validateQuery(schema.toAttributes, producerParams.toMap[String, Object].asJava, topic) - - override def createInternalRowWriterFactory(): KafkaContinuousWriterFactory = - KafkaContinuousWriterFactory(topic, producerParams, schema) - - override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {} - override def abort(messages: Array[WriterCommitMessage]): Unit = {} -} - -/** - * A [[DataWriterFactory]] for Kafka writing. Will be serialized and sent to executors to generate - * the per-task data writers. - * @param topic The topic that should be written to. If None, topic will be inferred from - * a `topic` field in the incoming data. - * @param producerParams Parameters for Kafka producers in each task. - * @param schema The schema of the input data. - */ -case class KafkaContinuousWriterFactory( - topic: Option[String], producerParams: Map[String, String], schema: StructType) - extends DataWriterFactory[InternalRow] { - - override def createDataWriter(partitionId: Int, attemptNumber: Int): DataWriter[InternalRow] = { - new KafkaContinuousDataWriter(topic, producerParams, schema.toAttributes) - } -} - -/** - * A [[DataWriter]] for Kafka writing. One data writer will be created in each partition to - * process incoming rows. - * - * @param targetTopic The topic that this data writer is targeting. If None, topic will be inferred - * from a `topic` field in the incoming data. - * @param producerParams Parameters to use for the Kafka producer. - * @param inputSchema The attributes in the input data. - */ -class KafkaContinuousDataWriter( - targetTopic: Option[String], producerParams: Map[String, String], inputSchema: Seq[Attribute]) - extends KafkaRowWriter(inputSchema, targetTopic) with DataWriter[InternalRow] { - import scala.collection.JavaConverters._ - - private lazy val producer = CachedKafkaProducer.getOrCreate( - new java.util.HashMap[String, Object](producerParams.asJava)) - - def write(row: InternalRow): Unit = { - checkForErrors() - sendRow(row, producer) - } - - def commit(): WriterCommitMessage = { - // Send is asynchronous, but we can't commit until all rows are actually in Kafka. - // This requires flushing and then checking that no callbacks produced errors. - // We also check for errors before to fail as soon as possible - the check is cheap. - checkForErrors() - producer.flush() - checkForErrors() - KafkaWriterCommitMessage - } - - def abort(): Unit = {} - - def close(): Unit = { - checkForErrors() - if (producer != null) { - producer.flush() - checkForErrors() - CachedKafkaProducer.close(new java.util.HashMap[String, Object](producerParams.asJava)) - } - } -} diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala index 551641cfdbca8..3e65949a6fd1b 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala @@ -117,14 +117,10 @@ private[kafka010] class KafkaOffsetReader( * Resolves the specific offsets based on Kafka seek positions. * This method resolves offset value -1 to the latest and -2 to the * earliest Kafka seek position. - * - * @param partitionOffsets the specific offsets to resolve - * @param reportDataLoss callback to either report or log data loss depending on setting */ def fetchSpecificOffsets( - partitionOffsets: Map[TopicPartition, Long], - reportDataLoss: String => Unit): KafkaSourceOffset = { - val fetched = runUninterruptibly { + partitionOffsets: Map[TopicPartition, Long]): Map[TopicPartition, Long] = + runUninterruptibly { withRetriesWithoutInterrupt { // Poll to get the latest assigned partitions consumer.poll(0) @@ -149,19 +145,6 @@ private[kafka010] class KafkaOffsetReader( } } - partitionOffsets.foreach { - case (tp, off) if off != KafkaOffsetRangeLimit.LATEST && - off != KafkaOffsetRangeLimit.EARLIEST => - if (fetched(tp) != off) { - reportDataLoss( - s"startingOffsets for $tp was $off but consumer reset to ${fetched(tp)}") - } - case _ => - // no real way to check that beginning or end is reasonable - } - KafkaSourceOffset(fetched) - } - /** * Fetch the earliest offsets for the topic partitions that are indicated * in the [[ConsumerStrategy]]. diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala index 27da76068a66f..e9cff04ba5f2e 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala @@ -130,7 +130,7 @@ private[kafka010] class KafkaSource( val offsets = startingOffsets match { case EarliestOffsetRangeLimit => KafkaSourceOffset(kafkaReader.fetchEarliestOffsets()) case LatestOffsetRangeLimit => KafkaSourceOffset(kafkaReader.fetchLatestOffsets()) - case SpecificOffsetRangeLimit(p) => kafkaReader.fetchSpecificOffsets(p, reportDataLoss) + case SpecificOffsetRangeLimit(p) => fetchAndVerify(p) } metadataLog.add(0, offsets) logInfo(s"Initial offsets: $offsets") @@ -138,6 +138,21 @@ private[kafka010] class KafkaSource( }.partitionToOffsets } + private def fetchAndVerify(specificOffsets: Map[TopicPartition, Long]) = { + val result = kafkaReader.fetchSpecificOffsets(specificOffsets) + specificOffsets.foreach { + case (tp, off) if off != KafkaOffsetRangeLimit.LATEST && + off != KafkaOffsetRangeLimit.EARLIEST => + if (result(tp) != off) { + reportDataLoss( + s"startingOffsets for $tp was $off but consumer reset to ${result(tp)}") + } + case _ => + // no real way to check that beginning or end is reasonable + } + KafkaSourceOffset(result) + } + private var currentPartitionOffsets: Option[Map[TopicPartition, Long]] = None override def schema: StructType = KafkaOffsetReader.kafkaSchema diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceOffset.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceOffset.scala index c82154cfbad7f..b5da415b3097e 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceOffset.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceOffset.scala @@ -20,22 +20,17 @@ package org.apache.spark.sql.kafka010 import org.apache.kafka.common.TopicPartition import org.apache.spark.sql.execution.streaming.{Offset, SerializedOffset} -import org.apache.spark.sql.sources.v2.streaming.reader.{Offset => OffsetV2, PartitionOffset} /** * An [[Offset]] for the [[KafkaSource]]. This one tracks all partitions of subscribed topics and * their offsets. */ private[kafka010] -case class KafkaSourceOffset(partitionToOffsets: Map[TopicPartition, Long]) extends OffsetV2 { +case class KafkaSourceOffset(partitionToOffsets: Map[TopicPartition, Long]) extends Offset { override val json = JsonUtils.partitionOffsets(partitionToOffsets) } -private[kafka010] -case class KafkaSourcePartitionOffset(topicPartition: TopicPartition, partitionOffset: Long) - extends PartitionOffset - /** Companion object of the [[KafkaSourceOffset]] */ private[kafka010] object KafkaSourceOffset { diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala index 3914370a96595..3cb4d8cad12cc 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.kafka010 import java.{util => ju} -import java.util.{Locale, Optional, UUID} +import java.util.{Locale, UUID} import scala.collection.JavaConverters._ @@ -27,12 +27,9 @@ import org.apache.kafka.clients.producer.ProducerConfig import org.apache.kafka.common.serialization.{ByteArrayDeserializer, ByteArraySerializer} import org.apache.spark.internal.Logging -import org.apache.spark.sql.{AnalysisException, DataFrame, SaveMode, SparkSession, SQLContext} -import org.apache.spark.sql.execution.streaming.{Offset, Sink, Source} +import org.apache.spark.sql.{AnalysisException, DataFrame, SaveMode, SQLContext} +import org.apache.spark.sql.execution.streaming.{Sink, Source} import org.apache.spark.sql.sources._ -import org.apache.spark.sql.sources.v2.{DataSourceV2, DataSourceV2Options} -import org.apache.spark.sql.sources.v2.streaming.{ContinuousReadSupport, ContinuousWriteSupport} -import org.apache.spark.sql.sources.v2.streaming.writer.ContinuousWriter import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType @@ -46,8 +43,6 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister with StreamSinkProvider with RelationProvider with CreatableRelationProvider - with ContinuousWriteSupport - with ContinuousReadSupport with Logging { import KafkaSourceProvider._ @@ -106,43 +101,6 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister failOnDataLoss(caseInsensitiveParams)) } - override def createContinuousReader( - schema: Optional[StructType], - metadataPath: String, - options: DataSourceV2Options): KafkaContinuousReader = { - val parameters = options.asMap().asScala.toMap - validateStreamOptions(parameters) - // Each running query should use its own group id. Otherwise, the query may be only assigned - // partial data since Kafka will assign partitions to multiple consumers having the same group - // id. Hence, we should generate a unique id for each query. - val uniqueGroupId = s"spark-kafka-source-${UUID.randomUUID}-${metadataPath.hashCode}" - - val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) } - val specifiedKafkaParams = - parameters - .keySet - .filter(_.toLowerCase(Locale.ROOT).startsWith("kafka.")) - .map { k => k.drop(6).toString -> parameters(k) } - .toMap - - val startingStreamOffsets = KafkaSourceProvider.getKafkaOffsetRangeLimit(caseInsensitiveParams, - STARTING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit) - - val kafkaOffsetReader = new KafkaOffsetReader( - strategy(caseInsensitiveParams), - kafkaParamsForDriver(specifiedKafkaParams), - parameters, - driverGroupIdPrefix = s"$uniqueGroupId-driver") - - new KafkaContinuousReader( - kafkaOffsetReader, - kafkaParamsForExecutors(specifiedKafkaParams, uniqueGroupId), - parameters, - metadataPath, - startingStreamOffsets, - failOnDataLoss(caseInsensitiveParams)) - } - /** * Returns a new base relation with the given parameters. * @@ -223,22 +181,26 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister } } - override def createContinuousWriter( - queryId: String, - schema: StructType, - mode: OutputMode, - options: DataSourceV2Options): Optional[ContinuousWriter] = { - import scala.collection.JavaConverters._ - - val spark = SparkSession.getActiveSession.get - val topic = Option(options.get(TOPIC_OPTION_KEY).orElse(null)).map(_.trim) - // We convert the options argument from V2 -> Java map -> scala mutable -> scala immutable. - val producerParams = kafkaParamsForProducer(options.asMap.asScala.toMap) - - KafkaWriter.validateQuery( - schema.toAttributes, new java.util.HashMap[String, Object](producerParams.asJava), topic) + private def kafkaParamsForProducer(parameters: Map[String, String]): Map[String, String] = { + val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) } + if (caseInsensitiveParams.contains(s"kafka.${ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG}")) { + throw new IllegalArgumentException( + s"Kafka option '${ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG}' is not supported as keys " + + "are serialized with ByteArraySerializer.") + } - Optional.of(new KafkaContinuousWriter(topic, producerParams, schema)) + if (caseInsensitiveParams.contains(s"kafka.${ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG}")) + { + throw new IllegalArgumentException( + s"Kafka option '${ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG}' is not supported as " + + "value are serialized with ByteArraySerializer.") + } + parameters + .keySet + .filter(_.toLowerCase(Locale.ROOT).startsWith("kafka.")) + .map { k => k.drop(6).toString -> parameters(k) } + .toMap + (ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG -> classOf[ByteArraySerializer].getName, + ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG -> classOf[ByteArraySerializer].getName) } private def strategy(caseInsensitiveParams: Map[String, String]) = @@ -488,27 +450,4 @@ private[kafka010] object KafkaSourceProvider extends Logging { def build(): ju.Map[String, Object] = map } - - private[kafka010] def kafkaParamsForProducer( - parameters: Map[String, String]): Map[String, String] = { - val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) } - if (caseInsensitiveParams.contains(s"kafka.${ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG}")) { - throw new IllegalArgumentException( - s"Kafka option '${ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG}' is not supported as keys " - + "are serialized with ByteArraySerializer.") - } - - if (caseInsensitiveParams.contains(s"kafka.${ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG}")) - { - throw new IllegalArgumentException( - s"Kafka option '${ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG}' is not supported as " - + "value are serialized with ByteArraySerializer.") - } - parameters - .keySet - .filter(_.toLowerCase(Locale.ROOT).startsWith("kafka.")) - .map { k => k.drop(6).toString -> parameters(k) } - .toMap + (ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG -> classOf[ByteArraySerializer].getName, - ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG -> classOf[ByteArraySerializer].getName) - } } diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala index baa60febf661d..6fd333e2f43ba 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala @@ -33,8 +33,10 @@ import org.apache.spark.sql.types.{BinaryType, StringType} private[kafka010] class KafkaWriteTask( producerConfiguration: ju.Map[String, Object], inputSchema: Seq[Attribute], - topic: Option[String]) extends KafkaRowWriter(inputSchema, topic) { + topic: Option[String]) { // used to synchronize with Kafka callbacks + @volatile private var failedWrite: Exception = null + private val projection = createProjection private var producer: KafkaProducer[Array[Byte], Array[Byte]] = _ /** @@ -44,7 +46,23 @@ private[kafka010] class KafkaWriteTask( producer = CachedKafkaProducer.getOrCreate(producerConfiguration) while (iterator.hasNext && failedWrite == null) { val currentRow = iterator.next() - sendRow(currentRow, producer) + val projectedRow = projection(currentRow) + val topic = projectedRow.getUTF8String(0) + val key = projectedRow.getBinary(1) + val value = projectedRow.getBinary(2) + if (topic == null) { + throw new NullPointerException(s"null topic present in the data. Use the " + + s"${KafkaSourceProvider.TOPIC_OPTION_KEY} option for setting a default topic.") + } + val record = new ProducerRecord[Array[Byte], Array[Byte]](topic.toString, key, value) + val callback = new Callback() { + override def onCompletion(recordMetadata: RecordMetadata, e: Exception): Unit = { + if (failedWrite == null && e != null) { + failedWrite = e + } + } + } + producer.send(record, callback) } } @@ -56,49 +74,8 @@ private[kafka010] class KafkaWriteTask( producer = null } } -} - -private[kafka010] abstract class KafkaRowWriter( - inputSchema: Seq[Attribute], topic: Option[String]) { - - // used to synchronize with Kafka callbacks - @volatile protected var failedWrite: Exception = _ - protected val projection = createProjection - - private val callback = new Callback() { - override def onCompletion(recordMetadata: RecordMetadata, e: Exception): Unit = { - if (failedWrite == null && e != null) { - failedWrite = e - } - } - } - /** - * Send the specified row to the producer, with a callback that will save any exception - * to failedWrite. Note that send is asynchronous; subclasses must flush() their producer before - * assuming the row is in Kafka. - */ - protected def sendRow( - row: InternalRow, producer: KafkaProducer[Array[Byte], Array[Byte]]): Unit = { - val projectedRow = projection(row) - val topic = projectedRow.getUTF8String(0) - val key = projectedRow.getBinary(1) - val value = projectedRow.getBinary(2) - if (topic == null) { - throw new NullPointerException(s"null topic present in the data. Use the " + - s"${KafkaSourceProvider.TOPIC_OPTION_KEY} option for setting a default topic.") - } - val record = new ProducerRecord[Array[Byte], Array[Byte]](topic.toString, key, value) - producer.send(record, callback) - } - - protected def checkForErrors(): Unit = { - if (failedWrite != null) { - throw failedWrite - } - } - - private def createProjection = { + private def createProjection: UnsafeProjection = { val topicExpression = topic.map(Literal(_)).orElse { inputSchema.find(_.name == KafkaWriter.TOPIC_ATTRIBUTE_NAME) }.getOrElse { @@ -135,5 +112,11 @@ private[kafka010] abstract class KafkaRowWriter( Seq(topicExpression, Cast(keyExpression, BinaryType), Cast(valueExpression, BinaryType)), inputSchema) } + + private def checkForErrors(): Unit = { + if (failedWrite != null) { + throw failedWrite + } + } } diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala index 15cd44812cb0c..5e9ae35b3f008 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala @@ -43,9 +43,10 @@ private[kafka010] object KafkaWriter extends Logging { override def toString: String = "KafkaWriter" def validateQuery( - schema: Seq[Attribute], + queryExecution: QueryExecution, kafkaParameters: ju.Map[String, Object], topic: Option[String] = None): Unit = { + val schema = queryExecution.analyzed.output schema.find(_.name == TOPIC_ATTRIBUTE_NAME).getOrElse( if (topic.isEmpty) { throw new AnalysisException(s"topic option required when no " + @@ -83,7 +84,7 @@ private[kafka010] object KafkaWriter extends Logging { kafkaParameters: ju.Map[String, Object], topic: Option[String] = None): Unit = { val schema = queryExecution.analyzed.output - validateQuery(schema, kafkaParameters, topic) + validateQuery(queryExecution, kafkaParameters, topic) queryExecution.toRdd.foreachPartition { iter => val writeTask = new KafkaWriteTask(kafkaParameters, schema, topic) Utils.tryWithSafeFinally(block = writeTask.execute(iter))( diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala deleted file mode 100644 index dfc97b1c38bb5..0000000000000 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala +++ /dev/null @@ -1,474 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.kafka010 - -import java.util.Locale -import java.util.concurrent.atomic.AtomicInteger - -import org.apache.kafka.clients.producer.ProducerConfig -import org.apache.kafka.common.serialization.ByteArraySerializer -import org.scalatest.time.SpanSugar._ -import scala.collection.JavaConverters._ - -import org.apache.spark.sql.{AnalysisException, DataFrame, Row, SaveMode} -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, SpecificInternalRow, UnsafeProjection} -import org.apache.spark.sql.execution.streaming.MemoryStream -import org.apache.spark.sql.streaming._ -import org.apache.spark.sql.types.{BinaryType, DataType} -import org.apache.spark.util.Utils - -/** - * This is a temporary port of KafkaSinkSuite, since we do not yet have a V2 memory stream. - * Once we have one, this will be changed to a specialization of KafkaSinkSuite and we won't have - * to duplicate all the code. - */ -class KafkaContinuousSinkSuite extends KafkaContinuousTest { - import testImplicits._ - - override val streamingTimeout = 30.seconds - - override def beforeAll(): Unit = { - super.beforeAll() - testUtils = new KafkaTestUtils( - withBrokerProps = Map("auto.create.topics.enable" -> "false")) - testUtils.setup() - } - - override def afterAll(): Unit = { - if (testUtils != null) { - testUtils.teardown() - testUtils = null - } - super.afterAll() - } - - test("streaming - write to kafka with topic field") { - val inputTopic = newTopic() - testUtils.createTopic(inputTopic, partitions = 1) - - val input = spark - .readStream - .format("kafka") - .option("kafka.bootstrap.servers", testUtils.brokerAddress) - .option("subscribe", inputTopic) - .option("startingOffsets", "earliest") - .load() - - val topic = newTopic() - testUtils.createTopic(topic) - - val writer = createKafkaWriter( - input.toDF(), - withTopic = None, - withOutputMode = Some(OutputMode.Append))( - withSelectExpr = s"'$topic' as topic", "value") - - val reader = createKafkaReader(topic) - .selectExpr("CAST(key as STRING) key", "CAST(value as STRING) value") - .selectExpr("CAST(key as INT) key", "CAST(value as INT) value") - .as[(Int, Int)] - .map(_._2) - - try { - testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) - failAfter(streamingTimeout) { - writer.processAllAvailable() - } - checkDatasetUnorderly(reader, 1, 2, 3, 4, 5) - testUtils.sendMessages(inputTopic, Array("6", "7", "8", "9", "10")) - failAfter(streamingTimeout) { - writer.processAllAvailable() - } - checkDatasetUnorderly(reader, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10) - } finally { - writer.stop() - } - } - - test("streaming - write w/o topic field, with topic option") { - val inputTopic = newTopic() - testUtils.createTopic(inputTopic, partitions = 1) - - val input = spark - .readStream - .format("kafka") - .option("kafka.bootstrap.servers", testUtils.brokerAddress) - .option("subscribe", inputTopic) - .option("startingOffsets", "earliest") - .load() - - val topic = newTopic() - testUtils.createTopic(topic) - - val writer = createKafkaWriter( - input.toDF(), - withTopic = Some(topic), - withOutputMode = Some(OutputMode.Append()))() - - val reader = createKafkaReader(topic) - .selectExpr("CAST(key as STRING) key", "CAST(value as STRING) value") - .selectExpr("CAST(key as INT) key", "CAST(value as INT) value") - .as[(Int, Int)] - .map(_._2) - - try { - testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) - failAfter(streamingTimeout) { - writer.processAllAvailable() - } - checkDatasetUnorderly(reader, 1, 2, 3, 4, 5) - testUtils.sendMessages(inputTopic, Array("6", "7", "8", "9", "10")) - failAfter(streamingTimeout) { - writer.processAllAvailable() - } - checkDatasetUnorderly(reader, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10) - } finally { - writer.stop() - } - } - - test("streaming - topic field and topic option") { - /* The purpose of this test is to ensure that the topic option - * overrides the topic field. We begin by writing some data that - * includes a topic field and value (e.g., 'foo') along with a topic - * option. Then when we read from the topic specified in the option - * we should see the data i.e., the data was written to the topic - * option, and not to the topic in the data e.g., foo - */ - val inputTopic = newTopic() - testUtils.createTopic(inputTopic, partitions = 1) - - val input = spark - .readStream - .format("kafka") - .option("kafka.bootstrap.servers", testUtils.brokerAddress) - .option("subscribe", inputTopic) - .option("startingOffsets", "earliest") - .load() - - val topic = newTopic() - testUtils.createTopic(topic) - - val writer = createKafkaWriter( - input.toDF(), - withTopic = Some(topic), - withOutputMode = Some(OutputMode.Append()))( - withSelectExpr = "'foo' as topic", "CAST(value as STRING) value") - - val reader = createKafkaReader(topic) - .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") - .selectExpr("CAST(key AS INT)", "CAST(value AS INT)") - .as[(Int, Int)] - .map(_._2) - - try { - testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) - failAfter(streamingTimeout) { - writer.processAllAvailable() - } - checkDatasetUnorderly(reader, 1, 2, 3, 4, 5) - testUtils.sendMessages(inputTopic, Array("6", "7", "8", "9", "10")) - failAfter(streamingTimeout) { - writer.processAllAvailable() - } - checkDatasetUnorderly(reader, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10) - } finally { - writer.stop() - } - } - - test("null topic attribute") { - val inputTopic = newTopic() - testUtils.createTopic(inputTopic, partitions = 1) - - val input = spark - .readStream - .format("kafka") - .option("kafka.bootstrap.servers", testUtils.brokerAddress) - .option("subscribe", inputTopic) - .option("startingOffsets", "earliest") - .load() - val topic = newTopic() - testUtils.createTopic(topic) - - /* No topic field or topic option */ - var writer: StreamingQuery = null - var ex: Exception = null - try { - ex = intercept[StreamingQueryException] { - writer = createKafkaWriter(input.toDF())( - withSelectExpr = "CAST(null as STRING) as topic", "value" - ) - testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) - writer.processAllAvailable() - } - } finally { - writer.stop() - } - assert(ex.getCause.getCause.getMessage - .toLowerCase(Locale.ROOT) - .contains("null topic present in the data.")) - } - - test("streaming - write data with bad schema") { - val inputTopic = newTopic() - testUtils.createTopic(inputTopic, partitions = 1) - - val input = spark - .readStream - .format("kafka") - .option("kafka.bootstrap.servers", testUtils.brokerAddress) - .option("subscribe", inputTopic) - .option("startingOffsets", "earliest") - .load() - val topic = newTopic() - testUtils.createTopic(topic) - - /* No topic field or topic option */ - var writer: StreamingQuery = null - var ex: Exception = null - try { - ex = intercept[StreamingQueryException] { - writer = createKafkaWriter(input.toDF())( - withSelectExpr = "value as key", "value" - ) - testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) - writer.processAllAvailable() - } - } finally { - writer.stop() - } - assert(ex.getMessage - .toLowerCase(Locale.ROOT) - .contains("topic option required when no 'topic' attribute is present")) - - try { - /* No value field */ - ex = intercept[StreamingQueryException] { - writer = createKafkaWriter(input.toDF())( - withSelectExpr = s"'$topic' as topic", "value as key" - ) - testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) - writer.processAllAvailable() - } - } finally { - writer.stop() - } - assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( - "required attribute 'value' not found")) - } - - test("streaming - write data with valid schema but wrong types") { - val inputTopic = newTopic() - testUtils.createTopic(inputTopic, partitions = 1) - - val input = spark - .readStream - .format("kafka") - .option("kafka.bootstrap.servers", testUtils.brokerAddress) - .option("subscribe", inputTopic) - .option("startingOffsets", "earliest") - .load() - .selectExpr("CAST(value as STRING) value") - val topic = newTopic() - testUtils.createTopic(topic) - - var writer: StreamingQuery = null - var ex: Exception = null - try { - /* topic field wrong type */ - ex = intercept[StreamingQueryException] { - writer = createKafkaWriter(input.toDF())( - withSelectExpr = s"CAST('1' as INT) as topic", "value" - ) - testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) - writer.processAllAvailable() - } - } finally { - writer.stop() - } - assert(ex.getMessage.toLowerCase(Locale.ROOT).contains("topic type must be a string")) - - try { - /* value field wrong type */ - ex = intercept[StreamingQueryException] { - writer = createKafkaWriter(input.toDF())( - withSelectExpr = s"'$topic' as topic", "CAST(value as INT) as value" - ) - testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) - writer.processAllAvailable() - } - } finally { - writer.stop() - } - assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( - "value attribute type must be a string or binarytype")) - - try { - ex = intercept[StreamingQueryException] { - /* key field wrong type */ - writer = createKafkaWriter(input.toDF())( - withSelectExpr = s"'$topic' as topic", "CAST(value as INT) as key", "value" - ) - testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) - writer.processAllAvailable() - } - } finally { - writer.stop() - } - assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( - "key attribute type must be a string or binarytype")) - } - - test("streaming - write to non-existing topic") { - val inputTopic = newTopic() - testUtils.createTopic(inputTopic, partitions = 1) - - val input = spark - .readStream - .format("kafka") - .option("kafka.bootstrap.servers", testUtils.brokerAddress) - .option("subscribe", inputTopic) - .option("startingOffsets", "earliest") - .load() - val topic = newTopic() - - var writer: StreamingQuery = null - var ex: Exception = null - try { - ex = intercept[StreamingQueryException] { - writer = createKafkaWriter(input.toDF(), withTopic = Some(topic))() - testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) - eventually(timeout(streamingTimeout)) { - assert(writer.exception.isDefined) - } - throw writer.exception.get - } - } finally { - writer.stop() - } - assert(ex.getMessage.toLowerCase(Locale.ROOT).contains("job aborted")) - } - - test("streaming - exception on config serializer") { - val inputTopic = newTopic() - testUtils.createTopic(inputTopic, partitions = 1) - testUtils.sendMessages(inputTopic, Array("0")) - - val input = spark - .readStream - .format("kafka") - .option("kafka.bootstrap.servers", testUtils.brokerAddress) - .option("subscribe", inputTopic) - .load() - var writer: StreamingQuery = null - var ex: Exception = null - try { - ex = intercept[StreamingQueryException] { - writer = createKafkaWriter( - input.toDF(), - withOptions = Map("kafka.key.serializer" -> "foo"))() - writer.processAllAvailable() - } - assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( - "kafka option 'key.serializer' is not supported")) - } finally { - writer.stop() - } - - try { - ex = intercept[StreamingQueryException] { - writer = createKafkaWriter( - input.toDF(), - withOptions = Map("kafka.value.serializer" -> "foo"))() - writer.processAllAvailable() - } - assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( - "kafka option 'value.serializer' is not supported")) - } finally { - writer.stop() - } - } - - test("generic - write big data with small producer buffer") { - /* This test ensures that we understand the semantics of Kafka when - * is comes to blocking on a call to send when the send buffer is full. - * This test will configure the smallest possible producer buffer and - * indicate that we should block when it is full. Thus, no exception should - * be thrown in the case of a full buffer. - */ - val topic = newTopic() - testUtils.createTopic(topic, 1) - val options = new java.util.HashMap[String, String] - options.put("bootstrap.servers", testUtils.brokerAddress) - options.put("buffer.memory", "16384") // min buffer size - options.put("block.on.buffer.full", "true") - options.put(ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG, classOf[ByteArraySerializer].getName) - options.put(ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG, classOf[ByteArraySerializer].getName) - val inputSchema = Seq(AttributeReference("value", BinaryType)()) - val data = new Array[Byte](15000) // large value - val writeTask = new KafkaContinuousDataWriter(Some(topic), options.asScala.toMap, inputSchema) - try { - val fieldTypes: Array[DataType] = Array(BinaryType) - val converter = UnsafeProjection.create(fieldTypes) - val row = new SpecificInternalRow(fieldTypes) - row.update(0, data) - val iter = Seq.fill(1000)(converter.apply(row)).iterator - iter.foreach(writeTask.write(_)) - writeTask.commit() - } finally { - writeTask.close() - } - } - - private def createKafkaReader(topic: String): DataFrame = { - spark.read - .format("kafka") - .option("kafka.bootstrap.servers", testUtils.brokerAddress) - .option("startingOffsets", "earliest") - .option("endingOffsets", "latest") - .option("subscribe", topic) - .load() - } - - private def createKafkaWriter( - input: DataFrame, - withTopic: Option[String] = None, - withOutputMode: Option[OutputMode] = None, - withOptions: Map[String, String] = Map[String, String]()) - (withSelectExpr: String*): StreamingQuery = { - var stream: DataStreamWriter[Row] = null - val checkpointDir = Utils.createTempDir() - var df = input.toDF() - if (withSelectExpr.length > 0) { - df = df.selectExpr(withSelectExpr: _*) - } - stream = df.writeStream - .format("kafka") - .option("checkpointLocation", checkpointDir.getCanonicalPath) - .option("kafka.bootstrap.servers", testUtils.brokerAddress) - // We need to reduce blocking time to efficiently test non-existent partition behavior. - .option("kafka.max.block.ms", "1000") - .trigger(Trigger.Continuous(1000)) - .queryName("kafkaStream") - withTopic.foreach(stream.option("topic", _)) - withOutputMode.foreach(stream.outputMode(_)) - withOptions.foreach(opt => stream.option(opt._1, opt._2)) - stream.start() - } -} diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala deleted file mode 100644 index b3dade414f625..0000000000000 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala +++ /dev/null @@ -1,96 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.kafka010 - -import java.util.Properties -import java.util.concurrent.atomic.AtomicInteger - -import org.scalatest.time.SpanSugar._ -import scala.collection.mutable -import scala.util.Random - -import org.apache.spark.SparkContext -import org.apache.spark.sql.{DataFrame, Dataset, ForeachWriter, Row} -import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation -import org.apache.spark.sql.execution.streaming.StreamExecution -import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution -import org.apache.spark.sql.streaming.{StreamTest, Trigger} -import org.apache.spark.sql.test.{SharedSQLContext, TestSparkSession} - -// Run tests in KafkaSourceSuiteBase in continuous execution mode. -class KafkaContinuousSourceSuite extends KafkaSourceSuiteBase with KafkaContinuousTest - -class KafkaContinuousSourceTopicDeletionSuite extends KafkaContinuousTest { - import testImplicits._ - - override val brokerProps = Map("auto.create.topics.enable" -> "false") - - test("subscribing topic by pattern with topic deletions") { - val topicPrefix = newTopic() - val topic = topicPrefix + "-seems" - val topic2 = topicPrefix + "-bad" - testUtils.createTopic(topic, partitions = 5) - testUtils.sendMessages(topic, Array("-1")) - require(testUtils.getLatestOffsets(Set(topic)).size === 5) - - val reader = spark - .readStream - .format("kafka") - .option("kafka.bootstrap.servers", testUtils.brokerAddress) - .option("kafka.metadata.max.age.ms", "1") - .option("subscribePattern", s"$topicPrefix-.*") - .option("failOnDataLoss", "false") - - val kafka = reader.load() - .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") - .as[(String, String)] - val mapped = kafka.map(kv => kv._2.toInt + 1) - - testStream(mapped)( - makeSureGetOffsetCalled, - AddKafkaData(Set(topic), 1, 2, 3), - CheckAnswer(2, 3, 4), - Execute { query => - testUtils.deleteTopic(topic) - testUtils.createTopic(topic2, partitions = 5) - eventually(timeout(streamingTimeout)) { - assert( - query.lastExecution.logical.collectFirst { - case DataSourceV2Relation(_, r: KafkaContinuousReader) => r - }.exists { r => - // Ensure the new topic is present and the old topic is gone. - r.knownPartitions.exists(_.topic == topic2) - }, - s"query never reconfigured to new topic $topic2") - } - }, - AddKafkaData(Set(topic2), 4, 5, 6), - CheckAnswer(2, 3, 4, 5, 6, 7) - ) - } -} - -class KafkaContinuousSourceStressForDontFailOnDataLossSuite - extends KafkaSourceStressForDontFailOnDataLossSuite { - override protected def startStream(ds: Dataset[Int]) = { - ds.writeStream - .format("memory") - .queryName("memory") - .start() - } -} diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala deleted file mode 100644 index e713e6695d2bd..0000000000000 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala +++ /dev/null @@ -1,64 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.kafka010 - -import org.apache.spark.SparkContext -import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation -import org.apache.spark.sql.execution.streaming.StreamExecution -import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution -import org.apache.spark.sql.streaming.Trigger -import org.apache.spark.sql.test.TestSparkSession - -// Trait to configure StreamTest for kafka continuous execution tests. -trait KafkaContinuousTest extends KafkaSourceTest { - override val defaultTrigger = Trigger.Continuous(1000) - override val defaultUseV2Sink = true - - // We need more than the default local[2] to be able to schedule all partitions simultaneously. - override protected def createSparkSession = new TestSparkSession( - new SparkContext( - "local[10]", - "continuous-stream-test-sql-context", - sparkConf.set("spark.sql.testkey", "true"))) - - // In addition to setting the partitions in Kafka, we have to wait until the query has - // reconfigured to the new count so the test framework can hook in properly. - override protected def setTopicPartitions( - topic: String, newCount: Int, query: StreamExecution) = { - testUtils.addPartitions(topic, newCount) - eventually(timeout(streamingTimeout)) { - assert( - query.lastExecution.logical.collectFirst { - case DataSourceV2Relation(_, r: KafkaContinuousReader) => r - }.exists(_.knownPartitions.size == newCount), - s"query never reconfigured to $newCount partitions") - } - } - - test("ensure continuous stream is being used") { - val query = spark.readStream - .format("rate") - .option("numPartitions", "1") - .option("rowsPerSecond", "1") - .load() - - testStream(query)( - Execute(q => assert(q.isInstanceOf[ContinuousExecution])) - ) - } -} diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala index d66908f86ccc7..2034b9be07f24 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala @@ -34,14 +34,11 @@ import org.scalatest.concurrent.PatienceConfiguration.Timeout import org.scalatest.time.SpanSugar._ import org.apache.spark.SparkContext -import org.apache.spark.sql.{DataFrame, Dataset, ForeachWriter, Row} -import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, WriteToDataSourceV2Exec} +import org.apache.spark.sql.ForeachWriter import org.apache.spark.sql.execution.streaming._ -import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution -import org.apache.spark.sql.execution.streaming.sources.ContinuousMemoryWriter import org.apache.spark.sql.functions.{count, window} import org.apache.spark.sql.kafka010.KafkaSourceProvider._ -import org.apache.spark.sql.streaming.{ProcessingTime, StreamTest, Trigger} +import org.apache.spark.sql.streaming.{ProcessingTime, StreamTest} import org.apache.spark.sql.streaming.util.StreamManualClock import org.apache.spark.sql.test.{SharedSQLContext, TestSparkSession} import org.apache.spark.util.Utils @@ -52,11 +49,9 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext { override val streamingTimeout = 30.seconds - protected val brokerProps = Map[String, Object]() - override def beforeAll(): Unit = { super.beforeAll() - testUtils = new KafkaTestUtils(brokerProps) + testUtils = new KafkaTestUtils testUtils.setup() } @@ -64,25 +59,18 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext { if (testUtils != null) { testUtils.teardown() testUtils = null + super.afterAll() } - super.afterAll() } protected def makeSureGetOffsetCalled = AssertOnQuery { q => // Because KafkaSource's initialPartitionOffsets is set lazily, we need to make sure - // its "getOffset" is called before pushing any data. Otherwise, because of the race condition, + // its "getOffset" is called before pushing any data. Otherwise, because of the race contion, // we don't know which data should be fetched when `startingOffsets` is latest. - q match { - case c: ContinuousExecution => c.awaitEpoch(0) - case m: MicroBatchExecution => m.processAllAvailable() - } + q.processAllAvailable() true } - protected def setTopicPartitions(topic: String, newCount: Int, query: StreamExecution) : Unit = { - testUtils.addPartitions(topic, newCount) - } - /** * Add data to Kafka. * @@ -94,7 +82,7 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext { message: String = "", topicAction: (String, Option[Int]) => Unit = (_, _) => {}) extends AddData { - override def addData(query: Option[StreamExecution]): (BaseStreamingSource, Offset) = { + override def addData(query: Option[StreamExecution]): (Source, Offset) = { if (query.get.isActive) { // Make sure no Spark job is running when deleting a topic query.get.processAllAvailable() @@ -109,18 +97,16 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext { topicAction(existingTopicPartitions._1, Some(existingTopicPartitions._2)) } + // Read all topics again in case some topics are delete. + val allTopics = testUtils.getAllTopicsAndPartitionSize().toMap.keys require( query.nonEmpty, "Cannot add data when there is no query for finding the active kafka source") val sources = query.get.logicalPlan.collect { - case StreamingExecutionRelation(source: KafkaSource, _) => source - } ++ (query.get.lastExecution match { - case null => Seq() - case e => e.logical.collect { - case DataSourceV2Relation(_, reader: KafkaContinuousReader) => reader - } - }) + case StreamingExecutionRelation(source, _) if source.isInstanceOf[KafkaSource] => + source.asInstanceOf[KafkaSource] + } if (sources.isEmpty) { throw new Exception( "Could not find Kafka source in the StreamExecution logical plan to add data to") @@ -151,158 +137,14 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext { override def toString: String = s"AddKafkaData(topics = $topics, data = $data, message = $message)" } - - private val topicId = new AtomicInteger(0) - protected def newTopic(): String = s"topic-${topicId.getAndIncrement()}" } -class KafkaMicroBatchSourceSuite extends KafkaSourceSuiteBase { - - import testImplicits._ - - test("(de)serialization of initial offsets") { - val topic = newTopic() - testUtils.createTopic(topic, partitions = 5) - - val reader = spark - .readStream - .format("kafka") - .option("kafka.bootstrap.servers", testUtils.brokerAddress) - .option("subscribe", topic) - - testStream(reader.load)( - makeSureGetOffsetCalled, - StopStream, - StartStream(), - StopStream) - } - - test("maxOffsetsPerTrigger") { - val topic = newTopic() - testUtils.createTopic(topic, partitions = 3) - testUtils.sendMessages(topic, (100 to 200).map(_.toString).toArray, Some(0)) - testUtils.sendMessages(topic, (10 to 20).map(_.toString).toArray, Some(1)) - testUtils.sendMessages(topic, Array("1"), Some(2)) - - val reader = spark - .readStream - .format("kafka") - .option("kafka.bootstrap.servers", testUtils.brokerAddress) - .option("kafka.metadata.max.age.ms", "1") - .option("maxOffsetsPerTrigger", 10) - .option("subscribe", topic) - .option("startingOffsets", "earliest") - val kafka = reader.load() - .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") - .as[(String, String)] - val mapped: org.apache.spark.sql.Dataset[_] = kafka.map(kv => kv._2.toInt) - - val clock = new StreamManualClock - - val waitUntilBatchProcessed = AssertOnQuery { q => - eventually(Timeout(streamingTimeout)) { - if (!q.exception.isDefined) { - assert(clock.isStreamWaitingAt(clock.getTimeMillis())) - } - } - if (q.exception.isDefined) { - throw q.exception.get - } - true - } - - testStream(mapped)( - StartStream(ProcessingTime(100), clock), - waitUntilBatchProcessed, - // 1 from smallest, 1 from middle, 8 from biggest - CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107), - AdvanceManualClock(100), - waitUntilBatchProcessed, - // smallest now empty, 1 more from middle, 9 more from biggest - CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107, - 11, 108, 109, 110, 111, 112, 113, 114, 115, 116 - ), - StopStream, - StartStream(ProcessingTime(100), clock), - waitUntilBatchProcessed, - // smallest now empty, 1 more from middle, 9 more from biggest - CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107, - 11, 108, 109, 110, 111, 112, 113, 114, 115, 116, - 12, 117, 118, 119, 120, 121, 122, 123, 124, 125 - ), - AdvanceManualClock(100), - waitUntilBatchProcessed, - // smallest now empty, 1 more from middle, 9 more from biggest - CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107, - 11, 108, 109, 110, 111, 112, 113, 114, 115, 116, - 12, 117, 118, 119, 120, 121, 122, 123, 124, 125, - 13, 126, 127, 128, 129, 130, 131, 132, 133, 134 - ) - ) - } - - test("input row metrics") { - val topic = newTopic() - testUtils.createTopic(topic, partitions = 5) - testUtils.sendMessages(topic, Array("-1")) - require(testUtils.getLatestOffsets(Set(topic)).size === 5) - - val kafka = spark - .readStream - .format("kafka") - .option("subscribe", topic) - .option("kafka.bootstrap.servers", testUtils.brokerAddress) - .load() - .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") - .as[(String, String)] - - val mapped = kafka.map(kv => kv._2.toInt + 1) - testStream(mapped)( - StartStream(trigger = ProcessingTime(1)), - makeSureGetOffsetCalled, - AddKafkaData(Set(topic), 1, 2, 3), - CheckAnswer(2, 3, 4), - AssertOnQuery { query => - val recordsRead = query.recentProgress.map(_.numInputRows).sum - recordsRead == 3 - } - ) - } - - test("subscribing topic by pattern with topic deletions") { - val topicPrefix = newTopic() - val topic = topicPrefix + "-seems" - val topic2 = topicPrefix + "-bad" - testUtils.createTopic(topic, partitions = 5) - testUtils.sendMessages(topic, Array("-1")) - require(testUtils.getLatestOffsets(Set(topic)).size === 5) - val reader = spark - .readStream - .format("kafka") - .option("kafka.bootstrap.servers", testUtils.brokerAddress) - .option("kafka.metadata.max.age.ms", "1") - .option("subscribePattern", s"$topicPrefix-.*") - .option("failOnDataLoss", "false") +class KafkaSourceSuite extends KafkaSourceTest { - val kafka = reader.load() - .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") - .as[(String, String)] - val mapped = kafka.map(kv => kv._2.toInt + 1) + import testImplicits._ - testStream(mapped)( - makeSureGetOffsetCalled, - AddKafkaData(Set(topic), 1, 2, 3), - CheckAnswer(2, 3, 4), - Assert { - testUtils.deleteTopic(topic) - testUtils.createTopic(topic2, partitions = 5) - true - }, - AddKafkaData(Set(topic2), 4, 5, 6), - CheckAnswer(2, 3, 4, 5, 6, 7) - ) - } + private val topicId = new AtomicInteger(0) testWithUninterruptibleThread( "deserialization of initial offset with Spark 2.1.0") { @@ -395,51 +237,86 @@ class KafkaMicroBatchSourceSuite extends KafkaSourceSuiteBase { } } - test("KafkaSource with watermark") { - val now = System.currentTimeMillis() + test("(de)serialization of initial offsets") { val topic = newTopic() - testUtils.createTopic(newTopic(), partitions = 1) - testUtils.sendMessages(topic, Array(1).map(_.toString)) + testUtils.createTopic(topic, partitions = 64) - val kafka = spark + val reader = spark .readStream .format("kafka") .option("kafka.bootstrap.servers", testUtils.brokerAddress) - .option("kafka.metadata.max.age.ms", "1") - .option("startingOffsets", s"earliest") .option("subscribe", topic) - .load() - - val windowedAggregation = kafka - .withWatermark("timestamp", "10 seconds") - .groupBy(window($"timestamp", "5 seconds") as 'window) - .agg(count("*") as 'count) - .select($"window".getField("start") as 'window, $"count") - val query = windowedAggregation - .writeStream - .format("memory") - .outputMode("complete") - .queryName("kafkaWatermark") - .start() - query.processAllAvailable() - val rows = spark.table("kafkaWatermark").collect() - assert(rows.length === 1, s"Unexpected results: ${rows.toList}") - val row = rows(0) - // We cannot check the exact window start time as it depands on the time that messages were - // inserted by the producer. So here we just use a low bound to make sure the internal - // conversion works. - assert( - row.getAs[java.sql.Timestamp]("window").getTime >= now - 5 * 1000, - s"Unexpected results: $row") - assert(row.getAs[Int]("count") === 1, s"Unexpected results: $row") - query.stop() + testStream(reader.load)( + makeSureGetOffsetCalled, + StopStream, + StartStream(), + StopStream) } -} -class KafkaSourceSuiteBase extends KafkaSourceTest { + test("maxOffsetsPerTrigger") { + val topic = newTopic() + testUtils.createTopic(topic, partitions = 3) + testUtils.sendMessages(topic, (100 to 200).map(_.toString).toArray, Some(0)) + testUtils.sendMessages(topic, (10 to 20).map(_.toString).toArray, Some(1)) + testUtils.sendMessages(topic, Array("1"), Some(2)) - import testImplicits._ + val reader = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("kafka.metadata.max.age.ms", "1") + .option("maxOffsetsPerTrigger", 10) + .option("subscribe", topic) + .option("startingOffsets", "earliest") + val kafka = reader.load() + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .as[(String, String)] + val mapped: org.apache.spark.sql.Dataset[_] = kafka.map(kv => kv._2.toInt) + + val clock = new StreamManualClock + + val waitUntilBatchProcessed = AssertOnQuery { q => + eventually(Timeout(streamingTimeout)) { + if (!q.exception.isDefined) { + assert(clock.isStreamWaitingAt(clock.getTimeMillis())) + } + } + if (q.exception.isDefined) { + throw q.exception.get + } + true + } + + testStream(mapped)( + StartStream(ProcessingTime(100), clock), + waitUntilBatchProcessed, + // 1 from smallest, 1 from middle, 8 from biggest + CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107), + AdvanceManualClock(100), + waitUntilBatchProcessed, + // smallest now empty, 1 more from middle, 9 more from biggest + CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107, + 11, 108, 109, 110, 111, 112, 113, 114, 115, 116 + ), + StopStream, + StartStream(ProcessingTime(100), clock), + waitUntilBatchProcessed, + // smallest now empty, 1 more from middle, 9 more from biggest + CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107, + 11, 108, 109, 110, 111, 112, 113, 114, 115, 116, + 12, 117, 118, 119, 120, 121, 122, 123, 124, 125 + ), + AdvanceManualClock(100), + waitUntilBatchProcessed, + // smallest now empty, 1 more from middle, 9 more from biggest + CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107, + 11, 108, 109, 110, 111, 112, 113, 114, 115, 116, + 12, 117, 118, 119, 120, 121, 122, 123, 124, 125, + 13, 126, 127, 128, 129, 130, 131, 132, 133, 134 + ) + ) + } test("cannot stop Kafka stream") { val topic = newTopic() @@ -451,7 +328,7 @@ class KafkaSourceSuiteBase extends KafkaSourceTest { .format("kafka") .option("kafka.bootstrap.servers", testUtils.brokerAddress) .option("kafka.metadata.max.age.ms", "1") - .option("subscribePattern", s"$topic.*") + .option("subscribePattern", s"topic-.*") val kafka = reader.load() .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") @@ -545,6 +422,65 @@ class KafkaSourceSuiteBase extends KafkaSourceTest { } } + test("subscribing topic by pattern with topic deletions") { + val topicPrefix = newTopic() + val topic = topicPrefix + "-seems" + val topic2 = topicPrefix + "-bad" + testUtils.createTopic(topic, partitions = 5) + testUtils.sendMessages(topic, Array("-1")) + require(testUtils.getLatestOffsets(Set(topic)).size === 5) + + val reader = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("kafka.metadata.max.age.ms", "1") + .option("subscribePattern", s"$topicPrefix-.*") + .option("failOnDataLoss", "false") + + val kafka = reader.load() + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .as[(String, String)] + val mapped = kafka.map(kv => kv._2.toInt + 1) + + testStream(mapped)( + makeSureGetOffsetCalled, + AddKafkaData(Set(topic), 1, 2, 3), + CheckAnswer(2, 3, 4), + Assert { + testUtils.deleteTopic(topic) + testUtils.createTopic(topic2, partitions = 5) + true + }, + AddKafkaData(Set(topic2), 4, 5, 6), + CheckAnswer(2, 3, 4, 5, 6, 7) + ) + } + + test("starting offset is latest by default") { + val topic = newTopic() + testUtils.createTopic(topic, partitions = 5) + testUtils.sendMessages(topic, Array("0")) + require(testUtils.getLatestOffsets(Set(topic)).size === 5) + + val reader = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("subscribe", topic) + + val kafka = reader.load() + .selectExpr("CAST(value AS STRING)") + .as[String] + val mapped = kafka.map(_.toInt) + + testStream(mapped)( + makeSureGetOffsetCalled, + AddKafkaData(Set(topic), 1, 2, 3), + CheckAnswer(1, 2, 3) // should not have 0 + ) + } + test("bad source options") { def testBadOptions(options: (String, String)*)(expectedMsgs: String*): Unit = { val ex = intercept[IllegalArgumentException] { @@ -604,6 +540,34 @@ class KafkaSourceSuiteBase extends KafkaSourceTest { testUnsupportedConfig("kafka.auto.offset.reset", "latest") } + test("input row metrics") { + val topic = newTopic() + testUtils.createTopic(topic, partitions = 5) + testUtils.sendMessages(topic, Array("-1")) + require(testUtils.getLatestOffsets(Set(topic)).size === 5) + + val kafka = spark + .readStream + .format("kafka") + .option("subscribe", topic) + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .load() + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .as[(String, String)] + + val mapped = kafka.map(kv => kv._2.toInt + 1) + testStream(mapped)( + StartStream(trigger = ProcessingTime(1)), + makeSureGetOffsetCalled, + AddKafkaData(Set(topic), 1, 2, 3), + CheckAnswer(2, 3, 4), + AssertOnQuery { query => + val recordsRead = query.recentProgress.map(_.numInputRows).sum + recordsRead == 3 + } + ) + } + test("delete a topic when a Spark job is running") { KafkaSourceSuite.collectedData.clear() @@ -665,6 +629,8 @@ class KafkaSourceSuiteBase extends KafkaSourceTest { } } + private def newTopic(): String = s"topic-${topicId.getAndIncrement()}" + private def assignString(topic: String, partitions: Iterable[Int]): String = { JsonUtils.partitions(partitions.map(p => new TopicPartition(topic, p))) } @@ -710,10 +676,6 @@ class KafkaSourceSuiteBase extends KafkaSourceTest { testStream(mapped)( makeSureGetOffsetCalled, - Execute { q => - // wait to reach the last offset in every partition - q.awaitOffset(0, KafkaSourceOffset(partitionOffsets.mapValues(_ => 3L))) - }, CheckAnswer(-20, -21, -22, 0, 1, 2, 11, 12, 22), StopStream, StartStream(), @@ -744,7 +706,6 @@ class KafkaSourceSuiteBase extends KafkaSourceTest { .format("memory") .outputMode("append") .queryName("kafkaColumnTypes") - .trigger(defaultTrigger) .start() query.processAllAvailable() val rows = spark.table("kafkaColumnTypes").collect() @@ -762,6 +723,47 @@ class KafkaSourceSuiteBase extends KafkaSourceTest { query.stop() } + test("KafkaSource with watermark") { + val now = System.currentTimeMillis() + val topic = newTopic() + testUtils.createTopic(newTopic(), partitions = 1) + testUtils.sendMessages(topic, Array(1).map(_.toString)) + + val kafka = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("kafka.metadata.max.age.ms", "1") + .option("startingOffsets", s"earliest") + .option("subscribe", topic) + .load() + + val windowedAggregation = kafka + .withWatermark("timestamp", "10 seconds") + .groupBy(window($"timestamp", "5 seconds") as 'window) + .agg(count("*") as 'count) + .select($"window".getField("start") as 'window, $"count") + + val query = windowedAggregation + .writeStream + .format("memory") + .outputMode("complete") + .queryName("kafkaWatermark") + .start() + query.processAllAvailable() + val rows = spark.table("kafkaWatermark").collect() + assert(rows.length === 1, s"Unexpected results: ${rows.toList}") + val row = rows(0) + // We cannot check the exact window start time as it depands on the time that messages were + // inserted by the producer. So here we just use a low bound to make sure the internal + // conversion works. + assert( + row.getAs[java.sql.Timestamp]("window").getTime >= now - 5 * 1000, + s"Unexpected results: $row") + assert(row.getAs[Int]("count") === 1, s"Unexpected results: $row") + query.stop() + } + private def testFromLatestOffsets( topic: String, addPartitions: Boolean, @@ -798,7 +800,9 @@ class KafkaSourceSuiteBase extends KafkaSourceTest { AddKafkaData(Set(topic), 7, 8), CheckAnswer(2, 3, 4, 5, 6, 7, 8, 9), AssertOnQuery("Add partitions") { query: StreamExecution => - if (addPartitions) setTopicPartitions(topic, 10, query) + if (addPartitions) { + testUtils.addPartitions(topic, 10) + } true }, AddKafkaData(Set(topic), 9, 10, 11, 12, 13, 14, 15, 16), @@ -839,7 +843,9 @@ class KafkaSourceSuiteBase extends KafkaSourceTest { StartStream(), CheckAnswer(2, 3, 4, 5, 6, 7, 8, 9), AssertOnQuery("Add partitions") { query: StreamExecution => - if (addPartitions) setTopicPartitions(topic, 10, query) + if (addPartitions) { + testUtils.addPartitions(topic, 10) + } true }, AddKafkaData(Set(topic), 9, 10, 11, 12, 13, 14, 15, 16), @@ -971,23 +977,6 @@ class KafkaSourceStressForDontFailOnDataLossSuite extends StreamTest with Shared } } - protected def startStream(ds: Dataset[Int]) = { - ds.writeStream.foreach(new ForeachWriter[Int] { - - override def open(partitionId: Long, version: Long): Boolean = { - true - } - - override def process(value: Int): Unit = { - // Slow down the processing speed so that messages may be aged out. - Thread.sleep(Random.nextInt(500)) - } - - override def close(errorOrNull: Throwable): Unit = { - } - }).start() - } - test("stress test for failOnDataLoss=false") { val reader = spark .readStream @@ -1001,7 +990,20 @@ class KafkaSourceStressForDontFailOnDataLossSuite extends StreamTest with Shared val kafka = reader.load() .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") .as[(String, String)] - val query = startStream(kafka.map(kv => kv._2.toInt)) + val query = kafka.map(kv => kv._2.toInt).writeStream.foreach(new ForeachWriter[Int] { + + override def open(partitionId: Long, version: Long): Boolean = { + true + } + + override def process(value: Int): Unit = { + // Slow down the processing speed so that messages may be aged out. + Thread.sleep(Random.nextInt(500)) + } + + override def close(errorOrNull: Throwable): Unit = { + } + }).start() val testTime = 1.minutes val startTime = System.currentTimeMillis() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index b714a46b5f786..e8d683a578f35 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -191,9 +191,6 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { ds = ds.asInstanceOf[DataSourceV2], conf = sparkSession.sessionState.conf)).asJava) - // Streaming also uses the data source V2 API. So it may be that the data source implements - // v2, but has no v2 implementation for batch reads. In that case, we fall back to loading - // the dataframe as a v1 source. val reader = (ds, userSpecifiedSchema) match { case (ds: ReadSupportWithSchema, Some(schema)) => ds.createReader(schema, options) @@ -211,30 +208,23 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { } reader - case _ => null // fall back to v1 + case _ => + throw new AnalysisException(s"$cls does not support data reading.") } - if (reader == null) { - loadV1Source(paths: _*) - } else { - Dataset.ofRows(sparkSession, DataSourceV2Relation(reader)) - } + Dataset.ofRows(sparkSession, DataSourceV2Relation(reader)) } else { - loadV1Source(paths: _*) + // Code path for data source v1. + sparkSession.baseRelationToDataFrame( + DataSource.apply( + sparkSession, + paths = paths, + userSpecifiedSchema = userSpecifiedSchema, + className = source, + options = extraOptions.toMap).resolveRelation()) } } - private def loadV1Source(paths: String*) = { - // Code path for data source v1. - sparkSession.baseRelationToDataFrame( - DataSource.apply( - sparkSession, - paths = paths, - userSpecifiedSchema = userSpecifiedSchema, - className = source, - options = extraOptions.toMap).resolveRelation()) - } - /** * Construct a `DataFrame` representing the database table accessible via JDBC URL * url named table and connection properties. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 97f12ff625c42..3304f368e1050 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -255,24 +255,17 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { } } - // Streaming also uses the data source V2 API. So it may be that the data source implements - // v2, but has no v2 implementation for batch writes. In that case, we fall back to saving - // as though it's a V1 source. - case _ => saveToV1Source() + case _ => throw new AnalysisException(s"$cls does not support data writing.") } } else { - saveToV1Source() - } - } - - private def saveToV1Source(): Unit = { - // Code path for data source v1. - runCommand(df.sparkSession, "save") { - DataSource( - sparkSession = df.sparkSession, - className = source, - partitionColumns = partitioningColumns.getOrElse(Nil), - options = extraOptions.toMap).planForWriting(mode, AnalysisBarrier(df.logicalPlan)) + // Code path for data source v1. + runCommand(df.sparkSession, "save") { + DataSource( + sparkSession = df.sparkSession, + className = source, + partitionColumns = partitioningColumns.getOrElse(Nil), + options = extraOptions.toMap).planForWriting(mode, AnalysisBarrier(df.logicalPlan)) + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala index a4a857f2d4d9b..f0bdf84bb7a84 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala @@ -81,11 +81,9 @@ case class WriteToDataSourceV2Exec(writer: DataSourceV2Writer, query: SparkPlan) (index, message: WriterCommitMessage) => messages(index) = message ) - if (!writer.isInstanceOf[ContinuousWriter]) { - logInfo(s"Data source writer $writer is committing.") - writer.commit(messages) - logInfo(s"Data source writer $writer committed.") - } + logInfo(s"Data source writer $writer is committing.") + writer.commit(messages) + logInfo(s"Data source writer $writer committed.") } catch { case _: InterruptedException if writer.isInstanceOf[ContinuousWriter] => // Interruption is how continuous queries are ended, so accept and ignore the exception. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index cf27e1a70650a..24a8b000df0c1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -142,8 +142,7 @@ abstract class StreamExecution( override val id: UUID = UUID.fromString(streamMetadata.id) - override def runId: UUID = currentRunId - protected var currentRunId = UUID.randomUUID + override val runId: UUID = UUID.randomUUID /** * Pretty identified string of printing in logs. Format is @@ -419,17 +418,11 @@ abstract class StreamExecution( * Blocks the current thread until processing for data from the given `source` has reached at * least the given `Offset`. This method is intended for use primarily when writing tests. */ - private[sql] def awaitOffset(sourceIndex: Int, newOffset: Offset): Unit = { + private[sql] def awaitOffset(source: BaseStreamingSource, newOffset: Offset): Unit = { assertAwaitThread() def notDone = { val localCommittedOffsets = committedOffsets - if (sources == null) { - // sources might not be initialized yet - false - } else { - val source = sources(sourceIndex) - !localCommittedOffsets.contains(source) || localCommittedOffsets(source) != newOffset - } + !localCommittedOffsets.contains(source) || localCommittedOffsets(source) != newOffset } while (notDone) { @@ -443,7 +436,7 @@ abstract class StreamExecution( awaitProgressLock.unlock() } } - logDebug(s"Unblocked at $newOffset for ${sources(sourceIndex)}") + logDebug(s"Unblocked at $newOffset for $source") } /** A flag to indicate that a batch has completed with no new data available. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala index e700aa4f9aea7..d79e4bd65f563 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala @@ -77,6 +77,7 @@ class ContinuousDataSourceRDD( dataReaderThread.start() context.addTaskCompletionListener(_ => { + reader.close() dataReaderThread.interrupt() epochPollExecutor.shutdown() }) @@ -200,8 +201,6 @@ class DataReaderThread( failedFlag.set(true) // Don't rethrow the exception in this thread. It's not needed, and the default Spark // exception handler will kill the executor. - } finally { - reader.close() } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index 667410ef9f1c6..9657b5e26d770 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -17,9 +17,7 @@ package org.apache.spark.sql.execution.streaming.continuous -import java.util.UUID import java.util.concurrent.TimeUnit -import java.util.function.UnaryOperator import scala.collection.JavaConverters._ import scala.collection.mutable.{ArrayBuffer, Map => MutableMap} @@ -54,7 +52,7 @@ class ContinuousExecution( sparkSession, name, checkpointRoot, analyzedPlan, sink, trigger, triggerClock, outputMode, deleteCheckpointOnStop) { - @volatile protected var continuousSources: Seq[ContinuousReader] = _ + @volatile protected var continuousSources: Seq[ContinuousReader] = Seq.empty override protected def sources: Seq[BaseStreamingSource] = continuousSources override lazy val logicalPlan: LogicalPlan = { @@ -80,17 +78,15 @@ class ContinuousExecution( } override protected def runActivatedStream(sparkSessionForStream: SparkSession): Unit = { - val stateUpdate = new UnaryOperator[State] { - override def apply(s: State) = s match { - // If we ended the query to reconfigure, reset the state to active. - case RECONFIGURING => ACTIVE - case _ => s - } - } - do { - runContinuous(sparkSessionForStream) - } while (state.updateAndGet(stateUpdate) == ACTIVE) + try { + runContinuous(sparkSessionForStream) + } catch { + case _: InterruptedException if state.get().equals(RECONFIGURING) => + // swallow exception and run again + state.set(ACTIVE) + } + } while (state.get() == ACTIVE) } /** @@ -124,16 +120,12 @@ class ContinuousExecution( } committedOffsets = nextOffsets.toStreamProgress(sources) - // Get to an epoch ID that has definitely never been sent to a sink before. Since sink - // commit happens between offset log write and commit log write, this means an epoch ID - // which is not in the offset log. - val (latestOffsetEpoch, _) = offsetLog.getLatest().getOrElse { - throw new IllegalStateException( - s"Offset log had no latest element. This shouldn't be possible because nextOffsets is" + - s"an element.") - } - currentBatchId = latestOffsetEpoch + 1 + // Forcibly align commit and offset logs by slicing off any spurious offset logs from + // a previous run. We can't allow commits to an epoch that a previous run reached but + // this run has not. + offsetLog.purgeAfter(latestEpochId) + currentBatchId = latestEpochId + 1 logDebug(s"Resuming at epoch $currentBatchId with committed offsets $committedOffsets") nextOffsets case None => @@ -149,7 +141,6 @@ class ContinuousExecution( * @param sparkSessionForQuery Isolated [[SparkSession]] to run the continuous query with. */ private def runContinuous(sparkSessionForQuery: SparkSession): Unit = { - currentRunId = UUID.randomUUID // A list of attributes that will need to be updated. val replacements = new ArrayBuffer[(Attribute, Attribute)] // Translate from continuous relation to the underlying data source. @@ -234,11 +225,13 @@ class ContinuousExecution( triggerExecutor.execute(() => { startTrigger() - if (reader.needsReconfiguration() && state.compareAndSet(ACTIVE, RECONFIGURING)) { + if (reader.needsReconfiguration()) { + state.set(RECONFIGURING) stopSources() if (queryExecutionThread.isAlive) { sparkSession.sparkContext.cancelJobGroup(runId.toString) queryExecutionThread.interrupt() + // No need to join - this thread is about to end anyway. } false } else if (isActive) { @@ -266,7 +259,6 @@ class ContinuousExecution( sparkSessionForQuery, lastExecution)(lastExecution.toRdd) } } finally { - epochEndpoint.askSync[Unit](StopContinuousExecutionWrites) SparkEnv.get.rpcEnv.stop(epochEndpoint) epochUpdateThread.interrupt() @@ -281,22 +273,17 @@ class ContinuousExecution( epoch: Long, reader: ContinuousReader, partitionOffsets: Seq[PartitionOffset]): Unit = { assert(continuousSources.length == 1, "only one continuous source supported currently") - val globalOffset = reader.mergeOffsets(partitionOffsets.toArray) - val oldOffset = synchronized { - offsetLog.add(epoch, OffsetSeq.fill(globalOffset)) - offsetLog.get(epoch - 1) + if (partitionOffsets.contains(null)) { + // If any offset is null, that means the corresponding partition hasn't seen any data yet, so + // there's nothing meaningful to add to the offset log. } - - // If offset hasn't changed since last epoch, there's been no new data. - if (oldOffset.contains(OffsetSeq.fill(globalOffset))) { - noNewData = true - } - - awaitProgressLock.lock() - try { - awaitProgressLockCondition.signalAll() - } finally { - awaitProgressLock.unlock() + val globalOffset = reader.mergeOffsets(partitionOffsets.toArray) + synchronized { + if (queryExecutionThread.isAlive) { + offsetLog.add(epoch, OffsetSeq.fill(globalOffset)) + } else { + return + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala index 40dcbecade814..98017c3ac6a33 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala @@ -39,15 +39,6 @@ private[continuous] sealed trait EpochCoordinatorMessage extends Serializable */ private[sql] case object IncrementAndGetEpoch extends EpochCoordinatorMessage -/** - * The RpcEndpoint stop() will wait to clear out the message queue before terminating the - * object. This can lead to a race condition where the query restarts at epoch n, a new - * EpochCoordinator starts at epoch n, and then the old epoch coordinator commits epoch n + 1. - * The framework doesn't provide a handle to wait on the message queue, so we use a synchronous - * message to stop any writes to the ContinuousExecution object. - */ -private[sql] case object StopContinuousExecutionWrites extends EpochCoordinatorMessage - // Init messages /** * Set the reader and writer partition counts. Tasks may not be started until the coordinator @@ -125,8 +116,6 @@ private[continuous] class EpochCoordinator( override val rpcEnv: RpcEnv) extends ThreadSafeRpcEndpoint with Logging { - private var queryWritesStopped: Boolean = false - private var numReaderPartitions: Int = _ private var numWriterPartitions: Int = _ @@ -158,16 +147,12 @@ private[continuous] class EpochCoordinator( partitionCommits.remove(k) } for (k <- partitionOffsets.keys.filter { case (e, _) => e < epoch }) { - partitionOffsets.remove(k) + partitionCommits.remove(k) } } } override def receive: PartialFunction[Any, Unit] = { - // If we just drop these messages, we won't do any writes to the query. The lame duck tasks - // won't shed errors or anything. - case _ if queryWritesStopped => () - case CommitPartitionEpoch(partitionId, epoch, message) => logDebug(s"Got commit from partition $partitionId at epoch $epoch: $message") if (!partitionCommits.isDefinedAt((epoch, partitionId))) { @@ -203,9 +188,5 @@ private[continuous] class EpochCoordinator( case SetWriterPartitions(numPartitions) => numWriterPartitions = numPartitions context.reply(()) - - case StopContinuousExecutionWrites => - queryWritesStopped = true - context.reply(()) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala index b5b4a05ab4973..db588ae282f38 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala @@ -29,7 +29,6 @@ import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger import org.apache.spark.sql.execution.streaming.sources.{MemoryPlanV2, MemorySinkV2} -import org.apache.spark.sql.sources.v2.streaming.ContinuousWriteSupport /** * Interface used to write a streaming `Dataset` to external storage systems (e.g. file systems, @@ -280,29 +279,18 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { useTempCheckpointLocation = true, trigger = trigger) } else { - val sink = trigger match { - case _: ContinuousTrigger => - val ds = DataSource.lookupDataSource(source, df.sparkSession.sessionState.conf) - ds.newInstance() match { - case w: ContinuousWriteSupport => w - case _ => throw new AnalysisException( - s"Data source $source does not support continuous writing") - } - case _ => - val ds = DataSource( - df.sparkSession, - className = source, - options = extraOptions.toMap, - partitionColumns = normalizedParCols.getOrElse(Nil)) - ds.createSink(outputMode) - } - + val dataSource = + DataSource( + df.sparkSession, + className = source, + options = extraOptions.toMap, + partitionColumns = normalizedParCols.getOrElse(Nil)) df.sparkSession.sessionState.streamingQueryManager.startQuery( extraOptions.get("queryName"), extraOptions.get("checkpointLocation"), df, extraOptions.toMap, - sink, + dataSource.createSink(outputMode), outputMode, useTempCheckpointLocation = source == "console", recoverFromCheckpointLocation = true, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index 0762895fdc620..d46461fa9bf6d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -38,9 +38,8 @@ import org.apache.spark.sql.{Dataset, Encoder, QueryTest, Row} import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder, RowEncoder} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util._ -import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation import org.apache.spark.sql.execution.streaming._ -import org.apache.spark.sql.execution.streaming.continuous.{ContinuousExecution, ContinuousTrigger, EpochCoordinatorRef, IncrementAndGetEpoch} +import org.apache.spark.sql.execution.streaming.continuous.{ContinuousExecution, EpochCoordinatorRef, IncrementAndGetEpoch} import org.apache.spark.sql.execution.streaming.sources.MemorySinkV2 import org.apache.spark.sql.execution.streaming.state.StateStore import org.apache.spark.sql.streaming.StreamingQueryListener._ @@ -81,9 +80,6 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be StateStore.stop() // stop the state store maintenance thread and unload store providers } - protected val defaultTrigger = Trigger.ProcessingTime(0) - protected val defaultUseV2Sink = false - /** How long to wait for an active stream to catch up when checking a result. */ val streamingTimeout = 10.seconds @@ -193,7 +189,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be /** Starts the stream, resuming if data has already been processed. It must not be running. */ case class StartStream( - trigger: Trigger = defaultTrigger, + trigger: Trigger = Trigger.ProcessingTime(0), triggerClock: Clock = new SystemClock, additionalConfs: Map[String, String] = Map.empty, checkpointLocation: String = null) @@ -280,7 +276,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be def testStream( _stream: Dataset[_], outputMode: OutputMode = OutputMode.Append, - useV2Sink: Boolean = defaultUseV2Sink)(actions: StreamAction*): Unit = synchronized { + useV2Sink: Boolean = false)(actions: StreamAction*): Unit = synchronized { import org.apache.spark.sql.streaming.util.StreamManualClock // `synchronized` is added to prevent the user from calling multiple `testStream`s concurrently @@ -407,11 +403,18 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be def fetchStreamAnswer(currentStream: StreamExecution, lastOnly: Boolean) = { verify(currentStream != null, "stream not running") + // Get the map of source index to the current source objects + val indexToSource = currentStream + .logicalPlan + .collect { case StreamingExecutionRelation(s, _) => s } + .zipWithIndex + .map(_.swap) + .toMap // Block until all data added has been processed for all the source awaiting.foreach { case (sourceIndex, offset) => failAfter(streamingTimeout) { - currentStream.awaitOffset(sourceIndex, offset) + currentStream.awaitOffset(indexToSource(sourceIndex), offset) } } @@ -470,12 +473,6 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be // after starting the query. try { currentStream.awaitInitialization(streamingTimeout.toMillis) - currentStream match { - case s: ContinuousExecution => eventually("IncrementalExecution was not created") { - s.lastExecution.executedPlan // will fail if lastExecution is null - } - case _ => - } } catch { case _: StreamingQueryException => // Ignore the exception. `StopStream` or `ExpectFailure` will catch it as well. @@ -603,10 +600,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be def findSourceIndex(plan: LogicalPlan): Option[Int] = { plan - .collect { - case StreamingExecutionRelation(s, _) => s - case DataSourceV2Relation(_, r) => r - } + .collect { case StreamingExecutionRelation(s, _) => s } .zipWithIndex .find(_._1 == source) .map(_._2) @@ -619,13 +613,9 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be findSourceIndex(query.logicalPlan) }.orElse { findSourceIndex(stream.logicalPlan) - }.orElse { - queryToUse.flatMap { q => - findSourceIndex(q.lastExecution.logical) - } }.getOrElse { throw new IllegalArgumentException( - "Could not find index of the source to which data was added") + "Could find index of the source to which data was added") } // Store the expected offset of added data to wait for it later From cd9f49a2aed3799964976ead06080a0f7044a0c3 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Sat, 13 Jan 2018 16:13:44 +0900 Subject: [PATCH 0081/1161] [SPARK-22980][PYTHON][SQL] Clarify the length of each series is of each batch within scalar Pandas UDF ## What changes were proposed in this pull request? This PR proposes to add a note that saying the length of a scalar Pandas UDF's `Series` is not of the whole input column but of the batch. We are fine for a group map UDF because the usage is different from our typical UDF but scalar UDFs might cause confusion with the normal UDF. For example, please consider this example: ```python from pyspark.sql.functions import pandas_udf, col, lit df = spark.range(1) f = pandas_udf(lambda x, y: len(x) + y, LongType()) df.select(f(lit('text'), col('id'))).show() ``` ``` +------------------+ |(text, id)| +------------------+ | 1| +------------------+ ``` ```python from pyspark.sql.functions import udf, col, lit df = spark.range(1) f = udf(lambda x, y: len(x) + y, "long") df.select(f(lit('text'), col('id'))).show() ``` ``` +------------------+ |(text, id)| +------------------+ | 4| +------------------+ ``` ## How was this patch tested? Manually built the doc and checked the output. Author: hyukjinkwon Closes #20237 from HyukjinKwon/SPARK-22980. --- python/pyspark/sql/functions.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 733e32bd825b0..e1ad6590554cf 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2184,6 +2184,11 @@ def pandas_udf(f=None, returnType=None, functionType=None): | 8| JOHN DOE| 22| +----------+--------------+------------+ + .. note:: The length of `pandas.Series` within a scalar UDF is not that of the whole input + column, but is the length of an internal batch used for each call to the function. + Therefore, this can be used, for example, to ensure the length of each returned + `pandas.Series`, and can not be used as the column length. + 2. GROUP_MAP A group map UDF defines transformation: A `pandas.DataFrame` -> A `pandas.DataFrame` From 628a1ca5a4d14397a90e9e96a7e03e8f63531b20 Mon Sep 17 00:00:00 2001 From: shimamoto Date: Sat, 13 Jan 2018 09:40:00 -0600 Subject: [PATCH 0082/1161] [SPARK-23043][BUILD] Upgrade json4s to 3.5.3 ## What changes were proposed in this pull request? Spark still use a few years old version 3.2.11. This change is to upgrade json4s to 3.5.3. Note that this change does not include the Jackson update because the Jackson version referenced in json4s 3.5.3 is 2.8.4, which has a security vulnerability ([see](https://issues.apache.org/jira/browse/SPARK-20433)). ## How was this patch tested? Existing unit tests and build. Author: shimamoto Closes #20233 from shimamoto/upgrade-json4s. --- .../deploy/history/HistoryServerSuite.scala | 2 +- .../org/apache/spark/ui/UISeleniumSuite.scala | 19 ++++++++++--------- dev/deps/spark-deps-hadoop-2.6 | 8 ++++---- dev/deps/spark-deps-hadoop-2.7 | 8 ++++---- pom.xml | 13 +++++++------ 5 files changed, 26 insertions(+), 24 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala index 3738f85da5831..87778dda0e2c8 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala @@ -486,7 +486,7 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers json match { case JNothing => Seq() case apps: JArray => - apps.filter(app => { + apps.children.filter(app => { (app \ "attempts") match { case attempts: JArray => val state = (attempts.children.head \ "completed").asInstanceOf[JBool] diff --git a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala index 326546787ab6c..ed51fc445fdfb 100644 --- a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala @@ -131,7 +131,8 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B val storageJson = getJson(ui, "storage/rdd") storageJson.children.length should be (1) - (storageJson \ "storageLevel").extract[String] should be (StorageLevels.DISK_ONLY.description) + (storageJson.children.head \ "storageLevel").extract[String] should be ( + StorageLevels.DISK_ONLY.description) val rddJson = getJson(ui, "storage/rdd/0") (rddJson \ "storageLevel").extract[String] should be (StorageLevels.DISK_ONLY.description) @@ -150,7 +151,7 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B val updatedStorageJson = getJson(ui, "storage/rdd") updatedStorageJson.children.length should be (1) - (updatedStorageJson \ "storageLevel").extract[String] should be ( + (updatedStorageJson.children.head \ "storageLevel").extract[String] should be ( StorageLevels.MEMORY_ONLY.description) val updatedRddJson = getJson(ui, "storage/rdd/0") (updatedRddJson \ "storageLevel").extract[String] should be ( @@ -204,7 +205,7 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B } val stageJson = getJson(sc.ui.get, "stages") stageJson.children.length should be (1) - (stageJson \ "status").extract[String] should be (StageStatus.FAILED.name()) + (stageJson.children.head \ "status").extract[String] should be (StageStatus.FAILED.name()) // Regression test for SPARK-2105 class NotSerializable @@ -325,11 +326,11 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B find(cssSelector(".progress-cell .progress")).get.text should be ("2/2 (1 failed)") } val jobJson = getJson(sc.ui.get, "jobs") - (jobJson \ "numTasks").extract[Int]should be (2) - (jobJson \ "numCompletedTasks").extract[Int] should be (3) - (jobJson \ "numFailedTasks").extract[Int] should be (1) - (jobJson \ "numCompletedStages").extract[Int] should be (2) - (jobJson \ "numFailedStages").extract[Int] should be (1) + (jobJson \\ "numTasks").extract[Int]should be (2) + (jobJson \\ "numCompletedTasks").extract[Int] should be (3) + (jobJson \\ "numFailedTasks").extract[Int] should be (1) + (jobJson \\ "numCompletedStages").extract[Int] should be (2) + (jobJson \\ "numFailedStages").extract[Int] should be (1) val stageJson = getJson(sc.ui.get, "stages") for { @@ -656,7 +657,7 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B sc.ui.get.webUrl + "/api/v1/applications")) val appListJsonAst = JsonMethods.parse(appListRawJson) appListJsonAst.children.length should be (1) - val attempts = (appListJsonAst \ "attempts").children + val attempts = (appListJsonAst.children.head \ "attempts").children attempts.size should be (1) (attempts(0) \ "completed").extract[Boolean] should be (false) parseDate(attempts(0) \ "startTime") should be (sc.startTime) diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6 index a7fce2ede0ea5..2a298769be44c 100644 --- a/dev/deps/spark-deps-hadoop-2.6 +++ b/dev/deps/spark-deps-hadoop-2.6 @@ -122,9 +122,10 @@ jline-2.12.1.jar joda-time-2.9.3.jar jodd-core-3.5.2.jar jpam-1.1.jar -json4s-ast_2.11-3.2.11.jar -json4s-core_2.11-3.2.11.jar -json4s-jackson_2.11-3.2.11.jar +json4s-ast_2.11-3.5.3.jar +json4s-core_2.11-3.5.3.jar +json4s-jackson_2.11-3.5.3.jar +json4s-scalap_2.11-3.5.3.jar jsr305-1.3.9.jar jta-1.1.jar jtransforms-2.4.0.jar @@ -167,7 +168,6 @@ scala-library-2.11.8.jar scala-parser-combinators_2.11-1.0.4.jar scala-reflect-2.11.8.jar scala-xml_2.11-1.0.5.jar -scalap-2.11.8.jar shapeless_2.11-2.3.2.jar slf4j-api-1.7.16.jar slf4j-log4j12-1.7.16.jar diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index 94b2e98d85e74..abee326f283ab 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -122,9 +122,10 @@ jline-2.12.1.jar joda-time-2.9.3.jar jodd-core-3.5.2.jar jpam-1.1.jar -json4s-ast_2.11-3.2.11.jar -json4s-core_2.11-3.2.11.jar -json4s-jackson_2.11-3.2.11.jar +json4s-ast_2.11-3.5.3.jar +json4s-core_2.11-3.5.3.jar +json4s-jackson_2.11-3.5.3.jar +json4s-scalap_2.11-3.5.3.jar jsp-api-2.1.jar jsr305-1.3.9.jar jta-1.1.jar @@ -168,7 +169,6 @@ scala-library-2.11.8.jar scala-parser-combinators_2.11-1.0.4.jar scala-reflect-2.11.8.jar scala-xml_2.11-1.0.5.jar -scalap-2.11.8.jar shapeless_2.11-2.3.2.jar slf4j-api-1.7.16.jar slf4j-log4j12-1.7.16.jar diff --git a/pom.xml b/pom.xml index d14594aa4ccb0..666d5d7169a15 100644 --- a/pom.xml +++ b/pom.xml @@ -705,7 +705,13 @@ org.json4s json4s-jackson_${scala.binary.version} - 3.2.11 + 3.5.3 + + + com.fasterxml.jackson.core + * + + org.scala-lang @@ -732,11 +738,6 @@ scala-parser-combinators_${scala.binary.version} 1.0.4 - - org.scala-lang - scalap - ${scala.version} - jline From fc6fe8a1d0f161c4788f3db94de49a8669ba3bcc Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Sat, 13 Jan 2018 10:01:44 -0600 Subject: [PATCH 0083/1161] [SPARK-22870][CORE] Dynamic allocation should allow 0 idle time ## What changes were proposed in this pull request? This pr to make `0` as a valid value for `spark.dynamicAllocation.executorIdleTimeout`. For details, see the jira description: https://issues.apache.org/jira/browse/SPARK-22870. ## How was this patch tested? N/A Author: Yuming Wang Author: Yuming Wang Closes #20080 from wangyum/SPARK-22870. --- .../scala/org/apache/spark/ExecutorAllocationManager.scala | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala index 2e00dc8b49dd5..6c59038f2a6c1 100644 --- a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala +++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala @@ -195,8 +195,11 @@ private[spark] class ExecutorAllocationManager( throw new SparkException( "spark.dynamicAllocation.sustainedSchedulerBacklogTimeout must be > 0!") } - if (executorIdleTimeoutS <= 0) { - throw new SparkException("spark.dynamicAllocation.executorIdleTimeout must be > 0!") + if (executorIdleTimeoutS < 0) { + throw new SparkException("spark.dynamicAllocation.executorIdleTimeout must be >= 0!") + } + if (cachedExecutorIdleTimeoutS < 0) { + throw new SparkException("spark.dynamicAllocation.cachedExecutorIdleTimeout must be >= 0!") } // Require external shuffle service for dynamic allocation // Otherwise, we may lose shuffle files when killing executors From bd4a21b4820c4ebaf750131574a6b2eeea36907e Mon Sep 17 00:00:00 2001 From: xubo245 <601450868@qq.com> Date: Sun, 14 Jan 2018 02:28:57 +0800 Subject: [PATCH 0084/1161] [SPARK-23036][SQL][TEST] Add withGlobalTempView for testing ## What changes were proposed in this pull request? Add withGlobalTempView when create global temp view, like withTempView and withView. And correct some improper usage. Please see jira. There are other similar place like that. I will fix it if community need. Please confirm it. ## How was this patch tested? no new test. Author: xubo245 <601450868@qq.com> Closes #20228 from xubo245/DropTempView. --- .../sql/execution/GlobalTempViewSuite.scala | 55 ++++++++----------- .../spark/sql/execution/SQLViewSuite.scala | 34 +++++++----- .../apache/spark/sql/test/SQLTestUtils.scala | 21 +++++-- 3 files changed, 59 insertions(+), 51 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/GlobalTempViewSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/GlobalTempViewSuite.scala index cc943e0356f2a..dcc6fa6403f31 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/GlobalTempViewSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/GlobalTempViewSuite.scala @@ -36,7 +36,7 @@ class GlobalTempViewSuite extends QueryTest with SharedSQLContext { test("basic semantic") { val expectedErrorMsg = "not found" - try { + withGlobalTempView("src") { sql("CREATE GLOBAL TEMP VIEW src AS SELECT 1, 'a'") // If there is no database in table name, we should try local temp view first, if not found, @@ -79,19 +79,15 @@ class GlobalTempViewSuite extends QueryTest with SharedSQLContext { // We can also use Dataset API to replace global temp view Seq(2 -> "b").toDF("i", "j").createOrReplaceGlobalTempView("src") checkAnswer(spark.table(s"$globalTempDB.src"), Row(2, "b")) - } finally { - spark.catalog.dropGlobalTempView("src") } } test("global temp view is shared among all sessions") { - try { + withGlobalTempView("src") { sql("CREATE GLOBAL TEMP VIEW src AS SELECT 1, 2") checkAnswer(spark.table(s"$globalTempDB.src"), Row(1, 2)) val newSession = spark.newSession() checkAnswer(newSession.table(s"$globalTempDB.src"), Row(1, 2)) - } finally { - spark.catalog.dropGlobalTempView("src") } } @@ -105,27 +101,25 @@ class GlobalTempViewSuite extends QueryTest with SharedSQLContext { test("CREATE GLOBAL TEMP VIEW USING") { withTempPath { path => - try { + withGlobalTempView("src") { Seq(1 -> "a").toDF("i", "j").write.parquet(path.getAbsolutePath) sql(s"CREATE GLOBAL TEMP VIEW src USING parquet OPTIONS (PATH '${path.toURI}')") checkAnswer(spark.table(s"$globalTempDB.src"), Row(1, "a")) sql(s"INSERT INTO $globalTempDB.src SELECT 2, 'b'") checkAnswer(spark.table(s"$globalTempDB.src"), Row(1, "a") :: Row(2, "b") :: Nil) - } finally { - spark.catalog.dropGlobalTempView("src") } } } test("CREATE TABLE LIKE should work for global temp view") { - try { - sql("CREATE GLOBAL TEMP VIEW src AS SELECT 1 AS a, '2' AS b") - sql(s"CREATE TABLE cloned LIKE $globalTempDB.src") - val tableMeta = spark.sessionState.catalog.getTableMetadata(TableIdentifier("cloned")) - assert(tableMeta.schema == new StructType().add("a", "int", false).add("b", "string", false)) - } finally { - spark.catalog.dropGlobalTempView("src") - sql("DROP TABLE default.cloned") + withTable("cloned") { + withGlobalTempView("src") { + sql("CREATE GLOBAL TEMP VIEW src AS SELECT 1 AS a, '2' AS b") + sql(s"CREATE TABLE cloned LIKE $globalTempDB.src") + val tableMeta = spark.sessionState.catalog.getTableMetadata(TableIdentifier("cloned")) + assert(tableMeta.schema == new StructType() + .add("a", "int", false).add("b", "string", false)) + } } } @@ -146,26 +140,25 @@ class GlobalTempViewSuite extends QueryTest with SharedSQLContext { } test("should lookup global temp view if and only if global temp db is specified") { - try { - sql("CREATE GLOBAL TEMP VIEW same_name AS SELECT 3, 4") - sql("CREATE TEMP VIEW same_name AS SELECT 1, 2") + withTempView("same_name") { + withGlobalTempView("same_name") { + sql("CREATE GLOBAL TEMP VIEW same_name AS SELECT 3, 4") + sql("CREATE TEMP VIEW same_name AS SELECT 1, 2") - checkAnswer(sql("SELECT * FROM same_name"), Row(1, 2)) + checkAnswer(sql("SELECT * FROM same_name"), Row(1, 2)) - // we never lookup global temp views if database is not specified in table name - spark.catalog.dropTempView("same_name") - intercept[AnalysisException](sql("SELECT * FROM same_name")) + // we never lookup global temp views if database is not specified in table name + spark.catalog.dropTempView("same_name") + intercept[AnalysisException](sql("SELECT * FROM same_name")) - // Use qualified name to lookup a global temp view. - checkAnswer(sql(s"SELECT * FROM $globalTempDB.same_name"), Row(3, 4)) - } finally { - spark.catalog.dropTempView("same_name") - spark.catalog.dropGlobalTempView("same_name") + // Use qualified name to lookup a global temp view. + checkAnswer(sql(s"SELECT * FROM $globalTempDB.same_name"), Row(3, 4)) + } } } test("public Catalog should recognize global temp view") { - try { + withGlobalTempView("src") { sql("CREATE GLOBAL TEMP VIEW src AS SELECT 1, 2") assert(spark.catalog.tableExists(globalTempDB, "src")) @@ -175,8 +168,6 @@ class GlobalTempViewSuite extends QueryTest with SharedSQLContext { description = null, tableType = "TEMPORARY", isTemporary = true).toString) - } finally { - spark.catalog.dropGlobalTempView("src") } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala index 08a4a21b20f61..8c55758cfe38d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala @@ -69,21 +69,25 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils { } test("create a permanent view on a temp view") { - withView("jtv1", "temp_jtv1", "global_temp_jtv1") { - sql("CREATE TEMPORARY VIEW temp_jtv1 AS SELECT * FROM jt WHERE id > 3") - var e = intercept[AnalysisException] { - sql("CREATE VIEW jtv1 AS SELECT * FROM temp_jtv1 WHERE id < 6") - }.getMessage - assert(e.contains("Not allowed to create a permanent view `jtv1` by " + - "referencing a temporary view `temp_jtv1`")) - - val globalTempDB = spark.sharedState.globalTempViewManager.database - sql("CREATE GLOBAL TEMP VIEW global_temp_jtv1 AS SELECT * FROM jt WHERE id > 0") - e = intercept[AnalysisException] { - sql(s"CREATE VIEW jtv1 AS SELECT * FROM $globalTempDB.global_temp_jtv1 WHERE id < 6") - }.getMessage - assert(e.contains(s"Not allowed to create a permanent view `jtv1` by referencing " + - s"a temporary view `global_temp`.`global_temp_jtv1`")) + withView("jtv1") { + withTempView("temp_jtv1") { + withGlobalTempView("global_temp_jtv1") { + sql("CREATE TEMPORARY VIEW temp_jtv1 AS SELECT * FROM jt WHERE id > 3") + var e = intercept[AnalysisException] { + sql("CREATE VIEW jtv1 AS SELECT * FROM temp_jtv1 WHERE id < 6") + }.getMessage + assert(e.contains("Not allowed to create a permanent view `jtv1` by " + + "referencing a temporary view `temp_jtv1`")) + + val globalTempDB = spark.sharedState.globalTempViewManager.database + sql("CREATE GLOBAL TEMP VIEW global_temp_jtv1 AS SELECT * FROM jt WHERE id > 0") + e = intercept[AnalysisException] { + sql(s"CREATE VIEW jtv1 AS SELECT * FROM $globalTempDB.global_temp_jtv1 WHERE id < 6") + }.getMessage + assert(e.contains(s"Not allowed to create a permanent view `jtv1` by referencing " + + s"a temporary view `global_temp`.`global_temp_jtv1`")) + } + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index 904f9f2ad0b22..bc4a120f7042f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -254,13 +254,26 @@ private[sql] trait SQLTestUtilsBase } /** - * Drops temporary table `tableName` after calling `f`. + * Drops temporary view `viewNames` after calling `f`. */ - protected def withTempView(tableNames: String*)(f: => Unit): Unit = { + protected def withTempView(viewNames: String*)(f: => Unit): Unit = { try f finally { // If the test failed part way, we don't want to mask the failure by failing to remove - // temp tables that never got created. - try tableNames.foreach(spark.catalog.dropTempView) catch { + // temp views that never got created. + try viewNames.foreach(spark.catalog.dropTempView) catch { + case _: NoSuchTableException => + } + } + } + + /** + * Drops global temporary view `viewNames` after calling `f`. + */ + protected def withGlobalTempView(viewNames: String*)(f: => Unit): Unit = { + try f finally { + // If the test failed part way, we don't want to mask the failure by failing to remove + // global temp views that never got created. + try viewNames.foreach(spark.catalog.dropGlobalTempView) catch { case _: NoSuchTableException => } } From ba891ec993c616dc4249fc786c56ea82ed04a827 Mon Sep 17 00:00:00 2001 From: CodingCat Date: Sun, 14 Jan 2018 02:36:32 +0800 Subject: [PATCH 0085/1161] [SPARK-22790][SQL] add a configurable factor to describe HadoopFsRelation's size ## What changes were proposed in this pull request? as per discussion in https://github.com/apache/spark/pull/19864#discussion_r156847927 the current HadoopFsRelation is purely based on the underlying file size which is not accurate and makes the execution vulnerable to errors like OOM Users can enable CBO with the functionalities in https://github.com/apache/spark/pull/19864 to avoid this issue This JIRA proposes to add a configurable factor to sizeInBytes method in HadoopFsRelation class so that users can mitigate this problem without CBO ## How was this patch tested? Existing tests Author: CodingCat Author: Nan Zhu Closes #20072 from CodingCat/SPARK-22790. --- .../apache/spark/sql/internal/SQLConf.scala | 13 +++++- .../datasources/HadoopFsRelation.scala | 6 ++- .../datasources/HadoopFsRelationSuite.scala | 41 +++++++++++++++++++ 3 files changed, 58 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 36e802a9faa6f..6746fbcaf2483 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -249,7 +249,7 @@ object SQLConf { val CONSTRAINT_PROPAGATION_ENABLED = buildConf("spark.sql.constraintPropagation.enabled") .internal() .doc("When true, the query optimizer will infer and propagate data constraints in the query " + - "plan to optimize them. Constraint propagation can sometimes be computationally expensive" + + "plan to optimize them. Constraint propagation can sometimes be computationally expensive " + "for certain kinds of query plans (such as those with a large number of predicates and " + "aliases) which might negatively impact overall runtime.") .booleanConf @@ -263,6 +263,15 @@ object SQLConf { .booleanConf .createWithDefault(false) + val FILE_COMRESSION_FACTOR = buildConf("spark.sql.sources.fileCompressionFactor") + .internal() + .doc("When estimating the output data size of a table scan, multiply the file size with this " + + "factor as the estimated data size, in case the data is compressed in the file and lead to" + + " a heavily underestimated result.") + .doubleConf + .checkValue(_ > 0, "the value of fileDataSizeFactor must be larger than 0") + .createWithDefault(1.0) + val PARQUET_SCHEMA_MERGING_ENABLED = buildConf("spark.sql.parquet.mergeSchema") .doc("When true, the Parquet data source merges schemas collected from all data files, " + "otherwise the schema is picked from the summary file or a random data file " + @@ -1255,6 +1264,8 @@ class SQLConf extends Serializable with Logging { def escapedStringLiterals: Boolean = getConf(ESCAPED_STRING_LITERALS) + def fileCompressionFactor: Double = getConf(FILE_COMRESSION_FACTOR) + def stringRedationPattern: Option[Regex] = SQL_STRING_REDACTION_PATTERN.readFrom(reader) /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelation.scala index 89d8a85a9cbd2..6b34638529770 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelation.scala @@ -82,7 +82,11 @@ case class HadoopFsRelation( } } - override def sizeInBytes: Long = location.sizeInBytes + override def sizeInBytes: Long = { + val compressionFactor = sqlContext.conf.fileCompressionFactor + (location.sizeInBytes * compressionFactor).toLong + } + override def inputFiles: Array[String] = location.inputFiles } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelationSuite.scala index caf03885e3873..c1f2c18d1417d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelationSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.datasources import java.io.{File, FilenameFilter} import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, SortMergeJoinExec} import org.apache.spark.sql.test.SharedSQLContext class HadoopFsRelationSuite extends QueryTest with SharedSQLContext { @@ -39,4 +40,44 @@ class HadoopFsRelationSuite extends QueryTest with SharedSQLContext { assert(df.queryExecution.logical.stats.sizeInBytes === BigInt(totalSize)) } } + + test("SPARK-22790: spark.sql.sources.compressionFactor takes effect") { + import testImplicits._ + Seq(1.0, 0.5).foreach { compressionFactor => + withSQLConf("spark.sql.sources.fileCompressionFactor" -> compressionFactor.toString, + "spark.sql.autoBroadcastJoinThreshold" -> "400") { + withTempPath { workDir => + // the file size is 740 bytes + val workDirPath = workDir.getAbsolutePath + val data1 = Seq(100, 200, 300, 400).toDF("count") + data1.write.parquet(workDirPath + "/data1") + val df1FromFile = spark.read.parquet(workDirPath + "/data1") + val data2 = Seq(100, 200, 300, 400).toDF("count") + data2.write.parquet(workDirPath + "/data2") + val df2FromFile = spark.read.parquet(workDirPath + "/data2") + val joinedDF = df1FromFile.join(df2FromFile, Seq("count")) + if (compressionFactor == 0.5) { + val bJoinExec = joinedDF.queryExecution.executedPlan.collect { + case bJoin: BroadcastHashJoinExec => bJoin + } + assert(bJoinExec.nonEmpty) + val smJoinExec = joinedDF.queryExecution.executedPlan.collect { + case smJoin: SortMergeJoinExec => smJoin + } + assert(smJoinExec.isEmpty) + } else { + // compressionFactor is 1.0 + val bJoinExec = joinedDF.queryExecution.executedPlan.collect { + case bJoin: BroadcastHashJoinExec => bJoin + } + assert(bJoinExec.isEmpty) + val smJoinExec = joinedDF.queryExecution.executedPlan.collect { + case smJoin: SortMergeJoinExec => smJoin + } + assert(smJoinExec.nonEmpty) + } + } + } + } + } } From 0066d6f6fa604817468471832968d4339f71c5cb Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Sun, 14 Jan 2018 05:39:38 +0800 Subject: [PATCH 0086/1161] [SPARK-21213][SQL][FOLLOWUP] Use compatible types for comparisons in compareAndGetNewStats ## What changes were proposed in this pull request? This pr fixed code to compare values in `compareAndGetNewStats`. The test below fails in the current master; ``` val oldStats2 = CatalogStatistics(sizeInBytes = BigInt(Long.MaxValue) * 2) val newStats5 = CommandUtils.compareAndGetNewStats( Some(oldStats2), newTotalSize = BigInt(Long.MaxValue) * 2, None) assert(newStats5.isEmpty) ``` ## How was this patch tested? Added some tests in `CommandUtilsSuite`. Author: Takeshi Yamamuro Closes #20245 from maropu/SPARK-21213-FOLLOWUP. --- .../sql/execution/command/CommandUtils.scala | 4 +- .../execution/command/CommandUtilsSuite.scala | 56 +++++++++++++++++++ 2 files changed, 58 insertions(+), 2 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/command/CommandUtilsSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala index 1a0d67fc71fbc..c27048626c8eb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala @@ -116,8 +116,8 @@ object CommandUtils extends Logging { oldStats: Option[CatalogStatistics], newTotalSize: BigInt, newRowCount: Option[BigInt]): Option[CatalogStatistics] = { - val oldTotalSize = oldStats.map(_.sizeInBytes.toLong).getOrElse(-1L) - val oldRowCount = oldStats.flatMap(_.rowCount.map(_.toLong)).getOrElse(-1L) + val oldTotalSize = oldStats.map(_.sizeInBytes).getOrElse(BigInt(-1)) + val oldRowCount = oldStats.flatMap(_.rowCount).getOrElse(BigInt(-1)) var newStats: Option[CatalogStatistics] = None if (newTotalSize >= 0 && newTotalSize != oldTotalSize) { newStats = Some(CatalogStatistics(sizeInBytes = newTotalSize)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CommandUtilsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CommandUtilsSuite.scala new file mode 100644 index 0000000000000..f3e15189a6418 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CommandUtilsSuite.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.command + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.catalog.CatalogStatistics + +class CommandUtilsSuite extends SparkFunSuite { + + test("Check if compareAndGetNewStats returns correct results") { + val oldStats1 = CatalogStatistics(sizeInBytes = 10, rowCount = Some(100)) + val newStats1 = CommandUtils.compareAndGetNewStats( + Some(oldStats1), newTotalSize = 10, newRowCount = Some(100)) + assert(newStats1.isEmpty) + val newStats2 = CommandUtils.compareAndGetNewStats( + Some(oldStats1), newTotalSize = -1, newRowCount = None) + assert(newStats2.isEmpty) + val newStats3 = CommandUtils.compareAndGetNewStats( + Some(oldStats1), newTotalSize = 20, newRowCount = Some(-1)) + assert(newStats3.isDefined) + newStats3.foreach { stat => + assert(stat.sizeInBytes === 20) + assert(stat.rowCount.isEmpty) + } + val newStats4 = CommandUtils.compareAndGetNewStats( + Some(oldStats1), newTotalSize = -1, newRowCount = Some(200)) + assert(newStats4.isDefined) + newStats4.foreach { stat => + assert(stat.sizeInBytes === 10) + assert(stat.rowCount.isDefined && stat.rowCount.get === 200) + } + } + + test("Check if compareAndGetNewStats can handle large values") { + // Tests for large values + val oldStats2 = CatalogStatistics(sizeInBytes = BigInt(Long.MaxValue) * 2) + val newStats5 = CommandUtils.compareAndGetNewStats( + Some(oldStats2), newTotalSize = BigInt(Long.MaxValue) * 2, None) + assert(newStats5.isEmpty) + } +} From afae8f2bc82597593595af68d1aa2d802210ea8b Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Sun, 14 Jan 2018 11:26:49 +0900 Subject: [PATCH 0087/1161] [SPARK-22959][PYTHON] Configuration to select the modules for daemon and worker in PySpark ## What changes were proposed in this pull request? We are now forced to use `pyspark/daemon.py` and `pyspark/worker.py` in PySpark. This doesn't allow a custom modification for it (well, maybe we can still do this in a super hacky way though, for example, setting Python executable that has the custom modification). Because of this, for example, it's sometimes hard to debug what happens inside Python worker processes. This is actually related with [SPARK-7721](https://issues.apache.org/jira/browse/SPARK-7721) too as somehow Coverage is unable to detect the coverage from `os.fork`. If we have some custom fixes to force the coverage, it works fine. This is also related with [SPARK-20368](https://issues.apache.org/jira/browse/SPARK-20368). This JIRA describes Sentry support which (roughly) needs some changes within worker side. With this configuration advanced users will be able to do a lot of pluggable workarounds and we can meet such potential needs in the future. As an example, let's say if I configure the module `coverage_daemon` and had `coverage_daemon.py` in the python path: ```python import os from pyspark import daemon if "COVERAGE_PROCESS_START" in os.environ: from pyspark.worker import main def _cov_wrapped(*args, **kwargs): import coverage cov = coverage.coverage( config_file=os.environ["COVERAGE_PROCESS_START"]) cov.start() try: main(*args, **kwargs) finally: cov.stop() cov.save() daemon.worker_main = _cov_wrapped if __name__ == '__main__': daemon.manager() ``` I can track the coverages in worker side too. More importantly, we can leave the main code intact but allow some workarounds. ## How was this patch tested? Manually tested. Author: hyukjinkwon Closes #20151 from HyukjinKwon/configuration-daemon-worker. --- .../api/python/PythonWorkerFactory.scala | 41 +++++++++++++++---- 1 file changed, 32 insertions(+), 9 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala index f53c6178047f5..30976ac752a8a 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala @@ -34,10 +34,10 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String import PythonWorkerFactory._ - // Because forking processes from Java is expensive, we prefer to launch a single Python daemon - // (pyspark/daemon.py) and tell it to fork new workers for our tasks. This daemon currently - // only works on UNIX-based systems now because it uses signals for child management, so we can - // also fall back to launching workers (pyspark/worker.py) directly. + // Because forking processes from Java is expensive, we prefer to launch a single Python daemon, + // pyspark/daemon.py (by default) and tell it to fork new workers for our tasks. This daemon + // currently only works on UNIX-based systems now because it uses signals for child management, + // so we can also fall back to launching workers, pyspark/worker.py (by default) directly. val useDaemon = { val useDaemonEnabled = SparkEnv.get.conf.getBoolean("spark.python.use.daemon", true) @@ -45,6 +45,28 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String !System.getProperty("os.name").startsWith("Windows") && useDaemonEnabled } + // WARN: Both configurations, 'spark.python.daemon.module' and 'spark.python.worker.module' are + // for very advanced users and they are experimental. This should be considered + // as expert-only option, and shouldn't be used before knowing what it means exactly. + + // This configuration indicates the module to run the daemon to execute its Python workers. + val daemonModule = SparkEnv.get.conf.getOption("spark.python.daemon.module").map { value => + logInfo( + s"Python daemon module in PySpark is set to [$value] in 'spark.python.daemon.module', " + + "using this to start the daemon up. Note that this configuration only has an effect when " + + "'spark.python.use.daemon' is enabled and the platform is not Windows.") + value + }.getOrElse("pyspark.daemon") + + // This configuration indicates the module to run each Python worker. + val workerModule = SparkEnv.get.conf.getOption("spark.python.worker.module").map { value => + logInfo( + s"Python worker module in PySpark is set to [$value] in 'spark.python.worker.module', " + + "using this to start the worker up. Note that this configuration only has an effect when " + + "'spark.python.use.daemon' is disabled or the platform is Windows.") + value + }.getOrElse("pyspark.worker") + var daemon: Process = null val daemonHost = InetAddress.getByAddress(Array(127, 0, 0, 1)) var daemonPort: Int = 0 @@ -74,8 +96,9 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String } /** - * Connect to a worker launched through pyspark/daemon.py, which forks python processes itself - * to avoid the high cost of forking from Java. This currently only works on UNIX-based systems. + * Connect to a worker launched through pyspark/daemon.py (by default), which forks python + * processes itself to avoid the high cost of forking from Java. This currently only works + * on UNIX-based systems. */ private def createThroughDaemon(): Socket = { @@ -108,7 +131,7 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String } /** - * Launch a worker by executing worker.py directly and telling it to connect to us. + * Launch a worker by executing worker.py (by default) directly and telling it to connect to us. */ private def createSimpleWorker(): Socket = { var serverSocket: ServerSocket = null @@ -116,7 +139,7 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String serverSocket = new ServerSocket(0, 1, InetAddress.getByAddress(Array(127, 0, 0, 1))) // Create and start the worker - val pb = new ProcessBuilder(Arrays.asList(pythonExec, "-m", "pyspark.worker")) + val pb = new ProcessBuilder(Arrays.asList(pythonExec, "-m", workerModule)) val workerEnv = pb.environment() workerEnv.putAll(envVars.asJava) workerEnv.put("PYTHONPATH", pythonPath) @@ -159,7 +182,7 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String try { // Create and start the daemon - val pb = new ProcessBuilder(Arrays.asList(pythonExec, "-m", "pyspark.daemon")) + val pb = new ProcessBuilder(Arrays.asList(pythonExec, "-m", daemonModule)) val workerEnv = pb.environment() workerEnv.putAll(envVars.asJava) workerEnv.put("PYTHONPATH", pythonPath) From c3548d11c3c57e8f2c6ebd9d2d6a3924ddcd3cba Mon Sep 17 00:00:00 2001 From: foxish Date: Sat, 13 Jan 2018 21:34:28 -0800 Subject: [PATCH 0088/1161] [SPARK-23063][K8S] K8s changes for publishing scripts (and a couple of other misses) ## What changes were proposed in this pull request? Including the `-Pkubernetes` flag in a few places it was missed. ## How was this patch tested? checkstyle, mima through manual tests. Author: foxish Closes #20256 from foxish/SPARK-23063. --- dev/create-release/release-build.sh | 4 ++-- dev/create-release/releaseutils.py | 2 ++ dev/deps/spark-deps-hadoop-2.6 | 11 +++++++++++ dev/deps/spark-deps-hadoop-2.7 | 11 +++++++++++ dev/lint-java | 2 +- dev/mima | 2 +- dev/scalastyle | 1 + dev/test-dependencies.sh | 2 +- 8 files changed, 30 insertions(+), 5 deletions(-) diff --git a/dev/create-release/release-build.sh b/dev/create-release/release-build.sh index c71137468054f..a3579f21fc539 100755 --- a/dev/create-release/release-build.sh +++ b/dev/create-release/release-build.sh @@ -92,9 +92,9 @@ MVN="build/mvn --force" # Hive-specific profiles for some builds HIVE_PROFILES="-Phive -Phive-thriftserver" # Profiles for publishing snapshots and release to Maven Central -PUBLISH_PROFILES="-Pmesos -Pyarn -Pflume $HIVE_PROFILES -Pspark-ganglia-lgpl -Pkinesis-asl" +PUBLISH_PROFILES="-Pmesos -Pyarn -Pkubernetes -Pflume $HIVE_PROFILES -Pspark-ganglia-lgpl -Pkinesis-asl" # Profiles for building binary releases -BASE_RELEASE_PROFILES="-Pmesos -Pyarn -Pflume -Psparkr" +BASE_RELEASE_PROFILES="-Pmesos -Pyarn -Pkubernetes -Pflume -Psparkr" # Scala 2.11 only profiles for some builds SCALA_2_11_PROFILES="-Pkafka-0-8" # Scala 2.12 only profiles for some builds diff --git a/dev/create-release/releaseutils.py b/dev/create-release/releaseutils.py index 730138195e5fe..32f6cbb29f0be 100755 --- a/dev/create-release/releaseutils.py +++ b/dev/create-release/releaseutils.py @@ -185,6 +185,8 @@ def get_commits(tag): "graphx": "GraphX", "input/output": CORE_COMPONENT, "java api": "Java API", + "k8s": "Kubernetes", + "kubernetes": "Kubernetes", "mesos": "Mesos", "ml": "MLlib", "mllib": "MLlib", diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6 index 2a298769be44c..48e54568e6fc6 100644 --- a/dev/deps/spark-deps-hadoop-2.6 +++ b/dev/deps/spark-deps-hadoop-2.6 @@ -17,6 +17,7 @@ arpack_combined_all-0.1.jar arrow-format-0.8.0.jar arrow-memory-0.8.0.jar arrow-vector-0.8.0.jar +automaton-1.11-8.jar avro-1.7.7.jar avro-ipc-1.7.7.jar avro-mapred-1.7.7-hadoop2.jar @@ -60,6 +61,7 @@ datanucleus-rdbms-3.2.9.jar derby-10.12.1.1.jar eigenbase-properties-1.1.5.jar flatbuffers-1.2.0-3f79e055.jar +generex-1.0.1.jar gson-2.2.4.jar guava-14.0.1.jar guice-3.0.jar @@ -91,8 +93,10 @@ jackson-annotations-2.6.7.jar jackson-core-2.6.7.jar jackson-core-asl-1.9.13.jar jackson-databind-2.6.7.1.jar +jackson-dataformat-yaml-2.6.7.jar jackson-jaxrs-1.9.13.jar jackson-mapper-asl-1.9.13.jar +jackson-module-jaxb-annotations-2.6.7.jar jackson-module-paranamer-2.7.9.jar jackson-module-scala_2.11-2.6.7.1.jar jackson-xc-1.9.13.jar @@ -131,10 +135,13 @@ jta-1.1.jar jtransforms-2.4.0.jar jul-to-slf4j-1.7.16.jar kryo-shaded-3.0.3.jar +kubernetes-client-3.0.0.jar +kubernetes-model-2.0.0.jar leveldbjni-all-1.8.jar libfb303-0.9.3.jar libthrift-0.9.3.jar log4j-1.2.17.jar +logging-interceptor-3.8.1.jar lz4-java-1.4.0.jar machinist_2.11-0.6.1.jar macro-compat_2.11-1.1.1.jar @@ -147,6 +154,8 @@ minlog-1.3.0.jar netty-3.9.9.Final.jar netty-all-4.1.17.Final.jar objenesis-2.1.jar +okhttp-3.8.1.jar +okio-1.13.0.jar opencsv-2.3.jar orc-core-1.4.1-nohive.jar orc-mapreduce-1.4.1-nohive.jar @@ -171,6 +180,7 @@ scala-xml_2.11-1.0.5.jar shapeless_2.11-2.3.2.jar slf4j-api-1.7.16.jar slf4j-log4j12-1.7.16.jar +snakeyaml-1.15.jar snappy-0.2.jar snappy-java-1.1.2.6.jar spire-macros_2.11-0.13.0.jar @@ -186,5 +196,6 @@ xbean-asm5-shaded-4.4.jar xercesImpl-2.9.1.jar xmlenc-0.52.jar xz-1.0.jar +zjsonpatch-0.3.0.jar zookeeper-3.4.6.jar zstd-jni-1.3.2-2.jar diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index abee326f283ab..1807a77900e52 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -17,6 +17,7 @@ arpack_combined_all-0.1.jar arrow-format-0.8.0.jar arrow-memory-0.8.0.jar arrow-vector-0.8.0.jar +automaton-1.11-8.jar avro-1.7.7.jar avro-ipc-1.7.7.jar avro-mapred-1.7.7-hadoop2.jar @@ -60,6 +61,7 @@ datanucleus-rdbms-3.2.9.jar derby-10.12.1.1.jar eigenbase-properties-1.1.5.jar flatbuffers-1.2.0-3f79e055.jar +generex-1.0.1.jar gson-2.2.4.jar guava-14.0.1.jar guice-3.0.jar @@ -91,8 +93,10 @@ jackson-annotations-2.6.7.jar jackson-core-2.6.7.jar jackson-core-asl-1.9.13.jar jackson-databind-2.6.7.1.jar +jackson-dataformat-yaml-2.6.7.jar jackson-jaxrs-1.9.13.jar jackson-mapper-asl-1.9.13.jar +jackson-module-jaxb-annotations-2.6.7.jar jackson-module-paranamer-2.7.9.jar jackson-module-scala_2.11-2.6.7.1.jar jackson-xc-1.9.13.jar @@ -132,10 +136,13 @@ jta-1.1.jar jtransforms-2.4.0.jar jul-to-slf4j-1.7.16.jar kryo-shaded-3.0.3.jar +kubernetes-client-3.0.0.jar +kubernetes-model-2.0.0.jar leveldbjni-all-1.8.jar libfb303-0.9.3.jar libthrift-0.9.3.jar log4j-1.2.17.jar +logging-interceptor-3.8.1.jar lz4-java-1.4.0.jar machinist_2.11-0.6.1.jar macro-compat_2.11-1.1.1.jar @@ -148,6 +155,8 @@ minlog-1.3.0.jar netty-3.9.9.Final.jar netty-all-4.1.17.Final.jar objenesis-2.1.jar +okhttp-3.8.1.jar +okio-1.13.0.jar opencsv-2.3.jar orc-core-1.4.1-nohive.jar orc-mapreduce-1.4.1-nohive.jar @@ -172,6 +181,7 @@ scala-xml_2.11-1.0.5.jar shapeless_2.11-2.3.2.jar slf4j-api-1.7.16.jar slf4j-log4j12-1.7.16.jar +snakeyaml-1.15.jar snappy-0.2.jar snappy-java-1.1.2.6.jar spire-macros_2.11-0.13.0.jar @@ -187,5 +197,6 @@ xbean-asm5-shaded-4.4.jar xercesImpl-2.9.1.jar xmlenc-0.52.jar xz-1.0.jar +zjsonpatch-0.3.0.jar zookeeper-3.4.6.jar zstd-jni-1.3.2-2.jar diff --git a/dev/lint-java b/dev/lint-java index c2e80538ef2a5..1f0b0c8379ed0 100755 --- a/dev/lint-java +++ b/dev/lint-java @@ -20,7 +20,7 @@ SCRIPT_DIR="$( cd "$( dirname "$0" )" && pwd )" SPARK_ROOT_DIR="$(dirname $SCRIPT_DIR)" -ERRORS=$($SCRIPT_DIR/../build/mvn -Pkinesis-asl -Pmesos -Pyarn -Phive -Phive-thriftserver checkstyle:check | grep ERROR) +ERRORS=$($SCRIPT_DIR/../build/mvn -Pkinesis-asl -Pmesos -Pkubernetes -Pyarn -Phive -Phive-thriftserver checkstyle:check | grep ERROR) if test ! -z "$ERRORS"; then echo -e "Checkstyle checks failed at following occurrences:\n$ERRORS" diff --git a/dev/mima b/dev/mima index 1e3ca9700bc07..cd2694ff4d3de 100755 --- a/dev/mima +++ b/dev/mima @@ -24,7 +24,7 @@ set -e FWDIR="$(cd "`dirname "$0"`"/..; pwd)" cd "$FWDIR" -SPARK_PROFILES="-Pmesos -Pkafka-0-8 -Pyarn -Pflume -Pspark-ganglia-lgpl -Pkinesis-asl -Phive-thriftserver -Phive" +SPARK_PROFILES="-Pmesos -Pkafka-0-8 -Pkubernetes -Pyarn -Pflume -Pspark-ganglia-lgpl -Pkinesis-asl -Phive-thriftserver -Phive" TOOLS_CLASSPATH="$(build/sbt -DcopyDependencies=false "export tools/fullClasspath" | tail -n1)" OLD_DEPS_CLASSPATH="$(build/sbt -DcopyDependencies=false $SPARK_PROFILES "export oldDeps/fullClasspath" | tail -n1)" diff --git a/dev/scalastyle b/dev/scalastyle index 89ecc8abd6f8c..b8053df05fa2b 100755 --- a/dev/scalastyle +++ b/dev/scalastyle @@ -24,6 +24,7 @@ ERRORS=$(echo -e "q\n" \ -Pkinesis-asl \ -Pmesos \ -Pkafka-0-8 \ + -Pkubernetes \ -Pyarn \ -Pflume \ -Phive \ diff --git a/dev/test-dependencies.sh b/dev/test-dependencies.sh index 58b295d4f6e00..3bf7618e1ea96 100755 --- a/dev/test-dependencies.sh +++ b/dev/test-dependencies.sh @@ -29,7 +29,7 @@ export LC_ALL=C # TODO: This would be much nicer to do in SBT, once SBT supports Maven-style resolution. # NOTE: These should match those in the release publishing script -HADOOP2_MODULE_PROFILES="-Phive-thriftserver -Pmesos -Pkafka-0-8 -Pyarn -Pflume -Phive" +HADOOP2_MODULE_PROFILES="-Phive-thriftserver -Pmesos -Pkafka-0-8 -Pkubernetes -Pyarn -Pflume -Phive" MVN="build/mvn" HADOOP_PROFILES=( hadoop-2.6 From 7a3d0aad2b89aef54f7dd580397302e9ff984d9d Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Sat, 13 Jan 2018 23:26:12 -0800 Subject: [PATCH 0089/1161] [SPARK-23038][TEST] Update docker/spark-test (JDK/OS) ## What changes were proposed in this pull request? This PR aims to update the followings in `docker/spark-test`. - JDK7 -> JDK8 Spark 2.2+ supports JDK8 only. - Ubuntu 12.04.5 LTS(precise) -> Ubuntu 16.04.3 LTS(xeniel) The end of life of `precise` was April 28, 2017. ## How was this patch tested? Manual. * Master ``` $ cd external/docker $ ./build $ export SPARK_HOME=... $ docker run -v $SPARK_HOME:/opt/spark spark-test-master CONTAINER_IP=172.17.0.3 ... 18/01/11 06:50:25 INFO MasterWebUI: Bound MasterWebUI to 172.17.0.3, and started at http://172.17.0.3:8080 18/01/11 06:50:25 INFO Utils: Successfully started service on port 6066. 18/01/11 06:50:25 INFO StandaloneRestServer: Started REST server for submitting applications on port 6066 18/01/11 06:50:25 INFO Master: I have been elected leader! New state: ALIVE ``` * Slave ``` $ docker run -v $SPARK_HOME:/opt/spark spark-test-worker spark://172.17.0.3:7077 CONTAINER_IP=172.17.0.4 ... 18/01/11 06:51:54 INFO Worker: Successfully registered with master spark://172.17.0.3:7077 ``` After slave starts, master will show ``` 18/01/11 06:51:54 INFO Master: Registering worker 172.17.0.4:8888 with 4 cores, 1024.0 MB RAM ``` Author: Dongjoon Hyun Closes #20230 from dongjoon-hyun/SPARK-23038. --- external/docker/spark-test/base/Dockerfile | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/external/docker/spark-test/base/Dockerfile b/external/docker/spark-test/base/Dockerfile index 5a95a9387c310..c70cd71367679 100644 --- a/external/docker/spark-test/base/Dockerfile +++ b/external/docker/spark-test/base/Dockerfile @@ -15,14 +15,14 @@ # limitations under the License. # -FROM ubuntu:precise +FROM ubuntu:xenial # Upgrade package index -# install a few other useful packages plus Open Jdk 7 +# install a few other useful packages plus Open Jdk 8 # Remove unneeded /var/lib/apt/lists/* after install to reduce the # docker image size (by ~30MB) RUN apt-get update && \ - apt-get install -y less openjdk-7-jre-headless net-tools vim-tiny sudo openssh-server && \ + apt-get install -y less openjdk-8-jre-headless iproute2 vim-tiny sudo openssh-server && \ rm -rf /var/lib/apt/lists/* ENV SCALA_VERSION 2.11.8 From 66738d29c59871b29d26fc3756772b95ef536248 Mon Sep 17 00:00:00 2001 From: Felix Cheung Date: Sun, 14 Jan 2018 19:43:10 +0900 Subject: [PATCH 0090/1161] [SPARK-23069][DOCS][SPARKR] fix R doc for describe missing text ## What changes were proposed in this pull request? fix doc truncated ## How was this patch tested? manually Author: Felix Cheung Closes #20263 from felixcheung/r23docfix. --- R/pkg/R/DataFrame.R | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 9956f7eda91e6..6caa125e1e14a 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -3054,10 +3054,10 @@ setMethod("describe", #' \item stddev #' \item min #' \item max -#' \item arbitrary approximate percentiles specified as a percentage (eg, "75%") +#' \item arbitrary approximate percentiles specified as a percentage (eg, "75\%") #' } #' If no statistics are given, this function computes count, mean, stddev, min, -#' approximate quartiles (percentiles at 25%, 50%, and 75%), and max. +#' approximate quartiles (percentiles at 25\%, 50\%, and 75\%), and max. #' This function is meant for exploratory data analysis, as we make no guarantee about the #' backward compatibility of the schema of the resulting Dataset. If you want to #' programmatically compute summary statistics, use the \code{agg} function instead. @@ -4019,9 +4019,9 @@ setMethod("broadcast", #' #' Spark will use this watermark for several purposes: #' \itemize{ -#' \item{-} To know when a given time window aggregation can be finalized and thus can be emitted +#' \item To know when a given time window aggregation can be finalized and thus can be emitted #' when using output modes that do not allow updates. -#' \item{-} To minimize the amount of state that we need to keep for on-going aggregations. +#' \item To minimize the amount of state that we need to keep for on-going aggregations. #' } #' The current watermark is computed by looking at the \code{MAX(eventTime)} seen across #' all of the partitions in the query minus a user specified \code{delayThreshold}. Due to the cost From 990f05c80347c6eec2ee06823cff587c9ea90b49 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Sun, 14 Jan 2018 22:26:21 +0800 Subject: [PATCH 0091/1161] [SPARK-23021][SQL] AnalysisBarrier should override innerChildren to print correct explain output ## What changes were proposed in this pull request? `AnalysisBarrier` in the current master cuts off explain results for parsed logical plans; ``` scala> Seq((1, 1)).toDF("a", "b").groupBy("a").count().sample(0.1).explain(true) == Parsed Logical Plan == Sample 0.0, 0.1, false, -7661439431999668039 +- AnalysisBarrier Aggregate [a#5], [a#5, count(1) AS count#14L] ``` To fix this, `AnalysisBarrier` needs to override `innerChildren` and this pr changed the output to; ``` == Parsed Logical Plan == Sample 0.0, 0.1, false, -5086223488015741426 +- AnalysisBarrier +- Aggregate [a#5], [a#5, count(1) AS count#14L] +- Project [_1#2 AS a#5, _2#3 AS b#6] +- LocalRelation [_1#2, _2#3] ``` ## How was this patch tested? Added tests in `DataFrameSuite`. Author: Takeshi Yamamuro Closes #20247 from maropu/SPARK-23021-2. --- .../plans/logical/basicLogicalOperators.scala | 1 + .../sql/hive/execution/HiveExplainSuite.scala | 17 +++++++++++++++++ 2 files changed, 18 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 95e099c340af1..a4fca790dd086 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -903,6 +903,7 @@ case class Deduplicate( * This analysis barrier will be removed at the end of analysis stage. */ case class AnalysisBarrier(child: LogicalPlan) extends LeafNode { + override protected def innerChildren: Seq[LogicalPlan] = Seq(child) override def output: Seq[Attribute] = child.output override def isStreaming: Boolean = child.isStreaming override def doCanonicalize(): LogicalPlan = child.canonicalized diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala index dfabf1ec2a22a..a4273de5fe260 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala @@ -171,4 +171,21 @@ class HiveExplainSuite extends QueryTest with SQLTestUtils with TestHiveSingleto sql("EXPLAIN EXTENDED CODEGEN SELECT 1") } } + + test("SPARK-23021 AnalysisBarrier should not cut off explain output for parsed logical plans") { + val df = Seq((1, 1)).toDF("a", "b").groupBy("a").count().limit(1) + val outputStream = new java.io.ByteArrayOutputStream() + Console.withOut(outputStream) { + df.explain(true) + } + assert(outputStream.toString.replaceAll("""#\d+""", "#0").contains( + s"""== Parsed Logical Plan == + |GlobalLimit 1 + |+- LocalLimit 1 + | +- AnalysisBarrier + | +- Aggregate [a#0], [a#0, count(1) AS count#0L] + | +- Project [_1#0 AS a#0, _2#0 AS b#0] + | +- LocalRelation [_1#0, _2#0] + |""".stripMargin)) + } } From 60eeecd7760aee6ce2fd207c83ae40054eadaf83 Mon Sep 17 00:00:00 2001 From: Sandor Murakozi Date: Sun, 14 Jan 2018 08:32:35 -0600 Subject: [PATCH 0092/1161] [SPARK-23051][CORE] Fix for broken job description in Spark UI ## What changes were proposed in this pull request? In 2.2, Spark UI displayed the stage description if the job description was not set. This functionality was broken, the GUI has shown no description in this case. In addition, the code uses jobName and jobDescription instead of stageName and stageDescription when JobTableRowData is created. In this PR the logic producing values for the job rows was modified to find the latest stage attempt for the job and use that as a fallback if job description was missing. StageName and stageDescription are also set using values from stage and jobName/description is used only as a fallback. ## How was this patch tested? Manual testing of the UI, using the code in the bug report. Author: Sandor Murakozi Closes #20251 from smurakozi/SPARK-23051. --- .../apache/spark/ui/jobs/AllJobsPage.scala | 23 ++++++++++--------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala index 37e3b3b304a63..ff916bb6a5759 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala @@ -65,12 +65,10 @@ private[ui] class AllJobsPage(parent: JobsTab, store: AppStatusStore) extends We }.map { job => val jobId = job.jobId val status = job.status - val displayJobDescription = - if (job.description.isEmpty) { - job.name - } else { - UIUtils.makeDescription(job.description.get, "", plainText = true).text - } + val jobDescription = store.lastStageAttempt(job.stageIds.max).description + val displayJobDescription = jobDescription + .map(UIUtils.makeDescription(_, "", plainText = true).text) + .getOrElse("") val submissionTime = job.submissionTime.get.getTime() val completionTime = job.completionTime.map(_.getTime()).getOrElse(System.currentTimeMillis()) val classNameByStatus = status match { @@ -429,20 +427,23 @@ private[ui] class JobDataSource( val formattedDuration = duration.map(d => UIUtils.formatDuration(d)).getOrElse("Unknown") val submissionTime = jobData.submissionTime val formattedSubmissionTime = submissionTime.map(UIUtils.formatDate).getOrElse("Unknown") - val jobDescription = UIUtils.makeDescription(jobData.description.getOrElse(""), - basePath, plainText = false) + val lastStageAttempt = store.lastStageAttempt(jobData.stageIds.max) + val lastStageDescription = lastStageAttempt.description.getOrElse("") + + val formattedJobDescription = + UIUtils.makeDescription(lastStageDescription, basePath, plainText = false) val detailUrl = "%s/jobs/job?id=%s".format(basePath, jobData.jobId) new JobTableRowData( jobData, - jobData.name, - jobData.description.getOrElse(jobData.name), + lastStageAttempt.name, + lastStageDescription, duration.getOrElse(-1), formattedDuration, submissionTime.map(_.getTime()).getOrElse(-1L), formattedSubmissionTime, - jobDescription, + formattedJobDescription, detailUrl ) } From 42a1a15d739890bdfbb367ef94198b19e98ffcb7 Mon Sep 17 00:00:00 2001 From: guoxiaolong Date: Mon, 15 Jan 2018 02:02:49 +0800 Subject: [PATCH 0093/1161] [SPARK-22999][SQL] show databases like command' can remove the like keyword ## What changes were proposed in this pull request? SHOW DATABASES (LIKE pattern = STRING)? Can be like the back increase? When using this command, LIKE keyword can be removed. You can refer to the SHOW TABLES command, SHOW TABLES 'test *' and SHOW TABELS like 'test *' can be used. Similarly SHOW DATABASES 'test *' and SHOW DATABASES like 'test *' can be used. ## How was this patch tested? unit tests manual tests Please review http://spark.apache.org/contributing.html before opening a pull request. Author: guoxiaolong Closes #20194 from guoxiaolongzte/SPARK-22999. --- .../antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 | 2 +- .../org/apache/spark/sql/execution/command/DDLSuite.scala | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 6daf01d98426c..39d5e4ed56628 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -141,7 +141,7 @@ statement (LIKE? pattern=STRING)? #showTables | SHOW TABLE EXTENDED ((FROM | IN) db=identifier)? LIKE pattern=STRING partitionSpec? #showTable - | SHOW DATABASES (LIKE pattern=STRING)? #showDatabases + | SHOW DATABASES (LIKE? pattern=STRING)? #showDatabases | SHOW TBLPROPERTIES table=tableIdentifier ('(' key=tablePropertyKey ')')? #showTblProperties | SHOW COLUMNS (FROM | IN) tableIdentifier diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index 591510c1d8283..2b4b7c137428a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -991,6 +991,10 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { sql("SHOW DATABASES LIKE '*db1A'"), Row("showdb1a") :: Nil) + checkAnswer( + sql("SHOW DATABASES '*db1A'"), + Row("showdb1a") :: Nil) + checkAnswer( sql("SHOW DATABASES LIKE 'showdb1A'"), Row("showdb1a") :: Nil) From b98ffa4d6dabaf787177d3f14b200fc4b118c7ce Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Mon, 15 Jan 2018 10:55:21 +0800 Subject: [PATCH 0094/1161] [SPARK-23054][SQL] Fix incorrect results of casting UserDefinedType to String ## What changes were proposed in this pull request? This pr fixed the issue when casting `UserDefinedType`s into strings; ``` >>> from pyspark.ml.classification import MultilayerPerceptronClassifier >>> from pyspark.ml.linalg import Vectors >>> df = spark.createDataFrame([(0.0, Vectors.dense([0.0, 0.0])), (1.0, Vectors.dense([0.0, 1.0]))], ["label", "features"]) >>> df.selectExpr("CAST(features AS STRING)").show(truncate = False) +-------------------------------------------+ |features | +-------------------------------------------+ |[6,1,0,0,2800000020,2,0,0,0] | |[6,1,0,0,2800000020,2,0,0,3ff0000000000000]| +-------------------------------------------+ ``` The root cause is that `Cast` handles input data as `UserDefinedType.sqlType`(this is underlying storage type), so we should pass data into `UserDefinedType.deserialize` then `toString`. This pr modified the result into; ``` +---------+ |features | +---------+ |[0.0,0.0]| |[0.0,1.0]| +---------+ ``` ## How was this patch tested? Added tests in `UserDefinedTypeSuite `. Author: Takeshi Yamamuro Closes #20246 from maropu/SPARK-23054. --- .../spark/sql/catalyst/expressions/Cast.scala | 7 +++++++ .../apache/spark/sql/UserDefinedTypeSuite.scala | 15 +++++++++++++-- 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index f21aa1e9e3135..a95ebe301b9d1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -282,6 +282,8 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String builder.append("]") builder.build() }) + case udt: UserDefinedType[_] => + buildCast[Any](_, o => UTF8String.fromString(udt.deserialize(o).toString)) case _ => buildCast[Any](_, o => UTF8String.fromString(o.toString)) } @@ -836,6 +838,11 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String |$evPrim = $buffer.build(); """.stripMargin } + case udt: UserDefinedType[_] => + val udtRef = ctx.addReferenceObj("udt", udt) + (c, evPrim, evNull) => { + s"$evPrim = UTF8String.fromString($udtRef.deserialize($c).toString());" + } case _ => (c, evPrim, evNull) => s"$evPrim = UTF8String.fromString(String.valueOf($c));" } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala index a08433ba794d9..cc8b600efa46a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala @@ -21,7 +21,7 @@ import scala.beans.{BeanInfo, BeanProperty} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} -import org.apache.spark.sql.catalyst.expressions.GenericInternalRow +import org.apache.spark.sql.catalyst.expressions.{Cast, ExpressionEvalHelper, GenericInternalRow, Literal} import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData} import org.apache.spark.sql.execution.datasources.parquet.ParquetTest import org.apache.spark.sql.functions._ @@ -44,6 +44,8 @@ object UDT { case v: MyDenseVector => java.util.Arrays.equals(this.data, v.data) case _ => false } + + override def toString: String = data.mkString("(", ", ", ")") } private[sql] class MyDenseVectorUDT extends UserDefinedType[MyDenseVector] { @@ -143,7 +145,8 @@ private[spark] class ExampleSubTypeUDT extends UserDefinedType[IExampleSubType] override def userClass: Class[IExampleSubType] = classOf[IExampleSubType] } -class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetTest { +class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetTest + with ExpressionEvalHelper { import testImplicits._ private lazy val pointsRDD = Seq( @@ -304,4 +307,12 @@ class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetT pointsRDD.except(pointsRDD2), Seq(Row(0.0, new UDT.MyDenseVector(Array(0.2, 2.0))))) } + + test("SPARK-23054 Cast UserDefinedType to string") { + val udt = new UDT.MyDenseVectorUDT() + val vector = new UDT.MyDenseVector(Array(1.0, 3.0, 5.0, 7.0, 9.0)) + val data = udt.serialize(vector) + val ret = Cast(Literal(data, udt), StringType, None) + checkEvaluation(ret, "(1.0, 3.0, 5.0, 7.0, 9.0)") + } } From 9a96bfc8bf021cb4b6c62fac6ce1bcf87affcd43 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Mon, 15 Jan 2018 12:06:56 +0800 Subject: [PATCH 0095/1161] [SPARK-23049][SQL] `spark.sql.files.ignoreCorruptFiles` should work for ORC files ## What changes were proposed in this pull request? When `spark.sql.files.ignoreCorruptFiles=true`, we should ignore corrupted ORC files. ## How was this patch tested? Pass the Jenkins with a newly added test case. Author: Dongjoon Hyun Closes #20240 from dongjoon-hyun/SPARK-23049. --- .../execution/datasources/orc/OrcUtils.scala | 29 ++++++++---- .../datasources/orc/OrcQuerySuite.scala | 47 +++++++++++++++++++ .../parquet/ParquetQuerySuite.scala | 23 +++++++-- .../spark/sql/hive/orc/OrcFileFormat.scala | 8 +++- .../spark/sql/hive/orc/OrcFileOperator.scala | 28 +++++++++-- 5 files changed, 117 insertions(+), 18 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala index 13a23996f4ade..460194ba61c8b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala @@ -23,6 +23,7 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.orc.{OrcFile, Reader, TypeDescription} +import org.apache.spark.SparkException import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkSession @@ -50,23 +51,35 @@ object OrcUtils extends Logging { paths } - def readSchema(file: Path, conf: Configuration): Option[TypeDescription] = { + def readSchema(file: Path, conf: Configuration, ignoreCorruptFiles: Boolean) + : Option[TypeDescription] = { val fs = file.getFileSystem(conf) val readerOptions = OrcFile.readerOptions(conf).filesystem(fs) - val reader = OrcFile.createReader(file, readerOptions) - val schema = reader.getSchema - if (schema.getFieldNames.size == 0) { - None - } else { - Some(schema) + try { + val reader = OrcFile.createReader(file, readerOptions) + val schema = reader.getSchema + if (schema.getFieldNames.size == 0) { + None + } else { + Some(schema) + } + } catch { + case e: org.apache.orc.FileFormatException => + if (ignoreCorruptFiles) { + logWarning(s"Skipped the footer in the corrupted file: $file", e) + None + } else { + throw new SparkException(s"Could not read footer for file: $file", e) + } } } def readSchema(sparkSession: SparkSession, files: Seq[FileStatus]) : Option[StructType] = { + val ignoreCorruptFiles = sparkSession.sessionState.conf.ignoreCorruptFiles val conf = sparkSession.sessionState.newHadoopConf() // TODO: We need to support merge schema. Please see SPARK-11412. - files.map(_.getPath).flatMap(readSchema(_, conf)).headOption.map { schema => + files.map(_.getPath).flatMap(readSchema(_, conf, ignoreCorruptFiles)).headOption.map { schema => logDebug(s"Reading schema from file $files, got Hive schema string: $schema") CatalystSqlParser.parseDataType(schema.toString).asInstanceOf[StructType] } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcQuerySuite.scala index e00e057a18cc6..f58c331f33ca8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcQuerySuite.scala @@ -31,6 +31,7 @@ import org.apache.orc.OrcConf.COMPRESS import org.apache.orc.mapred.OrcStruct import org.apache.orc.mapreduce.OrcInputFormat +import org.apache.spark.SparkException import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation, RecordReaderIterator} @@ -531,6 +532,52 @@ abstract class OrcQueryTest extends OrcTest { val df = spark.read.orc(path1.getCanonicalPath, path2.getCanonicalPath) assert(df.count() == 20) } + + test("Enabling/disabling ignoreCorruptFiles") { + def testIgnoreCorruptFiles(): Unit = { + withTempDir { dir => + val basePath = dir.getCanonicalPath + spark.range(1).toDF("a").write.orc(new Path(basePath, "first").toString) + spark.range(1, 2).toDF("a").write.orc(new Path(basePath, "second").toString) + spark.range(2, 3).toDF("a").write.json(new Path(basePath, "third").toString) + val df = spark.read.orc( + new Path(basePath, "first").toString, + new Path(basePath, "second").toString, + new Path(basePath, "third").toString) + checkAnswer(df, Seq(Row(0), Row(1))) + } + } + + def testIgnoreCorruptFilesWithoutSchemaInfer(): Unit = { + withTempDir { dir => + val basePath = dir.getCanonicalPath + spark.range(1).toDF("a").write.orc(new Path(basePath, "first").toString) + spark.range(1, 2).toDF("a").write.orc(new Path(basePath, "second").toString) + spark.range(2, 3).toDF("a").write.json(new Path(basePath, "third").toString) + val df = spark.read.schema("a long").orc( + new Path(basePath, "first").toString, + new Path(basePath, "second").toString, + new Path(basePath, "third").toString) + checkAnswer(df, Seq(Row(0), Row(1))) + } + } + + withSQLConf(SQLConf.IGNORE_CORRUPT_FILES.key -> "true") { + testIgnoreCorruptFiles() + testIgnoreCorruptFilesWithoutSchemaInfer() + } + + withSQLConf(SQLConf.IGNORE_CORRUPT_FILES.key -> "false") { + val m1 = intercept[SparkException] { + testIgnoreCorruptFiles() + }.getMessage + assert(m1.contains("Could not read footer for file")) + val m2 = intercept[SparkException] { + testIgnoreCorruptFilesWithoutSchemaInfer() + }.getMessage + assert(m2.contains("Malformed ORC file")) + } + } } class OrcQuerySuite extends OrcQueryTest with SharedSQLContext { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala index 4c8c9ef6e0432..6ad88ed997ce7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala @@ -320,14 +320,27 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext new Path(basePath, "first").toString, new Path(basePath, "second").toString, new Path(basePath, "third").toString) - checkAnswer( - df, - Seq(Row(0), Row(1))) + checkAnswer(df, Seq(Row(0), Row(1))) + } + } + + def testIgnoreCorruptFilesWithoutSchemaInfer(): Unit = { + withTempDir { dir => + val basePath = dir.getCanonicalPath + spark.range(1).toDF("a").write.parquet(new Path(basePath, "first").toString) + spark.range(1, 2).toDF("a").write.parquet(new Path(basePath, "second").toString) + spark.range(2, 3).toDF("a").write.json(new Path(basePath, "third").toString) + val df = spark.read.schema("a long").parquet( + new Path(basePath, "first").toString, + new Path(basePath, "second").toString, + new Path(basePath, "third").toString) + checkAnswer(df, Seq(Row(0), Row(1))) } } withSQLConf(SQLConf.IGNORE_CORRUPT_FILES.key -> "true") { testIgnoreCorruptFiles() + testIgnoreCorruptFilesWithoutSchemaInfer() } withSQLConf(SQLConf.IGNORE_CORRUPT_FILES.key -> "false") { @@ -335,6 +348,10 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext testIgnoreCorruptFiles() } assert(exception.getMessage().contains("is not a Parquet file")) + val exception2 = intercept[SparkException] { + testIgnoreCorruptFilesWithoutSchemaInfer() + } + assert(exception2.getMessage().contains("is not a Parquet file")) } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala index 95741c7b30289..237ed9bc05988 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala @@ -59,9 +59,11 @@ class OrcFileFormat extends FileFormat with DataSourceRegister with Serializable sparkSession: SparkSession, options: Map[String, String], files: Seq[FileStatus]): Option[StructType] = { + val ignoreCorruptFiles = sparkSession.sessionState.conf.ignoreCorruptFiles OrcFileOperator.readSchema( files.map(_.getPath.toString), - Some(sparkSession.sessionState.newHadoopConf()) + Some(sparkSession.sessionState.newHadoopConf()), + ignoreCorruptFiles ) } @@ -129,6 +131,7 @@ class OrcFileFormat extends FileFormat with DataSourceRegister with Serializable val broadcastedHadoopConf = sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) + val ignoreCorruptFiles = sparkSession.sessionState.conf.ignoreCorruptFiles (file: PartitionedFile) => { val conf = broadcastedHadoopConf.value.value @@ -138,7 +141,8 @@ class OrcFileFormat extends FileFormat with DataSourceRegister with Serializable // SPARK-8501: Empty ORC files always have an empty schema stored in their footer. In this // case, `OrcFileOperator.readSchema` returns `None`, and we can't read the underlying file // using the given physical schema. Instead, we simply return an empty iterator. - val isEmptyFile = OrcFileOperator.readSchema(Seq(filePath.toString), Some(conf)).isEmpty + val isEmptyFile = + OrcFileOperator.readSchema(Seq(filePath.toString), Some(conf), ignoreCorruptFiles).isEmpty if (isEmptyFile) { Iterator.empty } else { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileOperator.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileOperator.scala index 5a3fcd7a759c0..80e44ca504356 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileOperator.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileOperator.scala @@ -17,11 +17,14 @@ package org.apache.spark.sql.hive.orc +import java.io.IOException + import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.hadoop.hive.ql.io.orc.{OrcFile, Reader} import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector +import org.apache.spark.SparkException import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.parser.CatalystSqlParser @@ -46,7 +49,10 @@ private[hive] object OrcFileOperator extends Logging { * create the result reader from that file. If no such file is found, it returns `None`. * @todo Needs to consider all files when schema evolution is taken into account. */ - def getFileReader(basePath: String, config: Option[Configuration] = None): Option[Reader] = { + def getFileReader(basePath: String, + config: Option[Configuration] = None, + ignoreCorruptFiles: Boolean = false) + : Option[Reader] = { def isWithNonEmptySchema(path: Path, reader: Reader): Boolean = { reader.getObjectInspector match { case oi: StructObjectInspector if oi.getAllStructFieldRefs.size() == 0 => @@ -65,16 +71,28 @@ private[hive] object OrcFileOperator extends Logging { } listOrcFiles(basePath, conf).iterator.map { path => - path -> OrcFile.createReader(fs, path) + val reader = try { + Some(OrcFile.createReader(fs, path)) + } catch { + case e: IOException => + if (ignoreCorruptFiles) { + logWarning(s"Skipped the footer in the corrupted file: $path", e) + None + } else { + throw new SparkException(s"Could not read footer for file: $path", e) + } + } + path -> reader }.collectFirst { - case (path, reader) if isWithNonEmptySchema(path, reader) => reader + case (path, Some(reader)) if isWithNonEmptySchema(path, reader) => reader } } - def readSchema(paths: Seq[String], conf: Option[Configuration]): Option[StructType] = { + def readSchema(paths: Seq[String], conf: Option[Configuration], ignoreCorruptFiles: Boolean) + : Option[StructType] = { // Take the first file where we can open a valid reader if we can find one. Otherwise just // return None to indicate we can't infer the schema. - paths.flatMap(getFileReader(_, conf)).headOption.map { reader => + paths.flatMap(getFileReader(_, conf, ignoreCorruptFiles)).headOption.map { reader => val readerInspector = reader.getObjectInspector.asInstanceOf[StructObjectInspector] val schema = readerInspector.getTypeName logDebug(s"Reading schema from file $paths, got Hive schema string: $schema") From b59808385cfe24ce768e5b3098b9034e64b99a5a Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Mon, 15 Jan 2018 16:26:52 +0800 Subject: [PATCH 0096/1161] [SPARK-23023][SQL] Cast field data to strings in showString ## What changes were proposed in this pull request? The current `Datset.showString` prints rows thru `RowEncoder` deserializers like; ``` scala> Seq(Seq(Seq(1, 2), Seq(3), Seq(4, 5, 6))).toDF("a").show(false) +------------------------------------------------------------+ |a | +------------------------------------------------------------+ |[WrappedArray(1, 2), WrappedArray(3), WrappedArray(4, 5, 6)]| +------------------------------------------------------------+ ``` This result is incorrect because the correct one is; ``` scala> Seq(Seq(Seq(1, 2), Seq(3), Seq(4, 5, 6))).toDF("a").show(false) +------------------------+ |a | +------------------------+ |[[1, 2], [3], [4, 5, 6]]| +------------------------+ ``` So, this pr fixed code in `showString` to cast field data to strings before printing. ## How was this patch tested? Added tests in `DataFrameSuite`. Author: Takeshi Yamamuro Closes #20214 from maropu/SPARK-23023. --- python/pyspark/sql/functions.py | 32 +++++++++---------- .../scala/org/apache/spark/sql/Dataset.scala | 21 ++++++------ .../org/apache/spark/sql/DataFrameSuite.scala | 28 ++++++++++++++++ .../org/apache/spark/sql/DatasetSuite.scala | 12 +++---- 4 files changed, 61 insertions(+), 32 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index e1ad6590554cf..f7b3f29764040 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1849,14 +1849,14 @@ def explode_outer(col): +---+----------+----+-----+ >>> df.select("id", "a_map", explode_outer("an_array")).show() - +---+-------------+----+ - | id| a_map| col| - +---+-------------+----+ - | 1|Map(x -> 1.0)| foo| - | 1|Map(x -> 1.0)| bar| - | 2| Map()|null| - | 3| null|null| - +---+-------------+----+ + +---+----------+----+ + | id| a_map| col| + +---+----------+----+ + | 1|[x -> 1.0]| foo| + | 1|[x -> 1.0]| bar| + | 2| []|null| + | 3| null|null| + +---+----------+----+ """ sc = SparkContext._active_spark_context jc = sc._jvm.functions.explode_outer(_to_java_column(col)) @@ -1881,14 +1881,14 @@ def posexplode_outer(col): | 3| null|null|null| null| +---+----------+----+----+-----+ >>> df.select("id", "a_map", posexplode_outer("an_array")).show() - +---+-------------+----+----+ - | id| a_map| pos| col| - +---+-------------+----+----+ - | 1|Map(x -> 1.0)| 0| foo| - | 1|Map(x -> 1.0)| 1| bar| - | 2| Map()|null|null| - | 3| null|null|null| - +---+-------------+----+----+ + +---+----------+----+----+ + | id| a_map| pos| col| + +---+----------+----+----+ + | 1|[x -> 1.0]| 0| foo| + | 1|[x -> 1.0]| 1| bar| + | 2| []|null|null| + | 3| null|null|null| + +---+----------+----+----+ """ sc = SparkContext._active_spark_context jc = sc._jvm.functions.posexplode_outer(_to_java_column(col)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 77e571272920a..34f0ab5aa6699 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -237,13 +237,20 @@ class Dataset[T] private[sql]( private[sql] def showString( _numRows: Int, truncate: Int = 20, vertical: Boolean = false): String = { val numRows = _numRows.max(0).min(Int.MaxValue - 1) - val takeResult = toDF().take(numRows + 1) + val newDf = toDF() + val castCols = newDf.logicalPlan.output.map { col => + // Since binary types in top-level schema fields have a specific format to print, + // so we do not cast them to strings here. + if (col.dataType == BinaryType) { + Column(col) + } else { + Column(col).cast(StringType) + } + } + val takeResult = newDf.select(castCols: _*).take(numRows + 1) val hasMoreData = takeResult.length > numRows val data = takeResult.take(numRows) - lazy val timeZone = - DateTimeUtils.getTimeZone(sparkSession.sessionState.conf.sessionLocalTimeZone) - // For array values, replace Seq and Array with square brackets // For cells that are beyond `truncate` characters, replace it with the // first `truncate-3` and "..." @@ -252,12 +259,6 @@ class Dataset[T] private[sql]( val str = cell match { case null => "null" case binary: Array[Byte] => binary.map("%02X".format(_)).mkString("[", " ", "]") - case array: Array[_] => array.mkString("[", ", ", "]") - case seq: Seq[_] => seq.mkString("[", ", ", "]") - case d: Date => - DateTimeUtils.dateToString(DateTimeUtils.fromJavaDate(d)) - case ts: Timestamp => - DateTimeUtils.timestampToString(DateTimeUtils.fromJavaTimestamp(ts), timeZone) case _ => cell.toString } if (truncate > 0 && str.length > truncate) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 5e4c1a6a484fb..33707080c1301 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -1255,6 +1255,34 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { assert(testData.select($"*").showString(1, vertical = true) === expectedAnswer) } + test("SPARK-23023 Cast rows to strings in showString") { + val df1 = Seq(Seq(1, 2, 3, 4)).toDF("a") + assert(df1.showString(10) === + s"""+------------+ + || a| + |+------------+ + ||[1, 2, 3, 4]| + |+------------+ + |""".stripMargin) + val df2 = Seq(Map(1 -> "a", 2 -> "b")).toDF("a") + assert(df2.showString(10) === + s"""+----------------+ + || a| + |+----------------+ + ||[1 -> a, 2 -> b]| + |+----------------+ + |""".stripMargin) + val df3 = Seq(((1, "a"), 0), ((2, "b"), 0)).toDF("a", "b") + assert(df3.showString(10) === + s"""+------+---+ + || a| b| + |+------+---+ + ||[1, a]| 0| + ||[2, b]| 0| + |+------+---+ + |""".stripMargin) + } + test("SPARK-7327 show with empty dataFrame") { val expectedAnswer = """+---+-----+ ||key|value| diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 54893c184642b..49c59cf695dc1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -958,12 +958,12 @@ class DatasetSuite extends QueryTest with SharedSQLContext { ).toDS() val expected = - """+-------+ - || f| - |+-------+ - ||[foo,1]| - ||[bar,2]| - |+-------+ + """+--------+ + || f| + |+--------+ + ||[foo, 1]| + ||[bar, 2]| + |+--------+ |""".stripMargin checkShowString(ds, expected) From a38c887ac093d7cf343d807515147d87ca931ce7 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Mon, 15 Jan 2018 07:49:34 -0600 Subject: [PATCH 0097/1161] [SPARK-19550][BUILD][FOLLOW-UP] Remove MaxPermSize for sql module ## What changes were proposed in this pull request? Remove `MaxPermSize` for `sql` module ## How was this patch tested? Manually tested. Author: Yuming Wang Closes #20268 from wangyum/SPARK-19550-MaxPermSize. --- sql/catalyst/pom.xml | 2 +- sql/core/pom.xml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml index 839b929abd3cb..7d23637e28342 100644 --- a/sql/catalyst/pom.xml +++ b/sql/catalyst/pom.xml @@ -134,7 +134,7 @@ org.scalatest scalatest-maven-plugin - -ea -Xmx4g -Xss4m -XX:MaxPermSize=${MaxPermGen} -XX:ReservedCodeCacheSize=512m + -ea -Xmx4g -Xss4m -XX:ReservedCodeCacheSize=${CodeCacheSize} diff --git a/sql/core/pom.xml b/sql/core/pom.xml index 744daa6079779..ef41837f89d68 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -195,7 +195,7 @@ org.scalatest scalatest-maven-plugin - -ea -Xmx4g -Xss4m -XX:MaxPermSize=${MaxPermGen} -XX:ReservedCodeCacheSize=512m + -ea -Xmx4g -Xss4m -XX:ReservedCodeCacheSize=${CodeCacheSize} From bd08a9e7af4137bddca638e627ad2ae531bce20f Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Mon, 15 Jan 2018 22:32:38 +0800 Subject: [PATCH 0098/1161] [SPARK-23070] Bump previousSparkVersion in MimaBuild.scala to be 2.2.0 ## What changes were proposed in this pull request? Bump previousSparkVersion in MimaBuild.scala to be 2.2.0 and add the missing exclusions to `v23excludes` in `MimaExcludes`. No item can be un-excluded in `v23excludes`. ## How was this patch tested? The existing tests. Author: gatorsmile Closes #20264 from gatorsmile/bump22. --- project/MimaBuild.scala | 2 +- project/MimaExcludes.scala | 35 ++++++++++++++++++++++++++++++++++- 2 files changed, 35 insertions(+), 2 deletions(-) diff --git a/project/MimaBuild.scala b/project/MimaBuild.scala index 2ef0e7b40d940..adde213e361f0 100644 --- a/project/MimaBuild.scala +++ b/project/MimaBuild.scala @@ -88,7 +88,7 @@ object MimaBuild { def mimaSettings(sparkHome: File, projectRef: ProjectRef) = { val organization = "org.apache.spark" - val previousSparkVersion = "2.0.0" + val previousSparkVersion = "2.2.0" val project = projectRef.project val fullId = "spark-" + project + "_2.11" mimaDefaultSettings ++ diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 32eb31f495979..d35c50e1d00fe 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -102,7 +102,40 @@ object MimaExcludes { // [SPARK-21087] CrossValidator, TrainValidationSplit expose sub models after fitting: Scala ProblemFilters.exclude[FinalClassProblem]("org.apache.spark.ml.tuning.CrossValidatorModel$CrossValidatorModelWriter"), - ProblemFilters.exclude[FinalClassProblem]("org.apache.spark.ml.tuning.TrainValidationSplitModel$TrainValidationSplitModelWriter") + ProblemFilters.exclude[FinalClassProblem]("org.apache.spark.ml.tuning.TrainValidationSplitModel$TrainValidationSplitModelWriter"), + + // [SPARK-21728][CORE] Allow SparkSubmit to use Logging + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.SparkSubmit.downloadFileList"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.SparkSubmit.downloadFile"), + + // [SPARK-21714][CORE][YARN] Avoiding re-uploading remote resources in yarn client mode + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.SparkSubmit.prepareSubmitEnvironment"), + + // [SPARK-22324][SQL][PYTHON] Upgrade Arrow to 0.8.0 + ProblemFilters.exclude[FinalMethodProblem]("org.apache.spark.network.util.AbstractFileRegion.transfered"), + + // [SPARK-20643][CORE] Add listener implementation to collect app state + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.TaskData.$default$5"), + + // [SPARK-20648][CORE] Port JobsTab and StageTab to the new UI backend + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.TaskData.$default$12"), + + // [SPARK-21462][SS] Added batchId to StreamingQueryProgress.json + // [SPARK-21409][SS] Expose state store memory usage in SQL metrics and progress updates + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.streaming.StateOperatorProgress.this"), + + // [SPARK-22278][SS] Expose current event time watermark and current processing time in GroupState + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.streaming.GroupState.getCurrentWatermarkMs"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.streaming.GroupState.getCurrentProcessingTimeMs"), + + // [SPARK-20542][ML][SQL] Add an API to Bucketizer that can bin multiple columns + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasOutputCols.org$apache$spark$ml$param$shared$HasOutputCols$_setter_$outputCols_="), + + // [SPARK-18619][ML] Make QuantileDiscretizer/Bucketizer/StringIndexer/RFormula inherit from HasHandleInvalid + ProblemFilters.exclude[FinalMethodProblem]("org.apache.spark.ml.feature.Bucketizer.getHandleInvalid"), + ProblemFilters.exclude[FinalMethodProblem]("org.apache.spark.ml.feature.StringIndexer.getHandleInvalid"), + ProblemFilters.exclude[FinalMethodProblem]("org.apache.spark.ml.feature.QuantileDiscretizer.getHandleInvalid"), + ProblemFilters.exclude[FinalMethodProblem]("org.apache.spark.ml.feature.StringIndexerModel.getHandleInvalid") ) // Exclude rules for 2.2.x From 6c81fe227a6233f5d9665d2efadf8a1cf09f700d Mon Sep 17 00:00:00 2001 From: xubo245 <601450868@qq.com> Date: Mon, 15 Jan 2018 23:13:15 +0800 Subject: [PATCH 0099/1161] [SPARK-23035][SQL] Fix improper information of TempTableAlreadyExistsException ## What changes were proposed in this pull request? Problem: it throw TempTableAlreadyExistsException and output "Temporary table '$table' already exists" when we create temp view by using org.apache.spark.sql.catalyst.catalog.GlobalTempViewManager#create, it's improper. So fix improper information about TempTableAlreadyExistsException when create temp view: change "Temporary table" to "Temporary view" ## How was this patch tested? test("rename temporary view - destination table already exists, with: CREATE TEMPORARY view") test("rename temporary view - destination table with database name,with:CREATE TEMPORARY view") Author: xubo245 <601450868@qq.com> Closes #20227 from xubo245/fixDeprecated. --- .../analysis/AlreadyExistException.scala | 2 +- .../catalog/SessionCatalogSuite.scala | 6 +- .../spark/sql/execution/SQLViewSuite.scala | 2 +- .../sql/execution/command/DDLSuite.scala | 75 ++++++++++++++++++- 4 files changed, 78 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AlreadyExistException.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AlreadyExistException.scala index 57f7a80bedc6c..6d587abd8fd4d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AlreadyExistException.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AlreadyExistException.scala @@ -31,7 +31,7 @@ class TableAlreadyExistsException(db: String, table: String) extends AnalysisException(s"Table or view '$table' already exists in database '$db'") class TempTableAlreadyExistsException(table: String) - extends AnalysisException(s"Temporary table '$table' already exists") + extends AnalysisException(s"Temporary view '$table' already exists") class PartitionAlreadyExistsException(db: String, table: String, spec: TablePartitionSpec) extends AnalysisException( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala index 95c87ffa20cb7..6abab0073cca3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala @@ -279,7 +279,7 @@ abstract class SessionCatalogSuite extends AnalysisTest { } } - test("create temp table") { + test("create temp view") { withBasicCatalog { catalog => val tempTable1 = Range(1, 10, 1, 10) val tempTable2 = Range(1, 20, 2, 10) @@ -288,11 +288,11 @@ abstract class SessionCatalogSuite extends AnalysisTest { assert(catalog.getTempView("tbl1") == Option(tempTable1)) assert(catalog.getTempView("tbl2") == Option(tempTable2)) assert(catalog.getTempView("tbl3").isEmpty) - // Temporary table already exists + // Temporary view already exists intercept[TempTableAlreadyExistsException] { catalog.createTempView("tbl1", tempTable1, overrideIfExists = false) } - // Temporary table already exists but we override it + // Temporary view already exists but we override it catalog.createTempView("tbl1", tempTable2, overrideIfExists = true) assert(catalog.getTempView("tbl1") == Option(tempTable2)) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala index 8c55758cfe38d..14082197ba0bd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala @@ -293,7 +293,7 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils { sql("CREATE TEMPORARY VIEW testView AS SELECT id FROM jt") } - assert(e.message.contains("Temporary table") && e.message.contains("already exists")) + assert(e.message.contains("Temporary view") && e.message.contains("already exists")) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index 2b4b7c137428a..6ca21b5aa1595 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -835,6 +835,31 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } } + test("rename temporary view - destination table with database name,with:CREATE TEMPORARY view") { + withTempView("view1") { + sql( + """ + |CREATE TEMPORARY VIEW view1 + |USING org.apache.spark.sql.sources.DDLScanSource + |OPTIONS ( + | From '1', + | To '10', + | Table 'test1' + |) + """.stripMargin) + + val e = intercept[AnalysisException] { + sql("ALTER TABLE view1 RENAME TO default.tab2") + } + assert(e.getMessage.contains( + "RENAME TEMPORARY VIEW from '`view1`' to '`default`.`tab2`': " + + "cannot specify database name 'default' in the destination table")) + + val catalog = spark.sessionState.catalog + assert(catalog.listTables("default") == Seq(TableIdentifier("view1"))) + } + } + test("rename temporary view") { withTempView("tab1", "tab2") { spark.range(10).createOrReplaceTempView("tab1") @@ -883,6 +908,42 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } } + test("rename temporary view - destination table already exists, with: CREATE TEMPORARY view") { + withTempView("view1", "view2") { + sql( + """ + |CREATE TEMPORARY VIEW view1 + |USING org.apache.spark.sql.sources.DDLScanSource + |OPTIONS ( + | From '1', + | To '10', + | Table 'test1' + |) + """.stripMargin) + + sql( + """ + |CREATE TEMPORARY VIEW view2 + |USING org.apache.spark.sql.sources.DDLScanSource + |OPTIONS ( + | From '1', + | To '10', + | Table 'test1' + |) + """.stripMargin) + + val e = intercept[AnalysisException] { + sql("ALTER TABLE view1 RENAME TO view2") + } + assert(e.getMessage.contains( + "RENAME TEMPORARY VIEW from '`view1`' to '`view2`': destination table already exists")) + + val catalog = spark.sessionState.catalog + assert(catalog.listTables("default") == + Seq(TableIdentifier("view1"), TableIdentifier("view2"))) + } + } + test("alter table: bucketing is not supported") { val catalog = spark.sessionState.catalog val tableIdent = TableIdentifier("tab1", Some("dbx")) @@ -1728,12 +1789,22 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } test("block creating duplicate temp table") { - withView("t_temp") { + withTempView("t_temp") { sql("CREATE TEMPORARY VIEW t_temp AS SELECT 1, 2") val e = intercept[TempTableAlreadyExistsException] { sql("CREATE TEMPORARY TABLE t_temp (c3 int, c4 string) USING JSON") }.getMessage - assert(e.contains("Temporary table 't_temp' already exists")) + assert(e.contains("Temporary view 't_temp' already exists")) + } + } + + test("block creating duplicate temp view") { + withTempView("t_temp") { + sql("CREATE TEMPORARY VIEW t_temp AS SELECT 1, 2") + val e = intercept[TempTableAlreadyExistsException] { + sql("CREATE TEMPORARY VIEW t_temp (c3 int, c4 string) USING JSON") + }.getMessage + assert(e.contains("Temporary view 't_temp' already exists")) } } From 8ab2d7ea99b2cff8b54b2cb3a1dbf7580845986a Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Tue, 16 Jan 2018 11:47:42 +0900 Subject: [PATCH 0100/1161] [SPARK-23080][SQL] Improve error message for built-in functions ## What changes were proposed in this pull request? When a user puts the wrong number of parameters in a function, an AnalysisException is thrown. If the function is a UDF, he user is told how many parameters the function expected and how many he/she put. If the function, instead, is a built-in one, no information about the number of parameters expected and the actual one is provided. This can help in some cases, to debug the errors (eg. bad quotes escaping may lead to a different number of parameters than expected, etc. etc.) The PR adds the information about the number of parameters passed and the expected one, analogously to what happens for UDF. ## How was this patch tested? modified existing UT + manual test Author: Marco Gaido Closes #20271 from mgaido91/SPARK-23080. --- .../spark/sql/catalyst/analysis/FunctionRegistry.scala | 10 +++++++++- .../resources/sql-tests/results/json-functions.sql.out | 4 ++-- .../src/test/scala/org/apache/spark/sql/UDFSuite.scala | 4 ++-- 3 files changed, 13 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 5ddb39822617d..747016beb06e7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -526,7 +526,15 @@ object FunctionRegistry { // Otherwise, find a constructor method that matches the number of arguments, and use that. val params = Seq.fill(expressions.size)(classOf[Expression]) val f = constructors.find(_.getParameterTypes.toSeq == params).getOrElse { - throw new AnalysisException(s"Invalid number of arguments for function $name") + val validParametersCount = constructors.map(_.getParameterCount).distinct.sorted + val expectedNumberOfParameters = if (validParametersCount.length == 1) { + validParametersCount.head.toString + } else { + validParametersCount.init.mkString("one of ", ", ", " and ") + + validParametersCount.last + } + throw new AnalysisException(s"Invalid number of arguments for function $name. " + + s"Expected: $expectedNumberOfParameters; Found: ${params.length}") } Try(f.newInstance(expressions : _*).asInstanceOf[Expression]) match { case Success(e) => e diff --git a/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out index d9dc728a18e8d..581dddc89d0bb 100644 --- a/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out @@ -129,7 +129,7 @@ select to_json() struct<> -- !query 12 output org.apache.spark.sql.AnalysisException -Invalid number of arguments for function to_json; line 1 pos 7 +Invalid number of arguments for function to_json. Expected: one of 1, 2 and 3; Found: 0; line 1 pos 7 -- !query 13 @@ -225,7 +225,7 @@ select from_json() struct<> -- !query 21 output org.apache.spark.sql.AnalysisException -Invalid number of arguments for function from_json; line 1 pos 7 +Invalid number of arguments for function from_json. Expected: one of 2, 3 and 4; Found: 0; line 1 pos 7 -- !query 22 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index db37be68e42e6..af6a10b425b9f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -80,7 +80,7 @@ class UDFSuite extends QueryTest with SharedSQLContext { val e = intercept[AnalysisException] { df.selectExpr("substr('abcd', 2, 3, 4)") } - assert(e.getMessage.contains("Invalid number of arguments for function substr")) + assert(e.getMessage.contains("Invalid number of arguments for function substr. Expected:")) } test("error reporting for incorrect number of arguments - udf") { @@ -89,7 +89,7 @@ class UDFSuite extends QueryTest with SharedSQLContext { spark.udf.register("foo", (_: String).length) df.selectExpr("foo(2, 3, 4)") } - assert(e.getMessage.contains("Invalid number of arguments for function foo")) + assert(e.getMessage.contains("Invalid number of arguments for function foo. Expected:")) } test("error reporting for undefined functions") { From c7572b79da0a29e502890d7618eaf805a1c9f474 Mon Sep 17 00:00:00 2001 From: Sameer Agarwal Date: Tue, 16 Jan 2018 11:20:18 +0800 Subject: [PATCH 0101/1161] [SPARK-23000] Use fully qualified table names in HiveMetastoreCatalogSuite ## What changes were proposed in this pull request? In another attempt to fix DataSourceWithHiveMetastoreCatalogSuite, this patch uses qualified table names (`default.t`) in the individual tests. ## How was this patch tested? N/A (Test Only Change) Author: Sameer Agarwal Closes #20273 from sameeragarwal/flaky-test. --- .../sql/hive/HiveMetastoreCatalogSuite.scala | 26 ++++++++++--------- 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala index ba9b944e4a055..83b4c862e2546 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala @@ -166,13 +166,13 @@ class DataSourceWithHiveMetastoreCatalogSuite )) ).foreach { case (provider, (inputFormat, outputFormat, serde)) => test(s"Persist non-partitioned $provider relation into metastore as managed table") { - withTable("t") { + withTable("default.t") { withSQLConf(SQLConf.PARQUET_WRITE_LEGACY_FORMAT.key -> "true") { testDF .write .mode(SaveMode.Overwrite) .format(provider) - .saveAsTable("t") + .saveAsTable("default.t") } val hiveTable = sessionState.catalog.getTableMetadata(TableIdentifier("t", Some("default"))) @@ -187,14 +187,15 @@ class DataSourceWithHiveMetastoreCatalogSuite assert(columns.map(_.name) === Seq("d1", "d2")) assert(columns.map(_.dataType) === Seq(DecimalType(10, 3), StringType)) - checkAnswer(table("t"), testDF) - assert(sparkSession.metadataHive.runSqlHive("SELECT * FROM t") === Seq("1.1\t1", "2.1\t2")) + checkAnswer(table("default.t"), testDF) + assert(sparkSession.metadataHive.runSqlHive("SELECT * FROM default.t") === + Seq("1.1\t1", "2.1\t2")) } } test(s"Persist non-partitioned $provider relation into metastore as external table") { withTempPath { dir => - withTable("t") { + withTable("default.t") { val path = dir.getCanonicalFile withSQLConf(SQLConf.PARQUET_WRITE_LEGACY_FORMAT.key -> "true") { @@ -203,7 +204,7 @@ class DataSourceWithHiveMetastoreCatalogSuite .mode(SaveMode.Overwrite) .format(provider) .option("path", path.toString) - .saveAsTable("t") + .saveAsTable("default.t") } val hiveTable = @@ -219,8 +220,8 @@ class DataSourceWithHiveMetastoreCatalogSuite assert(columns.map(_.name) === Seq("d1", "d2")) assert(columns.map(_.dataType) === Seq(DecimalType(10, 3), StringType)) - checkAnswer(table("t"), testDF) - assert(sparkSession.metadataHive.runSqlHive("SELECT * FROM t") === + checkAnswer(table("default.t"), testDF) + assert(sparkSession.metadataHive.runSqlHive("SELECT * FROM default.t") === Seq("1.1\t1", "2.1\t2")) } } @@ -228,9 +229,9 @@ class DataSourceWithHiveMetastoreCatalogSuite test(s"Persist non-partitioned $provider relation into metastore as managed table using CTAS") { withTempPath { dir => - withTable("t") { + withTable("default.t") { sql( - s"""CREATE TABLE t USING $provider + s"""CREATE TABLE default.t USING $provider |OPTIONS (path '${dir.toURI}') |AS SELECT 1 AS d1, "val_1" AS d2 """.stripMargin) @@ -248,8 +249,9 @@ class DataSourceWithHiveMetastoreCatalogSuite assert(columns.map(_.name) === Seq("d1", "d2")) assert(columns.map(_.dataType) === Seq(IntegerType, StringType)) - checkAnswer(table("t"), Row(1, "val_1")) - assert(sparkSession.metadataHive.runSqlHive("SELECT * FROM t") === Seq("1\tval_1")) + checkAnswer(table("default.t"), Row(1, "val_1")) + assert(sparkSession.metadataHive.runSqlHive("SELECT * FROM default.t") === + Seq("1\tval_1")) } } } From 07ae39d0ec1f03b1c73259373a8bb599694c7860 Mon Sep 17 00:00:00 2001 From: Yuanjian Li Date: Mon, 15 Jan 2018 22:01:14 -0800 Subject: [PATCH 0102/1161] [SPARK-22956][SS] Bug fix for 2 streams union failover scenario ## What changes were proposed in this pull request? This problem reported by yanlin-Lynn ivoson and LiangchangZ. Thanks! When we union 2 streams from kafka or other sources, while one of them have no continues data coming and in the same time task restart, this will cause an `IllegalStateException`. This mainly cause because the code in [MicroBatchExecution](https://github.com/apache/spark/blob/master/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala#L190) , while one stream has no continues data, its comittedOffset same with availableOffset during `populateStartOffsets`, and `currentPartitionOffsets` not properly handled in KafkaSource. Also, maybe we should also consider this scenario in other Source. ## How was this patch tested? Add a UT in KafkaSourceSuite.scala Author: Yuanjian Li Closes #20150 from xuanyuanking/SPARK-22956. --- .../spark/sql/kafka010/KafkaSource.scala | 13 ++-- .../spark/sql/kafka010/KafkaSourceSuite.scala | 65 +++++++++++++++++++ .../streaming/MicroBatchExecution.scala | 6 +- .../sql/execution/streaming/memory.scala | 6 ++ 4 files changed, 81 insertions(+), 9 deletions(-) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala index e9cff04ba5f2e..864a92b8f813f 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala @@ -223,6 +223,14 @@ private[kafka010] class KafkaSource( logInfo(s"GetBatch called with start = $start, end = $end") val untilPartitionOffsets = KafkaSourceOffset.getPartitionOffsets(end) + // On recovery, getBatch will get called before getOffset + if (currentPartitionOffsets.isEmpty) { + currentPartitionOffsets = Some(untilPartitionOffsets) + } + if (start.isDefined && start.get == end) { + return sqlContext.internalCreateDataFrame( + sqlContext.sparkContext.emptyRDD, schema, isStreaming = true) + } val fromPartitionOffsets = start match { case Some(prevBatchEndOffset) => KafkaSourceOffset.getPartitionOffsets(prevBatchEndOffset) @@ -305,11 +313,6 @@ private[kafka010] class KafkaSource( logInfo("GetBatch generating RDD of offset range: " + offsetRanges.sortBy(_.topicPartition.toString).mkString(", ")) - // On recovery, getBatch will get called before getOffset - if (currentPartitionOffsets.isEmpty) { - currentPartitionOffsets = Some(untilPartitionOffsets) - } - sqlContext.internalCreateDataFrame(rdd, schema, isStreaming = true) } diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala index 2034b9be07f24..a0f5695fc485c 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala @@ -318,6 +318,71 @@ class KafkaSourceSuite extends KafkaSourceTest { ) } + test("SPARK-22956: currentPartitionOffsets should be set when no new data comes in") { + def getSpecificDF(range: Range.Inclusive): org.apache.spark.sql.Dataset[Int] = { + val topic = newTopic() + testUtils.createTopic(topic, partitions = 1) + testUtils.sendMessages(topic, range.map(_.toString).toArray, Some(0)) + + val reader = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("kafka.metadata.max.age.ms", "1") + .option("maxOffsetsPerTrigger", 5) + .option("subscribe", topic) + .option("startingOffsets", "earliest") + + reader.load() + .selectExpr("CAST(value AS STRING)") + .as[String] + .map(k => k.toInt) + } + + val df1 = getSpecificDF(0 to 9) + val df2 = getSpecificDF(100 to 199) + + val kafka = df1.union(df2) + + val clock = new StreamManualClock + + val waitUntilBatchProcessed = AssertOnQuery { q => + eventually(Timeout(streamingTimeout)) { + if (!q.exception.isDefined) { + assert(clock.isStreamWaitingAt(clock.getTimeMillis())) + } + } + if (q.exception.isDefined) { + throw q.exception.get + } + true + } + + testStream(kafka)( + StartStream(ProcessingTime(100), clock), + waitUntilBatchProcessed, + // 5 from smaller topic, 5 from bigger one + CheckLastBatch((0 to 4) ++ (100 to 104): _*), + AdvanceManualClock(100), + waitUntilBatchProcessed, + // 5 from smaller topic, 5 from bigger one + CheckLastBatch((5 to 9) ++ (105 to 109): _*), + AdvanceManualClock(100), + waitUntilBatchProcessed, + // smaller topic empty, 5 from bigger one + CheckLastBatch(110 to 114: _*), + StopStream, + StartStream(ProcessingTime(100), clock), + waitUntilBatchProcessed, + // smallest now empty, 5 from bigger one + CheckLastBatch(115 to 119: _*), + AdvanceManualClock(100), + waitUntilBatchProcessed, + // smallest now empty, 5 from bigger one + CheckLastBatch(120 to 124: _*) + ) + } + test("cannot stop Kafka stream") { val topic = newTopic() testUtils.createTopic(topic, partitions = 5) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index 42240eeb58d4b..70407f0580f97 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -208,10 +208,8 @@ class MicroBatchExecution( * batch will be executed before getOffset is called again. */ availableOffsets.foreach { case (source: Source, end: Offset) => - if (committedOffsets.get(source).map(_ != end).getOrElse(true)) { - val start = committedOffsets.get(source) - source.getBatch(start, end) - } + val start = committedOffsets.get(source) + source.getBatch(start, end) case nonV1Tuple => // The V2 API does not have the same edge case requiring getBatch to be called // here, so we do nothing here. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala index 3041d4d703cb4..509a69dd922fb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -119,9 +119,15 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) val newBlocks = synchronized { val sliceStart = startOrdinal - lastOffsetCommitted.offset.toInt - 1 val sliceEnd = endOrdinal - lastOffsetCommitted.offset.toInt - 1 + assert(sliceStart <= sliceEnd, s"sliceStart: $sliceStart sliceEnd: $sliceEnd") batches.slice(sliceStart, sliceEnd) } + if (newBlocks.isEmpty) { + return sqlContext.internalCreateDataFrame( + sqlContext.sparkContext.emptyRDD, schema, isStreaming = true) + } + logDebug(generateDebugString(newBlocks, startOrdinal, endOrdinal)) newBlocks From 66217dac4f8952a9923625908ad3dcb030763c81 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Mon, 15 Jan 2018 22:40:44 -0800 Subject: [PATCH 0103/1161] [SPARK-23020][CORE] Fix races in launcher code, test. The race in the code is because the handle might update its state to the wrong state if the connection handling thread is still processing incoming data; so the handle needs to wait for the connection to finish up before checking the final state. The race in the test is because when waiting for a handle to reach a final state, the waitFor() method needs to wait until all handle state is updated (which also includes waiting for the connection thread above to finish). Otherwise, waitFor() may return too early, which would cause a bunch of different races (like the listener not being yet notified of the state change, or being in the middle of being notified, or the handle not being properly disposed and causing postChecks() to assert). On top of that I found, by code inspection, a couple of potential races that could make a handle end up in the wrong state when being killed. Tested by running the existing unit tests a lot (and not seeing the errors I was seeing before). Author: Marcelo Vanzin Closes #20223 from vanzin/SPARK-23020. --- .../spark/launcher/SparkLauncherSuite.java | 49 ++++++++++++------- .../spark/launcher/AbstractAppHandle.java | 22 +++++++-- .../spark/launcher/ChildProcAppHandle.java | 18 ++++--- .../spark/launcher/InProcessAppHandle.java | 17 ++++--- .../spark/launcher/LauncherConnection.java | 14 +++--- .../apache/spark/launcher/LauncherServer.java | 46 ++++++++++++++--- .../org/apache/spark/launcher/BaseSuite.java | 42 +++++++++++++--- .../spark/launcher/LauncherServerSuite.java | 20 +++----- 8 files changed, 156 insertions(+), 72 deletions(-) diff --git a/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java b/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java index 9d2f563b2e367..a042375c6ae91 100644 --- a/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java +++ b/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java @@ -17,6 +17,7 @@ package org.apache.spark.launcher; +import java.time.Duration; import java.util.Arrays; import java.util.ArrayList; import java.util.HashMap; @@ -31,6 +32,7 @@ import static org.mockito.Mockito.*; import org.apache.spark.SparkContext; +import org.apache.spark.SparkContext$; import org.apache.spark.internal.config.package$; import org.apache.spark.util.Utils; @@ -137,7 +139,9 @@ public void testInProcessLauncher() throws Exception { // Here DAGScheduler is stopped, while SparkContext.clearActiveContext may not be called yet. // Wait for a reasonable amount of time to avoid creating two active SparkContext in JVM. // See SPARK-23019 and SparkContext.stop() for details. - TimeUnit.MILLISECONDS.sleep(500); + eventually(Duration.ofSeconds(5), Duration.ofMillis(10), () -> { + assertTrue("SparkContext is still alive.", SparkContext$.MODULE$.getActive().isEmpty()); + }); } } @@ -146,26 +150,35 @@ private void inProcessLauncherTestImpl() throws Exception { SparkAppHandle.Listener listener = mock(SparkAppHandle.Listener.class); doAnswer(invocation -> { SparkAppHandle h = (SparkAppHandle) invocation.getArguments()[0]; - transitions.add(h.getState()); + synchronized (transitions) { + transitions.add(h.getState()); + } return null; }).when(listener).stateChanged(any(SparkAppHandle.class)); - SparkAppHandle handle = new InProcessLauncher() - .setMaster("local") - .setAppResource(SparkLauncher.NO_RESOURCE) - .setMainClass(InProcessTestApp.class.getName()) - .addAppArgs("hello") - .startApplication(listener); - - waitFor(handle); - assertEquals(SparkAppHandle.State.FINISHED, handle.getState()); - - // Matches the behavior of LocalSchedulerBackend. - List expected = Arrays.asList( - SparkAppHandle.State.CONNECTED, - SparkAppHandle.State.RUNNING, - SparkAppHandle.State.FINISHED); - assertEquals(expected, transitions); + SparkAppHandle handle = null; + try { + handle = new InProcessLauncher() + .setMaster("local") + .setAppResource(SparkLauncher.NO_RESOURCE) + .setMainClass(InProcessTestApp.class.getName()) + .addAppArgs("hello") + .startApplication(listener); + + waitFor(handle); + assertEquals(SparkAppHandle.State.FINISHED, handle.getState()); + + // Matches the behavior of LocalSchedulerBackend. + List expected = Arrays.asList( + SparkAppHandle.State.CONNECTED, + SparkAppHandle.State.RUNNING, + SparkAppHandle.State.FINISHED); + assertEquals(expected, transitions); + } finally { + if (handle != null) { + handle.kill(); + } + } } public static class SparkLauncherTestApp { diff --git a/launcher/src/main/java/org/apache/spark/launcher/AbstractAppHandle.java b/launcher/src/main/java/org/apache/spark/launcher/AbstractAppHandle.java index df1e7316861d4..daf0972f824dd 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/AbstractAppHandle.java +++ b/launcher/src/main/java/org/apache/spark/launcher/AbstractAppHandle.java @@ -33,7 +33,7 @@ abstract class AbstractAppHandle implements SparkAppHandle { private List listeners; private State state; private String appId; - private boolean disposed; + private volatile boolean disposed; protected AbstractAppHandle(LauncherServer server) { this.server = server; @@ -70,8 +70,7 @@ public void stop() { @Override public synchronized void disconnect() { - if (!disposed) { - disposed = true; + if (!isDisposed()) { if (connection != null) { try { connection.close(); @@ -79,7 +78,7 @@ public synchronized void disconnect() { // no-op. } } - server.unregister(this); + dispose(); } } @@ -95,6 +94,21 @@ boolean isDisposed() { return disposed; } + /** + * Mark the handle as disposed, and set it as LOST in case the current state is not final. + */ + synchronized void dispose() { + if (!isDisposed()) { + // Unregister first to make sure that the connection with the app has been really + // terminated. + server.unregister(this); + if (!getState().isFinal()) { + setState(State.LOST); + } + this.disposed = true; + } + } + void setState(State s) { setState(s, false); } diff --git a/launcher/src/main/java/org/apache/spark/launcher/ChildProcAppHandle.java b/launcher/src/main/java/org/apache/spark/launcher/ChildProcAppHandle.java index 8b3f427b7750e..2b99461652e1f 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/ChildProcAppHandle.java +++ b/launcher/src/main/java/org/apache/spark/launcher/ChildProcAppHandle.java @@ -48,14 +48,16 @@ public synchronized void disconnect() { @Override public synchronized void kill() { - disconnect(); - if (childProc != null) { - if (childProc.isAlive()) { - childProc.destroyForcibly(); + if (!isDisposed()) { + setState(State.KILLED); + disconnect(); + if (childProc != null) { + if (childProc.isAlive()) { + childProc.destroyForcibly(); + } + childProc = null; } - childProc = null; } - setState(State.KILLED); } void setChildProc(Process childProc, String loggerName, InputStream logStream) { @@ -94,8 +96,6 @@ void monitorChild() { return; } - disconnect(); - int ec; try { ec = proc.exitValue(); @@ -118,6 +118,8 @@ void monitorChild() { if (newState != null) { setState(newState, true); } + + disconnect(); } } diff --git a/launcher/src/main/java/org/apache/spark/launcher/InProcessAppHandle.java b/launcher/src/main/java/org/apache/spark/launcher/InProcessAppHandle.java index acd64c962604f..f04263cb74a58 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/InProcessAppHandle.java +++ b/launcher/src/main/java/org/apache/spark/launcher/InProcessAppHandle.java @@ -39,15 +39,16 @@ class InProcessAppHandle extends AbstractAppHandle { @Override public synchronized void kill() { - LOG.warning("kill() may leave the underlying app running in in-process mode."); - disconnect(); - - // Interrupt the thread. This is not guaranteed to kill the app, though. - if (app != null) { - app.interrupt(); + if (!isDisposed()) { + LOG.warning("kill() may leave the underlying app running in in-process mode."); + setState(State.KILLED); + disconnect(); + + // Interrupt the thread. This is not guaranteed to kill the app, though. + if (app != null) { + app.interrupt(); + } } - - setState(State.KILLED); } synchronized void start(String appName, Method main, String[] args) { diff --git a/launcher/src/main/java/org/apache/spark/launcher/LauncherConnection.java b/launcher/src/main/java/org/apache/spark/launcher/LauncherConnection.java index b4a8719e26053..fd6f229b2349c 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/LauncherConnection.java +++ b/launcher/src/main/java/org/apache/spark/launcher/LauncherConnection.java @@ -95,15 +95,15 @@ protected synchronized void send(Message msg) throws IOException { } @Override - public void close() throws IOException { + public synchronized void close() throws IOException { if (!closed) { - synchronized (this) { - if (!closed) { - closed = true; - socket.close(); - } - } + closed = true; + socket.close(); } } + boolean isOpen() { + return !closed; + } + } diff --git a/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java b/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java index b8999a1d7a4f4..660c4443b20b9 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java +++ b/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java @@ -217,6 +217,33 @@ void unregister(AbstractAppHandle handle) { break; } } + + // If there is a live connection for this handle, we need to wait for it to finish before + // returning, otherwise there might be a race between the connection thread processing + // buffered data and the handle cleaning up after itself, leading to potentially the wrong + // state being reported for the handle. + ServerConnection conn = null; + synchronized (clients) { + for (ServerConnection c : clients) { + if (c.handle == handle) { + conn = c; + break; + } + } + } + + if (conn != null) { + synchronized (conn) { + if (conn.isOpen()) { + try { + conn.wait(); + } catch (InterruptedException ie) { + // Ignore. + } + } + } + } + unref(); } @@ -288,7 +315,7 @@ private String createSecret() { private class ServerConnection extends LauncherConnection { private TimerTask timeout; - private AbstractAppHandle handle; + volatile AbstractAppHandle handle; ServerConnection(Socket socket, TimerTask timeout) throws IOException { super(socket); @@ -338,16 +365,21 @@ protected void handle(Message msg) throws IOException { @Override public void close() throws IOException { + if (!isOpen()) { + return; + } + synchronized (clients) { clients.remove(this); } - super.close(); + + synchronized (this) { + super.close(); + notifyAll(); + } + if (handle != null) { - if (!handle.getState().isFinal()) { - LOG.log(Level.WARNING, "Lost connection to spark application."); - handle.setState(SparkAppHandle.State.LOST); - } - handle.disconnect(); + handle.dispose(); } } diff --git a/launcher/src/test/java/org/apache/spark/launcher/BaseSuite.java b/launcher/src/test/java/org/apache/spark/launcher/BaseSuite.java index 3e1a90eae98d4..3722a59d9438e 100644 --- a/launcher/src/test/java/org/apache/spark/launcher/BaseSuite.java +++ b/launcher/src/test/java/org/apache/spark/launcher/BaseSuite.java @@ -17,6 +17,7 @@ package org.apache.spark.launcher; +import java.time.Duration; import java.util.concurrent.TimeUnit; import org.junit.After; @@ -47,19 +48,46 @@ public void postChecks() { assertNull(server); } - protected void waitFor(SparkAppHandle handle) throws Exception { - long deadline = System.nanoTime() + TimeUnit.SECONDS.toNanos(10); + protected void waitFor(final SparkAppHandle handle) throws Exception { try { - while (!handle.getState().isFinal()) { - assertTrue("Timed out waiting for handle to transition to final state.", - System.nanoTime() < deadline); - TimeUnit.MILLISECONDS.sleep(10); - } + eventually(Duration.ofSeconds(10), Duration.ofMillis(10), () -> { + assertTrue("Handle is not in final state.", handle.getState().isFinal()); + }); } finally { if (!handle.getState().isFinal()) { handle.kill(); } } + + // Wait until the handle has been marked as disposed, to make sure all cleanup tasks + // have been performed. + AbstractAppHandle ahandle = (AbstractAppHandle) handle; + eventually(Duration.ofSeconds(10), Duration.ofMillis(10), () -> { + assertTrue("Handle is still not marked as disposed.", ahandle.isDisposed()); + }); + } + + /** + * Call a closure that performs a check every "period" until it succeeds, or the timeout + * elapses. + */ + protected void eventually(Duration timeout, Duration period, Runnable check) throws Exception { + assertTrue("Timeout needs to be larger than period.", timeout.compareTo(period) > 0); + long deadline = System.nanoTime() + timeout.toNanos(); + int count = 0; + while (true) { + try { + count++; + check.run(); + return; + } catch (Throwable t) { + if (System.nanoTime() >= deadline) { + String msg = String.format("Failed check after %d tries: %s.", count, t.getMessage()); + throw new IllegalStateException(msg, t); + } + Thread.sleep(period.toMillis()); + } + } } } diff --git a/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java b/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java index 7e2b09ce25c9b..75c1af0c71e2a 100644 --- a/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java +++ b/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java @@ -23,12 +23,14 @@ import java.net.InetAddress; import java.net.Socket; import java.net.SocketException; +import java.time.Duration; import java.util.Arrays; import java.util.List; import java.util.concurrent.BlockingQueue; import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.Semaphore; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; import org.junit.Test; import static org.junit.Assert.*; @@ -197,28 +199,20 @@ private void close(Closeable c) { * server-side close immediately. */ private void waitForError(TestClient client, String secret) throws Exception { - boolean helloSent = false; - int maxTries = 10; - for (int i = 0; i < maxTries; i++) { + final AtomicBoolean helloSent = new AtomicBoolean(); + eventually(Duration.ofSeconds(1), Duration.ofMillis(10), () -> { try { - if (!helloSent) { + if (!helloSent.get()) { client.send(new Hello(secret, "1.4.0")); - helloSent = true; + helloSent.set(true); } else { client.send(new SetAppId("appId")); } fail("Expected error but message went through."); } catch (IllegalStateException | IOException e) { // Expected. - break; - } catch (AssertionError e) { - if (i < maxTries - 1) { - Thread.sleep(100); - } else { - throw new AssertionError("Test failed after " + maxTries + " attempts.", e); - } } - } + }); } private static class TestClient extends LauncherConnection { From b85eb946ac298e711dad25db0d04eee41d7fd236 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 16 Jan 2018 20:20:33 +0900 Subject: [PATCH 0104/1161] [SPARK-22978][PYSPARK] Register Vectorized UDFs for SQL Statement ## What changes were proposed in this pull request? Register Vectorized UDFs for SQL Statement. For example, ```Python >>> from pyspark.sql.functions import pandas_udf, PandasUDFType >>> pandas_udf("integer", PandasUDFType.SCALAR) ... def add_one(x): ... return x + 1 ... >>> _ = spark.udf.register("add_one", add_one) >>> spark.sql("SELECT add_one(id) FROM range(3)").collect() [Row(add_one(id)=1), Row(add_one(id)=2), Row(add_one(id)=3)] ``` ## How was this patch tested? Added test cases Author: gatorsmile Closes #20171 from gatorsmile/supportVectorizedUDF. --- python/pyspark/sql/catalog.py | 75 ++++++++++++++++++++++++---------- python/pyspark/sql/context.py | 51 ++++++++++++++++------- python/pyspark/sql/tests.py | 76 ++++++++++++++++++++++++++++++----- 3 files changed, 155 insertions(+), 47 deletions(-) diff --git a/python/pyspark/sql/catalog.py b/python/pyspark/sql/catalog.py index 156603128d063..35fbe9e669adb 100644 --- a/python/pyspark/sql/catalog.py +++ b/python/pyspark/sql/catalog.py @@ -226,18 +226,23 @@ def dropGlobalTempView(self, viewName): @ignore_unicode_prefix @since(2.0) - def registerFunction(self, name, f, returnType=StringType()): + def registerFunction(self, name, f, returnType=None): """Registers a Python function (including lambda function) or a :class:`UserDefinedFunction` - as a UDF. The registered UDF can be used in SQL statement. + as a UDF. The registered UDF can be used in SQL statements. - In addition to a name and the function itself, the return type can be optionally specified. - When the return type is not given it default to a string and conversion will automatically - be done. For any other return type, the produced object must match the specified type. + :func:`spark.udf.register` is an alias for :func:`spark.catalog.registerFunction`. - :param name: name of the UDF - :param f: a Python function, or a wrapped/native UserDefinedFunction - :param returnType: a :class:`pyspark.sql.types.DataType` object - :return: a wrapped :class:`UserDefinedFunction` + In addition to a name and the function itself, `returnType` can be optionally specified. + 1) When f is a Python function, `returnType` defaults to a string. The produced object must + match the specified type. 2) When f is a :class:`UserDefinedFunction`, Spark uses the return + type of the given UDF as the return type of the registered UDF. The input parameter + `returnType` is None by default. If given by users, the value must be None. + + :param name: name of the UDF in SQL statements. + :param f: a Python function, or a wrapped/native UserDefinedFunction. The UDF can be either + row-at-a-time or vectorized. + :param returnType: the return type of the registered UDF. + :return: a wrapped/native :class:`UserDefinedFunction` >>> strlen = spark.catalog.registerFunction("stringLengthString", len) >>> spark.sql("SELECT stringLengthString('test')").collect() @@ -256,27 +261,55 @@ def registerFunction(self, name, f, returnType=StringType()): >>> spark.sql("SELECT stringLengthInt('test')").collect() [Row(stringLengthInt(test)=4)] + >>> from pyspark.sql.types import IntegerType + >>> from pyspark.sql.functions import udf + >>> slen = udf(lambda s: len(s), IntegerType()) + >>> _ = spark.udf.register("slen", slen) + >>> spark.sql("SELECT slen('test')").collect() + [Row(slen(test)=4)] + >>> import random >>> from pyspark.sql.functions import udf - >>> from pyspark.sql.types import IntegerType, StringType + >>> from pyspark.sql.types import IntegerType >>> random_udf = udf(lambda: random.randint(0, 100), IntegerType()).asNondeterministic() - >>> newRandom_udf = spark.catalog.registerFunction("random_udf", random_udf, StringType()) + >>> new_random_udf = spark.catalog.registerFunction("random_udf", random_udf) >>> spark.sql("SELECT random_udf()").collect() # doctest: +SKIP - [Row(random_udf()=u'82')] - >>> spark.range(1).select(newRandom_udf()).collect() # doctest: +SKIP - [Row(random_udf()=u'62')] + [Row(random_udf()=82)] + >>> spark.range(1).select(new_random_udf()).collect() # doctest: +SKIP + [Row(()=26)] + + >>> from pyspark.sql.functions import pandas_udf, PandasUDFType + >>> @pandas_udf("integer", PandasUDFType.SCALAR) # doctest: +SKIP + ... def add_one(x): + ... return x + 1 + ... + >>> _ = spark.udf.register("add_one", add_one) # doctest: +SKIP + >>> spark.sql("SELECT add_one(id) FROM range(3)").collect() # doctest: +SKIP + [Row(add_one(id)=1), Row(add_one(id)=2), Row(add_one(id)=3)] """ # This is to check whether the input function is a wrapped/native UserDefinedFunction if hasattr(f, 'asNondeterministic'): - udf = UserDefinedFunction(f.func, returnType=returnType, name=name, - evalType=PythonEvalType.SQL_BATCHED_UDF, - deterministic=f.deterministic) + if returnType is not None: + raise TypeError( + "Invalid returnType: None is expected when f is a UserDefinedFunction, " + "but got %s." % returnType) + if f.evalType not in [PythonEvalType.SQL_BATCHED_UDF, + PythonEvalType.SQL_PANDAS_SCALAR_UDF]: + raise ValueError( + "Invalid f: f must be either SQL_BATCHED_UDF or SQL_PANDAS_SCALAR_UDF") + register_udf = UserDefinedFunction(f.func, returnType=f.returnType, name=name, + evalType=f.evalType, + deterministic=f.deterministic) + return_udf = f else: - udf = UserDefinedFunction(f, returnType=returnType, name=name, - evalType=PythonEvalType.SQL_BATCHED_UDF) - self._jsparkSession.udf().registerPython(name, udf._judf) - return udf._wrapped() + if returnType is None: + returnType = StringType() + register_udf = UserDefinedFunction(f, returnType=returnType, name=name, + evalType=PythonEvalType.SQL_BATCHED_UDF) + return_udf = register_udf._wrapped() + self._jsparkSession.udf().registerPython(name, register_udf._judf) + return return_udf @since(2.0) def isCached(self, tableName): diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index b8d86cc098e94..85479095af594 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -174,18 +174,23 @@ def range(self, start, end=None, step=1, numPartitions=None): @ignore_unicode_prefix @since(1.2) - def registerFunction(self, name, f, returnType=StringType()): + def registerFunction(self, name, f, returnType=None): """Registers a Python function (including lambda function) or a :class:`UserDefinedFunction` - as a UDF. The registered UDF can be used in SQL statement. + as a UDF. The registered UDF can be used in SQL statements. - In addition to a name and the function itself, the return type can be optionally specified. - When the return type is not given it default to a string and conversion will automatically - be done. For any other return type, the produced object must match the specified type. + :func:`spark.udf.register` is an alias for :func:`sqlContext.registerFunction`. - :param name: name of the UDF - :param f: a Python function, or a wrapped/native UserDefinedFunction - :param returnType: a :class:`pyspark.sql.types.DataType` object - :return: a wrapped :class:`UserDefinedFunction` + In addition to a name and the function itself, `returnType` can be optionally specified. + 1) When f is a Python function, `returnType` defaults to a string. The produced object must + match the specified type. 2) When f is a :class:`UserDefinedFunction`, Spark uses the return + type of the given UDF as the return type of the registered UDF. The input parameter + `returnType` is None by default. If given by users, the value must be None. + + :param name: name of the UDF in SQL statements. + :param f: a Python function, or a wrapped/native UserDefinedFunction. The UDF can be either + row-at-a-time or vectorized. + :param returnType: the return type of the registered UDF. + :return: a wrapped/native :class:`UserDefinedFunction` >>> strlen = sqlContext.registerFunction("stringLengthString", lambda x: len(x)) >>> sqlContext.sql("SELECT stringLengthString('test')").collect() @@ -204,15 +209,31 @@ def registerFunction(self, name, f, returnType=StringType()): >>> sqlContext.sql("SELECT stringLengthInt('test')").collect() [Row(stringLengthInt(test)=4)] + >>> from pyspark.sql.types import IntegerType + >>> from pyspark.sql.functions import udf + >>> slen = udf(lambda s: len(s), IntegerType()) + >>> _ = sqlContext.udf.register("slen", slen) + >>> sqlContext.sql("SELECT slen('test')").collect() + [Row(slen(test)=4)] + >>> import random >>> from pyspark.sql.functions import udf - >>> from pyspark.sql.types import IntegerType, StringType + >>> from pyspark.sql.types import IntegerType >>> random_udf = udf(lambda: random.randint(0, 100), IntegerType()).asNondeterministic() - >>> newRandom_udf = sqlContext.registerFunction("random_udf", random_udf, StringType()) + >>> new_random_udf = sqlContext.registerFunction("random_udf", random_udf) >>> sqlContext.sql("SELECT random_udf()").collect() # doctest: +SKIP - [Row(random_udf()=u'82')] - >>> sqlContext.range(1).select(newRandom_udf()).collect() # doctest: +SKIP - [Row(random_udf()=u'62')] + [Row(random_udf()=82)] + >>> sqlContext.range(1).select(new_random_udf()).collect() # doctest: +SKIP + [Row(()=26)] + + >>> from pyspark.sql.functions import pandas_udf, PandasUDFType + >>> @pandas_udf("integer", PandasUDFType.SCALAR) # doctest: +SKIP + ... def add_one(x): + ... return x + 1 + ... + >>> _ = sqlContext.udf.register("add_one", add_one) # doctest: +SKIP + >>> sqlContext.sql("SELECT add_one(id) FROM range(3)").collect() # doctest: +SKIP + [Row(add_one(id)=1), Row(add_one(id)=2), Row(add_one(id)=3)] """ return self.sparkSession.catalog.registerFunction(name, f, returnType) @@ -575,7 +596,7 @@ class UDFRegistration(object): def __init__(self, sqlContext): self.sqlContext = sqlContext - def register(self, name, f, returnType=StringType()): + def register(self, name, f, returnType=None): return self.sqlContext.registerFunction(name, f, returnType) def registerJavaFunction(self, name, javaClassName, returnType=None): diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 80a94a91a87b3..8906618666b14 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -380,12 +380,25 @@ def test_udf2(self): self.assertEqual(4, res[0]) def test_udf3(self): - twoargs = self.spark.catalog.registerFunction( - "twoArgs", UserDefinedFunction(lambda x, y: len(x) + y), IntegerType()) - self.assertEqual(twoargs.deterministic, True) + two_args = self.spark.catalog.registerFunction( + "twoArgs", UserDefinedFunction(lambda x, y: len(x) + y)) + self.assertEqual(two_args.deterministic, True) + [row] = self.spark.sql("SELECT twoArgs('test', 1)").collect() + self.assertEqual(row[0], u'5') + + def test_udf_registration_return_type_none(self): + two_args = self.spark.catalog.registerFunction( + "twoArgs", UserDefinedFunction(lambda x, y: len(x) + y, "integer"), None) + self.assertEqual(two_args.deterministic, True) [row] = self.spark.sql("SELECT twoArgs('test', 1)").collect() self.assertEqual(row[0], 5) + def test_udf_registration_return_type_not_none(self): + with QuietTest(self.sc): + with self.assertRaisesRegexp(TypeError, "Invalid returnType"): + self.spark.catalog.registerFunction( + "f", UserDefinedFunction(lambda x, y: len(x) + y, StringType()), StringType()) + def test_nondeterministic_udf(self): # Test that nondeterministic UDFs are evaluated only once in chained UDF evaluations from pyspark.sql.functions import udf @@ -402,12 +415,12 @@ def test_nondeterministic_udf2(self): from pyspark.sql.functions import udf random_udf = udf(lambda: random.randint(6, 6), IntegerType()).asNondeterministic() self.assertEqual(random_udf.deterministic, False) - random_udf1 = self.spark.catalog.registerFunction("randInt", random_udf, StringType()) + random_udf1 = self.spark.catalog.registerFunction("randInt", random_udf) self.assertEqual(random_udf1.deterministic, False) [row] = self.spark.sql("SELECT randInt()").collect() - self.assertEqual(row[0], "6") + self.assertEqual(row[0], 6) [row] = self.spark.range(1).select(random_udf1()).collect() - self.assertEqual(row[0], "6") + self.assertEqual(row[0], 6) [row] = self.spark.range(1).select(random_udf()).collect() self.assertEqual(row[0], 6) # render_doc() reproduces the help() exception without printing output @@ -3691,7 +3704,7 @@ def tearDownClass(cls): ReusedSQLTestCase.tearDownClass() @property - def random_udf(self): + def nondeterministic_vectorized_udf(self): from pyspark.sql.functions import pandas_udf @pandas_udf('double') @@ -3726,6 +3739,21 @@ def test_vectorized_udf_basic(self): bool_f(col('bool'))) self.assertEquals(df.collect(), res.collect()) + def test_register_nondeterministic_vectorized_udf_basic(self): + from pyspark.sql.functions import pandas_udf + from pyspark.rdd import PythonEvalType + import random + random_pandas_udf = pandas_udf( + lambda x: random.randint(6, 6) + x, IntegerType()).asNondeterministic() + self.assertEqual(random_pandas_udf.deterministic, False) + self.assertEqual(random_pandas_udf.evalType, PythonEvalType.SQL_PANDAS_SCALAR_UDF) + nondeterministic_pandas_udf = self.spark.catalog.registerFunction( + "randomPandasUDF", random_pandas_udf) + self.assertEqual(nondeterministic_pandas_udf.deterministic, False) + self.assertEqual(nondeterministic_pandas_udf.evalType, PythonEvalType.SQL_PANDAS_SCALAR_UDF) + [row] = self.spark.sql("SELECT randomPandasUDF(1)").collect() + self.assertEqual(row[0], 7) + def test_vectorized_udf_null_boolean(self): from pyspark.sql.functions import pandas_udf, col data = [(True,), (True,), (None,), (False,)] @@ -4085,14 +4113,14 @@ def test_vectorized_udf_timestamps_respect_session_timezone(self): finally: self.spark.conf.set("spark.sql.session.timeZone", orig_tz) - def test_nondeterministic_udf(self): + def test_nondeterministic_vectorized_udf(self): # Test that nondeterministic UDFs are evaluated only once in chained UDF evaluations from pyspark.sql.functions import udf, pandas_udf, col @pandas_udf('double') def plus_ten(v): return v + 10 - random_udf = self.random_udf + random_udf = self.nondeterministic_vectorized_udf df = self.spark.range(10).withColumn('rand', random_udf(col('id'))) result1 = df.withColumn('plus_ten(rand)', plus_ten(df['rand'])).toPandas() @@ -4100,11 +4128,11 @@ def plus_ten(v): self.assertEqual(random_udf.deterministic, False) self.assertTrue(result1['plus_ten(rand)'].equals(result1['rand'] + 10)) - def test_nondeterministic_udf_in_aggregate(self): + def test_nondeterministic_vectorized_udf_in_aggregate(self): from pyspark.sql.functions import pandas_udf, sum df = self.spark.range(10) - random_udf = self.random_udf + random_udf = self.nondeterministic_vectorized_udf with QuietTest(self.sc): with self.assertRaisesRegexp(AnalysisException, 'nondeterministic'): @@ -4112,6 +4140,23 @@ def test_nondeterministic_udf_in_aggregate(self): with self.assertRaisesRegexp(AnalysisException, 'nondeterministic'): df.agg(sum(random_udf(df.id))).collect() + def test_register_vectorized_udf_basic(self): + from pyspark.rdd import PythonEvalType + from pyspark.sql.functions import pandas_udf, col, expr + df = self.spark.range(10).select( + col('id').cast('int').alias('a'), + col('id').cast('int').alias('b')) + original_add = pandas_udf(lambda x, y: x + y, IntegerType()) + self.assertEqual(original_add.deterministic, True) + self.assertEqual(original_add.evalType, PythonEvalType.SQL_PANDAS_SCALAR_UDF) + new_add = self.spark.catalog.registerFunction("add1", original_add) + res1 = df.select(new_add(col('a'), col('b'))) + res2 = self.spark.sql( + "SELECT add1(t.a, t.b) FROM (SELECT id as a, id as b FROM range(10)) t") + expected = df.select(expr('a + b')) + self.assertEquals(expected.collect(), res1.collect()) + self.assertEquals(expected.collect(), res2.collect()) + @unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed") class GroupbyApplyTests(ReusedSQLTestCase): @@ -4147,6 +4192,15 @@ def test_simple(self): expected = df.toPandas().groupby('id').apply(foo_udf.func).reset_index(drop=True) self.assertFramesEqual(expected, result) + def test_register_group_map_udf(self): + from pyspark.sql.functions import pandas_udf, PandasUDFType + + foo_udf = pandas_udf(lambda x: x, "id long", PandasUDFType.GROUP_MAP) + with QuietTest(self.sc): + with self.assertRaisesRegexp(ValueError, 'f must be either SQL_BATCHED_UDF or ' + 'SQL_PANDAS_SCALAR_UDF'): + self.spark.catalog.registerFunction("foo_udf", foo_udf) + def test_decorator(self): from pyspark.sql.functions import pandas_udf, PandasUDFType df = self.data From 75db14864d2bd9b8e13154226e94d466e3a7e0a0 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 16 Jan 2018 22:41:30 +0800 Subject: [PATCH 0105/1161] [SPARK-22392][SQL] data source v2 columnar batch reader ## What changes were proposed in this pull request? a new Data Source V2 interface to allow the data source to return `ColumnarBatch` during the scan. ## How was this patch tested? new tests Author: Wenchen Fan Closes #20153 from cloud-fan/columnar-reader. --- .../sources/v2/reader/DataSourceV2Reader.java | 5 +- .../v2/reader/SupportsScanColumnarBatch.java | 52 ++++++++ .../v2/reader/SupportsScanUnsafeRow.java | 2 +- .../sql/execution/ColumnarBatchScan.scala | 37 +++++- .../sql/execution/DataSourceScanExec.scala | 39 ++---- .../columnar/InMemoryTableScanExec.scala | 101 +++++++++------- .../datasources/v2/DataSourceRDD.scala | 20 ++-- .../datasources/v2/DataSourceV2ScanExec.scala | 72 ++++++----- .../ContinuousDataSourceRDDIter.scala | 4 +- .../sql/sources/v2/JavaBatchDataSourceV2.java | 112 ++++++++++++++++++ .../execution/WholeStageCodegenSuite.scala | 28 ++--- .../sql/sources/v2/DataSourceV2Suite.scala | 72 ++++++++++- 12 files changed, 400 insertions(+), 144 deletions(-) create mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanColumnarBatch.java create mode 100644 sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceV2Reader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceV2Reader.java index 95ee4a8278322..f23c3842bf1b1 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceV2Reader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceV2Reader.java @@ -38,7 +38,10 @@ * 2. Information Reporting. E.g., statistics reporting, ordering reporting, etc. * Names of these interfaces start with `SupportsReporting`. * 3. Special scans. E.g, columnar scan, unsafe row scan, etc. - * Names of these interfaces start with `SupportsScan`. + * Names of these interfaces start with `SupportsScan`. Note that a reader should only + * implement at most one of the special scans, if more than one special scans are implemented, + * only one of them would be respected, according to the priority list from high to low: + * {@link SupportsScanColumnarBatch}, {@link SupportsScanUnsafeRow}. * * If an exception was throw when applying any of these query optimizations, the action would fail * and no Spark job was submitted. diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanColumnarBatch.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanColumnarBatch.java new file mode 100644 index 0000000000000..27cf3a77724f0 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanColumnarBatch.java @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.reader; + +import java.util.List; + +import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.vectorized.ColumnarBatch; + +/** + * A mix-in interface for {@link DataSourceV2Reader}. Data source readers can implement this + * interface to output {@link ColumnarBatch} and make the scan faster. + */ +@InterfaceStability.Evolving +public interface SupportsScanColumnarBatch extends DataSourceV2Reader { + @Override + default List> createReadTasks() { + throw new IllegalStateException( + "createReadTasks not supported by default within SupportsScanColumnarBatch."); + } + + /** + * Similar to {@link DataSourceV2Reader#createReadTasks()}, but returns columnar data in batches. + */ + List> createBatchReadTasks(); + + /** + * Returns true if the concrete data source reader can read data in batch according to the scan + * properties like required columns, pushes filters, etc. It's possible that the implementation + * can only support some certain columns with certain types. Users can overwrite this method and + * {@link #createReadTasks()} to fallback to normal read path under some conditions. + */ + default boolean enableBatchRead() { + return true; + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanUnsafeRow.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanUnsafeRow.java index b90ec880dc85e..2d3ad0eee65ff 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanUnsafeRow.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanUnsafeRow.java @@ -35,7 +35,7 @@ public interface SupportsScanUnsafeRow extends DataSourceV2Reader { @Override default List> createReadTasks() { throw new IllegalStateException( - "createReadTasks should not be called with SupportsScanUnsafeRow."); + "createReadTasks not supported by default within SupportsScanUnsafeRow"); } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala index 5617046e1396e..dd68df9686691 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution -import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.catalyst.expressions.{BoundReference, UnsafeRow} import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.types.DataType @@ -25,13 +25,16 @@ import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector} /** - * Helper trait for abstracting scan functionality using - * [[ColumnarBatch]]es. + * Helper trait for abstracting scan functionality using [[ColumnarBatch]]es. */ private[sql] trait ColumnarBatchScan extends CodegenSupport { def vectorTypes: Option[Seq[String]] = None + protected def supportsBatch: Boolean = true + + protected def needsUnsafeRowConversion: Boolean = true + override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), "scanTime" -> SQLMetrics.createTimingMetric(sparkContext, "scan time")) @@ -71,7 +74,14 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport { // PhysicalRDD always just has one input val input = ctx.addMutableState("scala.collection.Iterator", "input", v => s"$v = inputs[0];") + if (supportsBatch) { + produceBatches(ctx, input) + } else { + produceRows(ctx, input) + } + } + private def produceBatches(ctx: CodegenContext, input: String): String = { // metrics val numOutputRows = metricTerm(ctx, "numOutputRows") val scanTimeMetric = metricTerm(ctx, "scanTime") @@ -137,4 +147,25 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport { """.stripMargin } + private def produceRows(ctx: CodegenContext, input: String): String = { + val numOutputRows = metricTerm(ctx, "numOutputRows") + val row = ctx.freshName("row") + + ctx.INPUT_ROW = row + ctx.currentVars = null + // Always provide `outputVars`, so that the framework can help us build unsafe row if the input + // row is not unsafe row, i.e. `needsUnsafeRowConversion` is true. + val outputVars = output.zipWithIndex.map { case (a, i) => + BoundReference(i, a.dataType, a.nullable).genCode(ctx) + } + val inputRow = if (needsUnsafeRowConversion) null else row + s""" + |while ($input.hasNext()) { + | InternalRow $row = (InternalRow) $input.next(); + | $numOutputRows.add(1); + | ${consume(ctx, outputVars, inputRow).trim} + | if (shouldStop()) return; + |} + """.stripMargin + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index d1ff82c7c06bc..7c7d79c2bbd7c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -164,13 +164,15 @@ case class FileSourceScanExec( override val tableIdentifier: Option[TableIdentifier]) extends DataSourceScanExec with ColumnarBatchScan { - val supportsBatch: Boolean = relation.fileFormat.supportBatch( + override val supportsBatch: Boolean = relation.fileFormat.supportBatch( relation.sparkSession, StructType.fromAttributes(output)) - val needsUnsafeRowConversion: Boolean = if (relation.fileFormat.isInstanceOf[ParquetSource]) { - SparkSession.getActiveSession.get.sessionState.conf.parquetVectorizedReaderEnabled - } else { - false + override val needsUnsafeRowConversion: Boolean = { + if (relation.fileFormat.isInstanceOf[ParquetSource]) { + SparkSession.getActiveSession.get.sessionState.conf.parquetVectorizedReaderEnabled + } else { + false + } } override def vectorTypes: Option[Seq[String]] = @@ -346,33 +348,6 @@ case class FileSourceScanExec( override val nodeNamePrefix: String = "File" - override protected def doProduce(ctx: CodegenContext): String = { - if (supportsBatch) { - return super.doProduce(ctx) - } - val numOutputRows = metricTerm(ctx, "numOutputRows") - // PhysicalRDD always just has one input - val input = ctx.addMutableState("scala.collection.Iterator", "input", v => s"$v = inputs[0];") - val row = ctx.freshName("row") - - ctx.INPUT_ROW = row - ctx.currentVars = null - // Always provide `outputVars`, so that the framework can help us build unsafe row if the input - // row is not unsafe row, i.e. `needsUnsafeRowConversion` is true. - val outputVars = output.zipWithIndex.map{ case (a, i) => - BoundReference(i, a.dataType, a.nullable).genCode(ctx) - } - val inputRow = if (needsUnsafeRowConversion) null else row - s""" - |while ($input.hasNext()) { - | InternalRow $row = (InternalRow) $input.next(); - | $numOutputRows.add(1); - | ${consume(ctx, outputVars, inputRow).trim} - | if (shouldStop()) return; - |} - """.stripMargin - } - /** * Create an RDD for bucketed reads. * The non-bucketed variant of this function is [[createNonBucketedReadRDD]]. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala index 933b9753faa61..3565ee3af1b9f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala @@ -49,9 +49,9 @@ case class InMemoryTableScanExec( /** * If true, get data from ColumnVector in ColumnarBatch, which are generally faster. - * If false, get data from UnsafeRow build from ColumnVector + * If false, get data from UnsafeRow build from CachedBatch */ - override val supportCodegen: Boolean = { + override val supportsBatch: Boolean = { // In the initial implementation, for ease of review // support only primitive data types and # of fields is less than wholeStageMaxNumFields relation.schema.fields.forall(f => f.dataType match { @@ -61,6 +61,8 @@ case class InMemoryTableScanExec( }) && !WholeStageCodegenExec.isTooManyFields(conf, relation.schema) } + override protected def needsUnsafeRowConversion: Boolean = false + private val columnIndices = attributes.map(a => relation.output.map(o => o.exprId).indexOf(a.exprId)).toArray @@ -90,14 +92,56 @@ case class InMemoryTableScanExec( columnarBatch } - override def inputRDDs(): Seq[RDD[InternalRow]] = { - assert(supportCodegen) + private lazy val inputRDD: RDD[InternalRow] = { val buffers = filteredCachedBatches() - // HACK ALERT: This is actually an RDD[ColumnarBatch]. - // We're taking advantage of Scala's type erasure here to pass these batches along. - Seq(buffers.map(createAndDecompressColumn(_)).asInstanceOf[RDD[InternalRow]]) + if (supportsBatch) { + // HACK ALERT: This is actually an RDD[ColumnarBatch]. + // We're taking advantage of Scala's type erasure here to pass these batches along. + buffers.map(createAndDecompressColumn).asInstanceOf[RDD[InternalRow]] + } else { + val numOutputRows = longMetric("numOutputRows") + + if (enableAccumulatorsForTest) { + readPartitions.setValue(0) + readBatches.setValue(0) + } + + // Using these variables here to avoid serialization of entire objects (if referenced + // directly) within the map Partitions closure. + val relOutput: AttributeSeq = relation.output + + filteredCachedBatches().mapPartitionsInternal { cachedBatchIterator => + // Find the ordinals and data types of the requested columns. + val (requestedColumnIndices, requestedColumnDataTypes) = + attributes.map { a => + relOutput.indexOf(a.exprId) -> a.dataType + }.unzip + + // update SQL metrics + val withMetrics = cachedBatchIterator.map { batch => + if (enableAccumulatorsForTest) { + readBatches.add(1) + } + numOutputRows += batch.numRows + batch + } + + val columnTypes = requestedColumnDataTypes.map { + case udt: UserDefinedType[_] => udt.sqlType + case other => other + }.toArray + val columnarIterator = GenerateColumnAccessor.generate(columnTypes) + columnarIterator.initialize(withMetrics, columnTypes, requestedColumnIndices.toArray) + if (enableAccumulatorsForTest && columnarIterator.hasNext) { + readPartitions.add(1) + } + columnarIterator + } + } } + override def inputRDDs(): Seq[RDD[InternalRow]] = Seq(inputRDD) + override def output: Seq[Attribute] = attributes private def updateAttribute(expr: Expression): Expression = { @@ -185,7 +229,7 @@ case class InMemoryTableScanExec( } } - lazy val enableAccumulators: Boolean = + lazy val enableAccumulatorsForTest: Boolean = sqlContext.getConf("spark.sql.inMemoryTableScanStatistics.enable", "false").toBoolean // Accumulators used for testing purposes @@ -230,43 +274,10 @@ case class InMemoryTableScanExec( } protected override def doExecute(): RDD[InternalRow] = { - val numOutputRows = longMetric("numOutputRows") - - if (enableAccumulators) { - readPartitions.setValue(0) - readBatches.setValue(0) - } - - // Using these variables here to avoid serialization of entire objects (if referenced directly) - // within the map Partitions closure. - val relOutput: AttributeSeq = relation.output - - filteredCachedBatches().mapPartitionsInternal { cachedBatchIterator => - // Find the ordinals and data types of the requested columns. - val (requestedColumnIndices, requestedColumnDataTypes) = - attributes.map { a => - relOutput.indexOf(a.exprId) -> a.dataType - }.unzip - - // update SQL metrics - val withMetrics = cachedBatchIterator.map { batch => - if (enableAccumulators) { - readBatches.add(1) - } - numOutputRows += batch.numRows - batch - } - - val columnTypes = requestedColumnDataTypes.map { - case udt: UserDefinedType[_] => udt.sqlType - case other => other - }.toArray - val columnarIterator = GenerateColumnAccessor.generate(columnTypes) - columnarIterator.initialize(withMetrics, columnTypes, requestedColumnIndices.toArray) - if (enableAccumulators && columnarIterator.hasNext) { - readPartitions.add(1) - } - columnarIterator + if (supportsBatch) { + WholeStageCodegenExec(this).execute() + } else { + inputRDD } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala index 5f30be5ed4af1..ac104d7cd0cb3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala @@ -18,19 +18,19 @@ package org.apache.spark.sql.execution.datasources.v2 import scala.collection.JavaConverters._ +import scala.reflect.ClassTag import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskContext} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.sources.v2.reader.ReadTask -class DataSourceRDDPartition(val index: Int, val readTask: ReadTask[UnsafeRow]) +class DataSourceRDDPartition[T : ClassTag](val index: Int, val readTask: ReadTask[T]) extends Partition with Serializable -class DataSourceRDD( +class DataSourceRDD[T: ClassTag]( sc: SparkContext, - @transient private val readTasks: java.util.List[ReadTask[UnsafeRow]]) - extends RDD[UnsafeRow](sc, Nil) { + @transient private val readTasks: java.util.List[ReadTask[T]]) + extends RDD[T](sc, Nil) { override protected def getPartitions: Array[Partition] = { readTasks.asScala.zipWithIndex.map { @@ -38,10 +38,10 @@ class DataSourceRDD( }.toArray } - override def compute(split: Partition, context: TaskContext): Iterator[UnsafeRow] = { - val reader = split.asInstanceOf[DataSourceRDDPartition].readTask.createDataReader() + override def compute(split: Partition, context: TaskContext): Iterator[T] = { + val reader = split.asInstanceOf[DataSourceRDDPartition[T]].readTask.createDataReader() context.addTaskCompletionListener(_ => reader.close()) - val iter = new Iterator[UnsafeRow] { + val iter = new Iterator[T] { private[this] var valuePrepared = false override def hasNext: Boolean = { @@ -51,7 +51,7 @@ class DataSourceRDD( valuePrepared } - override def next(): UnsafeRow = { + override def next(): T = { if (!hasNext) { throw new java.util.NoSuchElementException("End of stream") } @@ -63,6 +63,6 @@ class DataSourceRDD( } override def getPreferredLocations(split: Partition): Seq[String] = { - split.asInstanceOf[DataSourceRDDPartition].readTask.preferredLocations() + split.asInstanceOf[DataSourceRDDPartition[T]].readTask.preferredLocations() } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala index 49c506bc560cf..8c64df080242f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala @@ -24,10 +24,8 @@ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder} import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.execution.LeafExecNode -import org.apache.spark.sql.execution.metric.SQLMetrics -import org.apache.spark.sql.execution.streaming.StreamExecution -import org.apache.spark.sql.execution.streaming.continuous.{ContinuousDataSourceRDD, ContinuousExecution, EpochCoordinatorRef, SetReaderPartitions} +import org.apache.spark.sql.execution.{ColumnarBatchScan, LeafExecNode, WholeStageCodegenExec} +import org.apache.spark.sql.execution.streaming.continuous._ import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.sources.v2.streaming.reader.ContinuousReader import org.apache.spark.sql.types.StructType @@ -37,40 +35,56 @@ import org.apache.spark.sql.types.StructType */ case class DataSourceV2ScanExec( fullOutput: Seq[AttributeReference], - @transient reader: DataSourceV2Reader) extends LeafExecNode with DataSourceReaderHolder { + @transient reader: DataSourceV2Reader) + extends LeafExecNode with DataSourceReaderHolder with ColumnarBatchScan { override def canEqual(other: Any): Boolean = other.isInstanceOf[DataSourceV2ScanExec] - override def references: AttributeSet = AttributeSet.empty + override def producedAttributes: AttributeSet = AttributeSet(fullOutput) - override lazy val metrics = Map( - "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) + private lazy val readTasks: java.util.List[ReadTask[UnsafeRow]] = reader match { + case r: SupportsScanUnsafeRow => r.createUnsafeRowReadTasks() + case _ => + reader.createReadTasks().asScala.map { + new RowToUnsafeRowReadTask(_, reader.readSchema()): ReadTask[UnsafeRow] + }.asJava + } - override protected def doExecute(): RDD[InternalRow] = { - val readTasks: java.util.List[ReadTask[UnsafeRow]] = reader match { - case r: SupportsScanUnsafeRow => r.createUnsafeRowReadTasks() - case _ => - reader.createReadTasks().asScala.map { - new RowToUnsafeRowReadTask(_, reader.readSchema()): ReadTask[UnsafeRow] - }.asJava - } + private lazy val inputRDD: RDD[InternalRow] = reader match { + case r: SupportsScanColumnarBatch if r.enableBatchRead() => + assert(!reader.isInstanceOf[ContinuousReader], + "continuous stream reader does not support columnar read yet.") + new DataSourceRDD(sparkContext, r.createBatchReadTasks()).asInstanceOf[RDD[InternalRow]] + + case _: ContinuousReader => + EpochCoordinatorRef.get( + sparkContext.getLocalProperty(ContinuousExecution.RUN_ID_KEY), sparkContext.env) + .askSync[Unit](SetReaderPartitions(readTasks.size())) + new ContinuousDataSourceRDD(sparkContext, sqlContext, readTasks) + .asInstanceOf[RDD[InternalRow]] + + case _ => + new DataSourceRDD(sparkContext, readTasks).asInstanceOf[RDD[InternalRow]] + } - val inputRDD = reader match { - case _: ContinuousReader => - EpochCoordinatorRef.get( - sparkContext.getLocalProperty(ContinuousExecution.RUN_ID_KEY), sparkContext.env) - .askSync[Unit](SetReaderPartitions(readTasks.size())) + override def inputRDDs(): Seq[RDD[InternalRow]] = Seq(inputRDD) - new ContinuousDataSourceRDD(sparkContext, sqlContext, readTasks) + override val supportsBatch: Boolean = reader match { + case r: SupportsScanColumnarBatch if r.enableBatchRead() => true + case _ => false + } - case _ => - new DataSourceRDD(sparkContext, readTasks) - } + override protected def needsUnsafeRowConversion: Boolean = false - val numOutputRows = longMetric("numOutputRows") - inputRDD.asInstanceOf[RDD[InternalRow]].map { r => - numOutputRows += 1 - r + override protected def doExecute(): RDD[InternalRow] = { + if (supportsBatch) { + WholeStageCodegenExec(this).execute() + } else { + val numOutputRows = longMetric("numOutputRows") + inputRDD.map { r => + numOutputRows += 1 + r + } } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala index d79e4bd65f563..b3f1a1a1aaab3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala @@ -52,7 +52,7 @@ class ContinuousDataSourceRDD( } override def compute(split: Partition, context: TaskContext): Iterator[UnsafeRow] = { - val reader = split.asInstanceOf[DataSourceRDDPartition].readTask.createDataReader() + val reader = split.asInstanceOf[DataSourceRDDPartition[UnsafeRow]].readTask.createDataReader() val runId = context.getLocalProperty(ContinuousExecution.RUN_ID_KEY) @@ -132,7 +132,7 @@ class ContinuousDataSourceRDD( } override def getPreferredLocations(split: Partition): Seq[String] = { - split.asInstanceOf[DataSourceRDDPartition].readTask.preferredLocations() + split.asInstanceOf[DataSourceRDDPartition[UnsafeRow]].readTask.preferredLocations() } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java new file mode 100644 index 0000000000000..44e5146d7c553 --- /dev/null +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test.org.apache.spark.sql.sources.v2; + +import java.io.IOException; +import java.util.List; + +import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector; +import org.apache.spark.sql.sources.v2.DataSourceV2; +import org.apache.spark.sql.sources.v2.DataSourceV2Options; +import org.apache.spark.sql.sources.v2.ReadSupport; +import org.apache.spark.sql.sources.v2.reader.*; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.vectorized.ColumnVector; +import org.apache.spark.sql.vectorized.ColumnarBatch; + + +public class JavaBatchDataSourceV2 implements DataSourceV2, ReadSupport { + + class Reader implements DataSourceV2Reader, SupportsScanColumnarBatch { + private final StructType schema = new StructType().add("i", "int").add("j", "int"); + + @Override + public StructType readSchema() { + return schema; + } + + @Override + public List> createBatchReadTasks() { + return java.util.Arrays.asList(new JavaBatchReadTask(0, 50), new JavaBatchReadTask(50, 90)); + } + } + + static class JavaBatchReadTask implements ReadTask, DataReader { + private int start; + private int end; + + private static final int BATCH_SIZE = 20; + + private OnHeapColumnVector i; + private OnHeapColumnVector j; + private ColumnarBatch batch; + + JavaBatchReadTask(int start, int end) { + this.start = start; + this.end = end; + } + + @Override + public DataReader createDataReader() { + this.i = new OnHeapColumnVector(BATCH_SIZE, DataTypes.IntegerType); + this.j = new OnHeapColumnVector(BATCH_SIZE, DataTypes.IntegerType); + ColumnVector[] vectors = new ColumnVector[2]; + vectors[0] = i; + vectors[1] = j; + this.batch = new ColumnarBatch(new StructType().add("i", "int").add("j", "int"), vectors, BATCH_SIZE); + return this; + } + + @Override + public boolean next() { + i.reset(); + j.reset(); + int count = 0; + while (start < end && count < BATCH_SIZE) { + i.putInt(count, start); + j.putInt(count, -start); + start += 1; + count += 1; + } + + if (count == 0) { + return false; + } else { + batch.setNumRows(count); + return true; + } + } + + @Override + public ColumnarBatch get() { + return batch; + } + + @Override + public void close() throws IOException { + batch.close(); + } + } + + + @Override + public DataSourceV2Reader createReader(DataSourceV2Options options) { + return new Reader(); + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index bc05dca578c47..22ca128c27768 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -121,31 +121,23 @@ class WholeStageCodegenSuite extends QueryTest with SharedSQLContext { test("cache for primitive type should be in WholeStageCodegen with InMemoryTableScanExec") { import testImplicits._ - val dsInt = spark.range(3).cache - dsInt.count + val dsInt = spark.range(3).cache() + dsInt.count() val dsIntFilter = dsInt.filter(_ > 0) val planInt = dsIntFilter.queryExecution.executedPlan - assert(planInt.find(p => - p.isInstanceOf[WholeStageCodegenExec] && - p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[FilterExec] && - p.asInstanceOf[WholeStageCodegenExec].child.asInstanceOf[FilterExec].child - .isInstanceOf[InMemoryTableScanExec] && - p.asInstanceOf[WholeStageCodegenExec].child.asInstanceOf[FilterExec].child - .asInstanceOf[InMemoryTableScanExec].supportCodegen).isDefined - ) + assert(planInt.collect { + case WholeStageCodegenExec(FilterExec(_, i: InMemoryTableScanExec)) if i.supportsBatch => () + }.length == 1) assert(dsIntFilter.collect() === Array(1, 2)) // cache for string type is not supported for InMemoryTableScanExec - val dsString = spark.range(3).map(_.toString).cache - dsString.count + val dsString = spark.range(3).map(_.toString).cache() + dsString.count() val dsStringFilter = dsString.filter(_ == "1") val planString = dsStringFilter.queryExecution.executedPlan - assert(planString.find(p => - p.isInstanceOf[WholeStageCodegenExec] && - p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[FilterExec] && - !p.asInstanceOf[WholeStageCodegenExec].child.asInstanceOf[FilterExec].child - .isInstanceOf[InMemoryTableScanExec]).isDefined - ) + assert(planString.collect { + case WholeStageCodegenExec(FilterExec(_, i: InMemoryTableScanExec)) if !i.supportsBatch => () + }.length == 1) assert(dsStringFilter.collect() === Array("1")) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala index ab37e4984bd1f..a89f7c55bf4f7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala @@ -24,10 +24,12 @@ import test.org.apache.spark.sql.sources.v2._ import org.apache.spark.SparkException import org.apache.spark.sql.{AnalysisException, QueryTest, Row} import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector import org.apache.spark.sql.sources.{Filter, GreaterThan} import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{IntegerType, StructType} +import org.apache.spark.sql.vectorized.ColumnarBatch class DataSourceV2Suite extends QueryTest with SharedSQLContext { import testImplicits._ @@ -56,7 +58,7 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { } } - test("unsafe row implementation") { + test("unsafe row scan implementation") { Seq(classOf[UnsafeRowDataSourceV2], classOf[JavaUnsafeRowDataSourceV2]).foreach { cls => withClue(cls.getName) { val df = spark.read.format(cls.getName).load() @@ -67,6 +69,17 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { } } + test("columnar batch scan implementation") { + Seq(classOf[BatchDataSourceV2], classOf[JavaBatchDataSourceV2]).foreach { cls => + withClue(cls.getName) { + val df = spark.read.format(cls.getName).load() + checkAnswer(df, (0 until 90).map(i => Row(i, -i))) + checkAnswer(df.select('j), (0 until 90).map(i => Row(-i))) + checkAnswer(df.filter('i > 50), (51 until 90).map(i => Row(i, -i))) + } + } + } + test("schema required data source") { Seq(classOf[SchemaRequiredDataSource], classOf[JavaSchemaRequiredDataSource]).foreach { cls => withClue(cls.getName) { @@ -275,7 +288,7 @@ class UnsafeRowReadTask(start: Int, end: Int) private var current = start - 1 - override def createDataReader(): DataReader[UnsafeRow] = new UnsafeRowReadTask(start, end) + override def createDataReader(): DataReader[UnsafeRow] = this override def next(): Boolean = { current += 1 @@ -300,3 +313,56 @@ class SchemaRequiredDataSource extends DataSourceV2 with ReadSupportWithSchema { override def createReader(schema: StructType, options: DataSourceV2Options): DataSourceV2Reader = new Reader(schema) } + +class BatchDataSourceV2 extends DataSourceV2 with ReadSupport { + + class Reader extends DataSourceV2Reader with SupportsScanColumnarBatch { + override def readSchema(): StructType = new StructType().add("i", "int").add("j", "int") + + override def createBatchReadTasks(): JList[ReadTask[ColumnarBatch]] = { + java.util.Arrays.asList(new BatchReadTask(0, 50), new BatchReadTask(50, 90)) + } + } + + override def createReader(options: DataSourceV2Options): DataSourceV2Reader = new Reader +} + +class BatchReadTask(start: Int, end: Int) + extends ReadTask[ColumnarBatch] with DataReader[ColumnarBatch] { + + private final val BATCH_SIZE = 20 + private lazy val i = new OnHeapColumnVector(BATCH_SIZE, IntegerType) + private lazy val j = new OnHeapColumnVector(BATCH_SIZE, IntegerType) + private lazy val batch = new ColumnarBatch( + new StructType().add("i", "int").add("j", "int"), Array(i, j), BATCH_SIZE) + + private var current = start + + override def createDataReader(): DataReader[ColumnarBatch] = this + + override def next(): Boolean = { + i.reset() + j.reset() + + var count = 0 + while (current < end && count < BATCH_SIZE) { + i.putInt(count, current) + j.putInt(count, -current) + current += 1 + count += 1 + } + + if (count == 0) { + false + } else { + batch.setNumRows(count) + true + } + } + + override def get(): ColumnarBatch = { + batch + } + + override def close(): Unit = batch.close() +} From 12db365b4faf7a185708648d246fc4a2aae0c2c0 Mon Sep 17 00:00:00 2001 From: Gabor Somogyi Date: Tue, 16 Jan 2018 11:41:08 -0800 Subject: [PATCH 0106/1161] [SPARK-16139][TEST] Add logging functionality for leaked threads in tests ## What changes were proposed in this pull request? Lots of our tests don't properly shutdown everything they create, and end up leaking lots of threads. For example, `TaskSetManagerSuite` doesn't stop the extra `TaskScheduler` and `DAGScheduler` it creates. There are a couple more instances, eg. in `DAGSchedulerSuite`. This PR adds the possibility to print out the not properly stopped thread list after a test suite executed. The format is the following: ``` ===== FINISHED o.a.s.scheduler.DAGSchedulerSuite: 'task end event should have updated accumulators (SPARK-20342)' ===== ... ===== Global thread whitelist loaded with name /thread_whitelist from classpath: rpc-client.*, rpc-server.*, shuffle-client.*, shuffle-server.*' ===== ScalaTest-run: ===== THREADS NOT STOPPED PROPERLY ===== ScalaTest-run: dag-scheduler-event-loop ScalaTest-run: globalEventExecutor-2-5 ScalaTest-run: ===== END OF THREAD DUMP ===== ScalaTest-run: ===== EITHER PUT THREAD NAME INTO THE WHITELIST FILE OR SHUT IT DOWN PROPERLY ===== ``` With the help of this leaking threads has been identified in TaskSetManagerSuite. My intention is to hunt down and fix such bugs in later PRs. ## How was this patch tested? Manual: TaskSetManagerSuite test executed and found out where are the leaking threads. Automated: Pass the Jenkins. Author: Gabor Somogyi Closes #19893 from gaborgsomogyi/SPARK-16139. --- .../org/apache/spark/SparkFunSuite.scala | 34 +++++++ .../scala/org/apache/spark/ThreadAudit.scala | 99 +++++++++++++++++++ .../spark/scheduler/TaskSetManagerSuite.scala | 7 +- .../apache/spark/sql/SessionStateSuite.scala | 1 + .../sql/sources/DataSourceAnalysisSuite.scala | 1 + .../spark/sql/test/SharedSQLContext.scala | 23 ++++- .../hive/HiveContextCompatibilitySuite.scala | 1 + .../sql/hive/HiveSessionStateSuite.scala | 1 + .../spark/sql/hive/HiveSparkSubmitSuite.scala | 2 + .../sql/hive/client/HiveClientSuite.scala | 1 + .../sql/hive/client/HiveVersionSuite.scala | 1 + .../spark/sql/hive/client/VersionsSuite.scala | 2 + .../hive/execution/HiveComparisonTest.scala | 2 + .../hive/orc/OrcHadoopFsRelationSuite.scala | 1 + .../sql/hive/test/TestHiveSingleton.scala | 1 + 15 files changed, 174 insertions(+), 3 deletions(-) create mode 100644 core/src/test/scala/org/apache/spark/ThreadAudit.scala diff --git a/core/src/test/scala/org/apache/spark/SparkFunSuite.scala b/core/src/test/scala/org/apache/spark/SparkFunSuite.scala index 18077c08c9dcc..3af9d82393bc4 100644 --- a/core/src/test/scala/org/apache/spark/SparkFunSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkFunSuite.scala @@ -27,19 +27,53 @@ import org.apache.spark.util.AccumulatorContext /** * Base abstract class for all unit tests in Spark for handling common functionality. + * + * Thread audit happens normally here automatically when a new test suite created. + * The only prerequisite for that is that the test class must extend [[SparkFunSuite]]. + * + * It is possible to override the default thread audit behavior by setting enableAutoThreadAudit + * to false and manually calling the audit methods, if desired. For example: + * + * class MyTestSuite extends SparkFunSuite { + * + * override val enableAutoThreadAudit = false + * + * protected override def beforeAll(): Unit = { + * doThreadPreAudit() + * super.beforeAll() + * } + * + * protected override def afterAll(): Unit = { + * super.afterAll() + * doThreadPostAudit() + * } + * } */ abstract class SparkFunSuite extends FunSuite with BeforeAndAfterAll + with ThreadAudit with Logging { // scalastyle:on + protected val enableAutoThreadAudit = true + + protected override def beforeAll(): Unit = { + if (enableAutoThreadAudit) { + doThreadPreAudit() + } + super.beforeAll() + } + protected override def afterAll(): Unit = { try { // Avoid leaking map entries in tests that use accumulators without SparkContext AccumulatorContext.clear() } finally { super.afterAll() + if (enableAutoThreadAudit) { + doThreadPostAudit() + } } } diff --git a/core/src/test/scala/org/apache/spark/ThreadAudit.scala b/core/src/test/scala/org/apache/spark/ThreadAudit.scala new file mode 100644 index 0000000000000..b3cea9de8f304 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/ThreadAudit.scala @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import scala.collection.JavaConverters._ + +import org.apache.spark.internal.Logging + +/** + * Thread audit for test suites. + */ +trait ThreadAudit extends Logging { + + val threadWhiteList = Set( + /** + * Netty related internal threads. + * These are excluded because their lifecycle is handled by the netty itself + * and spark has no explicit effect on them. + */ + "netty.*", + + /** + * Netty related internal threads. + * A Single-thread singleton EventExecutor inside netty which creates such threads. + * These are excluded because their lifecycle is handled by the netty itself + * and spark has no explicit effect on them. + */ + "globalEventExecutor.*", + + /** + * Netty related internal threads. + * Checks if a thread is alive periodically and runs a task when a thread dies. + * These are excluded because their lifecycle is handled by the netty itself + * and spark has no explicit effect on them. + */ + "threadDeathWatcher.*", + + /** + * During [[SparkContext]] creation [[org.apache.spark.rpc.netty.NettyRpcEnv]] + * creates event loops. One is wrapped inside + * [[org.apache.spark.network.server.TransportServer]] + * the other one is inside [[org.apache.spark.network.client.TransportClient]]. + * The thread pools behind shut down asynchronously triggered by [[SparkContext#stop]]. + * Manually checked and all of them stopped properly. + */ + "rpc-client.*", + "rpc-server.*", + + /** + * During [[SparkContext]] creation BlockManager creates event loops. One is wrapped inside + * [[org.apache.spark.network.server.TransportServer]] + * the other one is inside [[org.apache.spark.network.client.TransportClient]]. + * The thread pools behind shut down asynchronously triggered by [[SparkContext#stop]]. + * Manually checked and all of them stopped properly. + */ + "shuffle-client.*", + "shuffle-server.*" + ) + private var threadNamesSnapshot: Set[String] = Set.empty + + protected def doThreadPreAudit(): Unit = { + threadNamesSnapshot = runningThreadNames() + } + + protected def doThreadPostAudit(): Unit = { + val shortSuiteName = this.getClass.getName.replaceAll("org.apache.spark", "o.a.s") + + if (threadNamesSnapshot.nonEmpty) { + val remainingThreadNames = runningThreadNames().diff(threadNamesSnapshot) + .filterNot { s => threadWhiteList.exists(s.matches(_)) } + if (remainingThreadNames.nonEmpty) { + logWarning(s"\n\n===== POSSIBLE THREAD LEAK IN SUITE $shortSuiteName, " + + s"thread names: ${remainingThreadNames.mkString(", ")} =====\n") + } + } else { + logWarning("\n\n===== THREAD AUDIT POST ACTION CALLED " + + s"WITHOUT PRE ACTION IN SUITE $shortSuiteName =====\n") + } + } + + private def runningThreadNames(): Set[String] = { + Thread.getAllStackTraces.keySet().asScala.map(_.getName).toSet + } +} diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index 2ce81ae27daf6..ca6a7e5db3b17 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -683,7 +683,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg val conf = new SparkConf().set("spark.speculation", "true") sc = new SparkContext("local", "test", conf) - val sched = new FakeTaskScheduler(sc, ("execA", "host1"), ("execB", "host2")) + sched = new FakeTaskScheduler(sc, ("execA", "host1"), ("execB", "host2")) sched.initialize(new FakeSchedulerBackend() { override def killTask( taskId: Long, @@ -709,6 +709,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg } } } + sched.dagScheduler.stop() sched.setDAGScheduler(dagScheduler) val singleTask = new ShuffleMapTask(0, 0, null, new Partition { @@ -754,7 +755,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg sc.conf.set("spark.speculation", "true") var killTaskCalled = false - val sched = new FakeTaskScheduler(sc, ("exec1", "host1"), + sched = new FakeTaskScheduler(sc, ("exec1", "host1"), ("exec2", "host2"), ("exec3", "host3")) sched.initialize(new FakeSchedulerBackend() { override def killTask( @@ -789,6 +790,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg } } } + sched.dagScheduler.stop() sched.setDAGScheduler(dagScheduler) val taskSet = FakeTask.createShuffleMapTaskSet(4, 0, 0, @@ -1183,6 +1185,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg sc = new SparkContext("local", "test") sched = new FakeTaskScheduler(sc, ("exec1", "host1")) val mockDAGScheduler = mock(classOf[DAGScheduler]) + sched.dagScheduler.stop() sched.dagScheduler = mockDAGScheduler val taskSet = FakeTask.createTaskSet(numTasks = 1, stageId = 0, stageAttemptId = 0) val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock = new ManualClock(1)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala index c01666770720c..5d75f5835bf9e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala @@ -39,6 +39,7 @@ class SessionStateSuite extends SparkFunSuite protected var activeSession: SparkSession = _ override def beforeAll(): Unit = { + super.beforeAll() activeSession = SparkSession.builder().master("local").getOrCreate() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceAnalysisSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceAnalysisSuite.scala index 735e07c21373a..e1022e377132c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceAnalysisSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceAnalysisSuite.scala @@ -33,6 +33,7 @@ class DataSourceAnalysisSuite extends SparkFunSuite with BeforeAndAfterAll { private var targetPartitionSchema: StructType = _ override def beforeAll(): Unit = { + super.beforeAll() targetAttributes = Seq('a.int, 'd.int, 'b.int, 'c.int) targetPartitionSchema = new StructType() .add("b", IntegerType) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala index 4d578e21f5494..e6c7648c986ae 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala @@ -17,4 +17,25 @@ package org.apache.spark.sql.test -trait SharedSQLContext extends SQLTestUtils with SharedSparkSession +trait SharedSQLContext extends SQLTestUtils with SharedSparkSession { + + /** + * Suites extending [[SharedSQLContext]] are sharing resources (eg. SparkSession) in their tests. + * That trait initializes the spark session in its [[beforeAll()]] implementation before the + * automatic thread snapshot is performed, so the audit code could fail to report threads leaked + * by that shared session. + * + * The behavior is overridden here to take the snapshot before the spark session is initialized. + */ + override protected val enableAutoThreadAudit = false + + protected override def beforeAll(): Unit = { + doThreadPreAudit() + super.beforeAll() + } + + protected override def afterAll(): Unit = { + super.afterAll() + doThreadPostAudit() + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveContextCompatibilitySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveContextCompatibilitySuite.scala index 8a7423663f28d..a80db765846e9 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveContextCompatibilitySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveContextCompatibilitySuite.scala @@ -24,6 +24,7 @@ import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} class HiveContextCompatibilitySuite extends SparkFunSuite with BeforeAndAfterEach { + override protected val enableAutoThreadAudit = false private var sc: SparkContext = null private var hc: HiveContext = null diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSessionStateSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSessionStateSuite.scala index 958ad3e1c3ce8..f7da3c4cbb0aa 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSessionStateSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSessionStateSuite.scala @@ -30,6 +30,7 @@ class HiveSessionStateSuite extends SessionStateSuite override def beforeAll(): Unit = { // Reuse the singleton session + super.beforeAll() activeSession = spark } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala index 21b3e281490cf..10204f4694663 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala @@ -44,6 +44,8 @@ class HiveSparkSubmitSuite with BeforeAndAfterEach with ResetSystemProperties { + override protected val enableAutoThreadAudit = false + // TODO: rewrite these or mark them as slow tests to be run sparingly override def beforeEach() { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala index ce53acef51503..a5dfd89b3a574 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala @@ -67,6 +67,7 @@ class HiveClientSuite(version: String) } override def beforeAll() { + super.beforeAll() client = init(true) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveVersionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveVersionSuite.scala index 951ebfad4590e..bb8a4697b0a13 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveVersionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveVersionSuite.scala @@ -25,6 +25,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.hive.HiveUtils private[client] abstract class HiveVersionSuite(version: String) extends SparkFunSuite { + override protected val enableAutoThreadAudit = false protected var client: HiveClient = null protected def buildClient(hadoopConf: Configuration): HiveClient = { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala index e64389e56b5a1..72536b833481a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala @@ -50,6 +50,8 @@ import org.apache.spark.util.{MutableURLClassLoader, Utils} @ExtendedHiveTest class VersionsSuite extends SparkFunSuite with Logging { + override protected val enableAutoThreadAudit = false + import HiveClientBuilder.buildClient /** diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala index cee82cda4628a..272e6f51f5002 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala @@ -48,6 +48,8 @@ import org.apache.spark.sql.hive.test.{TestHive, TestHiveQueryExecution} abstract class HiveComparisonTest extends SparkFunSuite with BeforeAndAfterAll with GivenWhenThen { + override protected val enableAutoThreadAudit = false + /** * Path to the test datasets. We find this by looking up "hive-test-path-helper.txt" file. * diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala index f87162f94c01a..a1f054b8e3f44 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala @@ -30,6 +30,7 @@ import org.apache.spark.sql.types._ class OrcHadoopFsRelationSuite extends HadoopFsRelationTest { import testImplicits._ + override protected val enableAutoThreadAudit = false override val dataSourceName: String = classOf[org.apache.spark.sql.execution.datasources.orc.OrcFileFormat].getCanonicalName diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHiveSingleton.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHiveSingleton.scala index df7988f542b71..d3fff37c3424d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHiveSingleton.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHiveSingleton.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.hive.client.HiveClient trait TestHiveSingleton extends SparkFunSuite with BeforeAndAfterAll { + override protected val enableAutoThreadAudit = false protected val spark: SparkSession = TestHive.sparkSession protected val hiveContext: TestHiveContext = TestHive protected val hiveClient: HiveClient = From 4371466b3f06ca171b10568e776c9446f7bae6dd Mon Sep 17 00:00:00 2001 From: Bago Amirbekian Date: Tue, 16 Jan 2018 12:56:57 -0800 Subject: [PATCH 0107/1161] [SPARK-23045][ML][SPARKR] Update RFormula to use OneHotEncoderEstimator. ## What changes were proposed in this pull request? RFormula should use VectorSizeHint & OneHotEncoderEstimator in its pipeline to avoid using the deprecated OneHotEncoder & to ensure the model produced can be used in streaming. ## How was this patch tested? Unit tests. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Bago Amirbekian Closes #20229 from MrBago/rFormula. --- R/pkg/R/mllib_utils.R | 1 - .../apache/spark/ml/feature/RFormula.scala | 20 +++++-- .../spark/ml/feature/RFormulaSuite.scala | 53 +++++++++++-------- 3 files changed, 46 insertions(+), 28 deletions(-) diff --git a/R/pkg/R/mllib_utils.R b/R/pkg/R/mllib_utils.R index 23dda42c325be..a53c92c2c4815 100644 --- a/R/pkg/R/mllib_utils.R +++ b/R/pkg/R/mllib_utils.R @@ -130,4 +130,3 @@ read.ml <- function(path) { stop("Unsupported model: ", jobj) } } - diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala index f384ffbf578bc..1155ea5fdd85b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala @@ -199,6 +199,7 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override val uid: String) val parsedFormula = RFormulaParser.parse($(formula)) val resolvedFormula = parsedFormula.resolve(dataset.schema) val encoderStages = ArrayBuffer[PipelineStage]() + val oneHotEncodeColumns = ArrayBuffer[(String, String)]() val prefixesToRewrite = mutable.Map[String, String]() val tempColumns = ArrayBuffer[String]() @@ -242,16 +243,17 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override val uid: String) val encodedTerms = resolvedFormula.terms.map { case Seq(term) if dataset.schema(term).dataType == StringType => val encodedCol = tmpColumn("onehot") - var encoder = new OneHotEncoder() - .setInputCol(indexed(term)) - .setOutputCol(encodedCol) // Formula w/o intercept, one of the categories in the first category feature is // being used as reference category, we will not drop any category for that feature. if (!hasIntercept && !keepReferenceCategory) { - encoder = encoder.setDropLast(false) + encoderStages += new OneHotEncoderEstimator(uid) + .setInputCols(Array(indexed(term))) + .setOutputCols(Array(encodedCol)) + .setDropLast(false) keepReferenceCategory = true + } else { + oneHotEncodeColumns += indexed(term) -> encodedCol } - encoderStages += encoder prefixesToRewrite(encodedCol + "_") = term + "_" encodedCol case Seq(term) => @@ -265,6 +267,14 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override val uid: String) interactionCol } + if (oneHotEncodeColumns.nonEmpty) { + val (inputCols, outputCols) = oneHotEncodeColumns.toArray.unzip + encoderStages += new OneHotEncoderEstimator(uid) + .setInputCols(inputCols) + .setOutputCols(outputCols) + .setDropLast(true) + } + encoderStages += new VectorAssembler(uid) .setInputCols(encodedTerms.toArray) .setOutputCol($(featuresCol)) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala index f3f4b5a3d0233..bfe38d32dd77d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala @@ -29,6 +29,17 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { import testImplicits._ + def testRFormulaTransform[A: Encoder]( + dataframe: DataFrame, + formulaModel: RFormulaModel, + expected: DataFrame): Unit = { + val (first +: rest) = expected.schema.fieldNames.toSeq + val expectedRows = expected.collect() + testTransformerByGlobalCheckFunc[A](dataframe, formulaModel, first, rest: _*) { rows => + assert(rows === expectedRows) + } + } + test("params") { ParamsSuite.checkParams(new RFormula()) } @@ -47,7 +58,7 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { // TODO(ekl) make schema comparisons ignore metadata, to avoid .toString assert(result.schema.toString == resultSchema.toString) assert(resultSchema == expected.schema) - assert(result.collect() === expected.collect()) + testRFormulaTransform[(Int, Double, Double)](original, model, expected) } test("features column already exists") { @@ -109,7 +120,7 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { (7, 8.0, 9.0, Vectors.dense(8.0, 9.0)) ).toDF("id", "a", "b", "features") assert(result.schema.toString == resultSchema.toString) - assert(result.collect() === expected.collect()) + testRFormulaTransform[(Int, Double, Double)](original, model, expected) } test("encodes string terms") { @@ -126,7 +137,7 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { (4, "baz", 5, Vectors.dense(0.0, 0.0, 5.0), 4.0) ).toDF("id", "a", "b", "features", "label") assert(result.schema.toString == resultSchema.toString) - assert(result.collect() === expected.collect()) + testRFormulaTransform[(Int, String, Int)](original, model, expected) } test("encodes string terms with string indexer order type") { @@ -167,7 +178,7 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { val result = model.transform(original) val resultSchema = model.transformSchema(original.schema) assert(result.schema.toString == resultSchema.toString) - assert(result.collect() === expected(idx).collect()) + testRFormulaTransform[(Int, String, Int)](original, model, expected(idx)) idx += 1 } } @@ -210,7 +221,7 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { val result = model.transform(original) val resultSchema = model.transformSchema(original.schema) assert(result.schema.toString == resultSchema.toString) - assert(result.collect() === expected.collect()) + testRFormulaTransform[(Int, String, Int)](original, model, expected) } test("formula w/o intercept, we should output reference category when encoding string terms") { @@ -253,7 +264,7 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { (4, "baz", "zz", 5, Vectors.dense(0.0, 1.0, 0.0, 1.0, 5.0), 4.0) ).toDF("id", "a", "b", "c", "features", "label") assert(result1.schema.toString == resultSchema1.toString) - assert(result1.collect() === expected1.collect()) + testRFormulaTransform[(Int, String, String, Int)](original, model1, expected1) val attrs1 = AttributeGroup.fromStructField(result1.schema("features")) val expectedAttrs1 = new AttributeGroup( @@ -280,7 +291,7 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { (4, "baz", "zz", 5, Vectors.sparse(7, Array(2, 6), Array(1.0, 5.0)), 4.0) ).toDF("id", "a", "b", "c", "features", "label") assert(result2.schema.toString == resultSchema2.toString) - assert(result2.collect() === expected2.collect()) + testRFormulaTransform[(Int, String, String, Int)](original, model2, expected2) val attrs2 = AttributeGroup.fromStructField(result2.schema("features")) val expectedAttrs2 = new AttributeGroup( @@ -302,7 +313,6 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { Seq(("male", "foo", 4), ("female", "bar", 4), ("female", "bar", 5), ("male", "baz", 5)) .toDF("id", "a", "b") val model = formula.fit(original) - val result = model.transform(original) val expected = Seq( ("male", "foo", 4, Vectors.dense(0.0, 1.0, 4.0), 1.0), ("female", "bar", 4, Vectors.dense(1.0, 0.0, 4.0), 0.0), @@ -310,7 +320,7 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { ("male", "baz", 5, Vectors.dense(0.0, 0.0, 5.0), 1.0) ).toDF("id", "a", "b", "features", "label") // assert(result.schema.toString == resultSchema.toString) - assert(result.collect() === expected.collect()) + testRFormulaTransform[(String, String, Int)](original, model, expected) } test("force to index label even it is numeric type") { @@ -319,7 +329,6 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { Seq((1.0, "foo", 4), (1.0, "bar", 4), (0.0, "bar", 5), (1.0, "baz", 5)) ).toDF("id", "a", "b") val model = formula.fit(original) - val result = model.transform(original) val expected = spark.createDataFrame( Seq( (1.0, "foo", 4, Vectors.dense(0.0, 1.0, 4.0), 0.0), @@ -327,7 +336,7 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { (0.0, "bar", 5, Vectors.dense(1.0, 0.0, 5.0), 1.0), (1.0, "baz", 5, Vectors.dense(0.0, 0.0, 5.0), 0.0)) ).toDF("id", "a", "b", "features", "label") - assert(result.collect() === expected.collect()) + testRFormulaTransform[(Double, String, Int)](original, model, expected) } test("attribute generation") { @@ -391,7 +400,7 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { (1, 2, 4, 2, Vectors.dense(16.0), 1.0), (2, 3, 4, 1, Vectors.dense(12.0), 2.0) ).toDF("a", "b", "c", "d", "features", "label") - assert(result.collect() === expected.collect()) + testRFormulaTransform[(Int, Int, Int, Int)](original, model, expected) val attrs = AttributeGroup.fromStructField(result.schema("features")) val expectedAttrs = new AttributeGroup( "features", @@ -414,7 +423,7 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { (4, "baz", 5, Vectors.dense(5.0, 0.0, 0.0), 4.0), (4, "baz", 5, Vectors.dense(5.0, 0.0, 0.0), 4.0) ).toDF("id", "a", "b", "features", "label") - assert(result.collect() === expected.collect()) + testRFormulaTransform[(Int, String, Int)](original, model, expected) val attrs = AttributeGroup.fromStructField(result.schema("features")) val expectedAttrs = new AttributeGroup( "features", @@ -436,7 +445,7 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { (2, "bar", "zq", Vectors.dense(1.0, 0.0, 0.0, 0.0), 2.0), (3, "bar", "zz", Vectors.dense(0.0, 1.0, 0.0, 0.0), 3.0) ).toDF("id", "a", "b", "features", "label") - assert(result.collect() === expected.collect()) + testRFormulaTransform[(Int, String, String)](original, model, expected) val attrs = AttributeGroup.fromStructField(result.schema("features")) val expectedAttrs = new AttributeGroup( "features", @@ -511,8 +520,8 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { intercept[SparkException] { formula1.fit(df1).transform(df2).collect() } - val result1 = formula1.setHandleInvalid("skip").fit(df1).transform(df2) - val result2 = formula1.setHandleInvalid("keep").fit(df1).transform(df2) + val model1 = formula1.setHandleInvalid("skip").fit(df1) + val model2 = formula1.setHandleInvalid("keep").fit(df1) val expected1 = Seq( (1, "foo", "zq", Vectors.dense(0.0, 1.0), 1.0), @@ -524,16 +533,16 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { (3, "bar", "zy", Vectors.dense(1.0, 0.0, 0.0, 0.0), 3.0) ).toDF("id", "a", "b", "features", "label") - assert(result1.collect() === expected1.collect()) - assert(result2.collect() === expected2.collect()) + testRFormulaTransform[(Int, String, String)](df2, model1, expected1) + testRFormulaTransform[(Int, String, String)](df2, model2, expected2) // Handle unseen labels. val formula2 = new RFormula().setFormula("b ~ a + id") intercept[SparkException] { formula2.fit(df1).transform(df2).collect() } - val result3 = formula2.setHandleInvalid("skip").fit(df1).transform(df2) - val result4 = formula2.setHandleInvalid("keep").fit(df1).transform(df2) + val model3 = formula2.setHandleInvalid("skip").fit(df1) + val model4 = formula2.setHandleInvalid("keep").fit(df1) val expected3 = Seq( (1, "foo", "zq", Vectors.dense(0.0, 1.0), 0.0), @@ -545,8 +554,8 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { (3, "bar", "zy", Vectors.dense(1.0, 0.0, 3.0), 2.0) ).toDF("id", "a", "b", "features", "label") - assert(result3.collect() === expected3.collect()) - assert(result4.collect() === expected4.collect()) + testRFormulaTransform[(Int, String, String)](df2, model3, expected3) + testRFormulaTransform[(Int, String, String)](df2, model4, expected4) } test("Use Vectors as inputs to formula.") { From 5ae333391bd73331b5b90af71a3de52cdbb24109 Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Tue, 16 Jan 2018 16:25:10 -0800 Subject: [PATCH 0108/1161] [SPARK-23044] Error handling for jira assignment ## What changes were proposed in this pull request? * If there is any error while trying to assign the jira, prompt again * Filter out the "Apache Spark" choice * allow arbitrary user ids to be entered ## How was this patch tested? Couldn't really test the error case, just some testing of similar-ish code in python shell. Haven't run a merge yet. Author: Imran Rashid Closes #20236 from squito/SPARK-23044. --- dev/merge_spark_pr.py | 50 +++++++++++++++++++++++++++---------------- 1 file changed, 32 insertions(+), 18 deletions(-) diff --git a/dev/merge_spark_pr.py b/dev/merge_spark_pr.py index 57ca8400b6f3d..6b244d8184b2c 100755 --- a/dev/merge_spark_pr.py +++ b/dev/merge_spark_pr.py @@ -30,6 +30,7 @@ import re import subprocess import sys +import traceback import urllib2 try: @@ -298,24 +299,37 @@ def choose_jira_assignee(issue, asf_jira): Prompt the user to choose who to assign the issue to in jira, given a list of candidates, including the original reporter and all commentors """ - reporter = issue.fields.reporter - commentors = map(lambda x: x.author, issue.fields.comment.comments) - candidates = set(commentors) - candidates.add(reporter) - candidates = list(candidates) - print("JIRA is unassigned, choose assignee") - for idx, author in enumerate(candidates): - annotations = ["Reporter"] if author == reporter else [] - if author in commentors: - annotations.append("Commentor") - print("[%d] %s (%s)" % (idx, author.displayName, ",".join(annotations))) - assignee = raw_input("Enter number of user to assign to (blank to leave unassigned):") - if assignee == "": - return None - else: - assignee = candidates[int(assignee)] - asf_jira.assign_issue(issue.key, assignee.key) - return assignee + while True: + try: + reporter = issue.fields.reporter + commentors = map(lambda x: x.author, issue.fields.comment.comments) + candidates = set(commentors) + candidates.add(reporter) + candidates = list(candidates) + print("JIRA is unassigned, choose assignee") + for idx, author in enumerate(candidates): + if author.key == "apachespark": + continue + annotations = ["Reporter"] if author == reporter else [] + if author in commentors: + annotations.append("Commentor") + print("[%d] %s (%s)" % (idx, author.displayName, ",".join(annotations))) + raw_assignee = raw_input( + "Enter number of user, or userid, to assign to (blank to leave unassigned):") + if raw_assignee == "": + return None + else: + try: + id = int(raw_assignee) + assignee = candidates[id] + except: + # assume it's a user id, and try to assign (might fail, we just prompt again) + assignee = asf_jira.user(raw_assignee) + asf_jira.assign_issue(issue.key, assignee.key) + return assignee + except: + traceback.print_exc() + print("Error assigning JIRA, try again (or leave blank and fix manually)") def resolve_jira_issues(title, merge_branches, comment): From 0c2ba427bc7323729e6ffb34f1f06a97f0bf0c1d Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Wed, 17 Jan 2018 09:57:30 +0800 Subject: [PATCH 0109/1161] [SPARK-23095][SQL] Decorrelation of scalar subquery fails with java.util.NoSuchElementException ## What changes were proposed in this pull request? The following SQL involving scalar correlated query returns a map exception. ``` SQL SELECT t1a FROM t1 WHERE t1a = (SELECT count(*) FROM t2 WHERE t2c = t1c HAVING count(*) >= 1) ``` ``` SQL key not found: ExprId(278,786682bb-41f9-4bd5-a397-928272cc8e4e) java.util.NoSuchElementException: key not found: ExprId(278,786682bb-41f9-4bd5-a397-928272cc8e4e) at scala.collection.MapLike$class.default(MapLike.scala:228) at scala.collection.AbstractMap.default(Map.scala:59) at scala.collection.MapLike$class.apply(MapLike.scala:141) at scala.collection.AbstractMap.apply(Map.scala:59) at org.apache.spark.sql.catalyst.optimizer.RewriteCorrelatedScalarSubquery$.org$apache$spark$sql$catalyst$optimizer$RewriteCorrelatedScalarSubquery$$evalSubqueryOnZeroTups(subquery.scala:378) at org.apache.spark.sql.catalyst.optimizer.RewriteCorrelatedScalarSubquery$$anonfun$org$apache$spark$sql$catalyst$optimizer$RewriteCorrelatedScalarSubquery$$constructLeftJoins$1.apply(subquery.scala:430) at org.apache.spark.sql.catalyst.optimizer.RewriteCorrelatedScalarSubquery$$anonfun$org$apache$spark$sql$catalyst$optimizer$RewriteCorrelatedScalarSubquery$$constructLeftJoins$1.apply(subquery.scala:426) ``` In this case, after evaluating the HAVING clause "count(*) > 1" statically against the binding of aggregtation result on empty input, we determine that this query will not have a the count bug. We should simply return the evalSubqueryOnZeroTups with empty value. (Please fill in changes proposed in this fix) ## How was this patch tested? A new test was added in the Subquery bucket. Author: Dilip Biswal Closes #20283 from dilipbiswal/scalar-count-defect. --- .../sql/catalyst/optimizer/subquery.scala | 5 +- .../scalar-subquery-predicate.sql | 10 ++++ .../scalar-subquery-predicate.sql.out | 57 ++++++++++++------- 3 files changed, 49 insertions(+), 23 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala index 2673bea648d09..709db6d8bec7d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala @@ -369,13 +369,14 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] { case ne => (ne.exprId, evalAggOnZeroTups(ne)) }.toMap - case _ => sys.error(s"Unexpected operator in scalar subquery: $lp") + case _ => + sys.error(s"Unexpected operator in scalar subquery: $lp") } val resultMap = evalPlan(plan) // By convention, the scalar subquery result is the leftmost field. - resultMap(plan.output.head.exprId) + resultMap.getOrElse(plan.output.head.exprId, None) } /** diff --git a/sql/core/src/test/resources/sql-tests/inputs/subquery/scalar-subquery/scalar-subquery-predicate.sql b/sql/core/src/test/resources/sql-tests/inputs/subquery/scalar-subquery/scalar-subquery-predicate.sql index fb0d07fbdace7..1661209093fc4 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/subquery/scalar-subquery/scalar-subquery-predicate.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/subquery/scalar-subquery/scalar-subquery-predicate.sql @@ -173,6 +173,16 @@ WHERE t1a = (SELECT max(t2a) HAVING count(*) >= 0) OR t1i > '2014-12-31'; +-- TC 02.03.01 +SELECT t1a +FROM t1 +WHERE t1a = (SELECT max(t2a) + FROM t2 + WHERE t2c = t1c + GROUP BY t2c + HAVING count(*) >= 1) +OR t1i > '2014-12-31'; + -- TC 02.04 -- t1 on the right of an outer join -- can be reduced to inner join diff --git a/sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/scalar-subquery-predicate.sql.out b/sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/scalar-subquery-predicate.sql.out index 8b29300e71f90..a2b86db3e4f4c 100644 --- a/sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/scalar-subquery-predicate.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/scalar-subquery-predicate.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 26 +-- Number of queries: 29 -- !query 0 @@ -293,6 +293,21 @@ val1d -- !query 19 +SELECT t1a +FROM t1 +WHERE t1a = (SELECT max(t2a) + FROM t2 + WHERE t2c = t1c + GROUP BY t2c + HAVING count(*) >= 1) +OR t1i > '2014-12-31' +-- !query 19 schema +struct +-- !query 19 output +val1c +val1d + +-- !query 22 SELECT count(t1a) FROM t1 RIGHT JOIN t2 ON t1d = t2d @@ -300,13 +315,13 @@ WHERE t1a < (SELECT max(t2a) FROM t2 WHERE t2c = t1c GROUP BY t2c) --- !query 19 schema +-- !query 22 schema struct --- !query 19 output +-- !query 22 output 7 --- !query 20 +-- !query 23 SELECT t1a FROM t1 WHERE t1b <= (SELECT max(t2b) @@ -317,14 +332,14 @@ AND t1b >= (SELECT min(t2b) FROM t2 WHERE t2c = t1c GROUP BY t2c) --- !query 20 schema +-- !query 23 schema struct --- !query 20 output +-- !query 23 output val1b val1c --- !query 21 +-- !query 24 SELECT t1a FROM t1 WHERE t1a <= (SELECT max(t2a) @@ -338,14 +353,14 @@ WHERE t1a >= (SELECT min(t2a) FROM t2 WHERE t2c = t1c GROUP BY t2c) --- !query 21 schema +-- !query 24 schema struct --- !query 21 output +-- !query 24 output val1b val1c --- !query 22 +-- !query 25 SELECT t1a FROM t1 WHERE t1a <= (SELECT max(t2a) @@ -359,9 +374,9 @@ WHERE t1a >= (SELECT min(t2a) FROM t2 WHERE t2c = t1c GROUP BY t2c) --- !query 22 schema +-- !query 25 schema struct --- !query 22 output +-- !query 25 output val1a val1a val1b @@ -372,7 +387,7 @@ val1d val1d --- !query 23 +-- !query 26 SELECT t1a FROM t1 WHERE t1a <= (SELECT max(t2a) @@ -386,16 +401,16 @@ WHERE t1a >= (SELECT min(t2a) FROM t2 WHERE t2c = t1c GROUP BY t2c) --- !query 23 schema +-- !query 26 schema struct --- !query 23 output +-- !query 26 output val1a val1b val1c val1d --- !query 24 +-- !query 27 SELECT t1a FROM t1 WHERE t1a <= (SELECT max(t2a) @@ -409,13 +424,13 @@ WHERE t1a >= (SELECT min(t2a) FROM t2 WHERE t2c = t1c GROUP BY t2c) --- !query 24 schema +-- !query 27 schema struct --- !query 24 output +-- !query 27 output val1a --- !query 25 +-- !query 28 SELECT t1a FROM t1 GROUP BY t1a, t1c @@ -423,8 +438,8 @@ HAVING max(t1b) <= (SELECT max(t2b) FROM t2 WHERE t2c = t1c GROUP BY t2c) --- !query 25 schema +-- !query 28 schema struct --- !query 25 output +-- !query 28 output val1b val1c From a9b845ebb5b51eb619cfa7d73b6153024a6a420d Mon Sep 17 00:00:00 2001 From: Gabor Somogyi Date: Wed, 17 Jan 2018 10:03:25 +0800 Subject: [PATCH 0110/1161] [SPARK-22361][SQL][TEST] Add unit test for Window Frames ## What changes were proposed in this pull request? There are already quite a few integration tests using window frames, but the unit tests coverage is not ideal. In this PR the already existing tests are reorganized, extended and where gaps found additional cases added. ## How was this patch tested? Automated: Pass the Jenkins. Author: Gabor Somogyi Closes #20019 from gaborgsomogyi/SPARK-22361. --- .../parser/ExpressionParserSuite.scala | 57 ++- .../sql/DataFrameWindowFramesSuite.scala | 405 ++++++++++++++++++ .../sql/DataFrameWindowFunctionsSuite.scala | 243 ----------- 3 files changed, 454 insertions(+), 251 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFramesSuite.scala diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala index 2b9783a3295c6..cb8a1fecb80a7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala @@ -249,8 +249,8 @@ class ExpressionParserSuite extends PlanTest { assertEqual("foo(*) over (partition by a, b)", windowed(Seq('a, 'b))) assertEqual("foo(*) over (distribute by a, b)", windowed(Seq('a, 'b))) assertEqual("foo(*) over (cluster by a, b)", windowed(Seq('a, 'b))) - assertEqual("foo(*) over (order by a desc, b asc)", windowed(Seq.empty, Seq('a.desc, 'b.asc ))) - assertEqual("foo(*) over (sort by a desc, b asc)", windowed(Seq.empty, Seq('a.desc, 'b.asc ))) + assertEqual("foo(*) over (order by a desc, b asc)", windowed(Seq.empty, Seq('a.desc, 'b.asc))) + assertEqual("foo(*) over (sort by a desc, b asc)", windowed(Seq.empty, Seq('a.desc, 'b.asc))) assertEqual("foo(*) over (partition by a, b order by c)", windowed(Seq('a, 'b), Seq('c.asc))) assertEqual("foo(*) over (distribute by a, b sort by c)", windowed(Seq('a, 'b), Seq('c.asc))) @@ -263,21 +263,62 @@ class ExpressionParserSuite extends PlanTest { "sum(product + 1) over (partition by ((product / 2) + 1) order by 2)", WindowExpression('sum.function('product + 1), WindowSpecDefinition(Seq('product / 2 + 1), Seq(Literal(2).asc), UnspecifiedFrame))) + } + + test("range/rows window function expressions") { + val func = 'foo.function(star()) + def windowed( + partitioning: Seq[Expression] = Seq.empty, + ordering: Seq[SortOrder] = Seq.empty, + frame: WindowFrame = UnspecifiedFrame): Expression = { + WindowExpression(func, WindowSpecDefinition(partitioning, ordering, frame)) + } - // Range/Row val frameTypes = Seq(("rows", RowFrame), ("range", RangeFrame)) val boundaries = Seq( - ("10 preceding", -Literal(10), CurrentRow), + // No between combinations + ("unbounded preceding", UnboundedPreceding, CurrentRow), ("2147483648 preceding", -Literal(2147483648L), CurrentRow), + ("10 preceding", -Literal(10), CurrentRow), + ("3 + 1 preceding", -Add(Literal(3), Literal(1)), CurrentRow), + ("0 preceding", -Literal(0), CurrentRow), + ("current row", CurrentRow, CurrentRow), + ("0 following", Literal(0), CurrentRow), ("3 + 1 following", Add(Literal(3), Literal(1)), CurrentRow), - ("unbounded preceding", UnboundedPreceding, CurrentRow), + ("10 following", Literal(10), CurrentRow), + ("2147483649 following", Literal(2147483649L), CurrentRow), ("unbounded following", UnboundedFollowing, CurrentRow), // Will fail during analysis + + // Between combinations + ("between unbounded preceding and 5 following", + UnboundedPreceding, Literal(5)), + ("between unbounded preceding and 3 + 1 following", + UnboundedPreceding, Add(Literal(3), Literal(1))), + ("between unbounded preceding and 2147483649 following", + UnboundedPreceding, Literal(2147483649L)), ("between unbounded preceding and current row", UnboundedPreceding, CurrentRow), - ("between unbounded preceding and unbounded following", - UnboundedPreceding, UnboundedFollowing), + ("between 2147483648 preceding and current row", -Literal(2147483648L), CurrentRow), ("between 10 preceding and current row", -Literal(10), CurrentRow), + ("between 3 + 1 preceding and current row", -Add(Literal(3), Literal(1)), CurrentRow), + ("between 0 preceding and current row", -Literal(0), CurrentRow), + ("between current row and current row", CurrentRow, CurrentRow), + ("between current row and 0 following", CurrentRow, Literal(0)), ("between current row and 5 following", CurrentRow, Literal(5)), - ("between 10 preceding and 5 following", -Literal(10), Literal(5)) + ("between current row and 3 + 1 following", CurrentRow, Add(Literal(3), Literal(1))), + ("between current row and 2147483649 following", CurrentRow, Literal(2147483649L)), + ("between current row and unbounded following", CurrentRow, UnboundedFollowing), + ("between 2147483648 preceding and unbounded following", + -Literal(2147483648L), UnboundedFollowing), + ("between 10 preceding and unbounded following", + -Literal(10), UnboundedFollowing), + ("between 3 + 1 preceding and unbounded following", + -Add(Literal(3), Literal(1)), UnboundedFollowing), + ("between 0 preceding and unbounded following", -Literal(0), UnboundedFollowing), + + // Between partial and full range + ("between 10 preceding and 5 following", -Literal(10), Literal(5)), + ("between unbounded preceding and unbounded following", + UnboundedPreceding, UnboundedFollowing) ) frameTypes.foreach { case (frameTypeSql, frameType) => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFramesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFramesSuite.scala new file mode 100644 index 0000000000000..0ee9b0edc02b2 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFramesSuite.scala @@ -0,0 +1,405 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import java.sql.{Date, Timestamp} + +import org.apache.spark.sql.expressions.Window +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.unsafe.types.CalendarInterval + +/** + * Window frame testing for DataFrame API. + */ +class DataFrameWindowFramesSuite extends QueryTest with SharedSQLContext { + import testImplicits._ + + test("lead/lag with empty data frame") { + val df = Seq.empty[(Int, String)].toDF("key", "value") + val window = Window.partitionBy($"key").orderBy($"value") + + checkAnswer( + df.select( + lead("value", 1).over(window), + lag("value", 1).over(window)), + Nil) + } + + test("lead/lag with positive offset") { + val df = Seq((1, "1"), (2, "2"), (1, "3"), (2, "4")).toDF("key", "value") + val window = Window.partitionBy($"key").orderBy($"value") + + checkAnswer( + df.select( + $"key", + lead("value", 1).over(window), + lag("value", 1).over(window)), + Row(1, "3", null) :: Row(1, null, "1") :: Row(2, "4", null) :: Row(2, null, "2") :: Nil) + } + + test("reverse lead/lag with positive offset") { + val df = Seq((1, "1"), (2, "2"), (1, "3"), (2, "4")).toDF("key", "value") + val window = Window.partitionBy($"key").orderBy($"value".desc) + + checkAnswer( + df.select( + $"key", + lead("value", 1).over(window), + lag("value", 1).over(window)), + Row(1, "1", null) :: Row(1, null, "3") :: Row(2, "2", null) :: Row(2, null, "4") :: Nil) + } + + test("lead/lag with negative offset") { + val df = Seq((1, "1"), (2, "2"), (1, "3"), (2, "4")).toDF("key", "value") + val window = Window.partitionBy($"key").orderBy($"value") + + checkAnswer( + df.select( + $"key", + lead("value", -1).over(window), + lag("value", -1).over(window)), + Row(1, null, "3") :: Row(1, "1", null) :: Row(2, null, "4") :: Row(2, "2", null) :: Nil) + } + + test("reverse lead/lag with negative offset") { + val df = Seq((1, "1"), (2, "2"), (1, "3"), (2, "4")).toDF("key", "value") + val window = Window.partitionBy($"key").orderBy($"value".desc) + + checkAnswer( + df.select( + $"key", + lead("value", -1).over(window), + lag("value", -1).over(window)), + Row(1, null, "1") :: Row(1, "3", null) :: Row(2, null, "2") :: Row(2, "4", null) :: Nil) + } + + test("lead/lag with default value") { + val default = "n/a" + val df = Seq((1, "1"), (2, "2"), (1, "3"), (2, "4"), (2, "5")).toDF("key", "value") + val window = Window.partitionBy($"key").orderBy($"value") + + checkAnswer( + df.select( + $"key", + lead("value", 2, default).over(window), + lag("value", 2, default).over(window), + lead("value", -2, default).over(window), + lag("value", -2, default).over(window)), + Row(1, default, default, default, default) :: Row(1, default, default, default, default) :: + Row(2, "5", default, default, "5") :: Row(2, default, "2", "2", default) :: + Row(2, default, default, default, default) :: Nil) + } + + test("rows/range between with empty data frame") { + val df = Seq.empty[(String, Int)].toDF("key", "value") + val window = Window.partitionBy($"key").orderBy($"value") + + checkAnswer( + df.select( + 'key, + first("value").over( + window.rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)), + first("value").over( + window.rangeBetween(Window.unboundedPreceding, Window.unboundedFollowing))), + Nil) + } + + test("rows between should accept int/long values as boundary") { + val df = Seq((1L, "1"), (1L, "1"), (2147483650L, "1"), (3L, "2"), (2L, "1"), (2147483650L, "2")) + .toDF("key", "value") + + checkAnswer( + df.select( + $"key", + count("key").over( + Window.partitionBy($"value").orderBy($"key").rowsBetween(0, 2147483647))), + Seq(Row(1, 3), Row(1, 4), Row(2, 2), Row(3, 2), Row(2147483650L, 1), Row(2147483650L, 1)) + ) + + val e = intercept[AnalysisException]( + df.select( + $"key", + count("key").over( + Window.partitionBy($"value").orderBy($"key").rowsBetween(0, 2147483648L)))) + assert(e.message.contains("Boundary end is not a valid integer: 2147483648")) + } + + test("range between should accept at most one ORDER BY expression when unbounded") { + val df = Seq((1, 1)).toDF("key", "value") + val window = Window.orderBy($"key", $"value") + + checkAnswer( + df.select( + $"key", + min("key").over( + window.rangeBetween(Window.unboundedPreceding, Window.unboundedFollowing))), + Seq(Row(1, 1)) + ) + + val e1 = intercept[AnalysisException]( + df.select( + min("key").over(window.rangeBetween(Window.unboundedPreceding, 1)))) + assert(e1.message.contains("A range window frame with value boundaries cannot be used in a " + + "window specification with multiple order by expressions")) + + val e2 = intercept[AnalysisException]( + df.select( + min("key").over(window.rangeBetween(-1, Window.unboundedFollowing)))) + assert(e2.message.contains("A range window frame with value boundaries cannot be used in a " + + "window specification with multiple order by expressions")) + + val e3 = intercept[AnalysisException]( + df.select( + min("key").over(window.rangeBetween(-1, 1)))) + assert(e3.message.contains("A range window frame with value boundaries cannot be used in a " + + "window specification with multiple order by expressions")) + } + + test("range between should accept numeric values only when bounded") { + val df = Seq("non_numeric").toDF("value") + val window = Window.orderBy($"value") + + checkAnswer( + df.select( + $"value", + min("value").over( + window.rangeBetween(Window.unboundedPreceding, Window.unboundedFollowing))), + Row("non_numeric", "non_numeric") :: Nil) + + val e1 = intercept[AnalysisException]( + df.select( + min("value").over(window.rangeBetween(Window.unboundedPreceding, 1)))) + assert(e1.message.contains("The data type of the upper bound 'string' " + + "does not match the expected data type")) + + val e2 = intercept[AnalysisException]( + df.select( + min("value").over(window.rangeBetween(-1, Window.unboundedFollowing)))) + assert(e2.message.contains("The data type of the lower bound 'string' " + + "does not match the expected data type")) + + val e3 = intercept[AnalysisException]( + df.select( + min("value").over(window.rangeBetween(-1, 1)))) + assert(e3.message.contains("The data type of the lower bound 'string' " + + "does not match the expected data type")) + } + + test("range between should accept int/long values as boundary") { + val df = Seq((1L, "1"), (1L, "1"), (2147483650L, "1"), (3L, "2"), (2L, "1"), (2147483650L, "2")) + .toDF("key", "value") + + checkAnswer( + df.select( + $"key", + count("key").over( + Window.partitionBy($"value").orderBy($"key").rangeBetween(0, 2147483648L))), + Seq(Row(1, 3), Row(1, 3), Row(2, 2), Row(3, 2), Row(2147483650L, 1), Row(2147483650L, 1)) + ) + checkAnswer( + df.select( + $"key", + count("key").over( + Window.partitionBy($"value").orderBy($"key").rangeBetween(-2147483649L, 0))), + Seq(Row(1, 2), Row(1, 2), Row(2, 3), Row(2147483650L, 2), Row(2147483650L, 4), Row(3, 1)) + ) + + def dt(date: String): Date = Date.valueOf(date) + + val df2 = Seq((dt("2017-08-01"), "1"), (dt("2017-08-01"), "1"), (dt("2020-12-31"), "1"), + (dt("2017-08-03"), "2"), (dt("2017-08-02"), "1"), (dt("2020-12-31"), "2")) + .toDF("key", "value") + val window = Window.partitionBy($"value").orderBy($"key").rangeBetween(lit(0), lit(2)) + + checkAnswer( + df2.select( + $"key", + count("key").over(window)), + Seq(Row(dt("2017-08-01"), 3), Row(dt("2017-08-01"), 3), Row(dt("2020-12-31"), 1), + Row(dt("2017-08-03"), 1), Row(dt("2017-08-02"), 1), Row(dt("2020-12-31"), 1)) + ) + } + + test("range between should accept double values as boundary") { + val df = Seq((1.0D, "1"), (1.0D, "1"), (100.001D, "1"), (3.3D, "2"), (2.02D, "1"), + (100.001D, "2")).toDF("key", "value") + val window = Window.partitionBy($"value").orderBy($"key").rangeBetween(currentRow, lit(2.5D)) + + checkAnswer( + df.select( + $"key", + count("key").over(window)), + Seq(Row(1.0, 3), Row(1.0, 3), Row(100.001, 1), Row(3.3, 1), Row(2.02, 1), Row(100.001, 1)) + ) + } + + test("range between should accept interval values as boundary") { + def ts(timestamp: Long): Timestamp = new Timestamp(timestamp * 1000) + + val df = Seq((ts(1501545600), "1"), (ts(1501545600), "1"), (ts(1609372800), "1"), + (ts(1503000000), "2"), (ts(1502000000), "1"), (ts(1609372800), "2")) + .toDF("key", "value") + val window = Window.partitionBy($"value").orderBy($"key") + .rangeBetween(currentRow, lit(CalendarInterval.fromString("interval 23 days 4 hours"))) + + checkAnswer( + df.select( + $"key", + count("key").over(window)), + Seq(Row(ts(1501545600), 3), Row(ts(1501545600), 3), Row(ts(1609372800), 1), + Row(ts(1503000000), 1), Row(ts(1502000000), 1), Row(ts(1609372800), 1)) + ) + } + + test("unbounded rows/range between with aggregation") { + val df = Seq(("one", 1), ("two", 2), ("one", 3), ("two", 4)).toDF("key", "value") + val window = Window.partitionBy($"key").orderBy($"value") + + checkAnswer( + df.select( + 'key, + sum("value").over(window. + rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)), + sum("value").over(window. + rangeBetween(Window.unboundedPreceding, Window.unboundedFollowing))), + Row("one", 4, 4) :: Row("one", 4, 4) :: Row("two", 6, 6) :: Row("two", 6, 6) :: Nil) + } + + test("unbounded preceding/following rows between with aggregation") { + val df = Seq((1, "1"), (2, "2"), (2, "3"), (1, "3"), (3, "2"), (4, "3")).toDF("key", "value") + val window = Window.partitionBy($"value").orderBy($"key") + + checkAnswer( + df.select( + $"key", + last("key").over( + window.rowsBetween(Window.currentRow, Window.unboundedFollowing)), + last("key").over( + window.rowsBetween(Window.unboundedPreceding, Window.currentRow))), + Row(1, 1, 1) :: Row(2, 3, 2) :: Row(3, 3, 3) :: Row(1, 4, 1) :: Row(2, 4, 2) :: + Row(4, 4, 4) :: Nil) + } + + test("reverse unbounded preceding/following rows between with aggregation") { + val df = Seq((1, "1"), (2, "2"), (2, "3"), (1, "3"), (3, "2"), (4, "3")).toDF("key", "value") + val window = Window.partitionBy($"value").orderBy($"key".desc) + + checkAnswer( + df.select( + $"key", + last("key").over( + window.rowsBetween(Window.currentRow, Window.unboundedFollowing)), + last("key").over( + window.rowsBetween(Window.unboundedPreceding, Window.currentRow))), + Row(1, 1, 1) :: Row(3, 2, 3) :: Row(2, 2, 2) :: Row(4, 1, 4) :: Row(2, 1, 2) :: + Row(1, 1, 1) :: Nil) + } + + test("unbounded preceding/following range between with aggregation") { + val df = Seq((5, "1"), (5, "2"), (4, "2"), (6, "2"), (3, "1"), (2, "2")).toDF("key", "value") + val window = Window.partitionBy("value").orderBy("key") + + checkAnswer( + df.select( + $"key", + avg("key").over(window.rangeBetween(Window.unboundedPreceding, 1)) + .as("avg_key1"), + avg("key").over(window.rangeBetween(Window.currentRow, Window.unboundedFollowing)) + .as("avg_key2")), + Row(3, 3.0d, 4.0d) :: Row(5, 4.0d, 5.0d) :: Row(2, 2.0d, 17.0d / 4.0d) :: + Row(4, 11.0d / 3.0d, 5.0d) :: Row(5, 17.0d / 4.0d, 11.0d / 2.0d) :: + Row(6, 17.0d / 4.0d, 6.0d) :: Nil) + } + + // This is here to illustrate the fact that reverse order also reverses offsets. + test("reverse preceding/following range between with aggregation") { + val df = Seq(1, 2, 4, 3, 2, 1).toDF("value") + val window = Window.orderBy($"value".desc) + + checkAnswer( + df.select( + $"value", + sum($"value").over(window.rangeBetween(Window.unboundedPreceding, 1)), + sum($"value").over(window.rangeBetween(1, Window.unboundedFollowing))), + Row(1, 13, null) :: Row(2, 13, 2) :: Row(4, 7, 9) :: Row(3, 11, 6) :: + Row(2, 13, 2) :: Row(1, 13, null) :: Nil) + } + + test("sliding rows between with aggregation") { + val df = Seq((1, "1"), (2, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") + val window = Window.partitionBy($"value").orderBy($"key").rowsBetween(-1, 2) + + checkAnswer( + df.select( + $"key", + avg("key").over(window)), + Row(1, 4.0d / 3.0d) :: Row(1, 4.0d / 3.0d) :: Row(2, 3.0d / 2.0d) :: Row(2, 2.0d) :: + Row(2, 2.0d) :: Nil) + } + + test("reverse sliding rows between with aggregation") { + val df = Seq((1, "1"), (2, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") + val window = Window.partitionBy($"value").orderBy($"key".desc).rowsBetween(-1, 2) + + checkAnswer( + df.select( + $"key", + avg("key").over(window)), + Row(1, 1.0d) :: Row(1, 4.0d / 3.0d) :: Row(2, 4.0d / 3.0d) :: Row(2, 2.0d) :: + Row(2, 2.0d) :: Nil) + } + + test("sliding range between with aggregation") { + val df = Seq((1, "1"), (1, "1"), (3, "1"), (2, "2"), (2, "1"), (2, "2")).toDF("key", "value") + val window = Window.partitionBy($"value").orderBy($"key").rangeBetween(-1, 1) + + checkAnswer( + df.select( + $"key", + avg("key").over(window)), + Row(1, 4.0d / 3.0d) :: Row(1, 4.0d / 3.0d) :: Row(2, 7.0d / 4.0d) :: Row(3, 5.0d / 2.0d) :: + Row(2, 2.0d) :: Row(2, 2.0d) :: Nil) + } + + test("reverse sliding range between with aggregation") { + val df = Seq( + (1, "Thin", "Cell Phone", 6000), + (2, "Normal", "Tablet", 1500), + (3, "Mini", "Tablet", 5500), + (4, "Ultra thin", "Cell Phone", 5500), + (5, "Very thin", "Cell Phone", 6000), + (6, "Big", "Tablet", 2500), + (7, "Bendable", "Cell Phone", 3000), + (8, "Foldable", "Cell Phone", 3000), + (9, "Pro", "Tablet", 4500), + (10, "Pro2", "Tablet", 6500)). + toDF("id", "product", "category", "revenue") + val window = Window.partitionBy($"category").orderBy($"revenue".desc). + rangeBetween(-2000L, 1000L) + + checkAnswer( + df.select( + $"id", + avg($"revenue").over(window).cast("int")), + Row(1, 5833) :: Row(2, 2000) :: Row(3, 5500) :: + Row(4, 5833) :: Row(5, 5833) :: Row(6, 2833) :: + Row(7, 3000) :: Row(8, 3000) :: Row(9, 5500) :: + Row(10, 6000) :: Nil) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala index 01c988ecc3726..281147835abde 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala @@ -55,56 +55,6 @@ class DataFrameWindowFunctionsSuite extends QueryTest with SharedSQLContext { Row(1, "1") :: Row(2, "2") :: Row(null, null) :: Row(null, null) :: Nil) } - test("Window.rowsBetween") { - val df = Seq(("one", 1), ("two", 2)).toDF("key", "value") - // Running (cumulative) sum - checkAnswer( - df.select('key, sum("value").over( - Window.rowsBetween(Window.unboundedPreceding, Window.currentRow))), - Row("one", 1) :: Row("two", 3) :: Nil - ) - } - - test("lead") { - val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") - df.createOrReplaceTempView("window_table") - - checkAnswer( - df.select( - lead("value", 1).over(Window.partitionBy($"key").orderBy($"value"))), - Row("1") :: Row(null) :: Row("2") :: Row(null) :: Nil) - } - - test("lag") { - val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") - df.createOrReplaceTempView("window_table") - - checkAnswer( - df.select( - lag("value", 1).over(Window.partitionBy($"key").orderBy($"value"))), - Row(null) :: Row("1") :: Row(null) :: Row("2") :: Nil) - } - - test("lead with default value") { - val df = Seq((1, "1"), (1, "1"), (2, "2"), (1, "1"), - (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") - df.createOrReplaceTempView("window_table") - checkAnswer( - df.select( - lead("value", 2, "n/a").over(Window.partitionBy("key").orderBy("value"))), - Seq(Row("1"), Row("1"), Row("n/a"), Row("n/a"), Row("2"), Row("n/a"), Row("n/a"))) - } - - test("lag with default value") { - val df = Seq((1, "1"), (1, "1"), (2, "2"), (1, "1"), - (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") - df.createOrReplaceTempView("window_table") - checkAnswer( - df.select( - lag("value", 2, "n/a").over(Window.partitionBy($"key").orderBy($"value"))), - Seq(Row("n/a"), Row("n/a"), Row("1"), Row("1"), Row("n/a"), Row("n/a"), Row("2"))) - } - test("rank functions in unspecific window") { val df = Seq((1, "1"), (2, "2"), (1, "2"), (2, "2")).toDF("key", "value") df.createOrReplaceTempView("window_table") @@ -136,199 +86,6 @@ class DataFrameWindowFunctionsSuite extends QueryTest with SharedSQLContext { assert(e.message.contains("requires window to be ordered")) } - test("aggregation and rows between") { - val df = Seq((1, "1"), (2, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") - df.createOrReplaceTempView("window_table") - checkAnswer( - df.select( - avg("key").over(Window.partitionBy($"value").orderBy($"key").rowsBetween(-1, 2))), - Seq(Row(4.0d / 3.0d), Row(4.0d / 3.0d), Row(3.0d / 2.0d), Row(2.0d), Row(2.0d))) - } - - test("aggregation and range between") { - val df = Seq((1, "1"), (1, "1"), (3, "1"), (2, "2"), (2, "1"), (2, "2")).toDF("key", "value") - df.createOrReplaceTempView("window_table") - checkAnswer( - df.select( - avg("key").over(Window.partitionBy($"value").orderBy($"key").rangeBetween(-1, 1))), - Seq(Row(4.0d / 3.0d), Row(4.0d / 3.0d), Row(7.0d / 4.0d), Row(5.0d / 2.0d), - Row(2.0d), Row(2.0d))) - } - - test("row between should accept integer values as boundary") { - val df = Seq((1L, "1"), (1L, "1"), (2147483650L, "1"), - (3L, "2"), (2L, "1"), (2147483650L, "2")) - .toDF("key", "value") - df.createOrReplaceTempView("window_table") - checkAnswer( - df.select( - $"key", - count("key").over( - Window.partitionBy($"value").orderBy($"key").rowsBetween(0, 2147483647))), - Seq(Row(1, 3), Row(1, 4), Row(2, 2), Row(3, 2), Row(2147483650L, 1), Row(2147483650L, 1)) - ) - - val e = intercept[AnalysisException]( - df.select( - $"key", - count("key").over( - Window.partitionBy($"value").orderBy($"key").rowsBetween(0, 2147483648L)))) - assert(e.message.contains("Boundary end is not a valid integer: 2147483648")) - } - - test("range between should accept int/long values as boundary") { - val df = Seq((1L, "1"), (1L, "1"), (2147483650L, "1"), - (3L, "2"), (2L, "1"), (2147483650L, "2")) - .toDF("key", "value") - df.createOrReplaceTempView("window_table") - checkAnswer( - df.select( - $"key", - count("key").over( - Window.partitionBy($"value").orderBy($"key").rangeBetween(0, 2147483648L))), - Seq(Row(1, 3), Row(1, 3), Row(2, 2), Row(3, 2), Row(2147483650L, 1), Row(2147483650L, 1)) - ) - checkAnswer( - df.select( - $"key", - count("key").over( - Window.partitionBy($"value").orderBy($"key").rangeBetween(-2147483649L, 0))), - Seq(Row(1, 2), Row(1, 2), Row(2, 3), Row(2147483650L, 2), Row(2147483650L, 4), Row(3, 1)) - ) - - def dt(date: String): Date = Date.valueOf(date) - - val df2 = Seq((dt("2017-08-01"), "1"), (dt("2017-08-01"), "1"), (dt("2020-12-31"), "1"), - (dt("2017-08-03"), "2"), (dt("2017-08-02"), "1"), (dt("2020-12-31"), "2")) - .toDF("key", "value") - checkAnswer( - df2.select( - $"key", - count("key").over( - Window.partitionBy($"value").orderBy($"key").rangeBetween(lit(0), lit(2)))), - Seq(Row(dt("2017-08-01"), 3), Row(dt("2017-08-01"), 3), Row(dt("2020-12-31"), 1), - Row(dt("2017-08-03"), 1), Row(dt("2017-08-02"), 1), Row(dt("2020-12-31"), 1)) - ) - } - - test("range between should accept double values as boundary") { - val df = Seq((1.0D, "1"), (1.0D, "1"), (100.001D, "1"), - (3.3D, "2"), (2.02D, "1"), (100.001D, "2")) - .toDF("key", "value") - df.createOrReplaceTempView("window_table") - checkAnswer( - df.select( - $"key", - count("key").over( - Window.partitionBy($"value").orderBy($"key") - .rangeBetween(currentRow, lit(2.5D)))), - Seq(Row(1.0, 3), Row(1.0, 3), Row(100.001, 1), Row(3.3, 1), Row(2.02, 1), Row(100.001, 1)) - ) - } - - test("range between should accept interval values as boundary") { - def ts(timestamp: Long): Timestamp = new Timestamp(timestamp * 1000) - - val df = Seq((ts(1501545600), "1"), (ts(1501545600), "1"), (ts(1609372800), "1"), - (ts(1503000000), "2"), (ts(1502000000), "1"), (ts(1609372800), "2")) - .toDF("key", "value") - df.createOrReplaceTempView("window_table") - checkAnswer( - df.select( - $"key", - count("key").over( - Window.partitionBy($"value").orderBy($"key") - .rangeBetween(currentRow, - lit(CalendarInterval.fromString("interval 23 days 4 hours"))))), - Seq(Row(ts(1501545600), 3), Row(ts(1501545600), 3), Row(ts(1609372800), 1), - Row(ts(1503000000), 1), Row(ts(1502000000), 1), Row(ts(1609372800), 1)) - ) - } - - test("aggregation and rows between with unbounded") { - val df = Seq((1, "1"), (2, "2"), (2, "3"), (1, "3"), (3, "2"), (4, "3")).toDF("key", "value") - df.createOrReplaceTempView("window_table") - checkAnswer( - df.select( - $"key", - last("key").over( - Window.partitionBy($"value").orderBy($"key") - .rowsBetween(Window.currentRow, Window.unboundedFollowing)), - last("key").over( - Window.partitionBy($"value").orderBy($"key") - .rowsBetween(Window.unboundedPreceding, Window.currentRow)), - last("key").over(Window.partitionBy($"value").orderBy($"key").rowsBetween(-1, 1))), - Seq(Row(1, 1, 1, 1), Row(2, 3, 2, 3), Row(3, 3, 3, 3), Row(1, 4, 1, 2), Row(2, 4, 2, 4), - Row(4, 4, 4, 4))) - } - - test("aggregation and range between with unbounded") { - val df = Seq((5, "1"), (5, "2"), (4, "2"), (6, "2"), (3, "1"), (2, "2")).toDF("key", "value") - df.createOrReplaceTempView("window_table") - checkAnswer( - df.select( - $"key", - last("value").over( - Window.partitionBy($"value").orderBy($"key").rangeBetween(-2, -1)) - .equalTo("2") - .as("last_v"), - avg("key").over(Window.partitionBy("value").orderBy("key").rangeBetween(Long.MinValue, 1)) - .as("avg_key1"), - avg("key").over(Window.partitionBy("value").orderBy("key").rangeBetween(0, Long.MaxValue)) - .as("avg_key2"), - avg("key").over(Window.partitionBy("value").orderBy("key").rangeBetween(-1, 0)) - .as("avg_key3") - ), - Seq(Row(3, null, 3.0d, 4.0d, 3.0d), - Row(5, false, 4.0d, 5.0d, 5.0d), - Row(2, null, 2.0d, 17.0d / 4.0d, 2.0d), - Row(4, true, 11.0d / 3.0d, 5.0d, 4.0d), - Row(5, true, 17.0d / 4.0d, 11.0d / 2.0d, 4.5d), - Row(6, true, 17.0d / 4.0d, 6.0d, 11.0d / 2.0d))) - } - - test("reverse sliding range frame") { - val df = Seq( - (1, "Thin", "Cell Phone", 6000), - (2, "Normal", "Tablet", 1500), - (3, "Mini", "Tablet", 5500), - (4, "Ultra thin", "Cell Phone", 5500), - (5, "Very thin", "Cell Phone", 6000), - (6, "Big", "Tablet", 2500), - (7, "Bendable", "Cell Phone", 3000), - (8, "Foldable", "Cell Phone", 3000), - (9, "Pro", "Tablet", 4500), - (10, "Pro2", "Tablet", 6500)). - toDF("id", "product", "category", "revenue") - val window = Window. - partitionBy($"category"). - orderBy($"revenue".desc). - rangeBetween(-2000L, 1000L) - checkAnswer( - df.select( - $"id", - avg($"revenue").over(window).cast("int")), - Row(1, 5833) :: Row(2, 2000) :: Row(3, 5500) :: - Row(4, 5833) :: Row(5, 5833) :: Row(6, 2833) :: - Row(7, 3000) :: Row(8, 3000) :: Row(9, 5500) :: - Row(10, 6000) :: Nil) - } - - // This is here to illustrate the fact that reverse order also reverses offsets. - test("reverse unbounded range frame") { - val df = Seq(1, 2, 4, 3, 2, 1). - map(Tuple1.apply). - toDF("value") - val window = Window.orderBy($"value".desc) - checkAnswer( - df.select( - $"value", - sum($"value").over(window.rangeBetween(Long.MinValue, 1)), - sum($"value").over(window.rangeBetween(1, Long.MaxValue))), - Row(1, 13, null) :: Row(2, 13, 2) :: Row(4, 7, 9) :: - Row(3, 11, 6) :: Row(2, 13, 2) :: Row(1, 13, null) :: Nil) - } - test("statistical functions") { val df = Seq(("a", 1), ("a", 1), ("a", 2), ("a", 2), ("b", 4), ("b", 3), ("b", 2)). toDF("key", "value") From 16670578519a7b787b0c63888b7d2873af12d5b9 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Tue, 16 Jan 2018 18:11:27 -0800 Subject: [PATCH 0111/1161] [SPARK-22908][SS] Roll forward continuous processing Kafka support with fix to continuous Kafka data reader ## What changes were proposed in this pull request? The Kafka reader is now interruptible and can close itself. ## How was this patch tested? I locally ran one of the ContinuousKafkaSourceSuite tests in a tight loop. Before the fix, my machine ran out of open file descriptors a few iterations in; now it works fine. Author: Jose Torres Closes #20253 from jose-torres/fix-data-reader. --- .../sql/kafka010/KafkaContinuousReader.scala | 260 +++++++++ .../sql/kafka010/KafkaContinuousWriter.scala | 119 ++++ .../sql/kafka010/KafkaOffsetReader.scala | 21 +- .../spark/sql/kafka010/KafkaSource.scala | 17 +- .../sql/kafka010/KafkaSourceOffset.scala | 7 +- .../sql/kafka010/KafkaSourceProvider.scala | 105 +++- .../spark/sql/kafka010/KafkaWriteTask.scala | 71 ++- .../spark/sql/kafka010/KafkaWriter.scala | 5 +- .../kafka010/KafkaContinuousSinkSuite.scala | 476 ++++++++++++++++ .../kafka010/KafkaContinuousSourceSuite.scala | 96 ++++ .../sql/kafka010/KafkaContinuousTest.scala | 94 +++ .../spark/sql/kafka010/KafkaSourceSuite.scala | 539 +++++++++--------- .../apache/spark/sql/DataFrameReader.scala | 32 +- .../apache/spark/sql/DataFrameWriter.scala | 25 +- .../datasources/v2/WriteToDataSourceV2.scala | 8 +- .../execution/streaming/StreamExecution.scala | 15 +- .../ContinuousDataSourceRDDIter.scala | 4 +- .../continuous/ContinuousExecution.scala | 67 ++- .../continuous/EpochCoordinator.scala | 21 +- .../sql/streaming/DataStreamWriter.scala | 26 +- .../spark/sql/streaming/StreamTest.scala | 36 +- 21 files changed, 1628 insertions(+), 416 deletions(-) create mode 100644 external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala create mode 100644 external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousWriter.scala create mode 100644 external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala create mode 100644 external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala create mode 100644 external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala new file mode 100644 index 0000000000000..fc977977504f7 --- /dev/null +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala @@ -0,0 +1,260 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import java.{util => ju} +import java.util.concurrent.TimeoutException + +import org.apache.kafka.clients.consumer.{ConsumerRecord, OffsetOutOfRangeException} +import org.apache.kafka.common.TopicPartition + +import org.apache.spark.TaskContext +import org.apache.spark.internal.Logging +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.catalyst.expressions.codegen.{BufferHolder, UnsafeRowWriter} +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.kafka010.KafkaSource.{INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE, INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE} +import org.apache.spark.sql.sources.v2.reader._ +import org.apache.spark.sql.sources.v2.streaming.reader.{ContinuousDataReader, ContinuousReader, Offset, PartitionOffset} +import org.apache.spark.sql.types.StructType +import org.apache.spark.unsafe.types.UTF8String + +/** + * A [[ContinuousReader]] for data from kafka. + * + * @param offsetReader a reader used to get kafka offsets. Note that the actual data will be + * read by per-task consumers generated later. + * @param kafkaParams String params for per-task Kafka consumers. + * @param sourceOptions The [[org.apache.spark.sql.sources.v2.DataSourceV2Options]] params which + * are not Kafka consumer params. + * @param metadataPath Path to a directory this reader can use for writing metadata. + * @param initialOffsets The Kafka offsets to start reading data at. + * @param failOnDataLoss Flag indicating whether reading should fail in data loss + * scenarios, where some offsets after the specified initial ones can't be + * properly read. + */ +class KafkaContinuousReader( + offsetReader: KafkaOffsetReader, + kafkaParams: ju.Map[String, Object], + sourceOptions: Map[String, String], + metadataPath: String, + initialOffsets: KafkaOffsetRangeLimit, + failOnDataLoss: Boolean) + extends ContinuousReader with SupportsScanUnsafeRow with Logging { + + private lazy val session = SparkSession.getActiveSession.get + private lazy val sc = session.sparkContext + + private val pollTimeoutMs = sourceOptions.getOrElse("kafkaConsumer.pollTimeoutMs", "512").toLong + + // Initialized when creating read tasks. If this diverges from the partitions at the latest + // offsets, we need to reconfigure. + // Exposed outside this object only for unit tests. + private[sql] var knownPartitions: Set[TopicPartition] = _ + + override def readSchema: StructType = KafkaOffsetReader.kafkaSchema + + private var offset: Offset = _ + override def setOffset(start: ju.Optional[Offset]): Unit = { + offset = start.orElse { + val offsets = initialOffsets match { + case EarliestOffsetRangeLimit => KafkaSourceOffset(offsetReader.fetchEarliestOffsets()) + case LatestOffsetRangeLimit => KafkaSourceOffset(offsetReader.fetchLatestOffsets()) + case SpecificOffsetRangeLimit(p) => offsetReader.fetchSpecificOffsets(p, reportDataLoss) + } + logInfo(s"Initial offsets: $offsets") + offsets + } + } + + override def getStartOffset(): Offset = offset + + override def deserializeOffset(json: String): Offset = { + KafkaSourceOffset(JsonUtils.partitionOffsets(json)) + } + + override def createUnsafeRowReadTasks(): ju.List[ReadTask[UnsafeRow]] = { + import scala.collection.JavaConverters._ + + val oldStartPartitionOffsets = KafkaSourceOffset.getPartitionOffsets(offset) + + val currentPartitionSet = offsetReader.fetchEarliestOffsets().keySet + val newPartitions = currentPartitionSet.diff(oldStartPartitionOffsets.keySet) + val newPartitionOffsets = offsetReader.fetchEarliestOffsets(newPartitions.toSeq) + + val deletedPartitions = oldStartPartitionOffsets.keySet.diff(currentPartitionSet) + if (deletedPartitions.nonEmpty) { + reportDataLoss(s"Some partitions were deleted: $deletedPartitions") + } + + val startOffsets = newPartitionOffsets ++ + oldStartPartitionOffsets.filterKeys(!deletedPartitions.contains(_)) + knownPartitions = startOffsets.keySet + + startOffsets.toSeq.map { + case (topicPartition, start) => + KafkaContinuousReadTask( + topicPartition, start, kafkaParams, pollTimeoutMs, failOnDataLoss) + .asInstanceOf[ReadTask[UnsafeRow]] + }.asJava + } + + /** Stop this source and free any resources it has allocated. */ + def stop(): Unit = synchronized { + offsetReader.close() + } + + override def commit(end: Offset): Unit = {} + + override def mergeOffsets(offsets: Array[PartitionOffset]): Offset = { + val mergedMap = offsets.map { + case KafkaSourcePartitionOffset(p, o) => Map(p -> o) + }.reduce(_ ++ _) + KafkaSourceOffset(mergedMap) + } + + override def needsReconfiguration(): Boolean = { + knownPartitions != null && offsetReader.fetchLatestOffsets().keySet != knownPartitions + } + + override def toString(): String = s"KafkaSource[$offsetReader]" + + /** + * If `failOnDataLoss` is true, this method will throw an `IllegalStateException`. + * Otherwise, just log a warning. + */ + private def reportDataLoss(message: String): Unit = { + if (failOnDataLoss) { + throw new IllegalStateException(message + s". $INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE") + } else { + logWarning(message + s". $INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE") + } + } +} + +/** + * A read task for continuous Kafka processing. This will be serialized and transformed into a + * full reader on executors. + * + * @param topicPartition The (topic, partition) pair this task is responsible for. + * @param startOffset The offset to start reading from within the partition. + * @param kafkaParams Kafka consumer params to use. + * @param pollTimeoutMs The timeout for Kafka consumer polling. + * @param failOnDataLoss Flag indicating whether data reader should fail if some offsets + * are skipped. + */ +case class KafkaContinuousReadTask( + topicPartition: TopicPartition, + startOffset: Long, + kafkaParams: ju.Map[String, Object], + pollTimeoutMs: Long, + failOnDataLoss: Boolean) extends ReadTask[UnsafeRow] { + override def createDataReader(): KafkaContinuousDataReader = { + new KafkaContinuousDataReader( + topicPartition, startOffset, kafkaParams, pollTimeoutMs, failOnDataLoss) + } +} + +/** + * A per-task data reader for continuous Kafka processing. + * + * @param topicPartition The (topic, partition) pair this data reader is responsible for. + * @param startOffset The offset to start reading from within the partition. + * @param kafkaParams Kafka consumer params to use. + * @param pollTimeoutMs The timeout for Kafka consumer polling. + * @param failOnDataLoss Flag indicating whether data reader should fail if some offsets + * are skipped. + */ +class KafkaContinuousDataReader( + topicPartition: TopicPartition, + startOffset: Long, + kafkaParams: ju.Map[String, Object], + pollTimeoutMs: Long, + failOnDataLoss: Boolean) extends ContinuousDataReader[UnsafeRow] { + private val topic = topicPartition.topic + private val kafkaPartition = topicPartition.partition + private val consumer = CachedKafkaConsumer.createUncached(topic, kafkaPartition, kafkaParams) + + private val sharedRow = new UnsafeRow(7) + private val bufferHolder = new BufferHolder(sharedRow) + private val rowWriter = new UnsafeRowWriter(bufferHolder, 7) + + private var nextKafkaOffset = startOffset + private var currentRecord: ConsumerRecord[Array[Byte], Array[Byte]] = _ + + override def next(): Boolean = { + var r: ConsumerRecord[Array[Byte], Array[Byte]] = null + while (r == null) { + if (TaskContext.get().isInterrupted() || TaskContext.get().isCompleted()) return false + // Our consumer.get is not interruptible, so we have to set a low poll timeout, leaving + // interrupt points to end the query rather than waiting for new data that might never come. + try { + r = consumer.get( + nextKafkaOffset, + untilOffset = Long.MaxValue, + pollTimeoutMs, + failOnDataLoss) + } catch { + // We didn't read within the timeout. We're supposed to block indefinitely for new data, so + // swallow and ignore this. + case _: TimeoutException => + + // This is a failOnDataLoss exception. Retry if nextKafkaOffset is within the data range, + // or if it's the endpoint of the data range (i.e. the "true" next offset). + case e: IllegalStateException if e.getCause.isInstanceOf[OffsetOutOfRangeException] => + val range = consumer.getAvailableOffsetRange() + if (range.latest >= nextKafkaOffset && range.earliest <= nextKafkaOffset) { + // retry + } else { + throw e + } + } + } + nextKafkaOffset = r.offset + 1 + currentRecord = r + true + } + + override def get(): UnsafeRow = { + bufferHolder.reset() + + if (currentRecord.key == null) { + rowWriter.setNullAt(0) + } else { + rowWriter.write(0, currentRecord.key) + } + rowWriter.write(1, currentRecord.value) + rowWriter.write(2, UTF8String.fromString(currentRecord.topic)) + rowWriter.write(3, currentRecord.partition) + rowWriter.write(4, currentRecord.offset) + rowWriter.write(5, + DateTimeUtils.fromJavaTimestamp(new java.sql.Timestamp(currentRecord.timestamp))) + rowWriter.write(6, currentRecord.timestampType.id) + sharedRow.setTotalSize(bufferHolder.totalSize) + sharedRow + } + + override def getOffset(): KafkaSourcePartitionOffset = { + KafkaSourcePartitionOffset(topicPartition, nextKafkaOffset) + } + + override def close(): Unit = { + consumer.close() + } +} diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousWriter.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousWriter.scala new file mode 100644 index 0000000000000..9843f469c5b25 --- /dev/null +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousWriter.scala @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import org.apache.kafka.clients.producer.{Callback, ProducerRecord, RecordMetadata} +import scala.collection.JavaConverters._ + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.{Row, SparkSession} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Literal, UnsafeProjection} +import org.apache.spark.sql.kafka010.KafkaSourceProvider.{kafkaParamsForProducer, TOPIC_OPTION_KEY} +import org.apache.spark.sql.kafka010.KafkaWriter.validateQuery +import org.apache.spark.sql.sources.v2.streaming.writer.ContinuousWriter +import org.apache.spark.sql.sources.v2.writer._ +import org.apache.spark.sql.streaming.OutputMode +import org.apache.spark.sql.types.{BinaryType, StringType, StructType} + +/** + * Dummy commit message. The DataSourceV2 framework requires a commit message implementation but we + * don't need to really send one. + */ +case object KafkaWriterCommitMessage extends WriterCommitMessage + +/** + * A [[ContinuousWriter]] for Kafka writing. Responsible for generating the writer factory. + * @param topic The topic this writer is responsible for. If None, topic will be inferred from + * a `topic` field in the incoming data. + * @param producerParams Parameters for Kafka producers in each task. + * @param schema The schema of the input data. + */ +class KafkaContinuousWriter( + topic: Option[String], producerParams: Map[String, String], schema: StructType) + extends ContinuousWriter with SupportsWriteInternalRow { + + validateQuery(schema.toAttributes, producerParams.toMap[String, Object].asJava, topic) + + override def createInternalRowWriterFactory(): KafkaContinuousWriterFactory = + KafkaContinuousWriterFactory(topic, producerParams, schema) + + override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {} + override def abort(messages: Array[WriterCommitMessage]): Unit = {} +} + +/** + * A [[DataWriterFactory]] for Kafka writing. Will be serialized and sent to executors to generate + * the per-task data writers. + * @param topic The topic that should be written to. If None, topic will be inferred from + * a `topic` field in the incoming data. + * @param producerParams Parameters for Kafka producers in each task. + * @param schema The schema of the input data. + */ +case class KafkaContinuousWriterFactory( + topic: Option[String], producerParams: Map[String, String], schema: StructType) + extends DataWriterFactory[InternalRow] { + + override def createDataWriter(partitionId: Int, attemptNumber: Int): DataWriter[InternalRow] = { + new KafkaContinuousDataWriter(topic, producerParams, schema.toAttributes) + } +} + +/** + * A [[DataWriter]] for Kafka writing. One data writer will be created in each partition to + * process incoming rows. + * + * @param targetTopic The topic that this data writer is targeting. If None, topic will be inferred + * from a `topic` field in the incoming data. + * @param producerParams Parameters to use for the Kafka producer. + * @param inputSchema The attributes in the input data. + */ +class KafkaContinuousDataWriter( + targetTopic: Option[String], producerParams: Map[String, String], inputSchema: Seq[Attribute]) + extends KafkaRowWriter(inputSchema, targetTopic) with DataWriter[InternalRow] { + import scala.collection.JavaConverters._ + + private lazy val producer = CachedKafkaProducer.getOrCreate( + new java.util.HashMap[String, Object](producerParams.asJava)) + + def write(row: InternalRow): Unit = { + checkForErrors() + sendRow(row, producer) + } + + def commit(): WriterCommitMessage = { + // Send is asynchronous, but we can't commit until all rows are actually in Kafka. + // This requires flushing and then checking that no callbacks produced errors. + // We also check for errors before to fail as soon as possible - the check is cheap. + checkForErrors() + producer.flush() + checkForErrors() + KafkaWriterCommitMessage + } + + def abort(): Unit = {} + + def close(): Unit = { + checkForErrors() + if (producer != null) { + producer.flush() + checkForErrors() + CachedKafkaProducer.close(new java.util.HashMap[String, Object](producerParams.asJava)) + } + } +} diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala index 3e65949a6fd1b..551641cfdbca8 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala @@ -117,10 +117,14 @@ private[kafka010] class KafkaOffsetReader( * Resolves the specific offsets based on Kafka seek positions. * This method resolves offset value -1 to the latest and -2 to the * earliest Kafka seek position. + * + * @param partitionOffsets the specific offsets to resolve + * @param reportDataLoss callback to either report or log data loss depending on setting */ def fetchSpecificOffsets( - partitionOffsets: Map[TopicPartition, Long]): Map[TopicPartition, Long] = - runUninterruptibly { + partitionOffsets: Map[TopicPartition, Long], + reportDataLoss: String => Unit): KafkaSourceOffset = { + val fetched = runUninterruptibly { withRetriesWithoutInterrupt { // Poll to get the latest assigned partitions consumer.poll(0) @@ -145,6 +149,19 @@ private[kafka010] class KafkaOffsetReader( } } + partitionOffsets.foreach { + case (tp, off) if off != KafkaOffsetRangeLimit.LATEST && + off != KafkaOffsetRangeLimit.EARLIEST => + if (fetched(tp) != off) { + reportDataLoss( + s"startingOffsets for $tp was $off but consumer reset to ${fetched(tp)}") + } + case _ => + // no real way to check that beginning or end is reasonable + } + KafkaSourceOffset(fetched) + } + /** * Fetch the earliest offsets for the topic partitions that are indicated * in the [[ConsumerStrategy]]. diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala index 864a92b8f813f..169a5d006fb04 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala @@ -130,7 +130,7 @@ private[kafka010] class KafkaSource( val offsets = startingOffsets match { case EarliestOffsetRangeLimit => KafkaSourceOffset(kafkaReader.fetchEarliestOffsets()) case LatestOffsetRangeLimit => KafkaSourceOffset(kafkaReader.fetchLatestOffsets()) - case SpecificOffsetRangeLimit(p) => fetchAndVerify(p) + case SpecificOffsetRangeLimit(p) => kafkaReader.fetchSpecificOffsets(p, reportDataLoss) } metadataLog.add(0, offsets) logInfo(s"Initial offsets: $offsets") @@ -138,21 +138,6 @@ private[kafka010] class KafkaSource( }.partitionToOffsets } - private def fetchAndVerify(specificOffsets: Map[TopicPartition, Long]) = { - val result = kafkaReader.fetchSpecificOffsets(specificOffsets) - specificOffsets.foreach { - case (tp, off) if off != KafkaOffsetRangeLimit.LATEST && - off != KafkaOffsetRangeLimit.EARLIEST => - if (result(tp) != off) { - reportDataLoss( - s"startingOffsets for $tp was $off but consumer reset to ${result(tp)}") - } - case _ => - // no real way to check that beginning or end is reasonable - } - KafkaSourceOffset(result) - } - private var currentPartitionOffsets: Option[Map[TopicPartition, Long]] = None override def schema: StructType = KafkaOffsetReader.kafkaSchema diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceOffset.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceOffset.scala index b5da415b3097e..c82154cfbad7f 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceOffset.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceOffset.scala @@ -20,17 +20,22 @@ package org.apache.spark.sql.kafka010 import org.apache.kafka.common.TopicPartition import org.apache.spark.sql.execution.streaming.{Offset, SerializedOffset} +import org.apache.spark.sql.sources.v2.streaming.reader.{Offset => OffsetV2, PartitionOffset} /** * An [[Offset]] for the [[KafkaSource]]. This one tracks all partitions of subscribed topics and * their offsets. */ private[kafka010] -case class KafkaSourceOffset(partitionToOffsets: Map[TopicPartition, Long]) extends Offset { +case class KafkaSourceOffset(partitionToOffsets: Map[TopicPartition, Long]) extends OffsetV2 { override val json = JsonUtils.partitionOffsets(partitionToOffsets) } +private[kafka010] +case class KafkaSourcePartitionOffset(topicPartition: TopicPartition, partitionOffset: Long) + extends PartitionOffset + /** Companion object of the [[KafkaSourceOffset]] */ private[kafka010] object KafkaSourceOffset { diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala index 3cb4d8cad12cc..3914370a96595 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.kafka010 import java.{util => ju} -import java.util.{Locale, UUID} +import java.util.{Locale, Optional, UUID} import scala.collection.JavaConverters._ @@ -27,9 +27,12 @@ import org.apache.kafka.clients.producer.ProducerConfig import org.apache.kafka.common.serialization.{ByteArrayDeserializer, ByteArraySerializer} import org.apache.spark.internal.Logging -import org.apache.spark.sql.{AnalysisException, DataFrame, SaveMode, SQLContext} -import org.apache.spark.sql.execution.streaming.{Sink, Source} +import org.apache.spark.sql.{AnalysisException, DataFrame, SaveMode, SparkSession, SQLContext} +import org.apache.spark.sql.execution.streaming.{Offset, Sink, Source} import org.apache.spark.sql.sources._ +import org.apache.spark.sql.sources.v2.{DataSourceV2, DataSourceV2Options} +import org.apache.spark.sql.sources.v2.streaming.{ContinuousReadSupport, ContinuousWriteSupport} +import org.apache.spark.sql.sources.v2.streaming.writer.ContinuousWriter import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType @@ -43,6 +46,8 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister with StreamSinkProvider with RelationProvider with CreatableRelationProvider + with ContinuousWriteSupport + with ContinuousReadSupport with Logging { import KafkaSourceProvider._ @@ -101,6 +106,43 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister failOnDataLoss(caseInsensitiveParams)) } + override def createContinuousReader( + schema: Optional[StructType], + metadataPath: String, + options: DataSourceV2Options): KafkaContinuousReader = { + val parameters = options.asMap().asScala.toMap + validateStreamOptions(parameters) + // Each running query should use its own group id. Otherwise, the query may be only assigned + // partial data since Kafka will assign partitions to multiple consumers having the same group + // id. Hence, we should generate a unique id for each query. + val uniqueGroupId = s"spark-kafka-source-${UUID.randomUUID}-${metadataPath.hashCode}" + + val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) } + val specifiedKafkaParams = + parameters + .keySet + .filter(_.toLowerCase(Locale.ROOT).startsWith("kafka.")) + .map { k => k.drop(6).toString -> parameters(k) } + .toMap + + val startingStreamOffsets = KafkaSourceProvider.getKafkaOffsetRangeLimit(caseInsensitiveParams, + STARTING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit) + + val kafkaOffsetReader = new KafkaOffsetReader( + strategy(caseInsensitiveParams), + kafkaParamsForDriver(specifiedKafkaParams), + parameters, + driverGroupIdPrefix = s"$uniqueGroupId-driver") + + new KafkaContinuousReader( + kafkaOffsetReader, + kafkaParamsForExecutors(specifiedKafkaParams, uniqueGroupId), + parameters, + metadataPath, + startingStreamOffsets, + failOnDataLoss(caseInsensitiveParams)) + } + /** * Returns a new base relation with the given parameters. * @@ -181,26 +223,22 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister } } - private def kafkaParamsForProducer(parameters: Map[String, String]): Map[String, String] = { - val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) } - if (caseInsensitiveParams.contains(s"kafka.${ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG}")) { - throw new IllegalArgumentException( - s"Kafka option '${ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG}' is not supported as keys " - + "are serialized with ByteArraySerializer.") - } + override def createContinuousWriter( + queryId: String, + schema: StructType, + mode: OutputMode, + options: DataSourceV2Options): Optional[ContinuousWriter] = { + import scala.collection.JavaConverters._ - if (caseInsensitiveParams.contains(s"kafka.${ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG}")) - { - throw new IllegalArgumentException( - s"Kafka option '${ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG}' is not supported as " - + "value are serialized with ByteArraySerializer.") - } - parameters - .keySet - .filter(_.toLowerCase(Locale.ROOT).startsWith("kafka.")) - .map { k => k.drop(6).toString -> parameters(k) } - .toMap + (ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG -> classOf[ByteArraySerializer].getName, - ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG -> classOf[ByteArraySerializer].getName) + val spark = SparkSession.getActiveSession.get + val topic = Option(options.get(TOPIC_OPTION_KEY).orElse(null)).map(_.trim) + // We convert the options argument from V2 -> Java map -> scala mutable -> scala immutable. + val producerParams = kafkaParamsForProducer(options.asMap.asScala.toMap) + + KafkaWriter.validateQuery( + schema.toAttributes, new java.util.HashMap[String, Object](producerParams.asJava), topic) + + Optional.of(new KafkaContinuousWriter(topic, producerParams, schema)) } private def strategy(caseInsensitiveParams: Map[String, String]) = @@ -450,4 +488,27 @@ private[kafka010] object KafkaSourceProvider extends Logging { def build(): ju.Map[String, Object] = map } + + private[kafka010] def kafkaParamsForProducer( + parameters: Map[String, String]): Map[String, String] = { + val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) } + if (caseInsensitiveParams.contains(s"kafka.${ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG}")) { + throw new IllegalArgumentException( + s"Kafka option '${ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG}' is not supported as keys " + + "are serialized with ByteArraySerializer.") + } + + if (caseInsensitiveParams.contains(s"kafka.${ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG}")) + { + throw new IllegalArgumentException( + s"Kafka option '${ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG}' is not supported as " + + "value are serialized with ByteArraySerializer.") + } + parameters + .keySet + .filter(_.toLowerCase(Locale.ROOT).startsWith("kafka.")) + .map { k => k.drop(6).toString -> parameters(k) } + .toMap + (ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG -> classOf[ByteArraySerializer].getName, + ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG -> classOf[ByteArraySerializer].getName) + } } diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala index 6fd333e2f43ba..baa60febf661d 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala @@ -33,10 +33,8 @@ import org.apache.spark.sql.types.{BinaryType, StringType} private[kafka010] class KafkaWriteTask( producerConfiguration: ju.Map[String, Object], inputSchema: Seq[Attribute], - topic: Option[String]) { + topic: Option[String]) extends KafkaRowWriter(inputSchema, topic) { // used to synchronize with Kafka callbacks - @volatile private var failedWrite: Exception = null - private val projection = createProjection private var producer: KafkaProducer[Array[Byte], Array[Byte]] = _ /** @@ -46,23 +44,7 @@ private[kafka010] class KafkaWriteTask( producer = CachedKafkaProducer.getOrCreate(producerConfiguration) while (iterator.hasNext && failedWrite == null) { val currentRow = iterator.next() - val projectedRow = projection(currentRow) - val topic = projectedRow.getUTF8String(0) - val key = projectedRow.getBinary(1) - val value = projectedRow.getBinary(2) - if (topic == null) { - throw new NullPointerException(s"null topic present in the data. Use the " + - s"${KafkaSourceProvider.TOPIC_OPTION_KEY} option for setting a default topic.") - } - val record = new ProducerRecord[Array[Byte], Array[Byte]](topic.toString, key, value) - val callback = new Callback() { - override def onCompletion(recordMetadata: RecordMetadata, e: Exception): Unit = { - if (failedWrite == null && e != null) { - failedWrite = e - } - } - } - producer.send(record, callback) + sendRow(currentRow, producer) } } @@ -74,8 +56,49 @@ private[kafka010] class KafkaWriteTask( producer = null } } +} + +private[kafka010] abstract class KafkaRowWriter( + inputSchema: Seq[Attribute], topic: Option[String]) { + + // used to synchronize with Kafka callbacks + @volatile protected var failedWrite: Exception = _ + protected val projection = createProjection + + private val callback = new Callback() { + override def onCompletion(recordMetadata: RecordMetadata, e: Exception): Unit = { + if (failedWrite == null && e != null) { + failedWrite = e + } + } + } - private def createProjection: UnsafeProjection = { + /** + * Send the specified row to the producer, with a callback that will save any exception + * to failedWrite. Note that send is asynchronous; subclasses must flush() their producer before + * assuming the row is in Kafka. + */ + protected def sendRow( + row: InternalRow, producer: KafkaProducer[Array[Byte], Array[Byte]]): Unit = { + val projectedRow = projection(row) + val topic = projectedRow.getUTF8String(0) + val key = projectedRow.getBinary(1) + val value = projectedRow.getBinary(2) + if (topic == null) { + throw new NullPointerException(s"null topic present in the data. Use the " + + s"${KafkaSourceProvider.TOPIC_OPTION_KEY} option for setting a default topic.") + } + val record = new ProducerRecord[Array[Byte], Array[Byte]](topic.toString, key, value) + producer.send(record, callback) + } + + protected def checkForErrors(): Unit = { + if (failedWrite != null) { + throw failedWrite + } + } + + private def createProjection = { val topicExpression = topic.map(Literal(_)).orElse { inputSchema.find(_.name == KafkaWriter.TOPIC_ATTRIBUTE_NAME) }.getOrElse { @@ -112,11 +135,5 @@ private[kafka010] class KafkaWriteTask( Seq(topicExpression, Cast(keyExpression, BinaryType), Cast(valueExpression, BinaryType)), inputSchema) } - - private def checkForErrors(): Unit = { - if (failedWrite != null) { - throw failedWrite - } - } } diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala index 5e9ae35b3f008..15cd44812cb0c 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala @@ -43,10 +43,9 @@ private[kafka010] object KafkaWriter extends Logging { override def toString: String = "KafkaWriter" def validateQuery( - queryExecution: QueryExecution, + schema: Seq[Attribute], kafkaParameters: ju.Map[String, Object], topic: Option[String] = None): Unit = { - val schema = queryExecution.analyzed.output schema.find(_.name == TOPIC_ATTRIBUTE_NAME).getOrElse( if (topic.isEmpty) { throw new AnalysisException(s"topic option required when no " + @@ -84,7 +83,7 @@ private[kafka010] object KafkaWriter extends Logging { kafkaParameters: ju.Map[String, Object], topic: Option[String] = None): Unit = { val schema = queryExecution.analyzed.output - validateQuery(queryExecution, kafkaParameters, topic) + validateQuery(schema, kafkaParameters, topic) queryExecution.toRdd.foreachPartition { iter => val writeTask = new KafkaWriteTask(kafkaParameters, schema, topic) Utils.tryWithSafeFinally(block = writeTask.execute(iter))( diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala new file mode 100644 index 0000000000000..8487a69851237 --- /dev/null +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala @@ -0,0 +1,476 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import java.util.Locale +import java.util.concurrent.atomic.AtomicInteger + +import org.apache.kafka.clients.producer.ProducerConfig +import org.apache.kafka.common.serialization.ByteArraySerializer +import org.scalatest.time.SpanSugar._ +import scala.collection.JavaConverters._ + +import org.apache.spark.sql.{AnalysisException, DataFrame, Row, SaveMode} +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, SpecificInternalRow, UnsafeProjection} +import org.apache.spark.sql.execution.streaming.MemoryStream +import org.apache.spark.sql.streaming._ +import org.apache.spark.sql.types.{BinaryType, DataType} +import org.apache.spark.util.Utils + +/** + * This is a temporary port of KafkaSinkSuite, since we do not yet have a V2 memory stream. + * Once we have one, this will be changed to a specialization of KafkaSinkSuite and we won't have + * to duplicate all the code. + */ +class KafkaContinuousSinkSuite extends KafkaContinuousTest { + import testImplicits._ + + override val streamingTimeout = 30.seconds + + override def beforeAll(): Unit = { + super.beforeAll() + testUtils = new KafkaTestUtils( + withBrokerProps = Map("auto.create.topics.enable" -> "false")) + testUtils.setup() + } + + override def afterAll(): Unit = { + if (testUtils != null) { + testUtils.teardown() + testUtils = null + } + super.afterAll() + } + + test("streaming - write to kafka with topic field") { + val inputTopic = newTopic() + testUtils.createTopic(inputTopic, partitions = 1) + + val input = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("subscribe", inputTopic) + .option("startingOffsets", "earliest") + .load() + + val topic = newTopic() + testUtils.createTopic(topic) + + val writer = createKafkaWriter( + input.toDF(), + withTopic = None, + withOutputMode = Some(OutputMode.Append))( + withSelectExpr = s"'$topic' as topic", "value") + + val reader = createKafkaReader(topic) + .selectExpr("CAST(key as STRING) key", "CAST(value as STRING) value") + .selectExpr("CAST(key as INT) key", "CAST(value as INT) value") + .as[(Int, Int)] + .map(_._2) + + try { + testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) + eventually(timeout(streamingTimeout)) { + checkDatasetUnorderly(reader, 1, 2, 3, 4, 5) + } + testUtils.sendMessages(inputTopic, Array("6", "7", "8", "9", "10")) + eventually(timeout(streamingTimeout)) { + checkDatasetUnorderly(reader, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10) + } + } finally { + writer.stop() + } + } + + test("streaming - write w/o topic field, with topic option") { + val inputTopic = newTopic() + testUtils.createTopic(inputTopic, partitions = 1) + + val input = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("subscribe", inputTopic) + .option("startingOffsets", "earliest") + .load() + + val topic = newTopic() + testUtils.createTopic(topic) + + val writer = createKafkaWriter( + input.toDF(), + withTopic = Some(topic), + withOutputMode = Some(OutputMode.Append()))() + + val reader = createKafkaReader(topic) + .selectExpr("CAST(key as STRING) key", "CAST(value as STRING) value") + .selectExpr("CAST(key as INT) key", "CAST(value as INT) value") + .as[(Int, Int)] + .map(_._2) + + try { + testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) + eventually(timeout(streamingTimeout)) { + checkDatasetUnorderly(reader, 1, 2, 3, 4, 5) + } + testUtils.sendMessages(inputTopic, Array("6", "7", "8", "9", "10")) + eventually(timeout(streamingTimeout)) { + checkDatasetUnorderly(reader, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10) + } + } finally { + writer.stop() + } + } + + test("streaming - topic field and topic option") { + /* The purpose of this test is to ensure that the topic option + * overrides the topic field. We begin by writing some data that + * includes a topic field and value (e.g., 'foo') along with a topic + * option. Then when we read from the topic specified in the option + * we should see the data i.e., the data was written to the topic + * option, and not to the topic in the data e.g., foo + */ + val inputTopic = newTopic() + testUtils.createTopic(inputTopic, partitions = 1) + + val input = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("subscribe", inputTopic) + .option("startingOffsets", "earliest") + .load() + + val topic = newTopic() + testUtils.createTopic(topic) + + val writer = createKafkaWriter( + input.toDF(), + withTopic = Some(topic), + withOutputMode = Some(OutputMode.Append()))( + withSelectExpr = "'foo' as topic", "CAST(value as STRING) value") + + val reader = createKafkaReader(topic) + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .selectExpr("CAST(key AS INT)", "CAST(value AS INT)") + .as[(Int, Int)] + .map(_._2) + + try { + testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) + eventually(timeout(streamingTimeout)) { + checkDatasetUnorderly(reader, 1, 2, 3, 4, 5) + } + testUtils.sendMessages(inputTopic, Array("6", "7", "8", "9", "10")) + eventually(timeout(streamingTimeout)) { + checkDatasetUnorderly(reader, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10) + } + } finally { + writer.stop() + } + } + + test("null topic attribute") { + val inputTopic = newTopic() + testUtils.createTopic(inputTopic, partitions = 1) + + val input = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("subscribe", inputTopic) + .option("startingOffsets", "earliest") + .load() + val topic = newTopic() + testUtils.createTopic(topic) + + /* No topic field or topic option */ + var writer: StreamingQuery = null + var ex: Exception = null + try { + writer = createKafkaWriter(input.toDF())( + withSelectExpr = "CAST(null as STRING) as topic", "value" + ) + testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) + eventually(timeout(streamingTimeout)) { + assert(writer.exception.isDefined) + ex = writer.exception.get + } + } finally { + writer.stop() + } + assert(ex.getCause.getCause.getMessage + .toLowerCase(Locale.ROOT) + .contains("null topic present in the data.")) + } + + test("streaming - write data with bad schema") { + val inputTopic = newTopic() + testUtils.createTopic(inputTopic, partitions = 1) + + val input = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("subscribe", inputTopic) + .option("startingOffsets", "earliest") + .load() + val topic = newTopic() + testUtils.createTopic(topic) + + /* No topic field or topic option */ + var writer: StreamingQuery = null + var ex: Exception = null + try { + writer = createKafkaWriter(input.toDF())( + withSelectExpr = "value as key", "value" + ) + testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) + eventually(timeout(streamingTimeout)) { + assert(writer.exception.isDefined) + ex = writer.exception.get + } + } finally { + writer.stop() + } + assert(ex.getMessage + .toLowerCase(Locale.ROOT) + .contains("topic option required when no 'topic' attribute is present")) + + try { + /* No value field */ + writer = createKafkaWriter(input.toDF())( + withSelectExpr = s"'$topic' as topic", "value as key" + ) + testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) + eventually(timeout(streamingTimeout)) { + assert(writer.exception.isDefined) + ex = writer.exception.get + } + } finally { + writer.stop() + } + assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( + "required attribute 'value' not found")) + } + + test("streaming - write data with valid schema but wrong types") { + val inputTopic = newTopic() + testUtils.createTopic(inputTopic, partitions = 1) + + val input = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("subscribe", inputTopic) + .option("startingOffsets", "earliest") + .load() + .selectExpr("CAST(value as STRING) value") + val topic = newTopic() + testUtils.createTopic(topic) + + var writer: StreamingQuery = null + var ex: Exception = null + try { + /* topic field wrong type */ + writer = createKafkaWriter(input.toDF())( + withSelectExpr = s"CAST('1' as INT) as topic", "value" + ) + testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) + eventually(timeout(streamingTimeout)) { + assert(writer.exception.isDefined) + ex = writer.exception.get + } + } finally { + writer.stop() + } + assert(ex.getMessage.toLowerCase(Locale.ROOT).contains("topic type must be a string")) + + try { + /* value field wrong type */ + writer = createKafkaWriter(input.toDF())( + withSelectExpr = s"'$topic' as topic", "CAST(value as INT) as value" + ) + testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) + eventually(timeout(streamingTimeout)) { + assert(writer.exception.isDefined) + ex = writer.exception.get + } + } finally { + writer.stop() + } + assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( + "value attribute type must be a string or binarytype")) + + try { + /* key field wrong type */ + writer = createKafkaWriter(input.toDF())( + withSelectExpr = s"'$topic' as topic", "CAST(value as INT) as key", "value" + ) + testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) + eventually(timeout(streamingTimeout)) { + assert(writer.exception.isDefined) + ex = writer.exception.get + } + } finally { + writer.stop() + } + assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( + "key attribute type must be a string or binarytype")) + } + + test("streaming - write to non-existing topic") { + val inputTopic = newTopic() + testUtils.createTopic(inputTopic, partitions = 1) + + val input = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("subscribe", inputTopic) + .option("startingOffsets", "earliest") + .load() + val topic = newTopic() + + var writer: StreamingQuery = null + var ex: Exception = null + try { + ex = intercept[StreamingQueryException] { + writer = createKafkaWriter(input.toDF(), withTopic = Some(topic))() + testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) + eventually(timeout(streamingTimeout)) { + assert(writer.exception.isDefined) + } + throw writer.exception.get + } + } finally { + writer.stop() + } + assert(ex.getMessage.toLowerCase(Locale.ROOT).contains("job aborted")) + } + + test("streaming - exception on config serializer") { + val inputTopic = newTopic() + testUtils.createTopic(inputTopic, partitions = 1) + testUtils.sendMessages(inputTopic, Array("0")) + + val input = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("subscribe", inputTopic) + .load() + var writer: StreamingQuery = null + var ex: Exception = null + try { + writer = createKafkaWriter( + input.toDF(), + withOptions = Map("kafka.key.serializer" -> "foo"))() + eventually(timeout(streamingTimeout)) { + assert(writer.exception.isDefined) + ex = writer.exception.get + } + assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( + "kafka option 'key.serializer' is not supported")) + } finally { + writer.stop() + } + + try { + writer = createKafkaWriter( + input.toDF(), + withOptions = Map("kafka.value.serializer" -> "foo"))() + eventually(timeout(streamingTimeout)) { + assert(writer.exception.isDefined) + ex = writer.exception.get + } + assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( + "kafka option 'value.serializer' is not supported")) + } finally { + writer.stop() + } + } + + test("generic - write big data with small producer buffer") { + /* This test ensures that we understand the semantics of Kafka when + * is comes to blocking on a call to send when the send buffer is full. + * This test will configure the smallest possible producer buffer and + * indicate that we should block when it is full. Thus, no exception should + * be thrown in the case of a full buffer. + */ + val topic = newTopic() + testUtils.createTopic(topic, 1) + val options = new java.util.HashMap[String, String] + options.put("bootstrap.servers", testUtils.brokerAddress) + options.put("buffer.memory", "16384") // min buffer size + options.put("block.on.buffer.full", "true") + options.put(ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG, classOf[ByteArraySerializer].getName) + options.put(ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG, classOf[ByteArraySerializer].getName) + val inputSchema = Seq(AttributeReference("value", BinaryType)()) + val data = new Array[Byte](15000) // large value + val writeTask = new KafkaContinuousDataWriter(Some(topic), options.asScala.toMap, inputSchema) + try { + val fieldTypes: Array[DataType] = Array(BinaryType) + val converter = UnsafeProjection.create(fieldTypes) + val row = new SpecificInternalRow(fieldTypes) + row.update(0, data) + val iter = Seq.fill(1000)(converter.apply(row)).iterator + iter.foreach(writeTask.write(_)) + writeTask.commit() + } finally { + writeTask.close() + } + } + + private def createKafkaReader(topic: String): DataFrame = { + spark.read + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("startingOffsets", "earliest") + .option("endingOffsets", "latest") + .option("subscribe", topic) + .load() + } + + private def createKafkaWriter( + input: DataFrame, + withTopic: Option[String] = None, + withOutputMode: Option[OutputMode] = None, + withOptions: Map[String, String] = Map[String, String]()) + (withSelectExpr: String*): StreamingQuery = { + var stream: DataStreamWriter[Row] = null + val checkpointDir = Utils.createTempDir() + var df = input.toDF() + if (withSelectExpr.length > 0) { + df = df.selectExpr(withSelectExpr: _*) + } + stream = df.writeStream + .format("kafka") + .option("checkpointLocation", checkpointDir.getCanonicalPath) + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + // We need to reduce blocking time to efficiently test non-existent partition behavior. + .option("kafka.max.block.ms", "1000") + .trigger(Trigger.Continuous(1000)) + .queryName("kafkaStream") + withTopic.foreach(stream.option("topic", _)) + withOutputMode.foreach(stream.outputMode(_)) + withOptions.foreach(opt => stream.option(opt._1, opt._2)) + stream.start() + } +} diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala new file mode 100644 index 0000000000000..b3dade414f625 --- /dev/null +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import java.util.Properties +import java.util.concurrent.atomic.AtomicInteger + +import org.scalatest.time.SpanSugar._ +import scala.collection.mutable +import scala.util.Random + +import org.apache.spark.SparkContext +import org.apache.spark.sql.{DataFrame, Dataset, ForeachWriter, Row} +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation +import org.apache.spark.sql.execution.streaming.StreamExecution +import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution +import org.apache.spark.sql.streaming.{StreamTest, Trigger} +import org.apache.spark.sql.test.{SharedSQLContext, TestSparkSession} + +// Run tests in KafkaSourceSuiteBase in continuous execution mode. +class KafkaContinuousSourceSuite extends KafkaSourceSuiteBase with KafkaContinuousTest + +class KafkaContinuousSourceTopicDeletionSuite extends KafkaContinuousTest { + import testImplicits._ + + override val brokerProps = Map("auto.create.topics.enable" -> "false") + + test("subscribing topic by pattern with topic deletions") { + val topicPrefix = newTopic() + val topic = topicPrefix + "-seems" + val topic2 = topicPrefix + "-bad" + testUtils.createTopic(topic, partitions = 5) + testUtils.sendMessages(topic, Array("-1")) + require(testUtils.getLatestOffsets(Set(topic)).size === 5) + + val reader = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("kafka.metadata.max.age.ms", "1") + .option("subscribePattern", s"$topicPrefix-.*") + .option("failOnDataLoss", "false") + + val kafka = reader.load() + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .as[(String, String)] + val mapped = kafka.map(kv => kv._2.toInt + 1) + + testStream(mapped)( + makeSureGetOffsetCalled, + AddKafkaData(Set(topic), 1, 2, 3), + CheckAnswer(2, 3, 4), + Execute { query => + testUtils.deleteTopic(topic) + testUtils.createTopic(topic2, partitions = 5) + eventually(timeout(streamingTimeout)) { + assert( + query.lastExecution.logical.collectFirst { + case DataSourceV2Relation(_, r: KafkaContinuousReader) => r + }.exists { r => + // Ensure the new topic is present and the old topic is gone. + r.knownPartitions.exists(_.topic == topic2) + }, + s"query never reconfigured to new topic $topic2") + } + }, + AddKafkaData(Set(topic2), 4, 5, 6), + CheckAnswer(2, 3, 4, 5, 6, 7) + ) + } +} + +class KafkaContinuousSourceStressForDontFailOnDataLossSuite + extends KafkaSourceStressForDontFailOnDataLossSuite { + override protected def startStream(ds: Dataset[Int]) = { + ds.writeStream + .format("memory") + .queryName("memory") + .start() + } +} diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala new file mode 100644 index 0000000000000..5a1a14f7a307a --- /dev/null +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import java.util.concurrent.atomic.AtomicInteger + +import org.apache.spark.SparkContext +import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd, SparkListenerTaskStart} +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation +import org.apache.spark.sql.execution.streaming.StreamExecution +import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution +import org.apache.spark.sql.streaming.Trigger +import org.apache.spark.sql.test.TestSparkSession + +// Trait to configure StreamTest for kafka continuous execution tests. +trait KafkaContinuousTest extends KafkaSourceTest { + override val defaultTrigger = Trigger.Continuous(1000) + override val defaultUseV2Sink = true + + // We need more than the default local[2] to be able to schedule all partitions simultaneously. + override protected def createSparkSession = new TestSparkSession( + new SparkContext( + "local[10]", + "continuous-stream-test-sql-context", + sparkConf.set("spark.sql.testkey", "true"))) + + // In addition to setting the partitions in Kafka, we have to wait until the query has + // reconfigured to the new count so the test framework can hook in properly. + override protected def setTopicPartitions( + topic: String, newCount: Int, query: StreamExecution) = { + testUtils.addPartitions(topic, newCount) + eventually(timeout(streamingTimeout)) { + assert( + query.lastExecution.logical.collectFirst { + case DataSourceV2Relation(_, r: KafkaContinuousReader) => r + }.exists(_.knownPartitions.size == newCount), + s"query never reconfigured to $newCount partitions") + } + } + + // Continuous processing tasks end asynchronously, so test that they actually end. + private val tasksEndedListener = new SparkListener() { + val activeTaskIdCount = new AtomicInteger(0) + + override def onTaskStart(start: SparkListenerTaskStart): Unit = { + activeTaskIdCount.incrementAndGet() + } + + override def onTaskEnd(end: SparkListenerTaskEnd): Unit = { + activeTaskIdCount.decrementAndGet() + } + } + + override def beforeEach(): Unit = { + super.beforeEach() + spark.sparkContext.addSparkListener(tasksEndedListener) + } + + override def afterEach(): Unit = { + eventually(timeout(streamingTimeout)) { + assert(tasksEndedListener.activeTaskIdCount.get() == 0) + } + spark.sparkContext.removeSparkListener(tasksEndedListener) + super.afterEach() + } + + + test("ensure continuous stream is being used") { + val query = spark.readStream + .format("rate") + .option("numPartitions", "1") + .option("rowsPerSecond", "1") + .load() + + testStream(query)( + Execute(q => assert(q.isInstanceOf[ContinuousExecution])) + ) + } +} diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala index a0f5695fc485c..1acff61e11d2a 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala @@ -34,11 +34,14 @@ import org.scalatest.concurrent.PatienceConfiguration.Timeout import org.scalatest.time.SpanSugar._ import org.apache.spark.SparkContext -import org.apache.spark.sql.ForeachWriter +import org.apache.spark.sql.{DataFrame, Dataset, ForeachWriter, Row} +import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, WriteToDataSourceV2Exec} import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution +import org.apache.spark.sql.execution.streaming.sources.ContinuousMemoryWriter import org.apache.spark.sql.functions.{count, window} import org.apache.spark.sql.kafka010.KafkaSourceProvider._ -import org.apache.spark.sql.streaming.{ProcessingTime, StreamTest} +import org.apache.spark.sql.streaming.{ProcessingTime, StreamTest, Trigger} import org.apache.spark.sql.streaming.util.StreamManualClock import org.apache.spark.sql.test.{SharedSQLContext, TestSparkSession} import org.apache.spark.util.Utils @@ -49,9 +52,11 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext { override val streamingTimeout = 30.seconds + protected val brokerProps = Map[String, Object]() + override def beforeAll(): Unit = { super.beforeAll() - testUtils = new KafkaTestUtils + testUtils = new KafkaTestUtils(brokerProps) testUtils.setup() } @@ -59,18 +64,25 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext { if (testUtils != null) { testUtils.teardown() testUtils = null - super.afterAll() } + super.afterAll() } protected def makeSureGetOffsetCalled = AssertOnQuery { q => // Because KafkaSource's initialPartitionOffsets is set lazily, we need to make sure - // its "getOffset" is called before pushing any data. Otherwise, because of the race contion, + // its "getOffset" is called before pushing any data. Otherwise, because of the race condition, // we don't know which data should be fetched when `startingOffsets` is latest. - q.processAllAvailable() + q match { + case c: ContinuousExecution => c.awaitEpoch(0) + case m: MicroBatchExecution => m.processAllAvailable() + } true } + protected def setTopicPartitions(topic: String, newCount: Int, query: StreamExecution) : Unit = { + testUtils.addPartitions(topic, newCount) + } + /** * Add data to Kafka. * @@ -82,10 +94,11 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext { message: String = "", topicAction: (String, Option[Int]) => Unit = (_, _) => {}) extends AddData { - override def addData(query: Option[StreamExecution]): (Source, Offset) = { - if (query.get.isActive) { + override def addData(query: Option[StreamExecution]): (BaseStreamingSource, Offset) = { + query match { // Make sure no Spark job is running when deleting a topic - query.get.processAllAvailable() + case Some(m: MicroBatchExecution) => m.processAllAvailable() + case _ => } val existingTopics = testUtils.getAllTopicsAndPartitionSize().toMap @@ -97,16 +110,18 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext { topicAction(existingTopicPartitions._1, Some(existingTopicPartitions._2)) } - // Read all topics again in case some topics are delete. - val allTopics = testUtils.getAllTopicsAndPartitionSize().toMap.keys require( query.nonEmpty, "Cannot add data when there is no query for finding the active kafka source") val sources = query.get.logicalPlan.collect { - case StreamingExecutionRelation(source, _) if source.isInstanceOf[KafkaSource] => - source.asInstanceOf[KafkaSource] - } + case StreamingExecutionRelation(source: KafkaSource, _) => source + } ++ (query.get.lastExecution match { + case null => Seq() + case e => e.logical.collect { + case DataSourceV2Relation(_, reader: KafkaContinuousReader) => reader + } + }) if (sources.isEmpty) { throw new Exception( "Could not find Kafka source in the StreamExecution logical plan to add data to") @@ -137,14 +152,158 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext { override def toString: String = s"AddKafkaData(topics = $topics, data = $data, message = $message)" } -} + private val topicId = new AtomicInteger(0) + protected def newTopic(): String = s"topic-${topicId.getAndIncrement()}" +} -class KafkaSourceSuite extends KafkaSourceTest { +class KafkaMicroBatchSourceSuite extends KafkaSourceSuiteBase { import testImplicits._ - private val topicId = new AtomicInteger(0) + test("(de)serialization of initial offsets") { + val topic = newTopic() + testUtils.createTopic(topic, partitions = 5) + + val reader = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("subscribe", topic) + + testStream(reader.load)( + makeSureGetOffsetCalled, + StopStream, + StartStream(), + StopStream) + } + + test("maxOffsetsPerTrigger") { + val topic = newTopic() + testUtils.createTopic(topic, partitions = 3) + testUtils.sendMessages(topic, (100 to 200).map(_.toString).toArray, Some(0)) + testUtils.sendMessages(topic, (10 to 20).map(_.toString).toArray, Some(1)) + testUtils.sendMessages(topic, Array("1"), Some(2)) + + val reader = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("kafka.metadata.max.age.ms", "1") + .option("maxOffsetsPerTrigger", 10) + .option("subscribe", topic) + .option("startingOffsets", "earliest") + val kafka = reader.load() + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .as[(String, String)] + val mapped: org.apache.spark.sql.Dataset[_] = kafka.map(kv => kv._2.toInt) + + val clock = new StreamManualClock + + val waitUntilBatchProcessed = AssertOnQuery { q => + eventually(Timeout(streamingTimeout)) { + if (!q.exception.isDefined) { + assert(clock.isStreamWaitingAt(clock.getTimeMillis())) + } + } + if (q.exception.isDefined) { + throw q.exception.get + } + true + } + + testStream(mapped)( + StartStream(ProcessingTime(100), clock), + waitUntilBatchProcessed, + // 1 from smallest, 1 from middle, 8 from biggest + CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107), + AdvanceManualClock(100), + waitUntilBatchProcessed, + // smallest now empty, 1 more from middle, 9 more from biggest + CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107, + 11, 108, 109, 110, 111, 112, 113, 114, 115, 116 + ), + StopStream, + StartStream(ProcessingTime(100), clock), + waitUntilBatchProcessed, + // smallest now empty, 1 more from middle, 9 more from biggest + CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107, + 11, 108, 109, 110, 111, 112, 113, 114, 115, 116, + 12, 117, 118, 119, 120, 121, 122, 123, 124, 125 + ), + AdvanceManualClock(100), + waitUntilBatchProcessed, + // smallest now empty, 1 more from middle, 9 more from biggest + CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107, + 11, 108, 109, 110, 111, 112, 113, 114, 115, 116, + 12, 117, 118, 119, 120, 121, 122, 123, 124, 125, + 13, 126, 127, 128, 129, 130, 131, 132, 133, 134 + ) + ) + } + + test("input row metrics") { + val topic = newTopic() + testUtils.createTopic(topic, partitions = 5) + testUtils.sendMessages(topic, Array("-1")) + require(testUtils.getLatestOffsets(Set(topic)).size === 5) + + val kafka = spark + .readStream + .format("kafka") + .option("subscribe", topic) + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .load() + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .as[(String, String)] + + val mapped = kafka.map(kv => kv._2.toInt + 1) + testStream(mapped)( + StartStream(trigger = ProcessingTime(1)), + makeSureGetOffsetCalled, + AddKafkaData(Set(topic), 1, 2, 3), + CheckAnswer(2, 3, 4), + AssertOnQuery { query => + val recordsRead = query.recentProgress.map(_.numInputRows).sum + recordsRead == 3 + } + ) + } + + test("subscribing topic by pattern with topic deletions") { + val topicPrefix = newTopic() + val topic = topicPrefix + "-seems" + val topic2 = topicPrefix + "-bad" + testUtils.createTopic(topic, partitions = 5) + testUtils.sendMessages(topic, Array("-1")) + require(testUtils.getLatestOffsets(Set(topic)).size === 5) + + val reader = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("kafka.metadata.max.age.ms", "1") + .option("subscribePattern", s"$topicPrefix-.*") + .option("failOnDataLoss", "false") + + val kafka = reader.load() + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .as[(String, String)] + val mapped = kafka.map(kv => kv._2.toInt + 1) + + testStream(mapped)( + makeSureGetOffsetCalled, + AddKafkaData(Set(topic), 1, 2, 3), + CheckAnswer(2, 3, 4), + Assert { + testUtils.deleteTopic(topic) + testUtils.createTopic(topic2, partitions = 5) + true + }, + AddKafkaData(Set(topic2), 4, 5, 6), + CheckAnswer(2, 3, 4, 5, 6, 7) + ) + } testWithUninterruptibleThread( "deserialization of initial offset with Spark 2.1.0") { @@ -237,86 +396,94 @@ class KafkaSourceSuite extends KafkaSourceTest { } } - test("(de)serialization of initial offsets") { + test("KafkaSource with watermark") { + val now = System.currentTimeMillis() val topic = newTopic() - testUtils.createTopic(topic, partitions = 64) + testUtils.createTopic(newTopic(), partitions = 1) + testUtils.sendMessages(topic, Array(1).map(_.toString)) - val reader = spark + val kafka = spark .readStream .format("kafka") .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("kafka.metadata.max.age.ms", "1") + .option("startingOffsets", s"earliest") .option("subscribe", topic) + .load() - testStream(reader.load)( - makeSureGetOffsetCalled, - StopStream, - StartStream(), - StopStream) + val windowedAggregation = kafka + .withWatermark("timestamp", "10 seconds") + .groupBy(window($"timestamp", "5 seconds") as 'window) + .agg(count("*") as 'count) + .select($"window".getField("start") as 'window, $"count") + + val query = windowedAggregation + .writeStream + .format("memory") + .outputMode("complete") + .queryName("kafkaWatermark") + .start() + query.processAllAvailable() + val rows = spark.table("kafkaWatermark").collect() + assert(rows.length === 1, s"Unexpected results: ${rows.toList}") + val row = rows(0) + // We cannot check the exact window start time as it depands on the time that messages were + // inserted by the producer. So here we just use a low bound to make sure the internal + // conversion works. + assert( + row.getAs[java.sql.Timestamp]("window").getTime >= now - 5 * 1000, + s"Unexpected results: $row") + assert(row.getAs[Int]("count") === 1, s"Unexpected results: $row") + query.stop() } - test("maxOffsetsPerTrigger") { + test("delete a topic when a Spark job is running") { + KafkaSourceSuite.collectedData.clear() + val topic = newTopic() - testUtils.createTopic(topic, partitions = 3) - testUtils.sendMessages(topic, (100 to 200).map(_.toString).toArray, Some(0)) - testUtils.sendMessages(topic, (10 to 20).map(_.toString).toArray, Some(1)) - testUtils.sendMessages(topic, Array("1"), Some(2)) + testUtils.createTopic(topic, partitions = 1) + testUtils.sendMessages(topic, (1 to 10).map(_.toString).toArray) val reader = spark .readStream .format("kafka") .option("kafka.bootstrap.servers", testUtils.brokerAddress) .option("kafka.metadata.max.age.ms", "1") - .option("maxOffsetsPerTrigger", 10) .option("subscribe", topic) + // If a topic is deleted and we try to poll data starting from offset 0, + // the Kafka consumer will just block until timeout and return an empty result. + // So set the timeout to 1 second to make this test fast. + .option("kafkaConsumer.pollTimeoutMs", "1000") .option("startingOffsets", "earliest") + .option("failOnDataLoss", "false") val kafka = reader.load() .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") .as[(String, String)] - val mapped: org.apache.spark.sql.Dataset[_] = kafka.map(kv => kv._2.toInt) - - val clock = new StreamManualClock - - val waitUntilBatchProcessed = AssertOnQuery { q => - eventually(Timeout(streamingTimeout)) { - if (!q.exception.isDefined) { - assert(clock.isStreamWaitingAt(clock.getTimeMillis())) - } + KafkaSourceSuite.globalTestUtils = testUtils + // The following ForeachWriter will delete the topic before fetching data from Kafka + // in executors. + val query = kafka.map(kv => kv._2.toInt).writeStream.foreach(new ForeachWriter[Int] { + override def open(partitionId: Long, version: Long): Boolean = { + KafkaSourceSuite.globalTestUtils.deleteTopic(topic) + true } - if (q.exception.isDefined) { - throw q.exception.get + + override def process(value: Int): Unit = { + KafkaSourceSuite.collectedData.add(value) } - true - } - testStream(mapped)( - StartStream(ProcessingTime(100), clock), - waitUntilBatchProcessed, - // 1 from smallest, 1 from middle, 8 from biggest - CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107), - AdvanceManualClock(100), - waitUntilBatchProcessed, - // smallest now empty, 1 more from middle, 9 more from biggest - CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107, - 11, 108, 109, 110, 111, 112, 113, 114, 115, 116 - ), - StopStream, - StartStream(ProcessingTime(100), clock), - waitUntilBatchProcessed, - // smallest now empty, 1 more from middle, 9 more from biggest - CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107, - 11, 108, 109, 110, 111, 112, 113, 114, 115, 116, - 12, 117, 118, 119, 120, 121, 122, 123, 124, 125 - ), - AdvanceManualClock(100), - waitUntilBatchProcessed, - // smallest now empty, 1 more from middle, 9 more from biggest - CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107, - 11, 108, 109, 110, 111, 112, 113, 114, 115, 116, - 12, 117, 118, 119, 120, 121, 122, 123, 124, 125, - 13, 126, 127, 128, 129, 130, 131, 132, 133, 134 - ) - ) + override def close(errorOrNull: Throwable): Unit = {} + }).start() + query.processAllAvailable() + query.stop() + // `failOnDataLoss` is `false`, we should not fail the query + assert(query.exception.isEmpty) } +} + +class KafkaSourceSuiteBase extends KafkaSourceTest { + + import testImplicits._ test("SPARK-22956: currentPartitionOffsets should be set when no new data comes in") { def getSpecificDF(range: Range.Inclusive): org.apache.spark.sql.Dataset[Int] = { @@ -393,7 +560,7 @@ class KafkaSourceSuite extends KafkaSourceTest { .format("kafka") .option("kafka.bootstrap.servers", testUtils.brokerAddress) .option("kafka.metadata.max.age.ms", "1") - .option("subscribePattern", s"topic-.*") + .option("subscribePattern", s"$topic.*") val kafka = reader.load() .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") @@ -487,65 +654,6 @@ class KafkaSourceSuite extends KafkaSourceTest { } } - test("subscribing topic by pattern with topic deletions") { - val topicPrefix = newTopic() - val topic = topicPrefix + "-seems" - val topic2 = topicPrefix + "-bad" - testUtils.createTopic(topic, partitions = 5) - testUtils.sendMessages(topic, Array("-1")) - require(testUtils.getLatestOffsets(Set(topic)).size === 5) - - val reader = spark - .readStream - .format("kafka") - .option("kafka.bootstrap.servers", testUtils.brokerAddress) - .option("kafka.metadata.max.age.ms", "1") - .option("subscribePattern", s"$topicPrefix-.*") - .option("failOnDataLoss", "false") - - val kafka = reader.load() - .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") - .as[(String, String)] - val mapped = kafka.map(kv => kv._2.toInt + 1) - - testStream(mapped)( - makeSureGetOffsetCalled, - AddKafkaData(Set(topic), 1, 2, 3), - CheckAnswer(2, 3, 4), - Assert { - testUtils.deleteTopic(topic) - testUtils.createTopic(topic2, partitions = 5) - true - }, - AddKafkaData(Set(topic2), 4, 5, 6), - CheckAnswer(2, 3, 4, 5, 6, 7) - ) - } - - test("starting offset is latest by default") { - val topic = newTopic() - testUtils.createTopic(topic, partitions = 5) - testUtils.sendMessages(topic, Array("0")) - require(testUtils.getLatestOffsets(Set(topic)).size === 5) - - val reader = spark - .readStream - .format("kafka") - .option("kafka.bootstrap.servers", testUtils.brokerAddress) - .option("subscribe", topic) - - val kafka = reader.load() - .selectExpr("CAST(value AS STRING)") - .as[String] - val mapped = kafka.map(_.toInt) - - testStream(mapped)( - makeSureGetOffsetCalled, - AddKafkaData(Set(topic), 1, 2, 3), - CheckAnswer(1, 2, 3) // should not have 0 - ) - } - test("bad source options") { def testBadOptions(options: (String, String)*)(expectedMsgs: String*): Unit = { val ex = intercept[IllegalArgumentException] { @@ -605,77 +713,6 @@ class KafkaSourceSuite extends KafkaSourceTest { testUnsupportedConfig("kafka.auto.offset.reset", "latest") } - test("input row metrics") { - val topic = newTopic() - testUtils.createTopic(topic, partitions = 5) - testUtils.sendMessages(topic, Array("-1")) - require(testUtils.getLatestOffsets(Set(topic)).size === 5) - - val kafka = spark - .readStream - .format("kafka") - .option("subscribe", topic) - .option("kafka.bootstrap.servers", testUtils.brokerAddress) - .load() - .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") - .as[(String, String)] - - val mapped = kafka.map(kv => kv._2.toInt + 1) - testStream(mapped)( - StartStream(trigger = ProcessingTime(1)), - makeSureGetOffsetCalled, - AddKafkaData(Set(topic), 1, 2, 3), - CheckAnswer(2, 3, 4), - AssertOnQuery { query => - val recordsRead = query.recentProgress.map(_.numInputRows).sum - recordsRead == 3 - } - ) - } - - test("delete a topic when a Spark job is running") { - KafkaSourceSuite.collectedData.clear() - - val topic = newTopic() - testUtils.createTopic(topic, partitions = 1) - testUtils.sendMessages(topic, (1 to 10).map(_.toString).toArray) - - val reader = spark - .readStream - .format("kafka") - .option("kafka.bootstrap.servers", testUtils.brokerAddress) - .option("kafka.metadata.max.age.ms", "1") - .option("subscribe", topic) - // If a topic is deleted and we try to poll data starting from offset 0, - // the Kafka consumer will just block until timeout and return an empty result. - // So set the timeout to 1 second to make this test fast. - .option("kafkaConsumer.pollTimeoutMs", "1000") - .option("startingOffsets", "earliest") - .option("failOnDataLoss", "false") - val kafka = reader.load() - .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") - .as[(String, String)] - KafkaSourceSuite.globalTestUtils = testUtils - // The following ForeachWriter will delete the topic before fetching data from Kafka - // in executors. - val query = kafka.map(kv => kv._2.toInt).writeStream.foreach(new ForeachWriter[Int] { - override def open(partitionId: Long, version: Long): Boolean = { - KafkaSourceSuite.globalTestUtils.deleteTopic(topic) - true - } - - override def process(value: Int): Unit = { - KafkaSourceSuite.collectedData.add(value) - } - - override def close(errorOrNull: Throwable): Unit = {} - }).start() - query.processAllAvailable() - query.stop() - // `failOnDataLoss` is `false`, we should not fail the query - assert(query.exception.isEmpty) - } - test("get offsets from case insensitive parameters") { for ((optionKey, optionValue, answer) <- Seq( (STARTING_OFFSETS_OPTION_KEY, "earLiEst", EarliestOffsetRangeLimit), @@ -694,8 +731,6 @@ class KafkaSourceSuite extends KafkaSourceTest { } } - private def newTopic(): String = s"topic-${topicId.getAndIncrement()}" - private def assignString(topic: String, partitions: Iterable[Int]): String = { JsonUtils.partitions(partitions.map(p => new TopicPartition(topic, p))) } @@ -741,6 +776,10 @@ class KafkaSourceSuite extends KafkaSourceTest { testStream(mapped)( makeSureGetOffsetCalled, + Execute { q => + // wait to reach the last offset in every partition + q.awaitOffset(0, KafkaSourceOffset(partitionOffsets.mapValues(_ => 3L))) + }, CheckAnswer(-20, -21, -22, 0, 1, 2, 11, 12, 22), StopStream, StartStream(), @@ -771,10 +810,13 @@ class KafkaSourceSuite extends KafkaSourceTest { .format("memory") .outputMode("append") .queryName("kafkaColumnTypes") + .trigger(defaultTrigger) .start() - query.processAllAvailable() - val rows = spark.table("kafkaColumnTypes").collect() - assert(rows.length === 1, s"Unexpected results: ${rows.toList}") + var rows: Array[Row] = Array() + eventually(timeout(streamingTimeout)) { + rows = spark.table("kafkaColumnTypes").collect() + assert(rows.length === 1, s"Unexpected results: ${rows.toList}") + } val row = rows(0) assert(row.getAs[Array[Byte]]("key") === null, s"Unexpected results: $row") assert(row.getAs[Array[Byte]]("value") === "1".getBytes(UTF_8), s"Unexpected results: $row") @@ -788,47 +830,6 @@ class KafkaSourceSuite extends KafkaSourceTest { query.stop() } - test("KafkaSource with watermark") { - val now = System.currentTimeMillis() - val topic = newTopic() - testUtils.createTopic(newTopic(), partitions = 1) - testUtils.sendMessages(topic, Array(1).map(_.toString)) - - val kafka = spark - .readStream - .format("kafka") - .option("kafka.bootstrap.servers", testUtils.brokerAddress) - .option("kafka.metadata.max.age.ms", "1") - .option("startingOffsets", s"earliest") - .option("subscribe", topic) - .load() - - val windowedAggregation = kafka - .withWatermark("timestamp", "10 seconds") - .groupBy(window($"timestamp", "5 seconds") as 'window) - .agg(count("*") as 'count) - .select($"window".getField("start") as 'window, $"count") - - val query = windowedAggregation - .writeStream - .format("memory") - .outputMode("complete") - .queryName("kafkaWatermark") - .start() - query.processAllAvailable() - val rows = spark.table("kafkaWatermark").collect() - assert(rows.length === 1, s"Unexpected results: ${rows.toList}") - val row = rows(0) - // We cannot check the exact window start time as it depands on the time that messages were - // inserted by the producer. So here we just use a low bound to make sure the internal - // conversion works. - assert( - row.getAs[java.sql.Timestamp]("window").getTime >= now - 5 * 1000, - s"Unexpected results: $row") - assert(row.getAs[Int]("count") === 1, s"Unexpected results: $row") - query.stop() - } - private def testFromLatestOffsets( topic: String, addPartitions: Boolean, @@ -865,9 +866,7 @@ class KafkaSourceSuite extends KafkaSourceTest { AddKafkaData(Set(topic), 7, 8), CheckAnswer(2, 3, 4, 5, 6, 7, 8, 9), AssertOnQuery("Add partitions") { query: StreamExecution => - if (addPartitions) { - testUtils.addPartitions(topic, 10) - } + if (addPartitions) setTopicPartitions(topic, 10, query) true }, AddKafkaData(Set(topic), 9, 10, 11, 12, 13, 14, 15, 16), @@ -908,9 +907,7 @@ class KafkaSourceSuite extends KafkaSourceTest { StartStream(), CheckAnswer(2, 3, 4, 5, 6, 7, 8, 9), AssertOnQuery("Add partitions") { query: StreamExecution => - if (addPartitions) { - testUtils.addPartitions(topic, 10) - } + if (addPartitions) setTopicPartitions(topic, 10, query) true }, AddKafkaData(Set(topic), 9, 10, 11, 12, 13, 14, 15, 16), @@ -1042,20 +1039,8 @@ class KafkaSourceStressForDontFailOnDataLossSuite extends StreamTest with Shared } } - test("stress test for failOnDataLoss=false") { - val reader = spark - .readStream - .format("kafka") - .option("kafka.bootstrap.servers", testUtils.brokerAddress) - .option("kafka.metadata.max.age.ms", "1") - .option("subscribePattern", "failOnDataLoss.*") - .option("startingOffsets", "earliest") - .option("failOnDataLoss", "false") - .option("fetchOffset.retryIntervalMs", "3000") - val kafka = reader.load() - .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") - .as[(String, String)] - val query = kafka.map(kv => kv._2.toInt).writeStream.foreach(new ForeachWriter[Int] { + protected def startStream(ds: Dataset[Int]) = { + ds.writeStream.foreach(new ForeachWriter[Int] { override def open(partitionId: Long, version: Long): Boolean = { true @@ -1069,6 +1054,22 @@ class KafkaSourceStressForDontFailOnDataLossSuite extends StreamTest with Shared override def close(errorOrNull: Throwable): Unit = { } }).start() + } + + test("stress test for failOnDataLoss=false") { + val reader = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("kafka.metadata.max.age.ms", "1") + .option("subscribePattern", "failOnDataLoss.*") + .option("startingOffsets", "earliest") + .option("failOnDataLoss", "false") + .option("fetchOffset.retryIntervalMs", "3000") + val kafka = reader.load() + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .as[(String, String)] + val query = startStream(kafka.map(kv => kv._2.toInt)) val testTime = 1.minutes val startTime = System.currentTimeMillis() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index e8d683a578f35..b714a46b5f786 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -191,6 +191,9 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { ds = ds.asInstanceOf[DataSourceV2], conf = sparkSession.sessionState.conf)).asJava) + // Streaming also uses the data source V2 API. So it may be that the data source implements + // v2, but has no v2 implementation for batch reads. In that case, we fall back to loading + // the dataframe as a v1 source. val reader = (ds, userSpecifiedSchema) match { case (ds: ReadSupportWithSchema, Some(schema)) => ds.createReader(schema, options) @@ -208,23 +211,30 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { } reader - case _ => - throw new AnalysisException(s"$cls does not support data reading.") + case _ => null // fall back to v1 } - Dataset.ofRows(sparkSession, DataSourceV2Relation(reader)) + if (reader == null) { + loadV1Source(paths: _*) + } else { + Dataset.ofRows(sparkSession, DataSourceV2Relation(reader)) + } } else { - // Code path for data source v1. - sparkSession.baseRelationToDataFrame( - DataSource.apply( - sparkSession, - paths = paths, - userSpecifiedSchema = userSpecifiedSchema, - className = source, - options = extraOptions.toMap).resolveRelation()) + loadV1Source(paths: _*) } } + private def loadV1Source(paths: String*) = { + // Code path for data source v1. + sparkSession.baseRelationToDataFrame( + DataSource.apply( + sparkSession, + paths = paths, + userSpecifiedSchema = userSpecifiedSchema, + className = source, + options = extraOptions.toMap).resolveRelation()) + } + /** * Construct a `DataFrame` representing the database table accessible via JDBC URL * url named table and connection properties. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 3304f368e1050..97f12ff625c42 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -255,17 +255,24 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { } } - case _ => throw new AnalysisException(s"$cls does not support data writing.") + // Streaming also uses the data source V2 API. So it may be that the data source implements + // v2, but has no v2 implementation for batch writes. In that case, we fall back to saving + // as though it's a V1 source. + case _ => saveToV1Source() } } else { - // Code path for data source v1. - runCommand(df.sparkSession, "save") { - DataSource( - sparkSession = df.sparkSession, - className = source, - partitionColumns = partitioningColumns.getOrElse(Nil), - options = extraOptions.toMap).planForWriting(mode, AnalysisBarrier(df.logicalPlan)) - } + saveToV1Source() + } + } + + private def saveToV1Source(): Unit = { + // Code path for data source v1. + runCommand(df.sparkSession, "save") { + DataSource( + sparkSession = df.sparkSession, + className = source, + partitionColumns = partitioningColumns.getOrElse(Nil), + options = extraOptions.toMap).planForWriting(mode, AnalysisBarrier(df.logicalPlan)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala index f0bdf84bb7a84..a4a857f2d4d9b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala @@ -81,9 +81,11 @@ case class WriteToDataSourceV2Exec(writer: DataSourceV2Writer, query: SparkPlan) (index, message: WriterCommitMessage) => messages(index) = message ) - logInfo(s"Data source writer $writer is committing.") - writer.commit(messages) - logInfo(s"Data source writer $writer committed.") + if (!writer.isInstanceOf[ContinuousWriter]) { + logInfo(s"Data source writer $writer is committing.") + writer.commit(messages) + logInfo(s"Data source writer $writer committed.") + } } catch { case _: InterruptedException if writer.isInstanceOf[ContinuousWriter] => // Interruption is how continuous queries are ended, so accept and ignore the exception. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index 24a8b000df0c1..cf27e1a70650a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -142,7 +142,8 @@ abstract class StreamExecution( override val id: UUID = UUID.fromString(streamMetadata.id) - override val runId: UUID = UUID.randomUUID + override def runId: UUID = currentRunId + protected var currentRunId = UUID.randomUUID /** * Pretty identified string of printing in logs. Format is @@ -418,11 +419,17 @@ abstract class StreamExecution( * Blocks the current thread until processing for data from the given `source` has reached at * least the given `Offset`. This method is intended for use primarily when writing tests. */ - private[sql] def awaitOffset(source: BaseStreamingSource, newOffset: Offset): Unit = { + private[sql] def awaitOffset(sourceIndex: Int, newOffset: Offset): Unit = { assertAwaitThread() def notDone = { val localCommittedOffsets = committedOffsets - !localCommittedOffsets.contains(source) || localCommittedOffsets(source) != newOffset + if (sources == null) { + // sources might not be initialized yet + false + } else { + val source = sources(sourceIndex) + !localCommittedOffsets.contains(source) || localCommittedOffsets(source) != newOffset + } } while (notDone) { @@ -436,7 +443,7 @@ abstract class StreamExecution( awaitProgressLock.unlock() } } - logDebug(s"Unblocked at $newOffset for $source") + logDebug(s"Unblocked at $newOffset for ${sources(sourceIndex)}") } /** A flag to indicate that a batch has completed with no new data available. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala index b3f1a1a1aaab3..66eb42d4658f6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala @@ -77,7 +77,6 @@ class ContinuousDataSourceRDD( dataReaderThread.start() context.addTaskCompletionListener(_ => { - reader.close() dataReaderThread.interrupt() epochPollExecutor.shutdown() }) @@ -177,6 +176,7 @@ class DataReaderThread( private[continuous] var failureReason: Throwable = _ override def run(): Unit = { + TaskContext.setTaskContext(context) val baseReader = ContinuousDataSourceRDD.getBaseReader(reader) try { while (!context.isInterrupted && !context.isCompleted()) { @@ -201,6 +201,8 @@ class DataReaderThread( failedFlag.set(true) // Don't rethrow the exception in this thread. It's not needed, and the default Spark // exception handler will kill the executor. + } finally { + reader.close() } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index 9657b5e26d770..667410ef9f1c6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -17,7 +17,9 @@ package org.apache.spark.sql.execution.streaming.continuous +import java.util.UUID import java.util.concurrent.TimeUnit +import java.util.function.UnaryOperator import scala.collection.JavaConverters._ import scala.collection.mutable.{ArrayBuffer, Map => MutableMap} @@ -52,7 +54,7 @@ class ContinuousExecution( sparkSession, name, checkpointRoot, analyzedPlan, sink, trigger, triggerClock, outputMode, deleteCheckpointOnStop) { - @volatile protected var continuousSources: Seq[ContinuousReader] = Seq.empty + @volatile protected var continuousSources: Seq[ContinuousReader] = _ override protected def sources: Seq[BaseStreamingSource] = continuousSources override lazy val logicalPlan: LogicalPlan = { @@ -78,15 +80,17 @@ class ContinuousExecution( } override protected def runActivatedStream(sparkSessionForStream: SparkSession): Unit = { - do { - try { - runContinuous(sparkSessionForStream) - } catch { - case _: InterruptedException if state.get().equals(RECONFIGURING) => - // swallow exception and run again - state.set(ACTIVE) + val stateUpdate = new UnaryOperator[State] { + override def apply(s: State) = s match { + // If we ended the query to reconfigure, reset the state to active. + case RECONFIGURING => ACTIVE + case _ => s } - } while (state.get() == ACTIVE) + } + + do { + runContinuous(sparkSessionForStream) + } while (state.updateAndGet(stateUpdate) == ACTIVE) } /** @@ -120,12 +124,16 @@ class ContinuousExecution( } committedOffsets = nextOffsets.toStreamProgress(sources) - // Forcibly align commit and offset logs by slicing off any spurious offset logs from - // a previous run. We can't allow commits to an epoch that a previous run reached but - // this run has not. - offsetLog.purgeAfter(latestEpochId) + // Get to an epoch ID that has definitely never been sent to a sink before. Since sink + // commit happens between offset log write and commit log write, this means an epoch ID + // which is not in the offset log. + val (latestOffsetEpoch, _) = offsetLog.getLatest().getOrElse { + throw new IllegalStateException( + s"Offset log had no latest element. This shouldn't be possible because nextOffsets is" + + s"an element.") + } + currentBatchId = latestOffsetEpoch + 1 - currentBatchId = latestEpochId + 1 logDebug(s"Resuming at epoch $currentBatchId with committed offsets $committedOffsets") nextOffsets case None => @@ -141,6 +149,7 @@ class ContinuousExecution( * @param sparkSessionForQuery Isolated [[SparkSession]] to run the continuous query with. */ private def runContinuous(sparkSessionForQuery: SparkSession): Unit = { + currentRunId = UUID.randomUUID // A list of attributes that will need to be updated. val replacements = new ArrayBuffer[(Attribute, Attribute)] // Translate from continuous relation to the underlying data source. @@ -225,13 +234,11 @@ class ContinuousExecution( triggerExecutor.execute(() => { startTrigger() - if (reader.needsReconfiguration()) { - state.set(RECONFIGURING) + if (reader.needsReconfiguration() && state.compareAndSet(ACTIVE, RECONFIGURING)) { stopSources() if (queryExecutionThread.isAlive) { sparkSession.sparkContext.cancelJobGroup(runId.toString) queryExecutionThread.interrupt() - // No need to join - this thread is about to end anyway. } false } else if (isActive) { @@ -259,6 +266,7 @@ class ContinuousExecution( sparkSessionForQuery, lastExecution)(lastExecution.toRdd) } } finally { + epochEndpoint.askSync[Unit](StopContinuousExecutionWrites) SparkEnv.get.rpcEnv.stop(epochEndpoint) epochUpdateThread.interrupt() @@ -273,17 +281,22 @@ class ContinuousExecution( epoch: Long, reader: ContinuousReader, partitionOffsets: Seq[PartitionOffset]): Unit = { assert(continuousSources.length == 1, "only one continuous source supported currently") - if (partitionOffsets.contains(null)) { - // If any offset is null, that means the corresponding partition hasn't seen any data yet, so - // there's nothing meaningful to add to the offset log. - } val globalOffset = reader.mergeOffsets(partitionOffsets.toArray) - synchronized { - if (queryExecutionThread.isAlive) { - offsetLog.add(epoch, OffsetSeq.fill(globalOffset)) - } else { - return - } + val oldOffset = synchronized { + offsetLog.add(epoch, OffsetSeq.fill(globalOffset)) + offsetLog.get(epoch - 1) + } + + // If offset hasn't changed since last epoch, there's been no new data. + if (oldOffset.contains(OffsetSeq.fill(globalOffset))) { + noNewData = true + } + + awaitProgressLock.lock() + try { + awaitProgressLockCondition.signalAll() + } finally { + awaitProgressLock.unlock() } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala index 98017c3ac6a33..40dcbecade814 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala @@ -39,6 +39,15 @@ private[continuous] sealed trait EpochCoordinatorMessage extends Serializable */ private[sql] case object IncrementAndGetEpoch extends EpochCoordinatorMessage +/** + * The RpcEndpoint stop() will wait to clear out the message queue before terminating the + * object. This can lead to a race condition where the query restarts at epoch n, a new + * EpochCoordinator starts at epoch n, and then the old epoch coordinator commits epoch n + 1. + * The framework doesn't provide a handle to wait on the message queue, so we use a synchronous + * message to stop any writes to the ContinuousExecution object. + */ +private[sql] case object StopContinuousExecutionWrites extends EpochCoordinatorMessage + // Init messages /** * Set the reader and writer partition counts. Tasks may not be started until the coordinator @@ -116,6 +125,8 @@ private[continuous] class EpochCoordinator( override val rpcEnv: RpcEnv) extends ThreadSafeRpcEndpoint with Logging { + private var queryWritesStopped: Boolean = false + private var numReaderPartitions: Int = _ private var numWriterPartitions: Int = _ @@ -147,12 +158,16 @@ private[continuous] class EpochCoordinator( partitionCommits.remove(k) } for (k <- partitionOffsets.keys.filter { case (e, _) => e < epoch }) { - partitionCommits.remove(k) + partitionOffsets.remove(k) } } } override def receive: PartialFunction[Any, Unit] = { + // If we just drop these messages, we won't do any writes to the query. The lame duck tasks + // won't shed errors or anything. + case _ if queryWritesStopped => () + case CommitPartitionEpoch(partitionId, epoch, message) => logDebug(s"Got commit from partition $partitionId at epoch $epoch: $message") if (!partitionCommits.isDefinedAt((epoch, partitionId))) { @@ -188,5 +203,9 @@ private[continuous] class EpochCoordinator( case SetWriterPartitions(numPartitions) => numWriterPartitions = numPartitions context.reply(()) + + case StopContinuousExecutionWrites => + queryWritesStopped = true + context.reply(()) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala index db588ae282f38..b5b4a05ab4973 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger import org.apache.spark.sql.execution.streaming.sources.{MemoryPlanV2, MemorySinkV2} +import org.apache.spark.sql.sources.v2.streaming.ContinuousWriteSupport /** * Interface used to write a streaming `Dataset` to external storage systems (e.g. file systems, @@ -279,18 +280,29 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { useTempCheckpointLocation = true, trigger = trigger) } else { - val dataSource = - DataSource( - df.sparkSession, - className = source, - options = extraOptions.toMap, - partitionColumns = normalizedParCols.getOrElse(Nil)) + val sink = trigger match { + case _: ContinuousTrigger => + val ds = DataSource.lookupDataSource(source, df.sparkSession.sessionState.conf) + ds.newInstance() match { + case w: ContinuousWriteSupport => w + case _ => throw new AnalysisException( + s"Data source $source does not support continuous writing") + } + case _ => + val ds = DataSource( + df.sparkSession, + className = source, + options = extraOptions.toMap, + partitionColumns = normalizedParCols.getOrElse(Nil)) + ds.createSink(outputMode) + } + df.sparkSession.sessionState.streamingQueryManager.startQuery( extraOptions.get("queryName"), extraOptions.get("checkpointLocation"), df, extraOptions.toMap, - dataSource.createSink(outputMode), + sink, outputMode, useTempCheckpointLocation = source == "console", recoverFromCheckpointLocation = true, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index d46461fa9bf6d..0762895fdc620 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -38,8 +38,9 @@ import org.apache.spark.sql.{Dataset, Encoder, QueryTest, Row} import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder, RowEncoder} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation import org.apache.spark.sql.execution.streaming._ -import org.apache.spark.sql.execution.streaming.continuous.{ContinuousExecution, EpochCoordinatorRef, IncrementAndGetEpoch} +import org.apache.spark.sql.execution.streaming.continuous.{ContinuousExecution, ContinuousTrigger, EpochCoordinatorRef, IncrementAndGetEpoch} import org.apache.spark.sql.execution.streaming.sources.MemorySinkV2 import org.apache.spark.sql.execution.streaming.state.StateStore import org.apache.spark.sql.streaming.StreamingQueryListener._ @@ -80,6 +81,9 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be StateStore.stop() // stop the state store maintenance thread and unload store providers } + protected val defaultTrigger = Trigger.ProcessingTime(0) + protected val defaultUseV2Sink = false + /** How long to wait for an active stream to catch up when checking a result. */ val streamingTimeout = 10.seconds @@ -189,7 +193,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be /** Starts the stream, resuming if data has already been processed. It must not be running. */ case class StartStream( - trigger: Trigger = Trigger.ProcessingTime(0), + trigger: Trigger = defaultTrigger, triggerClock: Clock = new SystemClock, additionalConfs: Map[String, String] = Map.empty, checkpointLocation: String = null) @@ -276,7 +280,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be def testStream( _stream: Dataset[_], outputMode: OutputMode = OutputMode.Append, - useV2Sink: Boolean = false)(actions: StreamAction*): Unit = synchronized { + useV2Sink: Boolean = defaultUseV2Sink)(actions: StreamAction*): Unit = synchronized { import org.apache.spark.sql.streaming.util.StreamManualClock // `synchronized` is added to prevent the user from calling multiple `testStream`s concurrently @@ -403,18 +407,11 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be def fetchStreamAnswer(currentStream: StreamExecution, lastOnly: Boolean) = { verify(currentStream != null, "stream not running") - // Get the map of source index to the current source objects - val indexToSource = currentStream - .logicalPlan - .collect { case StreamingExecutionRelation(s, _) => s } - .zipWithIndex - .map(_.swap) - .toMap // Block until all data added has been processed for all the source awaiting.foreach { case (sourceIndex, offset) => failAfter(streamingTimeout) { - currentStream.awaitOffset(indexToSource(sourceIndex), offset) + currentStream.awaitOffset(sourceIndex, offset) } } @@ -473,6 +470,12 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be // after starting the query. try { currentStream.awaitInitialization(streamingTimeout.toMillis) + currentStream match { + case s: ContinuousExecution => eventually("IncrementalExecution was not created") { + s.lastExecution.executedPlan // will fail if lastExecution is null + } + case _ => + } } catch { case _: StreamingQueryException => // Ignore the exception. `StopStream` or `ExpectFailure` will catch it as well. @@ -600,7 +603,10 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be def findSourceIndex(plan: LogicalPlan): Option[Int] = { plan - .collect { case StreamingExecutionRelation(s, _) => s } + .collect { + case StreamingExecutionRelation(s, _) => s + case DataSourceV2Relation(_, r) => r + } .zipWithIndex .find(_._1 == source) .map(_._2) @@ -613,9 +619,13 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be findSourceIndex(query.logicalPlan) }.orElse { findSourceIndex(stream.logicalPlan) + }.orElse { + queryToUse.flatMap { q => + findSourceIndex(q.lastExecution.logical) + } }.getOrElse { throw new IllegalArgumentException( - "Could find index of the source to which data was added") + "Could not find index of the source to which data was added") } // Store the expected offset of added data to wait for it later From 50345a2aa59741c511d555edbbad2da9611e7d16 Mon Sep 17 00:00:00 2001 From: Sameer Agarwal Date: Tue, 16 Jan 2018 22:14:47 -0800 Subject: [PATCH 0112/1161] Revert "[SPARK-23020][CORE] Fix races in launcher code, test." This reverts commit 66217dac4f8952a9923625908ad3dcb030763c81. --- .../spark/launcher/SparkLauncherSuite.java | 49 +++++++------------ .../spark/launcher/AbstractAppHandle.java | 22 ++------- .../spark/launcher/ChildProcAppHandle.java | 18 +++---- .../spark/launcher/InProcessAppHandle.java | 17 +++---- .../spark/launcher/LauncherConnection.java | 14 +++--- .../apache/spark/launcher/LauncherServer.java | 46 +++-------------- .../org/apache/spark/launcher/BaseSuite.java | 42 +++------------- .../spark/launcher/LauncherServerSuite.java | 20 +++++--- 8 files changed, 72 insertions(+), 156 deletions(-) diff --git a/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java b/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java index a042375c6ae91..9d2f563b2e367 100644 --- a/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java +++ b/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java @@ -17,7 +17,6 @@ package org.apache.spark.launcher; -import java.time.Duration; import java.util.Arrays; import java.util.ArrayList; import java.util.HashMap; @@ -32,7 +31,6 @@ import static org.mockito.Mockito.*; import org.apache.spark.SparkContext; -import org.apache.spark.SparkContext$; import org.apache.spark.internal.config.package$; import org.apache.spark.util.Utils; @@ -139,9 +137,7 @@ public void testInProcessLauncher() throws Exception { // Here DAGScheduler is stopped, while SparkContext.clearActiveContext may not be called yet. // Wait for a reasonable amount of time to avoid creating two active SparkContext in JVM. // See SPARK-23019 and SparkContext.stop() for details. - eventually(Duration.ofSeconds(5), Duration.ofMillis(10), () -> { - assertTrue("SparkContext is still alive.", SparkContext$.MODULE$.getActive().isEmpty()); - }); + TimeUnit.MILLISECONDS.sleep(500); } } @@ -150,35 +146,26 @@ private void inProcessLauncherTestImpl() throws Exception { SparkAppHandle.Listener listener = mock(SparkAppHandle.Listener.class); doAnswer(invocation -> { SparkAppHandle h = (SparkAppHandle) invocation.getArguments()[0]; - synchronized (transitions) { - transitions.add(h.getState()); - } + transitions.add(h.getState()); return null; }).when(listener).stateChanged(any(SparkAppHandle.class)); - SparkAppHandle handle = null; - try { - handle = new InProcessLauncher() - .setMaster("local") - .setAppResource(SparkLauncher.NO_RESOURCE) - .setMainClass(InProcessTestApp.class.getName()) - .addAppArgs("hello") - .startApplication(listener); - - waitFor(handle); - assertEquals(SparkAppHandle.State.FINISHED, handle.getState()); - - // Matches the behavior of LocalSchedulerBackend. - List expected = Arrays.asList( - SparkAppHandle.State.CONNECTED, - SparkAppHandle.State.RUNNING, - SparkAppHandle.State.FINISHED); - assertEquals(expected, transitions); - } finally { - if (handle != null) { - handle.kill(); - } - } + SparkAppHandle handle = new InProcessLauncher() + .setMaster("local") + .setAppResource(SparkLauncher.NO_RESOURCE) + .setMainClass(InProcessTestApp.class.getName()) + .addAppArgs("hello") + .startApplication(listener); + + waitFor(handle); + assertEquals(SparkAppHandle.State.FINISHED, handle.getState()); + + // Matches the behavior of LocalSchedulerBackend. + List expected = Arrays.asList( + SparkAppHandle.State.CONNECTED, + SparkAppHandle.State.RUNNING, + SparkAppHandle.State.FINISHED); + assertEquals(expected, transitions); } public static class SparkLauncherTestApp { diff --git a/launcher/src/main/java/org/apache/spark/launcher/AbstractAppHandle.java b/launcher/src/main/java/org/apache/spark/launcher/AbstractAppHandle.java index daf0972f824dd..df1e7316861d4 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/AbstractAppHandle.java +++ b/launcher/src/main/java/org/apache/spark/launcher/AbstractAppHandle.java @@ -33,7 +33,7 @@ abstract class AbstractAppHandle implements SparkAppHandle { private List listeners; private State state; private String appId; - private volatile boolean disposed; + private boolean disposed; protected AbstractAppHandle(LauncherServer server) { this.server = server; @@ -70,7 +70,8 @@ public void stop() { @Override public synchronized void disconnect() { - if (!isDisposed()) { + if (!disposed) { + disposed = true; if (connection != null) { try { connection.close(); @@ -78,7 +79,7 @@ public synchronized void disconnect() { // no-op. } } - dispose(); + server.unregister(this); } } @@ -94,21 +95,6 @@ boolean isDisposed() { return disposed; } - /** - * Mark the handle as disposed, and set it as LOST in case the current state is not final. - */ - synchronized void dispose() { - if (!isDisposed()) { - // Unregister first to make sure that the connection with the app has been really - // terminated. - server.unregister(this); - if (!getState().isFinal()) { - setState(State.LOST); - } - this.disposed = true; - } - } - void setState(State s) { setState(s, false); } diff --git a/launcher/src/main/java/org/apache/spark/launcher/ChildProcAppHandle.java b/launcher/src/main/java/org/apache/spark/launcher/ChildProcAppHandle.java index 2b99461652e1f..8b3f427b7750e 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/ChildProcAppHandle.java +++ b/launcher/src/main/java/org/apache/spark/launcher/ChildProcAppHandle.java @@ -48,16 +48,14 @@ public synchronized void disconnect() { @Override public synchronized void kill() { - if (!isDisposed()) { - setState(State.KILLED); - disconnect(); - if (childProc != null) { - if (childProc.isAlive()) { - childProc.destroyForcibly(); - } - childProc = null; + disconnect(); + if (childProc != null) { + if (childProc.isAlive()) { + childProc.destroyForcibly(); } + childProc = null; } + setState(State.KILLED); } void setChildProc(Process childProc, String loggerName, InputStream logStream) { @@ -96,6 +94,8 @@ void monitorChild() { return; } + disconnect(); + int ec; try { ec = proc.exitValue(); @@ -118,8 +118,6 @@ void monitorChild() { if (newState != null) { setState(newState, true); } - - disconnect(); } } diff --git a/launcher/src/main/java/org/apache/spark/launcher/InProcessAppHandle.java b/launcher/src/main/java/org/apache/spark/launcher/InProcessAppHandle.java index f04263cb74a58..acd64c962604f 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/InProcessAppHandle.java +++ b/launcher/src/main/java/org/apache/spark/launcher/InProcessAppHandle.java @@ -39,16 +39,15 @@ class InProcessAppHandle extends AbstractAppHandle { @Override public synchronized void kill() { - if (!isDisposed()) { - LOG.warning("kill() may leave the underlying app running in in-process mode."); - setState(State.KILLED); - disconnect(); - - // Interrupt the thread. This is not guaranteed to kill the app, though. - if (app != null) { - app.interrupt(); - } + LOG.warning("kill() may leave the underlying app running in in-process mode."); + disconnect(); + + // Interrupt the thread. This is not guaranteed to kill the app, though. + if (app != null) { + app.interrupt(); } + + setState(State.KILLED); } synchronized void start(String appName, Method main, String[] args) { diff --git a/launcher/src/main/java/org/apache/spark/launcher/LauncherConnection.java b/launcher/src/main/java/org/apache/spark/launcher/LauncherConnection.java index fd6f229b2349c..b4a8719e26053 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/LauncherConnection.java +++ b/launcher/src/main/java/org/apache/spark/launcher/LauncherConnection.java @@ -95,15 +95,15 @@ protected synchronized void send(Message msg) throws IOException { } @Override - public synchronized void close() throws IOException { + public void close() throws IOException { if (!closed) { - closed = true; - socket.close(); + synchronized (this) { + if (!closed) { + closed = true; + socket.close(); + } + } } } - boolean isOpen() { - return !closed; - } - } diff --git a/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java b/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java index 660c4443b20b9..b8999a1d7a4f4 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java +++ b/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java @@ -217,33 +217,6 @@ void unregister(AbstractAppHandle handle) { break; } } - - // If there is a live connection for this handle, we need to wait for it to finish before - // returning, otherwise there might be a race between the connection thread processing - // buffered data and the handle cleaning up after itself, leading to potentially the wrong - // state being reported for the handle. - ServerConnection conn = null; - synchronized (clients) { - for (ServerConnection c : clients) { - if (c.handle == handle) { - conn = c; - break; - } - } - } - - if (conn != null) { - synchronized (conn) { - if (conn.isOpen()) { - try { - conn.wait(); - } catch (InterruptedException ie) { - // Ignore. - } - } - } - } - unref(); } @@ -315,7 +288,7 @@ private String createSecret() { private class ServerConnection extends LauncherConnection { private TimerTask timeout; - volatile AbstractAppHandle handle; + private AbstractAppHandle handle; ServerConnection(Socket socket, TimerTask timeout) throws IOException { super(socket); @@ -365,21 +338,16 @@ protected void handle(Message msg) throws IOException { @Override public void close() throws IOException { - if (!isOpen()) { - return; - } - synchronized (clients) { clients.remove(this); } - - synchronized (this) { - super.close(); - notifyAll(); - } - + super.close(); if (handle != null) { - handle.dispose(); + if (!handle.getState().isFinal()) { + LOG.log(Level.WARNING, "Lost connection to spark application."); + handle.setState(SparkAppHandle.State.LOST); + } + handle.disconnect(); } } diff --git a/launcher/src/test/java/org/apache/spark/launcher/BaseSuite.java b/launcher/src/test/java/org/apache/spark/launcher/BaseSuite.java index 3722a59d9438e..3e1a90eae98d4 100644 --- a/launcher/src/test/java/org/apache/spark/launcher/BaseSuite.java +++ b/launcher/src/test/java/org/apache/spark/launcher/BaseSuite.java @@ -17,7 +17,6 @@ package org.apache.spark.launcher; -import java.time.Duration; import java.util.concurrent.TimeUnit; import org.junit.After; @@ -48,46 +47,19 @@ public void postChecks() { assertNull(server); } - protected void waitFor(final SparkAppHandle handle) throws Exception { + protected void waitFor(SparkAppHandle handle) throws Exception { + long deadline = System.nanoTime() + TimeUnit.SECONDS.toNanos(10); try { - eventually(Duration.ofSeconds(10), Duration.ofMillis(10), () -> { - assertTrue("Handle is not in final state.", handle.getState().isFinal()); - }); + while (!handle.getState().isFinal()) { + assertTrue("Timed out waiting for handle to transition to final state.", + System.nanoTime() < deadline); + TimeUnit.MILLISECONDS.sleep(10); + } } finally { if (!handle.getState().isFinal()) { handle.kill(); } } - - // Wait until the handle has been marked as disposed, to make sure all cleanup tasks - // have been performed. - AbstractAppHandle ahandle = (AbstractAppHandle) handle; - eventually(Duration.ofSeconds(10), Duration.ofMillis(10), () -> { - assertTrue("Handle is still not marked as disposed.", ahandle.isDisposed()); - }); - } - - /** - * Call a closure that performs a check every "period" until it succeeds, or the timeout - * elapses. - */ - protected void eventually(Duration timeout, Duration period, Runnable check) throws Exception { - assertTrue("Timeout needs to be larger than period.", timeout.compareTo(period) > 0); - long deadline = System.nanoTime() + timeout.toNanos(); - int count = 0; - while (true) { - try { - count++; - check.run(); - return; - } catch (Throwable t) { - if (System.nanoTime() >= deadline) { - String msg = String.format("Failed check after %d tries: %s.", count, t.getMessage()); - throw new IllegalStateException(msg, t); - } - Thread.sleep(period.toMillis()); - } - } } } diff --git a/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java b/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java index 75c1af0c71e2a..7e2b09ce25c9b 100644 --- a/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java +++ b/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java @@ -23,14 +23,12 @@ import java.net.InetAddress; import java.net.Socket; import java.net.SocketException; -import java.time.Duration; import java.util.Arrays; import java.util.List; import java.util.concurrent.BlockingQueue; import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.Semaphore; import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicBoolean; import org.junit.Test; import static org.junit.Assert.*; @@ -199,20 +197,28 @@ private void close(Closeable c) { * server-side close immediately. */ private void waitForError(TestClient client, String secret) throws Exception { - final AtomicBoolean helloSent = new AtomicBoolean(); - eventually(Duration.ofSeconds(1), Duration.ofMillis(10), () -> { + boolean helloSent = false; + int maxTries = 10; + for (int i = 0; i < maxTries; i++) { try { - if (!helloSent.get()) { + if (!helloSent) { client.send(new Hello(secret, "1.4.0")); - helloSent.set(true); + helloSent = true; } else { client.send(new SetAppId("appId")); } fail("Expected error but message went through."); } catch (IllegalStateException | IOException e) { // Expected. + break; + } catch (AssertionError e) { + if (i < maxTries - 1) { + Thread.sleep(100); + } else { + throw new AssertionError("Test failed after " + maxTries + " attempts.", e); + } } - }); + } } private static class TestClient extends LauncherConnection { From a963980a6d2b4bef2c546aa33acf0aa501d2507b Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Tue, 16 Jan 2018 22:27:28 -0800 Subject: [PATCH 0113/1161] Fix merge between 07ae39d0ec and 1667057851 ## What changes were proposed in this pull request? The first commit added a new test, and the second refactored the class the test was in. The automatic merge put the test in the wrong place. ## How was this patch tested? - Author: Jose Torres Closes #20289 from jose-torres/fix. --- .../apache/spark/sql/kafka010/KafkaSourceSuite.scala | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala index 1acff61e11d2a..62f6a34a6b67a 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala @@ -479,11 +479,6 @@ class KafkaMicroBatchSourceSuite extends KafkaSourceSuiteBase { // `failOnDataLoss` is `false`, we should not fail the query assert(query.exception.isEmpty) } -} - -class KafkaSourceSuiteBase extends KafkaSourceTest { - - import testImplicits._ test("SPARK-22956: currentPartitionOffsets should be set when no new data comes in") { def getSpecificDF(range: Range.Inclusive): org.apache.spark.sql.Dataset[Int] = { @@ -549,6 +544,11 @@ class KafkaSourceSuiteBase extends KafkaSourceTest { CheckLastBatch(120 to 124: _*) ) } +} + +class KafkaSourceSuiteBase extends KafkaSourceTest { + + import testImplicits._ test("cannot stop Kafka stream") { val topic = newTopic() From a0aedb0ded4183cc33b27e369df1cbf862779e26 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Wed, 17 Jan 2018 14:32:18 +0800 Subject: [PATCH 0114/1161] [SPARK-23072][SQL][TEST] Add a Unicode schema test for file-based data sources ## What changes were proposed in this pull request? After [SPARK-20682](https://github.com/apache/spark/pull/19651), Apache Spark 2.3 is able to read ORC files with Unicode schema. Previously, it raises `org.apache.spark.sql.catalyst.parser.ParseException`. This PR adds a Unicode schema test for CSV/JSON/ORC/Parquet file-based data sources. Note that TEXT data source only has [a single column with a fixed name 'value'](https://github.com/apache/spark/blob/master/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala#L71). ## How was this patch tested? Pass the newly added test case. Author: Dongjoon Hyun Closes #20266 from dongjoon-hyun/SPARK-23072. --- .../spark/sql/FileBasedDataSourceSuite.scala | 81 +++++++++++++++++++ .../org/apache/spark/sql/SQLQuerySuite.scala | 16 ---- .../sql/hive/MetastoreDataSourcesSuite.scala | 14 ---- .../sql/hive/execution/SQLQuerySuite.scala | 8 -- 4 files changed, 81 insertions(+), 38 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala new file mode 100644 index 0000000000000..22fb496bc838e --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.test.SharedSQLContext + +class FileBasedDataSourceSuite extends QueryTest with SharedSQLContext { + import testImplicits._ + + private val allFileBasedDataSources = Seq("orc", "parquet", "csv", "json", "text") + + allFileBasedDataSources.foreach { format => + test(s"Writing empty datasets should not fail - $format") { + withTempPath { dir => + Seq("str").toDS().limit(0).write.format(format).save(dir.getCanonicalPath) + } + } + } + + // `TEXT` data source always has a single column whose name is `value`. + allFileBasedDataSources.filterNot(_ == "text").foreach { format => + test(s"SPARK-23072 Write and read back unicode column names - $format") { + withTempPath { path => + val dir = path.getCanonicalPath + + // scalastyle:off nonascii + val df = Seq("a").toDF("한글") + // scalastyle:on nonascii + + df.write.format(format).option("header", "true").save(dir) + val answerDf = spark.read.format(format).option("header", "true").load(dir) + + assert(df.schema.sameType(answerDf.schema)) + checkAnswer(df, answerDf) + } + } + } + + // Only ORC/Parquet support this. `CSV` and `JSON` returns an empty schema. + // `TEXT` data source always has a single column whose name is `value`. + Seq("orc", "parquet").foreach { format => + test(s"SPARK-15474 Write and read back non-emtpy schema with empty dataframe - $format") { + withTempPath { file => + val path = file.getCanonicalPath + val emptyDf = Seq((true, 1, "str")).toDF().limit(0) + emptyDf.write.format(format).save(path) + + val df = spark.read.format(format).load(path) + assert(df.schema.sameType(emptyDf.schema)) + checkAnswer(df, emptyDf) + } + } + } + + allFileBasedDataSources.foreach { format => + test(s"SPARK-22146 read files containing special characters using $format") { + val nameWithSpecialChars = s"sp&cial%chars" + withTempDir { dir => + val tmpFile = s"$dir/$nameWithSpecialChars" + spark.createDataset(Seq("a", "b")).write.format(format).save(tmpFile) + val fileContent = spark.read.format(format).load(tmpFile) + checkAnswer(fileContent, Seq(Row("a"), Row("b"))) + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 96bf65fce9c4a..7c9840a34eaa3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -2757,20 +2757,4 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } } } - - // Only New OrcFileFormat supports this - Seq(classOf[org.apache.spark.sql.execution.datasources.orc.OrcFileFormat].getCanonicalName, - "parquet").foreach { format => - test(s"SPARK-15474 Write and read back non-emtpy schema with empty dataframe - $format") { - withTempPath { file => - val path = file.getCanonicalPath - val emptyDf = Seq((true, 1, "str")).toDF.limit(0) - emptyDf.write.format(format).save(path) - - val df = spark.read.format(format).load(path) - assert(df.schema.sameType(emptyDf.schema)) - checkAnswer(df, emptyDf) - } - } - } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala index c8caba83bf365..fade143a1755e 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -23,14 +23,12 @@ import scala.collection.mutable.ArrayBuffer import org.apache.hadoop.fs.Path -import org.apache.spark.SparkContext import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable, CatalogTableType} import org.apache.spark.sql.execution.command.CreateTableCommand import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation} import org.apache.spark.sql.hive.HiveExternalCatalog._ -import org.apache.spark.sql.hive.client.HiveClient import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.StaticSQLConf._ @@ -1344,18 +1342,6 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv } } - Seq("orc", "parquet", "csv", "json", "text").foreach { format => - test(s"SPARK-22146: read files containing special characters using $format") { - val nameWithSpecialChars = s"sp&cial%chars" - withTempDir { dir => - val tmpFile = s"$dir/$nameWithSpecialChars" - spark.createDataset(Seq("a", "b")).write.format(format).save(tmpFile) - val fileContent = spark.read.format(format).load(tmpFile) - checkAnswer(fileContent, Seq(Row("a"), Row("b"))) - } - } - } - private def withDebugMode(f: => Unit): Unit = { val previousValue = sparkSession.sparkContext.conf.get(DEBUG_MODE) try { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 47adc77a52d51..33bcae91fdaf4 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -2159,12 +2159,4 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } } } - - Seq("orc", "parquet", "csv", "json", "text").foreach { format => - test(s"Writing empty datasets should not fail - $format") { - withTempDir { dir => - Seq("str").toDS.limit(0).write.format(format).save(dir.getCanonicalPath + "/tmp") - } - } - } } From 1f3d933e0bd2b1e934a233ed699ad39295376e71 Mon Sep 17 00:00:00 2001 From: Henry Robinson Date: Wed, 17 Jan 2018 16:01:41 +0800 Subject: [PATCH 0115/1161] [SPARK-23062][SQL] Improve EXCEPT documentation ## What changes were proposed in this pull request? Make the default behavior of EXCEPT (i.e. EXCEPT DISTINCT) more explicit in the documentation, and call out the change in behavior from 1.x. Author: Henry Robinson Closes #20254 from henryr/spark-23062. --- R/pkg/R/DataFrame.R | 2 +- python/pyspark/sql/dataframe.py | 3 ++- sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala | 2 +- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 6caa125e1e14a..29f3e986eaab6 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -2853,7 +2853,7 @@ setMethod("intersect", #' except #' #' Return a new SparkDataFrame containing rows in this SparkDataFrame -#' but not in another SparkDataFrame. This is equivalent to \code{EXCEPT} in SQL. +#' but not in another SparkDataFrame. This is equivalent to \code{EXCEPT DISTINCT} in SQL. #' #' @param x a SparkDataFrame. #' @param y a SparkDataFrame. diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 95eca76fa9888..2d5e9b91468cf 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1364,7 +1364,8 @@ def subtract(self, other): """ Return a new :class:`DataFrame` containing rows in this frame but not in another frame. - This is equivalent to `EXCEPT` in SQL. + This is equivalent to `EXCEPT DISTINCT` in SQL. + """ return DataFrame(getattr(self._jdf, "except")(other._jdf), self.sql_ctx) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 34f0ab5aa6699..912f411fa3845 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -1903,7 +1903,7 @@ class Dataset[T] private[sql]( /** * Returns a new Dataset containing rows in this Dataset but not in another Dataset. - * This is equivalent to `EXCEPT` in SQL. + * This is equivalent to `EXCEPT DISTINCT` in SQL. * * @note Equality checking is performed directly on the encoded representation of the data * and thus is not affected by a custom `equals` function defined on `T`. From 0f8a28617a0742d5a99debfbae91222c2e3b5cec Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Wed, 17 Jan 2018 21:53:36 +0800 Subject: [PATCH 0116/1161] [SPARK-21783][SQL] Turn on ORC filter push-down by default ## What changes were proposed in this pull request? ORC filter push-down is disabled by default from the beginning, [SPARK-2883](https://github.com/apache/spark/commit/aa31e431fc09f0477f1c2351c6275769a31aca90#diff-41ef65b9ef5b518f77e2a03559893f4dR149 ). Now, Apache Spark starts to depend on Apache ORC 1.4.1. For Apache Spark 2.3, this PR turns on ORC filter push-down by default like Parquet ([SPARK-9207](https://issues.apache.org/jira/browse/SPARK-21783)) as a part of [SPARK-20901](https://issues.apache.org/jira/browse/SPARK-20901), "Feature parity for ORC with Parquet". ## How was this patch tested? Pass the existing tests. Author: Dongjoon Hyun Closes #20265 from dongjoon-hyun/SPARK-21783. --- .../apache/spark/sql/internal/SQLConf.scala | 2 +- .../spark/sql/FilterPushdownBenchmark.scala | 243 ++++++++++++++++++ 2 files changed, 244 insertions(+), 1 deletion(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/FilterPushdownBenchmark.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 6746fbcaf2483..16fbb0c3e9e21 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -410,7 +410,7 @@ object SQLConf { val ORC_FILTER_PUSHDOWN_ENABLED = buildConf("spark.sql.orc.filterPushdown") .doc("When true, enable filter pushdown for ORC files.") .booleanConf - .createWithDefault(false) + .createWithDefault(true) val HIVE_VERIFY_PARTITION_PATH = buildConf("spark.sql.hive.verifyPartitionPath") .doc("When true, check all the partition paths under the table\'s root directory " + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FilterPushdownBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/FilterPushdownBenchmark.scala new file mode 100644 index 0000000000000..c6dd7dadc9d93 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/FilterPushdownBenchmark.scala @@ -0,0 +1,243 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import java.io.File + +import scala.util.{Random, Try} + +import org.apache.spark.SparkConf +import org.apache.spark.sql.functions.monotonically_increasing_id +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.util.{Benchmark, Utils} + + +/** + * Benchmark to measure read performance with Filter pushdown. + */ +object FilterPushdownBenchmark { + val conf = new SparkConf() + conf.set("orc.compression", "snappy") + conf.set("spark.sql.parquet.compression.codec", "snappy") + + private val spark = SparkSession.builder() + .master("local[1]") + .appName("FilterPushdownBenchmark") + .config(conf) + .getOrCreate() + + def withTempPath(f: File => Unit): Unit = { + val path = Utils.createTempDir() + path.delete() + try f(path) finally Utils.deleteRecursively(path) + } + + def withTempTable(tableNames: String*)(f: => Unit): Unit = { + try f finally tableNames.foreach(spark.catalog.dropTempView) + } + + def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = { + val (keys, values) = pairs.unzip + val currentValues = keys.map(key => Try(spark.conf.get(key)).toOption) + (keys, values).zipped.foreach(spark.conf.set) + try f finally { + keys.zip(currentValues).foreach { + case (key, Some(value)) => spark.conf.set(key, value) + case (key, None) => spark.conf.unset(key) + } + } + } + + private def prepareTable(dir: File, numRows: Int, width: Int): Unit = { + import spark.implicits._ + val selectExpr = (1 to width).map(i => s"CAST(value AS STRING) c$i") + val df = spark.range(numRows).map(_ => Random.nextLong).selectExpr(selectExpr: _*) + .withColumn("id", monotonically_increasing_id()) + + val dirORC = dir.getCanonicalPath + "/orc" + val dirParquet = dir.getCanonicalPath + "/parquet" + + df.write.mode("overwrite").orc(dirORC) + df.write.mode("overwrite").parquet(dirParquet) + + spark.read.orc(dirORC).createOrReplaceTempView("orcTable") + spark.read.parquet(dirParquet).createOrReplaceTempView("parquetTable") + } + + def filterPushDownBenchmark( + values: Int, + title: String, + whereExpr: String, + selectExpr: String = "*"): Unit = { + val benchmark = new Benchmark(title, values, minNumIters = 5) + + Seq(false, true).foreach { pushDownEnabled => + val name = s"Parquet Vectorized ${if (pushDownEnabled) s"(Pushdown)" else ""}" + benchmark.addCase(name) { _ => + withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> s"$pushDownEnabled") { + spark.sql(s"SELECT $selectExpr FROM parquetTable WHERE $whereExpr").collect() + } + } + } + + Seq(false, true).foreach { pushDownEnabled => + val name = s"Native ORC Vectorized ${if (pushDownEnabled) s"(Pushdown)" else ""}" + benchmark.addCase(name) { _ => + withSQLConf(SQLConf.ORC_FILTER_PUSHDOWN_ENABLED.key -> s"$pushDownEnabled") { + spark.sql(s"SELECT $selectExpr FROM orcTable WHERE $whereExpr").collect() + } + } + } + + /* + Java HotSpot(TM) 64-Bit Server VM 1.8.0_152-b16 on Mac OS X 10.13.2 + Intel(R) Core(TM) i7-4770HQ CPU @ 2.20GHz + + Select 0 row (id IS NULL): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ----------------------------------------------------------------------------------------------- + Parquet Vectorized 7882 / 7957 2.0 501.1 1.0X + Parquet Vectorized (Pushdown) 55 / 60 285.2 3.5 142.9X + Native ORC Vectorized 5592 / 5627 2.8 355.5 1.4X + Native ORC Vectorized (Pushdown) 66 / 70 237.2 4.2 118.9X + + Select 0 row (7864320 < id < 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ----------------------------------------------------------------------------------------------- + Parquet Vectorized 7884 / 7909 2.0 501.2 1.0X + Parquet Vectorized (Pushdown) 739 / 752 21.3 47.0 10.7X + Native ORC Vectorized 5614 / 5646 2.8 356.9 1.4X + Native ORC Vectorized (Pushdown) 81 / 83 195.2 5.1 97.8X + + Select 1 row (id = 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ----------------------------------------------------------------------------------------------- + Parquet Vectorized 7905 / 8027 2.0 502.6 1.0X + Parquet Vectorized (Pushdown) 740 / 766 21.2 47.1 10.7X + Native ORC Vectorized 5684 / 5738 2.8 361.4 1.4X + Native ORC Vectorized (Pushdown) 78 / 81 202.4 4.9 101.7X + + Select 1 row (id <=> 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ----------------------------------------------------------------------------------------------- + Parquet Vectorized 7928 / 7993 2.0 504.1 1.0X + Parquet Vectorized (Pushdown) 747 / 772 21.0 47.5 10.6X + Native ORC Vectorized 5728 / 5753 2.7 364.2 1.4X + Native ORC Vectorized (Pushdown) 76 / 78 207.9 4.8 104.8X + + Select 1 row (7864320 <= id <= 7864320):Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ----------------------------------------------------------------------------------------------- + Parquet Vectorized 7939 / 8021 2.0 504.8 1.0X + Parquet Vectorized (Pushdown) 746 / 770 21.1 47.4 10.6X + Native ORC Vectorized 5690 / 5734 2.8 361.7 1.4X + Native ORC Vectorized (Pushdown) 76 / 79 206.7 4.8 104.3X + + Select 1 row (7864319 < id < 7864321): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ----------------------------------------------------------------------------------------------- + Parquet Vectorized 7972 / 8019 2.0 506.9 1.0X + Parquet Vectorized (Pushdown) 742 / 764 21.2 47.2 10.7X + Native ORC Vectorized 5704 / 5743 2.8 362.6 1.4X + Native ORC Vectorized (Pushdown) 76 / 78 207.9 4.8 105.4X + + Select 10% rows (id < 1572864): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ----------------------------------------------------------------------------------------------- + Parquet Vectorized 8733 / 8808 1.8 555.2 1.0X + Parquet Vectorized (Pushdown) 2213 / 2267 7.1 140.7 3.9X + Native ORC Vectorized 6420 / 6463 2.4 408.2 1.4X + Native ORC Vectorized (Pushdown) 1313 / 1331 12.0 83.5 6.7X + + Select 50% rows (id < 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ----------------------------------------------------------------------------------------------- + Parquet Vectorized 11518 / 11591 1.4 732.3 1.0X + Parquet Vectorized (Pushdown) 7962 / 7991 2.0 506.2 1.4X + Native ORC Vectorized 8927 / 8985 1.8 567.6 1.3X + Native ORC Vectorized (Pushdown) 6102 / 6160 2.6 387.9 1.9X + + Select 90% rows (id < 14155776): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ----------------------------------------------------------------------------------------------- + Parquet Vectorized 14255 / 14389 1.1 906.3 1.0X + Parquet Vectorized (Pushdown) 13564 / 13594 1.2 862.4 1.1X + Native ORC Vectorized 11442 / 11608 1.4 727.5 1.2X + Native ORC Vectorized (Pushdown) 10991 / 11029 1.4 698.8 1.3X + + Select all rows (id IS NOT NULL): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ----------------------------------------------------------------------------------------------- + Parquet Vectorized 14917 / 14938 1.1 948.4 1.0X + Parquet Vectorized (Pushdown) 14910 / 14964 1.1 948.0 1.0X + Native ORC Vectorized 11986 / 12069 1.3 762.0 1.2X + Native ORC Vectorized (Pushdown) 12037 / 12123 1.3 765.3 1.2X + + Select all rows (id > -1): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ----------------------------------------------------------------------------------------------- + Parquet Vectorized 14951 / 14976 1.1 950.6 1.0X + Parquet Vectorized (Pushdown) 14934 / 15016 1.1 949.5 1.0X + Native ORC Vectorized 12000 / 12156 1.3 763.0 1.2X + Native ORC Vectorized (Pushdown) 12079 / 12113 1.3 767.9 1.2X + + Select all rows (id != -1): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ----------------------------------------------------------------------------------------------- + Parquet Vectorized 14930 / 14972 1.1 949.3 1.0X + Parquet Vectorized (Pushdown) 15015 / 15047 1.0 954.6 1.0X + Native ORC Vectorized 12090 / 12259 1.3 768.7 1.2X + Native ORC Vectorized (Pushdown) 12021 / 12096 1.3 764.2 1.2X + */ + benchmark.run() + } + + def main(args: Array[String]): Unit = { + val numRows = 1024 * 1024 * 15 + val width = 5 + val mid = numRows / 2 + + withTempPath { dir => + withTempTable("orcTable", "patquetTable") { + prepareTable(dir, numRows, width) + + Seq("id IS NULL", s"$mid < id AND id < $mid").foreach { whereExpr => + val title = s"Select 0 row ($whereExpr)".replace("id AND id", "id") + filterPushDownBenchmark(numRows, title, whereExpr) + } + + Seq( + s"id = $mid", + s"id <=> $mid", + s"$mid <= id AND id <= $mid", + s"${mid - 1} < id AND id < ${mid + 1}" + ).foreach { whereExpr => + val title = s"Select 1 row ($whereExpr)".replace("id AND id", "id") + filterPushDownBenchmark(numRows, title, whereExpr) + } + + val selectExpr = (1 to width).map(i => s"MAX(c$i)").mkString("", ",", ", MAX(id)") + + Seq(10, 50, 90).foreach { percent => + filterPushDownBenchmark( + numRows, + s"Select $percent% rows (id < ${numRows * percent / 100})", + s"id < ${numRows * percent / 100}", + selectExpr + ) + } + + Seq("id IS NOT NULL", "id > -1", "id != -1").foreach { whereExpr => + filterPushDownBenchmark( + numRows, + s"Select all rows ($whereExpr)", + whereExpr, + selectExpr) + } + } + } + } +} From 8598a982b4147abe5f1aae005fea0fd5ae395ac4 Mon Sep 17 00:00:00 2001 From: Wang Gengliang Date: Thu, 18 Jan 2018 00:05:26 +0800 Subject: [PATCH 0117/1161] [SPARK-23079][SQL] Fix query constraints propagation with aliases ## What changes were proposed in this pull request? Previously, PR #19201 fix the problem of non-converging constraints. After that PR #19149 improve the loop and constraints is inferred only once. So the problem of non-converging constraints is gone. However, the case below will fail. ``` spark.range(5).write.saveAsTable("t") val t = spark.read.table("t") val left = t.withColumn("xid", $"id" + lit(1)).as("x") val right = t.withColumnRenamed("id", "xid").as("y") val df = left.join(right, "xid").filter("id = 3").toDF() checkAnswer(df, Row(4, 3)) ``` Because `aliasMap` replace all the aliased child. See the test case in PR for details. This PR is to fix this bug by removing useless code for preventing non-converging constraints. It can be also fixed with #20270, but this is much simpler and clean up the code. ## How was this patch tested? Unit test Author: Wang Gengliang Closes #20278 from gengliangwang/FixConstraintSimple. --- .../catalyst/plans/logical/LogicalPlan.scala | 1 + .../plans/logical/QueryPlanConstraints.scala | 37 +----------- .../InferFiltersFromConstraintsSuite.scala | 59 +------------------ .../plans/ConstraintPropagationSuite.scala | 2 + .../org/apache/spark/sql/SQLQuerySuite.scala | 11 ++++ 5 files changed, 17 insertions(+), 93 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index ff2a0ec588567..c8ccd9bd03994 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -255,6 +255,7 @@ abstract class UnaryNode extends LogicalPlan { case expr: Expression if expr.semanticEquals(e) => a.toAttribute }) + allConstraints += EqualNullSafe(e, a.toAttribute) case _ => // Don't change. } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala index 9c0a30a47f839..5c7b8e5b97883 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala @@ -94,25 +94,16 @@ trait QueryPlanConstraints { self: LogicalPlan => case _ => Seq.empty[Attribute] } - // Collect aliases from expressions of the whole tree rooted by the current QueryPlan node, so - // we may avoid producing recursive constraints. - private lazy val aliasMap: AttributeMap[Expression] = AttributeMap( - expressions.collect { - case a: Alias if !a.child.isInstanceOf[Literal] => (a.toAttribute, a.child) - } ++ children.flatMap(_.asInstanceOf[QueryPlanConstraints].aliasMap)) - // Note: the explicit cast is necessary, since Scala compiler fails to infer the type. - /** * Infers an additional set of constraints from a given set of equality constraints. * For e.g., if an operator has constraints of the form (`a = 5`, `a = b`), this returns an * additional constraint of the form `b = 5`. */ private def inferAdditionalConstraints(constraints: Set[Expression]): Set[Expression] = { - val aliasedConstraints = eliminateAliasedExpressionInConstraints(constraints) var inferredConstraints = Set.empty[Expression] - aliasedConstraints.foreach { + constraints.foreach { case eq @ EqualTo(l: Attribute, r: Attribute) => - val candidateConstraints = aliasedConstraints - eq + val candidateConstraints = constraints - eq inferredConstraints ++= replaceConstraints(candidateConstraints, l, r) inferredConstraints ++= replaceConstraints(candidateConstraints, r, l) case _ => // No inference @@ -120,30 +111,6 @@ trait QueryPlanConstraints { self: LogicalPlan => inferredConstraints -- constraints } - /** - * Replace the aliased expression in [[Alias]] with the alias name if both exist in constraints. - * Thus non-converging inference can be prevented. - * E.g. `Alias(b, f(a)), a = b` infers `f(a) = f(f(a))` without eliminating aliased expressions. - * Also, the size of constraints is reduced without losing any information. - * When the inferred filters are pushed down the operators that generate the alias, - * the alias names used in filters are replaced by the aliased expressions. - */ - private def eliminateAliasedExpressionInConstraints(constraints: Set[Expression]) - : Set[Expression] = { - val attributesInEqualTo = constraints.flatMap { - case EqualTo(l: Attribute, r: Attribute) => l :: r :: Nil - case _ => Nil - } - var aliasedConstraints = constraints - attributesInEqualTo.foreach { a => - if (aliasMap.contains(a)) { - val child = aliasMap.get(a).get - aliasedConstraints = replaceConstraints(aliasedConstraints, child, a) - } - } - aliasedConstraints - } - private def replaceConstraints( constraints: Set[Expression], source: Expression, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala index a0708bf7eee9a..178c4b8c270a0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala @@ -34,6 +34,7 @@ class InferFiltersFromConstraintsSuite extends PlanTest { PushDownPredicate, InferFiltersFromConstraints, CombineFilters, + SimplifyBinaryComparison, BooleanSimplification) :: Nil } @@ -160,64 +161,6 @@ class InferFiltersFromConstraintsSuite extends PlanTest { comparePlans(optimized, correctAnswer) } - test("inner join with alias: don't generate constraints for recursive functions") { - val t1 = testRelation.subquery('t1) - val t2 = testRelation.subquery('t2) - - // We should prevent `Coalese(a, b)` from recursively creating complicated constraints through - // the constraint inference procedure. - val originalQuery = t1.select('a, 'b.as('d), Coalesce(Seq('a, 'b)).as('int_col)) - // We hide an `Alias` inside the child's child's expressions, to cover the situation reported - // in [SPARK-20700]. - .select('int_col, 'd, 'a).as("t") - .join(t2, Inner, - Some("t.a".attr === "t2.a".attr - && "t.d".attr === "t2.a".attr - && "t.int_col".attr === "t2.a".attr)) - .analyze - val correctAnswer = t1 - .where(IsNotNull('a) && IsNotNull(Coalesce(Seq('a, 'a))) && IsNotNull(Coalesce(Seq('b, 'a))) - && IsNotNull('b) && IsNotNull(Coalesce(Seq('b, 'b))) && IsNotNull(Coalesce(Seq('a, 'b))) - && 'a === 'b && 'a === Coalesce(Seq('a, 'a)) && 'a === Coalesce(Seq('a, 'b)) - && 'a === Coalesce(Seq('b, 'a)) && 'b === Coalesce(Seq('a, 'b)) - && 'b === Coalesce(Seq('b, 'a)) && 'b === Coalesce(Seq('b, 'b))) - .select('a, 'b.as('d), Coalesce(Seq('a, 'b)).as('int_col)) - .select('int_col, 'd, 'a).as("t") - .join( - t2.where(IsNotNull('a) && IsNotNull(Coalesce(Seq('a, 'a))) && - 'a === Coalesce(Seq('a, 'a))), - Inner, - Some("t.a".attr === "t2.a".attr && "t.d".attr === "t2.a".attr - && "t.int_col".attr === "t2.a".attr)) - .analyze - val optimized = Optimize.execute(originalQuery) - comparePlans(optimized, correctAnswer) - } - - test("inner join with EqualTo expressions containing part of each other: don't generate " + - "constraints for recursive functions") { - val t1 = testRelation.subquery('t1) - val t2 = testRelation.subquery('t2) - - // We should prevent `c = Coalese(a, b)` and `a = Coalese(b, c)` from recursively creating - // complicated constraints through the constraint inference procedure. - val originalQuery = t1 - .select('a, 'b, 'c, Coalesce(Seq('b, 'c)).as('d), Coalesce(Seq('a, 'b)).as('e)) - .where('a === 'd && 'c === 'e) - .join(t2, Inner, Some("t1.a".attr === "t2.a".attr && "t1.c".attr === "t2.c".attr)) - .analyze - val correctAnswer = t1 - .where(IsNotNull('a) && IsNotNull('c) && 'a === Coalesce(Seq('b, 'c)) && - 'c === Coalesce(Seq('a, 'b))) - .select('a, 'b, 'c, Coalesce(Seq('b, 'c)).as('d), Coalesce(Seq('a, 'b)).as('e)) - .join(t2.where(IsNotNull('a) && IsNotNull('c)), - Inner, - Some("t1.a".attr === "t2.a".attr && "t1.c".attr === "t2.c".attr)) - .analyze - val optimized = Optimize.execute(originalQuery) - comparePlans(optimized, correctAnswer) - } - test("generate correct filters for alias that don't produce recursive constraints") { val t1 = testRelation.subquery('t1) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala index 866ff0d33cbb2..a37e06d922642 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala @@ -134,6 +134,8 @@ class ConstraintPropagationSuite extends SparkFunSuite with PlanTest { verifyConstraints(aliasedRelation.analyze.constraints, ExpressionSet(Seq(resolveColumn(aliasedRelation.analyze, "x") > 10, IsNotNull(resolveColumn(aliasedRelation.analyze, "x")), + resolveColumn(aliasedRelation.analyze, "b") <=> resolveColumn(aliasedRelation.analyze, "y"), + resolveColumn(aliasedRelation.analyze, "z") <=> resolveColumn(aliasedRelation.analyze, "x"), resolveColumn(aliasedRelation.analyze, "z") > 10, IsNotNull(resolveColumn(aliasedRelation.analyze, "z"))))) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 7c9840a34eaa3..d4d0aa4f5f5eb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -2717,6 +2717,17 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } } + test("SPARK-23079: constraints should be inferred correctly with aliases") { + withTable("t") { + spark.range(5).write.saveAsTable("t") + val t = spark.read.table("t") + val left = t.withColumn("xid", $"id" + lit(1)).as("x") + val right = t.withColumnRenamed("id", "xid").as("y") + val df = left.join(right, "xid").filter("id = 3").toDF() + checkAnswer(df, Row(4, 3)) + } + } + test("SRARK-22266: the same aggregate function was calculated multiple times") { val query = "SELECT a, max(b+1), max(b+1) + 1 FROM testData2 GROUP BY a" val df = sql(query) From c132538a164cd8b55dbd7e8ffdc0c0782a0b588c Mon Sep 17 00:00:00 2001 From: Sameer Agarwal Date: Wed, 17 Jan 2018 09:27:49 -0800 Subject: [PATCH 0118/1161] [SPARK-23020] Ignore Flaky Test: SparkLauncherSuite.testInProcessLauncher ## What changes were proposed in this pull request? Temporarily ignoring flaky test `SparkLauncherSuite.testInProcessLauncher` to de-flake the builds. This should be re-enabled when SPARK-23020 is merged. ## How was this patch tested? N/A (Test Only Change) Author: Sameer Agarwal Closes #20291 from sameeragarwal/disable-test-2. --- .../java/org/apache/spark/launcher/SparkLauncherSuite.java | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java b/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java index 9d2f563b2e367..dffa609f1cbdf 100644 --- a/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java +++ b/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java @@ -25,6 +25,7 @@ import java.util.Properties; import java.util.concurrent.TimeUnit; +import org.junit.Ignore; import org.junit.Test; import static org.junit.Assert.*; import static org.junit.Assume.*; @@ -120,7 +121,8 @@ public void testChildProcLauncher() throws Exception { assertEquals(0, app.waitFor()); } - @Test + // TODO: [SPARK-23020] Re-enable this + @Ignore public void testInProcessLauncher() throws Exception { // Because this test runs SparkLauncher in process and in client mode, it pollutes the system // properties, and that can cause test failures down the test pipeline. So restore the original From 86a845031824a5334db6a5299c6f5dcc982bc5b8 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Wed, 17 Jan 2018 13:52:51 -0800 Subject: [PATCH 0119/1161] [SPARK-23033][SS] Don't use task level retry for continuous processing ## What changes were proposed in this pull request? Continuous processing tasks will fail on any attempt number greater than 0. ContinuousExecution will catch these failures and restart globally from the last recorded checkpoints. ## How was this patch tested? unit test Author: Jose Torres Closes #20225 from jose-torres/no-retry. --- .../spark/sql/kafka010/KafkaSourceSuite.scala | 8 +-- .../ContinuousDataSourceRDDIter.scala | 5 ++ .../continuous/ContinuousExecution.scala | 2 +- .../ContinuousTaskRetryException.scala | 26 +++++++ .../spark/sql/streaming/StreamTest.scala | 9 ++- .../continuous/ContinuousSuite.scala | 71 +++++++++++-------- 6 files changed, 84 insertions(+), 37 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousTaskRetryException.scala diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala index 62f6a34a6b67a..27dbb3f7a8f31 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala @@ -808,16 +808,14 @@ class KafkaSourceSuiteBase extends KafkaSourceTest { val query = kafka .writeStream .format("memory") - .outputMode("append") .queryName("kafkaColumnTypes") .trigger(defaultTrigger) .start() - var rows: Array[Row] = Array() eventually(timeout(streamingTimeout)) { - rows = spark.table("kafkaColumnTypes").collect() - assert(rows.length === 1, s"Unexpected results: ${rows.toList}") + assert(spark.table("kafkaColumnTypes").count == 1, + s"Unexpected results: ${spark.table("kafkaColumnTypes").collectAsList()}") } - val row = rows(0) + val row = spark.table("kafkaColumnTypes").head() assert(row.getAs[Array[Byte]]("key") === null, s"Unexpected results: $row") assert(row.getAs[Array[Byte]]("value") === "1".getBytes(UTF_8), s"Unexpected results: $row") assert(row.getAs[String]("topic") === topic, s"Unexpected results: $row") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala index 66eb42d4658f6..dcb3b54c4e160 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala @@ -52,6 +52,11 @@ class ContinuousDataSourceRDD( } override def compute(split: Partition, context: TaskContext): Iterator[UnsafeRow] = { + // If attempt number isn't 0, this is a task retry, which we don't support. + if (context.attemptNumber() != 0) { + throw new ContinuousTaskRetryException() + } + val reader = split.asInstanceOf[DataSourceRDDPartition[UnsafeRow]].readTask.createDataReader() val runId = context.getLocalProperty(ContinuousExecution.RUN_ID_KEY) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index 667410ef9f1c6..45b794c70a50a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -24,7 +24,7 @@ import java.util.function.UnaryOperator import scala.collection.JavaConverters._ import scala.collection.mutable.{ArrayBuffer, Map => MutableMap} -import org.apache.spark.SparkEnv +import org.apache.spark.{SparkEnv, SparkException} import org.apache.spark.sql.{AnalysisException, SparkSession} import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, CurrentBatchTimestamp, CurrentDate, CurrentTimestamp} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousTaskRetryException.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousTaskRetryException.scala new file mode 100644 index 0000000000000..e0a6f6dd50bb3 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousTaskRetryException.scala @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.continuous + +import org.apache.spark.SparkException + +/** + * An exception thrown when a continuous processing task runs with a nonzero attempt ID. + */ +class ContinuousTaskRetryException + extends SparkException("Continuous execution does not support task retry", null) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index 0762895fdc620..c75247e0f6ed8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -472,8 +472,8 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be currentStream.awaitInitialization(streamingTimeout.toMillis) currentStream match { case s: ContinuousExecution => eventually("IncrementalExecution was not created") { - s.lastExecution.executedPlan // will fail if lastExecution is null - } + s.lastExecution.executedPlan // will fail if lastExecution is null + } case _ => } } catch { @@ -645,7 +645,10 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be } case CheckAnswerRowsContains(expectedAnswer, lastOnly) => - val sparkAnswer = fetchStreamAnswer(currentStream, lastOnly) + val sparkAnswer = currentStream match { + case null => fetchStreamAnswer(lastStream, lastOnly) + case s => fetchStreamAnswer(s, lastOnly) + } QueryTest.includesRows(expectedAnswer, sparkAnswer).foreach { error => failTest(error) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala index 9562c10feafe9..4b4ed82dc6520 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala @@ -17,36 +17,18 @@ package org.apache.spark.sql.streaming.continuous -import java.io.{File, InterruptedIOException, IOException, UncheckedIOException} -import java.nio.channels.ClosedByInterruptException -import java.util.concurrent.{CountDownLatch, ExecutionException, TimeoutException, TimeUnit} +import java.util.UUID -import scala.reflect.ClassTag -import scala.util.control.ControlThrowable - -import com.google.common.util.concurrent.UncheckedExecutionException -import org.apache.commons.io.FileUtils -import org.apache.hadoop.conf.Configuration - -import org.apache.spark.{SparkContext, SparkEnv} -import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} +import org.apache.spark.{SparkContext, SparkEnv, SparkException} +import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart, SparkListenerTaskStart} import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.plans.logical.Range -import org.apache.spark.sql.catalyst.streaming.InternalOutputModes -import org.apache.spark.sql.execution.command.ExplainCommand import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanExec, WriteToDataSourceV2Exec} import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous._ import org.apache.spark.sql.execution.streaming.sources.MemorySinkV2 -import org.apache.spark.sql.execution.streaming.state.{StateStore, StateStoreConf, StateStoreId, StateStoreProvider} import org.apache.spark.sql.functions._ -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.sources.StreamSourceProvider import org.apache.spark.sql.streaming.{StreamTest, Trigger} -import org.apache.spark.sql.streaming.util.StreamManualClock import org.apache.spark.sql.test.TestSparkSession -import org.apache.spark.sql.types._ -import org.apache.spark.util.Utils class ContinuousSuiteBase extends StreamTest { // We need more than the default local[2] to be able to schedule all partitions simultaneously. @@ -219,6 +201,41 @@ class ContinuousSuite extends ContinuousSuiteBase { StopStream) } + test("task failure kills the query") { + val df = spark.readStream + .format("rate") + .option("numPartitions", "5") + .option("rowsPerSecond", "5") + .load() + .select('value) + + // Get an arbitrary task from this query to kill. It doesn't matter which one. + var taskId: Long = -1 + val listener = new SparkListener() { + override def onTaskStart(start: SparkListenerTaskStart): Unit = { + taskId = start.taskInfo.taskId + } + } + spark.sparkContext.addSparkListener(listener) + try { + testStream(df, useV2Sink = true)( + StartStream(Trigger.Continuous(100)), + Execute(waitForRateSourceTriggers(_, 2)), + Execute { _ => + // Wait until a task is started, then kill its first attempt. + eventually(timeout(streamingTimeout)) { + assert(taskId != -1) + } + spark.sparkContext.killTaskAttempt(taskId) + }, + ExpectFailure[SparkException] { e => + e.getCause != null && e.getCause.getCause.isInstanceOf[ContinuousTaskRetryException] + }) + } finally { + spark.sparkContext.removeSparkListener(listener) + } + } + test("query without test harness") { val df = spark.readStream .format("rate") @@ -258,13 +275,9 @@ class ContinuousStressSuite extends ContinuousSuiteBase { AwaitEpoch(0), Execute(waitForRateSourceTriggers(_, 201)), IncrementEpoch(), - Execute { query => - val data = query.sink.asInstanceOf[MemorySinkV2].allData - val vals = data.map(_.getLong(0)).toSet - assert(scala.Range(0, 25000).forall { i => - vals.contains(i) - }) - }) + StopStream, + CheckAnswerRowsContains(scala.Range(0, 25000).map(Row(_))) + ) } test("automatic epoch advancement") { @@ -280,6 +293,7 @@ class ContinuousStressSuite extends ContinuousSuiteBase { AwaitEpoch(0), Execute(waitForRateSourceTriggers(_, 201)), IncrementEpoch(), + StopStream, CheckAnswerRowsContains(scala.Range(0, 25000).map(Row(_)))) } @@ -311,6 +325,7 @@ class ContinuousStressSuite extends ContinuousSuiteBase { StopStream, StartStream(Trigger.Continuous(2012)), AwaitEpoch(50), + StopStream, CheckAnswerRowsContains(scala.Range(0, 25000).map(Row(_)))) } } From e946c63dd56d121cf898084ed7e9b5b0868b226e Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Wed, 17 Jan 2018 13:58:44 -0800 Subject: [PATCH 0120/1161] [SPARK-23093][SS] Don't change run id when reconfiguring a continuous processing query. ## What changes were proposed in this pull request? Keep the run ID static, using a different ID for the epoch coordinator to avoid cross-execution message contamination. ## How was this patch tested? new and existing unit tests Author: Jose Torres Closes #20282 from jose-torres/fix-runid. --- .../datasources/v2/DataSourceV2ScanExec.scala | 3 ++- .../datasources/v2/WriteToDataSourceV2.scala | 5 ++-- .../execution/streaming/StreamExecution.scala | 3 +-- .../ContinuousDataSourceRDDIter.scala | 10 ++++---- .../continuous/ContinuousExecution.scala | 18 ++++++++----- .../continuous/EpochCoordinator.scala | 9 ++++--- .../spark/sql/streaming/StreamTest.scala | 2 +- .../StreamingQueryListenerSuite.scala | 25 +++++++++++++++++++ 8 files changed, 54 insertions(+), 21 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala index 8c64df080242f..beb66738732be 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala @@ -58,7 +58,8 @@ case class DataSourceV2ScanExec( case _: ContinuousReader => EpochCoordinatorRef.get( - sparkContext.getLocalProperty(ContinuousExecution.RUN_ID_KEY), sparkContext.env) + sparkContext.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), + sparkContext.env) .askSync[Unit](SetReaderPartitions(readTasks.size())) new ContinuousDataSourceRDD(sparkContext, sqlContext, readTasks) .asInstanceOf[RDD[InternalRow]] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala index a4a857f2d4d9b..3dbdae7b4df9f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala @@ -64,7 +64,8 @@ case class WriteToDataSourceV2Exec(writer: DataSourceV2Writer, query: SparkPlan) val runTask = writer match { case w: ContinuousWriter => EpochCoordinatorRef.get( - sparkContext.getLocalProperty(ContinuousExecution.RUN_ID_KEY), sparkContext.env) + sparkContext.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), + sparkContext.env) .askSync[Unit](SetWriterPartitions(rdd.getNumPartitions)) (context: TaskContext, iter: Iterator[InternalRow]) => @@ -135,7 +136,7 @@ object DataWritingSparkTask extends Logging { iter: Iterator[InternalRow]): WriterCommitMessage = { val dataWriter = writeTask.createDataWriter(context.partitionId(), context.attemptNumber()) val epochCoordinator = EpochCoordinatorRef.get( - context.getLocalProperty(ContinuousExecution.RUN_ID_KEY), + context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), SparkEnv.get) val currentMsg: WriterCommitMessage = null var currentEpoch = context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index cf27e1a70650a..e7982d7880ceb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -142,8 +142,7 @@ abstract class StreamExecution( override val id: UUID = UUID.fromString(streamMetadata.id) - override def runId: UUID = currentRunId - protected var currentRunId = UUID.randomUUID + override val runId: UUID = UUID.randomUUID /** * Pretty identified string of printing in logs. Format is diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala index dcb3b54c4e160..cd7065f5e6601 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala @@ -59,7 +59,7 @@ class ContinuousDataSourceRDD( val reader = split.asInstanceOf[DataSourceRDDPartition[UnsafeRow]].readTask.createDataReader() - val runId = context.getLocalProperty(ContinuousExecution.RUN_ID_KEY) + val coordinatorId = context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY) // This queue contains two types of messages: // * (null, null) representing an epoch boundary. @@ -68,7 +68,7 @@ class ContinuousDataSourceRDD( val epochPollFailed = new AtomicBoolean(false) val epochPollExecutor = ThreadUtils.newDaemonSingleThreadScheduledExecutor( - s"epoch-poll--${runId}--${context.partitionId()}") + s"epoch-poll--$coordinatorId--${context.partitionId()}") val epochPollRunnable = new EpochPollRunnable(queue, context, epochPollFailed) epochPollExecutor.scheduleWithFixedDelay( epochPollRunnable, 0, epochPollIntervalMs, TimeUnit.MILLISECONDS) @@ -86,7 +86,7 @@ class ContinuousDataSourceRDD( epochPollExecutor.shutdown() }) - val epochEndpoint = EpochCoordinatorRef.get(runId, SparkEnv.get) + val epochEndpoint = EpochCoordinatorRef.get(coordinatorId, SparkEnv.get) new Iterator[UnsafeRow] { private val POLL_TIMEOUT_MS = 1000 @@ -150,7 +150,7 @@ class EpochPollRunnable( private[continuous] var failureReason: Throwable = _ private val epochEndpoint = EpochCoordinatorRef.get( - context.getLocalProperty(ContinuousExecution.RUN_ID_KEY), SparkEnv.get) + context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), SparkEnv.get) private var currentEpoch = context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong override def run(): Unit = { @@ -177,7 +177,7 @@ class DataReaderThread( failedFlag: AtomicBoolean) extends Thread( s"continuous-reader--${context.partitionId()}--" + - s"${context.getLocalProperty(ContinuousExecution.RUN_ID_KEY)}") { + s"${context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY)}") { private[continuous] var failureReason: Throwable = _ override def run(): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index 45b794c70a50a..c0507224f9be8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -57,6 +57,9 @@ class ContinuousExecution( @volatile protected var continuousSources: Seq[ContinuousReader] = _ override protected def sources: Seq[BaseStreamingSource] = continuousSources + // For use only in test harnesses. + private[sql] var currentEpochCoordinatorId: String = _ + override lazy val logicalPlan: LogicalPlan = { assert(queryExecutionThread eq Thread.currentThread, "logicalPlan must be initialized in StreamExecutionThread " + @@ -149,7 +152,6 @@ class ContinuousExecution( * @param sparkSessionForQuery Isolated [[SparkSession]] to run the continuous query with. */ private def runContinuous(sparkSessionForQuery: SparkSession): Unit = { - currentRunId = UUID.randomUUID // A list of attributes that will need to be updated. val replacements = new ArrayBuffer[(Attribute, Attribute)] // Translate from continuous relation to the underlying data source. @@ -219,15 +221,19 @@ class ContinuousExecution( lastExecution.executedPlan // Force the lazy generation of execution plan } - sparkSession.sparkContext.setLocalProperty( + sparkSessionForQuery.sparkContext.setLocalProperty( ContinuousExecution.START_EPOCH_KEY, currentBatchId.toString) - sparkSession.sparkContext.setLocalProperty( - ContinuousExecution.RUN_ID_KEY, runId.toString) + // Add another random ID on top of the run ID, to distinguish epoch coordinators across + // reconfigurations. + val epochCoordinatorId = s"$runId--${UUID.randomUUID}" + currentEpochCoordinatorId = epochCoordinatorId + sparkSessionForQuery.sparkContext.setLocalProperty( + ContinuousExecution.EPOCH_COORDINATOR_ID_KEY, epochCoordinatorId) // Use the parent Spark session for the endpoint since it's where this query ID is registered. val epochEndpoint = EpochCoordinatorRef.create( - writer.get(), reader, this, currentBatchId, sparkSession, SparkEnv.get) + writer.get(), reader, this, epochCoordinatorId, currentBatchId, sparkSession, SparkEnv.get) val epochUpdateThread = new Thread(new Runnable { override def run: Unit = { try { @@ -359,5 +365,5 @@ class ContinuousExecution( object ContinuousExecution { val START_EPOCH_KEY = "__continuous_start_epoch" - val RUN_ID_KEY = "__run_id" + val EPOCH_COORDINATOR_ID_KEY = "__epoch_coordinator_id" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala index 40dcbecade814..90b3584aa0436 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala @@ -79,7 +79,7 @@ private[sql] case class ReportPartitionOffset( /** Helper object used to create reference to [[EpochCoordinator]]. */ private[sql] object EpochCoordinatorRef extends Logging { - private def endpointName(runId: String) = s"EpochCoordinator-$runId" + private def endpointName(id: String) = s"EpochCoordinator-$id" /** * Create a reference to a new [[EpochCoordinator]]. @@ -88,18 +88,19 @@ private[sql] object EpochCoordinatorRef extends Logging { writer: ContinuousWriter, reader: ContinuousReader, query: ContinuousExecution, + epochCoordinatorId: String, startEpoch: Long, session: SparkSession, env: SparkEnv): RpcEndpointRef = synchronized { val coordinator = new EpochCoordinator( writer, reader, query, startEpoch, session, env.rpcEnv) - val ref = env.rpcEnv.setupEndpoint(endpointName(query.runId.toString()), coordinator) + val ref = env.rpcEnv.setupEndpoint(endpointName(epochCoordinatorId), coordinator) logInfo("Registered EpochCoordinator endpoint") ref } - def get(runId: String, env: SparkEnv): RpcEndpointRef = synchronized { - val rpcEndpointRef = RpcUtils.makeDriverRef(endpointName(runId), env.conf, env.rpcEnv) + def get(id: String, env: SparkEnv): RpcEndpointRef = synchronized { + val rpcEndpointRef = RpcUtils.makeDriverRef(endpointName(id), env.conf, env.rpcEnv) logDebug("Retrieved existing EpochCoordinator endpoint") rpcEndpointRef } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index c75247e0f6ed8..efdb0e0e7cf1c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -263,7 +263,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be def apply(): AssertOnQuery = Execute { case s: ContinuousExecution => - val newEpoch = EpochCoordinatorRef.get(s.runId.toString, SparkEnv.get) + val newEpoch = EpochCoordinatorRef.get(s.currentEpochCoordinatorId, SparkEnv.get) .askSync[Long](IncrementAndGetEpoch) s.awaitEpoch(newEpoch - 1) case _ => throw new IllegalStateException("microbatch cannot increment epoch") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala index 9ff02dee288fb..79d65192a14aa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala @@ -174,6 +174,31 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter { } } + test("continuous processing listeners should receive QueryTerminatedEvent") { + val df = spark.readStream.format("rate").load() + val listeners = (1 to 5).map(_ => new EventCollector) + try { + listeners.foreach(listener => spark.streams.addListener(listener)) + testStream(df, OutputMode.Append, useV2Sink = true)( + StartStream(Trigger.Continuous(1000)), + StopStream, + AssertOnQuery { query => + eventually(Timeout(streamingTimeout)) { + listeners.foreach(listener => assert(listener.terminationEvent !== null)) + listeners.foreach(listener => assert(listener.terminationEvent.id === query.id)) + listeners.foreach(listener => assert(listener.terminationEvent.runId === query.runId)) + listeners.foreach(listener => assert(listener.terminationEvent.exception === None)) + } + listeners.foreach(listener => listener.checkAsyncErrors()) + listeners.foreach(listener => listener.reset()) + true + } + ) + } finally { + listeners.foreach(spark.streams.removeListener) + } + } + test("adding and removing listener") { def isListenerActive(listener: EventCollector): Boolean = { listener.reset() From 4e6f8fb150ae09c7d1de6beecb2b98e5afa5da19 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Thu, 18 Jan 2018 07:26:43 +0900 Subject: [PATCH 0121/1161] [SPARK-23047][PYTHON][SQL] Change MapVector to NullableMapVector in ArrowColumnVector ## What changes were proposed in this pull request? This PR changes usage of `MapVector` in Spark codebase to use `NullableMapVector`. `MapVector` is an internal Arrow class that is not supposed to be used directly. We should use `NullableMapVector` instead. ## How was this patch tested? Existing test. Author: Li Jin Closes #20239 from icexelloss/arrow-map-vector. --- .../sql/vectorized/ArrowColumnVector.java | 13 +++++-- .../vectorized/ArrowColumnVectorSuite.scala | 36 +++++++++++++++++++ 2 files changed, 46 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java index 708333213f3f1..eb69001fe677e 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java @@ -247,8 +247,8 @@ public ArrowColumnVector(ValueVector vector) { childColumns = new ArrowColumnVector[1]; childColumns[0] = new ArrowColumnVector(listVector.getDataVector()); - } else if (vector instanceof MapVector) { - MapVector mapVector = (MapVector) vector; + } else if (vector instanceof NullableMapVector) { + NullableMapVector mapVector = (NullableMapVector) vector; accessor = new StructAccessor(mapVector); childColumns = new ArrowColumnVector[mapVector.size()]; @@ -553,9 +553,16 @@ final int getArrayOffset(int rowId) { } } + /** + * Any call to "get" method will throw UnsupportedOperationException. + * + * Access struct values in a ArrowColumnVector doesn't use this accessor. Instead, it uses getStruct() method defined + * in the parent class. Any call to "get" method in this class is a bug in the code. + * + */ private static class StructAccessor extends ArrowVectorAccessor { - StructAccessor(MapVector vector) { + StructAccessor(NullableMapVector vector) { super(vector); } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala index 7304803a092c0..53432669e215d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala @@ -322,6 +322,42 @@ class ArrowColumnVectorSuite extends SparkFunSuite { allocator.close() } + test("non nullable struct") { + val allocator = ArrowUtils.rootAllocator.newChildAllocator("struct", 0, Long.MaxValue) + val schema = new StructType().add("int", IntegerType).add("long", LongType) + val vector = ArrowUtils.toArrowField("struct", schema, nullable = false, null) + .createVector(allocator).asInstanceOf[NullableMapVector] + + vector.allocateNew() + val intVector = vector.getChildByOrdinal(0).asInstanceOf[IntVector] + val longVector = vector.getChildByOrdinal(1).asInstanceOf[BigIntVector] + + vector.setIndexDefined(0) + intVector.setSafe(0, 1) + longVector.setSafe(0, 1L) + + vector.setIndexDefined(1) + intVector.setSafe(1, 2) + longVector.setNull(1) + + vector.setValueCount(2) + + val columnVector = new ArrowColumnVector(vector) + assert(columnVector.dataType === schema) + assert(columnVector.numNulls === 0) + + val row0 = columnVector.getStruct(0, 2) + assert(row0.getInt(0) === 1) + assert(row0.getLong(1) === 1L) + + val row1 = columnVector.getStruct(1, 2) + assert(row1.getInt(0) === 2) + assert(row1.isNullAt(1)) + + columnVector.close() + allocator.close() + } + test("struct") { val allocator = ArrowUtils.rootAllocator.newChildAllocator("struct", 0, Long.MaxValue) val schema = new StructType().add("int", IntegerType).add("long", LongType) From 45ad97df87c89cb94ce9564e5773897b6d9326f5 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Thu, 18 Jan 2018 07:30:54 +0900 Subject: [PATCH 0122/1161] [SPARK-23132][PYTHON][ML] Run doctests in ml.image when testing ## What changes were proposed in this pull request? This PR proposes to actually run the doctests in `ml/image.py`. ## How was this patch tested? doctests in `python/pyspark/ml/image.py`. Author: hyukjinkwon Closes #20294 from HyukjinKwon/trigger-image. --- python/pyspark/ml/image.py | 26 ++++++++++++++++++++++++-- 1 file changed, 24 insertions(+), 2 deletions(-) diff --git a/python/pyspark/ml/image.py b/python/pyspark/ml/image.py index c9b840276f675..2d86c7f03860c 100644 --- a/python/pyspark/ml/image.py +++ b/python/pyspark/ml/image.py @@ -194,9 +194,9 @@ def readImages(self, path, recursive=False, numPartitions=-1, :return: a :class:`DataFrame` with a single column of "images", see ImageSchema for details. - >>> df = ImageSchema.readImages('python/test_support/image/kittens', recursive=True) + >>> df = ImageSchema.readImages('data/mllib/images/kittens', recursive=True) >>> df.count() - 4 + 5 .. versionadded:: 2.3.0 """ @@ -216,3 +216,25 @@ def readImages(self, path, recursive=False, numPartitions=-1, def _disallow_instance(_): raise RuntimeError("Creating instance of _ImageSchema class is disallowed.") _ImageSchema.__init__ = _disallow_instance + + +def _test(): + import doctest + import pyspark.ml.image + globs = pyspark.ml.image.__dict__.copy() + spark = SparkSession.builder\ + .master("local[2]")\ + .appName("ml.image tests")\ + .getOrCreate() + globs['spark'] = spark + + (failure_count, test_count) = doctest.testmod( + pyspark.ml.image, globs=globs, + optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE) + spark.stop() + if failure_count: + exit(-1) + + +if __name__ == "__main__": + _test() From 7823d43ec0e9c4b8284bb4529b0e624c43bc9bb7 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Wed, 17 Jan 2018 17:16:57 -0600 Subject: [PATCH 0123/1161] [MINOR] Fix typos in ML scaladocs ## What changes were proposed in this pull request? Fixed some typos found in ML scaladocs ## How was this patch tested? NA Author: Bryan Cutler Closes #20300 from BryanCutler/ml-doc-typos-MINOR. --- .../src/main/scala/org/apache/spark/ml/stat/Summarizer.scala | 2 +- .../org/apache/spark/ml/tuning/TrainValidationSplit.scala | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala b/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala index 9bed74a9f2c05..d40827edb6d64 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala @@ -75,7 +75,7 @@ sealed abstract class SummaryBuilder { * val Row(meanVec) = meanDF.first() * }}} * - * Note: Currently, the performance of this interface is about 2x~3x slower then using the RDD + * Note: Currently, the performance of this interface is about 2x~3x slower than using the RDD * interface. */ @Experimental diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala index 8826ef3271bc1..88ff0dfd75e96 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala @@ -93,7 +93,7 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St def setSeed(value: Long): this.type = set(seed, value) /** - * Set the mamixum level of parallelism to evaluate models in parallel. + * Set the maximum level of parallelism to evaluate models in parallel. * Default is 1 for serial evaluation * * @group expertSetParam @@ -112,7 +112,8 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St * for more information. * * @group expertSetParam - */@Since("2.3.0") + */ + @Since("2.3.0") def setCollectSubModels(value: Boolean): this.type = set(collectSubModels, value) @Since("2.0.0") From bac0d661af6092dd26638223156827aceb901229 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Wed, 17 Jan 2018 16:40:02 -0800 Subject: [PATCH 0124/1161] [SPARK-23119][SS] Minor fixes to V2 streaming APIs ## What changes were proposed in this pull request? - Added `InterfaceStability.Evolving` annotations - Improved docs. ## How was this patch tested? Existing tests. Author: Tathagata Das Closes #20286 from tdas/SPARK-23119. --- .../v2/streaming/ContinuousReadSupport.java | 2 ++ .../streaming/reader/ContinuousDataReader.java | 2 ++ .../v2/streaming/reader/ContinuousReader.java | 9 +++++++-- .../v2/streaming/reader/MicroBatchReader.java | 5 +++++ .../sources/v2/streaming/reader/Offset.java | 18 +++++++++++++----- .../v2/streaming/reader/PartitionOffset.java | 3 +++ .../sources/v2/writer/DataSourceV2Writer.java | 5 ++++- 7 files changed, 36 insertions(+), 8 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/ContinuousReadSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/ContinuousReadSupport.java index 3136cee1f655f..9a93a806b0efc 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/ContinuousReadSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/ContinuousReadSupport.java @@ -19,6 +19,7 @@ import java.util.Optional; +import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.sources.v2.DataSourceV2; import org.apache.spark.sql.sources.v2.DataSourceV2Options; import org.apache.spark.sql.sources.v2.streaming.reader.ContinuousReader; @@ -28,6 +29,7 @@ * A mix-in interface for {@link DataSourceV2}. Data sources can implement this interface to * provide data reading ability for continuous stream processing. */ +@InterfaceStability.Evolving public interface ContinuousReadSupport extends DataSourceV2 { /** * Creates a {@link ContinuousReader} to scan the data from this data source. diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/ContinuousDataReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/ContinuousDataReader.java index ca9a290e97a02..3f13a4dbf5793 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/ContinuousDataReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/ContinuousDataReader.java @@ -17,11 +17,13 @@ package org.apache.spark.sql.sources.v2.streaming.reader; +import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.sources.v2.reader.DataReader; /** * A variation on {@link DataReader} for use with streaming in continuous processing mode. */ +@InterfaceStability.Evolving public interface ContinuousDataReader extends DataReader { /** * Get the offset of the current record, or the start offset if no records have been read. diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/ContinuousReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/ContinuousReader.java index f0b205869ed6c..745f1ce502443 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/ContinuousReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/ContinuousReader.java @@ -17,6 +17,7 @@ package org.apache.spark.sql.sources.v2.streaming.reader; +import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.execution.streaming.BaseStreamingSource; import org.apache.spark.sql.sources.v2.reader.DataSourceV2Reader; @@ -27,11 +28,15 @@ * interface to allow reading in a continuous processing mode stream. * * Implementations must ensure each read task output is a {@link ContinuousDataReader}. + * + * Note: This class currently extends {@link BaseStreamingSource} to maintain compatibility with + * DataSource V1 APIs. This extension will be removed once we get rid of V1 completely. */ +@InterfaceStability.Evolving public interface ContinuousReader extends BaseStreamingSource, DataSourceV2Reader { /** - * Merge offsets coming from {@link ContinuousDataReader} instances in each partition to - * a single global offset. + * Merge partitioned offsets coming from {@link ContinuousDataReader} instances for each + * partition to a single global offset. */ Offset mergeOffsets(PartitionOffset[] offsets); diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/MicroBatchReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/MicroBatchReader.java index 70ff756806032..02f37cebc7484 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/MicroBatchReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/MicroBatchReader.java @@ -17,6 +17,7 @@ package org.apache.spark.sql.sources.v2.streaming.reader; +import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.sources.v2.reader.DataSourceV2Reader; import org.apache.spark.sql.execution.streaming.BaseStreamingSource; @@ -25,7 +26,11 @@ /** * A mix-in interface for {@link DataSourceV2Reader}. Data source readers can implement this * interface to indicate they allow micro-batch streaming reads. + * + * Note: This class currently extends {@link BaseStreamingSource} to maintain compatibility with + * DataSource V1 APIs. This extension will be removed once we get rid of V1 completely. */ +@InterfaceStability.Evolving public interface MicroBatchReader extends DataSourceV2Reader, BaseStreamingSource { /** * Set the desired offset range for read tasks created from this reader. Read tasks will diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/Offset.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/Offset.java index 60b87f2ac0756..abba3e7188b13 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/Offset.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/Offset.java @@ -17,12 +17,20 @@ package org.apache.spark.sql.sources.v2.streaming.reader; +import org.apache.spark.annotation.InterfaceStability; + /** - * An abstract representation of progress through a [[MicroBatchReader]] or [[ContinuousReader]]. - * During execution, Offsets provided by the data source implementation will be logged and used as - * restart checkpoints. Sources should provide an Offset implementation which they can use to - * reconstruct the stream position where the offset was taken. + * An abstract representation of progress through a {@link MicroBatchReader} or + * {@link ContinuousReader}. + * During execution, offsets provided by the data source implementation will be logged and used as + * restart checkpoints. Each source should provide an offset implementation which the source can use + * to reconstruct a position in the stream up to which data has been seen/processed. + * + * Note: This class currently extends {@link org.apache.spark.sql.execution.streaming.Offset} to + * maintain compatibility with DataSource V1 APIs. This extension will be removed once we + * get rid of V1 completely. */ +@InterfaceStability.Evolving public abstract class Offset extends org.apache.spark.sql.execution.streaming.Offset { /** * A JSON-serialized representation of an Offset that is @@ -37,7 +45,7 @@ public abstract class Offset extends org.apache.spark.sql.execution.streaming.Of /** * Equality based on JSON string representation. We leverage the * JSON representation for normalization between the Offset's - * in memory and on disk representations. + * in deserialized and serialized representations. */ @Override public boolean equals(Object obj) { diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/PartitionOffset.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/PartitionOffset.java index eca0085c8a8ce..4688b85f49f5f 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/PartitionOffset.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/PartitionOffset.java @@ -19,11 +19,14 @@ import java.io.Serializable; +import org.apache.spark.annotation.InterfaceStability; + /** * Used for per-partition offsets in continuous processing. ContinuousReader implementations will * provide a method to merge these into a global Offset. * * These offsets must be serializable. */ +@InterfaceStability.Evolving public interface PartitionOffset extends Serializable { } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceV2Writer.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceV2Writer.java index fc37b9a516f82..317ac45bcfd74 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceV2Writer.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceV2Writer.java @@ -22,11 +22,14 @@ import org.apache.spark.sql.SaveMode; import org.apache.spark.sql.sources.v2.DataSourceV2Options; import org.apache.spark.sql.sources.v2.WriteSupport; +import org.apache.spark.sql.streaming.OutputMode; import org.apache.spark.sql.types.StructType; /** * A data source writer that is returned by - * {@link WriteSupport#createWriter(String, StructType, SaveMode, DataSourceV2Options)}. + * {@link WriteSupport#createWriter(String, StructType, SaveMode, DataSourceV2Options)}/ + * {@link org.apache.spark.sql.sources.v2.streaming.MicroBatchWriteSupport#createMicroBatchWriter(String, long, StructType, OutputMode, DataSourceV2Options)}/ + * {@link org.apache.spark.sql.sources.v2.streaming.ContinuousWriteSupport#createContinuousWriter(String, StructType, OutputMode, DataSourceV2Options)}. * It can mix in various writing optimization interfaces to speed up the data saving. The actual * writing logic is delegated to {@link DataWriter}. * From 1002bd6b23ff78a010ca259ea76988ef4c478c6e Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Wed, 17 Jan 2018 16:41:43 -0800 Subject: [PATCH 0125/1161] [SPARK-23064][DOCS][SS] Added documentation for stream-stream joins ## What changes were proposed in this pull request? Added documentation for stream-stream joins ![image](https://user-images.githubusercontent.com/663212/35018744-e999895a-fad7-11e7-9d6a-8c7a73e6eb9c.png) ![image](https://user-images.githubusercontent.com/663212/35018775-157eb464-fad8-11e7-879e-47a2fcbd8690.png) ![image](https://user-images.githubusercontent.com/663212/35018784-27791a24-fad8-11e7-98f4-7ff246f62a74.png) ![image](https://user-images.githubusercontent.com/663212/35018791-36a80334-fad8-11e7-9791-f85efa7c6ba2.png) ## How was this patch tested? N/a Author: Tathagata Das Closes #20255 from tdas/join-docs. --- .../structured-streaming-programming-guide.md | 338 +++++++++++++++++- 1 file changed, 326 insertions(+), 12 deletions(-) diff --git a/docs/structured-streaming-programming-guide.md b/docs/structured-streaming-programming-guide.md index de13e281916db..1779a4215e085 100644 --- a/docs/structured-streaming-programming-guide.md +++ b/docs/structured-streaming-programming-guide.md @@ -1051,7 +1051,19 @@ output mode. ### Join Operations -Streaming DataFrames can be joined with static DataFrames to create new streaming DataFrames. Here are a few examples. +Structured Streaming supports joining a streaming Dataset/DataFrame with a static Dataset/DataFrame +as well as another streaming Dataset/DataFrame. The result of the streaming join is generated +incrementally, similar to the results of streaming aggregations in the previous section. In this +section we will explore what type of joins (i.e. inner, outer, etc.) are supported in the above +cases. Note that in all the supported join types, the result of the join with a streaming +Dataset/DataFrame will be the exactly the same as if it was with a static Dataset/DataFrame +containing the same data in the stream. + + +#### Stream-static joins + +Since the introduction in Spark 2.0, Structured Streaming has supported joins (inner join and some +type of outer joins) between a streaming and a static DataFrame/Dataset. Here is a simple example.
    @@ -1089,6 +1101,300 @@ streamingDf.join(staticDf, "type", "right_join") # right outer join with a stat
    +Note that stream-static joins are not stateful, so no state management is necessary. +However, a few types of stream-static outer joins are not yet supported. +These are listed at the [end of this Join section](#support-matrix-for-joins-in-streaming-queries). + +#### Stream-stream Joins +In Spark 2.3, we have added support for stream-stream joins, that is, you can join two streaming +Datasets/DataFrames. The challenge of generating join results between two data streams is that, +at any point of time, the view of the dataset is incomplete for both sides of the join making +it much harder to find matches between inputs. Any row received from one input stream can match +with any future, yet-to-be-received row from the other input stream. Hence, for both the input +streams, we buffer past input as streaming state, so that we can match every future input with +past input and accordingly generate joined results. Furthermore, similar to streaming aggregations, +we automatically handle late, out-of-order data and can limit the state using watermarks. +Let’s discuss the different types of supported stream-stream joins and how to use them. + +##### Inner Joins with optional Watermarking +Inner joins on any kind of columns along with any kind of join conditions are supported. +However, as the stream runs, the size of streaming state will keep growing indefinitely as +*all* past input must be saved as the any new input can match with any input from the past. +To avoid unbounded state, you have to define additional join conditions such that indefinitely +old inputs cannot match with future inputs and therefore can be cleared from the state. +In other words, you will have to do the following additional steps in the join. + +1. Define watermark delays on both inputs such that the engine knows how delayed the input can be +(similar to streaming aggregations) + +1. Define a constraint on event-time across the two inputs such that the engine can figure out when +old rows of one input is not going to be required (i.e. will not satisfy the time constraint) for +matches with the other input. This constraint can be defined in one of the two ways. + + 1. Time range join conditions (e.g. `...JOIN ON leftTime BETWEN rightTime AND rightTime + INTERVAL 1 HOUR`), + + 1. Join on event-time windows (e.g. `...JOIN ON leftTimeWindow = rightTimeWindow`). + +Let’s understand this with an example. + +Let’s say we want to join a stream of advertisement impressions (when an ad was shown) with +another stream of user clicks on advertisements to correlate when impressions led to +monetizable clicks. To allow the state cleanup in this stream-stream join, you will have to +specify the watermarking delays and the time constraints as follows. + +1. Watermark delays: Say, the impressions and the corresponding clicks can be late/out-of-order +in event-time by at most 2 and 3 hours, respectively. + +1. Event-time range condition: Say, a click can occur within a time range of 0 seconds to 1 hour +after the corresponding impression. + +The code would look like this. + +
    +
    + +{% highlight scala %} +import org.apache.spark.sql.functions.expr + +val impressions = spark.readStream. ... +val clicks = spark.readStream. ... + +// Apply watermarks on event-time columns +val impressionsWithWatermark = impressions.withWatermark("impressionTime", "2 hours") +val clicksWithWatermark = clicks.withWatermark("clickTime", "3 hours") + +// Join with event-time constraints +impressionsWithWatermark.join( + clicksWithWatermark, + expr(""" + clickAdId = impressionAdId AND + clickTime >= impressionTime AND + clickTime <= impressionTime + interval 1 hour + """) +) + +{% endhighlight %} + +
    +
    + +{% highlight java %} +import static org.apache.spark.sql.functions.expr + +Dataset impressions = spark.readStream(). ... +Dataset clicks = spark.readStream(). ... + +// Apply watermarks on event-time columns +Dataset impressionsWithWatermark = impressions.withWatermark("impressionTime", "2 hours"); +Dataset clicksWithWatermark = clicks.withWatermark("clickTime", "3 hours"); + +// Join with event-time constraints +impressionsWithWatermark.join( + clicksWithWatermark, + expr( + "clickAdId = impressionAdId AND " + + "clickTime >= impressionTime AND " + + "clickTime <= impressionTime + interval 1 hour ") +); + +{% endhighlight %} + + +
    +
    + +{% highlight python %} +from pyspark.sql.functions import expr + +impressions = spark.readStream. ... +clicks = spark.readStream. ... + +# Apply watermarks on event-time columns +impressionsWithWatermark = impressions.withWatermark("impressionTime", "2 hours") +clicksWithWatermark = clicks.withWatermark("clickTime", "3 hours") + +# Join with event-time constraints +impressionsWithWatermark.join( + clicksWithWatermark, + expr(""" + clickAdId = impressionAdId AND + clickTime >= impressionTime AND + clickTime <= impressionTime + interval 1 hour + """) +) + +{% endhighlight %} + +
    +
    + +##### Outer Joins with Watermarking +While the watermark + event-time constraints is optional for inner joins, for left and right outer +joins they must be specified. This is because for generating the NULL results in outer join, the +engine must know when an input row is not going to match with anything in future. Hence, the +watermark + event-time constraints must be specified for generating correct results. Therefore, +a query with outer-join will look quite like the ad-monetization example earlier, except that +there will be an additional parameter specifying it to be an outer-join. + +
    +
    + +{% highlight scala %} + +impressionsWithWatermark.join( + clicksWithWatermark, + expr(""" + clickAdId = impressionAdId AND + clickTime >= impressionTime AND + clickTime <= impressionTime + interval 1 hour + """), + joinType = "leftOuter" // can be "inner", "leftOuter", "rightOuter" + ) + +{% endhighlight %} + +
    +
    + +{% highlight java %} +impressionsWithWatermark.join( + clicksWithWatermark, + expr( + "clickAdId = impressionAdId AND " + + "clickTime >= impressionTime AND " + + "clickTime <= impressionTime + interval 1 hour "), + "leftOuter" // can be "inner", "leftOuter", "rightOuter" +); + +{% endhighlight %} + + +
    +
    + +{% highlight python %} +impressionsWithWatermark.join( + clicksWithWatermark, + expr(""" + clickAdId = impressionAdId AND + clickTime >= impressionTime AND + clickTime <= impressionTime + interval 1 hour + """), + "leftOuter" # can be "inner", "leftOuter", "rightOuter" +) + +{% endhighlight %} + +
    +
    + +However, note that the outer NULL results will be generated with a delay (depends on the specified +watermark delay and the time range condition) because the engine has to wait for that long to ensure +there were no matches and there will be no more matches in future. + +##### Support matrix for joins in streaming queries + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
    Left InputRight InputJoin Type
    StaticStaticAll types + Supported, since its not on streaming data even though it + can be present in a streaming query +
    StreamStaticInnerSupported, not stateful
    Left OuterSupported, not stateful
    Right OuterNot supported
    Full OuterNot supported
    StaticStreamInnerSupported, not stateful
    Left OuterNot supported
    Right OuterSupported, not stateful
    Full OuterNot supported
    StreamStreamInner + Supported, optionally specify watermark on both sides + + time constraints for state cleanup +
    Left Outer + Conditionally supported, must specify watermark on right + time constraints for correct + results, optionally specify watermark on left for all state cleanup +
    Right Outer + Conditionally supported, must specify watermark on left + time constraints for correct + results, optionally specify watermark on right for all state cleanup +
    Full OuterNot supported
    + +Additional details on supported joins: + +- Joins can be cascaded, that is, you can do `df1.join(df2, ...).join(df3, ...).join(df4, ....)`. + +- As of Spark 2.3, you can use joins only when the query is in Append output mode. Other output modes are not yet supported. + +- As of Spark 2.3, you cannot use other non-map-like operations before joins. Here are a few examples of + what cannot be used. + + - Cannot use streaming aggregations before joins. + + - Cannot use mapGroupsWithState and flatMapGroupsWithState in Update mode before joins. + + ### Streaming Deduplication You can deduplicate records in data streams using a unique identifier in the events. This is exactly same as deduplication on static using a unique identifier column. The query will store the necessary amount of data from previous records such that it can filter duplicate records. Similar to aggregations, you can use deduplication with or without watermarking. @@ -1160,15 +1466,9 @@ Some of them are as follows. - Sorting operations are supported on streaming Datasets only after an aggregation and in Complete Output Mode. -- Outer joins between a streaming and a static Datasets are conditionally supported. - - + Full outer join with a streaming Dataset is not supported - - + Left outer join with a streaming Dataset on the right is not supported - - + Right outer join with a streaming Dataset on the left is not supported - -- Any kind of joins between two streaming Datasets is not yet supported. +- Few types of outer joins on streaming Datasets are not supported. See the + support matrix in the Join Operations section + for more details. In addition, there are some Dataset methods that will not work on streaming Datasets. They are actions that will immediately run queries and return results, which does not make sense on a streaming Dataset. Rather, those functionalities can be done by explicitly starting a streaming query (see the next section regarding that). @@ -1276,6 +1576,15 @@ Here is the compatibility matrix. Aggregations not allowed after flatMapGroupsWithState. + + Queries with joins + Append + + Update and Complete mode not supported yet. See the + support matrix in the Join Operations section + for more details on what types of joins are supported. + + Other queries Append, Update @@ -2142,6 +2451,11 @@ write.stream(aggDF, "memory", outputMode = "complete", checkpointLocation = "pat **Talks** -- Spark Summit 2017 Talk - [Easy, Scalable, Fault-tolerant Stream Processing with Structured Streaming in Apache Spark](https://spark-summit.org/2017/events/easy-scalable-fault-tolerant-stream-processing-with-structured-streaming-in-apache-spark/) -- Spark Summit 2016 Talk - [A Deep Dive into Structured Streaming](https://spark-summit.org/2016/events/a-deep-dive-into-structured-streaming/) +- Spark Summit Europe 2017 + - Easy, Scalable, Fault-tolerant Stream Processing with Structured Streaming in Apache Spark - + [Part 1 slides/video](https://databricks.com/session/easy-scalable-fault-tolerant-stream-processing-with-structured-streaming-in-apache-spark), [Part 2 slides/video](https://databricks.com/session/easy-scalable-fault-tolerant-stream-processing-with-structured-streaming-in-apache-spark-continues) + - Deep Dive into Stateful Stream Processing in Structured Streaming - [slides/video](https://databricks.com/session/deep-dive-into-stateful-stream-processing-in-structured-streaming) +- Spark Summit 2016 + - A Deep Dive into Structured Streaming - [slides/video](https://spark-summit.org/2016/events/a-deep-dive-into-structured-streaming/) + From 02194702068291b3af77486d01029fb848c36d7b Mon Sep 17 00:00:00 2001 From: Xiayun Sun Date: Wed, 17 Jan 2018 16:42:38 -0800 Subject: [PATCH 0126/1161] [SPARK-21996][SQL] read files with space in name for streaming ## What changes were proposed in this pull request? Structured streaming is now able to read files with space in file name (previously it would skip the file and output a warning) ## How was this patch tested? Added new unit test. Author: Xiayun Sun Closes #19247 from xysun/SPARK-21996. --- .../streaming/FileStreamSource.scala | 2 +- .../sql/streaming/FileStreamSourceSuite.scala | 50 ++++++++++++++++++- 2 files changed, 49 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala index 0debd7db84757..8c016abc5b643 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala @@ -166,7 +166,7 @@ class FileStreamSource( val newDataSource = DataSource( sparkSession, - paths = files.map(_.path), + paths = files.map(f => new Path(new URI(f.path)).toString), userSpecifiedSchema = Some(schema), partitionColumns = partitionColumns, className = fileFormatClassName, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala index 39bb572740617..5bb0f4d643bbe 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala @@ -74,11 +74,11 @@ abstract class FileStreamSourceTest protected def addData(source: FileStreamSource): Unit } - case class AddTextFileData(content: String, src: File, tmp: File) + case class AddTextFileData(content: String, src: File, tmp: File, tmpFilePrefix: String = "text") extends AddFileData { override def addData(source: FileStreamSource): Unit = { - val tempFile = Utils.tempFileWith(new File(tmp, "text")) + val tempFile = Utils.tempFileWith(new File(tmp, tmpFilePrefix)) val finalFile = new File(src, tempFile.getName) src.mkdirs() require(stringToFile(tempFile, content).renameTo(finalFile)) @@ -408,6 +408,52 @@ class FileStreamSourceSuite extends FileStreamSourceTest { } } + test("SPARK-21996 read from text files -- file name has space") { + withTempDirs { case (src, tmp) => + val textStream = createFileStream("text", src.getCanonicalPath) + val filtered = textStream.filter($"value" contains "keep") + + testStream(filtered)( + AddTextFileData("drop1\nkeep2\nkeep3", src, tmp, "text text"), + CheckAnswer("keep2", "keep3") + ) + } + } + + test("SPARK-21996 read from text files generated by file sink -- file name has space") { + val testTableName = "FileStreamSourceTest" + withTable(testTableName) { + withTempDirs { case (src, checkpoint) => + val output = new File(src, "text text") + val inputData = MemoryStream[String] + val ds = inputData.toDS() + + val query = ds.writeStream + .option("checkpointLocation", checkpoint.getCanonicalPath) + .format("text") + .start(output.getCanonicalPath) + + try { + inputData.addData("foo") + failAfter(streamingTimeout) { + query.processAllAvailable() + } + } finally { + query.stop() + } + + val df2 = spark.readStream.format("text").load(output.getCanonicalPath) + val query2 = df2.writeStream.format("memory").queryName(testTableName).start() + try { + query2.processAllAvailable() + checkDatasetUnorderly(spark.table(testTableName).as[String], "foo") + } finally { + query2.stop() + } + } + } + } + test("read from textfile") { withTempDirs { case (src, tmp) => val textStream = spark.readStream.textFile(src.getCanonicalPath) From 39d244d921d8d2d3ed741e8e8f1175515a74bdbd Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Thu, 18 Jan 2018 14:51:05 +0900 Subject: [PATCH 0127/1161] [SPARK-23122][PYTHON][SQL] Deprecate register* for UDFs in SQLContext and Catalog in PySpark ## What changes were proposed in this pull request? This PR proposes to deprecate `register*` for UDFs in `SQLContext` and `Catalog` in Spark 2.3.0. These are inconsistent with Scala / Java APIs and also these basically do the same things with `spark.udf.register*`. Also, this PR moves the logcis from `[sqlContext|spark.catalog].register*` to `spark.udf.register*` and reuse the docstring. This PR also handles minor doc corrections. It also includes https://github.com/apache/spark/pull/20158 ## How was this patch tested? Manually tested, manually checked the API documentation and tests added to check if deprecated APIs call the aliases correctly. Author: hyukjinkwon Closes #20288 from HyukjinKwon/deprecate-udf. --- dev/sparktestsupport/modules.py | 1 + python/pyspark/sql/catalog.py | 91 ++-------------- python/pyspark/sql/context.py | 137 ++++-------------------- python/pyspark/sql/functions.py | 4 +- python/pyspark/sql/group.py | 3 +- python/pyspark/sql/session.py | 6 +- python/pyspark/sql/tests.py | 20 ++++ python/pyspark/sql/udf.py | 182 +++++++++++++++++++++++++++++++- 8 files changed, 234 insertions(+), 210 deletions(-) diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 7164180a6a7b0..b900f0bd913c3 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -400,6 +400,7 @@ def __hash__(self): "pyspark.sql.functions", "pyspark.sql.readwriter", "pyspark.sql.streaming", + "pyspark.sql.udf", "pyspark.sql.window", "pyspark.sql.tests", ] diff --git a/python/pyspark/sql/catalog.py b/python/pyspark/sql/catalog.py index 35fbe9e669adb..6aef0f22340be 100644 --- a/python/pyspark/sql/catalog.py +++ b/python/pyspark/sql/catalog.py @@ -224,92 +224,17 @@ def dropGlobalTempView(self, viewName): """ self._jcatalog.dropGlobalTempView(viewName) - @ignore_unicode_prefix @since(2.0) def registerFunction(self, name, f, returnType=None): - """Registers a Python function (including lambda function) or a :class:`UserDefinedFunction` - as a UDF. The registered UDF can be used in SQL statements. - - :func:`spark.udf.register` is an alias for :func:`spark.catalog.registerFunction`. - - In addition to a name and the function itself, `returnType` can be optionally specified. - 1) When f is a Python function, `returnType` defaults to a string. The produced object must - match the specified type. 2) When f is a :class:`UserDefinedFunction`, Spark uses the return - type of the given UDF as the return type of the registered UDF. The input parameter - `returnType` is None by default. If given by users, the value must be None. - - :param name: name of the UDF in SQL statements. - :param f: a Python function, or a wrapped/native UserDefinedFunction. The UDF can be either - row-at-a-time or vectorized. - :param returnType: the return type of the registered UDF. - :return: a wrapped/native :class:`UserDefinedFunction` - - >>> strlen = spark.catalog.registerFunction("stringLengthString", len) - >>> spark.sql("SELECT stringLengthString('test')").collect() - [Row(stringLengthString(test)=u'4')] - - >>> spark.sql("SELECT 'foo' AS text").select(strlen("text")).collect() - [Row(stringLengthString(text)=u'3')] - - >>> from pyspark.sql.types import IntegerType - >>> _ = spark.catalog.registerFunction("stringLengthInt", len, IntegerType()) - >>> spark.sql("SELECT stringLengthInt('test')").collect() - [Row(stringLengthInt(test)=4)] - - >>> from pyspark.sql.types import IntegerType - >>> _ = spark.udf.register("stringLengthInt", len, IntegerType()) - >>> spark.sql("SELECT stringLengthInt('test')").collect() - [Row(stringLengthInt(test)=4)] - - >>> from pyspark.sql.types import IntegerType - >>> from pyspark.sql.functions import udf - >>> slen = udf(lambda s: len(s), IntegerType()) - >>> _ = spark.udf.register("slen", slen) - >>> spark.sql("SELECT slen('test')").collect() - [Row(slen(test)=4)] - - >>> import random - >>> from pyspark.sql.functions import udf - >>> from pyspark.sql.types import IntegerType - >>> random_udf = udf(lambda: random.randint(0, 100), IntegerType()).asNondeterministic() - >>> new_random_udf = spark.catalog.registerFunction("random_udf", random_udf) - >>> spark.sql("SELECT random_udf()").collect() # doctest: +SKIP - [Row(random_udf()=82)] - >>> spark.range(1).select(new_random_udf()).collect() # doctest: +SKIP - [Row(()=26)] - - >>> from pyspark.sql.functions import pandas_udf, PandasUDFType - >>> @pandas_udf("integer", PandasUDFType.SCALAR) # doctest: +SKIP - ... def add_one(x): - ... return x + 1 - ... - >>> _ = spark.udf.register("add_one", add_one) # doctest: +SKIP - >>> spark.sql("SELECT add_one(id) FROM range(3)").collect() # doctest: +SKIP - [Row(add_one(id)=1), Row(add_one(id)=2), Row(add_one(id)=3)] - """ + """An alias for :func:`spark.udf.register`. + See :meth:`pyspark.sql.UDFRegistration.register`. - # This is to check whether the input function is a wrapped/native UserDefinedFunction - if hasattr(f, 'asNondeterministic'): - if returnType is not None: - raise TypeError( - "Invalid returnType: None is expected when f is a UserDefinedFunction, " - "but got %s." % returnType) - if f.evalType not in [PythonEvalType.SQL_BATCHED_UDF, - PythonEvalType.SQL_PANDAS_SCALAR_UDF]: - raise ValueError( - "Invalid f: f must be either SQL_BATCHED_UDF or SQL_PANDAS_SCALAR_UDF") - register_udf = UserDefinedFunction(f.func, returnType=f.returnType, name=name, - evalType=f.evalType, - deterministic=f.deterministic) - return_udf = f - else: - if returnType is None: - returnType = StringType() - register_udf = UserDefinedFunction(f, returnType=returnType, name=name, - evalType=PythonEvalType.SQL_BATCHED_UDF) - return_udf = register_udf._wrapped() - self._jsparkSession.udf().registerPython(name, register_udf._judf) - return return_udf + .. note:: Deprecated in 2.3.0. Use :func:`spark.udf.register` instead. + """ + warnings.warn( + "Deprecated in 2.3.0. Use spark.udf.register instead.", + DeprecationWarning) + return self._sparkSession.udf.register(name, f, returnType) @since(2.0) def isCached(self, tableName): diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index 85479095af594..cc1cd1a5842d9 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -29,9 +29,10 @@ from pyspark.sql.readwriter import DataFrameReader from pyspark.sql.streaming import DataStreamReader from pyspark.sql.types import IntegerType, Row, StringType +from pyspark.sql.udf import UDFRegistration from pyspark.sql.utils import install_exception_handler -__all__ = ["SQLContext", "HiveContext", "UDFRegistration"] +__all__ = ["SQLContext", "HiveContext"] class SQLContext(object): @@ -147,7 +148,7 @@ def udf(self): :return: :class:`UDFRegistration` """ - return UDFRegistration(self) + return self.sparkSession.udf @since(1.4) def range(self, start, end=None, step=1, numPartitions=None): @@ -172,113 +173,29 @@ def range(self, start, end=None, step=1, numPartitions=None): """ return self.sparkSession.range(start, end, step, numPartitions) - @ignore_unicode_prefix @since(1.2) def registerFunction(self, name, f, returnType=None): - """Registers a Python function (including lambda function) or a :class:`UserDefinedFunction` - as a UDF. The registered UDF can be used in SQL statements. - - :func:`spark.udf.register` is an alias for :func:`sqlContext.registerFunction`. - - In addition to a name and the function itself, `returnType` can be optionally specified. - 1) When f is a Python function, `returnType` defaults to a string. The produced object must - match the specified type. 2) When f is a :class:`UserDefinedFunction`, Spark uses the return - type of the given UDF as the return type of the registered UDF. The input parameter - `returnType` is None by default. If given by users, the value must be None. - - :param name: name of the UDF in SQL statements. - :param f: a Python function, or a wrapped/native UserDefinedFunction. The UDF can be either - row-at-a-time or vectorized. - :param returnType: the return type of the registered UDF. - :return: a wrapped/native :class:`UserDefinedFunction` - - >>> strlen = sqlContext.registerFunction("stringLengthString", lambda x: len(x)) - >>> sqlContext.sql("SELECT stringLengthString('test')").collect() - [Row(stringLengthString(test)=u'4')] - - >>> sqlContext.sql("SELECT 'foo' AS text").select(strlen("text")).collect() - [Row(stringLengthString(text)=u'3')] - - >>> from pyspark.sql.types import IntegerType - >>> _ = sqlContext.registerFunction("stringLengthInt", lambda x: len(x), IntegerType()) - >>> sqlContext.sql("SELECT stringLengthInt('test')").collect() - [Row(stringLengthInt(test)=4)] - - >>> from pyspark.sql.types import IntegerType - >>> _ = sqlContext.udf.register("stringLengthInt", lambda x: len(x), IntegerType()) - >>> sqlContext.sql("SELECT stringLengthInt('test')").collect() - [Row(stringLengthInt(test)=4)] - - >>> from pyspark.sql.types import IntegerType - >>> from pyspark.sql.functions import udf - >>> slen = udf(lambda s: len(s), IntegerType()) - >>> _ = sqlContext.udf.register("slen", slen) - >>> sqlContext.sql("SELECT slen('test')").collect() - [Row(slen(test)=4)] - - >>> import random - >>> from pyspark.sql.functions import udf - >>> from pyspark.sql.types import IntegerType - >>> random_udf = udf(lambda: random.randint(0, 100), IntegerType()).asNondeterministic() - >>> new_random_udf = sqlContext.registerFunction("random_udf", random_udf) - >>> sqlContext.sql("SELECT random_udf()").collect() # doctest: +SKIP - [Row(random_udf()=82)] - >>> sqlContext.range(1).select(new_random_udf()).collect() # doctest: +SKIP - [Row(()=26)] - - >>> from pyspark.sql.functions import pandas_udf, PandasUDFType - >>> @pandas_udf("integer", PandasUDFType.SCALAR) # doctest: +SKIP - ... def add_one(x): - ... return x + 1 - ... - >>> _ = sqlContext.udf.register("add_one", add_one) # doctest: +SKIP - >>> sqlContext.sql("SELECT add_one(id) FROM range(3)").collect() # doctest: +SKIP - [Row(add_one(id)=1), Row(add_one(id)=2), Row(add_one(id)=3)] + """An alias for :func:`spark.udf.register`. + See :meth:`pyspark.sql.UDFRegistration.register`. + + .. note:: Deprecated in 2.3.0. Use :func:`spark.udf.register` instead. """ - return self.sparkSession.catalog.registerFunction(name, f, returnType) + warnings.warn( + "Deprecated in 2.3.0. Use spark.udf.register instead.", + DeprecationWarning) + return self.sparkSession.udf.register(name, f, returnType) - @ignore_unicode_prefix @since(2.1) def registerJavaFunction(self, name, javaClassName, returnType=None): - """Register a java UDF so it can be used in SQL statements. - - In addition to a name and the function itself, the return type can be optionally specified. - When the return type is not specified we would infer it via reflection. - :param name: name of the UDF - :param javaClassName: fully qualified name of java class - :param returnType: a :class:`pyspark.sql.types.DataType` object - - >>> sqlContext.registerJavaFunction("javaStringLength", - ... "test.org.apache.spark.sql.JavaStringLength", IntegerType()) - >>> sqlContext.sql("SELECT javaStringLength('test')").collect() - [Row(UDF:javaStringLength(test)=4)] - >>> sqlContext.registerJavaFunction("javaStringLength2", - ... "test.org.apache.spark.sql.JavaStringLength") - >>> sqlContext.sql("SELECT javaStringLength2('test')").collect() - [Row(UDF:javaStringLength2(test)=4)] + """An alias for :func:`spark.udf.registerJavaFunction`. + See :meth:`pyspark.sql.UDFRegistration.registerJavaFunction`. + .. note:: Deprecated in 2.3.0. Use :func:`spark.udf.registerJavaFunction` instead. """ - jdt = None - if returnType is not None: - jdt = self.sparkSession._jsparkSession.parseDataType(returnType.json()) - self.sparkSession._jsparkSession.udf().registerJava(name, javaClassName, jdt) - - @ignore_unicode_prefix - @since(2.3) - def registerJavaUDAF(self, name, javaClassName): - """Register a java UDAF so it can be used in SQL statements. - - :param name: name of the UDAF - :param javaClassName: fully qualified name of java class - - >>> sqlContext.registerJavaUDAF("javaUDAF", - ... "test.org.apache.spark.sql.MyDoubleAvg") - >>> df = sqlContext.createDataFrame([(1, "a"),(2, "b"), (3, "a")],["id", "name"]) - >>> df.registerTempTable("df") - >>> sqlContext.sql("SELECT name, javaUDAF(id) as avg from df group by name").collect() - [Row(name=u'b', avg=102.0), Row(name=u'a', avg=102.0)] - """ - self.sparkSession._jsparkSession.udf().registerJavaUDAF(name, javaClassName) + warnings.warn( + "Deprecated in 2.3.0. Use spark.udf.registerJavaFunction instead.", + DeprecationWarning) + return self.sparkSession.udf.registerJavaFunction(name, javaClassName, returnType) # TODO(andrew): delete this once we refactor things to take in SparkSession def _inferSchema(self, rdd, samplingRatio=None): @@ -590,24 +507,6 @@ def refreshTable(self, tableName): self._ssql_ctx.refreshTable(tableName) -class UDFRegistration(object): - """Wrapper for user-defined function registration.""" - - def __init__(self, sqlContext): - self.sqlContext = sqlContext - - def register(self, name, f, returnType=None): - return self.sqlContext.registerFunction(name, f, returnType) - - def registerJavaFunction(self, name, javaClassName, returnType=None): - self.sqlContext.registerJavaFunction(name, javaClassName, returnType) - - def registerJavaUDAF(self, name, javaClassName): - self.sqlContext.registerJavaUDAF(name, javaClassName) - - register.__doc__ = SQLContext.registerFunction.__doc__ - - def _test(): import os import doctest diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index f7b3f29764040..988c1d25259bc 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2103,7 +2103,7 @@ def udf(f=None, returnType=StringType()): >>> import random >>> random_udf = udf(lambda: int(random.random() * 100), IntegerType()).asNondeterministic() - .. note:: The user-defined functions do not support conditional expressions or short curcuiting + .. note:: The user-defined functions do not support conditional expressions or short circuiting in boolean expressions and it ends up with being executed all internally. If the functions can fail on special rows, the workaround is to incorporate the condition into the functions. @@ -2231,7 +2231,7 @@ def pandas_udf(f=None, returnType=None, functionType=None): ... return pd.Series(np.random.randn(len(v)) >>> random = random.asNondeterministic() # doctest: +SKIP - .. note:: The user-defined functions do not support conditional expressions or short curcuiting + .. note:: The user-defined functions do not support conditional expressions or short circuiting in boolean expressions and it ends up with being executed all internally. If the functions can fail on special rows, the workaround is to incorporate the condition into the functions. """ diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py index 09fae46adf014..22061b83eb78c 100644 --- a/python/pyspark/sql/group.py +++ b/python/pyspark/sql/group.py @@ -212,7 +212,8 @@ def apply(self, udf): This function does not support partial aggregation, and requires shuffling all the data in the :class:`DataFrame`. - :param udf: A function object returned by :meth:`pyspark.sql.functions.pandas_udf` + :param udf: a group map user-defined function returned by + :meth:`pyspark.sql.functions.pandas_udf`. >>> from pyspark.sql.functions import pandas_udf, PandasUDFType >>> df = spark.createDataFrame( diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index 604021c1f45cc..6c84023c43fb6 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -29,7 +29,6 @@ from pyspark import since from pyspark.rdd import RDD, ignore_unicode_prefix -from pyspark.sql.catalog import Catalog from pyspark.sql.conf import RuntimeConfig from pyspark.sql.dataframe import DataFrame from pyspark.sql.readwriter import DataFrameReader @@ -280,6 +279,7 @@ def catalog(self): :return: :class:`Catalog` """ + from pyspark.sql.catalog import Catalog if not hasattr(self, "_catalog"): self._catalog = Catalog(self) return self._catalog @@ -291,8 +291,8 @@ def udf(self): :return: :class:`UDFRegistration` """ - from pyspark.sql.context import UDFRegistration - return UDFRegistration(self._wrapped) + from pyspark.sql.udf import UDFRegistration + return UDFRegistration(self) @since(2.0) def range(self, start, end=None, step=1, numPartitions=None): diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 8906618666b14..f84aa3d68b808 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -372,6 +372,12 @@ def test_udf(self): [row] = self.spark.sql("SELECT twoArgs('test', 1)").collect() self.assertEqual(row[0], 5) + # This is to check if a deprecated 'SQLContext.registerFunction' can call its alias. + sqlContext = self.spark._wrapped + sqlContext.registerFunction("oneArg", lambda x: len(x), IntegerType()) + [row] = sqlContext.sql("SELECT oneArg('test')").collect() + self.assertEqual(row[0], 4) + def test_udf2(self): self.spark.catalog.registerFunction("strlen", lambda string: len(string), IntegerType()) self.spark.createDataFrame(self.sc.parallelize([Row(a="test")]))\ @@ -577,11 +583,25 @@ def test_udf_registration_returns_udf(self): df.select(add_three("id").alias("plus_three")).collect() ) + # This is to check if a 'SQLContext.udf' can call its alias. + sqlContext = self.spark._wrapped + add_four = sqlContext.udf.register("add_four", lambda x: x + 4, IntegerType()) + + self.assertListEqual( + df.selectExpr("add_four(id) AS plus_four").collect(), + df.select(add_four("id").alias("plus_four")).collect() + ) + def test_non_existed_udf(self): spark = self.spark self.assertRaisesRegexp(AnalysisException, "Can not load class non_existed_udf", lambda: spark.udf.registerJavaFunction("udf1", "non_existed_udf")) + # This is to check if a deprecated 'SQLContext.registerJavaFunction' can call its alias. + sqlContext = spark._wrapped + self.assertRaisesRegexp(AnalysisException, "Can not load class non_existed_udf", + lambda: sqlContext.registerJavaFunction("udf1", "non_existed_udf")) + def test_non_existed_udaf(self): spark = self.spark self.assertRaisesRegexp(AnalysisException, "Can not load class non_existed_udaf", diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index 5e80ab9165867..1943bb73f9ac2 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -19,11 +19,13 @@ """ import functools -from pyspark import SparkContext -from pyspark.rdd import _prepare_for_python_RDD, PythonEvalType +from pyspark import SparkContext, since +from pyspark.rdd import _prepare_for_python_RDD, PythonEvalType, ignore_unicode_prefix from pyspark.sql.column import Column, _to_java_column, _to_seq from pyspark.sql.types import StringType, DataType, StructType, _parse_datatype_string +__all__ = ["UDFRegistration"] + def _wrap_function(sc, func, returnType): command = (func, returnType) @@ -181,3 +183,179 @@ def asNondeterministic(self): """ self.deterministic = False return self + + +class UDFRegistration(object): + """ + Wrapper for user-defined function registration. This instance can be accessed by + :attr:`spark.udf` or :attr:`sqlContext.udf`. + + .. versionadded:: 1.3.1 + """ + + def __init__(self, sparkSession): + self.sparkSession = sparkSession + + @ignore_unicode_prefix + @since("1.3.1") + def register(self, name, f, returnType=None): + """Registers a Python function (including lambda function) or a user-defined function + in SQL statements. + + :param name: name of the user-defined function in SQL statements. + :param f: a Python function, or a user-defined function. The user-defined function can + be either row-at-a-time or vectorized. See :meth:`pyspark.sql.functions.udf` and + :meth:`pyspark.sql.functions.pandas_udf`. + :param returnType: the return type of the registered user-defined function. + :return: a user-defined function. + + `returnType` can be optionally specified when `f` is a Python function but not + when `f` is a user-defined function. Please see below. + + 1. When `f` is a Python function: + + `returnType` defaults to string type and can be optionally specified. The produced + object must match the specified type. In this case, this API works as if + `register(name, f, returnType=StringType())`. + + >>> strlen = spark.udf.register("stringLengthString", lambda x: len(x)) + >>> spark.sql("SELECT stringLengthString('test')").collect() + [Row(stringLengthString(test)=u'4')] + + >>> spark.sql("SELECT 'foo' AS text").select(strlen("text")).collect() + [Row(stringLengthString(text)=u'3')] + + >>> from pyspark.sql.types import IntegerType + >>> _ = spark.udf.register("stringLengthInt", lambda x: len(x), IntegerType()) + >>> spark.sql("SELECT stringLengthInt('test')").collect() + [Row(stringLengthInt(test)=4)] + + >>> from pyspark.sql.types import IntegerType + >>> _ = spark.udf.register("stringLengthInt", lambda x: len(x), IntegerType()) + >>> spark.sql("SELECT stringLengthInt('test')").collect() + [Row(stringLengthInt(test)=4)] + + 2. When `f` is a user-defined function: + + Spark uses the return type of the given user-defined function as the return type of + the registered user-defined function. `returnType` should not be specified. + In this case, this API works as if `register(name, f)`. + + >>> from pyspark.sql.types import IntegerType + >>> from pyspark.sql.functions import udf + >>> slen = udf(lambda s: len(s), IntegerType()) + >>> _ = spark.udf.register("slen", slen) + >>> spark.sql("SELECT slen('test')").collect() + [Row(slen(test)=4)] + + >>> import random + >>> from pyspark.sql.functions import udf + >>> from pyspark.sql.types import IntegerType + >>> random_udf = udf(lambda: random.randint(0, 100), IntegerType()).asNondeterministic() + >>> new_random_udf = spark.udf.register("random_udf", random_udf) + >>> spark.sql("SELECT random_udf()").collect() # doctest: +SKIP + [Row(random_udf()=82)] + + >>> from pyspark.sql.functions import pandas_udf, PandasUDFType + >>> @pandas_udf("integer", PandasUDFType.SCALAR) # doctest: +SKIP + ... def add_one(x): + ... return x + 1 + ... + >>> _ = spark.udf.register("add_one", add_one) # doctest: +SKIP + >>> spark.sql("SELECT add_one(id) FROM range(3)").collect() # doctest: +SKIP + [Row(add_one(id)=1), Row(add_one(id)=2), Row(add_one(id)=3)] + + .. note:: Registration for a user-defined function (case 2.) was added from + Spark 2.3.0. + """ + + # This is to check whether the input function is from a user-defined function or + # Python function. + if hasattr(f, 'asNondeterministic'): + if returnType is not None: + raise TypeError( + "Invalid returnType: data type can not be specified when f is" + "a user-defined function, but got %s." % returnType) + if f.evalType not in [PythonEvalType.SQL_BATCHED_UDF, + PythonEvalType.SQL_PANDAS_SCALAR_UDF]: + raise ValueError( + "Invalid f: f must be either SQL_BATCHED_UDF or SQL_PANDAS_SCALAR_UDF") + register_udf = UserDefinedFunction(f.func, returnType=f.returnType, name=name, + evalType=f.evalType, + deterministic=f.deterministic) + return_udf = f + else: + if returnType is None: + returnType = StringType() + register_udf = UserDefinedFunction(f, returnType=returnType, name=name, + evalType=PythonEvalType.SQL_BATCHED_UDF) + return_udf = register_udf._wrapped() + self.sparkSession._jsparkSession.udf().registerPython(name, register_udf._judf) + return return_udf + + @ignore_unicode_prefix + @since(2.3) + def registerJavaFunction(self, name, javaClassName, returnType=None): + """Register a Java user-defined function so it can be used in SQL statements. + + In addition to a name and the function itself, the return type can be optionally specified. + When the return type is not specified we would infer it via reflection. + + :param name: name of the user-defined function + :param javaClassName: fully qualified name of java class + :param returnType: a :class:`pyspark.sql.types.DataType` object + + >>> from pyspark.sql.types import IntegerType + >>> spark.udf.registerJavaFunction( + ... "javaStringLength", "test.org.apache.spark.sql.JavaStringLength", IntegerType()) + >>> spark.sql("SELECT javaStringLength('test')").collect() + [Row(UDF:javaStringLength(test)=4)] + >>> spark.udf.registerJavaFunction( + ... "javaStringLength2", "test.org.apache.spark.sql.JavaStringLength") + >>> spark.sql("SELECT javaStringLength2('test')").collect() + [Row(UDF:javaStringLength2(test)=4)] + """ + + jdt = None + if returnType is not None: + jdt = self.sparkSession._jsparkSession.parseDataType(returnType.json()) + self.sparkSession._jsparkSession.udf().registerJava(name, javaClassName, jdt) + + @ignore_unicode_prefix + @since(2.3) + def registerJavaUDAF(self, name, javaClassName): + """Register a Java user-defined aggregate function so it can be used in SQL statements. + + :param name: name of the user-defined aggregate function + :param javaClassName: fully qualified name of java class + + >>> spark.udf.registerJavaUDAF("javaUDAF", "test.org.apache.spark.sql.MyDoubleAvg") + >>> df = spark.createDataFrame([(1, "a"),(2, "b"), (3, "a")],["id", "name"]) + >>> df.registerTempTable("df") + >>> spark.sql("SELECT name, javaUDAF(id) as avg from df group by name").collect() + [Row(name=u'b', avg=102.0), Row(name=u'a', avg=102.0)] + """ + + self.sparkSession._jsparkSession.udf().registerJavaUDAF(name, javaClassName) + + +def _test(): + import doctest + from pyspark.sql import SparkSession + import pyspark.sql.udf + globs = pyspark.sql.udf.__dict__.copy() + spark = SparkSession.builder\ + .master("local[4]")\ + .appName("sql.udf tests")\ + .getOrCreate() + globs['spark'] = spark + (failure_count, test_count) = doctest.testmod( + pyspark.sql.udf, globs=globs, + optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE) + spark.stop() + if failure_count: + exit(-1) + + +if __name__ == "__main__": + _test() From 1c76a91e5fae11dcb66c453889e587b48039fdc9 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Wed, 17 Jan 2018 22:36:29 -0800 Subject: [PATCH 0128/1161] [SPARK-23052][SS] Migrate ConsoleSink to data source V2 api. ## What changes were proposed in this pull request? Migrate ConsoleSink to data source V2 api. Note that this includes a missing piece in DataStreamWriter required to specify a data source V2 writer. Note also that I've removed the "Rerun batch" part of the sink, because as far as I can tell this would never have actually happened. A MicroBatchExecution object will only commit each batch once for its lifetime, and a new MicroBatchExecution object would have a new ConsoleSink object which doesn't know it's retrying a batch. So I think this represents an anti-feature rather than a weakness in the V2 API. ## How was this patch tested? new unit test Author: Jose Torres Closes #20243 from jose-torres/console-sink. --- .../streaming/MicroBatchExecution.scala | 7 +- .../sql/execution/streaming/console.scala | 62 ++--- .../continuous/ContinuousExecution.scala | 9 +- .../streaming/sources/ConsoleWriter.scala | 64 +++++ .../sources/PackedRowWriterFactory.scala | 60 +++++ .../sql/streaming/DataStreamWriter.scala | 16 +- ...pache.spark.sql.sources.DataSourceRegister | 8 + .../sources/ConsoleWriterSuite.scala | 135 ++++++++++ .../sources/StreamingDataSourceV2Suite.scala | 249 ++++++++++++++++++ .../test/DataStreamReaderWriterSuite.scala | 25 -- 10 files changed, 551 insertions(+), 84 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PackedRowWriterFactory.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriterSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index 70407f0580f97..7c3804547b736 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -91,11 +91,14 @@ class MicroBatchExecution( nextSourceId += 1 StreamingExecutionRelation(reader, output)(sparkSession) }) - case s @ StreamingRelationV2(_, _, _, output, v1Relation) => + case s @ StreamingRelationV2(_, sourceName, _, output, v1Relation) => v2ToExecutionRelationMap.getOrElseUpdate(s, { // Materialize source to avoid creating it in every batch val metadataPath = s"$resolvedCheckpointRoot/sources/$nextSourceId" - assert(v1Relation.isDefined, "v2 execution didn't match but v1 was unavailable") + if (v1Relation.isEmpty) { + throw new UnsupportedOperationException( + s"Data source $sourceName does not support microbatch processing.") + } val source = v1Relation.get.dataSource.createSource(metadataPath) nextSourceId += 1 StreamingExecutionRelation(source, output)(sparkSession) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala index 71eaabe273fea..94820376ff7e7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala @@ -17,58 +17,36 @@ package org.apache.spark.sql.execution.streaming -import org.apache.spark.internal.Logging -import org.apache.spark.sql.{DataFrame, SaveMode, SQLContext} -import org.apache.spark.sql.execution.SQLExecution -import org.apache.spark.sql.sources.{BaseRelation, CreatableRelationProvider, DataSourceRegister, StreamSinkProvider} -import org.apache.spark.sql.streaming.OutputMode -import org.apache.spark.sql.types.StructType - -class ConsoleSink(options: Map[String, String]) extends Sink with Logging { - // Number of rows to display, by default 20 rows - private val numRowsToShow = options.get("numRows").map(_.toInt).getOrElse(20) - - // Truncate the displayed data if it is too long, by default it is true - private val isTruncated = options.get("truncate").map(_.toBoolean).getOrElse(true) +import java.util.Optional - // Track the batch id - private var lastBatchId = -1L - - override def addBatch(batchId: Long, data: DataFrame): Unit = synchronized { - val batchIdStr = if (batchId <= lastBatchId) { - s"Rerun batch: $batchId" - } else { - lastBatchId = batchId - s"Batch: $batchId" - } - - // scalastyle:off println - println("-------------------------------------------") - println(batchIdStr) - println("-------------------------------------------") - // scalastyle:off println - data.sparkSession.createDataFrame( - data.sparkSession.sparkContext.parallelize(data.collect()), data.schema) - .show(numRowsToShow, isTruncated) - } +import scala.collection.JavaConverters._ - override def toString(): String = s"ConsoleSink[numRows=$numRowsToShow, truncate=$isTruncated]" -} +import org.apache.spark.sql._ +import org.apache.spark.sql.execution.streaming.sources.ConsoleWriter +import org.apache.spark.sql.sources.{BaseRelation, CreatableRelationProvider, DataSourceRegister} +import org.apache.spark.sql.sources.v2.{DataSourceV2, DataSourceV2Options} +import org.apache.spark.sql.sources.v2.streaming.MicroBatchWriteSupport +import org.apache.spark.sql.sources.v2.writer.DataSourceV2Writer +import org.apache.spark.sql.streaming.OutputMode +import org.apache.spark.sql.types.StructType case class ConsoleRelation(override val sqlContext: SQLContext, data: DataFrame) extends BaseRelation { override def schema: StructType = data.schema } -class ConsoleSinkProvider extends StreamSinkProvider +class ConsoleSinkProvider extends DataSourceV2 + with MicroBatchWriteSupport with DataSourceRegister with CreatableRelationProvider { - def createSink( - sqlContext: SQLContext, - parameters: Map[String, String], - partitionColumns: Seq[String], - outputMode: OutputMode): Sink = { - new ConsoleSink(parameters) + + override def createMicroBatchWriter( + queryId: String, + epochId: Long, + schema: StructType, + mode: OutputMode, + options: DataSourceV2Options): Optional[DataSourceV2Writer] = { + Optional.of(new ConsoleWriter(epochId, schema, options)) } def createRelation( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index c0507224f9be8..462e7d9721d28 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -54,16 +54,13 @@ class ContinuousExecution( sparkSession, name, checkpointRoot, analyzedPlan, sink, trigger, triggerClock, outputMode, deleteCheckpointOnStop) { - @volatile protected var continuousSources: Seq[ContinuousReader] = _ + @volatile protected var continuousSources: Seq[ContinuousReader] = Seq() override protected def sources: Seq[BaseStreamingSource] = continuousSources // For use only in test harnesses. private[sql] var currentEpochCoordinatorId: String = _ - override lazy val logicalPlan: LogicalPlan = { - assert(queryExecutionThread eq Thread.currentThread, - "logicalPlan must be initialized in StreamExecutionThread " + - s"but the current thread was ${Thread.currentThread}") + override val logicalPlan: LogicalPlan = { val toExecutionRelationMap = MutableMap[StreamingRelationV2, ContinuousExecutionRelation]() analyzedPlan.transform { case r @ StreamingRelationV2( @@ -72,7 +69,7 @@ class ContinuousExecution( ContinuousExecutionRelation(source, extraReaderOptions, output)(sparkSession) }) case StreamingRelationV2(_, sourceName, _, _, _) => - throw new AnalysisException( + throw new UnsupportedOperationException( s"Data source $sourceName does not support continuous processing.") } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala new file mode 100644 index 0000000000000..361979984bbec --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.sources + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.{Row, SparkSession} +import org.apache.spark.sql.sources.v2.DataSourceV2Options +import org.apache.spark.sql.sources.v2.writer.{DataSourceV2Writer, DataWriterFactory, WriterCommitMessage} +import org.apache.spark.sql.types.StructType + +/** + * A [[DataSourceV2Writer]] that collects results to the driver and prints them in the console. + * Generated by [[org.apache.spark.sql.execution.streaming.ConsoleSinkProvider]]. + * + * This sink should not be used for production, as it requires sending all rows to the driver + * and does not support recovery. + */ +class ConsoleWriter(batchId: Long, schema: StructType, options: DataSourceV2Options) + extends DataSourceV2Writer with Logging { + // Number of rows to display, by default 20 rows + private val numRowsToShow = options.getInt("numRows", 20) + + // Truncate the displayed data if it is too long, by default it is true + private val isTruncated = options.getBoolean("truncate", true) + + assert(SparkSession.getActiveSession.isDefined) + private val spark = SparkSession.getActiveSession.get + + override def createWriterFactory(): DataWriterFactory[Row] = PackedRowWriterFactory + + override def commit(messages: Array[WriterCommitMessage]): Unit = synchronized { + val batch = messages.collect { + case PackedRowCommitMessage(rows) => rows + }.flatten + + // scalastyle:off println + println("-------------------------------------------") + println(s"Batch: $batchId") + println("-------------------------------------------") + // scalastyle:off println + spark.createDataFrame( + spark.sparkContext.parallelize(batch), schema) + .show(numRowsToShow, isTruncated) + } + + override def abort(messages: Array[WriterCommitMessage]): Unit = {} + + override def toString(): String = s"ConsoleWriter[numRows=$numRowsToShow, truncate=$isTruncated]" +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PackedRowWriterFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PackedRowWriterFactory.scala new file mode 100644 index 0000000000000..9282ba05bdb7b --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PackedRowWriterFactory.scala @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.sources + +import scala.collection.mutable + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.Row +import org.apache.spark.sql.sources.v2.writer.{DataWriter, DataWriterFactory, WriterCommitMessage} + +/** + * A simple [[DataWriterFactory]] whose tasks just pack rows into the commit message for delivery + * to a [[org.apache.spark.sql.sources.v2.writer.DataSourceV2Writer]] on the driver. + * + * Note that, because it sends all rows to the driver, this factory will generally be unsuitable + * for production-quality sinks. It's intended for use in tests. + */ +case object PackedRowWriterFactory extends DataWriterFactory[Row] { + def createDataWriter(partitionId: Int, attemptNumber: Int): DataWriter[Row] = { + new PackedRowDataWriter() + } +} + +/** + * Commit message for a [[PackedRowDataWriter]], containing all the rows written in the most + * recent interval. + */ +case class PackedRowCommitMessage(rows: Array[Row]) extends WriterCommitMessage + +/** + * A simple [[DataWriter]] that just sends all the rows it's received as a commit message. + */ +class PackedRowDataWriter() extends DataWriter[Row] with Logging { + private val data = mutable.Buffer[Row]() + + override def write(row: Row): Unit = data.append(row) + + override def commit(): PackedRowCommitMessage = { + val msg = PackedRowCommitMessage(data.toArray) + data.clear() + msg + } + + override def abort(): Unit = data.clear() +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala index b5b4a05ab4973..d24f0ddeab4de 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger import org.apache.spark.sql.execution.streaming.sources.{MemoryPlanV2, MemorySinkV2} -import org.apache.spark.sql.sources.v2.streaming.ContinuousWriteSupport +import org.apache.spark.sql.sources.v2.streaming.{ContinuousWriteSupport, MicroBatchWriteSupport} /** * Interface used to write a streaming `Dataset` to external storage systems (e.g. file systems, @@ -280,14 +280,12 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { useTempCheckpointLocation = true, trigger = trigger) } else { - val sink = trigger match { - case _: ContinuousTrigger => - val ds = DataSource.lookupDataSource(source, df.sparkSession.sessionState.conf) - ds.newInstance() match { - case w: ContinuousWriteSupport => w - case _ => throw new AnalysisException( - s"Data source $source does not support continuous writing") - } + val ds = DataSource.lookupDataSource(source, df.sparkSession.sessionState.conf) + val sink = (ds.newInstance(), trigger) match { + case (w: ContinuousWriteSupport, _: ContinuousTrigger) => w + case (_, _: ContinuousTrigger) => throw new UnsupportedOperationException( + s"Data source $source does not support continuous writing") + case (w: MicroBatchWriteSupport, _) => w case _ => val ds = DataSource( df.sparkSession, diff --git a/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister index c6973bf41d34b..a0b25b4e82364 100644 --- a/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister +++ b/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -5,3 +5,11 @@ org.apache.spark.sql.sources.FakeSourceFour org.apache.fakesource.FakeExternalSourceOne org.apache.fakesource.FakeExternalSourceTwo org.apache.fakesource.FakeExternalSourceThree +org.apache.spark.sql.streaming.sources.FakeReadMicroBatchOnly +org.apache.spark.sql.streaming.sources.FakeReadContinuousOnly +org.apache.spark.sql.streaming.sources.FakeReadBothModes +org.apache.spark.sql.streaming.sources.FakeReadNeitherMode +org.apache.spark.sql.streaming.sources.FakeWriteMicroBatchOnly +org.apache.spark.sql.streaming.sources.FakeWriteContinuousOnly +org.apache.spark.sql.streaming.sources.FakeWriteBothModes +org.apache.spark.sql.streaming.sources.FakeWriteNeitherMode diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriterSuite.scala new file mode 100644 index 0000000000000..60ffee9b9b42c --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriterSuite.scala @@ -0,0 +1,135 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.sources + +import java.io.ByteArrayOutputStream + +import org.apache.spark.sql.execution.streaming.MemoryStream +import org.apache.spark.sql.streaming.StreamTest + +class ConsoleWriterSuite extends StreamTest { + import testImplicits._ + + test("console") { + val input = MemoryStream[Int] + + val captured = new ByteArrayOutputStream() + Console.withOut(captured) { + val query = input.toDF().writeStream.format("console").start() + try { + input.addData(1, 2, 3) + query.processAllAvailable() + input.addData(4, 5, 6) + query.processAllAvailable() + input.addData() + query.processAllAvailable() + } finally { + query.stop() + } + } + + assert(captured.toString() == + """------------------------------------------- + |Batch: 0 + |------------------------------------------- + |+-----+ + ||value| + |+-----+ + || 1| + || 2| + || 3| + |+-----+ + | + |------------------------------------------- + |Batch: 1 + |------------------------------------------- + |+-----+ + ||value| + |+-----+ + || 4| + || 5| + || 6| + |+-----+ + | + |------------------------------------------- + |Batch: 2 + |------------------------------------------- + |+-----+ + ||value| + |+-----+ + |+-----+ + | + |""".stripMargin) + } + + test("console with numRows") { + val input = MemoryStream[Int] + + val captured = new ByteArrayOutputStream() + Console.withOut(captured) { + val query = input.toDF().writeStream.format("console").option("NUMROWS", 2).start() + try { + input.addData(1, 2, 3) + query.processAllAvailable() + } finally { + query.stop() + } + } + + assert(captured.toString() == + """------------------------------------------- + |Batch: 0 + |------------------------------------------- + |+-----+ + ||value| + |+-----+ + || 1| + || 2| + |+-----+ + |only showing top 2 rows + | + |""".stripMargin) + } + + test("console with truncation") { + val input = MemoryStream[String] + + val captured = new ByteArrayOutputStream() + Console.withOut(captured) { + val query = input.toDF().writeStream.format("console").option("TRUNCATE", true).start() + try { + input.addData("123456789012345678901234567890") + query.processAllAvailable() + } finally { + query.stop() + } + } + + assert(captured.toString() == + """------------------------------------------- + |Batch: 0 + |------------------------------------------- + |+--------------------+ + || value| + |+--------------------+ + ||12345678901234567...| + |+--------------------+ + | + |""".stripMargin) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala new file mode 100644 index 0000000000000..f152174b0a7f0 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala @@ -0,0 +1,249 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming.sources + +import java.util.Optional + +import org.apache.spark.sql.{AnalysisException, Row} +import org.apache.spark.sql.execution.datasources.DataSource +import org.apache.spark.sql.execution.streaming.{LongOffset, RateStreamOffset} +import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger +import org.apache.spark.sql.sources.DataSourceRegister +import org.apache.spark.sql.sources.v2.{DataSourceV2, DataSourceV2Options} +import org.apache.spark.sql.sources.v2.reader.ReadTask +import org.apache.spark.sql.sources.v2.streaming.{ContinuousReadSupport, ContinuousWriteSupport, MicroBatchReadSupport, MicroBatchWriteSupport} +import org.apache.spark.sql.sources.v2.streaming.reader.{ContinuousReader, MicroBatchReader, Offset, PartitionOffset} +import org.apache.spark.sql.sources.v2.streaming.writer.ContinuousWriter +import org.apache.spark.sql.sources.v2.writer.DataSourceV2Writer +import org.apache.spark.sql.streaming.{OutputMode, StreamingQueryException, StreamTest, Trigger} +import org.apache.spark.sql.types.StructType +import org.apache.spark.util.Utils + +case class FakeReader() extends MicroBatchReader with ContinuousReader { + def setOffsetRange(start: Optional[Offset], end: Optional[Offset]): Unit = {} + def getStartOffset: Offset = RateStreamOffset(Map()) + def getEndOffset: Offset = RateStreamOffset(Map()) + def deserializeOffset(json: String): Offset = RateStreamOffset(Map()) + def commit(end: Offset): Unit = {} + def readSchema(): StructType = StructType(Seq()) + def stop(): Unit = {} + def mergeOffsets(offsets: Array[PartitionOffset]): Offset = RateStreamOffset(Map()) + def setOffset(start: Optional[Offset]): Unit = {} + + def createReadTasks(): java.util.ArrayList[ReadTask[Row]] = { + throw new IllegalStateException("fake source - cannot actually read") + } +} + +trait FakeMicroBatchReadSupport extends MicroBatchReadSupport { + override def createMicroBatchReader( + schema: Optional[StructType], + checkpointLocation: String, + options: DataSourceV2Options): MicroBatchReader = FakeReader() +} + +trait FakeContinuousReadSupport extends ContinuousReadSupport { + override def createContinuousReader( + schema: Optional[StructType], + checkpointLocation: String, + options: DataSourceV2Options): ContinuousReader = FakeReader() +} + +trait FakeMicroBatchWriteSupport extends MicroBatchWriteSupport { + def createMicroBatchWriter( + queryId: String, + epochId: Long, + schema: StructType, + mode: OutputMode, + options: DataSourceV2Options): Optional[DataSourceV2Writer] = { + throw new IllegalStateException("fake sink - cannot actually write") + } +} + +trait FakeContinuousWriteSupport extends ContinuousWriteSupport { + def createContinuousWriter( + queryId: String, + schema: StructType, + mode: OutputMode, + options: DataSourceV2Options): Optional[ContinuousWriter] = { + throw new IllegalStateException("fake sink - cannot actually write") + } +} + +class FakeReadMicroBatchOnly extends DataSourceRegister with FakeMicroBatchReadSupport { + override def shortName(): String = "fake-read-microbatch-only" +} + +class FakeReadContinuousOnly extends DataSourceRegister with FakeContinuousReadSupport { + override def shortName(): String = "fake-read-continuous-only" +} + +class FakeReadBothModes extends DataSourceRegister + with FakeMicroBatchReadSupport with FakeContinuousReadSupport { + override def shortName(): String = "fake-read-microbatch-continuous" +} + +class FakeReadNeitherMode extends DataSourceRegister { + override def shortName(): String = "fake-read-neither-mode" +} + +class FakeWriteMicroBatchOnly extends DataSourceRegister with FakeMicroBatchWriteSupport { + override def shortName(): String = "fake-write-microbatch-only" +} + +class FakeWriteContinuousOnly extends DataSourceRegister with FakeContinuousWriteSupport { + override def shortName(): String = "fake-write-continuous-only" +} + +class FakeWriteBothModes extends DataSourceRegister + with FakeMicroBatchWriteSupport with FakeContinuousWriteSupport { + override def shortName(): String = "fake-write-microbatch-continuous" +} + +class FakeWriteNeitherMode extends DataSourceRegister { + override def shortName(): String = "fake-write-neither-mode" +} + +class StreamingDataSourceV2Suite extends StreamTest { + + override def beforeAll(): Unit = { + super.beforeAll() + val fakeCheckpoint = Utils.createTempDir() + spark.conf.set("spark.sql.streaming.checkpointLocation", fakeCheckpoint.getCanonicalPath) + } + + val readFormats = Seq( + "fake-read-microbatch-only", + "fake-read-continuous-only", + "fake-read-microbatch-continuous", + "fake-read-neither-mode") + val writeFormats = Seq( + "fake-write-microbatch-only", + "fake-write-continuous-only", + "fake-write-microbatch-continuous", + "fake-write-neither-mode") + val triggers = Seq( + Trigger.Once(), + Trigger.ProcessingTime(1000), + Trigger.Continuous(1000)) + + private def testPositiveCase(readFormat: String, writeFormat: String, trigger: Trigger) = { + val query = spark.readStream + .format(readFormat) + .load() + .writeStream + .format(writeFormat) + .trigger(trigger) + .start() + query.stop() + } + + private def testNegativeCase( + readFormat: String, + writeFormat: String, + trigger: Trigger, + errorMsg: String) = { + val ex = intercept[UnsupportedOperationException] { + testPositiveCase(readFormat, writeFormat, trigger) + } + assert(ex.getMessage.contains(errorMsg)) + } + + private def testPostCreationNegativeCase( + readFormat: String, + writeFormat: String, + trigger: Trigger, + errorMsg: String) = { + val query = spark.readStream + .format(readFormat) + .load() + .writeStream + .format(writeFormat) + .trigger(trigger) + .start() + + eventually(timeout(streamingTimeout)) { + assert(query.exception.isDefined) + assert(query.exception.get.cause != null) + assert(query.exception.get.cause.getMessage.contains(errorMsg)) + } + } + + // Get a list of (read, write, trigger) tuples for test cases. + val cases = readFormats.flatMap { read => + writeFormats.flatMap { write => + triggers.map(t => (write, t)) + }.map { + case (write, t) => (read, write, t) + } + } + + for ((read, write, trigger) <- cases) { + testQuietly(s"stream with read format $read, write format $write, trigger $trigger") { + val readSource = DataSource.lookupDataSource(read, spark.sqlContext.conf).newInstance() + val writeSource = DataSource.lookupDataSource(write, spark.sqlContext.conf).newInstance() + (readSource, writeSource, trigger) match { + // Valid microbatch queries. + case (_: MicroBatchReadSupport, _: MicroBatchWriteSupport, t) + if !t.isInstanceOf[ContinuousTrigger] => + testPositiveCase(read, write, trigger) + + // Valid continuous queries. + case (_: ContinuousReadSupport, _: ContinuousWriteSupport, _: ContinuousTrigger) => + testPositiveCase(read, write, trigger) + + // Invalid - can't read at all + case (r, _, _) + if !r.isInstanceOf[MicroBatchReadSupport] + && !r.isInstanceOf[ContinuousReadSupport] => + testNegativeCase(read, write, trigger, + s"Data source $read does not support streamed reading") + + // Invalid - trigger is continuous but writer is not + case (_, w, _: ContinuousTrigger) if !w.isInstanceOf[ContinuousWriteSupport] => + testNegativeCase(read, write, trigger, + s"Data source $write does not support continuous writing") + + // Invalid - can't write at all + case (_, w, _) + if !w.isInstanceOf[MicroBatchWriteSupport] + && !w.isInstanceOf[ContinuousWriteSupport] => + testNegativeCase(read, write, trigger, + s"Data source $write does not support streamed writing") + + // Invalid - trigger and writer are continuous but reader is not + case (r, _: ContinuousWriteSupport, _: ContinuousTrigger) + if !r.isInstanceOf[ContinuousReadSupport] => + testNegativeCase(read, write, trigger, + s"Data source $read does not support continuous processing") + + // Invalid - trigger is microbatch but writer is not + case (_, w, t) + if !w.isInstanceOf[MicroBatchWriteSupport] && !t.isInstanceOf[ContinuousTrigger] => + testNegativeCase(read, write, trigger, + s"Data source $write does not support streamed writing") + + // Invalid - trigger and writer are microbatch but reader is not + case (r, _, t) + if !r.isInstanceOf[MicroBatchReadSupport] && !t.isInstanceOf[ContinuousTrigger] => + testPostCreationNegativeCase(read, write, trigger, + s"Data source $read does not support microbatch processing") + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala index aa163d2211c38..8212fb912ec57 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala @@ -422,21 +422,6 @@ class DataStreamReaderWriterSuite extends StreamTest with BeforeAndAfter { } } - test("ConsoleSink can be correctly loaded") { - LastOptions.clear() - val df = spark.readStream - .format("org.apache.spark.sql.streaming.test") - .load() - - val sq = df.writeStream - .format("console") - .option("checkpointLocation", newMetadataDir) - .trigger(ProcessingTime(2.seconds)) - .start() - - sq.awaitTermination(2000L) - } - test("prevent all column partitioning") { withTempDir { dir => val path = dir.getCanonicalPath @@ -450,16 +435,6 @@ class DataStreamReaderWriterSuite extends StreamTest with BeforeAndAfter { } } - test("ConsoleSink should not require checkpointLocation") { - LastOptions.clear() - val df = spark.readStream - .format("org.apache.spark.sql.streaming.test") - .load() - - val sq = df.writeStream.format("console").start() - sq.stop() - } - private def testMemorySinkCheckpointRecovery(chkLoc: String, provideInWriter: Boolean): Unit = { import testImplicits._ val ms = new MemoryStream[Int](0, sqlContext) From 7a2248341396840628eef398aa512cac3e3bd55f Mon Sep 17 00:00:00 2001 From: jerryshao Date: Thu, 18 Jan 2018 19:18:55 +0800 Subject: [PATCH 0129/1161] [SPARK-23140][SQL] Add DataSourceV2Strategy to Hive Session state's planner ## What changes were proposed in this pull request? `DataSourceV2Strategy` is missing in `HiveSessionStateBuilder`'s planner, which will throw exception as described in [SPARK-23140](https://issues.apache.org/jira/browse/SPARK-23140). ## How was this patch tested? Manual test. Author: jerryshao Closes #20305 from jerryshao/SPARK-23140. --- .../sql/hive/HiveSessionStateBuilder.scala | 17 +---------------- 1 file changed, 1 insertion(+), 16 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala index dc92ad3b0c1ac..12c74368dd184 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala @@ -96,22 +96,7 @@ class HiveSessionStateBuilder(session: SparkSession, parentState: Option[Session override val sparkSession: SparkSession = session override def extraPlanningStrategies: Seq[Strategy] = - super.extraPlanningStrategies ++ customPlanningStrategies - - override def strategies: Seq[Strategy] = { - experimentalMethods.extraStrategies ++ - extraPlanningStrategies ++ Seq( - FileSourceStrategy, - DataSourceStrategy(conf), - SpecialLimits, - InMemoryScans, - HiveTableScans, - Scripts, - Aggregation, - JoinSelection, - BasicOperators - ) - } + super.extraPlanningStrategies ++ customPlanningStrategies ++ Seq(HiveTableScans, Scripts) } } From e28eb431146bcdcaf02a6f6c406ca30920592a6a Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Thu, 18 Jan 2018 21:24:39 +0800 Subject: [PATCH 0130/1161] [SPARK-22036][SQL] Decimal multiplication with high precision/scale often returns NULL ## What changes were proposed in this pull request? When there is an operation between Decimals and the result is a number which is not representable exactly with the result's precision and scale, Spark is returning `NULL`. This was done to reflect Hive's behavior, but it is against SQL ANSI 2011, which states that "If the result cannot be represented exactly in the result type, then whether it is rounded or truncated is implementation-defined". Moreover, Hive now changed its behavior in order to respect the standard, thanks to HIVE-15331. Therefore, the PR propose to: - update the rules to determine the result precision and scale according to the new Hive's ones introduces in HIVE-15331; - round the result of the operations, when it is not representable exactly with the result's precision and scale, instead of returning `NULL` - introduce a new config `spark.sql.decimalOperations.allowPrecisionLoss` which default to `true` (ie. the new behavior) in order to allow users to switch back to the previous one. Hive behavior reflects SQLServer's one. The only difference is that the precision and scale are adjusted for all the arithmetic operations in Hive, while SQL Server is said to do so only for multiplications and divisions in the documentation. This PR follows Hive's behavior. A more detailed explanation is available here: https://mail-archives.apache.org/mod_mbox/spark-dev/201712.mbox/%3CCAEorWNAJ4TxJR9NBcgSFMD_VxTg8qVxusjP%2BAJP-x%2BJV9zH-yA%40mail.gmail.com%3E. ## How was this patch tested? modified and added UTs. Comparisons with results of Hive and SQLServer. Author: Marco Gaido Closes #20023 from mgaido91/SPARK-22036. --- docs/sql-programming-guide.md | 5 + .../catalyst/analysis/DecimalPrecision.scala | 114 +++++--- .../sql/catalyst/expressions/literals.scala | 2 +- .../apache/spark/sql/internal/SQLConf.scala | 12 + .../apache/spark/sql/types/DecimalType.scala | 45 +++- .../sql/catalyst/analysis/AnalysisSuite.scala | 4 +- .../analysis/DecimalPrecisionSuite.scala | 20 +- .../native/decimalArithmeticOperations.sql | 47 ++++ .../decimalArithmeticOperations.sql.out | 245 ++++++++++++++++-- .../native/decimalPrecision.sql.out | 4 +- .../org/apache/spark/sql/SQLQuerySuite.scala | 18 -- 11 files changed, 434 insertions(+), 82 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 258c769ff593b..3e2e48a0ef249 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1793,6 +1793,11 @@ options. - Since Spark 2.3, when all inputs are binary, `functions.concat()` returns an output as binary. Otherwise, it returns as a string. Until Spark 2.3, it always returns as a string despite of input types. To keep the old behavior, set `spark.sql.function.concatBinaryAsString` to `true`. - Since Spark 2.3, when all inputs are binary, SQL `elt()` returns an output as binary. Otherwise, it returns as a string. Until Spark 2.3, it always returns as a string despite of input types. To keep the old behavior, set `spark.sql.function.eltOutputAsString` to `true`. + - Since Spark 2.3, by default arithmetic operations between decimals return a rounded value if an exact representation is not possible (instead of returning NULL). This is compliant to SQL ANSI 2011 specification and Hive's new behavior introduced in Hive 2.2 (HIVE-15331). This involves the following changes + - The rules to determine the result type of an arithmetic operation have been updated. In particular, if the precision / scale needed are out of the range of available values, the scale is reduced up to 6, in order to prevent the truncation of the integer part of the decimals. All the arithmetic operations are affected by the change, ie. addition (`+`), subtraction (`-`), multiplication (`*`), division (`/`), remainder (`%`) and positive module (`pmod`). + - Literal values used in SQL operations are converted to DECIMAL with the exact precision and scale needed by them. + - The configuration `spark.sql.decimalOperations.allowPrecisionLoss` has been introduced. It defaults to `true`, which means the new behavior described here; if set to `false`, Spark uses previous rules, ie. it doesn't adjust the needed scale to represent the values and it returns NULL if an exact representation of the value is not possible. + ## Upgrading From Spark SQL 2.1 to 2.2 - Spark 2.1.1 introduced a new configuration key: `spark.sql.hive.caseSensitiveInferenceMode`. It had a default setting of `NEVER_INFER`, which kept behavior identical to 2.1.0. However, Spark 2.2.0 changes this setting's default value to `INFER_AND_SAVE` to restore compatibility with reading Hive metastore tables whose underlying file schema have mixed-case column names. With the `INFER_AND_SAVE` configuration value, on first access Spark will perform schema inference on any Hive metastore table for which it has not already saved an inferred schema. Note that schema inference can be a very time consuming operation for tables with thousands of partitions. If compatibility with mixed-case column names is not a concern, you can safely set `spark.sql.hive.caseSensitiveInferenceMode` to `NEVER_INFER` to avoid the initial overhead of schema inference. Note that with the new default `INFER_AND_SAVE` setting, the results of the schema inference are saved as a metastore key for future use. Therefore, the initial schema inference occurs only at a table's first access. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala index a8100b9b24aac..ab63131b07573 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala @@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.Literal._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -42,8 +43,10 @@ import org.apache.spark.sql.types._ * e1 / e2 p1 - s1 + s2 + max(6, s1 + p2 + 1) max(6, s1 + p2 + 1) * e1 % e2 min(p1-s1, p2-s2) + max(s1, s2) max(s1, s2) * e1 union e2 max(s1, s2) + max(p1-s1, p2-s2) max(s1, s2) - * sum(e1) p1 + 10 s1 - * avg(e1) p1 + 4 s1 + 4 + * + * When `spark.sql.decimalOperations.allowPrecisionLoss` is set to true, if the precision / scale + * needed are out of the range of available values, the scale is reduced up to 6, in order to + * prevent the truncation of the integer part of the decimals. * * To implement the rules for fixed-precision types, we introduce casts to turn them to unlimited * precision, do the math on unlimited-precision numbers, then introduce casts back to the @@ -56,6 +59,7 @@ import org.apache.spark.sql.types._ * - INT gets turned into DECIMAL(10, 0) * - LONG gets turned into DECIMAL(20, 0) * - FLOAT and DOUBLE cause fixed-length decimals to turn into DOUBLE + * - Literals INT and LONG get turned into DECIMAL with the precision strictly needed by the value */ // scalastyle:on object DecimalPrecision extends TypeCoercionRule { @@ -93,41 +97,76 @@ object DecimalPrecision extends TypeCoercionRule { case e: BinaryArithmetic if e.left.isInstanceOf[PromotePrecision] => e case Add(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => - val dt = DecimalType.bounded(max(s1, s2) + max(p1 - s1, p2 - s2) + 1, max(s1, s2)) - CheckOverflow(Add(promotePrecision(e1, dt), promotePrecision(e2, dt)), dt) + val resultScale = max(s1, s2) + val resultType = if (SQLConf.get.decimalOperationsAllowPrecisionLoss) { + DecimalType.adjustPrecisionScale(max(p1 - s1, p2 - s2) + resultScale + 1, + resultScale) + } else { + DecimalType.bounded(max(p1 - s1, p2 - s2) + resultScale + 1, resultScale) + } + CheckOverflow(Add(promotePrecision(e1, resultType), promotePrecision(e2, resultType)), + resultType) case Subtract(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => - val dt = DecimalType.bounded(max(s1, s2) + max(p1 - s1, p2 - s2) + 1, max(s1, s2)) - CheckOverflow(Subtract(promotePrecision(e1, dt), promotePrecision(e2, dt)), dt) + val resultScale = max(s1, s2) + val resultType = if (SQLConf.get.decimalOperationsAllowPrecisionLoss) { + DecimalType.adjustPrecisionScale(max(p1 - s1, p2 - s2) + resultScale + 1, + resultScale) + } else { + DecimalType.bounded(max(p1 - s1, p2 - s2) + resultScale + 1, resultScale) + } + CheckOverflow(Subtract(promotePrecision(e1, resultType), promotePrecision(e2, resultType)), + resultType) case Multiply(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => - val resultType = DecimalType.bounded(p1 + p2 + 1, s1 + s2) + val resultType = if (SQLConf.get.decimalOperationsAllowPrecisionLoss) { + DecimalType.adjustPrecisionScale(p1 + p2 + 1, s1 + s2) + } else { + DecimalType.bounded(p1 + p2 + 1, s1 + s2) + } val widerType = widerDecimalType(p1, s1, p2, s2) CheckOverflow(Multiply(promotePrecision(e1, widerType), promotePrecision(e2, widerType)), resultType) case Divide(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => - var intDig = min(DecimalType.MAX_SCALE, p1 - s1 + s2) - var decDig = min(DecimalType.MAX_SCALE, max(6, s1 + p2 + 1)) - val diff = (intDig + decDig) - DecimalType.MAX_SCALE - if (diff > 0) { - decDig -= diff / 2 + 1 - intDig = DecimalType.MAX_SCALE - decDig + val resultType = if (SQLConf.get.decimalOperationsAllowPrecisionLoss) { + // Precision: p1 - s1 + s2 + max(6, s1 + p2 + 1) + // Scale: max(6, s1 + p2 + 1) + val intDig = p1 - s1 + s2 + val scale = max(DecimalType.MINIMUM_ADJUSTED_SCALE, s1 + p2 + 1) + val prec = intDig + scale + DecimalType.adjustPrecisionScale(prec, scale) + } else { + var intDig = min(DecimalType.MAX_SCALE, p1 - s1 + s2) + var decDig = min(DecimalType.MAX_SCALE, max(6, s1 + p2 + 1)) + val diff = (intDig + decDig) - DecimalType.MAX_SCALE + if (diff > 0) { + decDig -= diff / 2 + 1 + intDig = DecimalType.MAX_SCALE - decDig + } + DecimalType.bounded(intDig + decDig, decDig) } - val resultType = DecimalType.bounded(intDig + decDig, decDig) val widerType = widerDecimalType(p1, s1, p2, s2) CheckOverflow(Divide(promotePrecision(e1, widerType), promotePrecision(e2, widerType)), resultType) case Remainder(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => - val resultType = DecimalType.bounded(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2)) + val resultType = if (SQLConf.get.decimalOperationsAllowPrecisionLoss) { + DecimalType.adjustPrecisionScale(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2)) + } else { + DecimalType.bounded(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2)) + } // resultType may have lower precision, so we cast them into wider type first. val widerType = widerDecimalType(p1, s1, p2, s2) CheckOverflow(Remainder(promotePrecision(e1, widerType), promotePrecision(e2, widerType)), resultType) case Pmod(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => - val resultType = DecimalType.bounded(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2)) + val resultType = if (SQLConf.get.decimalOperationsAllowPrecisionLoss) { + DecimalType.adjustPrecisionScale(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2)) + } else { + DecimalType.bounded(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2)) + } // resultType may have lower precision, so we cast them into wider type first. val widerType = widerDecimalType(p1, s1, p2, s2) CheckOverflow(Pmod(promotePrecision(e1, widerType), promotePrecision(e2, widerType)), @@ -137,9 +176,6 @@ object DecimalPrecision extends TypeCoercionRule { e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 => val resultType = widerDecimalType(p1, s1, p2, s2) b.makeCopy(Array(Cast(e1, resultType), Cast(e2, resultType))) - - // TODO: MaxOf, MinOf, etc might want other rules - // SUM and AVERAGE are handled by the implementations of those expressions } /** @@ -243,17 +279,35 @@ object DecimalPrecision extends TypeCoercionRule { // Promote integers inside a binary expression with fixed-precision decimals to decimals, // and fixed-precision decimals in an expression with floats / doubles to doubles case b @ BinaryOperator(left, right) if left.dataType != right.dataType => - (left.dataType, right.dataType) match { - case (t: IntegralType, DecimalType.Fixed(p, s)) => - b.makeCopy(Array(Cast(left, DecimalType.forType(t)), right)) - case (DecimalType.Fixed(p, s), t: IntegralType) => - b.makeCopy(Array(left, Cast(right, DecimalType.forType(t)))) - case (t, DecimalType.Fixed(p, s)) if isFloat(t) => - b.makeCopy(Array(left, Cast(right, DoubleType))) - case (DecimalType.Fixed(p, s), t) if isFloat(t) => - b.makeCopy(Array(Cast(left, DoubleType), right)) - case _ => - b + (left, right) match { + // Promote literal integers inside a binary expression with fixed-precision decimals to + // decimals. The precision and scale are the ones strictly needed by the integer value. + // Requiring more precision than necessary may lead to a useless loss of precision. + // Consider the following example: multiplying a column which is DECIMAL(38, 18) by 2. + // If we use the default precision and scale for the integer type, 2 is considered a + // DECIMAL(10, 0). According to the rules, the result would be DECIMAL(38 + 10 + 1, 18), + // which is out of range and therefore it will becomes DECIMAL(38, 7), leading to + // potentially loosing 11 digits of the fractional part. Using only the precision needed + // by the Literal, instead, the result would be DECIMAL(38 + 1 + 1, 18), which would + // become DECIMAL(38, 16), safely having a much lower precision loss. + case (l: Literal, r) if r.dataType.isInstanceOf[DecimalType] + && l.dataType.isInstanceOf[IntegralType] => + b.makeCopy(Array(Cast(l, DecimalType.fromLiteral(l)), r)) + case (l, r: Literal) if l.dataType.isInstanceOf[DecimalType] + && r.dataType.isInstanceOf[IntegralType] => + b.makeCopy(Array(l, Cast(r, DecimalType.fromLiteral(r)))) + // Promote integers inside a binary expression with fixed-precision decimals to decimals, + // and fixed-precision decimals in an expression with floats / doubles to doubles + case (l @ IntegralType(), r @ DecimalType.Expression(_, _)) => + b.makeCopy(Array(Cast(l, DecimalType.forType(l.dataType)), r)) + case (l @ DecimalType.Expression(_, _), r @ IntegralType()) => + b.makeCopy(Array(l, Cast(r, DecimalType.forType(r.dataType)))) + case (l, r @ DecimalType.Expression(_, _)) if isFloat(l.dataType) => + b.makeCopy(Array(l, Cast(r, DoubleType))) + case (l @ DecimalType.Expression(_, _), r) if isFloat(r.dataType) => + b.makeCopy(Array(Cast(l, DoubleType), r)) + case _ => b } } + } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index 383203a209833..cd176d941819f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -58,7 +58,7 @@ object Literal { case s: Short => Literal(s, ShortType) case s: String => Literal(UTF8String.fromString(s), StringType) case b: Boolean => Literal(b, BooleanType) - case d: BigDecimal => Literal(Decimal(d), DecimalType(Math.max(d.precision, d.scale), d.scale)) + case d: BigDecimal => Literal(Decimal(d), DecimalType.fromBigDecimal(d)) case d: JavaBigDecimal => Literal(Decimal(d), DecimalType(Math.max(d.precision, d.scale), d.scale())) case d: Decimal => Literal(d, DecimalType(Math.max(d.precision, d.scale), d.scale)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 16fbb0c3e9e21..cc4f4bf332459 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1064,6 +1064,16 @@ object SQLConf { .booleanConf .createWithDefault(true) + val DECIMAL_OPERATIONS_ALLOW_PREC_LOSS = + buildConf("spark.sql.decimalOperations.allowPrecisionLoss") + .internal() + .doc("When true (default), establishing the result type of an arithmetic operation " + + "happens according to Hive behavior and SQL ANSI 2011 specification, ie. rounding the " + + "decimal part of the result if an exact representation is not possible. Otherwise, NULL " + + "is returned in those cases, as previously.") + .booleanConf + .createWithDefault(true) + val SQL_STRING_REDACTION_PATTERN = ConfigBuilder("spark.sql.redaction.string.regex") .doc("Regex to decide which parts of strings produced by Spark contain sensitive " + @@ -1441,6 +1451,8 @@ class SQLConf extends Serializable with Logging { def replaceExceptWithFilter: Boolean = getConf(REPLACE_EXCEPT_WITH_FILTER) + def decimalOperationsAllowPrecisionLoss: Boolean = getConf(DECIMAL_OPERATIONS_ALLOW_PREC_LOSS) + def continuousStreamingExecutorQueueSize: Int = getConf(CONTINUOUS_STREAMING_EXECUTOR_QUEUE_SIZE) def continuousStreamingExecutorPollIntervalMs: Long = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala index 6e050c18b8acb..ef3b67c0d48d0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala @@ -23,7 +23,7 @@ import scala.reflect.runtime.universe.typeTag import org.apache.spark.annotation.InterfaceStability import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.{Expression, Literal} /** @@ -117,6 +117,7 @@ object DecimalType extends AbstractDataType { val MAX_SCALE = 38 val SYSTEM_DEFAULT: DecimalType = DecimalType(MAX_PRECISION, 18) val USER_DEFAULT: DecimalType = DecimalType(10, 0) + val MINIMUM_ADJUSTED_SCALE = 6 // The decimal types compatible with other numeric types private[sql] val ByteDecimal = DecimalType(3, 0) @@ -136,10 +137,52 @@ object DecimalType extends AbstractDataType { case DoubleType => DoubleDecimal } + private[sql] def fromLiteral(literal: Literal): DecimalType = literal.value match { + case v: Short => fromBigDecimal(BigDecimal(v)) + case v: Int => fromBigDecimal(BigDecimal(v)) + case v: Long => fromBigDecimal(BigDecimal(v)) + case _ => forType(literal.dataType) + } + + private[sql] def fromBigDecimal(d: BigDecimal): DecimalType = { + DecimalType(Math.max(d.precision, d.scale), d.scale) + } + private[sql] def bounded(precision: Int, scale: Int): DecimalType = { DecimalType(min(precision, MAX_PRECISION), min(scale, MAX_SCALE)) } + /** + * Scale adjustment implementation is based on Hive's one, which is itself inspired to + * SQLServer's one. In particular, when a result precision is greater than + * {@link #MAX_PRECISION}, the corresponding scale is reduced to prevent the integral part of a + * result from being truncated. + * + * This method is used only when `spark.sql.decimalOperations.allowPrecisionLoss` is set to true. + */ + private[sql] def adjustPrecisionScale(precision: Int, scale: Int): DecimalType = { + // Assumptions: + assert(precision >= scale) + assert(scale >= 0) + + if (precision <= MAX_PRECISION) { + // Adjustment only needed when we exceed max precision + DecimalType(precision, scale) + } else { + // Precision/scale exceed maximum precision. Result must be adjusted to MAX_PRECISION. + val intDigits = precision - scale + // If original scale is less than MINIMUM_ADJUSTED_SCALE, use original scale value; otherwise + // preserve at least MINIMUM_ADJUSTED_SCALE fractional digits + val minScaleValue = Math.min(scale, MINIMUM_ADJUSTED_SCALE) + // The resulting scale is the maximum between what is available without causing a loss of + // digits for the integer part of the decimal and the minimum guaranteed scale, which is + // computed above + val adjustedScale = Math.max(MAX_PRECISION - intDigits, minScaleValue) + + DecimalType(MAX_PRECISION, adjustedScale) + } + } + override private[sql] def defaultConcreteType: DataType = SYSTEM_DEFAULT override private[sql] def acceptsType(other: DataType): Boolean = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index f4514205d3ae0..cd8579584eada 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -408,8 +408,8 @@ class AnalysisSuite extends AnalysisTest with Matchers { assertExpressionType(sum(Divide(1.0, 2.0)), DoubleType) assertExpressionType(sum(Divide(1, 2.0f)), DoubleType) assertExpressionType(sum(Divide(1.0f, 2)), DoubleType) - assertExpressionType(sum(Divide(1, Decimal(2))), DecimalType(31, 11)) - assertExpressionType(sum(Divide(Decimal(1), 2)), DecimalType(31, 11)) + assertExpressionType(sum(Divide(1, Decimal(2))), DecimalType(22, 11)) + assertExpressionType(sum(Divide(Decimal(1), 2)), DecimalType(26, 6)) assertExpressionType(sum(Divide(Decimal(1), 2.0)), DoubleType) assertExpressionType(sum(Divide(1.0, Decimal(2.0))), DoubleType) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala index 60e46a9910a8b..c86dc18dfa680 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala @@ -136,19 +136,19 @@ class DecimalPrecisionSuite extends AnalysisTest with BeforeAndAfter { test("maximum decimals") { for (expr <- Seq(d1, d2, i, u)) { - checkType(Add(expr, u), DecimalType.SYSTEM_DEFAULT) - checkType(Subtract(expr, u), DecimalType.SYSTEM_DEFAULT) + checkType(Add(expr, u), DecimalType(38, 17)) + checkType(Subtract(expr, u), DecimalType(38, 17)) } - checkType(Multiply(d1, u), DecimalType(38, 19)) - checkType(Multiply(d2, u), DecimalType(38, 20)) - checkType(Multiply(i, u), DecimalType(38, 18)) - checkType(Multiply(u, u), DecimalType(38, 36)) + checkType(Multiply(d1, u), DecimalType(38, 16)) + checkType(Multiply(d2, u), DecimalType(38, 14)) + checkType(Multiply(i, u), DecimalType(38, 7)) + checkType(Multiply(u, u), DecimalType(38, 6)) - checkType(Divide(u, d1), DecimalType(38, 18)) - checkType(Divide(u, d2), DecimalType(38, 19)) - checkType(Divide(u, i), DecimalType(38, 23)) - checkType(Divide(u, u), DecimalType(38, 18)) + checkType(Divide(u, d1), DecimalType(38, 17)) + checkType(Divide(u, d2), DecimalType(38, 16)) + checkType(Divide(u, i), DecimalType(38, 18)) + checkType(Divide(u, u), DecimalType(38, 6)) checkType(Remainder(d1, u), DecimalType(19, 18)) checkType(Remainder(d2, u), DecimalType(21, 18)) diff --git a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/decimalArithmeticOperations.sql b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/decimalArithmeticOperations.sql index c8e108ac2c45e..c6d8a49d4b93a 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/decimalArithmeticOperations.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/decimalArithmeticOperations.sql @@ -22,6 +22,51 @@ select a / b from t; select a % b from t; select pmod(a, b) from t; +-- tests for decimals handling in operations +create table decimals_test(id int, a decimal(38,18), b decimal(38,18)) using parquet; + +insert into decimals_test values(1, 100.0, 999.0), (2, 12345.123, 12345.123), + (3, 0.1234567891011, 1234.1), (4, 123456789123456789.0, 1.123456789123456789); + +-- test decimal operations +select id, a+b, a-b, a*b, a/b from decimals_test order by id; + +-- test operations between decimals and constants +select id, a*10, b/10 from decimals_test order by id; + +-- test operations on constants +select 10.3 * 3.0; +select 10.3000 * 3.0; +select 10.30000 * 30.0; +select 10.300000000000000000 * 3.000000000000000000; +select 10.300000000000000000 * 3.0000000000000000000; + +-- arithmetic operations causing an overflow return NULL +select (5e36 + 0.1) + 5e36; +select (-4e36 - 0.1) - 7e36; +select 12345678901234567890.0 * 12345678901234567890.0; +select 1e35 / 0.1; + +-- arithmetic operations causing a precision loss are truncated +select 123456789123456789.1234567890 * 1.123456789123456789; +select 0.001 / 9876543210987654321098765432109876543.2 + +-- return NULL instead of rounding, according to old Spark versions' behavior +set spark.sql.decimalOperations.allowPrecisionLoss=false; + +-- test decimal operations +select id, a+b, a-b, a*b, a/b from decimals_test order by id; + +-- test operations between decimals and constants +select id, a*10, b/10 from decimals_test order by id; + +-- test operations on constants +select 10.3 * 3.0; +select 10.3000 * 3.0; +select 10.30000 * 30.0; +select 10.300000000000000000 * 3.000000000000000000; +select 10.300000000000000000 * 3.0000000000000000000; + -- arithmetic operations causing an overflow return NULL select (5e36 + 0.1) + 5e36; select (-4e36 - 0.1) - 7e36; @@ -31,3 +76,5 @@ select 1e35 / 0.1; -- arithmetic operations causing a precision loss return NULL select 123456789123456789.1234567890 * 1.123456789123456789; select 0.001 / 9876543210987654321098765432109876543.2 + +drop table decimals_test; diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/decimalArithmeticOperations.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/decimalArithmeticOperations.sql.out index ce02f6adc456c..4d70fe19d539f 100644 --- a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/decimalArithmeticOperations.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/decimalArithmeticOperations.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 10 +-- Number of queries: 32 -- !query 0 @@ -35,48 +35,257 @@ NULL -- !query 4 -select (5e36 + 0.1) + 5e36 +create table decimals_test(id int, a decimal(38,18), b decimal(38,18)) using parquet -- !query 4 schema -struct<(CAST((CAST(5E+36 AS DECIMAL(38,1)) + CAST(0.1 AS DECIMAL(38,1))) AS DECIMAL(38,1)) + CAST(5E+36 AS DECIMAL(38,1))):decimal(38,1)> +struct<> -- !query 4 output -NULL + -- !query 5 -select (-4e36 - 0.1) - 7e36 +insert into decimals_test values(1, 100.0, 999.0), (2, 12345.123, 12345.123), + (3, 0.1234567891011, 1234.1), (4, 123456789123456789.0, 1.123456789123456789) -- !query 5 schema -struct<(CAST((CAST(-4E+36 AS DECIMAL(38,1)) - CAST(0.1 AS DECIMAL(38,1))) AS DECIMAL(38,1)) - CAST(7E+36 AS DECIMAL(38,1))):decimal(38,1)> +struct<> -- !query 5 output -NULL + -- !query 6 -select 12345678901234567890.0 * 12345678901234567890.0 +select id, a+b, a-b, a*b, a/b from decimals_test order by id -- !query 6 schema -struct<(12345678901234567890.0 * 12345678901234567890.0):decimal(38,2)> +struct -- !query 6 output -NULL +1 1099 -899 99900 0.1001 +2 24690.246 0 152402061.885129 1 +3 1234.2234567891011 -1233.9765432108989 152.358023 0.0001 +4 123456789123456790.12345678912345679 123456789123456787.87654321087654321 138698367904130467.515623 109890109097814272.043109 -- !query 7 -select 1e35 / 0.1 +select id, a*10, b/10 from decimals_test order by id -- !query 7 schema -struct<(CAST(1E+35 AS DECIMAL(37,1)) / CAST(0.1 AS DECIMAL(37,1))):decimal(38,3)> +struct -- !query 7 output -NULL +1 1000 99.9 +2 123451.23 1234.5123 +3 1.234567891011 123.41 +4 1234567891234567890 0.112345678912345679 -- !query 8 -select 123456789123456789.1234567890 * 1.123456789123456789 +select 10.3 * 3.0 -- !query 8 schema -struct<(CAST(123456789123456789.1234567890 AS DECIMAL(36,18)) * CAST(1.123456789123456789 AS DECIMAL(36,18))):decimal(38,28)> +struct<(CAST(10.3 AS DECIMAL(3,1)) * CAST(3.0 AS DECIMAL(3,1))):decimal(6,2)> -- !query 8 output -NULL +30.9 -- !query 9 -select 0.001 / 9876543210987654321098765432109876543.2 +select 10.3000 * 3.0 -- !query 9 schema -struct<(CAST(0.001 AS DECIMAL(38,3)) / CAST(9876543210987654321098765432109876543.2 AS DECIMAL(38,3))):decimal(38,37)> +struct<(CAST(10.3000 AS DECIMAL(6,4)) * CAST(3.0 AS DECIMAL(6,4))):decimal(9,5)> -- !query 9 output +30.9 + + +-- !query 10 +select 10.30000 * 30.0 +-- !query 10 schema +struct<(CAST(10.30000 AS DECIMAL(7,5)) * CAST(30.0 AS DECIMAL(7,5))):decimal(11,6)> +-- !query 10 output +309 + + +-- !query 11 +select 10.300000000000000000 * 3.000000000000000000 +-- !query 11 schema +struct<(CAST(10.300000000000000000 AS DECIMAL(20,18)) * CAST(3.000000000000000000 AS DECIMAL(20,18))):decimal(38,34)> +-- !query 11 output +30.9 + + +-- !query 12 +select 10.300000000000000000 * 3.0000000000000000000 +-- !query 12 schema +struct<(CAST(10.300000000000000000 AS DECIMAL(21,19)) * CAST(3.0000000000000000000 AS DECIMAL(21,19))):decimal(38,34)> +-- !query 12 output +30.9 + + +-- !query 13 +select (5e36 + 0.1) + 5e36 +-- !query 13 schema +struct<(CAST((CAST(5E+36 AS DECIMAL(38,1)) + CAST(0.1 AS DECIMAL(38,1))) AS DECIMAL(38,1)) + CAST(5E+36 AS DECIMAL(38,1))):decimal(38,1)> +-- !query 13 output +NULL + + +-- !query 14 +select (-4e36 - 0.1) - 7e36 +-- !query 14 schema +struct<(CAST((CAST(-4E+36 AS DECIMAL(38,1)) - CAST(0.1 AS DECIMAL(38,1))) AS DECIMAL(38,1)) - CAST(7E+36 AS DECIMAL(38,1))):decimal(38,1)> +-- !query 14 output +NULL + + +-- !query 15 +select 12345678901234567890.0 * 12345678901234567890.0 +-- !query 15 schema +struct<(12345678901234567890.0 * 12345678901234567890.0):decimal(38,2)> +-- !query 15 output NULL + + +-- !query 16 +select 1e35 / 0.1 +-- !query 16 schema +struct<(CAST(1E+35 AS DECIMAL(37,1)) / CAST(0.1 AS DECIMAL(37,1))):decimal(38,6)> +-- !query 16 output +NULL + + +-- !query 17 +select 123456789123456789.1234567890 * 1.123456789123456789 +-- !query 17 schema +struct<(CAST(123456789123456789.1234567890 AS DECIMAL(36,18)) * CAST(1.123456789123456789 AS DECIMAL(36,18))):decimal(38,18)> +-- !query 17 output +138698367904130467.654320988515622621 + + +-- !query 18 +select 0.001 / 9876543210987654321098765432109876543.2 + +set spark.sql.decimalOperations.allowPrecisionLoss=false +-- !query 18 schema +struct<> +-- !query 18 output +org.apache.spark.sql.catalyst.parser.ParseException + +mismatched input 'spark' expecting (line 3, pos 4) + +== SQL == +select 0.001 / 9876543210987654321098765432109876543.2 + +set spark.sql.decimalOperations.allowPrecisionLoss=false +----^^^ + + +-- !query 19 +select id, a+b, a-b, a*b, a/b from decimals_test order by id +-- !query 19 schema +struct +-- !query 19 output +1 1099 -899 99900 0.1001 +2 24690.246 0 152402061.885129 1 +3 1234.2234567891011 -1233.9765432108989 152.358023 0.0001 +4 123456789123456790.12345678912345679 123456789123456787.87654321087654321 138698367904130467.515623 109890109097814272.043109 + + +-- !query 20 +select id, a*10, b/10 from decimals_test order by id +-- !query 20 schema +struct +-- !query 20 output +1 1000 99.9 +2 123451.23 1234.5123 +3 1.234567891011 123.41 +4 1234567891234567890 0.112345678912345679 + + +-- !query 21 +select 10.3 * 3.0 +-- !query 21 schema +struct<(CAST(10.3 AS DECIMAL(3,1)) * CAST(3.0 AS DECIMAL(3,1))):decimal(6,2)> +-- !query 21 output +30.9 + + +-- !query 22 +select 10.3000 * 3.0 +-- !query 22 schema +struct<(CAST(10.3000 AS DECIMAL(6,4)) * CAST(3.0 AS DECIMAL(6,4))):decimal(9,5)> +-- !query 22 output +30.9 + + +-- !query 23 +select 10.30000 * 30.0 +-- !query 23 schema +struct<(CAST(10.30000 AS DECIMAL(7,5)) * CAST(30.0 AS DECIMAL(7,5))):decimal(11,6)> +-- !query 23 output +309 + + +-- !query 24 +select 10.300000000000000000 * 3.000000000000000000 +-- !query 24 schema +struct<(CAST(10.300000000000000000 AS DECIMAL(20,18)) * CAST(3.000000000000000000 AS DECIMAL(20,18))):decimal(38,34)> +-- !query 24 output +30.9 + + +-- !query 25 +select 10.300000000000000000 * 3.0000000000000000000 +-- !query 25 schema +struct<(CAST(10.300000000000000000 AS DECIMAL(21,19)) * CAST(3.0000000000000000000 AS DECIMAL(21,19))):decimal(38,34)> +-- !query 25 output +30.9 + + +-- !query 26 +select (5e36 + 0.1) + 5e36 +-- !query 26 schema +struct<(CAST((CAST(5E+36 AS DECIMAL(38,1)) + CAST(0.1 AS DECIMAL(38,1))) AS DECIMAL(38,1)) + CAST(5E+36 AS DECIMAL(38,1))):decimal(38,1)> +-- !query 26 output +NULL + + +-- !query 27 +select (-4e36 - 0.1) - 7e36 +-- !query 27 schema +struct<(CAST((CAST(-4E+36 AS DECIMAL(38,1)) - CAST(0.1 AS DECIMAL(38,1))) AS DECIMAL(38,1)) - CAST(7E+36 AS DECIMAL(38,1))):decimal(38,1)> +-- !query 27 output +NULL + + +-- !query 28 +select 12345678901234567890.0 * 12345678901234567890.0 +-- !query 28 schema +struct<(12345678901234567890.0 * 12345678901234567890.0):decimal(38,2)> +-- !query 28 output +NULL + + +-- !query 29 +select 1e35 / 0.1 +-- !query 29 schema +struct<(CAST(1E+35 AS DECIMAL(37,1)) / CAST(0.1 AS DECIMAL(37,1))):decimal(38,6)> +-- !query 29 output +NULL + + +-- !query 30 +select 123456789123456789.1234567890 * 1.123456789123456789 +-- !query 30 schema +struct<(CAST(123456789123456789.1234567890 AS DECIMAL(36,18)) * CAST(1.123456789123456789 AS DECIMAL(36,18))):decimal(38,18)> +-- !query 30 output +138698367904130467.654320988515622621 + + +-- !query 31 +select 0.001 / 9876543210987654321098765432109876543.2 + +drop table decimals_test +-- !query 31 schema +struct<> +-- !query 31 output +org.apache.spark.sql.catalyst.parser.ParseException + +mismatched input 'table' expecting (line 3, pos 5) + +== SQL == +select 0.001 / 9876543210987654321098765432109876543.2 + +drop table decimals_test +-----^^^ diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/decimalPrecision.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/decimalPrecision.sql.out index ebc8201ed5a1d..6ee7f59d69877 100644 --- a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/decimalPrecision.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/decimalPrecision.sql.out @@ -2329,7 +2329,7 @@ struct<(CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(20,0)) / CAST(C -- !query 280 SELECT cast(1 as bigint) / cast(1 as decimal(20, 0)) FROM t -- !query 280 schema -struct<(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) / CAST(1 AS DECIMAL(20,0))):decimal(38,19)> +struct<(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) / CAST(1 AS DECIMAL(20,0))):decimal(38,18)> -- !query 280 output 1 @@ -2661,7 +2661,7 @@ struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(20,0)) / CAST(CAST(CAST(1 AS BI -- !query 320 SELECT cast(1 as decimal(20, 0)) / cast(1 as bigint) FROM t -- !query 320 schema -struct<(CAST(1 AS DECIMAL(20,0)) / CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0))):decimal(38,19)> +struct<(CAST(1 AS DECIMAL(20,0)) / CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0))):decimal(38,18)> -- !query 320 output 1 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index d4d0aa4f5f5eb..083a0c0b1b9a0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -1517,24 +1517,6 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } } - test("decimal precision with multiply/division") { - checkAnswer(sql("select 10.3 * 3.0"), Row(BigDecimal("30.90"))) - checkAnswer(sql("select 10.3000 * 3.0"), Row(BigDecimal("30.90000"))) - checkAnswer(sql("select 10.30000 * 30.0"), Row(BigDecimal("309.000000"))) - checkAnswer(sql("select 10.300000000000000000 * 3.000000000000000000"), - Row(BigDecimal("30.900000000000000000000000000000000000", new MathContext(38)))) - checkAnswer(sql("select 10.300000000000000000 * 3.0000000000000000000"), - Row(null)) - - checkAnswer(sql("select 10.3 / 3.0"), Row(BigDecimal("3.433333"))) - checkAnswer(sql("select 10.3000 / 3.0"), Row(BigDecimal("3.4333333"))) - checkAnswer(sql("select 10.30000 / 30.0"), Row(BigDecimal("0.343333333"))) - checkAnswer(sql("select 10.300000000000000000 / 3.00000000000000000"), - Row(BigDecimal("3.433333333333333333333333333", new MathContext(38)))) - checkAnswer(sql("select 10.3000000000000000000 / 3.00000000000000000"), - Row(BigDecimal("3.4333333333333333333333333333", new MathContext(38)))) - } - test("SPARK-10215 Div of Decimal returns null") { val d = Decimal(1.12321).toBigDecimal val df = Seq((d, 1)).toDF("a", "b") From 5063b7481173ad72bd0dc941b5cf3c9b26a591e4 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Thu, 18 Jan 2018 22:33:04 +0900 Subject: [PATCH 0131/1161] [SPARK-23141][SQL][PYSPARK] Support data type string as a returnType for registerJavaFunction. ## What changes were proposed in this pull request? Currently `UDFRegistration.registerJavaFunction` doesn't support data type string as a `returnType` whereas `UDFRegistration.register`, `udf`, or `pandas_udf` does. We can support it for `UDFRegistration.registerJavaFunction` as well. ## How was this patch tested? Added a doctest and existing tests. Author: Takuya UESHIN Closes #20307 from ueshin/issues/SPARK-23141. --- python/pyspark/sql/functions.py | 6 ++++-- python/pyspark/sql/udf.py | 14 ++++++++++++-- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 988c1d25259bc..961b3267b44cf 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2108,7 +2108,8 @@ def udf(f=None, returnType=StringType()): can fail on special rows, the workaround is to incorporate the condition into the functions. :param f: python function if used as a standalone function - :param returnType: a :class:`pyspark.sql.types.DataType` object + :param returnType: the return type of the user-defined function. The value can be either a + :class:`pyspark.sql.types.DataType` object or a DDL-formatted type string. >>> from pyspark.sql.types import IntegerType >>> slen = udf(lambda s: len(s), IntegerType()) @@ -2148,7 +2149,8 @@ def pandas_udf(f=None, returnType=None, functionType=None): Creates a vectorized user defined function (UDF). :param f: user-defined function. A python function if used as a standalone function - :param returnType: a :class:`pyspark.sql.types.DataType` object + :param returnType: the return type of the user-defined function. The value can be either a + :class:`pyspark.sql.types.DataType` object or a DDL-formatted type string. :param functionType: an enum value in :class:`pyspark.sql.functions.PandasUDFType`. Default: SCALAR. diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index 1943bb73f9ac2..c77f19f89a442 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -206,7 +206,8 @@ def register(self, name, f, returnType=None): :param f: a Python function, or a user-defined function. The user-defined function can be either row-at-a-time or vectorized. See :meth:`pyspark.sql.functions.udf` and :meth:`pyspark.sql.functions.pandas_udf`. - :param returnType: the return type of the registered user-defined function. + :param returnType: the return type of the registered user-defined function. The value can + be either a :class:`pyspark.sql.types.DataType` object or a DDL-formatted type string. :return: a user-defined function. `returnType` can be optionally specified when `f` is a Python function but not @@ -303,21 +304,30 @@ def registerJavaFunction(self, name, javaClassName, returnType=None): :param name: name of the user-defined function :param javaClassName: fully qualified name of java class - :param returnType: a :class:`pyspark.sql.types.DataType` object + :param returnType: the return type of the registered Java function. The value can be either + a :class:`pyspark.sql.types.DataType` object or a DDL-formatted type string. >>> from pyspark.sql.types import IntegerType >>> spark.udf.registerJavaFunction( ... "javaStringLength", "test.org.apache.spark.sql.JavaStringLength", IntegerType()) >>> spark.sql("SELECT javaStringLength('test')").collect() [Row(UDF:javaStringLength(test)=4)] + >>> spark.udf.registerJavaFunction( ... "javaStringLength2", "test.org.apache.spark.sql.JavaStringLength") >>> spark.sql("SELECT javaStringLength2('test')").collect() [Row(UDF:javaStringLength2(test)=4)] + + >>> spark.udf.registerJavaFunction( + ... "javaStringLength3", "test.org.apache.spark.sql.JavaStringLength", "integer") + >>> spark.sql("SELECT javaStringLength3('test')").collect() + [Row(UDF:javaStringLength3(test)=4)] """ jdt = None if returnType is not None: + if not isinstance(returnType, DataType): + returnType = _parse_datatype_string(returnType) jdt = self.sparkSession._jsparkSession.parseDataType(returnType.json()) self.sparkSession._jsparkSession.udf().registerJava(name, javaClassName, jdt) From cf7ee1767ddadce08dce050fc3b40c77cdd187da Mon Sep 17 00:00:00 2001 From: jerryshao Date: Thu, 18 Jan 2018 10:19:36 -0800 Subject: [PATCH 0132/1161] [SPARK-23147][UI] Fix task page table IndexOutOfBound Exception ## What changes were proposed in this pull request? Stage's task page table will throw an exception when there's no complete tasks. Furthermore, because the `dataSize` doesn't take running tasks into account, so sometimes UI cannot show the running tasks. Besides table will only be displayed when first task is finished according to the default sortColumn("index"). ![screen shot 2018-01-18 at 8 50 08 pm](https://user-images.githubusercontent.com/850797/35100052-470b4cae-fc95-11e7-96a2-ad9636e732b3.png) To reproduce this issue, user could try `sc.parallelize(1 to 20, 20).map { i => Thread.sleep(10000); i }.collect()` or `sc.parallelize(1 to 20, 20).map { i => Thread.sleep((20 - i) * 1000); i }.collect` to reproduce the above issue. Here propose a solution to fix it. Not sure if it is a right fix, please help to review. ## How was this patch tested? Manual test. Author: jerryshao Closes #20315 from jerryshao/SPARK-23147. --- core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index 7c6e06cf183ba..af78373ddb4b2 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -676,7 +676,7 @@ private[ui] class TaskDataSource( private var _tasksToShow: Seq[TaskData] = null - override def dataSize: Int = stage.numCompleteTasks + stage.numFailedTasks + stage.numKilledTasks + override def dataSize: Int = stage.numTasks override def sliceData(from: Int, to: Int): Seq[TaskData] = { if (_tasksToShow == null) { From 9678941f54ebc5db935ed8d694e502086e2a31c0 Mon Sep 17 00:00:00 2001 From: Fernando Pereira Date: Thu, 18 Jan 2018 13:02:03 -0600 Subject: [PATCH 0133/1161] [SPARK-23029][DOCS] Specifying default units of configuration entries ## What changes were proposed in this pull request? This PR completes the docs, specifying the default units assumed in configuration entries of type size. This is crucial since unit-less values are accepted and the user might assume the base unit is bytes, which in most cases it is not, leading to hard-to-debug problems. ## How was this patch tested? This patch updates only documentation only. Author: Fernando Pereira Closes #20269 from ferdonline/docs_units. --- .../scala/org/apache/spark/SparkConf.scala | 6 +- .../spark/internal/config/package.scala | 47 ++++---- docs/configuration.md | 100 ++++++++++-------- 3 files changed, 85 insertions(+), 68 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index d77303e6fdf8b..f53b2bed74c6e 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -640,9 +640,9 @@ private[spark] object SparkConf extends Logging { translation = s => s"${s.toLong * 10}s")), "spark.reducer.maxSizeInFlight" -> Seq( AlternateConfig("spark.reducer.maxMbInFlight", "1.4")), - "spark.kryoserializer.buffer" -> - Seq(AlternateConfig("spark.kryoserializer.buffer.mb", "1.4", - translation = s => s"${(s.toDouble * 1000).toInt}k")), + "spark.kryoserializer.buffer" -> Seq( + AlternateConfig("spark.kryoserializer.buffer.mb", "1.4", + translation = s => s"${(s.toDouble * 1000).toInt}k")), "spark.kryoserializer.buffer.max" -> Seq( AlternateConfig("spark.kryoserializer.buffer.max.mb", "1.4")), "spark.shuffle.file.buffer" -> Seq( diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index eb12ddf961314..bbfcfbaa7363c 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -38,10 +38,13 @@ package object config { ConfigBuilder("spark.driver.userClassPathFirst").booleanConf.createWithDefault(false) private[spark] val DRIVER_MEMORY = ConfigBuilder("spark.driver.memory") + .doc("Amount of memory to use for the driver process, in MiB unless otherwise specified.") .bytesConf(ByteUnit.MiB) .createWithDefaultString("1g") private[spark] val DRIVER_MEMORY_OVERHEAD = ConfigBuilder("spark.driver.memoryOverhead") + .doc("The amount of off-heap memory to be allocated per driver in cluster mode, " + + "in MiB unless otherwise specified.") .bytesConf(ByteUnit.MiB) .createOptional @@ -62,6 +65,7 @@ package object config { .createWithDefault(false) private[spark] val EVENT_LOG_OUTPUT_BUFFER_SIZE = ConfigBuilder("spark.eventLog.buffer.kb") + .doc("Buffer size to use when writing to output streams, in KiB unless otherwise specified.") .bytesConf(ByteUnit.KiB) .createWithDefaultString("100k") @@ -81,10 +85,13 @@ package object config { ConfigBuilder("spark.executor.userClassPathFirst").booleanConf.createWithDefault(false) private[spark] val EXECUTOR_MEMORY = ConfigBuilder("spark.executor.memory") + .doc("Amount of memory to use per executor process, in MiB unless otherwise specified.") .bytesConf(ByteUnit.MiB) .createWithDefaultString("1g") private[spark] val EXECUTOR_MEMORY_OVERHEAD = ConfigBuilder("spark.executor.memoryOverhead") + .doc("The amount of off-heap memory to be allocated per executor in cluster mode, " + + "in MiB unless otherwise specified.") .bytesConf(ByteUnit.MiB) .createOptional @@ -353,7 +360,7 @@ package object config { private[spark] val BUFFER_WRITE_CHUNK_SIZE = ConfigBuilder("spark.buffer.write.chunkSize") .internal() - .doc("The chunk size during writing out the bytes of ChunkedByteBuffer.") + .doc("The chunk size in bytes during writing out the bytes of ChunkedByteBuffer.") .bytesConf(ByteUnit.BYTE) .checkValue(_ <= Int.MaxValue, "The chunk size during writing out the bytes of" + " ChunkedByteBuffer should not larger than Int.MaxValue.") @@ -368,9 +375,9 @@ package object config { private[spark] val SHUFFLE_ACCURATE_BLOCK_THRESHOLD = ConfigBuilder("spark.shuffle.accurateBlockThreshold") - .doc("When we compress the size of shuffle blocks in HighlyCompressedMapStatus, we will " + - "record the size accurately if it's above this config. This helps to prevent OOM by " + - "avoiding underestimating shuffle block size when fetch shuffle blocks.") + .doc("Threshold in bytes above which the size of shuffle blocks in " + + "HighlyCompressedMapStatus is accurately recorded. This helps to prevent OOM " + + "by avoiding underestimating shuffle block size when fetch shuffle blocks.") .bytesConf(ByteUnit.BYTE) .createWithDefault(100 * 1024 * 1024) @@ -389,23 +396,23 @@ package object config { private[spark] val REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS = ConfigBuilder("spark.reducer.maxBlocksInFlightPerAddress") - .doc("This configuration limits the number of remote blocks being fetched per reduce task" + - " from a given host port. When a large number of blocks are being requested from a given" + - " address in a single fetch or simultaneously, this could crash the serving executor or" + - " Node Manager. This is especially useful to reduce the load on the Node Manager when" + - " external shuffle is enabled. You can mitigate the issue by setting it to a lower value.") + .doc("This configuration limits the number of remote blocks being fetched per reduce task " + + "from a given host port. When a large number of blocks are being requested from a given " + + "address in a single fetch or simultaneously, this could crash the serving executor or " + + "Node Manager. This is especially useful to reduce the load on the Node Manager when " + + "external shuffle is enabled. You can mitigate the issue by setting it to a lower value.") .intConf .checkValue(_ > 0, "The max no. of blocks in flight cannot be non-positive.") .createWithDefault(Int.MaxValue) private[spark] val MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM = ConfigBuilder("spark.maxRemoteBlockSizeFetchToMem") - .doc("Remote block will be fetched to disk when size of the block is " + - "above this threshold. This is to avoid a giant request takes too much memory. We can " + - "enable this config by setting a specific value(e.g. 200m). Note this configuration will " + - "affect both shuffle fetch and block manager remote block fetch. For users who " + - "enabled external shuffle service, this feature can only be worked when external shuffle" + - " service is newer than Spark 2.2.") + .doc("Remote block will be fetched to disk when size of the block is above this threshold " + + "in bytes. This is to avoid a giant request takes too much memory. We can enable this " + + "config by setting a specific value(e.g. 200m). Note this configuration will affect " + + "both shuffle fetch and block manager remote block fetch. For users who enabled " + + "external shuffle service, this feature can only be worked when external shuffle" + + "service is newer than Spark 2.2.") .bytesConf(ByteUnit.BYTE) .createWithDefault(Long.MaxValue) @@ -419,9 +426,9 @@ package object config { private[spark] val SHUFFLE_FILE_BUFFER_SIZE = ConfigBuilder("spark.shuffle.file.buffer") - .doc("Size of the in-memory buffer for each shuffle file output stream. " + - "These buffers reduce the number of disk seeks and system calls made " + - "in creating intermediate shuffle files.") + .doc("Size of the in-memory buffer for each shuffle file output stream, in KiB unless " + + "otherwise specified. These buffers reduce the number of disk seeks and system calls " + + "made in creating intermediate shuffle files.") .bytesConf(ByteUnit.KiB) .checkValue(v => v > 0 && v <= Int.MaxValue / 1024, s"The file buffer size must be greater than 0 and less than ${Int.MaxValue / 1024}.") @@ -430,7 +437,7 @@ package object config { private[spark] val SHUFFLE_UNSAFE_FILE_OUTPUT_BUFFER_SIZE = ConfigBuilder("spark.shuffle.unsafe.file.output.buffer") .doc("The file system for this buffer size after each partition " + - "is written in unsafe shuffle writer.") + "is written in unsafe shuffle writer. In KiB unless otherwise specified.") .bytesConf(ByteUnit.KiB) .checkValue(v => v > 0 && v <= Int.MaxValue / 1024, s"The buffer size must be greater than 0 and less than ${Int.MaxValue / 1024}.") @@ -438,7 +445,7 @@ package object config { private[spark] val SHUFFLE_DISK_WRITE_BUFFER_SIZE = ConfigBuilder("spark.shuffle.spill.diskWriteBufferSize") - .doc("The buffer size to use when writing the sorted records to an on-disk file.") + .doc("The buffer size, in bytes, to use when writing the sorted records to an on-disk file.") .bytesConf(ByteUnit.BYTE) .checkValue(v => v > 0 && v <= Int.MaxValue, s"The buffer size must be greater than 0 and less than ${Int.MaxValue}.") diff --git a/docs/configuration.md b/docs/configuration.md index 1189aea2aa71f..eecb39dcafc9e 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -58,6 +58,10 @@ The following format is accepted: 1t or 1tb (tebibytes = 1024 gibibytes) 1p or 1pb (pebibytes = 1024 tebibytes) +While numbers without units are generally interpreted as bytes, a few are interpreted as KiB or MiB. +See documentation of individual configuration properties. Specifying units is desirable where +possible. + ## Dynamically Loading Spark Properties In some cases, you may want to avoid hard-coding certain configurations in a `SparkConf`. For @@ -136,9 +140,9 @@ of the most common options to set are: spark.driver.maxResultSize 1g - Limit of total size of serialized results of all partitions for each Spark action (e.g. collect). - Should be at least 1M, or 0 for unlimited. Jobs will be aborted if the total size - is above this limit. + Limit of total size of serialized results of all partitions for each Spark action (e.g. + collect) in bytes. Should be at least 1M, or 0 for unlimited. Jobs will be aborted if the total + size is above this limit. Having a high limit may cause out-of-memory errors in driver (depends on spark.driver.memory and memory overhead of objects in JVM). Setting a proper limit can protect the driver from out-of-memory errors. @@ -148,10 +152,10 @@ of the most common options to set are: spark.driver.memory 1g - Amount of memory to use for the driver process, i.e. where SparkContext is initialized. - (e.g. 1g, 2g). - -
    Note: In client mode, this config must not be set through the SparkConf + Amount of memory to use for the driver process, i.e. where SparkContext is initialized, in MiB + unless otherwise specified (e.g. 1g, 2g). +
    + Note: In client mode, this config must not be set through the SparkConf directly in your application, because the driver JVM has already started at that point. Instead, please set this through the --driver-memory command line option or in your default properties file. @@ -161,27 +165,28 @@ of the most common options to set are: spark.driver.memoryOverhead driverMemory * 0.10, with minimum of 384 - The amount of off-heap memory (in megabytes) to be allocated per driver in cluster mode. This is - memory that accounts for things like VM overheads, interned strings, other native overheads, etc. - This tends to grow with the container size (typically 6-10%). This option is currently supported - on YARN and Kubernetes. + The amount of off-heap memory to be allocated per driver in cluster mode, in MiB unless + otherwise specified. This is memory that accounts for things like VM overheads, interned strings, + other native overheads, etc. This tends to grow with the container size (typically 6-10%). + This option is currently supported on YARN and Kubernetes. spark.executor.memory 1g - Amount of memory to use per executor process (e.g. 2g, 8g). + Amount of memory to use per executor process, in MiB unless otherwise specified. + (e.g. 2g, 8g). spark.executor.memoryOverhead executorMemory * 0.10, with minimum of 384 - The amount of off-heap memory (in megabytes) to be allocated per executor. This is memory that - accounts for things like VM overheads, interned strings, other native overheads, etc. This tends - to grow with the executor size (typically 6-10%). This option is currently supported on YARN and - Kubernetes. + The amount of off-heap memory to be allocated per executor, in MiB unless otherwise specified. + This is memory that accounts for things like VM overheads, interned strings, other native + overheads, etc. This tends to grow with the executor size (typically 6-10%). + This option is currently supported on YARN and Kubernetes. @@ -431,8 +436,9 @@ Apart from these, the following properties are also available, and may be useful 512m Amount of memory to use per python worker process during aggregation, in the same - format as JVM memory strings (e.g. 512m, 2g). If the memory - used during aggregation goes above this amount, it will spill the data into disks. + format as JVM memory strings with a size unit suffix ("k", "m", "g" or "t") + (e.g. 512m, 2g). + If the memory used during aggregation goes above this amount, it will spill the data into disks. @@ -540,9 +546,10 @@ Apart from these, the following properties are also available, and may be useful spark.reducer.maxSizeInFlight 48m - Maximum size of map outputs to fetch simultaneously from each reduce task. Since - each output requires us to create a buffer to receive it, this represents a fixed memory - overhead per reduce task, so keep it small unless you have a large amount of memory. + Maximum size of map outputs to fetch simultaneously from each reduce task, in MiB unless + otherwise specified. Since each output requires us to create a buffer to receive it, this + represents a fixed memory overhead per reduce task, so keep it small unless you have a + large amount of memory. @@ -570,9 +577,9 @@ Apart from these, the following properties are also available, and may be useful spark.maxRemoteBlockSizeFetchToMem Long.MaxValue - The remote block will be fetched to disk when size of the block is above this threshold. + The remote block will be fetched to disk when size of the block is above this threshold in bytes. This is to avoid a giant request takes too much memory. We can enable this config by setting - a specific value(e.g. 200m). Note this configuration will affect both shuffle fetch + a specific value(e.g. 200m). Note this configuration will affect both shuffle fetch and block manager remote block fetch. For users who enabled external shuffle service, this feature can only be worked when external shuffle service is newer than Spark 2.2. @@ -589,8 +596,9 @@ Apart from these, the following properties are also available, and may be useful spark.shuffle.file.buffer 32k - Size of the in-memory buffer for each shuffle file output stream. These buffers - reduce the number of disk seeks and system calls made in creating intermediate shuffle files. + Size of the in-memory buffer for each shuffle file output stream, in KiB unless otherwise + specified. These buffers reduce the number of disk seeks and system calls made in creating + intermediate shuffle files. @@ -651,7 +659,7 @@ Apart from these, the following properties are also available, and may be useful spark.shuffle.service.index.cache.size 100m - Cache entries limited to the specified memory footprint. + Cache entries limited to the specified memory footprint in bytes. @@ -685,9 +693,9 @@ Apart from these, the following properties are also available, and may be useful spark.shuffle.accurateBlockThreshold 100 * 1024 * 1024 - When we compress the size of shuffle blocks in HighlyCompressedMapStatus, we will record the - size accurately if it's above this config. This helps to prevent OOM by avoiding - underestimating shuffle block size when fetch shuffle blocks. + Threshold in bytes above which the size of shuffle blocks in HighlyCompressedMapStatus is + accurately recorded. This helps to prevent OOM by avoiding underestimating shuffle + block size when fetch shuffle blocks. @@ -779,7 +787,7 @@ Apart from these, the following properties are also available, and may be useful spark.eventLog.buffer.kb 100k - Buffer size in KB to use when writing to output streams. + Buffer size to use when writing to output streams, in KiB unless otherwise specified. @@ -917,7 +925,7 @@ Apart from these, the following properties are also available, and may be useful spark.io.compression.lz4.blockSize 32k - Block size used in LZ4 compression, in the case when LZ4 compression codec + Block size in bytes used in LZ4 compression, in the case when LZ4 compression codec is used. Lowering this block size will also lower shuffle memory usage when LZ4 is used. @@ -925,7 +933,7 @@ Apart from these, the following properties are also available, and may be useful spark.io.compression.snappy.blockSize 32k - Block size used in Snappy compression, in the case when Snappy compression codec + Block size in bytes used in Snappy compression, in the case when Snappy compression codec is used. Lowering this block size will also lower shuffle memory usage when Snappy is used. @@ -941,7 +949,7 @@ Apart from these, the following properties are also available, and may be useful spark.io.compression.zstd.bufferSize 32k - Buffer size used in Zstd compression, in the case when Zstd compression codec + Buffer size in bytes used in Zstd compression, in the case when Zstd compression codec is used. Lowering this size will lower the shuffle memory usage when Zstd is used, but it might increase the compression cost because of excessive JNI call overhead. @@ -1001,8 +1009,8 @@ Apart from these, the following properties are also available, and may be useful spark.kryoserializer.buffer.max 64m - Maximum allowable size of Kryo serialization buffer. This must be larger than any - object you attempt to serialize and must be less than 2048m. + Maximum allowable size of Kryo serialization buffer, in MiB unless otherwise specified. + This must be larger than any object you attempt to serialize and must be less than 2048m. Increase this if you get a "buffer limit exceeded" exception inside Kryo. @@ -1010,9 +1018,9 @@ Apart from these, the following properties are also available, and may be useful spark.kryoserializer.buffer 64k - Initial size of Kryo's serialization buffer. Note that there will be one buffer - per core on each worker. This buffer will grow up to - spark.kryoserializer.buffer.max if needed. + Initial size of Kryo's serialization buffer, in KiB unless otherwise specified. + Note that there will be one buffer per core on each worker. This buffer will grow up to + spark.kryoserializer.buffer.max if needed. @@ -1086,7 +1094,8 @@ Apart from these, the following properties are also available, and may be useful spark.memory.offHeap.enabled false - If true, Spark will attempt to use off-heap memory for certain operations. If off-heap memory use is enabled, then spark.memory.offHeap.size must be positive. + If true, Spark will attempt to use off-heap memory for certain operations. If off-heap memory + use is enabled, then spark.memory.offHeap.size must be positive. @@ -1094,7 +1103,8 @@ Apart from these, the following properties are also available, and may be useful 0 The absolute amount of memory in bytes which can be used for off-heap allocation. - This setting has no impact on heap memory usage, so if your executors' total memory consumption must fit within some hard limit then be sure to shrink your JVM heap size accordingly. + This setting has no impact on heap memory usage, so if your executors' total memory consumption + must fit within some hard limit then be sure to shrink your JVM heap size accordingly. This must be set to a positive value when spark.memory.offHeap.enabled=true. @@ -1202,9 +1212,9 @@ Apart from these, the following properties are also available, and may be useful spark.broadcast.blockSize 4m - Size of each piece of a block for TorrentBroadcastFactory. - Too large a value decreases parallelism during broadcast (makes it slower); however, if it is - too small, BlockManager might take a performance hit. + Size of each piece of a block for TorrentBroadcastFactory, in KiB unless otherwise + specified. Too large a value decreases parallelism during broadcast (makes it slower); however, + if it is too small, BlockManager might take a performance hit. @@ -1312,7 +1322,7 @@ Apart from these, the following properties are also available, and may be useful spark.storage.memoryMapThreshold 2m - Size of a block above which Spark memory maps when reading a block from disk. + Size in bytes of a block above which Spark memory maps when reading a block from disk. This prevents Spark from memory mapping very small blocks. In general, memory mapping has high overhead for blocks close to or below the page size of the operating system. @@ -2490,4 +2500,4 @@ Also, you can modify or add configurations at runtime: --conf "spark.executor.extraJavaOptions=-XX:+PrintGCDetails -XX:+PrintGCTimeStamps" \ --conf spark.hadoop.abc.def=xyz \ myApp.jar -{% endhighlight %} \ No newline at end of file +{% endhighlight %} From 2d41f040a34d6483919fd5d491cf90eee5429290 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Thu, 18 Jan 2018 12:25:52 -0800 Subject: [PATCH 0134/1161] [SPARK-23143][SS][PYTHON] Added python API for setting continuous trigger ## What changes were proposed in this pull request? Self-explanatory. ## How was this patch tested? New python tests. Author: Tathagata Das Closes #20309 from tdas/SPARK-23143. --- python/pyspark/sql/streaming.py | 23 +++++++++++++++++++---- python/pyspark/sql/tests.py | 6 ++++++ 2 files changed, 25 insertions(+), 4 deletions(-) diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index 24ae3776a217b..e2a97acb5e2a7 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -786,7 +786,7 @@ def queryName(self, queryName): @keyword_only @since(2.0) - def trigger(self, processingTime=None, once=None): + def trigger(self, processingTime=None, once=None, continuous=None): """Set the trigger for the stream query. If this is not set it will run the query as fast as possible, which is equivalent to setting the trigger to ``processingTime='0 seconds'``. @@ -802,23 +802,38 @@ def trigger(self, processingTime=None, once=None): >>> writer = sdf.writeStream.trigger(processingTime='5 seconds') >>> # trigger the query for just once batch of data >>> writer = sdf.writeStream.trigger(once=True) + >>> # trigger the query for execution every 5 seconds + >>> writer = sdf.writeStream.trigger(continuous='5 seconds') """ + params = [processingTime, once, continuous] + + if params.count(None) == 3: + raise ValueError('No trigger provided') + elif params.count(None) < 2: + raise ValueError('Multiple triggers not allowed.') + jTrigger = None if processingTime is not None: - if once is not None: - raise ValueError('Multiple triggers not allowed.') if type(processingTime) != str or len(processingTime.strip()) == 0: raise ValueError('Value for processingTime must be a non empty string. Got: %s' % processingTime) interval = processingTime.strip() jTrigger = self._spark._sc._jvm.org.apache.spark.sql.streaming.Trigger.ProcessingTime( interval) + elif once is not None: if once is not True: raise ValueError('Value for once must be True. Got: %s' % once) jTrigger = self._spark._sc._jvm.org.apache.spark.sql.streaming.Trigger.Once() + else: - raise ValueError('No trigger provided') + if type(continuous) != str or len(continuous.strip()) == 0: + raise ValueError('Value for continuous must be a non empty string. Got: %s' % + continuous) + interval = continuous.strip() + jTrigger = self._spark._sc._jvm.org.apache.spark.sql.streaming.Trigger.Continuous( + interval) + self._jwrite = self._jwrite.trigger(jTrigger) return self diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index f84aa3d68b808..25483594f2725 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -1538,6 +1538,12 @@ def test_stream_trigger(self): except ValueError: pass + # Should not take multiple args + try: + df.writeStream.trigger(processingTime='5 seconds', continuous='1 second') + except ValueError: + pass + # Should take only keyword args try: df.writeStream.trigger('5 seconds') From bf34d665b9c865e00fac7001500bf6d521c2dff9 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Thu, 18 Jan 2018 12:33:39 -0800 Subject: [PATCH 0135/1161] [SPARK-23144][SS] Added console sink for continuous processing ## What changes were proposed in this pull request? Refactored ConsoleWriter into ConsoleMicrobatchWriter and ConsoleContinuousWriter. ## How was this patch tested? new unit test Author: Tathagata Das Closes #20311 from tdas/SPARK-23144. --- .../sql/execution/streaming/console.scala | 20 +++-- .../streaming/sources/ConsoleWriter.scala | 80 ++++++++++++++----- .../sources/ConsoleWriterSuite.scala | 26 +++++- 3 files changed, 96 insertions(+), 30 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala index 94820376ff7e7..f2aa3259731d1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala @@ -19,13 +19,12 @@ package org.apache.spark.sql.execution.streaming import java.util.Optional -import scala.collection.JavaConverters._ - import org.apache.spark.sql._ -import org.apache.spark.sql.execution.streaming.sources.ConsoleWriter +import org.apache.spark.sql.execution.streaming.sources.{ConsoleContinuousWriter, ConsoleMicroBatchWriter, ConsoleWriter} import org.apache.spark.sql.sources.{BaseRelation, CreatableRelationProvider, DataSourceRegister} import org.apache.spark.sql.sources.v2.{DataSourceV2, DataSourceV2Options} -import org.apache.spark.sql.sources.v2.streaming.MicroBatchWriteSupport +import org.apache.spark.sql.sources.v2.streaming.{ContinuousWriteSupport, MicroBatchWriteSupport} +import org.apache.spark.sql.sources.v2.streaming.writer.ContinuousWriter import org.apache.spark.sql.sources.v2.writer.DataSourceV2Writer import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType @@ -37,16 +36,25 @@ case class ConsoleRelation(override val sqlContext: SQLContext, data: DataFrame) class ConsoleSinkProvider extends DataSourceV2 with MicroBatchWriteSupport + with ContinuousWriteSupport with DataSourceRegister with CreatableRelationProvider { override def createMicroBatchWriter( queryId: String, - epochId: Long, + batchId: Long, schema: StructType, mode: OutputMode, options: DataSourceV2Options): Optional[DataSourceV2Writer] = { - Optional.of(new ConsoleWriter(epochId, schema, options)) + Optional.of(new ConsoleMicroBatchWriter(batchId, schema, options)) + } + + override def createContinuousWriter( + queryId: String, + schema: StructType, + mode: OutputMode, + options: DataSourceV2Options): Optional[ContinuousWriter] = { + Optional.of(new ConsoleContinuousWriter(schema, options)) } def createRelation( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala index 361979984bbec..6fb61dff60045 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala @@ -20,45 +20,85 @@ package org.apache.spark.sql.execution.streaming.sources import org.apache.spark.internal.Logging import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.sql.sources.v2.DataSourceV2Options +import org.apache.spark.sql.sources.v2.streaming.writer.ContinuousWriter import org.apache.spark.sql.sources.v2.writer.{DataSourceV2Writer, DataWriterFactory, WriterCommitMessage} import org.apache.spark.sql.types.StructType -/** - * A [[DataSourceV2Writer]] that collects results to the driver and prints them in the console. - * Generated by [[org.apache.spark.sql.execution.streaming.ConsoleSinkProvider]]. - * - * This sink should not be used for production, as it requires sending all rows to the driver - * and does not support recovery. - */ -class ConsoleWriter(batchId: Long, schema: StructType, options: DataSourceV2Options) - extends DataSourceV2Writer with Logging { +/** Common methods used to create writes for the the console sink */ +trait ConsoleWriter extends Logging { + + def options: DataSourceV2Options + // Number of rows to display, by default 20 rows - private val numRowsToShow = options.getInt("numRows", 20) + protected val numRowsToShow = options.getInt("numRows", 20) // Truncate the displayed data if it is too long, by default it is true - private val isTruncated = options.getBoolean("truncate", true) + protected val isTruncated = options.getBoolean("truncate", true) assert(SparkSession.getActiveSession.isDefined) - private val spark = SparkSession.getActiveSession.get + protected val spark = SparkSession.getActiveSession.get + + def createWriterFactory(): DataWriterFactory[Row] = PackedRowWriterFactory - override def createWriterFactory(): DataWriterFactory[Row] = PackedRowWriterFactory + def abort(messages: Array[WriterCommitMessage]): Unit = {} - override def commit(messages: Array[WriterCommitMessage]): Unit = synchronized { - val batch = messages.collect { + protected def printRows( + commitMessages: Array[WriterCommitMessage], + schema: StructType, + printMessage: String): Unit = { + val rows = commitMessages.collect { case PackedRowCommitMessage(rows) => rows }.flatten // scalastyle:off println println("-------------------------------------------") - println(s"Batch: $batchId") + println(printMessage) println("-------------------------------------------") // scalastyle:off println - spark.createDataFrame( - spark.sparkContext.parallelize(batch), schema) + spark + .createDataFrame(spark.sparkContext.parallelize(rows), schema) .show(numRowsToShow, isTruncated) } +} + + +/** + * A [[DataSourceV2Writer]] that collects results from a micro-batch query to the driver and + * prints them in the console. Created by + * [[org.apache.spark.sql.execution.streaming.ConsoleSinkProvider]]. + * + * This sink should not be used for production, as it requires sending all rows to the driver + * and does not support recovery. + */ +class ConsoleMicroBatchWriter(batchId: Long, schema: StructType, val options: DataSourceV2Options) + extends DataSourceV2Writer with ConsoleWriter { + + override def commit(messages: Array[WriterCommitMessage]): Unit = { + printRows(messages, schema, s"Batch: $batchId") + } + + override def toString(): String = { + s"ConsoleMicroBatchWriter[numRows=$numRowsToShow, truncate=$isTruncated]" + } +} - override def abort(messages: Array[WriterCommitMessage]): Unit = {} - override def toString(): String = s"ConsoleWriter[numRows=$numRowsToShow, truncate=$isTruncated]" +/** + * A [[DataSourceV2Writer]] that collects results from a continuous query to the driver and + * prints them in the console. Created by + * [[org.apache.spark.sql.execution.streaming.ConsoleSinkProvider]]. + * + * This sink should not be used for production, as it requires sending all rows to the driver + * and does not support recovery. + */ +class ConsoleContinuousWriter(schema: StructType, val options: DataSourceV2Options) + extends ContinuousWriter with ConsoleWriter { + + override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = { + printRows(messages, schema, s"Continuous processing epoch $epochId") + } + + override def toString(): String = { + s"ConsoleContinuousWriter[numRows=$numRowsToShow, truncate=$isTruncated]" + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriterSuite.scala index 60ffee9b9b42c..55acf2ba28d2f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriterSuite.scala @@ -19,13 +19,15 @@ package org.apache.spark.sql.execution.streaming.sources import java.io.ByteArrayOutputStream +import org.scalatest.time.SpanSugar._ + import org.apache.spark.sql.execution.streaming.MemoryStream -import org.apache.spark.sql.streaming.StreamTest +import org.apache.spark.sql.streaming.{StreamTest, Trigger} class ConsoleWriterSuite extends StreamTest { import testImplicits._ - test("console") { + test("microbatch - default") { val input = MemoryStream[Int] val captured = new ByteArrayOutputStream() @@ -77,7 +79,7 @@ class ConsoleWriterSuite extends StreamTest { |""".stripMargin) } - test("console with numRows") { + test("microbatch - with numRows") { val input = MemoryStream[Int] val captured = new ByteArrayOutputStream() @@ -106,7 +108,7 @@ class ConsoleWriterSuite extends StreamTest { |""".stripMargin) } - test("console with truncation") { + test("microbatch - truncation") { val input = MemoryStream[String] val captured = new ByteArrayOutputStream() @@ -132,4 +134,20 @@ class ConsoleWriterSuite extends StreamTest { | |""".stripMargin) } + + test("continuous - default") { + val captured = new ByteArrayOutputStream() + Console.withOut(captured) { + val input = spark.readStream + .format("rate") + .option("numPartitions", "1") + .option("rowsPerSecond", "5") + .load() + .select('value) + + val query = input.writeStream.format("console").trigger(Trigger.Continuous(200)).start() + assert(query.isActive) + query.stop() + } + } } From f568e9cf76f657d094f1d036ab5a95f2531f5761 Mon Sep 17 00:00:00 2001 From: Andrew Korzhuev Date: Thu, 18 Jan 2018 14:00:12 -0800 Subject: [PATCH 0136/1161] [SPARK-23133][K8S] Fix passing java options to Executor Pass through spark java options to the executor in context of docker image. Closes #20296 andrusha: Deployed two version of containers to local k8s, checked that java options were present in the updated image on the running executor. Manual test Author: Andrew Korzhuev Closes #20322 from foxish/patch-1. --- .../docker/src/main/dockerfiles/spark/entrypoint.sh | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh index 0c28c75857871..b9090dc2852a5 100755 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh @@ -42,7 +42,7 @@ shift 1 SPARK_CLASSPATH="$SPARK_CLASSPATH:${SPARK_HOME}/jars/*" env | grep SPARK_JAVA_OPT_ | sed 's/[^=]*=\(.*\)/\1/g' > /tmp/java_opts.txt -readarray -t SPARK_DRIVER_JAVA_OPTS < /tmp/java_opts.txt +readarray -t SPARK_JAVA_OPTS < /tmp/java_opts.txt if [ -n "$SPARK_MOUNTED_CLASSPATH" ]; then SPARK_CLASSPATH="$SPARK_CLASSPATH:$SPARK_MOUNTED_CLASSPATH" fi @@ -54,7 +54,7 @@ case "$SPARK_K8S_CMD" in driver) CMD=( ${JAVA_HOME}/bin/java - "${SPARK_DRIVER_JAVA_OPTS[@]}" + "${SPARK_JAVA_OPTS[@]}" -cp "$SPARK_CLASSPATH" -Xms$SPARK_DRIVER_MEMORY -Xmx$SPARK_DRIVER_MEMORY @@ -67,7 +67,7 @@ case "$SPARK_K8S_CMD" in executor) CMD=( ${JAVA_HOME}/bin/java - "${SPARK_EXECUTOR_JAVA_OPTS[@]}" + "${SPARK_JAVA_OPTS[@]}" -Xms$SPARK_EXECUTOR_MEMORY -Xmx$SPARK_EXECUTOR_MEMORY -cp "$SPARK_CLASSPATH" From e01919e834d301e13adc8919932796ebae900576 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Fri, 19 Jan 2018 07:36:06 +0900 Subject: [PATCH 0137/1161] [SPARK-23094] Fix invalid character handling in JsonDataSource ## What changes were proposed in this pull request? There were two related fixes regarding `from_json`, `get_json_object` and `json_tuple` ([Fix #1](https://github.com/apache/spark/commit/c8803c06854683c8761fdb3c0e4c55d5a9e22a95), [Fix #2](https://github.com/apache/spark/commit/86174ea89b39a300caaba6baffac70f3dc702788)), but they weren't comprehensive it seems. I wanted to extend those fixes to all the parsers, and add tests for each case. ## How was this patch tested? Regression tests Author: Burak Yavuz Closes #20302 from brkyvz/json-invfix. --- .../catalyst/json/CreateJacksonParser.scala | 5 +-- .../sources/JsonHadoopFsRelationSuite.scala | 34 +++++++++++++++++++ 2 files changed, 37 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/CreateJacksonParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/CreateJacksonParser.scala index 025a388aacaa5..b1672e7e2fca2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/CreateJacksonParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/CreateJacksonParser.scala @@ -40,10 +40,11 @@ private[sql] object CreateJacksonParser extends Serializable { } def text(jsonFactory: JsonFactory, record: Text): JsonParser = { - jsonFactory.createParser(record.getBytes, 0, record.getLength) + val bain = new ByteArrayInputStream(record.getBytes, 0, record.getLength) + jsonFactory.createParser(new InputStreamReader(bain, "UTF-8")) } def inputStream(jsonFactory: JsonFactory, record: InputStream): JsonParser = { - jsonFactory.createParser(record) + jsonFactory.createParser(new InputStreamReader(record, "UTF-8")) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/JsonHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/JsonHadoopFsRelationSuite.scala index 49be30435ad2f..27f398ebf301a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/JsonHadoopFsRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/JsonHadoopFsRelationSuite.scala @@ -28,6 +28,8 @@ import org.apache.spark.sql.types._ class JsonHadoopFsRelationSuite extends HadoopFsRelationTest { override val dataSourceName: String = "json" + private val badJson = "\u0000\u0000\u0000A\u0001AAA" + // JSON does not write data of NullType and does not play well with BinaryType. override protected def supportsDataType(dataType: DataType): Boolean = dataType match { case _: NullType => false @@ -105,4 +107,36 @@ class JsonHadoopFsRelationSuite extends HadoopFsRelationTest { ) } } + + test("invalid json with leading nulls - from file (multiLine=true)") { + import testImplicits._ + withTempDir { tempDir => + val path = tempDir.getAbsolutePath + Seq(badJson, """{"a":1}""").toDS().write.mode("overwrite").text(path) + val expected = s"""$badJson\n{"a":1}\n""" + val schema = new StructType().add("a", IntegerType).add("_corrupt_record", StringType) + val df = + spark.read.format(dataSourceName).option("multiLine", true).schema(schema).load(path) + checkAnswer(df, Row(null, expected)) + } + } + + test("invalid json with leading nulls - from file (multiLine=false)") { + import testImplicits._ + withTempDir { tempDir => + val path = tempDir.getAbsolutePath + Seq(badJson, """{"a":1}""").toDS().write.mode("overwrite").text(path) + val schema = new StructType().add("a", IntegerType).add("_corrupt_record", StringType) + val df = + spark.read.format(dataSourceName).option("multiLine", false).schema(schema).load(path) + checkAnswer(df, Seq(Row(1, null), Row(null, badJson))) + } + } + + test("invalid json with leading nulls - from dataset") { + import testImplicits._ + checkAnswer( + spark.read.json(Seq(badJson).toDS()), + Row(badJson)) + } } From 5d7c4ba4d73a72f26d591108db3c20b4a6c84f3f Mon Sep 17 00:00:00 2001 From: Yinan Li Date: Thu, 18 Jan 2018 14:44:22 -0800 Subject: [PATCH 0138/1161] [SPARK-22962][K8S] Fail fast if submission client local files are used ## What changes were proposed in this pull request? In the Kubernetes mode, fails fast in the submission process if any submission client local dependencies are used as the use case is not supported yet. ## How was this patch tested? Unit tests, integration tests, and manual tests. vanzin foxish Author: Yinan Li Closes #20320 from liyinan926/master. --- docs/running-on-kubernetes.md | 5 ++- .../k8s/submit/DriverConfigOrchestrator.scala | 14 ++++++++- .../DriverConfigOrchestratorSuite.scala | 31 ++++++++++++++++++- 3 files changed, 47 insertions(+), 3 deletions(-) diff --git a/docs/running-on-kubernetes.md b/docs/running-on-kubernetes.md index 08ec34c63ba3f..d6b1735ce5550 100644 --- a/docs/running-on-kubernetes.md +++ b/docs/running-on-kubernetes.md @@ -117,7 +117,10 @@ This URI is the location of the example jar that is already in the Docker image. If your application's dependencies are all hosted in remote locations like HDFS or HTTP servers, they may be referred to by their appropriate remote URIs. Also, application dependencies can be pre-mounted into custom-built Docker images. Those dependencies can be added to the classpath by referencing them with `local://` URIs and/or setting the -`SPARK_EXTRA_CLASSPATH` environment variable in your Dockerfiles. +`SPARK_EXTRA_CLASSPATH` environment variable in your Dockerfiles. The `local://` scheme is also required when referring to +dependencies in custom-built Docker images in `spark-submit`. Note that using application dependencies from the submission +client's local file system is currently not yet supported. + ### Using Remote Dependencies When there are application dependencies hosted in remote locations like HDFS or HTTP servers, the driver and executor pods diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestrator.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestrator.scala index c9cc300d65569..ae70904621184 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestrator.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestrator.scala @@ -20,7 +20,7 @@ import java.util.UUID import com.google.common.primitives.Longs -import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, SparkException} import org.apache.spark.deploy.k8s.{KubernetesUtils, MountSecretsBootstrap} import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ @@ -117,6 +117,12 @@ private[spark] class DriverConfigOrchestrator( .map(_.split(",")) .getOrElse(Array.empty[String]) + // TODO(SPARK-23153): remove once submission client local dependencies are supported. + if (existSubmissionLocalFiles(sparkJars) || existSubmissionLocalFiles(sparkFiles)) { + throw new SparkException("The Kubernetes mode does not yet support referencing application " + + "dependencies in the local file system.") + } + val dependencyResolutionStep = if (sparkJars.nonEmpty || sparkFiles.nonEmpty) { Seq(new DependencyResolutionStep( sparkJars, @@ -162,6 +168,12 @@ private[spark] class DriverConfigOrchestrator( initContainerBootstrapStep } + private def existSubmissionLocalFiles(files: Seq[String]): Boolean = { + files.exists { uri => + Utils.resolveURI(uri).getScheme == "file" + } + } + private def existNonContainerLocalFiles(files: Seq[String]): Boolean = { files.exists { uri => Utils.resolveURI(uri).getScheme != "local" diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestratorSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestratorSuite.scala index 65274c6f50e01..033d303e946fd 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestratorSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestratorSuite.scala @@ -16,7 +16,7 @@ */ package org.apache.spark.deploy.k8s.submit -import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.{SparkConf, SparkException, SparkFunSuite} import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.submit.steps._ @@ -117,6 +117,35 @@ class DriverConfigOrchestratorSuite extends SparkFunSuite { classOf[DriverMountSecretsStep]) } + test("Submission using client local dependencies") { + val sparkConf = new SparkConf(false) + .set(CONTAINER_IMAGE, DRIVER_IMAGE) + var orchestrator = new DriverConfigOrchestrator( + APP_ID, + LAUNCH_TIME, + Some(JavaMainAppResource("file:///var/apps/jars/main.jar")), + APP_NAME, + MAIN_CLASS, + APP_ARGS, + sparkConf) + assertThrows[SparkException] { + orchestrator.getAllConfigurationSteps + } + + sparkConf.set("spark.files", "/path/to/file1,/path/to/file2") + orchestrator = new DriverConfigOrchestrator( + APP_ID, + LAUNCH_TIME, + Some(JavaMainAppResource("local:///var/apps/jars/main.jar")), + APP_NAME, + MAIN_CLASS, + APP_ARGS, + sparkConf) + assertThrows[SparkException] { + orchestrator.getAllConfigurationSteps + } + } + private def validateStepTypes( orchestrator: DriverConfigOrchestrator, types: Class[_ <: DriverConfigurationStep]*): Unit = { From 4cd2ecc0c7222fef1337e04f1948333296c3be86 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Thu, 18 Jan 2018 16:29:45 -0800 Subject: [PATCH 0139/1161] [SPARK-23142][SS][DOCS] Added docs for continuous processing ## What changes were proposed in this pull request? Added documentation for continuous processing. Modified two locations. - Modified the overview to have a mention of Continuous Processing. - Added a new section on Continuous Processing at the end. ![image](https://user-images.githubusercontent.com/663212/35083551-a3dd23f6-fbd4-11e7-9e7e-90866f131ca9.png) ![image](https://user-images.githubusercontent.com/663212/35083618-d844027c-fbd4-11e7-9fde-75992cc517bd.png) ## How was this patch tested? N/A Author: Tathagata Das Closes #20308 from tdas/SPARK-23142. --- .../structured-streaming-programming-guide.md | 98 ++++++++++++++++++- 1 file changed, 97 insertions(+), 1 deletion(-) diff --git a/docs/structured-streaming-programming-guide.md b/docs/structured-streaming-programming-guide.md index 1779a4215e085..2ddba2f0d942e 100644 --- a/docs/structured-streaming-programming-guide.md +++ b/docs/structured-streaming-programming-guide.md @@ -10,7 +10,9 @@ title: Structured Streaming Programming Guide # Overview Structured Streaming is a scalable and fault-tolerant stream processing engine built on the Spark SQL engine. You can express your streaming computation the same way you would express a batch computation on static data. The Spark SQL engine will take care of running it incrementally and continuously and updating the final result as streaming data continues to arrive. You can use the [Dataset/DataFrame API](sql-programming-guide.html) in Scala, Java, Python or R to express streaming aggregations, event-time windows, stream-to-batch joins, etc. The computation is executed on the same optimized Spark SQL engine. Finally, the system ensures end-to-end exactly-once fault-tolerance guarantees through checkpointing and Write Ahead Logs. In short, *Structured Streaming provides fast, scalable, fault-tolerant, end-to-end exactly-once stream processing without the user having to reason about streaming.* -In this guide, we are going to walk you through the programming model and the APIs. First, let's start with a simple example - a streaming word count. +Internally, by default, Structured Streaming queries are processed using a *micro-batch processing* engine, which processes data streams as a series of small batch jobs thereby achieving end-to-end latencies as low as 100 milliseconds and exactly-once fault-tolerance guarantees. However, since Spark 2.3, we have introduced a new low-latency processing mode called **Continuous Processing**, which can achieve end-to-end latencies as low as 1 millisecond with at-least-once guarantees. Without changing the Dataset/DataFrame operations in your queries, you will be able choose the mode based on your application requirements. + +In this guide, we are going to walk you through the programming model and the APIs. We are going to explain the concepts mostly using the default micro-batch processing model, and then [later](#continuous-processing-experimental) discuss Continuous Processing model. First, let's start with a simple example of a Structured Streaming query - a streaming word count. # Quick Example Let’s say you want to maintain a running word count of text data received from a data server listening on a TCP socket. Let’s see how you can express this using Structured Streaming. You can see the full code in @@ -2434,6 +2436,100 @@ write.stream(aggDF, "memory", outputMode = "complete", checkpointLocation = "pat
+# Continuous Processing [Experimental] +**Continuous processing** is a new, experimental streaming execution mode introduced in Spark 2.3 that enables low (~1 ms) end-to-end latency with at-least-once fault-tolerance guarantees. Compare this with the default *micro-batch processing* engine which can achieve exactly-once guarantees but achieve latencies of ~100ms at best. For some types of queries (discussed below), you can choose which mode to execute them in without modifying the application logic (i.e. without changing the DataFrame/Dataset operations). + +To run a supported query in continuous processing mode, all you need to do is specify a **continuous trigger** with the desired checkpoint interval as a parameter. For example, + +
+
+{% highlight scala %} +import org.apache.spark.sql.streaming.Trigger + +spark + .readStream + .format("rate") + .option("rowsPerSecond", "10") + .option("") + +spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .option("subscribe", "topic1") + .load() + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .writeStream + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .option("topic", "topic1") + .trigger(Trigger.Continuous("1 second")) // only change in query + .start() +{% endhighlight %} +
+
+{% highlight java %} +import org.apache.spark.sql.streaming.Trigger; + +spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .option("subscribe", "topic1") + .load() + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .writeStream + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .option("topic", "topic1") + .trigger(Trigger.Continuous("1 second")) // only change in query + .start(); +{% endhighlight %} +
+
+{% highlight python %} +spark \ + .readStream \ + .format("kafka") \ + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") \ + .option("subscribe", "topic1") \ + .load() \ + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") \ + .writeStream \ + .format("kafka") \ + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") \ + .option("topic", "topic1") \ + .trigger(continuous="1 second") \ # only change in query + .start() + +{% endhighlight %} +
+
+ +A checkpoint interval of 1 second means that the continuous processing engine will records the progress of the query every second. The resulting checkpoints are in a format compatible with the micro-batch engine, hence any query can be restarted with any trigger. For example, a supported query started with the micro-batch mode can be restarted in continuous mode, and vice versa. Note that any time you switch to continuous mode, you will get at-least-once fault-tolerance guarantees. + +## Supported Queries +As of Spark 2.3, only the following type of queries are supported in the continuous processing mode. + +- *Operations*: Only map-like Dataset/DataFrame operations are supported in continuous mode, that is, only projections (`select`, `map`, `flatMap`, `mapPartitions`, etc.) and selections (`where`, `filter`, etc.). + + All SQL functions are supported except aggregation functions (since aggregations are not yet supported), `current_timestamp()` and `current_date()` (deterministic computations using time is challenging). + +- *Sources*: + + Kafka source: All options are supported. + + Rate source: Good for testing. Only options that are supported in the continuous mode are `numPartitions` and `rowsPerSecond`. + +- *Sinks*: + + Kafka sink: All options are supported. + + Memory sink: Good for debugging. + + Console sink: Good for debugging. All options are supported. Note that the console will print every checkpoint interval that you have specified in the continuous trigger. + +See [Input Sources](#input-sources) and [Output Sinks](#output-sinks) sections for more details on them. While the console sink is good for testing, the end-to-end low-latency processing can be best observed with Kafka as the source and sink, as this allows the engine to process the data and make the results available in the output topic within milliseconds of the input data being available in the input topic. + +## Caveats +- Continuous processing engine launches multiple long-running tasks that continuously read data from sources, process it and continuously write to sinks. The number of tasks required by the query depends on how many partitions the query can read from the sources in parallel. Therefore, before starting a continuous processing query, you must ensure there are enough cores in the cluster to all the tasks in parallel. For example, if you are reading from a Kafka topic that has 10 partitions, then the cluster must have at least 10 cores for the query to make progress. +- Stopping a continuous processing stream may produce spurious task termination warnings. These can be safely ignored. +- There are currently no automatic retries of failed tasks. Any failure will lead to the query being stopped and it needs to be manually restarted from the checkpoint. + # Additional Information **Further Reading** From 6121e91b7f5c9513d68674e4d5edbc3a4a5fd5fd Mon Sep 17 00:00:00 2001 From: brandonJY Date: Thu, 18 Jan 2018 18:57:49 -0600 Subject: [PATCH 0140/1161] [DOCS] change to dataset for java code in structured-streaming-kafka-integration document ## What changes were proposed in this pull request? In latest structured-streaming-kafka-integration document, Java code example for Kafka integration is using `DataFrame`, shouldn't it be changed to `DataSet`? ## How was this patch tested? manual test has been performed to test the updated example Java code in Spark 2.2.1 with Kafka 1.0 Author: brandonJY Closes #20312 from brandonJY/patch-2. --- docs/structured-streaming-kafka-integration.md | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/docs/structured-streaming-kafka-integration.md b/docs/structured-streaming-kafka-integration.md index bab0be8ddeb9f..461c29ce1ba89 100644 --- a/docs/structured-streaming-kafka-integration.md +++ b/docs/structured-streaming-kafka-integration.md @@ -61,7 +61,7 @@ df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") {% highlight java %} // Subscribe to 1 topic -DataFrame df = spark +Dataset df = spark .readStream() .format("kafka") .option("kafka.bootstrap.servers", "host1:port1,host2:port2") @@ -70,7 +70,7 @@ DataFrame df = spark df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") // Subscribe to multiple topics -DataFrame df = spark +Dataset df = spark .readStream() .format("kafka") .option("kafka.bootstrap.servers", "host1:port1,host2:port2") @@ -79,7 +79,7 @@ DataFrame df = spark df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") // Subscribe to a pattern -DataFrame df = spark +Dataset df = spark .readStream() .format("kafka") .option("kafka.bootstrap.servers", "host1:port1,host2:port2") @@ -171,7 +171,7 @@ df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") {% highlight java %} // Subscribe to 1 topic defaults to the earliest and latest offsets -DataFrame df = spark +Dataset df = spark .read() .format("kafka") .option("kafka.bootstrap.servers", "host1:port1,host2:port2") @@ -180,7 +180,7 @@ DataFrame df = spark df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)"); // Subscribe to multiple topics, specifying explicit Kafka offsets -DataFrame df = spark +Dataset df = spark .read() .format("kafka") .option("kafka.bootstrap.servers", "host1:port1,host2:port2") @@ -191,7 +191,7 @@ DataFrame df = spark df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)"); // Subscribe to a pattern, at the earliest and latest offsets -DataFrame df = spark +Dataset df = spark .read() .format("kafka") .option("kafka.bootstrap.servers", "host1:port1,host2:port2") From 568055da93049c207bb830f244ff9b60c638837c Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Fri, 19 Jan 2018 11:37:08 +0800 Subject: [PATCH 0141/1161] [SPARK-23054][SQL][PYSPARK][FOLLOWUP] Use sqlType casting when casting PythonUserDefinedType to String. ## What changes were proposed in this pull request? This is a follow-up of #20246. If a UDT in Python doesn't have its corresponding Scala UDT, cast to string will be the raw string of the internal value, e.g. `"org.apache.spark.sql.catalyst.expressions.UnsafeArrayDataxxxxxxxx"` if the internal type is `ArrayType`. This pr fixes it by using its `sqlType` casting. ## How was this patch tested? Added a test and existing tests. Author: Takuya UESHIN Closes #20306 from ueshin/issues/SPARK-23054/fup1. --- python/pyspark/sql/tests.py | 11 +++++++++++ .../apache/spark/sql/catalyst/expressions/Cast.scala | 2 ++ .../org/apache/spark/sql/test/ExamplePointUDT.scala | 2 ++ 3 files changed, 15 insertions(+) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 25483594f2725..4fee2ecde391b 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -1189,6 +1189,17 @@ def test_union_with_udt(self): ] ) + def test_cast_to_string_with_udt(self): + from pyspark.sql.tests import ExamplePointUDT, ExamplePoint + from pyspark.sql.functions import col + row = (ExamplePoint(1.0, 2.0), PythonOnlyPoint(3.0, 4.0)) + schema = StructType([StructField("point", ExamplePointUDT(), False), + StructField("pypoint", PythonOnlyUDT(), False)]) + df = self.spark.createDataFrame([row], schema) + + result = df.select(col('point').cast('string'), col('pypoint').cast('string')).head() + self.assertEqual(result, Row(point=u'(1.0, 2.0)', pypoint=u'[3.0, 4.0]')) + def test_column_operators(self): ci = self.df.key cs = self.df.value diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index a95ebe301b9d1..79b051670e9e4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -282,6 +282,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String builder.append("]") builder.build() }) + case pudt: PythonUserDefinedType => castToString(pudt.sqlType) case udt: UserDefinedType[_] => buildCast[Any](_, o => UTF8String.fromString(udt.deserialize(o).toString)) case _ => buildCast[Any](_, o => UTF8String.fromString(o.toString)) @@ -838,6 +839,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String |$evPrim = $buffer.build(); """.stripMargin } + case pudt: PythonUserDefinedType => castToStringCode(pudt.sqlType, ctx) case udt: UserDefinedType[_] => val udtRef = ctx.addReferenceObj("udt", udt) (c, evPrim, evNull) => { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala b/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala index a73e4272950a4..8bab7e1c58762 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala @@ -34,6 +34,8 @@ private[sql] class ExamplePoint(val x: Double, val y: Double) extends Serializab case that: ExamplePoint => this.x == that.x && this.y == that.y case _ => false } + + override def toString(): String = s"($x, $y)" } /** From 9c4b99861cda3f9ec44ca8c1adc81a293508190c Mon Sep 17 00:00:00 2001 From: Sameer Agarwal Date: Fri, 19 Jan 2018 01:38:08 -0800 Subject: [PATCH 0142/1161] [BUILD][MINOR] Fix java style check issues ## What changes were proposed in this pull request? This patch fixes a few recently introduced java style check errors in master and release branch. As an aside, given that [java linting currently fails](https://github.com/apache/spark/pull/10763 ) on machines with a clean maven cache, it'd be great to find another workaround to [re-enable the java style checks](https://github.com/apache/spark/blob/3a07eff5af601511e97a05e6fea0e3d48f74c4f0/dev/run-tests.py#L577) as part of Spark PRB. /cc zsxwing JoshRosen srowen for any suggestions ## How was this patch tested? Manual Check Author: Sameer Agarwal Closes #20323 from sameeragarwal/java. --- .../spark/sql/sources/v2/writer/DataSourceV2Writer.java | 6 ++++-- .../org/apache/spark/sql/vectorized/ArrowColumnVector.java | 5 +++-- .../apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java | 3 ++- 3 files changed, 9 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceV2Writer.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceV2Writer.java index 317ac45bcfd74..f1ef411423162 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceV2Writer.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceV2Writer.java @@ -28,8 +28,10 @@ /** * A data source writer that is returned by * {@link WriteSupport#createWriter(String, StructType, SaveMode, DataSourceV2Options)}/ - * {@link org.apache.spark.sql.sources.v2.streaming.MicroBatchWriteSupport#createMicroBatchWriter(String, long, StructType, OutputMode, DataSourceV2Options)}/ - * {@link org.apache.spark.sql.sources.v2.streaming.ContinuousWriteSupport#createContinuousWriter(String, StructType, OutputMode, DataSourceV2Options)}. + * {@link org.apache.spark.sql.sources.v2.streaming.MicroBatchWriteSupport#createMicroBatchWriter( + * String, long, StructType, OutputMode, DataSourceV2Options)}/ + * {@link org.apache.spark.sql.sources.v2.streaming.ContinuousWriteSupport#createContinuousWriter( + * String, StructType, OutputMode, DataSourceV2Options)}. * It can mix in various writing optimization interfaces to speed up the data saving. The actual * writing logic is delegated to {@link DataWriter}. * diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java index eb69001fe677e..bfd1b4cb0ef12 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java @@ -556,8 +556,9 @@ final int getArrayOffset(int rowId) { /** * Any call to "get" method will throw UnsupportedOperationException. * - * Access struct values in a ArrowColumnVector doesn't use this accessor. Instead, it uses getStruct() method defined - * in the parent class. Any call to "get" method in this class is a bug in the code. + * Access struct values in a ArrowColumnVector doesn't use this accessor. Instead, it uses + * getStruct() method defined in the parent class. Any call to "get" method in this class is a + * bug in the code. * */ private static class StructAccessor extends ArrowVectorAccessor { diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java index 44e5146d7c553..98d6a53b54d28 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java @@ -69,7 +69,8 @@ public DataReader createDataReader() { ColumnVector[] vectors = new ColumnVector[2]; vectors[0] = i; vectors[1] = j; - this.batch = new ColumnarBatch(new StructType().add("i", "int").add("j", "int"), vectors, BATCH_SIZE); + this.batch = + new ColumnarBatch(new StructType().add("i", "int").add("j", "int"), vectors, BATCH_SIZE); return this; } From 60203fca6a605ad158184e1e0ce5187e144a3ea7 Mon Sep 17 00:00:00 2001 From: Nick Pentreath Date: Fri, 19 Jan 2018 12:43:23 +0200 Subject: [PATCH 0143/1161] [SPARK-23127][DOC] Update FeatureHasher guide for categoricalCols parameter Update user guide entry for `FeatureHasher` to match the Scala / Python doc, to describe the `categoricalCols` parameter. ## How was this patch tested? Doc only Author: Nick Pentreath Closes #20293 from MLnick/SPARK-23127-catCol-userguide. --- docs/ml-features.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/ml-features.md b/docs/ml-features.md index 72643137d96b1..10183c3e78c76 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -222,9 +222,9 @@ The `FeatureHasher` transformer operates on multiple columns. Each column may co numeric or categorical features. Behavior and handling of column data types is as follows: - Numeric columns: For numeric features, the hash value of the column name is used to map the -feature value to its index in the feature vector. Numeric features are never treated as -categorical, even when they are integers. You must explicitly convert numeric columns containing -categorical features to strings first. +feature value to its index in the feature vector. By default, numeric features are not treated +as categorical (even when they are integers). To treat them as categorical, specify the relevant +columns using the `categoricalCols` parameter. - String columns: For categorical features, the hash value of the string "column_name=value" is used to map to the vector index, with an indicator value of `1.0`. Thus, categorical features are "one-hot" encoded (similarly to using [OneHotEncoder](ml-features.html#onehotencoder) with From b74366481cc87490adf4e69d26389ec737548c15 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 19 Jan 2018 12:48:42 +0200 Subject: [PATCH 0144/1161] [SPARK-23048][ML] Add OneHotEncoderEstimator document and examples ## What changes were proposed in this pull request? We have `OneHotEncoderEstimator` now and `OneHotEncoder` will be deprecated since 2.3.0. We should add `OneHotEncoderEstimator` into mllib document. We also need to provide corresponding examples for `OneHotEncoderEstimator` which are used in the document too. ## How was this patch tested? Existing tests. Author: Liang-Chi Hsieh Closes #20257 from viirya/SPARK-23048. --- docs/ml-features.md | 28 ++++++++----- ...=> JavaOneHotEncoderEstimatorExample.java} | 41 ++++++++----------- ...py => onehot_encoder_estimator_example.py} | 29 +++++++------ ...la => OneHotEncoderEstimatorExample.scala} | 40 ++++++++---------- 4 files changed, 68 insertions(+), 70 deletions(-) rename examples/src/main/java/org/apache/spark/examples/ml/{JavaOneHotEncoderExample.java => JavaOneHotEncoderEstimatorExample.java} (62%) rename examples/src/main/python/ml/{onehot_encoder_example.py => onehot_encoder_estimator_example.py} (65%) rename examples/src/main/scala/org/apache/spark/examples/ml/{OneHotEncoderExample.scala => OneHotEncoderEstimatorExample.scala} (65%) diff --git a/docs/ml-features.md b/docs/ml-features.md index 10183c3e78c76..466a8fbe99cf6 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -775,35 +775,43 @@ for more details on the API.
-## OneHotEncoder +## OneHotEncoder (Deprecated since 2.3.0) -[One-hot encoding](http://en.wikipedia.org/wiki/One-hot) maps a column of label indices to a column of binary vectors, with at most a single one-value. This encoding allows algorithms which expect continuous features, such as Logistic Regression, to use categorical features. +Because this existing `OneHotEncoder` is a stateless transformer, it is not usable on new data where the number of categories may differ from the training data. In order to fix this, a new `OneHotEncoderEstimator` was created that produces an `OneHotEncoderModel` when fitting. For more detail, please see [SPARK-13030](https://issues.apache.org/jira/browse/SPARK-13030). + +`OneHotEncoder` has been deprecated in 2.3.0 and will be removed in 3.0.0. Please use [OneHotEncoderEstimator](ml-features.html#onehotencoderestimator) instead. + +## OneHotEncoderEstimator + +[One-hot encoding](http://en.wikipedia.org/wiki/One-hot) maps a categorical feature, represented as a label index, to a binary vector with at most a single one-value indicating the presence of a specific feature value from among the set of all feature values. This encoding allows algorithms which expect continuous features, such as Logistic Regression, to use categorical features. For string type input data, it is common to encode categorical features using [StringIndexer](ml-features.html#stringindexer) first. + +`OneHotEncoderEstimator` can transform multiple columns, returning an one-hot-encoded output vector column for each input column. It is common to merge these vectors into a single feature vector using [VectorAssembler](ml-features.html#vectorassembler). + +`OneHotEncoderEstimator` supports the `handleInvalid` parameter to choose how to handle invalid input during transforming data. Available options include 'keep' (any invalid inputs are assigned to an extra categorical index) and 'error' (throw an error). **Examples**
-Refer to the [OneHotEncoder Scala docs](api/scala/index.html#org.apache.spark.ml.feature.OneHotEncoder) -for more details on the API. +Refer to the [OneHotEncoderEstimator Scala docs](api/scala/index.html#org.apache.spark.ml.feature.OneHotEncoderEstimator) for more details on the API. -{% include_example scala/org/apache/spark/examples/ml/OneHotEncoderExample.scala %} +{% include_example scala/org/apache/spark/examples/ml/OneHotEncoderEstimatorExample.scala %}
-Refer to the [OneHotEncoder Java docs](api/java/org/apache/spark/ml/feature/OneHotEncoder.html) +Refer to the [OneHotEncoderEstimator Java docs](api/java/org/apache/spark/ml/feature/OneHotEncoderEstimator.html) for more details on the API. -{% include_example java/org/apache/spark/examples/ml/JavaOneHotEncoderExample.java %} +{% include_example java/org/apache/spark/examples/ml/JavaOneHotEncoderEstimatorExample.java %}
-Refer to the [OneHotEncoder Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.OneHotEncoder) -for more details on the API. +Refer to the [OneHotEncoderEstimator Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.OneHotEncoderEstimator) for more details on the API. -{% include_example python/ml/onehot_encoder_example.py %} +{% include_example python/ml/onehot_encoder_estimator_example.py %}
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaOneHotEncoderExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaOneHotEncoderEstimatorExample.java similarity index 62% rename from examples/src/main/java/org/apache/spark/examples/ml/JavaOneHotEncoderExample.java rename to examples/src/main/java/org/apache/spark/examples/ml/JavaOneHotEncoderEstimatorExample.java index 99af37676ba98..6f93cff94b725 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaOneHotEncoderExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaOneHotEncoderEstimatorExample.java @@ -23,9 +23,8 @@ import java.util.Arrays; import java.util.List; -import org.apache.spark.ml.feature.OneHotEncoder; -import org.apache.spark.ml.feature.StringIndexer; -import org.apache.spark.ml.feature.StringIndexerModel; +import org.apache.spark.ml.feature.OneHotEncoderEstimator; +import org.apache.spark.ml.feature.OneHotEncoderModel; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; @@ -35,41 +34,37 @@ import org.apache.spark.sql.types.StructType; // $example off$ -public class JavaOneHotEncoderExample { +public class JavaOneHotEncoderEstimatorExample { public static void main(String[] args) { SparkSession spark = SparkSession .builder() - .appName("JavaOneHotEncoderExample") + .appName("JavaOneHotEncoderEstimatorExample") .getOrCreate(); + // Note: categorical features are usually first encoded with StringIndexer // $example on$ List data = Arrays.asList( - RowFactory.create(0, "a"), - RowFactory.create(1, "b"), - RowFactory.create(2, "c"), - RowFactory.create(3, "a"), - RowFactory.create(4, "a"), - RowFactory.create(5, "c") + RowFactory.create(0.0, 1.0), + RowFactory.create(1.0, 0.0), + RowFactory.create(2.0, 1.0), + RowFactory.create(0.0, 2.0), + RowFactory.create(0.0, 1.0), + RowFactory.create(2.0, 0.0) ); StructType schema = new StructType(new StructField[]{ - new StructField("id", DataTypes.IntegerType, false, Metadata.empty()), - new StructField("category", DataTypes.StringType, false, Metadata.empty()) + new StructField("categoryIndex1", DataTypes.DoubleType, false, Metadata.empty()), + new StructField("categoryIndex2", DataTypes.DoubleType, false, Metadata.empty()) }); Dataset df = spark.createDataFrame(data, schema); - StringIndexerModel indexer = new StringIndexer() - .setInputCol("category") - .setOutputCol("categoryIndex") - .fit(df); - Dataset indexed = indexer.transform(df); + OneHotEncoderEstimator encoder = new OneHotEncoderEstimator() + .setInputCols(new String[] {"categoryIndex1", "categoryIndex2"}) + .setOutputCols(new String[] {"categoryVec1", "categoryVec2"}); - OneHotEncoder encoder = new OneHotEncoder() - .setInputCol("categoryIndex") - .setOutputCol("categoryVec"); - - Dataset encoded = encoder.transform(indexed); + OneHotEncoderModel model = encoder.fit(df); + Dataset encoded = model.transform(df); encoded.show(); // $example off$ diff --git a/examples/src/main/python/ml/onehot_encoder_example.py b/examples/src/main/python/ml/onehot_encoder_estimator_example.py similarity index 65% rename from examples/src/main/python/ml/onehot_encoder_example.py rename to examples/src/main/python/ml/onehot_encoder_estimator_example.py index e1996c7f0a55b..2723e681cea7c 100644 --- a/examples/src/main/python/ml/onehot_encoder_example.py +++ b/examples/src/main/python/ml/onehot_encoder_estimator_example.py @@ -18,32 +18,31 @@ from __future__ import print_function # $example on$ -from pyspark.ml.feature import OneHotEncoder, StringIndexer +from pyspark.ml.feature import OneHotEncoderEstimator # $example off$ from pyspark.sql import SparkSession if __name__ == "__main__": spark = SparkSession\ .builder\ - .appName("OneHotEncoderExample")\ + .appName("OneHotEncoderEstimatorExample")\ .getOrCreate() + # Note: categorical features are usually first encoded with StringIndexer # $example on$ df = spark.createDataFrame([ - (0, "a"), - (1, "b"), - (2, "c"), - (3, "a"), - (4, "a"), - (5, "c") - ], ["id", "category"]) + (0.0, 1.0), + (1.0, 0.0), + (2.0, 1.0), + (0.0, 2.0), + (0.0, 1.0), + (2.0, 0.0) + ], ["categoryIndex1", "categoryIndex2"]) - stringIndexer = StringIndexer(inputCol="category", outputCol="categoryIndex") - model = stringIndexer.fit(df) - indexed = model.transform(df) - - encoder = OneHotEncoder(inputCol="categoryIndex", outputCol="categoryVec") - encoded = encoder.transform(indexed) + encoder = OneHotEncoderEstimator(inputCols=["categoryIndex1", "categoryIndex2"], + outputCols=["categoryVec1", "categoryVec2"]) + model = encoder.fit(df) + encoded = model.transform(df) encoded.show() # $example off$ diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/OneHotEncoderExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/OneHotEncoderEstimatorExample.scala similarity index 65% rename from examples/src/main/scala/org/apache/spark/examples/ml/OneHotEncoderExample.scala rename to examples/src/main/scala/org/apache/spark/examples/ml/OneHotEncoderEstimatorExample.scala index 274cc1268f4d1..45d816808ed8e 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/OneHotEncoderExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/OneHotEncoderEstimatorExample.scala @@ -19,38 +19,34 @@ package org.apache.spark.examples.ml // $example on$ -import org.apache.spark.ml.feature.{OneHotEncoder, StringIndexer} +import org.apache.spark.ml.feature.OneHotEncoderEstimator // $example off$ import org.apache.spark.sql.SparkSession -object OneHotEncoderExample { +object OneHotEncoderEstimatorExample { def main(args: Array[String]): Unit = { val spark = SparkSession .builder - .appName("OneHotEncoderExample") + .appName("OneHotEncoderEstimatorExample") .getOrCreate() + // Note: categorical features are usually first encoded with StringIndexer // $example on$ val df = spark.createDataFrame(Seq( - (0, "a"), - (1, "b"), - (2, "c"), - (3, "a"), - (4, "a"), - (5, "c") - )).toDF("id", "category") - - val indexer = new StringIndexer() - .setInputCol("category") - .setOutputCol("categoryIndex") - .fit(df) - val indexed = indexer.transform(df) - - val encoder = new OneHotEncoder() - .setInputCol("categoryIndex") - .setOutputCol("categoryVec") - - val encoded = encoder.transform(indexed) + (0.0, 1.0), + (1.0, 0.0), + (2.0, 1.0), + (0.0, 2.0), + (0.0, 1.0), + (2.0, 0.0) + )).toDF("categoryIndex1", "categoryIndex2") + + val encoder = new OneHotEncoderEstimator() + .setInputCols(Array("categoryIndex1", "categoryIndex2")) + .setOutputCols(Array("categoryVec1", "categoryVec2")) + val model = encoder.fit(df) + + val encoded = model.transform(df) encoded.show() // $example off$ From e41400c3c8aace9eb72e6134173f222627fb0faf Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Fri, 19 Jan 2018 19:46:48 +0800 Subject: [PATCH 0145/1161] [SPARK-23089][STS] Recreate session log directory if it doesn't exist ## What changes were proposed in this pull request? When creating a session directory, Thrift should create the parent directory (i.e. /tmp/base_session_log_dir) if it is not present. It is common that many tools delete empty directories, so the directory may be deleted. This can cause the session log to be disabled. This was fixed in HIVE-12262: this PR brings it in Spark too. ## How was this patch tested? manual tests Author: Marco Gaido Closes #20281 from mgaido91/SPARK-23089. --- .../hive/service/cli/session/HiveSessionImpl.java | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/session/HiveSessionImpl.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/session/HiveSessionImpl.java index 47bfaa86021d6..108074cce3d6d 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/session/HiveSessionImpl.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/session/HiveSessionImpl.java @@ -223,6 +223,18 @@ private void configureSession(Map sessionConfMap) throws HiveSQL @Override public void setOperationLogSessionDir(File operationLogRootDir) { + if (!operationLogRootDir.exists()) { + LOG.warn("The operation log root directory is removed, recreating: " + + operationLogRootDir.getAbsolutePath()); + if (!operationLogRootDir.mkdirs()) { + LOG.warn("Unable to create operation log root directory: " + + operationLogRootDir.getAbsolutePath()); + } + } + if (!operationLogRootDir.canWrite()) { + LOG.warn("The operation log root directory is not writable: " + + operationLogRootDir.getAbsolutePath()); + } sessionLogDir = new File(operationLogRootDir, sessionHandle.getHandleIdentifier().toString()); isOperationLogEnabled = true; if (!sessionLogDir.exists()) { From e1c33b6cd14e4e1123814f4d040e3520db7d1ec9 Mon Sep 17 00:00:00 2001 From: guoxiaolong Date: Fri, 19 Jan 2018 08:22:24 -0600 Subject: [PATCH 0146/1161] [SPARK-23024][WEB-UI] Spark ui about the contents of the form need to have hidden and show features, when the table records very much. ## What changes were proposed in this pull request? Spark ui about the contents of the form need to have hidden and show features, when the table records very much. Because sometimes you do not care about the record of the table, you just want to see the contents of the next table, but you have to scroll the scroll bar for a long time to see the contents of the next table. Currently we have about 500 workers, but I just wanted to see the logs for the running applications table. I had to scroll through the scroll bars for a long time to see the logs for the running applications table. In order to ensure functional consistency, I modified the Master Page, Worker Page, Job Page, Stage Page, Task Page, Configuration Page, Storage Page, Pool Page. fix before: ![1](https://user-images.githubusercontent.com/26266482/34805936-601ed628-f6bb-11e7-8dd3-d8413573a076.png) fix after: ![2](https://user-images.githubusercontent.com/26266482/34805949-6af8afba-f6bb-11e7-89f4-ab16584916fb.png) ## How was this patch tested? manual tests Please review http://spark.apache.org/contributing.html before opening a pull request. Author: guoxiaolong Closes #20216 from guoxiaolongzte/SPARK-23024. --- .../org/apache/spark/ui/static/webui.js | 30 ++++++++ .../deploy/master/ui/ApplicationPage.scala | 25 +++++-- .../spark/deploy/master/ui/MasterPage.scala | 63 ++++++++++++++--- .../spark/deploy/worker/ui/WorkerPage.scala | 52 +++++++++++--- .../apache/spark/ui/env/EnvironmentPage.scala | 48 +++++++++++-- .../apache/spark/ui/jobs/AllJobsPage.scala | 39 +++++++++-- .../apache/spark/ui/jobs/AllStagesPage.scala | 67 +++++++++++++++--- .../org/apache/spark/ui/jobs/JobPage.scala | 68 ++++++++++++++++--- .../org/apache/spark/ui/jobs/PoolPage.scala | 13 +++- .../org/apache/spark/ui/jobs/StagePage.scala | 12 +++- .../apache/spark/ui/storage/StoragePage.scala | 12 +++- 11 files changed, 373 insertions(+), 56 deletions(-) diff --git a/core/src/main/resources/org/apache/spark/ui/static/webui.js b/core/src/main/resources/org/apache/spark/ui/static/webui.js index 0fa1fcf25f8b9..e575c4c78970d 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/webui.js +++ b/core/src/main/resources/org/apache/spark/ui/static/webui.js @@ -50,4 +50,34 @@ function collapseTable(thisName, table){ // to remember if it's collapsed on each page reload $(function() { collapseTablePageLoad('collapse-aggregated-metrics','aggregated-metrics'); + collapseTablePageLoad('collapse-aggregated-executors','aggregated-executors'); + collapseTablePageLoad('collapse-aggregated-removedExecutors','aggregated-removedExecutors'); + collapseTablePageLoad('collapse-aggregated-workers','aggregated-workers'); + collapseTablePageLoad('collapse-aggregated-activeApps','aggregated-activeApps'); + collapseTablePageLoad('collapse-aggregated-activeDrivers','aggregated-activeDrivers'); + collapseTablePageLoad('collapse-aggregated-completedApps','aggregated-completedApps'); + collapseTablePageLoad('collapse-aggregated-completedDrivers','aggregated-completedDrivers'); + collapseTablePageLoad('collapse-aggregated-runningExecutors','aggregated-runningExecutors'); + collapseTablePageLoad('collapse-aggregated-runningDrivers','aggregated-runningDrivers'); + collapseTablePageLoad('collapse-aggregated-finishedExecutors','aggregated-finishedExecutors'); + collapseTablePageLoad('collapse-aggregated-finishedDrivers','aggregated-finishedDrivers'); + collapseTablePageLoad('collapse-aggregated-runtimeInformation','aggregated-runtimeInformation'); + collapseTablePageLoad('collapse-aggregated-sparkProperties','aggregated-sparkProperties'); + collapseTablePageLoad('collapse-aggregated-systemProperties','aggregated-systemProperties'); + collapseTablePageLoad('collapse-aggregated-classpathEntries','aggregated-classpathEntries'); + collapseTablePageLoad('collapse-aggregated-activeJobs','aggregated-activeJobs'); + collapseTablePageLoad('collapse-aggregated-completedJobs','aggregated-completedJobs'); + collapseTablePageLoad('collapse-aggregated-failedJobs','aggregated-failedJobs'); + collapseTablePageLoad('collapse-aggregated-poolTable','aggregated-poolTable'); + collapseTablePageLoad('collapse-aggregated-allActiveStages','aggregated-allActiveStages'); + collapseTablePageLoad('collapse-aggregated-allPendingStages','aggregated-allPendingStages'); + collapseTablePageLoad('collapse-aggregated-allCompletedStages','aggregated-allCompletedStages'); + collapseTablePageLoad('collapse-aggregated-allFailedStages','aggregated-allFailedStages'); + collapseTablePageLoad('collapse-aggregated-activeStages','aggregated-activeStages'); + collapseTablePageLoad('collapse-aggregated-pendingOrSkippedStages','aggregated-pendingOrSkippedStages'); + collapseTablePageLoad('collapse-aggregated-completedStages','aggregated-completedStages'); + collapseTablePageLoad('collapse-aggregated-failedStages','aggregated-failedStages'); + collapseTablePageLoad('collapse-aggregated-poolActiveStages','aggregated-poolActiveStages'); + collapseTablePageLoad('collapse-aggregated-tasks','aggregated-tasks'); + collapseTablePageLoad('collapse-aggregated-rdds','aggregated-rdds'); }); \ No newline at end of file diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala index 68e57b7564ad1..f699c75085fe1 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala @@ -100,12 +100,29 @@ private[ui] class ApplicationPage(parent: MasterWebUI) extends WebUIPage("app")
-

Executor Summary ({allExecutors.length})

- {executorsTable} + +

+ + Executor Summary ({allExecutors.length}) +

+
+
+ {executorsTable} +
{ if (removedExecutors.nonEmpty) { -

Removed Executors ({removedExecutors.length})

++ - removedExecutorsTable + +

+ + Removed Executors ({removedExecutors.length}) +

+
++ +
+ {removedExecutorsTable} +
} }
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala index bc0bf6a1d9700..c629937606b51 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala @@ -128,15 +128,31 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") {
-

Workers ({workers.length})

- {workerTable} + +

+ + Workers ({workers.length}) +

+
+
+ {workerTable} +
-

Running Applications ({activeApps.length})

- {activeAppsTable} + +

+ + Running Applications ({activeApps.length}) +

+
+
+ {activeAppsTable} +
@@ -144,8 +160,17 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { {if (hasDrivers) {
-

Running Drivers ({activeDrivers.length})

- {activeDriversTable} + +

+ + Running Drivers ({activeDrivers.length}) +

+
+
+ {activeDriversTable} +
} @@ -154,8 +179,17 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") {
-

Completed Applications ({completedApps.length})

- {completedAppsTable} + +

+ + Completed Applications ({completedApps.length}) +

+
+
+ {completedAppsTable} +
@@ -164,8 +198,17 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { if (hasDrivers) {
-

Completed Drivers ({completedDrivers.length})

- {completedDriversTable} + +

+ + Completed Drivers ({completedDrivers.length}) +

+
+
+ {completedDriversTable} +
} diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala index ce84bc4dae32c..8b98ae56fc108 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala @@ -77,24 +77,60 @@ private[ui] class WorkerPage(parent: WorkerWebUI) extends WebUIPage("") {
-

Running Executors ({runningExecutors.size})

- {runningExecutorTable} + +

+ + Running Executors ({runningExecutors.size}) +

+
+
+ {runningExecutorTable} +
{ if (runningDrivers.nonEmpty) { -

Running Drivers ({runningDrivers.size})

++ - runningDriverTable + +

+ + Running Drivers ({runningDrivers.size}) +

+
++ +
+ {runningDriverTable} +
} } { if (finishedExecutors.nonEmpty) { -

Finished Executors ({finishedExecutors.size})

++ - finishedExecutorTable + +

+ + Finished Executors ({finishedExecutors.size}) +

+
++ +
+ {finishedExecutorTable} +
} } { if (finishedDrivers.nonEmpty) { -

Finished Drivers ({finishedDrivers.size})

++ - finishedDriverTable + +

+ + Finished Drivers ({finishedDrivers.size}) +

+
++ +
+ {finishedDriverTable} +
} }
diff --git a/core/src/main/scala/org/apache/spark/ui/env/EnvironmentPage.scala b/core/src/main/scala/org/apache/spark/ui/env/EnvironmentPage.scala index 43adab7a35d65..902eb92b854f2 100644 --- a/core/src/main/scala/org/apache/spark/ui/env/EnvironmentPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/env/EnvironmentPage.scala @@ -48,10 +48,50 @@ private[ui] class EnvironmentPage( classPathHeaders, classPathRow, appEnv.classpathEntries, fixedWidth = true) val content = -

Runtime Information

{runtimeInformationTable} -

Spark Properties

{sparkPropertiesTable} -

System Properties

{systemPropertiesTable} -

Classpath Entries

{classpathEntriesTable} + +

+ + Runtime Information +

+
+
+ {runtimeInformationTable} +
+ +

+ + Spark Properties +

+
+
+ {sparkPropertiesTable} +
+ +

+ + System Properties +

+
+
+ {systemPropertiesTable} +
+ +

+ + Classpath Entries +

+
+
+ {classpathEntriesTable} +
UIUtils.headerSparkPage("Environment", content, parent) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala index ff916bb6a5759..e3b72f1f34859 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala @@ -363,16 +363,43 @@ private[ui] class AllJobsPage(parent: JobsTab, store: AppStatusStore) extends We store.executorList(false), startTime) if (shouldShowActiveJobs) { - content ++=

Active Jobs ({activeJobs.size})

++ - activeJobsTable + content ++= + +

+ + Active Jobs ({activeJobs.size}) +

+
++ +
+ {activeJobsTable} +
} if (shouldShowCompletedJobs) { - content ++=

Completed Jobs ({completedJobNumStr})

++ - completedJobsTable + content ++= + +

+ + Completed Jobs ({completedJobNumStr}) +

+
++ +
+ {completedJobsTable} +
} if (shouldShowFailedJobs) { - content ++=

Failed Jobs ({failedJobs.size})

++ - failedJobsTable + content ++= + +

+ + Failed Jobs ({failedJobs.size}) +

+
++ +
+ {failedJobsTable} +
} val helpText = """A job is triggered by an action, like count() or saveAsTextFile().""" + diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala index b1e343451e28e..606dc1e180e5b 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala @@ -116,26 +116,75 @@ private[ui] class AllStagesPage(parent: StagesTab) extends WebUIPage("") { var content = summary ++ { if (sc.isDefined && isFairScheduler) { -

Fair Scheduler Pools ({pools.size})

++ poolTable.toNodeSeq + +

+ + Fair Scheduler Pools ({pools.size}) +

+
++ +
+ {poolTable.toNodeSeq} +
} else { Seq.empty[Node] } } if (shouldShowActiveStages) { - content ++=

Active Stages ({activeStages.size})

++ - activeStagesTable.toNodeSeq + content ++= + +

+ + Active Stages ({activeStages.size}) +

+
++ +
+ {activeStagesTable.toNodeSeq} +
} if (shouldShowPendingStages) { - content ++=

Pending Stages ({pendingStages.size})

++ - pendingStagesTable.toNodeSeq + content ++= + +

+ + Pending Stages ({pendingStages.size}) +

+
++ +
+ {pendingStagesTable.toNodeSeq} +
} if (shouldShowCompletedStages) { - content ++=

Completed Stages ({completedStageNumStr})

++ - completedStagesTable.toNodeSeq + content ++= + +

+ + Completed Stages ({completedStageNumStr}) +

+
++ +
+ {completedStagesTable.toNodeSeq} +
} if (shouldShowFailedStages) { - content ++=

Failed Stages ({numFailedStages})

++ - failedStagesTable.toNodeSeq + content ++= + +

+ + Failed Stages ({numFailedStages}) +

+
++ +
+ {failedStagesTable.toNodeSeq} +
} UIUtils.headerSparkPage("Stages for All Jobs", content, parent) } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala index bf59152c8c0cd..c27f30c21a843 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala @@ -340,24 +340,72 @@ private[ui] class JobPage(parent: JobsTab, store: AppStatusStore) extends WebUIP jobId, store.operationGraphForJob(jobId)) if (shouldShowActiveStages) { - content ++=

Active Stages ({activeStages.size})

++ - activeStagesTable.toNodeSeq + content ++= + +

+ + Active Stages ({activeStages.size}) +

+
++ +
+ {activeStagesTable.toNodeSeq} +
} if (shouldShowPendingStages) { - content ++=

Pending Stages ({pendingOrSkippedStages.size})

++ - pendingOrSkippedStagesTable.toNodeSeq + content ++= + +

+ + Pending Stages ({pendingOrSkippedStages.size}) +

+
++ +
+ {pendingOrSkippedStagesTable.toNodeSeq} +
} if (shouldShowCompletedStages) { - content ++=

Completed Stages ({completedStages.size})

++ - completedStagesTable.toNodeSeq + content ++= + +

+ + Completed Stages ({completedStages.size}) +

+
++ +
+ {completedStagesTable.toNodeSeq} +
} if (shouldShowSkippedStages) { - content ++=

Skipped Stages ({pendingOrSkippedStages.size})

++ - pendingOrSkippedStagesTable.toNodeSeq + content ++= + +

+ + Skipped Stages ({pendingOrSkippedStages.size}) +

+
++ +
+ {pendingOrSkippedStagesTable.toNodeSeq} +
} if (shouldShowFailedStages) { - content ++=

Failed Stages ({failedStages.size})

++ - failedStagesTable.toNodeSeq + content ++= + +

+ + Failed Stages ({failedStages.size}) +

+
++ +
+ {failedStagesTable.toNodeSeq} +
} UIUtils.headerSparkPage(s"Details for Job $jobId", content, parent, showVisualization = true) } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala index 98fbd7aceaa11..a3e1f13782e30 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala @@ -51,7 +51,18 @@ private[ui] class PoolPage(parent: StagesTab) extends WebUIPage("pool") { val poolTable = new PoolTable(Map(pool -> uiPool), parent) var content =

Summary

++ poolTable.toNodeSeq if (activeStages.nonEmpty) { - content ++=

Active Stages ({activeStages.size})

++ activeStagesTable.toNodeSeq + content ++= + +

+ + Active Stages ({activeStages.size}) +

+
++ +
+ {activeStagesTable.toNodeSeq} +
} UIUtils.headerSparkPage("Fair Scheduler Pool: " + poolName, content, parent) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index af78373ddb4b2..25bee33028393 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -486,8 +486,16 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We
{summaryTable.getOrElse("No tasks have reported metrics yet.")}
++ aggMetrics ++ maybeAccumulableTable ++ -

Tasks ({totalTasksNumStr})

++ - taskTableHTML ++ jsForScrollingDownToTaskTable + +

+ + Tasks ({totalTasksNumStr}) +

+
++ +
+ {taskTableHTML ++ jsForScrollingDownToTaskTable} +
UIUtils.headerSparkPage(stageHeader, content, parent, showVisualization = true) } diff --git a/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala b/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala index b8aec9890247a..68d946574a37b 100644 --- a/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala @@ -41,8 +41,16 @@ private[ui] class StoragePage(parent: SparkUITab, store: AppStatusStore) extends Nil } else {
-

RDDs

- {UIUtils.listingTable(rddHeader, rddRow, rdds, id = Some("storage-by-rdd-table"))} + +

+ + RDDs +

+
+
+ {UIUtils.listingTable(rddHeader, rddRow, rdds, id = Some("storage-by-rdd-table"))} +
} } From 6c39654efcb2aa8cb4d082ab7277a6fa38fb48e4 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Fri, 19 Jan 2018 22:47:18 +0800 Subject: [PATCH 0147/1161] [SPARK-23000][TEST] Keep Derby DB Location Unchanged After Session Cloning ## What changes were proposed in this pull request? After session cloning in `TestHive`, the conf of the singleton SparkContext for derby DB location is changed to a new directory. The new directory is created in `HiveUtils.newTemporaryConfiguration(useInMemoryDerby = false)`. This PR is to keep the conf value of `ConfVars.METASTORECONNECTURLKEY.varname` unchanged during the session clone. ## How was this patch tested? The issue can be reproduced by the command: > build/sbt -Phive "hive/test-only org.apache.spark.sql.hive.HiveSessionStateSuite org.apache.spark.sql.hive.DataSourceWithHiveMetastoreCatalogSuite" Also added a test case. Author: gatorsmile Closes #20328 from gatorsmile/fixTestFailure. --- .../org/apache/spark/sql/SessionStateSuite.scala | 5 +---- .../apache/spark/sql/hive/test/TestHive.scala | 8 +++++++- .../spark/sql/hive/HiveSessionStateSuite.scala | 16 +++++++++++++--- 3 files changed, 21 insertions(+), 8 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala index 5d75f5835bf9e..4efae4c46c2e1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql -import org.scalatest.BeforeAndAfterAll -import org.scalatest.BeforeAndAfterEach import scala.collection.mutable.ArrayBuffer import org.apache.spark.SparkFunSuite @@ -28,8 +26,7 @@ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.util.QueryExecutionListener -class SessionStateSuite extends SparkFunSuite - with BeforeAndAfterEach with BeforeAndAfterAll { +class SessionStateSuite extends SparkFunSuite { /** * A shared SparkSession for all tests in this suite. Make sure you reset any changes to this diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index b6be00dbb3a73..c84131fc3212a 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -180,7 +180,13 @@ private[hive] class TestHiveSparkSession( ConfVars.METASTORE_INTEGER_JDO_PUSHDOWN.varname -> "true", // scratch directory used by Hive's metastore client ConfVars.SCRATCHDIR.varname -> TestHiveContext.makeScratchDir().toURI.toString, - ConfVars.METASTORE_CLIENT_CONNECT_RETRY_DELAY.varname -> "1") + ConfVars.METASTORE_CLIENT_CONNECT_RETRY_DELAY.varname -> "1") ++ + // After session cloning, the JDBC connect string for a JDBC metastore should not be changed. + existingSharedState.map { state => + val connKey = + state.sparkContext.hadoopConfiguration.get(ConfVars.METASTORECONNECTURLKEY.varname) + ConfVars.METASTORECONNECTURLKEY.varname -> connKey + } metastoreTempConf.foreach { case (k, v) => sc.hadoopConfiguration.set(k, v) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSessionStateSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSessionStateSuite.scala index f7da3c4cbb0aa..ecc09cdcdbeaf 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSessionStateSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSessionStateSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.hive -import org.scalatest.BeforeAndAfterEach +import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.spark.sql._ import org.apache.spark.sql.hive.test.TestHiveSingleton @@ -25,8 +25,7 @@ import org.apache.spark.sql.hive.test.TestHiveSingleton /** * Run all tests from `SessionStateSuite` with a Hive based `SessionState`. */ -class HiveSessionStateSuite extends SessionStateSuite - with TestHiveSingleton with BeforeAndAfterEach { +class HiveSessionStateSuite extends SessionStateSuite with TestHiveSingleton { override def beforeAll(): Unit = { // Reuse the singleton session @@ -39,4 +38,15 @@ class HiveSessionStateSuite extends SessionStateSuite activeSession = null super.afterAll() } + + test("Clone then newSession") { + val sparkSession = hiveContext.sparkSession + val conf = sparkSession.sparkContext.hadoopConfiguration + val oldValue = conf.get(ConfVars.METASTORECONNECTURLKEY.varname) + sparkSession.cloneSession() + sparkSession.sharedState.externalCatalog.client.newSession() + val newValue = conf.get(ConfVars.METASTORECONNECTURLKEY.varname) + assert(oldValue == newValue, + "cloneSession and then newSession should not affect the Derby directory") + } } From 606a7485f12c5d5377c50258006c353ba5e49c3f Mon Sep 17 00:00:00 2001 From: Zheng RuiFeng Date: Fri, 19 Jan 2018 09:28:35 -0600 Subject: [PATCH 0148/1161] [SPARK-23085][ML] API parity for mllib.linalg.Vectors.sparse ## What changes were proposed in this pull request? `ML.Vectors#sparse(size: Int, elements: Seq[(Int, Double)])` support zero-length ## How was this patch tested? existing tests Author: Zheng RuiFeng Closes #20275 from zhengruifeng/SparseVector_size. --- .../scala/org/apache/spark/ml/linalg/Vectors.scala | 2 +- .../org/apache/spark/ml/linalg/VectorsSuite.scala | 14 ++++++++++++++ .../org/apache/spark/mllib/linalg/Vectors.scala | 3 +-- .../apache/spark/mllib/linalg/VectorsSuite.scala | 14 ++++++++++++++ 4 files changed, 30 insertions(+), 3 deletions(-) diff --git a/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala b/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala index 941b6eca568d3..5824e463ca1aa 100644 --- a/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala +++ b/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala @@ -565,7 +565,7 @@ class SparseVector @Since("2.0.0") ( // validate the data { - require(size >= 0, "The size of the requested sparse vector must be greater than 0.") + require(size >= 0, "The size of the requested sparse vector must be no less than 0.") require(indices.length == values.length, "Sparse vectors require that the dimension of the" + s" indices match the dimension of the values. You provided ${indices.length} indices and " + s" ${values.length} values.") diff --git a/mllib-local/src/test/scala/org/apache/spark/ml/linalg/VectorsSuite.scala b/mllib-local/src/test/scala/org/apache/spark/ml/linalg/VectorsSuite.scala index 79acef8214d88..0a316f57f811b 100644 --- a/mllib-local/src/test/scala/org/apache/spark/ml/linalg/VectorsSuite.scala +++ b/mllib-local/src/test/scala/org/apache/spark/ml/linalg/VectorsSuite.scala @@ -366,4 +366,18 @@ class VectorsSuite extends SparkMLFunSuite { assert(v.slice(Array(2, 0)) === new SparseVector(2, Array(0), Array(2.2))) assert(v.slice(Array(2, 0, 3, 4)) === new SparseVector(4, Array(0, 3), Array(2.2, 4.4))) } + + test("sparse vector only support non-negative length") { + val v1 = Vectors.sparse(0, Array.emptyIntArray, Array.emptyDoubleArray) + val v2 = Vectors.sparse(0, Array.empty[(Int, Double)]) + assert(v1.size === 0) + assert(v2.size === 0) + + intercept[IllegalArgumentException] { + Vectors.sparse(-1, Array(1), Array(2.0)) + } + intercept[IllegalArgumentException] { + Vectors.sparse(-1, Array((1, 2.0))) + } + } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index fd9605c013625..6e68d9684a672 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -326,8 +326,6 @@ object Vectors { */ @Since("1.0.0") def sparse(size: Int, elements: Seq[(Int, Double)]): Vector = { - require(size > 0, "The size of the requested sparse vector must be greater than 0.") - val (indices, values) = elements.sortBy(_._1).unzip var prev = -1 indices.foreach { i => @@ -758,6 +756,7 @@ class SparseVector @Since("1.0.0") ( @Since("1.0.0") val indices: Array[Int], @Since("1.0.0") val values: Array[Double]) extends Vector { + require(size >= 0, "The size of the requested sparse vector must be no less than 0.") require(indices.length == values.length, "Sparse vectors require that the dimension of the" + s" indices match the dimension of the values. You provided ${indices.length} indices and " + s" ${values.length} values.") diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala index 4074bead421e6..217b4a35438fd 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala @@ -495,4 +495,18 @@ class VectorsSuite extends SparkFunSuite with Logging { assert(mlDenseVectorToArray(dv) === mlDenseVectorToArray(newDV)) assert(mlSparseVectorToArray(sv) === mlSparseVectorToArray(newSV)) } + + test("sparse vector only support non-negative length") { + val v1 = Vectors.sparse(0, Array.emptyIntArray, Array.emptyDoubleArray) + val v2 = Vectors.sparse(0, Array.empty[(Int, Double)]) + assert(v1.size === 0) + assert(v2.size === 0) + + intercept[IllegalArgumentException] { + Vectors.sparse(-1, Array(1), Array(2.0)) + } + intercept[IllegalArgumentException] { + Vectors.sparse(-1, Array((1, 2.0))) + } + } } From d8aaa771e249b3f54b57ce24763e53fd65a0dbf7 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 19 Jan 2018 08:58:21 -0800 Subject: [PATCH 0149/1161] [SPARK-23149][SQL] polish ColumnarBatch ## What changes were proposed in this pull request? Several cleanups in `ColumnarBatch` * remove `schema`. The `ColumnVector`s inside `ColumnarBatch` already have the data type information, we don't need this `schema`. * remove `capacity`. `ColumnarBatch` is just a wrapper of `ColumnVector`s, not builders, it doesn't need a capacity property. * remove `DEFAULT_BATCH_SIZE`. As a wrapper, `ColumnarBatch` can't decide the batch size, it should be decided by the reader, e.g. parquet reader, orc reader, cached table reader. The default batch size should also be defined by the reader. ## How was this patch tested? existing tests. Author: Wenchen Fan Closes #20316 from cloud-fan/columnar-batch. --- .../orc/OrcColumnarBatchReader.java | 49 +++++++------------ .../SpecificParquetRecordReaderBase.java | 12 ++--- .../VectorizedParquetRecordReader.java | 24 ++++----- .../vectorized/ColumnVectorUtils.java | 18 +++---- .../spark/sql/vectorized/ColumnarBatch.java | 20 +------- .../VectorizedHashMapGenerator.scala | 2 +- .../sql/execution/arrow/ArrowConverters.scala | 2 +- .../columnar/InMemoryTableScanExec.scala | 5 +- .../python/ArrowEvalPythonExec.scala | 8 +-- .../execution/python/ArrowPythonRunner.scala | 2 +- .../sql/sources/v2/JavaBatchDataSourceV2.java | 3 +- .../vectorized/ColumnarBatchSuite.scala | 7 ++- .../sql/sources/v2/DataSourceV2Suite.scala | 3 +- 13 files changed, 61 insertions(+), 94 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java index 36fdf2bdf84d2..89bae4326e93b 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java @@ -49,18 +49,8 @@ * After creating, `initialize` and `initBatch` should be called sequentially. */ public class OrcColumnarBatchReader extends RecordReader { - - /** - * The default size of batch. We use this value for ORC reader to make it consistent with Spark's - * columnar batch, because their default batch sizes are different like the following: - * - * - ORC's VectorizedRowBatch.DEFAULT_SIZE = 1024 - * - Spark's ColumnarBatch.DEFAULT_BATCH_SIZE = 4 * 1024 - */ - private static final int DEFAULT_SIZE = 4 * 1024; - - // ORC File Reader - private Reader reader; + // TODO: make this configurable. + private static final int CAPACITY = 4 * 1024; // Vectorized ORC Row Batch private VectorizedRowBatch batch; @@ -98,22 +88,22 @@ public OrcColumnarBatchReader(boolean useOffHeap, boolean copyToSpark) { @Override - public Void getCurrentKey() throws IOException, InterruptedException { + public Void getCurrentKey() { return null; } @Override - public ColumnarBatch getCurrentValue() throws IOException, InterruptedException { + public ColumnarBatch getCurrentValue() { return columnarBatch; } @Override - public float getProgress() throws IOException, InterruptedException { + public float getProgress() throws IOException { return recordReader.getProgress(); } @Override - public boolean nextKeyValue() throws IOException, InterruptedException { + public boolean nextKeyValue() throws IOException { return nextBatch(); } @@ -134,16 +124,15 @@ public void close() throws IOException { * Please note that `initBatch` is needed to be called after this. */ @Override - public void initialize(InputSplit inputSplit, TaskAttemptContext taskAttemptContext) - throws IOException, InterruptedException { + public void initialize( + InputSplit inputSplit, TaskAttemptContext taskAttemptContext) throws IOException { FileSplit fileSplit = (FileSplit)inputSplit; Configuration conf = taskAttemptContext.getConfiguration(); - reader = OrcFile.createReader( + Reader reader = OrcFile.createReader( fileSplit.getPath(), OrcFile.readerOptions(conf) .maxLength(OrcConf.MAX_FILE_LENGTH.getLong(conf)) .filesystem(fileSplit.getPath().getFileSystem(conf))); - Reader.Options options = OrcInputFormat.buildOptions(conf, reader, fileSplit.getStart(), fileSplit.getLength()); recordReader = reader.rows(options); @@ -159,7 +148,7 @@ public void initBatch( StructField[] requiredFields, StructType partitionSchema, InternalRow partitionValues) { - batch = orcSchema.createRowBatch(DEFAULT_SIZE); + batch = orcSchema.createRowBatch(CAPACITY); assert(!batch.selectedInUse); // `selectedInUse` should be initialized with `false`. this.requiredFields = requiredFields; @@ -171,19 +160,17 @@ public void initBatch( resultSchema = resultSchema.add(f); } - int capacity = DEFAULT_SIZE; - if (copyToSpark) { if (MEMORY_MODE == MemoryMode.OFF_HEAP) { - columnVectors = OffHeapColumnVector.allocateColumns(capacity, resultSchema); + columnVectors = OffHeapColumnVector.allocateColumns(CAPACITY, resultSchema); } else { - columnVectors = OnHeapColumnVector.allocateColumns(capacity, resultSchema); + columnVectors = OnHeapColumnVector.allocateColumns(CAPACITY, resultSchema); } // Initialize the missing columns once. for (int i = 0; i < requiredFields.length; i++) { if (requestedColIds[i] == -1) { - columnVectors[i].putNulls(0, capacity); + columnVectors[i].putNulls(0, CAPACITY); columnVectors[i].setIsConstant(); } } @@ -196,7 +183,7 @@ public void initBatch( } } - columnarBatch = new ColumnarBatch(resultSchema, columnVectors, capacity); + columnarBatch = new ColumnarBatch(columnVectors); } else { // Just wrap the ORC column vector instead of copying it to Spark column vector. orcVectorWrappers = new org.apache.spark.sql.vectorized.ColumnVector[resultSchema.length()]; @@ -206,8 +193,8 @@ public void initBatch( int colId = requestedColIds[i]; // Initialize the missing columns once. if (colId == -1) { - OnHeapColumnVector missingCol = new OnHeapColumnVector(capacity, dt); - missingCol.putNulls(0, capacity); + OnHeapColumnVector missingCol = new OnHeapColumnVector(CAPACITY, dt); + missingCol.putNulls(0, CAPACITY); missingCol.setIsConstant(); orcVectorWrappers[i] = missingCol; } else { @@ -219,14 +206,14 @@ public void initBatch( int partitionIdx = requiredFields.length; for (int i = 0; i < partitionValues.numFields(); i++) { DataType dt = partitionSchema.fields()[i].dataType(); - OnHeapColumnVector partitionCol = new OnHeapColumnVector(capacity, dt); + OnHeapColumnVector partitionCol = new OnHeapColumnVector(CAPACITY, dt); ColumnVectorUtils.populate(partitionCol, partitionValues, i); partitionCol.setIsConstant(); orcVectorWrappers[partitionIdx + i] = partitionCol; } } - columnarBatch = new ColumnarBatch(resultSchema, orcVectorWrappers, capacity); + columnarBatch = new ColumnarBatch(orcVectorWrappers); } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java index 80c2f491b48ce..e65cd252c3ddf 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java @@ -170,7 +170,7 @@ public void initialize(InputSplit inputSplit, TaskAttemptContext taskAttemptCont * Returns the list of files at 'path' recursively. This skips files that are ignored normally * by MapReduce. */ - public static List listDirectory(File path) throws IOException { + public static List listDirectory(File path) { List result = new ArrayList<>(); if (path.isDirectory()) { for (File f: path.listFiles()) { @@ -231,7 +231,7 @@ protected void initialize(String path, List columns) throws IOException } @Override - public Void getCurrentKey() throws IOException, InterruptedException { + public Void getCurrentKey() { return null; } @@ -259,7 +259,7 @@ public ValuesReaderIntIterator(ValuesReader delegate) { } @Override - int nextInt() throws IOException { + int nextInt() { return delegate.readInteger(); } } @@ -279,15 +279,15 @@ int nextInt() throws IOException { protected static final class NullIntIterator extends IntIterator { @Override - int nextInt() throws IOException { return 0; } + int nextInt() { return 0; } } /** * Creates a reader for definition and repetition levels, returning an optimized one if * the levels are not needed. */ - protected static IntIterator createRLEIterator(int maxLevel, BytesInput bytes, - ColumnDescriptor descriptor) throws IOException { + protected static IntIterator createRLEIterator( + int maxLevel, BytesInput bytes, ColumnDescriptor descriptor) throws IOException { try { if (maxLevel == 0) return new NullIntIterator(); return new RLEIntIterator( diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java index cd745b1f0e4e3..bb1b23611a7d7 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java @@ -50,6 +50,9 @@ * TODO: make this always return ColumnarBatches. */ public class VectorizedParquetRecordReader extends SpecificParquetRecordReaderBase { + // TODO: make this configurable. + private static final int CAPACITY = 4 * 1024; + /** * Batch of rows that we assemble and the current index we've returned. Every time this * batch is used up (batchIdx == numBatched), we populated the batch. @@ -152,7 +155,7 @@ public void close() throws IOException { } @Override - public boolean nextKeyValue() throws IOException, InterruptedException { + public boolean nextKeyValue() throws IOException { resultBatch(); if (returnColumnarBatch) return nextBatch(); @@ -165,13 +168,13 @@ public boolean nextKeyValue() throws IOException, InterruptedException { } @Override - public Object getCurrentValue() throws IOException, InterruptedException { + public Object getCurrentValue() { if (returnColumnarBatch) return columnarBatch; return columnarBatch.getRow(batchIdx - 1); } @Override - public float getProgress() throws IOException, InterruptedException { + public float getProgress() { return (float) rowsReturned / totalRowCount; } @@ -181,7 +184,7 @@ public float getProgress() throws IOException, InterruptedException { // Columns 0,1: data columns // Column 2: partitionValues[0] // Column 3: partitionValues[1] - public void initBatch( + private void initBatch( MemoryMode memMode, StructType partitionColumns, InternalRow partitionValues) { @@ -195,13 +198,12 @@ public void initBatch( } } - int capacity = ColumnarBatch.DEFAULT_BATCH_SIZE; if (memMode == MemoryMode.OFF_HEAP) { - columnVectors = OffHeapColumnVector.allocateColumns(capacity, batchSchema); + columnVectors = OffHeapColumnVector.allocateColumns(CAPACITY, batchSchema); } else { - columnVectors = OnHeapColumnVector.allocateColumns(capacity, batchSchema); + columnVectors = OnHeapColumnVector.allocateColumns(CAPACITY, batchSchema); } - columnarBatch = new ColumnarBatch(batchSchema, columnVectors, capacity); + columnarBatch = new ColumnarBatch(columnVectors); if (partitionColumns != null) { int partitionIdx = sparkSchema.fields().length; for (int i = 0; i < partitionColumns.fields().length; i++) { @@ -213,13 +215,13 @@ public void initBatch( // Initialize missing columns with nulls. for (int i = 0; i < missingColumns.length; i++) { if (missingColumns[i]) { - columnVectors[i].putNulls(0, columnarBatch.capacity()); + columnVectors[i].putNulls(0, CAPACITY); columnVectors[i].setIsConstant(); } } } - public void initBatch() { + private void initBatch() { initBatch(MEMORY_MODE, null, null); } @@ -255,7 +257,7 @@ public boolean nextBatch() throws IOException { if (rowsReturned >= totalRowCount) return false; checkEndOfRowGroup(); - int num = (int) Math.min((long) columnarBatch.capacity(), totalCountLoadedSoFar - rowsReturned); + int num = (int) Math.min((long) CAPACITY, totalCountLoadedSoFar - rowsReturned); for (int i = 0; i < columnReaders.length; ++i) { if (columnReaders[i] == null) continue; columnReaders[i].readBatch(num, columnVectors[i]); diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java index b5cbe8e2839ba..5ee8cc8da2309 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java @@ -118,19 +118,19 @@ private static void appendValue(WritableColumnVector dst, DataType t, Object o) } } else { if (t == DataTypes.BooleanType) { - dst.appendBoolean(((Boolean)o).booleanValue()); + dst.appendBoolean((Boolean) o); } else if (t == DataTypes.ByteType) { - dst.appendByte(((Byte) o).byteValue()); + dst.appendByte((Byte) o); } else if (t == DataTypes.ShortType) { - dst.appendShort(((Short)o).shortValue()); + dst.appendShort((Short) o); } else if (t == DataTypes.IntegerType) { - dst.appendInt(((Integer)o).intValue()); + dst.appendInt((Integer) o); } else if (t == DataTypes.LongType) { - dst.appendLong(((Long)o).longValue()); + dst.appendLong((Long) o); } else if (t == DataTypes.FloatType) { - dst.appendFloat(((Float)o).floatValue()); + dst.appendFloat((Float) o); } else if (t == DataTypes.DoubleType) { - dst.appendDouble(((Double)o).doubleValue()); + dst.appendDouble((Double) o); } else if (t == DataTypes.StringType) { byte[] b =((String)o).getBytes(StandardCharsets.UTF_8); dst.appendByteArray(b, 0, b.length); @@ -192,7 +192,7 @@ private static void appendValue(WritableColumnVector dst, DataType t, Row src, i */ public static ColumnarBatch toBatch( StructType schema, MemoryMode memMode, Iterator row) { - int capacity = ColumnarBatch.DEFAULT_BATCH_SIZE; + int capacity = 4 * 1024; WritableColumnVector[] columnVectors; if (memMode == MemoryMode.OFF_HEAP) { columnVectors = OffHeapColumnVector.allocateColumns(capacity, schema); @@ -208,7 +208,7 @@ public static ColumnarBatch toBatch( } n++; } - ColumnarBatch batch = new ColumnarBatch(schema, columnVectors, capacity); + ColumnarBatch batch = new ColumnarBatch(columnVectors); batch.setNumRows(n); return batch; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatch.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatch.java index 9ae1c6d9993f0..4dc826cf60c15 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatch.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatch.java @@ -20,7 +20,6 @@ import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.execution.vectorized.MutableColumnarRow; -import org.apache.spark.sql.types.StructType; /** * This class wraps multiple ColumnVectors as a row-wise table. It provides a row view of this @@ -28,10 +27,6 @@ * the entire data loading process. */ public final class ColumnarBatch { - public static final int DEFAULT_BATCH_SIZE = 4 * 1024; - - private final StructType schema; - private final int capacity; private int numRows; private final ColumnVector[] columns; @@ -82,7 +77,6 @@ public void remove() { * Sets the number of rows in this batch. */ public void setNumRows(int numRows) { - assert(numRows <= this.capacity); this.numRows = numRows; } @@ -96,16 +90,6 @@ public void setNumRows(int numRows) { */ public int numRows() { return numRows; } - /** - * Returns the schema that makes up this batch. - */ - public StructType schema() { return schema; } - - /** - * Returns the max capacity (in number of rows) for this batch. - */ - public int capacity() { return capacity; } - /** * Returns the column at `ordinal`. */ @@ -120,10 +104,8 @@ public InternalRow getRow(int rowId) { return row; } - public ColumnarBatch(StructType schema, ColumnVector[] columns, int capacity) { - this.schema = schema; + public ColumnarBatch(ColumnVector[] columns) { this.columns = columns; - this.capacity = capacity; this.row = new MutableColumnarRow(columns); } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala index 0cf9b53ce1d5d..eb48584d0c1ee 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala @@ -94,7 +94,7 @@ class VectorizedHashMapGenerator( | | public $generatedClassName() { | vectors = ${classOf[OnHeapColumnVector].getName}.allocateColumns(capacity, schema); - | batch = new ${classOf[ColumnarBatch].getName}(schema, vectors, capacity); + | batch = new ${classOf[ColumnarBatch].getName}(vectors); | | // Generates a projection to return the aggregate buffer only. | ${classOf[OnHeapColumnVector].getName}[] aggBufferVectors = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala index bcd1aa0890ba3..7487564ed64da 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala @@ -175,7 +175,7 @@ private[sql] object ArrowConverters { new ArrowColumnVector(vector).asInstanceOf[ColumnVector] }.toArray - val batch = new ColumnarBatch(schemaRead, columns, root.getRowCount) + val batch = new ColumnarBatch(columns) batch.setNumRows(root.getRowCount) batch.rowIterator().asScala } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala index 3565ee3af1b9f..28b3875505cd2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala @@ -78,11 +78,10 @@ case class InMemoryTableScanExec( } else { OffHeapColumnVector.allocateColumns(rowCount, columnarBatchSchema) } - val columnarBatch = new ColumnarBatch( - columnarBatchSchema, columnVectors.asInstanceOf[Array[ColumnVector]], rowCount) + val columnarBatch = new ColumnarBatch(columnVectors.asInstanceOf[Array[ColumnVector]]) columnarBatch.setNumRows(rowCount) - for (i <- 0 until attributes.length) { + for (i <- attributes.indices) { ColumnAccessor.decompress( cachedColumnarBatch.buffers(columnIndices(i)), columnarBatch.column(i).asInstanceOf[WritableColumnVector], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala index c06bc7b66ff39..47b146f076b62 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala @@ -74,8 +74,7 @@ case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chi schema: StructType, context: TaskContext): Iterator[InternalRow] = { - val schemaOut = StructType.fromAttributes(output.drop(child.output.length).zipWithIndex - .map { case (attr, i) => attr.withName(s"_$i") }) + val outputTypes = output.drop(child.output.length).map(_.dataType) // DO NOT use iter.grouped(). See BatchIterator. val batchIter = if (batchSize > 0) new BatchIterator(iter, batchSize) else Iterator(iter) @@ -90,8 +89,9 @@ case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chi private var currentIter = if (columnarBatchIter.hasNext) { val batch = columnarBatchIter.next() - assert(schemaOut.equals(batch.schema), - s"Invalid schema from pandas_udf: expected $schemaOut, got ${batch.schema}") + val actualDataTypes = (0 until batch.numCols()).map(i => batch.column(i).dataType()) + assert(outputTypes == actualDataTypes, "Invalid schema from pandas_udf: " + + s"expected ${outputTypes.mkString(", ")}, got ${actualDataTypes.mkString(", ")}") batch.rowIterator.asScala } else { Iterator.empty diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala index dc5ba96e69aec..5fcdcddca7d51 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala @@ -138,7 +138,7 @@ class ArrowPythonRunner( if (reader != null && batchLoaded) { batchLoaded = reader.loadNextBatch() if (batchLoaded) { - val batch = new ColumnarBatch(schema, vectors, root.getRowCount) + val batch = new ColumnarBatch(vectors) batch.setNumRows(root.getRowCount) batch } else { diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java index 98d6a53b54d28..a5d77a90ece42 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java @@ -69,8 +69,7 @@ public DataReader createDataReader() { ColumnVector[] vectors = new ColumnVector[2]; vectors[0] = i; vectors[1] = j; - this.batch = - new ColumnarBatch(new StructType().add("i", "int").add("j", "int"), vectors, BATCH_SIZE); + this.batch = new ColumnarBatch(vectors); return this; } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala index 675f06b31b970..cd90681ecabc6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala @@ -875,14 +875,13 @@ class ColumnarBatchSuite extends SparkFunSuite { .add("intCol2", IntegerType) .add("string", BinaryType) - val capacity = ColumnarBatch.DEFAULT_BATCH_SIZE + val capacity = 4 * 1024 val columns = schema.fields.map { field => allocate(capacity, field.dataType, memMode) } - val batch = new ColumnarBatch(schema, columns.toArray, ColumnarBatch.DEFAULT_BATCH_SIZE) + val batch = new ColumnarBatch(columns.toArray) assert(batch.numCols() == 4) assert(batch.numRows() == 0) - assert(batch.capacity() > 0) assert(batch.rowIterator().hasNext == false) // Add a row [1, 1.1, NULL] @@ -1153,7 +1152,7 @@ class ColumnarBatchSuite extends SparkFunSuite { val columnVectors = Seq(new ArrowColumnVector(vector1), new ArrowColumnVector(vector2)) val schema = StructType(Seq(StructField("int1", IntegerType), StructField("int2", IntegerType))) - val batch = new ColumnarBatch(schema, columnVectors.toArray[ColumnVector], 11) + val batch = new ColumnarBatch(columnVectors.toArray) batch.setNumRows(11) assert(batch.numCols() == 2) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala index a89f7c55bf4f7..0ca29524c6d05 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala @@ -333,8 +333,7 @@ class BatchReadTask(start: Int, end: Int) private final val BATCH_SIZE = 20 private lazy val i = new OnHeapColumnVector(BATCH_SIZE, IntegerType) private lazy val j = new OnHeapColumnVector(BATCH_SIZE, IntegerType) - private lazy val batch = new ColumnarBatch( - new StructType().add("i", "int").add("j", "int"), Array(i, j), BATCH_SIZE) + private lazy val batch = new ColumnarBatch(Array(i, j)) private var current = start From 73d3b230f3816a854a181c0912d87b180e347271 Mon Sep 17 00:00:00 2001 From: foxish Date: Fri, 19 Jan 2018 10:23:13 -0800 Subject: [PATCH 0150/1161] [SPARK-23104][K8S][DOCS] Changes to Kubernetes scheduler documentation ## What changes were proposed in this pull request? Docs changes: - Adding a warning that the backend is experimental. - Removing a defunct internal-only option from documentation - Clarifying that node selectors can be used right away, and other minor cosmetic changes ## How was this patch tested? Docs only change Author: foxish Closes #20314 from foxish/ambiguous-docs. --- docs/cluster-overview.md | 4 ++-- docs/running-on-kubernetes.md | 43 ++++++++++++++++------------------- 2 files changed, 22 insertions(+), 25 deletions(-) diff --git a/docs/cluster-overview.md b/docs/cluster-overview.md index 658e67f99dd71..7277e2fb2731d 100644 --- a/docs/cluster-overview.md +++ b/docs/cluster-overview.md @@ -52,8 +52,8 @@ The system currently supports three cluster managers: * [Apache Mesos](running-on-mesos.html) -- a general cluster manager that can also run Hadoop MapReduce and service applications. * [Hadoop YARN](running-on-yarn.html) -- the resource manager in Hadoop 2. -* [Kubernetes](running-on-kubernetes.html) -- [Kubernetes](https://kubernetes.io/docs/concepts/overview/what-is-kubernetes/) -is an open-source platform that provides container-centric infrastructure. +* [Kubernetes](running-on-kubernetes.html) -- an open-source system for automating deployment, scaling, + and management of containerized applications. A third-party project (not supported by the Spark project) exists to add support for [Nomad](https://github.com/hashicorp/nomad-spark) as a cluster manager. diff --git a/docs/running-on-kubernetes.md b/docs/running-on-kubernetes.md index d6b1735ce5550..3c7586e8544ba 100644 --- a/docs/running-on-kubernetes.md +++ b/docs/running-on-kubernetes.md @@ -8,6 +8,10 @@ title: Running Spark on Kubernetes Spark can run on clusters managed by [Kubernetes](https://kubernetes.io). This feature makes use of native Kubernetes scheduler that has been added to Spark. +**The Kubernetes scheduler is currently experimental. +In future versions, there may be behavioral changes around configuration, +container images and entrypoints.** + # Prerequisites * A runnable distribution of Spark 2.3 or above. @@ -41,11 +45,10 @@ logs and remains in "completed" state in the Kubernetes API until it's eventuall Note that in the completed state, the driver pod does *not* use any computational or memory resources. -The driver and executor pod scheduling is handled by Kubernetes. It will be possible to affect Kubernetes scheduling -decisions for driver and executor pods using advanced primitives like -[node selectors](https://kubernetes.io/docs/concepts/configuration/assign-pod-node/#nodeselector) -and [node/pod affinities](https://kubernetes.io/docs/concepts/configuration/assign-pod-node/#affinity-and-anti-affinity) -in a future release. +The driver and executor pod scheduling is handled by Kubernetes. It is possible to schedule the +driver and executor pods on a subset of available nodes through a [node selector](https://kubernetes.io/docs/concepts/configuration/assign-pod-node/#nodeselector) +using the configuration property for it. It will be possible to use more advanced +scheduling hints like [node/pod affinities](https://kubernetes.io/docs/concepts/configuration/assign-pod-node/#affinity-and-anti-affinity) in a future release. # Submitting Applications to Kubernetes @@ -62,8 +65,10 @@ use with the Kubernetes backend. Example usage is: - ./bin/docker-image-tool.sh -r -t my-tag build - ./bin/docker-image-tool.sh -r -t my-tag push +```bash +$ ./bin/docker-image-tool.sh -r -t my-tag build +$ ./bin/docker-image-tool.sh -r -t my-tag push +``` ## Cluster Mode @@ -94,7 +99,7 @@ must consist of lower case alphanumeric characters, `-`, and `.` and must start If you have a Kubernetes cluster setup, one way to discover the apiserver URL is by executing `kubectl cluster-info`. ```bash -kubectl cluster-info +$ kubectl cluster-info Kubernetes master is running at http://127.0.0.1:6443 ``` @@ -105,7 +110,7 @@ authenticating proxy, `kubectl proxy` to communicate to the Kubernetes API. The local proxy can be started by: ```bash -kubectl proxy +$ kubectl proxy ``` If the local proxy is running at localhost:8001, `--master k8s://http://127.0.0.1:8001` can be used as the argument to @@ -173,7 +178,7 @@ Logs can be accessed using the Kubernetes API and the `kubectl` CLI. When a Spar to stream logs from the application using: ```bash -kubectl -n= logs -f +$ kubectl -n= logs -f ``` The same logs can also be accessed through the @@ -186,7 +191,7 @@ The UI associated with any application can be accessed locally using [`kubectl port-forward`](https://kubernetes.io/docs/tasks/access-application-cluster/port-forward-access-application-cluster/#forward-a-local-port-to-a-port-on-the-pod). ```bash -kubectl port-forward 4040:4040 +$ kubectl port-forward 4040:4040 ``` Then, the Spark driver UI can be accessed on `http://localhost:4040`. @@ -200,13 +205,13 @@ are errors during the running of the application, often, the best way to investi To get some basic information about the scheduling decisions made around the driver pod, you can run: ```bash -kubectl describe pod +$ kubectl describe pod ``` If the pod has encountered a runtime error, the status can be probed further using: ```bash -kubectl logs +$ kubectl logs ``` Status and logs of failed executor pods can be checked in similar ways. Finally, deleting the driver pod will clean up the entire spark @@ -254,7 +259,7 @@ To create a custom service account, a user can use the `kubectl create serviceac following command creates a service account named `spark`: ```bash -kubectl create serviceaccount spark +$ kubectl create serviceaccount spark ``` To grant a service account a `Role` or `ClusterRole`, a `RoleBinding` or `ClusterRoleBinding` is needed. To create @@ -263,7 +268,7 @@ for `ClusterRoleBinding`) command. For example, the following command creates an namespace and grants it to the `spark` service account created above: ```bash -kubectl create clusterrolebinding spark-role --clusterrole=edit --serviceaccount=default:spark --namespace=default +$ kubectl create clusterrolebinding spark-role --clusterrole=edit --serviceaccount=default:spark --namespace=default ``` Note that a `Role` can only be used to grant access to resources (like pods) within a single namespace, whereas a @@ -543,14 +548,6 @@ specific to Spark on Kubernetes. to avoid name conflicts. - - spark.kubernetes.executor.podNamePrefix - (none) - - Prefix for naming the executor pods. - If not set, the executor pod name is set to driver pod name suffixed by an integer. - - spark.kubernetes.executor.lostCheck.maxAttempts 10 From 07296a61c29eb074553956b6c0f92810ecf7bab2 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Fri, 19 Jan 2018 10:25:18 -0800 Subject: [PATCH 0151/1161] [INFRA] Close stale PR. Closes #20185. From fed2139f053fac4a9a6952ff0ab1cc2a5f657bd0 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Fri, 19 Jan 2018 13:26:37 -0600 Subject: [PATCH 0152/1161] [SPARK-20664][CORE] Delete stale application data from SHS. Detect the deletion of event log files from storage, and remove data about the related application attempt in the SHS. Also contains code to fix SPARK-21571 based on code by ericvandenbergfb. Author: Marcelo Vanzin Closes #20138 from vanzin/SPARK-20664. --- .../deploy/history/FsHistoryProvider.scala | 297 +++++++++++------- .../history/FsHistoryProviderSuite.scala | 117 ++++++- .../deploy/history/HistoryServerSuite.scala | 4 +- 3 files changed, 306 insertions(+), 112 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index 94c80ebd55e74..f9d0b5ee4e23e 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -19,7 +19,7 @@ package org.apache.spark.deploy.history import java.io.{File, FileNotFoundException, IOException} import java.util.{Date, ServiceLoader, UUID} -import java.util.concurrent.{Executors, ExecutorService, Future, TimeUnit} +import java.util.concurrent.{ExecutorService, TimeUnit} import java.util.zip.{ZipEntry, ZipOutputStream} import scala.collection.JavaConverters._ @@ -29,7 +29,7 @@ import scala.xml.Node import com.fasterxml.jackson.annotation.JsonIgnore import com.google.common.io.ByteStreams -import com.google.common.util.concurrent.{MoreExecutors, ThreadFactoryBuilder} +import com.google.common.util.concurrent.MoreExecutors import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.fs.permission.FsAction import org.apache.hadoop.hdfs.DistributedFileSystem @@ -116,8 +116,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) // Used by check event thread and clean log thread. // Scheduled thread pool size must be one, otherwise it will have concurrent issues about fs // and applications between check task and clean task. - private val pool = Executors.newScheduledThreadPool(1, new ThreadFactoryBuilder() - .setNameFormat("spark-history-task-%d").setDaemon(true).build()) + private val pool = ThreadUtils.newDaemonSingleThreadScheduledExecutor("spark-history-task-%d") // The modification time of the newest log detected during the last scan. Currently only // used for logging msgs (logs are re-scanned based on file size, rather than modtime) @@ -174,7 +173,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) * Fixed size thread pool to fetch and parse log files. */ private val replayExecutor: ExecutorService = { - if (!conf.contains("spark.testing")) { + if (Utils.isTesting) { ThreadUtils.newDaemonFixedThreadPool(NUM_PROCESSING_THREADS, "log-replay-executor") } else { MoreExecutors.sameThreadExecutor() @@ -275,7 +274,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) try { Some(load(appId).toApplicationInfo()) } catch { - case e: NoSuchElementException => + case _: NoSuchElementException => None } } @@ -405,49 +404,70 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) try { val newLastScanTime = getNewLastScanTime() logDebug(s"Scanning $logDir with lastScanTime==$lastScanTime") - // scan for modified applications, replay and merge them - val logInfos = Option(fs.listStatus(new Path(logDir))).map(_.toSeq).getOrElse(Nil) + + val updated = Option(fs.listStatus(new Path(logDir))).map(_.toSeq).getOrElse(Nil) .filter { entry => !entry.isDirectory() && // FsHistoryProvider generates a hidden file which can't be read. Accidentally // reading a garbage file is safe, but we would log an error which can be scary to // the end-user. !entry.getPath().getName().startsWith(".") && - SparkHadoopUtil.get.checkAccessPermission(entry, FsAction.READ) && - recordedFileSize(entry.getPath()) < entry.getLen() + SparkHadoopUtil.get.checkAccessPermission(entry, FsAction.READ) + } + .filter { entry => + try { + val info = listing.read(classOf[LogInfo], entry.getPath().toString()) + if (info.fileSize < entry.getLen()) { + // Log size has changed, it should be parsed. + true + } else { + // If the SHS view has a valid application, update the time the file was last seen so + // that the entry is not deleted from the SHS listing. + if (info.appId.isDefined) { + listing.write(info.copy(lastProcessed = newLastScanTime)) + } + false + } + } catch { + case _: NoSuchElementException => + // If the file is currently not being tracked by the SHS, add an entry for it and try + // to parse it. This will allow the cleaner code to detect the file as stale later on + // if it was not possible to parse it. + listing.write(LogInfo(entry.getPath().toString(), newLastScanTime, None, None, + entry.getLen())) + entry.getLen() > 0 + } } .sortWith { case (entry1, entry2) => entry1.getModificationTime() > entry2.getModificationTime() } - if (logInfos.nonEmpty) { - logDebug(s"New/updated attempts found: ${logInfos.size} ${logInfos.map(_.getPath)}") + if (updated.nonEmpty) { + logDebug(s"New/updated attempts found: ${updated.size} ${updated.map(_.getPath)}") } - var tasks = mutable.ListBuffer[Future[_]]() - - try { - for (file <- logInfos) { - tasks += replayExecutor.submit(new Runnable { - override def run(): Unit = mergeApplicationListing(file) + val tasks = updated.map { entry => + try { + replayExecutor.submit(new Runnable { + override def run(): Unit = mergeApplicationListing(entry, newLastScanTime) }) + } catch { + // let the iteration over the updated entries break, since an exception on + // replayExecutor.submit (..) indicates the ExecutorService is unable + // to take any more submissions at this time + case e: Exception => + logError(s"Exception while submitting event log for replay", e) + null } - } catch { - // let the iteration over logInfos break, since an exception on - // replayExecutor.submit (..) indicates the ExecutorService is unable - // to take any more submissions at this time - - case e: Exception => - logError(s"Exception while submitting event log for replay", e) - } + }.filter(_ != null) pendingReplayTasksCount.addAndGet(tasks.size) + // Wait for all tasks to finish. This makes sure that checkForLogs + // is not scheduled again while some tasks are already running in + // the replayExecutor. tasks.foreach { task => try { - // Wait for all tasks to finish. This makes sure that checkForLogs - // is not scheduled again while some tasks are already running in - // the replayExecutor. task.get() } catch { case e: InterruptedException => @@ -459,13 +479,70 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) } } + // Delete all information about applications whose log files disappeared from storage. + // This is done by identifying the event logs which were not touched by the current + // directory scan. + // + // Only entries with valid applications are cleaned up here. Cleaning up invalid log + // files is done by the periodic cleaner task. + val stale = listing.view(classOf[LogInfo]) + .index("lastProcessed") + .last(newLastScanTime - 1) + .asScala + .toList + stale.foreach { log => + log.appId.foreach { appId => + cleanAppData(appId, log.attemptId, log.logPath) + listing.delete(classOf[LogInfo], log.logPath) + } + } + lastScanTime.set(newLastScanTime) } catch { case e: Exception => logError("Exception in checking for event log updates", e) } } - private def getNewLastScanTime(): Long = { + private def cleanAppData(appId: String, attemptId: Option[String], logPath: String): Unit = { + try { + val app = load(appId) + val (attempt, others) = app.attempts.partition(_.info.attemptId == attemptId) + + assert(attempt.isEmpty || attempt.size == 1) + val isStale = attempt.headOption.exists { a => + if (a.logPath != new Path(logPath).getName()) { + // If the log file name does not match, then probably the old log file was from an + // in progress application. Just return that the app should be left alone. + false + } else { + val maybeUI = synchronized { + activeUIs.remove(appId -> attemptId) + } + + maybeUI.foreach { ui => + ui.invalidate() + ui.ui.store.close() + } + + diskManager.foreach(_.release(appId, attemptId, delete = true)) + true + } + } + + if (isStale) { + if (others.nonEmpty) { + val newAppInfo = new ApplicationInfoWrapper(app.info, others) + listing.write(newAppInfo) + } else { + listing.delete(classOf[ApplicationInfoWrapper], appId) + } + } + } catch { + case _: NoSuchElementException => + } + } + + private[history] def getNewLastScanTime(): Long = { val fileName = "." + UUID.randomUUID().toString val path = new Path(logDir, fileName) val fos = fs.create(path) @@ -530,7 +607,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) /** * Replay the given log file, saving the application in the listing db. */ - protected def mergeApplicationListing(fileStatus: FileStatus): Unit = { + protected def mergeApplicationListing(fileStatus: FileStatus, scanTime: Long): Unit = { val eventsFilter: ReplayEventsFilter = { eventString => eventString.startsWith(APPL_START_EVENT_PREFIX) || eventString.startsWith(APPL_END_EVENT_PREFIX) || @@ -544,73 +621,78 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) bus.addListener(listener) replay(fileStatus, bus, eventsFilter = eventsFilter) - listener.applicationInfo.foreach { app => - // Invalidate the existing UI for the reloaded app attempt, if any. See LoadedAppUI for a - // discussion on the UI lifecycle. - synchronized { - activeUIs.get((app.info.id, app.attempts.head.info.attemptId)).foreach { ui => - ui.invalidate() - ui.ui.store.close() + val (appId, attemptId) = listener.applicationInfo match { + case Some(app) => + // Invalidate the existing UI for the reloaded app attempt, if any. See LoadedAppUI for a + // discussion on the UI lifecycle. + synchronized { + activeUIs.get((app.info.id, app.attempts.head.info.attemptId)).foreach { ui => + ui.invalidate() + ui.ui.store.close() + } } - } - addListing(app) + addListing(app) + (Some(app.info.id), app.attempts.head.info.attemptId) + + case _ => + // If the app hasn't written down its app ID to the logs, still record the entry in the + // listing db, with an empty ID. This will make the log eligible for deletion if the app + // does not make progress after the configured max log age. + (None, None) } - listing.write(new LogInfo(logPath.toString(), fileStatus.getLen())) + listing.write(LogInfo(logPath.toString(), scanTime, appId, attemptId, fileStatus.getLen())) } /** * Delete event logs from the log directory according to the clean policy defined by the user. */ - private[history] def cleanLogs(): Unit = { - var iterator: Option[KVStoreIterator[ApplicationInfoWrapper]] = None - try { - val maxTime = clock.getTimeMillis() - conf.get(MAX_LOG_AGE_S) * 1000 - - // Iterate descending over all applications whose oldest attempt happened before maxTime. - iterator = Some(listing.view(classOf[ApplicationInfoWrapper]) - .index("oldestAttempt") - .reverse() - .first(maxTime) - .closeableIterator()) - - iterator.get.asScala.foreach { app => - // Applications may have multiple attempts, some of which may not need to be deleted yet. - val (remaining, toDelete) = app.attempts.partition { attempt => - attempt.info.lastUpdated.getTime() >= maxTime - } + private[history] def cleanLogs(): Unit = Utils.tryLog { + val maxTime = clock.getTimeMillis() - conf.get(MAX_LOG_AGE_S) * 1000 - if (remaining.nonEmpty) { - val newApp = new ApplicationInfoWrapper(app.info, remaining) - listing.write(newApp) - } + val expired = listing.view(classOf[ApplicationInfoWrapper]) + .index("oldestAttempt") + .reverse() + .first(maxTime) + .asScala + .toList + expired.foreach { app => + // Applications may have multiple attempts, some of which may not need to be deleted yet. + val (remaining, toDelete) = app.attempts.partition { attempt => + attempt.info.lastUpdated.getTime() >= maxTime + } - toDelete.foreach { attempt => - val logPath = new Path(logDir, attempt.logPath) - try { - listing.delete(classOf[LogInfo], logPath.toString()) - } catch { - case _: NoSuchElementException => - logDebug(s"Log info entry for $logPath not found.") - } - try { - fs.delete(logPath, true) - } catch { - case e: AccessControlException => - logInfo(s"No permission to delete ${attempt.logPath}, ignoring.") - case t: IOException => - logError(s"IOException in cleaning ${attempt.logPath}", t) - } - } + if (remaining.nonEmpty) { + val newApp = new ApplicationInfoWrapper(app.info, remaining) + listing.write(newApp) + } - if (remaining.isEmpty) { - listing.delete(app.getClass(), app.id) - } + toDelete.foreach { attempt => + logInfo(s"Deleting expired event log for ${attempt.logPath}") + val logPath = new Path(logDir, attempt.logPath) + listing.delete(classOf[LogInfo], logPath.toString()) + cleanAppData(app.id, attempt.info.attemptId, logPath.toString()) + deleteLog(logPath) + } + + if (remaining.isEmpty) { + listing.delete(app.getClass(), app.id) + } + } + + // Delete log files that don't have a valid application and exceed the configured max age. + val stale = listing.view(classOf[LogInfo]) + .index("lastProcessed") + .reverse() + .first(maxTime) + .asScala + .toList + stale.foreach { log => + if (log.appId.isEmpty) { + logInfo(s"Deleting invalid / corrupt event log ${log.logPath}") + deleteLog(new Path(log.logPath)) + listing.delete(classOf[LogInfo], log.logPath) } - } catch { - case t: Exception => logError("Exception while cleaning logs", t) - } finally { - iterator.foreach(_.close()) } } @@ -631,12 +713,9 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) // an error the other way -- if we report a size bigger (ie later) than the file that is // actually read, we may never refresh the app. FileStatus is guaranteed to be static // after it's created, so we get a file size that is no bigger than what is actually read. - val logInput = EventLoggingListener.openEventLog(logPath, fs) - try { - bus.replay(logInput, logPath.toString, !isCompleted, eventsFilter) + Utils.tryWithResource(EventLoggingListener.openEventLog(logPath, fs)) { in => + bus.replay(in, logPath.toString, !isCompleted, eventsFilter) logInfo(s"Finished parsing $logPath") - } finally { - logInput.close() } } @@ -703,18 +782,6 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) | application count=$count}""".stripMargin } - /** - * Return the last known size of the given event log, recorded the last time the file - * system scanner detected a change in the file. - */ - private def recordedFileSize(log: Path): Long = { - try { - listing.read(classOf[LogInfo], log.toString()).fileSize - } catch { - case _: NoSuchElementException => 0L - } - } - private def load(appId: String): ApplicationInfoWrapper = { listing.read(classOf[ApplicationInfoWrapper], appId) } @@ -773,11 +840,8 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) logInfo(s"Leasing disk manager space for app $appId / ${attempt.info.attemptId}...") val lease = dm.lease(status.getLen(), isCompressed) val newStorePath = try { - val store = KVUtils.open(lease.tmpPath, metadata) - try { + Utils.tryWithResource(KVUtils.open(lease.tmpPath, metadata)) { store => rebuildAppStore(store, status, attempt.info.lastUpdated.getTime()) - } finally { - store.close() } lease.commit(appId, attempt.info.attemptId) } catch { @@ -806,6 +870,17 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) throw new NoSuchElementException(s"Cannot find attempt $attemptId of $appId.")) } + private def deleteLog(log: Path): Unit = { + try { + fs.delete(log, true) + } catch { + case _: AccessControlException => + logInfo(s"No permission to delete $log, ignoring.") + case ioe: IOException => + logError(s"IOException in cleaning $log", ioe) + } + } + } private[history] object FsHistoryProvider { @@ -832,8 +907,16 @@ private[history] case class FsHistoryProviderMetadata( uiVersion: Long, logDir: String) +/** + * Tracking info for event logs detected in the configured log directory. Tracks both valid and + * invalid logs (e.g. unparseable logs, recorded as logs with no app ID) so that the cleaner + * can know what log files are safe to delete. + */ private[history] case class LogInfo( @KVIndexParam logPath: String, + @KVIndexParam("lastProcessed") lastProcessed: Long, + appId: Option[String], + attemptId: Option[String], fileSize: Long) private[history] class AttemptInfoWrapper( diff --git a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala index 84ee01c7f5aaf..787de59edf465 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala @@ -31,7 +31,7 @@ import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.hdfs.DistributedFileSystem import org.json4s.jackson.JsonMethods._ import org.mockito.Matchers.any -import org.mockito.Mockito.{mock, spy, verify} +import org.mockito.Mockito.{doReturn, mock, spy, verify} import org.scalatest.BeforeAndAfter import org.scalatest.Matchers import org.scalatest.concurrent.Eventually._ @@ -149,8 +149,10 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc class TestFsHistoryProvider extends FsHistoryProvider(createTestConf()) { var mergeApplicationListingCall = 0 - override protected def mergeApplicationListing(fileStatus: FileStatus): Unit = { - super.mergeApplicationListing(fileStatus) + override protected def mergeApplicationListing( + fileStatus: FileStatus, + lastSeen: Long): Unit = { + super.mergeApplicationListing(fileStatus, lastSeen) mergeApplicationListingCall += 1 } } @@ -663,6 +665,115 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc freshUI.get.ui.store.job(0) } + test("clean up stale app information") { + val storeDir = Utils.createTempDir() + val conf = createTestConf().set(LOCAL_STORE_DIR, storeDir.getAbsolutePath()) + val provider = spy(new FsHistoryProvider(conf)) + val appId = "new1" + + // Write logs for two app attempts. + doReturn(1L).when(provider).getNewLastScanTime() + val attempt1 = newLogFile(appId, Some("1"), inProgress = false) + writeFile(attempt1, true, None, + SparkListenerApplicationStart(appId, Some(appId), 1L, "test", Some("1")), + SparkListenerJobStart(0, 1L, Nil, null), + SparkListenerApplicationEnd(5L) + ) + val attempt2 = newLogFile(appId, Some("2"), inProgress = false) + writeFile(attempt2, true, None, + SparkListenerApplicationStart(appId, Some(appId), 1L, "test", Some("2")), + SparkListenerJobStart(0, 1L, Nil, null), + SparkListenerApplicationEnd(5L) + ) + updateAndCheck(provider) { list => + assert(list.size === 1) + assert(list(0).id === appId) + assert(list(0).attempts.size === 2) + } + + // Load the app's UI. + val ui = provider.getAppUI(appId, Some("1")) + assert(ui.isDefined) + + // Delete the underlying log file for attempt 1 and rescan. The UI should go away, but since + // attempt 2 still exists, listing data should be there. + doReturn(2L).when(provider).getNewLastScanTime() + attempt1.delete() + updateAndCheck(provider) { list => + assert(list.size === 1) + assert(list(0).id === appId) + assert(list(0).attempts.size === 1) + } + assert(!ui.get.valid) + assert(provider.getAppUI(appId, None) === None) + + // Delete the second attempt's log file. Now everything should go away. + doReturn(3L).when(provider).getNewLastScanTime() + attempt2.delete() + updateAndCheck(provider) { list => + assert(list.isEmpty) + } + } + + test("SPARK-21571: clean up removes invalid history files") { + // TODO: "maxTime" becoming negative in cleanLogs() causes this test to fail, so avoid that + // until we figure out what's causing the problem. + val clock = new ManualClock(TimeUnit.DAYS.toMillis(120)) + val conf = createTestConf().set(MAX_LOG_AGE_S.key, s"2d") + val provider = new FsHistoryProvider(conf, clock) { + override def getNewLastScanTime(): Long = clock.getTimeMillis() + } + + // Create 0-byte size inprogress and complete files + var logCount = 0 + var validLogCount = 0 + + val emptyInProgress = newLogFile("emptyInprogressLogFile", None, inProgress = true) + emptyInProgress.createNewFile() + emptyInProgress.setLastModified(clock.getTimeMillis()) + logCount += 1 + + val slowApp = newLogFile("slowApp", None, inProgress = true) + slowApp.createNewFile() + slowApp.setLastModified(clock.getTimeMillis()) + logCount += 1 + + val emptyFinished = newLogFile("emptyFinishedLogFile", None, inProgress = false) + emptyFinished.createNewFile() + emptyFinished.setLastModified(clock.getTimeMillis()) + logCount += 1 + + // Create an incomplete log file, has an end record but no start record. + val corrupt = newLogFile("nonEmptyCorruptLogFile", None, inProgress = false) + writeFile(corrupt, true, None, SparkListenerApplicationEnd(0)) + corrupt.setLastModified(clock.getTimeMillis()) + logCount += 1 + + provider.checkForLogs() + provider.cleanLogs() + assert(new File(testDir.toURI).listFiles().size === logCount) + + // Move the clock forward 1 day and scan the files again. They should still be there. + clock.advance(TimeUnit.DAYS.toMillis(1)) + provider.checkForLogs() + provider.cleanLogs() + assert(new File(testDir.toURI).listFiles().size === logCount) + + // Update the slow app to contain valid info. Code should detect the change and not clean + // it up. + writeFile(slowApp, true, None, + SparkListenerApplicationStart(slowApp.getName(), Some(slowApp.getName()), 1L, "test", None)) + slowApp.setLastModified(clock.getTimeMillis()) + validLogCount += 1 + + // Move the clock forward another 2 days and scan the files again. This time the cleaner should + // pick up the invalid files and get rid of them. + clock.advance(TimeUnit.DAYS.toMillis(2)) + provider.checkForLogs() + provider.cleanLogs() + assert(new File(testDir.toURI).listFiles().size === validLogCount) + } + /** * Asks the provider to check for logs and calls a function to perform checks on the updated * app list. Example: diff --git a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala index 87778dda0e2c8..7aa60f2b60796 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala @@ -48,7 +48,7 @@ import org.apache.spark.deploy.history.config._ import org.apache.spark.status.api.v1.ApplicationInfo import org.apache.spark.status.api.v1.JobData import org.apache.spark.ui.SparkUI -import org.apache.spark.util.{ResetSystemProperties, Utils} +import org.apache.spark.util.{ResetSystemProperties, ShutdownHookManager, Utils} /** * A collection of tests against the historyserver, including comparing responses from the json @@ -564,7 +564,7 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers assert(jobcount === getNumJobs("/jobs")) // no need to retain the test dir now the tests complete - logDir.deleteOnExit() + ShutdownHookManager.registerShutdownDeleteDir(logDir) } test("ui and api authorization checks") { From aa3a1276f9e23ffbb093d00743e63cd4369f9f57 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Fri, 19 Jan 2018 13:32:20 -0600 Subject: [PATCH 0153/1161] [SPARK-23103][CORE] Ensure correct sort order for negative values in LevelDB. The code was sorting "0" as "less than" negative values, which is a little wrong. Fix is simple, most of the changes are the added test and related cleanup. Author: Marcelo Vanzin Closes #20284 from vanzin/SPARK-23103. --- .../spark/util/kvstore/LevelDBTypeInfo.java | 2 +- .../spark/util/kvstore/DBIteratorSuite.java | 7 +- .../spark/util/kvstore/LevelDBSuite.java | 77 ++++++++++--------- .../spark/status/AppStatusListenerSuite.scala | 8 +- 4 files changed, 52 insertions(+), 42 deletions(-) diff --git a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDBTypeInfo.java b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDBTypeInfo.java index 232ee41dd0b1f..f4d359234cb9e 100644 --- a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDBTypeInfo.java +++ b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDBTypeInfo.java @@ -493,7 +493,7 @@ byte[] toKey(Object value, byte prefix) { byte[] key = new byte[bytes * 2 + 2]; long longValue = ((Number) value).longValue(); key[0] = prefix; - key[1] = longValue > 0 ? POSITIVE_MARKER : NEGATIVE_MARKER; + key[1] = longValue >= 0 ? POSITIVE_MARKER : NEGATIVE_MARKER; for (int i = 0; i < key.length - 2; i++) { int masked = (int) ((longValue >>> (4 * i)) & 0xF); diff --git a/common/kvstore/src/test/java/org/apache/spark/util/kvstore/DBIteratorSuite.java b/common/kvstore/src/test/java/org/apache/spark/util/kvstore/DBIteratorSuite.java index 9a81f86812cde..1e062437d1803 100644 --- a/common/kvstore/src/test/java/org/apache/spark/util/kvstore/DBIteratorSuite.java +++ b/common/kvstore/src/test/java/org/apache/spark/util/kvstore/DBIteratorSuite.java @@ -73,7 +73,9 @@ default BaseComparator reverse() { private static final BaseComparator NATURAL_ORDER = (t1, t2) -> t1.key.compareTo(t2.key); private static final BaseComparator REF_INDEX_ORDER = (t1, t2) -> t1.id.compareTo(t2.id); private static final BaseComparator COPY_INDEX_ORDER = (t1, t2) -> t1.name.compareTo(t2.name); - private static final BaseComparator NUMERIC_INDEX_ORDER = (t1, t2) -> t1.num - t2.num; + private static final BaseComparator NUMERIC_INDEX_ORDER = (t1, t2) -> { + return Integer.valueOf(t1.num).compareTo(t2.num); + }; private static final BaseComparator CHILD_INDEX_ORDER = (t1, t2) -> t1.child.compareTo(t2.child); /** @@ -112,7 +114,8 @@ public void setup() throws Exception { t.key = "key" + i; t.id = "id" + i; t.name = "name" + RND.nextInt(MAX_ENTRIES); - t.num = RND.nextInt(MAX_ENTRIES); + // Force one item to have an integer value of zero to test the fix for SPARK-23103. + t.num = (i != 0) ? (int) RND.nextLong() : 0; t.child = "child" + (i % MIN_ENTRIES); allEntries.add(t); } diff --git a/common/kvstore/src/test/java/org/apache/spark/util/kvstore/LevelDBSuite.java b/common/kvstore/src/test/java/org/apache/spark/util/kvstore/LevelDBSuite.java index 2b07d249d2022..b8123ac81d29a 100644 --- a/common/kvstore/src/test/java/org/apache/spark/util/kvstore/LevelDBSuite.java +++ b/common/kvstore/src/test/java/org/apache/spark/util/kvstore/LevelDBSuite.java @@ -21,6 +21,8 @@ import java.util.Arrays; import java.util.List; import java.util.NoSuchElementException; +import java.util.stream.Collectors; +import java.util.stream.StreamSupport; import org.apache.commons.io.FileUtils; import org.iq80.leveldb.DBIterator; @@ -74,11 +76,7 @@ public void testReopenAndVersionCheckDb() throws Exception { @Test public void testObjectWriteReadDelete() throws Exception { - CustomType1 t = new CustomType1(); - t.key = "key"; - t.id = "id"; - t.name = "name"; - t.child = "child"; + CustomType1 t = createCustomType1(1); try { db.read(CustomType1.class, t.key); @@ -106,17 +104,9 @@ public void testObjectWriteReadDelete() throws Exception { @Test public void testMultipleObjectWriteReadDelete() throws Exception { - CustomType1 t1 = new CustomType1(); - t1.key = "key1"; - t1.id = "id"; - t1.name = "name1"; - t1.child = "child1"; - - CustomType1 t2 = new CustomType1(); - t2.key = "key2"; - t2.id = "id"; - t2.name = "name2"; - t2.child = "child2"; + CustomType1 t1 = createCustomType1(1); + CustomType1 t2 = createCustomType1(2); + t2.id = t1.id; db.write(t1); db.write(t2); @@ -142,11 +132,7 @@ public void testMultipleObjectWriteReadDelete() throws Exception { @Test public void testMultipleTypesWriteReadDelete() throws Exception { - CustomType1 t1 = new CustomType1(); - t1.key = "1"; - t1.id = "id"; - t1.name = "name1"; - t1.child = "child1"; + CustomType1 t1 = createCustomType1(1); IntKeyType t2 = new IntKeyType(); t2.key = 2; @@ -188,10 +174,7 @@ public void testMultipleTypesWriteReadDelete() throws Exception { public void testMetadata() throws Exception { assertNull(db.getMetadata(CustomType1.class)); - CustomType1 t = new CustomType1(); - t.id = "id"; - t.name = "name"; - t.child = "child"; + CustomType1 t = createCustomType1(1); db.setMetadata(t); assertEquals(t, db.getMetadata(CustomType1.class)); @@ -202,11 +185,7 @@ public void testMetadata() throws Exception { @Test public void testUpdate() throws Exception { - CustomType1 t = new CustomType1(); - t.key = "key"; - t.id = "id"; - t.name = "name"; - t.child = "child"; + CustomType1 t = createCustomType1(1); db.write(t); @@ -222,13 +201,7 @@ public void testUpdate() throws Exception { @Test public void testSkip() throws Exception { for (int i = 0; i < 10; i++) { - CustomType1 t = new CustomType1(); - t.key = "key" + i; - t.id = "id" + i; - t.name = "name" + i; - t.child = "child" + i; - - db.write(t); + db.write(createCustomType1(i)); } KVStoreIterator it = db.view(CustomType1.class).closeableIterator(); @@ -240,6 +213,36 @@ public void testSkip() throws Exception { assertFalse(it.hasNext()); } + @Test + public void testNegativeIndexValues() throws Exception { + List expected = Arrays.asList(-100, -50, 0, 50, 100); + + expected.stream().forEach(i -> { + try { + db.write(createCustomType1(i)); + } catch (Exception e) { + throw new RuntimeException(e); + } + }); + + List results = StreamSupport + .stream(db.view(CustomType1.class).index("int").spliterator(), false) + .map(e -> e.num) + .collect(Collectors.toList()); + + assertEquals(expected, results); + } + + private CustomType1 createCustomType1(int i) { + CustomType1 t = new CustomType1(); + t.key = "key" + i; + t.id = "id" + i; + t.name = "name" + i; + t.num = i; + t.child = "child" + i; + return t; + } + private int countKeys(Class type) throws Exception { byte[] prefix = db.getTypeInfo(type).keyPrefix(); int count = 0; diff --git a/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala index ca66b6b9db890..e7981bec6d64b 100644 --- a/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala @@ -894,15 +894,19 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { val dropped = stages.drop(1).head // Cache some quantiles by calling AppStatusStore.taskSummary(). For quantiles to be - // calculcated, we need at least one finished task. + // calculated, we need at least one finished task. The code in AppStatusStore uses + // `executorRunTime` to detect valid tasks, so that metric needs to be updated in the + // task end event. time += 1 val task = createTasks(1, Array("1")).head listener.onTaskStart(SparkListenerTaskStart(dropped.stageId, dropped.attemptId, task)) time += 1 task.markFinished(TaskState.FINISHED, time) + val metrics = TaskMetrics.empty + metrics.setExecutorRunTime(42L) listener.onTaskEnd(SparkListenerTaskEnd(dropped.stageId, dropped.attemptId, - "taskType", Success, task, null)) + "taskType", Success, task, metrics)) new AppStatusStore(store) .taskSummary(dropped.stageId, dropped.attemptId, Array(0.25d, 0.50d, 0.75d)) From f6da41b0150725fe96ccb2ee3b48840b207f47eb Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Fri, 19 Jan 2018 13:14:24 -0800 Subject: [PATCH 0154/1161] [SPARK-23135][UI] Fix rendering of accumulators in the stage page. This follows the behavior of 2.2: only named accumulators with a value are rendered. Screenshot: ![accs](https://user-images.githubusercontent.com/1694083/35065700-df409114-fb82-11e7-87c1-550c3f674371.png) Author: Marcelo Vanzin Closes #20299 from vanzin/SPARK-23135. --- .../org/apache/spark/ui/jobs/StagePage.scala | 20 ++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index 25bee33028393..0eb3190205c3e 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -260,7 +260,11 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We val accumulableHeaders: Seq[String] = Seq("Accumulable", "Value") def accumulableRow(acc: AccumulableInfo): Seq[Node] = { - {acc.name}{acc.value} + if (acc.name != null && acc.value != null) { + {acc.name}{acc.value} + } else { + Nil + } } val accumulableTable = UIUtils.listingTable( accumulableHeaders, @@ -864,7 +868,7 @@ private[ui] class TaskPagedTable( {formatBytes(task.taskMetrics.map(_.peakExecutionMemory))} {if (hasAccumulators(stage)) { - accumulatorsInfo(task) + {accumulatorsInfo(task)} }} {if (hasInput(stage)) { metricInfo(task) { m => @@ -920,8 +924,12 @@ private[ui] class TaskPagedTable( } private def accumulatorsInfo(task: TaskData): Seq[Node] = { - task.accumulatorUpdates.map { acc => - Unparsed(StringEscapeUtils.escapeHtml4(s"${acc.name}: ${acc.update}")) + task.accumulatorUpdates.flatMap { acc => + if (acc.name != null && acc.update.isDefined) { + Unparsed(StringEscapeUtils.escapeHtml4(s"${acc.name}: ${acc.update.get}")) ++
+ } else { + Nil + } } } @@ -985,7 +993,9 @@ private object ApiHelper { "Shuffle Spill (Disk)" -> TaskIndexNames.DISK_SPILL, "Errors" -> TaskIndexNames.ERROR) - def hasAccumulators(stageData: StageData): Boolean = stageData.accumulatorUpdates.size > 0 + def hasAccumulators(stageData: StageData): Boolean = { + stageData.accumulatorUpdates.exists { acc => acc.name != null && acc.value != null } + } def hasInput(stageData: StageData): Boolean = stageData.inputBytes > 0 From 793841c6b8b98b918dcf241e29f60ef125914db9 Mon Sep 17 00:00:00 2001 From: Kent Yao <11215016@zju.edu.cn> Date: Fri, 19 Jan 2018 15:49:29 -0800 Subject: [PATCH 0155/1161] [SPARK-21771][SQL] remove useless hive client in SparkSQLEnv ## What changes were proposed in this pull request? Once a meta hive client is created, it generates its SessionState which creates a lot of session related directories, some deleteOnExit, some does not. if a hive client is useless we may not create it at the very start. ## How was this patch tested? N/A cc hvanhovell cloud-fan Author: Kent Yao <11215016@zju.edu.cn> Closes #18983 from yaooqinn/patch-1. --- .../org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala index 6b19f971b73bb..cbd75ad12d430 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala @@ -50,8 +50,7 @@ private[hive] object SparkSQLEnv extends Logging { sqlContext = sparkSession.sqlContext val metadataHive = sparkSession - .sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog] - .client.newSession() + .sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog].client metadataHive.setOut(new PrintStream(System.out, true, "UTF-8")) metadataHive.setInfo(new PrintStream(System.err, true, "UTF-8")) metadataHive.setError(new PrintStream(System.err, true, "UTF-8")) From 396cdfbea45232bacbc03bfaf8be4ea85d47d3fd Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Fri, 19 Jan 2018 22:46:34 -0800 Subject: [PATCH 0156/1161] [SPARK-23091][ML] Incorrect unit test for approxQuantile ## What changes were proposed in this pull request? Narrow bound on approx quantile test to epsilon from 2*epsilon to match paper ## How was this patch tested? Existing tests. Author: Sean Owen Closes #20324 from srowen/SPARK-23091. --- .../apache/spark/sql/DataFrameStatSuite.scala | 24 +++++++++---------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index 5169d2b5fc6b2..8eae35325faea 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -154,24 +154,24 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext { val Array(d1, d2) = df.stat.approxQuantile("doubles", Array(q1, q2), epsilon) val Array(s1, s2) = df.stat.approxQuantile("singles", Array(q1, q2), epsilon) - val error_single = 2 * 1000 * epsilon - val error_double = 2 * 2000 * epsilon + val errorSingle = 1000 * epsilon + val errorDouble = 2.0 * errorSingle - assert(math.abs(single1 - q1 * n) < error_single) - assert(math.abs(double2 - 2 * q2 * n) < error_double) - assert(math.abs(s1 - q1 * n) < error_single) - assert(math.abs(s2 - q2 * n) < error_single) - assert(math.abs(d1 - 2 * q1 * n) < error_double) - assert(math.abs(d2 - 2 * q2 * n) < error_double) + assert(math.abs(single1 - q1 * n) <= errorSingle) + assert(math.abs(double2 - 2 * q2 * n) <= errorDouble) + assert(math.abs(s1 - q1 * n) <= errorSingle) + assert(math.abs(s2 - q2 * n) <= errorSingle) + assert(math.abs(d1 - 2 * q1 * n) <= errorDouble) + assert(math.abs(d2 - 2 * q2 * n) <= errorDouble) // Multiple columns val Array(Array(ms1, ms2), Array(md1, md2)) = df.stat.approxQuantile(Array("singles", "doubles"), Array(q1, q2), epsilon) - assert(math.abs(ms1 - q1 * n) < error_single) - assert(math.abs(ms2 - q2 * n) < error_single) - assert(math.abs(md1 - 2 * q1 * n) < error_double) - assert(math.abs(md2 - 2 * q2 * n) < error_double) + assert(math.abs(ms1 - q1 * n) <= errorSingle) + assert(math.abs(ms2 - q2 * n) <= errorSingle) + assert(math.abs(md1 - 2 * q1 * n) <= errorDouble) + assert(math.abs(md2 - 2 * q2 * n) <= errorDouble) } // quantile should be in the range [0.0, 1.0] From 84a076e0e9a38a26edf7b702c24fdbbcf1e697b9 Mon Sep 17 00:00:00 2001 From: Shashwat Anand Date: Sat, 20 Jan 2018 14:34:37 -0800 Subject: [PATCH 0157/1161] [SPARK-23165][DOC] Spelling mistake fix in quick-start doc. ## What changes were proposed in this pull request? Fix spelling in quick-start doc. ## How was this patch tested? Doc only. Author: Shashwat Anand Closes #20336 from ashashwat/SPARK-23165. --- docs/cloud-integration.md | 4 ++-- docs/configuration.md | 14 +++++++------- docs/graphx-programming-guide.md | 4 ++-- docs/monitoring.md | 8 ++++---- docs/quick-start.md | 6 +++--- docs/running-on-mesos.md | 2 +- docs/running-on-yarn.md | 2 +- docs/security.md | 2 +- docs/sql-programming-guide.md | 8 ++++---- docs/storage-openstack-swift.md | 2 +- docs/streaming-programming-guide.md | 4 ++-- docs/structured-streaming-kafka-integration.md | 4 ++-- docs/structured-streaming-programming-guide.md | 6 +++--- docs/submitting-applications.md | 8 ++++---- 14 files changed, 37 insertions(+), 37 deletions(-) diff --git a/docs/cloud-integration.md b/docs/cloud-integration.md index 751a192da4ffd..c150d9efc06ff 100644 --- a/docs/cloud-integration.md +++ b/docs/cloud-integration.md @@ -180,10 +180,10 @@ under the path, not the number of *new* files, so it can become a slow operation The size of the window needs to be set to handle this. 1. Files only appear in an object store once they are completely written; there -is no need for a worklow of write-then-rename to ensure that files aren't picked up +is no need for a workflow of write-then-rename to ensure that files aren't picked up while they are still being written. Applications can write straight to the monitored directory. -1. Streams should only be checkpointed to an store implementing a fast and +1. Streams should only be checkpointed to a store implementing a fast and atomic `rename()` operation Otherwise the checkpointing may be slow and potentially unreliable. ## Further Reading diff --git a/docs/configuration.md b/docs/configuration.md index eecb39dcafc9e..e7f2419cc2fa4 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -79,7 +79,7 @@ Then, you can supply configuration values at runtime: {% endhighlight %} The Spark shell and [`spark-submit`](submitting-applications.html) -tool support two ways to load configurations dynamically. The first are command line options, +tool support two ways to load configurations dynamically. The first is command line options, such as `--master`, as shown above. `spark-submit` can accept any Spark property using the `--conf` flag, but uses special flags for properties that play a part in launching the Spark application. Running `./bin/spark-submit --help` will show the entire list of these options. @@ -413,7 +413,7 @@ Apart from these, the following properties are also available, and may be useful false Enable profiling in Python worker, the profile result will show up by sc.show_profiles(), - or it will be displayed before the driver exiting. It also can be dumped into disk by + or it will be displayed before the driver exits. It also can be dumped into disk by sc.dump_profiles(path). If some of the profile results had been displayed manually, they will not be displayed automatically before driver exiting. @@ -446,7 +446,7 @@ Apart from these, the following properties are also available, and may be useful true Reuse Python worker or not. If yes, it will use a fixed number of Python workers, - does not need to fork() a Python process for every tasks. It will be very useful + does not need to fork() a Python process for every task. It will be very useful if there is large broadcast, then the broadcast will not be needed to transferred from JVM to Python worker for every task. @@ -1294,7 +1294,7 @@ Apart from these, the following properties are also available, and may be useful spark.files.openCostInBytes 4194304 (4 MB) - The estimated cost to open a file, measured by the number of bytes could be scanned in the same + The estimated cost to open a file, measured by the number of bytes could be scanned at the same time. This is used when putting multiple files into a partition. It is better to over estimate, then the partitions with small files will be faster than partitions with bigger files. @@ -1855,8 +1855,8 @@ Apart from these, the following properties are also available, and may be useful spark.user.groups.mapping org.apache.spark.security.ShellBasedGroupsMappingProvider - The list of groups for a user are determined by a group mapping service defined by the trait - org.apache.spark.security.GroupMappingServiceProvider which can configured by this property. + The list of groups for a user is determined by a group mapping service defined by the trait + org.apache.spark.security.GroupMappingServiceProvider which can be configured by this property. A default unix shell based implementation is provided org.apache.spark.security.ShellBasedGroupsMappingProvider which can be specified to resolve a list of groups for a user. Note: This implementation supports only a Unix/Linux based environment. Windows environment is @@ -2465,7 +2465,7 @@ should be included on Spark's classpath: The location of these configuration files varies across Hadoop versions, but a common location is inside of `/etc/hadoop/conf`. Some tools create -configurations on-the-fly, but offer a mechanisms to download copies of them. +configurations on-the-fly, but offer a mechanism to download copies of them. To make these files visible to Spark, set `HADOOP_CONF_DIR` in `$SPARK_HOME/conf/spark-env.sh` to a location containing the configuration files. diff --git a/docs/graphx-programming-guide.md b/docs/graphx-programming-guide.md index 46225dc598da8..5c97a248df4bc 100644 --- a/docs/graphx-programming-guide.md +++ b/docs/graphx-programming-guide.md @@ -708,7 +708,7 @@ messages remaining. > messaging function. These constraints allow additional optimization within GraphX. The following is the type signature of the [Pregel operator][GraphOps.pregel] as well as a *sketch* -of its implementation (note: to avoid stackOverflowError due to long lineage chains, pregel support periodcally +of its implementation (note: to avoid stackOverflowError due to long lineage chains, pregel support periodically checkpoint graph and messages by setting "spark.graphx.pregel.checkpointInterval" to a positive number, say 10. And set checkpoint directory as well using SparkContext.setCheckpointDir(directory: String)): @@ -928,7 +928,7 @@ switch to 2D-partitioning or other heuristics included in GraphX.

-Once the edges have be partitioned the key challenge to efficient graph-parallel computation is +Once the edges have been partitioned the key challenge to efficient graph-parallel computation is efficiently joining vertex attributes with the edges. Because real-world graphs typically have more edges than vertices, we move vertex attributes to the edges. Because not all partitions will contain edges adjacent to all vertices we internally maintain a routing table which identifies where diff --git a/docs/monitoring.md b/docs/monitoring.md index f8d3ce91a0691..6f6cfc1288d73 100644 --- a/docs/monitoring.md +++ b/docs/monitoring.md @@ -118,7 +118,7 @@ The history server can be configured as follows: The number of applications to retain UI data for in the cache. If this cap is exceeded, then the oldest applications will be removed from the cache. If an application is not in the cache, - it will have to be loaded from disk if its accessed from the UI. + it will have to be loaded from disk if it is accessed from the UI. @@ -407,7 +407,7 @@ can be identified by their `[attempt-id]`. In the API listed below, when running -The number of jobs and stages which can retrieved is constrained by the same retention +The number of jobs and stages which can be retrieved is constrained by the same retention mechanism of the standalone Spark UI; `"spark.ui.retainedJobs"` defines the threshold value triggering garbage collection on jobs, and `spark.ui.retainedStages` that for stages. Note that the garbage collection takes place on playback: it is possible to retrieve @@ -422,10 +422,10 @@ These endpoints have been strongly versioned to make it easier to develop applic * Individual fields will never be removed for any given endpoint * New endpoints may be added * New fields may be added to existing endpoints -* New versions of the api may be added in the future at a separate endpoint (eg., `api/v2`). New versions are *not* required to be backwards compatible. +* New versions of the api may be added in the future as a separate endpoint (eg., `api/v2`). New versions are *not* required to be backwards compatible. * Api versions may be dropped, but only after at least one minor release of co-existing with a new api version. -Note that even when examining the UI of a running applications, the `applications/[app-id]` portion is +Note that even when examining the UI of running applications, the `applications/[app-id]` portion is still required, though there is only one application available. Eg. to see the list of jobs for the running app, you would go to `http://localhost:4040/api/v1/applications/[app-id]/jobs`. This is to keep the paths consistent in both modes. diff --git a/docs/quick-start.md b/docs/quick-start.md index 200b97230e866..07c520cbee6be 100644 --- a/docs/quick-start.md +++ b/docs/quick-start.md @@ -67,7 +67,7 @@ res3: Long = 15 ./bin/pyspark -Or if PySpark is installed with pip in your current enviroment: +Or if PySpark is installed with pip in your current environment: pyspark @@ -156,7 +156,7 @@ One common data flow pattern is MapReduce, as popularized by Hadoop. Spark can i >>> wordCounts = textFile.select(explode(split(textFile.value, "\s+")).alias("word")).groupBy("word").count() {% endhighlight %} -Here, we use the `explode` function in `select`, to transfrom a Dataset of lines to a Dataset of words, and then combine `groupBy` and `count` to compute the per-word counts in the file as a DataFrame of 2 columns: "word" and "count". To collect the word counts in our shell, we can call `collect`: +Here, we use the `explode` function in `select`, to transform a Dataset of lines to a Dataset of words, and then combine `groupBy` and `count` to compute the per-word counts in the file as a DataFrame of 2 columns: "word" and "count". To collect the word counts in our shell, we can call `collect`: {% highlight python %} >>> wordCounts.collect() @@ -422,7 +422,7 @@ $ YOUR_SPARK_HOME/bin/spark-submit \ Lines with a: 46, Lines with b: 23 {% endhighlight %} -If you have PySpark pip installed into your enviroment (e.g., `pip install pyspark`), you can run your application with the regular Python interpreter or use the provided 'spark-submit' as you prefer. +If you have PySpark pip installed into your environment (e.g., `pip install pyspark`), you can run your application with the regular Python interpreter or use the provided 'spark-submit' as you prefer. {% highlight bash %} # Use the Python interpreter to run your application diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md index 382cbfd5301b0..2bb5ecf1b8509 100644 --- a/docs/running-on-mesos.md +++ b/docs/running-on-mesos.md @@ -154,7 +154,7 @@ can find the results of the driver from the Mesos Web UI. To use cluster mode, you must start the `MesosClusterDispatcher` in your cluster via the `sbin/start-mesos-dispatcher.sh` script, passing in the Mesos master URL (e.g: mesos://host:5050). This starts the `MesosClusterDispatcher` as a daemon running on the host. -By setting the Mesos proxy config property (requires mesos version >= 1.4), `--conf spark.mesos.proxy.baseURL=http://localhost:5050` when launching the dispacther, the mesos sandbox URI for each driver is added to the mesos dispatcher UI. +By setting the Mesos proxy config property (requires mesos version >= 1.4), `--conf spark.mesos.proxy.baseURL=http://localhost:5050` when launching the dispatcher, the mesos sandbox URI for each driver is added to the mesos dispatcher UI. If you like to run the `MesosClusterDispatcher` with Marathon, you need to run the `MesosClusterDispatcher` in the foreground (i.e: `bin/spark-class org.apache.spark.deploy.mesos.MesosClusterDispatcher`). Note that the `MesosClusterDispatcher` not yet supports multiple instances for HA. diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index e7edec5990363..e4f5a0c659e66 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -445,7 +445,7 @@ To use a custom metrics.properties for the application master and executors, upd yarn.nodemanager.log-aggregation.roll-monitoring-interval-seconds should be configured in yarn-site.xml. This feature can only be used with Hadoop 2.6.4+. The Spark log4j appender needs be changed to use - FileAppender or another appender that can handle the files being removed while its running. Based + FileAppender or another appender that can handle the files being removed while it is running. Based on the file name configured in the log4j configuration (like spark.log), the user should set the regex (spark*) to include all the log files that need to be aggregated. diff --git a/docs/security.md b/docs/security.md index 15aadf07cf873..bebc28ddbfb0e 100644 --- a/docs/security.md +++ b/docs/security.md @@ -62,7 +62,7 @@ component-specific configuration namespaces used to override the default setting -The full breakdown of available SSL options can be found on the [configuration page](configuration.html). +The full breakdown of available SSL options can be found on the [configuration page](configuration.html). SSL must be configured on each node and configured for each component involved in communication using the particular protocol. ### YARN mode diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 3e2e48a0ef249..502c0a8c37e01 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1253,7 +1253,7 @@ provide a ClassTag. (Note that this is different than the Spark SQL JDBC server, which allows other applications to run queries using Spark SQL). -To get started you will need to include the JDBC driver for you particular database on the +To get started you will need to include the JDBC driver for your particular database on the spark classpath. For example, to connect to postgres from the Spark Shell you would run the following command: @@ -1793,7 +1793,7 @@ options. - Since Spark 2.3, when all inputs are binary, `functions.concat()` returns an output as binary. Otherwise, it returns as a string. Until Spark 2.3, it always returns as a string despite of input types. To keep the old behavior, set `spark.sql.function.concatBinaryAsString` to `true`. - Since Spark 2.3, when all inputs are binary, SQL `elt()` returns an output as binary. Otherwise, it returns as a string. Until Spark 2.3, it always returns as a string despite of input types. To keep the old behavior, set `spark.sql.function.eltOutputAsString` to `true`. - - Since Spark 2.3, by default arithmetic operations between decimals return a rounded value if an exact representation is not possible (instead of returning NULL). This is compliant to SQL ANSI 2011 specification and Hive's new behavior introduced in Hive 2.2 (HIVE-15331). This involves the following changes + - Since Spark 2.3, by default arithmetic operations between decimals return a rounded value if an exact representation is not possible (instead of returning NULL). This is compliant with SQL ANSI 2011 specification and Hive's new behavior introduced in Hive 2.2 (HIVE-15331). This involves the following changes - The rules to determine the result type of an arithmetic operation have been updated. In particular, if the precision / scale needed are out of the range of available values, the scale is reduced up to 6, in order to prevent the truncation of the integer part of the decimals. All the arithmetic operations are affected by the change, ie. addition (`+`), subtraction (`-`), multiplication (`*`), division (`/`), remainder (`%`) and positive module (`pmod`). - Literal values used in SQL operations are converted to DECIMAL with the exact precision and scale needed by them. - The configuration `spark.sql.decimalOperations.allowPrecisionLoss` has been introduced. It defaults to `true`, which means the new behavior described here; if set to `false`, Spark uses previous rules, ie. it doesn't adjust the needed scale to represent the values and it returns NULL if an exact representation of the value is not possible. @@ -1821,7 +1821,7 @@ options. transformations (e.g., `map`, `filter`, and `groupByKey`) and untyped transformations (e.g., `select` and `groupBy`) are available on the Dataset class. Since compile-time type-safety in Python and R is not a language feature, the concept of Dataset does not apply to these languages’ - APIs. Instead, `DataFrame` remains the primary programing abstraction, which is analogous to the + APIs. Instead, `DataFrame` remains the primary programming abstraction, which is analogous to the single-node data frame notion in these languages. - Dataset and DataFrame API `unionAll` has been deprecated and replaced by `union` @@ -1997,7 +1997,7 @@ Java and Python users will need to update their code. Prior to Spark 1.3 there were separate Java compatible classes (`JavaSQLContext` and `JavaSchemaRDD`) that mirrored the Scala API. In Spark 1.3 the Java API and Scala API have been unified. Users -of either language should use `SQLContext` and `DataFrame`. In general theses classes try to +of either language should use `SQLContext` and `DataFrame`. In general these classes try to use types that are usable from both languages (i.e. `Array` instead of language specific collections). In some cases where no common type exists (e.g., for passing in closures or Maps) function overloading is used instead. diff --git a/docs/storage-openstack-swift.md b/docs/storage-openstack-swift.md index f4bb2353e3c49..1dd54719b21aa 100644 --- a/docs/storage-openstack-swift.md +++ b/docs/storage-openstack-swift.md @@ -42,7 +42,7 @@ Create core-site.xml and place it inside Spark's conf The main category of parameters that should be configured are the authentication parameters required by Keystone. -The following table contains a list of Keystone mandatory parameters. PROVIDER can be +The following table contains a list of Keystone mandatory parameters. PROVIDER can be any (alphanumeric) name. diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index 868acc41226dc..ffda36d64a770 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -74,7 +74,7 @@ import org.apache.spark.streaming._ import org.apache.spark.streaming.StreamingContext._ // not necessary since Spark 1.3 // Create a local StreamingContext with two working thread and batch interval of 1 second. -// The master requires 2 cores to prevent from a starvation scenario. +// The master requires 2 cores to prevent a starvation scenario. val conf = new SparkConf().setMaster("local[2]").setAppName("NetworkWordCount") val ssc = new StreamingContext(conf, Seconds(1)) @@ -172,7 +172,7 @@ each line will be split into multiple words and the stream of words is represent `words` DStream. Note that we defined the transformation using a [FlatMapFunction](api/scala/index.html#org.apache.spark.api.java.function.FlatMapFunction) object. As we will discover along the way, there are a number of such convenience classes in the Java API -that help define DStream transformations. +that help defines DStream transformations. Next, we want to count these words. diff --git a/docs/structured-streaming-kafka-integration.md b/docs/structured-streaming-kafka-integration.md index 461c29ce1ba89..5647ec6bc5797 100644 --- a/docs/structured-streaming-kafka-integration.md +++ b/docs/structured-streaming-kafka-integration.md @@ -125,7 +125,7 @@ df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") ### Creating a Kafka Source for Batch Queries If you have a use case that is better suited to batch processing, -you can create an Dataset/DataFrame for a defined range of offsets. +you can create a Dataset/DataFrame for a defined range of offsets.
@@ -597,7 +597,7 @@ Note that the following Kafka params cannot be set and the Kafka source or sink - **key.serializer**: Keys are always serialized with ByteArraySerializer or StringSerializer. Use DataFrame operations to explicitly serialize the keys into either strings or byte arrays. - **value.serializer**: values are always serialized with ByteArraySerializer or StringSerializer. Use -DataFrame oeprations to explicitly serialize the values into either strings or byte arrays. +DataFrame operations to explicitly serialize the values into either strings or byte arrays. - **enable.auto.commit**: Kafka source doesn't commit any offset. - **interceptor.classes**: Kafka source always read keys and values as byte arrays. It's not safe to use ConsumerInterceptor as it may break the query. diff --git a/docs/structured-streaming-programming-guide.md b/docs/structured-streaming-programming-guide.md index 2ddba2f0d942e..2ef5d3168a87b 100644 --- a/docs/structured-streaming-programming-guide.md +++ b/docs/structured-streaming-programming-guide.md @@ -10,7 +10,7 @@ title: Structured Streaming Programming Guide # Overview Structured Streaming is a scalable and fault-tolerant stream processing engine built on the Spark SQL engine. You can express your streaming computation the same way you would express a batch computation on static data. The Spark SQL engine will take care of running it incrementally and continuously and updating the final result as streaming data continues to arrive. You can use the [Dataset/DataFrame API](sql-programming-guide.html) in Scala, Java, Python or R to express streaming aggregations, event-time windows, stream-to-batch joins, etc. The computation is executed on the same optimized Spark SQL engine. Finally, the system ensures end-to-end exactly-once fault-tolerance guarantees through checkpointing and Write Ahead Logs. In short, *Structured Streaming provides fast, scalable, fault-tolerant, end-to-end exactly-once stream processing without the user having to reason about streaming.* -Internally, by default, Structured Streaming queries are processed using a *micro-batch processing* engine, which processes data streams as a series of small batch jobs thereby achieving end-to-end latencies as low as 100 milliseconds and exactly-once fault-tolerance guarantees. However, since Spark 2.3, we have introduced a new low-latency processing mode called **Continuous Processing**, which can achieve end-to-end latencies as low as 1 millisecond with at-least-once guarantees. Without changing the Dataset/DataFrame operations in your queries, you will be able choose the mode based on your application requirements. +Internally, by default, Structured Streaming queries are processed using a *micro-batch processing* engine, which processes data streams as a series of small batch jobs thereby achieving end-to-end latencies as low as 100 milliseconds and exactly-once fault-tolerance guarantees. However, since Spark 2.3, we have introduced a new low-latency processing mode called **Continuous Processing**, which can achieve end-to-end latencies as low as 1 millisecond with at-least-once guarantees. Without changing the Dataset/DataFrame operations in your queries, you will be able to choose the mode based on your application requirements. In this guide, we are going to walk you through the programming model and the APIs. We are going to explain the concepts mostly using the default micro-batch processing model, and then [later](#continuous-processing-experimental) discuss Continuous Processing model. First, let's start with a simple example of a Structured Streaming query - a streaming word count. @@ -1121,7 +1121,7 @@ Let’s discuss the different types of supported stream-stream joins and how to ##### Inner Joins with optional Watermarking Inner joins on any kind of columns along with any kind of join conditions are supported. However, as the stream runs, the size of streaming state will keep growing indefinitely as -*all* past input must be saved as the any new input can match with any input from the past. +*all* past input must be saved as any new input can match with any input from the past. To avoid unbounded state, you have to define additional join conditions such that indefinitely old inputs cannot match with future inputs and therefore can be cleared from the state. In other words, you will have to do the following additional steps in the join. @@ -1839,7 +1839,7 @@ aggDF \ .format("console") \ .start() -# Have all the aggregates in an in memory table. The query name will be the table name +# Have all the aggregates in an in-memory table. The query name will be the table name aggDF \ .writeStream \ .queryName("aggregates") \ diff --git a/docs/submitting-applications.md b/docs/submitting-applications.md index 0473ab73a5e6c..a3643bf0838a1 100644 --- a/docs/submitting-applications.md +++ b/docs/submitting-applications.md @@ -5,7 +5,7 @@ title: Submitting Applications The `spark-submit` script in Spark's `bin` directory is used to launch applications on a cluster. It can use all of Spark's supported [cluster managers](cluster-overview.html#cluster-manager-types) -through a uniform interface so you don't have to configure your application specially for each one. +through a uniform interface so you don't have to configure your application especially for each one. # Bundling Your Application's Dependencies If your code depends on other projects, you will need to package them alongside @@ -58,7 +58,7 @@ for applications that involve the REPL (e.g. Spark shell). Alternatively, if your application is submitted from a machine far from the worker machines (e.g. locally on your laptop), it is common to use `cluster` mode to minimize network latency between -the drivers and the executors. Currently, standalone mode does not support cluster mode for Python +the drivers and the executors. Currently, the standalone mode does not support cluster mode for Python applications. For Python applications, simply pass a `.py` file in the place of `` instead of a JAR, @@ -68,7 +68,7 @@ There are a few options available that are specific to the [cluster manager](cluster-overview.html#cluster-manager-types) that is being used. For example, with a [Spark standalone cluster](spark-standalone.html) with `cluster` deploy mode, you can also specify `--supervise` to make sure that the driver is automatically restarted if it -fails with non-zero exit code. To enumerate all such options available to `spark-submit`, +fails with a non-zero exit code. To enumerate all such options available to `spark-submit`, run it with `--help`. Here are a few examples of common options: {% highlight bash %} @@ -192,7 +192,7 @@ debugging information by running `spark-submit` with the `--verbose` option. # Advanced Dependency Management When using `spark-submit`, the application jar along with any jars included with the `--jars` option -will be automatically transferred to the cluster. URLs supplied after `--jars` must be separated by commas. That list is included on the driver and executor classpaths. Directory expansion does not work with `--jars`. +will be automatically transferred to the cluster. URLs supplied after `--jars` must be separated by commas. That list is included in the driver and executor classpaths. Directory expansion does not work with `--jars`. Spark uses the following URL scheme to allow different strategies for disseminating jars: From 00d169156d4b1c91d2bcfd788b254b03c509dc41 Mon Sep 17 00:00:00 2001 From: fjh100456 Date: Sat, 20 Jan 2018 14:49:49 -0800 Subject: [PATCH 0158/1161] [SPARK-21786][SQL] The 'spark.sql.parquet.compression.codec' and 'spark.sql.orc.compression.codec' configuration doesn't take effect on hive table writing MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit [SPARK-21786][SQL] The 'spark.sql.parquet.compression.codec' and 'spark.sql.orc.compression.codec' configuration doesn't take effect on hive table writing What changes were proposed in this pull request? Pass ‘spark.sql.parquet.compression.codec’ value to ‘parquet.compression’. Pass ‘spark.sql.orc.compression.codec’ value to ‘orc.compress’. How was this patch tested? Add test. Note: This is the same issue mentioned in #19218 . That branch was deleted mistakenly, so make a new pr instead. gatorsmile maropu dongjoon-hyun discipleforteen Author: fjh100456 Author: Takeshi Yamamuro Author: Wenchen Fan Author: gatorsmile Author: Yinan Li Author: Marcelo Vanzin Author: Juliusz Sompolski Author: Felix Cheung Author: jerryshao Author: Li Jin Author: Gera Shegalov Author: chetkhatri Author: Joseph K. Bradley Author: Bago Amirbekian Author: Xianjin YE Author: Bruce Robbins Author: zuotingbing Author: Kent Yao Author: hyukjinkwon Author: Adrian Ionescu Closes #20087 from fjh100456/HiveTableWriting. --- .../datasources/orc/OrcOptions.scala | 2 + .../datasources/parquet/ParquetOptions.scala | 6 +- .../sql/hive/execution/HiveOptions.scala | 22 ++ .../sql/hive/execution/SaveAsHiveFile.scala | 20 +- .../sql/hive/CompressionCodecSuite.scala | 353 ++++++++++++++++++ 5 files changed, 397 insertions(+), 6 deletions(-) create mode 100644 sql/hive/src/test/scala/org/apache/spark/sql/hive/CompressionCodecSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOptions.scala index c866dd834a525..0ad3862f6cf01 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOptions.scala @@ -67,4 +67,6 @@ object OrcOptions { "snappy" -> "SNAPPY", "zlib" -> "ZLIB", "lzo" -> "LZO") + + def getORCCompressionCodecName(name: String): String = shortOrcCompressionCodecNames(name) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala index ef67ea7d17cea..f36a89a4c3c5f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.internal.SQLConf /** * Options for the Parquet data source. */ -private[parquet] class ParquetOptions( +class ParquetOptions( @transient private val parameters: CaseInsensitiveMap[String], @transient private val sqlConf: SQLConf) extends Serializable { @@ -82,4 +82,8 @@ object ParquetOptions { "snappy" -> CompressionCodecName.SNAPPY, "gzip" -> CompressionCodecName.GZIP, "lzo" -> CompressionCodecName.LZO) + + def getParquetCompressionCodecName(name: String): String = { + shortParquetCompressionCodecNames(name).name() + } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveOptions.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveOptions.scala index 5c515515b9b9c..802ddafdbee4d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveOptions.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveOptions.scala @@ -19,7 +19,16 @@ package org.apache.spark.sql.hive.execution import java.util.Locale +import scala.collection.JavaConverters._ + +import org.apache.hadoop.hive.ql.plan.TableDesc +import org.apache.orc.OrcConf.COMPRESS +import org.apache.parquet.hadoop.ParquetOutputFormat + import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap +import org.apache.spark.sql.execution.datasources.orc.OrcOptions +import org.apache.spark.sql.execution.datasources.parquet.ParquetOptions +import org.apache.spark.sql.internal.SQLConf /** * Options for the Hive data source. Note that rule `DetermineHiveSerde` will extract Hive @@ -102,4 +111,17 @@ object HiveOptions { "collectionDelim" -> "colelction.delim", "mapkeyDelim" -> "mapkey.delim", "lineDelim" -> "line.delim").map { case (k, v) => k.toLowerCase(Locale.ROOT) -> v } + + def getHiveWriteCompression(tableInfo: TableDesc, sqlConf: SQLConf): Option[(String, String)] = { + val tableProps = tableInfo.getProperties.asScala.toMap + tableInfo.getOutputFileFormatClassName.toLowerCase(Locale.ROOT) match { + case formatName if formatName.endsWith("parquetoutputformat") => + val compressionCodec = new ParquetOptions(tableProps, sqlConf).compressionCodecClassName + Option((ParquetOutputFormat.COMPRESSION, compressionCodec)) + case formatName if formatName.endsWith("orcoutputformat") => + val compressionCodec = new OrcOptions(tableProps, sqlConf).compressionCodec + Option((COMPRESS.getAttribute, compressionCodec)) + case _ => None + } + } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/SaveAsHiveFile.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/SaveAsHiveFile.scala index 9a6607f2f2c6c..e484356906e87 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/SaveAsHiveFile.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/SaveAsHiveFile.scala @@ -55,18 +55,28 @@ private[hive] trait SaveAsHiveFile extends DataWritingCommand { customPartitionLocations: Map[TablePartitionSpec, String] = Map.empty, partitionAttributes: Seq[Attribute] = Nil): Set[String] = { - val isCompressed = hadoopConf.get("hive.exec.compress.output", "false").toBoolean + val isCompressed = + fileSinkConf.getTableInfo.getOutputFileFormatClassName.toLowerCase(Locale.ROOT) match { + case formatName if formatName.endsWith("orcoutputformat") => + // For ORC,"mapreduce.output.fileoutputformat.compress", + // "mapreduce.output.fileoutputformat.compress.codec", and + // "mapreduce.output.fileoutputformat.compress.type" + // have no impact because it uses table properties to store compression information. + false + case _ => hadoopConf.get("hive.exec.compress.output", "false").toBoolean + } + if (isCompressed) { - // Please note that isCompressed, "mapreduce.output.fileoutputformat.compress", - // "mapreduce.output.fileoutputformat.compress.codec", and - // "mapreduce.output.fileoutputformat.compress.type" - // have no impact on ORC because it uses table properties to store compression information. hadoopConf.set("mapreduce.output.fileoutputformat.compress", "true") fileSinkConf.setCompressed(true) fileSinkConf.setCompressCodec(hadoopConf .get("mapreduce.output.fileoutputformat.compress.codec")) fileSinkConf.setCompressType(hadoopConf .get("mapreduce.output.fileoutputformat.compress.type")) + } else { + // Set compression by priority + HiveOptions.getHiveWriteCompression(fileSinkConf.getTableInfo, sparkSession.sessionState.conf) + .foreach { case (compression, codec) => hadoopConf.set(compression, codec) } } val committer = FileCommitProtocol.instantiate( diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CompressionCodecSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CompressionCodecSuite.scala new file mode 100644 index 0000000000000..d10a6f25c64fc --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CompressionCodecSuite.scala @@ -0,0 +1,353 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import java.io.File + +import scala.collection.JavaConverters._ + +import org.apache.hadoop.fs.Path +import org.apache.orc.OrcConf.COMPRESS +import org.apache.parquet.hadoop.ParquetOutputFormat +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.sql.execution.datasources.orc.OrcOptions +import org.apache.spark.sql.execution.datasources.parquet.{ParquetOptions, ParquetTest} +import org.apache.spark.sql.hive.orc.OrcFileOperator +import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.internal.SQLConf + +class CompressionCodecSuite extends TestHiveSingleton with ParquetTest with BeforeAndAfterAll { + import spark.implicits._ + + override def beforeAll(): Unit = { + super.beforeAll() + (0 until maxRecordNum).toDF("a").createOrReplaceTempView("table_source") + } + + override def afterAll(): Unit = { + try { + spark.catalog.dropTempView("table_source") + } finally { + super.afterAll() + } + } + + private val maxRecordNum = 50 + + private def getConvertMetastoreConfName(format: String): String = format.toLowerCase match { + case "parquet" => HiveUtils.CONVERT_METASTORE_PARQUET.key + case "orc" => HiveUtils.CONVERT_METASTORE_ORC.key + } + + private def getSparkCompressionConfName(format: String): String = format.toLowerCase match { + case "parquet" => SQLConf.PARQUET_COMPRESSION.key + case "orc" => SQLConf.ORC_COMPRESSION.key + } + + private def getHiveCompressPropName(format: String): String = format.toLowerCase match { + case "parquet" => ParquetOutputFormat.COMPRESSION + case "orc" => COMPRESS.getAttribute + } + + private def normalizeCodecName(format: String, name: String): String = { + format.toLowerCase match { + case "parquet" => ParquetOptions.getParquetCompressionCodecName(name) + case "orc" => OrcOptions.getORCCompressionCodecName(name) + } + } + + private def getTableCompressionCodec(path: String, format: String): Seq[String] = { + val hadoopConf = spark.sessionState.newHadoopConf() + val codecs = format.toLowerCase match { + case "parquet" => for { + footer <- readAllFootersWithoutSummaryFiles(new Path(path), hadoopConf) + block <- footer.getParquetMetadata.getBlocks.asScala + column <- block.getColumns.asScala + } yield column.getCodec.name() + case "orc" => new File(path).listFiles().filter { file => + file.isFile && !file.getName.endsWith(".crc") && file.getName != "_SUCCESS" + }.map { orcFile => + OrcFileOperator.getFileReader(orcFile.toPath.toString).get.getCompression.toString + }.toSeq + } + codecs.distinct + } + + private def createTable( + rootDir: File, + tableName: String, + isPartitioned: Boolean, + format: String, + compressionCodec: Option[String]): Unit = { + val tblProperties = compressionCodec match { + case Some(prop) => s"TBLPROPERTIES('${getHiveCompressPropName(format)}'='$prop')" + case _ => "" + } + val partitionCreate = if (isPartitioned) "PARTITIONED BY (p string)" else "" + sql( + s""" + |CREATE TABLE $tableName(a int) + |$partitionCreate + |STORED AS $format + |LOCATION '${rootDir.toURI.toString.stripSuffix("/")}/$tableName' + |$tblProperties + """.stripMargin) + } + + private def writeDataToTable( + tableName: String, + partitionValue: Option[String]): Unit = { + val partitionInsert = partitionValue.map(p => s"partition (p='$p')").mkString + sql( + s""" + |INSERT INTO TABLE $tableName + |$partitionInsert + |SELECT * FROM table_source + """.stripMargin) + } + + private def writeDateToTableUsingCTAS( + rootDir: File, + tableName: String, + partitionValue: Option[String], + format: String, + compressionCodec: Option[String]): Unit = { + val partitionCreate = partitionValue.map(p => s"PARTITIONED BY (p)").mkString + val compressionOption = compressionCodec.map { codec => + s",'${getHiveCompressPropName(format)}'='$codec'" + }.mkString + val partitionSelect = partitionValue.map(p => s",'$p' AS p").mkString + sql( + s""" + |CREATE TABLE $tableName + |USING $format + |OPTIONS('path'='${rootDir.toURI.toString.stripSuffix("/")}/$tableName' $compressionOption) + |$partitionCreate + |AS SELECT * $partitionSelect FROM table_source + """.stripMargin) + } + + private def getPreparedTablePath( + tmpDir: File, + tableName: String, + isPartitioned: Boolean, + format: String, + compressionCodec: Option[String], + usingCTAS: Boolean): String = { + val partitionValue = if (isPartitioned) Some("test") else None + if (usingCTAS) { + writeDateToTableUsingCTAS(tmpDir, tableName, partitionValue, format, compressionCodec) + } else { + createTable(tmpDir, tableName, isPartitioned, format, compressionCodec) + writeDataToTable(tableName, partitionValue) + } + getTablePartitionPath(tmpDir, tableName, partitionValue) + } + + private def getTableSize(path: String): Long = { + val dir = new File(path) + val files = dir.listFiles().filter(_.getName.startsWith("part-")) + files.map(_.length()).sum + } + + private def getTablePartitionPath( + dir: File, + tableName: String, + partitionValue: Option[String]) = { + val partitionPath = partitionValue.map(p => s"p=$p").mkString + s"${dir.getPath.stripSuffix("/")}/$tableName/$partitionPath" + } + + private def getUncompressedDataSizeByFormat( + format: String, isPartitioned: Boolean, usingCTAS: Boolean): Long = { + var totalSize = 0L + val tableName = s"tbl_$format" + val codecName = normalizeCodecName(format, "uncompressed") + withSQLConf(getSparkCompressionConfName(format) -> codecName) { + withTempDir { tmpDir => + withTable(tableName) { + val compressionCodec = Option(codecName) + val path = getPreparedTablePath( + tmpDir, tableName, isPartitioned, format, compressionCodec, usingCTAS) + totalSize = getTableSize(path) + } + } + } + assert(totalSize > 0L) + totalSize + } + + private def checkCompressionCodecForTable( + format: String, + isPartitioned: Boolean, + compressionCodec: Option[String], + usingCTAS: Boolean) + (assertion: (String, Long) => Unit): Unit = { + val tableName = + if (usingCTAS) s"tbl_$format$isPartitioned" else s"tbl_$format${isPartitioned}_CAST" + withTempDir { tmpDir => + withTable(tableName) { + val path = getPreparedTablePath( + tmpDir, tableName, isPartitioned, format, compressionCodec, usingCTAS) + val relCompressionCodecs = getTableCompressionCodec(path, format) + assert(relCompressionCodecs.length == 1) + val tableSize = getTableSize(path) + assertion(relCompressionCodecs.head, tableSize) + } + } + } + + private def checkTableCompressionCodecForCodecs( + format: String, + isPartitioned: Boolean, + convertMetastore: Boolean, + usingCTAS: Boolean, + compressionCodecs: List[String], + tableCompressionCodecs: List[String]) + (assertionCompressionCodec: (Option[String], String, String, Long) => Unit): Unit = { + withSQLConf(getConvertMetastoreConfName(format) -> convertMetastore.toString) { + tableCompressionCodecs.foreach { tableCompression => + compressionCodecs.foreach { sessionCompressionCodec => + withSQLConf(getSparkCompressionConfName(format) -> sessionCompressionCodec) { + // 'tableCompression = null' means no table-level compression + val compression = Option(tableCompression) + checkCompressionCodecForTable(format, isPartitioned, compression, usingCTAS) { + case (realCompressionCodec, tableSize) => + assertionCompressionCodec( + compression, sessionCompressionCodec, realCompressionCodec, tableSize) + } + } + } + } + } + } + + // When the amount of data is small, compressed data size may be larger than uncompressed one, + // so we just check the difference when compressionCodec is not NONE or UNCOMPRESSED. + private def checkTableSize( + format: String, + compressionCodec: String, + isPartitioned: Boolean, + convertMetastore: Boolean, + usingCTAS: Boolean, + tableSize: Long): Boolean = { + val uncompressedSize = getUncompressedDataSizeByFormat(format, isPartitioned, usingCTAS) + compressionCodec match { + case "UNCOMPRESSED" if format == "parquet" => tableSize == uncompressedSize + case "NONE" if format == "orc" => tableSize == uncompressedSize + case _ => tableSize != uncompressedSize + } + } + + def checkForTableWithCompressProp(format: String, compressCodecs: List[String]): Unit = { + Seq(true, false).foreach { isPartitioned => + Seq(true, false).foreach { convertMetastore => + // TODO: Also verify CTAS(usingCTAS=true) cases when the bug(SPARK-22926) is fixed. + Seq(false).foreach { usingCTAS => + checkTableCompressionCodecForCodecs( + format, + isPartitioned, + convertMetastore, + usingCTAS, + compressionCodecs = compressCodecs, + tableCompressionCodecs = compressCodecs) { + case (tableCodec, sessionCodec, realCodec, tableSize) => + // For non-partitioned table and when convertMetastore is true, Expect session-level + // take effect, and in other cases expect table-level take effect + // TODO: It should always be table-level taking effect when the bug(SPARK-22926) + // is fixed + val expectCodec = + if (convertMetastore && !isPartitioned) sessionCodec else tableCodec.get + assert(expectCodec == realCodec) + assert(checkTableSize( + format, expectCodec, isPartitioned, convertMetastore, usingCTAS, tableSize)) + } + } + } + } + } + + def checkForTableWithoutCompressProp(format: String, compressCodecs: List[String]): Unit = { + Seq(true, false).foreach { isPartitioned => + Seq(true, false).foreach { convertMetastore => + // TODO: Also verify CTAS(usingCTAS=true) cases when the bug(SPARK-22926) is fixed. + Seq(false).foreach { usingCTAS => + checkTableCompressionCodecForCodecs( + format, + isPartitioned, + convertMetastore, + usingCTAS, + compressionCodecs = compressCodecs, + tableCompressionCodecs = List(null)) { + case (tableCodec, sessionCodec, realCodec, tableSize) => + // Always expect session-level take effect + assert(sessionCodec == realCodec) + assert(checkTableSize( + format, sessionCodec, isPartitioned, convertMetastore, usingCTAS, tableSize)) + } + } + } + } + } + + test("both table-level and session-level compression are set") { + checkForTableWithCompressProp("parquet", List("UNCOMPRESSED", "SNAPPY", "GZIP")) + checkForTableWithCompressProp("orc", List("NONE", "SNAPPY", "ZLIB")) + } + + test("table-level compression is not set but session-level compressions is set ") { + checkForTableWithoutCompressProp("parquet", List("UNCOMPRESSED", "SNAPPY", "GZIP")) + checkForTableWithoutCompressProp("orc", List("NONE", "SNAPPY", "ZLIB")) + } + + def checkTableWriteWithCompressionCodecs(format: String, compressCodecs: List[String]): Unit = { + Seq(true, false).foreach { isPartitioned => + Seq(true, false).foreach { convertMetastore => + withTempDir { tmpDir => + val tableName = s"tbl_$format$isPartitioned" + createTable(tmpDir, tableName, isPartitioned, format, None) + withTable(tableName) { + compressCodecs.foreach { compressionCodec => + val partitionValue = if (isPartitioned) Some(compressionCodec) else None + withSQLConf(getConvertMetastoreConfName(format) -> convertMetastore.toString, + getSparkCompressionConfName(format) -> compressionCodec + ) { writeDataToTable(tableName, partitionValue) } + } + val tablePath = getTablePartitionPath(tmpDir, tableName, None) + val realCompressionCodecs = + if (isPartitioned) compressCodecs.flatMap { codec => + getTableCompressionCodec(s"$tablePath/p=$codec", format) + } else { + getTableCompressionCodec(tablePath, format) + } + + assert(realCompressionCodecs.distinct.sorted == compressCodecs.sorted) + val recordsNum = sql(s"SELECT * from $tableName").count() + assert(recordsNum == maxRecordNum * compressCodecs.length) + } + } + } + } + } + + test("test table containing mixed compression codec") { + checkTableWriteWithCompressionCodecs("parquet", List("UNCOMPRESSED", "SNAPPY", "GZIP")) + checkTableWriteWithCompressionCodecs("orc", List("NONE", "SNAPPY", "ZLIB")) + } +} From 121dc96f088a7b157d5b2cffb626b0e22d1fc052 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Sat, 20 Jan 2018 22:39:49 -0800 Subject: [PATCH 0159/1161] [SPARK-23087][SQL] CheckCartesianProduct too restrictive when condition is false/null ## What changes were proposed in this pull request? CheckCartesianProduct raises an AnalysisException also when the join condition is always false/null. In this case, we shouldn't raise it, since the result will not be a cartesian product. ## How was this patch tested? added UT Author: Marco Gaido Closes #20333 from mgaido91/SPARK-23087. --- .../spark/sql/catalyst/optimizer/Optimizer.scala | 10 +++++++--- .../org/apache/spark/sql/DataFrameJoinSuite.scala | 14 ++++++++++++++ 2 files changed, 21 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index c794ba8619322..0f9daa5f04c76 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -1108,15 +1108,19 @@ object CheckCartesianProducts extends Rule[LogicalPlan] with PredicateHelper { */ def isCartesianProduct(join: Join): Boolean = { val conditions = join.condition.map(splitConjunctivePredicates).getOrElse(Nil) - !conditions.map(_.references).exists(refs => refs.exists(join.left.outputSet.contains) - && refs.exists(join.right.outputSet.contains)) + + conditions match { + case Seq(Literal.FalseLiteral) | Seq(Literal(null, BooleanType)) => false + case _ => !conditions.map(_.references).exists(refs => + refs.exists(join.left.outputSet.contains) && refs.exists(join.right.outputSet.contains)) + } } def apply(plan: LogicalPlan): LogicalPlan = if (SQLConf.get.crossJoinEnabled) { plan } else plan transform { - case j @ Join(left, right, Inner | LeftOuter | RightOuter | FullOuter, condition) + case j @ Join(left, right, Inner | LeftOuter | RightOuter | FullOuter, _) if isCartesianProduct(j) => throw new AnalysisException( s"""Detected cartesian product for ${j.joinType.sql} join between logical plans diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala index aef0d7f3e425b..1656f290ee19c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala @@ -274,4 +274,18 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext { checkAnswer(innerJoin, Row(1) :: Nil) } + test("SPARK-23087: don't throw Analysis Exception in CheckCartesianProduct when join condition " + + "is false or null") { + val df = spark.range(10) + val dfNull = spark.range(10).select(lit(null).as("b")) + val planNull = df.join(dfNull, $"id" === $"b", "left").queryExecution.analyzed + + spark.sessionState.executePlan(planNull).optimizedPlan + + val dfOne = df.select(lit(1).as("a")) + val dfTwo = spark.range(10).select(lit(2).as("b")) + val planFalse = dfOne.join(dfTwo, $"a" === $"b", "left").queryExecution.analyzed + + spark.sessionState.executePlan(planFalse).optimizedPlan + } } From 4f43d27c9e97be8605b120b3d7c11c7c61e3ca6f Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Sun, 21 Jan 2018 08:51:12 -0600 Subject: [PATCH 0160/1161] [SPARK-22119][ML] Add cosine distance to KMeans ## What changes were proposed in this pull request? Currently, KMeans assumes the only possible distance measure to be used is the Euclidean. This PR aims to add the cosine distance support to the KMeans algorithm. ## How was this patch tested? existing and added UTs. Author: Marco Gaido Author: Marco Gaido Closes #19340 from mgaido91/SPARK-22119. --- .../apache/spark/ml/clustering/KMeans.scala | 22 +- .../mllib/clustering/BisectingKMeans.scala | 11 +- .../spark/mllib/clustering/KMeans.scala | 216 ++++++++++++++---- .../spark/mllib/clustering/KMeansModel.scala | 74 +++++- .../spark/mllib/clustering/LocalKMeans.scala | 10 +- .../spark/ml/clustering/KMeansSuite.scala | 42 +++- .../spark/mllib/clustering/KMeansSuite.scala | 6 +- 7 files changed, 315 insertions(+), 66 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index f2af7fe082b41..c8145de564cbe 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -26,7 +26,7 @@ import org.apache.spark.ml.linalg.{Vector, VectorUDT} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ -import org.apache.spark.mllib.clustering.{KMeans => MLlibKMeans, KMeansModel => MLlibKMeansModel} +import org.apache.spark.mllib.clustering.{DistanceMeasure, KMeans => MLlibKMeans, KMeansModel => MLlibKMeansModel} import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => OldVectors} import org.apache.spark.mllib.linalg.VectorImplicits._ import org.apache.spark.rdd.RDD @@ -71,6 +71,15 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe @Since("1.5.0") def getInitMode: String = $(initMode) + @Since("2.4.0") + final val distanceMeasure = new Param[String](this, "distanceMeasure", "The distance measure. " + + "Supported options: 'euclidean' and 'cosine'.", + (value: String) => MLlibKMeans.validateDistanceMeasure(value)) + + /** @group expertGetParam */ + @Since("2.4.0") + def getDistanceMeasure: String = $(distanceMeasure) + /** * Param for the number of steps for the k-means|| initialization mode. This is an advanced * setting -- the default of 2 is almost always enough. Must be > 0. Default: 2. @@ -260,7 +269,8 @@ class KMeans @Since("1.5.0") ( maxIter -> 20, initMode -> MLlibKMeans.K_MEANS_PARALLEL, initSteps -> 2, - tol -> 1e-4) + tol -> 1e-4, + distanceMeasure -> DistanceMeasure.EUCLIDEAN) @Since("1.5.0") override def copy(extra: ParamMap): KMeans = defaultCopy(extra) @@ -284,6 +294,10 @@ class KMeans @Since("1.5.0") ( @Since("1.5.0") def setInitMode(value: String): this.type = set(initMode, value) + /** @group expertSetParam */ + @Since("2.4.0") + def setDistanceMeasure(value: String): this.type = set(distanceMeasure, value) + /** @group expertSetParam */ @Since("1.5.0") def setInitSteps(value: Int): this.type = set(initSteps, value) @@ -314,7 +328,8 @@ class KMeans @Since("1.5.0") ( } val instr = Instrumentation.create(this, instances) - instr.logParams(featuresCol, predictionCol, k, initMode, initSteps, maxIter, seed, tol) + instr.logParams(featuresCol, predictionCol, k, initMode, initSteps, distanceMeasure, + maxIter, seed, tol) val algo = new MLlibKMeans() .setK($(k)) .setInitializationMode($(initMode)) @@ -322,6 +337,7 @@ class KMeans @Since("1.5.0") ( .setMaxIterations($(maxIter)) .setSeed($(seed)) .setEpsilon($(tol)) + .setDistanceMeasure($(distanceMeasure)) val parentModel = algo.run(instances, Option(instr)) val model = copyValues(new KMeansModel(uid, parentModel).setParent(this)) val summary = new KMeansSummary( diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala index 9b9c70cfe5109..2221f4c0edc17 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala @@ -350,7 +350,7 @@ private object BisectingKMeans extends Serializable { val newClusterChildren = children.filter(newClusterCenters.contains(_)) if (newClusterChildren.nonEmpty) { val selected = newClusterChildren.minBy { child => - KMeans.fastSquaredDistance(newClusterCenters(child), v) + EuclideanDistanceMeasure.fastSquaredDistance(newClusterCenters(child), v) } (selected, v) } else { @@ -387,7 +387,7 @@ private object BisectingKMeans extends Serializable { val rightIndex = rightChildIndex(rawIndex) val indexes = Seq(leftIndex, rightIndex).filter(clusters.contains(_)) val height = math.sqrt(indexes.map { childIndex => - KMeans.fastSquaredDistance(center, clusters(childIndex).center) + EuclideanDistanceMeasure.fastSquaredDistance(center, clusters(childIndex).center) }.max) val children = indexes.map(buildSubTree(_)).toArray new ClusteringTreeNode(index, size, center, cost, height, children) @@ -457,7 +457,7 @@ private[clustering] class ClusteringTreeNode private[clustering] ( this :: Nil } else { val selected = children.minBy { child => - KMeans.fastSquaredDistance(child.centerWithNorm, pointWithNorm) + EuclideanDistanceMeasure.fastSquaredDistance(child.centerWithNorm, pointWithNorm) } selected :: selected.predictPath(pointWithNorm) } @@ -475,7 +475,8 @@ private[clustering] class ClusteringTreeNode private[clustering] ( * Predicts the cluster index and the cost of the input point. */ private def predict(pointWithNorm: VectorWithNorm): (Int, Double) = { - predict(pointWithNorm, KMeans.fastSquaredDistance(centerWithNorm, pointWithNorm)) + predict(pointWithNorm, + EuclideanDistanceMeasure.fastSquaredDistance(centerWithNorm, pointWithNorm)) } /** @@ -490,7 +491,7 @@ private[clustering] class ClusteringTreeNode private[clustering] ( (index, cost) } else { val (selectedChild, minCost) = children.map { child => - (child, KMeans.fastSquaredDistance(child.centerWithNorm, pointWithNorm)) + (child, EuclideanDistanceMeasure.fastSquaredDistance(child.centerWithNorm, pointWithNorm)) }.minBy(_._2) selectedChild.predict(pointWithNorm, minCost) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala index 49043b5acb807..607145cb59fba 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala @@ -25,7 +25,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.ml.clustering.{KMeans => NewKMeans} import org.apache.spark.ml.util.Instrumentation import org.apache.spark.mllib.linalg.{Vector, Vectors} -import org.apache.spark.mllib.linalg.BLAS.{axpy, scal} +import org.apache.spark.mllib.linalg.BLAS.{axpy, dot, scal} import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel @@ -46,14 +46,23 @@ class KMeans private ( private var initializationMode: String, private var initializationSteps: Int, private var epsilon: Double, - private var seed: Long) extends Serializable with Logging { + private var seed: Long, + private var distanceMeasure: String) extends Serializable with Logging { + + @Since("0.8.0") + private def this(k: Int, maxIterations: Int, initializationMode: String, initializationSteps: Int, + epsilon: Double, seed: Long) = + this(k, maxIterations, initializationMode, initializationSteps, + epsilon, seed, DistanceMeasure.EUCLIDEAN) /** * Constructs a KMeans instance with default parameters: {k: 2, maxIterations: 20, - * initializationMode: "k-means||", initializationSteps: 2, epsilon: 1e-4, seed: random}. + * initializationMode: "k-means||", initializationSteps: 2, epsilon: 1e-4, seed: random, + * distanceMeasure: "euclidean"}. */ @Since("0.8.0") - def this() = this(2, 20, KMeans.K_MEANS_PARALLEL, 2, 1e-4, Utils.random.nextLong()) + def this() = this(2, 20, KMeans.K_MEANS_PARALLEL, 2, 1e-4, Utils.random.nextLong(), + DistanceMeasure.EUCLIDEAN) /** * Number of clusters to create (k). @@ -184,6 +193,22 @@ class KMeans private ( this } + /** + * The distance suite used by the algorithm. + */ + @Since("2.4.0") + def getDistanceMeasure: String = distanceMeasure + + /** + * Set the distance suite used by the algorithm. + */ + @Since("2.4.0") + def setDistanceMeasure(distanceMeasure: String): this.type = { + KMeans.validateDistanceMeasure(distanceMeasure) + this.distanceMeasure = distanceMeasure + this + } + // Initial cluster centers can be provided as a KMeansModel object rather than using the // random or k-means|| initializationMode private var initialModel: Option[KMeansModel] = None @@ -246,6 +271,8 @@ class KMeans private ( val initStartTime = System.nanoTime() + val distanceMeasureInstance = DistanceMeasure.decodeFromString(this.distanceMeasure) + val centers = initialModel match { case Some(kMeansCenters) => kMeansCenters.clusterCenters.map(new VectorWithNorm(_)) @@ -253,7 +280,7 @@ class KMeans private ( if (initializationMode == KMeans.RANDOM) { initRandom(data) } else { - initKMeansParallel(data) + initKMeansParallel(data, distanceMeasureInstance) } } val initTimeInSeconds = (System.nanoTime() - initStartTime) / 1e9 @@ -281,7 +308,7 @@ class KMeans private ( val counts = Array.fill(thisCenters.length)(0L) points.foreach { point => - val (bestCenter, cost) = KMeans.findClosest(thisCenters, point) + val (bestCenter, cost) = distanceMeasureInstance.findClosest(thisCenters, point) costAccum.add(cost) val sum = sums(bestCenter) axpy(1.0, point.vector, sum) @@ -302,7 +329,8 @@ class KMeans private ( // Update the cluster centers and costs converged = true newCenters.foreach { case (j, newCenter) => - if (converged && KMeans.fastSquaredDistance(newCenter, centers(j)) > epsilon * epsilon) { + if (converged && + !distanceMeasureInstance.isCenterConverged(centers(j), newCenter, epsilon)) { converged = false } centers(j) = newCenter @@ -323,7 +351,7 @@ class KMeans private ( logInfo(s"The cost is $cost.") - new KMeansModel(centers.map(_.vector)) + new KMeansModel(centers.map(_.vector), distanceMeasure) } /** @@ -345,7 +373,8 @@ class KMeans private ( * * The original paper can be found at http://theory.stanford.edu/~sergei/papers/vldb12-kmpar.pdf. */ - private[clustering] def initKMeansParallel(data: RDD[VectorWithNorm]): Array[VectorWithNorm] = { + private[clustering] def initKMeansParallel(data: RDD[VectorWithNorm], + distanceMeasureInstance: DistanceMeasure): Array[VectorWithNorm] = { // Initialize empty centers and point costs. var costs = data.map(_ => Double.PositiveInfinity) @@ -369,7 +398,7 @@ class KMeans private ( bcNewCentersList += bcNewCenters val preCosts = costs costs = data.zip(preCosts).map { case (point, cost) => - math.min(KMeans.pointCost(bcNewCenters.value, point), cost) + math.min(distanceMeasureInstance.pointCost(bcNewCenters.value, point), cost) }.persist(StorageLevel.MEMORY_AND_DISK) val sumCosts = costs.sum() @@ -397,7 +426,9 @@ class KMeans private ( // candidate by the number of points in the dataset mapping to it and run a local k-means++ // on the weighted centers to pick k of them val bcCenters = data.context.broadcast(distinctCenters) - val countMap = data.map(KMeans.findClosest(bcCenters.value, _)._1).countByValue() + val countMap = data + .map(distanceMeasureInstance.findClosest(bcCenters.value, _)._1) + .countByValue() bcCenters.destroy(blocking = false) @@ -546,10 +577,110 @@ object KMeans { .run(data) } + private[spark] def validateInitMode(initMode: String): Boolean = { + initMode match { + case KMeans.RANDOM => true + case KMeans.K_MEANS_PARALLEL => true + case _ => false + } + } + + private[spark] def validateDistanceMeasure(distanceMeasure: String): Boolean = { + distanceMeasure match { + case DistanceMeasure.EUCLIDEAN => true + case DistanceMeasure.COSINE => true + case _ => false + } + } +} + +/** + * A vector with its norm for fast distance computation. + */ +private[clustering] class VectorWithNorm(val vector: Vector, val norm: Double) + extends Serializable { + + def this(vector: Vector) = this(vector, Vectors.norm(vector, 2.0)) + + def this(array: Array[Double]) = this(Vectors.dense(array)) + + /** Converts the vector to a dense vector. */ + def toDense: VectorWithNorm = new VectorWithNorm(Vectors.dense(vector.toArray), norm) +} + + +private[spark] abstract class DistanceMeasure extends Serializable { + + /** + * @return the index of the closest center to the given point, as well as the cost. + */ + def findClosest( + centers: TraversableOnce[VectorWithNorm], + point: VectorWithNorm): (Int, Double) = { + var bestDistance = Double.PositiveInfinity + var bestIndex = 0 + var i = 0 + centers.foreach { center => + val currentDistance = distance(center, point) + if (currentDistance < bestDistance) { + bestDistance = currentDistance + bestIndex = i + } + i += 1 + } + (bestIndex, bestDistance) + } + /** - * Returns the index of the closest center to the given point, as well as the squared distance. + * @return the K-means cost of a given point against the given cluster centers. */ - private[mllib] def findClosest( + def pointCost( + centers: TraversableOnce[VectorWithNorm], + point: VectorWithNorm): Double = { + findClosest(centers, point)._2 + } + + /** + * @return whether a center converged or not, given the epsilon parameter. + */ + def isCenterConverged( + oldCenter: VectorWithNorm, + newCenter: VectorWithNorm, + epsilon: Double): Boolean = { + distance(oldCenter, newCenter) <= epsilon + } + + /** + * @return the cosine distance between two points. + */ + def distance( + v1: VectorWithNorm, + v2: VectorWithNorm): Double + +} + +@Since("2.4.0") +object DistanceMeasure { + + @Since("2.4.0") + val EUCLIDEAN = "euclidean" + @Since("2.4.0") + val COSINE = "cosine" + + private[spark] def decodeFromString(distanceMeasure: String): DistanceMeasure = + distanceMeasure match { + case EUCLIDEAN => new EuclideanDistanceMeasure + case COSINE => new CosineDistanceMeasure + case _ => throw new IllegalArgumentException(s"distanceMeasure must be one of: " + + s"$EUCLIDEAN, $COSINE. $distanceMeasure provided.") + } +} + +private[spark] class EuclideanDistanceMeasure extends DistanceMeasure { + /** + * @return the index of the closest center to the given point, as well as the squared distance. + */ + override def findClosest( centers: TraversableOnce[VectorWithNorm], point: VectorWithNorm): (Int, Double) = { var bestDistance = Double.PositiveInfinity @@ -561,7 +692,7 @@ object KMeans { var lowerBoundOfSqDist = center.norm - point.norm lowerBoundOfSqDist = lowerBoundOfSqDist * lowerBoundOfSqDist if (lowerBoundOfSqDist < bestDistance) { - val distance: Double = fastSquaredDistance(center, point) + val distance: Double = EuclideanDistanceMeasure.fastSquaredDistance(center, point) if (distance < bestDistance) { bestDistance = distance bestIndex = i @@ -573,15 +704,29 @@ object KMeans { } /** - * Returns the K-means cost of a given point against the given cluster centers. + * @return whether a center converged or not, given the epsilon parameter. */ - private[mllib] def pointCost( - centers: TraversableOnce[VectorWithNorm], - point: VectorWithNorm): Double = - findClosest(centers, point)._2 + override def isCenterConverged( + oldCenter: VectorWithNorm, + newCenter: VectorWithNorm, + epsilon: Double): Boolean = { + EuclideanDistanceMeasure.fastSquaredDistance(newCenter, oldCenter) <= epsilon * epsilon + } + + /** + * @param v1: first vector + * @param v2: second vector + * @return the Euclidean distance between the two input vectors + */ + override def distance(v1: VectorWithNorm, v2: VectorWithNorm): Double = { + Math.sqrt(EuclideanDistanceMeasure.fastSquaredDistance(v1, v2)) + } +} + +private[spark] object EuclideanDistanceMeasure { /** - * Returns the squared Euclidean distance between two vectors computed by + * @return the squared Euclidean distance between two vectors computed by * [[org.apache.spark.mllib.util.MLUtils#fastSquaredDistance]]. */ private[clustering] def fastSquaredDistance( @@ -589,28 +734,15 @@ object KMeans { v2: VectorWithNorm): Double = { MLUtils.fastSquaredDistance(v1.vector, v1.norm, v2.vector, v2.norm) } - - private[spark] def validateInitMode(initMode: String): Boolean = { - initMode match { - case KMeans.RANDOM => true - case KMeans.K_MEANS_PARALLEL => true - case _ => false - } - } } -/** - * A vector with its norm for fast distance computation. - * - * @see [[org.apache.spark.mllib.clustering.KMeans#fastSquaredDistance]] - */ -private[clustering] -class VectorWithNorm(val vector: Vector, val norm: Double) extends Serializable { - - def this(vector: Vector) = this(vector, Vectors.norm(vector, 2.0)) - - def this(array: Array[Double]) = this(Vectors.dense(array)) - - /** Converts the vector to a dense vector. */ - def toDense: VectorWithNorm = new VectorWithNorm(Vectors.dense(vector.toArray), norm) +private[spark] class CosineDistanceMeasure extends DistanceMeasure { + /** + * @param v1: first vector + * @param v2: second vector + * @return the cosine distance between the two input vectors + */ + override def distance(v1: VectorWithNorm, v2: VectorWithNorm): Double = { + 1 - dot(v1.vector, v2.vector) / v1.norm / v2.norm + } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala index 3ad08c46d204d..a78c21e838e44 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala @@ -36,12 +36,20 @@ import org.apache.spark.sql.{Row, SparkSession} * A clustering model for K-means. Each point belongs to the cluster with the closest center. */ @Since("0.8.0") -class KMeansModel @Since("1.1.0") (@Since("1.0.0") val clusterCenters: Array[Vector]) +class KMeansModel @Since("2.4.0") (@Since("1.0.0") val clusterCenters: Array[Vector], + @Since("2.4.0") val distanceMeasure: String) extends Saveable with Serializable with PMMLExportable { + private val distanceMeasureInstance: DistanceMeasure = + DistanceMeasure.decodeFromString(distanceMeasure) + private val clusterCentersWithNorm = if (clusterCenters == null) null else clusterCenters.map(new VectorWithNorm(_)) + @Since("1.1.0") + def this(clusterCenters: Array[Vector]) = + this(clusterCenters: Array[Vector], DistanceMeasure.EUCLIDEAN) + /** * A Java-friendly constructor that takes an Iterable of Vectors. */ @@ -59,7 +67,7 @@ class KMeansModel @Since("1.1.0") (@Since("1.0.0") val clusterCenters: Array[Vec */ @Since("0.8.0") def predict(point: Vector): Int = { - KMeans.findClosest(clusterCentersWithNorm, new VectorWithNorm(point))._1 + distanceMeasureInstance.findClosest(clusterCentersWithNorm, new VectorWithNorm(point))._1 } /** @@ -68,7 +76,8 @@ class KMeansModel @Since("1.1.0") (@Since("1.0.0") val clusterCenters: Array[Vec @Since("1.0.0") def predict(points: RDD[Vector]): RDD[Int] = { val bcCentersWithNorm = points.context.broadcast(clusterCentersWithNorm) - points.map(p => KMeans.findClosest(bcCentersWithNorm.value, new VectorWithNorm(p))._1) + points.map(p => + distanceMeasureInstance.findClosest(bcCentersWithNorm.value, new VectorWithNorm(p))._1) } /** @@ -85,8 +94,9 @@ class KMeansModel @Since("1.1.0") (@Since("1.0.0") val clusterCenters: Array[Vec @Since("0.8.0") def computeCost(data: RDD[Vector]): Double = { val bcCentersWithNorm = data.context.broadcast(clusterCentersWithNorm) - val cost = data - .map(p => KMeans.pointCost(bcCentersWithNorm.value, new VectorWithNorm(p))).sum() + val cost = data.map(p => + distanceMeasureInstance.pointCost(bcCentersWithNorm.value, new VectorWithNorm(p))) + .sum() bcCentersWithNorm.destroy(blocking = false) cost } @@ -94,7 +104,7 @@ class KMeansModel @Since("1.1.0") (@Since("1.0.0") val clusterCenters: Array[Vec @Since("1.4.0") override def save(sc: SparkContext, path: String): Unit = { - KMeansModel.SaveLoadV1_0.save(sc, this, path) + KMeansModel.SaveLoadV2_0.save(sc, this, path) } override protected def formatVersion: String = "1.0" @@ -105,7 +115,20 @@ object KMeansModel extends Loader[KMeansModel] { @Since("1.4.0") override def load(sc: SparkContext, path: String): KMeansModel = { - KMeansModel.SaveLoadV1_0.load(sc, path) + val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path) + val classNameV1_0 = SaveLoadV1_0.thisClassName + val classNameV2_0 = SaveLoadV2_0.thisClassName + (loadedClassName, version) match { + case (className, "1.0") if className == classNameV1_0 => + SaveLoadV1_0.load(sc, path) + case (className, "2.0") if className == classNameV2_0 => + SaveLoadV2_0.load(sc, path) + case _ => throw new Exception( + s"KMeansModel.load did not recognize model with (className, format version):" + + s"($loadedClassName, $version). Supported:\n" + + s" ($classNameV1_0, 1.0)\n" + + s" ($classNameV2_0, 2.0)") + } } private case class Cluster(id: Int, point: Vector) @@ -116,8 +139,7 @@ object KMeansModel extends Loader[KMeansModel] { } } - private[clustering] - object SaveLoadV1_0 { + private[clustering] object SaveLoadV1_0 { private val thisFormatVersion = "1.0" @@ -149,4 +171,38 @@ object KMeansModel extends Loader[KMeansModel] { new KMeansModel(localCentroids.sortBy(_.id).map(_.point)) } } + + private[clustering] object SaveLoadV2_0 { + + private val thisFormatVersion = "2.0" + + private[clustering] val thisClassName = "org.apache.spark.mllib.clustering.KMeansModel" + + def save(sc: SparkContext, model: KMeansModel, path: String): Unit = { + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() + val metadata = compact(render( + ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) + ~ ("k" -> model.k) ~ ("distanceMeasure" -> model.distanceMeasure))) + sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path)) + val dataRDD = sc.parallelize(model.clusterCentersWithNorm.zipWithIndex).map { case (p, id) => + Cluster(id, p.vector) + } + spark.createDataFrame(dataRDD).write.parquet(Loader.dataPath(path)) + } + + def load(sc: SparkContext, path: String): KMeansModel = { + implicit val formats = DefaultFormats + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() + val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path) + assert(className == thisClassName) + assert(formatVersion == thisFormatVersion) + val k = (metadata \ "k").extract[Int] + val centroids = spark.read.parquet(Loader.dataPath(path)) + Loader.checkSchema[Cluster](centroids.schema) + val localCentroids = centroids.rdd.map(Cluster.apply).collect() + assert(k == localCentroids.length) + val distanceMeasure = (metadata \ "distanceMeasure").extract[String] + new KMeansModel(localCentroids.sortBy(_.id).map(_.point), distanceMeasure) + } + } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LocalKMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LocalKMeans.scala index 53587670a5db0..4a08c0a55e68f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LocalKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LocalKMeans.scala @@ -46,7 +46,7 @@ private[mllib] object LocalKMeans extends Logging { // Initialize centers by sampling using the k-means++ procedure. centers(0) = pickWeighted(rand, points, weights).toDense - val costArray = points.map(KMeans.fastSquaredDistance(_, centers(0))) + val costArray = points.map(EuclideanDistanceMeasure.fastSquaredDistance(_, centers(0))) for (i <- 1 until k) { val sum = costArray.zip(weights).map(p => p._1 * p._2).sum @@ -67,11 +67,15 @@ private[mllib] object LocalKMeans extends Logging { // update costArray for (p <- points.indices) { - costArray(p) = math.min(KMeans.fastSquaredDistance(points(p), centers(i)), costArray(p)) + costArray(p) = math.min( + EuclideanDistanceMeasure.fastSquaredDistance(points(p), centers(i)), + costArray(p)) } } + val distanceMeasureInstance = new EuclideanDistanceMeasure + // Run up to maxIterations iterations of Lloyd's algorithm val oldClosest = Array.fill(points.length)(-1) var iteration = 0 @@ -83,7 +87,7 @@ private[mllib] object LocalKMeans extends Logging { var i = 0 while (i < points.length) { val p = points(i) - val index = KMeans.findClosest(centers, p)._1 + val index = distanceMeasureInstance.findClosest(centers, p)._1 axpy(weights(i), p.vector, sums(index)) counts(index) += weights(i) if (index != oldClosest(i)) { diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala index 119fe1dead9a9..e4506f23feb31 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala @@ -23,7 +23,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} -import org.apache.spark.mllib.clustering.{KMeans => MLlibKMeans} +import org.apache.spark.mllib.clustering.{DistanceMeasure, KMeans => MLlibKMeans} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, Dataset, SparkSession} @@ -50,6 +50,7 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR assert(kmeans.getInitMode === MLlibKMeans.K_MEANS_PARALLEL) assert(kmeans.getInitSteps === 2) assert(kmeans.getTol === 1e-4) + assert(kmeans.getDistanceMeasure === DistanceMeasure.EUCLIDEAN) val model = kmeans.setMaxIter(1).fit(dataset) MLTestingUtils.checkCopyAndUids(kmeans, model) @@ -68,6 +69,7 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR .setInitSteps(3) .setSeed(123) .setTol(1e-3) + .setDistanceMeasure(DistanceMeasure.COSINE) assert(kmeans.getK === 9) assert(kmeans.getFeaturesCol === "test_feature") @@ -77,6 +79,7 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR assert(kmeans.getInitSteps === 3) assert(kmeans.getSeed === 123) assert(kmeans.getTol === 1e-3) + assert(kmeans.getDistanceMeasure === DistanceMeasure.COSINE) } test("parameters validation") { @@ -89,6 +92,9 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR intercept[IllegalArgumentException] { new KMeans().setInitSteps(0) } + intercept[IllegalArgumentException] { + new KMeans().setDistanceMeasure("no_such_a_measure") + } } test("fit, transform and summary") { @@ -144,6 +150,37 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR assert(model.getPredictionCol == predictionColName) } + test("KMeans using cosine distance") { + val df = spark.createDataFrame(spark.sparkContext.parallelize(Array( + Vectors.dense(1.0, 1.0), + Vectors.dense(10.0, 10.0), + Vectors.dense(1.0, 0.5), + Vectors.dense(10.0, 4.4), + Vectors.dense(-1.0, 1.0), + Vectors.dense(-100.0, 90.0) + )).map(v => TestRow(v))) + + val model = new KMeans() + .setK(3) + .setSeed(1) + .setInitMode(MLlibKMeans.RANDOM) + .setTol(1e-6) + .setDistanceMeasure(DistanceMeasure.COSINE) + .fit(df) + + val predictionDf = model.transform(df) + assert(predictionDf.select("prediction").distinct().count() == 3) + val predictionsMap = predictionDf.collect().map(row => + row.getAs[Vector]("features") -> row.getAs[Int]("prediction")).toMap + assert(predictionsMap(Vectors.dense(1.0, 1.0)) == + predictionsMap(Vectors.dense(10.0, 10.0))) + assert(predictionsMap(Vectors.dense(1.0, 0.5)) == + predictionsMap(Vectors.dense(10.0, 4.4))) + assert(predictionsMap(Vectors.dense(-1.0, 1.0)) == + predictionsMap(Vectors.dense(-100.0, 90.0))) + + } + test("read/write") { def checkModelData(model: KMeansModel, model2: KMeansModel): Unit = { assert(model.clusterCenters === model2.clusterCenters) @@ -182,6 +219,7 @@ object KMeansSuite { "predictionCol" -> "myPrediction", "k" -> 3, "maxIter" -> 2, - "tol" -> 0.01 + "tol" -> 0.01, + "distanceMeasure" -> DistanceMeasure.EUCLIDEAN ) } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala index 00d7e2f2d3864..1b98250061c7a 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala @@ -89,7 +89,9 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext { .setInitializationMode("k-means||") .setInitializationSteps(10) .setSeed(seed) - val initialCenters = km.initKMeansParallel(normedData).map(_.vector) + + val distanceMeasureInstance = new EuclideanDistanceMeasure + val initialCenters = km.initKMeansParallel(normedData, distanceMeasureInstance).map(_.vector) assert(initialCenters.length === initialCenters.distinct.length) assert(initialCenters.length <= numDistinctPoints) @@ -104,7 +106,7 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext { .setInitializationMode("k-means||") .setInitializationSteps(10) .setSeed(seed) - val initialCenters2 = km2.initKMeansParallel(normedData).map(_.vector) + val initialCenters2 = km2.initKMeansParallel(normedData, distanceMeasureInstance).map(_.vector) assert(initialCenters2.length === initialCenters2.distinct.length) assert(initialCenters2.length === k) From 2239d7a410e906ccd40aa8e84d637e9d06cd7b8a Mon Sep 17 00:00:00 2001 From: Felix Cheung Date: Sun, 21 Jan 2018 11:23:51 -0800 Subject: [PATCH 0161/1161] [SPARK-21293][SS][SPARKR] Add doc example for streaming join, dedup ## What changes were proposed in this pull request? streaming programming guide changes ## How was this patch tested? manually Author: Felix Cheung Closes #20340 from felixcheung/rstreamdoc. --- .../structured-streaming-programming-guide.md | 74 ++++++++++++++++++- 1 file changed, 72 insertions(+), 2 deletions(-) diff --git a/docs/structured-streaming-programming-guide.md b/docs/structured-streaming-programming-guide.md index 2ef5d3168a87b..62589a62ac4c4 100644 --- a/docs/structured-streaming-programming-guide.md +++ b/docs/structured-streaming-programming-guide.md @@ -1100,6 +1100,21 @@ streamingDf.join(staticDf, "type") # inner equi-join with a static DF streamingDf.join(staticDf, "type", "right_join") # right outer join with a static DF {% endhighlight %} +
+ +
+ +{% highlight r %} +staticDf <- read.df(...) +streamingDf <- read.stream(...) +joined <- merge(streamingDf, staticDf, sort = FALSE) # inner equi-join with a static DF +joined <- join( + staticDf, + streamingDf, + streamingDf$value == staticDf$value, + "right_outer") # right outer join with a static DF +{% endhighlight %} +
@@ -1227,6 +1242,30 @@ impressionsWithWatermark.join( {% endhighlight %} + +
+ +{% highlight r %} +impressions <- read.stream(...) +clicks <- read.stream(...) + +# Apply watermarks on event-time columns +impressionsWithWatermark <- withWatermark(impressions, "impressionTime", "2 hours") +clicksWithWatermark <- withWatermark(clicks, "clickTime", "3 hours") + +# Join with event-time constraints +joined <- join( + impressionsWithWatermark, + clicksWithWatermark, + expr( + paste( + "clickAdId = impressionAdId AND", + "clickTime >= impressionTime AND", + "clickTime <= impressionTime + interval 1 hour" +))) + +{% endhighlight %} +
@@ -1287,6 +1326,23 @@ impressionsWithWatermark.join( {% endhighlight %} + +
+ +{% highlight r %} +joined <- join( + impressionsWithWatermark, + clicksWithWatermark, + expr( + paste( + "clickAdId = impressionAdId AND", + "clickTime >= impressionTime AND", + "clickTime <= impressionTime + interval 1 hour"), + "left_outer" # can be "inner", "left_outer", "right_outer" +)) + +{% endhighlight %} +
@@ -1441,15 +1497,29 @@ streamingDf {% highlight python %} streamingDf = spark.readStream. ... -// Without watermark using guid column +# Without watermark using guid column streamingDf.dropDuplicates("guid") -// With watermark using guid and eventTime columns +# With watermark using guid and eventTime columns streamingDf \ .withWatermark("eventTime", "10 seconds") \ .dropDuplicates("guid", "eventTime") {% endhighlight %} + +
+ +{% highlight r %} +streamingDf <- read.stream(...) + +# Without watermark using guid column +streamingDf <- dropDuplicates(streamingDf, "guid") + +# With watermark using guid and eventTime columns +streamingDf <- withWatermark(streamingDf, "eventTime", "10 seconds") +streamingDf <- dropDuplicates(streamingDf, "guid", "eventTime") +{% endhighlight %} +
From 12faae295e42820b99a695ba49826051944244e1 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Mon, 22 Jan 2018 09:45:27 +0900 Subject: [PATCH 0162/1161] [SPARK-23169][INFRA][R] Run lintr on the changes of lint-r script and .lintr configuration ## What changes were proposed in this pull request? When running the `run-tests` script, seems we don't run lintr on the changes of `lint-r` script and `.lintr` configuration. ## How was this patch tested? Jenkins builds Author: hyukjinkwon Closes #20339 from HyukjinKwon/check-r-changed. --- dev/run-tests.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/dev/run-tests.py b/dev/run-tests.py index 7e6f7ff060351..fb270c4ee0508 100755 --- a/dev/run-tests.py +++ b/dev/run-tests.py @@ -578,7 +578,10 @@ def main(): pass if not changed_files or any(f.endswith(".py") for f in changed_files): run_python_style_checks() - if not changed_files or any(f.endswith(".R") for f in changed_files): + if not changed_files or any(f.endswith(".R") + or f.endswith("lint-r") + or f.endswith(".lintr") + for f in changed_files): run_sparkr_style_checks() # determine if docs were changed and if we're inside the amplab environment From 602c6d82d893a7f34b37d674642669048eb59b03 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E6=99=93=E5=93=B2?= Date: Mon, 22 Jan 2018 10:43:12 +0900 Subject: [PATCH 0163/1161] [SPARK-20947][PYTHON] Fix encoding/decoding error in pipe action MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? Pipe action convert objects into strings using a way that was affected by the default encoding setting of Python environment. This patch fixed the problem. The detailed description is added here: https://issues.apache.org/jira/browse/SPARK-20947 ## How was this patch tested? Run the following statement in pyspark-shell, and it will NOT raise exception if this patch is applied: ```python sc.parallelize([u'\u6d4b\u8bd5']).pipe('cat').collect() ``` Author: 王晓哲 Closes #18277 from chaoslawful/fix_pipe_encoding_error. --- python/pyspark/rdd.py | 2 +- python/pyspark/tests.py | 7 +++++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 340bc3a6b7470..1b3915548fb14 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -766,7 +766,7 @@ def func(iterator): def pipe_objs(out): for obj in iterator: - s = str(obj).rstrip('\n') + '\n' + s = unicode(obj).rstrip('\n') + '\n' out.write(s.encode('utf-8')) out.close() Thread(target=pipe_objs, args=[pipe.stdin]).start() diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index da99872da2f0e..511585763cb01 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -1239,6 +1239,13 @@ def test_pipe_functions(self): self.assertRaises(Py4JJavaError, rdd.pipe('grep 4', checkCode=True).collect) self.assertEqual([], rdd.pipe('grep 4').collect()) + def test_pipe_unicode(self): + # Regression test for SPARK-20947 + data = [u'\u6d4b\u8bd5', '1'] + rdd = self.sc.parallelize(data) + result = rdd.pipe('cat').collect() + self.assertEqual(data, result) + class ProfilerTests(PySparkTestCase): From 11daeb833222b1cd349fb1410307d64ab33981db Mon Sep 17 00:00:00 2001 From: Russell Spitzer Date: Mon, 22 Jan 2018 12:27:51 +0800 Subject: [PATCH 0164/1161] [SPARK-22976][CORE] Cluster mode driver dir removed while running ## What changes were proposed in this pull request? The clean up logic on the worker perviously determined the liveness of a particular applicaiton based on whether or not it had running executors. This would fail in the case that a directory was made for a driver running in cluster mode if that driver had no running executors on the same machine. To preserve driver directories we consider both executors and running drivers when checking directory liveness. ## How was this patch tested? Manually started up two node cluster with a single core on each node. Turned on worker directory cleanup and set the interval to 1 second and liveness to one second. Without the patch the driver directory is removed immediately after the app is launched. With the patch it is not ### Without Patch ``` INFO 2018-01-05 23:48:24,693 Logging.scala:54 - Asked to launch driver driver-20180105234824-0000 INFO 2018-01-05 23:48:25,293 Logging.scala:54 - Changing view acls to: cassandra INFO 2018-01-05 23:48:25,293 Logging.scala:54 - Changing modify acls to: cassandra INFO 2018-01-05 23:48:25,294 Logging.scala:54 - Changing view acls groups to: INFO 2018-01-05 23:48:25,294 Logging.scala:54 - Changing modify acls groups to: INFO 2018-01-05 23:48:25,294 Logging.scala:54 - SecurityManager: authentication disabled; ui acls disabled; users with view permissions: Set(cassandra); groups with view permissions: Set(); users with modify permissions: Set(cassandra); groups with modify permissions: Set() INFO 2018-01-05 23:48:25,330 Logging.scala:54 - Copying user jar file:/home/automaton/writeRead-0.1.jar to /var/lib/spark/worker/driver-20180105234824-0000/writeRead-0.1.jar INFO 2018-01-05 23:48:25,332 Logging.scala:54 - Copying /home/automaton/writeRead-0.1.jar to /var/lib/spark/worker/driver-20180105234824-0000/writeRead-0.1.jar INFO 2018-01-05 23:48:25,361 Logging.scala:54 - Launch Command: "/usr/lib/jvm/jdk1.8.0_40//bin/java" .... **** INFO 2018-01-05 23:48:56,577 Logging.scala:54 - Removing directory: /var/lib/spark/worker/driver-20180105234824-0000 ### << Cleaned up **** -- One minute passes while app runs (app has 1 minute sleep built in) -- WARN 2018-01-05 23:49:58,080 ShuffleSecretManager.java:73 - Attempted to unregister application app-20180105234831-0000 when it is not registered INFO 2018-01-05 23:49:58,081 ExternalShuffleBlockResolver.java:163 - Application app-20180105234831-0000 removed, cleanupLocalDirs = false INFO 2018-01-05 23:49:58,081 ExternalShuffleBlockResolver.java:163 - Application app-20180105234831-0000 removed, cleanupLocalDirs = false INFO 2018-01-05 23:49:58,082 ExternalShuffleBlockResolver.java:163 - Application app-20180105234831-0000 removed, cleanupLocalDirs = true INFO 2018-01-05 23:50:00,999 Logging.scala:54 - Driver driver-20180105234824-0000 exited successfully ``` With Patch ``` INFO 2018-01-08 23:19:54,603 Logging.scala:54 - Asked to launch driver driver-20180108231954-0002 INFO 2018-01-08 23:19:54,975 Logging.scala:54 - Changing view acls to: automaton INFO 2018-01-08 23:19:54,976 Logging.scala:54 - Changing modify acls to: automaton INFO 2018-01-08 23:19:54,976 Logging.scala:54 - Changing view acls groups to: INFO 2018-01-08 23:19:54,976 Logging.scala:54 - Changing modify acls groups to: INFO 2018-01-08 23:19:54,976 Logging.scala:54 - SecurityManager: authentication disabled; ui acls disabled; users with view permissions: Set(automaton); groups with view permissions: Set(); users with modify permissions: Set(automaton); groups with modify permissions: Set() INFO 2018-01-08 23:19:55,029 Logging.scala:54 - Copying user jar file:/home/automaton/writeRead-0.1.jar to /var/lib/spark/worker/driver-20180108231954-0002/writeRead-0.1.jar INFO 2018-01-08 23:19:55,031 Logging.scala:54 - Copying /home/automaton/writeRead-0.1.jar to /var/lib/spark/worker/driver-20180108231954-0002/writeRead-0.1.jar INFO 2018-01-08 23:19:55,038 Logging.scala:54 - Launch Command: ...... INFO 2018-01-08 23:21:28,674 ShuffleSecretManager.java:69 - Unregistered shuffle secret for application app-20180108232000-0000 INFO 2018-01-08 23:21:28,675 ExternalShuffleBlockResolver.java:163 - Application app-20180108232000-0000 removed, cleanupLocalDirs = false INFO 2018-01-08 23:21:28,675 ExternalShuffleBlockResolver.java:163 - Application app-20180108232000-0000 removed, cleanupLocalDirs = false INFO 2018-01-08 23:21:28,681 ExternalShuffleBlockResolver.java:163 - Application app-20180108232000-0000 removed, cleanupLocalDirs = true INFO 2018-01-08 23:21:31,703 Logging.scala:54 - Driver driver-20180108231954-0002 exited successfully ***** INFO 2018-01-08 23:21:32,346 Logging.scala:54 - Removing directory: /var/lib/spark/worker/driver-20180108231954-0002 ### < Happening AFTER the Run completes rather than during it ***** ``` Author: Russell Spitzer Closes #20298 from RussellSpitzer/SPARK-22976-master. --- core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index 3962d422f81d3..563b84934f264 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -441,7 +441,7 @@ private[deploy] class Worker( // Spin up a separate thread (in a future) to do the dir cleanup; don't tie up worker // rpcEndpoint. // Copy ids so that it can be used in the cleanup thread. - val appIds = executors.values.map(_.appId).toSet + val appIds = (executors.values.map(_.appId) ++ drivers.values.map(_.driverId)).toSet val cleanupFuture = concurrent.Future { val appDirs = workDir.listFiles() if (appDirs == null) { From 8142a3b883a5fe6fc620a2c5b25b6bde4fda32e5 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Mon, 22 Jan 2018 15:18:57 +0900 Subject: [PATCH 0165/1161] [MINOR][SQL] Fix wrong comments on org.apache.spark.sql.parquet.row.attributes ## What changes were proposed in this pull request? This PR fixes the wrong comment on `org.apache.spark.sql.parquet.row.attributes` which is useful for UDTs like Vector/Matrix. Please see [SPARK-22320](https://issues.apache.org/jira/browse/SPARK-22320) for the usage. Originally, [SPARK-19411](https://github.com/apache/spark/commit/bf493686eb17006727b3ec81849b22f3df68fdef#diff-ee26d4c4be21e92e92a02e9f16dbc285L314) left this behind during removing optional column metadatas. In the same PR, the same comment was removed at line 310-311. ## How was this patch tested? N/A (This is about comments). Author: Dongjoon Hyun Closes #20346 from dongjoon-hyun/minor_comment_parquet. --- .../sql/execution/datasources/parquet/ParquetFileFormat.scala | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala index 45bedf70f975c..f53a97ba45a26 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala @@ -108,8 +108,7 @@ class ParquetFileFormat ParquetOutputFormat.setWriteSupportClass(job, classOf[ParquetWriteSupport]) - // We want to clear this temporary metadata from saving into Parquet file. - // This metadata is only useful for detecting optional columns when pushdowning filters. + // This metadata is useful for keeping UDTs like Vector/Matrix. ParquetWriteSupport.setSchema(dataSchema, conf) // Sets flags for `ParquetWriteSupport`, which converts Catalyst schema to Parquet From ec228976156619ed8df21a85bceb5fd3bdeb5855 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Mon, 22 Jan 2018 14:49:12 +0800 Subject: [PATCH 0166/1161] [SPARK-23020][CORE] Fix races in launcher code, test. The race in the code is because the handle might update its state to the wrong state if the connection handling thread is still processing incoming data; so the handle needs to wait for the connection to finish up before checking the final state. The race in the test is because when waiting for a handle to reach a final state, the waitFor() method needs to wait until all handle state is updated (which also includes waiting for the connection thread above to finish). Otherwise, waitFor() may return too early, which would cause a bunch of different races (like the listener not being yet notified of the state change, or being in the middle of being notified, or the handle not being properly disposed and causing postChecks() to assert). On top of that I found, by code inspection, a couple of potential races that could make a handle end up in the wrong state when being killed. The original version of this fix introduced the flipped version of the first race described above; the connection closing might override the handle state before the handle might have a chance to do cleanup. The fix there is to only dispose of the handle from the connection when there is an error, and let the handle dispose itself in the normal case. The fix also caused a bug in YarnClusterSuite to be surfaced; the code was checking for a file in the classpath that was not expected to be there in client mode. Because of the above issues, the error was not propagating correctly and the (buggy) test was incorrectly passing. Tested by running the existing unit tests a lot (and not seeing the errors I was seeing before). Author: Marcelo Vanzin Closes #20297 from vanzin/SPARK-23020. --- .../spark/launcher/SparkLauncherSuite.java | 53 +++++++++++-------- .../spark/launcher/AbstractAppHandle.java | 22 ++++++-- .../spark/launcher/ChildProcAppHandle.java | 18 ++++--- .../spark/launcher/InProcessAppHandle.java | 17 +++--- .../spark/launcher/LauncherConnection.java | 18 +++---- .../apache/spark/launcher/LauncherServer.java | 49 +++++++++++++---- .../org/apache/spark/launcher/BaseSuite.java | 42 ++++++++++++--- .../spark/launcher/LauncherServerSuite.java | 23 ++++---- .../spark/deploy/yarn/YarnClusterSuite.scala | 4 +- 9 files changed, 165 insertions(+), 81 deletions(-) diff --git a/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java b/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java index dffa609f1cbdf..a042375c6ae91 100644 --- a/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java +++ b/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java @@ -17,6 +17,7 @@ package org.apache.spark.launcher; +import java.time.Duration; import java.util.Arrays; import java.util.ArrayList; import java.util.HashMap; @@ -25,13 +26,13 @@ import java.util.Properties; import java.util.concurrent.TimeUnit; -import org.junit.Ignore; import org.junit.Test; import static org.junit.Assert.*; import static org.junit.Assume.*; import static org.mockito.Mockito.*; import org.apache.spark.SparkContext; +import org.apache.spark.SparkContext$; import org.apache.spark.internal.config.package$; import org.apache.spark.util.Utils; @@ -121,8 +122,7 @@ public void testChildProcLauncher() throws Exception { assertEquals(0, app.waitFor()); } - // TODO: [SPARK-23020] Re-enable this - @Ignore + @Test public void testInProcessLauncher() throws Exception { // Because this test runs SparkLauncher in process and in client mode, it pollutes the system // properties, and that can cause test failures down the test pipeline. So restore the original @@ -139,7 +139,9 @@ public void testInProcessLauncher() throws Exception { // Here DAGScheduler is stopped, while SparkContext.clearActiveContext may not be called yet. // Wait for a reasonable amount of time to avoid creating two active SparkContext in JVM. // See SPARK-23019 and SparkContext.stop() for details. - TimeUnit.MILLISECONDS.sleep(500); + eventually(Duration.ofSeconds(5), Duration.ofMillis(10), () -> { + assertTrue("SparkContext is still alive.", SparkContext$.MODULE$.getActive().isEmpty()); + }); } } @@ -148,26 +150,35 @@ private void inProcessLauncherTestImpl() throws Exception { SparkAppHandle.Listener listener = mock(SparkAppHandle.Listener.class); doAnswer(invocation -> { SparkAppHandle h = (SparkAppHandle) invocation.getArguments()[0]; - transitions.add(h.getState()); + synchronized (transitions) { + transitions.add(h.getState()); + } return null; }).when(listener).stateChanged(any(SparkAppHandle.class)); - SparkAppHandle handle = new InProcessLauncher() - .setMaster("local") - .setAppResource(SparkLauncher.NO_RESOURCE) - .setMainClass(InProcessTestApp.class.getName()) - .addAppArgs("hello") - .startApplication(listener); - - waitFor(handle); - assertEquals(SparkAppHandle.State.FINISHED, handle.getState()); - - // Matches the behavior of LocalSchedulerBackend. - List expected = Arrays.asList( - SparkAppHandle.State.CONNECTED, - SparkAppHandle.State.RUNNING, - SparkAppHandle.State.FINISHED); - assertEquals(expected, transitions); + SparkAppHandle handle = null; + try { + handle = new InProcessLauncher() + .setMaster("local") + .setAppResource(SparkLauncher.NO_RESOURCE) + .setMainClass(InProcessTestApp.class.getName()) + .addAppArgs("hello") + .startApplication(listener); + + waitFor(handle); + assertEquals(SparkAppHandle.State.FINISHED, handle.getState()); + + // Matches the behavior of LocalSchedulerBackend. + List expected = Arrays.asList( + SparkAppHandle.State.CONNECTED, + SparkAppHandle.State.RUNNING, + SparkAppHandle.State.FINISHED); + assertEquals(expected, transitions); + } finally { + if (handle != null) { + handle.kill(); + } + } } public static class SparkLauncherTestApp { diff --git a/launcher/src/main/java/org/apache/spark/launcher/AbstractAppHandle.java b/launcher/src/main/java/org/apache/spark/launcher/AbstractAppHandle.java index df1e7316861d4..daf0972f824dd 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/AbstractAppHandle.java +++ b/launcher/src/main/java/org/apache/spark/launcher/AbstractAppHandle.java @@ -33,7 +33,7 @@ abstract class AbstractAppHandle implements SparkAppHandle { private List listeners; private State state; private String appId; - private boolean disposed; + private volatile boolean disposed; protected AbstractAppHandle(LauncherServer server) { this.server = server; @@ -70,8 +70,7 @@ public void stop() { @Override public synchronized void disconnect() { - if (!disposed) { - disposed = true; + if (!isDisposed()) { if (connection != null) { try { connection.close(); @@ -79,7 +78,7 @@ public synchronized void disconnect() { // no-op. } } - server.unregister(this); + dispose(); } } @@ -95,6 +94,21 @@ boolean isDisposed() { return disposed; } + /** + * Mark the handle as disposed, and set it as LOST in case the current state is not final. + */ + synchronized void dispose() { + if (!isDisposed()) { + // Unregister first to make sure that the connection with the app has been really + // terminated. + server.unregister(this); + if (!getState().isFinal()) { + setState(State.LOST); + } + this.disposed = true; + } + } + void setState(State s) { setState(s, false); } diff --git a/launcher/src/main/java/org/apache/spark/launcher/ChildProcAppHandle.java b/launcher/src/main/java/org/apache/spark/launcher/ChildProcAppHandle.java index 8b3f427b7750e..2b99461652e1f 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/ChildProcAppHandle.java +++ b/launcher/src/main/java/org/apache/spark/launcher/ChildProcAppHandle.java @@ -48,14 +48,16 @@ public synchronized void disconnect() { @Override public synchronized void kill() { - disconnect(); - if (childProc != null) { - if (childProc.isAlive()) { - childProc.destroyForcibly(); + if (!isDisposed()) { + setState(State.KILLED); + disconnect(); + if (childProc != null) { + if (childProc.isAlive()) { + childProc.destroyForcibly(); + } + childProc = null; } - childProc = null; } - setState(State.KILLED); } void setChildProc(Process childProc, String loggerName, InputStream logStream) { @@ -94,8 +96,6 @@ void monitorChild() { return; } - disconnect(); - int ec; try { ec = proc.exitValue(); @@ -118,6 +118,8 @@ void monitorChild() { if (newState != null) { setState(newState, true); } + + disconnect(); } } diff --git a/launcher/src/main/java/org/apache/spark/launcher/InProcessAppHandle.java b/launcher/src/main/java/org/apache/spark/launcher/InProcessAppHandle.java index acd64c962604f..f04263cb74a58 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/InProcessAppHandle.java +++ b/launcher/src/main/java/org/apache/spark/launcher/InProcessAppHandle.java @@ -39,15 +39,16 @@ class InProcessAppHandle extends AbstractAppHandle { @Override public synchronized void kill() { - LOG.warning("kill() may leave the underlying app running in in-process mode."); - disconnect(); - - // Interrupt the thread. This is not guaranteed to kill the app, though. - if (app != null) { - app.interrupt(); + if (!isDisposed()) { + LOG.warning("kill() may leave the underlying app running in in-process mode."); + setState(State.KILLED); + disconnect(); + + // Interrupt the thread. This is not guaranteed to kill the app, though. + if (app != null) { + app.interrupt(); + } } - - setState(State.KILLED); } synchronized void start(String appName, Method main, String[] args) { diff --git a/launcher/src/main/java/org/apache/spark/launcher/LauncherConnection.java b/launcher/src/main/java/org/apache/spark/launcher/LauncherConnection.java index b4a8719e26053..e8ab3f5e369ab 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/LauncherConnection.java +++ b/launcher/src/main/java/org/apache/spark/launcher/LauncherConnection.java @@ -53,7 +53,7 @@ abstract class LauncherConnection implements Closeable, Runnable { public void run() { try { FilteredObjectInputStream in = new FilteredObjectInputStream(socket.getInputStream()); - while (!closed) { + while (isOpen()) { Message msg = (Message) in.readObject(); handle(msg); } @@ -95,15 +95,15 @@ protected synchronized void send(Message msg) throws IOException { } @Override - public void close() throws IOException { - if (!closed) { - synchronized (this) { - if (!closed) { - closed = true; - socket.close(); - } - } + public synchronized void close() throws IOException { + if (isOpen()) { + closed = true; + socket.close(); } } + boolean isOpen() { + return !closed; + } + } diff --git a/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java b/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java index b8999a1d7a4f4..8091885c4f562 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java +++ b/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java @@ -217,6 +217,33 @@ void unregister(AbstractAppHandle handle) { break; } } + + // If there is a live connection for this handle, we need to wait for it to finish before + // returning, otherwise there might be a race between the connection thread processing + // buffered data and the handle cleaning up after itself, leading to potentially the wrong + // state being reported for the handle. + ServerConnection conn = null; + synchronized (clients) { + for (ServerConnection c : clients) { + if (c.handle == handle) { + conn = c; + break; + } + } + } + + if (conn != null) { + synchronized (conn) { + if (conn.isOpen()) { + try { + conn.wait(); + } catch (InterruptedException ie) { + // Ignore. + } + } + } + } + unref(); } @@ -288,7 +315,7 @@ private String createSecret() { private class ServerConnection extends LauncherConnection { private TimerTask timeout; - private AbstractAppHandle handle; + volatile AbstractAppHandle handle; ServerConnection(Socket socket, TimerTask timeout) throws IOException { super(socket); @@ -313,7 +340,7 @@ protected void handle(Message msg) throws IOException { } else { if (handle == null) { throw new IllegalArgumentException("Expected hello, got: " + - msg != null ? msg.getClass().getName() : null); + msg != null ? msg.getClass().getName() : null); } if (msg instanceof SetAppId) { SetAppId set = (SetAppId) msg; @@ -331,6 +358,9 @@ protected void handle(Message msg) throws IOException { timeout.cancel(); } close(); + if (handle != null) { + handle.dispose(); + } } finally { timeoutTimer.purge(); } @@ -338,16 +368,17 @@ protected void handle(Message msg) throws IOException { @Override public void close() throws IOException { + if (!isOpen()) { + return; + } + synchronized (clients) { clients.remove(this); } - super.close(); - if (handle != null) { - if (!handle.getState().isFinal()) { - LOG.log(Level.WARNING, "Lost connection to spark application."); - handle.setState(SparkAppHandle.State.LOST); - } - handle.disconnect(); + + synchronized (this) { + super.close(); + notifyAll(); } } diff --git a/launcher/src/test/java/org/apache/spark/launcher/BaseSuite.java b/launcher/src/test/java/org/apache/spark/launcher/BaseSuite.java index 3e1a90eae98d4..3722a59d9438e 100644 --- a/launcher/src/test/java/org/apache/spark/launcher/BaseSuite.java +++ b/launcher/src/test/java/org/apache/spark/launcher/BaseSuite.java @@ -17,6 +17,7 @@ package org.apache.spark.launcher; +import java.time.Duration; import java.util.concurrent.TimeUnit; import org.junit.After; @@ -47,19 +48,46 @@ public void postChecks() { assertNull(server); } - protected void waitFor(SparkAppHandle handle) throws Exception { - long deadline = System.nanoTime() + TimeUnit.SECONDS.toNanos(10); + protected void waitFor(final SparkAppHandle handle) throws Exception { try { - while (!handle.getState().isFinal()) { - assertTrue("Timed out waiting for handle to transition to final state.", - System.nanoTime() < deadline); - TimeUnit.MILLISECONDS.sleep(10); - } + eventually(Duration.ofSeconds(10), Duration.ofMillis(10), () -> { + assertTrue("Handle is not in final state.", handle.getState().isFinal()); + }); } finally { if (!handle.getState().isFinal()) { handle.kill(); } } + + // Wait until the handle has been marked as disposed, to make sure all cleanup tasks + // have been performed. + AbstractAppHandle ahandle = (AbstractAppHandle) handle; + eventually(Duration.ofSeconds(10), Duration.ofMillis(10), () -> { + assertTrue("Handle is still not marked as disposed.", ahandle.isDisposed()); + }); + } + + /** + * Call a closure that performs a check every "period" until it succeeds, or the timeout + * elapses. + */ + protected void eventually(Duration timeout, Duration period, Runnable check) throws Exception { + assertTrue("Timeout needs to be larger than period.", timeout.compareTo(period) > 0); + long deadline = System.nanoTime() + timeout.toNanos(); + int count = 0; + while (true) { + try { + count++; + check.run(); + return; + } catch (Throwable t) { + if (System.nanoTime() >= deadline) { + String msg = String.format("Failed check after %d tries: %s.", count, t.getMessage()); + throw new IllegalStateException(msg, t); + } + Thread.sleep(period.toMillis()); + } + } } } diff --git a/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java b/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java index 7e2b09ce25c9b..024efac33c391 100644 --- a/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java +++ b/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java @@ -23,12 +23,14 @@ import java.net.InetAddress; import java.net.Socket; import java.net.SocketException; +import java.time.Duration; import java.util.Arrays; import java.util.List; import java.util.concurrent.BlockingQueue; import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.Semaphore; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; import org.junit.Test; import static org.junit.Assert.*; @@ -143,7 +145,8 @@ public void infoChanged(SparkAppHandle handle) { assertTrue(semaphore.tryAcquire(30, TimeUnit.SECONDS)); // Make sure the server matched the client to the handle. assertNotNull(handle.getConnection()); - close(client); + client.close(); + handle.dispose(); assertTrue(semaphore.tryAcquire(30, TimeUnit.SECONDS)); assertEquals(SparkAppHandle.State.LOST, handle.getState()); } finally { @@ -197,28 +200,20 @@ private void close(Closeable c) { * server-side close immediately. */ private void waitForError(TestClient client, String secret) throws Exception { - boolean helloSent = false; - int maxTries = 10; - for (int i = 0; i < maxTries; i++) { + final AtomicBoolean helloSent = new AtomicBoolean(); + eventually(Duration.ofSeconds(1), Duration.ofMillis(10), () -> { try { - if (!helloSent) { + if (!helloSent.get()) { client.send(new Hello(secret, "1.4.0")); - helloSent = true; + helloSent.set(true); } else { client.send(new SetAppId("appId")); } fail("Expected error but message went through."); } catch (IllegalStateException | IOException e) { // Expected. - break; - } catch (AssertionError e) { - if (i < maxTries - 1) { - Thread.sleep(100); - } else { - throw new AssertionError("Test failed after " + maxTries + " attempts.", e); - } } - } + }); } private static class TestClient extends LauncherConnection { diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala index 061f653b97b7a..e9dcfaf6ba4f0 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala @@ -381,7 +381,9 @@ private object YarnClusterDriver extends Logging with Matchers { // Verify that the config archive is correctly placed in the classpath of all containers. val confFile = "/" + Client.SPARK_CONF_FILE - assert(getClass().getResource(confFile) != null) + if (conf.getOption(SparkLauncher.DEPLOY_MODE) == Some("cluster")) { + assert(getClass().getResource(confFile) != null) + } val configFromExecutors = sc.parallelize(1 to 4, 4) .map { _ => Option(getClass().getResource(confFile)).map(_.toString).orNull } .collect() From 60175e959f275d2961798fbc5a9150dac9de51ff Mon Sep 17 00:00:00 2001 From: Arseniy Tashoyan Date: Mon, 22 Jan 2018 20:17:05 +0800 Subject: [PATCH 0167/1161] [MINOR][DOC] Fix the path to the examples jar ## What changes were proposed in this pull request? The example jar file is now in ./examples/jars directory of Spark distribution. Author: Arseniy Tashoyan Closes #20349 from tashoyan/patch-1. --- docs/running-on-yarn.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index e4f5a0c659e66..c010af35f8d2e 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -35,7 +35,7 @@ For example: --executor-memory 2g \ --executor-cores 1 \ --queue thequeue \ - lib/spark-examples*.jar \ + examples/jars/spark-examples*.jar \ 10 The above starts a YARN client program which starts the default Application Master. Then SparkPi will be run as a child thread of Application Master. The client will periodically poll the Application Master for status updates and display them in the console. The client will exit once your application has finished running. Refer to the "Debugging your Application" section below for how to see driver and executor logs. From 73281161fc7fddd645c712986ec376ac2b1bd213 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Mon, 22 Jan 2018 04:27:59 -0800 Subject: [PATCH 0168/1161] [SPARK-23122][PYSPARK][FOLLOW-UP] Update the docs for UDF Registration ## What changes were proposed in this pull request? This PR is to update the docs for UDF registration ## How was this patch tested? N/A Author: gatorsmile Closes #20348 from gatorsmile/testUpdateDoc. --- python/pyspark/sql/udf.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index c77f19f89a442..134badb8485f5 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -199,8 +199,8 @@ def __init__(self, sparkSession): @ignore_unicode_prefix @since("1.3.1") def register(self, name, f, returnType=None): - """Registers a Python function (including lambda function) or a user-defined function - in SQL statements. + """Register a Python function (including lambda function) or a user-defined function + as a SQL function. :param name: name of the user-defined function in SQL statements. :param f: a Python function, or a user-defined function. The user-defined function can @@ -210,6 +210,10 @@ def register(self, name, f, returnType=None): be either a :class:`pyspark.sql.types.DataType` object or a DDL-formatted type string. :return: a user-defined function. + To register a nondeterministic Python function, users need to first build + a nondeterministic user-defined function for the Python function and then register it + as a SQL function. + `returnType` can be optionally specified when `f` is a Python function but not when `f` is a user-defined function. Please see below. @@ -297,7 +301,7 @@ def register(self, name, f, returnType=None): @ignore_unicode_prefix @since(2.3) def registerJavaFunction(self, name, javaClassName, returnType=None): - """Register a Java user-defined function so it can be used in SQL statements. + """Register a Java user-defined function as a SQL function. In addition to a name and the function itself, the return type can be optionally specified. When the return type is not specified we would infer it via reflection. @@ -334,7 +338,7 @@ def registerJavaFunction(self, name, javaClassName, returnType=None): @ignore_unicode_prefix @since(2.3) def registerJavaUDAF(self, name, javaClassName): - """Register a Java user-defined aggregate function so it can be used in SQL statements. + """Register a Java user-defined aggregate function as a SQL function. :param name: name of the user-defined aggregate function :param javaClassName: fully qualified name of java class From 78801881c405de47f7e53eea3e0420dd69593dbd Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Mon, 22 Jan 2018 04:31:24 -0800 Subject: [PATCH 0169/1161] [SPARK-23170][SQL] Dump the statistics of effective runs of analyzer and optimizer rules ## What changes were proposed in this pull request? Dump the statistics of effective runs of analyzer and optimizer rules. ## How was this patch tested? Do a manual run of TPCDSQuerySuite ``` === Metrics of Analyzer/Optimizer Rules === Total number of runs: 175899 Total time: 25.486559948 seconds Rule Effective Time / Total Time Effective Runs / Total Runs org.apache.spark.sql.catalyst.optimizer.ColumnPruning 1603280450 / 2868461549 761 / 1877 org.apache.spark.sql.catalyst.analysis.Analyzer$CTESubstitution 2045860009 / 2056602674 37 / 788 org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveAggregateFunctions 440719059 / 1693110949 38 / 1982 org.apache.spark.sql.catalyst.optimizer.Optimizer$OptimizeSubqueries 1429834919 / 1446016225 39 / 285 org.apache.spark.sql.catalyst.optimizer.PruneFilters 33273083 / 1389586938 3 / 1592 org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveReferences 821183615 / 1266668754 616 / 1982 org.apache.spark.sql.catalyst.optimizer.ReorderJoin 775837028 / 866238225 132 / 1592 org.apache.spark.sql.catalyst.analysis.DecimalPrecision 550683593 / 748854507 211 / 1982 org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveSubquery 513075345 / 634370596 49 / 1982 org.apache.spark.sql.catalyst.analysis.Analyzer$FixNullability 33475731 / 606406532 12 / 742 org.apache.spark.sql.catalyst.analysis.TypeCoercion$ImplicitTypeCasts 193144298 / 545403925 86 / 1982 org.apache.spark.sql.catalyst.optimizer.BooleanSimplification 18651497 / 495725004 7 / 1592 org.apache.spark.sql.catalyst.optimizer.PushPredicateThroughJoin 369257217 / 489934378 709 / 1592 org.apache.spark.sql.catalyst.optimizer.RemoveRedundantAliases 3707000 / 468291609 9 / 1592 org.apache.spark.sql.catalyst.optimizer.InferFiltersFromConstraints 410155900 / 435254175 192 / 285 org.apache.spark.sql.execution.datasources.FindDataSourceTable 348885539 / 371855866 233 / 1982 org.apache.spark.sql.catalyst.optimizer.NullPropagation 11307645 / 307531225 26 / 1592 org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveFunctions 120324545 / 304948785 294 / 1982 org.apache.spark.sql.catalyst.analysis.TypeCoercion$FunctionArgumentConversion 92323199 / 286695007 38 / 1982 org.apache.spark.sql.catalyst.optimizer.PushDownPredicate 230084193 / 265845972 785 / 1592 org.apache.spark.sql.catalyst.analysis.TypeCoercion$PromoteStrings 45938401 / 265144009 40 / 1982 org.apache.spark.sql.catalyst.analysis.TypeCoercion$InConversion 14888776 / 261499450 1 / 1982 org.apache.spark.sql.catalyst.analysis.TypeCoercion$CaseWhenCoercion 113796384 / 244913861 29 / 1982 org.apache.spark.sql.catalyst.optimizer.ConstantFolding 65008069 / 236548480 126 / 1592 org.apache.spark.sql.catalyst.analysis.Analyzer$ExtractGenerator 0 / 226338929 0 / 1982 org.apache.spark.sql.catalyst.analysis.ResolveTimeZone 98134906 / 221323770 417 / 1982 org.apache.spark.sql.catalyst.optimizer.ReorderAssociativeOperator 0 / 208421703 0 / 1592 org.apache.spark.sql.catalyst.optimizer.OptimizeIn 8762534 / 199351958 16 / 1592 org.apache.spark.sql.catalyst.analysis.TypeCoercion$DateTimeOperations 11980016 / 190779046 27 / 1982 org.apache.spark.sql.catalyst.optimizer.SimplifyBinaryComparison 0 / 188887385 0 / 1592 org.apache.spark.sql.catalyst.optimizer.SimplifyConditionals 0 / 186812106 0 / 1592 org.apache.spark.sql.catalyst.optimizer.SimplifyCaseConversionExpressions 0 / 183885230 0 / 1592 org.apache.spark.sql.catalyst.optimizer.SimplifyCasts 17128295 / 182901910 69 / 1592 org.apache.spark.sql.catalyst.analysis.TypeCoercion$Division 14579110 / 180309340 8 / 1982 org.apache.spark.sql.catalyst.analysis.TypeCoercion$BooleanEquality 0 / 176740516 0 / 1982 org.apache.spark.sql.catalyst.analysis.TypeCoercion$IfCoercion 0 / 170781986 0 / 1982 org.apache.spark.sql.catalyst.optimizer.LikeSimplification 771605 / 164136736 1 / 1592 org.apache.spark.sql.catalyst.optimizer.RemoveDispensableExpressions 0 / 155958962 0 / 1592 org.apache.spark.sql.catalyst.analysis.ResolveCreateNamedStruct 0 / 151222943 0 / 1982 org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveWindowOrder 7534632 / 146596355 14 / 1982 org.apache.spark.sql.catalyst.analysis.TypeCoercion$EltCoercion 0 / 144488654 0 / 1982 org.apache.spark.sql.catalyst.analysis.TypeCoercion$ConcatCoercion 0 / 142403338 0 / 1982 org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveWindowFrame 12067635 / 141500665 21 / 1982 org.apache.spark.sql.catalyst.analysis.TimeWindowing 0 / 140431958 0 / 1982 org.apache.spark.sql.catalyst.analysis.TypeCoercion$WindowFrameCoercion 0 / 125471960 0 / 1982 org.apache.spark.sql.catalyst.optimizer.EliminateOuterJoin 14226972 / 124922019 11 / 1592 org.apache.spark.sql.catalyst.analysis.TypeCoercion$StackCoercion 0 / 123613887 0 / 1982 org.apache.spark.sql.catalyst.optimizer.RewriteCorrelatedScalarSubquery 8491071 / 121179056 7 / 1592 org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveGroupingAnalytics 55526073 / 120290529 11 / 1982 org.apache.spark.sql.catalyst.optimizer.ConstantPropagation 0 / 113886790 0 / 1592 org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveDeserializer 52383759 / 107160222 148 / 1982 org.apache.spark.sql.catalyst.analysis.CleanupAliases 52543524 / 102091518 344 / 1086 org.apache.spark.sql.catalyst.optimizer.RemoveRedundantProject 40682895 / 94403652 342 / 1877 org.apache.spark.sql.catalyst.analysis.Analyzer$ExtractWindowExpressions 38473816 / 89740578 23 / 1982 org.apache.spark.sql.catalyst.optimizer.CollapseProject 46806090 / 83315506 281 / 1877 org.apache.spark.sql.catalyst.optimizer.FoldablePropagation 0 / 78750087 0 / 1592 org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveAliases 13742765 / 77227258 47 / 1982 org.apache.spark.sql.catalyst.optimizer.CombineFilters 53386729 / 76960344 448 / 1592 org.apache.spark.sql.execution.datasources.DataSourceAnalysis 68034341 / 75724186 24 / 742 org.apache.spark.sql.catalyst.analysis.Analyzer$LookupFunctions 0 / 71151084 0 / 750 org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveMissingReferences 12139848 / 67599140 8 / 1982 org.apache.spark.sql.catalyst.optimizer.PullupCorrelatedPredicates 45017938 / 65968777 23 / 285 org.apache.spark.sql.execution.datasources.v2.PushDownOperatorsToDataSource 0 / 60937767 0 / 285 org.apache.spark.sql.catalyst.optimizer.CollapseRepartition 0 / 59897237 0 / 1592 org.apache.spark.sql.catalyst.optimizer.PushProjectionThroughUnion 8547262 / 53941370 10 / 1592 org.apache.spark.sql.catalyst.analysis.Analyzer$HandleNullInputsForUDF 0 / 52735976 0 / 742 org.apache.spark.sql.catalyst.analysis.TypeCoercion$WidenSetOperationTypes 9797713 / 52401665 9 / 1982 org.apache.spark.sql.catalyst.analysis.Analyzer$PullOutNondeterministic 0 / 51741500 0 / 742 org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveRelations 28614911 / 51061186 233 / 1990 org.apache.spark.sql.execution.datasources.PruneFileSourcePartitions 0 / 50621510 0 / 285 org.apache.spark.sql.catalyst.optimizer.CombineUnions 2777800 / 50262112 17 / 1877 org.apache.spark.sql.catalyst.analysis.Analyzer$GlobalAggregates 1640641 / 49633909 46 / 1982 org.apache.spark.sql.catalyst.optimizer.DecimalAggregates 20198374 / 48488419 100 / 385 org.apache.spark.sql.catalyst.optimizer.LimitPushDown 0 / 45052523 0 / 1592 org.apache.spark.sql.catalyst.optimizer.CombineLimits 0 / 44719443 0 / 1592 org.apache.spark.sql.catalyst.optimizer.EliminateSorts 0 / 44216930 0 / 1592 org.apache.spark.sql.catalyst.optimizer.RewritePredicateSubquery 36235699 / 44165786 148 / 285 org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveNewInstance 0 / 42750307 0 / 1982 org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveUpCast 0 / 41811748 0 / 1982 org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveOrdinalInOrderByAndGroupBy 3819476 / 41776562 4 / 1982 org.apache.spark.sql.catalyst.optimizer.ComputeCurrentTime 0 / 40527808 0 / 285 org.apache.spark.sql.catalyst.optimizer.CollapseWindow 0 / 36832538 0 / 1592 org.apache.spark.sql.catalyst.optimizer.EliminateSerialization 0 / 36120667 0 / 1592 org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveAggAliasInGroupBy 0 / 32435826 0 / 1982 org.apache.spark.sql.execution.datasources.PreprocessTableCreation 0 / 32145218 0 / 742 org.apache.spark.sql.execution.datasources.ResolveSQLOnFile 0 / 30295614 0 / 1982 org.apache.spark.sql.catalyst.analysis.Analyzer$ResolvePivot 0 / 30111655 0 / 1982 org.apache.spark.sql.catalyst.expressions.codegen.package$ExpressionCanonicalizer$CleanExpressions 59930 / 28038201 26 / 8280 org.apache.spark.sql.catalyst.analysis.ResolveInlineTables 0 / 27808108 0 / 1982 org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveSubqueryColumnAliases 0 / 27066690 0 / 1982 org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveGenerate 0 / 26660210 0 / 1982 org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveNaturalAndUsingJoin 0 / 25255184 0 / 1982 org.apache.spark.sql.catalyst.analysis.ResolveTableValuedFunctions 0 / 24663088 0 / 1990 org.apache.spark.sql.catalyst.analysis.SubstituteUnresolvedOrdinals 9709079 / 24450670 4 / 788 org.apache.spark.sql.catalyst.analysis.ResolveHints$ResolveBroadcastHints 0 / 23776535 0 / 750 org.apache.spark.sql.catalyst.optimizer.ReplaceExpressions 0 / 22697895 0 / 285 org.apache.spark.sql.catalyst.optimizer.CheckCartesianProducts 0 / 22523798 0 / 285 org.apache.spark.sql.catalyst.optimizer.ReplaceDistinctWithAggregate 988593 / 21535410 15 / 300 org.apache.spark.sql.catalyst.optimizer.EliminateMapObjects 0 / 20269996 0 / 285 org.apache.spark.sql.catalyst.optimizer.RewriteDistinctAggregates 0 / 19388592 0 / 285 org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases 17675532 / 18971185 215 / 285 org.apache.spark.sql.catalyst.optimizer.GetCurrentDatabase 0 / 18271152 0 / 285 org.apache.spark.sql.catalyst.optimizer.PropagateEmptyRelation 2077097 / 17190855 3 / 288 org.apache.spark.sql.catalyst.analysis.EliminateBarriers 0 / 16736359 0 / 1086 org.apache.spark.sql.execution.OptimizeMetadataOnlyQuery 0 / 16669341 0 / 285 org.apache.spark.sql.catalyst.analysis.UpdateOuterReferences 0 / 14470235 0 / 742 org.apache.spark.sql.catalyst.optimizer.ReplaceExceptWithAntiJoin 6715625 / 12190561 1 / 300 org.apache.spark.sql.catalyst.optimizer.ReplaceIntersectWithSemiJoin 3451793 / 11431432 7 / 300 org.apache.spark.sql.execution.python.ExtractPythonUDFFromAggregate 0 / 10810568 0 / 285 org.apache.spark.sql.catalyst.optimizer.RemoveRepetitionFromGroupExpressions 344198 / 10475276 1 / 286 org.apache.spark.sql.catalyst.analysis.Analyzer$WindowsSubstitution 0 / 10386630 0 / 788 org.apache.spark.sql.catalyst.analysis.EliminateUnions 0 / 10096526 0 / 788 org.apache.spark.sql.catalyst.analysis.AliasViewChild 0 / 9991706 0 / 742 org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation 0 / 9649334 0 / 288 org.apache.spark.sql.catalyst.analysis.ResolveHints$RemoveAllHints 0 / 8739109 0 / 750 org.apache.spark.sql.execution.datasources.PreprocessTableInsertion 0 / 8420889 0 / 742 org.apache.spark.sql.catalyst.analysis.EliminateView 0 / 8319134 0 / 285 org.apache.spark.sql.catalyst.optimizer.RemoveLiteralFromGroupExpressions 0 / 7392627 0 / 286 org.apache.spark.sql.catalyst.optimizer.ReplaceExceptWithFilter 0 / 7170516 0 / 300 org.apache.spark.sql.catalyst.optimizer.SimplifyCreateArrayOps 0 / 7109643 0 / 1592 org.apache.spark.sql.catalyst.optimizer.SimplifyCreateStructOps 0 / 6837590 0 / 1592 org.apache.spark.sql.catalyst.optimizer.SimplifyCreateMapOps 0 / 6617848 0 / 1592 org.apache.spark.sql.catalyst.optimizer.CombineConcats 0 / 5768406 0 / 1592 org.apache.spark.sql.catalyst.optimizer.ReplaceDeduplicateWithAggregate 0 / 5349831 0 / 285 org.apache.spark.sql.catalyst.optimizer.CombineTypedFilters 0 / 5186642 0 / 285 org.apache.spark.sql.catalyst.optimizer.EliminateDistinct 0 / 2427686 0 / 285 org.apache.spark.sql.catalyst.optimizer.CostBasedJoinReorder 0 / 2420436 0 / 285 ``` Author: gatorsmile Closes #20342 from gatorsmile/reportExecution. --- .../rules/QueryExecutionMetering.scala | 91 +++++++++++++++++++ .../sql/catalyst/rules/RuleExecutor.scala | 32 +++---- .../apache/spark/sql/BenchmarkQueryTest.scala | 2 +- .../apache/spark/sql/SQLQueryTestSuite.scala | 2 +- .../execution/HiveCompatibilitySuite.scala | 2 +- 5 files changed, 109 insertions(+), 20 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/QueryExecutionMetering.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/QueryExecutionMetering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/QueryExecutionMetering.scala new file mode 100644 index 0000000000000..62f7541150a6e --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/QueryExecutionMetering.scala @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.rules + +import scala.collection.JavaConverters._ + +import com.google.common.util.concurrent.AtomicLongMap + +case class QueryExecutionMetering() { + private val timeMap = AtomicLongMap.create[String]() + private val numRunsMap = AtomicLongMap.create[String]() + private val numEffectiveRunsMap = AtomicLongMap.create[String]() + private val timeEffectiveRunsMap = AtomicLongMap.create[String]() + + /** Resets statistics about time spent running specific rules */ + def resetMetrics(): Unit = { + timeMap.clear() + numRunsMap.clear() + numEffectiveRunsMap.clear() + timeEffectiveRunsMap.clear() + } + + def totalTime: Long = { + timeMap.sum() + } + + def totalNumRuns: Long = { + numRunsMap.sum() + } + + def incExecutionTimeBy(ruleName: String, delta: Long): Unit = { + timeMap.addAndGet(ruleName, delta) + } + + def incTimeEffectiveExecutionBy(ruleName: String, delta: Long): Unit = { + timeEffectiveRunsMap.addAndGet(ruleName, delta) + } + + def incNumEffectiveExecution(ruleName: String): Unit = { + numEffectiveRunsMap.incrementAndGet(ruleName) + } + + def incNumExecution(ruleName: String): Unit = { + numRunsMap.incrementAndGet(ruleName) + } + + /** Dump statistics about time spent running specific rules. */ + def dumpTimeSpent(): String = { + val map = timeMap.asMap().asScala + val maxLengthRuleNames = map.keys.map(_.toString.length).max + + val colRuleName = "Rule".padTo(maxLengthRuleNames, " ").mkString + val colRunTime = "Effective Time / Total Time".padTo(len = 47, " ").mkString + val colNumRuns = "Effective Runs / Total Runs".padTo(len = 47, " ").mkString + + val ruleMetrics = map.toSeq.sortBy(_._2).reverseMap { case (name, time) => + val timeEffectiveRun = timeEffectiveRunsMap.get(name) + val numRuns = numRunsMap.get(name) + val numEffectiveRun = numEffectiveRunsMap.get(name) + + val ruleName = name.padTo(maxLengthRuleNames, " ").mkString + val runtimeValue = s"$timeEffectiveRun / $time".padTo(len = 47, " ").mkString + val numRunValue = s"$numEffectiveRun / $numRuns".padTo(len = 47, " ").mkString + s"$ruleName $runtimeValue $numRunValue" + }.mkString("\n", "\n", "") + + s""" + |=== Metrics of Analyzer/Optimizer Rules === + |Total number of runs: $totalNumRuns + |Total time: ${totalTime / 1000000000D} seconds + | + |$colRuleName $colRunTime $colNumRuns + |$ruleMetrics + """.stripMargin + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala index 7e4b784033bfc..dccb44ddebfa4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala @@ -17,10 +17,6 @@ package org.apache.spark.sql.catalyst.rules -import scala.collection.JavaConverters._ - -import com.google.common.util.concurrent.AtomicLongMap - import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.trees.TreeNode @@ -28,18 +24,16 @@ import org.apache.spark.sql.catalyst.util.sideBySide import org.apache.spark.util.Utils object RuleExecutor { - protected val timeMap = AtomicLongMap.create[String]() - - /** Resets statistics about time spent running specific rules */ - def resetTime(): Unit = timeMap.clear() + protected val queryExecutionMeter = QueryExecutionMetering() /** Dump statistics about time spent running specific rules. */ def dumpTimeSpent(): String = { - val map = timeMap.asMap().asScala - val maxSize = map.keys.map(_.toString.length).max - map.toSeq.sortBy(_._2).reverseMap { case (k, v) => - s"${k.padTo(maxSize, " ").mkString} $v" - }.mkString("\n", "\n", "") + queryExecutionMeter.dumpTimeSpent() + } + + /** Resets statistics about time spent running specific rules */ + def resetMetrics(): Unit = { + queryExecutionMeter.resetMetrics() } } @@ -77,6 +71,7 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging { */ def execute(plan: TreeType): TreeType = { var curPlan = plan + val queryExecutionMetrics = RuleExecutor.queryExecutionMeter batches.foreach { batch => val batchStartPlan = curPlan @@ -91,15 +86,18 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging { val startTime = System.nanoTime() val result = rule(plan) val runTime = System.nanoTime() - startTime - RuleExecutor.timeMap.addAndGet(rule.ruleName, runTime) if (!result.fastEquals(plan)) { + queryExecutionMetrics.incNumEffectiveExecution(rule.ruleName) + queryExecutionMetrics.incTimeEffectiveExecutionBy(rule.ruleName, runTime) logTrace( s""" |=== Applying Rule ${rule.ruleName} === |${sideBySide(plan.treeString, result.treeString).mkString("\n")} """.stripMargin) } + queryExecutionMetrics.incExecutionTimeBy(rule.ruleName, runTime) + queryExecutionMetrics.incNumExecution(rule.ruleName) // Run the structural integrity checker against the plan after each rule. if (!isPlanIntegral(result)) { @@ -135,9 +133,9 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging { if (!batchStartPlan.fastEquals(curPlan)) { logDebug( s""" - |=== Result of Batch ${batch.name} === - |${sideBySide(batchStartPlan.treeString, curPlan.treeString).mkString("\n")} - """.stripMargin) + |=== Result of Batch ${batch.name} === + |${sideBySide(batchStartPlan.treeString, curPlan.treeString).mkString("\n")} + """.stripMargin) } else { logTrace(s"Batch ${batch.name} has no effect.") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/BenchmarkQueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/BenchmarkQueryTest.scala index 7037749f14478..e51aad021fcbf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/BenchmarkQueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/BenchmarkQueryTest.scala @@ -46,7 +46,7 @@ abstract class BenchmarkQueryTest extends QueryTest with SharedSQLContext with B override def beforeAll() { super.beforeAll() - RuleExecutor.resetTime() + RuleExecutor.resetMetrics() } protected def checkGeneratedCode(plan: SparkPlan): Unit = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala index e3901af4b9988..054ada56d99ad 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala @@ -291,7 +291,7 @@ class SQLQueryTestSuite extends QueryTest with SharedSQLContext { TimeZone.setDefault(TimeZone.getTimeZone("America/Los_Angeles")) // Add Locale setting Locale.setDefault(Locale.US) - RuleExecutor.resetTime() + RuleExecutor.resetMetrics() } override def afterAll(): Unit = { diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index 45791c69b4cb7..cebaad5b4ad9b 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -62,7 +62,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { // Fix session local timezone to America/Los_Angeles for those timezone sensitive tests // (timestamp_*) TestHive.setConf(SQLConf.SESSION_LOCAL_TIMEZONE, "America/Los_Angeles") - RuleExecutor.resetTime() + RuleExecutor.resetMetrics() } override def afterAll() { From 896e45af5fea264683b1d7d20a1711f33908a06f Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Mon, 22 Jan 2018 04:32:59 -0800 Subject: [PATCH 0170/1161] [MINOR][SQL][TEST] Test case cleanups for recent PRs ## What changes were proposed in this pull request? Revert the unneeded test case changes we made in SPARK-23000 Also fixes the test suites that do not call `super.afterAll()` in the local `afterAll`. The `afterAll()` of `TestHiveSingleton` actually reset the environments. ## How was this patch tested? N/A Author: gatorsmile Closes #20341 from gatorsmile/testRelated. --- .../apache/spark/sql/DataFrameJoinSuite.scala | 21 ++++++----- .../apache/spark/sql/hive/test/TestHive.scala | 3 +- .../sql/hive/HiveMetastoreCatalogSuite.scala | 26 +++++++------- .../sql/hive/execution/HiveUDAFSuite.scala | 8 +++-- .../hive/execution/Hive_2_1_DDLSuite.scala | 6 +++- .../execution/ObjectHashAggregateSuite.scala | 6 +++- .../apache/spark/sql/hive/parquetSuites.scala | 35 +++++++++++-------- 7 files changed, 60 insertions(+), 45 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala index 1656f290ee19c..0d9eeabb397a1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala @@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.plans.{Inner, LeftOuter, RightOuter} import org.apache.spark.sql.catalyst.plans.logical.Join import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext class DataFrameJoinSuite extends QueryTest with SharedSQLContext { @@ -276,16 +277,14 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext { test("SPARK-23087: don't throw Analysis Exception in CheckCartesianProduct when join condition " + "is false or null") { - val df = spark.range(10) - val dfNull = spark.range(10).select(lit(null).as("b")) - val planNull = df.join(dfNull, $"id" === $"b", "left").queryExecution.analyzed - - spark.sessionState.executePlan(planNull).optimizedPlan - - val dfOne = df.select(lit(1).as("a")) - val dfTwo = spark.range(10).select(lit(2).as("b")) - val planFalse = dfOne.join(dfTwo, $"a" === $"b", "left").queryExecution.analyzed - - spark.sessionState.executePlan(planFalse).optimizedPlan + withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "false") { + val df = spark.range(10) + val dfNull = spark.range(10).select(lit(null).as("b")) + df.join(dfNull, $"id" === $"b", "left").queryExecution.optimizedPlan + + val dfOne = df.select(lit(1).as("a")) + val dfTwo = spark.range(10).select(lit(2).as("b")) + dfOne.join(dfTwo, $"a" === $"b", "left").queryExecution.optimizedPlan + } } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index c84131fc3212a..7287e20d55bbe 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -492,8 +492,7 @@ private[hive] class TestHiveSparkSession( protected val originalUDFs: JavaSet[String] = FunctionRegistry.getFunctionNames /** - * Resets the test instance by deleting any tables that have been created. - * TODO: also clear out UDFs, views, etc. + * Resets the test instance by deleting any table, view, temp view, and UDF that have been created */ def reset() { try { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala index 83b4c862e2546..ba9b944e4a055 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala @@ -166,13 +166,13 @@ class DataSourceWithHiveMetastoreCatalogSuite )) ).foreach { case (provider, (inputFormat, outputFormat, serde)) => test(s"Persist non-partitioned $provider relation into metastore as managed table") { - withTable("default.t") { + withTable("t") { withSQLConf(SQLConf.PARQUET_WRITE_LEGACY_FORMAT.key -> "true") { testDF .write .mode(SaveMode.Overwrite) .format(provider) - .saveAsTable("default.t") + .saveAsTable("t") } val hiveTable = sessionState.catalog.getTableMetadata(TableIdentifier("t", Some("default"))) @@ -187,15 +187,14 @@ class DataSourceWithHiveMetastoreCatalogSuite assert(columns.map(_.name) === Seq("d1", "d2")) assert(columns.map(_.dataType) === Seq(DecimalType(10, 3), StringType)) - checkAnswer(table("default.t"), testDF) - assert(sparkSession.metadataHive.runSqlHive("SELECT * FROM default.t") === - Seq("1.1\t1", "2.1\t2")) + checkAnswer(table("t"), testDF) + assert(sparkSession.metadataHive.runSqlHive("SELECT * FROM t") === Seq("1.1\t1", "2.1\t2")) } } test(s"Persist non-partitioned $provider relation into metastore as external table") { withTempPath { dir => - withTable("default.t") { + withTable("t") { val path = dir.getCanonicalFile withSQLConf(SQLConf.PARQUET_WRITE_LEGACY_FORMAT.key -> "true") { @@ -204,7 +203,7 @@ class DataSourceWithHiveMetastoreCatalogSuite .mode(SaveMode.Overwrite) .format(provider) .option("path", path.toString) - .saveAsTable("default.t") + .saveAsTable("t") } val hiveTable = @@ -220,8 +219,8 @@ class DataSourceWithHiveMetastoreCatalogSuite assert(columns.map(_.name) === Seq("d1", "d2")) assert(columns.map(_.dataType) === Seq(DecimalType(10, 3), StringType)) - checkAnswer(table("default.t"), testDF) - assert(sparkSession.metadataHive.runSqlHive("SELECT * FROM default.t") === + checkAnswer(table("t"), testDF) + assert(sparkSession.metadataHive.runSqlHive("SELECT * FROM t") === Seq("1.1\t1", "2.1\t2")) } } @@ -229,9 +228,9 @@ class DataSourceWithHiveMetastoreCatalogSuite test(s"Persist non-partitioned $provider relation into metastore as managed table using CTAS") { withTempPath { dir => - withTable("default.t") { + withTable("t") { sql( - s"""CREATE TABLE default.t USING $provider + s"""CREATE TABLE t USING $provider |OPTIONS (path '${dir.toURI}') |AS SELECT 1 AS d1, "val_1" AS d2 """.stripMargin) @@ -249,9 +248,8 @@ class DataSourceWithHiveMetastoreCatalogSuite assert(columns.map(_.name) === Seq("d1", "d2")) assert(columns.map(_.dataType) === Seq(IntegerType, StringType)) - checkAnswer(table("default.t"), Row(1, "val_1")) - assert(sparkSession.metadataHive.runSqlHive("SELECT * FROM default.t") === - Seq("1\tval_1")) + checkAnswer(table("t"), Row(1, "val_1")) + assert(sparkSession.metadataHive.runSqlHive("SELECT * FROM t") === Seq("1\tval_1")) } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala index 8986fb58c6460..7402c9626873c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala @@ -49,8 +49,12 @@ class HiveUDAFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils { } protected override def afterAll(): Unit = { - sql(s"DROP TEMPORARY FUNCTION IF EXISTS mock") - sql(s"DROP TEMPORARY FUNCTION IF EXISTS hive_max") + try { + sql(s"DROP TEMPORARY FUNCTION IF EXISTS mock") + sql(s"DROP TEMPORARY FUNCTION IF EXISTS hive_max") + } finally { + super.afterAll() + } } test("built-in Hive UDAF") { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/Hive_2_1_DDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/Hive_2_1_DDLSuite.scala index bc828877e35ec..eaedac1fa95d8 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/Hive_2_1_DDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/Hive_2_1_DDLSuite.scala @@ -74,7 +74,11 @@ class Hive_2_1_DDLSuite extends SparkFunSuite with TestHiveSingleton with Before } override def afterAll(): Unit = { - catalog = null + try { + catalog = null + } finally { + super.afterAll() + } } test("SPARK-21617: ALTER TABLE for non-compatible DataSource tables") { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ObjectHashAggregateSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ObjectHashAggregateSuite.scala index 9eaf44c043c71..8dbcd24cd78de 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ObjectHashAggregateSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ObjectHashAggregateSuite.scala @@ -47,7 +47,11 @@ class ObjectHashAggregateSuite } protected override def afterAll(): Unit = { - sql(s"DROP TEMPORARY FUNCTION IF EXISTS hive_max") + try { + sql(s"DROP TEMPORARY FUNCTION IF EXISTS hive_max") + } finally { + super.afterAll() + } } test("typed_count without grouping keys") { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala index 740e0837350cc..2327d83a1b4f6 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala @@ -180,15 +180,18 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { } override def afterAll(): Unit = { - dropTables("partitioned_parquet", - "partitioned_parquet_with_key", - "partitioned_parquet_with_complextypes", - "partitioned_parquet_with_key_and_complextypes", - "normal_parquet", - "jt", - "jt_array", - "test_parquet") - super.afterAll() + try { + dropTables("partitioned_parquet", + "partitioned_parquet_with_key", + "partitioned_parquet_with_complextypes", + "partitioned_parquet_with_key_and_complextypes", + "normal_parquet", + "jt", + "jt_array", + "test_parquet") + } finally { + super.afterAll() + } } test(s"conversion is working") { @@ -931,11 +934,15 @@ abstract class ParquetPartitioningTest extends QueryTest with SQLTestUtils with } override protected def afterAll(): Unit = { - partitionedTableDir.delete() - normalTableDir.delete() - partitionedTableDirWithKey.delete() - partitionedTableDirWithComplexTypes.delete() - partitionedTableDirWithKeyAndComplexTypes.delete() + try { + partitionedTableDir.delete() + normalTableDir.delete() + partitionedTableDirWithKey.delete() + partitionedTableDirWithComplexTypes.delete() + partitionedTableDirWithKeyAndComplexTypes.delete() + } finally { + super.afterAll() + } } /** From 5d680cae486c77cdb12dbe9e043710e49e8d51e4 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 22 Jan 2018 20:56:38 +0800 Subject: [PATCH 0171/1161] [SPARK-23090][SQL] polish ColumnVector ## What changes were proposed in this pull request? Several improvements: * provide a default implementation for the batch get methods * rename `getChildColumn` to `getChild`, which is more concise * remove `getStruct(int, int)`, it's only used to simplify the codegen, which is an internal thing, we should not add a public API for this purpose. ## How was this patch tested? existing tests Author: Wenchen Fan Closes #20277 from cloud-fan/column-vector. --- .../expressions/codegen/CodeGenerator.scala | 18 ++-- .../datasources/orc/OrcColumnVector.java | 65 +----------- .../orc/OrcColumnarBatchReader.java | 23 ++--- .../vectorized/ColumnVectorUtils.java | 10 +- .../vectorized/MutableColumnarRow.java | 4 +- .../vectorized/WritableColumnVector.java | 10 +- .../sql/vectorized/ArrowColumnVector.java | 99 +------------------ .../spark/sql/vectorized/ColumnVector.java | 79 +++++++++++---- .../spark/sql/vectorized/ColumnarArray.java | 4 +- .../spark/sql/vectorized/ColumnarRow.java | 46 ++++----- .../sql/execution/ColumnarBatchScan.scala | 2 +- .../VectorizedHashMapGenerator.scala | 4 +- .../execution/arrow/ArrowWriterSuite.scala | 14 +-- .../vectorized/ArrowColumnVectorSuite.scala | 12 +-- .../vectorized/ColumnVectorSuite.scala | 12 +-- .../vectorized/ColumnarBatchBenchmark.scala | 38 +++---- .../vectorized/ColumnarBatchSuite.scala | 20 ++-- 17 files changed, 164 insertions(+), 296 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 2c714c228e6c9..f96ed7628fda1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -688,17 +688,13 @@ class CodegenContext { /** * Returns the specialized code to access a value from a column vector for a given `DataType`. */ - def getValue(vector: String, rowId: String, dataType: DataType): String = { - val jt = javaType(dataType) - dataType match { - case _ if isPrimitiveType(jt) => - s"$vector.get${primitiveTypeName(jt)}($rowId)" - case t: DecimalType => - s"$vector.getDecimal($rowId, ${t.precision}, ${t.scale})" - case StringType => - s"$vector.getUTF8String($rowId)" - case _ => - throw new IllegalArgumentException(s"cannot generate code for unsupported type: $dataType") + def getValueFromVector(vector: String, dataType: DataType, rowId: String): String = { + if (dataType.isInstanceOf[StructType]) { + // `ColumnVector.getStruct` is different from `InternalRow.getStruct`, it only takes an + // `ordinal` parameter. + s"$vector.getStruct($rowId)" + } else { + getValue(vector, dataType, rowId) } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java index b6e792274da11..aaf2a380034a9 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java @@ -110,57 +110,21 @@ public boolean getBoolean(int rowId) { return longData.vector[getRowIndex(rowId)] == 1; } - @Override - public boolean[] getBooleans(int rowId, int count) { - boolean[] res = new boolean[count]; - for (int i = 0; i < count; i++) { - res[i] = getBoolean(rowId + i); - } - return res; - } - @Override public byte getByte(int rowId) { return (byte) longData.vector[getRowIndex(rowId)]; } - @Override - public byte[] getBytes(int rowId, int count) { - byte[] res = new byte[count]; - for (int i = 0; i < count; i++) { - res[i] = getByte(rowId + i); - } - return res; - } - @Override public short getShort(int rowId) { return (short) longData.vector[getRowIndex(rowId)]; } - @Override - public short[] getShorts(int rowId, int count) { - short[] res = new short[count]; - for (int i = 0; i < count; i++) { - res[i] = getShort(rowId + i); - } - return res; - } - @Override public int getInt(int rowId) { return (int) longData.vector[getRowIndex(rowId)]; } - @Override - public int[] getInts(int rowId, int count) { - int[] res = new int[count]; - for (int i = 0; i < count; i++) { - res[i] = getInt(rowId + i); - } - return res; - } - @Override public long getLong(int rowId) { int index = getRowIndex(rowId); @@ -171,43 +135,16 @@ public long getLong(int rowId) { } } - @Override - public long[] getLongs(int rowId, int count) { - long[] res = new long[count]; - for (int i = 0; i < count; i++) { - res[i] = getLong(rowId + i); - } - return res; - } - @Override public float getFloat(int rowId) { return (float) doubleData.vector[getRowIndex(rowId)]; } - @Override - public float[] getFloats(int rowId, int count) { - float[] res = new float[count]; - for (int i = 0; i < count; i++) { - res[i] = getFloat(rowId + i); - } - return res; - } - @Override public double getDouble(int rowId) { return doubleData.vector[getRowIndex(rowId)]; } - @Override - public double[] getDoubles(int rowId, int count) { - double[] res = new double[count]; - for (int i = 0; i < count; i++) { - res[i] = getDouble(rowId + i); - } - return res; - } - @Override public int getArrayLength(int rowId) { throw new UnsupportedOperationException(); @@ -245,7 +182,7 @@ public org.apache.spark.sql.vectorized.ColumnVector arrayData() { } @Override - public org.apache.spark.sql.vectorized.ColumnVector getChildColumn(int ordinal) { + public org.apache.spark.sql.vectorized.ColumnVector getChild(int ordinal) { throw new UnsupportedOperationException(); } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java index 89bae4326e93b..5e7cad470e1d1 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java @@ -289,10 +289,9 @@ private void putRepeatingValues( toColumn.putDoubles(0, batchSize, ((DoubleColumnVector)fromColumn).vector[0]); } else if (type instanceof StringType || type instanceof BinaryType) { BytesColumnVector data = (BytesColumnVector)fromColumn; - WritableColumnVector arrayData = toColumn.getChildColumn(0); int size = data.vector[0].length; - arrayData.reserve(size); - arrayData.putBytes(0, size, data.vector[0], 0); + toColumn.arrayData().reserve(size); + toColumn.arrayData().putBytes(0, size, data.vector[0], 0); for (int index = 0; index < batchSize; index++) { toColumn.putArray(index, 0, size); } @@ -352,7 +351,7 @@ private void putNonNullValues( toColumn.putDoubles(0, batchSize, ((DoubleColumnVector)fromColumn).vector, 0); } else if (type instanceof StringType || type instanceof BinaryType) { BytesColumnVector data = ((BytesColumnVector)fromColumn); - WritableColumnVector arrayData = toColumn.getChildColumn(0); + WritableColumnVector arrayData = toColumn.arrayData(); int totalNumBytes = IntStream.of(data.length).sum(); arrayData.reserve(totalNumBytes); for (int index = 0, pos = 0; index < batchSize; pos += data.length[index], index++) { @@ -363,8 +362,7 @@ private void putNonNullValues( DecimalType decimalType = (DecimalType)type; DecimalColumnVector data = ((DecimalColumnVector)fromColumn); if (decimalType.precision() > Decimal.MAX_LONG_DIGITS()) { - WritableColumnVector arrayData = toColumn.getChildColumn(0); - arrayData.reserve(batchSize * 16); + toColumn.arrayData().reserve(batchSize * 16); } for (int index = 0; index < batchSize; index++) { putDecimalWritable( @@ -459,7 +457,7 @@ private void putValues( } } else if (type instanceof StringType || type instanceof BinaryType) { BytesColumnVector vector = (BytesColumnVector)fromColumn; - WritableColumnVector arrayData = toColumn.getChildColumn(0); + WritableColumnVector arrayData = toColumn.arrayData(); int totalNumBytes = IntStream.of(vector.length).sum(); arrayData.reserve(totalNumBytes); for (int index = 0, pos = 0; index < batchSize; pos += vector.length[index], index++) { @@ -474,8 +472,7 @@ private void putValues( DecimalType decimalType = (DecimalType)type; HiveDecimalWritable[] vector = ((DecimalColumnVector)fromColumn).vector; if (decimalType.precision() > Decimal.MAX_LONG_DIGITS()) { - WritableColumnVector arrayData = toColumn.getChildColumn(0); - arrayData.reserve(batchSize * 16); + toColumn.arrayData().reserve(batchSize * 16); } for (int index = 0; index < batchSize; index++) { if (fromColumn.isNull[index]) { @@ -521,8 +518,7 @@ private static void putDecimalWritable( toColumn.putLong(index, value.toUnscaledLong()); } else { byte[] bytes = value.toJavaBigDecimal().unscaledValue().toByteArray(); - WritableColumnVector arrayData = toColumn.getChildColumn(0); - arrayData.putBytes(index * 16, bytes.length, bytes, 0); + toColumn.arrayData().putBytes(index * 16, bytes.length, bytes, 0); toColumn.putArray(index, index * 16, bytes.length); } } @@ -547,9 +543,8 @@ private static void putDecimalWritables( toColumn.putLongs(0, size, value.toUnscaledLong()); } else { byte[] bytes = value.toJavaBigDecimal().unscaledValue().toByteArray(); - WritableColumnVector arrayData = toColumn.getChildColumn(0); - arrayData.reserve(bytes.length); - arrayData.putBytes(0, bytes.length, bytes, 0); + toColumn.arrayData().reserve(bytes.length); + toColumn.arrayData().putBytes(0, bytes.length, bytes, 0); for (int index = 0; index < size; index++) { toColumn.putArray(index, 0, bytes.length); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java index 5ee8cc8da2309..a2853bbadc92b 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java @@ -85,8 +85,8 @@ public static void populate(WritableColumnVector col, InternalRow row, int field } } else if (t instanceof CalendarIntervalType) { CalendarInterval c = (CalendarInterval)row.get(fieldIdx, t); - col.getChildColumn(0).putInts(0, capacity, c.months); - col.getChildColumn(1).putLongs(0, capacity, c.microseconds); + col.getChild(0).putInts(0, capacity, c.months); + col.getChild(1).putLongs(0, capacity, c.microseconds); } else if (t instanceof DateType) { col.putInts(0, capacity, row.getInt(fieldIdx)); } else if (t instanceof TimestampType) { @@ -149,8 +149,8 @@ private static void appendValue(WritableColumnVector dst, DataType t, Object o) } else if (t instanceof CalendarIntervalType) { CalendarInterval c = (CalendarInterval)o; dst.appendStruct(false); - dst.getChildColumn(0).appendInt(c.months); - dst.getChildColumn(1).appendLong(c.microseconds); + dst.getChild(0).appendInt(c.months); + dst.getChild(1).appendLong(c.microseconds); } else if (t instanceof DateType) { dst.appendInt(DateTimeUtils.fromJavaDate((Date)o)); } else { @@ -179,7 +179,7 @@ private static void appendValue(WritableColumnVector dst, DataType t, Row src, i dst.appendStruct(false); Row c = src.getStruct(fieldIdx); for (int i = 0; i < st.fields().length; i++) { - appendValue(dst.getChildColumn(i), st.fields()[i].dataType(), c, i); + appendValue(dst.getChild(i), st.fields()[i].dataType(), c, i); } } } else { diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java index 70057a9def6c0..2bab095d4d951 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java @@ -146,8 +146,8 @@ public byte[] getBinary(int ordinal) { @Override public CalendarInterval getInterval(int ordinal) { if (columns[ordinal].isNullAt(rowId)) return null; - final int months = columns[ordinal].getChildColumn(0).getInt(rowId); - final long microseconds = columns[ordinal].getChildColumn(1).getLong(rowId); + final int months = columns[ordinal].getChild(0).getInt(rowId); + final long microseconds = columns[ordinal].getChild(1).getLong(rowId); return new CalendarInterval(months, microseconds); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java index d2ae32b06f83b..ca4f00985c2a3 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java @@ -599,17 +599,13 @@ public final int appendStruct(boolean isNull) { return elementsAppended; } - /** - * Returns the data for the underlying array. - */ + // `WritableColumnVector` puts the data of array in the first child column vector, and puts the + // array offsets and lengths in the current column vector. @Override public WritableColumnVector arrayData() { return childColumns[0]; } - /** - * Returns the ordinal's child data column. - */ @Override - public WritableColumnVector getChildColumn(int ordinal) { return childColumns[ordinal]; } + public WritableColumnVector getChild(int ordinal) { return childColumns[ordinal]; } /** * Returns the elements appended. diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java index bfd1b4cb0ef12..ca7a4751450d4 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java @@ -33,18 +33,6 @@ public final class ArrowColumnVector extends ColumnVector { private final ArrowVectorAccessor accessor; private ArrowColumnVector[] childColumns; - private void ensureAccessible(int index) { - ensureAccessible(index, 1); - } - - private void ensureAccessible(int index, int count) { - int valueCount = accessor.getValueCount(); - if (index < 0 || index + count > valueCount) { - throw new IndexOutOfBoundsException( - String.format("index range: [%d, %d), valueCount: %d", index, index + count, valueCount)); - } - } - @Override public int numNulls() { return accessor.getNullCount(); @@ -55,156 +43,75 @@ public void close() { if (childColumns != null) { for (int i = 0; i < childColumns.length; i++) { childColumns[i].close(); + childColumns[i] = null; } + childColumns = null; } accessor.close(); } @Override public boolean isNullAt(int rowId) { - ensureAccessible(rowId); return accessor.isNullAt(rowId); } @Override public boolean getBoolean(int rowId) { - ensureAccessible(rowId); return accessor.getBoolean(rowId); } - @Override - public boolean[] getBooleans(int rowId, int count) { - ensureAccessible(rowId, count); - boolean[] array = new boolean[count]; - for (int i = 0; i < count; ++i) { - array[i] = accessor.getBoolean(rowId + i); - } - return array; - } - @Override public byte getByte(int rowId) { - ensureAccessible(rowId); return accessor.getByte(rowId); } - @Override - public byte[] getBytes(int rowId, int count) { - ensureAccessible(rowId, count); - byte[] array = new byte[count]; - for (int i = 0; i < count; ++i) { - array[i] = accessor.getByte(rowId + i); - } - return array; - } - @Override public short getShort(int rowId) { - ensureAccessible(rowId); return accessor.getShort(rowId); } - @Override - public short[] getShorts(int rowId, int count) { - ensureAccessible(rowId, count); - short[] array = new short[count]; - for (int i = 0; i < count; ++i) { - array[i] = accessor.getShort(rowId + i); - } - return array; - } - @Override public int getInt(int rowId) { - ensureAccessible(rowId); return accessor.getInt(rowId); } - @Override - public int[] getInts(int rowId, int count) { - ensureAccessible(rowId, count); - int[] array = new int[count]; - for (int i = 0; i < count; ++i) { - array[i] = accessor.getInt(rowId + i); - } - return array; - } - @Override public long getLong(int rowId) { - ensureAccessible(rowId); return accessor.getLong(rowId); } - @Override - public long[] getLongs(int rowId, int count) { - ensureAccessible(rowId, count); - long[] array = new long[count]; - for (int i = 0; i < count; ++i) { - array[i] = accessor.getLong(rowId + i); - } - return array; - } - @Override public float getFloat(int rowId) { - ensureAccessible(rowId); return accessor.getFloat(rowId); } - @Override - public float[] getFloats(int rowId, int count) { - ensureAccessible(rowId, count); - float[] array = new float[count]; - for (int i = 0; i < count; ++i) { - array[i] = accessor.getFloat(rowId + i); - } - return array; - } - @Override public double getDouble(int rowId) { - ensureAccessible(rowId); return accessor.getDouble(rowId); } - @Override - public double[] getDoubles(int rowId, int count) { - ensureAccessible(rowId, count); - double[] array = new double[count]; - for (int i = 0; i < count; ++i) { - array[i] = accessor.getDouble(rowId + i); - } - return array; - } - @Override public int getArrayLength(int rowId) { - ensureAccessible(rowId); return accessor.getArrayLength(rowId); } @Override public int getArrayOffset(int rowId) { - ensureAccessible(rowId); return accessor.getArrayOffset(rowId); } @Override public Decimal getDecimal(int rowId, int precision, int scale) { - ensureAccessible(rowId); return accessor.getDecimal(rowId, precision, scale); } @Override public UTF8String getUTF8String(int rowId) { - ensureAccessible(rowId); return accessor.getUTF8String(rowId); } @Override public byte[] getBinary(int rowId) { - ensureAccessible(rowId); return accessor.getBinary(rowId); } @@ -212,7 +119,7 @@ public byte[] getBinary(int rowId) { public ArrowColumnVector arrayData() { return childColumns[0]; } @Override - public ArrowColumnVector getChildColumn(int ordinal) { return childColumns[ordinal]; } + public ArrowColumnVector getChild(int ordinal) { return childColumns[ordinal]; } public ArrowColumnVector(ValueVector vector) { super(ArrowUtils.fromArrowField(vector.getField())); diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java index d1196e1299fee..f9936214035b6 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java @@ -51,12 +51,16 @@ public abstract class ColumnVector implements AutoCloseable { public final DataType dataType() { return type; } /** - * Cleans up memory for this column. The column is not usable after this. + * Cleans up memory for this column vector. The column vector is not usable after this. + * + * This overwrites `AutoCloseable.close` to remove the `throws` clause, as column vector is + * in-memory and we don't expect any exception to happen during closing. */ + @Override public abstract void close(); /** - * Returns the number of nulls in this column. + * Returns the number of nulls in this column vector. */ public abstract int numNulls(); @@ -73,7 +77,13 @@ public abstract class ColumnVector implements AutoCloseable { /** * Gets values from [rowId, rowId + count) */ - public abstract boolean[] getBooleans(int rowId, int count); + public boolean[] getBooleans(int rowId, int count) { + boolean[] res = new boolean[count]; + for (int i = 0; i < count; i++) { + res[i] = getBoolean(rowId + i); + } + return res; + } /** * Returns the value for rowId. @@ -83,7 +93,13 @@ public abstract class ColumnVector implements AutoCloseable { /** * Gets values from [rowId, rowId + count) */ - public abstract byte[] getBytes(int rowId, int count); + public byte[] getBytes(int rowId, int count) { + byte[] res = new byte[count]; + for (int i = 0; i < count; i++) { + res[i] = getByte(rowId + i); + } + return res; + } /** * Returns the value for rowId. @@ -93,7 +109,13 @@ public abstract class ColumnVector implements AutoCloseable { /** * Gets values from [rowId, rowId + count) */ - public abstract short[] getShorts(int rowId, int count); + public short[] getShorts(int rowId, int count) { + short[] res = new short[count]; + for (int i = 0; i < count; i++) { + res[i] = getShort(rowId + i); + } + return res; + } /** * Returns the value for rowId. @@ -103,7 +125,13 @@ public abstract class ColumnVector implements AutoCloseable { /** * Gets values from [rowId, rowId + count) */ - public abstract int[] getInts(int rowId, int count); + public int[] getInts(int rowId, int count) { + int[] res = new int[count]; + for (int i = 0; i < count; i++) { + res[i] = getInt(rowId + i); + } + return res; + } /** * Returns the value for rowId. @@ -113,7 +141,13 @@ public abstract class ColumnVector implements AutoCloseable { /** * Gets values from [rowId, rowId + count) */ - public abstract long[] getLongs(int rowId, int count); + public long[] getLongs(int rowId, int count) { + long[] res = new long[count]; + for (int i = 0; i < count; i++) { + res[i] = getLong(rowId + i); + } + return res; + } /** * Returns the value for rowId. @@ -123,7 +157,13 @@ public abstract class ColumnVector implements AutoCloseable { /** * Gets values from [rowId, rowId + count) */ - public abstract float[] getFloats(int rowId, int count); + public float[] getFloats(int rowId, int count) { + float[] res = new float[count]; + for (int i = 0; i < count; i++) { + res[i] = getFloat(rowId + i); + } + return res; + } /** * Returns the value for rowId. @@ -133,7 +173,13 @@ public abstract class ColumnVector implements AutoCloseable { /** * Gets values from [rowId, rowId + count) */ - public abstract double[] getDoubles(int rowId, int count); + public double[] getDoubles(int rowId, int count) { + double[] res = new double[count]; + for (int i = 0; i < count; i++) { + res[i] = getDouble(rowId + i); + } + return res; + } /** * Returns the length of the array for rowId. @@ -152,14 +198,6 @@ public final ColumnarRow getStruct(int rowId) { return new ColumnarRow(this, rowId); } - /** - * A special version of {@link #getStruct(int)}, which is only used as an adapter for Spark - * codegen framework, the second parameter is totally ignored. - */ - public final ColumnarRow getStruct(int rowId, int size) { - return getStruct(rowId); - } - /** * Returns the array for rowId. */ @@ -196,9 +234,9 @@ public MapData getMap(int ordinal) { public abstract ColumnVector arrayData(); /** - * Returns the ordinal's child data column. + * Returns the ordinal's child column vector. */ - public abstract ColumnVector getChildColumn(int ordinal); + public abstract ColumnVector getChild(int ordinal); /** * Data type for this column. @@ -206,8 +244,7 @@ public MapData getMap(int ordinal) { protected DataType type; /** - * Sets up the common state and also handles creating the child columns if this is a nested - * type. + * Sets up the data type of this column vector. */ protected ColumnVector(DataType type) { this.type = type; diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java index 0d89a52e7a4fe..522c39580389f 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java @@ -133,8 +133,8 @@ public byte[] getBinary(int ordinal) { @Override public CalendarInterval getInterval(int ordinal) { - int month = data.getChildColumn(0).getInt(offset + ordinal); - long microseconds = data.getChildColumn(1).getLong(offset + ordinal); + int month = data.getChild(0).getInt(offset + ordinal); + long microseconds = data.getChild(1).getLong(offset + ordinal); return new CalendarInterval(month, microseconds); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java index 3c6656dec77cd..2e59085a82768 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java @@ -28,7 +28,7 @@ */ public final class ColumnarRow extends InternalRow { // The data for this row. - // E.g. the value of 3rd int field is `data.getChildColumn(3).getInt(rowId)`. + // E.g. the value of 3rd int field is `data.getChild(3).getInt(rowId)`. private final ColumnVector data; private final int rowId; private final int numFields; @@ -53,7 +53,7 @@ public InternalRow copy() { if (isNullAt(i)) { row.setNullAt(i); } else { - DataType dt = data.getChildColumn(i).dataType(); + DataType dt = data.getChild(i).dataType(); if (dt instanceof BooleanType) { row.setBoolean(i, getBoolean(i)); } else if (dt instanceof ByteType) { @@ -93,65 +93,65 @@ public boolean anyNull() { } @Override - public boolean isNullAt(int ordinal) { return data.getChildColumn(ordinal).isNullAt(rowId); } + public boolean isNullAt(int ordinal) { return data.getChild(ordinal).isNullAt(rowId); } @Override - public boolean getBoolean(int ordinal) { return data.getChildColumn(ordinal).getBoolean(rowId); } + public boolean getBoolean(int ordinal) { return data.getChild(ordinal).getBoolean(rowId); } @Override - public byte getByte(int ordinal) { return data.getChildColumn(ordinal).getByte(rowId); } + public byte getByte(int ordinal) { return data.getChild(ordinal).getByte(rowId); } @Override - public short getShort(int ordinal) { return data.getChildColumn(ordinal).getShort(rowId); } + public short getShort(int ordinal) { return data.getChild(ordinal).getShort(rowId); } @Override - public int getInt(int ordinal) { return data.getChildColumn(ordinal).getInt(rowId); } + public int getInt(int ordinal) { return data.getChild(ordinal).getInt(rowId); } @Override - public long getLong(int ordinal) { return data.getChildColumn(ordinal).getLong(rowId); } + public long getLong(int ordinal) { return data.getChild(ordinal).getLong(rowId); } @Override - public float getFloat(int ordinal) { return data.getChildColumn(ordinal).getFloat(rowId); } + public float getFloat(int ordinal) { return data.getChild(ordinal).getFloat(rowId); } @Override - public double getDouble(int ordinal) { return data.getChildColumn(ordinal).getDouble(rowId); } + public double getDouble(int ordinal) { return data.getChild(ordinal).getDouble(rowId); } @Override public Decimal getDecimal(int ordinal, int precision, int scale) { - if (data.getChildColumn(ordinal).isNullAt(rowId)) return null; - return data.getChildColumn(ordinal).getDecimal(rowId, precision, scale); + if (data.getChild(ordinal).isNullAt(rowId)) return null; + return data.getChild(ordinal).getDecimal(rowId, precision, scale); } @Override public UTF8String getUTF8String(int ordinal) { - if (data.getChildColumn(ordinal).isNullAt(rowId)) return null; - return data.getChildColumn(ordinal).getUTF8String(rowId); + if (data.getChild(ordinal).isNullAt(rowId)) return null; + return data.getChild(ordinal).getUTF8String(rowId); } @Override public byte[] getBinary(int ordinal) { - if (data.getChildColumn(ordinal).isNullAt(rowId)) return null; - return data.getChildColumn(ordinal).getBinary(rowId); + if (data.getChild(ordinal).isNullAt(rowId)) return null; + return data.getChild(ordinal).getBinary(rowId); } @Override public CalendarInterval getInterval(int ordinal) { - if (data.getChildColumn(ordinal).isNullAt(rowId)) return null; - final int months = data.getChildColumn(ordinal).getChildColumn(0).getInt(rowId); - final long microseconds = data.getChildColumn(ordinal).getChildColumn(1).getLong(rowId); + if (data.getChild(ordinal).isNullAt(rowId)) return null; + final int months = data.getChild(ordinal).getChild(0).getInt(rowId); + final long microseconds = data.getChild(ordinal).getChild(1).getLong(rowId); return new CalendarInterval(months, microseconds); } @Override public ColumnarRow getStruct(int ordinal, int numFields) { - if (data.getChildColumn(ordinal).isNullAt(rowId)) return null; - return data.getChildColumn(ordinal).getStruct(rowId); + if (data.getChild(ordinal).isNullAt(rowId)) return null; + return data.getChild(ordinal).getStruct(rowId); } @Override public ColumnarArray getArray(int ordinal) { - if (data.getChildColumn(ordinal).isNullAt(rowId)) return null; - return data.getChildColumn(ordinal).getArray(rowId); + if (data.getChild(ordinal).isNullAt(rowId)) return null; + return data.getChild(ordinal).getArray(rowId); } @Override diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala index dd68df9686691..04f2619ed7541 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala @@ -50,7 +50,7 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport { dataType: DataType, nullable: Boolean): ExprCode = { val javaType = ctx.javaType(dataType) - val value = ctx.getValue(columnVar, dataType, ordinal) + val value = ctx.getValueFromVector(columnVar, dataType, ordinal) val isNullVar = if (nullable) { ctx.freshName("isNull") } else { "false" } val valueVar = ctx.freshName("value") val str = s"columnVector[$columnVar, $ordinal, ${dataType.simpleString}]" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala index eb48584d0c1ee..633eeac180974 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala @@ -127,8 +127,8 @@ class VectorizedHashMapGenerator( def genEqualsForKeys(groupingKeys: Seq[Buffer]): String = { groupingKeys.zipWithIndex.map { case (key: Buffer, ordinal: Int) => - s"""(${ctx.genEqual(key.dataType, ctx.getValue(s"vectors[$ordinal]", "buckets[idx]", - key.dataType), key.name)})""" + val value = ctx.getValueFromVector(s"vectors[$ordinal]", key.dataType, "buckets[idx]") + s"(${ctx.genEqual(key.dataType, value, key.name)})" }.mkString(" && ") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala index c42bc60a59d67..92506032ab2e5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala @@ -217,21 +217,21 @@ class ArrowWriterSuite extends SparkFunSuite { val reader = new ArrowColumnVector(writer.root.getFieldVectors().get(0)) - val struct0 = reader.getStruct(0, 2) + val struct0 = reader.getStruct(0) assert(struct0.getInt(0) === 1) assert(struct0.getUTF8String(1) === UTF8String.fromString("str1")) - val struct1 = reader.getStruct(1, 2) + val struct1 = reader.getStruct(1) assert(struct1.isNullAt(0)) assert(struct1.isNullAt(1)) assert(reader.isNullAt(2)) - val struct3 = reader.getStruct(3, 2) + val struct3 = reader.getStruct(3) assert(struct3.getInt(0) === 4) assert(struct3.isNullAt(1)) - val struct4 = reader.getStruct(4, 2) + val struct4 = reader.getStruct(4) assert(struct4.isNullAt(0)) assert(struct4.getUTF8String(1) === UTF8String.fromString("str5")) @@ -252,15 +252,15 @@ class ArrowWriterSuite extends SparkFunSuite { val reader = new ArrowColumnVector(writer.root.getFieldVectors().get(0)) - val struct00 = reader.getStruct(0, 1).getStruct(0, 2) + val struct00 = reader.getStruct(0).getStruct(0, 2) assert(struct00.getInt(0) === 1) assert(struct00.getUTF8String(1) === UTF8String.fromString("str1")) - val struct10 = reader.getStruct(1, 1).getStruct(0, 2) + val struct10 = reader.getStruct(1).getStruct(0, 2) assert(struct10.isNullAt(0)) assert(struct10.isNullAt(1)) - val struct2 = reader.getStruct(2, 1) + val struct2 = reader.getStruct(2) assert(struct2.isNullAt(0)) assert(reader.isNullAt(3)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala index 53432669e215d..e794f50781ff2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala @@ -346,11 +346,11 @@ class ArrowColumnVectorSuite extends SparkFunSuite { assert(columnVector.dataType === schema) assert(columnVector.numNulls === 0) - val row0 = columnVector.getStruct(0, 2) + val row0 = columnVector.getStruct(0) assert(row0.getInt(0) === 1) assert(row0.getLong(1) === 1L) - val row1 = columnVector.getStruct(1, 2) + val row1 = columnVector.getStruct(1) assert(row1.getInt(0) === 2) assert(row1.isNullAt(1)) @@ -398,21 +398,21 @@ class ArrowColumnVectorSuite extends SparkFunSuite { assert(columnVector.dataType === schema) assert(columnVector.numNulls === 1) - val row0 = columnVector.getStruct(0, 2) + val row0 = columnVector.getStruct(0) assert(row0.getInt(0) === 1) assert(row0.getLong(1) === 1L) - val row1 = columnVector.getStruct(1, 2) + val row1 = columnVector.getStruct(1) assert(row1.getInt(0) === 2) assert(row1.isNullAt(1)) - val row2 = columnVector.getStruct(2, 2) + val row2 = columnVector.getStruct(2) assert(row2.isNullAt(0)) assert(row2.getLong(1) === 3L) assert(columnVector.isNullAt(3)) - val row4 = columnVector.getStruct(4, 2) + val row4 = columnVector.getStruct(4) assert(row4.getInt(0) === 5) assert(row4.getLong(1) === 5L) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala index 944240f3bade5..2d1ad4b456783 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala @@ -199,17 +199,17 @@ class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach { val structType: StructType = new StructType().add("int", IntegerType).add("double", DoubleType) testVectors("struct", 10, structType) { testVector => - val c1 = testVector.getChildColumn(0) - val c2 = testVector.getChildColumn(1) + val c1 = testVector.getChild(0) + val c2 = testVector.getChild(1) c1.putInt(0, 123) c2.putDouble(0, 3.45) c1.putInt(1, 456) c2.putDouble(1, 5.67) - assert(testVector.getStruct(0, structType.length).get(0, IntegerType) === 123) - assert(testVector.getStruct(0, structType.length).get(1, DoubleType) === 3.45) - assert(testVector.getStruct(1, structType.length).get(0, IntegerType) === 456) - assert(testVector.getStruct(1, structType.length).get(1, DoubleType) === 5.67) + assert(testVector.getStruct(0).get(0, IntegerType) === 123) + assert(testVector.getStruct(0).get(1, DoubleType) === 3.45) + assert(testVector.getStruct(1).get(0, IntegerType) === 456) + assert(testVector.getStruct(1).get(1, DoubleType) === 5.67) } test("[SPARK-22092] off-heap column vector reallocation corrupts array data") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala index 38ea2e47fdef8..ad74fb99b0c73 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala @@ -268,17 +268,17 @@ object ColumnarBatchBenchmark { Int Read/Write: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ - Java Array 177 / 181 1856.4 0.5 1.0X - ByteBuffer Unsafe 318 / 322 1032.0 1.0 0.6X - ByteBuffer API 1411 / 1418 232.2 4.3 0.1X - DirectByteBuffer 467 / 474 701.8 1.4 0.4X - Unsafe Buffer 178 / 185 1843.6 0.5 1.0X - Column(on heap) 178 / 184 1840.8 0.5 1.0X - Column(off heap) 341 / 344 961.8 1.0 0.5X - Column(off heap direct) 178 / 184 1845.4 0.5 1.0X - UnsafeRow (on heap) 378 / 389 866.3 1.2 0.5X - UnsafeRow (off heap) 393 / 402 834.0 1.2 0.4X - Column On Heap Append 309 / 318 1059.1 0.9 0.6X + Java Array 177 / 183 1851.1 0.5 1.0X + ByteBuffer Unsafe 314 / 330 1043.7 1.0 0.6X + ByteBuffer API 1298 / 1307 252.4 4.0 0.1X + DirectByteBuffer 465 / 483 704.2 1.4 0.4X + Unsafe Buffer 179 / 183 1835.5 0.5 1.0X + Column(on heap) 181 / 186 1815.2 0.6 1.0X + Column(off heap) 344 / 349 951.7 1.1 0.5X + Column(off heap direct) 178 / 186 1838.6 0.5 1.0X + UnsafeRow (on heap) 388 / 394 844.8 1.2 0.5X + UnsafeRow (off heap) 400 / 403 819.4 1.2 0.4X + Column On Heap Append 315 / 325 1041.8 1.0 0.6X */ val benchmark = new Benchmark("Int Read/Write", count * iters) benchmark.addCase("Java Array")(javaArray) @@ -337,8 +337,8 @@ object ColumnarBatchBenchmark { Boolean Read/Write: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ - Bitset 726 / 727 462.4 2.2 1.0X - Byte Array 530 / 542 632.7 1.6 1.4X + Bitset 741 / 747 452.6 2.2 1.0X + Byte Array 531 / 542 631.6 1.6 1.4X */ benchmark.run() } @@ -394,8 +394,8 @@ object ColumnarBatchBenchmark { String Read/Write: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ - On Heap 332 / 338 49.3 20.3 1.0X - Off Heap 466 / 467 35.2 28.4 0.7X + On Heap 351 / 362 46.6 21.4 1.0X + Off Heap 456 / 466 35.9 27.8 0.8X */ val benchmark = new Benchmark("String Read/Write", count * iters) benchmark.addCase("On Heap")(column(MemoryMode.ON_HEAP)) @@ -479,10 +479,10 @@ object ColumnarBatchBenchmark { Array Vector Read: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ - On Heap Read Size Only 415 / 422 394.7 2.5 1.0X - Off Heap Read Size Only 394 / 402 415.9 2.4 1.1X - On Heap Read Elements 2558 / 2593 64.0 15.6 0.2X - Off Heap Read Elements 3316 / 3317 49.4 20.2 0.1X + On Heap Read Size Only 416 / 423 393.5 2.5 1.0X + Off Heap Read Size Only 396 / 404 413.6 2.4 1.1X + On Heap Read Elements 2569 / 2590 63.8 15.7 0.2X + Off Heap Read Elements 3302 / 3333 49.6 20.2 0.1X */ benchmark.run } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala index cd90681ecabc6..1873c24ab063c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala @@ -732,8 +732,8 @@ class ColumnarBatchSuite extends SparkFunSuite { "Struct Column", 10, new StructType().add("int", IntegerType).add("double", DoubleType)) { column => - val c1 = column.getChildColumn(0) - val c2 = column.getChildColumn(1) + val c1 = column.getChild(0) + val c2 = column.getChild(1) assert(c1.dataType() == IntegerType) assert(c2.dataType() == DoubleType) @@ -787,8 +787,8 @@ class ColumnarBatchSuite extends SparkFunSuite { 10, new ArrayType(structType, true)) { column => val data = column.arrayData() - val c0 = data.getChildColumn(0) - val c1 = data.getChildColumn(1) + val c0 = data.getChild(0) + val c1 = data.getChild(1) // Structs in child column: (0, 0), (1, 10), (2, 20), (3, 30), (4, 40), (5, 50) (0 until 6).foreach { i => c0.putInt(i, i) @@ -815,8 +815,8 @@ class ColumnarBatchSuite extends SparkFunSuite { new StructType() .add("int", IntegerType) .add("array", new ArrayType(IntegerType, true))) { column => - val c0 = column.getChildColumn(0) - val c1 = column.getChildColumn(1) + val c0 = column.getChild(0) + val c1 = column.getChild(1) c0.putInt(0, 0) c0.putInt(1, 1) c0.putInt(2, 2) @@ -844,13 +844,13 @@ class ColumnarBatchSuite extends SparkFunSuite { "Nest Struct in Struct", 10, new StructType().add("int", IntegerType).add("struct", subSchema)) { column => - val c0 = column.getChildColumn(0) - val c1 = column.getChildColumn(1) + val c0 = column.getChild(0) + val c1 = column.getChild(1) c0.putInt(0, 0) c0.putInt(1, 1) c0.putInt(2, 2) - val c1c0 = c1.getChildColumn(0) - val c1c1 = c1.getChildColumn(1) + val c1c0 = c1.getChild(0) + val c1c1 = c1.getChild(1) // Structs in c1: (7, 70), (8, 80), (9, 90) c1c0.putInt(0, 7) c1c0.putInt(1, 8) From 87ffe7adddf517541aac0d1e8536b02ad8881606 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Mon, 22 Jan 2018 22:12:50 +0900 Subject: [PATCH 0172/1161] [SPARK-7721][PYTHON][TESTS] Adds PySpark coverage generation script MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? Note that this PR was made based on the top of https://github.com/apache/spark/pull/20151. So, it almost leaves the main codes intact. This PR proposes to add a script for the preparation of automatic PySpark coverage generation. Now, it's difficult to check the actual coverage in case of PySpark. With this script, it allows to run tests by the way we did via `run-tests` script before. The usage is exactly the same with `run-tests` script as this basically wraps it. This script and PR alone should also be useful. I was asked about how to run this before, and seems some reviewers (including me) need this. It would be also useful to run it manually. It usually requires a small diff in normal Python projects but PySpark cases are a bit different because apparently we are unable to track the coverage after it's forked. So, here, I made a custom worker that forces the coverage, based on the top of https://github.com/apache/spark/pull/20151. I made a simple demo. Please take a look - https://spark-test.github.io/pyspark-coverage-site. To show up the structure, this PR adds the files as below: ``` python ├── .coveragerc # Runtime configuration when we run the script. ├── run-tests-with-coverage # The script that has coverage support and wraps run-tests script. └── test_coverage # Directories that have files required when running coverage. ├── conf │   └── spark-defaults.conf # Having the configuration 'spark.python.daemon.module'. ├── coverage_daemon.py # A daemon having custom fix and wrapping our daemon.py └── sitecustomize.py # Initiate coverage with COVERAGE_PROCESS_START ``` Note that this PR has a minor nit: [This scope](https://github.com/apache/spark/blob/04e44b37cc04f62fbf9e08c7076349e0a4d12ea8/python/pyspark/daemon.py#L148-L169) in `daemon.py` is not in the coverage results as basically I am producing the coverage results in `worker.py` separately and then merging it. I believe it's not a big deal. In a followup, I might have a site that has a single up-to-date PySpark coverage from the master branch as the fallback / default, or have a site that has multiple PySpark coverages and the site link will be left to each pull request. ## How was this patch tested? Manually tested. Usage is the same with the existing Python test script - `./python/run-tests`. For example, ``` sh run-tests-with-coverage --python-executables=python3 --modules=pyspark-sql ``` Running this will generate HTMLs under `./python/test_coverage/htmlcov`. Console output example: ``` sh run-tests-with-coverage --python-executables=python3,python --modules=pyspark-core Running PySpark tests. Output is in /.../spark/python/unit-tests.log Will test against the following Python executables: ['python3', 'python'] Will test the following Python modules: ['pyspark-core'] Starting test(python): pyspark.tests Starting test(python3): pyspark.tests ... Tests passed in 231 seconds Combining collected coverage data under /.../spark/python/test_coverage/coverage_data Reporting the coverage data at /...spark/python/test_coverage/coverage_data/coverage Name Stmts Miss Branch BrPart Cover -------------------------------------------------------------- pyspark/__init__.py 41 0 8 2 96% ... pyspark/profiler.py 74 11 22 5 83% pyspark/rdd.py 871 40 303 32 93% pyspark/rddsampler.py 68 10 32 2 82% ... -------------------------------------------------------------- TOTAL 8521 3077 2748 191 59% Generating HTML files for PySpark coverage under /.../spark/python/test_coverage/htmlcov ``` Author: hyukjinkwon Closes #20204 from HyukjinKwon/python-coverage. --- .gitignore | 2 + python/.coveragerc | 21 ++++++ python/run-tests-with-coverage | 69 +++++++++++++++++++ python/run-tests.py | 5 +- python/test_coverage/conf/spark-defaults.conf | 21 ++++++ python/test_coverage/coverage_daemon.py | 45 ++++++++++++ python/test_coverage/sitecustomize.py | 23 +++++++ 7 files changed, 185 insertions(+), 1 deletion(-) create mode 100644 python/.coveragerc create mode 100755 python/run-tests-with-coverage create mode 100644 python/test_coverage/conf/spark-defaults.conf create mode 100644 python/test_coverage/coverage_daemon.py create mode 100644 python/test_coverage/sitecustomize.py diff --git a/.gitignore b/.gitignore index 903297db96901..39085904e324c 100644 --- a/.gitignore +++ b/.gitignore @@ -62,6 +62,8 @@ project/plugins/src_managed/ project/plugins/target/ python/lib/pyspark.zip python/deps +python/test_coverage/coverage_data +python/test_coverage/htmlcov python/pyspark/python reports/ scalastyle-on-compile.generated.xml diff --git a/python/.coveragerc b/python/.coveragerc new file mode 100644 index 0000000000000..b3339cd356a6e --- /dev/null +++ b/python/.coveragerc @@ -0,0 +1,21 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +[run] +branch = true +parallel = true +data_file = ${COVERAGE_DIR}/coverage_data/coverage diff --git a/python/run-tests-with-coverage b/python/run-tests-with-coverage new file mode 100755 index 0000000000000..6d74b563e9140 --- /dev/null +++ b/python/run-tests-with-coverage @@ -0,0 +1,69 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +set -o pipefail +set -e + +# This variable indicates which coverage executable to run to combine coverages +# and generate HTMLs, for example, 'coverage3' in Python 3. +COV_EXEC="${COV_EXEC:-coverage}" +FWDIR="$(cd "`dirname $0`"; pwd)" +pushd "$FWDIR" > /dev/null + +# Ensure that coverage executable is installed. +if ! hash $COV_EXEC 2>/dev/null; then + echo "Missing coverage executable in your path, skipping PySpark coverage" + exit 1 +fi + +# Set up the directories for coverage results. +export COVERAGE_DIR="$FWDIR/test_coverage" +rm -fr "$COVERAGE_DIR/coverage_data" +rm -fr "$COVERAGE_DIR/htmlcov" +mkdir -p "$COVERAGE_DIR/coverage_data" + +# Current directory are added in the python path so that it doesn't refer our built +# pyspark zip library first. +export PYTHONPATH="$FWDIR:$PYTHONPATH" +# Also, our sitecustomize.py and coverage_daemon.py are included in the path. +export PYTHONPATH="$COVERAGE_DIR:$PYTHONPATH" + +# We use 'spark.python.daemon.module' configuration to insert the coverage supported workers. +export SPARK_CONF_DIR="$COVERAGE_DIR/conf" + +# This environment variable enables the coverage. +export COVERAGE_PROCESS_START="$FWDIR/.coveragerc" + +# If you'd like to run a specific unittest class, you could do such as +# SPARK_TESTING=1 ../bin/pyspark pyspark.sql.tests VectorizedUDFTests +./run-tests "$@" + +# Don't run coverage for the coverage command itself +unset COVERAGE_PROCESS_START + +# Coverage could generate empty coverage data files. Remove it to get rid of warnings when combining. +find $COVERAGE_DIR/coverage_data -size 0 -print0 | xargs -0 rm +echo "Combining collected coverage data under $COVERAGE_DIR/coverage_data" +$COV_EXEC combine +echo "Reporting the coverage data at $COVERAGE_DIR/coverage_data/coverage" +$COV_EXEC report --include "pyspark/*" +echo "Generating HTML files for PySpark coverage under $COVERAGE_DIR/htmlcov" +$COV_EXEC html --ignore-errors --include "pyspark/*" --directory "$COVERAGE_DIR/htmlcov" + +popd diff --git a/python/run-tests.py b/python/run-tests.py index 1341086f02db0..f03284c334285 100755 --- a/python/run-tests.py +++ b/python/run-tests.py @@ -38,7 +38,7 @@ from sparktestsupport import SPARK_HOME # noqa (suppress pep8 warnings) -from sparktestsupport.shellutils import which, subprocess_check_output # noqa +from sparktestsupport.shellutils import which, subprocess_check_output, run_cmd # noqa from sparktestsupport.modules import all_modules # noqa @@ -175,6 +175,9 @@ def main(): task_queue = Queue.PriorityQueue() for python_exec in python_execs: + if "COVERAGE_PROCESS_START" in os.environ: + # Make sure if coverage is installed. + run_cmd([python_exec, "-c", "import coverage"]) python_implementation = subprocess_check_output( [python_exec, "-c", "import platform; print(platform.python_implementation())"], universal_newlines=True).strip() diff --git a/python/test_coverage/conf/spark-defaults.conf b/python/test_coverage/conf/spark-defaults.conf new file mode 100644 index 0000000000000..bf44ea6e7cfec --- /dev/null +++ b/python/test_coverage/conf/spark-defaults.conf @@ -0,0 +1,21 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# This is used to generate PySpark coverage results. Seems there's no way to +# add a configuration when SPARK_TESTING environment variable is set because +# we will directly execute modules by python -m. +spark.python.daemon.module coverage_daemon diff --git a/python/test_coverage/coverage_daemon.py b/python/test_coverage/coverage_daemon.py new file mode 100644 index 0000000000000..c87366a1ac23b --- /dev/null +++ b/python/test_coverage/coverage_daemon.py @@ -0,0 +1,45 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import imp + + +# This is a hack to always refer the main code rather than built zip. +main_code_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) +daemon = imp.load_source("daemon", "%s/pyspark/daemon.py" % main_code_dir) + +if "COVERAGE_PROCESS_START" in os.environ: + worker = imp.load_source("worker", "%s/pyspark/worker.py" % main_code_dir) + + def _cov_wrapped(*args, **kwargs): + import coverage + cov = coverage.coverage( + config_file=os.environ["COVERAGE_PROCESS_START"]) + cov.start() + try: + worker.main(*args, **kwargs) + finally: + cov.stop() + cov.save() + daemon.worker_main = _cov_wrapped +else: + raise RuntimeError("COVERAGE_PROCESS_START environment variable is not set, exiting.") + + +if __name__ == '__main__': + daemon.manager() diff --git a/python/test_coverage/sitecustomize.py b/python/test_coverage/sitecustomize.py new file mode 100644 index 0000000000000..630237a518126 --- /dev/null +++ b/python/test_coverage/sitecustomize.py @@ -0,0 +1,23 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Note that this 'sitecustomize' module is a built-in feature in Python. +# If this module is defined, it's executed when the Python session begins. +# `coverage.process_startup()` seeks if COVERAGE_PROCESS_START environment +# variable is set or not. If set, it starts to run the coverage. +import coverage +coverage.process_startup() From 4327ccf289b5a0dc51f6294113d01af6eb52eea0 Mon Sep 17 00:00:00 2001 From: Rekha Joshi Date: Mon, 22 Jan 2018 08:36:17 -0600 Subject: [PATCH 0173/1161] [SPARK-11630][CORE] ClosureCleaner moved from warning to debug ## What changes were proposed in this pull request? ClosureCleaner moved from warning to debug ## How was this patch tested? Existing tests Author: Rekha Joshi Author: rjoshi2 Closes #20337 from rekhajoshm/SPARK-11630-1. --- core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala index 40616421b5bca..ad0c0639521f6 100644 --- a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala +++ b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala @@ -207,7 +207,7 @@ private[spark] object ClosureCleaner extends Logging { accessedFields: Map[Class[_], Set[String]]): Unit = { if (!isClosure(func.getClass)) { - logWarning("Expected a closure; got " + func.getClass.getName) + logDebug(s"Expected a closure; got ${func.getClass.getName}") return } From 446948af1d8dbc080a26a6eec6f743d338f1d12b Mon Sep 17 00:00:00 2001 From: Sandor Murakozi Date: Mon, 22 Jan 2018 10:36:28 -0800 Subject: [PATCH 0174/1161] [SPARK-23121][CORE] Fix for ui becoming unaccessible for long running streaming apps ## What changes were proposed in this pull request? The allJobs and the job pages attempt to use stage attempt and DAG visualization from the store, but for long running jobs they are not guaranteed to be retained, leading to exceptions when these pages are rendered. To fix it `store.lastStageAttempt(stageId)` and `store.operationGraphForJob(jobId)` are wrapped in `store.asOption` and default values are used if the info is missing. ## How was this patch tested? Manual testing of the UI, also using the test command reported in SPARK-23121: ./bin/spark-submit --class org.apache.spark.examples.streaming.HdfsWordCount ./examples/jars/spark-examples_2.11-2.4.0-SNAPSHOT.jar /spark Closes #20287 Author: Sandor Murakozi Closes #20330 from smurakozi/SPARK-23121. --- .../apache/spark/ui/jobs/AllJobsPage.scala | 24 ++++++++++--------- .../org/apache/spark/ui/jobs/JobPage.scala | 10 ++++++-- .../org/apache/spark/ui/jobs/StagePage.scala | 9 ++++--- 3 files changed, 27 insertions(+), 16 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala index e3b72f1f34859..2b0f4acbac72a 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala @@ -36,6 +36,9 @@ import org.apache.spark.util.Utils /** Page showing list of all ongoing and recently finished jobs */ private[ui] class AllJobsPage(parent: JobsTab, store: AppStatusStore) extends WebUIPage("") { + + import ApiHelper._ + private val JOBS_LEGEND =
val jobId = job.jobId val status = job.status - val jobDescription = store.lastStageAttempt(job.stageIds.max).description - val displayJobDescription = jobDescription - .map(UIUtils.makeDescription(_, "", plainText = true).text) - .getOrElse("") + val (_, lastStageDescription) = lastStageNameAndDescription(store, job) + val jobDescription = UIUtils.makeDescription(lastStageDescription, "", plainText = true).text + val submissionTime = job.submissionTime.get.getTime() val completionTime = job.completionTime.map(_.getTime()).getOrElse(System.currentTimeMillis()) val classNameByStatus = status match { @@ -80,7 +82,7 @@ private[ui] class AllJobsPage(parent: JobsTab, store: AppStatusStore) extends We // The timeline library treats contents as HTML, so we have to escape them. We need to add // extra layers of escaping in order to embed this in a Javascript string literal. - val escapedDesc = Utility.escape(displayJobDescription) + val escapedDesc = Utility.escape(jobDescription) val jsEscapedDesc = StringEscapeUtils.escapeEcmaScript(escapedDesc) val jobEventJsonAsStr = s""" @@ -430,6 +432,8 @@ private[ui] class JobDataSource( sortColumn: String, desc: Boolean) extends PagedDataSource[JobTableRowData](pageSize) { + import ApiHelper._ + // Convert JobUIData to JobTableRowData which contains the final contents to show in the table // so that we can avoid creating duplicate contents during sorting the data private val data = jobs.map(jobRow).sorted(ordering(sortColumn, desc)) @@ -454,23 +458,21 @@ private[ui] class JobDataSource( val formattedDuration = duration.map(d => UIUtils.formatDuration(d)).getOrElse("Unknown") val submissionTime = jobData.submissionTime val formattedSubmissionTime = submissionTime.map(UIUtils.formatDate).getOrElse("Unknown") - val lastStageAttempt = store.lastStageAttempt(jobData.stageIds.max) - val lastStageDescription = lastStageAttempt.description.getOrElse("") + val (lastStageName, lastStageDescription) = lastStageNameAndDescription(store, jobData) - val formattedJobDescription = - UIUtils.makeDescription(lastStageDescription, basePath, plainText = false) + val jobDescription = UIUtils.makeDescription(lastStageDescription, basePath, plainText = false) val detailUrl = "%s/jobs/job?id=%s".format(basePath, jobData.jobId) new JobTableRowData( jobData, - lastStageAttempt.name, + lastStageName, lastStageDescription, duration.getOrElse(-1), formattedDuration, submissionTime.map(_.getTime()).getOrElse(-1L), formattedSubmissionTime, - formattedJobDescription, + jobDescription, detailUrl ) } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala index c27f30c21a843..46f2a76cc651b 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala @@ -336,8 +336,14 @@ private[ui] class JobPage(parent: JobsTab, store: AppStatusStore) extends WebUIP content ++= makeTimeline(activeStages ++ completedStages ++ failedStages, store.executorList(false), appStartTime) - content ++= UIUtils.showDagVizForJob( - jobId, store.operationGraphForJob(jobId)) + val operationGraphContent = store.asOption(store.operationGraphForJob(jobId)) match { + case Some(operationGraph) => UIUtils.showDagVizForJob(jobId, operationGraph) + case None => +
+

No DAG visualization information to display for job {jobId}

+
+ } + content ++= operationGraphContent if (shouldShowActiveStages) { content ++= diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index 0eb3190205c3e..5c2b0c3a19996 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -23,12 +23,10 @@ import java.util.concurrent.TimeUnit import javax.servlet.http.HttpServletRequest import scala.collection.mutable.{HashMap, HashSet} -import scala.xml.{Elem, Node, Unparsed} +import scala.xml.{Node, Unparsed} import org.apache.commons.lang3.StringEscapeUtils -import org.apache.spark.SparkConf -import org.apache.spark.internal.config._ import org.apache.spark.scheduler.TaskLocality import org.apache.spark.status._ import org.apache.spark.status.api.v1._ @@ -1020,4 +1018,9 @@ private object ApiHelper { } } + def lastStageNameAndDescription(store: AppStatusStore, job: JobData): (String, String) = { + val stage = store.asOption(store.lastStageAttempt(job.stageIds.max)) + (stage.map(_.name).getOrElse(""), stage.flatMap(_.description).getOrElse(job.name)) + } + } From 76b8b840ddc951ee6203f9cccd2c2b9671c1b5e8 Mon Sep 17 00:00:00 2001 From: Jacek Laskowski Date: Mon, 22 Jan 2018 13:55:14 -0600 Subject: [PATCH 0175/1161] [MINOR] Typo fixes ## What changes were proposed in this pull request? Typo fixes ## How was this patch tested? Local build / Doc-only changes Author: Jacek Laskowski Closes #20344 from jaceklaskowski/typo-fixes. --- .../main/scala/org/apache/spark/SparkContext.scala | 2 +- .../spark/sql/kafka010/KafkaSourceProvider.scala | 4 ++-- .../apache/spark/sql/kafka010/KafkaWriteTask.scala | 2 +- .../org/apache/spark/sql/streaming/OutputMode.java | 2 +- .../spark/sql/catalyst/analysis/Analyzer.scala | 8 ++++---- .../spark/sql/catalyst/analysis/unresolved.scala | 2 +- .../catalyst/expressions/aggregate/interfaces.scala | 12 +++++------- .../catalyst/plans/logical/LogicalPlanVisitor.scala | 2 +- .../statsEstimation/BasicStatsPlanVisitor.scala | 2 +- .../SizeInBytesOnlyStatsPlanVisitor.scala | 4 ++-- .../org/apache/spark/sql/internal/SQLConf.scala | 2 +- .../apache/spark/sql/catalyst/plans/PlanTest.scala | 2 +- .../scala/org/apache/spark/sql/DataFrameWriter.scala | 2 +- .../apache/spark/sql/execution/SparkSqlParser.scala | 2 +- .../spark/sql/execution/WholeStageCodegenExec.scala | 2 +- .../spark/sql/execution/command/SetCommand.scala | 4 ++-- .../spark/sql/execution/datasources/rules.scala | 2 +- .../sql/execution/streaming/HDFSMetadataLog.scala | 2 +- .../spark/sql/execution/streaming/OffsetSeq.scala | 2 +- .../spark/sql/execution/streaming/OffsetSeqLog.scala | 2 +- .../execution/streaming/StreamingQueryWrapper.scala | 2 +- .../sql/execution/streaming/state/StateStore.scala | 2 +- .../spark/sql/execution/ui/ExecutionPage.scala | 2 +- .../spark/sql/expressions/UserDefinedFunction.scala | 4 ++-- .../spark/sql/internal/BaseSessionStateBuilder.scala | 4 ++-- .../spark/sql/streaming/DataStreamReader.scala | 6 +++--- .../results/columnresolution-negative.sql.out | 2 +- .../sql-tests/results/columnresolution-views.sql.out | 2 +- .../sql-tests/results/columnresolution.sql.out | 6 +++--- .../scala/org/apache/spark/sql/SQLQuerySuite.scala | 4 ++-- .../apache/spark/sql/execution/SQLViewSuite.scala | 2 +- .../apache/spark/sql/hive/HiveExternalCatalog.scala | 4 ++-- 32 files changed, 50 insertions(+), 52 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 31f3cb9dfa0ae..3828d4f703247 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -2276,7 +2276,7 @@ class SparkContext(config: SparkConf) extends Logging { } /** - * Clean a closure to make it ready to be serialized and send to tasks + * Clean a closure to make it ready to be serialized and sent to tasks * (removes unreferenced variables in $outer's, updates REPL variables) * If checkSerializable is set, clean will also proactively * check to see if f is serializable and throw a SparkException diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala index 3914370a96595..62a998fbfb30b 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala @@ -307,7 +307,7 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister if (caseInsensitiveParams.contains(s"kafka.${ConsumerConfig.GROUP_ID_CONFIG}")) { throw new IllegalArgumentException( s"Kafka option '${ConsumerConfig.GROUP_ID_CONFIG}' is not supported as " + - s"user-specified consumer groups is not used to track offsets.") + s"user-specified consumer groups are not used to track offsets.") } if (caseInsensitiveParams.contains(s"kafka.${ConsumerConfig.AUTO_OFFSET_RESET_CONFIG}")) { @@ -335,7 +335,7 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister { throw new IllegalArgumentException( s"Kafka option '${ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG}' is not supported as " - + "value are deserialized as byte arrays with ByteArrayDeserializer. Use DataFrame " + + "values are deserialized as byte arrays with ByteArrayDeserializer. Use DataFrame " + "operations to explicitly deserialize the values.") } diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala index baa60febf661d..d90630a8adc93 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Literal, Unsa import org.apache.spark.sql.types.{BinaryType, StringType} /** - * A simple trait for writing out data in a single Spark task, without any concerns about how + * Writes out data in a single Spark task, without any concerns about how * to commit or abort tasks. Exceptions thrown by the implementation of this class will * automatically trigger task aborts. */ diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/OutputMode.java b/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/OutputMode.java index 2800b3068f87b..470c128ee6c3d 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/OutputMode.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/OutputMode.java @@ -21,7 +21,7 @@ import org.apache.spark.sql.catalyst.streaming.InternalOutputModes; /** - * OutputMode is used to what data will be written to a streaming sink when there is + * OutputMode describes what data will be written to a streaming sink when there is * new data available in a streaming DataFrame/Dataset. * * @since 2.0.0 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 35b35110e491f..2b14c8220d43b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -611,8 +611,8 @@ class Analyzer( if (AnalysisContext.get.nestedViewDepth > conf.maxNestedViewDepth) { view.failAnalysis(s"The depth of view ${view.desc.identifier} exceeds the maximum " + s"view resolution depth (${conf.maxNestedViewDepth}). Analysis is aborted to " + - "avoid errors. Increase the value of spark.sql.view.maxNestedViewDepth to work " + - "aroud this.") + s"avoid errors. Increase the value of ${SQLConf.MAX_NESTED_VIEW_DEPTH.key} to work " + + "around this.") } executeSameContext(child) } @@ -653,7 +653,7 @@ class Analyzer( // Note that if the database is not defined, it is possible we are looking up a temp view. case e: NoSuchDatabaseException => u.failAnalysis(s"Table or view not found: ${tableIdentWithDb.unquotedString}, the " + - s"database ${e.db} doesn't exsits.") + s"database ${e.db} doesn't exist.") } } @@ -1524,7 +1524,7 @@ class Analyzer( } /** - * Extracts [[Generator]] from the projectList of a [[Project]] operator and create [[Generate]] + * Extracts [[Generator]] from the projectList of a [[Project]] operator and creates [[Generate]] * operator under [[Project]]. * * This rule will throw [[AnalysisException]] for following cases: diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index d336f801d0770..a65f58fa61ff4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -294,7 +294,7 @@ case class UnresolvedStar(target: Option[Seq[String]]) extends Star with Unevalu } else { val from = input.inputSet.map(_.name).mkString(", ") val targetString = target.get.mkString(".") - throw new AnalysisException(s"cannot resolve '$targetString.*' give input columns '$from'") + throw new AnalysisException(s"cannot resolve '$targetString.*' given input columns '$from'") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala index 19abce01a26cf..e1d16a2cd38b0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala @@ -190,17 +190,15 @@ abstract class AggregateFunction extends Expression { def defaultResult: Option[Literal] = None /** - * Wraps this [[AggregateFunction]] in an [[AggregateExpression]] because - * [[AggregateExpression]] is the container of an [[AggregateFunction]], aggregation mode, - * and the flag indicating if this aggregation is distinct aggregation or not. - * An [[AggregateFunction]] should not be used without being wrapped in - * an [[AggregateExpression]]. + * Creates [[AggregateExpression]] with `isDistinct` flag disabled. + * + * @see `toAggregateExpression(isDistinct: Boolean)` for detailed description */ def toAggregateExpression(): AggregateExpression = toAggregateExpression(isDistinct = false) /** - * Wraps this [[AggregateFunction]] in an [[AggregateExpression]] and set isDistinct - * field of the [[AggregateExpression]] to the given value because + * Wraps this [[AggregateFunction]] in an [[AggregateExpression]] and sets `isDistinct` + * flag of the [[AggregateExpression]] to the given value because * [[AggregateExpression]] is the container of an [[AggregateFunction]], aggregation mode, * and the flag indicating if this aggregation is distinct aggregation or not. * An [[AggregateFunction]] should not be used without being wrapped in diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlanVisitor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlanVisitor.scala index e0748043c46e2..2c248d74869ce 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlanVisitor.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlanVisitor.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.plans.logical /** - * A visitor pattern for traversing a [[LogicalPlan]] tree and compute some properties. + * A visitor pattern for traversing a [[LogicalPlan]] tree and computing some properties. */ trait LogicalPlanVisitor[T] { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/BasicStatsPlanVisitor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/BasicStatsPlanVisitor.scala index ca0775a2e8408..b6c16079d1984 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/BasicStatsPlanVisitor.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/BasicStatsPlanVisitor.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.plans.logical.statsEstimation import org.apache.spark.sql.catalyst.plans.logical._ /** - * An [[LogicalPlanVisitor]] that computes a the statistics used in a cost-based optimizer. + * A [[LogicalPlanVisitor]] that computes the statistics for the cost-based optimizer. */ object BasicStatsPlanVisitor extends LogicalPlanVisitor[Statistics] { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/SizeInBytesOnlyStatsPlanVisitor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/SizeInBytesOnlyStatsPlanVisitor.scala index 5e1c4e0bd6069..85f67c7d66075 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/SizeInBytesOnlyStatsPlanVisitor.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/SizeInBytesOnlyStatsPlanVisitor.scala @@ -48,8 +48,8 @@ object SizeInBytesOnlyStatsPlanVisitor extends LogicalPlanVisitor[Statistics] { } /** - * For leaf nodes, use its computeStats. For other nodes, we assume the size in bytes is the - * sum of all of the children's. + * For leaf nodes, use its `computeStats`. For other nodes, we assume the size in bytes is the + * product of all of the children's `computeStats`. */ override def default(p: LogicalPlan): Statistics = p match { case p: LeafNode => p.computeStats() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index cc4f4bf332459..1cef09a5bf053 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -894,7 +894,7 @@ object SQLConf { .internal() .doc("The number of bins when generating histograms.") .intConf - .checkValue(num => num > 1, "The number of bins must be large than 1.") + .checkValue(num => num > 1, "The number of bins must be larger than 1.") .createWithDefault(254) val PERCENTILE_ACCURACY = diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala index 82c5307d54360..6241d5cbb1d25 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala @@ -154,7 +154,7 @@ trait PlanTestBase extends PredicateHelper { self: Suite => } /** - * Sets all SQL configurations specified in `pairs`, calls `f`, and then restore all SQL + * Sets all SQL configurations specified in `pairs`, calls `f`, and then restores all SQL * configurations. */ protected def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 97f12ff625c42..5f3d4448e4e54 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -311,7 +311,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { if (partitioningColumns.isDefined) { throw new AnalysisException( "insertInto() can't be used together with partitionBy(). " + - "Partition columns have already be defined for the table. " + + "Partition columns have already been defined for the table. " + "It is not necessary to use partitionBy()." ) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index d3cfd2a1ffbf2..4828fa60a7b58 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -327,7 +327,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { } /** - * Create a [[DescribeTableCommand]] logical plan. + * Create a [[DescribeColumnCommand]] or [[DescribeTableCommand]] logical commands. */ override def visitDescribeTable(ctx: DescribeTableContext): LogicalPlan = withOrigin(ctx) { val isExtended = ctx.EXTENDED != null || ctx.FORMATTED != null diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index 065954559e487..6102937852347 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -58,7 +58,7 @@ trait CodegenSupport extends SparkPlan { } /** - * Whether this SparkPlan support whole stage codegen or not. + * Whether this SparkPlan supports whole stage codegen or not. */ def supportCodegen: Boolean = true diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/SetCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/SetCommand.scala index 7477d025dfe89..3c900be839aa9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/SetCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/SetCommand.scala @@ -91,8 +91,8 @@ case class SetCommand(kv: Option[(String, Option[String])]) extends RunnableComm if (sparkSession.conf.get(CATALOG_IMPLEMENTATION.key).equals("hive") && key.startsWith("hive.")) { logWarning(s"'SET $key=$value' might not work, since Spark doesn't support changing " + - "the Hive config dynamically. Please passing the Hive-specific config by adding the " + - s"prefix spark.hadoop (e.g., spark.hadoop.$key) when starting a Spark application. " + + "the Hive config dynamically. Please pass the Hive-specific config by adding the " + + s"prefix spark.hadoop (e.g. spark.hadoop.$key) when starting a Spark application. " + "For details, see the link: https://spark.apache.org/docs/latest/configuration.html#" + "dynamically-loading-spark-properties.") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala index f64e079539c4f..5dbcf4a915cbf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.types.{AtomicType, StructType} import org.apache.spark.sql.util.SchemaUtils /** - * Try to replaces [[UnresolvedRelation]]s if the plan is for direct query on files. + * Replaces [[UnresolvedRelation]]s if the plan is for direct query on files. */ class ResolveSQLOnFile(sparkSession: SparkSession) extends Rule[LogicalPlan] { private def maybeSQLFile(u: UnresolvedRelation): Boolean = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala index 6e8154d58d4c6..00bc215a5dc8c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala @@ -330,7 +330,7 @@ object HDFSMetadataLog { /** A simple trait to abstract out the file management operations needed by HDFSMetadataLog. */ trait FileManager { - /** List the files in a path that matches a filter. */ + /** List the files in a path that match a filter. */ def list(path: Path, filter: PathFilter): Array[FileStatus] /** Make directory at the give path and all its parent directories as needed. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala index a1b63a6de3823..73945b39b8967 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.internal.SQLConf.{SHUFFLE_PARTITIONS, STATE_STORE_PR case class OffsetSeq(offsets: Seq[Option[Offset]], metadata: Option[OffsetSeqMetadata] = None) { /** - * Unpacks an offset into [[StreamProgress]] by associating each offset with the order list of + * Unpacks an offset into [[StreamProgress]] by associating each offset with the ordered list of * sources. * * This method is typically used to associate a serialized offset with actual sources (which diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLog.scala index e3f4abcf9f1dc..2c8d7c7b0f3c5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLog.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.SparkSession /** * This class is used to log offsets to persistent files in HDFS. * Each file corresponds to a specific batch of offsets. The file - * format contain a version string in the first line, followed + * format contains a version string in the first line, followed * by a the JSON string representation of the offsets separated * by a newline character. If a source offset is missing, then * that line will contain a string value defined in the diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingQueryWrapper.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingQueryWrapper.scala index 020c9cb4a7304..3f2cdadfbaeee 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingQueryWrapper.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingQueryWrapper.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.streaming.{StreamingQuery, StreamingQueryException, StreamingQueryProgress, StreamingQueryStatus} /** - * Wrap non-serializable StreamExecution to make the query serializable as it's easy to for it to + * Wrap non-serializable StreamExecution to make the query serializable as it's easy for it to * get captured with normal usage. It's safe to capture the query but not use it in executors. * However, if the user tries to call its methods, it will throw `IllegalStateException`. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index 6fe632f958ffc..d1d9f95cb0977 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -94,7 +94,7 @@ trait StateStore { def abort(): Unit /** - * Return an iterator containing all the key-value pairs in the SateStore. Implementations must + * Return an iterator containing all the key-value pairs in the StateStore. Implementations must * ensure that updates (puts, removes) can be made while iterating over this iterator. */ def iterator(): Iterator[UnsafeRowPair] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala index f29e135ac357f..e0554f0c4d337 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala @@ -80,7 +80,7 @@ class ExecutionPage(parent: SQLTab) extends WebUIPage("execution") with Logging planVisualization(metrics, graph) ++ physicalPlanDescription(executionUIData.physicalPlanDescription) }.getOrElse { -
No information to display for Plan {executionId}
+
No information to display for query {executionId}
} UIUtils.headerSparkPage(s"Details for Query $executionId", content, parent, Some(5000)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala index 40a058d2cadd2..bdc4bb4422ae7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala @@ -27,8 +27,8 @@ import org.apache.spark.sql.types.DataType * * As an example: * {{{ - * // Defined a UDF that returns true or false based on some numeric score. - * val predict = udf((score: Double) => if (score > 0.5) true else false) + * // Define a UDF that returns true or false based on some numeric score. + * val predict = udf((score: Double) => score > 0.5) * * // Projects a column that adds a prediction column based on the score column. * df.select( predict(df("score")) ) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala index 2867b4cd7da5e..007f8760edf82 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala @@ -206,7 +206,7 @@ abstract class BaseSessionStateBuilder( /** * Logical query plan optimizer. * - * Note: this depends on the `conf`, `catalog` and `experimentalMethods` fields. + * Note: this depends on `catalog` and `experimentalMethods` fields. */ protected def optimizer: Optimizer = { new SparkOptimizer(catalog, experimentalMethods) { @@ -263,7 +263,7 @@ abstract class BaseSessionStateBuilder( * An interface to register custom [[org.apache.spark.sql.util.QueryExecutionListener]]s * that listen for execution metrics. * - * This gets cloned from parent if available, otherwise is a new instance is created. + * This gets cloned from parent if available, otherwise a new instance is created. */ protected def listenerManager: ExecutionListenerManager = { parentState.map(_.listenerManager.clone()).getOrElse( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index 52f2e2639cd86..9f5ca9f914284 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -118,7 +118,7 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo * You can set the following option(s): *
    *
  • `timeZone` (default session local timezone): sets the string that indicates a timezone - * to be used to parse timestamps in the JSON/CSV datasources or partition values.
  • + * to be used to parse timestamps in the JSON/CSV data sources or partition values. *
* * @since 2.0.0 @@ -129,12 +129,12 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo } /** - * Adds input options for the underlying data source. + * (Java-specific) Adds input options for the underlying data source. * * You can set the following option(s): *
    *
  • `timeZone` (default session local timezone): sets the string that indicates a timezone - * to be used to parse timestamps in the JSON/CSV datasources or partition values.
  • + * to be used to parse timestamps in the JSON/CSV data sources or partition values. *
* * @since 2.0.0 diff --git a/sql/core/src/test/resources/sql-tests/results/columnresolution-negative.sql.out b/sql/core/src/test/resources/sql-tests/results/columnresolution-negative.sql.out index b5a4f5c2bf654..539f673c9d679 100644 --- a/sql/core/src/test/resources/sql-tests/results/columnresolution-negative.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/columnresolution-negative.sql.out @@ -195,7 +195,7 @@ SELECT t1.x.y.* FROM t1 struct<> -- !query 22 output org.apache.spark.sql.AnalysisException -cannot resolve 't1.x.y.*' give input columns 'i1'; +cannot resolve 't1.x.y.*' given input columns 'i1'; -- !query 23 diff --git a/sql/core/src/test/resources/sql-tests/results/columnresolution-views.sql.out b/sql/core/src/test/resources/sql-tests/results/columnresolution-views.sql.out index 7c451c2aa5b5c..2092119600954 100644 --- a/sql/core/src/test/resources/sql-tests/results/columnresolution-views.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/columnresolution-views.sql.out @@ -88,7 +88,7 @@ SELECT global_temp.view1.* FROM global_temp.view1 struct<> -- !query 10 output org.apache.spark.sql.AnalysisException -cannot resolve 'global_temp.view1.*' give input columns 'i1'; +cannot resolve 'global_temp.view1.*' given input columns 'i1'; -- !query 11 diff --git a/sql/core/src/test/resources/sql-tests/results/columnresolution.sql.out b/sql/core/src/test/resources/sql-tests/results/columnresolution.sql.out index d3ca4443cce55..e10f516ad6e5b 100644 --- a/sql/core/src/test/resources/sql-tests/results/columnresolution.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/columnresolution.sql.out @@ -179,7 +179,7 @@ SELECT mydb1.t1.* FROM mydb1.t1 struct<> -- !query 21 output org.apache.spark.sql.AnalysisException -cannot resolve 'mydb1.t1.*' give input columns 'i1'; +cannot resolve 'mydb1.t1.*' given input columns 'i1'; -- !query 22 @@ -212,7 +212,7 @@ SELECT mydb1.t1.* FROM mydb1.t1 struct<> -- !query 25 output org.apache.spark.sql.AnalysisException -cannot resolve 'mydb1.t1.*' give input columns 'i1'; +cannot resolve 'mydb1.t1.*' given input columns 'i1'; -- !query 26 @@ -420,7 +420,7 @@ SELECT mydb1.t5.* FROM mydb1.t5 struct<> -- !query 50 output org.apache.spark.sql.AnalysisException -cannot resolve 'mydb1.t5.*' give input columns 'i1, t5'; +cannot resolve 'mydb1.t5.*' given input columns 'i1, t5'; -- !query 51 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 083a0c0b1b9a0..a79ab47f0197e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -1896,12 +1896,12 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { var e = intercept[AnalysisException] { sql("SELECT a.* FROM temp_table_no_cols a") }.getMessage - assert(e.contains("cannot resolve 'a.*' give input columns ''")) + assert(e.contains("cannot resolve 'a.*' given input columns ''")) e = intercept[AnalysisException] { dfNoCols.select($"b.*") }.getMessage - assert(e.contains("cannot resolve 'b.*' give input columns ''")) + assert(e.contains("cannot resolve 'b.*' given input columns ''")) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala index 14082197ba0bd..ce8fde28a941c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala @@ -663,7 +663,7 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils { }.getMessage assert(e.contains("The depth of view `default`.`view0` exceeds the maximum view " + "resolution depth (10). Analysis is aborted to avoid errors. Increase the value " + - "of spark.sql.view.maxNestedViewDepth to work aroud this.")) + "of spark.sql.view.maxNestedViewDepth to work around this.")) } val e = intercept[IllegalArgumentException] { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index 632e3e0c4c3f9..3b8a8ca301c27 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -109,8 +109,8 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat } /** - * Get the raw table metadata from hive metastore directly. The raw table metadata may contains - * special data source properties and should not be exposed outside of `HiveExternalCatalog`. We + * Get the raw table metadata from hive metastore directly. The raw table metadata may contain + * special data source properties that should not be exposed outside of `HiveExternalCatalog`. We * should interpret these special data source properties and restore the original table metadata * before returning it. */ From 51eb750263dd710434ddb60311571fa3dcec66eb Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 22 Jan 2018 15:21:09 -0800 Subject: [PATCH 0176/1161] [SPARK-22389][SQL] data source v2 partitioning reporting interface ## What changes were proposed in this pull request? a new interface which allows data source to report partitioning and avoid shuffle at Spark side. The design is pretty like the internal distribution/partitioing framework. Spark defines a `Distribution` interfaces and several concrete implementations, and ask the data source to report a `Partitioning`, the `Partitioning` should tell Spark if it can satisfy a `Distribution` or not. ## How was this patch tested? new test Author: Wenchen Fan Closes #20201 from cloud-fan/partition-reporting. --- .../plans/physical/partitioning.scala | 2 +- .../v2/reader/ClusteredDistribution.java | 38 ++++++ .../sql/sources/v2/reader/Distribution.java | 39 +++++++ .../sql/sources/v2/reader/Partitioning.java | 46 ++++++++ .../v2/reader/SupportsReportPartitioning.java | 33 ++++++ .../v2/DataSourcePartitioning.scala | 56 +++++++++ .../datasources/v2/DataSourceV2ScanExec.scala | 9 ++ .../v2/JavaPartitionAwareDataSource.java | 110 ++++++++++++++++++ .../sql/sources/v2/DataSourceV2Suite.scala | 79 +++++++++++++ 9 files changed, 411 insertions(+), 1 deletion(-) create mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ClusteredDistribution.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Distribution.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Partitioning.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourcePartitioning.scala create mode 100644 sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 0189bd73c56bf..4d9a9925fe3ff 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -153,7 +153,7 @@ case class BroadcastDistribution(mode: BroadcastMode) extends Distribution { * 1. number of partitions. * 2. if it can satisfy a given distribution. */ -sealed trait Partitioning { +trait Partitioning { /** Returns the number of partitions that the data is split across */ val numPartitions: Int diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ClusteredDistribution.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ClusteredDistribution.java new file mode 100644 index 0000000000000..7346500de45b6 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ClusteredDistribution.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.reader; + +import org.apache.spark.annotation.InterfaceStability; + +/** + * A concrete implementation of {@link Distribution}. Represents a distribution where records that + * share the same values for the {@link #clusteredColumns} will be produced by the same + * {@link ReadTask}. + */ +@InterfaceStability.Evolving +public class ClusteredDistribution implements Distribution { + + /** + * The names of the clustered columns. Note that they are order insensitive. + */ + public final String[] clusteredColumns; + + public ClusteredDistribution(String[] clusteredColumns) { + this.clusteredColumns = clusteredColumns; + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Distribution.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Distribution.java new file mode 100644 index 0000000000000..a6201a222f541 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Distribution.java @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.reader; + +import org.apache.spark.annotation.InterfaceStability; + +/** + * An interface to represent data distribution requirement, which specifies how the records should + * be distributed among the {@link ReadTask}s that are returned by + * {@link DataSourceV2Reader#createReadTasks()}. Note that this interface has nothing to do with + * the data ordering inside one partition(the output records of a single {@link ReadTask}). + * + * The instance of this interface is created and provided by Spark, then consumed by + * {@link Partitioning#satisfy(Distribution)}. This means data source developers don't need to + * implement this interface, but need to catch as more concrete implementations of this interface + * as possible in {@link Partitioning#satisfy(Distribution)}. + * + * Concrete implementations until now: + *
    + *
  • {@link ClusteredDistribution}
  • + *
+ */ +@InterfaceStability.Evolving +public interface Distribution {} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Partitioning.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Partitioning.java new file mode 100644 index 0000000000000..199e45d4a02ab --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Partitioning.java @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.reader; + +import org.apache.spark.annotation.InterfaceStability; + +/** + * An interface to represent the output data partitioning for a data source, which is returned by + * {@link SupportsReportPartitioning#outputPartitioning()}. Note that this should work like a + * snapshot. Once created, it should be deterministic and always report the same number of + * partitions and the same "satisfy" result for a certain distribution. + */ +@InterfaceStability.Evolving +public interface Partitioning { + + /** + * Returns the number of partitions(i.e., {@link ReadTask}s) the data source outputs. + */ + int numPartitions(); + + /** + * Returns true if this partitioning can satisfy the given distribution, which means Spark does + * not need to shuffle the output data of this data source for some certain operations. + * + * Note that, Spark may add new concrete implementations of {@link Distribution} in new releases. + * This method should be aware of it and always return false for unrecognized distributions. It's + * recommended to check every Spark new release and support new distributions if possible, to + * avoid shuffle at Spark side for more cases. + */ + boolean satisfy(Distribution distribution); +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java new file mode 100644 index 0000000000000..f786472ccf345 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.reader; + +import org.apache.spark.annotation.InterfaceStability; + +/** + * A mix in interface for {@link DataSourceV2Reader}. Data source readers can implement this + * interface to report data partitioning and try to avoid shuffle at Spark side. + */ +@InterfaceStability.Evolving +public interface SupportsReportPartitioning { + + /** + * Returns the output data partitioning that this reader guarantees. + */ + Partitioning outputPartitioning(); +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourcePartitioning.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourcePartitioning.scala new file mode 100644 index 0000000000000..943d0100aca56 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourcePartitioning.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, Expression} +import org.apache.spark.sql.catalyst.plans.physical +import org.apache.spark.sql.sources.v2.reader.{ClusteredDistribution, Partitioning} + +/** + * An adapter from public data source partitioning to catalyst internal `Partitioning`. + */ +class DataSourcePartitioning( + partitioning: Partitioning, + colNames: AttributeMap[String]) extends physical.Partitioning { + + override val numPartitions: Int = partitioning.numPartitions() + + override def satisfies(required: physical.Distribution): Boolean = { + super.satisfies(required) || { + required match { + case d: physical.ClusteredDistribution if isCandidate(d.clustering) => + val attrs = d.clustering.map(_.asInstanceOf[Attribute]) + partitioning.satisfy( + new ClusteredDistribution(attrs.map { a => + val name = colNames.get(a) + assert(name.isDefined, s"Attribute ${a.name} is not found in the data source output") + name.get + }.toArray)) + + case _ => false + } + } + } + + private def isCandidate(clustering: Seq[Expression]): Boolean = { + clustering.forall { + case a: Attribute => colNames.contains(a) + case _ => false + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala index beb66738732be..69d871df3e1dd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder} import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.physical import org.apache.spark.sql.execution.{ColumnarBatchScan, LeafExecNode, WholeStageCodegenExec} import org.apache.spark.sql.execution.streaming.continuous._ import org.apache.spark.sql.sources.v2.reader._ @@ -42,6 +43,14 @@ case class DataSourceV2ScanExec( override def producedAttributes: AttributeSet = AttributeSet(fullOutput) + override def outputPartitioning: physical.Partitioning = reader match { + case s: SupportsReportPartitioning => + new DataSourcePartitioning( + s.outputPartitioning(), AttributeMap(output.map(a => a -> a.name))) + + case _ => super.outputPartitioning + } + private lazy val readTasks: java.util.List[ReadTask[UnsafeRow]] = reader match { case r: SupportsScanUnsafeRow => r.createUnsafeRowReadTasks() case _ => diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java new file mode 100644 index 0000000000000..806d0bcd93f18 --- /dev/null +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test.org.apache.spark.sql.sources.v2; + +import java.io.IOException; +import java.util.Arrays; +import java.util.List; + +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.expressions.GenericRow; +import org.apache.spark.sql.sources.v2.DataSourceV2; +import org.apache.spark.sql.sources.v2.DataSourceV2Options; +import org.apache.spark.sql.sources.v2.ReadSupport; +import org.apache.spark.sql.sources.v2.reader.*; +import org.apache.spark.sql.types.StructType; + +public class JavaPartitionAwareDataSource implements DataSourceV2, ReadSupport { + + class Reader implements DataSourceV2Reader, SupportsReportPartitioning { + private final StructType schema = new StructType().add("a", "int").add("b", "int"); + + @Override + public StructType readSchema() { + return schema; + } + + @Override + public List> createReadTasks() { + return java.util.Arrays.asList( + new SpecificReadTask(new int[]{1, 1, 3}, new int[]{4, 4, 6}), + new SpecificReadTask(new int[]{2, 4, 4}, new int[]{6, 2, 2})); + } + + @Override + public Partitioning outputPartitioning() { + return new MyPartitioning(); + } + } + + static class MyPartitioning implements Partitioning { + + @Override + public int numPartitions() { + return 2; + } + + @Override + public boolean satisfy(Distribution distribution) { + if (distribution instanceof ClusteredDistribution) { + String[] clusteredCols = ((ClusteredDistribution) distribution).clusteredColumns; + return Arrays.asList(clusteredCols).contains("a"); + } + + return false; + } + } + + static class SpecificReadTask implements ReadTask, DataReader { + private int[] i; + private int[] j; + private int current = -1; + + SpecificReadTask(int[] i, int[] j) { + assert i.length == j.length; + this.i = i; + this.j = j; + } + + @Override + public boolean next() throws IOException { + current += 1; + return current < i.length; + } + + @Override + public Row get() { + return new GenericRow(new Object[] {i[current], j[current]}); + } + + @Override + public void close() throws IOException { + + } + + @Override + public DataReader createDataReader() { + return this; + } + } + + @Override + public DataSourceV2Reader createReader(DataSourceV2Options options) { + return new Reader(); + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala index 0ca29524c6d05..0620693b35d16 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala @@ -24,6 +24,7 @@ import test.org.apache.spark.sql.sources.v2._ import org.apache.spark.SparkException import org.apache.spark.sql.{AnalysisException, QueryTest, Row} import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector import org.apache.spark.sql.sources.{Filter, GreaterThan} import org.apache.spark.sql.sources.v2.reader._ @@ -95,6 +96,40 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { } } + test("partitioning reporting") { + import org.apache.spark.sql.functions.{count, sum} + Seq(classOf[PartitionAwareDataSource], classOf[JavaPartitionAwareDataSource]).foreach { cls => + withClue(cls.getName) { + val df = spark.read.format(cls.getName).load() + checkAnswer(df, Seq(Row(1, 4), Row(1, 4), Row(3, 6), Row(2, 6), Row(4, 2), Row(4, 2))) + + val groupByColA = df.groupBy('a).agg(sum('b)) + checkAnswer(groupByColA, Seq(Row(1, 8), Row(2, 6), Row(3, 6), Row(4, 4))) + assert(groupByColA.queryExecution.executedPlan.collectFirst { + case e: ShuffleExchangeExec => e + }.isEmpty) + + val groupByColAB = df.groupBy('a, 'b).agg(count("*")) + checkAnswer(groupByColAB, Seq(Row(1, 4, 2), Row(2, 6, 1), Row(3, 6, 1), Row(4, 2, 2))) + assert(groupByColAB.queryExecution.executedPlan.collectFirst { + case e: ShuffleExchangeExec => e + }.isEmpty) + + val groupByColB = df.groupBy('b).agg(sum('a)) + checkAnswer(groupByColB, Seq(Row(2, 8), Row(4, 2), Row(6, 5))) + assert(groupByColB.queryExecution.executedPlan.collectFirst { + case e: ShuffleExchangeExec => e + }.isDefined) + + val groupByAPlusB = df.groupBy('a + 'b).agg(count("*")) + checkAnswer(groupByAPlusB, Seq(Row(5, 2), Row(6, 2), Row(8, 1), Row(9, 1))) + assert(groupByAPlusB.queryExecution.executedPlan.collectFirst { + case e: ShuffleExchangeExec => e + }.isDefined) + } + } + } + test("simple writable data source") { // TODO: java implementation. Seq(classOf[SimpleWritableDataSource]).foreach { cls => @@ -365,3 +400,47 @@ class BatchReadTask(start: Int, end: Int) override def close(): Unit = batch.close() } + +class PartitionAwareDataSource extends DataSourceV2 with ReadSupport { + + class Reader extends DataSourceV2Reader with SupportsReportPartitioning { + override def readSchema(): StructType = new StructType().add("a", "int").add("b", "int") + + override def createReadTasks(): JList[ReadTask[Row]] = { + // Note that we don't have same value of column `a` across partitions. + java.util.Arrays.asList( + new SpecificReadTask(Array(1, 1, 3), Array(4, 4, 6)), + new SpecificReadTask(Array(2, 4, 4), Array(6, 2, 2))) + } + + override def outputPartitioning(): Partitioning = new MyPartitioning + } + + class MyPartitioning extends Partitioning { + override def numPartitions(): Int = 2 + + override def satisfy(distribution: Distribution): Boolean = distribution match { + case c: ClusteredDistribution => c.clusteredColumns.contains("a") + case _ => false + } + } + + override def createReader(options: DataSourceV2Options): DataSourceV2Reader = new Reader +} + +class SpecificReadTask(i: Array[Int], j: Array[Int]) extends ReadTask[Row] with DataReader[Row] { + assert(i.length == j.length) + + private var current = -1 + + override def createDataReader(): DataReader[Row] = this + + override def next(): Boolean = { + current += 1 + current < i.length + } + + override def get(): Row = Row(i(current), j(current)) + + override def close(): Unit = {} +} From b2ce17b4c9fea58140a57ca1846b2689b15c0d61 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Tue, 23 Jan 2018 14:11:30 +0900 Subject: [PATCH 0177/1161] [SPARK-22274][PYTHON][SQL] User-defined aggregation functions with pandas udf (full shuffle) ## What changes were proposed in this pull request? Add support for using pandas UDFs with groupby().agg(). This PR introduces a new type of pandas UDF - group aggregate pandas UDF. This type of UDF defines a transformation of multiple pandas Series -> a scalar value. Group aggregate pandas UDFs can be used with groupby().agg(). Note group aggregate pandas UDF doesn't support partial aggregation, i.e., a full shuffle is required. This PR doesn't support group aggregate pandas UDFs that return ArrayType, StructType or MapType. Support for these types is left for future PR. ## How was this patch tested? GroupbyAggPandasUDFTests Author: Li Jin Closes #19872 from icexelloss/SPARK-22274-groupby-agg. --- .../spark/api/python/PythonRunner.scala | 2 + python/pyspark/rdd.py | 1 + python/pyspark/sql/functions.py | 36 +- python/pyspark/sql/group.py | 33 +- python/pyspark/sql/tests.py | 486 +++++++++++++++++- python/pyspark/sql/udf.py | 13 +- python/pyspark/worker.py | 22 +- .../sql/catalyst/analysis/CheckAnalysis.scala | 14 +- .../sql/catalyst/expressions}/PythonUDF.scala | 31 +- .../sql/catalyst/planning/patterns.scala | 12 +- .../spark/sql/RelationalGroupedDataset.scala | 1 - .../spark/sql/execution/SparkStrategies.scala | 29 +- .../python/AggregateInPandasExec.scala | 155 ++++++ .../execution/python/ExtractPythonUDFs.scala | 16 +- .../python/UserDefinedPythonFunction.scala | 2 +- 15 files changed, 792 insertions(+), 61 deletions(-) rename sql/{core/src/main/scala/org/apache/spark/sql/execution/python => catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions}/PythonUDF.scala (60%) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala index 1ec0e717fac29..29148a7ee558b 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala @@ -39,12 +39,14 @@ private[spark] object PythonEvalType { val SQL_PANDAS_SCALAR_UDF = 200 val SQL_PANDAS_GROUP_MAP_UDF = 201 + val SQL_PANDAS_GROUP_AGG_UDF = 202 def toString(pythonEvalType: Int): String = pythonEvalType match { case NON_UDF => "NON_UDF" case SQL_BATCHED_UDF => "SQL_BATCHED_UDF" case SQL_PANDAS_SCALAR_UDF => "SQL_PANDAS_SCALAR_UDF" case SQL_PANDAS_GROUP_MAP_UDF => "SQL_PANDAS_GROUP_MAP_UDF" + case SQL_PANDAS_GROUP_AGG_UDF => "SQL_PANDAS_GROUP_AGG_UDF" } } diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 1b3915548fb14..6b018c3a38444 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -70,6 +70,7 @@ class PythonEvalType(object): SQL_PANDAS_SCALAR_UDF = 200 SQL_PANDAS_GROUP_MAP_UDF = 201 + SQL_PANDAS_GROUP_AGG_UDF = 202 def portable_hash(x): diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 961b3267b44cf..a291c9b71913f 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2089,6 +2089,8 @@ class PandasUDFType(object): GROUP_MAP = PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF + GROUP_AGG = PythonEvalType.SQL_PANDAS_GROUP_AGG_UDF + @since(1.3) def udf(f=None, returnType=StringType()): @@ -2159,7 +2161,7 @@ def pandas_udf(f=None, returnType=None, functionType=None): 1. SCALAR A scalar UDF defines a transformation: One or more `pandas.Series` -> A `pandas.Series`. - The returnType should be a primitive data type, e.g., `DoubleType()`. + The returnType should be a primitive data type, e.g., :class:`DoubleType`. The length of the returned `pandas.Series` must be of the same as the input `pandas.Series`. Scalar UDFs are used with :meth:`pyspark.sql.DataFrame.withColumn` and @@ -2221,6 +2223,35 @@ def pandas_udf(f=None, returnType=None, functionType=None): .. seealso:: :meth:`pyspark.sql.GroupedData.apply` + 3. GROUP_AGG + + A group aggregate UDF defines a transformation: One or more `pandas.Series` -> A scalar + The `returnType` should be a primitive data type, e.g., :class:`DoubleType`. + The returned scalar can be either a python primitive type, e.g., `int` or `float` + or a numpy data type, e.g., `numpy.int64` or `numpy.float64`. + + :class:`ArrayType`, :class:`MapType` and :class:`StructType` are currently not supported as + output types. + + Group aggregate UDFs are used with :meth:`pyspark.sql.GroupedData.agg` + + >>> from pyspark.sql.functions import pandas_udf, PandasUDFType + >>> df = spark.createDataFrame( + ... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], + ... ("id", "v")) + >>> @pandas_udf("double", PandasUDFType.GROUP_AGG) # doctest: +SKIP + ... def mean_udf(v): + ... return v.mean() + >>> df.groupby("id").agg(mean_udf(df['v'])).show() # doctest: +SKIP + +---+-----------+ + | id|mean_udf(v)| + +---+-----------+ + | 1| 1.5| + | 2| 6.0| + +---+-----------+ + + .. seealso:: :meth:`pyspark.sql.GroupedData.agg` + .. note:: The user-defined functions are considered deterministic by default. Due to optimization, duplicate invocations may be eliminated or the function may even be invoked more times than it is present in the query. If your function is not deterministic, call @@ -2267,7 +2298,8 @@ def pandas_udf(f=None, returnType=None, functionType=None): raise ValueError("Invalid returnType: returnType can not be None") if eval_type not in [PythonEvalType.SQL_PANDAS_SCALAR_UDF, - PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF]: + PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF, + PythonEvalType.SQL_PANDAS_GROUP_AGG_UDF]: raise ValueError("Invalid functionType: " "functionType must be one the values from PandasUDFType") diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py index 22061b83eb78c..f90a909d7c2b1 100644 --- a/python/pyspark/sql/group.py +++ b/python/pyspark/sql/group.py @@ -65,13 +65,27 @@ def __init__(self, jgd, df): def agg(self, *exprs): """Compute aggregates and returns the result as a :class:`DataFrame`. - The available aggregate functions are `avg`, `max`, `min`, `sum`, `count`. + The available aggregate functions can be: + + 1. built-in aggregation functions, such as `avg`, `max`, `min`, `sum`, `count` + + 2. group aggregate pandas UDFs, created with :func:`pyspark.sql.functions.pandas_udf` + + .. note:: There is no partial aggregation with group aggregate UDFs, i.e., + a full shuffle is required. Also, all the data of a group will be loaded into + memory, so the user should be aware of the potential OOM risk if data is skewed + and certain groups are too large to fit in memory. + + .. seealso:: :func:`pyspark.sql.functions.pandas_udf` If ``exprs`` is a single :class:`dict` mapping from string to string, then the key is the column to perform aggregation on, and the value is the aggregate function. Alternatively, ``exprs`` can also be a list of aggregate :class:`Column` expressions. + .. note:: Built-in aggregation functions and group aggregate pandas UDFs cannot be mixed + in a single call to this function. + :param exprs: a dict mapping from column name (string) to aggregate functions (string), or a list of :class:`Column`. @@ -82,6 +96,13 @@ def agg(self, *exprs): >>> from pyspark.sql import functions as F >>> sorted(gdf.agg(F.min(df.age)).collect()) [Row(name=u'Alice', min(age)=2), Row(name=u'Bob', min(age)=5)] + + >>> from pyspark.sql.functions import pandas_udf, PandasUDFType + >>> @pandas_udf('int', PandasUDFType.GROUP_AGG) # doctest: +SKIP + ... def min_udf(v): + ... return v.min() + >>> sorted(gdf.agg(min_udf(df.age)).collect()) # doctest: +SKIP + [Row(name=u'Alice', min_udf(age)=2), Row(name=u'Bob', min_udf(age)=5)] """ assert exprs, "exprs should not be empty" if len(exprs) == 1 and isinstance(exprs[0], dict): @@ -204,16 +225,18 @@ def apply(self, udf): The user-defined function should take a `pandas.DataFrame` and return another `pandas.DataFrame`. For each group, all columns are passed together as a `pandas.DataFrame` - to the user-function and the returned `pandas.DataFrame`s are combined as a + to the user-function and the returned `pandas.DataFrame` are combined as a :class:`DataFrame`. + The returned `pandas.DataFrame` can be of arbitrary length and its schema must match the returnType of the pandas udf. - This function does not support partial aggregation, and requires shuffling all the data in - the :class:`DataFrame`. + .. note:: This function requires a full shuffle. all the data of a group will be loaded + into memory, so the user should be aware of the potential OOM risk if data is skewed + and certain groups are too large to fit in memory. :param udf: a group map user-defined function returned by - :meth:`pyspark.sql.functions.pandas_udf`. + :func:`pyspark.sql.functions.pandas_udf`. >>> from pyspark.sql.functions import pandas_udf, PandasUDFType >>> df = spark.createDataFrame( diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 4fee2ecde391b..84e8eec71dd8a 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -197,6 +197,12 @@ def tearDownClass(cls): ReusedPySparkTestCase.tearDownClass() cls.spark.stop() + def assertPandasEqual(self, expected, result): + msg = ("DataFrames are not equal: " + + "\n\nExpected:\n%s\n%s" % (expected, expected.dtypes) + + "\n\nResult:\n%s\n%s" % (result, result.dtypes)) + self.assertTrue(expected.equals(result), msg=msg) + class DataTypeTests(unittest.TestCase): # regression test for SPARK-6055 @@ -3371,12 +3377,6 @@ def tearDownClass(cls): time.tzset() ReusedSQLTestCase.tearDownClass() - def assertFramesEqual(self, df_with_arrow, df_without): - msg = ("DataFrame from Arrow is not equal" + - ("\n\nWith Arrow:\n%s\n%s" % (df_with_arrow, df_with_arrow.dtypes)) + - ("\n\nWithout:\n%s\n%s" % (df_without, df_without.dtypes))) - self.assertTrue(df_without.equals(df_with_arrow), msg=msg) - def create_pandas_data_frame(self): import pandas as pd import numpy as np @@ -3414,7 +3414,7 @@ def _toPandas_arrow_toggle(self, df): def test_toPandas_arrow_toggle(self): df = self.spark.createDataFrame(self.data, schema=self.schema) pdf, pdf_arrow = self._toPandas_arrow_toggle(df) - self.assertFramesEqual(pdf_arrow, pdf) + self.assertPandasEqual(pdf_arrow, pdf) def test_toPandas_respect_session_timezone(self): df = self.spark.createDataFrame(self.data, schema=self.schema) @@ -3425,11 +3425,11 @@ def test_toPandas_respect_session_timezone(self): self.spark.conf.set("spark.sql.execution.pandas.respectSessionTimeZone", "false") try: pdf_la, pdf_arrow_la = self._toPandas_arrow_toggle(df) - self.assertFramesEqual(pdf_arrow_la, pdf_la) + self.assertPandasEqual(pdf_arrow_la, pdf_la) finally: self.spark.conf.set("spark.sql.execution.pandas.respectSessionTimeZone", "true") pdf_ny, pdf_arrow_ny = self._toPandas_arrow_toggle(df) - self.assertFramesEqual(pdf_arrow_ny, pdf_ny) + self.assertPandasEqual(pdf_arrow_ny, pdf_ny) self.assertFalse(pdf_ny.equals(pdf_la)) @@ -3439,7 +3439,7 @@ def test_toPandas_respect_session_timezone(self): if isinstance(field.dataType, TimestampType): pdf_la_corrected[field.name] = _check_series_convert_timestamps_local_tz( pdf_la_corrected[field.name], timezone) - self.assertFramesEqual(pdf_ny, pdf_la_corrected) + self.assertPandasEqual(pdf_ny, pdf_la_corrected) finally: self.spark.conf.set("spark.sql.session.timeZone", orig_tz) @@ -3447,7 +3447,7 @@ def test_pandas_round_trip(self): pdf = self.create_pandas_data_frame() df = self.spark.createDataFrame(self.data, schema=self.schema) pdf_arrow = df.toPandas() - self.assertFramesEqual(pdf_arrow, pdf) + self.assertPandasEqual(pdf_arrow, pdf) def test_filtered_frame(self): df = self.spark.range(3).toDF("i") @@ -3505,7 +3505,7 @@ def test_createDataFrame_with_schema(self): df = self.spark.createDataFrame(pdf, schema=self.schema) self.assertEquals(self.schema, df.schema) pdf_arrow = df.toPandas() - self.assertFramesEqual(pdf_arrow, pdf) + self.assertPandasEqual(pdf_arrow, pdf) def test_createDataFrame_with_incorrect_schema(self): pdf = self.create_pandas_data_frame() @@ -3717,7 +3717,7 @@ def foo(k, v): @unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed") -class VectorizedUDFTests(ReusedSQLTestCase): +class ScalarPandasUDF(ReusedSQLTestCase): @classmethod def setUpClass(cls): @@ -4196,13 +4196,7 @@ def test_register_vectorized_udf_basic(self): @unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed") -class GroupbyApplyTests(ReusedSQLTestCase): - - def assertFramesEqual(self, expected, result): - msg = ("DataFrames are not equal: " + - ("\n\nExpected:\n%s\n%s" % (expected, expected.dtypes)) + - ("\n\nResult:\n%s\n%s" % (result, result.dtypes))) - self.assertTrue(expected.equals(result), msg=msg) +class GroupbyApplyPandasUDFTests(ReusedSQLTestCase): @property def data(self): @@ -4227,7 +4221,7 @@ def test_simple(self): result = df.groupby('id').apply(foo_udf).sort('id').toPandas() expected = df.toPandas().groupby('id').apply(foo_udf.func).reset_index(drop=True) - self.assertFramesEqual(expected, result) + self.assertPandasEqual(expected, result) def test_register_group_map_udf(self): from pyspark.sql.functions import pandas_udf, PandasUDFType @@ -4251,7 +4245,7 @@ def foo(pdf): result = df.groupby('id').apply(foo).sort('id').toPandas() expected = df.toPandas().groupby('id').apply(foo.func).reset_index(drop=True) - self.assertFramesEqual(expected, result) + self.assertPandasEqual(expected, result) def test_coerce(self): from pyspark.sql.functions import pandas_udf, PandasUDFType @@ -4266,7 +4260,7 @@ def test_coerce(self): result = df.groupby('id').apply(foo).sort('id').toPandas() expected = df.toPandas().groupby('id').apply(foo.func).reset_index(drop=True) expected = expected.assign(v=expected.v.astype('float64')) - self.assertFramesEqual(expected, result) + self.assertPandasEqual(expected, result) def test_complex_groupby(self): from pyspark.sql.functions import pandas_udf, col, PandasUDFType @@ -4285,7 +4279,7 @@ def normalize(pdf): expected = pdf.groupby(pdf['id'] % 2 == 0).apply(normalize.func) expected = expected.sort_values(['id', 'v']).reset_index(drop=True) expected = expected.assign(norm=expected.norm.astype('float64')) - self.assertFramesEqual(expected, result) + self.assertPandasEqual(expected, result) def test_empty_groupby(self): from pyspark.sql.functions import pandas_udf, col, PandasUDFType @@ -4304,7 +4298,7 @@ def normalize(pdf): expected = normalize.func(pdf) expected = expected.sort_values(['id', 'v']).reset_index(drop=True) expected = expected.assign(norm=expected.norm.astype('float64')) - self.assertFramesEqual(expected, result) + self.assertPandasEqual(expected, result) def test_datatype_string(self): from pyspark.sql.functions import pandas_udf, PandasUDFType @@ -4318,7 +4312,7 @@ def test_datatype_string(self): result = df.groupby('id').apply(foo_udf).sort('id').toPandas() expected = df.toPandas().groupby('id').apply(foo_udf.func).reset_index(drop=True) - self.assertFramesEqual(expected, result) + self.assertPandasEqual(expected, result) def test_wrong_return_type(self): from pyspark.sql.functions import pandas_udf, PandasUDFType @@ -4370,6 +4364,446 @@ def test_unsupported_types(self): df.groupby('id').apply(f).collect() +@unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed") +class GroupbyAggPandasUDFTests(ReusedSQLTestCase): + + @property + def data(self): + from pyspark.sql.functions import array, explode, col, lit + return self.spark.range(10).toDF('id') \ + .withColumn("vs", array([lit(i * 1.0) + col('id') for i in range(20, 30)])) \ + .withColumn("v", explode(col('vs'))) \ + .drop('vs') \ + .withColumn('w', lit(1.0)) + + @property + def python_plus_one(self): + from pyspark.sql.functions import udf + + @udf('double') + def plus_one(v): + assert isinstance(v, (int, float)) + return v + 1 + return plus_one + + @property + def pandas_scalar_plus_two(self): + import pandas as pd + from pyspark.sql.functions import pandas_udf, PandasUDFType + + @pandas_udf('double', PandasUDFType.SCALAR) + def plus_two(v): + assert isinstance(v, pd.Series) + return v + 2 + return plus_two + + @property + def pandas_agg_mean_udf(self): + from pyspark.sql.functions import pandas_udf, PandasUDFType + + @pandas_udf('double', PandasUDFType.GROUP_AGG) + def avg(v): + return v.mean() + return avg + + @property + def pandas_agg_sum_udf(self): + from pyspark.sql.functions import pandas_udf, PandasUDFType + + @pandas_udf('double', PandasUDFType.GROUP_AGG) + def sum(v): + return v.sum() + return sum + + @property + def pandas_agg_weighted_mean_udf(self): + import numpy as np + from pyspark.sql.functions import pandas_udf, PandasUDFType + + @pandas_udf('double', PandasUDFType.GROUP_AGG) + def weighted_mean(v, w): + return np.average(v, weights=w) + return weighted_mean + + def test_manual(self): + df = self.data + sum_udf = self.pandas_agg_sum_udf + mean_udf = self.pandas_agg_mean_udf + + result1 = df.groupby('id').agg(sum_udf(df.v), mean_udf(df.v)).sort('id') + expected1 = self.spark.createDataFrame( + [[0, 245.0, 24.5], + [1, 255.0, 25.5], + [2, 265.0, 26.5], + [3, 275.0, 27.5], + [4, 285.0, 28.5], + [5, 295.0, 29.5], + [6, 305.0, 30.5], + [7, 315.0, 31.5], + [8, 325.0, 32.5], + [9, 335.0, 33.5]], + ['id', 'sum(v)', 'avg(v)']) + + self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) + + def test_basic(self): + from pyspark.sql.functions import col, lit, sum, mean + + df = self.data + weighted_mean_udf = self.pandas_agg_weighted_mean_udf + + # Groupby one column and aggregate one UDF with literal + result1 = df.groupby('id').agg(weighted_mean_udf(df.v, lit(1.0))).sort('id') + expected1 = df.groupby('id').agg(mean(df.v).alias('weighted_mean(v, 1.0)')).sort('id') + self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) + + # Groupby one expression and aggregate one UDF with literal + result2 = df.groupby((col('id') + 1)).agg(weighted_mean_udf(df.v, lit(1.0)))\ + .sort(df.id + 1) + expected2 = df.groupby((col('id') + 1))\ + .agg(mean(df.v).alias('weighted_mean(v, 1.0)')).sort(df.id + 1) + self.assertPandasEqual(expected2.toPandas(), result2.toPandas()) + + # Groupby one column and aggregate one UDF without literal + result3 = df.groupby('id').agg(weighted_mean_udf(df.v, df.w)).sort('id') + expected3 = df.groupby('id').agg(mean(df.v).alias('weighted_mean(v, w)')).sort('id') + self.assertPandasEqual(expected3.toPandas(), result3.toPandas()) + + # Groupby one expression and aggregate one UDF without literal + result4 = df.groupby((col('id') + 1).alias('id'))\ + .agg(weighted_mean_udf(df.v, df.w))\ + .sort('id') + expected4 = df.groupby((col('id') + 1).alias('id'))\ + .agg(mean(df.v).alias('weighted_mean(v, w)'))\ + .sort('id') + self.assertPandasEqual(expected4.toPandas(), result4.toPandas()) + + def test_unsupported_types(self): + from pyspark.sql.types import ArrayType, DoubleType, MapType + from pyspark.sql.functions import pandas_udf, PandasUDFType + + with QuietTest(self.sc): + with self.assertRaisesRegex(NotImplementedError, 'not supported'): + @pandas_udf(ArrayType(DoubleType()), PandasUDFType.GROUP_AGG) + def mean_and_std_udf(v): + return [v.mean(), v.std()] + + with QuietTest(self.sc): + with self.assertRaisesRegex(NotImplementedError, 'not supported'): + @pandas_udf('mean double, std double', PandasUDFType.GROUP_AGG) + def mean_and_std_udf(v): + return v.mean(), v.std() + + with QuietTest(self.sc): + with self.assertRaisesRegex(NotImplementedError, 'not supported'): + @pandas_udf(MapType(DoubleType(), DoubleType()), PandasUDFType.GROUP_AGG) + def mean_and_std_udf(v): + return {v.mean(): v.std()} + + def test_alias(self): + from pyspark.sql.functions import mean + + df = self.data + mean_udf = self.pandas_agg_mean_udf + + result1 = df.groupby('id').agg(mean_udf(df.v).alias('mean_alias')) + expected1 = df.groupby('id').agg(mean(df.v).alias('mean_alias')) + + self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) + + def test_mixed_sql(self): + """ + Test mixing group aggregate pandas UDF with sql expression. + """ + from pyspark.sql.functions import sum, mean + + df = self.data + sum_udf = self.pandas_agg_sum_udf + + # Mix group aggregate pandas UDF with sql expression + result1 = (df.groupby('id') + .agg(sum_udf(df.v) + 1) + .sort('id')) + expected1 = (df.groupby('id') + .agg(sum(df.v) + 1) + .sort('id')) + + # Mix group aggregate pandas UDF with sql expression (order swapped) + result2 = (df.groupby('id') + .agg(sum_udf(df.v + 1)) + .sort('id')) + + expected2 = (df.groupby('id') + .agg(sum(df.v + 1)) + .sort('id')) + + # Wrap group aggregate pandas UDF with two sql expressions + result3 = (df.groupby('id') + .agg(sum_udf(df.v + 1) + 2) + .sort('id')) + expected3 = (df.groupby('id') + .agg(sum(df.v + 1) + 2) + .sort('id')) + + self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) + self.assertPandasEqual(expected2.toPandas(), result2.toPandas()) + self.assertPandasEqual(expected3.toPandas(), result3.toPandas()) + + def test_mixed_udfs(self): + """ + Test mixing group aggregate pandas UDF with python UDF and scalar pandas UDF. + """ + from pyspark.sql.functions import sum, mean + + df = self.data + plus_one = self.python_plus_one + plus_two = self.pandas_scalar_plus_two + sum_udf = self.pandas_agg_sum_udf + + # Mix group aggregate pandas UDF and python UDF + result1 = (df.groupby('id') + .agg(plus_one(sum_udf(df.v))) + .sort('id')) + expected1 = (df.groupby('id') + .agg(plus_one(sum(df.v))) + .sort('id')) + + # Mix group aggregate pandas UDF and python UDF (order swapped) + result2 = (df.groupby('id') + .agg(sum_udf(plus_one(df.v))) + .sort('id')) + expected2 = (df.groupby('id') + .agg(sum(plus_one(df.v))) + .sort('id')) + + # Mix group aggregate pandas UDF and scalar pandas UDF + result3 = (df.groupby('id') + .agg(sum_udf(plus_two(df.v))) + .sort('id')) + expected3 = (df.groupby('id') + .agg(sum(plus_two(df.v))) + .sort('id')) + + # Mix group aggregate pandas UDF and scalar pandas UDF (order swapped) + result4 = (df.groupby('id') + .agg(plus_two(sum_udf(df.v))) + .sort('id')) + expected4 = (df.groupby('id') + .agg(plus_two(sum(df.v))) + .sort('id')) + + # Wrap group aggregate pandas UDF with two python UDFs and use python UDF in groupby + result5 = (df.groupby(plus_one(df.id)) + .agg(plus_one(sum_udf(plus_one(df.v)))) + .sort('plus_one(id)')) + expected5 = (df.groupby(plus_one(df.id)) + .agg(plus_one(sum(plus_one(df.v)))) + .sort('plus_one(id)')) + + # Wrap group aggregate pandas UDF with two scala pandas UDF and user scala pandas UDF in + # groupby + result6 = (df.groupby(plus_two(df.id)) + .agg(plus_two(sum_udf(plus_two(df.v)))) + .sort('plus_two(id)')) + expected6 = (df.groupby(plus_two(df.id)) + .agg(plus_two(sum(plus_two(df.v)))) + .sort('plus_two(id)')) + + self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) + self.assertPandasEqual(expected2.toPandas(), result2.toPandas()) + self.assertPandasEqual(expected3.toPandas(), result3.toPandas()) + self.assertPandasEqual(expected4.toPandas(), result4.toPandas()) + self.assertPandasEqual(expected5.toPandas(), result5.toPandas()) + self.assertPandasEqual(expected6.toPandas(), result6.toPandas()) + + def test_multiple_udfs(self): + """ + Test multiple group aggregate pandas UDFs in one agg function. + """ + from pyspark.sql.functions import col, lit, sum, mean + + df = self.data + mean_udf = self.pandas_agg_mean_udf + sum_udf = self.pandas_agg_sum_udf + weighted_mean_udf = self.pandas_agg_weighted_mean_udf + + result1 = (df.groupBy('id') + .agg(mean_udf(df.v), + sum_udf(df.v), + weighted_mean_udf(df.v, df.w)) + .sort('id') + .toPandas()) + expected1 = (df.groupBy('id') + .agg(mean(df.v), + sum(df.v), + mean(df.v).alias('weighted_mean(v, w)')) + .sort('id') + .toPandas()) + + self.assertPandasEqual(expected1, result1) + + def test_complex_groupby(self): + from pyspark.sql.functions import lit, sum + + df = self.data + sum_udf = self.pandas_agg_sum_udf + plus_one = self.python_plus_one + plus_two = self.pandas_scalar_plus_two + + # groupby one expression + result1 = df.groupby(df.v % 2).agg(sum_udf(df.v)) + expected1 = df.groupby(df.v % 2).agg(sum(df.v)) + + # empty groupby + result2 = df.groupby().agg(sum_udf(df.v)) + expected2 = df.groupby().agg(sum(df.v)) + + # groupby one column and one sql expression + result3 = df.groupby(df.id, df.v % 2).agg(sum_udf(df.v)) + expected3 = df.groupby(df.id, df.v % 2).agg(sum(df.v)) + + # groupby one python UDF + result4 = df.groupby(plus_one(df.id)).agg(sum_udf(df.v)) + expected4 = df.groupby(plus_one(df.id)).agg(sum(df.v)) + + # groupby one scalar pandas UDF + result5 = df.groupby(plus_two(df.id)).agg(sum_udf(df.v)) + expected5 = df.groupby(plus_two(df.id)).agg(sum(df.v)) + + # groupby one expression and one python UDF + result6 = df.groupby(df.v % 2, plus_one(df.id)).agg(sum_udf(df.v)) + expected6 = df.groupby(df.v % 2, plus_one(df.id)).agg(sum(df.v)) + + # groupby one expression and one scalar pandas UDF + result7 = df.groupby(df.v % 2, plus_two(df.id)).agg(sum_udf(df.v)).sort('sum(v)') + expected7 = df.groupby(df.v % 2, plus_two(df.id)).agg(sum(df.v)).sort('sum(v)') + + self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) + self.assertPandasEqual(expected2.toPandas(), result2.toPandas()) + self.assertPandasEqual(expected3.toPandas(), result3.toPandas()) + self.assertPandasEqual(expected4.toPandas(), result4.toPandas()) + self.assertPandasEqual(expected5.toPandas(), result5.toPandas()) + self.assertPandasEqual(expected6.toPandas(), result6.toPandas()) + self.assertPandasEqual(expected7.toPandas(), result7.toPandas()) + + def test_complex_expressions(self): + from pyspark.sql.functions import col, sum + + df = self.data + plus_one = self.python_plus_one + plus_two = self.pandas_scalar_plus_two + sum_udf = self.pandas_agg_sum_udf + + # Test complex expressions with sql expression, python UDF and + # group aggregate pandas UDF + result1 = (df.withColumn('v1', plus_one(df.v)) + .withColumn('v2', df.v + 2) + .groupby(df.id, df.v % 2) + .agg(sum_udf(col('v')), + sum_udf(col('v1') + 3), + sum_udf(col('v2')) + 5, + plus_one(sum_udf(col('v1'))), + sum_udf(plus_one(col('v2')))) + .sort('id') + .toPandas()) + + expected1 = (df.withColumn('v1', df.v + 1) + .withColumn('v2', df.v + 2) + .groupby(df.id, df.v % 2) + .agg(sum(col('v')), + sum(col('v1') + 3), + sum(col('v2')) + 5, + plus_one(sum(col('v1'))), + sum(plus_one(col('v2')))) + .sort('id') + .toPandas()) + + # Test complex expressions with sql expression, scala pandas UDF and + # group aggregate pandas UDF + result2 = (df.withColumn('v1', plus_one(df.v)) + .withColumn('v2', df.v + 2) + .groupby(df.id, df.v % 2) + .agg(sum_udf(col('v')), + sum_udf(col('v1') + 3), + sum_udf(col('v2')) + 5, + plus_two(sum_udf(col('v1'))), + sum_udf(plus_two(col('v2')))) + .sort('id') + .toPandas()) + + expected2 = (df.withColumn('v1', df.v + 1) + .withColumn('v2', df.v + 2) + .groupby(df.id, df.v % 2) + .agg(sum(col('v')), + sum(col('v1') + 3), + sum(col('v2')) + 5, + plus_two(sum(col('v1'))), + sum(plus_two(col('v2')))) + .sort('id') + .toPandas()) + + # Test sequential groupby aggregate + result3 = (df.groupby('id') + .agg(sum_udf(df.v).alias('v')) + .groupby('id') + .agg(sum_udf(col('v'))) + .sort('id') + .toPandas()) + + expected3 = (df.groupby('id') + .agg(sum(df.v).alias('v')) + .groupby('id') + .agg(sum(col('v'))) + .sort('id') + .toPandas()) + + self.assertPandasEqual(expected1, result1) + self.assertPandasEqual(expected2, result2) + self.assertPandasEqual(expected3, result3) + + def test_retain_group_columns(self): + from pyspark.sql.functions import sum, lit, col + orig_value = self.spark.conf.get("spark.sql.retainGroupColumns", None) + self.spark.conf.set("spark.sql.retainGroupColumns", False) + try: + df = self.data + sum_udf = self.pandas_agg_sum_udf + + result1 = df.groupby(df.id).agg(sum_udf(df.v)) + expected1 = df.groupby(df.id).agg(sum(df.v)) + self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) + + finally: + if orig_value is None: + self.spark.conf.unset("spark.sql.retainGroupColumns") + else: + self.spark.conf.set("spark.sql.retainGroupColumns", orig_value) + + def test_invalid_args(self): + from pyspark.sql.functions import mean + + df = self.data + plus_one = self.python_plus_one + mean_udf = self.pandas_agg_mean_udf + + with QuietTest(self.sc): + with self.assertRaisesRegexp( + AnalysisException, + 'nor.*aggregate function'): + df.groupby(df.id).agg(plus_one(df.v)).collect() + + with QuietTest(self.sc): + with self.assertRaisesRegexp( + AnalysisException, + 'aggregate function.*argument.*aggregate function'): + df.groupby(df.id).agg(mean_udf(mean_udf(df.v))).collect() + + with QuietTest(self.sc): + with self.assertRaisesRegexp( + AnalysisException, + 'mixture.*aggregate function.*group aggregate pandas UDF'): + df.groupby(df.id).agg(mean_udf(df.v), mean(df.v)).collect() + if __name__ == "__main__": from pyspark.sql.tests import * if xmlrunner: diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index 134badb8485f5..de96846c5c774 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -22,7 +22,8 @@ from pyspark import SparkContext, since from pyspark.rdd import _prepare_for_python_RDD, PythonEvalType, ignore_unicode_prefix from pyspark.sql.column import Column, _to_java_column, _to_seq -from pyspark.sql.types import StringType, DataType, StructType, _parse_datatype_string +from pyspark.sql.types import StringType, DataType, ArrayType, StructType, MapType, \ + _parse_datatype_string __all__ = ["UDFRegistration"] @@ -36,8 +37,10 @@ def _wrap_function(sc, func, returnType): def _create_udf(f, returnType, evalType): - if evalType == PythonEvalType.SQL_PANDAS_SCALAR_UDF or \ - evalType == PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF: + if evalType in (PythonEvalType.SQL_PANDAS_SCALAR_UDF, + PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF, + PythonEvalType.SQL_PANDAS_GROUP_AGG_UDF): + import inspect from pyspark.sql.utils import require_minimum_pyarrow_version @@ -113,6 +116,10 @@ def returnType(self): and not isinstance(self._returnType_placeholder, StructType): raise ValueError("Invalid returnType: returnType must be a StructType for " "pandas_udf with function type GROUP_MAP") + elif self.evalType == PythonEvalType.SQL_PANDAS_GROUP_AGG_UDF \ + and isinstance(self._returnType_placeholder, (StructType, ArrayType, MapType)): + raise NotImplementedError( + "ArrayType, StructType and MapType are not supported with PandasUDFType.GROUP_AGG") return self._returnType_placeholder diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index e6737ae1c1285..173d8fb2856fa 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -110,6 +110,17 @@ def wrapped(*series): return wrapped +def wrap_pandas_group_agg_udf(f, return_type): + arrow_return_type = to_arrow_type(return_type) + + def wrapped(*series): + import pandas as pd + result = f(*series) + return pd.Series(result) + + return lambda *a: (wrapped(*a), arrow_return_type) + + def read_single_udf(pickleSer, infile, eval_type): num_arg = read_int(infile) arg_offsets = [read_int(infile) for i in range(num_arg)] @@ -126,8 +137,12 @@ def read_single_udf(pickleSer, infile, eval_type): return arg_offsets, wrap_pandas_scalar_udf(row_func, return_type) elif eval_type == PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF: return arg_offsets, wrap_pandas_group_map_udf(row_func, return_type) - else: + elif eval_type == PythonEvalType.SQL_PANDAS_GROUP_AGG_UDF: + return arg_offsets, wrap_pandas_group_agg_udf(row_func, return_type) + elif eval_type == PythonEvalType.SQL_BATCHED_UDF: return arg_offsets, wrap_udf(row_func, return_type) + else: + raise ValueError("Unknown eval type: {}".format(eval_type)) def read_udfs(pickleSer, infile, eval_type): @@ -148,8 +163,9 @@ def read_udfs(pickleSer, infile, eval_type): func = lambda _, it: map(mapper, it) - if eval_type == PythonEvalType.SQL_PANDAS_SCALAR_UDF \ - or eval_type == PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF: + if eval_type in (PythonEvalType.SQL_PANDAS_SCALAR_UDF, + PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF, + PythonEvalType.SQL_PANDAS_GROUP_AGG_UDF): timezone = utf8_deserializer.loads(infile) ser = ArrowStreamPandasSerializer(timezone) else: diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index bbcec5627bd49..ef91d79f3302c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -153,11 +153,19 @@ trait CheckAnalysis extends PredicateHelper { s"of type ${condition.dataType.simpleString} is not a boolean.") case Aggregate(groupingExprs, aggregateExprs, child) => + def isAggregateExpression(expr: Expression) = { + expr.isInstanceOf[AggregateExpression] || PythonUDF.isGroupAggPandasUDF(expr) + } + def checkValidAggregateExpression(expr: Expression): Unit = expr match { - case aggExpr: AggregateExpression => - aggExpr.aggregateFunction.children.foreach { child => + case expr: Expression if isAggregateExpression(expr) => + val aggFunction = expr match { + case agg: AggregateExpression => agg.aggregateFunction + case udf: PythonUDF => udf + } + aggFunction.children.foreach { child => child.foreach { - case agg: AggregateExpression => + case expr: Expression if isAggregateExpression(expr) => failAnalysis( s"It is not allowed to use an aggregate function in the argument of " + s"another aggregate function. Please use the inner aggregate function " + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala similarity index 60% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala index d3f743d9eb61e..4ba8ff6e3802f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala @@ -15,12 +15,31 @@ * limitations under the License. */ -package org.apache.spark.sql.execution.python +package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.api.python.PythonFunction -import org.apache.spark.sql.catalyst.expressions.{Expression, NonSQLExpression, Unevaluable, UserDefinedExpression} +import org.apache.spark.api.python.{PythonEvalType, PythonFunction} +import org.apache.spark.sql.catalyst.util.toPrettySQL import org.apache.spark.sql.types.DataType +/** + * Helper functions for [[PythonUDF]] + */ +object PythonUDF { + private[this] val SCALAR_TYPES = Set( + PythonEvalType.SQL_BATCHED_UDF, + PythonEvalType.SQL_PANDAS_SCALAR_UDF + ) + + def isScalarPythonUDF(e: Expression): Boolean = { + e.isInstanceOf[PythonUDF] && SCALAR_TYPES.contains(e.asInstanceOf[PythonUDF].evalType) + } + + def isGroupAggPandasUDF(e: Expression): Boolean = { + e.isInstanceOf[PythonUDF] && + e.asInstanceOf[PythonUDF].evalType == PythonEvalType.SQL_PANDAS_GROUP_AGG_UDF + } +} + /** * A serialized version of a Python lambda function. */ @@ -30,12 +49,16 @@ case class PythonUDF( dataType: DataType, children: Seq[Expression], evalType: Int, - udfDeterministic: Boolean) + udfDeterministic: Boolean, + resultId: ExprId = NamedExpression.newExprId) extends Expression with Unevaluable with NonSQLExpression with UserDefinedExpression { override lazy val deterministic: Boolean = udfDeterministic && children.forall(_.deterministic) override def toString: String = s"$name(${children.mkString(", ")})" + lazy val resultAttribute: Attribute = AttributeReference(toPrettySQL(this), dataType, nullable)( + exprId = resultId) + override def nullable: Boolean = true } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index cc391aae55787..132241061d510 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.planning +import org.apache.spark.api.python.PythonEvalType import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression @@ -199,7 +200,7 @@ object ExtractFiltersAndInnerJoins extends PredicateHelper { object PhysicalAggregation { // groupingExpressions, aggregateExpressions, resultExpressions, child type ReturnType = - (Seq[NamedExpression], Seq[AggregateExpression], Seq[NamedExpression], LogicalPlan) + (Seq[NamedExpression], Seq[Expression], Seq[NamedExpression], LogicalPlan) def unapply(a: Any): Option[ReturnType] = a match { case logical.Aggregate(groupingExpressions, resultExpressions, child) => @@ -213,7 +214,10 @@ object PhysicalAggregation { expr.collect { // addExpr() always returns false for non-deterministic expressions and do not add them. case agg: AggregateExpression - if (!equivalentAggregateExpressions.addExpr(agg)) => agg + if !equivalentAggregateExpressions.addExpr(agg) => agg + case udf: PythonUDF + if PythonUDF.isGroupAggPandasUDF(udf) && + !equivalentAggregateExpressions.addExpr(udf) => udf } } @@ -241,6 +245,10 @@ object PhysicalAggregation { // so replace each aggregate expression by its corresponding attribute in the set: equivalentAggregateExpressions.getEquivalentExprs(ae).headOption .getOrElse(ae).asInstanceOf[AggregateExpression].resultAttribute + // Similar to AggregateExpression + case ue: PythonUDF if PythonUDF.isGroupAggPandasUDF(ue) => + equivalentAggregateExpressions.getEquivalentExprs(ue).headOption + .getOrElse(ue).asInstanceOf[PythonUDF].resultAttribute case expression => // Since we're using `namedGroupingAttributes` to extract the grouping key // columns, we need to replace grouping key expressions with their corresponding diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index a009c00b0abc5..d320c1c359411 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -31,7 +31,6 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util.toPrettySQL import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression -import org.apache.spark.sql.execution.python.PythonUDF import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{NumericType, StructType} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 910294853c318..ce512bc46563a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.{execution, AnalysisException, Strategy} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ @@ -288,9 +289,14 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case PhysicalAggregation( namedGroupingExpressions, aggregateExpressions, rewrittenResultExpressions, child) => + if (aggregateExpressions.exists(PythonUDF.isGroupAggPandasUDF)) { + throw new AnalysisException( + "Streaming aggregation doesn't support group aggregate pandas UDF") + } + aggregate.AggUtils.planStreamingAggregation( namedGroupingExpressions, - aggregateExpressions, + aggregateExpressions.map(expr => expr.asInstanceOf[AggregateExpression]), rewrittenResultExpressions, planLater(child)) @@ -333,8 +339,10 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { */ object Aggregation extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case PhysicalAggregation( - groupingExpressions, aggregateExpressions, resultExpressions, child) => + case PhysicalAggregation(groupingExpressions, aggExpressions, resultExpressions, child) + if aggExpressions.forall(expr => expr.isInstanceOf[AggregateExpression]) => + val aggregateExpressions = aggExpressions.map(expr => + expr.asInstanceOf[AggregateExpression]) val (functionsWithDistinct, functionsWithoutDistinct) = aggregateExpressions.partition(_.isDistinct) @@ -363,6 +371,21 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { aggregateOperator + case PhysicalAggregation(groupingExpressions, aggExpressions, resultExpressions, child) + if aggExpressions.forall(expr => expr.isInstanceOf[PythonUDF]) => + val udfExpressions = aggExpressions.map(expr => expr.asInstanceOf[PythonUDF]) + + Seq(execution.python.AggregateInPandasExec( + groupingExpressions, + udfExpressions, + resultExpressions, + planLater(child))) + + case PhysicalAggregation(_, _, _, _) => + // If cannot match the two cases above, then it's an error + throw new AnalysisException( + "Cannot use a mixture of aggregate function and group aggregate pandas UDF") + case _ => Nil } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala new file mode 100644 index 0000000000000..18e5f8605c60d --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala @@ -0,0 +1,155 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.python + +import java.io.File + +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.{SparkEnv, TaskContext} +import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning} +import org.apache.spark.sql.execution.{GroupedIterator, SparkPlan, UnaryExecNode} +import org.apache.spark.sql.types.{DataType, StructField, StructType} +import org.apache.spark.util.Utils + +/** + * Physical node for aggregation with group aggregate Pandas UDF. + * + * This plan works by sending the necessary (projected) input grouped data as Arrow record batches + * to the python worker, the python worker invokes the UDF and sends the results to the executor, + * finally the executor evaluates any post-aggregation expressions and join the result with the + * grouped key. + */ +case class AggregateInPandasExec( + groupingExpressions: Seq[NamedExpression], + udfExpressions: Seq[PythonUDF], + resultExpressions: Seq[NamedExpression], + child: SparkPlan) + extends UnaryExecNode { + + override val output: Seq[Attribute] = resultExpressions.map(_.toAttribute) + + override def outputPartitioning: Partitioning = child.outputPartitioning + + override def producedAttributes: AttributeSet = AttributeSet(output) + + override def requiredChildDistribution: Seq[Distribution] = { + if (groupingExpressions.isEmpty) { + AllTuples :: Nil + } else { + ClusteredDistribution(groupingExpressions) :: Nil + } + } + + private def collectFunctions(udf: PythonUDF): (ChainedPythonFunctions, Seq[Expression]) = { + udf.children match { + case Seq(u: PythonUDF) => + val (chained, children) = collectFunctions(u) + (ChainedPythonFunctions(chained.funcs ++ Seq(udf.func)), children) + case children => + // There should not be any other UDFs, or the children can't be evaluated directly. + assert(children.forall(_.find(_.isInstanceOf[PythonUDF]).isEmpty)) + (ChainedPythonFunctions(Seq(udf.func)), udf.children) + } + } + + override def requiredChildOrdering: Seq[Seq[SortOrder]] = + Seq(groupingExpressions.map(SortOrder(_, Ascending))) + + override protected def doExecute(): RDD[InternalRow] = { + val inputRDD = child.execute() + + val bufferSize = inputRDD.conf.getInt("spark.buffer.size", 65536) + val reuseWorker = inputRDD.conf.getBoolean("spark.python.worker.reuse", defaultValue = true) + val sessionLocalTimeZone = conf.sessionLocalTimeZone + val pandasRespectSessionTimeZone = conf.pandasRespectSessionTimeZone + + val (pyFuncs, inputs) = udfExpressions.map(collectFunctions).unzip + + // Filter child output attributes down to only those that are UDF inputs. + // Also eliminate duplicate UDF inputs. + val allInputs = new ArrayBuffer[Expression] + val dataTypes = new ArrayBuffer[DataType] + val argOffsets = inputs.map { input => + input.map { e => + if (allInputs.exists(_.semanticEquals(e))) { + allInputs.indexWhere(_.semanticEquals(e)) + } else { + allInputs += e + dataTypes += e.dataType + allInputs.length - 1 + } + }.toArray + }.toArray + + // Schema of input rows to the python runner + val aggInputSchema = StructType(dataTypes.zipWithIndex.map { case (dt, i) => + StructField(s"_$i", dt) + }) + + inputRDD.mapPartitionsInternal { iter => + val prunedProj = UnsafeProjection.create(allInputs, child.output) + + val grouped = if (groupingExpressions.isEmpty) { + // Use an empty unsafe row as a place holder for the grouping key + Iterator((new UnsafeRow(), iter)) + } else { + GroupedIterator(iter, groupingExpressions, child.output) + }.map { case (key, rows) => + (key, rows.map(prunedProj)) + } + + val context = TaskContext.get() + + // The queue used to buffer input rows so we can drain it to + // combine input with output from Python. + val queue = HybridRowQueue(context.taskMemoryManager(), + new File(Utils.getLocalDir(SparkEnv.get.conf)), groupingExpressions.length) + context.addTaskCompletionListener { _ => + queue.close() + } + + // Add rows to queue to join later with the result. + val projectedRowIter = grouped.map { case (groupingKey, rows) => + queue.add(groupingKey.asInstanceOf[UnsafeRow]) + rows + } + + val columnarBatchIter = new ArrowPythonRunner( + pyFuncs, bufferSize, reuseWorker, + PythonEvalType.SQL_PANDAS_GROUP_AGG_UDF, argOffsets, aggInputSchema, + sessionLocalTimeZone, pandasRespectSessionTimeZone) + .compute(projectedRowIter, context.partitionId(), context) + + val joinedAttributes = + groupingExpressions.map(_.toAttribute) ++ udfExpressions.map(_.resultAttribute) + val joined = new JoinedRow + val resultProj = UnsafeProjection.create(resultExpressions, joinedAttributes) + + columnarBatchIter.map(_.rowIterator.next()).map { aggOutputRow => + val leftRow = queue.remove() + val joinedRow = joined(leftRow, aggOutputRow) + resultProj(joinedRow) + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala index 2f53fe788c7d0..1862e3f6e12ca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala @@ -39,12 +39,13 @@ object ExtractPythonUDFFromAggregate extends Rule[LogicalPlan] { */ private def belongAggregate(e: Expression, agg: Aggregate): Boolean = { e.isInstanceOf[AggregateExpression] || + PythonUDF.isGroupAggPandasUDF(e) || agg.groupingExpressions.exists(_.semanticEquals(e)) } private def hasPythonUdfOverAggregate(expr: Expression, agg: Aggregate): Boolean = { expr.find { - e => e.isInstanceOf[PythonUDF] && e.find(belongAggregate(_, agg)).isDefined + e => PythonUDF.isScalarPythonUDF(e) && e.find(belongAggregate(_, agg)).isDefined }.isDefined } @@ -93,7 +94,7 @@ object ExtractPythonUDFFromAggregate extends Rule[LogicalPlan] { object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper { private def hasPythonUDF(e: Expression): Boolean = { - e.find(_.isInstanceOf[PythonUDF]).isDefined + e.find(PythonUDF.isScalarPythonUDF).isDefined } private def canEvaluateInPython(e: PythonUDF): Boolean = { @@ -106,12 +107,12 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper { } private def collectEvaluatableUDF(expr: Expression): Seq[PythonUDF] = expr match { - case udf: PythonUDF if canEvaluateInPython(udf) => Seq(udf) + case udf: PythonUDF if PythonUDF.isScalarPythonUDF(udf) && canEvaluateInPython(udf) => Seq(udf) case e => e.children.flatMap(collectEvaluatableUDF) } def apply(plan: SparkPlan): SparkPlan = plan transformUp { - // FlatMapGroupsInPandas can be evaluated directly in python worker + // AggregateInPandasExec and FlatMapGroupsInPandas can be evaluated directly in python worker // Therefore we don't need to extract the UDFs case plan: FlatMapGroupsInPandasExec => plan case plan: SparkPlan => extract(plan) @@ -149,10 +150,9 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper { udf.references.subsetOf(child.outputSet) } if (validUdfs.nonEmpty) { - require(validUdfs.forall(udf => - udf.evalType == PythonEvalType.SQL_BATCHED_UDF || - udf.evalType == PythonEvalType.SQL_PANDAS_SCALAR_UDF - ), "Can only extract scalar vectorized udf or sql batch udf") + require( + validUdfs.forall(PythonUDF.isScalarPythonUDF), + "Can only extract scalar vectorized udf or sql batch udf") val resultAttrs = udfs.zipWithIndex.map { case (u, i) => AttributeReference(s"pythonUDF$i", u.dataType)() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala index 50dca32cb7861..f4c2d02ee9420 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.python import org.apache.spark.api.python.PythonFunction import org.apache.spark.sql.Column -import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.{Expression, PythonUDF} import org.apache.spark.sql.types.DataType /** From 96cb60bc33936c1aaf728a1738781073891480ff Mon Sep 17 00:00:00 2001 From: Xingbo Jiang Date: Tue, 23 Jan 2018 04:08:32 -0800 Subject: [PATCH 0178/1161] [SPARK-22465][FOLLOWUP] Update the number of partitions of default partitioner when defaultParallelism is set ## What changes were proposed in this pull request? #20002 purposed a way to safe check the default partitioner, however, if `spark.default.parallelism` is set, the defaultParallelism still could be smaller than the proper number of partitions for upstreams RDDs. This PR tries to extend the approach to address the condition when `spark.default.parallelism` is set. The requirements where the PR helps with are : - Max partitioner is not eligible since it is atleast an order smaller, and - User has explicitly set 'spark.default.parallelism', and - Value of 'spark.default.parallelism' is lower than max partitioner - Since max partitioner was discarded due to being at least an order smaller, default parallelism is worse - even though user specified. Under the rest cases, the changes should be no-op. ## How was this patch tested? Add corresponding test cases in `PairRDDFunctionsSuite` and `PartitioningSuite`. Author: Xingbo Jiang Closes #20091 from jiangxb1987/partitioner. --- .../scala/org/apache/spark/Partitioner.scala | 51 ++++++++++--------- .../org/apache/spark/PartitioningSuite.scala | 44 +++++++++++++--- .../spark/rdd/PairRDDFunctionsSuite.scala | 45 +++++++++++++++- 3 files changed, 108 insertions(+), 32 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/Partitioner.scala b/core/src/main/scala/org/apache/spark/Partitioner.scala index 437bbaae1968b..c940cb25d478b 100644 --- a/core/src/main/scala/org/apache/spark/Partitioner.scala +++ b/core/src/main/scala/org/apache/spark/Partitioner.scala @@ -43,17 +43,19 @@ object Partitioner { /** * Choose a partitioner to use for a cogroup-like operation between a number of RDDs. * - * If any of the RDDs already has a partitioner, and the number of partitions of the - * partitioner is either greater than or is less than and within a single order of - * magnitude of the max number of upstream partitions, choose that one. + * If spark.default.parallelism is set, we'll use the value of SparkContext defaultParallelism + * as the default partitions number, otherwise we'll use the max number of upstream partitions. * - * Otherwise, we use a default HashPartitioner. For the number of partitions, if - * spark.default.parallelism is set, then we'll use the value from SparkContext - * defaultParallelism, otherwise we'll use the max number of upstream partitions. + * When available, we choose the partitioner from rdds with maximum number of partitions. If this + * partitioner is eligible (number of partitions within an order of maximum number of partitions + * in rdds), or has partition number higher than default partitions number - we use this + * partitioner. * - * Unless spark.default.parallelism is set, the number of partitions will be the - * same as the number of partitions in the largest upstream RDD, as this should - * be least likely to cause out-of-memory errors. + * Otherwise, we'll use a new HashPartitioner with the default partitions number. + * + * Unless spark.default.parallelism is set, the number of partitions will be the same as the + * number of partitions in the largest upstream RDD, as this should be least likely to cause + * out-of-memory errors. * * We use two method parameters (rdd, others) to enforce callers passing at least 1 RDD. */ @@ -67,31 +69,32 @@ object Partitioner { None } - if (isEligiblePartitioner(hasMaxPartitioner, rdds)) { + val defaultNumPartitions = if (rdd.context.conf.contains("spark.default.parallelism")) { + rdd.context.defaultParallelism + } else { + rdds.map(_.partitions.length).max + } + + // If the existing max partitioner is an eligible one, or its partitions number is larger + // than the default number of partitions, use the existing partitioner. + if (hasMaxPartitioner.nonEmpty && (isEligiblePartitioner(hasMaxPartitioner.get, rdds) || + defaultNumPartitions < hasMaxPartitioner.get.getNumPartitions)) { hasMaxPartitioner.get.partitioner.get } else { - if (rdd.context.conf.contains("spark.default.parallelism")) { - new HashPartitioner(rdd.context.defaultParallelism) - } else { - new HashPartitioner(rdds.map(_.partitions.length).max) - } + new HashPartitioner(defaultNumPartitions) } } /** - * Returns true if the number of partitions of the RDD is either greater - * than or is less than and within a single order of magnitude of the - * max number of upstream partitions; - * otherwise, returns false + * Returns true if the number of partitions of the RDD is either greater than or is less than and + * within a single order of magnitude of the max number of upstream partitions, otherwise returns + * false. */ private def isEligiblePartitioner( - hasMaxPartitioner: Option[RDD[_]], + hasMaxPartitioner: RDD[_], rdds: Seq[RDD[_]]): Boolean = { - if (hasMaxPartitioner.isEmpty) { - return false - } val maxPartitions = rdds.map(_.partitions.length).max - log10(maxPartitions) - log10(hasMaxPartitioner.get.getNumPartitions) < 1 + log10(maxPartitions) - log10(hasMaxPartitioner.getNumPartitions) < 1 } } diff --git a/core/src/test/scala/org/apache/spark/PartitioningSuite.scala b/core/src/test/scala/org/apache/spark/PartitioningSuite.scala index 155ca17db726b..9206b5debf4f3 100644 --- a/core/src/test/scala/org/apache/spark/PartitioningSuite.scala +++ b/core/src/test/scala/org/apache/spark/PartitioningSuite.scala @@ -262,14 +262,11 @@ class PartitioningSuite extends SparkFunSuite with SharedSparkContext with Priva test("defaultPartitioner") { val rdd1 = sc.parallelize((1 to 1000).map(x => (x, x)), 150) - val rdd2 = sc - .parallelize(Array((1, 2), (2, 3), (2, 4), (3, 4))) + val rdd2 = sc.parallelize(Array((1, 2), (2, 3), (2, 4), (3, 4))) .partitionBy(new HashPartitioner(10)) - val rdd3 = sc - .parallelize(Array((1, 6), (7, 8), (3, 10), (5, 12), (13, 14))) + val rdd3 = sc.parallelize(Array((1, 6), (7, 8), (3, 10), (5, 12), (13, 14))) .partitionBy(new HashPartitioner(100)) - val rdd4 = sc - .parallelize(Array((1, 2), (2, 3), (2, 4), (3, 4))) + val rdd4 = sc.parallelize(Array((1, 2), (2, 3), (2, 4), (3, 4))) .partitionBy(new HashPartitioner(9)) val rdd5 = sc.parallelize((1 to 10).map(x => (x, x)), 11) @@ -284,7 +281,42 @@ class PartitioningSuite extends SparkFunSuite with SharedSparkContext with Priva assert(partitioner3.numPartitions == rdd3.getNumPartitions) assert(partitioner4.numPartitions == rdd3.getNumPartitions) assert(partitioner5.numPartitions == rdd4.getNumPartitions) + } + test("defaultPartitioner when defaultParallelism is set") { + assert(!sc.conf.contains("spark.default.parallelism")) + try { + sc.conf.set("spark.default.parallelism", "4") + + val rdd1 = sc.parallelize((1 to 1000).map(x => (x, x)), 150) + val rdd2 = sc.parallelize(Array((1, 2), (2, 3), (2, 4), (3, 4))) + .partitionBy(new HashPartitioner(10)) + val rdd3 = sc.parallelize(Array((1, 6), (7, 8), (3, 10), (5, 12), (13, 14))) + .partitionBy(new HashPartitioner(100)) + val rdd4 = sc.parallelize(Array((1, 2), (2, 3), (2, 4), (3, 4))) + .partitionBy(new HashPartitioner(9)) + val rdd5 = sc.parallelize((1 to 10).map(x => (x, x)), 11) + val rdd6 = sc.parallelize(Array((1, 2), (2, 3), (2, 4), (3, 4))) + .partitionBy(new HashPartitioner(3)) + + val partitioner1 = Partitioner.defaultPartitioner(rdd1, rdd2) + val partitioner2 = Partitioner.defaultPartitioner(rdd2, rdd3) + val partitioner3 = Partitioner.defaultPartitioner(rdd3, rdd1) + val partitioner4 = Partitioner.defaultPartitioner(rdd1, rdd2, rdd3) + val partitioner5 = Partitioner.defaultPartitioner(rdd4, rdd5) + val partitioner6 = Partitioner.defaultPartitioner(rdd5, rdd5) + val partitioner7 = Partitioner.defaultPartitioner(rdd1, rdd6) + + assert(partitioner1.numPartitions == rdd2.getNumPartitions) + assert(partitioner2.numPartitions == rdd3.getNumPartitions) + assert(partitioner3.numPartitions == rdd3.getNumPartitions) + assert(partitioner4.numPartitions == rdd3.getNumPartitions) + assert(partitioner5.numPartitions == rdd4.getNumPartitions) + assert(partitioner6.numPartitions == sc.defaultParallelism) + assert(partitioner7.numPartitions == sc.defaultParallelism) + } finally { + sc.conf.remove("spark.default.parallelism") + } } } diff --git a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala index a39e0469272fe..47af5c3320dd9 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala @@ -322,8 +322,7 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext { } // See SPARK-22465 - test("cogroup between multiple RDD" + - " with number of partitions similar in order of magnitude") { + test("cogroup between multiple RDD with number of partitions similar in order of magnitude") { val rdd1 = sc.parallelize((1 to 1000).map(x => (x, x)), 20) val rdd2 = sc .parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1))) @@ -332,6 +331,48 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext { assert(joined.getNumPartitions == rdd2.getNumPartitions) } + test("cogroup between multiple RDD when defaultParallelism is set without proper partitioner") { + assert(!sc.conf.contains("spark.default.parallelism")) + try { + sc.conf.set("spark.default.parallelism", "4") + val rdd1 = sc.parallelize((1 to 1000).map(x => (x, x)), 20) + val rdd2 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)), 10) + val joined = rdd1.cogroup(rdd2) + assert(joined.getNumPartitions == sc.defaultParallelism) + } finally { + sc.conf.remove("spark.default.parallelism") + } + } + + test("cogroup between multiple RDD when defaultParallelism is set with proper partitioner") { + assert(!sc.conf.contains("spark.default.parallelism")) + try { + sc.conf.set("spark.default.parallelism", "4") + val rdd1 = sc.parallelize((1 to 1000).map(x => (x, x)), 20) + val rdd2 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1))) + .partitionBy(new HashPartitioner(10)) + val joined = rdd1.cogroup(rdd2) + assert(joined.getNumPartitions == rdd2.getNumPartitions) + } finally { + sc.conf.remove("spark.default.parallelism") + } + } + + test("cogroup between multiple RDD when defaultParallelism is set; with huge number of " + + "partitions in upstream RDDs") { + assert(!sc.conf.contains("spark.default.parallelism")) + try { + sc.conf.set("spark.default.parallelism", "4") + val rdd1 = sc.parallelize((1 to 1000).map(x => (x, x)), 1000) + val rdd2 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1))) + .partitionBy(new HashPartitioner(10)) + val joined = rdd1.cogroup(rdd2) + assert(joined.getNumPartitions == rdd2.getNumPartitions) + } finally { + sc.conf.remove("spark.default.parallelism") + } + } + test("rightOuterJoin") { val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1))) val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w'))) From ee572ba8c1339d21c592001ec4f7f270005ff1cf Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 23 Jan 2018 21:36:20 +0900 Subject: [PATCH 0179/1161] [SPARK-20749][SQL][FOLLOW-UP] Override prettyName for bit_length and octet_length ## What changes were proposed in this pull request? We need to override the prettyName for bit_length and octet_length for getting the expected auto-generated alias name. ## How was this patch tested? The existing tests Author: gatorsmile Closes #20358 from gatorsmile/test2.3More. --- .../spark/sql/catalyst/parser/SqlBase.g4 | 2 +- .../expressions/stringExpressions.scala | 4 ++ .../sql-tests/results/operators.sql.out | 4 +- .../scalar-subquery-predicate.sql.out | 45 ++++++++++--------- 4 files changed, 30 insertions(+), 25 deletions(-) diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 39d5e4ed56628..5fa75fe348e68 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -141,7 +141,7 @@ statement (LIKE? pattern=STRING)? #showTables | SHOW TABLE EXTENDED ((FROM | IN) db=identifier)? LIKE pattern=STRING partitionSpec? #showTable - | SHOW DATABASES (LIKE? pattern=STRING)? #showDatabases + | SHOW DATABASES (LIKE? pattern=STRING)? #showDatabases | SHOW TBLPROPERTIES table=tableIdentifier ('(' key=tablePropertyKey ')')? #showTblProperties | SHOW COLUMNS (FROM | IN) tableIdentifier diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index e004bfc6af473..5cf783f1a5979 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -1708,6 +1708,8 @@ case class BitLength(child: Expression) extends UnaryExpression with ImplicitCas case BinaryType => defineCodeGen(ctx, ev, c => s"($c).length * 8") } } + + override def prettyName: String = "bit_length" } /** @@ -1735,6 +1737,8 @@ case class OctetLength(child: Expression) extends UnaryExpression with ImplicitC case BinaryType => defineCodeGen(ctx, ev, c => s"($c).length") } } + + override def prettyName: String = "octet_length" } /** diff --git a/sql/core/src/test/resources/sql-tests/results/operators.sql.out b/sql/core/src/test/resources/sql-tests/results/operators.sql.out index 237b618a8b904..840655b7a6447 100644 --- a/sql/core/src/test/resources/sql-tests/results/operators.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/operators.sql.out @@ -425,7 +425,7 @@ struct<(7 % 2):int,(7 % 0):int,(0 % 2):int,(7 % CAST(NULL AS INT)):int,(CAST(NUL -- !query 51 select BIT_LENGTH('abc') -- !query 51 schema -struct +struct -- !query 51 output 24 @@ -449,7 +449,7 @@ struct -- !query 54 select OCTET_LENGTH('abc') -- !query 54 schema -struct +struct -- !query 54 output 3 diff --git a/sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/scalar-subquery-predicate.sql.out b/sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/scalar-subquery-predicate.sql.out index a2b86db3e4f4c..dd82efba0dde1 100644 --- a/sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/scalar-subquery-predicate.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/scalar-subquery-predicate.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 29 +-- Number of queries: 27 -- !query 0 @@ -307,7 +307,8 @@ struct val1c val1d --- !query 22 + +-- !query 20 SELECT count(t1a) FROM t1 RIGHT JOIN t2 ON t1d = t2d @@ -315,13 +316,13 @@ WHERE t1a < (SELECT max(t2a) FROM t2 WHERE t2c = t1c GROUP BY t2c) --- !query 22 schema +-- !query 20 schema struct --- !query 22 output +-- !query 20 output 7 --- !query 23 +-- !query 21 SELECT t1a FROM t1 WHERE t1b <= (SELECT max(t2b) @@ -332,14 +333,14 @@ AND t1b >= (SELECT min(t2b) FROM t2 WHERE t2c = t1c GROUP BY t2c) --- !query 23 schema +-- !query 21 schema struct --- !query 23 output +-- !query 21 output val1b val1c --- !query 24 +-- !query 22 SELECT t1a FROM t1 WHERE t1a <= (SELECT max(t2a) @@ -353,14 +354,14 @@ WHERE t1a >= (SELECT min(t2a) FROM t2 WHERE t2c = t1c GROUP BY t2c) --- !query 24 schema +-- !query 22 schema struct --- !query 24 output +-- !query 22 output val1b val1c --- !query 25 +-- !query 23 SELECT t1a FROM t1 WHERE t1a <= (SELECT max(t2a) @@ -374,9 +375,9 @@ WHERE t1a >= (SELECT min(t2a) FROM t2 WHERE t2c = t1c GROUP BY t2c) --- !query 25 schema +-- !query 23 schema struct --- !query 25 output +-- !query 23 output val1a val1a val1b @@ -387,7 +388,7 @@ val1d val1d --- !query 26 +-- !query 24 SELECT t1a FROM t1 WHERE t1a <= (SELECT max(t2a) @@ -401,16 +402,16 @@ WHERE t1a >= (SELECT min(t2a) FROM t2 WHERE t2c = t1c GROUP BY t2c) --- !query 26 schema +-- !query 24 schema struct --- !query 26 output +-- !query 24 output val1a val1b val1c val1d --- !query 27 +-- !query 25 SELECT t1a FROM t1 WHERE t1a <= (SELECT max(t2a) @@ -424,13 +425,13 @@ WHERE t1a >= (SELECT min(t2a) FROM t2 WHERE t2c = t1c GROUP BY t2c) --- !query 27 schema +-- !query 25 schema struct --- !query 27 output +-- !query 25 output val1a --- !query 28 +-- !query 26 SELECT t1a FROM t1 GROUP BY t1a, t1c @@ -438,8 +439,8 @@ HAVING max(t1b) <= (SELECT max(t2b) FROM t2 WHERE t2c = t1c GROUP BY t2c) --- !query 28 schema +-- !query 26 schema struct --- !query 28 output +-- !query 26 output val1b val1c From bdebb8e48eafcca0382d1a3173b2f3ce969abab3 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Tue, 23 Jan 2018 10:12:13 -0800 Subject: [PATCH 0180/1161] [SPARK-20664][SPARK-23103][CORE] Follow-up: remove workaround for . Author: Marcelo Vanzin Closes #20353 from vanzin/SPARK-20664. --- .../apache/spark/deploy/history/FsHistoryProviderSuite.scala | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala index 787de59edf465..fde5f25bce456 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala @@ -716,9 +716,7 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc } test("SPARK-21571: clean up removes invalid history files") { - // TODO: "maxTime" becoming negative in cleanLogs() causes this test to fail, so avoid that - // until we figure out what's causing the problem. - val clock = new ManualClock(TimeUnit.DAYS.toMillis(120)) + val clock = new ManualClock() val conf = createTestConf().set(MAX_LOG_AGE_S.key, s"2d") val provider = new FsHistoryProvider(conf, clock) { override def getNewLastScanTime(): Long = clock.getTimeMillis() From dc4761fd8f0eec1d001e53837e65f7c5fe4e248d Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Tue, 23 Jan 2018 12:51:40 -0800 Subject: [PATCH 0181/1161] [SPARK-17088][HIVE] Fix 'sharesHadoopClasses' option when creating client. Because the call to the constructor of HiveClientImpl crosses class loader boundaries, different versions of the same class (Configuration in this case) were loaded, and that caused a runtime error when instantiating the client. By using a safer type in the signature of the constructor, it's possible to avoid the problem. I considered removing 'sharesHadoopClasses', but it may still be desired (even though there are 0 users of it since it was not working). When Spark starts to support Hadoop 3, it may be necessary to use that option to load clients for older Hive metastore versions that don't know about Hadoop 3. Tested with added unit test. Author: Marcelo Vanzin Closes #20169 from vanzin/SPARK-17088. --- .../spark/sql/hive/client/HiveClientImpl.scala | 8 +++++--- .../sql/hive/client/IsolatedClientLoader.scala | 16 ++++++++++------ .../sql/hive/client/HiveClientBuilder.scala | 6 ++++-- .../spark/sql/hive/client/HiveClientSuite.scala | 4 ++++ .../spark/sql/hive/client/HiveVersionSuite.scala | 11 ++++++++--- 5 files changed, 31 insertions(+), 14 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala index 4b923f5235a90..39d839059be75 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala @@ -18,7 +18,8 @@ package org.apache.spark.sql.hive.client import java.io.{File, PrintStream} -import java.util.Locale +import java.lang.{Iterable => JIterable} +import java.util.{Locale, Map => JMap} import scala.collection.JavaConverters._ import scala.collection.mutable @@ -82,8 +83,9 @@ import org.apache.spark.util.{CircularBuffer, Utils} */ private[hive] class HiveClientImpl( override val version: HiveVersion, + warehouseDir: Option[String], sparkConf: SparkConf, - hadoopConf: Configuration, + hadoopConf: JIterable[JMap.Entry[String, String]], extraConfig: Map[String, String], initClassLoader: ClassLoader, val clientLoader: IsolatedClientLoader) @@ -130,7 +132,7 @@ private[hive] class HiveClientImpl( if (ret != null) { // hive.metastore.warehouse.dir is determined in SharedState after the CliSessionState // instance constructed, we need to follow that change here. - Option(hadoopConf.get(ConfVars.METASTOREWAREHOUSE.varname)).foreach { dir => + warehouseDir.foreach { dir => ret.getConf.setVar(ConfVars.METASTOREWAREHOUSE, dir) } ret diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala index 7a76fd3fd2eb3..dac0e333b63bc 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala @@ -26,6 +26,7 @@ import scala.util.Try import org.apache.commons.io.{FileUtils, IOUtils} import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.spark.SparkConf import org.apache.spark.deploy.SparkSubmitUtils @@ -48,11 +49,12 @@ private[hive] object IsolatedClientLoader extends Logging { config: Map[String, String] = Map.empty, ivyPath: Option[String] = None, sharedPrefixes: Seq[String] = Seq.empty, - barrierPrefixes: Seq[String] = Seq.empty): IsolatedClientLoader = synchronized { + barrierPrefixes: Seq[String] = Seq.empty, + sharesHadoopClasses: Boolean = true): IsolatedClientLoader = synchronized { val resolvedVersion = hiveVersion(hiveMetastoreVersion) // We will first try to share Hadoop classes. If we cannot resolve the Hadoop artifact // with the given version, we will use Hadoop 2.6 and then will not share Hadoop classes. - var sharesHadoopClasses = true + var _sharesHadoopClasses = sharesHadoopClasses val files = if (resolvedVersions.contains((resolvedVersion, hadoopVersion))) { resolvedVersions((resolvedVersion, hadoopVersion)) } else { @@ -68,7 +70,7 @@ private[hive] object IsolatedClientLoader extends Logging { "Hadoop classes will not be shared between Spark and Hive metastore client. " + "It is recommended to set jars used by Hive metastore client through " + "spark.sql.hive.metastore.jars in the production environment.") - sharesHadoopClasses = false + _sharesHadoopClasses = false (downloadVersion(resolvedVersion, "2.6.5", ivyPath), "2.6.5") } resolvedVersions.put((resolvedVersion, actualHadoopVersion), downloadedFiles) @@ -81,7 +83,7 @@ private[hive] object IsolatedClientLoader extends Logging { execJars = files, hadoopConf = hadoopConf, config = config, - sharesHadoopClasses = sharesHadoopClasses, + sharesHadoopClasses = _sharesHadoopClasses, sharedPrefixes = sharedPrefixes, barrierPrefixes = barrierPrefixes) } @@ -249,8 +251,10 @@ private[hive] class IsolatedClientLoader( /** The isolated client interface to Hive. */ private[hive] def createClient(): HiveClient = synchronized { + val warehouseDir = Option(hadoopConf.get(ConfVars.METASTOREWAREHOUSE.varname)) if (!isolationOn) { - return new HiveClientImpl(version, sparkConf, hadoopConf, config, baseClassLoader, this) + return new HiveClientImpl(version, warehouseDir, sparkConf, hadoopConf, config, + baseClassLoader, this) } // Pre-reflective instantiation setup. logDebug("Initializing the logger to avoid disaster...") @@ -261,7 +265,7 @@ private[hive] class IsolatedClientLoader( classLoader .loadClass(classOf[HiveClientImpl].getName) .getConstructors.head - .newInstance(version, sparkConf, hadoopConf, config, classLoader, this) + .newInstance(version, warehouseDir, sparkConf, hadoopConf, config, classLoader, this) .asInstanceOf[HiveClient] } catch { case e: InvocationTargetException => diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientBuilder.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientBuilder.scala index ae804ce7c7b07..ab73f668c6ca6 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientBuilder.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientBuilder.scala @@ -46,13 +46,15 @@ private[client] object HiveClientBuilder { def buildClient( version: String, hadoopConf: Configuration, - extraConf: Map[String, String] = Map.empty): HiveClient = { + extraConf: Map[String, String] = Map.empty, + sharesHadoopClasses: Boolean = true): HiveClient = { IsolatedClientLoader.forVersion( hiveMetastoreVersion = version, hadoopVersion = VersionInfo.getVersion, sparkConf = new SparkConf(), hadoopConf = hadoopConf, config = buildConf(extraConf), - ivyPath = ivyPath).createClient() + ivyPath = ivyPath, + sharesHadoopClasses = sharesHadoopClasses).createClient() } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala index a5dfd89b3a574..f991352b207d4 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala @@ -202,6 +202,10 @@ class HiveClientSuite(version: String) day1 :: day2 :: Nil) } + test("create client with sharesHadoopClasses = false") { + buildClient(new Configuration(), sharesHadoopClasses = false) + } + private def testMetastorePartitionFiltering( filterString: String, expectedDs: Seq[Int], diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveVersionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveVersionSuite.scala index bb8a4697b0a13..a70fb6464cc1d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveVersionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveVersionSuite.scala @@ -28,7 +28,9 @@ private[client] abstract class HiveVersionSuite(version: String) extends SparkFu override protected val enableAutoThreadAudit = false protected var client: HiveClient = null - protected def buildClient(hadoopConf: Configuration): HiveClient = { + protected def buildClient( + hadoopConf: Configuration, + sharesHadoopClasses: Boolean = true): HiveClient = { // Hive changed the default of datanucleus.schema.autoCreateAll from true to false and // hive.metastore.schema.verification from false to true since 2.0 // For details, see the JIRA HIVE-6113 and HIVE-12463 @@ -36,8 +38,11 @@ private[client] abstract class HiveVersionSuite(version: String) extends SparkFu hadoopConf.set("datanucleus.schema.autoCreateAll", "true") hadoopConf.set("hive.metastore.schema.verification", "false") } - HiveClientBuilder - .buildClient(version, hadoopConf, HiveUtils.formatTimeVarsForHiveClient(hadoopConf)) + HiveClientBuilder.buildClient( + version, + hadoopConf, + HiveUtils.formatTimeVarsForHiveClient(hadoopConf), + sharesHadoopClasses = sharesHadoopClasses) } override def suiteName: String = s"${super.suiteName}($version)" From 05839d164836e544af79c13de25802552eadd636 Mon Sep 17 00:00:00 2001 From: Bago Amirbekian Date: Tue, 23 Jan 2018 14:11:23 -0800 Subject: [PATCH 0182/1161] [SPARK-22735][ML][DOC] Added VectorSizeHint docs and examples. ## What changes were proposed in this pull request? Added documentation for new transformer. Author: Bago Amirbekian Closes #20285 from MrBago/sizeHintDocs. --- docs/ml-features.md | 51 ++++++++++++ .../ml/JavaVectorSizeHintExample.java | 79 +++++++++++++++++++ .../python/ml/vector_size_hint_example.py | 57 +++++++++++++ .../examples/ml/VectorSizeHintExample.scala | 63 +++++++++++++++ 4 files changed, 250 insertions(+) create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaVectorSizeHintExample.java create mode 100644 examples/src/main/python/ml/vector_size_hint_example.py create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/VectorSizeHintExample.scala diff --git a/docs/ml-features.md b/docs/ml-features.md index 466a8fbe99cf6..3370eb3893272 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -1291,6 +1291,57 @@ for more details on the API.
+## VectorSizeHint + +It can sometimes be useful to explicitly specify the size of the vectors for a column of +`VectorType`. For example, `VectorAssembler` uses size information from its input columns to +produce size information and metadata for its output column. While in some cases this information +can be obtained by inspecting the contents of the column, in a streaming dataframe the contents are +not available until the stream is started. `VectorSizeHint` allows a user to explicitly specify the +vector size for a column so that `VectorAssembler`, or other transformers that might +need to know vector size, can use that column as an input. + +To use `VectorSizeHint` a user must set the `inputCol` and `size` parameters. Applying this +transformer to a dataframe produces a new dataframe with updated metadata for `inputCol` specifying +the vector size. Downstream operations on the resulting dataframe can get this size using the +meatadata. + +`VectorSizeHint` can also take an optional `handleInvalid` parameter which controls its +behaviour when the vector column contains nulls or vectors of the wrong size. By default +`handleInvalid` is set to "error", indicating an exception should be thrown. This parameter can +also be set to "skip", indicating that rows containing invalid values should be filtered out from +the resulting dataframe, or "optimistic", indicating that the column should not be checked for +invalid values and all rows should be kept. Note that the use of "optimistic" can cause the +resulting dataframe to be in an inconsistent state, me:aning the metadata for the column +`VectorSizeHint` was applied to does not match the contents of that column. Users should take care +to avoid this kind of inconsistent state. + +
+
+ +Refer to the [VectorSizeHint Scala docs](api/scala/index.html#org.apache.spark.ml.feature.VectorSizeHint) +for more details on the API. + +{% include_example scala/org/apache/spark/examples/ml/VectorSizeHintExample.scala %} +
+ +
+ +Refer to the [VectorSizeHint Java docs](api/java/org/apache/spark/ml/feature/VectorSizeHint.html) +for more details on the API. + +{% include_example java/org/apache/spark/examples/ml/JavaVectorSizeHintExample.java %} +
+ +
+ +Refer to the [VectorSizeHint Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.VectorSizeHint) +for more details on the API. + +{% include_example python/ml/vector_size_hint_example.py %} +
+
+ ## QuantileDiscretizer `QuantileDiscretizer` takes a column with continuous features and outputs a column with binned diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorSizeHintExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorSizeHintExample.java new file mode 100644 index 0000000000000..d649a2ccbaa72 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorSizeHintExample.java @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.sql.SparkSession; + +// $example on$ +import java.util.Arrays; + +import org.apache.spark.ml.feature.VectorAssembler; +import org.apache.spark.ml.feature.VectorSizeHint; +import org.apache.spark.ml.linalg.VectorUDT; +import org.apache.spark.ml.linalg.Vectors; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import static org.apache.spark.sql.types.DataTypes.*; +// $example off$ + +public class JavaVectorSizeHintExample { + public static void main(String[] args) { + SparkSession spark = SparkSession + .builder() + .appName("JavaVectorSizeHintExample") + .getOrCreate(); + + // $example on$ + StructType schema = createStructType(new StructField[]{ + createStructField("id", IntegerType, false), + createStructField("hour", IntegerType, false), + createStructField("mobile", DoubleType, false), + createStructField("userFeatures", new VectorUDT(), false), + createStructField("clicked", DoubleType, false) + }); + Row row0 = RowFactory.create(0, 18, 1.0, Vectors.dense(0.0, 10.0, 0.5), 1.0); + Row row1 = RowFactory.create(0, 18, 1.0, Vectors.dense(0.0, 10.0), 0.0); + Dataset dataset = spark.createDataFrame(Arrays.asList(row0, row1), schema); + + VectorSizeHint sizeHint = new VectorSizeHint() + .setInputCol("userFeatures") + .setHandleInvalid("skip") + .setSize(3); + + Dataset datasetWithSize = sizeHint.transform(dataset); + System.out.println("Rows where 'userFeatures' is not the right size are filtered out"); + datasetWithSize.show(false); + + VectorAssembler assembler = new VectorAssembler() + .setInputCols(new String[]{"hour", "mobile", "userFeatures"}) + .setOutputCol("features"); + + // This dataframe can be used by downstream transformers as before + Dataset output = assembler.transform(datasetWithSize); + System.out.println("Assembled columns 'hour', 'mobile', 'userFeatures' to vector column " + + "'features'"); + output.select("features", "clicked").show(false); + // $example off$ + + spark.stop(); + } +} + diff --git a/examples/src/main/python/ml/vector_size_hint_example.py b/examples/src/main/python/ml/vector_size_hint_example.py new file mode 100644 index 0000000000000..fb77dacec629d --- /dev/null +++ b/examples/src/main/python/ml/vector_size_hint_example.py @@ -0,0 +1,57 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +# $example on$ +from pyspark.ml.linalg import Vectors +from pyspark.ml.feature import (VectorSizeHint, VectorAssembler) +# $example off$ +from pyspark.sql import SparkSession + +if __name__ == "__main__": + spark = SparkSession\ + .builder\ + .appName("VectorSizeHintExample")\ + .getOrCreate() + + # $example on$ + dataset = spark.createDataFrame( + [(0, 18, 1.0, Vectors.dense([0.0, 10.0, 0.5]), 1.0), + (0, 18, 1.0, Vectors.dense([0.0, 10.0]), 0.0)], + ["id", "hour", "mobile", "userFeatures", "clicked"]) + + sizeHint = VectorSizeHint( + inputCol="userFeatures", + handleInvalid="skip", + size=3) + + datasetWithSize = sizeHint.transform(dataset) + print("Rows where 'userFeatures' is not the right size are filtered out") + datasetWithSize.show(truncate=False) + + assembler = VectorAssembler( + inputCols=["hour", "mobile", "userFeatures"], + outputCol="features") + + # This dataframe can be used by downstream transformers as before + output = assembler.transform(datasetWithSize) + print("Assembled columns 'hour', 'mobile', 'userFeatures' to vector column 'features'") + output.select("features", "clicked").show(truncate=False) + # $example off$ + + spark.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/VectorSizeHintExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/VectorSizeHintExample.scala new file mode 100644 index 0000000000000..688731a791f35 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/VectorSizeHintExample.scala @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.{VectorAssembler, VectorSizeHint} +import org.apache.spark.ml.linalg.Vectors +// $example off$ +import org.apache.spark.sql.SparkSession + +object VectorSizeHintExample { + def main(args: Array[String]): Unit = { + val spark = SparkSession + .builder + .appName("VectorSizeHintExample") + .getOrCreate() + + // $example on$ + val dataset = spark.createDataFrame( + Seq( + (0, 18, 1.0, Vectors.dense(0.0, 10.0, 0.5), 1.0), + (0, 18, 1.0, Vectors.dense(0.0, 10.0), 0.0)) + ).toDF("id", "hour", "mobile", "userFeatures", "clicked") + + val sizeHint = new VectorSizeHint() + .setInputCol("userFeatures") + .setHandleInvalid("skip") + .setSize(3) + + val datasetWithSize = sizeHint.transform(dataset) + println("Rows where 'userFeatures' is not the right size are filtered out") + datasetWithSize.show(false) + + val assembler = new VectorAssembler() + .setInputCols(Array("hour", "mobile", "userFeatures")) + .setOutputCol("features") + + // This dataframe can be used by downstream transformers as before + val output = assembler.transform(datasetWithSize) + println("Assembled columns 'hour', 'mobile', 'userFeatures' to vector column 'features'") + output.select("features", "clicked").show(false) + // $example off$ + + spark.stop() + } +} +// scalastyle:on println From 613c290336e3826111164c24319f66774b1f65a3 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 23 Jan 2018 14:56:28 -0800 Subject: [PATCH 0183/1161] [SPARK-23192][SQL] Keep the Hint after Using Cached Data ## What changes were proposed in this pull request? The hint of the plan segment is lost, if the plan segment is replaced by the cached data. ```Scala val df1 = spark.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value") val df2 = spark.createDataFrame(Seq((1, "1"), (2, "2"))).toDF("key", "value") df2.cache() val df3 = df1.join(broadcast(df2), Seq("key"), "inner") ``` This PR is to fix it. ## How was this patch tested? Added a test Author: gatorsmile Closes #20365 from gatorsmile/fixBroadcastHintloss. --- .../apache/spark/sql/execution/CacheManager.scala | 12 ++++++++---- .../sql/execution/joins/BroadcastJoinSuite.scala | 13 +++++++++++++ 2 files changed, 21 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala index b05fe49a6ac3b..432eb59d6fe57 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala @@ -26,7 +26,7 @@ import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.spark.internal.Logging import org.apache.spark.sql.{Dataset, SparkSession} import org.apache.spark.sql.catalyst.expressions.SubqueryExpression -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Statistics} +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, ResolvedHint} import org.apache.spark.sql.execution.columnar.InMemoryRelation import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation} import org.apache.spark.storage.StorageLevel @@ -170,9 +170,13 @@ class CacheManager extends Logging { def useCachedData(plan: LogicalPlan): LogicalPlan = { val newPlan = plan transformDown { case currentFragment => - lookupCachedData(currentFragment) - .map(_.cachedRepresentation.withOutput(currentFragment.output)) - .getOrElse(currentFragment) + lookupCachedData(currentFragment).map { cached => + val cachedPlan = cached.cachedRepresentation.withOutput(currentFragment.output) + currentFragment match { + case hint: ResolvedHint => ResolvedHint(cachedPlan, hint.hints) + case _ => cachedPlan + } + }.getOrElse(currentFragment) } newPlan transformAllExpressions { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala index 0bcd54e1fceab..1704bc8376f0d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala @@ -109,6 +109,19 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils { } } + test("broadcast hint is retained after using the cached data") { + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + val df1 = spark.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value") + val df2 = spark.createDataFrame(Seq((1, "1"), (2, "2"))).toDF("key", "value") + df2.cache() + val df3 = df1.join(broadcast(df2), Seq("key"), "inner") + val numBroadCastHashJoin = df3.queryExecution.executedPlan.collect { + case b: BroadcastHashJoinExec => b + }.size + assert(numBroadCastHashJoin === 1) + } + } + test("broadcast hint isn't propagated after a join") { withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { val df1 = spark.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value") From 44cc4daf3a03f1a220eef8ce3c86867745db9ab7 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 23 Jan 2018 16:17:09 -0800 Subject: [PATCH 0184/1161] [SPARK-23195][SQL] Keep the Hint of Cached Data ## What changes were proposed in this pull request? The broadcast hint of the cached plan is lost if we cache the plan. This PR is to correct it. ```Scala val df1 = spark.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value") val df2 = spark.createDataFrame(Seq((1, "1"), (2, "2"))).toDF("key", "value") broadcast(df2).cache() df2.collect() val df3 = df1.join(df2, Seq("key"), "inner") ``` ## How was this patch tested? Added a test. Author: gatorsmile Closes #20368 from gatorsmile/cachedBroadcastHint. --- .../execution/columnar/InMemoryRelation.scala | 4 ++-- .../sql/execution/joins/BroadcastJoinSuite.scala | 16 ++++++++++++++++ 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala index 51928d914841e..5945808c4abfb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala @@ -63,7 +63,7 @@ case class InMemoryRelation( tableName: Option[String])( @transient var _cachedColumnBuffers: RDD[CachedBatch] = null, val batchStats: LongAccumulator = child.sqlContext.sparkContext.longAccumulator, - statsOfPlanToCache: Statistics = null) + statsOfPlanToCache: Statistics) extends logical.LeafNode with MultiInstanceRelation { override protected def innerChildren: Seq[SparkPlan] = Seq(child) @@ -77,7 +77,7 @@ case class InMemoryRelation( // Underlying columnar RDD hasn't been materialized, use the stats from the plan to cache statsOfPlanToCache } else { - Statistics(sizeInBytes = batchStats.value.longValue) + Statistics(sizeInBytes = batchStats.value.longValue, hints = statsOfPlanToCache.hints) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala index 1704bc8376f0d..889cab0489534 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala @@ -139,6 +139,22 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils { } } + test("broadcast hint is retained in a cached plan") { + Seq(true, false).foreach { materialized => + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + val df1 = spark.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value") + val df2 = spark.createDataFrame(Seq((1, "1"), (2, "2"))).toDF("key", "value") + broadcast(df2).cache() + if (materialized) df2.collect() + val df3 = df1.join(df2, Seq("key"), "inner") + val numBroadCastHashJoin = df3.queryExecution.executedPlan.collect { + case b: BroadcastHashJoinExec => b + }.size + assert(numBroadCastHashJoin === 1) + } + } + } + private def assertBroadcastJoin(df : Dataset[Row]) : Unit = { val df1 = spark.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value") val joined = df1.join(df, Seq("key"), "inner") From 15adcc8273e73352e5e1c3fc9915c0b004ec4836 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 23 Jan 2018 16:24:20 -0800 Subject: [PATCH 0185/1161] [SPARK-23197][DSTREAMS] Increased timeouts to resolve flakiness ## What changes were proposed in this pull request? Increased timeout from 50 ms to 300 ms (50 ms was really too low). ## How was this patch tested? Multiple rounds of tests. Author: Tathagata Das Closes #20371 from tdas/SPARK-23197. --- .../scala/org/apache/spark/streaming/ReceiverSuite.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala index 145c48e5a9a72..fc6218a33f741 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala @@ -105,13 +105,13 @@ class ReceiverSuite extends TestSuiteBase with TimeLimits with Serializable { assert(executor.errors.head.eq(exception)) // Verify restarting actually stops and starts the receiver - receiver.restart("restarting", null, 100) - eventually(timeout(50 millis), interval(10 millis)) { + receiver.restart("restarting", null, 600) + eventually(timeout(300 millis), interval(10 millis)) { // receiver will be stopped async assert(receiver.isStopped) assert(receiver.onStopCalled) } - eventually(timeout(1000 millis), interval(100 millis)) { + eventually(timeout(1000 millis), interval(10 millis)) { // receiver will be started async assert(receiver.onStartCalled) assert(executor.isReceiverStarted) From a3911cf896de6e9386042ae4d93632cba69eef0f Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 24 Jan 2018 11:43:48 +0900 Subject: [PATCH 0186/1161] [SPARK-23177][SQL][PYSPARK] Extract zero-parameter UDFs from aggregate ## What changes were proposed in this pull request? We extract Python UDFs in logical aggregate which depends on aggregate expression or grouping key in ExtractPythonUDFFromAggregate rule. But Python UDFs which don't depend on above expressions should also be extracted to avoid the issue reported in the JIRA. A small code snippet to reproduce that issue looks like: ```python import pyspark.sql.functions as f df = spark.createDataFrame([(1,2), (3,4)]) f_udf = f.udf(lambda: str("const_str")) df2 = df.distinct().withColumn("a", f_udf()) df2.show() ``` Error exception is raised as: ``` : org.apache.spark.sql.catalyst.errors.package$TreeNodeException: Binding attribute, tree: pythonUDF0#50 at org.apache.spark.sql.catalyst.errors.package$.attachTree(package.scala:56) at org.apache.spark.sql.catalyst.expressions.BindReferences$$anonfun$bindReference$1.applyOrElse(BoundAttribute.scala:91) at org.apache.spark.sql.catalyst.expressions.BindReferences$$anonfun$bindReference$1.applyOrElse(BoundAttribute.scala:90) at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$2.apply(TreeNode.scala:267) at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$2.apply(TreeNode.scala:267) at org.apache.spark.sql.catalyst.trees.CurrentOrigin$.withOrigin(TreeNode.scala:70) at org.apache.spark.sql.catalyst.trees.TreeNode.transformDown(TreeNode.scala:266) at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$transformDown$1.apply(TreeNode.scala:272) at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$transformDown$1.apply(TreeNode.scala:272) at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$4.apply(TreeNode.scala:306) at org.apache.spark.sql.catalyst.trees.TreeNode.mapProductIterator(TreeNode.scala:187) at org.apache.spark.sql.catalyst.trees.TreeNode.mapChildren(TreeNode.scala:304) at org.apache.spark.sql.catalyst.trees.TreeNode.transformDown(TreeNode.scala:272) at org.apache.spark.sql.catalyst.trees.TreeNode.transform(TreeNode.scala:256) at org.apache.spark.sql.catalyst.expressions.BindReferences$.bindReference(BoundAttribute.scala:90) at org.apache.spark.sql.execution.aggregate.HashAggregateExec$$anonfun$38.apply(HashAggregateExec.scala:514) at org.apache.spark.sql.execution.aggregate.HashAggregateExec$$anonfun$38.apply(HashAggregateExec.scala:513) ``` This exception raises because `HashAggregateExec` tries to bind the aliased Python UDF expression (e.g., `pythonUDF0#50 AS a#44`) to grouping key. ## How was this patch tested? Added test. Author: Liang-Chi Hsieh Closes #20360 from viirya/SPARK-23177. --- python/pyspark/sql/tests.py | 8 ++++++++ .../spark/sql/execution/python/ExtractPythonUDFs.scala | 5 +++-- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 84e8eec71dd8a..a466ab87d882d 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -1106,6 +1106,14 @@ def myudf(x): rows = [r[0] for r in df.selectExpr("udf(id)").take(2)] self.assertEqual(rows, [None, PythonOnlyPoint(1, 1)]) + def test_nonparam_udf_with_aggregate(self): + import pyspark.sql.functions as f + + df = self.spark.createDataFrame([(1, 2), (1, 2)]) + f_udf = f.udf(lambda: "const_str") + rows = df.distinct().withColumn("a", f_udf()).collect() + self.assertEqual(rows, [Row(_1=1, _2=2, a=u'const_str')]) + def test_infer_schema_with_udt(self): from pyspark.sql.tests import ExamplePoint, ExamplePointUDT row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala index 1862e3f6e12ca..4ae4e164830be 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.execution.{FilterExec, ProjectExec, SparkPlan} /** * Extracts all the Python UDFs in logical aggregate, which depends on aggregate expression or - * grouping key, evaluate them after aggregate. + * grouping key, or doesn't depend on any above expressions, evaluate them after aggregate. */ object ExtractPythonUDFFromAggregate extends Rule[LogicalPlan] { @@ -45,7 +45,8 @@ object ExtractPythonUDFFromAggregate extends Rule[LogicalPlan] { private def hasPythonUdfOverAggregate(expr: Expression, agg: Aggregate): Boolean = { expr.find { - e => PythonUDF.isScalarPythonUDF(e) && e.find(belongAggregate(_, agg)).isDefined + e => PythonUDF.isScalarPythonUDF(e) && + (e.references.isEmpty || e.find(belongAggregate(_, agg)).isDefined) }.isDefined } From f54b65c15a732540f7a41a9083eeb7a08feca125 Mon Sep 17 00:00:00 2001 From: neilalex Date: Tue, 23 Jan 2018 22:31:14 -0800 Subject: [PATCH 0187/1161] [SPARK-21727][R] Allow multi-element atomic vector as column type in SparkR DataFrame MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? A fix to https://issues.apache.org/jira/browse/SPARK-21727, "Operating on an ArrayType in a SparkR DataFrame throws error" ## How was this patch tested? - Ran tests at R\pkg\tests\run-all.R (see below attached results) - Tested the following lines in SparkR, which now seem to execute without error: ``` indices <- 1:4 myDf <- data.frame(indices) myDf$data <- list(rep(0, 20)) mySparkDf <- as.DataFrame(myDf) collect(mySparkDf) ``` [2018-01-22 SPARK-21727 Test Results.txt](https://github.com/apache/spark/files/1653535/2018-01-22.SPARK-21727.Test.Results.txt) felixcheung yanboliang sun-rui shivaram _The contribution is my original work and I license the work to the project under the project’s open source license_ Author: neilalex Closes #20352 from neilalex/neilalex-sparkr-arraytype. --- R/pkg/R/serialize.R | 11 +++---- R/pkg/tests/fulltests/test_Serde.R | 47 ++++++++++++++++++++++++++++++ 2 files changed, 53 insertions(+), 5 deletions(-) diff --git a/R/pkg/R/serialize.R b/R/pkg/R/serialize.R index 3bbf60d9b668c..263b9b576c0c5 100644 --- a/R/pkg/R/serialize.R +++ b/R/pkg/R/serialize.R @@ -30,14 +30,17 @@ # POSIXct,POSIXlt -> Time # # list[T] -> Array[T], where T is one of above mentioned types +# Multi-element vector of any of the above (except raw) -> Array[T] # environment -> Map[String, T], where T is a native type # jobj -> Object, where jobj is an object created in the backend # nolint end getSerdeType <- function(object) { type <- class(object)[[1]] - if (type != "list") { - type + if (is.atomic(object) & !is.raw(object) & length(object) > 1) { + "array" + } else if (type != "list") { + type } else { # Check if all elements are of same type elemType <- unique(sapply(object, function(elem) { getSerdeType(elem) })) @@ -50,9 +53,7 @@ getSerdeType <- function(object) { } writeObject <- function(con, object, writeType = TRUE) { - # NOTE: In R vectors have same type as objects. So we don't support - # passing in vectors as arrays and instead require arrays to be passed - # as lists. + # NOTE: In R vectors have same type as objects type <- class(object)[[1]] # class of POSIXlt is c("POSIXlt", "POSIXt") # Checking types is needed here, since 'is.na' only handles atomic vectors, # lists and pairlists diff --git a/R/pkg/tests/fulltests/test_Serde.R b/R/pkg/tests/fulltests/test_Serde.R index 6bbd201bf1d82..3577929323b8b 100644 --- a/R/pkg/tests/fulltests/test_Serde.R +++ b/R/pkg/tests/fulltests/test_Serde.R @@ -37,6 +37,53 @@ test_that("SerDe of primitive types", { expect_equal(class(x), "character") }) +test_that("SerDe of multi-element primitive vectors inside R data.frame", { + # vector of integers embedded in R data.frame + indices <- 1L:3L + myDf <- data.frame(indices) + myDf$data <- list(rep(0L, 3L)) + mySparkDf <- as.DataFrame(myDf) + myResultingDf <- collect(mySparkDf) + myDfListedData <- data.frame(indices) + myDfListedData$data <- list(as.list(rep(0L, 3L))) + expect_equal(myResultingDf, myDfListedData) + expect_equal(class(myResultingDf[["data"]][[1]]), "list") + expect_equal(class(myResultingDf[["data"]][[1]][[1]]), "integer") + + # vector of numeric embedded in R data.frame + myDf <- data.frame(indices) + myDf$data <- list(rep(0, 3L)) + mySparkDf <- as.DataFrame(myDf) + myResultingDf <- collect(mySparkDf) + myDfListedData <- data.frame(indices) + myDfListedData$data <- list(as.list(rep(0, 3L))) + expect_equal(myResultingDf, myDfListedData) + expect_equal(class(myResultingDf[["data"]][[1]]), "list") + expect_equal(class(myResultingDf[["data"]][[1]][[1]]), "numeric") + + # vector of logical embedded in R data.frame + myDf <- data.frame(indices) + myDf$data <- list(rep(TRUE, 3L)) + mySparkDf <- as.DataFrame(myDf) + myResultingDf <- collect(mySparkDf) + myDfListedData <- data.frame(indices) + myDfListedData$data <- list(as.list(rep(TRUE, 3L))) + expect_equal(myResultingDf, myDfListedData) + expect_equal(class(myResultingDf[["data"]][[1]]), "list") + expect_equal(class(myResultingDf[["data"]][[1]][[1]]), "logical") + + # vector of character embedded in R data.frame + myDf <- data.frame(indices) + myDf$data <- list(rep("abc", 3L)) + mySparkDf <- as.DataFrame(myDf) + myResultingDf <- collect(mySparkDf) + myDfListedData <- data.frame(indices) + myDfListedData$data <- list(as.list(rep("abc", 3L))) + expect_equal(myResultingDf, myDfListedData) + expect_equal(class(myResultingDf[["data"]][[1]]), "list") + expect_equal(class(myResultingDf[["data"]][[1]][[1]]), "character") +}) + test_that("SerDe of list of primitive types", { x <- list(1L, 2L, 3L) y <- callJStatic("SparkRHandler", "echo", x) From 4e7b49041aceca0beafec20f697b63a473a2b42f Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 23 Jan 2018 22:38:20 -0800 Subject: [PATCH 0188/1161] Revert "[SPARK-23195][SQL] Keep the Hint of Cached Data" This reverts commit 44cc4daf3a03f1a220eef8ce3c86867745db9ab7. --- .../execution/columnar/InMemoryRelation.scala | 4 ++-- .../sql/execution/joins/BroadcastJoinSuite.scala | 16 ---------------- 2 files changed, 2 insertions(+), 18 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala index 5945808c4abfb..51928d914841e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala @@ -63,7 +63,7 @@ case class InMemoryRelation( tableName: Option[String])( @transient var _cachedColumnBuffers: RDD[CachedBatch] = null, val batchStats: LongAccumulator = child.sqlContext.sparkContext.longAccumulator, - statsOfPlanToCache: Statistics) + statsOfPlanToCache: Statistics = null) extends logical.LeafNode with MultiInstanceRelation { override protected def innerChildren: Seq[SparkPlan] = Seq(child) @@ -77,7 +77,7 @@ case class InMemoryRelation( // Underlying columnar RDD hasn't been materialized, use the stats from the plan to cache statsOfPlanToCache } else { - Statistics(sizeInBytes = batchStats.value.longValue, hints = statsOfPlanToCache.hints) + Statistics(sizeInBytes = batchStats.value.longValue) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala index 889cab0489534..1704bc8376f0d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala @@ -139,22 +139,6 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils { } } - test("broadcast hint is retained in a cached plan") { - Seq(true, false).foreach { materialized => - withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { - val df1 = spark.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value") - val df2 = spark.createDataFrame(Seq((1, "1"), (2, "2"))).toDF("key", "value") - broadcast(df2).cache() - if (materialized) df2.collect() - val df3 = df1.join(df2, Seq("key"), "inner") - val numBroadCastHashJoin = df3.queryExecution.executedPlan.collect { - case b: BroadcastHashJoinExec => b - }.size - assert(numBroadCastHashJoin === 1) - } - } - } - private def assertBroadcastJoin(df : Dataset[Row]) : Unit = { val df1 = spark.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value") val joined = df1.join(df, Seq("key"), "inner") From 7af1a325da57daa2e25c713472a320f4ccb43d71 Mon Sep 17 00:00:00 2001 From: Rekha Joshi Date: Wed, 24 Jan 2018 21:13:47 +0900 Subject: [PATCH 0189/1161] [SPARK-23174][BUILD][PYTHON] python code style checker update ## What changes were proposed in this pull request? Referencing latest python code style checking from PyPi/pycodestyle Removed pending TODO For now, in tox.ini excluded the additional style error discovered on existing python due to latest style checker (will fallback on review comment to finalize exclusion or fix py) Any further code styling requirement needs to be part of pycodestyle, not in SPARK. ## How was this patch tested? ./dev/run-tests Author: Rekha Joshi Author: rjoshi2 Closes #20338 from rekhajoshm/SPARK-11222. --- dev/lint-python | 37 ++++++++++++++++++------------------- dev/run-tests.py | 5 ++++- dev/tox.ini | 4 ++-- 3 files changed, 24 insertions(+), 22 deletions(-) diff --git a/dev/lint-python b/dev/lint-python index df8df037a5f69..e069cafa1b8c6 100755 --- a/dev/lint-python +++ b/dev/lint-python @@ -21,7 +21,7 @@ SCRIPT_DIR="$( cd "$( dirname "$0" )" && pwd )" SPARK_ROOT_DIR="$(dirname "$SCRIPT_DIR")" # Exclude auto-generated configuration file. PATHS_TO_CHECK="$( cd "$SPARK_ROOT_DIR" && find . -name "*.py" )" -PEP8_REPORT_PATH="$SPARK_ROOT_DIR/dev/pep8-report.txt" +PYCODESTYLE_REPORT_PATH="$SPARK_ROOT_DIR/dev/pycodestyle-report.txt" PYLINT_REPORT_PATH="$SPARK_ROOT_DIR/dev/pylint-report.txt" PYLINT_INSTALL_INFO="$SPARK_ROOT_DIR/dev/pylint-info.txt" SPHINXBUILD=${SPHINXBUILD:=sphinx-build} @@ -30,23 +30,22 @@ SPHINX_REPORT_PATH="$SPARK_ROOT_DIR/dev/sphinx-report.txt" cd "$SPARK_ROOT_DIR" # compileall: https://docs.python.org/2/library/compileall.html -python -B -m compileall -q -l $PATHS_TO_CHECK > "$PEP8_REPORT_PATH" +python -B -m compileall -q -l $PATHS_TO_CHECK > "$PYCODESTYLE_REPORT_PATH" compile_status="${PIPESTATUS[0]}" -# Get pep8 at runtime so that we don't rely on it being installed on the build server. +# Get pycodestyle at runtime so that we don't rely on it being installed on the build server. #+ See: https://github.com/apache/spark/pull/1744#issuecomment-50982162 -#+ TODOs: -#+ - Download pep8 from PyPI. It's more "official". -PEP8_VERSION="1.7.0" -PEP8_SCRIPT_PATH="$SPARK_ROOT_DIR/dev/pep8-$PEP8_VERSION.py" -PEP8_SCRIPT_REMOTE_PATH="https://raw.githubusercontent.com/jcrocholl/pep8/$PEP8_VERSION/pep8.py" +# Updated to latest official version for pep8. pep8 is formally renamed to pycodestyle. +PYCODESTYLE_VERSION="2.3.1" +PYCODESTYLE_SCRIPT_PATH="$SPARK_ROOT_DIR/dev/pycodestyle-$PYCODESTYLE_VERSION.py" +PYCODESTYLE_SCRIPT_REMOTE_PATH="https://raw.githubusercontent.com/PyCQA/pycodestyle/$PYCODESTYLE_VERSION/pycodestyle.py" -if [ ! -e "$PEP8_SCRIPT_PATH" ]; then - curl --silent -o "$PEP8_SCRIPT_PATH" "$PEP8_SCRIPT_REMOTE_PATH" +if [ ! -e "$PYCODESTYLE_SCRIPT_PATH" ]; then + curl --silent -o "$PYCODESTYLE_SCRIPT_PATH" "$PYCODESTYLE_SCRIPT_REMOTE_PATH" curl_status="$?" if [ "$curl_status" -ne 0 ]; then - echo "Failed to download pep8.py from \"$PEP8_SCRIPT_REMOTE_PATH\"." + echo "Failed to download pycodestyle.py from \"$PYCODESTYLE_SCRIPT_REMOTE_PATH\"." exit "$curl_status" fi fi @@ -64,23 +63,23 @@ export "PATH=$PYTHONPATH:$PATH" #+ first, but we do so so that the check status can #+ be output before the report, like with the #+ scalastyle and RAT checks. -python "$PEP8_SCRIPT_PATH" --config=dev/tox.ini $PATHS_TO_CHECK >> "$PEP8_REPORT_PATH" -pep8_status="${PIPESTATUS[0]}" +python "$PYCODESTYLE_SCRIPT_PATH" --config=dev/tox.ini $PATHS_TO_CHECK >> "$PYCODESTYLE_REPORT_PATH" +pycodestyle_status="${PIPESTATUS[0]}" -if [ "$compile_status" -eq 0 -a "$pep8_status" -eq 0 ]; then +if [ "$compile_status" -eq 0 -a "$pycodestyle_status" -eq 0 ]; then lint_status=0 else lint_status=1 fi if [ "$lint_status" -ne 0 ]; then - echo "PEP8 checks failed." - cat "$PEP8_REPORT_PATH" - rm "$PEP8_REPORT_PATH" + echo "PYCODESTYLE checks failed." + cat "$PYCODESTYLE_REPORT_PATH" + rm "$PYCODESTYLE_REPORT_PATH" exit "$lint_status" else - echo "PEP8 checks passed." - rm "$PEP8_REPORT_PATH" + echo "pycodestyle checks passed." + rm "$PYCODESTYLE_REPORT_PATH" fi # Check that the documentation builds acceptably, skip check if sphinx is not installed. diff --git a/dev/run-tests.py b/dev/run-tests.py index fb270c4ee0508..fe75ef4411c8c 100755 --- a/dev/run-tests.py +++ b/dev/run-tests.py @@ -576,7 +576,10 @@ def main(): for f in changed_files): # run_java_style_checks() pass - if not changed_files or any(f.endswith(".py") for f in changed_files): + if not changed_files or any(f.endswith("lint-python") + or f.endswith("tox.ini") + or f.endswith(".py") + for f in changed_files): run_python_style_checks() if not changed_files or any(f.endswith(".R") or f.endswith("lint-r") diff --git a/dev/tox.ini b/dev/tox.ini index eb8b1eb2c2886..583c1eaaa966b 100644 --- a/dev/tox.ini +++ b/dev/tox.ini @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -[pep8] -ignore=E402,E731,E241,W503,E226 +[pycodestyle] +ignore=E402,E731,E241,W503,E226,E722,E741,E305 max-line-length=100 exclude=cloudpickle.py,heapq3.py,shared.py,python/docs/conf.py,work/*/*.py,python/.eggs/* From de36f65d3a819c00d6bf6979deef46c824203669 Mon Sep 17 00:00:00 2001 From: Henry Robinson Date: Wed, 24 Jan 2018 21:19:09 +0900 Subject: [PATCH 0190/1161] [SPARK-23148][SQL] Allow pathnames with special characters for CSV / JSON / text MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …JSON / text ## What changes were proposed in this pull request? Fix for JSON and CSV data sources when file names include characters that would be changed by URL encoding. ## How was this patch tested? New unit tests for JSON, CSV and text suites Author: Henry Robinson Closes #20355 from henryr/spark-23148. --- .../execution/datasources/CodecStreams.scala | 6 +++--- .../datasources/csv/CSVDataSource.scala | 11 ++++++----- .../datasources/json/JsonDataSource.scala | 10 ++++++---- .../spark/sql/FileBasedDataSourceSuite.scala | 18 ++++++++++++++++-- 4 files changed, 31 insertions(+), 14 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CodecStreams.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CodecStreams.scala index 54549f698aca5..c0df6c779d7bd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CodecStreams.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CodecStreams.scala @@ -45,11 +45,11 @@ object CodecStreams { } /** - * Creates an input stream from the string path and add a closure for the input stream to be + * Creates an input stream from the given path and add a closure for the input stream to be * closed on task completion. */ - def createInputStreamWithCloseResource(config: Configuration, path: String): InputStream = { - val inputStream = createInputStream(config, new Path(path)) + def createInputStreamWithCloseResource(config: Configuration, path: Path): InputStream = { + val inputStream = createInputStream(config, path) Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => inputStream.close())) inputStream } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala index 2031381dd2e10..4870d75fc5f08 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala @@ -17,11 +17,12 @@ package org.apache.spark.sql.execution.datasources.csv +import java.net.URI import java.nio.charset.{Charset, StandardCharsets} import com.univocity.parsers.csv.CsvParser import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.FileStatus +import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.io.{LongWritable, Text} import org.apache.hadoop.mapred.TextInputFormat import org.apache.hadoop.mapreduce.Job @@ -32,7 +33,6 @@ import org.apache.spark.input.{PortableDataStream, StreamInputFormat} import org.apache.spark.rdd.{BinaryFileRDD, RDD} import org.apache.spark.sql.{Dataset, Encoders, SparkSession} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.text.TextFileFormat import org.apache.spark.sql.types.StructType @@ -206,7 +206,7 @@ object MultiLineCSVDataSource extends CSVDataSource { parser: UnivocityParser, schema: StructType): Iterator[InternalRow] = { UnivocityParser.parseStream( - CodecStreams.createInputStreamWithCloseResource(conf, file.filePath), + CodecStreams.createInputStreamWithCloseResource(conf, new Path(new URI(file.filePath))), parser.options.headerFlag, parser, schema) @@ -218,8 +218,9 @@ object MultiLineCSVDataSource extends CSVDataSource { parsedOptions: CSVOptions): StructType = { val csv = createBaseRdd(sparkSession, inputPaths, parsedOptions) csv.flatMap { lines => + val path = new Path(lines.getPath()) UnivocityParser.tokenizeStream( - CodecStreams.createInputStreamWithCloseResource(lines.getConfiguration, lines.getPath()), + CodecStreams.createInputStreamWithCloseResource(lines.getConfiguration, path), shouldDropHeader = false, new CsvParser(parsedOptions.asParserSettings)) }.take(1).headOption match { @@ -230,7 +231,7 @@ object MultiLineCSVDataSource extends CSVDataSource { UnivocityParser.tokenizeStream( CodecStreams.createInputStreamWithCloseResource( lines.getConfiguration, - lines.getPath()), + new Path(lines.getPath())), parsedOptions.headerFlag, new CsvParser(parsedOptions.asParserSettings)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala index 8b7c2709afde1..77e7edc8e7a20 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala @@ -18,11 +18,12 @@ package org.apache.spark.sql.execution.datasources.json import java.io.InputStream +import java.net.URI import com.fasterxml.jackson.core.{JsonFactory, JsonParser} import com.google.common.io.ByteStreams import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.FileStatus +import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.io.Text import org.apache.hadoop.mapreduce.Job import org.apache.hadoop.mapreduce.lib.input.FileInputFormat @@ -168,9 +169,10 @@ object MultiLineJsonDataSource extends JsonDataSource { } private def createParser(jsonFactory: JsonFactory, record: PortableDataStream): JsonParser = { + val path = new Path(record.getPath()) CreateJacksonParser.inputStream( jsonFactory, - CodecStreams.createInputStreamWithCloseResource(record.getConfiguration, record.getPath())) + CodecStreams.createInputStreamWithCloseResource(record.getConfiguration, path)) } override def readFile( @@ -180,7 +182,7 @@ object MultiLineJsonDataSource extends JsonDataSource { schema: StructType): Iterator[InternalRow] = { def partitionedFileString(ignored: Any): UTF8String = { Utils.tryWithResource { - CodecStreams.createInputStreamWithCloseResource(conf, file.filePath) + CodecStreams.createInputStreamWithCloseResource(conf, new Path(new URI(file.filePath))) } { inputStream => UTF8String.fromBytes(ByteStreams.toByteArray(inputStream)) } @@ -193,6 +195,6 @@ object MultiLineJsonDataSource extends JsonDataSource { parser.options.columnNameOfCorruptRecord) safeParser.parse( - CodecStreams.createInputStreamWithCloseResource(conf, file.filePath)) + CodecStreams.createInputStreamWithCloseResource(conf, new Path(new URI(file.filePath)))) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala index 22fb496bc838e..c272c99ae45a8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala @@ -23,6 +23,7 @@ class FileBasedDataSourceSuite extends QueryTest with SharedSQLContext { import testImplicits._ private val allFileBasedDataSources = Seq("orc", "parquet", "csv", "json", "text") + private val nameWithSpecialChars = "sp&cial%c hars" allFileBasedDataSources.foreach { format => test(s"Writing empty datasets should not fail - $format") { @@ -54,7 +55,7 @@ class FileBasedDataSourceSuite extends QueryTest with SharedSQLContext { // Only ORC/Parquet support this. `CSV` and `JSON` returns an empty schema. // `TEXT` data source always has a single column whose name is `value`. Seq("orc", "parquet").foreach { format => - test(s"SPARK-15474 Write and read back non-emtpy schema with empty dataframe - $format") { + test(s"SPARK-15474 Write and read back non-empty schema with empty dataframe - $format") { withTempPath { file => val path = file.getCanonicalPath val emptyDf = Seq((true, 1, "str")).toDF().limit(0) @@ -69,7 +70,6 @@ class FileBasedDataSourceSuite extends QueryTest with SharedSQLContext { allFileBasedDataSources.foreach { format => test(s"SPARK-22146 read files containing special characters using $format") { - val nameWithSpecialChars = s"sp&cial%chars" withTempDir { dir => val tmpFile = s"$dir/$nameWithSpecialChars" spark.createDataset(Seq("a", "b")).write.format(format).save(tmpFile) @@ -78,4 +78,18 @@ class FileBasedDataSourceSuite extends QueryTest with SharedSQLContext { } } } + + // Separate test case for formats that support multiLine as an option. + Seq("json", "csv").foreach { format => + test("SPARK-23148 read files containing special characters " + + s"using $format with multiline enabled") { + withTempDir { dir => + val tmpFile = s"$dir/$nameWithSpecialChars" + spark.createDataset(Seq("a", "b")).write.format(format).save(tmpFile) + val reader = spark.read.format(format).option("multiLine", true) + val fileContent = reader.load(tmpFile) + checkAnswer(fileContent, Seq(Row("a"), Row("b"))) + } + } + } } From 0ec95bb7df775be33fc8983f6c0983a67032d2c8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Cattilapiros=E2=80=9D?= Date: Wed, 24 Jan 2018 11:34:59 -0600 Subject: [PATCH 0191/1161] [SPARK-22577][CORE] executor page blacklist status should update with TaskSet level blacklisting MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? In this PR stage blacklisting is propagated to UI by introducing a new Spark listener event (SparkListenerExecutorBlacklistedForStage) which indicates the executor is blacklisted for a stage. Either because of the number of failures are exceeded a limit given for an executor (spark.blacklist.stage.maxFailedTasksPerExecutor) or because of the whole node is blacklisted for a stage (spark.blacklist.stage.maxFailedExecutorsPerNode). In case of the node is blacklisting all executors will listed as blacklisted for the stage. Blacklisting state for a selected stage can be seen "Aggregated Metrics by Executor" table's blacklisting column, where after this change three possible labels could be found: - "for application": when the executor is blacklisted for the application (see the configuration spark.blacklist.application.maxFailedTasksPerExecutor for details) - "for stage": when the executor is **only** blacklisted for the stage - "false" : when the executor is not blacklisted at all ## How was this patch tested? It is tested both manually and with unit tests. #### Unit tests - HistoryServerSuite - TaskSetBlacklistSuite - AppStatusListenerSuite #### Manual test for executor blacklisting Running Spark as a local cluster: ``` $ bin/spark-shell --master "local-cluster[2,1,1024]" --conf "spark.blacklist.enabled=true" --conf "spark.blacklist.stage.maxFailedTasksPerExecutor=1" --conf "spark.blacklist.application.maxFailedTasksPerExecutor=10" --conf "spark.eventLog.enabled=true" ``` Executing: ``` scala import org.apache.spark.SparkEnv sc.parallelize(1 to 10, 10).map { x => if (SparkEnv.get.executorId == "0") throw new RuntimeException("Bad executor") else (x % 3, x) }.reduceByKey((a, b) => a + b).collect() ``` To see result check the "Aggregated Metrics by Executor" section at the bottom of picture: ![UI screenshot for stage level blacklisting executor](https://issues.apache.org/jira/secure/attachment/12905283/stage_blacklisting.png) #### Manual test for node blacklisting Running Spark as on a cluster: ``` bash ./bin/spark-shell --master yarn --deploy-mode client --executor-memory=2G --num-executors=8 --conf "spark.blacklist.enabled=true" --conf "spark.blacklist.stage.maxFailedTasksPerExecutor=1" --conf "spark.blacklist.stage.maxFailedExecutorsPerNode=1" --conf "spark.blacklist.application.maxFailedTasksPerExecutor=10" --conf "spark.eventLog.enabled=true" ``` And the job was: ``` scala import org.apache.spark.SparkEnv sc.parallelize(1 to 10000, 10).map { x => if (SparkEnv.get.executorId.toInt >= 4) throw new RuntimeException("Bad executor") else (x % 3, x) }.reduceByKey((a, b) => a + b).collect() ``` The result is: ![UI screenshot for stage level node blacklisting](https://issues.apache.org/jira/secure/attachment/12906833/node_blacklisting_for_stage.png) Here you can see apiros3.gce.test.com was node blacklisted for the stage because of failures on executor 4 and 5. As expected executor 3 is also blacklisted even it has no failures itself but sharing the node with 4 and 5. Author: “attilapiros” Author: Attila Zsolt Piros <2017933+attilapiros@users.noreply.github.com> Closes #20203 from attilapiros/SPARK-22577. --- .../apache/spark/SparkFirehoseListener.java | 12 + .../scheduler/EventLoggingListener.scala | 9 + .../spark/scheduler/SparkListener.scala | 35 + .../spark/scheduler/SparkListenerBus.scala | 4 + .../spark/scheduler/TaskSetBlacklist.scala | 19 +- .../spark/scheduler/TaskSetManager.scala | 2 +- .../spark/status/AppStatusListener.scala | 25 + .../org/apache/spark/status/LiveEntity.scala | 4 +- .../org/apache/spark/status/api/v1/api.scala | 3 +- .../apache/spark/ui/jobs/ExecutorTable.scala | 10 +- .../application_list_json_expectation.json | 70 +- .../blacklisting_for_stage_expectation.json | 639 ++++++++++++++ ...acklisting_node_for_stage_expectation.json | 783 ++++++++++++++++++ .../completed_app_list_json_expectation.json | 71 +- .../limit_app_list_json_expectation.json | 54 +- .../minDate_app_list_json_expectation.json | 62 +- .../minEndDate_app_list_json_expectation.json | 34 +- .../one_stage_attempt_json_expectation.json | 3 +- .../one_stage_json_expectation.json | 3 +- ...age_with_accumulable_json_expectation.json | 3 +- .../spark-events/app-20180109111548-0000 | 59 ++ .../application_1516285256255_0012 | 71 ++ .../deploy/history/HistoryServerSuite.scala | 2 + .../scheduler/BlacklistTrackerSuite.scala | 2 +- .../scheduler/TaskSetBlacklistSuite.scala | 119 ++- .../spark/status/AppStatusListenerSuite.scala | 43 + dev/.rat-excludes | 2 + 27 files changed, 2040 insertions(+), 103 deletions(-) create mode 100644 core/src/test/resources/HistoryServerExpectations/blacklisting_for_stage_expectation.json create mode 100644 core/src/test/resources/HistoryServerExpectations/blacklisting_node_for_stage_expectation.json create mode 100755 core/src/test/resources/spark-events/app-20180109111548-0000 create mode 100755 core/src/test/resources/spark-events/application_1516285256255_0012 diff --git a/core/src/main/java/org/apache/spark/SparkFirehoseListener.java b/core/src/main/java/org/apache/spark/SparkFirehoseListener.java index 3583856d88998..94c5c11b61a50 100644 --- a/core/src/main/java/org/apache/spark/SparkFirehoseListener.java +++ b/core/src/main/java/org/apache/spark/SparkFirehoseListener.java @@ -118,6 +118,18 @@ public final void onExecutorBlacklisted(SparkListenerExecutorBlacklisted executo onEvent(executorBlacklisted); } + @Override + public void onExecutorBlacklistedForStage( + SparkListenerExecutorBlacklistedForStage executorBlacklistedForStage) { + onEvent(executorBlacklistedForStage); + } + + @Override + public void onNodeBlacklistedForStage( + SparkListenerNodeBlacklistedForStage nodeBlacklistedForStage) { + onEvent(nodeBlacklistedForStage); + } + @Override public final void onExecutorUnblacklisted( SparkListenerExecutorUnblacklisted executorUnblacklisted) { diff --git a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala index b3a5b1f1e05b3..69bc51c1ecf90 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala @@ -207,6 +207,15 @@ private[spark] class EventLoggingListener( logEvent(event, flushLogger = true) } + override def onExecutorBlacklistedForStage( + event: SparkListenerExecutorBlacklistedForStage): Unit = { + logEvent(event, flushLogger = true) + } + + override def onNodeBlacklistedForStage(event: SparkListenerNodeBlacklistedForStage): Unit = { + logEvent(event, flushLogger = true) + } + override def onExecutorUnblacklisted(event: SparkListenerExecutorUnblacklisted): Unit = { logEvent(event, flushLogger = true) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala index 3b677ca9657db..8a112f6a37b96 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala @@ -120,6 +120,24 @@ case class SparkListenerExecutorBlacklisted( taskFailures: Int) extends SparkListenerEvent +@DeveloperApi +case class SparkListenerExecutorBlacklistedForStage( + time: Long, + executorId: String, + taskFailures: Int, + stageId: Int, + stageAttemptId: Int) + extends SparkListenerEvent + +@DeveloperApi +case class SparkListenerNodeBlacklistedForStage( + time: Long, + hostId: String, + executorFailures: Int, + stageId: Int, + stageAttemptId: Int) + extends SparkListenerEvent + @DeveloperApi case class SparkListenerExecutorUnblacklisted(time: Long, executorId: String) extends SparkListenerEvent @@ -261,6 +279,17 @@ private[spark] trait SparkListenerInterface { */ def onExecutorBlacklisted(executorBlacklisted: SparkListenerExecutorBlacklisted): Unit + /** + * Called when the driver blacklists an executor for a stage. + */ + def onExecutorBlacklistedForStage( + executorBlacklistedForStage: SparkListenerExecutorBlacklistedForStage): Unit + + /** + * Called when the driver blacklists a node for a stage. + */ + def onNodeBlacklistedForStage(nodeBlacklistedForStage: SparkListenerNodeBlacklistedForStage): Unit + /** * Called when the driver re-enables a previously blacklisted executor. */ @@ -339,6 +368,12 @@ abstract class SparkListener extends SparkListenerInterface { override def onExecutorBlacklisted( executorBlacklisted: SparkListenerExecutorBlacklisted): Unit = { } + def onExecutorBlacklistedForStage( + executorBlacklistedForStage: SparkListenerExecutorBlacklistedForStage): Unit = { } + + def onNodeBlacklistedForStage( + nodeBlacklistedForStage: SparkListenerNodeBlacklistedForStage): Unit = { } + override def onExecutorUnblacklisted( executorUnblacklisted: SparkListenerExecutorUnblacklisted): Unit = { } diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala index 056c0cbded435..ff19cc65552e0 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala @@ -61,6 +61,10 @@ private[spark] trait SparkListenerBus listener.onExecutorAdded(executorAdded) case executorRemoved: SparkListenerExecutorRemoved => listener.onExecutorRemoved(executorRemoved) + case executorBlacklistedForStage: SparkListenerExecutorBlacklistedForStage => + listener.onExecutorBlacklistedForStage(executorBlacklistedForStage) + case nodeBlacklistedForStage: SparkListenerNodeBlacklistedForStage => + listener.onNodeBlacklistedForStage(nodeBlacklistedForStage) case executorBlacklisted: SparkListenerExecutorBlacklisted => listener.onExecutorBlacklisted(executorBlacklisted) case executorUnblacklisted: SparkListenerExecutorUnblacklisted => diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetBlacklist.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetBlacklist.scala index 233781f3d9719..b680979a466a5 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetBlacklist.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetBlacklist.scala @@ -36,8 +36,12 @@ import org.apache.spark.util.Clock * [[TaskSetManager]] this class is designed only to be called from code with a lock on the * TaskScheduler (e.g. its event handlers). It should not be called from other threads. */ -private[scheduler] class TaskSetBlacklist(val conf: SparkConf, val stageId: Int, val clock: Clock) - extends Logging { +private[scheduler] class TaskSetBlacklist( + private val listenerBus: LiveListenerBus, + val conf: SparkConf, + val stageId: Int, + val stageAttemptId: Int, + val clock: Clock) extends Logging { private val MAX_TASK_ATTEMPTS_PER_EXECUTOR = conf.get(config.MAX_TASK_ATTEMPTS_PER_EXECUTOR) private val MAX_TASK_ATTEMPTS_PER_NODE = conf.get(config.MAX_TASK_ATTEMPTS_PER_NODE) @@ -128,16 +132,23 @@ private[scheduler] class TaskSetBlacklist(val conf: SparkConf, val stageId: Int, } // Check if enough tasks have failed on the executor to blacklist it for the entire stage. - if (execFailures.numUniqueTasksWithFailures >= MAX_FAILURES_PER_EXEC_STAGE) { + val numFailures = execFailures.numUniqueTasksWithFailures + if (numFailures >= MAX_FAILURES_PER_EXEC_STAGE) { if (blacklistedExecs.add(exec)) { logInfo(s"Blacklisting executor ${exec} for stage $stageId") // This executor has been pushed into the blacklist for this stage. Let's check if it // pushes the whole node into the blacklist. val blacklistedExecutorsOnNode = execsWithFailuresOnNode.filter(blacklistedExecs.contains(_)) - if (blacklistedExecutorsOnNode.size >= MAX_FAILED_EXEC_PER_NODE_STAGE) { + val now = clock.getTimeMillis() + listenerBus.post( + SparkListenerExecutorBlacklistedForStage(now, exec, numFailures, stageId, stageAttemptId)) + val numFailExec = blacklistedExecutorsOnNode.size + if (numFailExec >= MAX_FAILED_EXEC_PER_NODE_STAGE) { if (blacklistedNodes.add(host)) { logInfo(s"Blacklisting ${host} for stage $stageId") + listenerBus.post( + SparkListenerNodeBlacklistedForStage(now, host, numFailExec, stageId, stageAttemptId)) } } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index c3ed11bfe352a..886c2c99f1ff3 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -102,7 +102,7 @@ private[spark] class TaskSetManager( private[scheduler] val taskSetBlacklistHelperOpt: Option[TaskSetBlacklist] = { blacklistTracker.map { _ => - new TaskSetBlacklist(conf, stageId, clock) + new TaskSetBlacklist(sched.sc.listenerBus, conf, stageId, taskSet.stageAttemptId, clock) } } diff --git a/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala index b4edcf23abc09..3e34bdc0c7b63 100644 --- a/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala +++ b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala @@ -211,6 +211,31 @@ private[spark] class AppStatusListener( updateBlackListStatus(event.executorId, true) } + override def onExecutorBlacklistedForStage( + event: SparkListenerExecutorBlacklistedForStage): Unit = { + Option(liveStages.get((event.stageId, event.stageAttemptId))).foreach { stage => + val now = System.nanoTime() + val esummary = stage.executorSummary(event.executorId) + esummary.isBlacklisted = true + maybeUpdate(esummary, now) + } + } + + override def onNodeBlacklistedForStage(event: SparkListenerNodeBlacklistedForStage): Unit = { + val now = System.nanoTime() + + // Implicitly blacklist every available executor for the stage associated with this node + Option(liveStages.get((event.stageId, event.stageAttemptId))).foreach { stage => + liveExecutors.values.foreach { exec => + if (exec.hostname == event.hostId) { + val esummary = stage.executorSummary(exec.executorId) + esummary.isBlacklisted = true + maybeUpdate(esummary, now) + } + } + } + } + override def onExecutorUnblacklisted(event: SparkListenerExecutorUnblacklisted): Unit = { updateBlackListStatus(event.executorId, false) } diff --git a/core/src/main/scala/org/apache/spark/status/LiveEntity.scala b/core/src/main/scala/org/apache/spark/status/LiveEntity.scala index 4295e664e131c..d5f9e19ffdcd0 100644 --- a/core/src/main/scala/org/apache/spark/status/LiveEntity.scala +++ b/core/src/main/scala/org/apache/spark/status/LiveEntity.scala @@ -316,6 +316,7 @@ private class LiveExecutorStageSummary( var succeededTasks = 0 var failedTasks = 0 var killedTasks = 0 + var isBlacklisted = false var metrics = createMetrics(default = 0L) @@ -334,7 +335,8 @@ private class LiveExecutorStageSummary( metrics.shuffleWriteMetrics.bytesWritten, metrics.shuffleWriteMetrics.recordsWritten, metrics.memoryBytesSpilled, - metrics.diskBytesSpilled) + metrics.diskBytesSpilled, + isBlacklisted) new ExecutorStageSummaryWrapper(stageId, attemptId, executorId, info) } diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala index 7d8e4de3c8efb..550eac3952bbb 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala @@ -68,7 +68,8 @@ class ExecutorStageSummary private[spark]( val shuffleWrite : Long, val shuffleWriteRecords : Long, val memoryBytesSpilled : Long, - val diskBytesSpilled : Long) + val diskBytesSpilled : Long, + val isBlacklistedForStage: Boolean) class ExecutorSummary private[spark]( val id: String, diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala index 95c12b1e73653..0ff64f053f371 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala @@ -136,7 +136,15 @@ private[ui] class ExecutorTable(stage: StageData, store: AppStatusStore) { {Utils.bytesToString(v.diskBytesSpilled)} }} -
+ { + if (executor.map(_.isBlacklisted).getOrElse(false)) { + + } else if (v.isBlacklistedForStage) { + + } else { + + } + } } } diff --git a/core/src/test/resources/HistoryServerExpectations/application_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/application_list_json_expectation.json index f2c3ec5da8891..4fecf84db65a2 100644 --- a/core/src/test/resources/HistoryServerExpectations/application_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/application_list_json_expectation.json @@ -1,4 +1,34 @@ [ { + "id" : "application_1516285256255_0012", + "name" : "Spark shell", + "attempts" : [ { + "startTime" : "2018-01-18T18:30:35.119GMT", + "endTime" : "2018-01-18T18:38:27.938GMT", + "lastUpdated" : "", + "duration" : 472819, + "sparkUser" : "attilapiros", + "completed" : true, + "appSparkVersion" : "2.3.0-SNAPSHOT", + "lastUpdatedEpoch" : 0, + "startTimeEpoch" : 1516300235119, + "endTimeEpoch" : 1516300707938 + } ] +}, { + "id" : "app-20180109111548-0000", + "name" : "Spark shell", + "attempts" : [ { + "startTime" : "2018-01-09T10:15:42.372GMT", + "endTime" : "2018-01-09T10:24:37.606GMT", + "lastUpdated" : "", + "duration" : 535234, + "sparkUser" : "attilapiros", + "completed" : true, + "appSparkVersion" : "2.3.0-SNAPSHOT", + "lastUpdatedEpoch" : 0, + "startTimeEpoch" : 1515492942372, + "endTimeEpoch" : 1515493477606 + } ] +}, { "id" : "app-20161116163331-0000", "name" : "Spark shell", "attempts" : [ { @@ -9,9 +39,9 @@ "sparkUser" : "jose", "completed" : true, "appSparkVersion" : "2.1.0-SNAPSHOT", - "endTimeEpoch" : 1479335620587, + "lastUpdatedEpoch" : 0, "startTimeEpoch" : 1479335609916, - "lastUpdatedEpoch" : 0 + "endTimeEpoch" : 1479335620587 } ] }, { "id" : "app-20161115172038-0000", @@ -24,9 +54,9 @@ "sparkUser" : "jose", "completed" : true, "appSparkVersion" : "2.1.0-SNAPSHOT", - "endTimeEpoch" : 1479252138874, + "lastUpdatedEpoch" : 0, "startTimeEpoch" : 1479252037079, - "lastUpdatedEpoch" : 0 + "endTimeEpoch" : 1479252138874 } ] }, { "id" : "local-1430917381534", @@ -39,9 +69,9 @@ "sparkUser" : "irashid", "completed" : true, "appSparkVersion" : "1.4.0-SNAPSHOT", - "endTimeEpoch" : 1430917391398, + "lastUpdatedEpoch" : 0, "startTimeEpoch" : 1430917380893, - "lastUpdatedEpoch" : 0 + "endTimeEpoch" : 1430917391398 } ] }, { "id" : "local-1430917381535", @@ -55,9 +85,9 @@ "sparkUser" : "irashid", "completed" : true, "appSparkVersion" : "1.4.0-SNAPSHOT", - "endTimeEpoch" : 1430917380950, + "lastUpdatedEpoch" : 0, "startTimeEpoch" : 1430917380893, - "lastUpdatedEpoch" : 0 + "endTimeEpoch" : 1430917380950 }, { "attemptId" : "1", "startTime" : "2015-05-06T13:03:00.880GMT", @@ -67,9 +97,9 @@ "sparkUser" : "irashid", "completed" : true, "appSparkVersion" : "1.4.0-SNAPSHOT", - "endTimeEpoch" : 1430917380890, + "lastUpdatedEpoch" : 0, "startTimeEpoch" : 1430917380880, - "lastUpdatedEpoch" : 0 + "endTimeEpoch" : 1430917380890 } ] }, { "id" : "local-1426533911241", @@ -83,9 +113,9 @@ "sparkUser" : "irashid", "completed" : true, "appSparkVersion" : "", - "endTimeEpoch" : 1426633945177, + "lastUpdatedEpoch" : 0, "startTimeEpoch" : 1426633910242, - "lastUpdatedEpoch" : 0 + "endTimeEpoch" : 1426633945177 }, { "attemptId" : "1", "startTime" : "2015-03-16T19:25:10.242GMT", @@ -95,9 +125,9 @@ "sparkUser" : "irashid", "completed" : true, "appSparkVersion" : "", - "endTimeEpoch" : 1426533945177, + "lastUpdatedEpoch" : 0, "startTimeEpoch" : 1426533910242, - "lastUpdatedEpoch" : 0 + "endTimeEpoch" : 1426533945177 } ] }, { "id" : "local-1425081759269", @@ -110,9 +140,9 @@ "sparkUser" : "irashid", "completed" : true, "appSparkVersion" : "", - "endTimeEpoch" : 1425081766912, + "lastUpdatedEpoch" : 0, "startTimeEpoch" : 1425081758277, - "lastUpdatedEpoch" : 0 + "endTimeEpoch" : 1425081766912 } ] }, { "id" : "local-1422981780767", @@ -125,9 +155,9 @@ "sparkUser" : "irashid", "completed" : true, "appSparkVersion" : "", - "endTimeEpoch" : 1422981788731, + "lastUpdatedEpoch" : 0, "startTimeEpoch" : 1422981779720, - "lastUpdatedEpoch" : 0 + "endTimeEpoch" : 1422981788731 } ] }, { "id" : "local-1422981759269", @@ -140,8 +170,8 @@ "sparkUser" : "irashid", "completed" : true, "appSparkVersion" : "", - "endTimeEpoch" : 1422981766912, + "lastUpdatedEpoch" : 0, "startTimeEpoch" : 1422981758277, - "lastUpdatedEpoch" : 0 + "endTimeEpoch" : 1422981766912 } ] } ] diff --git a/core/src/test/resources/HistoryServerExpectations/blacklisting_for_stage_expectation.json b/core/src/test/resources/HistoryServerExpectations/blacklisting_for_stage_expectation.json new file mode 100644 index 0000000000000..5e9e8230e2745 --- /dev/null +++ b/core/src/test/resources/HistoryServerExpectations/blacklisting_for_stage_expectation.json @@ -0,0 +1,639 @@ +{ + "status": "COMPLETE", + "stageId": 0, + "attemptId": 0, + "numTasks": 10, + "numActiveTasks": 0, + "numCompleteTasks": 10, + "numFailedTasks": 2, + "numKilledTasks": 0, + "numCompletedIndices": 10, + "executorRunTime": 761, + "executorCpuTime": 269916000, + "submissionTime": "2018-01-09T10:21:18.152GMT", + "firstTaskLaunchedTime": "2018-01-09T10:21:18.347GMT", + "completionTime": "2018-01-09T10:21:19.062GMT", + "inputBytes": 0, + "inputRecords": 0, + "outputBytes": 0, + "outputRecords": 0, + "shuffleReadBytes": 0, + "shuffleReadRecords": 0, + "shuffleWriteBytes": 460, + "shuffleWriteRecords": 10, + "memoryBytesSpilled": 0, + "diskBytesSpilled": 0, + "name": "map at :26", + "details": "org.apache.spark.rdd.RDD.map(RDD.scala:370)\n$line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:26)\n$line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:34)\n$line17.$read$$iw$$iw$$iw$$iw$$iw$$iw.(:36)\n$line17.$read$$iw$$iw$$iw$$iw$$iw.(:38)\n$line17.$read$$iw$$iw$$iw$$iw.(:40)\n$line17.$read$$iw$$iw$$iw.(:42)\n$line17.$read$$iw$$iw.(:44)\n$line17.$read$$iw.(:46)\n$line17.$read.(:48)\n$line17.$read$.(:52)\n$line17.$read$.()\n$line17.$eval$.$print$lzycompute(:7)\n$line17.$eval$.$print(:6)\n$line17.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:498)\nscala.tools.nsc.interpreter.IMain$ReadEvalPrint.call(IMain.scala:786)", + "schedulingPool": "default", + "rddIds": [ + 1, + 0 + ], + "accumulatorUpdates": [], + "tasks": { + "0": { + "taskId": 0, + "index": 0, + "attempt": 0, + "launchTime": "2018-01-09T10:21:18.347GMT", + "duration": 562, + "executorId": "0", + "host": "172.30.65.138", + "status": "FAILED", + "taskLocality": "PROCESS_LOCAL", + "speculative": false, + "accumulatorUpdates": [], + "errorMessage": "java.lang.RuntimeException: Bad executor\n\tat $line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2.apply(:27)\n\tat $line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2.apply(:26)\n\tat scala.collection.Iterator$$anon$11.next(Iterator.scala:409)\n\tat org.apache.spark.util.collection.ExternalSorter.insertAll(ExternalSorter.scala:193)\n\tat org.apache.spark.shuffle.sort.SortShuffleWriter.write(SortShuffleWriter.scala:63)\n\tat org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:96)\n\tat org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:53)\n\tat org.apache.spark.scheduler.Task.run(Task.scala:109)\n\tat org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:345)\n\tat java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)\n\tat java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\n\tat java.lang.Thread.run(Thread.java:748)\n", + "taskMetrics": { + "executorDeserializeTime": 0, + "executorDeserializeCpuTime": 0, + "executorRunTime": 460, + "executorCpuTime": 0, + "resultSize": 0, + "jvmGcTime": 14, + "resultSerializationTime": 0, + "memoryBytesSpilled": 0, + "diskBytesSpilled": 0, + "peakExecutionMemory": 0, + "inputMetrics": { + "bytesRead": 0, + "recordsRead": 0 + }, + "outputMetrics": { + "bytesWritten": 0, + "recordsWritten": 0 + }, + "shuffleReadMetrics": { + "remoteBlocksFetched": 0, + "localBlocksFetched": 0, + "fetchWaitTime": 0, + "remoteBytesRead": 0, + "remoteBytesReadToDisk": 0, + "localBytesRead": 0, + "recordsRead": 0 + }, + "shuffleWriteMetrics": { + "bytesWritten": 0, + "writeTime": 3873006, + "recordsWritten": 0 + } + } + }, + "5": { + "taskId": 5, + "index": 3, + "attempt": 0, + "launchTime": "2018-01-09T10:21:18.958GMT", + "duration": 22, + "executorId": "1", + "host": "172.30.65.138", + "status": "SUCCESS", + "taskLocality": "PROCESS_LOCAL", + "speculative": false, + "accumulatorUpdates": [], + "taskMetrics": { + "executorDeserializeTime": 3, + "executorDeserializeCpuTime": 2586000, + "executorRunTime": 9, + "executorCpuTime": 9635000, + "resultSize": 1029, + "jvmGcTime": 0, + "resultSerializationTime": 0, + "memoryBytesSpilled": 0, + "diskBytesSpilled": 0, + "peakExecutionMemory": 0, + "inputMetrics": { + "bytesRead": 0, + "recordsRead": 0 + }, + "outputMetrics": { + "bytesWritten": 0, + "recordsWritten": 0 + }, + "shuffleReadMetrics": { + "remoteBlocksFetched": 0, + "localBlocksFetched": 0, + "fetchWaitTime": 0, + "remoteBytesRead": 0, + "remoteBytesReadToDisk": 0, + "localBytesRead": 0, + "recordsRead": 0 + }, + "shuffleWriteMetrics": { + "bytesWritten": 46, + "writeTime": 262919, + "recordsWritten": 1 + } + } + }, + "10": { + "taskId": 10, + "index": 8, + "attempt": 0, + "launchTime": "2018-01-09T10:21:19.034GMT", + "duration": 12, + "executorId": "1", + "host": "172.30.65.138", + "status": "SUCCESS", + "taskLocality": "PROCESS_LOCAL", + "speculative": false, + "accumulatorUpdates": [], + "taskMetrics": { + "executorDeserializeTime": 2, + "executorDeserializeCpuTime": 1803000, + "executorRunTime": 6, + "executorCpuTime": 6157000, + "resultSize": 1029, + "jvmGcTime": 0, + "resultSerializationTime": 0, + "memoryBytesSpilled": 0, + "diskBytesSpilled": 0, + "peakExecutionMemory": 0, + "inputMetrics": { + "bytesRead": 0, + "recordsRead": 0 + }, + "outputMetrics": { + "bytesWritten": 0, + "recordsWritten": 0 + }, + "shuffleReadMetrics": { + "remoteBlocksFetched": 0, + "localBlocksFetched": 0, + "fetchWaitTime": 0, + "remoteBytesRead": 0, + "remoteBytesReadToDisk": 0, + "localBytesRead": 0, + "recordsRead": 0 + }, + "shuffleWriteMetrics": { + "bytesWritten": 46, + "writeTime": 243647, + "recordsWritten": 1 + } + } + }, + "1": { + "taskId": 1, + "index": 1, + "attempt": 0, + "launchTime": "2018-01-09T10:21:18.364GMT", + "duration": 565, + "executorId": "1", + "host": "172.30.65.138", + "status": "SUCCESS", + "taskLocality": "PROCESS_LOCAL", + "speculative": false, + "accumulatorUpdates": [], + "taskMetrics": { + "executorDeserializeTime": 301, + "executorDeserializeCpuTime": 200029000, + "executorRunTime": 212, + "executorCpuTime": 198479000, + "resultSize": 1115, + "jvmGcTime": 13, + "resultSerializationTime": 1, + "memoryBytesSpilled": 0, + "diskBytesSpilled": 0, + "peakExecutionMemory": 0, + "inputMetrics": { + "bytesRead": 0, + "recordsRead": 0 + }, + "outputMetrics": { + "bytesWritten": 0, + "recordsWritten": 0 + }, + "shuffleReadMetrics": { + "remoteBlocksFetched": 0, + "localBlocksFetched": 0, + "fetchWaitTime": 0, + "remoteBytesRead": 0, + "remoteBytesReadToDisk": 0, + "localBytesRead": 0, + "recordsRead": 0 + }, + "shuffleWriteMetrics": { + "bytesWritten": 46, + "writeTime": 2409488, + "recordsWritten": 1 + } + } + }, + "6": { + "taskId": 6, + "index": 4, + "attempt": 0, + "launchTime": "2018-01-09T10:21:18.980GMT", + "duration": 16, + "executorId": "1", + "host": "172.30.65.138", + "status": "SUCCESS", + "taskLocality": "PROCESS_LOCAL", + "speculative": false, + "accumulatorUpdates": [], + "taskMetrics": { + "executorDeserializeTime": 3, + "executorDeserializeCpuTime": 2610000, + "executorRunTime": 10, + "executorCpuTime": 9622000, + "resultSize": 1029, + "jvmGcTime": 0, + "resultSerializationTime": 0, + "memoryBytesSpilled": 0, + "diskBytesSpilled": 0, + "peakExecutionMemory": 0, + "inputMetrics": { + "bytesRead": 0, + "recordsRead": 0 + }, + "outputMetrics": { + "bytesWritten": 0, + "recordsWritten": 0 + }, + "shuffleReadMetrics": { + "remoteBlocksFetched": 0, + "localBlocksFetched": 0, + "fetchWaitTime": 0, + "remoteBytesRead": 0, + "remoteBytesReadToDisk": 0, + "localBytesRead": 0, + "recordsRead": 0 + }, + "shuffleWriteMetrics": { + "bytesWritten": 46, + "writeTime": 385110, + "recordsWritten": 1 + } + } + }, + "9": { + "taskId": 9, + "index": 7, + "attempt": 0, + "launchTime": "2018-01-09T10:21:19.022GMT", + "duration": 12, + "executorId": "1", + "host": "172.30.65.138", + "status": "SUCCESS", + "taskLocality": "PROCESS_LOCAL", + "speculative": false, + "accumulatorUpdates": [], + "taskMetrics": { + "executorDeserializeTime": 2, + "executorDeserializeCpuTime": 1981000, + "executorRunTime": 7, + "executorCpuTime": 6335000, + "resultSize": 1029, + "jvmGcTime": 0, + "resultSerializationTime": 0, + "memoryBytesSpilled": 0, + "diskBytesSpilled": 0, + "peakExecutionMemory": 0, + "inputMetrics": { + "bytesRead": 0, + "recordsRead": 0 + }, + "outputMetrics": { + "bytesWritten": 0, + "recordsWritten": 0 + }, + "shuffleReadMetrics": { + "remoteBlocksFetched": 0, + "localBlocksFetched": 0, + "fetchWaitTime": 0, + "remoteBytesRead": 0, + "remoteBytesReadToDisk": 0, + "localBytesRead": 0, + "recordsRead": 0 + }, + "shuffleWriteMetrics": { + "bytesWritten": 46, + "writeTime": 259354, + "recordsWritten": 1 + } + } + }, + "2": { + "taskId": 2, + "index": 2, + "attempt": 0, + "launchTime": "2018-01-09T10:21:18.899GMT", + "duration": 27, + "executorId": "0", + "host": "172.30.65.138", + "status": "FAILED", + "taskLocality": "PROCESS_LOCAL", + "speculative": false, + "accumulatorUpdates": [], + "errorMessage": "java.lang.RuntimeException: Bad executor\n\tat $line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2.apply(:27)\n\tat $line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2.apply(:26)\n\tat scala.collection.Iterator$$anon$11.next(Iterator.scala:409)\n\tat org.apache.spark.util.collection.ExternalSorter.insertAll(ExternalSorter.scala:193)\n\tat org.apache.spark.shuffle.sort.SortShuffleWriter.write(SortShuffleWriter.scala:63)\n\tat org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:96)\n\tat org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:53)\n\tat org.apache.spark.scheduler.Task.run(Task.scala:109)\n\tat org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:345)\n\tat java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)\n\tat java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\n\tat java.lang.Thread.run(Thread.java:748)\n", + "taskMetrics": { + "executorDeserializeTime": 0, + "executorDeserializeCpuTime": 0, + "executorRunTime": 16, + "executorCpuTime": 0, + "resultSize": 0, + "jvmGcTime": 0, + "resultSerializationTime": 0, + "memoryBytesSpilled": 0, + "diskBytesSpilled": 0, + "peakExecutionMemory": 0, + "inputMetrics": { + "bytesRead": 0, + "recordsRead": 0 + }, + "outputMetrics": { + "bytesWritten": 0, + "recordsWritten": 0 + }, + "shuffleReadMetrics": { + "remoteBlocksFetched": 0, + "localBlocksFetched": 0, + "fetchWaitTime": 0, + "remoteBytesRead": 0, + "remoteBytesReadToDisk": 0, + "localBytesRead": 0, + "recordsRead": 0 + }, + "shuffleWriteMetrics": { + "bytesWritten": 0, + "writeTime": 126128, + "recordsWritten": 0 + } + } + }, + "7": { + "taskId": 7, + "index": 5, + "attempt": 0, + "launchTime": "2018-01-09T10:21:18.996GMT", + "duration": 15, + "executorId": "1", + "host": "172.30.65.138", + "status": "SUCCESS", + "taskLocality": "PROCESS_LOCAL", + "speculative": false, + "accumulatorUpdates": [], + "taskMetrics": { + "executorDeserializeTime": 2, + "executorDeserializeCpuTime": 2231000, + "executorRunTime": 9, + "executorCpuTime": 8407000, + "resultSize": 1029, + "jvmGcTime": 0, + "resultSerializationTime": 0, + "memoryBytesSpilled": 0, + "diskBytesSpilled": 0, + "peakExecutionMemory": 0, + "inputMetrics": { + "bytesRead": 0, + "recordsRead": 0 + }, + "outputMetrics": { + "bytesWritten": 0, + "recordsWritten": 0 + }, + "shuffleReadMetrics": { + "remoteBlocksFetched": 0, + "localBlocksFetched": 0, + "fetchWaitTime": 0, + "remoteBytesRead": 0, + "remoteBytesReadToDisk": 0, + "localBytesRead": 0, + "recordsRead": 0 + }, + "shuffleWriteMetrics": { + "bytesWritten": 46, + "writeTime": 205520, + "recordsWritten": 1 + } + } + }, + "3": { + "taskId": 3, + "index": 0, + "attempt": 1, + "launchTime": "2018-01-09T10:21:18.919GMT", + "duration": 24, + "executorId": "1", + "host": "172.30.65.138", + "status": "SUCCESS", + "taskLocality": "PROCESS_LOCAL", + "speculative": false, + "accumulatorUpdates": [], + "taskMetrics": { + "executorDeserializeTime": 8, + "executorDeserializeCpuTime": 8878000, + "executorRunTime": 10, + "executorCpuTime": 9364000, + "resultSize": 1029, + "jvmGcTime": 0, + "resultSerializationTime": 0, + "memoryBytesSpilled": 0, + "diskBytesSpilled": 0, + "peakExecutionMemory": 0, + "inputMetrics": { + "bytesRead": 0, + "recordsRead": 0 + }, + "outputMetrics": { + "bytesWritten": 0, + "recordsWritten": 0 + }, + "shuffleReadMetrics": { + "remoteBlocksFetched": 0, + "localBlocksFetched": 0, + "fetchWaitTime": 0, + "remoteBytesRead": 0, + "remoteBytesReadToDisk": 0, + "localBytesRead": 0, + "recordsRead": 0 + }, + "shuffleWriteMetrics": { + "bytesWritten": 46, + "writeTime": 207014, + "recordsWritten": 1 + } + } + }, + "11": { + "taskId": 11, + "index": 9, + "attempt": 0, + "launchTime": "2018-01-09T10:21:19.045GMT", + "duration": 15, + "executorId": "1", + "host": "172.30.65.138", + "status": "SUCCESS", + "taskLocality": "PROCESS_LOCAL", + "speculative": false, + "accumulatorUpdates": [], + "taskMetrics": { + "executorDeserializeTime": 3, + "executorDeserializeCpuTime": 2017000, + "executorRunTime": 6, + "executorCpuTime": 6676000, + "resultSize": 1029, + "jvmGcTime": 0, + "resultSerializationTime": 0, + "memoryBytesSpilled": 0, + "diskBytesSpilled": 0, + "peakExecutionMemory": 0, + "inputMetrics": { + "bytesRead": 0, + "recordsRead": 0 + }, + "outputMetrics": { + "bytesWritten": 0, + "recordsWritten": 0 + }, + "shuffleReadMetrics": { + "remoteBlocksFetched": 0, + "localBlocksFetched": 0, + "fetchWaitTime": 0, + "remoteBytesRead": 0, + "remoteBytesReadToDisk": 0, + "localBytesRead": 0, + "recordsRead": 0 + }, + "shuffleWriteMetrics": { + "bytesWritten": 46, + "writeTime": 233652, + "recordsWritten": 1 + } + } + }, + "8": { + "taskId": 8, + "index": 6, + "attempt": 0, + "launchTime": "2018-01-09T10:21:19.011GMT", + "duration": 11, + "executorId": "1", + "host": "172.30.65.138", + "status": "SUCCESS", + "taskLocality": "PROCESS_LOCAL", + "speculative": false, + "accumulatorUpdates": [], + "taskMetrics": { + "executorDeserializeTime": 1, + "executorDeserializeCpuTime": 1554000, + "executorRunTime": 7, + "executorCpuTime": 6034000, + "resultSize": 1029, + "jvmGcTime": 0, + "resultSerializationTime": 0, + "memoryBytesSpilled": 0, + "diskBytesSpilled": 0, + "peakExecutionMemory": 0, + "inputMetrics": { + "bytesRead": 0, + "recordsRead": 0 + }, + "outputMetrics": { + "bytesWritten": 0, + "recordsWritten": 0 + }, + "shuffleReadMetrics": { + "remoteBlocksFetched": 0, + "localBlocksFetched": 0, + "fetchWaitTime": 0, + "remoteBytesRead": 0, + "remoteBytesReadToDisk": 0, + "localBytesRead": 0, + "recordsRead": 0 + }, + "shuffleWriteMetrics": { + "bytesWritten": 46, + "writeTime": 213296, + "recordsWritten": 1 + } + } + }, + "4": { + "taskId": 4, + "index": 2, + "attempt": 1, + "launchTime": "2018-01-09T10:21:18.943GMT", + "duration": 16, + "executorId": "1", + "host": "172.30.65.138", + "status": "SUCCESS", + "taskLocality": "PROCESS_LOCAL", + "speculative": false, + "accumulatorUpdates": [], + "taskMetrics": { + "executorDeserializeTime": 2, + "executorDeserializeCpuTime": 2211000, + "executorRunTime": 9, + "executorCpuTime": 9207000, + "resultSize": 1029, + "jvmGcTime": 0, + "resultSerializationTime": 0, + "memoryBytesSpilled": 0, + "diskBytesSpilled": 0, + "peakExecutionMemory": 0, + "inputMetrics": { + "bytesRead": 0, + "recordsRead": 0 + }, + "outputMetrics": { + "bytesWritten": 0, + "recordsWritten": 0 + }, + "shuffleReadMetrics": { + "remoteBlocksFetched": 0, + "localBlocksFetched": 0, + "fetchWaitTime": 0, + "remoteBytesRead": 0, + "remoteBytesReadToDisk": 0, + "localBytesRead": 0, + "recordsRead": 0 + }, + "shuffleWriteMetrics": { + "bytesWritten": 46, + "writeTime": 292381, + "recordsWritten": 1 + } + } + } + }, + "executorSummary": { + "0": { + "taskTime": 589, + "failedTasks": 2, + "succeededTasks": 0, + "killedTasks": 0, + "inputBytes": 0, + "inputRecords": 0, + "outputBytes": 0, + "outputRecords": 0, + "shuffleRead": 0, + "shuffleReadRecords": 0, + "shuffleWrite": 0, + "shuffleWriteRecords": 0, + "memoryBytesSpilled": 0, + "diskBytesSpilled": 0, + "isBlacklistedForStage": true + }, + "1": { + "taskTime": 708, + "failedTasks": 0, + "succeededTasks": 10, + "killedTasks": 0, + "inputBytes": 0, + "inputRecords": 0, + "outputBytes": 0, + "outputRecords": 0, + "shuffleRead": 0, + "shuffleReadRecords": 0, + "shuffleWrite": 460, + "shuffleWriteRecords": 10, + "memoryBytesSpilled": 0, + "diskBytesSpilled": 0, + "isBlacklistedForStage": false + } + }, + "killedTasksSummary": {} +} diff --git a/core/src/test/resources/HistoryServerExpectations/blacklisting_node_for_stage_expectation.json b/core/src/test/resources/HistoryServerExpectations/blacklisting_node_for_stage_expectation.json new file mode 100644 index 0000000000000..acd4cc53de6cd --- /dev/null +++ b/core/src/test/resources/HistoryServerExpectations/blacklisting_node_for_stage_expectation.json @@ -0,0 +1,783 @@ +{ + "status" : "COMPLETE", + "stageId" : 0, + "attemptId" : 0, + "numTasks" : 10, + "numActiveTasks" : 0, + "numCompleteTasks" : 10, + "numFailedTasks" : 4, + "numKilledTasks" : 0, + "numCompletedIndices" : 10, + "executorRunTime" : 5080, + "executorCpuTime" : 1163210819, + "submissionTime" : "2018-01-18T18:33:12.658GMT", + "firstTaskLaunchedTime" : "2018-01-18T18:33:12.816GMT", + "completionTime" : "2018-01-18T18:33:15.279GMT", + "inputBytes" : 0, + "inputRecords" : 0, + "outputBytes" : 0, + "outputRecords" : 0, + "shuffleReadBytes" : 0, + "shuffleReadRecords" : 0, + "shuffleWriteBytes" : 1461, + "shuffleWriteRecords" : 30, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "name" : "map at :27", + "details" : "org.apache.spark.rdd.RDD.map(RDD.scala:370)\n$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:27)\n$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:35)\n$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw.(:37)\n$line15.$read$$iw$$iw$$iw$$iw$$iw.(:39)\n$line15.$read$$iw$$iw$$iw$$iw.(:41)\n$line15.$read$$iw$$iw$$iw.(:43)\n$line15.$read$$iw$$iw.(:45)\n$line15.$read$$iw.(:47)\n$line15.$read.(:49)\n$line15.$read$.(:53)\n$line15.$read$.()\n$line15.$eval$.$print$lzycompute(:7)\n$line15.$eval$.$print(:6)\n$line15.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:498)\nscala.tools.nsc.interpreter.IMain$ReadEvalPrint.call(IMain.scala:786)", + "schedulingPool" : "default", + "rddIds" : [ 1, 0 ], + "accumulatorUpdates" : [ ], + "tasks" : { + "0" : { + "taskId" : 0, + "index" : 0, + "attempt" : 0, + "launchTime" : "2018-01-18T18:33:12.816GMT", + "duration" : 2064, + "executorId" : "1", + "host" : "apiros-3.gce.test.com", + "status" : "SUCCESS", + "taskLocality" : "PROCESS_LOCAL", + "speculative" : false, + "accumulatorUpdates" : [ ], + "taskMetrics" : { + "executorDeserializeTime" : 1081, + "executorDeserializeCpuTime" : 353981050, + "executorRunTime" : 914, + "executorCpuTime" : 368865439, + "resultSize" : 1134, + "jvmGcTime" : 75, + "resultSerializationTime" : 1, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, + "inputMetrics" : { + "bytesRead" : 0, + "recordsRead" : 0 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 144, + "writeTime" : 3662221, + "recordsWritten" : 3 + } + } + }, + "5" : { + "taskId" : 5, + "index" : 5, + "attempt" : 0, + "launchTime" : "2018-01-18T18:33:14.320GMT", + "duration" : 73, + "executorId" : "5", + "host" : "apiros-2.gce.test.com", + "status" : "FAILED", + "taskLocality" : "PROCESS_LOCAL", + "speculative" : false, + "accumulatorUpdates" : [ ], + "errorMessage" : "java.lang.RuntimeException: Bad executor\n\tat $line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2.apply(:28)\n\tat $line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2.apply(:27)\n\tat scala.collection.Iterator$$anon$11.next(Iterator.scala:409)\n\tat org.apache.spark.util.collection.ExternalSorter.insertAll(ExternalSorter.scala:193)\n\tat org.apache.spark.shuffle.sort.SortShuffleWriter.write(SortShuffleWriter.scala:63)\n\tat org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:96)\n\tat org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:53)\n\tat org.apache.spark.scheduler.Task.run(Task.scala:109)\n\tat org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:345)\n\tat java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)\n\tat java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\n\tat java.lang.Thread.run(Thread.java:748)\n", + "taskMetrics" : { + "executorDeserializeTime" : 0, + "executorDeserializeCpuTime" : 0, + "executorRunTime" : 27, + "executorCpuTime" : 0, + "resultSize" : 0, + "jvmGcTime" : 0, + "resultSerializationTime" : 0, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, + "inputMetrics" : { + "bytesRead" : 0, + "recordsRead" : 0 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 0, + "writeTime" : 191901, + "recordsWritten" : 0 + } + } + }, + "10" : { + "taskId" : 10, + "index" : 1, + "attempt" : 1, + "launchTime" : "2018-01-18T18:33:15.069GMT", + "duration" : 132, + "executorId" : "2", + "host" : "apiros-3.gce.test.com", + "status" : "SUCCESS", + "taskLocality" : "PROCESS_LOCAL", + "speculative" : false, + "accumulatorUpdates" : [ ], + "taskMetrics" : { + "executorDeserializeTime" : 5, + "executorDeserializeCpuTime" : 4598966, + "executorRunTime" : 76, + "executorCpuTime" : 20826337, + "resultSize" : 1091, + "jvmGcTime" : 0, + "resultSerializationTime" : 1, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, + "inputMetrics" : { + "bytesRead" : 0, + "recordsRead" : 0 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 144, + "writeTime" : 301705, + "recordsWritten" : 3 + } + } + }, + "1" : { + "taskId" : 1, + "index" : 1, + "attempt" : 0, + "launchTime" : "2018-01-18T18:33:12.832GMT", + "duration" : 1506, + "executorId" : "5", + "host" : "apiros-2.gce.test.com", + "status" : "FAILED", + "taskLocality" : "PROCESS_LOCAL", + "speculative" : false, + "accumulatorUpdates" : [ ], + "errorMessage" : "java.lang.RuntimeException: Bad executor\n\tat $line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2.apply(:28)\n\tat $line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2.apply(:27)\n\tat scala.collection.Iterator$$anon$11.next(Iterator.scala:409)\n\tat org.apache.spark.util.collection.ExternalSorter.insertAll(ExternalSorter.scala:193)\n\tat org.apache.spark.shuffle.sort.SortShuffleWriter.write(SortShuffleWriter.scala:63)\n\tat org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:96)\n\tat org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:53)\n\tat org.apache.spark.scheduler.Task.run(Task.scala:109)\n\tat org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:345)\n\tat java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)\n\tat java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\n\tat java.lang.Thread.run(Thread.java:748)\n", + "taskMetrics" : { + "executorDeserializeTime" : 0, + "executorDeserializeCpuTime" : 0, + "executorRunTime" : 1332, + "executorCpuTime" : 0, + "resultSize" : 0, + "jvmGcTime" : 33, + "resultSerializationTime" : 0, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, + "inputMetrics" : { + "bytesRead" : 0, + "recordsRead" : 0 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 0, + "writeTime" : 3075188, + "recordsWritten" : 0 + } + } + }, + "6" : { + "taskId" : 6, + "index" : 6, + "attempt" : 0, + "launchTime" : "2018-01-18T18:33:14.323GMT", + "duration" : 67, + "executorId" : "4", + "host" : "apiros-2.gce.test.com", + "status" : "FAILED", + "taskLocality" : "PROCESS_LOCAL", + "speculative" : false, + "accumulatorUpdates" : [ ], + "errorMessage" : "java.lang.RuntimeException: Bad executor\n\tat $line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2.apply(:28)\n\tat $line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2.apply(:27)\n\tat scala.collection.Iterator$$anon$11.next(Iterator.scala:409)\n\tat org.apache.spark.util.collection.ExternalSorter.insertAll(ExternalSorter.scala:193)\n\tat org.apache.spark.shuffle.sort.SortShuffleWriter.write(SortShuffleWriter.scala:63)\n\tat org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:96)\n\tat org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:53)\n\tat org.apache.spark.scheduler.Task.run(Task.scala:109)\n\tat org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:345)\n\tat java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)\n\tat java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\n\tat java.lang.Thread.run(Thread.java:748)\n", + "taskMetrics" : { + "executorDeserializeTime" : 0, + "executorDeserializeCpuTime" : 0, + "executorRunTime" : 51, + "executorCpuTime" : 0, + "resultSize" : 0, + "jvmGcTime" : 0, + "resultSerializationTime" : 0, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, + "inputMetrics" : { + "bytesRead" : 0, + "recordsRead" : 0 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 0, + "writeTime" : 183718, + "recordsWritten" : 0 + } + } + }, + "9" : { + "taskId" : 9, + "index" : 4, + "attempt" : 1, + "launchTime" : "2018-01-18T18:33:14.973GMT", + "duration" : 96, + "executorId" : "2", + "host" : "apiros-3.gce.test.com", + "status" : "SUCCESS", + "taskLocality" : "PROCESS_LOCAL", + "speculative" : false, + "accumulatorUpdates" : [ ], + "taskMetrics" : { + "executorDeserializeTime" : 5, + "executorDeserializeCpuTime" : 4793905, + "executorRunTime" : 48, + "executorCpuTime" : 25678331, + "resultSize" : 1091, + "jvmGcTime" : 0, + "resultSerializationTime" : 1, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, + "inputMetrics" : { + "bytesRead" : 0, + "recordsRead" : 0 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 147, + "writeTime" : 366050, + "recordsWritten" : 3 + } + } + }, + "13" : { + "taskId" : 13, + "index" : 9, + "attempt" : 0, + "launchTime" : "2018-01-18T18:33:15.200GMT", + "duration" : 76, + "executorId" : "2", + "host" : "apiros-3.gce.test.com", + "status" : "SUCCESS", + "taskLocality" : "PROCESS_LOCAL", + "speculative" : false, + "accumulatorUpdates" : [ ], + "taskMetrics" : { + "executorDeserializeTime" : 25, + "executorDeserializeCpuTime" : 5860574, + "executorRunTime" : 25, + "executorCpuTime" : 20585619, + "resultSize" : 1048, + "jvmGcTime" : 0, + "resultSerializationTime" : 0, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, + "inputMetrics" : { + "bytesRead" : 0, + "recordsRead" : 0 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 147, + "writeTime" : 369513, + "recordsWritten" : 3 + } + } + }, + "2" : { + "taskId" : 2, + "index" : 2, + "attempt" : 0, + "launchTime" : "2018-01-18T18:33:12.832GMT", + "duration" : 1774, + "executorId" : "3", + "host" : "apiros-2.gce.test.com", + "status" : "SUCCESS", + "taskLocality" : "PROCESS_LOCAL", + "speculative" : false, + "accumulatorUpdates" : [ ], + "taskMetrics" : { + "executorDeserializeTime" : 1206, + "executorDeserializeCpuTime" : 263386625, + "executorRunTime" : 493, + "executorCpuTime" : 278399617, + "resultSize" : 1134, + "jvmGcTime" : 78, + "resultSerializationTime" : 1, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, + "inputMetrics" : { + "bytesRead" : 0, + "recordsRead" : 0 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 144, + "writeTime" : 3322956, + "recordsWritten" : 3 + } + } + }, + "12" : { + "taskId" : 12, + "index" : 8, + "attempt" : 0, + "launchTime" : "2018-01-18T18:33:15.165GMT", + "duration" : 60, + "executorId" : "1", + "host" : "apiros-3.gce.test.com", + "status" : "SUCCESS", + "taskLocality" : "PROCESS_LOCAL", + "speculative" : false, + "accumulatorUpdates" : [ ], + "taskMetrics" : { + "executorDeserializeTime" : 4, + "executorDeserializeCpuTime" : 4010338, + "executorRunTime" : 34, + "executorCpuTime" : 21657558, + "resultSize" : 1048, + "jvmGcTime" : 0, + "resultSerializationTime" : 0, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, + "inputMetrics" : { + "bytesRead" : 0, + "recordsRead" : 0 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 147, + "writeTime" : 319101, + "recordsWritten" : 3 + } + } + }, + "7" : { + "taskId" : 7, + "index" : 5, + "attempt" : 1, + "launchTime" : "2018-01-18T18:33:14.859GMT", + "duration" : 115, + "executorId" : "2", + "host" : "apiros-3.gce.test.com", + "status" : "SUCCESS", + "taskLocality" : "PROCESS_LOCAL", + "speculative" : false, + "accumulatorUpdates" : [ ], + "taskMetrics" : { + "executorDeserializeTime" : 11, + "executorDeserializeCpuTime" : 10894331, + "executorRunTime" : 84, + "executorCpuTime" : 28283110, + "resultSize" : 1048, + "jvmGcTime" : 0, + "resultSerializationTime" : 0, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, + "inputMetrics" : { + "bytesRead" : 0, + "recordsRead" : 0 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 147, + "writeTime" : 377601, + "recordsWritten" : 3 + } + } + }, + "3" : { + "taskId" : 3, + "index" : 3, + "attempt" : 0, + "launchTime" : "2018-01-18T18:33:12.833GMT", + "duration" : 2027, + "executorId" : "2", + "host" : "apiros-3.gce.test.com", + "status" : "SUCCESS", + "taskLocality" : "PROCESS_LOCAL", + "speculative" : false, + "accumulatorUpdates" : [ ], + "taskMetrics" : { + "executorDeserializeTime" : 1282, + "executorDeserializeCpuTime" : 365807898, + "executorRunTime" : 681, + "executorCpuTime" : 349920830, + "resultSize" : 1134, + "jvmGcTime" : 102, + "resultSerializationTime" : 1, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, + "inputMetrics" : { + "bytesRead" : 0, + "recordsRead" : 0 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 147, + "writeTime" : 3587839, + "recordsWritten" : 3 + } + } + }, + "11" : { + "taskId" : 11, + "index" : 7, + "attempt" : 0, + "launchTime" : "2018-01-18T18:33:15.072GMT", + "duration" : 93, + "executorId" : "1", + "host" : "apiros-3.gce.test.com", + "status" : "SUCCESS", + "taskLocality" : "PROCESS_LOCAL", + "speculative" : false, + "accumulatorUpdates" : [ ], + "taskMetrics" : { + "executorDeserializeTime" : 4, + "executorDeserializeCpuTime" : 4239884, + "executorRunTime" : 77, + "executorCpuTime" : 21689428, + "resultSize" : 1048, + "jvmGcTime" : 0, + "resultSerializationTime" : 0, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, + "inputMetrics" : { + "bytesRead" : 0, + "recordsRead" : 0 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 147, + "writeTime" : 323898, + "recordsWritten" : 3 + } + } + }, + "8" : { + "taskId" : 8, + "index" : 6, + "attempt" : 1, + "launchTime" : "2018-01-18T18:33:14.879GMT", + "duration" : 194, + "executorId" : "1", + "host" : "apiros-3.gce.test.com", + "status" : "SUCCESS", + "taskLocality" : "PROCESS_LOCAL", + "speculative" : false, + "accumulatorUpdates" : [ ], + "taskMetrics" : { + "executorDeserializeTime" : 56, + "executorDeserializeCpuTime" : 12246145, + "executorRunTime" : 54, + "executorCpuTime" : 27304550, + "resultSize" : 1048, + "jvmGcTime" : 0, + "resultSerializationTime" : 0, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, + "inputMetrics" : { + "bytesRead" : 0, + "recordsRead" : 0 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 147, + "writeTime" : 311940, + "recordsWritten" : 3 + } + } + }, + "4" : { + "taskId" : 4, + "index" : 4, + "attempt" : 0, + "launchTime" : "2018-01-18T18:33:12.833GMT", + "duration" : 1522, + "executorId" : "4", + "host" : "apiros-2.gce.test.com", + "status" : "FAILED", + "taskLocality" : "PROCESS_LOCAL", + "speculative" : false, + "accumulatorUpdates" : [ ], + "errorMessage" : "java.lang.RuntimeException: Bad executor\n\tat $line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2.apply(:28)\n\tat $line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2.apply(:27)\n\tat scala.collection.Iterator$$anon$11.next(Iterator.scala:409)\n\tat org.apache.spark.util.collection.ExternalSorter.insertAll(ExternalSorter.scala:193)\n\tat org.apache.spark.shuffle.sort.SortShuffleWriter.write(SortShuffleWriter.scala:63)\n\tat org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:96)\n\tat org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:53)\n\tat org.apache.spark.scheduler.Task.run(Task.scala:109)\n\tat org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:345)\n\tat java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)\n\tat java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\n\tat java.lang.Thread.run(Thread.java:748)\n", + "taskMetrics" : { + "executorDeserializeTime" : 0, + "executorDeserializeCpuTime" : 0, + "executorRunTime" : 1184, + "executorCpuTime" : 0, + "resultSize" : 0, + "jvmGcTime" : 82, + "resultSerializationTime" : 0, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "peakExecutionMemory" : 0, + "inputMetrics" : { + "bytesRead" : 0, + "recordsRead" : 0 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "remoteBytesReadToDisk" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 0, + "writeTime" : 16858066, + "recordsWritten" : 0 + } + } + } + }, + "executorSummary" : { + "4" : { + "taskTime" : 1589, + "failedTasks" : 2, + "succeededTasks" : 0, + "killedTasks" : 0, + "inputBytes" : 0, + "inputRecords" : 0, + "outputBytes" : 0, + "outputRecords" : 0, + "shuffleRead" : 0, + "shuffleReadRecords" : 0, + "shuffleWrite" : 0, + "shuffleWriteRecords" : 0, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "isBlacklistedForStage" : true + }, + "5" : { + "taskTime" : 1579, + "failedTasks" : 2, + "succeededTasks" : 0, + "killedTasks" : 0, + "inputBytes" : 0, + "inputRecords" : 0, + "outputBytes" : 0, + "outputRecords" : 0, + "shuffleRead" : 0, + "shuffleReadRecords" : 0, + "shuffleWrite" : 0, + "shuffleWriteRecords" : 0, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "isBlacklistedForStage" : true + }, + "1" : { + "taskTime" : 2411, + "failedTasks" : 0, + "succeededTasks" : 4, + "killedTasks" : 0, + "inputBytes" : 0, + "inputRecords" : 0, + "outputBytes" : 0, + "outputRecords" : 0, + "shuffleRead" : 0, + "shuffleReadRecords" : 0, + "shuffleWrite" : 585, + "shuffleWriteRecords" : 12, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "isBlacklistedForStage" : false + }, + "2" : { + "taskTime" : 2446, + "failedTasks" : 0, + "succeededTasks" : 5, + "killedTasks" : 0, + "inputBytes" : 0, + "inputRecords" : 0, + "outputBytes" : 0, + "outputRecords" : 0, + "shuffleRead" : 0, + "shuffleReadRecords" : 0, + "shuffleWrite" : 732, + "shuffleWriteRecords" : 15, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "isBlacklistedForStage" : false + }, + "3" : { + "taskTime" : 1774, + "failedTasks" : 0, + "succeededTasks" : 1, + "killedTasks" : 0, + "inputBytes" : 0, + "inputRecords" : 0, + "outputBytes" : 0, + "outputRecords" : 0, + "shuffleRead" : 0, + "shuffleReadRecords" : 0, + "shuffleWrite" : 144, + "shuffleWriteRecords" : 3, + "memoryBytesSpilled" : 0, + "diskBytesSpilled" : 0, + "isBlacklistedForStage" : true + } + }, + "killedTasksSummary" : { } +} \ No newline at end of file diff --git a/core/src/test/resources/HistoryServerExpectations/completed_app_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/completed_app_list_json_expectation.json index c925c1dd8a4d3..4fecf84db65a2 100644 --- a/core/src/test/resources/HistoryServerExpectations/completed_app_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/completed_app_list_json_expectation.json @@ -1,4 +1,34 @@ [ { + "id" : "application_1516285256255_0012", + "name" : "Spark shell", + "attempts" : [ { + "startTime" : "2018-01-18T18:30:35.119GMT", + "endTime" : "2018-01-18T18:38:27.938GMT", + "lastUpdated" : "", + "duration" : 472819, + "sparkUser" : "attilapiros", + "completed" : true, + "appSparkVersion" : "2.3.0-SNAPSHOT", + "lastUpdatedEpoch" : 0, + "startTimeEpoch" : 1516300235119, + "endTimeEpoch" : 1516300707938 + } ] +}, { + "id" : "app-20180109111548-0000", + "name" : "Spark shell", + "attempts" : [ { + "startTime" : "2018-01-09T10:15:42.372GMT", + "endTime" : "2018-01-09T10:24:37.606GMT", + "lastUpdated" : "", + "duration" : 535234, + "sparkUser" : "attilapiros", + "completed" : true, + "appSparkVersion" : "2.3.0-SNAPSHOT", + "lastUpdatedEpoch" : 0, + "startTimeEpoch" : 1515492942372, + "endTimeEpoch" : 1515493477606 + } ] +}, { "id" : "app-20161116163331-0000", "name" : "Spark shell", "attempts" : [ { @@ -9,9 +39,9 @@ "sparkUser" : "jose", "completed" : true, "appSparkVersion" : "2.1.0-SNAPSHOT", - "endTimeEpoch" : 1479335620587, + "lastUpdatedEpoch" : 0, "startTimeEpoch" : 1479335609916, - "lastUpdatedEpoch" : 0 + "endTimeEpoch" : 1479335620587 } ] }, { "id" : "app-20161115172038-0000", @@ -24,9 +54,9 @@ "sparkUser" : "jose", "completed" : true, "appSparkVersion" : "2.1.0-SNAPSHOT", - "endTimeEpoch" : 1479252138874, + "lastUpdatedEpoch" : 0, "startTimeEpoch" : 1479252037079, - "lastUpdatedEpoch" : 0 + "endTimeEpoch" : 1479252138874 } ] }, { "id" : "local-1430917381534", @@ -39,9 +69,9 @@ "sparkUser" : "irashid", "completed" : true, "appSparkVersion" : "1.4.0-SNAPSHOT", - "endTimeEpoch" : 1430917391398, + "lastUpdatedEpoch" : 0, "startTimeEpoch" : 1430917380893, - "lastUpdatedEpoch" : 0 + "endTimeEpoch" : 1430917391398 } ] }, { "id" : "local-1430917381535", @@ -55,9 +85,9 @@ "sparkUser" : "irashid", "completed" : true, "appSparkVersion" : "1.4.0-SNAPSHOT", - "endTimeEpoch" : 1430917380950, + "lastUpdatedEpoch" : 0, "startTimeEpoch" : 1430917380893, - "lastUpdatedEpoch" : 0 + "endTimeEpoch" : 1430917380950 }, { "attemptId" : "1", "startTime" : "2015-05-06T13:03:00.880GMT", @@ -67,9 +97,9 @@ "sparkUser" : "irashid", "completed" : true, "appSparkVersion" : "1.4.0-SNAPSHOT", - "endTimeEpoch" : 1430917380890, + "lastUpdatedEpoch" : 0, "startTimeEpoch" : 1430917380880, - "lastUpdatedEpoch" : 0 + "endTimeEpoch" : 1430917380890 } ] }, { "id" : "local-1426533911241", @@ -83,9 +113,9 @@ "sparkUser" : "irashid", "completed" : true, "appSparkVersion" : "", - "endTimeEpoch" : 1426633945177, + "lastUpdatedEpoch" : 0, "startTimeEpoch" : 1426633910242, - "lastUpdatedEpoch" : 0 + "endTimeEpoch" : 1426633945177 }, { "attemptId" : "1", "startTime" : "2015-03-16T19:25:10.242GMT", @@ -95,9 +125,9 @@ "sparkUser" : "irashid", "completed" : true, "appSparkVersion" : "", - "endTimeEpoch" : 1426533945177, + "lastUpdatedEpoch" : 0, "startTimeEpoch" : 1426533910242, - "lastUpdatedEpoch" : 0 + "endTimeEpoch" : 1426533945177 } ] }, { "id" : "local-1425081759269", @@ -110,10 +140,9 @@ "sparkUser" : "irashid", "completed" : true, "appSparkVersion" : "", - "appSparkVersion" : "", - "endTimeEpoch" : 1425081766912, + "lastUpdatedEpoch" : 0, "startTimeEpoch" : 1425081758277, - "lastUpdatedEpoch" : 0 + "endTimeEpoch" : 1425081766912 } ] }, { "id" : "local-1422981780767", @@ -126,9 +155,9 @@ "sparkUser" : "irashid", "completed" : true, "appSparkVersion" : "", - "endTimeEpoch" : 1422981788731, + "lastUpdatedEpoch" : 0, "startTimeEpoch" : 1422981779720, - "lastUpdatedEpoch" : 0 + "endTimeEpoch" : 1422981788731 } ] }, { "id" : "local-1422981759269", @@ -141,8 +170,8 @@ "sparkUser" : "irashid", "completed" : true, "appSparkVersion" : "", - "endTimeEpoch" : 1422981766912, + "lastUpdatedEpoch" : 0, "startTimeEpoch" : 1422981758277, - "lastUpdatedEpoch" : 0 + "endTimeEpoch" : 1422981766912 } ] } ] diff --git a/core/src/test/resources/HistoryServerExpectations/limit_app_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/limit_app_list_json_expectation.json index cc0b2b0022bd3..79950b0dc6486 100644 --- a/core/src/test/resources/HistoryServerExpectations/limit_app_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/limit_app_list_json_expectation.json @@ -1,46 +1,46 @@ [ { - "id" : "app-20161116163331-0000", + "id" : "application_1516285256255_0012", "name" : "Spark shell", "attempts" : [ { - "startTime" : "2016-11-16T22:33:29.916GMT", - "endTime" : "2016-11-16T22:33:40.587GMT", + "startTime" : "2018-01-18T18:30:35.119GMT", + "endTime" : "2018-01-18T18:38:27.938GMT", "lastUpdated" : "", - "duration" : 10671, - "sparkUser" : "jose", + "duration" : 472819, + "sparkUser" : "attilapiros", "completed" : true, - "appSparkVersion" : "2.1.0-SNAPSHOT", - "endTimeEpoch" : 1479335620587, - "startTimeEpoch" : 1479335609916, - "lastUpdatedEpoch" : 0 + "appSparkVersion" : "2.3.0-SNAPSHOT", + "lastUpdatedEpoch" : 0, + "startTimeEpoch" : 1516300235119, + "endTimeEpoch" : 1516300707938 } ] }, { - "id" : "app-20161115172038-0000", + "id" : "app-20180109111548-0000", "name" : "Spark shell", "attempts" : [ { - "startTime" : "2016-11-15T23:20:37.079GMT", - "endTime" : "2016-11-15T23:22:18.874GMT", + "startTime" : "2018-01-09T10:15:42.372GMT", + "endTime" : "2018-01-09T10:24:37.606GMT", "lastUpdated" : "", - "duration" : 101795, - "sparkUser" : "jose", + "duration" : 535234, + "sparkUser" : "attilapiros", "completed" : true, - "appSparkVersion" : "2.1.0-SNAPSHOT", - "endTimeEpoch" : 1479252138874, - "startTimeEpoch" : 1479252037079, - "lastUpdatedEpoch" : 0 + "appSparkVersion" : "2.3.0-SNAPSHOT", + "lastUpdatedEpoch" : 0, + "startTimeEpoch" : 1515492942372, + "endTimeEpoch" : 1515493477606 } ] }, { - "id" : "local-1430917381534", + "id" : "app-20161116163331-0000", "name" : "Spark shell", "attempts" : [ { - "startTime" : "2015-05-06T13:03:00.893GMT", - "endTime" : "2015-05-06T13:03:11.398GMT", + "startTime" : "2016-11-16T22:33:29.916GMT", + "endTime" : "2016-11-16T22:33:40.587GMT", "lastUpdated" : "", - "duration" : 10505, - "sparkUser" : "irashid", + "duration" : 10671, + "sparkUser" : "jose", "completed" : true, - "appSparkVersion" : "1.4.0-SNAPSHOT", - "endTimeEpoch" : 1430917391398, - "startTimeEpoch" : 1430917380893, - "lastUpdatedEpoch" : 0 + "appSparkVersion" : "2.1.0-SNAPSHOT", + "lastUpdatedEpoch" : 0, + "startTimeEpoch" : 1479335609916, + "endTimeEpoch" : 1479335620587 } ] } ] diff --git a/core/src/test/resources/HistoryServerExpectations/minDate_app_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/minDate_app_list_json_expectation.json index 5af50abd85330..7d60977dcd4fe 100644 --- a/core/src/test/resources/HistoryServerExpectations/minDate_app_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/minDate_app_list_json_expectation.json @@ -1,4 +1,34 @@ [ { + "id" : "application_1516285256255_0012", + "name" : "Spark shell", + "attempts" : [ { + "startTime" : "2018-01-18T18:30:35.119GMT", + "endTime" : "2018-01-18T18:38:27.938GMT", + "lastUpdated" : "", + "duration" : 472819, + "sparkUser" : "attilapiros", + "completed" : true, + "appSparkVersion" : "2.3.0-SNAPSHOT", + "lastUpdatedEpoch" : 0, + "startTimeEpoch" : 1516300235119, + "endTimeEpoch" : 1516300707938 + } ] +}, { + "id" : "app-20180109111548-0000", + "name" : "Spark shell", + "attempts" : [ { + "startTime" : "2018-01-09T10:15:42.372GMT", + "endTime" : "2018-01-09T10:24:37.606GMT", + "lastUpdated" : "", + "duration" : 535234, + "sparkUser" : "attilapiros", + "completed" : true, + "appSparkVersion" : "2.3.0-SNAPSHOT", + "lastUpdatedEpoch" : 0, + "startTimeEpoch" : 1515492942372, + "endTimeEpoch" : 1515493477606 + } ] +}, { "id" : "app-20161116163331-0000", "name" : "Spark shell", "attempts" : [ { @@ -9,9 +39,9 @@ "sparkUser" : "jose", "completed" : true, "appSparkVersion" : "2.1.0-SNAPSHOT", - "endTimeEpoch" : 1479335620587, + "lastUpdatedEpoch" : 0, "startTimeEpoch" : 1479335609916, - "lastUpdatedEpoch" : 0 + "endTimeEpoch" : 1479335620587 } ] }, { "id" : "app-20161115172038-0000", @@ -24,9 +54,9 @@ "sparkUser" : "jose", "completed" : true, "appSparkVersion" : "2.1.0-SNAPSHOT", - "endTimeEpoch" : 1479252138874, + "lastUpdatedEpoch" : 0, "startTimeEpoch" : 1479252037079, - "lastUpdatedEpoch" : 0 + "endTimeEpoch" : 1479252138874 } ] }, { "id" : "local-1430917381534", @@ -39,9 +69,9 @@ "sparkUser" : "irashid", "completed" : true, "appSparkVersion" : "1.4.0-SNAPSHOT", - "endTimeEpoch" : 1430917391398, + "lastUpdatedEpoch" : 0, "startTimeEpoch" : 1430917380893, - "lastUpdatedEpoch" : 0 + "endTimeEpoch" : 1430917391398 } ] }, { "id" : "local-1430917381535", @@ -55,9 +85,9 @@ "sparkUser" : "irashid", "completed" : true, "appSparkVersion" : "1.4.0-SNAPSHOT", - "endTimeEpoch" : 1430917380950, + "lastUpdatedEpoch" : 0, "startTimeEpoch" : 1430917380893, - "lastUpdatedEpoch" : 0 + "endTimeEpoch" : 1430917380950 }, { "attemptId" : "1", "startTime" : "2015-05-06T13:03:00.880GMT", @@ -67,9 +97,9 @@ "sparkUser" : "irashid", "completed" : true, "appSparkVersion" : "1.4.0-SNAPSHOT", - "endTimeEpoch" : 1430917380890, + "lastUpdatedEpoch" : 0, "startTimeEpoch" : 1430917380880, - "lastUpdatedEpoch" : 0 + "endTimeEpoch" : 1430917380890 } ] }, { "id" : "local-1426533911241", @@ -83,9 +113,9 @@ "sparkUser" : "irashid", "completed" : true, "appSparkVersion" : "", - "endTimeEpoch" : 1426633945177, + "lastUpdatedEpoch" : 0, "startTimeEpoch" : 1426633910242, - "lastUpdatedEpoch" : 0 + "endTimeEpoch" : 1426633945177 }, { "attemptId" : "1", "startTime" : "2015-03-16T19:25:10.242GMT", @@ -95,9 +125,9 @@ "sparkUser" : "irashid", "completed" : true, "appSparkVersion" : "", - "endTimeEpoch" : 1426533945177, + "lastUpdatedEpoch" : 0, "startTimeEpoch" : 1426533910242, - "lastUpdatedEpoch" : 0 + "endTimeEpoch" : 1426533945177 } ] }, { "id" : "local-1425081759269", @@ -110,8 +140,8 @@ "sparkUser" : "irashid", "completed" : true, "appSparkVersion" : "", - "endTimeEpoch" : 1425081766912, + "lastUpdatedEpoch" : 0, "startTimeEpoch" : 1425081758277, - "lastUpdatedEpoch" : 0 + "endTimeEpoch" : 1425081766912 } ] } ] diff --git a/core/src/test/resources/HistoryServerExpectations/minEndDate_app_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/minEndDate_app_list_json_expectation.json index 7f896c74b5be1..dfbfd8aedcc23 100644 --- a/core/src/test/resources/HistoryServerExpectations/minEndDate_app_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/minEndDate_app_list_json_expectation.json @@ -1,4 +1,34 @@ [ { + "id" : "application_1516285256255_0012", + "name" : "Spark shell", + "attempts" : [ { + "startTime" : "2018-01-18T18:30:35.119GMT", + "endTime" : "2018-01-18T18:38:27.938GMT", + "lastUpdated" : "", + "duration" : 472819, + "sparkUser" : "attilapiros", + "completed" : true, + "appSparkVersion" : "2.3.0-SNAPSHOT", + "lastUpdatedEpoch" : 0, + "startTimeEpoch" : 1516300235119, + "endTimeEpoch" : 1516300707938 + } ] +}, { + "id" : "app-20180109111548-0000", + "name" : "Spark shell", + "attempts" : [ { + "startTime" : "2018-01-09T10:15:42.372GMT", + "endTime" : "2018-01-09T10:24:37.606GMT", + "lastUpdated" : "", + "duration" : 535234, + "sparkUser" : "attilapiros", + "completed" : true, + "appSparkVersion" : "2.3.0-SNAPSHOT", + "lastUpdatedEpoch" : 0, + "startTimeEpoch" : 1515492942372, + "endTimeEpoch" : 1515493477606 + } ] +}, { "id" : "app-20161116163331-0000", "name" : "Spark shell", "attempts" : [ { @@ -9,8 +39,8 @@ "sparkUser" : "jose", "completed" : true, "appSparkVersion" : "2.1.0-SNAPSHOT", - "startTimeEpoch" : 1479335609916, "lastUpdatedEpoch" : 0, + "startTimeEpoch" : 1479335609916, "endTimeEpoch" : 1479335620587 } ] }, { @@ -24,8 +54,8 @@ "sparkUser" : "jose", "completed" : true, "appSparkVersion" : "2.1.0-SNAPSHOT", - "startTimeEpoch" : 1479252037079, "lastUpdatedEpoch" : 0, + "startTimeEpoch" : 1479252037079, "endTimeEpoch" : 1479252138874 } ] }, { diff --git a/core/src/test/resources/HistoryServerExpectations/one_stage_attempt_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/one_stage_attempt_json_expectation.json index 31093a661663b..03f886afa5413 100644 --- a/core/src/test/resources/HistoryServerExpectations/one_stage_attempt_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/one_stage_attempt_json_expectation.json @@ -421,7 +421,8 @@ "shuffleWrite" : 13180, "shuffleWriteRecords" : 0, "memoryBytesSpilled" : 0, - "diskBytesSpilled" : 0 + "diskBytesSpilled" : 0, + "isBlacklistedForStage" : false } }, "killedTasksSummary" : { } diff --git a/core/src/test/resources/HistoryServerExpectations/one_stage_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/one_stage_json_expectation.json index 601d70695b17c..947c89906955d 100644 --- a/core/src/test/resources/HistoryServerExpectations/one_stage_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/one_stage_json_expectation.json @@ -421,7 +421,8 @@ "shuffleWrite" : 13180, "shuffleWriteRecords" : 0, "memoryBytesSpilled" : 0, - "diskBytesSpilled" : 0 + "diskBytesSpilled" : 0, + "isBlacklistedForStage" : false } }, "killedTasksSummary" : { } diff --git a/core/src/test/resources/HistoryServerExpectations/stage_with_accumulable_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_with_accumulable_json_expectation.json index 9cdcef0746185..963f010968b62 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_with_accumulable_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_with_accumulable_json_expectation.json @@ -465,7 +465,8 @@ "shuffleWrite" : 0, "shuffleWriteRecords" : 0, "memoryBytesSpilled" : 0, - "diskBytesSpilled" : 0 + "diskBytesSpilled" : 0, + "isBlacklistedForStage" : false } }, "killedTasksSummary" : { } diff --git a/core/src/test/resources/spark-events/app-20180109111548-0000 b/core/src/test/resources/spark-events/app-20180109111548-0000 new file mode 100755 index 0000000000000..50893d3001b95 --- /dev/null +++ b/core/src/test/resources/spark-events/app-20180109111548-0000 @@ -0,0 +1,59 @@ +{"Event":"SparkListenerLogStart","Spark Version":"2.3.0-SNAPSHOT"} +{"Event":"SparkListenerEnvironmentUpdate","JVM Information":{"Java Home":"/Library/Java/JavaVirtualMachines/jdk1.8.0_152.jdk/Contents/Home/jre","Java Version":"1.8.0_152 (Oracle Corporation)","Scala Version":"version 2.11.8"},"Spark Properties":{"spark.blacklist.enabled":"true","spark.driver.host":"172.30.65.138","spark.eventLog.enabled":"true","spark.driver.port":"64273","spark.repl.class.uri":"spark://172.30.65.138:64273/classes","spark.jars":"","spark.repl.class.outputDir":"/private/var/folders/9g/gf583nd1765cvfgb_lsvwgp00000gp/T/spark-811c1b49-eb66-4bfb-91ae-33b45efa269d/repl-c4438f51-ee23-41ed-8e04-71496e2f40f5","spark.app.name":"Spark shell","spark.scheduler.mode":"FIFO","spark.ui.showConsoleProgress":"true","spark.blacklist.stage.maxFailedTasksPerExecutor":"1","spark.executor.id":"driver","spark.submit.deployMode":"client","spark.master":"local-cluster[2,1,1024]","spark.home":"*********(redacted)","spark.sql.catalogImplementation":"in-memory","spark.blacklist.application.maxFailedTasksPerExecutor":"10","spark.app.id":"app-20180109111548-0000"},"System Properties":{"java.io.tmpdir":"/var/folders/9g/gf583nd1765cvfgb_lsvwgp00000gp/T/","line.separator":"\n","path.separator":":","sun.management.compiler":"HotSpot 64-Bit Tiered Compilers","SPARK_SUBMIT":"true","sun.cpu.endian":"little","java.specification.version":"1.8","java.vm.specification.name":"Java Virtual Machine Specification","java.vendor":"Oracle Corporation","java.vm.specification.version":"1.8","user.home":"*********(redacted)","file.encoding.pkg":"sun.io","sun.nio.ch.bugLevel":"","ftp.nonProxyHosts":"local|*.local|169.254/16|*.169.254/16","sun.arch.data.model":"64","sun.boot.library.path":"/Library/Java/JavaVirtualMachines/jdk1.8.0_152.jdk/Contents/Home/jre/lib","user.dir":"*********(redacted)","java.library.path":"*********(redacted)","sun.cpu.isalist":"","os.arch":"x86_64","java.vm.version":"25.152-b16","java.endorsed.dirs":"/Library/Java/JavaVirtualMachines/jdk1.8.0_152.jdk/Contents/Home/jre/lib/endorsed","java.runtime.version":"1.8.0_152-b16","java.vm.info":"mixed mode","java.ext.dirs":"*********(redacted)","java.runtime.name":"Java(TM) SE Runtime Environment","file.separator":"/","java.class.version":"52.0","scala.usejavacp":"true","java.specification.name":"Java Platform API Specification","sun.boot.class.path":"/Library/Java/JavaVirtualMachines/jdk1.8.0_152.jdk/Contents/Home/jre/lib/resources.jar:/Library/Java/JavaVirtualMachines/jdk1.8.0_152.jdk/Contents/Home/jre/lib/rt.jar:/Library/Java/JavaVirtualMachines/jdk1.8.0_152.jdk/Contents/Home/jre/lib/sunrsasign.jar:/Library/Java/JavaVirtualMachines/jdk1.8.0_152.jdk/Contents/Home/jre/lib/jsse.jar:/Library/Java/JavaVirtualMachines/jdk1.8.0_152.jdk/Contents/Home/jre/lib/jce.jar:/Library/Java/JavaVirtualMachines/jdk1.8.0_152.jdk/Contents/Home/jre/lib/charsets.jar:/Library/Java/JavaVirtualMachines/jdk1.8.0_152.jdk/Contents/Home/jre/lib/jfr.jar:/Library/Java/JavaVirtualMachines/jdk1.8.0_152.jdk/Contents/Home/jre/classes","file.encoding":"UTF-8","user.timezone":"*********(redacted)","java.specification.vendor":"Oracle Corporation","sun.java.launcher":"SUN_STANDARD","os.version":"10.12.6","sun.os.patch.level":"unknown","gopherProxySet":"false","java.vm.specification.vendor":"Oracle Corporation","user.country":"*********(redacted)","sun.jnu.encoding":"UTF-8","http.nonProxyHosts":"local|*.local|169.254/16|*.169.254/16","user.language":"*********(redacted)","socksNonProxyHosts":"local|*.local|169.254/16|*.169.254/16","java.vendor.url":"*********(redacted)","java.awt.printerjob":"sun.lwawt.macosx.CPrinterJob","java.awt.graphicsenv":"sun.awt.CGraphicsEnvironment","awt.toolkit":"sun.lwawt.macosx.LWCToolkit","os.name":"Mac OS X","java.vm.vendor":"Oracle Corporation","java.vendor.url.bug":"*********(redacted)","user.name":"*********(redacted)","java.vm.name":"Java HotSpot(TM) 64-Bit Server VM","sun.java.command":"org.apache.spark.deploy.SparkSubmit --master local-cluster[2,1,1024] --conf spark.blacklist.stage.maxFailedTasksPerExecutor=1 --conf spark.blacklist.enabled=true --conf spark.blacklist.application.maxFailedTasksPerExecutor=10 --conf spark.eventLog.enabled=true --class org.apache.spark.repl.Main --name Spark shell spark-shell","java.home":"/Library/Java/JavaVirtualMachines/jdk1.8.0_152.jdk/Contents/Home/jre","java.version":"1.8.0_152","sun.io.unicode.encoding":"UnicodeBig"},"Classpath Entries":{"/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/api-asn1-api-1.0.0-M20.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/json4s-jackson_2.11-3.2.11.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/oro-2.0.8.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/machinist_2.11-0.6.1.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/metrics-json-3.1.5.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/lz4-java-1.4.0.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/spark-sketch_2.11-2.3.0-SNAPSHOT.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/spark-catalyst_2.11-2.3.0-SNAPSHOT.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/scala-reflect-2.11.8.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/hadoop-mapreduce-client-app-2.6.5.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/activation-1.1.1.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/jsr305-1.3.9.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/univocity-parsers-2.5.9.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/hk2-locator-2.4.0-b34.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/curator-framework-2.6.0.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/avro-mapred-1.7.7-hadoop2.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/jackson-jaxrs-1.9.13.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/jtransforms-2.4.0.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/json4s-core_2.11-3.2.11.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/metrics-jvm-3.1.5.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/jackson-mapper-asl-1.9.13.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/parquet-encoding-1.8.2.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/hk2-api-2.4.0-b34.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/py4j-0.10.6.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/zookeeper-3.4.6.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/jackson-core-asl-1.9.13.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/core-1.1.2.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/hadoop-mapreduce-client-core-2.6.5.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/hadoop-yarn-api-2.6.5.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/commons-beanutils-1.7.0.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/RoaringBitmap-0.5.11.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/jackson-module-paranamer-2.7.9.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/hadoop-common-2.6.5.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/jersey-common-2.22.2.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/javax.ws.rs-api-2.0.1.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/commons-configuration-1.6.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/javax.inject-2.4.0-b34.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/xercesImpl-2.9.1.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/gson-2.2.4.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/hadoop-hdfs-2.6.5.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/arrow-format-0.8.0.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/jackson-databind-2.6.7.1.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/jersey-guava-2.22.2.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/commons-lang3-3.5.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/aopalliance-repackaged-2.4.0-b34.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/jersey-media-jaxb-2.22.2.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/janino-3.0.8.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/hadoop-client-2.6.5.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/hadoop-auth-2.6.5.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/javassist-3.18.1-GA.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/parquet-format-2.3.1.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/breeze-macros_2.11-0.13.2.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/commons-compress-1.4.1.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/jersey-container-servlet-core-2.22.2.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/slf4j-log4j12-1.7.16.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/jersey-server-2.22.2.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/commons-collections-3.2.2.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/stax-api-1.0-2.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/guava-14.0.1.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/httpcore-4.4.8.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/spark-mllib_2.11-2.3.0-SNAPSHOT.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/osgi-resource-locator-1.0.1.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/spark-network-common_2.11-2.3.0-SNAPSHOT.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/arrow-memory-0.8.0.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/log4j-1.2.17.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/breeze_2.11-0.13.2.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/arrow-vector-0.8.0.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/opencsv-2.3.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/parquet-jackson-1.8.2.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/minlog-1.3.0.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/hadoop-mapreduce-client-jobclient-2.6.5.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/spark-network-shuffle_2.11-2.3.0-SNAPSHOT.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/htrace-core-3.0.4.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/bcprov-jdk15on-1.58.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/scalap-2.11.8.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/netty-all-4.1.17.Final.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/hppc-0.7.2.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/antlr4-runtime-4.7.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/commons-io-2.4.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/httpclient-4.5.4.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/jcl-over-slf4j-1.7.16.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/hk2-utils-2.4.0-b34.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/shapeless_2.11-2.3.2.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/scala-parser-combinators_2.11-1.0.4.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/commons-codec-1.10.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/protobuf-java-2.5.0.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/java-xmlbuilder-1.1.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/commons-net-2.2.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/compress-lzf-1.0.3.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/commons-beanutils-core-1.8.0.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/chill_2.11-0.8.4.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/flatbuffers-1.2.0-3f79e055.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/leveldbjni-all-1.8.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/hadoop-yarn-client-2.6.5.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/orc-mapreduce-1.4.1-nohive.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/paranamer-2.8.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/spark-launcher_2.11-2.3.0-SNAPSHOT.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/commons-httpclient-3.1.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/javax.servlet-api-3.1.0.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/jersey-container-servlet-2.22.2.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/aircompressor-0.8.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/spark-sql_2.11-2.3.0-SNAPSHOT.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/jackson-module-scala_2.11-2.6.7.1.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/slf4j-api-1.7.16.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/metrics-core-3.1.5.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/hadoop-yarn-common-2.6.5.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/spark-streaming_2.11-2.3.0-SNAPSHOT.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/spark-unsafe_2.11-2.3.0-SNAPSHOT.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/xbean-asm5-shaded-4.4.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/orc-core-1.4.1-nohive.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/scala-xml_2.11-1.0.5.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/spark-core_2.11-2.3.0-SNAPSHOT.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/javax.annotation-api-1.2.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/commons-math3-3.4.1.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/jets3t-0.9.4.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/commons-crypto-1.0.0.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/base64-2.3.8.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/commons-lang-2.6.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/curator-recipes-2.6.0.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/spire-macros_2.11-0.13.0.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/commons-compiler-3.0.8.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/spark-repl_2.11-2.3.0-SNAPSHOT.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/apacheds-i18n-2.0.0-M15.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/scala-library-2.11.8.jar":"*********(redacted)","/Users/attilapiros/github/spark/conf/":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/jackson-annotations-2.6.7.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/parquet-common-1.8.2.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/jetty-util-6.1.26.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/metrics-graphite-3.1.5.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/stream-2.7.0.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/chill-java-0.8.4.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/hadoop-mapreduce-client-common-2.6.5.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/jul-to-slf4j-1.7.16.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/ivy-2.4.0.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/xz-1.0.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/spire_2.11-0.13.0.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/parquet-hadoop-1.8.2.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/scala-compiler-2.11.8.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/commons-cli-1.2.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/avro-1.7.7.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/hadoop-yarn-server-common-2.6.5.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/kryo-shaded-3.0.3.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/commons-digester-1.8.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/jersey-client-2.22.2.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/spark-graphx_2.11-2.3.0-SNAPSHOT.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/hadoop-mapreduce-client-shuffle-2.6.5.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/spark-mllib-local_2.11-2.3.0-SNAPSHOT.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/snappy-java-1.1.2.6.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/xmlenc-0.52.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/spark-kvstore_2.11-2.3.0-SNAPSHOT.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/validation-api-1.1.0.Final.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/jackson-core-2.6.7.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/macro-compat_2.11-1.1.1.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/jaxb-api-2.2.2.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/parquet-column-1.8.2.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/zstd-jni-1.3.2-2.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/arpack_combined_all-0.1.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/json4s-ast_2.11-3.2.11.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/netty-3.9.9.Final.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/apacheds-kerberos-codec-2.0.0-M15.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/spark-tags_2.11-2.3.0-SNAPSHOT.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/hadoop-annotations-2.6.5.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/api-util-1.0.0-M20.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/curator-client-2.6.0.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/pyrolite-4.13.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/joda-time-2.9.3.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/jackson-xc-1.9.13.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/objenesis-2.1.jar":"*********(redacted)","/Users/attilapiros/github/spark/assembly/target/scala-2.11/jars/avro-ipc-1.7.7.jar":"*********(redacted)"}} +{"Event":"SparkListenerApplicationStart","App Name":"Spark shell","App ID":"app-20180109111548-0000","Timestamp":1515492942372,"User":"attilapiros"} +{"Event":"SparkListenerExecutorAdded","Timestamp":1515492965588,"Executor ID":"0","Executor Info":{"Host":"172.30.65.138","Total Cores":1,"Log Urls":{"stdout":"http://172.30.65.138:64279/logPage/?appId=app-20180109111548-0000&executorId=0&logType=stdout","stderr":"http://172.30.65.138:64279/logPage/?appId=app-20180109111548-0000&executorId=0&logType=stderr"}}} +{"Event":"SparkListenerExecutorAdded","Timestamp":1515492965598,"Executor ID":"1","Executor Info":{"Host":"172.30.65.138","Total Cores":1,"Log Urls":{"stdout":"http://172.30.65.138:64278/logPage/?appId=app-20180109111548-0000&executorId=1&logType=stdout","stderr":"http://172.30.65.138:64278/logPage/?appId=app-20180109111548-0000&executorId=1&logType=stderr"}}} +{"Event":"SparkListenerBlockManagerAdded","Block Manager ID":{"Executor ID":"0","Host":"172.30.65.138","Port":64290},"Maximum Memory":384093388,"Timestamp":1515492965643,"Maximum Onheap Memory":384093388,"Maximum Offheap Memory":0} +{"Event":"SparkListenerBlockManagerAdded","Block Manager ID":{"Executor ID":"1","Host":"172.30.65.138","Port":64291},"Maximum Memory":384093388,"Timestamp":1515492965652,"Maximum Onheap Memory":384093388,"Maximum Offheap Memory":0} +{"Event":"SparkListenerJobStart","Job ID":0,"Submission Time":1515493278122,"Stage Infos":[{"Stage ID":0,"Stage Attempt ID":0,"Stage Name":"map at :26","Number of Tasks":10,"RDD Info":[{"RDD ID":1,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"1\",\"name\":\"map\"}","Callsite":"map at :26","Parent IDs":[0],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":10,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":0,"Name":"ParallelCollectionRDD","Scope":"{\"id\":\"0\",\"name\":\"parallelize\"}","Callsite":"parallelize at :26","Parent IDs":[],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":10,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0}],"Parent IDs":[],"Details":"org.apache.spark.rdd.RDD.map(RDD.scala:370)\n$line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:26)\n$line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:34)\n$line17.$read$$iw$$iw$$iw$$iw$$iw$$iw.(:36)\n$line17.$read$$iw$$iw$$iw$$iw$$iw.(:38)\n$line17.$read$$iw$$iw$$iw$$iw.(:40)\n$line17.$read$$iw$$iw$$iw.(:42)\n$line17.$read$$iw$$iw.(:44)\n$line17.$read$$iw.(:46)\n$line17.$read.(:48)\n$line17.$read$.(:52)\n$line17.$read$.()\n$line17.$eval$.$print$lzycompute(:7)\n$line17.$eval$.$print(:6)\n$line17.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:498)\nscala.tools.nsc.interpreter.IMain$ReadEvalPrint.call(IMain.scala:786)","Accumulables":[]},{"Stage ID":1,"Stage Attempt ID":0,"Stage Name":"collect at :29","Number of Tasks":10,"RDD Info":[{"RDD ID":2,"Name":"ShuffledRDD","Scope":"{\"id\":\"2\",\"name\":\"reduceByKey\"}","Callsite":"reduceByKey at :29","Parent IDs":[1],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":10,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0}],"Parent IDs":[0],"Details":"org.apache.spark.rdd.RDD.collect(RDD.scala:936)\n$line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:29)\n$line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:34)\n$line17.$read$$iw$$iw$$iw$$iw$$iw$$iw.(:36)\n$line17.$read$$iw$$iw$$iw$$iw$$iw.(:38)\n$line17.$read$$iw$$iw$$iw$$iw.(:40)\n$line17.$read$$iw$$iw$$iw.(:42)\n$line17.$read$$iw$$iw.(:44)\n$line17.$read$$iw.(:46)\n$line17.$read.(:48)\n$line17.$read$.(:52)\n$line17.$read$.()\n$line17.$eval$.$print$lzycompute(:7)\n$line17.$eval$.$print(:6)\n$line17.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:498)\nscala.tools.nsc.interpreter.IMain$ReadEvalPrint.call(IMain.scala:786)","Accumulables":[]}],"Stage IDs":[0,1],"Properties":{"spark.rdd.scope.noOverride":"true","spark.rdd.scope":"{\"id\":\"3\",\"name\":\"collect\"}"}} +{"Event":"SparkListenerStageSubmitted","Stage Info":{"Stage ID":0,"Stage Attempt ID":0,"Stage Name":"map at :26","Number of Tasks":10,"RDD Info":[{"RDD ID":1,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"1\",\"name\":\"map\"}","Callsite":"map at :26","Parent IDs":[0],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":10,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":0,"Name":"ParallelCollectionRDD","Scope":"{\"id\":\"0\",\"name\":\"parallelize\"}","Callsite":"parallelize at :26","Parent IDs":[],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":10,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0}],"Parent IDs":[],"Details":"org.apache.spark.rdd.RDD.map(RDD.scala:370)\n$line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:26)\n$line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:34)\n$line17.$read$$iw$$iw$$iw$$iw$$iw$$iw.(:36)\n$line17.$read$$iw$$iw$$iw$$iw$$iw.(:38)\n$line17.$read$$iw$$iw$$iw$$iw.(:40)\n$line17.$read$$iw$$iw$$iw.(:42)\n$line17.$read$$iw$$iw.(:44)\n$line17.$read$$iw.(:46)\n$line17.$read.(:48)\n$line17.$read$.(:52)\n$line17.$read$.()\n$line17.$eval$.$print$lzycompute(:7)\n$line17.$eval$.$print(:6)\n$line17.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:498)\nscala.tools.nsc.interpreter.IMain$ReadEvalPrint.call(IMain.scala:786)","Submission Time":1515493278152,"Accumulables":[]},"Properties":{"spark.rdd.scope.noOverride":"true","spark.rdd.scope":"{\"id\":\"3\",\"name\":\"collect\"}"}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":0,"Index":0,"Attempt":0,"Launch Time":1515493278347,"Executor ID":"0","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":1,"Index":1,"Attempt":0,"Launch Time":1515493278364,"Executor ID":"1","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":2,"Index":2,"Attempt":0,"Launch Time":1515493278899,"Executor ID":"0","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"org.apache.spark.scheduler.SparkListenerExecutorBlacklistedForStage","time":1515493278918,"executorId":"0","taskFailures":1,"stageId":0,"stageAttemptId":0} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"ExceptionFailure","Class Name":"java.lang.RuntimeException","Description":"Bad executor","Stack Trace":[{"Declaring Class":"$line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2","Method Name":"apply","File Name":"","Line Number":27},{"Declaring Class":"$line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2","Method Name":"apply","File Name":"","Line Number":26},{"Declaring Class":"scala.collection.Iterator$$anon$11","Method Name":"next","File Name":"Iterator.scala","Line Number":409},{"Declaring Class":"org.apache.spark.util.collection.ExternalSorter","Method Name":"insertAll","File Name":"ExternalSorter.scala","Line Number":193},{"Declaring Class":"org.apache.spark.shuffle.sort.SortShuffleWriter","Method Name":"write","File Name":"SortShuffleWriter.scala","Line Number":63},{"Declaring Class":"org.apache.spark.scheduler.ShuffleMapTask","Method Name":"runTask","File Name":"ShuffleMapTask.scala","Line Number":96},{"Declaring Class":"org.apache.spark.scheduler.ShuffleMapTask","Method Name":"runTask","File Name":"ShuffleMapTask.scala","Line Number":53},{"Declaring Class":"org.apache.spark.scheduler.Task","Method Name":"run","File Name":"Task.scala","Line Number":109},{"Declaring Class":"org.apache.spark.executor.Executor$TaskRunner","Method Name":"run","File Name":"Executor.scala","Line Number":345},{"Declaring Class":"java.util.concurrent.ThreadPoolExecutor","Method Name":"runWorker","File Name":"ThreadPoolExecutor.java","Line Number":1149},{"Declaring Class":"java.util.concurrent.ThreadPoolExecutor$Worker","Method Name":"run","File Name":"ThreadPoolExecutor.java","Line Number":624},{"Declaring Class":"java.lang.Thread","Method Name":"run","File Name":"Thread.java","Line Number":748}],"Full Stack Trace":"java.lang.RuntimeException: Bad executor\n\tat $line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2.apply(:27)\n\tat $line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2.apply(:26)\n\tat scala.collection.Iterator$$anon$11.next(Iterator.scala:409)\n\tat org.apache.spark.util.collection.ExternalSorter.insertAll(ExternalSorter.scala:193)\n\tat org.apache.spark.shuffle.sort.SortShuffleWriter.write(SortShuffleWriter.scala:63)\n\tat org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:96)\n\tat org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:53)\n\tat org.apache.spark.scheduler.Task.run(Task.scala:109)\n\tat org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:345)\n\tat java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)\n\tat java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\n\tat java.lang.Thread.run(Thread.java:748)\n","Accumulator Updates":[{"ID":2,"Update":"460","Internal":false,"Count Failed Values":true},{"ID":4,"Update":"0","Internal":false,"Count Failed Values":true},{"ID":5,"Update":"14","Internal":false,"Count Failed Values":true},{"ID":20,"Update":"3873006","Internal":false,"Count Failed Values":true}]},"Task Info":{"Task ID":0,"Index":0,"Attempt":0,"Launch Time":1515493278347,"Executor ID":"0","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1515493278909,"Failed":true,"Killed":false,"Accumulables":[{"ID":20,"Name":"internal.metrics.shuffle.write.writeTime","Update":3873006,"Value":3873006,"Internal":true,"Count Failed Values":true},{"ID":5,"Name":"internal.metrics.jvmGCTime","Update":14,"Value":14,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":460,"Value":460,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":0,"Executor Deserialize CPU Time":0,"Executor Run Time":460,"Executor CPU Time":0,"Result Size":0,"JVM GC Time":14,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":3873006,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":3,"Index":0,"Attempt":1,"Launch Time":1515493278919,"Executor ID":"1","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1515493278943,"Failed":false,"Killed":false,"Accumulables":[{"ID":20,"Name":"internal.metrics.shuffle.write.writeTime","Update":207014,"Value":6615636,"Internal":true,"Count Failed Values":true},{"ID":19,"Name":"internal.metrics.shuffle.write.recordsWritten","Update":1,"Value":2,"Internal":true,"Count Failed Values":true},{"ID":18,"Name":"internal.metrics.shuffle.write.bytesWritten","Update":46,"Value":92,"Internal":true,"Count Failed Values":true},{"ID":9,"Name":"internal.metrics.peakExecutionMemory","Update":896,"Value":1792,"Internal":true,"Count Failed Values":true},{"ID":8,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":7,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":1029,"Value":2144,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":9364000,"Value":207843000,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":10,"Value":698,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":8878000,"Value":208907000,"Internal":true,"Count Failed Values":true},{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":8,"Value":309,"Internal":true,"Count Failed Values":true}]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"ExceptionFailure","Class Name":"java.lang.RuntimeException","Description":"Bad executor","Stack Trace":[{"Declaring Class":"$line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2","Method Name":"apply","File Name":"","Line Number":27},{"Declaring Class":"$line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2","Method Name":"apply","File Name":"","Line Number":26},{"Declaring Class":"scala.collection.Iterator$$anon$11","Method Name":"next","File Name":"Iterator.scala","Line Number":409},{"Declaring Class":"org.apache.spark.util.collection.ExternalSorter","Method Name":"insertAll","File Name":"ExternalSorter.scala","Line Number":193},{"Declaring Class":"org.apache.spark.shuffle.sort.SortShuffleWriter","Method Name":"write","File Name":"SortShuffleWriter.scala","Line Number":63},{"Declaring Class":"org.apache.spark.scheduler.ShuffleMapTask","Method Name":"runTask","File Name":"ShuffleMapTask.scala","Line Number":96},{"Declaring Class":"org.apache.spark.scheduler.ShuffleMapTask","Method Name":"runTask","File Name":"ShuffleMapTask.scala","Line Number":53},{"Declaring Class":"org.apache.spark.scheduler.Task","Method Name":"run","File Name":"Task.scala","Line Number":109},{"Declaring Class":"org.apache.spark.executor.Executor$TaskRunner","Method Name":"run","File Name":"Executor.scala","Line Number":345},{"Declaring Class":"java.util.concurrent.ThreadPoolExecutor","Method Name":"runWorker","File Name":"ThreadPoolExecutor.java","Line Number":1149},{"Declaring Class":"java.util.concurrent.ThreadPoolExecutor$Worker","Method Name":"run","File Name":"ThreadPoolExecutor.java","Line Number":624},{"Declaring Class":"java.lang.Thread","Method Name":"run","File Name":"Thread.java","Line Number":748}],"Full Stack Trace":"java.lang.RuntimeException: Bad executor\n\tat $line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2.apply(:27)\n\tat $line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2.apply(:26)\n\tat scala.collection.Iterator$$anon$11.next(Iterator.scala:409)\n\tat org.apache.spark.util.collection.ExternalSorter.insertAll(ExternalSorter.scala:193)\n\tat org.apache.spark.shuffle.sort.SortShuffleWriter.write(SortShuffleWriter.scala:63)\n\tat org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:96)\n\tat org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:53)\n\tat org.apache.spark.scheduler.Task.run(Task.scala:109)\n\tat org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:345)\n\tat java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)\n\tat java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\n\tat java.lang.Thread.run(Thread.java:748)\n","Accumulator Updates":[{"ID":2,"Update":"16","Internal":false,"Count Failed Values":true},{"ID":4,"Update":"0","Internal":false,"Count Failed Values":true},{"ID":20,"Update":"126128","Internal":false,"Count Failed Values":true}]},"Task Info":{"Task ID":2,"Index":2,"Attempt":0,"Launch Time":1515493278899,"Executor ID":"0","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1515493278926,"Failed":true,"Killed":false,"Accumulables":[{"ID":20,"Name":"internal.metrics.shuffle.write.writeTime","Update":126128,"Value":3999134,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":16,"Value":476,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":0,"Executor Deserialize CPU Time":0,"Executor Run Time":16,"Executor CPU Time":0,"Result Size":0,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":126128,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":1,"Index":1,"Attempt":0,"Launch Time":1515493278364,"Executor ID":"1","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1515493278929,"Failed":false,"Killed":false,"Accumulables":[{"ID":20,"Name":"internal.metrics.shuffle.write.writeTime","Update":2409488,"Value":6408622,"Internal":true,"Count Failed Values":true},{"ID":19,"Name":"internal.metrics.shuffle.write.recordsWritten","Update":1,"Value":1,"Internal":true,"Count Failed Values":true},{"ID":18,"Name":"internal.metrics.shuffle.write.bytesWritten","Update":46,"Value":46,"Internal":true,"Count Failed Values":true},{"ID":9,"Name":"internal.metrics.peakExecutionMemory","Update":896,"Value":896,"Internal":true,"Count Failed Values":true},{"ID":8,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":7,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":6,"Name":"internal.metrics.resultSerializationTime","Update":1,"Value":1,"Internal":true,"Count Failed Values":true},{"ID":5,"Name":"internal.metrics.jvmGCTime","Update":13,"Value":27,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":1115,"Value":1115,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":198479000,"Value":198479000,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":212,"Value":688,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":200029000,"Value":200029000,"Internal":true,"Count Failed Values":true},{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":301,"Value":301,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":301,"Executor Deserialize CPU Time":200029000,"Executor Run Time":212,"Executor CPU Time":198479000,"Result Size":1115,"JVM GC Time":13,"Result Serialization Time":1,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":46,"Shuffle Write Time":2409488,"Shuffle Records Written":1},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":4,"Index":2,"Attempt":1,"Launch Time":1515493278943,"Executor ID":"1","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1515493278959,"Failed":false,"Killed":false,"Accumulables":[{"ID":20,"Name":"internal.metrics.shuffle.write.writeTime","Update":292381,"Value":6908017,"Internal":true,"Count Failed Values":true},{"ID":19,"Name":"internal.metrics.shuffle.write.recordsWritten","Update":1,"Value":3,"Internal":true,"Count Failed Values":true},{"ID":18,"Name":"internal.metrics.shuffle.write.bytesWritten","Update":46,"Value":138,"Internal":true,"Count Failed Values":true},{"ID":9,"Name":"internal.metrics.peakExecutionMemory","Update":912,"Value":2704,"Internal":true,"Count Failed Values":true},{"ID":8,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":7,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":1029,"Value":3173,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":9207000,"Value":217050000,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":9,"Value":707,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":2211000,"Value":211118000,"Internal":true,"Count Failed Values":true},{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":2,"Value":311,"Internal":true,"Count Failed Values":true}]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":3,"Index":0,"Attempt":1,"Launch Time":1515493278919,"Executor ID":"1","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1515493278943,"Failed":false,"Killed":false,"Accumulables":[{"ID":20,"Name":"internal.metrics.shuffle.write.writeTime","Update":207014,"Value":6615636,"Internal":true,"Count Failed Values":true},{"ID":19,"Name":"internal.metrics.shuffle.write.recordsWritten","Update":1,"Value":2,"Internal":true,"Count Failed Values":true},{"ID":18,"Name":"internal.metrics.shuffle.write.bytesWritten","Update":46,"Value":92,"Internal":true,"Count Failed Values":true},{"ID":9,"Name":"internal.metrics.peakExecutionMemory","Update":896,"Value":1792,"Internal":true,"Count Failed Values":true},{"ID":8,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":7,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":1029,"Value":2144,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":9364000,"Value":207843000,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":10,"Value":698,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":8878000,"Value":208907000,"Internal":true,"Count Failed Values":true},{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":8,"Value":309,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":8,"Executor Deserialize CPU Time":8878000,"Executor Run Time":10,"Executor CPU Time":9364000,"Result Size":1029,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":46,"Shuffle Write Time":207014,"Shuffle Records Written":1},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":5,"Index":3,"Attempt":0,"Launch Time":1515493278958,"Executor ID":"1","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1515493278980,"Failed":false,"Killed":false,"Accumulables":[{"ID":20,"Name":"internal.metrics.shuffle.write.writeTime","Update":262919,"Value":7170936,"Internal":true,"Count Failed Values":true},{"ID":19,"Name":"internal.metrics.shuffle.write.recordsWritten","Update":1,"Value":4,"Internal":true,"Count Failed Values":true},{"ID":18,"Name":"internal.metrics.shuffle.write.bytesWritten","Update":46,"Value":184,"Internal":true,"Count Failed Values":true},{"ID":9,"Name":"internal.metrics.peakExecutionMemory","Update":912,"Value":3616,"Internal":true,"Count Failed Values":true},{"ID":8,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":7,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":1029,"Value":4202,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":9635000,"Value":226685000,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":9,"Value":716,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":2586000,"Value":213704000,"Internal":true,"Count Failed Values":true},{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":3,"Value":314,"Internal":true,"Count Failed Values":true}]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":4,"Index":2,"Attempt":1,"Launch Time":1515493278943,"Executor ID":"1","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1515493278959,"Failed":false,"Killed":false,"Accumulables":[{"ID":20,"Name":"internal.metrics.shuffle.write.writeTime","Update":292381,"Value":6908017,"Internal":true,"Count Failed Values":true},{"ID":19,"Name":"internal.metrics.shuffle.write.recordsWritten","Update":1,"Value":3,"Internal":true,"Count Failed Values":true},{"ID":18,"Name":"internal.metrics.shuffle.write.bytesWritten","Update":46,"Value":138,"Internal":true,"Count Failed Values":true},{"ID":9,"Name":"internal.metrics.peakExecutionMemory","Update":912,"Value":2704,"Internal":true,"Count Failed Values":true},{"ID":8,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":7,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":1029,"Value":3173,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":9207000,"Value":217050000,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":9,"Value":707,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":2211000,"Value":211118000,"Internal":true,"Count Failed Values":true},{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":2,"Value":311,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":2,"Executor Deserialize CPU Time":2211000,"Executor Run Time":9,"Executor CPU Time":9207000,"Result Size":1029,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":46,"Shuffle Write Time":292381,"Shuffle Records Written":1},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":6,"Index":4,"Attempt":0,"Launch Time":1515493278980,"Executor ID":"1","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":5,"Index":3,"Attempt":0,"Launch Time":1515493278958,"Executor ID":"1","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1515493278980,"Failed":false,"Killed":false,"Accumulables":[{"ID":20,"Name":"internal.metrics.shuffle.write.writeTime","Update":262919,"Value":7170936,"Internal":true,"Count Failed Values":true},{"ID":19,"Name":"internal.metrics.shuffle.write.recordsWritten","Update":1,"Value":4,"Internal":true,"Count Failed Values":true},{"ID":18,"Name":"internal.metrics.shuffle.write.bytesWritten","Update":46,"Value":184,"Internal":true,"Count Failed Values":true},{"ID":9,"Name":"internal.metrics.peakExecutionMemory","Update":912,"Value":3616,"Internal":true,"Count Failed Values":true},{"ID":8,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":7,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":1029,"Value":4202,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":9635000,"Value":226685000,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":9,"Value":716,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":2586000,"Value":213704000,"Internal":true,"Count Failed Values":true},{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":3,"Value":314,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":3,"Executor Deserialize CPU Time":2586000,"Executor Run Time":9,"Executor CPU Time":9635000,"Result Size":1029,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":46,"Shuffle Write Time":262919,"Shuffle Records Written":1},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":7,"Index":5,"Attempt":0,"Launch Time":1515493278996,"Executor ID":"1","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":6,"Index":4,"Attempt":0,"Launch Time":1515493278980,"Executor ID":"1","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1515493278996,"Failed":false,"Killed":false,"Accumulables":[{"ID":20,"Name":"internal.metrics.shuffle.write.writeTime","Update":385110,"Value":7556046,"Internal":true,"Count Failed Values":true},{"ID":19,"Name":"internal.metrics.shuffle.write.recordsWritten","Update":1,"Value":5,"Internal":true,"Count Failed Values":true},{"ID":18,"Name":"internal.metrics.shuffle.write.bytesWritten","Update":46,"Value":230,"Internal":true,"Count Failed Values":true},{"ID":9,"Name":"internal.metrics.peakExecutionMemory","Update":912,"Value":4528,"Internal":true,"Count Failed Values":true},{"ID":8,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":7,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":1029,"Value":5231,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":9622000,"Value":236307000,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":10,"Value":726,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":2610000,"Value":216314000,"Internal":true,"Count Failed Values":true},{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":3,"Value":317,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":3,"Executor Deserialize CPU Time":2610000,"Executor Run Time":10,"Executor CPU Time":9622000,"Result Size":1029,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":46,"Shuffle Write Time":385110,"Shuffle Records Written":1},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":8,"Index":6,"Attempt":0,"Launch Time":1515493279011,"Executor ID":"1","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":7,"Index":5,"Attempt":0,"Launch Time":1515493278996,"Executor ID":"1","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1515493279011,"Failed":false,"Killed":false,"Accumulables":[{"ID":20,"Name":"internal.metrics.shuffle.write.writeTime","Update":205520,"Value":7761566,"Internal":true,"Count Failed Values":true},{"ID":19,"Name":"internal.metrics.shuffle.write.recordsWritten","Update":1,"Value":6,"Internal":true,"Count Failed Values":true},{"ID":18,"Name":"internal.metrics.shuffle.write.bytesWritten","Update":46,"Value":276,"Internal":true,"Count Failed Values":true},{"ID":9,"Name":"internal.metrics.peakExecutionMemory","Update":912,"Value":5440,"Internal":true,"Count Failed Values":true},{"ID":8,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":7,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":1029,"Value":6260,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":8407000,"Value":244714000,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":9,"Value":735,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":2231000,"Value":218545000,"Internal":true,"Count Failed Values":true},{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":2,"Value":319,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":2,"Executor Deserialize CPU Time":2231000,"Executor Run Time":9,"Executor CPU Time":8407000,"Result Size":1029,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":46,"Shuffle Write Time":205520,"Shuffle Records Written":1},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":9,"Index":7,"Attempt":0,"Launch Time":1515493279022,"Executor ID":"1","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":8,"Index":6,"Attempt":0,"Launch Time":1515493279011,"Executor ID":"1","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1515493279022,"Failed":false,"Killed":false,"Accumulables":[{"ID":20,"Name":"internal.metrics.shuffle.write.writeTime","Update":213296,"Value":7974862,"Internal":true,"Count Failed Values":true},{"ID":19,"Name":"internal.metrics.shuffle.write.recordsWritten","Update":1,"Value":7,"Internal":true,"Count Failed Values":true},{"ID":18,"Name":"internal.metrics.shuffle.write.bytesWritten","Update":46,"Value":322,"Internal":true,"Count Failed Values":true},{"ID":9,"Name":"internal.metrics.peakExecutionMemory","Update":912,"Value":6352,"Internal":true,"Count Failed Values":true},{"ID":8,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":7,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":1029,"Value":7289,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":6034000,"Value":250748000,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":7,"Value":742,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":1554000,"Value":220099000,"Internal":true,"Count Failed Values":true},{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":1,"Value":320,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":1,"Executor Deserialize CPU Time":1554000,"Executor Run Time":7,"Executor CPU Time":6034000,"Result Size":1029,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":46,"Shuffle Write Time":213296,"Shuffle Records Written":1},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":10,"Index":8,"Attempt":0,"Launch Time":1515493279034,"Executor ID":"1","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":9,"Index":7,"Attempt":0,"Launch Time":1515493279022,"Executor ID":"1","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1515493279034,"Failed":false,"Killed":false,"Accumulables":[{"ID":20,"Name":"internal.metrics.shuffle.write.writeTime","Update":259354,"Value":8234216,"Internal":true,"Count Failed Values":true},{"ID":19,"Name":"internal.metrics.shuffle.write.recordsWritten","Update":1,"Value":8,"Internal":true,"Count Failed Values":true},{"ID":18,"Name":"internal.metrics.shuffle.write.bytesWritten","Update":46,"Value":368,"Internal":true,"Count Failed Values":true},{"ID":9,"Name":"internal.metrics.peakExecutionMemory","Update":912,"Value":7264,"Internal":true,"Count Failed Values":true},{"ID":8,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":7,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":1029,"Value":8318,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":6335000,"Value":257083000,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":7,"Value":749,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":1981000,"Value":222080000,"Internal":true,"Count Failed Values":true},{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":2,"Value":322,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":2,"Executor Deserialize CPU Time":1981000,"Executor Run Time":7,"Executor CPU Time":6335000,"Result Size":1029,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":46,"Shuffle Write Time":259354,"Shuffle Records Written":1},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":11,"Index":9,"Attempt":0,"Launch Time":1515493279045,"Executor ID":"1","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":10,"Index":8,"Attempt":0,"Launch Time":1515493279034,"Executor ID":"1","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1515493279046,"Failed":false,"Killed":false,"Accumulables":[{"ID":20,"Name":"internal.metrics.shuffle.write.writeTime","Update":243647,"Value":8477863,"Internal":true,"Count Failed Values":true},{"ID":19,"Name":"internal.metrics.shuffle.write.recordsWritten","Update":1,"Value":9,"Internal":true,"Count Failed Values":true},{"ID":18,"Name":"internal.metrics.shuffle.write.bytesWritten","Update":46,"Value":414,"Internal":true,"Count Failed Values":true},{"ID":9,"Name":"internal.metrics.peakExecutionMemory","Update":912,"Value":8176,"Internal":true,"Count Failed Values":true},{"ID":8,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":7,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":1029,"Value":9347,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":6157000,"Value":263240000,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":6,"Value":755,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":1803000,"Value":223883000,"Internal":true,"Count Failed Values":true},{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":2,"Value":324,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":2,"Executor Deserialize CPU Time":1803000,"Executor Run Time":6,"Executor CPU Time":6157000,"Result Size":1029,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":46,"Shuffle Write Time":243647,"Shuffle Records Written":1},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":11,"Index":9,"Attempt":0,"Launch Time":1515493279045,"Executor ID":"1","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1515493279060,"Failed":false,"Killed":false,"Accumulables":[{"ID":20,"Name":"internal.metrics.shuffle.write.writeTime","Update":233652,"Value":8711515,"Internal":true,"Count Failed Values":true},{"ID":19,"Name":"internal.metrics.shuffle.write.recordsWritten","Update":1,"Value":10,"Internal":true,"Count Failed Values":true},{"ID":18,"Name":"internal.metrics.shuffle.write.bytesWritten","Update":46,"Value":460,"Internal":true,"Count Failed Values":true},{"ID":9,"Name":"internal.metrics.peakExecutionMemory","Update":912,"Value":9088,"Internal":true,"Count Failed Values":true},{"ID":8,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":7,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":1029,"Value":10376,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":6676000,"Value":269916000,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":6,"Value":761,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":2017000,"Value":225900000,"Internal":true,"Count Failed Values":true},{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":3,"Value":327,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":3,"Executor Deserialize CPU Time":2017000,"Executor Run Time":6,"Executor CPU Time":6676000,"Result Size":1029,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":46,"Shuffle Write Time":233652,"Shuffle Records Written":1},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerStageCompleted","Stage Info":{"Stage ID":0,"Stage Attempt ID":0,"Stage Name":"map at :26","Number of Tasks":10,"RDD Info":[{"RDD ID":1,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"1\",\"name\":\"map\"}","Callsite":"map at :26","Parent IDs":[0],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":10,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":0,"Name":"ParallelCollectionRDD","Scope":"{\"id\":\"0\",\"name\":\"parallelize\"}","Callsite":"parallelize at :26","Parent IDs":[],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":10,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0}],"Parent IDs":[],"Details":"org.apache.spark.rdd.RDD.map(RDD.scala:370)\n$line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:26)\n$line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:34)\n$line17.$read$$iw$$iw$$iw$$iw$$iw$$iw.(:36)\n$line17.$read$$iw$$iw$$iw$$iw$$iw.(:38)\n$line17.$read$$iw$$iw$$iw$$iw.(:40)\n$line17.$read$$iw$$iw$$iw.(:42)\n$line17.$read$$iw$$iw.(:44)\n$line17.$read$$iw.(:46)\n$line17.$read.(:48)\n$line17.$read$.(:52)\n$line17.$read$.()\n$line17.$eval$.$print$lzycompute(:7)\n$line17.$eval$.$print(:6)\n$line17.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:498)\nscala.tools.nsc.interpreter.IMain$ReadEvalPrint.call(IMain.scala:786)","Submission Time":1515493278152,"Completion Time":1515493279062,"Accumulables":[{"ID":8,"Name":"internal.metrics.diskBytesSpilled","Value":0,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Value":761,"Internal":true,"Count Failed Values":true},{"ID":20,"Name":"internal.metrics.shuffle.write.writeTime","Value":8711515,"Internal":true,"Count Failed Values":true},{"ID":5,"Name":"internal.metrics.jvmGCTime","Value":27,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Value":10376,"Internal":true,"Count Failed Values":true},{"ID":7,"Name":"internal.metrics.memoryBytesSpilled","Value":0,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Value":225900000,"Internal":true,"Count Failed Values":true},{"ID":19,"Name":"internal.metrics.shuffle.write.recordsWritten","Value":10,"Internal":true,"Count Failed Values":true},{"ID":9,"Name":"internal.metrics.peakExecutionMemory","Value":9088,"Internal":true,"Count Failed Values":true},{"ID":18,"Name":"internal.metrics.shuffle.write.bytesWritten","Value":460,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Value":269916000,"Internal":true,"Count Failed Values":true},{"ID":6,"Name":"internal.metrics.resultSerializationTime","Value":1,"Internal":true,"Count Failed Values":true},{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Value":327,"Internal":true,"Count Failed Values":true}]}} +{"Event":"SparkListenerStageSubmitted","Stage Info":{"Stage ID":1,"Stage Attempt ID":0,"Stage Name":"collect at :29","Number of Tasks":10,"RDD Info":[{"RDD ID":2,"Name":"ShuffledRDD","Scope":"{\"id\":\"2\",\"name\":\"reduceByKey\"}","Callsite":"reduceByKey at :29","Parent IDs":[1],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":10,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0}],"Parent IDs":[0],"Details":"org.apache.spark.rdd.RDD.collect(RDD.scala:936)\n$line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:29)\n$line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:34)\n$line17.$read$$iw$$iw$$iw$$iw$$iw$$iw.(:36)\n$line17.$read$$iw$$iw$$iw$$iw$$iw.(:38)\n$line17.$read$$iw$$iw$$iw$$iw.(:40)\n$line17.$read$$iw$$iw$$iw.(:42)\n$line17.$read$$iw$$iw.(:44)\n$line17.$read$$iw.(:46)\n$line17.$read.(:48)\n$line17.$read$.(:52)\n$line17.$read$.()\n$line17.$eval$.$print$lzycompute(:7)\n$line17.$eval$.$print(:6)\n$line17.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:498)\nscala.tools.nsc.interpreter.IMain$ReadEvalPrint.call(IMain.scala:786)","Submission Time":1515493279071,"Accumulables":[]},"Properties":{"spark.rdd.scope.noOverride":"true","spark.rdd.scope":"{\"id\":\"3\",\"name\":\"collect\"}"}} +{"Event":"SparkListenerTaskStart","Stage ID":1,"Stage Attempt ID":0,"Task Info":{"Task ID":12,"Index":0,"Attempt":0,"Launch Time":1515493279077,"Executor ID":"0","Host":"172.30.65.138","Locality":"NODE_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":1,"Stage Attempt ID":0,"Task Info":{"Task ID":13,"Index":1,"Attempt":0,"Launch Time":1515493279078,"Executor ID":"1","Host":"172.30.65.138","Locality":"NODE_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":1,"Stage Attempt ID":0,"Task Info":{"Task ID":14,"Index":2,"Attempt":0,"Launch Time":1515493279152,"Executor ID":"1","Host":"172.30.65.138","Locality":"NODE_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":1,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":13,"Index":1,"Attempt":0,"Launch Time":1515493279078,"Executor ID":"1","Host":"172.30.65.138","Locality":"NODE_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1515493279152,"Failed":false,"Killed":false,"Accumulables":[{"ID":42,"Name":"internal.metrics.shuffle.read.recordsRead","Update":4,"Value":4,"Internal":true,"Count Failed Values":true},{"ID":41,"Name":"internal.metrics.shuffle.read.fetchWaitTime","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":40,"Name":"internal.metrics.shuffle.read.localBytesRead","Update":184,"Value":184,"Internal":true,"Count Failed Values":true},{"ID":39,"Name":"internal.metrics.shuffle.read.remoteBytesReadToDisk","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":38,"Name":"internal.metrics.shuffle.read.remoteBytesRead","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":37,"Name":"internal.metrics.shuffle.read.localBlocksFetched","Update":4,"Value":4,"Internal":true,"Count Failed Values":true},{"ID":36,"Name":"internal.metrics.shuffle.read.remoteBlocksFetched","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":34,"Name":"internal.metrics.peakExecutionMemory","Update":944,"Value":944,"Internal":true,"Count Failed Values":true},{"ID":33,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":32,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":29,"Name":"internal.metrics.resultSize","Update":1286,"Value":1286,"Internal":true,"Count Failed Values":true},{"ID":28,"Name":"internal.metrics.executorCpuTime","Update":41280000,"Value":41280000,"Internal":true,"Count Failed Values":true},{"ID":27,"Name":"internal.metrics.executorRunTime","Update":53,"Value":53,"Internal":true,"Count Failed Values":true},{"ID":26,"Name":"internal.metrics.executorDeserializeCpuTime","Update":11820000,"Value":11820000,"Internal":true,"Count Failed Values":true},{"ID":25,"Name":"internal.metrics.executorDeserializeTime","Update":17,"Value":17,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":17,"Executor Deserialize CPU Time":11820000,"Executor Run Time":53,"Executor CPU Time":41280000,"Result Size":1286,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":4,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":184,"Total Records Read":4},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":1,"Stage Attempt ID":0,"Task Info":{"Task ID":15,"Index":3,"Attempt":0,"Launch Time":1515493279166,"Executor ID":"1","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":1,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":14,"Index":2,"Attempt":0,"Launch Time":1515493279152,"Executor ID":"1","Host":"172.30.65.138","Locality":"NODE_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1515493279167,"Failed":false,"Killed":false,"Accumulables":[{"ID":42,"Name":"internal.metrics.shuffle.read.recordsRead","Update":3,"Value":7,"Internal":true,"Count Failed Values":true},{"ID":41,"Name":"internal.metrics.shuffle.read.fetchWaitTime","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":40,"Name":"internal.metrics.shuffle.read.localBytesRead","Update":138,"Value":322,"Internal":true,"Count Failed Values":true},{"ID":39,"Name":"internal.metrics.shuffle.read.remoteBytesReadToDisk","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":38,"Name":"internal.metrics.shuffle.read.remoteBytesRead","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":37,"Name":"internal.metrics.shuffle.read.localBlocksFetched","Update":3,"Value":7,"Internal":true,"Count Failed Values":true},{"ID":36,"Name":"internal.metrics.shuffle.read.remoteBlocksFetched","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":34,"Name":"internal.metrics.peakExecutionMemory","Update":944,"Value":1888,"Internal":true,"Count Failed Values":true},{"ID":33,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":32,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":29,"Name":"internal.metrics.resultSize","Update":1286,"Value":2572,"Internal":true,"Count Failed Values":true},{"ID":28,"Name":"internal.metrics.executorCpuTime","Update":7673000,"Value":48953000,"Internal":true,"Count Failed Values":true},{"ID":27,"Name":"internal.metrics.executorRunTime","Update":8,"Value":61,"Internal":true,"Count Failed Values":true},{"ID":26,"Name":"internal.metrics.executorDeserializeCpuTime","Update":1706000,"Value":13526000,"Internal":true,"Count Failed Values":true},{"ID":25,"Name":"internal.metrics.executorDeserializeTime","Update":2,"Value":19,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":2,"Executor Deserialize CPU Time":1706000,"Executor Run Time":8,"Executor CPU Time":7673000,"Result Size":1286,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":3,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":138,"Total Records Read":3},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":1,"Stage Attempt ID":0,"Task Info":{"Task ID":16,"Index":4,"Attempt":0,"Launch Time":1515493279179,"Executor ID":"1","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":1,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":15,"Index":3,"Attempt":0,"Launch Time":1515493279166,"Executor ID":"1","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1515493279180,"Failed":false,"Killed":false,"Accumulables":[{"ID":42,"Name":"internal.metrics.shuffle.read.recordsRead","Update":0,"Value":7,"Internal":true,"Count Failed Values":true},{"ID":41,"Name":"internal.metrics.shuffle.read.fetchWaitTime","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":40,"Name":"internal.metrics.shuffle.read.localBytesRead","Update":0,"Value":322,"Internal":true,"Count Failed Values":true},{"ID":39,"Name":"internal.metrics.shuffle.read.remoteBytesReadToDisk","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":38,"Name":"internal.metrics.shuffle.read.remoteBytesRead","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":37,"Name":"internal.metrics.shuffle.read.localBlocksFetched","Update":0,"Value":7,"Internal":true,"Count Failed Values":true},{"ID":36,"Name":"internal.metrics.shuffle.read.remoteBlocksFetched","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":34,"Name":"internal.metrics.peakExecutionMemory","Update":0,"Value":1888,"Internal":true,"Count Failed Values":true},{"ID":33,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":32,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":29,"Name":"internal.metrics.resultSize","Update":1134,"Value":3706,"Internal":true,"Count Failed Values":true},{"ID":28,"Name":"internal.metrics.executorCpuTime","Update":6972000,"Value":55925000,"Internal":true,"Count Failed Values":true},{"ID":27,"Name":"internal.metrics.executorRunTime","Update":7,"Value":68,"Internal":true,"Count Failed Values":true},{"ID":26,"Name":"internal.metrics.executorDeserializeCpuTime","Update":1569000,"Value":15095000,"Internal":true,"Count Failed Values":true},{"ID":25,"Name":"internal.metrics.executorDeserializeTime","Update":2,"Value":21,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":2,"Executor Deserialize CPU Time":1569000,"Executor Run Time":7,"Executor CPU Time":6972000,"Result Size":1134,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":1,"Stage Attempt ID":0,"Task Info":{"Task ID":17,"Index":5,"Attempt":0,"Launch Time":1515493279190,"Executor ID":"1","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":1,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":16,"Index":4,"Attempt":0,"Launch Time":1515493279179,"Executor ID":"1","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1515493279190,"Failed":false,"Killed":false,"Accumulables":[{"ID":42,"Name":"internal.metrics.shuffle.read.recordsRead","Update":0,"Value":7,"Internal":true,"Count Failed Values":true},{"ID":41,"Name":"internal.metrics.shuffle.read.fetchWaitTime","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":40,"Name":"internal.metrics.shuffle.read.localBytesRead","Update":0,"Value":322,"Internal":true,"Count Failed Values":true},{"ID":39,"Name":"internal.metrics.shuffle.read.remoteBytesReadToDisk","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":38,"Name":"internal.metrics.shuffle.read.remoteBytesRead","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":37,"Name":"internal.metrics.shuffle.read.localBlocksFetched","Update":0,"Value":7,"Internal":true,"Count Failed Values":true},{"ID":36,"Name":"internal.metrics.shuffle.read.remoteBlocksFetched","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":34,"Name":"internal.metrics.peakExecutionMemory","Update":0,"Value":1888,"Internal":true,"Count Failed Values":true},{"ID":33,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":32,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":29,"Name":"internal.metrics.resultSize","Update":1134,"Value":4840,"Internal":true,"Count Failed Values":true},{"ID":28,"Name":"internal.metrics.executorCpuTime","Update":4905000,"Value":60830000,"Internal":true,"Count Failed Values":true},{"ID":27,"Name":"internal.metrics.executorRunTime","Update":5,"Value":73,"Internal":true,"Count Failed Values":true},{"ID":26,"Name":"internal.metrics.executorDeserializeCpuTime","Update":1882000,"Value":16977000,"Internal":true,"Count Failed Values":true},{"ID":25,"Name":"internal.metrics.executorDeserializeTime","Update":2,"Value":23,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":2,"Executor Deserialize CPU Time":1882000,"Executor Run Time":5,"Executor CPU Time":4905000,"Result Size":1134,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":1,"Stage Attempt ID":0,"Task Info":{"Task ID":18,"Index":6,"Attempt":0,"Launch Time":1515493279193,"Executor ID":"0","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":1,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":12,"Index":0,"Attempt":0,"Launch Time":1515493279077,"Executor ID":"0","Host":"172.30.65.138","Locality":"NODE_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1515493279194,"Failed":false,"Killed":false,"Accumulables":[{"ID":42,"Name":"internal.metrics.shuffle.read.recordsRead","Update":3,"Value":10,"Internal":true,"Count Failed Values":true},{"ID":41,"Name":"internal.metrics.shuffle.read.fetchWaitTime","Update":23,"Value":23,"Internal":true,"Count Failed Values":true},{"ID":40,"Name":"internal.metrics.shuffle.read.localBytesRead","Update":0,"Value":322,"Internal":true,"Count Failed Values":true},{"ID":39,"Name":"internal.metrics.shuffle.read.remoteBytesReadToDisk","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":38,"Name":"internal.metrics.shuffle.read.remoteBytesRead","Update":138,"Value":138,"Internal":true,"Count Failed Values":true},{"ID":37,"Name":"internal.metrics.shuffle.read.localBlocksFetched","Update":0,"Value":7,"Internal":true,"Count Failed Values":true},{"ID":36,"Name":"internal.metrics.shuffle.read.remoteBlocksFetched","Update":3,"Value":3,"Internal":true,"Count Failed Values":true},{"ID":34,"Name":"internal.metrics.peakExecutionMemory","Update":944,"Value":2832,"Internal":true,"Count Failed Values":true},{"ID":33,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":32,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":29,"Name":"internal.metrics.resultSize","Update":1286,"Value":6126,"Internal":true,"Count Failed Values":true},{"ID":28,"Name":"internal.metrics.executorCpuTime","Update":56742000,"Value":117572000,"Internal":true,"Count Failed Values":true},{"ID":27,"Name":"internal.metrics.executorRunTime","Update":89,"Value":162,"Internal":true,"Count Failed Values":true},{"ID":26,"Name":"internal.metrics.executorDeserializeCpuTime","Update":12625000,"Value":29602000,"Internal":true,"Count Failed Values":true},{"ID":25,"Name":"internal.metrics.executorDeserializeTime","Update":18,"Value":41,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":18,"Executor Deserialize CPU Time":12625000,"Executor Run Time":89,"Executor CPU Time":56742000,"Result Size":1286,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":3,"Local Blocks Fetched":0,"Fetch Wait Time":23,"Remote Bytes Read":138,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":3},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":1,"Stage Attempt ID":0,"Task Info":{"Task ID":19,"Index":7,"Attempt":0,"Launch Time":1515493279202,"Executor ID":"1","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":1,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":17,"Index":5,"Attempt":0,"Launch Time":1515493279190,"Executor ID":"1","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1515493279203,"Failed":false,"Killed":false,"Accumulables":[{"ID":42,"Name":"internal.metrics.shuffle.read.recordsRead","Update":0,"Value":10,"Internal":true,"Count Failed Values":true},{"ID":41,"Name":"internal.metrics.shuffle.read.fetchWaitTime","Update":0,"Value":23,"Internal":true,"Count Failed Values":true},{"ID":40,"Name":"internal.metrics.shuffle.read.localBytesRead","Update":0,"Value":322,"Internal":true,"Count Failed Values":true},{"ID":39,"Name":"internal.metrics.shuffle.read.remoteBytesReadToDisk","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":38,"Name":"internal.metrics.shuffle.read.remoteBytesRead","Update":0,"Value":138,"Internal":true,"Count Failed Values":true},{"ID":37,"Name":"internal.metrics.shuffle.read.localBlocksFetched","Update":0,"Value":7,"Internal":true,"Count Failed Values":true},{"ID":36,"Name":"internal.metrics.shuffle.read.remoteBlocksFetched","Update":0,"Value":3,"Internal":true,"Count Failed Values":true},{"ID":34,"Name":"internal.metrics.peakExecutionMemory","Update":0,"Value":2832,"Internal":true,"Count Failed Values":true},{"ID":33,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":32,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":29,"Name":"internal.metrics.resultSize","Update":1134,"Value":7260,"Internal":true,"Count Failed Values":true},{"ID":28,"Name":"internal.metrics.executorCpuTime","Update":6476000,"Value":124048000,"Internal":true,"Count Failed Values":true},{"ID":27,"Name":"internal.metrics.executorRunTime","Update":7,"Value":169,"Internal":true,"Count Failed Values":true},{"ID":26,"Name":"internal.metrics.executorDeserializeCpuTime","Update":1890000,"Value":31492000,"Internal":true,"Count Failed Values":true},{"ID":25,"Name":"internal.metrics.executorDeserializeTime","Update":2,"Value":43,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":2,"Executor Deserialize CPU Time":1890000,"Executor Run Time":7,"Executor CPU Time":6476000,"Result Size":1134,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":1,"Stage Attempt ID":0,"Task Info":{"Task ID":20,"Index":8,"Attempt":0,"Launch Time":1515493279215,"Executor ID":"1","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":1,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":19,"Index":7,"Attempt":0,"Launch Time":1515493279202,"Executor ID":"1","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1515493279216,"Failed":false,"Killed":false,"Accumulables":[{"ID":42,"Name":"internal.metrics.shuffle.read.recordsRead","Update":0,"Value":10,"Internal":true,"Count Failed Values":true},{"ID":41,"Name":"internal.metrics.shuffle.read.fetchWaitTime","Update":0,"Value":23,"Internal":true,"Count Failed Values":true},{"ID":40,"Name":"internal.metrics.shuffle.read.localBytesRead","Update":0,"Value":322,"Internal":true,"Count Failed Values":true},{"ID":39,"Name":"internal.metrics.shuffle.read.remoteBytesReadToDisk","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":38,"Name":"internal.metrics.shuffle.read.remoteBytesRead","Update":0,"Value":138,"Internal":true,"Count Failed Values":true},{"ID":37,"Name":"internal.metrics.shuffle.read.localBlocksFetched","Update":0,"Value":7,"Internal":true,"Count Failed Values":true},{"ID":36,"Name":"internal.metrics.shuffle.read.remoteBlocksFetched","Update":0,"Value":3,"Internal":true,"Count Failed Values":true},{"ID":34,"Name":"internal.metrics.peakExecutionMemory","Update":0,"Value":2832,"Internal":true,"Count Failed Values":true},{"ID":33,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":32,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":29,"Name":"internal.metrics.resultSize","Update":1134,"Value":8394,"Internal":true,"Count Failed Values":true},{"ID":28,"Name":"internal.metrics.executorCpuTime","Update":6927000,"Value":130975000,"Internal":true,"Count Failed Values":true},{"ID":27,"Name":"internal.metrics.executorRunTime","Update":7,"Value":176,"Internal":true,"Count Failed Values":true},{"ID":26,"Name":"internal.metrics.executorDeserializeCpuTime","Update":2038000,"Value":33530000,"Internal":true,"Count Failed Values":true},{"ID":25,"Name":"internal.metrics.executorDeserializeTime","Update":2,"Value":45,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":2,"Executor Deserialize CPU Time":2038000,"Executor Run Time":7,"Executor CPU Time":6927000,"Result Size":1134,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":1,"Stage Attempt ID":0,"Task Info":{"Task ID":21,"Index":9,"Attempt":0,"Launch Time":1515493279218,"Executor ID":"0","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":1,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":18,"Index":6,"Attempt":0,"Launch Time":1515493279193,"Executor ID":"0","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1515493279218,"Failed":false,"Killed":false,"Accumulables":[{"ID":42,"Name":"internal.metrics.shuffle.read.recordsRead","Update":0,"Value":10,"Internal":true,"Count Failed Values":true},{"ID":41,"Name":"internal.metrics.shuffle.read.fetchWaitTime","Update":0,"Value":23,"Internal":true,"Count Failed Values":true},{"ID":40,"Name":"internal.metrics.shuffle.read.localBytesRead","Update":0,"Value":322,"Internal":true,"Count Failed Values":true},{"ID":39,"Name":"internal.metrics.shuffle.read.remoteBytesReadToDisk","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":38,"Name":"internal.metrics.shuffle.read.remoteBytesRead","Update":0,"Value":138,"Internal":true,"Count Failed Values":true},{"ID":37,"Name":"internal.metrics.shuffle.read.localBlocksFetched","Update":0,"Value":7,"Internal":true,"Count Failed Values":true},{"ID":36,"Name":"internal.metrics.shuffle.read.remoteBlocksFetched","Update":0,"Value":3,"Internal":true,"Count Failed Values":true},{"ID":34,"Name":"internal.metrics.peakExecutionMemory","Update":0,"Value":2832,"Internal":true,"Count Failed Values":true},{"ID":33,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":32,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":29,"Name":"internal.metrics.resultSize","Update":1134,"Value":9528,"Internal":true,"Count Failed Values":true},{"ID":28,"Name":"internal.metrics.executorCpuTime","Update":11214000,"Value":142189000,"Internal":true,"Count Failed Values":true},{"ID":27,"Name":"internal.metrics.executorRunTime","Update":16,"Value":192,"Internal":true,"Count Failed Values":true},{"ID":26,"Name":"internal.metrics.executorDeserializeCpuTime","Update":2697000,"Value":36227000,"Internal":true,"Count Failed Values":true},{"ID":25,"Name":"internal.metrics.executorDeserializeTime","Update":4,"Value":49,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":4,"Executor Deserialize CPU Time":2697000,"Executor Run Time":16,"Executor CPU Time":11214000,"Result Size":1134,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":1,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":20,"Index":8,"Attempt":0,"Launch Time":1515493279215,"Executor ID":"1","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1515493279226,"Failed":false,"Killed":false,"Accumulables":[{"ID":42,"Name":"internal.metrics.shuffle.read.recordsRead","Update":0,"Value":10,"Internal":true,"Count Failed Values":true},{"ID":41,"Name":"internal.metrics.shuffle.read.fetchWaitTime","Update":0,"Value":23,"Internal":true,"Count Failed Values":true},{"ID":40,"Name":"internal.metrics.shuffle.read.localBytesRead","Update":0,"Value":322,"Internal":true,"Count Failed Values":true},{"ID":39,"Name":"internal.metrics.shuffle.read.remoteBytesReadToDisk","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":38,"Name":"internal.metrics.shuffle.read.remoteBytesRead","Update":0,"Value":138,"Internal":true,"Count Failed Values":true},{"ID":37,"Name":"internal.metrics.shuffle.read.localBlocksFetched","Update":0,"Value":7,"Internal":true,"Count Failed Values":true},{"ID":36,"Name":"internal.metrics.shuffle.read.remoteBlocksFetched","Update":0,"Value":3,"Internal":true,"Count Failed Values":true},{"ID":34,"Name":"internal.metrics.peakExecutionMemory","Update":0,"Value":2832,"Internal":true,"Count Failed Values":true},{"ID":33,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":32,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":29,"Name":"internal.metrics.resultSize","Update":1134,"Value":10662,"Internal":true,"Count Failed Values":true},{"ID":28,"Name":"internal.metrics.executorCpuTime","Update":4905000,"Value":147094000,"Internal":true,"Count Failed Values":true},{"ID":27,"Name":"internal.metrics.executorRunTime","Update":5,"Value":197,"Internal":true,"Count Failed Values":true},{"ID":26,"Name":"internal.metrics.executorDeserializeCpuTime","Update":1700000,"Value":37927000,"Internal":true,"Count Failed Values":true},{"ID":25,"Name":"internal.metrics.executorDeserializeTime","Update":2,"Value":51,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":2,"Executor Deserialize CPU Time":1700000,"Executor Run Time":5,"Executor CPU Time":4905000,"Result Size":1134,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":1,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":21,"Index":9,"Attempt":0,"Launch Time":1515493279218,"Executor ID":"0","Host":"172.30.65.138","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1515493279232,"Failed":false,"Killed":false,"Accumulables":[{"ID":42,"Name":"internal.metrics.shuffle.read.recordsRead","Update":0,"Value":10,"Internal":true,"Count Failed Values":true},{"ID":41,"Name":"internal.metrics.shuffle.read.fetchWaitTime","Update":0,"Value":23,"Internal":true,"Count Failed Values":true},{"ID":40,"Name":"internal.metrics.shuffle.read.localBytesRead","Update":0,"Value":322,"Internal":true,"Count Failed Values":true},{"ID":39,"Name":"internal.metrics.shuffle.read.remoteBytesReadToDisk","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":38,"Name":"internal.metrics.shuffle.read.remoteBytesRead","Update":0,"Value":138,"Internal":true,"Count Failed Values":true},{"ID":37,"Name":"internal.metrics.shuffle.read.localBlocksFetched","Update":0,"Value":7,"Internal":true,"Count Failed Values":true},{"ID":36,"Name":"internal.metrics.shuffle.read.remoteBlocksFetched","Update":0,"Value":3,"Internal":true,"Count Failed Values":true},{"ID":34,"Name":"internal.metrics.peakExecutionMemory","Update":0,"Value":2832,"Internal":true,"Count Failed Values":true},{"ID":33,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":32,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":29,"Name":"internal.metrics.resultSize","Update":1134,"Value":11796,"Internal":true,"Count Failed Values":true},{"ID":28,"Name":"internal.metrics.executorCpuTime","Update":7850000,"Value":154944000,"Internal":true,"Count Failed Values":true},{"ID":27,"Name":"internal.metrics.executorRunTime","Update":8,"Value":205,"Internal":true,"Count Failed Values":true},{"ID":26,"Name":"internal.metrics.executorDeserializeCpuTime","Update":2186000,"Value":40113000,"Internal":true,"Count Failed Values":true},{"ID":25,"Name":"internal.metrics.executorDeserializeTime","Update":3,"Value":54,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":3,"Executor Deserialize CPU Time":2186000,"Executor Run Time":8,"Executor CPU Time":7850000,"Result Size":1134,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerStageCompleted","Stage Info":{"Stage ID":1,"Stage Attempt ID":0,"Stage Name":"collect at :29","Number of Tasks":10,"RDD Info":[{"RDD ID":2,"Name":"ShuffledRDD","Scope":"{\"id\":\"2\",\"name\":\"reduceByKey\"}","Callsite":"reduceByKey at :29","Parent IDs":[1],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":10,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0}],"Parent IDs":[0],"Details":"org.apache.spark.rdd.RDD.collect(RDD.scala:936)\n$line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:29)\n$line17.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:34)\n$line17.$read$$iw$$iw$$iw$$iw$$iw$$iw.(:36)\n$line17.$read$$iw$$iw$$iw$$iw$$iw.(:38)\n$line17.$read$$iw$$iw$$iw$$iw.(:40)\n$line17.$read$$iw$$iw$$iw.(:42)\n$line17.$read$$iw$$iw.(:44)\n$line17.$read$$iw.(:46)\n$line17.$read.(:48)\n$line17.$read$.(:52)\n$line17.$read$.()\n$line17.$eval$.$print$lzycompute(:7)\n$line17.$eval$.$print(:6)\n$line17.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:498)\nscala.tools.nsc.interpreter.IMain$ReadEvalPrint.call(IMain.scala:786)","Submission Time":1515493279071,"Completion Time":1515493279232,"Accumulables":[{"ID":41,"Name":"internal.metrics.shuffle.read.fetchWaitTime","Value":23,"Internal":true,"Count Failed Values":true},{"ID":32,"Name":"internal.metrics.memoryBytesSpilled","Value":0,"Internal":true,"Count Failed Values":true},{"ID":26,"Name":"internal.metrics.executorDeserializeCpuTime","Value":40113000,"Internal":true,"Count Failed Values":true},{"ID":29,"Name":"internal.metrics.resultSize","Value":11796,"Internal":true,"Count Failed Values":true},{"ID":38,"Name":"internal.metrics.shuffle.read.remoteBytesRead","Value":138,"Internal":true,"Count Failed Values":true},{"ID":40,"Name":"internal.metrics.shuffle.read.localBytesRead","Value":322,"Internal":true,"Count Failed Values":true},{"ID":25,"Name":"internal.metrics.executorDeserializeTime","Value":54,"Internal":true,"Count Failed Values":true},{"ID":34,"Name":"internal.metrics.peakExecutionMemory","Value":2832,"Internal":true,"Count Failed Values":true},{"ID":37,"Name":"internal.metrics.shuffle.read.localBlocksFetched","Value":7,"Internal":true,"Count Failed Values":true},{"ID":28,"Name":"internal.metrics.executorCpuTime","Value":154944000,"Internal":true,"Count Failed Values":true},{"ID":27,"Name":"internal.metrics.executorRunTime","Value":205,"Internal":true,"Count Failed Values":true},{"ID":36,"Name":"internal.metrics.shuffle.read.remoteBlocksFetched","Value":3,"Internal":true,"Count Failed Values":true},{"ID":39,"Name":"internal.metrics.shuffle.read.remoteBytesReadToDisk","Value":0,"Internal":true,"Count Failed Values":true},{"ID":42,"Name":"internal.metrics.shuffle.read.recordsRead","Value":10,"Internal":true,"Count Failed Values":true},{"ID":33,"Name":"internal.metrics.diskBytesSpilled","Value":0,"Internal":true,"Count Failed Values":true}]}} +{"Event":"SparkListenerJobEnd","Job ID":0,"Completion Time":1515493279237,"Job Result":{"Result":"JobSucceeded"}} +{"Event":"SparkListenerApplicationEnd","Timestamp":1515493477606} diff --git a/core/src/test/resources/spark-events/application_1516285256255_0012 b/core/src/test/resources/spark-events/application_1516285256255_0012 new file mode 100755 index 0000000000000..3e1736c3fe224 --- /dev/null +++ b/core/src/test/resources/spark-events/application_1516285256255_0012 @@ -0,0 +1,71 @@ +{"Event":"SparkListenerLogStart","Spark Version":"2.3.0-SNAPSHOT"} +{"Event":"SparkListenerEnvironmentUpdate","JVM Information":{"Java Home":"/usr/lib/jvm/java-1.8.0-openjdk-1.8.0.161-0.b14.el7_4.x86_64/jre","Java Version":"1.8.0_161 (Oracle Corporation)","Scala Version":"version 2.11.8"},"Spark Properties":{"spark.blacklist.enabled":"true","spark.driver.host":"apiros-1.gce.test.com","spark.eventLog.enabled":"true","spark.driver.port":"33058","spark.repl.class.uri":"spark://apiros-1.gce.test.com:33058/classes","spark.jars":"","spark.repl.class.outputDir":"/tmp/spark-6781fb17-e07a-4b32-848b-9936c2e88b33/repl-c0fd7008-04be-471e-a173-6ad3e62d53d7","spark.app.name":"Spark shell","spark.blacklist.stage.maxFailedExecutorsPerNode":"1","spark.scheduler.mode":"FIFO","spark.executor.instances":"8","spark.ui.showConsoleProgress":"true","spark.blacklist.stage.maxFailedTasksPerExecutor":"1","spark.executor.id":"driver","spark.submit.deployMode":"client","spark.master":"yarn","spark.ui.filters":"org.apache.hadoop.yarn.server.webproxy.amfilter.AmIpFilter","spark.executor.memory":"2G","spark.home":"/github/spark","spark.sql.catalogImplementation":"hive","spark.driver.appUIAddress":"http://apiros-1.gce.test.com:4040","spark.blacklist.application.maxFailedTasksPerExecutor":"10","spark.org.apache.hadoop.yarn.server.webproxy.amfilter.AmIpFilter.param.PROXY_HOSTS":"apiros-1.gce.test.com","spark.org.apache.hadoop.yarn.server.webproxy.amfilter.AmIpFilter.param.PROXY_URI_BASES":"http://apiros-1.gce.test.com:8088/proxy/application_1516285256255_0012","spark.app.id":"application_1516285256255_0012"},"System Properties":{"java.io.tmpdir":"/tmp","line.separator":"\n","path.separator":":","sun.management.compiler":"HotSpot 64-Bit Tiered Compilers","SPARK_SUBMIT":"true","sun.cpu.endian":"little","java.specification.version":"1.8","java.vm.specification.name":"Java Virtual Machine Specification","java.vendor":"Oracle Corporation","java.vm.specification.version":"1.8","user.home":"*********(redacted)","file.encoding.pkg":"sun.io","sun.nio.ch.bugLevel":"","sun.arch.data.model":"64","sun.boot.library.path":"/usr/lib/jvm/java-1.8.0-openjdk-1.8.0.161-0.b14.el7_4.x86_64/jre/lib/amd64","user.dir":"*********(redacted)","java.library.path":"/usr/java/packages/lib/amd64:/usr/lib64:/lib64:/lib:/usr/lib","sun.cpu.isalist":"","os.arch":"amd64","java.vm.version":"25.161-b14","java.endorsed.dirs":"/usr/lib/jvm/java-1.8.0-openjdk-1.8.0.161-0.b14.el7_4.x86_64/jre/lib/endorsed","java.runtime.version":"1.8.0_161-b14","java.vm.info":"mixed mode","java.ext.dirs":"/usr/lib/jvm/java-1.8.0-openjdk-1.8.0.161-0.b14.el7_4.x86_64/jre/lib/ext:/usr/java/packages/lib/ext","java.runtime.name":"OpenJDK Runtime Environment","file.separator":"/","java.class.version":"52.0","scala.usejavacp":"true","java.specification.name":"Java Platform API Specification","sun.boot.class.path":"/usr/lib/jvm/java-1.8.0-openjdk-1.8.0.161-0.b14.el7_4.x86_64/jre/lib/resources.jar:/usr/lib/jvm/java-1.8.0-openjdk-1.8.0.161-0.b14.el7_4.x86_64/jre/lib/rt.jar:/usr/lib/jvm/java-1.8.0-openjdk-1.8.0.161-0.b14.el7_4.x86_64/jre/lib/sunrsasign.jar:/usr/lib/jvm/java-1.8.0-openjdk-1.8.0.161-0.b14.el7_4.x86_64/jre/lib/jsse.jar:/usr/lib/jvm/java-1.8.0-openjdk-1.8.0.161-0.b14.el7_4.x86_64/jre/lib/jce.jar:/usr/lib/jvm/java-1.8.0-openjdk-1.8.0.161-0.b14.el7_4.x86_64/jre/lib/charsets.jar:/usr/lib/jvm/java-1.8.0-openjdk-1.8.0.161-0.b14.el7_4.x86_64/jre/lib/jfr.jar:/usr/lib/jvm/java-1.8.0-openjdk-1.8.0.161-0.b14.el7_4.x86_64/jre/classes","file.encoding":"UTF-8","user.timezone":"*********(redacted)","java.specification.vendor":"Oracle Corporation","sun.java.launcher":"SUN_STANDARD","os.version":"3.10.0-693.5.2.el7.x86_64","sun.os.patch.level":"unknown","java.vm.specification.vendor":"Oracle Corporation","user.country":"*********(redacted)","sun.jnu.encoding":"UTF-8","user.language":"*********(redacted)","java.vendor.url":"*********(redacted)","java.awt.printerjob":"sun.print.PSPrinterJob","java.awt.graphicsenv":"sun.awt.X11GraphicsEnvironment","awt.toolkit":"sun.awt.X11.XToolkit","os.name":"Linux","java.vm.vendor":"Oracle Corporation","java.vendor.url.bug":"*********(redacted)","user.name":"*********(redacted)","java.vm.name":"OpenJDK 64-Bit Server VM","sun.java.command":"org.apache.spark.deploy.SparkSubmit --master yarn --deploy-mode client --conf spark.blacklist.stage.maxFailedTasksPerExecutor=1 --conf spark.blacklist.enabled=true --conf spark.blacklist.application.maxFailedTasksPerExecutor=10 --conf spark.blacklist.stage.maxFailedExecutorsPerNode=1 --conf spark.eventLog.enabled=true --class org.apache.spark.repl.Main --name Spark shell --executor-memory 2G --num-executors 8 spark-shell","java.home":"/usr/lib/jvm/java-1.8.0-openjdk-1.8.0.161-0.b14.el7_4.x86_64/jre","java.version":"1.8.0_161","sun.io.unicode.encoding":"UnicodeLittle"},"Classpath Entries":{"/github/spark/assembly/target/scala-2.11/jars/validation-api-1.1.0.Final.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/arrow-vector-0.8.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/commons-io-2.4.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/javax.servlet-api-3.1.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/spark-hive_2.11-2.3.0-SNAPSHOT.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/scala-parser-combinators_2.11-1.0.4.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/stax-api-1.0-2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/json4s-ast_2.11-3.2.11.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/apache-log4j-extras-1.2.17.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/hive-metastore-1.2.1.spark2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/avro-1.7.7.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/core-1.1.2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/jersey-common-2.22.2.jar":"System Classpath","/github/spark/conf/":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/metrics-json-3.1.5.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/protobuf-java-2.5.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/aircompressor-0.8.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/stax-api-1.0.1.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/leveldbjni-all-1.8.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/snappy-java-1.1.2.6.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/curator-recipes-2.7.1.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/jersey-container-servlet-core-2.22.2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/arrow-format-0.8.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/ivy-2.4.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/libthrift-0.9.3.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/commons-lang-2.6.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/spark-sketch_2.11-2.3.0-SNAPSHOT.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/spark-tags_2.11-2.3.0-SNAPSHOT.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/hadoop-yarn-common-2.6.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/slf4j-api-1.7.16.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/jersey-server-2.22.2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/stringtemplate-3.2.1.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/pyrolite-4.13.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/commons-crypto-1.0.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/datanucleus-api-jdo-3.2.6.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/commons-net-2.2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/hadoop-annotations-2.6.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/orc-core-1.4.1-nohive.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/spire_2.11-0.13.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/arrow-memory-0.8.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/log4j-1.2.17.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/jackson-core-asl-1.9.13.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/scalap-2.11.8.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/scala-xml_2.11-1.0.5.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/JavaEWAH-0.3.2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/bcprov-jdk15on-1.58.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/scala-reflect-2.11.8.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/spark-sql_2.11-2.3.0-SNAPSHOT.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/javolution-5.5.1.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/libfb303-0.9.3.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/jersey-media-jaxb-2.22.2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/jodd-core-3.5.2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/janino-3.0.8.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/spark-unsafe_2.11-2.3.0-SNAPSHOT.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/antlr4-runtime-4.7.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/snappy-0.2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/guice-3.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/java-xmlbuilder-1.1.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/chill_2.11-0.8.4.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/apacheds-kerberos-codec-2.0.0-M15.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/stream-2.7.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/ST4-4.0.4.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/datanucleus-core-3.2.10.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/hadoop-yarn-api-2.6.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/guice-servlet-3.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/avro-mapred-1.7.7-hadoop2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/hive-exec-1.2.1.spark2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/commons-beanutils-1.7.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/jetty-6.1.26.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/hadoop-yarn-server-common-2.6.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/commons-configuration-1.6.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/minlog-1.3.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/base64-2.3.8.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/slf4j-log4j12-1.7.16.jar":"System Classpath","/etc/hadoop/conf/":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/commons-httpclient-3.1.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/jackson-mapper-asl-1.9.13.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/spark-yarn_2.11-2.3.0-SNAPSHOT.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/spark-repl_2.11-2.3.0-SNAPSHOT.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/spire-macros_2.11-0.13.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/hadoop-client-2.6.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/jackson-jaxrs-1.9.13.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/apacheds-i18n-2.0.0-M15.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/commons-cli-1.2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/javax.annotation-api-1.2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/lz4-java-1.4.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/spark-mllib-local_2.11-2.3.0-SNAPSHOT.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/commons-compress-1.4.1.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/breeze-macros_2.11-0.13.2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/jackson-module-scala_2.11-2.6.7.1.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/curator-framework-2.7.1.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/curator-client-2.7.1.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/netty-3.9.9.Final.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/calcite-avatica-1.2.0-incubating.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/jackson-annotations-2.6.7.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/machinist_2.11-0.6.1.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/jaxb-api-2.2.2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/api-asn1-api-1.0.0-M20.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/calcite-linq4j-1.2.0-incubating.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/spark-network-common_2.11-2.3.0-SNAPSHOT.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/hadoop-auth-2.6.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/orc-mapreduce-1.4.1-nohive.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/hadoop-common-2.6.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/hadoop-mapreduce-client-common-2.6.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/xercesImpl-2.9.1.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/hppc-0.7.2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/commons-beanutils-core-1.8.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/commons-math3-3.4.1.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/spark-core_2.11-2.3.0-SNAPSHOT.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/scala-library-2.11.8.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/jersey-container-servlet-2.22.2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/hadoop-mapreduce-client-app-2.6.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/parquet-hadoop-1.8.2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/spark-catalyst_2.11-2.3.0-SNAPSHOT.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/metrics-jvm-3.1.5.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/scala-compiler-2.11.8.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/objenesis-2.1.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/shapeless_2.11-2.3.2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/activation-1.1.1.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/py4j-0.10.6.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/hadoop-mapreduce-client-core-2.6.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/zookeeper-3.4.6.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/parquet-hadoop-bundle-1.6.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/antlr-runtime-3.4.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/spark-mllib_2.11-2.3.0-SNAPSHOT.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/oro-2.0.8.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/eigenbase-properties-1.1.5.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/spark-graphx_2.11-2.3.0-SNAPSHOT.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/hk2-locator-2.4.0-b34.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/javax.ws.rs-api-2.0.1.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/aopalliance-repackaged-2.4.0-b34.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/spark-network-shuffle_2.11-2.3.0-SNAPSHOT.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/parquet-format-2.3.1.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/spark-launcher_2.11-2.3.0-SNAPSHOT.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/hadoop-mapreduce-client-shuffle-2.6.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/paranamer-2.8.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/jta-1.1.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/derby-10.12.1.1.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/xz-1.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/hadoop-yarn-client-2.6.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/commons-logging-1.1.3.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/commons-pool-1.5.4.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/spark-streaming_2.11-2.3.0-SNAPSHOT.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/javassist-3.18.1-GA.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/guava-14.0.1.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/xmlenc-0.52.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/htrace-core-3.0.4.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/javax.inject-2.4.0-b34.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/httpclient-4.5.4.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/jackson-databind-2.6.7.1.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/parquet-column-1.8.2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/zstd-jni-1.3.2-2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/hadoop-yarn-server-web-proxy-2.6.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/spark-kvstore_2.11-2.3.0-SNAPSHOT.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/parquet-encoding-1.8.2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/univocity-parsers-2.5.9.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/compress-lzf-1.0.3.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/commons-collections-3.2.2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/hadoop-mapreduce-client-jobclient-2.6.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/osgi-resource-locator-1.0.1.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/jersey-client-2.22.2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/chill-java-0.8.4.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/avro-ipc-1.7.7.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/antlr-2.7.7.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/hk2-utils-2.4.0-b34.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/RoaringBitmap-0.5.11.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/jul-to-slf4j-1.7.16.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/xbean-asm5-shaded-4.4.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/datanucleus-rdbms-3.2.9.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/arpack_combined_all-0.1.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/hk2-api-2.4.0-b34.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/metrics-graphite-3.1.5.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/parquet-common-1.8.2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/hadoop-hdfs-2.6.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/javax.inject-1.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/opencsv-2.3.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/api-util-1.0.0-M20.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/jdo-api-3.0.1.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/jackson-module-paranamer-2.7.9.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/kryo-shaded-3.0.3.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/commons-dbcp-1.4.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/netty-all-4.1.17.Final.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/parquet-jackson-1.8.2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/gson-2.2.4.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/calcite-core-1.2.0-incubating.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/macro-compat_2.11-1.1.1.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/flatbuffers-1.2.0-3f79e055.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/json4s-core_2.11-3.2.11.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/breeze_2.11-0.13.2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/commons-digester-1.8.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/jsr305-1.3.9.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/jtransforms-2.4.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/jets3t-0.9.4.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/jackson-core-2.6.7.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/jackson-xc-1.9.13.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/aopalliance-1.0.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/bonecp-0.8.0.RELEASE.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/jetty-util-6.1.26.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/joda-time-2.9.3.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/json4s-jackson_2.11-3.2.11.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/metrics-core-3.1.5.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/jcl-over-slf4j-1.7.16.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/httpcore-4.4.8.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/commons-lang3-3.5.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/jersey-guava-2.22.2.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/commons-codec-1.10.jar":"System Classpath","/github/spark/assembly/target/scala-2.11/jars/commons-compiler-3.0.8.jar":"System Classpath"}} +{"Event":"SparkListenerApplicationStart","App Name":"Spark shell","App ID":"application_1516285256255_0012","Timestamp":1516300235119,"User":"attilapiros"} +{"Event":"SparkListenerExecutorAdded","Timestamp":1516300252095,"Executor ID":"2","Executor Info":{"Host":"apiros-3.gce.test.com","Total Cores":1,"Log Urls":{"stdout":"http://apiros-3.gce.test.com:8042/node/containerlogs/container_1516285256255_0012_01_000003/attilapiros/stdout?start=-4096","stderr":"http://apiros-3.gce.test.com:8042/node/containerlogs/container_1516285256255_0012_01_000003/attilapiros/stderr?start=-4096"}}} +{"Event":"SparkListenerBlockManagerAdded","Block Manager ID":{"Executor ID":"2","Host":"apiros-3.gce.test.com","Port":38670},"Maximum Memory":956615884,"Timestamp":1516300252260,"Maximum Onheap Memory":956615884,"Maximum Offheap Memory":0} +{"Event":"SparkListenerExecutorAdded","Timestamp":1516300252715,"Executor ID":"3","Executor Info":{"Host":"apiros-2.gce.test.com","Total Cores":1,"Log Urls":{"stdout":"http://apiros-2.gce.test.com:8042/node/containerlogs/container_1516285256255_0012_01_000004/attilapiros/stdout?start=-4096","stderr":"http://apiros-2.gce.test.com:8042/node/containerlogs/container_1516285256255_0012_01_000004/attilapiros/stderr?start=-4096"}}} +{"Event":"SparkListenerExecutorAdded","Timestamp":1516300252918,"Executor ID":"1","Executor Info":{"Host":"apiros-3.gce.test.com","Total Cores":1,"Log Urls":{"stdout":"http://apiros-3.gce.test.com:8042/node/containerlogs/container_1516285256255_0012_01_000002/attilapiros/stdout?start=-4096","stderr":"http://apiros-3.gce.test.com:8042/node/containerlogs/container_1516285256255_0012_01_000002/attilapiros/stderr?start=-4096"}}} +{"Event":"SparkListenerBlockManagerAdded","Block Manager ID":{"Executor ID":"3","Host":"apiros-2.gce.test.com","Port":38641},"Maximum Memory":956615884,"Timestamp":1516300252959,"Maximum Onheap Memory":956615884,"Maximum Offheap Memory":0} +{"Event":"SparkListenerBlockManagerAdded","Block Manager ID":{"Executor ID":"1","Host":"apiros-3.gce.test.com","Port":34970},"Maximum Memory":956615884,"Timestamp":1516300252988,"Maximum Onheap Memory":956615884,"Maximum Offheap Memory":0} +{"Event":"SparkListenerExecutorAdded","Timestamp":1516300253542,"Executor ID":"4","Executor Info":{"Host":"apiros-2.gce.test.com","Total Cores":1,"Log Urls":{"stdout":"http://apiros-2.gce.test.com:8042/node/containerlogs/container_1516285256255_0012_01_000005/attilapiros/stdout?start=-4096","stderr":"http://apiros-2.gce.test.com:8042/node/containerlogs/container_1516285256255_0012_01_000005/attilapiros/stderr?start=-4096"}}} +{"Event":"SparkListenerBlockManagerAdded","Block Manager ID":{"Executor ID":"4","Host":"apiros-2.gce.test.com","Port":33229},"Maximum Memory":956615884,"Timestamp":1516300253653,"Maximum Onheap Memory":956615884,"Maximum Offheap Memory":0} +{"Event":"SparkListenerExecutorAdded","Timestamp":1516300254323,"Executor ID":"5","Executor Info":{"Host":"apiros-2.gce.test.com","Total Cores":1,"Log Urls":{"stdout":"http://apiros-2.gce.test.com:8042/node/containerlogs/container_1516285256255_0012_01_000007/attilapiros/stdout?start=-4096","stderr":"http://apiros-2.gce.test.com:8042/node/containerlogs/container_1516285256255_0012_01_000007/attilapiros/stderr?start=-4096"}}} +{"Event":"SparkListenerBlockManagerAdded","Block Manager ID":{"Executor ID":"5","Host":"apiros-2.gce.test.com","Port":45147},"Maximum Memory":956615884,"Timestamp":1516300254385,"Maximum Onheap Memory":956615884,"Maximum Offheap Memory":0} +{"Event":"SparkListenerJobStart","Job ID":0,"Submission Time":1516300392631,"Stage Infos":[{"Stage ID":0,"Stage Attempt ID":0,"Stage Name":"map at :27","Number of Tasks":10,"RDD Info":[{"RDD ID":1,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"1\",\"name\":\"map\"}","Callsite":"map at :27","Parent IDs":[0],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":10,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":0,"Name":"ParallelCollectionRDD","Scope":"{\"id\":\"0\",\"name\":\"parallelize\"}","Callsite":"parallelize at :27","Parent IDs":[],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":10,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0}],"Parent IDs":[],"Details":"org.apache.spark.rdd.RDD.map(RDD.scala:370)\n$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:27)\n$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:35)\n$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw.(:37)\n$line15.$read$$iw$$iw$$iw$$iw$$iw.(:39)\n$line15.$read$$iw$$iw$$iw$$iw.(:41)\n$line15.$read$$iw$$iw$$iw.(:43)\n$line15.$read$$iw$$iw.(:45)\n$line15.$read$$iw.(:47)\n$line15.$read.(:49)\n$line15.$read$.(:53)\n$line15.$read$.()\n$line15.$eval$.$print$lzycompute(:7)\n$line15.$eval$.$print(:6)\n$line15.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:498)\nscala.tools.nsc.interpreter.IMain$ReadEvalPrint.call(IMain.scala:786)","Accumulables":[]},{"Stage ID":1,"Stage Attempt ID":0,"Stage Name":"collect at :30","Number of Tasks":10,"RDD Info":[{"RDD ID":2,"Name":"ShuffledRDD","Scope":"{\"id\":\"2\",\"name\":\"reduceByKey\"}","Callsite":"reduceByKey at :30","Parent IDs":[1],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":10,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0}],"Parent IDs":[0],"Details":"org.apache.spark.rdd.RDD.collect(RDD.scala:936)\n$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:30)\n$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:35)\n$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw.(:37)\n$line15.$read$$iw$$iw$$iw$$iw$$iw.(:39)\n$line15.$read$$iw$$iw$$iw$$iw.(:41)\n$line15.$read$$iw$$iw$$iw.(:43)\n$line15.$read$$iw$$iw.(:45)\n$line15.$read$$iw.(:47)\n$line15.$read.(:49)\n$line15.$read$.(:53)\n$line15.$read$.()\n$line15.$eval$.$print$lzycompute(:7)\n$line15.$eval$.$print(:6)\n$line15.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:498)\nscala.tools.nsc.interpreter.IMain$ReadEvalPrint.call(IMain.scala:786)","Accumulables":[]}],"Stage IDs":[0,1],"Properties":{"spark.rdd.scope.noOverride":"true","spark.rdd.scope":"{\"id\":\"3\",\"name\":\"collect\"}"}} +{"Event":"SparkListenerStageSubmitted","Stage Info":{"Stage ID":0,"Stage Attempt ID":0,"Stage Name":"map at :27","Number of Tasks":10,"RDD Info":[{"RDD ID":1,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"1\",\"name\":\"map\"}","Callsite":"map at :27","Parent IDs":[0],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":10,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":0,"Name":"ParallelCollectionRDD","Scope":"{\"id\":\"0\",\"name\":\"parallelize\"}","Callsite":"parallelize at :27","Parent IDs":[],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":10,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0}],"Parent IDs":[],"Details":"org.apache.spark.rdd.RDD.map(RDD.scala:370)\n$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:27)\n$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:35)\n$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw.(:37)\n$line15.$read$$iw$$iw$$iw$$iw$$iw.(:39)\n$line15.$read$$iw$$iw$$iw$$iw.(:41)\n$line15.$read$$iw$$iw$$iw.(:43)\n$line15.$read$$iw$$iw.(:45)\n$line15.$read$$iw.(:47)\n$line15.$read.(:49)\n$line15.$read$.(:53)\n$line15.$read$.()\n$line15.$eval$.$print$lzycompute(:7)\n$line15.$eval$.$print(:6)\n$line15.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:498)\nscala.tools.nsc.interpreter.IMain$ReadEvalPrint.call(IMain.scala:786)","Submission Time":1516300392658,"Accumulables":[]},"Properties":{"spark.rdd.scope.noOverride":"true","spark.rdd.scope":"{\"id\":\"3\",\"name\":\"collect\"}"}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":0,"Index":0,"Attempt":0,"Launch Time":1516300392816,"Executor ID":"1","Host":"apiros-3.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":1,"Index":1,"Attempt":0,"Launch Time":1516300392832,"Executor ID":"5","Host":"apiros-2.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":2,"Index":2,"Attempt":0,"Launch Time":1516300392832,"Executor ID":"3","Host":"apiros-2.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":3,"Index":3,"Attempt":0,"Launch Time":1516300392833,"Executor ID":"2","Host":"apiros-3.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":4,"Index":4,"Attempt":0,"Launch Time":1516300392833,"Executor ID":"4","Host":"apiros-2.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":5,"Index":5,"Attempt":0,"Launch Time":1516300394320,"Executor ID":"5","Host":"apiros-2.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":6,"Index":6,"Attempt":0,"Launch Time":1516300394323,"Executor ID":"4","Host":"apiros-2.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"org.apache.spark.scheduler.SparkListenerExecutorBlacklistedForStage","time":1516300394348,"executorId":"5","taskFailures":1,"stageId":0,"stageAttemptId":0} +{"Event":"org.apache.spark.scheduler.SparkListenerNodeBlacklistedForStage","time":1516300394348,"hostId":"apiros-2.gce.test.com","executorFailures":1,"stageId":0,"stageAttemptId":0} +{"Event":"org.apache.spark.scheduler.SparkListenerExecutorBlacklistedForStage","time":1516300394356,"executorId":"4","taskFailures":1,"stageId":0,"stageAttemptId":0} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"ExceptionFailure","Class Name":"java.lang.RuntimeException","Description":"Bad executor","Stack Trace":[{"Declaring Class":"$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2","Method Name":"apply","File Name":"","Line Number":28},{"Declaring Class":"$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2","Method Name":"apply","File Name":"","Line Number":27},{"Declaring Class":"scala.collection.Iterator$$anon$11","Method Name":"next","File Name":"Iterator.scala","Line Number":409},{"Declaring Class":"org.apache.spark.util.collection.ExternalSorter","Method Name":"insertAll","File Name":"ExternalSorter.scala","Line Number":193},{"Declaring Class":"org.apache.spark.shuffle.sort.SortShuffleWriter","Method Name":"write","File Name":"SortShuffleWriter.scala","Line Number":63},{"Declaring Class":"org.apache.spark.scheduler.ShuffleMapTask","Method Name":"runTask","File Name":"ShuffleMapTask.scala","Line Number":96},{"Declaring Class":"org.apache.spark.scheduler.ShuffleMapTask","Method Name":"runTask","File Name":"ShuffleMapTask.scala","Line Number":53},{"Declaring Class":"org.apache.spark.scheduler.Task","Method Name":"run","File Name":"Task.scala","Line Number":109},{"Declaring Class":"org.apache.spark.executor.Executor$TaskRunner","Method Name":"run","File Name":"Executor.scala","Line Number":345},{"Declaring Class":"java.util.concurrent.ThreadPoolExecutor","Method Name":"runWorker","File Name":"ThreadPoolExecutor.java","Line Number":1149},{"Declaring Class":"java.util.concurrent.ThreadPoolExecutor$Worker","Method Name":"run","File Name":"ThreadPoolExecutor.java","Line Number":624},{"Declaring Class":"java.lang.Thread","Method Name":"run","File Name":"Thread.java","Line Number":748}],"Full Stack Trace":"java.lang.RuntimeException: Bad executor\n\tat $line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2.apply(:28)\n\tat $line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2.apply(:27)\n\tat scala.collection.Iterator$$anon$11.next(Iterator.scala:409)\n\tat org.apache.spark.util.collection.ExternalSorter.insertAll(ExternalSorter.scala:193)\n\tat org.apache.spark.shuffle.sort.SortShuffleWriter.write(SortShuffleWriter.scala:63)\n\tat org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:96)\n\tat org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:53)\n\tat org.apache.spark.scheduler.Task.run(Task.scala:109)\n\tat org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:345)\n\tat java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)\n\tat java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\n\tat java.lang.Thread.run(Thread.java:748)\n","Accumulator Updates":[{"ID":2,"Update":"1332","Internal":false,"Count Failed Values":true},{"ID":4,"Update":"0","Internal":false,"Count Failed Values":true},{"ID":5,"Update":"33","Internal":false,"Count Failed Values":true},{"ID":20,"Update":"3075188","Internal":false,"Count Failed Values":true}]},"Task Info":{"Task ID":1,"Index":1,"Attempt":0,"Launch Time":1516300392832,"Executor ID":"5","Host":"apiros-2.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1516300394338,"Failed":true,"Killed":false,"Accumulables":[{"ID":20,"Name":"internal.metrics.shuffle.write.writeTime","Update":3075188,"Value":3075188,"Internal":true,"Count Failed Values":true},{"ID":5,"Name":"internal.metrics.jvmGCTime","Update":33,"Value":33,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":1332,"Value":1332,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":0,"Executor Deserialize CPU Time":0,"Executor Run Time":1332,"Executor CPU Time":0,"Result Size":0,"JVM GC Time":33,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":3075188,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"ExceptionFailure","Class Name":"java.lang.RuntimeException","Description":"Bad executor","Stack Trace":[{"Declaring Class":"$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2","Method Name":"apply","File Name":"","Line Number":28},{"Declaring Class":"$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2","Method Name":"apply","File Name":"","Line Number":27},{"Declaring Class":"scala.collection.Iterator$$anon$11","Method Name":"next","File Name":"Iterator.scala","Line Number":409},{"Declaring Class":"org.apache.spark.util.collection.ExternalSorter","Method Name":"insertAll","File Name":"ExternalSorter.scala","Line Number":193},{"Declaring Class":"org.apache.spark.shuffle.sort.SortShuffleWriter","Method Name":"write","File Name":"SortShuffleWriter.scala","Line Number":63},{"Declaring Class":"org.apache.spark.scheduler.ShuffleMapTask","Method Name":"runTask","File Name":"ShuffleMapTask.scala","Line Number":96},{"Declaring Class":"org.apache.spark.scheduler.ShuffleMapTask","Method Name":"runTask","File Name":"ShuffleMapTask.scala","Line Number":53},{"Declaring Class":"org.apache.spark.scheduler.Task","Method Name":"run","File Name":"Task.scala","Line Number":109},{"Declaring Class":"org.apache.spark.executor.Executor$TaskRunner","Method Name":"run","File Name":"Executor.scala","Line Number":345},{"Declaring Class":"java.util.concurrent.ThreadPoolExecutor","Method Name":"runWorker","File Name":"ThreadPoolExecutor.java","Line Number":1149},{"Declaring Class":"java.util.concurrent.ThreadPoolExecutor$Worker","Method Name":"run","File Name":"ThreadPoolExecutor.java","Line Number":624},{"Declaring Class":"java.lang.Thread","Method Name":"run","File Name":"Thread.java","Line Number":748}],"Full Stack Trace":"java.lang.RuntimeException: Bad executor\n\tat $line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2.apply(:28)\n\tat $line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2.apply(:27)\n\tat scala.collection.Iterator$$anon$11.next(Iterator.scala:409)\n\tat org.apache.spark.util.collection.ExternalSorter.insertAll(ExternalSorter.scala:193)\n\tat org.apache.spark.shuffle.sort.SortShuffleWriter.write(SortShuffleWriter.scala:63)\n\tat org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:96)\n\tat org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:53)\n\tat org.apache.spark.scheduler.Task.run(Task.scala:109)\n\tat org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:345)\n\tat java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)\n\tat java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\n\tat java.lang.Thread.run(Thread.java:748)\n","Accumulator Updates":[{"ID":2,"Update":"1184","Internal":false,"Count Failed Values":true},{"ID":4,"Update":"0","Internal":false,"Count Failed Values":true},{"ID":5,"Update":"82","Internal":false,"Count Failed Values":true},{"ID":20,"Update":"16858066","Internal":false,"Count Failed Values":true}]},"Task Info":{"Task ID":4,"Index":4,"Attempt":0,"Launch Time":1516300392833,"Executor ID":"4","Host":"apiros-2.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1516300394355,"Failed":true,"Killed":false,"Accumulables":[{"ID":20,"Name":"internal.metrics.shuffle.write.writeTime","Update":16858066,"Value":19933254,"Internal":true,"Count Failed Values":true},{"ID":5,"Name":"internal.metrics.jvmGCTime","Update":82,"Value":115,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":1184,"Value":2516,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":0,"Executor Deserialize CPU Time":0,"Executor Run Time":1184,"Executor CPU Time":0,"Result Size":0,"JVM GC Time":82,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":16858066,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"ExceptionFailure","Class Name":"java.lang.RuntimeException","Description":"Bad executor","Stack Trace":[{"Declaring Class":"$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2","Method Name":"apply","File Name":"","Line Number":28},{"Declaring Class":"$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2","Method Name":"apply","File Name":"","Line Number":27},{"Declaring Class":"scala.collection.Iterator$$anon$11","Method Name":"next","File Name":"Iterator.scala","Line Number":409},{"Declaring Class":"org.apache.spark.util.collection.ExternalSorter","Method Name":"insertAll","File Name":"ExternalSorter.scala","Line Number":193},{"Declaring Class":"org.apache.spark.shuffle.sort.SortShuffleWriter","Method Name":"write","File Name":"SortShuffleWriter.scala","Line Number":63},{"Declaring Class":"org.apache.spark.scheduler.ShuffleMapTask","Method Name":"runTask","File Name":"ShuffleMapTask.scala","Line Number":96},{"Declaring Class":"org.apache.spark.scheduler.ShuffleMapTask","Method Name":"runTask","File Name":"ShuffleMapTask.scala","Line Number":53},{"Declaring Class":"org.apache.spark.scheduler.Task","Method Name":"run","File Name":"Task.scala","Line Number":109},{"Declaring Class":"org.apache.spark.executor.Executor$TaskRunner","Method Name":"run","File Name":"Executor.scala","Line Number":345},{"Declaring Class":"java.util.concurrent.ThreadPoolExecutor","Method Name":"runWorker","File Name":"ThreadPoolExecutor.java","Line Number":1149},{"Declaring Class":"java.util.concurrent.ThreadPoolExecutor$Worker","Method Name":"run","File Name":"ThreadPoolExecutor.java","Line Number":624},{"Declaring Class":"java.lang.Thread","Method Name":"run","File Name":"Thread.java","Line Number":748}],"Full Stack Trace":"java.lang.RuntimeException: Bad executor\n\tat $line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2.apply(:28)\n\tat $line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2.apply(:27)\n\tat scala.collection.Iterator$$anon$11.next(Iterator.scala:409)\n\tat org.apache.spark.util.collection.ExternalSorter.insertAll(ExternalSorter.scala:193)\n\tat org.apache.spark.shuffle.sort.SortShuffleWriter.write(SortShuffleWriter.scala:63)\n\tat org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:96)\n\tat org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:53)\n\tat org.apache.spark.scheduler.Task.run(Task.scala:109)\n\tat org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:345)\n\tat java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)\n\tat java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\n\tat java.lang.Thread.run(Thread.java:748)\n","Accumulator Updates":[{"ID":2,"Update":"51","Internal":false,"Count Failed Values":true},{"ID":4,"Update":"0","Internal":false,"Count Failed Values":true},{"ID":20,"Update":"183718","Internal":false,"Count Failed Values":true}]},"Task Info":{"Task ID":6,"Index":6,"Attempt":0,"Launch Time":1516300394323,"Executor ID":"4","Host":"apiros-2.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1516300394390,"Failed":true,"Killed":false,"Accumulables":[{"ID":20,"Name":"internal.metrics.shuffle.write.writeTime","Update":183718,"Value":20116972,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":51,"Value":2567,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":0,"Executor Deserialize CPU Time":0,"Executor Run Time":51,"Executor CPU Time":0,"Result Size":0,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":183718,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"ExceptionFailure","Class Name":"java.lang.RuntimeException","Description":"Bad executor","Stack Trace":[{"Declaring Class":"$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2","Method Name":"apply","File Name":"","Line Number":28},{"Declaring Class":"$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2","Method Name":"apply","File Name":"","Line Number":27},{"Declaring Class":"scala.collection.Iterator$$anon$11","Method Name":"next","File Name":"Iterator.scala","Line Number":409},{"Declaring Class":"org.apache.spark.util.collection.ExternalSorter","Method Name":"insertAll","File Name":"ExternalSorter.scala","Line Number":193},{"Declaring Class":"org.apache.spark.shuffle.sort.SortShuffleWriter","Method Name":"write","File Name":"SortShuffleWriter.scala","Line Number":63},{"Declaring Class":"org.apache.spark.scheduler.ShuffleMapTask","Method Name":"runTask","File Name":"ShuffleMapTask.scala","Line Number":96},{"Declaring Class":"org.apache.spark.scheduler.ShuffleMapTask","Method Name":"runTask","File Name":"ShuffleMapTask.scala","Line Number":53},{"Declaring Class":"org.apache.spark.scheduler.Task","Method Name":"run","File Name":"Task.scala","Line Number":109},{"Declaring Class":"org.apache.spark.executor.Executor$TaskRunner","Method Name":"run","File Name":"Executor.scala","Line Number":345},{"Declaring Class":"java.util.concurrent.ThreadPoolExecutor","Method Name":"runWorker","File Name":"ThreadPoolExecutor.java","Line Number":1149},{"Declaring Class":"java.util.concurrent.ThreadPoolExecutor$Worker","Method Name":"run","File Name":"ThreadPoolExecutor.java","Line Number":624},{"Declaring Class":"java.lang.Thread","Method Name":"run","File Name":"Thread.java","Line Number":748}],"Full Stack Trace":"java.lang.RuntimeException: Bad executor\n\tat $line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2.apply(:28)\n\tat $line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$2.apply(:27)\n\tat scala.collection.Iterator$$anon$11.next(Iterator.scala:409)\n\tat org.apache.spark.util.collection.ExternalSorter.insertAll(ExternalSorter.scala:193)\n\tat org.apache.spark.shuffle.sort.SortShuffleWriter.write(SortShuffleWriter.scala:63)\n\tat org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:96)\n\tat org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:53)\n\tat org.apache.spark.scheduler.Task.run(Task.scala:109)\n\tat org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:345)\n\tat java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)\n\tat java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\n\tat java.lang.Thread.run(Thread.java:748)\n","Accumulator Updates":[{"ID":2,"Update":"27","Internal":false,"Count Failed Values":true},{"ID":4,"Update":"0","Internal":false,"Count Failed Values":true},{"ID":20,"Update":"191901","Internal":false,"Count Failed Values":true}]},"Task Info":{"Task ID":5,"Index":5,"Attempt":0,"Launch Time":1516300394320,"Executor ID":"5","Host":"apiros-2.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1516300394393,"Failed":true,"Killed":false,"Accumulables":[{"ID":20,"Name":"internal.metrics.shuffle.write.writeTime","Update":191901,"Value":20308873,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":27,"Value":2594,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":0,"Executor Deserialize CPU Time":0,"Executor Run Time":27,"Executor CPU Time":0,"Result Size":0,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":191901,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":2,"Index":2,"Attempt":0,"Launch Time":1516300392832,"Executor ID":"3","Host":"apiros-2.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1516300394606,"Failed":false,"Killed":false,"Accumulables":[{"ID":20,"Name":"internal.metrics.shuffle.write.writeTime","Update":3322956,"Value":23631829,"Internal":true,"Count Failed Values":true},{"ID":19,"Name":"internal.metrics.shuffle.write.recordsWritten","Update":3,"Value":3,"Internal":true,"Count Failed Values":true},{"ID":18,"Name":"internal.metrics.shuffle.write.bytesWritten","Update":144,"Value":144,"Internal":true,"Count Failed Values":true},{"ID":9,"Name":"internal.metrics.peakExecutionMemory","Update":1080,"Value":1080,"Internal":true,"Count Failed Values":true},{"ID":8,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":7,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":6,"Name":"internal.metrics.resultSerializationTime","Update":1,"Value":1,"Internal":true,"Count Failed Values":true},{"ID":5,"Name":"internal.metrics.jvmGCTime","Update":78,"Value":193,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":1134,"Value":1134,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":278399617,"Value":278399617,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":493,"Value":3087,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":263386625,"Value":263386625,"Internal":true,"Count Failed Values":true},{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":1206,"Value":1206,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":1206,"Executor Deserialize CPU Time":263386625,"Executor Run Time":493,"Executor CPU Time":278399617,"Result Size":1134,"JVM GC Time":78,"Result Serialization Time":1,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":144,"Shuffle Write Time":3322956,"Shuffle Records Written":3},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":7,"Index":5,"Attempt":1,"Launch Time":1516300394859,"Executor ID":"2","Host":"apiros-3.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":3,"Index":3,"Attempt":0,"Launch Time":1516300392833,"Executor ID":"2","Host":"apiros-3.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1516300394860,"Failed":false,"Killed":false,"Accumulables":[{"ID":20,"Name":"internal.metrics.shuffle.write.writeTime","Update":3587839,"Value":27219668,"Internal":true,"Count Failed Values":true},{"ID":19,"Name":"internal.metrics.shuffle.write.recordsWritten","Update":3,"Value":6,"Internal":true,"Count Failed Values":true},{"ID":18,"Name":"internal.metrics.shuffle.write.bytesWritten","Update":147,"Value":291,"Internal":true,"Count Failed Values":true},{"ID":9,"Name":"internal.metrics.peakExecutionMemory","Update":1080,"Value":2160,"Internal":true,"Count Failed Values":true},{"ID":8,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":7,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":6,"Name":"internal.metrics.resultSerializationTime","Update":1,"Value":2,"Internal":true,"Count Failed Values":true},{"ID":5,"Name":"internal.metrics.jvmGCTime","Update":102,"Value":295,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":1134,"Value":2268,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":349920830,"Value":628320447,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":681,"Value":3768,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":365807898,"Value":629194523,"Internal":true,"Count Failed Values":true},{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":1282,"Value":2488,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":1282,"Executor Deserialize CPU Time":365807898,"Executor Run Time":681,"Executor CPU Time":349920830,"Result Size":1134,"JVM GC Time":102,"Result Serialization Time":1,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":147,"Shuffle Write Time":3587839,"Shuffle Records Written":3},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":8,"Index":6,"Attempt":1,"Launch Time":1516300394879,"Executor ID":"1","Host":"apiros-3.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":0,"Index":0,"Attempt":0,"Launch Time":1516300392816,"Executor ID":"1","Host":"apiros-3.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1516300394880,"Failed":false,"Killed":false,"Accumulables":[{"ID":20,"Name":"internal.metrics.shuffle.write.writeTime","Update":3662221,"Value":30881889,"Internal":true,"Count Failed Values":true},{"ID":19,"Name":"internal.metrics.shuffle.write.recordsWritten","Update":3,"Value":9,"Internal":true,"Count Failed Values":true},{"ID":18,"Name":"internal.metrics.shuffle.write.bytesWritten","Update":144,"Value":435,"Internal":true,"Count Failed Values":true},{"ID":9,"Name":"internal.metrics.peakExecutionMemory","Update":1080,"Value":3240,"Internal":true,"Count Failed Values":true},{"ID":8,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":7,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":6,"Name":"internal.metrics.resultSerializationTime","Update":1,"Value":3,"Internal":true,"Count Failed Values":true},{"ID":5,"Name":"internal.metrics.jvmGCTime","Update":75,"Value":370,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":1134,"Value":3402,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":368865439,"Value":997185886,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":914,"Value":4682,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":353981050,"Value":983175573,"Internal":true,"Count Failed Values":true},{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":1081,"Value":3569,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":1081,"Executor Deserialize CPU Time":353981050,"Executor Run Time":914,"Executor CPU Time":368865439,"Result Size":1134,"JVM GC Time":75,"Result Serialization Time":1,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":144,"Shuffle Write Time":3662221,"Shuffle Records Written":3},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":9,"Index":4,"Attempt":1,"Launch Time":1516300394973,"Executor ID":"2","Host":"apiros-3.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":7,"Index":5,"Attempt":1,"Launch Time":1516300394859,"Executor ID":"2","Host":"apiros-3.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1516300394974,"Failed":false,"Killed":false,"Accumulables":[{"ID":20,"Name":"internal.metrics.shuffle.write.writeTime","Update":377601,"Value":31259490,"Internal":true,"Count Failed Values":true},{"ID":19,"Name":"internal.metrics.shuffle.write.recordsWritten","Update":3,"Value":12,"Internal":true,"Count Failed Values":true},{"ID":18,"Name":"internal.metrics.shuffle.write.bytesWritten","Update":147,"Value":582,"Internal":true,"Count Failed Values":true},{"ID":9,"Name":"internal.metrics.peakExecutionMemory","Update":1080,"Value":4320,"Internal":true,"Count Failed Values":true},{"ID":8,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":7,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":1048,"Value":4450,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":28283110,"Value":1025468996,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":84,"Value":4766,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":10894331,"Value":994069904,"Internal":true,"Count Failed Values":true},{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":11,"Value":3580,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":11,"Executor Deserialize CPU Time":10894331,"Executor Run Time":84,"Executor CPU Time":28283110,"Result Size":1048,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":147,"Shuffle Write Time":377601,"Shuffle Records Written":3},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":10,"Index":1,"Attempt":1,"Launch Time":1516300395069,"Executor ID":"2","Host":"apiros-3.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":9,"Index":4,"Attempt":1,"Launch Time":1516300394973,"Executor ID":"2","Host":"apiros-3.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1516300395069,"Failed":false,"Killed":false,"Accumulables":[{"ID":20,"Name":"internal.metrics.shuffle.write.writeTime","Update":366050,"Value":31625540,"Internal":true,"Count Failed Values":true},{"ID":19,"Name":"internal.metrics.shuffle.write.recordsWritten","Update":3,"Value":15,"Internal":true,"Count Failed Values":true},{"ID":18,"Name":"internal.metrics.shuffle.write.bytesWritten","Update":147,"Value":729,"Internal":true,"Count Failed Values":true},{"ID":9,"Name":"internal.metrics.peakExecutionMemory","Update":1080,"Value":5400,"Internal":true,"Count Failed Values":true},{"ID":8,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":7,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":6,"Name":"internal.metrics.resultSerializationTime","Update":1,"Value":4,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":1091,"Value":5541,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":25678331,"Value":1051147327,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":48,"Value":4814,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":4793905,"Value":998863809,"Internal":true,"Count Failed Values":true},{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":5,"Value":3585,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":5,"Executor Deserialize CPU Time":4793905,"Executor Run Time":48,"Executor CPU Time":25678331,"Result Size":1091,"JVM GC Time":0,"Result Serialization Time":1,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":147,"Shuffle Write Time":366050,"Shuffle Records Written":3},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":11,"Index":7,"Attempt":0,"Launch Time":1516300395072,"Executor ID":"1","Host":"apiros-3.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":8,"Index":6,"Attempt":1,"Launch Time":1516300394879,"Executor ID":"1","Host":"apiros-3.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1516300395073,"Failed":false,"Killed":false,"Accumulables":[{"ID":20,"Name":"internal.metrics.shuffle.write.writeTime","Update":311940,"Value":31937480,"Internal":true,"Count Failed Values":true},{"ID":19,"Name":"internal.metrics.shuffle.write.recordsWritten","Update":3,"Value":18,"Internal":true,"Count Failed Values":true},{"ID":18,"Name":"internal.metrics.shuffle.write.bytesWritten","Update":147,"Value":876,"Internal":true,"Count Failed Values":true},{"ID":9,"Name":"internal.metrics.peakExecutionMemory","Update":1080,"Value":6480,"Internal":true,"Count Failed Values":true},{"ID":8,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":7,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":1048,"Value":6589,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":27304550,"Value":1078451877,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":54,"Value":4868,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":12246145,"Value":1011109954,"Internal":true,"Count Failed Values":true},{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":56,"Value":3641,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":56,"Executor Deserialize CPU Time":12246145,"Executor Run Time":54,"Executor CPU Time":27304550,"Result Size":1048,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":147,"Shuffle Write Time":311940,"Shuffle Records Written":3},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":12,"Index":8,"Attempt":0,"Launch Time":1516300395165,"Executor ID":"1","Host":"apiros-3.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":11,"Index":7,"Attempt":0,"Launch Time":1516300395072,"Executor ID":"1","Host":"apiros-3.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1516300395165,"Failed":false,"Killed":false,"Accumulables":[{"ID":20,"Name":"internal.metrics.shuffle.write.writeTime","Update":323898,"Value":32261378,"Internal":true,"Count Failed Values":true},{"ID":19,"Name":"internal.metrics.shuffle.write.recordsWritten","Update":3,"Value":21,"Internal":true,"Count Failed Values":true},{"ID":18,"Name":"internal.metrics.shuffle.write.bytesWritten","Update":147,"Value":1023,"Internal":true,"Count Failed Values":true},{"ID":9,"Name":"internal.metrics.peakExecutionMemory","Update":1080,"Value":7560,"Internal":true,"Count Failed Values":true},{"ID":8,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":7,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":1048,"Value":7637,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":21689428,"Value":1100141305,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":77,"Value":4945,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":4239884,"Value":1015349838,"Internal":true,"Count Failed Values":true},{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":4,"Value":3645,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":4,"Executor Deserialize CPU Time":4239884,"Executor Run Time":77,"Executor CPU Time":21689428,"Result Size":1048,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":147,"Shuffle Write Time":323898,"Shuffle Records Written":3},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":13,"Index":9,"Attempt":0,"Launch Time":1516300395200,"Executor ID":"2","Host":"apiros-3.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":10,"Index":1,"Attempt":1,"Launch Time":1516300395069,"Executor ID":"2","Host":"apiros-3.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1516300395201,"Failed":false,"Killed":false,"Accumulables":[{"ID":20,"Name":"internal.metrics.shuffle.write.writeTime","Update":301705,"Value":32563083,"Internal":true,"Count Failed Values":true},{"ID":19,"Name":"internal.metrics.shuffle.write.recordsWritten","Update":3,"Value":24,"Internal":true,"Count Failed Values":true},{"ID":18,"Name":"internal.metrics.shuffle.write.bytesWritten","Update":144,"Value":1167,"Internal":true,"Count Failed Values":true},{"ID":9,"Name":"internal.metrics.peakExecutionMemory","Update":1080,"Value":8640,"Internal":true,"Count Failed Values":true},{"ID":8,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":7,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":6,"Name":"internal.metrics.resultSerializationTime","Update":1,"Value":5,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":1091,"Value":8728,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":20826337,"Value":1120967642,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":76,"Value":5021,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":4598966,"Value":1019948804,"Internal":true,"Count Failed Values":true},{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":5,"Value":3650,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":5,"Executor Deserialize CPU Time":4598966,"Executor Run Time":76,"Executor CPU Time":20826337,"Result Size":1091,"JVM GC Time":0,"Result Serialization Time":1,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":144,"Shuffle Write Time":301705,"Shuffle Records Written":3},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":12,"Index":8,"Attempt":0,"Launch Time":1516300395165,"Executor ID":"1","Host":"apiros-3.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1516300395225,"Failed":false,"Killed":false,"Accumulables":[{"ID":20,"Name":"internal.metrics.shuffle.write.writeTime","Update":319101,"Value":32882184,"Internal":true,"Count Failed Values":true},{"ID":19,"Name":"internal.metrics.shuffle.write.recordsWritten","Update":3,"Value":27,"Internal":true,"Count Failed Values":true},{"ID":18,"Name":"internal.metrics.shuffle.write.bytesWritten","Update":147,"Value":1314,"Internal":true,"Count Failed Values":true},{"ID":9,"Name":"internal.metrics.peakExecutionMemory","Update":1080,"Value":9720,"Internal":true,"Count Failed Values":true},{"ID":8,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":7,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":1048,"Value":9776,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":21657558,"Value":1142625200,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":34,"Value":5055,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":4010338,"Value":1023959142,"Internal":true,"Count Failed Values":true},{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":4,"Value":3654,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":4,"Executor Deserialize CPU Time":4010338,"Executor Run Time":34,"Executor CPU Time":21657558,"Result Size":1048,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":147,"Shuffle Write Time":319101,"Shuffle Records Written":3},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ShuffleMapTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":13,"Index":9,"Attempt":0,"Launch Time":1516300395200,"Executor ID":"2","Host":"apiros-3.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1516300395276,"Failed":false,"Killed":false,"Accumulables":[{"ID":20,"Name":"internal.metrics.shuffle.write.writeTime","Update":369513,"Value":33251697,"Internal":true,"Count Failed Values":true},{"ID":19,"Name":"internal.metrics.shuffle.write.recordsWritten","Update":3,"Value":30,"Internal":true,"Count Failed Values":true},{"ID":18,"Name":"internal.metrics.shuffle.write.bytesWritten","Update":147,"Value":1461,"Internal":true,"Count Failed Values":true},{"ID":9,"Name":"internal.metrics.peakExecutionMemory","Update":1080,"Value":10800,"Internal":true,"Count Failed Values":true},{"ID":8,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":7,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":1048,"Value":10824,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":20585619,"Value":1163210819,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":25,"Value":5080,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":5860574,"Value":1029819716,"Internal":true,"Count Failed Values":true},{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":25,"Value":3679,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":25,"Executor Deserialize CPU Time":5860574,"Executor Run Time":25,"Executor CPU Time":20585619,"Result Size":1048,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":147,"Shuffle Write Time":369513,"Shuffle Records Written":3},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerStageCompleted","Stage Info":{"Stage ID":0,"Stage Attempt ID":0,"Stage Name":"map at :27","Number of Tasks":10,"RDD Info":[{"RDD ID":1,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"1\",\"name\":\"map\"}","Callsite":"map at :27","Parent IDs":[0],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":10,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":0,"Name":"ParallelCollectionRDD","Scope":"{\"id\":\"0\",\"name\":\"parallelize\"}","Callsite":"parallelize at :27","Parent IDs":[],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":10,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0}],"Parent IDs":[],"Details":"org.apache.spark.rdd.RDD.map(RDD.scala:370)\n$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:27)\n$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:35)\n$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw.(:37)\n$line15.$read$$iw$$iw$$iw$$iw$$iw.(:39)\n$line15.$read$$iw$$iw$$iw$$iw.(:41)\n$line15.$read$$iw$$iw$$iw.(:43)\n$line15.$read$$iw$$iw.(:45)\n$line15.$read$$iw.(:47)\n$line15.$read.(:49)\n$line15.$read$.(:53)\n$line15.$read$.()\n$line15.$eval$.$print$lzycompute(:7)\n$line15.$eval$.$print(:6)\n$line15.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:498)\nscala.tools.nsc.interpreter.IMain$ReadEvalPrint.call(IMain.scala:786)","Submission Time":1516300392658,"Completion Time":1516300395279,"Accumulables":[{"ID":8,"Name":"internal.metrics.diskBytesSpilled","Value":0,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Value":5080,"Internal":true,"Count Failed Values":true},{"ID":20,"Name":"internal.metrics.shuffle.write.writeTime","Value":33251697,"Internal":true,"Count Failed Values":true},{"ID":5,"Name":"internal.metrics.jvmGCTime","Value":370,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Value":10824,"Internal":true,"Count Failed Values":true},{"ID":7,"Name":"internal.metrics.memoryBytesSpilled","Value":0,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Value":1029819716,"Internal":true,"Count Failed Values":true},{"ID":19,"Name":"internal.metrics.shuffle.write.recordsWritten","Value":30,"Internal":true,"Count Failed Values":true},{"ID":9,"Name":"internal.metrics.peakExecutionMemory","Value":10800,"Internal":true,"Count Failed Values":true},{"ID":18,"Name":"internal.metrics.shuffle.write.bytesWritten","Value":1461,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Value":1163210819,"Internal":true,"Count Failed Values":true},{"ID":6,"Name":"internal.metrics.resultSerializationTime","Value":5,"Internal":true,"Count Failed Values":true},{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Value":3679,"Internal":true,"Count Failed Values":true}]}} +{"Event":"SparkListenerStageSubmitted","Stage Info":{"Stage ID":1,"Stage Attempt ID":0,"Stage Name":"collect at :30","Number of Tasks":10,"RDD Info":[{"RDD ID":2,"Name":"ShuffledRDD","Scope":"{\"id\":\"2\",\"name\":\"reduceByKey\"}","Callsite":"reduceByKey at :30","Parent IDs":[1],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":10,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0}],"Parent IDs":[0],"Details":"org.apache.spark.rdd.RDD.collect(RDD.scala:936)\n$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:30)\n$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:35)\n$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw.(:37)\n$line15.$read$$iw$$iw$$iw$$iw$$iw.(:39)\n$line15.$read$$iw$$iw$$iw$$iw.(:41)\n$line15.$read$$iw$$iw$$iw.(:43)\n$line15.$read$$iw$$iw.(:45)\n$line15.$read$$iw.(:47)\n$line15.$read.(:49)\n$line15.$read$.(:53)\n$line15.$read$.()\n$line15.$eval$.$print$lzycompute(:7)\n$line15.$eval$.$print(:6)\n$line15.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:498)\nscala.tools.nsc.interpreter.IMain$ReadEvalPrint.call(IMain.scala:786)","Submission Time":1516300395292,"Accumulables":[]},"Properties":{"spark.rdd.scope.noOverride":"true","spark.rdd.scope":"{\"id\":\"3\",\"name\":\"collect\"}"}} +{"Event":"SparkListenerTaskStart","Stage ID":1,"Stage Attempt ID":0,"Task Info":{"Task ID":14,"Index":0,"Attempt":0,"Launch Time":1516300395302,"Executor ID":"1","Host":"apiros-3.gce.test.com","Locality":"NODE_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":1,"Stage Attempt ID":0,"Task Info":{"Task ID":15,"Index":1,"Attempt":0,"Launch Time":1516300395303,"Executor ID":"2","Host":"apiros-3.gce.test.com","Locality":"NODE_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":1,"Stage Attempt ID":0,"Task Info":{"Task ID":16,"Index":3,"Attempt":0,"Launch Time":1516300395304,"Executor ID":"5","Host":"apiros-2.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":1,"Stage Attempt ID":0,"Task Info":{"Task ID":17,"Index":4,"Attempt":0,"Launch Time":1516300395304,"Executor ID":"4","Host":"apiros-2.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":1,"Stage Attempt ID":0,"Task Info":{"Task ID":18,"Index":5,"Attempt":0,"Launch Time":1516300395304,"Executor ID":"3","Host":"apiros-2.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":1,"Stage Attempt ID":0,"Task Info":{"Task ID":19,"Index":6,"Attempt":0,"Launch Time":1516300395525,"Executor ID":"4","Host":"apiros-2.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":1,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":17,"Index":4,"Attempt":0,"Launch Time":1516300395304,"Executor ID":"4","Host":"apiros-2.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1516300395525,"Failed":false,"Killed":false,"Accumulables":[{"ID":42,"Name":"internal.metrics.shuffle.read.recordsRead","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":41,"Name":"internal.metrics.shuffle.read.fetchWaitTime","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":40,"Name":"internal.metrics.shuffle.read.localBytesRead","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":39,"Name":"internal.metrics.shuffle.read.remoteBytesReadToDisk","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":38,"Name":"internal.metrics.shuffle.read.remoteBytesRead","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":37,"Name":"internal.metrics.shuffle.read.localBlocksFetched","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":36,"Name":"internal.metrics.shuffle.read.remoteBlocksFetched","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":34,"Name":"internal.metrics.peakExecutionMemory","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":33,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":32,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":29,"Name":"internal.metrics.resultSize","Update":1134,"Value":1134,"Internal":true,"Count Failed Values":true},{"ID":28,"Name":"internal.metrics.executorCpuTime","Update":52455999,"Value":52455999,"Internal":true,"Count Failed Values":true},{"ID":27,"Name":"internal.metrics.executorRunTime","Update":95,"Value":95,"Internal":true,"Count Failed Values":true},{"ID":26,"Name":"internal.metrics.executorDeserializeCpuTime","Update":23136577,"Value":23136577,"Internal":true,"Count Failed Values":true},{"ID":25,"Name":"internal.metrics.executorDeserializeTime","Update":82,"Value":82,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":82,"Executor Deserialize CPU Time":23136577,"Executor Run Time":95,"Executor CPU Time":52455999,"Result Size":1134,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":1,"Stage Attempt ID":0,"Task Info":{"Task ID":20,"Index":7,"Attempt":0,"Launch Time":1516300395575,"Executor ID":"4","Host":"apiros-2.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":1,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":19,"Index":6,"Attempt":0,"Launch Time":1516300395525,"Executor ID":"4","Host":"apiros-2.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1516300395576,"Failed":false,"Killed":false,"Accumulables":[{"ID":42,"Name":"internal.metrics.shuffle.read.recordsRead","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":41,"Name":"internal.metrics.shuffle.read.fetchWaitTime","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":40,"Name":"internal.metrics.shuffle.read.localBytesRead","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":39,"Name":"internal.metrics.shuffle.read.remoteBytesReadToDisk","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":38,"Name":"internal.metrics.shuffle.read.remoteBytesRead","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":37,"Name":"internal.metrics.shuffle.read.localBlocksFetched","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":36,"Name":"internal.metrics.shuffle.read.remoteBlocksFetched","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":34,"Name":"internal.metrics.peakExecutionMemory","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":33,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":32,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":29,"Name":"internal.metrics.resultSize","Update":1134,"Value":2268,"Internal":true,"Count Failed Values":true},{"ID":28,"Name":"internal.metrics.executorCpuTime","Update":13617615,"Value":66073614,"Internal":true,"Count Failed Values":true},{"ID":27,"Name":"internal.metrics.executorRunTime","Update":29,"Value":124,"Internal":true,"Count Failed Values":true},{"ID":26,"Name":"internal.metrics.executorDeserializeCpuTime","Update":3469612,"Value":26606189,"Internal":true,"Count Failed Values":true},{"ID":25,"Name":"internal.metrics.executorDeserializeTime","Update":4,"Value":86,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":4,"Executor Deserialize CPU Time":3469612,"Executor Run Time":29,"Executor CPU Time":13617615,"Result Size":1134,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":1,"Stage Attempt ID":0,"Task Info":{"Task ID":21,"Index":8,"Attempt":0,"Launch Time":1516300395581,"Executor ID":"3","Host":"apiros-2.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":1,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":18,"Index":5,"Attempt":0,"Launch Time":1516300395304,"Executor ID":"3","Host":"apiros-2.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1516300395581,"Failed":false,"Killed":false,"Accumulables":[{"ID":42,"Name":"internal.metrics.shuffle.read.recordsRead","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":41,"Name":"internal.metrics.shuffle.read.fetchWaitTime","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":40,"Name":"internal.metrics.shuffle.read.localBytesRead","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":39,"Name":"internal.metrics.shuffle.read.remoteBytesReadToDisk","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":38,"Name":"internal.metrics.shuffle.read.remoteBytesRead","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":37,"Name":"internal.metrics.shuffle.read.localBlocksFetched","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":36,"Name":"internal.metrics.shuffle.read.remoteBlocksFetched","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":34,"Name":"internal.metrics.peakExecutionMemory","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":33,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":32,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":29,"Name":"internal.metrics.resultSize","Update":1134,"Value":3402,"Internal":true,"Count Failed Values":true},{"ID":28,"Name":"internal.metrics.executorCpuTime","Update":55540208,"Value":121613822,"Internal":true,"Count Failed Values":true},{"ID":27,"Name":"internal.metrics.executorRunTime","Update":179,"Value":303,"Internal":true,"Count Failed Values":true},{"ID":26,"Name":"internal.metrics.executorDeserializeCpuTime","Update":22400065,"Value":49006254,"Internal":true,"Count Failed Values":true},{"ID":25,"Name":"internal.metrics.executorDeserializeTime","Update":78,"Value":164,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":78,"Executor Deserialize CPU Time":22400065,"Executor Run Time":179,"Executor CPU Time":55540208,"Result Size":1134,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":1,"Stage Attempt ID":0,"Task Info":{"Task ID":22,"Index":9,"Attempt":0,"Launch Time":1516300395593,"Executor ID":"5","Host":"apiros-2.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":1,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":16,"Index":3,"Attempt":0,"Launch Time":1516300395304,"Executor ID":"5","Host":"apiros-2.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1516300395593,"Failed":false,"Killed":false,"Accumulables":[{"ID":42,"Name":"internal.metrics.shuffle.read.recordsRead","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":41,"Name":"internal.metrics.shuffle.read.fetchWaitTime","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":40,"Name":"internal.metrics.shuffle.read.localBytesRead","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":39,"Name":"internal.metrics.shuffle.read.remoteBytesReadToDisk","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":38,"Name":"internal.metrics.shuffle.read.remoteBytesRead","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":37,"Name":"internal.metrics.shuffle.read.localBlocksFetched","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":36,"Name":"internal.metrics.shuffle.read.remoteBlocksFetched","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":34,"Name":"internal.metrics.peakExecutionMemory","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":33,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":32,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":29,"Name":"internal.metrics.resultSize","Update":1134,"Value":4536,"Internal":true,"Count Failed Values":true},{"ID":28,"Name":"internal.metrics.executorCpuTime","Update":52311573,"Value":173925395,"Internal":true,"Count Failed Values":true},{"ID":27,"Name":"internal.metrics.executorRunTime","Update":153,"Value":456,"Internal":true,"Count Failed Values":true},{"ID":26,"Name":"internal.metrics.executorDeserializeCpuTime","Update":20519033,"Value":69525287,"Internal":true,"Count Failed Values":true},{"ID":25,"Name":"internal.metrics.executorDeserializeTime","Update":67,"Value":231,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":67,"Executor Deserialize CPU Time":20519033,"Executor Run Time":153,"Executor CPU Time":52311573,"Result Size":1134,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":1,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":20,"Index":7,"Attempt":0,"Launch Time":1516300395575,"Executor ID":"4","Host":"apiros-2.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1516300395660,"Failed":false,"Killed":false,"Accumulables":[{"ID":42,"Name":"internal.metrics.shuffle.read.recordsRead","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":41,"Name":"internal.metrics.shuffle.read.fetchWaitTime","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":40,"Name":"internal.metrics.shuffle.read.localBytesRead","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":39,"Name":"internal.metrics.shuffle.read.remoteBytesReadToDisk","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":38,"Name":"internal.metrics.shuffle.read.remoteBytesRead","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":37,"Name":"internal.metrics.shuffle.read.localBlocksFetched","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":36,"Name":"internal.metrics.shuffle.read.remoteBlocksFetched","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":34,"Name":"internal.metrics.peakExecutionMemory","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":33,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":32,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":29,"Name":"internal.metrics.resultSize","Update":1134,"Value":5670,"Internal":true,"Count Failed Values":true},{"ID":28,"Name":"internal.metrics.executorCpuTime","Update":11294260,"Value":185219655,"Internal":true,"Count Failed Values":true},{"ID":27,"Name":"internal.metrics.executorRunTime","Update":33,"Value":489,"Internal":true,"Count Failed Values":true},{"ID":26,"Name":"internal.metrics.executorDeserializeCpuTime","Update":3570887,"Value":73096174,"Internal":true,"Count Failed Values":true},{"ID":25,"Name":"internal.metrics.executorDeserializeTime","Update":4,"Value":235,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":4,"Executor Deserialize CPU Time":3570887,"Executor Run Time":33,"Executor CPU Time":11294260,"Result Size":1134,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":1,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":22,"Index":9,"Attempt":0,"Launch Time":1516300395593,"Executor ID":"5","Host":"apiros-2.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1516300395669,"Failed":false,"Killed":false,"Accumulables":[{"ID":42,"Name":"internal.metrics.shuffle.read.recordsRead","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":41,"Name":"internal.metrics.shuffle.read.fetchWaitTime","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":40,"Name":"internal.metrics.shuffle.read.localBytesRead","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":39,"Name":"internal.metrics.shuffle.read.remoteBytesReadToDisk","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":38,"Name":"internal.metrics.shuffle.read.remoteBytesRead","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":37,"Name":"internal.metrics.shuffle.read.localBlocksFetched","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":36,"Name":"internal.metrics.shuffle.read.remoteBlocksFetched","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":34,"Name":"internal.metrics.peakExecutionMemory","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":33,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":32,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":29,"Name":"internal.metrics.resultSize","Update":1134,"Value":6804,"Internal":true,"Count Failed Values":true},{"ID":28,"Name":"internal.metrics.executorCpuTime","Update":12983732,"Value":198203387,"Internal":true,"Count Failed Values":true},{"ID":27,"Name":"internal.metrics.executorRunTime","Update":44,"Value":533,"Internal":true,"Count Failed Values":true},{"ID":26,"Name":"internal.metrics.executorDeserializeCpuTime","Update":3518757,"Value":76614931,"Internal":true,"Count Failed Values":true},{"ID":25,"Name":"internal.metrics.executorDeserializeTime","Update":4,"Value":239,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":4,"Executor Deserialize CPU Time":3518757,"Executor Run Time":44,"Executor CPU Time":12983732,"Result Size":1134,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":1,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":21,"Index":8,"Attempt":0,"Launch Time":1516300395581,"Executor ID":"3","Host":"apiros-2.gce.test.com","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1516300395674,"Failed":false,"Killed":false,"Accumulables":[{"ID":42,"Name":"internal.metrics.shuffle.read.recordsRead","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":41,"Name":"internal.metrics.shuffle.read.fetchWaitTime","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":40,"Name":"internal.metrics.shuffle.read.localBytesRead","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":39,"Name":"internal.metrics.shuffle.read.remoteBytesReadToDisk","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":38,"Name":"internal.metrics.shuffle.read.remoteBytesRead","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":37,"Name":"internal.metrics.shuffle.read.localBlocksFetched","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":36,"Name":"internal.metrics.shuffle.read.remoteBlocksFetched","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":34,"Name":"internal.metrics.peakExecutionMemory","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":33,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":32,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":29,"Name":"internal.metrics.resultSize","Update":1134,"Value":7938,"Internal":true,"Count Failed Values":true},{"ID":28,"Name":"internal.metrics.executorCpuTime","Update":14706240,"Value":212909627,"Internal":true,"Count Failed Values":true},{"ID":27,"Name":"internal.metrics.executorRunTime","Update":64,"Value":597,"Internal":true,"Count Failed Values":true},{"ID":26,"Name":"internal.metrics.executorDeserializeCpuTime","Update":7698059,"Value":84312990,"Internal":true,"Count Failed Values":true},{"ID":25,"Name":"internal.metrics.executorDeserializeTime","Update":21,"Value":260,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":21,"Executor Deserialize CPU Time":7698059,"Executor Run Time":64,"Executor CPU Time":14706240,"Result Size":1134,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Remote Bytes Read To Disk":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":1,"Stage Attempt ID":0,"Task Info":{"Task ID":23,"Index":2,"Attempt":0,"Launch Time":1516300395686,"Executor ID":"1","Host":"apiros-3.gce.test.com","Locality":"NODE_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":1,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":14,"Index":0,"Attempt":0,"Launch Time":1516300395302,"Executor ID":"1","Host":"apiros-3.gce.test.com","Locality":"NODE_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1516300395687,"Failed":false,"Killed":false,"Accumulables":[{"ID":42,"Name":"internal.metrics.shuffle.read.recordsRead","Update":10,"Value":10,"Internal":true,"Count Failed Values":true},{"ID":41,"Name":"internal.metrics.shuffle.read.fetchWaitTime","Update":52,"Value":52,"Internal":true,"Count Failed Values":true},{"ID":40,"Name":"internal.metrics.shuffle.read.localBytesRead","Update":195,"Value":195,"Internal":true,"Count Failed Values":true},{"ID":39,"Name":"internal.metrics.shuffle.read.remoteBytesReadToDisk","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":38,"Name":"internal.metrics.shuffle.read.remoteBytesRead","Update":292,"Value":292,"Internal":true,"Count Failed Values":true},{"ID":37,"Name":"internal.metrics.shuffle.read.localBlocksFetched","Update":4,"Value":4,"Internal":true,"Count Failed Values":true},{"ID":36,"Name":"internal.metrics.shuffle.read.remoteBlocksFetched","Update":6,"Value":6,"Internal":true,"Count Failed Values":true},{"ID":34,"Name":"internal.metrics.peakExecutionMemory","Update":944,"Value":944,"Internal":true,"Count Failed Values":true},{"ID":33,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":32,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":29,"Name":"internal.metrics.resultSize","Update":1286,"Value":9224,"Internal":true,"Count Failed Values":true},{"ID":28,"Name":"internal.metrics.executorCpuTime","Update":91696783,"Value":304606410,"Internal":true,"Count Failed Values":true},{"ID":27,"Name":"internal.metrics.executorRunTime","Update":221,"Value":818,"Internal":true,"Count Failed Values":true},{"ID":26,"Name":"internal.metrics.executorDeserializeCpuTime","Update":24063461,"Value":108376451,"Internal":true,"Count Failed Values":true},{"ID":25,"Name":"internal.metrics.executorDeserializeTime","Update":150,"Value":410,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":150,"Executor Deserialize CPU Time":24063461,"Executor Run Time":221,"Executor CPU Time":91696783,"Result Size":1286,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":6,"Local Blocks Fetched":4,"Fetch Wait Time":52,"Remote Bytes Read":292,"Remote Bytes Read To Disk":0,"Local Bytes Read":195,"Total Records Read":10},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":1,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":15,"Index":1,"Attempt":0,"Launch Time":1516300395303,"Executor ID":"2","Host":"apiros-3.gce.test.com","Locality":"NODE_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1516300395687,"Failed":false,"Killed":false,"Accumulables":[{"ID":42,"Name":"internal.metrics.shuffle.read.recordsRead","Update":10,"Value":20,"Internal":true,"Count Failed Values":true},{"ID":41,"Name":"internal.metrics.shuffle.read.fetchWaitTime","Update":107,"Value":159,"Internal":true,"Count Failed Values":true},{"ID":40,"Name":"internal.metrics.shuffle.read.localBytesRead","Update":244,"Value":439,"Internal":true,"Count Failed Values":true},{"ID":39,"Name":"internal.metrics.shuffle.read.remoteBytesReadToDisk","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":38,"Name":"internal.metrics.shuffle.read.remoteBytesRead","Update":243,"Value":535,"Internal":true,"Count Failed Values":true},{"ID":37,"Name":"internal.metrics.shuffle.read.localBlocksFetched","Update":5,"Value":9,"Internal":true,"Count Failed Values":true},{"ID":36,"Name":"internal.metrics.shuffle.read.remoteBlocksFetched","Update":5,"Value":11,"Internal":true,"Count Failed Values":true},{"ID":34,"Name":"internal.metrics.peakExecutionMemory","Update":944,"Value":1888,"Internal":true,"Count Failed Values":true},{"ID":33,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":32,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":29,"Name":"internal.metrics.resultSize","Update":1286,"Value":10510,"Internal":true,"Count Failed Values":true},{"ID":28,"Name":"internal.metrics.executorCpuTime","Update":91683507,"Value":396289917,"Internal":true,"Count Failed Values":true},{"ID":27,"Name":"internal.metrics.executorRunTime","Update":289,"Value":1107,"Internal":true,"Count Failed Values":true},{"ID":26,"Name":"internal.metrics.executorDeserializeCpuTime","Update":22106726,"Value":130483177,"Internal":true,"Count Failed Values":true},{"ID":25,"Name":"internal.metrics.executorDeserializeTime","Update":79,"Value":489,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":79,"Executor Deserialize CPU Time":22106726,"Executor Run Time":289,"Executor CPU Time":91683507,"Result Size":1286,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":5,"Local Blocks Fetched":5,"Fetch Wait Time":107,"Remote Bytes Read":243,"Remote Bytes Read To Disk":0,"Local Bytes Read":244,"Total Records Read":10},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":1,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":23,"Index":2,"Attempt":0,"Launch Time":1516300395686,"Executor ID":"1","Host":"apiros-3.gce.test.com","Locality":"NODE_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1516300395728,"Failed":false,"Killed":false,"Accumulables":[{"ID":42,"Name":"internal.metrics.shuffle.read.recordsRead","Update":10,"Value":30,"Internal":true,"Count Failed Values":true},{"ID":41,"Name":"internal.metrics.shuffle.read.fetchWaitTime","Update":0,"Value":159,"Internal":true,"Count Failed Values":true},{"ID":40,"Name":"internal.metrics.shuffle.read.localBytesRead","Update":195,"Value":634,"Internal":true,"Count Failed Values":true},{"ID":39,"Name":"internal.metrics.shuffle.read.remoteBytesReadToDisk","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":38,"Name":"internal.metrics.shuffle.read.remoteBytesRead","Update":292,"Value":827,"Internal":true,"Count Failed Values":true},{"ID":37,"Name":"internal.metrics.shuffle.read.localBlocksFetched","Update":4,"Value":13,"Internal":true,"Count Failed Values":true},{"ID":36,"Name":"internal.metrics.shuffle.read.remoteBlocksFetched","Update":6,"Value":17,"Internal":true,"Count Failed Values":true},{"ID":34,"Name":"internal.metrics.peakExecutionMemory","Update":944,"Value":2832,"Internal":true,"Count Failed Values":true},{"ID":33,"Name":"internal.metrics.diskBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":32,"Name":"internal.metrics.memoryBytesSpilled","Update":0,"Value":0,"Internal":true,"Count Failed Values":true},{"ID":29,"Name":"internal.metrics.resultSize","Update":1286,"Value":11796,"Internal":true,"Count Failed Values":true},{"ID":28,"Name":"internal.metrics.executorCpuTime","Update":17607810,"Value":413897727,"Internal":true,"Count Failed Values":true},{"ID":27,"Name":"internal.metrics.executorRunTime","Update":33,"Value":1140,"Internal":true,"Count Failed Values":true},{"ID":26,"Name":"internal.metrics.executorDeserializeCpuTime","Update":2897647,"Value":133380824,"Internal":true,"Count Failed Values":true},{"ID":25,"Name":"internal.metrics.executorDeserializeTime","Update":2,"Value":491,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":2,"Executor Deserialize CPU Time":2897647,"Executor Run Time":33,"Executor CPU Time":17607810,"Result Size":1286,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":6,"Local Blocks Fetched":4,"Fetch Wait Time":0,"Remote Bytes Read":292,"Remote Bytes Read To Disk":0,"Local Bytes Read":195,"Total Records Read":10},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerStageCompleted","Stage Info":{"Stage ID":1,"Stage Attempt ID":0,"Stage Name":"collect at :30","Number of Tasks":10,"RDD Info":[{"RDD ID":2,"Name":"ShuffledRDD","Scope":"{\"id\":\"2\",\"name\":\"reduceByKey\"}","Callsite":"reduceByKey at :30","Parent IDs":[1],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":10,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0}],"Parent IDs":[0],"Details":"org.apache.spark.rdd.RDD.collect(RDD.scala:936)\n$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:30)\n$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:35)\n$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw.(:37)\n$line15.$read$$iw$$iw$$iw$$iw$$iw.(:39)\n$line15.$read$$iw$$iw$$iw$$iw.(:41)\n$line15.$read$$iw$$iw$$iw.(:43)\n$line15.$read$$iw$$iw.(:45)\n$line15.$read$$iw.(:47)\n$line15.$read.(:49)\n$line15.$read$.(:53)\n$line15.$read$.()\n$line15.$eval$.$print$lzycompute(:7)\n$line15.$eval$.$print(:6)\n$line15.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:498)\nscala.tools.nsc.interpreter.IMain$ReadEvalPrint.call(IMain.scala:786)","Submission Time":1516300395292,"Completion Time":1516300395728,"Accumulables":[{"ID":41,"Name":"internal.metrics.shuffle.read.fetchWaitTime","Value":159,"Internal":true,"Count Failed Values":true},{"ID":32,"Name":"internal.metrics.memoryBytesSpilled","Value":0,"Internal":true,"Count Failed Values":true},{"ID":26,"Name":"internal.metrics.executorDeserializeCpuTime","Value":133380824,"Internal":true,"Count Failed Values":true},{"ID":29,"Name":"internal.metrics.resultSize","Value":11796,"Internal":true,"Count Failed Values":true},{"ID":38,"Name":"internal.metrics.shuffle.read.remoteBytesRead","Value":827,"Internal":true,"Count Failed Values":true},{"ID":40,"Name":"internal.metrics.shuffle.read.localBytesRead","Value":634,"Internal":true,"Count Failed Values":true},{"ID":25,"Name":"internal.metrics.executorDeserializeTime","Value":491,"Internal":true,"Count Failed Values":true},{"ID":34,"Name":"internal.metrics.peakExecutionMemory","Value":2832,"Internal":true,"Count Failed Values":true},{"ID":37,"Name":"internal.metrics.shuffle.read.localBlocksFetched","Value":13,"Internal":true,"Count Failed Values":true},{"ID":28,"Name":"internal.metrics.executorCpuTime","Value":413897727,"Internal":true,"Count Failed Values":true},{"ID":27,"Name":"internal.metrics.executorRunTime","Value":1140,"Internal":true,"Count Failed Values":true},{"ID":36,"Name":"internal.metrics.shuffle.read.remoteBlocksFetched","Value":17,"Internal":true,"Count Failed Values":true},{"ID":39,"Name":"internal.metrics.shuffle.read.remoteBytesReadToDisk","Value":0,"Internal":true,"Count Failed Values":true},{"ID":42,"Name":"internal.metrics.shuffle.read.recordsRead","Value":30,"Internal":true,"Count Failed Values":true},{"ID":33,"Name":"internal.metrics.diskBytesSpilled","Value":0,"Internal":true,"Count Failed Values":true}]}} +{"Event":"SparkListenerJobEnd","Job ID":0,"Completion Time":1516300395734,"Job Result":{"Result":"JobSucceeded"}} +{"Event":"SparkListenerApplicationEnd","Timestamp":1516300707938} diff --git a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala index 7aa60f2b60796..87f12f303cd5e 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala @@ -156,6 +156,8 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers "applications/local-1426533911241/1/stages/0/0/taskList", "stage task list from multi-attempt app json(2)" -> "applications/local-1426533911241/2/stages/0/0/taskList", + "blacklisting for stage" -> "applications/app-20180109111548-0000/stages/0/0", + "blacklisting node for stage" -> "applications/application_1516285256255_0012/stages/0/0", "rdd list storage json" -> "applications/local-1422981780767/storage/rdd", "executor node blacklisting" -> "applications/app-20161116163331-0000/executors", diff --git a/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala index cd1b7a9e5ab18..afebcdd7b9e31 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala @@ -92,7 +92,7 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M } def createTaskSetBlacklist(stageId: Int = 0): TaskSetBlacklist = { - new TaskSetBlacklist(conf, stageId, clock) + new TaskSetBlacklist(listenerBusMock, conf, stageId, stageAttemptId = 0, clock = clock) } test("executors can be blacklisted with only a few failures per stage") { diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetBlacklistSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetBlacklistSuite.scala index 18981d5be2f94..6e2709dbe1e8b 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetBlacklistSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetBlacklistSuite.scala @@ -16,18 +16,32 @@ */ package org.apache.spark.scheduler +import org.mockito.Matchers.isA +import org.mockito.Mockito.{never, verify} +import org.scalatest.BeforeAndAfterEach +import org.scalatest.mockito.MockitoSugar + import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.internal.config -import org.apache.spark.util.{ManualClock, SystemClock} +import org.apache.spark.util.ManualClock + +class TaskSetBlacklistSuite extends SparkFunSuite with BeforeAndAfterEach with MockitoSugar { -class TaskSetBlacklistSuite extends SparkFunSuite { + private var listenerBusMock: LiveListenerBus = _ + + override def beforeEach(): Unit = { + listenerBusMock = mock[LiveListenerBus] + super.beforeEach() + } test("Blacklisting tasks, executors, and nodes") { val conf = new SparkConf().setAppName("test").setMaster("local") .set(config.BLACKLIST_ENABLED.key, "true") val clock = new ManualClock + val attemptId = 0 + val taskSetBlacklist = new TaskSetBlacklist( + listenerBusMock, conf, stageId = 0, stageAttemptId = attemptId, clock = clock) - val taskSetBlacklist = new TaskSetBlacklist(conf, stageId = 0, clock = clock) clock.setTime(0) // We will mark task 0 & 1 failed on both executor 1 & 2. // We should blacklist all executors on that host, for all tasks for the stage. Note the API @@ -46,27 +60,53 @@ class TaskSetBlacklistSuite extends SparkFunSuite { val shouldBeBlacklisted = (executor == "exec1" && index == 0) assert(taskSetBlacklist.isExecutorBlacklistedForTask(executor, index) === shouldBeBlacklisted) } + assert(!taskSetBlacklist.isExecutorBlacklistedForTaskSet("exec1")) + verify(listenerBusMock, never()) + .post(isA(classOf[SparkListenerExecutorBlacklistedForStage])) + assert(!taskSetBlacklist.isNodeBlacklistedForTaskSet("hostA")) + verify(listenerBusMock, never()) + .post(isA(classOf[SparkListenerNodeBlacklistedForStage])) // Mark task 1 failed on exec1 -- this pushes the executor into the blacklist taskSetBlacklist.updateBlacklistForFailedTask( "hostA", exec = "exec1", index = 1, failureReason = "testing") + assert(taskSetBlacklist.isExecutorBlacklistedForTaskSet("exec1")) + verify(listenerBusMock).post( + SparkListenerExecutorBlacklistedForStage(0, "exec1", 2, 0, attemptId)) + assert(!taskSetBlacklist.isNodeBlacklistedForTaskSet("hostA")) + verify(listenerBusMock, never()) + .post(isA(classOf[SparkListenerNodeBlacklistedForStage])) + // Mark one task as failed on exec2 -- not enough for any further blacklisting yet. taskSetBlacklist.updateBlacklistForFailedTask( "hostA", exec = "exec2", index = 0, failureReason = "testing") assert(taskSetBlacklist.isExecutorBlacklistedForTaskSet("exec1")) + assert(!taskSetBlacklist.isExecutorBlacklistedForTaskSet("exec2")) + assert(!taskSetBlacklist.isNodeBlacklistedForTaskSet("hostA")) + verify(listenerBusMock, never()) + .post(isA(classOf[SparkListenerNodeBlacklistedForStage])) + // Mark another task as failed on exec2 -- now we blacklist exec2, which also leads to // blacklisting the entire node. taskSetBlacklist.updateBlacklistForFailedTask( "hostA", exec = "exec2", index = 1, failureReason = "testing") + assert(taskSetBlacklist.isExecutorBlacklistedForTaskSet("exec1")) + assert(taskSetBlacklist.isExecutorBlacklistedForTaskSet("exec2")) + verify(listenerBusMock).post( + SparkListenerExecutorBlacklistedForStage(0, "exec2", 2, 0, attemptId)) + assert(taskSetBlacklist.isNodeBlacklistedForTaskSet("hostA")) + verify(listenerBusMock).post( + SparkListenerNodeBlacklistedForStage(0, "hostA", 2, 0, attemptId)) + // Make sure the blacklist has the correct per-task && per-executor responses, over a wider // range of inputs. for { @@ -81,6 +121,10 @@ class TaskSetBlacklistSuite extends SparkFunSuite { // intentional, it keeps it fast and is sufficient for usage in the scheduler. taskSetBlacklist.isExecutorBlacklistedForTask(executor, index) === (badExec && badIndex)) assert(taskSetBlacklist.isExecutorBlacklistedForTaskSet(executor) === badExec) + if (badExec) { + verify(listenerBusMock).post( + SparkListenerExecutorBlacklistedForStage(0, executor, 2, 0, attemptId)) + } } } assert(taskSetBlacklist.isNodeBlacklistedForTaskSet("hostA")) @@ -110,7 +154,14 @@ class TaskSetBlacklistSuite extends SparkFunSuite { .set(config.MAX_TASK_ATTEMPTS_PER_NODE, 3) .set(config.MAX_FAILURES_PER_EXEC_STAGE, 2) .set(config.MAX_FAILED_EXEC_PER_NODE_STAGE, 3) - val taskSetBlacklist = new TaskSetBlacklist(conf, stageId = 0, new SystemClock()) + val clock = new ManualClock + + val attemptId = 0 + val taskSetBlacklist = new TaskSetBlacklist( + listenerBusMock, conf, stageId = 0, stageAttemptId = attemptId, clock = clock) + + var time = 0 + clock.setTime(time) // Fail a task twice on hostA, exec:1 taskSetBlacklist.updateBlacklistForFailedTask( "hostA", exec = "1", index = 0, failureReason = "testing") @@ -118,37 +169,75 @@ class TaskSetBlacklistSuite extends SparkFunSuite { "hostA", exec = "1", index = 0, failureReason = "testing") assert(taskSetBlacklist.isExecutorBlacklistedForTask("1", 0)) assert(!taskSetBlacklist.isNodeBlacklistedForTask("hostA", 0)) + assert(!taskSetBlacklist.isExecutorBlacklistedForTaskSet("1")) + verify(listenerBusMock, never()).post( + SparkListenerExecutorBlacklistedForStage(time, "1", 2, 0, attemptId)) + assert(!taskSetBlacklist.isNodeBlacklistedForTaskSet("hostA")) + verify(listenerBusMock, never()).post( + SparkListenerNodeBlacklistedForStage(time, "hostA", 2, 0, attemptId)) // Fail the same task once more on hostA, exec:2 + time += 1 + clock.setTime(time) taskSetBlacklist.updateBlacklistForFailedTask( "hostA", exec = "2", index = 0, failureReason = "testing") assert(taskSetBlacklist.isNodeBlacklistedForTask("hostA", 0)) + assert(!taskSetBlacklist.isExecutorBlacklistedForTaskSet("2")) + verify(listenerBusMock, never()).post( + SparkListenerExecutorBlacklistedForStage(time, "2", 2, 0, attemptId)) + assert(!taskSetBlacklist.isNodeBlacklistedForTaskSet("hostA")) + verify(listenerBusMock, never()).post( + SparkListenerNodeBlacklistedForStage(time, "hostA", 2, 0, attemptId)) // Fail another task on hostA, exec:1. Now that executor has failures on two different tasks, // so its blacklisted + time += 1 + clock.setTime(time) taskSetBlacklist.updateBlacklistForFailedTask( "hostA", exec = "1", index = 1, failureReason = "testing") + assert(taskSetBlacklist.isExecutorBlacklistedForTaskSet("1")) + verify(listenerBusMock) + .post(SparkListenerExecutorBlacklistedForStage(time, "1", 2, 0, attemptId)) + assert(!taskSetBlacklist.isNodeBlacklistedForTaskSet("hostA")) + verify(listenerBusMock, never()) + .post(isA(classOf[SparkListenerNodeBlacklistedForStage])) // Fail a third task on hostA, exec:2, so that exec is blacklisted for the whole task set + time += 1 + clock.setTime(time) taskSetBlacklist.updateBlacklistForFailedTask( "hostA", exec = "2", index = 2, failureReason = "testing") + assert(taskSetBlacklist.isExecutorBlacklistedForTaskSet("2")) + verify(listenerBusMock) + .post(SparkListenerExecutorBlacklistedForStage(time, "2", 2, 0, attemptId)) + assert(!taskSetBlacklist.isNodeBlacklistedForTaskSet("hostA")) + verify(listenerBusMock, never()) + .post(isA(classOf[SparkListenerNodeBlacklistedForStage])) // Fail a fourth & fifth task on hostA, exec:3. Now we've got three executors that are // blacklisted for the taskset, so blacklist the whole node. + time += 1 + clock.setTime(time) taskSetBlacklist.updateBlacklistForFailedTask( "hostA", exec = "3", index = 3, failureReason = "testing") taskSetBlacklist.updateBlacklistForFailedTask( "hostA", exec = "3", index = 4, failureReason = "testing") + assert(taskSetBlacklist.isExecutorBlacklistedForTaskSet("3")) + verify(listenerBusMock) + .post(SparkListenerExecutorBlacklistedForStage(time, "3", 2, 0, attemptId)) + assert(taskSetBlacklist.isNodeBlacklistedForTaskSet("hostA")) + verify(listenerBusMock).post( + SparkListenerNodeBlacklistedForStage(time, "hostA", 3, 0, attemptId)) } test("only blacklist nodes for the task set when all the blacklisted executors are all on " + @@ -157,22 +246,42 @@ class TaskSetBlacklistSuite extends SparkFunSuite { // lead to any node blacklisting val conf = new SparkConf().setAppName("test").setMaster("local") .set(config.BLACKLIST_ENABLED.key, "true") - val taskSetBlacklist = new TaskSetBlacklist(conf, stageId = 0, new SystemClock()) + val clock = new ManualClock + + val attemptId = 0 + val taskSetBlacklist = new TaskSetBlacklist( + listenerBusMock, conf, stageId = 0, stageAttemptId = attemptId, clock = clock) + var time = 0 + clock.setTime(time) taskSetBlacklist.updateBlacklistForFailedTask( "hostA", exec = "1", index = 0, failureReason = "testing") taskSetBlacklist.updateBlacklistForFailedTask( "hostA", exec = "1", index = 1, failureReason = "testing") + assert(taskSetBlacklist.isExecutorBlacklistedForTaskSet("1")) + verify(listenerBusMock) + .post(SparkListenerExecutorBlacklistedForStage(time, "1", 2, 0, attemptId)) + assert(!taskSetBlacklist.isNodeBlacklistedForTaskSet("hostA")) + verify(listenerBusMock, never()).post( + SparkListenerNodeBlacklistedForStage(time, "hostA", 2, 0, attemptId)) + time += 1 + clock.setTime(time) taskSetBlacklist.updateBlacklistForFailedTask( "hostB", exec = "2", index = 0, failureReason = "testing") taskSetBlacklist.updateBlacklistForFailedTask( "hostB", exec = "2", index = 1, failureReason = "testing") assert(taskSetBlacklist.isExecutorBlacklistedForTaskSet("1")) + assert(taskSetBlacklist.isExecutorBlacklistedForTaskSet("2")) + verify(listenerBusMock) + .post(SparkListenerExecutorBlacklistedForStage(time, "2", 2, 0, attemptId)) + assert(!taskSetBlacklist.isNodeBlacklistedForTaskSet("hostA")) assert(!taskSetBlacklist.isNodeBlacklistedForTaskSet("hostB")) + verify(listenerBusMock, never()) + .post(isA(classOf[SparkListenerNodeBlacklistedForStage])) } } diff --git a/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala index e7981bec6d64b..042bba7f226fd 100644 --- a/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala @@ -251,6 +251,49 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { } } + // Blacklisting executor for stage + time += 1 + listener.onExecutorBlacklistedForStage(SparkListenerExecutorBlacklistedForStage( + time = time, + executorId = execIds.head, + taskFailures = 2, + stageId = stages.head.stageId, + stageAttemptId = stages.head.attemptId)) + + val executorStageSummaryWrappers = + store.view(classOf[ExecutorStageSummaryWrapper]).index("stage") + .first(key(stages.head)) + .last(key(stages.head)) + .asScala.toSeq + + assert(executorStageSummaryWrappers.nonEmpty) + executorStageSummaryWrappers.foreach { exec => + // only the first executor is expected to be blacklisted + val expectedBlacklistedFlag = exec.executorId == execIds.head + assert(exec.info.isBlacklistedForStage === expectedBlacklistedFlag) + } + + // Blacklisting node for stage + time += 1 + listener.onNodeBlacklistedForStage(SparkListenerNodeBlacklistedForStage( + time = time, + hostId = "2.example.com", // this is where the second executor is hosted + executorFailures = 1, + stageId = stages.head.stageId, + stageAttemptId = stages.head.attemptId)) + + val executorStageSummaryWrappersForNode = + store.view(classOf[ExecutorStageSummaryWrapper]).index("stage") + .first(key(stages.head)) + .last(key(stages.head)) + .asScala.toSeq + + assert(executorStageSummaryWrappersForNode.nonEmpty) + executorStageSummaryWrappersForNode.foreach { exec => + // both executor is expected to be blacklisted + assert(exec.info.isBlacklistedForStage === true) + } + // Fail one of the tasks, re-start it. time += 1 s1Tasks.head.markFinished(TaskState.FAILED, time) diff --git a/dev/.rat-excludes b/dev/.rat-excludes index 607234b4068d0..243fbe3e1bc24 100644 --- a/dev/.rat-excludes +++ b/dev/.rat-excludes @@ -73,8 +73,10 @@ logs .*dependency-reduced-pom.xml known_translations json_expectation +app-20180109111548-0000 app-20161115172038-0000 app-20161116163331-0000 +application_1516285256255_0012 local-1422981759269 local-1422981780767 local-1425081759269 From e18d6f5326e0d9ea03d31de5ce04cb84d3b8ab37 Mon Sep 17 00:00:00 2001 From: Felix Cheung Date: Wed, 24 Jan 2018 09:37:54 -0800 Subject: [PATCH 0192/1161] [SPARK-20906][SPARKR] Add API doc example for Constrained Logistic Regression ## What changes were proposed in this pull request? doc only changes ## How was this patch tested? manual Author: Felix Cheung Closes #20380 from felixcheung/rclrdoc. --- R/pkg/R/mllib_classification.R | 15 ++++++++++++++- R/pkg/tests/fulltests/test_mllib_classification.R | 10 +++++----- 2 files changed, 19 insertions(+), 6 deletions(-) diff --git a/R/pkg/R/mllib_classification.R b/R/pkg/R/mllib_classification.R index 7cd072a1d6f89..f6e9b1357561b 100644 --- a/R/pkg/R/mllib_classification.R +++ b/R/pkg/R/mllib_classification.R @@ -279,11 +279,24 @@ function(object, path, overwrite = FALSE) { #' savedModel <- read.ml(path) #' summary(savedModel) #' -#' # multinomial logistic regression +#' # binary logistic regression against two classes with +#' # upperBoundsOnCoefficients and upperBoundsOnIntercepts +#' ubc <- matrix(c(1.0, 0.0, 1.0, 0.0), nrow = 1, ncol = 4) +#' model <- spark.logit(training, Species ~ ., +#' upperBoundsOnCoefficients = ubc, +#' upperBoundsOnIntercepts = 1.0) #' +#' # multinomial logistic regression #' model <- spark.logit(training, Class ~ ., regParam = 0.5) #' summary <- summary(model) #' +#' # multinomial logistic regression with +#' # lowerBoundsOnCoefficients and lowerBoundsOnIntercepts +#' lbc <- matrix(c(0.0, -1.0, 0.0, -1.0, 0.0, -1.0, 0.0, -1.0), nrow = 2, ncol = 4) +#' lbi <- as.array(c(0.0, 0.0)) +#' model <- spark.logit(training, Species ~ ., family = "multinomial", +#' lowerBoundsOnCoefficients = lbc, +#' lowerBoundsOnIntercepts = lbi) #' } #' @note spark.logit since 2.1.0 setMethod("spark.logit", signature(data = "SparkDataFrame", formula = "formula"), diff --git a/R/pkg/tests/fulltests/test_mllib_classification.R b/R/pkg/tests/fulltests/test_mllib_classification.R index ad47717ddc12f..a46c47dccd02e 100644 --- a/R/pkg/tests/fulltests/test_mllib_classification.R +++ b/R/pkg/tests/fulltests/test_mllib_classification.R @@ -124,7 +124,7 @@ test_that("spark.logit", { # Petal.Width 0.42122607 # nolint end - # Test multinomial logistic regression againt three classes + # Test multinomial logistic regression against three classes df <- suppressWarnings(createDataFrame(iris)) model <- spark.logit(df, Species ~ ., regParam = 0.5) summary <- summary(model) @@ -196,7 +196,7 @@ test_that("spark.logit", { # # nolint end - # Test multinomial logistic regression againt two classes + # Test multinomial logistic regression against two classes df <- suppressWarnings(createDataFrame(iris)) training <- df[df$Species %in% c("versicolor", "virginica"), ] model <- spark.logit(training, Species ~ ., regParam = 0.5, family = "multinomial") @@ -208,7 +208,7 @@ test_that("spark.logit", { expect_true(all(abs(versicolorCoefsR - versicolorCoefs) < 0.1)) expect_true(all(abs(virginicaCoefsR - virginicaCoefs) < 0.1)) - # Test binomial logistic regression againt two classes + # Test binomial logistic regression against two classes model <- spark.logit(training, Species ~ ., regParam = 0.5) summary <- summary(model) coefsR <- c(-6.08, 0.25, 0.16, 0.48, 1.04) @@ -239,7 +239,7 @@ test_that("spark.logit", { prediction2 <- collect(select(predict(model2, df2), "prediction")) expect_equal(sort(prediction2$prediction), c("0.0", "0.0", "0.0", "0.0", "0.0")) - # Test binomial logistic regression againt two classes with upperBoundsOnCoefficients + # Test binomial logistic regression against two classes with upperBoundsOnCoefficients # and upperBoundsOnIntercepts u <- matrix(c(1.0, 0.0, 1.0, 0.0), nrow = 1, ncol = 4) model <- spark.logit(training, Species ~ ., upperBoundsOnCoefficients = u, @@ -252,7 +252,7 @@ test_that("spark.logit", { expect_error(spark.logit(training, Species ~ ., upperBoundsOnCoefficients = as.array(c(1, 2)), upperBoundsOnIntercepts = 1.0)) - # Test binomial logistic regression againt two classes with lowerBoundsOnCoefficients + # Test binomial logistic regression against two classes with lowerBoundsOnCoefficients # and lowerBoundsOnIntercepts l <- matrix(c(0.0, -1.0, 0.0, -1.0), nrow = 1, ncol = 4) model <- spark.logit(training, Species ~ ., lowerBoundsOnCoefficients = l, From 8c273b4162b6138c4abba64f595c2750d1ef8bcb Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Wed, 24 Jan 2018 10:00:42 -0800 Subject: [PATCH 0193/1161] [SPARK-23020][CORE][FOLLOWUP] Fix Java style check issues. ## What changes were proposed in this pull request? This is a follow-up of #20297 which broke lint-java checks. This pr fixes the lint-java issues. ``` [ERROR] src/test/java/org/apache/spark/launcher/BaseSuite.java:[21,8] (imports) UnusedImports: Unused import - java.util.concurrent.TimeUnit. [ERROR] src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java:[27,8] (imports) UnusedImports: Unused import - java.util.concurrent.TimeUnit. ``` ## How was this patch tested? Checked manually in my local environment. Author: Takuya UESHIN Closes #20376 from ueshin/issues/SPARK-23020/fup1. --- .../test/java/org/apache/spark/launcher/SparkLauncherSuite.java | 1 - launcher/src/test/java/org/apache/spark/launcher/BaseSuite.java | 1 - 2 files changed, 2 deletions(-) diff --git a/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java b/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java index a042375c6ae91..1543f4fdb0162 100644 --- a/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java +++ b/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java @@ -24,7 +24,6 @@ import java.util.List; import java.util.Map; import java.util.Properties; -import java.util.concurrent.TimeUnit; import org.junit.Test; import static org.junit.Assert.*; diff --git a/launcher/src/test/java/org/apache/spark/launcher/BaseSuite.java b/launcher/src/test/java/org/apache/spark/launcher/BaseSuite.java index 3722a59d9438e..438349e027a24 100644 --- a/launcher/src/test/java/org/apache/spark/launcher/BaseSuite.java +++ b/launcher/src/test/java/org/apache/spark/launcher/BaseSuite.java @@ -18,7 +18,6 @@ package org.apache.spark.launcher; import java.time.Duration; -import java.util.concurrent.TimeUnit; import org.junit.After; import org.slf4j.bridge.SLF4JBridgeHandler; From bbb87b350d9d0d393db3fb7ca61dcbae538553bb Mon Sep 17 00:00:00 2001 From: zuotingbing Date: Wed, 24 Jan 2018 10:07:24 -0800 Subject: [PATCH 0194/1161] [SPARK-22837][SQL] Session timeout checker does not work in SessionManager. ## What changes were proposed in this pull request? Currently we do not call the `super.init(hiveConf)` in `SparkSQLSessionManager.init`. So we do not load the config `HIVE_SERVER2_SESSION_CHECK_INTERVAL HIVE_SERVER2_IDLE_SESSION_TIMEOUT HIVE_SERVER2_IDLE_SESSION_CHECK_OPERATION` , which cause the session timeout checker does not work. ## How was this patch tested? manual tests Author: zuotingbing Closes #20025 from zuotingbing/SPARK-22837. --- .../thriftserver/SparkSQLSessionManager.scala | 16 +--------------- 1 file changed, 1 insertion(+), 15 deletions(-) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala index 48c0ebef3e0ce..2958b771f3648 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala @@ -40,22 +40,8 @@ private[hive] class SparkSQLSessionManager(hiveServer: HiveServer2, sqlContext: private lazy val sparkSqlOperationManager = new SparkSQLOperationManager() override def init(hiveConf: HiveConf) { - setSuperField(this, "hiveConf", hiveConf) - - // Create operation log root directory, if operation logging is enabled - if (hiveConf.getBoolVar(ConfVars.HIVE_SERVER2_LOGGING_OPERATION_ENABLED)) { - invoke(classOf[SessionManager], this, "initOperationLogRootDir") - } - - val backgroundPoolSize = hiveConf.getIntVar(ConfVars.HIVE_SERVER2_ASYNC_EXEC_THREADS) - setSuperField(this, "backgroundOperationPool", Executors.newFixedThreadPool(backgroundPoolSize)) - getAncestorField[Log](this, 3, "LOG").info( - s"HiveServer2: Async execution pool size $backgroundPoolSize") - setSuperField(this, "operationManager", sparkSqlOperationManager) - addService(sparkSqlOperationManager) - - initCompositeService(hiveConf) + super.init(hiveConf) } override def openSession( From 840dea64abd8a3a5960de830f19a57f5f1aa3bf6 Mon Sep 17 00:00:00 2001 From: Matthew Tovbin Date: Wed, 24 Jan 2018 13:13:44 -0500 Subject: [PATCH 0195/1161] [SPARK-23152][ML] - Correctly guard against empty datasets ## What changes were proposed in this pull request? Correctly guard against empty datasets in `org.apache.spark.ml.classification.Classifier` ## How was this patch tested? existing tests Author: Matthew Tovbin Closes #20321 from tovbinm/SPARK-23152. --- .../org/apache/spark/ml/classification/Classifier.scala | 2 +- .../apache/spark/ml/classification/ClassifierSuite.scala | 7 +++++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala index bc0b49d48d323..9d1d5aa1e0cff 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala @@ -109,7 +109,7 @@ abstract class Classifier[ case None => // Get number of classes from dataset itself. val maxLabelRow: Array[Row] = dataset.select(max($(labelCol))).take(1) - if (maxLabelRow.isEmpty) { + if (maxLabelRow.isEmpty || maxLabelRow(0).get(0) == null) { throw new SparkException("ML algorithm was given empty dataset.") } val maxDoubleLabel: Double = maxLabelRow.head.getDouble(0) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala index de712079329da..87bf2be06c2be 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala @@ -90,6 +90,13 @@ class ClassifierSuite extends SparkFunSuite with MLlibTestSparkContext { } assert(e.getMessage.contains("requires integers in range")) } + val df3 = getTestData(Seq.empty[Double]) + withClue("getNumClasses should fail if dataset is empty") { + val e: SparkException = intercept[SparkException] { + c.getNumClasses(df3) + } + assert(e.getMessage == "ML algorithm was given empty dataset.") + } } } From 0e178e1523175a0be9437920045e80deb0a2712b Mon Sep 17 00:00:00 2001 From: Mark Petruska Date: Wed, 24 Jan 2018 10:25:14 -0800 Subject: [PATCH 0196/1161] [SPARK-22297][CORE TESTS] Flaky test: BlockManagerSuite "Shuffle registration timeout and maxAttempts conf" ## What changes were proposed in this pull request? [Ticket](https://issues.apache.org/jira/browse/SPARK-22297) - one of the tests seems to produce unreliable results due to execution speed variability Since the original test was trying to connect to the test server with `40 ms` timeout, and the test server replied after `50 ms`, the error might be produced under the following conditions: - it might occur that the test server replies correctly after `50 ms` - but the client does only receive the timeout after `51 ms`s - this might happen if the executor has to schedule a big number of threads, and decides to delay the thread/actor that is responsible to watch the timeout, because of high CPU load - running an entire test suite usually produces high loads on the CPU executing the tests ## How was this patch tested? The test's check cases remain the same and the set-up emulates the previous version's. Author: Mark Petruska Closes #19671 from mpetruska/SPARK-22297. --- .../spark/storage/BlockManagerSuite.scala | 55 +++++++++++++------ 1 file changed, 38 insertions(+), 17 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 629eed49b04cc..b19d8ebf72c61 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -20,7 +20,6 @@ package org.apache.spark.storage import java.nio.ByteBuffer import scala.collection.JavaConverters._ -import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import scala.concurrent.Future import scala.concurrent.duration._ @@ -44,8 +43,9 @@ import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} import org.apache.spark.network.client.{RpcResponseCallback, TransportClient} import org.apache.spark.network.netty.{NettyBlockTransferService, SparkTransportConf} import org.apache.spark.network.server.{NoOpRpcHandler, TransportServer, TransportServerBootstrap} -import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient, TempFileManager} +import org.apache.spark.network.shuffle.{BlockFetchingListener, TempFileManager} import org.apache.spark.network.shuffle.protocol.{BlockTransferMessage, RegisterExecutor} +import org.apache.spark.network.util.TransportConf import org.apache.spark.rpc.RpcEnv import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.security.{CryptoStreamUtils, EncryptionFunSuite} @@ -1325,9 +1325,18 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE test("SPARK-20640: Shuffle registration timeout and maxAttempts conf are working") { val tryAgainMsg = "test_spark_20640_try_again" + val timingoutExecutor = "timingoutExecutor" + val tryAgainExecutor = "tryAgainExecutor" + val succeedingExecutor = "succeedingExecutor" + // a server which delays response 50ms and must try twice for success. def newShuffleServer(port: Int): (TransportServer, Int) = { - val attempts = new mutable.HashMap[String, Int]() + val failure = new Exception(tryAgainMsg) + val success = ByteBuffer.wrap(new Array[Byte](0)) + + var secondExecutorFailedOnce = false + var thirdExecutorFailedOnce = false + val handler = new NoOpRpcHandler { override def receive( client: TransportClient, @@ -1335,15 +1344,26 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE callback: RpcResponseCallback): Unit = { val msgObj = BlockTransferMessage.Decoder.fromByteBuffer(message) msgObj match { - case exec: RegisterExecutor => - Thread.sleep(50) - val attempt = attempts.getOrElse(exec.execId, 0) + 1 - attempts(exec.execId) = attempt - if (attempt < 2) { - callback.onFailure(new Exception(tryAgainMsg)) - return - } - callback.onSuccess(ByteBuffer.wrap(new Array[Byte](0))) + + case exec: RegisterExecutor if exec.execId == timingoutExecutor => + () // No reply to generate client-side timeout + + case exec: RegisterExecutor + if exec.execId == tryAgainExecutor && !secondExecutorFailedOnce => + secondExecutorFailedOnce = true + callback.onFailure(failure) + + case exec: RegisterExecutor if exec.execId == tryAgainExecutor => + callback.onSuccess(success) + + case exec: RegisterExecutor + if exec.execId == succeedingExecutor && !thirdExecutorFailedOnce => + thirdExecutorFailedOnce = true + callback.onFailure(failure) + + case exec: RegisterExecutor if exec.execId == succeedingExecutor => + callback.onSuccess(success) + } } } @@ -1352,6 +1372,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE val transCtx = new TransportContext(transConf, handler, true) (transCtx.createServer(port, Seq.empty[TransportServerBootstrap].asJava), port) } + val candidatePort = RandomUtils.nextInt(1024, 65536) val (server, shufflePort) = Utils.startServiceOnPort(candidatePort, newShuffleServer, conf, "ShuffleServer") @@ -1360,21 +1381,21 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE conf.set("spark.shuffle.service.port", shufflePort.toString) conf.set(SHUFFLE_REGISTRATION_TIMEOUT.key, "40") conf.set(SHUFFLE_REGISTRATION_MAX_ATTEMPTS.key, "1") - var e = intercept[SparkException]{ - makeBlockManager(8000, "executor1") + var e = intercept[SparkException] { + makeBlockManager(8000, timingoutExecutor) }.getMessage assert(e.contains("TimeoutException")) conf.set(SHUFFLE_REGISTRATION_TIMEOUT.key, "1000") conf.set(SHUFFLE_REGISTRATION_MAX_ATTEMPTS.key, "1") - e = intercept[SparkException]{ - makeBlockManager(8000, "executor2") + e = intercept[SparkException] { + makeBlockManager(8000, tryAgainExecutor) }.getMessage assert(e.contains(tryAgainMsg)) conf.set(SHUFFLE_REGISTRATION_TIMEOUT.key, "1000") conf.set(SHUFFLE_REGISTRATION_MAX_ATTEMPTS.key, "2") - makeBlockManager(8000, "executor3") + makeBlockManager(8000, succeedingExecutor) server.close() } From bc9641d9026aeae3571915b003ac971f6245d53c Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Wed, 24 Jan 2018 12:58:44 -0800 Subject: [PATCH 0197/1161] [SPARK-23198][SS][TEST] Fix KafkaContinuousSourceStressForDontFailOnDataLossSuite to test ContinuousExecution ## What changes were proposed in this pull request? Currently, `KafkaContinuousSourceStressForDontFailOnDataLossSuite` runs on `MicroBatchExecution`. It should test `ContinuousExecution`. ## How was this patch tested? Pass the updated test suite. Author: Dongjoon Hyun Closes #20374 from dongjoon-hyun/SPARK-23198. --- .../apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala index b3dade414f625..a7083fa4e3417 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala @@ -91,6 +91,7 @@ class KafkaContinuousSourceStressForDontFailOnDataLossSuite ds.writeStream .format("memory") .queryName("memory") + .trigger(Trigger.Continuous("1 second")) .start() } } From 6f0ba8472d1128551fa8090deebcecde0daebc53 Mon Sep 17 00:00:00 2001 From: caoxuewen Date: Wed, 24 Jan 2018 13:06:09 -0800 Subject: [PATCH 0198/1161] [MINOR][SQL] add new unit test to LimitPushdown MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? This PR is repaired as follows 1、update y -> x in "left outer join" test case ,maybe is mistake. 2、add a new test case:"left outer join and left sides are limited" 3、add a new test case:"left outer join and right sides are limited" 4、add a new test case: "right outer join and right sides are limited" 5、add a new test case: "right outer join and left sides are limited" 6、Remove annotations without code implementation ## How was this patch tested? add new unit test case. Author: caoxuewen Closes #20381 from heary-cao/LimitPushdownSuite. --- .../sql/catalyst/optimizer/Optimizer.scala | 1 - .../optimizer/LimitPushdownSuite.scala | 30 ++++++++++++++++++- 2 files changed, 29 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 0f9daa5f04c76..8d207708c12ad 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -352,7 +352,6 @@ object LimitPushDown extends Rule[LogicalPlan] { // on both sides if it is applied multiple times. Therefore: // - If one side is already limited, stack another limit on top if the new limit is smaller. // The redundant limit will be collapsed by the CombineLimits rule. - // - If neither side is limited, limit the side that is estimated to be bigger. case LocalLimit(exp, join @ Join(left, right, joinType, _)) => val newJoin = joinType match { case RightOuter => join.copy(right = maybePushLocalLimit(exp, right)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala index cc98d2350c777..17fb9fc5d11e3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala @@ -93,7 +93,21 @@ class LimitPushdownSuite extends PlanTest { test("left outer join") { val originalQuery = x.join(y, LeftOuter).limit(1) val optimized = Optimize.execute(originalQuery.analyze) - val correctAnswer = Limit(1, LocalLimit(1, y).join(y, LeftOuter)).analyze + val correctAnswer = Limit(1, LocalLimit(1, x).join(y, LeftOuter)).analyze + comparePlans(optimized, correctAnswer) + } + + test("left outer join and left sides are limited") { + val originalQuery = x.limit(2).join(y, LeftOuter).limit(1) + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = Limit(1, LocalLimit(1, x).join(y, LeftOuter)).analyze + comparePlans(optimized, correctAnswer) + } + + test("left outer join and right sides are limited") { + val originalQuery = x.join(y.limit(2), LeftOuter).limit(1) + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = Limit(1, LocalLimit(1, x).join(Limit(2, y), LeftOuter)).analyze comparePlans(optimized, correctAnswer) } @@ -104,6 +118,20 @@ class LimitPushdownSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + test("right outer join and right sides are limited") { + val originalQuery = x.join(y.limit(2), RightOuter).limit(1) + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = Limit(1, x.join(LocalLimit(1, y), RightOuter)).analyze + comparePlans(optimized, correctAnswer) + } + + test("right outer join and left sides are limited") { + val originalQuery = x.limit(2).join(y, RightOuter).limit(1) + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = Limit(1, Limit(2, x).join(LocalLimit(1, y), RightOuter)).analyze + comparePlans(optimized, correctAnswer) + } + test("larger limits are not pushed on top of smaller ones in right outer join") { val originalQuery = x.join(y.limit(5), RightOuter).limit(10) val optimized = Optimize.execute(originalQuery.analyze) From 45b4bbfddc18a77011c3bc1bfd71b2cd3466443c Mon Sep 17 00:00:00 2001 From: zhoukang Date: Thu, 25 Jan 2018 15:24:52 +0800 Subject: [PATCH 0199/1161] [SPARK-23129][CORE] Make deserializeStream of DiskMapIterator init lazily ## What changes were proposed in this pull request? Currently,the deserializeStream in ExternalAppendOnlyMap#DiskMapIterator init when DiskMapIterator instance created.This will cause memory use overhead when ExternalAppendOnlyMap spill too much times. We can avoid this by making deserializeStream init when it is used the first time. This patch make deserializeStream init lazily. ## How was this patch tested? Exist tests Author: zhoukang Closes #20292 from caneGuy/zhoukang/lay-diskmapiterator. --- .../util/collection/ExternalAppendOnlyMap.scala | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala index 375f4a6921225..5c6dd45ec58e3 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala @@ -463,7 +463,7 @@ class ExternalAppendOnlyMap[K, V, C]( // An intermediate stream that reads from exactly one batch // This guards against pre-fetching and other arbitrary behavior of higher level streams - private var deserializeStream = nextBatchStream() + private var deserializeStream: DeserializationStream = null private var nextItem: (K, C) = null private var objectsRead = 0 @@ -528,7 +528,11 @@ class ExternalAppendOnlyMap[K, V, C]( override def hasNext: Boolean = { if (nextItem == null) { if (deserializeStream == null) { - return false + // In case of deserializeStream has not been initialized + deserializeStream = nextBatchStream() + if (deserializeStream == null) { + return false + } } nextItem = readNextItem() } @@ -536,19 +540,18 @@ class ExternalAppendOnlyMap[K, V, C]( } override def next(): (K, C) = { - val item = if (nextItem == null) readNextItem() else nextItem - if (item == null) { + if (!hasNext) { throw new NoSuchElementException } + val item = nextItem nextItem = null item } private def cleanup() { batchIndex = batchOffsets.length // Prevent reading any other batch - val ds = deserializeStream - if (ds != null) { - ds.close() + if (deserializeStream != null) { + deserializeStream.close() deserializeStream = null } if (fileStream != null) { From e29b08add92462a6505fef966629e74ba30e994e Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Thu, 25 Jan 2018 16:40:41 +0800 Subject: [PATCH 0200/1161] [SPARK-23208][SQL] Fix code generation for complex create array (related) expressions ## What changes were proposed in this pull request? The `GenArrayData.genCodeToCreateArrayData` produces illegal java code when code splitting is enabled. This is used in `CreateArray` and `CreateMap` expressions for complex object arrays. This issue is caused by a typo. ## How was this patch tested? Added a regression test in `complexTypesSuite`. Author: Herman van Hovell Closes #20391 from hvanhovell/SPARK-23208. --- .../sql/catalyst/expressions/complexTypeCreator.scala | 2 +- .../spark/sql/catalyst/optimizer/complexTypesSuite.scala | 8 +++++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 3dc2ee03a86e3..047b80ac5289c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -111,7 +111,7 @@ private [sql] object GenArrayData { val assignmentString = ctx.splitExpressionsWithCurrentInputs( expressions = assignments, funcName = "apply", - extraArguments = ("Object[]", arrayDataName) :: Nil) + extraArguments = ("Object[]", arrayName) :: Nil) (s"Object[] $arrayName = new Object[$numElements];", assignmentString, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala index 0d11958876ce9..de544ac314789 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Range} import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.catalyst.util.GenericArrayData import org.apache.spark.sql.types._ /** @@ -31,7 +32,7 @@ import org.apache.spark.sql.types._ * i.e. {{{create_named_struct(square, `x` * `x`).square}}} can be simplified to {{{`x` * `x`}}}. * sam applies to create_array and create_map */ -class ComplexTypesSuite extends PlanTest{ +class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { object Optimizer extends RuleExecutor[LogicalPlan] { val batches = @@ -171,6 +172,11 @@ class ComplexTypesSuite extends PlanTest{ assert(ctx.inlinedMutableStates.length == 0) } + test("SPARK-23208: Test code splitting for create array related methods") { + val inputs = (1 to 2500).map(x => Literal(s"l_$x")) + checkEvaluation(CreateArray(inputs), new GenericArrayData(inputs.map(_.eval()))) + } + test("simplify map ops") { val rel = relation .select( From 39ee2acf96f1e1496cff8e4d2614d27fca76d43b Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Thu, 25 Jan 2018 01:48:11 -0800 Subject: [PATCH 0201/1161] [SPARK-23163][DOC][PYTHON] Sync ML Python API with Scala ## What changes were proposed in this pull request? This syncs the ML Python API with Scala for differences found after the 2.3 QA audit. ## How was this patch tested? NA Author: Bryan Cutler Closes #20354 from BryanCutler/pyspark-ml-doc-sync-23163. --- python/pyspark/ml/evaluation.py | 8 +++++++- python/pyspark/ml/feature.py | 2 +- python/pyspark/ml/fpm.py | 2 +- 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/python/pyspark/ml/evaluation.py b/python/pyspark/ml/evaluation.py index aa8dbe708a115..0cbce9b40048f 100644 --- a/python/pyspark/ml/evaluation.py +++ b/python/pyspark/ml/evaluation.py @@ -334,7 +334,13 @@ class ClusteringEvaluator(JavaEvaluator, HasPredictionCol, HasFeaturesCol, .. note:: Experimental Evaluator for Clustering results, which expects two input - columns: prediction and features. + columns: prediction and features. The metric computes the Silhouette + measure using the squared Euclidean distance. + + The Silhouette is a measure for the validation of the consistency + within clusters. It ranges between 1 and -1, where a value close to + 1 means that the points in a cluster are close to the other points + in the same cluster and far from the points of the other clusters. >>> from pyspark.ml.linalg import Vectors >>> featureAndPredictions = map(lambda x: (Vectors.dense(x[0]), x[1]), diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index eb79b193103e2..da85ba761a145 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -3440,7 +3440,7 @@ class ChiSqSelector(JavaEstimator, HasFeaturesCol, HasOutputCol, HasLabelCol, Ja selectorType = Param(Params._dummy(), "selectorType", "The selector type of the ChisqSelector. " + - "Supported options: numTopFeatures (default), percentile and fpr.", + "Supported options: numTopFeatures (default), percentile, fpr, fdr, fwe.", typeConverter=TypeConverters.toString) numTopFeatures = \ diff --git a/python/pyspark/ml/fpm.py b/python/pyspark/ml/fpm.py index dd7dda5f03124..b8dafd49d354d 100644 --- a/python/pyspark/ml/fpm.py +++ b/python/pyspark/ml/fpm.py @@ -144,7 +144,7 @@ def freqItemsets(self): @since("2.2.0") def associationRules(self): """ - Data with three columns: + DataFrame with three columns: * `antecedent` - Array of the same type as the input column. * `consequent` - Array of the same type as the input column. * `confidence` - Confidence for the rule (`DoubleType`). From d20bbc2d87ae6bd56d236a7c3d036b52c5f20ff5 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 25 Jan 2018 19:49:58 +0800 Subject: [PATCH 0202/1161] [SPARK-21717][SQL] Decouple consume functions of physical operators in whole-stage codegen ## What changes were proposed in this pull request? It has been observed in SPARK-21603 that whole-stage codegen suffers performance degradation, if the generated functions are too long to be optimized by JIT. We basically produce a single function to incorporate generated codes from all physical operators in whole-stage. Thus, it is possibly to grow the size of generated function over a threshold that we can't have JIT optimization for it anymore. This patch is trying to decouple the logic of consuming rows in physical operators to avoid a giant function processing rows. ## How was this patch tested? Added tests. Author: Liang-Chi Hsieh Closes #18931 from viirya/SPARK-21717. --- .../expressions/codegen/CodeGenerator.scala | 38 ++++- .../apache/spark/sql/internal/SQLConf.scala | 12 ++ .../sql/execution/WholeStageCodegenExec.scala | 135 +++++++++++++++--- .../execution/WholeStageCodegenSuite.scala | 47 +++++- 4 files changed, 203 insertions(+), 29 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index f96ed7628fda1..4dcbb702893da 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -1245,6 +1245,31 @@ class CodegenContext { "" } } + + /** + * Returns the length of parameters for a Java method descriptor. `this` contributes one unit + * and a parameter of type long or double contributes two units. Besides, for nullable parameter, + * we also need to pass a boolean parameter for the null status. + */ + def calculateParamLength(params: Seq[Expression]): Int = { + def paramLengthForExpr(input: Expression): Int = { + // For a nullable expression, we need to pass in an extra boolean parameter. + (if (input.nullable) 1 else 0) + javaType(input.dataType) match { + case JAVA_LONG | JAVA_DOUBLE => 2 + case _ => 1 + } + } + // Initial value is 1 for `this`. + 1 + params.map(paramLengthForExpr(_)).sum + } + + /** + * In Java, a method descriptor is valid only if it represents method parameters with a total + * length less than a pre-defined constant. + */ + def isValidParamLength(paramLength: Int): Boolean = { + paramLength <= CodeGenerator.MAX_JVM_METHOD_PARAMS_LENGTH + } } /** @@ -1311,26 +1336,29 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin object CodeGenerator extends Logging { // This is the value of HugeMethodLimit in the OpenJDK JVM settings - val DEFAULT_JVM_HUGE_METHOD_LIMIT = 8000 + final val DEFAULT_JVM_HUGE_METHOD_LIMIT = 8000 + + // The max valid length of method parameters in JVM. + final val MAX_JVM_METHOD_PARAMS_LENGTH = 255 // This is the threshold over which the methods in an inner class are grouped in a single // method which is going to be called by the outer class instead of the many small ones - val MERGE_SPLIT_METHODS_THRESHOLD = 3 + final val MERGE_SPLIT_METHODS_THRESHOLD = 3 // The number of named constants that can exist in the class is limited by the Constant Pool // limit, 65,536. We cannot know how many constants will be inserted for a class, so we use a // threshold of 1000k bytes to determine when a function should be inlined to a private, inner // class. - val GENERATED_CLASS_SIZE_THRESHOLD = 1000000 + final val GENERATED_CLASS_SIZE_THRESHOLD = 1000000 // This is the threshold for the number of global variables, whose types are primitive type or // complex type (e.g. more than one-dimensional array), that will be placed at the outer class - val OUTER_CLASS_VARIABLES_THRESHOLD = 10000 + final val OUTER_CLASS_VARIABLES_THRESHOLD = 10000 // This is the maximum number of array elements to keep global variables in one Java array // 32767 is the maximum integer value that does not require a constant pool entry in a Java // bytecode instruction - val MUTABLESTATEARRAY_SIZE_LIMIT = 32768 + final val MUTABLESTATEARRAY_SIZE_LIMIT = 32768 /** * Compile the Java source code into a Java class, using Janino. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 1cef09a5bf053..470f88c213561 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -661,6 +661,15 @@ object SQLConf { .intConf .createWithDefault(CodeGenerator.DEFAULT_JVM_HUGE_METHOD_LIMIT) + val WHOLESTAGE_SPLIT_CONSUME_FUNC_BY_OPERATOR = + buildConf("spark.sql.codegen.splitConsumeFuncByOperator") + .internal() + .doc("When true, whole stage codegen would put the logic of consuming rows of each " + + "physical operator into individual methods, instead of a single big method. This can be " + + "used to avoid oversized function that can miss the opportunity of JIT optimization.") + .booleanConf + .createWithDefault(true) + val FILES_MAX_PARTITION_BYTES = buildConf("spark.sql.files.maxPartitionBytes") .doc("The maximum number of bytes to pack into a single partition when reading files.") .longConf @@ -1263,6 +1272,9 @@ class SQLConf extends Serializable with Logging { def hugeMethodLimit: Int = getConf(WHOLESTAGE_HUGE_METHOD_LIMIT) + def wholeStageSplitConsumeFuncByOperator: Boolean = + getConf(WHOLESTAGE_SPLIT_CONSUME_FUNC_BY_OPERATOR) + def tableRelationCacheSize: Int = getConf(StaticSQLConf.FILESOURCE_TABLE_RELATION_CACHE_SIZE) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index 6102937852347..8ea9e81b2e53b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.execution import java.util.Locale +import scala.collection.mutable + import org.apache.spark.broadcast import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow @@ -106,6 +108,31 @@ trait CodegenSupport extends SparkPlan { */ protected def doProduce(ctx: CodegenContext): String + private def prepareRowVar(ctx: CodegenContext, row: String, colVars: Seq[ExprCode]): ExprCode = { + if (row != null) { + ExprCode("", "false", row) + } else { + if (colVars.nonEmpty) { + val colExprs = output.zipWithIndex.map { case (attr, i) => + BoundReference(i, attr.dataType, attr.nullable) + } + val evaluateInputs = evaluateVariables(colVars) + // generate the code to create a UnsafeRow + ctx.INPUT_ROW = row + ctx.currentVars = colVars + val ev = GenerateUnsafeProjection.createCode(ctx, colExprs, false) + val code = s""" + |$evaluateInputs + |${ev.code.trim} + """.stripMargin.trim + ExprCode(code, "false", ev.value) + } else { + // There is no columns + ExprCode("", "false", "unsafeRow") + } + } + } + /** * Consume the generated columns or row from current SparkPlan, call its parent's `doConsume()`. * @@ -126,28 +153,7 @@ trait CodegenSupport extends SparkPlan { } } - val rowVar = if (row != null) { - ExprCode("", "false", row) - } else { - if (outputVars.nonEmpty) { - val colExprs = output.zipWithIndex.map { case (attr, i) => - BoundReference(i, attr.dataType, attr.nullable) - } - val evaluateInputs = evaluateVariables(outputVars) - // generate the code to create a UnsafeRow - ctx.INPUT_ROW = row - ctx.currentVars = outputVars - val ev = GenerateUnsafeProjection.createCode(ctx, colExprs, false) - val code = s""" - |$evaluateInputs - |${ev.code.trim} - """.stripMargin.trim - ExprCode(code, "false", ev.value) - } else { - // There is no columns - ExprCode("", "false", "unsafeRow") - } - } + val rowVar = prepareRowVar(ctx, row, outputVars) // Set up the `currentVars` in the codegen context, as we generate the code of `inputVars` // before calling `parent.doConsume`. We can't set up `INPUT_ROW`, because parent needs to @@ -156,13 +162,96 @@ trait CodegenSupport extends SparkPlan { ctx.INPUT_ROW = null ctx.freshNamePrefix = parent.variablePrefix val evaluated = evaluateRequiredVariables(output, inputVars, parent.usedInputs) + + // Under certain conditions, we can put the logic to consume the rows of this operator into + // another function. So we can prevent a generated function too long to be optimized by JIT. + // The conditions: + // 1. The config "spark.sql.codegen.splitConsumeFuncByOperator" is enabled. + // 2. `inputVars` are all materialized. That is guaranteed to be true if the parent plan uses + // all variables in output (see `requireAllOutput`). + // 3. The number of output variables must less than maximum number of parameters in Java method + // declaration. + val confEnabled = SQLConf.get.wholeStageSplitConsumeFuncByOperator + val requireAllOutput = output.forall(parent.usedInputs.contains(_)) + val paramLength = ctx.calculateParamLength(output) + (if (row != null) 1 else 0) + val consumeFunc = if (confEnabled && requireAllOutput && ctx.isValidParamLength(paramLength)) { + constructDoConsumeFunction(ctx, inputVars, row) + } else { + parent.doConsume(ctx, inputVars, rowVar) + } s""" |${ctx.registerComment(s"CONSUME: ${parent.simpleString}")} |$evaluated - |${parent.doConsume(ctx, inputVars, rowVar)} + |$consumeFunc + """.stripMargin + } + + /** + * To prevent concatenated function growing too long to be optimized by JIT. We can separate the + * parent's `doConsume` codes of a `CodegenSupport` operator into a function to call. + */ + private def constructDoConsumeFunction( + ctx: CodegenContext, + inputVars: Seq[ExprCode], + row: String): String = { + val (args, params, inputVarsInFunc) = constructConsumeParameters(ctx, output, inputVars, row) + val rowVar = prepareRowVar(ctx, row, inputVarsInFunc) + + val doConsume = ctx.freshName("doConsume") + ctx.currentVars = inputVarsInFunc + ctx.INPUT_ROW = null + + val doConsumeFuncName = ctx.addNewFunction(doConsume, + s""" + | private void $doConsume(${params.mkString(", ")}) throws java.io.IOException { + | ${parent.doConsume(ctx, inputVarsInFunc, rowVar)} + | } + """.stripMargin) + + s""" + | $doConsumeFuncName(${args.mkString(", ")}); """.stripMargin } + /** + * Returns arguments for calling method and method definition parameters of the consume function. + * And also returns the list of `ExprCode` for the parameters. + */ + private def constructConsumeParameters( + ctx: CodegenContext, + attributes: Seq[Attribute], + variables: Seq[ExprCode], + row: String): (Seq[String], Seq[String], Seq[ExprCode]) = { + val arguments = mutable.ArrayBuffer[String]() + val parameters = mutable.ArrayBuffer[String]() + val paramVars = mutable.ArrayBuffer[ExprCode]() + + if (row != null) { + arguments += row + parameters += s"InternalRow $row" + } + + variables.zipWithIndex.foreach { case (ev, i) => + val paramName = ctx.freshName(s"expr_$i") + val paramType = ctx.javaType(attributes(i).dataType) + + arguments += ev.value + parameters += s"$paramType $paramName" + val paramIsNull = if (!attributes(i).nullable) { + // Use constant `false` without passing `isNull` for non-nullable variable. + "false" + } else { + val isNull = ctx.freshName(s"exprIsNull_$i") + arguments += ev.isNull + parameters += s"boolean $isNull" + isNull + } + + paramVars += ExprCode("", paramIsNull, paramName) + } + (arguments, parameters, paramVars) + } + /** * Returns source code to evaluate all the variables, and clear the code of them, to prevent * them to be evaluated twice. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index 22ca128c27768..242bb48c22942 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -205,7 +205,7 @@ class WholeStageCodegenSuite extends QueryTest with SharedSQLContext { val codeWithShortFunctions = genGroupByCode(3) val (_, maxCodeSize1) = CodeGenerator.compile(codeWithShortFunctions) assert(maxCodeSize1 < SQLConf.WHOLESTAGE_HUGE_METHOD_LIMIT.defaultValue.get) - val codeWithLongFunctions = genGroupByCode(20) + val codeWithLongFunctions = genGroupByCode(50) val (_, maxCodeSize2) = CodeGenerator.compile(codeWithLongFunctions) assert(maxCodeSize2 > SQLConf.WHOLESTAGE_HUGE_METHOD_LIMIT.defaultValue.get) } @@ -228,4 +228,49 @@ class WholeStageCodegenSuite extends QueryTest with SharedSQLContext { } } } + + test("Control splitting consume function by operators with config") { + import testImplicits._ + val df = spark.range(10).select(Seq.tabulate(2) {i => ('id + i).as(s"c$i")} : _*) + + Seq(true, false).foreach { config => + withSQLConf(SQLConf.WHOLESTAGE_SPLIT_CONSUME_FUNC_BY_OPERATOR.key -> s"$config") { + val plan = df.queryExecution.executedPlan + val wholeStageCodeGenExec = plan.find(p => p match { + case wp: WholeStageCodegenExec => true + case _ => false + }) + assert(wholeStageCodeGenExec.isDefined) + val code = wholeStageCodeGenExec.get.asInstanceOf[WholeStageCodegenExec].doCodeGen()._2 + assert(code.body.contains("project_doConsume") == config) + } + } + } + + test("Skip splitting consume function when parameter number exceeds JVM limit") { + import testImplicits._ + + Seq((255, false), (254, true)).foreach { case (columnNum, hasSplit) => + withTempPath { dir => + val path = dir.getCanonicalPath + spark.range(10).select(Seq.tabulate(columnNum) {i => ('id + i).as(s"c$i")} : _*) + .write.mode(SaveMode.Overwrite).parquet(path) + + withSQLConf(SQLConf.WHOLESTAGE_MAX_NUM_FIELDS.key -> "255", + SQLConf.WHOLESTAGE_SPLIT_CONSUME_FUNC_BY_OPERATOR.key -> "true") { + val projection = Seq.tabulate(columnNum)(i => s"c$i + c$i as newC$i") + val df = spark.read.parquet(path).selectExpr(projection: _*) + + val plan = df.queryExecution.executedPlan + val wholeStageCodeGenExec = plan.find(p => p match { + case wp: WholeStageCodegenExec => true + case _ => false + }) + assert(wholeStageCodeGenExec.isDefined) + val code = wholeStageCodeGenExec.get.asInstanceOf[WholeStageCodegenExec].doCodeGen()._2 + assert(code.body.contains("project_doConsume") == hasSplit) + } + } + } + } } From 8532e26f335b67b74c976712ad82c20ea6dbbf80 Mon Sep 17 00:00:00 2001 From: Nick Pentreath Date: Thu, 25 Jan 2018 15:01:22 +0200 Subject: [PATCH 0203/1161] [SPARK-23112][DOC] Add highlights and migration guide for 2.3 Update ML user guide with highlights and migration guide for `2.3`. ## How was this patch tested? Doc only. Author: Nick Pentreath Closes #20363 from MLnick/SPARK-23112-ml-guide. --- docs/ml-guide.md | 78 ++++++++++++++----------------------- docs/ml-migration-guides.md | 23 +++++++++++ 2 files changed, 52 insertions(+), 49 deletions(-) diff --git a/docs/ml-guide.md b/docs/ml-guide.md index f6288e7c32d97..b957445579ffd 100644 --- a/docs/ml-guide.md +++ b/docs/ml-guide.md @@ -72,32 +72,31 @@ To use MLlib in Python, you will need [NumPy](http://www.numpy.org) version 1.4 [^1]: To learn more about the benefits and background of system optimised natives, you may wish to watch Sam Halliday's ScalaX talk on [High Performance Linear Algebra in Scala](http://fommil.github.io/scalax14/#/). -# Highlights in 2.2 +# Highlights in 2.3 -The list below highlights some of the new features and enhancements added to MLlib in the `2.2` +The list below highlights some of the new features and enhancements added to MLlib in the `2.3` release of Spark: -* [`ALS`](ml-collaborative-filtering.html) methods for _top-k_ recommendations for all - users or items, matching the functionality in `mllib` - ([SPARK-19535](https://issues.apache.org/jira/browse/SPARK-19535)). - Performance was also improved for both `ml` and `mllib` - ([SPARK-11968](https://issues.apache.org/jira/browse/SPARK-11968) and - [SPARK-20587](https://issues.apache.org/jira/browse/SPARK-20587)) -* [`Correlation`](ml-statistics.html#correlation) and - [`ChiSquareTest`](ml-statistics.html#hypothesis-testing) stats functions for `DataFrames` - ([SPARK-19636](https://issues.apache.org/jira/browse/SPARK-19636) and - [SPARK-19635](https://issues.apache.org/jira/browse/SPARK-19635)) -* [`FPGrowth`](ml-frequent-pattern-mining.html#fp-growth) algorithm for frequent pattern mining - ([SPARK-14503](https://issues.apache.org/jira/browse/SPARK-14503)) -* `GLM` now supports the full `Tweedie` family - ([SPARK-18929](https://issues.apache.org/jira/browse/SPARK-18929)) -* [`Imputer`](ml-features.html#imputer) feature transformer to impute missing values in a dataset - ([SPARK-13568](https://issues.apache.org/jira/browse/SPARK-13568)) -* [`LinearSVC`](ml-classification-regression.html#linear-support-vector-machine) - for linear Support Vector Machine classification - ([SPARK-14709](https://issues.apache.org/jira/browse/SPARK-14709)) -* Logistic regression now supports constraints on the coefficients during training - ([SPARK-20047](https://issues.apache.org/jira/browse/SPARK-20047)) +* Built-in support for reading images into a `DataFrame` was added +([SPARK-21866](https://issues.apache.org/jira/browse/SPARK-21866)). +* [`OneHotEncoderEstimator`](ml-features.html#onehotencoderestimator) was added, and should be +used instead of the existing `OneHotEncoder` transformer. The new estimator supports +transforming multiple columns. +* Multiple column support was also added to `QuantileDiscretizer` and `Bucketizer` +([SPARK-22397](https://issues.apache.org/jira/browse/SPARK-22397) and +[SPARK-20542](https://issues.apache.org/jira/browse/SPARK-20542)) +* A new [`FeatureHasher`](ml-features.html#featurehasher) transformer was added + ([SPARK-13969](https://issues.apache.org/jira/browse/SPARK-13969)). +* Added support for evaluating multiple models in parallel when performing cross-validation using +[`TrainValidationSplit` or `CrossValidator`](ml-tuning.html) +([SPARK-19357](https://issues.apache.org/jira/browse/SPARK-19357)). +* Improved support for custom pipeline components in Python (see +[SPARK-21633](https://issues.apache.org/jira/browse/SPARK-21633) and +[SPARK-21542](https://issues.apache.org/jira/browse/SPARK-21542)). +* `DataFrame` functions for descriptive summary statistics over vector columns +([SPARK-19634](https://issues.apache.org/jira/browse/SPARK-19634)). +* Robust linear regression with Huber loss +([SPARK-3181](https://issues.apache.org/jira/browse/SPARK-3181)). # Migration guide @@ -115,36 +114,17 @@ There are no breaking changes. **Deprecations** -There are no deprecations. +* `OneHotEncoder` has been deprecated and will be removed in `3.0`. It has been replaced by the +new [`OneHotEncoderEstimator`](ml-features.html#onehotencoderestimator) +(see [SPARK-13030](https://issues.apache.org/jira/browse/SPARK-13030)). **Note** that +`OneHotEncoderEstimator` will be renamed to `OneHotEncoder` in `3.0` (but +`OneHotEncoderEstimator` will be kept as an alias). **Changes of behavior** * [SPARK-21027](https://issues.apache.org/jira/browse/SPARK-21027): - We are now setting the default parallelism used in `OneVsRest` to be 1 (i.e. serial), in 2.2 and earlier version, - the `OneVsRest` parallelism would be parallelism of the default threadpool in scala. - -## From 2.1 to 2.2 - -### Breaking changes - -There are no breaking changes. - -### Deprecations and changes of behavior - -**Deprecations** - -There are no deprecations. - -**Changes of behavior** - -* [SPARK-19787](https://issues.apache.org/jira/browse/SPARK-19787): - Default value of `regParam` changed from `1.0` to `0.1` for `ALS.train` method (marked `DeveloperApi`). - **Note** this does _not affect_ the `ALS` Estimator or Model, nor MLlib's `ALS` class. -* [SPARK-14772](https://issues.apache.org/jira/browse/SPARK-14772): - Fixed inconsistency between Python and Scala APIs for `Param.copy` method. -* [SPARK-11569](https://issues.apache.org/jira/browse/SPARK-11569): - `StringIndexer` now handles `NULL` values in the same way as unseen values. Previously an exception - would always be thrown regardless of the setting of the `handleInvalid` parameter. + We are now setting the default parallelism used in `OneVsRest` to be 1 (i.e. serial). In 2.2 and + earlier versions, the level of parallelism was set to the default threadpool size in Scala. ## Previous Spark versions diff --git a/docs/ml-migration-guides.md b/docs/ml-migration-guides.md index 687d7c8930362..f4b0df58cf63b 100644 --- a/docs/ml-migration-guides.md +++ b/docs/ml-migration-guides.md @@ -7,6 +7,29 @@ description: MLlib migration guides from before Spark SPARK_VERSION_SHORT The migration guide for the current Spark version is kept on the [MLlib Guide main page](ml-guide.html#migration-guide). +## From 2.1 to 2.2 + +### Breaking changes + +There are no breaking changes. + +### Deprecations and changes of behavior + +**Deprecations** + +There are no deprecations. + +**Changes of behavior** + +* [SPARK-19787](https://issues.apache.org/jira/browse/SPARK-19787): + Default value of `regParam` changed from `1.0` to `0.1` for `ALS.train` method (marked `DeveloperApi`). + **Note** this does _not affect_ the `ALS` Estimator or Model, nor MLlib's `ALS` class. +* [SPARK-14772](https://issues.apache.org/jira/browse/SPARK-14772): + Fixed inconsistency between Python and Scala APIs for `Param.copy` method. +* [SPARK-11569](https://issues.apache.org/jira/browse/SPARK-11569): + `StringIndexer` now handles `NULL` values in the same way as unseen values. Previously an exception + would always be thrown regardless of the setting of the `handleInvalid` parameter. + ## From 2.0 to 2.1 ### Breaking changes From 8480c0c57698b7dcccec5483d67b17cf2c7527ed Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Fri, 26 Jan 2018 07:50:48 +0900 Subject: [PATCH 0204/1161] [SPARK-23081][PYTHON] Add colRegex API to PySpark ## What changes were proposed in this pull request? Add colRegex API to PySpark ## How was this patch tested? add a test in sql/tests.py Author: Huaxin Gao Closes #20390 from huaxingao/spark-23081. --- python/pyspark/sql/dataframe.py | 23 +++++++++++++++++++ .../scala/org/apache/spark/sql/Dataset.scala | 8 +++---- 2 files changed, 27 insertions(+), 4 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 2d5e9b91468cf..ac403080acfdf 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -819,6 +819,29 @@ def columns(self): """ return [f.name for f in self.schema.fields] + @since(2.3) + def colRegex(self, colName): + """ + Selects column based on the column name specified as a regex and returns it + as :class:`Column`. + + :param colName: string, column name specified as a regex. + + >>> df = spark.createDataFrame([("a", 1), ("b", 2), ("c", 3)], ["Col1", "Col2"]) + >>> df.select(df.colRegex("`(Col1)?+.+`")).show() + +----+ + |Col2| + +----+ + | 1| + | 2| + | 3| + +----+ + """ + if not isinstance(colName, basestring): + raise ValueError("colName should be provided as string") + jc = self._jdf.colRegex(colName) + return Column(jc) + @ignore_unicode_prefix @since(1.3) def alias(self, alias): diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 912f411fa3845..edb6644ed5ac0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -1194,7 +1194,7 @@ class Dataset[T] private[sql]( def orderBy(sortExprs: Column*): Dataset[T] = sort(sortExprs : _*) /** - * Selects column based on the column name and return it as a [[Column]]. + * Selects column based on the column name and returns it as a [[Column]]. * * @note The column name can also reference to a nested column like `a.b`. * @@ -1220,7 +1220,7 @@ class Dataset[T] private[sql]( } /** - * Selects column based on the column name and return it as a [[Column]]. + * Selects column based on the column name and returns it as a [[Column]]. * * @note The column name can also reference to a nested column like `a.b`. * @@ -1240,7 +1240,7 @@ class Dataset[T] private[sql]( } /** - * Selects column based on the column name specified as a regex and return it as [[Column]]. + * Selects column based on the column name specified as a regex and returns it as [[Column]]. * @group untypedrel * @since 2.3.0 */ @@ -2729,7 +2729,7 @@ class Dataset[T] private[sql]( } /** - * Return an iterator that contains all rows in this Dataset. + * Returns an iterator that contains all rows in this Dataset. * * The iterator will consume as much memory as the largest partition in this Dataset. * From e57f394818b0a62f99609e1032fede7e981f306f Mon Sep 17 00:00:00 2001 From: Kris Mok Date: Thu, 25 Jan 2018 16:11:33 -0800 Subject: [PATCH 0205/1161] [SPARK-23032][SQL] Add a per-query codegenStageId to WholeStageCodegenExec ## What changes were proposed in this pull request? **Proposal** Add a per-query ID to the codegen stages as represented by `WholeStageCodegenExec` operators. This ID will be used in - the explain output of the physical plan, and in - the generated class name. Specifically, this ID will be stable within a query, counting up from 1 in depth-first post-order for all the `WholeStageCodegenExec` inserted into a plan. The ID value 0 is reserved for "free-floating" `WholeStageCodegenExec` objects, which may have been created for one-off purposes, e.g. for fallback handling of codegen stages that failed to codegen the whole stage and wishes to codegen a subset of the children operators (as seen in `org.apache.spark.sql.execution.FileSourceScanExec#doExecute`). Example: for the following query: ```scala scala> spark.conf.set("spark.sql.autoBroadcastJoinThreshold", 1) scala> val df1 = spark.range(10).select('id as 'x, 'id + 1 as 'y).orderBy('x).select('x + 1 as 'z, 'y) df1: org.apache.spark.sql.DataFrame = [z: bigint, y: bigint] scala> val df2 = spark.range(5) df2: org.apache.spark.sql.Dataset[Long] = [id: bigint] scala> val query = df1.join(df2, 'z === 'id) query: org.apache.spark.sql.DataFrame = [z: bigint, y: bigint ... 1 more field] ``` The explain output before the change is: ```scala scala> query.explain == Physical Plan == *SortMergeJoin [z#9L], [id#13L], Inner :- *Sort [z#9L ASC NULLS FIRST], false, 0 : +- Exchange hashpartitioning(z#9L, 200) : +- *Project [(x#3L + 1) AS z#9L, y#4L] : +- *Sort [x#3L ASC NULLS FIRST], true, 0 : +- Exchange rangepartitioning(x#3L ASC NULLS FIRST, 200) : +- *Project [id#0L AS x#3L, (id#0L + 1) AS y#4L] : +- *Range (0, 10, step=1, splits=8) +- *Sort [id#13L ASC NULLS FIRST], false, 0 +- Exchange hashpartitioning(id#13L, 200) +- *Range (0, 5, step=1, splits=8) ``` Note how codegen'd operators are annotated with a prefix `"*"`. See how the `SortMergeJoin` operator and its direct children `Sort` operators are adjacent and all annotated with the `"*"`, so it's hard to tell they're actually in separate codegen stages. and after this change it'll be: ```scala scala> query.explain == Physical Plan == *(6) SortMergeJoin [z#9L], [id#13L], Inner :- *(3) Sort [z#9L ASC NULLS FIRST], false, 0 : +- Exchange hashpartitioning(z#9L, 200) : +- *(2) Project [(x#3L + 1) AS z#9L, y#4L] : +- *(2) Sort [x#3L ASC NULLS FIRST], true, 0 : +- Exchange rangepartitioning(x#3L ASC NULLS FIRST, 200) : +- *(1) Project [id#0L AS x#3L, (id#0L + 1) AS y#4L] : +- *(1) Range (0, 10, step=1, splits=8) +- *(5) Sort [id#13L ASC NULLS FIRST], false, 0 +- Exchange hashpartitioning(id#13L, 200) +- *(4) Range (0, 5, step=1, splits=8) ``` Note that the annotated prefix becomes `"*(id) "`. See how the `SortMergeJoin` operator and its direct children `Sort` operators have different codegen stage IDs. It'll also show up in the name of the generated class, as a suffix in the format of `GeneratedClass$GeneratedIterator$id`. For example, note how `GeneratedClass$GeneratedIteratorForCodegenStage3` and `GeneratedClass$GeneratedIteratorForCodegenStage6` in the following stack trace corresponds to the IDs shown in the explain output above: ``` "Executor task launch worker for task 42412957" daemon prio=5 tid=0x58 nid=NA runnable java.lang.Thread.State: RUNNABLE at org.apache.spark.sql.execution.UnsafeExternalRowSorter.insertRow(UnsafeExternalRowSorter.java:109) at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage3.sort_addToSorter$(generated.java:32) at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage3.processNext(generated.java:41) at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43) at org.apache.spark.sql.execution.WholeStageCodegenExec$$anonfun$9$$anon$1.hasNext(WholeStageCodegenExec.scala:494) at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage6.findNextInnerJoinRows$(generated.java:42) at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage6.processNext(generated.java:101) at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43) at org.apache.spark.sql.execution.WholeStageCodegenExec$$anonfun$11$$anon$2.hasNext(WholeStageCodegenExec.scala:513) at org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:253) at org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:247) at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$25.apply(RDD.scala:828) at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$25.apply(RDD.scala:828) at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:38) at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324) at org.apache.spark.rdd.RDD.iterator(RDD.scala:288) at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:87) at org.apache.spark.scheduler.Task.run(Task.scala:109) at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:345) at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142) at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617) at java.lang.Thread.run(Thread.java:748) ``` **Rationale** Right now, the codegen from Spark SQL lacks the means to differentiate between a couple of things: 1. It's hard to tell which physical operators are in the same WholeStageCodegen stage. Note that this "stage" is a separate notion from Spark's RDD execution stages; this one is only to delineate codegen units. There can be adjacent physical operators that are both codegen'd but are in separate codegen stages. Some of this is due to hacky implementation details, such as the case with `SortMergeJoin` and its `Sort` inputs -- they're hard coded to be split into separate stages although both are codegen'd. When printing out the explain output of the physical plan, you'd only see the codegen'd physical operators annotated with a preceding star (`'*'`) but would have no way to figure out if they're in the same stage. 2. Performance/error diagnosis The generated code has class/method names that are hard to differentiate between queries or even between codegen stages within the same query. If we use a Java-level profiler to collect profiles, or if we encounter a Java-level exception with a stack trace in it, it's really hard to tell which part of a query it's at. By introducing a per-query codegen stage ID, we'd at least be able to know which codegen stage (and in turn, which group of physical operators) was a profile tick or an exception happened. The reason why this proposal uses a per-query ID is because it's stable within a query, so that multiple runs of the same query will see the same resulting IDs. This both benefits understandability for users, and also it plays well with the codegen cache in Spark SQL which uses the generated source code as the key. The downside to using per-query IDs as opposed to a per-session or globally incrementing ID is of course we can't tell apart different query runs with this ID alone. But for now I believe this is a good enough tradeoff. ## How was this patch tested? Existing tests. This PR does not involve any runtime behavior changes other than some name changes. The SQL query test suites that compares explain outputs have been updates to ignore the newly added `codegenStageId`. Author: Kris Mok Closes #20224 from rednaxelafx/wsc-codegenstageid. --- .../apache/spark/sql/internal/SQLConf.scala | 10 +++ .../sql/execution/DataSourceScanExec.scala | 2 +- .../sql/execution/WholeStageCodegenExec.scala | 85 +++++++++++++++++-- .../columnar/InMemoryTableScanExec.scala | 2 +- .../datasources/v2/DataSourceV2ScanExec.scala | 2 +- .../apache/spark/sql/SQLQueryTestSuite.scala | 3 +- .../execution/WholeStageCodegenSuite.scala | 34 ++++++++ .../columnar/InMemoryColumnarQuerySuite.scala | 2 +- .../sql/hive/execution/HiveExplainSuite.scala | 39 +++++++-- 9 files changed, 158 insertions(+), 21 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 470f88c213561..b0d18b6dced76 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -629,6 +629,14 @@ object SQLConf { .booleanConf .createWithDefault(true) + val WHOLESTAGE_CODEGEN_USE_ID_IN_CLASS_NAME = + buildConf("spark.sql.codegen.useIdInClassName") + .internal() + .doc("When true, embed the (whole-stage) codegen stage ID into " + + "the class name of the generated class as a suffix") + .booleanConf + .createWithDefault(true) + val WHOLESTAGE_MAX_NUM_FIELDS = buildConf("spark.sql.codegen.maxFields") .internal() .doc("The maximum number of fields (including nested fields) that will be supported before" + @@ -1264,6 +1272,8 @@ class SQLConf extends Serializable with Logging { def wholeStageEnabled: Boolean = getConf(WHOLESTAGE_CODEGEN_ENABLED) + def wholeStageUseIdInClassName: Boolean = getConf(WHOLESTAGE_CODEGEN_USE_ID_IN_CLASS_NAME) + def wholeStageMaxNumFields: Int = getConf(WHOLESTAGE_MAX_NUM_FIELDS) def codegenFallback: Boolean = getConf(CODEGEN_FALLBACK) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index 7c7d79c2bbd7c..aa66ee7e948ea 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -324,7 +324,7 @@ case class FileSourceScanExec( // in the case of fallback, this batched scan should never fail because of: // 1) only primitive types are supported // 2) the number of columns should be smaller than spark.sql.codegen.maxFields - WholeStageCodegenExec(this).execute() + WholeStageCodegenExec(this)(codegenStageId = 0).execute() } else { val unsafeRows = { val scan = inputRDD diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index 8ea9e81b2e53b..0e525b1e22eb9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution import java.util.Locale +import java.util.function.Supplier import scala.collection.mutable @@ -414,6 +415,58 @@ object WholeStageCodegenExec { } } +object WholeStageCodegenId { + // codegenStageId: ID for codegen stages within a query plan. + // It does not affect equality, nor does it participate in destructuring pattern matching + // of WholeStageCodegenExec. + // + // This ID is used to help differentiate between codegen stages. It is included as a part + // of the explain output for physical plans, e.g. + // + // == Physical Plan == + // *(5) SortMergeJoin [x#3L], [y#9L], Inner + // :- *(2) Sort [x#3L ASC NULLS FIRST], false, 0 + // : +- Exchange hashpartitioning(x#3L, 200) + // : +- *(1) Project [(id#0L % 2) AS x#3L] + // : +- *(1) Filter isnotnull((id#0L % 2)) + // : +- *(1) Range (0, 5, step=1, splits=8) + // +- *(4) Sort [y#9L ASC NULLS FIRST], false, 0 + // +- Exchange hashpartitioning(y#9L, 200) + // +- *(3) Project [(id#6L % 2) AS y#9L] + // +- *(3) Filter isnotnull((id#6L % 2)) + // +- *(3) Range (0, 5, step=1, splits=8) + // + // where the ID makes it obvious that not all adjacent codegen'd plan operators are of the + // same codegen stage. + // + // The codegen stage ID is also optionally included in the name of the generated classes as + // a suffix, so that it's easier to associate a generated class back to the physical operator. + // This is controlled by SQLConf: spark.sql.codegen.useIdInClassName + // + // The ID is also included in various log messages. + // + // Within a query, a codegen stage in a plan starts counting from 1, in "insertion order". + // WholeStageCodegenExec operators are inserted into a plan in depth-first post-order. + // See CollapseCodegenStages.insertWholeStageCodegen for the definition of insertion order. + // + // 0 is reserved as a special ID value to indicate a temporary WholeStageCodegenExec object + // is created, e.g. for special fallback handling when an existing WholeStageCodegenExec + // failed to generate/compile code. + + private val codegenStageCounter = ThreadLocal.withInitial(new Supplier[Integer] { + override def get() = 1 // TODO: change to Scala lambda syntax when upgraded to Scala 2.12+ + }) + + def resetPerQuery(): Unit = codegenStageCounter.set(1) + + def getNextStageId(): Int = { + val counter = codegenStageCounter + val id = counter.get() + counter.set(id + 1) + id + } +} + /** * WholeStageCodegen compiles a subtree of plans that support codegen together into single Java * function. @@ -442,7 +495,8 @@ object WholeStageCodegenExec { * `doCodeGen()` will create a `CodeGenContext`, which will hold a list of variables for input, * used to generated code for [[BoundReference]]. */ -case class WholeStageCodegenExec(child: SparkPlan) extends UnaryExecNode with CodegenSupport { +case class WholeStageCodegenExec(child: SparkPlan)(val codegenStageId: Int) + extends UnaryExecNode with CodegenSupport { override def output: Seq[Attribute] = child.output @@ -454,6 +508,12 @@ case class WholeStageCodegenExec(child: SparkPlan) extends UnaryExecNode with Co "pipelineTime" -> SQLMetrics.createTimingMetric(sparkContext, WholeStageCodegenExec.PIPELINE_DURATION_METRIC)) + def generatedClassName(): String = if (conf.wholeStageUseIdInClassName) { + s"GeneratedIteratorForCodegenStage$codegenStageId" + } else { + "GeneratedIterator" + } + /** * Generates code for this subtree. * @@ -471,19 +531,23 @@ case class WholeStageCodegenExec(child: SparkPlan) extends UnaryExecNode with Co } """, inlineToOuterClass = true) + val className = generatedClassName() + val source = s""" public Object generate(Object[] references) { - return new GeneratedIterator(references); + return new $className(references); } - ${ctx.registerComment(s"""Codegend pipeline for\n${child.treeString.trim}""")} - final class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator { + ${ctx.registerComment( + s"""Codegend pipeline for stage (id=$codegenStageId) + |${this.treeString.trim}""".stripMargin)} + final class $className extends ${classOf[BufferedRowIterator].getName} { private Object[] references; private scala.collection.Iterator[] inputs; ${ctx.declareMutableStates()} - public GeneratedIterator(Object[] references) { + public $className(Object[] references) { this.references = references; } @@ -516,7 +580,7 @@ case class WholeStageCodegenExec(child: SparkPlan) extends UnaryExecNode with Co } catch { case _: Exception if !Utils.isTesting && sqlContext.conf.codegenFallback => // We should already saw the error message - logWarning(s"Whole-stage codegen disabled for this plan:\n $treeString") + logWarning(s"Whole-stage codegen disabled for plan (id=$codegenStageId):\n $treeString") return child.execute() } @@ -525,7 +589,7 @@ case class WholeStageCodegenExec(child: SparkPlan) extends UnaryExecNode with Co logInfo(s"Found too long generated codes and JIT optimization might not work: " + s"the bytecode size ($maxCodeSize) is above the limit " + s"${sqlContext.conf.hugeMethodLimit}, and the whole-stage codegen was disabled " + - s"for this plan. To avoid this, you can raise the limit " + + s"for this plan (id=$codegenStageId). To avoid this, you can raise the limit " + s"`${SQLConf.WHOLESTAGE_HUGE_METHOD_LIMIT.key}`:\n$treeString") child match { // The fallback solution of batch file source scan still uses WholeStageCodegenExec @@ -603,10 +667,12 @@ case class WholeStageCodegenExec(child: SparkPlan) extends UnaryExecNode with Co verbose: Boolean, prefix: String = "", addSuffix: Boolean = false): StringBuilder = { - child.generateTreeString(depth, lastChildren, builder, verbose, "*") + child.generateTreeString(depth, lastChildren, builder, verbose, s"*($codegenStageId) ") } override def needStopCheck: Boolean = true + + override protected def otherCopyArgs: Seq[AnyRef] = Seq(codegenStageId.asInstanceOf[Integer]) } @@ -657,13 +723,14 @@ case class CollapseCodegenStages(conf: SQLConf) extends Rule[SparkPlan] { case plan if plan.output.length == 1 && plan.output.head.dataType.isInstanceOf[ObjectType] => plan.withNewChildren(plan.children.map(insertWholeStageCodegen)) case plan: CodegenSupport if supportCodegen(plan) => - WholeStageCodegenExec(insertInputAdapter(plan)) + WholeStageCodegenExec(insertInputAdapter(plan))(WholeStageCodegenId.getNextStageId()) case other => other.withNewChildren(other.children.map(insertWholeStageCodegen)) } def apply(plan: SparkPlan): SparkPlan = { if (conf.wholeStageEnabled) { + WholeStageCodegenId.resetPerQuery() insertWholeStageCodegen(plan) } else { plan diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala index 28b3875505cd2..c167f1e7dc621 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala @@ -274,7 +274,7 @@ case class InMemoryTableScanExec( protected override def doExecute(): RDD[InternalRow] = { if (supportsBatch) { - WholeStageCodegenExec(this).execute() + WholeStageCodegenExec(this)(codegenStageId = 0).execute() } else { inputRDD } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala index 69d871df3e1dd..2c22239e81869 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala @@ -88,7 +88,7 @@ case class DataSourceV2ScanExec( override protected def doExecute(): RDD[InternalRow] = { if (supportsBatch) { - WholeStageCodegenExec(this).execute() + WholeStageCodegenExec(this)(codegenStageId = 0).execute() } else { val numOutputRows = longMetric("numOutputRows") inputRDD.map { r => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala index 054ada56d99ad..beac9699585d5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala @@ -230,7 +230,8 @@ class SQLQueryTestSuite extends QueryTest with SharedSQLContext { .replaceAll("Location.*/sql/core/", s"Location ${notIncludedMsg}sql/core/") .replaceAll("Created By.*", s"Created By $notIncludedMsg") .replaceAll("Created Time.*", s"Created Time $notIncludedMsg") - .replaceAll("Last Access.*", s"Last Access $notIncludedMsg")) + .replaceAll("Last Access.*", s"Last Access $notIncludedMsg") + .replaceAll("\\*\\(\\d+\\) ", "*")) // remove the WholeStageCodegen codegenStageIds // If the output is not pre-sorted, sort it. if (isSorted(df.queryExecution.analyzed)) (schema, answer) else (schema, answer.sorted) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index 242bb48c22942..28ad712feaae6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution +import org.apache.spark.metrics.source.CodegenMetrics import org.apache.spark.sql.{QueryTest, Row, SaveMode} import org.apache.spark.sql.catalyst.expressions.codegen.{CodeAndComment, CodeGenerator} import org.apache.spark.sql.execution.aggregate.HashAggregateExec @@ -273,4 +274,37 @@ class WholeStageCodegenSuite extends QueryTest with SharedSQLContext { } } } + + test("codegen stage IDs should be preserved in transformations after CollapseCodegenStages") { + // test case adapted from DataFrameSuite to trigger ReuseExchange + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "2") { + val df = spark.range(100) + val join = df.join(df, "id") + val plan = join.queryExecution.executedPlan + assert(!plan.find(p => + p.isInstanceOf[WholeStageCodegenExec] && + p.asInstanceOf[WholeStageCodegenExec].codegenStageId == 0).isDefined, + "codegen stage IDs should be preserved through ReuseExchange") + checkAnswer(join, df.toDF) + } + } + + test("including codegen stage ID in generated class name should not regress codegen caching") { + import testImplicits._ + + withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_USE_ID_IN_CLASS_NAME.key -> "true") { + val bytecodeSizeHisto = CodegenMetrics.METRIC_GENERATED_METHOD_BYTECODE_SIZE + + // the same query run twice should hit the codegen cache + spark.range(3).select('id + 2).collect + val after1 = bytecodeSizeHisto.getCount + spark.range(3).select('id + 2).collect + val after2 = bytecodeSizeHisto.getCount // same query shape as above, deliberately + // bytecodeSizeHisto's count is always monotonically increasing if new compilation to + // bytecode had occurred. If the count stayed the same that means we've got a cache hit. + assert(after1 == after2, "Should hit codegen cache. No new compilation to bytecode expected") + + // a different query can result in codegen cache miss, that's by design + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala index ff7c5e58e9863..2280da927cf70 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala @@ -477,7 +477,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { assert(planBeforeFilter.head.isInstanceOf[InMemoryTableScanExec]) val execPlan = if (enabled == "true") { - WholeStageCodegenExec(planBeforeFilter.head) + WholeStageCodegenExec(planBeforeFilter.head)(codegenStageId = 0) } else { planBeforeFilter.head } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala index a4273de5fe260..f84d188075b72 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala @@ -154,14 +154,39 @@ class HiveExplainSuite extends QueryTest with SQLTestUtils with TestHiveSingleto } } - test("EXPLAIN CODEGEN command") { - checkKeywordsExist(sql("EXPLAIN CODEGEN SELECT 1"), - "WholeStageCodegen", - "Generated code:", - "/* 001 */ public Object generate(Object[] references) {", - "/* 002 */ return new GeneratedIterator(references);", - "/* 003 */ }" + test("explain output of physical plan should contain proper codegen stage ID") { + checkKeywordsExist(sql( + """ + |EXPLAIN SELECT t1.id AS a, t2.id AS b FROM + |(SELECT * FROM range(3)) t1 JOIN + |(SELECT * FROM range(10)) t2 ON t1.id == t2.id % 3 + """.stripMargin), + "== Physical Plan ==", + "*(2) Project ", + "+- *(2) BroadcastHashJoin ", + " :- BroadcastExchange ", + " : +- *(1) Range ", + " +- *(2) Range " ) + } + + test("EXPLAIN CODEGEN command") { + // the generated class name in this test should stay in sync with + // org.apache.spark.sql.execution.WholeStageCodegenExec.generatedClassName() + for ((useIdInClassName, expectedClassName) <- Seq( + ("true", "GeneratedIteratorForCodegenStage1"), + ("false", "GeneratedIterator"))) { + withSQLConf( + SQLConf.WHOLESTAGE_CODEGEN_USE_ID_IN_CLASS_NAME.key -> useIdInClassName) { + checkKeywordsExist(sql("EXPLAIN CODEGEN SELECT 1"), + "WholeStageCodegen", + "Generated code:", + "/* 001 */ public Object generate(Object[] references) {", + s"/* 002 */ return new $expectedClassName(references);", + "/* 003 */ }" + ) + } + } checkKeywordsNotExist(sql("EXPLAIN CODEGEN SELECT 1"), "== Physical Plan ==" From 7bd46d9871567597216cc02e1dc72ff5806ecdf8 Mon Sep 17 00:00:00 2001 From: Sid Murching Date: Thu, 25 Jan 2018 18:15:29 -0600 Subject: [PATCH 0206/1161] [SPARK-23205][ML] Update ImageSchema.readImages to correctly set alpha values for four-channel images ## What changes were proposed in this pull request? When parsing raw image data in ImageSchema.decode(), we use a [java.awt.Color](https://docs.oracle.com/javase/7/docs/api/java/awt/Color.html#Color(int)) constructor that sets alpha = 255, even for four-channel images (which may have different alpha values). This PR fixes this issue & adds a unit test to verify correctness of reading four-channel images. ## How was this patch tested? Updates an existing unit test ("readImages pixel values test" in `ImageSchemaSuite`) to also verify correctness when reading a four-channel image. Author: Sid Murching Closes #20389 from smurching/image-schema-bugfix. --- data/mllib/images/multi-channel/BGRA_alpha_60.png | Bin 0 -> 747 bytes .../org/apache/spark/ml/image/ImageSchema.scala | 5 ++--- .../apache/spark/ml/image/ImageSchemaSuite.scala | 9 ++++++--- 3 files changed, 8 insertions(+), 6 deletions(-) create mode 100644 data/mllib/images/multi-channel/BGRA_alpha_60.png diff --git a/data/mllib/images/multi-channel/BGRA_alpha_60.png b/data/mllib/images/multi-channel/BGRA_alpha_60.png new file mode 100644 index 0000000000000000000000000000000000000000..913637cd2828ab4e2ff4b2bbd92c4cf362f871c4 GIT binary patch literal 747 zcmV zL3V>M390}vf5ExS_g|iB()io}T zdzc-a9(QkC83JQkyn5Lt1Z5CrRi#xHpD{9oI-@&0jtqe@HUD3>kZs0McGO5TM~1+d z7FE_NXiaZ3z?maMU@%u%R=mvsm?J}AJhD?K_8)UCLtrp7B&y$7129;Iz!*D2yjN8< z0y9X4z_>+*xc3{0=Ex8j*D|Cx*=8j4K{5o!c7|A?%)le4CMiSsY+o_Vo>AV}VFh50 z41v-2iea`H1DYg5U<}czr}rCyox2Qyfqu6aXGTB<$q*RG3>nT0#|)AoFmf{_nrt%+ z=R=0T=#wE~ZFN$PgH=pIpR#r$}~v0vQ6s<%(gm8QC*8vEQiGG6cq@D~4&`dke3xtTJT?jHV1R zn*p1-WHaVkhQK(LAu?mT_I%GyhQKgo$ZgD^p$y@(n<2xRQG=Ep8^{nCn;A09ds8(A zT2-xU83JRGAy_kN4BT(jY8e8f?Uz2S`+1phqfY#&mLV|C{nDp(Kbg^7%Mci;cTmZU z&sv7SNV$V*STh2UAuvMkprV>#Myssnk&_`_<2W4$=?*U$0wXp)$=dj2RgM;~p7O*^V>AfDC~V@@+tF<5118 dqE*&-`~khx#S9AHa*F@}002ovPDHLkV1lO_Oc($F literal 0 HcmV?d00001 diff --git a/mllib/src/main/scala/org/apache/spark/ml/image/ImageSchema.scala b/mllib/src/main/scala/org/apache/spark/ml/image/ImageSchema.scala index f7850b238465b..dcc40b6668c7a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/image/ImageSchema.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/image/ImageSchema.scala @@ -169,12 +169,11 @@ object ImageSchema { var offset = 0 for (h <- 0 until height) { for (w <- 0 until width) { - val color = new Color(img.getRGB(w, h)) - + val color = new Color(img.getRGB(w, h), hasAlpha) decoded(offset) = color.getBlue.toByte decoded(offset + 1) = color.getGreen.toByte decoded(offset + 2) = color.getRed.toByte - if (nChannels == 4) { + if (hasAlpha) { decoded(offset + 3) = color.getAlpha.toByte } offset += nChannels diff --git a/mllib/src/test/scala/org/apache/spark/ml/image/ImageSchemaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/image/ImageSchemaSuite.scala index dba61cd1eb1cc..a8833c615865d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/image/ImageSchemaSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/image/ImageSchemaSuite.scala @@ -53,11 +53,11 @@ class ImageSchemaSuite extends SparkFunSuite with MLlibTestSparkContext { assert(df.count === 1) df = readImages(imagePath, null, true, -1, false, 1.0, 0) - assert(df.count === 9) + assert(df.count === 10) df = readImages(imagePath, null, true, -1, true, 1.0, 0) val countTotal = df.count - assert(countTotal === 7) + assert(countTotal === 8) df = readImages(imagePath, null, true, -1, true, 0.5, 0) // Random number about half of the size of the original dataset @@ -103,6 +103,9 @@ class ImageSchemaSuite extends SparkFunSuite with MLlibTestSparkContext { -71, -58, -56, -73, -64))), "BGRA.png" -> (("CV_8UC4", Array[Byte](-128, -128, -8, -1, -128, -128, -8, -1, -128, - -128, -8, -1, 127, 127, -9, -1, 127, 127, -9, -1))) + -128, -8, -1, 127, 127, -9, -1, 127, 127, -9, -1))), + "BGRA_alpha_60.png" -> (("CV_8UC4", + Array[Byte](-128, -128, -8, 60, -128, -128, -8, 60, -128, + -128, -8, 60, 127, 127, -9, 60, 127, 127, -9, 60))) ) } From 70a68b328b856c17eb22cc86fee0ebe8d64f8825 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Fri, 26 Jan 2018 11:58:20 +0800 Subject: [PATCH 0207/1161] [SPARK-23020][CORE] Fix race in SparkAppHandle cleanup, again. Third time is the charm? There was still a race that was left in previous attempts. If the handle closes the connection, the close() implementation would clean up state that would prevent the thread from waiting on the connection thread to finish. That could cause the race causing the test flakiness reported in the bug. The fix is to move the "wait for connection thread" code to a separate close method that is used by the handle; that also simplifies the code a bit and makes it also easier to follow. I included an unrelated, but correct, change to a YARN test so that it triggers when the PR is built. Tested by inserting a sleep in the connection thread to mimic the race; test failed reliably with the sleep, passes now. (Sleep not included in the patch.) Also ran YARN tests to make sure. Author: Marcelo Vanzin Closes #20388 from vanzin/SPARK-23020. --- .../spark/launcher/AbstractAppHandle.java | 42 ++++++++------ .../spark/launcher/ChildProcAppHandle.java | 11 +--- .../spark/launcher/InProcessAppHandle.java | 9 +-- .../apache/spark/launcher/LauncherServer.java | 55 +++++++++---------- .../spark/deploy/yarn/YarnClusterSuite.scala | 5 +- 5 files changed, 55 insertions(+), 67 deletions(-) diff --git a/launcher/src/main/java/org/apache/spark/launcher/AbstractAppHandle.java b/launcher/src/main/java/org/apache/spark/launcher/AbstractAppHandle.java index daf0972f824dd..84a25a5254151 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/AbstractAppHandle.java +++ b/launcher/src/main/java/org/apache/spark/launcher/AbstractAppHandle.java @@ -20,6 +20,7 @@ import java.io.IOException; import java.util.ArrayList; import java.util.List; +import java.util.concurrent.atomic.AtomicReference; import java.util.logging.Level; import java.util.logging.Logger; @@ -29,15 +30,15 @@ abstract class AbstractAppHandle implements SparkAppHandle { private final LauncherServer server; - private LauncherConnection connection; + private LauncherServer.ServerConnection connection; private List listeners; - private State state; + private AtomicReference state; private String appId; private volatile boolean disposed; protected AbstractAppHandle(LauncherServer server) { this.server = server; - this.state = State.UNKNOWN; + this.state = new AtomicReference<>(State.UNKNOWN); } @Override @@ -50,7 +51,7 @@ public synchronized void addListener(Listener l) { @Override public State getState() { - return state; + return state.get(); } @Override @@ -73,7 +74,7 @@ public synchronized void disconnect() { if (!isDisposed()) { if (connection != null) { try { - connection.close(); + connection.closeAndWait(); } catch (IOException ioe) { // no-op. } @@ -82,7 +83,7 @@ public synchronized void disconnect() { } } - void setConnection(LauncherConnection connection) { + void setConnection(LauncherServer.ServerConnection connection) { this.connection = connection; } @@ -99,12 +100,9 @@ boolean isDisposed() { */ synchronized void dispose() { if (!isDisposed()) { - // Unregister first to make sure that the connection with the app has been really - // terminated. server.unregister(this); - if (!getState().isFinal()) { - setState(State.LOST); - } + // Set state to LOST if not yet final. + setState(State.LOST, false); this.disposed = true; } } @@ -113,14 +111,24 @@ void setState(State s) { setState(s, false); } - synchronized void setState(State s, boolean force) { - if (force || !state.isFinal()) { - state = s; + void setState(State s, boolean force) { + if (force) { + state.set(s); fireEvent(false); - } else { - LOG.log(Level.WARNING, "Backend requested transition from final state {0} to {1}.", - new Object[] { state, s }); + return; } + + State current = state.get(); + while (!current.isFinal()) { + if (state.compareAndSet(current, s)) { + fireEvent(false); + return; + } + current = state.get(); + } + + LOG.log(Level.WARNING, "Backend requested transition from final state {0} to {1}.", + new Object[] { current, s }); } synchronized void setAppId(String appId) { diff --git a/launcher/src/main/java/org/apache/spark/launcher/ChildProcAppHandle.java b/launcher/src/main/java/org/apache/spark/launcher/ChildProcAppHandle.java index 2b99461652e1f..5e3c95676ecbe 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/ChildProcAppHandle.java +++ b/launcher/src/main/java/org/apache/spark/launcher/ChildProcAppHandle.java @@ -104,19 +104,12 @@ void monitorChild() { ec = 1; } - State currState = getState(); - State newState = null; if (ec != 0) { + State currState = getState(); // Override state with failure if the current state is not final, or is success. if (!currState.isFinal() || currState == State.FINISHED) { - newState = State.FAILED; + setState(State.FAILED, true); } - } else if (!currState.isFinal()) { - newState = State.LOST; - } - - if (newState != null) { - setState(newState, true); } disconnect(); diff --git a/launcher/src/main/java/org/apache/spark/launcher/InProcessAppHandle.java b/launcher/src/main/java/org/apache/spark/launcher/InProcessAppHandle.java index f04263cb74a58..b8030e0063a37 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/InProcessAppHandle.java +++ b/launcher/src/main/java/org/apache/spark/launcher/InProcessAppHandle.java @@ -66,14 +66,7 @@ synchronized void start(String appName, Method main, String[] args) { setState(State.FAILED); } - synchronized (InProcessAppHandle.this) { - if (!isDisposed()) { - disconnect(); - if (!getState().isFinal()) { - setState(State.LOST, true); - } - } - } + disconnect(); }); app.setName(String.format(THREAD_NAME_FMT, THREAD_IDS.incrementAndGet(), appName)); diff --git a/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java b/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java index 8091885c4f562..f4ecd52fdeab8 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java +++ b/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java @@ -218,32 +218,6 @@ void unregister(AbstractAppHandle handle) { } } - // If there is a live connection for this handle, we need to wait for it to finish before - // returning, otherwise there might be a race between the connection thread processing - // buffered data and the handle cleaning up after itself, leading to potentially the wrong - // state being reported for the handle. - ServerConnection conn = null; - synchronized (clients) { - for (ServerConnection c : clients) { - if (c.handle == handle) { - conn = c; - break; - } - } - } - - if (conn != null) { - synchronized (conn) { - if (conn.isOpen()) { - try { - conn.wait(); - } catch (InterruptedException ie) { - // Ignore. - } - } - } - } - unref(); } @@ -312,9 +286,10 @@ private String createSecret() { } } - private class ServerConnection extends LauncherConnection { + class ServerConnection extends LauncherConnection { private TimerTask timeout; + private volatile Thread connectionThread; volatile AbstractAppHandle handle; ServerConnection(Socket socket, TimerTask timeout) throws IOException { @@ -322,6 +297,12 @@ private class ServerConnection extends LauncherConnection { this.timeout = timeout; } + @Override + public void run() { + this.connectionThread = Thread.currentThread(); + super.run(); + } + @Override protected void handle(Message msg) throws IOException { try { @@ -376,9 +357,23 @@ public void close() throws IOException { clients.remove(this); } - synchronized (this) { - super.close(); - notifyAll(); + super.close(); + } + + /** + * Close the connection and wait for any buffered data to be processed before returning. + * This ensures any changes reported by the child application take effect. + */ + public void closeAndWait() throws IOException { + close(); + + Thread connThread = this.connectionThread; + if (Thread.currentThread() != connThread) { + try { + connThread.join(); + } catch (InterruptedException ie) { + // Ignore. + } } } diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala index e9dcfaf6ba4f0..5003326b440bf 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala @@ -45,8 +45,7 @@ import org.apache.spark.util.Utils /** * Integration tests for YARN; these tests use a mini Yarn cluster to run Spark-on-YARN - * applications, and require the Spark assembly to be built before they can be successfully - * run. + * applications. */ @ExtendedYarnTest class YarnClusterSuite extends BaseYarnClusterSuite { @@ -152,7 +151,7 @@ class YarnClusterSuite extends BaseYarnClusterSuite { } test("run Python application in yarn-cluster mode using " + - " spark.yarn.appMasterEnv to override local envvar") { + "spark.yarn.appMasterEnv to override local envvar") { testPySpark( clientMode = false, extraConf = Map( From d1721816d26bedee3c72eeb75db49da500568376 Mon Sep 17 00:00:00 2001 From: Santiago Saavedra Date: Fri, 26 Jan 2018 15:24:06 +0800 Subject: [PATCH 0208/1161] [SPARK-23200] Reset Kubernetes-specific config on Checkpoint restore ## What changes were proposed in this pull request? When using the Kubernetes cluster-manager and spawning a Streaming workload, it is important to reset many spark.kubernetes.* properties that are generated by spark-submit but which would get rewritten when restoring a Checkpoint. This is so, because the spark-submit codepath creates Kubernetes resources, such as a ConfigMap, a Secret and other variables, which have an autogenerated name and the previous one will not resolve anymore. In short, this change enables checkpoint restoration for streaming workloads, and thus enables Spark Streaming workloads in Kubernetes, which were not possible to restore from a checkpoint before if the workload went down. ## How was this patch tested? This patch was tested with the twitter-streaming example in AWS, using checkpoints in s3 with the s3a:// protocol, as supported by Hadoop. This is similar to the YARN related code for resetting a Spark Streaming workload, but for the Kubernetes scheduler. I'm adding the initcontainers properties because even if the discussion is not completely settled on the mailing list, my understanding is that at this moment they are going forward for the moment. For a previous discussion, see the non-rebased work at: https://github.com/apache-spark-on-k8s/spark/pull/516 Author: Santiago Saavedra Closes #20383 from ssaavedra/fix-k8s-checkpointing. --- .../org/apache/spark/streaming/Checkpoint.scala | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala index aed67a5027433..ed2a896033749 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala @@ -53,6 +53,21 @@ class Checkpoint(ssc: StreamingContext, val checkpointTime: Time) "spark.driver.host", "spark.driver.bindAddress", "spark.driver.port", + "spark.kubernetes.driver.pod.name", + "spark.kubernetes.executor.podNamePrefix", + "spark.kubernetes.initcontainer.executor.configmapname", + "spark.kubernetes.initcontainer.executor.configmapkey", + "spark.kubernetes.initcontainer.downloadJarsResourceIdentifier", + "spark.kubernetes.initcontainer.downloadJarsSecretLocation", + "spark.kubernetes.initcontainer.downloadFilesResourceIdentifier", + "spark.kubernetes.initcontainer.downloadFilesSecretLocation", + "spark.kubernetes.initcontainer.remoteJars", + "spark.kubernetes.initcontainer.remoteFiles", + "spark.kubernetes.mountdependencies.jarsDownloadDir", + "spark.kubernetes.mountdependencies.filesDownloadDir", + "spark.kubernetes.initcontainer.executor.stagingServerSecret.name", + "spark.kubernetes.initcontainer.executor.stagingServerSecret.mountDir", + "spark.kubernetes.executor.limit.cores", "spark.master", "spark.yarn.jars", "spark.yarn.keytab", @@ -66,6 +81,7 @@ class Checkpoint(ssc: StreamingContext, val checkpointTime: Time) val newSparkConf = new SparkConf(loadDefaults = false).setAll(sparkConfPairs) .remove("spark.driver.host") .remove("spark.driver.bindAddress") + .remove("spark.kubernetes.driver.pod.name") .remove("spark.driver.port") val newReloadConf = new SparkConf(loadDefaults = true) propertiesToReload.foreach { prop => From cd3956df0f96dd416b6161bf7ce2962e06d0a62e Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Fri, 26 Jan 2018 12:23:14 +0200 Subject: [PATCH 0209/1161] [SPARK-22799][ML] Bucketizer should throw exception if single- and multi-column params are both set ## What changes were proposed in this pull request? Currently there is a mixed situation when both single- and multi-column are supported. In some cases exceptions are thrown, in others only a warning log is emitted. In this discussion https://issues.apache.org/jira/browse/SPARK-8418?focusedCommentId=16275049&page=com.atlassian.jira.plugin.system.issuetabpanels:comment-tabpanel#comment-16275049, the decision was to throw an exception. The PR throws an exception in `Bucketizer`, instead of logging a warning. ## How was this patch tested? modified UT Author: Marco Gaido Author: Joseph K. Bradley Closes #19993 from mgaido91/SPARK-22799. --- .../apache/spark/ml/feature/Bucketizer.scala | 44 +++++------- .../org/apache/spark/ml/param/params.scala | 69 +++++++++++++++++++ .../spark/ml/feature/BucketizerSuite.scala | 41 +++++------ .../apache/spark/ml/param/ParamsSuite.scala | 22 ++++++ 4 files changed, 131 insertions(+), 45 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala index 8299a3e95d822..c13bf47eacb94 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala @@ -32,11 +32,13 @@ import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{DoubleType, StructField, StructType} /** - * `Bucketizer` maps a column of continuous features to a column of feature buckets. Since 2.3.0, + * `Bucketizer` maps a column of continuous features to a column of feature buckets. + * + * Since 2.3.0, * `Bucketizer` can map multiple columns at once by setting the `inputCols` parameter. Note that - * when both the `inputCol` and `inputCols` parameters are set, a log warning will be printed and - * only `inputCol` will take effect, while `inputCols` will be ignored. The `splits` parameter is - * only used for single column usage, and `splitsArray` is for multiple columns. + * when both the `inputCol` and `inputCols` parameters are set, an Exception will be thrown. The + * `splits` parameter is only used for single column usage, and `splitsArray` is for multiple + * columns. */ @Since("1.4.0") final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String) @@ -134,28 +136,11 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String @Since("2.3.0") def setOutputCols(value: Array[String]): this.type = set(outputCols, value) - /** - * Determines whether this `Bucketizer` is going to map multiple columns. If and only if - * `inputCols` is set, it will map multiple columns. Otherwise, it just maps a column specified - * by `inputCol`. A warning will be printed if both are set. - */ - private[feature] def isBucketizeMultipleColumns(): Boolean = { - if (isSet(inputCols) && isSet(inputCol)) { - logWarning("Both `inputCol` and `inputCols` are set, we ignore `inputCols` and this " + - "`Bucketizer` only map one column specified by `inputCol`") - false - } else if (isSet(inputCols)) { - true - } else { - false - } - } - @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { val transformedSchema = transformSchema(dataset.schema) - val (inputColumns, outputColumns) = if (isBucketizeMultipleColumns()) { + val (inputColumns, outputColumns) = if (isSet(inputCols)) { ($(inputCols).toSeq, $(outputCols).toSeq) } else { (Seq($(inputCol)), Seq($(outputCol))) @@ -170,7 +155,7 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String } } - val seqOfSplits = if (isBucketizeMultipleColumns()) { + val seqOfSplits = if (isSet(inputCols)) { $(splitsArray).toSeq } else { Seq($(splits)) @@ -201,9 +186,18 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String @Since("1.4.0") override def transformSchema(schema: StructType): StructType = { - if (isBucketizeMultipleColumns()) { + ParamValidators.checkSingleVsMultiColumnParams(this, Seq(outputCol, splits), + Seq(outputCols, splitsArray)) + + if (isSet(inputCols)) { + require(getInputCols.length == getOutputCols.length && + getInputCols.length == getSplitsArray.length, s"Bucketizer $this has mismatched Params " + + s"for multi-column transform. Params (inputCols, outputCols, splitsArray) should have " + + s"equal lengths, but they have different lengths: " + + s"(${getInputCols.length}, ${getOutputCols.length}, ${getSplitsArray.length}).") + var transformedSchema = schema - $(inputCols).zip($(outputCols)).zipWithIndex.map { case ((inputCol, outputCol), idx) => + $(inputCols).zip($(outputCols)).zipWithIndex.foreach { case ((inputCol, outputCol), idx) => SchemaUtils.checkNumericType(transformedSchema, inputCol) transformedSchema = SchemaUtils.appendColumn(transformedSchema, prepOutputField($(splitsArray)(idx), outputCol)) diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index 1b4b401ac4aa0..9a83a5882ce29 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -249,6 +249,75 @@ object ParamValidators { def arrayLengthGt[T](lowerBound: Double): Array[T] => Boolean = { (value: Array[T]) => value.length > lowerBound } + + /** + * Utility for Param validity checks for Transformers which have both single- and multi-column + * support. This utility assumes that `inputCol` indicates single-column usage and + * that `inputCols` indicates multi-column usage. + * + * This checks to ensure that exactly one set of Params has been set, and it + * raises an `IllegalArgumentException` if not. + * + * @param singleColumnParams Params which should be set (or have defaults) if `inputCol` has been + * set. This does not need to include `inputCol`. + * @param multiColumnParams Params which should be set (or have defaults) if `inputCols` has been + * set. This does not need to include `inputCols`. + */ + def checkSingleVsMultiColumnParams( + model: Params, + singleColumnParams: Seq[Param[_]], + multiColumnParams: Seq[Param[_]]): Unit = { + val name = s"${model.getClass.getSimpleName} $model" + + def checkExclusiveParams( + isSingleCol: Boolean, + requiredParams: Seq[Param[_]], + excludedParams: Seq[Param[_]]): Unit = { + val badParamsMsgBuilder = new mutable.StringBuilder() + + val mustUnsetParams = excludedParams.filter(p => model.isSet(p)) + .map(_.name).mkString(", ") + if (mustUnsetParams.nonEmpty) { + badParamsMsgBuilder ++= + s"The following Params are not applicable and should not be set: $mustUnsetParams." + } + + val mustSetParams = requiredParams.filter(p => !model.isDefined(p)) + .map(_.name).mkString(", ") + if (mustSetParams.nonEmpty) { + badParamsMsgBuilder ++= + s"The following Params must be defined but are not set: $mustSetParams." + } + + val badParamsMsg = badParamsMsgBuilder.toString() + + if (badParamsMsg.nonEmpty) { + val errPrefix = if (isSingleCol) { + s"$name has the inputCol Param set for single-column transform." + } else { + s"$name has the inputCols Param set for multi-column transform." + } + throw new IllegalArgumentException(s"$errPrefix $badParamsMsg") + } + } + + val inputCol = model.getParam("inputCol") + val inputCols = model.getParam("inputCols") + + if (model.isSet(inputCol)) { + require(!model.isSet(inputCols), s"$name requires " + + s"exactly one of inputCol, inputCols Params to be set, but both are set.") + + checkExclusiveParams(isSingleCol = true, requiredParams = singleColumnParams, + excludedParams = multiColumnParams) + } else if (model.isSet(inputCols)) { + checkExclusiveParams(isSingleCol = false, requiredParams = multiColumnParams, + excludedParams = singleColumnParams) + } else { + throw new IllegalArgumentException(s"$name requires " + + s"exactly one of inputCol, inputCols Params to be set, but neither is set.") + } + } } // specialize primitive-typed params because Java doesn't recognize scala.Double, scala.Int, ... diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala index d9c97ae8067d3..7403680ae3fdc 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala @@ -216,8 +216,6 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa .setOutputCols(Array("result1", "result2")) .setSplitsArray(splits) - assert(bucketizer1.isBucketizeMultipleColumns()) - bucketizer1.transform(dataFrame).select("result1", "expected1", "result2", "expected2") BucketizerSuite.checkBucketResults(bucketizer1.transform(dataFrame), Seq("result1", "result2"), @@ -233,8 +231,6 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa .setOutputCols(Array("result")) .setSplitsArray(Array(splits(0))) - assert(bucketizer2.isBucketizeMultipleColumns()) - withClue("Invalid feature value -0.9 was not caught as an invalid feature!") { intercept[SparkException] { bucketizer2.transform(badDF1).collect() @@ -268,8 +264,6 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa .setOutputCols(Array("result1", "result2")) .setSplitsArray(splits) - assert(bucketizer.isBucketizeMultipleColumns()) - BucketizerSuite.checkBucketResults(bucketizer.transform(dataFrame), Seq("result1", "result2"), Seq("expected1", "expected2")) @@ -295,8 +289,6 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa .setOutputCols(Array("result1", "result2")) .setSplitsArray(splits) - assert(bucketizer.isBucketizeMultipleColumns()) - bucketizer.setHandleInvalid("keep") BucketizerSuite.checkBucketResults(bucketizer.transform(dataFrame), Seq("result1", "result2"), @@ -335,7 +327,6 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa .setInputCols(Array("myInputCol")) .setOutputCols(Array("myOutputCol")) .setSplitsArray(Array(Array(0.1, 0.8, 0.9))) - assert(t.isBucketizeMultipleColumns()) testDefaultReadWrite(t) } @@ -348,8 +339,6 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa .setOutputCols(Array("result1", "result2")) .setSplitsArray(Array(Array(-0.5, 0.0, 0.5), Array(-0.5, 0.0, 0.5))) - assert(bucket.isBucketizeMultipleColumns()) - val pl = new Pipeline() .setStages(Array(bucket)) .fit(df) @@ -401,15 +390,27 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa } } - test("Both inputCol and inputCols are set") { - val bucket = new Bucketizer() - .setInputCol("feature1") - .setOutputCol("result") - .setSplits(Array(-0.5, 0.0, 0.5)) - .setInputCols(Array("feature1", "feature2")) - - // When both are set, we ignore `inputCols` and just map the column specified by `inputCol`. - assert(bucket.isBucketizeMultipleColumns() == false) + test("assert exception is thrown if both multi-column and single-column params are set") { + val df = Seq((0.5, 0.3), (0.5, -0.4)).toDF("feature1", "feature2") + ParamsSuite.testExclusiveParams(new Bucketizer, df, ("inputCol", "feature1"), + ("inputCols", Array("feature1", "feature2"))) + ParamsSuite.testExclusiveParams(new Bucketizer, df, ("inputCol", "feature1"), + ("outputCol", "result1"), ("splits", Array(-0.5, 0.0, 0.5)), + ("outputCols", Array("result1", "result2"))) + ParamsSuite.testExclusiveParams(new Bucketizer, df, ("inputCol", "feature1"), + ("outputCol", "result1"), ("splits", Array(-0.5, 0.0, 0.5)), + ("splitsArray", Array(Array(-0.5, 0.0, 0.5), Array(-0.5, 0.0, 0.5)))) + + // this should fail because at least one of inputCol and inputCols must be set + ParamsSuite.testExclusiveParams(new Bucketizer, df, ("outputCol", "feature1"), + ("splits", Array(-0.5, 0.0, 0.5))) + + // the following should fail because not all the params are set + ParamsSuite.testExclusiveParams(new Bucketizer, df, ("inputCol", "feature1"), + ("outputCol", "result1")) + ParamsSuite.testExclusiveParams(new Bucketizer, df, + ("inputCols", Array("feature1", "feature2")), + ("outputCols", Array("result1", "result2"))) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala index 85198ad4c913a..36e06091d24de 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala @@ -20,8 +20,10 @@ package org.apache.spark.ml.param import java.io.{ByteArrayOutputStream, ObjectOutputStream} import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.{Estimator, Transformer} import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.util.MyParams +import org.apache.spark.sql.Dataset class ParamsSuite extends SparkFunSuite { @@ -430,4 +432,24 @@ object ParamsSuite extends SparkFunSuite { require(copyReturnType === obj.getClass, s"${clazz.getName}.copy should return ${clazz.getName} instead of ${copyReturnType.getName}.") } + + /** + * Checks that the class throws an exception in case multiple exclusive params are set. + * The params to be checked are passed as arguments with their value. + */ + def testExclusiveParams( + model: Params, + dataset: Dataset[_], + paramsAndValues: (String, Any)*): Unit = { + val m = model.copy(ParamMap.empty) + paramsAndValues.foreach { case (paramName, paramValue) => + m.set(m.getParam(paramName), paramValue) + } + intercept[IllegalArgumentException] { + m match { + case t: Transformer => t.transform(dataset) + case e: Estimator[_] => e.fit(dataset) + } + } + } } From c22eaa94e85aaac649566495dcf763a5de3c8d06 Mon Sep 17 00:00:00 2001 From: Zheng RuiFeng Date: Fri, 26 Jan 2018 12:28:27 +0200 Subject: [PATCH 0210/1161] [SPARK-22797][PYSPARK] Bucketizer support multi-column ## What changes were proposed in this pull request? Bucketizer support multi-column in the python side ## How was this patch tested? existing tests and added tests Author: Zheng RuiFeng Closes #19892 from zhengruifeng/20542_py. --- python/pyspark/ml/feature.py | 105 +++++++++++++++++++++------- python/pyspark/ml/param/__init__.py | 10 +++ python/pyspark/ml/tests.py | 9 +++ 3 files changed, 99 insertions(+), 25 deletions(-) diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index da85ba761a145..fdc7787140490 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -317,26 +317,33 @@ class BucketedRandomProjectionLSHModel(LSHModel, JavaMLReadable, JavaMLWritable) @inherit_doc -class Bucketizer(JavaTransformer, HasInputCol, HasOutputCol, HasHandleInvalid, - JavaMLReadable, JavaMLWritable): - """ - Maps a column of continuous features to a column of feature buckets. - - >>> values = [(0.1,), (0.4,), (1.2,), (1.5,), (float("nan"),), (float("nan"),)] - >>> df = spark.createDataFrame(values, ["values"]) +class Bucketizer(JavaTransformer, HasInputCol, HasOutputCol, HasInputCols, HasOutputCols, + HasHandleInvalid, JavaMLReadable, JavaMLWritable): + """ + Maps a column of continuous features to a column of feature buckets. Since 2.3.0, + :py:class:`Bucketizer` can map multiple columns at once by setting the :py:attr:`inputCols` + parameter. Note that when both the :py:attr:`inputCol` and :py:attr:`inputCols` parameters + are set, an Exception will be thrown. The :py:attr:`splits` parameter is only used for single + column usage, and :py:attr:`splitsArray` is for multiple columns. + + >>> values = [(0.1, 0.0), (0.4, 1.0), (1.2, 1.3), (1.5, float("nan")), + ... (float("nan"), 1.0), (float("nan"), 0.0)] + >>> df = spark.createDataFrame(values, ["values1", "values2"]) >>> bucketizer = Bucketizer(splits=[-float("inf"), 0.5, 1.4, float("inf")], - ... inputCol="values", outputCol="buckets") - >>> bucketed = bucketizer.setHandleInvalid("keep").transform(df).collect() - >>> len(bucketed) - 6 - >>> bucketed[0].buckets - 0.0 - >>> bucketed[1].buckets - 0.0 - >>> bucketed[2].buckets - 1.0 - >>> bucketed[3].buckets - 2.0 + ... inputCol="values1", outputCol="buckets") + >>> bucketed = bucketizer.setHandleInvalid("keep").transform(df.select("values1")) + >>> bucketed.show(truncate=False) + +-------+-------+ + |values1|buckets| + +-------+-------+ + |0.1 |0.0 | + |0.4 |0.0 | + |1.2 |1.0 | + |1.5 |2.0 | + |NaN |3.0 | + |NaN |3.0 | + +-------+-------+ + ... >>> bucketizer.setParams(outputCol="b").transform(df).head().b 0.0 >>> bucketizerPath = temp_path + "/bucketizer" @@ -347,6 +354,22 @@ class Bucketizer(JavaTransformer, HasInputCol, HasOutputCol, HasHandleInvalid, >>> bucketed = bucketizer.setHandleInvalid("skip").transform(df).collect() >>> len(bucketed) 4 + >>> bucketizer2 = Bucketizer(splitsArray= + ... [[-float("inf"), 0.5, 1.4, float("inf")], [-float("inf"), 0.5, float("inf")]], + ... inputCols=["values1", "values2"], outputCols=["buckets1", "buckets2"]) + >>> bucketed2 = bucketizer2.setHandleInvalid("keep").transform(df) + >>> bucketed2.show(truncate=False) + +-------+-------+--------+--------+ + |values1|values2|buckets1|buckets2| + +-------+-------+--------+--------+ + |0.1 |0.0 |0.0 |0.0 | + |0.4 |1.0 |0.0 |1.0 | + |1.2 |1.3 |1.0 |1.0 | + |1.5 |NaN |2.0 |2.0 | + |NaN |1.0 |3.0 |1.0 | + |NaN |0.0 |3.0 |0.0 | + +-------+-------+--------+--------+ + ... .. versionadded:: 1.4.0 """ @@ -363,14 +386,30 @@ class Bucketizer(JavaTransformer, HasInputCol, HasOutputCol, HasHandleInvalid, handleInvalid = Param(Params._dummy(), "handleInvalid", "how to handle invalid entries. " + "Options are 'skip' (filter out rows with invalid values), " + - "'error' (throw an error), or 'keep' (keep invalid values in a special " + - "additional bucket).", + "'error' (throw an error), or 'keep' (keep invalid values in a " + + "special additional bucket). Note that in the multiple column " + + "case, the invalid handling is applied to all columns. That said " + + "for 'error' it will throw an error if any invalids are found in " + + "any column, for 'skip' it will skip rows with any invalids in " + + "any columns, etc.", typeConverter=TypeConverters.toString) + splitsArray = Param(Params._dummy(), "splitsArray", "The array of split points for mapping " + + "continuous features into buckets for multiple columns. For each input " + + "column, with n+1 splits, there are n buckets. A bucket defined by " + + "splits x,y holds values in the range [x,y) except the last bucket, " + + "which also includes y. The splits should be of length >= 3 and " + + "strictly increasing. Values at -inf, inf must be explicitly provided " + + "to cover all Double values; otherwise, values outside the splits " + + "specified will be treated as errors.", + typeConverter=TypeConverters.toListListFloat) + @keyword_only - def __init__(self, splits=None, inputCol=None, outputCol=None, handleInvalid="error"): + def __init__(self, splits=None, inputCol=None, outputCol=None, handleInvalid="error", + splitsArray=None, inputCols=None, outputCols=None): """ - __init__(self, splits=None, inputCol=None, outputCol=None, handleInvalid="error") + __init__(self, splits=None, inputCol=None, outputCol=None, handleInvalid="error", \ + splitsArray=None, inputCols=None, outputCols=None) """ super(Bucketizer, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.Bucketizer", self.uid) @@ -380,9 +419,11 @@ def __init__(self, splits=None, inputCol=None, outputCol=None, handleInvalid="er @keyword_only @since("1.4.0") - def setParams(self, splits=None, inputCol=None, outputCol=None, handleInvalid="error"): + def setParams(self, splits=None, inputCol=None, outputCol=None, handleInvalid="error", + splitsArray=None, inputCols=None, outputCols=None): """ - setParams(self, splits=None, inputCol=None, outputCol=None, handleInvalid="error") + setParams(self, splits=None, inputCol=None, outputCol=None, handleInvalid="error", \ + splitsArray=None, inputCols=None, outputCols=None) Sets params for this Bucketizer. """ kwargs = self._input_kwargs @@ -402,6 +443,20 @@ def getSplits(self): """ return self.getOrDefault(self.splits) + @since("2.3.0") + def setSplitsArray(self, value): + """ + Sets the value of :py:attr:`splitsArray`. + """ + return self._set(splitsArray=value) + + @since("2.3.0") + def getSplitsArray(self): + """ + Gets the array of split points or its default value. + """ + return self.getOrDefault(self.splitsArray) + @inherit_doc class CountVectorizer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): diff --git a/python/pyspark/ml/param/__init__.py b/python/pyspark/ml/param/__init__.py index 043c25cf9feb4..5b6b70292f099 100644 --- a/python/pyspark/ml/param/__init__.py +++ b/python/pyspark/ml/param/__init__.py @@ -134,6 +134,16 @@ def toListFloat(value): return [float(v) for v in value] raise TypeError("Could not convert %s to list of floats" % value) + @staticmethod + def toListListFloat(value): + """ + Convert a value to list of list of floats, if possible. + """ + if TypeConverters._can_convert_to_list(value): + value = TypeConverters.toList(value) + return [TypeConverters.toListFloat(v) for v in value] + raise TypeError("Could not convert %s to list of list of floats" % value) + @staticmethod def toListInt(value): """ diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 1af2b91da900d..b8bddbd06f165 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -238,6 +238,15 @@ def test_bool(self): self.assertRaises(TypeError, lambda: LogisticRegression(fitIntercept=1)) self.assertRaises(TypeError, lambda: LogisticRegression(fitIntercept="false")) + def test_list_list_float(self): + b = Bucketizer(splitsArray=[[-0.1, 0.5, 3], [-5, 1.5]]) + self.assertEqual(b.getSplitsArray(), [[-0.1, 0.5, 3.0], [-5.0, 1.5]]) + self.assertTrue(all([type(v) == list for v in b.getSplitsArray()])) + self.assertTrue(all([type(v) == float for v in b.getSplitsArray()[0]])) + self.assertTrue(all([type(v) == float for v in b.getSplitsArray()[1]])) + self.assertRaises(TypeError, lambda: Bucketizer(splitsArray=["a", 1.0])) + self.assertRaises(TypeError, lambda: Bucketizer(splitsArray=[[-5, 1.5], ["a", 1.0]])) + class PipelineTests(PySparkTestCase): From 3e252514741447004f3c18ddd77c617b4e37cfaa Mon Sep 17 00:00:00 2001 From: Xianyang Liu Date: Fri, 26 Jan 2018 19:18:18 +0800 Subject: [PATCH 0211/1161] [SPARK-22068][CORE] Reduce the duplicate code between putIteratorAsValues and putIteratorAsBytes ## What changes were proposed in this pull request? The code logic between `MemoryStore.putIteratorAsValues` and `Memory.putIteratorAsBytes` are almost same, so we should reduce the duplicate code between them. ## How was this patch tested? Existing UT. Author: Xianyang Liu Closes #19285 from ConeyLiu/rmemorystore. --- .../spark/storage/memory/MemoryStore.scala | 336 ++++++++++-------- 1 file changed, 178 insertions(+), 158 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala index 17f7a69ad6ba1..4cc5bcb7f9baf 100644 --- a/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala @@ -162,7 +162,7 @@ private[spark] class MemoryStore( } /** - * Attempt to put the given block in memory store as values. + * Attempt to put the given block in memory store as values or bytes. * * It's possible that the iterator is too large to materialize and store in memory. To avoid * OOM exceptions, this method will gradually unroll the iterator while periodically checking @@ -170,18 +170,24 @@ private[spark] class MemoryStore( * temporary unroll memory used during the materialization is "transferred" to storage memory, * so we won't acquire more memory than is actually needed to store the block. * - * @return in case of success, the estimated size of the stored data. In case of failure, return - * an iterator containing the values of the block. The returned iterator will be backed - * by the combination of the partially-unrolled block and the remaining elements of the - * original input iterator. The caller must either fully consume this iterator or call - * `close()` on it in order to free the storage memory consumed by the partially-unrolled - * block. + * @param blockId The block id. + * @param values The values which need be stored. + * @param classTag the [[ClassTag]] for the block. + * @param memoryMode The values saved memory mode(ON_HEAP or OFF_HEAP). + * @param valuesHolder A holder that supports storing record of values into memory store as + * values or bytes. + * @return if the block is stored successfully, return the stored data size. Else return the + * memory has reserved for unrolling the block (There are two reasons for store failed: + * First, the block is partially-unrolled; second, the block is entirely unrolled and + * the actual stored data size is larger than reserved, but we can't request extra + * memory). */ - private[storage] def putIteratorAsValues[T]( + private def putIterator[T]( blockId: BlockId, values: Iterator[T], - classTag: ClassTag[T]): Either[PartiallyUnrolledIterator[T], Long] = { - + classTag: ClassTag[T], + memoryMode: MemoryMode, + valuesHolder: ValuesHolder[T]): Either[Long, Long] = { require(!contains(blockId), s"Block $blockId is already present in the MemoryStore") // Number of elements unrolled so far @@ -198,12 +204,10 @@ private[spark] class MemoryStore( val memoryGrowthFactor = conf.get(UNROLL_MEMORY_GROWTH_FACTOR) // Keep track of unroll memory used by this particular block / putIterator() operation var unrollMemoryUsedByThisBlock = 0L - // Underlying vector for unrolling the block - var vector = new SizeTrackingVector[T]()(classTag) // Request enough memory to begin unrolling keepUnrolling = - reserveUnrollMemoryForThisTask(blockId, initialMemoryThreshold, MemoryMode.ON_HEAP) + reserveUnrollMemoryForThisTask(blockId, initialMemoryThreshold, memoryMode) if (!keepUnrolling) { logWarning(s"Failed to reserve initial memory threshold of " + @@ -214,14 +218,14 @@ private[spark] class MemoryStore( // Unroll this block safely, checking whether we have exceeded our threshold periodically while (values.hasNext && keepUnrolling) { - vector += values.next() + valuesHolder.storeValue(values.next()) if (elementsUnrolled % memoryCheckPeriod == 0) { + val currentSize = valuesHolder.estimatedSize() // If our vector's size has exceeded the threshold, request more memory - val currentSize = vector.estimateSize() if (currentSize >= memoryThreshold) { val amountToRequest = (currentSize * memoryGrowthFactor - memoryThreshold).toLong keepUnrolling = - reserveUnrollMemoryForThisTask(blockId, amountToRequest, MemoryMode.ON_HEAP) + reserveUnrollMemoryForThisTask(blockId, amountToRequest, memoryMode) if (keepUnrolling) { unrollMemoryUsedByThisBlock += amountToRequest } @@ -232,78 +236,86 @@ private[spark] class MemoryStore( elementsUnrolled += 1 } + // Make sure that we have enough memory to store the block. By this point, it is possible that + // the block's actual memory usage has exceeded the unroll memory by a small amount, so we + // perform one final call to attempt to allocate additional memory if necessary. if (keepUnrolling) { - // We successfully unrolled the entirety of this block - val arrayValues = vector.toArray - vector = null - val entry = - new DeserializedMemoryEntry[T](arrayValues, SizeEstimator.estimate(arrayValues), classTag) - val size = entry.size - def transferUnrollToStorage(amount: Long): Unit = { + val entryBuilder = valuesHolder.getBuilder() + val size = entryBuilder.preciseSize + if (size > unrollMemoryUsedByThisBlock) { + val amountToRequest = size - unrollMemoryUsedByThisBlock + keepUnrolling = reserveUnrollMemoryForThisTask(blockId, amountToRequest, memoryMode) + if (keepUnrolling) { + unrollMemoryUsedByThisBlock += amountToRequest + } + } + + if (keepUnrolling) { + val entry = entryBuilder.build() // Synchronize so that transfer is atomic memoryManager.synchronized { - releaseUnrollMemoryForThisTask(MemoryMode.ON_HEAP, amount) - val success = memoryManager.acquireStorageMemory(blockId, amount, MemoryMode.ON_HEAP) + releaseUnrollMemoryForThisTask(memoryMode, unrollMemoryUsedByThisBlock) + val success = memoryManager.acquireStorageMemory(blockId, entry.size, memoryMode) assert(success, "transferring unroll memory to storage memory failed") } - } - // Acquire storage memory if necessary to store this block in memory. - val enoughStorageMemory = { - if (unrollMemoryUsedByThisBlock <= size) { - val acquiredExtra = - memoryManager.acquireStorageMemory( - blockId, size - unrollMemoryUsedByThisBlock, MemoryMode.ON_HEAP) - if (acquiredExtra) { - transferUnrollToStorage(unrollMemoryUsedByThisBlock) - } - acquiredExtra - } else { // unrollMemoryUsedByThisBlock > size - // If this task attempt already owns more unroll memory than is necessary to store the - // block, then release the extra memory that will not be used. - val excessUnrollMemory = unrollMemoryUsedByThisBlock - size - releaseUnrollMemoryForThisTask(MemoryMode.ON_HEAP, excessUnrollMemory) - transferUnrollToStorage(size) - true - } - } - if (enoughStorageMemory) { + entries.synchronized { entries.put(blockId, entry) } - logInfo("Block %s stored as values in memory (estimated size %s, free %s)".format( - blockId, Utils.bytesToString(size), Utils.bytesToString(maxMemory - blocksMemoryUsed))) - Right(size) + + logInfo("Block %s stored as values in memory (estimated size %s, free %s)".format(blockId, + Utils.bytesToString(entry.size), Utils.bytesToString(maxMemory - blocksMemoryUsed))) + Right(entry.size) } else { - assert(currentUnrollMemoryForThisTask >= unrollMemoryUsedByThisBlock, - "released too much unroll memory") + // We ran out of space while unrolling the values for this block + logUnrollFailureMessage(blockId, entryBuilder.preciseSize) + Left(unrollMemoryUsedByThisBlock) + } + } else { + // We ran out of space while unrolling the values for this block + logUnrollFailureMessage(blockId, valuesHolder.estimatedSize()) + Left(unrollMemoryUsedByThisBlock) + } + } + + /** + * Attempt to put the given block in memory store as values. + * + * @return in case of success, the estimated size of the stored data. In case of failure, return + * an iterator containing the values of the block. The returned iterator will be backed + * by the combination of the partially-unrolled block and the remaining elements of the + * original input iterator. The caller must either fully consume this iterator or call + * `close()` on it in order to free the storage memory consumed by the partially-unrolled + * block. + */ + private[storage] def putIteratorAsValues[T]( + blockId: BlockId, + values: Iterator[T], + classTag: ClassTag[T]): Either[PartiallyUnrolledIterator[T], Long] = { + + val valuesHolder = new DeserializedValuesHolder[T](classTag) + + putIterator(blockId, values, classTag, MemoryMode.ON_HEAP, valuesHolder) match { + case Right(storedSize) => Right(storedSize) + case Left(unrollMemoryUsedByThisBlock) => + val unrolledIterator = if (valuesHolder.vector != null) { + valuesHolder.vector.iterator + } else { + valuesHolder.arrayValues.toIterator + } + Left(new PartiallyUnrolledIterator( this, MemoryMode.ON_HEAP, unrollMemoryUsedByThisBlock, - unrolled = arrayValues.toIterator, - rest = Iterator.empty)) - } - } else { - // We ran out of space while unrolling the values for this block - logUnrollFailureMessage(blockId, vector.estimateSize()) - Left(new PartiallyUnrolledIterator( - this, - MemoryMode.ON_HEAP, - unrollMemoryUsedByThisBlock, - unrolled = vector.iterator, - rest = values)) + unrolled = unrolledIterator, + rest = values)) } } /** * Attempt to put the given block in memory store as bytes. * - * It's possible that the iterator is too large to materialize and store in memory. To avoid - * OOM exceptions, this method will gradually unroll the iterator while periodically checking - * whether there is enough free memory. If the block is successfully materialized, then the - * temporary unroll memory used during the materialization is "transferred" to storage memory, - * so we won't acquire more memory than is actually needed to store the block. - * * @return in case of success, the estimated size of the stored data. In case of failure, * return a handle which allows the caller to either finish the serialization by * spilling to disk or to deserialize the partially-serialized block and reconstruct @@ -319,25 +331,8 @@ private[spark] class MemoryStore( require(!contains(blockId), s"Block $blockId is already present in the MemoryStore") - val allocator = memoryMode match { - case MemoryMode.ON_HEAP => ByteBuffer.allocate _ - case MemoryMode.OFF_HEAP => Platform.allocateDirectBuffer _ - } - - // Whether there is still enough memory for us to continue unrolling this block - var keepUnrolling = true - // Number of elements unrolled so far - var elementsUnrolled = 0L - // How often to check whether we need to request more memory - val memoryCheckPeriod = conf.get(UNROLL_MEMORY_CHECK_PERIOD) - // Memory to request as a multiple of current bbos size - val memoryGrowthFactor = conf.get(UNROLL_MEMORY_GROWTH_FACTOR) // Initial per-task memory to request for unrolling blocks (bytes). val initialMemoryThreshold = unrollMemoryThreshold - // Keep track of unroll memory used by this particular block / putIterator() operation - var unrollMemoryUsedByThisBlock = 0L - // Underlying buffer for unrolling the block - val redirectableStream = new RedirectableOutputStream val chunkSize = if (initialMemoryThreshold > Int.MaxValue) { logWarning(s"Initial memory threshold of ${Utils.bytesToString(initialMemoryThreshold)} " + s"is too large to be set as chunk size. Chunk size has been capped to " + @@ -346,85 +341,22 @@ private[spark] class MemoryStore( } else { initialMemoryThreshold.toInt } - val bbos = new ChunkedByteBufferOutputStream(chunkSize, allocator) - redirectableStream.setOutputStream(bbos) - val serializationStream: SerializationStream = { - val autoPick = !blockId.isInstanceOf[StreamBlockId] - val ser = serializerManager.getSerializer(classTag, autoPick).newInstance() - ser.serializeStream(serializerManager.wrapForCompression(blockId, redirectableStream)) - } - // Request enough memory to begin unrolling - keepUnrolling = reserveUnrollMemoryForThisTask(blockId, initialMemoryThreshold, memoryMode) + val valuesHolder = new SerializedValuesHolder[T](blockId, chunkSize, classTag, + memoryMode, serializerManager) - if (!keepUnrolling) { - logWarning(s"Failed to reserve initial memory threshold of " + - s"${Utils.bytesToString(initialMemoryThreshold)} for computing block $blockId in memory.") - } else { - unrollMemoryUsedByThisBlock += initialMemoryThreshold - } - - def reserveAdditionalMemoryIfNecessary(): Unit = { - if (bbos.size > unrollMemoryUsedByThisBlock) { - val amountToRequest = (bbos.size * memoryGrowthFactor - unrollMemoryUsedByThisBlock).toLong - keepUnrolling = reserveUnrollMemoryForThisTask(blockId, amountToRequest, memoryMode) - if (keepUnrolling) { - unrollMemoryUsedByThisBlock += amountToRequest - } - } - } - - // Unroll this block safely, checking whether we have exceeded our threshold - while (values.hasNext && keepUnrolling) { - serializationStream.writeObject(values.next())(classTag) - elementsUnrolled += 1 - if (elementsUnrolled % memoryCheckPeriod == 0) { - reserveAdditionalMemoryIfNecessary() - } - } - - // Make sure that we have enough memory to store the block. By this point, it is possible that - // the block's actual memory usage has exceeded the unroll memory by a small amount, so we - // perform one final call to attempt to allocate additional memory if necessary. - if (keepUnrolling) { - serializationStream.close() - if (bbos.size > unrollMemoryUsedByThisBlock) { - val amountToRequest = bbos.size - unrollMemoryUsedByThisBlock - keepUnrolling = reserveUnrollMemoryForThisTask(blockId, amountToRequest, memoryMode) - if (keepUnrolling) { - unrollMemoryUsedByThisBlock += amountToRequest - } - } - } - - if (keepUnrolling) { - val entry = SerializedMemoryEntry[T](bbos.toChunkedByteBuffer, memoryMode, classTag) - // Synchronize so that transfer is atomic - memoryManager.synchronized { - releaseUnrollMemoryForThisTask(memoryMode, unrollMemoryUsedByThisBlock) - val success = memoryManager.acquireStorageMemory(blockId, entry.size, memoryMode) - assert(success, "transferring unroll memory to storage memory failed") - } - entries.synchronized { - entries.put(blockId, entry) - } - logInfo("Block %s stored as bytes in memory (estimated size %s, free %s)".format( - blockId, Utils.bytesToString(entry.size), - Utils.bytesToString(maxMemory - blocksMemoryUsed))) - Right(entry.size) - } else { - // We ran out of space while unrolling the values for this block - logUnrollFailureMessage(blockId, bbos.size) - Left( - new PartiallySerializedBlock( + putIterator(blockId, values, classTag, memoryMode, valuesHolder) match { + case Right(storedSize) => Right(storedSize) + case Left(unrollMemoryUsedByThisBlock) => + Left(new PartiallySerializedBlock( this, serializerManager, blockId, - serializationStream, - redirectableStream, + valuesHolder.serializationStream, + valuesHolder.redirectableStream, unrollMemoryUsedByThisBlock, memoryMode, - bbos, + valuesHolder.bbos, values, classTag)) } @@ -702,6 +634,94 @@ private[spark] class MemoryStore( } } +private trait MemoryEntryBuilder[T] { + def preciseSize: Long + def build(): MemoryEntry[T] +} + +private trait ValuesHolder[T] { + def storeValue(value: T): Unit + def estimatedSize(): Long + + /** + * Note: After this method is called, the ValuesHolder is invalid, we can't store data and + * get estimate size again. + * @return a MemoryEntryBuilder which is used to build a memory entry and get the stored data + * size. + */ + def getBuilder(): MemoryEntryBuilder[T] +} + +/** + * A holder for storing the deserialized values. + */ +private class DeserializedValuesHolder[T] (classTag: ClassTag[T]) extends ValuesHolder[T] { + // Underlying vector for unrolling the block + var vector = new SizeTrackingVector[T]()(classTag) + var arrayValues: Array[T] = null + + override def storeValue(value: T): Unit = { + vector += value + } + + override def estimatedSize(): Long = { + vector.estimateSize() + } + + override def getBuilder(): MemoryEntryBuilder[T] = new MemoryEntryBuilder[T] { + // We successfully unrolled the entirety of this block + arrayValues = vector.toArray + vector = null + + override val preciseSize: Long = SizeEstimator.estimate(arrayValues) + + override def build(): MemoryEntry[T] = + DeserializedMemoryEntry[T](arrayValues, preciseSize, classTag) + } +} + +/** + * A holder for storing the serialized values. + */ +private class SerializedValuesHolder[T]( + blockId: BlockId, + chunkSize: Int, + classTag: ClassTag[T], + memoryMode: MemoryMode, + serializerManager: SerializerManager) extends ValuesHolder[T] { + val allocator = memoryMode match { + case MemoryMode.ON_HEAP => ByteBuffer.allocate _ + case MemoryMode.OFF_HEAP => Platform.allocateDirectBuffer _ + } + + val redirectableStream = new RedirectableOutputStream + val bbos = new ChunkedByteBufferOutputStream(chunkSize, allocator) + redirectableStream.setOutputStream(bbos) + val serializationStream: SerializationStream = { + val autoPick = !blockId.isInstanceOf[StreamBlockId] + val ser = serializerManager.getSerializer(classTag, autoPick).newInstance() + ser.serializeStream(serializerManager.wrapForCompression(blockId, redirectableStream)) + } + + override def storeValue(value: T): Unit = { + serializationStream.writeObject(value)(classTag) + } + + override def estimatedSize(): Long = { + bbos.size + } + + override def getBuilder(): MemoryEntryBuilder[T] = new MemoryEntryBuilder[T] { + // We successfully unrolled the entirety of this block + serializationStream.close() + + override def preciseSize(): Long = bbos.size + + override def build(): MemoryEntry[T] = + SerializedMemoryEntry[T](bbos.toChunkedByteBuffer, memoryMode, classTag) + } +} + /** * The result of a failed [[MemoryStore.putIteratorAsValues()]] call. * From dd8e257d1ccf20f4383dd7f30d634010b176f0d3 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 26 Jan 2018 09:17:05 -0800 Subject: [PATCH 0212/1161] [SPARK-23218][SQL] simplify ColumnVector.getArray ## What changes were proposed in this pull request? `ColumnVector` is very flexible about how to implement array type. As a result `ColumnVector` has 3 abstract methods for array type: `arrayData`, `getArrayOffset`, `getArrayLength`. For example, in `WritableColumnVector` we use the first child vector as the array data vector, and store offsets and lengths in 2 arrays in the parent vector. `ArrowColumnVector` has a different implementation. This PR simplifies `ColumnVector` by using only one abstract method for array type: `getArray`. ## How was this patch tested? existing tests. rerun `ColumnarBatchBenchmark`, there is no performance regression. Author: Wenchen Fan Closes #20395 from cloud-fan/vector. --- .../datasources/orc/OrcColumnVector.java | 13 +-- .../vectorized/WritableColumnVector.java | 13 ++- .../sql/vectorized/ArrowColumnVector.java | 48 ++++------ .../spark/sql/vectorized/ColumnVector.java | 88 ++++++++++--------- .../spark/sql/vectorized/ColumnarArray.java | 2 + .../spark/sql/vectorized/ColumnarBatch.java | 2 + .../spark/sql/vectorized/ColumnarRow.java | 2 + .../vectorized/ColumnarBatchBenchmark.scala | 14 ++- 8 files changed, 87 insertions(+), 95 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java index aaf2a380034a9..5078bc7922ee2 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java @@ -24,6 +24,7 @@ import org.apache.spark.sql.types.DataType; import org.apache.spark.sql.types.Decimal; import org.apache.spark.sql.types.TimestampType; +import org.apache.spark.sql.vectorized.ColumnarArray; import org.apache.spark.unsafe.types.UTF8String; /** @@ -145,16 +146,6 @@ public double getDouble(int rowId) { return doubleData.vector[getRowIndex(rowId)]; } - @Override - public int getArrayLength(int rowId) { - throw new UnsupportedOperationException(); - } - - @Override - public int getArrayOffset(int rowId) { - throw new UnsupportedOperationException(); - } - @Override public Decimal getDecimal(int rowId, int precision, int scale) { BigDecimal data = decimalData.vector[getRowIndex(rowId)].getHiveDecimal().bigDecimalValue(); @@ -177,7 +168,7 @@ public byte[] getBinary(int rowId) { } @Override - public org.apache.spark.sql.vectorized.ColumnVector arrayData() { + public ColumnarArray getArray(int rowId) { throw new UnsupportedOperationException(); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java index ca4f00985c2a3..a8ec8ef2aadf8 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java @@ -24,6 +24,7 @@ import org.apache.spark.sql.internal.SQLConf; import org.apache.spark.sql.types.*; import org.apache.spark.sql.vectorized.ColumnVector; +import org.apache.spark.sql.vectorized.ColumnarArray; import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.types.UTF8String; @@ -602,7 +603,17 @@ public final int appendStruct(boolean isNull) { // `WritableColumnVector` puts the data of array in the first child column vector, and puts the // array offsets and lengths in the current column vector. @Override - public WritableColumnVector arrayData() { return childColumns[0]; } + public final ColumnarArray getArray(int rowId) { + return new ColumnarArray(arrayData(), getArrayOffset(rowId), getArrayLength(rowId)); + } + + public WritableColumnVector arrayData() { + return childColumns[0]; + } + + public abstract int getArrayLength(int rowId); + + public abstract int getArrayOffset(int rowId); @Override public WritableColumnVector getChild(int ordinal) { return childColumns[ordinal]; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java index ca7a4751450d4..9803c3dec6de2 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java @@ -17,17 +17,21 @@ package org.apache.spark.sql.vectorized; +import io.netty.buffer.ArrowBuf; import org.apache.arrow.vector.*; import org.apache.arrow.vector.complex.*; import org.apache.arrow.vector.holders.NullableVarCharHolder; +import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.execution.arrow.ArrowUtils; import org.apache.spark.sql.types.*; import org.apache.spark.unsafe.types.UTF8String; /** - * A column vector backed by Apache Arrow. + * A column vector backed by Apache Arrow. Currently time interval type and map type are not + * supported. */ +@InterfaceStability.Evolving public final class ArrowColumnVector extends ColumnVector { private final ArrowVectorAccessor accessor; @@ -90,16 +94,6 @@ public double getDouble(int rowId) { return accessor.getDouble(rowId); } - @Override - public int getArrayLength(int rowId) { - return accessor.getArrayLength(rowId); - } - - @Override - public int getArrayOffset(int rowId) { - return accessor.getArrayOffset(rowId); - } - @Override public Decimal getDecimal(int rowId, int precision, int scale) { return accessor.getDecimal(rowId, precision, scale); @@ -116,7 +110,9 @@ public byte[] getBinary(int rowId) { } @Override - public ArrowColumnVector arrayData() { return childColumns[0]; } + public ColumnarArray getArray(int rowId) { + return accessor.getArray(rowId); + } @Override public ArrowColumnVector getChild(int ordinal) { return childColumns[ordinal]; } @@ -151,9 +147,6 @@ public ArrowColumnVector(ValueVector vector) { } else if (vector instanceof ListVector) { ListVector listVector = (ListVector) vector; accessor = new ArrayAccessor(listVector); - - childColumns = new ArrowColumnVector[1]; - childColumns[0] = new ArrowColumnVector(listVector.getDataVector()); } else if (vector instanceof NullableMapVector) { NullableMapVector mapVector = (NullableMapVector) vector; accessor = new StructAccessor(mapVector); @@ -180,10 +173,6 @@ boolean isNullAt(int rowId) { return vector.isNull(rowId); } - final int getValueCount() { - return vector.getValueCount(); - } - final int getNullCount() { return vector.getNullCount(); } @@ -232,11 +221,7 @@ byte[] getBinary(int rowId) { throw new UnsupportedOperationException(); } - int getArrayLength(int rowId) { - throw new UnsupportedOperationException(); - } - - int getArrayOffset(int rowId) { + ColumnarArray getArray(int rowId) { throw new UnsupportedOperationException(); } } @@ -433,10 +418,12 @@ final long getLong(int rowId) { private static class ArrayAccessor extends ArrowVectorAccessor { private final ListVector accessor; + private final ArrowColumnVector arrayData; ArrayAccessor(ListVector vector) { super(vector); this.accessor = vector; + this.arrayData = new ArrowColumnVector(vector.getDataVector()); } @Override @@ -450,13 +437,12 @@ final boolean isNullAt(int rowId) { } @Override - final int getArrayLength(int rowId) { - return accessor.getInnerValueCountAt(rowId); - } - - @Override - final int getArrayOffset(int rowId) { - return accessor.getOffsetBuffer().getInt(rowId * accessor.OFFSET_WIDTH); + final ColumnarArray getArray(int rowId) { + ArrowBuf offsets = accessor.getOffsetBuffer(); + int index = rowId * accessor.OFFSET_WIDTH; + int start = offsets.getInt(index); + int end = offsets.getInt(index + accessor.OFFSET_WIDTH); + return new ColumnarArray(arrayData, start, end - start); } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java index f9936214035b6..4b955ceddd0f2 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java @@ -16,6 +16,7 @@ */ package org.apache.spark.sql.vectorized; +import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.catalyst.util.MapData; import org.apache.spark.sql.types.DataType; import org.apache.spark.sql.types.Decimal; @@ -29,11 +30,14 @@ * Most of the APIs take the rowId as a parameter. This is the batch local 0-based row id for values * in this ColumnVector. * + * Spark only calls specific `get` method according to the data type of this {@link ColumnVector}, + * e.g. if it's int type, Spark is guaranteed to only call {@link #getInt(int)} or + * {@link #getInts(int, int)}. + * * ColumnVector supports all the data types including nested types. To handle nested types, - * ColumnVector can have children and is a tree structure. For struct type, it stores the actual - * data of each field in the corresponding child ColumnVector, and only stores null information in - * the parent ColumnVector. For array type, it stores the actual array elements in the child - * ColumnVector, and stores null information, array offsets and lengths in the parent ColumnVector. + * ColumnVector can have children and is a tree structure. Please refer to {@link #getStruct(int)}, + * {@link #getArray(int)} and {@link #getMap(int)} for the details about how to implement nested + * types. * * ColumnVector is expected to be reused during the entire data loading process, to avoid allocating * memory again and again. @@ -43,6 +47,7 @@ * format. Since it is expected to reuse the ColumnVector instance while loading data, the storage * footprint is negligible. */ +@InterfaceStability.Evolving public abstract class ColumnVector implements AutoCloseable { /** @@ -70,12 +75,12 @@ public abstract class ColumnVector implements AutoCloseable { public abstract boolean isNullAt(int rowId); /** - * Returns the value for rowId. + * Returns the boolean type value for rowId. */ public abstract boolean getBoolean(int rowId); /** - * Gets values from [rowId, rowId + count) + * Gets boolean type values from [rowId, rowId + count) */ public boolean[] getBooleans(int rowId, int count) { boolean[] res = new boolean[count]; @@ -86,12 +91,12 @@ public boolean[] getBooleans(int rowId, int count) { } /** - * Returns the value for rowId. + * Returns the byte type value for rowId. */ public abstract byte getByte(int rowId); /** - * Gets values from [rowId, rowId + count) + * Gets byte type values from [rowId, rowId + count) */ public byte[] getBytes(int rowId, int count) { byte[] res = new byte[count]; @@ -102,12 +107,12 @@ public byte[] getBytes(int rowId, int count) { } /** - * Returns the value for rowId. + * Returns the short type value for rowId. */ public abstract short getShort(int rowId); /** - * Gets values from [rowId, rowId + count) + * Gets short type values from [rowId, rowId + count) */ public short[] getShorts(int rowId, int count) { short[] res = new short[count]; @@ -118,12 +123,12 @@ public short[] getShorts(int rowId, int count) { } /** - * Returns the value for rowId. + * Returns the int type value for rowId. */ public abstract int getInt(int rowId); /** - * Gets values from [rowId, rowId + count) + * Gets int type values from [rowId, rowId + count) */ public int[] getInts(int rowId, int count) { int[] res = new int[count]; @@ -134,12 +139,12 @@ public int[] getInts(int rowId, int count) { } /** - * Returns the value for rowId. + * Returns the long type value for rowId. */ public abstract long getLong(int rowId); /** - * Gets values from [rowId, rowId + count) + * Gets long type values from [rowId, rowId + count) */ public long[] getLongs(int rowId, int count) { long[] res = new long[count]; @@ -150,12 +155,12 @@ public long[] getLongs(int rowId, int count) { } /** - * Returns the value for rowId. + * Returns the float type value for rowId. */ public abstract float getFloat(int rowId); /** - * Gets values from [rowId, rowId + count) + * Gets float type values from [rowId, rowId + count) */ public float[] getFloats(int rowId, int count) { float[] res = new float[count]; @@ -166,12 +171,12 @@ public float[] getFloats(int rowId, int count) { } /** - * Returns the value for rowId. + * Returns the double type value for rowId. */ public abstract double getDouble(int rowId); /** - * Gets values from [rowId, rowId + count) + * Gets double type values from [rowId, rowId + count) */ public double[] getDoubles(int rowId, int count) { double[] res = new double[count]; @@ -182,57 +187,54 @@ public double[] getDoubles(int rowId, int count) { } /** - * Returns the length of the array for rowId. - */ - public abstract int getArrayLength(int rowId); - - /** - * Returns the offset of the array for rowId. - */ - public abstract int getArrayOffset(int rowId); - - /** - * Returns the struct for rowId. + * Returns the struct type value for rowId. + * + * To support struct type, implementations must implement {@link #getChild(int)} and make this + * vector a tree structure. The number of child vectors must be same as the number of fields of + * the struct type, and each child vector is responsible to store the data for its corresponding + * struct field. */ public final ColumnarRow getStruct(int rowId) { return new ColumnarRow(this, rowId); } /** - * Returns the array for rowId. + * Returns the array type value for rowId. + * + * To support array type, implementations must construct an {@link ColumnarArray} and return it in + * this method. {@link ColumnarArray} requires a {@link ColumnVector} that stores the data of all + * the elements of all the arrays in this vector, and an offset and length which points to a range + * in that {@link ColumnVector}, and the range represents the array for rowId. Implementations + * are free to decide where to put the data vector and offsets and lengths. For example, we can + * use the first child vector as the data vector, and store offsets and lengths in 2 int arrays in + * this vector. */ - public final ColumnarArray getArray(int rowId) { - return new ColumnarArray(arrayData(), getArrayOffset(rowId), getArrayLength(rowId)); - } + public abstract ColumnarArray getArray(int rowId); /** - * Returns the map for rowId. + * Returns the map type value for rowId. */ public MapData getMap(int ordinal) { throw new UnsupportedOperationException(); } /** - * Returns the decimal for rowId. + * Returns the decimal type value for rowId. */ public abstract Decimal getDecimal(int rowId, int precision, int scale); /** - * Returns the UTF8String for rowId. Note that the returned UTF8String may point to the data of - * this column vector, please copy it if you want to keep it after this column vector is freed. + * Returns the string type value for rowId. Note that the returned UTF8String may point to the + * data of this column vector, please copy it if you want to keep it after this column vector is + * freed. */ public abstract UTF8String getUTF8String(int rowId); /** - * Returns the byte array for rowId. + * Returns the binary type value for rowId. */ public abstract byte[] getBinary(int rowId); - /** - * Returns the data for the underlying array. - */ - public abstract ColumnVector arrayData(); - /** * Returns the ordinal's child column vector. */ diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java index 522c39580389f..0d2c3ec8648d3 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java @@ -16,6 +16,7 @@ */ package org.apache.spark.sql.vectorized; +import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.catalyst.util.ArrayData; import org.apache.spark.sql.catalyst.util.MapData; import org.apache.spark.sql.types.*; @@ -25,6 +26,7 @@ /** * Array abstraction in {@link ColumnVector}. */ +@InterfaceStability.Evolving public final class ColumnarArray extends ArrayData { // The data for this array. This array contains elements from // data[offset] to data[offset + length). diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatch.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatch.java index 4dc826cf60c15..d206c1df42abb 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatch.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatch.java @@ -18,6 +18,7 @@ import java.util.*; +import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.execution.vectorized.MutableColumnarRow; @@ -26,6 +27,7 @@ * batch so that Spark can access the data row by row. Instance of it is meant to be reused during * the entire data loading process. */ +@InterfaceStability.Evolving public final class ColumnarBatch { private int numRows; private final ColumnVector[] columns; diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java index 2e59085a82768..25db7e09d20d0 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java @@ -16,6 +16,7 @@ */ package org.apache.spark.sql.vectorized; +import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; import org.apache.spark.sql.catalyst.util.MapData; @@ -26,6 +27,7 @@ /** * Row abstraction in {@link ColumnVector}. */ +@InterfaceStability.Evolving public final class ColumnarRow extends InternalRow { // The data for this row. // E.g. the value of 3rd int field is `data.getChild(3).getInt(rowId)`. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala index ad74fb99b0c73..1f31aa45a1220 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql.execution.datasources.parquet +package org.apache.spark.sql.execution.vectorized import java.nio.ByteBuffer import java.nio.charset.StandardCharsets @@ -23,8 +23,6 @@ import scala.util.Random import org.apache.spark.memory.MemoryMode import org.apache.spark.sql.catalyst.expressions.UnsafeRow -import org.apache.spark.sql.execution.vectorized.OffHeapColumnVector -import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector import org.apache.spark.sql.types.{ArrayType, BinaryType, IntegerType} import org.apache.spark.unsafe.Platform import org.apache.spark.util.Benchmark @@ -434,7 +432,6 @@ object ColumnarBatchBenchmark { } def readArrays(onHeap: Boolean): Unit = { - System.gc() val vector = if (onHeap) onHeapVector else offHeapVector var sum = 0L @@ -448,7 +445,6 @@ object ColumnarBatchBenchmark { } def readArrayElements(onHeap: Boolean): Unit = { - System.gc() val vector = if (onHeap) onHeapVector else offHeapVector var sum = 0L @@ -479,10 +475,10 @@ object ColumnarBatchBenchmark { Array Vector Read: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ - On Heap Read Size Only 416 / 423 393.5 2.5 1.0X - Off Heap Read Size Only 396 / 404 413.6 2.4 1.1X - On Heap Read Elements 2569 / 2590 63.8 15.7 0.2X - Off Heap Read Elements 3302 / 3333 49.6 20.2 0.1X + On Heap Read Size Only 426 / 437 384.9 2.6 1.0X + Off Heap Read Size Only 406 / 421 404.0 2.5 1.0X + On Heap Read Elements 2636 / 2642 62.2 16.1 0.2X + Off Heap Read Elements 3770 / 3774 43.5 23.0 0.1X */ benchmark.run } From a8a3e9b7cf7b9346c43cfbbf7b26fd2fd28dd521 Mon Sep 17 00:00:00 2001 From: Nick Pentreath Date: Fri, 26 Jan 2018 23:48:02 +0200 Subject: [PATCH 0213/1161] Revert "[SPARK-22797][PYSPARK] Bucketizer support multi-column" This reverts commit c22eaa94e85aaac649566495dcf763a5de3c8d06. --- python/pyspark/ml/feature.py | 105 +++++++--------------------- python/pyspark/ml/param/__init__.py | 10 --- python/pyspark/ml/tests.py | 9 --- 3 files changed, 25 insertions(+), 99 deletions(-) diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index fdc7787140490..da85ba761a145 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -317,33 +317,26 @@ class BucketedRandomProjectionLSHModel(LSHModel, JavaMLReadable, JavaMLWritable) @inherit_doc -class Bucketizer(JavaTransformer, HasInputCol, HasOutputCol, HasInputCols, HasOutputCols, - HasHandleInvalid, JavaMLReadable, JavaMLWritable): - """ - Maps a column of continuous features to a column of feature buckets. Since 2.3.0, - :py:class:`Bucketizer` can map multiple columns at once by setting the :py:attr:`inputCols` - parameter. Note that when both the :py:attr:`inputCol` and :py:attr:`inputCols` parameters - are set, an Exception will be thrown. The :py:attr:`splits` parameter is only used for single - column usage, and :py:attr:`splitsArray` is for multiple columns. - - >>> values = [(0.1, 0.0), (0.4, 1.0), (1.2, 1.3), (1.5, float("nan")), - ... (float("nan"), 1.0), (float("nan"), 0.0)] - >>> df = spark.createDataFrame(values, ["values1", "values2"]) +class Bucketizer(JavaTransformer, HasInputCol, HasOutputCol, HasHandleInvalid, + JavaMLReadable, JavaMLWritable): + """ + Maps a column of continuous features to a column of feature buckets. + + >>> values = [(0.1,), (0.4,), (1.2,), (1.5,), (float("nan"),), (float("nan"),)] + >>> df = spark.createDataFrame(values, ["values"]) >>> bucketizer = Bucketizer(splits=[-float("inf"), 0.5, 1.4, float("inf")], - ... inputCol="values1", outputCol="buckets") - >>> bucketed = bucketizer.setHandleInvalid("keep").transform(df.select("values1")) - >>> bucketed.show(truncate=False) - +-------+-------+ - |values1|buckets| - +-------+-------+ - |0.1 |0.0 | - |0.4 |0.0 | - |1.2 |1.0 | - |1.5 |2.0 | - |NaN |3.0 | - |NaN |3.0 | - +-------+-------+ - ... + ... inputCol="values", outputCol="buckets") + >>> bucketed = bucketizer.setHandleInvalid("keep").transform(df).collect() + >>> len(bucketed) + 6 + >>> bucketed[0].buckets + 0.0 + >>> bucketed[1].buckets + 0.0 + >>> bucketed[2].buckets + 1.0 + >>> bucketed[3].buckets + 2.0 >>> bucketizer.setParams(outputCol="b").transform(df).head().b 0.0 >>> bucketizerPath = temp_path + "/bucketizer" @@ -354,22 +347,6 @@ class Bucketizer(JavaTransformer, HasInputCol, HasOutputCol, HasInputCols, HasOu >>> bucketed = bucketizer.setHandleInvalid("skip").transform(df).collect() >>> len(bucketed) 4 - >>> bucketizer2 = Bucketizer(splitsArray= - ... [[-float("inf"), 0.5, 1.4, float("inf")], [-float("inf"), 0.5, float("inf")]], - ... inputCols=["values1", "values2"], outputCols=["buckets1", "buckets2"]) - >>> bucketed2 = bucketizer2.setHandleInvalid("keep").transform(df) - >>> bucketed2.show(truncate=False) - +-------+-------+--------+--------+ - |values1|values2|buckets1|buckets2| - +-------+-------+--------+--------+ - |0.1 |0.0 |0.0 |0.0 | - |0.4 |1.0 |0.0 |1.0 | - |1.2 |1.3 |1.0 |1.0 | - |1.5 |NaN |2.0 |2.0 | - |NaN |1.0 |3.0 |1.0 | - |NaN |0.0 |3.0 |0.0 | - +-------+-------+--------+--------+ - ... .. versionadded:: 1.4.0 """ @@ -386,30 +363,14 @@ class Bucketizer(JavaTransformer, HasInputCol, HasOutputCol, HasInputCols, HasOu handleInvalid = Param(Params._dummy(), "handleInvalid", "how to handle invalid entries. " + "Options are 'skip' (filter out rows with invalid values), " + - "'error' (throw an error), or 'keep' (keep invalid values in a " + - "special additional bucket). Note that in the multiple column " + - "case, the invalid handling is applied to all columns. That said " + - "for 'error' it will throw an error if any invalids are found in " + - "any column, for 'skip' it will skip rows with any invalids in " + - "any columns, etc.", + "'error' (throw an error), or 'keep' (keep invalid values in a special " + + "additional bucket).", typeConverter=TypeConverters.toString) - splitsArray = Param(Params._dummy(), "splitsArray", "The array of split points for mapping " + - "continuous features into buckets for multiple columns. For each input " + - "column, with n+1 splits, there are n buckets. A bucket defined by " + - "splits x,y holds values in the range [x,y) except the last bucket, " + - "which also includes y. The splits should be of length >= 3 and " + - "strictly increasing. Values at -inf, inf must be explicitly provided " + - "to cover all Double values; otherwise, values outside the splits " + - "specified will be treated as errors.", - typeConverter=TypeConverters.toListListFloat) - @keyword_only - def __init__(self, splits=None, inputCol=None, outputCol=None, handleInvalid="error", - splitsArray=None, inputCols=None, outputCols=None): + def __init__(self, splits=None, inputCol=None, outputCol=None, handleInvalid="error"): """ - __init__(self, splits=None, inputCol=None, outputCol=None, handleInvalid="error", \ - splitsArray=None, inputCols=None, outputCols=None) + __init__(self, splits=None, inputCol=None, outputCol=None, handleInvalid="error") """ super(Bucketizer, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.Bucketizer", self.uid) @@ -419,11 +380,9 @@ def __init__(self, splits=None, inputCol=None, outputCol=None, handleInvalid="er @keyword_only @since("1.4.0") - def setParams(self, splits=None, inputCol=None, outputCol=None, handleInvalid="error", - splitsArray=None, inputCols=None, outputCols=None): + def setParams(self, splits=None, inputCol=None, outputCol=None, handleInvalid="error"): """ - setParams(self, splits=None, inputCol=None, outputCol=None, handleInvalid="error", \ - splitsArray=None, inputCols=None, outputCols=None) + setParams(self, splits=None, inputCol=None, outputCol=None, handleInvalid="error") Sets params for this Bucketizer. """ kwargs = self._input_kwargs @@ -443,20 +402,6 @@ def getSplits(self): """ return self.getOrDefault(self.splits) - @since("2.3.0") - def setSplitsArray(self, value): - """ - Sets the value of :py:attr:`splitsArray`. - """ - return self._set(splitsArray=value) - - @since("2.3.0") - def getSplitsArray(self): - """ - Gets the array of split points or its default value. - """ - return self.getOrDefault(self.splitsArray) - @inherit_doc class CountVectorizer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): diff --git a/python/pyspark/ml/param/__init__.py b/python/pyspark/ml/param/__init__.py index 5b6b70292f099..043c25cf9feb4 100644 --- a/python/pyspark/ml/param/__init__.py +++ b/python/pyspark/ml/param/__init__.py @@ -134,16 +134,6 @@ def toListFloat(value): return [float(v) for v in value] raise TypeError("Could not convert %s to list of floats" % value) - @staticmethod - def toListListFloat(value): - """ - Convert a value to list of list of floats, if possible. - """ - if TypeConverters._can_convert_to_list(value): - value = TypeConverters.toList(value) - return [TypeConverters.toListFloat(v) for v in value] - raise TypeError("Could not convert %s to list of list of floats" % value) - @staticmethod def toListInt(value): """ diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index b8bddbd06f165..1af2b91da900d 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -238,15 +238,6 @@ def test_bool(self): self.assertRaises(TypeError, lambda: LogisticRegression(fitIntercept=1)) self.assertRaises(TypeError, lambda: LogisticRegression(fitIntercept="false")) - def test_list_list_float(self): - b = Bucketizer(splitsArray=[[-0.1, 0.5, 3], [-5, 1.5]]) - self.assertEqual(b.getSplitsArray(), [[-0.1, 0.5, 3.0], [-5.0, 1.5]]) - self.assertTrue(all([type(v) == list for v in b.getSplitsArray()])) - self.assertTrue(all([type(v) == float for v in b.getSplitsArray()[0]])) - self.assertTrue(all([type(v) == float for v in b.getSplitsArray()[1]])) - self.assertRaises(TypeError, lambda: Bucketizer(splitsArray=["a", 1.0])) - self.assertRaises(TypeError, lambda: Bucketizer(splitsArray=[[-5, 1.5], ["a", 1.0]])) - class PipelineTests(PySparkTestCase): From 94c67a76ec1fda908a671a47a2a1fa63b3ab1b06 Mon Sep 17 00:00:00 2001 From: Xingbo Jiang Date: Fri, 26 Jan 2018 15:01:03 -0800 Subject: [PATCH 0214/1161] [SPARK-23207][SQL] Shuffle+Repartition on a DataFrame could lead to incorrect answers ## What changes were proposed in this pull request? Currently shuffle repartition uses RoundRobinPartitioning, the generated result is nondeterministic since the sequence of input rows are not determined. The bug can be triggered when there is a repartition call following a shuffle (which would lead to non-deterministic row ordering), as the pattern shows below: upstream stage -> repartition stage -> result stage (-> indicate a shuffle) When one of the executors process goes down, some tasks on the repartition stage will be retried and generate inconsistent ordering, and some tasks of the result stage will be retried generating different data. The following code returns 931532, instead of 1000000: ``` import scala.sys.process._ import org.apache.spark.TaskContext val res = spark.range(0, 1000 * 1000, 1).repartition(200).map { x => x }.repartition(200).map { x => if (TaskContext.get.attemptNumber == 0 && TaskContext.get.partitionId < 2) { throw new Exception("pkill -f java".!!) } x } res.distinct().count() ``` In this PR, we propose a most straight-forward way to fix this problem by performing a local sort before partitioning, after we make the input row ordering deterministic, the function from rows to partitions is fully deterministic too. The downside of the approach is that with extra local sort inserted, the performance of repartition() will go down, so we add a new config named `spark.sql.execution.sortBeforeRepartition` to control whether this patch is applied. The patch is default enabled to be safe-by-default, but user may choose to manually turn it off to avoid performance regression. This patch also changes the output rows ordering of repartition(), that leads to a bunch of test cases failure because they are comparing the results directly. ## How was this patch tested? Add unit test in ExchangeSuite. With this patch(and `spark.sql.execution.sortBeforeRepartition` set to true), the following query returns 1000000: ``` import scala.sys.process._ import org.apache.spark.TaskContext spark.conf.set("spark.sql.execution.sortBeforeRepartition", "true") val res = spark.range(0, 1000 * 1000, 1).repartition(200).map { x => x }.repartition(200).map { x => if (TaskContext.get.attemptNumber == 0 && TaskContext.get.partitionId < 2) { throw new Exception("pkill -f java".!!) } x } res.distinct().count() res7: Long = 1000000 ``` Author: Xingbo Jiang Closes #20393 from jiangxb1987/shuffle-repartition. --- .../unsafe/sort/RecordComparator.java | 4 +- .../unsafe/sort/UnsafeInMemorySorter.java | 7 +- .../unsafe/sort/UnsafeSorterSpillMerger.java | 4 +- .../main/scala/org/apache/spark/rdd/RDD.scala | 2 + .../sort/UnsafeExternalSorterSuite.java | 4 +- .../sort/UnsafeInMemorySorterSuite.java | 8 ++- .../spark/ml/feature/Word2VecSuite.scala | 3 +- .../sql/execution/RecordBinaryComparator.java | 70 +++++++++++++++++++ .../execution/UnsafeExternalRowSorter.java | 44 ++++++++++-- .../apache/spark/sql/internal/SQLConf.scala | 14 ++++ .../sql/execution/UnsafeKVExternalSorter.java | 8 ++- .../apache/spark/sql/execution/SortExec.scala | 2 +- .../exchange/ShuffleExchangeExec.scala | 52 +++++++++++++- .../spark/sql/execution/ExchangeSuite.scala | 26 ++++++- .../datasources/parquet/ParquetIOSuite.scala | 6 +- .../datasources/text/WholeTextFileSuite.scala | 2 +- .../streaming/ForeachSinkSuite.scala | 6 +- 17 files changed, 233 insertions(+), 29 deletions(-) create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/execution/RecordBinaryComparator.java diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordComparator.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordComparator.java index 09e4258792204..02b5de8e128c9 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordComparator.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordComparator.java @@ -32,6 +32,8 @@ public abstract class RecordComparator { public abstract int compare( Object leftBaseObject, long leftBaseOffset, + int leftBaseLength, Object rightBaseObject, - long rightBaseOffset); + long rightBaseOffset, + int rightBaseLength); } diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java index 951d076420ee6..b3c27d83da172 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java @@ -62,12 +62,13 @@ public int compare(RecordPointerAndKeyPrefix r1, RecordPointerAndKeyPrefix r2) { int uaoSize = UnsafeAlignedOffset.getUaoSize(); if (prefixComparisonResult == 0) { final Object baseObject1 = memoryManager.getPage(r1.recordPointer); - // skip length final long baseOffset1 = memoryManager.getOffsetInPage(r1.recordPointer) + uaoSize; + final int baseLength1 = UnsafeAlignedOffset.getSize(baseObject1, baseOffset1 - uaoSize); final Object baseObject2 = memoryManager.getPage(r2.recordPointer); - // skip length final long baseOffset2 = memoryManager.getOffsetInPage(r2.recordPointer) + uaoSize; - return recordComparator.compare(baseObject1, baseOffset1, baseObject2, baseOffset2); + final int baseLength2 = UnsafeAlignedOffset.getSize(baseObject2, baseOffset2 - uaoSize); + return recordComparator.compare(baseObject1, baseOffset1, baseLength1, baseObject2, + baseOffset2, baseLength2); } else { return prefixComparisonResult; } diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java index cf4dfde86ca91..ff0dcc259a4ad 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java @@ -35,8 +35,8 @@ final class UnsafeSorterSpillMerger { prefixComparator.compare(left.getKeyPrefix(), right.getKeyPrefix()); if (prefixComparisonResult == 0) { return recordComparator.compare( - left.getBaseObject(), left.getBaseOffset(), - right.getBaseObject(), right.getBaseOffset()); + left.getBaseObject(), left.getBaseOffset(), left.getRecordLength(), + right.getBaseObject(), right.getBaseOffset(), right.getRecordLength()); } else { return prefixComparisonResult; } diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 7859781e98223..0574abdca32ac 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -414,6 +414,8 @@ abstract class RDD[T: ClassTag]( * * If you are decreasing the number of partitions in this RDD, consider using `coalesce`, * which can avoid performing a shuffle. + * + * TODO Fix the Shuffle+Repartition data loss issue described in SPARK-23207. */ def repartition(numPartitions: Int)(implicit ord: Ordering[T] = null): RDD[T] = withScope { coalesce(numPartitions, shuffle = true) diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java index af4975c888d65..411cd5cb57331 100644 --- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java @@ -72,8 +72,10 @@ public class UnsafeExternalSorterSuite { public int compare( Object leftBaseObject, long leftBaseOffset, + int leftBaseLength, Object rightBaseObject, - long rightBaseOffset) { + long rightBaseOffset, + int rightBaseLength) { return 0; } }; diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java index 594f07dd780f9..c145532328514 100644 --- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java @@ -98,8 +98,10 @@ public void testSortingOnlyByIntegerPrefix() throws Exception { public int compare( Object leftBaseObject, long leftBaseOffset, + int leftBaseLength, Object rightBaseObject, - long rightBaseOffset) { + long rightBaseOffset, + int rightBaseLength) { return 0; } }; @@ -164,8 +166,10 @@ public void freeAfterOOM() { public int compare( Object leftBaseObject, long leftBaseOffset, + int leftBaseLength, Object rightBaseObject, - long rightBaseOffset) { + long rightBaseOffset, + int rightBaseLength) { return 0; } }; diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala index 6183606a7b2ac..10682ba176aca 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala @@ -222,7 +222,8 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul val oldModel = new OldWord2VecModel(word2VecMap) val instance = new Word2VecModel("myWord2VecModel", oldModel) val newInstance = testDefaultReadWrite(instance) - assert(newInstance.getVectors.collect() === instance.getVectors.collect()) + assert(newInstance.getVectors.collect().sortBy(_.getString(0)) === + instance.getVectors.collect().sortBy(_.getString(0))) } test("Word2Vec works with input that is non-nullable (NGram)") { diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/RecordBinaryComparator.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/RecordBinaryComparator.java new file mode 100644 index 0000000000000..bb77b5bf6de2a --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/RecordBinaryComparator.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution; + +import org.apache.spark.unsafe.Platform; +import org.apache.spark.util.collection.unsafe.sort.RecordComparator; + +public final class RecordBinaryComparator extends RecordComparator { + + // TODO(jiangxb) Add test suite for this. + @Override + public int compare( + Object leftObj, long leftOff, int leftLen, Object rightObj, long rightOff, int rightLen) { + int i = 0; + int res = 0; + + // If the arrays have different length, the longer one is larger. + if (leftLen != rightLen) { + return leftLen - rightLen; + } + + // The following logic uses `leftLen` as the length for both `leftObj` and `rightObj`, since + // we have guaranteed `leftLen` == `rightLen`. + + // check if stars align and we can get both offsets to be aligned + if ((leftOff % 8) == (rightOff % 8)) { + while ((leftOff + i) % 8 != 0 && i < leftLen) { + res = (Platform.getByte(leftObj, leftOff + i) & 0xff) - + (Platform.getByte(rightObj, rightOff + i) & 0xff); + if (res != 0) return res; + i += 1; + } + } + // for architectures that support unaligned accesses, chew it up 8 bytes at a time + if (Platform.unaligned() || (((leftOff + i) % 8 == 0) && ((rightOff + i) % 8 == 0))) { + while (i <= leftLen - 8) { + res = (int) ((Platform.getLong(leftObj, leftOff + i) - + Platform.getLong(rightObj, rightOff + i)) % Integer.MAX_VALUE); + if (res != 0) return res; + i += 8; + } + } + // this will finish off the unaligned comparisons, or do the entire aligned comparison + // whichever is needed. + while (i < leftLen) { + res = (Platform.getByte(leftObj, leftOff + i) & 0xff) - + (Platform.getByte(rightObj, rightOff + i) & 0xff); + if (res != 0) return res; + i += 1; + } + + // The two arrays are equal. + return 0; + } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java index 6b002f0d3f8e8..78647b56d621f 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java @@ -18,7 +18,9 @@ package org.apache.spark.sql.execution; import java.io.IOException; +import java.util.function.Supplier; +import org.apache.spark.sql.catalyst.util.TypeUtils; import scala.collection.AbstractIterator; import scala.collection.Iterator; import scala.math.Ordering; @@ -56,26 +58,50 @@ public abstract static class PrefixComputer { public static class Prefix { /** Key prefix value, or the null prefix value if isNull = true. **/ - long value; + public long value; /** Whether the key is null. */ - boolean isNull; + public boolean isNull; } /** * Computes prefix for the given row. For efficiency, the returned object may be reused in * further calls to a given PrefixComputer. */ - abstract Prefix computePrefix(InternalRow row); + public abstract Prefix computePrefix(InternalRow row); } - public UnsafeExternalRowSorter( + public static UnsafeExternalRowSorter createWithRecordComparator( + StructType schema, + Supplier recordComparatorSupplier, + PrefixComparator prefixComparator, + PrefixComputer prefixComputer, + long pageSizeBytes, + boolean canUseRadixSort) throws IOException { + return new UnsafeExternalRowSorter(schema, recordComparatorSupplier, prefixComparator, + prefixComputer, pageSizeBytes, canUseRadixSort); + } + + public static UnsafeExternalRowSorter create( StructType schema, Ordering ordering, PrefixComparator prefixComparator, PrefixComputer prefixComputer, long pageSizeBytes, boolean canUseRadixSort) throws IOException { + Supplier recordComparatorSupplier = + () -> new RowComparator(ordering, schema.length()); + return new UnsafeExternalRowSorter(schema, recordComparatorSupplier, prefixComparator, + prefixComputer, pageSizeBytes, canUseRadixSort); + } + + private UnsafeExternalRowSorter( + StructType schema, + Supplier recordComparatorSupplier, + PrefixComparator prefixComparator, + PrefixComputer prefixComputer, + long pageSizeBytes, + boolean canUseRadixSort) throws IOException { this.schema = schema; this.prefixComputer = prefixComputer; final SparkEnv sparkEnv = SparkEnv.get(); @@ -85,7 +111,7 @@ public UnsafeExternalRowSorter( sparkEnv.blockManager(), sparkEnv.serializerManager(), taskContext, - () -> new RowComparator(ordering, schema.length()), + recordComparatorSupplier, prefixComparator, sparkEnv.conf().getInt("spark.shuffle.sort.initialBufferSize", DEFAULT_INITIAL_SORT_BUFFER_SIZE), @@ -206,7 +232,13 @@ private static final class RowComparator extends RecordComparator { } @Override - public int compare(Object baseObj1, long baseOff1, Object baseObj2, long baseOff2) { + public int compare( + Object baseObj1, + long baseOff1, + int baseLen1, + Object baseObj2, + long baseOff2, + int baseLen2) { // Note that since ordering doesn't need the total length of the record, we just pass 0 // into the row. row1.pointTo(baseObj1, baseOff1, 0); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index b0d18b6dced76..76b9d6f6f33bd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1145,6 +1145,18 @@ object SQLConf { .checkValues(PartitionOverwriteMode.values.map(_.toString)) .createWithDefault(PartitionOverwriteMode.STATIC.toString) + val SORT_BEFORE_REPARTITION = + buildConf("spark.sql.execution.sortBeforeRepartition") + .internal() + .doc("When perform a repartition following a shuffle, the output row ordering would be " + + "nondeterministic. If some downstream stages fail and some tasks of the repartition " + + "stage retry, these tasks may generate different data, and that can lead to correctness " + + "issues. Turn on this config to insert a local sort before actually doing repartition " + + "to generate consistent repartition results. The performance of repartition() may go " + + "down since we insert extra local sort before it.") + .booleanConf + .createWithDefault(true) + object Deprecated { val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" } @@ -1300,6 +1312,8 @@ class SQLConf extends Serializable with Logging { def stringRedationPattern: Option[Regex] = SQL_STRING_REDACTION_PATTERN.readFrom(reader) + def sortBeforeRepartition: Boolean = getConf(SORT_BEFORE_REPARTITION) + /** * Returns the [[Resolver]] for the current configuration, which can be used to determine if two * identifiers are equal. diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java index eb2fe82007af3..b0b5383a081a0 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java @@ -241,7 +241,13 @@ private static final class KVComparator extends RecordComparator { } @Override - public int compare(Object baseObj1, long baseOff1, Object baseObj2, long baseOff2) { + public int compare( + Object baseObj1, + long baseOff1, + int baseLen1, + Object baseObj2, + long baseOff2, + int baseLen2) { // Note that since ordering doesn't need the total length of the record, we just pass 0 // into the row. row1.pointTo(baseObj1, baseOff1 + 4, 0); diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala index ef1bb1c2a4468..ac1c34d41c4f1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala @@ -84,7 +84,7 @@ case class SortExec( } val pageSize = SparkEnv.get.memoryManager.pageSizeBytes - val sorter = new UnsafeExternalRowSorter( + val sorter = UnsafeExternalRowSorter.create( schema, ordering, prefixComparator, prefixComputer, pageSize, canUseRadixSort) if (testSpillFrequency > 0) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala index 5a1e217082bc2..76c1fa65f924b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.exchange import java.util.Random +import java.util.function.Supplier import org.apache.spark._ import org.apache.spark.rdd.RDD @@ -25,13 +26,15 @@ import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.sort.SortShuffleManager import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors._ -import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeProjection} +import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.catalyst.expressions.codegen.LazilyGeneratedOrdering import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.StructType import org.apache.spark.util.MutablePair +import org.apache.spark.util.collection.unsafe.sort.{PrefixComparators, RecordComparator} /** * Performs a shuffle that will result in the desired `newPartitioning`. @@ -247,14 +250,57 @@ object ShuffleExchangeExec { case RangePartitioning(_, _) | SinglePartition => identity case _ => sys.error(s"Exchange not implemented for $newPartitioning") } + val rddWithPartitionIds: RDD[Product2[Int, InternalRow]] = { - if (needToCopyObjectsBeforeShuffle(part, serializer)) { + // [SPARK-23207] Have to make sure the generated RoundRobinPartitioning is deterministic, + // otherwise a retry task may output different rows and thus lead to data loss. + // + // Currently we following the most straight-forward way that perform a local sort before + // partitioning. + val newRdd = if (SQLConf.get.sortBeforeRepartition && + newPartitioning.isInstanceOf[RoundRobinPartitioning]) { rdd.mapPartitionsInternal { iter => + val recordComparatorSupplier = new Supplier[RecordComparator] { + override def get: RecordComparator = new RecordBinaryComparator() + } + // The comparator for comparing row hashcode, which should always be Integer. + val prefixComparator = PrefixComparators.LONG + val canUseRadixSort = SparkEnv.get.conf.get(SQLConf.RADIX_SORT_ENABLED) + // The prefix computer generates row hashcode as the prefix, so we may decrease the + // probability that the prefixes are equal when input rows choose column values from a + // limited range. + val prefixComputer = new UnsafeExternalRowSorter.PrefixComputer { + private val result = new UnsafeExternalRowSorter.PrefixComputer.Prefix + override def computePrefix(row: InternalRow): + UnsafeExternalRowSorter.PrefixComputer.Prefix = { + // The hashcode generated from the binary form of a [[UnsafeRow]] should not be null. + result.isNull = false + result.value = row.hashCode() + result + } + } + val pageSize = SparkEnv.get.memoryManager.pageSizeBytes + + val sorter = UnsafeExternalRowSorter.createWithRecordComparator( + StructType.fromAttributes(outputAttributes), + recordComparatorSupplier, + prefixComparator, + prefixComputer, + pageSize, + canUseRadixSort) + sorter.sort(iter.asInstanceOf[Iterator[UnsafeRow]]) + } + } else { + rdd + } + + if (needToCopyObjectsBeforeShuffle(part, serializer)) { + newRdd.mapPartitionsInternal { iter => val getPartitionKey = getPartitionKeyExtractor() iter.map { row => (part.getPartition(getPartitionKey(row)), row.copy()) } } } else { - rdd.mapPartitionsInternal { iter => + newRdd.mapPartitionsInternal { iter => val getPartitionKey = getPartitionKeyExtractor() val mutablePair = new MutablePair[Int, InternalRow]() iter.map { row => mutablePair.update(part.getPartition(getPartitionKey(row)), row) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala index aac8d56ba6201..697d7e6520713 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala @@ -17,11 +17,14 @@ package org.apache.spark.sql.execution -import org.apache.spark.sql.Row +import scala.util.Random + +import org.apache.spark.sql.{Dataset, Row} import org.apache.spark.sql.catalyst.expressions.{Alias, Literal} import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, IdentityBroadcastMode, SinglePartition} import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec} import org.apache.spark.sql.execution.joins.HashedRelationBroadcastMode +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext class ExchangeSuite extends SparkPlanTest with SharedSQLContext { @@ -101,4 +104,25 @@ class ExchangeSuite extends SparkPlanTest with SharedSQLContext { assert(exchange4.sameResult(exchange5)) assert(exchange5 sameResult exchange4) } + + test("SPARK-23207: Make repartition() generate consistent output") { + def assertConsistency(ds: Dataset[java.lang.Long]): Unit = { + ds.persist() + + val exchange = ds.mapPartitions { iter => + Random.shuffle(iter) + }.repartition(111) + val exchange2 = ds.repartition(111) + + assert(exchange.rdd.collectPartitions() === exchange2.rdd.collectPartitions()) + } + + withSQLConf(SQLConf.SORT_BEFORE_REPARTITION.key -> "true") { + // repartition() should generate consistent output. + assertConsistency(spark.range(10000)) + + // case when input contains duplicated rows. + assertConsistency(spark.range(10000).map(i => Random.nextInt(1000).toLong)) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala index 44a8b25c61dfb..f3ece5b15e26a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala @@ -662,7 +662,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { val v = (row.getInt(0), row.getString(1)) result += v } - assert(data == result) + assert(data.toSet == result.toSet) } finally { reader.close() } @@ -678,7 +678,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { val row = reader.getCurrentValue.asInstanceOf[InternalRow] result += row.getString(0) } - assert(data.map(_._2) == result) + assert(data.map(_._2).toSet == result.toSet) } finally { reader.close() } @@ -695,7 +695,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { val v = (row.getString(0), row.getInt(1)) result += v } - assert(data.map { x => (x._2, x._1) } == result) + assert(data.map { x => (x._2, x._1) }.toSet == result.toSet) } finally { reader.close() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/WholeTextFileSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/WholeTextFileSuite.scala index 8bd736bee69de..fff0f82f9bc2b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/WholeTextFileSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/WholeTextFileSuite.scala @@ -95,7 +95,7 @@ class WholeTextFileSuite extends QueryTest with SharedSQLContext { df1.write.option("compression", "gzip").mode("overwrite").text(path) // On reading through wholetext mode, one file will be read as a single row, i.e. not // delimited by "next line" character. - val expected = Row(Range(0, 1000).mkString("", "\n", "\n")) + val expected = Row(df1.collect().map(_.getString(0)).mkString("", "\n", "\n")) Seq(10, 100, 1000).foreach { bytes => withSQLConf(SQLConf.FILES_MAX_PARTITION_BYTES.key -> bytes.toString) { val df2 = spark.read.option("wholetext", "true").format("text").load(path) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala index 9137d650e906b..1248c670df45c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala @@ -52,13 +52,13 @@ class ForeachSinkSuite extends StreamTest with SharedSQLContext with BeforeAndAf var expectedEventsForPartition0 = Seq( ForeachSinkSuite.Open(partition = 0, version = 0), - ForeachSinkSuite.Process(value = 1), + ForeachSinkSuite.Process(value = 2), ForeachSinkSuite.Process(value = 3), ForeachSinkSuite.Close(None) ) var expectedEventsForPartition1 = Seq( ForeachSinkSuite.Open(partition = 1, version = 0), - ForeachSinkSuite.Process(value = 2), + ForeachSinkSuite.Process(value = 1), ForeachSinkSuite.Process(value = 4), ForeachSinkSuite.Close(None) ) @@ -162,7 +162,7 @@ class ForeachSinkSuite extends StreamTest with SharedSQLContext with BeforeAndAf val allEvents = ForeachSinkSuite.allEvents() assert(allEvents.size === 1) assert(allEvents(0)(0) === ForeachSinkSuite.Open(partition = 0, version = 0)) - assert(allEvents(0)(1) === ForeachSinkSuite.Process(value = 1)) + assert(allEvents(0)(1) === ForeachSinkSuite.Process(value = 2)) // `close` should be called with the error val errorEvent = allEvents(0)(2).asInstanceOf[ForeachSinkSuite.Close] From 073744985f439ca90afb9bd0bbc1332c53f7b4bb Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Fri, 26 Jan 2018 16:09:57 -0800 Subject: [PATCH 0215/1161] [SPARK-23242][SS][TESTS] Don't run tests in KafkaSourceSuiteBase twice ## What changes were proposed in this pull request? KafkaSourceSuiteBase should be abstract class, otherwise KafkaSourceSuiteBase will also run. ## How was this patch tested? Jenkins Author: Shixiong Zhu Closes #20412 from zsxwing/SPARK-23242. --- .../scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala index 27dbb3f7a8f31..c4cb1bc4a2e18 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala @@ -546,7 +546,7 @@ class KafkaMicroBatchSourceSuite extends KafkaSourceSuiteBase { } } -class KafkaSourceSuiteBase extends KafkaSourceTest { +abstract class KafkaSourceSuiteBase extends KafkaSourceTest { import testImplicits._ From 5b5447c68ac79715e2256e487e1212861cdab1fc Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 26 Jan 2018 16:46:51 -0800 Subject: [PATCH 0216/1161] [SPARK-23214][SQL] cached data should not carry extra hint info ## What changes were proposed in this pull request? This is a regression introduced by https://github.com/apache/spark/pull/19864 When we lookup cache, we should not carry the hint info, as this cache entry might be added by a plan having hint info, while the input plan for this lookup may not have hint info, or have different hint info. ## How was this patch tested? a new test. Author: Wenchen Fan Closes #20394 from cloud-fan/cache. --- .../spark/sql/execution/CacheManager.scala | 17 +-- .../execution/columnar/InMemoryRelation.scala | 27 +++-- .../apache/spark/sql/CachedTableSuite.scala | 4 +- .../columnar/InMemoryColumnarQuerySuite.scala | 2 +- .../execution/joins/BroadcastJoinSuite.scala | 103 +++++++++++------- 5 files changed, 94 insertions(+), 59 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala index 432eb59d6fe57..d68aeb275afda 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala @@ -169,14 +169,17 @@ class CacheManager extends Logging { /** Replaces segments of the given logical plan with cached versions where possible. */ def useCachedData(plan: LogicalPlan): LogicalPlan = { val newPlan = plan transformDown { + // Do not lookup the cache by hint node. Hint node is special, we should ignore it when + // canonicalizing plans, so that plans which are same except hint can hit the same cache. + // However, we also want to keep the hint info after cache lookup. Here we skip the hint + // node, so that the returned caching plan won't replace the hint node and drop the hint info + // from the original plan. + case hint: ResolvedHint => hint + case currentFragment => - lookupCachedData(currentFragment).map { cached => - val cachedPlan = cached.cachedRepresentation.withOutput(currentFragment.output) - currentFragment match { - case hint: ResolvedHint => ResolvedHint(cachedPlan, hint.hints) - case _ => cachedPlan - } - }.getOrElse(currentFragment) + lookupCachedData(currentFragment) + .map(_.cachedRepresentation.withOutput(currentFragment.output)) + .getOrElse(currentFragment) } newPlan transformAllExpressions { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala index 51928d914841e..22e16913d4da9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical -import org.apache.spark.sql.catalyst.plans.logical.Statistics +import org.apache.spark.sql.catalyst.plans.logical.{HintInfo, Statistics} import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.storage.StorageLevel import org.apache.spark.util.LongAccumulator @@ -62,8 +62,8 @@ case class InMemoryRelation( @transient child: SparkPlan, tableName: Option[String])( @transient var _cachedColumnBuffers: RDD[CachedBatch] = null, - val batchStats: LongAccumulator = child.sqlContext.sparkContext.longAccumulator, - statsOfPlanToCache: Statistics = null) + val sizeInBytesStats: LongAccumulator = child.sqlContext.sparkContext.longAccumulator, + statsOfPlanToCache: Statistics) extends logical.LeafNode with MultiInstanceRelation { override protected def innerChildren: Seq[SparkPlan] = Seq(child) @@ -73,11 +73,16 @@ case class InMemoryRelation( @transient val partitionStatistics = new PartitionStatistics(output) override def computeStats(): Statistics = { - if (batchStats.value == 0L) { - // Underlying columnar RDD hasn't been materialized, use the stats from the plan to cache - statsOfPlanToCache + if (sizeInBytesStats.value == 0L) { + // Underlying columnar RDD hasn't been materialized, use the stats from the plan to cache. + // Note that we should drop the hint info here. We may cache a plan whose root node is a hint + // node. When we lookup the cache with a semantically same plan without hint info, the plan + // returned by cache lookup should not have hint info. If we lookup the cache with a + // semantically same plan with a different hint info, `CacheManager.useCachedData` will take + // care of it and retain the hint info in the lookup input plan. + statsOfPlanToCache.copy(hints = HintInfo()) } else { - Statistics(sizeInBytes = batchStats.value.longValue) + Statistics(sizeInBytes = sizeInBytesStats.value.longValue) } } @@ -122,7 +127,7 @@ case class InMemoryRelation( rowCount += 1 } - batchStats.add(totalSize) + sizeInBytesStats.add(totalSize) val stats = InternalRow.fromSeq( columnBuilders.flatMap(_.columnStats.collectedStatistics)) @@ -144,7 +149,7 @@ case class InMemoryRelation( def withOutput(newOutput: Seq[Attribute]): InMemoryRelation = { InMemoryRelation( newOutput, useCompression, batchSize, storageLevel, child, tableName)( - _cachedColumnBuffers, batchStats, statsOfPlanToCache) + _cachedColumnBuffers, sizeInBytesStats, statsOfPlanToCache) } override def newInstance(): this.type = { @@ -156,12 +161,12 @@ case class InMemoryRelation( child, tableName)( _cachedColumnBuffers, - batchStats, + sizeInBytesStats, statsOfPlanToCache).asInstanceOf[this.type] } def cachedColumnBuffers: RDD[CachedBatch] = _cachedColumnBuffers override protected def otherCopyArgs: Seq[AnyRef] = - Seq(_cachedColumnBuffers, batchStats, statsOfPlanToCache) + Seq(_cachedColumnBuffers, sizeInBytesStats, statsOfPlanToCache) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index 1e52445f28fc1..72fe0f42801f1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -368,12 +368,12 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext val toBeCleanedAccIds = new HashSet[Long] val accId1 = spark.table("t1").queryExecution.withCachedData.collect { - case i: InMemoryRelation => i.batchStats.id + case i: InMemoryRelation => i.sizeInBytesStats.id }.head toBeCleanedAccIds += accId1 val accId2 = spark.table("t1").queryExecution.withCachedData.collect { - case i: InMemoryRelation => i.batchStats.id + case i: InMemoryRelation => i.sizeInBytesStats.id }.head toBeCleanedAccIds += accId2 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala index 2280da927cf70..dc1766fb9a785 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala @@ -336,7 +336,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { checkAnswer(cached, expectedAnswer) // Check that the right size was calculated. - assert(cached.batchStats.value === expectedAnswer.size * INT.defaultSize) + assert(cached.sizeInBytesStats.value === expectedAnswer.size * INT.defaultSize) } test("access primitive-type columns in CachedBatch without whole stage codegen") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala index 1704bc8376f0d..bcdee792f4c70 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala @@ -22,7 +22,8 @@ import scala.reflect.ClassTag import org.apache.spark.AccumulatorSuite import org.apache.spark.sql.{Dataset, QueryTest, Row, SparkSession} import org.apache.spark.sql.catalyst.expressions.{BitwiseAnd, BitwiseOr, Cast, Literal, ShiftLeft} -import org.apache.spark.sql.execution.{BinaryExecNode, SparkPlan, WholeStageCodegenExec} +import org.apache.spark.sql.execution.{SparkPlan, WholeStageCodegenExec} +import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec import org.apache.spark.sql.execution.exchange.EnsureRequirements import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf @@ -70,8 +71,8 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils { private def testBroadcastJoin[T: ClassTag]( joinType: String, forceBroadcast: Boolean = false): SparkPlan = { - val df1 = spark.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value") - val df2 = spark.createDataFrame(Seq((1, "1"), (2, "2"))).toDF("key", "value") + val df1 = Seq((1, "4"), (2, "2")).toDF("key", "value") + val df2 = Seq((1, "1"), (2, "2")).toDF("key", "value") // Comparison at the end is for broadcast left semi join val joinExpression = df1("key") === df2("key") && df1("value") > df2("value") @@ -109,30 +110,58 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils { } } - test("broadcast hint is retained after using the cached data") { + test("SPARK-23192: broadcast hint should be retained after using the cached data") { withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { - val df1 = spark.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value") - val df2 = spark.createDataFrame(Seq((1, "1"), (2, "2"))).toDF("key", "value") - df2.cache() - val df3 = df1.join(broadcast(df2), Seq("key"), "inner") - val numBroadCastHashJoin = df3.queryExecution.executedPlan.collect { - case b: BroadcastHashJoinExec => b - }.size - assert(numBroadCastHashJoin === 1) + try { + val df1 = Seq((1, "4"), (2, "2")).toDF("key", "value") + val df2 = Seq((1, "1"), (2, "2")).toDF("key", "value") + df2.cache() + val df3 = df1.join(broadcast(df2), Seq("key"), "inner") + val numBroadCastHashJoin = df3.queryExecution.executedPlan.collect { + case b: BroadcastHashJoinExec => b + }.size + assert(numBroadCastHashJoin === 1) + } finally { + spark.catalog.clearCache() + } + } + } + + test("SPARK-23214: cached data should not carry extra hint info") { + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + try { + val df1 = Seq((1, "4"), (2, "2")).toDF("key", "value") + val df2 = Seq((1, "1"), (2, "2")).toDF("key", "value") + broadcast(df2).cache() + + val df3 = df1.join(df2, Seq("key"), "inner") + val numCachedPlan = df3.queryExecution.executedPlan.collect { + case i: InMemoryTableScanExec => i + }.size + // df2 should be cached. + assert(numCachedPlan === 1) + + val numBroadCastHashJoin = df3.queryExecution.executedPlan.collect { + case b: BroadcastHashJoinExec => b + }.size + // df2 should not be broadcasted. + assert(numBroadCastHashJoin === 0) + } finally { + spark.catalog.clearCache() + } } } test("broadcast hint isn't propagated after a join") { withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { - val df1 = spark.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value") - val df2 = spark.createDataFrame(Seq((1, "1"), (2, "2"))).toDF("key", "value") + val df1 = Seq((1, "4"), (2, "2")).toDF("key", "value") + val df2 = Seq((1, "1"), (2, "2")).toDF("key", "value") val df3 = df1.join(broadcast(df2), Seq("key"), "inner").drop(df2("key")) - val df4 = spark.createDataFrame(Seq((1, "5"), (2, "5"))).toDF("key", "value") + val df4 = Seq((1, "5"), (2, "5")).toDF("key", "value") val df5 = df4.join(df3, Seq("key"), "inner") - val plan = - EnsureRequirements(spark.sessionState.conf).apply(df5.queryExecution.sparkPlan) + val plan = EnsureRequirements(spark.sessionState.conf).apply(df5.queryExecution.sparkPlan) assert(plan.collect { case p: BroadcastHashJoinExec => p }.size === 1) assert(plan.collect { case p: SortMergeJoinExec => p }.size === 1) @@ -140,30 +169,30 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils { } private def assertBroadcastJoin(df : Dataset[Row]) : Unit = { - val df1 = spark.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value") + val df1 = Seq((1, "4"), (2, "2")).toDF("key", "value") val joined = df1.join(df, Seq("key"), "inner") - val plan = - EnsureRequirements(spark.sessionState.conf).apply(joined.queryExecution.sparkPlan) + val plan = EnsureRequirements(spark.sessionState.conf).apply(joined.queryExecution.sparkPlan) assert(plan.collect { case p: BroadcastHashJoinExec => p }.size === 1) } test("broadcast hint programming API") { withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { - val df2 = spark.createDataFrame(Seq((1, "1"), (2, "2"), (3, "2"))).toDF("key", "value") + val df2 = Seq((1, "1"), (2, "2"), (3, "2")).toDF("key", "value") val broadcasted = broadcast(df2) - val df3 = spark.createDataFrame(Seq((2, "2"), (3, "3"))).toDF("key", "value") - - val cases = Seq(broadcasted.limit(2), - broadcasted.filter("value < 10"), - broadcasted.sample(true, 0.5), - broadcasted.distinct(), - broadcasted.groupBy("value").agg(min($"key").as("key")), - // except and intersect are semi/anti-joins which won't return more data then - // their left argument, so the broadcast hint should be propagated here - broadcasted.except(df3), - broadcasted.intersect(df3)) + val df3 = Seq((2, "2"), (3, "3")).toDF("key", "value") + + val cases = Seq( + broadcasted.limit(2), + broadcasted.filter("value < 10"), + broadcasted.sample(true, 0.5), + broadcasted.distinct(), + broadcasted.groupBy("value").agg(min($"key").as("key")), + // except and intersect are semi/anti-joins which won't return more data then + // their left argument, so the broadcast hint should be propagated here + broadcasted.except(df3), + broadcasted.intersect(df3)) cases.foreach(assertBroadcastJoin) } @@ -240,9 +269,8 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils { test("Shouldn't change broadcast join buildSide if user clearly specified") { withTempView("t1", "t2") { - spark.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value").createTempView("t1") - spark.createDataFrame(Seq((1, "1"), (2, "12.3"), (2, "123"))).toDF("key", "value") - .createTempView("t2") + Seq((1, "4"), (2, "2")).toDF("key", "value").createTempView("t1") + Seq((1, "1"), (2, "12.3"), (2, "123")).toDF("key", "value").createTempView("t2") val t1Size = spark.table("t1").queryExecution.analyzed.children.head.stats.sizeInBytes val t2Size = spark.table("t2").queryExecution.analyzed.children.head.stats.sizeInBytes @@ -292,9 +320,8 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils { test("Shouldn't bias towards build right if user didn't specify") { withTempView("t1", "t2") { - spark.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value").createTempView("t1") - spark.createDataFrame(Seq((1, "1"), (2, "12.3"), (2, "123"))).toDF("key", "value") - .createTempView("t2") + Seq((1, "4"), (2, "2")).toDF("key", "value").createTempView("t1") + Seq((1, "1"), (2, "12.3"), (2, "123")).toDF("key", "value").createTempView("t2") val t1Size = spark.table("t1").queryExecution.analyzed.children.head.stats.sizeInBytes val t2Size = spark.table("t2").queryExecution.analyzed.children.head.stats.sizeInBytes From e7bc9f0524822a08d857c3a5ba57119644ceae85 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Fri, 26 Jan 2018 18:57:32 -0600 Subject: [PATCH 0217/1161] [MINOR][SS][DOC] Fix `Trigger` Scala/Java doc examples ## What changes were proposed in this pull request? This PR fixes Scala/Java doc examples in `Trigger.java`. ## How was this patch tested? N/A. Author: Dongjoon Hyun Closes #20401 from dongjoon-hyun/SPARK-TRIGGER. --- .../src/main/java/org/apache/spark/sql/streaming/Trigger.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/streaming/Trigger.java b/sql/core/src/main/java/org/apache/spark/sql/streaming/Trigger.java index 33ae9a9e87668..5371a23230c98 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/streaming/Trigger.java +++ b/sql/core/src/main/java/org/apache/spark/sql/streaming/Trigger.java @@ -50,7 +50,7 @@ public static Trigger ProcessingTime(long intervalMs) { * * {{{ * import java.util.concurrent.TimeUnit - * df.writeStream.trigger(ProcessingTime.create(10, TimeUnit.SECONDS)) + * df.writeStream().trigger(Trigger.ProcessingTime(10, TimeUnit.SECONDS)) * }}} * * @since 2.2.0 @@ -66,7 +66,7 @@ public static Trigger ProcessingTime(long interval, TimeUnit timeUnit) { * * {{{ * import scala.concurrent.duration._ - * df.writeStream.trigger(ProcessingTime(10.seconds)) + * df.writeStream.trigger(Trigger.ProcessingTime(10.seconds)) * }}} * @since 2.2.0 */ From 6328868e524121bd00595959d6d059f74e038a6b Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Fri, 26 Jan 2018 23:06:03 -0800 Subject: [PATCH 0218/1161] [SPARK-23245][SS][TESTS] Don't access `lastExecution.executedPlan` in StreamTest ## What changes were proposed in this pull request? `lastExecution.executedPlan` is lazy val so accessing it in StreamTest may need to acquire the lock of `lastExecution`. It may be waiting forever when the streaming thread is holding it and running a continuous Spark job. This PR changes to check if `s.lastExecution` is null to avoid accessing `lastExecution.executedPlan`. ## How was this patch tested? Jenkins Author: Jose Torres Closes #20413 from zsxwing/SPARK-23245. --- .../test/scala/org/apache/spark/sql/streaming/StreamTest.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index efdb0e0e7cf1c..d6433562fb29b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -472,7 +472,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be currentStream.awaitInitialization(streamingTimeout.toMillis) currentStream match { case s: ContinuousExecution => eventually("IncrementalExecution was not created") { - s.lastExecution.executedPlan // will fail if lastExecution is null + assert(s.lastExecution != null) } case _ => } From 3227d14feb1a65e95a2bf326cff6ac95615cc5ac Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Sat, 27 Jan 2018 11:26:09 -0800 Subject: [PATCH 0219/1161] [SPARK-23233][PYTHON] Reset the cache in asNondeterministic to set deterministic properly ## What changes were proposed in this pull request? Reproducer: ```python from pyspark.sql.functions import udf f = udf(lambda x: x) spark.range(1).select(f("id")) # cache JVM UDF instance. f = f.asNondeterministic() spark.range(1).select(f("id"))._jdf.logicalPlan().projectList().head().deterministic() ``` It should return `False` but the current master returns `True`. Seems it's because we cache the JVM UDF instance and then we reuse it even after setting `deterministic` disabled once it's called. ## How was this patch tested? Manually tested. I am not sure if I should add the test with a lot of JVM accesses with the intetnal stuff .. Let me know if anyone feels so. I will add. Author: hyukjinkwon Closes #20409 from HyukjinKwon/SPARK-23233. --- python/pyspark/sql/tests.py | 13 +++++++++++++ python/pyspark/sql/udf.py | 3 +++ 2 files changed, 16 insertions(+) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index a466ab87d882d..ca7bbf8ffe71c 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -441,6 +441,19 @@ def test_nondeterministic_udf2(self): pydoc.render_doc(random_udf1) pydoc.render_doc(udf(lambda x: x).asNondeterministic) + def test_nondeterministic_udf3(self): + # regression test for SPARK-23233 + from pyspark.sql.functions import udf + f = udf(lambda x: x) + # Here we cache the JVM UDF instance. + self.spark.range(1).select(f("id")) + # This should reset the cache to set the deterministic status correctly. + f = f.asNondeterministic() + # Check the deterministic status of udf. + df = self.spark.range(1).select(f("id")) + deterministic = df._jdf.logicalPlan().projectList().head().deterministic() + self.assertFalse(deterministic) + def test_nondeterministic_udf_in_aggregate(self): from pyspark.sql.functions import udf, sum import random diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index de96846c5c774..4f303304e5600 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -188,6 +188,9 @@ def asNondeterministic(self): .. versionadded:: 2.3 """ + # Here, we explicitly clean the cache to create a JVM UDF instance + # with 'deterministic' updated. See SPARK-23233. + self._judf_placeholder = None self.deterministic = False return self From b8c32dc57368e49baaacf660b7e8836eedab2df7 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Sun, 28 Jan 2018 10:33:06 +0900 Subject: [PATCH 0220/1161] [SPARK-23248][PYTHON][EXAMPLES] Relocate module docstrings to the top in PySpark examples ## What changes were proposed in this pull request? This PR proposes to relocate the docstrings in modules of examples to the top. Seems these are mistakes. So, for example, the below codes ```python >>> help(aft_survival_regression) ``` shows the module docstrings for examples as below: **Before** ``` Help on module aft_survival_regression: NAME aft_survival_regression ... DESCRIPTION # Licensed to the Apache Software Foundation (ASF) under one or more # contributor license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright ownership. # The ASF licenses this file to You under the Apache License, Version 2.0 # (the "License"); you may not use this file except in compliance with # the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ... (END) ``` **After** ``` Help on module aft_survival_regression: NAME aft_survival_regression ... DESCRIPTION An example demonstrating aft survival regression. Run with: bin/spark-submit examples/src/main/python/ml/aft_survival_regression.py (END) ``` ## How was this patch tested? Manually checked. Author: hyukjinkwon Closes #20416 from HyukjinKwon/module-docstring-example. --- examples/src/main/python/avro_inputformat.py | 14 +++++++------- .../src/main/python/ml/aft_survival_regression.py | 11 +++++------ .../main/python/ml/bisecting_k_means_example.py | 11 +++++------ .../ml/bucketed_random_projection_lsh_example.py | 12 +++++------- .../src/main/python/ml/chi_square_test_example.py | 10 +++++----- .../src/main/python/ml/correlation_example.py | 10 +++++----- examples/src/main/python/ml/cross_validator.py | 15 +++++++-------- examples/src/main/python/ml/fpgrowth_example.py | 9 ++++----- .../main/python/ml/gaussian_mixture_example.py | 11 +++++------ .../ml/generalized_linear_regression_example.py | 11 +++++------ examples/src/main/python/ml/imputer_example.py | 9 ++++----- .../main/python/ml/isotonic_regression_example.py | 9 +++------ examples/src/main/python/ml/kmeans_example.py | 15 +++++++-------- examples/src/main/python/ml/lda_example.py | 12 +++++------- .../ml/logistic_regression_summary_example.py | 11 +++++------ .../src/main/python/ml/min_hash_lsh_example.py | 12 +++++------- .../src/main/python/ml/one_vs_rest_example.py | 13 ++++++------- .../src/main/python/ml/train_validation_split.py | 13 ++++++------- examples/src/main/python/parquet_inputformat.py | 12 ++++++------ examples/src/main/python/sql/basic.py | 11 +++++------ examples/src/main/python/sql/datasource.py | 11 +++++------ examples/src/main/python/sql/hive.py | 11 +++++------ 22 files changed, 115 insertions(+), 138 deletions(-) diff --git a/examples/src/main/python/avro_inputformat.py b/examples/src/main/python/avro_inputformat.py index 4422f9e7a9589..6286ba6541fbd 100644 --- a/examples/src/main/python/avro_inputformat.py +++ b/examples/src/main/python/avro_inputformat.py @@ -15,13 +15,6 @@ # limitations under the License. # -from __future__ import print_function - -import sys - -from functools import reduce -from pyspark.sql import SparkSession - """ Read data file users.avro in local Spark distro: @@ -50,6 +43,13 @@ {u'favorite_color': None, u'name': u'Alyssa'} {u'favorite_color': u'red', u'name': u'Ben'} """ +from __future__ import print_function + +import sys + +from functools import reduce +from pyspark.sql import SparkSession + if __name__ == "__main__": if len(sys.argv) != 2 and len(sys.argv) != 3: print(""" diff --git a/examples/src/main/python/ml/aft_survival_regression.py b/examples/src/main/python/ml/aft_survival_regression.py index 2f0ca995e55c7..0a71f76418ea6 100644 --- a/examples/src/main/python/ml/aft_survival_regression.py +++ b/examples/src/main/python/ml/aft_survival_regression.py @@ -15,6 +15,11 @@ # limitations under the License. # +""" +An example demonstrating aft survival regression. +Run with: + bin/spark-submit examples/src/main/python/ml/aft_survival_regression.py +""" from __future__ import print_function # $example on$ @@ -23,12 +28,6 @@ # $example off$ from pyspark.sql import SparkSession -""" -An example demonstrating aft survival regression. -Run with: - bin/spark-submit examples/src/main/python/ml/aft_survival_regression.py -""" - if __name__ == "__main__": spark = SparkSession \ .builder \ diff --git a/examples/src/main/python/ml/bisecting_k_means_example.py b/examples/src/main/python/ml/bisecting_k_means_example.py index 1263cb5d177a8..7842d2009e238 100644 --- a/examples/src/main/python/ml/bisecting_k_means_example.py +++ b/examples/src/main/python/ml/bisecting_k_means_example.py @@ -15,6 +15,11 @@ # limitations under the License. # +""" +An example demonstrating bisecting k-means clustering. +Run with: + bin/spark-submit examples/src/main/python/ml/bisecting_k_means_example.py +""" from __future__ import print_function # $example on$ @@ -22,12 +27,6 @@ # $example off$ from pyspark.sql import SparkSession -""" -An example demonstrating bisecting k-means clustering. -Run with: - bin/spark-submit examples/src/main/python/ml/bisecting_k_means_example.py -""" - if __name__ == "__main__": spark = SparkSession\ .builder\ diff --git a/examples/src/main/python/ml/bucketed_random_projection_lsh_example.py b/examples/src/main/python/ml/bucketed_random_projection_lsh_example.py index 1b7a458125cef..610176ea596ca 100644 --- a/examples/src/main/python/ml/bucketed_random_projection_lsh_example.py +++ b/examples/src/main/python/ml/bucketed_random_projection_lsh_example.py @@ -15,7 +15,11 @@ # limitations under the License. # - +""" +An example demonstrating BucketedRandomProjectionLSH. +Run with: + bin/spark-submit examples/src/main/python/ml/bucketed_random_projection_lsh_example.py +""" from __future__ import print_function # $example on$ @@ -25,12 +29,6 @@ # $example off$ from pyspark.sql import SparkSession -""" -An example demonstrating BucketedRandomProjectionLSH. -Run with: - bin/spark-submit examples/src/main/python/ml/bucketed_random_projection_lsh_example.py -""" - if __name__ == "__main__": spark = SparkSession \ .builder \ diff --git a/examples/src/main/python/ml/chi_square_test_example.py b/examples/src/main/python/ml/chi_square_test_example.py index 8f25318ded00a..2af7e683cdb72 100644 --- a/examples/src/main/python/ml/chi_square_test_example.py +++ b/examples/src/main/python/ml/chi_square_test_example.py @@ -15,6 +15,11 @@ # limitations under the License. # +""" +An example for Chi-square hypothesis testing. +Run with: + bin/spark-submit examples/src/main/python/ml/chi_square_test_example.py +""" from __future__ import print_function from pyspark.sql import SparkSession @@ -23,11 +28,6 @@ from pyspark.ml.stat import ChiSquareTest # $example off$ -""" -An example for Chi-square hypothesis testing. -Run with: - bin/spark-submit examples/src/main/python/ml/chi_square_test_example.py -""" if __name__ == "__main__": spark = SparkSession \ .builder \ diff --git a/examples/src/main/python/ml/correlation_example.py b/examples/src/main/python/ml/correlation_example.py index 0a9d30da5a42e..1f4e402ac1a51 100644 --- a/examples/src/main/python/ml/correlation_example.py +++ b/examples/src/main/python/ml/correlation_example.py @@ -15,6 +15,11 @@ # limitations under the License. # +""" +An example for computing correlation matrix. +Run with: + bin/spark-submit examples/src/main/python/ml/correlation_example.py +""" from __future__ import print_function # $example on$ @@ -23,11 +28,6 @@ # $example off$ from pyspark.sql import SparkSession -""" -An example for computing correlation matrix. -Run with: - bin/spark-submit examples/src/main/python/ml/correlation_example.py -""" if __name__ == "__main__": spark = SparkSession \ .builder \ diff --git a/examples/src/main/python/ml/cross_validator.py b/examples/src/main/python/ml/cross_validator.py index db7054307c2e3..6256d11504afb 100644 --- a/examples/src/main/python/ml/cross_validator.py +++ b/examples/src/main/python/ml/cross_validator.py @@ -15,6 +15,13 @@ # limitations under the License. # +""" +A simple example demonstrating model selection using CrossValidator. +This example also demonstrates how Pipelines are Estimators. +Run with: + + bin/spark-submit examples/src/main/python/ml/cross_validator.py +""" from __future__ import print_function # $example on$ @@ -26,14 +33,6 @@ # $example off$ from pyspark.sql import SparkSession -""" -A simple example demonstrating model selection using CrossValidator. -This example also demonstrates how Pipelines are Estimators. -Run with: - - bin/spark-submit examples/src/main/python/ml/cross_validator.py -""" - if __name__ == "__main__": spark = SparkSession\ .builder\ diff --git a/examples/src/main/python/ml/fpgrowth_example.py b/examples/src/main/python/ml/fpgrowth_example.py index c92c3c27abb21..39092e616d429 100644 --- a/examples/src/main/python/ml/fpgrowth_example.py +++ b/examples/src/main/python/ml/fpgrowth_example.py @@ -15,16 +15,15 @@ # limitations under the License. # -# $example on$ -from pyspark.ml.fpm import FPGrowth -# $example off$ -from pyspark.sql import SparkSession - """ An example demonstrating FPGrowth. Run with: bin/spark-submit examples/src/main/python/ml/fpgrowth_example.py """ +# $example on$ +from pyspark.ml.fpm import FPGrowth +# $example off$ +from pyspark.sql import SparkSession if __name__ == "__main__": spark = SparkSession\ diff --git a/examples/src/main/python/ml/gaussian_mixture_example.py b/examples/src/main/python/ml/gaussian_mixture_example.py index e4a0d314e9d91..4938a904189f9 100644 --- a/examples/src/main/python/ml/gaussian_mixture_example.py +++ b/examples/src/main/python/ml/gaussian_mixture_example.py @@ -15,6 +15,11 @@ # limitations under the License. # +""" +A simple example demonstrating Gaussian Mixture Model (GMM). +Run with: + bin/spark-submit examples/src/main/python/ml/gaussian_mixture_example.py +""" from __future__ import print_function # $example on$ @@ -22,12 +27,6 @@ # $example off$ from pyspark.sql import SparkSession -""" -A simple example demonstrating Gaussian Mixture Model (GMM). -Run with: - bin/spark-submit examples/src/main/python/ml/gaussian_mixture_example.py -""" - if __name__ == "__main__": spark = SparkSession\ .builder\ diff --git a/examples/src/main/python/ml/generalized_linear_regression_example.py b/examples/src/main/python/ml/generalized_linear_regression_example.py index 796752a60f3ab..a52f4650c1c6f 100644 --- a/examples/src/main/python/ml/generalized_linear_regression_example.py +++ b/examples/src/main/python/ml/generalized_linear_regression_example.py @@ -15,6 +15,11 @@ # limitations under the License. # +""" +An example demonstrating generalized linear regression. +Run with: + bin/spark-submit examples/src/main/python/ml/generalized_linear_regression_example.py +""" from __future__ import print_function from pyspark.sql import SparkSession @@ -22,12 +27,6 @@ from pyspark.ml.regression import GeneralizedLinearRegression # $example off$ -""" -An example demonstrating generalized linear regression. -Run with: - bin/spark-submit examples/src/main/python/ml/generalized_linear_regression_example.py -""" - if __name__ == "__main__": spark = SparkSession\ .builder\ diff --git a/examples/src/main/python/ml/imputer_example.py b/examples/src/main/python/ml/imputer_example.py index b8437f827e56d..9ba0147763618 100644 --- a/examples/src/main/python/ml/imputer_example.py +++ b/examples/src/main/python/ml/imputer_example.py @@ -15,16 +15,15 @@ # limitations under the License. # -# $example on$ -from pyspark.ml.feature import Imputer -# $example off$ -from pyspark.sql import SparkSession - """ An example demonstrating Imputer. Run with: bin/spark-submit examples/src/main/python/ml/imputer_example.py """ +# $example on$ +from pyspark.ml.feature import Imputer +# $example off$ +from pyspark.sql import SparkSession if __name__ == "__main__": spark = SparkSession\ diff --git a/examples/src/main/python/ml/isotonic_regression_example.py b/examples/src/main/python/ml/isotonic_regression_example.py index 6ae15f1b4b0dd..89cba9dfc7e8f 100644 --- a/examples/src/main/python/ml/isotonic_regression_example.py +++ b/examples/src/main/python/ml/isotonic_regression_example.py @@ -17,6 +17,9 @@ """ Isotonic Regression Example. + +Run with: + bin/spark-submit examples/src/main/python/ml/isotonic_regression_example.py """ from __future__ import print_function @@ -25,12 +28,6 @@ # $example off$ from pyspark.sql import SparkSession -""" -An example demonstrating isotonic regression. -Run with: - bin/spark-submit examples/src/main/python/ml/isotonic_regression_example.py -""" - if __name__ == "__main__": spark = SparkSession\ .builder\ diff --git a/examples/src/main/python/ml/kmeans_example.py b/examples/src/main/python/ml/kmeans_example.py index 5f77843e3743a..80a878af679f4 100644 --- a/examples/src/main/python/ml/kmeans_example.py +++ b/examples/src/main/python/ml/kmeans_example.py @@ -15,6 +15,13 @@ # limitations under the License. # +""" +An example demonstrating k-means clustering. +Run with: + bin/spark-submit examples/src/main/python/ml/kmeans_example.py + +This example requires NumPy (http://www.numpy.org/). +""" from __future__ import print_function # $example on$ @@ -24,14 +31,6 @@ from pyspark.sql import SparkSession -""" -An example demonstrating k-means clustering. -Run with: - bin/spark-submit examples/src/main/python/ml/kmeans_example.py - -This example requires NumPy (http://www.numpy.org/). -""" - if __name__ == "__main__": spark = SparkSession\ .builder\ diff --git a/examples/src/main/python/ml/lda_example.py b/examples/src/main/python/ml/lda_example.py index a8b346f72cd6f..97d1a042d1479 100644 --- a/examples/src/main/python/ml/lda_example.py +++ b/examples/src/main/python/ml/lda_example.py @@ -15,7 +15,11 @@ # limitations under the License. # - +""" +An example demonstrating LDA. +Run with: + bin/spark-submit examples/src/main/python/ml/lda_example.py +""" from __future__ import print_function # $example on$ @@ -23,12 +27,6 @@ # $example off$ from pyspark.sql import SparkSession -""" -An example demonstrating LDA. -Run with: - bin/spark-submit examples/src/main/python/ml/lda_example.py -""" - if __name__ == "__main__": spark = SparkSession \ .builder \ diff --git a/examples/src/main/python/ml/logistic_regression_summary_example.py b/examples/src/main/python/ml/logistic_regression_summary_example.py index bd440a1fbe8df..2274ff707b2a3 100644 --- a/examples/src/main/python/ml/logistic_regression_summary_example.py +++ b/examples/src/main/python/ml/logistic_regression_summary_example.py @@ -15,6 +15,11 @@ # limitations under the License. # +""" +An example demonstrating Logistic Regression Summary. +Run with: + bin/spark-submit examples/src/main/python/ml/logistic_regression_summary_example.py +""" from __future__ import print_function # $example on$ @@ -22,12 +27,6 @@ # $example off$ from pyspark.sql import SparkSession -""" -An example demonstrating Logistic Regression Summary. -Run with: - bin/spark-submit examples/src/main/python/ml/logistic_regression_summary_example.py -""" - if __name__ == "__main__": spark = SparkSession \ .builder \ diff --git a/examples/src/main/python/ml/min_hash_lsh_example.py b/examples/src/main/python/ml/min_hash_lsh_example.py index 7b1dd611a865b..93136e6ae3cae 100644 --- a/examples/src/main/python/ml/min_hash_lsh_example.py +++ b/examples/src/main/python/ml/min_hash_lsh_example.py @@ -15,7 +15,11 @@ # limitations under the License. # - +""" +An example demonstrating MinHashLSH. +Run with: + bin/spark-submit examples/src/main/python/ml/min_hash_lsh_example.py +""" from __future__ import print_function # $example on$ @@ -25,12 +29,6 @@ # $example off$ from pyspark.sql import SparkSession -""" -An example demonstrating MinHashLSH. -Run with: - bin/spark-submit examples/src/main/python/ml/min_hash_lsh_example.py -""" - if __name__ == "__main__": spark = SparkSession \ .builder \ diff --git a/examples/src/main/python/ml/one_vs_rest_example.py b/examples/src/main/python/ml/one_vs_rest_example.py index 8e00c25d9342e..956e94ae4ab62 100644 --- a/examples/src/main/python/ml/one_vs_rest_example.py +++ b/examples/src/main/python/ml/one_vs_rest_example.py @@ -15,6 +15,12 @@ # limitations under the License. # +""" +An example of Multiclass to Binary Reduction with One Vs Rest, +using Logistic Regression as the base classifier. +Run with: + bin/spark-submit examples/src/main/python/ml/one_vs_rest_example.py +""" from __future__ import print_function # $example on$ @@ -23,13 +29,6 @@ # $example off$ from pyspark.sql import SparkSession -""" -An example of Multiclass to Binary Reduction with One Vs Rest, -using Logistic Regression as the base classifier. -Run with: - bin/spark-submit examples/src/main/python/ml/one_vs_rest_example.py -""" - if __name__ == "__main__": spark = SparkSession \ .builder \ diff --git a/examples/src/main/python/ml/train_validation_split.py b/examples/src/main/python/ml/train_validation_split.py index d104f7d30a1bf..d4f9184bf576e 100644 --- a/examples/src/main/python/ml/train_validation_split.py +++ b/examples/src/main/python/ml/train_validation_split.py @@ -15,13 +15,6 @@ # limitations under the License. # -# $example on$ -from pyspark.ml.evaluation import RegressionEvaluator -from pyspark.ml.regression import LinearRegression -from pyspark.ml.tuning import ParamGridBuilder, TrainValidationSplit -# $example off$ -from pyspark.sql import SparkSession - """ This example demonstrates applying TrainValidationSplit to split data and preform model selection. @@ -29,6 +22,12 @@ bin/spark-submit examples/src/main/python/ml/train_validation_split.py """ +# $example on$ +from pyspark.ml.evaluation import RegressionEvaluator +from pyspark.ml.regression import LinearRegression +from pyspark.ml.tuning import ParamGridBuilder, TrainValidationSplit +# $example off$ +from pyspark.sql import SparkSession if __name__ == "__main__": spark = SparkSession\ diff --git a/examples/src/main/python/parquet_inputformat.py b/examples/src/main/python/parquet_inputformat.py index 52e9662d528d8..a3f86cf8999cf 100644 --- a/examples/src/main/python/parquet_inputformat.py +++ b/examples/src/main/python/parquet_inputformat.py @@ -15,12 +15,6 @@ # limitations under the License. # -from __future__ import print_function - -import sys - -from pyspark.sql import SparkSession - """ Read data file users.parquet in local Spark distro: @@ -35,6 +29,12 @@ {u'favorite_color': u'red', u'name': u'Ben', u'favorite_numbers': []} <...more log output...> """ +from __future__ import print_function + +import sys + +from pyspark.sql import SparkSession + if __name__ == "__main__": if len(sys.argv) != 2: print(""" diff --git a/examples/src/main/python/sql/basic.py b/examples/src/main/python/sql/basic.py index c07fa8f2752b3..c8fb25d0533b5 100644 --- a/examples/src/main/python/sql/basic.py +++ b/examples/src/main/python/sql/basic.py @@ -15,6 +15,11 @@ # limitations under the License. # +""" +A simple example demonstrating basic Spark SQL features. +Run with: + ./bin/spark-submit examples/src/main/python/sql/basic.py +""" from __future__ import print_function # $example on:init_session$ @@ -30,12 +35,6 @@ from pyspark.sql.types import * # $example off:programmatic_schema$ -""" -A simple example demonstrating basic Spark SQL features. -Run with: - ./bin/spark-submit examples/src/main/python/sql/basic.py -""" - def basic_df_example(spark): # $example on:create_df$ diff --git a/examples/src/main/python/sql/datasource.py b/examples/src/main/python/sql/datasource.py index b375fa775de39..d8c879dfe02ed 100644 --- a/examples/src/main/python/sql/datasource.py +++ b/examples/src/main/python/sql/datasource.py @@ -15,6 +15,11 @@ # limitations under the License. # +""" +A simple example demonstrating Spark SQL data sources. +Run with: + ./bin/spark-submit examples/src/main/python/sql/datasource.py +""" from __future__ import print_function from pyspark.sql import SparkSession @@ -22,12 +27,6 @@ from pyspark.sql import Row # $example off:schema_merging$ -""" -A simple example demonstrating Spark SQL data sources. -Run with: - ./bin/spark-submit examples/src/main/python/sql/datasource.py -""" - def basic_datasource_example(spark): # $example on:generic_load_save_functions$ diff --git a/examples/src/main/python/sql/hive.py b/examples/src/main/python/sql/hive.py index 1f83a6fb48b97..33fc2dfbeefa2 100644 --- a/examples/src/main/python/sql/hive.py +++ b/examples/src/main/python/sql/hive.py @@ -15,6 +15,11 @@ # limitations under the License. # +""" +A simple example demonstrating Spark SQL Hive integration. +Run with: + ./bin/spark-submit examples/src/main/python/sql/hive.py +""" from __future__ import print_function # $example on:spark_hive$ @@ -24,12 +29,6 @@ from pyspark.sql import Row # $example off:spark_hive$ -""" -A simple example demonstrating Spark SQL Hive integration. -Run with: - ./bin/spark-submit examples/src/main/python/sql/hive.py -""" - if __name__ == "__main__": # $example on:spark_hive$ From c40fda9e4cf32d6cd17af2ace959bbbbe7c782a4 Mon Sep 17 00:00:00 2001 From: Yacine Mazari Date: Sun, 28 Jan 2018 10:27:59 -0600 Subject: [PATCH 0221/1161] [SPARK-23166][ML] Add maxDF Parameter to CountVectorizer ## What changes were proposed in this pull request? Currently, the CountVectorizer has a minDF parameter. It might be useful to also have a maxDF parameter. It will be used as a threshold for filtering all the terms that occur very frequently in a text corpus, because they are not very informative or could even be stop-words. This is analogous to scikit-learn, CountVectorizer, max_df. Other changes: - Refactored code to invoke "filter()" conditioned on maxDF or minDF set. - Refactored code to unpersist input after counting is done. ## How was this patch tested? Unit tests. Author: Yacine Mazari Closes #20367 from ymazari/SPARK-23166. --- .../spark/ml/feature/CountVectorizer.scala | 67 ++++++++++++++--- .../ml/feature/CountVectorizerSuite.scala | 72 +++++++++++++++++++ 2 files changed, 131 insertions(+), 8 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala index 1ebe29703bc47..60a4f918790a3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala @@ -69,6 +69,25 @@ private[feature] trait CountVectorizerParams extends Params with HasInputCol wit /** @group getParam */ def getMinDF: Double = $(minDF) + /** + * Specifies the maximum number of different documents a term must appear in to be included + * in the vocabulary. + * If this is an integer greater than or equal to 1, this specifies the number of documents + * the term must appear in; if this is a double in [0,1), then this specifies the fraction of + * documents. + * + * Default: (2^64^) - 1 + * @group param + */ + val maxDF: DoubleParam = new DoubleParam(this, "maxDF", "Specifies the maximum number of" + + " different documents a term must appear in to be included in the vocabulary." + + " If this is an integer >= 1, this specifies the number of documents the term must" + + " appear in; if this is a double in [0,1), then this specifies the fraction of documents.", + ParamValidators.gtEq(0.0)) + + /** @group getParam */ + def getMaxDF: Double = $(maxDF) + /** Validates and transforms the input schema. */ protected def validateAndTransformSchema(schema: StructType): StructType = { val typeCandidates = List(new ArrayType(StringType, true), new ArrayType(StringType, false)) @@ -113,7 +132,11 @@ private[feature] trait CountVectorizerParams extends Params with HasInputCol wit /** @group getParam */ def getBinary: Boolean = $(binary) - setDefault(vocabSize -> (1 << 18), minDF -> 1.0, minTF -> 1.0, binary -> false) + setDefault(vocabSize -> (1 << 18), + minDF -> 1.0, + maxDF -> Long.MaxValue, + minTF -> 1.0, + binary -> false) } /** @@ -142,6 +165,10 @@ class CountVectorizer @Since("1.5.0") (@Since("1.5.0") override val uid: String) @Since("1.5.0") def setMinDF(value: Double): this.type = set(minDF, value) + /** @group setParam */ + @Since("2.4.0") + def setMaxDF(value: Double): this.type = set(maxDF, value) + /** @group setParam */ @Since("1.5.0") def setMinTF(value: Double): this.type = set(minTF, value) @@ -155,12 +182,24 @@ class CountVectorizer @Since("1.5.0") (@Since("1.5.0") override val uid: String) transformSchema(dataset.schema, logging = true) val vocSize = $(vocabSize) val input = dataset.select($(inputCol)).rdd.map(_.getAs[Seq[String]](0)) + val countingRequired = $(minDF) < 1.0 || $(maxDF) < 1.0 + val maybeInputSize = if (countingRequired) { + Some(input.cache().count()) + } else { + None + } val minDf = if ($(minDF) >= 1.0) { $(minDF) } else { - $(minDF) * input.cache().count() + $(minDF) * maybeInputSize.get } - val wordCounts: RDD[(String, Long)] = input.flatMap { case (tokens) => + val maxDf = if ($(maxDF) >= 1.0) { + $(maxDF) + } else { + $(maxDF) * maybeInputSize.get + } + require(maxDf >= minDf, "maxDF must be >= minDF.") + val allWordCounts = input.flatMap { case (tokens) => val wc = new OpenHashMap[String, Long] tokens.foreach { w => wc.changeValue(w, 1L, _ + 1L) @@ -168,11 +207,23 @@ class CountVectorizer @Since("1.5.0") (@Since("1.5.0") override val uid: String) wc.map { case (word, count) => (word, (count, 1)) } }.reduceByKey { case ((wc1, df1), (wc2, df2)) => (wc1 + wc2, df1 + df2) - }.filter { case (word, (wc, df)) => - df >= minDf - }.map { case (word, (count, dfCount)) => - (word, count) - }.cache() + } + + val filteringRequired = isSet(minDF) || isSet(maxDF) + val maybeFilteredWordCounts = if (filteringRequired) { + allWordCounts.filter { case (_, (_, df)) => df >= minDf && df <= maxDf } + } else { + allWordCounts + } + + val wordCounts = maybeFilteredWordCounts + .map { case (word, (count, _)) => (word, count) } + .cache() + + if (countingRequired) { + input.unpersist() + } + val fullVocabSize = wordCounts.count() val vocab = wordCounts diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala index f213145f1ba0a..1784c07ca23e3 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala @@ -119,6 +119,78 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext } } + test("CountVectorizer maxDF") { + val df = Seq( + (0, split("a b c d"), Vectors.sparse(3, Seq((0, 1.0), (1, 1.0), (2, 1.0)))), + (1, split("a b c"), Vectors.sparse(3, Seq((0, 1.0), (1, 1.0)))), + (2, split("a b"), Vectors.sparse(3, Seq((0, 1.0)))), + (3, split("a"), Vectors.sparse(3, Seq())) + ).toDF("id", "words", "expected") + + // maxDF: ignore terms with count more than 3 + val cvModel = new CountVectorizer() + .setInputCol("words") + .setOutputCol("features") + .setMaxDF(3) + .fit(df) + assert(cvModel.vocabulary === Array("b", "c", "d")) + + cvModel.transform(df).select("features", "expected").collect().foreach { + case Row(features: Vector, expected: Vector) => + assert(features ~== expected absTol 1e-14) + } + + // maxDF: ignore terms with freq > 0.75 + val cvModel2 = new CountVectorizer() + .setInputCol("words") + .setOutputCol("features") + .setMaxDF(0.75) + .fit(df) + assert(cvModel2.vocabulary === Array("b", "c", "d")) + + cvModel2.transform(df).select("features", "expected").collect().foreach { + case Row(features: Vector, expected: Vector) => + assert(features ~== expected absTol 1e-14) + } + } + + test("CountVectorizer using both minDF and maxDF") { + // Ignore terms with count more than 3 AND less than 2 + val df = Seq( + (0, split("a b c d"), Vectors.sparse(2, Seq((0, 1.0), (1, 1.0)))), + (1, split("a b c"), Vectors.sparse(2, Seq((0, 1.0), (1, 1.0)))), + (2, split("a b"), Vectors.sparse(2, Seq((0, 1.0)))), + (3, split("a"), Vectors.sparse(2, Seq())) + ).toDF("id", "words", "expected") + + val cvModel = new CountVectorizer() + .setInputCol("words") + .setOutputCol("features") + .setMinDF(2) + .setMaxDF(3) + .fit(df) + assert(cvModel.vocabulary === Array("b", "c")) + + cvModel.transform(df).select("features", "expected").collect().foreach { + case Row(features: Vector, expected: Vector) => + assert(features ~== expected absTol 1e-14) + } + + // Ignore terms with frequency higher than 0.75 AND less than 0.5 + val cvModel2 = new CountVectorizer() + .setInputCol("words") + .setOutputCol("features") + .setMinDF(0.5) + .setMaxDF(0.75) + .fit(df) + assert(cvModel2.vocabulary === Array("b", "c")) + + cvModel2.transform(df).select("features", "expected").collect().foreach { + case Row(features: Vector, expected: Vector) => + assert(features ~== expected absTol 1e-14) + } + } + test("CountVectorizer throws exception when vocab is empty") { intercept[IllegalArgumentException] { val df = Seq( From 686a622c93207564635569f054e1e6c921624e96 Mon Sep 17 00:00:00 2001 From: CCInCharge Date: Sun, 28 Jan 2018 14:55:43 -0600 Subject: [PATCH 0222/1161] [SPARK-23250][DOCS] Typo in JavaDoc/ScalaDoc for DataFrameWriter ## What changes were proposed in this pull request? Fix typo in ScalaDoc for DataFrameWriter - originally stated "This is applicable for all file-based data sources (e.g. Parquet, JSON) staring Spark 2.1.0", should be "starting with Spark 2.1.0". ## How was this patch tested? Check of correct spelling in ScalaDoc Please review http://spark.apache.org/contributing.html before opening a pull request. Author: CCInCharge Closes #20417 from CCInCharge/master. --- .../scala/org/apache/spark/sql/DataFrameWriter.scala | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 5f3d4448e4e54..5c02eae05304b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -174,7 +174,8 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { * predicates on the partitioned columns. In order for partitioning to work well, the number * of distinct values in each column should typically be less than tens of thousands. * - * This is applicable for all file-based data sources (e.g. Parquet, JSON) staring Spark 2.1.0. + * This is applicable for all file-based data sources (e.g. Parquet, JSON) starting with Spark + * 2.1.0. * * @since 1.4.0 */ @@ -188,7 +189,8 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { * Buckets the output by the given columns. If specified, the output is laid out on the file * system similar to Hive's bucketing scheme. * - * This is applicable for all file-based data sources (e.g. Parquet, JSON) staring Spark 2.1.0. + * This is applicable for all file-based data sources (e.g. Parquet, JSON) starting with Spark + * 2.1.0. * * @since 2.0 */ @@ -202,7 +204,8 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { /** * Sorts the output in each bucket by the given columns. * - * This is applicable for all file-based data sources (e.g. Parquet, JSON) staring Spark 2.1.0. + * This is applicable for all file-based data sources (e.g. Parquet, JSON) starting with Spark + * 2.1.0. * * @since 2.0 */ From 49b0207dc9327989c72700b4d04d2a714c92e159 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Mon, 29 Jan 2018 13:10:38 +0800 Subject: [PATCH 0223/1161] [SPARK-23196] Unify continuous and microbatch V2 sinks ## What changes were proposed in this pull request? Replace streaming V2 sinks with a unified StreamWriteSupport interface, with a shim to use it with microbatch execution. Add a new SQL config to use for disabling V2 sinks, falling back to the V1 sink implementation. ## How was this patch tested? Existing tests, which in the case of Kafka (the only existing continuous V2 sink) now use V2 for microbatch. Author: Jose Torres Closes #20369 from jose-torres/streaming-sink. --- .../sql/kafka010/KafkaSourceProvider.scala | 16 +-- ...usWriter.scala => KafkaStreamWriter.scala} | 30 ++--- .../kafka010/KafkaContinuousSinkSuite.scala | 8 +- .../spark/sql/kafka010/KafkaSinkSuite.scala | 14 ++- .../spark/sql/kafka010/KafkaSourceSuite.scala | 8 +- .../apache/spark/sql/internal/SQLConf.scala | 9 ++ .../v2/streaming/MicroBatchWriteSupport.java | 60 ---------- ...teSupport.java => StreamWriteSupport.java} | 12 +- ...ontinuousWriter.java => StreamWriter.java} | 34 +++++- .../sources/v2/writer/DataSourceV2Writer.java | 4 +- .../datasources/v2/WriteToDataSourceV2.scala | 11 +- .../streaming/MicroBatchExecution.scala | 19 +-- .../sql/execution/streaming/console.scala | 27 ++--- .../continuous/ContinuousExecution.scala | 19 ++- .../continuous/EpochCoordinator.scala | 9 +- .../streaming/sources/ConsoleWriter.scala | 59 ++------- .../streaming/sources/MicroBatchWriter.scala | 54 +++++++++ .../streaming/sources/memoryV2.scala | 29 ++--- .../sql/streaming/DataStreamWriter.scala | 10 +- .../sql/streaming/StreamingQueryManager.scala | 9 +- ...pache.spark.sql.sources.DataSourceRegister | 7 +- .../streaming/MemorySinkV2Suite.scala | 2 +- .../sources/StreamingDataSourceV2Suite.scala | 112 +++++++++--------- 23 files changed, 265 insertions(+), 297 deletions(-) rename external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/{KafkaContinuousWriter.scala => KafkaStreamWriter.scala} (78%) delete mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/MicroBatchWriteSupport.java rename sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/{ContinuousWriteSupport.java => StreamWriteSupport.java} (85%) rename sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/writer/{ContinuousWriter.java => StreamWriter.java} (50%) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWriter.scala diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala index 62a998fbfb30b..2deb7fa2cdf1e 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala @@ -28,11 +28,11 @@ import org.apache.kafka.common.serialization.{ByteArrayDeserializer, ByteArraySe import org.apache.spark.internal.Logging import org.apache.spark.sql.{AnalysisException, DataFrame, SaveMode, SparkSession, SQLContext} -import org.apache.spark.sql.execution.streaming.{Offset, Sink, Source} +import org.apache.spark.sql.execution.streaming.{Sink, Source} import org.apache.spark.sql.sources._ -import org.apache.spark.sql.sources.v2.{DataSourceV2, DataSourceV2Options} -import org.apache.spark.sql.sources.v2.streaming.{ContinuousReadSupport, ContinuousWriteSupport} -import org.apache.spark.sql.sources.v2.streaming.writer.ContinuousWriter +import org.apache.spark.sql.sources.v2.DataSourceV2Options +import org.apache.spark.sql.sources.v2.streaming.{ContinuousReadSupport, StreamWriteSupport} +import org.apache.spark.sql.sources.v2.streaming.writer.StreamWriter import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType @@ -46,7 +46,7 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister with StreamSinkProvider with RelationProvider with CreatableRelationProvider - with ContinuousWriteSupport + with StreamWriteSupport with ContinuousReadSupport with Logging { import KafkaSourceProvider._ @@ -223,11 +223,11 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister } } - override def createContinuousWriter( + override def createStreamWriter( queryId: String, schema: StructType, mode: OutputMode, - options: DataSourceV2Options): Optional[ContinuousWriter] = { + options: DataSourceV2Options): StreamWriter = { import scala.collection.JavaConverters._ val spark = SparkSession.getActiveSession.get @@ -238,7 +238,7 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister KafkaWriter.validateQuery( schema.toAttributes, new java.util.HashMap[String, Object](producerParams.asJava), topic) - Optional.of(new KafkaContinuousWriter(topic, producerParams, schema)) + new KafkaStreamWriter(topic, producerParams, schema) } private def strategy(caseInsensitiveParams: Map[String, String]) = diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousWriter.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamWriter.scala similarity index 78% rename from external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousWriter.scala rename to external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamWriter.scala index 9843f469c5b25..a24efdefa4464 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousWriter.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamWriter.scala @@ -17,19 +17,14 @@ package org.apache.spark.sql.kafka010 -import org.apache.kafka.clients.producer.{Callback, ProducerRecord, RecordMetadata} import scala.collection.JavaConverters._ -import org.apache.spark.internal.Logging -import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Literal, UnsafeProjection} -import org.apache.spark.sql.kafka010.KafkaSourceProvider.{kafkaParamsForProducer, TOPIC_OPTION_KEY} +import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.kafka010.KafkaWriter.validateQuery -import org.apache.spark.sql.sources.v2.streaming.writer.ContinuousWriter +import org.apache.spark.sql.sources.v2.streaming.writer.StreamWriter import org.apache.spark.sql.sources.v2.writer._ -import org.apache.spark.sql.streaming.OutputMode -import org.apache.spark.sql.types.{BinaryType, StringType, StructType} +import org.apache.spark.sql.types.StructType /** * Dummy commit message. The DataSourceV2 framework requires a commit message implementation but we @@ -38,23 +33,24 @@ import org.apache.spark.sql.types.{BinaryType, StringType, StructType} case object KafkaWriterCommitMessage extends WriterCommitMessage /** - * A [[ContinuousWriter]] for Kafka writing. Responsible for generating the writer factory. + * A [[StreamWriter]] for Kafka writing. Responsible for generating the writer factory. + * * @param topic The topic this writer is responsible for. If None, topic will be inferred from * a `topic` field in the incoming data. * @param producerParams Parameters for Kafka producers in each task. * @param schema The schema of the input data. */ -class KafkaContinuousWriter( +class KafkaStreamWriter( topic: Option[String], producerParams: Map[String, String], schema: StructType) - extends ContinuousWriter with SupportsWriteInternalRow { + extends StreamWriter with SupportsWriteInternalRow { validateQuery(schema.toAttributes, producerParams.toMap[String, Object].asJava, topic) - override def createInternalRowWriterFactory(): KafkaContinuousWriterFactory = - KafkaContinuousWriterFactory(topic, producerParams, schema) + override def createInternalRowWriterFactory(): KafkaStreamWriterFactory = + KafkaStreamWriterFactory(topic, producerParams, schema) override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {} - override def abort(messages: Array[WriterCommitMessage]): Unit = {} + override def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {} } /** @@ -65,12 +61,12 @@ class KafkaContinuousWriter( * @param producerParams Parameters for Kafka producers in each task. * @param schema The schema of the input data. */ -case class KafkaContinuousWriterFactory( +case class KafkaStreamWriterFactory( topic: Option[String], producerParams: Map[String, String], schema: StructType) extends DataWriterFactory[InternalRow] { override def createDataWriter(partitionId: Int, attemptNumber: Int): DataWriter[InternalRow] = { - new KafkaContinuousDataWriter(topic, producerParams, schema.toAttributes) + new KafkaStreamDataWriter(topic, producerParams, schema.toAttributes) } } @@ -83,7 +79,7 @@ case class KafkaContinuousWriterFactory( * @param producerParams Parameters to use for the Kafka producer. * @param inputSchema The attributes in the input data. */ -class KafkaContinuousDataWriter( +class KafkaStreamDataWriter( targetTopic: Option[String], producerParams: Map[String, String], inputSchema: Seq[Attribute]) extends KafkaRowWriter(inputSchema, targetTopic) with DataWriter[InternalRow] { import scala.collection.JavaConverters._ diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala index 8487a69851237..fc890a0cfdac3 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala @@ -18,16 +18,14 @@ package org.apache.spark.sql.kafka010 import java.util.Locale -import java.util.concurrent.atomic.AtomicInteger import org.apache.kafka.clients.producer.ProducerConfig import org.apache.kafka.common.serialization.ByteArraySerializer import org.scalatest.time.SpanSugar._ import scala.collection.JavaConverters._ -import org.apache.spark.sql.{AnalysisException, DataFrame, Row, SaveMode} +import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.catalyst.expressions.{AttributeReference, SpecificInternalRow, UnsafeProjection} -import org.apache.spark.sql.execution.streaming.MemoryStream import org.apache.spark.sql.streaming._ import org.apache.spark.sql.types.{BinaryType, DataType} import org.apache.spark.util.Utils @@ -362,7 +360,7 @@ class KafkaContinuousSinkSuite extends KafkaContinuousTest { } finally { writer.stop() } - assert(ex.getMessage.toLowerCase(Locale.ROOT).contains("job aborted")) + assert(ex.getCause.getCause.getMessage.toLowerCase(Locale.ROOT).contains("job aborted")) } test("streaming - exception on config serializer") { @@ -424,7 +422,7 @@ class KafkaContinuousSinkSuite extends KafkaContinuousTest { options.put(ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG, classOf[ByteArraySerializer].getName) val inputSchema = Seq(AttributeReference("value", BinaryType)()) val data = new Array[Byte](15000) // large value - val writeTask = new KafkaContinuousDataWriter(Some(topic), options.asScala.toMap, inputSchema) + val writeTask = new KafkaStreamDataWriter(Some(topic), options.asScala.toMap, inputSchema) try { val fieldTypes: Array[DataType] = Array(BinaryType) val converter = UnsafeProjection.create(fieldTypes) diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala index 2ab336c7ac476..42f8b4c7657e2 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala @@ -336,27 +336,31 @@ class KafkaSinkSuite extends StreamTest with SharedSQLContext { } finally { writer.stop() } - assert(ex.getMessage.toLowerCase(Locale.ROOT).contains("job aborted")) + assert(ex.getCause.getCause.getMessage.toLowerCase(Locale.ROOT).contains("job aborted")) } test("streaming - exception on config serializer") { val input = MemoryStream[String] var writer: StreamingQuery = null var ex: Exception = null - ex = intercept[IllegalArgumentException] { + ex = intercept[StreamingQueryException] { writer = createKafkaWriter( input.toDF(), withOptions = Map("kafka.key.serializer" -> "foo"))() + input.addData("1") + writer.processAllAvailable() } - assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( + assert(ex.getCause.getMessage.toLowerCase(Locale.ROOT).contains( "kafka option 'key.serializer' is not supported")) - ex = intercept[IllegalArgumentException] { + ex = intercept[StreamingQueryException] { writer = createKafkaWriter( input.toDF(), withOptions = Map("kafka.value.serializer" -> "foo"))() + input.addData("1") + writer.processAllAvailable() } - assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( + assert(ex.getCause.getMessage.toLowerCase(Locale.ROOT).contains( "kafka option 'value.serializer' is not supported")) } diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala index c4cb1bc4a2e18..02c87643568bd 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala @@ -29,19 +29,17 @@ import scala.util.Random import org.apache.kafka.clients.producer.RecordMetadata import org.apache.kafka.common.TopicPartition -import org.scalatest.concurrent.Eventually._ import org.scalatest.concurrent.PatienceConfiguration.Timeout import org.scalatest.time.SpanSugar._ import org.apache.spark.SparkContext -import org.apache.spark.sql.{DataFrame, Dataset, ForeachWriter, Row} -import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, WriteToDataSourceV2Exec} +import org.apache.spark.sql.{Dataset, ForeachWriter} +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution -import org.apache.spark.sql.execution.streaming.sources.ContinuousMemoryWriter import org.apache.spark.sql.functions.{count, window} import org.apache.spark.sql.kafka010.KafkaSourceProvider._ -import org.apache.spark.sql.streaming.{ProcessingTime, StreamTest, Trigger} +import org.apache.spark.sql.streaming.{ProcessingTime, StreamTest} import org.apache.spark.sql.streaming.util.StreamManualClock import org.apache.spark.sql.test.{SharedSQLContext, TestSparkSession} import org.apache.spark.util.Utils diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 76b9d6f6f33bd..2c70b004bcff9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1127,6 +1127,13 @@ object SQLConf { .timeConf(TimeUnit.MILLISECONDS) .createWithDefault(100) + val DISABLED_V2_STREAMING_WRITERS = buildConf("spark.sql.streaming.disabledV2Writers") + .internal() + .doc("A comma-separated list of fully qualified data source register class names for which" + + " StreamWriteSupport is disabled. Writes to these sources will fail back to the V1 Sink.") + .stringConf + .createWithDefault("") + object PartitionOverwriteMode extends Enumeration { val STATIC, DYNAMIC = Value } @@ -1494,6 +1501,8 @@ class SQLConf extends Serializable with Logging { def continuousStreamingExecutorPollIntervalMs: Long = getConf(CONTINUOUS_STREAMING_EXECUTOR_POLL_INTERVAL_MS) + def disabledV2StreamingWriters: String = getConf(DISABLED_V2_STREAMING_WRITERS) + def concatBinaryAsString: Boolean = getConf(CONCAT_BINARY_AS_STRING) def eltOutputAsString: Boolean = getConf(ELT_OUTPUT_AS_STRING) diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/MicroBatchWriteSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/MicroBatchWriteSupport.java deleted file mode 100644 index 53ffa95ae0f4c..0000000000000 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/MicroBatchWriteSupport.java +++ /dev/null @@ -1,60 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.sources.v2.streaming; - -import java.util.Optional; - -import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.execution.streaming.BaseStreamingSink; -import org.apache.spark.sql.sources.v2.DataSourceV2; -import org.apache.spark.sql.sources.v2.DataSourceV2Options; -import org.apache.spark.sql.sources.v2.writer.DataSourceV2Writer; -import org.apache.spark.sql.streaming.OutputMode; -import org.apache.spark.sql.types.StructType; - -/** - * A mix-in interface for {@link DataSourceV2}. Data sources can implement this interface to - * provide data writing ability and save the data from a microbatch to the data source. - */ -@InterfaceStability.Evolving -public interface MicroBatchWriteSupport extends BaseStreamingSink { - - /** - * Creates an optional {@link DataSourceV2Writer} to save the data to this data source. Data - * sources can return None if there is no writing needed to be done. - * - * @param queryId A unique string for the writing query. It's possible that there are many writing - * queries running at the same time, and the returned {@link DataSourceV2Writer} - * can use this id to distinguish itself from others. - * @param epochId The unique numeric ID of the batch within this writing query. This is an - * incrementing counter representing a consistent set of data; the same batch may - * be started multiple times in failure recovery scenarios, but it will always - * contain the same records. - * @param schema the schema of the data to be written. - * @param mode the output mode which determines what successive batch output means to this - * sink, please refer to {@link OutputMode} for more details. - * @param options the options for the returned data source writer, which is an immutable - * case-insensitive string-to-string map. - */ - Optional createMicroBatchWriter( - String queryId, - long epochId, - StructType schema, - OutputMode mode, - DataSourceV2Options options); -} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/ContinuousWriteSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/StreamWriteSupport.java similarity index 85% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/ContinuousWriteSupport.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/StreamWriteSupport.java index dee493cadb71e..6cd219c67109a 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/ContinuousWriteSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/StreamWriteSupport.java @@ -17,26 +17,24 @@ package org.apache.spark.sql.sources.v2.streaming; -import java.util.Optional; - import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.execution.streaming.BaseStreamingSink; import org.apache.spark.sql.sources.v2.DataSourceV2; import org.apache.spark.sql.sources.v2.DataSourceV2Options; -import org.apache.spark.sql.sources.v2.streaming.writer.ContinuousWriter; +import org.apache.spark.sql.sources.v2.streaming.writer.StreamWriter; import org.apache.spark.sql.sources.v2.writer.DataSourceV2Writer; import org.apache.spark.sql.streaming.OutputMode; import org.apache.spark.sql.types.StructType; /** * A mix-in interface for {@link DataSourceV2}. Data sources can implement this interface to - * provide data writing ability for continuous stream processing. + * provide data writing ability for structured streaming. */ @InterfaceStability.Evolving -public interface ContinuousWriteSupport extends BaseStreamingSink { +public interface StreamWriteSupport extends BaseStreamingSink { /** - * Creates an optional {@link ContinuousWriter} to save the data to this data source. Data + * Creates an optional {@link StreamWriter} to save the data to this data source. Data * sources can return None if there is no writing needed to be done. * * @param queryId A unique string for the writing query. It's possible that there are many @@ -48,7 +46,7 @@ public interface ContinuousWriteSupport extends BaseStreamingSink { * @param options the options for the returned data source writer, which is an immutable * case-insensitive string-to-string map. */ - Optional createContinuousWriter( + StreamWriter createStreamWriter( String queryId, StructType schema, OutputMode mode, diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/writer/ContinuousWriter.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/writer/StreamWriter.java similarity index 50% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/writer/ContinuousWriter.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/writer/StreamWriter.java index 723395bd1e963..3156c88933e5e 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/writer/ContinuousWriter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/writer/StreamWriter.java @@ -23,10 +23,14 @@ import org.apache.spark.sql.sources.v2.writer.WriterCommitMessage; /** - * A {@link DataSourceV2Writer} for use with continuous stream processing. + * A {@link DataSourceV2Writer} for use with structured streaming. This writer handles commits and + * aborts relative to an epoch ID determined by the execution engine. + * + * {@link DataWriter} implementations generated by a StreamWriter may be reused for multiple epochs, + * and so must reset any internal state after a successful commit. */ @InterfaceStability.Evolving -public interface ContinuousWriter extends DataSourceV2Writer { +public interface StreamWriter extends DataSourceV2Writer { /** * Commits this writing job for the specified epoch with a list of commit messages. The commit * messages are collected from successful data writers and are produced by @@ -34,11 +38,35 @@ public interface ContinuousWriter extends DataSourceV2Writer { * * If this method fails (by throwing an exception), this writing job is considered to have been * failed, and the execution engine will attempt to call {@link #abort(WriterCommitMessage[])}. + * + * To support exactly-once processing, writer implementations should ensure that this method is + * idempotent. The execution engine may call commit() multiple times for the same epoch + * in some circumstances. */ void commit(long epochId, WriterCommitMessage[] messages); + /** + * Aborts this writing job because some data writers are failed and keep failing when retry, or + * the Spark job fails with some unknown reasons, or {@link #commit(WriterCommitMessage[])} fails. + * + * If this method fails (by throwing an exception), the underlying data source may require manual + * cleanup. + * + * Unless the abort is triggered by the failure of commit, the given messages should have some + * null slots as there maybe only a few data writers that are committed before the abort + * happens, or some data writers were committed but their commit messages haven't reached the + * driver when the abort is triggered. So this is just a "best effort" for data sources to + * clean up the data left by data writers. + */ + void abort(long epochId, WriterCommitMessage[] messages); + default void commit(WriterCommitMessage[] messages) { throw new UnsupportedOperationException( - "Commit without epoch should not be called with ContinuousWriter"); + "Commit without epoch should not be called with StreamWriter"); + } + + default void abort(WriterCommitMessage[] messages) { + throw new UnsupportedOperationException( + "Abort without epoch should not be called with StreamWriter"); } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceV2Writer.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceV2Writer.java index f1ef411423162..8048f507a1dca 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceV2Writer.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceV2Writer.java @@ -28,9 +28,7 @@ /** * A data source writer that is returned by * {@link WriteSupport#createWriter(String, StructType, SaveMode, DataSourceV2Options)}/ - * {@link org.apache.spark.sql.sources.v2.streaming.MicroBatchWriteSupport#createMicroBatchWriter( - * String, long, StructType, OutputMode, DataSourceV2Options)}/ - * {@link org.apache.spark.sql.sources.v2.streaming.ContinuousWriteSupport#createContinuousWriter( + * {@link org.apache.spark.sql.sources.v2.streaming.StreamWriteSupport#createStreamWriter( * String, StructType, OutputMode, DataSourceV2Options)}. * It can mix in various writing optimization interfaces to speed up the data saving. The actual * writing logic is delegated to {@link DataWriter}. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala index 3dbdae7b4df9f..cd6b3e99b6bcb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala @@ -26,9 +26,8 @@ import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder} import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.sql.execution.streaming.StreamExecution import org.apache.spark.sql.execution.streaming.continuous.{CommitPartitionEpoch, ContinuousExecution, EpochCoordinatorRef, SetWriterPartitions} -import org.apache.spark.sql.sources.v2.streaming.writer.ContinuousWriter +import org.apache.spark.sql.sources.v2.streaming.writer.StreamWriter import org.apache.spark.sql.sources.v2.writer._ import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils @@ -62,7 +61,9 @@ case class WriteToDataSourceV2Exec(writer: DataSourceV2Writer, query: SparkPlan) try { val runTask = writer match { - case w: ContinuousWriter => + // This case means that we're doing continuous processing. In microbatch streaming, the + // StreamWriter is wrapped in a MicroBatchWriter, which is executed as a normal batch. + case w: StreamWriter => EpochCoordinatorRef.get( sparkContext.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), sparkContext.env) @@ -82,13 +83,13 @@ case class WriteToDataSourceV2Exec(writer: DataSourceV2Writer, query: SparkPlan) (index, message: WriterCommitMessage) => messages(index) = message ) - if (!writer.isInstanceOf[ContinuousWriter]) { + if (!writer.isInstanceOf[StreamWriter]) { logInfo(s"Data source writer $writer is committing.") writer.commit(messages) logInfo(s"Data source writer $writer committed.") } } catch { - case _: InterruptedException if writer.isInstanceOf[ContinuousWriter] => + case _: InterruptedException if writer.isInstanceOf[StreamWriter] => // Interruption is how continuous queries are ended, so accept and ignore the exception. case cause: Throwable => logError(s"Data source writer $writer is aborting.") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index 7c3804547b736..975975243a3d1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -28,9 +28,11 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, Curre import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.datasources.v2.{StreamingDataSourceV2Relation, WriteToDataSourceV2} +import org.apache.spark.sql.execution.streaming.sources.{InternalRowMicroBatchWriter, MicroBatchWriter} import org.apache.spark.sql.sources.v2.DataSourceV2Options -import org.apache.spark.sql.sources.v2.streaming.{MicroBatchReadSupport, MicroBatchWriteSupport} +import org.apache.spark.sql.sources.v2.streaming.{MicroBatchReadSupport, StreamWriteSupport} import org.apache.spark.sql.sources.v2.streaming.reader.{MicroBatchReader, Offset => OffsetV2} +import org.apache.spark.sql.sources.v2.writer.SupportsWriteInternalRow import org.apache.spark.sql.streaming.{OutputMode, ProcessingTime, Trigger} import org.apache.spark.util.{Clock, Utils} @@ -440,15 +442,18 @@ class MicroBatchExecution( val triggerLogicalPlan = sink match { case _: Sink => newAttributePlan - case s: MicroBatchWriteSupport => - val writer = s.createMicroBatchWriter( + case s: StreamWriteSupport => + val writer = s.createStreamWriter( s"$runId", - currentBatchId, newAttributePlan.schema, outputMode, new DataSourceV2Options(extraOptions.asJava)) - assert(writer.isPresent, "microbatch writer must always be present") - WriteToDataSourceV2(writer.get, newAttributePlan) + if (writer.isInstanceOf[SupportsWriteInternalRow]) { + WriteToDataSourceV2( + new InternalRowMicroBatchWriter(currentBatchId, writer), newAttributePlan) + } else { + WriteToDataSourceV2(new MicroBatchWriter(currentBatchId, writer), newAttributePlan) + } case _ => throw new IllegalArgumentException(s"unknown sink type for $sink") } @@ -471,7 +476,7 @@ class MicroBatchExecution( SQLExecution.withNewExecutionId(sparkSessionToRunBatch, lastExecution) { sink match { case s: Sink => s.addBatch(currentBatchId, nextBatch) - case s: MicroBatchWriteSupport => + case _: StreamWriteSupport => // This doesn't accumulate any data - it just forces execution of the microbatch writer. nextBatch.collect() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala index f2aa3259731d1..d5ac0bd1df52b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala @@ -17,15 +17,12 @@ package org.apache.spark.sql.execution.streaming -import java.util.Optional - import org.apache.spark.sql._ -import org.apache.spark.sql.execution.streaming.sources.{ConsoleContinuousWriter, ConsoleMicroBatchWriter, ConsoleWriter} +import org.apache.spark.sql.execution.streaming.sources.ConsoleWriter import org.apache.spark.sql.sources.{BaseRelation, CreatableRelationProvider, DataSourceRegister} import org.apache.spark.sql.sources.v2.{DataSourceV2, DataSourceV2Options} -import org.apache.spark.sql.sources.v2.streaming.{ContinuousWriteSupport, MicroBatchWriteSupport} -import org.apache.spark.sql.sources.v2.streaming.writer.ContinuousWriter -import org.apache.spark.sql.sources.v2.writer.DataSourceV2Writer +import org.apache.spark.sql.sources.v2.streaming.StreamWriteSupport +import org.apache.spark.sql.sources.v2.streaming.writer.StreamWriter import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType @@ -35,26 +32,16 @@ case class ConsoleRelation(override val sqlContext: SQLContext, data: DataFrame) } class ConsoleSinkProvider extends DataSourceV2 - with MicroBatchWriteSupport - with ContinuousWriteSupport + with StreamWriteSupport with DataSourceRegister with CreatableRelationProvider { - override def createMicroBatchWriter( - queryId: String, - batchId: Long, - schema: StructType, - mode: OutputMode, - options: DataSourceV2Options): Optional[DataSourceV2Writer] = { - Optional.of(new ConsoleMicroBatchWriter(batchId, schema, options)) - } - - override def createContinuousWriter( + override def createStreamWriter( queryId: String, schema: StructType, mode: OutputMode, - options: DataSourceV2Options): Optional[ContinuousWriter] = { - Optional.of(new ConsoleContinuousWriter(schema, options)) + options: DataSourceV2Options): StreamWriter = { + new ConsoleWriter(schema, options) } def createRelation( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index 462e7d9721d28..60f880f9c73b8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -24,17 +24,16 @@ import java.util.function.UnaryOperator import scala.collection.JavaConverters._ import scala.collection.mutable.{ArrayBuffer, Map => MutableMap} -import org.apache.spark.{SparkEnv, SparkException} -import org.apache.spark.sql.{AnalysisException, SparkSession} -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, CurrentBatchTimestamp, CurrentDate, CurrentTimestamp} +import org.apache.spark.SparkEnv +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, CurrentDate, CurrentTimestamp} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, StreamingDataSourceV2Relation, WriteToDataSourceV2} import org.apache.spark.sql.execution.streaming.{ContinuousExecutionRelation, StreamingRelationV2, _} import org.apache.spark.sql.sources.v2.DataSourceV2Options -import org.apache.spark.sql.sources.v2.streaming.{ContinuousReadSupport, ContinuousWriteSupport} -import org.apache.spark.sql.sources.v2.streaming.reader.{ContinuousReader, Offset, PartitionOffset} -import org.apache.spark.sql.sources.v2.streaming.writer.ContinuousWriter +import org.apache.spark.sql.sources.v2.streaming.{ContinuousReadSupport, StreamWriteSupport} +import org.apache.spark.sql.sources.v2.streaming.reader.{ContinuousReader, PartitionOffset} import org.apache.spark.sql.streaming.{OutputMode, ProcessingTime, Trigger} import org.apache.spark.sql.types.StructType import org.apache.spark.util.{Clock, Utils} @@ -44,7 +43,7 @@ class ContinuousExecution( name: String, checkpointRoot: String, analyzedPlan: LogicalPlan, - sink: ContinuousWriteSupport, + sink: StreamWriteSupport, trigger: Trigger, triggerClock: Clock, outputMode: OutputMode, @@ -195,12 +194,12 @@ class ContinuousExecution( "CurrentTimestamp and CurrentDate not yet supported for continuous processing") } - val writer = sink.createContinuousWriter( + val writer = sink.createStreamWriter( s"$runId", triggerLogicalPlan.schema, outputMode, new DataSourceV2Options(extraOptions.asJava)) - val withSink = WriteToDataSourceV2(writer.get(), triggerLogicalPlan) + val withSink = WriteToDataSourceV2(writer, triggerLogicalPlan) val reader = withSink.collect { case DataSourceV2Relation(_, r: ContinuousReader) => r @@ -230,7 +229,7 @@ class ContinuousExecution( // Use the parent Spark session for the endpoint since it's where this query ID is registered. val epochEndpoint = EpochCoordinatorRef.create( - writer.get(), reader, this, epochCoordinatorId, currentBatchId, sparkSession, SparkEnv.get) + writer, reader, this, epochCoordinatorId, currentBatchId, sparkSession, SparkEnv.get) val epochUpdateThread = new Thread(new Runnable { override def run: Unit = { try { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala index 90b3584aa0436..84d262116cb46 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala @@ -17,17 +17,14 @@ package org.apache.spark.sql.execution.streaming.continuous -import java.util.concurrent.atomic.AtomicLong - import scala.collection.mutable import org.apache.spark.SparkEnv import org.apache.spark.internal.Logging import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, ThreadSafeRpcEndpoint} import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.execution.streaming.StreamingQueryWrapper import org.apache.spark.sql.sources.v2.streaming.reader.{ContinuousReader, PartitionOffset} -import org.apache.spark.sql.sources.v2.streaming.writer.ContinuousWriter +import org.apache.spark.sql.sources.v2.streaming.writer.StreamWriter import org.apache.spark.sql.sources.v2.writer.WriterCommitMessage import org.apache.spark.util.RpcUtils @@ -85,7 +82,7 @@ private[sql] object EpochCoordinatorRef extends Logging { * Create a reference to a new [[EpochCoordinator]]. */ def create( - writer: ContinuousWriter, + writer: StreamWriter, reader: ContinuousReader, query: ContinuousExecution, epochCoordinatorId: String, @@ -118,7 +115,7 @@ private[sql] object EpochCoordinatorRef extends Logging { * have both committed and reported an end offset for a given epoch. */ private[continuous] class EpochCoordinator( - writer: ContinuousWriter, + writer: StreamWriter, reader: ContinuousReader, query: ContinuousExecution, startEpoch: Long, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala index 6fb61dff60045..7c1700f1de48c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala @@ -20,14 +20,13 @@ package org.apache.spark.sql.execution.streaming.sources import org.apache.spark.internal.Logging import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.sql.sources.v2.DataSourceV2Options -import org.apache.spark.sql.sources.v2.streaming.writer.ContinuousWriter -import org.apache.spark.sql.sources.v2.writer.{DataSourceV2Writer, DataWriterFactory, WriterCommitMessage} +import org.apache.spark.sql.sources.v2.streaming.writer.StreamWriter +import org.apache.spark.sql.sources.v2.writer.{DataWriterFactory, WriterCommitMessage} import org.apache.spark.sql.types.StructType /** Common methods used to create writes for the the console sink */ -trait ConsoleWriter extends Logging { - - def options: DataSourceV2Options +class ConsoleWriter(schema: StructType, options: DataSourceV2Options) + extends StreamWriter with Logging { // Number of rows to display, by default 20 rows protected val numRowsToShow = options.getInt("numRows", 20) @@ -40,14 +39,20 @@ trait ConsoleWriter extends Logging { def createWriterFactory(): DataWriterFactory[Row] = PackedRowWriterFactory - def abort(messages: Array[WriterCommitMessage]): Unit = {} + override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = { + // We have to print a "Batch" label for the epoch for compatibility with the pre-data source V2 + // behavior. + printRows(messages, schema, s"Batch: $epochId") + } + + def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {} protected def printRows( commitMessages: Array[WriterCommitMessage], schema: StructType, printMessage: String): Unit = { val rows = commitMessages.collect { - case PackedRowCommitMessage(rows) => rows + case PackedRowCommitMessage(rs) => rs }.flatten // scalastyle:off println @@ -59,46 +64,8 @@ trait ConsoleWriter extends Logging { .createDataFrame(spark.sparkContext.parallelize(rows), schema) .show(numRowsToShow, isTruncated) } -} - - -/** - * A [[DataSourceV2Writer]] that collects results from a micro-batch query to the driver and - * prints them in the console. Created by - * [[org.apache.spark.sql.execution.streaming.ConsoleSinkProvider]]. - * - * This sink should not be used for production, as it requires sending all rows to the driver - * and does not support recovery. - */ -class ConsoleMicroBatchWriter(batchId: Long, schema: StructType, val options: DataSourceV2Options) - extends DataSourceV2Writer with ConsoleWriter { - - override def commit(messages: Array[WriterCommitMessage]): Unit = { - printRows(messages, schema, s"Batch: $batchId") - } - - override def toString(): String = { - s"ConsoleMicroBatchWriter[numRows=$numRowsToShow, truncate=$isTruncated]" - } -} - - -/** - * A [[DataSourceV2Writer]] that collects results from a continuous query to the driver and - * prints them in the console. Created by - * [[org.apache.spark.sql.execution.streaming.ConsoleSinkProvider]]. - * - * This sink should not be used for production, as it requires sending all rows to the driver - * and does not support recovery. - */ -class ConsoleContinuousWriter(schema: StructType, val options: DataSourceV2Options) - extends ContinuousWriter with ConsoleWriter { - - override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = { - printRows(messages, schema, s"Continuous processing epoch $epochId") - } override def toString(): String = { - s"ConsoleContinuousWriter[numRows=$numRowsToShow, truncate=$isTruncated]" + s"ConsoleWriter[numRows=$numRowsToShow, truncate=$isTruncated]" } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWriter.scala new file mode 100644 index 0000000000000..d7f3ba8856982 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWriter.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.sources + +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.sources.v2.streaming.writer.StreamWriter +import org.apache.spark.sql.sources.v2.writer.{DataSourceV2Writer, DataWriterFactory, SupportsWriteInternalRow, WriterCommitMessage} + +/** + * A [[DataSourceV2Writer]] used to hook V2 stream writers into a microbatch plan. It implements + * the non-streaming interface, forwarding the batch ID determined at construction to a wrapped + * streaming writer. + */ +class MicroBatchWriter(batchId: Long, writer: StreamWriter) extends DataSourceV2Writer { + override def commit(messages: Array[WriterCommitMessage]): Unit = { + writer.commit(batchId, messages) + } + + override def abort(messages: Array[WriterCommitMessage]): Unit = writer.abort(batchId, messages) + + override def createWriterFactory(): DataWriterFactory[Row] = writer.createWriterFactory() +} + +class InternalRowMicroBatchWriter(batchId: Long, writer: StreamWriter) + extends DataSourceV2Writer with SupportsWriteInternalRow { + override def commit(messages: Array[WriterCommitMessage]): Unit = { + writer.commit(batchId, messages) + } + + override def abort(messages: Array[WriterCommitMessage]): Unit = writer.abort(batchId, messages) + + override def createInternalRowWriterFactory(): DataWriterFactory[InternalRow] = + writer match { + case w: SupportsWriteInternalRow => w.createInternalRowWriterFactory() + case _ => throw new IllegalStateException( + "InternalRowMicroBatchWriter should only be created with base writer support") + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala index da7c31cf62428..ce55e44d932bd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala @@ -30,8 +30,8 @@ import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics} import org.apache.spark.sql.catalyst.streaming.InternalOutputModes.{Append, Complete, Update} import org.apache.spark.sql.execution.streaming.Sink import org.apache.spark.sql.sources.v2.{DataSourceV2, DataSourceV2Options} -import org.apache.spark.sql.sources.v2.streaming.{ContinuousWriteSupport, MicroBatchWriteSupport} -import org.apache.spark.sql.sources.v2.streaming.writer.ContinuousWriter +import org.apache.spark.sql.sources.v2.streaming.StreamWriteSupport +import org.apache.spark.sql.sources.v2.streaming.writer.StreamWriter import org.apache.spark.sql.sources.v2.writer._ import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType @@ -40,24 +40,13 @@ import org.apache.spark.sql.types.StructType * A sink that stores the results in memory. This [[Sink]] is primarily intended for use in unit * tests and does not provide durability. */ -class MemorySinkV2 extends DataSourceV2 - with MicroBatchWriteSupport with ContinuousWriteSupport with Logging { - - override def createMicroBatchWriter( +class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with Logging { + override def createStreamWriter( queryId: String, - batchId: Long, schema: StructType, mode: OutputMode, - options: DataSourceV2Options): java.util.Optional[DataSourceV2Writer] = { - java.util.Optional.of(new MemoryWriter(this, batchId, mode)) - } - - override def createContinuousWriter( - queryId: String, - schema: StructType, - mode: OutputMode, - options: DataSourceV2Options): java.util.Optional[ContinuousWriter] = { - java.util.Optional.of(new ContinuousMemoryWriter(this, mode)) + options: DataSourceV2Options): StreamWriter = { + new MemoryStreamWriter(this, mode) } private case class AddedData(batchId: Long, data: Array[Row]) @@ -141,8 +130,8 @@ class MemoryWriter(sink: MemorySinkV2, batchId: Long, outputMode: OutputMode) } } -class ContinuousMemoryWriter(val sink: MemorySinkV2, outputMode: OutputMode) - extends ContinuousWriter { +class MemoryStreamWriter(val sink: MemorySinkV2, outputMode: OutputMode) + extends StreamWriter { override def createWriterFactory: MemoryWriterFactory = MemoryWriterFactory(outputMode) @@ -153,7 +142,7 @@ class ContinuousMemoryWriter(val sink: MemorySinkV2, outputMode: OutputMode) sink.write(epochId, outputMode, newRows) } - override def abort(messages: Array[WriterCommitMessage]): Unit = { + override def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit = { // Don't accept any of the new input. } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala index d24f0ddeab4de..3b5b30d77945c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger import org.apache.spark.sql.execution.streaming.sources.{MemoryPlanV2, MemorySinkV2} -import org.apache.spark.sql.sources.v2.streaming.{ContinuousWriteSupport, MicroBatchWriteSupport} +import org.apache.spark.sql.sources.v2.streaming.StreamWriteSupport /** * Interface used to write a streaming `Dataset` to external storage systems (e.g. file systems, @@ -281,11 +281,9 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { trigger = trigger) } else { val ds = DataSource.lookupDataSource(source, df.sparkSession.sessionState.conf) - val sink = (ds.newInstance(), trigger) match { - case (w: ContinuousWriteSupport, _: ContinuousTrigger) => w - case (_, _: ContinuousTrigger) => throw new UnsupportedOperationException( - s"Data source $source does not support continuous writing") - case (w: MicroBatchWriteSupport, _) => w + val disabledSources = df.sparkSession.sqlContext.conf.disabledV2StreamingWriters.split(",") + val sink = ds.newInstance() match { + case w: StreamWriteSupport if !disabledSources.contains(w.getClass.getCanonicalName) => w case _ => val ds = DataSource( df.sparkSession, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala index 4b27e0d4ef47b..fdd709cdb1f38 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous.{ContinuousExecution, ContinuousTrigger} import org.apache.spark.sql.execution.streaming.state.StateStoreCoordinatorRef import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.sources.v2.streaming.{ContinuousWriteSupport, MicroBatchWriteSupport} +import org.apache.spark.sql.sources.v2.streaming.StreamWriteSupport import org.apache.spark.util.{Clock, SystemClock, Utils} /** @@ -241,7 +241,7 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Lo } (sink, trigger) match { - case (v2Sink: ContinuousWriteSupport, trigger: ContinuousTrigger) => + case (v2Sink: StreamWriteSupport, trigger: ContinuousTrigger) => UnsupportedOperationChecker.checkForContinuous(analyzedPlan, outputMode) new StreamingQueryWrapper(new ContinuousExecution( sparkSession, @@ -254,7 +254,7 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Lo outputMode, extraOptions, deleteCheckpointOnStop)) - case (_: MicroBatchWriteSupport, _) | (_: Sink, _) => + case _ => new StreamingQueryWrapper(new MicroBatchExecution( sparkSession, userSpecifiedName.orNull, @@ -266,9 +266,6 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Lo outputMode, extraOptions, deleteCheckpointOnStop)) - case (_: ContinuousWriteSupport, t) if !t.isInstanceOf[ContinuousTrigger] => - throw new AnalysisException( - "Sink only supports continuous writes, but a continuous trigger was not specified.") } } diff --git a/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister index a0b25b4e82364..46b38bed1c0fb 100644 --- a/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister +++ b/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -9,7 +9,6 @@ org.apache.spark.sql.streaming.sources.FakeReadMicroBatchOnly org.apache.spark.sql.streaming.sources.FakeReadContinuousOnly org.apache.spark.sql.streaming.sources.FakeReadBothModes org.apache.spark.sql.streaming.sources.FakeReadNeitherMode -org.apache.spark.sql.streaming.sources.FakeWriteMicroBatchOnly -org.apache.spark.sql.streaming.sources.FakeWriteContinuousOnly -org.apache.spark.sql.streaming.sources.FakeWriteBothModes -org.apache.spark.sql.streaming.sources.FakeWriteNeitherMode +org.apache.spark.sql.streaming.sources.FakeWrite +org.apache.spark.sql.streaming.sources.FakeNoWrite +org.apache.spark.sql.streaming.sources.FakeWriteV1Fallback diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala index 00d4f0b8503d8..9be22d94b5654 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala @@ -40,7 +40,7 @@ class MemorySinkV2Suite extends StreamTest with BeforeAndAfter { test("continuous writer") { val sink = new MemorySinkV2 - val writer = new ContinuousMemoryWriter(sink, OutputMode.Append()) + val writer = new MemoryStreamWriter(sink, OutputMode.Append()) writer.commit(0, Array( MemoryWriterCommitMessage(0, Seq(Row(1), Row(2))), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala index f152174b0a7f0..d4f8bae96695d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala @@ -19,18 +19,18 @@ package org.apache.spark.sql.streaming.sources import java.util.Optional -import org.apache.spark.sql.{AnalysisException, Row} +import org.apache.spark.sql.{DataFrame, Row, SQLContext} import org.apache.spark.sql.execution.datasources.DataSource -import org.apache.spark.sql.execution.streaming.{LongOffset, RateStreamOffset} +import org.apache.spark.sql.execution.streaming.{RateStreamOffset, Sink, StreamingQueryWrapper} import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger -import org.apache.spark.sql.sources.DataSourceRegister -import org.apache.spark.sql.sources.v2.{DataSourceV2, DataSourceV2Options} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.sources.{DataSourceRegister, StreamSinkProvider} +import org.apache.spark.sql.sources.v2.DataSourceV2Options import org.apache.spark.sql.sources.v2.reader.ReadTask -import org.apache.spark.sql.sources.v2.streaming.{ContinuousReadSupport, ContinuousWriteSupport, MicroBatchReadSupport, MicroBatchWriteSupport} +import org.apache.spark.sql.sources.v2.streaming._ import org.apache.spark.sql.sources.v2.streaming.reader.{ContinuousReader, MicroBatchReader, Offset, PartitionOffset} -import org.apache.spark.sql.sources.v2.streaming.writer.ContinuousWriter -import org.apache.spark.sql.sources.v2.writer.DataSourceV2Writer -import org.apache.spark.sql.streaming.{OutputMode, StreamingQueryException, StreamTest, Trigger} +import org.apache.spark.sql.sources.v2.streaming.writer.StreamWriter +import org.apache.spark.sql.streaming.{OutputMode, StreamTest, Trigger} import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils @@ -64,23 +64,12 @@ trait FakeContinuousReadSupport extends ContinuousReadSupport { options: DataSourceV2Options): ContinuousReader = FakeReader() } -trait FakeMicroBatchWriteSupport extends MicroBatchWriteSupport { - def createMicroBatchWriter( +trait FakeStreamWriteSupport extends StreamWriteSupport { + override def createStreamWriter( queryId: String, - epochId: Long, schema: StructType, mode: OutputMode, - options: DataSourceV2Options): Optional[DataSourceV2Writer] = { - throw new IllegalStateException("fake sink - cannot actually write") - } -} - -trait FakeContinuousWriteSupport extends ContinuousWriteSupport { - def createContinuousWriter( - queryId: String, - schema: StructType, - mode: OutputMode, - options: DataSourceV2Options): Optional[ContinuousWriter] = { + options: DataSourceV2Options): StreamWriter = { throw new IllegalStateException("fake sink - cannot actually write") } } @@ -102,23 +91,36 @@ class FakeReadNeitherMode extends DataSourceRegister { override def shortName(): String = "fake-read-neither-mode" } -class FakeWriteMicroBatchOnly extends DataSourceRegister with FakeMicroBatchWriteSupport { - override def shortName(): String = "fake-write-microbatch-only" +class FakeWrite extends DataSourceRegister with FakeStreamWriteSupport { + override def shortName(): String = "fake-write-microbatch-continuous" } -class FakeWriteContinuousOnly extends DataSourceRegister with FakeContinuousWriteSupport { - override def shortName(): String = "fake-write-continuous-only" +class FakeNoWrite extends DataSourceRegister { + override def shortName(): String = "fake-write-neither-mode" } -class FakeWriteBothModes extends DataSourceRegister - with FakeMicroBatchWriteSupport with FakeContinuousWriteSupport { - override def shortName(): String = "fake-write-microbatch-continuous" + +case class FakeWriteV1FallbackException() extends Exception + +class FakeSink extends Sink { + override def addBatch(batchId: Long, data: DataFrame): Unit = {} } -class FakeWriteNeitherMode extends DataSourceRegister { - override def shortName(): String = "fake-write-neither-mode" +class FakeWriteV1Fallback extends DataSourceRegister + with FakeStreamWriteSupport with StreamSinkProvider { + + override def createSink( + sqlContext: SQLContext, + parameters: Map[String, String], + partitionColumns: Seq[String], + outputMode: OutputMode): Sink = { + new FakeSink() + } + + override def shortName(): String = "fake-write-v1-fallback" } + class StreamingDataSourceV2Suite extends StreamTest { override def beforeAll(): Unit = { @@ -133,8 +135,6 @@ class StreamingDataSourceV2Suite extends StreamTest { "fake-read-microbatch-continuous", "fake-read-neither-mode") val writeFormats = Seq( - "fake-write-microbatch-only", - "fake-write-continuous-only", "fake-write-microbatch-continuous", "fake-write-neither-mode") val triggers = Seq( @@ -151,6 +151,7 @@ class StreamingDataSourceV2Suite extends StreamTest { .trigger(trigger) .start() query.stop() + query } private def testNegativeCase( @@ -184,6 +185,24 @@ class StreamingDataSourceV2Suite extends StreamTest { } } + test("disabled v2 write") { + // Ensure the V2 path works normally and generates a V2 sink.. + val v2Query = testPositiveCase( + "fake-read-microbatch-continuous", "fake-write-v1-fallback", Trigger.Once()) + assert(v2Query.asInstanceOf[StreamingQueryWrapper].streamingQuery.sink + .isInstanceOf[FakeWriteV1Fallback]) + + // Ensure we create a V1 sink with the config. Note the config is a comma separated + // list, including other fake entries. + val fullSinkName = "org.apache.spark.sql.streaming.sources.FakeWriteV1Fallback" + withSQLConf(SQLConf.DISABLED_V2_STREAMING_WRITERS.key -> s"a,b,c,test,$fullSinkName,d,e") { + val v1Query = testPositiveCase( + "fake-read-microbatch-continuous", "fake-write-v1-fallback", Trigger.Once()) + assert(v1Query.asInstanceOf[StreamingQueryWrapper].streamingQuery.sink + .isInstanceOf[FakeSink]) + } + } + // Get a list of (read, write, trigger) tuples for test cases. val cases = readFormats.flatMap { read => writeFormats.flatMap { write => @@ -199,12 +218,12 @@ class StreamingDataSourceV2Suite extends StreamTest { val writeSource = DataSource.lookupDataSource(write, spark.sqlContext.conf).newInstance() (readSource, writeSource, trigger) match { // Valid microbatch queries. - case (_: MicroBatchReadSupport, _: MicroBatchWriteSupport, t) + case (_: MicroBatchReadSupport, _: StreamWriteSupport, t) if !t.isInstanceOf[ContinuousTrigger] => testPositiveCase(read, write, trigger) // Valid continuous queries. - case (_: ContinuousReadSupport, _: ContinuousWriteSupport, _: ContinuousTrigger) => + case (_: ContinuousReadSupport, _: StreamWriteSupport, _: ContinuousTrigger) => testPositiveCase(read, write, trigger) // Invalid - can't read at all @@ -214,31 +233,18 @@ class StreamingDataSourceV2Suite extends StreamTest { testNegativeCase(read, write, trigger, s"Data source $read does not support streamed reading") - // Invalid - trigger is continuous but writer is not - case (_, w, _: ContinuousTrigger) if !w.isInstanceOf[ContinuousWriteSupport] => - testNegativeCase(read, write, trigger, - s"Data source $write does not support continuous writing") - - // Invalid - can't write at all - case (_, w, _) - if !w.isInstanceOf[MicroBatchWriteSupport] - && !w.isInstanceOf[ContinuousWriteSupport] => + // Invalid - can't write + case (_, w, _) if !w.isInstanceOf[StreamWriteSupport] => testNegativeCase(read, write, trigger, s"Data source $write does not support streamed writing") - // Invalid - trigger and writer are continuous but reader is not - case (r, _: ContinuousWriteSupport, _: ContinuousTrigger) + // Invalid - trigger is continuous but reader is not + case (r, _: StreamWriteSupport, _: ContinuousTrigger) if !r.isInstanceOf[ContinuousReadSupport] => testNegativeCase(read, write, trigger, s"Data source $read does not support continuous processing") - // Invalid - trigger is microbatch but writer is not - case (_, w, t) - if !w.isInstanceOf[MicroBatchWriteSupport] && !t.isInstanceOf[ContinuousTrigger] => - testNegativeCase(read, write, trigger, - s"Data source $write does not support streamed writing") - - // Invalid - trigger and writer are microbatch but reader is not + // Invalid - trigger is microbatch but reader is not case (r, _, t) if !r.isInstanceOf[MicroBatchReadSupport] && !t.isInstanceOf[ContinuousTrigger] => testPostCreationNegativeCase(read, write, trigger, From 39d2c6b03488895a0acb1dd3c46329db00fdd357 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Mon, 29 Jan 2018 21:09:05 +0900 Subject: [PATCH 0224/1161] [SPARK-23238][SQL] Externalize SQLConf configurations exposed in documentation ## What changes were proposed in this pull request? This PR proposes to expose few internal configurations found in the documentation. Also it fixes the description for `spark.sql.execution.arrow.enabled`. It's quite self-explanatory. ## How was this patch tested? N/A Author: hyukjinkwon Closes #20403 from HyukjinKwon/minor-doc-arrow. --- .../org/apache/spark/sql/internal/SQLConf.scala | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 2c70b004bcff9..61ea03d395afc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -123,14 +123,12 @@ object SQLConf { .createWithDefault(10) val COMPRESS_CACHED = buildConf("spark.sql.inMemoryColumnarStorage.compressed") - .internal() .doc("When set to true Spark SQL will automatically select a compression codec for each " + "column based on statistics of the data.") .booleanConf .createWithDefault(true) val COLUMN_BATCH_SIZE = buildConf("spark.sql.inMemoryColumnarStorage.batchSize") - .internal() .doc("Controls the size of batches for columnar caching. Larger batch sizes can improve " + "memory utilization and compression, but risk OOMs when caching data.") .intConf @@ -1043,11 +1041,11 @@ object SQLConf { val ARROW_EXECUTION_ENABLE = buildConf("spark.sql.execution.arrow.enabled") - .internal() - .doc("Make use of Apache Arrow for columnar data transfers. Currently available " + - "for use with pyspark.sql.DataFrame.toPandas with the following data types: " + - "StringType, BinaryType, BooleanType, DoubleType, FloatType, ByteType, IntegerType, " + - "LongType, ShortType") + .doc("When true, make use of Apache Arrow for columnar data transfers. Currently available " + + "for use with pyspark.sql.DataFrame.toPandas, and " + + "pyspark.sql.SparkSession.createDataFrame when its input is a Pandas DataFrame. " + + "The following data types are unsupported: " + + "MapType, ArrayType of TimestampType, and nested StructType.") .booleanConf .createWithDefault(false) From badf0d0e0d1d9aa169ed655176ce9ae684d3905d Mon Sep 17 00:00:00 2001 From: Wang Gengliang Date: Tue, 30 Jan 2018 00:50:49 +0800 Subject: [PATCH 0225/1161] [SPARK-23219][SQL] Rename ReadTask to DataReaderFactory in data source v2 ## What changes were proposed in this pull request? Currently we have `ReadTask` in data source v2 reader, while in writer we have `DataWriterFactory`. To make the naming consistent and better, renaming `ReadTask` to `DataReaderFactory`. ## How was this patch tested? Unit test Author: Wang Gengliang Closes #20397 from gengliangwang/rename. --- .../sql/kafka010/KafkaContinuousReader.scala | 16 ++--- .../execution/UnsafeExternalRowSorter.java | 1 - .../v2/reader/ClusteredDistribution.java | 2 +- .../sql/sources/v2/reader/DataReader.java | 2 +- .../{ReadTask.java => DataReaderFactory.java} | 22 +++---- .../sources/v2/reader/DataSourceV2Reader.java | 11 ++-- .../sql/sources/v2/reader/Distribution.java | 6 +- .../sql/sources/v2/reader/Partitioning.java | 2 +- .../v2/reader/SupportsScanColumnarBatch.java | 11 ++-- .../v2/reader/SupportsScanUnsafeRow.java | 9 +-- .../v2/streaming/MicroBatchReadSupport.java | 4 +- .../v2/streaming/reader/ContinuousReader.java | 14 ++--- .../v2/streaming/reader/MicroBatchReader.java | 6 +- .../datasources/v2/DataSourceRDD.scala | 14 ++--- .../datasources/v2/DataSourceV2ScanExec.scala | 25 ++++---- .../ContinuousDataSourceRDDIter.scala | 11 ++-- .../ContinuousRateStreamSource.scala | 10 ++-- .../sources/RateStreamSourceV2.scala | 6 +- .../sources/v2/JavaAdvancedDataSourceV2.java | 20 +++---- .../sql/sources/v2/JavaBatchDataSourceV2.java | 10 ++-- .../v2/JavaPartitionAwareDataSource.java | 10 ++-- .../v2/JavaSchemaRequiredDataSource.java | 4 +- .../sources/v2/JavaSimpleDataSourceV2.java | 14 ++--- .../sources/v2/JavaUnsafeRowDataSourceV2.java | 13 ++-- .../streaming/RateSourceV2Suite.scala | 10 ++-- .../sql/sources/v2/DataSourceV2Suite.scala | 59 ++++++++++--------- .../sources/v2/SimpleWritableDataSource.scala | 12 ++-- .../sources/StreamingDataSourceV2Suite.scala | 4 +- 28 files changed, 172 insertions(+), 156 deletions(-) rename sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/{ReadTask.java => DataReaderFactory.java} (65%) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala index fc977977504f7..9125cf5799d74 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala @@ -63,7 +63,7 @@ class KafkaContinuousReader( private val pollTimeoutMs = sourceOptions.getOrElse("kafkaConsumer.pollTimeoutMs", "512").toLong - // Initialized when creating read tasks. If this diverges from the partitions at the latest + // Initialized when creating reader factories. If this diverges from the partitions at the latest // offsets, we need to reconfigure. // Exposed outside this object only for unit tests. private[sql] var knownPartitions: Set[TopicPartition] = _ @@ -89,7 +89,7 @@ class KafkaContinuousReader( KafkaSourceOffset(JsonUtils.partitionOffsets(json)) } - override def createUnsafeRowReadTasks(): ju.List[ReadTask[UnsafeRow]] = { + override def createUnsafeRowReaderFactories(): ju.List[DataReaderFactory[UnsafeRow]] = { import scala.collection.JavaConverters._ val oldStartPartitionOffsets = KafkaSourceOffset.getPartitionOffsets(offset) @@ -109,9 +109,9 @@ class KafkaContinuousReader( startOffsets.toSeq.map { case (topicPartition, start) => - KafkaContinuousReadTask( + KafkaContinuousDataReaderFactory( topicPartition, start, kafkaParams, pollTimeoutMs, failOnDataLoss) - .asInstanceOf[ReadTask[UnsafeRow]] + .asInstanceOf[DataReaderFactory[UnsafeRow]] }.asJava } @@ -149,8 +149,8 @@ class KafkaContinuousReader( } /** - * A read task for continuous Kafka processing. This will be serialized and transformed into a - * full reader on executors. + * A data reader factory for continuous Kafka processing. This will be serialized and transformed + * into a full reader on executors. * * @param topicPartition The (topic, partition) pair this task is responsible for. * @param startOffset The offset to start reading from within the partition. @@ -159,12 +159,12 @@ class KafkaContinuousReader( * @param failOnDataLoss Flag indicating whether data reader should fail if some offsets * are skipped. */ -case class KafkaContinuousReadTask( +case class KafkaContinuousDataReaderFactory( topicPartition: TopicPartition, startOffset: Long, kafkaParams: ju.Map[String, Object], pollTimeoutMs: Long, - failOnDataLoss: Boolean) extends ReadTask[UnsafeRow] { + failOnDataLoss: Boolean) extends DataReaderFactory[UnsafeRow] { override def createDataReader(): KafkaContinuousDataReader = { new KafkaContinuousDataReader( topicPartition, startOffset, kafkaParams, pollTimeoutMs, failOnDataLoss) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java index 78647b56d621f..1b2f5eee5ccdd 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java @@ -20,7 +20,6 @@ import java.io.IOException; import java.util.function.Supplier; -import org.apache.spark.sql.catalyst.util.TypeUtils; import scala.collection.AbstractIterator; import scala.collection.Iterator; import scala.math.Ordering; diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ClusteredDistribution.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ClusteredDistribution.java index 7346500de45b6..27905e325df87 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ClusteredDistribution.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ClusteredDistribution.java @@ -22,7 +22,7 @@ /** * A concrete implementation of {@link Distribution}. Represents a distribution where records that * share the same values for the {@link #clusteredColumns} will be produced by the same - * {@link ReadTask}. + * {@link DataReader}. */ @InterfaceStability.Evolving public class ClusteredDistribution implements Distribution { diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReader.java index 8f58c865b6201..bb9790a1c819e 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReader.java @@ -23,7 +23,7 @@ import org.apache.spark.annotation.InterfaceStability; /** - * A data reader returned by {@link ReadTask#createDataReader()} and is responsible for + * A data reader returned by {@link DataReaderFactory#createDataReader()} and is responsible for * outputting data for a RDD partition. * * Note that, Currently the type `T` can only be {@link org.apache.spark.sql.Row} for normal data diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ReadTask.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReaderFactory.java similarity index 65% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ReadTask.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReaderFactory.java index fa161cdb8b347..077b95b837964 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ReadTask.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReaderFactory.java @@ -22,21 +22,23 @@ import org.apache.spark.annotation.InterfaceStability; /** - * A read task returned by {@link DataSourceV2Reader#createReadTasks()} and is responsible for - * creating the actual data reader. The relationship between {@link ReadTask} and {@link DataReader} + * A reader factory returned by {@link DataSourceV2Reader#createDataReaderFactories()} and is + * responsible for creating the actual data reader. The relationship between + * {@link DataReaderFactory} and {@link DataReader} * is similar to the relationship between {@link Iterable} and {@link java.util.Iterator}. * - * Note that, the read task will be serialized and sent to executors, then the data reader will be - * created on executors and do the actual reading. So {@link ReadTask} must be serializable and - * {@link DataReader} doesn't need to be. + * Note that, the reader factory will be serialized and sent to executors, then the data reader + * will be created on executors and do the actual reading. So {@link DataReaderFactory} must be + * serializable and {@link DataReader} doesn't need to be. */ @InterfaceStability.Evolving -public interface ReadTask extends Serializable { +public interface DataReaderFactory extends Serializable { /** - * The preferred locations where this read task can run faster, but Spark does not guarantee that - * this task will always run on these locations. The implementations should make sure that it can - * be run on any location. The location is a string representing the host name. + * The preferred locations where the data reader returned by this reader factory can run faster, + * but Spark does not guarantee to run the data reader on these locations. + * The implementations should make sure that it can be run on any location. + * The location is a string representing the host name. * * Note that if a host name cannot be recognized by Spark, it will be ignored as it was not in * the returned locations. By default this method returns empty string array, which means this @@ -50,7 +52,7 @@ default String[] preferredLocations() { } /** - * Returns a data reader to do the actual reading work for this read task. + * Returns a data reader to do the actual reading work. * * If this method fails (by throwing an exception), the corresponding Spark task would fail and * get retried until hitting the maximum retry times. diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceV2Reader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceV2Reader.java index f23c3842bf1b1..0180cd9ea47f8 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceV2Reader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceV2Reader.java @@ -30,7 +30,8 @@ * {@link org.apache.spark.sql.sources.v2.ReadSupportWithSchema#createReader( * StructType, org.apache.spark.sql.sources.v2.DataSourceV2Options)}. * It can mix in various query optimization interfaces to speed up the data scan. The actual scan - * logic is delegated to {@link ReadTask}s that are returned by {@link #createReadTasks()}. + * logic is delegated to {@link DataReaderFactory}s that are returned by + * {@link #createDataReaderFactories()}. * * There are mainly 3 kinds of query optimizations: * 1. Operators push-down. E.g., filter push-down, required columns push-down(aka column @@ -63,9 +64,9 @@ public interface DataSourceV2Reader { StructType readSchema(); /** - * Returns a list of read tasks. Each task is responsible for outputting data for one RDD - * partition. That means the number of tasks returned here is same as the number of RDD - * partitions this scan outputs. + * Returns a list of reader factories. Each factory is responsible for creating a data reader to + * output data for one RDD partition. That means the number of factories returned here is same as + * the number of RDD partitions this scan outputs. * * Note that, this may not be a full scan if the data source reader mixes in other optimization * interfaces like column pruning, filter push-down, etc. These optimizations are applied before @@ -74,5 +75,5 @@ public interface DataSourceV2Reader { * If this method fails (by throwing an exception), the action would fail and no Spark job was * submitted. */ - List> createReadTasks(); + List> createDataReaderFactories(); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Distribution.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Distribution.java index a6201a222f541..b37562167d9ef 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Distribution.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Distribution.java @@ -21,9 +21,9 @@ /** * An interface to represent data distribution requirement, which specifies how the records should - * be distributed among the {@link ReadTask}s that are returned by - * {@link DataSourceV2Reader#createReadTasks()}. Note that this interface has nothing to do with - * the data ordering inside one partition(the output records of a single {@link ReadTask}). + * be distributed among the data partitions(one {@link DataReader} outputs data for one partition). + * Note that this interface has nothing to do with the data ordering inside one + * partition(the output records of a single {@link DataReader}). * * The instance of this interface is created and provided by Spark, then consumed by * {@link Partitioning#satisfy(Distribution)}. This means data source developers don't need to diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Partitioning.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Partitioning.java index 199e45d4a02ab..5e334d13a1215 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Partitioning.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Partitioning.java @@ -29,7 +29,7 @@ public interface Partitioning { /** - * Returns the number of partitions(i.e., {@link ReadTask}s) the data source outputs. + * Returns the number of partitions(i.e., {@link DataReaderFactory}s) the data source outputs. */ int numPartitions(); diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanColumnarBatch.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanColumnarBatch.java index 27cf3a77724f0..67da55554bbf3 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanColumnarBatch.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanColumnarBatch.java @@ -30,21 +30,22 @@ @InterfaceStability.Evolving public interface SupportsScanColumnarBatch extends DataSourceV2Reader { @Override - default List> createReadTasks() { + default List> createDataReaderFactories() { throw new IllegalStateException( - "createReadTasks not supported by default within SupportsScanColumnarBatch."); + "createDataReaderFactories not supported by default within SupportsScanColumnarBatch."); } /** - * Similar to {@link DataSourceV2Reader#createReadTasks()}, but returns columnar data in batches. + * Similar to {@link DataSourceV2Reader#createDataReaderFactories()}, but returns columnar data + * in batches. */ - List> createBatchReadTasks(); + List> createBatchDataReaderFactories(); /** * Returns true if the concrete data source reader can read data in batch according to the scan * properties like required columns, pushes filters, etc. It's possible that the implementation * can only support some certain columns with certain types. Users can overwrite this method and - * {@link #createReadTasks()} to fallback to normal read path under some conditions. + * {@link #createDataReaderFactories()} to fallback to normal read path under some conditions. */ default boolean enableBatchRead() { return true; diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanUnsafeRow.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanUnsafeRow.java index 2d3ad0eee65ff..156af69520f77 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanUnsafeRow.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanUnsafeRow.java @@ -33,13 +33,14 @@ public interface SupportsScanUnsafeRow extends DataSourceV2Reader { @Override - default List> createReadTasks() { + default List> createDataReaderFactories() { throw new IllegalStateException( - "createReadTasks not supported by default within SupportsScanUnsafeRow"); + "createDataReaderFactories not supported by default within SupportsScanUnsafeRow"); } /** - * Similar to {@link DataSourceV2Reader#createReadTasks()}, but returns data in unsafe row format. + * Similar to {@link DataSourceV2Reader#createDataReaderFactories()}, + * but returns data in unsafe row format. */ - List> createUnsafeRowReadTasks(); + List> createUnsafeRowReaderFactories(); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/MicroBatchReadSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/MicroBatchReadSupport.java index 3c87a3db68243..3b357c01a29fe 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/MicroBatchReadSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/MicroBatchReadSupport.java @@ -36,8 +36,8 @@ public interface MicroBatchReadSupport extends DataSourceV2 { * streaming query. * * The execution engine will create a micro-batch reader at the start of a streaming query, - * alternate calls to setOffsetRange and createReadTasks for each batch to process, and then - * call stop() when the execution is complete. Note that a single query may have multiple + * alternate calls to setOffsetRange and createDataReaderFactories for each batch to process, and + * then call stop() when the execution is complete. Note that a single query may have multiple * executions due to restart or failure recovery. * * @param schema the user provided schema, or empty() if none was provided diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/ContinuousReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/ContinuousReader.java index 745f1ce502443..3ac979cb0b7b4 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/ContinuousReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/ContinuousReader.java @@ -27,7 +27,7 @@ * A mix-in interface for {@link DataSourceV2Reader}. Data source readers can implement this * interface to allow reading in a continuous processing mode stream. * - * Implementations must ensure each read task output is a {@link ContinuousDataReader}. + * Implementations must ensure each reader factory output is a {@link ContinuousDataReader}. * * Note: This class currently extends {@link BaseStreamingSource} to maintain compatibility with * DataSource V1 APIs. This extension will be removed once we get rid of V1 completely. @@ -47,9 +47,9 @@ public interface ContinuousReader extends BaseStreamingSource, DataSourceV2Reade Offset deserializeOffset(String json); /** - * Set the desired start offset for read tasks created from this reader. The scan will start - * from the first record after the provided offset, or from an implementation-defined inferred - * starting point if no offset is provided. + * Set the desired start offset for reader factories created from this reader. The scan will + * start from the first record after the provided offset, or from an implementation-defined + * inferred starting point if no offset is provided. */ void setOffset(Optional start); @@ -61,9 +61,9 @@ public interface ContinuousReader extends BaseStreamingSource, DataSourceV2Reade Offset getStartOffset(); /** - * The execution engine will call this method in every epoch to determine if new read tasks need - * to be generated, which may be required if for example the underlying source system has had - * partitions added or removed. + * The execution engine will call this method in every epoch to determine if new reader + * factories need to be generated, which may be required if for example the underlying + * source system has had partitions added or removed. * * If true, the query will be shut down and restarted with a new reader. */ diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/MicroBatchReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/MicroBatchReader.java index 02f37cebc7484..68887e569fc1d 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/MicroBatchReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/MicroBatchReader.java @@ -33,9 +33,9 @@ @InterfaceStability.Evolving public interface MicroBatchReader extends DataSourceV2Reader, BaseStreamingSource { /** - * Set the desired offset range for read tasks created from this reader. Read tasks will - * generate only data within (`start`, `end`]; that is, from the first record after `start` to - * the record with offset `end`. + * Set the desired offset range for reader factories created from this reader. Reader factories + * will generate only data within (`start`, `end`]; that is, from the first record after `start` + * to the record with offset `end`. * * @param start The initial offset to scan from. If not specified, scan from an * implementation-specified start point, such as the earliest available record. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala index ac104d7cd0cb3..5ed0ba71e94c7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala @@ -22,24 +22,24 @@ import scala.reflect.ClassTag import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskContext} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.sources.v2.reader.ReadTask +import org.apache.spark.sql.sources.v2.reader.DataReaderFactory -class DataSourceRDDPartition[T : ClassTag](val index: Int, val readTask: ReadTask[T]) +class DataSourceRDDPartition[T : ClassTag](val index: Int, val readerFactory: DataReaderFactory[T]) extends Partition with Serializable class DataSourceRDD[T: ClassTag]( sc: SparkContext, - @transient private val readTasks: java.util.List[ReadTask[T]]) + @transient private val readerFactories: java.util.List[DataReaderFactory[T]]) extends RDD[T](sc, Nil) { override protected def getPartitions: Array[Partition] = { - readTasks.asScala.zipWithIndex.map { - case (readTask, index) => new DataSourceRDDPartition(index, readTask) + readerFactories.asScala.zipWithIndex.map { + case (readerFactory, index) => new DataSourceRDDPartition(index, readerFactory) }.toArray } override def compute(split: Partition, context: TaskContext): Iterator[T] = { - val reader = split.asInstanceOf[DataSourceRDDPartition[T]].readTask.createDataReader() + val reader = split.asInstanceOf[DataSourceRDDPartition[T]].readerFactory.createDataReader() context.addTaskCompletionListener(_ => reader.close()) val iter = new Iterator[T] { private[this] var valuePrepared = false @@ -63,6 +63,6 @@ class DataSourceRDD[T: ClassTag]( } override def getPreferredLocations(split: Partition): Seq[String] = { - split.asInstanceOf[DataSourceRDDPartition[T]].readTask.preferredLocations() + split.asInstanceOf[DataSourceRDDPartition[T]].readerFactory.preferredLocations() } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala index 2c22239e81869..3f808fbb40932 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala @@ -51,11 +51,11 @@ case class DataSourceV2ScanExec( case _ => super.outputPartitioning } - private lazy val readTasks: java.util.List[ReadTask[UnsafeRow]] = reader match { - case r: SupportsScanUnsafeRow => r.createUnsafeRowReadTasks() + private lazy val readerFactories: java.util.List[DataReaderFactory[UnsafeRow]] = reader match { + case r: SupportsScanUnsafeRow => r.createUnsafeRowReaderFactories() case _ => - reader.createReadTasks().asScala.map { - new RowToUnsafeRowReadTask(_, reader.readSchema()): ReadTask[UnsafeRow] + reader.createDataReaderFactories().asScala.map { + new RowToUnsafeRowDataReaderFactory(_, reader.readSchema()): DataReaderFactory[UnsafeRow] }.asJava } @@ -63,18 +63,19 @@ case class DataSourceV2ScanExec( case r: SupportsScanColumnarBatch if r.enableBatchRead() => assert(!reader.isInstanceOf[ContinuousReader], "continuous stream reader does not support columnar read yet.") - new DataSourceRDD(sparkContext, r.createBatchReadTasks()).asInstanceOf[RDD[InternalRow]] + new DataSourceRDD(sparkContext, r.createBatchDataReaderFactories()) + .asInstanceOf[RDD[InternalRow]] case _: ContinuousReader => EpochCoordinatorRef.get( sparkContext.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), sparkContext.env) - .askSync[Unit](SetReaderPartitions(readTasks.size())) - new ContinuousDataSourceRDD(sparkContext, sqlContext, readTasks) + .askSync[Unit](SetReaderPartitions(readerFactories.size())) + new ContinuousDataSourceRDD(sparkContext, sqlContext, readerFactories) .asInstanceOf[RDD[InternalRow]] case _ => - new DataSourceRDD(sparkContext, readTasks).asInstanceOf[RDD[InternalRow]] + new DataSourceRDD(sparkContext, readerFactories).asInstanceOf[RDD[InternalRow]] } override def inputRDDs(): Seq[RDD[InternalRow]] = Seq(inputRDD) @@ -99,14 +100,14 @@ case class DataSourceV2ScanExec( } } -class RowToUnsafeRowReadTask(rowReadTask: ReadTask[Row], schema: StructType) - extends ReadTask[UnsafeRow] { +class RowToUnsafeRowDataReaderFactory(rowReaderFactory: DataReaderFactory[Row], schema: StructType) + extends DataReaderFactory[UnsafeRow] { - override def preferredLocations: Array[String] = rowReadTask.preferredLocations + override def preferredLocations: Array[String] = rowReaderFactory.preferredLocations override def createDataReader: DataReader[UnsafeRow] = { new RowToUnsafeDataReader( - rowReadTask.createDataReader, RowEncoder.apply(schema).resolveAndBind()) + rowReaderFactory.createDataReader, RowEncoder.apply(schema).resolveAndBind()) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala index cd7065f5e6601..8a7a38b22caca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala @@ -39,15 +39,15 @@ import org.apache.spark.util.{SystemClock, ThreadUtils} class ContinuousDataSourceRDD( sc: SparkContext, sqlContext: SQLContext, - @transient private val readTasks: java.util.List[ReadTask[UnsafeRow]]) + @transient private val readerFactories: java.util.List[DataReaderFactory[UnsafeRow]]) extends RDD[UnsafeRow](sc, Nil) { private val dataQueueSize = sqlContext.conf.continuousStreamingExecutorQueueSize private val epochPollIntervalMs = sqlContext.conf.continuousStreamingExecutorPollIntervalMs override protected def getPartitions: Array[Partition] = { - readTasks.asScala.zipWithIndex.map { - case (readTask, index) => new DataSourceRDDPartition(index, readTask) + readerFactories.asScala.zipWithIndex.map { + case (readerFactory, index) => new DataSourceRDDPartition(index, readerFactory) }.toArray } @@ -57,7 +57,8 @@ class ContinuousDataSourceRDD( throw new ContinuousTaskRetryException() } - val reader = split.asInstanceOf[DataSourceRDDPartition[UnsafeRow]].readTask.createDataReader() + val reader = split.asInstanceOf[DataSourceRDDPartition[UnsafeRow]] + .readerFactory.createDataReader() val coordinatorId = context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY) @@ -136,7 +137,7 @@ class ContinuousDataSourceRDD( } override def getPreferredLocations(split: Partition): Seq[String] = { - split.asInstanceOf[DataSourceRDDPartition[UnsafeRow]].readTask.preferredLocations() + split.asInstanceOf[DataSourceRDDPartition[UnsafeRow]].readerFactory.preferredLocations() } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala index b4b21e7d2052f..61304480f4721 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala @@ -68,7 +68,7 @@ class RateStreamContinuousReader(options: DataSourceV2Options) override def getStartOffset(): Offset = offset - override def createReadTasks(): java.util.List[ReadTask[Row]] = { + override def createDataReaderFactories(): java.util.List[DataReaderFactory[Row]] = { val partitionStartMap = offset match { case off: RateStreamOffset => off.partitionToValueAndRunTimeMs case off => @@ -86,13 +86,13 @@ class RateStreamContinuousReader(options: DataSourceV2Options) val start = partitionStartMap(i) // Have each partition advance by numPartitions each row, with starting points staggered // by their partition index. - RateStreamContinuousReadTask( + RateStreamContinuousDataReaderFactory( start.value, start.runTimeMs, i, numPartitions, perPartitionRate) - .asInstanceOf[ReadTask[Row]] + .asInstanceOf[DataReaderFactory[Row]] }.asJava } @@ -101,13 +101,13 @@ class RateStreamContinuousReader(options: DataSourceV2Options) } -case class RateStreamContinuousReadTask( +case class RateStreamContinuousDataReaderFactory( startValue: Long, startTimeMs: Long, partitionIndex: Int, increment: Long, rowsPerSecond: Double) - extends ReadTask[Row] { + extends DataReaderFactory[Row] { override def createDataReader(): DataReader[Row] = new RateStreamContinuousDataReader( startValue, startTimeMs, partitionIndex, increment, rowsPerSecond) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala index c0ed12cec25ef..a25cc4f3b06f8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala @@ -123,7 +123,7 @@ class RateStreamMicroBatchReader(options: DataSourceV2Options) RateStreamOffset(Serialization.read[Map[Int, ValueRunTimeMsPair]](json)) } - override def createReadTasks(): java.util.List[ReadTask[Row]] = { + override def createDataReaderFactories(): java.util.List[DataReaderFactory[Row]] = { val startMap = start.partitionToValueAndRunTimeMs val endMap = end.partitionToValueAndRunTimeMs endMap.keys.toSeq.map { part => @@ -139,7 +139,7 @@ class RateStreamMicroBatchReader(options: DataSourceV2Options) outTimeMs += msPerPartitionBetweenRows } - RateStreamBatchTask(packedRows).asInstanceOf[ReadTask[Row]] + RateStreamBatchTask(packedRows).asInstanceOf[DataReaderFactory[Row]] }.toList.asJava } @@ -147,7 +147,7 @@ class RateStreamMicroBatchReader(options: DataSourceV2Options) override def stop(): Unit = {} } -case class RateStreamBatchTask(vals: Seq[(Long, Long)]) extends ReadTask[Row] { +case class RateStreamBatchTask(vals: Seq[(Long, Long)]) extends DataReaderFactory[Row] { override def createDataReader(): DataReader[Row] = new RateStreamBatchReader(vals) } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java index 1cfdc08217e6e..4026ee44bfdb7 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java @@ -60,8 +60,8 @@ public Filter[] pushedFilters() { } @Override - public List> createReadTasks() { - List> res = new ArrayList<>(); + public List> createDataReaderFactories() { + List> res = new ArrayList<>(); Integer lowerBound = null; for (Filter filter : filters) { @@ -75,25 +75,25 @@ public List> createReadTasks() { } if (lowerBound == null) { - res.add(new JavaAdvancedReadTask(0, 5, requiredSchema)); - res.add(new JavaAdvancedReadTask(5, 10, requiredSchema)); + res.add(new JavaAdvancedDataReaderFactory(0, 5, requiredSchema)); + res.add(new JavaAdvancedDataReaderFactory(5, 10, requiredSchema)); } else if (lowerBound < 4) { - res.add(new JavaAdvancedReadTask(lowerBound + 1, 5, requiredSchema)); - res.add(new JavaAdvancedReadTask(5, 10, requiredSchema)); + res.add(new JavaAdvancedDataReaderFactory(lowerBound + 1, 5, requiredSchema)); + res.add(new JavaAdvancedDataReaderFactory(5, 10, requiredSchema)); } else if (lowerBound < 9) { - res.add(new JavaAdvancedReadTask(lowerBound + 1, 10, requiredSchema)); + res.add(new JavaAdvancedDataReaderFactory(lowerBound + 1, 10, requiredSchema)); } return res; } } - static class JavaAdvancedReadTask implements ReadTask, DataReader { + static class JavaAdvancedDataReaderFactory implements DataReaderFactory, DataReader { private int start; private int end; private StructType requiredSchema; - JavaAdvancedReadTask(int start, int end, StructType requiredSchema) { + JavaAdvancedDataReaderFactory(int start, int end, StructType requiredSchema) { this.start = start; this.end = end; this.requiredSchema = requiredSchema; @@ -101,7 +101,7 @@ static class JavaAdvancedReadTask implements ReadTask, DataReader { @Override public DataReader createDataReader() { - return new JavaAdvancedReadTask(start - 1, end, requiredSchema); + return new JavaAdvancedDataReaderFactory(start - 1, end, requiredSchema); } @Override diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java index a5d77a90ece42..34e6c63801064 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java @@ -42,12 +42,14 @@ public StructType readSchema() { } @Override - public List> createBatchReadTasks() { - return java.util.Arrays.asList(new JavaBatchReadTask(0, 50), new JavaBatchReadTask(50, 90)); + public List> createBatchDataReaderFactories() { + return java.util.Arrays.asList( + new JavaBatchDataReaderFactory(0, 50), new JavaBatchDataReaderFactory(50, 90)); } } - static class JavaBatchReadTask implements ReadTask, DataReader { + static class JavaBatchDataReaderFactory + implements DataReaderFactory, DataReader { private int start; private int end; @@ -57,7 +59,7 @@ static class JavaBatchReadTask implements ReadTask, DataReader> createReadTasks() { + public List> createDataReaderFactories() { return java.util.Arrays.asList( - new SpecificReadTask(new int[]{1, 1, 3}, new int[]{4, 4, 6}), - new SpecificReadTask(new int[]{2, 4, 4}, new int[]{6, 2, 2})); + new SpecificDataReaderFactory(new int[]{1, 1, 3}, new int[]{4, 4, 6}), + new SpecificDataReaderFactory(new int[]{2, 4, 4}, new int[]{6, 2, 2})); } @Override @@ -70,12 +70,12 @@ public boolean satisfy(Distribution distribution) { } } - static class SpecificReadTask implements ReadTask, DataReader { + static class SpecificDataReaderFactory implements DataReaderFactory, DataReader { private int[] i; private int[] j; private int current = -1; - SpecificReadTask(int[] i, int[] j) { + SpecificDataReaderFactory(int[] i, int[] j) { assert i.length == j.length; this.i = i; this.j = j; diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java index a174bd8092cbd..f997366af1a64 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java @@ -24,7 +24,7 @@ import org.apache.spark.sql.sources.v2.DataSourceV2Options; import org.apache.spark.sql.sources.v2.ReadSupportWithSchema; import org.apache.spark.sql.sources.v2.reader.DataSourceV2Reader; -import org.apache.spark.sql.sources.v2.reader.ReadTask; +import org.apache.spark.sql.sources.v2.reader.DataReaderFactory; import org.apache.spark.sql.types.StructType; public class JavaSchemaRequiredDataSource implements DataSourceV2, ReadSupportWithSchema { @@ -42,7 +42,7 @@ public StructType readSchema() { } @Override - public List> createReadTasks() { + public List> createDataReaderFactories() { return java.util.Collections.emptyList(); } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java index 2d458b7f7e906..2beed431d301f 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java @@ -26,7 +26,7 @@ import org.apache.spark.sql.sources.v2.DataSourceV2Options; import org.apache.spark.sql.sources.v2.ReadSupport; import org.apache.spark.sql.sources.v2.reader.DataReader; -import org.apache.spark.sql.sources.v2.reader.ReadTask; +import org.apache.spark.sql.sources.v2.reader.DataReaderFactory; import org.apache.spark.sql.sources.v2.reader.DataSourceV2Reader; import org.apache.spark.sql.types.StructType; @@ -41,25 +41,25 @@ public StructType readSchema() { } @Override - public List> createReadTasks() { + public List> createDataReaderFactories() { return java.util.Arrays.asList( - new JavaSimpleReadTask(0, 5), - new JavaSimpleReadTask(5, 10)); + new JavaSimpleDataReaderFactory(0, 5), + new JavaSimpleDataReaderFactory(5, 10)); } } - static class JavaSimpleReadTask implements ReadTask, DataReader { + static class JavaSimpleDataReaderFactory implements DataReaderFactory, DataReader { private int start; private int end; - JavaSimpleReadTask(int start, int end) { + JavaSimpleDataReaderFactory(int start, int end) { this.start = start; this.end = end; } @Override public DataReader createDataReader() { - return new JavaSimpleReadTask(start - 1, end); + return new JavaSimpleDataReaderFactory(start - 1, end); } @Override diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaUnsafeRowDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaUnsafeRowDataSourceV2.java index f6aa00869a681..e8187524ea871 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaUnsafeRowDataSourceV2.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaUnsafeRowDataSourceV2.java @@ -38,19 +38,20 @@ public StructType readSchema() { } @Override - public List> createUnsafeRowReadTasks() { + public List> createUnsafeRowReaderFactories() { return java.util.Arrays.asList( - new JavaUnsafeRowReadTask(0, 5), - new JavaUnsafeRowReadTask(5, 10)); + new JavaUnsafeRowDataReaderFactory(0, 5), + new JavaUnsafeRowDataReaderFactory(5, 10)); } } - static class JavaUnsafeRowReadTask implements ReadTask, DataReader { + static class JavaUnsafeRowDataReaderFactory + implements DataReaderFactory, DataReader { private int start; private int end; private UnsafeRow row; - JavaUnsafeRowReadTask(int start, int end) { + JavaUnsafeRowDataReaderFactory(int start, int end) { this.start = start; this.end = end; this.row = new UnsafeRow(2); @@ -59,7 +60,7 @@ static class JavaUnsafeRowReadTask implements ReadTask, DataReader createDataReader() { - return new JavaUnsafeRowReadTask(start - 1, end); + return new JavaUnsafeRowDataReaderFactory(start - 1, end); } @Override diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala index 85085d43061bd..d2cfe7905f6fa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala @@ -78,7 +78,7 @@ class RateSourceV2Suite extends StreamTest { val reader = new RateStreamMicroBatchReader( new DataSourceV2Options(Map("numPartitions" -> "11", "rowsPerSecond" -> "33").asJava)) reader.setOffsetRange(Optional.empty(), Optional.empty()) - val tasks = reader.createReadTasks() + val tasks = reader.createDataReaderFactories() assert(tasks.size == 11) } @@ -118,7 +118,7 @@ class RateSourceV2Suite extends StreamTest { val startOffset = RateStreamOffset(Map((0, ValueRunTimeMsPair(0, 1000)))) val endOffset = RateStreamOffset(Map((0, ValueRunTimeMsPair(20, 2000)))) reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) - val tasks = reader.createReadTasks() + val tasks = reader.createDataReaderFactories() assert(tasks.size == 1) assert(tasks.get(0).asInstanceOf[RateStreamBatchTask].vals.size == 20) } @@ -133,7 +133,7 @@ class RateSourceV2Suite extends StreamTest { }.toMap) reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) - val tasks = reader.createReadTasks() + val tasks = reader.createDataReaderFactories() assert(tasks.size == 11) val readData = tasks.asScala @@ -161,12 +161,12 @@ class RateSourceV2Suite extends StreamTest { val reader = new RateStreamContinuousReader( new DataSourceV2Options(Map("numPartitions" -> "2", "rowsPerSecond" -> "20").asJava)) reader.setOffset(Optional.empty()) - val tasks = reader.createReadTasks() + val tasks = reader.createDataReaderFactories() assert(tasks.size == 2) val data = scala.collection.mutable.ListBuffer[Row]() tasks.asScala.foreach { - case t: RateStreamContinuousReadTask => + case t: RateStreamContinuousDataReaderFactory => val startTimeMs = reader.getStartOffset() .asInstanceOf[RateStreamOffset] .partitionToValueAndRunTimeMs(t.partitionIndex) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala index 0620693b35d16..42c5d3bcea44b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala @@ -204,18 +204,20 @@ class SimpleDataSourceV2 extends DataSourceV2 with ReadSupport { class Reader extends DataSourceV2Reader { override def readSchema(): StructType = new StructType().add("i", "int").add("j", "int") - override def createReadTasks(): JList[ReadTask[Row]] = { - java.util.Arrays.asList(new SimpleReadTask(0, 5), new SimpleReadTask(5, 10)) + override def createDataReaderFactories(): JList[DataReaderFactory[Row]] = { + java.util.Arrays.asList(new SimpleDataReaderFactory(0, 5), new SimpleDataReaderFactory(5, 10)) } } override def createReader(options: DataSourceV2Options): DataSourceV2Reader = new Reader } -class SimpleReadTask(start: Int, end: Int) extends ReadTask[Row] with DataReader[Row] { +class SimpleDataReaderFactory(start: Int, end: Int) + extends DataReaderFactory[Row] + with DataReader[Row] { private var current = start - 1 - override def createDataReader(): DataReader[Row] = new SimpleReadTask(start, end) + override def createDataReader(): DataReader[Row] = new SimpleDataReaderFactory(start, end) override def next(): Boolean = { current += 1 @@ -252,21 +254,21 @@ class AdvancedDataSourceV2 extends DataSourceV2 with ReadSupport { requiredSchema } - override def createReadTasks(): JList[ReadTask[Row]] = { + override def createDataReaderFactories(): JList[DataReaderFactory[Row]] = { val lowerBound = filters.collect { case GreaterThan("i", v: Int) => v }.headOption - val res = new ArrayList[ReadTask[Row]] + val res = new ArrayList[DataReaderFactory[Row]] if (lowerBound.isEmpty) { - res.add(new AdvancedReadTask(0, 5, requiredSchema)) - res.add(new AdvancedReadTask(5, 10, requiredSchema)) + res.add(new AdvancedDataReaderFactory(0, 5, requiredSchema)) + res.add(new AdvancedDataReaderFactory(5, 10, requiredSchema)) } else if (lowerBound.get < 4) { - res.add(new AdvancedReadTask(lowerBound.get + 1, 5, requiredSchema)) - res.add(new AdvancedReadTask(5, 10, requiredSchema)) + res.add(new AdvancedDataReaderFactory(lowerBound.get + 1, 5, requiredSchema)) + res.add(new AdvancedDataReaderFactory(5, 10, requiredSchema)) } else if (lowerBound.get < 9) { - res.add(new AdvancedReadTask(lowerBound.get + 1, 10, requiredSchema)) + res.add(new AdvancedDataReaderFactory(lowerBound.get + 1, 10, requiredSchema)) } res @@ -276,13 +278,13 @@ class AdvancedDataSourceV2 extends DataSourceV2 with ReadSupport { override def createReader(options: DataSourceV2Options): DataSourceV2Reader = new Reader } -class AdvancedReadTask(start: Int, end: Int, requiredSchema: StructType) - extends ReadTask[Row] with DataReader[Row] { +class AdvancedDataReaderFactory(start: Int, end: Int, requiredSchema: StructType) + extends DataReaderFactory[Row] with DataReader[Row] { private var current = start - 1 override def createDataReader(): DataReader[Row] = { - new AdvancedReadTask(start, end, requiredSchema) + new AdvancedDataReaderFactory(start, end, requiredSchema) } override def close(): Unit = {} @@ -307,16 +309,17 @@ class UnsafeRowDataSourceV2 extends DataSourceV2 with ReadSupport { class Reader extends DataSourceV2Reader with SupportsScanUnsafeRow { override def readSchema(): StructType = new StructType().add("i", "int").add("j", "int") - override def createUnsafeRowReadTasks(): JList[ReadTask[UnsafeRow]] = { - java.util.Arrays.asList(new UnsafeRowReadTask(0, 5), new UnsafeRowReadTask(5, 10)) + override def createUnsafeRowReaderFactories(): JList[DataReaderFactory[UnsafeRow]] = { + java.util.Arrays.asList(new UnsafeRowDataReaderFactory(0, 5), + new UnsafeRowDataReaderFactory(5, 10)) } } override def createReader(options: DataSourceV2Options): DataSourceV2Reader = new Reader } -class UnsafeRowReadTask(start: Int, end: Int) - extends ReadTask[UnsafeRow] with DataReader[UnsafeRow] { +class UnsafeRowDataReaderFactory(start: Int, end: Int) + extends DataReaderFactory[UnsafeRow] with DataReader[UnsafeRow] { private val row = new UnsafeRow(2) row.pointTo(new Array[Byte](8 * 3), 8 * 3) @@ -341,7 +344,7 @@ class UnsafeRowReadTask(start: Int, end: Int) class SchemaRequiredDataSource extends DataSourceV2 with ReadSupportWithSchema { class Reader(val readSchema: StructType) extends DataSourceV2Reader { - override def createReadTasks(): JList[ReadTask[Row]] = + override def createDataReaderFactories(): JList[DataReaderFactory[Row]] = java.util.Collections.emptyList() } @@ -354,16 +357,16 @@ class BatchDataSourceV2 extends DataSourceV2 with ReadSupport { class Reader extends DataSourceV2Reader with SupportsScanColumnarBatch { override def readSchema(): StructType = new StructType().add("i", "int").add("j", "int") - override def createBatchReadTasks(): JList[ReadTask[ColumnarBatch]] = { - java.util.Arrays.asList(new BatchReadTask(0, 50), new BatchReadTask(50, 90)) + override def createBatchDataReaderFactories(): JList[DataReaderFactory[ColumnarBatch]] = { + java.util.Arrays.asList(new BatchDataReaderFactory(0, 50), new BatchDataReaderFactory(50, 90)) } } override def createReader(options: DataSourceV2Options): DataSourceV2Reader = new Reader } -class BatchReadTask(start: Int, end: Int) - extends ReadTask[ColumnarBatch] with DataReader[ColumnarBatch] { +class BatchDataReaderFactory(start: Int, end: Int) + extends DataReaderFactory[ColumnarBatch] with DataReader[ColumnarBatch] { private final val BATCH_SIZE = 20 private lazy val i = new OnHeapColumnVector(BATCH_SIZE, IntegerType) @@ -406,11 +409,11 @@ class PartitionAwareDataSource extends DataSourceV2 with ReadSupport { class Reader extends DataSourceV2Reader with SupportsReportPartitioning { override def readSchema(): StructType = new StructType().add("a", "int").add("b", "int") - override def createReadTasks(): JList[ReadTask[Row]] = { + override def createDataReaderFactories(): JList[DataReaderFactory[Row]] = { // Note that we don't have same value of column `a` across partitions. java.util.Arrays.asList( - new SpecificReadTask(Array(1, 1, 3), Array(4, 4, 6)), - new SpecificReadTask(Array(2, 4, 4), Array(6, 2, 2))) + new SpecificDataReaderFactory(Array(1, 1, 3), Array(4, 4, 6)), + new SpecificDataReaderFactory(Array(2, 4, 4), Array(6, 2, 2))) } override def outputPartitioning(): Partitioning = new MyPartitioning @@ -428,7 +431,9 @@ class PartitionAwareDataSource extends DataSourceV2 with ReadSupport { override def createReader(options: DataSourceV2Options): DataSourceV2Reader = new Reader } -class SpecificReadTask(i: Array[Int], j: Array[Int]) extends ReadTask[Row] with DataReader[Row] { +class SpecificDataReaderFactory(i: Array[Int], j: Array[Int]) + extends DataReaderFactory[Row] + with DataReader[Row] { assert(i.length == j.length) private var current = -1 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala index cd7252eb2e3d6..3310d6dd199d6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala @@ -28,7 +28,7 @@ import org.apache.hadoop.fs.{FileSystem, FSDataInputStream, Path} import org.apache.spark.SparkContext import org.apache.spark.sql.{Row, SaveMode} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.sources.v2.reader.{DataReader, DataSourceV2Reader, ReadTask} +import org.apache.spark.sql.sources.v2.reader.{DataReader, DataReaderFactory, DataSourceV2Reader} import org.apache.spark.sql.sources.v2.writer._ import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.util.SerializableConfiguration @@ -45,7 +45,7 @@ class SimpleWritableDataSource extends DataSourceV2 with ReadSupport with WriteS class Reader(path: String, conf: Configuration) extends DataSourceV2Reader { override def readSchema(): StructType = schema - override def createReadTasks(): JList[ReadTask[Row]] = { + override def createDataReaderFactories(): JList[DataReaderFactory[Row]] = { val dataPath = new Path(path) val fs = dataPath.getFileSystem(conf) if (fs.exists(dataPath)) { @@ -54,7 +54,9 @@ class SimpleWritableDataSource extends DataSourceV2 with ReadSupport with WriteS name.startsWith("_") || name.startsWith(".") }.map { f => val serializableConf = new SerializableConfiguration(conf) - new SimpleCSVReadTask(f.getPath.toUri.toString, serializableConf): ReadTask[Row] + new SimpleCSVDataReaderFactory( + f.getPath.toUri.toString, + serializableConf): DataReaderFactory[Row] }.toList.asJava } else { Collections.emptyList() @@ -149,8 +151,8 @@ class SimpleWritableDataSource extends DataSourceV2 with ReadSupport with WriteS } } -class SimpleCSVReadTask(path: String, conf: SerializableConfiguration) - extends ReadTask[Row] with DataReader[Row] { +class SimpleCSVDataReaderFactory(path: String, conf: SerializableConfiguration) + extends DataReaderFactory[Row] with DataReader[Row] { @transient private var lines: Iterator[String] = _ @transient private var currentLine: String = _ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala index d4f8bae96695d..dc8c857018457 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.{DataSourceRegister, StreamSinkProvider} import org.apache.spark.sql.sources.v2.DataSourceV2Options -import org.apache.spark.sql.sources.v2.reader.ReadTask +import org.apache.spark.sql.sources.v2.reader.DataReaderFactory import org.apache.spark.sql.sources.v2.streaming._ import org.apache.spark.sql.sources.v2.streaming.reader.{ContinuousReader, MicroBatchReader, Offset, PartitionOffset} import org.apache.spark.sql.sources.v2.streaming.writer.StreamWriter @@ -45,7 +45,7 @@ case class FakeReader() extends MicroBatchReader with ContinuousReader { def mergeOffsets(offsets: Array[PartitionOffset]): Offset = RateStreamOffset(Map()) def setOffset(start: Optional[Offset]): Unit = {} - def createReadTasks(): java.util.ArrayList[ReadTask[Row]] = { + def createDataReaderFactories(): java.util.ArrayList[DataReaderFactory[Row]] = { throw new IllegalStateException("fake source - cannot actually read") } } From 54dd7cf4ef921bc9dc12f99cfb90d1da57939901 Mon Sep 17 00:00:00 2001 From: caoxuewen Date: Mon, 29 Jan 2018 08:56:42 -0800 Subject: [PATCH 0226/1161] [SPARK-23199][SQL] improved Removes repetition from group expressions in Aggregate ## What changes were proposed in this pull request? Currently, all Aggregate operations will go into RemoveRepetitionFromGroupExpressions, but there is no group expression or there is no duplicate group expression in group expression, we not need copy for logic plan. ## How was this patch tested? the existed test case. Author: caoxuewen Closes #20375 from heary-cao/RepetitionGroupExpressions. --- .../apache/spark/sql/catalyst/optimizer/Optimizer.scala | 8 ++++++-- .../sql/catalyst/optimizer/AggregateOptimizeSuite.scala | 5 ++--- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 8d207708c12ad..a28b6a0feb8f9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -1302,8 +1302,12 @@ object RemoveLiteralFromGroupExpressions extends Rule[LogicalPlan] { */ object RemoveRepetitionFromGroupExpressions extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case a @ Aggregate(grouping, _, _) => + case a @ Aggregate(grouping, _, _) if grouping.size > 1 => val newGrouping = ExpressionSet(grouping).toSeq - a.copy(groupingExpressions = newGrouping) + if (newGrouping.size == grouping.size) { + a + } else { + a.copy(groupingExpressions = newGrouping) + } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala index a3184a4266c7c..f8ddc93597070 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala @@ -67,10 +67,9 @@ class AggregateOptimizeSuite extends PlanTest { } test("remove repetition in grouping expression") { - val input = LocalRelation('a.int, 'b.int, 'c.int) - val query = input.groupBy('a + 1, 'b + 2, Literal(1) + 'A, Literal(2) + 'B)(sum('c)) + val query = testRelation.groupBy('a + 1, 'b + 2, Literal(1) + 'A, Literal(2) + 'B)(sum('c)) val optimized = Optimize.execute(analyzer.execute(query)) - val correctAnswer = input.groupBy('a + 1, 'b + 2)(sum('c)).analyze + val correctAnswer = testRelation.groupBy('a + 1, 'b + 2)(sum('c)).analyze comparePlans(optimized, correctAnswer) } From fbce2ed0fa5c3e9fb2bdf9d9741eb3ff0760f88c Mon Sep 17 00:00:00 2001 From: xubo245 <601450868@qq.com> Date: Mon, 29 Jan 2018 08:58:14 -0800 Subject: [PATCH 0227/1161] [SPARK-23059][SQL][TEST] Correct some improper with view related method usage ## What changes were proposed in this pull request? Correct some improper with view related method usage Only change test cases like: ``` test("list global temp views") { try { sql("CREATE GLOBAL TEMP VIEW v1 AS SELECT 3, 4") sql("CREATE TEMP VIEW v2 AS SELECT 1, 2") checkAnswer(sql(s"SHOW TABLES IN $globalTempDB"), Row(globalTempDB, "v1", true) :: Row("", "v2", true) :: Nil) assert(spark.catalog.listTables(globalTempDB).collect().toSeq.map(_.name) == Seq("v1", "v2")) } finally { spark.catalog.dropTempView("v1") spark.catalog.dropGlobalTempView("v2") } } ``` other change please review the code. ## How was this patch tested? See test case. Author: xubo245 <601450868@qq.com> Closes #20250 from xubo245/DropTempViewError. --- .../org/apache/spark/sql/SQLQuerySuite.scala | 48 ++++++++++--------- .../sql/execution/GlobalTempViewSuite.scala | 4 +- .../spark/sql/execution/SQLViewSuite.scala | 36 ++++++++------ .../sql/execution/command/DDLSuite.scala | 2 +- .../sql/hive/MetastoreDataSourcesSuite.scala | 2 +- .../sql/hive/execution/HiveSQLViewSuite.scala | 26 +++++----- .../sql/hive/execution/SQLQuerySuite.scala | 44 +++++++++-------- 7 files changed, 88 insertions(+), 74 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index a79ab47f0197e..ffd736d2ebbb6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -1565,36 +1565,38 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("specifying database name for a temporary view is not allowed") { withTempPath { dir => - val path = dir.toURI.toString - val df = - sparkContext.parallelize(1 to 10).map(i => (i, i.toString)).toDF("num", "str") - df - .write - .format("parquet") - .save(path) - - // We don't support creating a temporary table while specifying a database - intercept[AnalysisException] { + withTempView("db.t") { + val path = dir.toURI.toString + val df = + sparkContext.parallelize(1 to 10).map(i => (i, i.toString)).toDF("num", "str") + df + .write + .format("parquet") + .save(path) + + // We don't support creating a temporary table while specifying a database + intercept[AnalysisException] { + spark.sql( + s""" + |CREATE TEMPORARY VIEW db.t + |USING parquet + |OPTIONS ( + | path '$path' + |) + """.stripMargin) + }.getMessage + + // If you use backticks to quote the name then it's OK. spark.sql( s""" - |CREATE TEMPORARY VIEW db.t + |CREATE TEMPORARY VIEW `db.t` |USING parquet |OPTIONS ( | path '$path' |) """.stripMargin) - }.getMessage - - // If you use backticks to quote the name then it's OK. - spark.sql( - s""" - |CREATE TEMPORARY VIEW `db.t` - |USING parquet - |OPTIONS ( - | path '$path' - |) - """.stripMargin) - checkAnswer(spark.table("`db.t`"), df) + checkAnswer(spark.table("`db.t`"), df) + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/GlobalTempViewSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/GlobalTempViewSuite.scala index dcc6fa6403f31..972b47e96fe06 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/GlobalTempViewSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/GlobalTempViewSuite.scala @@ -134,8 +134,8 @@ class GlobalTempViewSuite extends QueryTest with SharedSQLContext { assert(spark.catalog.listTables(globalTempDB).collect().toSeq.map(_.name) == Seq("v1", "v2")) } finally { - spark.catalog.dropTempView("v1") - spark.catalog.dropGlobalTempView("v2") + spark.catalog.dropGlobalTempView("v1") + spark.catalog.dropTempView("v2") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala index ce8fde28a941c..8269d4d3a285d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala @@ -53,15 +53,17 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils { } test("create a temp view on a permanent view") { - withView("jtv1", "temp_jtv1") { - sql("CREATE VIEW jtv1 AS SELECT * FROM jt WHERE id > 3") - sql("CREATE TEMPORARY VIEW temp_jtv1 AS SELECT * FROM jtv1 WHERE id < 6") - checkAnswer(sql("select count(*) FROM temp_jtv1"), Row(2)) + withView("jtv1") { + withTempView("temp_jtv1") { + sql("CREATE VIEW jtv1 AS SELECT * FROM jt WHERE id > 3") + sql("CREATE TEMPORARY VIEW temp_jtv1 AS SELECT * FROM jtv1 WHERE id < 6") + checkAnswer(sql("select count(*) FROM temp_jtv1"), Row(2)) + } } } test("create a temp view on a temp view") { - withView("temp_jtv1", "temp_jtv2") { + withTempView("temp_jtv1", "temp_jtv2") { sql("CREATE TEMPORARY VIEW temp_jtv1 AS SELECT * FROM jt WHERE id > 3") sql("CREATE TEMPORARY VIEW temp_jtv2 AS SELECT * FROM temp_jtv1 WHERE id < 6") checkAnswer(sql("select count(*) FROM temp_jtv2"), Row(2)) @@ -222,10 +224,12 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils { } test("error handling: disallow IF NOT EXISTS for CREATE TEMPORARY VIEW") { - val e = intercept[AnalysisException] { - sql("CREATE TEMPORARY VIEW IF NOT EXISTS myabcdview AS SELECT * FROM jt") + withTempView("myabcdview") { + val e = intercept[AnalysisException] { + sql("CREATE TEMPORARY VIEW IF NOT EXISTS myabcdview AS SELECT * FROM jt") + } + assert(e.message.contains("It is not allowed to define a TEMPORARY view with IF NOT EXISTS")) } - assert(e.message.contains("It is not allowed to define a TEMPORARY view with IF NOT EXISTS")) } test("error handling: fail if the temp view sql itself is invalid") { @@ -274,7 +278,7 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils { } test("correctly parse CREATE TEMPORARY VIEW statement") { - withView("testView") { + withTempView("testView") { sql( """CREATE TEMPORARY VIEW |testView (c1 COMMENT 'blabla', c2 COMMENT 'blabla') @@ -286,7 +290,7 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils { } test("should NOT allow CREATE TEMPORARY VIEW when TEMPORARY VIEW with same name exists") { - withView("testView") { + withTempView("testView") { sql("CREATE TEMPORARY VIEW testView AS SELECT id FROM jt") val e = intercept[AnalysisException] { @@ -299,15 +303,19 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils { test("should allow CREATE TEMPORARY VIEW when a permanent VIEW with same name exists") { withView("testView", "default.testView") { - sql("CREATE VIEW testView AS SELECT id FROM jt") - sql("CREATE TEMPORARY VIEW testView AS SELECT id FROM jt") + withTempView("testView") { + sql("CREATE VIEW testView AS SELECT id FROM jt") + sql("CREATE TEMPORARY VIEW testView AS SELECT id FROM jt") + } } } test("should allow CREATE permanent VIEW when a TEMPORARY VIEW with same name exists") { withView("testView", "default.testView") { - sql("CREATE TEMPORARY VIEW testView AS SELECT id FROM jt") - sql("CREATE VIEW testView AS SELECT id FROM jt") + withTempView("testView") { + sql("CREATE TEMPORARY VIEW testView AS SELECT id FROM jt") + sql("CREATE VIEW testView AS SELECT id FROM jt") + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index 6ca21b5aa1595..ee3674ba17821 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -739,7 +739,7 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { // starts with 'jar:', and it is an illegal parameter for Path, so here we copy it // to a temp file by withResourceTempPath withResourceTempPath("test-data/cars.csv") { tmpFile => - withView("testview") { + withTempView("testview") { sql(s"CREATE OR REPLACE TEMPORARY VIEW testview (c1 String, c2 String) USING " + "org.apache.spark.sql.execution.datasources.csv.CSVFileFormat " + s"OPTIONS (PATH '${tmpFile.toURI}')") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala index fade143a1755e..859099a321bf7 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -1151,7 +1151,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv test("create a temp view using hive") { val tableName = "tab1" - withTable(tableName) { + withTempView(tableName) { val e = intercept[AnalysisException] { sql( s""" diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSQLViewSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSQLViewSuite.scala index 97e4c2b6b2db8..5e6e114fc3fdc 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSQLViewSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSQLViewSuite.scala @@ -67,20 +67,22 @@ class HiveSQLViewSuite extends SQLViewSuite with TestHiveSingleton { classOf[org.apache.hadoop.hive.ql.udf.generic.GenericUDFUpper].getCanonicalName withUserDefinedFunction(tempFunctionName -> true) { sql(s"CREATE TEMPORARY FUNCTION $tempFunctionName AS '$functionClass'") - withView("view1", "tempView1") { - withTable("tab1") { - (1 to 10).map(i => s"$i").toDF("id").write.saveAsTable("tab1") + withView("view1") { + withTempView("tempView1") { + withTable("tab1") { + (1 to 10).map(i => s"$i").toDF("id").write.saveAsTable("tab1") - // temporary view - sql(s"CREATE TEMPORARY VIEW tempView1 AS SELECT $tempFunctionName(id) from tab1") - checkAnswer(sql("select count(*) FROM tempView1"), Row(10)) + // temporary view + sql(s"CREATE TEMPORARY VIEW tempView1 AS SELECT $tempFunctionName(id) from tab1") + checkAnswer(sql("select count(*) FROM tempView1"), Row(10)) - // permanent view - val e = intercept[AnalysisException] { - sql(s"CREATE VIEW view1 AS SELECT $tempFunctionName(id) from tab1") - }.getMessage - assert(e.contains("Not allowed to create a permanent view `view1` by referencing " + - s"a temporary function `$tempFunctionName`")) + // permanent view + val e = intercept[AnalysisException] { + sql(s"CREATE VIEW view1 AS SELECT $tempFunctionName(id) from tab1") + }.getMessage + assert(e.contains("Not allowed to create a permanent view `view1` by referencing " + + s"a temporary function `$tempFunctionName`")) + } } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 33bcae91fdaf4..baabc4a3bca2c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -1203,35 +1203,37 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { test("specifying database name for a temporary view is not allowed") { withTempPath { dir => - val path = dir.toURI.toString - val df = sparkContext.parallelize(1 to 10).map(i => (i, i.toString)).toDF("num", "str") - df - .write - .format("parquet") - .save(path) - - // We don't support creating a temporary table while specifying a database - intercept[AnalysisException] { + withTempView("db.t") { + val path = dir.toURI.toString + val df = sparkContext.parallelize(1 to 10).map(i => (i, i.toString)).toDF("num", "str") + df + .write + .format("parquet") + .save(path) + + // We don't support creating a temporary table while specifying a database + intercept[AnalysisException] { + spark.sql( + s""" + |CREATE TEMPORARY VIEW db.t + |USING parquet + |OPTIONS ( + | path '$path' + |) + """.stripMargin) + } + + // If you use backticks to quote the name then it's OK. spark.sql( s""" - |CREATE TEMPORARY VIEW db.t + |CREATE TEMPORARY VIEW `db.t` |USING parquet |OPTIONS ( | path '$path' |) """.stripMargin) + checkAnswer(spark.table("`db.t`"), df) } - - // If you use backticks to quote the name then it's OK. - spark.sql( - s""" - |CREATE TEMPORARY VIEW `db.t` - |USING parquet - |OPTIONS ( - | path '$path' - |) - """.stripMargin) - checkAnswer(spark.table("`db.t`"), df) } } From 2d903cf9d3a827e54217dfc9f1e4be99d8204387 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Mon, 29 Jan 2018 09:00:54 -0800 Subject: [PATCH 0228/1161] [SPARK-23223][SQL] Make stacking dataset transforms more performant ## What changes were proposed in this pull request? It is a common pattern to apply multiple transforms to a `Dataset` (using `Dataset.withColumn` for example. This is currently quite expensive because we run `CheckAnalysis` on the full plan and create an encoder for each intermediate `Dataset`. This PR extends the usage of the `AnalysisBarrier` to include `CheckAnalysis`. By doing this we hide the already analyzed plan from `CheckAnalysis` because barrier is a `LeafNode`. The `AnalysisBarrier` is in the `FinishAnalysis` phase of the optimizer. We also make binding the `Dataset` encoder lazy. The bound encoder is only needed when we materialize the dataset. ## How was this patch tested? Existing test should cover this. Author: Herman van Hovell Closes #20402 from hvanhovell/SPARK-23223. --- .../spark/sql/catalyst/analysis/Analyzer.scala | 16 ++++++++++++++-- .../sql/catalyst/analysis/CheckAnalysis.scala | 1 + .../sql/catalyst/analysis/AnalysisTest.scala | 3 +-- .../scala/org/apache/spark/sql/Dataset.scala | 8 ++++++-- .../spark/sql/execution/QueryExecution.scala | 16 ++-------------- .../apache/spark/sql/hive/test/TestHive.scala | 2 +- 6 files changed, 25 insertions(+), 21 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 2b14c8220d43b..91cb0365a0856 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -98,6 +98,19 @@ class Analyzer( this(catalog, conf, conf.optimizerMaxIterations) } + def executeAndCheck(plan: LogicalPlan): LogicalPlan = { + val analyzed = execute(plan) + try { + checkAnalysis(analyzed) + EliminateBarriers(analyzed) + } catch { + case e: AnalysisException => + val ae = new AnalysisException(e.message, e.line, e.startPosition, Option(analyzed)) + ae.setStackTrace(e.getStackTrace) + throw ae + } + } + override def execute(plan: LogicalPlan): LogicalPlan = { AnalysisContext.reset() try { @@ -178,8 +191,7 @@ class Analyzer( Batch("Subquery", Once, UpdateOuterReferences), Batch("Cleanup", fixedPoint, - CleanupAliases, - EliminateBarriers) + CleanupAliases) ) /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index ef91d79f3302c..90bda2a72ad82 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -356,6 +356,7 @@ trait CheckAnalysis extends PredicateHelper { } extendedCheckRules.foreach(_(plan)) plan.foreachUp { + case AnalysisBarrier(child) if !child.resolved => checkAnalysis(child) case o if !o.resolved => failAnalysis(s"unresolved operator ${o.simpleString}") case _ => } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala index 549a4355dfba3..3d7c91870133b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala @@ -54,8 +54,7 @@ trait AnalysisTest extends PlanTest { expectedPlan: LogicalPlan, caseSensitive: Boolean = true): Unit = { val analyzer = getAnalyzer(caseSensitive) - val actualPlan = analyzer.execute(inputPlan) - analyzer.checkAnalysis(actualPlan) + val actualPlan = analyzer.executeAndCheck(inputPlan) comparePlans(actualPlan, expectedPlan) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index edb6644ed5ac0..cc5b647b3f037 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -62,7 +62,11 @@ import org.apache.spark.util.Utils private[sql] object Dataset { def apply[T: Encoder](sparkSession: SparkSession, logicalPlan: LogicalPlan): Dataset[T] = { - new Dataset(sparkSession, logicalPlan, implicitly[Encoder[T]]) + val dataset = new Dataset(sparkSession, logicalPlan, implicitly[Encoder[T]]) + // Eagerly bind the encoder so we verify that the encoder matches the underlying + // schema. The user will get an error if this is not the case. + dataset.deserializer + dataset } def ofRows(sparkSession: SparkSession, logicalPlan: LogicalPlan): DataFrame = { @@ -204,7 +208,7 @@ class Dataset[T] private[sql]( // The deserializer expression which can be used to build a projection and turn rows to objects // of type T, after collecting rows to the driver side. - private val deserializer = + private lazy val deserializer = exprEnc.resolveAndBind(logicalPlan.output, sparkSession.sessionState.analyzer).deserializer private implicit def classTag = exprEnc.clsTag diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index 8bfe3eff0c3b3..7cae24bf5976c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -44,19 +44,7 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) { // TODO: Move the planner an optimizer into here from SessionState. protected def planner = sparkSession.sessionState.planner - def assertAnalyzed(): Unit = { - // Analyzer is invoked outside the try block to avoid calling it again from within the - // catch block below. - analyzed - try { - sparkSession.sessionState.analyzer.checkAnalysis(analyzed) - } catch { - case e: AnalysisException => - val ae = new AnalysisException(e.message, e.line, e.startPosition, Option(analyzed)) - ae.setStackTrace(e.getStackTrace) - throw ae - } - } + def assertAnalyzed(): Unit = analyzed def assertSupported(): Unit = { if (sparkSession.sessionState.conf.isUnsupportedOperationCheckEnabled) { @@ -66,7 +54,7 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) { lazy val analyzed: LogicalPlan = { SparkSession.setActiveSession(sparkSession) - sparkSession.sessionState.analyzer.execute(logical) + sparkSession.sessionState.analyzer.executeAndCheck(logical) } lazy val withCachedData: LogicalPlan = { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index 7287e20d55bbe..59708e7a0f2ff 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -575,7 +575,7 @@ private[hive] class TestHiveQueryExecution( logDebug(s"Query references test tables: ${referencedTestTables.mkString(", ")}") referencedTestTables.foreach(sparkSession.loadTestTable) // Proceed with analysis. - sparkSession.sessionState.analyzer.execute(logical) + sparkSession.sessionState.analyzer.executeAndCheck(logical) } } From 0d60b3213fe9a7ae5e9b208639f92011fdb2ca32 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Mon, 29 Jan 2018 10:25:25 -0800 Subject: [PATCH 0229/1161] [SPARK-22221][DOCS] Adding User Documentation for Arrow ## What changes were proposed in this pull request? Adding user facing documentation for working with Arrow in Spark Author: Bryan Cutler Author: Li Jin Author: hyukjinkwon Closes #19575 from BryanCutler/arrow-user-docs-SPARK-2221. --- docs/sql-programming-guide.md | 134 +++++++++++++++++++++++++- examples/src/main/python/sql/arrow.py | 129 +++++++++++++++++++++++++ 2 files changed, 262 insertions(+), 1 deletion(-) create mode 100644 examples/src/main/python/sql/arrow.py diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 502c0a8c37e01..d49c8d869cba6 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1640,6 +1640,138 @@ Configuration of Hive is done by placing your `hive-site.xml`, `core-site.xml` a You may run `./bin/spark-sql --help` for a complete list of all available options. +# PySpark Usage Guide for Pandas with Apache Arrow + +## Apache Arrow in Spark + +Apache Arrow is an in-memory columnar data format that is used in Spark to efficiently transfer +data between JVM and Python processes. This currently is most beneficial to Python users that +work with Pandas/NumPy data. Its usage is not automatic and might require some minor +changes to configuration or code to take full advantage and ensure compatibility. This guide will +give a high-level description of how to use Arrow in Spark and highlight any differences when +working with Arrow-enabled data. + +### Ensure PyArrow Installed + +If you install PySpark using pip, then PyArrow can be brought in as an extra dependency of the +SQL module with the command `pip install pyspark[sql]`. Otherwise, you must ensure that PyArrow +is installed and available on all cluster nodes. The current supported version is 0.8.0. +You can install using pip or conda from the conda-forge channel. See PyArrow +[installation](https://arrow.apache.org/docs/python/install.html) for details. + +## Enabling for Conversion to/from Pandas + +Arrow is available as an optimization when converting a Spark DataFrame to a Pandas DataFrame +using the call `toPandas()` and when creating a Spark DataFrame from a Pandas DataFrame with +`createDataFrame(pandas_df)`. To use Arrow when executing these calls, users need to first set +the Spark configuration 'spark.sql.execution.arrow.enabled' to 'true'. This is disabled by default. + +
+
+{% include_example dataframe_with_arrow python/sql/arrow.py %} +
+
+ +Using the above optimizations with Arrow will produce the same results as when Arrow is not +enabled. Note that even with Arrow, `toPandas()` results in the collection of all records in the +DataFrame to the driver program and should be done on a small subset of the data. Not all Spark +data types are currently supported and an error can be raised if a column has an unsupported type, +see [Supported Types](#supported-sql-arrow-types). If an error occurs during `createDataFrame()`, +Spark will fall back to create the DataFrame without Arrow. + +## Pandas UDFs (a.k.a. Vectorized UDFs) + +Pandas UDFs are user defined functions that are executed by Spark using Arrow to transfer data and +Pandas to work with the data. A Pandas UDF is defined using the keyword `pandas_udf` as a decorator +or to wrap the function, no additional configuration is required. Currently, there are two types of +Pandas UDF: Scalar and Group Map. + +### Scalar + +Scalar Pandas UDFs are used for vectorizing scalar operations. They can be used with functions such +as `select` and `withColumn`. The Python function should take `pandas.Series` as inputs and return +a `pandas.Series` of the same length. Internally, Spark will execute a Pandas UDF by splitting +columns into batches and calling the function for each batch as a subset of the data, then +concatenating the results together. + +The following example shows how to create a scalar Pandas UDF that computes the product of 2 columns. + +
+
+{% include_example scalar_pandas_udf python/sql/arrow.py %} +
+
+ +### Group Map +Group map Pandas UDFs are used with `groupBy().apply()` which implements the "split-apply-combine" pattern. +Split-apply-combine consists of three steps: +* Split the data into groups by using `DataFrame.groupBy`. +* Apply a function on each group. The input and output of the function are both `pandas.DataFrame`. The + input data contains all the rows and columns for each group. +* Combine the results into a new `DataFrame`. + +To use `groupBy().apply()`, the user needs to define the following: +* A Python function that defines the computation for each group. +* A `StructType` object or a string that defines the schema of the output `DataFrame`. + +Note that all data for a group will be loaded into memory before the function is applied. This can +lead to out of memory exceptons, especially if the group sizes are skewed. The configuration for +[maxRecordsPerBatch](#setting-arrow-batch-size) is not applied on groups and it is up to the user +to ensure that the grouped data will fit into the available memory. + +The following example shows how to use `groupby().apply()` to subtract the mean from each value in the group. + +
+
+{% include_example group_map_pandas_udf python/sql/arrow.py %} +
+
+ +For detailed usage, please see [`pyspark.sql.functions.pandas_udf`](api/python/pyspark.sql.html#pyspark.sql.functions.pandas_udf) and +[`pyspark.sql.GroupedData.apply`](api/python/pyspark.sql.html#pyspark.sql.GroupedData.apply). + +## Usage Notes + +### Supported SQL Types + +Currently, all Spark SQL data types are supported by Arrow-based conversion except `MapType`, +`ArrayType` of `TimestampType`, and nested `StructType`. + +### Setting Arrow Batch Size + +Data partitions in Spark are converted into Arrow record batches, which can temporarily lead to +high memory usage in the JVM. To avoid possible out of memory exceptions, the size of the Arrow +record batches can be adjusted by setting the conf "spark.sql.execution.arrow.maxRecordsPerBatch" +to an integer that will determine the maximum number of rows for each batch. The default value is +10,000 records per batch. If the number of columns is large, the value should be adjusted +accordingly. Using this limit, each data partition will be made into 1 or more record batches for +processing. + +### Timestamp with Time Zone Semantics + +Spark internally stores timestamps as UTC values, and timestamp data that is brought in without +a specified time zone is converted as local time to UTC with microsecond resolution. When timestamp +data is exported or displayed in Spark, the session time zone is used to localize the timestamp +values. The session time zone is set with the configuration 'spark.sql.session.timeZone' and will +default to the JVM system local time zone if not set. Pandas uses a `datetime64` type with nanosecond +resolution, `datetime64[ns]`, with optional time zone on a per-column basis. + +When timestamp data is transferred from Spark to Pandas it will be converted to nanoseconds +and each column will be converted to the Spark session time zone then localized to that time +zone, which removes the time zone and displays values as local time. This will occur +when calling `toPandas()` or `pandas_udf` with timestamp columns. + +When timestamp data is transferred from Pandas to Spark, it will be converted to UTC microseconds. This +occurs when calling `createDataFrame` with a Pandas DataFrame or when returning a timestamp from a +`pandas_udf`. These conversions are done automatically to ensure Spark will have data in the +expected format, so it is not necessary to do any of these conversions yourself. Any nanosecond +values will be truncated. + +Note that a standard UDF (non-Pandas) will load timestamp data as Python datetime objects, which is +different than a Pandas timestamp. It is recommended to use Pandas time series functionality when +working with timestamps in `pandas_udf`s to get the best performance, see +[here](https://pandas.pydata.org/pandas-docs/stable/timeseries.html) for details. + # Migration Guide ## Upgrading From Spark SQL 2.2 to 2.3 @@ -1788,7 +1920,7 @@ options. Note that, for DecimalType(38,0)*, the table above intentionally does not cover all other combinations of scales and precisions because currently we only infer decimal type like `BigInteger`/`BigInt`. For example, 1.1 is inferred as double type. - In PySpark, now we need Pandas 0.19.2 or upper if you want to use Pandas related functionalities, such as `toPandas`, `createDataFrame` from Pandas DataFrame, etc. - In PySpark, the behavior of timestamp values for Pandas related functionalities was changed to respect session timezone. If you want to use the old behavior, you need to set a configuration `spark.sql.execution.pandas.respectSessionTimeZone` to `False`. See [SPARK-22395](https://issues.apache.org/jira/browse/SPARK-22395) for details. - - In PySpark, `na.fill()` or `fillna` also accepts boolean and replaces nulls with booleans. In prior Spark versions, PySpark just ignores it and returns the original Dataset/DataFrame. + - In PySpark, `na.fill()` or `fillna` also accepts boolean and replaces nulls with booleans. In prior Spark versions, PySpark just ignores it and returns the original Dataset/DataFrame. - Since Spark 2.3, when either broadcast hash join or broadcast nested loop join is applicable, we prefer to broadcasting the table that is explicitly specified in a broadcast hint. For details, see the section [Broadcast Hint](#broadcast-hint-for-sql-queries) and [SPARK-22489](https://issues.apache.org/jira/browse/SPARK-22489). - Since Spark 2.3, when all inputs are binary, `functions.concat()` returns an output as binary. Otherwise, it returns as a string. Until Spark 2.3, it always returns as a string despite of input types. To keep the old behavior, set `spark.sql.function.concatBinaryAsString` to `true`. - Since Spark 2.3, when all inputs are binary, SQL `elt()` returns an output as binary. Otherwise, it returns as a string. Until Spark 2.3, it always returns as a string despite of input types. To keep the old behavior, set `spark.sql.function.eltOutputAsString` to `true`. diff --git a/examples/src/main/python/sql/arrow.py b/examples/src/main/python/sql/arrow.py new file mode 100644 index 0000000000000..6c0028b3f1c1f --- /dev/null +++ b/examples/src/main/python/sql/arrow.py @@ -0,0 +1,129 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +A simple example demonstrating Arrow in Spark. +Run with: + ./bin/spark-submit examples/src/main/python/sql/arrow.py +""" + +from __future__ import print_function + +from pyspark.sql import SparkSession +from pyspark.sql.utils import require_minimum_pandas_version, require_minimum_pyarrow_version + +require_minimum_pandas_version() +require_minimum_pyarrow_version() + + +def dataframe_with_arrow_example(spark): + # $example on:dataframe_with_arrow$ + import numpy as np + import pandas as pd + + # Enable Arrow-based columnar data transfers + spark.conf.set("spark.sql.execution.arrow.enabled", "true") + + # Generate a Pandas DataFrame + pdf = pd.DataFrame(np.random.rand(100, 3)) + + # Create a Spark DataFrame from a Pandas DataFrame using Arrow + df = spark.createDataFrame(pdf) + + # Convert the Spark DataFrame back to a Pandas DataFrame using Arrow + result_pdf = df.select("*").toPandas() + # $example off:dataframe_with_arrow$ + print("Pandas DataFrame result statistics:\n%s\n" % str(result_pdf.describe())) + + +def scalar_pandas_udf_example(spark): + # $example on:scalar_pandas_udf$ + import pandas as pd + + from pyspark.sql.functions import col, pandas_udf + from pyspark.sql.types import LongType + + # Declare the function and create the UDF + def multiply_func(a, b): + return a * b + + multiply = pandas_udf(multiply_func, returnType=LongType()) + + # The function for a pandas_udf should be able to execute with local Pandas data + x = pd.Series([1, 2, 3]) + print(multiply_func(x, x)) + # 0 1 + # 1 4 + # 2 9 + # dtype: int64 + + # Create a Spark DataFrame, 'spark' is an existing SparkSession + df = spark.createDataFrame(pd.DataFrame(x, columns=["x"])) + + # Execute function as a Spark vectorized UDF + df.select(multiply(col("x"), col("x"))).show() + # +-------------------+ + # |multiply_func(x, x)| + # +-------------------+ + # | 1| + # | 4| + # | 9| + # +-------------------+ + # $example off:scalar_pandas_udf$ + + +def group_map_pandas_udf_example(spark): + # $example on:group_map_pandas_udf$ + from pyspark.sql.functions import pandas_udf, PandasUDFType + + df = spark.createDataFrame( + [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], + ("id", "v")) + + @pandas_udf("id long, v double", PandasUDFType.GROUP_MAP) + def substract_mean(pdf): + # pdf is a pandas.DataFrame + v = pdf.v + return pdf.assign(v=v - v.mean()) + + df.groupby("id").apply(substract_mean).show() + # +---+----+ + # | id| v| + # +---+----+ + # | 1|-0.5| + # | 1| 0.5| + # | 2|-3.0| + # | 2|-1.0| + # | 2| 4.0| + # +---+----+ + # $example off:group_map_pandas_udf$ + + +if __name__ == "__main__": + spark = SparkSession \ + .builder \ + .appName("Python Arrow-in-Spark example") \ + .getOrCreate() + + print("Running Pandas to/from conversion example") + dataframe_with_arrow_example(spark) + print("Running pandas_udf scalar example") + scalar_pandas_udf_example(spark) + print("Running pandas_udf group map example") + group_map_pandas_udf_example(spark) + + spark.stop() From e30b34f7bd9a687eb43d636fffeb98fe235fcbf4 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Mon, 29 Jan 2018 10:29:42 -0800 Subject: [PATCH 0230/1161] [SPARK-22916][SQL][FOLLOW-UP] Update the Description of Join Selection ## What changes were proposed in this pull request? This PR is to update the description of the join algorithm changes. ## How was this patch tested? N/A Author: gatorsmile Closes #20420 from gatorsmile/followUp22916. --- .../spark/sql/execution/SparkStrategies.scala | 60 +++++++++++++++---- 1 file changed, 47 insertions(+), 13 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index ce512bc46563a..82b4eb9fba242 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -91,23 +91,58 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { * Select the proper physical plan for join based on joining keys and size of logical plan. * * At first, uses the [[ExtractEquiJoinKeys]] pattern to find joins where at least some of the - * predicates can be evaluated by matching join keys. If found, Join implementations are chosen + * predicates can be evaluated by matching join keys. If found, join implementations are chosen * with the following precedence: * - * - Broadcast: We prefer to broadcast the join side with an explicit broadcast hint(e.g. the - * user applied the [[org.apache.spark.sql.functions.broadcast()]] function to a DataFrame). - * If both sides have the broadcast hint, we prefer to broadcast the side with a smaller - * estimated physical size. If neither one of the sides has the broadcast hint, - * we only broadcast the join side if its estimated physical size that is smaller than - * the user-configurable [[SQLConf.AUTO_BROADCASTJOIN_THRESHOLD]] threshold. + * - Broadcast hash join (BHJ): + * BHJ is not supported for full outer join. For right outer join, we only can broadcast the + * left side. For left outer, left semi, left anti and the internal join type ExistenceJoin, + * we only can broadcast the right side. For inner like join, we can broadcast both sides. + * Normally, BHJ can perform faster than the other join algorithms when the broadcast side is + * small. However, broadcasting tables is a network-intensive operation. It could cause OOM + * or perform worse than the other join algorithms, especially when the build/broadcast side + * is big. + * + * For the supported cases, users can specify the broadcast hint (e.g. the user applied the + * [[org.apache.spark.sql.functions.broadcast()]] function to a DataFrame) and session-based + * [[SQLConf.AUTO_BROADCASTJOIN_THRESHOLD]] threshold to adjust whether BHJ is used and + * which join side is broadcast. + * + * 1) Broadcast the join side with the broadcast hint, even if the size is larger than + * [[SQLConf.AUTO_BROADCASTJOIN_THRESHOLD]]. If both sides have the hint (only when the type + * is inner like join), the side with a smaller estimated physical size will be broadcast. + * 2) Respect the [[SQLConf.AUTO_BROADCASTJOIN_THRESHOLD]] threshold and broadcast the side + * whose estimated physical size is smaller than the threshold. If both sides are below the + * threshold, broadcast the smaller side. If neither is smaller, BHJ is not used. + * * - Shuffle hash join: if the average size of a single partition is small enough to build a hash * table. + * * - Sort merge: if the matching join keys are sortable. * * If there is no joining keys, Join implementations are chosen with the following precedence: - * - BroadcastNestedLoopJoin: if one side of the join could be broadcasted - * - CartesianProduct: for Inner join - * - BroadcastNestedLoopJoin + * - BroadcastNestedLoopJoin (BNLJ): + * BNLJ supports all the join types but the impl is OPTIMIZED for the following scenarios: + * For right outer join, the left side is broadcast. For left outer, left semi, left anti + * and the internal join type ExistenceJoin, the right side is broadcast. For inner like + * joins, either side is broadcast. + * + * Like BHJ, users still can specify the broadcast hint and session-based + * [[SQLConf.AUTO_BROADCASTJOIN_THRESHOLD]] threshold to impact which side is broadcast. + * + * 1) Broadcast the join side with the broadcast hint, even if the size is larger than + * [[SQLConf.AUTO_BROADCASTJOIN_THRESHOLD]]. If both sides have the hint (i.e., just for + * inner-like join), the side with a smaller estimated physical size will be broadcast. + * 2) Respect the [[SQLConf.AUTO_BROADCASTJOIN_THRESHOLD]] threshold and broadcast the side + * whose estimated physical size is smaller than the threshold. If both sides are below the + * threshold, broadcast the smaller side. If neither is smaller, BNLJ is not used. + * + * - CartesianProduct: for inner like join, CartesianProduct is the fallback option. + * + * - BroadcastNestedLoopJoin (BNLJ): + * For the other join types, BNLJ is the fallback option. Here, we just pick the broadcast + * side with the broadcast hint. If neither side has a hint, we broadcast the side with + * the smaller estimated physical size. */ object JoinSelection extends Strategy with PredicateHelper { @@ -140,8 +175,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } private def canBuildRight(joinType: JoinType): Boolean = joinType match { - case _: InnerLike | LeftOuter | LeftSemi | LeftAnti => true - case j: ExistenceJoin => true + case _: InnerLike | LeftOuter | LeftSemi | LeftAnti | _: ExistenceJoin => true case _ => false } @@ -244,7 +278,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { // --- Without joining keys ------------------------------------------------------------ - // Pick BroadcastNestedLoopJoin if one side could be broadcasted + // Pick BroadcastNestedLoopJoin if one side could be broadcast case j @ logical.Join(left, right, joinType, condition) if canBroadcastByHints(joinType, left, right) => val buildSide = broadcastSideByHints(joinType, left, right) From b834446ec1338349f6d974afd96f677db3e8fd1a Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Mon, 29 Jan 2018 16:09:14 -0600 Subject: [PATCH 0231/1161] [SPARK-23209][core] Allow credential manager to work when Hive not available. The JVM seems to be doing early binding of classes that the Hive provider depends on, causing an error to be thrown before it was caught by the code in the class. The fix wraps the creation of the provider in a try..catch so that the provider can be ignored when dependencies are missing. Added a unit test (which fails without the fix), and also tested that getting tokens still works in a real cluster. Author: Marcelo Vanzin Closes #20399 from vanzin/SPARK-23209. --- .../HadoopDelegationTokenManager.scala | 17 +++++- .../HadoopDelegationTokenManagerSuite.scala | 58 +++++++++++++++++++ 2 files changed, 72 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManager.scala b/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManager.scala index 116a686fe1480..5151df00476f9 100644 --- a/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManager.scala +++ b/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManager.scala @@ -64,9 +64,9 @@ private[spark] class HadoopDelegationTokenManager( } private def getDelegationTokenProviders: Map[String, HadoopDelegationTokenProvider] = { - val providers = List(new HadoopFSDelegationTokenProvider(fileSystems), - new HiveDelegationTokenProvider, - new HBaseDelegationTokenProvider) + val providers = Seq(new HadoopFSDelegationTokenProvider(fileSystems)) ++ + safeCreateProvider(new HiveDelegationTokenProvider) ++ + safeCreateProvider(new HBaseDelegationTokenProvider) // Filter out providers for which spark.security.credentials.{service}.enabled is false. providers @@ -75,6 +75,17 @@ private[spark] class HadoopDelegationTokenManager( .toMap } + private def safeCreateProvider( + createFn: => HadoopDelegationTokenProvider): Option[HadoopDelegationTokenProvider] = { + try { + Some(createFn) + } catch { + case t: Throwable => + logDebug(s"Failed to load built in provider.", t) + None + } + } + def isServiceEnabled(serviceName: String): Boolean = { val key = providerEnabledConfig.format(serviceName) diff --git a/core/src/test/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManagerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManagerSuite.scala index eeffc36070b44..2849a10a2c81e 100644 --- a/core/src/test/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManagerSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.deploy.security +import org.apache.commons.io.IOUtils import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.FileSystem import org.apache.hadoop.security.Credentials @@ -110,7 +111,64 @@ class HadoopDelegationTokenManagerSuite extends SparkFunSuite with Matchers { creds.getAllTokens.size should be (0) } + test("SPARK-23209: obtain tokens when Hive classes are not available") { + // This test needs a custom class loader to hide Hive classes which are in the classpath. + // Because the manager code loads the Hive provider directly instead of using reflection, we + // need to drive the test through the custom class loader so a new copy that cannot find + // Hive classes is loaded. + val currentLoader = Thread.currentThread().getContextClassLoader() + val noHive = new ClassLoader() { + override def loadClass(name: String, resolve: Boolean): Class[_] = { + if (name.startsWith("org.apache.hive") || name.startsWith("org.apache.hadoop.hive")) { + throw new ClassNotFoundException(name) + } + + if (name.startsWith("java") || name.startsWith("scala")) { + currentLoader.loadClass(name) + } else { + val classFileName = name.replaceAll("\\.", "/") + ".class" + val in = currentLoader.getResourceAsStream(classFileName) + if (in != null) { + val bytes = IOUtils.toByteArray(in) + defineClass(name, bytes, 0, bytes.length) + } else { + throw new ClassNotFoundException(name) + } + } + } + } + + try { + Thread.currentThread().setContextClassLoader(noHive) + val test = noHive.loadClass(NoHiveTest.getClass.getName().stripSuffix("$")) + test.getMethod("runTest").invoke(null) + } finally { + Thread.currentThread().setContextClassLoader(currentLoader) + } + } + private[spark] def hadoopFSsToAccess(hadoopConf: Configuration): Set[FileSystem] = { Set(FileSystem.get(hadoopConf)) } } + +/** Test code for SPARK-23209 to avoid using too much reflection above. */ +private object NoHiveTest extends Matchers { + + def runTest(): Unit = { + try { + val manager = new HadoopDelegationTokenManager(new SparkConf(), new Configuration(), + _ => Set()) + manager.getServiceDelegationTokenProvider("hive") should be (None) + } catch { + case e: Throwable => + // Throw a better exception in case the test fails, since there may be a lot of nesting. + var cause = e + while (cause.getCause() != null) { + cause = cause.getCause() + } + throw cause + } + } + +} From f235df66a4754cbb64d5b7b5cfd5a52bdd243b8a Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Mon, 29 Jan 2018 17:37:55 -0800 Subject: [PATCH 0232/1161] [SPARK-22221][SQL][FOLLOWUP] Externalize spark.sql.execution.arrow.maxRecordsPerBatch ## What changes were proposed in this pull request? This is a followup to #19575 which added a section on setting max Arrow record batches and this will externalize the conf that was referenced in the docs. ## How was this patch tested? NA Author: Bryan Cutler Closes #20423 from BryanCutler/arrow-user-doc-externalize-maxRecordsPerBatch-SPARK-22221. --- .../src/main/scala/org/apache/spark/sql/internal/SQLConf.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 61ea03d395afc..54a35594f505e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1051,7 +1051,6 @@ object SQLConf { val ARROW_EXECUTION_MAX_RECORDS_PER_BATCH = buildConf("spark.sql.execution.arrow.maxRecordsPerBatch") - .internal() .doc("When using Apache Arrow, limit the maximum number of records that can be written " + "to a single ArrowRecordBatch in memory. If set to zero or negative there is no limit.") .intConf From 31bd1dab1301d27a16c9d5d1b0b3301d618b0516 Mon Sep 17 00:00:00 2001 From: Paul Mackles Date: Tue, 30 Jan 2018 11:15:27 +0800 Subject: [PATCH 0233/1161] [SPARK-23088][CORE] History server not showing incomplete/running applications ## What changes were proposed in this pull request? History server not showing incomplete/running applications when spark.history.ui.maxApplications property is set to a value that is smaller than the total number of applications. ## How was this patch tested? Verified manually against master and 2.2.2 branch. Author: Paul Mackles Closes #20335 from pmackles/SPARK-23088. --- .../resources/org/apache/spark/ui/static/historypage.js | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/core/src/main/resources/org/apache/spark/ui/static/historypage.js b/core/src/main/resources/org/apache/spark/ui/static/historypage.js index 2cde66b081a1c..f0b2a5a833a99 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/historypage.js +++ b/core/src/main/resources/org/apache/spark/ui/static/historypage.js @@ -108,7 +108,12 @@ $(document).ready(function() { requestedIncomplete = getParameterByName("showIncomplete", searchString); requestedIncomplete = (requestedIncomplete == "true" ? true : false); - $.getJSON("api/v1/applications?limit=" + appLimit, function(response,status,jqXHR) { + appParams = { + limit: appLimit, + status: (requestedIncomplete ? "running" : "completed") + }; + + $.getJSON("api/v1/applications", appParams, function(response,status,jqXHR) { var array = []; var hasMultipleAttempts = false; for (i in response) { From b375397b1678b7fe20a0b7f87a7e8b37ae5646ef Mon Sep 17 00:00:00 2001 From: Xingbo Jiang Date: Tue, 30 Jan 2018 11:40:42 +0800 Subject: [PATCH 0234/1161] [SPARK-23207][SQL][FOLLOW-UP] Don't perform local sort for DataFrame.repartition(1) ## What changes were proposed in this pull request? In `ShuffleExchangeExec`, we don't need to insert extra local sort before round-robin partitioning, if the new partitioning has only 1 partition, because under that case all output rows go to the same partition. ## How was this patch tested? The existing test cases. Author: Xingbo Jiang Closes #20426 from jiangxb1987/repartition1. --- .../spark/sql/execution/exchange/ShuffleExchangeExec.scala | 4 ++++ .../spark/sql/execution/streaming/ForeachSinkSuite.scala | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala index 76c1fa65f924b..4d95ee34f30de 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala @@ -257,7 +257,11 @@ object ShuffleExchangeExec { // // Currently we following the most straight-forward way that perform a local sort before // partitioning. + // + // Note that we don't perform local sort if the new partitioning has only 1 partition, under + // that case all output rows go to the same partition. val newRdd = if (SQLConf.get.sortBeforeRepartition && + newPartitioning.numPartitions > 1 && newPartitioning.isInstanceOf[RoundRobinPartitioning]) { rdd.mapPartitionsInternal { iter => val recordComparatorSupplier = new Supplier[RecordComparator] { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala index 1248c670df45c..41434e6d8b974 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala @@ -162,7 +162,7 @@ class ForeachSinkSuite extends StreamTest with SharedSQLContext with BeforeAndAf val allEvents = ForeachSinkSuite.allEvents() assert(allEvents.size === 1) assert(allEvents(0)(0) === ForeachSinkSuite.Open(partition = 0, version = 0)) - assert(allEvents(0)(1) === ForeachSinkSuite.Process(value = 2)) + assert(allEvents(0)(1) === ForeachSinkSuite.Process(value = 1)) // `close` should be called with the error val errorEvent = allEvents(0)(2).asInstanceOf[ForeachSinkSuite.Close] From 8b983243e45dfe2617c043a3229a7d87f4c4b44b Mon Sep 17 00:00:00 2001 From: Henry Robinson Date: Mon, 29 Jan 2018 22:19:59 -0800 Subject: [PATCH 0235/1161] [SPARK-23157][SQL] Explain restriction on column expression in withColumn() ## What changes were proposed in this pull request? It's not obvious from the comments that any added column must be a function of the dataset that we are adding it to. Add a comment to that effect to Scala, Python and R Data* methods. Author: Henry Robinson Closes #20429 from henryr/SPARK-23157. --- R/pkg/R/DataFrame.R | 3 ++- python/pyspark/sql/dataframe.py | 4 ++++ sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala | 3 +++ 3 files changed, 9 insertions(+), 1 deletion(-) diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 29f3e986eaab6..547b5ea48a555 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -2090,7 +2090,8 @@ setMethod("selectExpr", #' #' @param x a SparkDataFrame. #' @param colName a column name. -#' @param col a Column expression, or an atomic vector in the length of 1 as literal value. +#' @param col a Column expression (which must refer only to this DataFrame), or an atomic vector in +#' the length of 1 as literal value. #' @return A SparkDataFrame with the new column added or the existing column replaced. #' @family SparkDataFrame functions #' @aliases withColumn,SparkDataFrame,character-method diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index ac403080acfdf..055b2c4a0ffec 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1829,11 +1829,15 @@ def withColumn(self, colName, col): Returns a new :class:`DataFrame` by adding a column or replacing the existing column that has the same name. + The column expression must be an expression over this dataframe; attempting to add + a column from some other dataframe will raise an error. + :param colName: string, name of the new column. :param col: a :class:`Column` expression for the new column. >>> df.withColumn('age2', df.age + 2).collect() [Row(age=2, name=u'Alice', age2=4), Row(age=5, name=u'Bob', age2=7)] + """ assert isinstance(col, Column), "col should be Column" return DataFrame(self._jdf.withColumn(colName, col._jc), self.sql_ctx) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index cc5b647b3f037..d47cd0aecf56a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -2150,6 +2150,9 @@ class Dataset[T] private[sql]( * Returns a new Dataset by adding a column or replacing the existing column that has * the same name. * + * `column`'s expression must only refer to attributes supplied by this Dataset. It is an + * error to add a column that refers to some other Dataset. + * * @group untypedrel * @since 2.0.0 */ From 5056877e8bea56dd0f4dc9e3385669e1e78b2925 Mon Sep 17 00:00:00 2001 From: sethah Date: Tue, 30 Jan 2018 09:02:16 +0200 Subject: [PATCH 0236/1161] [SPARK-23138][ML][DOC] Multiclass logistic regression summary example and user guide ## What changes were proposed in this pull request? User guide and examples are updated to reflect multiclass logistic regression summary which was added in [SPARK-17139](https://issues.apache.org/jira/browse/SPARK-17139). I did not make a separate summary example, but added the summary code to the multiclass example that already existed. I don't see the need for a separate example for the summary. ## How was this patch tested? Docs and examples only. Ran all examples locally using spark-submit. Author: sethah Closes #20332 from sethah/multiclass_summary_example. --- docs/ml-classification-regression.md | 22 +++---- .../JavaLogisticRegressionSummaryExample.java | 17 ++--- ...gisticRegressionWithElasticNetExample.java | 62 +++++++++++++++++++ ...ss_logistic_regression_with_elastic_net.py | 38 ++++++++++++ .../ml/LogisticRegressionSummaryExample.scala | 15 ++--- ...isticRegressionWithElasticNetExample.scala | 43 +++++++++++++ 6 files changed, 164 insertions(+), 33 deletions(-) diff --git a/docs/ml-classification-regression.md b/docs/ml-classification-regression.md index bf979f3c73a52..ddd2f4b49ca07 100644 --- a/docs/ml-classification-regression.md +++ b/docs/ml-classification-regression.md @@ -87,7 +87,7 @@ More details on parameters can be found in the [R API documentation](api/R/spark The `spark.ml` implementation of logistic regression also supports extracting a summary of the model over the training set. Note that the predictions and metrics which are stored as `DataFrame` in -`BinaryLogisticRegressionSummary` are annotated `@transient` and hence +`LogisticRegressionSummary` are annotated `@transient` and hence only available on the driver.
@@ -97,10 +97,9 @@ only available on the driver. [`LogisticRegressionTrainingSummary`](api/scala/index.html#org.apache.spark.ml.classification.LogisticRegressionTrainingSummary) provides a summary for a [`LogisticRegressionModel`](api/scala/index.html#org.apache.spark.ml.classification.LogisticRegressionModel). -Currently, only binary classification is supported and the -summary must be explicitly cast to -[`BinaryLogisticRegressionTrainingSummary`](api/scala/index.html#org.apache.spark.ml.classification.BinaryLogisticRegressionTrainingSummary). -This will likely change when multiclass classification is supported. +In the case of binary classification, certain additional metrics are +available, e.g. ROC curve. The binary summary can be accessed via the +`binarySummary` method. See [`BinaryLogisticRegressionTrainingSummary`](api/scala/index.html#org.apache.spark.ml.classification.BinaryLogisticRegressionTrainingSummary). Continuing the earlier example: @@ -111,10 +110,9 @@ Continuing the earlier example: [`LogisticRegressionTrainingSummary`](api/java/org/apache/spark/ml/classification/LogisticRegressionTrainingSummary.html) provides a summary for a [`LogisticRegressionModel`](api/java/org/apache/spark/ml/classification/LogisticRegressionModel.html). -Currently, only binary classification is supported and the -summary must be explicitly cast to -[`BinaryLogisticRegressionTrainingSummary`](api/java/org/apache/spark/ml/classification/BinaryLogisticRegressionTrainingSummary.html). -Support for multiclass model summaries will be added in the future. +In the case of binary classification, certain additional metrics are +available, e.g. ROC curve. The binary summary can be accessed via the +`binarySummary` method. See [`BinaryLogisticRegressionTrainingSummary`](api/java/org/apache/spark/ml/classification/BinaryLogisticRegressionTrainingSummary.html). Continuing the earlier example: @@ -125,7 +123,8 @@ Continuing the earlier example: [`LogisticRegressionTrainingSummary`](api/python/pyspark.ml.html#pyspark.ml.classification.LogisticRegressionSummary) provides a summary for a [`LogisticRegressionModel`](api/python/pyspark.ml.html#pyspark.ml.classification.LogisticRegressionModel). -Currently, only binary classification is supported. Support for multiclass model summaries will be added in the future. +In the case of binary classification, certain additional metrics are +available, e.g. ROC curve. See [`BinaryLogisticRegressionTrainingSummary`](api/python/pyspark.ml.html#pyspark.ml.classification.BinaryLogisticRegressionTrainingSummary). Continuing the earlier example: @@ -162,7 +161,8 @@ For a detailed derivation please see [here](https://en.wikipedia.org/wiki/Multin **Examples** The following example shows how to train a multiclass logistic regression -model with elastic net regularization. +model with elastic net regularization, as well as extract the multiclass +training summary for evaluating the model.
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaLogisticRegressionSummaryExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaLogisticRegressionSummaryExample.java index dee56799d8aee..1529da16f051f 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaLogisticRegressionSummaryExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaLogisticRegressionSummaryExample.java @@ -18,10 +18,9 @@ package org.apache.spark.examples.ml; // $example on$ -import org.apache.spark.ml.classification.BinaryLogisticRegressionSummary; +import org.apache.spark.ml.classification.BinaryLogisticRegressionTrainingSummary; import org.apache.spark.ml.classification.LogisticRegression; import org.apache.spark.ml.classification.LogisticRegressionModel; -import org.apache.spark.ml.classification.LogisticRegressionTrainingSummary; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.SparkSession; @@ -50,7 +49,7 @@ public static void main(String[] args) { // $example on$ // Extract the summary from the returned LogisticRegressionModel instance trained in the earlier // example - LogisticRegressionTrainingSummary trainingSummary = lrModel.summary(); + BinaryLogisticRegressionTrainingSummary trainingSummary = lrModel.binarySummary(); // Obtain the loss per iteration. double[] objectiveHistory = trainingSummary.objectiveHistory(); @@ -58,21 +57,15 @@ public static void main(String[] args) { System.out.println(lossPerIteration); } - // Obtain the metrics useful to judge performance on test data. - // We cast the summary to a BinaryLogisticRegressionSummary since the problem is a binary - // classification problem. - BinaryLogisticRegressionSummary binarySummary = - (BinaryLogisticRegressionSummary) trainingSummary; - // Obtain the receiver-operating characteristic as a dataframe and areaUnderROC. - Dataset roc = binarySummary.roc(); + Dataset roc = trainingSummary.roc(); roc.show(); roc.select("FPR").show(); - System.out.println(binarySummary.areaUnderROC()); + System.out.println(trainingSummary.areaUnderROC()); // Get the threshold corresponding to the maximum F-Measure and rerun LogisticRegression with // this selected threshold. - Dataset fMeasure = binarySummary.fMeasureByThreshold(); + Dataset fMeasure = trainingSummary.fMeasureByThreshold(); double maxFMeasure = fMeasure.select(functions.max("F-Measure")).head().getDouble(0); double bestThreshold = fMeasure.where(fMeasure.col("F-Measure").equalTo(maxFMeasure)) .select("threshold").head().getDouble(0); diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaMulticlassLogisticRegressionWithElasticNetExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaMulticlassLogisticRegressionWithElasticNetExample.java index da410cba2b3f1..801a82cd2f24f 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaMulticlassLogisticRegressionWithElasticNetExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaMulticlassLogisticRegressionWithElasticNetExample.java @@ -20,6 +20,7 @@ // $example on$ import org.apache.spark.ml.classification.LogisticRegression; import org.apache.spark.ml.classification.LogisticRegressionModel; +import org.apache.spark.ml.classification.LogisticRegressionTrainingSummary; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.SparkSession; @@ -48,6 +49,67 @@ public static void main(String[] args) { // Print the coefficients and intercept for multinomial logistic regression System.out.println("Coefficients: \n" + lrModel.coefficientMatrix() + " \nIntercept: " + lrModel.interceptVector()); + LogisticRegressionTrainingSummary trainingSummary = lrModel.summary(); + + // Obtain the loss per iteration. + double[] objectiveHistory = trainingSummary.objectiveHistory(); + for (double lossPerIteration : objectiveHistory) { + System.out.println(lossPerIteration); + } + + // for multiclass, we can inspect metrics on a per-label basis + System.out.println("False positive rate by label:"); + int i = 0; + double[] fprLabel = trainingSummary.falsePositiveRateByLabel(); + for (double fpr : fprLabel) { + System.out.println("label " + i + ": " + fpr); + i++; + } + + System.out.println("True positive rate by label:"); + i = 0; + double[] tprLabel = trainingSummary.truePositiveRateByLabel(); + for (double tpr : tprLabel) { + System.out.println("label " + i + ": " + tpr); + i++; + } + + System.out.println("Precision by label:"); + i = 0; + double[] precLabel = trainingSummary.precisionByLabel(); + for (double prec : precLabel) { + System.out.println("label " + i + ": " + prec); + i++; + } + + System.out.println("Recall by label:"); + i = 0; + double[] recLabel = trainingSummary.recallByLabel(); + for (double rec : recLabel) { + System.out.println("label " + i + ": " + rec); + i++; + } + + System.out.println("F-measure by label:"); + i = 0; + double[] fLabel = trainingSummary.fMeasureByLabel(); + for (double f : fLabel) { + System.out.println("label " + i + ": " + f); + i++; + } + + double accuracy = trainingSummary.accuracy(); + double falsePositiveRate = trainingSummary.weightedFalsePositiveRate(); + double truePositiveRate = trainingSummary.weightedTruePositiveRate(); + double fMeasure = trainingSummary.weightedFMeasure(); + double precision = trainingSummary.weightedPrecision(); + double recall = trainingSummary.weightedRecall(); + System.out.println("Accuracy: " + accuracy); + System.out.println("FPR: " + falsePositiveRate); + System.out.println("TPR: " + truePositiveRate); + System.out.println("F-measure: " + fMeasure); + System.out.println("Precision: " + precision); + System.out.println("Recall: " + recall); // $example off$ spark.stop(); diff --git a/examples/src/main/python/ml/multiclass_logistic_regression_with_elastic_net.py b/examples/src/main/python/ml/multiclass_logistic_regression_with_elastic_net.py index bb9cd82d6ba27..bec9860c79a2d 100644 --- a/examples/src/main/python/ml/multiclass_logistic_regression_with_elastic_net.py +++ b/examples/src/main/python/ml/multiclass_logistic_regression_with_elastic_net.py @@ -43,6 +43,44 @@ # Print the coefficients and intercept for multinomial logistic regression print("Coefficients: \n" + str(lrModel.coefficientMatrix)) print("Intercept: " + str(lrModel.interceptVector)) + + trainingSummary = lrModel.summary + + # Obtain the objective per iteration + objectiveHistory = trainingSummary.objectiveHistory + print("objectiveHistory:") + for objective in objectiveHistory: + print(objective) + + # for multiclass, we can inspect metrics on a per-label basis + print("False positive rate by label:") + for i, rate in enumerate(trainingSummary.falsePositiveRateByLabel): + print("label %d: %s" % (i, rate)) + + print("True positive rate by label:") + for i, rate in enumerate(trainingSummary.truePositiveRateByLabel): + print("label %d: %s" % (i, rate)) + + print("Precision by label:") + for i, prec in enumerate(trainingSummary.precisionByLabel): + print("label %d: %s" % (i, prec)) + + print("Recall by label:") + for i, rec in enumerate(trainingSummary.recallByLabel): + print("label %d: %s" % (i, rec)) + + print("F-measure by label:") + for i, f in enumerate(trainingSummary.fMeasureByLabel()): + print("label %d: %s" % (i, f)) + + accuracy = trainingSummary.accuracy + falsePositiveRate = trainingSummary.weightedFalsePositiveRate + truePositiveRate = trainingSummary.weightedTruePositiveRate + fMeasure = trainingSummary.weightedFMeasure() + precision = trainingSummary.weightedPrecision + recall = trainingSummary.weightedRecall + print("Accuracy: %s\nFPR: %s\nTPR: %s\nF-measure: %s\nPrecision: %s\nRecall: %s" + % (accuracy, falsePositiveRate, truePositiveRate, fMeasure, precision, recall)) # $example off$ spark.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionSummaryExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionSummaryExample.scala index 1740a0d3f9d12..0368dcba460b5 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionSummaryExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionSummaryExample.scala @@ -19,7 +19,7 @@ package org.apache.spark.examples.ml // $example on$ -import org.apache.spark.ml.classification.{BinaryLogisticRegressionSummary, LogisticRegression} +import org.apache.spark.ml.classification.LogisticRegression // $example off$ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.functions.max @@ -47,25 +47,20 @@ object LogisticRegressionSummaryExample { // $example on$ // Extract the summary from the returned LogisticRegressionModel instance trained in the earlier // example - val trainingSummary = lrModel.summary + val trainingSummary = lrModel.binarySummary // Obtain the objective per iteration. val objectiveHistory = trainingSummary.objectiveHistory println("objectiveHistory:") objectiveHistory.foreach(loss => println(loss)) - // Obtain the metrics useful to judge performance on test data. - // We cast the summary to a BinaryLogisticRegressionSummary since the problem is a - // binary classification problem. - val binarySummary = trainingSummary.asInstanceOf[BinaryLogisticRegressionSummary] - // Obtain the receiver-operating characteristic as a dataframe and areaUnderROC. - val roc = binarySummary.roc + val roc = trainingSummary.roc roc.show() - println(s"areaUnderROC: ${binarySummary.areaUnderROC}") + println(s"areaUnderROC: ${trainingSummary.areaUnderROC}") // Set the model threshold to maximize F-Measure - val fMeasure = binarySummary.fMeasureByThreshold + val fMeasure = trainingSummary.fMeasureByThreshold val maxFMeasure = fMeasure.select(max("F-Measure")).head().getDouble(0) val bestThreshold = fMeasure.where($"F-Measure" === maxFMeasure) .select("threshold").head().getDouble(0) diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/MulticlassLogisticRegressionWithElasticNetExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/MulticlassLogisticRegressionWithElasticNetExample.scala index 3e61dbe628c20..1f7dbddd454e8 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/MulticlassLogisticRegressionWithElasticNetExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/MulticlassLogisticRegressionWithElasticNetExample.scala @@ -49,6 +49,49 @@ object MulticlassLogisticRegressionWithElasticNetExample { // Print the coefficients and intercept for multinomial logistic regression println(s"Coefficients: \n${lrModel.coefficientMatrix}") println(s"Intercepts: \n${lrModel.interceptVector}") + + val trainingSummary = lrModel.summary + + // Obtain the objective per iteration + val objectiveHistory = trainingSummary.objectiveHistory + println("objectiveHistory:") + objectiveHistory.foreach(println) + + // for multiclass, we can inspect metrics on a per-label basis + println("False positive rate by label:") + trainingSummary.falsePositiveRateByLabel.zipWithIndex.foreach { case (rate, label) => + println(s"label $label: $rate") + } + + println("True positive rate by label:") + trainingSummary.truePositiveRateByLabel.zipWithIndex.foreach { case (rate, label) => + println(s"label $label: $rate") + } + + println("Precision by label:") + trainingSummary.precisionByLabel.zipWithIndex.foreach { case (prec, label) => + println(s"label $label: $prec") + } + + println("Recall by label:") + trainingSummary.recallByLabel.zipWithIndex.foreach { case (rec, label) => + println(s"label $label: $rec") + } + + + println("F-measure by label:") + trainingSummary.fMeasureByLabel.zipWithIndex.foreach { case (f, label) => + println(s"label $label: $f") + } + + val accuracy = trainingSummary.accuracy + val falsePositiveRate = trainingSummary.weightedFalsePositiveRate + val truePositiveRate = trainingSummary.weightedTruePositiveRate + val fMeasure = trainingSummary.weightedFMeasure + val precision = trainingSummary.weightedPrecision + val recall = trainingSummary.weightedRecall + println(s"Accuracy: $accuracy\nFPR: $falsePositiveRate\nTPR: $truePositiveRate\n" + + s"F-measure: $fMeasure\nPrecision: $precision\nRecall: $recall") // $example off$ spark.stop() From 0a9ac0248b6514a1e83ff7e4c522424f01b8b78d Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 30 Jan 2018 19:43:17 +0800 Subject: [PATCH 0237/1161] [SPARK-23260][SPARK-23262][SQL] several data source v2 naming cleanup ## What changes were proposed in this pull request? All other classes in the reader/writer package doesn't have `V2` in their names, and the streaming reader/writer don't have `V2` either. It's more consistent to remove `V2` from `DataSourceV2Reader` and `DataSourceVWriter`. Also rename `DataSourceV2Option` to remote the `V2`, we should only have `V2` in the root interface: `DataSourceV2`. This PR also fixes some places that the mix-in interface doesn't extend the interface it aimed to mix in. ## How was this patch tested? existing tests. Author: Wenchen Fan Closes #20427 from cloud-fan/ds-v2. --- .../sql/kafka010/KafkaContinuousReader.scala | 2 +- .../sql/kafka010/KafkaSourceProvider.scala | 6 ++--- ...eV2Options.java => DataSourceOptions.java} | 8 +++---- .../spark/sql/sources/v2/ReadSupport.java | 8 +++---- .../sql/sources/v2/ReadSupportWithSchema.java | 8 +++---- .../sql/sources/v2/SessionConfigSupport.java | 2 +- .../spark/sql/sources/v2/WriteSupport.java | 12 +++++----- .../sources/v2/reader/DataReaderFactory.java | 2 +- ...rceV2Reader.java => DataSourceReader.java} | 11 +++++---- .../SupportsPushDownCatalystFilters.java | 4 ++-- .../v2/reader/SupportsPushDownFilters.java | 4 ++-- .../SupportsPushDownRequiredColumns.java | 6 ++--- .../v2/reader/SupportsReportPartitioning.java | 4 ++-- .../v2/reader/SupportsReportStatistics.java | 4 ++-- .../v2/reader/SupportsScanColumnarBatch.java | 6 ++--- .../v2/reader/SupportsScanUnsafeRow.java | 6 ++--- .../v2/streaming/ContinuousReadSupport.java | 4 ++-- .../v2/streaming/MicroBatchReadSupport.java | 4 ++-- .../v2/streaming/StreamWriteSupport.java | 10 ++++---- .../v2/streaming/reader/ContinuousReader.java | 6 ++--- .../v2/streaming/reader/MicroBatchReader.java | 6 ++--- .../v2/streaming/writer/StreamWriter.java | 6 ++--- ...rceV2Writer.java => DataSourceWriter.java} | 8 +++---- .../sql/sources/v2/writer/DataWriter.java | 12 +++++----- .../sources/v2/writer/DataWriterFactory.java | 2 +- .../v2/writer/SupportsWriteInternalRow.java | 4 ++-- .../v2/writer/WriterCommitMessage.java | 4 ++-- .../apache/spark/sql/DataFrameReader.scala | 2 +- .../apache/spark/sql/DataFrameWriter.scala | 2 +- .../v2/DataSourceReaderHolder.scala | 2 +- .../datasources/v2/DataSourceV2Relation.scala | 6 ++--- .../datasources/v2/DataSourceV2ScanExec.scala | 2 +- .../datasources/v2/WriteToDataSourceV2.scala | 4 ++-- .../streaming/MicroBatchExecution.scala | 6 ++--- .../streaming/RateSourceProvider.scala | 2 +- .../sql/execution/streaming/console.scala | 4 ++-- .../continuous/ContinuousExecution.scala | 6 ++--- .../ContinuousRateStreamSource.scala | 7 +++--- .../streaming/sources/ConsoleWriter.scala | 4 ++-- .../streaming/sources/MicroBatchWriter.scala | 8 +++---- .../sources/PackedRowWriterFactory.scala | 4 ++-- .../sources/RateStreamSourceV2.scala | 6 ++--- .../streaming/sources/memoryV2.scala | 6 ++--- .../sql/streaming/DataStreamReader.scala | 4 ++-- .../sources/v2/JavaAdvancedDataSourceV2.java | 6 ++--- .../sql/sources/v2/JavaBatchDataSourceV2.java | 6 ++--- .../v2/JavaPartitionAwareDataSource.java | 6 ++--- .../v2/JavaSchemaRequiredDataSource.java | 8 +++---- .../sources/v2/JavaSimpleDataSourceV2.java | 8 +++---- .../sources/v2/JavaUnsafeRowDataSourceV2.java | 6 ++--- .../streaming/RateSourceV2Suite.scala | 18 +++++++------- ...ite.scala => DataSourceOptionsSuite.scala} | 16 ++++++------- .../sql/sources/v2/DataSourceV2Suite.scala | 24 +++++++++---------- .../sources/v2/SimpleWritableDataSource.scala | 12 +++++----- .../sources/StreamingDataSourceV2Suite.scala | 8 +++---- 55 files changed, 176 insertions(+), 176 deletions(-) rename sql/core/src/main/java/org/apache/spark/sql/sources/v2/{DataSourceV2Options.java => DataSourceOptions.java} (94%) rename sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/{DataSourceV2Reader.java => DataSourceReader.java} (91%) rename sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/{DataSourceV2Writer.java => DataSourceWriter.java} (96%) rename sql/core/src/test/scala/org/apache/spark/sql/sources/v2/{DataSourceV2OptionsSuite.scala => DataSourceOptionsSuite.scala} (80%) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala index 9125cf5799d74..8c733426b256f 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala @@ -41,7 +41,7 @@ import org.apache.spark.unsafe.types.UTF8String * @param offsetReader a reader used to get kafka offsets. Note that the actual data will be * read by per-task consumers generated later. * @param kafkaParams String params for per-task Kafka consumers. - * @param sourceOptions The [[org.apache.spark.sql.sources.v2.DataSourceV2Options]] params which + * @param sourceOptions The [[org.apache.spark.sql.sources.v2.DataSourceOptions]] params which * are not Kafka consumer params. * @param metadataPath Path to a directory this reader can use for writing metadata. * @param initialOffsets The Kafka offsets to start reading data at. diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala index 2deb7fa2cdf1e..85e96b6783327 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala @@ -30,7 +30,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.{AnalysisException, DataFrame, SaveMode, SparkSession, SQLContext} import org.apache.spark.sql.execution.streaming.{Sink, Source} import org.apache.spark.sql.sources._ -import org.apache.spark.sql.sources.v2.DataSourceV2Options +import org.apache.spark.sql.sources.v2.DataSourceOptions import org.apache.spark.sql.sources.v2.streaming.{ContinuousReadSupport, StreamWriteSupport} import org.apache.spark.sql.sources.v2.streaming.writer.StreamWriter import org.apache.spark.sql.streaming.OutputMode @@ -109,7 +109,7 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister override def createContinuousReader( schema: Optional[StructType], metadataPath: String, - options: DataSourceV2Options): KafkaContinuousReader = { + options: DataSourceOptions): KafkaContinuousReader = { val parameters = options.asMap().asScala.toMap validateStreamOptions(parameters) // Each running query should use its own group id. Otherwise, the query may be only assigned @@ -227,7 +227,7 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister queryId: String, schema: StructType, mode: OutputMode, - options: DataSourceV2Options): StreamWriter = { + options: DataSourceOptions): StreamWriter = { import scala.collection.JavaConverters._ val spark = SparkSession.getActiveSession.get diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceV2Options.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceOptions.java similarity index 94% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceV2Options.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceOptions.java index ddc2acca693ac..c32053580f016 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceV2Options.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceOptions.java @@ -29,18 +29,18 @@ * data source options. */ @InterfaceStability.Evolving -public class DataSourceV2Options { +public class DataSourceOptions { private final Map keyLowerCasedMap; private String toLowerCase(String key) { return key.toLowerCase(Locale.ROOT); } - public static DataSourceV2Options empty() { - return new DataSourceV2Options(new HashMap<>()); + public static DataSourceOptions empty() { + return new DataSourceOptions(new HashMap<>()); } - public DataSourceV2Options(Map originalMap) { + public DataSourceOptions(Map originalMap) { keyLowerCasedMap = new HashMap<>(originalMap.size()); for (Map.Entry entry : originalMap.entrySet()) { keyLowerCasedMap.put(toLowerCase(entry.getKey()), entry.getValue()); diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupport.java index 948e20bacf4a2..0ea4dc6b5def3 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupport.java @@ -18,17 +18,17 @@ package org.apache.spark.sql.sources.v2; import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.sources.v2.reader.DataSourceV2Reader; +import org.apache.spark.sql.sources.v2.reader.DataSourceReader; /** * A mix-in interface for {@link DataSourceV2}. Data sources can implement this interface to * provide data reading ability and scan the data from the data source. */ @InterfaceStability.Evolving -public interface ReadSupport { +public interface ReadSupport extends DataSourceV2 { /** - * Creates a {@link DataSourceV2Reader} to scan the data from this data source. + * Creates a {@link DataSourceReader} to scan the data from this data source. * * If this method fails (by throwing an exception), the action would fail and no Spark job was * submitted. @@ -36,5 +36,5 @@ public interface ReadSupport { * @param options the options for the returned data source reader, which is an immutable * case-insensitive string-to-string map. */ - DataSourceV2Reader createReader(DataSourceV2Options options); + DataSourceReader createReader(DataSourceOptions options); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupportWithSchema.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupportWithSchema.java index b69c6bed8d1b5..3801402268af1 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupportWithSchema.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupportWithSchema.java @@ -18,7 +18,7 @@ package org.apache.spark.sql.sources.v2; import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.sources.v2.reader.DataSourceV2Reader; +import org.apache.spark.sql.sources.v2.reader.DataSourceReader; import org.apache.spark.sql.types.StructType; /** @@ -30,10 +30,10 @@ * supports both schema inference and user-specified schema. */ @InterfaceStability.Evolving -public interface ReadSupportWithSchema { +public interface ReadSupportWithSchema extends DataSourceV2 { /** - * Create a {@link DataSourceV2Reader} to scan the data from this data source. + * Create a {@link DataSourceReader} to scan the data from this data source. * * If this method fails (by throwing an exception), the action would fail and no Spark job was * submitted. @@ -45,5 +45,5 @@ public interface ReadSupportWithSchema { * @param options the options for the returned data source reader, which is an immutable * case-insensitive string-to-string map. */ - DataSourceV2Reader createReader(StructType schema, DataSourceV2Options options); + DataSourceReader createReader(StructType schema, DataSourceOptions options); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SessionConfigSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SessionConfigSupport.java index 3cb020d2e0836..9d66805d79b9e 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SessionConfigSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SessionConfigSupport.java @@ -25,7 +25,7 @@ * session. */ @InterfaceStability.Evolving -public interface SessionConfigSupport { +public interface SessionConfigSupport extends DataSourceV2 { /** * Key prefix of the session configs to propagate. Spark will extract all session configs that diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/WriteSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/WriteSupport.java index 1e3b644d8c4ae..cab56453816cc 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/WriteSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/WriteSupport.java @@ -21,7 +21,7 @@ import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.SaveMode; -import org.apache.spark.sql.sources.v2.writer.DataSourceV2Writer; +import org.apache.spark.sql.sources.v2.writer.DataSourceWriter; import org.apache.spark.sql.types.StructType; /** @@ -29,17 +29,17 @@ * provide data writing ability and save the data to the data source. */ @InterfaceStability.Evolving -public interface WriteSupport { +public interface WriteSupport extends DataSourceV2 { /** - * Creates an optional {@link DataSourceV2Writer} to save the data to this data source. Data + * Creates an optional {@link DataSourceWriter} to save the data to this data source. Data * sources can return None if there is no writing needed to be done according to the save mode. * * If this method fails (by throwing an exception), the action would fail and no Spark job was * submitted. * * @param jobId A unique string for the writing job. It's possible that there are many writing - * jobs running at the same time, and the returned {@link DataSourceV2Writer} can + * jobs running at the same time, and the returned {@link DataSourceWriter} can * use this job id to distinguish itself from other jobs. * @param schema the schema of the data to be written. * @param mode the save mode which determines what to do when the data are already in this data @@ -47,6 +47,6 @@ public interface WriteSupport { * @param options the options for the returned data source writer, which is an immutable * case-insensitive string-to-string map. */ - Optional createWriter( - String jobId, StructType schema, SaveMode mode, DataSourceV2Options options); + Optional createWriter( + String jobId, StructType schema, SaveMode mode, DataSourceOptions options); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReaderFactory.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReaderFactory.java index 077b95b837964..32e98e8f5d8bd 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReaderFactory.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReaderFactory.java @@ -22,7 +22,7 @@ import org.apache.spark.annotation.InterfaceStability; /** - * A reader factory returned by {@link DataSourceV2Reader#createDataReaderFactories()} and is + * A reader factory returned by {@link DataSourceReader#createDataReaderFactories()} and is * responsible for creating the actual data reader. The relationship between * {@link DataReaderFactory} and {@link DataReader} * is similar to the relationship between {@link Iterable} and {@link java.util.Iterator}. diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceV2Reader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceReader.java similarity index 91% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceV2Reader.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceReader.java index 0180cd9ea47f8..a470bccc5aad2 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceV2Reader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceReader.java @@ -21,14 +21,15 @@ import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.Row; +import org.apache.spark.sql.sources.v2.DataSourceOptions; +import org.apache.spark.sql.sources.v2.ReadSupport; +import org.apache.spark.sql.sources.v2.ReadSupportWithSchema; import org.apache.spark.sql.types.StructType; /** * A data source reader that is returned by - * {@link org.apache.spark.sql.sources.v2.ReadSupport#createReader( - * org.apache.spark.sql.sources.v2.DataSourceV2Options)} or - * {@link org.apache.spark.sql.sources.v2.ReadSupportWithSchema#createReader( - * StructType, org.apache.spark.sql.sources.v2.DataSourceV2Options)}. + * {@link ReadSupport#createReader(DataSourceOptions)} or + * {@link ReadSupportWithSchema#createReader(StructType, DataSourceOptions)}. * It can mix in various query optimization interfaces to speed up the data scan. The actual scan * logic is delegated to {@link DataReaderFactory}s that are returned by * {@link #createDataReaderFactories()}. @@ -52,7 +53,7 @@ * issues the scan request and does the actual data reading. */ @InterfaceStability.Evolving -public interface DataSourceV2Reader { +public interface DataSourceReader { /** * Returns the actual schema of this data source reader, which may be different from the physical diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownCatalystFilters.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownCatalystFilters.java index f76c687f450c8..98224102374aa 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownCatalystFilters.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownCatalystFilters.java @@ -21,7 +21,7 @@ import org.apache.spark.sql.catalyst.expressions.Expression; /** - * A mix-in interface for {@link DataSourceV2Reader}. Data source readers can implement this + * A mix-in interface for {@link DataSourceReader}. Data source readers can implement this * interface to push down arbitrary expressions as predicates to the data source. * This is an experimental and unstable interface as {@link Expression} is not public and may get * changed in the future Spark versions. @@ -31,7 +31,7 @@ * process this interface. */ @InterfaceStability.Unstable -public interface SupportsPushDownCatalystFilters { +public interface SupportsPushDownCatalystFilters extends DataSourceReader { /** * Pushes down filters, and returns unsupported filters. diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java index 6b0c9d417eeae..f35c711b0387a 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java @@ -21,7 +21,7 @@ import org.apache.spark.sql.sources.Filter; /** - * A mix-in interface for {@link DataSourceV2Reader}. Data source readers can implement this + * A mix-in interface for {@link DataSourceReader}. Data source readers can implement this * interface to push down filters to the data source and reduce the size of the data to be read. * * Note that, if data source readers implement both this interface and @@ -29,7 +29,7 @@ * {@link SupportsPushDownCatalystFilters}. */ @InterfaceStability.Evolving -public interface SupportsPushDownFilters { +public interface SupportsPushDownFilters extends DataSourceReader { /** * Pushes down filters, and returns unsupported filters. diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownRequiredColumns.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownRequiredColumns.java index fe0ac8ee0ee32..427b4d00a1128 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownRequiredColumns.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownRequiredColumns.java @@ -21,12 +21,12 @@ import org.apache.spark.sql.types.StructType; /** - * A mix-in interface for {@link DataSourceV2Reader}. Data source readers can implement this + * A mix-in interface for {@link DataSourceReader}. Data source readers can implement this * interface to push down required columns to the data source and only read these columns during * scan to reduce the size of the data to be read. */ @InterfaceStability.Evolving -public interface SupportsPushDownRequiredColumns { +public interface SupportsPushDownRequiredColumns extends DataSourceReader { /** * Applies column pruning w.r.t. the given requiredSchema. @@ -35,7 +35,7 @@ public interface SupportsPushDownRequiredColumns { * also OK to do the pruning partially, e.g., a data source may not be able to prune nested * fields, and only prune top-level columns. * - * Note that, data source readers should update {@link DataSourceV2Reader#readSchema()} after + * Note that, data source readers should update {@link DataSourceReader#readSchema()} after * applying column pruning. */ void pruneColumns(StructType requiredSchema); diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java index f786472ccf345..a2383a9d7d680 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java @@ -20,11 +20,11 @@ import org.apache.spark.annotation.InterfaceStability; /** - * A mix in interface for {@link DataSourceV2Reader}. Data source readers can implement this + * A mix in interface for {@link DataSourceReader}. Data source readers can implement this * interface to report data partitioning and try to avoid shuffle at Spark side. */ @InterfaceStability.Evolving -public interface SupportsReportPartitioning { +public interface SupportsReportPartitioning extends DataSourceReader { /** * Returns the output data partitioning that this reader guarantees. diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportStatistics.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportStatistics.java index c019d2f819ab7..11bb13fd3b211 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportStatistics.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportStatistics.java @@ -20,11 +20,11 @@ import org.apache.spark.annotation.InterfaceStability; /** - * A mix in interface for {@link DataSourceV2Reader}. Data source readers can implement this + * A mix in interface for {@link DataSourceReader}. Data source readers can implement this * interface to report statistics to Spark. */ @InterfaceStability.Evolving -public interface SupportsReportStatistics { +public interface SupportsReportStatistics extends DataSourceReader { /** * Returns the basic statistics of this data source. diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanColumnarBatch.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanColumnarBatch.java index 67da55554bbf3..2e5cfa78511f0 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanColumnarBatch.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanColumnarBatch.java @@ -24,11 +24,11 @@ import org.apache.spark.sql.vectorized.ColumnarBatch; /** - * A mix-in interface for {@link DataSourceV2Reader}. Data source readers can implement this + * A mix-in interface for {@link DataSourceReader}. Data source readers can implement this * interface to output {@link ColumnarBatch} and make the scan faster. */ @InterfaceStability.Evolving -public interface SupportsScanColumnarBatch extends DataSourceV2Reader { +public interface SupportsScanColumnarBatch extends DataSourceReader { @Override default List> createDataReaderFactories() { throw new IllegalStateException( @@ -36,7 +36,7 @@ default List> createDataReaderFactories() { } /** - * Similar to {@link DataSourceV2Reader#createDataReaderFactories()}, but returns columnar data + * Similar to {@link DataSourceReader#createDataReaderFactories()}, but returns columnar data * in batches. */ List> createBatchDataReaderFactories(); diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanUnsafeRow.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanUnsafeRow.java index 156af69520f77..9cd749e8e4ce9 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanUnsafeRow.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanUnsafeRow.java @@ -24,13 +24,13 @@ import org.apache.spark.sql.catalyst.expressions.UnsafeRow; /** - * A mix-in interface for {@link DataSourceV2Reader}. Data source readers can implement this + * A mix-in interface for {@link DataSourceReader}. Data source readers can implement this * interface to output {@link UnsafeRow} directly and avoid the row copy at Spark side. * This is an experimental and unstable interface, as {@link UnsafeRow} is not public and may get * changed in the future Spark versions. */ @InterfaceStability.Unstable -public interface SupportsScanUnsafeRow extends DataSourceV2Reader { +public interface SupportsScanUnsafeRow extends DataSourceReader { @Override default List> createDataReaderFactories() { @@ -39,7 +39,7 @@ default List> createDataReaderFactories() { } /** - * Similar to {@link DataSourceV2Reader#createDataReaderFactories()}, + * Similar to {@link DataSourceReader#createDataReaderFactories()}, * but returns data in unsafe row format. */ List> createUnsafeRowReaderFactories(); diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/ContinuousReadSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/ContinuousReadSupport.java index 9a93a806b0efc..f79424e036a52 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/ContinuousReadSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/ContinuousReadSupport.java @@ -21,7 +21,7 @@ import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.sources.v2.DataSourceV2; -import org.apache.spark.sql.sources.v2.DataSourceV2Options; +import org.apache.spark.sql.sources.v2.DataSourceOptions; import org.apache.spark.sql.sources.v2.streaming.reader.ContinuousReader; import org.apache.spark.sql.types.StructType; @@ -44,5 +44,5 @@ public interface ContinuousReadSupport extends DataSourceV2 { ContinuousReader createContinuousReader( Optional schema, String checkpointLocation, - DataSourceV2Options options); + DataSourceOptions options); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/MicroBatchReadSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/MicroBatchReadSupport.java index 3b357c01a29fe..22660e42ad850 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/MicroBatchReadSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/MicroBatchReadSupport.java @@ -20,8 +20,8 @@ import java.util.Optional; import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.sources.v2.DataSourceOptions; import org.apache.spark.sql.sources.v2.DataSourceV2; -import org.apache.spark.sql.sources.v2.DataSourceV2Options; import org.apache.spark.sql.sources.v2.streaming.reader.MicroBatchReader; import org.apache.spark.sql.types.StructType; @@ -50,5 +50,5 @@ public interface MicroBatchReadSupport extends DataSourceV2 { MicroBatchReader createMicroBatchReader( Optional schema, String checkpointLocation, - DataSourceV2Options options); + DataSourceOptions options); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/StreamWriteSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/StreamWriteSupport.java index 6cd219c67109a..7c5f304425093 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/StreamWriteSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/StreamWriteSupport.java @@ -19,10 +19,10 @@ import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.execution.streaming.BaseStreamingSink; +import org.apache.spark.sql.sources.v2.DataSourceOptions; import org.apache.spark.sql.sources.v2.DataSourceV2; -import org.apache.spark.sql.sources.v2.DataSourceV2Options; import org.apache.spark.sql.sources.v2.streaming.writer.StreamWriter; -import org.apache.spark.sql.sources.v2.writer.DataSourceV2Writer; +import org.apache.spark.sql.sources.v2.writer.DataSourceWriter; import org.apache.spark.sql.streaming.OutputMode; import org.apache.spark.sql.types.StructType; @@ -31,7 +31,7 @@ * provide data writing ability for structured streaming. */ @InterfaceStability.Evolving -public interface StreamWriteSupport extends BaseStreamingSink { +public interface StreamWriteSupport extends DataSourceV2, BaseStreamingSink { /** * Creates an optional {@link StreamWriter} to save the data to this data source. Data @@ -39,7 +39,7 @@ public interface StreamWriteSupport extends BaseStreamingSink { * * @param queryId A unique string for the writing query. It's possible that there are many * writing queries running at the same time, and the returned - * {@link DataSourceV2Writer} can use this id to distinguish itself from others. + * {@link DataSourceWriter} can use this id to distinguish itself from others. * @param schema the schema of the data to be written. * @param mode the output mode which determines what successive epoch output means to this * sink, please refer to {@link OutputMode} for more details. @@ -50,5 +50,5 @@ StreamWriter createStreamWriter( String queryId, StructType schema, OutputMode mode, - DataSourceV2Options options); + DataSourceOptions options); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/ContinuousReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/ContinuousReader.java index 3ac979cb0b7b4..6e5177ee83a62 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/ContinuousReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/ContinuousReader.java @@ -19,12 +19,12 @@ import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.execution.streaming.BaseStreamingSource; -import org.apache.spark.sql.sources.v2.reader.DataSourceV2Reader; +import org.apache.spark.sql.sources.v2.reader.DataSourceReader; import java.util.Optional; /** - * A mix-in interface for {@link DataSourceV2Reader}. Data source readers can implement this + * A mix-in interface for {@link DataSourceReader}. Data source readers can implement this * interface to allow reading in a continuous processing mode stream. * * Implementations must ensure each reader factory output is a {@link ContinuousDataReader}. @@ -33,7 +33,7 @@ * DataSource V1 APIs. This extension will be removed once we get rid of V1 completely. */ @InterfaceStability.Evolving -public interface ContinuousReader extends BaseStreamingSource, DataSourceV2Reader { +public interface ContinuousReader extends BaseStreamingSource, DataSourceReader { /** * Merge partitioned offsets coming from {@link ContinuousDataReader} instances for each * partition to a single global offset. diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/MicroBatchReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/MicroBatchReader.java index 68887e569fc1d..fcec446d892f5 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/MicroBatchReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/MicroBatchReader.java @@ -18,20 +18,20 @@ package org.apache.spark.sql.sources.v2.streaming.reader; import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.sources.v2.reader.DataSourceV2Reader; +import org.apache.spark.sql.sources.v2.reader.DataSourceReader; import org.apache.spark.sql.execution.streaming.BaseStreamingSource; import java.util.Optional; /** - * A mix-in interface for {@link DataSourceV2Reader}. Data source readers can implement this + * A mix-in interface for {@link DataSourceReader}. Data source readers can implement this * interface to indicate they allow micro-batch streaming reads. * * Note: This class currently extends {@link BaseStreamingSource} to maintain compatibility with * DataSource V1 APIs. This extension will be removed once we get rid of V1 completely. */ @InterfaceStability.Evolving -public interface MicroBatchReader extends DataSourceV2Reader, BaseStreamingSource { +public interface MicroBatchReader extends DataSourceReader, BaseStreamingSource { /** * Set the desired offset range for reader factories created from this reader. Reader factories * will generate only data within (`start`, `end`]; that is, from the first record after `start` diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/writer/StreamWriter.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/writer/StreamWriter.java index 3156c88933e5e..915ee6c4fb390 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/writer/StreamWriter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/writer/StreamWriter.java @@ -18,19 +18,19 @@ package org.apache.spark.sql.sources.v2.streaming.writer; import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.sources.v2.writer.DataSourceV2Writer; +import org.apache.spark.sql.sources.v2.writer.DataSourceWriter; import org.apache.spark.sql.sources.v2.writer.DataWriter; import org.apache.spark.sql.sources.v2.writer.WriterCommitMessage; /** - * A {@link DataSourceV2Writer} for use with structured streaming. This writer handles commits and + * A {@link DataSourceWriter} for use with structured streaming. This writer handles commits and * aborts relative to an epoch ID determined by the execution engine. * * {@link DataWriter} implementations generated by a StreamWriter may be reused for multiple epochs, * and so must reset any internal state after a successful commit. */ @InterfaceStability.Evolving -public interface StreamWriter extends DataSourceV2Writer { +public interface StreamWriter extends DataSourceWriter { /** * Commits this writing job for the specified epoch with a list of commit messages. The commit * messages are collected from successful data writers and are produced by diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceV2Writer.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java similarity index 96% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceV2Writer.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java index 8048f507a1dca..d89d27d0e5b1b 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceV2Writer.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java @@ -20,16 +20,16 @@ import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.Row; import org.apache.spark.sql.SaveMode; -import org.apache.spark.sql.sources.v2.DataSourceV2Options; +import org.apache.spark.sql.sources.v2.DataSourceOptions; import org.apache.spark.sql.sources.v2.WriteSupport; import org.apache.spark.sql.streaming.OutputMode; import org.apache.spark.sql.types.StructType; /** * A data source writer that is returned by - * {@link WriteSupport#createWriter(String, StructType, SaveMode, DataSourceV2Options)}/ + * {@link WriteSupport#createWriter(String, StructType, SaveMode, DataSourceOptions)}/ * {@link org.apache.spark.sql.sources.v2.streaming.StreamWriteSupport#createStreamWriter( - * String, StructType, OutputMode, DataSourceV2Options)}. + * String, StructType, OutputMode, DataSourceOptions)}. * It can mix in various writing optimization interfaces to speed up the data saving. The actual * writing logic is delegated to {@link DataWriter}. * @@ -52,7 +52,7 @@ * Please refer to the documentation of commit/abort methods for detailed specifications. */ @InterfaceStability.Evolving -public interface DataSourceV2Writer { +public interface DataSourceWriter { /** * Creates a writer factory which will be serialized and sent to executors. diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java index 04b03e63de500..53941a89ba94e 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java @@ -33,11 +33,11 @@ * * If this data writer succeeds(all records are successfully written and {@link #commit()} * succeeds), a {@link WriterCommitMessage} will be sent to the driver side and pass to - * {@link DataSourceV2Writer#commit(WriterCommitMessage[])} with commit messages from other data + * {@link DataSourceWriter#commit(WriterCommitMessage[])} with commit messages from other data * writers. If this data writer fails(one record fails to write or {@link #commit()} fails), an * exception will be sent to the driver side, and Spark will retry this writing task for some times, * each time {@link DataWriterFactory#createDataWriter(int, int)} gets a different `attemptNumber`, - * and finally call {@link DataSourceV2Writer#abort(WriterCommitMessage[])} if all retry fail. + * and finally call {@link DataSourceWriter#abort(WriterCommitMessage[])} if all retry fail. * * Besides the retry mechanism, Spark may launch speculative tasks if the existing writing task * takes too long to finish. Different from retried tasks, which are launched one by one after the @@ -69,11 +69,11 @@ public interface DataWriter { /** * Commits this writer after all records are written successfully, returns a commit message which * will be sent back to driver side and passed to - * {@link DataSourceV2Writer#commit(WriterCommitMessage[])}. + * {@link DataSourceWriter#commit(WriterCommitMessage[])}. * * The written data should only be visible to data source readers after - * {@link DataSourceV2Writer#commit(WriterCommitMessage[])} succeeds, which means this method - * should still "hide" the written data and ask the {@link DataSourceV2Writer} at driver side to + * {@link DataSourceWriter#commit(WriterCommitMessage[])} succeeds, which means this method + * should still "hide" the written data and ask the {@link DataSourceWriter} at driver side to * do the final commit via {@link WriterCommitMessage}. * * If this method fails (by throwing an exception), {@link #abort()} will be called and this @@ -91,7 +91,7 @@ public interface DataWriter { * failed. * * If this method fails(by throwing an exception), the underlying data source may have garbage - * that need to be cleaned by {@link DataSourceV2Writer#abort(WriterCommitMessage[])} or manually, + * that need to be cleaned by {@link DataSourceWriter#abort(WriterCommitMessage[])} or manually, * but these garbage should not be visible to data source readers. * * @throws IOException if failure happens during disk/network IO like writing files. diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java index 18ec792f5a2c9..ea95442511ce5 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java @@ -22,7 +22,7 @@ import org.apache.spark.annotation.InterfaceStability; /** - * A factory of {@link DataWriter} returned by {@link DataSourceV2Writer#createWriterFactory()}, + * A factory of {@link DataWriter} returned by {@link DataSourceWriter#createWriterFactory()}, * which is responsible for creating and initializing the actual data writer at executor side. * * Note that, the writer factory will be serialized and sent to executors, then the data writer diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsWriteInternalRow.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsWriteInternalRow.java index 3e0518814f458..d2cf7e01c08c8 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsWriteInternalRow.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsWriteInternalRow.java @@ -22,14 +22,14 @@ import org.apache.spark.sql.catalyst.InternalRow; /** - * A mix-in interface for {@link DataSourceV2Writer}. Data source writers can implement this + * A mix-in interface for {@link DataSourceWriter}. Data source writers can implement this * interface to write {@link InternalRow} directly and avoid the row conversion at Spark side. * This is an experimental and unstable interface, as {@link InternalRow} is not public and may get * changed in the future Spark versions. */ @InterfaceStability.Unstable -public interface SupportsWriteInternalRow extends DataSourceV2Writer { +public interface SupportsWriteInternalRow extends DataSourceWriter { @Override default DataWriterFactory createWriterFactory() { diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/WriterCommitMessage.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/WriterCommitMessage.java index 082d6b5dc409f..9e38836c0edf9 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/WriterCommitMessage.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/WriterCommitMessage.java @@ -23,10 +23,10 @@ /** * A commit message returned by {@link DataWriter#commit()} and will be sent back to the driver side - * as the input parameter of {@link DataSourceV2Writer#commit(WriterCommitMessage[])}. + * as the input parameter of {@link DataSourceWriter#commit(WriterCommitMessage[])}. * * This is an empty interface, data sources should define their own message class and use it in - * their {@link DataWriter#commit()} and {@link DataSourceV2Writer#commit(WriterCommitMessage[])} + * their {@link DataWriter#commit()} and {@link DataSourceWriter#commit(WriterCommitMessage[])} * implementations. */ @InterfaceStability.Evolving diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index b714a46b5f786..46b5f54a33f74 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -186,7 +186,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { val cls = DataSource.lookupDataSource(source, sparkSession.sessionState.conf) if (classOf[DataSourceV2].isAssignableFrom(cls)) { val ds = cls.newInstance() - val options = new DataSourceV2Options((extraOptions ++ + val options = new DataSourceOptions((extraOptions ++ DataSourceV2Utils.extractSessionConfigs( ds = ds.asInstanceOf[DataSourceV2], conf = sparkSession.sessionState.conf)).asJava) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 5c02eae05304b..ed7a9100cc7f1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -243,7 +243,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { val ds = cls.newInstance() ds match { case ws: WriteSupport => - val options = new DataSourceV2Options((extraOptions ++ + val options = new DataSourceOptions((extraOptions ++ DataSourceV2Utils.extractSessionConfigs( ds = ds.asInstanceOf[DataSourceV2], conf = df.sparkSession.sessionState.conf)).asJava) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceReaderHolder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceReaderHolder.scala index 6093df26630cd..6460c97abe344 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceReaderHolder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceReaderHolder.scala @@ -35,7 +35,7 @@ trait DataSourceReaderHolder { /** * The held data source reader. */ - def reader: DataSourceV2Reader + def reader: DataSourceReader /** * The metadata of this data source reader that can be used for equality test. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala index cba20dd902007..3d4c64981373d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.sources.v2.reader._ case class DataSourceV2Relation( fullOutput: Seq[AttributeReference], - reader: DataSourceV2Reader) extends LeafNode with DataSourceReaderHolder { + reader: DataSourceReader) extends LeafNode with DataSourceReaderHolder { override def canEqual(other: Any): Boolean = other.isInstanceOf[DataSourceV2Relation] @@ -41,12 +41,12 @@ case class DataSourceV2Relation( */ class StreamingDataSourceV2Relation( fullOutput: Seq[AttributeReference], - reader: DataSourceV2Reader) extends DataSourceV2Relation(fullOutput, reader) { + reader: DataSourceReader) extends DataSourceV2Relation(fullOutput, reader) { override def isStreaming: Boolean = true } object DataSourceV2Relation { - def apply(reader: DataSourceV2Reader): DataSourceV2Relation = { + def apply(reader: DataSourceReader): DataSourceV2Relation = { new DataSourceV2Relation(reader.readSchema().toAttributes, reader) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala index 3f808fbb40932..ee085820b0775 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala @@ -36,7 +36,7 @@ import org.apache.spark.sql.types.StructType */ case class DataSourceV2ScanExec( fullOutput: Seq[AttributeReference], - @transient reader: DataSourceV2Reader) + @transient reader: DataSourceReader) extends LeafExecNode with DataSourceReaderHolder with ColumnarBatchScan { override def canEqual(other: Any): Boolean = other.isInstanceOf[DataSourceV2ScanExec] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala index cd6b3e99b6bcb..c544adbf32cdf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala @@ -35,7 +35,7 @@ import org.apache.spark.util.Utils /** * The logical plan for writing data into data source v2. */ -case class WriteToDataSourceV2(writer: DataSourceV2Writer, query: LogicalPlan) extends LogicalPlan { +case class WriteToDataSourceV2(writer: DataSourceWriter, query: LogicalPlan) extends LogicalPlan { override def children: Seq[LogicalPlan] = Seq(query) override def output: Seq[Attribute] = Nil } @@ -43,7 +43,7 @@ case class WriteToDataSourceV2(writer: DataSourceV2Writer, query: LogicalPlan) e /** * The physical plan for writing data into data source v2. */ -case class WriteToDataSourceV2Exec(writer: DataSourceV2Writer, query: SparkPlan) extends SparkPlan { +case class WriteToDataSourceV2Exec(writer: DataSourceWriter, query: SparkPlan) extends SparkPlan { override def children: Seq[SparkPlan] = Seq(query) override def output: Seq[Attribute] = Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index 975975243a3d1..93572f7a63132 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.datasources.v2.{StreamingDataSourceV2Relation, WriteToDataSourceV2} import org.apache.spark.sql.execution.streaming.sources.{InternalRowMicroBatchWriter, MicroBatchWriter} -import org.apache.spark.sql.sources.v2.DataSourceV2Options +import org.apache.spark.sql.sources.v2.DataSourceOptions import org.apache.spark.sql.sources.v2.streaming.{MicroBatchReadSupport, StreamWriteSupport} import org.apache.spark.sql.sources.v2.streaming.reader.{MicroBatchReader, Offset => OffsetV2} import org.apache.spark.sql.sources.v2.writer.SupportsWriteInternalRow @@ -89,7 +89,7 @@ class MicroBatchExecution( val reader = source.createMicroBatchReader( Optional.empty(), // user specified schema metadataPath, - new DataSourceV2Options(options.asJava)) + new DataSourceOptions(options.asJava)) nextSourceId += 1 StreamingExecutionRelation(reader, output)(sparkSession) }) @@ -447,7 +447,7 @@ class MicroBatchExecution( s"$runId", newAttributePlan.schema, outputMode, - new DataSourceV2Options(extraOptions.asJava)) + new DataSourceOptions(extraOptions.asJava)) if (writer.isInstanceOf[SupportsWriteInternalRow]) { WriteToDataSourceV2( new InternalRowMicroBatchWriter(currentBatchId, writer), newAttributePlan) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala index 66eb0169ac1ec..5e3fee633f591 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala @@ -111,7 +111,7 @@ class RateSourceProvider extends StreamSourceProvider with DataSourceRegister override def createContinuousReader( schema: Optional[StructType], checkpointLocation: String, - options: DataSourceV2Options): ContinuousReader = { + options: DataSourceOptions): ContinuousReader = { new RateStreamContinuousReader(options) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala index d5ac0bd1df52b..3f5bb489d6528 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.streaming import org.apache.spark.sql._ import org.apache.spark.sql.execution.streaming.sources.ConsoleWriter import org.apache.spark.sql.sources.{BaseRelation, CreatableRelationProvider, DataSourceRegister} -import org.apache.spark.sql.sources.v2.{DataSourceV2, DataSourceV2Options} +import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2} import org.apache.spark.sql.sources.v2.streaming.StreamWriteSupport import org.apache.spark.sql.sources.v2.streaming.writer.StreamWriter import org.apache.spark.sql.streaming.OutputMode @@ -40,7 +40,7 @@ class ConsoleSinkProvider extends DataSourceV2 queryId: String, schema: StructType, mode: OutputMode, - options: DataSourceV2Options): StreamWriter = { + options: DataSourceOptions): StreamWriter = { new ConsoleWriter(schema, options) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index 60f880f9c73b8..9402d7c1dcefd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, StreamingDataSourceV2Relation, WriteToDataSourceV2} import org.apache.spark.sql.execution.streaming.{ContinuousExecutionRelation, StreamingRelationV2, _} -import org.apache.spark.sql.sources.v2.DataSourceV2Options +import org.apache.spark.sql.sources.v2.DataSourceOptions import org.apache.spark.sql.sources.v2.streaming.{ContinuousReadSupport, StreamWriteSupport} import org.apache.spark.sql.sources.v2.streaming.reader.{ContinuousReader, PartitionOffset} import org.apache.spark.sql.streaming.{OutputMode, ProcessingTime, Trigger} @@ -160,7 +160,7 @@ class ContinuousExecution( dataSource.createContinuousReader( java.util.Optional.empty[StructType](), metadataPath, - new DataSourceV2Options(extraReaderOptions.asJava)) + new DataSourceOptions(extraReaderOptions.asJava)) } uniqueSources = continuousSources.distinct @@ -198,7 +198,7 @@ class ContinuousExecution( s"$runId", triggerLogicalPlan.schema, outputMode, - new DataSourceV2Options(extraOptions.asJava)) + new DataSourceOptions(extraOptions.asJava)) val withSink = WriteToDataSourceV2(writer, triggerLogicalPlan) val reader = withSink.collect { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala index 61304480f4721..ff028ebc4236a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala @@ -23,19 +23,18 @@ import org.json4s.DefaultFormats import org.json4s.jackson.Serialization import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.streaming.{RateSourceProvider, RateStreamOffset, ValueRunTimeMsPair} import org.apache.spark.sql.execution.streaming.sources.RateStreamSourceV2 -import org.apache.spark.sql.sources.v2.{DataSourceV2, DataSourceV2Options} +import org.apache.spark.sql.sources.v2.DataSourceOptions import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.sources.v2.streaming.reader.{ContinuousDataReader, ContinuousReader, Offset, PartitionOffset} -import org.apache.spark.sql.types.{LongType, StructField, StructType, TimestampType} +import org.apache.spark.sql.types.StructType case class RateStreamPartitionOffset( partition: Int, currentValue: Long, currentTimeMs: Long) extends PartitionOffset -class RateStreamContinuousReader(options: DataSourceV2Options) +class RateStreamContinuousReader(options: DataSourceOptions) extends ContinuousReader { implicit val defaultFormats: DefaultFormats = DefaultFormats diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala index 7c1700f1de48c..d46f4d7b86360 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala @@ -19,13 +19,13 @@ package org.apache.spark.sql.execution.streaming.sources import org.apache.spark.internal.Logging import org.apache.spark.sql.{Row, SparkSession} -import org.apache.spark.sql.sources.v2.DataSourceV2Options +import org.apache.spark.sql.sources.v2.DataSourceOptions import org.apache.spark.sql.sources.v2.streaming.writer.StreamWriter import org.apache.spark.sql.sources.v2.writer.{DataWriterFactory, WriterCommitMessage} import org.apache.spark.sql.types.StructType /** Common methods used to create writes for the the console sink */ -class ConsoleWriter(schema: StructType, options: DataSourceV2Options) +class ConsoleWriter(schema: StructType, options: DataSourceOptions) extends StreamWriter with Logging { // Number of rows to display, by default 20 rows diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWriter.scala index d7f3ba8856982..d7ce9a7b84479 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWriter.scala @@ -20,14 +20,14 @@ package org.apache.spark.sql.execution.streaming.sources import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.sources.v2.streaming.writer.StreamWriter -import org.apache.spark.sql.sources.v2.writer.{DataSourceV2Writer, DataWriterFactory, SupportsWriteInternalRow, WriterCommitMessage} +import org.apache.spark.sql.sources.v2.writer.{DataSourceWriter, DataWriterFactory, SupportsWriteInternalRow, WriterCommitMessage} /** - * A [[DataSourceV2Writer]] used to hook V2 stream writers into a microbatch plan. It implements + * A [[DataSourceWriter]] used to hook V2 stream writers into a microbatch plan. It implements * the non-streaming interface, forwarding the batch ID determined at construction to a wrapped * streaming writer. */ -class MicroBatchWriter(batchId: Long, writer: StreamWriter) extends DataSourceV2Writer { +class MicroBatchWriter(batchId: Long, writer: StreamWriter) extends DataSourceWriter { override def commit(messages: Array[WriterCommitMessage]): Unit = { writer.commit(batchId, messages) } @@ -38,7 +38,7 @@ class MicroBatchWriter(batchId: Long, writer: StreamWriter) extends DataSourceV2 } class InternalRowMicroBatchWriter(batchId: Long, writer: StreamWriter) - extends DataSourceV2Writer with SupportsWriteInternalRow { + extends DataSourceWriter with SupportsWriteInternalRow { override def commit(messages: Array[WriterCommitMessage]): Unit = { writer.commit(batchId, messages) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PackedRowWriterFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PackedRowWriterFactory.scala index 9282ba05bdb7b..248295e401a0d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PackedRowWriterFactory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PackedRowWriterFactory.scala @@ -21,11 +21,11 @@ import scala.collection.mutable import org.apache.spark.internal.Logging import org.apache.spark.sql.Row -import org.apache.spark.sql.sources.v2.writer.{DataWriter, DataWriterFactory, WriterCommitMessage} +import org.apache.spark.sql.sources.v2.writer.{DataSourceWriter, DataWriter, DataWriterFactory, WriterCommitMessage} /** * A simple [[DataWriterFactory]] whose tasks just pack rows into the commit message for delivery - * to a [[org.apache.spark.sql.sources.v2.writer.DataSourceV2Writer]] on the driver. + * to a [[DataSourceWriter]] on the driver. * * Note that, because it sends all rows to the driver, this factory will generally be unsuitable * for production-quality sinks. It's intended for use in tests. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala index a25cc4f3b06f8..43949e6180aaa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.streaming.{RateStreamOffset, ValueRunTimeMsPair} import org.apache.spark.sql.sources.DataSourceRegister -import org.apache.spark.sql.sources.v2.{DataSourceV2, DataSourceV2Options} +import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2} import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.sources.v2.streaming.MicroBatchReadSupport import org.apache.spark.sql.sources.v2.streaming.reader.{MicroBatchReader, Offset} @@ -44,14 +44,14 @@ class RateSourceProviderV2 extends DataSourceV2 with MicroBatchReadSupport with override def createMicroBatchReader( schema: Optional[StructType], checkpointLocation: String, - options: DataSourceV2Options): MicroBatchReader = { + options: DataSourceOptions): MicroBatchReader = { new RateStreamMicroBatchReader(options) } override def shortName(): String = "ratev2" } -class RateStreamMicroBatchReader(options: DataSourceV2Options) +class RateStreamMicroBatchReader(options: DataSourceOptions) extends MicroBatchReader { implicit val defaultFormats: DefaultFormats = DefaultFormats diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala index ce55e44d932bd..58767261dc684 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics} import org.apache.spark.sql.catalyst.streaming.InternalOutputModes.{Append, Complete, Update} import org.apache.spark.sql.execution.streaming.Sink -import org.apache.spark.sql.sources.v2.{DataSourceV2, DataSourceV2Options} +import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2} import org.apache.spark.sql.sources.v2.streaming.StreamWriteSupport import org.apache.spark.sql.sources.v2.streaming.writer.StreamWriter import org.apache.spark.sql.sources.v2.writer._ @@ -45,7 +45,7 @@ class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with Logging { queryId: String, schema: StructType, mode: OutputMode, - options: DataSourceV2Options): StreamWriter = { + options: DataSourceOptions): StreamWriter = { new MemoryStreamWriter(this, mode) } @@ -114,7 +114,7 @@ class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with Logging { case class MemoryWriterCommitMessage(partition: Int, data: Seq[Row]) extends WriterCommitMessage {} class MemoryWriter(sink: MemorySinkV2, batchId: Long, outputMode: OutputMode) - extends DataSourceV2Writer with Logging { + extends DataSourceWriter with Logging { override def createWriterFactory: MemoryWriterFactory = MemoryWriterFactory(outputMode) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index 9f5ca9f914284..f1b3f93c4e1fc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming.{StreamingRelation, StreamingRelationV2} import org.apache.spark.sql.sources.StreamSourceProvider -import org.apache.spark.sql.sources.v2.DataSourceV2Options +import org.apache.spark.sql.sources.v2.DataSourceOptions import org.apache.spark.sql.sources.v2.streaming.{ContinuousReadSupport, MicroBatchReadSupport} import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils @@ -158,7 +158,7 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo } val ds = DataSource.lookupDataSource(source, sparkSession.sqlContext.conf).newInstance() - val options = new DataSourceV2Options(extraOptions.asJava) + val options = new DataSourceOptions(extraOptions.asJava) // We need to generate the V1 data source so we can pass it to the V2 relation as a shim. // We can't be sure at this point whether we'll actually want to use V2, since we don't know the // writer or whether the query is continuous. diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java index 4026ee44bfdb7..d421f7d19563f 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java @@ -24,15 +24,15 @@ import org.apache.spark.sql.catalyst.expressions.GenericRow; import org.apache.spark.sql.sources.Filter; import org.apache.spark.sql.sources.GreaterThan; +import org.apache.spark.sql.sources.v2.DataSourceOptions; import org.apache.spark.sql.sources.v2.DataSourceV2; -import org.apache.spark.sql.sources.v2.DataSourceV2Options; import org.apache.spark.sql.sources.v2.ReadSupport; import org.apache.spark.sql.sources.v2.reader.*; import org.apache.spark.sql.types.StructType; public class JavaAdvancedDataSourceV2 implements DataSourceV2, ReadSupport { - class Reader implements DataSourceV2Reader, SupportsPushDownRequiredColumns, + class Reader implements DataSourceReader, SupportsPushDownRequiredColumns, SupportsPushDownFilters { private StructType requiredSchema = new StructType().add("i", "int").add("j", "int"); @@ -131,7 +131,7 @@ public void close() throws IOException { @Override - public DataSourceV2Reader createReader(DataSourceV2Options options) { + public DataSourceReader createReader(DataSourceOptions options) { return new Reader(); } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java index 34e6c63801064..c55093768105b 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java @@ -21,8 +21,8 @@ import java.util.List; import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector; +import org.apache.spark.sql.sources.v2.DataSourceOptions; import org.apache.spark.sql.sources.v2.DataSourceV2; -import org.apache.spark.sql.sources.v2.DataSourceV2Options; import org.apache.spark.sql.sources.v2.ReadSupport; import org.apache.spark.sql.sources.v2.reader.*; import org.apache.spark.sql.types.DataTypes; @@ -33,7 +33,7 @@ public class JavaBatchDataSourceV2 implements DataSourceV2, ReadSupport { - class Reader implements DataSourceV2Reader, SupportsScanColumnarBatch { + class Reader implements DataSourceReader, SupportsScanColumnarBatch { private final StructType schema = new StructType().add("i", "int").add("j", "int"); @Override @@ -108,7 +108,7 @@ public void close() throws IOException { @Override - public DataSourceV2Reader createReader(DataSourceV2Options options) { + public DataSourceReader createReader(DataSourceOptions options) { return new Reader(); } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java index d0c87503ab455..99cca0f6dd626 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java @@ -23,15 +23,15 @@ import org.apache.spark.sql.Row; import org.apache.spark.sql.catalyst.expressions.GenericRow; +import org.apache.spark.sql.sources.v2.DataSourceOptions; import org.apache.spark.sql.sources.v2.DataSourceV2; -import org.apache.spark.sql.sources.v2.DataSourceV2Options; import org.apache.spark.sql.sources.v2.ReadSupport; import org.apache.spark.sql.sources.v2.reader.*; import org.apache.spark.sql.types.StructType; public class JavaPartitionAwareDataSource implements DataSourceV2, ReadSupport { - class Reader implements DataSourceV2Reader, SupportsReportPartitioning { + class Reader implements DataSourceReader, SupportsReportPartitioning { private final StructType schema = new StructType().add("a", "int").add("b", "int"); @Override @@ -104,7 +104,7 @@ public DataReader createDataReader() { } @Override - public DataSourceV2Reader createReader(DataSourceV2Options options) { + public DataSourceReader createReader(DataSourceOptions options) { return new Reader(); } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java index f997366af1a64..048d078dfaac4 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java @@ -20,16 +20,16 @@ import java.util.List; import org.apache.spark.sql.Row; +import org.apache.spark.sql.sources.v2.DataSourceOptions; import org.apache.spark.sql.sources.v2.DataSourceV2; -import org.apache.spark.sql.sources.v2.DataSourceV2Options; import org.apache.spark.sql.sources.v2.ReadSupportWithSchema; -import org.apache.spark.sql.sources.v2.reader.DataSourceV2Reader; +import org.apache.spark.sql.sources.v2.reader.DataSourceReader; import org.apache.spark.sql.sources.v2.reader.DataReaderFactory; import org.apache.spark.sql.types.StructType; public class JavaSchemaRequiredDataSource implements DataSourceV2, ReadSupportWithSchema { - class Reader implements DataSourceV2Reader { + class Reader implements DataSourceReader { private final StructType schema; Reader(StructType schema) { @@ -48,7 +48,7 @@ public List> createDataReaderFactories() { } @Override - public DataSourceV2Reader createReader(StructType schema, DataSourceV2Options options) { + public DataSourceReader createReader(StructType schema, DataSourceOptions options) { return new Reader(schema); } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java index 2beed431d301f..96f55b8a76811 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java @@ -23,16 +23,16 @@ import org.apache.spark.sql.Row; import org.apache.spark.sql.catalyst.expressions.GenericRow; import org.apache.spark.sql.sources.v2.DataSourceV2; -import org.apache.spark.sql.sources.v2.DataSourceV2Options; +import org.apache.spark.sql.sources.v2.DataSourceOptions; import org.apache.spark.sql.sources.v2.ReadSupport; import org.apache.spark.sql.sources.v2.reader.DataReader; import org.apache.spark.sql.sources.v2.reader.DataReaderFactory; -import org.apache.spark.sql.sources.v2.reader.DataSourceV2Reader; +import org.apache.spark.sql.sources.v2.reader.DataSourceReader; import org.apache.spark.sql.types.StructType; public class JavaSimpleDataSourceV2 implements DataSourceV2, ReadSupport { - class Reader implements DataSourceV2Reader { + class Reader implements DataSourceReader { private final StructType schema = new StructType().add("i", "int").add("j", "int"); @Override @@ -80,7 +80,7 @@ public void close() throws IOException { } @Override - public DataSourceV2Reader createReader(DataSourceV2Options options) { + public DataSourceReader createReader(DataSourceOptions options) { return new Reader(); } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaUnsafeRowDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaUnsafeRowDataSourceV2.java index e8187524ea871..c3916e0b370b5 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaUnsafeRowDataSourceV2.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaUnsafeRowDataSourceV2.java @@ -21,15 +21,15 @@ import java.util.List; import org.apache.spark.sql.catalyst.expressions.UnsafeRow; +import org.apache.spark.sql.sources.v2.DataSourceOptions; import org.apache.spark.sql.sources.v2.DataSourceV2; -import org.apache.spark.sql.sources.v2.DataSourceV2Options; import org.apache.spark.sql.sources.v2.ReadSupport; import org.apache.spark.sql.sources.v2.reader.*; import org.apache.spark.sql.types.StructType; public class JavaUnsafeRowDataSourceV2 implements DataSourceV2, ReadSupport { - class Reader implements DataSourceV2Reader, SupportsScanUnsafeRow { + class Reader implements DataSourceReader, SupportsScanUnsafeRow { private final StructType schema = new StructType().add("i", "int").add("j", "int"); @Override @@ -83,7 +83,7 @@ public void close() throws IOException { } @Override - public DataSourceV2Reader createReader(DataSourceV2Options options) { + public DataSourceReader createReader(DataSourceOptions options) { return new Reader(); } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala index d2cfe7905f6fa..b060aeeef811d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.Row import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming.continuous._ import org.apache.spark.sql.execution.streaming.sources.{RateStreamBatchTask, RateStreamMicroBatchReader, RateStreamSourceV2} -import org.apache.spark.sql.sources.v2.DataSourceV2Options +import org.apache.spark.sql.sources.v2.DataSourceOptions import org.apache.spark.sql.sources.v2.streaming.{ContinuousReadSupport, MicroBatchReadSupport} import org.apache.spark.sql.streaming.StreamTest import org.apache.spark.util.ManualClock @@ -49,7 +49,7 @@ class RateSourceV2Suite extends StreamTest { test("microbatch in registry") { DataSource.lookupDataSource("ratev2", spark.sqlContext.conf).newInstance() match { case ds: MicroBatchReadSupport => - val reader = ds.createMicroBatchReader(Optional.empty(), "", DataSourceV2Options.empty()) + val reader = ds.createMicroBatchReader(Optional.empty(), "", DataSourceOptions.empty()) assert(reader.isInstanceOf[RateStreamMicroBatchReader]) case _ => throw new IllegalStateException("Could not find v2 read support for rate") @@ -76,14 +76,14 @@ class RateSourceV2Suite extends StreamTest { test("microbatch - numPartitions propagated") { val reader = new RateStreamMicroBatchReader( - new DataSourceV2Options(Map("numPartitions" -> "11", "rowsPerSecond" -> "33").asJava)) + new DataSourceOptions(Map("numPartitions" -> "11", "rowsPerSecond" -> "33").asJava)) reader.setOffsetRange(Optional.empty(), Optional.empty()) val tasks = reader.createDataReaderFactories() assert(tasks.size == 11) } test("microbatch - set offset") { - val reader = new RateStreamMicroBatchReader(DataSourceV2Options.empty()) + val reader = new RateStreamMicroBatchReader(DataSourceOptions.empty()) val startOffset = RateStreamOffset(Map((0, ValueRunTimeMsPair(0, 1000)))) val endOffset = RateStreamOffset(Map((0, ValueRunTimeMsPair(0, 2000)))) reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) @@ -93,7 +93,7 @@ class RateSourceV2Suite extends StreamTest { test("microbatch - infer offsets") { val reader = new RateStreamMicroBatchReader( - new DataSourceV2Options(Map("numPartitions" -> "1", "rowsPerSecond" -> "100").asJava)) + new DataSourceOptions(Map("numPartitions" -> "1", "rowsPerSecond" -> "100").asJava)) reader.clock.waitTillTime(reader.clock.getTimeMillis() + 100) reader.setOffsetRange(Optional.empty(), Optional.empty()) reader.getStartOffset() match { @@ -114,7 +114,7 @@ class RateSourceV2Suite extends StreamTest { test("microbatch - predetermined batch size") { val reader = new RateStreamMicroBatchReader( - new DataSourceV2Options(Map("numPartitions" -> "1", "rowsPerSecond" -> "20").asJava)) + new DataSourceOptions(Map("numPartitions" -> "1", "rowsPerSecond" -> "20").asJava)) val startOffset = RateStreamOffset(Map((0, ValueRunTimeMsPair(0, 1000)))) val endOffset = RateStreamOffset(Map((0, ValueRunTimeMsPair(20, 2000)))) reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) @@ -125,7 +125,7 @@ class RateSourceV2Suite extends StreamTest { test("microbatch - data read") { val reader = new RateStreamMicroBatchReader( - new DataSourceV2Options(Map("numPartitions" -> "11", "rowsPerSecond" -> "33").asJava)) + new DataSourceOptions(Map("numPartitions" -> "11", "rowsPerSecond" -> "33").asJava)) val startOffset = RateStreamSourceV2.createInitialOffset(11, reader.creationTimeMs) val endOffset = RateStreamOffset(startOffset.partitionToValueAndRunTimeMs.toSeq.map { case (part, ValueRunTimeMsPair(currentVal, currentReadTime)) => @@ -150,7 +150,7 @@ class RateSourceV2Suite extends StreamTest { test("continuous in registry") { DataSource.lookupDataSource("rate", spark.sqlContext.conf).newInstance() match { case ds: ContinuousReadSupport => - val reader = ds.createContinuousReader(Optional.empty(), "", DataSourceV2Options.empty()) + val reader = ds.createContinuousReader(Optional.empty(), "", DataSourceOptions.empty()) assert(reader.isInstanceOf[RateStreamContinuousReader]) case _ => throw new IllegalStateException("Could not find v2 read support for rate") @@ -159,7 +159,7 @@ class RateSourceV2Suite extends StreamTest { test("continuous data") { val reader = new RateStreamContinuousReader( - new DataSourceV2Options(Map("numPartitions" -> "2", "rowsPerSecond" -> "20").asJava)) + new DataSourceOptions(Map("numPartitions" -> "2", "rowsPerSecond" -> "20").asJava)) reader.setOffset(Optional.empty()) val tasks = reader.createDataReaderFactories() assert(tasks.size == 2) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2OptionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceOptionsSuite.scala similarity index 80% rename from sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2OptionsSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceOptionsSuite.scala index 90d92864b26fa..31dfc55b23361 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2OptionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceOptionsSuite.scala @@ -22,24 +22,24 @@ import scala.collection.JavaConverters._ import org.apache.spark.SparkFunSuite /** - * A simple test suite to verify `DataSourceV2Options`. + * A simple test suite to verify `DataSourceOptions`. */ -class DataSourceV2OptionsSuite extends SparkFunSuite { +class DataSourceOptionsSuite extends SparkFunSuite { test("key is case-insensitive") { - val options = new DataSourceV2Options(Map("foo" -> "bar").asJava) + val options = new DataSourceOptions(Map("foo" -> "bar").asJava) assert(options.get("foo").get() == "bar") assert(options.get("FoO").get() == "bar") assert(!options.get("abc").isPresent) } test("value is case-sensitive") { - val options = new DataSourceV2Options(Map("foo" -> "bAr").asJava) + val options = new DataSourceOptions(Map("foo" -> "bAr").asJava) assert(options.get("foo").get == "bAr") } test("getInt") { - val options = new DataSourceV2Options(Map("numFOo" -> "1", "foo" -> "bar").asJava) + val options = new DataSourceOptions(Map("numFOo" -> "1", "foo" -> "bar").asJava) assert(options.getInt("numFOO", 10) == 1) assert(options.getInt("numFOO2", 10) == 10) @@ -49,7 +49,7 @@ class DataSourceV2OptionsSuite extends SparkFunSuite { } test("getBoolean") { - val options = new DataSourceV2Options( + val options = new DataSourceOptions( Map("isFoo" -> "true", "isFOO2" -> "false", "foo" -> "bar").asJava) assert(options.getBoolean("isFoo", false)) assert(!options.getBoolean("isFoo2", true)) @@ -59,7 +59,7 @@ class DataSourceV2OptionsSuite extends SparkFunSuite { } test("getLong") { - val options = new DataSourceV2Options(Map("numFoo" -> "9223372036854775807", + val options = new DataSourceOptions(Map("numFoo" -> "9223372036854775807", "foo" -> "bar").asJava) assert(options.getLong("numFOO", 0L) == 9223372036854775807L) assert(options.getLong("numFoo2", -1L) == -1L) @@ -70,7 +70,7 @@ class DataSourceV2OptionsSuite extends SparkFunSuite { } test("getDouble") { - val options = new DataSourceV2Options(Map("numFoo" -> "922337.1", + val options = new DataSourceOptions(Map("numFoo" -> "922337.1", "foo" -> "bar").asJava) assert(options.getDouble("numFOO", 0d) == 922337.1d) assert(options.getDouble("numFoo2", -1.02d) == -1.02d) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala index 42c5d3bcea44b..ee50e8a92270b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala @@ -201,7 +201,7 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { class SimpleDataSourceV2 extends DataSourceV2 with ReadSupport { - class Reader extends DataSourceV2Reader { + class Reader extends DataSourceReader { override def readSchema(): StructType = new StructType().add("i", "int").add("j", "int") override def createDataReaderFactories(): JList[DataReaderFactory[Row]] = { @@ -209,7 +209,7 @@ class SimpleDataSourceV2 extends DataSourceV2 with ReadSupport { } } - override def createReader(options: DataSourceV2Options): DataSourceV2Reader = new Reader + override def createReader(options: DataSourceOptions): DataSourceReader = new Reader } class SimpleDataReaderFactory(start: Int, end: Int) @@ -233,7 +233,7 @@ class SimpleDataReaderFactory(start: Int, end: Int) class AdvancedDataSourceV2 extends DataSourceV2 with ReadSupport { - class Reader extends DataSourceV2Reader + class Reader extends DataSourceReader with SupportsPushDownRequiredColumns with SupportsPushDownFilters { var requiredSchema = new StructType().add("i", "int").add("j", "int") @@ -275,7 +275,7 @@ class AdvancedDataSourceV2 extends DataSourceV2 with ReadSupport { } } - override def createReader(options: DataSourceV2Options): DataSourceV2Reader = new Reader + override def createReader(options: DataSourceOptions): DataSourceReader = new Reader } class AdvancedDataReaderFactory(start: Int, end: Int, requiredSchema: StructType) @@ -306,7 +306,7 @@ class AdvancedDataReaderFactory(start: Int, end: Int, requiredSchema: StructType class UnsafeRowDataSourceV2 extends DataSourceV2 with ReadSupport { - class Reader extends DataSourceV2Reader with SupportsScanUnsafeRow { + class Reader extends DataSourceReader with SupportsScanUnsafeRow { override def readSchema(): StructType = new StructType().add("i", "int").add("j", "int") override def createUnsafeRowReaderFactories(): JList[DataReaderFactory[UnsafeRow]] = { @@ -315,7 +315,7 @@ class UnsafeRowDataSourceV2 extends DataSourceV2 with ReadSupport { } } - override def createReader(options: DataSourceV2Options): DataSourceV2Reader = new Reader + override def createReader(options: DataSourceOptions): DataSourceReader = new Reader } class UnsafeRowDataReaderFactory(start: Int, end: Int) @@ -343,18 +343,18 @@ class UnsafeRowDataReaderFactory(start: Int, end: Int) class SchemaRequiredDataSource extends DataSourceV2 with ReadSupportWithSchema { - class Reader(val readSchema: StructType) extends DataSourceV2Reader { + class Reader(val readSchema: StructType) extends DataSourceReader { override def createDataReaderFactories(): JList[DataReaderFactory[Row]] = java.util.Collections.emptyList() } - override def createReader(schema: StructType, options: DataSourceV2Options): DataSourceV2Reader = + override def createReader(schema: StructType, options: DataSourceOptions): DataSourceReader = new Reader(schema) } class BatchDataSourceV2 extends DataSourceV2 with ReadSupport { - class Reader extends DataSourceV2Reader with SupportsScanColumnarBatch { + class Reader extends DataSourceReader with SupportsScanColumnarBatch { override def readSchema(): StructType = new StructType().add("i", "int").add("j", "int") override def createBatchDataReaderFactories(): JList[DataReaderFactory[ColumnarBatch]] = { @@ -362,7 +362,7 @@ class BatchDataSourceV2 extends DataSourceV2 with ReadSupport { } } - override def createReader(options: DataSourceV2Options): DataSourceV2Reader = new Reader + override def createReader(options: DataSourceOptions): DataSourceReader = new Reader } class BatchDataReaderFactory(start: Int, end: Int) @@ -406,7 +406,7 @@ class BatchDataReaderFactory(start: Int, end: Int) class PartitionAwareDataSource extends DataSourceV2 with ReadSupport { - class Reader extends DataSourceV2Reader with SupportsReportPartitioning { + class Reader extends DataSourceReader with SupportsReportPartitioning { override def readSchema(): StructType = new StructType().add("a", "int").add("b", "int") override def createDataReaderFactories(): JList[DataReaderFactory[Row]] = { @@ -428,7 +428,7 @@ class PartitionAwareDataSource extends DataSourceV2 with ReadSupport { } } - override def createReader(options: DataSourceV2Options): DataSourceV2Reader = new Reader + override def createReader(options: DataSourceOptions): DataSourceReader = new Reader } class SpecificDataReaderFactory(i: Array[Int], j: Array[Int]) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala index 3310d6dd199d6..a131b16953e3b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala @@ -28,7 +28,7 @@ import org.apache.hadoop.fs.{FileSystem, FSDataInputStream, Path} import org.apache.spark.SparkContext import org.apache.spark.sql.{Row, SaveMode} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.sources.v2.reader.{DataReader, DataReaderFactory, DataSourceV2Reader} +import org.apache.spark.sql.sources.v2.reader.{DataReader, DataReaderFactory, DataSourceReader} import org.apache.spark.sql.sources.v2.writer._ import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.util.SerializableConfiguration @@ -42,7 +42,7 @@ class SimpleWritableDataSource extends DataSourceV2 with ReadSupport with WriteS private val schema = new StructType().add("i", "long").add("j", "long") - class Reader(path: String, conf: Configuration) extends DataSourceV2Reader { + class Reader(path: String, conf: Configuration) extends DataSourceReader { override def readSchema(): StructType = schema override def createDataReaderFactories(): JList[DataReaderFactory[Row]] = { @@ -64,7 +64,7 @@ class SimpleWritableDataSource extends DataSourceV2 with ReadSupport with WriteS } } - class Writer(jobId: String, path: String, conf: Configuration) extends DataSourceV2Writer { + class Writer(jobId: String, path: String, conf: Configuration) extends DataSourceWriter { override def createWriterFactory(): DataWriterFactory[Row] = { new SimpleCSVDataWriterFactory(path, jobId, new SerializableConfiguration(conf)) } @@ -104,7 +104,7 @@ class SimpleWritableDataSource extends DataSourceV2 with ReadSupport with WriteS } } - override def createReader(options: DataSourceV2Options): DataSourceV2Reader = { + override def createReader(options: DataSourceOptions): DataSourceReader = { val path = new Path(options.get("path").get()) val conf = SparkContext.getActive.get.hadoopConfiguration new Reader(path.toUri.toString, conf) @@ -114,7 +114,7 @@ class SimpleWritableDataSource extends DataSourceV2 with ReadSupport with WriteS jobId: String, schema: StructType, mode: SaveMode, - options: DataSourceV2Options): Optional[DataSourceV2Writer] = { + options: DataSourceOptions): Optional[DataSourceWriter] = { assert(DataType.equalsStructurally(schema.asNullable, this.schema.asNullable)) assert(!SparkContext.getActive.get.conf.getBoolean("spark.speculation", false)) @@ -141,7 +141,7 @@ class SimpleWritableDataSource extends DataSourceV2 with ReadSupport with WriteS } private def createWriter( - jobId: String, path: Path, conf: Configuration, internal: Boolean): DataSourceV2Writer = { + jobId: String, path: Path, conf: Configuration, internal: Boolean): DataSourceWriter = { val pathStr = path.toUri.toString if (internal) { new InternalRowWriter(jobId, pathStr, conf) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala index dc8c857018457..3127d664d32dc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.execution.streaming.{RateStreamOffset, Sink, Streami import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.{DataSourceRegister, StreamSinkProvider} -import org.apache.spark.sql.sources.v2.DataSourceV2Options +import org.apache.spark.sql.sources.v2.DataSourceOptions import org.apache.spark.sql.sources.v2.reader.DataReaderFactory import org.apache.spark.sql.sources.v2.streaming._ import org.apache.spark.sql.sources.v2.streaming.reader.{ContinuousReader, MicroBatchReader, Offset, PartitionOffset} @@ -54,14 +54,14 @@ trait FakeMicroBatchReadSupport extends MicroBatchReadSupport { override def createMicroBatchReader( schema: Optional[StructType], checkpointLocation: String, - options: DataSourceV2Options): MicroBatchReader = FakeReader() + options: DataSourceOptions): MicroBatchReader = FakeReader() } trait FakeContinuousReadSupport extends ContinuousReadSupport { override def createContinuousReader( schema: Optional[StructType], checkpointLocation: String, - options: DataSourceV2Options): ContinuousReader = FakeReader() + options: DataSourceOptions): ContinuousReader = FakeReader() } trait FakeStreamWriteSupport extends StreamWriteSupport { @@ -69,7 +69,7 @@ trait FakeStreamWriteSupport extends StreamWriteSupport { queryId: String, schema: StructType, mode: OutputMode, - options: DataSourceV2Options): StreamWriter = { + options: DataSourceOptions): StreamWriter = { throw new IllegalStateException("fake sink - cannot actually write") } } From 7a2ada223e14d09271a76091be0338b2d375081e Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 30 Jan 2018 21:55:55 +0900 Subject: [PATCH 0238/1161] [SPARK-23261][PYSPARK] Rename Pandas UDFs ## What changes were proposed in this pull request? Rename the public APIs and names of pandas udfs. - `PANDAS SCALAR UDF` -> `SCALAR PANDAS UDF` - `PANDAS GROUP MAP UDF` -> `GROUPED MAP PANDAS UDF` - `PANDAS GROUP AGG UDF` -> `GROUPED AGG PANDAS UDF` ## How was this patch tested? The existing tests Author: gatorsmile Closes #20428 from gatorsmile/renamePandasUDFs. --- .../spark/api/python/PythonRunner.scala | 12 +-- docs/sql-programming-guide.md | 8 +- examples/src/main/python/sql/arrow.py | 12 +-- python/pyspark/rdd.py | 6 +- python/pyspark/sql/functions.py | 34 +++---- python/pyspark/sql/group.py | 10 +- python/pyspark/sql/tests.py | 92 +++++++++---------- python/pyspark/sql/udf.py | 25 ++--- python/pyspark/worker.py | 24 ++--- .../sql/catalyst/expressions/PythonUDF.scala | 4 +- .../sql/catalyst/planning/patterns.scala | 1 - .../spark/sql/RelationalGroupedDataset.scala | 4 +- .../python/AggregateInPandasExec.scala | 2 +- .../python/ArrowEvalPythonExec.scala | 2 +- .../execution/python/ExtractPythonUDFs.scala | 2 +- .../python/FlatMapGroupsInPandasExec.scala | 2 +- 16 files changed, 120 insertions(+), 120 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala index 29148a7ee558b..f075a7e0eb0b4 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala @@ -37,16 +37,16 @@ private[spark] object PythonEvalType { val SQL_BATCHED_UDF = 100 - val SQL_PANDAS_SCALAR_UDF = 200 - val SQL_PANDAS_GROUP_MAP_UDF = 201 - val SQL_PANDAS_GROUP_AGG_UDF = 202 + val SQL_SCALAR_PANDAS_UDF = 200 + val SQL_GROUPED_MAP_PANDAS_UDF = 201 + val SQL_GROUPED_AGG_PANDAS_UDF = 202 def toString(pythonEvalType: Int): String = pythonEvalType match { case NON_UDF => "NON_UDF" case SQL_BATCHED_UDF => "SQL_BATCHED_UDF" - case SQL_PANDAS_SCALAR_UDF => "SQL_PANDAS_SCALAR_UDF" - case SQL_PANDAS_GROUP_MAP_UDF => "SQL_PANDAS_GROUP_MAP_UDF" - case SQL_PANDAS_GROUP_AGG_UDF => "SQL_PANDAS_GROUP_AGG_UDF" + case SQL_SCALAR_PANDAS_UDF => "SQL_SCALAR_PANDAS_UDF" + case SQL_GROUPED_MAP_PANDAS_UDF => "SQL_GROUPED_MAP_PANDAS_UDF" + case SQL_GROUPED_AGG_PANDAS_UDF => "SQL_GROUPED_AGG_PANDAS_UDF" } } diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index d49c8d869cba6..a0e221b39cc34 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1684,7 +1684,7 @@ Spark will fall back to create the DataFrame without Arrow. Pandas UDFs are user defined functions that are executed by Spark using Arrow to transfer data and Pandas to work with the data. A Pandas UDF is defined using the keyword `pandas_udf` as a decorator or to wrap the function, no additional configuration is required. Currently, there are two types of -Pandas UDF: Scalar and Group Map. +Pandas UDF: Scalar and Grouped Map. ### Scalar @@ -1702,8 +1702,8 @@ The following example shows how to create a scalar Pandas UDF that computes the
-### Group Map -Group map Pandas UDFs are used with `groupBy().apply()` which implements the "split-apply-combine" pattern. +### Grouped Map +Grouped map Pandas UDFs are used with `groupBy().apply()` which implements the "split-apply-combine" pattern. Split-apply-combine consists of three steps: * Split the data into groups by using `DataFrame.groupBy`. * Apply a function on each group. The input and output of the function are both `pandas.DataFrame`. The @@ -1723,7 +1723,7 @@ The following example shows how to use `groupby().apply()` to subtract the mean
-{% include_example group_map_pandas_udf python/sql/arrow.py %} +{% include_example grouped_map_pandas_udf python/sql/arrow.py %}
diff --git a/examples/src/main/python/sql/arrow.py b/examples/src/main/python/sql/arrow.py index 6c0028b3f1c1f..4c5aefb6ff4a6 100644 --- a/examples/src/main/python/sql/arrow.py +++ b/examples/src/main/python/sql/arrow.py @@ -86,15 +86,15 @@ def multiply_func(a, b): # $example off:scalar_pandas_udf$ -def group_map_pandas_udf_example(spark): - # $example on:group_map_pandas_udf$ +def grouped_map_pandas_udf_example(spark): + # $example on:grouped_map_pandas_udf$ from pyspark.sql.functions import pandas_udf, PandasUDFType df = spark.createDataFrame( [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ("id", "v")) - @pandas_udf("id long, v double", PandasUDFType.GROUP_MAP) + @pandas_udf("id long, v double", PandasUDFType.GROUPED_MAP) def substract_mean(pdf): # pdf is a pandas.DataFrame v = pdf.v @@ -110,7 +110,7 @@ def substract_mean(pdf): # | 2|-1.0| # | 2| 4.0| # +---+----+ - # $example off:group_map_pandas_udf$ + # $example off:grouped_map_pandas_udf$ if __name__ == "__main__": @@ -123,7 +123,7 @@ def substract_mean(pdf): dataframe_with_arrow_example(spark) print("Running pandas_udf scalar example") scalar_pandas_udf_example(spark) - print("Running pandas_udf group map example") - group_map_pandas_udf_example(spark) + print("Running pandas_udf grouped map example") + grouped_map_pandas_udf_example(spark) spark.stop() diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 6b018c3a38444..93b8974a7e64a 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -68,9 +68,9 @@ class PythonEvalType(object): SQL_BATCHED_UDF = 100 - SQL_PANDAS_SCALAR_UDF = 200 - SQL_PANDAS_GROUP_MAP_UDF = 201 - SQL_PANDAS_GROUP_AGG_UDF = 202 + SQL_SCALAR_PANDAS_UDF = 200 + SQL_GROUPED_MAP_PANDAS_UDF = 201 + SQL_GROUPED_AGG_PANDAS_UDF = 202 def portable_hash(x): diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index a291c9b71913f..3c8fb4c4d19e7 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1737,8 +1737,8 @@ def translate(srcCol, matching, replace): def create_map(*cols): """Creates a new map column. - :param cols: list of column names (string) or list of :class:`Column` expressions that grouped - as key-value pairs, e.g. (key1, value1, key2, value2, ...). + :param cols: list of column names (string) or list of :class:`Column` expressions that are + grouped as key-value pairs, e.g. (key1, value1, key2, value2, ...). >>> df.select(create_map('name', 'age').alias("map")).collect() [Row(map={u'Alice': 2}), Row(map={u'Bob': 5})] @@ -2085,11 +2085,11 @@ def map_values(col): class PandasUDFType(object): """Pandas UDF Types. See :meth:`pyspark.sql.functions.pandas_udf`. """ - SCALAR = PythonEvalType.SQL_PANDAS_SCALAR_UDF + SCALAR = PythonEvalType.SQL_SCALAR_PANDAS_UDF - GROUP_MAP = PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF + GROUPED_MAP = PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF - GROUP_AGG = PythonEvalType.SQL_PANDAS_GROUP_AGG_UDF + GROUPED_AGG = PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF @since(1.3) @@ -2193,20 +2193,20 @@ def pandas_udf(f=None, returnType=None, functionType=None): Therefore, this can be used, for example, to ensure the length of each returned `pandas.Series`, and can not be used as the column length. - 2. GROUP_MAP + 2. GROUPED_MAP - A group map UDF defines transformation: A `pandas.DataFrame` -> A `pandas.DataFrame` + A grouped map UDF defines transformation: A `pandas.DataFrame` -> A `pandas.DataFrame` The returnType should be a :class:`StructType` describing the schema of the returned `pandas.DataFrame`. The length of the returned `pandas.DataFrame` can be arbitrary. - Group map UDFs are used with :meth:`pyspark.sql.GroupedData.apply`. + Grouped map UDFs are used with :meth:`pyspark.sql.GroupedData.apply`. >>> from pyspark.sql.functions import pandas_udf, PandasUDFType >>> df = spark.createDataFrame( ... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ... ("id", "v")) # doctest: +SKIP - >>> @pandas_udf("id long, v double", PandasUDFType.GROUP_MAP) # doctest: +SKIP + >>> @pandas_udf("id long, v double", PandasUDFType.GROUPED_MAP) # doctest: +SKIP ... def normalize(pdf): ... v = pdf.v ... return pdf.assign(v=(v - v.mean()) / v.std()) @@ -2223,9 +2223,9 @@ def pandas_udf(f=None, returnType=None, functionType=None): .. seealso:: :meth:`pyspark.sql.GroupedData.apply` - 3. GROUP_AGG + 3. GROUPED_AGG - A group aggregate UDF defines a transformation: One or more `pandas.Series` -> A scalar + A grouped aggregate UDF defines a transformation: One or more `pandas.Series` -> A scalar The `returnType` should be a primitive data type, e.g., :class:`DoubleType`. The returned scalar can be either a python primitive type, e.g., `int` or `float` or a numpy data type, e.g., `numpy.int64` or `numpy.float64`. @@ -2239,7 +2239,7 @@ def pandas_udf(f=None, returnType=None, functionType=None): >>> df = spark.createDataFrame( ... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ... ("id", "v")) - >>> @pandas_udf("double", PandasUDFType.GROUP_AGG) # doctest: +SKIP + >>> @pandas_udf("double", PandasUDFType.GROUPED_AGG) # doctest: +SKIP ... def mean_udf(v): ... return v.mean() >>> df.groupby("id").agg(mean_udf(df['v'])).show() # doctest: +SKIP @@ -2285,21 +2285,21 @@ def pandas_udf(f=None, returnType=None, functionType=None): eval_type = returnType else: # @pandas_udf(dataType) or @pandas_udf(returnType=dataType) - eval_type = PythonEvalType.SQL_PANDAS_SCALAR_UDF + eval_type = PythonEvalType.SQL_SCALAR_PANDAS_UDF else: return_type = returnType if functionType is not None: eval_type = functionType else: - eval_type = PythonEvalType.SQL_PANDAS_SCALAR_UDF + eval_type = PythonEvalType.SQL_SCALAR_PANDAS_UDF if return_type is None: raise ValueError("Invalid returnType: returnType can not be None") - if eval_type not in [PythonEvalType.SQL_PANDAS_SCALAR_UDF, - PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF, - PythonEvalType.SQL_PANDAS_GROUP_AGG_UDF]: + if eval_type not in [PythonEvalType.SQL_SCALAR_PANDAS_UDF, + PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, + PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF]: raise ValueError("Invalid functionType: " "functionType must be one the values from PandasUDFType") diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py index f90a909d7c2b1..ab646535c864c 100644 --- a/python/pyspark/sql/group.py +++ b/python/pyspark/sql/group.py @@ -98,7 +98,7 @@ def agg(self, *exprs): [Row(name=u'Alice', min(age)=2), Row(name=u'Bob', min(age)=5)] >>> from pyspark.sql.functions import pandas_udf, PandasUDFType - >>> @pandas_udf('int', PandasUDFType.GROUP_AGG) # doctest: +SKIP + >>> @pandas_udf('int', PandasUDFType.GROUPED_AGG) # doctest: +SKIP ... def min_udf(v): ... return v.min() >>> sorted(gdf.agg(min_udf(df.age)).collect()) # doctest: +SKIP @@ -235,14 +235,14 @@ def apply(self, udf): into memory, so the user should be aware of the potential OOM risk if data is skewed and certain groups are too large to fit in memory. - :param udf: a group map user-defined function returned by + :param udf: a grouped map user-defined function returned by :func:`pyspark.sql.functions.pandas_udf`. >>> from pyspark.sql.functions import pandas_udf, PandasUDFType >>> df = spark.createDataFrame( ... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ... ("id", "v")) - >>> @pandas_udf("id long, v double", PandasUDFType.GROUP_MAP) # doctest: +SKIP + >>> @pandas_udf("id long, v double", PandasUDFType.GROUPED_MAP) # doctest: +SKIP ... def normalize(pdf): ... v = pdf.v ... return pdf.assign(v=(v - v.mean()) / v.std()) @@ -262,9 +262,9 @@ def apply(self, udf): """ # Columns are special because hasattr always return True if isinstance(udf, Column) or not hasattr(udf, 'func') \ - or udf.evalType != PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF: + or udf.evalType != PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF: raise ValueError("Invalid udf: the udf argument must be a pandas_udf of type " - "GROUP_MAP.") + "GROUPED_MAP.") df = self._df udf_column = udf(*[df[col] for col in df.columns]) jdf = self._jgd.flatMapGroupsInPandas(udf_column._jc.expr()) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index ca7bbf8ffe71c..dc80870d3cd9f 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3621,34 +3621,34 @@ def test_pandas_udf_basic(self): udf = pandas_udf(lambda x: x, DoubleType()) self.assertEqual(udf.returnType, DoubleType()) - self.assertEqual(udf.evalType, PythonEvalType.SQL_PANDAS_SCALAR_UDF) + self.assertEqual(udf.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF) udf = pandas_udf(lambda x: x, DoubleType(), PandasUDFType.SCALAR) self.assertEqual(udf.returnType, DoubleType()) - self.assertEqual(udf.evalType, PythonEvalType.SQL_PANDAS_SCALAR_UDF) + self.assertEqual(udf.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF) udf = pandas_udf(lambda x: x, 'double', PandasUDFType.SCALAR) self.assertEqual(udf.returnType, DoubleType()) - self.assertEqual(udf.evalType, PythonEvalType.SQL_PANDAS_SCALAR_UDF) + self.assertEqual(udf.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF) udf = pandas_udf(lambda x: x, StructType([StructField("v", DoubleType())]), - PandasUDFType.GROUP_MAP) + PandasUDFType.GROUPED_MAP) self.assertEqual(udf.returnType, StructType([StructField("v", DoubleType())])) - self.assertEqual(udf.evalType, PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF) + self.assertEqual(udf.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF) - udf = pandas_udf(lambda x: x, 'v double', PandasUDFType.GROUP_MAP) + udf = pandas_udf(lambda x: x, 'v double', PandasUDFType.GROUPED_MAP) self.assertEqual(udf.returnType, StructType([StructField("v", DoubleType())])) - self.assertEqual(udf.evalType, PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF) + self.assertEqual(udf.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF) udf = pandas_udf(lambda x: x, 'v double', - functionType=PandasUDFType.GROUP_MAP) + functionType=PandasUDFType.GROUPED_MAP) self.assertEqual(udf.returnType, StructType([StructField("v", DoubleType())])) - self.assertEqual(udf.evalType, PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF) + self.assertEqual(udf.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF) udf = pandas_udf(lambda x: x, returnType='v double', - functionType=PandasUDFType.GROUP_MAP) + functionType=PandasUDFType.GROUPED_MAP) self.assertEqual(udf.returnType, StructType([StructField("v", DoubleType())])) - self.assertEqual(udf.evalType, PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF) + self.assertEqual(udf.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF) def test_pandas_udf_decorator(self): from pyspark.rdd import PythonEvalType @@ -3659,45 +3659,45 @@ def test_pandas_udf_decorator(self): def foo(x): return x self.assertEqual(foo.returnType, DoubleType()) - self.assertEqual(foo.evalType, PythonEvalType.SQL_PANDAS_SCALAR_UDF) + self.assertEqual(foo.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF) @pandas_udf(returnType=DoubleType()) def foo(x): return x self.assertEqual(foo.returnType, DoubleType()) - self.assertEqual(foo.evalType, PythonEvalType.SQL_PANDAS_SCALAR_UDF) + self.assertEqual(foo.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF) schema = StructType([StructField("v", DoubleType())]) - @pandas_udf(schema, PandasUDFType.GROUP_MAP) + @pandas_udf(schema, PandasUDFType.GROUPED_MAP) def foo(x): return x self.assertEqual(foo.returnType, schema) - self.assertEqual(foo.evalType, PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF) + self.assertEqual(foo.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF) - @pandas_udf('v double', PandasUDFType.GROUP_MAP) + @pandas_udf('v double', PandasUDFType.GROUPED_MAP) def foo(x): return x self.assertEqual(foo.returnType, schema) - self.assertEqual(foo.evalType, PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF) + self.assertEqual(foo.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF) - @pandas_udf(schema, functionType=PandasUDFType.GROUP_MAP) + @pandas_udf(schema, functionType=PandasUDFType.GROUPED_MAP) def foo(x): return x self.assertEqual(foo.returnType, schema) - self.assertEqual(foo.evalType, PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF) + self.assertEqual(foo.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF) @pandas_udf(returnType='v double', functionType=PandasUDFType.SCALAR) def foo(x): return x self.assertEqual(foo.returnType, schema) - self.assertEqual(foo.evalType, PythonEvalType.SQL_PANDAS_SCALAR_UDF) + self.assertEqual(foo.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF) - @pandas_udf(returnType=schema, functionType=PandasUDFType.GROUP_MAP) + @pandas_udf(returnType=schema, functionType=PandasUDFType.GROUPED_MAP) def foo(x): return x self.assertEqual(foo.returnType, schema) - self.assertEqual(foo.evalType, PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF) + self.assertEqual(foo.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF) def test_udf_wrong_arg(self): from pyspark.sql.functions import pandas_udf, PandasUDFType @@ -3724,15 +3724,15 @@ def zero_with_type(): return 1 with self.assertRaisesRegexp(TypeError, 'Invalid returnType'): - @pandas_udf(returnType=PandasUDFType.GROUP_MAP) + @pandas_udf(returnType=PandasUDFType.GROUPED_MAP) def foo(df): return df with self.assertRaisesRegexp(ValueError, 'Invalid returnType'): - @pandas_udf(returnType='double', functionType=PandasUDFType.GROUP_MAP) + @pandas_udf(returnType='double', functionType=PandasUDFType.GROUPED_MAP) def foo(df): return df with self.assertRaisesRegexp(ValueError, 'Invalid function'): - @pandas_udf(returnType='k int, v double', functionType=PandasUDFType.GROUP_MAP) + @pandas_udf(returnType='k int, v double', functionType=PandasUDFType.GROUPED_MAP) def foo(k, v): return k @@ -3804,11 +3804,11 @@ def test_register_nondeterministic_vectorized_udf_basic(self): random_pandas_udf = pandas_udf( lambda x: random.randint(6, 6) + x, IntegerType()).asNondeterministic() self.assertEqual(random_pandas_udf.deterministic, False) - self.assertEqual(random_pandas_udf.evalType, PythonEvalType.SQL_PANDAS_SCALAR_UDF) + self.assertEqual(random_pandas_udf.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF) nondeterministic_pandas_udf = self.spark.catalog.registerFunction( "randomPandasUDF", random_pandas_udf) self.assertEqual(nondeterministic_pandas_udf.deterministic, False) - self.assertEqual(nondeterministic_pandas_udf.evalType, PythonEvalType.SQL_PANDAS_SCALAR_UDF) + self.assertEqual(nondeterministic_pandas_udf.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF) [row] = self.spark.sql("SELECT randomPandasUDF(1)").collect() self.assertEqual(row[0], 7) @@ -4206,7 +4206,7 @@ def test_register_vectorized_udf_basic(self): col('id').cast('int').alias('b')) original_add = pandas_udf(lambda x, y: x + y, IntegerType()) self.assertEqual(original_add.deterministic, True) - self.assertEqual(original_add.evalType, PythonEvalType.SQL_PANDAS_SCALAR_UDF) + self.assertEqual(original_add.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF) new_add = self.spark.catalog.registerFunction("add1", original_add) res1 = df.select(new_add(col('a'), col('b'))) res2 = self.spark.sql( @@ -4237,20 +4237,20 @@ def test_simple(self): StructField('v', IntegerType()), StructField('v1', DoubleType()), StructField('v2', LongType())]), - PandasUDFType.GROUP_MAP + PandasUDFType.GROUPED_MAP ) result = df.groupby('id').apply(foo_udf).sort('id').toPandas() expected = df.toPandas().groupby('id').apply(foo_udf.func).reset_index(drop=True) self.assertPandasEqual(expected, result) - def test_register_group_map_udf(self): + def test_register_grouped_map_udf(self): from pyspark.sql.functions import pandas_udf, PandasUDFType - foo_udf = pandas_udf(lambda x: x, "id long", PandasUDFType.GROUP_MAP) + foo_udf = pandas_udf(lambda x: x, "id long", PandasUDFType.GROUPED_MAP) with QuietTest(self.sc): with self.assertRaisesRegexp(ValueError, 'f must be either SQL_BATCHED_UDF or ' - 'SQL_PANDAS_SCALAR_UDF'): + 'SQL_SCALAR_PANDAS_UDF'): self.spark.catalog.registerFunction("foo_udf", foo_udf) def test_decorator(self): @@ -4259,7 +4259,7 @@ def test_decorator(self): @pandas_udf( 'id long, v int, v1 double, v2 long', - PandasUDFType.GROUP_MAP + PandasUDFType.GROUPED_MAP ) def foo(pdf): return pdf.assign(v1=pdf.v * pdf.id * 1.0, v2=pdf.v + pdf.id) @@ -4275,7 +4275,7 @@ def test_coerce(self): foo = pandas_udf( lambda pdf: pdf, 'id long, v double', - PandasUDFType.GROUP_MAP + PandasUDFType.GROUPED_MAP ) result = df.groupby('id').apply(foo).sort('id').toPandas() @@ -4289,7 +4289,7 @@ def test_complex_groupby(self): @pandas_udf( 'id long, v int, norm double', - PandasUDFType.GROUP_MAP + PandasUDFType.GROUPED_MAP ) def normalize(pdf): v = pdf.v @@ -4308,7 +4308,7 @@ def test_empty_groupby(self): @pandas_udf( 'id long, v int, norm double', - PandasUDFType.GROUP_MAP + PandasUDFType.GROUPED_MAP ) def normalize(pdf): v = pdf.v @@ -4328,7 +4328,7 @@ def test_datatype_string(self): foo_udf = pandas_udf( lambda pdf: pdf.assign(v1=pdf.v * pdf.id * 1.0, v2=pdf.v + pdf.id), 'id long, v int, v1 double, v2 long', - PandasUDFType.GROUP_MAP + PandasUDFType.GROUPED_MAP ) result = df.groupby('id').apply(foo_udf).sort('id').toPandas() @@ -4342,7 +4342,7 @@ def test_wrong_return_type(self): foo = pandas_udf( lambda pdf: pdf, 'id long, v map', - PandasUDFType.GROUP_MAP + PandasUDFType.GROUPED_MAP ) with QuietTest(self.sc): @@ -4368,7 +4368,7 @@ def test_wrong_args(self): with self.assertRaisesRegexp(ValueError, 'Invalid udf'): df.groupby('id').apply( pandas_udf(lambda x, y: x, StructType([StructField("d", DoubleType())]))) - with self.assertRaisesRegexp(ValueError, 'Invalid udf.*GROUP_MAP'): + with self.assertRaisesRegexp(ValueError, 'Invalid udf.*GROUPED_MAP'): df.groupby('id').apply( pandas_udf(lambda x, y: x, StructType([StructField("d", DoubleType())]), PandasUDFType.SCALAR)) @@ -4379,7 +4379,7 @@ def test_unsupported_types(self): [StructField("id", LongType(), True), StructField("map", MapType(StringType(), IntegerType()), True)]) df = self.spark.createDataFrame([(1, None,)], schema=schema) - f = pandas_udf(lambda x: x, df.schema, PandasUDFType.GROUP_MAP) + f = pandas_udf(lambda x: x, df.schema, PandasUDFType.GROUPED_MAP) with QuietTest(self.sc): with self.assertRaisesRegexp(Exception, 'Unsupported data type'): df.groupby('id').apply(f).collect() @@ -4422,7 +4422,7 @@ def plus_two(v): def pandas_agg_mean_udf(self): from pyspark.sql.functions import pandas_udf, PandasUDFType - @pandas_udf('double', PandasUDFType.GROUP_AGG) + @pandas_udf('double', PandasUDFType.GROUPED_AGG) def avg(v): return v.mean() return avg @@ -4431,7 +4431,7 @@ def avg(v): def pandas_agg_sum_udf(self): from pyspark.sql.functions import pandas_udf, PandasUDFType - @pandas_udf('double', PandasUDFType.GROUP_AGG) + @pandas_udf('double', PandasUDFType.GROUPED_AGG) def sum(v): return v.sum() return sum @@ -4441,7 +4441,7 @@ def pandas_agg_weighted_mean_udf(self): import numpy as np from pyspark.sql.functions import pandas_udf, PandasUDFType - @pandas_udf('double', PandasUDFType.GROUP_AGG) + @pandas_udf('double', PandasUDFType.GROUPED_AGG) def weighted_mean(v, w): return np.average(v, weights=w) return weighted_mean @@ -4505,19 +4505,19 @@ def test_unsupported_types(self): with QuietTest(self.sc): with self.assertRaisesRegex(NotImplementedError, 'not supported'): - @pandas_udf(ArrayType(DoubleType()), PandasUDFType.GROUP_AGG) + @pandas_udf(ArrayType(DoubleType()), PandasUDFType.GROUPED_AGG) def mean_and_std_udf(v): return [v.mean(), v.std()] with QuietTest(self.sc): with self.assertRaisesRegex(NotImplementedError, 'not supported'): - @pandas_udf('mean double, std double', PandasUDFType.GROUP_AGG) + @pandas_udf('mean double, std double', PandasUDFType.GROUPED_AGG) def mean_and_std_udf(v): return v.mean(), v.std() with QuietTest(self.sc): with self.assertRaisesRegex(NotImplementedError, 'not supported'): - @pandas_udf(MapType(DoubleType(), DoubleType()), PandasUDFType.GROUP_AGG) + @pandas_udf(MapType(DoubleType(), DoubleType()), PandasUDFType.GROUPED_AGG) def mean_and_std_udf(v): return {v.mean(): v.std()} diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index 4f303304e5600..0f759c448b8a7 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -37,9 +37,9 @@ def _wrap_function(sc, func, returnType): def _create_udf(f, returnType, evalType): - if evalType in (PythonEvalType.SQL_PANDAS_SCALAR_UDF, - PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF, - PythonEvalType.SQL_PANDAS_GROUP_AGG_UDF): + if evalType in (PythonEvalType.SQL_SCALAR_PANDAS_UDF, + PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, + PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF): import inspect from pyspark.sql.utils import require_minimum_pyarrow_version @@ -47,16 +47,16 @@ def _create_udf(f, returnType, evalType): require_minimum_pyarrow_version() argspec = inspect.getargspec(f) - if evalType == PythonEvalType.SQL_PANDAS_SCALAR_UDF and len(argspec.args) == 0 and \ + if evalType == PythonEvalType.SQL_SCALAR_PANDAS_UDF and len(argspec.args) == 0 and \ argspec.varargs is None: raise ValueError( "Invalid function: 0-arg pandas_udfs are not supported. " "Instead, create a 1-arg pandas_udf and ignore the arg in your function." ) - if evalType == PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF and len(argspec.args) != 1: + if evalType == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF and len(argspec.args) != 1: raise ValueError( - "Invalid function: pandas_udfs with function type GROUP_MAP " + "Invalid function: pandas_udfs with function type GROUPED_MAP " "must take a single arg that is a pandas DataFrame." ) @@ -112,14 +112,15 @@ def returnType(self): else: self._returnType_placeholder = _parse_datatype_string(self._returnType) - if self.evalType == PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF \ + if self.evalType == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF \ and not isinstance(self._returnType_placeholder, StructType): raise ValueError("Invalid returnType: returnType must be a StructType for " - "pandas_udf with function type GROUP_MAP") - elif self.evalType == PythonEvalType.SQL_PANDAS_GROUP_AGG_UDF \ + "pandas_udf with function type GROUPED_MAP") + elif self.evalType == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF \ and isinstance(self._returnType_placeholder, (StructType, ArrayType, MapType)): raise NotImplementedError( - "ArrayType, StructType and MapType are not supported with PandasUDFType.GROUP_AGG") + "ArrayType, StructType and MapType are not supported with " + "PandasUDFType.GROUPED_AGG") return self._returnType_placeholder @@ -292,9 +293,9 @@ def register(self, name, f, returnType=None): "Invalid returnType: data type can not be specified when f is" "a user-defined function, but got %s." % returnType) if f.evalType not in [PythonEvalType.SQL_BATCHED_UDF, - PythonEvalType.SQL_PANDAS_SCALAR_UDF]: + PythonEvalType.SQL_SCALAR_PANDAS_UDF]: raise ValueError( - "Invalid f: f must be either SQL_BATCHED_UDF or SQL_PANDAS_SCALAR_UDF") + "Invalid f: f must be either SQL_BATCHED_UDF or SQL_SCALAR_PANDAS_UDF") register_udf = UserDefinedFunction(f.func, returnType=f.returnType, name=name, evalType=f.evalType, deterministic=f.deterministic) diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 173d8fb2856fa..121b3dd1aeec9 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -74,7 +74,7 @@ def wrap_udf(f, return_type): return lambda *a: f(*a) -def wrap_pandas_scalar_udf(f, return_type): +def wrap_scalar_pandas_udf(f, return_type): arrow_return_type = to_arrow_type(return_type) def verify_result_length(*a): @@ -90,7 +90,7 @@ def verify_result_length(*a): return lambda *a: (verify_result_length(*a), arrow_return_type) -def wrap_pandas_group_map_udf(f, return_type): +def wrap_grouped_map_pandas_udf(f, return_type): def wrapped(*series): import pandas as pd @@ -110,7 +110,7 @@ def wrapped(*series): return wrapped -def wrap_pandas_group_agg_udf(f, return_type): +def wrap_grouped_agg_pandas_udf(f, return_type): arrow_return_type = to_arrow_type(return_type) def wrapped(*series): @@ -133,12 +133,12 @@ def read_single_udf(pickleSer, infile, eval_type): row_func = chain(row_func, f) # the last returnType will be the return type of UDF - if eval_type == PythonEvalType.SQL_PANDAS_SCALAR_UDF: - return arg_offsets, wrap_pandas_scalar_udf(row_func, return_type) - elif eval_type == PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF: - return arg_offsets, wrap_pandas_group_map_udf(row_func, return_type) - elif eval_type == PythonEvalType.SQL_PANDAS_GROUP_AGG_UDF: - return arg_offsets, wrap_pandas_group_agg_udf(row_func, return_type) + if eval_type == PythonEvalType.SQL_SCALAR_PANDAS_UDF: + return arg_offsets, wrap_scalar_pandas_udf(row_func, return_type) + elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF: + return arg_offsets, wrap_grouped_map_pandas_udf(row_func, return_type) + elif eval_type == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF: + return arg_offsets, wrap_grouped_agg_pandas_udf(row_func, return_type) elif eval_type == PythonEvalType.SQL_BATCHED_UDF: return arg_offsets, wrap_udf(row_func, return_type) else: @@ -163,9 +163,9 @@ def read_udfs(pickleSer, infile, eval_type): func = lambda _, it: map(mapper, it) - if eval_type in (PythonEvalType.SQL_PANDAS_SCALAR_UDF, - PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF, - PythonEvalType.SQL_PANDAS_GROUP_AGG_UDF): + if eval_type in (PythonEvalType.SQL_SCALAR_PANDAS_UDF, + PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, + PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF): timezone = utf8_deserializer.loads(infile) ser = ArrowStreamPandasSerializer(timezone) else: diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala index 4ba8ff6e3802f..efd664dde725a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.types.DataType object PythonUDF { private[this] val SCALAR_TYPES = Set( PythonEvalType.SQL_BATCHED_UDF, - PythonEvalType.SQL_PANDAS_SCALAR_UDF + PythonEvalType.SQL_SCALAR_PANDAS_UDF ) def isScalarPythonUDF(e: Expression): Boolean = { @@ -36,7 +36,7 @@ object PythonUDF { def isGroupAggPandasUDF(e: Expression): Boolean = { e.isInstanceOf[PythonUDF] && - e.asInstanceOf[PythonUDF].evalType == PythonEvalType.SQL_PANDAS_GROUP_AGG_UDF + e.asInstanceOf[PythonUDF].evalType == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 132241061d510..626f905707191 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst.planning -import org.apache.spark.api.python.PythonEvalType import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index d320c1c359411..7147798d99533 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -449,8 +449,8 @@ class RelationalGroupedDataset protected[sql]( * workers. */ private[sql] def flatMapGroupsInPandas(expr: PythonUDF): DataFrame = { - require(expr.evalType == PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF, - "Must pass a group map udf") + require(expr.evalType == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, + "Must pass a grouped map udf") require(expr.dataType.isInstanceOf[StructType], "The returnType of the udf must be a StructType") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala index 18e5f8605c60d..8e01e8e56a5bd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala @@ -136,7 +136,7 @@ case class AggregateInPandasExec( val columnarBatchIter = new ArrowPythonRunner( pyFuncs, bufferSize, reuseWorker, - PythonEvalType.SQL_PANDAS_GROUP_AGG_UDF, argOffsets, aggInputSchema, + PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF, argOffsets, aggInputSchema, sessionLocalTimeZone, pandasRespectSessionTimeZone) .compute(projectedRowIter, context.partitionId(), context) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala index 47b146f076b62..c4de214679ae4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala @@ -81,7 +81,7 @@ case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chi val columnarBatchIter = new ArrowPythonRunner( funcs, bufferSize, reuseWorker, - PythonEvalType.SQL_PANDAS_SCALAR_UDF, argOffsets, schema, + PythonEvalType.SQL_SCALAR_PANDAS_UDF, argOffsets, schema, sessionLocalTimeZone, pandasRespectSessionTimeZone) .compute(batchIter, context.partitionId(), context) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala index 4ae4e164830be..9d56f48249982 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala @@ -160,7 +160,7 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper { } val evaluation = validUdfs.partition( - _.evalType == PythonEvalType.SQL_PANDAS_SCALAR_UDF + _.evalType == PythonEvalType.SQL_SCALAR_PANDAS_UDF ) match { case (vectorizedUdfs, plainUdfs) if plainUdfs.isEmpty => ArrowEvalPythonExec(vectorizedUdfs, child.output ++ resultAttrs, child) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala index 59db66bd7adf1..c798fe5a92c54 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala @@ -96,7 +96,7 @@ case class FlatMapGroupsInPandasExec( val columnarBatchIter = new ArrowPythonRunner( chainedFunc, bufferSize, reuseWorker, - PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF, argOffsets, schema, + PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, argOffsets, schema, sessionLocalTimeZone, pandasRespectSessionTimeZone) .compute(grouped, context.partitionId(), context) From 84bcf9dc88ffeae6fba4cfad9455ad75bed6e6f6 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 30 Jan 2018 21:00:29 +0800 Subject: [PATCH 0239/1161] [SPARK-23222][SQL] Make DataFrameRangeSuite not flaky ## What changes were proposed in this pull request? It is reported that the test `Cancelling stage in a query with Range` in `DataFrameRangeSuite` fails a few times in unrelated PRs. I personally also saw it too in my PR. This test is not very flaky actually but only fails occasionally. Based on how the test works, I guess that is because `range` finishes before the listener calls `cancelStage`. I increase the range number from `1000000000L` to `100000000000L` and count the range in one partition. I also reduce the `interval` of checking stage id. Hopefully it can make the test not flaky anymore. ## How was this patch tested? The modified tests. Author: Liang-Chi Hsieh Closes #20431 from viirya/SPARK-23222. --- .../scala/org/apache/spark/sql/DataFrameRangeSuite.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala index 45afbd29d1907..57a930dfaf320 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala @@ -154,7 +154,7 @@ class DataFrameRangeSuite extends QueryTest with SharedSQLContext with Eventuall test("Cancelling stage in a query with Range.") { val listener = new SparkListener { override def onJobStart(jobStart: SparkListenerJobStart): Unit = { - eventually(timeout(10.seconds)) { + eventually(timeout(10.seconds), interval(1.millis)) { assert(DataFrameRangeSuite.stageToKill > 0) } sparkContext.cancelStage(DataFrameRangeSuite.stageToKill) @@ -166,7 +166,7 @@ class DataFrameRangeSuite extends QueryTest with SharedSQLContext with Eventuall withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> codegen.toString()) { DataFrameRangeSuite.stageToKill = -1 val ex = intercept[SparkException] { - spark.range(1000000000L).map { x => + spark.range(0, 100000000000L, 1, 1).map { x => DataFrameRangeSuite.stageToKill = TaskContext.get().stageId() x }.toDF("id").agg(sum("id")).collect() @@ -184,6 +184,7 @@ class DataFrameRangeSuite extends QueryTest with SharedSQLContext with Eventuall assert(sparkContext.statusTracker.getExecutorInfos.map(_.numRunningTasks()).sum == 0) } } + sparkContext.removeSparkListener(listener) } test("SPARK-20430 Initialize Range parameters in a driver side") { From a23187f53037425c61f1180b5e7990a116f86a42 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Wed, 31 Jan 2018 00:51:00 +0900 Subject: [PATCH 0240/1161] [SPARK-23174][BUILD][PYTHON][FOLLOWUP] Add pycodestyle*.py to .gitignore file. ## What changes were proposed in this pull request? This is a follow-up pr of #20338 which changed the downloaded file name of the python code style checker but it's not contained in .gitignore file so the file remains as an untracked file for git after running the checker. This pr adds the file name to .gitignore file. ## How was this patch tested? Tested manually. Author: Takuya UESHIN Closes #20432 from ueshin/issues/SPARK-23174/fup1. --- dev/.gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/dev/.gitignore b/dev/.gitignore index 4a6027429e0d3..c673922f36d23 100644 --- a/dev/.gitignore +++ b/dev/.gitignore @@ -1 +1,2 @@ pep8*.py +pycodestyle*.py From 31c00ad8b090d7eddc4622e73dc4440cd32624de Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 30 Jan 2018 11:33:30 -0800 Subject: [PATCH 0241/1161] [SPARK-23267][SQL] Increase spark.sql.codegen.hugeMethodLimit to 65535 ## What changes were proposed in this pull request? Still saw the performance regression introduced by `spark.sql.codegen.hugeMethodLimit` in our internal workloads. There are two major issues in the current solution. - The size of the complied byte code is not identical to the bytecode size of the method. The detection is still not accurate. - The bytecode size of a single operator (e.g., `SerializeFromObject`) could still exceed 8K limit. We saw the performance regression in such scenario. Since it is close to the release of 2.3, we decide to increase it to 64K for avoiding the perf regression. ## How was this patch tested? N/A Author: gatorsmile Closes #20434 from gatorsmile/revertConf. --- .../scala/org/apache/spark/sql/internal/SQLConf.scala | 11 ++++++----- .../spark/sql/execution/WholeStageCodegenSuite.scala | 4 ++-- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 54a35594f505e..7394a0d7cf983 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -660,12 +660,13 @@ object SQLConf { val WHOLESTAGE_HUGE_METHOD_LIMIT = buildConf("spark.sql.codegen.hugeMethodLimit") .internal() .doc("The maximum bytecode size of a single compiled Java function generated by whole-stage " + - "codegen. When the compiled function exceeds this threshold, " + - "the whole-stage codegen is deactivated for this subtree of the current query plan. " + - s"The default value is ${CodeGenerator.DEFAULT_JVM_HUGE_METHOD_LIMIT} and " + - "this is a limit in the OpenJDK JVM implementation.") + "codegen. When the compiled function exceeds this threshold, the whole-stage codegen is " + + "deactivated for this subtree of the current query plan. The default value is 65535, which " + + "is the largest bytecode size possible for a valid Java method. When running on HotSpot, " + + s"it may be preferable to set the value to ${CodeGenerator.DEFAULT_JVM_HUGE_METHOD_LIMIT} " + + "to match HotSpot's implementation.") .intConf - .createWithDefault(CodeGenerator.DEFAULT_JVM_HUGE_METHOD_LIMIT) + .createWithDefault(65535) val WHOLESTAGE_SPLIT_CONSUME_FUNC_BY_OPERATOR = buildConf("spark.sql.codegen.splitConsumeFuncByOperator") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index 28ad712feaae6..6e8d5a70d5a8f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -202,7 +202,7 @@ class WholeStageCodegenSuite extends QueryTest with SharedSQLContext { wholeStageCodeGenExec.get.asInstanceOf[WholeStageCodegenExec].doCodeGen()._2 } - test("SPARK-21871 check if we can get large code size when compiling too long functions") { + ignore("SPARK-21871 check if we can get large code size when compiling too long functions") { val codeWithShortFunctions = genGroupByCode(3) val (_, maxCodeSize1) = CodeGenerator.compile(codeWithShortFunctions) assert(maxCodeSize1 < SQLConf.WHOLESTAGE_HUGE_METHOD_LIMIT.defaultValue.get) @@ -211,7 +211,7 @@ class WholeStageCodegenSuite extends QueryTest with SharedSQLContext { assert(maxCodeSize2 > SQLConf.WHOLESTAGE_HUGE_METHOD_LIMIT.defaultValue.get) } - test("bytecode of batch file scan exceeds the limit of WHOLESTAGE_HUGE_METHOD_LIMIT") { + ignore("bytecode of batch file scan exceeds the limit of WHOLESTAGE_HUGE_METHOD_LIMIT") { import testImplicits._ withTempPath { dir => val path = dir.getCanonicalPath From 58fcb5a95ee0b91300138cd23f3ce2165fab597f Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Tue, 30 Jan 2018 14:11:06 -0800 Subject: [PATCH 0242/1161] [SPARK-23275][SQL] hive/tests have been failing when run locally on the laptop (Mac) with OOM ## What changes were proposed in this pull request? hive tests have been failing when they are run locally (Mac Os) after a recent change in the trunk. After running the tests for some time, the test fails with OOM with Error: unable to create new native thread. I noticed the thread count goes all the way up to 2000+ after which we start getting these OOM errors. Most of the threads seem to be related to the connection pool in hive metastore (BoneCP-xxxxx-xxxx ). This behaviour change is happening after we made the following change to HiveClientImpl.reset() ``` SQL def reset(): Unit = withHiveState { try { // code } finally { runSqlHive("USE default") ===> this is causing the issue } ``` I am proposing to temporarily back-out part of a fix made to address SPARK-23000 to resolve this issue while we work-out the exact reason for this sudden increase in thread counts. ## How was this patch tested? Ran hive/test multiple times in different machines. (If this patch involves UI changes, please attach a screenshot; otherwise, remove this) Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Dilip Biswal Closes #20441 from dilipbiswal/hive_tests. --- .../sql/hive/client/HiveClientImpl.scala | 26 ++++++++----------- 1 file changed, 11 insertions(+), 15 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala index 39d839059be75..6c0f4144992ae 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala @@ -825,23 +825,19 @@ private[hive] class HiveClientImpl( } def reset(): Unit = withHiveState { - try { - client.getAllTables("default").asScala.foreach { t => - logDebug(s"Deleting table $t") - val table = client.getTable("default", t) - client.getIndexes("default", t, 255).asScala.foreach { index => - shim.dropIndex(client, "default", t, index.getIndexName) - } - if (!table.isIndexTable) { - client.dropTable("default", t) - } + client.getAllTables("default").asScala.foreach { t => + logDebug(s"Deleting table $t") + val table = client.getTable("default", t) + client.getIndexes("default", t, 255).asScala.foreach { index => + shim.dropIndex(client, "default", t, index.getIndexName) } - client.getAllDatabases.asScala.filterNot(_ == "default").foreach { db => - logDebug(s"Dropping Database: $db") - client.dropDatabase(db, true, false, true) + if (!table.isIndexTable) { + client.dropTable("default", t) } - } finally { - runSqlHive("USE default") + } + client.getAllDatabases.asScala.filterNot(_ == "default").foreach { db => + logDebug(s"Dropping Database: $db") + client.dropDatabase(db, true, false, true) } } } From 9623a98248837da302ba4ec240335d1c4268ee21 Mon Sep 17 00:00:00 2001 From: Shashwat Anand Date: Wed, 31 Jan 2018 07:37:25 +0900 Subject: [PATCH 0243/1161] [MINOR] Fix typos in dev/* scripts. ## What changes were proposed in this pull request? Consistency in style, grammar and removal of extraneous characters. ## How was this patch tested? Manually as this is a doc change. Author: Shashwat Anand Closes #20436 from ashashwat/SPARK-23174. --- dev/appveyor-guide.md | 6 +++--- dev/lint-python | 12 ++++++------ dev/run-pip-tests | 4 ++-- dev/run-tests-jenkins | 2 +- dev/sparktestsupport/modules.py | 8 ++++---- dev/sparktestsupport/toposort.py | 6 +++--- dev/tests/pr_merge_ability.sh | 4 ++-- dev/tests/pr_public_classes.sh | 4 ++-- 8 files changed, 23 insertions(+), 23 deletions(-) diff --git a/dev/appveyor-guide.md b/dev/appveyor-guide.md index d2e00b484727d..a842f39b3049a 100644 --- a/dev/appveyor-guide.md +++ b/dev/appveyor-guide.md @@ -1,6 +1,6 @@ # AppVeyor Guides -Currently, SparkR on Windows is being tested with [AppVeyor](https://ci.appveyor.com). This page describes how to set up AppVeyor with Spark, how to run the build, check the status and stop the build via this tool. There is the documenation for AppVeyor [here](https://www.appveyor.com/docs). Please refer this for full details. +Currently, SparkR on Windows is being tested with [AppVeyor](https://ci.appveyor.com). This page describes how to set up AppVeyor with Spark, how to run the build, check the status and stop the build via this tool. There is the documentation for AppVeyor [here](https://www.appveyor.com/docs). Please refer this for full details. ### Setting up AppVeyor @@ -45,7 +45,7 @@ Currently, SparkR on Windows is being tested with [AppVeyor](https://ci.appveyor 2016-08-30 12 16 35 -- Since we will use Github here, click the "GITHUB" button and then click "Authorize Github" so that AppVeyor can access to the Github logs (e.g. commits). +- Since we will use Github here, click the "GITHUB" button and then click "Authorize Github" so that AppVeyor can access the Github logs (e.g. commits). 2016-09-04 11 10 22 @@ -87,7 +87,7 @@ Currently, SparkR on Windows is being tested with [AppVeyor](https://ci.appveyor 2016-08-30 12 29 41 -- If the build is running, "CANCEL BUILD" buttom appears. Click this button top cancel the current build. +- If the build is running, "CANCEL BUILD" button appears. Click this button to cancel the current build. 2016-08-30 1 11 13 diff --git a/dev/lint-python b/dev/lint-python index e069cafa1b8c6..f738af9c49763 100755 --- a/dev/lint-python +++ b/dev/lint-python @@ -34,8 +34,8 @@ python -B -m compileall -q -l $PATHS_TO_CHECK > "$PYCODESTYLE_REPORT_PATH" compile_status="${PIPESTATUS[0]}" # Get pycodestyle at runtime so that we don't rely on it being installed on the build server. -#+ See: https://github.com/apache/spark/pull/1744#issuecomment-50982162 -# Updated to latest official version for pep8. pep8 is formally renamed to pycodestyle. +# See: https://github.com/apache/spark/pull/1744#issuecomment-50982162 +# Updated to the latest official version of pep8. pep8 is formally renamed to pycodestyle. PYCODESTYLE_VERSION="2.3.1" PYCODESTYLE_SCRIPT_PATH="$SPARK_ROOT_DIR/dev/pycodestyle-$PYCODESTYLE_VERSION.py" PYCODESTYLE_SCRIPT_REMOTE_PATH="https://raw.githubusercontent.com/PyCQA/pycodestyle/$PYCODESTYLE_VERSION/pycodestyle.py" @@ -60,9 +60,9 @@ export "PYLINT_HOME=$PYTHONPATH" export "PATH=$PYTHONPATH:$PATH" # There is no need to write this output to a file -#+ first, but we do so so that the check status can -#+ be output before the report, like with the -#+ scalastyle and RAT checks. +# first, but we do so so that the check status can +# be output before the report, like with the +# scalastyle and RAT checks. python "$PYCODESTYLE_SCRIPT_PATH" --config=dev/tox.ini $PATHS_TO_CHECK >> "$PYCODESTYLE_REPORT_PATH" pycodestyle_status="${PIPESTATUS[0]}" @@ -73,7 +73,7 @@ else fi if [ "$lint_status" -ne 0 ]; then - echo "PYCODESTYLE checks failed." + echo "pycodestyle checks failed." cat "$PYCODESTYLE_REPORT_PATH" rm "$PYCODESTYLE_REPORT_PATH" exit "$lint_status" diff --git a/dev/run-pip-tests b/dev/run-pip-tests index d51dde12a03c5..1321c2be4c192 100755 --- a/dev/run-pip-tests +++ b/dev/run-pip-tests @@ -25,10 +25,10 @@ shopt -s nullglob FWDIR="$(cd "$(dirname "$0")"/..; pwd)" cd "$FWDIR" -echo "Constucting virtual env for testing" +echo "Constructing virtual env for testing" VIRTUALENV_BASE=$(mktemp -d) -# Clean up the virtual env enviroment used if we created one. +# Clean up the virtual env environment used if we created one. function delete_virtualenv() { echo "Cleaning up temporary directory - $VIRTUALENV_BASE" rm -rf "$VIRTUALENV_BASE" diff --git a/dev/run-tests-jenkins b/dev/run-tests-jenkins index 03fd6ff0fba40..5bc03e41d1f2d 100755 --- a/dev/run-tests-jenkins +++ b/dev/run-tests-jenkins @@ -20,7 +20,7 @@ # Wrapper script that runs the Spark tests then reports QA results # to github via its API. # Environment variables are populated by the code here: -#+ https://github.com/jenkinsci/ghprb-plugin/blob/master/src/main/java/org/jenkinsci/plugins/ghprb/GhprbTrigger.java#L139 +# https://github.com/jenkinsci/ghprb-plugin/blob/master/src/main/java/org/jenkinsci/plugins/ghprb/GhprbTrigger.java#L139 FWDIR="$( cd "$( dirname "$0" )/.." && pwd )" cd "$FWDIR" diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index b900f0bd913c3..dfea762db98c6 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -25,10 +25,10 @@ @total_ordering class Module(object): """ - A module is the basic abstraction in our test runner script. Each module consists of a set of - source files, a set of test commands, and a set of dependencies on other modules. We use modules - to define a dependency graph that lets determine which tests to run based on which files have - changed. + A module is the basic abstraction in our test runner script. Each module consists of a set + of source files, a set of test commands, and a set of dependencies on other modules. We use + modules to define a dependency graph that let us determine which tests to run based on which + files have changed. """ def __init__(self, name, dependencies, source_file_regexes, build_profile_flags=(), environ={}, diff --git a/dev/sparktestsupport/toposort.py b/dev/sparktestsupport/toposort.py index 6c67b4504bc3b..8b2688d20039f 100644 --- a/dev/sparktestsupport/toposort.py +++ b/dev/sparktestsupport/toposort.py @@ -43,8 +43,8 @@ def toposort(data): """Dependencies are expressed as a dictionary whose keys are items and whose values are a set of dependent items. Output is a list of sets in topological order. The first set consists of items with no -dependences, each subsequent set consists of items that depend upon -items in the preceeding sets. +dependencies, each subsequent set consists of items that depend upon +items in the preceding sets. """ # Special case empty input. @@ -59,7 +59,7 @@ def toposort(data): v.discard(k) # Find all items that don't depend on anything. extra_items_in_deps = _reduce(set.union, data.values()) - set(data.keys()) - # Add empty dependences where needed. + # Add empty dependencies where needed. data.update({item: set() for item in extra_items_in_deps}) while True: ordered = set(item for item, dep in data.items() if len(dep) == 0) diff --git a/dev/tests/pr_merge_ability.sh b/dev/tests/pr_merge_ability.sh index d9a347fe24a8c..25fdbccac4dd8 100755 --- a/dev/tests/pr_merge_ability.sh +++ b/dev/tests/pr_merge_ability.sh @@ -23,9 +23,9 @@ # found at dev/run-tests-jenkins. # # Arg1: The Github Pull Request Actual Commit -#+ known as `ghprbActualCommit` in `run-tests-jenkins` +# known as `ghprbActualCommit` in `run-tests-jenkins` # Arg2: The SHA1 hash -#+ known as `sha1` in `run-tests-jenkins` +# known as `sha1` in `run-tests-jenkins` # ghprbActualCommit="$1" diff --git a/dev/tests/pr_public_classes.sh b/dev/tests/pr_public_classes.sh index 41c5d3ee8cb3c..479d1851fe0b8 100755 --- a/dev/tests/pr_public_classes.sh +++ b/dev/tests/pr_public_classes.sh @@ -23,7 +23,7 @@ # found at dev/run-tests-jenkins. # # Arg1: The Github Pull Request Actual Commit -#+ known as `ghprbActualCommit` in `run-tests-jenkins` +# known as `ghprbActualCommit` in `run-tests-jenkins` ghprbActualCommit="$1" @@ -31,7 +31,7 @@ ghprbActualCommit="$1" # master commit and the tip of the pull request branch. # By diffing$ghprbActualCommit^...$ghprbActualCommit and filtering to examine the diffs of only -# non-test files, we can gets us changes introduced in the PR and not anything else added to master +# non-test files, we can get changes introduced in the PR and not anything else added to master # since the PR was branched. # Handle differences between GNU and BSD sed From 77866167330a665e174ae08a2f8902ef9dc3438b Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Tue, 30 Jan 2018 17:14:17 -0800 Subject: [PATCH 0244/1161] [SPARK-23276][SQL][TEST] Enable UDT tests in (Hive)OrcHadoopFsRelationSuite ## What changes were proposed in this pull request? Like Parquet, ORC test suites should enable UDT tests. ## How was this patch tested? Pass the Jenkins with newly enabled test cases. Author: Dongjoon Hyun Closes #20440 from dongjoon-hyun/SPARK-23276. --- .../apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala index a1f054b8e3f44..3b82a6c458ce4 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala @@ -34,11 +34,10 @@ class OrcHadoopFsRelationSuite extends HadoopFsRelationTest { override val dataSourceName: String = classOf[org.apache.spark.sql.execution.datasources.orc.OrcFileFormat].getCanonicalName - // ORC does not play well with NullType and UDT. + // ORC does not play well with NullType. override protected def supportsDataType(dataType: DataType): Boolean = dataType match { case _: NullType => false case _: CalendarIntervalType => false - case _: UserDefinedType[_] => false case _ => true } From ca04c3ff2387bf0a4308a4b010154e6761827278 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 30 Jan 2018 20:05:57 -0800 Subject: [PATCH 0245/1161] [SPARK-23274][SQL] Fix ReplaceExceptWithFilter when the right's Filter contains the references that are not in the left output ## What changes were proposed in this pull request? This PR is to fix the `ReplaceExceptWithFilter` rule when the right's Filter contains the references that are not in the left output. Before this PR, we got the error like ``` java.util.NoSuchElementException: key not found: a at scala.collection.MapLike$class.default(MapLike.scala:228) at scala.collection.AbstractMap.default(Map.scala:59) at scala.collection.MapLike$class.apply(MapLike.scala:141) at scala.collection.AbstractMap.apply(Map.scala:59) ``` After this PR, `ReplaceExceptWithFilter ` will not take an effect in this case. ## How was this patch tested? Added tests Author: gatorsmile Closes #20444 from gatorsmile/fixReplaceExceptWithFilter. --- .../optimizer/ReplaceExceptWithFilter.scala | 17 +++++++++++++---- .../optimizer/ReplaceOperatorSuite.scala | 15 +++++++++++++++ .../org/apache/spark/sql/DataFrameSuite.scala | 8 ++++++++ 3 files changed, 36 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceExceptWithFilter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceExceptWithFilter.scala index 89bfcee078fba..45edf266bbce4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceExceptWithFilter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceExceptWithFilter.scala @@ -46,18 +46,27 @@ object ReplaceExceptWithFilter extends Rule[LogicalPlan] { } plan.transform { - case Except(left, right) if isEligible(left, right) => - Distinct(Filter(Not(transformCondition(left, skipProject(right))), left)) + case e @ Except(left, right) if isEligible(left, right) => + val newCondition = transformCondition(left, skipProject(right)) + newCondition.map { c => + Distinct(Filter(Not(c), left)) + }.getOrElse { + e + } } } - private def transformCondition(left: LogicalPlan, right: LogicalPlan): Expression = { + private def transformCondition(left: LogicalPlan, right: LogicalPlan): Option[Expression] = { val filterCondition = InferFiltersFromConstraints(combineFilters(right)).asInstanceOf[Filter].condition val attributeNameMap: Map[String, Attribute] = left.output.map(x => (x.name, x)).toMap - filterCondition.transform { case a : AttributeReference => attributeNameMap(a.name) } + if (filterCondition.references.forall(r => attributeNameMap.contains(r.name))) { + Some(filterCondition.transform { case a: AttributeReference => attributeNameMap(a.name) }) + } else { + None + } } // TODO: This can be further extended in the future. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala index e9701ffd2c54b..52dc2e9fb076c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala @@ -168,6 +168,21 @@ class ReplaceOperatorSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + test("replace Except with Filter when only right filter can be applied to the left") { + val table = LocalRelation(Seq('a.int, 'b.int)) + val left = table.where('b < 1).select('a).as("left") + val right = table.where('b < 3).select('a).as("right") + + val query = Except(left, right) + val optimized = Optimize.execute(query.analyze) + + val correctAnswer = + Aggregate(left.output, right.output, + Join(left, right, LeftAnti, Option($"left.a" <=> $"right.a"))).analyze + + comparePlans(optimized, correctAnswer) + } + test("replace Distinct with Aggregate") { val input = LocalRelation('a.int, 'b.int) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 33707080c1301..8b66f77b2f923 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -589,6 +589,14 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { Nil) } + test("SPARK-23274: except between two projects without references used in filter") { + val df = Seq((1, 2, 4), (1, 3, 5), (2, 2, 3), (2, 4, 5)).toDF("a", "b", "c") + val df1 = df.filter($"a" === 1) + val df2 = df.filter($"a" === 2) + checkAnswer(df1.select("b").except(df2.select("b")), Row(3) :: Nil) + checkAnswer(df1.select("b").except(df2.select("c")), Row(2) :: Nil) + } + test("except distinct - SQL compliance") { val df_left = Seq(1, 2, 2, 3, 3, 4).toDF("id") val df_right = Seq(1, 3).toDF("id") From 8c6a9c90a36a938372f28ee8be72178192fbc313 Mon Sep 17 00:00:00 2001 From: jerryshao Date: Wed, 31 Jan 2018 13:59:21 +0800 Subject: [PATCH 0246/1161] [SPARK-23279][SS] Avoid triggering distributed job for Console sink ## What changes were proposed in this pull request? Console sink will redistribute collected local data and trigger a distributed job in each batch, this is not necessary, so here change to local job. ## How was this patch tested? Existing UT and manual verification. Author: jerryshao Closes #20447 from jerryshao/console-minor. --- .../spark/sql/execution/streaming/sources/ConsoleWriter.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala index d46f4d7b86360..c57bdc4a28905 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.streaming.sources +import scala.collection.JavaConverters._ + import org.apache.spark.internal.Logging import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.sql.sources.v2.DataSourceOptions @@ -61,7 +63,7 @@ class ConsoleWriter(schema: StructType, options: DataSourceOptions) println("-------------------------------------------") // scalastyle:off println spark - .createDataFrame(spark.sparkContext.parallelize(rows), schema) + .createDataFrame(rows.toList.asJava, schema) .show(numRowsToShow, isTruncated) } From 695f7146bca342a0ee192d8c7f5ec48d4d8577a8 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 31 Jan 2018 15:13:15 +0800 Subject: [PATCH 0247/1161] [SPARK-23272][SQL] add calendar interval type support to ColumnVector ## What changes were proposed in this pull request? `ColumnVector` is aimed to support all the data types, but `CalendarIntervalType` is missing. Actually we do support interval type for inner fields, e.g. `ColumnarRow`, `ColumnarArray` both support interval type. It's weird if we don't support interval type at the top level. This PR adds the interval type support. This PR also makes `ColumnVector.getChild` protect. We need it public because `MutableColumnaRow.getInterval` needs it. Now the interval implementation is in `ColumnVector.getInterval`. ## How was this patch tested? a new test. Author: Wenchen Fan Closes #20438 from cloud-fan/interval. --- .../vectorized/MutableColumnarRow.java | 4 +- .../sql/vectorized/ArrowColumnVector.java | 2 +- .../spark/sql/vectorized/ColumnVector.java | 26 ++++++++++- .../spark/sql/vectorized/ColumnarArray.java | 4 +- .../spark/sql/vectorized/ColumnarRow.java | 4 +- .../vectorized/ColumnarBatchSuite.scala | 45 +++++++++++++++++-- 6 files changed, 70 insertions(+), 15 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java index 2bab095d4d951..66668f3753604 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java @@ -146,9 +146,7 @@ public byte[] getBinary(int ordinal) { @Override public CalendarInterval getInterval(int ordinal) { if (columns[ordinal].isNullAt(rowId)) return null; - final int months = columns[ordinal].getChild(0).getInt(rowId); - final long microseconds = columns[ordinal].getChild(1).getLong(rowId); - return new CalendarInterval(months, microseconds); + return columns[ordinal].getInterval(rowId); } @Override diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java index 9803c3dec6de2..a75d76bd0f82e 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java @@ -28,7 +28,7 @@ import org.apache.spark.unsafe.types.UTF8String; /** - * A column vector backed by Apache Arrow. Currently time interval type and map type are not + * A column vector backed by Apache Arrow. Currently calendar interval type and map type are not * supported. */ @InterfaceStability.Evolving diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java index 4b955ceddd0f2..111f5d9b358d4 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java @@ -20,6 +20,7 @@ import org.apache.spark.sql.catalyst.util.MapData; import org.apache.spark.sql.types.DataType; import org.apache.spark.sql.types.Decimal; +import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; /** @@ -195,6 +196,7 @@ public double[] getDoubles(int rowId, int count) { * struct field. */ public final ColumnarRow getStruct(int rowId) { + if (isNullAt(rowId)) return null; return new ColumnarRow(this, rowId); } @@ -236,9 +238,29 @@ public MapData getMap(int ordinal) { public abstract byte[] getBinary(int rowId); /** - * Returns the ordinal's child column vector. + * Returns the calendar interval type value for rowId. + * + * In Spark, calendar interval type value is basically an integer value representing the number of + * months in this interval, and a long value representing the number of microseconds in this + * interval. An interval type vector is the same as a struct type vector with 2 fields: `months` + * and `microseconds`. + * + * To support interval type, implementations must implement {@link #getChild(int)} and define 2 + * child vectors: the first child vector is an int type vector, containing all the month values of + * all the interval values in this vector. The second child vector is a long type vector, + * containing all the microsecond values of all the interval values in this vector. + */ + public final CalendarInterval getInterval(int rowId) { + if (isNullAt(rowId)) return null; + final int months = getChild(0).getInt(rowId); + final long microseconds = getChild(1).getLong(rowId); + return new CalendarInterval(months, microseconds); + } + + /** + * @return child [[ColumnVector]] at the given ordinal. */ - public abstract ColumnVector getChild(int ordinal); + protected abstract ColumnVector getChild(int ordinal); /** * Data type for this column. diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java index 0d2c3ec8648d3..72c07ee7cad3f 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java @@ -135,9 +135,7 @@ public byte[] getBinary(int ordinal) { @Override public CalendarInterval getInterval(int ordinal) { - int month = data.getChild(0).getInt(offset + ordinal); - long microseconds = data.getChild(1).getLong(offset + ordinal); - return new CalendarInterval(month, microseconds); + return data.getInterval(offset + ordinal); } @Override diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java index 25db7e09d20d0..6ca749d7c6e85 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java @@ -139,9 +139,7 @@ public byte[] getBinary(int ordinal) { @Override public CalendarInterval getInterval(int ordinal) { if (data.getChild(ordinal).isNullAt(rowId)) return null; - final int months = data.getChild(ordinal).getChild(0).getInt(rowId); - final long microseconds = data.getChild(ordinal).getChild(1).getLong(rowId); - return new CalendarInterval(months, microseconds); + return data.getChild(ordinal).getInterval(rowId); } @Override diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala index 1873c24ab063c..925c101fe1fee 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala @@ -620,6 +620,39 @@ class ColumnarBatchSuite extends SparkFunSuite { assert(column.arrayData().elementsAppended == 0) } + testVector("CalendarInterval APIs", 4, CalendarIntervalType) { + column => + val reference = mutable.ArrayBuffer.empty[CalendarInterval] + + val months = column.getChild(0) + val microseconds = column.getChild(1) + assert(months.dataType() == IntegerType) + assert(microseconds.dataType() == LongType) + + months.putInt(0, 1) + microseconds.putLong(0, 100) + reference += new CalendarInterval(1, 100) + + months.putInt(1, 0) + microseconds.putLong(1, 2000) + reference += new CalendarInterval(0, 2000) + + column.putNull(2) + reference += null + + months.putInt(3, 20) + microseconds.putLong(3, 0) + reference += new CalendarInterval(20, 0) + + reference.zipWithIndex.foreach { case (v, i) => + val errMsg = "VectorType=" + column.getClass.getSimpleName + assert(v == column.getInterval(i), errMsg) + if (v == null) assert(column.isNullAt(i), errMsg) + } + + column.close() + } + testVector("Int Array", 10, new ArrayType(IntegerType, true)) { column => @@ -739,14 +772,20 @@ class ColumnarBatchSuite extends SparkFunSuite { c1.putInt(0, 123) c2.putDouble(0, 3.45) - c1.putInt(1, 456) - c2.putDouble(1, 5.67) + + column.putNull(1) + + c1.putInt(2, 456) + c2.putDouble(2, 5.67) val s = column.getStruct(0) assert(s.getInt(0) == 123) assert(s.getDouble(1) == 3.45) - val s2 = column.getStruct(1) + assert(column.isNullAt(1)) + assert(column.getStruct(1) == null) + + val s2 = column.getStruct(2) assert(s2.getInt(0) == 456) assert(s2.getDouble(1) == 5.67) } From 161a3f2ae324271a601500e3d2900db9359ee2ef Mon Sep 17 00:00:00 2001 From: Nick Pentreath Date: Wed, 31 Jan 2018 10:37:37 +0200 Subject: [PATCH 0248/1161] [SPARK-23112][DOC] Update ML migration guide with breaking and behavior changes. Add breaking changes, as well as update behavior changes, to `2.3` ML migration guide. ## How was this patch tested? Doc only Author: Nick Pentreath Closes #20421 from MLnick/SPARK-23112-ml-guide. --- docs/ml-guide.md | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/docs/ml-guide.md b/docs/ml-guide.md index b957445579ffd..702bcf748fc74 100644 --- a/docs/ml-guide.md +++ b/docs/ml-guide.md @@ -108,7 +108,13 @@ and the migration guide below will explain all changes between releases. ### Breaking changes -There are no breaking changes. +* The class and trait hierarchy for logistic regression model summaries was changed to be cleaner +and better accommodate the addition of the multi-class summary. This is a breaking change for user +code that casts a `LogisticRegressionTrainingSummary` to a +` BinaryLogisticRegressionTrainingSummary`. Users should instead use the `model.binarySummary` +method. See [SPARK-17139](https://issues.apache.org/jira/browse/SPARK-17139) for more detail +(_note_ this is an `Experimental` API). This _does not_ affect the Python `summary` method, which +will still work correctly for both multinomial and binary cases. ### Deprecations and changes of behavior @@ -123,8 +129,19 @@ new [`OneHotEncoderEstimator`](ml-features.html#onehotencoderestimator) **Changes of behavior** * [SPARK-21027](https://issues.apache.org/jira/browse/SPARK-21027): - We are now setting the default parallelism used in `OneVsRest` to be 1 (i.e. serial). In 2.2 and + The default parallelism used in `OneVsRest` is now set to 1 (i.e. serial). In `2.2` and earlier versions, the level of parallelism was set to the default threadpool size in Scala. +* [SPARK-22156](https://issues.apache.org/jira/browse/SPARK-22156): + The learning rate update for `Word2Vec` was incorrect when `numIterations` was set greater than + `1`. This will cause training results to be different between `2.3` and earlier versions. +* [SPARK-21681](https://issues.apache.org/jira/browse/SPARK-21681): + Fixed an edge case bug in multinomial logistic regression that resulted in incorrect coefficients + when some features had zero variance. +* [SPARK-16957](https://issues.apache.org/jira/browse/SPARK-16957): + Tree algorithms now use mid-points for split values. This may change results from model training. +* [SPARK-14657](https://issues.apache.org/jira/browse/SPARK-14657): + Fixed an issue where the features generated by `RFormula` without an intercept were inconsistent + with the output in R. This may change results from model training in this scenario. ## Previous Spark versions From 3d0911bbe47f76c341c090edad3737e88a67e3d7 Mon Sep 17 00:00:00 2001 From: jerryshao Date: Wed, 31 Jan 2018 20:04:51 +0900 Subject: [PATCH 0249/1161] [SPARK-23228][PYSPARK] Add Python Created jsparkSession to JVM's defaultSession ## What changes were proposed in this pull request? In the current PySpark code, Python created `jsparkSession` doesn't add to JVM's defaultSession, this `SparkSession` object cannot be fetched from Java side, so the below scala code will be failed when loaded in PySpark application. ```scala class TestSparkSession extends SparkListener with Logging { override def onOtherEvent(event: SparkListenerEvent): Unit = { event match { case CreateTableEvent(db, table) => val session = SparkSession.getActiveSession.orElse(SparkSession.getDefaultSession) assert(session.isDefined) val tableInfo = session.get.sharedState.externalCatalog.getTable(db, table) logInfo(s"Table info ${tableInfo}") case e => logInfo(s"event $e") } } } ``` So here propose to add fresh create `jsparkSession` to `defaultSession`. ## How was this patch tested? Manual verification. Author: jerryshao Author: hyukjinkwon Author: Saisai Shao Closes #20404 from jerryshao/SPARK-23228. --- python/pyspark/sql/session.py | 10 +++++++++- python/pyspark/sql/tests.py | 28 +++++++++++++++++++++++++++- 2 files changed, 36 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index 6c84023c43fb6..1ed04298bc899 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -213,7 +213,12 @@ def __init__(self, sparkContext, jsparkSession=None): self._jsc = self._sc._jsc self._jvm = self._sc._jvm if jsparkSession is None: - jsparkSession = self._jvm.SparkSession(self._jsc.sc()) + if self._jvm.SparkSession.getDefaultSession().isDefined() \ + and not self._jvm.SparkSession.getDefaultSession().get() \ + .sparkContext().isStopped(): + jsparkSession = self._jvm.SparkSession.getDefaultSession().get() + else: + jsparkSession = self._jvm.SparkSession(self._jsc.sc()) self._jsparkSession = jsparkSession self._jwrapped = self._jsparkSession.sqlContext() self._wrapped = SQLContext(self._sc, self, self._jwrapped) @@ -225,6 +230,7 @@ def __init__(self, sparkContext, jsparkSession=None): if SparkSession._instantiatedSession is None \ or SparkSession._instantiatedSession._sc._jsc is None: SparkSession._instantiatedSession = self + self._jvm.SparkSession.setDefaultSession(self._jsparkSession) def _repr_html_(self): return """ @@ -759,6 +765,8 @@ def stop(self): """Stop the underlying :class:`SparkContext`. """ self._sc.stop() + # We should clean the default session up. See SPARK-23228. + self._jvm.SparkSession.clearDefaultSession() SparkSession._instantiatedSession = None @since(2.0) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index dc80870d3cd9f..dc26b96334c7a 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -69,7 +69,7 @@ from pyspark.sql.types import _array_signed_int_typecode_ctype_mappings, _array_type_mappings from pyspark.sql.types import _array_unsigned_int_typecode_ctype_mappings from pyspark.sql.types import _merge_type -from pyspark.tests import QuietTest, ReusedPySparkTestCase, SparkSubmitTests +from pyspark.tests import QuietTest, ReusedPySparkTestCase, PySparkTestCase, SparkSubmitTests from pyspark.sql.functions import UserDefinedFunction, sha2, lit from pyspark.sql.window import Window from pyspark.sql.utils import AnalysisException, ParseException, IllegalArgumentException @@ -2925,6 +2925,32 @@ def test_sparksession_with_stopped_sparkcontext(self): sc.stop() +class SparkSessionTests(PySparkTestCase): + + # This test is separate because it's closely related with session's start and stop. + # See SPARK-23228. + def test_set_jvm_default_session(self): + spark = SparkSession.builder.getOrCreate() + try: + self.assertTrue(spark._jvm.SparkSession.getDefaultSession().isDefined()) + finally: + spark.stop() + self.assertTrue(spark._jvm.SparkSession.getDefaultSession().isEmpty()) + + def test_jvm_default_session_already_set(self): + # Here, we assume there is the default session already set in JVM. + jsession = self.sc._jvm.SparkSession(self.sc._jsc.sc()) + self.sc._jvm.SparkSession.setDefaultSession(jsession) + + spark = SparkSession.builder.getOrCreate() + try: + self.assertTrue(spark._jvm.SparkSession.getDefaultSession().isDefined()) + # The session should be the same with the exiting one. + self.assertTrue(jsession.equals(spark._jvm.SparkSession.getDefaultSession().get())) + finally: + spark.stop() + + class UDFInitializationTests(unittest.TestCase): def tearDown(self): if SparkSession._instantiatedSession is not None: From 48dd6a4c79e33a8f2dba8349b58aa07e4796a925 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 1 Feb 2018 00:24:42 +0800 Subject: [PATCH 0250/1161] revert [SPARK-22785][SQL] remove ColumnVector.anyNullsSet ## What changes were proposed in this pull request? In https://github.com/apache/spark/pull/19980 , we thought `anyNullsSet` can be simply implemented by `numNulls() > 0`. This is logically true, but may have performance problems. `OrcColumnVector` is an example. It doesn't have the `numNulls` property, only has a `noNulls` property. We will lose a lot of performance if we use `numNulls() > 0` to check null. This PR simply revert #19980, with a renaming to call it `hasNull`. Better name suggestions are welcome, e.g. `nullable`? ## How was this patch tested? existing test Author: Wenchen Fan Closes #20452 from cloud-fan/null. --- .../execution/datasources/orc/OrcColumnVector.java | 5 +++++ .../execution/vectorized/OffHeapColumnVector.java | 2 +- .../sql/execution/vectorized/OnHeapColumnVector.java | 2 +- .../execution/vectorized/WritableColumnVector.java | 7 ++++++- .../spark/sql/vectorized/ArrowColumnVector.java | 5 +++++ .../apache/spark/sql/vectorized/ColumnVector.java | 5 +++++ .../vectorized/ArrowColumnVectorSuite.scala | 12 ++++++++++++ .../execution/vectorized/ColumnarBatchSuite.scala | 9 +++++++++ 8 files changed, 44 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java index 5078bc7922ee2..78203e3145c62 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java @@ -77,6 +77,11 @@ public void close() { } + @Override + public boolean hasNull() { + return !baseData.noNulls; + } + @Override public int numNulls() { if (baseData.isRepeating) { diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java index 1c45b846790b6..fa52e4a354786 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java @@ -123,7 +123,7 @@ public void putNulls(int rowId, int count) { @Override public void putNotNulls(int rowId, int count) { - if (numNulls == 0) return; + if (!hasNull()) return; long offset = nulls + rowId; for (int i = 0; i < count; ++i, ++offset) { Platform.putByte(null, offset, (byte) 0); diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java index 1d538fe4181b7..cccef78aebdc8 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java @@ -119,7 +119,7 @@ public void putNulls(int rowId, int count) { @Override public void putNotNulls(int rowId, int count) { - if (numNulls == 0) return; + if (!hasNull()) return; for (int i = 0; i < count; ++i) { nulls[rowId + i] = (byte)0; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java index a8ec8ef2aadf8..8ebc1adf59c8b 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java @@ -59,8 +59,8 @@ public void reset() { elementsAppended = 0; if (numNulls > 0) { putNotNulls(0, capacity); + numNulls = 0; } - numNulls = 0; } @Override @@ -102,6 +102,11 @@ private void throwUnsupportedException(int requiredCapacity, Throwable cause) { throw new RuntimeException(message, cause); } + @Override + public boolean hasNull() { + return numNulls > 0; + } + @Override public int numNulls() { return numNulls; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java index a75d76bd0f82e..5ff6474c161f3 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java @@ -37,6 +37,11 @@ public final class ArrowColumnVector extends ColumnVector { private final ArrowVectorAccessor accessor; private ArrowColumnVector[] childColumns; + @Override + public boolean hasNull() { + return accessor.getNullCount() > 0; + } + @Override public int numNulls() { return accessor.getNullCount(); diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java index 111f5d9b358d4..d588956208047 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java @@ -65,6 +65,11 @@ public abstract class ColumnVector implements AutoCloseable { @Override public abstract void close(); + /** + * Returns true if this column vector contains any null values. + */ + public abstract boolean hasNull(); + /** * Returns the number of nulls in this column vector. */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala index e794f50781ff2..b55489cb2678a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala @@ -42,6 +42,7 @@ class ArrowColumnVectorSuite extends SparkFunSuite { val columnVector = new ArrowColumnVector(vector) assert(columnVector.dataType === BooleanType) + assert(columnVector.hasNull) assert(columnVector.numNulls === 1) (0 until 10).foreach { i => @@ -69,6 +70,7 @@ class ArrowColumnVectorSuite extends SparkFunSuite { val columnVector = new ArrowColumnVector(vector) assert(columnVector.dataType === ByteType) + assert(columnVector.hasNull) assert(columnVector.numNulls === 1) (0 until 10).foreach { i => @@ -96,6 +98,7 @@ class ArrowColumnVectorSuite extends SparkFunSuite { val columnVector = new ArrowColumnVector(vector) assert(columnVector.dataType === ShortType) + assert(columnVector.hasNull) assert(columnVector.numNulls === 1) (0 until 10).foreach { i => @@ -123,6 +126,7 @@ class ArrowColumnVectorSuite extends SparkFunSuite { val columnVector = new ArrowColumnVector(vector) assert(columnVector.dataType === IntegerType) + assert(columnVector.hasNull) assert(columnVector.numNulls === 1) (0 until 10).foreach { i => @@ -150,6 +154,7 @@ class ArrowColumnVectorSuite extends SparkFunSuite { val columnVector = new ArrowColumnVector(vector) assert(columnVector.dataType === LongType) + assert(columnVector.hasNull) assert(columnVector.numNulls === 1) (0 until 10).foreach { i => @@ -177,6 +182,7 @@ class ArrowColumnVectorSuite extends SparkFunSuite { val columnVector = new ArrowColumnVector(vector) assert(columnVector.dataType === FloatType) + assert(columnVector.hasNull) assert(columnVector.numNulls === 1) (0 until 10).foreach { i => @@ -204,6 +210,7 @@ class ArrowColumnVectorSuite extends SparkFunSuite { val columnVector = new ArrowColumnVector(vector) assert(columnVector.dataType === DoubleType) + assert(columnVector.hasNull) assert(columnVector.numNulls === 1) (0 until 10).foreach { i => @@ -232,6 +239,7 @@ class ArrowColumnVectorSuite extends SparkFunSuite { val columnVector = new ArrowColumnVector(vector) assert(columnVector.dataType === StringType) + assert(columnVector.hasNull) assert(columnVector.numNulls === 1) (0 until 10).foreach { i => @@ -258,6 +266,7 @@ class ArrowColumnVectorSuite extends SparkFunSuite { val columnVector = new ArrowColumnVector(vector) assert(columnVector.dataType === BinaryType) + assert(columnVector.hasNull) assert(columnVector.numNulls === 1) (0 until 10).foreach { i => @@ -300,6 +309,7 @@ class ArrowColumnVectorSuite extends SparkFunSuite { val columnVector = new ArrowColumnVector(vector) assert(columnVector.dataType === ArrayType(IntegerType)) + assert(columnVector.hasNull) assert(columnVector.numNulls === 1) val array0 = columnVector.getArray(0) @@ -344,6 +354,7 @@ class ArrowColumnVectorSuite extends SparkFunSuite { val columnVector = new ArrowColumnVector(vector) assert(columnVector.dataType === schema) + assert(!columnVector.hasNull) assert(columnVector.numNulls === 0) val row0 = columnVector.getStruct(0) @@ -396,6 +407,7 @@ class ArrowColumnVectorSuite extends SparkFunSuite { val columnVector = new ArrowColumnVector(vector) assert(columnVector.dataType === schema) + assert(columnVector.hasNull) assert(columnVector.numNulls === 1) val row0 = columnVector.getStruct(0) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala index 925c101fe1fee..168bc5e3e480b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala @@ -66,22 +66,27 @@ class ColumnarBatchSuite extends SparkFunSuite { column => val reference = mutable.ArrayBuffer.empty[Boolean] var idx = 0 + assert(!column.hasNull) assert(column.numNulls() == 0) column.appendNotNull() reference += false + assert(!column.hasNull) assert(column.numNulls() == 0) column.appendNotNulls(3) (1 to 3).foreach(_ => reference += false) + assert(!column.hasNull) assert(column.numNulls() == 0) column.appendNull() reference += true + assert(column.hasNull) assert(column.numNulls() == 1) column.appendNulls(3) (1 to 3).foreach(_ => reference += true) + assert(column.hasNull) assert(column.numNulls() == 4) idx = column.elementsAppended @@ -89,11 +94,13 @@ class ColumnarBatchSuite extends SparkFunSuite { column.putNotNull(idx) reference += false idx += 1 + assert(column.hasNull) assert(column.numNulls() == 4) column.putNull(idx) reference += true idx += 1 + assert(column.hasNull) assert(column.numNulls() == 5) column.putNulls(idx, 3) @@ -101,6 +108,7 @@ class ColumnarBatchSuite extends SparkFunSuite { reference += true reference += true idx += 3 + assert(column.hasNull) assert(column.numNulls() == 8) column.putNotNulls(idx, 4) @@ -109,6 +117,7 @@ class ColumnarBatchSuite extends SparkFunSuite { reference += false reference += false idx += 4 + assert(column.hasNull) assert(column.numNulls() == 8) reference.zipWithIndex.foreach { v => From 8c21170decfb9ca4d3233e1ea13bd1b6e3199ed9 Mon Sep 17 00:00:00 2001 From: Glen Takahashi Date: Thu, 1 Feb 2018 01:14:01 +0800 Subject: [PATCH 0251/1161] [SPARK-23249][SQL] Improved block merging logic for partitions ## What changes were proposed in this pull request? Change DataSourceScanExec so that when grouping blocks together into partitions, also checks the end of the sorted list of splits to more efficiently fill out partitions. ## How was this patch tested? Updated old test to reflect the new logic, which causes the # of partitions to drop from 4 -> 3 Also, a current test exists to test large non-splittable files at https://github.com/glentakahashi/spark/blob/c575977a5952bf50b605be8079c9be1e30f3bd36/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala#L346 ## Rationale The current bin-packing method of next-fit descending for blocks into partitions is sub-optimal in a lot of cases and will result in extra partitions, un-even distribution of block-counts across partitions, and un-even distribution of partition sizes. As an example, 128 files ranging from 1MB, 2MB,...127MB,128MB. will result in 82 partitions with the current algorithm, but only 64 using this algorithm. Also in this example, the max # of blocks per partition in NFD is 13, while in this algorithm is is 2. More generally, running a simulation of 1000 runs using 128MB blocksize, between 1-1000 normally distributed file sizes between 1-500Mb, you can see an improvement of approx 5% reduction of partition counts, and a large reduction in standard deviation of blocks per partition. This algorithm also runs in O(n) time as NFD does, and in every case is strictly better results than NFD. Overall, the more even distribution of blocks across partitions and therefore reduced partition counts should result in a small but significant performance increase across the board Author: Glen Takahashi Closes #20372 from glentakahashi/feature/improved-block-merging. --- .../sql/execution/DataSourceScanExec.scala | 29 ++++++++++++++----- .../datasources/FileSourceStrategySuite.scala | 15 ++++------ 2 files changed, 27 insertions(+), 17 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index aa66ee7e948ea..f7732e2098c29 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -445,16 +445,29 @@ case class FileSourceScanExec( currentSize = 0 } - // Assign files to partitions using "Next Fit Decreasing" - splitFiles.foreach { file => - if (currentSize + file.length > maxSplitBytes) { - closePartition() + def addFile(file: PartitionedFile): Unit = { + currentFiles += file + currentSize += file.length + openCostInBytes + } + + var frontIndex = 0 + var backIndex = splitFiles.length - 1 + + while (frontIndex <= backIndex) { + addFile(splitFiles(frontIndex)) + frontIndex += 1 + while (frontIndex <= backIndex && + currentSize + splitFiles(frontIndex).length <= maxSplitBytes) { + addFile(splitFiles(frontIndex)) + frontIndex += 1 + } + while (backIndex > frontIndex && + currentSize + splitFiles(backIndex).length <= maxSplitBytes) { + addFile(splitFiles(backIndex)) + backIndex -= 1 } - // Add the given file to the current partition. - currentSize += file.length + openCostInBytes - currentFiles += file + closePartition() } - closePartition() new FileScanRDD(fsRelation.sparkSession, readFile, partitions) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala index c1d61b843d899..bfccc9335b361 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala @@ -141,16 +141,17 @@ class FileSourceStrategySuite extends QueryTest with SharedSQLContext with Predi withSQLConf(SQLConf.FILES_MAX_PARTITION_BYTES.key -> "4", SQLConf.FILES_OPEN_COST_IN_BYTES.key -> "1") { checkScan(table.select('c1)) { partitions => - // Files should be laid out [(file1), (file2, file3), (file4, file5), (file6)] - assert(partitions.size == 4, "when checking partitions") - assert(partitions(0).files.size == 1, "when checking partition 1") + // Files should be laid out [(file1, file6), (file2, file3), (file4, file5)] + assert(partitions.size == 3, "when checking partitions") + assert(partitions(0).files.size == 2, "when checking partition 1") assert(partitions(1).files.size == 2, "when checking partition 2") assert(partitions(2).files.size == 2, "when checking partition 3") - assert(partitions(3).files.size == 1, "when checking partition 4") - // First partition reads (file1) + // First partition reads (file1, file6) assert(partitions(0).files(0).start == 0) assert(partitions(0).files(0).length == 2) + assert(partitions(0).files(1).start == 0) + assert(partitions(0).files(1).length == 1) // Second partition reads (file2, file3) assert(partitions(1).files(0).start == 0) @@ -163,10 +164,6 @@ class FileSourceStrategySuite extends QueryTest with SharedSQLContext with Predi assert(partitions(2).files(0).length == 1) assert(partitions(2).files(1).start == 0) assert(partitions(2).files(1).length == 1) - - // Final partition reads (file6) - assert(partitions(3).files(0).start == 0) - assert(partitions(3).files(0).length == 1) } checkPartitionSchema(StructType(Nil)) From dd242bad39cc6df7ff6c6b16642bdc92dccca6ac Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Wed, 31 Jan 2018 11:48:19 -0800 Subject: [PATCH 0252/1161] [SPARK-21525][STREAMING] Check error code from supervisor RPC. The code was ignoring the error code from the AddBlock RPC, which means that a failure to write to the WAL was being ignored by the receiver, and would lead to the block being acked (in the case of the Flume receiver) and data potentially lost. Author: Marcelo Vanzin Closes #20161 from vanzin/SPARK-21525. --- .../spark/streaming/receiver/ReceiverSupervisorImpl.scala | 4 +++- .../apache/spark/streaming/scheduler/ReceiverTracker.scala | 3 ++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala index 27644a645727c..5d38c56aa5873 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala @@ -159,7 +159,9 @@ private[streaming] class ReceiverSupervisorImpl( logDebug(s"Pushed block $blockId in ${(System.currentTimeMillis - time)} ms") val numRecords = blockStoreResult.numRecords val blockInfo = ReceivedBlockInfo(streamId, numRecords, metadataOption, blockStoreResult) - trackerEndpoint.askSync[Boolean](AddBlock(blockInfo)) + if (!trackerEndpoint.askSync[Boolean](AddBlock(blockInfo))) { + throw new SparkException("Failed to add block to receiver tracker.") + } logDebug(s"Reported block $blockId") } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala index 6f130c803f310..c74ca1918a81d 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala @@ -521,7 +521,8 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false if (active) { context.reply(addBlock(receivedBlockInfo)) } else { - throw new IllegalStateException("ReceiverTracker RpcEndpoint shut down.") + context.sendFailure( + new IllegalStateException("ReceiverTracker RpcEndpoint already shut down.")) } } }) From 9ff1d96f01e2c89acfd248db917e068b93f519a6 Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Wed, 31 Jan 2018 13:52:47 -0800 Subject: [PATCH 0253/1161] [SPARK-23281][SQL] Query produces results in incorrect order when a composite order by clause refers to both original columns and aliases ## What changes were proposed in this pull request? Here is the test snippet. ``` SQL scala> Seq[(Integer, Integer)]( | (1, 1), | (1, 3), | (2, 3), | (3, 3), | (4, null), | (5, null) | ).toDF("key", "value").createOrReplaceTempView("src") scala> sql( | """ | |SELECT MAX(value) as value, key as col2 | |FROM src | |GROUP BY key | |ORDER BY value desc, key | """.stripMargin).show +-----+----+ |value|col2| +-----+----+ | 3| 3| | 3| 2| | 3| 1| | null| 5| | null| 4| +-----+----+ ```SQL Here is the explain output : ```SQL == Parsed Logical Plan == 'Sort ['value DESC NULLS LAST, 'key ASC NULLS FIRST], true +- 'Aggregate ['key], ['MAX('value) AS value#9, 'key AS col2#10] +- 'UnresolvedRelation `src` == Analyzed Logical Plan == value: int, col2: int Project [value#9, col2#10] +- Sort [value#9 DESC NULLS LAST, col2#10 DESC NULLS LAST], true +- Aggregate [key#5], [max(value#6) AS value#9, key#5 AS col2#10] +- SubqueryAlias src +- Project [_1#2 AS key#5, _2#3 AS value#6] +- LocalRelation [_1#2, _2#3] ``` SQL The sort direction is being wrongly changed from ASC to DSC while resolving ```Sort``` in resolveAggregateFunctions. The above testcase models TPCDS-Q71 and thus we have the same issue in Q71 as well. ## How was this patch tested? A few tests are added in SQLQuerySuite. Author: Dilip Biswal Closes #20453 from dilipbiswal/local_spark. --- .../sql/catalyst/analysis/Analyzer.scala | 2 +- .../org/apache/spark/sql/SQLQuerySuite.scala | 41 ++++++++++++++++++- 2 files changed, 41 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 91cb0365a0856..251099f750cf6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1493,7 +1493,7 @@ class Analyzer( // to push down this ordering expression and can reference the original aggregate // expression instead. val needsPushDown = ArrayBuffer.empty[NamedExpression] - val evaluatedOrderings = resolvedAliasedOrdering.zip(sortOrder).map { + val evaluatedOrderings = resolvedAliasedOrdering.zip(unresolvedSortOrders).map { case (evaluated, order) => val index = originalAggExprs.indexWhere { case Alias(child, _) => child semanticEquals evaluated.child diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index ffd736d2ebbb6..8f14575c3325f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql import java.io.File -import java.math.MathContext import java.net.{MalformedURLException, URL} import java.sql.Timestamp import java.util.concurrent.atomic.AtomicBoolean @@ -1618,6 +1617,46 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } } + test("SPARK-23281: verify the correctness of sort direction on composite order by clause") { + withTempView("src") { + Seq[(Integer, Integer)]( + (1, 1), + (1, 3), + (2, 3), + (3, 3), + (4, null), + (5, null) + ).toDF("key", "value").createOrReplaceTempView("src") + + checkAnswer(sql( + """ + |SELECT MAX(value) as value, key as col2 + |FROM src + |GROUP BY key + |ORDER BY value desc, key + """.stripMargin), + Seq(Row(3, 1), Row(3, 2), Row(3, 3), Row(null, 4), Row(null, 5))) + + checkAnswer(sql( + """ + |SELECT MAX(value) as value, key as col2 + |FROM src + |GROUP BY key + |ORDER BY value desc, key desc + """.stripMargin), + Seq(Row(3, 3), Row(3, 2), Row(3, 1), Row(null, 5), Row(null, 4))) + + checkAnswer(sql( + """ + |SELECT MAX(value) as value, key as col2 + |FROM src + |GROUP BY key + |ORDER BY value asc, key desc + """.stripMargin), + Seq(Row(null, 5), Row(null, 4), Row(3, 3), Row(3, 2), Row(3, 1))) + } + } + test("run sql directly on files") { val df = spark.range(100).toDF() withTempPath(f => { From f470df2fcf14e6234c577dc1bdfac27d49b441f5 Mon Sep 17 00:00:00 2001 From: Henry Robinson Date: Thu, 1 Feb 2018 11:15:17 +0900 Subject: [PATCH 0254/1161] [SPARK-23157][SQL][FOLLOW-UP] DataFrame -> SparkDataFrame in R comment Author: Henry Robinson Closes #20443 from henryr/SPARK-23157. --- R/pkg/R/DataFrame.R | 4 ++-- python/pyspark/sql/dataframe.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 547b5ea48a555..41c3c3a89fa72 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -2090,8 +2090,8 @@ setMethod("selectExpr", #' #' @param x a SparkDataFrame. #' @param colName a column name. -#' @param col a Column expression (which must refer only to this DataFrame), or an atomic vector in -#' the length of 1 as literal value. +#' @param col a Column expression (which must refer only to this SparkDataFrame), or an atomic +#' vector in the length of 1 as literal value. #' @return A SparkDataFrame with the new column added or the existing column replaced. #' @family SparkDataFrame functions #' @aliases withColumn,SparkDataFrame,character-method diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 055b2c4a0ffec..1496cba91b90e 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1829,7 +1829,7 @@ def withColumn(self, colName, col): Returns a new :class:`DataFrame` by adding a column or replacing the existing column that has the same name. - The column expression must be an expression over this dataframe; attempting to add + The column expression must be an expression over this DataFrame; attempting to add a column from some other dataframe will raise an error. :param colName: string, name of the new column. From 52e00f70663a87b5837235bdf72a3e6f84e11411 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 1 Feb 2018 11:56:06 +0800 Subject: [PATCH 0255/1161] [SPARK-23280][SQL] add map type support to ColumnVector ## What changes were proposed in this pull request? Fill the last missing piece of `ColumnVector`: the map type support. The idea is similar to the array type support. A map is basically 2 arrays: keys and values. We ask the implementations to provide a key array, a value array, and an offset and length to specify the range of this map in the key/value array. In `WritableColumnVector`, we put the key array in first child vector, and value array in second child vector, and offsets and lengths in the current vector, which is very similar to how array type is implemented here. ## How was this patch tested? a new test Author: Wenchen Fan Closes #20450 from cloud-fan/map. --- .../datasources/orc/OrcColumnVector.java | 6 ++ .../vectorized/ColumnVectorUtils.java | 15 ++++ .../vectorized/OffHeapColumnVector.java | 4 +- .../vectorized/OnHeapColumnVector.java | 4 +- .../vectorized/WritableColumnVector.java | 13 ++++ .../sql/vectorized/ArrowColumnVector.java | 5 ++ .../spark/sql/vectorized/ColumnVector.java | 14 +++- .../spark/sql/vectorized/ColumnarArray.java | 4 +- .../spark/sql/vectorized/ColumnarMap.java | 53 ++++++++++++++ .../spark/sql/vectorized/ColumnarRow.java | 5 +- .../vectorized/ColumnarBatchSuite.scala | 70 ++++++++++++++----- 11 files changed, 166 insertions(+), 27 deletions(-) create mode 100644 sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarMap.java diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java index 78203e3145c62..c8add4c9f486c 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java @@ -25,6 +25,7 @@ import org.apache.spark.sql.types.Decimal; import org.apache.spark.sql.types.TimestampType; import org.apache.spark.sql.vectorized.ColumnarArray; +import org.apache.spark.sql.vectorized.ColumnarMap; import org.apache.spark.unsafe.types.UTF8String; /** @@ -177,6 +178,11 @@ public ColumnarArray getArray(int rowId) { throw new UnsupportedOperationException(); } + @Override + public ColumnarMap getMap(int rowId) { + throw new UnsupportedOperationException(); + } + @Override public org.apache.spark.sql.vectorized.ColumnVector getChild(int ordinal) { throw new UnsupportedOperationException(); diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java index a2853bbadc92b..829f3ce750fe6 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java @@ -20,8 +20,10 @@ import java.math.BigInteger; import java.nio.charset.StandardCharsets; import java.sql.Date; +import java.util.HashMap; import java.util.Iterator; import java.util.List; +import java.util.Map; import org.apache.spark.memory.MemoryMode; import org.apache.spark.sql.Row; @@ -30,6 +32,7 @@ import org.apache.spark.sql.types.*; import org.apache.spark.sql.vectorized.ColumnarArray; import org.apache.spark.sql.vectorized.ColumnarBatch; +import org.apache.spark.sql.vectorized.ColumnarMap; import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; @@ -109,6 +112,18 @@ public static int[] toJavaIntArray(ColumnarArray array) { return array.toIntArray(); } + public static Map toJavaIntMap(ColumnarMap map) { + int[] keys = toJavaIntArray(map.keyArray()); + int[] values = toJavaIntArray(map.valueArray()); + assert keys.length == values.length; + + Map result = new HashMap<>(); + for (int i = 0; i < keys.length; i++) { + result.put(keys[i], values[i]); + } + return result; + } + private static void appendValue(WritableColumnVector dst, DataType t, Object o) { if (o == null) { if (t instanceof CalendarIntervalType) { diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java index fa52e4a354786..754c26579ff08 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java @@ -60,7 +60,7 @@ public static OffHeapColumnVector[] allocateColumns(int capacity, StructField[] private long nulls; private long data; - // Set iff the type is array. + // Only set if type is Array or Map. private long lengthData; private long offsetData; @@ -530,7 +530,7 @@ public int putByteArray(int rowId, byte[] value, int offset, int length) { @Override protected void reserveInternal(int newCapacity) { int oldCapacity = (nulls == 0L) ? 0 : capacity; - if (isArray()) { + if (isArray() || type instanceof MapType) { this.lengthData = Platform.reallocateMemory(lengthData, oldCapacity * 4, newCapacity * 4); this.offsetData = diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java index cccef78aebdc8..23dcc104e67c4 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java @@ -69,7 +69,7 @@ public static OnHeapColumnVector[] allocateColumns(int capacity, StructField[] f private float[] floatData; private double[] doubleData; - // Only set if type is Array. + // Only set if type is Array or Map. private int[] arrayLengths; private int[] arrayOffsets; @@ -503,7 +503,7 @@ public int putByteArray(int rowId, byte[] value, int offset, int length) { // Spilt this function out since it is the slow path. @Override protected void reserveInternal(int newCapacity) { - if (isArray()) { + if (isArray() || type instanceof MapType) { int[] newLengths = new int[newCapacity]; int[] newOffsets = new int[newCapacity]; if (this.arrayLengths != null) { diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java index 8ebc1adf59c8b..c2e595455549c 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java @@ -25,6 +25,7 @@ import org.apache.spark.sql.types.*; import org.apache.spark.sql.vectorized.ColumnVector; import org.apache.spark.sql.vectorized.ColumnarArray; +import org.apache.spark.sql.vectorized.ColumnarMap; import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.types.UTF8String; @@ -612,6 +613,13 @@ public final ColumnarArray getArray(int rowId) { return new ColumnarArray(arrayData(), getArrayOffset(rowId), getArrayLength(rowId)); } + // `WritableColumnVector` puts the key array in the first child column vector, value array in the + // second child column vector, and puts the offsets and lengths in the current column vector. + @Override + public final ColumnarMap getMap(int rowId) { + return new ColumnarMap(getChild(0), getChild(1), getArrayOffset(rowId), getArrayLength(rowId)); + } + public WritableColumnVector arrayData() { return childColumns[0]; } @@ -705,6 +713,11 @@ protected WritableColumnVector(int capacity, DataType type) { for (int i = 0; i < childColumns.length; ++i) { this.childColumns[i] = reserveNewColumn(capacity, st.fields()[i].dataType()); } + } else if (type instanceof MapType) { + MapType mapType = (MapType) type; + this.childColumns = new WritableColumnVector[2]; + this.childColumns[0] = reserveNewColumn(capacity, mapType.keyType()); + this.childColumns[1] = reserveNewColumn(capacity, mapType.valueType()); } else if (type instanceof CalendarIntervalType) { // Two columns. Months as int. Microseconds as Long. this.childColumns = new WritableColumnVector[2]; diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java index 5ff6474c161f3..f3ece538c3b80 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java @@ -119,6 +119,11 @@ public ColumnarArray getArray(int rowId) { return accessor.getArray(rowId); } + @Override + public ColumnarMap getMap(int rowId) { + throw new UnsupportedOperationException(); + } + @Override public ArrowColumnVector getChild(int ordinal) { return childColumns[ordinal]; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java index d588956208047..05271ec1f46ab 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java @@ -220,10 +220,18 @@ public final ColumnarRow getStruct(int rowId) { /** * Returns the map type value for rowId. + * + * In Spark, map type value is basically a key data array and a value data array. A key from the + * key array with a index and a value from the value array with the same index contribute to + * an entry of this map type value. + * + * To support map type, implementations must construct an {@link ColumnarMap} and return it in + * this method. {@link ColumnarMap} requires a {@link ColumnVector} that stores the data of all + * the keys of all the maps in this vector, and another {@link ColumnVector} that stores the data + * of all the values of all the maps in this vector, and a pair of offset and length which + * specify the range of the key/value array that belongs to the map type value at rowId. */ - public MapData getMap(int ordinal) { - throw new UnsupportedOperationException(); - } + public abstract ColumnarMap getMap(int ordinal); /** * Returns the decimal type value for rowId. diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java index 72c07ee7cad3f..7c7a1c806a2b7 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java @@ -149,8 +149,8 @@ public ColumnarArray getArray(int ordinal) { } @Override - public MapData getMap(int ordinal) { - throw new UnsupportedOperationException(); + public ColumnarMap getMap(int ordinal) { + return data.getMap(offset + ordinal); } @Override diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarMap.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarMap.java new file mode 100644 index 0000000000000..35648e386c4f1 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarMap.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.vectorized; + +import org.apache.spark.sql.catalyst.util.MapData; + +/** + * Map abstraction in {@link ColumnVector}. + */ +public final class ColumnarMap extends MapData { + private final ColumnarArray keys; + private final ColumnarArray values; + private final int length; + + public ColumnarMap(ColumnVector keys, ColumnVector values, int offset, int length) { + this.length = length; + this.keys = new ColumnarArray(keys, offset, length); + this.values = new ColumnarArray(values, offset, length); + } + + @Override + public int numElements() { return length; } + + @Override + public ColumnarArray keyArray() { + return keys; + } + + @Override + public ColumnarArray valueArray() { + return values; + } + + @Override + public ColumnarMap copy() { + throw new UnsupportedOperationException(); + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java index 6ca749d7c6e85..0c9e92ed11fbd 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java @@ -155,8 +155,9 @@ public ColumnarArray getArray(int ordinal) { } @Override - public MapData getMap(int ordinal) { - throw new UnsupportedOperationException(); + public ColumnarMap getMap(int ordinal) { + if (data.getChild(ordinal).isNullAt(rowId)) return null; + return data.getChild(ordinal).getMap(rowId); } @Override diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala index 168bc5e3e480b..8fe2985836f2e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala @@ -673,35 +673,37 @@ class ColumnarBatchSuite extends SparkFunSuite { i += 1 } - // Populate it with arrays [0], [1, 2], [], [3, 4, 5] + // Populate it with arrays [0], [1, 2], null, [], [3, 4, 5] column.putArray(0, 0, 1) column.putArray(1, 1, 2) - column.putArray(2, 2, 0) - column.putArray(3, 3, 3) + column.putNull(2) + column.putArray(3, 3, 0) + column.putArray(4, 3, 3) + + assert(column.getArray(0).numElements == 1) + assert(column.getArray(1).numElements == 2) + assert(column.isNullAt(2)) + assert(column.getArray(3).numElements == 0) + assert(column.getArray(4).numElements == 3) val a1 = ColumnVectorUtils.toJavaIntArray(column.getArray(0)) val a2 = ColumnVectorUtils.toJavaIntArray(column.getArray(1)) - val a3 = ColumnVectorUtils.toJavaIntArray(column.getArray(2)) - val a4 = ColumnVectorUtils.toJavaIntArray(column.getArray(3)) + val a3 = ColumnVectorUtils.toJavaIntArray(column.getArray(3)) + val a4 = ColumnVectorUtils.toJavaIntArray(column.getArray(4)) assert(a1 === Array(0)) assert(a2 === Array(1, 2)) assert(a3 === Array.empty[Int]) assert(a4 === Array(3, 4, 5)) - // Verify the ArrayData APIs - assert(column.getArray(0).numElements() == 1) + // Verify the ArrayData get APIs assert(column.getArray(0).getInt(0) == 0) - assert(column.getArray(1).numElements() == 2) assert(column.getArray(1).getInt(0) == 1) assert(column.getArray(1).getInt(1) == 2) - assert(column.getArray(2).numElements() == 0) - - assert(column.getArray(3).numElements() == 3) - assert(column.getArray(3).getInt(0) == 3) - assert(column.getArray(3).getInt(1) == 4) - assert(column.getArray(3).getInt(2) == 5) + assert(column.getArray(4).getInt(0) == 3) + assert(column.getArray(4).getInt(1) == 4) + assert(column.getArray(4).getInt(2) == 5) // Add a longer array which requires resizing column.reset() @@ -711,8 +713,7 @@ class ColumnarBatchSuite extends SparkFunSuite { assert(data.capacity == array.length * 2) data.putInts(0, array.length, array, 0) column.putArray(0, 0, array.length) - assert(ColumnVectorUtils.toJavaIntArray(column.getArray(0)) - === array) + assert(ColumnVectorUtils.toJavaIntArray(column.getArray(0)) === array) } test("toArray for primitive types") { @@ -770,6 +771,43 @@ class ColumnarBatchSuite extends SparkFunSuite { } } + test("Int Map") { + (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => + val column = allocate(10, new MapType(IntegerType, IntegerType, false), memMode) + (0 to 1).foreach { colIndex => + val data = column.getChild(colIndex) + (0 to 5).foreach {i => + data.putInt(i, i * (colIndex + 1)) + } + } + + // Populate it with maps [0->0], [1->2, 2->4], null, [], [3->6, 4->8, 5->10] + column.putArray(0, 0, 1) + column.putArray(1, 1, 2) + column.putNull(2) + column.putArray(3, 3, 0) + column.putArray(4, 3, 3) + + assert(column.getMap(0).numElements == 1) + assert(column.getMap(1).numElements == 2) + assert(column.isNullAt(2)) + assert(column.getMap(3).numElements == 0) + assert(column.getMap(4).numElements == 3) + + val a1 = ColumnVectorUtils.toJavaIntMap(column.getMap(0)) + val a2 = ColumnVectorUtils.toJavaIntMap(column.getMap(1)) + val a4 = ColumnVectorUtils.toJavaIntMap(column.getMap(3)) + val a5 = ColumnVectorUtils.toJavaIntMap(column.getMap(4)) + + assert(a1.asScala == Map(0 -> 0)) + assert(a2.asScala == Map(1 -> 2, 2 -> 4)) + assert(a4.asScala == Map()) + assert(a5.asScala == Map(3 -> 6, 4 -> 8, 5 -> 10)) + + column.close() + } + } + testVector( "Struct Column", 10, From 2ac895be909de7e58e1051dc2a1bba98a25bf4be Mon Sep 17 00:00:00 2001 From: caoxuewen Date: Thu, 1 Feb 2018 12:05:12 +0800 Subject: [PATCH 0256/1161] [SPARK-23247][SQL] combines Unsafe operations and statistics operations in Scan Data Source ## What changes were proposed in this pull request? Currently, we scan the execution plan of the data source, first the unsafe operation of each row of data, and then re traverse the data for the count of rows. In terms of performance, this is not necessary. this PR combines the two operations and makes statistics on the number of rows while performing the unsafe operation. Before modified, ``` val unsafeRow = rdd.mapPartitionsWithIndexInternal { (index, iter) => val proj = UnsafeProjection.create(schema) proj.initialize(index) iter.map(proj) } val numOutputRows = longMetric("numOutputRows") unsafeRow.map { r => numOutputRows += 1 r } ``` After modified, val numOutputRows = longMetric("numOutputRows") rdd.mapPartitionsWithIndexInternal { (index, iter) => val proj = UnsafeProjection.create(schema) proj.initialize(index) iter.map( r => { numOutputRows += 1 proj(r) }) } ## How was this patch tested? the existed test cases. Author: caoxuewen Closes #20415 from heary-cao/DataSourceScanExec. --- .../sql/execution/DataSourceScanExec.scala | 45 +++++++++---------- 1 file changed, 22 insertions(+), 23 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index f7732e2098c29..ba1157d5b6a49 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -90,16 +90,15 @@ case class RowDataSourceScanExec( Map("numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) protected override def doExecute(): RDD[InternalRow] = { - val unsafeRow = rdd.mapPartitionsWithIndexInternal { (index, iter) => + val numOutputRows = longMetric("numOutputRows") + + rdd.mapPartitionsWithIndexInternal { (index, iter) => val proj = UnsafeProjection.create(schema) proj.initialize(index) - iter.map(proj) - } - - val numOutputRows = longMetric("numOutputRows") - unsafeRow.map { r => - numOutputRows += 1 - r + iter.map( r => { + numOutputRows += 1 + proj(r) + }) } } @@ -326,22 +325,22 @@ case class FileSourceScanExec( // 2) the number of columns should be smaller than spark.sql.codegen.maxFields WholeStageCodegenExec(this)(codegenStageId = 0).execute() } else { - val unsafeRows = { - val scan = inputRDD - if (needsUnsafeRowConversion) { - scan.mapPartitionsWithIndexInternal { (index, iter) => - val proj = UnsafeProjection.create(schema) - proj.initialize(index) - iter.map(proj) - } - } else { - scan - } - } val numOutputRows = longMetric("numOutputRows") - unsafeRows.map { r => - numOutputRows += 1 - r + + if (needsUnsafeRowConversion) { + inputRDD.mapPartitionsWithIndexInternal { (index, iter) => + val proj = UnsafeProjection.create(schema) + proj.initialize(index) + iter.map( r => { + numOutputRows += 1 + proj(r) + }) + } + } else { + inputRDD.map { r => + numOutputRows += 1 + r + } } } } From 56ae32657e9e5d1e30b62afe77d9e14eb07cf4fb Mon Sep 17 00:00:00 2001 From: Wang Gengliang Date: Wed, 31 Jan 2018 20:33:51 -0800 Subject: [PATCH 0257/1161] [SPARK-23268][SQL] Reorganize packages in data source V2 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? 1. create a new package for partitioning/distribution related classes. As Spark will add new concrete implementations of `Distribution` in new releases, it is good to have a new package for partitioning/distribution related classes. 2. move streaming related class to package `org.apache.spark.sql.sources.v2.reader/writer.streaming`, instead of `org.apache.spark.sql.sources.v2.streaming.reader/writer`. So that the there won't be package reader/writer inside package streaming, which is quite confusing. Before change: ``` v2 ├── reader ├── streaming │   ├── reader │   └── writer └── writer ``` After change: ``` v2 ├── reader │   └── streaming └── writer └── streaming ``` ## How was this patch tested? Unit test. Author: Wang Gengliang Closes #20435 from gengliangwang/new_pkg. --- .../spark/sql/kafka010/KafkaContinuousReader.scala | 2 +- .../apache/spark/sql/kafka010/KafkaSourceOffset.scala | 2 +- .../spark/sql/kafka010/KafkaSourceProvider.scala | 5 +++-- .../apache/spark/sql/kafka010/KafkaStreamWriter.scala | 2 +- .../{streaming => reader}/ContinuousReadSupport.java | 4 ++-- .../{streaming => reader}/MicroBatchReadSupport.java | 4 ++-- .../sources/v2/reader/SupportsReportPartitioning.java | 1 + .../{ => partitioning}/ClusteredDistribution.java | 3 ++- .../v2/reader/{ => partitioning}/Distribution.java | 3 ++- .../v2/reader/{ => partitioning}/Partitioning.java | 4 +++- .../streaming}/ContinuousDataReader.java | 2 +- .../reader => reader/streaming}/ContinuousReader.java | 2 +- .../reader => reader/streaming}/MicroBatchReader.java | 2 +- .../{streaming/reader => reader/streaming}/Offset.java | 2 +- .../reader => reader/streaming}/PartitionOffset.java | 2 +- .../spark/sql/sources/v2/writer/DataSourceWriter.java | 2 +- .../v2/{streaming => writer}/StreamWriteSupport.java | 5 ++--- .../writer => writer/streaming}/StreamWriter.java | 2 +- .../datasources/v2/DataSourcePartitioning.scala | 2 +- .../datasources/v2/DataSourceV2ScanExec.scala | 2 +- .../execution/datasources/v2/WriteToDataSourceV2.scala | 2 +- .../sql/execution/streaming/MicroBatchExecution.scala | 6 +++--- .../sql/execution/streaming/RateSourceProvider.scala | 5 ++--- .../sql/execution/streaming/RateStreamOffset.scala | 2 +- .../sql/execution/streaming/StreamingRelation.scala | 2 +- .../apache/spark/sql/execution/streaming/console.scala | 4 ++-- .../continuous/ContinuousDataSourceRDDIter.scala | 10 +++------- .../streaming/continuous/ContinuousExecution.scala | 5 +++-- .../continuous/ContinuousRateStreamSource.scala | 2 +- .../streaming/continuous/EpochCoordinator.scala | 4 ++-- .../execution/streaming/sources/ConsoleWriter.scala | 2 +- .../execution/streaming/sources/MicroBatchWriter.scala | 2 +- .../streaming/sources/RateStreamSourceV2.scala | 3 +-- .../sql/execution/streaming/sources/memoryV2.scala | 3 +-- .../apache/spark/sql/streaming/DataStreamReader.scala | 2 +- .../apache/spark/sql/streaming/DataStreamWriter.scala | 2 +- .../spark/sql/streaming/StreamingQueryManager.scala | 2 +- .../sql/sources/v2/JavaPartitionAwareDataSource.java | 3 +++ .../sql/execution/streaming/RateSourceV2Suite.scala | 2 +- .../spark/sql/sources/v2/DataSourceV2Suite.scala | 1 + .../streaming/sources/StreamingDataSourceV2Suite.scala | 8 ++++---- 41 files changed, 64 insertions(+), 61 deletions(-) rename sql/core/src/main/java/org/apache/spark/sql/sources/v2/{streaming => reader}/ContinuousReadSupport.java (94%) rename sql/core/src/main/java/org/apache/spark/sql/sources/v2/{streaming => reader}/MicroBatchReadSupport.java (95%) rename sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/{ => partitioning}/ClusteredDistribution.java (92%) rename sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/{ => partitioning}/Distribution.java (93%) rename sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/{ => partitioning}/Partitioning.java (90%) rename sql/core/src/main/java/org/apache/spark/sql/sources/v2/{streaming/reader => reader/streaming}/ContinuousDataReader.java (96%) rename sql/core/src/main/java/org/apache/spark/sql/sources/v2/{streaming/reader => reader/streaming}/ContinuousReader.java (98%) rename sql/core/src/main/java/org/apache/spark/sql/sources/v2/{streaming/reader => reader/streaming}/MicroBatchReader.java (98%) rename sql/core/src/main/java/org/apache/spark/sql/sources/v2/{streaming/reader => reader/streaming}/Offset.java (97%) rename sql/core/src/main/java/org/apache/spark/sql/sources/v2/{streaming/reader => reader/streaming}/PartitionOffset.java (95%) rename sql/core/src/main/java/org/apache/spark/sql/sources/v2/{streaming => writer}/StreamWriteSupport.java (93%) rename sql/core/src/main/java/org/apache/spark/sql/sources/v2/{streaming/writer => writer/streaming}/StreamWriter.java (98%) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala index 8c733426b256f..41c443bc12120 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.{BufferHolder, UnsafeRo import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.kafka010.KafkaSource.{INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE, INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE} import org.apache.spark.sql.sources.v2.reader._ -import org.apache.spark.sql.sources.v2.streaming.reader.{ContinuousDataReader, ContinuousReader, Offset, PartitionOffset} +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousDataReader, ContinuousReader, Offset, PartitionOffset} import org.apache.spark.sql.types.StructType import org.apache.spark.unsafe.types.UTF8String diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceOffset.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceOffset.scala index c82154cfbad7f..8d41c0da2b133 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceOffset.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceOffset.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.kafka010 import org.apache.kafka.common.TopicPartition import org.apache.spark.sql.execution.streaming.{Offset, SerializedOffset} -import org.apache.spark.sql.sources.v2.streaming.reader.{Offset => OffsetV2, PartitionOffset} +import org.apache.spark.sql.sources.v2.reader.streaming.{Offset => OffsetV2, PartitionOffset} /** * An [[Offset]] for the [[KafkaSource]]. This one tracks all partitions of subscribed topics and diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala index 85e96b6783327..694ca76e24964 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala @@ -31,8 +31,9 @@ import org.apache.spark.sql.{AnalysisException, DataFrame, SaveMode, SparkSessio import org.apache.spark.sql.execution.streaming.{Sink, Source} import org.apache.spark.sql.sources._ import org.apache.spark.sql.sources.v2.DataSourceOptions -import org.apache.spark.sql.sources.v2.streaming.{ContinuousReadSupport, StreamWriteSupport} -import org.apache.spark.sql.sources.v2.streaming.writer.StreamWriter +import org.apache.spark.sql.sources.v2.reader.ContinuousReadSupport +import org.apache.spark.sql.sources.v2.writer.StreamWriteSupport +import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamWriter.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamWriter.scala index a24efdefa4464..9307bfc001c03 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamWriter.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamWriter.scala @@ -22,8 +22,8 @@ import scala.collection.JavaConverters._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.kafka010.KafkaWriter.validateQuery -import org.apache.spark.sql.sources.v2.streaming.writer.StreamWriter import org.apache.spark.sql.sources.v2.writer._ +import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter import org.apache.spark.sql.types.StructType /** diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/ContinuousReadSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousReadSupport.java similarity index 94% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/ContinuousReadSupport.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousReadSupport.java index f79424e036a52..0c1d5d1a9577a 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/ContinuousReadSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousReadSupport.java @@ -15,14 +15,14 @@ * limitations under the License. */ -package org.apache.spark.sql.sources.v2.streaming; +package org.apache.spark.sql.sources.v2.reader; import java.util.Optional; import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.sources.v2.DataSourceV2; import org.apache.spark.sql.sources.v2.DataSourceOptions; -import org.apache.spark.sql.sources.v2.streaming.reader.ContinuousReader; +import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousReader; import org.apache.spark.sql.types.StructType; /** diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/MicroBatchReadSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/MicroBatchReadSupport.java similarity index 95% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/MicroBatchReadSupport.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/MicroBatchReadSupport.java index 22660e42ad850..5e8f0c0dafdcf 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/MicroBatchReadSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/MicroBatchReadSupport.java @@ -15,14 +15,14 @@ * limitations under the License. */ -package org.apache.spark.sql.sources.v2.streaming; +package org.apache.spark.sql.sources.v2.reader; import java.util.Optional; import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.sources.v2.DataSourceOptions; import org.apache.spark.sql.sources.v2.DataSourceV2; -import org.apache.spark.sql.sources.v2.streaming.reader.MicroBatchReader; +import org.apache.spark.sql.sources.v2.reader.streaming.MicroBatchReader; import org.apache.spark.sql.types.StructType; /** diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java index a2383a9d7d680..5405a916951b8 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java @@ -18,6 +18,7 @@ package org.apache.spark.sql.sources.v2.reader; import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.sources.v2.reader.partitioning.Partitioning; /** * A mix in interface for {@link DataSourceReader}. Data source readers can implement this diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ClusteredDistribution.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/ClusteredDistribution.java similarity index 92% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ClusteredDistribution.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/ClusteredDistribution.java index 27905e325df87..2d0ee50212b56 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ClusteredDistribution.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/ClusteredDistribution.java @@ -15,9 +15,10 @@ * limitations under the License. */ -package org.apache.spark.sql.sources.v2.reader; +package org.apache.spark.sql.sources.v2.reader.partitioning; import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.sources.v2.reader.DataReader; /** * A concrete implementation of {@link Distribution}. Represents a distribution where records that diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Distribution.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Distribution.java similarity index 93% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Distribution.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Distribution.java index b37562167d9ef..f6b111fdf220d 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Distribution.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Distribution.java @@ -15,9 +15,10 @@ * limitations under the License. */ -package org.apache.spark.sql.sources.v2.reader; +package org.apache.spark.sql.sources.v2.reader.partitioning; import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.sources.v2.reader.DataReader; /** * An interface to represent data distribution requirement, which specifies how the records should diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Partitioning.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Partitioning.java similarity index 90% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Partitioning.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Partitioning.java index 5e334d13a1215..309d9e5de0a0f 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Partitioning.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Partitioning.java @@ -15,9 +15,11 @@ * limitations under the License. */ -package org.apache.spark.sql.sources.v2.reader; +package org.apache.spark.sql.sources.v2.reader.partitioning; import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.sources.v2.reader.DataReaderFactory; +import org.apache.spark.sql.sources.v2.reader.SupportsReportPartitioning; /** * An interface to represent the output data partitioning for a data source, which is returned by diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/ContinuousDataReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousDataReader.java similarity index 96% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/ContinuousDataReader.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousDataReader.java index 3f13a4dbf5793..47d26440841fd 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/ContinuousDataReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousDataReader.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.sources.v2.streaming.reader; +package org.apache.spark.sql.sources.v2.reader.streaming; import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.sources.v2.reader.DataReader; diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/ContinuousReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousReader.java similarity index 98% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/ContinuousReader.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousReader.java index 6e5177ee83a62..d1d1e7ffd1dd4 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/ContinuousReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousReader.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.sources.v2.streaming.reader; +package org.apache.spark.sql.sources.v2.reader.streaming; import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.execution.streaming.BaseStreamingSource; diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/MicroBatchReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/MicroBatchReader.java similarity index 98% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/MicroBatchReader.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/MicroBatchReader.java index fcec446d892f5..67ebde30d61a9 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/MicroBatchReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/MicroBatchReader.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.sources.v2.streaming.reader; +package org.apache.spark.sql.sources.v2.reader.streaming; import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.sources.v2.reader.DataSourceReader; diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/Offset.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/Offset.java similarity index 97% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/Offset.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/Offset.java index abba3e7188b13..e41c0351edc82 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/Offset.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/Offset.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.sources.v2.streaming.reader; +package org.apache.spark.sql.sources.v2.reader.streaming; import org.apache.spark.annotation.InterfaceStability; diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/PartitionOffset.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/PartitionOffset.java similarity index 95% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/PartitionOffset.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/PartitionOffset.java index 4688b85f49f5f..383e73db6762b 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/reader/PartitionOffset.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/PartitionOffset.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.sources.v2.streaming.reader; +package org.apache.spark.sql.sources.v2.reader.streaming; import java.io.Serializable; diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java index d89d27d0e5b1b..7096aec0d22c2 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java @@ -28,7 +28,7 @@ /** * A data source writer that is returned by * {@link WriteSupport#createWriter(String, StructType, SaveMode, DataSourceOptions)}/ - * {@link org.apache.spark.sql.sources.v2.streaming.StreamWriteSupport#createStreamWriter( + * {@link StreamWriteSupport#createStreamWriter( * String, StructType, OutputMode, DataSourceOptions)}. * It can mix in various writing optimization interfaces to speed up the data saving. The actual * writing logic is delegated to {@link DataWriter}. diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/StreamWriteSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/StreamWriteSupport.java similarity index 93% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/StreamWriteSupport.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/StreamWriteSupport.java index 7c5f304425093..1c0e2e12f8d51 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/StreamWriteSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/StreamWriteSupport.java @@ -15,14 +15,13 @@ * limitations under the License. */ -package org.apache.spark.sql.sources.v2.streaming; +package org.apache.spark.sql.sources.v2.writer; import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.execution.streaming.BaseStreamingSink; import org.apache.spark.sql.sources.v2.DataSourceOptions; import org.apache.spark.sql.sources.v2.DataSourceV2; -import org.apache.spark.sql.sources.v2.streaming.writer.StreamWriter; -import org.apache.spark.sql.sources.v2.writer.DataSourceWriter; +import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter; import org.apache.spark.sql.streaming.OutputMode; import org.apache.spark.sql.types.StructType; diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/writer/StreamWriter.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamWriter.java similarity index 98% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/writer/StreamWriter.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamWriter.java index 915ee6c4fb390..4913341bd505d 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/streaming/writer/StreamWriter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamWriter.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.sources.v2.streaming.writer; +package org.apache.spark.sql.sources.v2.writer.streaming; import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.sources.v2.writer.DataSourceWriter; diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourcePartitioning.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourcePartitioning.scala index 943d0100aca56..017a6737161a6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourcePartitioning.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourcePartitioning.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.datasources.v2 import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, Expression} import org.apache.spark.sql.catalyst.plans.physical -import org.apache.spark.sql.sources.v2.reader.{ClusteredDistribution, Partitioning} +import org.apache.spark.sql.sources.v2.reader.partitioning.{ClusteredDistribution, Partitioning} /** * An adapter from public data source partitioning to catalyst internal `Partitioning`. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala index ee085820b0775..df469af2c262a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.plans.physical import org.apache.spark.sql.execution.{ColumnarBatchScan, LeafExecNode, WholeStageCodegenExec} import org.apache.spark.sql.execution.streaming.continuous._ import org.apache.spark.sql.sources.v2.reader._ -import org.apache.spark.sql.sources.v2.streaming.reader.ContinuousReader +import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousReader import org.apache.spark.sql.types.StructType /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala index c544adbf32cdf..6592bd72fa338 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala @@ -27,8 +27,8 @@ import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.streaming.continuous.{CommitPartitionEpoch, ContinuousExecution, EpochCoordinatorRef, SetWriterPartitions} -import org.apache.spark.sql.sources.v2.streaming.writer.StreamWriter import org.apache.spark.sql.sources.v2.writer._ +import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index 93572f7a63132..d9aa8573ba930 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -30,9 +30,9 @@ import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.datasources.v2.{StreamingDataSourceV2Relation, WriteToDataSourceV2} import org.apache.spark.sql.execution.streaming.sources.{InternalRowMicroBatchWriter, MicroBatchWriter} import org.apache.spark.sql.sources.v2.DataSourceOptions -import org.apache.spark.sql.sources.v2.streaming.{MicroBatchReadSupport, StreamWriteSupport} -import org.apache.spark.sql.sources.v2.streaming.reader.{MicroBatchReader, Offset => OffsetV2} -import org.apache.spark.sql.sources.v2.writer.SupportsWriteInternalRow +import org.apache.spark.sql.sources.v2.reader.MicroBatchReadSupport +import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset => OffsetV2} +import org.apache.spark.sql.sources.v2.writer.{StreamWriteSupport, SupportsWriteInternalRow} import org.apache.spark.sql.streaming.{OutputMode, ProcessingTime, Trigger} import org.apache.spark.util.{Clock, Utils} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala index 5e3fee633f591..ce5e63f5bde85 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala @@ -30,11 +30,10 @@ import org.apache.spark.sql.{AnalysisException, DataFrame, SQLContext} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} import org.apache.spark.sql.execution.streaming.continuous.RateStreamContinuousReader -import org.apache.spark.sql.execution.streaming.sources.RateStreamMicroBatchReader import org.apache.spark.sql.sources.{DataSourceRegister, StreamSourceProvider} import org.apache.spark.sql.sources.v2._ -import org.apache.spark.sql.sources.v2.streaming.{ContinuousReadSupport, MicroBatchReadSupport} -import org.apache.spark.sql.sources.v2.streaming.reader.{ContinuousReader, MicroBatchReader} +import org.apache.spark.sql.sources.v2.reader.ContinuousReadSupport +import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousReader import org.apache.spark.sql.types._ import org.apache.spark.util.{ManualClock, SystemClock} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateStreamOffset.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateStreamOffset.scala index 261d69bbd9843..02fed50485b94 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateStreamOffset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateStreamOffset.scala @@ -23,7 +23,7 @@ import org.json4s.jackson.Serialization import org.apache.spark.sql.sources.v2 case class RateStreamOffset(partitionToValueAndRunTimeMs: Map[Int, ValueRunTimeMsPair]) - extends v2.streaming.reader.Offset { + extends v2.reader.streaming.Offset { implicit val defaultFormats: DefaultFormats = DefaultFormats override val json = Serialization.write(partitionToValueAndRunTimeMs) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala index a0ee683a895d8..845c8d2c14e43 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.plans.logical.Statistics import org.apache.spark.sql.execution.LeafExecNode import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.sources.v2.DataSourceV2 -import org.apache.spark.sql.sources.v2.streaming.ContinuousReadSupport +import org.apache.spark.sql.sources.v2.reader.ContinuousReadSupport object StreamingRelation { def apply(dataSource: DataSource): StreamingRelation = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala index 3f5bb489d6528..db600866067bc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala @@ -21,8 +21,8 @@ import org.apache.spark.sql._ import org.apache.spark.sql.execution.streaming.sources.ConsoleWriter import org.apache.spark.sql.sources.{BaseRelation, CreatableRelationProvider, DataSourceRegister} import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2} -import org.apache.spark.sql.sources.v2.streaming.StreamWriteSupport -import org.apache.spark.sql.sources.v2.streaming.writer.StreamWriter +import org.apache.spark.sql.sources.v2.writer.StreamWriteSupport +import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala index 8a7a38b22caca..cf02c0dda25d7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala @@ -18,23 +18,19 @@ package org.apache.spark.sql.execution.streaming.continuous import java.util.concurrent.{ArrayBlockingQueue, BlockingQueue, TimeUnit} -import java.util.concurrent.atomic.{AtomicBoolean, AtomicLong} +import java.util.concurrent.atomic.AtomicBoolean import scala.collection.JavaConverters._ import org.apache.spark._ import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD -import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.sql.{Row, SQLContext} import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.execution.datasources.v2.{DataSourceRDDPartition, RowToUnsafeDataReader} -import org.apache.spark.sql.execution.streaming._ -import org.apache.spark.sql.execution.streaming.continuous._ import org.apache.spark.sql.sources.v2.reader._ -import org.apache.spark.sql.sources.v2.streaming.reader.{ContinuousDataReader, PartitionOffset} -import org.apache.spark.sql.streaming.ProcessingTime -import org.apache.spark.util.{SystemClock, ThreadUtils} +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousDataReader, PartitionOffset} +import org.apache.spark.util.ThreadUtils class ContinuousDataSourceRDD( sc: SparkContext, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index 9402d7c1dcefd..08c81419a9d34 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -32,8 +32,9 @@ import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, StreamingDataSourceV2Relation, WriteToDataSourceV2} import org.apache.spark.sql.execution.streaming.{ContinuousExecutionRelation, StreamingRelationV2, _} import org.apache.spark.sql.sources.v2.DataSourceOptions -import org.apache.spark.sql.sources.v2.streaming.{ContinuousReadSupport, StreamWriteSupport} -import org.apache.spark.sql.sources.v2.streaming.reader.{ContinuousReader, PartitionOffset} +import org.apache.spark.sql.sources.v2.reader.ContinuousReadSupport +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReader, PartitionOffset} +import org.apache.spark.sql.sources.v2.writer.StreamWriteSupport import org.apache.spark.sql.streaming.{OutputMode, ProcessingTime, Trigger} import org.apache.spark.sql.types.StructType import org.apache.spark.util.{Clock, Utils} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala index ff028ebc4236a..0eaaa4889ba9e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.execution.streaming.{RateSourceProvider, RateStreamO import org.apache.spark.sql.execution.streaming.sources.RateStreamSourceV2 import org.apache.spark.sql.sources.v2.DataSourceOptions import org.apache.spark.sql.sources.v2.reader._ -import org.apache.spark.sql.sources.v2.streaming.reader.{ContinuousDataReader, ContinuousReader, Offset, PartitionOffset} +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousDataReader, ContinuousReader, Offset, PartitionOffset} import org.apache.spark.sql.types.StructType case class RateStreamPartitionOffset( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala index 84d262116cb46..cc6808065c0cd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala @@ -23,9 +23,9 @@ import org.apache.spark.SparkEnv import org.apache.spark.internal.Logging import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, ThreadSafeRpcEndpoint} import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.sources.v2.streaming.reader.{ContinuousReader, PartitionOffset} -import org.apache.spark.sql.sources.v2.streaming.writer.StreamWriter +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReader, PartitionOffset} import org.apache.spark.sql.sources.v2.writer.WriterCommitMessage +import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter import org.apache.spark.util.RpcUtils private[continuous] sealed trait EpochCoordinatorMessage extends Serializable diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala index c57bdc4a28905..d276403190b3c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala @@ -22,8 +22,8 @@ import scala.collection.JavaConverters._ import org.apache.spark.internal.Logging import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.sql.sources.v2.DataSourceOptions -import org.apache.spark.sql.sources.v2.streaming.writer.StreamWriter import org.apache.spark.sql.sources.v2.writer.{DataWriterFactory, WriterCommitMessage} +import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter import org.apache.spark.sql.types.StructType /** Common methods used to create writes for the the console sink */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWriter.scala index d7ce9a7b84479..56f7ff25cbed0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWriter.scala @@ -19,8 +19,8 @@ package org.apache.spark.sql.execution.streaming.sources import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.sources.v2.streaming.writer.StreamWriter import org.apache.spark.sql.sources.v2.writer.{DataSourceWriter, DataWriterFactory, SupportsWriteInternalRow, WriterCommitMessage} +import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter /** * A [[DataSourceWriter]] used to hook V2 stream writers into a microbatch plan. It implements diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala index 43949e6180aaa..1315885da8a6f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala @@ -31,8 +31,7 @@ import org.apache.spark.sql.execution.streaming.{RateStreamOffset, ValueRunTimeM import org.apache.spark.sql.sources.DataSourceRegister import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2} import org.apache.spark.sql.sources.v2.reader._ -import org.apache.spark.sql.sources.v2.streaming.MicroBatchReadSupport -import org.apache.spark.sql.sources.v2.streaming.reader.{MicroBatchReader, Offset} +import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset} import org.apache.spark.sql.types.{LongType, StructField, StructType, TimestampType} import org.apache.spark.util.{ManualClock, SystemClock} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala index 58767261dc684..3411edbc53412 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala @@ -30,9 +30,8 @@ import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics} import org.apache.spark.sql.catalyst.streaming.InternalOutputModes.{Append, Complete, Update} import org.apache.spark.sql.execution.streaming.Sink import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2} -import org.apache.spark.sql.sources.v2.streaming.StreamWriteSupport -import org.apache.spark.sql.sources.v2.streaming.writer.StreamWriter import org.apache.spark.sql.sources.v2.writer._ +import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index f1b3f93c4e1fc..116ac3da07b75 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming.{StreamingRelation, StreamingRelationV2} import org.apache.spark.sql.sources.StreamSourceProvider import org.apache.spark.sql.sources.v2.DataSourceOptions -import org.apache.spark.sql.sources.v2.streaming.{ContinuousReadSupport, MicroBatchReadSupport} +import org.apache.spark.sql.sources.v2.reader.{ContinuousReadSupport, MicroBatchReadSupport} import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala index 3b5b30d77945c..9aac360fd4bbc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger import org.apache.spark.sql.execution.streaming.sources.{MemoryPlanV2, MemorySinkV2} -import org.apache.spark.sql.sources.v2.streaming.StreamWriteSupport +import org.apache.spark.sql.sources.v2.writer.StreamWriteSupport /** * Interface used to write a streaming `Dataset` to external storage systems (e.g. file systems, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala index fdd709cdb1f38..ddb1edc433d5a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous.{ContinuousExecution, ContinuousTrigger} import org.apache.spark.sql.execution.streaming.state.StateStoreCoordinatorRef import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.sources.v2.streaming.StreamWriteSupport +import org.apache.spark.sql.sources.v2.writer.StreamWriteSupport import org.apache.spark.util.{Clock, SystemClock, Utils} /** diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java index 99cca0f6dd626..32fad59b97ff6 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java @@ -27,6 +27,9 @@ import org.apache.spark.sql.sources.v2.DataSourceV2; import org.apache.spark.sql.sources.v2.ReadSupport; import org.apache.spark.sql.sources.v2.reader.*; +import org.apache.spark.sql.sources.v2.reader.partitioning.ClusteredDistribution; +import org.apache.spark.sql.sources.v2.reader.partitioning.Distribution; +import org.apache.spark.sql.sources.v2.reader.partitioning.Partitioning; import org.apache.spark.sql.types.StructType; public class JavaPartitionAwareDataSource implements DataSourceV2, ReadSupport { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala index b060aeeef811d..3158995ec62f1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming.continuous._ import org.apache.spark.sql.execution.streaming.sources.{RateStreamBatchTask, RateStreamMicroBatchReader, RateStreamSourceV2} import org.apache.spark.sql.sources.v2.DataSourceOptions -import org.apache.spark.sql.sources.v2.streaming.{ContinuousReadSupport, MicroBatchReadSupport} +import org.apache.spark.sql.sources.v2.reader.{ContinuousReadSupport, MicroBatchReadSupport} import org.apache.spark.sql.streaming.StreamTest import org.apache.spark.util.ManualClock diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala index ee50e8a92270b..2f49b07018aaf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector import org.apache.spark.sql.sources.{Filter, GreaterThan} import org.apache.spark.sql.sources.v2.reader._ +import org.apache.spark.sql.sources.v2.reader.partitioning.{ClusteredDistribution, Distribution, Partitioning} import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{IntegerType, StructType} import org.apache.spark.sql.vectorized.ColumnarBatch diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala index 3127d664d32dc..cb873ab688e96 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala @@ -26,10 +26,10 @@ import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.{DataSourceRegister, StreamSinkProvider} import org.apache.spark.sql.sources.v2.DataSourceOptions -import org.apache.spark.sql.sources.v2.reader.DataReaderFactory -import org.apache.spark.sql.sources.v2.streaming._ -import org.apache.spark.sql.sources.v2.streaming.reader.{ContinuousReader, MicroBatchReader, Offset, PartitionOffset} -import org.apache.spark.sql.sources.v2.streaming.writer.StreamWriter +import org.apache.spark.sql.sources.v2.reader.{ContinuousReadSupport, DataReaderFactory, MicroBatchReadSupport} +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReader, MicroBatchReader, Offset, PartitionOffset} +import org.apache.spark.sql.sources.v2.writer.StreamWriteSupport +import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter import org.apache.spark.sql.streaming.{OutputMode, StreamTest, Trigger} import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils From b2e7677f4d3d8f47f5f148680af39d38f2b558f0 Mon Sep 17 00:00:00 2001 From: Atallah Hezbor Date: Wed, 31 Jan 2018 20:45:55 -0800 Subject: [PATCH 0258/1161] [SPARK-21396][SQL] Fixes MatchError when UDTs are passed through Hive Thriftserver Signed-off-by: Atallah Hezbor ## What changes were proposed in this pull request? This PR proposes modifying the match statement that gets the columns of a row in HiveThriftServer. There was previously no case for `UserDefinedType`, so querying a table that contained them would throw a match error. The changes catch that case and return the string representation. ## How was this patch tested? While I would have liked to add a unit test, I couldn't easily incorporate UDTs into the ``HiveThriftServer2Suites`` pipeline. With some guidance I would be happy to push a commit with tests. Instead I did a manual test by loading a `DataFrame` with Point UDT in a spark shell with a HiveThriftServer. Then in beeline, connecting to the server and querying that table. Here is the result before the change ``` 0: jdbc:hive2://localhost:10000> select * from chicago; Error: scala.MatchError: org.apache.spark.sql.PointUDT2d980dc3 (of class org.apache.spark.sql.PointUDT) (state=,code=0) ``` And after the change: ``` 0: jdbc:hive2://localhost:10000> select * from chicago; +---------------------------------------+--------------+------------------------+---------------------+--+ | __fid__ | case_number | dtg | geom | +---------------------------------------+--------------+------------------------+---------------------+--+ | 109602f9-54f8-414b-8c6f-42b1a337643e | 2 | 2016-01-01 19:00:00.0 | POINT (-77 38) | | 709602f9-fcff-4429-8027-55649b6fd7ed | 1 | 2015-12-31 19:00:00.0 | POINT (-76.5 38.5) | | 009602f9-fcb5-45b1-a867-eb8ba10cab40 | 3 | 2016-01-02 19:00:00.0 | POINT (-78 39) | +---------------------------------------+--------------+------------------------+---------------------+--+ ``` Author: Atallah Hezbor Closes #20385 from atallahhezbor/udts_over_hive. --- .../thriftserver/SparkExecuteStatementOperation.scala | 2 +- .../main/scala/org/apache/spark/sql/hive/HiveUtils.scala | 1 + .../scala/org/apache/spark/sql/hive/HiveUtilsSuite.scala | 8 +++++++- 3 files changed, 9 insertions(+), 2 deletions(-) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala index 664bc20601eaa..3cfc81b8a9579 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala @@ -102,7 +102,7 @@ private[hive] class SparkExecuteStatementOperation( to += from.getAs[Timestamp](ordinal) case BinaryType => to += from.getAs[Array[Byte]](ordinal) - case _: ArrayType | _: StructType | _: MapType => + case _: ArrayType | _: StructType | _: MapType | _: UserDefinedType[_] => val hiveString = HiveUtils.toHiveString((from.get(ordinal), dataTypes(ordinal))) to += hiveString } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala index c7717d70c996f..d9627eb9790eb 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala @@ -460,6 +460,7 @@ private[spark] object HiveUtils extends Logging { case (decimal: java.math.BigDecimal, DecimalType()) => // Hive strips trailing zeros so use its toString HiveDecimal.create(decimal).toString + case (other, _ : UserDefinedType[_]) => other.toString case (other, tpe) if primitiveTypes contains tpe => other.toString } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveUtilsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveUtilsSuite.scala index 8697d47e89e89..f2b75e4b23f02 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveUtilsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveUtilsSuite.scala @@ -25,7 +25,7 @@ import org.apache.spark.SparkConf import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.sql.QueryTest import org.apache.spark.sql.hive.test.TestHiveSingleton -import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.test.{ExamplePoint, ExamplePointUDT, SQLTestUtils} import org.apache.spark.util.{ChildFirstURLClassLoader, MutableURLClassLoader} class HiveUtilsSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { @@ -62,4 +62,10 @@ class HiveUtilsSuite extends QueryTest with SQLTestUtils with TestHiveSingleton Thread.currentThread().setContextClassLoader(contextClassLoader) } } + + test("toHiveString correctly handles UDTs") { + val point = new ExamplePoint(50.0, 50.0) + val tpe = new ExamplePointUDT() + assert(HiveUtils.toHiveString((point, tpe)) === "(50.0, 50.0)") + } } From cc41245fa3f954f961541bf4b4275c28473042b8 Mon Sep 17 00:00:00 2001 From: Xingbo Jiang Date: Thu, 1 Feb 2018 12:56:07 +0800 Subject: [PATCH 0259/1161] [SPARK-23188][SQL] Make vectorized columar reader batch size configurable ## What changes were proposed in this pull request? This PR include the following changes: - Make the capacity of `VectorizedParquetRecordReader` configurable; - Make the capacity of `OrcColumnarBatchReader` configurable; - Update the error message when required capacity in writable columnar vector cannot be fulfilled. ## How was this patch tested? N/A Author: Xingbo Jiang Closes #20361 from jiangxb1987/vectorCapacity. --- .../apache/spark/sql/internal/SQLConf.scala | 16 ++++++++++++++ .../orc/OrcColumnarBatchReader.java | 22 ++++++++++--------- .../VectorizedParquetRecordReader.java | 20 ++++++++--------- .../vectorized/WritableColumnVector.java | 7 ++++-- .../datasources/orc/OrcFileFormat.scala | 3 ++- .../parquet/ParquetFileFormat.scala | 3 ++- .../parquet/ParquetEncodingSuite.scala | 12 +++++++--- .../datasources/parquet/ParquetIOSuite.scala | 21 +++++++++++++----- .../parquet/ParquetReadBenchmark.scala | 11 +++++++--- 9 files changed, 78 insertions(+), 37 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 7394a0d7cf983..90654e67457e0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -375,6 +375,12 @@ object SQLConf { .booleanConf .createWithDefault(true) + val PARQUET_VECTORIZED_READER_BATCH_SIZE = buildConf("spark.sql.parquet.columnarReaderBatchSize") + .doc("The number of rows to include in a parquet vectorized reader batch. The number should " + + "be carefully chosen to minimize overhead and avoid OOMs in reading data.") + .intConf + .createWithDefault(4096) + val ORC_COMPRESSION = buildConf("spark.sql.orc.compression.codec") .doc("Sets the compression codec used when writing ORC files. If either `compression` or " + "`orc.compress` is specified in the table-specific options/properties, the precedence " + @@ -398,6 +404,12 @@ object SQLConf { .booleanConf .createWithDefault(true) + val ORC_VECTORIZED_READER_BATCH_SIZE = buildConf("spark.sql.orc.columnarReaderBatchSize") + .doc("The number of rows to include in a orc vectorized reader batch. The number should " + + "be carefully chosen to minimize overhead and avoid OOMs in reading data.") + .intConf + .createWithDefault(4096) + val ORC_COPY_BATCH_TO_SPARK = buildConf("spark.sql.orc.copyBatchToSpark") .doc("Whether or not to copy the ORC columnar batch to Spark columnar batch in the " + "vectorized ORC reader.") @@ -1250,10 +1262,14 @@ class SQLConf extends Serializable with Logging { def orcVectorizedReaderEnabled: Boolean = getConf(ORC_VECTORIZED_READER_ENABLED) + def orcVectorizedReaderBatchSize: Int = getConf(ORC_VECTORIZED_READER_BATCH_SIZE) + def parquetCompressionCodec: String = getConf(PARQUET_COMPRESSION) def parquetVectorizedReaderEnabled: Boolean = getConf(PARQUET_VECTORIZED_READER_ENABLED) + def parquetVectorizedReaderBatchSize: Int = getConf(PARQUET_VECTORIZED_READER_BATCH_SIZE) + def columnBatchSize: Int = getConf(COLUMN_BATCH_SIZE) def numShufflePartitions: Int = getConf(SHUFFLE_PARTITIONS) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java index 5e7cad470e1d1..dcebdc39f0aa2 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java @@ -49,8 +49,9 @@ * After creating, `initialize` and `initBatch` should be called sequentially. */ public class OrcColumnarBatchReader extends RecordReader { - // TODO: make this configurable. - private static final int CAPACITY = 4 * 1024; + + // The capacity of vectorized batch. + private int capacity; // Vectorized ORC Row Batch private VectorizedRowBatch batch; @@ -81,9 +82,10 @@ public class OrcColumnarBatchReader extends RecordReader { // Whether or not to copy the ORC columnar batch to Spark columnar batch. private final boolean copyToSpark; - public OrcColumnarBatchReader(boolean useOffHeap, boolean copyToSpark) { + public OrcColumnarBatchReader(boolean useOffHeap, boolean copyToSpark, int capacity) { MEMORY_MODE = useOffHeap ? MemoryMode.OFF_HEAP : MemoryMode.ON_HEAP; this.copyToSpark = copyToSpark; + this.capacity = capacity; } @@ -148,7 +150,7 @@ public void initBatch( StructField[] requiredFields, StructType partitionSchema, InternalRow partitionValues) { - batch = orcSchema.createRowBatch(CAPACITY); + batch = orcSchema.createRowBatch(capacity); assert(!batch.selectedInUse); // `selectedInUse` should be initialized with `false`. this.requiredFields = requiredFields; @@ -162,15 +164,15 @@ public void initBatch( if (copyToSpark) { if (MEMORY_MODE == MemoryMode.OFF_HEAP) { - columnVectors = OffHeapColumnVector.allocateColumns(CAPACITY, resultSchema); + columnVectors = OffHeapColumnVector.allocateColumns(capacity, resultSchema); } else { - columnVectors = OnHeapColumnVector.allocateColumns(CAPACITY, resultSchema); + columnVectors = OnHeapColumnVector.allocateColumns(capacity, resultSchema); } // Initialize the missing columns once. for (int i = 0; i < requiredFields.length; i++) { if (requestedColIds[i] == -1) { - columnVectors[i].putNulls(0, CAPACITY); + columnVectors[i].putNulls(0, capacity); columnVectors[i].setIsConstant(); } } @@ -193,8 +195,8 @@ public void initBatch( int colId = requestedColIds[i]; // Initialize the missing columns once. if (colId == -1) { - OnHeapColumnVector missingCol = new OnHeapColumnVector(CAPACITY, dt); - missingCol.putNulls(0, CAPACITY); + OnHeapColumnVector missingCol = new OnHeapColumnVector(capacity, dt); + missingCol.putNulls(0, capacity); missingCol.setIsConstant(); orcVectorWrappers[i] = missingCol; } else { @@ -206,7 +208,7 @@ public void initBatch( int partitionIdx = requiredFields.length; for (int i = 0; i < partitionValues.numFields(); i++) { DataType dt = partitionSchema.fields()[i].dataType(); - OnHeapColumnVector partitionCol = new OnHeapColumnVector(CAPACITY, dt); + OnHeapColumnVector partitionCol = new OnHeapColumnVector(capacity, dt); ColumnVectorUtils.populate(partitionCol, partitionValues, i); partitionCol.setIsConstant(); orcVectorWrappers[partitionIdx + i] = partitionCol; diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java index bb1b23611a7d7..5934a23db8af1 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java @@ -50,8 +50,9 @@ * TODO: make this always return ColumnarBatches. */ public class VectorizedParquetRecordReader extends SpecificParquetRecordReaderBase { - // TODO: make this configurable. - private static final int CAPACITY = 4 * 1024; + + // The capacity of vectorized batch. + private int capacity; /** * Batch of rows that we assemble and the current index we've returned. Every time this @@ -115,13 +116,10 @@ public class VectorizedParquetRecordReader extends SpecificParquetRecordReaderBa */ private final MemoryMode MEMORY_MODE; - public VectorizedParquetRecordReader(TimeZone convertTz, boolean useOffHeap) { + public VectorizedParquetRecordReader(TimeZone convertTz, boolean useOffHeap, int capacity) { this.convertTz = convertTz; MEMORY_MODE = useOffHeap ? MemoryMode.OFF_HEAP : MemoryMode.ON_HEAP; - } - - public VectorizedParquetRecordReader(boolean useOffHeap) { - this(null, useOffHeap); + this.capacity = capacity; } /** @@ -199,9 +197,9 @@ private void initBatch( } if (memMode == MemoryMode.OFF_HEAP) { - columnVectors = OffHeapColumnVector.allocateColumns(CAPACITY, batchSchema); + columnVectors = OffHeapColumnVector.allocateColumns(capacity, batchSchema); } else { - columnVectors = OnHeapColumnVector.allocateColumns(CAPACITY, batchSchema); + columnVectors = OnHeapColumnVector.allocateColumns(capacity, batchSchema); } columnarBatch = new ColumnarBatch(columnVectors); if (partitionColumns != null) { @@ -215,7 +213,7 @@ private void initBatch( // Initialize missing columns with nulls. for (int i = 0; i < missingColumns.length; i++) { if (missingColumns[i]) { - columnVectors[i].putNulls(0, CAPACITY); + columnVectors[i].putNulls(0, capacity); columnVectors[i].setIsConstant(); } } @@ -257,7 +255,7 @@ public boolean nextBatch() throws IOException { if (rowsReturned >= totalRowCount) return false; checkEndOfRowGroup(); - int num = (int) Math.min((long) CAPACITY, totalCountLoadedSoFar - rowsReturned); + int num = (int) Math.min((long) capacity, totalCountLoadedSoFar - rowsReturned); for (int i = 0; i < columnReaders.length; ++i) { if (columnReaders[i] == null) continue; columnReaders[i].readBatch(num, columnVectors[i]); diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java index c2e595455549c..9d447cdc79063 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java @@ -98,8 +98,11 @@ public void reserve(int requiredCapacity) { private void throwUnsupportedException(int requiredCapacity, Throwable cause) { String message = "Cannot reserve additional contiguous bytes in the vectorized reader " + "(requested = " + requiredCapacity + " bytes). As a workaround, you can disable the " + - "vectorized reader by setting " + SQLConf.PARQUET_VECTORIZED_READER_ENABLED().key() + - " to false."; + "vectorized reader, or increase the vectorized reader batch size. For parquet file " + + "format, refer to " + SQLConf.PARQUET_VECTORIZED_READER_ENABLED().key() + " and " + + SQLConf.PARQUET_VECTORIZED_READER_BATCH_SIZE().key() + "; for orc file format, refer to " + + SQLConf.ORC_VECTORIZED_READER_ENABLED().key() + " and " + + SQLConf.ORC_VECTORIZED_READER_BATCH_SIZE().key() + "."; throw new RuntimeException(message, cause); } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala index 2dd314d165348..dbf3bc6f0ee6c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala @@ -151,6 +151,7 @@ class OrcFileFormat val sqlConf = sparkSession.sessionState.conf val enableOffHeapColumnVector = sqlConf.offHeapColumnVectorEnabled val enableVectorizedReader = supportBatch(sparkSession, resultSchema) + val capacity = sqlConf.orcVectorizedReaderBatchSize val copyToSpark = sparkSession.sessionState.conf.getConf(SQLConf.ORC_COPY_BATCH_TO_SPARK) val broadcastedConf = @@ -186,7 +187,7 @@ class OrcFileFormat val taskContext = Option(TaskContext.get()) if (enableVectorizedReader) { val batchReader = new OrcColumnarBatchReader( - enableOffHeapColumnVector && taskContext.isDefined, copyToSpark) + enableOffHeapColumnVector && taskContext.isDefined, copyToSpark, capacity) batchReader.initialize(fileSplit, taskAttemptContext) batchReader.initBatch( reader.getSchema, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala index f53a97ba45a26..ba69f9a26c968 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala @@ -350,6 +350,7 @@ class ParquetFileFormat sparkSession.sessionState.conf.parquetRecordFilterEnabled val timestampConversion: Boolean = sparkSession.sessionState.conf.isParquetINT96TimestampConversion + val capacity = sqlConf.parquetVectorizedReaderBatchSize // Whole stage codegen (PhysicalRDD) is able to deal with batches directly val returningBatch = supportBatch(sparkSession, resultSchema) @@ -396,7 +397,7 @@ class ParquetFileFormat val taskContext = Option(TaskContext.get()) val parquetReader = if (enableVectorizedReader) { val vectorizedReader = new VectorizedParquetRecordReader( - convertTz.orNull, enableOffHeapColumnVector && taskContext.isDefined) + convertTz.orNull, enableOffHeapColumnVector && taskContext.isDefined, capacity) vectorizedReader.initialize(split, hadoopAttemptContext) logDebug(s"Appending $partitionSchema ${file.partitionValues}") vectorizedReader.initBatch(partitionSchema, file.partitionValues) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetEncodingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetEncodingSuite.scala index edb1290ee2eb0..db73bfa149aa0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetEncodingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetEncodingSuite.scala @@ -40,7 +40,9 @@ class ParquetEncodingSuite extends ParquetCompatibilityTest with SharedSQLContex List.fill(n)(ROW).toDF.repartition(1).write.parquet(dir.getCanonicalPath) val file = SpecificParquetRecordReaderBase.listDirectory(dir).toArray.head - val reader = new VectorizedParquetRecordReader(sqlContext.conf.offHeapColumnVectorEnabled) + val conf = sqlContext.conf + val reader = new VectorizedParquetRecordReader( + null, conf.offHeapColumnVectorEnabled, conf.parquetVectorizedReaderBatchSize) reader.initialize(file.asInstanceOf[String], null) val batch = reader.resultBatch() assert(reader.nextBatch()) @@ -65,7 +67,9 @@ class ParquetEncodingSuite extends ParquetCompatibilityTest with SharedSQLContex data.repartition(1).write.parquet(dir.getCanonicalPath) val file = SpecificParquetRecordReaderBase.listDirectory(dir).toArray.head - val reader = new VectorizedParquetRecordReader(sqlContext.conf.offHeapColumnVectorEnabled) + val conf = sqlContext.conf + val reader = new VectorizedParquetRecordReader( + null, conf.offHeapColumnVectorEnabled, conf.parquetVectorizedReaderBatchSize) reader.initialize(file.asInstanceOf[String], null) val batch = reader.resultBatch() assert(reader.nextBatch()) @@ -94,7 +98,9 @@ class ParquetEncodingSuite extends ParquetCompatibilityTest with SharedSQLContex data.toDF("f").coalesce(1).write.parquet(dir.getCanonicalPath) val file = SpecificParquetRecordReaderBase.listDirectory(dir).asScala.head - val reader = new VectorizedParquetRecordReader(sqlContext.conf.offHeapColumnVectorEnabled) + val conf = sqlContext.conf + val reader = new VectorizedParquetRecordReader( + null, conf.offHeapColumnVectorEnabled, conf.parquetVectorizedReaderBatchSize) reader.initialize(file, null /* set columns to null to project all columns */) val column = reader.resultBatch().column(0) assert(reader.nextBatch()) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala index f3ece5b15e26a..3af80930ec807 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala @@ -653,7 +653,9 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { spark.createDataFrame(data).repartition(1).write.parquet(dir.getCanonicalPath) val file = SpecificParquetRecordReaderBase.listDirectory(dir).get(0); { - val reader = new VectorizedParquetRecordReader(sqlContext.conf.offHeapColumnVectorEnabled) + val conf = sqlContext.conf + val reader = new VectorizedParquetRecordReader( + null, conf.offHeapColumnVectorEnabled, conf.parquetVectorizedReaderBatchSize) try { reader.initialize(file, null) val result = mutable.ArrayBuffer.empty[(Int, String)] @@ -670,7 +672,9 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { // Project just one column { - val reader = new VectorizedParquetRecordReader(sqlContext.conf.offHeapColumnVectorEnabled) + val conf = sqlContext.conf + val reader = new VectorizedParquetRecordReader( + null, conf.offHeapColumnVectorEnabled, conf.parquetVectorizedReaderBatchSize) try { reader.initialize(file, ("_2" :: Nil).asJava) val result = mutable.ArrayBuffer.empty[(String)] @@ -686,7 +690,9 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { // Project columns in opposite order { - val reader = new VectorizedParquetRecordReader(sqlContext.conf.offHeapColumnVectorEnabled) + val conf = sqlContext.conf + val reader = new VectorizedParquetRecordReader( + null, conf.offHeapColumnVectorEnabled, conf.parquetVectorizedReaderBatchSize) try { reader.initialize(file, ("_2" :: "_1" :: Nil).asJava) val result = mutable.ArrayBuffer.empty[(String, Int)] @@ -703,7 +709,9 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { // Empty projection { - val reader = new VectorizedParquetRecordReader(sqlContext.conf.offHeapColumnVectorEnabled) + val conf = sqlContext.conf + val reader = new VectorizedParquetRecordReader( + null, conf.offHeapColumnVectorEnabled, conf.parquetVectorizedReaderBatchSize) try { reader.initialize(file, List[String]().asJava) var result = 0 @@ -742,8 +750,9 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { dataTypes.zip(constantValues).foreach { case (dt, v) => val schema = StructType(StructField("pcol", dt) :: Nil) - val vectorizedReader = - new VectorizedParquetRecordReader(sqlContext.conf.offHeapColumnVectorEnabled) + val conf = sqlContext.conf + val vectorizedReader = new VectorizedParquetRecordReader( + null, conf.offHeapColumnVectorEnabled, conf.parquetVectorizedReaderBatchSize) val partitionValues = new GenericInternalRow(Array(v)) val file = SpecificParquetRecordReaderBase.listDirectory(dir).get(0) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala index 86a3c71a3c4f6..e43336d947364 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala @@ -76,6 +76,7 @@ object ParquetReadBenchmark { withTempPath { dir => withTempTable("t1", "tempTable") { val enableOffHeapColumnVector = spark.sessionState.conf.offHeapColumnVectorEnabled + val vectorizedReaderBatchSize = spark.sessionState.conf.parquetVectorizedReaderBatchSize spark.range(values).createOrReplaceTempView("t1") spark.sql("select cast(id as INT) as id from t1") .write.parquet(dir.getCanonicalPath) @@ -96,7 +97,8 @@ object ParquetReadBenchmark { parquetReaderBenchmark.addCase("ParquetReader Vectorized") { num => var sum = 0L files.map(_.asInstanceOf[String]).foreach { p => - val reader = new VectorizedParquetRecordReader(enableOffHeapColumnVector) + val reader = new VectorizedParquetRecordReader( + null, enableOffHeapColumnVector, vectorizedReaderBatchSize) try { reader.initialize(p, ("id" :: Nil).asJava) val batch = reader.resultBatch() @@ -119,7 +121,8 @@ object ParquetReadBenchmark { parquetReaderBenchmark.addCase("ParquetReader Vectorized -> Row") { num => var sum = 0L files.map(_.asInstanceOf[String]).foreach { p => - val reader = new VectorizedParquetRecordReader(enableOffHeapColumnVector) + val reader = new VectorizedParquetRecordReader( + null, enableOffHeapColumnVector, vectorizedReaderBatchSize) try { reader.initialize(p, ("id" :: Nil).asJava) val batch = reader.resultBatch() @@ -262,6 +265,7 @@ object ParquetReadBenchmark { withTempPath { dir => withTempTable("t1", "tempTable") { val enableOffHeapColumnVector = spark.sessionState.conf.offHeapColumnVectorEnabled + val vectorizedReaderBatchSize = spark.sessionState.conf.parquetVectorizedReaderBatchSize spark.range(values).createOrReplaceTempView("t1") spark.sql(s"select IF(rand(1) < $fractionOfNulls, NULL, cast(id as STRING)) as c1, " + s"IF(rand(2) < $fractionOfNulls, NULL, cast(id as STRING)) as c2 from t1") @@ -279,7 +283,8 @@ object ParquetReadBenchmark { benchmark.addCase("PR Vectorized") { num => var sum = 0 files.map(_.asInstanceOf[String]).foreach { p => - val reader = new VectorizedParquetRecordReader(enableOffHeapColumnVector) + val reader = new VectorizedParquetRecordReader( + null, enableOffHeapColumnVector, vectorizedReaderBatchSize) try { reader.initialize(p, ("c1" :: "c2" :: Nil).asJava) val batch = reader.resultBatch() From b6b50efc854f298d5b3e11c05dca995a85bec962 Mon Sep 17 00:00:00 2001 From: Xingbo Jiang Date: Wed, 31 Jan 2018 20:59:19 -0800 Subject: [PATCH 0260/1161] [SQL][MINOR] Inline SpecifiedWindowFrame.defaultWindowFrame(). ## What changes were proposed in this pull request? SpecifiedWindowFrame.defaultWindowFrame(hasOrderSpecification, acceptWindowFrame) was designed to handle the cases when some Window functions don't support setting a window frame (e.g. rank). However this param is never used. We may inline the whole of this function to simplify the code. ## How was this patch tested? Existing tests. Author: Xingbo Jiang Closes #20463 from jiangxb1987/defaultWindowFrame. --- .../sql/catalyst/analysis/Analyzer.scala | 6 +++++- .../expressions/windowExpressions.scala | 21 ------------------- .../catalyst/ExpressionSQLBuilderSuite.scala | 5 +---- 3 files changed, 6 insertions(+), 26 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 251099f750cf6..7848f88bda1c9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -2038,7 +2038,11 @@ class Analyzer( WindowExpression(wf, s.copy(frameSpecification = wf.frame)) case we @ WindowExpression(e, s @ WindowSpecDefinition(_, o, UnspecifiedFrame)) if e.resolved => - val frame = SpecifiedWindowFrame.defaultWindowFrame(o.nonEmpty, acceptWindowFrame = true) + val frame = if (o.nonEmpty) { + SpecifiedWindowFrame(RangeFrame, UnboundedPreceding, CurrentRow) + } else { + SpecifiedWindowFrame(RowFrame, UnboundedPreceding, UnboundedFollowing) + } we.copy(windowSpec = s.copy(frameSpecification = frame)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala index dd13d9a3bba51..78895f1c2f6f5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala @@ -265,27 +265,6 @@ case class SpecifiedWindowFrame( } } -object SpecifiedWindowFrame { - /** - * @param hasOrderSpecification If the window spec has order by expressions. - * @param acceptWindowFrame If the window function accepts user-specified frame. - * @return the default window frame. - */ - def defaultWindowFrame( - hasOrderSpecification: Boolean, - acceptWindowFrame: Boolean): SpecifiedWindowFrame = { - if (hasOrderSpecification && acceptWindowFrame) { - // If order spec is defined and the window function supports user specified window frames, - // the default frame is RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW. - SpecifiedWindowFrame(RangeFrame, UnboundedPreceding, CurrentRow) - } else { - // Otherwise, the default frame is - // ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING. - SpecifiedWindowFrame(RowFrame, UnboundedPreceding, UnboundedFollowing) - } - } -} - case class UnresolvedWindowExpression( child: Expression, windowSpec: WindowSpecReference) extends UnaryExpression with Unevaluable { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionSQLBuilderSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionSQLBuilderSuite.scala index d9cf1f361c1d6..61f9179042fe4 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionSQLBuilderSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionSQLBuilderSuite.scala @@ -108,10 +108,7 @@ class ExpressionSQLBuilderSuite extends QueryTest with TestHiveSingleton { } test("window specification") { - val frame = SpecifiedWindowFrame.defaultWindowFrame( - hasOrderSpecification = true, - acceptWindowFrame = true - ) + val frame = SpecifiedWindowFrame(RangeFrame, UnboundedPreceding, CurrentRow) checkSQL( WindowSpecDefinition('a.int :: Nil, Nil, frame), From 4b7cd479a28b274f5a0802c9b017b3eb15002c21 Mon Sep 17 00:00:00 2001 From: jerryshao Date: Thu, 1 Feb 2018 13:58:13 +0800 Subject: [PATCH 0261/1161] Revert "[SPARK-23200] Reset Kubernetes-specific config on Checkpoint restore" This reverts commit d1721816d26bedee3c72eeb75db49da500568376. The patch is not fully tested and out-of-date. So revert it. --- .../org/apache/spark/streaming/Checkpoint.scala | 16 ---------------- 1 file changed, 16 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala index ed2a896033749..aed67a5027433 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala @@ -53,21 +53,6 @@ class Checkpoint(ssc: StreamingContext, val checkpointTime: Time) "spark.driver.host", "spark.driver.bindAddress", "spark.driver.port", - "spark.kubernetes.driver.pod.name", - "spark.kubernetes.executor.podNamePrefix", - "spark.kubernetes.initcontainer.executor.configmapname", - "spark.kubernetes.initcontainer.executor.configmapkey", - "spark.kubernetes.initcontainer.downloadJarsResourceIdentifier", - "spark.kubernetes.initcontainer.downloadJarsSecretLocation", - "spark.kubernetes.initcontainer.downloadFilesResourceIdentifier", - "spark.kubernetes.initcontainer.downloadFilesSecretLocation", - "spark.kubernetes.initcontainer.remoteJars", - "spark.kubernetes.initcontainer.remoteFiles", - "spark.kubernetes.mountdependencies.jarsDownloadDir", - "spark.kubernetes.mountdependencies.filesDownloadDir", - "spark.kubernetes.initcontainer.executor.stagingServerSecret.name", - "spark.kubernetes.initcontainer.executor.stagingServerSecret.mountDir", - "spark.kubernetes.executor.limit.cores", "spark.master", "spark.yarn.jars", "spark.yarn.keytab", @@ -81,7 +66,6 @@ class Checkpoint(ssc: StreamingContext, val checkpointTime: Time) val newSparkConf = new SparkConf(loadDefaults = false).setAll(sparkConfPairs) .remove("spark.driver.host") .remove("spark.driver.bindAddress") - .remove("spark.kubernetes.driver.pod.name") .remove("spark.driver.port") val newReloadConf = new SparkConf(loadDefaults = true) propertiesToReload.foreach { prop => From 07cee33736aabf9e9a4a89344eda2b8ea29b27ea Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Wed, 31 Jan 2018 22:26:27 -0800 Subject: [PATCH 0262/1161] [SPARK-22274][PYTHON][SQL][FOLLOWUP] Use `assertRaisesRegexp` instead of `assertRaisesRegex`. ## What changes were proposed in this pull request? This is a follow-up pr of #19872 which uses `assertRaisesRegex` but it doesn't exist in Python 2, so some tests fail when running tests in Python 2 environment. Unfortunately, we missed it because currently Python 2 environment of the pr builder doesn't have proper versions of pandas or pyarrow, so the tests were skipped. This pr modifies to use `assertRaisesRegexp` instead of `assertRaisesRegex`. ## How was this patch tested? Tested manually in my local environment. Author: Takuya UESHIN Closes #20467 from ueshin/issues/SPARK-22274/fup1. --- python/pyspark/sql/tests.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index dc26b96334c7a..b27363023ae77 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -4530,19 +4530,19 @@ def test_unsupported_types(self): from pyspark.sql.functions import pandas_udf, PandasUDFType with QuietTest(self.sc): - with self.assertRaisesRegex(NotImplementedError, 'not supported'): + with self.assertRaisesRegexp(NotImplementedError, 'not supported'): @pandas_udf(ArrayType(DoubleType()), PandasUDFType.GROUPED_AGG) def mean_and_std_udf(v): return [v.mean(), v.std()] with QuietTest(self.sc): - with self.assertRaisesRegex(NotImplementedError, 'not supported'): + with self.assertRaisesRegexp(NotImplementedError, 'not supported'): @pandas_udf('mean double, std double', PandasUDFType.GROUPED_AGG) def mean_and_std_udf(v): return v.mean(), v.std() with QuietTest(self.sc): - with self.assertRaisesRegex(NotImplementedError, 'not supported'): + with self.assertRaisesRegexp(NotImplementedError, 'not supported'): @pandas_udf(MapType(DoubleType(), DoubleType()), PandasUDFType.GROUPED_AGG) def mean_and_std_udf(v): return {v.mean(): v.std()} From e15da5b14c8d845028365a609c0c66731d024ee7 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Thu, 1 Feb 2018 11:25:01 +0200 Subject: [PATCH 0263/1161] [SPARK-23107][ML] ML 2.3 QA: New Scala APIs, docs. ## What changes were proposed in this pull request? Audit new APIs and docs in 2.3.0. ## How was this patch tested? No test. Author: Yanbo Liang Closes #20459 from yanboliang/SPARK-23107. --- mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala | 2 +- .../scala/org/apache/spark/ml/regression/LinearRegression.scala | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala index 1155ea5fdd85b..22e7b8bbf1ff5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala @@ -74,7 +74,7 @@ private[feature] trait RFormulaBase extends HasFeaturesCol with HasLabelCol with * @group param */ @Since("2.3.0") - final override val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", + override val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", "How to handle invalid data (unseen or NULL values) in features and label column of string " + "type. Options are 'skip' (filter out rows with invalid data), error (throw an error), " + "or 'keep' (put invalid data in a special additional bucket, at index numLabels).", diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index a5873d03b4161..6d3fe7a6c748c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -645,7 +645,7 @@ class LinearRegressionModel private[ml] ( extends RegressionModel[Vector, LinearRegressionModel] with LinearRegressionParams with MLWritable { - def this(uid: String, coefficients: Vector, intercept: Double) = + private[ml] def this(uid: String, coefficients: Vector, intercept: Double) = this(uid, coefficients, intercept, 1.0) private var trainingSummary: Option[LinearRegressionTrainingSummary] = None From 8bb70b068ea782e799e45238fcb093a6acb0fc9f Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Thu, 1 Feb 2018 21:25:02 +0900 Subject: [PATCH 0264/1161] [SPARK-23280][SQL][FOLLOWUP] Fix Java style check issues. ## What changes were proposed in this pull request? This is a follow-up of #20450 which broke lint-java checks. This pr fixes the lint-java issues. ``` [ERROR] src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java:[20,8] (imports) UnusedImports: Unused import - org.apache.spark.sql.catalyst.util.MapData. [ERROR] src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java:[21,8] (imports) UnusedImports: Unused import - org.apache.spark.sql.catalyst.util.MapData. [ERROR] src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java:[22,8] (imports) UnusedImports: Unused import - org.apache.spark.sql.catalyst.util.MapData. ``` ## How was this patch tested? Checked manually in my local environment. Author: Takuya UESHIN Closes #20468 from ueshin/issues/SPARK-23280/fup1. --- .../main/java/org/apache/spark/sql/vectorized/ColumnVector.java | 1 - .../main/java/org/apache/spark/sql/vectorized/ColumnarArray.java | 1 - .../main/java/org/apache/spark/sql/vectorized/ColumnarRow.java | 1 - 3 files changed, 3 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java index 05271ec1f46ab..530d4d23d4eaf 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java @@ -17,7 +17,6 @@ package org.apache.spark.sql.vectorized; import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.catalyst.util.MapData; import org.apache.spark.sql.types.DataType; import org.apache.spark.sql.types.Decimal; import org.apache.spark.unsafe.types.CalendarInterval; diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java index 7c7a1c806a2b7..72a192d089b9f 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java @@ -18,7 +18,6 @@ import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.catalyst.util.ArrayData; -import org.apache.spark.sql.catalyst.util.MapData; import org.apache.spark.sql.types.*; import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java index 0c9e92ed11fbd..b400f7f93c1fe 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java @@ -19,7 +19,6 @@ import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; -import org.apache.spark.sql.catalyst.util.MapData; import org.apache.spark.sql.types.*; import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; From 89e8d556b93d1bf1b28fe153fd284f154045b0ee Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Thu, 1 Feb 2018 21:28:53 +0900 Subject: [PATCH 0265/1161] [SPARK-23280][SQL][FOLLOWUP] Enable `MutableColumnarRow.getMap()`. ## What changes were proposed in this pull request? This is a followup pr of #20450. We should've enabled `MutableColumnarRow.getMap()` as well. ## How was this patch tested? Existing tests. Author: Takuya UESHIN Closes #20471 from ueshin/issues/SPARK-23280/fup2. --- .../spark/sql/execution/vectorized/MutableColumnarRow.java | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java index 66668f3753604..307c19032dee5 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java @@ -21,10 +21,10 @@ import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; -import org.apache.spark.sql.catalyst.util.MapData; import org.apache.spark.sql.types.*; import org.apache.spark.sql.vectorized.ColumnarArray; import org.apache.spark.sql.vectorized.ColumnarBatch; +import org.apache.spark.sql.vectorized.ColumnarMap; import org.apache.spark.sql.vectorized.ColumnarRow; import org.apache.spark.sql.vectorized.ColumnVector; import org.apache.spark.unsafe.types.CalendarInterval; @@ -162,8 +162,9 @@ public ColumnarArray getArray(int ordinal) { } @Override - public MapData getMap(int ordinal) { - throw new UnsupportedOperationException(); + public ColumnarMap getMap(int ordinal) { + if (columns[ordinal].isNullAt(rowId)) return null; + return columns[ordinal].getMap(rowId); } @Override From ffbca84519011a747e0552632e88f5e4956e493d Mon Sep 17 00:00:00 2001 From: Wang Gengliang Date: Thu, 1 Feb 2018 20:39:15 +0800 Subject: [PATCH 0266/1161] [SPARK-23202][SQL] Add new API in DataSourceWriter: onDataWriterCommit ## What changes were proposed in this pull request? The current DataSourceWriter API makes it hard to implement `onTaskCommit(taskCommit: TaskCommitMessage)` in `FileCommitProtocol`. In general, on receiving commit message, driver can start processing messages(e.g. persist messages into files) before all the messages are collected. The proposal to add a new API: `add(WriterCommitMessage message)`: Handles a commit message on receiving from a successful data writer. This should make the whole API of DataSourceWriter compatible with `FileCommitProtocol`, and more flexible. There was another radical attempt in #20386. This one should be more reasonable. ## How was this patch tested? Unit test Author: Wang Gengliang Closes #20454 from gengliangwang/write_api. --- .../sources/v2/writer/DataSourceWriter.java | 14 +++++++++++-- .../datasources/v2/WriteToDataSourceV2.scala | 5 ++++- .../sql/sources/v2/DataSourceV2Suite.scala | 21 ++++++++++++++++++- .../sources/v2/SimpleWritableDataSource.scala | 21 +++++++++++++++++++ 4 files changed, 57 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java index 7096aec0d22c2..52324b3792b8a 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java @@ -62,6 +62,14 @@ public interface DataSourceWriter { */ DataWriterFactory createWriterFactory(); + /** + * Handles a commit message on receiving from a successful data writer. + * + * If this method fails (by throwing an exception), this writing job is considered to to have been + * failed, and {@link #abort(WriterCommitMessage[])} would be called. + */ + default void onDataWriterCommit(WriterCommitMessage message) {} + /** * Commits this writing job with a list of commit messages. The commit messages are collected from * successful data writers and are produced by {@link DataWriter#commit()}. @@ -78,8 +86,10 @@ public interface DataSourceWriter { void commit(WriterCommitMessage[] messages); /** - * Aborts this writing job because some data writers are failed and keep failing when retry, or - * the Spark job fails with some unknown reasons, or {@link #commit(WriterCommitMessage[])} fails. + * Aborts this writing job because some data writers are failed and keep failing when retry, + * or the Spark job fails with some unknown reasons, + * or {@link #onDataWriterCommit(WriterCommitMessage)} fails, + * or {@link #commit(WriterCommitMessage[])} fails. * * If this method fails (by throwing an exception), the underlying data source may require manual * cleanup. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala index 6592bd72fa338..eefbcf4c0e087 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala @@ -80,7 +80,10 @@ case class WriteToDataSourceV2Exec(writer: DataSourceWriter, query: SparkPlan) e rdd, runTask, rdd.partitions.indices, - (index, message: WriterCommitMessage) => messages(index) = message + (index, message: WriterCommitMessage) => { + messages(index) = message + writer.onDataWriterCommit(message) + } ) if (!writer.isInstanceOf[StreamWriter]) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala index 2f49b07018aaf..1c3ba7826f7de 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala @@ -21,7 +21,7 @@ import java.util.{ArrayList, List => JList} import test.org.apache.spark.sql.sources.v2._ -import org.apache.spark.SparkException +import org.apache.spark.{SparkConf, SparkException} import org.apache.spark.sql.{AnalysisException, QueryTest, Row} import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec @@ -198,6 +198,25 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { } } } + + test("simple counter in writer with onDataWriterCommit") { + Seq(classOf[SimpleWritableDataSource]).foreach { cls => + withTempPath { file => + val path = file.getCanonicalPath + assert(spark.read.format(cls.getName).option("path", path).load().collect().isEmpty) + + val numPartition = 6 + spark.range(0, 10, 1, numPartition).select('id, -'id).write.format(cls.getName) + .option("path", path).save() + checkAnswer( + spark.read.format(cls.getName).option("path", path).load(), + spark.range(10).select('id, -'id)) + + assert(SimpleCounter.getCounter == numPartition, + "method onDataWriterCommit should be called as many as the number of partitions") + } + } + } } class SimpleDataSourceV2 extends DataSourceV2 with ReadSupport { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala index a131b16953e3b..36dd2a350a055 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala @@ -66,9 +66,14 @@ class SimpleWritableDataSource extends DataSourceV2 with ReadSupport with WriteS class Writer(jobId: String, path: String, conf: Configuration) extends DataSourceWriter { override def createWriterFactory(): DataWriterFactory[Row] = { + SimpleCounter.resetCounter new SimpleCSVDataWriterFactory(path, jobId, new SerializableConfiguration(conf)) } + override def onDataWriterCommit(message: WriterCommitMessage): Unit = { + SimpleCounter.increaseCounter + } + override def commit(messages: Array[WriterCommitMessage]): Unit = { val finalPath = new Path(path) val jobPath = new Path(new Path(finalPath, "_temporary"), jobId) @@ -183,6 +188,22 @@ class SimpleCSVDataReaderFactory(path: String, conf: SerializableConfiguration) } } +private[v2] object SimpleCounter { + private var count: Int = 0 + + def increaseCounter: Unit = { + count += 1 + } + + def getCounter: Int = { + count + } + + def resetCounter: Unit = { + count = 0 + } +} + class SimpleCSVDataWriterFactory(path: String, jobId: String, conf: SerializableConfiguration) extends DataWriterFactory[Row] { From ec63e2d0743a4f75e1cce21d0fe2b54407a86a4a Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Thu, 1 Feb 2018 21:00:47 +0800 Subject: [PATCH 0267/1161] [SPARK-23289][CORE] OneForOneBlockFetcher.DownloadCallback.onData should write the buffer fully ## What changes were proposed in this pull request? `channel.write(buf)` may not write the whole buffer since the underlying channel is a FileChannel, we should retry until the whole buffer is written. ## How was this patch tested? Jenkins Author: Shixiong Zhu Closes #20461 from zsxwing/SPARK-23289. --- .../apache/spark/network/shuffle/OneForOneBlockFetcher.java | 4 +++- core/src/test/scala/org/apache/spark/FileSuite.scala | 5 ++++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java index 9cac7d00cc6b6..0bc571874f07c 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java @@ -171,7 +171,9 @@ private class DownloadCallback implements StreamCallback { @Override public void onData(String streamId, ByteBuffer buf) throws IOException { - channel.write(buf); + while (buf.hasRemaining()) { + channel.write(buf); + } } @Override diff --git a/core/src/test/scala/org/apache/spark/FileSuite.scala b/core/src/test/scala/org/apache/spark/FileSuite.scala index e9539dc73f6fa..55a9122cf9026 100644 --- a/core/src/test/scala/org/apache/spark/FileSuite.scala +++ b/core/src/test/scala/org/apache/spark/FileSuite.scala @@ -244,7 +244,10 @@ class FileSuite extends SparkFunSuite with LocalSparkContext { for (i <- 0 until testOutputCopies) { // Shift values by i so that they're different in the output val alteredOutput = testOutput.map(b => (b + i).toByte) - channel.write(ByteBuffer.wrap(alteredOutput)) + val buffer = ByteBuffer.wrap(alteredOutput) + while (buffer.hasRemaining) { + channel.write(buffer) + } } channel.close() file.close() From f051f834036e63d5e480d86440ce39924f979e82 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Thu, 1 Feb 2018 10:36:31 -0800 Subject: [PATCH 0268/1161] [SPARK-13983][SQL] Fix HiveThriftServer2 can not get "--hiveconf" and ''--hivevar" variables since 2.0 ## What changes were proposed in this pull request? `--hiveconf` and `--hivevar` variables no longer work since Spark 2.0. The `spark-sql` client has fixed by [SPARK-15730](https://issues.apache.org/jira/browse/SPARK-15730) and [SPARK-18086](https://issues.apache.org/jira/browse/SPARK-18086). but `beeline`/[`Spark SQL HiveThriftServer2`](https://github.com/apache/spark/blob/v2.1.1/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala) is still broken. This pull request fix it. This pull request works for both `JDBC client` and `beeline`. ## How was this patch tested? unit tests for `JDBC client` manual tests for `beeline`: ``` git checkout origin/pr/17886 dev/make-distribution.sh --mvn mvn --tgz -Phive -Phive-thriftserver -Phadoop-2.6 -DskipTests tar -zxf spark-2.3.0-SNAPSHOT-bin-2.6.5.tgz && cd spark-2.3.0-SNAPSHOT-bin-2.6.5 sbin/start-thriftserver.sh ``` ``` cat < test.sql select '\${a}', '\${b}'; EOF beeline -u jdbc:hive2://localhost:10000 --hiveconf a=avalue --hivevar b=bvalue -f test.sql ``` Author: Yuming Wang Closes #17886 from wangyum/SPARK-13983-dev. --- .../service/cli/session/HiveSessionImpl.java | 74 ++++++++++++++++++- .../server/SparkSQLOperationManager.scala | 12 +++ .../HiveThriftServer2Suites.scala | 23 +++++- 3 files changed, 105 insertions(+), 4 deletions(-) diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/session/HiveSessionImpl.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/session/HiveSessionImpl.java index 108074cce3d6d..fc818bc69c761 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/session/HiveSessionImpl.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/session/HiveSessionImpl.java @@ -44,7 +44,7 @@ import org.apache.hadoop.hive.ql.history.HiveHistory; import org.apache.hadoop.hive.ql.metadata.Hive; import org.apache.hadoop.hive.ql.metadata.HiveException; -import org.apache.hadoop.hive.ql.processors.SetProcessor; +import org.apache.hadoop.hive.ql.parse.VariableSubstitution; import org.apache.hadoop.hive.ql.session.SessionState; import org.apache.hadoop.hive.shims.ShimLoader; import org.apache.hive.common.util.HiveVersionInfo; @@ -71,6 +71,12 @@ import org.apache.hive.service.cli.thrift.TProtocolVersion; import org.apache.hive.service.server.ThreadWithGarbageCleanup; +import static org.apache.hadoop.hive.conf.SystemVariables.ENV_PREFIX; +import static org.apache.hadoop.hive.conf.SystemVariables.HIVECONF_PREFIX; +import static org.apache.hadoop.hive.conf.SystemVariables.HIVEVAR_PREFIX; +import static org.apache.hadoop.hive.conf.SystemVariables.METACONF_PREFIX; +import static org.apache.hadoop.hive.conf.SystemVariables.SYSTEM_PREFIX; + /** * HiveSession * @@ -209,7 +215,7 @@ private void configureSession(Map sessionConfMap) throws HiveSQL String key = entry.getKey(); if (key.startsWith("set:")) { try { - SetProcessor.setVariable(key.substring(4), entry.getValue()); + setVariable(key.substring(4), entry.getValue()); } catch (Exception e) { throw new HiveSQLException(e); } @@ -221,6 +227,70 @@ private void configureSession(Map sessionConfMap) throws HiveSQL } } + // Copy from org.apache.hadoop.hive.ql.processors.SetProcessor, only change: + // setConf(varname, propName, varvalue, true) when varname.startsWith(HIVECONF_PREFIX) + public static int setVariable(String varname, String varvalue) throws Exception { + SessionState ss = SessionState.get(); + if (varvalue.contains("\n")){ + ss.err.println("Warning: Value had a \\n character in it."); + } + varname = varname.trim(); + if (varname.startsWith(ENV_PREFIX)){ + ss.err.println("env:* variables can not be set."); + return 1; + } else if (varname.startsWith(SYSTEM_PREFIX)){ + String propName = varname.substring(SYSTEM_PREFIX.length()); + System.getProperties().setProperty(propName, + new VariableSubstitution().substitute(ss.getConf(),varvalue)); + } else if (varname.startsWith(HIVECONF_PREFIX)){ + String propName = varname.substring(HIVECONF_PREFIX.length()); + setConf(varname, propName, varvalue, true); + } else if (varname.startsWith(HIVEVAR_PREFIX)) { + String propName = varname.substring(HIVEVAR_PREFIX.length()); + ss.getHiveVariables().put(propName, + new VariableSubstitution().substitute(ss.getConf(),varvalue)); + } else if (varname.startsWith(METACONF_PREFIX)) { + String propName = varname.substring(METACONF_PREFIX.length()); + Hive hive = Hive.get(ss.getConf()); + hive.setMetaConf(propName, new VariableSubstitution().substitute(ss.getConf(), varvalue)); + } else { + setConf(varname, varname, varvalue, true); + } + return 0; + } + + // returns non-null string for validation fail + private static void setConf(String varname, String key, String varvalue, boolean register) + throws IllegalArgumentException { + HiveConf conf = SessionState.get().getConf(); + String value = new VariableSubstitution().substitute(conf, varvalue); + if (conf.getBoolVar(HiveConf.ConfVars.HIVECONFVALIDATION)) { + HiveConf.ConfVars confVars = HiveConf.getConfVars(key); + if (confVars != null) { + if (!confVars.isType(value)) { + StringBuilder message = new StringBuilder(); + message.append("'SET ").append(varname).append('=').append(varvalue); + message.append("' FAILED because ").append(key).append(" expects "); + message.append(confVars.typeString()).append(" type value."); + throw new IllegalArgumentException(message.toString()); + } + String fail = confVars.validate(value); + if (fail != null) { + StringBuilder message = new StringBuilder(); + message.append("'SET ").append(varname).append('=').append(varvalue); + message.append("' FAILED in validation : ").append(fail).append('.'); + throw new IllegalArgumentException(message.toString()); + } + } else if (key.startsWith("hive.")) { + throw new IllegalArgumentException("hive configuration " + key + " does not exists."); + } + } + conf.verifyAndSet(key, value); + if (register) { + SessionState.get().getOverriddenConfigurations().put(key, value); + } + } + @Override public void setOperationLogSessionDir(File operationLogRootDir) { if (!operationLogRootDir.exists()) { diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala index a0e5012633f5e..bf7c01f60fb5c 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala @@ -28,6 +28,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.SQLContext import org.apache.spark.sql.hive.HiveUtils import org.apache.spark.sql.hive.thriftserver.{ReflectionUtils, SparkExecuteStatementOperation} +import org.apache.spark.sql.internal.SQLConf /** * Executes queries using Spark SQL, and maintains a list of handles to active queries. @@ -50,6 +51,9 @@ private[thriftserver] class SparkSQLOperationManager() require(sqlContext != null, s"Session handle: ${parentSession.getSessionHandle} has not been" + s" initialized or had already closed.") val conf = sqlContext.sessionState.conf + val hiveSessionState = parentSession.getSessionState + setConfMap(conf, hiveSessionState.getOverriddenConfigurations) + setConfMap(conf, hiveSessionState.getHiveVariables) val runInBackground = async && conf.getConf(HiveUtils.HIVE_THRIFT_SERVER_ASYNC) val operation = new SparkExecuteStatementOperation(parentSession, statement, confOverlay, runInBackground)(sqlContext, sessionToActivePool) @@ -58,4 +62,12 @@ private[thriftserver] class SparkSQLOperationManager() s"runInBackground=$runInBackground") operation } + + def setConfMap(conf: SQLConf, confMap: java.util.Map[String, String]): Unit = { + val iterator = confMap.entrySet().iterator() + while (iterator.hasNext) { + val kv = iterator.next() + conf.setConfString(kv.getKey, kv.getValue) + } + } } diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala index 7289da71a3365..496f8c82a6c61 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala @@ -135,6 +135,22 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest { } } + test("Support beeline --hiveconf and --hivevar") { + withJdbcStatement() { statement => + executeTest(hiveConfList) + executeTest(hiveVarList) + def executeTest(hiveList: String): Unit = { + hiveList.split(";").foreach{ m => + val kv = m.split("=") + // select "${a}"; ---> avalue + val resultSet = statement.executeQuery("select \"${" + kv(0) + "}\"") + resultSet.next() + assert(resultSet.getString(1) === kv(1)) + } + } + } + } + test("JDBC query execution") { withJdbcStatement("test") { statement => val queries = Seq( @@ -740,10 +756,11 @@ abstract class HiveThriftJdbcTest extends HiveThriftServer2Test { s"""jdbc:hive2://localhost:$serverPort/ |default? |hive.server2.transport.mode=http; - |hive.server2.thrift.http.path=cliservice + |hive.server2.thrift.http.path=cliservice; + |${hiveConfList}#${hiveVarList} """.stripMargin.split("\n").mkString.trim } else { - s"jdbc:hive2://localhost:$serverPort/" + s"jdbc:hive2://localhost:$serverPort/?${hiveConfList}#${hiveVarList}" } def withMultipleConnectionJdbcStatement(tableNames: String*)(fs: (Statement => Unit)*) { @@ -779,6 +796,8 @@ abstract class HiveThriftServer2Test extends SparkFunSuite with BeforeAndAfterAl private var listeningPort: Int = _ protected def serverPort: Int = listeningPort + protected val hiveConfList = "a=avalue;b=bvalue" + protected val hiveVarList = "c=cvalue;d=dvalue" protected def user = System.getProperty("user.name") protected var warehousePath: File = _ From 73da3b6968630d9e2cafc742ccb6d4eb54957df4 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 1 Feb 2018 10:48:34 -0800 Subject: [PATCH 0269/1161] [SPARK-23293][SQL] fix data source v2 self join ## What changes were proposed in this pull request? `DataSourceV2Relation` should extend `MultiInstanceRelation`, to take care of self-join. ## How was this patch tested? a new test Author: Wenchen Fan Closes #20466 from cloud-fan/dsv2-selfjoin. --- .../execution/datasources/v2/DataSourceV2Relation.scala | 8 +++++++- .../apache/spark/sql/sources/v2/DataSourceV2Suite.scala | 6 ++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala index 3d4c64981373d..eebfa29f91b99 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala @@ -17,13 +17,15 @@ package org.apache.spark.sql.execution.datasources.v2 +import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics} import org.apache.spark.sql.sources.v2.reader._ case class DataSourceV2Relation( fullOutput: Seq[AttributeReference], - reader: DataSourceReader) extends LeafNode with DataSourceReaderHolder { + reader: DataSourceReader) + extends LeafNode with MultiInstanceRelation with DataSourceReaderHolder { override def canEqual(other: Any): Boolean = other.isInstanceOf[DataSourceV2Relation] @@ -33,6 +35,10 @@ case class DataSourceV2Relation( case _ => Statistics(sizeInBytes = conf.defaultSizeInBytes) } + + override def newInstance(): DataSourceV2Relation = { + copy(fullOutput = fullOutput.map(_.newInstance())) + } } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala index 1c3ba7826f7de..23147fffe8a08 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala @@ -217,6 +217,12 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { } } } + + test("SPARK-23293: data source v2 self join") { + val df = spark.read.format(classOf[SimpleDataSourceV2].getName).load() + val df2 = df.select(($"i" + 1).as("k"), $"j") + checkAnswer(df.join(df2, "j"), (0 until 10).map(i => Row(-i, i, i + 1))) + } } class SimpleDataSourceV2 extends DataSourceV2 with ReadSupport { From 4bcfdefb9f6d5ba88335953683a1dabbee83e9ea Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Thu, 1 Feb 2018 14:56:40 -0800 Subject: [PATCH 0270/1161] [INFRA] Close stale PRs. Closes #20334 Closes #20262 From 032c11b83f0d276bf8085992229b8c598f02798a Mon Sep 17 00:00:00 2001 From: Gera Shegalov Date: Thu, 1 Feb 2018 15:26:59 -0800 Subject: [PATCH 0271/1161] [SPARK-23296][YARN] Include stacktrace in YARN-app diagnostic ## What changes were proposed in this pull request? Include stacktrace in the diagnostics message upon abnormal unregister from RM ## How was this patch tested? Tested with a failing job, and confirmed a stacktrace in the client output and YARN webUI. Author: Gera Shegalov Closes #20470 from gerashegalov/gera/stacktrace-diagnostics. --- .../scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 4d5e3bb043671..2f88feb0f1fdf 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -30,6 +30,7 @@ import scala.util.control.NonFatal import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.security.UserGroupInformation +import org.apache.hadoop.util.StringUtils import org.apache.hadoop.yarn.api._ import org.apache.hadoop.yarn.api.records._ import org.apache.hadoop.yarn.conf.YarnConfiguration @@ -718,7 +719,7 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends logError("User class threw exception: " + cause, cause) finish(FinalApplicationStatus.FAILED, ApplicationMaster.EXIT_EXCEPTION_USER_CLASS, - "User class threw exception: " + cause) + "User class threw exception: " + StringUtils.stringifyException(cause)) } sparkContextPromise.tryFailure(e.getCause()) } finally { From 90848d507457d30abb36e3ba07618dfc87c34cd6 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 2 Feb 2018 10:18:32 +0800 Subject: [PATCH 0272/1161] [SPARK-23284][SQL] Document the behavior of several ColumnVector's get APIs when accessing null slot ## What changes were proposed in this pull request? For some ColumnVector get APIs such as getDecimal, getBinary, getStruct, getArray, getInterval, getUTF8String, we should clearly document their behaviors when accessing null slot. They should return null in this case. Then we can remove null checks from the places using above APIs. For the APIs of primitive values like getInt, getInts, etc., this also documents their behaviors when accessing null slots. Their returning values are undefined and can be anything. ## How was this patch tested? Added tests into `ColumnarBatchSuite`. Author: Liang-Chi Hsieh Closes #20455 from viirya/SPARK-23272-followup. --- .../datasources/orc/OrcColumnVector.java | 3 + .../vectorized/MutableColumnarRow.java | 7 -- .../vectorized/WritableColumnVector.java | 5 ++ .../sql/vectorized/ArrowColumnVector.java | 4 + .../spark/sql/vectorized/ColumnVector.java | 63 ++++++++++------ .../spark/sql/vectorized/ColumnarRow.java | 7 -- .../vectorized/ColumnarBatchSuite.scala | 74 ++++++++++++++++++- 7 files changed, 124 insertions(+), 39 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java index c8add4c9f486c..12f4d658b1868 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java @@ -154,12 +154,14 @@ public double getDouble(int rowId) { @Override public Decimal getDecimal(int rowId, int precision, int scale) { + if (isNullAt(rowId)) return null; BigDecimal data = decimalData.vector[getRowIndex(rowId)].getHiveDecimal().bigDecimalValue(); return Decimal.apply(data, precision, scale); } @Override public UTF8String getUTF8String(int rowId) { + if (isNullAt(rowId)) return null; int index = getRowIndex(rowId); BytesColumnVector col = bytesData; return UTF8String.fromBytes(col.vector[index], col.start[index], col.length[index]); @@ -167,6 +169,7 @@ public UTF8String getUTF8String(int rowId) { @Override public byte[] getBinary(int rowId) { + if (isNullAt(rowId)) return null; int index = getRowIndex(rowId); byte[] binary = new byte[bytesData.length[index]]; System.arraycopy(bytesData.vector[index], bytesData.start[index], binary, 0, binary.length); diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java index 307c19032dee5..4e4242fe8d9b9 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java @@ -127,43 +127,36 @@ public boolean anyNull() { @Override public Decimal getDecimal(int ordinal, int precision, int scale) { - if (columns[ordinal].isNullAt(rowId)) return null; return columns[ordinal].getDecimal(rowId, precision, scale); } @Override public UTF8String getUTF8String(int ordinal) { - if (columns[ordinal].isNullAt(rowId)) return null; return columns[ordinal].getUTF8String(rowId); } @Override public byte[] getBinary(int ordinal) { - if (columns[ordinal].isNullAt(rowId)) return null; return columns[ordinal].getBinary(rowId); } @Override public CalendarInterval getInterval(int ordinal) { - if (columns[ordinal].isNullAt(rowId)) return null; return columns[ordinal].getInterval(rowId); } @Override public ColumnarRow getStruct(int ordinal, int numFields) { - if (columns[ordinal].isNullAt(rowId)) return null; return columns[ordinal].getStruct(rowId); } @Override public ColumnarArray getArray(int ordinal) { - if (columns[ordinal].isNullAt(rowId)) return null; return columns[ordinal].getArray(rowId); } @Override public ColumnarMap getMap(int ordinal) { - if (columns[ordinal].isNullAt(rowId)) return null; return columns[ordinal].getMap(rowId); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java index 9d447cdc79063..5275e4a91eac0 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java @@ -341,6 +341,7 @@ public final int putByteArray(int rowId, byte[] value) { @Override public Decimal getDecimal(int rowId, int precision, int scale) { + if (isNullAt(rowId)) return null; if (precision <= Decimal.MAX_INT_DIGITS()) { return Decimal.createUnsafe(getInt(rowId), precision, scale); } else if (precision <= Decimal.MAX_LONG_DIGITS()) { @@ -367,6 +368,7 @@ public void putDecimal(int rowId, Decimal value, int precision) { @Override public UTF8String getUTF8String(int rowId) { + if (isNullAt(rowId)) return null; if (dictionary == null) { return arrayData().getBytesAsUTF8String(getArrayOffset(rowId), getArrayLength(rowId)); } else { @@ -384,6 +386,7 @@ public UTF8String getUTF8String(int rowId) { @Override public byte[] getBinary(int rowId) { + if (isNullAt(rowId)) return null; if (dictionary == null) { return arrayData().getBytes(getArrayOffset(rowId), getArrayLength(rowId)); } else { @@ -613,6 +616,7 @@ public final int appendStruct(boolean isNull) { // array offsets and lengths in the current column vector. @Override public final ColumnarArray getArray(int rowId) { + if (isNullAt(rowId)) return null; return new ColumnarArray(arrayData(), getArrayOffset(rowId), getArrayLength(rowId)); } @@ -620,6 +624,7 @@ public final ColumnarArray getArray(int rowId) { // second child column vector, and puts the offsets and lengths in the current column vector. @Override public final ColumnarMap getMap(int rowId) { + if (isNullAt(rowId)) return null; return new ColumnarMap(getChild(0), getChild(1), getArrayOffset(rowId), getArrayLength(rowId)); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java index f3ece538c3b80..f8e37e995a17f 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java @@ -101,21 +101,25 @@ public double getDouble(int rowId) { @Override public Decimal getDecimal(int rowId, int precision, int scale) { + if (isNullAt(rowId)) return null; return accessor.getDecimal(rowId, precision, scale); } @Override public UTF8String getUTF8String(int rowId) { + if (isNullAt(rowId)) return null; return accessor.getUTF8String(rowId); } @Override public byte[] getBinary(int rowId) { + if (isNullAt(rowId)) return null; return accessor.getBinary(rowId); } @Override public ColumnarArray getArray(int rowId) { + if (isNullAt(rowId)) return null; return accessor.getArray(rowId); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java index 530d4d23d4eaf..ad99b450a4809 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java @@ -80,12 +80,14 @@ public abstract class ColumnVector implements AutoCloseable { public abstract boolean isNullAt(int rowId); /** - * Returns the boolean type value for rowId. + * Returns the boolean type value for rowId. The return value is undefined and can be anything, + * if the slot for rowId is null. */ public abstract boolean getBoolean(int rowId); /** - * Gets boolean type values from [rowId, rowId + count) + * Gets boolean type values from [rowId, rowId + count). The return values for the null slots + * are undefined and can be anything. */ public boolean[] getBooleans(int rowId, int count) { boolean[] res = new boolean[count]; @@ -96,12 +98,14 @@ public boolean[] getBooleans(int rowId, int count) { } /** - * Returns the byte type value for rowId. + * Returns the byte type value for rowId. The return value is undefined and can be anything, + * if the slot for rowId is null. */ public abstract byte getByte(int rowId); /** - * Gets byte type values from [rowId, rowId + count) + * Gets byte type values from [rowId, rowId + count). The return values for the null slots + * are undefined and can be anything. */ public byte[] getBytes(int rowId, int count) { byte[] res = new byte[count]; @@ -112,12 +116,14 @@ public byte[] getBytes(int rowId, int count) { } /** - * Returns the short type value for rowId. + * Returns the short type value for rowId. The return value is undefined and can be anything, + * if the slot for rowId is null. */ public abstract short getShort(int rowId); /** - * Gets short type values from [rowId, rowId + count) + * Gets short type values from [rowId, rowId + count). The return values for the null slots + * are undefined and can be anything. */ public short[] getShorts(int rowId, int count) { short[] res = new short[count]; @@ -128,12 +134,14 @@ public short[] getShorts(int rowId, int count) { } /** - * Returns the int type value for rowId. + * Returns the int type value for rowId. The return value is undefined and can be anything, + * if the slot for rowId is null. */ public abstract int getInt(int rowId); /** - * Gets int type values from [rowId, rowId + count) + * Gets int type values from [rowId, rowId + count). The return values for the null slots + * are undefined and can be anything. */ public int[] getInts(int rowId, int count) { int[] res = new int[count]; @@ -144,12 +152,14 @@ public int[] getInts(int rowId, int count) { } /** - * Returns the long type value for rowId. + * Returns the long type value for rowId. The return value is undefined and can be anything, + * if the slot for rowId is null. */ public abstract long getLong(int rowId); /** - * Gets long type values from [rowId, rowId + count) + * Gets long type values from [rowId, rowId + count). The return values for the null slots + * are undefined and can be anything. */ public long[] getLongs(int rowId, int count) { long[] res = new long[count]; @@ -160,12 +170,14 @@ public long[] getLongs(int rowId, int count) { } /** - * Returns the float type value for rowId. + * Returns the float type value for rowId. The return value is undefined and can be anything, + * if the slot for rowId is null. */ public abstract float getFloat(int rowId); /** - * Gets float type values from [rowId, rowId + count) + * Gets float type values from [rowId, rowId + count). The return values for the null slots + * are undefined and can be anything. */ public float[] getFloats(int rowId, int count) { float[] res = new float[count]; @@ -176,12 +188,14 @@ public float[] getFloats(int rowId, int count) { } /** - * Returns the double type value for rowId. + * Returns the double type value for rowId. The return value is undefined and can be anything, + * if the slot for rowId is null. */ public abstract double getDouble(int rowId); /** - * Gets double type values from [rowId, rowId + count) + * Gets double type values from [rowId, rowId + count). The return values for the null slots + * are undefined and can be anything. */ public double[] getDoubles(int rowId, int count) { double[] res = new double[count]; @@ -192,7 +206,7 @@ public double[] getDoubles(int rowId, int count) { } /** - * Returns the struct type value for rowId. + * Returns the struct type value for rowId. If the slot for rowId is null, it should return null. * * To support struct type, implementations must implement {@link #getChild(int)} and make this * vector a tree structure. The number of child vectors must be same as the number of fields of @@ -205,7 +219,7 @@ public final ColumnarRow getStruct(int rowId) { } /** - * Returns the array type value for rowId. + * Returns the array type value for rowId. If the slot for rowId is null, it should return null. * * To support array type, implementations must construct an {@link ColumnarArray} and return it in * this method. {@link ColumnarArray} requires a {@link ColumnVector} that stores the data of all @@ -218,13 +232,13 @@ public final ColumnarRow getStruct(int rowId) { public abstract ColumnarArray getArray(int rowId); /** - * Returns the map type value for rowId. + * Returns the map type value for rowId. If the slot for rowId is null, it should return null. * * In Spark, map type value is basically a key data array and a value data array. A key from the * key array with a index and a value from the value array with the same index contribute to * an entry of this map type value. * - * To support map type, implementations must construct an {@link ColumnarMap} and return it in + * To support map type, implementations must construct a {@link ColumnarMap} and return it in * this method. {@link ColumnarMap} requires a {@link ColumnVector} that stores the data of all * the keys of all the maps in this vector, and another {@link ColumnVector} that stores the data * of all the values of all the maps in this vector, and a pair of offset and length which @@ -233,24 +247,25 @@ public final ColumnarRow getStruct(int rowId) { public abstract ColumnarMap getMap(int ordinal); /** - * Returns the decimal type value for rowId. + * Returns the decimal type value for rowId. If the slot for rowId is null, it should return null. */ public abstract Decimal getDecimal(int rowId, int precision, int scale); /** - * Returns the string type value for rowId. Note that the returned UTF8String may point to the - * data of this column vector, please copy it if you want to keep it after this column vector is - * freed. + * Returns the string type value for rowId. If the slot for rowId is null, it should return null. + * Note that the returned UTF8String may point to the data of this column vector, please copy it + * if you want to keep it after this column vector is freed. */ public abstract UTF8String getUTF8String(int rowId); /** - * Returns the binary type value for rowId. + * Returns the binary type value for rowId. If the slot for rowId is null, it should return null. */ public abstract byte[] getBinary(int rowId); /** - * Returns the calendar interval type value for rowId. + * Returns the calendar interval type value for rowId. If the slot for rowId is null, it should + * return null. * * In Spark, calendar interval type value is basically an integer value representing the number of * months in this interval, and a long value representing the number of microseconds in this diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java index b400f7f93c1fe..f2f2279590023 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java @@ -119,43 +119,36 @@ public boolean anyNull() { @Override public Decimal getDecimal(int ordinal, int precision, int scale) { - if (data.getChild(ordinal).isNullAt(rowId)) return null; return data.getChild(ordinal).getDecimal(rowId, precision, scale); } @Override public UTF8String getUTF8String(int ordinal) { - if (data.getChild(ordinal).isNullAt(rowId)) return null; return data.getChild(ordinal).getUTF8String(rowId); } @Override public byte[] getBinary(int ordinal) { - if (data.getChild(ordinal).isNullAt(rowId)) return null; return data.getChild(ordinal).getBinary(rowId); } @Override public CalendarInterval getInterval(int ordinal) { - if (data.getChild(ordinal).isNullAt(rowId)) return null; return data.getChild(ordinal).getInterval(rowId); } @Override public ColumnarRow getStruct(int ordinal, int numFields) { - if (data.getChild(ordinal).isNullAt(rowId)) return null; return data.getChild(ordinal).getStruct(rowId); } @Override public ColumnarArray getArray(int ordinal) { - if (data.getChild(ordinal).isNullAt(rowId)) return null; return data.getChild(ordinal).getArray(rowId); } @Override public ColumnarMap getMap(int ordinal) { - if (data.getChild(ordinal).isNullAt(rowId)) return null; return data.getChild(ordinal).getMap(rowId); } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala index 8fe2985836f2e..772f687526008 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala @@ -572,7 +572,7 @@ class ColumnarBatchSuite extends SparkFunSuite { } } - testVector("String APIs", 6, StringType) { + testVector("String APIs", 7, StringType) { column => val reference = mutable.ArrayBuffer.empty[String] @@ -619,6 +619,10 @@ class ColumnarBatchSuite extends SparkFunSuite { idx += 1 assert(column.arrayData().elementsAppended == 17 + (s + s).length) + column.putNull(idx) + assert(column.getUTF8String(idx) == null) + idx += 1 + reference.zipWithIndex.foreach { v => val errMsg = "VectorType=" + column.getClass.getSimpleName assert(v._1.length == column.getArrayLength(v._2), errMsg) @@ -647,6 +651,7 @@ class ColumnarBatchSuite extends SparkFunSuite { reference += new CalendarInterval(0, 2000) column.putNull(2) + assert(column.getInterval(2) == null) reference += null months.putInt(3, 20) @@ -683,6 +688,7 @@ class ColumnarBatchSuite extends SparkFunSuite { assert(column.getArray(0).numElements == 1) assert(column.getArray(1).numElements == 2) assert(column.isNullAt(2)) + assert(column.getArray(2) == null) assert(column.getArray(3).numElements == 0) assert(column.getArray(4).numElements == 3) @@ -785,6 +791,7 @@ class ColumnarBatchSuite extends SparkFunSuite { column.putArray(0, 0, 1) column.putArray(1, 1, 2) column.putNull(2) + assert(column.getMap(2) == null) column.putArray(3, 3, 0) column.putArray(4, 3, 3) @@ -821,6 +828,7 @@ class ColumnarBatchSuite extends SparkFunSuite { c2.putDouble(0, 3.45) column.putNull(1) + assert(column.getStruct(1) == null) c1.putInt(2, 456) c2.putDouble(2, 5.67) @@ -1261,4 +1269,68 @@ class ColumnarBatchSuite extends SparkFunSuite { batch.close() allocator.close() } + + testVector("Decimal API", 4, DecimalType.IntDecimal) { + column => + + val reference = mutable.ArrayBuffer.empty[Decimal] + + var idx = 0 + column.putDecimal(idx, new Decimal().set(10), 10) + reference += new Decimal().set(10) + idx += 1 + + column.putDecimal(idx, new Decimal().set(20), 10) + reference += new Decimal().set(20) + idx += 1 + + column.putNull(idx) + assert(column.getDecimal(idx, 10, 0) == null) + reference += null + idx += 1 + + column.putDecimal(idx, new Decimal().set(30), 10) + reference += new Decimal().set(30) + + reference.zipWithIndex.foreach { case (v, i) => + val errMsg = "VectorType=" + column.getClass.getSimpleName + assert(v == column.getDecimal(i, 10, 0), errMsg) + if (v == null) assert(column.isNullAt(i), errMsg) + } + + column.close() + } + + testVector("Binary APIs", 4, BinaryType) { + column => + + val reference = mutable.ArrayBuffer.empty[String] + var idx = 0 + column.putByteArray(idx, "Hello".getBytes(StandardCharsets.UTF_8)) + reference += "Hello" + idx += 1 + + column.putByteArray(idx, "World".getBytes(StandardCharsets.UTF_8)) + reference += "World" + idx += 1 + + column.putNull(idx) + reference += null + idx += 1 + + column.putByteArray(idx, "abc".getBytes(StandardCharsets.UTF_8)) + reference += "abc" + + reference.zipWithIndex.foreach { case (v, i) => + val errMsg = "VectorType=" + column.getClass.getSimpleName + if (v != null) { + assert(v == new String(column.getBinary(i)), errMsg) + } else { + assert(column.isNullAt(i), errMsg) + assert(column.getBinary(i) == null, errMsg) + } + } + + column.close() + } } From 969eda4a02faa7ca6cf3aff5cd10e6d51026b845 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Fri, 2 Feb 2018 11:43:22 +0800 Subject: [PATCH 0273/1161] [SPARK-23020][CORE] Fix another race in the in-process launcher test. First the bad news: there's an unfixable race in the launcher code. (By unfixable I mean it would take a lot more effort than this change to fix it.) The good news is that it should only affect super short lived applications, such as the one run by the flaky test, so it's possible to work around it in our test. The fix also uncovered an issue with the recently added "closeAndWait()" method; closing the connection would still possibly cause data loss, so this change waits a while for the connection to finish itself, and closes the socket if that times out. The existing connection timeout is reused so that if desired it's possible to control how long to wait. As part of that I also restored the old behavior that disconnect() would force a disconnection from the child app; the "wait for data to arrive" approach is only taken when disposing of the handle. I tested this by inserting a bunch of sleeps in the test and the socket handling code in the launcher library; with those I was able to reproduce the error from the jenkins jobs. With the changes, even with all the sleeps still in place, all tests pass. Author: Marcelo Vanzin Closes #20462 from vanzin/SPARK-23020. --- .../spark/launcher/SparkLauncherSuite.java | 40 ++++++++++++++--- .../spark/launcher/AbstractAppHandle.java | 45 ++++++++++++------- .../spark/launcher/ChildProcAppHandle.java | 2 +- .../spark/launcher/InProcessAppHandle.java | 2 +- .../apache/spark/launcher/LauncherServer.java | 30 ++++++++----- .../spark/launcher/LauncherServerSuite.java | 2 +- 6 files changed, 87 insertions(+), 34 deletions(-) diff --git a/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java b/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java index 1543f4fdb0162..2225591a4ff75 100644 --- a/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java +++ b/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java @@ -157,12 +157,24 @@ private void inProcessLauncherTestImpl() throws Exception { SparkAppHandle handle = null; try { - handle = new InProcessLauncher() - .setMaster("local") - .setAppResource(SparkLauncher.NO_RESOURCE) - .setMainClass(InProcessTestApp.class.getName()) - .addAppArgs("hello") - .startApplication(listener); + synchronized (InProcessTestApp.LOCK) { + handle = new InProcessLauncher() + .setMaster("local") + .setAppResource(SparkLauncher.NO_RESOURCE) + .setMainClass(InProcessTestApp.class.getName()) + .addAppArgs("hello") + .startApplication(listener); + + // SPARK-23020: see doc for InProcessTestApp.LOCK for a description of the race. Here + // we wait until we know that the connection between the app and the launcher has been + // established before allowing the app to finish. + final SparkAppHandle _handle = handle; + eventually(Duration.ofSeconds(5), Duration.ofMillis(10), () -> { + assertNotEquals(SparkAppHandle.State.UNKNOWN, _handle.getState()); + }); + + InProcessTestApp.LOCK.wait(5000); + } waitFor(handle); assertEquals(SparkAppHandle.State.FINISHED, handle.getState()); @@ -193,10 +205,26 @@ public static void main(String[] args) throws Exception { public static class InProcessTestApp { + /** + * SPARK-23020: there's a race caused by a child app finishing too quickly. This would cause + * the InProcessAppHandle to dispose of itself even before the child connection was properly + * established, so no state changes would be detected for the application and its final + * state would be LOST. + * + * It's not really possible to fix that race safely in the handle code itself without changing + * the way in-process apps talk to the launcher library, so we work around that in the test by + * synchronizing on this object. + */ + public static final Object LOCK = new Object(); + public static void main(String[] args) throws Exception { assertNotEquals(0, args.length); assertEquals(args[0], "hello"); new SparkContext().stop(); + + synchronized (LOCK) { + LOCK.notifyAll(); + } } } diff --git a/launcher/src/main/java/org/apache/spark/launcher/AbstractAppHandle.java b/launcher/src/main/java/org/apache/spark/launcher/AbstractAppHandle.java index 84a25a5254151..9cbebdaeb33d3 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/AbstractAppHandle.java +++ b/launcher/src/main/java/org/apache/spark/launcher/AbstractAppHandle.java @@ -18,22 +18,22 @@ package org.apache.spark.launcher; import java.io.IOException; -import java.util.ArrayList; import java.util.List; +import java.util.concurrent.CopyOnWriteArrayList; import java.util.concurrent.atomic.AtomicReference; import java.util.logging.Level; import java.util.logging.Logger; abstract class AbstractAppHandle implements SparkAppHandle { - private static final Logger LOG = Logger.getLogger(ChildProcAppHandle.class.getName()); + private static final Logger LOG = Logger.getLogger(AbstractAppHandle.class.getName()); private final LauncherServer server; private LauncherServer.ServerConnection connection; private List listeners; private AtomicReference state; - private String appId; + private volatile String appId; private volatile boolean disposed; protected AbstractAppHandle(LauncherServer server) { @@ -44,7 +44,7 @@ protected AbstractAppHandle(LauncherServer server) { @Override public synchronized void addListener(Listener l) { if (listeners == null) { - listeners = new ArrayList<>(); + listeners = new CopyOnWriteArrayList<>(); } listeners.add(l); } @@ -71,16 +71,14 @@ public void stop() { @Override public synchronized void disconnect() { - if (!isDisposed()) { - if (connection != null) { - try { - connection.closeAndWait(); - } catch (IOException ioe) { - // no-op. - } + if (connection != null && connection.isOpen()) { + try { + connection.close(); + } catch (IOException ioe) { + // no-op. } - dispose(); } + dispose(); } void setConnection(LauncherServer.ServerConnection connection) { @@ -97,10 +95,25 @@ boolean isDisposed() { /** * Mark the handle as disposed, and set it as LOST in case the current state is not final. + * + * This method should be called only when there's a reasonable expectation that the communication + * with the child application is not needed anymore, either because the code managing the handle + * has said so, or because the child application is finished. */ synchronized void dispose() { if (!isDisposed()) { + // First wait for all data from the connection to be read. Then unregister the handle. + // Otherwise, unregistering might cause the server to be stopped and all child connections + // to be closed. + if (connection != null) { + try { + connection.waitForClose(); + } catch (IOException ioe) { + // no-op. + } + } server.unregister(this); + // Set state to LOST if not yet final. setState(State.LOST, false); this.disposed = true; @@ -127,11 +140,13 @@ void setState(State s, boolean force) { current = state.get(); } - LOG.log(Level.WARNING, "Backend requested transition from final state {0} to {1}.", - new Object[] { current, s }); + if (s != State.LOST) { + LOG.log(Level.WARNING, "Backend requested transition from final state {0} to {1}.", + new Object[] { current, s }); + } } - synchronized void setAppId(String appId) { + void setAppId(String appId) { this.appId = appId; fireEvent(true); } diff --git a/launcher/src/main/java/org/apache/spark/launcher/ChildProcAppHandle.java b/launcher/src/main/java/org/apache/spark/launcher/ChildProcAppHandle.java index 5e3c95676ecbe..5609f8492f4f4 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/ChildProcAppHandle.java +++ b/launcher/src/main/java/org/apache/spark/launcher/ChildProcAppHandle.java @@ -112,7 +112,7 @@ void monitorChild() { } } - disconnect(); + dispose(); } } diff --git a/launcher/src/main/java/org/apache/spark/launcher/InProcessAppHandle.java b/launcher/src/main/java/org/apache/spark/launcher/InProcessAppHandle.java index b8030e0063a37..4b740d3fad20e 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/InProcessAppHandle.java +++ b/launcher/src/main/java/org/apache/spark/launcher/InProcessAppHandle.java @@ -66,7 +66,7 @@ synchronized void start(String appName, Method main, String[] args) { setState(State.FAILED); } - disconnect(); + dispose(); }); app.setName(String.format(THREAD_NAME_FMT, THREAD_IDS.incrementAndGet(), appName)); diff --git a/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java b/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java index f4ecd52fdeab8..607879fd02ea9 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java +++ b/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java @@ -238,6 +238,7 @@ public void run() { }; ServerConnection clientConnection = new ServerConnection(client, timeout); Thread clientThread = factory.newThread(clientConnection); + clientConnection.setConnectionThread(clientThread); synchronized (clients) { clients.add(clientConnection); } @@ -290,17 +291,15 @@ class ServerConnection extends LauncherConnection { private TimerTask timeout; private volatile Thread connectionThread; - volatile AbstractAppHandle handle; + private volatile AbstractAppHandle handle; ServerConnection(Socket socket, TimerTask timeout) throws IOException { super(socket); this.timeout = timeout; } - @Override - public void run() { - this.connectionThread = Thread.currentThread(); - super.run(); + void setConnectionThread(Thread t) { + this.connectionThread = t; } @Override @@ -361,19 +360,30 @@ public void close() throws IOException { } /** - * Close the connection and wait for any buffered data to be processed before returning. + * Wait for the remote side to close the connection so that any pending data is processed. * This ensures any changes reported by the child application take effect. + * + * This method allows a short period for the above to happen (same amount of time as the + * connection timeout, which is configurable). This should be fine for well-behaved + * applications, where they close the connection arond the same time the app handle detects the + * app has finished. + * + * In case the connection is not closed within the grace period, this method forcefully closes + * it and any subsequent data that may arrive will be ignored. */ - public void closeAndWait() throws IOException { - close(); - + public void waitForClose() throws IOException { Thread connThread = this.connectionThread; if (Thread.currentThread() != connThread) { try { - connThread.join(); + connThread.join(getConnectionTimeout()); } catch (InterruptedException ie) { // Ignore. } + + if (connThread.isAlive()) { + LOG.log(Level.WARNING, "Timed out waiting for child connection to close."); + close(); + } } } diff --git a/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java b/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java index 024efac33c391..d16337a319be3 100644 --- a/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java +++ b/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java @@ -94,8 +94,8 @@ public void infoChanged(SparkAppHandle handle) { Message stopMsg = client.inbound.poll(30, TimeUnit.SECONDS); assertTrue(stopMsg instanceof Stop); } finally { - handle.kill(); close(client); + handle.kill(); client.clientThread.join(); } } From b3a04283f490020c13b6750de021af734c449c3a Mon Sep 17 00:00:00 2001 From: Zhan Zhang Date: Fri, 2 Feb 2018 12:21:06 +0800 Subject: [PATCH 0274/1161] [SPARK-23306] Fix the oom caused by contention ## What changes were proposed in this pull request? here is race condition in TaskMemoryManger, which may cause OOM. The memory released may be taken by another task because there is a gap between releaseMemory and acquireMemory, e.g., UnifiedMemoryManager, causing the OOM. if the current is the only one that can perform spill. It can happen to BytesToBytesMap, as it only spill required bytes. Loop on current consumer if it still has memory to release. ## How was this patch tested? The race contention is hard to reproduce, but the current logic seems causing the issue. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Zhan Zhang Closes #20480 from zhzhan/oom. --- .../org/apache/spark/memory/TaskMemoryManager.java | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java index 632d718062212..d07faf1da1248 100644 --- a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java +++ b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java @@ -172,10 +172,7 @@ public long acquireExecutionMemory(long required, MemoryConsumer consumer) { currentEntry = sortedConsumers.lastEntry(); } List cList = currentEntry.getValue(); - MemoryConsumer c = cList.remove(cList.size() - 1); - if (cList.isEmpty()) { - sortedConsumers.remove(currentEntry.getKey()); - } + MemoryConsumer c = cList.get(cList.size() - 1); try { long released = c.spill(required - got, consumer); if (released > 0) { @@ -185,6 +182,11 @@ public long acquireExecutionMemory(long required, MemoryConsumer consumer) { if (got >= required) { break; } + } else { + cList.remove(cList.size() - 1); + if (cList.isEmpty()) { + sortedConsumers.remove(currentEntry.getKey()); + } } } catch (ClosedByInterruptException e) { // This called by user to kill a task (e.g: speculative task). From 19c7c7ebdef6c1c7a02ebac9af6a24f521b52c37 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 1 Feb 2018 20:44:46 -0800 Subject: [PATCH 0275/1161] [SPARK-23301][SQL] data source column pruning should work for arbitrary expressions ## What changes were proposed in this pull request? This PR fixes a mistake in the `PushDownOperatorsToDataSource` rule, the column pruning logic is incorrect about `Project`. ## How was this patch tested? a new test case for column pruning with arbitrary expressions, and improve the existing tests to make sure the `PushDownOperatorsToDataSource` really works. Author: Wenchen Fan Closes #20476 from cloud-fan/push-down. --- .../v2/PushDownOperatorsToDataSource.scala | 53 ++++---- .../sources/v2/JavaAdvancedDataSourceV2.java | 29 ++++- .../sql/sources/v2/DataSourceV2Suite.scala | 113 ++++++++++++++++-- 3 files changed, 155 insertions(+), 40 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala index df034adf1e7d6..566a48394f02e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution.datasources.v2 -import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeMap, Expression, NamedExpression, PredicateHelper} +import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeMap, AttributeSet, Expression, NamedExpression, PredicateHelper} import org.apache.spark.sql.catalyst.optimizer.RemoveRedundantProject import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.Rule @@ -81,35 +81,34 @@ object PushDownOperatorsToDataSource extends Rule[LogicalPlan] with PredicateHel // TODO: add more push down rules. - // TODO: nested fields pruning - def pushDownRequiredColumns(plan: LogicalPlan, requiredByParent: Seq[Attribute]): Unit = { - plan match { - case Project(projectList, child) => - val required = projectList.filter(requiredByParent.contains).flatMap(_.references) - pushDownRequiredColumns(child, required) - - case Filter(condition, child) => - val required = requiredByParent ++ condition.references - pushDownRequiredColumns(child, required) - - case DataSourceV2Relation(fullOutput, reader) => reader match { - case r: SupportsPushDownRequiredColumns => - // Match original case of attributes. - val attrMap = AttributeMap(fullOutput.zip(fullOutput)) - val requiredColumns = requiredByParent.map(attrMap) - r.pruneColumns(requiredColumns.toStructType) - case _ => - } + pushDownRequiredColumns(filterPushed, filterPushed.outputSet) + // After column pruning, we may have redundant PROJECT nodes in the query plan, remove them. + RemoveRedundantProject(filterPushed) + } + + // TODO: nested fields pruning + private def pushDownRequiredColumns(plan: LogicalPlan, requiredByParent: AttributeSet): Unit = { + plan match { + case Project(projectList, child) => + val required = projectList.flatMap(_.references) + pushDownRequiredColumns(child, AttributeSet(required)) + + case Filter(condition, child) => + val required = requiredByParent ++ condition.references + pushDownRequiredColumns(child, required) - // TODO: there may be more operators can be used to calculate required columns, we can add - // more and more in the future. - case _ => plan.children.foreach(child => pushDownRequiredColumns(child, child.output)) + case relation: DataSourceV2Relation => relation.reader match { + case reader: SupportsPushDownRequiredColumns => + val requiredColumns = relation.output.filter(requiredByParent.contains) + reader.pruneColumns(requiredColumns.toStructType) + + case _ => } - } - pushDownRequiredColumns(filterPushed, filterPushed.output) - // After column pruning, we may have redundant PROJECT nodes in the query plan, remove them. - RemoveRedundantProject(filterPushed) + // TODO: there may be more operators that can be used to calculate the required columns. We + // can add more and more in the future. + case _ => plan.children.foreach(child => pushDownRequiredColumns(child, child.outputSet)) + } } /** diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java index d421f7d19563f..172e5d5eebcbe 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java @@ -32,11 +32,12 @@ public class JavaAdvancedDataSourceV2 implements DataSourceV2, ReadSupport { - class Reader implements DataSourceReader, SupportsPushDownRequiredColumns, + public class Reader implements DataSourceReader, SupportsPushDownRequiredColumns, SupportsPushDownFilters { - private StructType requiredSchema = new StructType().add("i", "int").add("j", "int"); - private Filter[] filters = new Filter[0]; + // Exposed for testing. + public StructType requiredSchema = new StructType().add("i", "int").add("j", "int"); + public Filter[] filters = new Filter[0]; @Override public StructType readSchema() { @@ -50,8 +51,26 @@ public void pruneColumns(StructType requiredSchema) { @Override public Filter[] pushFilters(Filter[] filters) { - this.filters = filters; - return new Filter[0]; + Filter[] supported = Arrays.stream(filters).filter(f -> { + if (f instanceof GreaterThan) { + GreaterThan gt = (GreaterThan) f; + return gt.attribute().equals("i") && gt.value() instanceof Integer; + } else { + return false; + } + }).toArray(Filter[]::new); + + Filter[] unsupported = Arrays.stream(filters).filter(f -> { + if (f instanceof GreaterThan) { + GreaterThan gt = (GreaterThan) f; + return !gt.attribute().equals("i") || !(gt.value() instanceof Integer); + } else { + return true; + } + }).toArray(Filter[]::new); + + this.filters = supported; + return unsupported; } @Override diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala index 23147fffe8a08..eccd45442a3b2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala @@ -21,11 +21,13 @@ import java.util.{ArrayList, List => JList} import test.org.apache.spark.sql.sources.v2._ -import org.apache.spark.{SparkConf, SparkException} -import org.apache.spark.sql.{AnalysisException, QueryTest, Row} +import org.apache.spark.SparkException +import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row} import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanExec import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector +import org.apache.spark.sql.functions._ import org.apache.spark.sql.sources.{Filter, GreaterThan} import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.sources.v2.reader.partitioning.{ClusteredDistribution, Distribution, Partitioning} @@ -48,14 +50,72 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { } test("advanced implementation") { + def getReader(query: DataFrame): AdvancedDataSourceV2#Reader = { + query.queryExecution.executedPlan.collect { + case d: DataSourceV2ScanExec => d.reader.asInstanceOf[AdvancedDataSourceV2#Reader] + }.head + } + + def getJavaReader(query: DataFrame): JavaAdvancedDataSourceV2#Reader = { + query.queryExecution.executedPlan.collect { + case d: DataSourceV2ScanExec => d.reader.asInstanceOf[JavaAdvancedDataSourceV2#Reader] + }.head + } + Seq(classOf[AdvancedDataSourceV2], classOf[JavaAdvancedDataSourceV2]).foreach { cls => withClue(cls.getName) { val df = spark.read.format(cls.getName).load() checkAnswer(df, (0 until 10).map(i => Row(i, -i))) - checkAnswer(df.select('j), (0 until 10).map(i => Row(-i))) - checkAnswer(df.filter('i > 3), (4 until 10).map(i => Row(i, -i))) - checkAnswer(df.select('j).filter('i > 6), (7 until 10).map(i => Row(-i))) - checkAnswer(df.select('i).filter('i > 10), Nil) + + val q1 = df.select('j) + checkAnswer(q1, (0 until 10).map(i => Row(-i))) + if (cls == classOf[AdvancedDataSourceV2]) { + val reader = getReader(q1) + assert(reader.filters.isEmpty) + assert(reader.requiredSchema.fieldNames === Seq("j")) + } else { + val reader = getJavaReader(q1) + assert(reader.filters.isEmpty) + assert(reader.requiredSchema.fieldNames === Seq("j")) + } + + val q2 = df.filter('i > 3) + checkAnswer(q2, (4 until 10).map(i => Row(i, -i))) + if (cls == classOf[AdvancedDataSourceV2]) { + val reader = getReader(q2) + assert(reader.filters.flatMap(_.references).toSet == Set("i")) + assert(reader.requiredSchema.fieldNames === Seq("i", "j")) + } else { + val reader = getJavaReader(q2) + assert(reader.filters.flatMap(_.references).toSet == Set("i")) + assert(reader.requiredSchema.fieldNames === Seq("i", "j")) + } + + val q3 = df.select('i).filter('i > 6) + checkAnswer(q3, (7 until 10).map(i => Row(i))) + if (cls == classOf[AdvancedDataSourceV2]) { + val reader = getReader(q3) + assert(reader.filters.flatMap(_.references).toSet == Set("i")) + assert(reader.requiredSchema.fieldNames === Seq("i")) + } else { + val reader = getJavaReader(q3) + assert(reader.filters.flatMap(_.references).toSet == Set("i")) + assert(reader.requiredSchema.fieldNames === Seq("i")) + } + + val q4 = df.select('j).filter('j < -10) + checkAnswer(q4, Nil) + if (cls == classOf[AdvancedDataSourceV2]) { + val reader = getReader(q4) + // 'j < 10 is not supported by the testing data source. + assert(reader.filters.isEmpty) + assert(reader.requiredSchema.fieldNames === Seq("j")) + } else { + val reader = getJavaReader(q4) + // 'j < 10 is not supported by the testing data source. + assert(reader.filters.isEmpty) + assert(reader.requiredSchema.fieldNames === Seq("j")) + } } } } @@ -223,6 +283,39 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { val df2 = df.select(($"i" + 1).as("k"), $"j") checkAnswer(df.join(df2, "j"), (0 until 10).map(i => Row(-i, i, i + 1))) } + + test("SPARK-23301: column pruning with arbitrary expressions") { + def getReader(query: DataFrame): AdvancedDataSourceV2#Reader = { + query.queryExecution.executedPlan.collect { + case d: DataSourceV2ScanExec => d.reader.asInstanceOf[AdvancedDataSourceV2#Reader] + }.head + } + + val df = spark.read.format(classOf[AdvancedDataSourceV2].getName).load() + + val q1 = df.select('i + 1) + checkAnswer(q1, (1 until 11).map(i => Row(i))) + val reader1 = getReader(q1) + assert(reader1.requiredSchema.fieldNames === Seq("i")) + + val q2 = df.select(lit(1)) + checkAnswer(q2, (0 until 10).map(i => Row(1))) + val reader2 = getReader(q2) + assert(reader2.requiredSchema.isEmpty) + + // 'j === 1 can't be pushed down, but we should still be able do column pruning + val q3 = df.filter('j === -1).select('j * 2) + checkAnswer(q3, Row(-2)) + val reader3 = getReader(q3) + assert(reader3.filters.isEmpty) + assert(reader3.requiredSchema.fieldNames === Seq("j")) + + // column pruning should work with other operators. + val q4 = df.sort('i).limit(1).select('i + 1) + checkAnswer(q4, Row(1)) + val reader4 = getReader(q4) + assert(reader4.requiredSchema.fieldNames === Seq("i")) + } } class SimpleDataSourceV2 extends DataSourceV2 with ReadSupport { @@ -270,8 +363,12 @@ class AdvancedDataSourceV2 extends DataSourceV2 with ReadSupport { } override def pushFilters(filters: Array[Filter]): Array[Filter] = { - this.filters = filters - Array.empty + val (supported, unsupported) = filters.partition { + case GreaterThan("i", _: Int) => true + case _ => false + } + this.filters = supported + unsupported } override def pushedFilters(): Array[Filter] = filters From b9503fcbb3f4a3ce263164d1f11a8e99b9ca5710 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 2 Feb 2018 22:43:28 +0800 Subject: [PATCH 0276/1161] [SPARK-23312][SQL] add a config to turn off vectorized cache reader ## What changes were proposed in this pull request? https://issues.apache.org/jira/browse/SPARK-23309 reported a performance regression about cached table in Spark 2.3. While the investigating is still going on, this PR adds a conf to turn off the vectorized cache reader, to unblock the 2.3 release. ## How was this patch tested? a new test Author: Wenchen Fan Closes #20483 from cloud-fan/cache. --- .../org/apache/spark/sql/internal/SQLConf.scala | 8 ++++++++ .../columnar/InMemoryTableScanExec.scala | 2 +- .../org/apache/spark/sql/CachedTableSuite.scala | 15 +++++++++++++-- 3 files changed, 22 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 90654e67457e0..1e2501ee7757d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -141,6 +141,12 @@ object SQLConf { .booleanConf .createWithDefault(true) + val CACHE_VECTORIZED_READER_ENABLED = + buildConf("spark.sql.inMemoryColumnarStorage.enableVectorizedReader") + .doc("Enables vectorized reader for columnar caching.") + .booleanConf + .createWithDefault(true) + val COLUMN_VECTOR_OFFHEAP_ENABLED = buildConf("spark.sql.columnVector.offheap.enabled") .internal() @@ -1272,6 +1278,8 @@ class SQLConf extends Serializable with Logging { def columnBatchSize: Int = getConf(COLUMN_BATCH_SIZE) + def cacheVectorizedReaderEnabled: Boolean = getConf(CACHE_VECTORIZED_READER_ENABLED) + def numShufflePartitions: Int = getConf(SHUFFLE_PARTITIONS) def targetPostShuffleInputSize: Long = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala index c167f1e7dc621..e972f8b30d87c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala @@ -54,7 +54,7 @@ case class InMemoryTableScanExec( override val supportsBatch: Boolean = { // In the initial implementation, for ease of review // support only primitive data types and # of fields is less than wholeStageMaxNumFields - relation.schema.fields.forall(f => f.dataType match { + conf.cacheVectorizedReaderEnabled && relation.schema.fields.forall(f => f.dataType match { case BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType => true case _ => false diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index 72fe0f42801f1..9f27fa09127af 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -21,8 +21,6 @@ import scala.collection.mutable.HashSet import scala.concurrent.duration._ import scala.language.postfixOps -import org.scalatest.concurrent.Eventually._ - import org.apache.spark.CleanerListener import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.expressions.SubqueryExpression @@ -30,6 +28,7 @@ import org.apache.spark.sql.execution.{RDDScanExec, SparkPlan} import org.apache.spark.sql.execution.columnar._ import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} import org.apache.spark.storage.{RDDBlockId, StorageLevel} import org.apache.spark.util.{AccumulatorContext, Utils} @@ -782,4 +781,16 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext assert(getNumInMemoryRelations(cachedDs2) == 1) } } + + test("SPARK-23312: vectorized cache reader can be disabled") { + Seq(true, false).foreach { vectorized => + withSQLConf(SQLConf.CACHE_VECTORIZED_READER_ENABLED.key -> vectorized.toString) { + val df = spark.range(10).cache() + df.queryExecution.executedPlan.foreach { + case i: InMemoryTableScanExec => assert(i.supportsBatch == vectorized) + case _ => + } + } + } + } } From dd52681bf542386711609cb037a55b3d264eddef Mon Sep 17 00:00:00 2001 From: Kent Yao Date: Fri, 2 Feb 2018 09:10:50 -0600 Subject: [PATCH 0277/1161] [SPARK-23253][CORE][SHUFFLE] Only write shuffle temporary index file when there is not an existing one ## What changes were proposed in this pull request? Shuffle Index temporay file is used for atomic creating shuffle index file, it is not needed when the index file already exists after another attempts of same task had it done. ## How was this patch tested? exitsting ut cc squito Author: Kent Yao Closes #20422 from yaooqinn/SPARK-23253. --- .../shuffle/IndexShuffleBlockResolver.scala | 27 ++++----- .../sort/IndexShuffleBlockResolverSuite.scala | 59 ++++++++++++++----- 2 files changed, 56 insertions(+), 30 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala index 266ee42e39cca..c5f3f6e2b42b6 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala @@ -141,19 +141,6 @@ private[spark] class IndexShuffleBlockResolver( val indexFile = getIndexFile(shuffleId, mapId) val indexTmp = Utils.tempFileWith(indexFile) try { - val out = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(indexTmp))) - Utils.tryWithSafeFinally { - // We take in lengths of each block, need to convert it to offsets. - var offset = 0L - out.writeLong(offset) - for (length <- lengths) { - offset += length - out.writeLong(offset) - } - } { - out.close() - } - val dataFile = getDataFile(shuffleId, mapId) // There is only one IndexShuffleBlockResolver per executor, this synchronization make sure // the following check and rename are atomic. @@ -166,10 +153,22 @@ private[spark] class IndexShuffleBlockResolver( if (dataTmp != null && dataTmp.exists()) { dataTmp.delete() } - indexTmp.delete() } else { // This is the first successful attempt in writing the map outputs for this task, // so override any existing index and data files with the ones we wrote. + val out = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(indexTmp))) + Utils.tryWithSafeFinally { + // We take in lengths of each block, need to convert it to offsets. + var offset = 0L + out.writeLong(offset) + for (length <- lengths) { + offset += length + out.writeLong(offset) + } + } { + out.close() + } + if (indexFile.exists()) { indexFile.delete() } diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala index d21ce73f4021e..4ce379b76b551 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.shuffle.sort -import java.io.{File, FileInputStream, FileOutputStream} +import java.io.{DataInputStream, File, FileInputStream, FileOutputStream} import org.mockito.{Mock, MockitoAnnotations} import org.mockito.Answers.RETURNS_SMART_NULLS @@ -64,6 +64,9 @@ class IndexShuffleBlockResolverSuite extends SparkFunSuite with BeforeAndAfterEa } test("commit shuffle files multiple times") { + val shuffleId = 1 + val mapId = 2 + val idxName = s"shuffle_${shuffleId}_${mapId}_0.index" val resolver = new IndexShuffleBlockResolver(conf, blockManager) val lengths = Array[Long](10, 0, 20) val dataTmp = File.createTempFile("shuffle", null, tempDir) @@ -73,9 +76,13 @@ class IndexShuffleBlockResolverSuite extends SparkFunSuite with BeforeAndAfterEa } { out.close() } - resolver.writeIndexFileAndCommit(1, 2, lengths, dataTmp) + resolver.writeIndexFileAndCommit(shuffleId, mapId, lengths, dataTmp) - val dataFile = resolver.getDataFile(1, 2) + val indexFile = new File(tempDir.getAbsolutePath, idxName) + val dataFile = resolver.getDataFile(shuffleId, mapId) + + assert(indexFile.exists()) + assert(indexFile.length() === (lengths.length + 1) * 8) assert(dataFile.exists()) assert(dataFile.length() === 30) assert(!dataTmp.exists()) @@ -89,7 +96,9 @@ class IndexShuffleBlockResolverSuite extends SparkFunSuite with BeforeAndAfterEa } { out2.close() } - resolver.writeIndexFileAndCommit(1, 2, lengths2, dataTmp2) + resolver.writeIndexFileAndCommit(shuffleId, mapId, lengths2, dataTmp2) + + assert(indexFile.length() === (lengths.length + 1) * 8) assert(lengths2.toSeq === lengths.toSeq) assert(dataFile.exists()) assert(dataFile.length() === 30) @@ -97,18 +106,27 @@ class IndexShuffleBlockResolverSuite extends SparkFunSuite with BeforeAndAfterEa // The dataFile should be the previous one val firstByte = new Array[Byte](1) - val in = new FileInputStream(dataFile) + val dataIn = new FileInputStream(dataFile) Utils.tryWithSafeFinally { - in.read(firstByte) + dataIn.read(firstByte) } { - in.close() + dataIn.close() } assert(firstByte(0) === 0) + // The index file should not change + val indexIn = new DataInputStream(new FileInputStream(indexFile)) + Utils.tryWithSafeFinally { + indexIn.readLong() // the first offset is always 0 + assert(indexIn.readLong() === 10, "The index file should not change") + } { + indexIn.close() + } + // remove data file dataFile.delete() - val lengths3 = Array[Long](10, 10, 15) + val lengths3 = Array[Long](7, 10, 15, 3) val dataTmp3 = File.createTempFile("shuffle", null, tempDir) val out3 = new FileOutputStream(dataTmp3) Utils.tryWithSafeFinally { @@ -117,20 +135,29 @@ class IndexShuffleBlockResolverSuite extends SparkFunSuite with BeforeAndAfterEa } { out3.close() } - resolver.writeIndexFileAndCommit(1, 2, lengths3, dataTmp3) + resolver.writeIndexFileAndCommit(shuffleId, mapId, lengths3, dataTmp3) + assert(indexFile.length() === (lengths3.length + 1) * 8) assert(lengths3.toSeq != lengths.toSeq) assert(dataFile.exists()) assert(dataFile.length() === 35) - assert(!dataTmp2.exists()) + assert(!dataTmp3.exists()) - // The dataFile should be the previous one - val firstByte2 = new Array[Byte](1) - val in2 = new FileInputStream(dataFile) + // The dataFile should be the new one, since we deleted the dataFile from the first attempt + val dataIn2 = new FileInputStream(dataFile) + Utils.tryWithSafeFinally { + dataIn2.read(firstByte) + } { + dataIn2.close() + } + assert(firstByte(0) === 2) + + // The index file should be updated, since we deleted the dataFile from the first attempt + val indexIn2 = new DataInputStream(new FileInputStream(indexFile)) Utils.tryWithSafeFinally { - in2.read(firstByte2) + indexIn2.readLong() // the first offset is always 0 + assert(indexIn2.readLong() === 7, "The index file should be updated") } { - in2.close() + indexIn2.close() } - assert(firstByte2(0) === 2) } } From eefec93d193d43d5b71b8f8a4b1060286da971dd Mon Sep 17 00:00:00 2001 From: Kent Yao Date: Fri, 2 Feb 2018 10:17:51 -0600 Subject: [PATCH 0278/1161] [SPARK-23295][BUILD][MINOR] Exclude Waring message when generating versions in make-distribution.sh ## What changes were proposed in this pull request? When we specified a wrong profile to make a spark distribution, such as `-Phadoop1000`, we will get an odd package named like `spark-[WARNING] The requested profile "hadoop1000" could not be activated because it does not exist.-bin-hadoop-2.7.tgz`, which actually should be `"spark-$VERSION-bin-$NAME.tgz"` ## How was this patch tested? ### before ``` build/mvn help:evaluate -Dexpression=scala.binary.version -Phadoop1000 2>/dev/null | grep -v "INFO" | tail -n 1 [WARNING] The requested profile "hadoop1000" could not be activated because it does not exist. ``` ``` build/mvn help:evaluate -Dexpression=project.version -Phadoop1000 2>/dev/null | grep -v "INFO" | tail -n 1 [WARNING] The requested profile "hadoop1000" could not be activated because it does not exist. ``` ### after ``` build/mvn help:evaluate -Dexpression=project.version -Phadoop1000 2>/dev/null | grep -v "INFO" | grep -v "WARNING" | tail -n 1 2.4.0-SNAPSHOT ``` ``` build/mvn help:evaluate -Dexpression=scala.binary.version -Dscala.binary.version=2.11.1 2>/dev/null | grep -v "INFO" | grep -v "WARNING" | tail -n 1 2.11.1 ``` cloud-fan srowen Author: Kent Yao Closes #20469 from yaooqinn/dist-minor. --- dev/make-distribution.sh | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/dev/make-distribution.sh b/dev/make-distribution.sh index 7245163ea2a51..8b02446b2f15f 100755 --- a/dev/make-distribution.sh +++ b/dev/make-distribution.sh @@ -117,15 +117,21 @@ if [ ! "$(command -v "$MVN")" ] ; then exit -1; fi -VERSION=$("$MVN" help:evaluate -Dexpression=project.version $@ 2>/dev/null | grep -v "INFO" | tail -n 1) +VERSION=$("$MVN" help:evaluate -Dexpression=project.version $@ 2>/dev/null\ + | grep -v "INFO"\ + | grep -v "WARNING"\ + | tail -n 1) SCALA_VERSION=$("$MVN" help:evaluate -Dexpression=scala.binary.version $@ 2>/dev/null\ | grep -v "INFO"\ + | grep -v "WARNING"\ | tail -n 1) SPARK_HADOOP_VERSION=$("$MVN" help:evaluate -Dexpression=hadoop.version $@ 2>/dev/null\ | grep -v "INFO"\ + | grep -v "WARNING"\ | tail -n 1) SPARK_HIVE=$("$MVN" help:evaluate -Dexpression=project.activeProfiles -pl sql/hive $@ 2>/dev/null\ | grep -v "INFO"\ + | grep -v "WARNING"\ | fgrep --count "hive";\ # Reset exit status to 0, otherwise the script stops here if the last grep finds nothing\ # because we use "set -o pipefail" From eaf35de2471fac4337dd2920026836d52b1ec847 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Fri, 2 Feb 2018 17:37:51 -0800 Subject: [PATCH 0279/1161] [SPARK-23064][SS][DOCS] Stream-stream joins Documentation - follow up ## What changes were proposed in this pull request? Further clarification of caveats in using stream-stream outer joins. ## How was this patch tested? N/A Author: Tathagata Das Closes #20494 from tdas/SPARK-23064-2. --- docs/structured-streaming-programming-guide.md | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/docs/structured-streaming-programming-guide.md b/docs/structured-streaming-programming-guide.md index 62589a62ac4c4..48d6d0b542cc0 100644 --- a/docs/structured-streaming-programming-guide.md +++ b/docs/structured-streaming-programming-guide.md @@ -1346,10 +1346,20 @@ joined <- join( -However, note that the outer NULL results will be generated with a delay (depends on the specified -watermark delay and the time range condition) because the engine has to wait for that long to ensure + +There are a few points to note regarding outer joins. + +- *The outer NULL results will be generated with a delay that depends on the specified watermark +delay and the time range condition.* This is because the engine has to wait for that long to ensure there were no matches and there will be no more matches in future. +- In the current implementation in the micro-batch engine, watermarks are advanced at the end of a +micro-batch, and the next micro-batch uses the updated watermark to clean up state and output +outer results. Since we trigger a micro-batch only when there is new data to be processed, the +generation of the outer result may get delayed if there no new data being received in the stream. +*In short, if any of the two input streams being joined does not receive data for a while, the +outer (both cases, left or right) output may get delayed.* + ##### Support matrix for joins in streaming queries
{executor.map(_.isBlacklisted).getOrElse(false)}for applicationfor stagefalse
From 3ff83ad43a704cc3354ef9783e711c065e2a1a22 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Fri, 2 Feb 2018 20:36:27 -0800 Subject: [PATCH 0280/1161] [SQL] Minor doc update: Add an example in DataFrameReader.schema ## What changes were proposed in this pull request? This patch adds a small example to the schema string definition of schema function. It isn't obvious how to use it, so an example would be useful. ## How was this patch tested? N/A - doc only. Author: Reynold Xin Closes #20491 from rxin/schema-doc. --- .../src/main/scala/org/apache/spark/sql/DataFrameReader.scala | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 46b5f54a33f74..fcaf8d618c168 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -74,6 +74,10 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * infer the input schema automatically from data. By specifying the schema here, the underlying * data source can skip the schema inference step, and thus speed up data loading. * + * {{{ + * spark.read.schema("a INT, b STRING, c DOUBLE").csv("test.csv") + * }}} + * * @since 2.3.0 */ def schema(schemaString: String): DataFrameReader = { From fe73cb4b439169f16cc24cd851a11fd398ce7edf Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 2 Feb 2018 20:49:08 -0800 Subject: [PATCH 0281/1161] [SPARK-23317][SQL] rename ContinuousReader.setOffset to setStartOffset ## What changes were proposed in this pull request? In the document of `ContinuousReader.setOffset`, we say this method is used to specify the start offset. We also have a `ContinuousReader.getStartOffset` to get the value back. I think it makes more sense to rename `ContinuousReader.setOffset` to `setStartOffset`. ## How was this patch tested? N/A Author: Wenchen Fan Closes #20486 from cloud-fan/rename. --- .../org/apache/spark/sql/kafka010/KafkaContinuousReader.scala | 2 +- .../sql/sources/v2/reader/streaming/ContinuousReader.java | 4 ++-- .../execution/streaming/continuous/ContinuousExecution.scala | 2 +- .../streaming/continuous/ContinuousRateStreamSource.scala | 2 +- .../spark/sql/execution/streaming/RateSourceV2Suite.scala | 2 +- .../sql/streaming/sources/StreamingDataSourceV2Suite.scala | 2 +- 6 files changed, 7 insertions(+), 7 deletions(-) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala index 41c443bc12120..b049a054cb40e 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala @@ -71,7 +71,7 @@ class KafkaContinuousReader( override def readSchema: StructType = KafkaOffsetReader.kafkaSchema private var offset: Offset = _ - override def setOffset(start: ju.Optional[Offset]): Unit = { + override def setStartOffset(start: ju.Optional[Offset]): Unit = { offset = start.orElse { val offsets = initialOffsets match { case EarliestOffsetRangeLimit => KafkaSourceOffset(offsetReader.fetchEarliestOffsets()) diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousReader.java index d1d1e7ffd1dd4..7fe7f00ac2fa8 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousReader.java @@ -51,12 +51,12 @@ public interface ContinuousReader extends BaseStreamingSource, DataSourceReader * start from the first record after the provided offset, or from an implementation-defined * inferred starting point if no offset is provided. */ - void setOffset(Optional start); + void setStartOffset(Optional start); /** * Return the specified or inferred start offset for this reader. * - * @throws IllegalStateException if setOffset has not been called + * @throws IllegalStateException if setStartOffset has not been called */ Offset getStartOffset(); diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index 08c81419a9d34..ed22b9100497a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -181,7 +181,7 @@ class ContinuousExecution( val loggedOffset = offsets.offsets(0) val realOffset = loggedOffset.map(off => reader.deserializeOffset(off.json)) - reader.setOffset(java.util.Optional.ofNullable(realOffset.orNull)) + reader.setStartOffset(java.util.Optional.ofNullable(realOffset.orNull)) new StreamingDataSourceV2Relation(newOutput, reader) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala index 0eaaa4889ba9e..b63d8d3e20650 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala @@ -61,7 +61,7 @@ class RateStreamContinuousReader(options: DataSourceOptions) private var offset: Offset = _ - override def setOffset(offset: java.util.Optional[Offset]): Unit = { + override def setStartOffset(offset: java.util.Optional[Offset]): Unit = { this.offset = offset.orElse(RateStreamSourceV2.createInitialOffset(numPartitions, creationTime)) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala index 3158995ec62f1..0d68d9c3138aa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala @@ -160,7 +160,7 @@ class RateSourceV2Suite extends StreamTest { test("continuous data") { val reader = new RateStreamContinuousReader( new DataSourceOptions(Map("numPartitions" -> "2", "rowsPerSecond" -> "20").asJava)) - reader.setOffset(Optional.empty()) + reader.setStartOffset(Optional.empty()) val tasks = reader.createDataReaderFactories() assert(tasks.size == 2) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala index cb873ab688e96..51f44fa6285e4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala @@ -43,7 +43,7 @@ case class FakeReader() extends MicroBatchReader with ContinuousReader { def readSchema(): StructType = StructType(Seq()) def stop(): Unit = {} def mergeOffsets(offsets: Array[PartitionOffset]): Offset = RateStreamOffset(Map()) - def setOffset(start: Optional[Offset]): Unit = {} + def setStartOffset(start: Optional[Offset]): Unit = {} def createDataReaderFactories(): java.util.ArrayList[DataReaderFactory[Row]] = { throw new IllegalStateException("fake source - cannot actually read") From 63b49fa2e599080c2ba7d5189f9dde20a2e01fb4 Mon Sep 17 00:00:00 2001 From: caoxuewen Date: Sat, 3 Feb 2018 00:02:03 -0800 Subject: [PATCH 0282/1161] [SPARK-23311][SQL][TEST] add FilterFunction test case for test CombineTypedFilters ## What changes were proposed in this pull request? In the current test case for CombineTypedFilters, we lack the test of FilterFunction, so let's add it. In addition, in TypedFilterOptimizationSuite's existing test cases, Let's extract a common LocalRelation. ## How was this patch tested? add new test cases. Author: caoxuewen Closes #20482 from heary-cao/TypedFilterOptimizationSuite. --- .../spark/sql/catalyst/dsl/package.scala | 3 + .../TypedFilterOptimizationSuite.scala | 95 ++++++++++++++++--- 2 files changed, 84 insertions(+), 14 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 59cb26d5e6c36..efb2eba655e15 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -21,6 +21,7 @@ import java.sql.{Date, Timestamp} import scala.language.implicitConversions +import org.apache.spark.api.java.function.FilterFunction import org.apache.spark.sql.Encoder import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ @@ -301,6 +302,8 @@ package object dsl { def filter[T : Encoder](func: T => Boolean): LogicalPlan = TypedFilter(func, logicalPlan) + def filter[T : Encoder](func: FilterFunction[T]): LogicalPlan = TypedFilter(func, logicalPlan) + def serialize[T : Encoder]: LogicalPlan = CatalystSerde.serialize[T](logicalPlan) def deserialize[T : Encoder]: LogicalPlan = CatalystSerde.deserialize[T](logicalPlan) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TypedFilterOptimizationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TypedFilterOptimizationSuite.scala index 56f096f3ecf8c..5fc99a3a57c0f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TypedFilterOptimizationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TypedFilterOptimizationSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.optimizer import scala.reflect.runtime.universe.TypeTag +import org.apache.spark.api.java.function.FilterFunction import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder @@ -38,18 +39,19 @@ class TypedFilterOptimizationSuite extends PlanTest { implicit private def productEncoder[T <: Product : TypeTag] = ExpressionEncoder[T]() + val testRelation = LocalRelation('_1.int, '_2.int) + test("filter after serialize with the same object type") { - val input = LocalRelation('_1.int, '_2.int) val f = (i: (Int, Int)) => i._1 > 0 - val query = input + val query = testRelation .deserialize[(Int, Int)] .serialize[(Int, Int)] .filter(f).analyze val optimized = Optimize.execute(query) - val expected = input + val expected = testRelation .deserialize[(Int, Int)] .where(callFunction(f, BooleanType, 'obj)) .serialize[(Int, Int)].analyze @@ -58,10 +60,9 @@ class TypedFilterOptimizationSuite extends PlanTest { } test("filter after serialize with different object types") { - val input = LocalRelation('_1.int, '_2.int) val f = (i: OtherTuple) => i._1 > 0 - val query = input + val query = testRelation .deserialize[(Int, Int)] .serialize[(Int, Int)] .filter(f).analyze @@ -70,17 +71,16 @@ class TypedFilterOptimizationSuite extends PlanTest { } test("filter before deserialize with the same object type") { - val input = LocalRelation('_1.int, '_2.int) val f = (i: (Int, Int)) => i._1 > 0 - val query = input + val query = testRelation .filter(f) .deserialize[(Int, Int)] .serialize[(Int, Int)].analyze val optimized = Optimize.execute(query) - val expected = input + val expected = testRelation .deserialize[(Int, Int)] .where(callFunction(f, BooleanType, 'obj)) .serialize[(Int, Int)].analyze @@ -89,10 +89,9 @@ class TypedFilterOptimizationSuite extends PlanTest { } test("filter before deserialize with different object types") { - val input = LocalRelation('_1.int, '_2.int) val f = (i: OtherTuple) => i._1 > 0 - val query = input + val query = testRelation .filter(f) .deserialize[(Int, Int)] .serialize[(Int, Int)].analyze @@ -101,21 +100,89 @@ class TypedFilterOptimizationSuite extends PlanTest { } test("back to back filter with the same object type") { - val input = LocalRelation('_1.int, '_2.int) val f1 = (i: (Int, Int)) => i._1 > 0 val f2 = (i: (Int, Int)) => i._2 > 0 - val query = input.filter(f1).filter(f2).analyze + val query = testRelation.filter(f1).filter(f2).analyze val optimized = Optimize.execute(query) assert(optimized.collect { case t: TypedFilter => t }.length == 1) } test("back to back filter with different object types") { - val input = LocalRelation('_1.int, '_2.int) val f1 = (i: (Int, Int)) => i._1 > 0 val f2 = (i: OtherTuple) => i._2 > 0 - val query = input.filter(f1).filter(f2).analyze + val query = testRelation.filter(f1).filter(f2).analyze + val optimized = Optimize.execute(query) + assert(optimized.collect { case t: TypedFilter => t }.length == 2) + } + + test("back to back FilterFunction with the same object type") { + val f1 = new FilterFunction[(Int, Int)] { + override def call(value: (Int, Int)): Boolean = value._1 > 0 + } + val f2 = new FilterFunction[(Int, Int)] { + override def call(value: (Int, Int)): Boolean = value._2 > 0 + } + + val query = testRelation.filter(f1).filter(f2).analyze + val optimized = Optimize.execute(query) + assert(optimized.collect { case t: TypedFilter => t }.length == 1) + } + + test("back to back FilterFunction with different object types") { + val f1 = new FilterFunction[(Int, Int)] { + override def call(value: (Int, Int)): Boolean = value._1 > 0 + } + val f2 = new FilterFunction[OtherTuple] { + override def call(value: OtherTuple): Boolean = value._2 > 0 + } + + val query = testRelation.filter(f1).filter(f2).analyze + val optimized = Optimize.execute(query) + assert(optimized.collect { case t: TypedFilter => t }.length == 2) + } + + test("FilterFunction and filter with the same object type") { + val f1 = new FilterFunction[(Int, Int)] { + override def call(value: (Int, Int)): Boolean = value._1 > 0 + } + val f2 = (i: (Int, Int)) => i._2 > 0 + + val query = testRelation.filter(f1).filter(f2).analyze + val optimized = Optimize.execute(query) + assert(optimized.collect { case t: TypedFilter => t }.length == 1) + } + + test("FilterFunction and filter with different object types") { + val f1 = new FilterFunction[(Int, Int)] { + override def call(value: (Int, Int)): Boolean = value._1 > 0 + } + val f2 = (i: OtherTuple) => i._2 > 0 + + val query = testRelation.filter(f1).filter(f2).analyze + val optimized = Optimize.execute(query) + assert(optimized.collect { case t: TypedFilter => t }.length == 2) + } + + test("filter and FilterFunction with the same object type") { + val f2 = (i: (Int, Int)) => i._1 > 0 + val f1 = new FilterFunction[(Int, Int)] { + override def call(value: (Int, Int)): Boolean = value._2 > 0 + } + + val query = testRelation.filter(f1).filter(f2).analyze + val optimized = Optimize.execute(query) + assert(optimized.collect { case t: TypedFilter => t }.length == 1) + } + + test("filter and FilterFunction with different object types") { + val f2 = (i: (Int, Int)) => i._1 > 0 + val f1 = new FilterFunction[OtherTuple] { + override def call(value: OtherTuple): Boolean = value._2 > 0 + } + + val query = testRelation.filter(f1).filter(f2).analyze val optimized = Optimize.execute(query) assert(optimized.collect { case t: TypedFilter => t }.length == 2) } From 522e0b1866a0298669c83de5a47ba380dc0b7c84 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Sat, 3 Feb 2018 00:04:00 -0800 Subject: [PATCH 0283/1161] [SPARK-23305][SQL][TEST] Test `spark.sql.files.ignoreMissingFiles` for all file-based data sources ## What changes were proposed in this pull request? Like Parquet, all file-based data source handles `spark.sql.files.ignoreMissingFiles` correctly. We had better have a test coverage for feature parity and in order to prevent future accidental regression for all data sources. ## How was this patch tested? Pass Jenkins with a newly added test case. Author: Dongjoon Hyun Closes #20479 from dongjoon-hyun/SPARK-23305. --- .../spark/sql/FileBasedDataSourceSuite.scala | 37 +++++++++++++++++++ .../parquet/ParquetQuerySuite.scala | 33 ----------------- 2 files changed, 37 insertions(+), 33 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala index c272c99ae45a8..640d6b1583663 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala @@ -17,6 +17,10 @@ package org.apache.spark.sql +import org.apache.hadoop.fs.Path + +import org.apache.spark.SparkException +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext class FileBasedDataSourceSuite extends QueryTest with SharedSQLContext { @@ -92,4 +96,37 @@ class FileBasedDataSourceSuite extends QueryTest with SharedSQLContext { } } } + + allFileBasedDataSources.foreach { format => + testQuietly(s"Enabling/disabling ignoreMissingFiles using $format") { + def testIgnoreMissingFiles(): Unit = { + withTempDir { dir => + val basePath = dir.getCanonicalPath + Seq("0").toDF("a").write.format(format).save(new Path(basePath, "first").toString) + Seq("1").toDF("a").write.format(format).save(new Path(basePath, "second").toString) + val thirdPath = new Path(basePath, "third") + Seq("2").toDF("a").write.format(format).save(thirdPath.toString) + val df = spark.read.format(format).load( + new Path(basePath, "first").toString, + new Path(basePath, "second").toString, + new Path(basePath, "third").toString) + + val fs = thirdPath.getFileSystem(spark.sparkContext.hadoopConfiguration) + assert(fs.delete(thirdPath, true)) + checkAnswer(df, Seq(Row("0"), Row("1"))) + } + } + + withSQLConf(SQLConf.IGNORE_MISSING_FILES.key -> "true") { + testIgnoreMissingFiles() + } + + withSQLConf(SQLConf.IGNORE_MISSING_FILES.key -> "false") { + val exception = intercept[SparkException] { + testIgnoreMissingFiles() + } + assert(exception.getMessage().contains("does not exist")) + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala index 6ad88ed997ce7..55b0f729be8ce 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala @@ -355,39 +355,6 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext } } - testQuietly("Enabling/disabling ignoreMissingFiles") { - def testIgnoreMissingFiles(): Unit = { - withTempDir { dir => - val basePath = dir.getCanonicalPath - spark.range(1).toDF("a").write.parquet(new Path(basePath, "first").toString) - spark.range(1, 2).toDF("a").write.parquet(new Path(basePath, "second").toString) - val thirdPath = new Path(basePath, "third") - spark.range(2, 3).toDF("a").write.parquet(thirdPath.toString) - val df = spark.read.parquet( - new Path(basePath, "first").toString, - new Path(basePath, "second").toString, - new Path(basePath, "third").toString) - - val fs = thirdPath.getFileSystem(spark.sparkContext.hadoopConfiguration) - fs.delete(thirdPath, true) - checkAnswer( - df, - Seq(Row(0), Row(1))) - } - } - - withSQLConf(SQLConf.IGNORE_MISSING_FILES.key -> "true") { - testIgnoreMissingFiles() - } - - withSQLConf(SQLConf.IGNORE_MISSING_FILES.key -> "false") { - val exception = intercept[SparkException] { - testIgnoreMissingFiles() - } - assert(exception.getMessage().contains("does not exist")) - } - } - /** * this is part of test 'Enabling/disabling ignoreCorruptFiles' but run in a loop * to increase the chance of failure From 4aaa7d40bf495317e740b6d6f9c2a55dfd03521b Mon Sep 17 00:00:00 2001 From: Shashwat Anand Date: Sat, 3 Feb 2018 10:31:04 -0800 Subject: [PATCH 0284/1161] [MINOR][DOC] Use raw triple double quotes around docstrings where there are occurrences of backslashes. From [PEP 257](https://www.python.org/dev/peps/pep-0257/): > For consistency, always use """triple double quotes""" around docstrings. Use r"""raw triple double quotes""" if you use any backslashes in your docstrings. For Unicode docstrings, use u"""Unicode triple-quoted strings""". For example, this is what help (kafka_wordcount) shows: ``` DESCRIPTION Counts words in UTF8 encoded, ' ' delimited text received from the network every second. Usage: kafka_wordcount.py To run this on your local machine, you need to setup Kafka and create a producer first, see http://kafka.apache.org/documentation.html#quickstart and then run the example `$ bin/spark-submit --jars external/kafka-assembly/target/scala-*/spark-streaming-kafka-assembly-*.jar examples/src/main/python/streaming/kafka_wordcount.py localhost:2181 test` ``` This is what it shows, after the fix: ``` DESCRIPTION Counts words in UTF8 encoded, '\n' delimited text received from the network every second. Usage: kafka_wordcount.py To run this on your local machine, you need to setup Kafka and create a producer first, see http://kafka.apache.org/documentation.html#quickstart and then run the example `$ bin/spark-submit --jars \ external/kafka-assembly/target/scala-*/spark-streaming-kafka-assembly-*.jar \ examples/src/main/python/streaming/kafka_wordcount.py \ localhost:2181 test` ``` The thing worth noticing is no linebreak here in the help. ## What changes were proposed in this pull request? Change triple double quotes to raw triple double quotes when there are occurrences of backslashes in docstrings. ## How was this patch tested? Manually as this is a doc fix. Author: Shashwat Anand Closes #20497 from ashashwat/docstring-fixes. --- .../main/python/sql/streaming/structured_network_wordcount.py | 2 +- .../sql/streaming/structured_network_wordcount_windowed.py | 2 +- examples/src/main/python/streaming/direct_kafka_wordcount.py | 2 +- examples/src/main/python/streaming/flume_wordcount.py | 2 +- examples/src/main/python/streaming/kafka_wordcount.py | 2 +- examples/src/main/python/streaming/network_wordcount.py | 2 +- .../src/main/python/streaming/network_wordjoinsentiments.py | 2 +- examples/src/main/python/streaming/sql_network_wordcount.py | 2 +- .../src/main/python/streaming/stateful_network_wordcount.py | 2 +- 9 files changed, 9 insertions(+), 9 deletions(-) diff --git a/examples/src/main/python/sql/streaming/structured_network_wordcount.py b/examples/src/main/python/sql/streaming/structured_network_wordcount.py index afde2550587ca..c3284c1d01017 100644 --- a/examples/src/main/python/sql/streaming/structured_network_wordcount.py +++ b/examples/src/main/python/sql/streaming/structured_network_wordcount.py @@ -15,7 +15,7 @@ # limitations under the License. # -""" +r""" Counts words in UTF8 encoded, '\n' delimited text received from the network. Usage: structured_network_wordcount.py and describe the TCP server that Structured Streaming diff --git a/examples/src/main/python/sql/streaming/structured_network_wordcount_windowed.py b/examples/src/main/python/sql/streaming/structured_network_wordcount_windowed.py index 02a7d3363d780..db672551504b5 100644 --- a/examples/src/main/python/sql/streaming/structured_network_wordcount_windowed.py +++ b/examples/src/main/python/sql/streaming/structured_network_wordcount_windowed.py @@ -15,7 +15,7 @@ # limitations under the License. # -""" +r""" Counts words in UTF8 encoded, '\n' delimited text received from the network over a sliding window of configurable duration. Each line from the network is tagged with a timestamp that is used to determine the windows into which it falls. diff --git a/examples/src/main/python/streaming/direct_kafka_wordcount.py b/examples/src/main/python/streaming/direct_kafka_wordcount.py index 7097f7f4502bd..425df309011a0 100644 --- a/examples/src/main/python/streaming/direct_kafka_wordcount.py +++ b/examples/src/main/python/streaming/direct_kafka_wordcount.py @@ -15,7 +15,7 @@ # limitations under the License. # -""" +r""" Counts words in UTF8 encoded, '\n' delimited text directly received from Kafka in every 2 seconds. Usage: direct_kafka_wordcount.py diff --git a/examples/src/main/python/streaming/flume_wordcount.py b/examples/src/main/python/streaming/flume_wordcount.py index d75bc6daac138..5d6e6dc36d6f9 100644 --- a/examples/src/main/python/streaming/flume_wordcount.py +++ b/examples/src/main/python/streaming/flume_wordcount.py @@ -15,7 +15,7 @@ # limitations under the License. # -""" +r""" Counts words in UTF8 encoded, '\n' delimited text received from the network every second. Usage: flume_wordcount.py diff --git a/examples/src/main/python/streaming/kafka_wordcount.py b/examples/src/main/python/streaming/kafka_wordcount.py index 8d697f620f467..704f6602e2297 100644 --- a/examples/src/main/python/streaming/kafka_wordcount.py +++ b/examples/src/main/python/streaming/kafka_wordcount.py @@ -15,7 +15,7 @@ # limitations under the License. # -""" +r""" Counts words in UTF8 encoded, '\n' delimited text received from the network every second. Usage: kafka_wordcount.py diff --git a/examples/src/main/python/streaming/network_wordcount.py b/examples/src/main/python/streaming/network_wordcount.py index 2b48bcfd55db0..9010fafb425e6 100644 --- a/examples/src/main/python/streaming/network_wordcount.py +++ b/examples/src/main/python/streaming/network_wordcount.py @@ -15,7 +15,7 @@ # limitations under the License. # -""" +r""" Counts words in UTF8 encoded, '\n' delimited text received from the network every second. Usage: network_wordcount.py and describe the TCP server that Spark Streaming would connect to receive data. diff --git a/examples/src/main/python/streaming/network_wordjoinsentiments.py b/examples/src/main/python/streaming/network_wordjoinsentiments.py index b309d9fad33f5..d51a380a5d5f9 100644 --- a/examples/src/main/python/streaming/network_wordjoinsentiments.py +++ b/examples/src/main/python/streaming/network_wordjoinsentiments.py @@ -15,7 +15,7 @@ # limitations under the License. # -""" +r""" Shows the most positive words in UTF8 encoded, '\n' delimited text directly received the network every 5 seconds. The streaming data is joined with a static RDD of the AFINN word list (http://neuro.imm.dtu.dk/wiki/AFINN) diff --git a/examples/src/main/python/streaming/sql_network_wordcount.py b/examples/src/main/python/streaming/sql_network_wordcount.py index 398ac8d2d8f5e..7f12281c0e3fe 100644 --- a/examples/src/main/python/streaming/sql_network_wordcount.py +++ b/examples/src/main/python/streaming/sql_network_wordcount.py @@ -15,7 +15,7 @@ # limitations under the License. # -""" +r""" Use DataFrames and SQL to count words in UTF8 encoded, '\n' delimited text received from the network every second. diff --git a/examples/src/main/python/streaming/stateful_network_wordcount.py b/examples/src/main/python/streaming/stateful_network_wordcount.py index f8bbc659c2ea7..d7bb61e729f18 100644 --- a/examples/src/main/python/streaming/stateful_network_wordcount.py +++ b/examples/src/main/python/streaming/stateful_network_wordcount.py @@ -15,7 +15,7 @@ # limitations under the License. # -""" +r""" Counts words in UTF8 encoded, '\n' delimited text received from the network every second. From 551dff2bccb65e9b3f77b986f167aec90d9a6016 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Sat, 3 Feb 2018 10:40:21 -0800 Subject: [PATCH 0285/1161] [SPARK-21658][SQL][PYSPARK] Revert "[] Add default None for value in na.replace in PySpark" This reverts commit 0fcde87aadc9a92e138f11583119465ca4b5c518. See the discussion in [SPARK-21658](https://issues.apache.org/jira/browse/SPARK-21658), [SPARK-19454](https://issues.apache.org/jira/browse/SPARK-19454) and https://github.com/apache/spark/pull/16793 Author: hyukjinkwon Closes #20496 from HyukjinKwon/revert-SPARK-21658. --- python/pyspark/sql/dataframe.py | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 1496cba91b90e..2e55407b5397b 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1577,16 +1577,6 @@ def replace(self, to_replace, value=None, subset=None): |null| null|null| +----+------+----+ - >>> df4.na.replace('Alice').show() - +----+------+----+ - | age|height|name| - +----+------+----+ - | 10| 80|null| - | 5| null| Bob| - |null| null| Tom| - |null| null|null| - +----+------+----+ - >>> df4.na.replace(['Alice', 'Bob'], ['A', 'B'], 'name').show() +----+------+----+ | age|height|name| @@ -2055,7 +2045,7 @@ def fill(self, value, subset=None): fill.__doc__ = DataFrame.fillna.__doc__ - def replace(self, to_replace, value=None, subset=None): + def replace(self, to_replace, value, subset=None): return self.df.replace(to_replace, value, subset) replace.__doc__ = DataFrame.replace.__doc__ From 715047b02df0ac9ec16ab2a73481ab7f36ffc6ca Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Sun, 4 Feb 2018 17:53:31 +0900 Subject: [PATCH 0286/1161] [SPARK-23256][ML][PYTHON] Add columnSchema method to PySpark image reader ## What changes were proposed in this pull request? This PR proposes to add `columnSchema` in Python side too. ```python >>> from pyspark.ml.image import ImageSchema >>> ImageSchema.columnSchema.simpleString() 'struct' ``` ## How was this patch tested? Manually tested and unittest was added in `python/pyspark/ml/tests.py`. Author: hyukjinkwon Closes #20475 from HyukjinKwon/SPARK-23256. --- python/pyspark/ml/image.py | 20 +++++++++++++++++++- python/pyspark/ml/tests.py | 1 + 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/python/pyspark/ml/image.py b/python/pyspark/ml/image.py index 2d86c7f03860c..45c936645f2a8 100644 --- a/python/pyspark/ml/image.py +++ b/python/pyspark/ml/image.py @@ -40,6 +40,7 @@ class _ImageSchema(object): def __init__(self): self._imageSchema = None self._ocvTypes = None + self._columnSchema = None self._imageFields = None self._undefinedImageType = None @@ -49,7 +50,7 @@ def imageSchema(self): Returns the image schema. :return: a :class:`StructType` with a single column of images - named "image" (nullable). + named "image" (nullable) and having the same type returned by :meth:`columnSchema`. .. versionadded:: 2.3.0 """ @@ -75,6 +76,23 @@ def ocvTypes(self): self._ocvTypes = dict(ctx._jvm.org.apache.spark.ml.image.ImageSchema.javaOcvTypes()) return self._ocvTypes + @property + def columnSchema(self): + """ + Returns the schema for the image column. + + :return: a :class:`StructType` for image column, + ``struct``. + + .. versionadded:: 2.4.0 + """ + + if self._columnSchema is None: + ctx = SparkContext._active_spark_context + jschema = ctx._jvm.org.apache.spark.ml.image.ImageSchema.columnSchema() + self._columnSchema = _parse_datatype_json_string(jschema.json()) + return self._columnSchema + @property def imageFields(self): """ diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 1af2b91da900d..75d04785a0710 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -1852,6 +1852,7 @@ def test_read_images(self): self.assertEqual(len(array), first_row[1]) self.assertEqual(ImageSchema.toImage(array, origin=first_row[0]), first_row) self.assertEqual(df.schema, ImageSchema.imageSchema) + self.assertEqual(df.schema["image"].dataType, ImageSchema.columnSchema) expected = {'CV_8UC3': 16, 'Undefined': -1, 'CV_8U': 0, 'CV_8UC1': 0, 'CV_8UC4': 24} self.assertEqual(ImageSchema.ocvTypes, expected) expected = ['origin', 'height', 'width', 'nChannels', 'mode', 'data'] From 6fb3fd15365d43733aefdb396db205d7ccf57f75 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Sun, 4 Feb 2018 09:15:48 -0800 Subject: [PATCH 0287/1161] [SPARK-22036][SQL][FOLLOWUP] Fix decimalArithmeticOperations.sql ## What changes were proposed in this pull request? Fix decimalArithmeticOperations.sql test ## How was this patch tested? N/A Author: Yuming Wang Author: wangyum Author: Yuming Wang Closes #20498 from wangyum/SPARK-22036. --- .../native/decimalArithmeticOperations.sql | 6 +- .../decimalArithmeticOperations.sql.out | 140 ++++++++++-------- 2 files changed, 80 insertions(+), 66 deletions(-) diff --git a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/decimalArithmeticOperations.sql b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/decimalArithmeticOperations.sql index c6d8a49d4b93a..9be7fcdadfea8 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/decimalArithmeticOperations.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/decimalArithmeticOperations.sql @@ -48,8 +48,9 @@ select 12345678901234567890.0 * 12345678901234567890.0; select 1e35 / 0.1; -- arithmetic operations causing a precision loss are truncated +select 12345678912345678912345678912.1234567 + 9999999999999999999999999999999.12345; select 123456789123456789.1234567890 * 1.123456789123456789; -select 0.001 / 9876543210987654321098765432109876543.2 +select 12345678912345.123456789123 / 0.000000012345678; -- return NULL instead of rounding, according to old Spark versions' behavior set spark.sql.decimalOperations.allowPrecisionLoss=false; @@ -74,7 +75,8 @@ select 12345678901234567890.0 * 12345678901234567890.0; select 1e35 / 0.1; -- arithmetic operations causing a precision loss return NULL +select 12345678912345678912345678912.1234567 + 9999999999999999999999999999999.12345; select 123456789123456789.1234567890 * 1.123456789123456789; -select 0.001 / 9876543210987654321098765432109876543.2 +select 12345678912345.123456789123 / 0.000000012345678; drop table decimals_test; diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/decimalArithmeticOperations.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/decimalArithmeticOperations.sql.out index 4d70fe19d539f..6bfdb84548d4d 100644 --- a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/decimalArithmeticOperations.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/decimalArithmeticOperations.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 32 +-- Number of queries: 36 -- !query 0 @@ -146,146 +146,158 @@ NULL -- !query 17 -select 123456789123456789.1234567890 * 1.123456789123456789 +select 12345678912345678912345678912.1234567 + 9999999999999999999999999999999.12345 -- !query 17 schema -struct<(CAST(123456789123456789.1234567890 AS DECIMAL(36,18)) * CAST(1.123456789123456789 AS DECIMAL(36,18))):decimal(38,18)> +struct<(CAST(12345678912345678912345678912.1234567 AS DECIMAL(38,6)) + CAST(9999999999999999999999999999999.12345 AS DECIMAL(38,6))):decimal(38,6)> -- !query 17 output -138698367904130467.654320988515622621 +10012345678912345678912345678911.246907 -- !query 18 -select 0.001 / 9876543210987654321098765432109876543.2 - -set spark.sql.decimalOperations.allowPrecisionLoss=false +select 123456789123456789.1234567890 * 1.123456789123456789 -- !query 18 schema -struct<> +struct<(CAST(123456789123456789.1234567890 AS DECIMAL(36,18)) * CAST(1.123456789123456789 AS DECIMAL(36,18))):decimal(38,18)> -- !query 18 output -org.apache.spark.sql.catalyst.parser.ParseException - -mismatched input 'spark' expecting (line 3, pos 4) - -== SQL == -select 0.001 / 9876543210987654321098765432109876543.2 - -set spark.sql.decimalOperations.allowPrecisionLoss=false -----^^^ +138698367904130467.654320988515622621 -- !query 19 -select id, a+b, a-b, a*b, a/b from decimals_test order by id +select 12345678912345.123456789123 / 0.000000012345678 -- !query 19 schema -struct +struct<(CAST(12345678912345.123456789123 AS DECIMAL(29,15)) / CAST(1.2345678E-8 AS DECIMAL(29,15))):decimal(38,9)> -- !query 19 output -1 1099 -899 99900 0.1001 -2 24690.246 0 152402061.885129 1 -3 1234.2234567891011 -1233.9765432108989 152.358023 0.0001 -4 123456789123456790.12345678912345679 123456789123456787.87654321087654321 138698367904130467.515623 109890109097814272.043109 +1000000073899961059796.725866332 -- !query 20 -select id, a*10, b/10 from decimals_test order by id +set spark.sql.decimalOperations.allowPrecisionLoss=false -- !query 20 schema -struct +struct -- !query 20 output -1 1000 99.9 -2 123451.23 1234.5123 -3 1.234567891011 123.41 -4 1234567891234567890 0.112345678912345679 +spark.sql.decimalOperations.allowPrecisionLoss false -- !query 21 -select 10.3 * 3.0 +select id, a+b, a-b, a*b, a/b from decimals_test order by id -- !query 21 schema -struct<(CAST(10.3 AS DECIMAL(3,1)) * CAST(3.0 AS DECIMAL(3,1))):decimal(6,2)> +struct -- !query 21 output -30.9 +1 1099 -899 NULL 0.1001001001001001 +2 24690.246 0 NULL 1 +3 1234.2234567891011 -1233.9765432108989 NULL 0.000100037913541123 +4 123456789123456790.123456789123456789 123456789123456787.876543210876543211 NULL 109890109097814272.043109406191131436 -- !query 22 -select 10.3000 * 3.0 +select id, a*10, b/10 from decimals_test order by id -- !query 22 schema -struct<(CAST(10.3000 AS DECIMAL(6,4)) * CAST(3.0 AS DECIMAL(6,4))):decimal(9,5)> +struct -- !query 22 output -30.9 +1 1000 99.9 +2 123451.23 1234.5123 +3 1.234567891011 123.41 +4 1234567891234567890 0.1123456789123456789 -- !query 23 -select 10.30000 * 30.0 +select 10.3 * 3.0 -- !query 23 schema -struct<(CAST(10.30000 AS DECIMAL(7,5)) * CAST(30.0 AS DECIMAL(7,5))):decimal(11,6)> +struct<(CAST(10.3 AS DECIMAL(3,1)) * CAST(3.0 AS DECIMAL(3,1))):decimal(6,2)> -- !query 23 output -309 +30.9 -- !query 24 -select 10.300000000000000000 * 3.000000000000000000 +select 10.3000 * 3.0 -- !query 24 schema -struct<(CAST(10.300000000000000000 AS DECIMAL(20,18)) * CAST(3.000000000000000000 AS DECIMAL(20,18))):decimal(38,34)> +struct<(CAST(10.3000 AS DECIMAL(6,4)) * CAST(3.0 AS DECIMAL(6,4))):decimal(9,5)> -- !query 24 output 30.9 -- !query 25 -select 10.300000000000000000 * 3.0000000000000000000 +select 10.30000 * 30.0 -- !query 25 schema -struct<(CAST(10.300000000000000000 AS DECIMAL(21,19)) * CAST(3.0000000000000000000 AS DECIMAL(21,19))):decimal(38,34)> +struct<(CAST(10.30000 AS DECIMAL(7,5)) * CAST(30.0 AS DECIMAL(7,5))):decimal(11,6)> -- !query 25 output -30.9 +309 -- !query 26 -select (5e36 + 0.1) + 5e36 +select 10.300000000000000000 * 3.000000000000000000 -- !query 26 schema -struct<(CAST((CAST(5E+36 AS DECIMAL(38,1)) + CAST(0.1 AS DECIMAL(38,1))) AS DECIMAL(38,1)) + CAST(5E+36 AS DECIMAL(38,1))):decimal(38,1)> +struct<(CAST(10.300000000000000000 AS DECIMAL(20,18)) * CAST(3.000000000000000000 AS DECIMAL(20,18))):decimal(38,36)> -- !query 26 output -NULL +30.9 -- !query 27 -select (-4e36 - 0.1) - 7e36 +select 10.300000000000000000 * 3.0000000000000000000 -- !query 27 schema -struct<(CAST((CAST(-4E+36 AS DECIMAL(38,1)) - CAST(0.1 AS DECIMAL(38,1))) AS DECIMAL(38,1)) - CAST(7E+36 AS DECIMAL(38,1))):decimal(38,1)> +struct<(CAST(10.300000000000000000 AS DECIMAL(21,19)) * CAST(3.0000000000000000000 AS DECIMAL(21,19))):decimal(38,37)> -- !query 27 output NULL -- !query 28 -select 12345678901234567890.0 * 12345678901234567890.0 +select (5e36 + 0.1) + 5e36 -- !query 28 schema -struct<(12345678901234567890.0 * 12345678901234567890.0):decimal(38,2)> +struct<(CAST((CAST(5E+36 AS DECIMAL(38,1)) + CAST(0.1 AS DECIMAL(38,1))) AS DECIMAL(38,1)) + CAST(5E+36 AS DECIMAL(38,1))):decimal(38,1)> -- !query 28 output NULL -- !query 29 -select 1e35 / 0.1 +select (-4e36 - 0.1) - 7e36 -- !query 29 schema -struct<(CAST(1E+35 AS DECIMAL(37,1)) / CAST(0.1 AS DECIMAL(37,1))):decimal(38,6)> +struct<(CAST((CAST(-4E+36 AS DECIMAL(38,1)) - CAST(0.1 AS DECIMAL(38,1))) AS DECIMAL(38,1)) - CAST(7E+36 AS DECIMAL(38,1))):decimal(38,1)> -- !query 29 output NULL -- !query 30 -select 123456789123456789.1234567890 * 1.123456789123456789 +select 12345678901234567890.0 * 12345678901234567890.0 -- !query 30 schema -struct<(CAST(123456789123456789.1234567890 AS DECIMAL(36,18)) * CAST(1.123456789123456789 AS DECIMAL(36,18))):decimal(38,18)> +struct<(12345678901234567890.0 * 12345678901234567890.0):decimal(38,2)> -- !query 30 output -138698367904130467.654320988515622621 +NULL -- !query 31 -select 0.001 / 9876543210987654321098765432109876543.2 - -drop table decimals_test +select 1e35 / 0.1 -- !query 31 schema -struct<> +struct<(CAST(1E+35 AS DECIMAL(37,1)) / CAST(0.1 AS DECIMAL(37,1))):decimal(38,3)> -- !query 31 output -org.apache.spark.sql.catalyst.parser.ParseException +NULL -mismatched input 'table' expecting (line 3, pos 5) -== SQL == -select 0.001 / 9876543210987654321098765432109876543.2 +-- !query 32 +select 12345678912345678912345678912.1234567 + 9999999999999999999999999999999.12345 +-- !query 32 schema +struct<(CAST(12345678912345678912345678912.1234567 AS DECIMAL(38,7)) + CAST(9999999999999999999999999999999.12345 AS DECIMAL(38,7))):decimal(38,7)> +-- !query 32 output +NULL + + +-- !query 33 +select 123456789123456789.1234567890 * 1.123456789123456789 +-- !query 33 schema +struct<(CAST(123456789123456789.1234567890 AS DECIMAL(36,18)) * CAST(1.123456789123456789 AS DECIMAL(36,18))):decimal(38,28)> +-- !query 33 output +NULL + +-- !query 34 +select 12345678912345.123456789123 / 0.000000012345678 +-- !query 34 schema +struct<(CAST(12345678912345.123456789123 AS DECIMAL(29,15)) / CAST(1.2345678E-8 AS DECIMAL(29,15))):decimal(38,18)> +-- !query 34 output +NULL + + +-- !query 35 drop table decimals_test ------^^^ +-- !query 35 schema +struct<> +-- !query 35 output + From a6bf3db20773ba65cbc4f2775db7bd215e78829a Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Mon, 5 Feb 2018 18:41:49 +0800 Subject: [PATCH 0288/1161] [SPARK-23307][WEBUI] Sort jobs/stages/tasks/queries with the completed timestamp before cleaning up them ## What changes were proposed in this pull request? Sort jobs/stages/tasks/queries with the completed timestamp before cleaning up them to make the behavior consistent with 2.2. ## How was this patch tested? - Jenkins. - Manually ran the following codes and checked the UI for jobs/stages/tasks/queries. ``` spark.ui.retainedJobs 10 spark.ui.retainedStages 10 spark.sql.ui.retainedExecutions 10 spark.ui.retainedTasks 10 ``` ``` new Thread() { override def run() { spark.range(1, 2).foreach { i => Thread.sleep(10000) } } }.start() Thread.sleep(5000) for (_ <- 1 to 20) { new Thread() { override def run() { spark.range(1, 2).foreach { i => } } }.start() } Thread.sleep(15000) spark.range(1, 2).foreach { i => } sc.makeRDD(1 to 100, 100).foreach { i => } ``` Author: Shixiong Zhu Closes #20481 from zsxwing/SPARK-23307. --- .../spark/status/AppStatusListener.scala | 13 +-- .../org/apache/spark/status/storeTypes.scala | 7 ++ .../spark/status/AppStatusListenerSuite.scala | 90 +++++++++++++++++++ .../execution/ui/SQLAppStatusListener.scala | 4 +- .../sql/execution/ui/SQLAppStatusStore.scala | 9 +- .../ui/SQLAppStatusListenerSuite.scala | 45 ++++++++++ 6 files changed, 158 insertions(+), 10 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala index 3e34bdc0c7b63..ab01cddfca5b0 100644 --- a/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala +++ b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala @@ -875,8 +875,8 @@ private[spark] class AppStatusListener( return } - val toDelete = KVUtils.viewToSeq(kvstore.view(classOf[JobDataWrapper]), - countToDelete.toInt) { j => + val view = kvstore.view(classOf[JobDataWrapper]).index("completionTime").first(0L) + val toDelete = KVUtils.viewToSeq(view, countToDelete.toInt) { j => j.info.status != JobExecutionStatus.RUNNING && j.info.status != JobExecutionStatus.UNKNOWN } toDelete.foreach { j => kvstore.delete(j.getClass(), j.info.jobId) } @@ -888,8 +888,8 @@ private[spark] class AppStatusListener( return } - val stages = KVUtils.viewToSeq(kvstore.view(classOf[StageDataWrapper]), - countToDelete.toInt) { s => + val view = kvstore.view(classOf[StageDataWrapper]).index("completionTime").first(0L) + val stages = KVUtils.viewToSeq(view, countToDelete.toInt) { s => s.info.status != v1.StageStatus.ACTIVE && s.info.status != v1.StageStatus.PENDING } @@ -945,8 +945,9 @@ private[spark] class AppStatusListener( val countToDelete = calculateNumberToRemove(stage.savedTasks.get(), maxTasksPerStage).toInt if (countToDelete > 0) { val stageKey = Array(stage.info.stageId, stage.info.attemptNumber) - val view = kvstore.view(classOf[TaskDataWrapper]).index("stage").first(stageKey) - .last(stageKey) + val view = kvstore.view(classOf[TaskDataWrapper]) + .index(TaskIndexNames.COMPLETION_TIME) + .parent(stageKey) // Try to delete finished tasks only. val toDelete = KVUtils.viewToSeq(view, countToDelete) { t => diff --git a/core/src/main/scala/org/apache/spark/status/storeTypes.scala b/core/src/main/scala/org/apache/spark/status/storeTypes.scala index c9cb996a55fcc..412644d3657b5 100644 --- a/core/src/main/scala/org/apache/spark/status/storeTypes.scala +++ b/core/src/main/scala/org/apache/spark/status/storeTypes.scala @@ -73,6 +73,8 @@ private[spark] class JobDataWrapper( @JsonIgnore @KVIndex private def id: Int = info.jobId + @JsonIgnore @KVIndex("completionTime") + private def completionTime: Long = info.completionTime.map(_.getTime).getOrElse(-1L) } private[spark] class StageDataWrapper( @@ -90,6 +92,8 @@ private[spark] class StageDataWrapper( @JsonIgnore @KVIndex("active") private def active: Boolean = info.status == StageStatus.ACTIVE + @JsonIgnore @KVIndex("completionTime") + private def completionTime: Long = info.completionTime.map(_.getTime).getOrElse(-1L) } /** @@ -134,6 +138,7 @@ private[spark] object TaskIndexNames { final val STAGE = "stage" final val STATUS = "sta" final val TASK_INDEX = "idx" + final val COMPLETION_TIME = "ct" } /** @@ -337,6 +342,8 @@ private[spark] class TaskDataWrapper( @JsonIgnore @KVIndex(value = TaskIndexNames.ERROR, parent = TaskIndexNames.STAGE) private def error: String = if (errorMessage.isDefined) errorMessage.get else "" + @JsonIgnore @KVIndex(value = TaskIndexNames.COMPLETION_TIME, parent = TaskIndexNames.STAGE) + private def completionTime: Long = launchTime + duration } private[spark] class RDDStorageInfoWrapper(val info: RDDStorageInfo) { diff --git a/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala index 042bba7f226fd..b74d6ee2ec836 100644 --- a/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala @@ -1010,6 +1010,96 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { } } + test("eviction should respect job completion time") { + val testConf = conf.clone().set(MAX_RETAINED_JOBS, 2) + val listener = new AppStatusListener(store, testConf, true) + + // Start job 1 and job 2 + time += 1 + listener.onJobStart(SparkListenerJobStart(1, time, Nil, null)) + time += 1 + listener.onJobStart(SparkListenerJobStart(2, time, Nil, null)) + + // Stop job 2 before job 1 + time += 1 + listener.onJobEnd(SparkListenerJobEnd(2, time, JobSucceeded)) + time += 1 + listener.onJobEnd(SparkListenerJobEnd(1, time, JobSucceeded)) + + // Start job 3 and job 2 should be evicted. + time += 1 + listener.onJobStart(SparkListenerJobStart(3, time, Nil, null)) + assert(store.count(classOf[JobDataWrapper]) === 2) + intercept[NoSuchElementException] { + store.read(classOf[JobDataWrapper], 2) + } + } + + test("eviction should respect stage completion time") { + val testConf = conf.clone().set(MAX_RETAINED_STAGES, 2) + val listener = new AppStatusListener(store, testConf, true) + + val stage1 = new StageInfo(1, 0, "stage1", 4, Nil, Nil, "details1") + val stage2 = new StageInfo(2, 0, "stage2", 4, Nil, Nil, "details2") + val stage3 = new StageInfo(3, 0, "stage3", 4, Nil, Nil, "details3") + + // Start stage 1 and stage 2 + time += 1 + stage1.submissionTime = Some(time) + listener.onStageSubmitted(SparkListenerStageSubmitted(stage1, new Properties())) + time += 1 + stage2.submissionTime = Some(time) + listener.onStageSubmitted(SparkListenerStageSubmitted(stage2, new Properties())) + + // Stop stage 2 before stage 1 + time += 1 + stage2.completionTime = Some(time) + listener.onStageCompleted(SparkListenerStageCompleted(stage2)) + time += 1 + stage1.completionTime = Some(time) + listener.onStageCompleted(SparkListenerStageCompleted(stage1)) + + // Start stage 3 and stage 2 should be evicted. + stage3.submissionTime = Some(time) + listener.onStageSubmitted(SparkListenerStageSubmitted(stage3, new Properties())) + assert(store.count(classOf[StageDataWrapper]) === 2) + intercept[NoSuchElementException] { + store.read(classOf[StageDataWrapper], Array(2, 0)) + } + } + + test("eviction should respect task completion time") { + val testConf = conf.clone().set(MAX_RETAINED_TASKS_PER_STAGE, 2) + val listener = new AppStatusListener(store, testConf, true) + + val stage1 = new StageInfo(1, 0, "stage1", 4, Nil, Nil, "details1") + stage1.submissionTime = Some(time) + listener.onStageSubmitted(SparkListenerStageSubmitted(stage1, new Properties())) + + // Start task 1 and task 2 + val tasks = createTasks(3, Array("1")) + tasks.take(2).foreach { task => + listener.onTaskStart(SparkListenerTaskStart(stage1.stageId, stage1.attemptNumber, task)) + } + + // Stop task 2 before task 1 + time += 1 + tasks(1).markFinished(TaskState.FINISHED, time) + listener.onTaskEnd( + SparkListenerTaskEnd(stage1.stageId, stage1.attemptId, "taskType", Success, tasks(1), null)) + time += 1 + tasks(0).markFinished(TaskState.FINISHED, time) + listener.onTaskEnd( + SparkListenerTaskEnd(stage1.stageId, stage1.attemptId, "taskType", Success, tasks(0), null)) + + // Start task 3 and task 2 should be evicted. + listener.onTaskStart(SparkListenerTaskStart(stage1.stageId, stage1.attemptNumber, tasks(2))) + assert(store.count(classOf[TaskDataWrapper]) === 2) + intercept[NoSuchElementException] { + store.read(classOf[TaskDataWrapper], tasks(1).id) + } + } + test("driver logs") { val listener = new AppStatusListener(store, conf, true) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala index 73a105266e1c1..53fb9a0cc21cf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala @@ -332,8 +332,8 @@ class SQLAppStatusListener( return } - val toDelete = KVUtils.viewToSeq(kvstore.view(classOf[SQLExecutionUIData]), - countToDelete.toInt) { e => e.completionTime.isDefined } + val view = kvstore.view(classOf[SQLExecutionUIData]).index("completionTime").first(0L) + val toDelete = KVUtils.viewToSeq(view, countToDelete.toInt)(_.completionTime.isDefined) toDelete.foreach { e => kvstore.delete(e.getClass(), e.executionId) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusStore.scala index 910f2e52fdbb3..9a76584717f42 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusStore.scala @@ -23,11 +23,12 @@ import java.util.Date import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer +import com.fasterxml.jackson.annotation.JsonIgnore import com.fasterxml.jackson.databind.annotation.JsonDeserialize import org.apache.spark.JobExecutionStatus import org.apache.spark.status.KVUtils.KVIndexParam -import org.apache.spark.util.kvstore.KVStore +import org.apache.spark.util.kvstore.{KVIndex, KVStore} /** * Provides a view of a KVStore with methods that make it easy to query SQL-specific state. There's @@ -90,7 +91,11 @@ class SQLExecutionUIData( * from the SQL listener instance. */ @JsonDeserialize(keyAs = classOf[JLong]) - val metricValues: Map[Long, String]) + val metricValues: Map[Long, String]) { + + @JsonIgnore @KVIndex("completionTime") + private def completionTimeIndex: Long = completionTime.map(_.getTime).getOrElse(-1L) +} class SparkPlanGraphWrapper( @KVIndexParam val executionId: Long, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala index 7d84f45d36bee..85face3994fd4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala @@ -35,6 +35,7 @@ import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.catalyst.util.quietly import org.apache.spark.sql.execution.{LeafExecNode, QueryExecution, SparkPlanInfo, SQLExecution} import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} +import org.apache.spark.sql.internal.StaticSQLConf.UI_RETAINED_EXECUTIONS import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.status.ElementTrackingStore import org.apache.spark.status.config._ @@ -510,6 +511,50 @@ class SQLAppStatusListenerSuite extends SparkFunSuite with SharedSQLContext with } } + test("eviction should respect execution completion time") { + val conf = sparkContext.conf.clone().set(UI_RETAINED_EXECUTIONS.key, "2") + val store = new ElementTrackingStore(new InMemoryStore, conf) + val listener = new SQLAppStatusListener(conf, store, live = true) + val statusStore = new SQLAppStatusStore(store, Some(listener)) + + var time = 0 + val df = createTestDataFrame + // Start execution 1 and execution 2 + time += 1 + listener.onOtherEvent(SparkListenerSQLExecutionStart( + 1, + "test", + "test", + df.queryExecution.toString, + SparkPlanInfo.fromSparkPlan(df.queryExecution.executedPlan), + time)) + time += 1 + listener.onOtherEvent(SparkListenerSQLExecutionStart( + 2, + "test", + "test", + df.queryExecution.toString, + SparkPlanInfo.fromSparkPlan(df.queryExecution.executedPlan), + time)) + + // Stop execution 2 before execution 1 + time += 1 + listener.onOtherEvent(SparkListenerSQLExecutionEnd(2, time)) + time += 1 + listener.onOtherEvent(SparkListenerSQLExecutionEnd(1, time)) + + // Start execution 3 and execution 2 should be evicted. + time += 1 + listener.onOtherEvent(SparkListenerSQLExecutionStart( + 3, + "test", + "test", + df.queryExecution.toString, + SparkPlanInfo.fromSparkPlan(df.queryExecution.executedPlan), + time)) + assert(statusStore.executionsCount === 2) + assert(statusStore.execution(2) === None) + } } From 03b7e120dd7ff7848c936c7a23644da5bd7219ab Mon Sep 17 00:00:00 2001 From: Sital Kedia Date: Mon, 5 Feb 2018 10:19:18 -0800 Subject: [PATCH 0289/1161] [SPARK-23310][CORE] Turn off read ahead input stream for unshafe shuffle reader To fix regression for TPC-DS queries Author: Sital Kedia Closes #20492 from sitalkedia/turn_off_async_inputstream. --- .../util/collection/unsafe/sort/UnsafeSorterSpillReader.java | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java index e2f48e5508af6..71e7c7a95ebdb 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java @@ -76,8 +76,10 @@ public UnsafeSorterSpillReader( SparkEnv.get() == null ? 0.5 : SparkEnv.get().conf().getDouble("spark.unsafe.sorter.spill.read.ahead.fraction", 0.5); + // SPARK-23310: Disable read-ahead input stream, because it is causing lock contention and perf regression for + // TPC-DS queries. final boolean readAheadEnabled = SparkEnv.get() != null && - SparkEnv.get().conf().getBoolean("spark.unsafe.sorter.spill.read.ahead.enabled", true); + SparkEnv.get().conf().getBoolean("spark.unsafe.sorter.spill.read.ahead.enabled", false); final InputStream bs = new NioBufferedFileInputStream(file, (int) bufferSizeBytes); From c2766b07b4b9ed976931966a79c65043e81cf694 Mon Sep 17 00:00:00 2001 From: Xingbo Jiang Date: Mon, 5 Feb 2018 14:17:11 -0800 Subject: [PATCH 0290/1161] [SPARK-23330][WEBUI] Spark UI SQL executions page throws NPE ## What changes were proposed in this pull request? Spark SQL executions page throws the following error and the page crashes: ``` HTTP ERROR 500 Problem accessing /SQL/. Reason: Server Error Caused by: java.lang.NullPointerException at scala.collection.immutable.StringOps$.length$extension(StringOps.scala:47) at scala.collection.immutable.StringOps.length(StringOps.scala:47) at scala.collection.IndexedSeqOptimized$class.isEmpty(IndexedSeqOptimized.scala:27) at scala.collection.immutable.StringOps.isEmpty(StringOps.scala:29) at scala.collection.TraversableOnce$class.nonEmpty(TraversableOnce.scala:111) at scala.collection.immutable.StringOps.nonEmpty(StringOps.scala:29) at org.apache.spark.sql.execution.ui.ExecutionTable.descriptionCell(AllExecutionsPage.scala:182) at org.apache.spark.sql.execution.ui.ExecutionTable.row(AllExecutionsPage.scala:155) at org.apache.spark.sql.execution.ui.ExecutionTable$$anonfun$8.apply(AllExecutionsPage.scala:204) at org.apache.spark.sql.execution.ui.ExecutionTable$$anonfun$8.apply(AllExecutionsPage.scala:204) at org.apache.spark.ui.UIUtils$$anonfun$listingTable$2.apply(UIUtils.scala:339) at org.apache.spark.ui.UIUtils$$anonfun$listingTable$2.apply(UIUtils.scala:339) at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234) at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234) at scala.collection.mutable.ResizableArray$class.foreach(ResizableArray.scala:59) at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:48) at scala.collection.TraversableLike$class.map(TraversableLike.scala:234) at scala.collection.AbstractTraversable.map(Traversable.scala:104) at org.apache.spark.ui.UIUtils$.listingTable(UIUtils.scala:339) at org.apache.spark.sql.execution.ui.ExecutionTable.toNodeSeq(AllExecutionsPage.scala:203) at org.apache.spark.sql.execution.ui.AllExecutionsPage.render(AllExecutionsPage.scala:67) at org.apache.spark.ui.WebUI$$anonfun$2.apply(WebUI.scala:82) at org.apache.spark.ui.WebUI$$anonfun$2.apply(WebUI.scala:82) at org.apache.spark.ui.JettyUtils$$anon$3.doGet(JettyUtils.scala:90) at javax.servlet.http.HttpServlet.service(HttpServlet.java:687) at javax.servlet.http.HttpServlet.service(HttpServlet.java:790) at org.eclipse.jetty.servlet.ServletHolder.handle(ServletHolder.java:848) at org.eclipse.jetty.servlet.ServletHandler.doHandle(ServletHandler.java:584) at org.eclipse.jetty.server.handler.ContextHandler.doHandle(ContextHandler.java:1180) at org.eclipse.jetty.servlet.ServletHandler.doScope(ServletHandler.java:512) at org.eclipse.jetty.server.handler.ContextHandler.doScope(ContextHandler.java:1112) at org.eclipse.jetty.server.handler.ScopedHandler.handle(ScopedHandler.java:141) at org.eclipse.jetty.server.handler.ContextHandlerCollection.handle(ContextHandlerCollection.java:213) at org.eclipse.jetty.server.handler.HandlerWrapper.handle(HandlerWrapper.java:134) at org.eclipse.jetty.server.Server.handle(Server.java:534) at org.eclipse.jetty.server.HttpChannel.handle(HttpChannel.java:320) at org.eclipse.jetty.server.HttpConnection.onFillable(HttpConnection.java:251) at org.eclipse.jetty.io.AbstractConnection$ReadCallback.succeeded(AbstractConnection.java:283) at org.eclipse.jetty.io.FillInterest.fillable(FillInterest.java:108) at org.eclipse.jetty.io.SelectChannelEndPoint$2.run(SelectChannelEndPoint.java:93) at org.eclipse.jetty.util.thread.strategy.ExecuteProduceConsume.executeProduceConsume(ExecuteProduceConsume.java:303) at org.eclipse.jetty.util.thread.strategy.ExecuteProduceConsume.produceConsume(ExecuteProduceConsume.java:148) at org.eclipse.jetty.util.thread.strategy.ExecuteProduceConsume.run(ExecuteProduceConsume.java:136) at org.eclipse.jetty.util.thread.QueuedThreadPool.runJob(QueuedThreadPool.java:671) at org.eclipse.jetty.util.thread.QueuedThreadPool$2.run(QueuedThreadPool.java:589) at java.lang.Thread.run(Thread.java:748) ``` One of the possible reason that this page fails may be the `SparkListenerSQLExecutionStart` event get dropped before processed, so the execution description and details don't get updated. This was not a issue in 2.2 because it would ignore any job start event that arrives before the corresponding execution start event, which doesn't sound like a good decision. We shall try to handle the null values in the front page side, that is, try to give a default value when `execution.details` or `execution.description` is null. Another possible approach is not to spill the `LiveExecutionData` in `SQLAppStatusListener.update(exec: LiveExecutionData)` if `exec.details` is null. This is not ideal because this way you will not see the execution if `SparkListenerSQLExecutionStart` event is lost, because `AllExecutionsPage` only read executions from KVStore. ## How was this patch tested? After the change, the page shows the following: ![image](https://user-images.githubusercontent.com/4784782/35775480-28cc5fde-093e-11e8-8ccc-f58c2ef4a514.png) Author: Xingbo Jiang Closes #20502 from jiangxb1987/executionPage. --- .../apache/spark/sql/execution/ui/AllExecutionsPage.scala | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala index 7019d98e1619f..e751ce39cd5d7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala @@ -179,7 +179,7 @@ private[ui] abstract class ExecutionTable( } private def descriptionCell(execution: SQLExecutionUIData): Seq[Node] = { - val details = if (execution.details.nonEmpty) { + val details = if (execution.details != null && execution.details.nonEmpty) { +details ++ @@ -190,8 +190,10 @@ private[ui] abstract class ExecutionTable( Nil } - val desc = { + val desc = if (execution.description != null && execution.description.nonEmpty) { {execution.description} + } else { + {execution.executionId} }
{desc} {details}
From f3f1e14bb73dfdd2927d95b12d7d61d22de8a0ac Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Tue, 6 Feb 2018 14:42:42 +0800 Subject: [PATCH 0291/1161] [SPARK-23326][WEBUI] schedulerDelay should return 0 when the task is running ## What changes were proposed in this pull request? When a task is still running, metrics like executorRunTime are not available. Then `schedulerDelay` will be almost the same as `duration` and that's confusing. This PR makes `schedulerDelay` return 0 when the task is running which is the same behavior as 2.2. ## How was this patch tested? `AppStatusUtilsSuite.schedulerDelay` Author: Shixiong Zhu Closes #20493 from zsxwing/SPARK-23326. --- .../apache/spark/status/AppStatusUtils.scala | 11 ++- .../spark/status/AppStatusUtilsSuite.scala | 89 +++++++++++++++++++ 2 files changed, 98 insertions(+), 2 deletions(-) create mode 100644 core/src/test/scala/org/apache/spark/status/AppStatusUtilsSuite.scala diff --git a/core/src/main/scala/org/apache/spark/status/AppStatusUtils.scala b/core/src/main/scala/org/apache/spark/status/AppStatusUtils.scala index 341bd4e0cd016..87f434daf4870 100644 --- a/core/src/main/scala/org/apache/spark/status/AppStatusUtils.scala +++ b/core/src/main/scala/org/apache/spark/status/AppStatusUtils.scala @@ -17,16 +17,23 @@ package org.apache.spark.status -import org.apache.spark.status.api.v1.{TaskData, TaskMetrics} +import org.apache.spark.status.api.v1.TaskData private[spark] object AppStatusUtils { + private val TASK_FINISHED_STATES = Set("FAILED", "KILLED", "SUCCESS") + + private def isTaskFinished(task: TaskData): Boolean = { + TASK_FINISHED_STATES.contains(task.status) + } + def schedulerDelay(task: TaskData): Long = { - if (task.taskMetrics.isDefined && task.duration.isDefined) { + if (isTaskFinished(task) && task.taskMetrics.isDefined && task.duration.isDefined) { val m = task.taskMetrics.get schedulerDelay(task.launchTime.getTime(), fetchStart(task), task.duration.get, m.executorDeserializeTime, m.resultSerializationTime, m.executorRunTime) } else { + // The task is still running and the metrics like executorRunTime are not available. 0L } } diff --git a/core/src/test/scala/org/apache/spark/status/AppStatusUtilsSuite.scala b/core/src/test/scala/org/apache/spark/status/AppStatusUtilsSuite.scala new file mode 100644 index 0000000000000..9e74e86ad54b9 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/status/AppStatusUtilsSuite.scala @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.status + +import java.util.Date + +import org.apache.spark.SparkFunSuite +import org.apache.spark.status.api.v1.{TaskData, TaskMetrics} + +class AppStatusUtilsSuite extends SparkFunSuite { + + test("schedulerDelay") { + val runningTask = new TaskData( + taskId = 0, + index = 0, + attempt = 0, + launchTime = new Date(1L), + resultFetchStart = None, + duration = Some(100L), + executorId = "1", + host = "localhost", + status = "RUNNING", + taskLocality = "PROCESS_LOCAL", + speculative = false, + accumulatorUpdates = Nil, + errorMessage = None, + taskMetrics = Some(new TaskMetrics( + executorDeserializeTime = 0L, + executorDeserializeCpuTime = 0L, + executorRunTime = 0L, + executorCpuTime = 0L, + resultSize = 0L, + jvmGcTime = 0L, + resultSerializationTime = 0L, + memoryBytesSpilled = 0L, + diskBytesSpilled = 0L, + peakExecutionMemory = 0L, + inputMetrics = null, + outputMetrics = null, + shuffleReadMetrics = null, + shuffleWriteMetrics = null))) + assert(AppStatusUtils.schedulerDelay(runningTask) === 0L) + + val finishedTask = new TaskData( + taskId = 0, + index = 0, + attempt = 0, + launchTime = new Date(1L), + resultFetchStart = None, + duration = Some(100L), + executorId = "1", + host = "localhost", + status = "SUCCESS", + taskLocality = "PROCESS_LOCAL", + speculative = false, + accumulatorUpdates = Nil, + errorMessage = None, + taskMetrics = Some(new TaskMetrics( + executorDeserializeTime = 5L, + executorDeserializeCpuTime = 3L, + executorRunTime = 90L, + executorCpuTime = 10L, + resultSize = 100L, + jvmGcTime = 10L, + resultSerializationTime = 2L, + memoryBytesSpilled = 0L, + diskBytesSpilled = 0L, + peakExecutionMemory = 100L, + inputMetrics = null, + outputMetrics = null, + shuffleReadMetrics = null, + shuffleWriteMetrics = null))) + assert(AppStatusUtils.schedulerDelay(finishedTask) === 3L) + } +} From a24c03138a6935a442b983c8a4c721b26df3f9e2 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Tue, 6 Feb 2018 14:52:25 +0800 Subject: [PATCH 0292/1161] [SPARK-23290][SQL][PYTHON] Use datetime.date for date type when converting Spark DataFrame to Pandas DataFrame. ## What changes were proposed in this pull request? In #18664, there was a change in how `DateType` is being returned to users ([line 1968 in dataframe.py](https://github.com/apache/spark/pull/18664/files#diff-6fc344560230bf0ef711bb9b5573f1faR1968)). This can cause client code which works in Spark 2.2 to fail. See [SPARK-23290](https://issues.apache.org/jira/browse/SPARK-23290?focusedCommentId=16350917&page=com.atlassian.jira.plugin.system.issuetabpanels%3Acomment-tabpanel#comment-16350917) for an example. This pr modifies to use `datetime.date` for date type as Spark 2.2 does. ## How was this patch tested? Tests modified to fit the new behavior and existing tests. Author: Takuya UESHIN Closes #20506 from ueshin/issues/SPARK-23290. --- python/pyspark/serializers.py | 9 ++++-- python/pyspark/sql/dataframe.py | 7 ++-- python/pyspark/sql/tests.py | 57 ++++++++++++++++++++++++--------- python/pyspark/sql/types.py | 15 +++++++++ 4 files changed, 66 insertions(+), 22 deletions(-) diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 88d6a191babca..e870325d202ca 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -267,12 +267,15 @@ def load_stream(self, stream): """ Deserialize ArrowRecordBatches to an Arrow table and return as a list of pandas.Series. """ - from pyspark.sql.types import _check_dataframe_localize_timestamps + from pyspark.sql.types import from_arrow_schema, _check_dataframe_convert_date, \ + _check_dataframe_localize_timestamps import pyarrow as pa reader = pa.open_stream(stream) + schema = from_arrow_schema(reader.schema) for batch in reader: - # NOTE: changed from pa.Columns.to_pandas, timezone issue in conversion fixed in 0.7.1 - pdf = _check_dataframe_localize_timestamps(batch.to_pandas(), self._timezone) + pdf = batch.to_pandas() + pdf = _check_dataframe_convert_date(pdf, schema) + pdf = _check_dataframe_localize_timestamps(pdf, self._timezone) yield [c for _, c in pdf.iteritems()] def __repr__(self): diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 2e55407b5397b..59a417015b949 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1923,7 +1923,8 @@ def toPandas(self): if self.sql_ctx.getConf("spark.sql.execution.arrow.enabled", "false").lower() == "true": try: - from pyspark.sql.types import _check_dataframe_localize_timestamps + from pyspark.sql.types import _check_dataframe_convert_date, \ + _check_dataframe_localize_timestamps from pyspark.sql.utils import require_minimum_pyarrow_version import pyarrow require_minimum_pyarrow_version() @@ -1931,6 +1932,7 @@ def toPandas(self): if tables: table = pyarrow.concat_tables(tables) pdf = table.to_pandas() + pdf = _check_dataframe_convert_date(pdf, self.schema) return _check_dataframe_localize_timestamps(pdf, timezone) else: return pd.DataFrame.from_records([], columns=self.columns) @@ -2009,7 +2011,6 @@ def _to_corrected_pandas_type(dt): """ When converting Spark SQL records to Pandas DataFrame, the inferred data type may be wrong. This method gets the corrected data type for Pandas if that type may be inferred uncorrectly. - NOTE: DateType is inferred incorrectly as 'object', TimestampType is correct with datetime64[ns] """ import numpy as np if type(dt) == ByteType: @@ -2020,8 +2021,6 @@ def _to_corrected_pandas_type(dt): return np.int32 elif type(dt) == FloatType: return np.float32 - elif type(dt) == DateType: - return 'datetime64[ns]' else: return None diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index b27363023ae77..545ec5aee08ff 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -2816,7 +2816,7 @@ def test_to_pandas(self): self.assertEquals(types[1], np.object) self.assertEquals(types[2], np.bool) self.assertEquals(types[3], np.float32) - self.assertEquals(types[4], 'datetime64[ns]') + self.assertEquals(types[4], np.object) # datetime.date self.assertEquals(types[5], 'datetime64[ns]') @unittest.skipIf(not _have_old_pandas, "Old Pandas not installed") @@ -3388,7 +3388,7 @@ class ArrowTests(ReusedSQLTestCase): @classmethod def setUpClass(cls): - from datetime import datetime + from datetime import date, datetime from decimal import Decimal ReusedSQLTestCase.setUpClass() @@ -3410,11 +3410,11 @@ def setUpClass(cls): StructField("7_date_t", DateType(), True), StructField("8_timestamp_t", TimestampType(), True)]) cls.data = [(u"a", 1, 10, 0.2, 2.0, Decimal("2.0"), - datetime(1969, 1, 1), datetime(1969, 1, 1, 1, 1, 1)), + date(1969, 1, 1), datetime(1969, 1, 1, 1, 1, 1)), (u"b", 2, 20, 0.4, 4.0, Decimal("4.0"), - datetime(2012, 2, 2), datetime(2012, 2, 2, 2, 2, 2)), + date(2012, 2, 2), datetime(2012, 2, 2, 2, 2, 2)), (u"c", 3, 30, 0.8, 6.0, Decimal("6.0"), - datetime(2100, 3, 3), datetime(2100, 3, 3, 3, 3, 3))] + date(2100, 3, 3), datetime(2100, 3, 3, 3, 3, 3))] @classmethod def tearDownClass(cls): @@ -3461,7 +3461,9 @@ def _toPandas_arrow_toggle(self, df): def test_toPandas_arrow_toggle(self): df = self.spark.createDataFrame(self.data, schema=self.schema) pdf, pdf_arrow = self._toPandas_arrow_toggle(df) - self.assertPandasEqual(pdf_arrow, pdf) + expected = self.create_pandas_data_frame() + self.assertPandasEqual(expected, pdf) + self.assertPandasEqual(expected, pdf_arrow) def test_toPandas_respect_session_timezone(self): df = self.spark.createDataFrame(self.data, schema=self.schema) @@ -4062,18 +4064,42 @@ def test_vectorized_udf_unsupported_types(self): with self.assertRaisesRegexp(Exception, 'Unsupported data type'): df.select(f(col('map'))).collect() - def test_vectorized_udf_null_date(self): + def test_vectorized_udf_dates(self): from pyspark.sql.functions import pandas_udf, col from datetime import date - schema = StructType().add("date", DateType()) - data = [(date(1969, 1, 1),), - (date(2012, 2, 2),), - (None,), - (date(2100, 4, 4),)] + schema = StructType().add("idx", LongType()).add("date", DateType()) + data = [(0, date(1969, 1, 1),), + (1, date(2012, 2, 2),), + (2, None,), + (3, date(2100, 4, 4),)] df = self.spark.createDataFrame(data, schema=schema) - date_f = pandas_udf(lambda t: t, returnType=DateType()) - res = df.select(date_f(col("date"))) - self.assertEquals(df.collect(), res.collect()) + + date_copy = pandas_udf(lambda t: t, returnType=DateType()) + df = df.withColumn("date_copy", date_copy(col("date"))) + + @pandas_udf(returnType=StringType()) + def check_data(idx, date, date_copy): + import pandas as pd + msgs = [] + is_equal = date.isnull() + for i in range(len(idx)): + if (is_equal[i] and data[idx[i]][1] is None) or \ + date[i] == data[idx[i]][1]: + msgs.append(None) + else: + msgs.append( + "date values are not equal (date='%s': data[%d][1]='%s')" + % (date[i], idx[i], data[idx[i]][1])) + return pd.Series(msgs) + + result = df.withColumn("check_data", + check_data(col("idx"), col("date"), col("date_copy"))).collect() + + self.assertEquals(len(data), len(result)) + for i in range(len(result)): + self.assertEquals(data[i][1], result[i][1]) # "date" col + self.assertEquals(data[i][1], result[i][2]) # "date_copy" col + self.assertIsNone(result[i][3]) # "check_data" col def test_vectorized_udf_timestamps(self): from pyspark.sql.functions import pandas_udf, col @@ -4114,6 +4140,7 @@ def check_data(idx, timestamp, timestamp_copy): self.assertEquals(len(data), len(result)) for i in range(len(result)): self.assertEquals(data[i][1], result[i][1]) # "timestamp" col + self.assertEquals(data[i][1], result[i][2]) # "timestamp_copy" col self.assertIsNone(result[i][3]) # "check_data" col def test_vectorized_udf_return_timestamp_tz(self): diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 0dc5823f72a3c..093dae5a22e1f 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1694,6 +1694,21 @@ def from_arrow_schema(arrow_schema): for field in arrow_schema]) +def _check_dataframe_convert_date(pdf, schema): + """ Correct date type value to use datetime.date. + + Pandas DataFrame created from PyArrow uses datetime64[ns] for date type values, but we should + use datetime.date to match the behavior with when Arrow optimization is disabled. + + :param pdf: pandas.DataFrame + :param schema: a Spark schema of the pandas.DataFrame + """ + for field in schema: + if type(field.dataType) == DateType: + pdf[field.name] = pdf[field.name].dt.date + return pdf + + def _check_dataframe_localize_timestamps(pdf, timezone): """ Convert timezone aware timestamps to timezone-naive in the specified timezone or local timezone From 8141c3e3ddb55586906b9bc79ef515142c2b551a Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Tue, 6 Feb 2018 16:08:15 +0900 Subject: [PATCH 0293/1161] [SPARK-23300][TESTS] Prints out if Pandas and PyArrow are installed or not in PySpark SQL tests ## What changes were proposed in this pull request? This PR proposes to log if PyArrow and Pandas are installed or not so we can check if related tests are going to be skipped or not. ## How was this patch tested? Manually tested: I don't have PyArrow installed in PyPy. ```bash $ ./run-tests --python-executables=python3 ``` ``` ... Will test against the following Python executables: ['python3'] Will test the following Python modules: ['pyspark-core', 'pyspark-ml', 'pyspark-mllib', 'pyspark-sql', 'pyspark-streaming'] Will test PyArrow related features against Python executable 'python3' in 'pyspark-sql' module. Will test Pandas related features against Python executable 'python3' in 'pyspark-sql' module. Starting test(python3): pyspark.mllib.tests Starting test(python3): pyspark.sql.tests Starting test(python3): pyspark.streaming.tests Starting test(python3): pyspark.tests ``` ```bash $ ./run-tests --modules=pyspark-streaming ``` ``` ... Will test against the following Python executables: ['python2.7', 'pypy'] Will test the following Python modules: ['pyspark-streaming'] Starting test(pypy): pyspark.streaming.tests Starting test(pypy): pyspark.streaming.util Starting test(python2.7): pyspark.streaming.tests Starting test(python2.7): pyspark.streaming.util ``` ```bash $ ./run-tests ``` ``` ... Will test against the following Python executables: ['python2.7', 'pypy'] Will test the following Python modules: ['pyspark-core', 'pyspark-ml', 'pyspark-mllib', 'pyspark-sql', 'pyspark-streaming'] Will test PyArrow related features against Python executable 'python2.7' in 'pyspark-sql' module. Will test Pandas related features against Python executable 'python2.7' in 'pyspark-sql' module. Will skip PyArrow related features against Python executable 'pypy' in 'pyspark-sql' module. PyArrow >= 0.8.0 is required; however, PyArrow was not found. Will test Pandas related features against Python executable 'pypy' in 'pyspark-sql' module. Starting test(pypy): pyspark.streaming.tests Starting test(pypy): pyspark.sql.tests Starting test(pypy): pyspark.tests Starting test(python2.7): pyspark.mllib.tests ``` ```bash $ ./run-tests --modules=pyspark-sql --python-executables=pypy ``` ``` ... Will test against the following Python executables: ['pypy'] Will test the following Python modules: ['pyspark-sql'] Will skip PyArrow related features against Python executable 'pypy' in 'pyspark-sql' module. PyArrow >= 0.8.0 is required; however, PyArrow was not found. Will test Pandas related features against Python executable 'pypy' in 'pyspark-sql' module. Starting test(pypy): pyspark.sql.tests Starting test(pypy): pyspark.sql.catalog Starting test(pypy): pyspark.sql.column Starting test(pypy): pyspark.sql.conf ``` After some modification to produce other cases: ```bash $ ./run-tests ``` ``` ... Will test against the following Python executables: ['python2.7', 'pypy'] Will test the following Python modules: ['pyspark-core', 'pyspark-ml', 'pyspark-mllib', 'pyspark-sql', 'pyspark-streaming'] Will skip PyArrow related features against Python executable 'python2.7' in 'pyspark-sql' module. PyArrow >= 20.0.0 is required; however, PyArrow 0.8.0 was found. Will skip Pandas related features against Python executable 'python2.7' in 'pyspark-sql' module. Pandas >= 20.0.0 is required; however, Pandas 0.20.2 was found. Will skip PyArrow related features against Python executable 'pypy' in 'pyspark-sql' module. PyArrow >= 20.0.0 is required; however, PyArrow was not found. Will skip Pandas related features against Python executable 'pypy' in 'pyspark-sql' module. Pandas >= 20.0.0 is required; however, Pandas 0.22.0 was found. Starting test(pypy): pyspark.sql.tests Starting test(pypy): pyspark.streaming.tests Starting test(pypy): pyspark.tests Starting test(python2.7): pyspark.mllib.tests ``` ```bash ./run-tests-with-coverage ``` ``` ... Will test against the following Python executables: ['python2.7', 'pypy'] Will test the following Python modules: ['pyspark-core', 'pyspark-ml', 'pyspark-mllib', 'pyspark-sql', 'pyspark-streaming'] Will test PyArrow related features against Python executable 'python2.7' in 'pyspark-sql' module. Will test Pandas related features against Python executable 'python2.7' in 'pyspark-sql' module. Coverage is not installed in Python executable 'pypy' but 'COVERAGE_PROCESS_START' environment variable is set, exiting. ``` Author: hyukjinkwon Closes #20473 from HyukjinKwon/SPARK-23300. --- python/run-tests.py | 73 +++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 68 insertions(+), 5 deletions(-) diff --git a/python/run-tests.py b/python/run-tests.py index f03284c334285..6b41b5ee22814 100755 --- a/python/run-tests.py +++ b/python/run-tests.py @@ -31,6 +31,7 @@ import Queue else: import queue as Queue +from distutils.version import LooseVersion # Append `SPARK_HOME/dev` to the Python path so that we can import the sparktestsupport module @@ -38,8 +39,8 @@ from sparktestsupport import SPARK_HOME # noqa (suppress pep8 warnings) -from sparktestsupport.shellutils import which, subprocess_check_output, run_cmd # noqa -from sparktestsupport.modules import all_modules # noqa +from sparktestsupport.shellutils import which, subprocess_check_output # noqa +from sparktestsupport.modules import all_modules, pyspark_sql # noqa python_modules = dict((m.name, m) for m in all_modules if m.python_test_goals if m.name != 'root') @@ -151,6 +152,67 @@ def parse_opts(): return opts +def _check_dependencies(python_exec, modules_to_test): + if "COVERAGE_PROCESS_START" in os.environ: + # Make sure if coverage is installed. + try: + subprocess_check_output( + [python_exec, "-c", "import coverage"], + stderr=open(os.devnull, 'w')) + except: + print_red("Coverage is not installed in Python executable '%s' " + "but 'COVERAGE_PROCESS_START' environment variable is set, " + "exiting." % python_exec) + sys.exit(-1) + + # If we should test 'pyspark-sql', it checks if PyArrow and Pandas are installed and + # explicitly prints out. See SPARK-23300. + if pyspark_sql in modules_to_test: + # TODO(HyukjinKwon): Relocate and deduplicate these version specifications. + minimum_pyarrow_version = '0.8.0' + minimum_pandas_version = '0.19.2' + + try: + pyarrow_version = subprocess_check_output( + [python_exec, "-c", "import pyarrow; print(pyarrow.__version__)"], + universal_newlines=True, + stderr=open(os.devnull, 'w')).strip() + if LooseVersion(pyarrow_version) >= LooseVersion(minimum_pyarrow_version): + LOGGER.info("Will test PyArrow related features against Python executable " + "'%s' in '%s' module." % (python_exec, pyspark_sql.name)) + else: + LOGGER.warning( + "Will skip PyArrow related features against Python executable " + "'%s' in '%s' module. PyArrow >= %s is required; however, PyArrow " + "%s was found." % ( + python_exec, pyspark_sql.name, minimum_pyarrow_version, pyarrow_version)) + except: + LOGGER.warning( + "Will skip PyArrow related features against Python executable " + "'%s' in '%s' module. PyArrow >= %s is required; however, PyArrow " + "was not found." % (python_exec, pyspark_sql.name, minimum_pyarrow_version)) + + try: + pandas_version = subprocess_check_output( + [python_exec, "-c", "import pandas; print(pandas.__version__)"], + universal_newlines=True, + stderr=open(os.devnull, 'w')).strip() + if LooseVersion(pandas_version) >= LooseVersion(minimum_pandas_version): + LOGGER.info("Will test Pandas related features against Python executable " + "'%s' in '%s' module." % (python_exec, pyspark_sql.name)) + else: + LOGGER.warning( + "Will skip Pandas related features against Python executable " + "'%s' in '%s' module. Pandas >= %s is required; however, Pandas " + "%s was found." % ( + python_exec, pyspark_sql.name, minimum_pandas_version, pandas_version)) + except: + LOGGER.warning( + "Will skip Pandas related features against Python executable " + "'%s' in '%s' module. Pandas >= %s is required; however, Pandas " + "was not found." % (python_exec, pyspark_sql.name, minimum_pandas_version)) + + def main(): opts = parse_opts() if (opts.verbose): @@ -175,9 +237,10 @@ def main(): task_queue = Queue.PriorityQueue() for python_exec in python_execs: - if "COVERAGE_PROCESS_START" in os.environ: - # Make sure if coverage is installed. - run_cmd([python_exec, "-c", "import coverage"]) + # Check if the python executable has proper dependencies installed to run tests + # for given modules properly. + _check_dependencies(python_exec, modules_to_test) + python_implementation = subprocess_check_output( [python_exec, "-c", "import platform; print(platform.python_implementation())"], universal_newlines=True).strip() From 63c5bf13ce5cd3b8d7e7fb88de881ed207fde720 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Tue, 6 Feb 2018 18:30:50 +0900 Subject: [PATCH 0294/1161] [SPARK-23334][SQL][PYTHON] Fix pandas_udf with return type StringType() to handle str type properly in Python 2. ## What changes were proposed in this pull request? In Python 2, when `pandas_udf` tries to return string type value created in the udf with `".."`, the execution fails. E.g., ```python from pyspark.sql.functions import pandas_udf, col import pandas as pd df = spark.range(10) str_f = pandas_udf(lambda x: pd.Series(["%s" % i for i in x]), "string") df.select(str_f(col('id'))).show() ``` raises the following exception: ``` ... java.lang.AssertionError: assertion failed: Invalid schema from pandas_udf: expected StringType, got BinaryType at scala.Predef$.assert(Predef.scala:170) at org.apache.spark.sql.execution.python.ArrowEvalPythonExec$$anon$2.(ArrowEvalPythonExec.scala:93) ... ``` Seems like pyarrow ignores `type` parameter for `pa.Array.from_pandas()` and consider it as binary type when the type is string type and the string values are `str` instead of `unicode` in Python 2. This pr adds a workaround for the case. ## How was this patch tested? Added a test and existing tests. Author: Takuya UESHIN Closes #20507 from ueshin/issues/SPARK-23334. --- python/pyspark/serializers.py | 4 ++++ python/pyspark/sql/tests.py | 9 +++++++++ 2 files changed, 13 insertions(+) diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index e870325d202ca..91a7f093cec19 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -230,6 +230,10 @@ def create_array(s, t): s = _check_series_convert_timestamps_internal(s.fillna(0), timezone) # TODO: need cast after Arrow conversion, ns values cause error with pandas 0.19.2 return pa.Array.from_pandas(s, mask=mask).cast(t, safe=False) + elif t is not None and pa.types.is_string(t) and sys.version < '3': + # TODO: need decode before converting to Arrow in Python 2 + return pa.Array.from_pandas(s.apply( + lambda v: v.decode("utf-8") if isinstance(v, str) else v), mask=mask, type=t) return pa.Array.from_pandas(s, mask=mask, type=t) arrs = [create_array(s, t) for s, t in series] diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 545ec5aee08ff..89b7c2182d2d1 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3922,6 +3922,15 @@ def test_vectorized_udf_null_string(self): res = df.select(str_f(col('str'))) self.assertEquals(df.collect(), res.collect()) + def test_vectorized_udf_string_in_udf(self): + from pyspark.sql.functions import pandas_udf, col + import pandas as pd + df = self.spark.range(10) + str_f = pandas_udf(lambda x: pd.Series(map(str, x)), StringType()) + actual = df.select(str_f(col('id'))) + expected = df.select(col('id').cast('string')) + self.assertEquals(expected.collect(), actual.collect()) + def test_vectorized_udf_datatype_string(self): from pyspark.sql.functions import pandas_udf, col df = self.spark.range(10).select( From 7db9979babe52d15828967c86eb77e3fb2791579 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Tue, 6 Feb 2018 10:46:48 -0800 Subject: [PATCH 0295/1161] [SPARK-23310][CORE][FOLLOWUP] Fix Java style check issues. ## What changes were proposed in this pull request? This is a follow-up of #20492 which broke lint-java checks. This pr fixes the lint-java issues. ``` [ERROR] src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java:[79] (sizes) LineLength: Line is longer than 100 characters (found 114). ``` ## How was this patch tested? Checked manually in my local environment. Author: Takuya UESHIN Closes #20514 from ueshin/issues/SPARK-23310/fup1. --- .../util/collection/unsafe/sort/UnsafeSorterSpillReader.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java index 71e7c7a95ebdb..2c53c8d809d2e 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java @@ -76,8 +76,8 @@ public UnsafeSorterSpillReader( SparkEnv.get() == null ? 0.5 : SparkEnv.get().conf().getDouble("spark.unsafe.sorter.spill.read.ahead.fraction", 0.5); - // SPARK-23310: Disable read-ahead input stream, because it is causing lock contention and perf regression for - // TPC-DS queries. + // SPARK-23310: Disable read-ahead input stream, because it is causing lock contention and perf + // regression for TPC-DS queries. final boolean readAheadEnabled = SparkEnv.get() != null && SparkEnv.get().conf().getBoolean("spark.unsafe.sorter.spill.read.ahead.enabled", false); From ac7454cac04a1d9252b3856360eda5c3e8bcb8da Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 6 Feb 2018 12:27:37 -0800 Subject: [PATCH 0296/1161] [SPARK-23312][SQL][FOLLOWUP] add a config to turn off vectorized cache reader ## What changes were proposed in this pull request? https://github.com/apache/spark/pull/20483 tried to provide a way to turn off the new columnar cache reader, to restore the behavior in 2.2. However even we turn off that config, the behavior is still different than 2.2. If the output data are rows, we still enable whole stage codegen for the scan node, which is different with 2.2, we should also fix it. ## How was this patch tested? existing tests. Author: Wenchen Fan Closes #20513 from cloud-fan/cache. --- .../spark/sql/execution/columnar/InMemoryTableScanExec.scala | 3 +++ .../src/test/scala/org/apache/spark/sql/CachedTableSuite.scala | 3 ++- .../apache/spark/sql/execution/WholeStageCodegenSuite.scala | 2 +- 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala index e972f8b30d87c..a93e8a1ad954d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala @@ -61,6 +61,9 @@ case class InMemoryTableScanExec( }) && !WholeStageCodegenExec.isTooManyFields(conf, relation.schema) } + // TODO: revisit this. Shall we always turn off whole stage codegen if the output data are rows? + override def supportCodegen: Boolean = supportsBatch + override protected def needsUnsafeRowConversion: Boolean = false private val columnIndices = diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index 9f27fa09127af..669e5f2bf4e65 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -787,7 +787,8 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext withSQLConf(SQLConf.CACHE_VECTORIZED_READER_ENABLED.key -> vectorized.toString) { val df = spark.range(10).cache() df.queryExecution.executedPlan.foreach { - case i: InMemoryTableScanExec => assert(i.supportsBatch == vectorized) + case i: InMemoryTableScanExec => + assert(i.supportsBatch == vectorized && i.supportCodegen == vectorized) case _ => } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index 6e8d5a70d5a8f..ef16292a8e75c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -137,7 +137,7 @@ class WholeStageCodegenSuite extends QueryTest with SharedSQLContext { val dsStringFilter = dsString.filter(_ == "1") val planString = dsStringFilter.queryExecution.executedPlan assert(planString.collect { - case WholeStageCodegenExec(FilterExec(_, i: InMemoryTableScanExec)) if !i.supportsBatch => () + case i: InMemoryTableScanExec if !i.supportsBatch => () }.length == 1) assert(dsStringFilter.collect() === Array("1")) } From caf30445632de6aec810309293499199e7a20892 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Tue, 6 Feb 2018 12:30:04 -0800 Subject: [PATCH 0297/1161] [MINOR][TEST] Fix class name for Pandas UDF tests ## What changes were proposed in this pull request? In https://github.com/apache/spark/commit/b2ce17b4c9fea58140a57ca1846b2689b15c0d61, I mistakenly renamed `VectorizedUDFTests` to `ScalarPandasUDF`. This PR fixes the mistake. ## How was this patch tested? Existing tests. Author: Li Jin Closes #20489 from icexelloss/fix-scalar-udf-tests. --- python/pyspark/sql/tests.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 89b7c2182d2d1..53da7dd45c2f2 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3766,7 +3766,7 @@ def foo(k, v): @unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed") -class ScalarPandasUDF(ReusedSQLTestCase): +class ScalarPandasUDFTests(ReusedSQLTestCase): @classmethod def setUpClass(cls): @@ -4279,7 +4279,7 @@ def test_register_vectorized_udf_basic(self): @unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed") -class GroupbyApplyPandasUDFTests(ReusedSQLTestCase): +class GroupedMapPandasUDFTests(ReusedSQLTestCase): @property def data(self): @@ -4448,7 +4448,7 @@ def test_unsupported_types(self): @unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed") -class GroupbyAggPandasUDFTests(ReusedSQLTestCase): +class GroupedAggPandasUDFTests(ReusedSQLTestCase): @property def data(self): From b96a083b1c6ff0d2c588be9499b456e1adce97dc Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 6 Feb 2018 12:43:45 -0800 Subject: [PATCH 0298/1161] [SPARK-23315][SQL] failed to get output from canonicalized data source v2 related plans ## What changes were proposed in this pull request? `DataSourceV2Relation` keeps a `fullOutput` and resolves the real output on demand by column name lookup. i.e. ``` lazy val output: Seq[Attribute] = reader.readSchema().map(_.name).map { name => fullOutput.find(_.name == name).get } ``` This will be broken after we canonicalize the plan, because all attribute names become "None", see https://github.com/apache/spark/blob/v2.3.0-rc1/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala#L42 To fix this, `DataSourceV2Relation` should just keep `output`, and update the `output` when doing column pruning. ## How was this patch tested? a new test case Author: Wenchen Fan Closes #20485 from cloud-fan/canonicalize. --- .../v2/DataSourceReaderHolder.scala | 12 +++----- .../datasources/v2/DataSourceV2Relation.scala | 8 ++--- .../datasources/v2/DataSourceV2ScanExec.scala | 4 +-- .../v2/PushDownOperatorsToDataSource.scala | 29 +++++++++++++------ .../sql/sources/v2/DataSourceV2Suite.scala | 20 ++++++++++++- 5 files changed, 48 insertions(+), 25 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceReaderHolder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceReaderHolder.scala index 6460c97abe344..81219e9771bd8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceReaderHolder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceReaderHolder.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.datasources.v2 import java.util.Objects -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} +import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.sources.v2.reader._ /** @@ -28,9 +28,9 @@ import org.apache.spark.sql.sources.v2.reader._ trait DataSourceReaderHolder { /** - * The full output of the data source reader, without column pruning. + * The output of the data source reader, w.r.t. column pruning. */ - def fullOutput: Seq[AttributeReference] + def output: Seq[Attribute] /** * The held data source reader. @@ -46,7 +46,7 @@ trait DataSourceReaderHolder { case s: SupportsPushDownFilters => s.pushedFilters().toSet case _ => Nil } - Seq(fullOutput, reader.getClass, reader.readSchema(), filters) + Seq(output, reader.getClass, filters) } def canEqual(other: Any): Boolean @@ -61,8 +61,4 @@ trait DataSourceReaderHolder { override def hashCode(): Int = { metadata.map(Objects.hashCode).foldLeft(0)((a, b) => 31 * a + b) } - - lazy val output: Seq[Attribute] = reader.readSchema().map(_.name).map { name => - fullOutput.find(_.name == name).get - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala index eebfa29f91b99..38f6b15224788 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics} import org.apache.spark.sql.sources.v2.reader._ case class DataSourceV2Relation( - fullOutput: Seq[AttributeReference], + output: Seq[AttributeReference], reader: DataSourceReader) extends LeafNode with MultiInstanceRelation with DataSourceReaderHolder { @@ -37,7 +37,7 @@ case class DataSourceV2Relation( } override def newInstance(): DataSourceV2Relation = { - copy(fullOutput = fullOutput.map(_.newInstance())) + copy(output = output.map(_.newInstance())) } } @@ -46,8 +46,8 @@ case class DataSourceV2Relation( * to the non-streaming relation. */ class StreamingDataSourceV2Relation( - fullOutput: Seq[AttributeReference], - reader: DataSourceReader) extends DataSourceV2Relation(fullOutput, reader) { + output: Seq[AttributeReference], + reader: DataSourceReader) extends DataSourceV2Relation(output, reader) { override def isStreaming: Boolean = true } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala index df469af2c262a..7d9581be4db89 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala @@ -35,14 +35,12 @@ import org.apache.spark.sql.types.StructType * Physical plan node for scanning data from a data source. */ case class DataSourceV2ScanExec( - fullOutput: Seq[AttributeReference], + output: Seq[AttributeReference], @transient reader: DataSourceReader) extends LeafExecNode with DataSourceReaderHolder with ColumnarBatchScan { override def canEqual(other: Any): Boolean = other.isInstanceOf[DataSourceV2ScanExec] - override def producedAttributes: AttributeSet = AttributeSet(fullOutput) - override def outputPartitioning: physical.Partitioning = reader match { case s: SupportsReportPartitioning => new DataSourcePartitioning( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala index 566a48394f02e..1ca6cbf061b4e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala @@ -81,33 +81,44 @@ object PushDownOperatorsToDataSource extends Rule[LogicalPlan] with PredicateHel // TODO: add more push down rules. - pushDownRequiredColumns(filterPushed, filterPushed.outputSet) + val columnPruned = pushDownRequiredColumns(filterPushed, filterPushed.outputSet) // After column pruning, we may have redundant PROJECT nodes in the query plan, remove them. - RemoveRedundantProject(filterPushed) + RemoveRedundantProject(columnPruned) } // TODO: nested fields pruning - private def pushDownRequiredColumns(plan: LogicalPlan, requiredByParent: AttributeSet): Unit = { + private def pushDownRequiredColumns( + plan: LogicalPlan, requiredByParent: AttributeSet): LogicalPlan = { plan match { - case Project(projectList, child) => + case p @ Project(projectList, child) => val required = projectList.flatMap(_.references) - pushDownRequiredColumns(child, AttributeSet(required)) + p.copy(child = pushDownRequiredColumns(child, AttributeSet(required))) - case Filter(condition, child) => + case f @ Filter(condition, child) => val required = requiredByParent ++ condition.references - pushDownRequiredColumns(child, required) + f.copy(child = pushDownRequiredColumns(child, required)) case relation: DataSourceV2Relation => relation.reader match { case reader: SupportsPushDownRequiredColumns => + // TODO: Enable the below assert after we make `DataSourceV2Relation` immutable. Fow now + // it's possible that the mutable reader being updated by someone else, and we need to + // always call `reader.pruneColumns` here to correct it. + // assert(relation.output.toStructType == reader.readSchema(), + // "Schema of data source reader does not match the relation plan.") + val requiredColumns = relation.output.filter(requiredByParent.contains) reader.pruneColumns(requiredColumns.toStructType) - case _ => + val nameToAttr = relation.output.map(_.name).zip(relation.output).toMap + val newOutput = reader.readSchema().map(_.name).map(nameToAttr) + relation.copy(output = newOutput) + + case _ => relation } // TODO: there may be more operators that can be used to calculate the required columns. We // can add more and more in the future. - case _ => plan.children.foreach(child => pushDownRequiredColumns(child, child.outputSet)) + case _ => plan.mapChildren(c => pushDownRequiredColumns(c, c.outputSet)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala index eccd45442a3b2..a1c87fb15542c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala @@ -24,7 +24,7 @@ import test.org.apache.spark.sql.sources.v2._ import org.apache.spark.SparkException import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row} import org.apache.spark.sql.catalyst.expressions.UnsafeRow -import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanExec +import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, DataSourceV2ScanExec} import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector import org.apache.spark.sql.functions._ @@ -316,6 +316,24 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { val reader4 = getReader(q4) assert(reader4.requiredSchema.fieldNames === Seq("i")) } + + test("SPARK-23315: get output from canonicalized data source v2 related plans") { + def checkCanonicalizedOutput(df: DataFrame, numOutput: Int): Unit = { + val logical = df.queryExecution.optimizedPlan.collect { + case d: DataSourceV2Relation => d + }.head + assert(logical.canonicalized.output.length == numOutput) + + val physical = df.queryExecution.executedPlan.collect { + case d: DataSourceV2ScanExec => d + }.head + assert(physical.canonicalized.output.length == numOutput) + } + + val df = spark.read.format(classOf[AdvancedDataSourceV2].getName).load() + checkCanonicalizedOutput(df, 2) + checkCanonicalizedOutput(df.select('i), 1) + } } class SimpleDataSourceV2 extends DataSourceV2 with ReadSupport { From c36fecc3b416c38002779c3cf40b6a665ac4bf13 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 6 Feb 2018 16:46:43 -0800 Subject: [PATCH 0299/1161] [SPARK-23327][SQL] Update the description and tests of three external API or functions ## What changes were proposed in this pull request? Update the description and tests of three external API or functions `createFunction `, `length` and `repartitionByRange ` ## How was this patch tested? N/A Author: gatorsmile Closes #20495 from gatorsmile/updateFunc. --- R/pkg/R/functions.R | 4 +++- python/pyspark/sql/functions.py | 8 ++++--- .../sql/catalyst/catalog/SessionCatalog.scala | 7 ++++-- .../expressions/stringExpressions.scala | 23 ++++++++++--------- .../scala/org/apache/spark/sql/Dataset.scala | 2 ++ .../sql/execution/command/functions.scala | 14 +++++++---- .../org/apache/spark/sql/functions.scala | 4 +++- .../execution/command/DDLParserSuite.scala | 10 ++++---- 8 files changed, 44 insertions(+), 28 deletions(-) diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 55365a41d774b..9f7c6317cd924 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -1026,7 +1026,9 @@ setMethod("last_day", }) #' @details -#' \code{length}: Computes the length of a given string or binary column. +#' \code{length}: Computes the character length of a string data or number of bytes +#' of a binary data. The length of string data includes the trailing spaces. +#' The length of binary data includes binary zeros. #' #' @rdname column_string_functions #' @aliases length length,Column-method diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 3c8fb4c4d19e7..05031f5ec87d7 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1705,10 +1705,12 @@ def unhex(col): @ignore_unicode_prefix @since(1.5) def length(col): - """Calculates the length of a string or binary expression. + """Computes the character length of string data or number of bytes of binary data. + The length of character data includes the trailing spaces. The length of binary data + includes binary zeros. - >>> spark.createDataFrame([('ABC',)], ['a']).select(length('a').alias('length')).collect() - [Row(length=3)] + >>> spark.createDataFrame([('ABC ',)], ['a']).select(length('a').alias('length')).collect() + [Row(length=4)] """ sc = SparkContext._active_spark_context return Column(sc._jvm.functions.length(_to_java_column(col))) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index a129896230775..4b119c75260a7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -988,8 +988,11 @@ class SessionCatalog( // ------------------------------------------------------- /** - * Create a metastore function in the database specified in `funcDefinition`. + * Create a function in the database specified in `funcDefinition`. * If no such database is specified, create it in the current database. + * + * @param ignoreIfExists: When true, ignore if the function with the specified name exists + * in the specified database. */ def createFunction(funcDefinition: CatalogFunction, ignoreIfExists: Boolean): Unit = { val db = formatDatabaseName(funcDefinition.identifier.database.getOrElse(getCurrentDatabase)) @@ -1061,7 +1064,7 @@ class SessionCatalog( } /** - * Check if the specified function exists. + * Check if the function with the specified name exists */ def functionExists(name: FunctionIdentifier): Boolean = { val db = formatDatabaseName(name.database.getOrElse(getCurrentDatabase)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 5cf783f1a5979..d7612e30b4c57 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -1653,19 +1653,19 @@ case class Left(str: Expression, len: Expression, child: Expression) extends Run * A function that returns the char length of the given string expression or * number of bytes of the given binary expression. */ -// scalastyle:off line.size.limit @ExpressionDescription( - usage = "_FUNC_(expr) - Returns the character length of `expr` or number of bytes in binary data.", + usage = "_FUNC_(expr) - Returns the character length of string data or number of bytes of " + + "binary data. The length of string data includes the trailing spaces. The length of binary " + + "data includes binary zeros.", examples = """ Examples: - > SELECT _FUNC_('Spark SQL'); - 9 - > SELECT CHAR_LENGTH('Spark SQL'); - 9 - > SELECT CHARACTER_LENGTH('Spark SQL'); - 9 + > SELECT _FUNC_('Spark SQL '); + 10 + > SELECT CHAR_LENGTH('Spark SQL '); + 10 + > SELECT CHARACTER_LENGTH('Spark SQL '); + 10 """) -// scalastyle:on line.size.limit case class Length(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { override def dataType: DataType = IntegerType override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(StringType, BinaryType)) @@ -1687,7 +1687,7 @@ case class Length(child: Expression) extends UnaryExpression with ImplicitCastIn * A function that returns the bit length of the given string or binary expression. */ @ExpressionDescription( - usage = "_FUNC_(expr) - Returns the bit length of `expr` or number of bits in binary data.", + usage = "_FUNC_(expr) - Returns the bit length of string data or number of bits of binary data.", examples = """ Examples: > SELECT _FUNC_('Spark SQL'); @@ -1716,7 +1716,8 @@ case class BitLength(child: Expression) extends UnaryExpression with ImplicitCas * A function that returns the byte length of the given string or binary expression. */ @ExpressionDescription( - usage = "_FUNC_(expr) - Returns the byte length of `expr` or number of bytes in binary data.", + usage = "_FUNC_(expr) - Returns the byte length of string data or number of bytes of binary " + + "data.", examples = """ Examples: > SELECT _FUNC_('Spark SQL'); diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index d47cd0aecf56a..0aee1d7be5788 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -2825,6 +2825,7 @@ class Dataset[T] private[sql]( * * At least one partition-by expression must be specified. * When no explicit sort order is specified, "ascending nulls first" is assumed. + * Note, the rows are not sorted in each partition of the resulting Dataset. * * @group typedrel * @since 2.3.0 @@ -2848,6 +2849,7 @@ class Dataset[T] private[sql]( * * At least one partition-by expression must be specified. * When no explicit sort order is specified, "ascending nulls first" is assumed. + * Note, the rows are not sorted in each partition of the resulting Dataset. * * @group typedrel * @since 2.3.0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala index 4f92ffee687aa..1f7808c2f8e80 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala @@ -40,6 +40,10 @@ import org.apache.spark.sql.types.{StringType, StructField, StructType} * CREATE [OR REPLACE] FUNCTION [IF NOT EXISTS] [databaseName.]functionName * AS className [USING JAR\FILE 'uri' [, JAR|FILE 'uri']] * }}} + * + * @param ignoreIfExists: When true, ignore if the function with the specified name exists + * in the specified database. + * @param replace: When true, alter the function with the specified name */ case class CreateFunctionCommand( databaseName: Option[String], @@ -47,17 +51,17 @@ case class CreateFunctionCommand( className: String, resources: Seq[FunctionResource], isTemp: Boolean, - ifNotExists: Boolean, + ignoreIfExists: Boolean, replace: Boolean) extends RunnableCommand { - if (ifNotExists && replace) { + if (ignoreIfExists && replace) { throw new AnalysisException("CREATE FUNCTION with both IF NOT EXISTS and REPLACE" + " is not allowed.") } // Disallow to define a temporary function with `IF NOT EXISTS` - if (ifNotExists && isTemp) { + if (ignoreIfExists && isTemp) { throw new AnalysisException( "It is not allowed to define a TEMPORARY function with IF NOT EXISTS.") } @@ -79,12 +83,12 @@ case class CreateFunctionCommand( // Handles `CREATE OR REPLACE FUNCTION AS ... USING ...` if (replace && catalog.functionExists(func.identifier)) { // alter the function in the metastore - catalog.alterFunction(CatalogFunction(func.identifier, className, resources)) + catalog.alterFunction(func) } else { // For a permanent, we will store the metadata into underlying external catalog. // This function will be loaded into the FunctionRegistry when a query uses it. // We do not load it into FunctionRegistry right now. - catalog.createFunction(CatalogFunction(func.identifier, className, resources), ifNotExists) + catalog.createFunction(func, ignoreIfExists) } } Seq.empty[Row] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 0d11682d80a3c..0d54c02c3d06f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -2267,7 +2267,9 @@ object functions { } /** - * Computes the length of a given string or binary column. + * Computes the character length of a given string or number of bytes of a binary string. + * The length of character strings include the trailing spaces. The length of binary strings + * includes binary zeros. * * @group string_funcs * @since 1.5.0 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala index 2b1aea08b1223..e0ccae15f1d05 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala @@ -236,7 +236,7 @@ class DDLParserSuite extends PlanTest with SharedSQLContext { Seq( FunctionResource(FunctionResourceType.fromString("jar"), "/path/to/jar1"), FunctionResource(FunctionResourceType.fromString("jar"), "/path/to/jar2")), - isTemp = true, ifNotExists = false, replace = false) + isTemp = true, ignoreIfExists = false, replace = false) val expected2 = CreateFunctionCommand( Some("hello"), "world", @@ -244,7 +244,7 @@ class DDLParserSuite extends PlanTest with SharedSQLContext { Seq( FunctionResource(FunctionResourceType.fromString("archive"), "/path/to/archive"), FunctionResource(FunctionResourceType.fromString("file"), "/path/to/file")), - isTemp = false, ifNotExists = false, replace = false) + isTemp = false, ignoreIfExists = false, replace = false) val expected3 = CreateFunctionCommand( None, "helloworld3", @@ -252,7 +252,7 @@ class DDLParserSuite extends PlanTest with SharedSQLContext { Seq( FunctionResource(FunctionResourceType.fromString("jar"), "/path/to/jar1"), FunctionResource(FunctionResourceType.fromString("jar"), "/path/to/jar2")), - isTemp = true, ifNotExists = false, replace = true) + isTemp = true, ignoreIfExists = false, replace = true) val expected4 = CreateFunctionCommand( Some("hello"), "world1", @@ -260,7 +260,7 @@ class DDLParserSuite extends PlanTest with SharedSQLContext { Seq( FunctionResource(FunctionResourceType.fromString("archive"), "/path/to/archive"), FunctionResource(FunctionResourceType.fromString("file"), "/path/to/file")), - isTemp = false, ifNotExists = false, replace = true) + isTemp = false, ignoreIfExists = false, replace = true) val expected5 = CreateFunctionCommand( Some("hello"), "world2", @@ -268,7 +268,7 @@ class DDLParserSuite extends PlanTest with SharedSQLContext { Seq( FunctionResource(FunctionResourceType.fromString("archive"), "/path/to/archive"), FunctionResource(FunctionResourceType.fromString("file"), "/path/to/file")), - isTemp = false, ifNotExists = true, replace = false) + isTemp = false, ignoreIfExists = true, replace = false) comparePlans(parsed1, expected1) comparePlans(parsed2, expected2) comparePlans(parsed3, expected3) From 9775df67f924663598d51723a878557ddafb8cfd Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Wed, 7 Feb 2018 23:24:16 +0900 Subject: [PATCH 0300/1161] [SPARK-23122][PYSPARK][FOLLOWUP] Replace registerTempTable by createOrReplaceTempView ## What changes were proposed in this pull request? Replace `registerTempTable` by `createOrReplaceTempView`. ## How was this patch tested? N/A Author: gatorsmile Closes #20523 from gatorsmile/updateExamples. --- python/pyspark/sql/udf.py | 2 +- .../src/test/java/test/org/apache/spark/sql/JavaUDAFSuite.java | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index 0f759c448b8a7..08c6b9e521e82 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -356,7 +356,7 @@ def registerJavaUDAF(self, name, javaClassName): >>> spark.udf.registerJavaUDAF("javaUDAF", "test.org.apache.spark.sql.MyDoubleAvg") >>> df = spark.createDataFrame([(1, "a"),(2, "b"), (3, "a")],["id", "name"]) - >>> df.registerTempTable("df") + >>> df.createOrReplaceTempView("df") >>> spark.sql("SELECT name, javaUDAF(id) as avg from df group by name").collect() [Row(name=u'b', avg=102.0), Row(name=u'a', avg=102.0)] """ diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDAFSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDAFSuite.java index ddbaa45a483cb..08dc129f27a0c 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDAFSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDAFSuite.java @@ -46,7 +46,7 @@ public void tearDown() { @SuppressWarnings("unchecked") @Test public void udf1Test() { - spark.range(1, 10).toDF("value").registerTempTable("df"); + spark.range(1, 10).toDF("value").createOrReplaceTempView("df"); spark.udf().registerJavaUDAF("myDoubleAvg", MyDoubleAvg.class.getName()); Row result = spark.sql("SELECT myDoubleAvg(value) as my_avg from df").head(); Assert.assertEquals(105.0, result.getDouble(0), 1.0e-6); From 71cfba04aeec5ae9b85a507b13996e80f8750edc Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Wed, 7 Feb 2018 23:28:10 +0900 Subject: [PATCH 0301/1161] [SPARK-23319][TESTS] Explicitly specify Pandas and PyArrow versions in PySpark tests (to skip or test) ## What changes were proposed in this pull request? This PR proposes to explicitly specify Pandas and PyArrow versions in PySpark tests to skip or test. We declared the extra dependencies: https://github.com/apache/spark/blob/b8bfce51abf28c66ba1fc67b0f25fe1617c81025/python/setup.py#L204 In case of PyArrow: Currently we only check if pyarrow is installed or not without checking the version. It already fails to run tests. For example, if PyArrow 0.7.0 is installed: ``` ====================================================================== ERROR: test_vectorized_udf_wrong_return_type (pyspark.sql.tests.ScalarPandasUDF) ---------------------------------------------------------------------- Traceback (most recent call last): File "/.../spark/python/pyspark/sql/tests.py", line 4019, in test_vectorized_udf_wrong_return_type f = pandas_udf(lambda x: x * 1.0, MapType(LongType(), LongType())) File "/.../spark/python/pyspark/sql/functions.py", line 2309, in pandas_udf return _create_udf(f=f, returnType=return_type, evalType=eval_type) File "/.../spark/python/pyspark/sql/udf.py", line 47, in _create_udf require_minimum_pyarrow_version() File "/.../spark/python/pyspark/sql/utils.py", line 132, in require_minimum_pyarrow_version "however, your version was %s." % pyarrow.__version__) ImportError: pyarrow >= 0.8.0 must be installed on calling Python process; however, your version was 0.7.0. ---------------------------------------------------------------------- Ran 33 tests in 8.098s FAILED (errors=33) ``` In case of Pandas: There are few tests for old Pandas which were tested only when Pandas version was lower, and I rewrote them to be tested when both Pandas version is lower and missing. ## How was this patch tested? Manually tested by modifying the condition: ``` test_createDataFrame_column_name_encoding (pyspark.sql.tests.ArrowTests) ... skipped 'Pandas >= 1.19.2 must be installed; however, your version was 0.19.2.' test_createDataFrame_does_not_modify_input (pyspark.sql.tests.ArrowTests) ... skipped 'Pandas >= 1.19.2 must be installed; however, your version was 0.19.2.' test_createDataFrame_respect_session_timezone (pyspark.sql.tests.ArrowTests) ... skipped 'Pandas >= 1.19.2 must be installed; however, your version was 0.19.2.' ``` ``` test_createDataFrame_column_name_encoding (pyspark.sql.tests.ArrowTests) ... skipped 'Pandas >= 0.19.2 must be installed; however, it was not found.' test_createDataFrame_does_not_modify_input (pyspark.sql.tests.ArrowTests) ... skipped 'Pandas >= 0.19.2 must be installed; however, it was not found.' test_createDataFrame_respect_session_timezone (pyspark.sql.tests.ArrowTests) ... skipped 'Pandas >= 0.19.2 must be installed; however, it was not found.' ``` ``` test_createDataFrame_column_name_encoding (pyspark.sql.tests.ArrowTests) ... skipped 'PyArrow >= 1.8.0 must be installed; however, your version was 0.8.0.' test_createDataFrame_does_not_modify_input (pyspark.sql.tests.ArrowTests) ... skipped 'PyArrow >= 1.8.0 must be installed; however, your version was 0.8.0.' test_createDataFrame_respect_session_timezone (pyspark.sql.tests.ArrowTests) ... skipped 'PyArrow >= 1.8.0 must be installed; however, your version was 0.8.0.' ``` ``` test_createDataFrame_column_name_encoding (pyspark.sql.tests.ArrowTests) ... skipped 'PyArrow >= 0.8.0 must be installed; however, it was not found.' test_createDataFrame_does_not_modify_input (pyspark.sql.tests.ArrowTests) ... skipped 'PyArrow >= 0.8.0 must be installed; however, it was not found.' test_createDataFrame_respect_session_timezone (pyspark.sql.tests.ArrowTests) ... skipped 'PyArrow >= 0.8.0 must be installed; however, it was not found.' ``` Author: hyukjinkwon Closes #20487 from HyukjinKwon/pyarrow-pandas-skip. --- pom.xml | 4 ++ python/pyspark/sql/dataframe.py | 3 ++ python/pyspark/sql/session.py | 3 ++ python/pyspark/sql/tests.py | 87 ++++++++++++++++++--------------- python/pyspark/sql/utils.py | 30 +++++++++--- python/setup.py | 10 +++- 6 files changed, 89 insertions(+), 48 deletions(-) diff --git a/pom.xml b/pom.xml index 666d5d7169a15..d18831df1db6d 100644 --- a/pom.xml +++ b/pom.xml @@ -185,6 +185,10 @@ 2.8 1.8 1.0.0 + 0.8.0 ${java.home} diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 59a417015b949..8ec24db8717b2 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1913,6 +1913,9 @@ def toPandas(self): 0 2 Alice 1 5 Bob """ + from pyspark.sql.utils import require_minimum_pandas_version + require_minimum_pandas_version() + import pandas as pd if self.sql_ctx.getConf("spark.sql.execution.pandas.respectSessionTimeZone").lower() \ diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index 1ed04298bc899..b3af9b82953f3 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -646,6 +646,9 @@ def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=Tr except Exception: has_pandas = False if has_pandas and isinstance(data, pandas.DataFrame): + from pyspark.sql.utils import require_minimum_pandas_version + require_minimum_pandas_version() + if self.conf.get("spark.sql.execution.pandas.respectSessionTimeZone").lower() \ == "true": timezone = self.conf.get("spark.sql.session.timeZone") diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 53da7dd45c2f2..58359b61dc83a 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -48,19 +48,26 @@ else: import unittest -_have_pandas = False -_have_old_pandas = False +_pandas_requirement_message = None try: - import pandas - try: - from pyspark.sql.utils import require_minimum_pandas_version - require_minimum_pandas_version() - _have_pandas = True - except: - _have_old_pandas = True -except: - # No Pandas, but that's okay, we'll skip those tests - pass + from pyspark.sql.utils import require_minimum_pandas_version + require_minimum_pandas_version() +except ImportError as e: + from pyspark.util import _exception_message + # If Pandas version requirement is not satisfied, skip related tests. + _pandas_requirement_message = _exception_message(e) + +_pyarrow_requirement_message = None +try: + from pyspark.sql.utils import require_minimum_pyarrow_version + require_minimum_pyarrow_version() +except ImportError as e: + from pyspark.util import _exception_message + # If Arrow version requirement is not satisfied, skip related tests. + _pyarrow_requirement_message = _exception_message(e) + +_have_pandas = _pandas_requirement_message is None +_have_pyarrow = _pyarrow_requirement_message is None from pyspark import SparkContext from pyspark.sql import SparkSession, SQLContext, HiveContext, Column, Row @@ -75,15 +82,6 @@ from pyspark.sql.utils import AnalysisException, ParseException, IllegalArgumentException -_have_arrow = False -try: - import pyarrow - _have_arrow = True -except: - # No Arrow, but that's okay, we'll skip those tests - pass - - class UTCOffsetTimezone(datetime.tzinfo): """ Specifies timezone in UTC offset @@ -2794,7 +2792,6 @@ def count_bucketed_cols(names, table="pyspark_bucket"): def _to_pandas(self): from datetime import datetime, date - import numpy as np schema = StructType().add("a", IntegerType()).add("b", StringType())\ .add("c", BooleanType()).add("d", FloatType())\ .add("dt", DateType()).add("ts", TimestampType()) @@ -2807,7 +2804,7 @@ def _to_pandas(self): df = self.spark.createDataFrame(data, schema) return df.toPandas() - @unittest.skipIf(not _have_pandas, "Pandas not installed") + @unittest.skipIf(not _have_pandas, _pandas_requirement_message) def test_to_pandas(self): import numpy as np pdf = self._to_pandas() @@ -2819,13 +2816,13 @@ def test_to_pandas(self): self.assertEquals(types[4], np.object) # datetime.date self.assertEquals(types[5], 'datetime64[ns]') - @unittest.skipIf(not _have_old_pandas, "Old Pandas not installed") - def test_to_pandas_old(self): + @unittest.skipIf(_have_pandas, "Required Pandas was found.") + def test_to_pandas_required_pandas_not_found(self): with QuietTest(self.sc): with self.assertRaisesRegexp(ImportError, 'Pandas >= .* must be installed'): self._to_pandas() - @unittest.skipIf(not _have_pandas, "Pandas not installed") + @unittest.skipIf(not _have_pandas, _pandas_requirement_message) def test_to_pandas_avoid_astype(self): import numpy as np schema = StructType().add("a", IntegerType()).add("b", StringType())\ @@ -2843,7 +2840,7 @@ def test_create_dataframe_from_array_of_long(self): df = self.spark.createDataFrame(data) self.assertEqual(df.first(), Row(longarray=[-9223372036854775808, 0, 9223372036854775807])) - @unittest.skipIf(not _have_pandas, "Pandas not installed") + @unittest.skipIf(not _have_pandas, _pandas_requirement_message) def test_create_dataframe_from_pandas_with_timestamp(self): import pandas as pd from datetime import datetime @@ -2858,14 +2855,16 @@ def test_create_dataframe_from_pandas_with_timestamp(self): self.assertTrue(isinstance(df.schema['ts'].dataType, TimestampType)) self.assertTrue(isinstance(df.schema['d'].dataType, DateType)) - @unittest.skipIf(not _have_old_pandas, "Old Pandas not installed") - def test_create_dataframe_from_old_pandas(self): - import pandas as pd - from datetime import datetime - pdf = pd.DataFrame({"ts": [datetime(2017, 10, 31, 1, 1, 1)], - "d": [pd.Timestamp.now().date()]}) + @unittest.skipIf(_have_pandas, "Required Pandas was found.") + def test_create_dataframe_required_pandas_not_found(self): with QuietTest(self.sc): - with self.assertRaisesRegexp(ImportError, 'Pandas >= .* must be installed'): + with self.assertRaisesRegexp( + ImportError, + '(Pandas >= .* must be installed|No module named pandas)'): + import pandas as pd + from datetime import datetime + pdf = pd.DataFrame({"ts": [datetime(2017, 10, 31, 1, 1, 1)], + "d": [pd.Timestamp.now().date()]}) self.spark.createDataFrame(pdf) @@ -3383,7 +3382,9 @@ def __init__(self, **kwargs): _make_type_verifier(data_type, nullable=False)(obj) -@unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed") +@unittest.skipIf( + not _have_pandas or not _have_pyarrow, + _pandas_requirement_message or _pyarrow_requirement_message) class ArrowTests(ReusedSQLTestCase): @classmethod @@ -3641,7 +3642,9 @@ def test_createDataFrame_with_int_col_names(self): self.assertEqual(pdf_col_names, df_arrow.columns) -@unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed") +@unittest.skipIf( + not _have_pandas or not _have_pyarrow, + _pandas_requirement_message or _pyarrow_requirement_message) class PandasUDFTests(ReusedSQLTestCase): def test_pandas_udf_basic(self): from pyspark.rdd import PythonEvalType @@ -3765,7 +3768,9 @@ def foo(k, v): return k -@unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed") +@unittest.skipIf( + not _have_pandas or not _have_pyarrow, + _pandas_requirement_message or _pyarrow_requirement_message) class ScalarPandasUDFTests(ReusedSQLTestCase): @classmethod @@ -4278,7 +4283,9 @@ def test_register_vectorized_udf_basic(self): self.assertEquals(expected.collect(), res2.collect()) -@unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed") +@unittest.skipIf( + not _have_pandas or not _have_pyarrow, + _pandas_requirement_message or _pyarrow_requirement_message) class GroupedMapPandasUDFTests(ReusedSQLTestCase): @property @@ -4447,7 +4454,9 @@ def test_unsupported_types(self): df.groupby('id').apply(f).collect() -@unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed") +@unittest.skipIf( + not _have_pandas or not _have_pyarrow, + _pandas_requirement_message or _pyarrow_requirement_message) class GroupedAggPandasUDFTests(ReusedSQLTestCase): @property diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py index 08c34c6dccc5e..578298632dd4c 100644 --- a/python/pyspark/sql/utils.py +++ b/python/pyspark/sql/utils.py @@ -115,18 +115,32 @@ def toJArray(gateway, jtype, arr): def require_minimum_pandas_version(): """ Raise ImportError if minimum version of Pandas is not installed """ + # TODO(HyukjinKwon): Relocate and deduplicate the version specification. + minimum_pandas_version = "0.19.2" + from distutils.version import LooseVersion - import pandas - if LooseVersion(pandas.__version__) < LooseVersion('0.19.2'): - raise ImportError("Pandas >= 0.19.2 must be installed on calling Python process; " - "however, your version was %s." % pandas.__version__) + try: + import pandas + except ImportError: + raise ImportError("Pandas >= %s must be installed; however, " + "it was not found." % minimum_pandas_version) + if LooseVersion(pandas.__version__) < LooseVersion(minimum_pandas_version): + raise ImportError("Pandas >= %s must be installed; however, " + "your version was %s." % (minimum_pandas_version, pandas.__version__)) def require_minimum_pyarrow_version(): """ Raise ImportError if minimum version of pyarrow is not installed """ + # TODO(HyukjinKwon): Relocate and deduplicate the version specification. + minimum_pyarrow_version = "0.8.0" + from distutils.version import LooseVersion - import pyarrow - if LooseVersion(pyarrow.__version__) < LooseVersion('0.8.0'): - raise ImportError("pyarrow >= 0.8.0 must be installed on calling Python process; " - "however, your version was %s." % pyarrow.__version__) + try: + import pyarrow + except ImportError: + raise ImportError("PyArrow >= %s must be installed; however, " + "it was not found." % minimum_pyarrow_version) + if LooseVersion(pyarrow.__version__) < LooseVersion(minimum_pyarrow_version): + raise ImportError("PyArrow >= %s must be installed; however, " + "your version was %s." % (minimum_pyarrow_version, pyarrow.__version__)) diff --git a/python/setup.py b/python/setup.py index 251d4526d4dd0..6a98401941d8d 100644 --- a/python/setup.py +++ b/python/setup.py @@ -100,6 +100,11 @@ def _supports_symlinks(): file=sys.stderr) exit(-1) +# If you are changing the versions here, please also change ./python/pyspark/sql/utils.py and +# ./python/run-tests.py. In case of Arrow, you should also check ./pom.xml. +_minimum_pandas_version = "0.19.2" +_minimum_pyarrow_version = "0.8.0" + try: # We copy the shell script to be under pyspark/python/pyspark so that the launcher scripts # find it where expected. The rest of the files aren't copied because they are accessed @@ -201,7 +206,10 @@ def _supports_symlinks(): extras_require={ 'ml': ['numpy>=1.7'], 'mllib': ['numpy>=1.7'], - 'sql': ['pandas>=0.19.2', 'pyarrow>=0.8.0'] + 'sql': [ + 'pandas>=%s' % _minimum_pandas_version, + 'pyarrow>=%s' % _minimum_pyarrow_version, + ] }, classifiers=[ 'Development Status :: 5 - Production/Stable', From 9841ae0313cbee1f083f131f9446808c90ed5a7b Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 7 Feb 2018 09:48:49 -0800 Subject: [PATCH 0302/1161] [SPARK-23345][SQL] Remove open stream record even closing it fails ## What changes were proposed in this pull request? When `DebugFilesystem` closes opened stream, if any exception occurs, we still need to remove the open stream record from `DebugFilesystem`. Otherwise, it goes to report leaked filesystem connection. ## How was this patch tested? Existing tests. Author: Liang-Chi Hsieh Closes #20524 from viirya/SPARK-23345. --- core/src/test/scala/org/apache/spark/DebugFilesystem.scala | 7 +++++-- .../org/apache/spark/sql/test/SharedSparkSession.scala | 2 +- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/DebugFilesystem.scala b/core/src/test/scala/org/apache/spark/DebugFilesystem.scala index 91355f7362900..a5bdc95790722 100644 --- a/core/src/test/scala/org/apache/spark/DebugFilesystem.scala +++ b/core/src/test/scala/org/apache/spark/DebugFilesystem.scala @@ -103,8 +103,11 @@ class DebugFilesystem extends LocalFileSystem { override def markSupported(): Boolean = wrapped.markSupported() override def close(): Unit = { - wrapped.close() - removeOpenStream(wrapped) + try { + wrapped.close() + } finally { + removeOpenStream(wrapped) + } } override def read(): Int = wrapped.read() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala index 0b4629a51b425..e758c865b908f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala @@ -111,7 +111,7 @@ trait SharedSparkSession spark.sharedState.cacheManager.clearCache() // files can be closed from other threads, so wait a bit // normally this doesn't take more than 1s - eventually(timeout(10.seconds)) { + eventually(timeout(10.seconds), interval(2.seconds)) { DebugFilesystem.assertNoOpenStreams() } } From 30295bf5a6754d0ae43334f7bf00e7a29ed0f1af Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Wed, 7 Feb 2018 15:22:53 -0800 Subject: [PATCH 0303/1161] [SPARK-23092][SQL] Migrate MemoryStream to DataSourceV2 APIs ## What changes were proposed in this pull request? This PR migrates the MemoryStream to DataSourceV2 APIs. One additional change is in the reported keys in StreamingQueryProgress.durationMs. "getOffset" and "getBatch" replaced with "setOffsetRange" and "getEndOffset" as tracking these make more sense. Unit tests changed accordingly. ## How was this patch tested? Existing unit tests, few updated unit tests. Author: Tathagata Das Author: Burak Yavuz Closes #20445 from tdas/SPARK-23092. --- .../sql/execution/streaming/LongOffset.scala | 4 +- .../streaming/MicroBatchExecution.scala | 27 ++-- .../sql/execution/streaming/memory.scala | 132 +++++++++++------- .../sources/RateStreamSourceV2.scala | 2 +- .../streaming/ForeachSinkSuite.scala | 55 +++----- .../spark/sql/streaming/StreamSuite.scala | 8 +- .../spark/sql/streaming/StreamTest.scala | 2 +- .../StreamingQueryListenerSuite.scala | 5 +- .../sql/streaming/StreamingQuerySuite.scala | 70 ++++++---- 9 files changed, 171 insertions(+), 134 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/LongOffset.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/LongOffset.scala index 5f0b195fcfcb8..3ff5b86ac45d6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/LongOffset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/LongOffset.scala @@ -17,10 +17,12 @@ package org.apache.spark.sql.execution.streaming +import org.apache.spark.sql.sources.v2.reader.streaming.{Offset => OffsetV2} + /** * A simple offset for sources that produce a single linear stream of data. */ -case class LongOffset(offset: Long) extends Offset { +case class LongOffset(offset: Long) extends OffsetV2 { override val json = offset.toString diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index d9aa8573ba930..045d2b4b9569c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -270,16 +270,17 @@ class MicroBatchExecution( } case s: MicroBatchReader => updateStatusMessage(s"Getting offsets from $s") - reportTimeTaken("getOffset") { - // Once v1 streaming source execution is gone, we can refactor this away. - // For now, we set the range here to get the source to infer the available end offset, - // get that offset, and then set the range again when we later execute. - s.setOffsetRange( - toJava(availableOffsets.get(s).map(off => s.deserializeOffset(off.json))), - Optional.empty()) - - (s, Some(s.getEndOffset)) + reportTimeTaken("setOffsetRange") { + // Once v1 streaming source execution is gone, we can refactor this away. + // For now, we set the range here to get the source to infer the available end offset, + // get that offset, and then set the range again when we later execute. + s.setOffsetRange( + toJava(availableOffsets.get(s).map(off => s.deserializeOffset(off.json))), + Optional.empty()) } + + val currentOffset = reportTimeTaken("getEndOffset") { s.getEndOffset() } + (s, Option(currentOffset)) }.toMap availableOffsets ++= latestOffsets.filter { case (_, o) => o.nonEmpty }.mapValues(_.get) @@ -401,10 +402,14 @@ class MicroBatchExecution( case (reader: MicroBatchReader, available) if committedOffsets.get(reader).map(_ != available).getOrElse(true) => val current = committedOffsets.get(reader).map(off => reader.deserializeOffset(off.json)) + val availableV2: OffsetV2 = available match { + case v1: SerializedOffset => reader.deserializeOffset(v1.json) + case v2: OffsetV2 => v2 + } reader.setOffsetRange( toJava(current), - Optional.of(available.asInstanceOf[OffsetV2])) - logDebug(s"Retrieving data from $reader: $current -> $available") + Optional.of(availableV2)) + logDebug(s"Retrieving data from $reader: $current -> $availableV2") Some(reader -> new StreamingDataSourceV2Relation(reader.readSchema().toAttributes, reader)) case _ => None diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala index 509a69dd922fb..352d4ce9fbcaa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -17,21 +17,23 @@ package org.apache.spark.sql.execution.streaming +import java.{util => ju} +import java.util.Optional import java.util.concurrent.atomic.AtomicInteger import javax.annotation.concurrent.GuardedBy import scala.collection.JavaConverters._ -import scala.collection.mutable import scala.collection.mutable.{ArrayBuffer, ListBuffer} import scala.util.control.NonFatal import org.apache.spark.internal.Logging import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.encoders.encoderFor -import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LocalRelation, Statistics} +import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeRow} +import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics} import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ -import org.apache.spark.sql.execution.SQLExecution +import org.apache.spark.sql.sources.v2.reader.{DataReader, DataReaderFactory, SupportsScanUnsafeRow} +import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset => OffsetV2} import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils @@ -51,9 +53,10 @@ object MemoryStream { * available. */ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) - extends Source with Logging { + extends MicroBatchReader with SupportsScanUnsafeRow with Logging { protected val encoder = encoderFor[A] - protected val logicalPlan = StreamingExecutionRelation(this, sqlContext.sparkSession) + private val attributes = encoder.schema.toAttributes + protected val logicalPlan = StreamingExecutionRelation(this, attributes)(sqlContext.sparkSession) protected val output = logicalPlan.output /** @@ -61,11 +64,17 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) * Stored in a ListBuffer to facilitate removing committed batches. */ @GuardedBy("this") - protected val batches = new ListBuffer[Dataset[A]] + protected val batches = new ListBuffer[Array[UnsafeRow]] @GuardedBy("this") protected var currentOffset: LongOffset = new LongOffset(-1) + @GuardedBy("this") + private var startOffset = new LongOffset(-1) + + @GuardedBy("this") + private var endOffset = new LongOffset(-1) + /** * Last offset that was discarded, or -1 if no commits have occurred. Note that the value * -1 is used in calculations below and isn't just an arbitrary constant. @@ -73,8 +82,6 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) @GuardedBy("this") protected var lastOffsetCommitted : LongOffset = new LongOffset(-1) - def schema: StructType = encoder.schema - def toDS(): Dataset[A] = { Dataset(sqlContext.sparkSession, logicalPlan) } @@ -88,72 +95,69 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) } def addData(data: TraversableOnce[A]): Offset = { - val encoded = data.toVector.map(d => encoder.toRow(d).copy()) - val plan = new LocalRelation(schema.toAttributes, encoded, isStreaming = true) - val ds = Dataset[A](sqlContext.sparkSession, plan) - logDebug(s"Adding ds: $ds") + val objects = data.toSeq + val rows = objects.iterator.map(d => encoder.toRow(d).copy().asInstanceOf[UnsafeRow]).toArray + logDebug(s"Adding: $objects") this.synchronized { currentOffset = currentOffset + 1 - batches += ds + batches += rows currentOffset } } override def toString: String = s"MemoryStream[${Utils.truncatedString(output, ",")}]" - override def getOffset: Option[Offset] = synchronized { - if (currentOffset.offset == -1) { - None - } else { - Some(currentOffset) + override def setOffsetRange(start: Optional[OffsetV2], end: Optional[OffsetV2]): Unit = { + synchronized { + startOffset = start.orElse(LongOffset(-1)).asInstanceOf[LongOffset] + endOffset = end.orElse(currentOffset).asInstanceOf[LongOffset] } } - override def getBatch(start: Option[Offset], end: Offset): DataFrame = { - // Compute the internal batch numbers to fetch: [startOrdinal, endOrdinal) - val startOrdinal = - start.flatMap(LongOffset.convert).getOrElse(LongOffset(-1)).offset.toInt + 1 - val endOrdinal = LongOffset.convert(end).getOrElse(LongOffset(-1)).offset.toInt + 1 - - // Internal buffer only holds the batches after lastCommittedOffset. - val newBlocks = synchronized { - val sliceStart = startOrdinal - lastOffsetCommitted.offset.toInt - 1 - val sliceEnd = endOrdinal - lastOffsetCommitted.offset.toInt - 1 - assert(sliceStart <= sliceEnd, s"sliceStart: $sliceStart sliceEnd: $sliceEnd") - batches.slice(sliceStart, sliceEnd) - } + override def readSchema(): StructType = encoder.schema - if (newBlocks.isEmpty) { - return sqlContext.internalCreateDataFrame( - sqlContext.sparkContext.emptyRDD, schema, isStreaming = true) - } + override def deserializeOffset(json: String): OffsetV2 = LongOffset(json.toLong) + + override def getStartOffset: OffsetV2 = synchronized { + if (startOffset.offset == -1) null else startOffset + } - logDebug(generateDebugString(newBlocks, startOrdinal, endOrdinal)) + override def getEndOffset: OffsetV2 = synchronized { + if (endOffset.offset == -1) null else endOffset + } - newBlocks - .map(_.toDF()) - .reduceOption(_ union _) - .getOrElse { - sys.error("No data selected!") + override def createUnsafeRowReaderFactories(): ju.List[DataReaderFactory[UnsafeRow]] = { + synchronized { + // Compute the internal batch numbers to fetch: [startOrdinal, endOrdinal) + val startOrdinal = startOffset.offset.toInt + 1 + val endOrdinal = endOffset.offset.toInt + 1 + + // Internal buffer only holds the batches after lastCommittedOffset. + val newBlocks = synchronized { + val sliceStart = startOrdinal - lastOffsetCommitted.offset.toInt - 1 + val sliceEnd = endOrdinal - lastOffsetCommitted.offset.toInt - 1 + assert(sliceStart <= sliceEnd, s"sliceStart: $sliceStart sliceEnd: $sliceEnd") + batches.slice(sliceStart, sliceEnd) } + + logDebug(generateDebugString(newBlocks.flatten, startOrdinal, endOrdinal)) + + newBlocks.map { block => + new MemoryStreamDataReaderFactory(block).asInstanceOf[DataReaderFactory[UnsafeRow]] + }.asJava + } } private def generateDebugString( - blocks: TraversableOnce[Dataset[A]], + rows: Seq[UnsafeRow], startOrdinal: Int, endOrdinal: Int): String = { - val originalUnsupportedCheck = - sqlContext.getConf("spark.sql.streaming.unsupportedOperationCheck") - try { - sqlContext.setConf("spark.sql.streaming.unsupportedOperationCheck", "false") - s"MemoryBatch [$startOrdinal, $endOrdinal]: " + - s"${blocks.flatMap(_.collect()).mkString(", ")}" - } finally { - sqlContext.setConf("spark.sql.streaming.unsupportedOperationCheck", originalUnsupportedCheck) - } + val fromRow = encoder.resolveAndBind().fromRow _ + s"MemoryBatch [$startOrdinal, $endOrdinal]: " + + s"${rows.map(row => fromRow(row)).mkString(", ")}" } - override def commit(end: Offset): Unit = synchronized { + override def commit(end: OffsetV2): Unit = synchronized { def check(newOffset: LongOffset): Unit = { val offsetDiff = (newOffset.offset - lastOffsetCommitted.offset).toInt @@ -176,11 +180,33 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) def reset(): Unit = synchronized { batches.clear() + startOffset = LongOffset(-1) + endOffset = LongOffset(-1) currentOffset = new LongOffset(-1) lastOffsetCommitted = new LongOffset(-1) } } + +class MemoryStreamDataReaderFactory(records: Array[UnsafeRow]) + extends DataReaderFactory[UnsafeRow] { + override def createDataReader(): DataReader[UnsafeRow] = { + new DataReader[UnsafeRow] { + private var currentIndex = -1 + + override def next(): Boolean = { + // Return true as long as the new index is in the array. + currentIndex += 1 + currentIndex < records.length + } + + override def get(): UnsafeRow = records(currentIndex) + + override def close(): Unit = {} + } + } +} + /** * A sink that stores the results in memory. This [[Sink]] is primarily intended for use in unit * tests and does not provide durability. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala index 1315885da8a6f..077a255946a6b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala @@ -151,7 +151,7 @@ case class RateStreamBatchTask(vals: Seq[(Long, Long)]) extends DataReaderFactor } class RateStreamBatchReader(vals: Seq[(Long, Long)]) extends DataReader[Row] { - var currentIndex = -1 + private var currentIndex = -1 override def next(): Boolean = { // Return true as long as the new index is in the seq. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala index 41434e6d8b974..b249dd41a84a6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala @@ -46,49 +46,34 @@ class ForeachSinkSuite extends StreamTest with SharedSQLContext with BeforeAndAf .foreach(new TestForeachWriter()) .start() - // -- batch 0 --------------------------------------- - input.addData(1, 2, 3, 4) - query.processAllAvailable() + def verifyOutput(expectedVersion: Int, expectedData: Seq[Int]): Unit = { + import ForeachSinkSuite._ - var expectedEventsForPartition0 = Seq( - ForeachSinkSuite.Open(partition = 0, version = 0), - ForeachSinkSuite.Process(value = 2), - ForeachSinkSuite.Process(value = 3), - ForeachSinkSuite.Close(None) - ) - var expectedEventsForPartition1 = Seq( - ForeachSinkSuite.Open(partition = 1, version = 0), - ForeachSinkSuite.Process(value = 1), - ForeachSinkSuite.Process(value = 4), - ForeachSinkSuite.Close(None) - ) + val events = ForeachSinkSuite.allEvents() + assert(events.size === 2) // one seq of events for each of the 2 partitions - var allEvents = ForeachSinkSuite.allEvents() - assert(allEvents.size === 2) - assert(allEvents.toSet === Set(expectedEventsForPartition0, expectedEventsForPartition1)) + // Verify both seq of events have an Open event as the first event + assert(events.map(_.head).toSet === Set(0, 1).map(p => Open(p, expectedVersion))) + + // Verify all the Process event correspond to the expected data + val allProcessEvents = events.flatMap(_.filter(_.isInstanceOf[Process[_]])) + assert(allProcessEvents.toSet === expectedData.map { data => Process(data) }.toSet) + + // Verify both seq of events have a Close event as the last event + assert(events.map(_.last).toSet === Set(Close(None), Close(None))) + } + // -- batch 0 --------------------------------------- ForeachSinkSuite.clear() + input.addData(1, 2, 3, 4) + query.processAllAvailable() + verifyOutput(expectedVersion = 0, expectedData = 1 to 4) // -- batch 1 --------------------------------------- + ForeachSinkSuite.clear() input.addData(5, 6, 7, 8) query.processAllAvailable() - - expectedEventsForPartition0 = Seq( - ForeachSinkSuite.Open(partition = 0, version = 1), - ForeachSinkSuite.Process(value = 5), - ForeachSinkSuite.Process(value = 7), - ForeachSinkSuite.Close(None) - ) - expectedEventsForPartition1 = Seq( - ForeachSinkSuite.Open(partition = 1, version = 1), - ForeachSinkSuite.Process(value = 6), - ForeachSinkSuite.Process(value = 8), - ForeachSinkSuite.Close(None) - ) - - allEvents = ForeachSinkSuite.allEvents() - assert(allEvents.size === 2) - assert(allEvents.toSet === Set(expectedEventsForPartition0, expectedEventsForPartition1)) + verifyOutput(expectedVersion = 1, expectedData = 5 to 8) query.stop() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala index c65e5d3dd75c2..d1a04833390f5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala @@ -492,16 +492,16 @@ class StreamSuite extends StreamTest { val explainWithoutExtended = q.explainInternal(false) // `extended = false` only displays the physical plan. - assert("LocalRelation".r.findAllMatchIn(explainWithoutExtended).size === 0) - assert("LocalTableScan".r.findAllMatchIn(explainWithoutExtended).size === 1) + assert("StreamingDataSourceV2Relation".r.findAllMatchIn(explainWithoutExtended).size === 0) + assert("DataSourceV2Scan".r.findAllMatchIn(explainWithoutExtended).size === 1) // Use "StateStoreRestore" to verify that it does output a streaming physical plan assert(explainWithoutExtended.contains("StateStoreRestore")) val explainWithExtended = q.explainInternal(true) // `extended = true` displays 3 logical plans (Parsed/Optimized/Optimized) and 1 physical // plan. - assert("LocalRelation".r.findAllMatchIn(explainWithExtended).size === 3) - assert("LocalTableScan".r.findAllMatchIn(explainWithExtended).size === 1) + assert("StreamingDataSourceV2Relation".r.findAllMatchIn(explainWithExtended).size === 3) + assert("DataSourceV2Scan".r.findAllMatchIn(explainWithExtended).size === 1) // Use "StateStoreRestore" to verify that it does output a streaming physical plan assert(explainWithExtended.contains("StateStoreRestore")) } finally { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index d6433562fb29b..37fe595529baf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -120,7 +120,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be case class AddDataMemory[A](source: MemoryStream[A], data: Seq[A]) extends AddData { override def toString: String = s"AddData to $source: ${data.mkString(",")}" - override def addData(query: Option[StreamExecution]): (Source, Offset) = { + override def addData(query: Option[StreamExecution]): (BaseStreamingSource, Offset) = { (source, source.addData(data)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala index 79d65192a14aa..b96f2bcbdd644 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala @@ -33,6 +33,7 @@ import org.apache.spark.scheduler._ import org.apache.spark.sql.{Encoder, SparkSession} import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.sources.v2.reader.streaming.{Offset => OffsetV2} import org.apache.spark.sql.streaming.StreamingQueryListener._ import org.apache.spark.sql.streaming.util.StreamManualClock import org.apache.spark.util.JsonProtocol @@ -298,9 +299,9 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter { try { val input = new MemoryStream[Int](0, sqlContext) { @volatile var numTriggers = 0 - override def getOffset: Option[Offset] = { + override def getEndOffset: OffsetV2 = { numTriggers += 1 - super.getOffset + super.getEndOffset } } val clock = new StreamManualClock() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala index 76201c63a2701..3f9aa0d1fa5be 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala @@ -17,25 +17,27 @@ package org.apache.spark.sql.streaming +import java.{util => ju} +import java.util.Optional import java.util.concurrent.CountDownLatch import org.apache.commons.lang3.RandomStringUtils -import org.mockito.Mockito._ import org.scalactic.TolerantNumerics import org.scalatest.BeforeAndAfter -import org.scalatest.concurrent.Eventually._ import org.scalatest.concurrent.PatienceConfiguration.Timeout import org.scalatest.mockito.MockitoSugar import org.apache.spark.SparkException import org.apache.spark.internal.Logging import org.apache.spark.sql.{DataFrame, Dataset} +import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.sources.v2.reader.DataReaderFactory +import org.apache.spark.sql.sources.v2.reader.streaming.{Offset => OffsetV2} import org.apache.spark.sql.streaming.util.{BlockingSource, MockSourceProvider, StreamManualClock} import org.apache.spark.sql.types.StructType -import org.apache.spark.util.ManualClock class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging with MockitoSugar { @@ -206,19 +208,29 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi /** Custom MemoryStream that waits for manual clock to reach a time */ val inputData = new MemoryStream[Int](0, sqlContext) { - // getOffset should take 50 ms the first time it is called - override def getOffset: Option[Offset] = { - val offset = super.getOffset - if (offset.nonEmpty) { - clock.waitTillTime(1050) + + private def dataAdded: Boolean = currentOffset.offset != -1 + + // setOffsetRange should take 50 ms the first time it is called after data is added + override def setOffsetRange(start: Optional[OffsetV2], end: Optional[OffsetV2]): Unit = { + synchronized { + if (dataAdded) clock.waitTillTime(1050) + super.setOffsetRange(start, end) } - offset + } + + // getEndOffset should take 100 ms the first time it is called after data is added + override def getEndOffset(): OffsetV2 = synchronized { + if (dataAdded) clock.waitTillTime(1150) + super.getEndOffset() } // getBatch should take 100 ms the first time it is called - override def getBatch(start: Option[Offset], end: Offset): DataFrame = { - if (start.isEmpty) clock.waitTillTime(1150) - super.getBatch(start, end) + override def createUnsafeRowReaderFactories(): ju.List[DataReaderFactory[UnsafeRow]] = { + synchronized { + clock.waitTillTime(1350) + super.createUnsafeRowReaderFactories() + } } } @@ -258,39 +270,44 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi AssertOnQuery(_.status.message === "Waiting for next trigger"), AssertOnQuery(_.recentProgress.count(_.numInputRows > 0) === 0), - // Test status and progress while offset is being fetched + // Test status and progress when setOffsetRange is being called AddData(inputData, 1, 2), - AdvanceManualClock(1000), // time = 1000 to start new trigger, will block on getOffset + AdvanceManualClock(1000), // time = 1000 to start new trigger, will block on setOffsetRange AssertStreamExecThreadIsWaitingForTime(1050), AssertOnQuery(_.status.isDataAvailable === false), AssertOnQuery(_.status.isTriggerActive === true), AssertOnQuery(_.status.message.startsWith("Getting offsets from")), AssertOnQuery(_.recentProgress.count(_.numInputRows > 0) === 0), - // Test status and progress while batch is being fetched - AdvanceManualClock(50), // time = 1050 to unblock getOffset + AdvanceManualClock(50), // time = 1050 to unblock setOffsetRange AssertClockTime(1050), - AssertStreamExecThreadIsWaitingForTime(1150), // will block on getBatch that needs 1150 + AssertStreamExecThreadIsWaitingForTime(1150), // will block on getEndOffset that needs 1150 + AssertOnQuery(_.status.isDataAvailable === false), + AssertOnQuery(_.status.isTriggerActive === true), + AssertOnQuery(_.status.message.startsWith("Getting offsets from")), + AssertOnQuery(_.recentProgress.count(_.numInputRows > 0) === 0), + + AdvanceManualClock(100), // time = 1150 to unblock getEndOffset + AssertClockTime(1150), + AssertStreamExecThreadIsWaitingForTime(1350), // will block on createReadTasks that needs 1350 AssertOnQuery(_.status.isDataAvailable === true), AssertOnQuery(_.status.isTriggerActive === true), AssertOnQuery(_.status.message === "Processing new data"), AssertOnQuery(_.recentProgress.count(_.numInputRows > 0) === 0), - // Test status and progress while batch is being processed - AdvanceManualClock(100), // time = 1150 to unblock getBatch - AssertClockTime(1150), - AssertStreamExecThreadIsWaitingForTime(1500), // will block in Spark job that needs 1500 + AdvanceManualClock(200), // time = 1350 to unblock createReadTasks + AssertClockTime(1350), + AssertStreamExecThreadIsWaitingForTime(1500), // will block on map task that needs 1500 AssertOnQuery(_.status.isDataAvailable === true), AssertOnQuery(_.status.isTriggerActive === true), AssertOnQuery(_.status.message === "Processing new data"), AssertOnQuery(_.recentProgress.count(_.numInputRows > 0) === 0), // Test status and progress while batch processing has completed - AssertOnQuery { _ => clock.getTimeMillis() === 1150 }, - AdvanceManualClock(350), // time = 1500 to unblock job + AdvanceManualClock(150), // time = 1500 to unblock map task AssertClockTime(1500), CheckAnswer(2), - AssertStreamExecThreadIsWaitingForTime(2000), + AssertStreamExecThreadIsWaitingForTime(2000), // will block until the next trigger AssertOnQuery(_.status.isDataAvailable === true), AssertOnQuery(_.status.isTriggerActive === false), AssertOnQuery(_.status.message === "Waiting for next trigger"), @@ -307,10 +324,11 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi assert(progress.numInputRows === 2) assert(progress.processedRowsPerSecond === 4.0) - assert(progress.durationMs.get("getOffset") === 50) - assert(progress.durationMs.get("getBatch") === 100) + assert(progress.durationMs.get("setOffsetRange") === 50) + assert(progress.durationMs.get("getEndOffset") === 100) assert(progress.durationMs.get("queryPlanning") === 0) assert(progress.durationMs.get("walCommit") === 0) + assert(progress.durationMs.get("addBatch") === 350) assert(progress.durationMs.get("triggerExecution") === 500) assert(progress.sources.length === 1) From a62f30d3fa032ff75bc2b7bebbd0813e67ea5fd5 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Thu, 8 Feb 2018 12:46:10 +0900 Subject: [PATCH 0304/1161] [SPARK-23319][TESTS][FOLLOWUP] Fix a test for Python 3 without pandas. ## What changes were proposed in this pull request? This is a followup pr of #20487. When importing module but it doesn't exists, the error message is slightly different between Python 2 and 3. E.g., in Python 2: ``` No module named pandas ``` in Python 3: ``` No module named 'pandas' ``` So, one test to check an import error fails in Python 3 without pandas. This pr fixes it. ## How was this patch tested? Tested manually in my local environment. Author: Takuya UESHIN Closes #20538 from ueshin/issues/SPARK-23319/fup1. --- python/pyspark/sql/tests.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 58359b61dc83a..90ff084fed55e 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -2860,7 +2860,7 @@ def test_create_dataframe_required_pandas_not_found(self): with QuietTest(self.sc): with self.assertRaisesRegexp( ImportError, - '(Pandas >= .* must be installed|No module named pandas)'): + "(Pandas >= .* must be installed|No module named '?pandas'?)"): import pandas as pd from datetime import datetime pdf = pd.DataFrame({"ts": [datetime(2017, 10, 31, 1, 1, 1)], From 3473fda6dc77bdfd84b3de95d2082856ad4f8626 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Thu, 8 Feb 2018 12:21:18 +0800 Subject: [PATCH 0305/1161] Revert [SPARK-22279][SQL] Turn on spark.sql.hive.convertMetastoreOrc by default ## What changes were proposed in this pull request? This is to revert the changes made in https://github.com/apache/spark/pull/19499 , because this causes a regression. We should not ignore the table-specific compression conf when the Hive serde tables are converted to the data source tables. ## How was this patch tested? The existing tests. Author: gatorsmile Closes #20536 from gatorsmile/revert22279. --- .../src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala index d9627eb9790eb..93f3f38e52aa9 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala @@ -109,7 +109,7 @@ private[spark] object HiveUtils extends Logging { .doc("When set to true, the built-in ORC reader and writer are used to process " + "ORC tables created by using the HiveQL syntax, instead of Hive serde.") .booleanConf - .createWithDefault(true) + .createWithDefault(false) val HIVE_METASTORE_SHARED_PREFIXES = buildConf("spark.sql.hive.metastore.sharedPrefixes") .doc("A comma separated list of class prefixes that should be loaded using the classloader " + From 7f5f5fb1296275a38da0adfa05125dd8ebf729ff Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 8 Feb 2018 00:08:54 -0800 Subject: [PATCH 0306/1161] [SPARK-23348][SQL] append data using saveAsTable should adjust the data types ## What changes were proposed in this pull request? For inserting/appending data to an existing table, Spark should adjust the data types of the input query according to the table schema, or fail fast if it's uncastable. There are several ways to insert/append data: SQL API, `DataFrameWriter.insertInto`, `DataFrameWriter.saveAsTable`. The first 2 ways create `InsertIntoTable` plan, and the last way creates `CreateTable` plan. However, we only adjust input query data types for `InsertIntoTable`, and users may hit weird errors when appending data using `saveAsTable`. See the JIRA for the error case. This PR fixes this bug by adjusting data types for `CreateTable` too. ## How was this patch tested? new test. Author: Wenchen Fan Closes #20527 from cloud-fan/saveAsTable. --- .../sql/execution/datasources/rules.scala | 72 +++++++++++-------- .../sql/execution/command/DDLSuite.scala | 28 ++++++++ 2 files changed, 69 insertions(+), 31 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala index 5dbcf4a915cbf..5cc21eeaeaa94 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala @@ -22,7 +22,7 @@ import java.util.Locale import org.apache.spark.sql.{AnalysisException, SaveMode, SparkSession} import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.catalog._ -import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Expression, InputFileBlockLength, InputFileBlockStart, InputFileName, RowOrdering} +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Cast, Expression, InputFileBlockLength, InputFileBlockStart, InputFileName, RowOrdering} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.command.DDLUtils @@ -178,7 +178,8 @@ case class PreprocessTableCreation(sparkSession: SparkSession) extends Rule[Logi c.copy( tableDesc = existingTable, - query = Some(newQuery)) + query = Some(DDLPreprocessingUtils.castAndRenameQueryOutput( + newQuery, existingTable.schema.toAttributes, conf))) // Here we normalize partition, bucket and sort column names, w.r.t. the case sensitivity // config, and do various checks: @@ -316,7 +317,7 @@ case class PreprocessTableCreation(sparkSession: SparkSession) extends Rule[Logi * table. It also does data type casting and field renaming, to make sure that the columns to be * inserted have the correct data type and fields have the correct names. */ -case class PreprocessTableInsertion(conf: SQLConf) extends Rule[LogicalPlan] with CastSupport { +case class PreprocessTableInsertion(conf: SQLConf) extends Rule[LogicalPlan] { private def preprocess( insert: InsertIntoTable, tblName: String, @@ -336,6 +337,8 @@ case class PreprocessTableInsertion(conf: SQLConf) extends Rule[LogicalPlan] wit s"including ${staticPartCols.size} partition column(s) having constant value(s).") } + val newQuery = DDLPreprocessingUtils.castAndRenameQueryOutput( + insert.query, expectedColumns, conf) if (normalizedPartSpec.nonEmpty) { if (normalizedPartSpec.size != partColNames.length) { throw new AnalysisException( @@ -346,37 +349,11 @@ case class PreprocessTableInsertion(conf: SQLConf) extends Rule[LogicalPlan] wit """.stripMargin) } - castAndRenameChildOutput(insert.copy(partition = normalizedPartSpec), expectedColumns) + insert.copy(query = newQuery, partition = normalizedPartSpec) } else { // All partition columns are dynamic because the InsertIntoTable command does // not explicitly specify partitioning columns. - castAndRenameChildOutput(insert, expectedColumns) - .copy(partition = partColNames.map(_ -> None).toMap) - } - } - - private def castAndRenameChildOutput( - insert: InsertIntoTable, - expectedOutput: Seq[Attribute]): InsertIntoTable = { - val newChildOutput = expectedOutput.zip(insert.query.output).map { - case (expected, actual) => - if (expected.dataType.sameType(actual.dataType) && - expected.name == actual.name && - expected.metadata == actual.metadata) { - actual - } else { - // Renaming is needed for handling the following cases like - // 1) Column names/types do not match, e.g., INSERT INTO TABLE tab1 SELECT 1, 2 - // 2) Target tables have column metadata - Alias(cast(actual, expected.dataType), expected.name)( - explicitMetadata = Option(expected.metadata)) - } - } - - if (newChildOutput == insert.query.output) { - insert - } else { - insert.copy(query = Project(newChildOutput, insert.query)) + insert.copy(query = newQuery, partition = partColNames.map(_ -> None).toMap) } } @@ -491,3 +468,36 @@ object PreWriteCheck extends (LogicalPlan => Unit) { } } } + +object DDLPreprocessingUtils { + + /** + * Adjusts the name and data type of the input query output columns, to match the expectation. + */ + def castAndRenameQueryOutput( + query: LogicalPlan, + expectedOutput: Seq[Attribute], + conf: SQLConf): LogicalPlan = { + val newChildOutput = expectedOutput.zip(query.output).map { + case (expected, actual) => + if (expected.dataType.sameType(actual.dataType) && + expected.name == actual.name && + expected.metadata == actual.metadata) { + actual + } else { + // Renaming is needed for handling the following cases like + // 1) Column names/types do not match, e.g., INSERT INTO TABLE tab1 SELECT 1, 2 + // 2) Target tables have column metadata + Alias( + Cast(actual, expected.dataType, Option(conf.sessionLocalTimeZone)), + expected.name)(explicitMetadata = Option(expected.metadata)) + } + } + + if (newChildOutput == query.output) { + query + } else { + Project(newChildOutput, query) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index ee3674ba17821..f76bfd2fda2b9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -37,6 +37,8 @@ import org.apache.spark.util.Utils class InMemoryCatalogedDDLSuite extends DDLSuite with SharedSQLContext with BeforeAndAfterEach { + import testImplicits._ + override def afterEach(): Unit = { try { // drop all databases, tables and functions after each test @@ -132,6 +134,32 @@ class InMemoryCatalogedDDLSuite extends DDLSuite with SharedSQLContext with Befo checkAnswer(spark.table("t"), Row(Row("a", 1)) :: Nil) } } + + // TODO: This test is copied from HiveDDLSuite, unify it later. + test("SPARK-23348: append data to data source table with saveAsTable") { + withTable("t", "t1") { + Seq(1 -> "a").toDF("i", "j").write.saveAsTable("t") + checkAnswer(spark.table("t"), Row(1, "a")) + + sql("INSERT INTO t SELECT 2, 'b'") + checkAnswer(spark.table("t"), Row(1, "a") :: Row(2, "b") :: Nil) + + Seq(3 -> "c").toDF("i", "j").write.mode("append").saveAsTable("t") + checkAnswer(spark.table("t"), Row(1, "a") :: Row(2, "b") :: Row(3, "c") :: Nil) + + Seq("c" -> 3).toDF("i", "j").write.mode("append").saveAsTable("t") + checkAnswer(spark.table("t"), Row(1, "a") :: Row(2, "b") :: Row(3, "c") + :: Row(null, "3") :: Nil) + + Seq(4 -> "d").toDF("i", "j").write.saveAsTable("t1") + + val e = intercept[AnalysisException] { + Seq(5 -> "e").toDF("i", "j").write.mode("append").format("json").saveAsTable("t1") + } + assert(e.message.contains("The format of the existing table default.t1 is " + + "`ParquetFileFormat`. It doesn't match the specified format `JsonFileFormat`.")) + } + } } abstract class DDLSuite extends QueryTest with SQLTestUtils { From a75f927173632eee1316879447cb62c8cf30ae37 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 8 Feb 2018 19:20:11 +0800 Subject: [PATCH 0307/1161] [SPARK-23268][SQL][FOLLOWUP] Reorganize packages in data source V2 ## What changes were proposed in this pull request? This is a followup of https://github.com/apache/spark/pull/20435. While reorganizing the packages for streaming data source v2, the top level stream read/write support interfaces should not be in the reader/writer package, but should be in the `sources.v2` package, to follow the `ReadSupport`, `WriteSupport`, etc. ## How was this patch tested? N/A Author: Wenchen Fan Closes #20509 from cloud-fan/followup. --- .../org/apache/spark/sql/kafka010/KafkaSourceProvider.scala | 4 +--- .../sql/sources/v2/{reader => }/ContinuousReadSupport.java | 4 +--- .../sql/sources/v2/{reader => }/MicroBatchReadSupport.java | 4 +--- .../sql/sources/v2/{writer => }/StreamWriteSupport.java | 5 ++--- .../apache/spark/sql/sources/v2/writer/DataSourceWriter.java | 1 + .../spark/sql/execution/streaming/MicroBatchExecution.scala | 5 ++--- .../spark/sql/execution/streaming/RateSourceProvider.scala | 1 - .../spark/sql/execution/streaming/StreamingRelation.scala | 3 +-- .../org/apache/spark/sql/execution/streaming/console.scala | 3 +-- .../execution/streaming/continuous/ContinuousExecution.scala | 4 +--- .../sql/execution/streaming/sources/RateStreamSourceV2.scala | 2 +- .../spark/sql/execution/streaming/sources/memoryV2.scala | 2 +- .../org/apache/spark/sql/streaming/DataStreamReader.scala | 3 +-- .../org/apache/spark/sql/streaming/DataStreamWriter.scala | 2 +- .../apache/spark/sql/streaming/StreamingQueryManager.scala | 2 +- .../spark/sql/execution/streaming/RateSourceV2Suite.scala | 2 +- .../sql/streaming/sources/StreamingDataSourceV2Suite.scala | 5 ++--- 17 files changed, 19 insertions(+), 33 deletions(-) rename sql/core/src/main/java/org/apache/spark/sql/sources/v2/{reader => }/ContinuousReadSupport.java (92%) rename sql/core/src/main/java/org/apache/spark/sql/sources/v2/{reader => }/MicroBatchReadSupport.java (93%) rename sql/core/src/main/java/org/apache/spark/sql/sources/v2/{writer => }/StreamWriteSupport.java (93%) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala index 694ca76e24964..d4fa0359c12d6 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala @@ -30,9 +30,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.{AnalysisException, DataFrame, SaveMode, SparkSession, SQLContext} import org.apache.spark.sql.execution.streaming.{Sink, Source} import org.apache.spark.sql.sources._ -import org.apache.spark.sql.sources.v2.DataSourceOptions -import org.apache.spark.sql.sources.v2.reader.ContinuousReadSupport -import org.apache.spark.sql.sources.v2.writer.StreamWriteSupport +import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions, StreamWriteSupport} import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousReadSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ContinuousReadSupport.java similarity index 92% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousReadSupport.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/ContinuousReadSupport.java index 0c1d5d1a9577a..7df5a451ae5f3 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousReadSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ContinuousReadSupport.java @@ -15,13 +15,11 @@ * limitations under the License. */ -package org.apache.spark.sql.sources.v2.reader; +package org.apache.spark.sql.sources.v2; import java.util.Optional; import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.sources.v2.DataSourceV2; -import org.apache.spark.sql.sources.v2.DataSourceOptions; import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousReader; import org.apache.spark.sql.types.StructType; diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/MicroBatchReadSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/MicroBatchReadSupport.java similarity index 93% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/MicroBatchReadSupport.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/MicroBatchReadSupport.java index 5e8f0c0dafdcf..209ffa7a0b9fa 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/MicroBatchReadSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/MicroBatchReadSupport.java @@ -15,13 +15,11 @@ * limitations under the License. */ -package org.apache.spark.sql.sources.v2.reader; +package org.apache.spark.sql.sources.v2; import java.util.Optional; import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.sources.v2.DataSourceOptions; -import org.apache.spark.sql.sources.v2.DataSourceV2; import org.apache.spark.sql.sources.v2.reader.streaming.MicroBatchReader; import org.apache.spark.sql.types.StructType; diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/StreamWriteSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/StreamWriteSupport.java similarity index 93% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/StreamWriteSupport.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/StreamWriteSupport.java index 1c0e2e12f8d51..a77b01497269e 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/StreamWriteSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/StreamWriteSupport.java @@ -15,12 +15,11 @@ * limitations under the License. */ -package org.apache.spark.sql.sources.v2.writer; +package org.apache.spark.sql.sources.v2; import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.execution.streaming.BaseStreamingSink; -import org.apache.spark.sql.sources.v2.DataSourceOptions; -import org.apache.spark.sql.sources.v2.DataSourceV2; +import org.apache.spark.sql.sources.v2.writer.DataSourceWriter; import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter; import org.apache.spark.sql.streaming.OutputMode; import org.apache.spark.sql.types.StructType; diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java index 52324b3792b8a..e3f682bf96a66 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java @@ -21,6 +21,7 @@ import org.apache.spark.sql.Row; import org.apache.spark.sql.SaveMode; import org.apache.spark.sql.sources.v2.DataSourceOptions; +import org.apache.spark.sql.sources.v2.StreamWriteSupport; import org.apache.spark.sql.sources.v2.WriteSupport; import org.apache.spark.sql.streaming.OutputMode; import org.apache.spark.sql.types.StructType; diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index 045d2b4b9569c..812533313332e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -29,10 +29,9 @@ import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.datasources.v2.{StreamingDataSourceV2Relation, WriteToDataSourceV2} import org.apache.spark.sql.execution.streaming.sources.{InternalRowMicroBatchWriter, MicroBatchWriter} -import org.apache.spark.sql.sources.v2.DataSourceOptions -import org.apache.spark.sql.sources.v2.reader.MicroBatchReadSupport +import org.apache.spark.sql.sources.v2.{DataSourceOptions, MicroBatchReadSupport, StreamWriteSupport} import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset => OffsetV2} -import org.apache.spark.sql.sources.v2.writer.{StreamWriteSupport, SupportsWriteInternalRow} +import org.apache.spark.sql.sources.v2.writer.SupportsWriteInternalRow import org.apache.spark.sql.streaming.{OutputMode, ProcessingTime, Trigger} import org.apache.spark.util.{Clock, Utils} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala index ce5e63f5bde85..649fbbfa184ec 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala @@ -32,7 +32,6 @@ import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} import org.apache.spark.sql.execution.streaming.continuous.RateStreamContinuousReader import org.apache.spark.sql.sources.{DataSourceRegister, StreamSourceProvider} import org.apache.spark.sql.sources.v2._ -import org.apache.spark.sql.sources.v2.reader.ContinuousReadSupport import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousReader import org.apache.spark.sql.types._ import org.apache.spark.util.{ManualClock, SystemClock} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala index 845c8d2c14e43..7146190645b37 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala @@ -25,8 +25,7 @@ import org.apache.spark.sql.catalyst.plans.logical.LeafNode import org.apache.spark.sql.catalyst.plans.logical.Statistics import org.apache.spark.sql.execution.LeafExecNode import org.apache.spark.sql.execution.datasources.DataSource -import org.apache.spark.sql.sources.v2.DataSourceV2 -import org.apache.spark.sql.sources.v2.reader.ContinuousReadSupport +import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceV2} object StreamingRelation { def apply(dataSource: DataSource): StreamingRelation = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala index db600866067bc..cfba1001c6de0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala @@ -20,8 +20,7 @@ package org.apache.spark.sql.execution.streaming import org.apache.spark.sql._ import org.apache.spark.sql.execution.streaming.sources.ConsoleWriter import org.apache.spark.sql.sources.{BaseRelation, CreatableRelationProvider, DataSourceRegister} -import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2} -import org.apache.spark.sql.sources.v2.writer.StreamWriteSupport +import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, StreamWriteSupport} import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index ed22b9100497a..c3294d64b10cd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -31,10 +31,8 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, StreamingDataSourceV2Relation, WriteToDataSourceV2} import org.apache.spark.sql.execution.streaming.{ContinuousExecutionRelation, StreamingRelationV2, _} -import org.apache.spark.sql.sources.v2.DataSourceOptions -import org.apache.spark.sql.sources.v2.reader.ContinuousReadSupport +import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions, StreamWriteSupport} import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReader, PartitionOffset} -import org.apache.spark.sql.sources.v2.writer.StreamWriteSupport import org.apache.spark.sql.streaming.{OutputMode, ProcessingTime, Trigger} import org.apache.spark.sql.types.StructType import org.apache.spark.util.{Clock, Utils} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala index 077a255946a6b..4e2459bb05bd6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.streaming.{RateStreamOffset, ValueRunTimeMsPair} import org.apache.spark.sql.sources.DataSourceRegister -import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2} +import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, MicroBatchReadSupport} import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset} import org.apache.spark.sql.types.{LongType, StructField, StructType, TimestampType} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala index 3411edbc53412..f960208155e3b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics} import org.apache.spark.sql.catalyst.streaming.InternalOutputModes.{Append, Complete, Update} import org.apache.spark.sql.execution.streaming.Sink -import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2} +import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, StreamWriteSupport} import org.apache.spark.sql.sources.v2.writer._ import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter import org.apache.spark.sql.streaming.OutputMode diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index 116ac3da07b75..f23851655350a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -28,8 +28,7 @@ import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming.{StreamingRelation, StreamingRelationV2} import org.apache.spark.sql.sources.StreamSourceProvider -import org.apache.spark.sql.sources.v2.DataSourceOptions -import org.apache.spark.sql.sources.v2.reader.{ContinuousReadSupport, MicroBatchReadSupport} +import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions, MicroBatchReadSupport} import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala index 9aac360fd4bbc..2fc903168cfa0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger import org.apache.spark.sql.execution.streaming.sources.{MemoryPlanV2, MemorySinkV2} -import org.apache.spark.sql.sources.v2.writer.StreamWriteSupport +import org.apache.spark.sql.sources.v2.StreamWriteSupport /** * Interface used to write a streaming `Dataset` to external storage systems (e.g. file systems, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala index ddb1edc433d5a..7cefd03e43bc3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous.{ContinuousExecution, ContinuousTrigger} import org.apache.spark.sql.execution.streaming.state.StateStoreCoordinatorRef import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.sources.v2.writer.StreamWriteSupport +import org.apache.spark.sql.sources.v2.StreamWriteSupport import org.apache.spark.util.{Clock, SystemClock, Utils} /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala index 0d68d9c3138aa..983ba1668f58f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala @@ -26,8 +26,8 @@ import org.apache.spark.sql.Row import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming.continuous._ import org.apache.spark.sql.execution.streaming.sources.{RateStreamBatchTask, RateStreamMicroBatchReader, RateStreamSourceV2} +import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, MicroBatchReadSupport} import org.apache.spark.sql.sources.v2.DataSourceOptions -import org.apache.spark.sql.sources.v2.reader.{ContinuousReadSupport, MicroBatchReadSupport} import org.apache.spark.sql.streaming.StreamTest import org.apache.spark.util.ManualClock diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala index 51f44fa6285e4..af4618bed5456 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala @@ -25,10 +25,9 @@ import org.apache.spark.sql.execution.streaming.{RateStreamOffset, Sink, Streami import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.{DataSourceRegister, StreamSinkProvider} -import org.apache.spark.sql.sources.v2.DataSourceOptions -import org.apache.spark.sql.sources.v2.reader.{ContinuousReadSupport, DataReaderFactory, MicroBatchReadSupport} +import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions, MicroBatchReadSupport, StreamWriteSupport} +import org.apache.spark.sql.sources.v2.reader.DataReaderFactory import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReader, MicroBatchReader, Offset, PartitionOffset} -import org.apache.spark.sql.sources.v2.writer.StreamWriteSupport import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter import org.apache.spark.sql.streaming.{OutputMode, StreamTest, Trigger} import org.apache.spark.sql.types.StructType From 76e019d9bdcdca176c79c1cd71ddbf496333bf93 Mon Sep 17 00:00:00 2001 From: liuxian Date: Thu, 8 Feb 2018 23:41:30 +0800 Subject: [PATCH 0308/1161] [SPARK-21860][CORE] Improve memory reuse for heap memory in `HeapMemoryAllocator` MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? In `HeapMemoryAllocator`, when allocating memory from pool, and the key of pool is memory size. Actually some size of memory ,such as 1025bytes,1026bytes,......1032bytes, we can think they are the same,because we allocate memory in multiples of 8 bytes. In this case, we can improve memory reuse. ## How was this patch tested? Existing tests and added unit tests Author: liuxian Closes #19077 from 10110346/headmemoptimize. --- .../unsafe/memory/HeapMemoryAllocator.java | 18 +++++++++------ .../spark/unsafe/PlatformUtilSuite.java | 22 +++++++++++++++++++ 2 files changed, 33 insertions(+), 7 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java index a9603c1aba051..2733760dd19ef 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java @@ -46,9 +46,12 @@ private boolean shouldPool(long size) { @Override public MemoryBlock allocate(long size) throws OutOfMemoryError { - if (shouldPool(size)) { + int numWords = (int) ((size + 7) / 8); + long alignedSize = numWords * 8L; + assert (alignedSize >= size); + if (shouldPool(alignedSize)) { synchronized (this) { - final LinkedList> pool = bufferPoolsBySize.get(size); + final LinkedList> pool = bufferPoolsBySize.get(alignedSize); if (pool != null) { while (!pool.isEmpty()) { final WeakReference arrayReference = pool.pop(); @@ -62,11 +65,11 @@ public MemoryBlock allocate(long size) throws OutOfMemoryError { return memory; } } - bufferPoolsBySize.remove(size); + bufferPoolsBySize.remove(alignedSize); } } } - long[] array = new long[(int) ((size + 7) / 8)]; + long[] array = new long[numWords]; MemoryBlock memory = new MemoryBlock(array, Platform.LONG_ARRAY_OFFSET, size); if (MemoryAllocator.MEMORY_DEBUG_FILL_ENABLED) { memory.fill(MemoryAllocator.MEMORY_DEBUG_FILL_CLEAN_VALUE); @@ -98,12 +101,13 @@ public void free(MemoryBlock memory) { long[] array = (long[]) memory.obj; memory.setObjAndOffset(null, 0); - if (shouldPool(size)) { + long alignedSize = ((size + 7) / 8) * 8; + if (shouldPool(alignedSize)) { synchronized (this) { - LinkedList> pool = bufferPoolsBySize.get(size); + LinkedList> pool = bufferPoolsBySize.get(alignedSize); if (pool == null) { pool = new LinkedList<>(); - bufferPoolsBySize.put(size, pool); + bufferPoolsBySize.put(alignedSize, pool); } pool.add(new WeakReference<>(array)); } diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java index 62854837b05ed..71c53d35dcab8 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java @@ -17,6 +17,7 @@ package org.apache.spark.unsafe; +import org.apache.spark.unsafe.memory.HeapMemoryAllocator; import org.apache.spark.unsafe.memory.MemoryAllocator; import org.apache.spark.unsafe.memory.MemoryBlock; @@ -134,4 +135,25 @@ public void memoryDebugFillEnabledInTest() { MemoryAllocator.MEMORY_DEBUG_FILL_CLEAN_VALUE); MemoryAllocator.UNSAFE.free(offheap); } + + @Test + public void heapMemoryReuse() { + MemoryAllocator heapMem = new HeapMemoryAllocator(); + // The size is less than `HeapMemoryAllocator.POOLING_THRESHOLD_BYTES`,allocate new memory every time. + MemoryBlock onheap1 = heapMem.allocate(513); + Object obj1 = onheap1.getBaseObject(); + heapMem.free(onheap1); + MemoryBlock onheap2 = heapMem.allocate(514); + Assert.assertNotEquals(obj1, onheap2.getBaseObject()); + + // The size is greater than `HeapMemoryAllocator.POOLING_THRESHOLD_BYTES`, + // reuse the previous memory which has released. + MemoryBlock onheap3 = heapMem.allocate(1024 * 1024 + 1); + Assert.assertEquals(onheap3.size(), 1024 * 1024 + 1); + Object obj3 = onheap3.getBaseObject(); + heapMem.free(onheap3); + MemoryBlock onheap4 = heapMem.allocate(1024 * 1024 + 7); + Assert.assertEquals(onheap4.size(), 1024 * 1024 + 7); + Assert.assertEquals(obj3, onheap4.getBaseObject()); + } } From 4df84c3f818aa536515729b442601e08c253ed35 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Thu, 8 Feb 2018 12:52:08 -0600 Subject: [PATCH 0309/1161] [SPARK-23336][BUILD] Upgrade snappy-java to 1.1.7.1 ## What changes were proposed in this pull request? This PR upgrade snappy-java from 1.1.2.6 to 1.1.7.1. 1.1.7.1 release notes: - Improved performance for big-endian architecture - The other performance improvement in [snappy-1.1.5](https://github.com/google/snappy/releases/tag/1.1.5) 1.1.4 release notes: - Fix a 1% performance regression when snappy is used in PIE executables. - Improve compression performance by 5%. - Improve decompression performance by 20%. More details: https://github.com/xerial/snappy-java/blob/master/Milestone.md ## How was this patch tested? manual tests Author: Yuming Wang Closes #20510 from wangyum/SPARK-23336. --- dev/deps/spark-deps-hadoop-2.6 | 2 +- dev/deps/spark-deps-hadoop-2.7 | 2 +- pom.xml | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6 index 48e54568e6fc6..99031384aa22e 100644 --- a/dev/deps/spark-deps-hadoop-2.6 +++ b/dev/deps/spark-deps-hadoop-2.6 @@ -182,7 +182,7 @@ slf4j-api-1.7.16.jar slf4j-log4j12-1.7.16.jar snakeyaml-1.15.jar snappy-0.2.jar -snappy-java-1.1.2.6.jar +snappy-java-1.1.7.1.jar spire-macros_2.11-0.13.0.jar spire_2.11-0.13.0.jar stax-api-1.0-2.jar diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index 1807a77900e52..cf8d2789b7ee9 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -183,7 +183,7 @@ slf4j-api-1.7.16.jar slf4j-log4j12-1.7.16.jar snakeyaml-1.15.jar snappy-0.2.jar -snappy-java-1.1.2.6.jar +snappy-java-1.1.7.1.jar spire-macros_2.11-0.13.0.jar spire_2.11-0.13.0.jar stax-api-1.0-2.jar diff --git a/pom.xml b/pom.xml index d18831df1db6d..de949b94d676c 100644 --- a/pom.xml +++ b/pom.xml @@ -160,7 +160,7 @@ 1.9.13 2.6.7 2.6.7.1 - 1.1.2.6 + 1.1.7.1 1.1.2 1.2.0-incubating 1.10 From 8cbcc33876c773722163b2259644037bbb259bd1 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Fri, 9 Feb 2018 12:54:57 +0800 Subject: [PATCH 0310/1161] [SPARK-23186][SQL] Initialize DriverManager first before loading JDBC Drivers ## What changes were proposed in this pull request? Since some JDBC Drivers have class initialization code to call `DriverManager`, we need to initialize `DriverManager` first in order to avoid potential executor-side **deadlock** situations like the following (or [STORM-2527](https://issues.apache.org/jira/browse/STORM-2527)). ``` Thread 9587: (state = BLOCKED) - sun.reflect.NativeConstructorAccessorImpl.newInstance0(java.lang.reflect.Constructor, java.lang.Object[]) bci=0 (Compiled frame; information may be imprecise) - sun.reflect.NativeConstructorAccessorImpl.newInstance(java.lang.Object[]) bci=85, line=62 (Compiled frame) - sun.reflect.DelegatingConstructorAccessorImpl.newInstance(java.lang.Object[]) bci=5, line=45 (Compiled frame) - java.lang.reflect.Constructor.newInstance(java.lang.Object[]) bci=79, line=423 (Compiled frame) - java.lang.Class.newInstance() bci=138, line=442 (Compiled frame) - java.util.ServiceLoader$LazyIterator.nextService() bci=119, line=380 (Interpreted frame) - java.util.ServiceLoader$LazyIterator.next() bci=11, line=404 (Interpreted frame) - java.util.ServiceLoader$1.next() bci=37, line=480 (Interpreted frame) - java.sql.DriverManager$2.run() bci=21, line=603 (Interpreted frame) - java.sql.DriverManager$2.run() bci=1, line=583 (Interpreted frame) - java.security.AccessController.doPrivileged(java.security.PrivilegedAction) bci=0 (Compiled frame) - java.sql.DriverManager.loadInitialDrivers() bci=27, line=583 (Interpreted frame) - java.sql.DriverManager.() bci=32, line=101 (Interpreted frame) - org.apache.phoenix.mapreduce.util.ConnectionUtil.getConnection(java.lang.String, java.lang.Integer, java.lang.String, java.util.Properties) bci=12, line=98 (Interpreted frame) - org.apache.phoenix.mapreduce.util.ConnectionUtil.getInputConnection(org.apache.hadoop.conf.Configuration, java.util.Properties) bci=22, line=57 (Interpreted frame) - org.apache.phoenix.mapreduce.PhoenixInputFormat.getQueryPlan(org.apache.hadoop.mapreduce.JobContext, org.apache.hadoop.conf.Configuration) bci=61, line=116 (Interpreted frame) - org.apache.phoenix.mapreduce.PhoenixInputFormat.createRecordReader(org.apache.hadoop.mapreduce.InputSplit, org.apache.hadoop.mapreduce.TaskAttemptContext) bci=10, line=71 (Interpreted frame) - org.apache.spark.rdd.NewHadoopRDD$$anon$1.(org.apache.spark.rdd.NewHadoopRDD, org.apache.spark.Partition, org.apache.spark.TaskContext) bci=233, line=156 (Interpreted frame) Thread 9170: (state = BLOCKED) - org.apache.phoenix.jdbc.PhoenixDriver.() bci=35, line=125 (Interpreted frame) - sun.reflect.NativeConstructorAccessorImpl.newInstance0(java.lang.reflect.Constructor, java.lang.Object[]) bci=0 (Compiled frame) - sun.reflect.NativeConstructorAccessorImpl.newInstance(java.lang.Object[]) bci=85, line=62 (Compiled frame) - sun.reflect.DelegatingConstructorAccessorImpl.newInstance(java.lang.Object[]) bci=5, line=45 (Compiled frame) - java.lang.reflect.Constructor.newInstance(java.lang.Object[]) bci=79, line=423 (Compiled frame) - java.lang.Class.newInstance() bci=138, line=442 (Compiled frame) - org.apache.spark.sql.execution.datasources.jdbc.DriverRegistry$.register(java.lang.String) bci=89, line=46 (Interpreted frame) - org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils$$anonfun$createConnectionFactory$2.apply() bci=7, line=53 (Interpreted frame) - org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils$$anonfun$createConnectionFactory$2.apply() bci=1, line=52 (Interpreted frame) - org.apache.spark.sql.execution.datasources.jdbc.JDBCRDD$$anon$1.(org.apache.spark.sql.execution.datasources.jdbc.JDBCRDD, org.apache.spark.Partition, org.apache.spark.TaskContext) bci=81, line=347 (Interpreted frame) - org.apache.spark.sql.execution.datasources.jdbc.JDBCRDD.compute(org.apache.spark.Partition, org.apache.spark.TaskContext) bci=7, line=339 (Interpreted frame) ``` ## How was this patch tested? N/A Author: Dongjoon Hyun Closes #20359 from dongjoon-hyun/SPARK-23186. --- .../sql/execution/datasources/jdbc/DriverRegistry.scala | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DriverRegistry.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DriverRegistry.scala index 7a6c0f9fed2f9..1723596de1db2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DriverRegistry.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DriverRegistry.scala @@ -32,6 +32,13 @@ import org.apache.spark.util.Utils */ object DriverRegistry extends Logging { + /** + * Load DriverManager first to avoid any race condition between + * DriverManager static initialization block and specific driver class's + * static initialization block. e.g. PhoenixDriver + */ + DriverManager.getDrivers + private val wrapperMap: mutable.Map[String, DriverWrapper] = mutable.Map.empty def register(className: String): Unit = { From 4b4ee2601079f12f8f410a38d2081793cbdedc14 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Fri, 9 Feb 2018 14:21:10 +0800 Subject: [PATCH 0311/1161] [SPARK-23328][PYTHON] Disallow default value None in na.replace/replace when 'to_replace' is not a dictionary ## What changes were proposed in this pull request? This PR proposes to disallow default value None when 'to_replace' is not a dictionary. It seems weird we set the default value of `value` to `None` and we ended up allowing the case as below: ```python >>> df.show() ``` ``` +----+------+-----+ | age|height| name| +----+------+-----+ | 10| 80|Alice| ... ``` ```python >>> df.na.replace('Alice').show() ``` ``` +----+------+----+ | age|height|name| +----+------+----+ | 10| 80|null| ... ``` **After** This PR targets to disallow the case above: ```python >>> df.na.replace('Alice').show() ``` ``` ... TypeError: value is required when to_replace is not a dictionary. ``` while we still allow when `to_replace` is a dictionary: ```python >>> df.na.replace({'Alice': None}).show() ``` ``` +----+------+----+ | age|height|name| +----+------+----+ | 10| 80|null| ... ``` ## How was this patch tested? Manually tested, tests were added in `python/pyspark/sql/tests.py` and doctests were fixed. Author: hyukjinkwon Closes #20499 from HyukjinKwon/SPARK-19454-followup. --- docs/sql-programming-guide.md | 1 + python/pyspark/__init__.py | 1 + python/pyspark/_globals.py | 70 +++++++++++++++++++++++++++++++++ python/pyspark/sql/dataframe.py | 26 +++++++++--- python/pyspark/sql/tests.py | 11 +++--- 5 files changed, 99 insertions(+), 10 deletions(-) create mode 100644 python/pyspark/_globals.py diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index a0e221b39cc34..eab4030ee25d2 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1929,6 +1929,7 @@ working with timestamps in `pandas_udf`s to get the best performance, see - The rules to determine the result type of an arithmetic operation have been updated. In particular, if the precision / scale needed are out of the range of available values, the scale is reduced up to 6, in order to prevent the truncation of the integer part of the decimals. All the arithmetic operations are affected by the change, ie. addition (`+`), subtraction (`-`), multiplication (`*`), division (`/`), remainder (`%`) and positive module (`pmod`). - Literal values used in SQL operations are converted to DECIMAL with the exact precision and scale needed by them. - The configuration `spark.sql.decimalOperations.allowPrecisionLoss` has been introduced. It defaults to `true`, which means the new behavior described here; if set to `false`, Spark uses previous rules, ie. it doesn't adjust the needed scale to represent the values and it returns NULL if an exact representation of the value is not possible. + - In PySpark, `df.replace` does not allow to omit `value` when `to_replace` is not a dictionary. Previously, `value` could be omitted in the other cases and had `None` by default, which is counterintuitive and error prone. ## Upgrading From Spark SQL 2.1 to 2.2 diff --git a/python/pyspark/__init__.py b/python/pyspark/__init__.py index 4d142c91629cc..58218918693ca 100644 --- a/python/pyspark/__init__.py +++ b/python/pyspark/__init__.py @@ -54,6 +54,7 @@ from pyspark.taskcontext import TaskContext from pyspark.profiler import Profiler, BasicProfiler from pyspark.version import __version__ +from pyspark._globals import _NoValue def since(version): diff --git a/python/pyspark/_globals.py b/python/pyspark/_globals.py new file mode 100644 index 0000000000000..8e6099db09963 --- /dev/null +++ b/python/pyspark/_globals.py @@ -0,0 +1,70 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Module defining global singleton classes. + +This module raises a RuntimeError if an attempt to reload it is made. In that +way the identities of the classes defined here are fixed and will remain so +even if pyspark itself is reloaded. In particular, a function like the following +will still work correctly after pyspark is reloaded: + + def foo(arg=pyspark._NoValue): + if arg is pyspark._NoValue: + ... + +See gh-7844 for a discussion of the reload problem that motivated this module. + +Note that this approach is taken after from NumPy. +""" + +__ALL__ = ['_NoValue'] + + +# Disallow reloading this module so as to preserve the identities of the +# classes defined here. +if '_is_loaded' in globals(): + raise RuntimeError('Reloading pyspark._globals is not allowed') +_is_loaded = True + + +class _NoValueType(object): + """Special keyword value. + + The instance of this class may be used as the default value assigned to a + deprecated keyword in order to check if it has been given a user defined + value. + + This class was copied from NumPy. + """ + __instance = None + + def __new__(cls): + # ensure that only one instance exists + if not cls.__instance: + cls.__instance = super(_NoValueType, cls).__new__(cls) + return cls.__instance + + # needed for python 2 to preserve identity through a pickle + def __reduce__(self): + return (self.__class__, ()) + + def __repr__(self): + return "" + + +_NoValue = _NoValueType() diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 8ec24db8717b2..faee870a2d2e2 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -27,7 +27,7 @@ import warnings -from pyspark import copy_func, since +from pyspark import copy_func, since, _NoValue from pyspark.rdd import RDD, _load_from_socket, ignore_unicode_prefix from pyspark.serializers import ArrowSerializer, BatchedSerializer, PickleSerializer, \ UTF8Deserializer @@ -1532,7 +1532,7 @@ def fillna(self, value, subset=None): return DataFrame(self._jdf.na().fill(value, self._jseq(subset)), self.sql_ctx) @since(1.4) - def replace(self, to_replace, value=None, subset=None): + def replace(self, to_replace, value=_NoValue, subset=None): """Returns a new :class:`DataFrame` replacing a value with another value. :func:`DataFrame.replace` and :func:`DataFrameNaFunctions.replace` are aliases of each other. @@ -1545,8 +1545,8 @@ def replace(self, to_replace, value=None, subset=None): :param to_replace: bool, int, long, float, string, list or dict. Value to be replaced. - If the value is a dict, then `value` is ignored and `to_replace` must be a - mapping between a value and a replacement. + If the value is a dict, then `value` is ignored or can be omitted, and `to_replace` + must be a mapping between a value and a replacement. :param value: bool, int, long, float, string, list or None. The replacement value must be a bool, int, long, float, string or None. If `value` is a list, `value` should be of the same length and type as `to_replace`. @@ -1577,6 +1577,16 @@ def replace(self, to_replace, value=None, subset=None): |null| null|null| +----+------+----+ + >>> df4.na.replace({'Alice': None}).show() + +----+------+----+ + | age|height|name| + +----+------+----+ + | 10| 80|null| + | 5| null| Bob| + |null| null| Tom| + |null| null|null| + +----+------+----+ + >>> df4.na.replace(['Alice', 'Bob'], ['A', 'B'], 'name').show() +----+------+----+ | age|height|name| @@ -1587,6 +1597,12 @@ def replace(self, to_replace, value=None, subset=None): |null| null|null| +----+------+----+ """ + if value is _NoValue: + if isinstance(to_replace, dict): + value = None + else: + raise TypeError("value argument is required when to_replace is not a dictionary.") + # Helper functions def all_of(types): """Given a type or tuple of types and a sequence of xs @@ -2047,7 +2063,7 @@ def fill(self, value, subset=None): fill.__doc__ = DataFrame.fillna.__doc__ - def replace(self, to_replace, value, subset=None): + def replace(self, to_replace, value=_NoValue, subset=None): return self.df.replace(to_replace, value, subset) replace.__doc__ = DataFrame.replace.__doc__ diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 90ff084fed55e..6ace16955000d 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -2243,11 +2243,6 @@ def test_replace(self): .replace(False, True).first()) self.assertTupleEqual(row, (True, True)) - # replace list while value is not given (default to None) - row = self.spark.createDataFrame( - [(u'Alice', 10, 80.0)], schema).replace(["Alice", "Bob"]).first() - self.assertTupleEqual(row, (None, 10, 80.0)) - # replace string with None and then drop None rows row = self.spark.createDataFrame( [(u'Alice', 10, 80.0)], schema).replace(u'Alice', None).dropna() @@ -2283,6 +2278,12 @@ def test_replace(self): self.spark.createDataFrame( [(u'Alice', 10, 80.1)], schema).replace({u"Alice": u"Bob", 10: 20}).first() + with self.assertRaisesRegexp( + TypeError, + 'value argument is required when to_replace is not a dictionary.'): + self.spark.createDataFrame( + [(u'Alice', 10, 80.0)], schema).replace(["Alice", "Bob"]).first() + def test_capture_analysis_exception(self): self.assertRaises(AnalysisException, lambda: self.spark.sql("select abc")) self.assertRaises(AnalysisException, lambda: self.df.selectExpr("a + b")) From f77270b8811bbd8956d0c08fa556265d2c5ee20e Mon Sep 17 00:00:00 2001 From: liuxian Date: Fri, 9 Feb 2018 08:45:06 -0600 Subject: [PATCH 0312/1161] [SPARK-23358][CORE] When the number of partitions is greater than 2^28, it will result in an error result ## What changes were proposed in this pull request? In the `checkIndexAndDataFile`,the `blocks` is the ` Int` type, when it is greater than 2^28, `blocks*8` will overflow, and this will result in an error result. In fact, `blocks` is actually the number of partitions. ## How was this patch tested? Manual test Author: liuxian Closes #20544 from 10110346/overflow. --- .../org/apache/spark/shuffle/IndexShuffleBlockResolver.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala index c5f3f6e2b42b6..d88b25cc7e258 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala @@ -84,7 +84,7 @@ private[spark] class IndexShuffleBlockResolver( */ private def checkIndexAndDataFile(index: File, data: File, blocks: Int): Array[Long] = { // the index file should have `block + 1` longs as offset. - if (index.length() != (blocks + 1) * 8) { + if (index.length() != (blocks + 1) * 8L) { return null } val lengths = new Array[Long](blocks) From 0fc26313f8071cdcb4ccd67bb1d6942983199d36 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Fri, 9 Feb 2018 08:46:27 -0600 Subject: [PATCH 0313/1161] [SPARK-21860][CORE][FOLLOWUP] fix java style error ## What changes were proposed in this pull request? #19077 introduced a Java style error (too long line). Quick fix. ## How was this patch tested? running `./dev/lint-java` Author: Marco Gaido Closes #20558 from mgaido91/SPARK-21860. --- .../test/java/org/apache/spark/unsafe/PlatformUtilSuite.java | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java index 71c53d35dcab8..3ad9ac7b4de9c 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java @@ -139,7 +139,8 @@ public void memoryDebugFillEnabledInTest() { @Test public void heapMemoryReuse() { MemoryAllocator heapMem = new HeapMemoryAllocator(); - // The size is less than `HeapMemoryAllocator.POOLING_THRESHOLD_BYTES`,allocate new memory every time. + // The size is less than `HeapMemoryAllocator.POOLING_THRESHOLD_BYTES`, + // allocate new memory every time. MemoryBlock onheap1 = heapMem.allocate(513); Object obj1 = onheap1.getBaseObject(); heapMem.free(onheap1); From 7f10cf83f311526737fc96d5bb8281d12e41932f Mon Sep 17 00:00:00 2001 From: Rob Vesse Date: Fri, 9 Feb 2018 11:21:20 -0800 Subject: [PATCH 0314/1161] [SPARK-16501][MESOS] Allow providing Mesos principal & secret via files This commit modifies the Mesos submission client to allow the principal and secret to be provided indirectly via files. The path to these files can be specified either via Spark configuration or via environment variable. Assuming these files are appropriately protected by FS/OS permissions this means we don't ever leak the actual values in process info like ps Environment variable specification is useful because it allows you to interpolate the location of this file when using per-user Mesos credentials. For some background as to why we have taken this approach I will briefly describe our set up. On our systems we provide each authorised user account with their own Mesos credentials to provide certain security and audit guarantees to our customers. These credentials are managed by a central Secret management service. In our `spark-env.sh` we determine the appropriate secret and principal files to use depending on the user who is invoking Spark hence the need to inject these via environment variables as well as by configuration properties. So we set these environment variables appropriately and our Spark read in the contents of those files to authenticate itself with Mesos. This is functionality we have been using it in production across multiple customer sites for some time. This has been in the field for around 18 months with no reported issues. These changes have been sufficient to meet our customer security and audit requirements. We have been building and deploying custom builds of Apache Spark with various minor tweaks like this which we are now looking to contribute back into the community in order that we can rely upon stock Apache Spark builds and stop maintaining our own internal fork. Author: Rob Vesse Closes #20167 from rvesse/SPARK-16501. --- docs/running-on-mesos.md | 40 ++++- .../cluster/mesos/MesosSchedulerUtils.scala | 55 ++++-- .../mesos/MesosSchedulerUtilsSuite.scala | 161 +++++++++++++++++- 3 files changed, 238 insertions(+), 18 deletions(-) diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md index 2bb5ecf1b8509..8e58892e2689f 100644 --- a/docs/running-on-mesos.md +++ b/docs/running-on-mesos.md @@ -82,6 +82,27 @@ a Spark driver program configured to connect to Mesos. Alternatively, you can also install Spark in the same location in all the Mesos slaves, and configure `spark.mesos.executor.home` (defaults to SPARK_HOME) to point to that location. +## Authenticating to Mesos + +When Mesos Framework authentication is enabled it is necessary to provide a principal and secret by which to authenticate Spark to Mesos. Each Spark job will register with Mesos as a separate framework. + +Depending on your deployment environment you may wish to create a single set of framework credentials that are shared across all users or create framework credentials for each user. Creating and managing framework credentials should be done following the Mesos [Authentication documentation](http://mesos.apache.org/documentation/latest/authentication/). + +Framework credentials may be specified in a variety of ways depending on your deployment environment and security requirements. The most simple way is to specify the `spark.mesos.principal` and `spark.mesos.secret` values directly in your Spark configuration. Alternatively you may specify these values indirectly by instead specifying `spark.mesos.principal.file` and `spark.mesos.secret.file`, these settings point to files containing the principal and secret. These files must be plaintext files in UTF-8 encoding. Combined with appropriate file ownership and mode/ACLs this provides a more secure way to specify these credentials. + +Additionally if you prefer to use environment variables you can specify all of the above via environment variables instead, the environment variable names are simply the configuration settings uppercased with `.` replaced with `_` e.g. `SPARK_MESOS_PRINCIPAL`. + +### Credential Specification Preference Order + +Please note that if you specify multiple ways to obtain the credentials then the following preference order applies. Spark will use the first valid value found and any subsequent values are ignored: + +- `spark.mesos.principal` configuration setting +- `SPARK_MESOS_PRINCIPAL` environment variable +- `spark.mesos.principal.file` configuration setting +- `SPARK_MESOS_PRINCIPAL_FILE` environment variable + +An equivalent order applies for the secret. Essentially we prefer the configuration to be specified directly rather than indirectly by files, and we prefer that configuration settings are used over environment variables. + ## Uploading Spark Package When Mesos runs a task on a Mesos slave for the first time, that slave must have a Spark binary @@ -427,7 +448,14 @@ See the [configuration page](configuration.html) for information on Spark config
+ + + + + @@ -435,7 +463,15 @@ See the [configuration page](configuration.html) for information on Spark config + + + + + diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala index e75450369ad85..ecbcc960fc5a0 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala @@ -17,6 +17,8 @@ package org.apache.spark.scheduler.cluster.mesos +import java.io.File +import java.nio.charset.StandardCharsets import java.util.{List => JList} import java.util.concurrent.CountDownLatch @@ -25,6 +27,7 @@ import scala.collection.mutable.ArrayBuffer import scala.util.control.NonFatal import com.google.common.base.Splitter +import com.google.common.io.Files import org.apache.mesos.{MesosSchedulerDriver, Protos, Scheduler, SchedulerDriver} import org.apache.mesos.Protos.{TaskState => MesosTaskState, _} import org.apache.mesos.Protos.FrameworkInfo.Capability @@ -71,26 +74,15 @@ trait MesosSchedulerUtils extends Logging { failoverTimeout: Option[Double] = None, frameworkId: Option[String] = None): SchedulerDriver = { val fwInfoBuilder = FrameworkInfo.newBuilder().setUser(sparkUser).setName(appName) - val credBuilder = Credential.newBuilder() + fwInfoBuilder.setHostname(Option(conf.getenv("SPARK_PUBLIC_DNS")).getOrElse( + conf.get(DRIVER_HOST_ADDRESS))) webuiUrl.foreach { url => fwInfoBuilder.setWebuiUrl(url) } checkpoint.foreach { checkpoint => fwInfoBuilder.setCheckpoint(checkpoint) } failoverTimeout.foreach { timeout => fwInfoBuilder.setFailoverTimeout(timeout) } frameworkId.foreach { id => fwInfoBuilder.setId(FrameworkID.newBuilder().setValue(id).build()) } - fwInfoBuilder.setHostname(Option(conf.getenv("SPARK_PUBLIC_DNS")).getOrElse( - conf.get(DRIVER_HOST_ADDRESS))) - conf.getOption("spark.mesos.principal").foreach { principal => - fwInfoBuilder.setPrincipal(principal) - credBuilder.setPrincipal(principal) - } - conf.getOption("spark.mesos.secret").foreach { secret => - credBuilder.setSecret(secret) - } - if (credBuilder.hasSecret && !fwInfoBuilder.hasPrincipal) { - throw new SparkException( - "spark.mesos.principal must be configured when spark.mesos.secret is set") - } + conf.getOption("spark.mesos.role").foreach { role => fwInfoBuilder.setRole(role) } @@ -98,6 +90,7 @@ trait MesosSchedulerUtils extends Logging { if (maxGpus > 0) { fwInfoBuilder.addCapabilities(Capability.newBuilder().setType(Capability.Type.GPU_RESOURCES)) } + val credBuilder = buildCredentials(conf, fwInfoBuilder) if (credBuilder.hasPrincipal) { new MesosSchedulerDriver( scheduler, fwInfoBuilder.build(), masterUrl, credBuilder.build()) @@ -106,6 +99,40 @@ trait MesosSchedulerUtils extends Logging { } } + def buildCredentials( + conf: SparkConf, + fwInfoBuilder: Protos.FrameworkInfo.Builder): Protos.Credential.Builder = { + val credBuilder = Credential.newBuilder() + conf.getOption("spark.mesos.principal") + .orElse(Option(conf.getenv("SPARK_MESOS_PRINCIPAL"))) + .orElse( + conf.getOption("spark.mesos.principal.file") + .orElse(Option(conf.getenv("SPARK_MESOS_PRINCIPAL_FILE"))) + .map { principalFile => + Files.toString(new File(principalFile), StandardCharsets.UTF_8) + } + ).foreach { principal => + fwInfoBuilder.setPrincipal(principal) + credBuilder.setPrincipal(principal) + } + conf.getOption("spark.mesos.secret") + .orElse(Option(conf.getenv("SPARK_MESOS_SECRET"))) + .orElse( + conf.getOption("spark.mesos.secret.file") + .orElse(Option(conf.getenv("SPARK_MESOS_SECRET_FILE"))) + .map { secretFile => + Files.toString(new File(secretFile), StandardCharsets.UTF_8) + } + ).foreach { secret => + credBuilder.setSecret(secret) + } + if (credBuilder.hasSecret && !fwInfoBuilder.hasPrincipal) { + throw new SparkException( + "spark.mesos.principal must be configured when spark.mesos.secret is set") + } + credBuilder + } + /** * Starts the MesosSchedulerDriver and stores the current running driver to this new instance. * This driver is expected to not be running. diff --git a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtilsSuite.scala b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtilsSuite.scala index 7df738958f85c..8d90e1a8591ad 100644 --- a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtilsSuite.scala +++ b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtilsSuite.scala @@ -17,16 +17,20 @@ package org.apache.spark.scheduler.cluster.mesos +import java.io.{File, FileNotFoundException} + import scala.collection.JavaConverters._ import scala.language.reflectiveCalls -import org.apache.mesos.Protos.{Resource, Value} +import com.google.common.io.Files +import org.apache.mesos.Protos.{FrameworkInfo, Resource, Value} import org.mockito.Mockito._ import org.scalatest._ import org.scalatest.mockito.MockitoSugar -import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark.{SparkConf, SparkContext, SparkException, SparkFunSuite} import org.apache.spark.internal.config._ +import org.apache.spark.util.SparkConfWithEnv class MesosSchedulerUtilsSuite extends SparkFunSuite with Matchers with MockitoSugar { @@ -237,4 +241,157 @@ class MesosSchedulerUtilsSuite extends SparkFunSuite with Matchers with MockitoS val portsToUse = getRangesFromResources(resourcesToBeUsed).map{r => r._1} portsToUse.isEmpty shouldBe true } + + test("Principal specified via spark.mesos.principal") { + val conf = new SparkConf() + conf.set("spark.mesos.principal", "test-principal") + + val credBuilder = utils.buildCredentials(conf, FrameworkInfo.newBuilder()) + credBuilder.hasPrincipal shouldBe true + credBuilder.getPrincipal shouldBe "test-principal" + } + + test("Principal specified via spark.mesos.principal.file") { + val pFile = File.createTempFile("MesosSchedulerUtilsSuite", ".txt"); + pFile.deleteOnExit() + Files.write("test-principal".getBytes("UTF-8"), pFile); + val conf = new SparkConf() + conf.set("spark.mesos.principal.file", pFile.getAbsolutePath()) + + val credBuilder = utils.buildCredentials(conf, FrameworkInfo.newBuilder()) + credBuilder.hasPrincipal shouldBe true + credBuilder.getPrincipal shouldBe "test-principal" + } + + test("Principal specified via spark.mesos.principal.file that does not exist") { + val conf = new SparkConf() + conf.set("spark.mesos.principal.file", "/tmp/does-not-exist") + + intercept[FileNotFoundException] { + utils.buildCredentials(conf, FrameworkInfo.newBuilder()) + } + } + + test("Principal specified via SPARK_MESOS_PRINCIPAL") { + val conf = new SparkConfWithEnv(Map("SPARK_MESOS_PRINCIPAL" -> "test-principal")) + + val credBuilder = utils.buildCredentials(conf, FrameworkInfo.newBuilder()) + credBuilder.hasPrincipal shouldBe true + credBuilder.getPrincipal shouldBe "test-principal" + } + + test("Principal specified via SPARK_MESOS_PRINCIPAL_FILE") { + val pFile = File.createTempFile("MesosSchedulerUtilsSuite", ".txt"); + pFile.deleteOnExit() + Files.write("test-principal".getBytes("UTF-8"), pFile); + val conf = new SparkConfWithEnv(Map("SPARK_MESOS_PRINCIPAL_FILE" -> pFile.getAbsolutePath())) + + val credBuilder = utils.buildCredentials(conf, FrameworkInfo.newBuilder()) + credBuilder.hasPrincipal shouldBe true + credBuilder.getPrincipal shouldBe "test-principal" + } + + test("Principal specified via SPARK_MESOS_PRINCIPAL_FILE that does not exist") { + val conf = new SparkConfWithEnv(Map("SPARK_MESOS_PRINCIPAL_FILE" -> "/tmp/does-not-exist")) + + intercept[FileNotFoundException] { + utils.buildCredentials(conf, FrameworkInfo.newBuilder()) + } + } + + test("Secret specified via spark.mesos.secret") { + val conf = new SparkConf() + conf.set("spark.mesos.principal", "test-principal") + conf.set("spark.mesos.secret", "my-secret") + + val credBuilder = utils.buildCredentials(conf, FrameworkInfo.newBuilder()) + credBuilder.hasPrincipal shouldBe true + credBuilder.getPrincipal shouldBe "test-principal" + credBuilder.hasSecret shouldBe true + credBuilder.getSecret shouldBe "my-secret" + } + + test("Principal specified via spark.mesos.secret.file") { + val sFile = File.createTempFile("MesosSchedulerUtilsSuite", ".txt"); + sFile.deleteOnExit() + Files.write("my-secret".getBytes("UTF-8"), sFile); + val conf = new SparkConf() + conf.set("spark.mesos.principal", "test-principal") + conf.set("spark.mesos.secret.file", sFile.getAbsolutePath()) + + val credBuilder = utils.buildCredentials(conf, FrameworkInfo.newBuilder()) + credBuilder.hasPrincipal shouldBe true + credBuilder.getPrincipal shouldBe "test-principal" + credBuilder.hasSecret shouldBe true + credBuilder.getSecret shouldBe "my-secret" + } + + test("Principal specified via spark.mesos.secret.file that does not exist") { + val conf = new SparkConf() + conf.set("spark.mesos.principal", "test-principal") + conf.set("spark.mesos.secret.file", "/tmp/does-not-exist") + + intercept[FileNotFoundException] { + utils.buildCredentials(conf, FrameworkInfo.newBuilder()) + } + } + + test("Principal specified via SPARK_MESOS_SECRET") { + val env = Map("SPARK_MESOS_SECRET" -> "my-secret") + val conf = new SparkConfWithEnv(env) + conf.set("spark.mesos.principal", "test-principal") + + val credBuilder = utils.buildCredentials(conf, FrameworkInfo.newBuilder()) + credBuilder.hasPrincipal shouldBe true + credBuilder.getPrincipal shouldBe "test-principal" + credBuilder.hasSecret shouldBe true + credBuilder.getSecret shouldBe "my-secret" + } + + test("Principal specified via SPARK_MESOS_SECRET_FILE") { + val sFile = File.createTempFile("MesosSchedulerUtilsSuite", ".txt"); + sFile.deleteOnExit() + Files.write("my-secret".getBytes("UTF-8"), sFile); + + val sFilePath = sFile.getAbsolutePath() + val env = Map("SPARK_MESOS_SECRET_FILE" -> sFilePath) + val conf = new SparkConfWithEnv(env) + conf.set("spark.mesos.principal", "test-principal") + + val credBuilder = utils.buildCredentials(conf, FrameworkInfo.newBuilder()) + credBuilder.hasPrincipal shouldBe true + credBuilder.getPrincipal shouldBe "test-principal" + credBuilder.hasSecret shouldBe true + credBuilder.getSecret shouldBe "my-secret" + } + + test("Secret specified with no principal") { + val conf = new SparkConf() + conf.set("spark.mesos.secret", "my-secret") + + intercept[SparkException] { + utils.buildCredentials(conf, FrameworkInfo.newBuilder()) + } + } + + test("Principal specification preference") { + val conf = new SparkConfWithEnv(Map("SPARK_MESOS_PRINCIPAL" -> "other-principal")) + conf.set("spark.mesos.principal", "test-principal") + + val credBuilder = utils.buildCredentials(conf, FrameworkInfo.newBuilder()) + credBuilder.hasPrincipal shouldBe true + credBuilder.getPrincipal shouldBe "test-principal" + } + + test("Secret specification preference") { + val conf = new SparkConfWithEnv(Map("SPARK_MESOS_SECRET" -> "other-secret")) + conf.set("spark.mesos.principal", "test-principal") + conf.set("spark.mesos.secret", "my-secret") + + val credBuilder = utils.buildCredentials(conf, FrameworkInfo.newBuilder()) + credBuilder.hasPrincipal shouldBe true + credBuilder.getPrincipal shouldBe "test-principal" + credBuilder.hasSecret shouldBe true + credBuilder.getSecret shouldBe "my-secret" + } } From 557938e2839afce26a10a849a2a4be8fc4580427 Mon Sep 17 00:00:00 2001 From: Jacek Laskowski Date: Fri, 9 Feb 2018 18:18:30 -0600 Subject: [PATCH 0315/1161] [MINOR][HIVE] Typo fixes ## What changes were proposed in this pull request? Typo fixes (with expanding a Hive property) ## How was this patch tested? local build. Awaiting Jenkins Author: Jacek Laskowski Closes #20550 from jaceklaskowski/hiveutils-typos. --- .../src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala index 93f3f38e52aa9..c448c5a9821be 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala @@ -304,7 +304,7 @@ private[spark] object HiveUtils extends Logging { throw new IllegalArgumentException( "Builtin jars can only be used when hive execution version == hive metastore version. " + s"Execution: $builtinHiveVersion != Metastore: $hiveMetastoreVersion. " + - "Specify a vaild path to the correct hive jars using $HIVE_METASTORE_JARS " + + s"Specify a valid path to the correct hive jars using ${HIVE_METASTORE_JARS.key} " + s"or change ${HIVE_METASTORE_VERSION.key} to $builtinHiveVersion.") } @@ -324,7 +324,7 @@ private[spark] object HiveUtils extends Logging { if (jars.length == 0) { throw new IllegalArgumentException( "Unable to locate hive jars to connect to metastore. " + - "Please set spark.sql.hive.metastore.jars.") + s"Please set ${HIVE_METASTORE_JARS.key}.") } logInfo( From 6d7c38330e68c7beb10f54eee8b4f607ee3c4136 Mon Sep 17 00:00:00 2001 From: Feng Liu Date: Fri, 9 Feb 2018 16:21:47 -0800 Subject: [PATCH 0316/1161] [SPARK-23275][SQL] fix the thread leaking in hive/tests ## What changes were proposed in this pull request? This is a follow up of https://github.com/apache/spark/pull/20441. The two lines actually can trigger the hive metastore bug: https://issues.apache.org/jira/browse/HIVE-16844 The two configs are not in the default `ObjectStore` properties, so any run hive commands after these two lines will set the `propsChanged` flag in the `ObjectStore.setConf` and then cause thread leaks. I don't think the two lines are very useful. They can be removed safely. ## How was this patch tested? (Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests) (If this patch involves UI changes, please attach a screenshot; otherwise, remove this) Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Feng Liu Closes #20562 from liufengdb/fix-omm. --- .../main/scala/org/apache/spark/sql/hive/test/TestHive.scala | 2 -- 1 file changed, 2 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index 59708e7a0f2ff..19028939f3673 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -530,8 +530,6 @@ private[hive] class TestHiveSparkSession( // For some reason, RESET does not reset the following variables... // https://issues.apache.org/jira/browse/HIVE-9004 metadataHive.runSqlHive("set hive.table.parameters.default=") - metadataHive.runSqlHive("set datanucleus.cache.collections=true") - metadataHive.runSqlHive("set datanucleus.cache.collections.lazy=true") // Lots of tests fail if we do not change the partition whitelist from the default. metadataHive.runSqlHive("set hive.metastore.partition.name.whitelist.pattern=.*") From 97a224a855c4410b2dfb9c0bcc6aae583bd28e92 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Sun, 11 Feb 2018 01:08:02 +0900 Subject: [PATCH 0317/1161] [SPARK-23360][SQL][PYTHON] Get local timezone from environment via pytz, or dateutil. ## What changes were proposed in this pull request? Currently we use `tzlocal()` to get Python local timezone, but it sometimes causes unexpected behavior. I changed the way to get Python local timezone to use pytz if the timezone is specified in environment variable, or timezone file via dateutil . ## How was this patch tested? Added a test and existing tests. Author: Takuya UESHIN Closes #20559 from ueshin/issues/SPARK-23360/master. --- python/pyspark/sql/tests.py | 28 ++++++++++++++++++++++++++++ python/pyspark/sql/types.py | 23 +++++++++++++++++++---- 2 files changed, 47 insertions(+), 4 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 6ace16955000d..1087c3fafdd16 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -2868,6 +2868,34 @@ def test_create_dataframe_required_pandas_not_found(self): "d": [pd.Timestamp.now().date()]}) self.spark.createDataFrame(pdf) + # Regression test for SPARK-23360 + @unittest.skipIf(not _have_pandas, _pandas_requirement_message) + def test_create_dateframe_from_pandas_with_dst(self): + import pandas as pd + from datetime import datetime + + pdf = pd.DataFrame({'time': [datetime(2015, 10, 31, 22, 30)]}) + + df = self.spark.createDataFrame(pdf) + self.assertPandasEqual(pdf, df.toPandas()) + + orig_env_tz = os.environ.get('TZ', None) + orig_session_tz = self.spark.conf.get('spark.sql.session.timeZone') + try: + tz = 'America/Los_Angeles' + os.environ['TZ'] = tz + time.tzset() + self.spark.conf.set('spark.sql.session.timeZone', tz) + + df = self.spark.createDataFrame(pdf) + self.assertPandasEqual(pdf, df.toPandas()) + finally: + del os.environ['TZ'] + if orig_env_tz is not None: + os.environ['TZ'] = orig_env_tz + time.tzset() + self.spark.conf.set('spark.sql.session.timeZone', orig_session_tz) + class HiveSparkSubmitTests(SparkSubmitTests): diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 093dae5a22e1f..2599dc5fdc599 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1709,6 +1709,21 @@ def _check_dataframe_convert_date(pdf, schema): return pdf +def _get_local_timezone(): + """ Get local timezone using pytz with environment variable, or dateutil. + + If there is a 'TZ' environment variable, pass it to pandas to use pytz and use it as timezone + string, otherwise use the special word 'dateutil/:' which means that pandas uses dateutil and + it reads system configuration to know the system local timezone. + + See also: + - https://github.com/pandas-dev/pandas/blob/0.19.x/pandas/tslib.pyx#L1753 + - https://github.com/dateutil/dateutil/blob/2.6.1/dateutil/tz/tz.py#L1338 + """ + import os + return os.environ.get('TZ', 'dateutil/:') + + def _check_dataframe_localize_timestamps(pdf, timezone): """ Convert timezone aware timestamps to timezone-naive in the specified timezone or local timezone @@ -1721,7 +1736,7 @@ def _check_dataframe_localize_timestamps(pdf, timezone): require_minimum_pandas_version() from pandas.api.types import is_datetime64tz_dtype - tz = timezone or 'tzlocal()' + tz = timezone or _get_local_timezone() for column, series in pdf.iteritems(): # TODO: handle nested timestamps, such as ArrayType(TimestampType())? if is_datetime64tz_dtype(series.dtype): @@ -1744,7 +1759,7 @@ def _check_series_convert_timestamps_internal(s, timezone): from pandas.api.types import is_datetime64_dtype, is_datetime64tz_dtype # TODO: handle nested timestamps, such as ArrayType(TimestampType())? if is_datetime64_dtype(s.dtype): - tz = timezone or 'tzlocal()' + tz = timezone or _get_local_timezone() return s.dt.tz_localize(tz).dt.tz_convert('UTC') elif is_datetime64tz_dtype(s.dtype): return s.dt.tz_convert('UTC') @@ -1766,8 +1781,8 @@ def _check_series_convert_timestamps_localize(s, from_timezone, to_timezone): import pandas as pd from pandas.api.types import is_datetime64tz_dtype, is_datetime64_dtype - from_tz = from_timezone or 'tzlocal()' - to_tz = to_timezone or 'tzlocal()' + from_tz = from_timezone or _get_local_timezone() + to_tz = to_timezone or _get_local_timezone() # TODO: handle nested timestamps, such as ArrayType(TimestampType())? if is_datetime64tz_dtype(s.dtype): return s.dt.tz_convert(to_tz).dt.tz_localize(None) From 0783876c81f212e1422a1b7786c26e3ac8e84f9f Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Sat, 10 Feb 2018 10:46:45 -0600 Subject: [PATCH 0318/1161] [SPARK-23344][PYTHON][ML] Add distanceMeasure param to KMeans ## What changes were proposed in this pull request? SPARK-22119 introduced a new parameter for KMeans, ie. `distanceMeasure`. The PR adds it also to the Python interface. ## How was this patch tested? added UTs Author: Marco Gaido Closes #20520 from mgaido91/SPARK-23344. --- python/pyspark/ml/clustering.py | 32 +++++++++++++++++++++++++++----- python/pyspark/ml/tests.py | 18 ++++++++++++++++++ 2 files changed, 45 insertions(+), 5 deletions(-) diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py index 66fb00508522e..6448b76a0da88 100644 --- a/python/pyspark/ml/clustering.py +++ b/python/pyspark/ml/clustering.py @@ -403,17 +403,23 @@ class KMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasTol typeConverter=TypeConverters.toString) initSteps = Param(Params._dummy(), "initSteps", "The number of steps for k-means|| " + "initialization mode. Must be > 0.", typeConverter=TypeConverters.toInt) + distanceMeasure = Param(Params._dummy(), "distanceMeasure", "The distance measure. " + + "Supported options: 'euclidean' and 'cosine'.", + typeConverter=TypeConverters.toString) @keyword_only def __init__(self, featuresCol="features", predictionCol="prediction", k=2, - initMode="k-means||", initSteps=2, tol=1e-4, maxIter=20, seed=None): + initMode="k-means||", initSteps=2, tol=1e-4, maxIter=20, seed=None, + distanceMeasure="euclidean"): """ __init__(self, featuresCol="features", predictionCol="prediction", k=2, \ - initMode="k-means||", initSteps=2, tol=1e-4, maxIter=20, seed=None) + initMode="k-means||", initSteps=2, tol=1e-4, maxIter=20, seed=None, \ + distanceMeasure="euclidean") """ super(KMeans, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.clustering.KMeans", self.uid) - self._setDefault(k=2, initMode="k-means||", initSteps=2, tol=1e-4, maxIter=20) + self._setDefault(k=2, initMode="k-means||", initSteps=2, tol=1e-4, maxIter=20, + distanceMeasure="euclidean") kwargs = self._input_kwargs self.setParams(**kwargs) @@ -423,10 +429,12 @@ def _create_model(self, java_model): @keyword_only @since("1.5.0") def setParams(self, featuresCol="features", predictionCol="prediction", k=2, - initMode="k-means||", initSteps=2, tol=1e-4, maxIter=20, seed=None): + initMode="k-means||", initSteps=2, tol=1e-4, maxIter=20, seed=None, + distanceMeasure="euclidean"): """ setParams(self, featuresCol="features", predictionCol="prediction", k=2, \ - initMode="k-means||", initSteps=2, tol=1e-4, maxIter=20, seed=None) + initMode="k-means||", initSteps=2, tol=1e-4, maxIter=20, seed=None, \ + distanceMeasure="euclidean") Sets params for KMeans. """ @@ -475,6 +483,20 @@ def getInitSteps(self): """ return self.getOrDefault(self.initSteps) + @since("2.4.0") + def setDistanceMeasure(self, value): + """ + Sets the value of :py:attr:`distanceMeasure`. + """ + return self._set(distanceMeasure=value) + + @since("2.4.0") + def getDistanceMeasure(self): + """ + Gets the value of `distanceMeasure` + """ + return self.getOrDefault(self.distanceMeasure) + class BisectingKMeansModel(JavaModel, JavaMLWritable, JavaMLReadable): """ diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 75d04785a0710..6d6737241e06e 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -418,6 +418,9 @@ def test_kmeans_param(self): self.assertEqual(algo.getK(), 10) algo.setInitSteps(10) self.assertEqual(algo.getInitSteps(), 10) + self.assertEqual(algo.getDistanceMeasure(), "euclidean") + algo.setDistanceMeasure("cosine") + self.assertEqual(algo.getDistanceMeasure(), "cosine") def test_hasseed(self): noSeedSpecd = TestParams() @@ -1620,6 +1623,21 @@ def test_kmeans_summary(self): self.assertEqual(s.k, 2) +class KMeansTests(SparkSessionTestCase): + + def test_kmeans_cosine_distance(self): + data = [(Vectors.dense([1.0, 1.0]),), (Vectors.dense([10.0, 10.0]),), + (Vectors.dense([1.0, 0.5]),), (Vectors.dense([10.0, 4.4]),), + (Vectors.dense([-1.0, 1.0]),), (Vectors.dense([-100.0, 90.0]),)] + df = self.spark.createDataFrame(data, ["features"]) + kmeans = KMeans(k=3, seed=1, distanceMeasure="cosine") + model = kmeans.fit(df) + result = model.transform(df).collect() + self.assertTrue(result[0].prediction == result[1].prediction) + self.assertTrue(result[2].prediction == result[3].prediction) + self.assertTrue(result[4].prediction == result[5].prediction) + + class OneVsRestTests(SparkSessionTestCase): def test_copy(self): From a34fce19bc0ee5a7e36c6ecba75d2aeb70fdcbc7 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Sun, 11 Feb 2018 17:31:35 +0900 Subject: [PATCH 0319/1161] [SPARK-23314][PYTHON] Add ambiguous=False when localizing tz-naive timestamps in Arrow codepath to deal with dst ## What changes were proposed in this pull request? When tz_localize a tz-naive timetamp, pandas will throw exception if the timestamp is during daylight saving time period, e.g., `2015-11-01 01:30:00`. This PR fixes this issue by setting `ambiguous=False` when calling tz_localize, which is the same default behavior of pytz. ## How was this patch tested? Add `test_timestamp_dst` Author: Li Jin Closes #20537 from icexelloss/SPARK-23314. --- python/pyspark/sql/tests.py | 39 +++++++++++++++++++++++++++++++++++++ python/pyspark/sql/types.py | 37 ++++++++++++++++++++++++++++++++--- 2 files changed, 73 insertions(+), 3 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 1087c3fafdd16..4bc59fd99fca5 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3670,6 +3670,21 @@ def test_createDataFrame_with_int_col_names(self): self.assertEqual(pdf_col_names, df.columns) self.assertEqual(pdf_col_names, df_arrow.columns) + # Regression test for SPARK-23314 + def test_timestamp_dst(self): + import pandas as pd + # Daylight saving time for Los Angeles for 2015 is Sun, Nov 1 at 2:00 am + dt = [datetime.datetime(2015, 11, 1, 0, 30), + datetime.datetime(2015, 11, 1, 1, 30), + datetime.datetime(2015, 11, 1, 2, 30)] + pdf = pd.DataFrame({'time': dt}) + + df_from_python = self.spark.createDataFrame(dt, 'timestamp').toDF('time') + df_from_pandas = self.spark.createDataFrame(pdf) + + self.assertPandasEqual(pdf, df_from_python.toPandas()) + self.assertPandasEqual(pdf, df_from_pandas.toPandas()) + @unittest.skipIf( not _have_pandas or not _have_pyarrow, @@ -4311,6 +4326,18 @@ def test_register_vectorized_udf_basic(self): self.assertEquals(expected.collect(), res1.collect()) self.assertEquals(expected.collect(), res2.collect()) + # Regression test for SPARK-23314 + def test_timestamp_dst(self): + from pyspark.sql.functions import pandas_udf + # Daylight saving time for Los Angeles for 2015 is Sun, Nov 1 at 2:00 am + dt = [datetime.datetime(2015, 11, 1, 0, 30), + datetime.datetime(2015, 11, 1, 1, 30), + datetime.datetime(2015, 11, 1, 2, 30)] + df = self.spark.createDataFrame(dt, 'timestamp').toDF('time') + foo_udf = pandas_udf(lambda x: x, 'timestamp') + result = df.withColumn('time', foo_udf(df.time)) + self.assertEquals(df.collect(), result.collect()) + @unittest.skipIf( not _have_pandas or not _have_pyarrow, @@ -4482,6 +4509,18 @@ def test_unsupported_types(self): with self.assertRaisesRegexp(Exception, 'Unsupported data type'): df.groupby('id').apply(f).collect() + # Regression test for SPARK-23314 + def test_timestamp_dst(self): + from pyspark.sql.functions import pandas_udf, PandasUDFType + # Daylight saving time for Los Angeles for 2015 is Sun, Nov 1 at 2:00 am + dt = [datetime.datetime(2015, 11, 1, 0, 30), + datetime.datetime(2015, 11, 1, 1, 30), + datetime.datetime(2015, 11, 1, 2, 30)] + df = self.spark.createDataFrame(dt, 'timestamp').toDF('time') + foo_udf = pandas_udf(lambda pdf: pdf, 'time timestamp', PandasUDFType.GROUPED_MAP) + result = df.groupby('time').apply(foo_udf).sort('time') + self.assertPandasEqual(df.toPandas(), result.toPandas()) + @unittest.skipIf( not _have_pandas or not _have_pyarrow, diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 2599dc5fdc599..f7141b4549e4e 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1759,8 +1759,38 @@ def _check_series_convert_timestamps_internal(s, timezone): from pandas.api.types import is_datetime64_dtype, is_datetime64tz_dtype # TODO: handle nested timestamps, such as ArrayType(TimestampType())? if is_datetime64_dtype(s.dtype): + # When tz_localize a tz-naive timestamp, the result is ambiguous if the tz-naive + # timestamp is during the hour when the clock is adjusted backward during due to + # daylight saving time (dst). + # E.g., for America/New_York, the clock is adjusted backward on 2015-11-01 2:00 to + # 2015-11-01 1:00 from dst-time to standard time, and therefore, when tz_localize + # a tz-naive timestamp 2015-11-01 1:30 with America/New_York timezone, it can be either + # dst time (2015-01-01 1:30-0400) or standard time (2015-11-01 1:30-0500). + # + # Here we explicit choose to use standard time. This matches the default behavior of + # pytz. + # + # Here are some code to help understand this behavior: + # >>> import datetime + # >>> import pandas as pd + # >>> import pytz + # >>> + # >>> t = datetime.datetime(2015, 11, 1, 1, 30) + # >>> ts = pd.Series([t]) + # >>> tz = pytz.timezone('America/New_York') + # >>> + # >>> ts.dt.tz_localize(tz, ambiguous=True) + # 0 2015-11-01 01:30:00-04:00 + # dtype: datetime64[ns, America/New_York] + # >>> + # >>> ts.dt.tz_localize(tz, ambiguous=False) + # 0 2015-11-01 01:30:00-05:00 + # dtype: datetime64[ns, America/New_York] + # >>> + # >>> str(tz.localize(t)) + # '2015-11-01 01:30:00-05:00' tz = timezone or _get_local_timezone() - return s.dt.tz_localize(tz).dt.tz_convert('UTC') + return s.dt.tz_localize(tz, ambiguous=False).dt.tz_convert('UTC') elif is_datetime64tz_dtype(s.dtype): return s.dt.tz_convert('UTC') else: @@ -1788,8 +1818,9 @@ def _check_series_convert_timestamps_localize(s, from_timezone, to_timezone): return s.dt.tz_convert(to_tz).dt.tz_localize(None) elif is_datetime64_dtype(s.dtype) and from_tz != to_tz: # `s.dt.tz_localize('tzlocal()')` doesn't work properly when including NaT. - return s.apply(lambda ts: ts.tz_localize(from_tz).tz_convert(to_tz).tz_localize(None) - if ts is not pd.NaT else pd.NaT) + return s.apply( + lambda ts: ts.tz_localize(from_tz, ambiguous=False).tz_convert(to_tz).tz_localize(None) + if ts is not pd.NaT else pd.NaT) else: return s From 8acb51f08b448628b65e90af3b268994f9550e45 Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Sun, 11 Feb 2018 18:55:38 +0900 Subject: [PATCH 0320/1161] [SPARK-23084][PYTHON] Add unboundedPreceding(), unboundedFollowing() and currentRow() to PySpark ## What changes were proposed in this pull request? Added unboundedPreceding(), unboundedFollowing() and currentRow() to PySpark, also updated the rangeBetween API ## How was this patch tested? did unit test on my local. Please let me know if I need to add unit test in tests.py Author: Huaxin Gao Closes #20400 from huaxingao/spark_23084. --- python/pyspark/sql/functions.py | 30 ++++++++++++++ python/pyspark/sql/window.py | 70 ++++++++++++++++++++++++--------- 2 files changed, 82 insertions(+), 18 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 05031f5ec87d7..9bb9c323a5a60 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -809,6 +809,36 @@ def ntile(n): return Column(sc._jvm.functions.ntile(int(n))) +@since(2.4) +def unboundedPreceding(): + """ + Window function: returns the special frame boundary that represents the first row + in the window partition. + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.unboundedPreceding()) + + +@since(2.4) +def unboundedFollowing(): + """ + Window function: returns the special frame boundary that represents the last row + in the window partition. + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.unboundedFollowing()) + + +@since(2.4) +def currentRow(): + """ + Window function: returns the special frame boundary that represents the current row + in the window partition. + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.currentRow()) + + # ---------------------- Date/Timestamp functions ------------------------------ @since(1.5) diff --git a/python/pyspark/sql/window.py b/python/pyspark/sql/window.py index 7ce27f9b102c0..bb841a9b9ff7c 100644 --- a/python/pyspark/sql/window.py +++ b/python/pyspark/sql/window.py @@ -16,9 +16,11 @@ # import sys +if sys.version >= '3': + long = int from pyspark import since, SparkContext -from pyspark.sql.column import _to_seq, _to_java_column +from pyspark.sql.column import Column, _to_seq, _to_java_column __all__ = ["Window", "WindowSpec"] @@ -120,20 +122,45 @@ def rangeBetween(start, end): and "5" means the five off after the current row. We recommend users use ``Window.unboundedPreceding``, ``Window.unboundedFollowing``, - and ``Window.currentRow`` to specify special boundary values, rather than using integral - values directly. + ``Window.currentRow``, ``pyspark.sql.functions.unboundedPreceding``, + ``pyspark.sql.functions.unboundedFollowing`` and ``pyspark.sql.functions.currentRow`` + to specify special boundary values, rather than using integral values directly. :param start: boundary start, inclusive. - The frame is unbounded if this is ``Window.unboundedPreceding``, or + The frame is unbounded if this is ``Window.unboundedPreceding``, + a column returned by ``pyspark.sql.functions.unboundedPreceding``, or any value less than or equal to max(-sys.maxsize, -9223372036854775808). :param end: boundary end, inclusive. - The frame is unbounded if this is ``Window.unboundedFollowing``, or + The frame is unbounded if this is ``Window.unboundedFollowing``, + a column returned by ``pyspark.sql.functions.unboundedFollowing``, or any value greater than or equal to min(sys.maxsize, 9223372036854775807). + + >>> from pyspark.sql import functions as F, SparkSession, Window + >>> spark = SparkSession.builder.getOrCreate() + >>> df = spark.createDataFrame( + ... [(1, "a"), (1, "a"), (2, "a"), (1, "b"), (2, "b"), (3, "b")], ["id", "category"]) + >>> window = Window.orderBy("id").partitionBy("category").rangeBetween( + ... F.currentRow(), F.lit(1)) + >>> df.withColumn("sum", F.sum("id").over(window)).show() + +---+--------+---+ + | id|category|sum| + +---+--------+---+ + | 1| b| 3| + | 2| b| 5| + | 3| b| 3| + | 1| a| 4| + | 1| a| 4| + | 2| a| 2| + +---+--------+---+ """ - if start <= Window._PRECEDING_THRESHOLD: - start = Window.unboundedPreceding - if end >= Window._FOLLOWING_THRESHOLD: - end = Window.unboundedFollowing + if isinstance(start, (int, long)) and isinstance(end, (int, long)): + if start <= Window._PRECEDING_THRESHOLD: + start = Window.unboundedPreceding + if end >= Window._FOLLOWING_THRESHOLD: + end = Window.unboundedFollowing + elif isinstance(start, Column) and isinstance(end, Column): + start = start._jc + end = end._jc sc = SparkContext._active_spark_context jspec = sc._jvm.org.apache.spark.sql.expressions.Window.rangeBetween(start, end) return WindowSpec(jspec) @@ -208,27 +235,34 @@ def rangeBetween(self, start, end): and "5" means the five off after the current row. We recommend users use ``Window.unboundedPreceding``, ``Window.unboundedFollowing``, - and ``Window.currentRow`` to specify special boundary values, rather than using integral - values directly. + ``Window.currentRow``, ``pyspark.sql.functions.unboundedPreceding``, + ``pyspark.sql.functions.unboundedFollowing`` and ``pyspark.sql.functions.currentRow`` + to specify special boundary values, rather than using integral values directly. :param start: boundary start, inclusive. - The frame is unbounded if this is ``Window.unboundedPreceding``, or + The frame is unbounded if this is ``Window.unboundedPreceding``, + a column returned by ``pyspark.sql.functions.unboundedPreceding``, or any value less than or equal to max(-sys.maxsize, -9223372036854775808). :param end: boundary end, inclusive. - The frame is unbounded if this is ``Window.unboundedFollowing``, or + The frame is unbounded if this is ``Window.unboundedFollowing``, + a column returned by ``pyspark.sql.functions.unboundedFollowing``, or any value greater than or equal to min(sys.maxsize, 9223372036854775807). """ - if start <= Window._PRECEDING_THRESHOLD: - start = Window.unboundedPreceding - if end >= Window._FOLLOWING_THRESHOLD: - end = Window.unboundedFollowing + if isinstance(start, (int, long)) and isinstance(end, (int, long)): + if start <= Window._PRECEDING_THRESHOLD: + start = Window.unboundedPreceding + if end >= Window._FOLLOWING_THRESHOLD: + end = Window.unboundedFollowing + elif isinstance(start, Column) and isinstance(end, Column): + start = start._jc + end = end._jc return WindowSpec(self._jspec.rangeBetween(start, end)) def _test(): import doctest SparkContext('local[4]', 'PythonTest') - (failure_count, test_count) = doctest.testmod() + (failure_count, test_count) = doctest.testmod(optionflags=doctest.NORMALIZE_WHITESPACE) if failure_count: exit(-1) From eacb62fbbed317fd0e972102838af231385d54d8 Mon Sep 17 00:00:00 2001 From: xubo245 <601450868@qq.com> Date: Sun, 11 Feb 2018 19:23:15 +0900 Subject: [PATCH 0321/1161] [SPARK-22624][PYSPARK] Expose range partitioning shuffle introduced by spark-22614 ## What changes were proposed in this pull request? Expose range partitioning shuffle introduced by spark-22614 ## How was this patch tested? Unit test in dataframe.py Please review http://spark.apache.org/contributing.html before opening a pull request. Author: xubo245 <601450868@qq.com> Closes #20456 from xubo245/SPARK22624_PysparkRangePartition. --- python/pyspark/sql/dataframe.py | 45 +++++++++++++++++++++++++++++++++ python/pyspark/sql/tests.py | 28 ++++++++++++++++++++ 2 files changed, 73 insertions(+) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index faee870a2d2e2..5cc8b63cdfadf 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -667,6 +667,51 @@ def repartition(self, numPartitions, *cols): else: raise TypeError("numPartitions should be an int or Column") + @since("2.4.0") + def repartitionByRange(self, numPartitions, *cols): + """ + Returns a new :class:`DataFrame` partitioned by the given partitioning expressions. The + resulting DataFrame is range partitioned. + + ``numPartitions`` can be an int to specify the target number of partitions or a Column. + If it is a Column, it will be used as the first partitioning column. If not specified, + the default number of partitions is used. + + At least one partition-by expression must be specified. + When no explicit sort order is specified, "ascending nulls first" is assumed. + + >>> df.repartitionByRange(2, "age").rdd.getNumPartitions() + 2 + >>> df.show() + +---+-----+ + |age| name| + +---+-----+ + | 2|Alice| + | 5| Bob| + +---+-----+ + >>> df.repartitionByRange(1, "age").rdd.getNumPartitions() + 1 + >>> data = df.repartitionByRange("age") + >>> df.show() + +---+-----+ + |age| name| + +---+-----+ + | 2|Alice| + | 5| Bob| + +---+-----+ + """ + if isinstance(numPartitions, int): + if len(cols) == 0: + return ValueError("At least one partition-by expression must be specified.") + else: + return DataFrame( + self._jdf.repartitionByRange(numPartitions, self._jcols(*cols)), self.sql_ctx) + elif isinstance(numPartitions, (basestring, Column)): + cols = (numPartitions,) + cols + return DataFrame(self._jdf.repartitionByRange(self._jcols(*cols)), self.sql_ctx) + else: + raise TypeError("numPartitions should be an int, string or Column") + @since(1.3) def distinct(self): """Returns a new :class:`DataFrame` containing the distinct rows in this :class:`DataFrame`. diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 4bc59fd99fca5..fe89bd0685027 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -2148,6 +2148,34 @@ def test_expr(self): result = df.select(functions.expr("length(a)")).collect()[0].asDict() self.assertEqual(13, result["length(a)"]) + def test_repartitionByRange_dataframe(self): + schema = StructType([ + StructField("name", StringType(), True), + StructField("age", IntegerType(), True), + StructField("height", DoubleType(), True)]) + + df1 = self.spark.createDataFrame( + [(u'Bob', 27, 66.0), (u'Alice', 10, 10.0), (u'Bob', 10, 66.0)], schema) + df2 = self.spark.createDataFrame( + [(u'Alice', 10, 10.0), (u'Bob', 10, 66.0), (u'Bob', 27, 66.0)], schema) + + # test repartitionByRange(numPartitions, *cols) + df3 = df1.repartitionByRange(2, "name", "age") + self.assertEqual(df3.rdd.getNumPartitions(), 2) + self.assertEqual(df3.rdd.first(), df2.rdd.first()) + self.assertEqual(df3.rdd.take(3), df2.rdd.take(3)) + + # test repartitionByRange(numPartitions, *cols) + df4 = df1.repartitionByRange(3, "name", "age") + self.assertEqual(df4.rdd.getNumPartitions(), 3) + self.assertEqual(df4.rdd.first(), df2.rdd.first()) + self.assertEqual(df4.rdd.take(3), df2.rdd.take(3)) + + # test repartitionByRange(*cols) + df5 = df1.repartitionByRange("name", "age") + self.assertEqual(df5.rdd.first(), df2.rdd.first()) + self.assertEqual(df5.rdd.take(3), df2.rdd.take(3)) + def test_replace(self): schema = StructType([ StructField("name", StringType(), True), From 4bbd7443ebb005f81ed6bc39849940ac8db3b3cc Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 12 Feb 2018 00:03:49 +0800 Subject: [PATCH 0322/1161] [SPARK-23376][SQL] creating UnsafeKVExternalSorter with BytesToBytesMap may fail ## What changes were proposed in this pull request? This is a long-standing bug in `UnsafeKVExternalSorter` and was reported in the dev list multiple times. When creating `UnsafeKVExternalSorter` with `BytesToBytesMap`, we need to create a `UnsafeInMemorySorter` to sort the data in `BytesToBytesMap`. The data format of the sorter and the map is same, so no data movement is required. However, both the sorter and the map need a point array for some bookkeeping work. There is an optimization in `UnsafeKVExternalSorter`: reuse the point array between the sorter and the map, to avoid an extra memory allocation. This sounds like a reasonable optimization, the length of the `BytesToBytesMap` point array is at least 4 times larger than the number of keys(to avoid hash collision, the hash table size should be at least 2 times larger than the number of keys, and each key occupies 2 slots). `UnsafeInMemorySorter` needs the pointer array size to be 4 times of the number of entries, so we are safe to reuse the point array. However, the number of keys of the map doesn't equal to the number of entries in the map, because `BytesToBytesMap` supports duplicated keys. This breaks the assumption of the above optimization and we may run out of space when inserting data into the sorter, and hit error ``` java.lang.IllegalStateException: There is no space for new record at org.apache.spark.util.collection.unsafe.sort.UnsafeInMemorySorter.insertRecord(UnsafeInMemorySorter.java:239) at org.apache.spark.sql.execution.UnsafeKVExternalSorter.(UnsafeKVExternalSorter.java:149) ... ``` This PR fixes this bug by creating a new point array if the existing one is not big enough. ## How was this patch tested? a new test Author: Wenchen Fan Closes #20561 from cloud-fan/bug. --- .../sql/execution/UnsafeKVExternalSorter.java | 31 +++++++++++---- .../UnsafeKVExternalSorterSuite.scala | 39 +++++++++++++++++++ 2 files changed, 62 insertions(+), 8 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java index b0b5383a081a0..9eb03430a7db2 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java @@ -34,6 +34,7 @@ import org.apache.spark.storage.BlockManager; import org.apache.spark.unsafe.KVIterator; import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.array.LongArray; import org.apache.spark.unsafe.map.BytesToBytesMap; import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.util.collection.unsafe.sort.*; @@ -98,19 +99,33 @@ public UnsafeKVExternalSorter( numElementsForSpillThreshold, canUseRadixSort); } else { - // The array will be used to do in-place sort, which require half of the space to be empty. - // Note: each record in the map takes two entries in the array, one is record pointer, - // another is the key prefix. - assert(map.numKeys() * 2 <= map.getArray().size() / 2); - // During spilling, the array in map will not be used, so we can borrow that and use it - // as the underlying array for in-memory sorter (it's always large enough). - // Since we will not grow the array, it's fine to pass `null` as consumer. + // During spilling, the pointer array in `BytesToBytesMap` will not be used, so we can borrow + // that and use it as the pointer array for `UnsafeInMemorySorter`. + LongArray pointerArray = map.getArray(); + // `BytesToBytesMap`'s pointer array is only guaranteed to hold all the distinct keys, but + // `UnsafeInMemorySorter`'s pointer array need to hold all the entries. Since + // `BytesToBytesMap` can have duplicated keys, here we need a check to make sure the pointer + // array can hold all the entries in `BytesToBytesMap`. + // The pointer array will be used to do in-place sort, which requires half of the space to be + // empty. Note: each record in the map takes two entries in the pointer array, one is record + // pointer, another is key prefix. So the required size of pointer array is `numRecords * 4`. + // TODO: It's possible to change UnsafeInMemorySorter to have multiple entries with same key, + // so that we can always reuse the pointer array. + if (map.numValues() > pointerArray.size() / 4) { + // Here we ask the map to allocate memory, so that the memory manager won't ask the map + // to spill, if the memory is not enough. + pointerArray = map.allocateArray(map.numValues() * 4L); + } + + // Since the pointer array(either reuse the one in the map, or create a new one) is guaranteed + // to be large enough, it's fine to pass `null` as consumer because we won't allocate more + // memory. final UnsafeInMemorySorter inMemSorter = new UnsafeInMemorySorter( null, taskMemoryManager, comparatorSupplier.get(), prefixComparator, - map.getArray(), + pointerArray, canUseRadixSort); // We cannot use the destructive iterator here because we are reusing the existing memory diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala index 6af9f8b77f8d3..bf588d3bb7841 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions.{InterpretedOrdering, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.map.BytesToBytesMap /** * Test suite for [[UnsafeKVExternalSorter]], with randomly generated test data. @@ -205,4 +206,42 @@ class UnsafeKVExternalSorterSuite extends SparkFunSuite with SharedSQLContext { spill = true ) } + + test("SPARK-23376: Create UnsafeKVExternalSorter with BytesToByteMap having duplicated keys") { + val memoryManager = new TestMemoryManager(new SparkConf()) + val taskMemoryManager = new TaskMemoryManager(memoryManager, 0) + val map = new BytesToBytesMap(taskMemoryManager, 64, taskMemoryManager.pageSizeBytes()) + + // Key/value are a unsafe rows with a single int column + val schema = new StructType().add("i", IntegerType) + val key = new UnsafeRow(1) + key.pointTo(new Array[Byte](32), 32) + key.setInt(0, 1) + val value = new UnsafeRow(1) + value.pointTo(new Array[Byte](32), 32) + value.setInt(0, 2) + + for (_ <- 1 to 65) { + val loc = map.lookup(key.getBaseObject, key.getBaseOffset, key.getSizeInBytes) + loc.append( + key.getBaseObject, key.getBaseOffset, key.getSizeInBytes, + value.getBaseObject, value.getBaseOffset, value.getSizeInBytes) + } + + // Make sure we can successfully create a UnsafeKVExternalSorter with a `BytesToBytesMap` + // which has duplicated keys and the number of entries exceeds its capacity. + try { + TaskContext.setTaskContext(new TaskContextImpl(0, 0, 0, 0, 0, taskMemoryManager, null, null)) + new UnsafeKVExternalSorter( + schema, + schema, + sparkContext.env.blockManager, + sparkContext.env.serializerManager, + taskMemoryManager.pageSizeBytes(), + Int.MaxValue, + map) + } finally { + TaskContext.unset() + } + } } From c0c902aedcf9ed24e482d873d766a7df63b964cb Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Sun, 11 Feb 2018 20:15:30 -0600 Subject: [PATCH 0323/1161] [SPARK-22119][FOLLOWUP][ML] Use spherical KMeans with cosine distance ## What changes were proposed in this pull request? In #19340 some comments considered needed to use spherical KMeans when cosine distance measure is specified, as Matlab does; instead of the implementation based on the behavior of other tools/libraries like Rapidminer, nltk and ELKI, ie. the centroids are computed as the mean of all the points in the clusters. The PR introduce the approach used in spherical KMeans. This behavior has the nice feature to minimize the within-cluster cosine distance. ## How was this patch tested? existing/improved UTs Author: Marco Gaido Closes #20518 from mgaido91/SPARK-22119_followup. --- .../spark/mllib/clustering/KMeans.scala | 54 ++++++++++++++++--- .../spark/ml/clustering/KMeansSuite.scala | 15 +++++- 2 files changed, 62 insertions(+), 7 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala index 607145cb59fba..3c4ba0bc60c7f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala @@ -310,8 +310,7 @@ class KMeans private ( points.foreach { point => val (bestCenter, cost) = distanceMeasureInstance.findClosest(thisCenters, point) costAccum.add(cost) - val sum = sums(bestCenter) - axpy(1.0, point.vector, sum) + distanceMeasureInstance.updateClusterSum(point, sums(bestCenter)) counts(bestCenter) += 1 } @@ -319,10 +318,9 @@ class KMeans private ( }.reduceByKey { case ((sum1, count1), (sum2, count2)) => axpy(1.0, sum2, sum1) (sum1, count1 + count2) - }.mapValues { case (sum, count) => - scal(1.0 / count, sum) - new VectorWithNorm(sum) - }.collectAsMap() + }.collectAsMap().mapValues { case (sum, count) => + distanceMeasureInstance.centroid(sum, count) + } bcCenters.destroy(blocking = false) @@ -657,6 +655,26 @@ private[spark] abstract class DistanceMeasure extends Serializable { v1: VectorWithNorm, v2: VectorWithNorm): Double + /** + * Updates the value of `sum` adding the `point` vector. + * @param point a `VectorWithNorm` to be added to `sum` of a cluster + * @param sum the `sum` for a cluster to be updated + */ + def updateClusterSum(point: VectorWithNorm, sum: Vector): Unit = { + axpy(1.0, point.vector, sum) + } + + /** + * Returns a centroid for a cluster given its `sum` vector and its `count` of points. + * + * @param sum the `sum` for a cluster + * @param count the number of points in the cluster + * @return the centroid of the cluster + */ + def centroid(sum: Vector, count: Long): VectorWithNorm = { + scal(1.0 / count, sum) + new VectorWithNorm(sum) + } } @Since("2.4.0") @@ -743,6 +761,30 @@ private[spark] class CosineDistanceMeasure extends DistanceMeasure { * @return the cosine distance between the two input vectors */ override def distance(v1: VectorWithNorm, v2: VectorWithNorm): Double = { + assert(v1.norm > 0 && v2.norm > 0, "Cosine distance is not defined for zero-length vectors.") 1 - dot(v1.vector, v2.vector) / v1.norm / v2.norm } + + /** + * Updates the value of `sum` adding the `point` vector. + * @param point a `VectorWithNorm` to be added to `sum` of a cluster + * @param sum the `sum` for a cluster to be updated + */ + override def updateClusterSum(point: VectorWithNorm, sum: Vector): Unit = { + axpy(1.0 / point.norm, point.vector, sum) + } + + /** + * Returns a centroid for a cluster given its `sum` vector and its `count` of points. + * + * @param sum the `sum` for a cluster + * @param count the number of points in the cluster + * @return the centroid of the cluster + */ + override def centroid(sum: Vector, count: Long): VectorWithNorm = { + scal(1.0 / count, sum) + val norm = Vectors.norm(sum, 2) + scal(1.0 / norm, sum) + new VectorWithNorm(sum, 1) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala index e4506f23feb31..32830b39407ad 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.ml.clustering import scala.util.Random -import org.apache.spark.SparkFunSuite +import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} @@ -179,6 +179,19 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR assert(predictionsMap(Vectors.dense(-1.0, 1.0)) == predictionsMap(Vectors.dense(-100.0, 90.0))) + model.clusterCenters.forall(Vectors.norm(_, 2) == 1.0) + } + + test("KMeans with cosine distance is not supported for 0-length vectors") { + val model = new KMeans().setDistanceMeasure(DistanceMeasure.COSINE).setK(2) + val df = spark.createDataFrame(spark.sparkContext.parallelize(Array( + Vectors.dense(0.0, 0.0), + Vectors.dense(10.0, 10.0), + Vectors.dense(1.0, 0.5) + )).map(v => TestRow(v))) + val e = intercept[SparkException](model.fit(df)) + assert(e.getCause.isInstanceOf[AssertionError]) + assert(e.getCause.getMessage.contains("Cosine distance is not defined")) } test("read/write") { From 6efd5d117e98074d1b16a5c991fbd38df9aa196e Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Sun, 11 Feb 2018 23:46:23 -0800 Subject: [PATCH 0324/1161] [SPARK-23390][SQL] Flaky Test Suite: FileBasedDataSourceSuite in Spark 2.3/hadoop 2.7 ## What changes were proposed in this pull request? This test only fails with sbt on Hadoop 2.7, I can't reproduce it locally, but here is my speculation by looking at the code: 1. FileSystem.delete doesn't delete the directory entirely, somehow we can still open the file as a 0-length empty file.(just speculation) 2. ORC intentionally allow empty files, and the reader fails during reading without closing the file stream. This PR improves the test to make sure all files are deleted and can't be opened. ## How was this patch tested? N/A Author: Wenchen Fan Closes #20584 from cloud-fan/flaky-test. --- .../spark/sql/FileBasedDataSourceSuite.scala | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala index 640d6b1583663..2e332362ea644 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql +import java.io.FileNotFoundException + import org.apache.hadoop.fs.Path import org.apache.spark.SparkException @@ -102,17 +104,27 @@ class FileBasedDataSourceSuite extends QueryTest with SharedSQLContext { def testIgnoreMissingFiles(): Unit = { withTempDir { dir => val basePath = dir.getCanonicalPath + Seq("0").toDF("a").write.format(format).save(new Path(basePath, "first").toString) Seq("1").toDF("a").write.format(format).save(new Path(basePath, "second").toString) + val thirdPath = new Path(basePath, "third") + val fs = thirdPath.getFileSystem(spark.sparkContext.hadoopConfiguration) Seq("2").toDF("a").write.format(format).save(thirdPath.toString) + val files = fs.listStatus(thirdPath).filter(_.isFile).map(_.getPath) + val df = spark.read.format(format).load( new Path(basePath, "first").toString, new Path(basePath, "second").toString, new Path(basePath, "third").toString) - val fs = thirdPath.getFileSystem(spark.sparkContext.hadoopConfiguration) + // Make sure all data files are deleted and can't be opened. + files.foreach(f => fs.delete(f, false)) assert(fs.delete(thirdPath, true)) + for (f <- files) { + intercept[FileNotFoundException](fs.open(f)) + } + checkAnswer(df, Seq(Row("0"), Row("1"))) } } From c338c8cf8253c037ecd4f39bbd58ed5a86581b37 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Mon, 12 Feb 2018 20:49:36 +0900 Subject: [PATCH 0325/1161] [SPARK-23352][PYTHON] Explicitly specify supported types in Pandas UDFs ## What changes were proposed in this pull request? This PR targets to explicitly specify supported types in Pandas UDFs. The main change here is to add a deduplicated and explicit type checking in `returnType` ahead with documenting this; however, it happened to fix multiple things. 1. Currently, we don't support `BinaryType` in Pandas UDFs, for example, see: ```python from pyspark.sql.functions import pandas_udf pudf = pandas_udf(lambda x: x, "binary") df = spark.createDataFrame([[bytearray(1)]]) df.select(pudf("_1")).show() ``` ``` ... TypeError: Unsupported type in conversion to Arrow: BinaryType ``` We can document this behaviour for its guide. 2. Also, the grouped aggregate Pandas UDF fails fast on `ArrayType` but seems we can support this case. ```python from pyspark.sql.functions import pandas_udf, PandasUDFType foo = pandas_udf(lambda v: v.mean(), 'array', PandasUDFType.GROUPED_AGG) df = spark.range(100).selectExpr("id", "array(id) as value") df.groupBy("id").agg(foo("value")).show() ``` ``` ... NotImplementedError: ArrayType, StructType and MapType are not supported with PandasUDFType.GROUPED_AGG ``` 3. Since we can check the return type ahead, we can fail fast before actual execution. ```python # we can fail fast at this stage because we know the schema ahead pandas_udf(lambda x: x, BinaryType()) ``` ## How was this patch tested? Manually tested and unit tests for `BinaryType` and `ArrayType(...)` were added. Author: hyukjinkwon Closes #20531 from HyukjinKwon/pudf-cleanup. --- docs/sql-programming-guide.md | 4 +- python/pyspark/sql/tests.py | 130 +++++++++++------- python/pyspark/sql/types.py | 4 + python/pyspark/sql/udf.py | 36 +++-- python/pyspark/worker.py | 2 +- .../apache/spark/sql/internal/SQLConf.scala | 2 +- 6 files changed, 111 insertions(+), 67 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index eab4030ee25d2..6174a93b68492 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1676,7 +1676,7 @@ Using the above optimizations with Arrow will produce the same results as when A enabled. Note that even with Arrow, `toPandas()` results in the collection of all records in the DataFrame to the driver program and should be done on a small subset of the data. Not all Spark data types are currently supported and an error can be raised if a column has an unsupported type, -see [Supported Types](#supported-sql-arrow-types). If an error occurs during `createDataFrame()`, +see [Supported SQL Types](#supported-sql-arrow-types). If an error occurs during `createDataFrame()`, Spark will fall back to create the DataFrame without Arrow. ## Pandas UDFs (a.k.a. Vectorized UDFs) @@ -1734,7 +1734,7 @@ For detailed usage, please see [`pyspark.sql.functions.pandas_udf`](api/python/p ### Supported SQL Types -Currently, all Spark SQL data types are supported by Arrow-based conversion except `MapType`, +Currently, all Spark SQL data types are supported by Arrow-based conversion except `BinaryType`, `MapType`, `ArrayType` of `TimestampType`, and nested `StructType`. ### Setting Arrow Batch Size diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index fe89bd0685027..2af218a691026 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3790,10 +3790,10 @@ def foo(x): self.assertEqual(foo.returnType, schema) self.assertEqual(foo.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF) - @pandas_udf(returnType='v double', functionType=PandasUDFType.SCALAR) + @pandas_udf(returnType='double', functionType=PandasUDFType.SCALAR) def foo(x): return x - self.assertEqual(foo.returnType, schema) + self.assertEqual(foo.returnType, DoubleType()) self.assertEqual(foo.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF) @pandas_udf(returnType=schema, functionType=PandasUDFType.GROUPED_MAP) @@ -3830,7 +3830,7 @@ def zero_with_type(): @pandas_udf(returnType=PandasUDFType.GROUPED_MAP) def foo(df): return df - with self.assertRaisesRegexp(ValueError, 'Invalid returnType'): + with self.assertRaisesRegexp(TypeError, 'Invalid returnType'): @pandas_udf(returnType='double', functionType=PandasUDFType.GROUPED_MAP) def foo(df): return df @@ -3879,7 +3879,7 @@ def random_udf(v): return random_udf def test_vectorized_udf_basic(self): - from pyspark.sql.functions import pandas_udf, col + from pyspark.sql.functions import pandas_udf, col, array df = self.spark.range(10).select( col('id').cast('string').alias('str'), col('id').cast('int').alias('int'), @@ -3887,7 +3887,8 @@ def test_vectorized_udf_basic(self): col('id').cast('float').alias('float'), col('id').cast('double').alias('double'), col('id').cast('decimal').alias('decimal'), - col('id').cast('boolean').alias('bool')) + col('id').cast('boolean').alias('bool'), + array(col('id')).alias('array_long')) f = lambda x: x str_f = pandas_udf(f, StringType()) int_f = pandas_udf(f, IntegerType()) @@ -3896,10 +3897,11 @@ def test_vectorized_udf_basic(self): double_f = pandas_udf(f, DoubleType()) decimal_f = pandas_udf(f, DecimalType()) bool_f = pandas_udf(f, BooleanType()) + array_long_f = pandas_udf(f, ArrayType(LongType())) res = df.select(str_f(col('str')), int_f(col('int')), long_f(col('long')), float_f(col('float')), double_f(col('double')), decimal_f('decimal'), - bool_f(col('bool'))) + bool_f(col('bool')), array_long_f('array_long')) self.assertEquals(df.collect(), res.collect()) def test_register_nondeterministic_vectorized_udf_basic(self): @@ -4104,10 +4106,11 @@ def test_vectorized_udf_chained(self): def test_vectorized_udf_wrong_return_type(self): from pyspark.sql.functions import pandas_udf, col df = self.spark.range(10) - f = pandas_udf(lambda x: x * 1.0, MapType(LongType(), LongType())) with QuietTest(self.sc): - with self.assertRaisesRegexp(Exception, 'Unsupported.*type.*conversion'): - df.select(f(col('id'))).collect() + with self.assertRaisesRegexp( + NotImplementedError, + 'Invalid returnType.*scalar Pandas UDF.*MapType'): + pandas_udf(lambda x: x * 1.0, MapType(LongType(), LongType())) def test_vectorized_udf_return_scalar(self): from pyspark.sql.functions import pandas_udf, col @@ -4142,13 +4145,18 @@ def test_vectorized_udf_varargs(self): self.assertEquals(df.collect(), res.collect()) def test_vectorized_udf_unsupported_types(self): - from pyspark.sql.functions import pandas_udf, col - schema = StructType([StructField("map", MapType(StringType(), IntegerType()), True)]) - df = self.spark.createDataFrame([(None,)], schema=schema) - f = pandas_udf(lambda x: x, MapType(StringType(), IntegerType())) + from pyspark.sql.functions import pandas_udf with QuietTest(self.sc): - with self.assertRaisesRegexp(Exception, 'Unsupported data type'): - df.select(f(col('map'))).collect() + with self.assertRaisesRegexp( + NotImplementedError, + 'Invalid returnType.*scalar Pandas UDF.*MapType'): + pandas_udf(lambda x: x, MapType(StringType(), IntegerType())) + + with QuietTest(self.sc): + with self.assertRaisesRegexp( + NotImplementedError, + 'Invalid returnType.*scalar Pandas UDF.*BinaryType'): + pandas_udf(lambda x: x, BinaryType()) def test_vectorized_udf_dates(self): from pyspark.sql.functions import pandas_udf, col @@ -4379,15 +4387,16 @@ def data(self): .withColumn("vs", array([lit(i) for i in range(20, 30)])) \ .withColumn("v", explode(col('vs'))).drop('vs') - def test_simple(self): - from pyspark.sql.functions import pandas_udf, PandasUDFType - df = self.data + def test_supported_types(self): + from pyspark.sql.functions import pandas_udf, PandasUDFType, array, col + df = self.data.withColumn("arr", array(col("id"))) foo_udf = pandas_udf( lambda pdf: pdf.assign(v1=pdf.v * pdf.id * 1.0, v2=pdf.v + pdf.id), StructType( [StructField('id', LongType()), StructField('v', IntegerType()), + StructField('arr', ArrayType(LongType())), StructField('v1', DoubleType()), StructField('v2', LongType())]), PandasUDFType.GROUPED_MAP @@ -4490,17 +4499,15 @@ def test_datatype_string(self): def test_wrong_return_type(self): from pyspark.sql.functions import pandas_udf, PandasUDFType - df = self.data - - foo = pandas_udf( - lambda pdf: pdf, - 'id long, v map', - PandasUDFType.GROUPED_MAP - ) with QuietTest(self.sc): - with self.assertRaisesRegexp(Exception, 'Unsupported.*type.*conversion'): - df.groupby('id').apply(foo).sort('id').toPandas() + with self.assertRaisesRegexp( + NotImplementedError, + 'Invalid returnType.*grouped map Pandas UDF.*MapType'): + pandas_udf( + lambda pdf: pdf, + 'id long, v map', + PandasUDFType.GROUPED_MAP) def test_wrong_args(self): from pyspark.sql.functions import udf, pandas_udf, sum, PandasUDFType @@ -4519,23 +4526,30 @@ def test_wrong_args(self): df.groupby('id').apply( pandas_udf(lambda: 1, StructType([StructField("d", DoubleType())]))) with self.assertRaisesRegexp(ValueError, 'Invalid udf'): - df.groupby('id').apply( - pandas_udf(lambda x, y: x, StructType([StructField("d", DoubleType())]))) + df.groupby('id').apply(pandas_udf(lambda x, y: x, DoubleType())) with self.assertRaisesRegexp(ValueError, 'Invalid udf.*GROUPED_MAP'): df.groupby('id').apply( - pandas_udf(lambda x, y: x, StructType([StructField("d", DoubleType())]), - PandasUDFType.SCALAR)) + pandas_udf(lambda x, y: x, DoubleType(), PandasUDFType.SCALAR)) def test_unsupported_types(self): - from pyspark.sql.functions import pandas_udf, col, PandasUDFType + from pyspark.sql.functions import pandas_udf, PandasUDFType schema = StructType( [StructField("id", LongType(), True), StructField("map", MapType(StringType(), IntegerType()), True)]) - df = self.spark.createDataFrame([(1, None,)], schema=schema) - f = pandas_udf(lambda x: x, df.schema, PandasUDFType.GROUPED_MAP) with QuietTest(self.sc): - with self.assertRaisesRegexp(Exception, 'Unsupported data type'): - df.groupby('id').apply(f).collect() + with self.assertRaisesRegexp( + NotImplementedError, + 'Invalid returnType.*grouped map Pandas UDF.*MapType'): + pandas_udf(lambda x: x, schema, PandasUDFType.GROUPED_MAP) + + schema = StructType( + [StructField("id", LongType(), True), + StructField("arr_ts", ArrayType(TimestampType()), True)]) + with QuietTest(self.sc): + with self.assertRaisesRegexp( + NotImplementedError, + 'Invalid returnType.*grouped map Pandas UDF.*ArrayType.*TimestampType'): + pandas_udf(lambda x: x, schema, PandasUDFType.GROUPED_MAP) # Regression test for SPARK-23314 def test_timestamp_dst(self): @@ -4614,23 +4628,32 @@ def weighted_mean(v, w): return weighted_mean def test_manual(self): + from pyspark.sql.functions import pandas_udf, array + df = self.data sum_udf = self.pandas_agg_sum_udf mean_udf = self.pandas_agg_mean_udf - - result1 = df.groupby('id').agg(sum_udf(df.v), mean_udf(df.v)).sort('id') + mean_arr_udf = pandas_udf( + self.pandas_agg_mean_udf.func, + ArrayType(self.pandas_agg_mean_udf.returnType), + self.pandas_agg_mean_udf.evalType) + + result1 = df.groupby('id').agg( + sum_udf(df.v), + mean_udf(df.v), + mean_arr_udf(array(df.v))).sort('id') expected1 = self.spark.createDataFrame( - [[0, 245.0, 24.5], - [1, 255.0, 25.5], - [2, 265.0, 26.5], - [3, 275.0, 27.5], - [4, 285.0, 28.5], - [5, 295.0, 29.5], - [6, 305.0, 30.5], - [7, 315.0, 31.5], - [8, 325.0, 32.5], - [9, 335.0, 33.5]], - ['id', 'sum(v)', 'avg(v)']) + [[0, 245.0, 24.5, [24.5]], + [1, 255.0, 25.5, [25.5]], + [2, 265.0, 26.5, [26.5]], + [3, 275.0, 27.5, [27.5]], + [4, 285.0, 28.5, [28.5]], + [5, 295.0, 29.5, [29.5]], + [6, 305.0, 30.5, [30.5]], + [7, 315.0, 31.5, [31.5]], + [8, 325.0, 32.5, [32.5]], + [9, 335.0, 33.5, [33.5]]], + ['id', 'sum(v)', 'avg(v)', 'avg(array(v))']) self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) @@ -4667,14 +4690,15 @@ def test_basic(self): self.assertPandasEqual(expected4.toPandas(), result4.toPandas()) def test_unsupported_types(self): - from pyspark.sql.types import ArrayType, DoubleType, MapType + from pyspark.sql.types import DoubleType, MapType from pyspark.sql.functions import pandas_udf, PandasUDFType with QuietTest(self.sc): with self.assertRaisesRegexp(NotImplementedError, 'not supported'): - @pandas_udf(ArrayType(DoubleType()), PandasUDFType.GROUPED_AGG) - def mean_and_std_udf(v): - return [v.mean(), v.std()] + pandas_udf( + lambda x: x, + ArrayType(ArrayType(TimestampType())), + PandasUDFType.GROUPED_AGG) with QuietTest(self.sc): with self.assertRaisesRegexp(NotImplementedError, 'not supported'): diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index f7141b4549e4e..e25941cd37595 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1638,6 +1638,8 @@ def to_arrow_type(dt): # Timestamps should be in UTC, JVM Arrow timestamps require a timezone to be read arrow_type = pa.timestamp('us', tz='UTC') elif type(dt) == ArrayType: + if type(dt.elementType) == TimestampType: + raise TypeError("Unsupported type in conversion to Arrow: " + str(dt)) arrow_type = pa.list_(to_arrow_type(dt.elementType)) else: raise TypeError("Unsupported type in conversion to Arrow: " + str(dt)) @@ -1680,6 +1682,8 @@ def from_arrow_type(at): elif types.is_timestamp(at): spark_type = TimestampType() elif types.is_list(at): + if types.is_timestamp(at.value_type): + raise TypeError("Unsupported type in conversion from Arrow: " + str(at)) spark_type = ArrayType(from_arrow_type(at.value_type)) else: raise TypeError("Unsupported type in conversion from Arrow: " + str(at)) diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index 08c6b9e521e82..e5b35fc60e167 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -23,7 +23,7 @@ from pyspark.rdd import _prepare_for_python_RDD, PythonEvalType, ignore_unicode_prefix from pyspark.sql.column import Column, _to_java_column, _to_seq from pyspark.sql.types import StringType, DataType, ArrayType, StructType, MapType, \ - _parse_datatype_string + _parse_datatype_string, to_arrow_type, to_arrow_schema __all__ = ["UDFRegistration"] @@ -112,15 +112,31 @@ def returnType(self): else: self._returnType_placeholder = _parse_datatype_string(self._returnType) - if self.evalType == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF \ - and not isinstance(self._returnType_placeholder, StructType): - raise ValueError("Invalid returnType: returnType must be a StructType for " - "pandas_udf with function type GROUPED_MAP") - elif self.evalType == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF \ - and isinstance(self._returnType_placeholder, (StructType, ArrayType, MapType)): - raise NotImplementedError( - "ArrayType, StructType and MapType are not supported with " - "PandasUDFType.GROUPED_AGG") + if self.evalType == PythonEvalType.SQL_SCALAR_PANDAS_UDF: + try: + to_arrow_type(self._returnType_placeholder) + except TypeError: + raise NotImplementedError( + "Invalid returnType with scalar Pandas UDFs: %s is " + "not supported" % str(self._returnType_placeholder)) + elif self.evalType == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF: + if isinstance(self._returnType_placeholder, StructType): + try: + to_arrow_schema(self._returnType_placeholder) + except TypeError: + raise NotImplementedError( + "Invalid returnType with grouped map Pandas UDFs: " + "%s is not supported" % str(self._returnType_placeholder)) + else: + raise TypeError("Invalid returnType for grouped map Pandas " + "UDFs: returnType must be a StructType.") + elif self.evalType == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF: + try: + to_arrow_type(self._returnType_placeholder) + except TypeError: + raise NotImplementedError( + "Invalid returnType with grouped aggregate Pandas UDFs: " + "%s is not supported" % str(self._returnType_placeholder)) return self._returnType_placeholder diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 121b3dd1aeec9..89a3a92bc66d6 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -116,7 +116,7 @@ def wrap_grouped_agg_pandas_udf(f, return_type): def wrapped(*series): import pandas as pd result = f(*series) - return pd.Series(result) + return pd.Series([result]) return lambda *a: (wrapped(*a), arrow_return_type) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 1e2501ee7757d..7835dbaa58439 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1064,7 +1064,7 @@ object SQLConf { "for use with pyspark.sql.DataFrame.toPandas, and " + "pyspark.sql.SparkSession.createDataFrame when its input is a Pandas DataFrame. " + "The following data types are unsupported: " + - "MapType, ArrayType of TimestampType, and nested StructType.") + "BinaryType, MapType, ArrayType of TimestampType, and nested StructType.") .booleanConf .createWithDefault(false) From caeb108e25e5bfb7cffcf09ef9abbb1abcfa355d Mon Sep 17 00:00:00 2001 From: caoxuewen Date: Mon, 12 Feb 2018 22:05:27 +0800 Subject: [PATCH 0326/1161] [MINOR][TEST] spark.testing` No effect on the SparkFunSuite unit test ## What changes were proposed in this pull request? Currently, we use SBT and MAVN to spark unit test, are affected by the parameters of `spark.testing`. However, when using the IDE test tool, `spark.testing` support is not very good, sometimes need to be manually added to the beforeEach. example: HiveSparkSubmitSuite RPackageUtilsSuite SparkSubmitSuite. The PR unified `spark.testing` parameter extraction to SparkFunSuite, support IDE test tool, and the test code is more compact. ## How was this patch tested? the existed test cases. Author: caoxuewen Closes #20582 from heary-cao/sparktesting. --- core/src/test/scala/org/apache/spark/SparkFunSuite.scala | 1 + .../test/scala/org/apache/spark/deploy/RPackageUtilsSuite.scala | 1 - .../test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala | 1 - .../spark/network/netty/NettyBlockTransferServiceSuite.scala | 1 + .../scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala | 1 - 5 files changed, 2 insertions(+), 3 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/SparkFunSuite.scala b/core/src/test/scala/org/apache/spark/SparkFunSuite.scala index 3af9d82393bc4..31289026b0027 100644 --- a/core/src/test/scala/org/apache/spark/SparkFunSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkFunSuite.scala @@ -59,6 +59,7 @@ abstract class SparkFunSuite protected val enableAutoThreadAudit = true protected override def beforeAll(): Unit = { + System.setProperty("spark.testing", "true") if (enableAutoThreadAudit) { doThreadPreAudit() } diff --git a/core/src/test/scala/org/apache/spark/deploy/RPackageUtilsSuite.scala b/core/src/test/scala/org/apache/spark/deploy/RPackageUtilsSuite.scala index 32dd3ecc2f027..ef947eb074647 100644 --- a/core/src/test/scala/org/apache/spark/deploy/RPackageUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/RPackageUtilsSuite.scala @@ -66,7 +66,6 @@ class RPackageUtilsSuite override def beforeEach(): Unit = { super.beforeEach() - System.setProperty("spark.testing", "true") lineBuffer.clear() } diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index 27dd435332348..803a38d77fb82 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -107,7 +107,6 @@ class SparkSubmitSuite override def beforeEach() { super.beforeEach() - System.setProperty("spark.testing", "true") } // scalastyle:off println diff --git a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferServiceSuite.scala b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferServiceSuite.scala index f7bc3725d7278..78423ee68a0ec 100644 --- a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferServiceSuite.scala +++ b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferServiceSuite.scala @@ -80,6 +80,7 @@ class NettyBlockTransferServiceSuite private def verifyServicePort(expectedPort: Int, actualPort: Int): Unit = { actualPort should be >= expectedPort // avoid testing equality in case of simultaneous tests + // if `spark.testing` is true, // the default value for `spark.port.maxRetries` is 100 under test actualPort should be <= (expectedPort + 100) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala index 10204f4694663..2d31781132edc 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala @@ -50,7 +50,6 @@ class HiveSparkSubmitSuite override def beforeEach() { super.beforeEach() - System.setProperty("spark.testing", "true") } test("temporary Hive UDF: define a UDF and use it") { From 0e2c266de7189473177f45aa68ea6a45c7e47ec3 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 12 Feb 2018 22:07:59 +0800 Subject: [PATCH 0327/1161] [SPARK-22977][SQL] fix web UI SQL tab for CTAS ## What changes were proposed in this pull request? This is a regression in Spark 2.3. In Spark 2.2, we have a fragile UI support for SQL data writing commands. We only track the input query plan of `FileFormatWriter` and display its metrics. This is not ideal because we don't know who triggered the writing(can be table insertion, CTAS, etc.), but it's still useful to see the metrics of the input query. In Spark 2.3, we introduced a new mechanism: `DataWritigCommand`, to fix the UI issue entirely. Now these writing commands have real children, and we don't need to hack into the `FileFormatWriter` for the UI. This also helps with `explain`, now `explain` can show the physical plan of the input query, while in 2.2 the physical writing plan is simply `ExecutedCommandExec` and it has no child. However there is a regression in CTAS. CTAS commands don't extend `DataWritigCommand`, and we don't have the UI hack in `FileFormatWriter` anymore, so the UI for CTAS is just an empty node. See https://issues.apache.org/jira/browse/SPARK-22977 for more information about this UI issue. To fix it, we should apply the `DataWritigCommand` mechanism to CTAS commands. TODO: In the future, we should refactor this part and create some physical layer code pieces for data writing, and reuse them in different writing commands. We should have different logical nodes for different operators, even some of them share some same logic, e.g. CTAS, CREATE TABLE, INSERT TABLE. Internally we can share the same physical logic. ## How was this patch tested? manually tested. For data source table 1 For hive table 2 Author: Wenchen Fan Closes #20521 from cloud-fan/UI. --- .../command/createDataSourceTables.scala | 21 +++---- .../execution/datasources/DataSource.scala | 44 +++++++++++++-- .../datasources/DataSourceStrategy.scala | 2 +- .../spark/sql/hive/HiveStrategies.scala | 2 +- .../CreateHiveTableAsSelectCommand.scala | 55 ++++++++++--------- .../sql/hive/execution/HiveExplainSuite.scala | 26 --------- 6 files changed, 80 insertions(+), 70 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala index 306f43dc4214a..e9747769dfcfc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala @@ -21,7 +21,9 @@ import java.net.URI import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.catalog._ +import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.sources.BaseRelation import org.apache.spark.sql.types.StructType @@ -136,12 +138,11 @@ case class CreateDataSourceTableCommand(table: CatalogTable, ignoreIfExists: Boo case class CreateDataSourceTableAsSelectCommand( table: CatalogTable, mode: SaveMode, - query: LogicalPlan) - extends RunnableCommand { - - override protected def innerChildren: Seq[LogicalPlan] = Seq(query) + query: LogicalPlan, + outputColumns: Seq[Attribute]) + extends DataWritingCommand { - override def run(sparkSession: SparkSession): Seq[Row] = { + override def run(sparkSession: SparkSession, child: SparkPlan): Seq[Row] = { assert(table.tableType != CatalogTableType.VIEW) assert(table.provider.isDefined) @@ -163,7 +164,7 @@ case class CreateDataSourceTableAsSelectCommand( } saveDataIntoTable( - sparkSession, table, table.storage.locationUri, query, SaveMode.Append, tableExists = true) + sparkSession, table, table.storage.locationUri, child, SaveMode.Append, tableExists = true) } else { assert(table.schema.isEmpty) @@ -173,7 +174,7 @@ case class CreateDataSourceTableAsSelectCommand( table.storage.locationUri } val result = saveDataIntoTable( - sparkSession, table, tableLocation, query, SaveMode.Overwrite, tableExists = false) + sparkSession, table, tableLocation, child, SaveMode.Overwrite, tableExists = false) val newTable = table.copy( storage = table.storage.copy(locationUri = tableLocation), // We will use the schema of resolved.relation as the schema of the table (instead of @@ -198,10 +199,10 @@ case class CreateDataSourceTableAsSelectCommand( session: SparkSession, table: CatalogTable, tableLocation: Option[URI], - data: LogicalPlan, + physicalPlan: SparkPlan, mode: SaveMode, tableExists: Boolean): BaseRelation = { - // Create the relation based on the input logical plan: `data`. + // Create the relation based on the input logical plan: `query`. val pathOption = tableLocation.map("path" -> CatalogUtils.URIToString(_)) val dataSource = DataSource( session, @@ -212,7 +213,7 @@ case class CreateDataSourceTableAsSelectCommand( catalogTable = if (tableExists) Some(table) else None) try { - dataSource.writeAndRead(mode, query) + dataSource.writeAndRead(mode, query, outputColumns, physicalPlan) } catch { case ex: AnalysisException => logError(s"Failed to write to table ${table.identifier.unquotedString}", ex) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index 25e1210504273..6e1b5727e3fd5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -31,8 +31,10 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat, CatalogTable, CatalogUtils} +import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap +import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat import org.apache.spark.sql.execution.datasources.jdbc.JdbcRelationProvider import org.apache.spark.sql.execution.datasources.json.JsonFileFormat @@ -435,10 +437,11 @@ case class DataSource( } /** - * Writes the given [[LogicalPlan]] out in this [[FileFormat]]. + * Creates a command node to write the given [[LogicalPlan]] out to the given [[FileFormat]]. + * The returned command is unresolved and need to be analyzed. */ private def planForWritingFileFormat( - format: FileFormat, mode: SaveMode, data: LogicalPlan): LogicalPlan = { + format: FileFormat, mode: SaveMode, data: LogicalPlan): InsertIntoHadoopFsRelationCommand = { // Don't glob path for the write path. The contracts here are: // 1. Only one output path can be specified on the write path; // 2. Output path must be a legal HDFS style file system path; @@ -482,9 +485,24 @@ case class DataSource( /** * Writes the given [[LogicalPlan]] out to this [[DataSource]] and returns a [[BaseRelation]] for * the following reading. + * + * @param mode The save mode for this writing. + * @param data The input query plan that produces the data to be written. Note that this plan + * is analyzed and optimized. + * @param outputColumns The original output columns of the input query plan. The optimizer may not + * preserve the output column's names' case, so we need this parameter + * instead of `data.output`. + * @param physicalPlan The physical plan of the input query plan. We should run the writing + * command with this physical plan instead of creating a new physical plan, + * so that the metrics can be correctly linked to the given physical plan and + * shown in the web UI. */ - def writeAndRead(mode: SaveMode, data: LogicalPlan): BaseRelation = { - if (data.schema.map(_.dataType).exists(_.isInstanceOf[CalendarIntervalType])) { + def writeAndRead( + mode: SaveMode, + data: LogicalPlan, + outputColumns: Seq[Attribute], + physicalPlan: SparkPlan): BaseRelation = { + if (outputColumns.map(_.dataType).exists(_.isInstanceOf[CalendarIntervalType])) { throw new AnalysisException("Cannot save interval data type into external storage.") } @@ -493,9 +511,23 @@ case class DataSource( dataSource.createRelation( sparkSession.sqlContext, mode, caseInsensitiveOptions, Dataset.ofRows(sparkSession, data)) case format: FileFormat => - sparkSession.sessionState.executePlan(planForWritingFileFormat(format, mode, data)).toRdd + val cmd = planForWritingFileFormat(format, mode, data) + val resolvedPartCols = cmd.partitionColumns.map { col => + // The partition columns created in `planForWritingFileFormat` should always be + // `UnresolvedAttribute` with a single name part. + assert(col.isInstanceOf[UnresolvedAttribute]) + val unresolved = col.asInstanceOf[UnresolvedAttribute] + assert(unresolved.nameParts.length == 1) + val name = unresolved.nameParts.head + outputColumns.find(a => equality(a.name, name)).getOrElse { + throw new AnalysisException( + s"Unable to resolve $name given [${data.output.map(_.name).mkString(", ")}]") + } + } + val resolved = cmd.copy(partitionColumns = resolvedPartCols, outputColumns = outputColumns) + resolved.run(sparkSession, physicalPlan) // Replace the schema with that of the DataFrame we just wrote out to avoid re-inferring - copy(userSpecifiedSchema = Some(data.schema.asNullable)).resolveRelation() + copy(userSpecifiedSchema = Some(outputColumns.toStructType.asNullable)).resolveRelation() case _ => sys.error(s"${providingClass.getCanonicalName} does not allow create table as select.") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index d94c5bbccdd84..3f41612c08065 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -139,7 +139,7 @@ case class DataSourceAnalysis(conf: SQLConf) extends Rule[LogicalPlan] with Cast case CreateTable(tableDesc, mode, Some(query)) if query.resolved && DDLUtils.isDatasourceTable(tableDesc) => DDLUtils.checkDataColNames(tableDesc.copy(schema = query.schema)) - CreateDataSourceTableAsSelectCommand(tableDesc, mode, query) + CreateDataSourceTableAsSelectCommand(tableDesc, mode, query, query.output) case InsertIntoTable(l @ LogicalRelation(_: InsertableRelation, _, _, _), parts, query, overwrite, false) if parts.isEmpty => diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index ab857b9055720..8df05cbb20361 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -157,7 +157,7 @@ object HiveAnalysis extends Rule[LogicalPlan] { case CreateTable(tableDesc, mode, Some(query)) if DDLUtils.isHiveTable(tableDesc) => DDLUtils.checkDataColNames(tableDesc) - CreateHiveTableAsSelectCommand(tableDesc, query, mode) + CreateHiveTableAsSelectCommand(tableDesc, query, query.output, mode) case InsertIntoDir(isLocal, storage, provider, child, overwrite) if DDLUtils.isHiveTable(provider) => diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala index 65e8b4e3c725c..1e801fe1845c4 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala @@ -20,10 +20,11 @@ package org.apache.spark.sql.hive.execution import scala.util.control.NonFatal import org.apache.spark.sql.{AnalysisException, Row, SaveMode, SparkSession} -import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.catalyst.catalog.CatalogTable -import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoTable, LogicalPlan} -import org.apache.spark.sql.execution.command.RunnableCommand +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.command.DataWritingCommand /** @@ -36,15 +37,15 @@ import org.apache.spark.sql.execution.command.RunnableCommand case class CreateHiveTableAsSelectCommand( tableDesc: CatalogTable, query: LogicalPlan, + outputColumns: Seq[Attribute], mode: SaveMode) - extends RunnableCommand { + extends DataWritingCommand { private val tableIdentifier = tableDesc.identifier - override def innerChildren: Seq[LogicalPlan] = Seq(query) - - override def run(sparkSession: SparkSession): Seq[Row] = { - if (sparkSession.sessionState.catalog.tableExists(tableIdentifier)) { + override def run(sparkSession: SparkSession, child: SparkPlan): Seq[Row] = { + val catalog = sparkSession.sessionState.catalog + if (catalog.tableExists(tableIdentifier)) { assert(mode != SaveMode.Overwrite, s"Expect the table $tableIdentifier has been dropped when the save mode is Overwrite") @@ -56,34 +57,36 @@ case class CreateHiveTableAsSelectCommand( return Seq.empty } - sparkSession.sessionState.executePlan( - InsertIntoTable( - UnresolvedRelation(tableIdentifier), - Map(), - query, - overwrite = false, - ifPartitionNotExists = false)).toRdd + InsertIntoHiveTable( + tableDesc, + Map.empty, + query, + overwrite = false, + ifPartitionNotExists = false, + outputColumns = outputColumns).run(sparkSession, child) } else { // TODO ideally, we should get the output data ready first and then // add the relation into catalog, just in case of failure occurs while data // processing. assert(tableDesc.schema.isEmpty) - sparkSession.sessionState.catalog.createTable( - tableDesc.copy(schema = query.schema), ignoreIfExists = false) + catalog.createTable(tableDesc.copy(schema = query.schema), ignoreIfExists = false) try { - sparkSession.sessionState.executePlan( - InsertIntoTable( - UnresolvedRelation(tableIdentifier), - Map(), - query, - overwrite = true, - ifPartitionNotExists = false)).toRdd + // Read back the metadata of the table which was created just now. + val createdTableMeta = catalog.getTableMetadata(tableDesc.identifier) + // For CTAS, there is no static partition values to insert. + val partition = createdTableMeta.partitionColumnNames.map(_ -> None).toMap + InsertIntoHiveTable( + createdTableMeta, + partition, + query, + overwrite = true, + ifPartitionNotExists = false, + outputColumns = outputColumns).run(sparkSession, child) } catch { case NonFatal(e) => // drop the created table. - sparkSession.sessionState.catalog.dropTable(tableIdentifier, ignoreIfNotExists = true, - purge = false) + catalog.dropTable(tableIdentifier, ignoreIfNotExists = true, purge = false) throw e } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala index f84d188075b72..5d56f89c2271c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala @@ -128,32 +128,6 @@ class HiveExplainSuite extends QueryTest with SQLTestUtils with TestHiveSingleto "src") } - test("SPARK-17409: The EXPLAIN output of CTAS only shows the analyzed plan") { - withTempView("jt") { - val ds = (1 to 10).map(i => s"""{"a":$i, "b":"str$i"}""").toDS() - spark.read.json(ds).createOrReplaceTempView("jt") - val outputs = sql( - s""" - |EXPLAIN EXTENDED - |CREATE TABLE t1 - |AS - |SELECT * FROM jt - """.stripMargin).collect().map(_.mkString).mkString - - val shouldContain = - "== Parsed Logical Plan ==" :: "== Analyzed Logical Plan ==" :: "Subquery" :: - "== Optimized Logical Plan ==" :: "== Physical Plan ==" :: - "CreateHiveTableAsSelect" :: "InsertIntoHiveTable" :: "jt" :: Nil - for (key <- shouldContain) { - assert(outputs.contains(key), s"$key doesn't exist in result") - } - - val physicalIndex = outputs.indexOf("== Physical Plan ==") - assert(outputs.substring(physicalIndex).contains("Subquery"), - "Physical Plan should contain SubqueryAlias since the query should not be optimized") - } - } - test("explain output of physical plan should contain proper codegen stage ID") { checkKeywordsExist(sql( """ From 4a4dd4f36f65410ef5c87f7b61a960373f044e61 Mon Sep 17 00:00:00 2001 From: liuxian Date: Mon, 12 Feb 2018 08:49:45 -0600 Subject: [PATCH 0328/1161] [SPARK-23391][CORE] It may lead to overflow for some integer multiplication ## What changes were proposed in this pull request? In the `getBlockData`,`blockId.reduceId` is the `Int` type, when it is greater than 2^28, `blockId.reduceId*8` will overflow In the `decompress0`, `len` and `unitSize` are Int type, so `len * unitSize` may lead to overflow ## How was this patch tested? N/A Author: liuxian Closes #20581 from 10110346/overflow2. --- .../org/apache/spark/shuffle/IndexShuffleBlockResolver.scala | 4 ++-- .../execution/columnar/compression/compressionSchemes.scala | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala index d88b25cc7e258..d3f1c7ec1bbee 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala @@ -202,13 +202,13 @@ private[spark] class IndexShuffleBlockResolver( // class of issue from re-occurring in the future which is why they are left here even though // SPARK-22982 is fixed. val channel = Files.newByteChannel(indexFile.toPath) - channel.position(blockId.reduceId * 8) + channel.position(blockId.reduceId * 8L) val in = new DataInputStream(Channels.newInputStream(channel)) try { val offset = in.readLong() val nextOffset = in.readLong() val actualPosition = channel.position() - val expectedPosition = blockId.reduceId * 8 + 16 + val expectedPosition = blockId.reduceId * 8L + 16 if (actualPosition != expectedPosition) { throw new Exception(s"SPARK-22982: Incorrect channel position after index file reads: " + s"expected $expectedPosition but actual position was $actualPosition.") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/compressionSchemes.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/compressionSchemes.scala index 79dcf3a6105ce..00a1d54b41709 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/compressionSchemes.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/compressionSchemes.scala @@ -116,7 +116,7 @@ private[columnar] case object PassThrough extends CompressionScheme { while (pos < capacity) { if (pos != nextNullIndex) { val len = nextNullIndex - pos - assert(len * unitSize < Int.MaxValue) + assert(len * unitSize.toLong < Int.MaxValue) putFunction(columnVector, pos, bufferPos, len) bufferPos += len * unitSize pos += len From 5bb11411aec18b8d623e54caba5397d7cb8e89f0 Mon Sep 17 00:00:00 2001 From: James Thompson Date: Mon, 12 Feb 2018 11:34:56 -0800 Subject: [PATCH 0329/1161] [SPARK-23388][SQL] Support for Parquet Binary DecimalType in VectorizedColumnReader ## What changes were proposed in this pull request? Re-add support for parquet binary DecimalType in VectorizedColumnReader ## How was this patch tested? Existing test suite Author: James Thompson Closes #20580 from jamesthomp/jt/add-back-binary-decimal. --- .../execution/datasources/parquet/VectorizedColumnReader.java | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java index c120863152a96..47dd625f4b154 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java @@ -444,7 +444,8 @@ private void readBinaryBatch(int rowId, int num, WritableColumnVector column) { // This is where we implement support for the valid type conversions. // TODO: implement remaining type conversions VectorizedValuesReader data = (VectorizedValuesReader) dataColumn; - if (column.dataType() == DataTypes.StringType || column.dataType() == DataTypes.BinaryType) { + if (column.dataType() == DataTypes.StringType || column.dataType() == DataTypes.BinaryType + || DecimalType.isByteArrayDecimalType(column.dataType())) { defColumn.readBinarys(num, column, rowId, maxDefLevel, data); } else if (column.dataType() == DataTypes.TimestampType) { if (!shouldConvertTimestamps()) { From 0c66fe4f22f8af4932893134bb0fd56f00fabeae Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Mon, 12 Feb 2018 12:20:29 -0800 Subject: [PATCH 0330/1161] [SPARK-22002][SQL][FOLLOWUP][TEST] Add a test to check if the original schema doesn't have metadata. ## What changes were proposed in this pull request? This is a follow-up pr of #19231 which modified the behavior to remove metadata from JDBC table schema. This pr adds a test to check if the schema doesn't have metadata. ## How was this patch tested? Added a test and existing tests. Author: Takuya UESHIN Closes #20585 from ueshin/issues/SPARK-22002/fup1. --- .../org/apache/spark/sql/jdbc/JDBCSuite.scala | 22 +++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index cb2df0ac54f4c..5238adce4a699 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -1168,4 +1168,26 @@ class JDBCSuite extends SparkFunSuite val df3 = sql("SELECT * FROM test_sessionInitStatement") assert(df3.collect() === Array(Row(21519, 1234))) } + + test("jdbc data source shouldn't have unnecessary metadata in its schema") { + val schema = StructType(Seq( + StructField("NAME", StringType, true), StructField("THEID", IntegerType, true))) + + val df = spark.read.format("jdbc") + .option("Url", urlWithUserAndPass) + .option("DbTaBle", "TEST.PEOPLE") + .load() + assert(df.schema === schema) + + withTempView("people_view") { + sql( + s""" + |CREATE TEMPORARY VIEW people_view + |USING org.apache.spark.sql.jdbc + |OPTIONS (uRl '$url', DbTaBlE 'TEST.PEOPLE', User 'testUser', PassWord 'testPass') + """.stripMargin.replaceAll("\n", " ")) + + assert(sql("select * from people_view").schema === schema) + } + } } From fba01b9a65e5d9438d35da0bd807c179ba741911 Mon Sep 17 00:00:00 2001 From: Feng Liu Date: Mon, 12 Feb 2018 14:58:31 -0800 Subject: [PATCH 0331/1161] [SPARK-23378][SQL] move setCurrentDatabase from HiveExternalCatalog to HiveClientImpl ## What changes were proposed in this pull request? This removes the special case that `alterPartitions` call from `HiveExternalCatalog` can reset the current database in the hive client as a side effect. ## How was this patch tested? (Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests) (If this patch involves UI changes, please attach a screenshot; otherwise, remove this) Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Feng Liu Closes #20564 from liufengdb/move. --- .../spark/sql/hive/HiveExternalCatalog.scala | 5 ---- .../sql/hive/client/HiveClientImpl.scala | 26 ++++++++++++++----- 2 files changed, 20 insertions(+), 11 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index 3b8a8ca301c27..1ee1d57b8ebe1 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -1107,11 +1107,6 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat } } - // Note: Before altering table partitions in Hive, you *must* set the current database - // to the one that contains the table of interest. Otherwise you will end up with the - // most helpful error message ever: "Unable to alter partition. alter is not possible." - // See HIVE-2742 for more detail. - client.setCurrentDatabase(db) client.alterPartitions(db, table, withStatsProps) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala index 6c0f4144992ae..c223f51b1be75 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala @@ -291,14 +291,18 @@ private[hive] class HiveClientImpl( state.err = stream } - override def setCurrentDatabase(databaseName: String): Unit = withHiveState { - if (databaseExists(databaseName)) { - state.setCurrentDatabase(databaseName) + private def setCurrentDatabaseRaw(db: String): Unit = { + if (databaseExists(db)) { + state.setCurrentDatabase(db) } else { - throw new NoSuchDatabaseException(databaseName) + throw new NoSuchDatabaseException(db) } } + override def setCurrentDatabase(databaseName: String): Unit = withHiveState { + setCurrentDatabaseRaw(databaseName) + } + override def createDatabase( database: CatalogDatabase, ignoreIfExists: Boolean): Unit = withHiveState { @@ -598,8 +602,18 @@ private[hive] class HiveClientImpl( db: String, table: String, newParts: Seq[CatalogTablePartition]): Unit = withHiveState { - val hiveTable = toHiveTable(getTable(db, table), Some(userName)) - shim.alterPartitions(client, table, newParts.map { p => toHivePartition(p, hiveTable) }.asJava) + // Note: Before altering table partitions in Hive, you *must* set the current database + // to the one that contains the table of interest. Otherwise you will end up with the + // most helpful error message ever: "Unable to alter partition. alter is not possible." + // See HIVE-2742 for more detail. + val original = state.getCurrentDatabase + try { + setCurrentDatabaseRaw(db) + val hiveTable = toHiveTable(getTable(db, table), Some(userName)) + shim.alterPartitions(client, table, newParts.map { toHivePartition(_, hiveTable) }.asJava) + } finally { + state.setCurrentDatabase(original) + } } /** From 6cb59708c70c03696c772fbb5d158eed57fe67d4 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Mon, 12 Feb 2018 15:26:37 -0800 Subject: [PATCH 0332/1161] [SPARK-23313][DOC] Add a migration guide for ORC ## What changes were proposed in this pull request? This PR adds a migration guide documentation for ORC. ![orc-guide](https://user-images.githubusercontent.com/9700541/36123859-ec165cae-1002-11e8-90b7-7313be7a81a5.png) ## How was this patch tested? N/A. Author: Dongjoon Hyun Closes #20484 from dongjoon-hyun/SPARK-23313. --- docs/sql-programming-guide.md | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 6174a93b68492..0f9f01e18682f 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1776,6 +1776,35 @@ working with timestamps in `pandas_udf`s to get the best performance, see ## Upgrading From Spark SQL 2.2 to 2.3 + - Since Spark 2.3, Spark supports a vectorized ORC reader with a new ORC file format for ORC files. To do that, the following configurations are newly added or change their default values. The vectorized reader is used for the native ORC tables (e.g., the ones created using the clause `USING ORC`) when `spark.sql.orc.impl` is set to `native` and `spark.sql.orc.enableVectorizedReader` is set to `true`. For the Hive ORC serde table (e.g., the ones created using the clause `USING HIVE OPTIONS (fileFormat 'ORC')`), the vectorized reader is used when `spark.sql.hive.convertMetastoreOrc` is set to `true`. + + - New configurations + +
spark.mesos.principal (none) - Set the principal with which Spark framework will use to authenticate with Mesos. + Set the principal with which Spark framework will use to authenticate with Mesos. You can also specify this via the environment variable `SPARK_MESOS_PRINCIPAL`. +
spark.mesos.principal.file(none) + Set the file containing the principal with which Spark framework will use to authenticate with Mesos. Allows specifying the principal indirectly in more security conscious deployments. The file must be readable by the user launching the job and be UTF-8 encoded plaintext. You can also specify this via the environment variable `SPARK_MESOS_PRINCIPAL_FILE`.
(none) Set the secret with which Spark framework will use to authenticate with Mesos. Used, for example, when - authenticating with the registry. + authenticating with the registry. You can also specify this via the environment variable `SPARK_MESOS_SECRET`. +
spark.mesos.secret.file(none) + Set the file containing the secret with which Spark framework will use to authenticate with Mesos. Used, for example, when + authenticating with the registry. Allows for specifying the secret indirectly in more security conscious deployments. The file must be readable by the user launching the job and be UTF-8 encoded plaintext. You can also specify this via the environment variable `SPARK_MESOS_SECRET_FILE`.
+ + + + + + + + + + + +
Property NameDefaultMeaning
spark.sql.orc.implnativeThe name of ORC implementation. It can be one of native and hive. native means the native ORC support that is built on Apache ORC 1.4.1. `hive` means the ORC library in Hive 1.2.1 which is used prior to Spark 2.3.
spark.sql.orc.enableVectorizedReadertrueEnables vectorized orc decoding in native implementation. If false, a new non-vectorized ORC reader is used in native implementation. For hive implementation, this is ignored.
+ + - Changed configurations + + + + + + + + +
Property NameDefaultMeaning
spark.sql.orc.filterPushdowntrueEnables filter pushdown for ORC files. It is false by default prior to Spark 2.3.
+ - Since Spark 2.3, the queries from raw JSON/CSV files are disallowed when the referenced columns only include the internal corrupt record column (named `_corrupt_record` by default). For example, `spark.read.schema(schema).json(file).filter($"_corrupt_record".isNotNull).count()` and `spark.read.schema(schema).json(file).select("_corrupt_record").show()`. Instead, you can cache or save the parsed results and then send the same query. For example, `val df = spark.read.schema(schema).json(file).cache()` and then `df.filter($"_corrupt_record".isNotNull).count()`. - The `percentile_approx` function previously accepted numeric type input and output double type results. Now it supports date type, timestamp type and numeric types as input types. The result type is also changed to be the same as the input type, which is more reasonable for percentiles. - Since Spark 2.3, the Join/Filter's deterministic predicates that are after the first non-deterministic predicates are also pushed down/through the child operators, if possible. In prior Spark versions, these filters are not eligible for predicate pushdown. From 4104b68e958cd13975567a96541dac7cccd8195c Mon Sep 17 00:00:00 2001 From: sychen Date: Mon, 12 Feb 2018 16:00:47 -0800 Subject: [PATCH 0333/1161] [SPARK-23230][SQL] When hive.default.fileformat is other kinds of file types, create textfile table cause a serde error When hive.default.fileformat is other kinds of file types, create textfile table cause a serde error. We should take the default type of textfile and sequencefile both as org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe. ``` set hive.default.fileformat=orc; create table tbl( i string ) stored as textfile; desc formatted tbl; Serde Library org.apache.hadoop.hive.ql.io.orc.OrcSerde InputFormat org.apache.hadoop.mapred.TextInputFormat OutputFormat org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat ``` Author: sychen Closes #20406 from cxzl25/default_serde. --- .../apache/spark/sql/internal/HiveSerDe.scala | 6 ++++-- .../sql/hive/execution/HiveSerDeSuite.scala | 19 +++++++++++++++++++ 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/HiveSerDe.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/HiveSerDe.scala index dac463641cfab..eca612f06f9bb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/HiveSerDe.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/HiveSerDe.scala @@ -31,7 +31,8 @@ object HiveSerDe { "sequencefile" -> HiveSerDe( inputFormat = Option("org.apache.hadoop.mapred.SequenceFileInputFormat"), - outputFormat = Option("org.apache.hadoop.mapred.SequenceFileOutputFormat")), + outputFormat = Option("org.apache.hadoop.mapred.SequenceFileOutputFormat"), + serde = Option("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe")), "rcfile" -> HiveSerDe( @@ -54,7 +55,8 @@ object HiveSerDe { "textfile" -> HiveSerDe( inputFormat = Option("org.apache.hadoop.mapred.TextInputFormat"), - outputFormat = Option("org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat")), + outputFormat = Option("org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat"), + serde = Option("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe")), "avro" -> HiveSerDe( diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala index 1c9f00141ae1d..d7752e987cb4b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala @@ -100,6 +100,25 @@ class HiveSerDeSuite extends HiveComparisonTest with PlanTest with BeforeAndAfte assert(output == Some("org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat")) assert(serde == Some("org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe")) } + + withSQLConf("hive.default.fileformat" -> "orc") { + val (desc, exists) = extractTableDesc( + "CREATE TABLE IF NOT EXISTS fileformat_test (id int) STORED AS textfile") + assert(exists) + assert(desc.storage.inputFormat == Some("org.apache.hadoop.mapred.TextInputFormat")) + assert(desc.storage.outputFormat == + Some("org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat")) + assert(desc.storage.serde == Some("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe")) + } + + withSQLConf("hive.default.fileformat" -> "orc") { + val (desc, exists) = extractTableDesc( + "CREATE TABLE IF NOT EXISTS fileformat_test (id int) STORED AS sequencefile") + assert(exists) + assert(desc.storage.inputFormat == Some("org.apache.hadoop.mapred.SequenceFileInputFormat")) + assert(desc.storage.outputFormat == Some("org.apache.hadoop.mapred.SequenceFileOutputFormat")) + assert(desc.storage.serde == Some("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe")) + } } test("create hive serde table with new syntax - basic") { From c1bcef876c1415e39e624cfbca9c9bdeae24cbb9 Mon Sep 17 00:00:00 2001 From: Ryan Blue Date: Tue, 13 Feb 2018 11:40:34 +0800 Subject: [PATCH 0334/1161] [SPARK-23323][SQL] Support commit coordinator for DataSourceV2 writes ## What changes were proposed in this pull request? DataSourceV2 batch writes should use the output commit coordinator if it is required by the data source. This adds a new method, `DataWriterFactory#useCommitCoordinator`, that determines whether the coordinator will be used. If the write factory returns true, `WriteToDataSourceV2` will use the coordinator for batch writes. ## How was this patch tested? This relies on existing write tests, which now use the commit coordinator. Author: Ryan Blue Closes #20490 from rdblue/SPARK-23323-add-commit-coordinator. --- .../sources/v2/writer/DataSourceWriter.java | 19 +++++++-- .../datasources/v2/WriteToDataSourceV2.scala | 41 +++++++++++++++---- 2 files changed, 48 insertions(+), 12 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java index e3f682bf96a66..0a0fd8db58035 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java @@ -63,6 +63,16 @@ public interface DataSourceWriter { */ DataWriterFactory createWriterFactory(); + /** + * Returns whether Spark should use the commit coordinator to ensure that at most one attempt for + * each task commits. + * + * @return true if commit coordinator should be used, false otherwise. + */ + default boolean useCommitCoordinator() { + return true; + } + /** * Handles a commit message on receiving from a successful data writer. * @@ -79,10 +89,11 @@ default void onDataWriterCommit(WriterCommitMessage message) {} * failed, and {@link #abort(WriterCommitMessage[])} would be called. The state of the destination * is undefined and @{@link #abort(WriterCommitMessage[])} may not be able to deal with it. * - * Note that, one partition may have multiple committed data writers because of speculative tasks. - * Spark will pick the first successful one and get its commit message. Implementations should be - * aware of this and handle it correctly, e.g., have a coordinator to make sure only one data - * writer can commit, or have a way to clean up the data of already-committed writers. + * Note that speculative execution may cause multiple tasks to run for a partition. By default, + * Spark uses the commit coordinator to allow at most one attempt to commit. Implementations can + * disable this behavior by overriding {@link #useCommitCoordinator()}. If disabled, multiple + * attempts may have committed successfully and one successful commit message per task will be + * passed to this commit method. The remaining commit messages are ignored by Spark. */ void commit(WriterCommitMessage[] messages); diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala index eefbcf4c0e087..535e7962d7439 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.datasources.v2 import org.apache.spark.{SparkEnv, SparkException, TaskContext} +import org.apache.spark.executor.CommitDeniedException import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql.Row @@ -53,6 +54,7 @@ case class WriteToDataSourceV2Exec(writer: DataSourceWriter, query: SparkPlan) e case _ => new InternalRowDataWriterFactory(writer.createWriterFactory(), query.schema) } + val useCommitCoordinator = writer.useCommitCoordinator val rdd = query.execute() val messages = new Array[WriterCommitMessage](rdd.partitions.length) @@ -73,7 +75,7 @@ case class WriteToDataSourceV2Exec(writer: DataSourceWriter, query: SparkPlan) e DataWritingSparkTask.runContinuous(writeTask, context, iter) case _ => (context: TaskContext, iter: Iterator[InternalRow]) => - DataWritingSparkTask.run(writeTask, context, iter) + DataWritingSparkTask.run(writeTask, context, iter, useCommitCoordinator) } sparkContext.runJob( @@ -116,21 +118,44 @@ object DataWritingSparkTask extends Logging { def run( writeTask: DataWriterFactory[InternalRow], context: TaskContext, - iter: Iterator[InternalRow]): WriterCommitMessage = { - val dataWriter = writeTask.createDataWriter(context.partitionId(), context.attemptNumber()) + iter: Iterator[InternalRow], + useCommitCoordinator: Boolean): WriterCommitMessage = { + val stageId = context.stageId() + val partId = context.partitionId() + val attemptId = context.attemptNumber() + val dataWriter = writeTask.createDataWriter(partId, attemptId) // write the data and commit this writer. Utils.tryWithSafeFinallyAndFailureCallbacks(block = { iter.foreach(dataWriter.write) - logInfo(s"Writer for partition ${context.partitionId()} is committing.") - val msg = dataWriter.commit() - logInfo(s"Writer for partition ${context.partitionId()} committed.") + + val msg = if (useCommitCoordinator) { + val coordinator = SparkEnv.get.outputCommitCoordinator + val commitAuthorized = coordinator.canCommit(context.stageId(), partId, attemptId) + if (commitAuthorized) { + logInfo(s"Writer for stage $stageId, task $partId.$attemptId is authorized to commit.") + dataWriter.commit() + } else { + val message = s"Stage $stageId, task $partId.$attemptId: driver did not authorize commit" + logInfo(message) + // throwing CommitDeniedException will trigger the catch block for abort + throw new CommitDeniedException(message, stageId, partId, attemptId) + } + + } else { + logInfo(s"Writer for partition ${context.partitionId()} is committing.") + dataWriter.commit() + } + + logInfo(s"Writer for stage $stageId, task $partId.$attemptId committed.") + msg + })(catchBlock = { // If there is an error, abort this writer - logError(s"Writer for partition ${context.partitionId()} is aborting.") + logError(s"Writer for stage $stageId, task $partId.$attemptId is aborting.") dataWriter.abort() - logError(s"Writer for partition ${context.partitionId()} aborted.") + logError(s"Writer for stage $stageId, task $partId.$attemptId aborted.") }) } From ed4e78bd606e7defc2cd01a5c2e9b47954baa424 Mon Sep 17 00:00:00 2001 From: Feng Liu Date: Mon, 12 Feb 2018 20:57:26 -0800 Subject: [PATCH 0335/1161] [SPARK-23379][SQL] skip when setting the same current database in HiveClientImpl ## What changes were proposed in this pull request? If the target database name is as same as the current database, we should be able to skip one metastore access. ## How was this patch tested? (Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests) (If this patch involves UI changes, please attach a screenshot; otherwise, remove this) Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Feng Liu Closes #20565 from liufengdb/remove-redundant. --- .../apache/spark/sql/hive/client/HiveClientImpl.scala | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala index c223f51b1be75..146fa54a1bce4 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala @@ -292,10 +292,12 @@ private[hive] class HiveClientImpl( } private def setCurrentDatabaseRaw(db: String): Unit = { - if (databaseExists(db)) { - state.setCurrentDatabase(db) - } else { - throw new NoSuchDatabaseException(db) + if (state.getCurrentDatabase != db) { + if (databaseExists(db)) { + state.setCurrentDatabase(db) + } else { + throw new NoSuchDatabaseException(db) + } } } From f17b936f0ddb7d46d1349bd42f9a64c84c06e48d Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 12 Feb 2018 21:12:22 -0800 Subject: [PATCH 0336/1161] [SPARK-23303][SQL] improve the explain result for data source v2 relations ## What changes were proposed in this pull request? The current explain result for data source v2 relation is unreadable: ``` == Parsed Logical Plan == 'Filter ('i > 6) +- AnalysisBarrier +- Project [j#1] +- DataSourceV2Relation [i#0, j#1], org.apache.spark.sql.sources.v2.AdvancedDataSourceV2$Reader3b415940 == Analyzed Logical Plan == j: int Project [j#1] +- Filter (i#0 > 6) +- Project [j#1, i#0] +- DataSourceV2Relation [i#0, j#1], org.apache.spark.sql.sources.v2.AdvancedDataSourceV2$Reader3b415940 == Optimized Logical Plan == Project [j#1] +- Filter isnotnull(i#0) +- DataSourceV2Relation [i#0, j#1], org.apache.spark.sql.sources.v2.AdvancedDataSourceV2$Reader3b415940 == Physical Plan == *(1) Project [j#1] +- *(1) Filter isnotnull(i#0) +- *(1) DataSourceV2Scan [i#0, j#1], org.apache.spark.sql.sources.v2.AdvancedDataSourceV2$Reader3b415940 ``` after this PR ``` == Parsed Logical Plan == 'Project [unresolvedalias('j, None)] +- AnalysisBarrier +- Relation AdvancedDataSourceV2[i#0, j#1] == Analyzed Logical Plan == j: int Project [j#1] +- Relation AdvancedDataSourceV2[i#0, j#1] == Optimized Logical Plan == Relation AdvancedDataSourceV2[j#1] == Physical Plan == *(1) Scan AdvancedDataSourceV2[j#1] ``` ------- ``` == Analyzed Logical Plan == i: int, j: int Filter (i#88 > 3) +- Relation JavaAdvancedDataSourceV2[i#88, j#89] == Optimized Logical Plan == Filter isnotnull(i#88) +- Relation JavaAdvancedDataSourceV2[i#88, j#89] (PushedFilter: [GreaterThan(i,3)]) == Physical Plan == *(1) Filter isnotnull(i#88) +- *(1) Scan JavaAdvancedDataSourceV2[i#88, j#89] (PushedFilter: [GreaterThan(i,3)]) ``` an example for streaming query ``` == Parsed Logical Plan == Aggregate [value#6], [value#6, count(1) AS count(1)#11L] +- SerializeFromObject [staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, input[0, java.lang.String, true], true, false) AS value#6] +- MapElements , class java.lang.String, [StructField(value,StringType,true)], obj#5: java.lang.String +- DeserializeToObject cast(value#25 as string).toString, obj#4: java.lang.String +- Streaming Relation FakeDataSourceV2$[value#25] == Analyzed Logical Plan == value: string, count(1): bigint Aggregate [value#6], [value#6, count(1) AS count(1)#11L] +- SerializeFromObject [staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, input[0, java.lang.String, true], true, false) AS value#6] +- MapElements , class java.lang.String, [StructField(value,StringType,true)], obj#5: java.lang.String +- DeserializeToObject cast(value#25 as string).toString, obj#4: java.lang.String +- Streaming Relation FakeDataSourceV2$[value#25] == Optimized Logical Plan == Aggregate [value#6], [value#6, count(1) AS count(1)#11L] +- SerializeFromObject [staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, input[0, java.lang.String, true], true, false) AS value#6] +- MapElements , class java.lang.String, [StructField(value,StringType,true)], obj#5: java.lang.String +- DeserializeToObject value#25.toString, obj#4: java.lang.String +- Streaming Relation FakeDataSourceV2$[value#25] == Physical Plan == *(4) HashAggregate(keys=[value#6], functions=[count(1)], output=[value#6, count(1)#11L]) +- StateStoreSave [value#6], state info [ checkpoint = *********(redacted)/cloud/dev/spark/target/tmp/temporary-549f264b-2531-4fcb-a52f-433c77347c12/state, runId = f84d9da9-2f8c-45c1-9ea1-70791be684de, opId = 0, ver = 0, numPartitions = 5], Complete, 0 +- *(3) HashAggregate(keys=[value#6], functions=[merge_count(1)], output=[value#6, count#16L]) +- StateStoreRestore [value#6], state info [ checkpoint = *********(redacted)/cloud/dev/spark/target/tmp/temporary-549f264b-2531-4fcb-a52f-433c77347c12/state, runId = f84d9da9-2f8c-45c1-9ea1-70791be684de, opId = 0, ver = 0, numPartitions = 5] +- *(2) HashAggregate(keys=[value#6], functions=[merge_count(1)], output=[value#6, count#16L]) +- Exchange hashpartitioning(value#6, 5) +- *(1) HashAggregate(keys=[value#6], functions=[partial_count(1)], output=[value#6, count#16L]) +- *(1) SerializeFromObject [staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, input[0, java.lang.String, true], true, false) AS value#6] +- *(1) MapElements , obj#5: java.lang.String +- *(1) DeserializeToObject value#25.toString, obj#4: java.lang.String +- *(1) Scan FakeDataSourceV2$[value#25] ``` ## How was this patch tested? N/A Author: Wenchen Fan Closes #20477 from cloud-fan/explain. --- .../kafka010/KafkaContinuousSourceSuite.scala | 18 +--- .../sql/kafka010/KafkaContinuousTest.scala | 3 +- .../spark/sql/kafka010/KafkaSourceSuite.scala | 3 +- .../apache/spark/sql/DataFrameReader.scala | 8 +- .../v2/DataSourceReaderHolder.scala | 64 ------------- .../v2/DataSourceV2QueryPlan.scala | 96 +++++++++++++++++++ .../datasources/v2/DataSourceV2Relation.scala | 26 +++-- .../datasources/v2/DataSourceV2ScanExec.scala | 6 +- .../datasources/v2/DataSourceV2Strategy.scala | 4 +- .../v2/PushDownOperatorsToDataSource.scala | 4 +- .../streaming/MicroBatchExecution.scala | 22 +++-- .../continuous/ContinuousExecution.scala | 9 +- .../spark/sql/streaming/StreamSuite.scala | 8 +- .../spark/sql/streaming/StreamTest.scala | 2 +- .../continuous/ContinuousSuite.scala | 11 +-- 15 files changed, 157 insertions(+), 127 deletions(-) delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceReaderHolder.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2QueryPlan.scala diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala index a7083fa4e3417..72ee0c551ec3d 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala @@ -17,20 +17,9 @@ package org.apache.spark.sql.kafka010 -import java.util.Properties -import java.util.concurrent.atomic.AtomicInteger - -import org.scalatest.time.SpanSugar._ -import scala.collection.mutable -import scala.util.Random - -import org.apache.spark.SparkContext -import org.apache.spark.sql.{DataFrame, Dataset, ForeachWriter, Row} +import org.apache.spark.sql.Dataset import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation -import org.apache.spark.sql.execution.streaming.StreamExecution -import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution -import org.apache.spark.sql.streaming.{StreamTest, Trigger} -import org.apache.spark.sql.test.{SharedSQLContext, TestSparkSession} +import org.apache.spark.sql.streaming.Trigger // Run tests in KafkaSourceSuiteBase in continuous execution mode. class KafkaContinuousSourceSuite extends KafkaSourceSuiteBase with KafkaContinuousTest @@ -71,7 +60,8 @@ class KafkaContinuousSourceTopicDeletionSuite extends KafkaContinuousTest { eventually(timeout(streamingTimeout)) { assert( query.lastExecution.logical.collectFirst { - case DataSourceV2Relation(_, r: KafkaContinuousReader) => r + case r: DataSourceV2Relation if r.reader.isInstanceOf[KafkaContinuousReader] => + r.reader.asInstanceOf[KafkaContinuousReader] }.exists { r => // Ensure the new topic is present and the old topic is gone. r.knownPartitions.exists(_.topic == topic2) diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala index 5a1a14f7a307a..d34458ac81014 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala @@ -47,7 +47,8 @@ trait KafkaContinuousTest extends KafkaSourceTest { eventually(timeout(streamingTimeout)) { assert( query.lastExecution.logical.collectFirst { - case DataSourceV2Relation(_, r: KafkaContinuousReader) => r + case r: DataSourceV2Relation if r.reader.isInstanceOf[KafkaContinuousReader] => + r.reader.asInstanceOf[KafkaContinuousReader] }.exists(_.knownPartitions.size == newCount), s"query never reconfigured to $newCount partitions") } diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala index 02c87643568bd..cb09cce75ff6f 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala @@ -117,7 +117,8 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext { } ++ (query.get.lastExecution match { case null => Seq() case e => e.logical.collect { - case DataSourceV2Relation(_, reader: KafkaContinuousReader) => reader + case r: DataSourceV2Relation if r.reader.isInstanceOf[KafkaContinuousReader] => + r.reader.asInstanceOf[KafkaContinuousReader] } }) if (sources.isEmpty) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index fcaf8d618c168..984b6510f2dbe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -189,11 +189,9 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { val cls = DataSource.lookupDataSource(source, sparkSession.sessionState.conf) if (classOf[DataSourceV2].isAssignableFrom(cls)) { - val ds = cls.newInstance() + val ds = cls.newInstance().asInstanceOf[DataSourceV2] val options = new DataSourceOptions((extraOptions ++ - DataSourceV2Utils.extractSessionConfigs( - ds = ds.asInstanceOf[DataSourceV2], - conf = sparkSession.sessionState.conf)).asJava) + DataSourceV2Utils.extractSessionConfigs(ds, sparkSession.sessionState.conf)).asJava) // Streaming also uses the data source V2 API. So it may be that the data source implements // v2, but has no v2 implementation for batch reads. In that case, we fall back to loading @@ -221,7 +219,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { if (reader == null) { loadV1Source(paths: _*) } else { - Dataset.ofRows(sparkSession, DataSourceV2Relation(reader)) + Dataset.ofRows(sparkSession, DataSourceV2Relation(ds, reader)) } } else { loadV1Source(paths: _*) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceReaderHolder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceReaderHolder.scala deleted file mode 100644 index 81219e9771bd8..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceReaderHolder.scala +++ /dev/null @@ -1,64 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.datasources.v2 - -import java.util.Objects - -import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.sources.v2.reader._ - -/** - * A base class for data source reader holder with customized equals/hashCode methods. - */ -trait DataSourceReaderHolder { - - /** - * The output of the data source reader, w.r.t. column pruning. - */ - def output: Seq[Attribute] - - /** - * The held data source reader. - */ - def reader: DataSourceReader - - /** - * The metadata of this data source reader that can be used for equality test. - */ - private def metadata: Seq[Any] = { - val filters: Any = reader match { - case s: SupportsPushDownCatalystFilters => s.pushedCatalystFilters().toSet - case s: SupportsPushDownFilters => s.pushedFilters().toSet - case _ => Nil - } - Seq(output, reader.getClass, filters) - } - - def canEqual(other: Any): Boolean - - override def equals(other: Any): Boolean = other match { - case other: DataSourceReaderHolder => - canEqual(other) && metadata.length == other.metadata.length && - metadata.zip(other.metadata).forall { case (l, r) => l == r } - case _ => false - } - - override def hashCode(): Int = { - metadata.map(Objects.hashCode).foldLeft(0)((a, b) => 31 * a + b) - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2QueryPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2QueryPlan.scala new file mode 100644 index 0000000000000..1e0d088f3a57c --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2QueryPlan.scala @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import java.util.Objects + +import org.apache.commons.lang3.StringUtils + +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.sources.v2.DataSourceV2 +import org.apache.spark.sql.sources.v2.reader._ +import org.apache.spark.util.Utils + +/** + * A base class for data source v2 related query plan(both logical and physical). It defines the + * equals/hashCode methods, and provides a string representation of the query plan, according to + * some common information. + */ +trait DataSourceV2QueryPlan { + + /** + * The output of the data source reader, w.r.t. column pruning. + */ + def output: Seq[Attribute] + + /** + * The instance of this data source implementation. Note that we only consider its class in + * equals/hashCode, not the instance itself. + */ + def source: DataSourceV2 + + /** + * The created data source reader. Here we use it to get the filters that has been pushed down + * so far, itself doesn't take part in the equals/hashCode. + */ + def reader: DataSourceReader + + private lazy val filters = reader match { + case s: SupportsPushDownCatalystFilters => s.pushedCatalystFilters().toSet + case s: SupportsPushDownFilters => s.pushedFilters().toSet + case _ => Set.empty + } + + /** + * The metadata of this data source query plan that can be used for equality check. + */ + private def metadata: Seq[Any] = Seq(output, source.getClass, filters) + + def canEqual(other: Any): Boolean + + override def equals(other: Any): Boolean = other match { + case other: DataSourceV2QueryPlan => canEqual(other) && metadata == other.metadata + case _ => false + } + + override def hashCode(): Int = { + metadata.map(Objects.hashCode).foldLeft(0)((a, b) => 31 * a + b) + } + + def metadataString: String = { + val entries = scala.collection.mutable.ArrayBuffer.empty[(String, String)] + if (filters.nonEmpty) entries += "PushedFilter" -> filters.mkString("[", ", ", "]") + + val outputStr = Utils.truncatedString(output, "[", ", ", "]") + + val entriesStr = if (entries.nonEmpty) { + Utils.truncatedString(entries.map { + case (key, value) => key + ": " + StringUtils.abbreviate(redact(value), 100) + }, " (", ", ", ")") + } else { + "" + } + + s"${source.getClass.getSimpleName}$outputStr$entriesStr" + } + + private def redact(text: String): String = { + Utils.redact(SQLConf.get.stringRedationPattern, text) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala index 38f6b15224788..cd97e0cab6b5c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala @@ -20,15 +20,23 @@ package org.apache.spark.sql.execution.datasources.v2 import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics} +import org.apache.spark.sql.sources.v2.DataSourceV2 import org.apache.spark.sql.sources.v2.reader._ case class DataSourceV2Relation( output: Seq[AttributeReference], - reader: DataSourceReader) - extends LeafNode with MultiInstanceRelation with DataSourceReaderHolder { + source: DataSourceV2, + reader: DataSourceReader, + override val isStreaming: Boolean) + extends LeafNode with MultiInstanceRelation with DataSourceV2QueryPlan { override def canEqual(other: Any): Boolean = other.isInstanceOf[DataSourceV2Relation] + override def simpleString: String = { + val streamingHeader = if (isStreaming) "Streaming " else "" + s"${streamingHeader}Relation $metadataString" + } + override def computeStats(): Statistics = reader match { case r: SupportsReportStatistics => Statistics(sizeInBytes = r.getStatistics.sizeInBytes().orElse(conf.defaultSizeInBytes)) @@ -41,18 +49,8 @@ case class DataSourceV2Relation( } } -/** - * A specialization of DataSourceV2Relation with the streaming bit set to true. Otherwise identical - * to the non-streaming relation. - */ -class StreamingDataSourceV2Relation( - output: Seq[AttributeReference], - reader: DataSourceReader) extends DataSourceV2Relation(output, reader) { - override def isStreaming: Boolean = true -} - object DataSourceV2Relation { - def apply(reader: DataSourceReader): DataSourceV2Relation = { - new DataSourceV2Relation(reader.readSchema().toAttributes, reader) + def apply(source: DataSourceV2, reader: DataSourceReader): DataSourceV2Relation = { + new DataSourceV2Relation(reader.readSchema().toAttributes, source, reader, isStreaming = false) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala index 7d9581be4db89..c99d535efcf81 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical import org.apache.spark.sql.execution.{ColumnarBatchScan, LeafExecNode, WholeStageCodegenExec} import org.apache.spark.sql.execution.streaming.continuous._ +import org.apache.spark.sql.sources.v2.DataSourceV2 import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousReader import org.apache.spark.sql.types.StructType @@ -36,11 +37,14 @@ import org.apache.spark.sql.types.StructType */ case class DataSourceV2ScanExec( output: Seq[AttributeReference], + @transient source: DataSourceV2, @transient reader: DataSourceReader) - extends LeafExecNode with DataSourceReaderHolder with ColumnarBatchScan { + extends LeafExecNode with DataSourceV2QueryPlan with ColumnarBatchScan { override def canEqual(other: Any): Boolean = other.isInstanceOf[DataSourceV2ScanExec] + override def simpleString: String = s"Scan $metadataString" + override def outputPartitioning: physical.Partitioning = reader match { case s: SupportsReportPartitioning => new DataSourcePartitioning( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index df5b524485f54..fb61e6f32b1f4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -23,8 +23,8 @@ import org.apache.spark.sql.execution.SparkPlan object DataSourceV2Strategy extends Strategy { override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case DataSourceV2Relation(output, reader) => - DataSourceV2ScanExec(output, reader) :: Nil + case r: DataSourceV2Relation => + DataSourceV2ScanExec(r.output, r.source, r.reader) :: Nil case WriteToDataSourceV2(writer, query) => WriteToDataSourceV2Exec(writer, planLater(query)) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala index 1ca6cbf061b4e..4cfdd50e8f46b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala @@ -39,11 +39,11 @@ object PushDownOperatorsToDataSource extends Rule[LogicalPlan] with PredicateHel // TODO: Ideally column pruning should be implemented via a plan property that is propagated // top-down, then we can simplify the logic here and only collect target operators. val filterPushed = plan transformUp { - case FilterAndProject(fields, condition, r @ DataSourceV2Relation(_, reader)) => + case FilterAndProject(fields, condition, r: DataSourceV2Relation) => val (candidates, nonDeterministic) = splitConjunctivePredicates(condition).partition(_.deterministic) - val stayUpFilters: Seq[Expression] = reader match { + val stayUpFilters: Seq[Expression] = r.reader match { case r: SupportsPushDownCatalystFilters => r.pushCatalystFilters(candidates.toArray) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index 812533313332e..84564b6639ac9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -27,9 +27,9 @@ import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, CurrentBatchTimestamp, CurrentDate, CurrentTimestamp} import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.execution.SQLExecution -import org.apache.spark.sql.execution.datasources.v2.{StreamingDataSourceV2Relation, WriteToDataSourceV2} +import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, WriteToDataSourceV2} import org.apache.spark.sql.execution.streaming.sources.{InternalRowMicroBatchWriter, MicroBatchWriter} -import org.apache.spark.sql.sources.v2.{DataSourceOptions, MicroBatchReadSupport, StreamWriteSupport} +import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, MicroBatchReadSupport, StreamWriteSupport} import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset => OffsetV2} import org.apache.spark.sql.sources.v2.writer.SupportsWriteInternalRow import org.apache.spark.sql.streaming.{OutputMode, ProcessingTime, Trigger} @@ -52,6 +52,8 @@ class MicroBatchExecution( @volatile protected var sources: Seq[BaseStreamingSource] = Seq.empty + private val readerToDataSourceMap = MutableMap.empty[MicroBatchReader, DataSourceV2] + private val triggerExecutor = trigger match { case t: ProcessingTime => ProcessingTimeExecutor(t, triggerClock) case OneTimeTrigger => OneTimeExecutor() @@ -90,6 +92,7 @@ class MicroBatchExecution( metadataPath, new DataSourceOptions(options.asJava)) nextSourceId += 1 + readerToDataSourceMap(reader) = source StreamingExecutionRelation(reader, output)(sparkSession) }) case s @ StreamingRelationV2(_, sourceName, _, output, v1Relation) => @@ -405,12 +408,15 @@ class MicroBatchExecution( case v1: SerializedOffset => reader.deserializeOffset(v1.json) case v2: OffsetV2 => v2 } - reader.setOffsetRange( - toJava(current), - Optional.of(availableV2)) + reader.setOffsetRange(toJava(current), Optional.of(availableV2)) logDebug(s"Retrieving data from $reader: $current -> $availableV2") - Some(reader -> - new StreamingDataSourceV2Relation(reader.readSchema().toAttributes, reader)) + Some(reader -> new DataSourceV2Relation( + reader.readSchema().toAttributes, + // Provide a fake value here just in case something went wrong, e.g. the reader gives + // a wrong `equals` implementation. + readerToDataSourceMap.getOrElse(reader, FakeDataSourceV2), + reader, + isStreaming = true)) case _ => None } } @@ -500,3 +506,5 @@ class MicroBatchExecution( Optional.ofNullable(scalaOption.orNull) } } + +object FakeDataSourceV2 extends DataSourceV2 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index c3294d64b10cd..f87d57d0b3209 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, CurrentDate, CurrentTimestamp} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.SQLExecution -import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, StreamingDataSourceV2Relation, WriteToDataSourceV2} +import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, WriteToDataSourceV2} import org.apache.spark.sql.execution.streaming.{ContinuousExecutionRelation, StreamingRelationV2, _} import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions, StreamWriteSupport} import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReader, PartitionOffset} @@ -167,7 +167,7 @@ class ContinuousExecution( var insertedSourceId = 0 val withNewSources = logicalPlan transform { - case ContinuousExecutionRelation(_, _, output) => + case ContinuousExecutionRelation(ds, _, output) => val reader = continuousSources(insertedSourceId) insertedSourceId += 1 val newOutput = reader.readSchema().toAttributes @@ -180,7 +180,7 @@ class ContinuousExecution( val loggedOffset = offsets.offsets(0) val realOffset = loggedOffset.map(off => reader.deserializeOffset(off.json)) reader.setStartOffset(java.util.Optional.ofNullable(realOffset.orNull)) - new StreamingDataSourceV2Relation(newOutput, reader) + new DataSourceV2Relation(newOutput, ds, reader, isStreaming = true) } // Rewire the plan to use the new attributes that were returned by the source. @@ -201,7 +201,8 @@ class ContinuousExecution( val withSink = WriteToDataSourceV2(writer, triggerLogicalPlan) val reader = withSink.collect { - case DataSourceV2Relation(_, r: ContinuousReader) => r + case r: DataSourceV2Relation if r.reader.isInstanceOf[ContinuousReader] => + r.reader.asInstanceOf[ContinuousReader] }.head reportTimeTaken("queryPlanning") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala index d1a04833390f5..70eb9f0ac66d5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala @@ -492,16 +492,16 @@ class StreamSuite extends StreamTest { val explainWithoutExtended = q.explainInternal(false) // `extended = false` only displays the physical plan. - assert("StreamingDataSourceV2Relation".r.findAllMatchIn(explainWithoutExtended).size === 0) - assert("DataSourceV2Scan".r.findAllMatchIn(explainWithoutExtended).size === 1) + assert("Streaming Relation".r.findAllMatchIn(explainWithoutExtended).size === 0) + assert("Scan FakeDataSourceV2".r.findAllMatchIn(explainWithoutExtended).size === 1) // Use "StateStoreRestore" to verify that it does output a streaming physical plan assert(explainWithoutExtended.contains("StateStoreRestore")) val explainWithExtended = q.explainInternal(true) // `extended = true` displays 3 logical plans (Parsed/Optimized/Optimized) and 1 physical // plan. - assert("StreamingDataSourceV2Relation".r.findAllMatchIn(explainWithExtended).size === 3) - assert("DataSourceV2Scan".r.findAllMatchIn(explainWithExtended).size === 1) + assert("Streaming Relation".r.findAllMatchIn(explainWithExtended).size === 3) + assert("Scan FakeDataSourceV2".r.findAllMatchIn(explainWithExtended).size === 1) // Use "StateStoreRestore" to verify that it does output a streaming physical plan assert(explainWithExtended.contains("StateStoreRestore")) } finally { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index 37fe595529baf..254394685857b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -605,7 +605,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be plan .collect { case StreamingExecutionRelation(s, _) => s - case DataSourceV2Relation(_, r) => r + case d: DataSourceV2Relation => d.reader } .zipWithIndex .find(_._1 == source) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala index 4b4ed82dc6520..9ee9aaf87f87c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala @@ -17,15 +17,12 @@ package org.apache.spark.sql.streaming.continuous -import java.util.UUID - -import org.apache.spark.{SparkContext, SparkEnv, SparkException} -import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart, SparkListenerTaskStart} +import org.apache.spark.{SparkContext, SparkException} +import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskStart} import org.apache.spark.sql._ -import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanExec, WriteToDataSourceV2Exec} +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanExec import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous._ -import org.apache.spark.sql.execution.streaming.sources.MemorySinkV2 import org.apache.spark.sql.functions._ import org.apache.spark.sql.streaming.{StreamTest, Trigger} import org.apache.spark.sql.test.TestSparkSession @@ -43,7 +40,7 @@ class ContinuousSuiteBase extends StreamTest { case s: ContinuousExecution => assert(numTriggers >= 2, "must wait for at least 2 triggers to ensure query is initialized") val reader = s.lastExecution.executedPlan.collectFirst { - case DataSourceV2ScanExec(_, r: RateStreamContinuousReader) => r + case DataSourceV2ScanExec(_, _, r: RateStreamContinuousReader) => r }.get val deltaMs = numTriggers * 1000 + 300 From 407f67249639709c40c46917700ed6dd736daa7d Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 13 Feb 2018 15:05:13 +0900 Subject: [PATCH 0337/1161] [SPARK-20090][FOLLOW-UP] Revert the deprecation of `names` in PySpark ## What changes were proposed in this pull request? Deprecating the field `name` in PySpark is not expected. This PR is to revert the change. ## How was this patch tested? N/A Author: gatorsmile Closes #20595 from gatorsmile/removeDeprecate. --- python/pyspark/sql/types.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index e25941cd37595..cd857402db8f7 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -455,9 +455,6 @@ class StructType(DataType): Iterating a :class:`StructType` will iterate its :class:`StructField`\\s. A contained :class:`StructField` can be accessed by name or position. - .. note:: `names` attribute is deprecated in 2.3. Use `fieldNames` method instead - to get a list of field names. - >>> struct1 = StructType([StructField("f1", StringType(), True)]) >>> struct1["f1"] StructField(f1,StringType,true) From 9dae715168a8e72e318ab231c34a1069bfa342a6 Mon Sep 17 00:00:00 2001 From: Arseniy Tashoyan Date: Tue, 13 Feb 2018 06:20:34 -0600 Subject: [PATCH 0338/1161] [SPARK-23318][ML] FP-growth: WARN FPGrowth: Input data is not cached ## What changes were proposed in this pull request? Cache the RDD of items in ml.FPGrowth before passing it to mllib.FPGrowth. Cache only when the user did not cache the input dataset of transactions. This fixes the warning about uncached data emerging from mllib.FPGrowth. ## How was this patch tested? Manually: 1. Run ml.FPGrowthExample - warning is there 2. Apply the fix 3. Run ml.FPGrowthExample again - no warning anymore Author: Arseniy Tashoyan Closes #20578 from tashoyan/SPARK-23318. --- .../scala/org/apache/spark/ml/fpm/FPGrowth.scala | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala index aa7871d6ff29d..3d041fc80eb7f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala @@ -32,6 +32,7 @@ import org.apache.spark.mllib.fpm.FPGrowth.FreqItemset import org.apache.spark.sql._ import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ +import org.apache.spark.storage.StorageLevel /** * Common params for FPGrowth and FPGrowthModel @@ -158,18 +159,30 @@ class FPGrowth @Since("2.2.0") ( } private def genericFit[T: ClassTag](dataset: Dataset[_]): FPGrowthModel = { + val handlePersistence = dataset.storageLevel == StorageLevel.NONE + val data = dataset.select($(itemsCol)) - val items = data.where(col($(itemsCol)).isNotNull).rdd.map(r => r.getSeq[T](0).toArray) + val items = data.where(col($(itemsCol)).isNotNull).rdd.map(r => r.getSeq[Any](0).toArray) val mllibFP = new MLlibFPGrowth().setMinSupport($(minSupport)) if (isSet(numPartitions)) { mllibFP.setNumPartitions($(numPartitions)) } + + if (handlePersistence) { + items.persist(StorageLevel.MEMORY_AND_DISK) + } + val parentModel = mllibFP.run(items) val rows = parentModel.freqItemsets.map(f => Row(f.items, f.freq)) val schema = StructType(Seq( StructField("items", dataset.schema($(itemsCol)).dataType, nullable = false), StructField("freq", LongType, nullable = false))) val frequentItems = dataset.sparkSession.createDataFrame(rows, schema) + + if (handlePersistence) { + items.unpersist() + } + copyValues(new FPGrowthModel(uid, frequentItems)).setParent(this) } From 300c40f50ab4258d697f06a814d1491dc875c847 Mon Sep 17 00:00:00 2001 From: guoxiaolong Date: Tue, 13 Feb 2018 06:23:10 -0600 Subject: [PATCH 0339/1161] [SPARK-23384][WEB-UI] When it has no incomplete(completed) applications found, the last updated time is not formatted and client local time zone is not show in history server web ui. ## What changes were proposed in this pull request? When it has no incomplete(completed) applications found, the last updated time is not formatted and client local time zone is not show in history server web ui. It is a bug. fix before: ![1](https://user-images.githubusercontent.com/26266482/36070635-264d7cf0-0f3a-11e8-8426-14135ffedb16.png) fix after: ![2](https://user-images.githubusercontent.com/26266482/36070651-8ec3800e-0f3a-11e8-991c-6122cc9539fe.png) ## How was this patch tested? (Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests) (If this patch involves UI changes, please attach a screenshot; otherwise, remove this) Please review http://spark.apache.org/contributing.html before opening a pull request. Author: guoxiaolong Closes #20573 from guoxiaolongzte/SPARK-23384. --- .../scala/org/apache/spark/deploy/history/HistoryPage.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala index 5d62a7d8bebb4..6fc12d721e6f1 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala @@ -37,7 +37,8 @@ private[history] class HistoryPage(parent: HistoryServer) extends WebUIPage("") val lastUpdatedTime = parent.getLastUpdatedTime() val providerConfig = parent.getProviderConfig() val content = - + ++ +
    @@ -65,7 +66,6 @@ private[history] class HistoryPage(parent: HistoryServer) extends WebUIPage("") if (allAppsSize > 0) { ++
    ++ - ++ ++ } else if (requestedIncomplete) { From 116c581d2658571d38f8b9b27a516ef517170589 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Cattilapiros=E2=80=9D?= Date: Tue, 13 Feb 2018 06:54:15 -0800 Subject: [PATCH 0340/1161] [SPARK-20659][CORE] Removing sc.getExecutorStorageStatus and making StorageStatus private MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? In this PR StorageStatus is made to private and simplified a bit moreover SparkContext.getExecutorStorageStatus method is removed. The reason of keeping StorageStatus is that it is usage from SparkContext.getRDDStorageInfo. Instead of the method SparkContext.getExecutorStorageStatus executor infos are extended with additional memory metrics such as usedOnHeapStorageMemory, usedOffHeapStorageMemory, totalOnHeapStorageMemory, totalOffHeapStorageMemory. ## How was this patch tested? By running existing unit tests. Author: “attilapiros” Author: Attila Zsolt Piros <2017933+attilapiros@users.noreply.github.com> Closes #20546 from attilapiros/SPARK-20659. --- .../org/apache/spark/SparkExecutorInfo.java | 4 + .../scala/org/apache/spark/SparkContext.scala | 19 +- .../org/apache/spark/SparkStatusTracker.scala | 9 +- .../org/apache/spark/StatusAPIImpl.scala | 6 +- .../apache/spark/storage/StorageUtils.scala | 119 +--------- .../org/apache/spark/DistributedSuite.scala | 7 +- .../StandaloneDynamicAllocationSuite.scala | 2 +- .../apache/spark/storage/StorageSuite.scala | 219 ------------------ project/MimaExcludes.scala | 14 ++ .../spark/repl/SingletonReplSuite.scala | 6 +- 10 files changed, 44 insertions(+), 361 deletions(-) diff --git a/core/src/main/java/org/apache/spark/SparkExecutorInfo.java b/core/src/main/java/org/apache/spark/SparkExecutorInfo.java index dc3e826475987..2b93385adf103 100644 --- a/core/src/main/java/org/apache/spark/SparkExecutorInfo.java +++ b/core/src/main/java/org/apache/spark/SparkExecutorInfo.java @@ -30,4 +30,8 @@ public interface SparkExecutorInfo extends Serializable { int port(); long cacheSize(); int numRunningTasks(); + long usedOnHeapStorageMemory(); + long usedOffHeapStorageMemory(); + long totalOnHeapStorageMemory(); + long totalOffHeapStorageMemory(); } diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 3828d4f703247..c4f74c4f1f9c2 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -1715,7 +1715,13 @@ class SparkContext(config: SparkConf) extends Logging { private[spark] def getRDDStorageInfo(filter: RDD[_] => Boolean): Array[RDDInfo] = { assertNotStopped() val rddInfos = persistentRdds.values.filter(filter).map(RDDInfo.fromRdd).toArray - StorageUtils.updateRddInfo(rddInfos, getExecutorStorageStatus) + rddInfos.foreach { rddInfo => + val rddId = rddInfo.id + val rddStorageInfo = statusStore.asOption(statusStore.rdd(rddId)) + rddInfo.numCachedPartitions = rddStorageInfo.map(_.numCachedPartitions).getOrElse(0) + rddInfo.memSize = rddStorageInfo.map(_.memoryUsed).getOrElse(0L) + rddInfo.diskSize = rddStorageInfo.map(_.diskUsed).getOrElse(0L) + } rddInfos.filter(_.isCached) } @@ -1726,17 +1732,6 @@ class SparkContext(config: SparkConf) extends Logging { */ def getPersistentRDDs: Map[Int, RDD[_]] = persistentRdds.toMap - /** - * :: DeveloperApi :: - * Return information about blocks stored in all of the slaves - */ - @DeveloperApi - @deprecated("This method may change or be removed in a future release.", "2.2.0") - def getExecutorStorageStatus: Array[StorageStatus] = { - assertNotStopped() - env.blockManager.master.getStorageStatus - } - /** * :: DeveloperApi :: * Return pools for fair scheduler diff --git a/core/src/main/scala/org/apache/spark/SparkStatusTracker.scala b/core/src/main/scala/org/apache/spark/SparkStatusTracker.scala index 70865cb58c571..815237eba0174 100644 --- a/core/src/main/scala/org/apache/spark/SparkStatusTracker.scala +++ b/core/src/main/scala/org/apache/spark/SparkStatusTracker.scala @@ -97,7 +97,8 @@ class SparkStatusTracker private[spark] (sc: SparkContext, store: AppStatusStore } /** - * Returns information of all known executors, including host, port, cacheSize, numRunningTasks. + * Returns information of all known executors, including host, port, cacheSize, numRunningTasks + * and memory metrics. */ def getExecutorInfos: Array[SparkExecutorInfo] = { store.executorList(true).map { exec => @@ -113,7 +114,11 @@ class SparkStatusTracker private[spark] (sc: SparkContext, store: AppStatusStore host, port, cachedMem, - exec.activeTasks) + exec.activeTasks, + exec.memoryMetrics.map(_.usedOffHeapStorageMemory).getOrElse(0L), + exec.memoryMetrics.map(_.usedOnHeapStorageMemory).getOrElse(0L), + exec.memoryMetrics.map(_.totalOffHeapStorageMemory).getOrElse(0L), + exec.memoryMetrics.map(_.totalOnHeapStorageMemory).getOrElse(0L)) }.toArray } } diff --git a/core/src/main/scala/org/apache/spark/StatusAPIImpl.scala b/core/src/main/scala/org/apache/spark/StatusAPIImpl.scala index c1f24a6377788..6a888c1e9e772 100644 --- a/core/src/main/scala/org/apache/spark/StatusAPIImpl.scala +++ b/core/src/main/scala/org/apache/spark/StatusAPIImpl.scala @@ -38,5 +38,9 @@ private class SparkExecutorInfoImpl( val host: String, val port: Int, val cacheSize: Long, - val numRunningTasks: Int) + val numRunningTasks: Int, + val usedOnHeapStorageMemory: Long, + val usedOffHeapStorageMemory: Long, + val totalOnHeapStorageMemory: Long, + val totalOffHeapStorageMemory: Long) extends SparkExecutorInfo diff --git a/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala b/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala index e9694fdbca2de..adc406bb1c441 100644 --- a/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala +++ b/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala @@ -24,19 +24,15 @@ import scala.collection.mutable import sun.nio.ch.DirectBuffer -import org.apache.spark.annotation.DeveloperApi import org.apache.spark.internal.Logging /** - * :: DeveloperApi :: * Storage information for each BlockManager. * * This class assumes BlockId and BlockStatus are immutable, such that the consumers of this * class cannot mutate the source of the information. Accesses are not thread-safe. */ -@DeveloperApi -@deprecated("This class may be removed or made private in a future release.", "2.2.0") -class StorageStatus( +private[spark] class StorageStatus( val blockManagerId: BlockManagerId, val maxMemory: Long, val maxOnHeapMem: Option[Long], @@ -44,9 +40,6 @@ class StorageStatus( /** * Internal representation of the blocks stored in this block manager. - * - * We store RDD blocks and non-RDD blocks separately to allow quick retrievals of RDD blocks. - * These collections should only be mutated through the add/update/removeBlock methods. */ private val _rddBlocks = new mutable.HashMap[Int, mutable.Map[BlockId, BlockStatus]] private val _nonRddBlocks = new mutable.HashMap[BlockId, BlockStatus] @@ -87,9 +80,6 @@ class StorageStatus( */ def rddBlocks: Map[BlockId, BlockStatus] = _rddBlocks.flatMap { case (_, blocks) => blocks } - /** Return the blocks that belong to the given RDD stored in this block manager. */ - def rddBlocksById(rddId: Int): Map[BlockId, BlockStatus] = _rddBlocks.getOrElse(rddId, Map.empty) - /** Add the given block to this storage status. If it already exists, overwrite it. */ private[spark] def addBlock(blockId: BlockId, blockStatus: BlockStatus): Unit = { updateStorageInfo(blockId, blockStatus) @@ -101,46 +91,6 @@ class StorageStatus( } } - /** Update the given block in this storage status. If it doesn't already exist, add it. */ - private[spark] def updateBlock(blockId: BlockId, blockStatus: BlockStatus): Unit = { - addBlock(blockId, blockStatus) - } - - /** Remove the given block from this storage status. */ - private[spark] def removeBlock(blockId: BlockId): Option[BlockStatus] = { - updateStorageInfo(blockId, BlockStatus.empty) - blockId match { - case RDDBlockId(rddId, _) => - // Actually remove the block, if it exists - if (_rddBlocks.contains(rddId)) { - val removed = _rddBlocks(rddId).remove(blockId) - // If the given RDD has no more blocks left, remove the RDD - if (_rddBlocks(rddId).isEmpty) { - _rddBlocks.remove(rddId) - } - removed - } else { - None - } - case _ => - _nonRddBlocks.remove(blockId) - } - } - - /** - * Return whether the given block is stored in this block manager in O(1) time. - * - * @note This is much faster than `this.blocks.contains`, which is O(blocks) time. - */ - def containsBlock(blockId: BlockId): Boolean = { - blockId match { - case RDDBlockId(rddId, _) => - _rddBlocks.get(rddId).exists(_.contains(blockId)) - case _ => - _nonRddBlocks.contains(blockId) - } - } - /** * Return the given block stored in this block manager in O(1) time. * @@ -155,37 +105,12 @@ class StorageStatus( } } - /** - * Return the number of blocks stored in this block manager in O(RDDs) time. - * - * @note This is much faster than `this.blocks.size`, which is O(blocks) time. - */ - def numBlocks: Int = _nonRddBlocks.size + numRddBlocks - - /** - * Return the number of RDD blocks stored in this block manager in O(RDDs) time. - * - * @note This is much faster than `this.rddBlocks.size`, which is O(RDD blocks) time. - */ - def numRddBlocks: Int = _rddBlocks.values.map(_.size).sum - - /** - * Return the number of blocks that belong to the given RDD in O(1) time. - * - * @note This is much faster than `this.rddBlocksById(rddId).size`, which is - * O(blocks in this RDD) time. - */ - def numRddBlocksById(rddId: Int): Int = _rddBlocks.get(rddId).map(_.size).getOrElse(0) - /** Return the max memory can be used by this block manager. */ def maxMem: Long = maxMemory /** Return the memory remaining in this block manager. */ def memRemaining: Long = maxMem - memUsed - /** Return the memory used by caching RDDs */ - def cacheSize: Long = onHeapCacheSize.getOrElse(0L) + offHeapCacheSize.getOrElse(0L) - /** Return the memory used by this block manager. */ def memUsed: Long = onHeapMemUsed.getOrElse(0L) + offHeapMemUsed.getOrElse(0L) @@ -220,15 +145,9 @@ class StorageStatus( /** Return the disk space used by this block manager. */ def diskUsed: Long = _nonRddStorageInfo.diskUsage + _rddBlocks.keys.toSeq.map(diskUsedByRdd).sum - /** Return the memory used by the given RDD in this block manager in O(1) time. */ - def memUsedByRdd(rddId: Int): Long = _rddStorageInfo.get(rddId).map(_.memoryUsage).getOrElse(0L) - /** Return the disk space used by the given RDD in this block manager in O(1) time. */ def diskUsedByRdd(rddId: Int): Long = _rddStorageInfo.get(rddId).map(_.diskUsage).getOrElse(0L) - /** Return the storage level, if any, used by the given RDD in this block manager. */ - def rddStorageLevel(rddId: Int): Option[StorageLevel] = _rddStorageInfo.get(rddId).map(_.level) - /** * Update the relevant storage info, taking into account any existing status for this block. */ @@ -295,40 +214,4 @@ private[spark] object StorageUtils extends Logging { cleaner.clean() } } - - /** - * Update the given list of RDDInfo with the given list of storage statuses. - * This method overwrites the old values stored in the RDDInfo's. - */ - def updateRddInfo(rddInfos: Seq[RDDInfo], statuses: Seq[StorageStatus]): Unit = { - rddInfos.foreach { rddInfo => - val rddId = rddInfo.id - // Assume all blocks belonging to the same RDD have the same storage level - val storageLevel = statuses - .flatMap(_.rddStorageLevel(rddId)).headOption.getOrElse(StorageLevel.NONE) - val numCachedPartitions = statuses.map(_.numRddBlocksById(rddId)).sum - val memSize = statuses.map(_.memUsedByRdd(rddId)).sum - val diskSize = statuses.map(_.diskUsedByRdd(rddId)).sum - - rddInfo.storageLevel = storageLevel - rddInfo.numCachedPartitions = numCachedPartitions - rddInfo.memSize = memSize - rddInfo.diskSize = diskSize - } - } - - /** - * Return a mapping from block ID to its locations for each block that belongs to the given RDD. - */ - def getRddBlockLocations(rddId: Int, statuses: Seq[StorageStatus]): Map[BlockId, Seq[String]] = { - val blockLocations = new mutable.HashMap[BlockId, mutable.ListBuffer[String]] - statuses.foreach { status => - status.rddBlocksById(rddId).foreach { case (bid, _) => - val location = status.blockManagerId.hostPort - blockLocations.getOrElseUpdate(bid, mutable.ListBuffer.empty) += location - } - } - blockLocations - } - } diff --git a/core/src/test/scala/org/apache/spark/DistributedSuite.scala b/core/src/test/scala/org/apache/spark/DistributedSuite.scala index e09d5f59817b9..28ea0c6f0bdba 100644 --- a/core/src/test/scala/org/apache/spark/DistributedSuite.scala +++ b/core/src/test/scala/org/apache/spark/DistributedSuite.scala @@ -160,11 +160,8 @@ class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContex val data = sc.parallelize(1 to 1000, 10) val cachedData = data.persist(storageLevel) assert(cachedData.count === 1000) - assert(sc.getExecutorStorageStatus.map(_.rddBlocksById(cachedData.id).size).sum === - storageLevel.replication * data.getNumPartitions) - assert(cachedData.count === 1000) - assert(cachedData.count === 1000) - + assert(sc.getRDDStorageInfo.filter(_.id == cachedData.id).map(_.numCachedPartitions).sum === + data.getNumPartitions) // Get all the locations of the first partition and try to fetch the partitions // from those locations. val blockIds = data.partitions.indices.map(index => RDDBlockId(data.id, index)).toArray diff --git a/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala b/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala index bf7480d79f8a1..c21ee7d26f8ca 100644 --- a/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala @@ -610,7 +610,7 @@ class StandaloneDynamicAllocationSuite * we submit a request to kill them. This must be called before each kill request. */ private def syncExecutors(sc: SparkContext): Unit = { - val driverExecutors = sc.getExecutorStorageStatus + val driverExecutors = sc.env.blockManager.master.getStorageStatus .map(_.blockManagerId.executorId) .filter { _ != SparkContext.DRIVER_IDENTIFIER} val masterExecutors = getExecutorIds(sc) diff --git a/core/src/test/scala/org/apache/spark/storage/StorageSuite.scala b/core/src/test/scala/org/apache/spark/storage/StorageSuite.scala index da198f946fd64..ca352387055f4 100644 --- a/core/src/test/scala/org/apache/spark/storage/StorageSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/StorageSuite.scala @@ -51,27 +51,6 @@ class StorageSuite extends SparkFunSuite { assert(status.diskUsed === 60L) } - test("storage status update non-RDD blocks") { - val status = storageStatus1 - status.updateBlock(TestBlockId("foo"), BlockStatus(memAndDisk, 50L, 100L)) - status.updateBlock(TestBlockId("fee"), BlockStatus(memAndDisk, 100L, 20L)) - assert(status.blocks.size === 3) - assert(status.memUsed === 160L) - assert(status.memRemaining === 840L) - assert(status.diskUsed === 140L) - } - - test("storage status remove non-RDD blocks") { - val status = storageStatus1 - status.removeBlock(TestBlockId("foo")) - status.removeBlock(TestBlockId("faa")) - assert(status.blocks.size === 1) - assert(status.blocks.contains(TestBlockId("fee"))) - assert(status.memUsed === 10L) - assert(status.memRemaining === 990L) - assert(status.diskUsed === 20L) - } - // For testing add, update, remove, get, and contains etc. for both RDD and non-RDD blocks private def storageStatus2: StorageStatus = { val status = new StorageStatus(BlockManagerId("big", "dog", 1), 1000L, Some(1000L), Some(0L)) @@ -95,85 +74,6 @@ class StorageSuite extends SparkFunSuite { assert(status.rddBlocks.contains(RDDBlockId(2, 2))) assert(status.rddBlocks.contains(RDDBlockId(2, 3))) assert(status.rddBlocks.contains(RDDBlockId(2, 4))) - assert(status.rddBlocksById(0).size === 1) - assert(status.rddBlocksById(0).contains(RDDBlockId(0, 0))) - assert(status.rddBlocksById(1).size === 1) - assert(status.rddBlocksById(1).contains(RDDBlockId(1, 1))) - assert(status.rddBlocksById(2).size === 3) - assert(status.rddBlocksById(2).contains(RDDBlockId(2, 2))) - assert(status.rddBlocksById(2).contains(RDDBlockId(2, 3))) - assert(status.rddBlocksById(2).contains(RDDBlockId(2, 4))) - assert(status.memUsedByRdd(0) === 10L) - assert(status.memUsedByRdd(1) === 100L) - assert(status.memUsedByRdd(2) === 30L) - assert(status.diskUsedByRdd(0) === 20L) - assert(status.diskUsedByRdd(1) === 200L) - assert(status.diskUsedByRdd(2) === 80L) - assert(status.rddStorageLevel(0) === Some(memAndDisk)) - assert(status.rddStorageLevel(1) === Some(memAndDisk)) - assert(status.rddStorageLevel(2) === Some(memAndDisk)) - - // Verify default values for RDDs that don't exist - assert(status.rddBlocksById(10).isEmpty) - assert(status.memUsedByRdd(10) === 0L) - assert(status.diskUsedByRdd(10) === 0L) - assert(status.rddStorageLevel(10) === None) - } - - test("storage status update RDD blocks") { - val status = storageStatus2 - status.updateBlock(TestBlockId("dan"), BlockStatus(memAndDisk, 5000L, 0L)) - status.updateBlock(RDDBlockId(0, 0), BlockStatus(memAndDisk, 0L, 0L)) - status.updateBlock(RDDBlockId(2, 2), BlockStatus(memAndDisk, 0L, 1000L)) - assert(status.blocks.size === 7) - assert(status.rddBlocks.size === 5) - assert(status.rddBlocksById(0).size === 1) - assert(status.rddBlocksById(1).size === 1) - assert(status.rddBlocksById(2).size === 3) - assert(status.memUsedByRdd(0) === 0L) - assert(status.memUsedByRdd(1) === 100L) - assert(status.memUsedByRdd(2) === 20L) - assert(status.diskUsedByRdd(0) === 0L) - assert(status.diskUsedByRdd(1) === 200L) - assert(status.diskUsedByRdd(2) === 1060L) - } - - test("storage status remove RDD blocks") { - val status = storageStatus2 - status.removeBlock(TestBlockId("man")) - status.removeBlock(RDDBlockId(1, 1)) - status.removeBlock(RDDBlockId(2, 2)) - status.removeBlock(RDDBlockId(2, 4)) - assert(status.blocks.size === 3) - assert(status.rddBlocks.size === 2) - assert(status.rddBlocks.contains(RDDBlockId(0, 0))) - assert(status.rddBlocks.contains(RDDBlockId(2, 3))) - assert(status.rddBlocksById(0).size === 1) - assert(status.rddBlocksById(0).contains(RDDBlockId(0, 0))) - assert(status.rddBlocksById(1).size === 0) - assert(status.rddBlocksById(2).size === 1) - assert(status.rddBlocksById(2).contains(RDDBlockId(2, 3))) - assert(status.memUsedByRdd(0) === 10L) - assert(status.memUsedByRdd(1) === 0L) - assert(status.memUsedByRdd(2) === 10L) - assert(status.diskUsedByRdd(0) === 20L) - assert(status.diskUsedByRdd(1) === 0L) - assert(status.diskUsedByRdd(2) === 20L) - } - - test("storage status containsBlock") { - val status = storageStatus2 - // blocks that actually exist - assert(status.blocks.contains(TestBlockId("dan")) === status.containsBlock(TestBlockId("dan"))) - assert(status.blocks.contains(TestBlockId("man")) === status.containsBlock(TestBlockId("man"))) - assert(status.blocks.contains(RDDBlockId(0, 0)) === status.containsBlock(RDDBlockId(0, 0))) - assert(status.blocks.contains(RDDBlockId(1, 1)) === status.containsBlock(RDDBlockId(1, 1))) - assert(status.blocks.contains(RDDBlockId(2, 2)) === status.containsBlock(RDDBlockId(2, 2))) - assert(status.blocks.contains(RDDBlockId(2, 3)) === status.containsBlock(RDDBlockId(2, 3))) - assert(status.blocks.contains(RDDBlockId(2, 4)) === status.containsBlock(RDDBlockId(2, 4))) - // blocks that don't exist - assert(status.blocks.contains(TestBlockId("fan")) === status.containsBlock(TestBlockId("fan"))) - assert(status.blocks.contains(RDDBlockId(100, 0)) === status.containsBlock(RDDBlockId(100, 0))) } test("storage status getBlock") { @@ -191,40 +91,6 @@ class StorageSuite extends SparkFunSuite { assert(status.blocks.get(RDDBlockId(100, 0)) === status.getBlock(RDDBlockId(100, 0))) } - test("storage status num[Rdd]Blocks") { - val status = storageStatus2 - assert(status.blocks.size === status.numBlocks) - assert(status.rddBlocks.size === status.numRddBlocks) - status.addBlock(TestBlockId("Foo"), BlockStatus(memAndDisk, 0L, 0L)) - status.addBlock(RDDBlockId(4, 4), BlockStatus(memAndDisk, 0L, 0L)) - status.addBlock(RDDBlockId(4, 8), BlockStatus(memAndDisk, 0L, 0L)) - assert(status.blocks.size === status.numBlocks) - assert(status.rddBlocks.size === status.numRddBlocks) - assert(status.rddBlocksById(4).size === status.numRddBlocksById(4)) - assert(status.rddBlocksById(10).size === status.numRddBlocksById(10)) - status.updateBlock(TestBlockId("Foo"), BlockStatus(memAndDisk, 0L, 10L)) - status.updateBlock(RDDBlockId(4, 0), BlockStatus(memAndDisk, 0L, 0L)) - status.updateBlock(RDDBlockId(4, 8), BlockStatus(memAndDisk, 0L, 0L)) - status.updateBlock(RDDBlockId(10, 10), BlockStatus(memAndDisk, 0L, 0L)) - assert(status.blocks.size === status.numBlocks) - assert(status.rddBlocks.size === status.numRddBlocks) - assert(status.rddBlocksById(4).size === status.numRddBlocksById(4)) - assert(status.rddBlocksById(10).size === status.numRddBlocksById(10)) - assert(status.rddBlocksById(100).size === status.numRddBlocksById(100)) - status.removeBlock(RDDBlockId(4, 0)) - status.removeBlock(RDDBlockId(10, 10)) - assert(status.blocks.size === status.numBlocks) - assert(status.rddBlocks.size === status.numRddBlocks) - assert(status.rddBlocksById(4).size === status.numRddBlocksById(4)) - assert(status.rddBlocksById(10).size === status.numRddBlocksById(10)) - // remove a block that doesn't exist - status.removeBlock(RDDBlockId(1000, 999)) - assert(status.blocks.size === status.numBlocks) - assert(status.rddBlocks.size === status.numRddBlocks) - assert(status.rddBlocksById(4).size === status.numRddBlocksById(4)) - assert(status.rddBlocksById(10).size === status.numRddBlocksById(10)) - assert(status.rddBlocksById(1000).size === status.numRddBlocksById(1000)) - } test("storage status memUsed, diskUsed, externalBlockStoreUsed") { val status = storageStatus2 @@ -237,17 +103,6 @@ class StorageSuite extends SparkFunSuite { status.addBlock(RDDBlockId(25, 25), BlockStatus(memAndDisk, 40L, 50L)) assert(status.memUsed === actualMemUsed) assert(status.diskUsed === actualDiskUsed) - status.updateBlock(TestBlockId("dan"), BlockStatus(memAndDisk, 4L, 5L)) - status.updateBlock(RDDBlockId(0, 0), BlockStatus(memAndDisk, 4L, 5L)) - status.updateBlock(RDDBlockId(1, 1), BlockStatus(memAndDisk, 4L, 5L)) - assert(status.memUsed === actualMemUsed) - assert(status.diskUsed === actualDiskUsed) - status.removeBlock(TestBlockId("fire")) - status.removeBlock(TestBlockId("man")) - status.removeBlock(RDDBlockId(2, 2)) - status.removeBlock(RDDBlockId(2, 3)) - assert(status.memUsed === actualMemUsed) - assert(status.diskUsed === actualDiskUsed) } // For testing StorageUtils.updateRddInfo and StorageUtils.getRddBlockLocations @@ -273,65 +128,6 @@ class StorageSuite extends SparkFunSuite { Seq(info0, info1) } - test("StorageUtils.updateRddInfo") { - val storageStatuses = stockStorageStatuses - val rddInfos = stockRDDInfos - StorageUtils.updateRddInfo(rddInfos, storageStatuses) - assert(rddInfos(0).storageLevel === memAndDisk) - assert(rddInfos(0).numCachedPartitions === 5) - assert(rddInfos(0).memSize === 5L) - assert(rddInfos(0).diskSize === 10L) - assert(rddInfos(0).externalBlockStoreSize === 0L) - assert(rddInfos(1).storageLevel === memAndDisk) - assert(rddInfos(1).numCachedPartitions === 3) - assert(rddInfos(1).memSize === 3L) - assert(rddInfos(1).diskSize === 6L) - assert(rddInfos(1).externalBlockStoreSize === 0L) - } - - test("StorageUtils.getRddBlockLocations") { - val storageStatuses = stockStorageStatuses - val blockLocations0 = StorageUtils.getRddBlockLocations(0, storageStatuses) - val blockLocations1 = StorageUtils.getRddBlockLocations(1, storageStatuses) - assert(blockLocations0.size === 5) - assert(blockLocations1.size === 3) - assert(blockLocations0.contains(RDDBlockId(0, 0))) - assert(blockLocations0.contains(RDDBlockId(0, 1))) - assert(blockLocations0.contains(RDDBlockId(0, 2))) - assert(blockLocations0.contains(RDDBlockId(0, 3))) - assert(blockLocations0.contains(RDDBlockId(0, 4))) - assert(blockLocations1.contains(RDDBlockId(1, 0))) - assert(blockLocations1.contains(RDDBlockId(1, 1))) - assert(blockLocations1.contains(RDDBlockId(1, 2))) - assert(blockLocations0(RDDBlockId(0, 0)) === Seq("dog:1")) - assert(blockLocations0(RDDBlockId(0, 1)) === Seq("dog:1")) - assert(blockLocations0(RDDBlockId(0, 2)) === Seq("duck:2")) - assert(blockLocations0(RDDBlockId(0, 3)) === Seq("duck:2")) - assert(blockLocations0(RDDBlockId(0, 4)) === Seq("cat:3")) - assert(blockLocations1(RDDBlockId(1, 0)) === Seq("duck:2")) - assert(blockLocations1(RDDBlockId(1, 1)) === Seq("duck:2")) - assert(blockLocations1(RDDBlockId(1, 2)) === Seq("cat:3")) - } - - test("StorageUtils.getRddBlockLocations with multiple locations") { - val storageStatuses = stockStorageStatuses - storageStatuses(0).addBlock(RDDBlockId(1, 0), BlockStatus(memAndDisk, 1L, 2L)) - storageStatuses(0).addBlock(RDDBlockId(0, 4), BlockStatus(memAndDisk, 1L, 2L)) - storageStatuses(2).addBlock(RDDBlockId(0, 0), BlockStatus(memAndDisk, 1L, 2L)) - val blockLocations0 = StorageUtils.getRddBlockLocations(0, storageStatuses) - val blockLocations1 = StorageUtils.getRddBlockLocations(1, storageStatuses) - assert(blockLocations0.size === 5) - assert(blockLocations1.size === 3) - assert(blockLocations0(RDDBlockId(0, 0)) === Seq("dog:1", "cat:3")) - assert(blockLocations0(RDDBlockId(0, 1)) === Seq("dog:1")) - assert(blockLocations0(RDDBlockId(0, 2)) === Seq("duck:2")) - assert(blockLocations0(RDDBlockId(0, 3)) === Seq("duck:2")) - assert(blockLocations0(RDDBlockId(0, 4)) === Seq("dog:1", "cat:3")) - assert(blockLocations1(RDDBlockId(1, 0)) === Seq("dog:1", "duck:2")) - assert(blockLocations1(RDDBlockId(1, 1)) === Seq("duck:2")) - assert(blockLocations1(RDDBlockId(1, 2)) === Seq("cat:3")) - } - private val offheap = StorageLevel.OFF_HEAP // For testing add, update, remove, get, and contains etc. for both RDD and non-RDD onheap // and offheap blocks @@ -373,21 +169,6 @@ class StorageSuite extends SparkFunSuite { status.addBlock(RDDBlockId(25, 25), BlockStatus(memAndDisk, 40L, 50L)) assert(status.memUsed === actualMemUsed) assert(status.diskUsed === actualDiskUsed) - - status.updateBlock(TestBlockId("dan"), BlockStatus(memAndDisk, 4L, 5L)) - status.updateBlock(RDDBlockId(0, 0), BlockStatus(offheap, 4L, 0L)) - status.updateBlock(RDDBlockId(1, 1), BlockStatus(offheap, 4L, 0L)) - assert(status.memUsed === actualMemUsed) - assert(status.diskUsed === actualDiskUsed) - assert(status.onHeapMemUsed.get === actualOnHeapMemUsed) - assert(status.offHeapMemUsed.get === actualOffHeapMemUsed) - - status.removeBlock(TestBlockId("fire")) - status.removeBlock(TestBlockId("man")) - status.removeBlock(RDDBlockId(2, 2)) - status.removeBlock(RDDBlockId(2, 3)) - assert(status.memUsed === actualMemUsed) - assert(status.diskUsed === actualDiskUsed) } private def storageStatus4: StorageStatus = { diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index d35c50e1d00fe..381f7b5be1ddf 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -36,6 +36,20 @@ object MimaExcludes { // Exclude rules for 2.4.x lazy val v24excludes = v23excludes ++ Seq( + // [SPARK-20659] Remove StorageStatus, or make it private + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.SparkExecutorInfo.totalOffHeapStorageMemory"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.SparkExecutorInfo.usedOffHeapStorageMemory"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.SparkExecutorInfo.usedOnHeapStorageMemory"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.SparkExecutorInfo.totalOnHeapStorageMemory"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.SparkContext.getExecutorStorageStatus"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.StorageStatus.numBlocks"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.StorageStatus.numRddBlocks"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.StorageStatus.containsBlock"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.StorageStatus.rddBlocksById"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.StorageStatus.numRddBlocksById"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.StorageStatus.memUsedByRdd"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.StorageStatus.cacheSize"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.StorageStatus.rddStorageLevel") ) // Exclude rules for 2.3.x diff --git a/repl/src/test/scala/org/apache/spark/repl/SingletonReplSuite.scala b/repl/src/test/scala/org/apache/spark/repl/SingletonReplSuite.scala index ec3d790255ad3..d49e0fd85229f 100644 --- a/repl/src/test/scala/org/apache/spark/repl/SingletonReplSuite.scala +++ b/repl/src/test/scala/org/apache/spark/repl/SingletonReplSuite.scala @@ -350,7 +350,7 @@ class SingletonReplSuite extends SparkFunSuite { """ |val timeout = 60000 // 60 seconds |val start = System.currentTimeMillis - |while(sc.getExecutorStorageStatus.size != 3 && + |while(sc.statusTracker.getExecutorInfos.size != 3 && | (System.currentTimeMillis - start) < timeout) { | Thread.sleep(10) |} @@ -361,11 +361,11 @@ class SingletonReplSuite extends SparkFunSuite { |case class Foo(i: Int) |val ret = sc.parallelize((1 to 100).map(Foo), 10).persist(MEMORY_AND_DISK_2) |ret.count() - |val res = sc.getExecutorStorageStatus.map(s => s.rddBlocksById(ret.id).size).sum + |val res = sc.getRDDStorageInfo.filter(_.id == ret.id).map(_.numCachedPartitions).sum """.stripMargin) assertDoesNotContain("error:", output) assertDoesNotContain("Exception", output) - assertContains("res: Int = 20", output) + assertContains("res: Int = 10", output) } test("should clone and clean line object in ClosureCleaner") { From d6e1958a2472898e60bd013902c2f35111596e40 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Cattilapiros=E2=80=9D?= Date: Tue, 13 Feb 2018 09:54:52 -0600 Subject: [PATCH 0341/1161] [SPARK-23189][CORE][WEB UI] Reflect stage level blacklisting on executor tab MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? The purpose of this PR to reflect the stage level blacklisting on the executor tab for the currently active stages. After this change in the executor tab at the Status column one of the following label will be: - "Blacklisted" when the executor is blacklisted application level (old flag) - "Dead" when the executor is not Blacklisted and not Active - "Blacklisted in Stages: [...]" when the executor is Active but the there are active blacklisted stages for the executor. Within the [] coma separated active stageIDs are listed. - "Active" when the executor is Active and there is no active blacklisted stages for the executor ## How was this patch tested? Both with unit tests and manually. #### Manual test Spark was started as: ```bash bin/spark-shell --master "local-cluster[2,1,1024]" --conf "spark.blacklist.enabled=true" --conf "spark.blacklist.stage.maxFailedTasksPerExecutor=1" --conf "spark.blacklist.application.maxFailedTasksPerExecutor=10" ``` And the job was: ```scala import org.apache.spark.SparkEnv val pairs = sc.parallelize(1 to 10000, 10).map { x => if (SparkEnv.get.executorId.toInt == 0) throw new RuntimeException("Bad executor") else { Thread.sleep(10) (x % 10, x) } } val all = pairs.cogroup(pairs) all.collect() ``` UI screenshots about the running: - One executor is blacklisted in the two stages: ![One executor is blacklisted in two stages](https://issues.apache.org/jira/secure/attachment/12908314/multiple_stages_1.png) - One stage completes the other one is still running: ![One stage completes the other is still running](https://issues.apache.org/jira/secure/attachment/12908315/multiple_stages_2.png) - Both stages are completed: ![Both stages are completed](https://issues.apache.org/jira/secure/attachment/12908316/multiple_stages_3.png) ### Unit tests In AppStatusListenerSuite.scala both the node blacklisting for a stage and the executor blacklisting for stage are tested. Author: “attilapiros” Closes #20408 from attilapiros/SPARK-23189. --- .../apache/spark/ui/static/executorspage.js | 21 +++++--- .../spark/status/AppStatusListener.scala | 49 ++++++++++++++----- .../org/apache/spark/status/LiveEntity.scala | 7 ++- .../org/apache/spark/status/api/v1/api.scala | 3 +- .../executor_list_json_expectation.json | 3 +- .../executor_memory_usage_expectation.json | 15 ++++-- ...xecutor_node_blacklisting_expectation.json | 15 ++++-- ...acklisting_unblacklisting_expectation.json | 15 ++++-- .../spark/status/AppStatusListenerSuite.scala | 21 ++++++++ 9 files changed, 113 insertions(+), 36 deletions(-) diff --git a/core/src/main/resources/org/apache/spark/ui/static/executorspage.js b/core/src/main/resources/org/apache/spark/ui/static/executorspage.js index d430d8c5fb35a..6717af3ac4daf 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/executorspage.js +++ b/core/src/main/resources/org/apache/spark/ui/static/executorspage.js @@ -25,12 +25,18 @@ function getThreadDumpEnabled() { return threadDumpEnabled; } -function formatStatus(status, type) { +function formatStatus(status, type, row) { + if (row.isBlacklisted) { + return "Blacklisted"; + } + if (status) { - return "Active" - } else { - return "Dead" + if (row.blacklistedInStages.length == 0) { + return "Active" + } + return "Active (Blacklisted in Stages: [" + row.blacklistedInStages.join(", ") + "])"; } + return "Dead" } jQuery.extend(jQuery.fn.dataTableExt.oSort, { @@ -415,9 +421,10 @@ $(document).ready(function () { } }, {data: 'hostPort'}, - {data: 'isActive', render: function (data, type, row) { - if (row.isBlacklisted) return "Blacklisted"; - else return formatStatus (data, type); + { + data: 'isActive', + render: function (data, type, row) { + return formatStatus (data, type, row); } }, {data: 'rddBlocks'}, diff --git a/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala index ab01cddfca5b0..79a17e26665fd 100644 --- a/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala +++ b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala @@ -213,11 +213,13 @@ private[spark] class AppStatusListener( override def onExecutorBlacklistedForStage( event: SparkListenerExecutorBlacklistedForStage): Unit = { + val now = System.nanoTime() + Option(liveStages.get((event.stageId, event.stageAttemptId))).foreach { stage => - val now = System.nanoTime() - val esummary = stage.executorSummary(event.executorId) - esummary.isBlacklisted = true - maybeUpdate(esummary, now) + setStageBlackListStatus(stage, now, event.executorId) + } + liveExecutors.get(event.executorId).foreach { exec => + addBlackListedStageTo(exec, event.stageId, now) } } @@ -226,16 +228,29 @@ private[spark] class AppStatusListener( // Implicitly blacklist every available executor for the stage associated with this node Option(liveStages.get((event.stageId, event.stageAttemptId))).foreach { stage => - liveExecutors.values.foreach { exec => - if (exec.hostname == event.hostId) { - val esummary = stage.executorSummary(exec.executorId) - esummary.isBlacklisted = true - maybeUpdate(esummary, now) - } - } + val executorIds = liveExecutors.values.filter(_.host == event.hostId).map(_.executorId).toSeq + setStageBlackListStatus(stage, now, executorIds: _*) + } + liveExecutors.values.filter(_.hostname == event.hostId).foreach { exec => + addBlackListedStageTo(exec, event.stageId, now) } } + private def addBlackListedStageTo(exec: LiveExecutor, stageId: Int, now: Long): Unit = { + exec.blacklistedInStages += stageId + liveUpdate(exec, now) + } + + private def setStageBlackListStatus(stage: LiveStage, now: Long, executorIds: String*): Unit = { + executorIds.foreach { executorId => + val executorStageSummary = stage.executorSummary(executorId) + executorStageSummary.isBlacklisted = true + maybeUpdate(executorStageSummary, now) + } + stage.blackListedExecutors ++= executorIds + maybeUpdate(stage, now) + } + override def onExecutorUnblacklisted(event: SparkListenerExecutorUnblacklisted): Unit = { updateBlackListStatus(event.executorId, false) } @@ -594,12 +609,24 @@ private[spark] class AppStatusListener( stage.executorSummaries.values.foreach(update(_, now)) update(stage, now, last = true) + + val executorIdsForStage = stage.blackListedExecutors + executorIdsForStage.foreach { executorId => + liveExecutors.get(executorId).foreach { exec => + removeBlackListedStageFrom(exec, event.stageInfo.stageId, now) + } + } } appSummary = new AppSummary(appSummary.numCompletedJobs, appSummary.numCompletedStages + 1) kvstore.write(appSummary) } + private def removeBlackListedStageFrom(exec: LiveExecutor, stageId: Int, now: Long) = { + exec.blacklistedInStages -= stageId + liveUpdate(exec, now) + } + override def onBlockManagerAdded(event: SparkListenerBlockManagerAdded): Unit = { // This needs to set fields that are already set by onExecutorAdded because the driver is // considered an "executor" in the UI, but does not have a SparkListenerExecutorAdded event. diff --git a/core/src/main/scala/org/apache/spark/status/LiveEntity.scala b/core/src/main/scala/org/apache/spark/status/LiveEntity.scala index d5f9e19ffdcd0..79e3f13b826ce 100644 --- a/core/src/main/scala/org/apache/spark/status/LiveEntity.scala +++ b/core/src/main/scala/org/apache/spark/status/LiveEntity.scala @@ -20,6 +20,7 @@ package org.apache.spark.status import java.util.Date import java.util.concurrent.atomic.AtomicInteger +import scala.collection.immutable.{HashSet, TreeSet} import scala.collection.mutable.HashMap import com.google.common.collect.Interners @@ -254,6 +255,7 @@ private class LiveExecutor(val executorId: String, _addTime: Long) extends LiveE var totalShuffleRead = 0L var totalShuffleWrite = 0L var isBlacklisted = false + var blacklistedInStages: Set[Int] = TreeSet() var executorLogs = Map[String, String]() @@ -299,7 +301,8 @@ private class LiveExecutor(val executorId: String, _addTime: Long) extends LiveE Option(removeTime), Option(removeReason), executorLogs, - memoryMetrics) + memoryMetrics, + blacklistedInStages) new ExecutorSummaryWrapper(info) } @@ -371,6 +374,8 @@ private class LiveStage extends LiveEntity { val executorSummaries = new HashMap[String, LiveExecutorStageSummary]() + var blackListedExecutors = new HashSet[String]() + // Used for cleanup of tasks after they reach the configured limit. Not written to the store. @volatile var cleaning = false var savedTasks = new AtomicInteger(0) diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala index 550eac3952bbb..a333f1aaf6325 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala @@ -95,7 +95,8 @@ class ExecutorSummary private[spark]( val removeTime: Option[Date], val removeReason: Option[String], val executorLogs: Map[String, String], - val memoryMetrics: Option[MemoryMetrics]) + val memoryMetrics: Option[MemoryMetrics], + val blacklistedInStages: Set[Int]) class MemoryMetrics private[spark]( val usedOnHeapStorageMemory: Long, diff --git a/core/src/test/resources/HistoryServerExpectations/executor_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/executor_list_json_expectation.json index 942e6d8f04363..7bb8fe8fd8f98 100644 --- a/core/src/test/resources/HistoryServerExpectations/executor_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/executor_list_json_expectation.json @@ -19,5 +19,6 @@ "isBlacklisted" : false, "maxMemory" : 278302556, "addTime" : "2015-02-03T16:43:00.906GMT", - "executorLogs" : { } + "executorLogs" : { }, + "blacklistedInStages" : [ ] } ] diff --git a/core/src/test/resources/HistoryServerExpectations/executor_memory_usage_expectation.json b/core/src/test/resources/HistoryServerExpectations/executor_memory_usage_expectation.json index ed33c90dd39ba..dd5b1dcb7372b 100644 --- a/core/src/test/resources/HistoryServerExpectations/executor_memory_usage_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/executor_memory_usage_expectation.json @@ -25,7 +25,8 @@ "usedOffHeapStorageMemory" : 0, "totalOnHeapStorageMemory" : 384093388, "totalOffHeapStorageMemory" : 524288000 - } + }, + "blacklistedInStages" : [ ] }, { "id" : "3", "hostPort" : "172.22.0.167:51485", @@ -56,7 +57,8 @@ "usedOffHeapStorageMemory" : 0, "totalOnHeapStorageMemory" : 384093388, "totalOffHeapStorageMemory" : 524288000 - } + }, + "blacklistedInStages" : [ ] } ,{ "id" : "2", "hostPort" : "172.22.0.167:51487", @@ -87,7 +89,8 @@ "usedOffHeapStorageMemory" : 0, "totalOnHeapStorageMemory" : 384093388, "totalOffHeapStorageMemory" : 524288000 - } + }, + "blacklistedInStages" : [ ] }, { "id" : "1", "hostPort" : "172.22.0.167:51490", @@ -118,7 +121,8 @@ "usedOffHeapStorageMemory": 0, "totalOnHeapStorageMemory": 384093388, "totalOffHeapStorageMemory": 524288000 - } + }, + "blacklistedInStages" : [ ] }, { "id" : "0", "hostPort" : "172.22.0.167:51491", @@ -149,5 +153,6 @@ "usedOffHeapStorageMemory" : 0, "totalOnHeapStorageMemory" : 384093388, "totalOffHeapStorageMemory" : 524288000 - } + }, + "blacklistedInStages" : [ ] } ] diff --git a/core/src/test/resources/HistoryServerExpectations/executor_node_blacklisting_expectation.json b/core/src/test/resources/HistoryServerExpectations/executor_node_blacklisting_expectation.json index 73519f1d9e2e4..3e55d3d9d7eb9 100644 --- a/core/src/test/resources/HistoryServerExpectations/executor_node_blacklisting_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/executor_node_blacklisting_expectation.json @@ -25,7 +25,8 @@ "usedOffHeapStorageMemory" : 0, "totalOnHeapStorageMemory" : 384093388, "totalOffHeapStorageMemory" : 524288000 - } + }, + "blacklistedInStages" : [ ] }, { "id" : "3", "hostPort" : "172.22.0.167:51485", @@ -56,7 +57,8 @@ "usedOffHeapStorageMemory" : 0, "totalOnHeapStorageMemory" : 384093388, "totalOffHeapStorageMemory" : 524288000 - } + }, + "blacklistedInStages" : [ ] }, { "id" : "2", "hostPort" : "172.22.0.167:51487", @@ -87,7 +89,8 @@ "usedOffHeapStorageMemory" : 0, "totalOnHeapStorageMemory" : 384093388, "totalOffHeapStorageMemory" : 524288000 - } + }, + "blacklistedInStages" : [ ] }, { "id" : "1", "hostPort" : "172.22.0.167:51490", @@ -118,7 +121,8 @@ "usedOffHeapStorageMemory": 0, "totalOnHeapStorageMemory": 384093388, "totalOffHeapStorageMemory": 524288000 - } + }, + "blacklistedInStages" : [ ] }, { "id" : "0", "hostPort" : "172.22.0.167:51491", @@ -149,5 +153,6 @@ "usedOffHeapStorageMemory": 0, "totalOnHeapStorageMemory": 384093388, "totalOffHeapStorageMemory": 524288000 - } + }, + "blacklistedInStages" : [ ] } ] diff --git a/core/src/test/resources/HistoryServerExpectations/executor_node_blacklisting_unblacklisting_expectation.json b/core/src/test/resources/HistoryServerExpectations/executor_node_blacklisting_unblacklisting_expectation.json index 6931fead3d2ff..e87f3e78f2dc8 100644 --- a/core/src/test/resources/HistoryServerExpectations/executor_node_blacklisting_unblacklisting_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/executor_node_blacklisting_unblacklisting_expectation.json @@ -19,7 +19,8 @@ "isBlacklisted" : false, "maxMemory" : 384093388, "addTime" : "2016-11-15T23:20:38.836GMT", - "executorLogs" : { } + "executorLogs" : { }, + "blacklistedInStages" : [ ] }, { "id" : "3", "hostPort" : "172.22.0.111:64543", @@ -44,7 +45,8 @@ "executorLogs" : { "stdout" : "http://172.22.0.111:64521/logPage/?appId=app-20161115172038-0000&executorId=3&logType=stdout", "stderr" : "http://172.22.0.111:64521/logPage/?appId=app-20161115172038-0000&executorId=3&logType=stderr" - } + }, + "blacklistedInStages" : [ ] }, { "id" : "2", "hostPort" : "172.22.0.111:64539", @@ -69,7 +71,8 @@ "executorLogs" : { "stdout" : "http://172.22.0.111:64519/logPage/?appId=app-20161115172038-0000&executorId=2&logType=stdout", "stderr" : "http://172.22.0.111:64519/logPage/?appId=app-20161115172038-0000&executorId=2&logType=stderr" - } + }, + "blacklistedInStages" : [ ] }, { "id" : "1", "hostPort" : "172.22.0.111:64541", @@ -94,7 +97,8 @@ "executorLogs" : { "stdout" : "http://172.22.0.111:64518/logPage/?appId=app-20161115172038-0000&executorId=1&logType=stdout", "stderr" : "http://172.22.0.111:64518/logPage/?appId=app-20161115172038-0000&executorId=1&logType=stderr" - } + }, + "blacklistedInStages" : [ ] }, { "id" : "0", "hostPort" : "172.22.0.111:64540", @@ -119,5 +123,6 @@ "executorLogs" : { "stdout" : "http://172.22.0.111:64517/logPage/?appId=app-20161115172038-0000&executorId=0&logType=stdout", "stderr" : "http://172.22.0.111:64517/logPage/?appId=app-20161115172038-0000&executorId=0&logType=stderr" - } + }, + "blacklistedInStages" : [ ] } ] diff --git a/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala index b74d6ee2ec836..749502709b5c8 100644 --- a/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala @@ -273,6 +273,10 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { assert(exec.info.isBlacklistedForStage === expectedBlacklistedFlag) } + check[ExecutorSummaryWrapper](execIds.head) { exec => + assert(exec.info.blacklistedInStages === Set(stages.head.stageId)) + } + // Blacklisting node for stage time += 1 listener.onNodeBlacklistedForStage(SparkListenerNodeBlacklistedForStage( @@ -439,6 +443,10 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { assert(stage.info.numCompleteTasks === pending.size) } + check[ExecutorSummaryWrapper](execIds.head) { exec => + assert(exec.info.blacklistedInStages === Set()) + } + // Submit stage 2. time += 1 stages.last.submissionTime = Some(time) @@ -453,6 +461,19 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { assert(stage.info.submissionTime === Some(new Date(stages.last.submissionTime.get))) } + // Blacklisting node for stage + time += 1 + listener.onNodeBlacklistedForStage(SparkListenerNodeBlacklistedForStage( + time = time, + hostId = "1.example.com", + executorFailures = 1, + stageId = stages.last.stageId, + stageAttemptId = stages.last.attemptId)) + + check[ExecutorSummaryWrapper](execIds.head) { exec => + assert(exec.info.blacklistedInStages === Set(stages.last.stageId)) + } + // Start and fail all tasks of stage 2. time += 1 val s2Tasks = createTasks(4, execIds) From 091a000d27f324de8c5c527880854ecfcf5de9a4 Mon Sep 17 00:00:00 2001 From: huangtengfei Date: Tue, 13 Feb 2018 09:59:21 -0600 Subject: [PATCH 0342/1161] [SPARK-23053][CORE] taskBinarySerialization and task partitions calculate in DagScheduler.submitMissingTasks should keep the same RDD checkpoint status ## What changes were proposed in this pull request? When we run concurrent jobs using the same rdd which is marked to do checkpoint. If one job has finished running the job, and start the process of RDD.doCheckpoint, while another job is submitted, then submitStage and submitMissingTasks will be called. In [submitMissingTasks](https://github.com/apache/spark/blob/master/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala#L961), will serialize taskBinaryBytes and calculate task partitions which are both affected by the status of checkpoint, if the former is calculated before doCheckpoint finished, while the latter is calculated after doCheckpoint finished, when run task, rdd.compute will be called, for some rdds with particular partition type such as [UnionRDD](https://github.com/apache/spark/blob/master/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala) who will do partition type cast, will get a ClassCastException because the part params is actually a CheckpointRDDPartition. This error occurs because rdd.doCheckpoint occurs in the same thread that called sc.runJob, while the task serialization occurs in the DAGSchedulers event loop. ## How was this patch tested? the exist uts and also add a test case in DAGScheduerSuite to show the exception case. Author: huangtengfei Closes #20244 from ivoson/branch-taskpart-mistype. --- .../apache/spark/scheduler/DAGScheduler.scala | 27 ++++++++++++------- 1 file changed, 18 insertions(+), 9 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 199937b8c27af..8c46a84323392 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -39,7 +39,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.internal.config import org.apache.spark.network.util.JavaUtils import org.apache.spark.partial.{ApproximateActionListener, ApproximateEvaluator, PartialResult} -import org.apache.spark.rdd.RDD +import org.apache.spark.rdd.{RDD, RDDCheckpointData} import org.apache.spark.rpc.RpcTimeout import org.apache.spark.storage._ import org.apache.spark.storage.BlockManagerMessages.BlockManagerHeartbeat @@ -1016,15 +1016,24 @@ class DAGScheduler( // might modify state of objects referenced in their closures. This is necessary in Hadoop // where the JobConf/Configuration object is not thread-safe. var taskBinary: Broadcast[Array[Byte]] = null + var partitions: Array[Partition] = null try { // For ShuffleMapTask, serialize and broadcast (rdd, shuffleDep). // For ResultTask, serialize and broadcast (rdd, func). - val taskBinaryBytes: Array[Byte] = stage match { - case stage: ShuffleMapStage => - JavaUtils.bufferToArray( - closureSerializer.serialize((stage.rdd, stage.shuffleDep): AnyRef)) - case stage: ResultStage => - JavaUtils.bufferToArray(closureSerializer.serialize((stage.rdd, stage.func): AnyRef)) + var taskBinaryBytes: Array[Byte] = null + // taskBinaryBytes and partitions are both effected by the checkpoint status. We need + // this synchronization in case another concurrent job is checkpointing this RDD, so we get a + // consistent view of both variables. + RDDCheckpointData.synchronized { + taskBinaryBytes = stage match { + case stage: ShuffleMapStage => + JavaUtils.bufferToArray( + closureSerializer.serialize((stage.rdd, stage.shuffleDep): AnyRef)) + case stage: ResultStage => + JavaUtils.bufferToArray(closureSerializer.serialize((stage.rdd, stage.func): AnyRef)) + } + + partitions = stage.rdd.partitions } taskBinary = sc.broadcast(taskBinaryBytes) @@ -1049,7 +1058,7 @@ class DAGScheduler( stage.pendingPartitions.clear() partitionsToCompute.map { id => val locs = taskIdToLocations(id) - val part = stage.rdd.partitions(id) + val part = partitions(id) stage.pendingPartitions += id new ShuffleMapTask(stage.id, stage.latestInfo.attemptNumber, taskBinary, part, locs, properties, serializedTaskMetrics, Option(jobId), @@ -1059,7 +1068,7 @@ class DAGScheduler( case stage: ResultStage => partitionsToCompute.map { id => val p: Int = stage.partitions(id) - val part = stage.rdd.partitions(p) + val part = partitions(p) val locs = taskIdToLocations(id) new ResultTask(stage.id, stage.latestInfo.attemptNumber, taskBinary, part, locs, id, properties, serializedTaskMetrics, From bd24731722a9142c90cf3d76008115f308203844 Mon Sep 17 00:00:00 2001 From: guoxiaolong Date: Tue, 13 Feb 2018 11:39:33 -0600 Subject: [PATCH 0343/1161] [SPARK-23382][WEB-UI] Spark Streaming ui about the contents of the for need to have hidden and show features, when the table records very much. ## What changes were proposed in this pull request? Spark Streaming ui about the contents of the for need to have hidden and show features, when the table records very much. please refer to https://github.com/apache/spark/pull/20216 fix after: ![1](https://user-images.githubusercontent.com/26266482/36068644-df029328-0f14-11e8-8350-cfdde9733ffc.png) ## How was this patch tested? (Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests) (If this patch involves UI changes, please attach a screenshot; otherwise, remove this) Please review http://spark.apache.org/contributing.html before opening a pull request. Author: guoxiaolong Closes #20570 from guoxiaolongzte/SPARK-23382. --- .../org/apache/spark/ui/static/webui.js | 2 + .../spark/streaming/ui/StreamingPage.scala | 37 ++++++++++++++++--- 2 files changed, 33 insertions(+), 6 deletions(-) diff --git a/core/src/main/resources/org/apache/spark/ui/static/webui.js b/core/src/main/resources/org/apache/spark/ui/static/webui.js index e575c4c78970d..83009df91d30a 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/webui.js +++ b/core/src/main/resources/org/apache/spark/ui/static/webui.js @@ -80,4 +80,6 @@ $(function() { collapseTablePageLoad('collapse-aggregated-poolActiveStages','aggregated-poolActiveStages'); collapseTablePageLoad('collapse-aggregated-tasks','aggregated-tasks'); collapseTablePageLoad('collapse-aggregated-rdds','aggregated-rdds'); + collapseTablePageLoad('collapse-aggregated-activeBatches','aggregated-activeBatches'); + collapseTablePageLoad('collapse-aggregated-completedBatches','aggregated-completedBatches'); }); \ No newline at end of file diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala index 7abafd6ba7908..3a176f64cdd60 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala @@ -490,15 +490,40 @@ private[ui] class StreamingPage(parent: StreamingTab) sortBy(_.batchTime.milliseconds).reverse val activeBatchesContent = { -

    Active Batches ({runningBatches.size + waitingBatches.size})

    ++ - new ActiveBatchTable(runningBatches, waitingBatches, listener.batchDuration).toNodeSeq +
    +
    + +

    + + Active Batches ({runningBatches.size + waitingBatches.size}) +

    +
    +
    + {new ActiveBatchTable(runningBatches, waitingBatches, listener.batchDuration).toNodeSeq} +
    +
    +
    } val completedBatchesContent = { -

    - Completed Batches (last {completedBatches.size} out of {listener.numTotalCompletedBatches}) -

    ++ - new CompletedBatchTable(completedBatches, listener.batchDuration).toNodeSeq +
    +
    + +

    + + Completed Batches (last {completedBatches.size} + out of {listener.numTotalCompletedBatches}) +

    +
    +
    + {new CompletedBatchTable(completedBatches, listener.batchDuration).toNodeSeq} +
    +
    +
    } activeBatchesContent ++ completedBatchesContent From 263531466f4a7e223c94caa8705e6e8394a12054 Mon Sep 17 00:00:00 2001 From: xubo245 <601450868@qq.com> Date: Tue, 13 Feb 2018 11:45:20 -0600 Subject: [PATCH 0344/1161] [SPARK-23392][TEST] Add some test cases for images feature ## What changes were proposed in this pull request? Add some test cases for images feature ## How was this patch tested? Add some test cases in ImageSchemaSuite Author: xubo245 <601450868@qq.com> Closes #20583 from xubo245/CARBONDATA23392_AddTestForImage. --- .../spark/ml/image/ImageSchemaSuite.scala | 62 ++++++++++++++++++- 1 file changed, 61 insertions(+), 1 deletion(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/image/ImageSchemaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/image/ImageSchemaSuite.scala index a8833c615865d..527b3f8955968 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/image/ImageSchemaSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/image/ImageSchemaSuite.scala @@ -65,11 +65,71 @@ class ImageSchemaSuite extends SparkFunSuite with MLlibTestSparkContext { assert(count50 > 0 && count50 < countTotal) } + test("readImages test: recursive = false") { + val df = readImages(imagePath, null, false, 3, true, 1.0, 0) + assert(df.count() === 0) + } + + test("readImages test: read jpg image") { + val df = readImages(imagePath + "/kittens/DP153539.jpg", null, false, 3, true, 1.0, 0) + assert(df.count() === 1) + } + + test("readImages test: read png image") { + val df = readImages(imagePath + "/multi-channel/BGRA.png", null, false, 3, true, 1.0, 0) + assert(df.count() === 1) + } + + test("readImages test: read non image") { + val df = readImages(imagePath + "/kittens/not-image.txt", null, false, 3, true, 1.0, 0) + assert(df.schema("image").dataType == columnSchema, "data do not fit ImageSchema") + assert(df.count() === 0) + } + + test("readImages test: read non image and dropImageFailures is false") { + val df = readImages(imagePath + "/kittens/not-image.txt", null, false, 3, false, 1.0, 0) + assert(df.count() === 1) + } + + test("readImages test: sampleRatio > 1") { + val e = intercept[IllegalArgumentException] { + readImages(imagePath, null, true, 3, true, 1.1, 0) + } + assert(e.getMessage.contains("sampleRatio")) + } + + test("readImages test: sampleRatio < 0") { + val e = intercept[IllegalArgumentException] { + readImages(imagePath, null, true, 3, true, -0.1, 0) + } + assert(e.getMessage.contains("sampleRatio")) + } + + test("readImages test: sampleRatio = 0") { + val df = readImages(imagePath, null, true, 3, true, 0.0, 0) + assert(df.count() === 0) + } + + test("readImages test: with sparkSession") { + val df = readImages(imagePath, sparkSession = spark, true, 3, true, 1.0, 0) + assert(df.count() === 8) + } + test("readImages partition test") { val df = readImages(imagePath, null, true, 3, true, 1.0, 0) assert(df.rdd.getNumPartitions === 3) } + test("readImages partition test: < 0") { + val df = readImages(imagePath, null, true, -3, true, 1.0, 0) + assert(df.rdd.getNumPartitions === spark.sparkContext.defaultParallelism) + } + + test("readImages partition test: = 0") { + val df = readImages(imagePath, null, true, 0, true, 1.0, 0) + assert(df.rdd.getNumPartitions === spark.sparkContext.defaultParallelism) + } + // Images with the different number of channels test("readImages pixel values test") { @@ -93,7 +153,7 @@ class ImageSchemaSuite extends SparkFunSuite with MLlibTestSparkContext { // - default representation for 3-channel RGB images is BGR row-wise: // (B00, G00, R00, B10, G10, R10, ...) // - default representation for 4-channel RGB images is BGRA row-wise: - // (B00, G00, R00, A00, B10, G10, R10, A00, ...) + // (B00, G00, R00, A00, B10, G10, R10, A10, ...) private val firstBytes20 = Map( "grayscale.jpg" -> (("CV_8UC1", Array[Byte](-2, -33, -61, -60, -59, -59, -64, -59, -66, -67, -73, -73, -62, From 05d051293fe46938e9cb012342fea6e8a3715cd4 Mon Sep 17 00:00:00 2001 From: Bogdan Raducanu Date: Tue, 13 Feb 2018 09:49:52 -0800 Subject: [PATCH 0345/1161] [SPARK-23316][SQL] AnalysisException after max iteration reached for IN query ## What changes were proposed in this pull request? Added flag ignoreNullability to DataType.equalsStructurally. The previous semantic is for ignoreNullability=false. When ignoreNullability=true equalsStructurally ignores nullability of contained types (map key types, value types, array element types, structure field types). In.checkInputTypes calls equalsStructurally to check if the children types match. They should match regardless of nullability (which is just a hint), so it is now called with ignoreNullability=true. ## How was this patch tested? New test in SubquerySuite Author: Bogdan Raducanu Closes #20548 from bogdanrdc/SPARK-23316. --- .../sql/catalyst/expressions/predicates.scala | 3 ++- .../org/apache/spark/sql/types/DataType.scala | 18 ++++++++++++------ .../org/apache/spark/sql/SubquerySuite.scala | 5 +++++ 3 files changed, 19 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index b469f5cb7586a..a6d41ea7d00d4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -157,7 +157,8 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { require(list != null, "list should not be null") override def checkInputDataTypes(): TypeCheckResult = { - val mismatchOpt = list.find(l => !DataType.equalsStructurally(l.dataType, value.dataType)) + val mismatchOpt = list.find(l => !DataType.equalsStructurally(l.dataType, value.dataType, + ignoreNullability = true)) if (mismatchOpt.isDefined) { list match { case ListQuery(_, _, _, childOutputs) :: Nil => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala index d6e0df12218ad..0bef11659fc9e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -295,25 +295,31 @@ object DataType { } /** - * Returns true if the two data types share the same "shape", i.e. the types (including - * nullability) are the same, but the field names don't need to be the same. + * Returns true if the two data types share the same "shape", i.e. the types + * are the same, but the field names don't need to be the same. + * + * @param ignoreNullability whether to ignore nullability when comparing the types */ - def equalsStructurally(from: DataType, to: DataType): Boolean = { + def equalsStructurally( + from: DataType, + to: DataType, + ignoreNullability: Boolean = false): Boolean = { (from, to) match { case (left: ArrayType, right: ArrayType) => equalsStructurally(left.elementType, right.elementType) && - left.containsNull == right.containsNull + (ignoreNullability || left.containsNull == right.containsNull) case (left: MapType, right: MapType) => equalsStructurally(left.keyType, right.keyType) && equalsStructurally(left.valueType, right.valueType) && - left.valueContainsNull == right.valueContainsNull + (ignoreNullability || left.valueContainsNull == right.valueContainsNull) case (StructType(fromFields), StructType(toFields)) => fromFields.length == toFields.length && fromFields.zip(toFields) .forall { case (l, r) => - equalsStructurally(l.dataType, r.dataType) && l.nullable == r.nullable + equalsStructurally(l.dataType, r.dataType) && + (ignoreNullability || l.nullable == r.nullable) } case (fromDataType, toDataType) => fromDataType == toDataType diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala index 8673dc14f7597..31e8b0e8dede0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala @@ -950,4 +950,9 @@ class SubquerySuite extends QueryTest with SharedSQLContext { assert(join.duplicateResolved) assert(optimizedPlan.resolved) } + + test("SPARK-23316: AnalysisException after max iteration reached for IN query") { + // before the fix this would throw AnalysisException + spark.range(10).where("(id,id) in (select id, null from range(3))").count + } } From 4e0fb010ccdf13fe411f2a4796bbadc385b01520 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Tue, 13 Feb 2018 11:51:19 -0600 Subject: [PATCH 0346/1161] [SPARK-23217][ML] Add cosine distance measure to ClusteringEvaluator ## What changes were proposed in this pull request? The PR provided an implementation of ClusteringEvaluator using the cosine distance measure. This allows to evaluate clustering results created using the cosine distance, introduced in SPARK-22119. In the corresponding JIRA, there is a design document for the algorithm implemented here. ## How was this patch tested? Added UT which compares the result to the one provided by python sklearn. Author: Marco Gaido Closes #20396 from mgaido91/SPARK-23217. --- .../ml/evaluation/ClusteringEvaluator.scala | 334 ++++++++++++++---- .../evaluation/ClusteringEvaluatorSuite.scala | 32 +- 2 files changed, 300 insertions(+), 66 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringEvaluator.scala index d6ec5223237bb..8d4ae562b3d2b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringEvaluator.scala @@ -20,11 +20,12 @@ package org.apache.spark.ml.evaluation import org.apache.spark.SparkContext import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.broadcast.Broadcast -import org.apache.spark.ml.linalg.{BLAS, DenseVector, Vector, Vectors, VectorUDT} +import org.apache.spark.ml.linalg.{BLAS, DenseVector, SparseVector, Vector, Vectors, VectorUDT} import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators} import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasPredictionCol} -import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable, SchemaUtils} -import org.apache.spark.sql.{DataFrame, Dataset} +import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable, + SchemaUtils} +import org.apache.spark.sql.{Column, DataFrame, Dataset} import org.apache.spark.sql.functions.{avg, col, udf} import org.apache.spark.sql.types.DoubleType @@ -32,15 +33,11 @@ import org.apache.spark.sql.types.DoubleType * :: Experimental :: * * Evaluator for clustering results. - * The metric computes the Silhouette measure - * using the squared Euclidean distance. - * - * The Silhouette is a measure for the validation - * of the consistency within clusters. It ranges - * between 1 and -1, where a value close to 1 - * means that the points in a cluster are close - * to the other points in the same cluster and - * far from the points of the other clusters. + * The metric computes the Silhouette measure using the specified distance measure. + * + * The Silhouette is a measure for the validation of the consistency within clusters. It ranges + * between 1 and -1, where a value close to 1 means that the points in a cluster are close to the + * other points in the same cluster and far from the points of the other clusters. */ @Experimental @Since("2.3.0") @@ -84,18 +81,40 @@ class ClusteringEvaluator @Since("2.3.0") (@Since("2.3.0") override val uid: Str @Since("2.3.0") def setMetricName(value: String): this.type = set(metricName, value) - setDefault(metricName -> "silhouette") + /** + * param for distance measure to be used in evaluation + * (supports `"squaredEuclidean"` (default), `"cosine"`) + * @group param + */ + @Since("2.4.0") + val distanceMeasure: Param[String] = { + val availableValues = Array("squaredEuclidean", "cosine") + val allowedParams = ParamValidators.inArray(availableValues) + new Param(this, "distanceMeasure", "distance measure in evaluation. Supported options: " + + availableValues.mkString("'", "', '", "'"), allowedParams) + } + + /** @group getParam */ + @Since("2.4.0") + def getDistanceMeasure: String = $(distanceMeasure) + + /** @group setParam */ + @Since("2.4.0") + def setDistanceMeasure(value: String): this.type = set(distanceMeasure, value) + + setDefault(metricName -> "silhouette", distanceMeasure -> "squaredEuclidean") @Since("2.3.0") override def evaluate(dataset: Dataset[_]): Double = { SchemaUtils.checkColumnType(dataset.schema, $(featuresCol), new VectorUDT) SchemaUtils.checkNumericType(dataset.schema, $(predictionCol)) - $(metricName) match { - case "silhouette" => + ($(metricName), $(distanceMeasure)) match { + case ("silhouette", "squaredEuclidean") => SquaredEuclideanSilhouette.computeSilhouetteScore( - dataset, $(predictionCol), $(featuresCol) - ) + dataset, $(predictionCol), $(featuresCol)) + case ("silhouette", "cosine") => + CosineSilhouette.computeSilhouetteScore(dataset, $(predictionCol), $(featuresCol)) } } } @@ -111,6 +130,48 @@ object ClusteringEvaluator } +private[evaluation] abstract class Silhouette { + + /** + * It computes the Silhouette coefficient for a point. + */ + def pointSilhouetteCoefficient( + clusterIds: Set[Double], + pointClusterId: Double, + pointClusterNumOfPoints: Long, + averageDistanceToCluster: (Double) => Double): Double = { + // Here we compute the average dissimilarity of the current point to any cluster of which the + // point is not a member. + // The cluster with the lowest average dissimilarity - i.e. the nearest cluster to the current + // point - is said to be the "neighboring cluster". + val otherClusterIds = clusterIds.filter(_ != pointClusterId) + val neighboringClusterDissimilarity = otherClusterIds.map(averageDistanceToCluster).min + + // adjustment for excluding the node itself from the computation of the average dissimilarity + val currentClusterDissimilarity = if (pointClusterNumOfPoints == 1) { + 0.0 + } else { + averageDistanceToCluster(pointClusterId) * pointClusterNumOfPoints / + (pointClusterNumOfPoints - 1) + } + + if (currentClusterDissimilarity < neighboringClusterDissimilarity) { + 1 - (currentClusterDissimilarity / neighboringClusterDissimilarity) + } else if (currentClusterDissimilarity > neighboringClusterDissimilarity) { + (neighboringClusterDissimilarity / currentClusterDissimilarity) - 1 + } else { + 0.0 + } + } + + /** + * Compute the mean Silhouette values of all samples. + */ + def overallScore(df: DataFrame, scoreColumn: Column): Double = { + df.select(avg(scoreColumn)).collect()(0).getDouble(0) + } +} + /** * SquaredEuclideanSilhouette computes the average of the * Silhouette over all the data of the dataset, which is @@ -259,7 +320,7 @@ object ClusteringEvaluator * `N` is the number of points in the dataset and `W` is the number * of worker nodes. */ -private[evaluation] object SquaredEuclideanSilhouette { +private[evaluation] object SquaredEuclideanSilhouette extends Silhouette { private[this] var kryoRegistrationPerformed: Boolean = false @@ -336,18 +397,19 @@ private[evaluation] object SquaredEuclideanSilhouette { * It computes the Silhouette coefficient for a point. * * @param broadcastedClustersMap A map of the precomputed values for each cluster. - * @param features The [[org.apache.spark.ml.linalg.Vector]] representing the current point. + * @param point The [[org.apache.spark.ml.linalg.Vector]] representing the current point. * @param clusterId The id of the cluster the current point belongs to. * @param squaredNorm The `$\Xi_{X}$` (which is the squared norm) precomputed for the point. * @return The Silhouette for the point. */ def computeSilhouetteCoefficient( broadcastedClustersMap: Broadcast[Map[Double, ClusterStats]], - features: Vector, + point: Vector, clusterId: Double, squaredNorm: Double): Double = { - def compute(squaredNorm: Double, point: Vector, clusterStats: ClusterStats): Double = { + def compute(targetClusterId: Double): Double = { + val clusterStats = broadcastedClustersMap.value(targetClusterId) val pointDotClusterFeaturesSum = BLAS.dot(point, clusterStats.featureSum) squaredNorm + @@ -355,41 +417,14 @@ private[evaluation] object SquaredEuclideanSilhouette { 2 * pointDotClusterFeaturesSum / clusterStats.numOfPoints } - // Here we compute the average dissimilarity of the - // current point to any cluster of which the point - // is not a member. - // The cluster with the lowest average dissimilarity - // - i.e. the nearest cluster to the current point - - // is said to be the "neighboring cluster". - var neighboringClusterDissimilarity = Double.MaxValue - broadcastedClustersMap.value.keySet.foreach { - c => - if (c != clusterId) { - val dissimilarity = compute(squaredNorm, features, broadcastedClustersMap.value(c)) - if(dissimilarity < neighboringClusterDissimilarity) { - neighboringClusterDissimilarity = dissimilarity - } - } - } - val currentCluster = broadcastedClustersMap.value(clusterId) - // adjustment for excluding the node itself from - // the computation of the average dissimilarity - val currentClusterDissimilarity = if (currentCluster.numOfPoints == 1) { - 0 - } else { - compute(squaredNorm, features, currentCluster) * currentCluster.numOfPoints / - (currentCluster.numOfPoints - 1) - } - - (currentClusterDissimilarity compare neighboringClusterDissimilarity).signum match { - case -1 => 1 - (currentClusterDissimilarity / neighboringClusterDissimilarity) - case 1 => (neighboringClusterDissimilarity / currentClusterDissimilarity) - 1 - case 0 => 0.0 - } + pointSilhouetteCoefficient(broadcastedClustersMap.value.keySet, + clusterId, + broadcastedClustersMap.value(clusterId).numOfPoints, + compute) } /** - * Compute the mean Silhouette values of all samples. + * Compute the Silhouette score of the dataset using squared Euclidean distance measure. * * @param dataset The input dataset (previously clustered) on which compute the Silhouette. * @param predictionCol The name of the column which contains the predicted cluster id @@ -412,7 +447,7 @@ private[evaluation] object SquaredEuclideanSilhouette { val clustersStatsMap = SquaredEuclideanSilhouette .computeClusterStats(dfWithSquaredNorm, predictionCol, featuresCol) - // Silhouette is reasonable only when the number of clusters is grater then 1 + // Silhouette is reasonable only when the number of clusters is greater then 1 assert(clustersStatsMap.size > 1, "Number of clusters must be greater than one.") val bClustersStatsMap = dataset.sparkSession.sparkContext.broadcast(clustersStatsMap) @@ -421,13 +456,190 @@ private[evaluation] object SquaredEuclideanSilhouette { computeSilhouetteCoefficient(bClustersStatsMap, _: Vector, _: Double, _: Double) } - val silhouetteScore = dfWithSquaredNorm - .select(avg( - computeSilhouetteCoefficientUDF( - col(featuresCol), col(predictionCol).cast(DoubleType), col("squaredNorm")) - )) - .collect()(0) - .getDouble(0) + val silhouetteScore = overallScore(dfWithSquaredNorm, + computeSilhouetteCoefficientUDF(col(featuresCol), col(predictionCol).cast(DoubleType), + col("squaredNorm"))) + + bClustersStatsMap.destroy() + + silhouetteScore + } +} + + +/** + * The algorithm which is implemented in this object, instead, is an efficient and parallel + * implementation of the Silhouette using the cosine distance measure. The cosine distance + * measure is defined as `1 - s` where `s` is the cosine similarity between two points. + * + * The total distance of the point `X` to the points `$C_{i}$` belonging to the cluster `$\Gamma$` + * is: + * + *
    + * $$ + * \sum\limits_{i=1}^N d(X, C_{i} ) = + * \sum\limits_{i=1}^N \Big( 1 - \frac{\sum\limits_{j=1}^D x_{j}c_{ij} }{ \|X\|\|C_{i}\|} \Big) + * = \sum\limits_{i=1}^N 1 - \sum\limits_{i=1}^N \sum\limits_{j=1}^D \frac{x_{j}}{\|X\|} + * \frac{c_{ij}}{\|C_{i}\|} + * = N - \sum\limits_{j=1}^D \frac{x_{j}}{\|X\|} \Big( \sum\limits_{i=1}^N + * \frac{c_{ij}}{\|C_{i}\|} \Big) + * $$ + *
    + * + * where `$x_{j}$` is the `j`-th dimension of the point `X` and `$c_{ij}$` is the `j`-th dimension + * of the `i`-th point in cluster `$\Gamma$`. + * + * Then, we can define the vector: + * + *
    + * $$ + * \xi_{X} : \xi_{X i} = \frac{x_{i}}{\|X\|}, i = 1, ..., D + * $$ + *
    + * + * which can be precomputed for each point and the vector + * + *
    + * $$ + * \Omega_{\Gamma} : \Omega_{\Gamma i} = \sum\limits_{j=1}^N \xi_{C_{j}i}, i = 1, ..., D + * $$ + *
    + * + * which can be precomputed too for each cluster `$\Gamma$` by its points `$C_{i}$`. + * + * With these definitions, the numerator becomes: + * + *
    + * $$ + * N - \sum\limits_{j=1}^D \xi_{X j} \Omega_{\Gamma j} + * $$ + *
    + * + * Thus the average distance of a point `X` to the points of the cluster `$\Gamma$` is: + * + *
    + * $$ + * 1 - \frac{\sum\limits_{j=1}^D \xi_{X j} \Omega_{\Gamma j}}{N} + * $$ + *
    + * + * In the implementation, the precomputed values for the clusters are distributed among the worker + * nodes via broadcasted variables, because we can assume that the clusters are limited in number. + * + * The main strengths of this algorithm are the low computational complexity and the intrinsic + * parallelism. The precomputed information for each point and for each cluster can be computed + * with a computational complexity which is `O(N/W)`, where `N` is the number of points in the + * dataset and `W` is the number of worker nodes. After that, every point can be analyzed + * independently from the others. + * + * For every point we need to compute the average distance to all the clusters. Since the formula + * above requires `O(D)` operations, this phase has a computational complexity which is + * `O(C*D*N/W)` where `C` is the number of clusters (which we assume quite low), `D` is the number + * of dimensions, `N` is the number of points in the dataset and `W` is the number of worker + * nodes. + */ +private[evaluation] object CosineSilhouette extends Silhouette { + + private[this] val normalizedFeaturesColName = "normalizedFeatures" + + /** + * The method takes the input dataset and computes the aggregated values + * about a cluster which are needed by the algorithm. + * + * @param df The DataFrame which contains the input data + * @param predictionCol The name of the column which contains the predicted cluster id + * for the point. + * @return A [[scala.collection.immutable.Map]] which associates each cluster id to a + * its statistics (ie. the precomputed values `N` and `$\Omega_{\Gamma}$`). + */ + def computeClusterStats(df: DataFrame, predictionCol: String): Map[Double, (Vector, Long)] = { + val numFeatures = df.select(col(normalizedFeaturesColName)).first().getAs[Vector](0).size + val clustersStatsRDD = df.select( + col(predictionCol).cast(DoubleType), col(normalizedFeaturesColName)) + .rdd + .map { row => (row.getDouble(0), row.getAs[Vector](1)) } + .aggregateByKey[(DenseVector, Long)]((Vectors.zeros(numFeatures).toDense, 0L))( + seqOp = { + case ((normalizedFeaturesSum: DenseVector, numOfPoints: Long), (normalizedFeatures)) => + BLAS.axpy(1.0, normalizedFeatures, normalizedFeaturesSum) + (normalizedFeaturesSum, numOfPoints + 1) + }, + combOp = { + case ((normalizedFeaturesSum1, numOfPoints1), (normalizedFeaturesSum2, numOfPoints2)) => + BLAS.axpy(1.0, normalizedFeaturesSum2, normalizedFeaturesSum1) + (normalizedFeaturesSum1, numOfPoints1 + numOfPoints2) + } + ) + + clustersStatsRDD + .collectAsMap() + .toMap + } + + /** + * It computes the Silhouette coefficient for a point. + * + * @param broadcastedClustersMap A map of the precomputed values for each cluster. + * @param normalizedFeatures The [[org.apache.spark.ml.linalg.Vector]] representing the + * normalized features of the current point. + * @param clusterId The id of the cluster the current point belongs to. + */ + def computeSilhouetteCoefficient( + broadcastedClustersMap: Broadcast[Map[Double, (Vector, Long)]], + normalizedFeatures: Vector, + clusterId: Double): Double = { + + def compute(targetClusterId: Double): Double = { + val (normalizedFeatureSum, numOfPoints) = broadcastedClustersMap.value(targetClusterId) + 1 - BLAS.dot(normalizedFeatures, normalizedFeatureSum) / numOfPoints + } + + pointSilhouetteCoefficient(broadcastedClustersMap.value.keySet, + clusterId, + broadcastedClustersMap.value(clusterId)._2, + compute) + } + + /** + * Compute the Silhouette score of the dataset using the cosine distance measure. + * + * @param dataset The input dataset (previously clustered) on which compute the Silhouette. + * @param predictionCol The name of the column which contains the predicted cluster id + * for the point. + * @param featuresCol The name of the column which contains the feature vector of the point. + * @return The average of the Silhouette values of the clustered data. + */ + def computeSilhouetteScore( + dataset: Dataset[_], + predictionCol: String, + featuresCol: String): Double = { + val normalizeFeatureUDF = udf { + features: Vector => { + val norm = Vectors.norm(features, 2.0) + features match { + case d: DenseVector => Vectors.dense(d.values.map(_ / norm)) + case s: SparseVector => Vectors.sparse(s.size, s.indices, s.values.map(_ / norm)) + } + } + } + val dfWithNormalizedFeatures = dataset.withColumn(normalizedFeaturesColName, + normalizeFeatureUDF(col(featuresCol))) + + // compute aggregate values for clusters needed by the algorithm + val clustersStatsMap = computeClusterStats(dfWithNormalizedFeatures, predictionCol) + + // Silhouette is reasonable only when the number of clusters is greater then 1 + assert(clustersStatsMap.size > 1, "Number of clusters must be greater than one.") + + val bClustersStatsMap = dataset.sparkSession.sparkContext.broadcast(clustersStatsMap) + + val computeSilhouetteCoefficientUDF = udf { + computeSilhouetteCoefficient(bClustersStatsMap, _: Vector, _: Double) + } + + val silhouetteScore = overallScore(dfWithNormalizedFeatures, + computeSilhouetteCoefficientUDF(col(normalizedFeaturesColName), + col(predictionCol).cast(DoubleType))) bClustersStatsMap.destroy() diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/ClusteringEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/ClusteringEvaluatorSuite.scala index 677ce49a903ab..3bf34770f5687 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/evaluation/ClusteringEvaluatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/ClusteringEvaluatorSuite.scala @@ -66,16 +66,38 @@ class ClusteringEvaluatorSuite assert(evaluator.evaluate(irisDataset) ~== 0.6564679231 relTol 1e-5) } - test("number of clusters must be greater than one") { - val singleClusterDataset = irisDataset.where($"label" === 0.0) + /* + Use the following python code to load the data and evaluate it using scikit-learn package. + + from sklearn import datasets + from sklearn.metrics import silhouette_score + iris = datasets.load_iris() + round(silhouette_score(iris.data, iris.target, metric='cosine'), 10) + + 0.7222369298 + */ + test("cosine Silhouette") { val evaluator = new ClusteringEvaluator() .setFeaturesCol("features") .setPredictionCol("label") + .setDistanceMeasure("cosine") + + assert(evaluator.evaluate(irisDataset) ~== 0.7222369298 relTol 1e-5) + } + + test("number of clusters must be greater than one") { + val singleClusterDataset = irisDataset.where($"label" === 0.0) + Seq("squaredEuclidean", "cosine").foreach { distanceMeasure => + val evaluator = new ClusteringEvaluator() + .setFeaturesCol("features") + .setPredictionCol("label") + .setDistanceMeasure(distanceMeasure) - val e = intercept[AssertionError]{ - evaluator.evaluate(singleClusterDataset) + val e = intercept[AssertionError] { + evaluator.evaluate(singleClusterDataset) + } + assert(e.getMessage.contains("Number of clusters must be greater than one")) } - assert(e.getMessage.contains("Number of clusters must be greater than one")) } } From d58fe28836639e68e262812d911f167cb071007b Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Tue, 13 Feb 2018 11:18:45 -0800 Subject: [PATCH 0347/1161] [SPARK-23154][ML][DOC] Document backwards compatibility guarantees for ML persistence ## What changes were proposed in this pull request? Added documentation about what MLlib guarantees in terms of loading ML models and Pipelines from old Spark versions. Discussed & confirmed on linked JIRA. Author: Joseph K. Bradley Closes #20592 from jkbradley/SPARK-23154-backwards-compat-doc. --- docs/ml-pipeline.md | 31 +++++++++++++++++++++++++++++-- 1 file changed, 29 insertions(+), 2 deletions(-) diff --git a/docs/ml-pipeline.md b/docs/ml-pipeline.md index aa92c0a37c0f4..e22e9003c30f6 100644 --- a/docs/ml-pipeline.md +++ b/docs/ml-pipeline.md @@ -188,9 +188,36 @@ Parameters belong to specific instances of `Estimator`s and `Transformer`s. For example, if we have two `LogisticRegression` instances `lr1` and `lr2`, then we can build a `ParamMap` with both `maxIter` parameters specified: `ParamMap(lr1.maxIter -> 10, lr2.maxIter -> 20)`. This is useful if there are two algorithms with the `maxIter` parameter in a `Pipeline`. -## Saving and Loading Pipelines +## ML persistence: Saving and Loading Pipelines -Often times it is worth it to save a model or a pipeline to disk for later use. In Spark 1.6, a model import/export functionality was added to the Pipeline API. Most basic transformers are supported as well as some of the more basic ML models. Please refer to the algorithm's API documentation to see if saving and loading is supported. +Often times it is worth it to save a model or a pipeline to disk for later use. In Spark 1.6, a model import/export functionality was added to the Pipeline API. +As of Spark 2.3, the DataFrame-based API in `spark.ml` and `pyspark.ml` has complete coverage. + +ML persistence works across Scala, Java and Python. However, R currently uses a modified format, +so models saved in R can only be loaded back in R; this should be fixed in the future and is +tracked in [SPARK-15572](https://issues.apache.org/jira/browse/SPARK-15572). + +### Backwards compatibility for ML persistence + +In general, MLlib maintains backwards compatibility for ML persistence. I.e., if you save an ML +model or Pipeline in one version of Spark, then you should be able to load it back and use it in a +future version of Spark. However, there are rare exceptions, described below. + +Model persistence: Is a model or Pipeline saved using Apache Spark ML persistence in Spark +version X loadable by Spark version Y? + +* Major versions: No guarantees, but best-effort. +* Minor and patch versions: Yes; these are backwards compatible. +* Note about the format: There are no guarantees for a stable persistence format, but model loading itself is designed to be backwards compatible. + +Model behavior: Does a model or Pipeline in Spark version X behave identically in Spark version Y? + +* Major versions: No guarantees, but best-effort. +* Minor and patch versions: Identical behavior, except for bug fixes. + +For both model persistence and model behavior, any breaking changes across a minor version or patch +version are reported in the Spark version release notes. If a breakage is not reported in release +notes, then it should be treated as a bug to be fixed. # Code examples From 2ee76c22b6e48e643694c9475e5f0d37124215e7 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 13 Feb 2018 11:56:49 -0800 Subject: [PATCH 0348/1161] [SPARK-23400][SQL] Add a constructors for ScalaUDF ## What changes were proposed in this pull request? In this upcoming 2.3 release, we changed the interface of `ScalaUDF`. Unfortunately, some Spark packages (e.g., spark-deep-learning) are using our internal class `ScalaUDF`. In the release 2.3, we added new parameters into this class. The users hit the binary compatibility issues and got the exception: ``` > java.lang.NoSuchMethodError: org.apache.spark.sql.catalyst.expressions.ScalaUDF.<init>(Ljava/lang/Object;Lorg/apache/spark/sql/types/DataType;Lscala/collection/Seq;Lscala/collection/Seq;Lscala/Option;)V ``` This PR is to improve the backward compatibility. However, we definitely should not encourage the external packages to use our internal classes. This might make us hard to maintain/develop the codes in Spark. ## How was this patch tested? N/A Author: gatorsmile Closes #20591 from gatorsmile/scalaUDF. --- .../spark/sql/catalyst/expressions/ScalaUDF.scala | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala index 388ef42883ad3..989c02305620a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala @@ -49,6 +49,17 @@ case class ScalaUDF( udfDeterministic: Boolean = true) extends Expression with ImplicitCastInputTypes with NonSQLExpression with UserDefinedExpression { + // The constructor for SPARK 2.1 and 2.2 + def this( + function: AnyRef, + dataType: DataType, + children: Seq[Expression], + inputTypes: Seq[DataType], + udfName: Option[String]) = { + this( + function, dataType, children, inputTypes, udfName, nullable = true, udfDeterministic = true) + } + override lazy val deterministic: Boolean = udfDeterministic && children.forall(_.deterministic) override def toString: String = From a5a4b83501526e02d0e3cd0056e4a5c0e1c8284f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Cattilapiros=E2=80=9D?= Date: Tue, 13 Feb 2018 16:46:43 -0600 Subject: [PATCH 0349/1161] [SPARK-23235][CORE] Add executor Threaddump to api MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? Extending api with the executor thread dump data. For this new REST URL is introduced: - GET http://localhost:4040/api/v1/applications/{applicationId}/executors/{executorId}/threads
    Example response: ``` javascript [ { "threadId" : 52, "threadName" : "context-cleaner-periodic-gc", "threadState" : "TIMED_WAITING", "stackTrace" : "sun.misc.Unsafe.park(Native Method)\njava.util.concurrent.locks.LockSupport.parkNanos(LockSupport.java:215)\njava.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.awaitNanos(AbstractQueuedSynchronizer.java:2078)\njava.util.concurrent.ScheduledThreadPoolExecutor$DelayedWorkQueue.take(ScheduledThreadPoolExecutor.java:1093)\njava.util.concurrent.ScheduledThreadPoolExecutor$DelayedWorkQueue.take(ScheduledThreadPoolExecutor.java:809)\njava.util.concurrent.ThreadPoolExecutor.getTask(ThreadPoolExecutor.java:1074)\njava.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1134)\njava.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "Lock(java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject1385411893})", "holdingLocks" : [ ] }, { "threadId" : 48, "threadName" : "dag-scheduler-event-loop", "threadState" : "WAITING", "stackTrace" : "sun.misc.Unsafe.park(Native Method)\njava.util.concurrent.locks.LockSupport.park(LockSupport.java:175)\njava.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.await(AbstractQueuedSynchronizer.java:2039)\njava.util.concurrent.LinkedBlockingDeque.takeFirst(LinkedBlockingDeque.java:492)\njava.util.concurrent.LinkedBlockingDeque.take(LinkedBlockingDeque.java:680)\norg.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:46)", "blockedByLock" : "Lock(java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject1138053349})", "holdingLocks" : [ ] }, { "threadId" : 17, "threadName" : "dispatcher-event-loop-0", "threadState" : "WAITING", "stackTrace" : "sun.misc.Unsafe.park(Native Method)\njava.util.concurrent.locks.LockSupport.park(LockSupport.java:175)\njava.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.await(AbstractQueuedSynchronizer.java:2039)\njava.util.concurrent.LinkedBlockingQueue.take(LinkedBlockingQueue.java:442)\norg.apache.spark.rpc.netty.Dispatcher$MessageLoop.run(Dispatcher.scala:215)\njava.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)\njava.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "Lock(java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject1764626380})", "holdingLocks" : [ "Lock(java.util.concurrent.ThreadPoolExecutor$Worker832743930})" ] }, { "threadId" : 18, "threadName" : "dispatcher-event-loop-1", "threadState" : "WAITING", "stackTrace" : "sun.misc.Unsafe.park(Native Method)\njava.util.concurrent.locks.LockSupport.park(LockSupport.java:175)\njava.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.await(AbstractQueuedSynchronizer.java:2039)\njava.util.concurrent.LinkedBlockingQueue.take(LinkedBlockingQueue.java:442)\norg.apache.spark.rpc.netty.Dispatcher$MessageLoop.run(Dispatcher.scala:215)\njava.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)\njava.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "Lock(java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject1764626380})", "holdingLocks" : [ "Lock(java.util.concurrent.ThreadPoolExecutor$Worker834153999})" ] }, { "threadId" : 19, "threadName" : "dispatcher-event-loop-2", "threadState" : "WAITING", "stackTrace" : "sun.misc.Unsafe.park(Native Method)\njava.util.concurrent.locks.LockSupport.park(LockSupport.java:175)\njava.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.await(AbstractQueuedSynchronizer.java:2039)\njava.util.concurrent.LinkedBlockingQueue.take(LinkedBlockingQueue.java:442)\norg.apache.spark.rpc.netty.Dispatcher$MessageLoop.run(Dispatcher.scala:215)\njava.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)\njava.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "Lock(java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject1764626380})", "holdingLocks" : [ "Lock(java.util.concurrent.ThreadPoolExecutor$Worker664836465})" ] }, { "threadId" : 20, "threadName" : "dispatcher-event-loop-3", "threadState" : "WAITING", "stackTrace" : "sun.misc.Unsafe.park(Native Method)\njava.util.concurrent.locks.LockSupport.park(LockSupport.java:175)\njava.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.await(AbstractQueuedSynchronizer.java:2039)\njava.util.concurrent.LinkedBlockingQueue.take(LinkedBlockingQueue.java:442)\norg.apache.spark.rpc.netty.Dispatcher$MessageLoop.run(Dispatcher.scala:215)\njava.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)\njava.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "Lock(java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject1764626380})", "holdingLocks" : [ "Lock(java.util.concurrent.ThreadPoolExecutor$Worker1645557354})" ] }, { "threadId" : 21, "threadName" : "dispatcher-event-loop-4", "threadState" : "WAITING", "stackTrace" : "sun.misc.Unsafe.park(Native Method)\njava.util.concurrent.locks.LockSupport.park(LockSupport.java:175)\njava.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.await(AbstractQueuedSynchronizer.java:2039)\njava.util.concurrent.LinkedBlockingQueue.take(LinkedBlockingQueue.java:442)\norg.apache.spark.rpc.netty.Dispatcher$MessageLoop.run(Dispatcher.scala:215)\njava.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)\njava.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "Lock(java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject1764626380})", "holdingLocks" : [ "Lock(java.util.concurrent.ThreadPoolExecutor$Worker1188871851})" ] }, { "threadId" : 22, "threadName" : "dispatcher-event-loop-5", "threadState" : "WAITING", "stackTrace" : "sun.misc.Unsafe.park(Native Method)\njava.util.concurrent.locks.LockSupport.park(LockSupport.java:175)\njava.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.await(AbstractQueuedSynchronizer.java:2039)\njava.util.concurrent.LinkedBlockingQueue.take(LinkedBlockingQueue.java:442)\norg.apache.spark.rpc.netty.Dispatcher$MessageLoop.run(Dispatcher.scala:215)\njava.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)\njava.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "Lock(java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject1764626380})", "holdingLocks" : [ "Lock(java.util.concurrent.ThreadPoolExecutor$Worker920926249})" ] }, { "threadId" : 23, "threadName" : "dispatcher-event-loop-6", "threadState" : "WAITING", "stackTrace" : "sun.misc.Unsafe.park(Native Method)\njava.util.concurrent.locks.LockSupport.park(LockSupport.java:175)\njava.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.await(AbstractQueuedSynchronizer.java:2039)\njava.util.concurrent.LinkedBlockingQueue.take(LinkedBlockingQueue.java:442)\norg.apache.spark.rpc.netty.Dispatcher$MessageLoop.run(Dispatcher.scala:215)\njava.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)\njava.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "Lock(java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject1764626380})", "holdingLocks" : [ "Lock(java.util.concurrent.ThreadPoolExecutor$Worker355222677})" ] }, { "threadId" : 24, "threadName" : "dispatcher-event-loop-7", "threadState" : "WAITING", "stackTrace" : "sun.misc.Unsafe.park(Native Method)\njava.util.concurrent.locks.LockSupport.park(LockSupport.java:175)\njava.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.await(AbstractQueuedSynchronizer.java:2039)\njava.util.concurrent.LinkedBlockingQueue.take(LinkedBlockingQueue.java:442)\norg.apache.spark.rpc.netty.Dispatcher$MessageLoop.run(Dispatcher.scala:215)\njava.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)\njava.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "Lock(java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject1764626380})", "holdingLocks" : [ "Lock(java.util.concurrent.ThreadPoolExecutor$Worker1589745212})" ] }, { "threadId" : 49, "threadName" : "driver-heartbeater", "threadState" : "TIMED_WAITING", "stackTrace" : "sun.misc.Unsafe.park(Native Method)\njava.util.concurrent.locks.LockSupport.parkNanos(LockSupport.java:215)\njava.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.awaitNanos(AbstractQueuedSynchronizer.java:2078)\njava.util.concurrent.ScheduledThreadPoolExecutor$DelayedWorkQueue.take(ScheduledThreadPoolExecutor.java:1093)\njava.util.concurrent.ScheduledThreadPoolExecutor$DelayedWorkQueue.take(ScheduledThreadPoolExecutor.java:809)\njava.util.concurrent.ThreadPoolExecutor.getTask(ThreadPoolExecutor.java:1074)\njava.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1134)\njava.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "Lock(java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject1602885835})", "holdingLocks" : [ ] }, { "threadId" : 53, "threadName" : "element-tracking-store-worker", "threadState" : "WAITING", "stackTrace" : "sun.misc.Unsafe.park(Native Method)\njava.util.concurrent.locks.LockSupport.park(LockSupport.java:175)\njava.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.await(AbstractQueuedSynchronizer.java:2039)\njava.util.concurrent.LinkedBlockingQueue.take(LinkedBlockingQueue.java:442)\njava.util.concurrent.ThreadPoolExecutor.getTask(ThreadPoolExecutor.java:1074)\njava.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1134)\njava.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "Lock(java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject1439439099})", "holdingLocks" : [ ] }, { "threadId" : 3, "threadName" : "Finalizer", "threadState" : "WAITING", "stackTrace" : "java.lang.Object.wait(Native Method)\njava.lang.ref.ReferenceQueue.remove(ReferenceQueue.java:143)\njava.lang.ref.ReferenceQueue.remove(ReferenceQueue.java:164)\njava.lang.ref.Finalizer$FinalizerThread.run(Finalizer.java:209)", "blockedByLock" : "Lock(java.lang.ref.ReferenceQueue$Lock1213098236})", "holdingLocks" : [ ] }, { "threadId" : 15, "threadName" : "ForkJoinPool-1-worker-13", "threadState" : "WAITING", "stackTrace" : "sun.misc.Unsafe.park(Native Method)\nscala.concurrent.forkjoin.ForkJoinPool.scan(ForkJoinPool.java:2075)\nscala.concurrent.forkjoin.ForkJoinPool.runWorker(ForkJoinPool.java:1979)\nscala.concurrent.forkjoin.ForkJoinWorkerThread.run(ForkJoinWorkerThread.java:107)", "blockedByLock" : "Lock(scala.concurrent.forkjoin.ForkJoinPool380286413})", "holdingLocks" : [ ] }, { "threadId" : 45, "threadName" : "heartbeat-receiver-event-loop-thread", "threadState" : "TIMED_WAITING", "stackTrace" : "sun.misc.Unsafe.park(Native Method)\njava.util.concurrent.locks.LockSupport.parkNanos(LockSupport.java:215)\njava.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.awaitNanos(AbstractQueuedSynchronizer.java:2078)\njava.util.concurrent.ScheduledThreadPoolExecutor$DelayedWorkQueue.take(ScheduledThreadPoolExecutor.java:1093)\njava.util.concurrent.ScheduledThreadPoolExecutor$DelayedWorkQueue.take(ScheduledThreadPoolExecutor.java:809)\njava.util.concurrent.ThreadPoolExecutor.getTask(ThreadPoolExecutor.java:1074)\njava.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1134)\njava.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "Lock(java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject715135812})", "holdingLocks" : [ ] }, { "threadId" : 1, "threadName" : "main", "threadState" : "RUNNABLE", "stackTrace" : "java.io.FileInputStream.read0(Native Method)\njava.io.FileInputStream.read(FileInputStream.java:207)\nscala.tools.jline_embedded.internal.NonBlockingInputStream.read(NonBlockingInputStream.java:169) => holding Monitor(scala.tools.jline_embedded.internal.NonBlockingInputStream46248392})\nscala.tools.jline_embedded.internal.NonBlockingInputStream.read(NonBlockingInputStream.java:137)\nscala.tools.jline_embedded.internal.NonBlockingInputStream.read(NonBlockingInputStream.java:246)\nscala.tools.jline_embedded.internal.InputStreamReader.read(InputStreamReader.java:261) => holding Monitor(scala.tools.jline_embedded.internal.NonBlockingInputStream46248392})\nscala.tools.jline_embedded.internal.InputStreamReader.read(InputStreamReader.java:198) => holding Monitor(scala.tools.jline_embedded.internal.NonBlockingInputStream46248392})\nscala.tools.jline_embedded.console.ConsoleReader.readCharacter(ConsoleReader.java:2145)\nscala.tools.jline_embedded.console.ConsoleReader.readLine(ConsoleReader.java:2349)\nscala.tools.jline_embedded.console.ConsoleReader.readLine(ConsoleReader.java:2269)\nscala.tools.nsc.interpreter.jline_embedded.InteractiveReader.readOneLine(JLineReader.scala:57)\nscala.tools.nsc.interpreter.InteractiveReader$$anonfun$readLine$2.apply(InteractiveReader.scala:37)\nscala.tools.nsc.interpreter.InteractiveReader$$anonfun$readLine$2.apply(InteractiveReader.scala:37)\nscala.tools.nsc.interpreter.InteractiveReader$.restartSysCalls(InteractiveReader.scala:44)\nscala.tools.nsc.interpreter.InteractiveReader$class.readLine(InteractiveReader.scala:37)\nscala.tools.nsc.interpreter.jline_embedded.InteractiveReader.readLine(JLineReader.scala:28)\nscala.tools.nsc.interpreter.ILoop.readOneLine(ILoop.scala:404)\nscala.tools.nsc.interpreter.ILoop.loop(ILoop.scala:413)\nscala.tools.nsc.interpreter.ILoop$$anonfun$process$1.apply$mcZ$sp(ILoop.scala:923)\nscala.tools.nsc.interpreter.ILoop$$anonfun$process$1.apply(ILoop.scala:909)\nscala.tools.nsc.interpreter.ILoop$$anonfun$process$1.apply(ILoop.scala:909)\nscala.reflect.internal.util.ScalaClassLoader$.savingContextLoader(ScalaClassLoader.scala:97)\nscala.tools.nsc.interpreter.ILoop.process(ILoop.scala:909)\norg.apache.spark.repl.Main$.doMain(Main.scala:76)\norg.apache.spark.repl.Main$.main(Main.scala:56)\norg.apache.spark.repl.Main.main(Main.scala)\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:498)\norg.apache.spark.deploy.JavaMainApplication.start(SparkApplication.scala:52)\norg.apache.spark.deploy.SparkSubmit$.org$apache$spark$deploy$SparkSubmit$$runMain(SparkSubmit.scala:879)\norg.apache.spark.deploy.SparkSubmit$.doRunMain$1(SparkSubmit.scala:197)\norg.apache.spark.deploy.SparkSubmit$.submit(SparkSubmit.scala:227)\norg.apache.spark.deploy.SparkSubmit$.main(SparkSubmit.scala:136)\norg.apache.spark.deploy.SparkSubmit.main(SparkSubmit.scala)", "blockedByLock" : "", "holdingLocks" : [ "Monitor(scala.tools.jline_embedded.internal.NonBlockingInputStream46248392})" ] }, { "threadId" : 26, "threadName" : "map-output-dispatcher-0", "threadState" : "WAITING", "stackTrace" : "sun.misc.Unsafe.park(Native Method)\njava.util.concurrent.locks.LockSupport.park(LockSupport.java:175)\njava.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.await(AbstractQueuedSynchronizer.java:2039)\njava.util.concurrent.LinkedBlockingQueue.take(LinkedBlockingQueue.java:442)\norg.apache.spark.MapOutputTrackerMaster$MessageLoop.run(MapOutputTracker.scala:384)\njava.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)\njava.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "Lock(java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject350285679})", "holdingLocks" : [ "Lock(java.util.concurrent.ThreadPoolExecutor$Worker1791280119})" ] }, { "threadId" : 27, "threadName" : "map-output-dispatcher-1", "threadState" : "WAITING", "stackTrace" : "sun.misc.Unsafe.park(Native Method)\njava.util.concurrent.locks.LockSupport.park(LockSupport.java:175)\njava.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.await(AbstractQueuedSynchronizer.java:2039)\njava.util.concurrent.LinkedBlockingQueue.take(LinkedBlockingQueue.java:442)\norg.apache.spark.MapOutputTrackerMaster$MessageLoop.run(MapOutputTracker.scala:384)\njava.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)\njava.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "Lock(java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject350285679})", "holdingLocks" : [ "Lock(java.util.concurrent.ThreadPoolExecutor$Worker1947378744})" ] }, { "threadId" : 28, "threadName" : "map-output-dispatcher-2", "threadState" : "WAITING", "stackTrace" : "sun.misc.Unsafe.park(Native Method)\njava.util.concurrent.locks.LockSupport.park(LockSupport.java:175)\njava.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.await(AbstractQueuedSynchronizer.java:2039)\njava.util.concurrent.LinkedBlockingQueue.take(LinkedBlockingQueue.java:442)\norg.apache.spark.MapOutputTrackerMaster$MessageLoop.run(MapOutputTracker.scala:384)\njava.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)\njava.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "Lock(java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject350285679})", "holdingLocks" : [ "Lock(java.util.concurrent.ThreadPoolExecutor$Worker507507251})" ] }, { "threadId" : 29, "threadName" : "map-output-dispatcher-3", "threadState" : "WAITING", "stackTrace" : "sun.misc.Unsafe.park(Native Method)\njava.util.concurrent.locks.LockSupport.park(LockSupport.java:175)\njava.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.await(AbstractQueuedSynchronizer.java:2039)\njava.util.concurrent.LinkedBlockingQueue.take(LinkedBlockingQueue.java:442)\norg.apache.spark.MapOutputTrackerMaster$MessageLoop.run(MapOutputTracker.scala:384)\njava.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)\njava.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "Lock(java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject350285679})", "holdingLocks" : [ "Lock(java.util.concurrent.ThreadPoolExecutor$Worker1016408627})" ] }, { "threadId" : 30, "threadName" : "map-output-dispatcher-4", "threadState" : "WAITING", "stackTrace" : "sun.misc.Unsafe.park(Native Method)\njava.util.concurrent.locks.LockSupport.park(LockSupport.java:175)\njava.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.await(AbstractQueuedSynchronizer.java:2039)\njava.util.concurrent.LinkedBlockingQueue.take(LinkedBlockingQueue.java:442)\norg.apache.spark.MapOutputTrackerMaster$MessageLoop.run(MapOutputTracker.scala:384)\njava.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)\njava.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "Lock(java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject350285679})", "holdingLocks" : [ "Lock(java.util.concurrent.ThreadPoolExecutor$Worker1879219501})" ] }, { "threadId" : 31, "threadName" : "map-output-dispatcher-5", "threadState" : "WAITING", "stackTrace" : "sun.misc.Unsafe.park(Native Method)\njava.util.concurrent.locks.LockSupport.park(LockSupport.java:175)\njava.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.await(AbstractQueuedSynchronizer.java:2039)\njava.util.concurrent.LinkedBlockingQueue.take(LinkedBlockingQueue.java:442)\norg.apache.spark.MapOutputTrackerMaster$MessageLoop.run(MapOutputTracker.scala:384)\njava.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)\njava.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "Lock(java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject350285679})", "holdingLocks" : [ "Lock(java.util.concurrent.ThreadPoolExecutor$Worker290509937})" ] }, { "threadId" : 32, "threadName" : "map-output-dispatcher-6", "threadState" : "WAITING", "stackTrace" : "sun.misc.Unsafe.park(Native Method)\njava.util.concurrent.locks.LockSupport.park(LockSupport.java:175)\njava.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.await(AbstractQueuedSynchronizer.java:2039)\njava.util.concurrent.LinkedBlockingQueue.take(LinkedBlockingQueue.java:442)\norg.apache.spark.MapOutputTrackerMaster$MessageLoop.run(MapOutputTracker.scala:384)\njava.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)\njava.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "Lock(java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject350285679})", "holdingLocks" : [ "Lock(java.util.concurrent.ThreadPoolExecutor$Worker1889468930})" ] }, { "threadId" : 33, "threadName" : "map-output-dispatcher-7", "threadState" : "WAITING", "stackTrace" : "sun.misc.Unsafe.park(Native Method)\njava.util.concurrent.locks.LockSupport.park(LockSupport.java:175)\njava.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.await(AbstractQueuedSynchronizer.java:2039)\njava.util.concurrent.LinkedBlockingQueue.take(LinkedBlockingQueue.java:442)\norg.apache.spark.MapOutputTrackerMaster$MessageLoop.run(MapOutputTracker.scala:384)\njava.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)\njava.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "Lock(java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject350285679})", "holdingLocks" : [ "Lock(java.util.concurrent.ThreadPoolExecutor$Worker1699637904})" ] }, { "threadId" : 47, "threadName" : "netty-rpc-env-timeout", "threadState" : "TIMED_WAITING", "stackTrace" : "sun.misc.Unsafe.park(Native Method)\njava.util.concurrent.locks.LockSupport.parkNanos(LockSupport.java:215)\njava.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.awaitNanos(AbstractQueuedSynchronizer.java:2078)\njava.util.concurrent.ScheduledThreadPoolExecutor$DelayedWorkQueue.take(ScheduledThreadPoolExecutor.java:1093)\njava.util.concurrent.ScheduledThreadPoolExecutor$DelayedWorkQueue.take(ScheduledThreadPoolExecutor.java:809)\njava.util.concurrent.ThreadPoolExecutor.getTask(ThreadPoolExecutor.java:1074)\njava.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1134)\njava.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "Lock(java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject977194847})", "holdingLocks" : [ ] }, { "threadId" : 14, "threadName" : "NonBlockingInputStreamThread", "threadState" : "WAITING", "stackTrace" : "java.lang.Object.wait(Native Method)\nscala.tools.jline_embedded.internal.NonBlockingInputStream.run(NonBlockingInputStream.java:278)\njava.lang.Thread.run(Thread.java:748)", "blockedByThreadId" : 1, "blockedByLock" : "Lock(scala.tools.jline_embedded.internal.NonBlockingInputStream46248392})", "holdingLocks" : [ ] }, { "threadId" : 2, "threadName" : "Reference Handler", "threadState" : "WAITING", "stackTrace" : "java.lang.Object.wait(Native Method)\njava.lang.Object.wait(Object.java:502)\njava.lang.ref.Reference.tryHandlePending(Reference.java:191)\njava.lang.ref.Reference$ReferenceHandler.run(Reference.java:153)", "blockedByLock" : "Lock(java.lang.ref.Reference$Lock1359433302})", "holdingLocks" : [ ] }, { "threadId" : 35, "threadName" : "refresh progress", "threadState" : "TIMED_WAITING", "stackTrace" : "java.lang.Object.wait(Native Method)\njava.util.TimerThread.mainLoop(Timer.java:552)\njava.util.TimerThread.run(Timer.java:505)", "blockedByLock" : "Lock(java.util.TaskQueue44276328})", "holdingLocks" : [ ] }, { "threadId" : 34, "threadName" : "RemoteBlock-temp-file-clean-thread", "threadState" : "TIMED_WAITING", "stackTrace" : "java.lang.Object.wait(Native Method)\njava.lang.ref.ReferenceQueue.remove(ReferenceQueue.java:143)\norg.apache.spark.storage.BlockManager$RemoteBlockTempFileManager.org$apache$spark$storage$BlockManager$RemoteBlockTempFileManager$$keepCleaning(BlockManager.scala:1630)\norg.apache.spark.storage.BlockManager$RemoteBlockTempFileManager$$anon$1.run(BlockManager.scala:1608)", "blockedByLock" : "Lock(java.lang.ref.ReferenceQueue$Lock391748181})", "holdingLocks" : [ ] }, { "threadId" : 25, "threadName" : "rpc-server-3-1", "threadState" : "RUNNABLE", "stackTrace" : "sun.nio.ch.KQueueArrayWrapper.kevent0(Native Method)\nsun.nio.ch.KQueueArrayWrapper.poll(KQueueArrayWrapper.java:198)\nsun.nio.ch.KQueueSelectorImpl.doSelect(KQueueSelectorImpl.java:117)\nsun.nio.ch.SelectorImpl.lockAndDoSelect(SelectorImpl.java:86) => holding Monitor(sun.nio.ch.KQueueSelectorImpl2057702496})\nsun.nio.ch.SelectorImpl.select(SelectorImpl.java:97)\nio.netty.channel.nio.SelectedSelectionKeySetSelector.select(SelectedSelectionKeySetSelector.java:62)\nio.netty.channel.nio.NioEventLoop.select(NioEventLoop.java:753)\nio.netty.channel.nio.NioEventLoop.run(NioEventLoop.java:409)\nio.netty.util.concurrent.SingleThreadEventExecutor$5.run(SingleThreadEventExecutor.java:858)\nio.netty.util.concurrent.DefaultThreadFactory$DefaultRunnableDecorator.run(DefaultThreadFactory.java:138)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "", "holdingLocks" : [ "Monitor(io.netty.channel.nio.SelectedSelectionKeySet1066929256})", "Monitor(java.util.Collections$UnmodifiableSet561426729})", "Monitor(sun.nio.ch.KQueueSelectorImpl2057702496})" ] }, { "threadId" : 50, "threadName" : "shuffle-server-5-1", "threadState" : "RUNNABLE", "stackTrace" : "sun.nio.ch.KQueueArrayWrapper.kevent0(Native Method)\nsun.nio.ch.KQueueArrayWrapper.poll(KQueueArrayWrapper.java:198)\nsun.nio.ch.KQueueSelectorImpl.doSelect(KQueueSelectorImpl.java:117)\nsun.nio.ch.SelectorImpl.lockAndDoSelect(SelectorImpl.java:86) => holding Monitor(sun.nio.ch.KQueueSelectorImpl1401522546})\nsun.nio.ch.SelectorImpl.select(SelectorImpl.java:97)\nio.netty.channel.nio.SelectedSelectionKeySetSelector.select(SelectedSelectionKeySetSelector.java:62)\nio.netty.channel.nio.NioEventLoop.select(NioEventLoop.java:753)\nio.netty.channel.nio.NioEventLoop.run(NioEventLoop.java:409)\nio.netty.util.concurrent.SingleThreadEventExecutor$5.run(SingleThreadEventExecutor.java:858)\nio.netty.util.concurrent.DefaultThreadFactory$DefaultRunnableDecorator.run(DefaultThreadFactory.java:138)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "", "holdingLocks" : [ "Monitor(io.netty.channel.nio.SelectedSelectionKeySet385972319})", "Monitor(java.util.Collections$UnmodifiableSet477937109})", "Monitor(sun.nio.ch.KQueueSelectorImpl1401522546})" ] }, { "threadId" : 4, "threadName" : "Signal Dispatcher", "threadState" : "RUNNABLE", "stackTrace" : "", "blockedByLock" : "", "holdingLocks" : [ ] }, { "threadId" : 51, "threadName" : "Spark Context Cleaner", "threadState" : "TIMED_WAITING", "stackTrace" : "java.lang.Object.wait(Native Method)\njava.lang.ref.ReferenceQueue.remove(ReferenceQueue.java:143)\norg.apache.spark.ContextCleaner$$anonfun$org$apache$spark$ContextCleaner$$keepCleaning$1.apply$mcV$sp(ContextCleaner.scala:181)\norg.apache.spark.util.Utils$.tryOrStopSparkContext(Utils.scala:1319)\norg.apache.spark.ContextCleaner.org$apache$spark$ContextCleaner$$keepCleaning(ContextCleaner.scala:178)\norg.apache.spark.ContextCleaner$$anon$1.run(ContextCleaner.scala:73)", "blockedByLock" : "Lock(java.lang.ref.ReferenceQueue$Lock1739420764})", "holdingLocks" : [ ] }, { "threadId" : 16, "threadName" : "spark-listener-group-appStatus", "threadState" : "WAITING", "stackTrace" : "sun.misc.Unsafe.park(Native Method)\njava.util.concurrent.locks.LockSupport.park(LockSupport.java:175)\njava.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.await(AbstractQueuedSynchronizer.java:2039)\njava.util.concurrent.LinkedBlockingQueue.take(LinkedBlockingQueue.java:442)\norg.apache.spark.scheduler.AsyncEventQueue$$anonfun$org$apache$spark$scheduler$AsyncEventQueue$$dispatch$1.apply(AsyncEventQueue.scala:94)\nscala.util.DynamicVariable.withValue(DynamicVariable.scala:58)\norg.apache.spark.scheduler.AsyncEventQueue.org$apache$spark$scheduler$AsyncEventQueue$$dispatch(AsyncEventQueue.scala:83)\norg.apache.spark.scheduler.AsyncEventQueue$$anon$1$$anonfun$run$1.apply$mcV$sp(AsyncEventQueue.scala:79)\norg.apache.spark.util.Utils$.tryOrStopSparkContext(Utils.scala:1319)\norg.apache.spark.scheduler.AsyncEventQueue$$anon$1.run(AsyncEventQueue.scala:78)", "blockedByLock" : "Lock(java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject1287190987})", "holdingLocks" : [ ] }, { "threadId" : 44, "threadName" : "spark-listener-group-executorManagement", "threadState" : "WAITING", "stackTrace" : "sun.misc.Unsafe.park(Native Method)\njava.util.concurrent.locks.LockSupport.park(LockSupport.java:175)\njava.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.await(AbstractQueuedSynchronizer.java:2039)\njava.util.concurrent.LinkedBlockingQueue.take(LinkedBlockingQueue.java:442)\norg.apache.spark.scheduler.AsyncEventQueue$$anonfun$org$apache$spark$scheduler$AsyncEventQueue$$dispatch$1.apply(AsyncEventQueue.scala:94)\nscala.util.DynamicVariable.withValue(DynamicVariable.scala:58)\norg.apache.spark.scheduler.AsyncEventQueue.org$apache$spark$scheduler$AsyncEventQueue$$dispatch(AsyncEventQueue.scala:83)\norg.apache.spark.scheduler.AsyncEventQueue$$anon$1$$anonfun$run$1.apply$mcV$sp(AsyncEventQueue.scala:79)\norg.apache.spark.util.Utils$.tryOrStopSparkContext(Utils.scala:1319)\norg.apache.spark.scheduler.AsyncEventQueue$$anon$1.run(AsyncEventQueue.scala:78)", "blockedByLock" : "Lock(java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject943262890})", "holdingLocks" : [ ] }, { "threadId" : 54, "threadName" : "spark-listener-group-shared", "threadState" : "WAITING", "stackTrace" : "sun.misc.Unsafe.park(Native Method)\njava.util.concurrent.locks.LockSupport.park(LockSupport.java:175)\njava.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.await(AbstractQueuedSynchronizer.java:2039)\njava.util.concurrent.LinkedBlockingQueue.take(LinkedBlockingQueue.java:442)\norg.apache.spark.scheduler.AsyncEventQueue$$anonfun$org$apache$spark$scheduler$AsyncEventQueue$$dispatch$1.apply(AsyncEventQueue.scala:94)\nscala.util.DynamicVariable.withValue(DynamicVariable.scala:58)\norg.apache.spark.scheduler.AsyncEventQueue.org$apache$spark$scheduler$AsyncEventQueue$$dispatch(AsyncEventQueue.scala:83)\norg.apache.spark.scheduler.AsyncEventQueue$$anon$1$$anonfun$run$1.apply$mcV$sp(AsyncEventQueue.scala:79)\norg.apache.spark.util.Utils$.tryOrStopSparkContext(Utils.scala:1319)\norg.apache.spark.scheduler.AsyncEventQueue$$anon$1.run(AsyncEventQueue.scala:78)", "blockedByLock" : "Lock(java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject334604425})", "holdingLocks" : [ ] }, { "threadId" : 37, "threadName" : "SparkUI-37", "threadState" : "TIMED_WAITING", "stackTrace" : "sun.misc.Unsafe.park(Native Method)\njava.util.concurrent.locks.LockSupport.parkNanos(LockSupport.java:215)\njava.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.awaitNanos(AbstractQueuedSynchronizer.java:2078)\norg.spark_project.jetty.util.BlockingArrayQueue.poll(BlockingArrayQueue.java:392)\norg.spark_project.jetty.util.thread.QueuedThreadPool.idleJobPoll(QueuedThreadPool.java:563)\norg.spark_project.jetty.util.thread.QueuedThreadPool.access$800(QueuedThreadPool.java:48)\norg.spark_project.jetty.util.thread.QueuedThreadPool$2.run(QueuedThreadPool.java:626)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "Lock(java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject1503479572})", "holdingLocks" : [ ] }, { "threadId" : 38, "threadName" : "SparkUI-38", "threadState" : "RUNNABLE", "stackTrace" : "sun.nio.ch.KQueueArrayWrapper.kevent0(Native Method)\nsun.nio.ch.KQueueArrayWrapper.poll(KQueueArrayWrapper.java:198)\nsun.nio.ch.KQueueSelectorImpl.doSelect(KQueueSelectorImpl.java:117)\nsun.nio.ch.SelectorImpl.lockAndDoSelect(SelectorImpl.java:86) => holding Monitor(sun.nio.ch.KQueueSelectorImpl841741934})\nsun.nio.ch.SelectorImpl.select(SelectorImpl.java:97)\nsun.nio.ch.SelectorImpl.select(SelectorImpl.java:101)\norg.spark_project.jetty.io.ManagedSelector$SelectorProducer.select(ManagedSelector.java:243)\norg.spark_project.jetty.io.ManagedSelector$SelectorProducer.produce(ManagedSelector.java:191)\norg.spark_project.jetty.util.thread.strategy.ExecuteProduceConsume.executeProduceConsume(ExecuteProduceConsume.java:249)\norg.spark_project.jetty.util.thread.strategy.ExecuteProduceConsume.produceConsume(ExecuteProduceConsume.java:148)\norg.spark_project.jetty.util.thread.strategy.ExecuteProduceConsume.run(ExecuteProduceConsume.java:136)\norg.spark_project.jetty.util.thread.QueuedThreadPool.runJob(QueuedThreadPool.java:671)\norg.spark_project.jetty.util.thread.QueuedThreadPool$2.run(QueuedThreadPool.java:589)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "", "holdingLocks" : [ "Monitor(sun.nio.ch.Util$3873523986})", "Monitor(java.util.Collections$UnmodifiableSet1769333189})", "Monitor(sun.nio.ch.KQueueSelectorImpl841741934})" ] }, { "threadId" : 40, "threadName" : "SparkUI-40-acceptor-034929380-Spark3a557b62{HTTP/1.1,[http/1.1]}{0.0.0.0:4040}", "threadState" : "RUNNABLE", "stackTrace" : "sun.nio.ch.ServerSocketChannelImpl.accept0(Native Method)\nsun.nio.ch.ServerSocketChannelImpl.accept(ServerSocketChannelImpl.java:422)\nsun.nio.ch.ServerSocketChannelImpl.accept(ServerSocketChannelImpl.java:250) => holding Monitor(java.lang.Object1134240909})\norg.spark_project.jetty.server.ServerConnector.accept(ServerConnector.java:371)\norg.spark_project.jetty.server.AbstractConnector$Acceptor.run(AbstractConnector.java:601)\norg.spark_project.jetty.util.thread.QueuedThreadPool.runJob(QueuedThreadPool.java:671)\norg.spark_project.jetty.util.thread.QueuedThreadPool$2.run(QueuedThreadPool.java:589)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "", "holdingLocks" : [ "Monitor(java.lang.Object1134240909})" ] }, { "threadId" : 43, "threadName" : "SparkUI-43", "threadState" : "RUNNABLE", "stackTrace" : "sun.management.ThreadImpl.dumpThreads0(Native Method)\nsun.management.ThreadImpl.dumpAllThreads(ThreadImpl.java:454)\norg.apache.spark.util.Utils$.getThreadDump(Utils.scala:2170)\norg.apache.spark.SparkContext.getExecutorThreadDump(SparkContext.scala:596)\norg.apache.spark.status.api.v1.AbstractApplicationResource$$anonfun$threadDump$1$$anonfun$apply$1.apply(OneApplicationResource.scala:66)\norg.apache.spark.status.api.v1.AbstractApplicationResource$$anonfun$threadDump$1$$anonfun$apply$1.apply(OneApplicationResource.scala:65)\nscala.Option.flatMap(Option.scala:171)\norg.apache.spark.status.api.v1.AbstractApplicationResource$$anonfun$threadDump$1.apply(OneApplicationResource.scala:65)\norg.apache.spark.status.api.v1.AbstractApplicationResource$$anonfun$threadDump$1.apply(OneApplicationResource.scala:58)\norg.apache.spark.status.api.v1.BaseAppResource$$anonfun$withUI$1.apply(ApiRootResource.scala:139)\norg.apache.spark.status.api.v1.BaseAppResource$$anonfun$withUI$1.apply(ApiRootResource.scala:134)\norg.apache.spark.ui.SparkUI.withSparkUI(SparkUI.scala:106)\norg.apache.spark.status.api.v1.BaseAppResource$class.withUI(ApiRootResource.scala:134)\norg.apache.spark.status.api.v1.AbstractApplicationResource.withUI(OneApplicationResource.scala:32)\norg.apache.spark.status.api.v1.AbstractApplicationResource.threadDump(OneApplicationResource.scala:58)\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:498)\norg.glassfish.jersey.server.model.internal.ResourceMethodInvocationHandlerFactory$1.invoke(ResourceMethodInvocationHandlerFactory.java:81)\norg.glassfish.jersey.server.model.internal.AbstractJavaResourceMethodDispatcher$1.run(AbstractJavaResourceMethodDispatcher.java:144)\norg.glassfish.jersey.server.model.internal.AbstractJavaResourceMethodDispatcher.invoke(AbstractJavaResourceMethodDispatcher.java:161)\norg.glassfish.jersey.server.model.internal.JavaResourceMethodDispatcherProvider$TypeOutInvoker.doDispatch(JavaResourceMethodDispatcherProvider.java:205)\norg.glassfish.jersey.server.model.internal.AbstractJavaResourceMethodDispatcher.dispatch(AbstractJavaResourceMethodDispatcher.java:99)\norg.glassfish.jersey.server.model.ResourceMethodInvoker.invoke(ResourceMethodInvoker.java:389)\norg.glassfish.jersey.server.model.ResourceMethodInvoker.apply(ResourceMethodInvoker.java:347)\norg.glassfish.jersey.server.model.ResourceMethodInvoker.apply(ResourceMethodInvoker.java:102)\norg.glassfish.jersey.server.ServerRuntime$2.run(ServerRuntime.java:326)\norg.glassfish.jersey.internal.Errors$1.call(Errors.java:271)\norg.glassfish.jersey.internal.Errors$1.call(Errors.java:267)\norg.glassfish.jersey.internal.Errors.process(Errors.java:315)\norg.glassfish.jersey.internal.Errors.process(Errors.java:297)\norg.glassfish.jersey.internal.Errors.process(Errors.java:267)\norg.glassfish.jersey.process.internal.RequestScope.runInScope(RequestScope.java:317)\norg.glassfish.jersey.server.ServerRuntime.process(ServerRuntime.java:305)\norg.glassfish.jersey.server.ApplicationHandler.handle(ApplicationHandler.java:1154)\norg.glassfish.jersey.servlet.WebComponent.serviceImpl(WebComponent.java:473)\norg.glassfish.jersey.servlet.WebComponent.service(WebComponent.java:427)\norg.glassfish.jersey.servlet.ServletContainer.service(ServletContainer.java:388)\norg.glassfish.jersey.servlet.ServletContainer.service(ServletContainer.java:341)\norg.glassfish.jersey.servlet.ServletContainer.service(ServletContainer.java:228)\norg.spark_project.jetty.servlet.ServletHolder.handle(ServletHolder.java:848)\norg.spark_project.jetty.servlet.ServletHandler.doHandle(ServletHandler.java:584)\norg.spark_project.jetty.server.handler.ContextHandler.doHandle(ContextHandler.java:1180)\norg.spark_project.jetty.servlet.ServletHandler.doScope(ServletHandler.java:512)\norg.spark_project.jetty.server.handler.ContextHandler.doScope(ContextHandler.java:1112)\norg.spark_project.jetty.server.handler.ScopedHandler.handle(ScopedHandler.java:141)\norg.spark_project.jetty.server.handler.gzip.GzipHandler.handle(GzipHandler.java:493)\norg.spark_project.jetty.server.handler.ContextHandlerCollection.handle(ContextHandlerCollection.java:213)\norg.spark_project.jetty.server.handler.HandlerWrapper.handle(HandlerWrapper.java:134)\norg.spark_project.jetty.server.Server.handle(Server.java:534)\norg.spark_project.jetty.server.HttpChannel.handle(HttpChannel.java:320)\norg.spark_project.jetty.server.HttpConnection.onFillable(HttpConnection.java:251)\norg.spark_project.jetty.io.AbstractConnection$ReadCallback.succeeded(AbstractConnection.java:283)\norg.spark_project.jetty.io.FillInterest.fillable(FillInterest.java:108)\norg.spark_project.jetty.io.SelectChannelEndPoint$2.run(SelectChannelEndPoint.java:93)\norg.spark_project.jetty.util.thread.strategy.ExecuteProduceConsume.executeProduceConsume(ExecuteProduceConsume.java:303)\norg.spark_project.jetty.util.thread.strategy.ExecuteProduceConsume.produceConsume(ExecuteProduceConsume.java:148)\norg.spark_project.jetty.util.thread.strategy.ExecuteProduceConsume.run(ExecuteProduceConsume.java:136)\norg.spark_project.jetty.util.thread.QueuedThreadPool.runJob(QueuedThreadPool.java:671)\norg.spark_project.jetty.util.thread.QueuedThreadPool$2.run(QueuedThreadPool.java:589)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "", "holdingLocks" : [ ] }, { "threadId" : 67, "threadName" : "SparkUI-67", "threadState" : "RUNNABLE", "stackTrace" : "sun.nio.ch.KQueueArrayWrapper.kevent0(Native Method)\nsun.nio.ch.KQueueArrayWrapper.poll(KQueueArrayWrapper.java:198)\nsun.nio.ch.KQueueSelectorImpl.doSelect(KQueueSelectorImpl.java:117)\nsun.nio.ch.SelectorImpl.lockAndDoSelect(SelectorImpl.java:86) => holding Monitor(sun.nio.ch.KQueueSelectorImpl1837806480})\nsun.nio.ch.SelectorImpl.select(SelectorImpl.java:97)\nsun.nio.ch.SelectorImpl.select(SelectorImpl.java:101)\norg.spark_project.jetty.io.ManagedSelector$SelectorProducer.select(ManagedSelector.java:243)\norg.spark_project.jetty.io.ManagedSelector$SelectorProducer.produce(ManagedSelector.java:191)\norg.spark_project.jetty.util.thread.strategy.ExecuteProduceConsume.executeProduceConsume(ExecuteProduceConsume.java:249)\norg.spark_project.jetty.util.thread.strategy.ExecuteProduceConsume.produceConsume(ExecuteProduceConsume.java:148)\norg.spark_project.jetty.util.thread.strategy.ExecuteProduceConsume.run(ExecuteProduceConsume.java:136)\norg.spark_project.jetty.util.thread.QueuedThreadPool.runJob(QueuedThreadPool.java:671)\norg.spark_project.jetty.util.thread.QueuedThreadPool$2.run(QueuedThreadPool.java:589)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "", "holdingLocks" : [ "Monitor(sun.nio.ch.Util$3881415814})", "Monitor(java.util.Collections$UnmodifiableSet62050480})", "Monitor(sun.nio.ch.KQueueSelectorImpl1837806480})" ] }, { "threadId" : 68, "threadName" : "SparkUI-68", "threadState" : "RUNNABLE", "stackTrace" : "sun.nio.ch.KQueueArrayWrapper.kevent0(Native Method)\nsun.nio.ch.KQueueArrayWrapper.poll(KQueueArrayWrapper.java:198)\nsun.nio.ch.KQueueSelectorImpl.doSelect(KQueueSelectorImpl.java:117)\nsun.nio.ch.SelectorImpl.lockAndDoSelect(SelectorImpl.java:86) => holding Monitor(sun.nio.ch.KQueueSelectorImpl223607814})\nsun.nio.ch.SelectorImpl.select(SelectorImpl.java:97)\nsun.nio.ch.SelectorImpl.select(SelectorImpl.java:101)\norg.spark_project.jetty.io.ManagedSelector$SelectorProducer.select(ManagedSelector.java:243)\norg.spark_project.jetty.io.ManagedSelector$SelectorProducer.produce(ManagedSelector.java:191)\norg.spark_project.jetty.util.thread.strategy.ExecuteProduceConsume.executeProduceConsume(ExecuteProduceConsume.java:249)\norg.spark_project.jetty.util.thread.strategy.ExecuteProduceConsume.produceConsume(ExecuteProduceConsume.java:148)\norg.spark_project.jetty.util.thread.strategy.ExecuteProduceConsume.run(ExecuteProduceConsume.java:136)\norg.spark_project.jetty.util.thread.QueuedThreadPool.runJob(QueuedThreadPool.java:671)\norg.spark_project.jetty.util.thread.QueuedThreadPool$2.run(QueuedThreadPool.java:589)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "", "holdingLocks" : [ "Monitor(sun.nio.ch.Util$3543145185})", "Monitor(java.util.Collections$UnmodifiableSet897441546})", "Monitor(sun.nio.ch.KQueueSelectorImpl223607814})" ] }, { "threadId" : 71, "threadName" : "SparkUI-71", "threadState" : "TIMED_WAITING", "stackTrace" : "sun.misc.Unsafe.park(Native Method)\njava.util.concurrent.locks.LockSupport.parkNanos(LockSupport.java:215)\njava.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.awaitNanos(AbstractQueuedSynchronizer.java:2078)\norg.spark_project.jetty.util.BlockingArrayQueue.poll(BlockingArrayQueue.java:392)\norg.spark_project.jetty.util.thread.QueuedThreadPool.idleJobPoll(QueuedThreadPool.java:563)\norg.spark_project.jetty.util.thread.QueuedThreadPool.access$800(QueuedThreadPool.java:48)\norg.spark_project.jetty.util.thread.QueuedThreadPool$2.run(QueuedThreadPool.java:626)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "Lock(java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject1503479572})", "holdingLocks" : [ ] }, { "threadId" : 77, "threadName" : "SparkUI-77", "threadState" : "TIMED_WAITING", "stackTrace" : "sun.misc.Unsafe.park(Native Method)\njava.util.concurrent.locks.LockSupport.parkNanos(LockSupport.java:215)\njava.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.awaitNanos(AbstractQueuedSynchronizer.java:2078)\norg.spark_project.jetty.util.BlockingArrayQueue.poll(BlockingArrayQueue.java:392)\norg.spark_project.jetty.util.thread.QueuedThreadPool.idleJobPoll(QueuedThreadPool.java:563)\norg.spark_project.jetty.util.thread.QueuedThreadPool.access$800(QueuedThreadPool.java:48)\norg.spark_project.jetty.util.thread.QueuedThreadPool$2.run(QueuedThreadPool.java:626)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "Lock(java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject1503479572})", "holdingLocks" : [ ] }, { "threadId" : 78, "threadName" : "SparkUI-78", "threadState" : "RUNNABLE", "stackTrace" : "sun.nio.ch.KQueueArrayWrapper.kevent0(Native Method)\nsun.nio.ch.KQueueArrayWrapper.poll(KQueueArrayWrapper.java:198)\nsun.nio.ch.KQueueSelectorImpl.doSelect(KQueueSelectorImpl.java:117)\nsun.nio.ch.SelectorImpl.lockAndDoSelect(SelectorImpl.java:86) => holding Monitor(sun.nio.ch.KQueueSelectorImpl403077801})\nsun.nio.ch.SelectorImpl.select(SelectorImpl.java:97)\nsun.nio.ch.SelectorImpl.select(SelectorImpl.java:101)\norg.spark_project.jetty.io.ManagedSelector$SelectorProducer.select(ManagedSelector.java:243)\norg.spark_project.jetty.io.ManagedSelector$SelectorProducer.produce(ManagedSelector.java:191)\norg.spark_project.jetty.util.thread.strategy.ExecuteProduceConsume.executeProduceConsume(ExecuteProduceConsume.java:249)\norg.spark_project.jetty.util.thread.strategy.ExecuteProduceConsume.produceConsume(ExecuteProduceConsume.java:148)\norg.spark_project.jetty.util.thread.strategy.ExecuteProduceConsume.run(ExecuteProduceConsume.java:136)\norg.spark_project.jetty.util.thread.QueuedThreadPool.runJob(QueuedThreadPool.java:671)\norg.spark_project.jetty.util.thread.QueuedThreadPool$2.run(QueuedThreadPool.java:589)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "", "holdingLocks" : [ "Monitor(sun.nio.ch.Util$3261312406})", "Monitor(java.util.Collections$UnmodifiableSet852901260})", "Monitor(sun.nio.ch.KQueueSelectorImpl403077801})" ] }, { "threadId" : 72, "threadName" : "SparkUI-JettyScheduler", "threadState" : "TIMED_WAITING", "stackTrace" : "sun.misc.Unsafe.park(Native Method)\njava.util.concurrent.locks.LockSupport.parkNanos(LockSupport.java:215)\njava.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.awaitNanos(AbstractQueuedSynchronizer.java:2078)\njava.util.concurrent.ScheduledThreadPoolExecutor$DelayedWorkQueue.take(ScheduledThreadPoolExecutor.java:1093)\njava.util.concurrent.ScheduledThreadPoolExecutor$DelayedWorkQueue.take(ScheduledThreadPoolExecutor.java:809)\njava.util.concurrent.ThreadPoolExecutor.getTask(ThreadPoolExecutor.java:1074)\njava.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1134)\njava.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "Lock(java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject1587346642})", "holdingLocks" : [ ] }, { "threadId" : 63, "threadName" : "task-result-getter-0", "threadState" : "WAITING", "stackTrace" : "sun.misc.Unsafe.park(Native Method)\njava.util.concurrent.locks.LockSupport.park(LockSupport.java:175)\njava.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.await(AbstractQueuedSynchronizer.java:2039)\njava.util.concurrent.LinkedBlockingQueue.take(LinkedBlockingQueue.java:442)\njava.util.concurrent.ThreadPoolExecutor.getTask(ThreadPoolExecutor.java:1074)\njava.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1134)\njava.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "Lock(java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject537563105})", "holdingLocks" : [ ] }, { "threadId" : 64, "threadName" : "task-result-getter-1", "threadState" : "WAITING", "stackTrace" : "sun.misc.Unsafe.park(Native Method)\njava.util.concurrent.locks.LockSupport.park(LockSupport.java:175)\njava.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.await(AbstractQueuedSynchronizer.java:2039)\njava.util.concurrent.LinkedBlockingQueue.take(LinkedBlockingQueue.java:442)\njava.util.concurrent.ThreadPoolExecutor.getTask(ThreadPoolExecutor.java:1074)\njava.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1134)\njava.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "Lock(java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject537563105})", "holdingLocks" : [ ] }, { "threadId" : 65, "threadName" : "task-result-getter-2", "threadState" : "WAITING", "stackTrace" : "sun.misc.Unsafe.park(Native Method)\njava.util.concurrent.locks.LockSupport.park(LockSupport.java:175)\njava.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.await(AbstractQueuedSynchronizer.java:2039)\njava.util.concurrent.LinkedBlockingQueue.take(LinkedBlockingQueue.java:442)\njava.util.concurrent.ThreadPoolExecutor.getTask(ThreadPoolExecutor.java:1074)\njava.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1134)\njava.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "Lock(java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject537563105})", "holdingLocks" : [ ] }, { "threadId" : 66, "threadName" : "task-result-getter-3", "threadState" : "WAITING", "stackTrace" : "sun.misc.Unsafe.park(Native Method)\njava.util.concurrent.locks.LockSupport.park(LockSupport.java:175)\njava.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.await(AbstractQueuedSynchronizer.java:2039)\njava.util.concurrent.LinkedBlockingQueue.take(LinkedBlockingQueue.java:442)\njava.util.concurrent.ThreadPoolExecutor.getTask(ThreadPoolExecutor.java:1074)\njava.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1134)\njava.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\njava.lang.Thread.run(Thread.java:748)", "blockedByLock" : "Lock(java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject537563105})", "holdingLocks" : [ ] }, { "threadId" : 46, "threadName" : "Timer-0", "threadState" : "WAITING", "stackTrace" : "java.lang.Object.wait(Native Method)\njava.lang.Object.wait(Object.java:502)\njava.util.TimerThread.mainLoop(Timer.java:526)\njava.util.TimerThread.run(Timer.java:505)", "blockedByLock" : "Lock(java.util.TaskQueue635634547})", "holdingLocks" : [ ] } ] ```
    ## How was this patch tested? It was tested manually. Old executor page with thread dumps: screen shot 2018-02-01 at 14 31 19 New api: screen shot 2018-02-01 at 14 31 56 Testing error cases. Initial state: ![screen shot 2018-02-06 at 13 05 05](https://user-images.githubusercontent.com/2017933/35858990-ad2982be-0b3e-11e8-879b-656112065c7f.png) Dead executor: ```bash $ curl -o - -s -w "\n%{http_code}\n" http://localhost:4040/api/v1/applications/app-20180206122543-0000/executors/1/threads Executor is not active. 400 ``` Never existed (but well formatted: number) executor ID: ```bash $ curl -o - -s -w "\n%{http_code}\n" http://localhost:4040/api/v1/applications/app-20180206122543-0000/executors/42/threads Executor does not exist. 404 ``` Not available stacktrace (dead executor but UI has not registered as dead yet): ```bash $ kill -9 ; curl -o - -s -w "\n%{http_code}\n" http://localhost:4040/api/v1/applications/app-20180206122543-0000/executors/2/threads No thread dump is available. 404 ``` Invalid executor ID format: ```bash $ curl -o - -s -w "\n%{http_code}\n" http://localhost:4040/api/v1/applications/app-20180206122543-0000/executors/something6/threads Invalid executorId: neither 'driver' nor number. 400 ``` Author: “attilapiros” Closes #20474 from attilapiros/SPARK-23235. --- .../scala/org/apache/spark/SparkContext.scala | 1 + .../spark/status/api/v1/ApiRootResource.scala | 8 +++++ .../api/v1/OneApplicationResource.scala | 29 +++++++++++++++-- .../org/apache/spark/status/api/v1/api.scala | 9 ++++++ .../ui/exec/ExecutorThreadDumpPage.scala | 13 +------- .../apache/spark/util/ThreadStackTrace.scala | 31 ------------------- .../scala/org/apache/spark/util/Utils.scala | 18 ++++++++++- docs/monitoring.md | 7 +++++ 8 files changed, 69 insertions(+), 47 deletions(-) delete mode 100644 core/src/main/scala/org/apache/spark/util/ThreadStackTrace.scala diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index c4f74c4f1f9c2..dc531e3337014 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -54,6 +54,7 @@ import org.apache.spark.scheduler._ import org.apache.spark.scheduler.cluster.{CoarseGrainedSchedulerBackend, StandaloneSchedulerBackend} import org.apache.spark.scheduler.local.LocalSchedulerBackend import org.apache.spark.status.AppStatusStore +import org.apache.spark.status.api.v1.ThreadStackTrace import org.apache.spark.storage._ import org.apache.spark.storage.BlockManagerMessages.TriggerThreadDump import org.apache.spark.ui.{ConsoleProgressBar, SparkUI} diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala index ed9bdc6e1e3c2..7127397f6205c 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala @@ -157,6 +157,14 @@ private[v1] class NotFoundException(msg: String) extends WebApplicationException .build() ) +private[v1] class ServiceUnavailable(msg: String) extends WebApplicationException( + new ServiceUnavailableException(msg), + Response + .status(Response.Status.SERVICE_UNAVAILABLE) + .entity(ErrorWrapper(msg)) + .build() +) + private[v1] class BadParameterException(msg: String) extends WebApplicationException( new IllegalArgumentException(msg), Response diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/OneApplicationResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/OneApplicationResource.scala index bd4df07e7afc6..974697890dd03 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/OneApplicationResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/OneApplicationResource.scala @@ -19,13 +19,13 @@ package org.apache.spark.status.api.v1 import java.io.OutputStream import java.util.{List => JList} import java.util.zip.ZipOutputStream -import javax.ws.rs.{GET, Path, PathParam, Produces, QueryParam} +import javax.ws.rs._ import javax.ws.rs.core.{MediaType, Response, StreamingOutput} import scala.util.control.NonFatal -import org.apache.spark.JobExecutionStatus -import org.apache.spark.ui.SparkUI +import org.apache.spark.{JobExecutionStatus, SparkContext} +import org.apache.spark.ui.UIUtils @Produces(Array(MediaType.APPLICATION_JSON)) private[v1] class AbstractApplicationResource extends BaseAppResource { @@ -51,6 +51,29 @@ private[v1] class AbstractApplicationResource extends BaseAppResource { @Path("executors") def executorList(): Seq[ExecutorSummary] = withUI(_.store.executorList(true)) + @GET + @Path("executors/{executorId}/threads") + def threadDump(@PathParam("executorId") execId: String): Array[ThreadStackTrace] = withUI { ui => + if (execId != SparkContext.DRIVER_IDENTIFIER && !execId.forall(Character.isDigit)) { + throw new BadParameterException( + s"Invalid executorId: neither '${SparkContext.DRIVER_IDENTIFIER}' nor number.") + } + + val safeSparkContext = ui.sc.getOrElse { + throw new ServiceUnavailable("Thread dumps not available through the history server.") + } + + ui.store.asOption(ui.store.executorSummary(execId)) match { + case Some(executorSummary) if executorSummary.isActive => + val safeThreadDump = safeSparkContext.getExecutorThreadDump(execId).getOrElse { + throw new NotFoundException("No thread dump is available.") + } + safeThreadDump + case Some(_) => throw new BadParameterException("Executor is not active.") + case _ => throw new NotFoundException("Executor does not exist.") + } + } + @GET @Path("allexecutors") def allExecutorList(): Seq[ExecutorSummary] = withUI(_.store.executorList(false)) diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala index a333f1aaf6325..369e98b683b1a 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala @@ -316,3 +316,12 @@ class RuntimeInfo private[spark]( val javaVersion: String, val javaHome: String, val scalaVersion: String) + +case class ThreadStackTrace( + val threadId: Long, + val threadName: String, + val threadState: Thread.State, + val stackTrace: String, + val blockedByThreadId: Option[Long], + val blockedByLock: String, + val holdingLocks: Seq[String]) diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala index f4686ea3cf91f..7a9aaf29a8b05 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala @@ -17,7 +17,6 @@ package org.apache.spark.ui.exec -import java.util.Locale import javax.servlet.http.HttpServletRequest import scala.xml.{Node, Text} @@ -41,17 +40,7 @@ private[ui] class ExecutorThreadDumpPage( val maybeThreadDump = sc.get.getExecutorThreadDump(executorId) val content = maybeThreadDump.map { threadDump => - val dumpRows = threadDump.sortWith { - case (threadTrace1, threadTrace2) => - val v1 = if (threadTrace1.threadName.contains("Executor task launch")) 1 else 0 - val v2 = if (threadTrace2.threadName.contains("Executor task launch")) 1 else 0 - if (v1 == v2) { - threadTrace1.threadName.toLowerCase(Locale.ROOT) < - threadTrace2.threadName.toLowerCase(Locale.ROOT) - } else { - v1 > v2 - } - }.map { thread => + val dumpRows = threadDump.map { thread => val threadId = thread.threadId val blockedBy = thread.blockedByThreadId match { case Some(_) => diff --git a/core/src/main/scala/org/apache/spark/util/ThreadStackTrace.scala b/core/src/main/scala/org/apache/spark/util/ThreadStackTrace.scala deleted file mode 100644 index b1217980faf1f..0000000000000 --- a/core/src/main/scala/org/apache/spark/util/ThreadStackTrace.scala +++ /dev/null @@ -1,31 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.util - -/** - * Used for shipping per-thread stacktraces from the executors to driver. - */ -private[spark] case class ThreadStackTrace( - threadId: Long, - threadName: String, - threadState: Thread.State, - stackTrace: String, - blockedByThreadId: Option[Long], - blockedByLock: String, - holdingLocks: Seq[String]) - diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 5853302973140..d493663f0b168 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -63,6 +63,7 @@ import org.apache.spark.internal.config._ import org.apache.spark.launcher.SparkLauncher import org.apache.spark.network.util.JavaUtils import org.apache.spark.serializer.{DeserializationStream, SerializationStream, SerializerInstance} +import org.apache.spark.status.api.v1.ThreadStackTrace /** CallSite represents a place in user code. It can have a short and a long form. */ private[spark] case class CallSite(shortForm: String, longForm: String) @@ -2168,7 +2169,22 @@ private[spark] object Utils extends Logging { // We need to filter out null values here because dumpAllThreads() may return null array // elements for threads that are dead / don't exist. val threadInfos = ManagementFactory.getThreadMXBean.dumpAllThreads(true, true).filter(_ != null) - threadInfos.sortBy(_.getThreadId).map(threadInfoToThreadStackTrace) + threadInfos.sortWith { case (threadTrace1, threadTrace2) => + val v1 = if (threadTrace1.getThreadName.contains("Executor task launch")) 1 else 0 + val v2 = if (threadTrace2.getThreadName.contains("Executor task launch")) 1 else 0 + if (v1 == v2) { + val name1 = threadTrace1.getThreadName().toLowerCase(Locale.ROOT) + val name2 = threadTrace2.getThreadName().toLowerCase(Locale.ROOT) + val nameCmpRes = name1.compareTo(name2) + if (nameCmpRes == 0) { + threadTrace1.getThreadId < threadTrace2.getThreadId + } else { + nameCmpRes < 0 + } + } else { + v1 > v2 + } + }.map(threadInfoToThreadStackTrace) } def getThreadDumpForThread(threadId: Long): Option[ThreadStackTrace] = { diff --git a/docs/monitoring.md b/docs/monitoring.md index 6f6cfc1288d73..d5f7ffcc260a1 100644 --- a/docs/monitoring.md +++ b/docs/monitoring.md @@ -347,6 +347,13 @@ can be identified by their `[attempt-id]`. In the API listed below, when running /applications/[app-id]/executors A list of all active executors for the given application. + + /applications/[app-id]/executors/[executor-id]/threads + + Stack traces of all the threads running within the given active executor. + Not available via the history server. + + /applications/[app-id]/allexecutors A list of all(active and dead) executors for the given application. From d6f5e172b480c62165be168deae0deff8062f476 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 13 Feb 2018 16:21:17 -0800 Subject: [PATCH 0350/1161] Revert "[SPARK-23303][SQL] improve the explain result for data source v2 relations" This reverts commit f17b936f0ddb7d46d1349bd42f9a64c84c06e48d. --- .../kafka010/KafkaContinuousSourceSuite.scala | 18 +++- .../sql/kafka010/KafkaContinuousTest.scala | 3 +- .../spark/sql/kafka010/KafkaSourceSuite.scala | 3 +- .../apache/spark/sql/DataFrameReader.scala | 8 +- .../v2/DataSourceReaderHolder.scala | 64 +++++++++++++ .../v2/DataSourceV2QueryPlan.scala | 96 ------------------- .../datasources/v2/DataSourceV2Relation.scala | 26 ++--- .../datasources/v2/DataSourceV2ScanExec.scala | 6 +- .../datasources/v2/DataSourceV2Strategy.scala | 4 +- .../v2/PushDownOperatorsToDataSource.scala | 4 +- .../streaming/MicroBatchExecution.scala | 22 ++--- .../continuous/ContinuousExecution.scala | 9 +- .../spark/sql/streaming/StreamSuite.scala | 8 +- .../spark/sql/streaming/StreamTest.scala | 2 +- .../continuous/ContinuousSuite.scala | 11 ++- 15 files changed, 127 insertions(+), 157 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceReaderHolder.scala delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2QueryPlan.scala diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala index 72ee0c551ec3d..a7083fa4e3417 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala @@ -17,9 +17,20 @@ package org.apache.spark.sql.kafka010 -import org.apache.spark.sql.Dataset +import java.util.Properties +import java.util.concurrent.atomic.AtomicInteger + +import org.scalatest.time.SpanSugar._ +import scala.collection.mutable +import scala.util.Random + +import org.apache.spark.SparkContext +import org.apache.spark.sql.{DataFrame, Dataset, ForeachWriter, Row} import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation -import org.apache.spark.sql.streaming.Trigger +import org.apache.spark.sql.execution.streaming.StreamExecution +import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution +import org.apache.spark.sql.streaming.{StreamTest, Trigger} +import org.apache.spark.sql.test.{SharedSQLContext, TestSparkSession} // Run tests in KafkaSourceSuiteBase in continuous execution mode. class KafkaContinuousSourceSuite extends KafkaSourceSuiteBase with KafkaContinuousTest @@ -60,8 +71,7 @@ class KafkaContinuousSourceTopicDeletionSuite extends KafkaContinuousTest { eventually(timeout(streamingTimeout)) { assert( query.lastExecution.logical.collectFirst { - case r: DataSourceV2Relation if r.reader.isInstanceOf[KafkaContinuousReader] => - r.reader.asInstanceOf[KafkaContinuousReader] + case DataSourceV2Relation(_, r: KafkaContinuousReader) => r }.exists { r => // Ensure the new topic is present and the old topic is gone. r.knownPartitions.exists(_.topic == topic2) diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala index d34458ac81014..5a1a14f7a307a 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala @@ -47,8 +47,7 @@ trait KafkaContinuousTest extends KafkaSourceTest { eventually(timeout(streamingTimeout)) { assert( query.lastExecution.logical.collectFirst { - case r: DataSourceV2Relation if r.reader.isInstanceOf[KafkaContinuousReader] => - r.reader.asInstanceOf[KafkaContinuousReader] + case DataSourceV2Relation(_, r: KafkaContinuousReader) => r }.exists(_.knownPartitions.size == newCount), s"query never reconfigured to $newCount partitions") } diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala index cb09cce75ff6f..02c87643568bd 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala @@ -117,8 +117,7 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext { } ++ (query.get.lastExecution match { case null => Seq() case e => e.logical.collect { - case r: DataSourceV2Relation if r.reader.isInstanceOf[KafkaContinuousReader] => - r.reader.asInstanceOf[KafkaContinuousReader] + case DataSourceV2Relation(_, reader: KafkaContinuousReader) => reader } }) if (sources.isEmpty) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 984b6510f2dbe..fcaf8d618c168 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -189,9 +189,11 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { val cls = DataSource.lookupDataSource(source, sparkSession.sessionState.conf) if (classOf[DataSourceV2].isAssignableFrom(cls)) { - val ds = cls.newInstance().asInstanceOf[DataSourceV2] + val ds = cls.newInstance() val options = new DataSourceOptions((extraOptions ++ - DataSourceV2Utils.extractSessionConfigs(ds, sparkSession.sessionState.conf)).asJava) + DataSourceV2Utils.extractSessionConfigs( + ds = ds.asInstanceOf[DataSourceV2], + conf = sparkSession.sessionState.conf)).asJava) // Streaming also uses the data source V2 API. So it may be that the data source implements // v2, but has no v2 implementation for batch reads. In that case, we fall back to loading @@ -219,7 +221,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { if (reader == null) { loadV1Source(paths: _*) } else { - Dataset.ofRows(sparkSession, DataSourceV2Relation(ds, reader)) + Dataset.ofRows(sparkSession, DataSourceV2Relation(reader)) } } else { loadV1Source(paths: _*) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceReaderHolder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceReaderHolder.scala new file mode 100644 index 0000000000000..81219e9771bd8 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceReaderHolder.scala @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import java.util.Objects + +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.sources.v2.reader._ + +/** + * A base class for data source reader holder with customized equals/hashCode methods. + */ +trait DataSourceReaderHolder { + + /** + * The output of the data source reader, w.r.t. column pruning. + */ + def output: Seq[Attribute] + + /** + * The held data source reader. + */ + def reader: DataSourceReader + + /** + * The metadata of this data source reader that can be used for equality test. + */ + private def metadata: Seq[Any] = { + val filters: Any = reader match { + case s: SupportsPushDownCatalystFilters => s.pushedCatalystFilters().toSet + case s: SupportsPushDownFilters => s.pushedFilters().toSet + case _ => Nil + } + Seq(output, reader.getClass, filters) + } + + def canEqual(other: Any): Boolean + + override def equals(other: Any): Boolean = other match { + case other: DataSourceReaderHolder => + canEqual(other) && metadata.length == other.metadata.length && + metadata.zip(other.metadata).forall { case (l, r) => l == r } + case _ => false + } + + override def hashCode(): Int = { + metadata.map(Objects.hashCode).foldLeft(0)((a, b) => 31 * a + b) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2QueryPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2QueryPlan.scala deleted file mode 100644 index 1e0d088f3a57c..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2QueryPlan.scala +++ /dev/null @@ -1,96 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.datasources.v2 - -import java.util.Objects - -import org.apache.commons.lang3.StringUtils - -import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.sources.v2.DataSourceV2 -import org.apache.spark.sql.sources.v2.reader._ -import org.apache.spark.util.Utils - -/** - * A base class for data source v2 related query plan(both logical and physical). It defines the - * equals/hashCode methods, and provides a string representation of the query plan, according to - * some common information. - */ -trait DataSourceV2QueryPlan { - - /** - * The output of the data source reader, w.r.t. column pruning. - */ - def output: Seq[Attribute] - - /** - * The instance of this data source implementation. Note that we only consider its class in - * equals/hashCode, not the instance itself. - */ - def source: DataSourceV2 - - /** - * The created data source reader. Here we use it to get the filters that has been pushed down - * so far, itself doesn't take part in the equals/hashCode. - */ - def reader: DataSourceReader - - private lazy val filters = reader match { - case s: SupportsPushDownCatalystFilters => s.pushedCatalystFilters().toSet - case s: SupportsPushDownFilters => s.pushedFilters().toSet - case _ => Set.empty - } - - /** - * The metadata of this data source query plan that can be used for equality check. - */ - private def metadata: Seq[Any] = Seq(output, source.getClass, filters) - - def canEqual(other: Any): Boolean - - override def equals(other: Any): Boolean = other match { - case other: DataSourceV2QueryPlan => canEqual(other) && metadata == other.metadata - case _ => false - } - - override def hashCode(): Int = { - metadata.map(Objects.hashCode).foldLeft(0)((a, b) => 31 * a + b) - } - - def metadataString: String = { - val entries = scala.collection.mutable.ArrayBuffer.empty[(String, String)] - if (filters.nonEmpty) entries += "PushedFilter" -> filters.mkString("[", ", ", "]") - - val outputStr = Utils.truncatedString(output, "[", ", ", "]") - - val entriesStr = if (entries.nonEmpty) { - Utils.truncatedString(entries.map { - case (key, value) => key + ": " + StringUtils.abbreviate(redact(value), 100) - }, " (", ", ", ")") - } else { - "" - } - - s"${source.getClass.getSimpleName}$outputStr$entriesStr" - } - - private def redact(text: String): String = { - Utils.redact(SQLConf.get.stringRedationPattern, text) - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala index cd97e0cab6b5c..38f6b15224788 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala @@ -20,23 +20,15 @@ package org.apache.spark.sql.execution.datasources.v2 import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics} -import org.apache.spark.sql.sources.v2.DataSourceV2 import org.apache.spark.sql.sources.v2.reader._ case class DataSourceV2Relation( output: Seq[AttributeReference], - source: DataSourceV2, - reader: DataSourceReader, - override val isStreaming: Boolean) - extends LeafNode with MultiInstanceRelation with DataSourceV2QueryPlan { + reader: DataSourceReader) + extends LeafNode with MultiInstanceRelation with DataSourceReaderHolder { override def canEqual(other: Any): Boolean = other.isInstanceOf[DataSourceV2Relation] - override def simpleString: String = { - val streamingHeader = if (isStreaming) "Streaming " else "" - s"${streamingHeader}Relation $metadataString" - } - override def computeStats(): Statistics = reader match { case r: SupportsReportStatistics => Statistics(sizeInBytes = r.getStatistics.sizeInBytes().orElse(conf.defaultSizeInBytes)) @@ -49,8 +41,18 @@ case class DataSourceV2Relation( } } +/** + * A specialization of DataSourceV2Relation with the streaming bit set to true. Otherwise identical + * to the non-streaming relation. + */ +class StreamingDataSourceV2Relation( + output: Seq[AttributeReference], + reader: DataSourceReader) extends DataSourceV2Relation(output, reader) { + override def isStreaming: Boolean = true +} + object DataSourceV2Relation { - def apply(source: DataSourceV2, reader: DataSourceReader): DataSourceV2Relation = { - new DataSourceV2Relation(reader.readSchema().toAttributes, source, reader, isStreaming = false) + def apply(reader: DataSourceReader): DataSourceV2Relation = { + new DataSourceV2Relation(reader.readSchema().toAttributes, reader) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala index c99d535efcf81..7d9581be4db89 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala @@ -27,7 +27,6 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical import org.apache.spark.sql.execution.{ColumnarBatchScan, LeafExecNode, WholeStageCodegenExec} import org.apache.spark.sql.execution.streaming.continuous._ -import org.apache.spark.sql.sources.v2.DataSourceV2 import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousReader import org.apache.spark.sql.types.StructType @@ -37,14 +36,11 @@ import org.apache.spark.sql.types.StructType */ case class DataSourceV2ScanExec( output: Seq[AttributeReference], - @transient source: DataSourceV2, @transient reader: DataSourceReader) - extends LeafExecNode with DataSourceV2QueryPlan with ColumnarBatchScan { + extends LeafExecNode with DataSourceReaderHolder with ColumnarBatchScan { override def canEqual(other: Any): Boolean = other.isInstanceOf[DataSourceV2ScanExec] - override def simpleString: String = s"Scan $metadataString" - override def outputPartitioning: physical.Partitioning = reader match { case s: SupportsReportPartitioning => new DataSourcePartitioning( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index fb61e6f32b1f4..df5b524485f54 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -23,8 +23,8 @@ import org.apache.spark.sql.execution.SparkPlan object DataSourceV2Strategy extends Strategy { override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case r: DataSourceV2Relation => - DataSourceV2ScanExec(r.output, r.source, r.reader) :: Nil + case DataSourceV2Relation(output, reader) => + DataSourceV2ScanExec(output, reader) :: Nil case WriteToDataSourceV2(writer, query) => WriteToDataSourceV2Exec(writer, planLater(query)) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala index 4cfdd50e8f46b..1ca6cbf061b4e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala @@ -39,11 +39,11 @@ object PushDownOperatorsToDataSource extends Rule[LogicalPlan] with PredicateHel // TODO: Ideally column pruning should be implemented via a plan property that is propagated // top-down, then we can simplify the logic here and only collect target operators. val filterPushed = plan transformUp { - case FilterAndProject(fields, condition, r: DataSourceV2Relation) => + case FilterAndProject(fields, condition, r @ DataSourceV2Relation(_, reader)) => val (candidates, nonDeterministic) = splitConjunctivePredicates(condition).partition(_.deterministic) - val stayUpFilters: Seq[Expression] = r.reader match { + val stayUpFilters: Seq[Expression] = reader match { case r: SupportsPushDownCatalystFilters => r.pushCatalystFilters(candidates.toArray) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index 84564b6639ac9..812533313332e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -27,9 +27,9 @@ import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, CurrentBatchTimestamp, CurrentDate, CurrentTimestamp} import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.execution.SQLExecution -import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, WriteToDataSourceV2} +import org.apache.spark.sql.execution.datasources.v2.{StreamingDataSourceV2Relation, WriteToDataSourceV2} import org.apache.spark.sql.execution.streaming.sources.{InternalRowMicroBatchWriter, MicroBatchWriter} -import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, MicroBatchReadSupport, StreamWriteSupport} +import org.apache.spark.sql.sources.v2.{DataSourceOptions, MicroBatchReadSupport, StreamWriteSupport} import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset => OffsetV2} import org.apache.spark.sql.sources.v2.writer.SupportsWriteInternalRow import org.apache.spark.sql.streaming.{OutputMode, ProcessingTime, Trigger} @@ -52,8 +52,6 @@ class MicroBatchExecution( @volatile protected var sources: Seq[BaseStreamingSource] = Seq.empty - private val readerToDataSourceMap = MutableMap.empty[MicroBatchReader, DataSourceV2] - private val triggerExecutor = trigger match { case t: ProcessingTime => ProcessingTimeExecutor(t, triggerClock) case OneTimeTrigger => OneTimeExecutor() @@ -92,7 +90,6 @@ class MicroBatchExecution( metadataPath, new DataSourceOptions(options.asJava)) nextSourceId += 1 - readerToDataSourceMap(reader) = source StreamingExecutionRelation(reader, output)(sparkSession) }) case s @ StreamingRelationV2(_, sourceName, _, output, v1Relation) => @@ -408,15 +405,12 @@ class MicroBatchExecution( case v1: SerializedOffset => reader.deserializeOffset(v1.json) case v2: OffsetV2 => v2 } - reader.setOffsetRange(toJava(current), Optional.of(availableV2)) + reader.setOffsetRange( + toJava(current), + Optional.of(availableV2)) logDebug(s"Retrieving data from $reader: $current -> $availableV2") - Some(reader -> new DataSourceV2Relation( - reader.readSchema().toAttributes, - // Provide a fake value here just in case something went wrong, e.g. the reader gives - // a wrong `equals` implementation. - readerToDataSourceMap.getOrElse(reader, FakeDataSourceV2), - reader, - isStreaming = true)) + Some(reader -> + new StreamingDataSourceV2Relation(reader.readSchema().toAttributes, reader)) case _ => None } } @@ -506,5 +500,3 @@ class MicroBatchExecution( Optional.ofNullable(scalaOption.orNull) } } - -object FakeDataSourceV2 extends DataSourceV2 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index f87d57d0b3209..c3294d64b10cd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, CurrentDate, CurrentTimestamp} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.SQLExecution -import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, WriteToDataSourceV2} +import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, StreamingDataSourceV2Relation, WriteToDataSourceV2} import org.apache.spark.sql.execution.streaming.{ContinuousExecutionRelation, StreamingRelationV2, _} import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions, StreamWriteSupport} import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReader, PartitionOffset} @@ -167,7 +167,7 @@ class ContinuousExecution( var insertedSourceId = 0 val withNewSources = logicalPlan transform { - case ContinuousExecutionRelation(ds, _, output) => + case ContinuousExecutionRelation(_, _, output) => val reader = continuousSources(insertedSourceId) insertedSourceId += 1 val newOutput = reader.readSchema().toAttributes @@ -180,7 +180,7 @@ class ContinuousExecution( val loggedOffset = offsets.offsets(0) val realOffset = loggedOffset.map(off => reader.deserializeOffset(off.json)) reader.setStartOffset(java.util.Optional.ofNullable(realOffset.orNull)) - new DataSourceV2Relation(newOutput, ds, reader, isStreaming = true) + new StreamingDataSourceV2Relation(newOutput, reader) } // Rewire the plan to use the new attributes that were returned by the source. @@ -201,8 +201,7 @@ class ContinuousExecution( val withSink = WriteToDataSourceV2(writer, triggerLogicalPlan) val reader = withSink.collect { - case r: DataSourceV2Relation if r.reader.isInstanceOf[ContinuousReader] => - r.reader.asInstanceOf[ContinuousReader] + case DataSourceV2Relation(_, r: ContinuousReader) => r }.head reportTimeTaken("queryPlanning") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala index 70eb9f0ac66d5..d1a04833390f5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala @@ -492,16 +492,16 @@ class StreamSuite extends StreamTest { val explainWithoutExtended = q.explainInternal(false) // `extended = false` only displays the physical plan. - assert("Streaming Relation".r.findAllMatchIn(explainWithoutExtended).size === 0) - assert("Scan FakeDataSourceV2".r.findAllMatchIn(explainWithoutExtended).size === 1) + assert("StreamingDataSourceV2Relation".r.findAllMatchIn(explainWithoutExtended).size === 0) + assert("DataSourceV2Scan".r.findAllMatchIn(explainWithoutExtended).size === 1) // Use "StateStoreRestore" to verify that it does output a streaming physical plan assert(explainWithoutExtended.contains("StateStoreRestore")) val explainWithExtended = q.explainInternal(true) // `extended = true` displays 3 logical plans (Parsed/Optimized/Optimized) and 1 physical // plan. - assert("Streaming Relation".r.findAllMatchIn(explainWithExtended).size === 3) - assert("Scan FakeDataSourceV2".r.findAllMatchIn(explainWithExtended).size === 1) + assert("StreamingDataSourceV2Relation".r.findAllMatchIn(explainWithExtended).size === 3) + assert("DataSourceV2Scan".r.findAllMatchIn(explainWithExtended).size === 1) // Use "StateStoreRestore" to verify that it does output a streaming physical plan assert(explainWithExtended.contains("StateStoreRestore")) } finally { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index 254394685857b..37fe595529baf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -605,7 +605,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be plan .collect { case StreamingExecutionRelation(s, _) => s - case d: DataSourceV2Relation => d.reader + case DataSourceV2Relation(_, r) => r } .zipWithIndex .find(_._1 == source) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala index 9ee9aaf87f87c..4b4ed82dc6520 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala @@ -17,12 +17,15 @@ package org.apache.spark.sql.streaming.continuous -import org.apache.spark.{SparkContext, SparkException} -import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskStart} +import java.util.UUID + +import org.apache.spark.{SparkContext, SparkEnv, SparkException} +import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart, SparkListenerTaskStart} import org.apache.spark.sql._ -import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanExec +import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanExec, WriteToDataSourceV2Exec} import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous._ +import org.apache.spark.sql.execution.streaming.sources.MemorySinkV2 import org.apache.spark.sql.functions._ import org.apache.spark.sql.streaming.{StreamTest, Trigger} import org.apache.spark.sql.test.TestSparkSession @@ -40,7 +43,7 @@ class ContinuousSuiteBase extends StreamTest { case s: ContinuousExecution => assert(numTriggers >= 2, "must wait for at least 2 triggers to ensure query is initialized") val reader = s.lastExecution.executedPlan.collectFirst { - case DataSourceV2ScanExec(_, _, r: RateStreamContinuousReader) => r + case DataSourceV2ScanExec(_, r: RateStreamContinuousReader) => r }.get val deltaMs = numTriggers * 1000 + 300 From 357babde5a8eb9710de7016d7ae82dee21fa4ef3 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Wed, 14 Feb 2018 10:55:24 +0800 Subject: [PATCH 0351/1161] [SPARK-23399][SQL] Register a task completion listener first for OrcColumnarBatchReader ## What changes were proposed in this pull request? This PR aims to resolve an open file leakage issue reported at [SPARK-23390](https://issues.apache.org/jira/browse/SPARK-23390) by moving the listener registration position. Currently, the sequence is like the following. 1. Create `batchReader` 2. `batchReader.initialize` opens a ORC file. 3. `batchReader.initBatch` may take a long time to alloc memory in some environment and cause errors. 4. `Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => iter.close()))` This PR moves 4 before 2 and 3. To sum up, the new sequence is 1 -> 4 -> 2 -> 3. ## How was this patch tested? Manual. The following test case makes OOM intentionally to cause leaked filesystem connection in the current code base. With this patch, leakage doesn't occurs. ```scala // This should be tested manually because it raises OOM intentionally // in order to cause `Leaked filesystem connection`. test("SPARK-23399 Register a task completion listener first for OrcColumnarBatchReader") { withSQLConf(SQLConf.ORC_VECTORIZED_READER_BATCH_SIZE.key -> s"${Int.MaxValue}") { withTempDir { dir => val basePath = dir.getCanonicalPath Seq(0).toDF("a").write.format("orc").save(new Path(basePath, "first").toString) Seq(1).toDF("a").write.format("orc").save(new Path(basePath, "second").toString) val df = spark.read.orc( new Path(basePath, "first").toString, new Path(basePath, "second").toString) val e = intercept[SparkException] { df.collect() } assert(e.getCause.isInstanceOf[OutOfMemoryError]) } } } ``` Author: Dongjoon Hyun Closes #20590 from dongjoon-hyun/SPARK-23399. --- .../sql/execution/datasources/orc/OrcFileFormat.scala | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala index dbf3bc6f0ee6c..1de2ca2914c44 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala @@ -188,6 +188,12 @@ class OrcFileFormat if (enableVectorizedReader) { val batchReader = new OrcColumnarBatchReader( enableOffHeapColumnVector && taskContext.isDefined, copyToSpark, capacity) + // SPARK-23399 Register a task completion listener first to call `close()` in all cases. + // There is a possibility that `initialize` and `initBatch` hit some errors (like OOM) + // after opening a file. + val iter = new RecordReaderIterator(batchReader) + Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => iter.close())) + batchReader.initialize(fileSplit, taskAttemptContext) batchReader.initBatch( reader.getSchema, @@ -196,8 +202,6 @@ class OrcFileFormat partitionSchema, file.partitionValues) - val iter = new RecordReaderIterator(batchReader) - Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => iter.close())) iter.asInstanceOf[Iterator[InternalRow]] } else { val orcRecordReader = new OrcInputFormat[OrcStruct] From 140f87533a468b1046504fc3ff01fbe1637e41cd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Cattilapiros=E2=80=9D?= Date: Wed, 14 Feb 2018 06:45:54 -0800 Subject: [PATCH 0352/1161] [SPARK-23394][UI] In RDD storage page show the executor addresses instead of the IDs MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? Extending RDD storage page to show executor addresses in the block table. ## How was this patch tested? Manually: ![screen shot 2018-02-13 at 10 30 59](https://user-images.githubusercontent.com/2017933/36142668-0b3578f8-10a9-11e8-95ea-2f57703ee4af.png) Author: “attilapiros” Closes #20589 from attilapiros/SPARK-23394. --- .../org/apache/spark/ui/storage/RDDPage.scala | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala b/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala index 02cee7f8c5b33..2674b9291203a 100644 --- a/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala @@ -23,7 +23,7 @@ import javax.servlet.http.HttpServletRequest import scala.xml.{Node, Unparsed} import org.apache.spark.status.AppStatusStore -import org.apache.spark.status.api.v1.{RDDDataDistribution, RDDPartitionInfo} +import org.apache.spark.status.api.v1.{ExecutorSummary, RDDDataDistribution, RDDPartitionInfo} import org.apache.spark.ui._ import org.apache.spark.util.Utils @@ -76,7 +76,8 @@ private[ui] class RDDPage(parent: SparkUITab, store: AppStatusStore) extends Web rddStorageInfo.partitions.get, blockPageSize, blockSortColumn, - blockSortDesc) + blockSortDesc, + store.executorList(true)) _blockTable.table(page) } catch { case e @ (_ : IllegalArgumentException | _ : IndexOutOfBoundsException) => @@ -182,7 +183,8 @@ private[ui] class BlockDataSource( rddPartitions: Seq[RDDPartitionInfo], pageSize: Int, sortColumn: String, - desc: Boolean) extends PagedDataSource[BlockTableRowData](pageSize) { + desc: Boolean, + executorIdToAddress: Map[String, String]) extends PagedDataSource[BlockTableRowData](pageSize) { private val data = rddPartitions.map(blockRow).sorted(ordering(sortColumn, desc)) @@ -198,7 +200,10 @@ private[ui] class BlockDataSource( rddPartition.storageLevel, rddPartition.memoryUsed, rddPartition.diskUsed, - rddPartition.executors.mkString(" ")) + rddPartition.executors + .map { id => executorIdToAddress.get(id).getOrElse(id) } + .sorted + .mkString(" ")) } /** @@ -226,7 +231,8 @@ private[ui] class BlockPagedTable( rddPartitions: Seq[RDDPartitionInfo], pageSize: Int, sortColumn: String, - desc: Boolean) extends PagedTable[BlockTableRowData] { + desc: Boolean, + executorSummaries: Seq[ExecutorSummary]) extends PagedTable[BlockTableRowData] { override def tableId: String = "rdd-storage-by-block-table" @@ -243,7 +249,8 @@ private[ui] class BlockPagedTable( rddPartitions, pageSize, sortColumn, - desc) + desc, + executorSummaries.map { ex => (ex.id, ex.hostPort) }.toMap) override def pageLink(page: Int): String = { val encodedSortColumn = URLEncoder.encode(sortColumn, "UTF-8") From 400a1d9e25c1196f0be87323bd89fb3af0660166 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Wed, 14 Feb 2018 10:57:12 -0800 Subject: [PATCH 0353/1161] Revert "[SPARK-23249][SQL] Improved block merging logic for partitions" This reverts commit 8c21170decfb9ca4d3233e1ea13bd1b6e3199ed9. --- .../sql/execution/DataSourceScanExec.scala | 29 +++++-------------- .../datasources/FileSourceStrategySuite.scala | 15 ++++++---- 2 files changed, 17 insertions(+), 27 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index ba1157d5b6a49..08ff33afbba3d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -444,29 +444,16 @@ case class FileSourceScanExec( currentSize = 0 } - def addFile(file: PartitionedFile): Unit = { - currentFiles += file - currentSize += file.length + openCostInBytes - } - - var frontIndex = 0 - var backIndex = splitFiles.length - 1 - - while (frontIndex <= backIndex) { - addFile(splitFiles(frontIndex)) - frontIndex += 1 - while (frontIndex <= backIndex && - currentSize + splitFiles(frontIndex).length <= maxSplitBytes) { - addFile(splitFiles(frontIndex)) - frontIndex += 1 - } - while (backIndex > frontIndex && - currentSize + splitFiles(backIndex).length <= maxSplitBytes) { - addFile(splitFiles(backIndex)) - backIndex -= 1 + // Assign files to partitions using "Next Fit Decreasing" + splitFiles.foreach { file => + if (currentSize + file.length > maxSplitBytes) { + closePartition() } - closePartition() + // Add the given file to the current partition. + currentSize += file.length + openCostInBytes + currentFiles += file } + closePartition() new FileScanRDD(fsRelation.sparkSession, readFile, partitions) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala index bfccc9335b361..c1d61b843d899 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala @@ -141,17 +141,16 @@ class FileSourceStrategySuite extends QueryTest with SharedSQLContext with Predi withSQLConf(SQLConf.FILES_MAX_PARTITION_BYTES.key -> "4", SQLConf.FILES_OPEN_COST_IN_BYTES.key -> "1") { checkScan(table.select('c1)) { partitions => - // Files should be laid out [(file1, file6), (file2, file3), (file4, file5)] - assert(partitions.size == 3, "when checking partitions") - assert(partitions(0).files.size == 2, "when checking partition 1") + // Files should be laid out [(file1), (file2, file3), (file4, file5), (file6)] + assert(partitions.size == 4, "when checking partitions") + assert(partitions(0).files.size == 1, "when checking partition 1") assert(partitions(1).files.size == 2, "when checking partition 2") assert(partitions(2).files.size == 2, "when checking partition 3") + assert(partitions(3).files.size == 1, "when checking partition 4") - // First partition reads (file1, file6) + // First partition reads (file1) assert(partitions(0).files(0).start == 0) assert(partitions(0).files(0).length == 2) - assert(partitions(0).files(1).start == 0) - assert(partitions(0).files(1).length == 1) // Second partition reads (file2, file3) assert(partitions(1).files(0).start == 0) @@ -164,6 +163,10 @@ class FileSourceStrategySuite extends QueryTest with SharedSQLContext with Predi assert(partitions(2).files(0).length == 1) assert(partitions(2).files(1).start == 0) assert(partitions(2).files(1).length == 1) + + // Final partition reads (file6) + assert(partitions(3).files(0).start == 0) + assert(partitions(3).files(0).length == 1) } checkPartitionSchema(StructType(Nil)) From 658d9d9d785a30857bf35d164e6cbbd9799d6959 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Wed, 14 Feb 2018 14:27:02 -0800 Subject: [PATCH 0354/1161] [SPARK-23406][SS] Enable stream-stream self-joins ## What changes were proposed in this pull request? Solved two bugs to enable stream-stream self joins. ### Incorrect analysis due to missing MultiInstanceRelation trait Streaming leaf nodes did not extend MultiInstanceRelation, which is necessary for the catalyst analyzer to convert the self-join logical plan DAG into a tree (by creating new instances of the leaf relations). This was causing the error `Failure when resolving conflicting references in Join:` (see JIRA for details). ### Incorrect attribute rewrite when splicing batch plans in MicroBatchExecution When splicing the source's batch plan into the streaming plan (by replacing the StreamingExecutionPlan), we were rewriting the attribute reference in the streaming plan with the new attribute references from the batch plan. This was incorrectly handling the scenario when multiple StreamingExecutionRelation point to the same source, and therefore eventually point to the same batch plan returned by the source. Here is an example query, and its corresponding plan transformations. ``` val df = input.toDF val join = df.select('value % 5 as "key", 'value).join( df.select('value % 5 as "key", 'value), "key") ``` Streaming logical plan before splicing the batch plan ``` Project [key#6, value#1, value#12] +- Join Inner, (key#6 = key#9) :- Project [(value#1 % 5) AS key#6, value#1] : +- StreamingExecutionRelation Memory[#1], value#1 +- Project [(value#12 % 5) AS key#9, value#12] +- StreamingExecutionRelation Memory[#1], value#12 // two different leaves pointing to same source ``` Batch logical plan after splicing the batch plan and before rewriting ``` Project [key#6, value#1, value#12] +- Join Inner, (key#6 = key#9) :- Project [(value#1 % 5) AS key#6, value#1] : +- LocalRelation [value#66] // replaces StreamingExecutionRelation Memory[#1], value#1 +- Project [(value#12 % 5) AS key#9, value#12] +- LocalRelation [value#66] // replaces StreamingExecutionRelation Memory[#1], value#12 ``` Batch logical plan after rewriting the attributes. Specifically, for spliced, the new output attributes (value#66) replace the earlier output attributes (value#12, and value#1, one for each StreamingExecutionRelation). ``` Project [key#6, value#66, value#66] // both value#1 and value#12 replaces by value#66 +- Join Inner, (key#6 = key#9) :- Project [(value#66 % 5) AS key#6, value#66] : +- LocalRelation [value#66] +- Project [(value#66 % 5) AS key#9, value#66] +- LocalRelation [value#66] ``` This causes the optimizer to eliminate value#66 from one side of the join. ``` Project [key#6, value#66, value#66] +- Join Inner, (key#6 = key#9) :- Project [(value#66 % 5) AS key#6, value#66] : +- LocalRelation [value#66] +- Project [(value#66 % 5) AS key#9] // this does not generate value, incorrect join results +- LocalRelation [value#66] ``` **Solution**: Instead of rewriting attributes, use a Project to introduce aliases between the output attribute references and the new reference generated by the spliced plans. The analyzer and optimizer will take care of the rest. ``` Project [key#6, value#1, value#12] +- Join Inner, (key#6 = key#9) :- Project [(value#1 % 5) AS key#6, value#1] : +- Project [value#66 AS value#1] // solution: project with aliases : +- LocalRelation [value#66] +- Project [(value#12 % 5) AS key#9, value#12] +- Project [value#66 AS value#12] // solution: project with aliases +- LocalRelation [value#66] ``` ## How was this patch tested? New unit test Author: Tathagata Das Closes #20598 from tdas/SPARK-23406. --- .../streaming/MicroBatchExecution.scala | 16 ++++++------ .../streaming/StreamingRelation.scala | 20 ++++++++++----- .../sql/streaming/StreamingJoinSuite.scala | 25 ++++++++++++++++++- 3 files changed, 45 insertions(+), 16 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index 812533313332e..ac73ba3417904 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -24,8 +24,8 @@ import scala.collection.mutable.{ArrayBuffer, Map => MutableMap} import org.apache.spark.sql.{Dataset, SparkSession} import org.apache.spark.sql.catalyst.encoders.RowEncoder -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, CurrentBatchTimestamp, CurrentDate, CurrentTimestamp} -import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeMap, CurrentBatchTimestamp, CurrentDate, CurrentTimestamp} +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Project} import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.datasources.v2.{StreamingDataSourceV2Relation, WriteToDataSourceV2} import org.apache.spark.sql.execution.streaming.sources.{InternalRowMicroBatchWriter, MicroBatchWriter} @@ -415,8 +415,6 @@ class MicroBatchExecution( } } - // A list of attributes that will need to be updated. - val replacements = new ArrayBuffer[(Attribute, Attribute)] // Replace sources in the logical plan with data that has arrived since the last batch. val newBatchesPlan = logicalPlan transform { case StreamingExecutionRelation(source, output) => @@ -424,18 +422,18 @@ class MicroBatchExecution( assert(output.size == dataPlan.output.size, s"Invalid batch: ${Utils.truncatedString(output, ",")} != " + s"${Utils.truncatedString(dataPlan.output, ",")}") - replacements ++= output.zip(dataPlan.output) - dataPlan + + val aliases = output.zip(dataPlan.output).map { case (to, from) => + Alias(from, to.name)(exprId = to.exprId, explicitMetadata = Some(from.metadata)) + } + Project(aliases, dataPlan) }.getOrElse { LocalRelation(output, isStreaming = true) } } // Rewire the plan to use the new attributes that were returned by the source. - val replacementMap = AttributeMap(replacements) val newAttributePlan = newBatchesPlan transformAllExpressions { - case a: Attribute if replacementMap.contains(a) => - replacementMap(a).withMetadata(a.metadata) case ct: CurrentTimestamp => CurrentBatchTimestamp(offsetSeqMetadata.batchTimestampMs, ct.dataType) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala index 7146190645b37..f02d3a2c3733f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala @@ -20,9 +20,9 @@ package org.apache.spark.sql.execution.streaming import org.apache.spark.rdd.RDD import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.catalyst.plans.logical.LeafNode -import org.apache.spark.sql.catalyst.plans.logical.Statistics +import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} import org.apache.spark.sql.execution.LeafExecNode import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceV2} @@ -42,7 +42,7 @@ object StreamingRelation { * passing to [[StreamExecution]] to run a query. */ case class StreamingRelation(dataSource: DataSource, sourceName: String, output: Seq[Attribute]) - extends LeafNode { + extends LeafNode with MultiInstanceRelation { override def isStreaming: Boolean = true override def toString: String = sourceName @@ -53,6 +53,8 @@ case class StreamingRelation(dataSource: DataSource, sourceName: String, output: override def computeStats(): Statistics = Statistics( sizeInBytes = BigInt(dataSource.sparkSession.sessionState.conf.defaultSizeInBytes) ) + + override def newInstance(): LogicalPlan = this.copy(output = output.map(_.newInstance())) } /** @@ -62,7 +64,7 @@ case class StreamingRelation(dataSource: DataSource, sourceName: String, output: case class StreamingExecutionRelation( source: BaseStreamingSource, output: Seq[Attribute])(session: SparkSession) - extends LeafNode { + extends LeafNode with MultiInstanceRelation { override def isStreaming: Boolean = true override def toString: String = source.toString @@ -74,6 +76,8 @@ case class StreamingExecutionRelation( override def computeStats(): Statistics = Statistics( sizeInBytes = BigInt(session.sessionState.conf.defaultSizeInBytes) ) + + override def newInstance(): LogicalPlan = this.copy(output = output.map(_.newInstance()))(session) } // We have to pack in the V1 data source as a shim, for the case when a source implements @@ -92,13 +96,15 @@ case class StreamingRelationV2( extraOptions: Map[String, String], output: Seq[Attribute], v1Relation: Option[StreamingRelation])(session: SparkSession) - extends LeafNode { + extends LeafNode with MultiInstanceRelation { override def isStreaming: Boolean = true override def toString: String = sourceName override def computeStats(): Statistics = Statistics( sizeInBytes = BigInt(session.sessionState.conf.defaultSizeInBytes) ) + + override def newInstance(): LogicalPlan = this.copy(output = output.map(_.newInstance()))(session) } /** @@ -108,7 +114,7 @@ case class ContinuousExecutionRelation( source: ContinuousReadSupport, extraOptions: Map[String, String], output: Seq[Attribute])(session: SparkSession) - extends LeafNode { + extends LeafNode with MultiInstanceRelation { override def isStreaming: Boolean = true override def toString: String = source.toString @@ -120,6 +126,8 @@ case class ContinuousExecutionRelation( override def computeStats(): Statistics = Statistics( sizeInBytes = BigInt(session.sessionState.conf.defaultSizeInBytes) ) + + override def newInstance(): LogicalPlan = this.copy(output = output.map(_.newInstance()))(session) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala index 54eb863dacc83..92087f68ad74a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala @@ -28,7 +28,9 @@ import org.apache.spark.sql.{AnalysisException, DataFrame, Row, SparkSession} import org.apache.spark.sql.catalyst.analysis.StreamingJoinHelper import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeSet, Literal} import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, Filter} -import org.apache.spark.sql.execution.LogicalRDD +import org.apache.spark.sql.catalyst.trees.TreeNode +import org.apache.spark.sql.execution.{FileSourceScanExec, LogicalRDD} +import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.streaming.{MemoryStream, StatefulOperatorStateInfo, StreamingSymmetricHashJoinHelper} import org.apache.spark.sql.execution.streaming.state.{StateStore, StateStoreProviderId} import org.apache.spark.sql.functions._ @@ -323,6 +325,27 @@ class StreamingInnerJoinSuite extends StreamTest with StateStoreMetricsTest with assert(e.toString.contains("Stream stream joins without equality predicate is not supported")) } + test("stream stream self join") { + val input = MemoryStream[Int] + val df = input.toDF + val join = + df.select('value % 5 as "key", 'value).join( + df.select('value % 5 as "key", 'value), "key") + + testStream(join)( + AddData(input, 1, 2), + CheckAnswer((1, 1, 1), (2, 2, 2)), + StopStream, + StartStream(), + AddData(input, 3, 6), + /* + (1, 1) (1, 1) + (2, 2) x (2, 2) = (1, 1, 1), (1, 1, 6), (2, 2, 2), (1, 6, 1), (1, 6, 6) + (1, 6) (1, 6) + */ + CheckAnswer((3, 3, 3), (1, 1, 1), (1, 1, 6), (2, 2, 2), (1, 6, 1), (1, 6, 6))) + } + test("locality preferences of StateStoreAwareZippedRDD") { import StreamingSymmetricHashJoinHelper._ From a77ebb0921e390cf4fc6279a8c0a92868ad7e69b Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Wed, 14 Feb 2018 23:52:59 -0800 Subject: [PATCH 0355/1161] [SPARK-23421][SPARK-22356][SQL] Document the behavior change in ## What changes were proposed in this pull request? https://github.com/apache/spark/pull/19579 introduces a behavior change. We need to document it in the migration guide. ## How was this patch tested? Also update the HiveExternalCatalogVersionsSuite to verify it. Author: gatorsmile Closes #20606 from gatorsmile/addMigrationGuide. --- docs/sql-programming-guide.md | 2 ++ .../spark/sql/hive/HiveExternalCatalogVersionsSuite.scala | 4 ++-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 0f9f01e18682f..cf9529a79f4f9 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1963,6 +1963,8 @@ working with timestamps in `pandas_udf`s to get the best performance, see ## Upgrading From Spark SQL 2.1 to 2.2 - Spark 2.1.1 introduced a new configuration key: `spark.sql.hive.caseSensitiveInferenceMode`. It had a default setting of `NEVER_INFER`, which kept behavior identical to 2.1.0. However, Spark 2.2.0 changes this setting's default value to `INFER_AND_SAVE` to restore compatibility with reading Hive metastore tables whose underlying file schema have mixed-case column names. With the `INFER_AND_SAVE` configuration value, on first access Spark will perform schema inference on any Hive metastore table for which it has not already saved an inferred schema. Note that schema inference can be a very time consuming operation for tables with thousands of partitions. If compatibility with mixed-case column names is not a concern, you can safely set `spark.sql.hive.caseSensitiveInferenceMode` to `NEVER_INFER` to avoid the initial overhead of schema inference. Note that with the new default `INFER_AND_SAVE` setting, the results of the schema inference are saved as a metastore key for future use. Therefore, the initial schema inference occurs only at a table's first access. + + - Since Spark 2.2.1 and 2.3.0, the schema is always inferred at runtime when the data source tables have the columns that exist in both partition schema and data schema. The inferred schema does not have the partitioned columns. When reading the table, Spark respects the partition values of these overlapping columns instead of the values stored in the data source files. In 2.2.0 and 2.1.x release, the inferred schema is partitioned but the data of the table is invisible to users (i.e., the result set is empty). ## Upgrading From Spark SQL 2.0 to 2.1 diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala index ae4aeb7b4ce4a..c13a750dbb270 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala @@ -195,7 +195,7 @@ class HiveExternalCatalogVersionsSuite extends SparkSubmitTestUtils { object PROCESS_TABLES extends QueryTest with SQLTestUtils { // Tests the latest version of every release line. - val testingVersions = Seq("2.0.2", "2.1.2", "2.2.0") + val testingVersions = Seq("2.0.2", "2.1.2", "2.2.0", "2.2.1") protected var spark: SparkSession = _ @@ -249,7 +249,7 @@ object PROCESS_TABLES extends QueryTest with SQLTestUtils { // SPARK-22356: overlapped columns between data and partition schema in data source tables val tbl_with_col_overlap = s"tbl_with_col_overlap_$index" - // For Spark 2.2.0 and 2.1.x, the behavior is different from Spark 2.0. + // For Spark 2.2.0 and 2.1.x, the behavior is different from Spark 2.0, 2.2.1, 2.3+ if (testingVersions(index).startsWith("2.1") || testingVersions(index) == "2.2.0") { spark.sql("msck repair table " + tbl_with_col_overlap) assert(spark.table(tbl_with_col_overlap).columns === Array("i", "j", "p")) From 95e4b4916065e66a4f8dba57e98e725796f75e04 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Wed, 14 Feb 2018 23:56:02 -0800 Subject: [PATCH 0356/1161] [SPARK-23094] Revert [] Fix invalid character handling in JsonDataSource ## What changes were proposed in this pull request? This PR is to revert the PR https://github.com/apache/spark/pull/20302, because it causes a regression. ## How was this patch tested? N/A Author: gatorsmile Closes #20614 from gatorsmile/revertJsonFix. --- .../catalyst/json/CreateJacksonParser.scala | 5 ++- .../sources/JsonHadoopFsRelationSuite.scala | 34 ------------------- 2 files changed, 2 insertions(+), 37 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/CreateJacksonParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/CreateJacksonParser.scala index b1672e7e2fca2..025a388aacaa5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/CreateJacksonParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/CreateJacksonParser.scala @@ -40,11 +40,10 @@ private[sql] object CreateJacksonParser extends Serializable { } def text(jsonFactory: JsonFactory, record: Text): JsonParser = { - val bain = new ByteArrayInputStream(record.getBytes, 0, record.getLength) - jsonFactory.createParser(new InputStreamReader(bain, "UTF-8")) + jsonFactory.createParser(record.getBytes, 0, record.getLength) } def inputStream(jsonFactory: JsonFactory, record: InputStream): JsonParser = { - jsonFactory.createParser(new InputStreamReader(record, "UTF-8")) + jsonFactory.createParser(record) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/JsonHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/JsonHadoopFsRelationSuite.scala index 27f398ebf301a..49be30435ad2f 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/JsonHadoopFsRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/JsonHadoopFsRelationSuite.scala @@ -28,8 +28,6 @@ import org.apache.spark.sql.types._ class JsonHadoopFsRelationSuite extends HadoopFsRelationTest { override val dataSourceName: String = "json" - private val badJson = "\u0000\u0000\u0000A\u0001AAA" - // JSON does not write data of NullType and does not play well with BinaryType. override protected def supportsDataType(dataType: DataType): Boolean = dataType match { case _: NullType => false @@ -107,36 +105,4 @@ class JsonHadoopFsRelationSuite extends HadoopFsRelationTest { ) } } - - test("invalid json with leading nulls - from file (multiLine=true)") { - import testImplicits._ - withTempDir { tempDir => - val path = tempDir.getAbsolutePath - Seq(badJson, """{"a":1}""").toDS().write.mode("overwrite").text(path) - val expected = s"""$badJson\n{"a":1}\n""" - val schema = new StructType().add("a", IntegerType).add("_corrupt_record", StringType) - val df = - spark.read.format(dataSourceName).option("multiLine", true).schema(schema).load(path) - checkAnswer(df, Row(null, expected)) - } - } - - test("invalid json with leading nulls - from file (multiLine=false)") { - import testImplicits._ - withTempDir { tempDir => - val path = tempDir.getAbsolutePath - Seq(badJson, """{"a":1}""").toDS().write.mode("overwrite").text(path) - val schema = new StructType().add("a", IntegerType).add("_corrupt_record", StringType) - val df = - spark.read.format(dataSourceName).option("multiLine", false).schema(schema).load(path) - checkAnswer(df, Seq(Row(1, null), Row(null, badJson))) - } - } - - test("invalid json with leading nulls - from dataset") { - import testImplicits._ - checkAnswer( - spark.read.json(Seq(badJson).toDS()), - Row(badJson)) - } } From f38c760638063f1fb45e9ee2c772090fb203a4a0 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 15 Feb 2018 16:59:44 +0800 Subject: [PATCH 0357/1161] [SPARK-23419][SPARK-23416][SS] data source v2 write path should re-throw interruption exceptions directly ## What changes were proposed in this pull request? Streaming execution has a list of exceptions that means interruption, and handle them specially. `WriteToDataSourceV2Exec` should also respect this list and not wrap them with `SparkException`. ## How was this patch tested? existing test. Author: Wenchen Fan Closes #20605 from cloud-fan/write. --- .../datasources/v2/WriteToDataSourceV2.scala | 11 ++++- .../execution/streaming/StreamExecution.scala | 40 ++++++++++--------- 2 files changed, 31 insertions(+), 20 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala index 535e7962d7439..41cdfc80d8a19 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.datasources.v2 +import scala.util.control.NonFatal + import org.apache.spark.{SparkEnv, SparkException, TaskContext} import org.apache.spark.executor.CommitDeniedException import org.apache.spark.internal.Logging @@ -27,6 +29,7 @@ import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder} import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.streaming.StreamExecution import org.apache.spark.sql.execution.streaming.continuous.{CommitPartitionEpoch, ContinuousExecution, EpochCoordinatorRef, SetWriterPartitions} import org.apache.spark.sql.sources.v2.writer._ import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter @@ -107,7 +110,13 @@ case class WriteToDataSourceV2Exec(writer: DataSourceWriter, query: SparkPlan) e throw new SparkException("Writing job failed.", cause) } logError(s"Data source writer $writer aborted.") - throw new SparkException("Writing job aborted.", cause) + cause match { + // Do not wrap interruption exceptions that will be handled by streaming specially. + case _ if StreamExecution.isInterruptionException(cause) => throw cause + // Only wrap non fatal exceptions. + case NonFatal(e) => throw new SparkException("Writing job aborted.", e) + case _ => throw cause + } } sparkContext.emptyRDD diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index e7982d7880ceb..3fc8c7887896a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -356,25 +356,7 @@ abstract class StreamExecution( private def isInterruptedByStop(e: Throwable): Boolean = { if (state.get == TERMINATED) { - e match { - // InterruptedIOException - thrown when an I/O operation is interrupted - // ClosedByInterruptException - thrown when an I/O operation upon a channel is interrupted - case _: InterruptedException | _: InterruptedIOException | _: ClosedByInterruptException => - true - // The cause of the following exceptions may be one of the above exceptions: - // - // UncheckedIOException - thrown by codes that cannot throw a checked IOException, such as - // BiFunction.apply - // ExecutionException - thrown by codes running in a thread pool and these codes throw an - // exception - // UncheckedExecutionException - thrown by codes that cannot throw a checked - // ExecutionException, such as BiFunction.apply - case e2 @ (_: UncheckedIOException | _: ExecutionException | _: UncheckedExecutionException) - if e2.getCause != null => - isInterruptedByStop(e2.getCause) - case _ => - false - } + StreamExecution.isInterruptionException(e) } else { false } @@ -565,6 +547,26 @@ abstract class StreamExecution( object StreamExecution { val QUERY_ID_KEY = "sql.streaming.queryId" + + def isInterruptionException(e: Throwable): Boolean = e match { + // InterruptedIOException - thrown when an I/O operation is interrupted + // ClosedByInterruptException - thrown when an I/O operation upon a channel is interrupted + case _: InterruptedException | _: InterruptedIOException | _: ClosedByInterruptException => + true + // The cause of the following exceptions may be one of the above exceptions: + // + // UncheckedIOException - thrown by codes that cannot throw a checked IOException, such as + // BiFunction.apply + // ExecutionException - thrown by codes running in a thread pool and these codes throw an + // exception + // UncheckedExecutionException - thrown by codes that cannot throw a checked + // ExecutionException, such as BiFunction.apply + case e2 @ (_: UncheckedIOException | _: ExecutionException | _: UncheckedExecutionException) + if e2.getCause != null => + isInterruptionException(e2.getCause) + case _ => + false + } } /** From 7539ae59d6c354c95c50528abe9ddff6972e960f Mon Sep 17 00:00:00 2001 From: Juliusz Sompolski Date: Thu, 15 Feb 2018 17:09:06 +0800 Subject: [PATCH 0358/1161] [SPARK-23366] Improve hot reading path in ReadAheadInputStream MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? `ReadAheadInputStream` was introduced in https://github.com/apache/spark/pull/18317/ to optimize reading spill files from disk. However, from the profiles it seems that the hot path of reading small amounts of data (like readInt) is inefficient - it involves taking locks, and multiple checks. Optimize locking: Lock is not needed when simply accessing the active buffer. Only lock when needing to swap buffers or trigger async reading, or get information about the async state. Optimize short-path single byte reads, that are used e.g. by Java library DataInputStream.readInt. The asyncReader used to call "read" only once on the underlying stream, that never filled the underlying buffer when it was wrapping an LZ4BlockInputStream. If the buffer was returned unfilled, that would trigger the async reader to be triggered to fill the read ahead buffer on each call, because the reader would see that the active buffer is below the refill threshold all the time. However, filling the full buffer all the time could introduce increased latency, so also add an `AtomicBoolean` flag for the async reader to return earlier if there is a reader waiting for data. Remove `readAheadThresholdInBytes` and instead immediately trigger async read when switching the buffers. It allows to simplify code paths, especially the hot one that then only has to check if there is available data in the active buffer, without worrying if it needs to retrigger async read. It seems to have positive effect on perf. ## How was this patch tested? It was noticed as a regression in some workloads after upgrading to Spark 2.3.  It was particularly visible on TPCDS Q95 running on instances with fast disk (i3 AWS instances). Running with profiling: * Spark 2.2 - 5.2-5.3 minutes 9.5% in LZ4BlockInputStream.read * Spark 2.3 - 6.4-6.6 minutes 31.1% in ReadAheadInputStream.read * Spark 2.3 + fix - 5.3-5.4 minutes 13.3% in ReadAheadInputStream.read - very slightly slower, practically within noise. We didn't see other regressions, and many workloads in general seem to be faster with Spark 2.3 (not investigated if thanks to async readed, or unrelated). Author: Juliusz Sompolski Closes #20555 from juliuszsompolski/SPARK-23366. --- .../apache/spark/io/ReadAheadInputStream.java | 119 +++++++++--------- .../unsafe/sort/UnsafeSorterSpillReader.java | 10 +- .../spark/io/GenericFileInputStreamSuite.java | 98 ++++++++------- .../spark/io/NioBufferedInputStreamSuite.java | 6 +- .../spark/io/ReadAheadInputStreamSuite.java | 17 ++- 5 files changed, 133 insertions(+), 117 deletions(-) diff --git a/core/src/main/java/org/apache/spark/io/ReadAheadInputStream.java b/core/src/main/java/org/apache/spark/io/ReadAheadInputStream.java index 5b45d268ace8d..0cced9e222952 100644 --- a/core/src/main/java/org/apache/spark/io/ReadAheadInputStream.java +++ b/core/src/main/java/org/apache/spark/io/ReadAheadInputStream.java @@ -27,6 +27,7 @@ import java.nio.ByteBuffer; import java.util.concurrent.ExecutorService; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.locks.Condition; import java.util.concurrent.locks.ReentrantLock; @@ -78,9 +79,8 @@ public class ReadAheadInputStream extends InputStream { // whether there is a read ahead task running, private boolean isReading; - // If the remaining data size in the current buffer is below this threshold, - // we issue an async read from the underlying input stream. - private final int readAheadThresholdInBytes; + // whether there is a reader waiting for data. + private AtomicBoolean isWaiting = new AtomicBoolean(false); private final InputStream underlyingInputStream; @@ -97,20 +97,13 @@ public class ReadAheadInputStream extends InputStream { * * @param inputStream The underlying input stream. * @param bufferSizeInBytes The buffer size. - * @param readAheadThresholdInBytes If the active buffer has less data than the read-ahead - * threshold, an async read is triggered. */ public ReadAheadInputStream( - InputStream inputStream, int bufferSizeInBytes, int readAheadThresholdInBytes) { + InputStream inputStream, int bufferSizeInBytes) { Preconditions.checkArgument(bufferSizeInBytes > 0, "bufferSizeInBytes should be greater than 0, but the value is " + bufferSizeInBytes); - Preconditions.checkArgument(readAheadThresholdInBytes > 0 && - readAheadThresholdInBytes < bufferSizeInBytes, - "readAheadThresholdInBytes should be greater than 0 and less than bufferSizeInBytes, " + - "but the value is " + readAheadThresholdInBytes); activeBuffer = ByteBuffer.allocate(bufferSizeInBytes); readAheadBuffer = ByteBuffer.allocate(bufferSizeInBytes); - this.readAheadThresholdInBytes = readAheadThresholdInBytes; this.underlyingInputStream = inputStream; activeBuffer.flip(); readAheadBuffer.flip(); @@ -166,12 +159,17 @@ public void run() { // in that case the reader waits for this async read to complete. // So there is no race condition in both the situations. int read = 0; + int off = 0, len = arr.length; Throwable exception = null; try { - while (true) { - read = underlyingInputStream.read(arr); - if (0 != read) break; - } + // try to fill the read ahead buffer. + // if a reader is waiting, possibly return early. + do { + read = underlyingInputStream.read(arr, off, len); + if (read <= 0) break; + off += read; + len -= read; + } while (len > 0 && !isWaiting.get()); } catch (Throwable ex) { exception = ex; if (ex instanceof Error) { @@ -181,13 +179,12 @@ public void run() { } } finally { stateChangeLock.lock(); + readAheadBuffer.limit(off); if (read < 0 || (exception instanceof EOFException)) { endOfStream = true; } else if (exception != null) { readAborted = true; readException = exception; - } else { - readAheadBuffer.limit(read); } readInProgress = false; signalAsyncReadComplete(); @@ -230,7 +227,10 @@ private void signalAsyncReadComplete() { private void waitForAsyncReadComplete() throws IOException { stateChangeLock.lock(); + isWaiting.set(true); try { + // There is only one reader, and one writer, so the writer should signal only once, + // but a while loop checking the wake up condition is still needed to avoid spurious wakeups. while (readInProgress) { asyncReadComplete.await(); } @@ -239,6 +239,7 @@ private void waitForAsyncReadComplete() throws IOException { iio.initCause(e); throw iio; } finally { + isWaiting.set(false); stateChangeLock.unlock(); } checkReadException(); @@ -246,8 +247,13 @@ private void waitForAsyncReadComplete() throws IOException { @Override public int read() throws IOException { - byte[] oneByteArray = oneByte.get(); - return read(oneByteArray, 0, 1) == -1 ? -1 : oneByteArray[0] & 0xFF; + if (activeBuffer.hasRemaining()) { + // short path - just get one byte. + return activeBuffer.get() & 0xFF; + } else { + byte[] oneByteArray = oneByte.get(); + return read(oneByteArray, 0, 1) == -1 ? -1 : oneByteArray[0] & 0xFF; + } } @Override @@ -258,54 +264,43 @@ public int read(byte[] b, int offset, int len) throws IOException { if (len == 0) { return 0; } - stateChangeLock.lock(); - try { - return readInternal(b, offset, len); - } finally { - stateChangeLock.unlock(); - } - } - /** - * flip the active and read ahead buffer - */ - private void swapBuffers() { - ByteBuffer temp = activeBuffer; - activeBuffer = readAheadBuffer; - readAheadBuffer = temp; - } - - /** - * Internal read function which should be called only from read() api. The assumption is that - * the stateChangeLock is already acquired in the caller before calling this function. - */ - private int readInternal(byte[] b, int offset, int len) throws IOException { - assert (stateChangeLock.isLocked()); if (!activeBuffer.hasRemaining()) { - waitForAsyncReadComplete(); - if (readAheadBuffer.hasRemaining()) { - swapBuffers(); - } else { - // The first read or activeBuffer is skipped. - readAsync(); + // No remaining in active buffer - lock and switch to write ahead buffer. + stateChangeLock.lock(); + try { waitForAsyncReadComplete(); - if (isEndOfStream()) { - return -1; + if (!readAheadBuffer.hasRemaining()) { + // The first read. + readAsync(); + waitForAsyncReadComplete(); + if (isEndOfStream()) { + return -1; + } } + // Swap the newly read read ahead buffer in place of empty active buffer. swapBuffers(); + // After swapping buffers, trigger another async read for read ahead buffer. + readAsync(); + } finally { + stateChangeLock.unlock(); } - } else { - checkReadException(); } len = Math.min(len, activeBuffer.remaining()); activeBuffer.get(b, offset, len); - if (activeBuffer.remaining() <= readAheadThresholdInBytes && !readAheadBuffer.hasRemaining()) { - readAsync(); - } return len; } + /** + * flip the active and read ahead buffer + */ + private void swapBuffers() { + ByteBuffer temp = activeBuffer; + activeBuffer = readAheadBuffer; + readAheadBuffer = temp; + } + @Override public int available() throws IOException { stateChangeLock.lock(); @@ -323,6 +318,11 @@ public long skip(long n) throws IOException { if (n <= 0L) { return 0L; } + if (n <= activeBuffer.remaining()) { + // Only skipping from active buffer is sufficient + activeBuffer.position((int) n + activeBuffer.position()); + return n; + } stateChangeLock.lock(); long skipped; try { @@ -346,21 +346,14 @@ private long skipInternal(long n) throws IOException { if (available() >= n) { // we can skip from the internal buffers int toSkip = (int) n; - if (toSkip <= activeBuffer.remaining()) { - // Only skipping from active buffer is sufficient - activeBuffer.position(toSkip + activeBuffer.position()); - if (activeBuffer.remaining() <= readAheadThresholdInBytes - && !readAheadBuffer.hasRemaining()) { - readAsync(); - } - return n; - } // We need to skip from both active buffer and read ahead buffer toSkip -= activeBuffer.remaining(); + assert(toSkip > 0); // skipping from activeBuffer already handled. activeBuffer.position(0); activeBuffer.flip(); readAheadBuffer.position(toSkip + readAheadBuffer.position()); swapBuffers(); + // Trigger async read to emptied read ahead buffer. readAsync(); return n; } else { diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java index 2c53c8d809d2e..fb179d07edebc 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java @@ -72,21 +72,15 @@ public UnsafeSorterSpillReader( bufferSizeBytes = DEFAULT_BUFFER_SIZE_BYTES; } - final double readAheadFraction = - SparkEnv.get() == null ? 0.5 : - SparkEnv.get().conf().getDouble("spark.unsafe.sorter.spill.read.ahead.fraction", 0.5); - - // SPARK-23310: Disable read-ahead input stream, because it is causing lock contention and perf - // regression for TPC-DS queries. final boolean readAheadEnabled = SparkEnv.get() != null && - SparkEnv.get().conf().getBoolean("spark.unsafe.sorter.spill.read.ahead.enabled", false); + SparkEnv.get().conf().getBoolean("spark.unsafe.sorter.spill.read.ahead.enabled", true); final InputStream bs = new NioBufferedFileInputStream(file, (int) bufferSizeBytes); try { if (readAheadEnabled) { this.in = new ReadAheadInputStream(serializerManager.wrapStream(blockId, bs), - (int) bufferSizeBytes, (int) (bufferSizeBytes * readAheadFraction)); + (int) bufferSizeBytes); } else { this.in = serializerManager.wrapStream(blockId, bs); } diff --git a/core/src/test/java/org/apache/spark/io/GenericFileInputStreamSuite.java b/core/src/test/java/org/apache/spark/io/GenericFileInputStreamSuite.java index 3440e1aea2f46..22db3592ecc96 100644 --- a/core/src/test/java/org/apache/spark/io/GenericFileInputStreamSuite.java +++ b/core/src/test/java/org/apache/spark/io/GenericFileInputStreamSuite.java @@ -37,7 +37,7 @@ public abstract class GenericFileInputStreamSuite { protected File inputFile; - protected InputStream inputStream; + protected InputStream[] inputStreams; @Before public void setUp() throws IOException { @@ -54,77 +54,91 @@ public void tearDown() { @Test public void testReadOneByte() throws IOException { - for (int i = 0; i < randomBytes.length; i++) { - assertEquals(randomBytes[i], (byte) inputStream.read()); + for (InputStream inputStream: inputStreams) { + for (int i = 0; i < randomBytes.length; i++) { + assertEquals(randomBytes[i], (byte) inputStream.read()); + } } } @Test public void testReadMultipleBytes() throws IOException { - byte[] readBytes = new byte[8 * 1024]; - int i = 0; - while (i < randomBytes.length) { - int read = inputStream.read(readBytes, 0, 8 * 1024); - for (int j = 0; j < read; j++) { - assertEquals(randomBytes[i], readBytes[j]); - i++; + for (InputStream inputStream: inputStreams) { + byte[] readBytes = new byte[8 * 1024]; + int i = 0; + while (i < randomBytes.length) { + int read = inputStream.read(readBytes, 0, 8 * 1024); + for (int j = 0; j < read; j++) { + assertEquals(randomBytes[i], readBytes[j]); + i++; + } } } } @Test public void testBytesSkipped() throws IOException { - assertEquals(1024, inputStream.skip(1024)); - for (int i = 1024; i < randomBytes.length; i++) { - assertEquals(randomBytes[i], (byte) inputStream.read()); + for (InputStream inputStream: inputStreams) { + assertEquals(1024, inputStream.skip(1024)); + for (int i = 1024; i < randomBytes.length; i++) { + assertEquals(randomBytes[i], (byte) inputStream.read()); + } } } @Test public void testBytesSkippedAfterRead() throws IOException { - for (int i = 0; i < 1024; i++) { - assertEquals(randomBytes[i], (byte) inputStream.read()); - } - assertEquals(1024, inputStream.skip(1024)); - for (int i = 2048; i < randomBytes.length; i++) { - assertEquals(randomBytes[i], (byte) inputStream.read()); + for (InputStream inputStream: inputStreams) { + for (int i = 0; i < 1024; i++) { + assertEquals(randomBytes[i], (byte) inputStream.read()); + } + assertEquals(1024, inputStream.skip(1024)); + for (int i = 2048; i < randomBytes.length; i++) { + assertEquals(randomBytes[i], (byte) inputStream.read()); + } } } @Test public void testNegativeBytesSkippedAfterRead() throws IOException { - for (int i = 0; i < 1024; i++) { - assertEquals(randomBytes[i], (byte) inputStream.read()); - } - // Skipping negative bytes should essential be a no-op - assertEquals(0, inputStream.skip(-1)); - assertEquals(0, inputStream.skip(-1024)); - assertEquals(0, inputStream.skip(Long.MIN_VALUE)); - assertEquals(1024, inputStream.skip(1024)); - for (int i = 2048; i < randomBytes.length; i++) { - assertEquals(randomBytes[i], (byte) inputStream.read()); + for (InputStream inputStream: inputStreams) { + for (int i = 0; i < 1024; i++) { + assertEquals(randomBytes[i], (byte) inputStream.read()); + } + // Skipping negative bytes should essential be a no-op + assertEquals(0, inputStream.skip(-1)); + assertEquals(0, inputStream.skip(-1024)); + assertEquals(0, inputStream.skip(Long.MIN_VALUE)); + assertEquals(1024, inputStream.skip(1024)); + for (int i = 2048; i < randomBytes.length; i++) { + assertEquals(randomBytes[i], (byte) inputStream.read()); + } } } @Test public void testSkipFromFileChannel() throws IOException { - // Since the buffer is smaller than the skipped bytes, this will guarantee - // we skip from underlying file channel. - assertEquals(1024, inputStream.skip(1024)); - for (int i = 1024; i < 2048; i++) { - assertEquals(randomBytes[i], (byte) inputStream.read()); - } - assertEquals(256, inputStream.skip(256)); - assertEquals(256, inputStream.skip(256)); - assertEquals(512, inputStream.skip(512)); - for (int i = 3072; i < randomBytes.length; i++) { - assertEquals(randomBytes[i], (byte) inputStream.read()); + for (InputStream inputStream: inputStreams) { + // Since the buffer is smaller than the skipped bytes, this will guarantee + // we skip from underlying file channel. + assertEquals(1024, inputStream.skip(1024)); + for (int i = 1024; i < 2048; i++) { + assertEquals(randomBytes[i], (byte) inputStream.read()); + } + assertEquals(256, inputStream.skip(256)); + assertEquals(256, inputStream.skip(256)); + assertEquals(512, inputStream.skip(512)); + for (int i = 3072; i < randomBytes.length; i++) { + assertEquals(randomBytes[i], (byte) inputStream.read()); + } } } @Test public void testBytesSkippedAfterEOF() throws IOException { - assertEquals(randomBytes.length, inputStream.skip(randomBytes.length + 1)); - assertEquals(-1, inputStream.read()); + for (InputStream inputStream: inputStreams) { + assertEquals(randomBytes.length, inputStream.skip(randomBytes.length + 1)); + assertEquals(-1, inputStream.read()); + } } } diff --git a/core/src/test/java/org/apache/spark/io/NioBufferedInputStreamSuite.java b/core/src/test/java/org/apache/spark/io/NioBufferedInputStreamSuite.java index 211b33a1a9fb0..a320f8662f707 100644 --- a/core/src/test/java/org/apache/spark/io/NioBufferedInputStreamSuite.java +++ b/core/src/test/java/org/apache/spark/io/NioBufferedInputStreamSuite.java @@ -18,6 +18,7 @@ import org.junit.Before; +import java.io.InputStream; import java.io.IOException; /** @@ -28,6 +29,9 @@ public class NioBufferedInputStreamSuite extends GenericFileInputStreamSuite { @Before public void setUp() throws IOException { super.setUp(); - inputStream = new NioBufferedFileInputStream(inputFile); + inputStreams = new InputStream[] { + new NioBufferedFileInputStream(inputFile), // default + new NioBufferedFileInputStream(inputFile, 123) // small, unaligned buffer + }; } } diff --git a/core/src/test/java/org/apache/spark/io/ReadAheadInputStreamSuite.java b/core/src/test/java/org/apache/spark/io/ReadAheadInputStreamSuite.java index 918ddc4517ec4..bfa1e0b908824 100644 --- a/core/src/test/java/org/apache/spark/io/ReadAheadInputStreamSuite.java +++ b/core/src/test/java/org/apache/spark/io/ReadAheadInputStreamSuite.java @@ -19,16 +19,27 @@ import org.junit.Before; import java.io.IOException; +import java.io.InputStream; /** - * Tests functionality of {@link NioBufferedFileInputStream} + * Tests functionality of {@link ReadAheadInputStreamSuite} */ public class ReadAheadInputStreamSuite extends GenericFileInputStreamSuite { @Before public void setUp() throws IOException { super.setUp(); - inputStream = new ReadAheadInputStream( - new NioBufferedFileInputStream(inputFile), 8 * 1024, 4 * 1024); + inputStreams = new InputStream[] { + // Tests equal and aligned buffers of wrapped an outer stream. + new ReadAheadInputStream(new NioBufferedFileInputStream(inputFile, 8 * 1024), 8 * 1024), + // Tests aligned buffers, wrapped bigger than outer. + new ReadAheadInputStream(new NioBufferedFileInputStream(inputFile, 3 * 1024), 2 * 1024), + // Tests aligned buffers, wrapped smaller than outer. + new ReadAheadInputStream(new NioBufferedFileInputStream(inputFile, 2 * 1024), 3 * 1024), + // Tests unaligned buffers, wrapped bigger than outer. + new ReadAheadInputStream(new NioBufferedFileInputStream(inputFile, 321), 123), + // Tests unaligned buffers, wrapped smaller than outer. + new ReadAheadInputStream(new NioBufferedFileInputStream(inputFile, 123), 321) + }; } } From ed8647609883fcef16be5d24c2cb4ebda25bd6f0 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Thu, 15 Feb 2018 17:13:05 +0800 Subject: [PATCH 0359/1161] [SPARK-23359][SQL] Adds an alias 'names' of 'fieldNames' in Scala's StructType ## What changes were proposed in this pull request? This PR proposes to add an alias 'names' of 'fieldNames' in Scala. Please see the discussion in [SPARK-20090](https://issues.apache.org/jira/browse/SPARK-20090). ## How was this patch tested? Unit tests added in `DataTypeSuite.scala`. Author: hyukjinkwon Closes #20545 from HyukjinKwon/SPARK-23359. --- .../scala/org/apache/spark/sql/types/StructType.scala | 7 +++++++ .../scala/org/apache/spark/sql/types/DataTypeSuite.scala | 8 ++++++++ 2 files changed, 15 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index e3b0969283a84..d5011c3cb87e9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -104,6 +104,13 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru /** Returns all field names in an array. */ def fieldNames: Array[String] = fields.map(_.name) + /** + * Returns all field names in an array. This is an alias of `fieldNames`. + * + * @since 2.4.0 + */ + def names: Array[String] = fieldNames + private lazy val fieldNamesSet: Set[String] = fieldNames.toSet private lazy val nameToField: Map[String, StructField] = fields.map(f => f.name -> f).toMap private lazy val nameToIndex: Map[String, Int] = fieldNames.zipWithIndex.toMap diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala index 8e2b32c2b9a08..5a86f4055dce7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala @@ -134,6 +134,14 @@ class DataTypeSuite extends SparkFunSuite { assert(mapped === expected) } + test("fieldNames and names returns field names") { + val struct = StructType( + StructField("a", LongType) :: StructField("b", FloatType) :: Nil) + + assert(struct.fieldNames === Seq("a", "b")) + assert(struct.names === Seq("a", "b")) + } + test("merge where right contains type conflict") { val left = StructType( StructField("a", LongType) :: From 44e20c42254bc6591b594f54cd94ced5fcfadae3 Mon Sep 17 00:00:00 2001 From: Gabor Somogyi Date: Thu, 15 Feb 2018 03:52:40 -0800 Subject: [PATCH 0360/1161] =?UTF-8?q?[SPARK-23422][CORE]=20YarnShuffleInte?= =?UTF-8?q?grationSuite=20fix=20when=20SPARK=5FPREPEN=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …D_CLASSES set to 1 ## What changes were proposed in this pull request? YarnShuffleIntegrationSuite fails when SPARK_PREPEND_CLASSES set to 1. Normally mllib built before yarn module. When SPARK_PREPEND_CLASSES used mllib classes are on yarn test classpath. Before 2.3 that did not cause issues. But 2.3 has SPARK-22450, which registered some mllib classes with the kryo serializer. Now it dies with the following error: ` 18/02/13 07:33:29 INFO SparkContext: Starting job: collect at YarnShuffleIntegrationSuite.scala:143 Exception in thread "dag-scheduler-event-loop" java.lang.NoClassDefFoundError: breeze/linalg/DenseMatrix ` In this PR NoClassDefFoundError caught only in case of testing and then do nothing. ## How was this patch tested? Automated: Pass the Jenkins. Author: Gabor Somogyi Closes #20608 from gaborgsomogyi/SPARK-23422. --- .../main/scala/org/apache/spark/serializer/KryoSerializer.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index 538ae05e4eea1..72427dd6ce4d4 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -206,6 +206,7 @@ class KryoSerializer(conf: SparkConf) kryo.register(clazz) } catch { case NonFatal(_) => // do nothing + case _: NoClassDefFoundError if Utils.isTesting => // See SPARK-23422. } } From f217d7d9b22c4b9c947fc5467379af17f036ee61 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Thu, 15 Feb 2018 07:47:40 -0800 Subject: [PATCH 0361/1161] [INFRA] Close stale PRs. Closes #20587 Closes #20586 From 2f0498d1e85a53b60da6a47d20bbdf56b42b7dcb Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Thu, 15 Feb 2018 08:55:39 -0800 Subject: [PATCH 0362/1161] [SPARK-23426][SQL] Use `hive` ORC impl and disable PPD for Spark 2.3.0 ## What changes were proposed in this pull request? To prevent any regressions, this PR changes ORC implementation to `hive` by default like Spark 2.2.X. Users can enable `native` ORC. Also, ORC PPD is also restored to `false` like Spark 2.2.X. ![orc_section](https://user-images.githubusercontent.com/9700541/36221575-57a1d702-1173-11e8-89fe-dca5842f4ca7.png) ## How was this patch tested? Pass all test cases. Author: Dongjoon Hyun Closes #20610 from dongjoon-hyun/SPARK-ORC-DISABLE. --- docs/sql-programming-guide.md | 52 ++++++++----------- .../apache/spark/sql/internal/SQLConf.scala | 6 +-- .../spark/sql/FileBasedDataSourceSuite.scala | 17 +++++- .../sql/streaming/FileStreamSinkSuite.scala | 13 +++++ .../sql/streaming/FileStreamSourceSuite.scala | 13 +++++ 5 files changed, 68 insertions(+), 33 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index cf9529a79f4f9..91e43678481d6 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1004,6 +1004,29 @@ Configuration of Parquet can be done using the `setConf` method on `SparkSession +## ORC Files + +Since Spark 2.3, Spark supports a vectorized ORC reader with a new ORC file format for ORC files. +To do that, the following configurations are newly added. The vectorized reader is used for the +native ORC tables (e.g., the ones created using the clause `USING ORC`) when `spark.sql.orc.impl` +is set to `native` and `spark.sql.orc.enableVectorizedReader` is set to `true`. For the Hive ORC +serde tables (e.g., the ones created using the clause `USING HIVE OPTIONS (fileFormat 'ORC')`), +the vectorized reader is used when `spark.sql.hive.convertMetastoreOrc` is also set to `true`. + + + + + + + + + + + + + +
    Property NameDefaultMeaning
    spark.sql.orc.implhiveThe name of ORC implementation. It can be one of native and hive. native means the native ORC support that is built on Apache ORC 1.4.1. `hive` means the ORC library in Hive 1.2.1.
    spark.sql.orc.enableVectorizedReadertrueEnables vectorized orc decoding in native implementation. If false, a new non-vectorized ORC reader is used in native implementation. For hive implementation, this is ignored.
    + ## JSON Datasets
    @@ -1776,35 +1799,6 @@ working with timestamps in `pandas_udf`s to get the best performance, see ## Upgrading From Spark SQL 2.2 to 2.3 - - Since Spark 2.3, Spark supports a vectorized ORC reader with a new ORC file format for ORC files. To do that, the following configurations are newly added or change their default values. The vectorized reader is used for the native ORC tables (e.g., the ones created using the clause `USING ORC`) when `spark.sql.orc.impl` is set to `native` and `spark.sql.orc.enableVectorizedReader` is set to `true`. For the Hive ORC serde table (e.g., the ones created using the clause `USING HIVE OPTIONS (fileFormat 'ORC')`), the vectorized reader is used when `spark.sql.hive.convertMetastoreOrc` is set to `true`. - - - New configurations - - - - - - - - - - - - - -
    Property NameDefaultMeaning
    spark.sql.orc.implnativeThe name of ORC implementation. It can be one of native and hive. native means the native ORC support that is built on Apache ORC 1.4.1. `hive` means the ORC library in Hive 1.2.1 which is used prior to Spark 2.3.
    spark.sql.orc.enableVectorizedReadertrueEnables vectorized orc decoding in native implementation. If false, a new non-vectorized ORC reader is used in native implementation. For hive implementation, this is ignored.
    - - - Changed configurations - - - - - - - - -
    Property NameDefaultMeaning
    spark.sql.orc.filterPushdowntrueEnables filter pushdown for ORC files. It is false by default prior to Spark 2.3.
    - - Since Spark 2.3, the queries from raw JSON/CSV files are disallowed when the referenced columns only include the internal corrupt record column (named `_corrupt_record` by default). For example, `spark.read.schema(schema).json(file).filter($"_corrupt_record".isNotNull).count()` and `spark.read.schema(schema).json(file).select("_corrupt_record").show()`. Instead, you can cache or save the parsed results and then send the same query. For example, `val df = spark.read.schema(schema).json(file).cache()` and then `df.filter($"_corrupt_record".isNotNull).count()`. - The `percentile_approx` function previously accepted numeric type input and output double type results. Now it supports date type, timestamp type and numeric types as input types. The result type is also changed to be the same as the input type, which is more reasonable for percentiles. - Since Spark 2.3, the Join/Filter's deterministic predicates that are after the first non-deterministic predicates are also pushed down/through the child operators, if possible. In prior Spark versions, these filters are not eligible for predicate pushdown. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 7835dbaa58439..f24fd7ff74d3f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -399,11 +399,11 @@ object SQLConf { val ORC_IMPLEMENTATION = buildConf("spark.sql.orc.impl") .doc("When native, use the native version of ORC support instead of the ORC library in Hive " + - "1.2.1. It is 'hive' by default prior to Spark 2.3.") + "1.2.1. It is 'hive' by default.") .internal() .stringConf .checkValues(Set("hive", "native")) - .createWithDefault("native") + .createWithDefault("hive") val ORC_VECTORIZED_READER_ENABLED = buildConf("spark.sql.orc.enableVectorizedReader") .doc("Enables vectorized orc decoding.") @@ -426,7 +426,7 @@ object SQLConf { val ORC_FILTER_PUSHDOWN_ENABLED = buildConf("spark.sql.orc.filterPushdown") .doc("When true, enable filter pushdown for ORC files.") .booleanConf - .createWithDefault(true) + .createWithDefault(false) val HIVE_VERIFY_PARTITION_PATH = buildConf("spark.sql.hive.verifyPartitionPath") .doc("When true, check all the partition paths under the table\'s root directory " + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala index 2e332362ea644..b5d4c558f0d3e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala @@ -20,14 +20,29 @@ package org.apache.spark.sql import java.io.FileNotFoundException import org.apache.hadoop.fs.Path +import org.scalatest.BeforeAndAfterAll import org.apache.spark.SparkException import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext -class FileBasedDataSourceSuite extends QueryTest with SharedSQLContext { + +class FileBasedDataSourceSuite extends QueryTest with SharedSQLContext with BeforeAndAfterAll { import testImplicits._ + override def beforeAll(): Unit = { + super.beforeAll() + spark.sessionState.conf.setConf(SQLConf.ORC_IMPLEMENTATION, "native") + } + + override def afterAll(): Unit = { + try { + spark.sessionState.conf.unsetConf(SQLConf.ORC_IMPLEMENTATION) + } finally { + super.afterAll() + } + } + private val allFileBasedDataSources = Seq("orc", "parquet", "csv", "json", "text") private val nameWithSpecialChars = "sp&cial%c hars" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala index 8c4e1fd00b0a2..ba48bc1ce0c4d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala @@ -33,6 +33,19 @@ import org.apache.spark.util.Utils class FileStreamSinkSuite extends StreamTest { import testImplicits._ + override def beforeAll(): Unit = { + super.beforeAll() + spark.sessionState.conf.setConf(SQLConf.ORC_IMPLEMENTATION, "native") + } + + override def afterAll(): Unit = { + try { + spark.sessionState.conf.unsetConf(SQLConf.ORC_IMPLEMENTATION) + } finally { + super.afterAll() + } + } + test("unpartitioned writing and batch reading") { val inputData = MemoryStream[Int] val df = inputData.toDF() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala index 5bb0f4d643bbe..d4bd9c7987f2d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala @@ -207,6 +207,19 @@ class FileStreamSourceSuite extends FileStreamSourceTest { .collect { case s @ StreamingRelation(dataSource, _, _) => s.schema }.head } + override def beforeAll(): Unit = { + super.beforeAll() + spark.sessionState.conf.setConf(SQLConf.ORC_IMPLEMENTATION, "native") + } + + override def afterAll(): Unit = { + try { + spark.sessionState.conf.unsetConf(SQLConf.ORC_IMPLEMENTATION) + } finally { + super.afterAll() + } + } + // ============= Basic parameter exists tests ================ test("FileStreamSource schema: no path") { From 6968c3cfd70961c4e86daffd6a156d0a9c1d7a2a Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Thu, 15 Feb 2018 09:40:08 -0800 Subject: [PATCH 0363/1161] [MINOR][SQL] Fix an error message about inserting into bucketed tables ## What changes were proposed in this pull request? This replaces `Sparkcurrently` to `Spark currently` in the following error message. ```scala scala> sql("insert into t2 select * from v1") org.apache.spark.sql.AnalysisException: Output Hive table `default`.`t2` is bucketed but Sparkcurrently does NOT populate bucketed ... ``` ## How was this patch tested? Manual. Author: Dongjoon Hyun Closes #20617 from dongjoon-hyun/SPARK-ERROR-MSG. --- .../apache/spark/sql/hive/execution/InsertIntoHiveTable.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index 3ce5b8469d6fc..02a60f16b3b3a 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -172,7 +172,7 @@ case class InsertIntoHiveTable( val enforceBucketingConfig = "hive.enforce.bucketing" val enforceSortingConfig = "hive.enforce.sorting" - val message = s"Output Hive table ${table.identifier} is bucketed but Spark" + + val message = s"Output Hive table ${table.identifier} is bucketed but Spark " + "currently does NOT populate bucketed output which is compatible with Hive." if (hadoopConf.get(enforceBucketingConfig, "true").toBoolean || From db45daab90ede4c03c1abc9096f4eac584e9db17 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 15 Feb 2018 09:54:39 -0800 Subject: [PATCH 0364/1161] [SPARK-23377][ML] Fixes Bucketizer with multiple columns persistence bug ## What changes were proposed in this pull request? #### Problem: Since 2.3, `Bucketizer` supports multiple input/output columns. We will check if exclusive params are set during transformation. E.g., if `inputCols` and `outputCol` are both set, an error will be thrown. However, when we write `Bucketizer`, looks like the default params and user-supplied params are merged during writing. All saved params are loaded back and set to created model instance. So the default `outputCol` param in `HasOutputCol` trait will be set in `paramMap` and become an user-supplied param. That makes the check of exclusive params failed. #### Fix: This changes the saving logic of Bucketizer to handle this case. This is a quick fix to catch the time of 2.3. We should consider modify the persistence mechanism later. Please see the discussion in the JIRA. Note: The multi-column `QuantileDiscretizer` also has the same issue. ## How was this patch tested? Modified tests. Author: Liang-Chi Hsieh Closes #20594 from viirya/SPARK-23377-2. --- .../apache/spark/ml/feature/Bucketizer.scala | 28 +++++++++++++++++++ .../ml/feature/QuantileDiscretizer.scala | 28 +++++++++++++++++++ .../spark/ml/feature/BucketizerSuite.scala | 12 ++++++-- .../ml/feature/QuantileDiscretizerSuite.scala | 14 ++++++++-- 4 files changed, 78 insertions(+), 4 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala index c13bf47eacb94..f49c410cbcfe2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala @@ -19,6 +19,10 @@ package org.apache.spark.ml.feature import java.{util => ju} +import org.json4s.JsonDSL._ +import org.json4s.JValue +import org.json4s.jackson.JsonMethods._ + import org.apache.spark.SparkException import org.apache.spark.annotation.Since import org.apache.spark.ml.Model @@ -213,6 +217,8 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String override def copy(extra: ParamMap): Bucketizer = { defaultCopy[Bucketizer](extra).setParent(parent) } + + override def write: MLWriter = new Bucketizer.BucketizerWriter(this) } @Since("1.6.0") @@ -290,6 +296,28 @@ object Bucketizer extends DefaultParamsReadable[Bucketizer] { } } + + private[Bucketizer] class BucketizerWriter(instance: Bucketizer) extends MLWriter { + + override protected def saveImpl(path: String): Unit = { + // SPARK-23377: The default params will be saved and loaded as user-supplied params. + // Once `inputCols` is set, the default value of `outputCol` param causes the error + // when checking exclusive params. As a temporary to fix it, we skip the default value + // of `outputCol` if `inputCols` is set when saving the metadata. + // TODO: If we modify the persistence mechanism later to better handle default params, + // we can get rid of this. + var paramWithoutOutputCol: Option[JValue] = None + if (instance.isSet(instance.inputCols)) { + val params = instance.extractParamMap().toSeq + val jsonParams = params.filter(_.param != instance.outputCol).map { case ParamPair(p, v) => + p.name -> parse(p.jsonEncode(v)) + }.toList + paramWithoutOutputCol = Some(render(jsonParams)) + } + DefaultParamsWriter.saveMetadata(instance, path, sc, paramMap = paramWithoutOutputCol) + } + } + @Since("1.6.0") override def load(path: String): Bucketizer = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala index 1ec5f8cb6139b..3b4c25478fb1d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala @@ -17,6 +17,10 @@ package org.apache.spark.ml.feature +import org.json4s.JsonDSL._ +import org.json4s.JValue +import org.json4s.jackson.JsonMethods._ + import org.apache.spark.annotation.Since import org.apache.spark.internal.Logging import org.apache.spark.ml._ @@ -249,11 +253,35 @@ final class QuantileDiscretizer @Since("1.6.0") (@Since("1.6.0") override val ui @Since("1.6.0") override def copy(extra: ParamMap): QuantileDiscretizer = defaultCopy(extra) + + override def write: MLWriter = new QuantileDiscretizer.QuantileDiscretizerWriter(this) } @Since("1.6.0") object QuantileDiscretizer extends DefaultParamsReadable[QuantileDiscretizer] with Logging { + private[QuantileDiscretizer] + class QuantileDiscretizerWriter(instance: QuantileDiscretizer) extends MLWriter { + + override protected def saveImpl(path: String): Unit = { + // SPARK-23377: The default params will be saved and loaded as user-supplied params. + // Once `inputCols` is set, the default value of `outputCol` param causes the error + // when checking exclusive params. As a temporary to fix it, we skip the default value + // of `outputCol` if `inputCols` is set when saving the metadata. + // TODO: If we modify the persistence mechanism later to better handle default params, + // we can get rid of this. + var paramWithoutOutputCol: Option[JValue] = None + if (instance.isSet(instance.inputCols)) { + val params = instance.extractParamMap().toSeq + val jsonParams = params.filter(_.param != instance.outputCol).map { case ParamPair(p, v) => + p.name -> parse(p.jsonEncode(v)) + }.toList + paramWithoutOutputCol = Some(render(jsonParams)) + } + DefaultParamsWriter.saveMetadata(instance, path, sc, paramMap = paramWithoutOutputCol) + } + } + @Since("1.6.0") override def load(path: String): QuantileDiscretizer = super.load(path) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala index 7403680ae3fdc..41cf72fe3470a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala @@ -172,7 +172,10 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa .setInputCol("myInputCol") .setOutputCol("myOutputCol") .setSplits(Array(0.1, 0.8, 0.9)) - testDefaultReadWrite(t) + + val bucketizer = testDefaultReadWrite(t) + val data = Seq((1.0, 2.0), (10.0, 100.0), (101.0, -1.0)).toDF("myInputCol", "myInputCol2") + bucketizer.transform(data) } test("Bucket numeric features") { @@ -327,7 +330,12 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa .setInputCols(Array("myInputCol")) .setOutputCols(Array("myOutputCol")) .setSplitsArray(Array(Array(0.1, 0.8, 0.9))) - testDefaultReadWrite(t) + + val bucketizer = testDefaultReadWrite(t) + val data = Seq((1.0, 2.0), (10.0, 100.0), (101.0, -1.0)).toDF("myInputCol", "myInputCol2") + bucketizer.transform(data) + assert(t.hasDefault(t.outputCol)) + assert(bucketizer.hasDefault(bucketizer.outputCol)) } test("Bucketizer in a pipeline") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala index e9a75e931e6a8..6c363799dd300 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala @@ -27,6 +27,8 @@ import org.apache.spark.sql.functions.udf class QuantileDiscretizerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + import testImplicits._ + test("Test observed number of buckets and their sizes match expected values") { val spark = this.spark import spark.implicits._ @@ -132,7 +134,10 @@ class QuantileDiscretizerSuite .setInputCol("myInputCol") .setOutputCol("myOutputCol") .setNumBuckets(6) - testDefaultReadWrite(t) + + val readDiscretizer = testDefaultReadWrite(t) + val data = sc.parallelize(1 to 100).map(Tuple1.apply).toDF("myInputCol") + readDiscretizer.fit(data) } test("Verify resulting model has parent") { @@ -379,7 +384,12 @@ class QuantileDiscretizerSuite .setInputCols(Array("input1", "input2")) .setOutputCols(Array("result1", "result2")) .setNumBucketsArray(Array(5, 10)) - testDefaultReadWrite(discretizer) + + val readDiscretizer = testDefaultReadWrite(discretizer) + val data = Seq((1.0, 2.0), (2.0, 3.0), (3.0, 4.0)).toDF("input1", "input2") + readDiscretizer.fit(data) + assert(discretizer.hasDefault(discretizer.outputCol)) + assert(readDiscretizer.hasDefault(readDiscretizer.outputCol)) } test("Multiple Columns: Both inputCol and inputCols are set") { From 1dc2c1d5e85c5f404f470aeb44c1f3c22786bdea Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Cattilapiros=E2=80=9D?= Date: Thu, 15 Feb 2018 13:51:24 -0600 Subject: [PATCH 0365/1161] [SPARK-23413][UI] Fix sorting tasks by Host / Executor ID at the Stage page MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? Fixing exception got at sorting tasks by Host / Executor ID: ``` java.lang.IllegalArgumentException: Invalid sort column: Host at org.apache.spark.ui.jobs.ApiHelper$.indexName(StagePage.scala:1017) at org.apache.spark.ui.jobs.TaskDataSource.sliceData(StagePage.scala:694) at org.apache.spark.ui.PagedDataSource.pageData(PagedTable.scala:61) at org.apache.spark.ui.PagedTable$class.table(PagedTable.scala:96) at org.apache.spark.ui.jobs.TaskPagedTable.table(StagePage.scala:708) at org.apache.spark.ui.jobs.StagePage.liftedTree1$1(StagePage.scala:293) at org.apache.spark.ui.jobs.StagePage.render(StagePage.scala:282) at org.apache.spark.ui.WebUI$$anonfun$2.apply(WebUI.scala:82) at org.apache.spark.ui.WebUI$$anonfun$2.apply(WebUI.scala:82) at org.apache.spark.ui.JettyUtils$$anon$3.doGet(JettyUtils.scala:90) at javax.servlet.http.HttpServlet.service(HttpServlet.java:687) at javax.servlet.http.HttpServlet.service(HttpServlet.java:790) at org.spark_project.jetty.servlet.ServletHolder.handle(ServletHolder.java:848) at org.spark_project.jetty.servlet.ServletHandler.doHandle(ServletHandler.java:584) ``` Moreover some refactoring to avoid similar problems by introducing constants for each header name and reusing them at the identification of the corresponding sorting index. ## How was this patch tested? Manually: ![screen shot 2018-02-13 at 18 57 10](https://user-images.githubusercontent.com/2017933/36166532-1cfdf3b8-10f3-11e8-8d32-5fcaad2af214.png) Author: “attilapiros” Closes #20601 from attilapiros/SPARK-23413. --- .../org/apache/spark/status/storeTypes.scala | 2 + .../org/apache/spark/ui/jobs/StagePage.scala | 121 +++++++++++------- .../org/apache/spark/ui/StagePageSuite.scala | 63 ++++++++- 3 files changed, 139 insertions(+), 47 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/status/storeTypes.scala b/core/src/main/scala/org/apache/spark/status/storeTypes.scala index 412644d3657b5..646cf25880e37 100644 --- a/core/src/main/scala/org/apache/spark/status/storeTypes.scala +++ b/core/src/main/scala/org/apache/spark/status/storeTypes.scala @@ -109,6 +109,7 @@ private[spark] object TaskIndexNames { final val DURATION = "dur" final val ERROR = "err" final val EXECUTOR = "exe" + final val HOST = "hst" final val EXEC_CPU_TIME = "ect" final val EXEC_RUN_TIME = "ert" final val GC_TIME = "gc" @@ -165,6 +166,7 @@ private[spark] class TaskDataWrapper( val duration: Long, @KVIndexParam(value = TaskIndexNames.EXECUTOR, parent = TaskIndexNames.STAGE) val executorId: String, + @KVIndexParam(value = TaskIndexNames.HOST, parent = TaskIndexNames.STAGE) val host: String, @KVIndexParam(value = TaskIndexNames.STATUS, parent = TaskIndexNames.STAGE) val status: String, diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index 5c2b0c3a19996..a9265d4dbcdfb 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -750,37 +750,39 @@ private[ui] class TaskPagedTable( } def headers: Seq[Node] = { + import ApiHelper._ + val taskHeadersAndCssClasses: Seq[(String, String)] = Seq( - ("Index", ""), ("ID", ""), ("Attempt", ""), ("Status", ""), ("Locality Level", ""), - ("Executor ID", ""), ("Host", ""), ("Launch Time", ""), ("Duration", ""), - ("Scheduler Delay", TaskDetailsClassNames.SCHEDULER_DELAY), - ("Task Deserialization Time", TaskDetailsClassNames.TASK_DESERIALIZATION_TIME), - ("GC Time", ""), - ("Result Serialization Time", TaskDetailsClassNames.RESULT_SERIALIZATION_TIME), - ("Getting Result Time", TaskDetailsClassNames.GETTING_RESULT_TIME), - ("Peak Execution Memory", TaskDetailsClassNames.PEAK_EXECUTION_MEMORY)) ++ - {if (hasAccumulators(stage)) Seq(("Accumulators", "")) else Nil} ++ - {if (hasInput(stage)) Seq(("Input Size / Records", "")) else Nil} ++ - {if (hasOutput(stage)) Seq(("Output Size / Records", "")) else Nil} ++ + (HEADER_TASK_INDEX, ""), (HEADER_ID, ""), (HEADER_ATTEMPT, ""), (HEADER_STATUS, ""), + (HEADER_LOCALITY, ""), (HEADER_EXECUTOR, ""), (HEADER_HOST, ""), (HEADER_LAUNCH_TIME, ""), + (HEADER_DURATION, ""), (HEADER_SCHEDULER_DELAY, TaskDetailsClassNames.SCHEDULER_DELAY), + (HEADER_DESER_TIME, TaskDetailsClassNames.TASK_DESERIALIZATION_TIME), + (HEADER_GC_TIME, ""), + (HEADER_SER_TIME, TaskDetailsClassNames.RESULT_SERIALIZATION_TIME), + (HEADER_GETTING_RESULT_TIME, TaskDetailsClassNames.GETTING_RESULT_TIME), + (HEADER_PEAK_MEM, TaskDetailsClassNames.PEAK_EXECUTION_MEMORY)) ++ + {if (hasAccumulators(stage)) Seq((HEADER_ACCUMULATORS, "")) else Nil} ++ + {if (hasInput(stage)) Seq((HEADER_INPUT_SIZE, "")) else Nil} ++ + {if (hasOutput(stage)) Seq((HEADER_OUTPUT_SIZE, "")) else Nil} ++ {if (hasShuffleRead(stage)) { - Seq(("Shuffle Read Blocked Time", TaskDetailsClassNames.SHUFFLE_READ_BLOCKED_TIME), - ("Shuffle Read Size / Records", ""), - ("Shuffle Remote Reads", TaskDetailsClassNames.SHUFFLE_READ_REMOTE_SIZE)) + Seq((HEADER_SHUFFLE_READ_TIME, TaskDetailsClassNames.SHUFFLE_READ_BLOCKED_TIME), + (HEADER_SHUFFLE_TOTAL_READS, ""), + (HEADER_SHUFFLE_REMOTE_READS, TaskDetailsClassNames.SHUFFLE_READ_REMOTE_SIZE)) } else { Nil }} ++ {if (hasShuffleWrite(stage)) { - Seq(("Write Time", ""), ("Shuffle Write Size / Records", "")) + Seq((HEADER_SHUFFLE_WRITE_TIME, ""), (HEADER_SHUFFLE_WRITE_SIZE, "")) } else { Nil }} ++ {if (hasBytesSpilled(stage)) { - Seq(("Shuffle Spill (Memory)", ""), ("Shuffle Spill (Disk)", "")) + Seq((HEADER_MEM_SPILL, ""), (HEADER_DISK_SPILL, "")) } else { Nil }} ++ - Seq(("Errors", "")) + Seq((HEADER_ERROR, "")) if (!taskHeadersAndCssClasses.map(_._1).contains(sortColumn)) { throw new IllegalArgumentException(s"Unknown column: $sortColumn") @@ -961,35 +963,62 @@ private[ui] class TaskPagedTable( } } -private object ApiHelper { - - - private val COLUMN_TO_INDEX = Map( - "ID" -> null.asInstanceOf[String], - "Index" -> TaskIndexNames.TASK_INDEX, - "Attempt" -> TaskIndexNames.ATTEMPT, - "Status" -> TaskIndexNames.STATUS, - "Locality Level" -> TaskIndexNames.LOCALITY, - "Executor ID / Host" -> TaskIndexNames.EXECUTOR, - "Launch Time" -> TaskIndexNames.LAUNCH_TIME, - "Duration" -> TaskIndexNames.DURATION, - "Scheduler Delay" -> TaskIndexNames.SCHEDULER_DELAY, - "Task Deserialization Time" -> TaskIndexNames.DESER_TIME, - "GC Time" -> TaskIndexNames.GC_TIME, - "Result Serialization Time" -> TaskIndexNames.SER_TIME, - "Getting Result Time" -> TaskIndexNames.GETTING_RESULT_TIME, - "Peak Execution Memory" -> TaskIndexNames.PEAK_MEM, - "Accumulators" -> TaskIndexNames.ACCUMULATORS, - "Input Size / Records" -> TaskIndexNames.INPUT_SIZE, - "Output Size / Records" -> TaskIndexNames.OUTPUT_SIZE, - "Shuffle Read Blocked Time" -> TaskIndexNames.SHUFFLE_READ_TIME, - "Shuffle Read Size / Records" -> TaskIndexNames.SHUFFLE_TOTAL_READS, - "Shuffle Remote Reads" -> TaskIndexNames.SHUFFLE_REMOTE_READS, - "Write Time" -> TaskIndexNames.SHUFFLE_WRITE_TIME, - "Shuffle Write Size / Records" -> TaskIndexNames.SHUFFLE_WRITE_SIZE, - "Shuffle Spill (Memory)" -> TaskIndexNames.MEM_SPILL, - "Shuffle Spill (Disk)" -> TaskIndexNames.DISK_SPILL, - "Errors" -> TaskIndexNames.ERROR) +private[ui] object ApiHelper { + + val HEADER_ID = "ID" + val HEADER_TASK_INDEX = "Index" + val HEADER_ATTEMPT = "Attempt" + val HEADER_STATUS = "Status" + val HEADER_LOCALITY = "Locality Level" + val HEADER_EXECUTOR = "Executor ID" + val HEADER_HOST = "Host" + val HEADER_LAUNCH_TIME = "Launch Time" + val HEADER_DURATION = "Duration" + val HEADER_SCHEDULER_DELAY = "Scheduler Delay" + val HEADER_DESER_TIME = "Task Deserialization Time" + val HEADER_GC_TIME = "GC Time" + val HEADER_SER_TIME = "Result Serialization Time" + val HEADER_GETTING_RESULT_TIME = "Getting Result Time" + val HEADER_PEAK_MEM = "Peak Execution Memory" + val HEADER_ACCUMULATORS = "Accumulators" + val HEADER_INPUT_SIZE = "Input Size / Records" + val HEADER_OUTPUT_SIZE = "Output Size / Records" + val HEADER_SHUFFLE_READ_TIME = "Shuffle Read Blocked Time" + val HEADER_SHUFFLE_TOTAL_READS = "Shuffle Read Size / Records" + val HEADER_SHUFFLE_REMOTE_READS = "Shuffle Remote Reads" + val HEADER_SHUFFLE_WRITE_TIME = "Write Time" + val HEADER_SHUFFLE_WRITE_SIZE = "Shuffle Write Size / Records" + val HEADER_MEM_SPILL = "Shuffle Spill (Memory)" + val HEADER_DISK_SPILL = "Shuffle Spill (Disk)" + val HEADER_ERROR = "Errors" + + private[ui] val COLUMN_TO_INDEX = Map( + HEADER_ID -> null.asInstanceOf[String], + HEADER_TASK_INDEX -> TaskIndexNames.TASK_INDEX, + HEADER_ATTEMPT -> TaskIndexNames.ATTEMPT, + HEADER_STATUS -> TaskIndexNames.STATUS, + HEADER_LOCALITY -> TaskIndexNames.LOCALITY, + HEADER_EXECUTOR -> TaskIndexNames.EXECUTOR, + HEADER_HOST -> TaskIndexNames.HOST, + HEADER_LAUNCH_TIME -> TaskIndexNames.LAUNCH_TIME, + HEADER_DURATION -> TaskIndexNames.DURATION, + HEADER_SCHEDULER_DELAY -> TaskIndexNames.SCHEDULER_DELAY, + HEADER_DESER_TIME -> TaskIndexNames.DESER_TIME, + HEADER_GC_TIME -> TaskIndexNames.GC_TIME, + HEADER_SER_TIME -> TaskIndexNames.SER_TIME, + HEADER_GETTING_RESULT_TIME -> TaskIndexNames.GETTING_RESULT_TIME, + HEADER_PEAK_MEM -> TaskIndexNames.PEAK_MEM, + HEADER_ACCUMULATORS -> TaskIndexNames.ACCUMULATORS, + HEADER_INPUT_SIZE -> TaskIndexNames.INPUT_SIZE, + HEADER_OUTPUT_SIZE -> TaskIndexNames.OUTPUT_SIZE, + HEADER_SHUFFLE_READ_TIME -> TaskIndexNames.SHUFFLE_READ_TIME, + HEADER_SHUFFLE_TOTAL_READS -> TaskIndexNames.SHUFFLE_TOTAL_READS, + HEADER_SHUFFLE_REMOTE_READS -> TaskIndexNames.SHUFFLE_REMOTE_READS, + HEADER_SHUFFLE_WRITE_TIME -> TaskIndexNames.SHUFFLE_WRITE_TIME, + HEADER_SHUFFLE_WRITE_SIZE -> TaskIndexNames.SHUFFLE_WRITE_SIZE, + HEADER_MEM_SPILL -> TaskIndexNames.MEM_SPILL, + HEADER_DISK_SPILL -> TaskIndexNames.DISK_SPILL, + HEADER_ERROR -> TaskIndexNames.ERROR) def hasAccumulators(stageData: StageData): Boolean = { stageData.accumulatorUpdates.exists { acc => acc.name != null && acc.value != null } diff --git a/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala b/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala index 0aeddf730cd35..6044563f7dde7 100644 --- a/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala @@ -28,13 +28,74 @@ import org.apache.spark._ import org.apache.spark.executor.TaskMetrics import org.apache.spark.scheduler._ import org.apache.spark.status.AppStatusStore +import org.apache.spark.status.api.v1.{AccumulableInfo => UIAccumulableInfo, StageData, StageStatus} import org.apache.spark.status.config._ -import org.apache.spark.ui.jobs.{StagePage, StagesTab} +import org.apache.spark.ui.jobs.{ApiHelper, StagePage, StagesTab, TaskPagedTable} class StagePageSuite extends SparkFunSuite with LocalSparkContext { private val peakExecutionMemory = 10 + test("ApiHelper.COLUMN_TO_INDEX should match headers of the task table") { + val conf = new SparkConf(false).set(LIVE_ENTITY_UPDATE_PERIOD, 0L) + val statusStore = AppStatusStore.createLiveStore(conf) + try { + val stageData = new StageData( + status = StageStatus.ACTIVE, + stageId = 1, + attemptId = 1, + numTasks = 1, + numActiveTasks = 1, + numCompleteTasks = 1, + numFailedTasks = 1, + numKilledTasks = 1, + numCompletedIndices = 1, + + executorRunTime = 1L, + executorCpuTime = 1L, + submissionTime = None, + firstTaskLaunchedTime = None, + completionTime = None, + failureReason = None, + + inputBytes = 1L, + inputRecords = 1L, + outputBytes = 1L, + outputRecords = 1L, + shuffleReadBytes = 1L, + shuffleReadRecords = 1L, + shuffleWriteBytes = 1L, + shuffleWriteRecords = 1L, + memoryBytesSpilled = 1L, + diskBytesSpilled = 1L, + + name = "stage1", + description = Some("description"), + details = "detail", + schedulingPool = "pool1", + + rddIds = Seq(1), + accumulatorUpdates = Seq(new UIAccumulableInfo(0L, "acc", None, "value")), + tasks = None, + executorSummary = None, + killedTasksSummary = Map.empty + ) + val taskTable = new TaskPagedTable( + stageData, + basePath = "/a/b/c", + currentTime = 0, + pageSize = 10, + sortColumn = "Index", + desc = false, + store = statusStore + ) + val columnNames = (taskTable.headers \ "th" \ "a").map(_.child(1).text).toSet + assert(columnNames === ApiHelper.COLUMN_TO_INDEX.keySet) + } finally { + statusStore.close() + } + } + test("peak execution memory should displayed") { val html = renderStagePage().toString().toLowerCase(Locale.ROOT) val targetString = "peak execution memory" From c5857e496ff0d170ed0339f14afc7d36b192da6d Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Fri, 16 Feb 2018 09:41:17 -0800 Subject: [PATCH 0366/1161] [SPARK-23446][PYTHON] Explicitly check supported types in toPandas ## What changes were proposed in this pull request? This PR explicitly specifies and checks the types we supported in `toPandas`. This was a hole. For example, we haven't finished the binary type support in Python side yet but now it allows as below: ```python spark.conf.set("spark.sql.execution.arrow.enabled", "false") df = spark.createDataFrame([[bytearray("a")]]) df.toPandas() spark.conf.set("spark.sql.execution.arrow.enabled", "true") df.toPandas() ``` ``` _1 0 [97] _1 0 a ``` This should be disallowed. I think the same things also apply to nested timestamps too. I also added some nicer message about `spark.sql.execution.arrow.enabled` in the error message. ## How was this patch tested? Manually tested and tests added in `python/pyspark/sql/tests.py`. Author: hyukjinkwon Closes #20625 from HyukjinKwon/pandas_convertion_supported_type. --- python/pyspark/sql/dataframe.py | 15 +++++++++------ python/pyspark/sql/tests.py | 9 ++++++++- 2 files changed, 17 insertions(+), 7 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 5cc8b63cdfadf..f37777e13ee12 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1988,10 +1988,11 @@ def toPandas(self): if self.sql_ctx.getConf("spark.sql.execution.arrow.enabled", "false").lower() == "true": try: from pyspark.sql.types import _check_dataframe_convert_date, \ - _check_dataframe_localize_timestamps + _check_dataframe_localize_timestamps, to_arrow_schema from pyspark.sql.utils import require_minimum_pyarrow_version - import pyarrow require_minimum_pyarrow_version() + import pyarrow + to_arrow_schema(self.schema) tables = self._collectAsArrow() if tables: table = pyarrow.concat_tables(tables) @@ -2000,10 +2001,12 @@ def toPandas(self): return _check_dataframe_localize_timestamps(pdf, timezone) else: return pd.DataFrame.from_records([], columns=self.columns) - except ImportError as e: - msg = "note: pyarrow must be installed and available on calling Python process " \ - "if using spark.sql.execution.arrow.enabled=true" - raise ImportError("%s\n%s" % (_exception_message(e), msg)) + except Exception as e: + msg = ( + "Note: toPandas attempted Arrow optimization because " + "'spark.sql.execution.arrow.enabled' is set to true. Please set it to false " + "to disable this.") + raise RuntimeError("%s\n%s" % (_exception_message(e), msg)) else: pdf = pd.DataFrame.from_records(self.collect(), columns=self.columns) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 2af218a691026..19653072ea316 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3497,7 +3497,14 @@ def test_unsupported_datatype(self): schema = StructType([StructField("map", MapType(StringType(), IntegerType()), True)]) df = self.spark.createDataFrame([(None,)], schema=schema) with QuietTest(self.sc): - with self.assertRaisesRegexp(Exception, 'Unsupported data type'): + with self.assertRaisesRegexp(Exception, 'Unsupported type'): + df.toPandas() + + df = self.spark.createDataFrame([(None,)], schema="a binary") + with QuietTest(self.sc): + with self.assertRaisesRegexp( + Exception, + 'Unsupported type.*\nNote: toPandas attempted Arrow optimization because'): df.toPandas() def test_null_conversion(self): From 0a73aa31f41c83503d5d99eff3c9d7b406014ab3 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Fri, 16 Feb 2018 14:30:19 -0800 Subject: [PATCH 0367/1161] [SPARK-23362][SS] Migrate Kafka Microbatch source to v2 ## What changes were proposed in this pull request? Migrating KafkaSource (with data source v1) to KafkaMicroBatchReader (with data source v2). Performance comparison: In a unit test with in-process Kafka broker, I tested the read throughput of V1 and V2 using 20M records in a single partition. They were comparable. ## How was this patch tested? Existing tests, few modified to be better tests than the existing ones. Author: Tathagata Das Closes #20554 from tdas/SPARK-23362. --- dev/.rat-excludes | 1 + .../sql/kafka010/CachedKafkaConsumer.scala | 2 +- .../sql/kafka010/KafkaContinuousReader.scala | 29 +- .../sql/kafka010/KafkaMicroBatchReader.scala | 403 ++++++++++++++++++ .../KafkaRecordToUnsafeRowConverter.scala | 52 +++ .../spark/sql/kafka010/KafkaSource.scala | 19 +- .../sql/kafka010/KafkaSourceProvider.scala | 70 ++- ...a-source-initial-offset-future-version.bin | 2 + ...ka-source-initial-offset-version-2.1.0.bin | 2 +- ...scala => KafkaMicroBatchSourceSuite.scala} | 254 +++++++---- .../apache/spark/sql/internal/SQLConf.scala | 15 +- .../streaming/MicroBatchExecution.scala | 20 +- 12 files changed, 741 insertions(+), 128 deletions(-) create mode 100644 external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala create mode 100644 external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRecordToUnsafeRowConverter.scala create mode 100644 external/kafka-0-10-sql/src/test/resources/kafka-source-initial-offset-future-version.bin rename external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/{KafkaSourceSuite.scala => KafkaMicroBatchSourceSuite.scala} (85%) diff --git a/dev/.rat-excludes b/dev/.rat-excludes index 243fbe3e1bc24..9552d001a079c 100644 --- a/dev/.rat-excludes +++ b/dev/.rat-excludes @@ -105,3 +105,4 @@ META-INF/* spark-warehouse structured-streaming/* kafka-source-initial-offset-version-2.1.0.bin +kafka-source-initial-offset-future-version.bin diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumer.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumer.scala index 90ed7b1fba2f8..e97881cb0a163 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumer.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumer.scala @@ -27,7 +27,7 @@ import org.apache.kafka.common.TopicPartition import org.apache.spark.{SparkEnv, SparkException, TaskContext} import org.apache.spark.internal.Logging -import org.apache.spark.sql.kafka010.KafkaSource._ +import org.apache.spark.sql.kafka010.KafkaSourceProvider._ import org.apache.spark.util.UninterruptibleThread diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala index b049a054cb40e..97a0f66e1880d 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.catalyst.expressions.codegen.{BufferHolder, UnsafeRowWriter} import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.kafka010.KafkaSource.{INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE, INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE} +import org.apache.spark.sql.kafka010.KafkaSourceProvider.{INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE, INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE} import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousDataReader, ContinuousReader, Offset, PartitionOffset} import org.apache.spark.sql.types.StructType @@ -187,13 +187,9 @@ class KafkaContinuousDataReader( kafkaParams: ju.Map[String, Object], pollTimeoutMs: Long, failOnDataLoss: Boolean) extends ContinuousDataReader[UnsafeRow] { - private val topic = topicPartition.topic - private val kafkaPartition = topicPartition.partition - private val consumer = CachedKafkaConsumer.createUncached(topic, kafkaPartition, kafkaParams) - - private val sharedRow = new UnsafeRow(7) - private val bufferHolder = new BufferHolder(sharedRow) - private val rowWriter = new UnsafeRowWriter(bufferHolder, 7) + private val consumer = + CachedKafkaConsumer.createUncached(topicPartition.topic, topicPartition.partition, kafkaParams) + private val converter = new KafkaRecordToUnsafeRowConverter private var nextKafkaOffset = startOffset private var currentRecord: ConsumerRecord[Array[Byte], Array[Byte]] = _ @@ -232,22 +228,7 @@ class KafkaContinuousDataReader( } override def get(): UnsafeRow = { - bufferHolder.reset() - - if (currentRecord.key == null) { - rowWriter.setNullAt(0) - } else { - rowWriter.write(0, currentRecord.key) - } - rowWriter.write(1, currentRecord.value) - rowWriter.write(2, UTF8String.fromString(currentRecord.topic)) - rowWriter.write(3, currentRecord.partition) - rowWriter.write(4, currentRecord.offset) - rowWriter.write(5, - DateTimeUtils.fromJavaTimestamp(new java.sql.Timestamp(currentRecord.timestamp))) - rowWriter.write(6, currentRecord.timestampType.id) - sharedRow.setTotalSize(bufferHolder.totalSize) - sharedRow + converter.toUnsafeRow(currentRecord) } override def getOffset(): KafkaSourcePartitionOffset = { diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala new file mode 100644 index 0000000000000..fb647ca7e70dd --- /dev/null +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala @@ -0,0 +1,403 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import java.{util => ju} +import java.io._ +import java.nio.charset.StandardCharsets + +import scala.collection.JavaConverters._ + +import org.apache.commons.io.IOUtils +import org.apache.kafka.common.TopicPartition + +import org.apache.spark.SparkEnv +import org.apache.spark.internal.Logging +import org.apache.spark.scheduler.ExecutorCacheTaskLocation +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.execution.streaming.{HDFSMetadataLog, SerializedOffset} +import org.apache.spark.sql.kafka010.KafkaSourceProvider.{INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE, INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE} +import org.apache.spark.sql.sources.v2.DataSourceOptions +import org.apache.spark.sql.sources.v2.reader.{DataReader, DataReaderFactory, SupportsScanUnsafeRow} +import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset} +import org.apache.spark.sql.types.StructType +import org.apache.spark.util.UninterruptibleThread + +/** + * A [[MicroBatchReader]] that reads data from Kafka. + * + * The [[KafkaSourceOffset]] is the custom [[Offset]] defined for this source that contains + * a map of TopicPartition -> offset. Note that this offset is 1 + (available offset). For + * example if the last record in a Kafka topic "t", partition 2 is offset 5, then + * KafkaSourceOffset will contain TopicPartition("t", 2) -> 6. This is done keep it consistent + * with the semantics of `KafkaConsumer.position()`. + * + * Zero data lost is not guaranteed when topics are deleted. If zero data lost is critical, the user + * must make sure all messages in a topic have been processed when deleting a topic. + * + * There is a known issue caused by KAFKA-1894: the query using Kafka maybe cannot be stopped. + * To avoid this issue, you should make sure stopping the query before stopping the Kafka brokers + * and not use wrong broker addresses. + */ +private[kafka010] class KafkaMicroBatchReader( + kafkaOffsetReader: KafkaOffsetReader, + executorKafkaParams: ju.Map[String, Object], + options: DataSourceOptions, + metadataPath: String, + startingOffsets: KafkaOffsetRangeLimit, + failOnDataLoss: Boolean) + extends MicroBatchReader with SupportsScanUnsafeRow with Logging { + + type PartitionOffsetMap = Map[TopicPartition, Long] + + private var startPartitionOffsets: PartitionOffsetMap = _ + private var endPartitionOffsets: PartitionOffsetMap = _ + + private val pollTimeoutMs = options.getLong( + "kafkaConsumer.pollTimeoutMs", + SparkEnv.get.conf.getTimeAsMs("spark.network.timeout", "120s")) + + private val maxOffsetsPerTrigger = + Option(options.get("maxOffsetsPerTrigger").orElse(null)).map(_.toLong) + + /** + * Lazily initialize `initialPartitionOffsets` to make sure that `KafkaConsumer.poll` is only + * called in StreamExecutionThread. Otherwise, interrupting a thread while running + * `KafkaConsumer.poll` may hang forever (KAFKA-1894). + */ + private lazy val initialPartitionOffsets = getOrCreateInitialPartitionOffsets() + + override def setOffsetRange(start: ju.Optional[Offset], end: ju.Optional[Offset]): Unit = { + // Make sure initialPartitionOffsets is initialized + initialPartitionOffsets + + startPartitionOffsets = Option(start.orElse(null)) + .map(_.asInstanceOf[KafkaSourceOffset].partitionToOffsets) + .getOrElse(initialPartitionOffsets) + + endPartitionOffsets = Option(end.orElse(null)) + .map(_.asInstanceOf[KafkaSourceOffset].partitionToOffsets) + .getOrElse { + val latestPartitionOffsets = kafkaOffsetReader.fetchLatestOffsets() + maxOffsetsPerTrigger.map { maxOffsets => + rateLimit(maxOffsets, startPartitionOffsets, latestPartitionOffsets) + }.getOrElse { + latestPartitionOffsets + } + } + } + + override def createUnsafeRowReaderFactories(): ju.List[DataReaderFactory[UnsafeRow]] = { + // Find the new partitions, and get their earliest offsets + val newPartitions = endPartitionOffsets.keySet.diff(startPartitionOffsets.keySet) + val newPartitionOffsets = kafkaOffsetReader.fetchEarliestOffsets(newPartitions.toSeq) + if (newPartitionOffsets.keySet != newPartitions) { + // We cannot get from offsets for some partitions. It means they got deleted. + val deletedPartitions = newPartitions.diff(newPartitionOffsets.keySet) + reportDataLoss( + s"Cannot find earliest offsets of ${deletedPartitions}. Some data may have been missed") + } + logInfo(s"Partitions added: $newPartitionOffsets") + newPartitionOffsets.filter(_._2 != 0).foreach { case (p, o) => + reportDataLoss( + s"Added partition $p starts from $o instead of 0. Some data may have been missed") + } + + // Find deleted partitions, and report data loss if required + val deletedPartitions = startPartitionOffsets.keySet.diff(endPartitionOffsets.keySet) + if (deletedPartitions.nonEmpty) { + reportDataLoss(s"$deletedPartitions are gone. Some data may have been missed") + } + + // Use the until partitions to calculate offset ranges to ignore partitions that have + // been deleted + val topicPartitions = endPartitionOffsets.keySet.filter { tp => + // Ignore partitions that we don't know the from offsets. + newPartitionOffsets.contains(tp) || startPartitionOffsets.contains(tp) + }.toSeq + logDebug("TopicPartitions: " + topicPartitions.mkString(", ")) + + val sortedExecutors = getSortedExecutorList() + val numExecutors = sortedExecutors.length + logDebug("Sorted executors: " + sortedExecutors.mkString(", ")) + + // Calculate offset ranges + val factories = topicPartitions.flatMap { tp => + val fromOffset = startPartitionOffsets.get(tp).getOrElse { + newPartitionOffsets.getOrElse( + tp, { + // This should not happen since newPartitionOffsets contains all partitions not in + // fromPartitionOffsets + throw new IllegalStateException(s"$tp doesn't have a from offset") + }) + } + val untilOffset = endPartitionOffsets(tp) + + if (untilOffset >= fromOffset) { + // This allows cached KafkaConsumers in the executors to be re-used to read the same + // partition in every batch. + val preferredLoc = if (numExecutors > 0) { + Some(sortedExecutors(Math.floorMod(tp.hashCode, numExecutors))) + } else None + val range = KafkaOffsetRange(tp, fromOffset, untilOffset) + Some( + new KafkaMicroBatchDataReaderFactory( + range, preferredLoc, executorKafkaParams, pollTimeoutMs, failOnDataLoss)) + } else { + reportDataLoss( + s"Partition $tp's offset was changed from " + + s"$fromOffset to $untilOffset, some data may have been missed") + None + } + } + factories.map(_.asInstanceOf[DataReaderFactory[UnsafeRow]]).asJava + } + + override def getStartOffset: Offset = { + KafkaSourceOffset(startPartitionOffsets) + } + + override def getEndOffset: Offset = { + KafkaSourceOffset(endPartitionOffsets) + } + + override def deserializeOffset(json: String): Offset = { + KafkaSourceOffset(JsonUtils.partitionOffsets(json)) + } + + override def readSchema(): StructType = KafkaOffsetReader.kafkaSchema + + override def commit(end: Offset): Unit = {} + + override def stop(): Unit = { + kafkaOffsetReader.close() + } + + override def toString(): String = s"Kafka[$kafkaOffsetReader]" + + /** + * Read initial partition offsets from the checkpoint, or decide the offsets and write them to + * the checkpoint. + */ + private def getOrCreateInitialPartitionOffsets(): PartitionOffsetMap = { + // Make sure that `KafkaConsumer.poll` is only called in StreamExecutionThread. + // Otherwise, interrupting a thread while running `KafkaConsumer.poll` may hang forever + // (KAFKA-1894). + assert(Thread.currentThread().isInstanceOf[UninterruptibleThread]) + + // SparkSession is required for getting Hadoop configuration for writing to checkpoints + assert(SparkSession.getActiveSession.nonEmpty) + + val metadataLog = + new KafkaSourceInitialOffsetWriter(SparkSession.getActiveSession.get, metadataPath) + metadataLog.get(0).getOrElse { + val offsets = startingOffsets match { + case EarliestOffsetRangeLimit => + KafkaSourceOffset(kafkaOffsetReader.fetchEarliestOffsets()) + case LatestOffsetRangeLimit => + KafkaSourceOffset(kafkaOffsetReader.fetchLatestOffsets()) + case SpecificOffsetRangeLimit(p) => + kafkaOffsetReader.fetchSpecificOffsets(p, reportDataLoss) + } + metadataLog.add(0, offsets) + logInfo(s"Initial offsets: $offsets") + offsets + }.partitionToOffsets + } + + /** Proportionally distribute limit number of offsets among topicpartitions */ + private def rateLimit( + limit: Long, + from: PartitionOffsetMap, + until: PartitionOffsetMap): PartitionOffsetMap = { + val fromNew = kafkaOffsetReader.fetchEarliestOffsets(until.keySet.diff(from.keySet).toSeq) + val sizes = until.flatMap { + case (tp, end) => + // If begin isn't defined, something's wrong, but let alert logic in getBatch handle it + from.get(tp).orElse(fromNew.get(tp)).flatMap { begin => + val size = end - begin + logDebug(s"rateLimit $tp size is $size") + if (size > 0) Some(tp -> size) else None + } + } + val total = sizes.values.sum.toDouble + if (total < 1) { + until + } else { + until.map { + case (tp, end) => + tp -> sizes.get(tp).map { size => + val begin = from.get(tp).getOrElse(fromNew(tp)) + val prorate = limit * (size / total) + // Don't completely starve small topicpartitions + val off = begin + (if (prorate < 1) Math.ceil(prorate) else Math.floor(prorate)).toLong + // Paranoia, make sure not to return an offset that's past end + Math.min(end, off) + }.getOrElse(end) + } + } + } + + private def getSortedExecutorList(): Array[String] = { + + def compare(a: ExecutorCacheTaskLocation, b: ExecutorCacheTaskLocation): Boolean = { + if (a.host == b.host) { + a.executorId > b.executorId + } else { + a.host > b.host + } + } + + val bm = SparkEnv.get.blockManager + bm.master.getPeers(bm.blockManagerId).toArray + .map(x => ExecutorCacheTaskLocation(x.host, x.executorId)) + .sortWith(compare) + .map(_.toString) + } + + /** + * If `failOnDataLoss` is true, this method will throw an `IllegalStateException`. + * Otherwise, just log a warning. + */ + private def reportDataLoss(message: String): Unit = { + if (failOnDataLoss) { + throw new IllegalStateException(message + s". $INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE") + } else { + logWarning(message + s". $INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE") + } + } + + /** A version of [[HDFSMetadataLog]] specialized for saving the initial offsets. */ + class KafkaSourceInitialOffsetWriter(sparkSession: SparkSession, metadataPath: String) + extends HDFSMetadataLog[KafkaSourceOffset](sparkSession, metadataPath) { + + val VERSION = 1 + + override def serialize(metadata: KafkaSourceOffset, out: OutputStream): Unit = { + out.write(0) // A zero byte is written to support Spark 2.1.0 (SPARK-19517) + val writer = new BufferedWriter(new OutputStreamWriter(out, StandardCharsets.UTF_8)) + writer.write("v" + VERSION + "\n") + writer.write(metadata.json) + writer.flush + } + + override def deserialize(in: InputStream): KafkaSourceOffset = { + in.read() // A zero byte is read to support Spark 2.1.0 (SPARK-19517) + val content = IOUtils.toString(new InputStreamReader(in, StandardCharsets.UTF_8)) + // HDFSMetadataLog guarantees that it never creates a partial file. + assert(content.length != 0) + if (content(0) == 'v') { + val indexOfNewLine = content.indexOf("\n") + if (indexOfNewLine > 0) { + val version = parseVersion(content.substring(0, indexOfNewLine), VERSION) + KafkaSourceOffset(SerializedOffset(content.substring(indexOfNewLine + 1))) + } else { + throw new IllegalStateException( + s"Log file was malformed: failed to detect the log file version line.") + } + } else { + // The log was generated by Spark 2.1.0 + KafkaSourceOffset(SerializedOffset(content)) + } + } + } +} + +/** A [[DataReaderFactory]] for reading Kafka data in a micro-batch streaming query. */ +private[kafka010] class KafkaMicroBatchDataReaderFactory( + range: KafkaOffsetRange, + preferredLoc: Option[String], + executorKafkaParams: ju.Map[String, Object], + pollTimeoutMs: Long, + failOnDataLoss: Boolean) extends DataReaderFactory[UnsafeRow] { + + override def preferredLocations(): Array[String] = preferredLoc.toArray + + override def createDataReader(): DataReader[UnsafeRow] = new KafkaMicroBatchDataReader( + range, executorKafkaParams, pollTimeoutMs, failOnDataLoss) +} + +/** A [[DataReader]] for reading Kafka data in a micro-batch streaming query. */ +private[kafka010] class KafkaMicroBatchDataReader( + offsetRange: KafkaOffsetRange, + executorKafkaParams: ju.Map[String, Object], + pollTimeoutMs: Long, + failOnDataLoss: Boolean) extends DataReader[UnsafeRow] with Logging { + + private val consumer = CachedKafkaConsumer.getOrCreate( + offsetRange.topicPartition.topic, offsetRange.topicPartition.partition, executorKafkaParams) + private val rangeToRead = resolveRange(offsetRange) + private val converter = new KafkaRecordToUnsafeRowConverter + + private var nextOffset = rangeToRead.fromOffset + private var nextRow: UnsafeRow = _ + + override def next(): Boolean = { + if (nextOffset < rangeToRead.untilOffset) { + val record = consumer.get(nextOffset, rangeToRead.untilOffset, pollTimeoutMs, failOnDataLoss) + if (record != null) { + nextRow = converter.toUnsafeRow(record) + true + } else { + false + } + } else { + false + } + } + + override def get(): UnsafeRow = { + assert(nextRow != null) + nextOffset += 1 + nextRow + } + + override def close(): Unit = { + // Indicate that we're no longer using this consumer + CachedKafkaConsumer.releaseKafkaConsumer( + offsetRange.topicPartition.topic, offsetRange.topicPartition.partition, executorKafkaParams) + } + + private def resolveRange(range: KafkaOffsetRange): KafkaOffsetRange = { + if (range.fromOffset < 0 || range.untilOffset < 0) { + // Late bind the offset range + val availableOffsetRange = consumer.getAvailableOffsetRange() + val fromOffset = if (range.fromOffset < 0) { + assert(range.fromOffset == KafkaOffsetRangeLimit.EARLIEST, + s"earliest offset ${range.fromOffset} does not equal ${KafkaOffsetRangeLimit.EARLIEST}") + availableOffsetRange.earliest + } else { + range.fromOffset + } + val untilOffset = if (range.untilOffset < 0) { + assert(range.untilOffset == KafkaOffsetRangeLimit.LATEST, + s"latest offset ${range.untilOffset} does not equal ${KafkaOffsetRangeLimit.LATEST}") + availableOffsetRange.latest + } else { + range.untilOffset + } + KafkaOffsetRange(range.topicPartition, fromOffset, untilOffset) + } else { + range + } + } +} + +private[kafka010] case class KafkaOffsetRange( + topicPartition: TopicPartition, fromOffset: Long, untilOffset: Long) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRecordToUnsafeRowConverter.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRecordToUnsafeRowConverter.scala new file mode 100644 index 0000000000000..1acdd56125741 --- /dev/null +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRecordToUnsafeRowConverter.scala @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import org.apache.kafka.clients.consumer.ConsumerRecord + +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.catalyst.expressions.codegen.{BufferHolder, UnsafeRowWriter} +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.unsafe.types.UTF8String + +/** A simple class for converting Kafka ConsumerRecord to UnsafeRow */ +private[kafka010] class KafkaRecordToUnsafeRowConverter { + private val sharedRow = new UnsafeRow(7) + private val bufferHolder = new BufferHolder(sharedRow) + private val rowWriter = new UnsafeRowWriter(bufferHolder, 7) + + def toUnsafeRow(record: ConsumerRecord[Array[Byte], Array[Byte]]): UnsafeRow = { + bufferHolder.reset() + + if (record.key == null) { + rowWriter.setNullAt(0) + } else { + rowWriter.write(0, record.key) + } + rowWriter.write(1, record.value) + rowWriter.write(2, UTF8String.fromString(record.topic)) + rowWriter.write(3, record.partition) + rowWriter.write(4, record.offset) + rowWriter.write( + 5, + DateTimeUtils.fromJavaTimestamp(new java.sql.Timestamp(record.timestamp))) + rowWriter.write(6, record.timestampType.id) + sharedRow.setTotalSize(bufferHolder.totalSize) + sharedRow + } +} diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala index 169a5d006fb04..1c7b3a29a861f 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.kafka010.KafkaSource._ +import org.apache.spark.sql.kafka010.KafkaSourceProvider.{INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE, INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -306,7 +307,7 @@ private[kafka010] class KafkaSource( kafkaReader.close() } - override def toString(): String = s"KafkaSource[$kafkaReader]" + override def toString(): String = s"KafkaSourceV1[$kafkaReader]" /** * If `failOnDataLoss` is true, this method will throw an `IllegalStateException`. @@ -323,22 +324,6 @@ private[kafka010] class KafkaSource( /** Companion object for the [[KafkaSource]]. */ private[kafka010] object KafkaSource { - val INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE = - """ - |Some data may have been lost because they are not available in Kafka any more; either the - | data was aged out by Kafka or the topic may have been deleted before all the data in the - | topic was processed. If you want your streaming query to fail on such cases, set the source - | option "failOnDataLoss" to "true". - """.stripMargin - - val INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE = - """ - |Some data may have been lost because they are not available in Kafka any more; either the - | data was aged out by Kafka or the topic may have been deleted before all the data in the - | topic was processed. If you don't want your streaming query to fail on such cases, set the - | source option "failOnDataLoss" to "false". - """.stripMargin - private[kafka010] val VERSION = 1 def getSortedExecutorList(sc: SparkContext): Array[String] = { diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala index d4fa0359c12d6..0aa64a6a9cf90 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala @@ -30,13 +30,13 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.{AnalysisException, DataFrame, SaveMode, SparkSession, SQLContext} import org.apache.spark.sql.execution.streaming.{Sink, Source} import org.apache.spark.sql.sources._ -import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions, StreamWriteSupport} +import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions, MicroBatchReadSupport, StreamWriteSupport} import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType /** - * The provider class for the [[KafkaSource]]. This provider is designed such that it throws + * The provider class for all Kafka readers and writers. It is designed such that it throws * IllegalArgumentException when the Kafka Dataset is created, so that it can catch * missing options even before the query is started. */ @@ -47,6 +47,7 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister with CreatableRelationProvider with StreamWriteSupport with ContinuousReadSupport + with MicroBatchReadSupport with Logging { import KafkaSourceProvider._ @@ -105,6 +106,52 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister failOnDataLoss(caseInsensitiveParams)) } + /** + * Creates a [[org.apache.spark.sql.sources.v2.reader.streaming.MicroBatchReader]] to read batches + * of Kafka data in a micro-batch streaming query. + */ + override def createMicroBatchReader( + schema: Optional[StructType], + metadataPath: String, + options: DataSourceOptions): KafkaMicroBatchReader = { + + val parameters = options.asMap().asScala.toMap + validateStreamOptions(parameters) + // Each running query should use its own group id. Otherwise, the query may be only assigned + // partial data since Kafka will assign partitions to multiple consumers having the same group + // id. Hence, we should generate a unique id for each query. + val uniqueGroupId = s"spark-kafka-source-${UUID.randomUUID}-${metadataPath.hashCode}" + + val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) } + val specifiedKafkaParams = + parameters + .keySet + .filter(_.toLowerCase(Locale.ROOT).startsWith("kafka.")) + .map { k => k.drop(6).toString -> parameters(k) } + .toMap + + val startingStreamOffsets = KafkaSourceProvider.getKafkaOffsetRangeLimit(caseInsensitiveParams, + STARTING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit) + + val kafkaOffsetReader = new KafkaOffsetReader( + strategy(caseInsensitiveParams), + kafkaParamsForDriver(specifiedKafkaParams), + parameters, + driverGroupIdPrefix = s"$uniqueGroupId-driver") + + new KafkaMicroBatchReader( + kafkaOffsetReader, + kafkaParamsForExecutors(specifiedKafkaParams, uniqueGroupId), + options, + metadataPath, + startingStreamOffsets, + failOnDataLoss(caseInsensitiveParams)) + } + + /** + * Creates a [[org.apache.spark.sql.sources.v2.reader.streaming.ContinuousDataReader]] to read + * Kafka data in a continuous streaming query. + */ override def createContinuousReader( schema: Optional[StructType], metadataPath: String, @@ -408,8 +455,27 @@ private[kafka010] object KafkaSourceProvider extends Logging { private[kafka010] val STARTING_OFFSETS_OPTION_KEY = "startingoffsets" private[kafka010] val ENDING_OFFSETS_OPTION_KEY = "endingoffsets" private val FAIL_ON_DATA_LOSS_OPTION_KEY = "failondataloss" + val TOPIC_OPTION_KEY = "topic" + val INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE = + """ + |Some data may have been lost because they are not available in Kafka any more; either the + | data was aged out by Kafka or the topic may have been deleted before all the data in the + | topic was processed. If you want your streaming query to fail on such cases, set the source + | option "failOnDataLoss" to "true". + """.stripMargin + + val INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE = + """ + |Some data may have been lost because they are not available in Kafka any more; either the + | data was aged out by Kafka or the topic may have been deleted before all the data in the + | topic was processed. If you don't want your streaming query to fail on such cases, set the + | source option "failOnDataLoss" to "false". + """.stripMargin + + + private val deserClassName = classOf[ByteArrayDeserializer].getName def getKafkaOffsetRangeLimit( diff --git a/external/kafka-0-10-sql/src/test/resources/kafka-source-initial-offset-future-version.bin b/external/kafka-0-10-sql/src/test/resources/kafka-source-initial-offset-future-version.bin new file mode 100644 index 0000000000000..d530773f57327 --- /dev/null +++ b/external/kafka-0-10-sql/src/test/resources/kafka-source-initial-offset-future-version.bin @@ -0,0 +1,2 @@ +0v99999 +{"kafka-initial-offset-future-version":{"2":2,"1":1,"0":0}} \ No newline at end of file diff --git a/external/kafka-0-10-sql/src/test/resources/kafka-source-initial-offset-version-2.1.0.bin b/external/kafka-0-10-sql/src/test/resources/kafka-source-initial-offset-version-2.1.0.bin index ae928e724967d..8c78d9e390a0e 100644 --- a/external/kafka-0-10-sql/src/test/resources/kafka-source-initial-offset-version-2.1.0.bin +++ b/external/kafka-0-10-sql/src/test/resources/kafka-source-initial-offset-version-2.1.0.bin @@ -1 +1 @@ -2{"kafka-initial-offset-2-1-0":{"2":0,"1":0,"0":0}} \ No newline at end of file +2{"kafka-initial-offset-2-1-0":{"2":2,"1":1,"0":0}} \ No newline at end of file diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala similarity index 85% rename from external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala rename to external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala index 02c87643568bd..ed4ecfeafa972 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala @@ -25,6 +25,7 @@ import java.util.concurrent.ConcurrentLinkedQueue import java.util.concurrent.atomic.AtomicInteger import scala.collection.mutable +import scala.io.Source import scala.util.Random import org.apache.kafka.clients.producer.RecordMetadata @@ -42,7 +43,6 @@ import org.apache.spark.sql.kafka010.KafkaSourceProvider._ import org.apache.spark.sql.streaming.{ProcessingTime, StreamTest} import org.apache.spark.sql.streaming.util.StreamManualClock import org.apache.spark.sql.test.{SharedSQLContext, TestSparkSession} -import org.apache.spark.util.Utils abstract class KafkaSourceTest extends StreamTest with SharedSQLContext { @@ -112,14 +112,18 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext { query.nonEmpty, "Cannot add data when there is no query for finding the active kafka source") - val sources = query.get.logicalPlan.collect { - case StreamingExecutionRelation(source: KafkaSource, _) => source - } ++ (query.get.lastExecution match { - case null => Seq() - case e => e.logical.collect { - case DataSourceV2Relation(_, reader: KafkaContinuousReader) => reader - } - }) + val sources = { + query.get.logicalPlan.collect { + case StreamingExecutionRelation(source: KafkaSource, _) => source + case StreamingExecutionRelation(source: KafkaMicroBatchReader, _) => source + } ++ (query.get.lastExecution match { + case null => Seq() + case e => e.logical.collect { + case DataSourceV2Relation(_, reader: KafkaContinuousReader) => reader + } + }) + }.distinct + if (sources.isEmpty) { throw new Exception( "Could not find Kafka source in the StreamExecution logical plan to add data to") @@ -155,7 +159,7 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext { protected def newTopic(): String = s"topic-${topicId.getAndIncrement()}" } -class KafkaMicroBatchSourceSuite extends KafkaSourceSuiteBase { +abstract class KafkaMicroBatchSourceSuiteBase extends KafkaSourceSuiteBase { import testImplicits._ @@ -303,94 +307,105 @@ class KafkaMicroBatchSourceSuite extends KafkaSourceSuiteBase { ) } - testWithUninterruptibleThread( - "deserialization of initial offset with Spark 2.1.0") { + test("ensure that initial offset are written with an extra byte in the beginning (SPARK-19517)") { withTempDir { metadataPath => - val topic = newTopic - testUtils.createTopic(topic, partitions = 3) + val topic = "kafka-initial-offset-current" + testUtils.createTopic(topic, partitions = 1) - val provider = new KafkaSourceProvider - val parameters = Map( - "kafka.bootstrap.servers" -> testUtils.brokerAddress, - "subscribe" -> topic - ) - val source = provider.createSource(spark.sqlContext, metadataPath.getAbsolutePath, None, - "", parameters) - source.getOffset.get // Write initial offset - - // Make sure Spark 2.1.0 will throw an exception when reading the new log - intercept[java.lang.IllegalArgumentException] { - // Simulate how Spark 2.1.0 reads the log - Utils.tryWithResource(new FileInputStream(metadataPath.getAbsolutePath + "/0")) { in => - val length = in.read() - val bytes = new Array[Byte](length) - in.read(bytes) - KafkaSourceOffset(SerializedOffset(new String(bytes, UTF_8))) - } + val initialOffsetFile = Paths.get(s"${metadataPath.getAbsolutePath}/sources/0/0").toFile + + val df = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("subscribe", topic) + .option("startingOffsets", s"earliest") + .load() + + // Test the written initial offset file has 0 byte in the beginning, so that + // Spark 2.1.0 can read the offsets (see SPARK-19517) + testStream(df)( + StartStream(checkpointLocation = metadataPath.getAbsolutePath), + makeSureGetOffsetCalled) + + val binarySource = Source.fromFile(initialOffsetFile) + try { + assert(binarySource.next().toInt == 0) // first byte is binary 0 + } finally { + binarySource.close() } } } - testWithUninterruptibleThread("deserialization of initial offset written by Spark 2.1.0") { + test("deserialization of initial offset written by Spark 2.1.0 (SPARK-19517)") { withTempDir { metadataPath => val topic = "kafka-initial-offset-2-1-0" testUtils.createTopic(topic, partitions = 3) + testUtils.sendMessages(topic, Array("0", "1", "2"), Some(0)) + testUtils.sendMessages(topic, Array("0", "10", "20"), Some(1)) + testUtils.sendMessages(topic, Array("0", "100", "200"), Some(2)) - val provider = new KafkaSourceProvider - val parameters = Map( - "kafka.bootstrap.servers" -> testUtils.brokerAddress, - "subscribe" -> topic - ) - + // Copy the initial offset file into the right location inside the checkpoint root directory + // such that the Kafka source can read it for initial offsets. val from = new File( getClass.getResource("/kafka-source-initial-offset-version-2.1.0.bin").toURI).toPath - val to = Paths.get(s"${metadataPath.getAbsolutePath}/0") + val to = Paths.get(s"${metadataPath.getAbsolutePath}/sources/0/0") + Files.createDirectories(to.getParent) Files.copy(from, to) - val source = provider.createSource( - spark.sqlContext, metadataPath.toURI.toString, None, "", parameters) - val deserializedOffset = source.getOffset.get - val referenceOffset = KafkaSourceOffset((topic, 0, 0L), (topic, 1, 0L), (topic, 2, 0L)) - assert(referenceOffset == deserializedOffset) + val df = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("subscribe", topic) + .option("startingOffsets", s"earliest") + .load() + .selectExpr("CAST(value AS STRING)") + .as[String] + .map(_.toInt) + + // Test that the query starts from the expected initial offset (i.e. read older offsets, + // even though startingOffsets is latest). + testStream(df)( + StartStream(checkpointLocation = metadataPath.getAbsolutePath), + AddKafkaData(Set(topic), 1000), + CheckAnswer(0, 1, 2, 10, 20, 200, 1000)) } } - testWithUninterruptibleThread("deserialization of initial offset written by future version") { + test("deserialization of initial offset written by future version") { withTempDir { metadataPath => - val futureMetadataLog = - new HDFSMetadataLog[KafkaSourceOffset](sqlContext.sparkSession, - metadataPath.getAbsolutePath) { - override def serialize(metadata: KafkaSourceOffset, out: OutputStream): Unit = { - out.write(0) - val writer = new BufferedWriter(new OutputStreamWriter(out, UTF_8)) - writer.write(s"v99999\n${metadata.json}") - writer.flush - } - } - - val topic = newTopic + val topic = "kafka-initial-offset-future-version" testUtils.createTopic(topic, partitions = 3) - val offset = KafkaSourceOffset((topic, 0, 0L), (topic, 1, 0L), (topic, 2, 0L)) - futureMetadataLog.add(0, offset) - - val provider = new KafkaSourceProvider - val parameters = Map( - "kafka.bootstrap.servers" -> testUtils.brokerAddress, - "subscribe" -> topic - ) - val source = provider.createSource(spark.sqlContext, metadataPath.getAbsolutePath, None, - "", parameters) - val e = intercept[java.lang.IllegalStateException] { - source.getOffset.get // Read initial offset - } + // Copy the initial offset file into the right location inside the checkpoint root directory + // such that the Kafka source can read it for initial offsets. + val from = new File( + getClass.getResource("/kafka-source-initial-offset-future-version.bin").toURI).toPath + val to = Paths.get(s"${metadataPath.getAbsolutePath}/sources/0/0") + Files.createDirectories(to.getParent) + Files.copy(from, to) - Seq( - s"maximum supported log version is v${KafkaSource.VERSION}, but encountered v99999", - "produced by a newer version of Spark and cannot be read by this version" - ).foreach { message => - assert(e.getMessage.contains(message)) - } + val df = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("subscribe", topic) + .load() + .selectExpr("CAST(value AS STRING)") + .as[String] + .map(_.toInt) + + testStream(df)( + StartStream(checkpointLocation = metadataPath.getAbsolutePath), + ExpectFailure[IllegalStateException](e => { + Seq( + s"maximum supported log version is v1, but encountered v99999", + "produced by a newer version of Spark and cannot be read by this version" + ).foreach { message => + assert(e.toString.contains(message)) + } + })) } } @@ -542,6 +557,91 @@ class KafkaMicroBatchSourceSuite extends KafkaSourceSuiteBase { CheckLastBatch(120 to 124: _*) ) } + + test("ensure stream-stream self-join generates only one offset in offset log") { + val topic = newTopic() + testUtils.createTopic(topic, partitions = 2) + require(testUtils.getLatestOffsets(Set(topic)).size === 2) + + val kafka = spark + .readStream + .format("kafka") + .option("subscribe", topic) + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("kafka.metadata.max.age.ms", "1") + .load() + + val values = kafka + .selectExpr("CAST(CAST(value AS STRING) AS INT) AS value", + "CAST(CAST(value AS STRING) AS INT) % 5 AS key") + + val join = values.join(values, "key") + + testStream(join)( + makeSureGetOffsetCalled, + AddKafkaData(Set(topic), 1, 2), + CheckAnswer((1, 1, 1), (2, 2, 2)), + AddKafkaData(Set(topic), 6, 3), + CheckAnswer((1, 1, 1), (2, 2, 2), (3, 3, 3), (1, 6, 1), (1, 1, 6), (1, 6, 6)) + ) + } +} + + +class KafkaMicroBatchV1SourceSuite extends KafkaMicroBatchSourceSuiteBase { + override def beforeAll(): Unit = { + super.beforeAll() + spark.conf.set( + "spark.sql.streaming.disabledV2MicroBatchReaders", + classOf[KafkaSourceProvider].getCanonicalName) + } + + test("V1 Source is used when disabled through SQLConf") { + val topic = newTopic() + testUtils.createTopic(topic, partitions = 5) + + val kafka = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("kafka.metadata.max.age.ms", "1") + .option("subscribePattern", s"$topic.*") + .load() + + testStream(kafka)( + makeSureGetOffsetCalled, + AssertOnQuery { query => + query.logicalPlan.collect { + case StreamingExecutionRelation(_: KafkaSource, _) => true + }.nonEmpty + } + ) + } +} + +class KafkaMicroBatchV2SourceSuite extends KafkaMicroBatchSourceSuiteBase { + + test("V2 Source is used by default") { + val topic = newTopic() + testUtils.createTopic(topic, partitions = 5) + + val kafka = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("kafka.metadata.max.age.ms", "1") + .option("subscribePattern", s"$topic.*") + .load() + + testStream(kafka)( + makeSureGetOffsetCalled, + AssertOnQuery { query => + query.logicalPlan.collect { + case StreamingExecutionRelation(_: KafkaMicroBatchReader, _) => true + }.nonEmpty + } + ) + } } abstract class KafkaSourceSuiteBase extends KafkaSourceTest { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index f24fd7ff74d3f..e75e1d66ebcf8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1146,10 +1146,20 @@ object SQLConf { val DISABLED_V2_STREAMING_WRITERS = buildConf("spark.sql.streaming.disabledV2Writers") .internal() .doc("A comma-separated list of fully qualified data source register class names for which" + - " StreamWriteSupport is disabled. Writes to these sources will fail back to the V1 Sink.") + " StreamWriteSupport is disabled. Writes to these sources will fall back to the V1 Sinks.") .stringConf .createWithDefault("") + val DISABLED_V2_STREAMING_MICROBATCH_READERS = + buildConf("spark.sql.streaming.disabledV2MicroBatchReaders") + .internal() + .doc( + "A comma-separated list of fully qualified data source register class names for which " + + "MicroBatchReadSupport is disabled. Reads from these sources will fall back to the " + + "V1 Sources.") + .stringConf + .createWithDefault("") + object PartitionOverwriteMode extends Enumeration { val STATIC, DYNAMIC = Value } @@ -1525,6 +1535,9 @@ class SQLConf extends Serializable with Logging { def disabledV2StreamingWriters: String = getConf(DISABLED_V2_STREAMING_WRITERS) + def disabledV2StreamingMicroBatchReaders: String = + getConf(DISABLED_V2_STREAMING_MICROBATCH_READERS) + def concatBinaryAsString: Boolean = getConf(CONCAT_BINARY_AS_STRING) def eltOutputAsString: Boolean = getConf(ELT_OUTPUT_AS_STRING) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index ac73ba3417904..84655013ba957 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -72,27 +72,36 @@ class MicroBatchExecution( // Note that we have to use the previous `output` as attributes in StreamingExecutionRelation, // since the existing logical plan has already used those attributes. The per-microbatch // transformation is responsible for replacing attributes with their final values. + + val disabledSources = + sparkSession.sqlContext.conf.disabledV2StreamingMicroBatchReaders.split(",") + val _logicalPlan = analyzedPlan.transform { - case streamingRelation@StreamingRelation(dataSource, _, output) => + case streamingRelation@StreamingRelation(dataSourceV1, sourceName, output) => toExecutionRelationMap.getOrElseUpdate(streamingRelation, { // Materialize source to avoid creating it in every batch val metadataPath = s"$resolvedCheckpointRoot/sources/$nextSourceId" - val source = dataSource.createSource(metadataPath) + val source = dataSourceV1.createSource(metadataPath) nextSourceId += 1 + logInfo(s"Using Source [$source] from DataSourceV1 named '$sourceName' [$dataSourceV1]") StreamingExecutionRelation(source, output)(sparkSession) }) - case s @ StreamingRelationV2(source: MicroBatchReadSupport, _, options, output, _) => + case s @ StreamingRelationV2( + dataSourceV2: MicroBatchReadSupport, sourceName, options, output, _) if + !disabledSources.contains(dataSourceV2.getClass.getCanonicalName) => v2ToExecutionRelationMap.getOrElseUpdate(s, { // Materialize source to avoid creating it in every batch val metadataPath = s"$resolvedCheckpointRoot/sources/$nextSourceId" - val reader = source.createMicroBatchReader( + val reader = dataSourceV2.createMicroBatchReader( Optional.empty(), // user specified schema metadataPath, new DataSourceOptions(options.asJava)) nextSourceId += 1 + logInfo(s"Using MicroBatchReader [$reader] from " + + s"DataSourceV2 named '$sourceName' [$dataSourceV2]") StreamingExecutionRelation(reader, output)(sparkSession) }) - case s @ StreamingRelationV2(_, sourceName, _, output, v1Relation) => + case s @ StreamingRelationV2(dataSourceV2, sourceName, _, output, v1Relation) => v2ToExecutionRelationMap.getOrElseUpdate(s, { // Materialize source to avoid creating it in every batch val metadataPath = s"$resolvedCheckpointRoot/sources/$nextSourceId" @@ -102,6 +111,7 @@ class MicroBatchExecution( } val source = v1Relation.get.dataSource.createSource(metadataPath) nextSourceId += 1 + logInfo(s"Using Source [$source] from DataSourceV2 named '$sourceName' [$dataSourceV2]") StreamingExecutionRelation(source, output)(sparkSession) }) } From d5ed2108d32e1d95b26ee7fed39e8a733e935e2c Mon Sep 17 00:00:00 2001 From: Shintaro Murakami Date: Fri, 16 Feb 2018 17:17:55 -0800 Subject: [PATCH 0368/1161] [SPARK-23381][CORE] Murmur3 hash generates a different value from other implementations ## What changes were proposed in this pull request? Murmur3 hash generates a different value from the original and other implementations (like Scala standard library and Guava or so) when the length of a bytes array is not multiple of 4. ## How was this patch tested? Added a unit test. **Note: When we merge this PR, please give all the credits to Shintaro Murakami.** Author: Shintaro Murakami Author: gatorsmile Author: Shintaro Murakami Closes #20630 from gatorsmile/pr-20568. --- .../spark/util/sketch/Murmur3_x86_32.java | 16 +++++++++ .../spark/unsafe/hash/Murmur3_x86_32.java | 16 +++++++++ .../unsafe/hash/Murmur3_x86_32Suite.java | 19 +++++++++++ .../spark/ml/feature/FeatureHasher.scala | 33 ++++++++++++++++++- .../spark/mllib/feature/HashingTF.scala | 2 +- .../spark/ml/feature/FeatureHasherSuite.scala | 11 ++++++- python/pyspark/ml/feature.py | 4 +-- 7 files changed, 96 insertions(+), 5 deletions(-) diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/Murmur3_x86_32.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/Murmur3_x86_32.java index a61ce4fb7241d..e83b331391e39 100644 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/Murmur3_x86_32.java +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/Murmur3_x86_32.java @@ -60,6 +60,8 @@ public static int hashUnsafeWords(Object base, long offset, int lengthInBytes, i } public static int hashUnsafeBytes(Object base, long offset, int lengthInBytes, int seed) { + // This is not compatible with original and another implementations. + // But remain it for backward compatibility for the components existing before 2.3. assert (lengthInBytes >= 0): "lengthInBytes cannot be negative"; int lengthAligned = lengthInBytes - lengthInBytes % 4; int h1 = hashBytesByInt(base, offset, lengthAligned, seed); @@ -71,6 +73,20 @@ public static int hashUnsafeBytes(Object base, long offset, int lengthInBytes, i return fmix(h1, lengthInBytes); } + public static int hashUnsafeBytes2(Object base, long offset, int lengthInBytes, int seed) { + // This is compatible with original and another implementations. + // Use this method for new components after Spark 2.3. + assert (lengthInBytes >= 0): "lengthInBytes cannot be negative"; + int lengthAligned = lengthInBytes - lengthInBytes % 4; + int h1 = hashBytesByInt(base, offset, lengthAligned, seed); + int k1 = 0; + for (int i = lengthAligned, shift = 0; i < lengthInBytes; i++, shift += 8) { + k1 ^= (Platform.getByte(base, offset + i) & 0xFF) << shift; + } + h1 ^= mixK1(k1); + return fmix(h1, lengthInBytes); + } + private static int hashBytesByInt(Object base, long offset, int lengthInBytes, int seed) { assert (lengthInBytes % 4 == 0); int h1 = seed; diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java index 5e7ee480cafd1..d239de6083ad0 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java @@ -60,6 +60,8 @@ public static int hashUnsafeWords(Object base, long offset, int lengthInBytes, i } public static int hashUnsafeBytes(Object base, long offset, int lengthInBytes, int seed) { + // This is not compatible with original and another implementations. + // But remain it for backward compatibility for the components existing before 2.3. assert (lengthInBytes >= 0): "lengthInBytes cannot be negative"; int lengthAligned = lengthInBytes - lengthInBytes % 4; int h1 = hashBytesByInt(base, offset, lengthAligned, seed); @@ -71,6 +73,20 @@ public static int hashUnsafeBytes(Object base, long offset, int lengthInBytes, i return fmix(h1, lengthInBytes); } + public static int hashUnsafeBytes2(Object base, long offset, int lengthInBytes, int seed) { + // This is compatible with original and another implementations. + // Use this method for new components after Spark 2.3. + assert (lengthInBytes >= 0): "lengthInBytes cannot be negative"; + int lengthAligned = lengthInBytes - lengthInBytes % 4; + int h1 = hashBytesByInt(base, offset, lengthAligned, seed); + int k1 = 0; + for (int i = lengthAligned, shift = 0; i < lengthInBytes; i++, shift += 8) { + k1 ^= (Platform.getByte(base, offset + i) & 0xFF) << shift; + } + h1 ^= mixK1(k1); + return fmix(h1, lengthInBytes); + } + private static int hashBytesByInt(Object base, long offset, int lengthInBytes, int seed) { assert (lengthInBytes % 4 == 0); int h1 = seed; diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java index e759cb33b3e6a..6348a73bf3895 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java @@ -22,6 +22,8 @@ import java.util.Random; import java.util.Set; +import scala.util.hashing.MurmurHash3$; + import org.apache.spark.unsafe.Platform; import org.junit.Assert; import org.junit.Test; @@ -51,6 +53,23 @@ public void testKnownLongInputs() { Assert.assertEquals(-2106506049, hasher.hashLong(Long.MAX_VALUE)); } + // SPARK-23381 Check whether the hash of the byte array is the same as another implementations + @Test + public void testKnownBytesInputs() { + byte[] test = "test".getBytes(StandardCharsets.UTF_8); + Assert.assertEquals(MurmurHash3$.MODULE$.bytesHash(test, 0), + Murmur3_x86_32.hashUnsafeBytes2(test, Platform.BYTE_ARRAY_OFFSET, test.length, 0)); + byte[] test1 = "test1".getBytes(StandardCharsets.UTF_8); + Assert.assertEquals(MurmurHash3$.MODULE$.bytesHash(test1, 0), + Murmur3_x86_32.hashUnsafeBytes2(test1, Platform.BYTE_ARRAY_OFFSET, test1.length, 0)); + byte[] te = "te".getBytes(StandardCharsets.UTF_8); + Assert.assertEquals(MurmurHash3$.MODULE$.bytesHash(te, 0), + Murmur3_x86_32.hashUnsafeBytes2(te, Platform.BYTE_ARRAY_OFFSET, te.length, 0)); + byte[] tes = "tes".getBytes(StandardCharsets.UTF_8); + Assert.assertEquals(MurmurHash3$.MODULE$.bytesHash(tes, 0), + Murmur3_x86_32.hashUnsafeBytes2(tes, Platform.BYTE_ARRAY_OFFSET, tes.length, 0)); + } + @Test public void randomizedStressTest() { int size = 65536; diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/FeatureHasher.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/FeatureHasher.scala index a918dd4c075da..c78f61ac3ef71 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/FeatureHasher.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/FeatureHasher.scala @@ -17,6 +17,7 @@ package org.apache.spark.ml.feature +import org.apache.spark.SparkException import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.Transformer import org.apache.spark.ml.attribute.AttributeGroup @@ -28,6 +29,8 @@ import org.apache.spark.mllib.feature.{HashingTF => OldHashingTF} import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.hash.Murmur3_x86_32.{hashInt, hashLong, hashUnsafeBytes2} +import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils import org.apache.spark.util.collection.OpenHashMap @@ -138,7 +141,7 @@ class FeatureHasher(@Since("2.3.0") override val uid: String) extends Transforme @Since("2.3.0") override def transform(dataset: Dataset[_]): DataFrame = { - val hashFunc: Any => Int = OldHashingTF.murmur3Hash + val hashFunc: Any => Int = FeatureHasher.murmur3Hash val n = $(numFeatures) val localInputCols = $(inputCols) val catCols = if (isSet(categoricalCols)) { @@ -218,4 +221,32 @@ object FeatureHasher extends DefaultParamsReadable[FeatureHasher] { @Since("2.3.0") override def load(path: String): FeatureHasher = super.load(path) + + private val seed = OldHashingTF.seed + + /** + * Calculate a hash code value for the term object using + * Austin Appleby's MurmurHash 3 algorithm (MurmurHash3_x86_32). + * This is the default hash algorithm used from Spark 2.0 onwards. + * Use hashUnsafeBytes2 to match the original algorithm with the value. + * See SPARK-23381. + */ + @Since("2.3.0") + private[feature] def murmur3Hash(term: Any): Int = { + term match { + case null => seed + case b: Boolean => hashInt(if (b) 1 else 0, seed) + case b: Byte => hashInt(b, seed) + case s: Short => hashInt(s, seed) + case i: Int => hashInt(i, seed) + case l: Long => hashLong(l, seed) + case f: Float => hashInt(java.lang.Float.floatToIntBits(f), seed) + case d: Double => hashLong(java.lang.Double.doubleToLongBits(d), seed) + case s: String => + val utf8 = UTF8String.fromString(s) + hashUnsafeBytes2(utf8.getBaseObject, utf8.getBaseOffset, utf8.numBytes(), seed) + case _ => throw new SparkException("FeatureHasher with murmur3 algorithm does not " + + s"support type ${term.getClass.getCanonicalName} of input data.") + } + } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala index 9abdd44a635d1..8935c8496cdbb 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala @@ -135,7 +135,7 @@ object HashingTF { private[HashingTF] val Murmur3: String = "murmur3" - private val seed = 42 + private[spark] val seed = 42 /** * Calculate a hash code value for the term object using the native Scala implementation. diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/FeatureHasherSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/FeatureHasherSuite.scala index 3fc3cbb62d5b5..7bc1825b69c43 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/FeatureHasherSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/FeatureHasherSuite.scala @@ -27,6 +27,7 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.functions.col import org.apache.spark.sql.types._ +import org.apache.spark.util.Utils class FeatureHasherSuite extends SparkFunSuite with MLlibTestSparkContext @@ -34,7 +35,7 @@ class FeatureHasherSuite extends SparkFunSuite import testImplicits._ - import HashingTFSuite.murmur3FeatureIdx + import FeatureHasherSuite.murmur3FeatureIdx implicit private val vectorEncoder = ExpressionEncoder[Vector]() @@ -216,3 +217,11 @@ class FeatureHasherSuite extends SparkFunSuite testDefaultReadWrite(t) } } + +object FeatureHasherSuite { + + private[feature] def murmur3FeatureIdx(numFeatures: Int)(term: Any): Int = { + Utils.nonNegativeMod(FeatureHasher.murmur3Hash(term), numFeatures) + } + +} diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index da85ba761a145..04b07e6a05481 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -741,9 +741,9 @@ class FeatureHasher(JavaTransformer, HasInputCols, HasOutputCol, HasNumFeatures, >>> df = spark.createDataFrame(data, cols) >>> hasher = FeatureHasher(inputCols=cols, outputCol="features") >>> hasher.transform(df).head().features - SparseVector(262144, {51871: 1.0, 63643: 1.0, 174475: 2.0, 253195: 1.0}) + SparseVector(262144, {174475: 2.0, 247670: 1.0, 257907: 1.0, 262126: 1.0}) >>> hasher.setCategoricalCols(["real"]).transform(df).head().features - SparseVector(262144, {51871: 1.0, 63643: 1.0, 171257: 1.0, 253195: 1.0}) + SparseVector(262144, {171257: 1.0, 247670: 1.0, 257907: 1.0, 262126: 1.0}) >>> hasherPath = temp_path + "/hasher" >>> hasher.save(hasherPath) >>> loadedHasher = FeatureHasher.load(hasherPath) From 15ad4a7f1000c83cefbecd41e315c964caa3c39f Mon Sep 17 00:00:00 2001 From: Kris Mok Date: Sat, 17 Feb 2018 10:54:14 +0800 Subject: [PATCH 0369/1161] [SPARK-23447][SQL] Cleanup codegen template for Literal ## What changes were proposed in this pull request? Cleaned up the codegen templates for `Literal`s, to make sure that the `ExprCode` returned from `Literal.doGenCode()` has: 1. an empty `code` field; 2. an `isNull` field of either literal `true` or `false`; 3. a `value` field that is just a simple literal/constant. Before this PR, there are a couple of paths that would return a non-trivial `code` and all of them are actually unnecessary. The `NaN` and `Infinity` constants for `double` and `float` can be accessed through constants directly available so there's no need to add a reference for them. Also took the opportunity to add a new util method for ease of creating `ExprCode` for inline-able non-null values. ## How was this patch tested? Existing tests. Author: Kris Mok Closes #20626 from rednaxelafx/codegen-literal. --- .../expressions/codegen/CodeGenerator.scala | 6 +++ .../sql/catalyst/expressions/literals.scala | 51 ++++++++++--------- 2 files changed, 34 insertions(+), 23 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 4dcbb702893da..31ba29ae8d8ce 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -58,6 +58,12 @@ import org.apache.spark.util.{ParentClassLoader, Utils} */ case class ExprCode(var code: String, var isNull: String, var value: String) +object ExprCode { + def forNonNullValue(value: String): ExprCode = { + ExprCode(code = "", isNull = "false", value = value) + } +} + /** * State used for subexpression elimination. * diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index cd176d941819f..c1e65e34c2ea6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -278,40 +278,45 @@ case class Literal (value: Any, dataType: DataType) extends LeafExpression { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val javaType = ctx.javaType(dataType) - // change the isNull and primitive to consts, to inline them if (value == null) { - ev.isNull = "true" - ev.copy(s"final $javaType ${ev.value} = ${ctx.defaultValue(dataType)};") + val defaultValueLiteral = ctx.defaultValue(javaType) match { + case "null" => s"(($javaType)null)" + case lit => lit + } + ExprCode(code = "", isNull = "true", value = defaultValueLiteral) } else { - ev.isNull = "false" dataType match { case BooleanType | IntegerType | DateType => - ev.copy(code = "", value = value.toString) + ExprCode.forNonNullValue(value.toString) case FloatType => - val v = value.asInstanceOf[Float] - if (v.isNaN || v.isInfinite) { - val boxedValue = ctx.addReferenceObj("boxedValue", v) - val code = s"final $javaType ${ev.value} = ($javaType) $boxedValue;" - ev.copy(code = code) - } else { - ev.copy(code = "", value = s"${value}f") + value.asInstanceOf[Float] match { + case v if v.isNaN => + ExprCode.forNonNullValue("Float.NaN") + case Float.PositiveInfinity => + ExprCode.forNonNullValue("Float.POSITIVE_INFINITY") + case Float.NegativeInfinity => + ExprCode.forNonNullValue("Float.NEGATIVE_INFINITY") + case _ => + ExprCode.forNonNullValue(s"${value}F") } case DoubleType => - val v = value.asInstanceOf[Double] - if (v.isNaN || v.isInfinite) { - val boxedValue = ctx.addReferenceObj("boxedValue", v) - val code = s"final $javaType ${ev.value} = ($javaType) $boxedValue;" - ev.copy(code = code) - } else { - ev.copy(code = "", value = s"${value}D") + value.asInstanceOf[Double] match { + case v if v.isNaN => + ExprCode.forNonNullValue("Double.NaN") + case Double.PositiveInfinity => + ExprCode.forNonNullValue("Double.POSITIVE_INFINITY") + case Double.NegativeInfinity => + ExprCode.forNonNullValue("Double.NEGATIVE_INFINITY") + case _ => + ExprCode.forNonNullValue(s"${value}D") } case ByteType | ShortType => - ev.copy(code = "", value = s"($javaType)$value") + ExprCode.forNonNullValue(s"($javaType)$value") case TimestampType | LongType => - ev.copy(code = "", value = s"${value}L") + ExprCode.forNonNullValue(s"${value}L") case _ => - ev.copy(code = "", value = ctx.addReferenceObj("literal", value, - ctx.javaType(dataType))) + val constRef = ctx.addReferenceObj("literal", value, javaType) + ExprCode.forNonNullValue(constRef) } } } From 3ee3b2ae1ff8fbeb43a08becef43a9bd763b06bb Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Sat, 17 Feb 2018 00:25:36 -0800 Subject: [PATCH 0370/1161] [SPARK-23340][SQL] Upgrade Apache ORC to 1.4.3 ## What changes were proposed in this pull request? This PR updates Apache ORC dependencies to 1.4.3 released on February 9th. Apache ORC 1.4.2 release removes unnecessary dependencies and 1.4.3 has 5 more patches (https://s.apache.org/Fll8). Especially, the following ORC-285 is fixed at 1.4.3. ```scala scala> val df = Seq(Array.empty[Float]).toDF() scala> df.write.format("orc").save("/tmp/floatarray") scala> spark.read.orc("/tmp/floatarray") res1: org.apache.spark.sql.DataFrame = [value: array] scala> spark.read.orc("/tmp/floatarray").show() 18/02/12 22:09:10 ERROR Executor: Exception in task 0.0 in stage 1.0 (TID 1) java.io.IOException: Error reading file: file:/tmp/floatarray/part-00000-9c0b461b-4df1-4c23-aac1-3e4f349ac7d6-c000.snappy.orc at org.apache.orc.impl.RecordReaderImpl.nextBatch(RecordReaderImpl.java:1191) at org.apache.orc.mapreduce.OrcMapreduceRecordReader.ensureBatch(OrcMapreduceRecordReader.java:78) ... Caused by: java.io.EOFException: Read past EOF for compressed stream Stream for column 2 kind DATA position: 0 length: 0 range: 0 offset: 0 limit: 0 ``` ## How was this patch tested? Pass the Jenkins test. Author: Dongjoon Hyun Closes #20511 from dongjoon-hyun/SPARK-23340. --- dev/deps/spark-deps-hadoop-2.6 | 4 ++-- dev/deps/spark-deps-hadoop-2.7 | 4 ++-- pom.xml | 6 +----- .../sql/execution/datasources/orc/OrcSourceSuite.scala | 9 +++++++++ .../apache/spark/sql/hive/orc/HiveOrcQuerySuite.scala | 10 ++++++++++ 5 files changed, 24 insertions(+), 9 deletions(-) diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6 index 99031384aa22e..ed310507d14ed 100644 --- a/dev/deps/spark-deps-hadoop-2.6 +++ b/dev/deps/spark-deps-hadoop-2.6 @@ -157,8 +157,8 @@ objenesis-2.1.jar okhttp-3.8.1.jar okio-1.13.0.jar opencsv-2.3.jar -orc-core-1.4.1-nohive.jar -orc-mapreduce-1.4.1-nohive.jar +orc-core-1.4.3-nohive.jar +orc-mapreduce-1.4.3-nohive.jar oro-2.0.8.jar osgi-resource-locator-1.0.1.jar paranamer-2.8.jar diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index cf8d2789b7ee9..04dec04796af4 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -158,8 +158,8 @@ objenesis-2.1.jar okhttp-3.8.1.jar okio-1.13.0.jar opencsv-2.3.jar -orc-core-1.4.1-nohive.jar -orc-mapreduce-1.4.1-nohive.jar +orc-core-1.4.3-nohive.jar +orc-mapreduce-1.4.3-nohive.jar oro-2.0.8.jar osgi-resource-locator-1.0.1.jar paranamer-2.8.jar diff --git a/pom.xml b/pom.xml index de949b94d676c..ac30107066389 100644 --- a/pom.xml +++ b/pom.xml @@ -130,7 +130,7 @@ 1.2.1 10.12.1.1 1.8.2 - 1.4.1 + 1.4.3 nohive 1.6.0 9.3.20.v20170531 @@ -1740,10 +1740,6 @@ org.apache.hive hive-storage-api - - io.airlift - slice - diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala index 6f5f2fd795f74..523f7cf77e103 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala @@ -160,6 +160,15 @@ abstract class OrcSuite extends OrcTest with BeforeAndAfterAll { } } } + + test("SPARK-23340 Empty float/double array columns raise EOFException") { + Seq(Seq(Array.empty[Float]).toDF(), Seq(Array.empty[Double]).toDF()).foreach { df => + withTempPath { path => + df.write.format("orc").save(path.getCanonicalPath) + checkAnswer(spark.read.orc(path.getCanonicalPath), df) + } + } + } } class OrcSourceSuite extends OrcSuite with SharedSQLContext { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcQuerySuite.scala index 92b2f069cacd6..597b0f56a55e4 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcQuerySuite.scala @@ -208,4 +208,14 @@ class HiveOrcQuerySuite extends OrcQueryTest with TestHiveSingleton { } } } + + test("SPARK-23340 Empty float/double array columns raise EOFException") { + withSQLConf(HiveUtils.CONVERT_METASTORE_ORC.key -> "false") { + withTable("spark_23340") { + sql("CREATE TABLE spark_23340(a array, b array) STORED AS ORC") + sql("INSERT INTO spark_23340 VALUES (array(), array())") + checkAnswer(spark.table("spark_23340"), Seq(Row(Array.empty[Float], Array.empty[Double]))) + } + } + } } From f5850e78924d03448ad243cdd32b24c3fe0ea8af Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Tue, 20 Feb 2018 13:33:03 +0800 Subject: [PATCH 0371/1161] [SPARK-23457][SQL] Register task completion listeners first in ParquetFileFormat ## What changes were proposed in this pull request? ParquetFileFormat leaks opened files in some cases. This PR prevents that by registering task completion listers first before initialization. - [spark-branch-2.3-test-sbt-hadoop-2.7](https://amplab.cs.berkeley.edu/jenkins/view/Spark%20QA%20Test%20(Dashboard)/job/spark-branch-2.3-test-sbt-hadoop-2.7/205/testReport/org.apache.spark.sql/FileBasedDataSourceSuite/_It_is_not_a_test_it_is_a_sbt_testing_SuiteSelector_/) - [spark-master-test-sbt-hadoop-2.6](https://amplab.cs.berkeley.edu/jenkins/view/Spark%20QA%20Test%20(Dashboard)/job/spark-master-test-sbt-hadoop-2.6/4228/testReport/junit/org.apache.spark.sql.execution.datasources.parquet/ParquetQuerySuite/_It_is_not_a_test_it_is_a_sbt_testing_SuiteSelector_/) ``` Caused by: sbt.ForkMain$ForkError: java.lang.Throwable: null at org.apache.spark.DebugFilesystem$.addOpenStream(DebugFilesystem.scala:36) at org.apache.spark.DebugFilesystem.open(DebugFilesystem.scala:70) at org.apache.hadoop.fs.FileSystem.open(FileSystem.java:769) at org.apache.parquet.hadoop.ParquetFileReader.(ParquetFileReader.java:538) at org.apache.spark.sql.execution.datasources.parquet.SpecificParquetRecordReaderBase.initialize(SpecificParquetRecordReaderBase.java:149) at org.apache.spark.sql.execution.datasources.parquet.VectorizedParquetRecordReader.initialize(VectorizedParquetRecordReader.java:133) at org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat$$anonfun$buildReaderWithPartitionValues$1.apply(ParquetFileFormat.scala:400) at ``` ## How was this patch tested? Manual. The following test case generates the same leakage. ```scala test("SPARK-23457 Register task completion listeners first in ParquetFileFormat") { withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_BATCH_SIZE.key -> s"${Int.MaxValue}") { withTempDir { dir => val basePath = dir.getCanonicalPath Seq(0).toDF("a").write.format("parquet").save(new Path(basePath, "first").toString) Seq(1).toDF("a").write.format("parquet").save(new Path(basePath, "second").toString) val df = spark.read.parquet( new Path(basePath, "first").toString, new Path(basePath, "second").toString) val e = intercept[SparkException] { df.collect() } assert(e.getCause.isInstanceOf[OutOfMemoryError]) } } } ``` Author: Dongjoon Hyun Closes #20619 from dongjoon-hyun/SPARK-23390. --- .../parquet/ParquetFileFormat.scala | 22 +++++++++---------- 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala index ba69f9a26c968..476bd02374364 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala @@ -395,16 +395,21 @@ class ParquetFileFormat ParquetInputFormat.setFilterPredicate(hadoopAttemptContext.getConfiguration, pushed.get) } val taskContext = Option(TaskContext.get()) - val parquetReader = if (enableVectorizedReader) { + if (enableVectorizedReader) { val vectorizedReader = new VectorizedParquetRecordReader( convertTz.orNull, enableOffHeapColumnVector && taskContext.isDefined, capacity) + val iter = new RecordReaderIterator(vectorizedReader) + // SPARK-23457 Register a task completion lister before `initialization`. + taskContext.foreach(_.addTaskCompletionListener(_ => iter.close())) vectorizedReader.initialize(split, hadoopAttemptContext) logDebug(s"Appending $partitionSchema ${file.partitionValues}") vectorizedReader.initBatch(partitionSchema, file.partitionValues) if (returningBatch) { vectorizedReader.enableReturningBatches() } - vectorizedReader + + // UnsafeRowParquetRecordReader appends the columns internally to avoid another copy. + iter.asInstanceOf[Iterator[InternalRow]] } else { logDebug(s"Falling back to parquet-mr") // ParquetRecordReader returns UnsafeRow @@ -414,18 +419,11 @@ class ParquetFileFormat } else { new ParquetRecordReader[UnsafeRow](new ParquetReadSupport(convertTz)) } + val iter = new RecordReaderIterator(reader) + // SPARK-23457 Register a task completion lister before `initialization`. + taskContext.foreach(_.addTaskCompletionListener(_ => iter.close())) reader.initialize(split, hadoopAttemptContext) - reader - } - val iter = new RecordReaderIterator(parquetReader) - taskContext.foreach(_.addTaskCompletionListener(_ => iter.close())) - - // UnsafeRowParquetRecordReader appends the columns internally to avoid another copy. - if (parquetReader.isInstanceOf[VectorizedParquetRecordReader] && - enableVectorizedReader) { - iter.asInstanceOf[Iterator[InternalRow]] - } else { val fullSchema = requiredSchema.toAttributes ++ partitionSchema.toAttributes val joinedRow = new JoinedRow() val appendPartitionColumns = GenerateUnsafeProjection.generate(fullSchema, fullSchema) From 651b0277fe989119932d5ae1ef729c9768aa018d Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Tue, 20 Feb 2018 13:56:38 +0800 Subject: [PATCH 0372/1161] [SPARK-23436][SQL] Infer partition as Date only if it can be casted to Date ## What changes were proposed in this pull request? Before the patch, Spark could infer as Date a partition value which cannot be casted to Date (this can happen when there are extra characters after a valid date, like `2018-02-15AAA`). When this happens and the input format has metadata which define the schema of the table, then `null` is returned as a value for the partition column, because the `cast` operator used in (`PartitioningAwareFileIndex.inferPartitioning`) is unable to convert the value. The PR checks in the partition inference that values can be casted to Date and Timestamp, in order to infer that datatype to them. ## How was this patch tested? added UT Author: Marco Gaido Closes #20621 from mgaido91/SPARK-23436. --- .../datasources/PartitioningUtils.scala | 40 ++++++++++++++----- .../ParquetPartitionDiscoverySuite.scala | 14 +++++++ 2 files changed, 44 insertions(+), 10 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala index 472bf82d3604d..379acb67f7c71 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala @@ -407,6 +407,34 @@ object PartitioningUtils { Literal(bigDecimal) } + val dateTry = Try { + // try and parse the date, if no exception occurs this is a candidate to be resolved as + // DateType + DateTimeUtils.getThreadLocalDateFormat.parse(raw) + // SPARK-23436: Casting the string to date may still return null if a bad Date is provided. + // This can happen since DateFormat.parse may not use the entire text of the given string: + // so if there are extra-characters after the date, it returns correctly. + // We need to check that we can cast the raw string since we later can use Cast to get + // the partition values with the right DataType (see + // org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex.inferPartitioning) + val dateValue = Cast(Literal(raw), DateType).eval() + // Disallow DateType if the cast returned null + require(dateValue != null) + Literal.create(dateValue, DateType) + } + + val timestampTry = Try { + val unescapedRaw = unescapePathName(raw) + // try and parse the date, if no exception occurs this is a candidate to be resolved as + // TimestampType + DateTimeUtils.getThreadLocalTimestampFormat(timeZone).parse(unescapedRaw) + // SPARK-23436: see comment for date + val timestampValue = Cast(Literal(unescapedRaw), TimestampType, Some(timeZone.getID)).eval() + // Disallow TimestampType if the cast returned null + require(timestampValue != null) + Literal.create(timestampValue, TimestampType) + } + if (typeInference) { // First tries integral types Try(Literal.create(Integer.parseInt(raw), IntegerType)) @@ -415,16 +443,8 @@ object PartitioningUtils { // Then falls back to fractional types .orElse(Try(Literal.create(JDouble.parseDouble(raw), DoubleType))) // Then falls back to date/timestamp types - .orElse(Try( - Literal.create( - DateTimeUtils.getThreadLocalTimestampFormat(timeZone) - .parse(unescapePathName(raw)).getTime * 1000L, - TimestampType))) - .orElse(Try( - Literal.create( - DateTimeUtils.millisToDays( - DateTimeUtils.getThreadLocalDateFormat.parse(raw).getTime), - DateType))) + .orElse(timestampTry) + .orElse(dateTry) // Then falls back to string .getOrElse { if (raw == DEFAULT_PARTITION_NAME) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala index d4902641e335f..edb3da904d10d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala @@ -1120,4 +1120,18 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha Row(3, BigDecimal("2" * 30)) :: Nil) } } + + test("SPARK-23436: invalid Dates should be inferred as String in partition inference") { + withTempPath { path => + val data = Seq(("1", "2018-01", "2018-01-01-04", "test")) + .toDF("id", "date_month", "date_hour", "data") + + data.write.partitionBy("date_month", "date_hour").parquet(path.getAbsolutePath) + val input = spark.read.parquet(path.getAbsolutePath).select("id", + "date_month", "date_hour", "data") + + assert(input.schema.sameType(input.schema)) + checkAnswer(input, data) + } + } } From aadf9535b4a11b42fd9d72f636576d2da0766199 Mon Sep 17 00:00:00 2001 From: Ryan Blue Date: Tue, 20 Feb 2018 16:04:22 +0800 Subject: [PATCH 0373/1161] [SPARK-23203][SQL] DataSourceV2: Use immutable logical plans. ## What changes were proposed in this pull request? SPARK-23203: DataSourceV2 should use immutable catalyst trees instead of wrapping a mutable DataSourceV2Reader. This commit updates DataSourceV2Relation and consolidates much of the DataSourceV2 API requirements for the read path in it. Instead of wrapping a reader that changes, the relation lazily produces a reader from its configuration. This commit also updates the predicate and projection push-down. Instead of the implementation from SPARK-22197, this reuses the rule matching from the Hive and DataSource read paths (using `PhysicalOperation`) and copies most of the implementation of `SparkPlanner.pruneFilterProject`, with updates for DataSourceV2. By reusing the implementation from other read paths, this should have fewer regressions from other read paths and is less code to maintain. The new push-down rules also supports the following edge cases: * The output of DataSourceV2Relation should be what is returned by the reader, in case the reader can only partially satisfy the requested schema projection * The requested projection passed to the DataSourceV2Reader should include filter columns * The push-down rule may be run more than once if filters are not pushed through projections ## How was this patch tested? Existing push-down and read tests. Author: Ryan Blue Closes #20387 from rdblue/SPARK-22386-push-down-immutable-trees. --- .../kafka010/KafkaContinuousSourceSuite.scala | 19 +- .../sql/kafka010/KafkaContinuousTest.scala | 4 +- .../kafka010/KafkaMicroBatchSourceSuite.scala | 4 +- .../apache/spark/sql/DataFrameReader.scala | 41 +--- .../datasources/v2/DataSourceV2Relation.scala | 212 +++++++++++++++++- .../datasources/v2/DataSourceV2Strategy.scala | 7 +- .../v2/PushDownOperatorsToDataSource.scala | 159 ++++--------- .../continuous/ContinuousExecution.scala | 2 +- .../sql/sources/v2/DataSourceV2Suite.scala | 2 +- .../spark/sql/streaming/StreamTest.scala | 6 +- 10 files changed, 269 insertions(+), 187 deletions(-) diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala index a7083fa4e3417..f679e9bfc0450 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala @@ -17,20 +17,9 @@ package org.apache.spark.sql.kafka010 -import java.util.Properties -import java.util.concurrent.atomic.AtomicInteger - -import org.scalatest.time.SpanSugar._ -import scala.collection.mutable -import scala.util.Random - -import org.apache.spark.SparkContext -import org.apache.spark.sql.{DataFrame, Dataset, ForeachWriter, Row} -import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation -import org.apache.spark.sql.execution.streaming.StreamExecution -import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution -import org.apache.spark.sql.streaming.{StreamTest, Trigger} -import org.apache.spark.sql.test.{SharedSQLContext, TestSparkSession} +import org.apache.spark.sql.Dataset +import org.apache.spark.sql.execution.datasources.v2.StreamingDataSourceV2Relation +import org.apache.spark.sql.streaming.Trigger // Run tests in KafkaSourceSuiteBase in continuous execution mode. class KafkaContinuousSourceSuite extends KafkaSourceSuiteBase with KafkaContinuousTest @@ -71,7 +60,7 @@ class KafkaContinuousSourceTopicDeletionSuite extends KafkaContinuousTest { eventually(timeout(streamingTimeout)) { assert( query.lastExecution.logical.collectFirst { - case DataSourceV2Relation(_, r: KafkaContinuousReader) => r + case StreamingDataSourceV2Relation(_, r: KafkaContinuousReader) => r }.exists { r => // Ensure the new topic is present and the old topic is gone. r.knownPartitions.exists(_.topic == topic2) diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala index 5a1a14f7a307a..48ac3fc1e8f9d 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala @@ -21,7 +21,7 @@ import java.util.concurrent.atomic.AtomicInteger import org.apache.spark.SparkContext import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd, SparkListenerTaskStart} -import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation +import org.apache.spark.sql.execution.datasources.v2.StreamingDataSourceV2Relation import org.apache.spark.sql.execution.streaming.StreamExecution import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution import org.apache.spark.sql.streaming.Trigger @@ -47,7 +47,7 @@ trait KafkaContinuousTest extends KafkaSourceTest { eventually(timeout(streamingTimeout)) { assert( query.lastExecution.logical.collectFirst { - case DataSourceV2Relation(_, r: KafkaContinuousReader) => r + case StreamingDataSourceV2Relation(_, r: KafkaContinuousReader) => r }.exists(_.knownPartitions.size == newCount), s"query never reconfigured to $newCount partitions") } diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala index ed4ecfeafa972..89c9ef4cc73b5 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala @@ -35,7 +35,7 @@ import org.scalatest.time.SpanSugar._ import org.apache.spark.SparkContext import org.apache.spark.sql.{Dataset, ForeachWriter} -import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation +import org.apache.spark.sql.execution.datasources.v2.StreamingDataSourceV2Relation import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution import org.apache.spark.sql.functions.{count, window} @@ -119,7 +119,7 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext { } ++ (query.get.lastExecution match { case null => Seq() case e => e.logical.collect { - case DataSourceV2Relation(_, reader: KafkaContinuousReader) => reader + case StreamingDataSourceV2Relation(_, reader: KafkaContinuousReader) => reader } }) }.distinct diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index fcaf8d618c168..4274f120a375a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -34,7 +34,7 @@ import org.apache.spark.sql.execution.datasources.jdbc._ import org.apache.spark.sql.execution.datasources.json.TextInputJsonDataSource import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Utils -import org.apache.spark.sql.sources.v2._ +import org.apache.spark.sql.sources.v2.{DataSourceV2, ReadSupport, ReadSupportWithSchema} import org.apache.spark.sql.types.{StringType, StructType} import org.apache.spark.unsafe.types.UTF8String @@ -189,39 +189,16 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { val cls = DataSource.lookupDataSource(source, sparkSession.sessionState.conf) if (classOf[DataSourceV2].isAssignableFrom(cls)) { - val ds = cls.newInstance() - val options = new DataSourceOptions((extraOptions ++ - DataSourceV2Utils.extractSessionConfigs( - ds = ds.asInstanceOf[DataSourceV2], - conf = sparkSession.sessionState.conf)).asJava) - - // Streaming also uses the data source V2 API. So it may be that the data source implements - // v2, but has no v2 implementation for batch reads. In that case, we fall back to loading - // the dataframe as a v1 source. - val reader = (ds, userSpecifiedSchema) match { - case (ds: ReadSupportWithSchema, Some(schema)) => - ds.createReader(schema, options) - - case (ds: ReadSupport, None) => - ds.createReader(options) - - case (ds: ReadSupportWithSchema, None) => - throw new AnalysisException(s"A schema needs to be specified when using $ds.") - - case (ds: ReadSupport, Some(schema)) => - val reader = ds.createReader(options) - if (reader.readSchema() != schema) { - throw new AnalysisException(s"$ds does not allow user-specified schemas.") - } - reader - - case _ => null // fall back to v1 - } + val ds = cls.newInstance().asInstanceOf[DataSourceV2] + if (ds.isInstanceOf[ReadSupport] || ds.isInstanceOf[ReadSupportWithSchema]) { + val sessionOptions = DataSourceV2Utils.extractSessionConfigs( + ds = ds, conf = sparkSession.sessionState.conf) + Dataset.ofRows(sparkSession, DataSourceV2Relation.create( + ds, extraOptions.toMap ++ sessionOptions, + userSpecifiedSchema = userSpecifiedSchema)) - if (reader == null) { - loadV1Source(paths: _*) } else { - Dataset.ofRows(sparkSession, DataSourceV2Relation(reader)) + loadV1Source(paths: _*) } } else { loadV1Source(paths: _*) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala index 38f6b15224788..a98dd4866f82a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala @@ -17,17 +17,80 @@ package org.apache.spark.sql.execution.datasources.v2 +import scala.collection.JavaConverters._ + +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation -import org.apache.spark.sql.catalyst.expressions.AttributeReference -import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics} -import org.apache.spark.sql.sources.v2.reader._ +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression} +import org.apache.spark.sql.catalyst.plans.QueryPlan +import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} +import org.apache.spark.sql.execution.datasources.DataSourceStrategy +import org.apache.spark.sql.sources.{DataSourceRegister, Filter} +import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, ReadSupport, ReadSupportWithSchema} +import org.apache.spark.sql.sources.v2.reader.{DataSourceReader, SupportsPushDownCatalystFilters, SupportsPushDownFilters, SupportsPushDownRequiredColumns, SupportsReportStatistics} +import org.apache.spark.sql.types.StructType case class DataSourceV2Relation( - output: Seq[AttributeReference], - reader: DataSourceReader) - extends LeafNode with MultiInstanceRelation with DataSourceReaderHolder { + source: DataSourceV2, + options: Map[String, String], + projection: Seq[AttributeReference], + filters: Option[Seq[Expression]] = None, + userSpecifiedSchema: Option[StructType] = None) extends LeafNode with MultiInstanceRelation { + + import DataSourceV2Relation._ + + override def simpleString: String = { + s"DataSourceV2Relation(source=${source.name}, " + + s"schema=[${output.map(a => s"$a ${a.dataType.simpleString}").mkString(", ")}], " + + s"filters=[${pushedFilters.mkString(", ")}], options=$options)" + } + + override lazy val schema: StructType = reader.readSchema() + + override lazy val output: Seq[AttributeReference] = { + // use the projection attributes to avoid assigning new ids. fields that are not projected + // will be assigned new ids, which is okay because they are not projected. + val attrMap = projection.map(a => a.name -> a).toMap + schema.map(f => attrMap.getOrElse(f.name, + AttributeReference(f.name, f.dataType, f.nullable, f.metadata)())) + } + + private lazy val v2Options: DataSourceOptions = makeV2Options(options) + + lazy val ( + reader: DataSourceReader, + unsupportedFilters: Seq[Expression], + pushedFilters: Seq[Expression]) = { + val newReader = userSpecifiedSchema match { + case Some(s) => + source.asReadSupportWithSchema.createReader(s, v2Options) + case _ => + source.asReadSupport.createReader(v2Options) + } + + DataSourceV2Relation.pushRequiredColumns(newReader, projection.toStructType) - override def canEqual(other: Any): Boolean = other.isInstanceOf[DataSourceV2Relation] + val (remainingFilters, pushedFilters) = filters match { + case Some(filterSeq) => + DataSourceV2Relation.pushFilters(newReader, filterSeq) + case _ => + (Nil, Nil) + } + + (newReader, remainingFilters, pushedFilters) + } + + override def doCanonicalize(): LogicalPlan = { + val c = super.doCanonicalize().asInstanceOf[DataSourceV2Relation] + + // override output with canonicalized output to avoid attempting to configure a reader + val canonicalOutput: Seq[AttributeReference] = this.output + .map(a => QueryPlan.normalizeExprId(a, projection)) + + new DataSourceV2Relation(c.source, c.options, c.projection) { + override lazy val output: Seq[AttributeReference] = canonicalOutput + } + } override def computeStats(): Statistics = reader match { case r: SupportsReportStatistics => @@ -37,7 +100,9 @@ case class DataSourceV2Relation( } override def newInstance(): DataSourceV2Relation = { - copy(output = output.map(_.newInstance())) + // projection is used to maintain id assignment. + // if projection is not set, use output so the copy is not equal to the original + copy(projection = projection.map(_.newInstance())) } } @@ -45,14 +110,137 @@ case class DataSourceV2Relation( * A specialization of DataSourceV2Relation with the streaming bit set to true. Otherwise identical * to the non-streaming relation. */ -class StreamingDataSourceV2Relation( +case class StreamingDataSourceV2Relation( output: Seq[AttributeReference], - reader: DataSourceReader) extends DataSourceV2Relation(output, reader) { + reader: DataSourceReader) + extends LeafNode with DataSourceReaderHolder with MultiInstanceRelation { override def isStreaming: Boolean = true + + override def canEqual(other: Any): Boolean = other.isInstanceOf[StreamingDataSourceV2Relation] + + override def newInstance(): LogicalPlan = copy(output = output.map(_.newInstance())) + + override def computeStats(): Statistics = reader match { + case r: SupportsReportStatistics => + Statistics(sizeInBytes = r.getStatistics.sizeInBytes().orElse(conf.defaultSizeInBytes)) + case _ => + Statistics(sizeInBytes = conf.defaultSizeInBytes) + } } object DataSourceV2Relation { - def apply(reader: DataSourceReader): DataSourceV2Relation = { - new DataSourceV2Relation(reader.readSchema().toAttributes, reader) + private implicit class SourceHelpers(source: DataSourceV2) { + def asReadSupport: ReadSupport = { + source match { + case support: ReadSupport => + support + case _: ReadSupportWithSchema => + // this method is only called if there is no user-supplied schema. if there is no + // user-supplied schema and ReadSupport was not implemented, throw a helpful exception. + throw new AnalysisException(s"Data source requires a user-supplied schema: $name") + case _ => + throw new AnalysisException(s"Data source is not readable: $name") + } + } + + def asReadSupportWithSchema: ReadSupportWithSchema = { + source match { + case support: ReadSupportWithSchema => + support + case _: ReadSupport => + throw new AnalysisException( + s"Data source does not support user-supplied schema: $name") + case _ => + throw new AnalysisException(s"Data source is not readable: $name") + } + } + + def name: String = { + source match { + case registered: DataSourceRegister => + registered.shortName() + case _ => + source.getClass.getSimpleName + } + } + } + + private def makeV2Options(options: Map[String, String]): DataSourceOptions = { + new DataSourceOptions(options.asJava) + } + + private def schema( + source: DataSourceV2, + v2Options: DataSourceOptions, + userSchema: Option[StructType]): StructType = { + val reader = userSchema match { + // TODO: remove this case because it is confusing for users + case Some(s) if !source.isInstanceOf[ReadSupportWithSchema] => + val reader = source.asReadSupport.createReader(v2Options) + if (reader.readSchema() != s) { + throw new AnalysisException(s"${source.name} does not allow user-specified schemas.") + } + reader + case Some(s) => + source.asReadSupportWithSchema.createReader(s, v2Options) + case _ => + source.asReadSupport.createReader(v2Options) + } + reader.readSchema() + } + + def create( + source: DataSourceV2, + options: Map[String, String], + filters: Option[Seq[Expression]] = None, + userSpecifiedSchema: Option[StructType] = None): DataSourceV2Relation = { + val projection = schema(source, makeV2Options(options), userSpecifiedSchema).toAttributes + DataSourceV2Relation(source, options, projection, filters, + // if the source does not implement ReadSupportWithSchema, then the userSpecifiedSchema must + // be equal to the reader's schema. the schema method enforces this. because the user schema + // and the reader's schema are identical, drop the user schema. + if (source.isInstanceOf[ReadSupportWithSchema]) userSpecifiedSchema else None) + } + + private def pushRequiredColumns(reader: DataSourceReader, struct: StructType): Unit = { + reader match { + case projectionSupport: SupportsPushDownRequiredColumns => + projectionSupport.pruneColumns(struct) + case _ => + } + } + + private def pushFilters( + reader: DataSourceReader, + filters: Seq[Expression]): (Seq[Expression], Seq[Expression]) = { + reader match { + case catalystFilterSupport: SupportsPushDownCatalystFilters => + ( + catalystFilterSupport.pushCatalystFilters(filters.toArray), + catalystFilterSupport.pushedCatalystFilters() + ) + + case filterSupport: SupportsPushDownFilters => + // A map from original Catalyst expressions to corresponding translated data source + // filters. If a predicate is not in this map, it means it cannot be pushed down. + val translatedMap: Map[Expression, Filter] = filters.flatMap { p => + DataSourceStrategy.translateFilter(p).map(f => p -> f) + }.toMap + + // Catalyst predicate expressions that cannot be converted to data source filters. + val nonConvertiblePredicates = filters.filterNot(translatedMap.contains) + + // Data source filters that cannot be pushed down. An unhandled filter means + // the data source cannot guarantee the rows returned can pass the filter. + // As a result we must return it so Spark can plan an extra filter operator. + val unhandledFilters = filterSupport.pushFilters(translatedMap.values.toArray).toSet + val (unhandledPredicates, pushedPredicates) = translatedMap.partition { case (_, f) => + unhandledFilters.contains(f) + } + + (nonConvertiblePredicates ++ unhandledPredicates.keys, pushedPredicates.keys.toSeq) + + case _ => (filters, Nil) + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index df5b524485f54..c4e7644683c36 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -23,8 +23,11 @@ import org.apache.spark.sql.execution.SparkPlan object DataSourceV2Strategy extends Strategy { override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case DataSourceV2Relation(output, reader) => - DataSourceV2ScanExec(output, reader) :: Nil + case relation: DataSourceV2Relation => + DataSourceV2ScanExec(relation.output, relation.reader) :: Nil + + case relation: StreamingDataSourceV2Relation => + DataSourceV2ScanExec(relation.output, relation.reader) :: Nil case WriteToDataSourceV2(writer, query) => WriteToDataSourceV2Exec(writer, planLater(query)) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala index 1ca6cbf061b4e..f23d228567241 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala @@ -17,130 +17,55 @@ package org.apache.spark.sql.execution.datasources.v2 -import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeMap, AttributeSet, Expression, NamedExpression, PredicateHelper} -import org.apache.spark.sql.catalyst.optimizer.RemoveRedundantProject +import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, AttributeSet} +import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.execution.datasources.DataSourceStrategy -import org.apache.spark.sql.sources -import org.apache.spark.sql.sources.v2.reader._ -/** - * Pushes down various operators to the underlying data source for better performance. Operators are - * being pushed down with a specific order. As an example, given a LIMIT has a FILTER child, you - * can't push down LIMIT if FILTER is not completely pushed down. When both are pushed down, the - * data source should execute FILTER before LIMIT. And required columns are calculated at the end, - * because when more operators are pushed down, we may need less columns at Spark side. - */ -object PushDownOperatorsToDataSource extends Rule[LogicalPlan] with PredicateHelper { - override def apply(plan: LogicalPlan): LogicalPlan = { - // Note that, we need to collect the target operator along with PROJECT node, as PROJECT may - // appear in many places for column pruning. - // TODO: Ideally column pruning should be implemented via a plan property that is propagated - // top-down, then we can simplify the logic here and only collect target operators. - val filterPushed = plan transformUp { - case FilterAndProject(fields, condition, r @ DataSourceV2Relation(_, reader)) => - val (candidates, nonDeterministic) = - splitConjunctivePredicates(condition).partition(_.deterministic) - - val stayUpFilters: Seq[Expression] = reader match { - case r: SupportsPushDownCatalystFilters => - r.pushCatalystFilters(candidates.toArray) - - case r: SupportsPushDownFilters => - // A map from original Catalyst expressions to corresponding translated data source - // filters. If a predicate is not in this map, it means it cannot be pushed down. - val translatedMap: Map[Expression, sources.Filter] = candidates.flatMap { p => - DataSourceStrategy.translateFilter(p).map(f => p -> f) - }.toMap - - // Catalyst predicate expressions that cannot be converted to data source filters. - val nonConvertiblePredicates = candidates.filterNot(translatedMap.contains) - - // Data source filters that cannot be pushed down. An unhandled filter means - // the data source cannot guarantee the rows returned can pass the filter. - // As a result we must return it so Spark can plan an extra filter operator. - val unhandledFilters = r.pushFilters(translatedMap.values.toArray).toSet - val unhandledPredicates = translatedMap.filter { case (_, f) => - unhandledFilters.contains(f) - }.keys - - nonConvertiblePredicates ++ unhandledPredicates - - case _ => candidates - } - - val filterCondition = (stayUpFilters ++ nonDeterministic).reduceLeftOption(And) - val withFilter = filterCondition.map(Filter(_, r)).getOrElse(r) - if (withFilter.output == fields) { - withFilter - } else { - Project(fields, withFilter) - } - } - - // TODO: add more push down rules. - - val columnPruned = pushDownRequiredColumns(filterPushed, filterPushed.outputSet) - // After column pruning, we may have redundant PROJECT nodes in the query plan, remove them. - RemoveRedundantProject(columnPruned) - } - - // TODO: nested fields pruning - private def pushDownRequiredColumns( - plan: LogicalPlan, requiredByParent: AttributeSet): LogicalPlan = { - plan match { - case p @ Project(projectList, child) => - val required = projectList.flatMap(_.references) - p.copy(child = pushDownRequiredColumns(child, AttributeSet(required))) - - case f @ Filter(condition, child) => - val required = requiredByParent ++ condition.references - f.copy(child = pushDownRequiredColumns(child, required)) - - case relation: DataSourceV2Relation => relation.reader match { - case reader: SupportsPushDownRequiredColumns => - // TODO: Enable the below assert after we make `DataSourceV2Relation` immutable. Fow now - // it's possible that the mutable reader being updated by someone else, and we need to - // always call `reader.pruneColumns` here to correct it. - // assert(relation.output.toStructType == reader.readSchema(), - // "Schema of data source reader does not match the relation plan.") - - val requiredColumns = relation.output.filter(requiredByParent.contains) - reader.pruneColumns(requiredColumns.toStructType) - - val nameToAttr = relation.output.map(_.name).zip(relation.output).toMap - val newOutput = reader.readSchema().map(_.name).map(nameToAttr) - relation.copy(output = newOutput) - - case _ => relation +object PushDownOperatorsToDataSource extends Rule[LogicalPlan] { + override def apply( + plan: LogicalPlan): LogicalPlan = plan transformUp { + // PhysicalOperation guarantees that filters are deterministic; no need to check + case PhysicalOperation(project, newFilters, relation : DataSourceV2Relation) => + // merge the filters + val filters = relation.filters match { + case Some(existing) => + existing ++ newFilters + case _ => + newFilters } - // TODO: there may be more operators that can be used to calculate the required columns. We - // can add more and more in the future. - case _ => plan.mapChildren(c => pushDownRequiredColumns(c, c.outputSet)) - } - } - - /** - * Finds a Filter node(with an optional Project child) above data source relation. - */ - object FilterAndProject { - // returns the project list, the filter condition and the data source relation. - def unapply(plan: LogicalPlan) - : Option[(Seq[NamedExpression], Expression, DataSourceV2Relation)] = plan match { + val projectAttrs = project.map(_.toAttribute) + val projectSet = AttributeSet(project.flatMap(_.references)) + val filterSet = AttributeSet(filters.flatMap(_.references)) + + val projection = if (filterSet.subsetOf(projectSet) && + AttributeSet(projectAttrs) == projectSet) { + // When the required projection contains all of the filter columns and column pruning alone + // can produce the required projection, push the required projection. + // A final projection may still be needed if the data source produces a different column + // order or if it cannot prune all of the nested columns. + projectAttrs + } else { + // When there are filter columns not already in the required projection or when the required + // projection is more complicated than column pruning, base column pruning on the set of + // all columns needed by both. + (projectSet ++ filterSet).toSeq + } - case Filter(condition, r: DataSourceV2Relation) => Some((r.output, condition, r)) + val newRelation = relation.copy( + projection = projection.asInstanceOf[Seq[AttributeReference]], + filters = Some(filters)) - case Filter(condition, Project(fields, r: DataSourceV2Relation)) - if fields.forall(_.deterministic) => - val attributeMap = AttributeMap(fields.map(e => e.toAttribute -> e)) - val substituted = condition.transform { - case a: Attribute => attributeMap.getOrElse(a, a) - } - Some((fields, substituted, r)) + // Add a Filter for any filters that could not be pushed + val unpushedFilter = newRelation.unsupportedFilters.reduceLeftOption(And) + val filtered = unpushedFilter.map(Filter(_, newRelation)).getOrElse(newRelation) - case _ => None - } + // Add a Project to ensure the output matches the required projection + if (newRelation.output != projectAttrs) { + Project(project, filtered) + } else { + filtered + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index c3294d64b10cd..2c1d6c509d21b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -201,7 +201,7 @@ class ContinuousExecution( val withSink = WriteToDataSourceV2(writer, triggerLogicalPlan) val reader = withSink.collect { - case DataSourceV2Relation(_, r: ContinuousReader) => r + case StreamingDataSourceV2Relation(_, r: ContinuousReader) => r }.head reportTimeTaken("queryPlanning") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala index a1c87fb15542c..1157a350461d8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala @@ -146,7 +146,7 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { Seq(classOf[SchemaRequiredDataSource], classOf[JavaSchemaRequiredDataSource]).foreach { cls => withClue(cls.getName) { val e = intercept[AnalysisException](spark.read.format(cls.getName).load()) - assert(e.message.contains("A schema needs to be specified")) + assert(e.message.contains("requires a user-supplied schema")) val schema = new StructType().add("i", "int").add("s", "string") val df = spark.read.format(cls.getName).schema(schema).load() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index 37fe595529baf..159dd0ecb5902 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -38,9 +38,9 @@ import org.apache.spark.sql.{Dataset, Encoder, QueryTest, Row} import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder, RowEncoder} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util._ -import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation +import org.apache.spark.sql.execution.datasources.v2.StreamingDataSourceV2Relation import org.apache.spark.sql.execution.streaming._ -import org.apache.spark.sql.execution.streaming.continuous.{ContinuousExecution, ContinuousTrigger, EpochCoordinatorRef, IncrementAndGetEpoch} +import org.apache.spark.sql.execution.streaming.continuous.{ContinuousExecution, EpochCoordinatorRef, IncrementAndGetEpoch} import org.apache.spark.sql.execution.streaming.sources.MemorySinkV2 import org.apache.spark.sql.execution.streaming.state.StateStore import org.apache.spark.sql.streaming.StreamingQueryListener._ @@ -605,7 +605,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be plan .collect { case StreamingExecutionRelation(s, _) => s - case DataSourceV2Relation(_, r) => r + case StreamingDataSourceV2Relation(_, r) => r } .zipWithIndex .find(_._1 == source) From 862fa697d829cdddf0f25e5613c91b040f9d9652 Mon Sep 17 00:00:00 2001 From: Bruce Robbins Date: Tue, 20 Feb 2018 20:26:26 +0900 Subject: [PATCH 0374/1161] [SPARK-23240][PYTHON] Better error message when extraneous data in pyspark.daemon's stdout ## What changes were proposed in this pull request? Print more helpful message when daemon module's stdout is empty or contains a bad port number. ## How was this patch tested? Manually recreated the environmental issues that caused the mysterious exceptions at one site. Tested that the expected messages are logged. Also, ran all scala unit tests. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Bruce Robbins Closes #20424 from bersprockets/SPARK-23240_prop2. --- .../api/python/PythonWorkerFactory.scala | 29 +++++++++++++++++-- 1 file changed, 26 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala index 30976ac752a8a..2340580b54f67 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala @@ -17,7 +17,7 @@ package org.apache.spark.api.python -import java.io.{DataInputStream, DataOutputStream, InputStream, OutputStreamWriter} +import java.io.{DataInputStream, DataOutputStream, EOFException, InputStream, OutputStreamWriter} import java.net.{InetAddress, ServerSocket, Socket, SocketException} import java.nio.charset.StandardCharsets import java.util.Arrays @@ -182,7 +182,8 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String try { // Create and start the daemon - val pb = new ProcessBuilder(Arrays.asList(pythonExec, "-m", daemonModule)) + val command = Arrays.asList(pythonExec, "-m", daemonModule) + val pb = new ProcessBuilder(command) val workerEnv = pb.environment() workerEnv.putAll(envVars.asJava) workerEnv.put("PYTHONPATH", pythonPath) @@ -191,7 +192,29 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String daemon = pb.start() val in = new DataInputStream(daemon.getInputStream) - daemonPort = in.readInt() + try { + daemonPort = in.readInt() + } catch { + case _: EOFException => + throw new SparkException(s"No port number in $daemonModule's stdout") + } + + // test that the returned port number is within a valid range. + // note: this does not cover the case where the port number + // is arbitrary data but is also coincidentally within range + if (daemonPort < 1 || daemonPort > 0xffff) { + val exceptionMessage = f""" + |Bad data in $daemonModule's standard output. Invalid port number: + | $daemonPort (0x$daemonPort%08x) + |Python command to execute the daemon was: + | ${command.asScala.mkString(" ")} + |Check that you don't have any unexpected modules or libraries in + |your PYTHONPATH: + | $pythonPath + |Also, check if you have a sitecustomize.py module in your python path, + |or in your python installation, that is printing to standard output""" + throw new SparkException(exceptionMessage.stripMargin) + } // Redirect daemon stdout and stderr redirectStreamsToStderr(in, daemon.getErrorStream) From 189f56f3dcdad4d997248c01aa5490617f018bd0 Mon Sep 17 00:00:00 2001 From: Kent Yao Date: Tue, 20 Feb 2018 07:51:30 -0600 Subject: [PATCH 0375/1161] [SPARK-23383][BUILD][MINOR] Make a distribution should exit with usage while detecting wrong options ## What changes were proposed in this pull request? ```shell ./dev/make-distribution.sh --name ne-1.0.0-SNAPSHOT xyz --tgz -Phadoop-2.7 +++ dirname ./dev/make-distribution.sh ++ cd ./dev/.. ++ pwd + SPARK_HOME=/Users/Kent/Documents/spark + DISTDIR=/Users/Kent/Documents/spark/dist + MAKE_TGZ=false + MAKE_PIP=false + MAKE_R=false + NAME=none + MVN=/Users/Kent/Documents/spark/build/mvn + (( 5 )) + case $1 in + NAME=ne-1.0.0-SNAPSHOT + shift + shift + (( 3 )) + case $1 in + break + '[' -z /Users/Kent/.jenv/candidates/java/current ']' + '[' -z /Users/Kent/.jenv/candidates/java/current ']' ++ command -v git + '[' /usr/local/bin/git ']' ++ git rev-parse --short HEAD + GITREV=98ea6a7 + '[' '!' -z 98ea6a7 ']' + GITREVSTRING=' (git revision 98ea6a7)' + unset GITREV ++ command -v /Users/Kent/Documents/spark/build/mvn + '[' '!' /Users/Kent/Documents/spark/build/mvn ']' ++ /Users/Kent/Documents/spark/build/mvn help:evaluate -Dexpression=project.version xyz --tgz -Phadoop-2.7 ++ grep -v INFO ++ tail -n 1 + VERSION=' -X,--debug Produce execution debug output' ``` It is better to declare the mistakes and exit with usage than `break` ## How was this patch tested? manually cc srowen Author: Kent Yao Closes #20571 from yaooqinn/SPARK-23383. --- dev/make-distribution.sh | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/dev/make-distribution.sh b/dev/make-distribution.sh index 8b02446b2f15f..84233c64caa9c 100755 --- a/dev/make-distribution.sh +++ b/dev/make-distribution.sh @@ -72,9 +72,17 @@ while (( "$#" )); do --help) exit_with_usage ;; - *) + --*) + echo "Error: $1 is not supported" + exit_with_usage + ;; + -*) break ;; + *) + echo "Error: $1 is not supported" + exit_with_usage + ;; esac shift done From 83c008762af444eef73d835eb6f506ecf5aebc17 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Tue, 20 Feb 2018 09:14:56 -0800 Subject: [PATCH 0376/1161] [SPARK-23456][SPARK-21783] Turn on `native` ORC impl and PPD by default ## What changes were proposed in this pull request? Apache Spark 2.3 introduced `native` ORC supports with vectorization and many fixes. However, it's shipped as a not-default option. This PR enables `native` ORC implementation and predicate-pushdown by default for Apache Spark 2.4. We will improve and stabilize ORC data source before Apache Spark 2.4. And, eventually, Apache Spark will drop old Hive-based ORC code. ## How was this patch tested? Pass the Jenkins with existing tests. Author: Dongjoon Hyun Closes #20634 from dongjoon-hyun/SPARK-23456. --- docs/sql-programming-guide.md | 6 +++++- .../main/scala/org/apache/spark/sql/internal/SQLConf.scala | 6 +++--- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 91e43678481d6..c37c338a134f3 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1018,7 +1018,7 @@ the vectorized reader is used when `spark.sql.hive.convertMetastoreOrc` is also spark.sql.orc.impl hive - The name of ORC implementation. It can be one of native and hive. native means the native ORC support that is built on Apache ORC 1.4.1. `hive` means the ORC library in Hive 1.2.1. + The name of ORC implementation. It can be one of native and hive. native means the native ORC support that is built on Apache ORC 1.4. `hive` means the ORC library in Hive 1.2.1. spark.sql.orc.enableVectorizedReader @@ -1797,6 +1797,10 @@ working with timestamps in `pandas_udf`s to get the best performance, see # Migration Guide +## Upgrading From Spark SQL 2.3 to 2.4 + + - Since Spark 2.4, Spark maximizes the usage of a vectorized ORC reader for ORC files by default. To do that, `spark.sql.orc.impl` and `spark.sql.orc.filterPushdown` change their default values to `native` and `true` respectively. + ## Upgrading From Spark SQL 2.2 to 2.3 - Since Spark 2.3, the queries from raw JSON/CSV files are disallowed when the referenced columns only include the internal corrupt record column (named `_corrupt_record` by default). For example, `spark.read.schema(schema).json(file).filter($"_corrupt_record".isNotNull).count()` and `spark.read.schema(schema).json(file).select("_corrupt_record").show()`. Instead, you can cache or save the parsed results and then send the same query. For example, `val df = spark.read.schema(schema).json(file).cache()` and then `df.filter($"_corrupt_record".isNotNull).count()`. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index e75e1d66ebcf8..ce3f94618edeb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -399,11 +399,11 @@ object SQLConf { val ORC_IMPLEMENTATION = buildConf("spark.sql.orc.impl") .doc("When native, use the native version of ORC support instead of the ORC library in Hive " + - "1.2.1. It is 'hive' by default.") + "1.2.1. It is 'hive' by default prior to Spark 2.4.") .internal() .stringConf .checkValues(Set("hive", "native")) - .createWithDefault("hive") + .createWithDefault("native") val ORC_VECTORIZED_READER_ENABLED = buildConf("spark.sql.orc.enableVectorizedReader") .doc("Enables vectorized orc decoding.") @@ -426,7 +426,7 @@ object SQLConf { val ORC_FILTER_PUSHDOWN_ENABLED = buildConf("spark.sql.orc.filterPushdown") .doc("When true, enable filter pushdown for ORC files.") .booleanConf - .createWithDefault(false) + .createWithDefault(true) val HIVE_VERIFY_PARTITION_PATH = buildConf("spark.sql.hive.verifyPartitionPath") .doc("When true, check all the partition paths under the table\'s root directory " + From 3e48f3b9ee7645e4218ad3ff7559e578d4bd9667 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Tue, 20 Feb 2018 16:02:44 -0800 Subject: [PATCH 0377/1161] [SPARK-23434][SQL] Spark should not warn `metadata directory` for a HDFS file path ## What changes were proposed in this pull request? In a kerberized cluster, when Spark reads a file path (e.g. `people.json`), it warns with a wrong warning message during looking up `people.json/_spark_metadata`. The root cause of this situation is the difference between `LocalFileSystem` and `DistributedFileSystem`. `LocalFileSystem.exists()` returns `false`, but `DistributedFileSystem.exists` raises `org.apache.hadoop.security.AccessControlException`. ```scala scala> spark.version res0: String = 2.4.0-SNAPSHOT scala> spark.read.json("file:///usr/hdp/current/spark-client/examples/src/main/resources/people.json").show +----+-------+ | age| name| +----+-------+ |null|Michael| | 30| Andy| | 19| Justin| +----+-------+ scala> spark.read.json("hdfs:///tmp/people.json") 18/02/15 05:00:48 WARN streaming.FileStreamSink: Error while looking for metadata directory. 18/02/15 05:00:48 WARN streaming.FileStreamSink: Error while looking for metadata directory. ``` After this PR, ```scala scala> spark.read.json("hdfs:///tmp/people.json").show +----+-------+ | age| name| +----+-------+ |null|Michael| | 30| Andy| | 19| Justin| +----+-------+ ``` ## How was this patch tested? Manual. Author: Dongjoon Hyun Closes #20616 from dongjoon-hyun/SPARK-23434. --- .../spark/sql/execution/streaming/FileStreamSink.scala | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala index 2715fa93d0e98..87a17cebdc10c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala @@ -42,9 +42,11 @@ object FileStreamSink extends Logging { try { val hdfsPath = new Path(singlePath) val fs = hdfsPath.getFileSystem(hadoopConf) - val metadataPath = new Path(hdfsPath, metadataDir) - val res = fs.exists(metadataPath) - res + if (fs.isDirectory(hdfsPath)) { + fs.exists(new Path(hdfsPath, metadataDir)) + } else { + false + } } catch { case NonFatal(e) => logWarning(s"Error while looking for metadata directory.") From 2ba77ed9e51922303e3c3533e368b95788bd7de5 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Tue, 20 Feb 2018 17:54:06 -0800 Subject: [PATCH 0378/1161] [SPARK-23470][UI] Use first attempt of last stage to define job description. This is much faster than finding out what the last attempt is, and the data should be the same. There's room for improvement in this page (like only loading data for the jobs being shown, instead of loading all available jobs and sorting them), but this should bring performance on par with the 2.2 version. Author: Marcelo Vanzin Closes #20644 from vanzin/SPARK-23470. --- core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index a9265d4dbcdfb..ac83de10f9237 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -1048,7 +1048,7 @@ private[ui] object ApiHelper { } def lastStageNameAndDescription(store: AppStatusStore, job: JobData): (String, String) = { - val stage = store.asOption(store.lastStageAttempt(job.stageIds.max)) + val stage = store.asOption(store.stageAttempt(job.stageIds.max, 0)) (stage.map(_.name).getOrElse(""), stage.flatMap(_.description).getOrElse(job.name)) } From 6d398c05cbad69aa9093429e04ae44c73b81cd5a Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Tue, 20 Feb 2018 18:06:21 -0800 Subject: [PATCH 0379/1161] [SPARK-23468][CORE] Stringify auth secret before storing it in credentials. The secret is used as a string in many parts of the code, so it has to be turned into a hex string to avoid issues such as the random byte sequence not containing a valid UTF8 sequence. Author: Marcelo Vanzin Closes #20643 from vanzin/SPARK-23468. --- core/src/main/scala/org/apache/spark/SecurityManager.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/SecurityManager.scala b/core/src/main/scala/org/apache/spark/SecurityManager.scala index 4c1dbe3ffb4ad..5b15a1c57779d 100644 --- a/core/src/main/scala/org/apache/spark/SecurityManager.scala +++ b/core/src/main/scala/org/apache/spark/SecurityManager.scala @@ -541,7 +541,8 @@ private[spark] class SecurityManager( rnd.nextBytes(secretBytes) val creds = new Credentials() - creds.addSecretKey(SECRET_LOOKUP_KEY, secretBytes) + val secretStr = HashCodes.fromBytes(secretBytes).toString() + creds.addSecretKey(SECRET_LOOKUP_KEY, secretStr.getBytes(UTF_8)) UserGroupInformation.getCurrentUser().addCredentials(creds) } From 601d653bff9160db8477f86d961e609fc2190237 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 20 Feb 2018 18:16:10 -0800 Subject: [PATCH 0380/1161] [SPARK-23454][SS][DOCS] Added trigger information to the Structured Streaming programming guide ## What changes were proposed in this pull request? - Added clear information about triggers - Made the semantics guarantees of watermarks more clear for streaming aggregations and stream-stream joins. ## How was this patch tested? (Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests) (If this patch involves UI changes, please attach a screenshot; otherwise, remove this) Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Tathagata Das Closes #20631 from tdas/SPARK-23454. --- .../structured-streaming-programming-guide.md | 214 +++++++++++++++++- 1 file changed, 207 insertions(+), 7 deletions(-) diff --git a/docs/structured-streaming-programming-guide.md b/docs/structured-streaming-programming-guide.md index 48d6d0b542cc0..9a83f157452ad 100644 --- a/docs/structured-streaming-programming-guide.md +++ b/docs/structured-streaming-programming-guide.md @@ -904,7 +904,7 @@ windowedCounts <- count(
    -### Handling Late Data and Watermarking +#### Handling Late Data and Watermarking Now consider what happens if one of the events arrives late to the application. For example, say, a word generated at 12:04 (i.e. event time) could be received by the application at 12:11. The application should use the time 12:04 instead of 12:11 @@ -925,7 +925,9 @@ specifying the event time column and the threshold on how late the data is expec event time. For a specific window starting at time `T`, the engine will maintain state and allow late data to update the state until `(max event time seen by the engine - late threshold > T)`. In other words, late data within the threshold will be aggregated, -but data later than the threshold will be dropped. Let's understand this with an example. We can +but data later than the threshold will start getting dropped +(see [later]((#semantic-guarantees-of-aggregation-with-watermarking)) +in the section for the exact guarantees). Let's understand this with an example. We can easily define watermarking on the previous example using `withWatermark()` as shown below.
    @@ -1031,7 +1033,9 @@ then drops intermediate state of a window < watermark, and appends the final counts to the Result Table/sink. For example, the final counts of window `12:00 - 12:10` is appended to the Result Table only after the watermark is updated to `12:11`. -**Conditions for watermarking to clean aggregation state** +##### Conditions for watermarking to clean aggregation state +{:.no_toc} + It is important to note that the following conditions must be satisfied for the watermarking to clean the state in aggregation queries *(as of Spark 2.1.1, subject to change in the future)*. @@ -1051,6 +1055,16 @@ from the aggregation column. For example, `df.groupBy("time").count().withWatermark("time", "1 min")` is invalid in Append output mode. +##### Semantic Guarantees of Aggregation with Watermarking +{:.no_toc} + +- A watermark delay (set with `withWatermark`) of "2 hours" guarantees that the engine will never +drop any data that is less than 2 hours delayed. In other words, any data less than 2 hours behind +(in terms of event-time) the latest data processed till then is guaranteed to be aggregated. + +- However, the guarantee is strict only in one direction. Data delayed by more than 2 hours is +not guaranteed to be dropped; it may or may not get aggregated. More delayed is the data, less +likely is the engine going to process it. ### Join Operations Structured Streaming supports joining a streaming Dataset/DataFrame with a static Dataset/DataFrame @@ -1062,7 +1076,7 @@ Dataset/DataFrame will be the exactly the same as if it was with a static Datase containing the same data in the stream. -#### Stream-static joins +#### Stream-static Joins Since the introduction in Spark 2.0, Structured Streaming has supported joins (inner join and some type of outer joins) between a streaming and a static DataFrame/Dataset. Here is a simple example. @@ -1269,6 +1283,12 @@ joined <- join(
+###### Semantic Guarantees of Stream-stream Inner Joins with Watermarking +{:.no_toc} +This is similar to the [guarantees provided by watermarking on aggregations](#semantic-guarantees-of-aggregation-with-watermarking). +A watermark delay of "2 hours" guarantees that the engine will never drop any data that is less than + 2 hours delayed. But data delayed by more than 2 hours may or may not get processed. + ##### Outer Joins with Watermarking While the watermark + event-time constraints is optional for inner joins, for left and right outer joins they must be specified. This is because for generating the NULL results in outer join, the @@ -1347,7 +1367,14 @@ joined <- join(
-There are a few points to note regarding outer joins. +###### Semantic Guarantees of Stream-stream Outer Joins with Watermarking +{:.no_toc} +Outer joins have the same guarantees as [inner joins](#semantic-guarantees-of-stream-stream-inner-joins-with-watermarking) +regarding watermark delays and whether data will be dropped or not. + +###### Caveats +{:.no_toc} +There are a few important characteristics to note regarding how the outer results are generated. - *The outer NULL results will be generated with a delay that depends on the specified watermark delay and the time range condition.* This is because the engine has to wait for that long to ensure @@ -1962,7 +1989,7 @@ head(sql("select * from aggregates")) -#### Using Foreach +##### Using Foreach The `foreach` operation allows arbitrary operations to be computed on the output data. As of Spark 2.1, this is available only for Scala and Java. To use this, you will have to implement the interface `ForeachWriter` ([Scala](api/scala/index.html#org.apache.spark.sql.ForeachWriter)/[Java](api/java/org/apache/spark/sql/ForeachWriter.html) docs), which has methods that get called whenever there is a sequence of rows generated as output after a trigger. Note the following important points. @@ -1979,6 +2006,172 @@ which has methods that get called whenever there is a sequence of rows generated - Whenever `open` is called, `close` will also be called (unless the JVM exits due to some error). This is true even if `open` returns false. If there is any error in processing and writing the data, `close` will be called with the error. It is your responsibility to clean up state (e.g. connections, transactions, etc.) that have been created in `open` such that there are no resource leaks. +#### Triggers +The trigger settings of a streaming query defines the timing of streaming data processing, whether +the query is going to executed as micro-batch query with a fixed batch interval or as a continuous processing query. +Here are the different kinds of triggers that are supported. + + + + + + + + + + + + + + + + + + + + + + +
Trigger TypeDescription
unspecified (default) + If no trigger setting is explicitly specified, then by default, the query will be + executed in micro-batch mode, where micro-batches will be generated as soon as + the previous micro-batch has completed processing. +
Fixed interval micro-batches + The query will be executed with micro-batches mode, where micro-batches will be kicked off + at the user-specified intervals. +
    +
  • If the previous micro-batch completes within the interval, then the engine will wait until + the interval is over before kicking off the next micro-batch.
  • + +
  • If the previous micro-batch takes longer than the interval to complete (i.e. if an + interval boundary is missed), then the next micro-batch will start as soon as the + previous one completes (i.e., it will not wait for the next interval boundary).
  • + +
  • If no new data is available, then no micro-batch will be kicked off.
  • +
+
One-time micro-batch + The query will execute *only one* micro-batch to process all the available data and then + stop on its own. This is useful in scenarios you want to periodically spin up a cluster, + process everything that is available since the last period, and then shutdown the + cluster. In some case, this may lead to significant cost savings. +
Continuous with fixed checkpoint interval
(experimental)
+ The query will be executed in the new low-latency, continuous processing mode. Read more + about this in the Continuous Processing section below. +
+ +Here are a few code examples. + +
+
+ +{% highlight scala %} +import org.apache.spark.sql.streaming.Trigger + +// Default trigger (runs micro-batch as soon as it can) +df.writeStream + .format("console") + .start() + +// ProcessingTime trigger with two-seconds micro-batch interval +df.writeStream + .format("console") + .trigger(Trigger.ProcessingTime("2 seconds")) + .start() + +// One-time trigger +df.writeStream + .format("console") + .trigger(Trigger.Once()) + .start() + +// Continuous trigger with one-second checkpointing interval +df.writeStream + .format("console") + .trigger(Trigger.Continuous("1 second")) + .start() + +{% endhighlight %} + + +
+
+ +{% highlight java %} +import org.apache.spark.sql.streaming.Trigger + +// Default trigger (runs micro-batch as soon as it can) +df.writeStream + .format("console") + .start(); + +// ProcessingTime trigger with two-seconds micro-batch interval +df.writeStream + .format("console") + .trigger(Trigger.ProcessingTime("2 seconds")) + .start(); + +// One-time trigger +df.writeStream + .format("console") + .trigger(Trigger.Once()) + .start(); + +// Continuous trigger with one-second checkpointing interval +df.writeStream + .format("console") + .trigger(Trigger.Continuous("1 second")) + .start(); + +{% endhighlight %} + +
+
+ +{% highlight python %} + +# Default trigger (runs micro-batch as soon as it can) +df.writeStream \ + .format("console") \ + .start() + +# ProcessingTime trigger with two-seconds micro-batch interval +df.writeStream \ + .format("console") \ + .trigger(processingTime='2 seconds') \ + .start() + +# One-time trigger +df.writeStream \ + .format("console") \ + .trigger(once=True) \ + .start() + +# Continuous trigger with one-second checkpointing interval +df.writeStream + .format("console") + .trigger(continuous='1 second') + .start() + +{% endhighlight %} +
+
+ +{% highlight r %} +# Default trigger (runs micro-batch as soon as it can) +write.stream(df, "console") + +# ProcessingTime trigger with two-seconds micro-batch interval +write.stream(df, "console", trigger.processingTime = "2 seconds") + +# One-time trigger +write.stream(df, "console", trigger.once = TRUE) + +# Continuous trigger is not yet supported +{% endhighlight %} +
+
+ + ## Managing Streaming Queries The `StreamingQuery` object created when a query is started can be used to monitor and manage the query. @@ -2516,7 +2709,10 @@ write.stream(aggDF, "memory", outputMode = "complete", checkpointLocation = "pat -# Continuous Processing [Experimental] +# Continuous Processing +## [Experimental] +{:.no_toc} + **Continuous processing** is a new, experimental streaming execution mode introduced in Spark 2.3 that enables low (~1 ms) end-to-end latency with at-least-once fault-tolerance guarantees. Compare this with the default *micro-batch processing* engine which can achieve exactly-once guarantees but achieve latencies of ~100ms at best. For some types of queries (discussed below), you can choose which mode to execute them in without modifying the application logic (i.e. without changing the DataFrame/Dataset operations). To run a supported query in continuous processing mode, all you need to do is specify a **continuous trigger** with the desired checkpoint interval as a parameter. For example, @@ -2589,6 +2785,8 @@ spark \ A checkpoint interval of 1 second means that the continuous processing engine will records the progress of the query every second. The resulting checkpoints are in a format compatible with the micro-batch engine, hence any query can be restarted with any trigger. For example, a supported query started with the micro-batch mode can be restarted in continuous mode, and vice versa. Note that any time you switch to continuous mode, you will get at-least-once fault-tolerance guarantees. ## Supported Queries +{:.no_toc} + As of Spark 2.3, only the following type of queries are supported in the continuous processing mode. - *Operations*: Only map-like Dataset/DataFrame operations are supported in continuous mode, that is, only projections (`select`, `map`, `flatMap`, `mapPartitions`, etc.) and selections (`where`, `filter`, etc.). @@ -2606,6 +2804,8 @@ As of Spark 2.3, only the following type of queries are supported in the continu See [Input Sources](#input-sources) and [Output Sinks](#output-sinks) sections for more details on them. While the console sink is good for testing, the end-to-end low-latency processing can be best observed with Kafka as the source and sink, as this allows the engine to process the data and make the results available in the output topic within milliseconds of the input data being available in the input topic. ## Caveats +{:.no_toc} + - Continuous processing engine launches multiple long-running tasks that continuously read data from sources, process it and continuously write to sinks. The number of tasks required by the query depends on how many partitions the query can read from the sources in parallel. Therefore, before starting a continuous processing query, you must ensure there are enough cores in the cluster to all the tasks in parallel. For example, if you are reading from a Kafka topic that has 10 partitions, then the cluster must have at least 10 cores for the query to make progress. - Stopping a continuous processing stream may produce spurious task termination warnings. These can be safely ignored. - There are currently no automatic retries of failed tasks. Any failure will lead to the query being stopped and it needs to be manually restarted from the checkpoint. From 95e25ed1a8b56937345eff637c0032aea85a503d Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Wed, 21 Feb 2018 11:26:06 +0800 Subject: [PATCH 0381/1161] [SPARK-23424][SQL] Add codegenStageId in comment ## What changes were proposed in this pull request? This PR always adds `codegenStageId` in comment of the generated class. This is a replication of #20419 for post-Spark 2.3. Closes #20419 ``` /* 001 */ public Object generate(Object[] references) { /* 002 */ return new GeneratedIteratorForCodegenStage1(references); /* 003 */ } /* 004 */ /* 005 */ // codegenStageId=1 /* 006 */ final class GeneratedIteratorForCodegenStage1 extends org.apache.spark.sql.execution.BufferedRowIterator { /* 007 */ private Object[] references; ... ``` ## How was this patch tested? Existing tests Author: Kazuaki Ishizaki Closes #20612 from kiszk/SPARK-23424. --- .../expressions/codegen/CodeGenerator.scala | 21 ++++++++++++++++--- .../sql/execution/WholeStageCodegenExec.scala | 4 +++- 2 files changed, 21 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 31ba29ae8d8ce..60a6f50472504 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -1232,14 +1232,29 @@ class CodegenContext { /** * Register a comment and return the corresponding place holder + * + * @param placeholderId an optionally specified identifier for the comment's placeholder. + * The caller should make sure this identifier is unique within the + * compilation unit. If this argument is not specified, a fresh identifier + * will be automatically created and used as the placeholder. + * @param force whether to force registering the comments */ - def registerComment(text: => String): String = { + def registerComment( + text: => String, + placeholderId: String = "", + force: Boolean = false): String = { // By default, disable comments in generated code because computing the comments themselves can // be extremely expensive in certain cases, such as deeply-nested expressions which operate over // inputs with wide schemas. For more details on the performance issues that motivated this // flat, see SPARK-15680. - if (SparkEnv.get != null && SparkEnv.get.conf.getBoolean("spark.sql.codegen.comments", false)) { - val name = freshName("c") + if (force || + SparkEnv.get != null && SparkEnv.get.conf.getBoolean("spark.sql.codegen.comments", false)) { + val name = if (placeholderId != "") { + assert(!placeHolderToComments.contains(placeholderId)) + placeholderId + } else { + freshName("c") + } val comment = if (text.contains("\n") || text.contains("\r")) { text.split("(\r\n)|\r|\n").mkString("/**\n * ", "\n * ", "\n */") } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index 0e525b1e22eb9..deb0a044c2fb2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -540,7 +540,9 @@ case class WholeStageCodegenExec(child: SparkPlan)(val codegenStageId: Int) ${ctx.registerComment( s"""Codegend pipeline for stage (id=$codegenStageId) - |${this.treeString.trim}""".stripMargin)} + |${this.treeString.trim}""".stripMargin, + "wsc_codegenPipeline")} + ${ctx.registerComment(s"codegenStageId=$codegenStageId", "wsc_codegenStageId", true)} final class $className extends ${classOf[BufferedRowIterator].getName} { private Object[] references; From c8c4441dfdfeda22f8d92e25aee1b6a6269752f9 Mon Sep 17 00:00:00 2001 From: Ryan Blue Date: Wed, 21 Feb 2018 15:10:08 +0800 Subject: [PATCH 0382/1161] [SPARK-23418][SQL] Fail DataSourceV2 reads when user schema is passed, but not supported. ## What changes were proposed in this pull request? DataSourceV2 initially allowed user-supplied schemas when a source doesn't implement `ReadSupportWithSchema`, as long as the schema was identical to the source's schema. This is confusing behavior because changes to an underlying table can cause a previously working job to fail with an exception that user-supplied schemas are not allowed. This reverts commit adcb25a0624, which was added to #20387 so that it could be removed in a separate JIRA issue and PR. ## How was this patch tested? Existing tests. Author: Ryan Blue Closes #20603 from rdblue/SPARK-23418-revert-adcb25a0624. --- .../datasources/v2/DataSourceV2Relation.scala | 13 +------------ 1 file changed, 1 insertion(+), 12 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala index a98dd4866f82a..cc6cb631e3f06 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala @@ -174,13 +174,6 @@ object DataSourceV2Relation { v2Options: DataSourceOptions, userSchema: Option[StructType]): StructType = { val reader = userSchema match { - // TODO: remove this case because it is confusing for users - case Some(s) if !source.isInstanceOf[ReadSupportWithSchema] => - val reader = source.asReadSupport.createReader(v2Options) - if (reader.readSchema() != s) { - throw new AnalysisException(s"${source.name} does not allow user-specified schemas.") - } - reader case Some(s) => source.asReadSupportWithSchema.createReader(s, v2Options) case _ => @@ -195,11 +188,7 @@ object DataSourceV2Relation { filters: Option[Seq[Expression]] = None, userSpecifiedSchema: Option[StructType] = None): DataSourceV2Relation = { val projection = schema(source, makeV2Options(options), userSpecifiedSchema).toAttributes - DataSourceV2Relation(source, options, projection, filters, - // if the source does not implement ReadSupportWithSchema, then the userSpecifiedSchema must - // be equal to the reader's schema. the schema method enforces this. because the user schema - // and the reader's schema are identical, drop the user schema. - if (source.isInstanceOf[ReadSupportWithSchema]) userSpecifiedSchema else None) + DataSourceV2Relation(source, options, projection, filters, userSpecifiedSchema) } private def pushRequiredColumns(reader: DataSourceReader, struct: StructType): Unit = { From e836c27ce011ca9aef822bef6320b4a7059ec343 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Wed, 21 Feb 2018 12:39:36 -0600 Subject: [PATCH 0383/1161] [SPARK-23217][ML][PYTHON] Add distanceMeasure param to ClusteringEvaluator Python API ## What changes were proposed in this pull request? The PR adds the `distanceMeasure` param to ClusteringEvaluator in the Python API. This allows the user to specify `cosine` as distance measure in addition to the default `squaredEuclidean`. ## How was this patch tested? added UT Author: Marco Gaido Closes #20627 from mgaido91/SPARK-23217_python. --- python/pyspark/ml/evaluation.py | 28 +++++++++++++++++++++++----- python/pyspark/ml/tests.py | 16 ++++++++++++++-- 2 files changed, 37 insertions(+), 7 deletions(-) diff --git a/python/pyspark/ml/evaluation.py b/python/pyspark/ml/evaluation.py index 0cbce9b40048f..695d8ab27cc96 100644 --- a/python/pyspark/ml/evaluation.py +++ b/python/pyspark/ml/evaluation.py @@ -362,18 +362,21 @@ class ClusteringEvaluator(JavaEvaluator, HasPredictionCol, HasFeaturesCol, metricName = Param(Params._dummy(), "metricName", "metric name in evaluation (silhouette)", typeConverter=TypeConverters.toString) + distanceMeasure = Param(Params._dummy(), "distanceMeasure", "The distance measure. " + + "Supported options: 'squaredEuclidean' and 'cosine'.", + typeConverter=TypeConverters.toString) @keyword_only def __init__(self, predictionCol="prediction", featuresCol="features", - metricName="silhouette"): + metricName="silhouette", distanceMeasure="squaredEuclidean"): """ __init__(self, predictionCol="prediction", featuresCol="features", \ - metricName="silhouette") + metricName="silhouette", distanceMeasure="squaredEuclidean") """ super(ClusteringEvaluator, self).__init__() self._java_obj = self._new_java_obj( "org.apache.spark.ml.evaluation.ClusteringEvaluator", self.uid) - self._setDefault(metricName="silhouette") + self._setDefault(metricName="silhouette", distanceMeasure="squaredEuclidean") kwargs = self._input_kwargs self._set(**kwargs) @@ -394,15 +397,30 @@ def getMetricName(self): @keyword_only @since("2.3.0") def setParams(self, predictionCol="prediction", featuresCol="features", - metricName="silhouette"): + metricName="silhouette", distanceMeasure="squaredEuclidean"): """ setParams(self, predictionCol="prediction", featuresCol="features", \ - metricName="silhouette") + metricName="silhouette", distanceMeasure="squaredEuclidean") Sets params for clustering evaluator. """ kwargs = self._input_kwargs return self._set(**kwargs) + @since("2.4.0") + def setDistanceMeasure(self, value): + """ + Sets the value of :py:attr:`distanceMeasure`. + """ + return self._set(distanceMeasure=value) + + @since("2.4.0") + def getDistanceMeasure(self): + """ + Gets the value of `distanceMeasure` + """ + return self.getOrDefault(self.distanceMeasure) + + if __name__ == "__main__": import doctest import tempfile diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 6d6737241e06e..116885969345c 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -51,7 +51,7 @@ from pyspark.ml.classification import * from pyspark.ml.clustering import * from pyspark.ml.common import _java2py, _py2java -from pyspark.ml.evaluation import BinaryClassificationEvaluator, \ +from pyspark.ml.evaluation import BinaryClassificationEvaluator, ClusteringEvaluator, \ MulticlassClassificationEvaluator, RegressionEvaluator from pyspark.ml.feature import * from pyspark.ml.fpm import FPGrowth, FPGrowthModel @@ -541,6 +541,15 @@ def test_java_params(self): self.assertEqual(evaluator._java_obj.getMetricName(), "r2") self.assertEqual(evaluatorCopy._java_obj.getMetricName(), "mae") + def test_clustering_evaluator_with_cosine_distance(self): + featureAndPredictions = map(lambda x: (Vectors.dense(x[0]), x[1]), + [([1.0, 1.0], 1.0), ([10.0, 10.0], 1.0), ([1.0, 0.5], 2.0), + ([10.0, 4.4], 2.0), ([-1.0, 1.0], 3.0), ([-100.0, 90.0], 3.0)]) + dataset = self.spark.createDataFrame(featureAndPredictions, ["features", "prediction"]) + evaluator = ClusteringEvaluator(predictionCol="prediction", distanceMeasure="cosine") + self.assertEqual(evaluator.getDistanceMeasure(), "cosine") + self.assertTrue(np.isclose(evaluator.evaluate(dataset), 0.992671213, atol=1e-5)) + class FeatureTests(SparkSessionTestCase): @@ -1961,11 +1970,14 @@ def test_java_params(self): import pyspark.ml.feature import pyspark.ml.classification import pyspark.ml.clustering + import pyspark.ml.evaluation import pyspark.ml.pipeline import pyspark.ml.recommendation import pyspark.ml.regression + modules = [pyspark.ml.feature, pyspark.ml.classification, pyspark.ml.clustering, - pyspark.ml.pipeline, pyspark.ml.recommendation, pyspark.ml.regression] + pyspark.ml.evaluation, pyspark.ml.pipeline, pyspark.ml.recommendation, + pyspark.ml.regression] for module in modules: for name, cls in inspect.getmembers(module, inspect.isclass): if not name.endswith('Model') and issubclass(cls, JavaParams)\ From 3fd0ccb13fea44727d970479af1682ef00592147 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Wed, 21 Feb 2018 14:56:13 -0800 Subject: [PATCH 0384/1161] [SPARK-23484][SS] Fix possible race condition in KafkaContinuousReader ## What changes were proposed in this pull request? var `KafkaContinuousReader.knownPartitions` should be threadsafe as it is accessed from multiple threads - the query thread at the time of reader factory creation, and the epoch tracking thread at the time of `needsReconfiguration`. ## How was this patch tested? Existing tests. Author: Tathagata Das Closes #20655 from tdas/SPARK-23484. --- .../org/apache/spark/sql/kafka010/KafkaContinuousReader.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala index 97a0f66e1880d..ecd1170321f3f 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala @@ -66,7 +66,7 @@ class KafkaContinuousReader( // Initialized when creating reader factories. If this diverges from the partitions at the latest // offsets, we need to reconfigure. // Exposed outside this object only for unit tests. - private[sql] var knownPartitions: Set[TopicPartition] = _ + @volatile private[sql] var knownPartitions: Set[TopicPartition] = _ override def readSchema: StructType = KafkaOffsetReader.kafkaSchema From 744d5af652ee8cece361cbca31e5201134e0fb42 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Wed, 21 Feb 2018 15:37:28 -0800 Subject: [PATCH 0385/1161] [SPARK-23481][WEBUI] lastStageAttempt should fail when a stage doesn't exist ## What changes were proposed in this pull request? The issue here is `AppStatusStore.lastStageAttempt` will return the next available stage in the store when a stage doesn't exist. This PR adds `last(stageId)` to ensure it returns a correct `StageData` ## How was this patch tested? The new unit test. Author: Shixiong Zhu Closes #20654 from zsxwing/SPARK-23481. --- .../apache/spark/status/AppStatusStore.scala | 6 +++- .../spark/status/AppStatusListenerSuite.scala | 33 +++++++++++++++++++ 2 files changed, 38 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala b/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala index efc28538a33db..688f25a9fdea1 100644 --- a/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala +++ b/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala @@ -95,7 +95,11 @@ private[spark] class AppStatusStore( } def lastStageAttempt(stageId: Int): v1.StageData = { - val it = store.view(classOf[StageDataWrapper]).index("stageId").reverse().first(stageId) + val it = store.view(classOf[StageDataWrapper]) + .index("stageId") + .reverse() + .first(stageId) + .last(stageId) .closeableIterator() try { if (it.hasNext()) { diff --git a/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala index 749502709b5c8..673d191b5a4db 100644 --- a/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala @@ -1121,6 +1121,39 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { } } + test("lastStageAttempt should fail when the stage doesn't exist") { + val testConf = conf.clone().set(MAX_RETAINED_STAGES, 1) + val listener = new AppStatusListener(store, testConf, true) + val appStore = new AppStatusStore(store) + + val stage1 = new StageInfo(1, 0, "stage1", 4, Nil, Nil, "details1") + val stage2 = new StageInfo(2, 0, "stage2", 4, Nil, Nil, "details2") + val stage3 = new StageInfo(3, 0, "stage3", 4, Nil, Nil, "details3") + + time += 1 + stage1.submissionTime = Some(time) + listener.onStageSubmitted(SparkListenerStageSubmitted(stage1, new Properties())) + stage1.completionTime = Some(time) + listener.onStageCompleted(SparkListenerStageCompleted(stage1)) + + // Make stage 3 complete before stage 2 so that stage 3 will be evicted + time += 1 + stage3.submissionTime = Some(time) + listener.onStageSubmitted(SparkListenerStageSubmitted(stage3, new Properties())) + stage3.completionTime = Some(time) + listener.onStageCompleted(SparkListenerStageCompleted(stage3)) + + time += 1 + stage2.submissionTime = Some(time) + listener.onStageSubmitted(SparkListenerStageSubmitted(stage2, new Properties())) + stage2.completionTime = Some(time) + listener.onStageCompleted(SparkListenerStageCompleted(stage2)) + + assert(appStore.asOption(appStore.lastStageAttempt(1)) === None) + assert(appStore.asOption(appStore.lastStageAttempt(2)).map(_.stageId) === Some(2)) + assert(appStore.asOption(appStore.lastStageAttempt(3)) === None) + } + test("driver logs") { val listener = new AppStatusListener(store, conf, true) From 45cf714ee6d4eead2fe00794a0d754fa6d33d4a6 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Wed, 21 Feb 2018 19:43:11 -0800 Subject: [PATCH 0386/1161] [SPARK-23475][WEBUI] Skipped stages should be evicted before completed stages ## What changes were proposed in this pull request? The root cause of missing completed stages is because `cleanupStages` will never remove skipped stages. This PR changes the logic to always remove skipped stage first. This is safe since the job itself contains enough information to render skipped stages in the UI. ## How was this patch tested? The new unit tests. Author: Shixiong Zhu Closes #20656 from zsxwing/SPARK-23475. --- .../spark/status/AppStatusListener.scala | 5 ++- .../spark/status/AppStatusListenerSuite.scala | 36 +++++++++++++++++++ 2 files changed, 40 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala index 79a17e26665fd..5ea161cd0d151 100644 --- a/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala +++ b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala @@ -915,7 +915,10 @@ private[spark] class AppStatusListener( return } - val view = kvstore.view(classOf[StageDataWrapper]).index("completionTime").first(0L) + // As the completion time of a skipped stage is always -1, we will remove skipped stages first. + // This is safe since the job itself contains enough information to render skipped stages in the + // UI. + val view = kvstore.view(classOf[StageDataWrapper]).index("completionTime") val stages = KVUtils.viewToSeq(view, countToDelete.toInt) { s => s.info.status != v1.StageStatus.ACTIVE && s.info.status != v1.StageStatus.PENDING } diff --git a/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala index 673d191b5a4db..1cd71955ad4d9 100644 --- a/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala @@ -1089,6 +1089,42 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { } } + test("skipped stages should be evicted before completed stages") { + val testConf = conf.clone().set(MAX_RETAINED_STAGES, 2) + val listener = new AppStatusListener(store, testConf, true) + + val stage1 = new StageInfo(1, 0, "stage1", 4, Nil, Nil, "details1") + val stage2 = new StageInfo(2, 0, "stage2", 4, Nil, Nil, "details2") + + // Sart job 1 + time += 1 + listener.onJobStart(SparkListenerJobStart(1, time, Seq(stage1, stage2), null)) + + // Start and stop stage 1 + time += 1 + stage1.submissionTime = Some(time) + listener.onStageSubmitted(SparkListenerStageSubmitted(stage1, new Properties())) + + time += 1 + stage1.completionTime = Some(time) + listener.onStageCompleted(SparkListenerStageCompleted(stage1)) + + // Stop job 1 and stage 2 will become SKIPPED + time += 1 + listener.onJobEnd(SparkListenerJobEnd(1, time, JobSucceeded)) + + // Submit stage 3 and verify stage 2 is evicted + val stage3 = new StageInfo(3, 0, "stage3", 4, Nil, Nil, "details3") + time += 1 + stage3.submissionTime = Some(time) + listener.onStageSubmitted(SparkListenerStageSubmitted(stage3, new Properties())) + + assert(store.count(classOf[StageDataWrapper]) === 2) + intercept[NoSuchElementException] { + store.read(classOf[StageDataWrapper], Array(2, 0)) + } + } + test("eviction should respect task completion time") { val testConf = conf.clone().set(MAX_RETAINED_TASKS_PER_STAGE, 2) val listener = new AppStatusListener(store, testConf, true) From 87293c746e19d66f475d506d0adb43421f496843 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Thu, 22 Feb 2018 11:00:12 -0800 Subject: [PATCH 0387/1161] [SPARK-23475][UI] Show also skipped stages ## What changes were proposed in this pull request? SPARK-20648 introduced the status `SKIPPED` for the stages. On the UI, previously, skipped stages were shown as `PENDING`; after this change, they are not shown on the UI. The PR introduce a new section in order to show also `SKIPPED` stages in a proper table. ## How was this patch tested? manual tests Author: Marco Gaido Closes #20651 from mgaido91/SPARK-23475. --- .../org/apache/spark/ui/static/webui.js | 1 + .../apache/spark/ui/jobs/AllStagesPage.scala | 27 +++++++++++++++++++ .../org/apache/spark/ui/UISeleniumSuite.scala | 17 ++++++++++++ 3 files changed, 45 insertions(+) diff --git a/core/src/main/resources/org/apache/spark/ui/static/webui.js b/core/src/main/resources/org/apache/spark/ui/static/webui.js index 83009df91d30a..f01c567ba58ad 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/webui.js +++ b/core/src/main/resources/org/apache/spark/ui/static/webui.js @@ -72,6 +72,7 @@ $(function() { collapseTablePageLoad('collapse-aggregated-allActiveStages','aggregated-allActiveStages'); collapseTablePageLoad('collapse-aggregated-allPendingStages','aggregated-allPendingStages'); collapseTablePageLoad('collapse-aggregated-allCompletedStages','aggregated-allCompletedStages'); + collapseTablePageLoad('collapse-aggregated-allSkippedStages','aggregated-allSkippedStages'); collapseTablePageLoad('collapse-aggregated-allFailedStages','aggregated-allFailedStages'); collapseTablePageLoad('collapse-aggregated-activeStages','aggregated-activeStages'); collapseTablePageLoad('collapse-aggregated-pendingOrSkippedStages','aggregated-pendingOrSkippedStages'); diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala index 606dc1e180e5b..38450b9126ff0 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala @@ -36,6 +36,7 @@ private[ui] class AllStagesPage(parent: StagesTab) extends WebUIPage("") { val activeStages = allStages.filter(_.status == StageStatus.ACTIVE) val pendingStages = allStages.filter(_.status == StageStatus.PENDING) + val skippedStages = allStages.filter(_.status == StageStatus.SKIPPED) val completedStages = allStages.filter(_.status == StageStatus.COMPLETE) val failedStages = allStages.filter(_.status == StageStatus.FAILED).reverse @@ -51,6 +52,9 @@ private[ui] class AllStagesPage(parent: StagesTab) extends WebUIPage("") { val completedStagesTable = new StageTableBase(parent.store, request, completedStages, "completed", "completedStage", parent.basePath, subPath, parent.isFairScheduler, false, false) + val skippedStagesTable = + new StageTableBase(parent.store, request, skippedStages, "skipped", "skippedStage", + parent.basePath, subPath, parent.isFairScheduler, false, false) val failedStagesTable = new StageTableBase(parent.store, request, failedStages, "failed", "failedStage", parent.basePath, subPath, parent.isFairScheduler, false, true) @@ -66,6 +70,7 @@ private[ui] class AllStagesPage(parent: StagesTab) extends WebUIPage("") { val shouldShowActiveStages = activeStages.nonEmpty val shouldShowPendingStages = pendingStages.nonEmpty val shouldShowCompletedStages = completedStages.nonEmpty + val shouldShowSkippedStages = skippedStages.nonEmpty val shouldShowFailedStages = failedStages.nonEmpty val appSummary = parent.store.appSummary() @@ -102,6 +107,14 @@ private[ui] class AllStagesPage(parent: StagesTab) extends WebUIPage("") { } } + { + if (shouldShowSkippedStages) { +
  • + Skipped Stages: + {skippedStages.size} +
  • + } + } { if (shouldShowFailedStages) {
  • @@ -172,6 +185,20 @@ private[ui] class AllStagesPage(parent: StagesTab) extends WebUIPage("") { {completedStagesTable.toNodeSeq} } + if (shouldShowSkippedStages) { + content ++= + +

    + + Skipped Stages ({skippedStages.size}) +

    +
    ++ +
    + {skippedStagesTable.toNodeSeq} +
    + } if (shouldShowFailedStages) { content ++= + val rdd = sc.parallelize(0 to 100, 100).repartition(10).cache() + rdd.count() + rdd.count() + + eventually(timeout(5 seconds), interval(50 milliseconds)) { + goToUi(sc, "/stages") + find(id("skipped")).get.text should be("Skipped Stages (1)") + } + val stagesJson = getJson(sc.ui.get, "stages") + stagesJson.children.size should be (4) + val stagesStatus = stagesJson.children.map(_ \ "status") + stagesStatus.count(_ == JString(StageStatus.SKIPPED.name())) should be (1) + } + } + def goToUi(sc: SparkContext, path: String): Unit = { goToUi(sc.ui.get, path) } From c5abb3c2d16f601d507bee3c53663d4e117eb8b5 Mon Sep 17 00:00:00 2001 From: Gabor Somogyi Date: Thu, 22 Feb 2018 12:07:51 -0800 Subject: [PATCH 0388/1161] [SPARK-23476][CORE] Generate secret in local mode when authentication on ## What changes were proposed in this pull request? If spark is run with "spark.authenticate=true", then it will fail to start in local mode. This PR generates secret in local mode when authentication on. ## How was this patch tested? Modified existing unit test. Manually started spark-shell. Author: Gabor Somogyi Closes #20652 from gaborgsomogyi/SPARK-23476. --- .../org/apache/spark/SecurityManager.scala | 16 ++++-- .../apache/spark/SecurityManagerSuite.scala | 50 +++++++++++++------ docs/security.md | 2 +- 3 files changed, 46 insertions(+), 22 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SecurityManager.scala b/core/src/main/scala/org/apache/spark/SecurityManager.scala index 5b15a1c57779d..2519d266879aa 100644 --- a/core/src/main/scala/org/apache/spark/SecurityManager.scala +++ b/core/src/main/scala/org/apache/spark/SecurityManager.scala @@ -520,19 +520,25 @@ private[spark] class SecurityManager( * * If authentication is disabled, do nothing. * - * In YARN mode, generate a new secret and store it in the current user's credentials. + * In YARN and local mode, generate a new secret and store it in the current user's credentials. * * In other modes, assert that the auth secret is set in the configuration. */ def initializeAuth(): Unit = { + import SparkMasterRegex._ + if (!sparkConf.get(NETWORK_AUTH_ENABLED)) { return } - if (sparkConf.get(SparkLauncher.SPARK_MASTER, null) != "yarn") { - require(sparkConf.contains(SPARK_AUTH_SECRET_CONF), - s"A secret key must be specified via the $SPARK_AUTH_SECRET_CONF config.") - return + val master = sparkConf.get(SparkLauncher.SPARK_MASTER, "") + master match { + case "yarn" | "local" | LOCAL_N_REGEX(_) | LOCAL_N_FAILURES_REGEX(_, _) => + // Secret generation allowed here + case _ => + require(sparkConf.contains(SPARK_AUTH_SECRET_CONF), + s"A secret key must be specified via the $SPARK_AUTH_SECRET_CONF config.") + return } val rnd = new SecureRandom() diff --git a/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala b/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala index cf59265dd646d..106ece7aed0a4 100644 --- a/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala @@ -440,23 +440,41 @@ class SecurityManagerSuite extends SparkFunSuite with ResetSystemProperties { assert(keyFromEnv === new SecurityManager(conf2).getSecretKey()) } - test("secret key generation in yarn mode") { - val conf = new SparkConf() - .set(NETWORK_AUTH_ENABLED, true) - .set(SparkLauncher.SPARK_MASTER, "yarn") - val mgr = new SecurityManager(conf) - - UserGroupInformation.createUserForTesting("authTest", Array()).doAs( - new PrivilegedExceptionAction[Unit]() { - override def run(): Unit = { - mgr.initializeAuth() - val creds = UserGroupInformation.getCurrentUser().getCredentials() - val secret = creds.getSecretKey(SecurityManager.SECRET_LOOKUP_KEY) - assert(secret != null) - assert(new String(secret, UTF_8) === mgr.getSecretKey()) + test("secret key generation") { + Seq( + ("yarn", true), + ("local", true), + ("local[*]", true), + ("local[1, 2]", true), + ("local-cluster[2, 1, 1024]", false), + ("invalid", false) + ).foreach { case (master, shouldGenerateSecret) => + val conf = new SparkConf() + .set(NETWORK_AUTH_ENABLED, true) + .set(SparkLauncher.SPARK_MASTER, master) + val mgr = new SecurityManager(conf) + + UserGroupInformation.createUserForTesting("authTest", Array()).doAs( + new PrivilegedExceptionAction[Unit]() { + override def run(): Unit = { + if (shouldGenerateSecret) { + mgr.initializeAuth() + val creds = UserGroupInformation.getCurrentUser().getCredentials() + val secret = creds.getSecretKey(SecurityManager.SECRET_LOOKUP_KEY) + assert(secret != null) + assert(new String(secret, UTF_8) === mgr.getSecretKey()) + } else { + intercept[IllegalArgumentException] { + mgr.initializeAuth() + } + intercept[IllegalArgumentException] { + mgr.getSecretKey() + } + } + } } - } - ) + ) + } } } diff --git a/docs/security.md b/docs/security.md index bebc28ddbfb0e..0f384b411812a 100644 --- a/docs/security.md +++ b/docs/security.md @@ -6,7 +6,7 @@ title: Security Spark currently supports authentication via a shared secret. Authentication can be configured to be on via the `spark.authenticate` configuration parameter. This parameter controls whether the Spark communication protocols do authentication using the shared secret. This authentication is a basic handshake to make sure both sides have the same shared secret and are allowed to communicate. If the shared secret is not identical they will not be allowed to communicate. The shared secret is created as follows: -* For Spark on [YARN](running-on-yarn.html) deployments, configuring `spark.authenticate` to `true` will automatically handle generating and distributing the shared secret. Each application will use a unique shared secret. +* For Spark on [YARN](running-on-yarn.html) and local deployments, configuring `spark.authenticate` to `true` will automatically handle generating and distributing the shared secret. Each application will use a unique shared secret. * For other types of Spark deployments, the Spark parameter `spark.authenticate.secret` should be configured on each of the nodes. This secret will be used by all the Master/Workers and applications. ## Web UI From 049f243c59737699fee54fdc9d65cbd7c788032a Mon Sep 17 00:00:00 2001 From: Wang Gengliang Date: Thu, 22 Feb 2018 21:49:25 -0800 Subject: [PATCH 0389/1161] [SPARK-23490][SQL] Check storage.locationUri with existing table in CreateTable ## What changes were proposed in this pull request? For CreateTable with Append mode, we should check if `storage.locationUri` is the same with existing table in `PreprocessTableCreation` In the current code, there is only a simple exception if the `storage.locationUri` is different with existing table: `org.apache.spark.sql.AnalysisException: Table or view not found:` which can be improved. ## How was this patch tested? Unit test Author: Wang Gengliang Closes #20660 from gengliangwang/locationUri. --- .../sql/execution/datasources/rules.scala | 8 +++++ .../sql/execution/command/DDLSuite.scala | 29 +++++++++++++++++++ 2 files changed, 37 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala index 5cc21eeaeaa94..0dea767840ed3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala @@ -118,6 +118,14 @@ case class PreprocessTableCreation(sparkSession: SparkSession) extends Rule[Logi s"`${existingProvider.getSimpleName}`. It doesn't match the specified format " + s"`${specifiedProvider.getSimpleName}`.") } + tableDesc.storage.locationUri match { + case Some(location) if location.getPath != existingTable.location.getPath => + throw new AnalysisException( + s"The location of the existing table ${tableIdentWithDB.quotedString} is " + + s"`${existingTable.location}`. It doesn't match the specified location " + + s"`${tableDesc.location}`.") + case _ => + } if (query.schema.length != existingTable.schema.length) { throw new AnalysisException( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index f76bfd2fda2b9..b800e6ff5b0ce 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -536,6 +536,35 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } } + test("create table - append to a non-partitioned table created with different paths") { + import testImplicits._ + withTempDir { dir1 => + withTempDir { dir2 => + withTable("path_test") { + Seq(1L -> "a").toDF("v1", "v2") + .write + .mode(SaveMode.Append) + .format("json") + .option("path", dir1.getCanonicalPath) + .saveAsTable("path_test") + + val ex = intercept[AnalysisException] { + Seq((3L, "c")).toDF("v1", "v2") + .write + .mode(SaveMode.Append) + .format("json") + .option("path", dir2.getCanonicalPath) + .saveAsTable("path_test") + }.getMessage + assert(ex.contains("The location of the existing table `default`.`path_test`")) + + checkAnswer( + spark.table("path_test"), Row(1L, "a") :: Nil) + } + } + } + } + test("Refresh table after changing the data source table partitioning") { import testImplicits._ From 855ce13d045569b7b16fdc7eee9c981f4ff3a545 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Fri, 23 Feb 2018 12:40:58 -0800 Subject: [PATCH 0390/1161] [SPARK-23408][SS] Synchronize successive AddData actions in Streaming*JoinSuite **The best way to review this PR is to ignore whitespace/indent changes. Use this link - https://github.com/apache/spark/pull/20650/files?w=1** ## What changes were proposed in this pull request? The stream-stream join tests add data to multiple sources and expect it all to show up in the next batch. But there's a race condition; the new batch might trigger when only one of the AddData actions has been reached. Prior attempt to solve this issue by jose-torres in #20646 attempted to simultaneously synchronize on all memory sources together when consecutive AddData was found in the actions. However, this carries the risk of deadlock as well as unintended modification of stress tests (see the above PR for a detailed explanation). Instead, this PR attempts the following. - A new action called `StreamProgressBlockedActions` that allows multiple actions to be executed while the streaming query is blocked from making progress. This allows data to be added to multiple sources that are made visible simultaneously in the next batch. - An alias of `StreamProgressBlockedActions` called `MultiAddData` is explicitly used in the `Streaming*JoinSuites` to add data to two memory sources simultaneously. This should avoid unintentional modification of the stress tests (or any other test for that matter) while making sure that the flaky tests are deterministic. ## How was this patch tested? Modified test cases in `Streaming*JoinSuites` where there are consecutive `AddData` actions. Author: Tathagata Das Closes #20650 from tdas/SPARK-23408. --- .../streaming/MicroBatchExecution.scala | 10 + .../spark/sql/streaming/StreamTest.scala | 472 ++++++++++-------- .../sql/streaming/StreamingJoinSuite.scala | 54 +- 3 files changed, 284 insertions(+), 252 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index 84655013ba957..6bd03972c301d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -504,6 +504,16 @@ class MicroBatchExecution( } } + /** Execute a function while locking the stream from making an progress */ + private[sql] def withProgressLocked(f: => Unit): Unit = { + awaitProgressLock.lock() + try { + f + } finally { + awaitProgressLock.unlock() + } + } + private def toJava(scalaOption: Option[OffsetV2]): Optional[OffsetV2] = { Optional.ofNullable(scalaOption.orNull) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index 159dd0ecb5902..08f722ecb10e5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -102,6 +102,19 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be AddDataMemory(source, data) } + /** + * Adds data to multiple memory streams such that all the data will be made visible in the + * same batch. This is applicable only to MicroBatchExecution, as this coordination cannot be + * performed at the driver in ContinuousExecutions. + */ + object MultiAddData { + def apply[A] + (source1: MemoryStream[A], data1: A*)(source2: MemoryStream[A], data2: A*): StreamAction = { + val actions = Seq(AddDataMemory(source1, data1), AddDataMemory(source2, data2)) + StreamProgressLockedActions(actions, desc = actions.mkString("[ ", " | ", " ]")) + } + } + /** A trait that can be extended when testing a source. */ trait AddData extends StreamAction { /** @@ -217,6 +230,19 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be s"ExpectFailure[${causeClass.getName}, isFatalError: $isFatalError]" } + /** + * Performs multiple actions while locking the stream from progressing. + * This is applicable only to MicroBatchExecution, as progress of ContinuousExecution + * cannot be controlled from the driver. + */ + case class StreamProgressLockedActions(actions: Seq[StreamAction], desc: String = null) + extends StreamAction { + + override def toString(): String = { + if (desc != null) desc else super.toString + } + } + /** Assert that a body is true */ class Assert(condition: => Boolean, val message: String = "") extends StreamAction { def run(): Unit = { Assertions.assert(condition) } @@ -295,6 +321,9 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be val awaiting = new mutable.HashMap[Int, Offset]() // source index -> offset to wait for val sink = if (useV2Sink) new MemorySinkV2 else new MemorySink(stream.schema, outputMode) val resetConfValues = mutable.Map[String, Option[String]]() + val defaultCheckpointLocation = + Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath + var manualClockExpectedTime = -1L @volatile var streamThreadDeathCause: Throwable = null @@ -425,243 +454,254 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be } } - var manualClockExpectedTime = -1L - val defaultCheckpointLocation = - Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath - try { - startedTest.foreach { action => - logInfo(s"Processing test stream action: $action") - action match { - case StartStream(trigger, triggerClock, additionalConfs, checkpointLocation) => - verify(currentStream == null, "stream already running") - verify(triggerClock.isInstanceOf[SystemClock] - || triggerClock.isInstanceOf[StreamManualClock], - "Use either SystemClock or StreamManualClock to start the stream") - if (triggerClock.isInstanceOf[StreamManualClock]) { - manualClockExpectedTime = triggerClock.asInstanceOf[StreamManualClock].getTimeMillis() + def executeAction(action: StreamAction): Unit = { + logInfo(s"Processing test stream action: $action") + action match { + case StartStream(trigger, triggerClock, additionalConfs, checkpointLocation) => + verify(currentStream == null, "stream already running") + verify(triggerClock.isInstanceOf[SystemClock] + || triggerClock.isInstanceOf[StreamManualClock], + "Use either SystemClock or StreamManualClock to start the stream") + if (triggerClock.isInstanceOf[StreamManualClock]) { + manualClockExpectedTime = triggerClock.asInstanceOf[StreamManualClock].getTimeMillis() + } + val metadataRoot = Option(checkpointLocation).getOrElse(defaultCheckpointLocation) + + additionalConfs.foreach(pair => { + val value = + if (sparkSession.conf.contains(pair._1)) { + Some(sparkSession.conf.get(pair._1)) + } else None + resetConfValues(pair._1) = value + sparkSession.conf.set(pair._1, pair._2) + }) + + lastStream = currentStream + currentStream = + sparkSession + .streams + .startQuery( + None, + Some(metadataRoot), + stream, + Map(), + sink, + outputMode, + trigger = trigger, + triggerClock = triggerClock) + .asInstanceOf[StreamingQueryWrapper] + .streamingQuery + // Wait until the initialization finishes, because some tests need to use `logicalPlan` + // after starting the query. + try { + currentStream.awaitInitialization(streamingTimeout.toMillis) + currentStream match { + case s: ContinuousExecution => eventually("IncrementalExecution was not created") { + assert(s.lastExecution != null) + } + case _ => } - val metadataRoot = Option(checkpointLocation).getOrElse(defaultCheckpointLocation) + } catch { + case _: StreamingQueryException => + // Ignore the exception. `StopStream` or `ExpectFailure` will catch it as well. + } - additionalConfs.foreach(pair => { - val value = - if (sparkSession.conf.contains(pair._1)) { - Some(sparkSession.conf.get(pair._1)) - } else None - resetConfValues(pair._1) = value - sparkSession.conf.set(pair._1, pair._2) - }) + case AdvanceManualClock(timeToAdd) => + verify(currentStream != null, + "can not advance manual clock when a stream is not running") + verify(currentStream.triggerClock.isInstanceOf[StreamManualClock], + s"can not advance clock of type ${currentStream.triggerClock.getClass}") + val clock = currentStream.triggerClock.asInstanceOf[StreamManualClock] + assert(manualClockExpectedTime >= 0) + + // Make sure we don't advance ManualClock too early. See SPARK-16002. + eventually("StreamManualClock has not yet entered the waiting state") { + assert(clock.isStreamWaitingAt(manualClockExpectedTime)) + } + clock.advance(timeToAdd) + manualClockExpectedTime += timeToAdd + verify(clock.getTimeMillis() === manualClockExpectedTime, + s"Unexpected clock time after updating: " + + s"expecting $manualClockExpectedTime, current ${clock.getTimeMillis()}") + + case StopStream => + verify(currentStream != null, "can not stop a stream that is not running") + try failAfter(streamingTimeout) { + currentStream.stop() + verify(!currentStream.queryExecutionThread.isAlive, + s"microbatch thread not stopped") + verify(!currentStream.isActive, + "query.isActive() is false even after stopping") + verify(currentStream.exception.isEmpty, + s"query.exception() is not empty after clean stop: " + + currentStream.exception.map(_.toString()).getOrElse("")) + } catch { + case _: InterruptedException => + case e: org.scalatest.exceptions.TestFailedDueToTimeoutException => + failTest( + "Timed out while stopping and waiting for microbatchthread to terminate.", e) + case t: Throwable => + failTest("Error while stopping stream", t) + } finally { lastStream = currentStream - currentStream = - sparkSession - .streams - .startQuery( - None, - Some(metadataRoot), - stream, - Map(), - sink, - outputMode, - trigger = trigger, - triggerClock = triggerClock) - .asInstanceOf[StreamingQueryWrapper] - .streamingQuery - // Wait until the initialization finishes, because some tests need to use `logicalPlan` - // after starting the query. - try { - currentStream.awaitInitialization(streamingTimeout.toMillis) - currentStream match { - case s: ContinuousExecution => eventually("IncrementalExecution was not created") { - assert(s.lastExecution != null) - } - case _ => - } - } catch { - case _: StreamingQueryException => - // Ignore the exception. `StopStream` or `ExpectFailure` will catch it as well. - } + currentStream = null + } - case AdvanceManualClock(timeToAdd) => - verify(currentStream != null, - "can not advance manual clock when a stream is not running") - verify(currentStream.triggerClock.isInstanceOf[StreamManualClock], - s"can not advance clock of type ${currentStream.triggerClock.getClass}") - val clock = currentStream.triggerClock.asInstanceOf[StreamManualClock] - assert(manualClockExpectedTime >= 0) - - // Make sure we don't advance ManualClock too early. See SPARK-16002. - eventually("StreamManualClock has not yet entered the waiting state") { - assert(clock.isStreamWaitingAt(manualClockExpectedTime)) + case ef: ExpectFailure[_] => + verify(currentStream != null, "can not expect failure when stream is not running") + try failAfter(streamingTimeout) { + val thrownException = intercept[StreamingQueryException] { + currentStream.awaitTermination() } - - clock.advance(timeToAdd) - manualClockExpectedTime += timeToAdd - verify(clock.getTimeMillis() === manualClockExpectedTime, - s"Unexpected clock time after updating: " + - s"expecting $manualClockExpectedTime, current ${clock.getTimeMillis()}") - - case StopStream => - verify(currentStream != null, "can not stop a stream that is not running") - try failAfter(streamingTimeout) { - currentStream.stop() - verify(!currentStream.queryExecutionThread.isAlive, - s"microbatch thread not stopped") - verify(!currentStream.isActive, - "query.isActive() is false even after stopping") - verify(currentStream.exception.isEmpty, - s"query.exception() is not empty after clean stop: " + - currentStream.exception.map(_.toString()).getOrElse("")) - } catch { - case _: InterruptedException => - case e: org.scalatest.exceptions.TestFailedDueToTimeoutException => - failTest( - "Timed out while stopping and waiting for microbatchthread to terminate.", e) - case t: Throwable => - failTest("Error while stopping stream", t) - } finally { - lastStream = currentStream - currentStream = null + eventually("microbatch thread not stopped after termination with failure") { + assert(!currentStream.queryExecutionThread.isAlive) } + verify(currentStream.exception === Some(thrownException), + s"incorrect exception returned by query.exception()") + + val exception = currentStream.exception.get + verify(exception.cause.getClass === ef.causeClass, + "incorrect cause in exception returned by query.exception()\n" + + s"\tExpected: ${ef.causeClass}\n\tReturned: ${exception.cause.getClass}") + if (ef.isFatalError) { + // This is a fatal error, `streamThreadDeathCause` should be set to this error in + // UncaughtExceptionHandler. + verify(streamThreadDeathCause != null && + streamThreadDeathCause.getClass === ef.causeClass, + "UncaughtExceptionHandler didn't receive the correct error\n" + + s"\tExpected: ${ef.causeClass}\n\tReturned: $streamThreadDeathCause") + streamThreadDeathCause = null + } + ef.assertFailure(exception.getCause) + } catch { + case _: InterruptedException => + case e: org.scalatest.exceptions.TestFailedDueToTimeoutException => + failTest("Timed out while waiting for failure", e) + case t: Throwable => + failTest("Error while checking stream failure", t) + } finally { + lastStream = currentStream + currentStream = null + } - case ef: ExpectFailure[_] => - verify(currentStream != null, "can not expect failure when stream is not running") - try failAfter(streamingTimeout) { - val thrownException = intercept[StreamingQueryException] { - currentStream.awaitTermination() - } - eventually("microbatch thread not stopped after termination with failure") { - assert(!currentStream.queryExecutionThread.isAlive) + case a: AssertOnQuery => + verify(currentStream != null || lastStream != null, + "cannot assert when no stream has been started") + val streamToAssert = Option(currentStream).getOrElse(lastStream) + try { + verify(a.condition(streamToAssert), s"Assert on query failed: ${a.message}") + } catch { + case NonFatal(e) => + failTest(s"Assert on query failed: ${a.message}", e) + } + + case a: Assert => + val streamToAssert = Option(currentStream).getOrElse(lastStream) + verify({ a.run(); true }, s"Assert failed: ${a.message}") + + case a: AddData => + try { + + // If the query is running with manual clock, then wait for the stream execution + // thread to start waiting for the clock to increment. This is needed so that we + // are adding data when there is no trigger that is active. This would ensure that + // the data gets deterministically added to the next batch triggered after the manual + // clock is incremented in following AdvanceManualClock. This avoid race conditions + // between the test thread and the stream execution thread in tests using manual + // clock. + if (currentStream != null && + currentStream.triggerClock.isInstanceOf[StreamManualClock]) { + val clock = currentStream.triggerClock.asInstanceOf[StreamManualClock] + eventually("Error while synchronizing with manual clock before adding data") { + if (currentStream.isActive) { + assert(clock.isStreamWaitingAt(clock.getTimeMillis())) + } } - verify(currentStream.exception === Some(thrownException), - s"incorrect exception returned by query.exception()") - - val exception = currentStream.exception.get - verify(exception.cause.getClass === ef.causeClass, - "incorrect cause in exception returned by query.exception()\n" + - s"\tExpected: ${ef.causeClass}\n\tReturned: ${exception.cause.getClass}") - if (ef.isFatalError) { - // This is a fatal error, `streamThreadDeathCause` should be set to this error in - // UncaughtExceptionHandler. - verify(streamThreadDeathCause != null && - streamThreadDeathCause.getClass === ef.causeClass, - "UncaughtExceptionHandler didn't receive the correct error\n" + - s"\tExpected: ${ef.causeClass}\n\tReturned: $streamThreadDeathCause") - streamThreadDeathCause = null + if (!currentStream.isActive) { + failTest("Query terminated while synchronizing with manual clock") } - ef.assertFailure(exception.getCause) - } catch { - case _: InterruptedException => - case e: org.scalatest.exceptions.TestFailedDueToTimeoutException => - failTest("Timed out while waiting for failure", e) - case t: Throwable => - failTest("Error while checking stream failure", t) - } finally { - lastStream = currentStream - currentStream = null } - - case a: AssertOnQuery => - verify(currentStream != null || lastStream != null, - "cannot assert when no stream has been started") - val streamToAssert = Option(currentStream).getOrElse(lastStream) - try { - verify(a.condition(streamToAssert), s"Assert on query failed: ${a.message}") - } catch { - case NonFatal(e) => - failTest(s"Assert on query failed: ${a.message}", e) + // Add data + val queryToUse = Option(currentStream).orElse(Option(lastStream)) + val (source, offset) = a.addData(queryToUse) + + def findSourceIndex(plan: LogicalPlan): Option[Int] = { + plan + .collect { + case StreamingExecutionRelation(s, _) => s + case StreamingDataSourceV2Relation(_, r) => r + } + .zipWithIndex + .find(_._1 == source) + .map(_._2) } - case a: Assert => - val streamToAssert = Option(currentStream).getOrElse(lastStream) - verify({ a.run(); true }, s"Assert failed: ${a.message}") - - case a: AddData => - try { - - // If the query is running with manual clock, then wait for the stream execution - // thread to start waiting for the clock to increment. This is needed so that we - // are adding data when there is no trigger that is active. This would ensure that - // the data gets deterministically added to the next batch triggered after the manual - // clock is incremented in following AdvanceManualClock. This avoid race conditions - // between the test thread and the stream execution thread in tests using manual - // clock. - if (currentStream != null && - currentStream.triggerClock.isInstanceOf[StreamManualClock]) { - val clock = currentStream.triggerClock.asInstanceOf[StreamManualClock] - eventually("Error while synchronizing with manual clock before adding data") { - if (currentStream.isActive) { - assert(clock.isStreamWaitingAt(clock.getTimeMillis())) - } + // Try to find the index of the source to which data was added. Either get the index + // from the current active query or the original input logical plan. + val sourceIndex = + queryToUse.flatMap { query => + findSourceIndex(query.logicalPlan) + }.orElse { + findSourceIndex(stream.logicalPlan) + }.orElse { + queryToUse.flatMap { q => + findSourceIndex(q.lastExecution.logical) } - if (!currentStream.isActive) { - failTest("Query terminated while synchronizing with manual clock") - } - } - // Add data - val queryToUse = Option(currentStream).orElse(Option(lastStream)) - val (source, offset) = a.addData(queryToUse) - - def findSourceIndex(plan: LogicalPlan): Option[Int] = { - plan - .collect { - case StreamingExecutionRelation(s, _) => s - case StreamingDataSourceV2Relation(_, r) => r - } - .zipWithIndex - .find(_._1 == source) - .map(_._2) + }.getOrElse { + throw new IllegalArgumentException( + "Could not find index of the source to which data was added") } - // Try to find the index of the source to which data was added. Either get the index - // from the current active query or the original input logical plan. - val sourceIndex = - queryToUse.flatMap { query => - findSourceIndex(query.logicalPlan) - }.orElse { - findSourceIndex(stream.logicalPlan) - }.orElse { - queryToUse.flatMap { q => - findSourceIndex(q.lastExecution.logical) - } - }.getOrElse { - throw new IllegalArgumentException( - "Could not find index of the source to which data was added") - } + // Store the expected offset of added data to wait for it later + awaiting.put(sourceIndex, offset) + } catch { + case NonFatal(e) => + failTest("Error adding data", e) + } - // Store the expected offset of added data to wait for it later - awaiting.put(sourceIndex, offset) - } catch { - case NonFatal(e) => - failTest("Error adding data", e) - } + case e: ExternalAction => + e.runAction() - case e: ExternalAction => - e.runAction() + case CheckAnswerRows(expectedAnswer, lastOnly, isSorted) => + val sparkAnswer = fetchStreamAnswer(currentStream, lastOnly) + QueryTest.sameRows(expectedAnswer, sparkAnswer, isSorted).foreach { + error => failTest(error) + } - case CheckAnswerRows(expectedAnswer, lastOnly, isSorted) => - val sparkAnswer = fetchStreamAnswer(currentStream, lastOnly) - QueryTest.sameRows(expectedAnswer, sparkAnswer, isSorted).foreach { - error => failTest(error) - } + case CheckAnswerRowsContains(expectedAnswer, lastOnly) => + val sparkAnswer = currentStream match { + case null => fetchStreamAnswer(lastStream, lastOnly) + case s => fetchStreamAnswer(s, lastOnly) + } + QueryTest.includesRows(expectedAnswer, sparkAnswer).foreach { + error => failTest(error) + } - case CheckAnswerRowsContains(expectedAnswer, lastOnly) => - val sparkAnswer = currentStream match { - case null => fetchStreamAnswer(lastStream, lastOnly) - case s => fetchStreamAnswer(s, lastOnly) - } - QueryTest.includesRows(expectedAnswer, sparkAnswer).foreach { - error => failTest(error) - } + case CheckAnswerRowsByFunc(globalCheckFunction, lastOnly) => + val sparkAnswer = fetchStreamAnswer(currentStream, lastOnly) + try { + globalCheckFunction(sparkAnswer) + } catch { + case e: Throwable => failTest(e.toString) + } + } + pos += 1 + } - case CheckAnswerRowsByFunc(globalCheckFunction, lastOnly) => - val sparkAnswer = fetchStreamAnswer(currentStream, lastOnly) - try { - globalCheckFunction(sparkAnswer) - } catch { - case e: Throwable => failTest(e.toString) - } - } - pos += 1 + try { + startedTest.foreach { + case StreamProgressLockedActions(actns, _) => + // Perform actions while holding the stream from progressing + assert(currentStream != null, + s"Cannot perform stream-progress-locked actions $actns when query is not active") + assert(currentStream.isInstanceOf[MicroBatchExecution], + s"Cannot perform stream-progress-locked actions on non-microbatch queries") + currentStream.asInstanceOf[MicroBatchExecution].withProgressLocked { + actns.foreach(executeAction) + } + + case action: StreamAction => executeAction(action) } if (streamThreadDeathCause != null) { failTest("Stream Thread Died", streamThreadDeathCause) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala index 92087f68ad74a..11bdd13942dcb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala @@ -462,15 +462,13 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with .select(left("key"), left("window.end").cast("long"), 'leftValue, 'rightValue) testStream(joined)( - AddData(leftInput, 1, 2, 3), - AddData(rightInput, 3, 4, 5), + MultiAddData(leftInput, 1, 2, 3)(rightInput, 3, 4, 5), // The left rows with leftValue <= 4 should generate their outer join row now and // not get added to the state. CheckLastBatch(Row(3, 10, 6, "9"), Row(1, 10, 2, null), Row(2, 10, 4, null)), assertNumStateRows(total = 4, updated = 4), // We shouldn't get more outer join rows when the watermark advances. - AddData(leftInput, 20), - AddData(rightInput, 21), + MultiAddData(leftInput, 20)(rightInput, 21), CheckLastBatch(), AddData(rightInput, 20), CheckLastBatch((20, 30, 40, "60")) @@ -493,15 +491,13 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with .select(left("key"), left("window.end").cast("long"), 'leftValue, 'rightValue) testStream(joined)( - AddData(leftInput, 3, 4, 5), - AddData(rightInput, 1, 2, 3), + MultiAddData(leftInput, 3, 4, 5)(rightInput, 1, 2, 3), // The right rows with value <= 7 should never be added to the state. CheckLastBatch(Row(3, 10, 6, "9")), assertNumStateRows(total = 4, updated = 4), // When the watermark advances, we get the outer join rows just as we would if they // were added but didn't match the full join condition. - AddData(leftInput, 20), - AddData(rightInput, 21), + MultiAddData(leftInput, 20)(rightInput, 21), CheckLastBatch(), AddData(rightInput, 20), CheckLastBatch(Row(20, 30, 40, "60"), Row(4, 10, 8, null), Row(5, 10, 10, null)) @@ -524,15 +520,13 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with .select(right("key"), right("window.end").cast("long"), 'leftValue, 'rightValue) testStream(joined)( - AddData(leftInput, 1, 2, 3), - AddData(rightInput, 3, 4, 5), + MultiAddData(leftInput, 1, 2, 3)(rightInput, 3, 4, 5), // The left rows with value <= 4 should never be added to the state. CheckLastBatch(Row(3, 10, 6, "9")), assertNumStateRows(total = 4, updated = 4), // When the watermark advances, we get the outer join rows just as we would if they // were added but didn't match the full join condition. - AddData(leftInput, 20), - AddData(rightInput, 21), + MultiAddData(leftInput, 20)(rightInput, 21), CheckLastBatch(), AddData(rightInput, 20), CheckLastBatch(Row(20, 30, 40, "60"), Row(4, 10, null, "12"), Row(5, 10, null, "15")) @@ -555,15 +549,13 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with .select(right("key"), right("window.end").cast("long"), 'leftValue, 'rightValue) testStream(joined)( - AddData(leftInput, 3, 4, 5), - AddData(rightInput, 1, 2, 3), + MultiAddData(leftInput, 3, 4, 5)(rightInput, 1, 2, 3), // The right rows with rightValue <= 7 should generate their outer join row now and // not get added to the state. CheckLastBatch(Row(3, 10, 6, "9"), Row(1, 10, null, "3"), Row(2, 10, null, "6")), assertNumStateRows(total = 4, updated = 4), // We shouldn't get more outer join rows when the watermark advances. - AddData(leftInput, 20), - AddData(rightInput, 21), + MultiAddData(leftInput, 20)(rightInput, 21), CheckLastBatch(), AddData(rightInput, 20), CheckLastBatch((20, 30, 40, "60")) @@ -575,13 +567,11 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with testStream(joined)( // Test inner part of the join. - AddData(leftInput, 1, 2, 3, 4, 5), - AddData(rightInput, 3, 4, 5, 6, 7), + MultiAddData(leftInput, 1, 2, 3, 4, 5)(rightInput, 3, 4, 5, 6, 7), CheckLastBatch((3, 10, 6, 9), (4, 10, 8, 12), (5, 10, 10, 15)), // Old state doesn't get dropped until the batch *after* it gets introduced, so the // nulls won't show up until the next batch after the watermark advances. - AddData(leftInput, 21), - AddData(rightInput, 22), + MultiAddData(leftInput, 21)(rightInput, 22), CheckLastBatch(), assertNumStateRows(total = 12, updated = 2), AddData(leftInput, 22), @@ -595,13 +585,11 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with testStream(joined)( // Test inner part of the join. - AddData(leftInput, 1, 2, 3, 4, 5), - AddData(rightInput, 3, 4, 5, 6, 7), + MultiAddData(leftInput, 1, 2, 3, 4, 5)(rightInput, 3, 4, 5, 6, 7), CheckLastBatch((3, 10, 6, 9), (4, 10, 8, 12), (5, 10, 10, 15)), // Old state doesn't get dropped until the batch *after* it gets introduced, so the // nulls won't show up until the next batch after the watermark advances. - AddData(leftInput, 21), - AddData(rightInput, 22), + MultiAddData(leftInput, 21)(rightInput, 22), CheckLastBatch(), assertNumStateRows(total = 12, updated = 2), AddData(leftInput, 22), @@ -676,11 +664,9 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with testStream(joined)( // leftValue <= 10 should generate outer join rows even though it matches right keys - AddData(leftInput, 1, 2, 3), - AddData(rightInput, 1, 2, 3), + MultiAddData(leftInput, 1, 2, 3)(rightInput, 1, 2, 3), CheckLastBatch(Row(1, 10, 2, null), Row(2, 10, 4, null), Row(3, 10, 6, null)), - AddData(leftInput, 20), - AddData(rightInput, 21), + MultiAddData(leftInput, 20)(rightInput, 21), CheckLastBatch(), assertNumStateRows(total = 5, updated = 2), AddData(rightInput, 20), @@ -688,22 +674,18 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with Row(20, 30, 40, 60)), assertNumStateRows(total = 3, updated = 1), // leftValue and rightValue both satisfying condition should not generate outer join rows - AddData(leftInput, 40, 41), - AddData(rightInput, 40, 41), + MultiAddData(leftInput, 40, 41)(rightInput, 40, 41), CheckLastBatch((40, 50, 80, 120), (41, 50, 82, 123)), - AddData(leftInput, 70), - AddData(rightInput, 71), + MultiAddData(leftInput, 70)(rightInput, 71), CheckLastBatch(), assertNumStateRows(total = 6, updated = 2), AddData(rightInput, 70), CheckLastBatch((70, 80, 140, 210)), assertNumStateRows(total = 3, updated = 1), // rightValue between 300 and 1000 should generate outer join rows even though it matches left - AddData(leftInput, 101, 102, 103), - AddData(rightInput, 101, 102, 103), + MultiAddData(leftInput, 101, 102, 103)(rightInput, 101, 102, 103), CheckLastBatch(), - AddData(leftInput, 1000), - AddData(rightInput, 1001), + MultiAddData(leftInput, 1000)(rightInput, 1001), CheckLastBatch(), assertNumStateRows(total = 8, updated = 2), AddData(rightInput, 1000), From 1a198ce8f580bcf35b9cbfab403fc40f821046a1 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Fri, 23 Feb 2018 16:30:32 -0800 Subject: [PATCH 0391/1161] [SPARK-23459][SQL] Improve the error message when unknown column is specified in partition columns ## What changes were proposed in this pull request? This PR avoids to print schema internal information when unknown column is specified in partition columns. This PR prints column names in the schema with more readable format. The following is an example. Source code ``` test("save with an unknown partition column") { withTempDir { dir => val path = dir.getCanonicalPath Seq(1L -> "a").toDF("i", "j").write .format("parquet") .partitionBy("unknownColumn") .save(path) } ``` Output without this PR ``` Partition column unknownColumn not found in schema StructType(StructField(i,LongType,false), StructField(j,StringType,true)); ``` Output with this PR ``` Partition column unknownColumn not found in schema struct; ``` ## How was this patch tested? Manually tested Author: Kazuaki Ishizaki Closes #20653 from kiszk/SPARK-23459. --- .../datasources/PartitioningUtils.scala | 3 ++- .../apache/spark/sql/sources/SaveLoadSuite.scala | 16 ++++++++++++++++ 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala index 379acb67f7c71..f9a24806953e6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala @@ -486,7 +486,8 @@ object PartitioningUtils { val equality = columnNameEquality(caseSensitive) StructType(partitionColumns.map { col => schema.find(f => equality(f.name, col)).getOrElse { - throw new AnalysisException(s"Partition column $col not found in schema $schema") + val schemaCatalog = schema.catalogString + throw new AnalysisException(s"Partition column `$col` not found in schema $schemaCatalog") } }).asNullable } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala index 773d34dfaf9a8..12779b46bfe8c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala @@ -126,4 +126,20 @@ class SaveLoadSuite extends DataSourceTest with SharedSQLContext with BeforeAndA checkLoad(df2, "jsonTable2") } + + test("SPARK-23459: Improve error message when specified unknown column in partition columns") { + withTempDir { dir => + val path = dir.getCanonicalPath + val unknown = "unknownColumn" + val df = Seq(1L -> "a").toDF("i", "j") + val schemaCatalog = df.schema.catalogString + val e = intercept[AnalysisException] { + df.write + .format("parquet") + .partitionBy(unknown) + .save(path) + }.getMessage + assert(e.contains(s"Partition column `$unknown` not found in schema $schemaCatalog")) + } + } } From 3ca9a2c56513444d7b233088b020d2d43fa6b77a Mon Sep 17 00:00:00 2001 From: Gabor Somogyi Date: Sun, 25 Feb 2018 09:29:59 -0600 Subject: [PATCH 0392/1161] =?UTF-8?q?[SPARK-22886][ML][TESTS]=20ML=20test?= =?UTF-8?q?=20for=20structured=20streaming:=20ml.recomme=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? Converting spark.ml.recommendation tests to also check code with structured streaming, using the ML testing infrastructure implemented in SPARK-22882. ## How was this patch tested? Automated: Pass the Jenkins. Author: Gabor Somogyi Closes #20362 from gaborgsomogyi/SPARK-22886. --- .../spark/ml/recommendation/ALSSuite.scala | 213 ++++++++++++------ .../apache/spark/ml/util/MLTestingUtils.scala | 44 ---- 2 files changed, 143 insertions(+), 114 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala index addcd21d50aac..e3dfe2faf5698 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala @@ -22,8 +22,7 @@ import java.util.Random import scala.collection.JavaConverters._ import scala.collection.mutable -import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.WrappedArray +import scala.collection.mutable.{ArrayBuffer, WrappedArray} import scala.language.existentials import com.github.fommil.netlib.BLAS.{getInstance => blas} @@ -35,21 +34,20 @@ import org.apache.spark._ import org.apache.spark.internal.Logging import org.apache.spark.ml.linalg.Vectors import org.apache.spark.ml.recommendation.ALS._ -import org.apache.spark.ml.recommendation.ALS.Rating -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ -import org.apache.spark.mllib.recommendation.MatrixFactorizationModelSuite import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD import org.apache.spark.scheduler.{SparkListener, SparkListenerStageCompleted} -import org.apache.spark.sql.{DataFrame, Row, SparkSession} -import org.apache.spark.sql.functions.lit +import org.apache.spark.sql.{DataFrame, Encoder, Row, SparkSession} +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.functions.{col, lit} +import org.apache.spark.sql.streaming.StreamingQueryException import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils -class ALSSuite - extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest with Logging { +class ALSSuite extends MLTest with DefaultReadWriteTest with Logging { override def beforeAll(): Unit = { super.beforeAll() @@ -413,34 +411,36 @@ class ALSSuite .setSeed(0) val alpha = als.getAlpha val model = als.fit(training.toDF()) - val predictions = model.transform(test.toDF()).select("rating", "prediction").rdd.map { - case Row(rating: Float, prediction: Float) => - (rating.toDouble, prediction.toDouble) + testTransformerByGlobalCheckFunc[Rating[Int]](test.toDF(), model, "rating", "prediction") { + case rows: Seq[Row] => + val predictions = rows.map(row => (row.getFloat(0).toDouble, row.getFloat(1).toDouble)) + + val rmse = + if (implicitPrefs) { + // TODO: Use a better (rank-based?) evaluation metric for implicit feedback. + // We limit the ratings and the predictions to interval [0, 1] and compute the + // weighted RMSE with the confidence scores as weights. + val (totalWeight, weightedSumSq) = predictions.map { case (rating, prediction) => + val confidence = 1.0 + alpha * math.abs(rating) + val rating01 = math.max(math.min(rating, 1.0), 0.0) + val prediction01 = math.max(math.min(prediction, 1.0), 0.0) + val err = prediction01 - rating01 + (confidence, confidence * err * err) + }.reduce[(Double, Double)] { case ((c0, e0), (c1, e1)) => + (c0 + c1, e0 + e1) + } + math.sqrt(weightedSumSq / totalWeight) + } else { + val errorSquares = predictions.map { case (rating, prediction) => + val err = rating - prediction + err * err + } + val mse = errorSquares.sum / errorSquares.length + math.sqrt(mse) + } + logInfo(s"Test RMSE is $rmse.") + assert(rmse < targetRMSE) } - val rmse = - if (implicitPrefs) { - // TODO: Use a better (rank-based?) evaluation metric for implicit feedback. - // We limit the ratings and the predictions to interval [0, 1] and compute the weighted RMSE - // with the confidence scores as weights. - val (totalWeight, weightedSumSq) = predictions.map { case (rating, prediction) => - val confidence = 1.0 + alpha * math.abs(rating) - val rating01 = math.max(math.min(rating, 1.0), 0.0) - val prediction01 = math.max(math.min(prediction, 1.0), 0.0) - val err = prediction01 - rating01 - (confidence, confidence * err * err) - }.reduce { case ((c0, e0), (c1, e1)) => - (c0 + c1, e0 + e1) - } - math.sqrt(weightedSumSq / totalWeight) - } else { - val mse = predictions.map { case (rating, prediction) => - val err = rating - prediction - err * err - }.mean() - math.sqrt(mse) - } - logInfo(s"Test RMSE is $rmse.") - assert(rmse < targetRMSE) MLTestingUtils.checkCopyAndUids(als, model) } @@ -586,6 +586,68 @@ class ALSSuite allModelParamSettings, checkModelData) } + private def checkNumericTypesALS( + estimator: ALS, + spark: SparkSession, + column: String, + baseType: NumericType) + (check: (ALSModel, ALSModel) => Unit) + (check2: (ALSModel, ALSModel, DataFrame, Encoder[_]) => Unit): Unit = { + val dfs = genRatingsDFWithNumericCols(spark, column) + val df = dfs.find { + case (numericTypeWithEncoder, _) => numericTypeWithEncoder.numericType == baseType + } match { + case Some((_, df)) => df + } + val expected = estimator.fit(df) + val actuals = dfs.filter(_ != baseType).map(t => (t, estimator.fit(t._2))) + actuals.foreach { case (_, actual) => check(expected, actual) } + actuals.foreach { case (t, actual) => check2(expected, actual, t._2, t._1.encoder) } + + val baseDF = dfs.find(_._1.numericType == baseType).get._2 + val others = baseDF.columns.toSeq.diff(Seq(column)).map(col) + val cols = Seq(col(column).cast(StringType)) ++ others + val strDF = baseDF.select(cols: _*) + val thrown = intercept[IllegalArgumentException] { + estimator.fit(strDF) + } + assert(thrown.getMessage.contains( + s"$column must be of type NumericType but was actually of type StringType")) + } + + private class NumericTypeWithEncoder[A](val numericType: NumericType) + (implicit val encoder: Encoder[(A, Int, Double)]) + + private def genRatingsDFWithNumericCols( + spark: SparkSession, + column: String) = { + + import testImplicits._ + + val df = spark.createDataFrame(Seq( + (0, 10, 1.0), + (1, 20, 2.0), + (2, 30, 3.0), + (3, 40, 4.0), + (4, 50, 5.0) + )).toDF("user", "item", "rating") + + val others = df.columns.toSeq.diff(Seq(column)).map(col) + val types = + Seq(new NumericTypeWithEncoder[Short](ShortType), + new NumericTypeWithEncoder[Long](LongType), + new NumericTypeWithEncoder[Int](IntegerType), + new NumericTypeWithEncoder[Float](FloatType), + new NumericTypeWithEncoder[Byte](ByteType), + new NumericTypeWithEncoder[Double](DoubleType), + new NumericTypeWithEncoder[Decimal](DecimalType(10, 0))(ExpressionEncoder()) + ) + types.map { t => + val cols = Seq(col(column).cast(t.numericType)) ++ others + t -> df.select(cols: _*) + } + } + test("input type validation") { val spark = this.spark import spark.implicits._ @@ -595,12 +657,16 @@ class ALSSuite val als = new ALS().setMaxIter(1).setRank(1) Seq(("user", IntegerType), ("item", IntegerType), ("rating", FloatType)).foreach { case (colName, sqlType) => - MLTestingUtils.checkNumericTypesALS(als, spark, colName, sqlType) { + checkNumericTypesALS(als, spark, colName, sqlType) { (ex, act) => - ex.userFactors.first().getSeq[Float](1) === act.userFactors.first.getSeq[Float](1) - } { (ex, act, _) => - ex.transform(_: DataFrame).select("prediction").first.getDouble(0) ~== - act.transform(_: DataFrame).select("prediction").first.getDouble(0) absTol 1e-6 + ex.userFactors.first().getSeq[Float](1) === act.userFactors.first().getSeq[Float](1) + } { (ex, act, df, enc) => + val expected = ex.transform(df).selectExpr("prediction") + .first().getFloat(0) + testTransformerByGlobalCheckFunc(df, act, "prediction") { + case rows: Seq[Row] => + expected ~== rows.head.getFloat(0) absTol 1e-6 + }(enc) } } // check user/item ids falling outside of Int range @@ -628,18 +694,22 @@ class ALSSuite } withClue("transform should fail when ids exceed integer range. ") { val model = als.fit(df) - assert(intercept[SparkException] { - model.transform(df.select(df("user_big").as("user"), df("item"))).first - }.getMessage.contains(msg)) - assert(intercept[SparkException] { - model.transform(df.select(df("user_small").as("user"), df("item"))).first - }.getMessage.contains(msg)) - assert(intercept[SparkException] { - model.transform(df.select(df("item_big").as("item"), df("user"))).first - }.getMessage.contains(msg)) - assert(intercept[SparkException] { - model.transform(df.select(df("item_small").as("item"), df("user"))).first - }.getMessage.contains(msg)) + def testTransformIdExceedsIntRange[A : Encoder](dataFrame: DataFrame): Unit = { + assert(intercept[SparkException] { + model.transform(dataFrame).first + }.getMessage.contains(msg)) + assert(intercept[StreamingQueryException] { + testTransformer[A](dataFrame, model, "prediction") { _ => } + }.getMessage.contains(msg)) + } + testTransformIdExceedsIntRange[(Long, Int)](df.select(df("user_big").as("user"), + df("item"))) + testTransformIdExceedsIntRange[(Double, Int)](df.select(df("user_small").as("user"), + df("item"))) + testTransformIdExceedsIntRange[(Long, Int)](df.select(df("item_big").as("item"), + df("user"))) + testTransformIdExceedsIntRange[(Double, Int)](df.select(df("item_small").as("item"), + df("user"))) } } @@ -662,28 +732,31 @@ class ALSSuite val knownItem = data.select(max("item")).as[Int].first() val unknownItem = knownItem + 20 val test = Seq( - (unknownUser, unknownItem), - (knownUser, unknownItem), - (unknownUser, knownItem), - (knownUser, knownItem) - ).toDF("user", "item") + (unknownUser, unknownItem, true), + (knownUser, unknownItem, true), + (unknownUser, knownItem, true), + (knownUser, knownItem, false) + ).toDF("user", "item", "expectedIsNaN") val als = new ALS().setMaxIter(1).setRank(1) // default is 'nan' val defaultModel = als.fit(data) - val defaultPredictions = defaultModel.transform(test).select("prediction").as[Float].collect() - assert(defaultPredictions.length == 4) - assert(defaultPredictions.slice(0, 3).forall(_.isNaN)) - assert(!defaultPredictions.last.isNaN) + testTransformer[(Int, Int, Boolean)](test, defaultModel, "expectedIsNaN", "prediction") { + case Row(expectedIsNaN: Boolean, prediction: Float) => + assert(prediction.isNaN === expectedIsNaN) + } // check 'drop' strategy should filter out rows with unknown users/items - val dropPredictions = defaultModel - .setColdStartStrategy("drop") - .transform(test) - .select("prediction").as[Float].collect() - assert(dropPredictions.length == 1) - assert(!dropPredictions.head.isNaN) - assert(dropPredictions.head ~== defaultPredictions.last relTol 1e-14) + val defaultPrediction = defaultModel.transform(test).select("prediction") + .as[Float].filter(!_.isNaN).first() + testTransformerByGlobalCheckFunc[(Int, Int, Boolean)](test, + defaultModel.setColdStartStrategy("drop"), "prediction") { + case rows: Seq[Row] => + val dropPredictions = rows.map(_.getFloat(0)) + assert(dropPredictions.length == 1) + assert(!dropPredictions.head.isNaN) + assert(dropPredictions.head ~== defaultPrediction relTol 1e-14) + } } test("case insensitive cold start param value") { @@ -693,7 +766,7 @@ class ALSSuite val data = ratings.toDF val model = new ALS().fit(data) Seq("nan", "NaN", "Nan", "drop", "DROP", "Drop").foreach { s => - model.setColdStartStrategy(s).transform(data) + testTransformer[Rating[Int]](data, model.setColdStartStrategy(s), "prediction") { _ => } } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala index aef81c8c173a0..c328d81b4bc3a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala @@ -91,30 +91,6 @@ object MLTestingUtils extends SparkFunSuite { } } - def checkNumericTypesALS( - estimator: ALS, - spark: SparkSession, - column: String, - baseType: NumericType) - (check: (ALSModel, ALSModel) => Unit) - (check2: (ALSModel, ALSModel, DataFrame) => Unit): Unit = { - val dfs = genRatingsDFWithNumericCols(spark, column) - val expected = estimator.fit(dfs(baseType)) - val actuals = dfs.keys.filter(_ != baseType).map(t => (t, estimator.fit(dfs(t)))) - actuals.foreach { case (_, actual) => check(expected, actual) } - actuals.foreach { case (t, actual) => check2(expected, actual, dfs(t)) } - - val baseDF = dfs(baseType) - val others = baseDF.columns.toSeq.diff(Seq(column)).map(col) - val cols = Seq(col(column).cast(StringType)) ++ others - val strDF = baseDF.select(cols: _*) - val thrown = intercept[IllegalArgumentException] { - estimator.fit(strDF) - } - assert(thrown.getMessage.contains( - s"$column must be of type NumericType but was actually of type StringType")) - } - def checkNumericTypes[T <: Evaluator](evaluator: T, spark: SparkSession): Unit = { val dfs = genEvaluatorDFWithNumericLabelCol(spark, "label", "prediction") val expected = evaluator.evaluate(dfs(DoubleType)) @@ -176,26 +152,6 @@ object MLTestingUtils extends SparkFunSuite { }.toMap } - def genRatingsDFWithNumericCols( - spark: SparkSession, - column: String): Map[NumericType, DataFrame] = { - val df = spark.createDataFrame(Seq( - (0, 10, 1.0), - (1, 20, 2.0), - (2, 30, 3.0), - (3, 40, 4.0), - (4, 50, 5.0) - )).toDF("user", "item", "rating") - - val others = df.columns.toSeq.diff(Seq(column)).map(col) - val types: Seq[NumericType] = - Seq(ShortType, LongType, IntegerType, FloatType, ByteType, DoubleType, DecimalType(10, 0)) - types.map { t => - val cols = Seq(col(column).cast(t)) ++ others - t -> df.select(cols: _*) - }.toMap - } - def genEvaluatorDFWithNumericLabelCol( spark: SparkSession, labelColName: String = "label", From b308182f233b8840dfe0e6b5736d2f2746f40757 Mon Sep 17 00:00:00 2001 From: Gabor Somogyi Date: Mon, 26 Feb 2018 08:39:44 -0800 Subject: [PATCH 0393/1161] [SPARK-23438][DSTREAMS] Fix DStreams data loss with WAL when driver crashes ## What changes were proposed in this pull request? There is a race condition introduced in SPARK-11141 which could cause data loss. The problem is that ReceivedBlockTracker.insertAllocatedBatch function assumes that all blocks from streamIdToUnallocatedBlockQueues allocated to the batch and clears the queue. In this PR only the allocated blocks will be removed from the queue which will prevent data loss. ## How was this patch tested? Additional unit test + manually. Author: Gabor Somogyi Closes #20620 from gaborgsomogyi/SPARK-23438. --- .../scheduler/ReceivedBlockTracker.scala | 11 +++++---- .../streaming/ReceivedBlockTrackerSuite.scala | 23 ++++++++++++++++++- 2 files changed, 29 insertions(+), 5 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala index 5d9a8ac0d9297..dacff69d55dd2 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala @@ -193,12 +193,15 @@ private[streaming] class ReceivedBlockTracker( getReceivedBlockQueue(receivedBlockInfo.streamId) += receivedBlockInfo } - // Insert the recovered block-to-batch allocations and clear the queue of received blocks - // (when the blocks were originally allocated to the batch, the queue must have been cleared). + // Insert the recovered block-to-batch allocations and removes them from queue of + // received blocks. def insertAllocatedBatch(batchTime: Time, allocatedBlocks: AllocatedBlocks) { logTrace(s"Recovery: Inserting allocated batch for time $batchTime to " + s"${allocatedBlocks.streamIdToAllocatedBlocks}") - streamIdToUnallocatedBlockQueues.values.foreach { _.clear() } + allocatedBlocks.streamIdToAllocatedBlocks.foreach { + case (streamId, allocatedBlocksInStream) => + getReceivedBlockQueue(streamId).dequeueAll(allocatedBlocksInStream.toSet) + } timeToAllocatedBlocks.put(batchTime, allocatedBlocks) lastAllocatedBatchTime = batchTime } @@ -227,7 +230,7 @@ private[streaming] class ReceivedBlockTracker( } /** Write an update to the tracker to the write ahead log */ - private def writeToLog(record: ReceivedBlockTrackerLogEvent): Boolean = { + private[streaming] def writeToLog(record: ReceivedBlockTrackerLogEvent): Boolean = { if (isWriteAheadLogEnabled) { logTrace(s"Writing record: $record") try { diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala index 107c3f5dcc08d..4fa236bd39663 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala @@ -33,7 +33,7 @@ import org.apache.spark.{SparkConf, SparkException, SparkFunSuite} import org.apache.spark.internal.Logging import org.apache.spark.storage.StreamBlockId import org.apache.spark.streaming.receiver.BlockManagerBasedStoreResult -import org.apache.spark.streaming.scheduler._ +import org.apache.spark.streaming.scheduler.{AllocatedBlocks, _} import org.apache.spark.streaming.util._ import org.apache.spark.streaming.util.WriteAheadLogSuite._ import org.apache.spark.util.{Clock, ManualClock, SystemClock, Utils} @@ -94,6 +94,27 @@ class ReceivedBlockTrackerSuite receivedBlockTracker.getUnallocatedBlocks(streamId) shouldEqual blockInfos } + test("recovery with write ahead logs should remove only allocated blocks from received queue") { + val manualClock = new ManualClock + val batchTime = manualClock.getTimeMillis() + + val tracker1 = createTracker(clock = manualClock) + tracker1.isWriteAheadLogEnabled should be (true) + + val allocatedBlockInfos = generateBlockInfos() + val unallocatedBlockInfos = generateBlockInfos() + val receivedBlockInfos = allocatedBlockInfos ++ unallocatedBlockInfos + receivedBlockInfos.foreach { b => tracker1.writeToLog(BlockAdditionEvent(b)) } + val allocatedBlocks = AllocatedBlocks(Map(streamId -> allocatedBlockInfos)) + tracker1.writeToLog(BatchAllocationEvent(batchTime, allocatedBlocks)) + tracker1.stop() + + val tracker2 = createTracker(clock = manualClock, recoverFromWriteAheadLog = true) + tracker2.getBlocksOfBatch(batchTime) shouldEqual allocatedBlocks.streamIdToAllocatedBlocks + tracker2.getUnallocatedBlocks(streamId) shouldEqual unallocatedBlockInfos + tracker2.stop() + } + test("recovery and cleanup with write ahead logs") { val manualClock = new ManualClock // Set the time increment level to twice the rotation interval so that every increment creates From 185f5bc7dd52cebe8fac9393ecb2bd0968bc5867 Mon Sep 17 00:00:00 2001 From: Andrew Korzhuev Date: Mon, 26 Feb 2018 10:28:45 -0800 Subject: [PATCH 0394/1161] [SPARK-23449][K8S] Preserve extraJavaOptions ordering For some JVM options, like `-XX:+UnlockExperimentalVMOptions` ordering is necessary. ## What changes were proposed in this pull request? Keep original `extraJavaOptions` ordering, when passing them through environment variables inside the Docker container. ## How was this patch tested? Ran base branch a couple of times and checked startup command in logs. Ordering differed every time. Added sorting, ordering was consistent to what user had in `extraJavaOptions`. Author: Andrew Korzhuev Closes #20628 from andrusha/patch-2. --- .../kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh index b9090dc2852a5..3d67b0a702dd4 100755 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh @@ -41,7 +41,7 @@ fi shift 1 SPARK_CLASSPATH="$SPARK_CLASSPATH:${SPARK_HOME}/jars/*" -env | grep SPARK_JAVA_OPT_ | sed 's/[^=]*=\(.*\)/\1/g' > /tmp/java_opts.txt +env | grep SPARK_JAVA_OPT_ | sort -t_ -k4 -n | sed 's/[^=]*=\(.*\)/\1/g' > /tmp/java_opts.txt readarray -t SPARK_JAVA_OPTS < /tmp/java_opts.txt if [ -n "$SPARK_MOUNTED_CLASSPATH" ]; then SPARK_CLASSPATH="$SPARK_CLASSPATH:$SPARK_MOUNTED_CLASSPATH" From 7ec83658fbc88505dfc2d8a6f76e90db747f1292 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Mon, 26 Feb 2018 11:28:44 -0800 Subject: [PATCH 0395/1161] [SPARK-23491][SS] Remove explicit job cancellation from ContinuousExecution reconfiguring ## What changes were proposed in this pull request? Remove queryExecutionThread.interrupt() from ContinuousExecution. As detailed in the JIRA, interrupting the thread is only relevant in the microbatch case; for continuous processing the query execution can quickly clean itself up without. ## How was this patch tested? existing tests Author: Jose Torres Closes #20622 from jose-torres/SPARK-23441. --- .../streaming/continuous/ContinuousExecution.scala | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index 2c1d6c509d21b..daebd1dd010ac 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -236,9 +236,7 @@ class ContinuousExecution( startTrigger() if (reader.needsReconfiguration() && state.compareAndSet(ACTIVE, RECONFIGURING)) { - stopSources() if (queryExecutionThread.isAlive) { - sparkSession.sparkContext.cancelJobGroup(runId.toString) queryExecutionThread.interrupt() } false @@ -266,12 +264,20 @@ class ContinuousExecution( SQLExecution.withNewExecutionId( sparkSessionForQuery, lastExecution)(lastExecution.toRdd) } + } catch { + case t: Throwable + if StreamExecution.isInterruptionException(t) && state.get() == RECONFIGURING => + logInfo(s"Query $id ignoring exception from reconfiguring: $t") + // interrupted by reconfiguration - swallow exception so we can restart the query } finally { epochEndpoint.askSync[Unit](StopContinuousExecutionWrites) SparkEnv.get.rpcEnv.stop(epochEndpoint) epochUpdateThread.interrupt() epochUpdateThread.join() + + stopSources() + sparkSession.sparkContext.cancelJobGroup(runId.toString) } } From 8077bb04f350fd35df83ef896135c0672dc3f7b0 Mon Sep 17 00:00:00 2001 From: Juliusz Sompolski Date: Mon, 26 Feb 2018 23:37:31 -0800 Subject: [PATCH 0396/1161] [SPARK-23445] ColumnStat refactoring ## What changes were proposed in this pull request? Refactor ColumnStat to be more flexible. * Split `ColumnStat` and `CatalogColumnStat` just like `CatalogStatistics` is split from `Statistics`. This detaches how the statistics are stored from how they are processed in the query plan. `CatalogColumnStat` keeps `min` and `max` as `String`, making it not depend on dataType information. * For `CatalogColumnStat`, parse column names from property names in the metastore (`KEY_VERSION` property), not from metastore schema. This means that `CatalogColumnStat`s can be created for columns even if the schema itself is not stored in the metastore. * Make all fields optional. `min`, `max` and `histogram` for columns were optional already. Having them all optional is more consistent, and gives flexibility to e.g. drop some of the fields through transformations if they are difficult / impossible to calculate. The added flexibility will make it possible to have alternative implementations for stats, and separates stats collection from stats and estimation processing in plans. ## How was this patch tested? Refactored existing tests to work with refactored `ColumnStat` and `CatalogColumnStat`. New tests added in `StatisticsSuite` checking that backwards / forwards compatibility is not broken. Author: Juliusz Sompolski Closes #20624 from juliuszsompolski/SPARK-23445. --- .../sql/catalyst/catalog/interface.scala | 146 ++++++++- .../optimizer/StarSchemaDetection.scala | 6 +- .../catalyst/plans/logical/Statistics.scala | 256 ++-------------- .../statsEstimation/AggregateEstimation.scala | 6 +- .../statsEstimation/EstimationUtils.scala | 20 +- .../statsEstimation/FilterEstimation.scala | 98 +++--- .../statsEstimation/JoinEstimation.scala | 55 ++-- .../catalyst/optimizer/JoinReorderSuite.scala | 25 +- .../StarJoinCostBasedReorderSuite.scala | 96 ++---- .../optimizer/StarJoinReorderSuite.scala | 77 ++--- .../AggregateEstimationSuite.scala | 24 +- .../BasicStatsEstimationSuite.scala | 12 +- .../FilterEstimationSuite.scala | 279 +++++++++--------- .../statsEstimation/JoinEstimationSuite.scala | 138 +++++---- .../ProjectEstimationSuite.scala | 70 +++-- .../StatsEstimationTestBase.scala | 10 +- .../command/AnalyzeColumnCommand.scala | 138 ++++++++- .../spark/sql/execution/command/tables.scala | 9 +- .../spark/sql/StatisticsCollectionSuite.scala | 9 +- .../sql/StatisticsCollectionTestBase.scala | 168 +++++++++-- .../spark/sql/hive/HiveExternalCatalog.scala | 63 ++-- .../spark/sql/hive/StatisticsSuite.scala | 162 +++------- 22 files changed, 995 insertions(+), 872 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala index 95b6fbb0cd61a..f3e67dc4e975c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala @@ -21,7 +21,9 @@ import java.net.URI import java.util.Date import scala.collection.mutable +import scala.util.control.NonFatal +import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.{FunctionIdentifier, InternalRow, TableIdentifier} import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation @@ -30,7 +32,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} import org.apache.spark.sql.catalyst.util.quoteIdentifier -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types._ /** @@ -361,7 +363,7 @@ object CatalogTable { case class CatalogStatistics( sizeInBytes: BigInt, rowCount: Option[BigInt] = None, - colStats: Map[String, ColumnStat] = Map.empty) { + colStats: Map[String, CatalogColumnStat] = Map.empty) { /** * Convert [[CatalogStatistics]] to [[Statistics]], and match column stats to attributes based @@ -369,7 +371,8 @@ case class CatalogStatistics( */ def toPlanStats(planOutput: Seq[Attribute], cboEnabled: Boolean): Statistics = { if (cboEnabled && rowCount.isDefined) { - val attrStats = AttributeMap(planOutput.flatMap(a => colStats.get(a.name).map(a -> _))) + val attrStats = AttributeMap(planOutput + .flatMap(a => colStats.get(a.name).map(a -> _.toPlanStat(a.name, a.dataType)))) // Estimate size as number of rows * row size. val size = EstimationUtils.getOutputSize(planOutput, rowCount.get, attrStats) Statistics(sizeInBytes = size, rowCount = rowCount, attributeStats = attrStats) @@ -387,6 +390,143 @@ case class CatalogStatistics( } } +/** + * This class of statistics for a column is used in [[CatalogTable]] to interact with metastore. + */ +case class CatalogColumnStat( + distinctCount: Option[BigInt] = None, + min: Option[String] = None, + max: Option[String] = None, + nullCount: Option[BigInt] = None, + avgLen: Option[Long] = None, + maxLen: Option[Long] = None, + histogram: Option[Histogram] = None) { + + /** + * Returns a map from string to string that can be used to serialize the column stats. + * The key is the name of the column and name of the field (e.g. "colName.distinctCount"), + * and the value is the string representation for the value. + * min/max values are stored as Strings. They can be deserialized using + * [[CatalogColumnStat.fromExternalString]]. + * + * As part of the protocol, the returned map always contains a key called "version". + * Any of the fields that are null (None) won't appear in the map. + */ + def toMap(colName: String): Map[String, String] = { + val map = new scala.collection.mutable.HashMap[String, String] + map.put(s"${colName}.${CatalogColumnStat.KEY_VERSION}", "1") + distinctCount.foreach { v => + map.put(s"${colName}.${CatalogColumnStat.KEY_DISTINCT_COUNT}", v.toString) + } + nullCount.foreach { v => + map.put(s"${colName}.${CatalogColumnStat.KEY_NULL_COUNT}", v.toString) + } + avgLen.foreach { v => map.put(s"${colName}.${CatalogColumnStat.KEY_AVG_LEN}", v.toString) } + maxLen.foreach { v => map.put(s"${colName}.${CatalogColumnStat.KEY_MAX_LEN}", v.toString) } + min.foreach { v => map.put(s"${colName}.${CatalogColumnStat.KEY_MIN_VALUE}", v) } + max.foreach { v => map.put(s"${colName}.${CatalogColumnStat.KEY_MAX_VALUE}", v) } + histogram.foreach { h => + map.put(s"${colName}.${CatalogColumnStat.KEY_HISTOGRAM}", HistogramSerializer.serialize(h)) + } + map.toMap + } + + /** Convert [[CatalogColumnStat]] to [[ColumnStat]]. */ + def toPlanStat( + colName: String, + dataType: DataType): ColumnStat = + ColumnStat( + distinctCount = distinctCount, + min = min.map(CatalogColumnStat.fromExternalString(_, colName, dataType)), + max = max.map(CatalogColumnStat.fromExternalString(_, colName, dataType)), + nullCount = nullCount, + avgLen = avgLen, + maxLen = maxLen, + histogram = histogram) +} + +object CatalogColumnStat extends Logging { + + // List of string keys used to serialize CatalogColumnStat + val KEY_VERSION = "version" + private val KEY_DISTINCT_COUNT = "distinctCount" + private val KEY_MIN_VALUE = "min" + private val KEY_MAX_VALUE = "max" + private val KEY_NULL_COUNT = "nullCount" + private val KEY_AVG_LEN = "avgLen" + private val KEY_MAX_LEN = "maxLen" + private val KEY_HISTOGRAM = "histogram" + + /** + * Converts from string representation of data type to the corresponding Catalyst data type. + */ + def fromExternalString(s: String, name: String, dataType: DataType): Any = { + dataType match { + case BooleanType => s.toBoolean + case DateType => DateTimeUtils.fromJavaDate(java.sql.Date.valueOf(s)) + case TimestampType => DateTimeUtils.fromJavaTimestamp(java.sql.Timestamp.valueOf(s)) + case ByteType => s.toByte + case ShortType => s.toShort + case IntegerType => s.toInt + case LongType => s.toLong + case FloatType => s.toFloat + case DoubleType => s.toDouble + case _: DecimalType => Decimal(s) + // This version of Spark does not use min/max for binary/string types so we ignore it. + case BinaryType | StringType => null + case _ => + throw new AnalysisException("Column statistics deserialization is not supported for " + + s"column $name of data type: $dataType.") + } + } + + /** + * Converts the given value from Catalyst data type to string representation of external + * data type. + */ + def toExternalString(v: Any, colName: String, dataType: DataType): String = { + val externalValue = dataType match { + case DateType => DateTimeUtils.toJavaDate(v.asInstanceOf[Int]) + case TimestampType => DateTimeUtils.toJavaTimestamp(v.asInstanceOf[Long]) + case BooleanType | _: IntegralType | FloatType | DoubleType => v + case _: DecimalType => v.asInstanceOf[Decimal].toJavaBigDecimal + // This version of Spark does not use min/max for binary/string types so we ignore it. + case _ => + throw new AnalysisException("Column statistics serialization is not supported for " + + s"column $colName of data type: $dataType.") + } + externalValue.toString + } + + + /** + * Creates a [[CatalogColumnStat]] object from the given map. + * This is used to deserialize column stats from some external storage. + * The serialization side is defined in [[CatalogColumnStat.toMap]]. + */ + def fromMap( + table: String, + colName: String, + map: Map[String, String]): Option[CatalogColumnStat] = { + + try { + Some(CatalogColumnStat( + distinctCount = map.get(s"${colName}.${KEY_DISTINCT_COUNT}").map(v => BigInt(v.toLong)), + min = map.get(s"${colName}.${KEY_MIN_VALUE}"), + max = map.get(s"${colName}.${KEY_MAX_VALUE}"), + nullCount = map.get(s"${colName}.${KEY_NULL_COUNT}").map(v => BigInt(v.toLong)), + avgLen = map.get(s"${colName}.${KEY_AVG_LEN}").map(_.toLong), + maxLen = map.get(s"${colName}.${KEY_MAX_LEN}").map(_.toLong), + histogram = map.get(s"${colName}.${KEY_HISTOGRAM}").map(HistogramSerializer.deserialize) + )) + } catch { + case NonFatal(e) => + logWarning(s"Failed to parse column statistics for column ${colName} in table $table", e) + None + } + } +} + case class CatalogTableType private(name: String) object CatalogTableType { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/StarSchemaDetection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/StarSchemaDetection.scala index 1f20b7661489e..2aa762e2595ad 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/StarSchemaDetection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/StarSchemaDetection.scala @@ -187,11 +187,11 @@ object StarSchemaDetection extends PredicateHelper { stats.rowCount match { case Some(rowCount) if rowCount >= 0 => if (stats.attributeStats.nonEmpty && stats.attributeStats.contains(col)) { - val colStats = stats.attributeStats.get(col) - if (colStats.get.nullCount > 0) { + val colStats = stats.attributeStats.get(col).get + if (!colStats.hasCountStats || colStats.nullCount.get > 0) { false } else { - val distinctCount = colStats.get.distinctCount + val distinctCount = colStats.distinctCount.get val relDiff = math.abs((distinctCount.toDouble / rowCount.toDouble) - 1.0d) // ndvMaxErr adjusted based on TPCDS 1TB data results relDiff <= conf.ndvMaxError * 2 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala index 96b199d7f20b0..b3a48860aa63b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala @@ -27,6 +27,7 @@ import net.jpountz.lz4.{LZ4BlockInputStream, LZ4BlockOutputStream} import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.catalog.CatalogColumnStat import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.util.{ArrayData, DateTimeUtils} @@ -79,11 +80,10 @@ case class Statistics( /** * Statistics collected for a column. * - * 1. Supported data types are defined in `ColumnStat.supportsType`. - * 2. The JVM data type stored in min/max is the internal data type for the corresponding + * 1. The JVM data type stored in min/max is the internal data type for the corresponding * Catalyst data type. For example, the internal type of DateType is Int, and that the internal * type of TimestampType is Long. - * 3. There is no guarantee that the statistics collected are accurate. Approximation algorithms + * 2. There is no guarantee that the statistics collected are accurate. Approximation algorithms * (sketches) might have been used, and the data collected can also be stale. * * @param distinctCount number of distinct values @@ -95,240 +95,32 @@ case class Statistics( * @param histogram histogram of the values */ case class ColumnStat( - distinctCount: BigInt, - min: Option[Any], - max: Option[Any], - nullCount: BigInt, - avgLen: Long, - maxLen: Long, + distinctCount: Option[BigInt] = None, + min: Option[Any] = None, + max: Option[Any] = None, + nullCount: Option[BigInt] = None, + avgLen: Option[Long] = None, + maxLen: Option[Long] = None, histogram: Option[Histogram] = None) { - // We currently don't store min/max for binary/string type. This can change in the future and - // then we need to remove this require. - require(min.isEmpty || (!min.get.isInstanceOf[Array[Byte]] && !min.get.isInstanceOf[String])) - require(max.isEmpty || (!max.get.isInstanceOf[Array[Byte]] && !max.get.isInstanceOf[String])) - - /** - * Returns a map from string to string that can be used to serialize the column stats. - * The key is the name of the field (e.g. "distinctCount" or "min"), and the value is the string - * representation for the value. min/max values are converted to the external data type. For - * example, for DateType we store java.sql.Date, and for TimestampType we store - * java.sql.Timestamp. The deserialization side is defined in [[ColumnStat.fromMap]]. - * - * As part of the protocol, the returned map always contains a key called "version". - * In the case min/max values are null (None), they won't appear in the map. - */ - def toMap(colName: String, dataType: DataType): Map[String, String] = { - val map = new scala.collection.mutable.HashMap[String, String] - map.put(ColumnStat.KEY_VERSION, "1") - map.put(ColumnStat.KEY_DISTINCT_COUNT, distinctCount.toString) - map.put(ColumnStat.KEY_NULL_COUNT, nullCount.toString) - map.put(ColumnStat.KEY_AVG_LEN, avgLen.toString) - map.put(ColumnStat.KEY_MAX_LEN, maxLen.toString) - min.foreach { v => map.put(ColumnStat.KEY_MIN_VALUE, toExternalString(v, colName, dataType)) } - max.foreach { v => map.put(ColumnStat.KEY_MAX_VALUE, toExternalString(v, colName, dataType)) } - histogram.foreach { h => map.put(ColumnStat.KEY_HISTOGRAM, HistogramSerializer.serialize(h)) } - map.toMap - } - - /** - * Converts the given value from Catalyst data type to string representation of external - * data type. - */ - private def toExternalString(v: Any, colName: String, dataType: DataType): String = { - val externalValue = dataType match { - case DateType => DateTimeUtils.toJavaDate(v.asInstanceOf[Int]) - case TimestampType => DateTimeUtils.toJavaTimestamp(v.asInstanceOf[Long]) - case BooleanType | _: IntegralType | FloatType | DoubleType => v - case _: DecimalType => v.asInstanceOf[Decimal].toJavaBigDecimal - // This version of Spark does not use min/max for binary/string types so we ignore it. - case _ => - throw new AnalysisException("Column statistics deserialization is not supported for " + - s"column $colName of data type: $dataType.") - } - externalValue.toString - } - -} + // Are distinctCount and nullCount statistics defined? + val hasCountStats = distinctCount.isDefined && nullCount.isDefined + // Are min and max statistics defined? + val hasMinMaxStats = min.isDefined && max.isDefined -object ColumnStat extends Logging { - - // List of string keys used to serialize ColumnStat - val KEY_VERSION = "version" - private val KEY_DISTINCT_COUNT = "distinctCount" - private val KEY_MIN_VALUE = "min" - private val KEY_MAX_VALUE = "max" - private val KEY_NULL_COUNT = "nullCount" - private val KEY_AVG_LEN = "avgLen" - private val KEY_MAX_LEN = "maxLen" - private val KEY_HISTOGRAM = "histogram" - - /** Returns true iff the we support gathering column statistics on column of the given type. */ - def supportsType(dataType: DataType): Boolean = dataType match { - case _: IntegralType => true - case _: DecimalType => true - case DoubleType | FloatType => true - case BooleanType => true - case DateType => true - case TimestampType => true - case BinaryType | StringType => true - case _ => false - } - - /** Returns true iff the we support gathering histogram on column of the given type. */ - def supportsHistogram(dataType: DataType): Boolean = dataType match { - case _: IntegralType => true - case _: DecimalType => true - case DoubleType | FloatType => true - case DateType => true - case TimestampType => true - case _ => false - } - - /** - * Creates a [[ColumnStat]] object from the given map. This is used to deserialize column stats - * from some external storage. The serialization side is defined in [[ColumnStat.toMap]]. - */ - def fromMap(table: String, field: StructField, map: Map[String, String]): Option[ColumnStat] = { - try { - Some(ColumnStat( - distinctCount = BigInt(map(KEY_DISTINCT_COUNT).toLong), - // Note that flatMap(Option.apply) turns Option(null) into None. - min = map.get(KEY_MIN_VALUE) - .map(fromExternalString(_, field.name, field.dataType)).flatMap(Option.apply), - max = map.get(KEY_MAX_VALUE) - .map(fromExternalString(_, field.name, field.dataType)).flatMap(Option.apply), - nullCount = BigInt(map(KEY_NULL_COUNT).toLong), - avgLen = map.getOrElse(KEY_AVG_LEN, field.dataType.defaultSize.toString).toLong, - maxLen = map.getOrElse(KEY_MAX_LEN, field.dataType.defaultSize.toString).toLong, - histogram = map.get(KEY_HISTOGRAM).map(HistogramSerializer.deserialize) - )) - } catch { - case NonFatal(e) => - logWarning(s"Failed to parse column statistics for column ${field.name} in table $table", e) - None - } - } - - /** - * Converts from string representation of external data type to the corresponding Catalyst data - * type. - */ - private def fromExternalString(s: String, name: String, dataType: DataType): Any = { - dataType match { - case BooleanType => s.toBoolean - case DateType => DateTimeUtils.fromJavaDate(java.sql.Date.valueOf(s)) - case TimestampType => DateTimeUtils.fromJavaTimestamp(java.sql.Timestamp.valueOf(s)) - case ByteType => s.toByte - case ShortType => s.toShort - case IntegerType => s.toInt - case LongType => s.toLong - case FloatType => s.toFloat - case DoubleType => s.toDouble - case _: DecimalType => Decimal(s) - // This version of Spark does not use min/max for binary/string types so we ignore it. - case BinaryType | StringType => null - case _ => - throw new AnalysisException("Column statistics deserialization is not supported for " + - s"column $name of data type: $dataType.") - } - } - - /** - * Constructs an expression to compute column statistics for a given column. - * - * The expression should create a single struct column with the following schema: - * distinctCount: Long, min: T, max: T, nullCount: Long, avgLen: Long, maxLen: Long, - * distinctCountsForIntervals: Array[Long] - * - * Together with [[rowToColumnStat]], this function is used to create [[ColumnStat]] and - * as a result should stay in sync with it. - */ - def statExprs( - col: Attribute, - conf: SQLConf, - colPercentiles: AttributeMap[ArrayData]): CreateNamedStruct = { - def struct(exprs: Expression*): CreateNamedStruct = CreateStruct(exprs.map { expr => - expr.transformUp { case af: AggregateFunction => af.toAggregateExpression() } - }) - val one = Literal(1, LongType) - - // the approximate ndv (num distinct value) should never be larger than the number of rows - val numNonNulls = if (col.nullable) Count(col) else Count(one) - val ndv = Least(Seq(HyperLogLogPlusPlus(col, conf.ndvMaxError), numNonNulls)) - val numNulls = Subtract(Count(one), numNonNulls) - val defaultSize = Literal(col.dataType.defaultSize, LongType) - val nullArray = Literal(null, ArrayType(LongType)) - - def fixedLenTypeStruct: CreateNamedStruct = { - val genHistogram = - ColumnStat.supportsHistogram(col.dataType) && colPercentiles.contains(col) - val intervalNdvsExpr = if (genHistogram) { - ApproxCountDistinctForIntervals(col, - Literal(colPercentiles(col), ArrayType(col.dataType)), conf.ndvMaxError) - } else { - nullArray - } - // For fixed width types, avg size should be the same as max size. - struct(ndv, Cast(Min(col), col.dataType), Cast(Max(col), col.dataType), numNulls, - defaultSize, defaultSize, intervalNdvsExpr) - } - - col.dataType match { - case _: IntegralType => fixedLenTypeStruct - case _: DecimalType => fixedLenTypeStruct - case DoubleType | FloatType => fixedLenTypeStruct - case BooleanType => fixedLenTypeStruct - case DateType => fixedLenTypeStruct - case TimestampType => fixedLenTypeStruct - case BinaryType | StringType => - // For string and binary type, we don't compute min, max or histogram - val nullLit = Literal(null, col.dataType) - struct( - ndv, nullLit, nullLit, numNulls, - // Set avg/max size to default size if all the values are null or there is no value. - Coalesce(Seq(Ceil(Average(Length(col))), defaultSize)), - Coalesce(Seq(Cast(Max(Length(col)), LongType), defaultSize)), - nullArray) - case _ => - throw new AnalysisException("Analyzing column statistics is not supported for column " + - s"${col.name} of data type: ${col.dataType}.") - } - } - - /** Convert a struct for column stats (defined in `statExprs`) into [[ColumnStat]]. */ - def rowToColumnStat( - row: InternalRow, - attr: Attribute, - rowCount: Long, - percentiles: Option[ArrayData]): ColumnStat = { - // The first 6 fields are basic column stats, the 7th is ndvs for histogram bins. - val cs = ColumnStat( - distinctCount = BigInt(row.getLong(0)), - // for string/binary min/max, get should return null - min = Option(row.get(1, attr.dataType)), - max = Option(row.get(2, attr.dataType)), - nullCount = BigInt(row.getLong(3)), - avgLen = row.getLong(4), - maxLen = row.getLong(5) - ) - if (row.isNullAt(6)) { - cs - } else { - val ndvs = row.getArray(6).toLongArray() - assert(percentiles.get.numElements() == ndvs.length + 1) - val endpoints = percentiles.get.toArray[Any](attr.dataType).map(_.toString.toDouble) - // Construct equi-height histogram - val bins = ndvs.zipWithIndex.map { case (ndv, i) => - HistogramBin(endpoints(i), endpoints(i + 1), ndv) - } - val nonNullRows = rowCount - cs.nullCount - val histogram = Histogram(nonNullRows.toDouble / ndvs.length, bins) - cs.copy(histogram = Some(histogram)) - } - } + // Are avgLen and maxLen statistics defined? + val hasLenStats = avgLen.isDefined && maxLen.isDefined + def toCatalogColumnStat(colName: String, dataType: DataType): CatalogColumnStat = + CatalogColumnStat( + distinctCount = distinctCount, + min = min.map(CatalogColumnStat.toExternalString(_, colName, dataType)), + max = max.map(CatalogColumnStat.toExternalString(_, colName, dataType)), + nullCount = nullCount, + avgLen = avgLen, + maxLen = maxLen, + histogram = histogram) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala index c41fac4015ec0..111c594a53e52 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala @@ -32,13 +32,15 @@ object AggregateEstimation { val childStats = agg.child.stats // Check if we have column stats for all group-by columns. val colStatsExist = agg.groupingExpressions.forall { e => - e.isInstanceOf[Attribute] && childStats.attributeStats.contains(e.asInstanceOf[Attribute]) + e.isInstanceOf[Attribute] && + childStats.attributeStats.get(e.asInstanceOf[Attribute]).exists(_.hasCountStats) } if (rowCountsExist(agg.child) && colStatsExist) { // Multiply distinct counts of group-by columns. This is an upper bound, which assumes // the data contains all combinations of distinct values of group-by columns. var outputRows: BigInt = agg.groupingExpressions.foldLeft(BigInt(1))( - (res, expr) => res * childStats.attributeStats(expr.asInstanceOf[Attribute]).distinctCount) + (res, expr) => res * + childStats.attributeStats(expr.asInstanceOf[Attribute]).distinctCount.get) outputRows = if (agg.groupingExpressions.isEmpty) { // If there's no group-by columns, the output is a single row containing values of aggregate diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala index d793f77413d18..0f147f0ffb135 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.plans.logical.statsEstimation +import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import scala.math.BigDecimal.RoundingMode @@ -38,9 +39,18 @@ object EstimationUtils { } } + /** Check if each attribute has column stat containing distinct and null counts + * in the corresponding statistic. */ + def columnStatsWithCountsExist(statsAndAttr: (Statistics, Attribute)*): Boolean = { + statsAndAttr.forall { case (stats, attr) => + stats.attributeStats.get(attr).map(_.hasCountStats).getOrElse(false) + } + } + + /** Statistics for a Column containing only NULLs. */ def nullColumnStat(dataType: DataType, rowCount: BigInt): ColumnStat = { - ColumnStat(distinctCount = 0, min = None, max = None, nullCount = rowCount, - avgLen = dataType.defaultSize, maxLen = dataType.defaultSize) + ColumnStat(distinctCount = Some(0), min = None, max = None, nullCount = Some(rowCount), + avgLen = Some(dataType.defaultSize), maxLen = Some(dataType.defaultSize)) } /** @@ -70,13 +80,13 @@ object EstimationUtils { // We assign a generic overhead for a Row object, the actual overhead is different for different // Row format. val sizePerRow = 8 + attributes.map { attr => - if (attrStats.contains(attr)) { + if (attrStats.get(attr).map(_.avgLen.isDefined).getOrElse(false)) { attr.dataType match { case StringType => // UTF8String: base + offset + numBytes - attrStats(attr).avgLen + 8 + 4 + attrStats(attr).avgLen.get + 8 + 4 case _ => - attrStats(attr).avgLen + attrStats(attr).avgLen.get } } else { attr.dataType.defaultSize diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala index 4cc32de2d32d7..0538c9d88584b 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala @@ -225,7 +225,7 @@ case class FilterEstimation(plan: Filter) extends Logging { attr: Attribute, isNull: Boolean, update: Boolean): Option[Double] = { - if (!colStatsMap.contains(attr)) { + if (!colStatsMap.contains(attr) || !colStatsMap(attr).hasCountStats) { logDebug("[CBO] No statistics for " + attr) return None } @@ -234,14 +234,14 @@ case class FilterEstimation(plan: Filter) extends Logging { val nullPercent: Double = if (rowCountValue == 0) { 0 } else { - (BigDecimal(colStat.nullCount) / BigDecimal(rowCountValue)).toDouble + (BigDecimal(colStat.nullCount.get) / BigDecimal(rowCountValue)).toDouble } if (update) { val newStats = if (isNull) { - colStat.copy(distinctCount = 0, min = None, max = None) + colStat.copy(distinctCount = Some(0), min = None, max = None) } else { - colStat.copy(nullCount = 0) + colStat.copy(nullCount = Some(0)) } colStatsMap.update(attr, newStats) } @@ -322,17 +322,21 @@ case class FilterEstimation(plan: Filter) extends Logging { // value. val newStats = attr.dataType match { case StringType | BinaryType => - colStat.copy(distinctCount = 1, nullCount = 0) + colStat.copy(distinctCount = Some(1), nullCount = Some(0)) case _ => - colStat.copy(distinctCount = 1, min = Some(literal.value), - max = Some(literal.value), nullCount = 0) + colStat.copy(distinctCount = Some(1), min = Some(literal.value), + max = Some(literal.value), nullCount = Some(0)) } colStatsMap.update(attr, newStats) } if (colStat.histogram.isEmpty) { - // returns 1/ndv if there is no histogram - Some(1.0 / colStat.distinctCount.toDouble) + if (!colStat.distinctCount.isEmpty) { + // returns 1/ndv if there is no histogram + Some(1.0 / colStat.distinctCount.get.toDouble) + } else { + None + } } else { Some(computeEqualityPossibilityByHistogram(literal, colStat)) } @@ -378,13 +382,13 @@ case class FilterEstimation(plan: Filter) extends Logging { attr: Attribute, hSet: Set[Any], update: Boolean): Option[Double] = { - if (!colStatsMap.contains(attr)) { + if (!colStatsMap.hasDistinctCount(attr)) { logDebug("[CBO] No statistics for " + attr) return None } val colStat = colStatsMap(attr) - val ndv = colStat.distinctCount + val ndv = colStat.distinctCount.get val dataType = attr.dataType var newNdv = ndv @@ -407,8 +411,8 @@ case class FilterEstimation(plan: Filter) extends Logging { // 1 and 6. The predicate column IN (1, 2, 3, 4, 5). validQuerySet.size is 5. newNdv = ndv.min(BigInt(validQuerySet.size)) if (update) { - val newStats = colStat.copy(distinctCount = newNdv, min = Some(newMin), - max = Some(newMax), nullCount = 0) + val newStats = colStat.copy(distinctCount = Some(newNdv), min = Some(newMin), + max = Some(newMax), nullCount = Some(0)) colStatsMap.update(attr, newStats) } @@ -416,7 +420,7 @@ case class FilterEstimation(plan: Filter) extends Logging { case StringType | BinaryType => newNdv = ndv.min(BigInt(hSet.size)) if (update) { - val newStats = colStat.copy(distinctCount = newNdv, nullCount = 0) + val newStats = colStat.copy(distinctCount = Some(newNdv), nullCount = Some(0)) colStatsMap.update(attr, newStats) } } @@ -443,12 +447,17 @@ case class FilterEstimation(plan: Filter) extends Logging { literal: Literal, update: Boolean): Option[Double] = { + if (!colStatsMap.hasMinMaxStats(attr) || !colStatsMap.hasDistinctCount(attr)) { + logDebug("[CBO] No statistics for " + attr) + return None + } + val colStat = colStatsMap(attr) val statsInterval = ValueInterval(colStat.min, colStat.max, attr.dataType).asInstanceOf[NumericValueInterval] val max = statsInterval.max val min = statsInterval.min - val ndv = colStat.distinctCount.toDouble + val ndv = colStat.distinctCount.get.toDouble // determine the overlapping degree between predicate interval and column's interval val numericLiteral = EstimationUtils.toDouble(literal.value, literal.dataType) @@ -520,8 +529,8 @@ case class FilterEstimation(plan: Filter) extends Logging { newMax = newValue } - val newStats = colStat.copy(distinctCount = ceil(ndv * percent), - min = newMin, max = newMax, nullCount = 0) + val newStats = colStat.copy(distinctCount = Some(ceil(ndv * percent)), + min = newMin, max = newMax, nullCount = Some(0)) colStatsMap.update(attr, newStats) } @@ -637,11 +646,11 @@ case class FilterEstimation(plan: Filter) extends Logging { attrRight: Attribute, update: Boolean): Option[Double] = { - if (!colStatsMap.contains(attrLeft)) { + if (!colStatsMap.hasCountStats(attrLeft)) { logDebug("[CBO] No statistics for " + attrLeft) return None } - if (!colStatsMap.contains(attrRight)) { + if (!colStatsMap.hasCountStats(attrRight)) { logDebug("[CBO] No statistics for " + attrRight) return None } @@ -668,7 +677,7 @@ case class FilterEstimation(plan: Filter) extends Logging { val minRight = statsIntervalRight.min // determine the overlapping degree between predicate interval and column's interval - val allNotNull = (colStatLeft.nullCount == 0) && (colStatRight.nullCount == 0) + val allNotNull = (colStatLeft.nullCount.get == 0) && (colStatRight.nullCount.get == 0) val (noOverlap: Boolean, completeOverlap: Boolean) = op match { // Left < Right or Left <= Right // - no overlap: @@ -707,14 +716,14 @@ case class FilterEstimation(plan: Filter) extends Logging { case _: EqualTo => ((maxLeft < minRight) || (maxRight < minLeft), (minLeft == minRight) && (maxLeft == maxRight) && allNotNull - && (colStatLeft.distinctCount == colStatRight.distinctCount) + && (colStatLeft.distinctCount.get == colStatRight.distinctCount.get) ) case _: EqualNullSafe => // For null-safe equality, we use a very restrictive condition to evaluate its overlap. // If null values exists, we set it to partial overlap. (((maxLeft < minRight) || (maxRight < minLeft)) && allNotNull, (minLeft == minRight) && (maxLeft == maxRight) && allNotNull - && (colStatLeft.distinctCount == colStatRight.distinctCount) + && (colStatLeft.distinctCount.get == colStatRight.distinctCount.get) ) } @@ -731,9 +740,9 @@ case class FilterEstimation(plan: Filter) extends Logging { if (update) { // Need to adjust new min/max after the filter condition is applied - val ndvLeft = BigDecimal(colStatLeft.distinctCount) + val ndvLeft = BigDecimal(colStatLeft.distinctCount.get) val newNdvLeft = ceil(ndvLeft * percent) - val ndvRight = BigDecimal(colStatRight.distinctCount) + val ndvRight = BigDecimal(colStatRight.distinctCount.get) val newNdvRight = ceil(ndvRight * percent) var newMaxLeft = colStatLeft.max @@ -817,10 +826,10 @@ case class FilterEstimation(plan: Filter) extends Logging { } } - val newStatsLeft = colStatLeft.copy(distinctCount = newNdvLeft, min = newMinLeft, + val newStatsLeft = colStatLeft.copy(distinctCount = Some(newNdvLeft), min = newMinLeft, max = newMaxLeft) colStatsMap(attrLeft) = newStatsLeft - val newStatsRight = colStatRight.copy(distinctCount = newNdvRight, min = newMinRight, + val newStatsRight = colStatRight.copy(distinctCount = Some(newNdvRight), min = newMinRight, max = newMaxRight) colStatsMap(attrRight) = newStatsRight } @@ -849,17 +858,35 @@ case class ColumnStatsMap(originalMap: AttributeMap[ColumnStat]) { def contains(a: Attribute): Boolean = updatedMap.contains(a.exprId) || originalMap.contains(a) /** - * Gets column stat for the given attribute. Prefer the column stat in updatedMap than that in - * originalMap, because updatedMap has the latest (updated) column stats. + * Gets an Option of column stat for the given attribute. + * Prefer the column stat in updatedMap than that in originalMap, + * because updatedMap has the latest (updated) column stats. */ - def apply(a: Attribute): ColumnStat = { + def get(a: Attribute): Option[ColumnStat] = { if (updatedMap.contains(a.exprId)) { - updatedMap(a.exprId)._2 + updatedMap.get(a.exprId).map(_._2) } else { - originalMap(a) + originalMap.get(a) } } + def hasCountStats(a: Attribute): Boolean = + get(a).map(_.hasCountStats).getOrElse(false) + + def hasDistinctCount(a: Attribute): Boolean = + get(a).map(_.distinctCount.isDefined).getOrElse(false) + + def hasMinMaxStats(a: Attribute): Boolean = + get(a).map(_.hasCountStats).getOrElse(false) + + /** + * Gets column stat for the given attribute. Prefer the column stat in updatedMap than that in + * originalMap, because updatedMap has the latest (updated) column stats. + */ + def apply(a: Attribute): ColumnStat = { + get(a).get + } + /** Updates column stats in updatedMap. */ def update(a: Attribute, stats: ColumnStat): Unit = updatedMap.update(a.exprId, a -> stats) @@ -871,11 +898,14 @@ case class ColumnStatsMap(originalMap: AttributeMap[ColumnStat]) { : AttributeMap[ColumnStat] = { val newColumnStats = originalMap.map { case (attr, oriColStat) => val colStat = updatedMap.get(attr.exprId).map(_._2).getOrElse(oriColStat) - val newNdv = if (colStat.distinctCount > 1) { + val newNdv = if (colStat.distinctCount.isEmpty) { + // No NDV in the original stats. + None + } else if (colStat.distinctCount.get > 1) { // Update ndv based on the overall filter selectivity: scale down ndv if the number of rows // decreases; otherwise keep it unchanged. - EstimationUtils.updateNdv(oldNumRows = rowsBeforeFilter, - newNumRows = rowsAfterFilter, oldNdv = oriColStat.distinctCount) + Some(EstimationUtils.updateNdv(oldNumRows = rowsBeforeFilter, + newNumRows = rowsAfterFilter, oldNdv = oriColStat.distinctCount.get)) } else { // no need to scale down since it is already down to 1 (for skewed distribution case) colStat.distinctCount diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala index f0294a4246703..2543e38a92c0a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala @@ -85,7 +85,8 @@ case class JoinEstimation(join: Join) extends Logging { // 3. Update statistics based on the output of join val inputAttrStats = AttributeMap( leftStats.attributeStats.toSeq ++ rightStats.attributeStats.toSeq) - val attributesWithStat = join.output.filter(a => inputAttrStats.contains(a)) + val attributesWithStat = join.output.filter(a => + inputAttrStats.get(a).map(_.hasCountStats).getOrElse(false)) val (fromLeft, fromRight) = attributesWithStat.partition(join.left.outputSet.contains(_)) val outputStats: Seq[(Attribute, ColumnStat)] = if (outputRows == 0) { @@ -106,10 +107,10 @@ case class JoinEstimation(join: Join) extends Logging { case FullOuter => fromLeft.map { a => val oriColStat = inputAttrStats(a) - (a, oriColStat.copy(nullCount = oriColStat.nullCount + rightRows)) + (a, oriColStat.copy(nullCount = Some(oriColStat.nullCount.get + rightRows))) } ++ fromRight.map { a => val oriColStat = inputAttrStats(a) - (a, oriColStat.copy(nullCount = oriColStat.nullCount + leftRows)) + (a, oriColStat.copy(nullCount = Some(oriColStat.nullCount.get + leftRows))) } case _ => assert(joinType == Inner || joinType == Cross) @@ -219,19 +220,27 @@ case class JoinEstimation(join: Join) extends Logging { private def computeByNdv( leftKey: AttributeReference, rightKey: AttributeReference, - newMin: Option[Any], - newMax: Option[Any]): (BigInt, ColumnStat) = { + min: Option[Any], + max: Option[Any]): (BigInt, ColumnStat) = { val leftKeyStat = leftStats.attributeStats(leftKey) val rightKeyStat = rightStats.attributeStats(rightKey) - val maxNdv = leftKeyStat.distinctCount.max(rightKeyStat.distinctCount) + val maxNdv = leftKeyStat.distinctCount.get.max(rightKeyStat.distinctCount.get) // Compute cardinality by the basic formula. val card = BigDecimal(leftStats.rowCount.get * rightStats.rowCount.get) / BigDecimal(maxNdv) // Get the intersected column stat. - val newNdv = leftKeyStat.distinctCount.min(rightKeyStat.distinctCount) - val newMaxLen = math.min(leftKeyStat.maxLen, rightKeyStat.maxLen) - val newAvgLen = (leftKeyStat.avgLen + rightKeyStat.avgLen) / 2 - val newStats = ColumnStat(newNdv, newMin, newMax, 0, newAvgLen, newMaxLen) + val newNdv = Some(leftKeyStat.distinctCount.get.min(rightKeyStat.distinctCount.get)) + val newMaxLen = if (leftKeyStat.maxLen.isDefined && rightKeyStat.maxLen.isDefined) { + Some(math.min(leftKeyStat.maxLen.get, rightKeyStat.maxLen.get)) + } else { + None + } + val newAvgLen = if (leftKeyStat.avgLen.isDefined && rightKeyStat.avgLen.isDefined) { + Some((leftKeyStat.avgLen.get + rightKeyStat.avgLen.get) / 2) + } else { + None + } + val newStats = ColumnStat(newNdv, min, max, Some(0), newAvgLen, newMaxLen) (ceil(card), newStats) } @@ -267,9 +276,17 @@ case class JoinEstimation(join: Join) extends Logging { val leftKeyStat = leftStats.attributeStats(leftKey) val rightKeyStat = rightStats.attributeStats(rightKey) - val newMaxLen = math.min(leftKeyStat.maxLen, rightKeyStat.maxLen) - val newAvgLen = (leftKeyStat.avgLen + rightKeyStat.avgLen) / 2 - val newStats = ColumnStat(ceil(totalNdv), newMin, newMax, 0, newAvgLen, newMaxLen) + val newMaxLen = if (leftKeyStat.maxLen.isDefined && rightKeyStat.maxLen.isDefined) { + Some(math.min(leftKeyStat.maxLen.get, rightKeyStat.maxLen.get)) + } else { + None + } + val newAvgLen = if (leftKeyStat.avgLen.isDefined && rightKeyStat.avgLen.isDefined) { + Some((leftKeyStat.avgLen.get + rightKeyStat.avgLen.get) / 2) + } else { + None + } + val newStats = ColumnStat(Some(ceil(totalNdv)), newMin, newMax, Some(0), newAvgLen, newMaxLen) (ceil(card), newStats) } @@ -292,10 +309,14 @@ case class JoinEstimation(join: Join) extends Logging { } else { val oldColStat = oldAttrStats(a) val oldNdv = oldColStat.distinctCount - val newNdv = if (join.left.outputSet.contains(a)) { - updateNdv(oldNumRows = leftRows, newNumRows = outputRows, oldNdv = oldNdv) + val newNdv = if (oldNdv.isDefined) { + Some(if (join.left.outputSet.contains(a)) { + updateNdv(oldNumRows = leftRows, newNumRows = outputRows, oldNdv = oldNdv.get) + } else { + updateNdv(oldNumRows = rightRows, newNumRows = outputRows, oldNdv = oldNdv.get) + }) } else { - updateNdv(oldNumRows = rightRows, newNumRows = outputRows, oldNdv = oldNdv) + None } val newColStat = oldColStat.copy(distinctCount = newNdv) // TODO: support nullCount updates for specific outer joins @@ -313,7 +334,7 @@ case class JoinEstimation(join: Join) extends Logging { // Note: join keys from EqualNullSafe also fall into this case (Coalesce), consider to // support it in the future by using `nullCount` in column stats. case (lk: AttributeReference, rk: AttributeReference) - if columnStatsExist((leftStats, lk), (rightStats, rk)) => (lk, rk) + if columnStatsWithCountsExist((leftStats, lk), (rightStats, rk)) => (lk, rk) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala index 2fb587d50a4cb..565b0a10154a8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala @@ -62,24 +62,15 @@ class JoinReorderSuite extends PlanTest with StatsEstimationTestBase { } } - /** Set up tables and columns for testing */ private val columnInfo: AttributeMap[ColumnStat] = AttributeMap(Seq( - attr("t1.k-1-2") -> ColumnStat(distinctCount = 2, min = Some(1), max = Some(2), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("t1.v-1-10") -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("t2.k-1-5") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("t3.v-1-100") -> ColumnStat(distinctCount = 100, min = Some(1), max = Some(100), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("t4.k-1-2") -> ColumnStat(distinctCount = 2, min = Some(1), max = Some(2), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("t4.v-1-10") -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("t5.k-1-5") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("t5.v-1-5") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), - nullCount = 0, avgLen = 4, maxLen = 4) + attr("t1.k-1-2") -> rangeColumnStat(2, 0), + attr("t1.v-1-10") -> rangeColumnStat(10, 0), + attr("t2.k-1-5") -> rangeColumnStat(5, 0), + attr("t3.v-1-100") -> rangeColumnStat(100, 0), + attr("t4.k-1-2") -> rangeColumnStat(2, 0), + attr("t4.v-1-10") -> rangeColumnStat(10, 0), + attr("t5.k-1-5") -> rangeColumnStat(5, 0), + attr("t5.v-1-5") -> rangeColumnStat(5, 0) )) private val nameToAttr: Map[String, Attribute] = columnInfo.map(kv => kv._1.name -> kv._1) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinCostBasedReorderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinCostBasedReorderSuite.scala index ada6e2a43ea0f..d4d23ad69b2c2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinCostBasedReorderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinCostBasedReorderSuite.scala @@ -68,88 +68,56 @@ class StarJoinCostBasedReorderSuite extends PlanTest with StatsEstimationTestBas private val columnInfo: AttributeMap[ColumnStat] = AttributeMap(Seq( // F1 (fact table) - attr("f1_fk1") -> ColumnStat(distinctCount = 100, min = Some(1), max = Some(100), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("f1_fk2") -> ColumnStat(distinctCount = 100, min = Some(1), max = Some(100), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("f1_fk3") -> ColumnStat(distinctCount = 100, min = Some(1), max = Some(100), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("f1_c1") -> ColumnStat(distinctCount = 100, min = Some(1), max = Some(100), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("f1_c2") -> ColumnStat(distinctCount = 100, min = Some(1), max = Some(100), - nullCount = 0, avgLen = 4, maxLen = 4), + attr("f1_fk1") -> rangeColumnStat(100, 0), + attr("f1_fk2") -> rangeColumnStat(100, 0), + attr("f1_fk3") -> rangeColumnStat(100, 0), + attr("f1_c1") -> rangeColumnStat(100, 0), + attr("f1_c2") -> rangeColumnStat(100, 0), // D1 (dimension) - attr("d1_pk") -> ColumnStat(distinctCount = 100, min = Some(1), max = Some(100), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("d1_c2") -> ColumnStat(distinctCount = 50, min = Some(1), max = Some(50), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("d1_c3") -> ColumnStat(distinctCount = 50, min = Some(1), max = Some(50), - nullCount = 0, avgLen = 4, maxLen = 4), + attr("d1_pk") -> rangeColumnStat(100, 0), + attr("d1_c2") -> rangeColumnStat(50, 0), + attr("d1_c3") -> rangeColumnStat(50, 0), // D2 (dimension) - attr("d2_pk") -> ColumnStat(distinctCount = 20, min = Some(1), max = Some(20), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("d2_c2") -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("d2_c3") -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4), + attr("d2_pk") -> rangeColumnStat(20, 0), + attr("d2_c2") -> rangeColumnStat(10, 0), + attr("d2_c3") -> rangeColumnStat(10, 0), // D3 (dimension) - attr("d3_pk") -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("d3_c2") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("d3_c3") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), - nullCount = 0, avgLen = 4, maxLen = 4), + attr("d3_pk") -> rangeColumnStat(10, 0), + attr("d3_c2") -> rangeColumnStat(5, 0), + attr("d3_c3") -> rangeColumnStat(5, 0), // T1 (regular table i.e. outside star) - attr("t1_c1") -> ColumnStat(distinctCount = 20, min = Some(1), max = Some(20), - nullCount = 1, avgLen = 4, maxLen = 4), - attr("t1_c2") -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), - nullCount = 1, avgLen = 4, maxLen = 4), - attr("t1_c3") -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), - nullCount = 1, avgLen = 4, maxLen = 4), + attr("t1_c1") -> rangeColumnStat(20, 1), + attr("t1_c2") -> rangeColumnStat(10, 1), + attr("t1_c3") -> rangeColumnStat(10, 1), // T2 (regular table) - attr("t2_c1") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), - nullCount = 1, avgLen = 4, maxLen = 4), - attr("t2_c2") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), - nullCount = 1, avgLen = 4, maxLen = 4), - attr("t2_c3") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), - nullCount = 1, avgLen = 4, maxLen = 4), + attr("t2_c1") -> rangeColumnStat(5, 1), + attr("t2_c2") -> rangeColumnStat(5, 1), + attr("t2_c3") -> rangeColumnStat(5, 1), // T3 (regular table) - attr("t3_c1") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), - nullCount = 1, avgLen = 4, maxLen = 4), - attr("t3_c2") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), - nullCount = 1, avgLen = 4, maxLen = 4), - attr("t3_c3") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), - nullCount = 1, avgLen = 4, maxLen = 4), + attr("t3_c1") -> rangeColumnStat(5, 1), + attr("t3_c2") -> rangeColumnStat(5, 1), + attr("t3_c3") -> rangeColumnStat(5, 1), // T4 (regular table) - attr("t4_c1") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), - nullCount = 1, avgLen = 4, maxLen = 4), - attr("t4_c2") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), - nullCount = 1, avgLen = 4, maxLen = 4), - attr("t4_c3") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), - nullCount = 1, avgLen = 4, maxLen = 4), + attr("t4_c1") -> rangeColumnStat(5, 1), + attr("t4_c2") -> rangeColumnStat(5, 1), + attr("t4_c3") -> rangeColumnStat(5, 1), // T5 (regular table) - attr("t5_c1") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), - nullCount = 1, avgLen = 4, maxLen = 4), - attr("t5_c2") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), - nullCount = 1, avgLen = 4, maxLen = 4), - attr("t5_c3") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), - nullCount = 1, avgLen = 4, maxLen = 4), + attr("t5_c1") -> rangeColumnStat(5, 1), + attr("t5_c2") -> rangeColumnStat(5, 1), + attr("t5_c3") -> rangeColumnStat(5, 1), // T6 (regular table) - attr("t6_c1") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), - nullCount = 1, avgLen = 4, maxLen = 4), - attr("t6_c2") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), - nullCount = 1, avgLen = 4, maxLen = 4), - attr("t6_c3") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), - nullCount = 1, avgLen = 4, maxLen = 4) + attr("t6_c1") -> rangeColumnStat(5, 1), + attr("t6_c2") -> rangeColumnStat(5, 1), + attr("t6_c3") -> rangeColumnStat(5, 1) )) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinReorderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinReorderSuite.scala index 777c5637201ed..4e0883e91e84a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinReorderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinReorderSuite.scala @@ -70,59 +70,40 @@ class StarJoinReorderSuite extends PlanTest with StatsEstimationTestBase { // Tables' cardinality: f1 > d3 > d1 > d2 > s3 private val columnInfo: AttributeMap[ColumnStat] = AttributeMap(Seq( // F1 - attr("f1_fk1") -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("f1_fk2") -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("f1_fk3") -> ColumnStat(distinctCount = 4, min = Some(1), max = Some(4), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("f1_c4") -> ColumnStat(distinctCount = 4, min = Some(1), max = Some(4), - nullCount = 0, avgLen = 4, maxLen = 4), + attr("f1_fk1") -> rangeColumnStat(3, 0), + attr("f1_fk2") -> rangeColumnStat(3, 0), + attr("f1_fk3") -> rangeColumnStat(4, 0), + attr("f1_c4") -> rangeColumnStat(4, 0), // D1 - attr("d1_pk1") -> ColumnStat(distinctCount = 4, min = Some(1), max = Some(4), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("d1_c2") -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("d1_c3") -> ColumnStat(distinctCount = 4, min = Some(1), max = Some(4), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("d1_c4") -> ColumnStat(distinctCount = 2, min = Some(2), max = Some(3), - nullCount = 0, avgLen = 4, maxLen = 4), + attr("d1_pk1") -> rangeColumnStat(4, 0), + attr("d1_c2") -> rangeColumnStat(3, 0), + attr("d1_c3") -> rangeColumnStat(4, 0), + attr("d1_c4") -> ColumnStat(distinctCount = Some(2), min = Some("2"), max = Some("3"), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), // D2 - attr("d2_c2") -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3), - nullCount = 1, avgLen = 4, maxLen = 4), - attr("d2_pk1") -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("d2_c3") -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("d2_c4") -> ColumnStat(distinctCount = 2, min = Some(3), max = Some(4), - nullCount = 0, avgLen = 4, maxLen = 4), + attr("d2_c2") -> ColumnStat(distinctCount = Some(3), min = Some("1"), max = Some("3"), + nullCount = Some(1), avgLen = Some(4), maxLen = Some(4)), + attr("d2_pk1") -> rangeColumnStat(3, 0), + attr("d2_c3") -> rangeColumnStat(3, 0), + attr("d2_c4") -> ColumnStat(distinctCount = Some(2), min = Some("3"), max = Some("4"), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), // D3 - attr("d3_fk1") -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("d3_c2") -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("d3_pk1") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("d3_c4") -> ColumnStat(distinctCount = 2, min = Some(2), max = Some(3), - nullCount = 0, avgLen = 4, maxLen = 4), + attr("d3_fk1") -> rangeColumnStat(3, 0), + attr("d3_c2") -> rangeColumnStat(3, 0), + attr("d3_pk1") -> rangeColumnStat(5, 0), + attr("d3_c4") -> ColumnStat(distinctCount = Some(2), min = Some("2"), max = Some("3"), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), // S3 - attr("s3_pk1") -> ColumnStat(distinctCount = 2, min = Some(1), max = Some(2), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("s3_c2") -> ColumnStat(distinctCount = 1, min = Some(3), max = Some(3), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("s3_c3") -> ColumnStat(distinctCount = 1, min = Some(3), max = Some(3), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("s3_c4") -> ColumnStat(distinctCount = 2, min = Some(3), max = Some(4), - nullCount = 0, avgLen = 4, maxLen = 4), + attr("s3_pk1") -> rangeColumnStat(2, 0), + attr("s3_c2") -> rangeColumnStat(1, 0), + attr("s3_c3") -> rangeColumnStat(1, 0), + attr("s3_c4") -> ColumnStat(distinctCount = Some(2), min = Some("3"), max = Some("4"), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), // F11 - attr("f11_fk1") -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("f11_fk2") -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("f11_fk3") -> ColumnStat(distinctCount = 4, min = Some(1), max = Some(4), - nullCount = 0, avgLen = 4, maxLen = 4), - attr("f11_c4") -> ColumnStat(distinctCount = 4, min = Some(1), max = Some(4), - nullCount = 0, avgLen = 4, maxLen = 4) + attr("f11_fk1") -> rangeColumnStat(3, 0), + attr("f11_fk2") -> rangeColumnStat(3, 0), + attr("f11_fk3") -> rangeColumnStat(4, 0), + attr("f11_c4") -> rangeColumnStat(4, 0) )) private val nameToAttr: Map[String, Attribute] = columnInfo.map(kv => kv._1.name -> kv._1) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggregateEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggregateEstimationSuite.scala index 23f95a6cc2ac2..8213d568fe85e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggregateEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggregateEstimationSuite.scala @@ -29,16 +29,16 @@ class AggregateEstimationSuite extends StatsEstimationTestBase with PlanTest { /** Columns for testing */ private val columnInfo: AttributeMap[ColumnStat] = AttributeMap(Seq( - attr("key11") -> ColumnStat(distinctCount = 2, min = Some(1), max = Some(2), nullCount = 0, - avgLen = 4, maxLen = 4), - attr("key12") -> ColumnStat(distinctCount = 4, min = Some(10), max = Some(40), nullCount = 0, - avgLen = 4, maxLen = 4), - attr("key21") -> ColumnStat(distinctCount = 2, min = Some(1), max = Some(2), nullCount = 0, - avgLen = 4, maxLen = 4), - attr("key22") -> ColumnStat(distinctCount = 2, min = Some(10), max = Some(20), nullCount = 0, - avgLen = 4, maxLen = 4), - attr("key31") -> ColumnStat(distinctCount = 0, min = None, max = None, nullCount = 0, - avgLen = 4, maxLen = 4) + attr("key11") -> ColumnStat(distinctCount = Some(2), min = Some(1), max = Some(2), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), + attr("key12") -> ColumnStat(distinctCount = Some(4), min = Some(10), max = Some(40), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), + attr("key21") -> ColumnStat(distinctCount = Some(2), min = Some(1), max = Some(2), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), + attr("key22") -> ColumnStat(distinctCount = Some(2), min = Some(10), max = Some(20), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), + attr("key31") -> ColumnStat(distinctCount = Some(0), min = None, max = None, + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)) )) private val nameToAttr: Map[String, Attribute] = columnInfo.map(kv => kv._1.name -> kv._1) @@ -63,8 +63,8 @@ class AggregateEstimationSuite extends StatsEstimationTestBase with PlanTest { tableRowCount = 6, groupByColumns = Seq("key21", "key22"), // Row count = product of ndv - expectedOutputRowCount = nameToColInfo("key21")._2.distinctCount * nameToColInfo("key22")._2 - .distinctCount) + expectedOutputRowCount = nameToColInfo("key21")._2.distinctCount.get * + nameToColInfo("key22")._2.distinctCount.get) } test("empty group-by column") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala index 7d532ff343178..953094cb0dd52 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala @@ -28,8 +28,8 @@ import org.apache.spark.sql.types.IntegerType class BasicStatsEstimationSuite extends PlanTest with StatsEstimationTestBase { val attribute = attr("key") - val colStat = ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4) + val colStat = ColumnStat(distinctCount = Some(10), min = Some(1), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)) val plan = StatsTestPlan( outputList = Seq(attribute), @@ -116,13 +116,17 @@ class BasicStatsEstimationSuite extends PlanTest with StatsEstimationTestBase { sizeInBytes = 40, rowCount = Some(10), attributeStats = AttributeMap(Seq( - AttributeReference("c1", IntegerType)() -> ColumnStat(10, Some(1), Some(10), 0, 4, 4)))) + AttributeReference("c1", IntegerType)() -> ColumnStat(distinctCount = Some(10), + min = Some(1), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))))) val expectedCboStats = Statistics( sizeInBytes = 4, rowCount = Some(1), attributeStats = AttributeMap(Seq( - AttributeReference("c1", IntegerType)() -> ColumnStat(1, Some(5), Some(5), 0, 4, 4)))) + AttributeReference("c1", IntegerType)() -> ColumnStat(distinctCount = Some(10), + min = Some(5), max = Some(5), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))))) val plan = DummyLogicalPlan(defaultStats = expectedDefaultStats, cboStats = expectedCboStats) checkStats( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala index 2b1fe987a7960..43440d51dede6 100755 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala @@ -37,59 +37,61 @@ class FilterEstimationSuite extends StatsEstimationTestBase { // column cint has values: 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 // Hence, distinctCount:10, min:1, max:10, nullCount:0, avgLen:4, maxLen:4 val attrInt = AttributeReference("cint", IntegerType)() - val colStatInt = ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4) + val colStatInt = ColumnStat(distinctCount = Some(10), min = Some(1), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)) // column cbool has only 2 distinct values val attrBool = AttributeReference("cbool", BooleanType)() - val colStatBool = ColumnStat(distinctCount = 2, min = Some(false), max = Some(true), - nullCount = 0, avgLen = 1, maxLen = 1) + val colStatBool = ColumnStat(distinctCount = Some(2), min = Some(false), max = Some(true), + nullCount = Some(0), avgLen = Some(1), maxLen = Some(1)) // column cdate has 10 values from 2017-01-01 through 2017-01-10. val dMin = DateTimeUtils.fromJavaDate(Date.valueOf("2017-01-01")) val dMax = DateTimeUtils.fromJavaDate(Date.valueOf("2017-01-10")) val attrDate = AttributeReference("cdate", DateType)() - val colStatDate = ColumnStat(distinctCount = 10, min = Some(dMin), max = Some(dMax), - nullCount = 0, avgLen = 4, maxLen = 4) + val colStatDate = ColumnStat(distinctCount = Some(10), + min = Some(dMin), max = Some(dMax), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)) // column cdecimal has 4 values from 0.20 through 0.80 at increment of 0.20. val decMin = Decimal("0.200000000000000000") val decMax = Decimal("0.800000000000000000") val attrDecimal = AttributeReference("cdecimal", DecimalType(18, 18))() - val colStatDecimal = ColumnStat(distinctCount = 4, min = Some(decMin), max = Some(decMax), - nullCount = 0, avgLen = 8, maxLen = 8) + val colStatDecimal = ColumnStat(distinctCount = Some(4), + min = Some(decMin), max = Some(decMax), + nullCount = Some(0), avgLen = Some(8), maxLen = Some(8)) // column cdouble has 10 double values: 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0 val attrDouble = AttributeReference("cdouble", DoubleType)() - val colStatDouble = ColumnStat(distinctCount = 10, min = Some(1.0), max = Some(10.0), - nullCount = 0, avgLen = 8, maxLen = 8) + val colStatDouble = ColumnStat(distinctCount = Some(10), min = Some(1.0), max = Some(10.0), + nullCount = Some(0), avgLen = Some(8), maxLen = Some(8)) // column cstring has 10 String values: // "A0", "A1", "A2", "A3", "A4", "A5", "A6", "A7", "A8", "A9" val attrString = AttributeReference("cstring", StringType)() - val colStatString = ColumnStat(distinctCount = 10, min = None, max = None, - nullCount = 0, avgLen = 2, maxLen = 2) + val colStatString = ColumnStat(distinctCount = Some(10), min = None, max = None, + nullCount = Some(0), avgLen = Some(2), maxLen = Some(2)) // column cint2 has values: 7, 8, 9, 10, 11, 12, 13, 14, 15, 16 // Hence, distinctCount:10, min:7, max:16, nullCount:0, avgLen:4, maxLen:4 // This column is created to test "cint < cint2 val attrInt2 = AttributeReference("cint2", IntegerType)() - val colStatInt2 = ColumnStat(distinctCount = 10, min = Some(7), max = Some(16), - nullCount = 0, avgLen = 4, maxLen = 4) + val colStatInt2 = ColumnStat(distinctCount = Some(10), min = Some(7), max = Some(16), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)) // column cint3 has values: 30, 31, 32, 33, 34, 35, 36, 37, 38, 39 // Hence, distinctCount:10, min:30, max:39, nullCount:0, avgLen:4, maxLen:4 // This column is created to test "cint = cint3 without overlap at all. val attrInt3 = AttributeReference("cint3", IntegerType)() - val colStatInt3 = ColumnStat(distinctCount = 10, min = Some(30), max = Some(39), - nullCount = 0, avgLen = 4, maxLen = 4) + val colStatInt3 = ColumnStat(distinctCount = Some(10), min = Some(30), max = Some(39), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)) // column cint4 has values in the range from 1 to 10 // distinctCount:10, min:1, max:10, nullCount:0, avgLen:4, maxLen:4 // This column is created to test complete overlap val attrInt4 = AttributeReference("cint4", IntegerType)() - val colStatInt4 = ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4) + val colStatInt4 = ColumnStat(distinctCount = Some(10), min = Some(1), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)) // column cintHgm has values: 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 with histogram. // Note that cintHgm has an even distribution with histogram information built. @@ -98,8 +100,8 @@ class FilterEstimationSuite extends StatsEstimationTestBase { val hgmInt = Histogram(2.0, Array(HistogramBin(1.0, 2.0, 2), HistogramBin(2.0, 4.0, 2), HistogramBin(4.0, 6.0, 2), HistogramBin(6.0, 8.0, 2), HistogramBin(8.0, 10.0, 2))) - val colStatIntHgm = ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4, histogram = Some(hgmInt)) + val colStatIntHgm = ColumnStat(distinctCount = Some(10), min = Some(1), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4), histogram = Some(hgmInt)) // column cintSkewHgm has values: 1, 4, 4, 5, 5, 5, 5, 6, 6, 10 with histogram. // Note that cintSkewHgm has a skewed distribution with histogram information built. @@ -108,8 +110,8 @@ class FilterEstimationSuite extends StatsEstimationTestBase { val hgmIntSkew = Histogram(2.0, Array(HistogramBin(1.0, 4.0, 2), HistogramBin(4.0, 5.0, 2), HistogramBin(5.0, 5.0, 1), HistogramBin(5.0, 6.0, 2), HistogramBin(6.0, 10.0, 2))) - val colStatIntSkewHgm = ColumnStat(distinctCount = 5, min = Some(1), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4, histogram = Some(hgmIntSkew)) + val colStatIntSkewHgm = ColumnStat(distinctCount = Some(5), min = Some(1), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4), histogram = Some(hgmIntSkew)) val attributeMap = AttributeMap(Seq( attrInt -> colStatInt, @@ -172,7 +174,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { val condition = Or(LessThan(attrInt, Literal(3)), Literal(null, IntegerType)) validateEstimatedStats( Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)), - Seq(attrInt -> colStatInt.copy(distinctCount = 3)), + Seq(attrInt -> colStatInt.copy(distinctCount = Some(3))), expectedRowCount = 3) } @@ -180,7 +182,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { val condition = Not(And(LessThan(attrInt, Literal(3)), Literal(null, IntegerType))) validateEstimatedStats( Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)), - Seq(attrInt -> colStatInt.copy(distinctCount = 8)), + Seq(attrInt -> colStatInt.copy(distinctCount = Some(8))), expectedRowCount = 8) } @@ -196,23 +198,23 @@ class FilterEstimationSuite extends StatsEstimationTestBase { val condition = Not(And(LessThan(attrInt, Literal(3)), Not(Literal(null, IntegerType)))) validateEstimatedStats( Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)), - Seq(attrInt -> colStatInt.copy(distinctCount = 8)), + Seq(attrInt -> colStatInt.copy(distinctCount = Some(8))), expectedRowCount = 8) } test("cint = 2") { validateEstimatedStats( Filter(EqualTo(attrInt, Literal(2)), childStatsTestPlan(Seq(attrInt), 10L)), - Seq(attrInt -> ColumnStat(distinctCount = 1, min = Some(2), max = Some(2), - nullCount = 0, avgLen = 4, maxLen = 4)), + Seq(attrInt -> ColumnStat(distinctCount = Some(1), min = Some(2), max = Some(2), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))), expectedRowCount = 1) } test("cint <=> 2") { validateEstimatedStats( Filter(EqualNullSafe(attrInt, Literal(2)), childStatsTestPlan(Seq(attrInt), 10L)), - Seq(attrInt -> ColumnStat(distinctCount = 1, min = Some(2), max = Some(2), - nullCount = 0, avgLen = 4, maxLen = 4)), + Seq(attrInt -> ColumnStat(distinctCount = Some(1), min = Some(2), max = Some(2), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))), expectedRowCount = 1) } @@ -227,8 +229,8 @@ class FilterEstimationSuite extends StatsEstimationTestBase { test("cint < 3") { validateEstimatedStats( Filter(LessThan(attrInt, Literal(3)), childStatsTestPlan(Seq(attrInt), 10L)), - Seq(attrInt -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3), - nullCount = 0, avgLen = 4, maxLen = 4)), + Seq(attrInt -> ColumnStat(distinctCount = Some(3), min = Some(1), max = Some(3), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))), expectedRowCount = 3) } @@ -243,16 +245,16 @@ class FilterEstimationSuite extends StatsEstimationTestBase { test("cint <= 3") { validateEstimatedStats( Filter(LessThanOrEqual(attrInt, Literal(3)), childStatsTestPlan(Seq(attrInt), 10L)), - Seq(attrInt -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3), - nullCount = 0, avgLen = 4, maxLen = 4)), + Seq(attrInt -> ColumnStat(distinctCount = Some(3), min = Some(1), max = Some(3), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))), expectedRowCount = 3) } test("cint > 6") { validateEstimatedStats( Filter(GreaterThan(attrInt, Literal(6)), childStatsTestPlan(Seq(attrInt), 10L)), - Seq(attrInt -> ColumnStat(distinctCount = 5, min = Some(6), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4)), + Seq(attrInt -> ColumnStat(distinctCount = Some(5), min = Some(6), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))), expectedRowCount = 5) } @@ -267,8 +269,8 @@ class FilterEstimationSuite extends StatsEstimationTestBase { test("cint >= 6") { validateEstimatedStats( Filter(GreaterThanOrEqual(attrInt, Literal(6)), childStatsTestPlan(Seq(attrInt), 10L)), - Seq(attrInt -> ColumnStat(distinctCount = 5, min = Some(6), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4)), + Seq(attrInt -> ColumnStat(distinctCount = Some(5), min = Some(6), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))), expectedRowCount = 5) } @@ -282,8 +284,8 @@ class FilterEstimationSuite extends StatsEstimationTestBase { test("cint IS NOT NULL") { validateEstimatedStats( Filter(IsNotNull(attrInt), childStatsTestPlan(Seq(attrInt), 10L)), - Seq(attrInt -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4)), + Seq(attrInt -> ColumnStat(distinctCount = Some(10), min = Some(1), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))), expectedRowCount = 10) } @@ -301,8 +303,8 @@ class FilterEstimationSuite extends StatsEstimationTestBase { val condition = And(GreaterThan(attrInt, Literal(3)), LessThanOrEqual(attrInt, Literal(6))) validateEstimatedStats( Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)), - Seq(attrInt -> ColumnStat(distinctCount = 4, min = Some(3), max = Some(6), - nullCount = 0, avgLen = 4, maxLen = 4)), + Seq(attrInt -> ColumnStat(distinctCount = Some(4), min = Some(3), max = Some(6), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))), expectedRowCount = 4) } @@ -310,7 +312,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { val condition = Or(EqualTo(attrInt, Literal(3)), EqualTo(attrInt, Literal(6))) validateEstimatedStats( Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)), - Seq(attrInt -> colStatInt.copy(distinctCount = 2)), + Seq(attrInt -> colStatInt.copy(distinctCount = Some(2))), expectedRowCount = 2) } @@ -318,7 +320,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { val condition = Not(And(GreaterThan(attrInt, Literal(3)), LessThanOrEqual(attrInt, Literal(6)))) validateEstimatedStats( Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)), - Seq(attrInt -> colStatInt.copy(distinctCount = 6)), + Seq(attrInt -> colStatInt.copy(distinctCount = Some(6))), expectedRowCount = 6) } @@ -326,7 +328,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { val condition = Not(Or(LessThanOrEqual(attrInt, Literal(3)), GreaterThan(attrInt, Literal(6)))) validateEstimatedStats( Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)), - Seq(attrInt -> colStatInt.copy(distinctCount = 5)), + Seq(attrInt -> colStatInt.copy(distinctCount = Some(5))), expectedRowCount = 5) } @@ -342,47 +344,47 @@ class FilterEstimationSuite extends StatsEstimationTestBase { val condition = Not(Or(EqualTo(attrInt, Literal(3)), LessThan(attrString, Literal("A8")))) validateEstimatedStats( Filter(condition, childStatsTestPlan(Seq(attrInt, attrString), 10L)), - Seq(attrInt -> colStatInt.copy(distinctCount = 9), - attrString -> colStatString.copy(distinctCount = 9)), + Seq(attrInt -> colStatInt.copy(distinctCount = Some(9)), + attrString -> colStatString.copy(distinctCount = Some(9))), expectedRowCount = 9) } test("cint IN (3, 4, 5)") { validateEstimatedStats( Filter(InSet(attrInt, Set(3, 4, 5)), childStatsTestPlan(Seq(attrInt), 10L)), - Seq(attrInt -> ColumnStat(distinctCount = 3, min = Some(3), max = Some(5), - nullCount = 0, avgLen = 4, maxLen = 4)), + Seq(attrInt -> ColumnStat(distinctCount = Some(3), min = Some(3), max = Some(5), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))), expectedRowCount = 3) } test("cint NOT IN (3, 4, 5)") { validateEstimatedStats( Filter(Not(InSet(attrInt, Set(3, 4, 5))), childStatsTestPlan(Seq(attrInt), 10L)), - Seq(attrInt -> colStatInt.copy(distinctCount = 7)), + Seq(attrInt -> colStatInt.copy(distinctCount = Some(7))), expectedRowCount = 7) } test("cbool IN (true)") { validateEstimatedStats( Filter(InSet(attrBool, Set(true)), childStatsTestPlan(Seq(attrBool), 10L)), - Seq(attrBool -> ColumnStat(distinctCount = 1, min = Some(true), max = Some(true), - nullCount = 0, avgLen = 1, maxLen = 1)), + Seq(attrBool -> ColumnStat(distinctCount = Some(1), min = Some(true), max = Some(true), + nullCount = Some(0), avgLen = Some(1), maxLen = Some(1))), expectedRowCount = 5) } test("cbool = true") { validateEstimatedStats( Filter(EqualTo(attrBool, Literal(true)), childStatsTestPlan(Seq(attrBool), 10L)), - Seq(attrBool -> ColumnStat(distinctCount = 1, min = Some(true), max = Some(true), - nullCount = 0, avgLen = 1, maxLen = 1)), + Seq(attrBool -> ColumnStat(distinctCount = Some(1), min = Some(true), max = Some(true), + nullCount = Some(0), avgLen = Some(1), maxLen = Some(1))), expectedRowCount = 5) } test("cbool > false") { validateEstimatedStats( Filter(GreaterThan(attrBool, Literal(false)), childStatsTestPlan(Seq(attrBool), 10L)), - Seq(attrBool -> ColumnStat(distinctCount = 1, min = Some(false), max = Some(true), - nullCount = 0, avgLen = 1, maxLen = 1)), + Seq(attrBool -> ColumnStat(distinctCount = Some(1), min = Some(false), max = Some(true), + nullCount = Some(0), avgLen = Some(1), maxLen = Some(1))), expectedRowCount = 5) } @@ -391,18 +393,21 @@ class FilterEstimationSuite extends StatsEstimationTestBase { validateEstimatedStats( Filter(EqualTo(attrDate, Literal(d20170102, DateType)), childStatsTestPlan(Seq(attrDate), 10L)), - Seq(attrDate -> ColumnStat(distinctCount = 1, min = Some(d20170102), max = Some(d20170102), - nullCount = 0, avgLen = 4, maxLen = 4)), + Seq(attrDate -> ColumnStat(distinctCount = Some(1), + min = Some(d20170102), max = Some(d20170102), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))), expectedRowCount = 1) } test("cdate < cast('2017-01-03' AS DATE)") { + val d20170101 = DateTimeUtils.fromJavaDate(Date.valueOf("2017-01-01")) val d20170103 = DateTimeUtils.fromJavaDate(Date.valueOf("2017-01-03")) validateEstimatedStats( Filter(LessThan(attrDate, Literal(d20170103, DateType)), childStatsTestPlan(Seq(attrDate), 10L)), - Seq(attrDate -> ColumnStat(distinctCount = 3, min = Some(dMin), max = Some(d20170103), - nullCount = 0, avgLen = 4, maxLen = 4)), + Seq(attrDate -> ColumnStat(distinctCount = Some(3), + min = Some(d20170101), max = Some(d20170103), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))), expectedRowCount = 3) } @@ -414,8 +419,9 @@ class FilterEstimationSuite extends StatsEstimationTestBase { validateEstimatedStats( Filter(In(attrDate, Seq(Literal(d20170103, DateType), Literal(d20170104, DateType), Literal(d20170105, DateType))), childStatsTestPlan(Seq(attrDate), 10L)), - Seq(attrDate -> ColumnStat(distinctCount = 3, min = Some(d20170103), max = Some(d20170105), - nullCount = 0, avgLen = 4, maxLen = 4)), + Seq(attrDate -> ColumnStat(distinctCount = Some(3), + min = Some(d20170103), max = Some(d20170105), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))), expectedRowCount = 3) } @@ -424,42 +430,45 @@ class FilterEstimationSuite extends StatsEstimationTestBase { validateEstimatedStats( Filter(EqualTo(attrDecimal, Literal(dec_0_40)), childStatsTestPlan(Seq(attrDecimal), 4L)), - Seq(attrDecimal -> ColumnStat(distinctCount = 1, min = Some(dec_0_40), max = Some(dec_0_40), - nullCount = 0, avgLen = 8, maxLen = 8)), + Seq(attrDecimal -> ColumnStat(distinctCount = Some(1), + min = Some(dec_0_40), max = Some(dec_0_40), + nullCount = Some(0), avgLen = Some(8), maxLen = Some(8))), expectedRowCount = 1) } test("cdecimal < 0.60 ") { + val dec_0_20 = Decimal("0.200000000000000000") val dec_0_60 = Decimal("0.600000000000000000") validateEstimatedStats( Filter(LessThan(attrDecimal, Literal(dec_0_60)), childStatsTestPlan(Seq(attrDecimal), 4L)), - Seq(attrDecimal -> ColumnStat(distinctCount = 3, min = Some(decMin), max = Some(dec_0_60), - nullCount = 0, avgLen = 8, maxLen = 8)), + Seq(attrDecimal -> ColumnStat(distinctCount = Some(3), + min = Some(dec_0_20), max = Some(dec_0_60), + nullCount = Some(0), avgLen = Some(8), maxLen = Some(8))), expectedRowCount = 3) } test("cdouble < 3.0") { validateEstimatedStats( Filter(LessThan(attrDouble, Literal(3.0)), childStatsTestPlan(Seq(attrDouble), 10L)), - Seq(attrDouble -> ColumnStat(distinctCount = 3, min = Some(1.0), max = Some(3.0), - nullCount = 0, avgLen = 8, maxLen = 8)), + Seq(attrDouble -> ColumnStat(distinctCount = Some(3), min = Some(1.0), max = Some(3.0), + nullCount = Some(0), avgLen = Some(8), maxLen = Some(8))), expectedRowCount = 3) } test("cstring = 'A2'") { validateEstimatedStats( Filter(EqualTo(attrString, Literal("A2")), childStatsTestPlan(Seq(attrString), 10L)), - Seq(attrString -> ColumnStat(distinctCount = 1, min = None, max = None, - nullCount = 0, avgLen = 2, maxLen = 2)), + Seq(attrString -> ColumnStat(distinctCount = Some(1), min = None, max = None, + nullCount = Some(0), avgLen = Some(2), maxLen = Some(2))), expectedRowCount = 1) } test("cstring < 'A2' - unsupported condition") { validateEstimatedStats( Filter(LessThan(attrString, Literal("A2")), childStatsTestPlan(Seq(attrString), 10L)), - Seq(attrString -> ColumnStat(distinctCount = 10, min = None, max = None, - nullCount = 0, avgLen = 2, maxLen = 2)), + Seq(attrString -> ColumnStat(distinctCount = Some(10), min = None, max = None, + nullCount = Some(0), avgLen = Some(2), maxLen = Some(2))), expectedRowCount = 10) } @@ -468,8 +477,9 @@ class FilterEstimationSuite extends StatsEstimationTestBase { // valid values in IN clause is greater than the number of distinct values for a given column. // For example, column has only 2 distinct values 1 and 6. // The predicate is: column IN (1, 2, 3, 4, 5). - val cornerChildColStatInt = ColumnStat(distinctCount = 2, min = Some(1), max = Some(6), - nullCount = 0, avgLen = 4, maxLen = 4) + val cornerChildColStatInt = ColumnStat(distinctCount = Some(2), + min = Some(1), max = Some(6), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)) val cornerChildStatsTestplan = StatsTestPlan( outputList = Seq(attrInt), rowCount = 2L, @@ -477,16 +487,17 @@ class FilterEstimationSuite extends StatsEstimationTestBase { ) validateEstimatedStats( Filter(InSet(attrInt, Set(1, 2, 3, 4, 5)), cornerChildStatsTestplan), - Seq(attrInt -> ColumnStat(distinctCount = 2, min = Some(1), max = Some(5), - nullCount = 0, avgLen = 4, maxLen = 4)), + Seq(attrInt -> ColumnStat(distinctCount = Some(2), min = Some(1), max = Some(5), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))), expectedRowCount = 2) } // This is a limitation test. We should remove it after the limitation is removed. test("don't estimate IsNull or IsNotNull if the child is a non-leaf node") { val attrIntLargerRange = AttributeReference("c1", IntegerType)() - val colStatIntLargerRange = ColumnStat(distinctCount = 20, min = Some(1), max = Some(20), - nullCount = 10, avgLen = 4, maxLen = 4) + val colStatIntLargerRange = ColumnStat(distinctCount = Some(20), + min = Some(1), max = Some(20), + nullCount = Some(10), avgLen = Some(4), maxLen = Some(4)) val smallerTable = childStatsTestPlan(Seq(attrInt), 10L) val largerTable = StatsTestPlan( outputList = Seq(attrIntLargerRange), @@ -508,10 +519,10 @@ class FilterEstimationSuite extends StatsEstimationTestBase { // partial overlap case validateEstimatedStats( Filter(EqualTo(attrInt, attrInt2), childStatsTestPlan(Seq(attrInt, attrInt2), 10L)), - Seq(attrInt -> ColumnStat(distinctCount = 4, min = Some(7), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4), - attrInt2 -> ColumnStat(distinctCount = 4, min = Some(7), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4)), + Seq(attrInt -> ColumnStat(distinctCount = Some(4), min = Some(7), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), + attrInt2 -> ColumnStat(distinctCount = Some(4), min = Some(7), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))), expectedRowCount = 4) } @@ -519,10 +530,10 @@ class FilterEstimationSuite extends StatsEstimationTestBase { // partial overlap case validateEstimatedStats( Filter(GreaterThan(attrInt, attrInt2), childStatsTestPlan(Seq(attrInt, attrInt2), 10L)), - Seq(attrInt -> ColumnStat(distinctCount = 4, min = Some(7), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4), - attrInt2 -> ColumnStat(distinctCount = 4, min = Some(7), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4)), + Seq(attrInt -> ColumnStat(distinctCount = Some(4), min = Some(7), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), + attrInt2 -> ColumnStat(distinctCount = Some(4), min = Some(7), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))), expectedRowCount = 4) } @@ -530,10 +541,10 @@ class FilterEstimationSuite extends StatsEstimationTestBase { // partial overlap case validateEstimatedStats( Filter(LessThan(attrInt, attrInt2), childStatsTestPlan(Seq(attrInt, attrInt2), 10L)), - Seq(attrInt -> ColumnStat(distinctCount = 4, min = Some(1), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4), - attrInt2 -> ColumnStat(distinctCount = 4, min = Some(7), max = Some(16), - nullCount = 0, avgLen = 4, maxLen = 4)), + Seq(attrInt -> ColumnStat(distinctCount = Some(4), min = Some(1), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), + attrInt2 -> ColumnStat(distinctCount = Some(4), min = Some(7), max = Some(16), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))), expectedRowCount = 4) } @@ -541,10 +552,10 @@ class FilterEstimationSuite extends StatsEstimationTestBase { // complete overlap case validateEstimatedStats( Filter(EqualTo(attrInt, attrInt4), childStatsTestPlan(Seq(attrInt, attrInt4), 10L)), - Seq(attrInt -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4), - attrInt4 -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4)), + Seq(attrInt -> ColumnStat(distinctCount = Some(10), min = Some(1), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), + attrInt4 -> ColumnStat(distinctCount = Some(10), min = Some(1), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))), expectedRowCount = 10) } @@ -552,10 +563,10 @@ class FilterEstimationSuite extends StatsEstimationTestBase { // partial overlap case validateEstimatedStats( Filter(LessThan(attrInt, attrInt4), childStatsTestPlan(Seq(attrInt, attrInt4), 10L)), - Seq(attrInt -> ColumnStat(distinctCount = 4, min = Some(1), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4), - attrInt4 -> ColumnStat(distinctCount = 4, min = Some(1), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4)), + Seq(attrInt -> ColumnStat(distinctCount = Some(4), min = Some(1), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), + attrInt4 -> ColumnStat(distinctCount = Some(4), min = Some(1), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))), expectedRowCount = 4) } @@ -571,10 +582,10 @@ class FilterEstimationSuite extends StatsEstimationTestBase { // all table records qualify. validateEstimatedStats( Filter(LessThan(attrInt, attrInt3), childStatsTestPlan(Seq(attrInt, attrInt3), 10L)), - Seq(attrInt -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4), - attrInt3 -> ColumnStat(distinctCount = 10, min = Some(30), max = Some(39), - nullCount = 0, avgLen = 4, maxLen = 4)), + Seq(attrInt -> ColumnStat(distinctCount = Some(10), min = Some(1), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), + attrInt3 -> ColumnStat(distinctCount = Some(10), min = Some(30), max = Some(39), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))), expectedRowCount = 10) } @@ -592,11 +603,11 @@ class FilterEstimationSuite extends StatsEstimationTestBase { validateEstimatedStats( Filter(condition, childStatsTestPlan(Seq(attrInt, attrInt4, attrString), 10L)), Seq( - attrInt -> ColumnStat(distinctCount = 5, min = Some(3), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4), - attrInt4 -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(6), - nullCount = 0, avgLen = 4, maxLen = 4), - attrString -> colStatString.copy(distinctCount = 5)), + attrInt -> ColumnStat(distinctCount = Some(5), min = Some(3), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), + attrInt4 -> ColumnStat(distinctCount = Some(5), min = Some(1), max = Some(6), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), + attrString -> colStatString.copy(distinctCount = Some(5))), expectedRowCount = 5) } @@ -606,15 +617,15 @@ class FilterEstimationSuite extends StatsEstimationTestBase { val condition = Not(And(LessThan(attrIntHgm, Literal(3)), Literal(null, IntegerType))) validateEstimatedStats( Filter(condition, childStatsTestPlan(Seq(attrIntHgm), 10L)), - Seq(attrIntHgm -> colStatIntHgm.copy(distinctCount = 7)), + Seq(attrIntHgm -> colStatIntHgm.copy(distinctCount = Some(7))), expectedRowCount = 7) } test("cintHgm = 5") { validateEstimatedStats( Filter(EqualTo(attrIntHgm, Literal(5)), childStatsTestPlan(Seq(attrIntHgm), 10L)), - Seq(attrIntHgm -> ColumnStat(distinctCount = 1, min = Some(5), max = Some(5), - nullCount = 0, avgLen = 4, maxLen = 4, histogram = Some(hgmInt))), + Seq(attrIntHgm -> ColumnStat(distinctCount = Some(1), min = Some(5), max = Some(5), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4), histogram = Some(hgmInt))), expectedRowCount = 1) } @@ -629,8 +640,8 @@ class FilterEstimationSuite extends StatsEstimationTestBase { test("cintHgm < 3") { validateEstimatedStats( Filter(LessThan(attrIntHgm, Literal(3)), childStatsTestPlan(Seq(attrIntHgm), 10L)), - Seq(attrIntHgm -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3), - nullCount = 0, avgLen = 4, maxLen = 4, histogram = Some(hgmInt))), + Seq(attrIntHgm -> ColumnStat(distinctCount = Some(3), min = Some(1), max = Some(3), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4), histogram = Some(hgmInt))), expectedRowCount = 3) } @@ -645,16 +656,16 @@ class FilterEstimationSuite extends StatsEstimationTestBase { test("cintHgm <= 3") { validateEstimatedStats( Filter(LessThanOrEqual(attrIntHgm, Literal(3)), childStatsTestPlan(Seq(attrIntHgm), 10L)), - Seq(attrIntHgm -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3), - nullCount = 0, avgLen = 4, maxLen = 4, histogram = Some(hgmInt))), + Seq(attrIntHgm -> ColumnStat(distinctCount = Some(3), min = Some(1), max = Some(3), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4), histogram = Some(hgmInt))), expectedRowCount = 3) } test("cintHgm > 6") { validateEstimatedStats( Filter(GreaterThan(attrIntHgm, Literal(6)), childStatsTestPlan(Seq(attrIntHgm), 10L)), - Seq(attrIntHgm -> ColumnStat(distinctCount = 4, min = Some(6), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4, histogram = Some(hgmInt))), + Seq(attrIntHgm -> ColumnStat(distinctCount = Some(4), min = Some(6), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4), histogram = Some(hgmInt))), expectedRowCount = 4) } @@ -669,8 +680,8 @@ class FilterEstimationSuite extends StatsEstimationTestBase { test("cintHgm >= 6") { validateEstimatedStats( Filter(GreaterThanOrEqual(attrIntHgm, Literal(6)), childStatsTestPlan(Seq(attrIntHgm), 10L)), - Seq(attrIntHgm -> ColumnStat(distinctCount = 5, min = Some(6), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4, histogram = Some(hgmInt))), + Seq(attrIntHgm -> ColumnStat(distinctCount = Some(5), min = Some(6), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4), histogram = Some(hgmInt))), expectedRowCount = 5) } @@ -679,8 +690,8 @@ class FilterEstimationSuite extends StatsEstimationTestBase { Literal(3)), LessThanOrEqual(attrIntHgm, Literal(6))) validateEstimatedStats( Filter(condition, childStatsTestPlan(Seq(attrIntHgm), 10L)), - Seq(attrIntHgm -> ColumnStat(distinctCount = 4, min = Some(3), max = Some(6), - nullCount = 0, avgLen = 4, maxLen = 4, histogram = Some(hgmInt))), + Seq(attrIntHgm -> ColumnStat(distinctCount = Some(4), min = Some(3), max = Some(6), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4), histogram = Some(hgmInt))), expectedRowCount = 4) } @@ -688,7 +699,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { val condition = Or(EqualTo(attrIntHgm, Literal(3)), EqualTo(attrIntHgm, Literal(6))) validateEstimatedStats( Filter(condition, childStatsTestPlan(Seq(attrIntHgm), 10L)), - Seq(attrIntHgm -> colStatIntHgm.copy(distinctCount = 3)), + Seq(attrIntHgm -> colStatIntHgm.copy(distinctCount = Some(3))), expectedRowCount = 3) } @@ -698,15 +709,15 @@ class FilterEstimationSuite extends StatsEstimationTestBase { val condition = Not(And(LessThan(attrIntSkewHgm, Literal(3)), Literal(null, IntegerType))) validateEstimatedStats( Filter(condition, childStatsTestPlan(Seq(attrIntSkewHgm), 10L)), - Seq(attrIntSkewHgm -> colStatIntSkewHgm.copy(distinctCount = 5)), + Seq(attrIntSkewHgm -> colStatIntSkewHgm.copy(distinctCount = Some(5))), expectedRowCount = 9) } test("cintSkewHgm = 5") { validateEstimatedStats( Filter(EqualTo(attrIntSkewHgm, Literal(5)), childStatsTestPlan(Seq(attrIntSkewHgm), 10L)), - Seq(attrIntSkewHgm -> ColumnStat(distinctCount = 1, min = Some(5), max = Some(5), - nullCount = 0, avgLen = 4, maxLen = 4, histogram = Some(hgmIntSkew))), + Seq(attrIntSkewHgm -> ColumnStat(distinctCount = Some(1), min = Some(5), max = Some(5), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4), histogram = Some(hgmIntSkew))), expectedRowCount = 4) } @@ -721,8 +732,8 @@ class FilterEstimationSuite extends StatsEstimationTestBase { test("cintSkewHgm < 3") { validateEstimatedStats( Filter(LessThan(attrIntSkewHgm, Literal(3)), childStatsTestPlan(Seq(attrIntSkewHgm), 10L)), - Seq(attrIntSkewHgm -> ColumnStat(distinctCount = 1, min = Some(1), max = Some(3), - nullCount = 0, avgLen = 4, maxLen = 4, histogram = Some(hgmIntSkew))), + Seq(attrIntSkewHgm -> ColumnStat(distinctCount = Some(1), min = Some(1), max = Some(3), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4), histogram = Some(hgmIntSkew))), expectedRowCount = 2) } @@ -738,16 +749,16 @@ class FilterEstimationSuite extends StatsEstimationTestBase { validateEstimatedStats( Filter(LessThanOrEqual(attrIntSkewHgm, Literal(3)), childStatsTestPlan(Seq(attrIntSkewHgm), 10L)), - Seq(attrIntSkewHgm -> ColumnStat(distinctCount = 1, min = Some(1), max = Some(3), - nullCount = 0, avgLen = 4, maxLen = 4, histogram = Some(hgmIntSkew))), + Seq(attrIntSkewHgm -> ColumnStat(distinctCount = Some(1), min = Some(1), max = Some(3), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4), histogram = Some(hgmIntSkew))), expectedRowCount = 2) } test("cintSkewHgm > 6") { validateEstimatedStats( Filter(GreaterThan(attrIntSkewHgm, Literal(6)), childStatsTestPlan(Seq(attrIntSkewHgm), 10L)), - Seq(attrIntSkewHgm -> ColumnStat(distinctCount = 1, min = Some(6), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4, histogram = Some(hgmIntSkew))), + Seq(attrIntSkewHgm -> ColumnStat(distinctCount = Some(1), min = Some(6), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4), histogram = Some(hgmIntSkew))), expectedRowCount = 2) } @@ -764,8 +775,8 @@ class FilterEstimationSuite extends StatsEstimationTestBase { validateEstimatedStats( Filter(GreaterThanOrEqual(attrIntSkewHgm, Literal(6)), childStatsTestPlan(Seq(attrIntSkewHgm), 10L)), - Seq(attrIntSkewHgm -> ColumnStat(distinctCount = 2, min = Some(6), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4, histogram = Some(hgmIntSkew))), + Seq(attrIntSkewHgm -> ColumnStat(distinctCount = Some(2), min = Some(6), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4), histogram = Some(hgmIntSkew))), expectedRowCount = 3) } @@ -774,8 +785,8 @@ class FilterEstimationSuite extends StatsEstimationTestBase { Literal(3)), LessThanOrEqual(attrIntSkewHgm, Literal(6))) validateEstimatedStats( Filter(condition, childStatsTestPlan(Seq(attrIntSkewHgm), 10L)), - Seq(attrIntSkewHgm -> ColumnStat(distinctCount = 4, min = Some(3), max = Some(6), - nullCount = 0, avgLen = 4, maxLen = 4, histogram = Some(hgmIntSkew))), + Seq(attrIntSkewHgm -> ColumnStat(distinctCount = Some(4), min = Some(3), max = Some(6), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4), histogram = Some(hgmIntSkew))), expectedRowCount = 8) } @@ -783,7 +794,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { val condition = Or(EqualTo(attrIntSkewHgm, Literal(3)), EqualTo(attrIntSkewHgm, Literal(6))) validateEstimatedStats( Filter(condition, childStatsTestPlan(Seq(attrIntSkewHgm), 10L)), - Seq(attrIntSkewHgm -> colStatIntSkewHgm.copy(distinctCount = 2)), + Seq(attrIntSkewHgm -> colStatIntSkewHgm.copy(distinctCount = Some(2))), expectedRowCount = 3) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala index 26139d85d25fb..12c0a7be21292 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala @@ -33,16 +33,16 @@ class JoinEstimationSuite extends StatsEstimationTestBase { /** Set up tables and its columns for testing */ private val columnInfo: AttributeMap[ColumnStat] = AttributeMap(Seq( - attr("key-1-5") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), nullCount = 0, - avgLen = 4, maxLen = 4), - attr("key-5-9") -> ColumnStat(distinctCount = 5, min = Some(5), max = Some(9), nullCount = 0, - avgLen = 4, maxLen = 4), - attr("key-1-2") -> ColumnStat(distinctCount = 2, min = Some(1), max = Some(2), nullCount = 0, - avgLen = 4, maxLen = 4), - attr("key-2-4") -> ColumnStat(distinctCount = 3, min = Some(2), max = Some(4), nullCount = 0, - avgLen = 4, maxLen = 4), - attr("key-2-3") -> ColumnStat(distinctCount = 2, min = Some(2), max = Some(3), nullCount = 0, - avgLen = 4, maxLen = 4) + attr("key-1-5") -> ColumnStat(distinctCount = Some(5), min = Some(1), max = Some(5), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), + attr("key-5-9") -> ColumnStat(distinctCount = Some(5), min = Some(5), max = Some(9), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), + attr("key-1-2") -> ColumnStat(distinctCount = Some(2), min = Some(1), max = Some(2), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), + attr("key-2-4") -> ColumnStat(distinctCount = Some(3), min = Some(2), max = Some(4), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), + attr("key-2-3") -> ColumnStat(distinctCount = Some(2), min = Some(2), max = Some(3), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)) )) private val nameToAttr: Map[String, Attribute] = columnInfo.map(kv => kv._1.name -> kv._1) @@ -70,8 +70,8 @@ class JoinEstimationSuite extends StatsEstimationTestBase { private def estimateByHistogram( leftHistogram: Histogram, rightHistogram: Histogram, - expectedMin: Double, - expectedMax: Double, + expectedMin: Any, + expectedMax: Any, expectedNdv: Long, expectedRows: Long): Unit = { val col1 = attr("key1") @@ -86,9 +86,11 @@ class JoinEstimationSuite extends StatsEstimationTestBase { rowCount = Some(expectedRows), attributeStats = AttributeMap(Seq( col1 -> c1.stats.attributeStats(col1).copy( - distinctCount = expectedNdv, min = Some(expectedMin), max = Some(expectedMax)), + distinctCount = Some(expectedNdv), + min = Some(expectedMin), max = Some(expectedMax)), col2 -> c2.stats.attributeStats(col2).copy( - distinctCount = expectedNdv, min = Some(expectedMin), max = Some(expectedMax)))) + distinctCount = Some(expectedNdv), + min = Some(expectedMin), max = Some(expectedMax)))) ) // Join order should not affect estimation result. @@ -100,9 +102,9 @@ class JoinEstimationSuite extends StatsEstimationTestBase { private def generateJoinChild( col: Attribute, histogram: Histogram, - expectedMin: Double, - expectedMax: Double): LogicalPlan = { - val colStat = inferColumnStat(histogram) + expectedMin: Any, + expectedMax: Any): LogicalPlan = { + val colStat = inferColumnStat(histogram, expectedMin, expectedMax) StatsTestPlan( outputList = Seq(col), rowCount = (histogram.height * histogram.bins.length).toLong, @@ -110,7 +112,11 @@ class JoinEstimationSuite extends StatsEstimationTestBase { } /** Column statistics should be consistent with histograms in tests. */ - private def inferColumnStat(histogram: Histogram): ColumnStat = { + private def inferColumnStat( + histogram: Histogram, + expectedMin: Any, + expectedMax: Any): ColumnStat = { + var ndv = 0L for (i <- histogram.bins.indices) { val bin = histogram.bins(i) @@ -118,8 +124,9 @@ class JoinEstimationSuite extends StatsEstimationTestBase { ndv += bin.ndv } } - ColumnStat(distinctCount = ndv, min = Some(histogram.bins.head.lo), - max = Some(histogram.bins.last.hi), nullCount = 0, avgLen = 4, maxLen = 4, + ColumnStat(distinctCount = Some(ndv), + min = Some(expectedMin), max = Some(expectedMax), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4), histogram = Some(histogram)) } @@ -343,10 +350,10 @@ class JoinEstimationSuite extends StatsEstimationTestBase { rowCount = Some(5 + 3), attributeStats = AttributeMap( // Update null count in column stats. - Seq(nameToAttr("key-1-5") -> columnInfo(nameToAttr("key-1-5")).copy(nullCount = 3), - nameToAttr("key-5-9") -> columnInfo(nameToAttr("key-5-9")).copy(nullCount = 3), - nameToAttr("key-1-2") -> columnInfo(nameToAttr("key-1-2")).copy(nullCount = 5), - nameToAttr("key-2-4") -> columnInfo(nameToAttr("key-2-4")).copy(nullCount = 5)))) + Seq(nameToAttr("key-1-5") -> columnInfo(nameToAttr("key-1-5")).copy(nullCount = Some(3)), + nameToAttr("key-5-9") -> columnInfo(nameToAttr("key-5-9")).copy(nullCount = Some(3)), + nameToAttr("key-1-2") -> columnInfo(nameToAttr("key-1-2")).copy(nullCount = Some(5)), + nameToAttr("key-2-4") -> columnInfo(nameToAttr("key-2-4")).copy(nullCount = Some(5))))) assert(join.stats == expectedStats) } @@ -356,11 +363,11 @@ class JoinEstimationSuite extends StatsEstimationTestBase { val join = Join(table1, table2, Inner, Some(EqualTo(nameToAttr("key-1-5"), nameToAttr("key-1-2")))) // Update column stats for equi-join keys (key-1-5 and key-1-2). - val joinedColStat = ColumnStat(distinctCount = 2, min = Some(1), max = Some(2), nullCount = 0, - avgLen = 4, maxLen = 4) + val joinedColStat = ColumnStat(distinctCount = Some(2), min = Some(1), max = Some(2), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)) // Update column stat for other column if #outputRow / #sideRow < 1 (key-5-9), or keep it // unchanged (key-2-4). - val colStatForkey59 = nameToColInfo("key-5-9")._2.copy(distinctCount = 5 * 3 / 5) + val colStatForkey59 = nameToColInfo("key-5-9")._2.copy(distinctCount = Some(5 * 3 / 5)) val expectedStats = Statistics( sizeInBytes = 3 * (8 + 4 * 4), @@ -379,10 +386,10 @@ class JoinEstimationSuite extends StatsEstimationTestBase { EqualTo(nameToAttr("key-2-4"), nameToAttr("key-2-3"))))) // Update column stats for join keys. - val joinedColStat1 = ColumnStat(distinctCount = 2, min = Some(1), max = Some(2), nullCount = 0, - avgLen = 4, maxLen = 4) - val joinedColStat2 = ColumnStat(distinctCount = 2, min = Some(2), max = Some(3), nullCount = 0, - avgLen = 4, maxLen = 4) + val joinedColStat1 = ColumnStat(distinctCount = Some(2), min = Some(1), max = Some(2), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)) + val joinedColStat2 = ColumnStat(distinctCount = Some(2), min = Some(2), max = Some(3), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)) val expectedStats = Statistics( sizeInBytes = 2 * (8 + 4 * 4), @@ -398,8 +405,8 @@ class JoinEstimationSuite extends StatsEstimationTestBase { // table3 (key-1-2 int, key-2-3 int): (1, 2), (2, 3) val join = Join(table3, table2, LeftOuter, Some(EqualTo(nameToAttr("key-2-3"), nameToAttr("key-2-4")))) - val joinedColStat = ColumnStat(distinctCount = 2, min = Some(2), max = Some(3), nullCount = 0, - avgLen = 4, maxLen = 4) + val joinedColStat = ColumnStat(distinctCount = Some(2), min = Some(2), max = Some(3), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)) val expectedStats = Statistics( sizeInBytes = 2 * (8 + 4 * 4), @@ -416,8 +423,8 @@ class JoinEstimationSuite extends StatsEstimationTestBase { // table3 (key-1-2 int, key-2-3 int): (1, 2), (2, 3) val join = Join(table2, table3, RightOuter, Some(EqualTo(nameToAttr("key-2-4"), nameToAttr("key-2-3")))) - val joinedColStat = ColumnStat(distinctCount = 2, min = Some(2), max = Some(3), nullCount = 0, - avgLen = 4, maxLen = 4) + val joinedColStat = ColumnStat(distinctCount = Some(2), min = Some(2), max = Some(3), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)) val expectedStats = Statistics( sizeInBytes = 2 * (8 + 4 * 4), @@ -466,30 +473,40 @@ class JoinEstimationSuite extends StatsEstimationTestBase { val date = DateTimeUtils.fromJavaDate(Date.valueOf("2016-05-08")) val timestamp = DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("2016-05-08 00:00:01")) mutable.LinkedHashMap[Attribute, ColumnStat]( - AttributeReference("cbool", BooleanType)() -> ColumnStat(distinctCount = 1, - min = Some(false), max = Some(false), nullCount = 0, avgLen = 1, maxLen = 1), - AttributeReference("cbyte", ByteType)() -> ColumnStat(distinctCount = 1, - min = Some(1.toByte), max = Some(1.toByte), nullCount = 0, avgLen = 1, maxLen = 1), - AttributeReference("cshort", ShortType)() -> ColumnStat(distinctCount = 1, - min = Some(1.toShort), max = Some(1.toShort), nullCount = 0, avgLen = 2, maxLen = 2), - AttributeReference("cint", IntegerType)() -> ColumnStat(distinctCount = 1, - min = Some(1), max = Some(1), nullCount = 0, avgLen = 4, maxLen = 4), - AttributeReference("clong", LongType)() -> ColumnStat(distinctCount = 1, - min = Some(1L), max = Some(1L), nullCount = 0, avgLen = 8, maxLen = 8), - AttributeReference("cdouble", DoubleType)() -> ColumnStat(distinctCount = 1, - min = Some(1.0), max = Some(1.0), nullCount = 0, avgLen = 8, maxLen = 8), - AttributeReference("cfloat", FloatType)() -> ColumnStat(distinctCount = 1, - min = Some(1.0f), max = Some(1.0f), nullCount = 0, avgLen = 4, maxLen = 4), - AttributeReference("cdec", DecimalType.SYSTEM_DEFAULT)() -> ColumnStat(distinctCount = 1, - min = Some(dec), max = Some(dec), nullCount = 0, avgLen = 16, maxLen = 16), - AttributeReference("cstring", StringType)() -> ColumnStat(distinctCount = 1, - min = None, max = None, nullCount = 0, avgLen = 3, maxLen = 3), - AttributeReference("cbinary", BinaryType)() -> ColumnStat(distinctCount = 1, - min = None, max = None, nullCount = 0, avgLen = 3, maxLen = 3), - AttributeReference("cdate", DateType)() -> ColumnStat(distinctCount = 1, - min = Some(date), max = Some(date), nullCount = 0, avgLen = 4, maxLen = 4), - AttributeReference("ctimestamp", TimestampType)() -> ColumnStat(distinctCount = 1, - min = Some(timestamp), max = Some(timestamp), nullCount = 0, avgLen = 8, maxLen = 8) + AttributeReference("cbool", BooleanType)() -> ColumnStat(distinctCount = Some(1), + min = Some(false), max = Some(false), + nullCount = Some(0), avgLen = Some(1), maxLen = Some(1)), + AttributeReference("cbyte", ByteType)() -> ColumnStat(distinctCount = Some(1), + min = Some(1.toByte), max = Some(1.toByte), + nullCount = Some(0), avgLen = Some(1), maxLen = Some(1)), + AttributeReference("cshort", ShortType)() -> ColumnStat(distinctCount = Some(1), + min = Some(1.toShort), max = Some(1.toShort), + nullCount = Some(0), avgLen = Some(2), maxLen = Some(2)), + AttributeReference("cint", IntegerType)() -> ColumnStat(distinctCount = Some(1), + min = Some(1), max = Some(1), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), + AttributeReference("clong", LongType)() -> ColumnStat(distinctCount = Some(1), + min = Some(1L), max = Some(1L), + nullCount = Some(0), avgLen = Some(8), maxLen = Some(8)), + AttributeReference("cdouble", DoubleType)() -> ColumnStat(distinctCount = Some(1), + min = Some(1.0), max = Some(1.0), + nullCount = Some(0), avgLen = Some(8), maxLen = Some(8)), + AttributeReference("cfloat", FloatType)() -> ColumnStat(distinctCount = Some(1), + min = Some(1.0f), max = Some(1.0f), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), + AttributeReference("cdec", DecimalType.SYSTEM_DEFAULT)() -> ColumnStat( + distinctCount = Some(1), min = Some(dec), max = Some(dec), + nullCount = Some(0), avgLen = Some(16), maxLen = Some(16)), + AttributeReference("cstring", StringType)() -> ColumnStat(distinctCount = Some(1), + min = None, max = None, nullCount = Some(0), avgLen = Some(3), maxLen = Some(3)), + AttributeReference("cbinary", BinaryType)() -> ColumnStat(distinctCount = Some(1), + min = None, max = None, nullCount = Some(0), avgLen = Some(3), maxLen = Some(3)), + AttributeReference("cdate", DateType)() -> ColumnStat(distinctCount = Some(1), + min = Some(date), max = Some(date), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), + AttributeReference("ctimestamp", TimestampType)() -> ColumnStat(distinctCount = Some(1), + min = Some(timestamp), max = Some(timestamp), + nullCount = Some(0), avgLen = Some(8), maxLen = Some(8)) ) } @@ -520,7 +537,8 @@ class JoinEstimationSuite extends StatsEstimationTestBase { test("join with null column") { val (nullColumn, nullColStat) = (attr("cnull"), - ColumnStat(distinctCount = 0, min = None, max = None, nullCount = 1, avgLen = 4, maxLen = 4)) + ColumnStat(distinctCount = Some(0), min = None, max = None, + nullCount = Some(1), avgLen = Some(4), maxLen = Some(4))) val nullTable = StatsTestPlan( outputList = Seq(nullColumn), rowCount = 1, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/ProjectEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/ProjectEstimationSuite.scala index cda54fa9d64f4..dcb37017329fc 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/ProjectEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/ProjectEstimationSuite.scala @@ -28,10 +28,10 @@ import org.apache.spark.sql.types._ class ProjectEstimationSuite extends StatsEstimationTestBase { test("project with alias") { - val (ar1, colStat1) = (attr("key1"), ColumnStat(distinctCount = 2, min = Some(1), - max = Some(2), nullCount = 0, avgLen = 4, maxLen = 4)) - val (ar2, colStat2) = (attr("key2"), ColumnStat(distinctCount = 1, min = Some(10), - max = Some(10), nullCount = 0, avgLen = 4, maxLen = 4)) + val (ar1, colStat1) = (attr("key1"), ColumnStat(distinctCount = Some(2), min = Some(1), + max = Some(2), nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))) + val (ar2, colStat2) = (attr("key2"), ColumnStat(distinctCount = Some(1), min = Some(10), + max = Some(10), nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))) val child = StatsTestPlan( outputList = Seq(ar1, ar2), @@ -49,8 +49,8 @@ class ProjectEstimationSuite extends StatsEstimationTestBase { } test("project on empty table") { - val (ar1, colStat1) = (attr("key1"), ColumnStat(distinctCount = 0, min = None, max = None, - nullCount = 0, avgLen = 4, maxLen = 4)) + val (ar1, colStat1) = (attr("key1"), ColumnStat(distinctCount = Some(0), min = None, max = None, + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))) val child = StatsTestPlan( outputList = Seq(ar1), rowCount = 0, @@ -71,30 +71,40 @@ class ProjectEstimationSuite extends StatsEstimationTestBase { val t2 = DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("2016-05-09 00:00:02")) val columnInfo: AttributeMap[ColumnStat] = AttributeMap(Seq( - AttributeReference("cbool", BooleanType)() -> ColumnStat(distinctCount = 2, - min = Some(false), max = Some(true), nullCount = 0, avgLen = 1, maxLen = 1), - AttributeReference("cbyte", ByteType)() -> ColumnStat(distinctCount = 2, - min = Some(1.toByte), max = Some(2.toByte), nullCount = 0, avgLen = 1, maxLen = 1), - AttributeReference("cshort", ShortType)() -> ColumnStat(distinctCount = 2, - min = Some(1.toShort), max = Some(3.toShort), nullCount = 0, avgLen = 2, maxLen = 2), - AttributeReference("cint", IntegerType)() -> ColumnStat(distinctCount = 2, - min = Some(1), max = Some(4), nullCount = 0, avgLen = 4, maxLen = 4), - AttributeReference("clong", LongType)() -> ColumnStat(distinctCount = 2, - min = Some(1L), max = Some(5L), nullCount = 0, avgLen = 8, maxLen = 8), - AttributeReference("cdouble", DoubleType)() -> ColumnStat(distinctCount = 2, - min = Some(1.0), max = Some(6.0), nullCount = 0, avgLen = 8, maxLen = 8), - AttributeReference("cfloat", FloatType)() -> ColumnStat(distinctCount = 2, - min = Some(1.0f), max = Some(7.0f), nullCount = 0, avgLen = 4, maxLen = 4), - AttributeReference("cdecimal", DecimalType.SYSTEM_DEFAULT)() -> ColumnStat(distinctCount = 2, - min = Some(dec1), max = Some(dec2), nullCount = 0, avgLen = 16, maxLen = 16), - AttributeReference("cstring", StringType)() -> ColumnStat(distinctCount = 2, - min = None, max = None, nullCount = 0, avgLen = 3, maxLen = 3), - AttributeReference("cbinary", BinaryType)() -> ColumnStat(distinctCount = 2, - min = None, max = None, nullCount = 0, avgLen = 3, maxLen = 3), - AttributeReference("cdate", DateType)() -> ColumnStat(distinctCount = 2, - min = Some(d1), max = Some(d2), nullCount = 0, avgLen = 4, maxLen = 4), - AttributeReference("ctimestamp", TimestampType)() -> ColumnStat(distinctCount = 2, - min = Some(t1), max = Some(t2), nullCount = 0, avgLen = 8, maxLen = 8) + AttributeReference("cbool", BooleanType)() -> ColumnStat(distinctCount = Some(2), + min = Some(false), max = Some(true), + nullCount = Some(0), avgLen = Some(1), maxLen = Some(1)), + AttributeReference("cbyte", ByteType)() -> ColumnStat(distinctCount = Some(2), + min = Some(1), max = Some(2), + nullCount = Some(0), avgLen = Some(1), maxLen = Some(1)), + AttributeReference("cshort", ShortType)() -> ColumnStat(distinctCount = Some(2), + min = Some(1), max = Some(3), + nullCount = Some(0), avgLen = Some(2), maxLen = Some(2)), + AttributeReference("cint", IntegerType)() -> ColumnStat(distinctCount = Some(2), + min = Some(1), max = Some(4), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), + AttributeReference("clong", LongType)() -> ColumnStat(distinctCount = Some(2), + min = Some(1), max = Some(5), + nullCount = Some(0), avgLen = Some(8), maxLen = Some(8)), + AttributeReference("cdouble", DoubleType)() -> ColumnStat(distinctCount = Some(2), + min = Some(1.0), max = Some(6.0), + nullCount = Some(0), avgLen = Some(8), maxLen = Some(8)), + AttributeReference("cfloat", FloatType)() -> ColumnStat(distinctCount = Some(2), + min = Some(1.0), max = Some(7.0), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), + AttributeReference("cdecimal", DecimalType.SYSTEM_DEFAULT)() -> ColumnStat( + distinctCount = Some(2), min = Some(dec1), max = Some(dec2), + nullCount = Some(0), avgLen = Some(16), maxLen = Some(16)), + AttributeReference("cstring", StringType)() -> ColumnStat(distinctCount = Some(2), + min = None, max = None, nullCount = Some(0), avgLen = Some(3), maxLen = Some(3)), + AttributeReference("cbinary", BinaryType)() -> ColumnStat(distinctCount = Some(2), + min = None, max = None, nullCount = Some(0), avgLen = Some(3), maxLen = Some(3)), + AttributeReference("cdate", DateType)() -> ColumnStat(distinctCount = Some(2), + min = Some(d1), max = Some(d2), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), + AttributeReference("ctimestamp", TimestampType)() -> ColumnStat(distinctCount = Some(2), + min = Some(t1), max = Some(t2), + nullCount = Some(0), avgLen = Some(8), maxLen = Some(8)) )) val columnSizes: Map[Attribute, Long] = columnInfo.map(kv => (kv._1, getColSize(kv._1, kv._2))) val child = StatsTestPlan( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala index 31dea2e3e7f1d..9dceca59f5b87 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala @@ -42,8 +42,8 @@ trait StatsEstimationTestBase extends SparkFunSuite { def getColSize(attribute: Attribute, colStat: ColumnStat): Long = attribute.dataType match { // For UTF8String: base + offset + numBytes - case StringType => colStat.avgLen + 8 + 4 - case _ => colStat.avgLen + case StringType => colStat.avgLen.getOrElse(attribute.dataType.defaultSize.toLong) + 8 + 4 + case _ => colStat.avgLen.getOrElse(attribute.dataType.defaultSize) } def attr(colName: String): AttributeReference = AttributeReference(colName, IntegerType)() @@ -54,6 +54,12 @@ trait StatsEstimationTestBase extends SparkFunSuite { val nameToAttr: Map[String, Attribute] = plan.output.map(a => (a.name, a)).toMap AttributeMap(colStats.map(kv => nameToAttr(kv._1) -> kv._2)) } + + /** Get a test ColumnStat with given distinctCount and nullCount */ + def rangeColumnStat(distinctCount: Int, nullCount: Int): ColumnStat = + ColumnStat(distinctCount = Some(distinctCount), + min = Some(1), max = Some(distinctCount), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala index 1122522ccb4cb..640e01336aa75 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala @@ -20,13 +20,15 @@ package org.apache.spark.sql.execution.command import scala.collection.mutable import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.catalog.{CatalogStatistics, CatalogTableType} +import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier} +import org.apache.spark.sql.catalyst.catalog.{CatalogColumnStat, CatalogStatistics, CatalogTableType} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util.ArrayData import org.apache.spark.sql.execution.QueryExecution +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types._ /** @@ -64,12 +66,12 @@ case class AnalyzeColumnCommand( /** * Compute stats for the given columns. - * @return (row count, map from column name to ColumnStats) + * @return (row count, map from column name to CatalogColumnStats) */ private def computeColumnStats( sparkSession: SparkSession, tableIdent: TableIdentifier, - columnNames: Seq[String]): (Long, Map[String, ColumnStat]) = { + columnNames: Seq[String]): (Long, Map[String, CatalogColumnStat]) = { val conf = sparkSession.sessionState.conf val relation = sparkSession.table(tableIdent).logicalPlan @@ -81,7 +83,7 @@ case class AnalyzeColumnCommand( // Make sure the column types are supported for stats gathering. attributesToAnalyze.foreach { attr => - if (!ColumnStat.supportsType(attr.dataType)) { + if (!supportsType(attr.dataType)) { throw new AnalysisException( s"Column ${attr.name} in table $tableIdent is of type ${attr.dataType}, " + "and Spark does not support statistics collection on this column type.") @@ -103,7 +105,7 @@ case class AnalyzeColumnCommand( // will be structs containing all column stats. // The layout of each struct follows the layout of the ColumnStats. val expressions = Count(Literal(1)).toAggregateExpression() +: - attributesToAnalyze.map(ColumnStat.statExprs(_, conf, attributePercentiles)) + attributesToAnalyze.map(statExprs(_, conf, attributePercentiles)) val namedExpressions = expressions.map(e => Alias(e, e.toString)()) val statsRow = new QueryExecution(sparkSession, Aggregate(Nil, namedExpressions, relation)) @@ -111,9 +113,9 @@ case class AnalyzeColumnCommand( val rowCount = statsRow.getLong(0) val columnStats = attributesToAnalyze.zipWithIndex.map { case (attr, i) => - // according to `ColumnStat.statExprs`, the stats struct always have 7 fields. - (attr.name, ColumnStat.rowToColumnStat(statsRow.getStruct(i + 1, 7), attr, rowCount, - attributePercentiles.get(attr))) + // according to `statExprs`, the stats struct always have 7 fields. + (attr.name, rowToColumnStat(statsRow.getStruct(i + 1, 7), attr, rowCount, + attributePercentiles.get(attr)).toCatalogColumnStat(attr.name, attr.dataType)) }.toMap (rowCount, columnStats) } @@ -124,7 +126,7 @@ case class AnalyzeColumnCommand( sparkSession: SparkSession, relation: LogicalPlan): AttributeMap[ArrayData] = { val attrsToGenHistogram = if (conf.histogramEnabled) { - attributesToAnalyze.filter(a => ColumnStat.supportsHistogram(a.dataType)) + attributesToAnalyze.filter(a => supportsHistogram(a.dataType)) } else { Nil } @@ -154,4 +156,120 @@ case class AnalyzeColumnCommand( AttributeMap(attributePercentiles.toSeq) } + /** Returns true iff the we support gathering column statistics on column of the given type. */ + private def supportsType(dataType: DataType): Boolean = dataType match { + case _: IntegralType => true + case _: DecimalType => true + case DoubleType | FloatType => true + case BooleanType => true + case DateType => true + case TimestampType => true + case BinaryType | StringType => true + case _ => false + } + + /** Returns true iff the we support gathering histogram on column of the given type. */ + private def supportsHistogram(dataType: DataType): Boolean = dataType match { + case _: IntegralType => true + case _: DecimalType => true + case DoubleType | FloatType => true + case DateType => true + case TimestampType => true + case _ => false + } + + /** + * Constructs an expression to compute column statistics for a given column. + * + * The expression should create a single struct column with the following schema: + * distinctCount: Long, min: T, max: T, nullCount: Long, avgLen: Long, maxLen: Long, + * distinctCountsForIntervals: Array[Long] + * + * Together with [[rowToColumnStat]], this function is used to create [[ColumnStat]] and + * as a result should stay in sync with it. + */ + private def statExprs( + col: Attribute, + conf: SQLConf, + colPercentiles: AttributeMap[ArrayData]): CreateNamedStruct = { + def struct(exprs: Expression*): CreateNamedStruct = CreateStruct(exprs.map { expr => + expr.transformUp { case af: AggregateFunction => af.toAggregateExpression() } + }) + val one = Literal(1, LongType) + + // the approximate ndv (num distinct value) should never be larger than the number of rows + val numNonNulls = if (col.nullable) Count(col) else Count(one) + val ndv = Least(Seq(HyperLogLogPlusPlus(col, conf.ndvMaxError), numNonNulls)) + val numNulls = Subtract(Count(one), numNonNulls) + val defaultSize = Literal(col.dataType.defaultSize, LongType) + val nullArray = Literal(null, ArrayType(LongType)) + + def fixedLenTypeStruct: CreateNamedStruct = { + val genHistogram = + supportsHistogram(col.dataType) && colPercentiles.contains(col) + val intervalNdvsExpr = if (genHistogram) { + ApproxCountDistinctForIntervals(col, + Literal(colPercentiles(col), ArrayType(col.dataType)), conf.ndvMaxError) + } else { + nullArray + } + // For fixed width types, avg size should be the same as max size. + struct(ndv, Cast(Min(col), col.dataType), Cast(Max(col), col.dataType), numNulls, + defaultSize, defaultSize, intervalNdvsExpr) + } + + col.dataType match { + case _: IntegralType => fixedLenTypeStruct + case _: DecimalType => fixedLenTypeStruct + case DoubleType | FloatType => fixedLenTypeStruct + case BooleanType => fixedLenTypeStruct + case DateType => fixedLenTypeStruct + case TimestampType => fixedLenTypeStruct + case BinaryType | StringType => + // For string and binary type, we don't compute min, max or histogram + val nullLit = Literal(null, col.dataType) + struct( + ndv, nullLit, nullLit, numNulls, + // Set avg/max size to default size if all the values are null or there is no value. + Coalesce(Seq(Ceil(Average(Length(col))), defaultSize)), + Coalesce(Seq(Cast(Max(Length(col)), LongType), defaultSize)), + nullArray) + case _ => + throw new AnalysisException("Analyzing column statistics is not supported for column " + + s"${col.name} of data type: ${col.dataType}.") + } + } + + /** Convert a struct for column stats (defined in `statExprs`) into [[ColumnStat]]. */ + private def rowToColumnStat( + row: InternalRow, + attr: Attribute, + rowCount: Long, + percentiles: Option[ArrayData]): ColumnStat = { + // The first 6 fields are basic column stats, the 7th is ndvs for histogram bins. + val cs = ColumnStat( + distinctCount = Option(BigInt(row.getLong(0))), + // for string/binary min/max, get should return null + min = Option(row.get(1, attr.dataType)), + max = Option(row.get(2, attr.dataType)), + nullCount = Option(BigInt(row.getLong(3))), + avgLen = Option(row.getLong(4)), + maxLen = Option(row.getLong(5)) + ) + if (row.isNullAt(6) || cs.nullCount.isEmpty) { + cs + } else { + val ndvs = row.getArray(6).toLongArray() + assert(percentiles.get.numElements() == ndvs.length + 1) + val endpoints = percentiles.get.toArray[Any](attr.dataType).map(_.toString.toDouble) + // Construct equi-height histogram + val bins = ndvs.zipWithIndex.map { case (ndv, i) => + HistogramBin(endpoints(i), endpoints(i + 1), ndv) + } + val nonNullRows = rowCount - cs.nullCount.get + val histogram = Histogram(nonNullRows.toDouble / ndvs.length, bins) + cs.copy(histogram = Some(histogram)) + } + } + } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala index e400975f19708..44749190c79eb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala @@ -695,10 +695,11 @@ case class DescribeColumnCommand( // Show column stats when EXTENDED or FORMATTED is specified. buffer += Row("min", cs.flatMap(_.min.map(_.toString)).getOrElse("NULL")) buffer += Row("max", cs.flatMap(_.max.map(_.toString)).getOrElse("NULL")) - buffer += Row("num_nulls", cs.map(_.nullCount.toString).getOrElse("NULL")) - buffer += Row("distinct_count", cs.map(_.distinctCount.toString).getOrElse("NULL")) - buffer += Row("avg_col_len", cs.map(_.avgLen.toString).getOrElse("NULL")) - buffer += Row("max_col_len", cs.map(_.maxLen.toString).getOrElse("NULL")) + buffer += Row("num_nulls", cs.flatMap(_.nullCount.map(_.toString)).getOrElse("NULL")) + buffer += Row("distinct_count", + cs.flatMap(_.distinctCount.map(_.toString)).getOrElse("NULL")) + buffer += Row("avg_col_len", cs.flatMap(_.avgLen.map(_.toString)).getOrElse("NULL")) + buffer += Row("max_col_len", cs.flatMap(_.maxLen.map(_.toString)).getOrElse("NULL")) val histDesc = for { c <- cs hist <- c.histogram diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala index b11e798532056..ed4ea0231f1a7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql import scala.collection.mutable import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.catalog.CatalogColumnStat import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext @@ -95,7 +96,8 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared assert(fetchedStats2.get.sizeInBytes == 0) val expectedColStat = - "key" -> ColumnStat(0, None, None, 0, IntegerType.defaultSize, IntegerType.defaultSize) + "key" -> CatalogColumnStat(Some(0), None, None, Some(0), + Some(IntegerType.defaultSize), Some(IntegerType.defaultSize)) // There won't be histogram for empty column. Seq("true", "false").foreach { histogramEnabled => @@ -156,7 +158,7 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared Seq(stats, statsWithHgms).foreach { s => s.zip(df.schema).foreach { case ((k, v), field) => withClue(s"column $k with type ${field.dataType}") { - val roundtrip = ColumnStat.fromMap("table_is_foo", field, v.toMap(k, field.dataType)) + val roundtrip = CatalogColumnStat.fromMap("table_is_foo", field.name, v.toMap(k)) assert(roundtrip == Some(v)) } } @@ -187,7 +189,8 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared }.mkString(", ")) val expectedColStats = dataTypes.map { case (tpe, idx) => - (s"col$idx", ColumnStat(0, None, None, 1, tpe.defaultSize.toLong, tpe.defaultSize.toLong)) + (s"col$idx", CatalogColumnStat(Some(0), None, None, Some(1), + Some(tpe.defaultSize.toLong), Some(tpe.defaultSize.toLong))) } // There won't be histograms for null columns. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionTestBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionTestBase.scala index 65ccc1915882f..bf4abb6e625c8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionTestBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionTestBase.scala @@ -24,8 +24,8 @@ import scala.collection.mutable import scala.util.Random import org.apache.spark.sql.catalyst.{QualifiedTableName, TableIdentifier} -import org.apache.spark.sql.catalyst.catalog.{CatalogStatistics, CatalogTable, HiveTableRelation} -import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Histogram, HistogramBin, LogicalPlan} +import org.apache.spark.sql.catalyst.catalog.{CatalogColumnStat, CatalogStatistics, CatalogTable, HiveTableRelation} +import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Histogram, HistogramBin, HistogramSerializer, LogicalPlan} import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf} @@ -67,18 +67,21 @@ abstract class StatisticsCollectionTestBase extends QueryTest with SQLTestUtils /** A mapping from column to the stats collected. */ protected val stats = mutable.LinkedHashMap( - "cbool" -> ColumnStat(2, Some(false), Some(true), 1, 1, 1), - "cbyte" -> ColumnStat(2, Some(1.toByte), Some(2.toByte), 1, 1, 1), - "cshort" -> ColumnStat(2, Some(1.toShort), Some(3.toShort), 1, 2, 2), - "cint" -> ColumnStat(2, Some(1), Some(4), 1, 4, 4), - "clong" -> ColumnStat(2, Some(1L), Some(5L), 1, 8, 8), - "cdouble" -> ColumnStat(2, Some(1.0), Some(6.0), 1, 8, 8), - "cfloat" -> ColumnStat(2, Some(1.0f), Some(7.0f), 1, 4, 4), - "cdecimal" -> ColumnStat(2, Some(Decimal(dec1)), Some(Decimal(dec2)), 1, 16, 16), - "cstring" -> ColumnStat(2, None, None, 1, 3, 3), - "cbinary" -> ColumnStat(2, None, None, 1, 3, 3), - "cdate" -> ColumnStat(2, Some(d1Internal), Some(d2Internal), 1, 4, 4), - "ctimestamp" -> ColumnStat(2, Some(t1Internal), Some(t2Internal), 1, 8, 8) + "cbool" -> CatalogColumnStat(Some(2), Some("false"), Some("true"), Some(1), Some(1), Some(1)), + "cbyte" -> CatalogColumnStat(Some(2), Some("1"), Some("2"), Some(1), Some(1), Some(1)), + "cshort" -> CatalogColumnStat(Some(2), Some("1"), Some("3"), Some(1), Some(2), Some(2)), + "cint" -> CatalogColumnStat(Some(2), Some("1"), Some("4"), Some(1), Some(4), Some(4)), + "clong" -> CatalogColumnStat(Some(2), Some("1"), Some("5"), Some(1), Some(8), Some(8)), + "cdouble" -> CatalogColumnStat(Some(2), Some("1.0"), Some("6.0"), Some(1), Some(8), Some(8)), + "cfloat" -> CatalogColumnStat(Some(2), Some("1.0"), Some("7.0"), Some(1), Some(4), Some(4)), + "cdecimal" -> CatalogColumnStat(Some(2), Some(dec1.toString), Some(dec2.toString), Some(1), + Some(16), Some(16)), + "cstring" -> CatalogColumnStat(Some(2), None, None, Some(1), Some(3), Some(3)), + "cbinary" -> CatalogColumnStat(Some(2), None, None, Some(1), Some(3), Some(3)), + "cdate" -> CatalogColumnStat(Some(2), Some(d1.toString), Some(d2.toString), Some(1), Some(4), + Some(4)), + "ctimestamp" -> CatalogColumnStat(Some(2), Some(t1.toString), Some(t2.toString), Some(1), + Some(8), Some(8)) ) /** @@ -110,6 +113,110 @@ abstract class StatisticsCollectionTestBase extends QueryTest with SQLTestUtils colStats } + val expectedSerializedColStats = Map( + "spark.sql.statistics.colStats.cbinary.avgLen" -> "3", + "spark.sql.statistics.colStats.cbinary.distinctCount" -> "2", + "spark.sql.statistics.colStats.cbinary.maxLen" -> "3", + "spark.sql.statistics.colStats.cbinary.nullCount" -> "1", + "spark.sql.statistics.colStats.cbinary.version" -> "1", + "spark.sql.statistics.colStats.cbool.avgLen" -> "1", + "spark.sql.statistics.colStats.cbool.distinctCount" -> "2", + "spark.sql.statistics.colStats.cbool.max" -> "true", + "spark.sql.statistics.colStats.cbool.maxLen" -> "1", + "spark.sql.statistics.colStats.cbool.min" -> "false", + "spark.sql.statistics.colStats.cbool.nullCount" -> "1", + "spark.sql.statistics.colStats.cbool.version" -> "1", + "spark.sql.statistics.colStats.cbyte.avgLen" -> "1", + "spark.sql.statistics.colStats.cbyte.distinctCount" -> "2", + "spark.sql.statistics.colStats.cbyte.max" -> "2", + "spark.sql.statistics.colStats.cbyte.maxLen" -> "1", + "spark.sql.statistics.colStats.cbyte.min" -> "1", + "spark.sql.statistics.colStats.cbyte.nullCount" -> "1", + "spark.sql.statistics.colStats.cbyte.version" -> "1", + "spark.sql.statistics.colStats.cdate.avgLen" -> "4", + "spark.sql.statistics.colStats.cdate.distinctCount" -> "2", + "spark.sql.statistics.colStats.cdate.max" -> "2016-05-09", + "spark.sql.statistics.colStats.cdate.maxLen" -> "4", + "spark.sql.statistics.colStats.cdate.min" -> "2016-05-08", + "spark.sql.statistics.colStats.cdate.nullCount" -> "1", + "spark.sql.statistics.colStats.cdate.version" -> "1", + "spark.sql.statistics.colStats.cdecimal.avgLen" -> "16", + "spark.sql.statistics.colStats.cdecimal.distinctCount" -> "2", + "spark.sql.statistics.colStats.cdecimal.max" -> "8.000000000000000000", + "spark.sql.statistics.colStats.cdecimal.maxLen" -> "16", + "spark.sql.statistics.colStats.cdecimal.min" -> "1.000000000000000000", + "spark.sql.statistics.colStats.cdecimal.nullCount" -> "1", + "spark.sql.statistics.colStats.cdecimal.version" -> "1", + "spark.sql.statistics.colStats.cdouble.avgLen" -> "8", + "spark.sql.statistics.colStats.cdouble.distinctCount" -> "2", + "spark.sql.statistics.colStats.cdouble.max" -> "6.0", + "spark.sql.statistics.colStats.cdouble.maxLen" -> "8", + "spark.sql.statistics.colStats.cdouble.min" -> "1.0", + "spark.sql.statistics.colStats.cdouble.nullCount" -> "1", + "spark.sql.statistics.colStats.cdouble.version" -> "1", + "spark.sql.statistics.colStats.cfloat.avgLen" -> "4", + "spark.sql.statistics.colStats.cfloat.distinctCount" -> "2", + "spark.sql.statistics.colStats.cfloat.max" -> "7.0", + "spark.sql.statistics.colStats.cfloat.maxLen" -> "4", + "spark.sql.statistics.colStats.cfloat.min" -> "1.0", + "spark.sql.statistics.colStats.cfloat.nullCount" -> "1", + "spark.sql.statistics.colStats.cfloat.version" -> "1", + "spark.sql.statistics.colStats.cint.avgLen" -> "4", + "spark.sql.statistics.colStats.cint.distinctCount" -> "2", + "spark.sql.statistics.colStats.cint.max" -> "4", + "spark.sql.statistics.colStats.cint.maxLen" -> "4", + "spark.sql.statistics.colStats.cint.min" -> "1", + "spark.sql.statistics.colStats.cint.nullCount" -> "1", + "spark.sql.statistics.colStats.cint.version" -> "1", + "spark.sql.statistics.colStats.clong.avgLen" -> "8", + "spark.sql.statistics.colStats.clong.distinctCount" -> "2", + "spark.sql.statistics.colStats.clong.max" -> "5", + "spark.sql.statistics.colStats.clong.maxLen" -> "8", + "spark.sql.statistics.colStats.clong.min" -> "1", + "spark.sql.statistics.colStats.clong.nullCount" -> "1", + "spark.sql.statistics.colStats.clong.version" -> "1", + "spark.sql.statistics.colStats.cshort.avgLen" -> "2", + "spark.sql.statistics.colStats.cshort.distinctCount" -> "2", + "spark.sql.statistics.colStats.cshort.max" -> "3", + "spark.sql.statistics.colStats.cshort.maxLen" -> "2", + "spark.sql.statistics.colStats.cshort.min" -> "1", + "spark.sql.statistics.colStats.cshort.nullCount" -> "1", + "spark.sql.statistics.colStats.cshort.version" -> "1", + "spark.sql.statistics.colStats.cstring.avgLen" -> "3", + "spark.sql.statistics.colStats.cstring.distinctCount" -> "2", + "spark.sql.statistics.colStats.cstring.maxLen" -> "3", + "spark.sql.statistics.colStats.cstring.nullCount" -> "1", + "spark.sql.statistics.colStats.cstring.version" -> "1", + "spark.sql.statistics.colStats.ctimestamp.avgLen" -> "8", + "spark.sql.statistics.colStats.ctimestamp.distinctCount" -> "2", + "spark.sql.statistics.colStats.ctimestamp.max" -> "2016-05-09 00:00:02.0", + "spark.sql.statistics.colStats.ctimestamp.maxLen" -> "8", + "spark.sql.statistics.colStats.ctimestamp.min" -> "2016-05-08 00:00:01.0", + "spark.sql.statistics.colStats.ctimestamp.nullCount" -> "1", + "spark.sql.statistics.colStats.ctimestamp.version" -> "1" + ) + + val expectedSerializedHistograms = Map( + "spark.sql.statistics.colStats.cbyte.histogram" -> + HistogramSerializer.serialize(statsWithHgms("cbyte").histogram.get), + "spark.sql.statistics.colStats.cshort.histogram" -> + HistogramSerializer.serialize(statsWithHgms("cshort").histogram.get), + "spark.sql.statistics.colStats.cint.histogram" -> + HistogramSerializer.serialize(statsWithHgms("cint").histogram.get), + "spark.sql.statistics.colStats.clong.histogram" -> + HistogramSerializer.serialize(statsWithHgms("clong").histogram.get), + "spark.sql.statistics.colStats.cdouble.histogram" -> + HistogramSerializer.serialize(statsWithHgms("cdouble").histogram.get), + "spark.sql.statistics.colStats.cfloat.histogram" -> + HistogramSerializer.serialize(statsWithHgms("cfloat").histogram.get), + "spark.sql.statistics.colStats.cdecimal.histogram" -> + HistogramSerializer.serialize(statsWithHgms("cdecimal").histogram.get), + "spark.sql.statistics.colStats.cdate.histogram" -> + HistogramSerializer.serialize(statsWithHgms("cdate").histogram.get), + "spark.sql.statistics.colStats.ctimestamp.histogram" -> + HistogramSerializer.serialize(statsWithHgms("ctimestamp").histogram.get) + ) + private val randomName = new Random(31) def getCatalogTable(tableName: String): CatalogTable = { @@ -151,7 +258,7 @@ abstract class StatisticsCollectionTestBase extends QueryTest with SQLTestUtils */ def checkColStats( df: DataFrame, - colStats: mutable.LinkedHashMap[String, ColumnStat]): Unit = { + colStats: mutable.LinkedHashMap[String, CatalogColumnStat]): Unit = { val tableName = "column_stats_test_" + randomName.nextInt(1000) withTable(tableName) { df.write.saveAsTable(tableName) @@ -161,14 +268,24 @@ abstract class StatisticsCollectionTestBase extends QueryTest with SQLTestUtils colStats.keys.mkString(", ")) // Validate statistics - val table = getCatalogTable(tableName) - assert(table.stats.isDefined) - assert(table.stats.get.colStats.size == colStats.size) - - colStats.foreach { case (k, v) => - withClue(s"column $k") { - assert(table.stats.get.colStats(k) == v) - } + validateColStats(tableName, colStats) + } + } + + /** + * Validate if the given catalog table has the provided statistics. + */ + def validateColStats( + tableName: String, + colStats: mutable.LinkedHashMap[String, CatalogColumnStat]): Unit = { + + val table = getCatalogTable(tableName) + assert(table.stats.isDefined) + assert(table.stats.get.colStats.size == colStats.size) + + colStats.foreach { case (k, v) => + withClue(s"column $k") { + assert(table.stats.get.colStats(k) == v) } } } @@ -215,12 +332,13 @@ abstract class StatisticsCollectionTestBase extends QueryTest with SQLTestUtils case catalogRel: HiveTableRelation => (catalogRel, catalogRel.tableMeta) case logicalRel: LogicalRelation => (logicalRel, logicalRel.catalogTable.get) }.head - val emptyColStat = ColumnStat(0, None, None, 0, 4, 4) + val emptyColStat = ColumnStat(Some(0), None, None, Some(0), Some(4), Some(4)) + val emptyCatalogColStat = CatalogColumnStat(Some(0), None, None, Some(0), Some(4), Some(4)) // Check catalog statistics assert(catalogTable.stats.isDefined) assert(catalogTable.stats.get.sizeInBytes == 0) assert(catalogTable.stats.get.rowCount == Some(0)) - assert(catalogTable.stats.get.colStats == Map("c1" -> emptyColStat)) + assert(catalogTable.stats.get.colStats == Map("c1" -> emptyCatalogColStat)) // Check relation statistics withSQLConf(SQLConf.CBO_ENABLED.key -> "true") { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index 1ee1d57b8ebe1..28c340a176d91 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -663,14 +663,10 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat requireTableExists(db, table) val rawTable = getRawTable(db, table) - // For datasource tables and hive serde tables created by spark 2.1 or higher, - // the data schema is stored in the table properties. - val schema = restoreTableMetadata(rawTable).schema - // convert table statistics to properties so that we can persist them through hive client val statsProperties = if (stats.isDefined) { - statsToProperties(stats.get, schema) + statsToProperties(stats.get) } else { new mutable.HashMap[String, String]() } @@ -1028,9 +1024,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat currentFullPath } - private def statsToProperties( - stats: CatalogStatistics, - schema: StructType): Map[String, String] = { + private def statsToProperties(stats: CatalogStatistics): Map[String, String] = { val statsProperties = new mutable.HashMap[String, String]() statsProperties += STATISTICS_TOTAL_SIZE -> stats.sizeInBytes.toString() @@ -1038,11 +1032,12 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat statsProperties += STATISTICS_NUM_ROWS -> stats.rowCount.get.toString() } - val colNameTypeMap: Map[String, DataType] = - schema.fields.map(f => (f.name, f.dataType)).toMap stats.colStats.foreach { case (colName, colStat) => - colStat.toMap(colName, colNameTypeMap(colName)).foreach { case (k, v) => - statsProperties += (columnStatKeyPropName(colName, k) -> v) + colStat.toMap(colName).foreach { case (k, v) => + // Fully qualified name used in table properties for a particular column stat. + // For example, for column "mycol", and "min" stat, this should return + // "spark.sql.statistics.colStats.mycol.min". + statsProperties += (STATISTICS_COL_STATS_PREFIX + k -> v) } } @@ -1058,23 +1053,20 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat if (statsProps.isEmpty) { None } else { + val colStats = new mutable.HashMap[String, CatalogColumnStat] + val colStatsProps = properties.filterKeys(_.startsWith(STATISTICS_COL_STATS_PREFIX)).map { + case (k, v) => k.drop(STATISTICS_COL_STATS_PREFIX.length) -> v + } - val colStats = new mutable.HashMap[String, ColumnStat] - - // For each column, recover its column stats. Note that this is currently a O(n^2) operation, - // but given the number of columns it usually not enormous, this is probably OK as a start. - // If we want to map this a linear operation, we'd need a stronger contract between the - // naming convention used for serialization. - schema.foreach { field => - if (statsProps.contains(columnStatKeyPropName(field.name, ColumnStat.KEY_VERSION))) { - // If "version" field is defined, then the column stat is defined. - val keyPrefix = columnStatKeyPropName(field.name, "") - val colStatMap = statsProps.filterKeys(_.startsWith(keyPrefix)).map { case (k, v) => - (k.drop(keyPrefix.length), v) - } - ColumnStat.fromMap(table, field, colStatMap).foreach { cs => - colStats += field.name -> cs - } + // Find all the column names by matching the KEY_VERSION properties for them. + colStatsProps.keys.filter { + k => k.endsWith(CatalogColumnStat.KEY_VERSION) + }.map { k => + k.dropRight(CatalogColumnStat.KEY_VERSION.length + 1) + }.foreach { fieldName => + // and for each, create a column stat. + CatalogColumnStat.fromMap(table, fieldName, colStatsProps).foreach { cs => + colStats += fieldName -> cs } } @@ -1093,14 +1085,10 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat val rawTable = getRawTable(db, table) - // For datasource tables and hive serde tables created by spark 2.1 or higher, - // the data schema is stored in the table properties. - val schema = restoreTableMetadata(rawTable).schema - // convert partition statistics to properties so that we can persist them through hive api val withStatsProps = lowerCasedParts.map { p => if (p.stats.isDefined) { - val statsProperties = statsToProperties(p.stats.get, schema) + val statsProperties = statsToProperties(p.stats.get) p.copy(parameters = p.parameters ++ statsProperties) } else { p @@ -1310,15 +1298,6 @@ object HiveExternalCatalog { val EMPTY_DATA_SCHEMA = new StructType() .add("col", "array", nullable = true, comment = "from deserializer") - /** - * Returns the fully qualified name used in table properties for a particular column stat. - * For example, for column "mycol", and "min" stat, this should return - * "spark.sql.statistics.colStats.mycol.min". - */ - private def columnStatKeyPropName(columnName: String, statKey: String): String = { - STATISTICS_COL_STATS_PREFIX + columnName + "." + statKey - } - // A persisted data source table always store its schema in the catalog. private def getSchemaFromTableProperties(metadata: CatalogTable): StructType = { val errorMessage = "Could not read schema from the hive metastore because it is corrupted." diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index 3af8af0814bb4..61cec82984795 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -28,7 +28,7 @@ import org.apache.hadoop.hive.common.StatsSetupConst import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.NoSuchPartitionException -import org.apache.spark.sql.catalyst.catalog.{CatalogStatistics, HiveTableRelation} +import org.apache.spark.sql.catalyst.catalog.{CatalogColumnStat, CatalogStatistics, HiveTableRelation} import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, HistogramBin, HistogramSerializer} import org.apache.spark.sql.catalyst.util.{DateTimeUtils, StringUtils} import org.apache.spark.sql.execution.command.DDLUtils @@ -177,8 +177,8 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto val fetchedStats0 = checkTableStats(table, hasSizeInBytes = true, expectedRowCounts = Some(2)) assert(fetchedStats0.get.colStats == Map( - "a" -> ColumnStat(2, Some(1), Some(2), 0, 4, 4), - "b" -> ColumnStat(1, Some(1), Some(1), 0, 4, 4))) + "a" -> CatalogColumnStat(Some(2), Some("1"), Some("2"), Some(0), Some(4), Some(4)), + "b" -> CatalogColumnStat(Some(1), Some("1"), Some("1"), Some(0), Some(4), Some(4)))) } } @@ -208,8 +208,8 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto val fetchedStats1 = checkTableStats(table, hasSizeInBytes = true, expectedRowCounts = Some(1)).get assert(fetchedStats1.colStats == Map( - "C1" -> ColumnStat(distinctCount = 1, min = Some(1), max = Some(1), nullCount = 0, - avgLen = 4, maxLen = 4))) + "C1" -> CatalogColumnStat(distinctCount = Some(1), min = Some("1"), max = Some("1"), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)))) } } @@ -596,7 +596,8 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS c1") val fetchedStats0 = checkTableStats(table, hasSizeInBytes = true, expectedRowCounts = Some(0)) - assert(fetchedStats0.get.colStats == Map("c1" -> ColumnStat(0, None, None, 0, 4, 4))) + assert(fetchedStats0.get.colStats == + Map("c1" -> CatalogColumnStat(Some(0), None, None, Some(0), Some(4), Some(4)))) // Insert new data and analyze: have the latest column stats. sql(s"INSERT INTO TABLE $table SELECT 1, 'a', 10.0") @@ -604,18 +605,18 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto val fetchedStats1 = checkTableStats(table, hasSizeInBytes = true, expectedRowCounts = Some(1)).get assert(fetchedStats1.colStats == Map( - "c1" -> ColumnStat(distinctCount = 1, min = Some(1), max = Some(1), nullCount = 0, - avgLen = 4, maxLen = 4))) + "c1" -> CatalogColumnStat(distinctCount = Some(1), min = Some("1"), max = Some("1"), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)))) // Analyze another column: since the table is not changed, the precious column stats are kept. sql(s"ANALYZE TABLE $table COMPUTE STATISTICS FOR COLUMNS c2") val fetchedStats2 = checkTableStats(table, hasSizeInBytes = true, expectedRowCounts = Some(1)).get assert(fetchedStats2.colStats == Map( - "c1" -> ColumnStat(distinctCount = 1, min = Some(1), max = Some(1), nullCount = 0, - avgLen = 4, maxLen = 4), - "c2" -> ColumnStat(distinctCount = 1, min = None, max = None, nullCount = 0, - avgLen = 1, maxLen = 1))) + "c1" -> CatalogColumnStat(distinctCount = Some(1), min = Some("1"), max = Some("1"), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), + "c2" -> CatalogColumnStat(distinctCount = Some(1), min = None, max = None, + nullCount = Some(0), avgLen = Some(1), maxLen = Some(1)))) // Insert new data and analyze: stale column stats are removed and newly collected column // stats are added. @@ -624,10 +625,10 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto val fetchedStats3 = checkTableStats(table, hasSizeInBytes = true, expectedRowCounts = Some(2)).get assert(fetchedStats3.colStats == Map( - "c1" -> ColumnStat(distinctCount = 2, min = Some(1), max = Some(2), nullCount = 0, - avgLen = 4, maxLen = 4), - "c3" -> ColumnStat(distinctCount = 2, min = Some(10.0), max = Some(20.0), nullCount = 0, - avgLen = 8, maxLen = 8))) + "c1" -> CatalogColumnStat(distinctCount = Some(2), min = Some("1"), max = Some("2"), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), + "c3" -> CatalogColumnStat(distinctCount = Some(2), min = Some("10.0"), max = Some("20.0"), + nullCount = Some(0), avgLen = Some(8), maxLen = Some(8)))) } } @@ -999,115 +1000,11 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto test("verify serialized column stats after analyzing columns") { import testImplicits._ - val tableName = "column_stats_test2" + val tableName = "column_stats_test_ser" // (data.head.productArity - 1) because the last column does not support stats collection. assert(stats.size == data.head.productArity - 1) val df = data.toDF(stats.keys.toSeq :+ "carray" : _*) - val expectedSerializedColStats = Map( - "spark.sql.statistics.colStats.cbinary.avgLen" -> "3", - "spark.sql.statistics.colStats.cbinary.distinctCount" -> "2", - "spark.sql.statistics.colStats.cbinary.maxLen" -> "3", - "spark.sql.statistics.colStats.cbinary.nullCount" -> "1", - "spark.sql.statistics.colStats.cbinary.version" -> "1", - "spark.sql.statistics.colStats.cbool.avgLen" -> "1", - "spark.sql.statistics.colStats.cbool.distinctCount" -> "2", - "spark.sql.statistics.colStats.cbool.max" -> "true", - "spark.sql.statistics.colStats.cbool.maxLen" -> "1", - "spark.sql.statistics.colStats.cbool.min" -> "false", - "spark.sql.statistics.colStats.cbool.nullCount" -> "1", - "spark.sql.statistics.colStats.cbool.version" -> "1", - "spark.sql.statistics.colStats.cbyte.avgLen" -> "1", - "spark.sql.statistics.colStats.cbyte.distinctCount" -> "2", - "spark.sql.statistics.colStats.cbyte.max" -> "2", - "spark.sql.statistics.colStats.cbyte.maxLen" -> "1", - "spark.sql.statistics.colStats.cbyte.min" -> "1", - "spark.sql.statistics.colStats.cbyte.nullCount" -> "1", - "spark.sql.statistics.colStats.cbyte.version" -> "1", - "spark.sql.statistics.colStats.cdate.avgLen" -> "4", - "spark.sql.statistics.colStats.cdate.distinctCount" -> "2", - "spark.sql.statistics.colStats.cdate.max" -> "2016-05-09", - "spark.sql.statistics.colStats.cdate.maxLen" -> "4", - "spark.sql.statistics.colStats.cdate.min" -> "2016-05-08", - "spark.sql.statistics.colStats.cdate.nullCount" -> "1", - "spark.sql.statistics.colStats.cdate.version" -> "1", - "spark.sql.statistics.colStats.cdecimal.avgLen" -> "16", - "spark.sql.statistics.colStats.cdecimal.distinctCount" -> "2", - "spark.sql.statistics.colStats.cdecimal.max" -> "8.000000000000000000", - "spark.sql.statistics.colStats.cdecimal.maxLen" -> "16", - "spark.sql.statistics.colStats.cdecimal.min" -> "1.000000000000000000", - "spark.sql.statistics.colStats.cdecimal.nullCount" -> "1", - "spark.sql.statistics.colStats.cdecimal.version" -> "1", - "spark.sql.statistics.colStats.cdouble.avgLen" -> "8", - "spark.sql.statistics.colStats.cdouble.distinctCount" -> "2", - "spark.sql.statistics.colStats.cdouble.max" -> "6.0", - "spark.sql.statistics.colStats.cdouble.maxLen" -> "8", - "spark.sql.statistics.colStats.cdouble.min" -> "1.0", - "spark.sql.statistics.colStats.cdouble.nullCount" -> "1", - "spark.sql.statistics.colStats.cdouble.version" -> "1", - "spark.sql.statistics.colStats.cfloat.avgLen" -> "4", - "spark.sql.statistics.colStats.cfloat.distinctCount" -> "2", - "spark.sql.statistics.colStats.cfloat.max" -> "7.0", - "spark.sql.statistics.colStats.cfloat.maxLen" -> "4", - "spark.sql.statistics.colStats.cfloat.min" -> "1.0", - "spark.sql.statistics.colStats.cfloat.nullCount" -> "1", - "spark.sql.statistics.colStats.cfloat.version" -> "1", - "spark.sql.statistics.colStats.cint.avgLen" -> "4", - "spark.sql.statistics.colStats.cint.distinctCount" -> "2", - "spark.sql.statistics.colStats.cint.max" -> "4", - "spark.sql.statistics.colStats.cint.maxLen" -> "4", - "spark.sql.statistics.colStats.cint.min" -> "1", - "spark.sql.statistics.colStats.cint.nullCount" -> "1", - "spark.sql.statistics.colStats.cint.version" -> "1", - "spark.sql.statistics.colStats.clong.avgLen" -> "8", - "spark.sql.statistics.colStats.clong.distinctCount" -> "2", - "spark.sql.statistics.colStats.clong.max" -> "5", - "spark.sql.statistics.colStats.clong.maxLen" -> "8", - "spark.sql.statistics.colStats.clong.min" -> "1", - "spark.sql.statistics.colStats.clong.nullCount" -> "1", - "spark.sql.statistics.colStats.clong.version" -> "1", - "spark.sql.statistics.colStats.cshort.avgLen" -> "2", - "spark.sql.statistics.colStats.cshort.distinctCount" -> "2", - "spark.sql.statistics.colStats.cshort.max" -> "3", - "spark.sql.statistics.colStats.cshort.maxLen" -> "2", - "spark.sql.statistics.colStats.cshort.min" -> "1", - "spark.sql.statistics.colStats.cshort.nullCount" -> "1", - "spark.sql.statistics.colStats.cshort.version" -> "1", - "spark.sql.statistics.colStats.cstring.avgLen" -> "3", - "spark.sql.statistics.colStats.cstring.distinctCount" -> "2", - "spark.sql.statistics.colStats.cstring.maxLen" -> "3", - "spark.sql.statistics.colStats.cstring.nullCount" -> "1", - "spark.sql.statistics.colStats.cstring.version" -> "1", - "spark.sql.statistics.colStats.ctimestamp.avgLen" -> "8", - "spark.sql.statistics.colStats.ctimestamp.distinctCount" -> "2", - "spark.sql.statistics.colStats.ctimestamp.max" -> "2016-05-09 00:00:02.0", - "spark.sql.statistics.colStats.ctimestamp.maxLen" -> "8", - "spark.sql.statistics.colStats.ctimestamp.min" -> "2016-05-08 00:00:01.0", - "spark.sql.statistics.colStats.ctimestamp.nullCount" -> "1", - "spark.sql.statistics.colStats.ctimestamp.version" -> "1" - ) - - val expectedSerializedHistograms = Map( - "spark.sql.statistics.colStats.cbyte.histogram" -> - HistogramSerializer.serialize(statsWithHgms("cbyte").histogram.get), - "spark.sql.statistics.colStats.cshort.histogram" -> - HistogramSerializer.serialize(statsWithHgms("cshort").histogram.get), - "spark.sql.statistics.colStats.cint.histogram" -> - HistogramSerializer.serialize(statsWithHgms("cint").histogram.get), - "spark.sql.statistics.colStats.clong.histogram" -> - HistogramSerializer.serialize(statsWithHgms("clong").histogram.get), - "spark.sql.statistics.colStats.cdouble.histogram" -> - HistogramSerializer.serialize(statsWithHgms("cdouble").histogram.get), - "spark.sql.statistics.colStats.cfloat.histogram" -> - HistogramSerializer.serialize(statsWithHgms("cfloat").histogram.get), - "spark.sql.statistics.colStats.cdecimal.histogram" -> - HistogramSerializer.serialize(statsWithHgms("cdecimal").histogram.get), - "spark.sql.statistics.colStats.cdate.histogram" -> - HistogramSerializer.serialize(statsWithHgms("cdate").histogram.get), - "spark.sql.statistics.colStats.ctimestamp.histogram" -> - HistogramSerializer.serialize(statsWithHgms("ctimestamp").histogram.get) - ) - def checkColStatsProps(expected: Map[String, String]): Unit = { sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS FOR COLUMNS " + stats.keys.mkString(", ")) val table = hiveClient.getTable("default", tableName) @@ -1129,6 +1026,29 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto } } + test("verify column stats can be deserialized from tblproperties") { + import testImplicits._ + + val tableName = "column_stats_test_de" + // (data.head.productArity - 1) because the last column does not support stats collection. + assert(stats.size == data.head.productArity - 1) + val df = data.toDF(stats.keys.toSeq :+ "carray" : _*) + + withTable(tableName) { + df.write.saveAsTable(tableName) + + // Put in stats properties manually. + val table = getCatalogTable(tableName) + val newTable = table.copy( + properties = table.properties ++ + expectedSerializedColStats ++ expectedSerializedHistograms + + ("spark.sql.statistics.totalSize" -> "1") /* totalSize always required */) + hiveClient.alterTable(newTable) + + validateColStats(tableName, statsWithHgms) + } + } + test("serialization and deserialization of histograms to/from hive metastore") { import testImplicits._ From 649ed9c5732f85ef1306576fdd3a9278a2a6410c Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Tue, 27 Feb 2018 08:18:41 -0600 Subject: [PATCH 0397/1161] [SPARK-23509][BUILD] Upgrade commons-net from 2.2 to 3.1 ## What changes were proposed in this pull request? This PR avoids version conflicts of `commons-net` by upgrading commons-net from 2.2 to 3.1. We are seeing the following message during the build using sbt. ``` [warn] Found version conflict(s) in library dependencies; some are suspected to be binary incompatible: ... [warn] * commons-net:commons-net:3.1 is selected over 2.2 [warn] +- org.apache.hadoop:hadoop-common:2.6.5 (depends on 3.1) [warn] +- org.apache.spark:spark-core_2.11:2.4.0-SNAPSHOT (depends on 2.2) [warn] ``` [Here](https://commons.apache.org/proper/commons-net/changes-report.html) is a release history. [Here](https://commons.apache.org/proper/commons-net/migration.html) is a migration guide from 2.x to 3.0. ## How was this patch tested? Existing tests Author: Kazuaki Ishizaki Closes #20672 from kiszk/SPARK-23509. --- dev/deps/spark-deps-hadoop-2.6 | 2 +- dev/deps/spark-deps-hadoop-2.7 | 2 +- pom.xml | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6 index ed310507d14ed..c3d1dd444b506 100644 --- a/dev/deps/spark-deps-hadoop-2.6 +++ b/dev/deps/spark-deps-hadoop-2.6 @@ -48,7 +48,7 @@ commons-lang-2.6.jar commons-lang3-3.5.jar commons-logging-1.1.3.jar commons-math3-3.4.1.jar -commons-net-2.2.jar +commons-net-3.1.jar commons-pool-1.5.4.jar compress-lzf-1.0.3.jar core-1.1.2.jar diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index 04dec04796af4..290867035f91d 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -48,7 +48,7 @@ commons-lang-2.6.jar commons-lang3-3.5.jar commons-logging-1.1.3.jar commons-math3-3.4.1.jar -commons-net-2.2.jar +commons-net-3.1.jar commons-pool-1.5.4.jar compress-lzf-1.0.3.jar core-1.1.2.jar diff --git a/pom.xml b/pom.xml index ac30107066389..b8396166f6b1b 100644 --- a/pom.xml +++ b/pom.xml @@ -579,7 +579,7 @@ commons-net commons-net - 2.2 + 3.1 io.netty From eac0b067222a3dfa52be20360a453cb7bd420bf2 Mon Sep 17 00:00:00 2001 From: cody koeninger Date: Tue, 27 Feb 2018 08:21:11 -0600 Subject: [PATCH 0398/1161] [SPARK-17147][STREAMING][KAFKA] Allow non-consecutive offsets ## What changes were proposed in this pull request? Add a configuration spark.streaming.kafka.allowNonConsecutiveOffsets to allow streaming jobs to proceed on compacted topics (or other situations involving gaps between offsets in the log). ## How was this patch tested? Added new unit test justinrmiller has been testing this branch in production for a few weeks Author: cody koeninger Closes #20572 from koeninger/SPARK-17147. --- .../kafka010/CachedKafkaConsumer.scala | 55 +++- .../spark/streaming/kafka010/KafkaRDD.scala | 236 +++++++++++++----- .../streaming/kafka010/KafkaRDDSuite.scala | 106 ++++++++ .../streaming/kafka010/KafkaTestUtils.scala | 25 +- .../kafka010/mocks/MockScheduler.scala | 96 +++++++ .../streaming/kafka010/mocks/MockTime.scala | 51 ++++ 6 files changed, 487 insertions(+), 82 deletions(-) create mode 100644 external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/mocks/MockScheduler.scala create mode 100644 external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/mocks/MockTime.scala diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/CachedKafkaConsumer.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/CachedKafkaConsumer.scala index fa3ea6131a507..aeb8c1dc342b3 100644 --- a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/CachedKafkaConsumer.scala +++ b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/CachedKafkaConsumer.scala @@ -22,10 +22,8 @@ import java.{ util => ju } import org.apache.kafka.clients.consumer.{ ConsumerConfig, ConsumerRecord, KafkaConsumer } import org.apache.kafka.common.{ KafkaException, TopicPartition } -import org.apache.spark.SparkConf import org.apache.spark.internal.Logging - /** * Consumer of single topicpartition, intended for cached reuse. * Underlying consumer is not threadsafe, so neither is this, @@ -38,7 +36,7 @@ class CachedKafkaConsumer[K, V] private( val partition: Int, val kafkaParams: ju.Map[String, Object]) extends Logging { - assert(groupId == kafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG), + require(groupId == kafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG), "groupId used for cache key must match the groupId in kafkaParams") val topicPartition = new TopicPartition(topic, partition) @@ -53,7 +51,7 @@ class CachedKafkaConsumer[K, V] private( // TODO if the buffer was kept around as a random-access structure, // could possibly optimize re-calculating of an RDD in the same batch - protected var buffer = ju.Collections.emptyList[ConsumerRecord[K, V]]().iterator + protected var buffer = ju.Collections.emptyListIterator[ConsumerRecord[K, V]]() protected var nextOffset = -2L def close(): Unit = consumer.close() @@ -71,7 +69,7 @@ class CachedKafkaConsumer[K, V] private( } if (!buffer.hasNext()) { poll(timeout) } - assert(buffer.hasNext(), + require(buffer.hasNext(), s"Failed to get records for $groupId $topic $partition $offset after polling for $timeout") var record = buffer.next() @@ -79,17 +77,56 @@ class CachedKafkaConsumer[K, V] private( logInfo(s"Buffer miss for $groupId $topic $partition $offset") seek(offset) poll(timeout) - assert(buffer.hasNext(), + require(buffer.hasNext(), s"Failed to get records for $groupId $topic $partition $offset after polling for $timeout") record = buffer.next() - assert(record.offset == offset, - s"Got wrong record for $groupId $topic $partition even after seeking to offset $offset") + require(record.offset == offset, + s"Got wrong record for $groupId $topic $partition even after seeking to offset $offset " + + s"got offset ${record.offset} instead. If this is a compacted topic, consider enabling " + + "spark.streaming.kafka.allowNonConsecutiveOffsets" + ) } nextOffset = offset + 1 record } + /** + * Start a batch on a compacted topic + */ + def compactedStart(offset: Long, timeout: Long): Unit = { + logDebug(s"compacted start $groupId $topic $partition starting $offset") + // This seek may not be necessary, but it's hard to tell due to gaps in compacted topics + if (offset != nextOffset) { + logInfo(s"Initial fetch for compacted $groupId $topic $partition $offset") + seek(offset) + poll(timeout) + } + } + + /** + * Get the next record in the batch from a compacted topic. + * Assumes compactedStart has been called first, and ignores gaps. + */ + def compactedNext(timeout: Long): ConsumerRecord[K, V] = { + if (!buffer.hasNext()) { + poll(timeout) + } + require(buffer.hasNext(), + s"Failed to get records for compacted $groupId $topic $partition after polling for $timeout") + val record = buffer.next() + nextOffset = record.offset + 1 + record + } + + /** + * Rewind to previous record in the batch from a compacted topic. + * @throws NoSuchElementException if no previous element + */ + def compactedPrevious(): ConsumerRecord[K, V] = { + buffer.previous() + } + private def seek(offset: Long): Unit = { logDebug(s"Seeking to $topicPartition $offset") consumer.seek(topicPartition, offset) @@ -99,7 +136,7 @@ class CachedKafkaConsumer[K, V] private( val p = consumer.poll(timeout) val r = p.records(topicPartition) logDebug(s"Polled ${p.partitions()} ${r.size}") - buffer = r.iterator + buffer = r.listIterator } } diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaRDD.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaRDD.scala index d9fc9cc206647..07239eda64d2e 100644 --- a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaRDD.scala +++ b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaRDD.scala @@ -55,12 +55,12 @@ private[spark] class KafkaRDD[K, V]( useConsumerCache: Boolean ) extends RDD[ConsumerRecord[K, V]](sc, Nil) with Logging with HasOffsetRanges { - assert("none" == + require("none" == kafkaParams.get(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG).asInstanceOf[String], ConsumerConfig.AUTO_OFFSET_RESET_CONFIG + " must be set to none for executor kafka params, else messages may not match offsetRange") - assert(false == + require(false == kafkaParams.get(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG).asInstanceOf[Boolean], ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG + " must be set to false for executor kafka params, else offsets may commit before processing") @@ -74,6 +74,8 @@ private[spark] class KafkaRDD[K, V]( conf.getInt("spark.streaming.kafka.consumer.cache.maxCapacity", 64) private val cacheLoadFactor = conf.getDouble("spark.streaming.kafka.consumer.cache.loadFactor", 0.75).toFloat + private val compacted = + conf.getBoolean("spark.streaming.kafka.allowNonConsecutiveOffsets", false) override def persist(newLevel: StorageLevel): this.type = { logError("Kafka ConsumerRecord is not serializable. " + @@ -87,48 +89,63 @@ private[spark] class KafkaRDD[K, V]( }.toArray } - override def count(): Long = offsetRanges.map(_.count).sum + override def count(): Long = + if (compacted) { + super.count() + } else { + offsetRanges.map(_.count).sum + } override def countApprox( timeout: Long, confidence: Double = 0.95 - ): PartialResult[BoundedDouble] = { - val c = count - new PartialResult(new BoundedDouble(c, 1.0, c, c), true) - } - - override def isEmpty(): Boolean = count == 0L - - override def take(num: Int): Array[ConsumerRecord[K, V]] = { - val nonEmptyPartitions = this.partitions - .map(_.asInstanceOf[KafkaRDDPartition]) - .filter(_.count > 0) + ): PartialResult[BoundedDouble] = + if (compacted) { + super.countApprox(timeout, confidence) + } else { + val c = count + new PartialResult(new BoundedDouble(c, 1.0, c, c), true) + } - if (num < 1 || nonEmptyPartitions.isEmpty) { - return new Array[ConsumerRecord[K, V]](0) + override def isEmpty(): Boolean = + if (compacted) { + super.isEmpty() + } else { + count == 0L } - // Determine in advance how many messages need to be taken from each partition - val parts = nonEmptyPartitions.foldLeft(Map[Int, Int]()) { (result, part) => - val remain = num - result.values.sum - if (remain > 0) { - val taken = Math.min(remain, part.count) - result + (part.index -> taken.toInt) + override def take(num: Int): Array[ConsumerRecord[K, V]] = + if (compacted) { + super.take(num) + } else if (num < 1) { + Array.empty[ConsumerRecord[K, V]] + } else { + val nonEmptyPartitions = this.partitions + .map(_.asInstanceOf[KafkaRDDPartition]) + .filter(_.count > 0) + + if (nonEmptyPartitions.isEmpty) { + Array.empty[ConsumerRecord[K, V]] } else { - result + // Determine in advance how many messages need to be taken from each partition + val parts = nonEmptyPartitions.foldLeft(Map[Int, Int]()) { (result, part) => + val remain = num - result.values.sum + if (remain > 0) { + val taken = Math.min(remain, part.count) + result + (part.index -> taken.toInt) + } else { + result + } + } + + context.runJob( + this, + (tc: TaskContext, it: Iterator[ConsumerRecord[K, V]]) => + it.take(parts(tc.partitionId)).toArray, parts.keys.toArray + ).flatten } } - val buf = new ArrayBuffer[ConsumerRecord[K, V]] - val res = context.runJob( - this, - (tc: TaskContext, it: Iterator[ConsumerRecord[K, V]]) => - it.take(parts(tc.partitionId)).toArray, parts.keys.toArray - ) - res.foreach(buf ++= _) - buf.toArray - } - private def executors(): Array[ExecutorCacheTaskLocation] = { val bm = sparkContext.env.blockManager bm.master.getPeers(bm.blockManagerId).toArray @@ -172,57 +189,138 @@ private[spark] class KafkaRDD[K, V]( override def compute(thePart: Partition, context: TaskContext): Iterator[ConsumerRecord[K, V]] = { val part = thePart.asInstanceOf[KafkaRDDPartition] - assert(part.fromOffset <= part.untilOffset, errBeginAfterEnd(part)) + require(part.fromOffset <= part.untilOffset, errBeginAfterEnd(part)) if (part.fromOffset == part.untilOffset) { logInfo(s"Beginning offset ${part.fromOffset} is the same as ending offset " + s"skipping ${part.topic} ${part.partition}") Iterator.empty } else { - new KafkaRDDIterator(part, context) + logInfo(s"Computing topic ${part.topic}, partition ${part.partition} " + + s"offsets ${part.fromOffset} -> ${part.untilOffset}") + if (compacted) { + new CompactedKafkaRDDIterator[K, V]( + part, + context, + kafkaParams, + useConsumerCache, + pollTimeout, + cacheInitialCapacity, + cacheMaxCapacity, + cacheLoadFactor + ) + } else { + new KafkaRDDIterator[K, V]( + part, + context, + kafkaParams, + useConsumerCache, + pollTimeout, + cacheInitialCapacity, + cacheMaxCapacity, + cacheLoadFactor + ) + } } } +} - /** - * An iterator that fetches messages directly from Kafka for the offsets in partition. - * Uses a cached consumer where possible to take advantage of prefetching - */ - private class KafkaRDDIterator( - part: KafkaRDDPartition, - context: TaskContext) extends Iterator[ConsumerRecord[K, V]] { - - logInfo(s"Computing topic ${part.topic}, partition ${part.partition} " + - s"offsets ${part.fromOffset} -> ${part.untilOffset}") - - val groupId = kafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG).asInstanceOf[String] +/** + * An iterator that fetches messages directly from Kafka for the offsets in partition. + * Uses a cached consumer where possible to take advantage of prefetching + */ +private class KafkaRDDIterator[K, V]( + part: KafkaRDDPartition, + context: TaskContext, + kafkaParams: ju.Map[String, Object], + useConsumerCache: Boolean, + pollTimeout: Long, + cacheInitialCapacity: Int, + cacheMaxCapacity: Int, + cacheLoadFactor: Float +) extends Iterator[ConsumerRecord[K, V]] { + + val groupId = kafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG).asInstanceOf[String] + + context.addTaskCompletionListener(_ => closeIfNeeded()) + + val consumer = if (useConsumerCache) { + CachedKafkaConsumer.init(cacheInitialCapacity, cacheMaxCapacity, cacheLoadFactor) + if (context.attemptNumber >= 1) { + // just in case the prior attempt failures were cache related + CachedKafkaConsumer.remove(groupId, part.topic, part.partition) + } + CachedKafkaConsumer.get[K, V](groupId, part.topic, part.partition, kafkaParams) + } else { + CachedKafkaConsumer.getUncached[K, V](groupId, part.topic, part.partition, kafkaParams) + } - context.addTaskCompletionListener{ context => closeIfNeeded() } + var requestOffset = part.fromOffset - val consumer = if (useConsumerCache) { - CachedKafkaConsumer.init(cacheInitialCapacity, cacheMaxCapacity, cacheLoadFactor) - if (context.attemptNumber >= 1) { - // just in case the prior attempt failures were cache related - CachedKafkaConsumer.remove(groupId, part.topic, part.partition) - } - CachedKafkaConsumer.get[K, V](groupId, part.topic, part.partition, kafkaParams) - } else { - CachedKafkaConsumer.getUncached[K, V](groupId, part.topic, part.partition, kafkaParams) + def closeIfNeeded(): Unit = { + if (!useConsumerCache && consumer != null) { + consumer.close() } + } - var requestOffset = part.fromOffset + override def hasNext(): Boolean = requestOffset < part.untilOffset - def closeIfNeeded(): Unit = { - if (!useConsumerCache && consumer != null) { - consumer.close - } + override def next(): ConsumerRecord[K, V] = { + if (!hasNext) { + throw new ju.NoSuchElementException("Can't call getNext() once untilOffset has been reached") } + val r = consumer.get(requestOffset, pollTimeout) + requestOffset += 1 + r + } +} - override def hasNext(): Boolean = requestOffset < part.untilOffset - - override def next(): ConsumerRecord[K, V] = { - assert(hasNext(), "Can't call getNext() once untilOffset has been reached") - val r = consumer.get(requestOffset, pollTimeout) - requestOffset += 1 - r +/** + * An iterator that fetches messages directly from Kafka for the offsets in partition. + * Uses a cached consumer where possible to take advantage of prefetching. + * Intended for compacted topics, or other cases when non-consecutive offsets are ok. + */ +private class CompactedKafkaRDDIterator[K, V]( + part: KafkaRDDPartition, + context: TaskContext, + kafkaParams: ju.Map[String, Object], + useConsumerCache: Boolean, + pollTimeout: Long, + cacheInitialCapacity: Int, + cacheMaxCapacity: Int, + cacheLoadFactor: Float + ) extends KafkaRDDIterator[K, V]( + part, + context, + kafkaParams, + useConsumerCache, + pollTimeout, + cacheInitialCapacity, + cacheMaxCapacity, + cacheLoadFactor + ) { + + consumer.compactedStart(part.fromOffset, pollTimeout) + + private var nextRecord = consumer.compactedNext(pollTimeout) + + private var okNext: Boolean = true + + override def hasNext(): Boolean = okNext + + override def next(): ConsumerRecord[K, V] = { + if (!hasNext) { + throw new ju.NoSuchElementException("Can't call getNext() once untilOffset has been reached") + } + val r = nextRecord + if (r.offset + 1 >= part.untilOffset) { + okNext = false + } else { + nextRecord = consumer.compactedNext(pollTimeout) + if (nextRecord.offset >= part.untilOffset) { + okNext = false + consumer.compactedPrevious() + } } + r } } diff --git a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaRDDSuite.scala b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaRDDSuite.scala index be373af0599cc..271adea1df731 100644 --- a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaRDDSuite.scala +++ b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaRDDSuite.scala @@ -18,16 +18,22 @@ package org.apache.spark.streaming.kafka010 import java.{ util => ju } +import java.io.File import scala.collection.JavaConverters._ import scala.util.Random +import kafka.common.TopicAndPartition +import kafka.log._ +import kafka.message._ +import kafka.utils.Pool import org.apache.kafka.common.TopicPartition import org.apache.kafka.common.serialization.StringDeserializer import org.scalatest.BeforeAndAfterAll import org.apache.spark._ import org.apache.spark.scheduler.ExecutorCacheTaskLocation +import org.apache.spark.streaming.kafka010.mocks.MockTime class KafkaRDDSuite extends SparkFunSuite with BeforeAndAfterAll { @@ -64,6 +70,41 @@ class KafkaRDDSuite extends SparkFunSuite with BeforeAndAfterAll { private val preferredHosts = LocationStrategies.PreferConsistent + private def compactLogs(topic: String, partition: Int, messages: Array[(String, String)]) { + val mockTime = new MockTime() + // LogCleaner in 0.10 version of Kafka is still expecting the old TopicAndPartition api + val logs = new Pool[TopicAndPartition, Log]() + val logDir = kafkaTestUtils.brokerLogDir + val dir = new File(logDir, topic + "-" + partition) + dir.mkdirs() + val logProps = new ju.Properties() + logProps.put(LogConfig.CleanupPolicyProp, LogConfig.Compact) + logProps.put(LogConfig.MinCleanableDirtyRatioProp, java.lang.Float.valueOf(0.1f)) + val log = new Log( + dir, + LogConfig(logProps), + 0L, + mockTime.scheduler, + mockTime + ) + messages.foreach { case (k, v) => + val msg = new ByteBufferMessageSet( + NoCompressionCodec, + new Message(v.getBytes, k.getBytes, Message.NoTimestamp, Message.CurrentMagicValue)) + log.append(msg) + } + log.roll() + logs.put(TopicAndPartition(topic, partition), log) + + val cleaner = new LogCleaner(CleanerConfig(), logDirs = Array(dir), logs = logs) + cleaner.startup() + cleaner.awaitCleaned(topic, partition, log.activeSegment.baseOffset, 1000) + + cleaner.shutdown() + mockTime.scheduler.shutdown() + } + + test("basic usage") { val topic = s"topicbasic-${Random.nextInt}-${System.currentTimeMillis}" kafkaTestUtils.createTopic(topic) @@ -102,6 +143,71 @@ class KafkaRDDSuite extends SparkFunSuite with BeforeAndAfterAll { } } + test("compacted topic") { + val compactConf = sparkConf.clone() + compactConf.set("spark.streaming.kafka.allowNonConsecutiveOffsets", "true") + sc.stop() + sc = new SparkContext(compactConf) + val topic = s"topiccompacted-${Random.nextInt}-${System.currentTimeMillis}" + + val messages = Array( + ("a", "1"), + ("a", "2"), + ("b", "1"), + ("c", "1"), + ("c", "2"), + ("b", "2"), + ("b", "3") + ) + val compactedMessages = Array( + ("a", "2"), + ("b", "3"), + ("c", "2") + ) + + compactLogs(topic, 0, messages) + + val props = new ju.Properties() + props.put("cleanup.policy", "compact") + props.put("flush.messages", "1") + props.put("segment.ms", "1") + props.put("segment.bytes", "256") + kafkaTestUtils.createTopic(topic, 1, props) + + + val kafkaParams = getKafkaParams() + + val offsetRanges = Array(OffsetRange(topic, 0, 0, messages.size)) + + val rdd = KafkaUtils.createRDD[String, String]( + sc, kafkaParams, offsetRanges, preferredHosts + ).map(m => m.key -> m.value) + + val received = rdd.collect.toSet + assert(received === compactedMessages.toSet) + + // size-related method optimizations return sane results + assert(rdd.count === compactedMessages.size) + assert(rdd.countApprox(0).getFinalValue.mean === compactedMessages.size) + assert(!rdd.isEmpty) + assert(rdd.take(1).size === 1) + assert(rdd.take(1).head === compactedMessages.head) + assert(rdd.take(messages.size + 10).size === compactedMessages.size) + + val emptyRdd = KafkaUtils.createRDD[String, String]( + sc, kafkaParams, Array(OffsetRange(topic, 0, 0, 0)), preferredHosts) + + assert(emptyRdd.isEmpty) + + // invalid offset ranges throw exceptions + val badRanges = Array(OffsetRange(topic, 0, 0, messages.size + 1)) + intercept[SparkException] { + val result = KafkaUtils.createRDD[String, String](sc, kafkaParams, badRanges, preferredHosts) + .map(_.value) + .collect() + } + } + test("iterator boundary conditions") { // the idea is to find e.g. off-by-one errors between what kafka has available and the rdd val topic = s"topicboundary-${Random.nextInt}-${System.currentTimeMillis}" diff --git a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaTestUtils.scala b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaTestUtils.scala index 6c7024ea4b5a5..70b579d96d692 100644 --- a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaTestUtils.scala +++ b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaTestUtils.scala @@ -162,17 +162,22 @@ private[kafka010] class KafkaTestUtils extends Logging { } /** Create a Kafka topic and wait until it is propagated to the whole cluster */ - def createTopic(topic: String, partitions: Int): Unit = { - AdminUtils.createTopic(zkUtils, topic, partitions, 1) + def createTopic(topic: String, partitions: Int, config: Properties): Unit = { + AdminUtils.createTopic(zkUtils, topic, partitions, 1, config) // wait until metadata is propagated (0 until partitions).foreach { p => waitUntilMetadataIsPropagated(topic, p) } } + /** Create a Kafka topic and wait until it is propagated to the whole cluster */ + def createTopic(topic: String, partitions: Int): Unit = { + createTopic(topic, partitions, new Properties()) + } + /** Create a Kafka topic and wait until it is propagated to the whole cluster */ def createTopic(topic: String): Unit = { - createTopic(topic, 1) + createTopic(topic, 1, new Properties()) } /** Java-friendly function for sending messages to the Kafka broker */ @@ -196,12 +201,24 @@ private[kafka010] class KafkaTestUtils extends Logging { producer = null } + /** Send the array of (key, value) messages to the Kafka broker */ + def sendMessages(topic: String, messages: Array[(String, String)]): Unit = { + producer = new KafkaProducer[String, String](producerConfiguration) + messages.foreach { message => + producer.send(new ProducerRecord[String, String](topic, message._1, message._2)) + } + producer.close() + producer = null + } + + val brokerLogDir = Utils.createTempDir().getAbsolutePath + private def brokerConfiguration: Properties = { val props = new Properties() props.put("broker.id", "0") props.put("host.name", "localhost") props.put("port", brokerPort.toString) - props.put("log.dir", Utils.createTempDir().getAbsolutePath) + props.put("log.dir", brokerLogDir) props.put("zookeeper.connect", zkAddress) props.put("log.flush.interval.messages", "1") props.put("replica.socket.timeout.ms", "1500") diff --git a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/mocks/MockScheduler.scala b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/mocks/MockScheduler.scala new file mode 100644 index 0000000000000..928e1a6ef54b9 --- /dev/null +++ b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/mocks/MockScheduler.scala @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kafka010.mocks + +import java.util.concurrent.TimeUnit + +import scala.collection.mutable.PriorityQueue + +import kafka.utils.{Scheduler, Time} + +/** + * A mock scheduler that executes tasks synchronously using a mock time instance. + * Tasks are executed synchronously when the time is advanced. + * This class is meant to be used in conjunction with MockTime. + * + * Example usage + * + * val time = new MockTime + * time.scheduler.schedule("a task", println("hello world: " + time.milliseconds), delay = 1000) + * time.sleep(1001) // this should cause our scheduled task to fire + * + * + * Incrementing the time to the exact next execution time of a task will result in that task + * executing (it as if execution itself takes no time). + */ +private[kafka010] class MockScheduler(val time: Time) extends Scheduler { + + /* a priority queue of tasks ordered by next execution time */ + var tasks = new PriorityQueue[MockTask]() + + def isStarted: Boolean = true + + def startup(): Unit = {} + + def shutdown(): Unit = synchronized { + tasks.foreach(_.fun()) + tasks.clear() + } + + /** + * Check for any tasks that need to execute. Since this is a mock scheduler this check only occurs + * when this method is called and the execution happens synchronously in the calling thread. + * If you are using the scheduler associated with a MockTime instance this call + * will be triggered automatically. + */ + def tick(): Unit = synchronized { + val now = time.milliseconds + while(!tasks.isEmpty && tasks.head.nextExecution <= now) { + /* pop and execute the task with the lowest next execution time */ + val curr = tasks.dequeue + curr.fun() + /* if the task is periodic, reschedule it and re-enqueue */ + if(curr.periodic) { + curr.nextExecution += curr.period + this.tasks += curr + } + } + } + + def schedule( + name: String, + fun: () => Unit, + delay: Long = 0, + period: Long = -1, + unit: TimeUnit = TimeUnit.MILLISECONDS): Unit = synchronized { + tasks += MockTask(name, fun, time.milliseconds + delay, period = period) + tick() + } + +} + +case class MockTask( + val name: String, + val fun: () => Unit, + var nextExecution: Long, + val period: Long) extends Ordered[MockTask] { + def periodic: Boolean = period >= 0 + def compare(t: MockTask): Int = { + java.lang.Long.compare(t.nextExecution, nextExecution) + } +} diff --git a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/mocks/MockTime.scala b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/mocks/MockTime.scala new file mode 100644 index 0000000000000..a68f94db1f689 --- /dev/null +++ b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/mocks/MockTime.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kafka010.mocks + +import java.util.concurrent._ + +import kafka.utils.Time + +/** + * A class used for unit testing things which depend on the Time interface. + * + * This class never manually advances the clock, it only does so when you call + * sleep(ms) + * + * It also comes with an associated scheduler instance for managing background tasks in + * a deterministic way. + */ +private[kafka010] class MockTime(@volatile private var currentMs: Long) extends Time { + + val scheduler = new MockScheduler(this) + + def this() = this(System.currentTimeMillis) + + def milliseconds: Long = currentMs + + def nanoseconds: Long = + TimeUnit.NANOSECONDS.convert(currentMs, TimeUnit.MILLISECONDS) + + def sleep(ms: Long) { + this.currentMs += ms + scheduler.tick() + } + + override def toString(): String = s"MockTime($milliseconds)" + +} From 414ee867ba0835b97aae2e8d4e489e1879c251dd Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 27 Feb 2018 08:44:25 -0800 Subject: [PATCH 0399/1161] [SPARK-23523][SQL] Fix the incorrect result caused by the rule OptimizeMetadataOnlyQuery ## What changes were proposed in this pull request? ```Scala val tablePath = new File(s"${path.getCanonicalPath}/cOl3=c/cOl1=a/cOl5=e") Seq(("a", "b", "c", "d", "e")).toDF("cOl1", "cOl2", "cOl3", "cOl4", "cOl5") .write.json(tablePath.getCanonicalPath) val df = spark.read.json(path.getCanonicalPath).select("CoL1", "CoL5", "CoL3").distinct() df.show() ``` It generates a wrong result. ``` [c,e,a] ``` We have a bug in the rule `OptimizeMetadataOnlyQuery `. We should respect the attribute order in the original leaf node. This PR is to fix it. ## How was this patch tested? Added a test case Author: gatorsmile Closes #20684 from gatorsmile/optimizeMetadataOnly. --- .../plans/logical/LocalRelation.scala | 9 ++++---- .../execution/OptimizeMetadataOnlyQuery.scala | 12 ++++++++-- .../datasources/HadoopFsRelation.scala | 3 +++ .../OptimizeMetadataOnlyQuerySuite.scala | 22 +++++++++++++++++++ 4 files changed, 40 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala index d73d7e73f28d5..b05508db786ad 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala @@ -43,10 +43,11 @@ object LocalRelation { } } -case class LocalRelation(output: Seq[Attribute], - data: Seq[InternalRow] = Nil, - // Indicates whether this relation has data from a streaming source. - override val isStreaming: Boolean = false) +case class LocalRelation( + output: Seq[Attribute], + data: Seq[InternalRow] = Nil, + // Indicates whether this relation has data from a streaming source. + override val isStreaming: Boolean = false) extends LeafNode with analysis.MultiInstanceRelation { // A local relation must have resolved output. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala index 18f6f697bc857..0613d9053f826 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala @@ -17,6 +17,9 @@ package org.apache.spark.sql.execution +import java.util.Locale + +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.catalog.{HiveTableRelation, SessionCatalog} import org.apache.spark.sql.catalyst.expressions._ @@ -80,8 +83,13 @@ case class OptimizeMetadataOnlyQuery(catalog: SessionCatalog) extends Rule[Logic private def getPartitionAttrs( partitionColumnNames: Seq[String], relation: LogicalPlan): Seq[Attribute] = { - val partColumns = partitionColumnNames.map(_.toLowerCase).toSet - relation.output.filter(a => partColumns.contains(a.name.toLowerCase)) + val attrMap = relation.output.map(_.name.toLowerCase(Locale.ROOT)).zip(relation.output).toMap + partitionColumnNames.map { colName => + attrMap.getOrElse(colName.toLowerCase(Locale.ROOT), + throw new AnalysisException(s"Unable to find the column `$colName` " + + s"given [${relation.output.map(_.name).mkString(", ")}]") + ) + } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelation.scala index 6b34638529770..ac574b07ec497 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelation.scala @@ -67,6 +67,9 @@ case class HadoopFsRelation( } } + // When data schema and partition schema have the overlapped columns, the output + // schema respects the order of data schema for the overlapped columns, but respect + // the data types of partition schema val schema: StructType = { StructType(dataSchema.map(f => overlappedPartCols.getOrElse(getColName(f), f)) ++ partitionSchema.filterNot(f => overlappedPartCols.contains(getColName(f)))) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuerySuite.scala index 78c1e5dae566d..a543eb8351656 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuerySuite.scala @@ -17,9 +17,12 @@ package org.apache.spark.sql.execution +import java.io.File + import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.SQLConf.OPTIMIZER_METADATA_ONLY import org.apache.spark.sql.test.SharedSQLContext class OptimizeMetadataOnlyQuerySuite extends QueryTest with SharedSQLContext { @@ -125,4 +128,23 @@ class OptimizeMetadataOnlyQuerySuite extends QueryTest with SharedSQLContext { sql("SELECT COUNT(DISTINCT p) FROM t_1000").collect() } } + + test("Incorrect result caused by the rule OptimizeMetadataOnlyQuery") { + withSQLConf(OPTIMIZER_METADATA_ONLY.key -> "true") { + withTempPath { path => + val tablePath = new File(s"${path.getCanonicalPath}/cOl3=c/cOl1=a/cOl5=e") + Seq(("a", "b", "c", "d", "e")).toDF("cOl1", "cOl2", "cOl3", "cOl4", "cOl5") + .write.json(tablePath.getCanonicalPath) + + val df = spark.read.json(path.getCanonicalPath).select("CoL1", "CoL5", "CoL3").distinct() + checkAnswer(df, Row("a", "e", "c")) + + val localRelation = df.queryExecution.optimizedPlan.collectFirst { + case l: LocalRelation => l + } + assert(localRelation.nonEmpty, "expect to see a LocalRelation") + assert(localRelation.get.output.map(_.name) == Seq("cOl3", "cOl1", "cOl5")) + } + } + } } From ecb8b383af1cf1b67f3111c148229e00c9c17c40 Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Tue, 27 Feb 2018 11:12:32 -0800 Subject: [PATCH 0400/1161] [SPARK-23365][CORE] Do not adjust num executors when killing idle executors. The ExecutorAllocationManager should not adjust the target number of executors when killing idle executors, as it has already adjusted the target number down based on the task backlog. The name `replace` was misleading with DynamicAllocation on, as the target number of executors is changed outside of the call to `killExecutors`, so I adjusted that name. Also separated out the logic of `countFailures` as you don't always want that tied to `replace`. While I was there I made two changes that weren't directly related to this: 1) Fixed `countFailures` in a couple cases where it was getting an incorrect value since it used to be tied to `replace`, eg. when killing executors on a blacklisted node. 2) hard error if you call `sc.killExecutors` with dynamic allocation on, since that's another way the ExecutorAllocationManager and the CoarseGrainedSchedulerBackend would get out of sync. Added a unit test case which verifies that the calls to ExecutorAllocationClient do not adjust the number of executors. Author: Imran Rashid Closes #20604 from squito/SPARK-23365. --- .../spark/ExecutorAllocationClient.scala | 15 +++-- .../spark/ExecutorAllocationManager.scala | 20 ++++-- .../scala/org/apache/spark/SparkContext.scala | 13 +++- .../spark/scheduler/BlacklistTracker.scala | 3 +- .../CoarseGrainedSchedulerBackend.scala | 22 ++++--- .../ExecutorAllocationManagerSuite.scala | 66 ++++++++++++++++++- .../StandaloneDynamicAllocationSuite.scala | 3 +- .../scheduler/BlacklistTrackerSuite.scala | 14 ++-- 8 files changed, 121 insertions(+), 35 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationClient.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationClient.scala index 9112d93a86b2a..63d87b4cd385c 100644 --- a/core/src/main/scala/org/apache/spark/ExecutorAllocationClient.scala +++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationClient.scala @@ -55,18 +55,18 @@ private[spark] trait ExecutorAllocationClient { /** * Request that the cluster manager kill the specified executors. * - * When asking the executor to be replaced, the executor loss is considered a failure, and - * killed tasks that are running on the executor will count towards the failure limits. If no - * replacement is being requested, then the tasks will not count towards the limit. - * * @param executorIds identifiers of executors to kill - * @param replace whether to replace the killed executors with new ones, default false + * @param adjustTargetNumExecutors whether the target number of executors will be adjusted down + * after these executors have been killed + * @param countFailures if there are tasks running on the executors when they are killed, whether + * to count those failures toward task failure limits * @param force whether to force kill busy executors, default false * @return the ids of the executors acknowledged by the cluster manager to be removed. */ def killExecutors( executorIds: Seq[String], - replace: Boolean = false, + adjustTargetNumExecutors: Boolean, + countFailures: Boolean, force: Boolean = false): Seq[String] /** @@ -81,7 +81,8 @@ private[spark] trait ExecutorAllocationClient { * @return whether the request is acknowledged by the cluster manager. */ def killExecutor(executorId: String): Boolean = { - val killedExecutors = killExecutors(Seq(executorId)) + val killedExecutors = killExecutors(Seq(executorId), adjustTargetNumExecutors = true, + countFailures = false) killedExecutors.nonEmpty && killedExecutors(0).equals(executorId) } } diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala index 6c59038f2a6c1..189d91333c045 100644 --- a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala +++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala @@ -29,6 +29,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.internal.config.{DYN_ALLOCATION_MAX_EXECUTORS, DYN_ALLOCATION_MIN_EXECUTORS} import org.apache.spark.metrics.source.Source import org.apache.spark.scheduler._ +import org.apache.spark.storage.BlockManagerMaster import org.apache.spark.util.{Clock, SystemClock, ThreadUtils, Utils} /** @@ -81,7 +82,8 @@ import org.apache.spark.util.{Clock, SystemClock, ThreadUtils, Utils} private[spark] class ExecutorAllocationManager( client: ExecutorAllocationClient, listenerBus: LiveListenerBus, - conf: SparkConf) + conf: SparkConf, + blockManagerMaster: BlockManagerMaster) extends Logging { allocationManager => @@ -151,7 +153,7 @@ private[spark] class ExecutorAllocationManager( private var clock: Clock = new SystemClock() // Listener for Spark events that impact the allocation policy - private val listener = new ExecutorAllocationListener + val listener = new ExecutorAllocationListener // Executor that handles the scheduling task. private val executor = @@ -334,6 +336,11 @@ private[spark] class ExecutorAllocationManager( // If the new target has not changed, avoid sending a message to the cluster manager if (numExecutorsTarget < oldNumExecutorsTarget) { + // We lower the target number of executors but don't actively kill any yet. Killing is + // controlled separately by an idle timeout. It's still helpful to reduce the target number + // in case an executor just happens to get lost (eg., bad hardware, or the cluster manager + // preempts it) -- in that case, there is no point in trying to immediately get a new + // executor, since we wouldn't even use it yet. client.requestTotalExecutors(numExecutorsTarget, localityAwareTasks, hostToLocalTaskCount) logDebug(s"Lowering target number of executors to $numExecutorsTarget (previously " + s"$oldNumExecutorsTarget) because not all requested executors are actually needed") @@ -455,7 +462,10 @@ private[spark] class ExecutorAllocationManager( val executorsRemoved = if (testing) { executorIdsToBeRemoved } else { - client.killExecutors(executorIdsToBeRemoved) + // We don't want to change our target number of executors, because we already did that + // when the task backlog decreased. + client.killExecutors(executorIdsToBeRemoved, adjustTargetNumExecutors = false, + countFailures = false, force = false) } // [SPARK-21834] killExecutors api reduces the target number of executors. // So we need to update the target with desired value. @@ -575,7 +585,7 @@ private[spark] class ExecutorAllocationManager( // Note that it is not necessary to query the executors since all the cached // blocks we are concerned with are reported to the driver. Note that this // does not include broadcast blocks. - val hasCachedBlocks = SparkEnv.get.blockManager.master.hasCachedBlocks(executorId) + val hasCachedBlocks = blockManagerMaster.hasCachedBlocks(executorId) val now = clock.getTimeMillis() val timeout = { if (hasCachedBlocks) { @@ -610,7 +620,7 @@ private[spark] class ExecutorAllocationManager( * This class is intentionally conservative in its assumptions about the relative ordering * and consistency of events returned by the listener. */ - private class ExecutorAllocationListener extends SparkListener { + private[spark] class ExecutorAllocationListener extends SparkListener { private val stageIdToNumTasks = new mutable.HashMap[Int, Int] // Number of running tasks per stage including speculative tasks. diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index dc531e3337014..5e8595603cc90 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -534,7 +534,8 @@ class SparkContext(config: SparkConf) extends Logging { schedulerBackend match { case b: ExecutorAllocationClient => Some(new ExecutorAllocationManager( - schedulerBackend.asInstanceOf[ExecutorAllocationClient], listenerBus, _conf)) + schedulerBackend.asInstanceOf[ExecutorAllocationClient], listenerBus, _conf, + _env.blockManager.master)) case _ => None } @@ -1633,6 +1634,8 @@ class SparkContext(config: SparkConf) extends Logging { * :: DeveloperApi :: * Request that the cluster manager kill the specified executors. * + * This is not supported when dynamic allocation is turned on. + * * @note This is an indication to the cluster manager that the application wishes to adjust * its resource usage downwards. If the application wishes to replace the executors it kills * through this method with new ones, it should follow up explicitly with a call to @@ -1644,7 +1647,10 @@ class SparkContext(config: SparkConf) extends Logging { def killExecutors(executorIds: Seq[String]): Boolean = { schedulerBackend match { case b: ExecutorAllocationClient => - b.killExecutors(executorIds, replace = false, force = true).nonEmpty + require(executorAllocationManager.isEmpty, + "killExecutors() unsupported with Dynamic Allocation turned on") + b.killExecutors(executorIds, adjustTargetNumExecutors = true, countFailures = false, + force = true).nonEmpty case _ => logWarning("Killing executors is not supported by current scheduler.") false @@ -1682,7 +1688,8 @@ class SparkContext(config: SparkConf) extends Logging { private[spark] def killAndReplaceExecutor(executorId: String): Boolean = { schedulerBackend match { case b: ExecutorAllocationClient => - b.killExecutors(Seq(executorId), replace = true, force = true).nonEmpty + b.killExecutors(Seq(executorId), adjustTargetNumExecutors = false, countFailures = true, + force = true).nonEmpty case _ => logWarning("Killing executors is not supported by current scheduler.") false diff --git a/core/src/main/scala/org/apache/spark/scheduler/BlacklistTracker.scala b/core/src/main/scala/org/apache/spark/scheduler/BlacklistTracker.scala index cd8e61d6d0208..952598f6de19d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/BlacklistTracker.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/BlacklistTracker.scala @@ -152,7 +152,8 @@ private[scheduler] class BlacklistTracker ( case Some(a) => logInfo(s"Killing blacklisted executor id $exec " + s"since ${config.BLACKLIST_KILL_ENABLED.key} is set.") - a.killExecutors(Seq(exec), true, true) + a.killExecutors(Seq(exec), adjustTargetNumExecutors = false, countFailures = false, + force = true) case None => logWarning(s"Not attempting to kill blacklisted executor id $exec " + s"since allocation client is not defined.") diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 4d75063fbf1c5..5627a557a12f3 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -147,7 +147,8 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp case KillExecutorsOnHost(host) => scheduler.getExecutorsAliveOnHost(host).foreach { exec => - killExecutors(exec.toSeq, replace = true, force = true) + killExecutors(exec.toSeq, adjustTargetNumExecutors = false, countFailures = false, + force = true) } case UpdateDelegationTokens(newDelegationTokens) => @@ -584,18 +585,18 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp /** * Request that the cluster manager kill the specified executors. * - * When asking the executor to be replaced, the executor loss is considered a failure, and - * killed tasks that are running on the executor will count towards the failure limits. If no - * replacement is being requested, then the tasks will not count towards the limit. - * * @param executorIds identifiers of executors to kill - * @param replace whether to replace the killed executors with new ones, default false + * @param adjustTargetNumExecutors whether the target number of executors be adjusted down + * after these executors have been killed + * @param countFailures if there are tasks running on the executors when they are killed, whether + * those failures be counted to task failure limits? * @param force whether to force kill busy executors, default false * @return the ids of the executors acknowledged by the cluster manager to be removed. */ final override def killExecutors( executorIds: Seq[String], - replace: Boolean, + adjustTargetNumExecutors: Boolean, + countFailures: Boolean, force: Boolean): Seq[String] = { logInfo(s"Requesting to kill executor(s) ${executorIds.mkString(", ")}") @@ -610,7 +611,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp val executorsToKill = knownExecutors .filter { id => !executorsPendingToRemove.contains(id) } .filter { id => force || !scheduler.isExecutorBusy(id) } - executorsToKill.foreach { id => executorsPendingToRemove(id) = !replace } + executorsToKill.foreach { id => executorsPendingToRemove(id) = !countFailures } logInfo(s"Actual list of executor(s) to be killed is ${executorsToKill.mkString(", ")}") @@ -618,12 +619,13 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp // with the cluster manager to avoid allocating new ones. When computing the new target, // take into account executors that are pending to be added or removed. val adjustTotalExecutors = - if (!replace) { + if (adjustTargetNumExecutors) { requestedTotalExecutors = math.max(requestedTotalExecutors - executorsToKill.size, 0) if (requestedTotalExecutors != (numExistingExecutors + numPendingExecutors - executorsPendingToRemove.size)) { logDebug( - s"""killExecutors($executorIds, $replace, $force): Executor counts do not match: + s"""killExecutors($executorIds, $adjustTargetNumExecutors, $countFailures, $force): + |Executor counts do not match: |requestedTotalExecutors = $requestedTotalExecutors |numExistingExecutors = $numExistingExecutors |numPendingExecutors = $numPendingExecutors diff --git a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala index a0cae5a9e011c..9807d1269e3d4 100644 --- a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala @@ -19,6 +19,8 @@ package org.apache.spark import scala.collection.mutable +import org.mockito.Matchers.{any, eq => meq} +import org.mockito.Mockito.{mock, never, verify, when} import org.scalatest.{BeforeAndAfter, PrivateMethodTester} import org.apache.spark.executor.TaskMetrics @@ -26,6 +28,7 @@ import org.apache.spark.scheduler._ import org.apache.spark.scheduler.ExternalClusterManager import org.apache.spark.scheduler.cluster.ExecutorInfo import org.apache.spark.scheduler.local.LocalSchedulerBackend +import org.apache.spark.storage.BlockManagerMaster import org.apache.spark.util.ManualClock /** @@ -1050,6 +1053,66 @@ class ExecutorAllocationManagerSuite assert(removeTimes(manager) === Map.empty) } + test("SPARK-23365 Don't update target num executors when killing idle executors") { + val minExecutors = 1 + val initialExecutors = 1 + val maxExecutors = 2 + val conf = new SparkConf() + .set("spark.dynamicAllocation.enabled", "true") + .set("spark.shuffle.service.enabled", "true") + .set("spark.dynamicAllocation.minExecutors", minExecutors.toString) + .set("spark.dynamicAllocation.maxExecutors", maxExecutors.toString) + .set("spark.dynamicAllocation.initialExecutors", initialExecutors.toString) + .set("spark.dynamicAllocation.schedulerBacklogTimeout", "1000ms") + .set("spark.dynamicAllocation.sustainedSchedulerBacklogTimeout", "1000ms") + .set("spark.dynamicAllocation.executorIdleTimeout", s"3000ms") + val mockAllocationClient = mock(classOf[ExecutorAllocationClient]) + val mockBMM = mock(classOf[BlockManagerMaster]) + val manager = new ExecutorAllocationManager( + mockAllocationClient, mock(classOf[LiveListenerBus]), conf, mockBMM) + val clock = new ManualClock() + manager.setClock(clock) + + when(mockAllocationClient.requestTotalExecutors(meq(2), any(), any())).thenReturn(true) + // test setup -- job with 2 tasks, scale up to two executors + assert(numExecutorsTarget(manager) === 1) + manager.listener.onExecutorAdded(SparkListenerExecutorAdded( + clock.getTimeMillis(), "executor-1", new ExecutorInfo("host1", 1, Map.empty))) + manager.listener.onStageSubmitted(SparkListenerStageSubmitted(createStageInfo(0, 2))) + clock.advance(1000) + manager invokePrivate _updateAndSyncNumExecutorsTarget(clock.getTimeMillis()) + assert(numExecutorsTarget(manager) === 2) + val taskInfo0 = createTaskInfo(0, 0, "executor-1") + manager.listener.onTaskStart(SparkListenerTaskStart(0, 0, taskInfo0)) + manager.listener.onExecutorAdded(SparkListenerExecutorAdded( + clock.getTimeMillis(), "executor-2", new ExecutorInfo("host1", 1, Map.empty))) + val taskInfo1 = createTaskInfo(1, 1, "executor-2") + manager.listener.onTaskStart(SparkListenerTaskStart(0, 0, taskInfo1)) + assert(numExecutorsTarget(manager) === 2) + + // have one task finish -- we should adjust the target number of executors down + // but we should *not* kill any executors yet + manager.listener.onTaskEnd(SparkListenerTaskEnd(0, 0, null, Success, taskInfo0, null)) + assert(maxNumExecutorsNeeded(manager) === 1) + assert(numExecutorsTarget(manager) === 2) + clock.advance(1000) + manager invokePrivate _updateAndSyncNumExecutorsTarget(clock.getTimeMillis()) + assert(numExecutorsTarget(manager) === 1) + verify(mockAllocationClient, never).killExecutors(any(), any(), any(), any()) + + // now we cross the idle timeout for executor-1, so we kill it. the really important + // thing here is that we do *not* ask the executor allocation client to adjust the target + // number of executors down + when(mockAllocationClient.killExecutors(Seq("executor-1"), false, false, false)) + .thenReturn(Seq("executor-1")) + clock.advance(3000) + schedule(manager) + assert(maxNumExecutorsNeeded(manager) === 1) + assert(numExecutorsTarget(manager) === 1) + // here's the important verify -- we did kill the executors, but did not adjust the target count + verify(mockAllocationClient).killExecutors(Seq("executor-1"), false, false, false) + } + private def createSparkContext( minExecutors: Int = 1, maxExecutors: Int = 5, @@ -1268,7 +1331,8 @@ private class DummyLocalSchedulerBackend (sc: SparkContext, sb: SchedulerBackend override def killExecutors( executorIds: Seq[String], - replace: Boolean, + adjustTargetNumExecutors: Boolean, + countFailures: Boolean, force: Boolean): Seq[String] = executorIds override def start(): Unit = sb.start() diff --git a/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala b/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala index c21ee7d26f8ca..27cc47496c805 100644 --- a/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala @@ -573,7 +573,8 @@ class StandaloneDynamicAllocationSuite syncExecutors(sc) sc.schedulerBackend match { case b: CoarseGrainedSchedulerBackend => - b.killExecutors(Seq(executorId), replace = false, force) + b.killExecutors(Seq(executorId), adjustTargetNumExecutors = true, countFailures = false, + force) case _ => fail("expected coarse grained scheduler") } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala index afebcdd7b9e31..06d7afaaff55c 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala @@ -479,7 +479,7 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M test("blacklisting kills executors, configured by BLACKLIST_KILL_ENABLED") { val allocationClientMock = mock[ExecutorAllocationClient] - when(allocationClientMock.killExecutors(any(), any(), any())).thenReturn(Seq("called")) + when(allocationClientMock.killExecutors(any(), any(), any(), any())).thenReturn(Seq("called")) when(allocationClientMock.killExecutorsOnHost("hostA")).thenAnswer(new Answer[Boolean] { // To avoid a race between blacklisting and killing, it is important that the nodeBlacklist // is updated before we ask the executor allocation client to kill all the executors @@ -517,7 +517,7 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M } blacklist.updateBlacklistForSuccessfulTaskSet(0, 0, taskSetBlacklist1.execToFailures) - verify(allocationClientMock, never).killExecutors(any(), any(), any()) + verify(allocationClientMock, never).killExecutors(any(), any(), any(), any()) verify(allocationClientMock, never).killExecutorsOnHost(any()) // Enable auto-kill. Blacklist an executor and make sure killExecutors is called. @@ -533,7 +533,7 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M } blacklist.updateBlacklistForSuccessfulTaskSet(0, 0, taskSetBlacklist2.execToFailures) - verify(allocationClientMock).killExecutors(Seq("1"), true, true) + verify(allocationClientMock).killExecutors(Seq("1"), false, false, true) val taskSetBlacklist3 = createTaskSetBlacklist(stageId = 1) // Fail 4 tasks in one task set on executor 2, so that executor gets blacklisted for the whole @@ -545,13 +545,13 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M } blacklist.updateBlacklistForSuccessfulTaskSet(0, 0, taskSetBlacklist3.execToFailures) - verify(allocationClientMock).killExecutors(Seq("2"), true, true) + verify(allocationClientMock).killExecutors(Seq("2"), false, false, true) verify(allocationClientMock).killExecutorsOnHost("hostA") } test("fetch failure blacklisting kills executors, configured by BLACKLIST_KILL_ENABLED") { val allocationClientMock = mock[ExecutorAllocationClient] - when(allocationClientMock.killExecutors(any(), any(), any())).thenReturn(Seq("called")) + when(allocationClientMock.killExecutors(any(), any(), any(), any())).thenReturn(Seq("called")) when(allocationClientMock.killExecutorsOnHost("hostA")).thenAnswer(new Answer[Boolean] { // To avoid a race between blacklisting and killing, it is important that the nodeBlacklist // is updated before we ask the executor allocation client to kill all the executors @@ -571,7 +571,7 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M conf.set(config.BLACKLIST_KILL_ENABLED, false) blacklist.updateBlacklistForFetchFailure("hostA", exec = "1") - verify(allocationClientMock, never).killExecutors(any(), any(), any()) + verify(allocationClientMock, never).killExecutors(any(), any(), any(), any()) verify(allocationClientMock, never).killExecutorsOnHost(any()) // Enable auto-kill. Blacklist an executor and make sure killExecutors is called. @@ -580,7 +580,7 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M clock.advance(1000) blacklist.updateBlacklistForFetchFailure("hostA", exec = "1") - verify(allocationClientMock).killExecutors(Seq("1"), true, true) + verify(allocationClientMock).killExecutors(Seq("1"), false, false, true) verify(allocationClientMock, never).killExecutorsOnHost(any()) assert(blacklist.executorIdToBlacklistStatus.contains("1")) From 598446b74b61fee272d3aee3a2e9a3fc90a70d6a Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Tue, 27 Feb 2018 11:33:10 -0800 Subject: [PATCH 0401/1161] [SPARK-23501][UI] Refactor AllStagesPage in order to avoid redundant code As suggested in #20651, the code is very redundant in `AllStagesPage` and modifying it is a copy-and-paste work. We should avoid such a pattern, which is error prone, and have a cleaner solution which avoids code redundancy. existing UTs Author: Marco Gaido Closes #20663 from mgaido91/SPARK-23475_followup. --- .../apache/spark/ui/jobs/AllStagesPage.scala | 261 +++++++----------- 1 file changed, 102 insertions(+), 159 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala index 38450b9126ff0..4658aa1cea3f1 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala @@ -19,46 +19,20 @@ package org.apache.spark.ui.jobs import javax.servlet.http.HttpServletRequest -import scala.xml.{Node, NodeSeq} +import scala.xml.{Attribute, Elem, Node, NodeSeq, Null, Text} import org.apache.spark.scheduler.Schedulable -import org.apache.spark.status.PoolData -import org.apache.spark.status.api.v1._ +import org.apache.spark.status.{AppSummary, PoolData} +import org.apache.spark.status.api.v1.{StageData, StageStatus} import org.apache.spark.ui.{UIUtils, WebUIPage} /** Page showing list of all ongoing and recently finished stages and pools */ private[ui] class AllStagesPage(parent: StagesTab) extends WebUIPage("") { private val sc = parent.sc + private val subPath = "stages" private def isFairScheduler = parent.isFairScheduler def render(request: HttpServletRequest): Seq[Node] = { - val allStages = parent.store.stageList(null) - - val activeStages = allStages.filter(_.status == StageStatus.ACTIVE) - val pendingStages = allStages.filter(_.status == StageStatus.PENDING) - val skippedStages = allStages.filter(_.status == StageStatus.SKIPPED) - val completedStages = allStages.filter(_.status == StageStatus.COMPLETE) - val failedStages = allStages.filter(_.status == StageStatus.FAILED).reverse - - val numFailedStages = failedStages.size - val subPath = "stages" - - val activeStagesTable = - new StageTableBase(parent.store, request, activeStages, "active", "activeStage", - parent.basePath, subPath, parent.isFairScheduler, parent.killEnabled, false) - val pendingStagesTable = - new StageTableBase(parent.store, request, pendingStages, "pending", "pendingStage", - parent.basePath, subPath, parent.isFairScheduler, false, false) - val completedStagesTable = - new StageTableBase(parent.store, request, completedStages, "completed", "completedStage", - parent.basePath, subPath, parent.isFairScheduler, false, false) - val skippedStagesTable = - new StageTableBase(parent.store, request, skippedStages, "skipped", "skippedStage", - parent.basePath, subPath, parent.isFairScheduler, false, false) - val failedStagesTable = - new StageTableBase(parent.store, request, failedStages, "failed", "failedStage", - parent.basePath, subPath, parent.isFairScheduler, false, true) - // For now, pool information is only accessible in live UIs val pools = sc.map(_.getAllPools).getOrElse(Seq.empty[Schedulable]).map { pool => val uiPool = parent.store.asOption(parent.store.pool(pool.name)).getOrElse( @@ -67,152 +41,121 @@ private[ui] class AllStagesPage(parent: StagesTab) extends WebUIPage("") { }.toMap val poolTable = new PoolTable(pools, parent) - val shouldShowActiveStages = activeStages.nonEmpty - val shouldShowPendingStages = pendingStages.nonEmpty - val shouldShowCompletedStages = completedStages.nonEmpty - val shouldShowSkippedStages = skippedStages.nonEmpty - val shouldShowFailedStages = failedStages.nonEmpty + val allStatuses = Seq(StageStatus.ACTIVE, StageStatus.PENDING, StageStatus.COMPLETE, + StageStatus.SKIPPED, StageStatus.FAILED) + val allStages = parent.store.stageList(null) val appSummary = parent.store.appSummary() - val completedStageNumStr = if (appSummary.numCompletedStages == completedStages.size) { - s"${appSummary.numCompletedStages}" - } else { - s"${appSummary.numCompletedStages}, only showing ${completedStages.size}" - } + + val (summaries, tables) = allStatuses.map( + summaryAndTableForStatus(allStages, appSummary, _, request)).unzip val summary: NodeSeq =
      - { - if (shouldShowActiveStages) { -
    • - Active Stages: - {activeStages.size} -
    • - } - } - { - if (shouldShowPendingStages) { -
    • - Pending Stages: - {pendingStages.size} -
    • - } - } - { - if (shouldShowCompletedStages) { -
    • - Completed Stages: - {completedStageNumStr} -
    • - } - } - { - if (shouldShowSkippedStages) { -
    • - Skipped Stages: - {skippedStages.size} -
    • - } - } - { - if (shouldShowFailedStages) { -
    • - Failed Stages: - {numFailedStages} -
    • - } - } + {summaries.flatten}
    - var content = summary ++ - { - if (sc.isDefined && isFairScheduler) { - -

    - - Fair Scheduler Pools ({pools.size}) -

    -
    ++ -
    - {poolTable.toNodeSeq} -
    - } else { - Seq.empty[Node] - } - } - if (shouldShowActiveStages) { - content ++= - -

    - - Active Stages ({activeStages.size}) -

    -
    ++ -
    - {activeStagesTable.toNodeSeq} -
    - } - if (shouldShowPendingStages) { - content ++= - + val poolsDescription = if (sc.isDefined && isFairScheduler) { +

    - Pending Stages ({pendingStages.size}) + Fair Scheduler Pools ({pools.size})

    ++ -
    - {pendingStagesTable.toNodeSeq} +
    + {poolTable.toNodeSeq}
    + } else { + Seq.empty[Node] + } + + val content = summary ++ poolsDescription ++ tables.flatten.flatten + + UIUtils.headerSparkPage("Stages for All Jobs", content, parent) + } + + private def summaryAndTableForStatus( + allStages: Seq[StageData], + appSummary: AppSummary, + status: StageStatus, + request: HttpServletRequest): (Option[Elem], Option[NodeSeq]) = { + val stages = if (status == StageStatus.FAILED) { + allStages.filter(_.status == status).reverse + } else { + allStages.filter(_.status == status) } - if (shouldShowCompletedStages) { - content ++= - -

    - - Completed Stages ({completedStageNumStr}) -

    -
    ++ -
    - {completedStagesTable.toNodeSeq} -
    + + if (stages.isEmpty) { + (None, None) + } else { + val killEnabled = status == StageStatus.ACTIVE && parent.killEnabled + val isFailedStage = status == StageStatus.FAILED + + val stagesTable = + new StageTableBase(parent.store, request, stages, statusName(status), stageTag(status), + parent.basePath, subPath, parent.isFairScheduler, killEnabled, isFailedStage) + val stagesSize = stages.size + (Some(summary(appSummary, status, stagesSize)), + Some(table(appSummary, status, stagesTable, stagesSize))) } - if (shouldShowSkippedStages) { - content ++= - -

    - - Skipped Stages ({skippedStages.size}) -

    -
    ++ -
    - {skippedStagesTable.toNodeSeq} -
    + } + + private def statusName(status: StageStatus): String = status match { + case StageStatus.ACTIVE => "active" + case StageStatus.COMPLETE => "completed" + case StageStatus.FAILED => "failed" + case StageStatus.PENDING => "pending" + case StageStatus.SKIPPED => "skipped" + } + + private def stageTag(status: StageStatus): String = s"${statusName(status)}Stage" + + private def headerDescription(status: StageStatus): String = statusName(status).capitalize + + private def summaryContent(appSummary: AppSummary, status: StageStatus, size: Int): String = { + if (status == StageStatus.COMPLETE && appSummary.numCompletedStages != size) { + s"${appSummary.numCompletedStages}, only showing $size" + } else { + s"$size" } - if (shouldShowFailedStages) { - content ++= - -

    - - Failed Stages ({numFailedStages}) -

    -
    ++ -
    - {failedStagesTable.toNodeSeq} -
    + } + + private def summary(appSummary: AppSummary, status: StageStatus, size: Int): Elem = { + val summary = +
  • + + {headerDescription(status)} Stages: + + {summaryContent(appSummary, status, size)} +
  • + + if (status == StageStatus.COMPLETE) { + summary % Attribute(None, "id", Text("completed-summary"), Null) + } else { + summary } - UIUtils.headerSparkPage("Stages for All Jobs", content, parent) + } + + private def table( + appSummary: AppSummary, + status: StageStatus, + stagesTable: StageTableBase, + size: Int): NodeSeq = { + val classSuffix = s"${statusName(status).capitalize}Stages" + +

    + + {headerDescription(status)} Stages ({summaryContent(appSummary, status, size)}) +

    +
    ++ +
    + {stagesTable.toNodeSeq} +
    } } From 23ac3aaba4a33bc3d31d01f21e93c4681ef6de03 Mon Sep 17 00:00:00 2001 From: Bruce Robbins Date: Wed, 28 Feb 2018 09:25:02 +0900 Subject: [PATCH 0402/1161] [SPARK-23417][PYTHON] Fix the build instructions supplied by exception messages in python streaming tests ## What changes were proposed in this pull request? Fix the build instructions supplied by exception messages in python streaming tests. I also added -DskipTests to the maven instructions to avoid the 170 minutes of scala tests that occurs each time one wants to add a jar to the assembly directory. ## How was this patch tested? - clone branch - run build/sbt package - run python/run-tests --modules "pyspark-streaming" , expect error message - follow instructions in error message. i.e., run build/sbt assembly/package streaming-kafka-0-8-assembly/assembly - rerun python tests, expect error message - follow instructions in error message. i.e run build/sbt -Pflume assembly/package streaming-flume-assembly/assembly - rerun python tests, see success. - repeated all of the above for mvn version of the process. Author: Bruce Robbins Closes #20638 from bersprockets/SPARK-23417_propa. --- python/pyspark/streaming/tests.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 5b86c1cb2c390..71f8101e34c50 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -1477,8 +1477,8 @@ def search_kafka_assembly_jar(): raise Exception( ("Failed to find Spark Streaming kafka assembly jar in %s. " % kafka_assembly_dir) + "You need to build Spark with " - "'build/sbt assembly/package streaming-kafka-0-8-assembly/assembly' or " - "'build/mvn -Pkafka-0-8 package' before running this test.") + "'build/sbt -Pkafka-0-8 assembly/package streaming-kafka-0-8-assembly/assembly' or " + "'build/mvn -DskipTests -Pkafka-0-8 package' before running this test.") elif len(jars) > 1: raise Exception(("Found multiple Spark Streaming Kafka assembly JARs: %s; please " "remove all but one") % (", ".join(jars))) @@ -1494,8 +1494,8 @@ def search_flume_assembly_jar(): raise Exception( ("Failed to find Spark Streaming Flume assembly jar in %s. " % flume_assembly_dir) + "You need to build Spark with " - "'build/sbt assembly/assembly streaming-flume-assembly/assembly' or " - "'build/mvn -Pflume package' before running this test.") + "'build/sbt -Pflume assembly/package streaming-flume-assembly/assembly' or " + "'build/mvn -DskipTests -Pflume package' before running this test.") elif len(jars) > 1: raise Exception(("Found multiple Spark Streaming Flume assembly JARs: %s; please " "remove all but one") % (", ".join(jars))) From b14993e1fcb68e1c946a671c6048605ab4afdf58 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 28 Feb 2018 11:00:54 +0900 Subject: [PATCH 0403/1161] [SPARK-23448][SQL] Clarify JSON and CSV parser behavior in document ## What changes were proposed in this pull request? Clarify JSON and CSV reader behavior in document. JSON doesn't support partial results for corrupted records. CSV only supports partial results for the records with more or less tokens. ## How was this patch tested? Pass existing tests. Author: Liang-Chi Hsieh Closes #20666 from viirya/SPARK-23448-2. --- python/pyspark/sql/readwriter.py | 30 ++++++++++--------- python/pyspark/sql/streaming.py | 30 ++++++++++--------- .../sql/catalyst/json/JacksonParser.scala | 3 ++ .../apache/spark/sql/DataFrameReader.scala | 22 +++++++------- .../datasources/csv/UnivocityParser.scala | 5 ++++ .../sql/streaming/DataStreamReader.scala | 22 +++++++------- 6 files changed, 64 insertions(+), 48 deletions(-) diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 49af1bcee5ef8..9d05ac7cb39be 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -209,13 +209,13 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, :param mode: allows a mode for dealing with corrupt records during parsing. If None is set, it uses the default value, ``PERMISSIVE``. - * ``PERMISSIVE`` : sets other fields to ``null`` when it meets a corrupted \ - record, and puts the malformed string into a field configured by \ - ``columnNameOfCorruptRecord``. To keep corrupt records, an user can set \ - a string type field named ``columnNameOfCorruptRecord`` in an user-defined \ - schema. If a schema does not have the field, it drops corrupt records during \ - parsing. When inferring a schema, it implicitly adds a \ - ``columnNameOfCorruptRecord`` field in an output schema. + * ``PERMISSIVE`` : when it meets a corrupted record, puts the malformed string \ + into a field configured by ``columnNameOfCorruptRecord``, and sets other \ + fields to ``null``. To keep corrupt records, an user can set a string type \ + field named ``columnNameOfCorruptRecord`` in an user-defined schema. If a \ + schema does not have the field, it drops corrupt records during parsing. \ + When inferring a schema, it implicitly adds a ``columnNameOfCorruptRecord`` \ + field in an output schema. * ``DROPMALFORMED`` : ignores the whole corrupted records. * ``FAILFAST`` : throws an exception when it meets corrupted records. @@ -393,13 +393,15 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non :param mode: allows a mode for dealing with corrupt records during parsing. If None is set, it uses the default value, ``PERMISSIVE``. - * ``PERMISSIVE`` : sets other fields to ``null`` when it meets a corrupted \ - record, and puts the malformed string into a field configured by \ - ``columnNameOfCorruptRecord``. To keep corrupt records, an user can set \ - a string type field named ``columnNameOfCorruptRecord`` in an \ - user-defined schema. If a schema does not have the field, it drops corrupt \ - records during parsing. When a length of parsed CSV tokens is shorter than \ - an expected length of a schema, it sets `null` for extra fields. + * ``PERMISSIVE`` : when it meets a corrupted record, puts the malformed string \ + into a field configured by ``columnNameOfCorruptRecord``, and sets other \ + fields to ``null``. To keep corrupt records, an user can set a string type \ + field named ``columnNameOfCorruptRecord`` in an user-defined schema. If a \ + schema does not have the field, it drops corrupt records during parsing. \ + A record with less/more tokens than schema is not a corrupted record to CSV. \ + When it meets a record having fewer tokens than the length of the schema, \ + sets ``null`` to extra fields. When the record has more tokens than the \ + length of the schema, it drops extra tokens. * ``DROPMALFORMED`` : ignores the whole corrupted records. * ``FAILFAST`` : throws an exception when it meets corrupted records. diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index e2a97acb5e2a7..cc622decfd682 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -442,13 +442,13 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, :param mode: allows a mode for dealing with corrupt records during parsing. If None is set, it uses the default value, ``PERMISSIVE``. - * ``PERMISSIVE`` : sets other fields to ``null`` when it meets a corrupted \ - record, and puts the malformed string into a field configured by \ - ``columnNameOfCorruptRecord``. To keep corrupt records, an user can set \ - a string type field named ``columnNameOfCorruptRecord`` in an user-defined \ - schema. If a schema does not have the field, it drops corrupt records during \ - parsing. When inferring a schema, it implicitly adds a \ - ``columnNameOfCorruptRecord`` field in an output schema. + * ``PERMISSIVE`` : when it meets a corrupted record, puts the malformed string \ + into a field configured by ``columnNameOfCorruptRecord``, and sets other \ + fields to ``null``. To keep corrupt records, an user can set a string type \ + field named ``columnNameOfCorruptRecord`` in an user-defined schema. If a \ + schema does not have the field, it drops corrupt records during parsing. \ + When inferring a schema, it implicitly adds a ``columnNameOfCorruptRecord`` \ + field in an output schema. * ``DROPMALFORMED`` : ignores the whole corrupted records. * ``FAILFAST`` : throws an exception when it meets corrupted records. @@ -621,13 +621,15 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non :param mode: allows a mode for dealing with corrupt records during parsing. If None is set, it uses the default value, ``PERMISSIVE``. - * ``PERMISSIVE`` : sets other fields to ``null`` when it meets a corrupted \ - record, and puts the malformed string into a field configured by \ - ``columnNameOfCorruptRecord``. To keep corrupt records, an user can set \ - a string type field named ``columnNameOfCorruptRecord`` in an \ - user-defined schema. If a schema does not have the field, it drops corrupt \ - records during parsing. When a length of parsed CSV tokens is shorter than \ - an expected length of a schema, it sets `null` for extra fields. + * ``PERMISSIVE`` : when it meets a corrupted record, puts the malformed string \ + into a field configured by ``columnNameOfCorruptRecord``, and sets other \ + fields to ``null``. To keep corrupt records, an user can set a string type \ + field named ``columnNameOfCorruptRecord`` in an user-defined schema. If a \ + schema does not have the field, it drops corrupt records during parsing. \ + A record with less/more tokens than schema is not a corrupted record to CSV. \ + When it meets a record having fewer tokens than the length of the schema, \ + sets ``null`` to extra fields. When the record has more tokens than the \ + length of the schema, it drops extra tokens. * ``DROPMALFORMED`` : ignores the whole corrupted records. * ``FAILFAST`` : throws an exception when it meets corrupted records. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala index bd144c9575c72..7f6956994f31f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala @@ -357,6 +357,9 @@ class JacksonParser( } } catch { case e @ (_: RuntimeException | _: JsonProcessingException) => + // JSON parser currently doesn't support partial results for corrupted records. + // For such records, all fields other than the field configured by + // `columnNameOfCorruptRecord` are set to `null`. throw BadRecordException(() => recordLiteral(record), () => None, e) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 4274f120a375a..0139913aaa4e2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -345,12 +345,12 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { *
  • `mode` (default `PERMISSIVE`): allows a mode for dealing with corrupt records * during parsing. *
      - *
    • `PERMISSIVE` : sets other fields to `null` when it meets a corrupted record, and puts - * the malformed string into a field configured by `columnNameOfCorruptRecord`. To keep - * corrupt records, an user can set a string type field named `columnNameOfCorruptRecord` - * in an user-defined schema. If a schema does not have the field, it drops corrupt records - * during parsing. When inferring a schema, it implicitly adds a `columnNameOfCorruptRecord` - * field in an output schema.
    • + *
    • `PERMISSIVE` : when it meets a corrupted record, puts the malformed string into a + * field configured by `columnNameOfCorruptRecord`, and sets other fields to `null`. To + * keep corrupt records, an user can set a string type field named + * `columnNameOfCorruptRecord` in an user-defined schema. If a schema does not have the + * field, it drops corrupt records during parsing. When inferring a schema, it implicitly + * adds a `columnNameOfCorruptRecord` field in an output schema.
    • *
    • `DROPMALFORMED` : ignores the whole corrupted records.
    • *
    • `FAILFAST` : throws an exception when it meets corrupted records.
    • *
    @@ -550,12 +550,14 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { *
  • `mode` (default `PERMISSIVE`): allows a mode for dealing with corrupt records * during parsing. It supports the following case-insensitive modes. *
      - *
    • `PERMISSIVE` : sets other fields to `null` when it meets a corrupted record, and puts - * the malformed string into a field configured by `columnNameOfCorruptRecord`. To keep + *
    • `PERMISSIVE` : when it meets a corrupted record, puts the malformed string into a + * field configured by `columnNameOfCorruptRecord`, and sets other fields to `null`. To keep * corrupt records, an user can set a string type field named `columnNameOfCorruptRecord` * in an user-defined schema. If a schema does not have the field, it drops corrupt records - * during parsing. When a length of parsed CSV tokens is shorter than an expected length - * of a schema, it sets `null` for extra fields.
    • + * during parsing. A record with less/more tokens than schema is not a corrupted record to + * CSV. When it meets a record having fewer tokens than the length of the schema, sets + * `null` to extra fields. When the record has more tokens than the length of the schema, + * it drops extra tokens. *
    • `DROPMALFORMED` : ignores the whole corrupted records.
    • *
    • `FAILFAST` : throws an exception when it meets corrupted records.
    • *
    diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala index 7d6d7e7eef926..3d6cc30f2ba83 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala @@ -203,6 +203,8 @@ class UnivocityParser( case _: BadRecordException => None } } + // For records with less or more tokens than the schema, tries to return partial results + // if possible. throw BadRecordException( () => getCurrentInput, () => getPartialResult(), @@ -218,6 +220,9 @@ class UnivocityParser( row } catch { case NonFatal(e) => + // For corrupted records with the number of tokens same as the schema, + // CSV reader doesn't support partial results. All fields other than the field + // configured by `columnNameOfCorruptRecord` are set to `null`. throw BadRecordException(() => getCurrentInput, () => None, e) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index f23851655350a..61e22fac854f9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -236,12 +236,12 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo *
  • `mode` (default `PERMISSIVE`): allows a mode for dealing with corrupt records * during parsing. *
      - *
    • `PERMISSIVE` : sets other fields to `null` when it meets a corrupted record, and puts - * the malformed string into a field configured by `columnNameOfCorruptRecord`. To keep - * corrupt records, an user can set a string type field named `columnNameOfCorruptRecord` - * in an user-defined schema. If a schema does not have the field, it drops corrupt records - * during parsing. When inferring a schema, it implicitly adds a `columnNameOfCorruptRecord` - * field in an output schema.
    • + *
    • `PERMISSIVE` : when it meets a corrupted record, puts the malformed string into a + * field configured by `columnNameOfCorruptRecord`, and sets other fields to `null`. To + * keep corrupt records, an user can set a string type field named + * `columnNameOfCorruptRecord` in an user-defined schema. If a schema does not have the + * field, it drops corrupt records during parsing. When inferring a schema, it implicitly + * adds a `columnNameOfCorruptRecord` field in an output schema.
    • *
    • `DROPMALFORMED` : ignores the whole corrupted records.
    • *
    • `FAILFAST` : throws an exception when it meets corrupted records.
    • *
    @@ -316,12 +316,14 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo *
  • `mode` (default `PERMISSIVE`): allows a mode for dealing with corrupt records * during parsing. It supports the following case-insensitive modes. *
      - *
    • `PERMISSIVE` : sets other fields to `null` when it meets a corrupted record, and puts - * the malformed string into a field configured by `columnNameOfCorruptRecord`. To keep + *
    • `PERMISSIVE` : when it meets a corrupted record, puts the malformed string into a + * field configured by `columnNameOfCorruptRecord`, and sets other fields to `null`. To keep * corrupt records, an user can set a string type field named `columnNameOfCorruptRecord` * in an user-defined schema. If a schema does not have the field, it drops corrupt records - * during parsing. When a length of parsed CSV tokens is shorter than an expected length - * of a schema, it sets `null` for extra fields.
    • + * during parsing. A record with less/more tokens than schema is not a corrupted record to + * CSV. When it meets a record having fewer tokens than the length of the schema, sets + * `null` to extra fields. When the record has more tokens than the length of the schema, + * it drops extra tokens. *
    • `DROPMALFORMED` : ignores the whole corrupted records.
    • *
    • `FAILFAST` : throws an exception when it meets corrupted records.
    • *
    From 6a8abe29ef3369b387d9bc2ee3459a6611246ab1 Mon Sep 17 00:00:00 2001 From: zhoukang Date: Wed, 28 Feb 2018 23:16:29 +0800 Subject: [PATCH 0404/1161] [SPARK-23508][CORE] Fix BlockmanagerId in case blockManagerIdCache cause oom MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit … cause oom ## What changes were proposed in this pull request? blockManagerIdCache in BlockManagerId will not remove old values which may cause oom `val blockManagerIdCache = new ConcurrentHashMap[BlockManagerId, BlockManagerId]()` Since whenever we apply a new BlockManagerId, it will put into this map. This patch will use guava cahce for blockManagerIdCache instead. A heap dump show in [SPARK-23508](https://issues.apache.org/jira/browse/SPARK-23508) ## How was this patch tested? Exist tests. Author: zhoukang Closes #20667 from caneGuy/zhoukang/fix-history. --- .../org/apache/spark/storage/BlockManagerId.scala | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala index 2c3da0ee85e06..d4a59c33b974c 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala @@ -18,7 +18,8 @@ package org.apache.spark.storage import java.io.{Externalizable, IOException, ObjectInput, ObjectOutput} -import java.util.concurrent.ConcurrentHashMap + +import com.google.common.cache.{CacheBuilder, CacheLoader} import org.apache.spark.SparkContext import org.apache.spark.annotation.DeveloperApi @@ -132,10 +133,17 @@ private[spark] object BlockManagerId { getCachedBlockManagerId(obj) } - val blockManagerIdCache = new ConcurrentHashMap[BlockManagerId, BlockManagerId]() + /** + * The max cache size is hardcoded to 10000, since the size of a BlockManagerId + * object is about 48B, the total memory cost should be below 1MB which is feasible. + */ + val blockManagerIdCache = CacheBuilder.newBuilder() + .maximumSize(10000) + .build(new CacheLoader[BlockManagerId, BlockManagerId]() { + override def load(id: BlockManagerId) = id + }) def getCachedBlockManagerId(id: BlockManagerId): BlockManagerId = { - blockManagerIdCache.putIfAbsent(id, id) blockManagerIdCache.get(id) } } From fab563b9bd1581112462c0fc0b299ad6510b6564 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Thu, 1 Mar 2018 00:44:13 +0900 Subject: [PATCH 0405/1161] [SPARK-23517][PYTHON] Make `pyspark.util._exception_message` produce the trace from Java side by Py4JJavaError ## What changes were proposed in this pull request? This PR proposes for `pyspark.util._exception_message` to produce the trace from Java side by `Py4JJavaError`. Currently, in Python 2, it uses `message` attribute which `Py4JJavaError` didn't happen to have: ```python >>> from pyspark.util import _exception_message >>> try: ... sc._jvm.java.lang.String(None) ... except Exception as e: ... pass ... >>> e.message '' ``` Seems we should use `str` instead for now: https://github.com/bartdag/py4j/blob/aa6c53b59027925a426eb09b58c453de02c21b7c/py4j-python/src/py4j/protocol.py#L412 but this doesn't address the problem with non-ascii string from Java side - `https://github.com/bartdag/py4j/issues/306` So, we could directly call `__str__()`: ```python >>> e.__str__() u'An error occurred while calling None.java.lang.String.\n: java.lang.NullPointerException\n\tat java.lang.String.(String.java:588)\n\tat sun.reflect.NativeConstructorAccessorImpl.newInstance0(Native Method)\n\tat sun.reflect.NativeConstructorAccessorImpl.newInstance(NativeConstructorAccessorImpl.java:62)\n\tat sun.reflect.DelegatingConstructorAccessorImpl.newInstance(DelegatingConstructorAccessorImpl.java:45)\n\tat java.lang.reflect.Constructor.newInstance(Constructor.java:422)\n\tat py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:247)\n\tat py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)\n\tat py4j.Gateway.invoke(Gateway.java:238)\n\tat py4j.commands.ConstructorCommand.invokeConstructor(ConstructorCommand.java:80)\n\tat py4j.commands.ConstructorCommand.execute(ConstructorCommand.java:69)\n\tat py4j.GatewayConnection.run(GatewayConnection.java:214)\n\tat java.lang.Thread.run(Thread.java:745)\n' ``` which doesn't type coerce unicodes to `str` in Python 2. This can be actually a problem: ```python from pyspark.sql.functions import udf spark.conf.set("spark.sql.execution.arrow.enabled", True) spark.range(1).select(udf(lambda x: [[]])()).toPandas() ``` **Before** ``` Traceback (most recent call last): File "", line 1, in File "/.../spark/python/pyspark/sql/dataframe.py", line 2009, in toPandas raise RuntimeError("%s\n%s" % (_exception_message(e), msg)) RuntimeError: Note: toPandas attempted Arrow optimization because 'spark.sql.execution.arrow.enabled' is set to true. Please set it to false to disable this. ``` **After** ``` Traceback (most recent call last): File "", line 1, in File "/.../spark/python/pyspark/sql/dataframe.py", line 2009, in toPandas raise RuntimeError("%s\n%s" % (_exception_message(e), msg)) RuntimeError: An error occurred while calling o47.collectAsArrowToPython. : org.apache.spark.SparkException: Job aborted due to stage failure: Task 7 in stage 0.0 failed 1 times, most recent failure: Lost task 7.0 in stage 0.0 (TID 7, localhost, executor driver): org.apache.spark.api.python.PythonException: Traceback (most recent call last): File "/.../spark/python/pyspark/worker.py", line 245, in main process() File "/.../spark/python/pyspark/worker.py", line 240, in process ... Note: toPandas attempted Arrow optimization because 'spark.sql.execution.arrow.enabled' is set to true. Please set it to false to disable this. ``` ## How was this patch tested? Manually tested and unit tests were added. Author: hyukjinkwon Closes #20680 from HyukjinKwon/SPARK-23517. --- python/pyspark/tests.py | 11 +++++++++++ python/pyspark/util.py | 7 +++++++ 2 files changed, 18 insertions(+) diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 511585763cb01..9111dbbed5929 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -2293,6 +2293,17 @@ def set(self, x=None, other=None, other_x=None): self.assertEqual(b._x, 2) +class UtilTests(PySparkTestCase): + def test_py4j_exception_message(self): + from pyspark.util import _exception_message + + with self.assertRaises(Py4JJavaError) as context: + # This attempts java.lang.String(null) which throws an NPE. + self.sc._jvm.java.lang.String(None) + + self.assertTrue('NullPointerException' in _exception_message(context.exception)) + + @unittest.skipIf(not _have_scipy, "SciPy not installed") class SciPyTests(PySparkTestCase): diff --git a/python/pyspark/util.py b/python/pyspark/util.py index e5d332ce54429..ad4a0bc68ef41 100644 --- a/python/pyspark/util.py +++ b/python/pyspark/util.py @@ -15,6 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # +from py4j.protocol import Py4JJavaError __all__ = [] @@ -33,6 +34,12 @@ def _exception_message(excp): >>> msg == _exception_message(excp) True """ + if isinstance(excp, Py4JJavaError): + # 'Py4JJavaError' doesn't contain the stack trace available on the Java side in 'message' + # attribute in Python 2. We should call 'str' function on this exception in general but + # 'Py4JJavaError' has an issue about addressing non-ascii strings. So, here we work + # around by the direct call, '__str__()'. Please see SPARK-23517. + return excp.__str__() if hasattr(excp, "message"): return excp.message return str(excp) From 476a7f026bc45462067ebd39cd269147e84cd641 Mon Sep 17 00:00:00 2001 From: Juliusz Sompolski Date: Wed, 28 Feb 2018 08:44:53 -0800 Subject: [PATCH 0406/1161] [SPARK-23514] Use SessionState.newHadoopConf() to propage hadoop configs set in SQLConf. ## What changes were proposed in this pull request? A few places in `spark-sql` were using `sc.hadoopConfiguration` directly. They should be using `sessionState.newHadoopConf()` to blend in configs that were set through `SQLConf`. Also, for better UX, for these configs blended in from `SQLConf`, we should consider removing the `spark.hadoop` prefix, so that the settings are recognized whether or not they were specified by the user. ## How was this patch tested? Tested that AlterTableRecoverPartitions now correctly recognizes settings that are passed in to the FileSystem through SQLConf. Author: Juliusz Sompolski Closes #20679 from juliuszsompolski/SPARK-23514. --- .../scala/org/apache/spark/sql/execution/command/ddl.scala | 6 +++--- .../scala/org/apache/spark/sql/hive/test/TestHive.scala | 5 +++-- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala index 0142f17ce62e2..964cbca049b27 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala @@ -610,10 +610,10 @@ case class AlterTableRecoverPartitionsCommand( val root = new Path(table.location) logInfo(s"Recover all the partitions in $root") - val fs = root.getFileSystem(spark.sparkContext.hadoopConfiguration) + val hadoopConf = spark.sessionState.newHadoopConf() + val fs = root.getFileSystem(hadoopConf) val threshold = spark.conf.get("spark.rdd.parallelListingThreshold", "10").toInt - val hadoopConf = spark.sparkContext.hadoopConfiguration val pathFilter = getPathFilter(hadoopConf) val evalPool = ThreadUtils.newForkJoinPool("AlterTableRecoverPartitionsCommand", 8) @@ -697,7 +697,7 @@ case class AlterTableRecoverPartitionsCommand( pathFilter: PathFilter, threshold: Int): GenMap[String, PartitionStatistics] = { if (partitionSpecsAndLocs.length > threshold) { - val hadoopConf = spark.sparkContext.hadoopConfiguration + val hadoopConf = spark.sessionState.newHadoopConf() val serializableConfiguration = new SerializableConfiguration(hadoopConf) val serializedPaths = partitionSpecsAndLocs.map(_._2.toString).toArray diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index 19028939f3673..fcf2025d34432 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -518,8 +518,9 @@ private[hive] class TestHiveSparkSession( // an HDFS scratch dir: ${hive.exec.scratchdir}/ is created, with // ${hive.scratch.dir.permission}. To resolve the permission issue, the simplest way is to // delete it. Later, it will be re-created with the right permission. - val location = new Path(sc.hadoopConfiguration.get(ConfVars.SCRATCHDIR.varname)) - val fs = location.getFileSystem(sc.hadoopConfiguration) + val hadoopConf = sessionState.newHadoopConf() + val location = new Path(hadoopConf.get(ConfVars.SCRATCHDIR.varname)) + val fs = location.getFileSystem(hadoopConf) fs.delete(location, true) // Some tests corrupt this value on purpose, which breaks the RESET call below. From 25c2776dd9ae3f9792048c78be2cbd958fd99841 Mon Sep 17 00:00:00 2001 From: Xingbo Jiang Date: Wed, 28 Feb 2018 12:16:26 -0800 Subject: [PATCH 0407/1161] [SPARK-23523][SQL][FOLLOWUP] Minor refactor of OptimizeMetadataOnlyQuery ## What changes were proposed in this pull request? Inside `OptimizeMetadataOnlyQuery.getPartitionAttrs`, avoid using `zip` to generate attribute map. Also include other minor update of comments and format. ## How was this patch tested? Existing test cases. Author: Xingbo Jiang Closes #20693 from jiangxb1987/SPARK-23523. --- .../spark/sql/execution/OptimizeMetadataOnlyQuery.scala | 2 +- .../spark/sql/execution/datasources/HadoopFsRelation.scala | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala index 0613d9053f826..dc4aff9f12580 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala @@ -83,7 +83,7 @@ case class OptimizeMetadataOnlyQuery(catalog: SessionCatalog) extends Rule[Logic private def getPartitionAttrs( partitionColumnNames: Seq[String], relation: LogicalPlan): Seq[Attribute] = { - val attrMap = relation.output.map(_.name.toLowerCase(Locale.ROOT)).zip(relation.output).toMap + val attrMap = relation.output.map(a => a.name.toLowerCase(Locale.ROOT) -> a).toMap partitionColumnNames.map { colName => attrMap.getOrElse(colName.toLowerCase(Locale.ROOT), throw new AnalysisException(s"Unable to find the column `$colName` " + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelation.scala index ac574b07ec497..b2f73b7f8d1fc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelation.scala @@ -67,9 +67,9 @@ case class HadoopFsRelation( } } - // When data schema and partition schema have the overlapped columns, the output - // schema respects the order of data schema for the overlapped columns, but respect - // the data types of partition schema + // When data and partition schemas have overlapping columns, the output + // schema respects the order of the data schema for the overlapping columns, and it + // respects the data types of the partition schema. val schema: StructType = { StructType(dataSchema.map(f => overlappedPartCols.getOrElse(getColName(f), f)) ++ partitionSchema.filterNot(f => overlappedPartCols.contains(getColName(f)))) From 22f3d3334c85c042c6e90f5a02f308d7cd1c1498 Mon Sep 17 00:00:00 2001 From: liuxian Date: Thu, 1 Mar 2018 14:28:28 +0800 Subject: [PATCH 0408/1161] [SPARK-23389][CORE] When the shuffle dependency specifies aggregation ,and `dependency.mapSideCombine =false`, we should be able to use serialized sorting. ## What changes were proposed in this pull request? When the shuffle dependency specifies aggregation ,and `dependency.mapSideCombine=false`, in the map side,there is no need for aggregation and sorting, so we should be able to use serialized sorting. ## How was this patch tested? Existing unit test Author: liuxian Closes #20576 from 10110346/mapsidecombine. --- .../scala/org/apache/spark/Dependency.scala | 3 +++ .../spark/shuffle/BlockStoreShuffleReader.scala | 1 - .../spark/shuffle/sort/SortShuffleManager.scala | 6 +++--- .../spark/shuffle/sort/SortShuffleWriter.scala | 2 -- .../shuffle/sort/SortShuffleManagerSuite.scala | 17 +++++++++-------- 5 files changed, 15 insertions(+), 14 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/Dependency.scala b/core/src/main/scala/org/apache/spark/Dependency.scala index ca52ecafa2cc8..9ea6d2fa2fd95 100644 --- a/core/src/main/scala/org/apache/spark/Dependency.scala +++ b/core/src/main/scala/org/apache/spark/Dependency.scala @@ -76,6 +76,9 @@ class ShuffleDependency[K: ClassTag, V: ClassTag, C: ClassTag]( val mapSideCombine: Boolean = false) extends Dependency[Product2[K, V]] { + if (mapSideCombine) { + require(aggregator.isDefined, "Map-side combine without Aggregator specified!") + } override def rdd: RDD[Product2[K, V]] = _rdd.asInstanceOf[RDD[Product2[K, V]]] private[spark] val keyClassName: String = reflect.classTag[K].runtimeClass.getName diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala index 0562d45ff57c5..edd69715c9602 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala @@ -90,7 +90,6 @@ private[spark] class BlockStoreShuffleReader[K, C]( dep.aggregator.get.combineValuesByKey(keyValuesIterator, context) } } else { - require(!dep.mapSideCombine, "Map-side combine without Aggregator specified!") interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]] } diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala index bfb4dc698e325..d9fad64f34c7c 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala @@ -188,9 +188,9 @@ private[spark] object SortShuffleManager extends Logging { log.debug(s"Can't use serialized shuffle for shuffle $shufId because the serializer, " + s"${dependency.serializer.getClass.getName}, does not support object relocation") false - } else if (dependency.aggregator.isDefined) { - log.debug( - s"Can't use serialized shuffle for shuffle $shufId because an aggregator is defined") + } else if (dependency.mapSideCombine) { + log.debug(s"Can't use serialized shuffle for shuffle $shufId because we need to do " + + s"map-side aggregation") false } else if (numPartitions > MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE) { log.debug(s"Can't use serialized shuffle for shuffle $shufId because it has more than " + diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala index 636b88e792bf3..274399b9cc1f3 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala @@ -50,7 +50,6 @@ private[spark] class SortShuffleWriter[K, V, C]( /** Write a bunch of records to this task's output */ override def write(records: Iterator[Product2[K, V]]): Unit = { sorter = if (dep.mapSideCombine) { - require(dep.aggregator.isDefined, "Map-side combine without Aggregator specified!") new ExternalSorter[K, V, C]( context, dep.aggregator, Some(dep.partitioner), dep.keyOrdering, dep.serializer) } else { @@ -107,7 +106,6 @@ private[spark] object SortShuffleWriter { def shouldBypassMergeSort(conf: SparkConf, dep: ShuffleDependency[_, _, _]): Boolean = { // We cannot bypass sorting if we need to do map-side aggregation. if (dep.mapSideCombine) { - require(dep.aggregator.isDefined, "Map-side combine without Aggregator specified!") false } else { val bypassMergeThreshold: Int = conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200) diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleManagerSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleManagerSuite.scala index 55cebe7c8b6a8..f29dac965c803 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleManagerSuite.scala @@ -85,6 +85,14 @@ class SortShuffleManagerSuite extends SparkFunSuite with Matchers { mapSideCombine = false ))) + // We support serialized shuffle if we do not need to do map-side aggregation + assert(canUseSerializedShuffle(shuffleDep( + partitioner = new HashPartitioner(2), + serializer = kryo, + keyOrdering = None, + aggregator = Some(mock(classOf[Aggregator[Any, Any, Any]])), + mapSideCombine = false + ))) } test("unsupported shuffle dependencies for serialized shuffle") { @@ -111,14 +119,7 @@ class SortShuffleManagerSuite extends SparkFunSuite with Matchers { mapSideCombine = false ))) - // We do not support shuffles that perform aggregation - assert(!canUseSerializedShuffle(shuffleDep( - partitioner = new HashPartitioner(2), - serializer = kryo, - keyOrdering = None, - aggregator = Some(mock(classOf[Aggregator[Any, Any, Any]])), - mapSideCombine = false - ))) + // We do not support serialized shuffle if we need to do map-side aggregation assert(!canUseSerializedShuffle(shuffleDep( partitioner = new HashPartitioner(2), serializer = kryo, From ff1480189b827af0be38605d566a4ee71b4c36f6 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Thu, 1 Mar 2018 16:26:11 +0800 Subject: [PATCH 0409/1161] [SPARK-23510][SQL] Support Hive 2.2 and Hive 2.3 metastore ## What changes were proposed in this pull request? This is based on https://github.com/apache/spark/pull/20668 for supporting Hive 2.2 and Hive 2.3 metastore. When we merge the PR, we should give the major credit to wangyum ## How was this patch tested? Added the test cases Author: Yuming Wang Author: gatorsmile Closes #20671 from gatorsmile/pr-20668. --- .../org/apache/spark/sql/hive/HiveUtils.scala | 2 +- .../sql/hive/client/HiveClientImpl.scala | 3 +- .../spark/sql/hive/client/HiveShim.scala | 8 +-- .../hive/client/IsolatedClientLoader.scala | 2 + .../spark/sql/hive/client/package.scala | 10 +++- .../sql/hive/execution/SaveAsHiveFile.scala | 3 +- .../sql/hive/client/HiveClientVersions.scala | 3 +- .../sql/hive/client/HiveVersionSuite.scala | 2 +- .../spark/sql/hive/client/VersionsSuite.scala | 51 +++++++++++++++++-- 9 files changed, 72 insertions(+), 12 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala index c448c5a9821be..10c9603745379 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala @@ -62,7 +62,7 @@ private[spark] object HiveUtils extends Logging { val HIVE_METASTORE_VERSION = buildConf("spark.sql.hive.metastore.version") .doc("Version of the Hive metastore. Available options are " + - s"0.12.0 through 2.1.1.") + s"0.12.0 through 2.3.2.") .stringConf .createWithDefault(builtinHiveVersion) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala index 146fa54a1bce4..da9fe2d3088b4 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala @@ -25,7 +25,6 @@ import scala.collection.JavaConverters._ import scala.collection.mutable import scala.collection.mutable.ArrayBuffer -import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.hadoop.hive.common.StatsSetupConst import org.apache.hadoop.hive.conf.HiveConf @@ -104,6 +103,8 @@ private[hive] class HiveClientImpl( case hive.v1_2 => new Shim_v1_2() case hive.v2_0 => new Shim_v2_0() case hive.v2_1 => new Shim_v2_1() + case hive.v2_2 => new Shim_v2_2() + case hive.v2_3 => new Shim_v2_3() } // Create an internal session state for this HiveClientImpl. diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala index 1eac70dbf19cd..948ba542b5733 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala @@ -880,9 +880,7 @@ private[client] class Shim_v0_14 extends Shim_v0_13 { } -private[client] class Shim_v1_0 extends Shim_v0_14 { - -} +private[client] class Shim_v1_0 extends Shim_v0_14 private[client] class Shim_v1_1 extends Shim_v1_0 { @@ -1146,3 +1144,7 @@ private[client] class Shim_v2_1 extends Shim_v2_0 { alterPartitionsMethod.invoke(hive, tableName, newParts, environmentContextInAlterTable) } } + +private[client] class Shim_v2_2 extends Shim_v2_1 + +private[client] class Shim_v2_3 extends Shim_v2_1 diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala index dac0e333b63bc..12975bc85b971 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala @@ -97,6 +97,8 @@ private[hive] object IsolatedClientLoader extends Logging { case "1.2" | "1.2.0" | "1.2.1" | "1.2.2" => hive.v1_2 case "2.0" | "2.0.0" | "2.0.1" => hive.v2_0 case "2.1" | "2.1.0" | "2.1.1" => hive.v2_1 + case "2.2" | "2.2.0" => hive.v2_2 + case "2.3" | "2.3.0" | "2.3.1" | "2.3.2" => hive.v2_3 } private def downloadVersion( diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala index c14154a3b3c21..681ee9200f02b 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala @@ -71,7 +71,15 @@ package object client { exclusions = Seq("org.apache.curator:*", "org.pentaho:pentaho-aggdesigner-algorithm")) - val allSupportedHiveVersions = Set(v12, v13, v14, v1_0, v1_1, v1_2, v2_0, v2_1) + case object v2_2 extends HiveVersion("2.2.0", + exclusions = Seq("org.apache.curator:*", + "org.pentaho:pentaho-aggdesigner-algorithm")) + + case object v2_3 extends HiveVersion("2.3.2", + exclusions = Seq("org.apache.curator:*", + "org.pentaho:pentaho-aggdesigner-algorithm")) + + val allSupportedHiveVersions = Set(v12, v13, v14, v1_0, v1_1, v1_2, v2_0, v2_1, v2_2, v2_3) } // scalastyle:on diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/SaveAsHiveFile.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/SaveAsHiveFile.scala index e484356906e87..6a7b25b36d9a5 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/SaveAsHiveFile.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/SaveAsHiveFile.scala @@ -114,7 +114,8 @@ private[hive] trait SaveAsHiveFile extends DataWritingCommand { // staging directory under the table director for Hive prior to 1.1, the staging directory will // be removed by Hive when Hive is trying to empty the table directory. val hiveVersionsUsingOldExternalTempPath: Set[HiveVersion] = Set(v12, v13, v14, v1_0) - val hiveVersionsUsingNewExternalTempPath: Set[HiveVersion] = Set(v1_1, v1_2, v2_0, v2_1) + val hiveVersionsUsingNewExternalTempPath: Set[HiveVersion] = + Set(v1_1, v1_2, v2_0, v2_1, v2_2, v2_3) // Ensure all the supported versions are considered here. assert(hiveVersionsUsingNewExternalTempPath ++ hiveVersionsUsingOldExternalTempPath == diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientVersions.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientVersions.scala index 2e7dfde8b2fa5..30592a3f85428 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientVersions.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientVersions.scala @@ -22,5 +22,6 @@ import scala.collection.immutable.IndexedSeq import org.apache.spark.SparkFunSuite private[client] trait HiveClientVersions { - protected val versions = IndexedSeq("0.12", "0.13", "0.14", "1.0", "1.1", "1.2", "2.0", "2.1") + protected val versions = + IndexedSeq("0.12", "0.13", "0.14", "1.0", "1.1", "1.2", "2.0", "2.1", "2.2", "2.3") } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveVersionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveVersionSuite.scala index a70fb6464cc1d..e5963d03f6b52 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveVersionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveVersionSuite.scala @@ -34,7 +34,7 @@ private[client] abstract class HiveVersionSuite(version: String) extends SparkFu // Hive changed the default of datanucleus.schema.autoCreateAll from true to false and // hive.metastore.schema.verification from false to true since 2.0 // For details, see the JIRA HIVE-6113 and HIVE-12463 - if (version == "2.0" || version == "2.1") { + if (version == "2.0" || version == "2.1" || version == "2.2" || version == "2.3") { hadoopConf.set("datanucleus.schema.autoCreateAll", "true") hadoopConf.set("hive.metastore.schema.verification", "false") } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala index 72536b833481a..6176273c88db1 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala @@ -21,6 +21,7 @@ import java.io.{ByteArrayOutputStream, File, PrintStream, PrintWriter} import java.net.URI import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.hive.common.StatsSetupConst import org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe import org.apache.hadoop.mapred.TextInputFormat @@ -110,7 +111,8 @@ class VersionsSuite extends SparkFunSuite with Logging { assert(getNestedMessages(e) contains "Unknown column 'A0.OWNER_NAME' in 'field list'") } - private val versions = Seq("0.12", "0.13", "0.14", "1.0", "1.1", "1.2", "2.0", "2.1") + private val versions = + Seq("0.12", "0.13", "0.14", "1.0", "1.1", "1.2", "2.0", "2.1", "2.2", "2.3") private var client: HiveClient = null @@ -125,7 +127,7 @@ class VersionsSuite extends SparkFunSuite with Logging { // Hive changed the default of datanucleus.schema.autoCreateAll from true to false and // hive.metastore.schema.verification from false to true since 2.0 // For details, see the JIRA HIVE-6113 and HIVE-12463 - if (version == "2.0" || version == "2.1") { + if (version == "2.0" || version == "2.1" || version == "2.2" || version == "2.3") { hadoopConf.set("datanucleus.schema.autoCreateAll", "true") hadoopConf.set("hive.metastore.schema.verification", "false") } @@ -422,15 +424,18 @@ class VersionsSuite extends SparkFunSuite with Logging { test(s"$version: alterPartitions") { val spec = Map("key1" -> "1", "key2" -> "2") + val parameters = Map(StatsSetupConst.TOTAL_SIZE -> "0", StatsSetupConst.NUM_FILES -> "1") val newLocation = new URI(Utils.createTempDir().toURI.toString.stripSuffix("/")) val storage = storageFormat.copy( locationUri = Some(newLocation), // needed for 0.12 alter partitions serde = Some("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe")) - val partition = CatalogTablePartition(spec, storage) + val partition = CatalogTablePartition(spec, storage, parameters) client.alterPartitions("default", "src_part", Seq(partition)) assert(client.getPartition("default", "src_part", spec) .storage.locationUri == Some(newLocation)) + assert(client.getPartition("default", "src_part", spec) + .parameters.get(StatsSetupConst.TOTAL_SIZE) == Some("0")) } test(s"$version: dropPartitions") { @@ -633,6 +638,46 @@ class VersionsSuite extends SparkFunSuite with Logging { } } + test(s"$version: CREATE Partitioned TABLE AS SELECT") { + withTable("tbl") { + versionSpark.sql( + """ + |CREATE TABLE tbl(c1 string) + |PARTITIONED BY (ds STRING) + """.stripMargin) + versionSpark.sql("INSERT OVERWRITE TABLE tbl partition (ds='2') SELECT '1'") + + assert(versionSpark.table("tbl").collect().toSeq == Seq(Row("1", "2"))) + val partMeta = versionSpark.sessionState.catalog.getPartition( + TableIdentifier("tbl"), spec = Map("ds" -> "2")).parameters + val totalSize = partMeta.get(StatsSetupConst.TOTAL_SIZE).map(_.toLong) + val numFiles = partMeta.get(StatsSetupConst.NUM_FILES).map(_.toLong) + // Except 0.12, all the following versions will fill the Hive-generated statistics + if (version == "0.12") { + assert(totalSize.isEmpty && numFiles.isEmpty) + } else { + assert(totalSize.nonEmpty && numFiles.nonEmpty) + } + + versionSpark.sql( + """ + |ALTER TABLE tbl PARTITION (ds='2') + |SET SERDEPROPERTIES ('newKey' = 'vvv') + """.stripMargin) + val newPartMeta = versionSpark.sessionState.catalog.getPartition( + TableIdentifier("tbl"), spec = Map("ds" -> "2")).parameters + + val newTotalSize = newPartMeta.get(StatsSetupConst.TOTAL_SIZE).map(_.toLong) + val newNumFiles = newPartMeta.get(StatsSetupConst.NUM_FILES).map(_.toLong) + // Except 0.12, all the following versions will fill the Hive-generated statistics + if (version == "0.12") { + assert(newTotalSize.isEmpty && newNumFiles.isEmpty) + } else { + assert(newTotalSize.nonEmpty && newNumFiles.nonEmpty) + } + } + } + test(s"$version: Delete the temporary staging directory and files after each insert") { withTempDir { tmpDir => withTable("tab") { From cdcccd7b41c43d79edff2fec7a84cd00e9524f75 Mon Sep 17 00:00:00 2001 From: KaiXinXiaoLei <584620569@qq.com> Date: Fri, 2 Mar 2018 00:09:44 +0800 Subject: [PATCH 0410/1161] [SPARK-23405] Generate additional constraints for Join's children ## What changes were proposed in this pull request? (Please fill in changes proposed in this fix) I run a sql: `select ls.cs_order_number from ls left semi join catalog_sales cs on ls.cs_order_number = cs.cs_order_number`, The `ls` table is a small table ,and the number is one. The `catalog_sales` table is a big table, and the number is 10 billion. The task will be hang up. And i find the many null values of `cs_order_number` in the `catalog_sales` table. I think the null value should be removed in the logical plan. >== Optimized Logical Plan == >Join LeftSemi, (cs_order_number#1 = cs_order_number#22) >:- Project cs_order_number#1 > : +- Filter isnotnull(cs_order_number#1) > : +- MetastoreRelation 100t, ls >+- Project cs_order_number#22 > +- MetastoreRelation 100t, catalog_sales Now, use this patch, the plan will be: >== Optimized Logical Plan == >Join LeftSemi, (cs_order_number#1 = cs_order_number#22) >:- Project cs_order_number#1 > : +- Filter isnotnull(cs_order_number#1) > : +- MetastoreRelation 100t, ls >+- Project cs_order_number#22 > : **+- Filter isnotnull(cs_order_number#22)** > :+- MetastoreRelation 100t, catalog_sales ## How was this patch tested? (Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests) (If this patch involves UI changes, please attach a screenshot; otherwise, remove this) Please review http://spark.apache.org/contributing.html before opening a pull request. Author: KaiXinXiaoLei <584620569@qq.com> Author: hanghang <584620569@qq.com> Closes #20670 from KaiXinXiaoLei/Spark-23405. --- .../sql/catalyst/optimizer/Optimizer.scala | 2 +- .../plans/logical/QueryPlanConstraints.scala | 27 ++++++++++--------- .../InferFiltersFromConstraintsSuite.scala | 12 +++++++++ 3 files changed, 28 insertions(+), 13 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index a28b6a0feb8f9..91208479be03b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -661,7 +661,7 @@ object InferFiltersFromConstraints extends Rule[LogicalPlan] with PredicateHelpe case join @ Join(left, right, joinType, conditionOpt) => // Only consider constraints that can be pushed down completely to either the left or the // right child - val constraints = join.constraints.filter { c => + val constraints = join.allConstraints.filter { c => c.references.subsetOf(left.outputSet) || c.references.subsetOf(right.outputSet) } // Remove those constraints that are already enforced by either the left or the right child diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala index 5c7b8e5b97883..046848875548b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala @@ -23,25 +23,28 @@ import org.apache.spark.sql.catalyst.expressions._ trait QueryPlanConstraints { self: LogicalPlan => /** - * An [[ExpressionSet]] that contains invariants about the rows output by this operator. For - * example, if this set contains the expression `a = 2` then that expression is guaranteed to - * evaluate to `true` for all rows produced. + * An [[ExpressionSet]] that contains an additional set of constraints, such as equality + * constraints and `isNotNull` constraints, etc. */ - lazy val constraints: ExpressionSet = { + lazy val allConstraints: ExpressionSet = { if (conf.constraintPropagationEnabled) { - ExpressionSet( - validConstraints - .union(inferAdditionalConstraints(validConstraints)) - .union(constructIsNotNullConstraints(validConstraints)) - .filter { c => - c.references.nonEmpty && c.references.subsetOf(outputSet) && c.deterministic - } - ) + ExpressionSet(validConstraints + .union(inferAdditionalConstraints(validConstraints)) + .union(constructIsNotNullConstraints(validConstraints))) } else { ExpressionSet(Set.empty) } } + /** + * An [[ExpressionSet]] that contains invariants about the rows output by this operator. For + * example, if this set contains the expression `a = 2` then that expression is guaranteed to + * evaluate to `true` for all rows produced. + */ + lazy val constraints: ExpressionSet = ExpressionSet(allConstraints.filter { c => + c.references.nonEmpty && c.references.subsetOf(outputSet) && c.deterministic + }) + /** * This method can be overridden by any child class of QueryPlan to specify a set of constraints * based on the given operator's constraint propagation logic. These constraints are then diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala index 178c4b8c270a0..f78c2356e35a5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala @@ -192,4 +192,16 @@ class InferFiltersFromConstraintsSuite extends PlanTest { comparePlans(Optimize.execute(original.analyze), correct.analyze) } + + test("SPARK-23405: left-semi equal-join should filter out null join keys on both sides") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + val condition = Some("x.a".attr === "y.a".attr) + val originalQuery = x.join(y, LeftSemi, condition).analyze + val left = x.where(IsNotNull('a)) + val right = y.where(IsNotNull('a)) + val correctAnswer = left.join(right, LeftSemi, condition).analyze + val optimized = Optimize.execute(originalQuery) + comparePlans(optimized, correctAnswer) + } } From 34811e0b908449fd59bca476604612b1d200778d Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Thu, 1 Mar 2018 17:26:39 -0800 Subject: [PATCH 0411/1161] [SPARK-23551][BUILD] Exclude `hadoop-mapreduce-client-core` dependency from `orc-mapreduce` ## What changes were proposed in this pull request? This PR aims to prevent `orc-mapreduce` dependency from making IDEs and maven confused. **BEFORE** Please note that `2.6.4` at `Spark Project SQL`. ``` $ mvn dependency:tree -Phadoop-2.7 -Dincludes=org.apache.hadoop:hadoop-mapreduce-client-core ... [INFO] ------------------------------------------------------------------------ [INFO] Building Spark Project Catalyst 2.4.0-SNAPSHOT [INFO] ------------------------------------------------------------------------ [INFO] [INFO] --- maven-dependency-plugin:3.0.2:tree (default-cli) spark-catalyst_2.11 --- [INFO] org.apache.spark:spark-catalyst_2.11:jar:2.4.0-SNAPSHOT [INFO] \- org.apache.spark:spark-core_2.11:jar:2.4.0-SNAPSHOT:compile [INFO] \- org.apache.hadoop:hadoop-client:jar:2.7.3:compile [INFO] \- org.apache.hadoop:hadoop-mapreduce-client-core:jar:2.7.3:compile [INFO] [INFO] ------------------------------------------------------------------------ [INFO] Building Spark Project SQL 2.4.0-SNAPSHOT [INFO] ------------------------------------------------------------------------ [INFO] [INFO] --- maven-dependency-plugin:3.0.2:tree (default-cli) spark-sql_2.11 --- [INFO] org.apache.spark:spark-sql_2.11:jar:2.4.0-SNAPSHOT [INFO] \- org.apache.orc:orc-mapreduce:jar:nohive:1.4.3:compile [INFO] \- org.apache.hadoop:hadoop-mapreduce-client-core:jar:2.6.4:compile ``` **AFTER** ``` $ mvn dependency:tree -Phadoop-2.7 -Dincludes=org.apache.hadoop:hadoop-mapreduce-client-core ... [INFO] ------------------------------------------------------------------------ [INFO] Building Spark Project Catalyst 2.4.0-SNAPSHOT [INFO] ------------------------------------------------------------------------ [INFO] [INFO] --- maven-dependency-plugin:3.0.2:tree (default-cli) spark-catalyst_2.11 --- [INFO] org.apache.spark:spark-catalyst_2.11:jar:2.4.0-SNAPSHOT [INFO] \- org.apache.spark:spark-core_2.11:jar:2.4.0-SNAPSHOT:compile [INFO] \- org.apache.hadoop:hadoop-client:jar:2.7.3:compile [INFO] \- org.apache.hadoop:hadoop-mapreduce-client-core:jar:2.7.3:compile [INFO] [INFO] ------------------------------------------------------------------------ [INFO] Building Spark Project SQL 2.4.0-SNAPSHOT [INFO] ------------------------------------------------------------------------ [INFO] [INFO] --- maven-dependency-plugin:3.0.2:tree (default-cli) spark-sql_2.11 --- [INFO] org.apache.spark:spark-sql_2.11:jar:2.4.0-SNAPSHOT [INFO] \- org.apache.spark:spark-core_2.11:jar:2.4.0-SNAPSHOT:compile [INFO] \- org.apache.hadoop:hadoop-client:jar:2.7.3:compile [INFO] \- org.apache.hadoop:hadoop-mapreduce-client-core:jar:2.7.3:compile ``` ## How was this patch tested? 1. Pass the Jenkins with `dev/test-dependencies.sh` with the existing dependencies. 2. Manually do the following and see the change. ``` mvn dependency:tree -Phadoop-2.7 -Dincludes=org.apache.hadoop:hadoop-mapreduce-client-core ``` Author: Dongjoon Hyun Closes #20704 from dongjoon-hyun/SPARK-23551. --- pom.xml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pom.xml b/pom.xml index b8396166f6b1b..0a711f287a53f 100644 --- a/pom.xml +++ b/pom.xml @@ -1753,6 +1753,10 @@ org.apache.hadoop hadoop-common + + org.apache.hadoop + hadoop-mapreduce-client-core + org.apache.orc orc-core From 119f6a0e4729aa952e811d2047790a32ee90bf69 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Thu, 1 Mar 2018 21:04:01 -0800 Subject: [PATCH 0412/1161] [SPARK-22883][ML][TEST] Streaming tests for spark.ml.feature, from A to H ## What changes were proposed in this pull request? Adds structured streaming tests using testTransformer for these suites: * BinarizerSuite * BucketedRandomProjectionLSHSuite * BucketizerSuite * ChiSqSelectorSuite * CountVectorizerSuite * DCTSuite.scala * ElementwiseProductSuite * FeatureHasherSuite * HashingTFSuite ## How was this patch tested? It tests itself because it is a bunch of tests! Author: Joseph K. Bradley Closes #20111 from jkbradley/SPARK-22883-streaming-featureAM. --- .../spark/ml/feature/BinarizerSuite.scala | 8 ++-- .../BucketedRandomProjectionLSHSuite.scala | 26 ++++++++--- .../spark/ml/feature/BucketizerSuite.scala | 11 +++-- .../spark/ml/feature/ChiSqSelectorSuite.scala | 36 +++++++-------- .../ml/feature/CountVectorizerSuite.scala | 23 +++++----- .../apache/spark/ml/feature/DCTSuite.scala | 14 +++--- .../ml/feature/ElementwiseProductSuite.scala | 30 ++++++++++--- .../spark/ml/feature/FeatureHasherSuite.scala | 45 +++++++++---------- .../spark/ml/feature/HashingTFSuite.scala | 34 ++++++++------ 9 files changed, 126 insertions(+), 101 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala index 4455d35210878..05d4a6ee2dabf 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala @@ -17,14 +17,12 @@ package org.apache.spark.ml.feature -import org.apache.spark.SparkFunSuite import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.DefaultReadWriteTest -import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest} import org.apache.spark.sql.{DataFrame, Row} -class BinarizerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class BinarizerSuite extends MLTest with DefaultReadWriteTest { import testImplicits._ @@ -47,7 +45,7 @@ class BinarizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defau .setInputCol("feature") .setOutputCol("binarized_feature") - binarizer.transform(dataFrame).select("binarized_feature", "expected").collect().foreach { + testTransformer[(Double, Double)](dataFrame, binarizer, "binarized_feature", "expected") { case Row(x: Double, y: Double) => assert(x === y, "The feature value is not correct after binarization.") } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala index 7175c721bff36..ed9a39d8d1512 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala @@ -20,16 +20,15 @@ package org.apache.spark.ml.feature import breeze.numerics.{cos, sin} import breeze.numerics.constants.Pi -import org.apache.spark.SparkFunSuite import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ -import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.Dataset +import org.apache.spark.sql.{Dataset, Row} -class BucketedRandomProjectionLSHSuite - extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class BucketedRandomProjectionLSHSuite extends MLTest with DefaultReadWriteTest { + + import testImplicits._ @transient var dataset: Dataset[_] = _ @@ -98,6 +97,21 @@ class BucketedRandomProjectionLSHSuite MLTestingUtils.checkCopyAndUids(brp, brpModel) } + test("BucketedRandomProjectionLSH: streaming transform") { + val brp = new BucketedRandomProjectionLSH() + .setNumHashTables(2) + .setInputCol("keys") + .setOutputCol("values") + .setBucketLength(1.0) + .setSeed(12345) + val brpModel = brp.fit(dataset) + + testTransformer[Tuple1[Vector]](dataset.toDF(), brpModel, "values") { + case Row(values: Seq[_]) => + assert(values.length === brp.getNumHashTables) + } + } + test("BucketedRandomProjectionLSH: test of LSH property") { // Project from 2 dimensional Euclidean Space to 1 dimensions val brp = new BucketedRandomProjectionLSH() diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala index 41cf72fe3470a..9ea15e1918532 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala @@ -23,14 +23,13 @@ import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.Pipeline import org.apache.spark.ml.linalg.Vectors import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest} import org.apache.spark.ml.util.TestingUtils._ -import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ -class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class BucketizerSuite extends MLTest with DefaultReadWriteTest { import testImplicits._ @@ -50,7 +49,7 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa .setOutputCol("result") .setSplits(splits) - bucketizer.transform(dataFrame).select("result", "expected").collect().foreach { + testTransformer[(Double, Double)](dataFrame, bucketizer, "result", "expected") { case Row(x: Double, y: Double) => assert(x === y, s"The feature value is not correct after bucketing. Expected $y but found $x") @@ -84,7 +83,7 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa .setOutputCol("result") .setSplits(splits) - bucketizer.transform(dataFrame).select("result", "expected").collect().foreach { + testTransformer[(Double, Double)](dataFrame, bucketizer, "result", "expected") { case Row(x: Double, y: Double) => assert(x === y, s"The feature value is not correct after bucketing. Expected $y but found $x") @@ -103,7 +102,7 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa .setSplits(splits) bucketizer.setHandleInvalid("keep") - bucketizer.transform(dataFrame).select("result", "expected").collect().foreach { + testTransformer[(Double, Double)](dataFrame, bucketizer, "result", "expected") { case Row(x: Double, y: Double) => assert(x === y, s"The feature value is not correct after bucketing. Expected $y but found $x") diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala index c83909c4498f2..c843df9f33e3e 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala @@ -17,16 +17,15 @@ package org.apache.spark.ml.feature -import org.apache.spark.SparkFunSuite import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ -import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{Dataset, Row} -class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext - with DefaultReadWriteTest { +class ChiSqSelectorSuite extends MLTest with DefaultReadWriteTest { + + import testImplicits._ @transient var dataset: Dataset[_] = _ @@ -119,32 +118,32 @@ class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext test("Test Chi-Square selector: numTopFeatures") { val selector = new ChiSqSelector() .setOutputCol("filtered").setSelectorType("numTopFeatures").setNumTopFeatures(1) - val model = ChiSqSelectorSuite.testSelector(selector, dataset) + val model = testSelector(selector, dataset) MLTestingUtils.checkCopyAndUids(selector, model) } test("Test Chi-Square selector: percentile") { val selector = new ChiSqSelector() .setOutputCol("filtered").setSelectorType("percentile").setPercentile(0.17) - ChiSqSelectorSuite.testSelector(selector, dataset) + testSelector(selector, dataset) } test("Test Chi-Square selector: fpr") { val selector = new ChiSqSelector() .setOutputCol("filtered").setSelectorType("fpr").setFpr(0.02) - ChiSqSelectorSuite.testSelector(selector, dataset) + testSelector(selector, dataset) } test("Test Chi-Square selector: fdr") { val selector = new ChiSqSelector() .setOutputCol("filtered").setSelectorType("fdr").setFdr(0.12) - ChiSqSelectorSuite.testSelector(selector, dataset) + testSelector(selector, dataset) } test("Test Chi-Square selector: fwe") { val selector = new ChiSqSelector() .setOutputCol("filtered").setSelectorType("fwe").setFwe(0.12) - ChiSqSelectorSuite.testSelector(selector, dataset) + testSelector(selector, dataset) } test("read/write") { @@ -163,18 +162,19 @@ class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext assert(expected.selectedFeatures === actual.selectedFeatures) } } -} -object ChiSqSelectorSuite { - - private def testSelector(selector: ChiSqSelector, dataset: Dataset[_]): ChiSqSelectorModel = { - val selectorModel = selector.fit(dataset) - selectorModel.transform(dataset).select("filtered", "topFeature").collect() - .foreach { case Row(vec1: Vector, vec2: Vector) => + private def testSelector(selector: ChiSqSelector, data: Dataset[_]): ChiSqSelectorModel = { + val selectorModel = selector.fit(data) + testTransformer[(Double, Vector, Vector)](data.toDF(), selectorModel, + "filtered", "topFeature") { + case Row(vec1: Vector, vec2: Vector) => assert(vec1 ~== vec2 absTol 1e-1) - } + } selectorModel } +} + +object ChiSqSelectorSuite { /** * Mapping from all Params to valid settings which differ from the defaults. diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala index 1784c07ca23e3..61217669d9277 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala @@ -16,16 +16,13 @@ */ package org.apache.spark.ml.feature -import org.apache.spark.SparkFunSuite import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ -import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.Row -class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext - with DefaultReadWriteTest { +class CountVectorizerSuite extends MLTest with DefaultReadWriteTest { import testImplicits._ @@ -50,7 +47,7 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext val cv = new CountVectorizerModel(Array("a", "b", "c", "d")) .setInputCol("words") .setOutputCol("features") - cv.transform(df).select("features", "expected").collect().foreach { + testTransformer[(Int, Seq[String], Vector)](df, cv, "features", "expected") { case Row(features: Vector, expected: Vector) => assert(features ~== expected absTol 1e-14) } @@ -72,7 +69,7 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext MLTestingUtils.checkCopyAndUids(cv, cvm) assert(cvm.vocabulary.toSet === Set("a", "b", "c", "d", "e")) - cvm.transform(df).select("features", "expected").collect().foreach { + testTransformer[(Int, Seq[String], Vector)](df, cvm, "features", "expected") { case Row(features: Vector, expected: Vector) => assert(features ~== expected absTol 1e-14) } @@ -100,7 +97,7 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext .fit(df) assert(cvModel2.vocabulary === Array("a", "b")) - cvModel2.transform(df).select("features", "expected").collect().foreach { + testTransformer[(Int, Seq[String], Vector)](df, cvModel2, "features", "expected") { case Row(features: Vector, expected: Vector) => assert(features ~== expected absTol 1e-14) } @@ -113,7 +110,7 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext .fit(df) assert(cvModel3.vocabulary === Array("a", "b")) - cvModel3.transform(df).select("features", "expected").collect().foreach { + testTransformer[(Int, Seq[String], Vector)](df, cvModel3, "features", "expected") { case Row(features: Vector, expected: Vector) => assert(features ~== expected absTol 1e-14) } @@ -219,7 +216,7 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext .setInputCol("words") .setOutputCol("features") .setMinTF(3) - cv.transform(df).select("features", "expected").collect().foreach { + testTransformer[(Int, Seq[String], Vector)](df, cv, "features", "expected") { case Row(features: Vector, expected: Vector) => assert(features ~== expected absTol 1e-14) } @@ -238,7 +235,7 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext .setInputCol("words") .setOutputCol("features") .setMinTF(0.3) - cv.transform(df).select("features", "expected").collect().foreach { + testTransformer[(Int, Seq[String], Vector)](df, cv, "features", "expected") { case Row(features: Vector, expected: Vector) => assert(features ~== expected absTol 1e-14) } @@ -258,7 +255,7 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext .setOutputCol("features") .setBinary(true) .fit(df) - cv.transform(df).select("features", "expected").collect().foreach { + testTransformer[(Int, Seq[String], Vector)](df, cv, "features", "expected") { case Row(features: Vector, expected: Vector) => assert(features ~== expected absTol 1e-14) } @@ -268,7 +265,7 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext .setInputCol("words") .setOutputCol("features") .setBinary(true) - cv2.transform(df).select("features", "expected").collect().foreach { + testTransformer[(Int, Seq[String], Vector)](df, cv2, "features", "expected") { case Row(features: Vector, expected: Vector) => assert(features ~== expected absTol 1e-14) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/DCTSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/DCTSuite.scala index 8dd3dd75e1be5..6734336aac39c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/DCTSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/DCTSuite.scala @@ -21,16 +21,14 @@ import scala.beans.BeanInfo import edu.emory.mathcs.jtransforms.dct.DoubleDCT_1D -import org.apache.spark.SparkFunSuite import org.apache.spark.ml.linalg.{Vector, Vectors} -import org.apache.spark.ml.util.DefaultReadWriteTest -import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest} import org.apache.spark.sql.Row @BeanInfo case class DCTTestData(vec: Vector, wantedVec: Vector) -class DCTSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class DCTSuite extends MLTest with DefaultReadWriteTest { import testImplicits._ @@ -72,11 +70,9 @@ class DCTSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead .setOutputCol("resultVec") .setInverse(inverse) - transformer.transform(dataset) - .select("resultVec", "wantedVec") - .collect() - .foreach { case Row(resultVec: Vector, wantedVec: Vector) => - assert(Vectors.sqdist(resultVec, wantedVec) < 1e-6) + testTransformer[(Vector, Vector)](dataset, transformer, "resultVec", "wantedVec") { + case Row(resultVec: Vector, wantedVec: Vector) => + assert(Vectors.sqdist(resultVec, wantedVec) < 1e-6) } } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/ElementwiseProductSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/ElementwiseProductSuite.scala index a4cca27be7815..3a8d0762e2ab7 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/ElementwiseProductSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/ElementwiseProductSuite.scala @@ -17,13 +17,31 @@ package org.apache.spark.ml.feature -import org.apache.spark.SparkFunSuite -import org.apache.spark.ml.linalg.Vectors -import org.apache.spark.ml.util.DefaultReadWriteTest -import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.ml.linalg.{Vector, Vectors} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest} +import org.apache.spark.ml.util.TestingUtils._ +import org.apache.spark.sql.Row -class ElementwiseProductSuite - extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class ElementwiseProductSuite extends MLTest with DefaultReadWriteTest { + + import testImplicits._ + + test("streaming transform") { + val scalingVec = Vectors.dense(0.1, 10.0) + val data = Seq( + (Vectors.dense(0.1, 1.0), Vectors.dense(0.01, 10.0)), + (Vectors.dense(0.0, -1.1), Vectors.dense(0.0, -11.0)) + ) + val df = spark.createDataFrame(data).toDF("features", "expected") + val ep = new ElementwiseProduct() + .setInputCol("features") + .setOutputCol("actual") + .setScalingVec(scalingVec) + testTransformer[(Vector, Vector)](df, ep, "actual", "expected") { + case Row(actual: Vector, expected: Vector) => + assert(actual ~== expected relTol 1e-14) + } + } test("read/write") { val ep = new ElementwiseProduct() diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/FeatureHasherSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/FeatureHasherSuite.scala index 7bc1825b69c43..d799ba6011fa8 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/FeatureHasherSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/FeatureHasherSuite.scala @@ -17,27 +17,24 @@ package org.apache.spark.ml.feature -import org.apache.spark.SparkFunSuite import org.apache.spark.ml.attribute.AttributeGroup import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest} import org.apache.spark.ml.util.TestingUtils._ -import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.functions.col import org.apache.spark.sql.types._ import org.apache.spark.util.Utils -class FeatureHasherSuite extends SparkFunSuite - with MLlibTestSparkContext - with DefaultReadWriteTest { +class FeatureHasherSuite extends MLTest with DefaultReadWriteTest { import testImplicits._ import FeatureHasherSuite.murmur3FeatureIdx - implicit private val vectorEncoder = ExpressionEncoder[Vector]() + implicit private val vectorEncoder: ExpressionEncoder[Vector] = ExpressionEncoder[Vector]() test("params") { ParamsSuite.checkParams(new FeatureHasher) @@ -52,31 +49,31 @@ class FeatureHasherSuite extends SparkFunSuite } test("feature hashing") { + val numFeatures = 100 + // Assume perfect hash on field names in computing expected results + def idx: Any => Int = murmur3FeatureIdx(numFeatures) + val df = Seq( - (2.0, true, "1", "foo"), - (3.0, false, "2", "bar") - ).toDF("real", "bool", "stringNum", "string") + (2.0, true, "1", "foo", + Vectors.sparse(numFeatures, Seq((idx("real"), 2.0), (idx("bool=true"), 1.0), + (idx("stringNum=1"), 1.0), (idx("string=foo"), 1.0)))), + (3.0, false, "2", "bar", + Vectors.sparse(numFeatures, Seq((idx("real"), 3.0), (idx("bool=false"), 1.0), + (idx("stringNum=2"), 1.0), (idx("string=bar"), 1.0)))) + ).toDF("real", "bool", "stringNum", "string", "expected") - val n = 100 val hasher = new FeatureHasher() .setInputCols("real", "bool", "stringNum", "string") .setOutputCol("features") - .setNumFeatures(n) + .setNumFeatures(numFeatures) val output = hasher.transform(df) val attrGroup = AttributeGroup.fromStructField(output.schema("features")) - assert(attrGroup.numAttributes === Some(n)) + assert(attrGroup.numAttributes === Some(numFeatures)) - val features = output.select("features").as[Vector].collect() - // Assume perfect hash on field names - def idx: Any => Int = murmur3FeatureIdx(n) - // check expected indices - val expected = Seq( - Vectors.sparse(n, Seq((idx("real"), 2.0), (idx("bool=true"), 1.0), - (idx("stringNum=1"), 1.0), (idx("string=foo"), 1.0))), - Vectors.sparse(n, Seq((idx("real"), 3.0), (idx("bool=false"), 1.0), - (idx("stringNum=2"), 1.0), (idx("string=bar"), 1.0))) - ) - assert(features.zip(expected).forall { case (e, a) => e ~== a absTol 1e-14 }) + testTransformer[(Double, Boolean, String, String, Vector)](df, hasher, "features", "expected") { + case Row(features: Vector, expected: Vector) => + assert(features ~== expected absTol 1e-14 ) + } } test("setting explicit numerical columns to treat as categorical") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala index a46272fdce1fb..c5183ecfef7d7 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala @@ -17,17 +17,16 @@ package org.apache.spark.ml.feature -import org.apache.spark.SparkFunSuite import org.apache.spark.ml.attribute.AttributeGroup import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest} import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.feature.{HashingTF => MLlibHashingTF} -import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.Row import org.apache.spark.util.Utils -class HashingTFSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class HashingTFSuite extends MLTest with DefaultReadWriteTest { import testImplicits._ import HashingTFSuite.murmur3FeatureIdx @@ -37,21 +36,28 @@ class HashingTFSuite extends SparkFunSuite with MLlibTestSparkContext with Defau } test("hashingTF") { - val df = Seq((0, "a a b b c d".split(" ").toSeq)).toDF("id", "words") - val n = 100 + val numFeatures = 100 + // Assume perfect hash when computing expected features. + def idx: Any => Int = murmur3FeatureIdx(numFeatures) + val data = Seq( + ("a a b b c d".split(" ").toSeq, + Vectors.sparse(numFeatures, + Seq((idx("a"), 2.0), (idx("b"), 2.0), (idx("c"), 1.0), (idx("d"), 1.0)))) + ) + + val df = data.toDF("words", "expected") val hashingTF = new HashingTF() .setInputCol("words") .setOutputCol("features") - .setNumFeatures(n) + .setNumFeatures(numFeatures) val output = hashingTF.transform(df) val attrGroup = AttributeGroup.fromStructField(output.schema("features")) - require(attrGroup.numAttributes === Some(n)) - val features = output.select("features").first().getAs[Vector](0) - // Assume perfect hash on "a", "b", "c", and "d". - def idx: Any => Int = murmur3FeatureIdx(n) - val expected = Vectors.sparse(n, - Seq((idx("a"), 2.0), (idx("b"), 2.0), (idx("c"), 1.0), (idx("d"), 1.0))) - assert(features ~== expected absTol 1e-14) + require(attrGroup.numAttributes === Some(numFeatures)) + + testTransformer[(Seq[String], Vector)](df, hashingTF, "features", "expected") { + case Row(features: Vector, expected: Vector) => + assert(features ~== expected absTol 1e-14) + } } test("applying binary term freqs") { From 0b6ceadeb563205cbd6bd03bc88e608086273b5b Mon Sep 17 00:00:00 2001 From: Felix Cheung Date: Fri, 2 Mar 2018 09:23:39 -0800 Subject: [PATCH 0413/1161] [SPARKR][DOC] fix link in vignettes ## What changes were proposed in this pull request? Fix doc link that was changed in 2.3 shivaram Author: Felix Cheung Closes #20711 from felixcheung/rvigmean. --- R/pkg/vignettes/sparkr-vignettes.Rmd | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/R/pkg/vignettes/sparkr-vignettes.Rmd b/R/pkg/vignettes/sparkr-vignettes.Rmd index feca617c2554c..d4713de7806a1 100644 --- a/R/pkg/vignettes/sparkr-vignettes.Rmd +++ b/R/pkg/vignettes/sparkr-vignettes.Rmd @@ -46,7 +46,7 @@ Sys.setenv("_JAVA_OPTIONS" = paste("-XX:-UsePerfData", old_java_opt, sep = " ")) ## Overview -SparkR is an R package that provides a light-weight frontend to use Apache Spark from R. With Spark `r packageVersion("SparkR")`, SparkR provides a distributed data frame implementation that supports data processing operations like selection, filtering, aggregation etc. and distributed machine learning using [MLlib](http://spark.apache.org/mllib/). +SparkR is an R package that provides a light-weight frontend to use Apache Spark from R. With Spark `r packageVersion("SparkR")`, SparkR provides a distributed data frame implementation that supports data processing operations like selection, filtering, aggregation etc. and distributed machine learning using [MLlib](https://spark.apache.org/mllib/). ## Getting Started @@ -132,7 +132,7 @@ sparkR.session.stop() Different from many other R packages, to use SparkR, you need an additional installation of Apache Spark. The Spark installation will be used to run a backend process that will compile and execute SparkR programs. -After installing the SparkR package, you can call `sparkR.session` as explained in the previous section to start and it will check for the Spark installation. If you are working with SparkR from an interactive shell (eg. R, RStudio) then Spark is downloaded and cached automatically if it is not found. Alternatively, we provide an easy-to-use function `install.spark` for running this manually. If you don't have Spark installed on the computer, you may download it from [Apache Spark Website](http://spark.apache.org/downloads.html). +After installing the SparkR package, you can call `sparkR.session` as explained in the previous section to start and it will check for the Spark installation. If you are working with SparkR from an interactive shell (eg. R, RStudio) then Spark is downloaded and cached automatically if it is not found. Alternatively, we provide an easy-to-use function `install.spark` for running this manually. If you don't have Spark installed on the computer, you may download it from [Apache Spark Website](https://spark.apache.org/downloads.html). ```{r, eval=FALSE} install.spark() @@ -147,7 +147,7 @@ sparkR.session(sparkHome = "/HOME/spark") ### Spark Session {#SetupSparkSession} -In addition to `sparkHome`, many other options can be specified in `sparkR.session`. For a complete list, see [Starting up: SparkSession](http://spark.apache.org/docs/latest/sparkr.html#starting-up-sparksession) and [SparkR API doc](http://spark.apache.org/docs/latest/api/R/sparkR.session.html). +In addition to `sparkHome`, many other options can be specified in `sparkR.session`. For a complete list, see [Starting up: SparkSession](https://spark.apache.org/docs/latest/sparkr.html#starting-up-sparksession) and [SparkR API doc](https://spark.apache.org/docs/latest/api/R/sparkR.session.html). In particular, the following Spark driver properties can be set in `sparkConfig`. @@ -169,7 +169,7 @@ sparkR.session(spark.sql.warehouse.dir = spark_warehouse_path) #### Cluster Mode -SparkR can connect to remote Spark clusters. [Cluster Mode Overview](http://spark.apache.org/docs/latest/cluster-overview.html) is a good introduction to different Spark cluster modes. +SparkR can connect to remote Spark clusters. [Cluster Mode Overview](https://spark.apache.org/docs/latest/cluster-overview.html) is a good introduction to different Spark cluster modes. When connecting SparkR to a remote Spark cluster, make sure that the Spark version and Hadoop version on the machine match the corresponding versions on the cluster. Current SparkR package is compatible with ```{r, echo=FALSE, tidy = TRUE} @@ -177,7 +177,7 @@ paste("Spark", packageVersion("SparkR")) ``` It should be used both on the local computer and on the remote cluster. -To connect, pass the URL of the master node to `sparkR.session`. A complete list can be seen in [Spark Master URLs](http://spark.apache.org/docs/latest/submitting-applications.html#master-urls). +To connect, pass the URL of the master node to `sparkR.session`. A complete list can be seen in [Spark Master URLs](https://spark.apache.org/docs/latest/submitting-applications.html#master-urls). For example, to connect to a local standalone Spark master, we can call ```{r, eval=FALSE} @@ -317,7 +317,7 @@ A common flow of grouping and aggregation is 2. Feed the `GroupedData` object to `agg` or `summarize` functions, with some provided aggregation functions to compute a number within each group. -A number of widely used functions are supported to aggregate data after grouping, including `avg`, `countDistinct`, `count`, `first`, `kurtosis`, `last`, `max`, `mean`, `min`, `sd`, `skewness`, `stddev_pop`, `stddev_samp`, `sumDistinct`, `sum`, `var_pop`, `var_samp`, `var`. See the [API doc for `mean`](http://spark.apache.org/docs/latest/api/R/mean.html) and other `agg_funcs` linked there. +A number of widely used functions are supported to aggregate data after grouping, including `avg`, `countDistinct`, `count`, `first`, `kurtosis`, `last`, `max`, `mean`, `min`, `sd`, `skewness`, `stddev_pop`, `stddev_samp`, `sumDistinct`, `sum`, `var_pop`, `var_samp`, `var`. See the [API doc for aggregate functions](https://spark.apache.org/docs/latest/api/R/column_aggregate_functions.html) linked there. For example we can compute a histogram of the number of cylinders in the `mtcars` dataset as shown below. @@ -935,7 +935,7 @@ perplexity #### Alternating Least Squares -`spark.als` learns latent factors in [collaborative filtering](https://en.wikipedia.org/wiki/Recommender_system#Collaborative_filtering) via [alternating least squares](http://dl.acm.org/citation.cfm?id=1608614). +`spark.als` learns latent factors in [collaborative filtering](https://en.wikipedia.org/wiki/Recommender_system#Collaborative_filtering) via [alternating least squares](https://dl.acm.org/citation.cfm?id=1608614). There are multiple options that can be configured in `spark.als`, including `rank`, `reg`, and `nonnegative`. For a complete list, refer to the help file. @@ -1171,11 +1171,11 @@ env | map ## References -* [Spark Cluster Mode Overview](http://spark.apache.org/docs/latest/cluster-overview.html) +* [Spark Cluster Mode Overview](https://spark.apache.org/docs/latest/cluster-overview.html) -* [Submitting Spark Applications](http://spark.apache.org/docs/latest/submitting-applications.html) +* [Submitting Spark Applications](https://spark.apache.org/docs/latest/submitting-applications.html) -* [Machine Learning Library Guide (MLlib)](http://spark.apache.org/docs/latest/ml-guide.html) +* [Machine Learning Library Guide (MLlib)](https://spark.apache.org/docs/latest/ml-guide.html) * [SparkR: Scaling R Programs with Spark](https://people.csail.mit.edu/matei/papers/2016/sigmod_sparkr.pdf), Shivaram Venkataraman, Zongheng Yang, Davies Liu, Eric Liang, Hossein Falaki, Xiangrui Meng, Reynold Xin, Ali Ghodsi, Michael Franklin, Ion Stoica, and Matei Zaharia. SIGMOD 2016. June 2016. From 3a4d15e5d2b9ddbaeb2a6ab2d86d059ada6407b2 Mon Sep 17 00:00:00 2001 From: Feng Liu Date: Fri, 2 Mar 2018 10:38:50 -0800 Subject: [PATCH 0414/1161] [SPARK-23518][SQL] Avoid metastore access when the users only want to read and write data frames ## What changes were proposed in this pull request? https://github.com/apache/spark/pull/18944 added one patch, which allowed a spark session to be created when the hive metastore server is down. However, it did not allow running any commands with the spark session. This brings troubles to the user who only wants to read / write data frames without metastore setup. ## How was this patch tested? Added some unit tests to read and write data frames based on the original HiveMetastoreLazyInitializationSuite. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Feng Liu Closes #20681 from liufengdb/completely-lazy. --- R/pkg/tests/fulltests/test_sparkSQL.R | 2 ++ .../sql/catalyst/catalog/SessionCatalog.scala | 11 +++++++---- .../sql/internal/BaseSessionStateBuilder.scala | 4 ++-- .../HiveMetastoreLazyInitializationSuite.scala | 14 ++++++++++++++ .../apache/spark/sql/hive/HiveSessionCatalog.scala | 8 ++++---- .../spark/sql/hive/HiveSessionStateBuilder.scala | 10 +++++----- 6 files changed, 34 insertions(+), 15 deletions(-) diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index 5197838eaac66..bd0a0dcd0674c 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -67,6 +67,8 @@ sparkSession <- if (windows_with_hadoop()) { sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) } sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession) +# materialize the catalog implementation +listTables() mockLines <- c("{\"name\":\"Michael\"}", "{\"name\":\"Andy\", \"age\":30}", diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index 4b119c75260a7..64e7ca11270b4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -54,8 +54,8 @@ object SessionCatalog { * This class must be thread-safe. */ class SessionCatalog( - val externalCatalog: ExternalCatalog, - globalTempViewManager: GlobalTempViewManager, + externalCatalogBuilder: () => ExternalCatalog, + globalTempViewManagerBuilder: () => GlobalTempViewManager, functionRegistry: FunctionRegistry, conf: SQLConf, hadoopConf: Configuration, @@ -70,8 +70,8 @@ class SessionCatalog( functionRegistry: FunctionRegistry, conf: SQLConf) { this( - externalCatalog, - new GlobalTempViewManager("global_temp"), + () => externalCatalog, + () => new GlobalTempViewManager("global_temp"), functionRegistry, conf, new Configuration(), @@ -87,6 +87,9 @@ class SessionCatalog( new SQLConf().copy(SQLConf.CASE_SENSITIVE -> true)) } + lazy val externalCatalog = externalCatalogBuilder() + lazy val globalTempViewManager = globalTempViewManagerBuilder() + /** List of temporary views, mapping from table name to their logical plan. */ @GuardedBy("this") protected val tempViews = new mutable.HashMap[String, LogicalPlan] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala index 007f8760edf82..3a0db7e16c23a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala @@ -130,8 +130,8 @@ abstract class BaseSessionStateBuilder( */ protected lazy val catalog: SessionCatalog = { val catalog = new SessionCatalog( - session.sharedState.externalCatalog, - session.sharedState.globalTempViewManager, + () => session.sharedState.externalCatalog, + () => session.sharedState.globalTempViewManager, functionRegistry, conf, SessionState.newHadoopConf(session.sparkContext.hadoopConfiguration, conf), diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreLazyInitializationSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreLazyInitializationSuite.scala index 3f135cc864983..277df548aefd0 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreLazyInitializationSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreLazyInitializationSuite.scala @@ -38,6 +38,20 @@ class HiveMetastoreLazyInitializationSuite extends SparkFunSuite { // We should be able to run Spark jobs without Hive client. assert(spark.sparkContext.range(0, 1).count() === 1) + // We should be able to use Spark SQL if no table references. + assert(spark.sql("select 1 + 1").count() === 1) + assert(spark.range(0, 1).count() === 1) + + // We should be able to use fs + val path = Utils.createTempDir() + path.delete() + try { + spark.range(0, 1).write.parquet(path.getAbsolutePath) + assert(spark.read.parquet(path.getAbsolutePath).count() === 1) + } finally { + Utils.deleteRecursively(path) + } + // Make sure that we are not using the local derby metastore. val exceptionString = Utils.exceptionString(intercept[AnalysisException] { spark.sql("show tables") diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala index 1f11adbd4f62e..e5aff3b99d0b9 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala @@ -39,8 +39,8 @@ import org.apache.spark.sql.types.{DecimalType, DoubleType} private[sql] class HiveSessionCatalog( - externalCatalog: HiveExternalCatalog, - globalTempViewManager: GlobalTempViewManager, + externalCatalogBuilder: () => HiveExternalCatalog, + globalTempViewManagerBuilder: () => GlobalTempViewManager, val metastoreCatalog: HiveMetastoreCatalog, functionRegistry: FunctionRegistry, conf: SQLConf, @@ -48,8 +48,8 @@ private[sql] class HiveSessionCatalog( parser: ParserInterface, functionResourceLoader: FunctionResourceLoader) extends SessionCatalog( - externalCatalog, - globalTempViewManager, + externalCatalogBuilder, + globalTempViewManagerBuilder, functionRegistry, conf, hadoopConf, diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala index 12c74368dd184..40b9bb51ca9a0 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala @@ -42,8 +42,7 @@ class HiveSessionStateBuilder(session: SparkSession, parentState: Option[Session * Create a Hive aware resource loader. */ override protected lazy val resourceLoader: HiveSessionResourceLoader = { - val client: HiveClient = externalCatalog.client - new HiveSessionResourceLoader(session, client) + new HiveSessionResourceLoader(session, () => externalCatalog.client) } /** @@ -51,8 +50,8 @@ class HiveSessionStateBuilder(session: SparkSession, parentState: Option[Session */ override protected lazy val catalog: HiveSessionCatalog = { val catalog = new HiveSessionCatalog( - externalCatalog, - session.sharedState.globalTempViewManager, + () => externalCatalog, + () => session.sharedState.globalTempViewManager, new HiveMetastoreCatalog(session), functionRegistry, conf, @@ -105,8 +104,9 @@ class HiveSessionStateBuilder(session: SparkSession, parentState: Option[Session class HiveSessionResourceLoader( session: SparkSession, - client: HiveClient) + clientBuilder: () => HiveClient) extends SessionResourceLoader(session) { + private lazy val client = clientBuilder() override def addJar(path: String): Unit = { client.addJar(path) super.addJar(path) From 707e6506d0dbdb598a6c99d666f3c66746113b67 Mon Sep 17 00:00:00 2001 From: jerryshao Date: Fri, 2 Mar 2018 12:27:42 -0800 Subject: [PATCH 0415/1161] [SPARK-23097][SQL][SS] Migrate text socket source to V2 ## What changes were proposed in this pull request? This PR moves structured streaming text socket source to V2. Questions: do we need to remove old "socket" source? ## How was this patch tested? Unit test and manual verification. Author: jerryshao Closes #20382 from jerryshao/SPARK-23097. --- ...pache.spark.sql.sources.DataSourceRegister | 2 +- .../execution/datasources/DataSource.scala | 5 +- .../streaming/{ => sources}/socket.scala | 178 ++++++---- .../sql/streaming/DataStreamReader.scala | 21 +- .../streaming/TextSocketStreamSuite.scala | 231 ------------- .../sources/TextSocketStreamSuite.scala | 306 ++++++++++++++++++ 6 files changed, 434 insertions(+), 309 deletions(-) rename sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/{ => sources}/socket.scala (51%) delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/TextSocketStreamSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala diff --git a/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister index 0259c774bbf4a..1fe9c093af99f 100644 --- a/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister +++ b/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -5,6 +5,6 @@ org.apache.spark.sql.execution.datasources.orc.OrcFileFormat org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat org.apache.spark.sql.execution.datasources.text.TextFileFormat org.apache.spark.sql.execution.streaming.ConsoleSinkProvider -org.apache.spark.sql.execution.streaming.TextSocketSourceProvider org.apache.spark.sql.execution.streaming.RateSourceProvider +org.apache.spark.sql.execution.streaming.sources.TextSocketSourceProvider org.apache.spark.sql.execution.streaming.sources.RateSourceProviderV2 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index 6e1b5727e3fd5..35fcff69b14d8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -41,6 +41,7 @@ import org.apache.spark.sql.execution.datasources.json.JsonFileFormat import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.execution.streaming.sources.TextSocketSourceProvider import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources._ import org.apache.spark.sql.streaming.OutputMode @@ -563,6 +564,7 @@ object DataSource extends Logging { val libsvm = "org.apache.spark.ml.source.libsvm.LibSVMFileFormat" val orc = "org.apache.spark.sql.hive.orc.OrcFileFormat" val nativeOrc = classOf[OrcFileFormat].getCanonicalName + val socket = classOf[TextSocketSourceProvider].getCanonicalName Map( "org.apache.spark.sql.jdbc" -> jdbc, @@ -583,7 +585,8 @@ object DataSource extends Logging { "org.apache.spark.sql.execution.datasources.orc" -> nativeOrc, "org.apache.spark.ml.source.libsvm.DefaultSource" -> libsvm, "org.apache.spark.ml.source.libsvm" -> libsvm, - "com.databricks.spark.csv" -> csv + "com.databricks.spark.csv" -> csv, + "org.apache.spark.sql.execution.streaming.TextSocketSourceProvider" -> socket ) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/socket.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala similarity index 51% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/socket.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala index 0b22cbc46e6bf..5aae46b463398 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/socket.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala @@ -15,27 +15,29 @@ * limitations under the License. */ -package org.apache.spark.sql.execution.streaming +package org.apache.spark.sql.execution.streaming.sources import java.io.{BufferedReader, InputStreamReader, IOException} import java.net.Socket import java.sql.Timestamp import java.text.SimpleDateFormat -import java.util.{Calendar, Locale} +import java.util.{Calendar, List => JList, Locale, Optional} import javax.annotation.concurrent.GuardedBy +import scala.collection.JavaConverters._ import scala.collection.mutable.ListBuffer import scala.util.{Failure, Success, Try} import org.apache.spark.internal.Logging import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.sources.{DataSourceRegister, StreamSourceProvider} +import org.apache.spark.sql.execution.streaming.LongOffset +import org.apache.spark.sql.sources.DataSourceRegister +import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, MicroBatchReadSupport} +import org.apache.spark.sql.sources.v2.reader.{DataReader, DataReaderFactory} +import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset} import org.apache.spark.sql.types.{StringType, StructField, StructType, TimestampType} -import org.apache.spark.unsafe.types.UTF8String - -object TextSocketSource { +object TextSocketMicroBatchReader { val SCHEMA_REGULAR = StructType(StructField("value", StringType) :: Nil) val SCHEMA_TIMESTAMP = StructType(StructField("value", StringType) :: StructField("timestamp", TimestampType) :: Nil) @@ -43,12 +45,17 @@ object TextSocketSource { } /** - * A source that reads text lines through a TCP socket, designed only for tutorials and debugging. - * This source will *not* work in production applications due to multiple reasons, including no - * support for fault recovery and keeping all of the text read in memory forever. + * A MicroBatchReader that reads text lines through a TCP socket, designed only for tutorials and + * debugging. This MicroBatchReader will *not* work in production applications due to multiple + * reasons, including no support for fault recovery. */ -class TextSocketSource(host: String, port: Int, includeTimestamp: Boolean, sqlContext: SQLContext) - extends Source with Logging { +class TextSocketMicroBatchReader(options: DataSourceOptions) extends MicroBatchReader with Logging { + + private var startOffset: Offset = _ + private var endOffset: Offset = _ + + private val host: String = options.get("host").get() + private val port: Int = options.get("port").get().toInt @GuardedBy("this") private var socket: Socket = null @@ -61,16 +68,21 @@ class TextSocketSource(host: String, port: Int, includeTimestamp: Boolean, sqlCo * Stored in a ListBuffer to facilitate removing committed batches. */ @GuardedBy("this") - protected val batches = new ListBuffer[(String, Timestamp)] + private val batches = new ListBuffer[(String, Timestamp)] @GuardedBy("this") - protected var currentOffset: LongOffset = new LongOffset(-1) + private var currentOffset: LongOffset = LongOffset(-1L) @GuardedBy("this") - protected var lastOffsetCommitted : LongOffset = new LongOffset(-1) + private var lastOffsetCommitted: LongOffset = LongOffset(-1L) initialize() + /** This method is only used for unit test */ + private[sources] def getCurrentOffset(): LongOffset = synchronized { + currentOffset.copy() + } + private def initialize(): Unit = synchronized { socket = new Socket(host, port) val reader = new BufferedReader(new InputStreamReader(socket.getInputStream)) @@ -86,12 +98,12 @@ class TextSocketSource(host: String, port: Int, includeTimestamp: Boolean, sqlCo logWarning(s"Stream closed by $host:$port") return } - TextSocketSource.this.synchronized { + TextSocketMicroBatchReader.this.synchronized { val newData = (line, Timestamp.valueOf( - TextSocketSource.DATE_FORMAT.format(Calendar.getInstance().getTime())) - ) - currentOffset = currentOffset + 1 + TextSocketMicroBatchReader.DATE_FORMAT.format(Calendar.getInstance().getTime())) + ) + currentOffset += 1 batches.append(newData) } } @@ -103,23 +115,37 @@ class TextSocketSource(host: String, port: Int, includeTimestamp: Boolean, sqlCo readThread.start() } - /** Returns the schema of the data from this source */ - override def schema: StructType = if (includeTimestamp) TextSocketSource.SCHEMA_TIMESTAMP - else TextSocketSource.SCHEMA_REGULAR + override def setOffsetRange(start: Optional[Offset], end: Optional[Offset]): Unit = synchronized { + startOffset = start.orElse(LongOffset(-1L)) + endOffset = end.orElse(currentOffset) + } + + override def getStartOffset(): Offset = { + Option(startOffset).getOrElse(throw new IllegalStateException("start offset not set")) + } + + override def getEndOffset(): Offset = { + Option(endOffset).getOrElse(throw new IllegalStateException("end offset not set")) + } + + override def deserializeOffset(json: String): Offset = { + LongOffset(json.toLong) + } - override def getOffset: Option[Offset] = synchronized { - if (currentOffset.offset == -1) { - None + override def readSchema(): StructType = { + if (options.getBoolean("includeTimestamp", false)) { + TextSocketMicroBatchReader.SCHEMA_TIMESTAMP } else { - Some(currentOffset) + TextSocketMicroBatchReader.SCHEMA_REGULAR } } - /** Returns the data that is between the offsets (`start`, `end`]. */ - override def getBatch(start: Option[Offset], end: Offset): DataFrame = synchronized { - val startOrdinal = - start.flatMap(LongOffset.convert).getOrElse(LongOffset(-1)).offset.toInt + 1 - val endOrdinal = LongOffset.convert(end).getOrElse(LongOffset(-1)).offset.toInt + 1 + override def createDataReaderFactories(): JList[DataReaderFactory[Row]] = { + assert(startOffset != null && endOffset != null, + "start offset and end offset should already be set before create read tasks.") + + val startOrdinal = LongOffset.convert(startOffset).get.offset.toInt + 1 + val endOrdinal = LongOffset.convert(endOffset).get.offset.toInt + 1 // Internal buffer only holds the batches after lastOffsetCommitted val rawList = synchronized { @@ -128,10 +154,34 @@ class TextSocketSource(host: String, port: Int, includeTimestamp: Boolean, sqlCo batches.slice(sliceStart, sliceEnd) } - val rdd = sqlContext.sparkContext - .parallelize(rawList) - .map { case (v, ts) => InternalRow(UTF8String.fromString(v), ts.getTime) } - sqlContext.internalCreateDataFrame(rdd, schema, isStreaming = true) + assert(SparkSession.getActiveSession.isDefined) + val spark = SparkSession.getActiveSession.get + val numPartitions = spark.sparkContext.defaultParallelism + + val slices = Array.fill(numPartitions)(new ListBuffer[(String, Timestamp)]) + rawList.zipWithIndex.foreach { case (r, idx) => + slices(idx % numPartitions).append(r) + } + + (0 until numPartitions).map { i => + val slice = slices(i) + new DataReaderFactory[Row] { + override def createDataReader(): DataReader[Row] = new DataReader[Row] { + private var currentIdx = -1 + + override def next(): Boolean = { + currentIdx += 1 + currentIdx < slice.size + } + + override def get(): Row = { + Row(slice(currentIdx)._1, slice(currentIdx)._2) + } + + override def close(): Unit = {} + } + } + }.toList.asJava } override def commit(end: Offset): Unit = synchronized { @@ -164,54 +214,40 @@ class TextSocketSource(host: String, port: Int, includeTimestamp: Boolean, sqlCo } } - override def toString: String = s"TextSocketSource[host: $host, port: $port]" + override def toString: String = s"TextSocket[host: $host, port: $port]" } -class TextSocketSourceProvider extends StreamSourceProvider with DataSourceRegister with Logging { - private def parseIncludeTimestamp(params: Map[String, String]): Boolean = { - Try(params.getOrElse("includeTimestamp", "false").toBoolean) match { - case Success(bool) => bool - case Failure(_) => - throw new AnalysisException("includeTimestamp must be set to either \"true\" or \"false\"") - } - } +class TextSocketSourceProvider extends DataSourceV2 + with MicroBatchReadSupport with DataSourceRegister with Logging { - /** Returns the name and schema of the source that can be used to continually read data. */ - override def sourceSchema( - sqlContext: SQLContext, - schema: Option[StructType], - providerName: String, - parameters: Map[String, String]): (String, StructType) = { + private def checkParameters(params: DataSourceOptions): Unit = { logWarning("The socket source should not be used for production applications! " + "It does not support recovery.") - if (!parameters.contains("host")) { + if (!params.get("host").isPresent) { throw new AnalysisException("Set a host to read from with option(\"host\", ...).") } - if (!parameters.contains("port")) { + if (!params.get("port").isPresent) { throw new AnalysisException("Set a port to read from with option(\"port\", ...).") } - if (schema.nonEmpty) { - throw new AnalysisException("The socket source does not support a user-specified schema.") + Try { + params.get("includeTimestamp").orElse("false").toBoolean + } match { + case Success(_) => + case Failure(_) => + throw new AnalysisException("includeTimestamp must be set to either \"true\" or \"false\"") } - - val sourceSchema = - if (parseIncludeTimestamp(parameters)) { - TextSocketSource.SCHEMA_TIMESTAMP - } else { - TextSocketSource.SCHEMA_REGULAR - } - ("textSocket", sourceSchema) } - override def createSource( - sqlContext: SQLContext, - metadataPath: String, - schema: Option[StructType], - providerName: String, - parameters: Map[String, String]): Source = { - val host = parameters("host") - val port = parameters("port").toInt - new TextSocketSource(host, port, parseIncludeTimestamp(parameters), sqlContext) + override def createMicroBatchReader( + schema: Optional[StructType], + checkpointLocation: String, + options: DataSourceOptions): MicroBatchReader = { + checkParameters(options) + if (schema.isPresent) { + throw new AnalysisException("The socket source does not support a user-specified schema.") + } + + new TextSocketMicroBatchReader(options) } /** String that represents the format that this data source provider uses. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index 61e22fac854f9..c393dcdfdd7e5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming.{StreamingRelation, StreamingRelationV2} import org.apache.spark.sql.sources.StreamSourceProvider import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions, MicroBatchReadSupport} +import org.apache.spark.sql.sources.v2.reader.streaming.MicroBatchReader import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils @@ -172,15 +173,25 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo } ds match { case s: MicroBatchReadSupport => - val tempReader = s.createMicroBatchReader( - Optional.ofNullable(userSpecifiedSchema.orNull), - Utils.createTempDir(namePrefix = s"temporaryReader").getCanonicalPath, - options) + var tempReader: MicroBatchReader = null + val schema = try { + tempReader = s.createMicroBatchReader( + Optional.ofNullable(userSpecifiedSchema.orNull), + Utils.createTempDir(namePrefix = s"temporaryReader").getCanonicalPath, + options) + tempReader.readSchema() + } finally { + // Stop tempReader to avoid side-effect thing + if (tempReader != null) { + tempReader.stop() + tempReader = null + } + } Dataset.ofRows( sparkSession, StreamingRelationV2( s, source, extraOptions.toMap, - tempReader.readSchema().toAttributes, v1Relation)(sparkSession)) + schema.toAttributes, v1Relation)(sparkSession)) case s: ContinuousReadSupport => val tempReader = s.createContinuousReader( Optional.ofNullable(userSpecifiedSchema.orNull), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/TextSocketStreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/TextSocketStreamSuite.scala deleted file mode 100644 index ec11549073650..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/TextSocketStreamSuite.scala +++ /dev/null @@ -1,231 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.streaming - -import java.io.{IOException, OutputStreamWriter} -import java.net.ServerSocket -import java.sql.Timestamp -import java.util.concurrent.LinkedBlockingQueue - -import org.scalatest.BeforeAndAfterEach - -import org.apache.spark.internal.Logging -import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.streaming.StreamTest -import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.sql.types.{StringType, StructField, StructType, TimestampType} - -class TextSocketStreamSuite extends StreamTest with SharedSQLContext with BeforeAndAfterEach { - import testImplicits._ - - override def afterEach() { - sqlContext.streams.active.foreach(_.stop()) - if (serverThread != null) { - serverThread.interrupt() - serverThread.join() - serverThread = null - } - if (source != null) { - source.stop() - source = null - } - } - - private var serverThread: ServerThread = null - private var source: Source = null - - test("basic usage") { - serverThread = new ServerThread() - serverThread.start() - - val provider = new TextSocketSourceProvider - val parameters = Map("host" -> "localhost", "port" -> serverThread.port.toString) - val schema = provider.sourceSchema(sqlContext, None, "", parameters)._2 - assert(schema === StructType(StructField("value", StringType) :: Nil)) - - source = provider.createSource(sqlContext, "", None, "", parameters) - - failAfter(streamingTimeout) { - serverThread.enqueue("hello") - while (source.getOffset.isEmpty) { - Thread.sleep(10) - } - withSQLConf("spark.sql.streaming.unsupportedOperationCheck" -> "false") { - val offset1 = source.getOffset.get - val batch1 = source.getBatch(None, offset1) - assert(batch1.as[String].collect().toSeq === Seq("hello")) - - serverThread.enqueue("world") - while (source.getOffset.get === offset1) { - Thread.sleep(10) - } - val offset2 = source.getOffset.get - val batch2 = source.getBatch(Some(offset1), offset2) - assert(batch2.as[String].collect().toSeq === Seq("world")) - - val both = source.getBatch(None, offset2) - assert(both.as[String].collect().sorted.toSeq === Seq("hello", "world")) - } - - // Try stopping the source to make sure this does not block forever. - source.stop() - source = null - } - } - - test("timestamped usage") { - serverThread = new ServerThread() - serverThread.start() - - val provider = new TextSocketSourceProvider - val parameters = Map("host" -> "localhost", "port" -> serverThread.port.toString, - "includeTimestamp" -> "true") - val schema = provider.sourceSchema(sqlContext, None, "", parameters)._2 - assert(schema === StructType(StructField("value", StringType) :: - StructField("timestamp", TimestampType) :: Nil)) - - source = provider.createSource(sqlContext, "", None, "", parameters) - - failAfter(streamingTimeout) { - serverThread.enqueue("hello") - while (source.getOffset.isEmpty) { - Thread.sleep(10) - } - withSQLConf("spark.sql.streaming.unsupportedOperationCheck" -> "false") { - val offset1 = source.getOffset.get - val batch1 = source.getBatch(None, offset1) - val batch1Seq = batch1.as[(String, Timestamp)].collect().toSeq - assert(batch1Seq.map(_._1) === Seq("hello")) - val batch1Stamp = batch1Seq(0)._2 - - serverThread.enqueue("world") - while (source.getOffset.get === offset1) { - Thread.sleep(10) - } - val offset2 = source.getOffset.get - val batch2 = source.getBatch(Some(offset1), offset2) - val batch2Seq = batch2.as[(String, Timestamp)].collect().toSeq - assert(batch2Seq.map(_._1) === Seq("world")) - val batch2Stamp = batch2Seq(0)._2 - assert(!batch2Stamp.before(batch1Stamp)) - } - - // Try stopping the source to make sure this does not block forever. - source.stop() - source = null - } - } - - test("params not given") { - val provider = new TextSocketSourceProvider - intercept[AnalysisException] { - provider.sourceSchema(sqlContext, None, "", Map()) - } - intercept[AnalysisException] { - provider.sourceSchema(sqlContext, None, "", Map("host" -> "localhost")) - } - intercept[AnalysisException] { - provider.sourceSchema(sqlContext, None, "", Map("port" -> "1234")) - } - } - - test("non-boolean includeTimestamp") { - val provider = new TextSocketSourceProvider - intercept[AnalysisException] { - provider.sourceSchema(sqlContext, None, "", Map("host" -> "localhost", - "port" -> "1234", "includeTimestamp" -> "fasle")) - } - } - - test("user-specified schema given") { - val provider = new TextSocketSourceProvider - val userSpecifiedSchema = StructType( - StructField("name", StringType) :: - StructField("area", StringType) :: Nil) - val exception = intercept[AnalysisException] { - provider.sourceSchema( - sqlContext, Some(userSpecifiedSchema), - "", - Map("host" -> "localhost", "port" -> "1234")) - } - assert(exception.getMessage.contains( - "socket source does not support a user-specified schema")) - } - - test("no server up") { - val provider = new TextSocketSourceProvider - val parameters = Map("host" -> "localhost", "port" -> "0") - intercept[IOException] { - source = provider.createSource(sqlContext, "", None, "", parameters) - } - } - - test("input row metrics") { - serverThread = new ServerThread() - serverThread.start() - - val provider = new TextSocketSourceProvider - val parameters = Map("host" -> "localhost", "port" -> serverThread.port.toString) - source = provider.createSource(sqlContext, "", None, "", parameters) - - failAfter(streamingTimeout) { - serverThread.enqueue("hello") - while (source.getOffset.isEmpty) { - Thread.sleep(10) - } - withSQLConf("spark.sql.streaming.unsupportedOperationCheck" -> "false") { - val batch = source.getBatch(None, source.getOffset.get).as[String] - batch.collect() - val numRowsMetric = - batch.queryExecution.executedPlan.collectLeaves().head.metrics.get("numOutputRows") - assert(numRowsMetric.nonEmpty) - assert(numRowsMetric.get.value === 1) - } - source.stop() - source = null - } - } - - private class ServerThread extends Thread with Logging { - private val serverSocket = new ServerSocket(0) - private val messageQueue = new LinkedBlockingQueue[String]() - - val port = serverSocket.getLocalPort - - override def run(): Unit = { - try { - val clientSocket = serverSocket.accept() - clientSocket.setTcpNoDelay(true) - val out = new OutputStreamWriter(clientSocket.getOutputStream) - while (true) { - val line = messageQueue.take() - out.write(line + "\n") - out.flush() - } - } catch { - case e: InterruptedException => - } finally { - serverSocket.close() - } - } - - def enqueue(line: String): Unit = { - messageQueue.put(line) - } - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala new file mode 100644 index 0000000000000..a15a980bb92fd --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala @@ -0,0 +1,306 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.sources + +import java.io.IOException +import java.net.InetSocketAddress +import java.nio.ByteBuffer +import java.nio.channels.ServerSocketChannel +import java.sql.Timestamp +import java.util.Optional +import java.util.concurrent.LinkedBlockingQueue + +import scala.collection.JavaConverters._ + +import org.scalatest.BeforeAndAfterEach + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.execution.datasources.DataSource +import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.sources.v2.{DataSourceOptions, MicroBatchReadSupport} +import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset} +import org.apache.spark.sql.streaming.StreamTest +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.{StringType, StructField, StructType, TimestampType} + +class TextSocketStreamSuite extends StreamTest with SharedSQLContext with BeforeAndAfterEach { + + override def afterEach() { + sqlContext.streams.active.foreach(_.stop()) + if (serverThread != null) { + serverThread.interrupt() + serverThread.join() + serverThread = null + } + if (batchReader != null) { + batchReader.stop() + batchReader = null + } + } + + private var serverThread: ServerThread = null + private var batchReader: MicroBatchReader = null + + case class AddSocketData(data: String*) extends AddData { + override def addData(query: Option[StreamExecution]): (BaseStreamingSource, Offset) = { + require( + query.nonEmpty, + "Cannot add data when there is no query for finding the active socket source") + + val sources = query.get.logicalPlan.collect { + case StreamingExecutionRelation(source: TextSocketMicroBatchReader, _) => source + } + if (sources.isEmpty) { + throw new Exception( + "Could not find socket source in the StreamExecution logical plan to add data to") + } else if (sources.size > 1) { + throw new Exception( + "Could not select the socket source in the StreamExecution logical plan as there" + + "are multiple socket sources:\n\t" + sources.mkString("\n\t")) + } + val socketSource = sources.head + + assert(serverThread != null && serverThread.port != 0) + val currOffset = socketSource.getCurrentOffset() + data.foreach(serverThread.enqueue) + + val newOffset = LongOffset(currOffset.offset + data.size) + (socketSource, newOffset) + } + + override def toString: String = s"AddSocketData(data = $data)" + } + + test("backward compatibility with old path") { + DataSource.lookupDataSource("org.apache.spark.sql.execution.streaming.TextSocketSourceProvider", + spark.sqlContext.conf).newInstance() match { + case ds: MicroBatchReadSupport => + assert(ds.isInstanceOf[TextSocketSourceProvider]) + case _ => + throw new IllegalStateException("Could not find socket source") + } + } + + test("basic usage") { + serverThread = new ServerThread() + serverThread.start() + + withSQLConf("spark.sql.streaming.unsupportedOperationCheck" -> "false") { + val ref = spark + import ref.implicits._ + + val socket = spark + .readStream + .format("socket") + .options(Map("host" -> "localhost", "port" -> serverThread.port.toString)) + .load() + .as[String] + + assert(socket.schema === StructType(StructField("value", StringType) :: Nil)) + + testStream(socket)( + StartStream(), + AddSocketData("hello"), + CheckAnswer("hello"), + AddSocketData("world"), + CheckLastBatch("world"), + CheckAnswer("hello", "world"), + StopStream + ) + } + } + + test("timestamped usage") { + serverThread = new ServerThread() + serverThread.start() + + withSQLConf("spark.sql.streaming.unsupportedOperationCheck" -> "false") { + val socket = spark + .readStream + .format("socket") + .options(Map( + "host" -> "localhost", + "port" -> serverThread.port.toString, + "includeTimestamp" -> "true")) + .load() + + assert(socket.schema === StructType(StructField("value", StringType) :: + StructField("timestamp", TimestampType) :: Nil)) + + var batch1Stamp: Timestamp = null + var batch2Stamp: Timestamp = null + + val curr = System.currentTimeMillis() + testStream(socket)( + StartStream(), + AddSocketData("hello"), + CheckAnswerRowsByFunc( + rows => { + assert(rows.size === 1) + assert(rows.head.getAs[String](0) === "hello") + batch1Stamp = rows.head.getAs[Timestamp](1) + Thread.sleep(10) + }, + true), + AddSocketData("world"), + CheckAnswerRowsByFunc( + rows => { + assert(rows.size === 1) + assert(rows.head.getAs[String](0) === "world") + batch2Stamp = rows.head.getAs[Timestamp](1) + }, + true), + StopStream + ) + + // Timestamp for rate stream is round to second which leads to milliseconds lost, that will + // make batch1stamp smaller than current timestamp if both of them are in the same second. + // Comparing by second to make sure the correct behavior. + assert(batch1Stamp.getTime >= curr / 1000 * 1000) + assert(!batch2Stamp.before(batch1Stamp)) + } + } + + test("params not given") { + val provider = new TextSocketSourceProvider + intercept[AnalysisException] { + provider.createMicroBatchReader(Optional.empty(), "", + new DataSourceOptions(Map.empty[String, String].asJava)) + } + intercept[AnalysisException] { + provider.createMicroBatchReader(Optional.empty(), "", + new DataSourceOptions(Map("host" -> "localhost").asJava)) + } + intercept[AnalysisException] { + provider.createMicroBatchReader(Optional.empty(), "", + new DataSourceOptions(Map("port" -> "1234").asJava)) + } + } + + test("non-boolean includeTimestamp") { + val provider = new TextSocketSourceProvider + val params = Map("host" -> "localhost", "port" -> "1234", "includeTimestamp" -> "fasle") + intercept[AnalysisException] { + val a = new DataSourceOptions(params.asJava) + provider.createMicroBatchReader(Optional.empty(), "", a) + } + } + + test("user-specified schema given") { + val provider = new TextSocketSourceProvider + val userSpecifiedSchema = StructType( + StructField("name", StringType) :: + StructField("area", StringType) :: Nil) + val params = Map("host" -> "localhost", "port" -> "1234") + val exception = intercept[AnalysisException] { + provider.createMicroBatchReader( + Optional.of(userSpecifiedSchema), "", new DataSourceOptions(params.asJava)) + } + assert(exception.getMessage.contains( + "socket source does not support a user-specified schema")) + } + + test("no server up") { + val provider = new TextSocketSourceProvider + val parameters = Map("host" -> "localhost", "port" -> "0") + intercept[IOException] { + batchReader = provider.createMicroBatchReader( + Optional.empty(), "", new DataSourceOptions(parameters.asJava)) + } + } + + test("input row metrics") { + serverThread = new ServerThread() + serverThread.start() + + withSQLConf("spark.sql.streaming.unsupportedOperationCheck" -> "false") { + val ref = spark + import ref.implicits._ + + val socket = spark + .readStream + .format("socket") + .options(Map("host" -> "localhost", "port" -> serverThread.port.toString)) + .load() + .as[String] + + assert(socket.schema === StructType(StructField("value", StringType) :: Nil)) + + testStream(socket)( + StartStream(), + AddSocketData("hello"), + CheckAnswer("hello"), + AssertOnQuery { q => + val numRowMetric = + q.lastExecution.executedPlan.collectLeaves().head.metrics.get("numOutputRows") + numRowMetric.nonEmpty && numRowMetric.get.value == 1 + }, + StopStream + ) + } + } + + private class ServerThread extends Thread with Logging { + private val serverSocketChannel = ServerSocketChannel.open() + serverSocketChannel.bind(new InetSocketAddress(0)) + private val messageQueue = new LinkedBlockingQueue[String]() + + val port = serverSocketChannel.socket().getLocalPort + + override def run(): Unit = { + try { + while (true) { + val clientSocketChannel = serverSocketChannel.accept() + clientSocketChannel.configureBlocking(false) + clientSocketChannel.socket().setTcpNoDelay(true) + + // Check whether remote client is closed but still send data to this closed socket. + // This happens in DataStreamReader where a source will be created to get the schema. + var remoteIsClosed = false + var cnt = 0 + while (cnt < 3 && !remoteIsClosed) { + if (clientSocketChannel.read(ByteBuffer.allocate(1)) != -1) { + cnt += 1 + Thread.sleep(100) + } else { + remoteIsClosed = true + } + } + + if (remoteIsClosed) { + logInfo(s"remote client ${clientSocketChannel.socket()} is closed") + } else { + while (true) { + val line = messageQueue.take() + "\n" + clientSocketChannel.write(ByteBuffer.wrap(line.getBytes("UTF-8"))) + } + } + } + } catch { + case e: InterruptedException => + } finally { + serverSocketChannel.close() + } + } + + def enqueue(line: String): Unit = { + messageQueue.put(line) + } + } +} From 487377e693af65b2ff3d6b874ca7326c1ff0076c Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Fri, 2 Mar 2018 14:30:37 -0800 Subject: [PATCH 0416/1161] [SPARK-23570][SQL] Add Spark 2.3.0 in HiveExternalCatalogVersionsSuite ## What changes were proposed in this pull request? Add Spark 2.3.0 in HiveExternalCatalogVersionsSuite since Spark 2.3.0 is released for ensuring backward compatibility. ## How was this patch tested? N/A Author: gatorsmile Closes #20720 from gatorsmile/add2.3. --- .../spark/sql/hive/HiveExternalCatalogVersionsSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala index c13a750dbb270..6ca58e68d31eb 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala @@ -195,7 +195,7 @@ class HiveExternalCatalogVersionsSuite extends SparkSubmitTestUtils { object PROCESS_TABLES extends QueryTest with SQLTestUtils { // Tests the latest version of every release line. - val testingVersions = Seq("2.0.2", "2.1.2", "2.2.0", "2.2.1") + val testingVersions = Seq("2.0.2", "2.1.2", "2.2.0", "2.2.1", "2.3.0") protected var spark: SparkSession = _ From 9e26473c0f29ee4281519104ac5e182a3bd4bf23 Mon Sep 17 00:00:00 2001 From: Alessandro Solimando <18898964+asolimando@users.noreply.github.com> Date: Fri, 2 Mar 2018 16:24:29 -0800 Subject: [PATCH 0417/1161] [SPARK-3159][ML] Add decision tree pruning ## What changes were proposed in this pull request? Added subtree pruning in the translation from LearningNode to Node: a learning node having a single prediction value for all the leaves in the subtree rooted at it is translated into a LeafNode, instead of a (redundant) InternalNode ## How was this patch tested? Added two unit tests under "mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala": - test("SPARK-3159 tree model redundancy - classification") - test("SPARK-3159 tree model redundancy - regression") 4 existing unit tests relying on the tree structure (existence of a specific redundant subtree) had to be adapted as the tested components in the output tree are now pruned (fixed by adding an extra _prune_ parameter which can be used to disable pruning for testing) Author: Alessandro Solimando <18898964+asolimando@users.noreply.github.com> Closes #20632 from asolimando/master. --- .../scala/org/apache/spark/ml/tree/Node.scala | 22 ++-- .../spark/ml/tree/impl/RandomForest.scala | 10 +- .../DecisionTreeClassifierSuite.scala | 38 ------- .../ml/tree/impl/RandomForestSuite.scala | 100 ++++++++++++++++-- .../spark/mllib/tree/DecisionTreeSuite.scala | 10 +- 5 files changed, 115 insertions(+), 65 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala index 07e98a142b10e..d30be452a436e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala @@ -19,8 +19,7 @@ package org.apache.spark.ml.tree import org.apache.spark.ml.linalg.Vector import org.apache.spark.mllib.tree.impurity.ImpurityCalculator -import org.apache.spark.mllib.tree.model.{ImpurityStats, - InformationGainStats => OldInformationGainStats, Node => OldNode, Predict => OldPredict} +import org.apache.spark.mllib.tree.model.{ImpurityStats, InformationGainStats => OldInformationGainStats, Node => OldNode, Predict => OldPredict} /** * Decision tree node interface. @@ -266,15 +265,23 @@ private[tree] class LearningNode( var isLeaf: Boolean, var stats: ImpurityStats) extends Serializable { + def toNode: Node = toNode(prune = true) + /** * Convert this [[LearningNode]] to a regular [[Node]], and recurse on any children. */ - def toNode: Node = { - if (leftChild.nonEmpty) { - assert(rightChild.nonEmpty && split.nonEmpty && stats != null, + def toNode(prune: Boolean = true): Node = { + + if (!leftChild.isEmpty || !rightChild.isEmpty) { + assert(leftChild.nonEmpty && rightChild.nonEmpty && split.nonEmpty && stats != null, "Unknown error during Decision Tree learning. Could not convert LearningNode to Node.") - new InternalNode(stats.impurityCalculator.predict, stats.impurity, stats.gain, - leftChild.get.toNode, rightChild.get.toNode, split.get, stats.impurityCalculator) + (leftChild.get.toNode(prune), rightChild.get.toNode(prune)) match { + case (l: LeafNode, r: LeafNode) if prune && l.prediction == r.prediction => + new LeafNode(l.prediction, stats.impurity, stats.impurityCalculator) + case (l, r) => + new InternalNode(stats.impurityCalculator.predict, stats.impurity, stats.gain, + l, r, split.get, stats.impurityCalculator) + } } else { if (stats.valid) { new LeafNode(stats.impurityCalculator.predict, stats.impurity, @@ -283,7 +290,6 @@ private[tree] class LearningNode( // Here we want to keep same behavior with the old mllib.DecisionTreeModel new LeafNode(stats.impurityCalculator.predict, -1.0, stats.impurityCalculator) } - } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index acfc6399c553b..8e514f11e78ea 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -92,6 +92,7 @@ private[spark] object RandomForest extends Logging { featureSubsetStrategy: String, seed: Long, instr: Option[Instrumentation[_]], + prune: Boolean = true, // exposed for testing only, real trees are always pruned parentUID: Option[String] = None): Array[DecisionTreeModel] = { val timer = new TimeTracker() @@ -223,22 +224,23 @@ private[spark] object RandomForest extends Logging { case Some(uid) => if (strategy.algo == OldAlgo.Classification) { topNodes.map { rootNode => - new DecisionTreeClassificationModel(uid, rootNode.toNode, numFeatures, + new DecisionTreeClassificationModel(uid, rootNode.toNode(prune), numFeatures, strategy.getNumClasses) } } else { topNodes.map { rootNode => - new DecisionTreeRegressionModel(uid, rootNode.toNode, numFeatures) + new DecisionTreeRegressionModel(uid, rootNode.toNode(prune), numFeatures) } } case None => if (strategy.algo == OldAlgo.Classification) { topNodes.map { rootNode => - new DecisionTreeClassificationModel(rootNode.toNode, numFeatures, + new DecisionTreeClassificationModel(rootNode.toNode(prune), numFeatures, strategy.getNumClasses) } } else { - topNodes.map(rootNode => new DecisionTreeRegressionModel(rootNode.toNode, numFeatures)) + topNodes.map(rootNode => + new DecisionTreeRegressionModel(rootNode.toNode(prune), numFeatures)) } } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala index 98c879ece62d6..38b265d62611b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala @@ -280,44 +280,6 @@ class DecisionTreeClassifierSuite dt.fit(df) } - test("Use soft prediction for binary classification with ordered categorical features") { - // The following dataset is set up such that the best split is {1} vs. {0, 2}. - // If the hard prediction is used to order the categories, then {0} vs. {1, 2} is chosen. - val arr = Array( - LabeledPoint(0.0, Vectors.dense(0.0)), - LabeledPoint(0.0, Vectors.dense(0.0)), - LabeledPoint(0.0, Vectors.dense(0.0)), - LabeledPoint(1.0, Vectors.dense(0.0)), - LabeledPoint(0.0, Vectors.dense(1.0)), - LabeledPoint(0.0, Vectors.dense(1.0)), - LabeledPoint(0.0, Vectors.dense(1.0)), - LabeledPoint(0.0, Vectors.dense(1.0)), - LabeledPoint(0.0, Vectors.dense(2.0)), - LabeledPoint(0.0, Vectors.dense(2.0)), - LabeledPoint(0.0, Vectors.dense(2.0)), - LabeledPoint(1.0, Vectors.dense(2.0))) - val data = sc.parallelize(arr) - val df = TreeTests.setMetadata(data, Map(0 -> 3), 2) - - // Must set maxBins s.t. the feature will be treated as an ordered categorical feature. - val dt = new DecisionTreeClassifier() - .setImpurity("gini") - .setMaxDepth(1) - .setMaxBins(3) - val model = dt.fit(df) - model.rootNode match { - case n: InternalNode => - n.split match { - case s: CategoricalSplit => - assert(s.leftCategories === Array(1.0)) - case other => - fail(s"All splits should be categorical, but got ${other.getClass.getName}: $other.") - } - case other => - fail(s"Root node should be an internal node, but got ${other.getClass.getName}: $other.") - } - } - test("Feature importance with toy data") { val dt = new DecisionTreeClassifier() .setImpurity("gini") diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala index dbe2ea931fb9c..5f0d26eb5c058 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.ml.tree.impl +import scala.annotation.tailrec import scala.collection.mutable import org.apache.spark.SparkFunSuite @@ -38,6 +39,8 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { import RandomForestSuite.mapToVec + private val seed = 42 + ///////////////////////////////////////////////////////////////////////////// // Tests for split calculation ///////////////////////////////////////////////////////////////////////////// @@ -320,10 +323,10 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { assert(topNode.isLeaf === false) assert(topNode.stats === null) - val nodesForGroup = Map((0, Array(topNode))) - val treeToNodeToIndexInfo = Map((0, Map( - (topNode.id, new RandomForest.NodeIndexInfo(0, None)) - ))) + val nodesForGroup = Map(0 -> Array(topNode)) + val treeToNodeToIndexInfo = Map(0 -> Map( + topNode.id -> new RandomForest.NodeIndexInfo(0, None) + )) val nodeStack = new mutable.ArrayStack[(Int, LearningNode)] RandomForest.findBestSplits(baggedInput, metadata, Map(0 -> topNode), nodesForGroup, treeToNodeToIndexInfo, splits, nodeStack) @@ -362,10 +365,10 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { assert(topNode.isLeaf === false) assert(topNode.stats === null) - val nodesForGroup = Map((0, Array(topNode))) - val treeToNodeToIndexInfo = Map((0, Map( - (topNode.id, new RandomForest.NodeIndexInfo(0, None)) - ))) + val nodesForGroup = Map(0 -> Array(topNode)) + val treeToNodeToIndexInfo = Map(0 -> Map( + topNode.id -> new RandomForest.NodeIndexInfo(0, None) + )) val nodeStack = new mutable.ArrayStack[(Int, LearningNode)] RandomForest.findBestSplits(baggedInput, metadata, Map(0 -> topNode), nodesForGroup, treeToNodeToIndexInfo, splits, nodeStack) @@ -407,7 +410,8 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3), maxBins = 3) val model = RandomForest.run(input, strategy, numTrees = 1, featureSubsetStrategy = "all", - seed = 42, instr = None).head + seed = 42, instr = None, prune = false).head + model.rootNode match { case n: InternalNode => n.split match { case s: CategoricalSplit => @@ -631,13 +635,89 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { val expected = Map(0 -> 1.0 / 3.0, 2 -> 2.0 / 3.0) assert(mapToVec(map.toMap) ~== mapToVec(expected) relTol 0.01) } + + /////////////////////////////////////////////////////////////////////////////// + // Tests for pruning of redundant subtrees (generated by a split improving the + // impurity measure, but always leading to the same prediction). + /////////////////////////////////////////////////////////////////////////////// + + test("SPARK-3159 tree model redundancy - classification") { + // The following dataset is set up such that splitting over feature_1 for points having + // feature_0 = 0 improves the impurity measure, despite the prediction will always be 0 + // in both branches. + val arr = Array( + LabeledPoint(0.0, Vectors.dense(0.0, 1.0)), + LabeledPoint(1.0, Vectors.dense(0.0, 1.0)), + LabeledPoint(0.0, Vectors.dense(0.0, 0.0)), + LabeledPoint(1.0, Vectors.dense(1.0, 0.0)), + LabeledPoint(0.0, Vectors.dense(1.0, 0.0)), + LabeledPoint(1.0, Vectors.dense(1.0, 1.0)) + ) + val rdd = sc.parallelize(arr) + + val numClasses = 2 + val strategy = new OldStrategy(algo = OldAlgo.Classification, impurity = Gini, maxDepth = 4, + numClasses = numClasses, maxBins = 32) + + val prunedTree = RandomForest.run(rdd, strategy, numTrees = 1, featureSubsetStrategy = "auto", + seed = 42, instr = None).head + + val unprunedTree = RandomForest.run(rdd, strategy, numTrees = 1, featureSubsetStrategy = "auto", + seed = 42, instr = None, prune = false).head + + assert(prunedTree.numNodes === 5) + assert(unprunedTree.numNodes === 7) + + assert(RandomForestSuite.getSumLeafCounters(List(prunedTree.rootNode)) === arr.size) + } + + test("SPARK-3159 tree model redundancy - regression") { + // The following dataset is set up such that splitting over feature_0 for points having + // feature_1 = 1 improves the impurity measure, despite the prediction will always be 0.5 + // in both branches. + val arr = Array( + LabeledPoint(0.0, Vectors.dense(0.0, 1.0)), + LabeledPoint(1.0, Vectors.dense(0.0, 1.0)), + LabeledPoint(0.0, Vectors.dense(0.0, 0.0)), + LabeledPoint(0.0, Vectors.dense(1.0, 0.0)), + LabeledPoint(1.0, Vectors.dense(1.0, 1.0)), + LabeledPoint(0.0, Vectors.dense(1.0, 1.0)), + LabeledPoint(0.5, Vectors.dense(1.0, 1.0)) + ) + val rdd = sc.parallelize(arr) + + val strategy = new OldStrategy(algo = OldAlgo.Regression, impurity = Variance, maxDepth = 4, + numClasses = 0, maxBins = 32) + + val prunedTree = RandomForest.run(rdd, strategy, numTrees = 1, featureSubsetStrategy = "auto", + seed = 42, instr = None).head + + val unprunedTree = RandomForest.run(rdd, strategy, numTrees = 1, featureSubsetStrategy = "auto", + seed = 42, instr = None, prune = false).head + + assert(prunedTree.numNodes === 3) + assert(unprunedTree.numNodes === 5) + assert(RandomForestSuite.getSumLeafCounters(List(prunedTree.rootNode)) === arr.size) + } } private object RandomForestSuite { - def mapToVec(map: Map[Int, Double]): Vector = { val size = (map.keys.toSeq :+ 0).max + 1 val (indices, values) = map.toSeq.sortBy(_._1).unzip Vectors.sparse(size, indices.toArray, values.toArray) } + + @tailrec + private def getSumLeafCounters(nodes: List[Node], acc: Long = 0): Long = { + if (nodes.isEmpty) { + acc + } + else { + nodes.head match { + case i: InternalNode => getSumLeafCounters(i.leftChild :: i.rightChild :: nodes.tail, acc) + case l: LeafNode => getSumLeafCounters(nodes.tail, acc + l.impurityStats.count) + } + } + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 441d0f7614bf6..bc59f3f4125fb 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -363,10 +363,10 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { // if a split does not satisfy min instances per node requirements, // this split is invalid, even though the information gain of split is large. val arr = Array( - LabeledPoint(0.0, Vectors.dense(0.0, 1.0)), - LabeledPoint(1.0, Vectors.dense(1.0, 1.0)), - LabeledPoint(0.0, Vectors.dense(0.0, 0.0)), - LabeledPoint(0.0, Vectors.dense(0.0, 0.0))) + LabeledPoint(1.0, Vectors.dense(0.0, 1.0)), + LabeledPoint(0.0, Vectors.dense(1.0, 1.0)), + LabeledPoint(1.0, Vectors.dense(0.0, 0.0)), + LabeledPoint(1.0, Vectors.dense(0.0, 0.0))) val rdd = sc.parallelize(arr) val strategy = new Strategy(algo = Classification, impurity = Gini, @@ -541,7 +541,7 @@ object DecisionTreeSuite extends SparkFunSuite { Array[LabeledPoint] = { val arr = new Array[LabeledPoint](3000) for (i <- 0 until 3000) { - if (i < 1000) { + if (i < 1001) { arr(i) = new LabeledPoint(2.0, Vectors.dense(2.0, 2.0)) } else if (i < 2000) { arr(i) = new LabeledPoint(1.0, Vectors.dense(1.0, 2.0)) From dea381dfaa73e0cfb9a833b79c741b15ae274f64 Mon Sep 17 00:00:00 2001 From: Juliusz Sompolski Date: Sat, 3 Mar 2018 09:10:48 +0800 Subject: [PATCH 0418/1161] [SPARK-23514][FOLLOW-UP] Remove more places using sparkContext.hadoopConfiguration directly ## What changes were proposed in this pull request? In https://github.com/apache/spark/pull/20679 I missed a few places in SQL tests. For hygiene, they should also use the sessionState interface where possible. ## How was this patch tested? Modified existing tests. Author: Juliusz Sompolski Closes #20718 from juliuszsompolski/SPARK-23514-followup. --- .../scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala | 2 +- .../org/apache/spark/sql/execution/command/DDLSuite.scala | 2 +- .../spark/sql/execution/datasources/FileIndexSuite.scala | 2 +- .../execution/datasources/parquet/ParquetCommitterSuite.scala | 2 +- .../datasources/parquet/ParquetFileFormatSuite.scala | 4 ++-- .../datasources/parquet/ParquetInteroperabilitySuite.scala | 2 +- .../sql/execution/datasources/parquet/ParquetQuerySuite.scala | 2 +- .../org/apache/spark/sql/streaming/FileStreamSinkSuite.scala | 2 +- 8 files changed, 9 insertions(+), 9 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala index b5d4c558f0d3e..73e3df3b6202e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala @@ -124,7 +124,7 @@ class FileBasedDataSourceSuite extends QueryTest with SharedSQLContext with Befo Seq("1").toDF("a").write.format(format).save(new Path(basePath, "second").toString) val thirdPath = new Path(basePath, "third") - val fs = thirdPath.getFileSystem(spark.sparkContext.hadoopConfiguration) + val fs = thirdPath.getFileSystem(spark.sessionState.newHadoopConf()) Seq("2").toDF("a").write.format(format).save(thirdPath.toString) val files = fs.listStatus(thirdPath).filter(_.isFile).map(_.getPath) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index b800e6ff5b0ce..db9023b7ec8b6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -1052,7 +1052,7 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { val part2 = Map("a" -> "2", "b" -> "6") val root = new Path(catalog.getTableMetadata(tableIdent).location) - val fs = root.getFileSystem(spark.sparkContext.hadoopConfiguration) + val fs = root.getFileSystem(spark.sessionState.newHadoopConf()) // valid fs.mkdirs(new Path(new Path(root, "a=1"), "b=5")) fs.createNewFile(new Path(new Path(root, "a=1/b=5"), "a.csv")) // file diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileIndexSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileIndexSuite.scala index b4616826e40b3..18bb4bfe661ce 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileIndexSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileIndexSuite.scala @@ -59,7 +59,7 @@ class FileIndexSuite extends SharedSQLContext { require(!unqualifiedDirPath.toString.contains("file:")) require(!unqualifiedFilePath.toString.contains("file:")) - val fs = unqualifiedDirPath.getFileSystem(sparkContext.hadoopConfiguration) + val fs = unqualifiedDirPath.getFileSystem(spark.sessionState.newHadoopConf()) val qualifiedFilePath = fs.makeQualified(new Path(file.getCanonicalPath)) require(qualifiedFilePath.toString.startsWith("file:")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCommitterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCommitterSuite.scala index caa4f6d70c6a9..f3ecc5ced689f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCommitterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCommitterSuite.scala @@ -101,7 +101,7 @@ class ParquetCommitterSuite extends SparkFunSuite with SQLTestUtils if (check) { result = Some(MarkingFileOutput.checkMarker( destPath, - spark.sparkContext.hadoopConfiguration)) + spark.sessionState.newHadoopConf())) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormatSuite.scala index ccb34355f1bac..3a0867fd2b78b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormatSuite.scala @@ -29,7 +29,7 @@ class ParquetFileFormatSuite extends QueryTest with ParquetTest with SharedSQLCo test("read parquet footers in parallel") { def testReadFooters(ignoreCorruptFiles: Boolean): Unit = { withTempDir { dir => - val fs = FileSystem.get(sparkContext.hadoopConfiguration) + val fs = FileSystem.get(spark.sessionState.newHadoopConf()) val basePath = dir.getCanonicalPath val path1 = new Path(basePath, "first") @@ -44,7 +44,7 @@ class ParquetFileFormatSuite extends QueryTest with ParquetTest with SharedSQLCo Seq(fs.listStatus(path1), fs.listStatus(path2), fs.listStatus(path3)).flatten val footers = ParquetFileFormat.readParquetFootersInParallel( - sparkContext.hadoopConfiguration, fileStatuses, ignoreCorruptFiles) + spark.sessionState.newHadoopConf(), fileStatuses, ignoreCorruptFiles) assert(footers.size == 2) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetInteroperabilitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetInteroperabilitySuite.scala index e3edafa9c84e1..fbd83a0fa425a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetInteroperabilitySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetInteroperabilitySuite.scala @@ -163,7 +163,7 @@ class ParquetInteroperabilitySuite extends ParquetCompatibilityTest with SharedS // Just to be defensive in case anything ever changes in parquet, this test checks // the assumption on column stats, and also the end-to-end behavior. - val hadoopConf = sparkContext.hadoopConfiguration + val hadoopConf = spark.sessionState.newHadoopConf() val fs = FileSystem.get(hadoopConf) val parts = fs.listStatus(new Path(tableDir.getAbsolutePath), new PathFilter { override def accept(path: Path): Boolean = !path.getName.startsWith("_") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala index 55b0f729be8ce..e1f094d0a7af3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala @@ -819,7 +819,7 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext val path = dir.getCanonicalPath spark.range(3).write.parquet(path) - val fs = FileSystem.get(sparkContext.hadoopConfiguration) + val fs = FileSystem.get(spark.sessionState.newHadoopConf()) val files = fs.listFiles(new Path(path), true) while (files.hasNext) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala index ba48bc1ce0c4d..31e5527d7366a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala @@ -353,7 +353,7 @@ class FileStreamSinkSuite extends StreamTest { } test("FileStreamSink.ancestorIsMetadataDirectory()") { - val hadoopConf = spark.sparkContext.hadoopConfiguration + val hadoopConf = spark.sessionState.newHadoopConf() def assertAncestorIsMetadataDirectory(path: String): Unit = assert(FileStreamSink.ancestorIsMetadataDirectory(new Path(path), hadoopConf)) def assertAncestorIsNotMetadataDirectory(path: String): Unit = From 486f99eefead4e664a30a861eca65cab8568e70b Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Fri, 2 Mar 2018 18:14:13 -0800 Subject: [PATCH 0419/1161] [SPARK-23541][SS] Allow Kafka source to read data with greater parallelism than the number of topic-partitions ## What changes were proposed in this pull request? Currently, when the Kafka source reads from Kafka, it generates as many tasks as the number of partitions in the topic(s) to be read. In some case, it may be beneficial to read the data with greater parallelism, that is, with more number partitions/tasks. That means, offset ranges must be divided up into smaller ranges such the number of records in partition ~= total records in batch / desired partitions. This would also balance out any data skews between topic-partitions. In this patch, I have added a new option called `minPartitions`, which allows the user to specify the desired level of parallelism. ## How was this patch tested? New tests in KafkaMicroBatchV2SourceSuite. Author: Tathagata Das Closes #20698 from tdas/SPARK-23541. --- .../sql/kafka010/KafkaMicroBatchReader.scala | 109 ++++++------- .../kafka010/KafkaOffsetRangeCalculator.scala | 105 +++++++++++++ .../sql/kafka010/KafkaSourceProvider.scala | 7 + .../apache/spark/sql/kafka010/package.scala | 24 +++ .../kafka010/KafkaMicroBatchSourceSuite.scala | 56 ++++++- .../KafkaOffsetRangeCalculatorSuite.scala | 147 ++++++++++++++++++ 6 files changed, 388 insertions(+), 60 deletions(-) create mode 100644 external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetRangeCalculator.scala create mode 100644 external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/package.scala create mode 100644 external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaOffsetRangeCalculatorSuite.scala diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala index fb647ca7e70dd..8a5f3a249b11c 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala @@ -24,7 +24,6 @@ import java.nio.charset.StandardCharsets import scala.collection.JavaConverters._ import org.apache.commons.io.IOUtils -import org.apache.kafka.common.TopicPartition import org.apache.spark.SparkEnv import org.apache.spark.internal.Logging @@ -64,8 +63,6 @@ private[kafka010] class KafkaMicroBatchReader( failOnDataLoss: Boolean) extends MicroBatchReader with SupportsScanUnsafeRow with Logging { - type PartitionOffsetMap = Map[TopicPartition, Long] - private var startPartitionOffsets: PartitionOffsetMap = _ private var endPartitionOffsets: PartitionOffsetMap = _ @@ -76,6 +73,7 @@ private[kafka010] class KafkaMicroBatchReader( private val maxOffsetsPerTrigger = Option(options.get("maxOffsetsPerTrigger").orElse(null)).map(_.toLong) + private val rangeCalculator = KafkaOffsetRangeCalculator(options) /** * Lazily initialize `initialPartitionOffsets` to make sure that `KafkaConsumer.poll` is only * called in StreamExecutionThread. Otherwise, interrupting a thread while running @@ -106,15 +104,15 @@ private[kafka010] class KafkaMicroBatchReader( override def createUnsafeRowReaderFactories(): ju.List[DataReaderFactory[UnsafeRow]] = { // Find the new partitions, and get their earliest offsets val newPartitions = endPartitionOffsets.keySet.diff(startPartitionOffsets.keySet) - val newPartitionOffsets = kafkaOffsetReader.fetchEarliestOffsets(newPartitions.toSeq) - if (newPartitionOffsets.keySet != newPartitions) { + val newPartitionInitialOffsets = kafkaOffsetReader.fetchEarliestOffsets(newPartitions.toSeq) + if (newPartitionInitialOffsets.keySet != newPartitions) { // We cannot get from offsets for some partitions. It means they got deleted. - val deletedPartitions = newPartitions.diff(newPartitionOffsets.keySet) + val deletedPartitions = newPartitions.diff(newPartitionInitialOffsets.keySet) reportDataLoss( s"Cannot find earliest offsets of ${deletedPartitions}. Some data may have been missed") } - logInfo(s"Partitions added: $newPartitionOffsets") - newPartitionOffsets.filter(_._2 != 0).foreach { case (p, o) => + logInfo(s"Partitions added: $newPartitionInitialOffsets") + newPartitionInitialOffsets.filter(_._2 != 0).foreach { case (p, o) => reportDataLoss( s"Added partition $p starts from $o instead of 0. Some data may have been missed") } @@ -125,46 +123,28 @@ private[kafka010] class KafkaMicroBatchReader( reportDataLoss(s"$deletedPartitions are gone. Some data may have been missed") } - // Use the until partitions to calculate offset ranges to ignore partitions that have + // Use the end partitions to calculate offset ranges to ignore partitions that have // been deleted val topicPartitions = endPartitionOffsets.keySet.filter { tp => // Ignore partitions that we don't know the from offsets. - newPartitionOffsets.contains(tp) || startPartitionOffsets.contains(tp) + newPartitionInitialOffsets.contains(tp) || startPartitionOffsets.contains(tp) }.toSeq logDebug("TopicPartitions: " + topicPartitions.mkString(", ")) - val sortedExecutors = getSortedExecutorList() - val numExecutors = sortedExecutors.length - logDebug("Sorted executors: " + sortedExecutors.mkString(", ")) - // Calculate offset ranges - val factories = topicPartitions.flatMap { tp => - val fromOffset = startPartitionOffsets.get(tp).getOrElse { - newPartitionOffsets.getOrElse( - tp, { - // This should not happen since newPartitionOffsets contains all partitions not in - // fromPartitionOffsets - throw new IllegalStateException(s"$tp doesn't have a from offset") - }) - } - val untilOffset = endPartitionOffsets(tp) - - if (untilOffset >= fromOffset) { - // This allows cached KafkaConsumers in the executors to be re-used to read the same - // partition in every batch. - val preferredLoc = if (numExecutors > 0) { - Some(sortedExecutors(Math.floorMod(tp.hashCode, numExecutors))) - } else None - val range = KafkaOffsetRange(tp, fromOffset, untilOffset) - Some( - new KafkaMicroBatchDataReaderFactory( - range, preferredLoc, executorKafkaParams, pollTimeoutMs, failOnDataLoss)) - } else { - reportDataLoss( - s"Partition $tp's offset was changed from " + - s"$fromOffset to $untilOffset, some data may have been missed") - None - } + val offsetRanges = rangeCalculator.getRanges( + fromOffsets = startPartitionOffsets ++ newPartitionInitialOffsets, + untilOffsets = endPartitionOffsets, + executorLocations = getSortedExecutorList()) + + // Reuse Kafka consumers only when all the offset ranges have distinct TopicPartitions, + // that is, concurrent tasks will not read the same TopicPartitions. + val reuseKafkaConsumer = offsetRanges.map(_.topicPartition).toSet.size == offsetRanges.size + + // Generate factories based on the offset ranges + val factories = offsetRanges.map { range => + new KafkaMicroBatchDataReaderFactory( + range, executorKafkaParams, pollTimeoutMs, failOnDataLoss, reuseKafkaConsumer) } factories.map(_.asInstanceOf[DataReaderFactory[UnsafeRow]]).asJava } @@ -320,28 +300,39 @@ private[kafka010] class KafkaMicroBatchReader( } /** A [[DataReaderFactory]] for reading Kafka data in a micro-batch streaming query. */ -private[kafka010] class KafkaMicroBatchDataReaderFactory( - range: KafkaOffsetRange, - preferredLoc: Option[String], +private[kafka010] case class KafkaMicroBatchDataReaderFactory( + offsetRange: KafkaOffsetRange, executorKafkaParams: ju.Map[String, Object], pollTimeoutMs: Long, - failOnDataLoss: Boolean) extends DataReaderFactory[UnsafeRow] { + failOnDataLoss: Boolean, + reuseKafkaConsumer: Boolean) extends DataReaderFactory[UnsafeRow] { - override def preferredLocations(): Array[String] = preferredLoc.toArray + override def preferredLocations(): Array[String] = offsetRange.preferredLoc.toArray override def createDataReader(): DataReader[UnsafeRow] = new KafkaMicroBatchDataReader( - range, executorKafkaParams, pollTimeoutMs, failOnDataLoss) + offsetRange, executorKafkaParams, pollTimeoutMs, failOnDataLoss, reuseKafkaConsumer) } /** A [[DataReader]] for reading Kafka data in a micro-batch streaming query. */ -private[kafka010] class KafkaMicroBatchDataReader( +private[kafka010] case class KafkaMicroBatchDataReader( offsetRange: KafkaOffsetRange, executorKafkaParams: ju.Map[String, Object], pollTimeoutMs: Long, - failOnDataLoss: Boolean) extends DataReader[UnsafeRow] with Logging { + failOnDataLoss: Boolean, + reuseKafkaConsumer: Boolean) extends DataReader[UnsafeRow] with Logging { + + private val consumer = { + if (!reuseKafkaConsumer) { + // If we can't reuse CachedKafkaConsumers, creating a new CachedKafkaConsumer. We + // uses `assign` here, hence we don't need to worry about the "group.id" conflicts. + CachedKafkaConsumer.createUncached( + offsetRange.topicPartition.topic, offsetRange.topicPartition.partition, executorKafkaParams) + } else { + CachedKafkaConsumer.getOrCreate( + offsetRange.topicPartition.topic, offsetRange.topicPartition.partition, executorKafkaParams) + } + } - private val consumer = CachedKafkaConsumer.getOrCreate( - offsetRange.topicPartition.topic, offsetRange.topicPartition.partition, executorKafkaParams) private val rangeToRead = resolveRange(offsetRange) private val converter = new KafkaRecordToUnsafeRowConverter @@ -369,9 +360,14 @@ private[kafka010] class KafkaMicroBatchDataReader( } override def close(): Unit = { - // Indicate that we're no longer using this consumer - CachedKafkaConsumer.releaseKafkaConsumer( - offsetRange.topicPartition.topic, offsetRange.topicPartition.partition, executorKafkaParams) + if (!reuseKafkaConsumer) { + // Don't forget to close non-reuse KafkaConsumers. You may take down your cluster! + consumer.close() + } else { + // Indicate that we're no longer using this consumer + CachedKafkaConsumer.releaseKafkaConsumer( + offsetRange.topicPartition.topic, offsetRange.topicPartition.partition, executorKafkaParams) + } } private def resolveRange(range: KafkaOffsetRange): KafkaOffsetRange = { @@ -392,12 +388,9 @@ private[kafka010] class KafkaMicroBatchDataReader( } else { range.untilOffset } - KafkaOffsetRange(range.topicPartition, fromOffset, untilOffset) + KafkaOffsetRange(range.topicPartition, fromOffset, untilOffset, None) } else { range } } } - -private[kafka010] case class KafkaOffsetRange( - topicPartition: TopicPartition, fromOffset: Long, untilOffset: Long) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetRangeCalculator.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetRangeCalculator.scala new file mode 100644 index 0000000000000..6631ae84167c8 --- /dev/null +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetRangeCalculator.scala @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import org.apache.kafka.common.TopicPartition + +import org.apache.spark.sql.sources.v2.DataSourceOptions + + +/** + * Class to calculate offset ranges to process based on the the from and until offsets, and + * the configured `minPartitions`. + */ +private[kafka010] class KafkaOffsetRangeCalculator(val minPartitions: Option[Int]) { + require(minPartitions.isEmpty || minPartitions.get > 0) + + import KafkaOffsetRangeCalculator._ + /** + * Calculate the offset ranges that we are going to process this batch. If `minPartitions` + * is not set or is set less than or equal the number of `topicPartitions` that we're going to + * consume, then we fall back to a 1-1 mapping of Spark tasks to Kafka partitions. If + * `numPartitions` is set higher than the number of our `topicPartitions`, then we will split up + * the read tasks of the skewed partitions to multiple Spark tasks. + * The number of Spark tasks will be *approximately* `numPartitions`. It can be less or more + * depending on rounding errors or Kafka partitions that didn't receive any new data. + */ + def getRanges( + fromOffsets: PartitionOffsetMap, + untilOffsets: PartitionOffsetMap, + executorLocations: Seq[String] = Seq.empty): Seq[KafkaOffsetRange] = { + val partitionsToRead = untilOffsets.keySet.intersect(fromOffsets.keySet) + + val offsetRanges = partitionsToRead.toSeq.map { tp => + KafkaOffsetRange(tp, fromOffsets(tp), untilOffsets(tp), preferredLoc = None) + }.filter(_.size > 0) + + // If minPartitions not set or there are enough partitions to satisfy minPartitions + if (minPartitions.isEmpty || offsetRanges.size > minPartitions.get) { + // Assign preferred executor locations to each range such that the same topic-partition is + // preferentially read from the same executor and the KafkaConsumer can be reused. + offsetRanges.map { range => + range.copy(preferredLoc = getLocation(range.topicPartition, executorLocations)) + } + } else { + + // Splits offset ranges with relatively large amount of data to smaller ones. + val totalSize = offsetRanges.map(_.size).sum + val idealRangeSize = totalSize.toDouble / minPartitions.get + + offsetRanges.flatMap { range => + // Split the current range into subranges as close to the ideal range size + val numSplitsInRange = math.round(range.size.toDouble / idealRangeSize).toInt + + (0 until numSplitsInRange).map { i => + val splitStart = range.fromOffset + range.size * (i.toDouble / numSplitsInRange) + val splitEnd = range.fromOffset + range.size * ((i.toDouble + 1) / numSplitsInRange) + KafkaOffsetRange( + range.topicPartition, splitStart.toLong, splitEnd.toLong, preferredLoc = None) + } + } + } + } + + private def getLocation(tp: TopicPartition, executorLocations: Seq[String]): Option[String] = { + def floorMod(a: Long, b: Int): Int = ((a % b).toInt + b) % b + + val numExecutors = executorLocations.length + if (numExecutors > 0) { + // This allows cached KafkaConsumers in the executors to be re-used to read the same + // partition in every batch. + Some(executorLocations(floorMod(tp.hashCode, numExecutors))) + } else None + } +} + +private[kafka010] object KafkaOffsetRangeCalculator { + + def apply(options: DataSourceOptions): KafkaOffsetRangeCalculator = { + val optionalValue = Option(options.get("minPartitions").orElse(null)).map(_.toInt) + new KafkaOffsetRangeCalculator(optionalValue) + } +} + +private[kafka010] case class KafkaOffsetRange( + topicPartition: TopicPartition, + fromOffset: Long, + untilOffset: Long, + preferredLoc: Option[String]) { + lazy val size: Long = untilOffset - fromOffset +} diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala index 0aa64a6a9cf90..36b9f0466566b 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala @@ -348,6 +348,12 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister throw new IllegalArgumentException("Unknown option") } + // Validate minPartitions value if present + if (caseInsensitiveParams.contains(MIN_PARTITIONS_OPTION_KEY)) { + val p = caseInsensitiveParams(MIN_PARTITIONS_OPTION_KEY).toInt + if (p <= 0) throw new IllegalArgumentException("minPartitions must be positive") + } + // Validate user-specified Kafka options if (caseInsensitiveParams.contains(s"kafka.${ConsumerConfig.GROUP_ID_CONFIG}")) { @@ -455,6 +461,7 @@ private[kafka010] object KafkaSourceProvider extends Logging { private[kafka010] val STARTING_OFFSETS_OPTION_KEY = "startingoffsets" private[kafka010] val ENDING_OFFSETS_OPTION_KEY = "endingoffsets" private val FAIL_ON_DATA_LOSS_OPTION_KEY = "failondataloss" + private val MIN_PARTITIONS_OPTION_KEY = "minpartitions" val TOPIC_OPTION_KEY = "topic" diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/package.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/package.scala new file mode 100644 index 0000000000000..43acd6a8d9473 --- /dev/null +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/package.scala @@ -0,0 +1,24 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql + +import org.apache.kafka.common.TopicPartition + +package object kafka010 { // scalastyle:ignore + // ^^ scalastyle:ignore is for ignoring warnings about digits in package name + type PartitionOffsetMap = Map[TopicPartition, Long] +} diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala index 89c9ef4cc73b5..f2b3ff7615e74 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala @@ -20,10 +20,11 @@ package org.apache.spark.sql.kafka010 import java.io._ import java.nio.charset.StandardCharsets.UTF_8 import java.nio.file.{Files, Paths} -import java.util.{Locale, Properties} +import java.util.{Locale, Optional, Properties} import java.util.concurrent.ConcurrentLinkedQueue import java.util.concurrent.atomic.AtomicInteger +import scala.collection.JavaConverters._ import scala.collection.mutable import scala.io.Source import scala.util.Random @@ -34,15 +35,19 @@ import org.scalatest.concurrent.PatienceConfiguration.Timeout import org.scalatest.time.SpanSugar._ import org.apache.spark.SparkContext -import org.apache.spark.sql.{Dataset, ForeachWriter} +import org.apache.spark.sql.{Dataset, ForeachWriter, SparkSession} +import org.apache.spark.sql.catalyst.streaming.InternalOutputModes.Update import org.apache.spark.sql.execution.datasources.v2.StreamingDataSourceV2Relation import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution import org.apache.spark.sql.functions.{count, window} import org.apache.spark.sql.kafka010.KafkaSourceProvider._ +import org.apache.spark.sql.sources.v2.DataSourceOptions +import org.apache.spark.sql.sources.v2.reader.streaming.{Offset => OffsetV2} import org.apache.spark.sql.streaming.{ProcessingTime, StreamTest} import org.apache.spark.sql.streaming.util.StreamManualClock import org.apache.spark.sql.test.{SharedSQLContext, TestSparkSession} +import org.apache.spark.sql.types.StructType abstract class KafkaSourceTest extends StreamTest with SharedSQLContext { @@ -642,6 +647,53 @@ class KafkaMicroBatchV2SourceSuite extends KafkaMicroBatchSourceSuiteBase { } ) } + + testWithUninterruptibleThread("minPartitions is supported") { + import testImplicits._ + + val topic = newTopic() + val tp = new TopicPartition(topic, 0) + testUtils.createTopic(topic, partitions = 1) + + def test( + minPartitions: String, + numPartitionsGenerated: Int, + reusesConsumers: Boolean): Unit = { + + SparkSession.setActiveSession(spark) + withTempDir { dir => + val provider = new KafkaSourceProvider() + val options = Map( + "kafka.bootstrap.servers" -> testUtils.brokerAddress, + "subscribe" -> topic + ) ++ Option(minPartitions).map { p => "minPartitions" -> p} + val reader = provider.createMicroBatchReader( + Optional.empty[StructType], dir.getAbsolutePath, new DataSourceOptions(options.asJava)) + reader.setOffsetRange( + Optional.of[OffsetV2](KafkaSourceOffset(Map(tp -> 0L))), + Optional.of[OffsetV2](KafkaSourceOffset(Map(tp -> 100L))) + ) + val factories = reader.createUnsafeRowReaderFactories().asScala + .map(_.asInstanceOf[KafkaMicroBatchDataReaderFactory]) + withClue(s"minPartitions = $minPartitions generated factories $factories\n\t") { + assert(factories.size == numPartitionsGenerated) + factories.foreach { f => assert(f.reuseKafkaConsumer == reusesConsumers) } + } + } + } + + // Test cases when minPartitions is used and not used + test(minPartitions = null, numPartitionsGenerated = 1, reusesConsumers = true) + test(minPartitions = "1", numPartitionsGenerated = 1, reusesConsumers = true) + test(minPartitions = "4", numPartitionsGenerated = 4, reusesConsumers = false) + + // Test illegal minPartitions values + intercept[IllegalArgumentException] { test(minPartitions = "a", 1, true) } + intercept[IllegalArgumentException] { test(minPartitions = "1.0", 1, true) } + intercept[IllegalArgumentException] { test(minPartitions = "0", 1, true) } + intercept[IllegalArgumentException] { test(minPartitions = "-1", 1, true) } + } + } abstract class KafkaSourceSuiteBase extends KafkaSourceTest { diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaOffsetRangeCalculatorSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaOffsetRangeCalculatorSuite.scala new file mode 100644 index 0000000000000..2ccf3e291bea7 --- /dev/null +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaOffsetRangeCalculatorSuite.scala @@ -0,0 +1,147 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import scala.collection.JavaConverters._ + +import org.apache.kafka.common.TopicPartition + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.sources.v2.DataSourceOptions + +class KafkaOffsetRangeCalculatorSuite extends SparkFunSuite { + + def testWithMinPartitions(name: String, minPartition: Int) + (f: KafkaOffsetRangeCalculator => Unit): Unit = { + val options = new DataSourceOptions(Map("minPartitions" -> minPartition.toString).asJava) + test(s"with minPartition = $minPartition: $name") { + f(KafkaOffsetRangeCalculator(options)) + } + } + + + test("with no minPartition: N TopicPartitions to N offset ranges") { + val calc = KafkaOffsetRangeCalculator(DataSourceOptions.empty()) + assert( + calc.getRanges( + fromOffsets = Map(tp1 -> 1), + untilOffsets = Map(tp1 -> 2)) == + Seq(KafkaOffsetRange(tp1, 1, 2, None))) + + assert( + calc.getRanges( + fromOffsets = Map(tp1 -> 1), + untilOffsets = Map(tp1 -> 2, tp2 -> 1), Seq.empty) == + Seq(KafkaOffsetRange(tp1, 1, 2, None))) + + assert( + calc.getRanges( + fromOffsets = Map(tp1 -> 1, tp2 -> 1), + untilOffsets = Map(tp1 -> 2)) == + Seq(KafkaOffsetRange(tp1, 1, 2, None))) + + assert( + calc.getRanges( + fromOffsets = Map(tp1 -> 1, tp2 -> 1), + untilOffsets = Map(tp1 -> 2), + executorLocations = Seq("location")) == + Seq(KafkaOffsetRange(tp1, 1, 2, Some("location")))) + } + + test("with no minPartition: empty ranges ignored") { + val calc = KafkaOffsetRangeCalculator(DataSourceOptions.empty()) + assert( + calc.getRanges( + fromOffsets = Map(tp1 -> 1, tp2 -> 1), + untilOffsets = Map(tp1 -> 2, tp2 -> 1)) == + Seq(KafkaOffsetRange(tp1, 1, 2, None))) + } + + testWithMinPartitions("N TopicPartitions to N offset ranges", 3) { calc => + assert( + calc.getRanges( + fromOffsets = Map(tp1 -> 1, tp2 -> 1, tp3 -> 1), + untilOffsets = Map(tp1 -> 2, tp2 -> 2, tp3 -> 2)) == + Seq( + KafkaOffsetRange(tp1, 1, 2, None), + KafkaOffsetRange(tp2, 1, 2, None), + KafkaOffsetRange(tp3, 1, 2, None))) + } + + testWithMinPartitions("1 TopicPartition to N offset ranges", 4) { calc => + assert( + calc.getRanges( + fromOffsets = Map(tp1 -> 1), + untilOffsets = Map(tp1 -> 5)) == + Seq( + KafkaOffsetRange(tp1, 1, 2, None), + KafkaOffsetRange(tp1, 2, 3, None), + KafkaOffsetRange(tp1, 3, 4, None), + KafkaOffsetRange(tp1, 4, 5, None))) + + assert( + calc.getRanges( + fromOffsets = Map(tp1 -> 1), + untilOffsets = Map(tp1 -> 5), + executorLocations = Seq("location")) == + Seq( + KafkaOffsetRange(tp1, 1, 2, None), + KafkaOffsetRange(tp1, 2, 3, None), + KafkaOffsetRange(tp1, 3, 4, None), + KafkaOffsetRange(tp1, 4, 5, None))) // location pref not set when minPartition is set + } + + testWithMinPartitions("N skewed TopicPartitions to M offset ranges", 3) { calc => + assert( + calc.getRanges( + fromOffsets = Map(tp1 -> 1, tp2 -> 1), + untilOffsets = Map(tp1 -> 5, tp2 -> 21)) == + Seq( + KafkaOffsetRange(tp1, 1, 5, None), + KafkaOffsetRange(tp2, 1, 7, None), + KafkaOffsetRange(tp2, 7, 14, None), + KafkaOffsetRange(tp2, 14, 21, None))) + } + + testWithMinPartitions("range inexact multiple of minPartitions", 3) { calc => + assert( + calc.getRanges( + fromOffsets = Map(tp1 -> 1), + untilOffsets = Map(tp1 -> 11)) == + Seq( + KafkaOffsetRange(tp1, 1, 4, None), + KafkaOffsetRange(tp1, 4, 7, None), + KafkaOffsetRange(tp1, 7, 11, None))) + } + + testWithMinPartitions("empty ranges ignored", 3) { calc => + assert( + calc.getRanges( + fromOffsets = Map(tp1 -> 1, tp2 -> 1, tp3 -> 1), + untilOffsets = Map(tp1 -> 5, tp2 -> 21, tp3 -> 1)) == + Seq( + KafkaOffsetRange(tp1, 1, 5, None), + KafkaOffsetRange(tp2, 1, 7, None), + KafkaOffsetRange(tp2, 7, 14, None), + KafkaOffsetRange(tp2, 14, 21, None))) + } + + private val tp1 = new TopicPartition("t1", 1) + private val tp2 = new TopicPartition("t2", 1) + private val tp3 = new TopicPartition("t3", 1) +} From a89cdf55fa76fa23a524f0443e323498c3cc8664 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Mon, 5 Mar 2018 07:32:24 +0900 Subject: [PATCH 0420/1161] [SQL][MINOR] XPathDouble prettyPrint should say 'double' not 'float' ## What changes were proposed in this pull request? It looks like this was incorrectly copied from `XPathFloat` in the class above. ## How was this patch tested? (Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests) (If this patch involves UI changes, please attach a screenshot; otherwise, remove this) Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Eric Liang Closes #20730 from ericl/fix-typo-xpath. --- .../org/apache/spark/sql/catalyst/expressions/xml/xpath.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/xpath.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/xpath.scala index d0185562c9cfc..aacf1a44e2ad0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/xpath.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/xpath.scala @@ -160,7 +160,7 @@ case class XPathFloat(xml: Expression, path: Expression) extends XPathExtract { """) // scalastyle:on line.size.limit case class XPathDouble(xml: Expression, path: Expression) extends XPathExtract { - override def prettyName: String = "xpath_float" + override def prettyName: String = "xpath_double" override def dataType: DataType = DoubleType override def nullSafeEval(xml: Any, path: Any): Any = { From 7965c91d8a67c213ca5eebda5e46e7c49a8ba121 Mon Sep 17 00:00:00 2001 From: "Michael (Stu) Stewart" Date: Mon, 5 Mar 2018 13:36:42 +0900 Subject: [PATCH 0421/1161] [SPARK-23569][PYTHON] Allow pandas_udf to work with python3 style type-annotated functions ## What changes were proposed in this pull request? Check python version to determine whether to use `inspect.getargspec` or `inspect.getfullargspec` before applying `pandas_udf` core logic to a function. The former is python2.7 (deprecated in python3) and the latter is python3.x. The latter correctly accounts for type annotations, which are syntax errors in python2.x. ## How was this patch tested? Locally, on python 2.7 and 3.6. Author: Michael (Stu) Stewart Closes #20728 from mstewart141/pandas_udf_fix. --- python/pyspark/sql/tests.py | 18 ++++++++++++++++++ python/pyspark/sql/udf.py | 9 ++++++++- 2 files changed, 26 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 19653072ea316..fa3b7203e10ac 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -4381,6 +4381,24 @@ def test_timestamp_dst(self): result = df.withColumn('time', foo_udf(df.time)) self.assertEquals(df.collect(), result.collect()) + @unittest.skipIf(sys.version_info[:2] < (3, 5), "Type hints are supported from Python 3.5.") + def test_type_annotation(self): + from pyspark.sql.functions import pandas_udf + # Regression test to check if type hints can be used. See SPARK-23569. + # Note that it throws an error during compilation in lower Python versions if 'exec' + # is not used. Also, note that we explicitly use another dictionary to avoid modifications + # in the current 'locals()'. + # + # Hyukjin: I think it's an ugly way to test issues about syntax specific in + # higher versions of Python, which we shouldn't encourage. This was the last resort + # I could come up with at that time. + _locals = {} + exec( + "import pandas as pd\ndef noop(col: pd.Series) -> pd.Series: return col", + _locals) + df = self.spark.range(1).select(pandas_udf(f=_locals['noop'], returnType='bigint')('id')) + self.assertEqual(df.first()[0], 0) + @unittest.skipIf( not _have_pandas or not _have_pyarrow, diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index e5b35fc60e167..b9b490874f4fb 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -42,10 +42,17 @@ def _create_udf(f, returnType, evalType): PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF): import inspect + import sys from pyspark.sql.utils import require_minimum_pyarrow_version require_minimum_pyarrow_version() - argspec = inspect.getargspec(f) + + if sys.version_info[0] < 3: + # `getargspec` is deprecated since python3.0 (incompatible with function annotations). + # See SPARK-23569. + argspec = inspect.getargspec(f) + else: + argspec = inspect.getfullargspec(f) if evalType == PythonEvalType.SQL_SCALAR_PANDAS_UDF and len(argspec.args) == 0 and \ argspec.varargs is None: From 269cd53590dd155aeb5269efc909a6e228f21e22 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Sun, 4 Mar 2018 21:22:30 -0800 Subject: [PATCH 0422/1161] [MINOR][DOCS] Fix a link in "Compatibility with Apache Hive" ## What changes were proposed in this pull request? This PR fixes a broken link as below: **Before:** 2018-03-05 12 23 58 **After:** 2018-03-05 12 23 20 Also see https://spark.apache.org/docs/2.3.0/sql-programming-guide.html#compatibility-with-apache-hive ## How was this patch tested? Manually tested. I checked the same instances in `docs` directory. Seems this is the only one. Author: hyukjinkwon Closes #20733 from HyukjinKwon/minor-link. --- docs/sql-programming-guide.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index c37c338a134f3..4d0f015f401bb 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -2223,7 +2223,7 @@ referencing a singleton. Spark SQL is designed to be compatible with the Hive Metastore, SerDes and UDFs. Currently Hive SerDes and UDFs are based on Hive 1.2.1, and Spark SQL can be connected to different versions of Hive Metastore -(from 0.12.0 to 2.1.1. Also see [Interacting with Different Versions of Hive Metastore] (#interacting-with-different-versions-of-hive-metastore)). +(from 0.12.0 to 2.1.1. Also see [Interacting with Different Versions of Hive Metastore](#interacting-with-different-versions-of-hive-metastore)). #### Deploying in Existing Hive Warehouses From 2ce37b50fc01558f49ad22f89c8659f50544ffec Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Mon, 5 Mar 2018 11:39:01 +0100 Subject: [PATCH 0423/1161] [SPARK-23546][SQL] Refactor stateless methods/values in CodegenContext ## What changes were proposed in this pull request? A current `CodegenContext` class has immutable value or method without mutable state, too. This refactoring moves them to `CodeGenerator` object class which can be accessed from anywhere without an instantiated `CodegenContext` in the program. ## How was this patch tested? Existing tests Author: Kazuaki Ishizaki Closes #20700 from kiszk/SPARK-23546. --- .../catalyst/expressions/BoundAttribute.scala | 9 +- .../spark/sql/catalyst/expressions/Cast.scala | 35 +- .../sql/catalyst/expressions/Expression.scala | 16 +- .../MonotonicallyIncreasingID.scala | 8 +- .../sql/catalyst/expressions/ScalaUDF.scala | 7 +- .../expressions/SparkPartitionID.scala | 7 +- .../sql/catalyst/expressions/TimeWindow.scala | 4 +- .../sql/catalyst/expressions/arithmetic.scala | 51 +- .../expressions/bitwiseExpressions.scala | 2 +- .../expressions/codegen/CodeGenerator.scala | 458 +++++++++--------- .../expressions/codegen/CodegenFallback.scala | 7 +- .../codegen/GenerateMutableProjection.scala | 6 +- .../codegen/GenerateOrdering.scala | 4 +- .../codegen/GenerateSafeProjection.scala | 6 +- .../codegen/GenerateUnsafeProjection.scala | 11 +- .../expressions/collectionOperations.scala | 6 +- .../expressions/complexTypeCreator.scala | 4 +- .../expressions/complexTypeExtractors.scala | 15 +- .../expressions/conditionalExpressions.scala | 10 +- .../expressions/datetimeExpressions.scala | 18 +- .../spark/sql/catalyst/expressions/hash.scala | 25 +- .../catalyst/expressions/inputFileBlock.scala | 8 +- .../sql/catalyst/expressions/literals.scala | 8 +- .../expressions/mathExpressions.scala | 5 +- .../expressions/nullExpressions.scala | 22 +- .../expressions/objects/objects.scala | 99 ++-- .../sql/catalyst/expressions/predicates.scala | 14 +- .../expressions/randomExpressions.scala | 8 +- .../expressions/regexpExpressions.scala | 8 +- .../expressions/stringExpressions.scala | 39 +- .../expressions/CodeGenerationSuite.scala | 4 +- .../sql/execution/ColumnarBatchScan.scala | 13 +- .../spark/sql/execution/ExpandExec.scala | 5 +- .../spark/sql/execution/GenerateExec.scala | 8 +- .../apache/spark/sql/execution/SortExec.scala | 5 +- .../sql/execution/WholeStageCodegenExec.scala | 2 +- .../aggregate/HashAggregateExec.scala | 16 +- .../aggregate/HashMapGenerator.scala | 8 +- .../aggregate/RowBasedHashMapGenerator.scala | 8 +- .../VectorizedHashMapGenerator.scala | 11 +- .../execution/basicPhysicalOperators.scala | 10 +- .../columnar/GenerateColumnAccessor.scala | 2 +- .../joins/BroadcastHashJoinExec.scala | 5 +- .../execution/joins/SortMergeJoinExec.scala | 8 +- .../apache/spark/sql/execution/limit.scala | 7 +- 45 files changed, 535 insertions(+), 497 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index 6a17a397b3ef2..89ffbb0016916 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors.attachTree -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} import org.apache.spark.sql.types._ /** @@ -66,13 +66,14 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) ev.copy(code = oev.code) } else { assert(ctx.INPUT_ROW != null, "INPUT_ROW and currentVars cannot both be null.") - val javaType = ctx.javaType(dataType) - val value = ctx.getValue(ctx.INPUT_ROW, dataType, ordinal.toString) + val javaType = CodeGenerator.javaType(dataType) + val value = CodeGenerator.getValue(ctx.INPUT_ROW, dataType, ordinal.toString) if (nullable) { ev.copy(code = s""" |boolean ${ev.isNull} = ${ctx.INPUT_ROW}.isNullAt($ordinal); - |$javaType ${ev.value} = ${ev.isNull} ? ${ctx.defaultValue(dataType)} : ($value); + |$javaType ${ev.value} = ${ev.isNull} ? + | ${CodeGenerator.defaultValue(dataType)} : ($value); """.stripMargin) } else { ev.copy(code = s"$javaType ${ev.value} = $value;", isNull = "false") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 79b051670e9e4..12330bfa55ab9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -669,7 +669,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String result: String, resultIsNull: String, resultType: DataType, cast: CastFunction): String = { s""" boolean $resultIsNull = $inputIsNull; - ${ctx.javaType(resultType)} $result = ${ctx.defaultValue(resultType)}; + ${CodeGenerator.javaType(resultType)} $result = ${CodeGenerator.defaultValue(resultType)}; if (!$inputIsNull) { ${cast(input, result, resultIsNull)} } @@ -685,7 +685,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String val funcName = ctx.freshName("elementToString") val elementToStringFunc = ctx.addNewFunction(funcName, s""" - |private UTF8String $funcName(${ctx.javaType(et)} element) { + |private UTF8String $funcName(${CodeGenerator.javaType(et)} element) { | UTF8String elementStr = null; | ${elementToStringCode("element", "elementStr", null /* resultIsNull won't be used */)} | return elementStr; @@ -697,13 +697,13 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String |$buffer.append("["); |if ($array.numElements() > 0) { | if (!$array.isNullAt(0)) { - | $buffer.append($elementToStringFunc(${ctx.getValue(array, et, "0")})); + | $buffer.append($elementToStringFunc(${CodeGenerator.getValue(array, et, "0")})); | } | for (int $loopIndex = 1; $loopIndex < $array.numElements(); $loopIndex++) { | $buffer.append(","); | if (!$array.isNullAt($loopIndex)) { | $buffer.append(" "); - | $buffer.append($elementToStringFunc(${ctx.getValue(array, et, loopIndex)})); + | $buffer.append($elementToStringFunc(${CodeGenerator.getValue(array, et, loopIndex)})); | } | } |} @@ -723,7 +723,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String val dataToStringCode = castToStringCode(dataType, ctx) ctx.addNewFunction(funcName, s""" - |private UTF8String $funcName(${ctx.javaType(dataType)} data) { + |private UTF8String $funcName(${CodeGenerator.javaType(dataType)} data) { | UTF8String dataStr = null; | ${dataToStringCode("data", "dataStr", null /* resultIsNull won't be used */)} | return dataStr; @@ -734,23 +734,26 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String val keyToStringFunc = dataToStringFunc("keyToString", kt) val valueToStringFunc = dataToStringFunc("valueToString", vt) val loopIndex = ctx.freshName("loopIndex") + val getMapFirstKey = CodeGenerator.getValue(s"$map.keyArray()", kt, "0") + val getMapFirstValue = CodeGenerator.getValue(s"$map.valueArray()", vt, "0") + val getMapKeyArray = CodeGenerator.getValue(s"$map.keyArray()", kt, loopIndex) + val getMapValueArray = CodeGenerator.getValue(s"$map.valueArray()", vt, loopIndex) s""" |$buffer.append("["); |if ($map.numElements() > 0) { - | $buffer.append($keyToStringFunc(${ctx.getValue(s"$map.keyArray()", kt, "0")})); + | $buffer.append($keyToStringFunc($getMapFirstKey)); | $buffer.append(" ->"); | if (!$map.valueArray().isNullAt(0)) { | $buffer.append(" "); - | $buffer.append($valueToStringFunc(${ctx.getValue(s"$map.valueArray()", vt, "0")})); + | $buffer.append($valueToStringFunc($getMapFirstValue)); | } | for (int $loopIndex = 1; $loopIndex < $map.numElements(); $loopIndex++) { | $buffer.append(", "); - | $buffer.append($keyToStringFunc(${ctx.getValue(s"$map.keyArray()", kt, loopIndex)})); + | $buffer.append($keyToStringFunc($getMapKeyArray)); | $buffer.append(" ->"); | if (!$map.valueArray().isNullAt($loopIndex)) { | $buffer.append(" "); - | $buffer.append($valueToStringFunc( - | ${ctx.getValue(s"$map.valueArray()", vt, loopIndex)})); + | $buffer.append($valueToStringFunc($getMapValueArray)); | } | } |} @@ -773,7 +776,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String | ${if (i != 0) s"""$buffer.append(" ");""" else ""} | | // Append $i field into the string buffer - | ${ctx.javaType(ft)} $field = ${ctx.getValue(row, ft, s"$i")}; + | ${CodeGenerator.javaType(ft)} $field = ${CodeGenerator.getValue(row, ft, s"$i")}; | UTF8String $fieldStr = null; | ${fieldToStringCode(field, fieldStr, null /* resultIsNull won't be used */)} | $buffer.append($fieldStr); @@ -1202,8 +1205,8 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String $values[$j] = null; } else { boolean $fromElementNull = false; - ${ctx.javaType(fromType)} $fromElementPrim = - ${ctx.getValue(c, fromType, j)}; + ${CodeGenerator.javaType(fromType)} $fromElementPrim = + ${CodeGenerator.getValue(c, fromType, j)}; ${castCode(ctx, fromElementPrim, fromElementNull, toElementPrim, toElementNull, toType, elementCast)} if ($toElementNull) { @@ -1259,20 +1262,20 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String val fromFieldNull = ctx.freshName("ffn") val toFieldPrim = ctx.freshName("tfp") val toFieldNull = ctx.freshName("tfn") - val fromType = ctx.javaType(from.fields(i).dataType) + val fromType = CodeGenerator.javaType(from.fields(i).dataType) s""" boolean $fromFieldNull = $tmpInput.isNullAt($i); if ($fromFieldNull) { $tmpResult.setNullAt($i); } else { $fromType $fromFieldPrim = - ${ctx.getValue(tmpInput, from.fields(i).dataType, i.toString)}; + ${CodeGenerator.getValue(tmpInput, from.fields(i).dataType, i.toString)}; ${castCode(ctx, fromFieldPrim, fromFieldNull, toFieldPrim, toFieldNull, to.fields(i).dataType, cast)} if ($toFieldNull) { $tmpResult.setNullAt($i); } else { - ${ctx.setColumn(tmpResult, to.fields(i).dataType, i, toFieldPrim)}; + ${CodeGenerator.setColumn(tmpResult, to.fields(i).dataType, i, toFieldPrim)}; } } """ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 4568714933095..ed90b185865a0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -119,7 +119,7 @@ abstract class Expression extends TreeNode[Expression] { // TODO: support whole stage codegen too if (eval.code.trim.length > 1024 && ctx.INPUT_ROW != null && ctx.currentVars == null) { val setIsNull = if (eval.isNull != "false" && eval.isNull != "true") { - val globalIsNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, "globalIsNull") + val globalIsNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "globalIsNull") val localIsNull = eval.isNull eval.isNull = globalIsNull s"$globalIsNull = $localIsNull;" @@ -127,7 +127,7 @@ abstract class Expression extends TreeNode[Expression] { "" } - val javaType = ctx.javaType(dataType) + val javaType = CodeGenerator.javaType(dataType) val newValue = ctx.freshName("value") val funcName = ctx.freshName(nodeName) @@ -411,14 +411,14 @@ abstract class UnaryExpression extends Expression { ev.copy(code = s""" ${childGen.code} boolean ${ev.isNull} = ${childGen.isNull}; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; $nullSafeEval """) } else { ev.copy(code = s""" boolean ${ev.isNull} = false; ${childGen.code} - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; $resultCode""", isNull = "false") } } @@ -510,7 +510,7 @@ abstract class BinaryExpression extends Expression { ev.copy(code = s""" boolean ${ev.isNull} = true; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; $nullSafeEval """) } else { @@ -518,7 +518,7 @@ abstract class BinaryExpression extends Expression { boolean ${ev.isNull} = false; ${leftGen.code} ${rightGen.code} - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; $resultCode""", isNull = "false") } } @@ -654,7 +654,7 @@ abstract class TernaryExpression extends Expression { ev.copy(code = s""" boolean ${ev.isNull} = true; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; $nullSafeEval""") } else { ev.copy(code = s""" @@ -662,7 +662,7 @@ abstract class TernaryExpression extends Expression { ${leftGen.code} ${midGen.code} ${rightGen.code} - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; $resultCode""", isNull = "false") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala index 11fb579dfa88c..4523079060896 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} import org.apache.spark.sql.types.{DataType, LongType} /** @@ -65,14 +65,14 @@ case class MonotonicallyIncreasingID() extends LeafExpression with Nondeterminis } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val countTerm = ctx.addMutableState(ctx.JAVA_LONG, "count") + val countTerm = ctx.addMutableState(CodeGenerator.JAVA_LONG, "count") val partitionMaskTerm = "partitionMask" - ctx.addImmutableStateIfNotExists(ctx.JAVA_LONG, partitionMaskTerm) + ctx.addImmutableStateIfNotExists(CodeGenerator.JAVA_LONG, partitionMaskTerm) ctx.addPartitionInitializationStatement(s"$countTerm = 0L;") ctx.addPartitionInitializationStatement(s"$partitionMaskTerm = ((long) partitionIndex) << 33;") ev.copy(code = s""" - final ${ctx.javaType(dataType)} ${ev.value} = $partitionMaskTerm + $countTerm; + final ${CodeGenerator.javaType(dataType)} ${ev.value} = $partitionMaskTerm + $countTerm; $countTerm++;""", isNull = "false") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala index 989c02305620a..e869258469a97 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala @@ -1018,11 +1018,12 @@ case class ScalaUDF( val udf = ctx.addReferenceObj("udf", function, s"scala.Function${children.length}") val getFuncResult = s"$udf.apply(${funcArgs.mkString(", ")})" val resultConverter = s"$convertersTerm[${children.length}]" + val boxedType = CodeGenerator.boxedType(dataType) val callFunc = s""" - |${ctx.boxedType(dataType)} $resultTerm = null; + |$boxedType $resultTerm = null; |try { - | $resultTerm = (${ctx.boxedType(dataType)})$resultConverter.apply($getFuncResult); + | $resultTerm = ($boxedType)$resultConverter.apply($getFuncResult); |} catch (Exception e) { | throw new org.apache.spark.SparkException($errorMsgTerm, e); |} @@ -1035,7 +1036,7 @@ case class ScalaUDF( |$callFunc | |boolean ${ev.isNull} = $resultTerm == null; - |${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + |${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; |if (!${ev.isNull}) { | ${ev.value} = $resultTerm; |} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala index a160b9b275290..cc6a769d032d3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} import org.apache.spark.sql.types.{DataType, IntegerType} /** @@ -44,8 +44,9 @@ case class SparkPartitionID() extends LeafExpression with Nondeterministic { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val idTerm = "partitionId" - ctx.addImmutableStateIfNotExists(ctx.JAVA_INT, idTerm) + ctx.addImmutableStateIfNotExists(CodeGenerator.JAVA_INT, idTerm) ctx.addPartitionInitializationStatement(s"$idTerm = partitionIndex;") - ev.copy(code = s"final ${ctx.javaType(dataType)} ${ev.value} = $idTerm;", isNull = "false") + ev.copy(code = s"final ${CodeGenerator.javaType(dataType)} ${ev.value} = $idTerm;", + isNull = "false") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala index 9a9f579b37f58..6c4a3601c1730 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala @@ -22,7 +22,7 @@ import org.apache.commons.lang3.StringUtils import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.CalendarInterval @@ -165,7 +165,7 @@ case class PreciseTimestampConversion( val eval = child.genCode(ctx) ev.copy(code = eval.code + s"""boolean ${ev.isNull} = ${eval.isNull}; - |${ctx.javaType(dataType)} ${ev.value} = ${eval.value}; + |${CodeGenerator.javaType(dataType)} ${ev.value} = ${eval.value}; """.stripMargin) } override def nullSafeEval(input: Any): Any = input diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 8bb14598a6d7b..508bdd5050b54 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -49,8 +49,8 @@ case class UnaryMinus(child: Expression) extends UnaryExpression // codegen would fail to compile if we just write (-($c)) // for example, we could not write --9223372036854775808L in code s""" - ${ctx.javaType(dt)} $originValue = (${ctx.javaType(dt)})($eval); - ${ev.value} = (${ctx.javaType(dt)})(-($originValue)); + ${CodeGenerator.javaType(dt)} $originValue = (${CodeGenerator.javaType(dt)})($eval); + ${ev.value} = (${CodeGenerator.javaType(dt)})(-($originValue)); """}) case dt: CalendarIntervalType => defineCodeGen(ctx, ev, c => s"$c.negate()") } @@ -107,7 +107,7 @@ case class Abs(child: Expression) case dt: DecimalType => defineCodeGen(ctx, ev, c => s"$c.abs()") case dt: NumericType => - defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dt)})(java.lang.Math.abs($c))") + defineCodeGen(ctx, ev, c => s"(${CodeGenerator.javaType(dt)})(java.lang.Math.abs($c))") } protected override def nullSafeEval(input: Any): Any = numeric.abs(input) @@ -129,7 +129,7 @@ abstract class BinaryArithmetic extends BinaryOperator with NullIntolerant { // byte and short are casted into int when add, minus, times or divide case ByteType | ShortType => defineCodeGen(ctx, ev, - (eval1, eval2) => s"(${ctx.javaType(dataType)})($eval1 $symbol $eval2)") + (eval1, eval2) => s"(${CodeGenerator.javaType(dataType)})($eval1 $symbol $eval2)") case _ => defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1 $symbol $eval2") } @@ -167,7 +167,7 @@ case class Add(left: Expression, right: Expression) extends BinaryArithmetic { defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$$plus($eval2)") case ByteType | ShortType => defineCodeGen(ctx, ev, - (eval1, eval2) => s"(${ctx.javaType(dataType)})($eval1 $symbol $eval2)") + (eval1, eval2) => s"(${CodeGenerator.javaType(dataType)})($eval1 $symbol $eval2)") case CalendarIntervalType => defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.add($eval2)") case _ => @@ -203,7 +203,7 @@ case class Subtract(left: Expression, right: Expression) extends BinaryArithmeti defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$$minus($eval2)") case ByteType | ShortType => defineCodeGen(ctx, ev, - (eval1, eval2) => s"(${ctx.javaType(dataType)})($eval1 $symbol $eval2)") + (eval1, eval2) => s"(${CodeGenerator.javaType(dataType)})($eval1 $symbol $eval2)") case CalendarIntervalType => defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.subtract($eval2)") case _ => @@ -278,7 +278,7 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic } else { s"${eval2.value} == 0" } - val javaType = ctx.javaType(dataType) + val javaType = CodeGenerator.javaType(dataType) val divide = if (dataType.isInstanceOf[DecimalType]) { s"${eval1.value}.$decimalMethod(${eval2.value})" } else { @@ -288,7 +288,7 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic ev.copy(code = s""" ${eval2.code} boolean ${ev.isNull} = false; - $javaType ${ev.value} = ${ctx.defaultValue(javaType)}; + $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if ($isZero) { ${ev.isNull} = true; } else { @@ -299,7 +299,7 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic ev.copy(code = s""" ${eval2.code} boolean ${ev.isNull} = false; - $javaType ${ev.value} = ${ctx.defaultValue(javaType)}; + $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (${eval2.isNull} || $isZero) { ${ev.isNull} = true; } else { @@ -365,7 +365,7 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet } else { s"${eval2.value} == 0" } - val javaType = ctx.javaType(dataType) + val javaType = CodeGenerator.javaType(dataType) val remainder = if (dataType.isInstanceOf[DecimalType]) { s"${eval1.value}.$decimalMethod(${eval2.value})" } else { @@ -375,7 +375,7 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet ev.copy(code = s""" ${eval2.code} boolean ${ev.isNull} = false; - $javaType ${ev.value} = ${ctx.defaultValue(javaType)}; + $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if ($isZero) { ${ev.isNull} = true; } else { @@ -386,7 +386,7 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet ev.copy(code = s""" ${eval2.code} boolean ${ev.isNull} = false; - $javaType ${ev.value} = ${ctx.defaultValue(javaType)}; + $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (${eval2.isNull} || $isZero) { ${ev.isNull} = true; } else { @@ -454,13 +454,13 @@ case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic { s"${eval2.value} == 0" } val remainder = ctx.freshName("remainder") - val javaType = ctx.javaType(dataType) + val javaType = CodeGenerator.javaType(dataType) val result = dataType match { case DecimalType.Fixed(_, _) => val decimalAdd = "$plus" s""" - ${ctx.javaType(dataType)} $remainder = ${eval1.value}.remainder(${eval2.value}); + $javaType $remainder = ${eval1.value}.remainder(${eval2.value}); if ($remainder.compare(new org.apache.spark.sql.types.Decimal().set(0)) < 0) { ${ev.value}=($remainder.$decimalAdd(${eval2.value})).remainder(${eval2.value}); } else { @@ -470,17 +470,16 @@ case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic { // byte and short are casted into int when add, minus, times or divide case ByteType | ShortType => s""" - ${ctx.javaType(dataType)} $remainder = - (${ctx.javaType(dataType)})(${eval1.value} % ${eval2.value}); + $javaType $remainder = ($javaType)(${eval1.value} % ${eval2.value}); if ($remainder < 0) { - ${ev.value}=(${ctx.javaType(dataType)})(($remainder + ${eval2.value}) % ${eval2.value}); + ${ev.value}=($javaType)(($remainder + ${eval2.value}) % ${eval2.value}); } else { ${ev.value}=$remainder; } """ case _ => s""" - ${ctx.javaType(dataType)} $remainder = ${eval1.value} % ${eval2.value}; + $javaType $remainder = ${eval1.value} % ${eval2.value}; if ($remainder < 0) { ${ev.value}=($remainder + ${eval2.value}) % ${eval2.value}; } else { @@ -493,7 +492,7 @@ case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic { ev.copy(code = s""" ${eval2.code} boolean ${ev.isNull} = false; - $javaType ${ev.value} = ${ctx.defaultValue(javaType)}; + $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if ($isZero) { ${ev.isNull} = true; } else { @@ -504,7 +503,7 @@ case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic { ev.copy(code = s""" ${eval2.code} boolean ${ev.isNull} = false; - $javaType ${ev.value} = ${ctx.defaultValue(javaType)}; + $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (${eval2.isNull} || $isZero) { ${ev.isNull} = true; } else { @@ -602,7 +601,7 @@ case class Least(children: Seq[Expression]) extends Expression { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val evalChildren = children.map(_.genCode(ctx)) - ev.isNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull) + ev.isNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, ev.isNull) val evals = evalChildren.map(eval => s""" |${eval.code} @@ -614,7 +613,7 @@ case class Least(children: Seq[Expression]) extends Expression { """.stripMargin ) - val resultType = ctx.javaType(dataType) + val resultType = CodeGenerator.javaType(dataType) val codes = ctx.splitExpressionsWithCurrentInputs( expressions = evals, funcName = "least", @@ -629,7 +628,7 @@ case class Least(children: Seq[Expression]) extends Expression { ev.copy(code = s""" |${ev.isNull} = true; - |${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + |$resultType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; |$codes """.stripMargin) } @@ -681,7 +680,7 @@ case class Greatest(children: Seq[Expression]) extends Expression { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val evalChildren = children.map(_.genCode(ctx)) - ev.isNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull) + ev.isNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, ev.isNull) val evals = evalChildren.map(eval => s""" |${eval.code} @@ -693,7 +692,7 @@ case class Greatest(children: Seq[Expression]) extends Expression { """.stripMargin ) - val resultType = ctx.javaType(dataType) + val resultType = CodeGenerator.javaType(dataType) val codes = ctx.splitExpressionsWithCurrentInputs( expressions = evals, funcName = "greatest", @@ -708,7 +707,7 @@ case class Greatest(children: Seq[Expression]) extends Expression { ev.copy(code = s""" |${ev.isNull} = true; - |${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + |$resultType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; |$codes """.stripMargin) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala index 173481f06a716..cc24e397cc14a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala @@ -147,7 +147,7 @@ case class BitwiseNot(child: Expression) extends UnaryExpression with ExpectsInp } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dataType)}) ~($c)") + defineCodeGen(ctx, ev, c => s"(${CodeGenerator.javaType(dataType)}) ~($c)") } protected override def nullSafeEval(input: Any): Any = not(input) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 60a6f50472504..793824b0b0a2f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -59,6 +59,11 @@ import org.apache.spark.util.{ParentClassLoader, Utils} case class ExprCode(var code: String, var isNull: String, var value: String) object ExprCode { + def forNullValue(dataType: DataType): ExprCode = { + val defaultValueLiteral = CodeGenerator.defaultValue(dataType, typedNull = true) + ExprCode(code = "", isNull = "true", value = defaultValueLiteral) + } + def forNonNullValue(value: String): ExprCode = { ExprCode(code = "", isNull = "false", value = value) } @@ -105,6 +110,8 @@ private[codegen] case class NewFunctionSpec( */ class CodegenContext { + import CodeGenerator._ + /** * Holding a list of objects that could be used passed into generated class. */ @@ -196,11 +203,11 @@ class CodegenContext { /** * Returns the reference of next available slot in current compacted array. The size of each - * compacted array is controlled by the constant `CodeGenerator.MUTABLESTATEARRAY_SIZE_LIMIT`. + * compacted array is controlled by the constant `MUTABLESTATEARRAY_SIZE_LIMIT`. * Once reaching the threshold, new compacted array is created. */ def getNextSlot(): String = { - if (currentIndex < CodeGenerator.MUTABLESTATEARRAY_SIZE_LIMIT) { + if (currentIndex < MUTABLESTATEARRAY_SIZE_LIMIT) { val res = s"${arrayNames.last}[$currentIndex]" currentIndex += 1 res @@ -247,10 +254,10 @@ class CodegenContext { * are satisfied: * 1. forceInline is true * 2. its type is primitive type and the total number of the inlined mutable variables - * is less than `CodeGenerator.OUTER_CLASS_VARIABLES_THRESHOLD` + * is less than `OUTER_CLASS_VARIABLES_THRESHOLD` * 3. its type is multi-dimensional array * When a variable is compacted into an array, the max size of the array for compaction - * is given by `CodeGenerator.MUTABLESTATEARRAY_SIZE_LIMIT`. + * is given by `MUTABLESTATEARRAY_SIZE_LIMIT`. */ def addMutableState( javaType: String, @@ -261,7 +268,7 @@ class CodegenContext { // want to put a primitive type variable at outerClass for performance val canInlinePrimitive = isPrimitiveType(javaType) && - (inlinedMutableStates.length < CodeGenerator.OUTER_CLASS_VARIABLES_THRESHOLD) + (inlinedMutableStates.length < OUTER_CLASS_VARIABLES_THRESHOLD) if (forceInline || canInlinePrimitive || javaType.contains("[][]")) { val varName = if (useFreshName) freshName(variableName) else variableName val initCode = initFunc(varName) @@ -339,7 +346,7 @@ class CodegenContext { val length = if (index + 1 == numArrays) { mutableStateArrays.getCurrentIndex } else { - CodeGenerator.MUTABLESTATEARRAY_SIZE_LIMIT + MUTABLESTATEARRAY_SIZE_LIMIT } if (javaType.contains("[]")) { // initializer had an one-dimensional array variable @@ -468,7 +475,7 @@ class CodegenContext { inlineToOuterClass: Boolean): NewFunctionSpec = { val (className, classInstance) = if (inlineToOuterClass) { outerClassName -> "" - } else if (currClassSize > CodeGenerator.GENERATED_CLASS_SIZE_THRESHOLD) { + } else if (currClassSize > GENERATED_CLASS_SIZE_THRESHOLD) { val className = freshName("NestedClass") val classInstance = freshName("nestedClassInstance") @@ -537,14 +544,6 @@ class CodegenContext { extraClasses.append(code) } - final val JAVA_BOOLEAN = "boolean" - final val JAVA_BYTE = "byte" - final val JAVA_SHORT = "short" - final val JAVA_INT = "int" - final val JAVA_LONG = "long" - final val JAVA_FLOAT = "float" - final val JAVA_DOUBLE = "double" - /** * The map from a variable name to it's next ID. */ @@ -580,196 +579,6 @@ class CodegenContext { } } - /** - * Returns the specialized code to access a value from `inputRow` at `ordinal`. - */ - def getValue(input: String, dataType: DataType, ordinal: String): String = { - val jt = javaType(dataType) - dataType match { - case _ if isPrimitiveType(jt) => s"$input.get${primitiveTypeName(jt)}($ordinal)" - case t: DecimalType => s"$input.getDecimal($ordinal, ${t.precision}, ${t.scale})" - case StringType => s"$input.getUTF8String($ordinal)" - case BinaryType => s"$input.getBinary($ordinal)" - case CalendarIntervalType => s"$input.getInterval($ordinal)" - case t: StructType => s"$input.getStruct($ordinal, ${t.size})" - case _: ArrayType => s"$input.getArray($ordinal)" - case _: MapType => s"$input.getMap($ordinal)" - case NullType => "null" - case udt: UserDefinedType[_] => getValue(input, udt.sqlType, ordinal) - case _ => s"($jt)$input.get($ordinal, null)" - } - } - - /** - * Returns the code to update a column in Row for a given DataType. - */ - def setColumn(row: String, dataType: DataType, ordinal: Int, value: String): String = { - val jt = javaType(dataType) - dataType match { - case _ if isPrimitiveType(jt) => s"$row.set${primitiveTypeName(jt)}($ordinal, $value)" - case t: DecimalType => s"$row.setDecimal($ordinal, $value, ${t.precision})" - case udt: UserDefinedType[_] => setColumn(row, udt.sqlType, ordinal, value) - // The UTF8String, InternalRow, ArrayData and MapData may came from UnsafeRow, we should copy - // it to avoid keeping a "pointer" to a memory region which may get updated afterwards. - case StringType | _: StructType | _: ArrayType | _: MapType => - s"$row.update($ordinal, $value.copy())" - case _ => s"$row.update($ordinal, $value)" - } - } - - /** - * Update a column in MutableRow from ExprCode. - * - * @param isVectorized True if the underlying row is of type `ColumnarBatch.Row`, false otherwise - */ - def updateColumn( - row: String, - dataType: DataType, - ordinal: Int, - ev: ExprCode, - nullable: Boolean, - isVectorized: Boolean = false): String = { - if (nullable) { - // Can't call setNullAt on DecimalType, because we need to keep the offset - if (!isVectorized && dataType.isInstanceOf[DecimalType]) { - s""" - if (!${ev.isNull}) { - ${setColumn(row, dataType, ordinal, ev.value)}; - } else { - ${setColumn(row, dataType, ordinal, "null")}; - } - """ - } else { - s""" - if (!${ev.isNull}) { - ${setColumn(row, dataType, ordinal, ev.value)}; - } else { - $row.setNullAt($ordinal); - } - """ - } - } else { - s"""${setColumn(row, dataType, ordinal, ev.value)};""" - } - } - - /** - * Returns the specialized code to set a given value in a column vector for a given `DataType`. - */ - def setValue(vector: String, rowId: String, dataType: DataType, value: String): String = { - val jt = javaType(dataType) - dataType match { - case _ if isPrimitiveType(jt) => - s"$vector.put${primitiveTypeName(jt)}($rowId, $value);" - case t: DecimalType => s"$vector.putDecimal($rowId, $value, ${t.precision});" - case t: StringType => s"$vector.putByteArray($rowId, $value.getBytes());" - case _ => - throw new IllegalArgumentException(s"cannot generate code for unsupported type: $dataType") - } - } - - /** - * Returns the specialized code to set a given value in a column vector for a given `DataType` - * that could potentially be nullable. - */ - def updateColumn( - vector: String, - rowId: String, - dataType: DataType, - ev: ExprCode, - nullable: Boolean): String = { - if (nullable) { - s""" - if (!${ev.isNull}) { - ${setValue(vector, rowId, dataType, ev.value)} - } else { - $vector.putNull($rowId); - } - """ - } else { - s"""${setValue(vector, rowId, dataType, ev.value)};""" - } - } - - /** - * Returns the specialized code to access a value from a column vector for a given `DataType`. - */ - def getValueFromVector(vector: String, dataType: DataType, rowId: String): String = { - if (dataType.isInstanceOf[StructType]) { - // `ColumnVector.getStruct` is different from `InternalRow.getStruct`, it only takes an - // `ordinal` parameter. - s"$vector.getStruct($rowId)" - } else { - getValue(vector, dataType, rowId) - } - } - - /** - * Returns the name used in accessor and setter for a Java primitive type. - */ - def primitiveTypeName(jt: String): String = jt match { - case JAVA_INT => "Int" - case _ => boxedType(jt) - } - - def primitiveTypeName(dt: DataType): String = primitiveTypeName(javaType(dt)) - - /** - * Returns the Java type for a DataType. - */ - def javaType(dt: DataType): String = dt match { - case BooleanType => JAVA_BOOLEAN - case ByteType => JAVA_BYTE - case ShortType => JAVA_SHORT - case IntegerType | DateType => JAVA_INT - case LongType | TimestampType => JAVA_LONG - case FloatType => JAVA_FLOAT - case DoubleType => JAVA_DOUBLE - case dt: DecimalType => "Decimal" - case BinaryType => "byte[]" - case StringType => "UTF8String" - case CalendarIntervalType => "CalendarInterval" - case _: StructType => "InternalRow" - case _: ArrayType => "ArrayData" - case _: MapType => "MapData" - case udt: UserDefinedType[_] => javaType(udt.sqlType) - case ObjectType(cls) if cls.isArray => s"${javaType(ObjectType(cls.getComponentType))}[]" - case ObjectType(cls) => cls.getName - case _ => "Object" - } - - /** - * Returns the boxed type in Java. - */ - def boxedType(jt: String): String = jt match { - case JAVA_BOOLEAN => "Boolean" - case JAVA_BYTE => "Byte" - case JAVA_SHORT => "Short" - case JAVA_INT => "Integer" - case JAVA_LONG => "Long" - case JAVA_FLOAT => "Float" - case JAVA_DOUBLE => "Double" - case other => other - } - - def boxedType(dt: DataType): String = boxedType(javaType(dt)) - - /** - * Returns the representation of default value for a given Java Type. - */ - def defaultValue(jt: String): String = jt match { - case JAVA_BOOLEAN => "false" - case JAVA_BYTE => "(byte)-1" - case JAVA_SHORT => "(short)-1" - case JAVA_INT => "-1" - case JAVA_LONG => "-1L" - case JAVA_FLOAT => "-1.0f" - case JAVA_DOUBLE => "-1.0" - case _ => "null" - } - - def defaultValue(dt: DataType): String = defaultValue(javaType(dt)) - /** * Generates code for equal expression in Java. */ @@ -812,6 +621,7 @@ class CodegenContext { val isNullB = freshName("isNullB") val compareFunc = freshName("compareArray") val minLength = freshName("minLength") + val jt = javaType(elementType) val funcCode: String = s""" public int $compareFunc(ArrayData a, ArrayData b) { @@ -833,8 +643,8 @@ class CodegenContext { } else if ($isNullB) { return 1; } else { - ${javaType(elementType)} $elementA = ${getValue("a", elementType, "i")}; - ${javaType(elementType)} $elementB = ${getValue("b", elementType, "i")}; + $jt $elementA = ${getValue("a", elementType, "i")}; + $jt $elementB = ${getValue("b", elementType, "i")}; int comp = ${genComp(elementType, elementA, elementB)}; if (comp != 0) { return comp; @@ -906,19 +716,6 @@ class CodegenContext { } } - /** - * List of java data types that have special accessors and setters in [[InternalRow]]. - */ - val primitiveTypes = - Seq(JAVA_BOOLEAN, JAVA_BYTE, JAVA_SHORT, JAVA_INT, JAVA_LONG, JAVA_FLOAT, JAVA_DOUBLE) - - /** - * Returns true if the Java type has a special accessor and setter in [[InternalRow]]. - */ - def isPrimitiveType(jt: String): Boolean = primitiveTypes.contains(jt) - - def isPrimitiveType(dt: DataType): Boolean = isPrimitiveType(javaType(dt)) - /** * Splits the generated code of expressions into multiple functions, because function has * 64kb code size limit in JVM. If the class to which the function would be inlined would grow @@ -1089,7 +886,7 @@ class CodegenContext { // for performance reasons, the functions are prepended, instead of appended, // thus here they are in reversed order val orderedFunctions = innerClassFunctions.reverse - if (orderedFunctions.size > CodeGenerator.MERGE_SPLIT_METHODS_THRESHOLD) { + if (orderedFunctions.size > MERGE_SPLIT_METHODS_THRESHOLD) { // Adding a new function to each inner class which contains the invocation of all the // ones which have been added to that inner class. For example, // private class NestedClass { @@ -1289,7 +1086,7 @@ class CodegenContext { * length less than a pre-defined constant. */ def isValidParamLength(paramLength: Int): Boolean = { - paramLength <= CodeGenerator.MAX_JVM_METHOD_PARAMS_LENGTH + paramLength <= MAX_JVM_METHOD_PARAMS_LENGTH } } @@ -1524,4 +1321,221 @@ object CodeGenerator extends Logging { result } }) + + /** + * Name of Java primitive data type + */ + final val JAVA_BOOLEAN = "boolean" + final val JAVA_BYTE = "byte" + final val JAVA_SHORT = "short" + final val JAVA_INT = "int" + final val JAVA_LONG = "long" + final val JAVA_FLOAT = "float" + final val JAVA_DOUBLE = "double" + + /** + * List of java primitive data types + */ + val primitiveTypes = + Seq(JAVA_BOOLEAN, JAVA_BYTE, JAVA_SHORT, JAVA_INT, JAVA_LONG, JAVA_FLOAT, JAVA_DOUBLE) + + /** + * Returns true if a Java type is Java primitive primitive type + */ + def isPrimitiveType(jt: String): Boolean = primitiveTypes.contains(jt) + + def isPrimitiveType(dt: DataType): Boolean = isPrimitiveType(javaType(dt)) + + /** + * Returns the specialized code to access a value from `inputRow` at `ordinal`. + */ + def getValue(input: String, dataType: DataType, ordinal: String): String = { + val jt = javaType(dataType) + dataType match { + case _ if isPrimitiveType(jt) => s"$input.get${primitiveTypeName(jt)}($ordinal)" + case t: DecimalType => s"$input.getDecimal($ordinal, ${t.precision}, ${t.scale})" + case StringType => s"$input.getUTF8String($ordinal)" + case BinaryType => s"$input.getBinary($ordinal)" + case CalendarIntervalType => s"$input.getInterval($ordinal)" + case t: StructType => s"$input.getStruct($ordinal, ${t.size})" + case _: ArrayType => s"$input.getArray($ordinal)" + case _: MapType => s"$input.getMap($ordinal)" + case NullType => "null" + case udt: UserDefinedType[_] => getValue(input, udt.sqlType, ordinal) + case _ => s"($jt)$input.get($ordinal, null)" + } + } + + /** + * Returns the code to update a column in Row for a given DataType. + */ + def setColumn(row: String, dataType: DataType, ordinal: Int, value: String): String = { + val jt = javaType(dataType) + dataType match { + case _ if isPrimitiveType(jt) => s"$row.set${primitiveTypeName(jt)}($ordinal, $value)" + case t: DecimalType => s"$row.setDecimal($ordinal, $value, ${t.precision})" + case udt: UserDefinedType[_] => setColumn(row, udt.sqlType, ordinal, value) + // The UTF8String, InternalRow, ArrayData and MapData may came from UnsafeRow, we should copy + // it to avoid keeping a "pointer" to a memory region which may get updated afterwards. + case StringType | _: StructType | _: ArrayType | _: MapType => + s"$row.update($ordinal, $value.copy())" + case _ => s"$row.update($ordinal, $value)" + } + } + + /** + * Update a column in MutableRow from ExprCode. + * + * @param isVectorized True if the underlying row is of type `ColumnarBatch.Row`, false otherwise + */ + def updateColumn( + row: String, + dataType: DataType, + ordinal: Int, + ev: ExprCode, + nullable: Boolean, + isVectorized: Boolean = false): String = { + if (nullable) { + // Can't call setNullAt on DecimalType, because we need to keep the offset + if (!isVectorized && dataType.isInstanceOf[DecimalType]) { + s""" + |if (!${ev.isNull}) { + | ${setColumn(row, dataType, ordinal, ev.value)}; + |} else { + | ${setColumn(row, dataType, ordinal, "null")}; + |} + """.stripMargin + } else { + s""" + |if (!${ev.isNull}) { + | ${setColumn(row, dataType, ordinal, ev.value)}; + |} else { + | $row.setNullAt($ordinal); + |} + """.stripMargin + } + } else { + s"""${setColumn(row, dataType, ordinal, ev.value)};""" + } + } + + /** + * Returns the specialized code to set a given value in a column vector for a given `DataType`. + */ + def setValue(vector: String, rowId: String, dataType: DataType, value: String): String = { + val jt = javaType(dataType) + dataType match { + case _ if isPrimitiveType(jt) => + s"$vector.put${primitiveTypeName(jt)}($rowId, $value);" + case t: DecimalType => s"$vector.putDecimal($rowId, $value, ${t.precision});" + case t: StringType => s"$vector.putByteArray($rowId, $value.getBytes());" + case _ => + throw new IllegalArgumentException(s"cannot generate code for unsupported type: $dataType") + } + } + + /** + * Returns the specialized code to set a given value in a column vector for a given `DataType` + * that could potentially be nullable. + */ + def updateColumn( + vector: String, + rowId: String, + dataType: DataType, + ev: ExprCode, + nullable: Boolean): String = { + if (nullable) { + s""" + |if (!${ev.isNull}) { + | ${setValue(vector, rowId, dataType, ev.value)} + |} else { + | $vector.putNull($rowId); + |} + """.stripMargin + } else { + s"""${setValue(vector, rowId, dataType, ev.value)};""" + } + } + + /** + * Returns the specialized code to access a value from a column vector for a given `DataType`. + */ + def getValueFromVector(vector: String, dataType: DataType, rowId: String): String = { + if (dataType.isInstanceOf[StructType]) { + // `ColumnVector.getStruct` is different from `InternalRow.getStruct`, it only takes an + // `ordinal` parameter. + s"$vector.getStruct($rowId)" + } else { + getValue(vector, dataType, rowId) + } + } + + /** + * Returns the name used in accessor and setter for a Java primitive type. + */ + def primitiveTypeName(jt: String): String = jt match { + case JAVA_INT => "Int" + case _ => boxedType(jt) + } + + def primitiveTypeName(dt: DataType): String = primitiveTypeName(javaType(dt)) + + /** + * Returns the Java type for a DataType. + */ + def javaType(dt: DataType): String = dt match { + case BooleanType => JAVA_BOOLEAN + case ByteType => JAVA_BYTE + case ShortType => JAVA_SHORT + case IntegerType | DateType => JAVA_INT + case LongType | TimestampType => JAVA_LONG + case FloatType => JAVA_FLOAT + case DoubleType => JAVA_DOUBLE + case _: DecimalType => "Decimal" + case BinaryType => "byte[]" + case StringType => "UTF8String" + case CalendarIntervalType => "CalendarInterval" + case _: StructType => "InternalRow" + case _: ArrayType => "ArrayData" + case _: MapType => "MapData" + case udt: UserDefinedType[_] => javaType(udt.sqlType) + case ObjectType(cls) if cls.isArray => s"${javaType(ObjectType(cls.getComponentType))}[]" + case ObjectType(cls) => cls.getName + case _ => "Object" + } + + /** + * Returns the boxed type in Java. + */ + def boxedType(jt: String): String = jt match { + case JAVA_BOOLEAN => "Boolean" + case JAVA_BYTE => "Byte" + case JAVA_SHORT => "Short" + case JAVA_INT => "Integer" + case JAVA_LONG => "Long" + case JAVA_FLOAT => "Float" + case JAVA_DOUBLE => "Double" + case other => other + } + + def boxedType(dt: DataType): String = boxedType(javaType(dt)) + + /** + * Returns the representation of default value for a given Java Type. + * @param jt the string name of the Java type + * @param typedNull if true, for null literals, return a typed (with a cast) version + */ + def defaultValue(jt: String, typedNull: Boolean): String = jt match { + case JAVA_BOOLEAN => "false" + case JAVA_BYTE => "(byte)-1" + case JAVA_SHORT => "(short)-1" + case JAVA_INT => "-1" + case JAVA_LONG => "-1L" + case JAVA_FLOAT => "-1.0f" + case JAVA_DOUBLE => "-1.0" + case _ => if (typedNull) s"(($jt)null)" else "null" + } + + def defaultValue(dt: DataType, typedNull: Boolean = false): String = + defaultValue(javaType(dt), typedNull) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala index 0322d1dd6a9ff..e12420bb5dfdd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala @@ -44,20 +44,21 @@ trait CodegenFallback extends Expression { } val objectTerm = ctx.freshName("obj") val placeHolder = ctx.registerComment(this.toString) + val javaType = CodeGenerator.javaType(this.dataType) if (nullable) { ev.copy(code = s""" $placeHolder Object $objectTerm = ((Expression) references[$idx]).eval($input); boolean ${ev.isNull} = $objectTerm == null; - ${ctx.javaType(this.dataType)} ${ev.value} = ${ctx.defaultValue(this.dataType)}; + $javaType ${ev.value} = ${CodeGenerator.defaultValue(this.dataType)}; if (!${ev.isNull}) { - ${ev.value} = (${ctx.boxedType(this.dataType)}) $objectTerm; + ${ev.value} = (${CodeGenerator.boxedType(this.dataType)}) $objectTerm; }""") } else { ev.copy(code = s""" $placeHolder Object $objectTerm = ((Expression) references[$idx]).eval($input); - ${ctx.javaType(this.dataType)} ${ev.value} = (${ctx.boxedType(this.dataType)}) $objectTerm; + $javaType ${ev.value} = (${CodeGenerator.boxedType(this.dataType)}) $objectTerm; """, isNull = "false") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index b53c0087e7e2d..d35fd8ecb4d63 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -62,9 +62,9 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP val projectionCodes: Seq[(String, String, String, Int)] = exprVals.zip(index).map { case (ev, i) => val e = expressions(i) - val value = ctx.addMutableState(ctx.javaType(e.dataType), "value") + val value = ctx.addMutableState(CodeGenerator.javaType(e.dataType), "value") if (e.nullable) { - val isNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, "isNull") + val isNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "isNull") (s""" |${ev.code} |$isNull = ${ev.isNull}; @@ -84,7 +84,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP val updates = validExpr.zip(projectionCodes).map { case (e, (_, isNull, value, i)) => val ev = ExprCode("", isNull, value) - ctx.updateColumn("mutableRow", e.dataType, i, ev, e.nullable) + CodeGenerator.updateColumn("mutableRow", e.dataType, i, ev, e.nullable) } val allProjections = ctx.splitExpressionsWithCurrentInputs(projectionCodes.map(_._1)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala index 4a459571ed634..9a51be6ed5aeb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala @@ -89,7 +89,7 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR s""" ${ctx.INPUT_ROW} = a; boolean $isNullA; - ${ctx.javaType(order.child.dataType)} $primitiveA; + ${CodeGenerator.javaType(order.child.dataType)} $primitiveA; { ${eval.code} $isNullA = ${eval.isNull}; @@ -97,7 +97,7 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR } ${ctx.INPUT_ROW} = b; boolean $isNullB; - ${ctx.javaType(order.child.dataType)} $primitiveB; + ${CodeGenerator.javaType(order.child.dataType)} $primitiveB; { ${eval.code} $isNullB = ${eval.isNull}; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala index 3dcbb518ba42a..f92f70ee71fef 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala @@ -53,7 +53,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] val rowClass = classOf[GenericInternalRow].getName val fieldWriters = schema.map(_.dataType).zipWithIndex.map { case (dt, i) => - val converter = convertToSafe(ctx, ctx.getValue(tmpInput, dt, i.toString), dt) + val converter = convertToSafe(ctx, CodeGenerator.getValue(tmpInput, dt, i.toString), dt) s""" if (!$tmpInput.isNullAt($i)) { ${converter.code} @@ -90,7 +90,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] val arrayClass = classOf[GenericArrayData].getName val elementConverter = convertToSafe( - ctx, ctx.getValue(tmpInput, elementType, index), elementType) + ctx, CodeGenerator.getValue(tmpInput, elementType, index), elementType) val code = s""" final ArrayData $tmpInput = $input; final int $numElements = $tmpInput.numElements(); @@ -153,7 +153,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] mutableRow.setNullAt($i); } else { ${converter.code} - ${ctx.setColumn("mutableRow", e.dataType, i, converter.value)}; + ${CodeGenerator.setColumn("mutableRow", e.dataType, i, converter.value)}; } """ } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 36ffa8dcdd2b6..22717f5954a45 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -52,7 +52,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro // Puts `input` in a local variable to avoid to re-evaluate it if it's a statement. val tmpInput = ctx.freshName("tmpInput") val fieldEvals = fieldTypes.zipWithIndex.map { case (dt, i) => - ExprCode("", s"$tmpInput.isNullAt($i)", ctx.getValue(tmpInput, dt, i.toString)) + ExprCode("", s"$tmpInput.isNullAt($i)", CodeGenerator.getValue(tmpInput, dt, i.toString)) } s""" @@ -195,16 +195,16 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro case other => other } - val jt = ctx.javaType(et) + val jt = CodeGenerator.javaType(et) val elementOrOffsetSize = et match { case t: DecimalType if t.precision <= Decimal.MAX_LONG_DIGITS => 8 - case _ if ctx.isPrimitiveType(jt) => et.defaultSize + case _ if CodeGenerator.isPrimitiveType(jt) => et.defaultSize case _ => 8 // we need 8 bytes to store offset and length } val tmpCursor = ctx.freshName("tmpCursor") - val element = ctx.getValue(tmpInput, et, index) + val element = CodeGenerator.getValue(tmpInput, et, index) val writeElement = et match { case t: StructType => s""" @@ -235,7 +235,8 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro case _ => s"$arrayWriter.write($index, $element);" } - val primitiveTypeName = if (ctx.isPrimitiveType(jt)) ctx.primitiveTypeName(et) else "" + val primitiveTypeName = + if (CodeGenerator.isPrimitiveType(jt)) CodeGenerator.primitiveTypeName(et) else "" s""" final ArrayData $tmpInput = $input; if ($tmpInput instanceof UnsafeArrayData) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 4270b987d6de0..beb84694c44e8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -20,7 +20,7 @@ import java.util.Comparator import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodegenFallback, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, CodegenFallback, ExprCode} import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData} import org.apache.spark.sql.types._ @@ -54,7 +54,7 @@ case class Size(child: Expression) extends UnaryExpression with ExpectsInputType ev.copy(code = s""" boolean ${ev.isNull} = false; ${childGen.code} - ${ctx.javaType(dataType)} ${ev.value} = ${childGen.isNull} ? -1 : + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${childGen.isNull} ? -1 : (${childGen.value}).numElements();""", isNull = "false") } } @@ -270,7 +270,7 @@ case class ArrayContains(left: Expression, right: Expression) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, (arr, value) => { val i = ctx.freshName("i") - val getValue = ctx.getValue(arr, right.dataType, i) + val getValue = CodeGenerator.getValue(arr, right.dataType, i) s""" for (int $i = 0; $i < $arr.numElements(); $i ++) { if ($arr.isNullAt($i)) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 047b80ac5289c..85facdad43db7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -90,7 +90,7 @@ private [sql] object GenArrayData { val arrayDataName = ctx.freshName("arrayData") val numElements = elementsCode.length - if (!ctx.isPrimitiveType(elementType)) { + if (!CodeGenerator.isPrimitiveType(elementType)) { val arrayName = ctx.freshName("arrayObject") val genericArrayClass = classOf[GenericArrayData].getName @@ -124,7 +124,7 @@ private [sql] object GenArrayData { ByteArrayMethods.roundNumberOfBytesToNearestWord(elementType.defaultSize * numElements) val baseOffset = Platform.BYTE_ARRAY_OFFSET - val primitiveValueTypeName = ctx.primitiveTypeName(elementType) + val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType) val assignments = elementsCode.zipWithIndex.map { case (eval, i) => val isNullAssignment = if (!isMapKey) { s"$arrayDataName.setNullAt($i);" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index 7e53ca3908905..6cdad19168dce 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} import org.apache.spark.sql.catalyst.util.{quoteIdentifier, ArrayData, GenericArrayData, MapData} import org.apache.spark.sql.types._ @@ -129,12 +129,12 @@ case class GetStructField(child: Expression, ordinal: Int, name: Option[String] if ($eval.isNullAt($ordinal)) { ${ev.isNull} = true; } else { - ${ev.value} = ${ctx.getValue(eval, dataType, ordinal.toString)}; + ${ev.value} = ${CodeGenerator.getValue(eval, dataType, ordinal.toString)}; } """ } else { s""" - ${ev.value} = ${ctx.getValue(eval, dataType, ordinal.toString)}; + ${ev.value} = ${CodeGenerator.getValue(eval, dataType, ordinal.toString)}; """ } }) @@ -205,7 +205,7 @@ case class GetArrayStructFields( } else { final InternalRow $row = $eval.getStruct($j, $numFields); $nullSafeEval { - $values[$j] = ${ctx.getValue(row, field.dataType, ordinal.toString)}; + $values[$j] = ${CodeGenerator.getValue(row, field.dataType, ordinal.toString)}; } } } @@ -260,7 +260,7 @@ case class GetArrayItem(child: Expression, ordinal: Expression) if ($index >= $eval1.numElements() || $index < 0$nullCheck) { ${ev.isNull} = true; } else { - ${ev.value} = ${ctx.getValue(eval1, dataType, index)}; + ${ev.value} = ${CodeGenerator.getValue(eval1, dataType, index)}; } """ }) @@ -327,6 +327,7 @@ case class GetMapValue(child: Expression, key: Expression) } else { "" } + val keyJavaType = CodeGenerator.javaType(keyType) nullSafeCodeGen(ctx, ev, (eval1, eval2) => { s""" final int $length = $eval1.numElements(); @@ -336,7 +337,7 @@ case class GetMapValue(child: Expression, key: Expression) int $index = 0; boolean $found = false; while ($index < $length && !$found) { - final ${ctx.javaType(keyType)} $key = ${ctx.getValue(keys, keyType, index)}; + final $keyJavaType $key = ${CodeGenerator.getValue(keys, keyType, index)}; if (${ctx.genEqual(keyType, key, eval2)}) { $found = true; } else { @@ -347,7 +348,7 @@ case class GetMapValue(child: Expression, key: Expression) if (!$found$nullCheck) { ${ev.isNull} = true; } else { - ${ev.value} = ${ctx.getValue(values, dataType, index)}; + ${ev.value} = ${CodeGenerator.getValue(values, dataType, index)}; } """ }) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index b444c3a7be92a..f4e9619bac59d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -69,7 +69,7 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi s""" |${condEval.code} |boolean ${ev.isNull} = false; - |${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + |${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; |if (!${condEval.isNull} && ${condEval.value}) { | ${trueEval.code} | ${ev.isNull} = ${trueEval.isNull}; @@ -191,7 +191,7 @@ case class CaseWhen( // It is initialized to `NOT_MATCHED`, and if it's set to `HAS_NULL` or `HAS_NONNULL`, // We won't go on anymore on the computation. val resultState = ctx.freshName("caseWhenResultState") - ev.value = ctx.addMutableState(ctx.javaType(dataType), ev.value) + ev.value = ctx.addMutableState(CodeGenerator.javaType(dataType), ev.value) // these blocks are meant to be inside a // do { @@ -244,10 +244,10 @@ case class CaseWhen( val codes = ctx.splitExpressionsWithCurrentInputs( expressions = allConditions, funcName = "caseWhen", - returnType = ctx.JAVA_BYTE, + returnType = CodeGenerator.JAVA_BYTE, makeSplitFunction = func => s""" - |${ctx.JAVA_BYTE} $resultState = $NOT_MATCHED; + |${CodeGenerator.JAVA_BYTE} $resultState = $NOT_MATCHED; |do { | $func |} while (false); @@ -264,7 +264,7 @@ case class CaseWhen( ev.copy(code = s""" - |${ctx.JAVA_BYTE} $resultState = $NOT_MATCHED; + |${CodeGenerator.JAVA_BYTE} $resultState = $NOT_MATCHED; |do { | $codes |} while (false); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index 424871f2047e9..1ae4e5a2f716b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -26,7 +26,7 @@ import scala.util.control.NonFatal import org.apache.commons.lang3.StringEscapeUtils import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodegenFallback, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} @@ -673,18 +673,19 @@ abstract class UnixTime } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val javaType = CodeGenerator.javaType(dataType) left.dataType match { case StringType if right.foldable => val df = classOf[DateFormat].getName if (formatter == null) { - ExprCode("", "true", ctx.defaultValue(dataType)) + ExprCode.forNullValue(dataType) } else { val formatterName = ctx.addReferenceObj("formatter", formatter, df) val eval1 = left.genCode(ctx) ev.copy(code = s""" ${eval1.code} boolean ${ev.isNull} = ${eval1.isNull}; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${ev.isNull}) { try { ${ev.value} = $formatterName.parse(${eval1.value}.toString()).getTime() / 1000L; @@ -713,7 +714,7 @@ abstract class UnixTime ev.copy(code = s""" ${eval1.code} boolean ${ev.isNull} = ${eval1.isNull}; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${ev.isNull}) { ${ev.value} = ${eval1.value} / 1000000L; }""") @@ -724,7 +725,7 @@ abstract class UnixTime ev.copy(code = s""" ${eval1.code} boolean ${ev.isNull} = ${eval1.isNull}; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${ev.isNull}) { ${ev.value} = $dtu.daysToMillis(${eval1.value}, $tz) / 1000L; }""") @@ -819,7 +820,7 @@ case class FromUnixTime(sec: Expression, format: Expression, timeZoneId: Option[ ev.copy(code = s""" ${t.code} boolean ${ev.isNull} = ${t.isNull}; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${ev.isNull}) { try { ${ev.value} = UTF8String.fromString($formatterName.format( @@ -1344,18 +1345,19 @@ trait TruncInstant extends BinaryExpression with ImplicitCastInputTypes { : ExprCode = { val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") + val javaType = CodeGenerator.javaType(dataType) if (format.foldable) { if (truncLevel == DateTimeUtils.TRUNC_INVALID || truncLevel > maxLevel) { ev.copy(code = s""" boolean ${ev.isNull} = true; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};""") + $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};""") } else { val t = instant.genCode(ctx) val truncFuncStr = truncFunc(t.value, truncLevel.toString) ev.copy(code = s""" ${t.code} boolean ${ev.isNull} = ${t.isNull}; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${ev.isNull}) { ${ev.value} = $dtu.$truncFuncStr; }""") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala index 055ebf6c0da54..b702422ed7a1d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala @@ -278,7 +278,7 @@ abstract class HashExpression[E] extends Expression { } } - val hashResultType = ctx.javaType(dataType) + val hashResultType = CodeGenerator.javaType(dataType) val codes = ctx.splitExpressionsWithCurrentInputs( expressions = childrenHash, funcName = "computeHash", @@ -307,9 +307,10 @@ abstract class HashExpression[E] extends Expression { ctx: CodegenContext): String = { val element = ctx.freshName("element") + val jt = CodeGenerator.javaType(elementType) ctx.nullSafeExec(nullable, s"$input.isNullAt($index)") { s""" - final ${ctx.javaType(elementType)} $element = ${ctx.getValue(input, elementType, index)}; + final $jt $element = ${CodeGenerator.getValue(input, elementType, index)}; ${computeHash(element, elementType, result, ctx)} """ } @@ -407,7 +408,7 @@ abstract class HashExpression[E] extends Expression { val fieldsHash = fields.zipWithIndex.map { case (field, index) => nullSafeElementHash(input, index.toString, field.nullable, field.dataType, result, ctx) } - val hashResultType = ctx.javaType(dataType) + val hashResultType = CodeGenerator.javaType(dataType) ctx.splitExpressions( expressions = fieldsHash, funcName = "computeHashForStruct", @@ -651,11 +652,11 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] { val codes = ctx.splitExpressionsWithCurrentInputs( expressions = childrenHash, funcName = "computeHash", - extraArguments = Seq(ctx.JAVA_INT -> ev.value), - returnType = ctx.JAVA_INT, + extraArguments = Seq(CodeGenerator.JAVA_INT -> ev.value), + returnType = CodeGenerator.JAVA_INT, makeSplitFunction = body => s""" - |${ctx.JAVA_INT} $childHash = 0; + |${CodeGenerator.JAVA_INT} $childHash = 0; |$body |return ${ev.value}; """.stripMargin, @@ -664,8 +665,8 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] { ev.copy(code = s""" - |${ctx.JAVA_INT} ${ev.value} = $seed; - |${ctx.JAVA_INT} $childHash = 0; + |${CodeGenerator.JAVA_INT} ${ev.value} = $seed; + |${CodeGenerator.JAVA_INT} $childHash = 0; |$codes """.stripMargin) } @@ -780,14 +781,14 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] { """.stripMargin } - s"${ctx.JAVA_INT} $childResult = 0;\n" + ctx.splitExpressions( + s"${CodeGenerator.JAVA_INT} $childResult = 0;\n" + ctx.splitExpressions( expressions = fieldsHash, funcName = "computeHashForStruct", - arguments = Seq("InternalRow" -> input, ctx.JAVA_INT -> result), - returnType = ctx.JAVA_INT, + arguments = Seq("InternalRow" -> input, CodeGenerator.JAVA_INT -> result), + returnType = CodeGenerator.JAVA_INT, makeSplitFunction = body => s""" - |${ctx.JAVA_INT} $childResult = 0; + |${CodeGenerator.JAVA_INT} $childResult = 0; |$body |return $result; """.stripMargin, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/inputFileBlock.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/inputFileBlock.scala index 7a8edabed1757..07785e7448586 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/inputFileBlock.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/inputFileBlock.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.rdd.InputFileBlockHolder import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} import org.apache.spark.sql.types.{DataType, LongType, StringType} import org.apache.spark.unsafe.types.UTF8String @@ -42,7 +42,7 @@ case class InputFileName() extends LeafExpression with Nondeterministic { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val className = InputFileBlockHolder.getClass.getName.stripSuffix("$") - ev.copy(code = s"final ${ctx.javaType(dataType)} ${ev.value} = " + + ev.copy(code = s"final ${CodeGenerator.javaType(dataType)} ${ev.value} = " + s"$className.getInputFilePath();", isNull = "false") } } @@ -65,7 +65,7 @@ case class InputFileBlockStart() extends LeafExpression with Nondeterministic { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val className = InputFileBlockHolder.getClass.getName.stripSuffix("$") - ev.copy(code = s"final ${ctx.javaType(dataType)} ${ev.value} = " + + ev.copy(code = s"final ${CodeGenerator.javaType(dataType)} ${ev.value} = " + s"$className.getStartOffset();", isNull = "false") } } @@ -88,7 +88,7 @@ case class InputFileBlockLength() extends LeafExpression with Nondeterministic { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val className = InputFileBlockHolder.getClass.getName.stripSuffix("$") - ev.copy(code = s"final ${ctx.javaType(dataType)} ${ev.value} = " + + ev.copy(code = s"final ${CodeGenerator.javaType(dataType)} ${ev.value} = " + s"$className.getLength();", isNull = "false") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index c1e65e34c2ea6..7395609a04ba5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -277,13 +277,9 @@ case class Literal (value: Any, dataType: DataType) extends LeafExpression { override def eval(input: InternalRow): Any = value override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val javaType = ctx.javaType(dataType) + val javaType = CodeGenerator.javaType(dataType) if (value == null) { - val defaultValueLiteral = ctx.defaultValue(javaType) match { - case "null" => s"(($javaType)null)" - case lit => lit - } - ExprCode(code = "", isNull = "true", value = defaultValueLiteral) + ExprCode.forNullValue(dataType) } else { dataType match { case BooleanType | IntegerType | DateType => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala index d8dc0862f1141..2c2cf3d2e6227 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala @@ -1128,15 +1128,16 @@ abstract class RoundBase(child: Expression, scale: Expression, }""" } + val javaType = CodeGenerator.javaType(dataType) if (scaleV == null) { // if scale is null, no need to eval its child at all ev.copy(code = s""" boolean ${ev.isNull} = true; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};""") + $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};""") } else { ev.copy(code = s""" ${ce.code} boolean ${ev.isNull} = ${ce.isNull}; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${ev.isNull}) { $evaluationCode }""") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala index 470d5da041ea5..b35fa72e95d1e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ @@ -72,7 +72,7 @@ case class Coalesce(children: Seq[Expression]) extends Expression { } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - ev.isNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull) + ev.isNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, ev.isNull) // all the evals are meant to be in a do { ... } while (false); loop val evals = children.map { e => @@ -87,14 +87,14 @@ case class Coalesce(children: Seq[Expression]) extends Expression { """.stripMargin } - val resultType = ctx.javaType(dataType) + val resultType = CodeGenerator.javaType(dataType) val codes = ctx.splitExpressionsWithCurrentInputs( expressions = evals, funcName = "coalesce", returnType = resultType, makeSplitFunction = func => s""" - |$resultType ${ev.value} = ${ctx.defaultValue(dataType)}; + |$resultType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; |do { | $func |} while (false); @@ -113,7 +113,7 @@ case class Coalesce(children: Seq[Expression]) extends Expression { ev.copy(code = s""" |${ev.isNull} = true; - |$resultType ${ev.value} = ${ctx.defaultValue(dataType)}; + |$resultType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; |do { | $codes |} while (false); @@ -234,7 +234,7 @@ case class IsNaN(child: Expression) extends UnaryExpression case DoubleType | FloatType => ev.copy(code = s""" ${eval.code} - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; ${ev.value} = !${eval.isNull} && Double.isNaN(${eval.value});""", isNull = "false") } } @@ -281,7 +281,7 @@ case class NaNvl(left: Expression, right: Expression) ev.copy(code = s""" ${leftGen.code} boolean ${ev.isNull} = false; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (${leftGen.isNull}) { ${ev.isNull} = true; } else { @@ -416,8 +416,8 @@ case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends Predicate val codes = ctx.splitExpressionsWithCurrentInputs( expressions = evals, funcName = "atLeastNNonNulls", - extraArguments = (ctx.JAVA_INT, nonnull) :: Nil, - returnType = ctx.JAVA_INT, + extraArguments = (CodeGenerator.JAVA_INT, nonnull) :: Nil, + returnType = CodeGenerator.JAVA_INT, makeSplitFunction = body => s""" |do { @@ -436,11 +436,11 @@ case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends Predicate ev.copy(code = s""" - |${ctx.JAVA_INT} $nonnull = 0; + |${CodeGenerator.JAVA_INT} $nonnull = 0; |do { | $codes |} while (false); - |${ctx.JAVA_BOOLEAN} ${ev.value} = $nonnull >= $n; + |${CodeGenerator.JAVA_BOOLEAN} ${ev.value} = $nonnull >= $n; """.stripMargin, isNull = "false") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 64da9bb9cdec1..80618af1e859f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.ScalaReflection.universe.TermName import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData} import org.apache.spark.sql.types._ @@ -62,13 +62,13 @@ trait InvokeLike extends Expression with NonSQLExpression { def prepareArguments(ctx: CodegenContext): (String, String, String) = { val resultIsNull = if (needNullCheck) { - val resultIsNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, "resultIsNull") + val resultIsNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "resultIsNull") resultIsNull } else { "false" } val argValues = arguments.map { e => - val argValue = ctx.addMutableState(ctx.javaType(e.dataType), "argValue") + val argValue = ctx.addMutableState(CodeGenerator.javaType(e.dataType), "argValue") argValue } @@ -137,7 +137,7 @@ case class StaticInvoke( throw new UnsupportedOperationException("Only code-generated evaluation is supported.") override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val javaType = ctx.javaType(dataType) + val javaType = CodeGenerator.javaType(dataType) val (argCode, argString, resultIsNull) = prepareArguments(ctx) @@ -151,7 +151,7 @@ case class StaticInvoke( } val evaluate = if (returnNullable) { - if (ctx.defaultValue(dataType) == "null") { + if (CodeGenerator.defaultValue(dataType) == "null") { s""" ${ev.value} = $callFunc; ${ev.isNull} = ${ev.value} == null; @@ -159,7 +159,7 @@ case class StaticInvoke( } else { val boxedResult = ctx.freshName("boxedResult") s""" - ${ctx.boxedType(dataType)} $boxedResult = $callFunc; + ${CodeGenerator.boxedType(dataType)} $boxedResult = $callFunc; ${ev.isNull} = $boxedResult == null; if (!${ev.isNull}) { ${ev.value} = $boxedResult; @@ -173,7 +173,7 @@ case class StaticInvoke( val code = s""" $argCode $prepareIsNull - $javaType ${ev.value} = ${ctx.defaultValue(dataType)}; + $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!$resultIsNull) { $evaluate } @@ -228,7 +228,7 @@ case class Invoke( } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val javaType = ctx.javaType(dataType) + val javaType = CodeGenerator.javaType(dataType) val obj = targetObject.genCode(ctx) val (argCode, argString, resultIsNull) = prepareArguments(ctx) @@ -255,11 +255,11 @@ case class Invoke( // If the function can return null, we do an extra check to make sure our null bit is still // set correctly. val assignResult = if (!returnNullable) { - s"${ev.value} = (${ctx.boxedType(javaType)}) $funcResult;" + s"${ev.value} = (${CodeGenerator.boxedType(javaType)}) $funcResult;" } else { s""" if ($funcResult != null) { - ${ev.value} = (${ctx.boxedType(javaType)}) $funcResult; + ${ev.value} = (${CodeGenerator.boxedType(javaType)}) $funcResult; } else { ${ev.isNull} = true; } @@ -275,7 +275,7 @@ case class Invoke( val code = s""" ${obj.code} boolean ${ev.isNull} = true; - $javaType ${ev.value} = ${ctx.defaultValue(dataType)}; + $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${obj.isNull}) { $argCode ${ev.isNull} = $resultIsNull; @@ -341,7 +341,7 @@ case class NewInstance( throw new UnsupportedOperationException("Only code-generated evaluation is supported.") override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val javaType = ctx.javaType(dataType) + val javaType = CodeGenerator.javaType(dataType) val (argCode, argString, resultIsNull) = prepareArguments(ctx) @@ -358,7 +358,8 @@ case class NewInstance( val code = s""" $argCode ${outer.map(_.code).getOrElse("")} - final $javaType ${ev.value} = ${ev.isNull} ? ${ctx.defaultValue(javaType)} : $constructorCall; + final $javaType ${ev.value} = ${ev.isNull} ? + ${CodeGenerator.defaultValue(dataType)} : $constructorCall; """ ev.copy(code = code) } @@ -385,15 +386,15 @@ case class UnwrapOption( throw new UnsupportedOperationException("Only code-generated evaluation is supported") override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val javaType = ctx.javaType(dataType) + val javaType = CodeGenerator.javaType(dataType) val inputObject = child.genCode(ctx) val code = s""" ${inputObject.code} final boolean ${ev.isNull} = ${inputObject.isNull} || ${inputObject.value}.isEmpty(); - $javaType ${ev.value} = ${ev.isNull} ? - ${ctx.defaultValue(javaType)} : (${ctx.boxedType(javaType)}) ${inputObject.value}.get(); + $javaType ${ev.value} = ${ev.isNull} ? ${CodeGenerator.defaultValue(dataType)} : + (${CodeGenerator.boxedType(javaType)}) ${inputObject.value}.get(); """ ev.copy(code = code) } @@ -546,7 +547,7 @@ case class MapObjects private( ArrayType(lambdaFunction.dataType, containsNull = lambdaFunction.nullable)) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val elementJavaType = ctx.javaType(loopVarDataType) + val elementJavaType = CodeGenerator.javaType(loopVarDataType) ctx.addMutableState(elementJavaType, loopValue, forceInline = true, useFreshName = false) val genInputData = inputData.genCode(ctx) val genFunction = lambdaFunction.genCode(ctx) @@ -554,7 +555,7 @@ case class MapObjects private( val convertedArray = ctx.freshName("convertedArray") val loopIndex = ctx.freshName("loopIndex") - val convertedType = ctx.boxedType(lambdaFunction.dataType) + val convertedType = CodeGenerator.boxedType(lambdaFunction.dataType) // Because of the way Java defines nested arrays, we have to handle the syntax specially. // Specifically, we have to insert the [$dataLength] in between the type and any extra nested @@ -621,7 +622,7 @@ case class MapObjects private( ( s"${genInputData.value}.numElements()", "", - ctx.getValue(genInputData.value, et, loopIndex) + CodeGenerator.getValue(genInputData.value, et, loopIndex) ) case ObjectType(cls) if cls == classOf[Object] => val it = ctx.freshName("it") @@ -643,7 +644,8 @@ case class MapObjects private( } val loopNullCheck = if (loopIsNull != "false") { - ctx.addMutableState(ctx.JAVA_BOOLEAN, loopIsNull, forceInline = true, useFreshName = false) + ctx.addMutableState( + CodeGenerator.JAVA_BOOLEAN, loopIsNull, forceInline = true, useFreshName = false) inputDataType match { case _: ArrayType => s"$loopIsNull = ${genInputData.value}.isNullAt($loopIndex);" case _ => s"$loopIsNull = $loopValue == null;" @@ -695,7 +697,7 @@ case class MapObjects private( val code = s""" ${genInputData.code} - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${genInputData.isNull}) { $determineCollectionType @@ -806,10 +808,10 @@ case class CatalystToExternalMap private( } val mapType = inputDataType(inputData.dataType).asInstanceOf[MapType] - val keyElementJavaType = ctx.javaType(mapType.keyType) + val keyElementJavaType = CodeGenerator.javaType(mapType.keyType) ctx.addMutableState(keyElementJavaType, keyLoopValue, forceInline = true, useFreshName = false) val genKeyFunction = keyLambdaFunction.genCode(ctx) - val valueElementJavaType = ctx.javaType(mapType.valueType) + val valueElementJavaType = CodeGenerator.javaType(mapType.valueType) ctx.addMutableState(valueElementJavaType, valueLoopValue, forceInline = true, useFreshName = false) val genValueFunction = valueLambdaFunction.genCode(ctx) @@ -825,10 +827,11 @@ case class CatalystToExternalMap private( val valueArray = ctx.freshName("valueArray") val getKeyArray = s"${classOf[ArrayData].getName} $keyArray = ${genInputData.value}.keyArray();" - val getKeyLoopVar = ctx.getValue(keyArray, inputDataType(mapType.keyType), loopIndex) + val getKeyLoopVar = CodeGenerator.getValue(keyArray, inputDataType(mapType.keyType), loopIndex) val getValueArray = s"${classOf[ArrayData].getName} $valueArray = ${genInputData.value}.valueArray();" - val getValueLoopVar = ctx.getValue(valueArray, inputDataType(mapType.valueType), loopIndex) + val getValueLoopVar = CodeGenerator.getValue( + valueArray, inputDataType(mapType.valueType), loopIndex) // Make a copy of the data if it's unsafe-backed def makeCopyIfInstanceOf(clazz: Class[_ <: Any], value: String) = @@ -844,7 +847,7 @@ case class CatalystToExternalMap private( val genValueFunctionValue = genFunctionValue(valueLambdaFunction, genValueFunction) val valueLoopNullCheck = if (valueLoopIsNull != "false") { - ctx.addMutableState(ctx.JAVA_BOOLEAN, valueLoopIsNull, forceInline = true, + ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, valueLoopIsNull, forceInline = true, useFreshName = false) s"$valueLoopIsNull = $valueArray.isNullAt($loopIndex);" } else { @@ -873,7 +876,7 @@ case class CatalystToExternalMap private( val code = s""" ${genInputData.code} - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${genInputData.isNull}) { int $dataLength = $getLength; @@ -993,8 +996,8 @@ case class ExternalMapToCatalyst private( val entry = ctx.freshName("entry") val entries = ctx.freshName("entries") - val keyElementJavaType = ctx.javaType(keyType) - val valueElementJavaType = ctx.javaType(valueType) + val keyElementJavaType = CodeGenerator.javaType(keyType) + val valueElementJavaType = CodeGenerator.javaType(valueType) ctx.addMutableState(keyElementJavaType, key, forceInline = true, useFreshName = false) ctx.addMutableState(valueElementJavaType, value, forceInline = true, useFreshName = false) @@ -1009,8 +1012,8 @@ case class ExternalMapToCatalyst private( val defineKeyValue = s""" final $javaMapEntryCls $entry = ($javaMapEntryCls) $entries.next(); - $key = (${ctx.boxedType(keyType)}) $entry.getKey(); - $value = (${ctx.boxedType(valueType)}) $entry.getValue(); + $key = (${CodeGenerator.boxedType(keyType)}) $entry.getKey(); + $value = (${CodeGenerator.boxedType(valueType)}) $entry.getValue(); """ defineEntries -> defineKeyValue @@ -1024,22 +1027,24 @@ case class ExternalMapToCatalyst private( val defineKeyValue = s""" final $scalaMapEntryCls $entry = ($scalaMapEntryCls) $entries.next(); - $key = (${ctx.boxedType(keyType)}) $entry._1(); - $value = (${ctx.boxedType(valueType)}) $entry._2(); + $key = (${CodeGenerator.boxedType(keyType)}) $entry._1(); + $value = (${CodeGenerator.boxedType(valueType)}) $entry._2(); """ defineEntries -> defineKeyValue } val keyNullCheck = if (keyIsNull != "false") { - ctx.addMutableState(ctx.JAVA_BOOLEAN, keyIsNull, forceInline = true, useFreshName = false) + ctx.addMutableState( + CodeGenerator.JAVA_BOOLEAN, keyIsNull, forceInline = true, useFreshName = false) s"$keyIsNull = $key == null;" } else { "" } val valueNullCheck = if (valueIsNull != "false") { - ctx.addMutableState(ctx.JAVA_BOOLEAN, valueIsNull, forceInline = true, useFreshName = false) + ctx.addMutableState( + CodeGenerator.JAVA_BOOLEAN, valueIsNull, forceInline = true, useFreshName = false) s"$valueIsNull = $value == null;" } else { "" @@ -1047,12 +1052,12 @@ case class ExternalMapToCatalyst private( val arrayCls = classOf[GenericArrayData].getName val mapCls = classOf[ArrayBasedMapData].getName - val convertedKeyType = ctx.boxedType(keyConverter.dataType) - val convertedValueType = ctx.boxedType(valueConverter.dataType) + val convertedKeyType = CodeGenerator.boxedType(keyConverter.dataType) + val convertedValueType = CodeGenerator.boxedType(valueConverter.dataType) val code = s""" ${inputMap.code} - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${inputMap.isNull}) { final int $length = ${inputMap.value}.size(); final Object[] $convertedKeys = new Object[$length]; @@ -1174,12 +1179,13 @@ case class EncodeUsingSerializer(child: Expression, kryo: Boolean) // Code to serialize. val input = child.genCode(ctx) - val javaType = ctx.javaType(dataType) + val javaType = CodeGenerator.javaType(dataType) val serialize = s"$serializer.serialize(${input.value}, null).array()" val code = s""" ${input.code} - final $javaType ${ev.value} = ${input.isNull} ? ${ctx.defaultValue(javaType)} : $serialize; + final $javaType ${ev.value} = + ${input.isNull} ? ${CodeGenerator.defaultValue(dataType)} : $serialize; """ ev.copy(code = code, isNull = input.isNull) } @@ -1223,13 +1229,14 @@ case class DecodeUsingSerializer[T](child: Expression, tag: ClassTag[T], kryo: B // Code to deserialize. val input = child.genCode(ctx) - val javaType = ctx.javaType(dataType) + val javaType = CodeGenerator.javaType(dataType) val deserialize = s"($javaType) $serializer.deserialize(java.nio.ByteBuffer.wrap(${input.value}), null)" val code = s""" ${input.code} - final $javaType ${ev.value} = ${input.isNull} ? ${ctx.defaultValue(javaType)} : $deserialize; + final $javaType ${ev.value} = + ${input.isNull} ? ${CodeGenerator.defaultValue(dataType)} : $deserialize; """ ev.copy(code = code, isNull = input.isNull) } @@ -1254,7 +1261,7 @@ case class InitializeJavaBean(beanInstance: Expression, setters: Map[String, Exp val instanceGen = beanInstance.genCode(ctx) val javaBeanInstance = ctx.freshName("javaBean") - val beanInstanceJavaType = ctx.javaType(beanInstance.dataType) + val beanInstanceJavaType = CodeGenerator.javaType(beanInstance.dataType) val initialize = setters.map { case (setterMethod, fieldValue) => @@ -1405,15 +1412,15 @@ case class ValidateExternalType(child: Expression, expected: DataType) case _: ArrayType => s"$obj instanceof ${classOf[Seq[_]].getName} || $obj.getClass().isArray()" case _ => - s"$obj instanceof ${ctx.boxedType(dataType)}" + s"$obj instanceof ${CodeGenerator.boxedType(dataType)}" } val code = s""" ${input.code} - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${input.isNull}) { if ($typeCheck) { - ${ev.value} = (${ctx.boxedType(dataType)}) $obj; + ${ev.value} = (${CodeGenerator.boxedType(dataType)}) $obj; } else { throw new RuntimeException($obj.getClass().getName() + $errMsgField); } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index a6d41ea7d00d4..4b85d9adbe311 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -21,7 +21,7 @@ import scala.collection.immutable.TreeSet import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, GenerateSafeProjection, GenerateUnsafeProjection, Predicate => BasePredicate} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, GenerateSafeProjection, GenerateUnsafeProjection, Predicate => BasePredicate} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ @@ -235,7 +235,7 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val javaDataType = ctx.javaType(value.dataType) + val javaDataType = CodeGenerator.javaType(value.dataType) val valueGen = value.genCode(ctx) val listGen = list.map(_.genCode(ctx)) // inTmpResult has 3 possible values: @@ -263,8 +263,8 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { val codes = ctx.splitExpressionsWithCurrentInputs( expressions = listCode, funcName = "valueIn", - extraArguments = (javaDataType, valueArg) :: (ctx.JAVA_BYTE, tmpResult) :: Nil, - returnType = ctx.JAVA_BYTE, + extraArguments = (javaDataType, valueArg) :: (CodeGenerator.JAVA_BYTE, tmpResult) :: Nil, + returnType = CodeGenerator.JAVA_BYTE, makeSplitFunction = body => s""" |do { @@ -348,8 +348,8 @@ case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with ev.copy(code = s""" |${childGen.code} - |${ctx.JAVA_BOOLEAN} ${ev.isNull} = ${childGen.isNull}; - |${ctx.JAVA_BOOLEAN} ${ev.value} = false; + |${CodeGenerator.JAVA_BOOLEAN} ${ev.isNull} = ${childGen.isNull}; + |${CodeGenerator.JAVA_BOOLEAN} ${ev.value} = false; |if (!${ev.isNull}) { | ${ev.value} = $setTerm.contains(${childGen.value}); | $setIsNull @@ -505,7 +505,7 @@ abstract class BinaryComparison extends BinaryOperator with Predicate { } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - if (ctx.isPrimitiveType(left.dataType) + if (CodeGenerator.isPrimitiveType(left.dataType) && left.dataType != BooleanType // java boolean doesn't support > or < operator && left.dataType != FloatType && left.dataType != DoubleType) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala index 8bc936fcbfc31..6c9937dacc70b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} import org.apache.spark.sql.types._ import org.apache.spark.util.Utils import org.apache.spark.util.random.XORShiftRandom @@ -82,7 +82,8 @@ case class Rand(child: Expression) extends RDG { ctx.addPartitionInitializationStatement( s"$rngTerm = new $className(${seed}L + partitionIndex);") ev.copy(code = s""" - final ${ctx.javaType(dataType)} ${ev.value} = $rngTerm.nextDouble();""", isNull = "false") + final ${CodeGenerator.javaType(dataType)} ${ev.value} = $rngTerm.nextDouble();""", + isNull = "false") } } @@ -116,7 +117,8 @@ case class Randn(child: Expression) extends RDG { ctx.addPartitionInitializationStatement( s"$rngTerm = new $className(${seed}L + partitionIndex);") ev.copy(code = s""" - final ${ctx.javaType(dataType)} ${ev.value} = $rngTerm.nextGaussian();""", isNull = "false") + final ${CodeGenerator.javaType(dataType)} ${ev.value} = $rngTerm.nextGaussian();""", + isNull = "false") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index f3e8f6de58975..ad0c0791d895f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -126,7 +126,7 @@ case class Like(left: Expression, right: Expression) extends StringRegexExpressi ev.copy(code = s""" ${eval.code} boolean ${ev.isNull} = ${eval.isNull}; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${ev.isNull}) { ${ev.value} = $pattern.matcher(${eval.value}.toString()).matches(); } @@ -134,7 +134,7 @@ case class Like(left: Expression, right: Expression) extends StringRegexExpressi } else { ev.copy(code = s""" boolean ${ev.isNull} = true; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; """) } } else { @@ -201,7 +201,7 @@ case class RLike(left: Expression, right: Expression) extends StringRegexExpress ev.copy(code = s""" ${eval.code} boolean ${ev.isNull} = ${eval.isNull}; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${ev.isNull}) { ${ev.value} = $pattern.matcher(${eval.value}.toString()).find(0); } @@ -209,7 +209,7 @@ case class RLike(left: Expression, right: Expression) extends StringRegexExpress } else { ev.copy(code = s""" boolean ${ev.isNull} = true; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; """) } } else { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index d7612e30b4c57..22fbb8998ed89 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -102,11 +102,11 @@ case class Concat(children: Seq[Expression]) extends Expression { val codes = ctx.splitExpressionsWithCurrentInputs( expressions = inputs, funcName = "valueConcat", - extraArguments = (s"${ctx.javaType(dataType)}[]", args) :: Nil) + extraArguments = (s"${CodeGenerator.javaType(dataType)}[]", args) :: Nil) ev.copy(s""" $initCode $codes - ${ctx.javaType(dataType)} ${ev.value} = $concatenator.concat($args); + ${CodeGenerator.javaType(dataType)} ${ev.value} = $concatenator.concat($args); boolean ${ev.isNull} = ${ev.value} == null; """) } @@ -196,7 +196,7 @@ case class ConcatWs(children: Seq[Expression]) } else { val array = ctx.freshName("array") val varargNum = ctx.freshName("varargNum") - val idxInVararg = ctx.freshName("idxInVararg") + val idxVararg = ctx.freshName("idxInVararg") val evals = children.map(_.genCode(ctx)) val (varargCount, varargBuild) = children.tail.zip(evals.tail).map { case (child, eval) => @@ -206,7 +206,7 @@ case class ConcatWs(children: Seq[Expression]) if (eval.isNull == "true") { "" } else { - s"$array[$idxInVararg ++] = ${eval.isNull} ? (UTF8String) null : ${eval.value};" + s"$array[$idxVararg ++] = ${eval.isNull} ? (UTF8String) null : ${eval.value};" }) case _: ArrayType => val size = ctx.freshName("n") @@ -222,7 +222,7 @@ case class ConcatWs(children: Seq[Expression]) if (!${eval.isNull}) { final int $size = ${eval.value}.numElements(); for (int j = 0; j < $size; j ++) { - $array[$idxInVararg ++] = ${ctx.getValue(eval.value, StringType, "j")}; + $array[$idxVararg ++] = ${CodeGenerator.getValue(eval.value, StringType, "j")}; } } """) @@ -247,20 +247,20 @@ case class ConcatWs(children: Seq[Expression]) val varargBuilds = ctx.splitExpressionsWithCurrentInputs( expressions = varargBuild, funcName = "varargBuildsConcatWs", - extraArguments = ("UTF8String []", array) :: ("int", idxInVararg) :: Nil, + extraArguments = ("UTF8String []", array) :: ("int", idxVararg) :: Nil, returnType = "int", makeSplitFunction = body => s""" |$body - |return $idxInVararg; + |return $idxVararg; """.stripMargin, - foldFunctions = _.map(funcCall => s"$idxInVararg = $funcCall;").mkString("\n")) + foldFunctions = _.map(funcCall => s"$idxVararg = $funcCall;").mkString("\n")) ev.copy( s""" $codes int $varargNum = ${children.count(_.dataType == StringType) - 1}; - int $idxInVararg = 0; + int $idxVararg = 0; $varargCounts UTF8String[] $array = new UTF8String[$varargNum]; $varargBuilds @@ -333,7 +333,7 @@ case class Elt(children: Seq[Expression]) extends Expression { val indexVal = ctx.freshName("index") val indexMatched = ctx.freshName("eltIndexMatched") - val inputVal = ctx.addMutableState(ctx.javaType(dataType), "inputVal") + val inputVal = ctx.addMutableState(CodeGenerator.javaType(dataType), "inputVal") val assignInputValue = inputs.zipWithIndex.map { case (eval, index) => s""" @@ -350,10 +350,10 @@ case class Elt(children: Seq[Expression]) extends Expression { expressions = assignInputValue, funcName = "eltFunc", extraArguments = ("int", indexVal) :: Nil, - returnType = ctx.JAVA_BOOLEAN, + returnType = CodeGenerator.JAVA_BOOLEAN, makeSplitFunction = body => s""" - |${ctx.JAVA_BOOLEAN} $indexMatched = false; + |${CodeGenerator.JAVA_BOOLEAN} $indexMatched = false; |do { | $body |} while (false); @@ -372,12 +372,12 @@ case class Elt(children: Seq[Expression]) extends Expression { s""" |${index.code} |final int $indexVal = ${index.value}; - |${ctx.JAVA_BOOLEAN} $indexMatched = false; + |${CodeGenerator.JAVA_BOOLEAN} $indexMatched = false; |$inputVal = null; |do { | $codes |} while (false); - |final ${ctx.javaType(dataType)} ${ev.value} = $inputVal; + |final ${CodeGenerator.javaType(dataType)} ${ev.value} = $inputVal; |final boolean ${ev.isNull} = ${ev.value} == null; """.stripMargin) } @@ -1410,10 +1410,10 @@ case class FormatString(children: Expression*) extends Expression with ImplicitC val numArgLists = argListGen.length val argListCode = argListGen.zipWithIndex.map { case(v, index) => val value = - if (ctx.boxedType(v._1) != ctx.javaType(v._1)) { + if (CodeGenerator.boxedType(v._1) != CodeGenerator.javaType(v._1)) { // Java primitives get boxed in order to allow null values. - s"(${v._2.isNull}) ? (${ctx.boxedType(v._1)}) null : " + - s"new ${ctx.boxedType(v._1)}(${v._2.value})" + s"(${v._2.isNull}) ? (${CodeGenerator.boxedType(v._1)}) null : " + + s"new ${CodeGenerator.boxedType(v._1)}(${v._2.value})" } else { s"(${v._2.isNull}) ? null : ${v._2.value}" } @@ -1434,7 +1434,7 @@ case class FormatString(children: Expression*) extends Expression with ImplicitC ev.copy(code = s""" ${pattern.code} boolean ${ev.isNull} = ${pattern.isNull}; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${ev.isNull}) { $stringBuffer $sb = new $stringBuffer(); $formatter $form = new $formatter($sb, ${classOf[Locale].getName}.US); @@ -2110,7 +2110,8 @@ case class FormatNumber(x: Expression, d: Expression) val usLocale = "US" val i = ctx.freshName("i") val dFormat = ctx.freshName("dFormat") - val lastDValue = ctx.addMutableState(ctx.JAVA_INT, "lastDValue", v => s"$v = -100;") + val lastDValue = + ctx.addMutableState(CodeGenerator.JAVA_INT, "lastDValue", v => s"$v = -100;") val pattern = ctx.addMutableState(sb, "pattern", v => s"$v = new $sb();") val numberFormat = ctx.addMutableState(df, "numberFormat", v => s"""$v = new $df("", new $dfs($l.$usLocale));""") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index 676ba3956ddc8..1e48c7b8df9da 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -405,12 +405,12 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { test("SPARK-18016: define mutable states by using an array") { val ctx1 = new CodegenContext for (i <- 1 to CodeGenerator.OUTER_CLASS_VARIABLES_THRESHOLD + 10) { - ctx1.addMutableState(ctx1.JAVA_INT, "i", v => s"$v = $i;") + ctx1.addMutableState(CodeGenerator.JAVA_INT, "i", v => s"$v = $i;") } assert(ctx1.inlinedMutableStates.size == CodeGenerator.OUTER_CLASS_VARIABLES_THRESHOLD) // When the number of primitive type mutable states is over the threshold, others are // allocated into an array - assert(ctx1.arrayCompactedMutableStates.get(ctx1.JAVA_INT).get.arrayNames.size == 1) + assert(ctx1.arrayCompactedMutableStates.get(CodeGenerator.JAVA_INT).get.arrayNames.size == 1) assert(ctx1.mutableStateInitCode.size == CodeGenerator.OUTER_CLASS_VARIABLES_THRESHOLD + 10) val ctx2 = new CodegenContext diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala index 04f2619ed7541..392906a022903 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.catalyst.expressions.{BoundReference, UnsafeRow} -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.types.DataType import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector} @@ -49,15 +49,15 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport { ordinal: String, dataType: DataType, nullable: Boolean): ExprCode = { - val javaType = ctx.javaType(dataType) - val value = ctx.getValueFromVector(columnVar, dataType, ordinal) + val javaType = CodeGenerator.javaType(dataType) + val value = CodeGenerator.getValueFromVector(columnVar, dataType, ordinal) val isNullVar = if (nullable) { ctx.freshName("isNull") } else { "false" } val valueVar = ctx.freshName("value") val str = s"columnVector[$columnVar, $ordinal, ${dataType.simpleString}]" val code = s"${ctx.registerComment(str)}\n" + (if (nullable) { s""" boolean $isNullVar = $columnVar.isNullAt($ordinal); - $javaType $valueVar = $isNullVar ? ${ctx.defaultValue(dataType)} : ($value); + $javaType $valueVar = $isNullVar ? ${CodeGenerator.defaultValue(dataType)} : ($value); """ } else { s"$javaType $valueVar = $value;" @@ -85,12 +85,13 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport { // metrics val numOutputRows = metricTerm(ctx, "numOutputRows") val scanTimeMetric = metricTerm(ctx, "scanTime") - val scanTimeTotalNs = ctx.addMutableState(ctx.JAVA_LONG, "scanTime") // init as scanTime = 0 + val scanTimeTotalNs = + ctx.addMutableState(CodeGenerator.JAVA_LONG, "scanTime") // init as scanTime = 0 val columnarBatchClz = classOf[ColumnarBatch].getName val batch = ctx.addMutableState(columnarBatchClz, "batch") - val idx = ctx.addMutableState(ctx.JAVA_INT, "batchIdx") // init as batchIdx = 0 + val idx = ctx.addMutableState(CodeGenerator.JAVA_INT, "batchIdx") // init as batchIdx = 0 val columnVectorClzs = vectorTypes.getOrElse( Seq.fill(output.indices.size)(classOf[ColumnVector].getName)) val (colVars, columnAssigns) = columnVectorClzs.zipWithIndex.map { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala index a7bd5ebf93ecd..12ae1ea4a7c13 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala @@ -21,7 +21,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartitioning} import org.apache.spark.sql.execution.metric.SQLMetrics @@ -154,7 +154,8 @@ case class ExpandExec( val value = ctx.freshName("value") val code = s""" |boolean $isNull = true; - |${ctx.javaType(firstExpr.dataType)} $value = ${ctx.defaultValue(firstExpr.dataType)}; + |${CodeGenerator.javaType(firstExpr.dataType)} $value = + | ${CodeGenerator.defaultValue(firstExpr.dataType)}; """.stripMargin ExprCode(code, isNull, value) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala index 0c2c4a1a9100d..384f0398a1ec0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructType} @@ -305,15 +305,15 @@ case class GenerateExec( nullable: Boolean, initialChecks: Seq[String]): ExprCode = { val value = ctx.freshName(name) - val javaType = ctx.javaType(dt) - val getter = ctx.getValue(source, dt, index) + val javaType = CodeGenerator.javaType(dt) + val getter = CodeGenerator.getValue(source, dt, index) val checks = initialChecks ++ optionalCode(nullable, s"$source.isNullAt($index)") if (checks.nonEmpty) { val isNull = ctx.freshName("isNull") val code = s""" |boolean $isNull = ${checks.mkString(" || ")}; - |$javaType $value = $isNull ? ${ctx.defaultValue(dt)} : $getter; + |$javaType $value = $isNull ? ${CodeGenerator.defaultValue(dt)} : $getter; """.stripMargin ExprCode(code, isNull, value) } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala index ac1c34d41c4f1..0dc16ba5ce281 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala @@ -22,7 +22,7 @@ import org.apache.spark.executor.TaskMetrics import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.metric.SQLMetrics @@ -133,7 +133,8 @@ case class SortExec( override def needStopCheck: Boolean = false override protected def doProduce(ctx: CodegenContext): String = { - val needToSort = ctx.addMutableState(ctx.JAVA_BOOLEAN, "needToSort", v => s"$v = true;") + val needToSort = + ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "needToSort", v => s"$v = true;") // Initialize the class member variables. This includes the instance of the Sorter and // the iterator to return sorted rows. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index deb0a044c2fb2..f89e3fb0e536f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -234,7 +234,7 @@ trait CodegenSupport extends SparkPlan { variables.zipWithIndex.foreach { case (ev, i) => val paramName = ctx.freshName(s"expr_$i") - val paramType = ctx.javaType(attributes(i).dataType) + val paramType = CodeGenerator.javaType(attributes(i).dataType) arguments += ev.value parameters += s"$paramType $paramName" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index ce3c68810f3b6..1926e9373bc55 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -178,7 +178,7 @@ case class HashAggregateExec( private var bufVars: Seq[ExprCode] = _ private def doProduceWithoutKeys(ctx: CodegenContext): String = { - val initAgg = ctx.addMutableState(ctx.JAVA_BOOLEAN, "initAgg") + val initAgg = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "initAgg") // The generated function doesn't have input row in the code context. ctx.INPUT_ROW = null @@ -186,8 +186,8 @@ case class HashAggregateExec( val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate]) val initExpr = functions.flatMap(f => f.initialValues) bufVars = initExpr.map { e => - val isNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, "bufIsNull") - val value = ctx.addMutableState(ctx.javaType(e.dataType), "bufValue") + val isNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "bufIsNull") + val value = ctx.addMutableState(CodeGenerator.javaType(e.dataType), "bufValue") // The initial expression should not access any column val ev = e.genCode(ctx) val initVars = s""" @@ -532,7 +532,7 @@ case class HashAggregateExec( */ private def checkIfFastHashMapSupported(ctx: CodegenContext): Boolean = { val isSupported = - (groupingKeySchema ++ bufferSchema).forall(f => ctx.isPrimitiveType(f.dataType) || + (groupingKeySchema ++ bufferSchema).forall(f => CodeGenerator.isPrimitiveType(f.dataType) || f.dataType.isInstanceOf[DecimalType] || f.dataType.isInstanceOf[StringType]) && bufferSchema.nonEmpty && modes.forall(mode => mode == Partial || mode == PartialMerge) @@ -565,7 +565,7 @@ case class HashAggregateExec( } private def doProduceWithKeys(ctx: CodegenContext): String = { - val initAgg = ctx.addMutableState(ctx.JAVA_BOOLEAN, "initAgg") + val initAgg = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "initAgg") if (sqlContext.conf.enableTwoLevelAggMap) { enableTwoLevelHashMap(ctx) } else { @@ -757,7 +757,7 @@ case class HashAggregateExec( val (checkFallbackForGeneratedHashMap, checkFallbackForBytesToBytesMap, resetCounter, incCounter) = if (testFallbackStartsAt.isDefined) { - val countTerm = ctx.addMutableState(ctx.JAVA_INT, "fallbackCounter") + val countTerm = ctx.addMutableState(CodeGenerator.JAVA_INT, "fallbackCounter") (s"$countTerm < ${testFallbackStartsAt.get._1}", s"$countTerm < ${testFallbackStartsAt.get._2}", s"$countTerm = 0;", s"$countTerm += 1;") } else { @@ -832,7 +832,7 @@ case class HashAggregateExec( } val updateUnsafeRowBuffer = unsafeRowBufferEvals.zipWithIndex.map { case (ev, i) => val dt = updateExpr(i).dataType - ctx.updateColumn(unsafeRowBuffer, dt, i, ev, updateExpr(i).nullable) + CodeGenerator.updateColumn(unsafeRowBuffer, dt, i, ev, updateExpr(i).nullable) } s""" |// common sub-expressions @@ -855,7 +855,7 @@ case class HashAggregateExec( } val updateFastRow = fastRowEvals.zipWithIndex.map { case (ev, i) => val dt = updateExpr(i).dataType - ctx.updateColumn( + CodeGenerator.updateColumn( fastRowBuffer, dt, i, ev, updateExpr(i).nullable, isVectorizedHashMapEnabled) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala index 1c613b19c4ab1..6b60b414ffe5f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, DeclarativeAggregate} -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} import org.apache.spark.sql.types._ /** @@ -41,13 +41,13 @@ abstract class HashMapGenerator( val groupingKeys = groupingKeySchema.map(k => Buffer(k.dataType, ctx.freshName("key"))) val bufferValues = bufferSchema.map(k => Buffer(k.dataType, ctx.freshName("value"))) val groupingKeySignature = - groupingKeys.map(key => s"${ctx.javaType(key.dataType)} ${key.name}").mkString(", ") + groupingKeys.map(key => s"${CodeGenerator.javaType(key.dataType)} ${key.name}").mkString(", ") val buffVars: Seq[ExprCode] = { val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate]) val initExpr = functions.flatMap(f => f.initialValues) initExpr.map { e => - val isNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, "bufIsNull") - val value = ctx.addMutableState(ctx.javaType(e.dataType), "bufValue") + val isNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "bufIsNull") + val value = ctx.addMutableState(CodeGenerator.javaType(e.dataType), "bufValue") val ev = e.genCode(ctx) val initVars = s""" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala index fd25707dd4ca6..8617be88f3570 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala @@ -18,8 +18,8 @@ package org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.catalyst.expressions.UnsafeRow -import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression} -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext} +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator} import org.apache.spark.sql.types._ /** @@ -114,7 +114,7 @@ class RowBasedHashMapGenerator( def genEqualsForKeys(groupingKeys: Seq[Buffer]): String = { groupingKeys.zipWithIndex.map { case (key: Buffer, ordinal: Int) => - s"""(${ctx.genEqual(key.dataType, ctx.getValue("row", + s"""(${ctx.genEqual(key.dataType, CodeGenerator.getValue("row", key.dataType, ordinal.toString()), key.name)})""" }.mkString(" && ") } @@ -147,7 +147,7 @@ class RowBasedHashMapGenerator( case t: DecimalType => s"agg_rowWriter.write(${ordinal}, ${key.name}, ${t.precision}, ${t.scale})" case t: DataType => - if (!t.isInstanceOf[StringType] && !ctx.isPrimitiveType(t)) { + if (!t.isInstanceOf[StringType] && !CodeGenerator.isPrimitiveType(t)) { throw new IllegalArgumentException(s"cannot generate code for unsupported type: $t") } s"agg_rowWriter.write(${ordinal}, ${key.name})" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala index 633eeac180974..7b3580cecc60d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression -import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator} import org.apache.spark.sql.execution.vectorized.{MutableColumnarRow, OnHeapColumnVector} import org.apache.spark.sql.types._ import org.apache.spark.sql.vectorized.ColumnarBatch @@ -127,7 +127,8 @@ class VectorizedHashMapGenerator( def genEqualsForKeys(groupingKeys: Seq[Buffer]): String = { groupingKeys.zipWithIndex.map { case (key: Buffer, ordinal: Int) => - val value = ctx.getValueFromVector(s"vectors[$ordinal]", key.dataType, "buckets[idx]") + val value = CodeGenerator.getValueFromVector(s"vectors[$ordinal]", key.dataType, + "buckets[idx]") s"(${ctx.genEqual(key.dataType, value, key.name)})" }.mkString(" && ") } @@ -182,14 +183,14 @@ class VectorizedHashMapGenerator( def genCodeToSetKeys(groupingKeys: Seq[Buffer]): Seq[String] = { groupingKeys.zipWithIndex.map { case (key: Buffer, ordinal: Int) => - ctx.setValue(s"vectors[$ordinal]", "numRows", key.dataType, key.name) + CodeGenerator.setValue(s"vectors[$ordinal]", "numRows", key.dataType, key.name) } } def genCodeToSetAggBuffers(bufferValues: Seq[Buffer]): Seq[String] = { bufferValues.zipWithIndex.map { case (key: Buffer, ordinal: Int) => - ctx.updateColumn(s"vectors[${groupingKeys.length + ordinal}]", "numRows", key.dataType, - buffVars(ordinal), nullable = true) + CodeGenerator.updateColumn(s"vectors[${groupingKeys.length + ordinal}]", "numRows", + key.dataType, buffVars(ordinal), nullable = true) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index a15a8d11aa2a0..4707022f74547 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -24,7 +24,7 @@ import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskCon import org.apache.spark.rdd.{EmptyRDD, PartitionwiseSampledRDD, RDD} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, ExpressionCanonicalizer} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, ExpressionCanonicalizer} import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.types.LongType @@ -364,8 +364,8 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) protected override def doProduce(ctx: CodegenContext): String = { val numOutput = metricTerm(ctx, "numOutputRows") - val initTerm = ctx.addMutableState(ctx.JAVA_BOOLEAN, "initRange") - val number = ctx.addMutableState(ctx.JAVA_LONG, "number") + val initTerm = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "initRange") + val number = ctx.addMutableState(CodeGenerator.JAVA_LONG, "number") val value = ctx.freshName("value") val ev = ExprCode("", "false", value) @@ -385,10 +385,10 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) // the metrics. // Once number == batchEnd, it's time to progress to the next batch. - val batchEnd = ctx.addMutableState(ctx.JAVA_LONG, "batchEnd") + val batchEnd = ctx.addMutableState(CodeGenerator.JAVA_LONG, "batchEnd") // How many values should still be generated by this range operator. - val numElementsTodo = ctx.addMutableState(ctx.JAVA_LONG, "numElementsTodo") + val numElementsTodo = ctx.addMutableState(CodeGenerator.JAVA_LONG, "numElementsTodo") // How many values should be generated in the next batch. val nextBatchTodo = ctx.freshName("nextBatchTodo") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala index 4f28eeb725cbb..3b5655ba0582e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala @@ -91,7 +91,7 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera val accessorName = ctx.addMutableState(accessorCls, "accessor") val createCode = dt match { - case t if ctx.isPrimitiveType(dt) => + case t if CodeGenerator.isPrimitiveType(dt) => s"$accessorName = new $accessorCls(ByteBuffer.wrap(buffers[$index]).order(nativeOrder));" case NullType | StringType | BinaryType => s"$accessorName = new $accessorCls(ByteBuffer.wrap(buffers[$index]).order(nativeOrder));" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala index 1918fcc5482db..487d6a2383318 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala @@ -22,7 +22,7 @@ import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, GenerateUnsafeProjection} +import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical.{BroadcastDistribution, Distribution, UnspecifiedDistribution} import org.apache.spark.sql.execution.{BinaryExecNode, CodegenSupport, SparkPlan} @@ -182,9 +182,10 @@ case class BroadcastHashJoinExec( // the variables are needed even there is no matched rows val isNull = ctx.freshName("isNull") val value = ctx.freshName("value") + val javaType = CodeGenerator.javaType(a.dataType) val code = s""" |boolean $isNull = true; - |${ctx.javaType(a.dataType)} $value = ${ctx.defaultValue(a.dataType)}; + |$javaType $value = ${CodeGenerator.defaultValue(a.dataType)}; |if ($matched != null) { | ${ev.code} | $isNull = ${ev.isNull}; diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index 2de2f30eb05d3..5a511b30e4fd9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -22,7 +22,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.{BinaryExecNode, CodegenSupport, @@ -516,9 +516,9 @@ case class SortMergeJoinExec( ctx.INPUT_ROW = leftRow left.output.zipWithIndex.map { case (a, i) => val value = ctx.freshName("value") - val valueCode = ctx.getValue(leftRow, a.dataType, i.toString) - val javaType = ctx.javaType(a.dataType) - val defaultValue = ctx.defaultValue(a.dataType) + val valueCode = CodeGenerator.getValue(leftRow, a.dataType, i.toString) + val javaType = CodeGenerator.javaType(a.dataType) + val defaultValue = CodeGenerator.defaultValue(a.dataType) if (a.nullable) { val isNull = ctx.freshName("isNull") val code = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala index cccee63bc0680..66bcda8913738 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala @@ -21,7 +21,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.serializer.Serializer import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, LazilyGeneratedOrdering} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, LazilyGeneratedOrdering} import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec import org.apache.spark.util.Utils @@ -71,7 +71,8 @@ trait BaseLimitExec extends UnaryExecNode with CodegenSupport { } override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { - val stopEarly = ctx.addMutableState(ctx.JAVA_BOOLEAN, "stopEarly") // init as stopEarly = false + val stopEarly = + ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "stopEarly") // init as stopEarly = false ctx.addNewFunction("stopEarly", s""" @Override @@ -79,7 +80,7 @@ trait BaseLimitExec extends UnaryExecNode with CodegenSupport { return $stopEarly; } """, inlineToOuterClass = true) - val countTerm = ctx.addMutableState(ctx.JAVA_INT, "count") // init as count = 0 + val countTerm = ctx.addMutableState(CodeGenerator.JAVA_INT, "count") // init as count = 0 s""" | if ($countTerm < $limit) { | $countTerm += 1; From 42cf48e20cd5e47e1b7557af9c71c4eea142f10f Mon Sep 17 00:00:00 2001 From: Ala Luszczak Date: Mon, 5 Mar 2018 14:33:12 +0100 Subject: [PATCH 0424/1161] [SPARK-23496][CORE] Locality of coalesced partitions can be severely skewed by the order of input partitions ## What changes were proposed in this pull request? The algorithm in `DefaultPartitionCoalescer.setupGroups` is responsible for picking preferred locations for coalesced partitions. It analyzes the preferred locations of input partitions. It starts by trying to create one partition for each unique location in the input. However, if the the requested number of coalesced partitions is higher that the number of unique locations, it has to pick duplicate locations. Previously, the duplicate locations would be picked by iterating over the input partitions in order, and copying their preferred locations to coalesced partitions. If the input partitions were clustered by location, this could result in severe skew. With the fix, instead of iterating over the list of input partitions in order, we pick them at random. It's not perfectly balanced, but it's much better. ## How was this patch tested? Unit test reproducing the behavior was added. Author: Ala Luszczak Closes #20664 from ala/SPARK-23496. --- .../org/apache/spark/rdd/CoalescedRDD.scala | 8 ++-- .../scala/org/apache/spark/rdd/RDDSuite.scala | 42 +++++++++++++++++++ 2 files changed, 46 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala index 10451a324b0f4..94e7d0b38cba3 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala @@ -266,17 +266,17 @@ private class DefaultPartitionCoalescer(val balanceSlack: Double = 0.10) numCreated += 1 } } - tries = 0 // if we don't have enough partition groups, create duplicates while (numCreated < targetLen) { - val (nxt_replica, nxt_part) = partitionLocs.partsWithLocs(tries) - tries += 1 + // Copy the preferred location from a random input partition. + // This helps in avoiding skew when the input partitions are clustered by preferred location. + val (nxt_replica, nxt_part) = partitionLocs.partsWithLocs( + rnd.nextInt(partitionLocs.partsWithLocs.length)) val pgroup = new PartitionGroup(Some(nxt_replica)) groupArr += pgroup groupHash.getOrElseUpdate(nxt_replica, ArrayBuffer()) += pgroup addPartToPGroup(nxt_part, pgroup) numCreated += 1 - if (tries >= partitionLocs.partsWithLocs.length) tries = 0 } } diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index e994d724c462f..191c61250ce21 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -1129,6 +1129,35 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext { }.collect() } + test("SPARK-23496: order of input partitions can result in severe skew in coalesce") { + val numInputPartitions = 100 + val numCoalescedPartitions = 50 + val locations = Array("locA", "locB") + + val inputRDD = sc.makeRDD(Range(0, numInputPartitions).toArray[Int], numInputPartitions) + assert(inputRDD.getNumPartitions == numInputPartitions) + + val locationPrefRDD = new LocationPrefRDD(inputRDD, { (p: Partition) => + if (p.index < numCoalescedPartitions) { + Seq(locations(0)) + } else { + Seq(locations(1)) + } + }) + val coalescedRDD = new CoalescedRDD(locationPrefRDD, numCoalescedPartitions) + + val numPartsPerLocation = coalescedRDD + .getPartitions + .map(coalescedRDD.getPreferredLocations(_).head) + .groupBy(identity) + .mapValues(_.size) + + // Make sure the coalesced partitions are distributed fairly evenly between the two locations. + // This should not become flaky since the DefaultPartitionsCoalescer uses a fixed seed. + assert(numPartsPerLocation(locations(0)) > 0.4 * numCoalescedPartitions) + assert(numPartsPerLocation(locations(1)) > 0.4 * numCoalescedPartitions) + } + // NOTE // Below tests calling sc.stop() have to be the last tests in this suite. If there are tests // running after them and if they access sc those tests will fail as sc is already closed, because @@ -1210,3 +1239,16 @@ class SizeBasedCoalescer(val maxSize: Int) extends PartitionCoalescer with Seria groups.toArray } } + +/** Alters the preferred locations of the parent RDD using provided function. */ +class LocationPrefRDD[T: ClassTag]( + @transient var prev: RDD[T], + val locationPicker: Partition => Seq[String]) extends RDD[T](prev) { + override protected def getPartitions: Array[Partition] = prev.partitions + + override def compute(partition: Partition, context: TaskContext): Iterator[T] = + null.asInstanceOf[Iterator[T]] + + override def getPreferredLocations(partition: Partition): Seq[String] = + locationPicker(partition) +} From 5ff72ffcf495d2823f7f1186078d1cb261667c3d Mon Sep 17 00:00:00 2001 From: Anirudh Date: Mon, 5 Mar 2018 23:17:16 +0900 Subject: [PATCH 0425/1161] [SPARK-23566][MINOR][DOC] Argument name mismatch fixed Argument name mismatch fixed. ## What changes were proposed in this pull request? `col` changed to `new` in doc string to match the argument list. Patch file added: https://issues.apache.org/jira/browse/SPARK-23566 Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Anirudh Closes #20716 from animenon/master. --- python/pyspark/sql/dataframe.py | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index f37777e13ee12..9d8e85cde914f 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -588,6 +588,8 @@ def coalesce(self, numPartitions): """ Returns a new :class:`DataFrame` that has exactly `numPartitions` partitions. + :param numPartitions: int, to specify the target number of partitions + Similar to coalesce defined on an :class:`RDD`, this operation results in a narrow dependency, e.g. if you go from 1000 partitions to 100 partitions, there will not be a shuffle, instead each of the 100 new partitions will @@ -612,9 +614,10 @@ def repartition(self, numPartitions, *cols): Returns a new :class:`DataFrame` partitioned by the given partitioning expressions. The resulting DataFrame is hash partitioned. - ``numPartitions`` can be an int to specify the target number of partitions or a Column. - If it is a Column, it will be used as the first partitioning column. If not specified, - the default number of partitions is used. + :param numPartitions: + can be an int to specify the target number of partitions or a Column. + If it is a Column, it will be used as the first partitioning column. If not specified, + the default number of partitions is used. .. versionchanged:: 1.6 Added optional arguments to specify the partitioning columns. Also made numPartitions @@ -673,9 +676,10 @@ def repartitionByRange(self, numPartitions, *cols): Returns a new :class:`DataFrame` partitioned by the given partitioning expressions. The resulting DataFrame is range partitioned. - ``numPartitions`` can be an int to specify the target number of partitions or a Column. - If it is a Column, it will be used as the first partitioning column. If not specified, - the default number of partitions is used. + :param numPartitions: + can be an int to specify the target number of partitions or a Column. + If it is a Column, it will be used as the first partitioning column. If not specified, + the default number of partitions is used. At least one partition-by expression must be specified. When no explicit sort order is specified, "ascending nulls first" is assumed. @@ -892,6 +896,8 @@ def colRegex(self, colName): def alias(self, alias): """Returns a new :class:`DataFrame` with an alias set. + :param alias: string, an alias name to be set for the DataFrame. + >>> from pyspark.sql.functions import * >>> df_as1 = df.alias("df_as1") >>> df_as2 = df.alias("df_as2") @@ -1900,7 +1906,7 @@ def withColumnRenamed(self, existing, new): This is a no-op if schema doesn't contain the given column name. :param existing: string, name of the existing column to rename. - :param col: string, new name of the column. + :param new: string, new name of the column. >>> df.withColumnRenamed('age', 'age2').collect() [Row(age2=2, name=u'Alice'), Row(age2=5, name=u'Bob')] From a366b950b90650693ad0eb1e5b9a988ad028d845 Mon Sep 17 00:00:00 2001 From: Mihaly Toth Date: Mon, 5 Mar 2018 23:46:40 +0900 Subject: [PATCH 0426/1161] [SPARK-23329][SQL] Fix documentation of trigonometric functions ## What changes were proposed in this pull request? Provide more details in trigonometric function documentations. Referenced `java.lang.Math` for further details in the descriptions. ## How was this patch tested? Ran full build, checked generated documentation manually Author: Mihaly Toth Closes #20618 from misutoth/trigonometric-doc. --- R/pkg/R/functions.R | 34 ++-- python/pyspark/sql/functions.py | 62 ++++--- .../expressions/mathExpressions.scala | 99 ++++++++--- .../org/apache/spark/sql/functions.scala | 160 ++++++++++++------ 4 files changed, 248 insertions(+), 107 deletions(-) diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 9f7c6317cd924..29ee146ab14f9 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -278,8 +278,8 @@ setMethod("abs", }) #' @details -#' \code{acos}: Computes the cosine inverse of the given value; the returned angle is in -#' the range 0.0 through pi. +#' \code{acos}: Returns the inverse cosine of the given value, +#' as if computed by \code{java.lang.Math.acos()} #' #' @rdname column_math_functions #' @export @@ -334,8 +334,8 @@ setMethod("ascii", }) #' @details -#' \code{asin}: Computes the sine inverse of the given value; the returned angle is in -#' the range -pi/2 through pi/2. +#' \code{asin}: Returns the inverse sine of the given value, +#' as if computed by \code{java.lang.Math.asin()} #' #' @rdname column_math_functions #' @export @@ -349,8 +349,8 @@ setMethod("asin", }) #' @details -#' \code{atan}: Computes the tangent inverse of the given value; the returned angle is in the range -#' -pi/2 through pi/2. +#' \code{atan}: Returns the inverse tangent of the given value, +#' as if computed by \code{java.lang.Math.atan()} #' #' @rdname column_math_functions #' @export @@ -613,7 +613,8 @@ setMethod("covar_pop", signature(col1 = "characterOrColumn", col2 = "characterOr }) #' @details -#' \code{cos}: Computes the cosine of the given value. Units in radians. +#' \code{cos}: Returns the cosine of the given value, +#' as if computed by \code{java.lang.Math.cos()}. Units in radians. #' #' @rdname column_math_functions #' @aliases cos cos,Column-method @@ -627,7 +628,8 @@ setMethod("cos", }) #' @details -#' \code{cosh}: Computes the hyperbolic cosine of the given value. +#' \code{cosh}: Returns the hyperbolic cosine of the given value, +#' as if computed by \code{java.lang.Math.cosh()}. #' #' @rdname column_math_functions #' @aliases cosh cosh,Column-method @@ -1463,7 +1465,8 @@ setMethod("sign", signature(x = "Column"), }) #' @details -#' \code{sin}: Computes the sine of the given value. Units in radians. +#' \code{sin}: Returns the sine of the given value, +#' as if computed by \code{java.lang.Math.sin()}. Units in radians. #' #' @rdname column_math_functions #' @aliases sin sin,Column-method @@ -1477,7 +1480,8 @@ setMethod("sin", }) #' @details -#' \code{sinh}: Computes the hyperbolic sine of the given value. +#' \code{sinh}: Returns the hyperbolic sine of the given value, +#' as if computed by \code{java.lang.Math.sinh()}. #' #' @rdname column_math_functions #' @aliases sinh sinh,Column-method @@ -1653,7 +1657,9 @@ setMethod("sumDistinct", }) #' @details -#' \code{tan}: Computes the tangent of the given value. Units in radians. +#' \code{tan}: Returns the tangent of the given value, +#' as if computed by \code{java.lang.Math.tan()}. +#' Units in radians. #' #' @rdname column_math_functions #' @aliases tan tan,Column-method @@ -1667,7 +1673,8 @@ setMethod("tan", }) #' @details -#' \code{tanh}: Computes the hyperbolic tangent of the given value. +#' \code{tanh}: Returns the hyperbolic tangent of the given value, +#' as if computed by \code{java.lang.Math.tanh()}. #' #' @rdname column_math_functions #' @aliases tanh tanh,Column-method @@ -1973,7 +1980,8 @@ setMethod("year", #' @details #' \code{atan2}: Returns the angle theta from the conversion of rectangular coordinates -#' (x, y) to polar coordinates (r, theta). Units in radians. +#' (x, y) to polar coordinates (r, theta), +#' as if computed by \code{java.lang.Math.atan2()}. Units in radians. #' #' @rdname column_math_functions #' @aliases atan2 atan2,Column-method diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 9bb9c323a5a60..b9c0c57262c5d 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -106,18 +106,15 @@ def _(): _functions_1_4 = { # unary math functions - 'acos': 'Computes the cosine inverse of the given value; the returned angle is in the range' + - '0.0 through pi.', - 'asin': 'Computes the sine inverse of the given value; the returned angle is in the range' + - '-pi/2 through pi/2.', - 'atan': 'Computes the tangent inverse of the given value; the returned angle is in the range' + - '-pi/2 through pi/2', + 'acos': ':return: inverse cosine of `col`, as if computed by `java.lang.Math.acos()`', + 'asin': ':return: inverse sine of `col`, as if computed by `java.lang.Math.asin()`', + 'atan': ':return: inverse tangent of `col`, as if computed by `java.lang.Math.atan()`', 'cbrt': 'Computes the cube-root of the given value.', 'ceil': 'Computes the ceiling of the given value.', - 'cos': """Computes the cosine of the given value. - - :param col: :class:`DoubleType` column, units in radians.""", - 'cosh': 'Computes the hyperbolic cosine of the given value.', + 'cos': """:param col: angle in radians + :return: cosine of the angle, as if computed by `java.lang.Math.cos()`.""", + 'cosh': """:param col: hyperbolic angle + :return: hyperbolic cosine of the angle, as if computed by `java.lang.Math.cosh()`""", 'exp': 'Computes the exponential of the given value.', 'expm1': 'Computes the exponential of the given value minus one.', 'floor': 'Computes the floor of the given value.', @@ -127,14 +124,16 @@ def _(): 'rint': 'Returns the double value that is closest in value to the argument and' + ' is equal to a mathematical integer.', 'signum': 'Computes the signum of the given value.', - 'sin': """Computes the sine of the given value. - - :param col: :class:`DoubleType` column, units in radians.""", - 'sinh': 'Computes the hyperbolic sine of the given value.', - 'tan': """Computes the tangent of the given value. - - :param col: :class:`DoubleType` column, units in radians.""", - 'tanh': 'Computes the hyperbolic tangent of the given value.', + 'sin': """:param col: angle in radians + :return: sine of the angle, as if computed by `java.lang.Math.sin()`""", + 'sinh': """:param col: hyperbolic angle + :return: hyperbolic sine of the given value, + as if computed by `java.lang.Math.sinh()`""", + 'tan': """:param col: angle in radians + :return: tangent of the given value, as if computed by `java.lang.Math.tan()`""", + 'tanh': """:param col: hyperbolic angle + :return: hyperbolic tangent of the given value, + as if computed by `java.lang.Math.tanh()`""", 'toDegrees': '.. note:: Deprecated in 2.1, use :func:`degrees` instead.', 'toRadians': '.. note:: Deprecated in 2.1, use :func:`radians` instead.', 'bitwiseNOT': 'Computes bitwise not.', @@ -173,16 +172,31 @@ def _(): _functions_2_1 = { # unary math functions - 'degrees': 'Converts an angle measured in radians to an approximately equivalent angle ' + - 'measured in degrees.', - 'radians': 'Converts an angle measured in degrees to an approximately equivalent angle ' + - 'measured in radians.', + 'degrees': """ + Converts an angle measured in radians to an approximately equivalent angle + measured in degrees. + :param col: angle in radians + :return: angle in degrees, as if computed by `java.lang.Math.toDegrees()` + """, + 'radians': """ + Converts an angle measured in degrees to an approximately equivalent angle + measured in radians. + :param col: angle in degrees + :return: angle in radians, as if computed by `java.lang.Math.toRadians()` + """, } # math functions that take two arguments as input _binary_mathfunctions = { - 'atan2': 'Returns the angle theta from the conversion of rectangular coordinates (x, y) to' + - 'polar coordinates (r, theta). Units in radians.', + 'atan2': """ + :param col1: coordinate on y-axis + :param col2: coordinate on x-axis + :return: the `theta` component of the point + (`r`, `theta`) + in polar coordinates that corresponds to the point + (`x`, `y`) in Cartesian coordinates, + as if computed by `java.lang.Math.atan2()` + """, 'hypot': 'Computes ``sqrt(a^2 + b^2)`` without intermediate overflow or underflow.', 'pow': 'Returns the value of the first argument raised to the power of the second argument.', } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala index 2c2cf3d2e6227..bc4cfcec47425 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala @@ -168,9 +168,11 @@ case class Pi() extends LeafMathExpression(math.Pi, "PI") //////////////////////////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////////////////////////// -// scalastyle:off line.size.limit @ExpressionDescription( - usage = "_FUNC_(expr) - Returns the inverse cosine (a.k.a. arccosine) of `expr` if -1<=`expr`<=1 or NaN otherwise.", + usage = """ + _FUNC_(expr) - Returns the inverse cosine (a.k.a. arc cosine) of `expr`, as if computed by + `java.lang.Math._FUNC_`. + """, examples = """ Examples: > SELECT _FUNC_(1); @@ -178,12 +180,13 @@ case class Pi() extends LeafMathExpression(math.Pi, "PI") > SELECT _FUNC_(2); NaN """) -// scalastyle:on line.size.limit case class Acos(child: Expression) extends UnaryMathExpression(math.acos, "ACOS") -// scalastyle:off line.size.limit @ExpressionDescription( - usage = "_FUNC_(expr) - Returns the inverse sine (a.k.a. arcsine) the arc sin of `expr` if -1<=`expr`<=1 or NaN otherwise.", + usage = """ + _FUNC_(expr) - Returns the inverse sine (a.k.a. arc sine) the arc sin of `expr`, + as if computed by `java.lang.Math._FUNC_`. + """, examples = """ Examples: > SELECT _FUNC_(0); @@ -191,18 +194,18 @@ case class Acos(child: Expression) extends UnaryMathExpression(math.acos, "ACOS" > SELECT _FUNC_(2); NaN """) -// scalastyle:on line.size.limit case class Asin(child: Expression) extends UnaryMathExpression(math.asin, "ASIN") -// scalastyle:off line.size.limit @ExpressionDescription( - usage = "_FUNC_(expr) - Returns the inverse tangent (a.k.a. arctangent).", + usage = """ + _FUNC_(expr) - Returns the inverse tangent (a.k.a. arc tangent) of `expr`, as if computed by + `java.lang.Math._FUNC_` + """, examples = """ Examples: > SELECT _FUNC_(0); 0.0 """) -// scalastyle:on line.size.limit case class Atan(child: Expression) extends UnaryMathExpression(math.atan, "ATAN") @ExpressionDescription( @@ -252,7 +255,14 @@ case class Ceil(child: Expression) extends UnaryMathExpression(math.ceil, "CEIL" } @ExpressionDescription( - usage = "_FUNC_(expr) - Returns the cosine of `expr`.", + usage = """ + _FUNC_(expr) - Returns the cosine of `expr`, as if computed by + `java.lang.Math._FUNC_`. + """, + arguments = """ + Arguments: + * expr - angle in radians + """, examples = """ Examples: > SELECT _FUNC_(0); @@ -261,7 +271,14 @@ case class Ceil(child: Expression) extends UnaryMathExpression(math.ceil, "CEIL" case class Cos(child: Expression) extends UnaryMathExpression(math.cos, "COS") @ExpressionDescription( - usage = "_FUNC_(expr) - Returns the hyperbolic cosine of `expr`.", + usage = """ + _FUNC_(expr) - Returns the hyperbolic cosine of `expr`, as if computed by + `java.lang.Math._FUNC_`. + """, + arguments = """ + Arguments: + * expr - hyperbolic angle + """, examples = """ Examples: > SELECT _FUNC_(0); @@ -512,7 +529,11 @@ case class Rint(child: Expression) extends UnaryMathExpression(math.rint, "ROUND case class Signum(child: Expression) extends UnaryMathExpression(math.signum, "SIGNUM") @ExpressionDescription( - usage = "_FUNC_(expr) - Returns the sine of `expr`.", + usage = "_FUNC_(expr) - Returns the sine of `expr`, as if computed by `java.lang.Math._FUNC_`.", + arguments = """ + Arguments: + * expr - angle in radians + """, examples = """ Examples: > SELECT _FUNC_(0); @@ -521,7 +542,13 @@ case class Signum(child: Expression) extends UnaryMathExpression(math.signum, "S case class Sin(child: Expression) extends UnaryMathExpression(math.sin, "SIN") @ExpressionDescription( - usage = "_FUNC_(expr) - Returns the hyperbolic sine of `expr`.", + usage = """ + _FUNC_(expr) - Returns hyperbolic sine of `expr`, as if computed by `java.lang.Math._FUNC_`. + """, + arguments = """ + Arguments: + * expr - hyperbolic angle + """, examples = """ Examples: > SELECT _FUNC_(0); @@ -539,7 +566,13 @@ case class Sinh(child: Expression) extends UnaryMathExpression(math.sinh, "SINH" case class Sqrt(child: Expression) extends UnaryMathExpression(math.sqrt, "SQRT") @ExpressionDescription( - usage = "_FUNC_(expr) - Returns the tangent of `expr`.", + usage = """ + _FUNC_(expr) - Returns the tangent of `expr`, as if computed by `java.lang.Math._FUNC_`. + """, + arguments = """ + Arguments: + * expr - angle in radians + """, examples = """ Examples: > SELECT _FUNC_(0); @@ -548,7 +581,13 @@ case class Sqrt(child: Expression) extends UnaryMathExpression(math.sqrt, "SQRT" case class Tan(child: Expression) extends UnaryMathExpression(math.tan, "TAN") @ExpressionDescription( - usage = "_FUNC_(expr) - Returns the cotangent of `expr`.", + usage = """ + _FUNC_(expr) - Returns the cotangent of `expr`, as if computed by `1/java.lang.Math._FUNC_`. + """, + arguments = """ + Arguments: + * expr - angle in radians + """, examples = """ Examples: > SELECT _FUNC_(1); @@ -562,7 +601,14 @@ case class Cot(child: Expression) } @ExpressionDescription( - usage = "_FUNC_(expr) - Returns the hyperbolic tangent of `expr`.", + usage = """ + _FUNC_(expr) - Returns the hyperbolic tangent of `expr`, as if computed by + `java.lang.Math._FUNC_`. + """, + arguments = """ + Arguments: + * expr - hyperbolic angle + """, examples = """ Examples: > SELECT _FUNC_(0); @@ -572,6 +618,10 @@ case class Tanh(child: Expression) extends UnaryMathExpression(math.tanh, "TANH" @ExpressionDescription( usage = "_FUNC_(expr) - Converts radians to degrees.", + arguments = """ + Arguments: + * expr - angle in radians + """, examples = """ Examples: > SELECT _FUNC_(3.141592653589793); @@ -583,6 +633,10 @@ case class ToDegrees(child: Expression) extends UnaryMathExpression(math.toDegre @ExpressionDescription( usage = "_FUNC_(expr) - Converts degrees to radians.", + arguments = """ + Arguments: + * expr - angle in degrees + """, examples = """ Examples: > SELECT _FUNC_(180); @@ -768,15 +822,22 @@ case class Unhex(child: Expression) extends UnaryExpression with ImplicitCastInp //////////////////////////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////////////////////////// -// scalastyle:off line.size.limit @ExpressionDescription( - usage = "_FUNC_(expr1, expr2) - Returns the angle in radians between the positive x-axis of a plane and the point given by the coordinates (`expr1`, `expr2`).", + usage = """ + _FUNC_(exprY, exprX) - Returns the angle in radians between the positive x-axis of a plane + and the point given by the coordinates (`exprX`, `exprY`), as if computed by + `java.lang.Math._FUNC_`. + """, + arguments = """ + Arguments: + * exprY - coordinate on y-axis + * exprX - coordinate on x-axis + """, examples = """ Examples: > SELECT _FUNC_(0, 0); 0.0 """) -// scalastyle:on line.size.limit case class Atan2(left: Expression, right: Expression) extends BinaryMathExpression(math.atan2, "ATAN2") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 0d54c02c3d06f..c9ca9a8996344 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1313,8 +1313,7 @@ object functions { ////////////////////////////////////////////////////////////////////////////////////////////// /** - * Computes the cosine inverse of the given value; the returned angle is in the range - * 0.0 through pi. + * @return inverse cosine of `e` in radians, as if computed by `java.lang.Math.acos` * * @group math_funcs * @since 1.4.0 @@ -1322,8 +1321,7 @@ object functions { def acos(e: Column): Column = withExpr { Acos(e.expr) } /** - * Computes the cosine inverse of the given column; the returned angle is in the range - * 0.0 through pi. + * @return inverse cosine of `columnName`, as if computed by `java.lang.Math.acos` * * @group math_funcs * @since 1.4.0 @@ -1331,8 +1329,7 @@ object functions { def acos(columnName: String): Column = acos(Column(columnName)) /** - * Computes the sine inverse of the given value; the returned angle is in the range - * -pi/2 through pi/2. + * @return inverse sine of `e` in radians, as if computed by `java.lang.Math.asin` * * @group math_funcs * @since 1.4.0 @@ -1340,8 +1337,7 @@ object functions { def asin(e: Column): Column = withExpr { Asin(e.expr) } /** - * Computes the sine inverse of the given column; the returned angle is in the range - * -pi/2 through pi/2. + * @return inverse sine of `columnName`, as if computed by `java.lang.Math.asin` * * @group math_funcs * @since 1.4.0 @@ -1349,8 +1345,7 @@ object functions { def asin(columnName: String): Column = asin(Column(columnName)) /** - * Computes the tangent inverse of the given column; the returned angle is in the range - * -pi/2 through pi/2 + * @return inverse tangent of `e`, as if computed by `java.lang.Math.atan` * * @group math_funcs * @since 1.4.0 @@ -1358,8 +1353,7 @@ object functions { def atan(e: Column): Column = withExpr { Atan(e.expr) } /** - * Computes the tangent inverse of the given column; the returned angle is in the range - * -pi/2 through pi/2 + * @return inverse tangent of `columnName`, as if computed by `java.lang.Math.atan` * * @group math_funcs * @since 1.4.0 @@ -1367,77 +1361,117 @@ object functions { def atan(columnName: String): Column = atan(Column(columnName)) /** - * Returns the angle theta from the conversion of rectangular coordinates (x, y) to - * polar coordinates (r, theta). Units in radians. + * @param y coordinate on y-axis + * @param x coordinate on x-axis + * @return the theta component of the point + * (r, theta) + * in polar coordinates that corresponds to the point + * (x, y) in Cartesian coordinates, + * as if computed by `java.lang.Math.atan2` * * @group math_funcs * @since 1.4.0 */ - def atan2(l: Column, r: Column): Column = withExpr { Atan2(l.expr, r.expr) } + def atan2(y: Column, x: Column): Column = withExpr { Atan2(y.expr, x.expr) } /** - * Returns the angle theta from the conversion of rectangular coordinates (x, y) to - * polar coordinates (r, theta). + * @param y coordinate on y-axis + * @param xName coordinate on x-axis + * @return the theta component of the point + * (r, theta) + * in polar coordinates that corresponds to the point + * (x, y) in Cartesian coordinates, + * as if computed by `java.lang.Math.atan2` * * @group math_funcs * @since 1.4.0 */ - def atan2(l: Column, rightName: String): Column = atan2(l, Column(rightName)) + def atan2(y: Column, xName: String): Column = atan2(y, Column(xName)) /** - * Returns the angle theta from the conversion of rectangular coordinates (x, y) to - * polar coordinates (r, theta). + * @param yName coordinate on y-axis + * @param x coordinate on x-axis + * @return the theta component of the point + * (r, theta) + * in polar coordinates that corresponds to the point + * (x, y) in Cartesian coordinates, + * as if computed by `java.lang.Math.atan2` * * @group math_funcs * @since 1.4.0 */ - def atan2(leftName: String, r: Column): Column = atan2(Column(leftName), r) + def atan2(yName: String, x: Column): Column = atan2(Column(yName), x) /** - * Returns the angle theta from the conversion of rectangular coordinates (x, y) to - * polar coordinates (r, theta). + * @param yName coordinate on y-axis + * @param xName coordinate on x-axis + * @return the theta component of the point + * (r, theta) + * in polar coordinates that corresponds to the point + * (x, y) in Cartesian coordinates, + * as if computed by `java.lang.Math.atan2` * * @group math_funcs * @since 1.4.0 */ - def atan2(leftName: String, rightName: String): Column = - atan2(Column(leftName), Column(rightName)) + def atan2(yName: String, xName: String): Column = + atan2(Column(yName), Column(xName)) /** - * Returns the angle theta from the conversion of rectangular coordinates (x, y) to - * polar coordinates (r, theta). + * @param y coordinate on y-axis + * @param xValue coordinate on x-axis + * @return the theta component of the point + * (r, theta) + * in polar coordinates that corresponds to the point + * (x, y) in Cartesian coordinates, + * as if computed by `java.lang.Math.atan2` * * @group math_funcs * @since 1.4.0 */ - def atan2(l: Column, r: Double): Column = atan2(l, lit(r)) + def atan2(y: Column, xValue: Double): Column = atan2(y, lit(xValue)) /** - * Returns the angle theta from the conversion of rectangular coordinates (x, y) to - * polar coordinates (r, theta). + * @param yName coordinate on y-axis + * @param xValue coordinate on x-axis + * @return the theta component of the point + * (r, theta) + * in polar coordinates that corresponds to the point + * (x, y) in Cartesian coordinates, + * as if computed by `java.lang.Math.atan2` * * @group math_funcs * @since 1.4.0 */ - def atan2(leftName: String, r: Double): Column = atan2(Column(leftName), r) + def atan2(yName: String, xValue: Double): Column = atan2(Column(yName), xValue) /** - * Returns the angle theta from the conversion of rectangular coordinates (x, y) to - * polar coordinates (r, theta). + * @param yValue coordinate on y-axis + * @param x coordinate on x-axis + * @return the theta component of the point + * (r, theta) + * in polar coordinates that corresponds to the point + * (x, y) in Cartesian coordinates, + * as if computed by `java.lang.Math.atan2` * * @group math_funcs * @since 1.4.0 */ - def atan2(l: Double, r: Column): Column = atan2(lit(l), r) + def atan2(yValue: Double, x: Column): Column = atan2(lit(yValue), x) /** - * Returns the angle theta from the conversion of rectangular coordinates (x, y) to - * polar coordinates (r, theta). + * @param yValue coordinate on y-axis + * @param xName coordinate on x-axis + * @return the theta component of the point + * (r, theta) + * in polar coordinates that corresponds to the point + * (x, y) in Cartesian coordinates, + * as if computed by `java.lang.Math.atan2` * * @group math_funcs * @since 1.4.0 */ - def atan2(l: Double, rightName: String): Column = atan2(l, Column(rightName)) + def atan2(yValue: Double, xName: String): Column = atan2(yValue, Column(xName)) /** * An expression that returns the string representation of the binary value of the given long @@ -1500,7 +1534,8 @@ object functions { } /** - * Computes the cosine of the given value. Units in radians. + * @param e angle in radians + * @return cosine of the angle, as if computed by `java.lang.Math.cos` * * @group math_funcs * @since 1.4.0 @@ -1508,7 +1543,8 @@ object functions { def cos(e: Column): Column = withExpr { Cos(e.expr) } /** - * Computes the cosine of the given column. + * @param columnName angle in radians + * @return cosine of the angle, as if computed by `java.lang.Math.cos` * * @group math_funcs * @since 1.4.0 @@ -1516,7 +1552,8 @@ object functions { def cos(columnName: String): Column = cos(Column(columnName)) /** - * Computes the hyperbolic cosine of the given value. + * @param e hyperbolic angle + * @return hyperbolic cosine of the angle, as if computed by `java.lang.Math.cosh` * * @group math_funcs * @since 1.4.0 @@ -1524,7 +1561,8 @@ object functions { def cosh(e: Column): Column = withExpr { Cosh(e.expr) } /** - * Computes the hyperbolic cosine of the given column. + * @param columnName hyperbolic angle + * @return hyperbolic cosine of the angle, as if computed by `java.lang.Math.cosh` * * @group math_funcs * @since 1.4.0 @@ -1967,7 +2005,8 @@ object functions { def signum(columnName: String): Column = signum(Column(columnName)) /** - * Computes the sine of the given value. Units in radians. + * @param e angle in radians + * @return sine of the angle, as if computed by `java.lang.Math.sin` * * @group math_funcs * @since 1.4.0 @@ -1975,7 +2014,8 @@ object functions { def sin(e: Column): Column = withExpr { Sin(e.expr) } /** - * Computes the sine of the given column. + * @param columnName angle in radians + * @return sine of the angle, as if computed by `java.lang.Math.sin` * * @group math_funcs * @since 1.4.0 @@ -1983,7 +2023,8 @@ object functions { def sin(columnName: String): Column = sin(Column(columnName)) /** - * Computes the hyperbolic sine of the given value. + * @param e hyperbolic angle + * @return hyperbolic sine of the given value, as if computed by `java.lang.Math.sinh` * * @group math_funcs * @since 1.4.0 @@ -1991,7 +2032,8 @@ object functions { def sinh(e: Column): Column = withExpr { Sinh(e.expr) } /** - * Computes the hyperbolic sine of the given column. + * @param columnName hyperbolic angle + * @return hyperbolic sine of the given value, as if computed by `java.lang.Math.sinh` * * @group math_funcs * @since 1.4.0 @@ -1999,7 +2041,8 @@ object functions { def sinh(columnName: String): Column = sinh(Column(columnName)) /** - * Computes the tangent of the given value. Units in radians. + * @param e angle in radians + * @return tangent of the given value, as if computed by `java.lang.Math.tan` * * @group math_funcs * @since 1.4.0 @@ -2007,7 +2050,8 @@ object functions { def tan(e: Column): Column = withExpr { Tan(e.expr) } /** - * Computes the tangent of the given column. + * @param columnName angle in radians + * @return tangent of the given value, as if computed by `java.lang.Math.tan` * * @group math_funcs * @since 1.4.0 @@ -2015,7 +2059,8 @@ object functions { def tan(columnName: String): Column = tan(Column(columnName)) /** - * Computes the hyperbolic tangent of the given value. + * @param e hyperbolic angle + * @return hyperbolic tangent of the given value, as if computed by `java.lang.Math.tanh` * * @group math_funcs * @since 1.4.0 @@ -2023,7 +2068,8 @@ object functions { def tanh(e: Column): Column = withExpr { Tanh(e.expr) } /** - * Computes the hyperbolic tangent of the given column. + * @param columnName hyperbolic angle + * @return hyperbolic tangent of the given value, as if computed by `java.lang.Math.tanh` * * @group math_funcs * @since 1.4.0 @@ -2047,6 +2093,9 @@ object functions { /** * Converts an angle measured in radians to an approximately equivalent angle measured in degrees. * + * @param e angle in radians + * @return angle in degrees, as if computed by `java.lang.Math.toDegrees` + * * @group math_funcs * @since 2.1.0 */ @@ -2055,6 +2104,9 @@ object functions { /** * Converts an angle measured in radians to an approximately equivalent angle measured in degrees. * + * @param columnName angle in radians + * @return angle in degrees, as if computed by `java.lang.Math.toDegrees` + * * @group math_funcs * @since 2.1.0 */ @@ -2077,6 +2129,9 @@ object functions { /** * Converts an angle measured in degrees to an approximately equivalent angle measured in radians. * + * @param e angle in degrees + * @return angle in radians, as if computed by `java.lang.Math.toRadians` + * * @group math_funcs * @since 2.1.0 */ @@ -2085,6 +2140,9 @@ object functions { /** * Converts an angle measured in degrees to an approximately equivalent angle measured in radians. * + * @param columnName angle in degrees + * @return angle in radians, as if computed by `java.lang.Math.toRadians` + * * @group math_funcs * @since 2.1.0 */ @@ -2873,7 +2931,7 @@ object functions { * or equal to the `windowDuration`. Check * `org.apache.spark.unsafe.types.CalendarInterval` for valid duration * identifiers. This duration is likewise absolute, and does not vary - * according to a calendar. + * according to a calendar. * @param startTime The offset with respect to 1970-01-01 00:00:00 UTC with which to start * window intervals. For example, in order to have hourly tumbling windows that * start 15 minutes past the hour, e.g. 12:15-13:15, 13:15-14:15... provide @@ -2929,7 +2987,7 @@ object functions { * or equal to the `windowDuration`. Check * `org.apache.spark.unsafe.types.CalendarInterval` for valid duration * identifiers. This duration is likewise absolute, and does not vary - * according to a calendar. + * according to a calendar. * * @group datetime_funcs * @since 2.0.0 From 947b4e6f09db6aa5d92409344b6e273e9faeb24e Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Mon, 5 Mar 2018 16:21:02 +0100 Subject: [PATCH 0427/1161] [SPARK-23510][DOC][FOLLOW-UP] Update spark.sql.hive.metastore.version ## What changes were proposed in this pull request? Update `spark.sql.hive.metastore.version` to 2.3.2, same as HiveUtils.scala: https://github.com/apache/spark/blob/ff1480189b827af0be38605d566a4ee71b4c36f6/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala#L63-L65 ## How was this patch tested? N/A Author: Yuming Wang Closes #20734 from wangyum/SPARK-23510-FOLLOW-UP. --- docs/sql-programming-guide.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 4d0f015f401bb..01e2076555ee6 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1214,7 +1214,7 @@ The following options can be used to configure the version of Hive that is used 1.2.1 Version of the Hive metastore. Available - options are 0.12.0 through 1.2.1. + options are 0.12.0 through 2.3.2. From 4586eada42d6a16bb78d1650d145531c51fa747f Mon Sep 17 00:00:00 2001 From: Rekha Joshi Date: Mon, 5 Mar 2018 09:30:49 -0800 Subject: [PATCH 0428/1161] [SPARK-22430][R][DOCS] Unknown tag warnings when building R docs with Roxygen 6.0.1 ## What changes were proposed in this pull request? Removed export tag to get rid of unknown tag warnings ## How was this patch tested? Existing tests Author: Rekha Joshi Author: rjoshi2 Closes #20501 from rekhajoshm/SPARK-22430. --- R/pkg/R/DataFrame.R | 92 --------- R/pkg/R/SQLContext.R | 16 -- R/pkg/R/WindowSpec.R | 8 - R/pkg/R/broadcast.R | 3 - R/pkg/R/catalog.R | 18 -- R/pkg/R/column.R | 7 - R/pkg/R/context.R | 6 - R/pkg/R/functions.R | 181 ----------------- R/pkg/R/generics.R | 343 --------------------------------- R/pkg/R/group.R | 7 - R/pkg/R/install.R | 1 - R/pkg/R/jvm.R | 3 - R/pkg/R/mllib_classification.R | 20 -- R/pkg/R/mllib_clustering.R | 23 --- R/pkg/R/mllib_fpm.R | 6 - R/pkg/R/mllib_recommendation.R | 5 - R/pkg/R/mllib_regression.R | 17 -- R/pkg/R/mllib_stat.R | 4 - R/pkg/R/mllib_tree.R | 33 ---- R/pkg/R/mllib_utils.R | 3 - R/pkg/R/schema.R | 7 - R/pkg/R/sparkR.R | 7 - R/pkg/R/stats.R | 6 - R/pkg/R/streaming.R | 9 - R/pkg/R/utils.R | 1 - R/pkg/R/window.R | 4 - 26 files changed, 830 deletions(-) diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 41c3c3a89fa72..c4852024c0f49 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -36,7 +36,6 @@ setOldClass("structType") #' @slot sdf A Java object reference to the backing Scala DataFrame #' @seealso \link{createDataFrame}, \link{read.json}, \link{table} #' @seealso \url{https://spark.apache.org/docs/latest/sparkr.html#sparkr-dataframes} -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -77,7 +76,6 @@ setWriteMode <- function(write, mode) { write } -#' @export #' @param sdf A Java object reference to the backing Scala DataFrame #' @param isCached TRUE if the SparkDataFrame is cached #' @noRd @@ -97,7 +95,6 @@ dataFrame <- function(sdf, isCached = FALSE) { #' @rdname printSchema #' @name printSchema #' @aliases printSchema,SparkDataFrame-method -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -123,7 +120,6 @@ setMethod("printSchema", #' @rdname schema #' @name schema #' @aliases schema,SparkDataFrame-method -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -146,7 +142,6 @@ setMethod("schema", #' @aliases explain,SparkDataFrame-method #' @rdname explain #' @name explain -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -178,7 +173,6 @@ setMethod("explain", #' @rdname isLocal #' @name isLocal #' @aliases isLocal,SparkDataFrame-method -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -209,7 +203,6 @@ setMethod("isLocal", #' @aliases showDF,SparkDataFrame-method #' @rdname showDF #' @name showDF -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -241,7 +234,6 @@ setMethod("showDF", #' @rdname show #' @aliases show,SparkDataFrame-method #' @name show -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -269,7 +261,6 @@ setMethod("show", "SparkDataFrame", #' @rdname dtypes #' @name dtypes #' @aliases dtypes,SparkDataFrame-method -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -296,7 +287,6 @@ setMethod("dtypes", #' @rdname columns #' @name columns #' @aliases columns,SparkDataFrame-method -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -388,7 +378,6 @@ setMethod("colnames<-", #' @aliases coltypes,SparkDataFrame-method #' @name coltypes #' @family SparkDataFrame functions -#' @export #' @examples #'\dontrun{ #' irisDF <- createDataFrame(iris) @@ -445,7 +434,6 @@ setMethod("coltypes", #' @rdname coltypes #' @name coltypes<- #' @aliases coltypes<-,SparkDataFrame,character-method -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -494,7 +482,6 @@ setMethod("coltypes<-", #' @rdname createOrReplaceTempView #' @name createOrReplaceTempView #' @aliases createOrReplaceTempView,SparkDataFrame,character-method -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -521,7 +508,6 @@ setMethod("createOrReplaceTempView", #' @rdname registerTempTable-deprecated #' @name registerTempTable #' @aliases registerTempTable,SparkDataFrame,character-method -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -552,7 +538,6 @@ setMethod("registerTempTable", #' @rdname insertInto #' @name insertInto #' @aliases insertInto,SparkDataFrame,character-method -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -580,7 +565,6 @@ setMethod("insertInto", #' @aliases cache,SparkDataFrame-method #' @rdname cache #' @name cache -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -611,7 +595,6 @@ setMethod("cache", #' @rdname persist #' @name persist #' @aliases persist,SparkDataFrame,character-method -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -641,7 +624,6 @@ setMethod("persist", #' @rdname unpersist #' @aliases unpersist,SparkDataFrame-method #' @name unpersist -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -669,7 +651,6 @@ setMethod("unpersist", #' @rdname storageLevel #' @aliases storageLevel,SparkDataFrame-method #' @name storageLevel -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -707,7 +688,6 @@ setMethod("storageLevel", #' @name coalesce #' @aliases coalesce,SparkDataFrame-method #' @seealso \link{repartition} -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -744,7 +724,6 @@ setMethod("coalesce", #' @name repartition #' @aliases repartition,SparkDataFrame-method #' @seealso \link{coalesce} -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -793,7 +772,6 @@ setMethod("repartition", #' @rdname toJSON #' @name toJSON #' @aliases toJSON,SparkDataFrame-method -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -826,7 +804,6 @@ setMethod("toJSON", #' @rdname write.json #' @name write.json #' @aliases write.json,SparkDataFrame,character-method -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -858,7 +835,6 @@ setMethod("write.json", #' @aliases write.orc,SparkDataFrame,character-method #' @rdname write.orc #' @name write.orc -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -890,7 +866,6 @@ setMethod("write.orc", #' @rdname write.parquet #' @name write.parquet #' @aliases write.parquet,SparkDataFrame,character-method -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -911,7 +886,6 @@ setMethod("write.parquet", #' @rdname write.parquet #' @name saveAsParquetFile #' @aliases saveAsParquetFile,SparkDataFrame,character-method -#' @export #' @note saveAsParquetFile since 1.4.0 setMethod("saveAsParquetFile", signature(x = "SparkDataFrame", path = "character"), @@ -936,7 +910,6 @@ setMethod("saveAsParquetFile", #' @aliases write.text,SparkDataFrame,character-method #' @rdname write.text #' @name write.text -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -963,7 +936,6 @@ setMethod("write.text", #' @aliases distinct,SparkDataFrame-method #' @rdname distinct #' @name distinct -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -1004,7 +976,6 @@ setMethod("unique", #' @aliases sample,SparkDataFrame-method #' @rdname sample #' @name sample -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -1061,7 +1032,6 @@ setMethod("sample_frac", #' @rdname nrow #' @name nrow #' @aliases count,SparkDataFrame-method -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -1094,7 +1064,6 @@ setMethod("nrow", #' @rdname ncol #' @name ncol #' @aliases ncol,SparkDataFrame-method -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -1118,7 +1087,6 @@ setMethod("ncol", #' @rdname dim #' @aliases dim,SparkDataFrame-method #' @name dim -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -1144,7 +1112,6 @@ setMethod("dim", #' @rdname collect #' @aliases collect,SparkDataFrame-method #' @name collect -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -1229,7 +1196,6 @@ setMethod("collect", #' @rdname limit #' @name limit #' @aliases limit,SparkDataFrame,numeric-method -#' @export #' @examples #' \dontrun{ #' sparkR.session() @@ -1253,7 +1219,6 @@ setMethod("limit", #' @rdname take #' @name take #' @aliases take,SparkDataFrame,numeric-method -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -1282,7 +1247,6 @@ setMethod("take", #' @aliases head,SparkDataFrame-method #' @rdname head #' @name head -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -1307,7 +1271,6 @@ setMethod("head", #' @aliases first,SparkDataFrame-method #' @rdname first #' @name first -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -1359,7 +1322,6 @@ setMethod("toRDD", #' @aliases groupBy,SparkDataFrame-method #' @rdname groupBy #' @name groupBy -#' @export #' @examples #' \dontrun{ #' # Compute the average for all numeric columns grouped by department. @@ -1401,7 +1363,6 @@ setMethod("group_by", #' @aliases agg,SparkDataFrame-method #' @rdname summarize #' @name agg -#' @export #' @note agg since 1.4.0 setMethod("agg", signature(x = "SparkDataFrame"), @@ -1460,7 +1421,6 @@ setClassUnion("characterOrstructType", c("character", "structType")) #' @aliases dapply,SparkDataFrame,function,characterOrstructType-method #' @name dapply #' @seealso \link{dapplyCollect} -#' @export #' @examples #' \dontrun{ #' df <- createDataFrame(iris) @@ -1519,7 +1479,6 @@ setMethod("dapply", #' @aliases dapplyCollect,SparkDataFrame,function-method #' @name dapplyCollect #' @seealso \link{dapply} -#' @export #' @examples #' \dontrun{ #' df <- createDataFrame(iris) @@ -1576,7 +1535,6 @@ setMethod("dapplyCollect", #' @rdname gapply #' @name gapply #' @seealso \link{gapplyCollect} -#' @export #' @examples #' #' \dontrun{ @@ -1673,7 +1631,6 @@ setMethod("gapply", #' @rdname gapplyCollect #' @name gapplyCollect #' @seealso \link{gapply} -#' @export #' @examples #' #' \dontrun{ @@ -1947,7 +1904,6 @@ setMethod("[", signature(x = "SparkDataFrame"), #' @param ... currently not used. #' @return A new SparkDataFrame containing only the rows that meet the condition with selected #' columns. -#' @export #' @family SparkDataFrame functions #' @aliases subset,SparkDataFrame-method #' @seealso \link{withColumn} @@ -1992,7 +1948,6 @@ setMethod("subset", signature(x = "SparkDataFrame"), #' If more than one column is assigned in \code{col}, \code{...} #' should be left empty. #' @return A new SparkDataFrame with selected columns. -#' @export #' @family SparkDataFrame functions #' @rdname select #' @aliases select,SparkDataFrame,character-method @@ -2024,7 +1979,6 @@ setMethod("select", signature(x = "SparkDataFrame", col = "character"), }) #' @rdname select -#' @export #' @aliases select,SparkDataFrame,Column-method #' @note select(SparkDataFrame, Column) since 1.4.0 setMethod("select", signature(x = "SparkDataFrame", col = "Column"), @@ -2037,7 +1991,6 @@ setMethod("select", signature(x = "SparkDataFrame", col = "Column"), }) #' @rdname select -#' @export #' @aliases select,SparkDataFrame,list-method #' @note select(SparkDataFrame, list) since 1.4.0 setMethod("select", @@ -2066,7 +2019,6 @@ setMethod("select", #' @aliases selectExpr,SparkDataFrame,character-method #' @rdname selectExpr #' @name selectExpr -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -2098,7 +2050,6 @@ setMethod("selectExpr", #' @rdname withColumn #' @name withColumn #' @seealso \link{rename} \link{mutate} \link{subset} -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -2137,7 +2088,6 @@ setMethod("withColumn", #' @rdname mutate #' @name mutate #' @seealso \link{rename} \link{withColumn} -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -2208,7 +2158,6 @@ setMethod("mutate", }) #' @param _data a SparkDataFrame. -#' @export #' @rdname mutate #' @aliases transform,SparkDataFrame-method #' @name transform @@ -2232,7 +2181,6 @@ setMethod("transform", #' @name withColumnRenamed #' @aliases withColumnRenamed,SparkDataFrame,character,character-method #' @seealso \link{mutate} -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -2258,7 +2206,6 @@ setMethod("withColumnRenamed", #' @rdname rename #' @name rename #' @aliases rename,SparkDataFrame-method -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -2304,7 +2251,6 @@ setClassUnion("characterOrColumn", c("character", "Column")) #' @aliases arrange,SparkDataFrame,Column-method #' @rdname arrange #' @name arrange -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -2335,7 +2281,6 @@ setMethod("arrange", #' @rdname arrange #' @name arrange #' @aliases arrange,SparkDataFrame,character-method -#' @export #' @note arrange(SparkDataFrame, character) since 1.4.0 setMethod("arrange", signature(x = "SparkDataFrame", col = "character"), @@ -2368,7 +2313,6 @@ setMethod("arrange", #' @rdname arrange #' @aliases orderBy,SparkDataFrame,characterOrColumn-method -#' @export #' @note orderBy(SparkDataFrame, characterOrColumn) since 1.4.0 setMethod("orderBy", signature(x = "SparkDataFrame", col = "characterOrColumn"), @@ -2389,7 +2333,6 @@ setMethod("orderBy", #' @rdname filter #' @name filter #' @family subsetting functions -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -2432,7 +2375,6 @@ setMethod("where", #' @aliases dropDuplicates,SparkDataFrame-method #' @rdname dropDuplicates #' @name dropDuplicates -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -2481,7 +2423,6 @@ setMethod("dropDuplicates", #' @rdname join #' @name join #' @seealso \link{merge} \link{crossJoin} -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -2533,7 +2474,6 @@ setMethod("join", #' @rdname crossJoin #' @name crossJoin #' @seealso \link{merge} \link{join} -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -2581,7 +2521,6 @@ setMethod("crossJoin", #' @aliases merge,SparkDataFrame,SparkDataFrame-method #' @rdname merge #' @seealso \link{join} \link{crossJoin} -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -2721,7 +2660,6 @@ genAliasesForIntersectedCols <- function(x, intersectedColNames, suffix) { #' @name union #' @aliases union,SparkDataFrame,SparkDataFrame-method #' @seealso \link{rbind} \link{unionByName} -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -2742,7 +2680,6 @@ setMethod("union", #' @rdname union #' @name unionAll #' @aliases unionAll,SparkDataFrame,SparkDataFrame-method -#' @export #' @note unionAll since 1.4.0 setMethod("unionAll", signature(x = "SparkDataFrame", y = "SparkDataFrame"), @@ -2769,7 +2706,6 @@ setMethod("unionAll", #' @name unionByName #' @aliases unionByName,SparkDataFrame,SparkDataFrame-method #' @seealso \link{rbind} \link{union} -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -2802,7 +2738,6 @@ setMethod("unionByName", #' @rdname rbind #' @name rbind #' @seealso \link{union} \link{unionByName} -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -2835,7 +2770,6 @@ setMethod("rbind", #' @aliases intersect,SparkDataFrame,SparkDataFrame-method #' @rdname intersect #' @name intersect -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -2863,7 +2797,6 @@ setMethod("intersect", #' @aliases except,SparkDataFrame,SparkDataFrame-method #' @rdname except #' @name except -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -2872,7 +2805,6 @@ setMethod("intersect", #' exceptDF <- except(df, df2) #' } #' @rdname except -#' @export #' @note except since 1.4.0 setMethod("except", signature(x = "SparkDataFrame", y = "SparkDataFrame"), @@ -2909,7 +2841,6 @@ setMethod("except", #' @aliases write.df,SparkDataFrame-method #' @rdname write.df #' @name write.df -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -2944,7 +2875,6 @@ setMethod("write.df", #' @rdname write.df #' @name saveDF #' @aliases saveDF,SparkDataFrame,character-method -#' @export #' @note saveDF since 1.4.0 setMethod("saveDF", signature(df = "SparkDataFrame", path = "character"), @@ -2978,7 +2908,6 @@ setMethod("saveDF", #' @aliases saveAsTable,SparkDataFrame,character-method #' @rdname saveAsTable #' @name saveAsTable -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -3015,7 +2944,6 @@ setMethod("saveAsTable", #' @aliases describe,SparkDataFrame,character-method describe,SparkDataFrame,ANY-method #' @rdname describe #' @name describe -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -3071,7 +2999,6 @@ setMethod("describe", #' @rdname summary #' @name summary #' @aliases summary,SparkDataFrame-method -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -3117,7 +3044,6 @@ setMethod("summary", #' @rdname nafunctions #' @aliases dropna,SparkDataFrame-method #' @name dropna -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -3148,7 +3074,6 @@ setMethod("dropna", #' @rdname nafunctions #' @name na.omit #' @aliases na.omit,SparkDataFrame-method -#' @export #' @note na.omit since 1.5.0 setMethod("na.omit", signature(object = "SparkDataFrame"), @@ -3168,7 +3093,6 @@ setMethod("na.omit", #' @rdname nafunctions #' @name fillna #' @aliases fillna,SparkDataFrame-method -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -3399,7 +3323,6 @@ setMethod("str", #' @rdname drop #' @name drop #' @aliases drop,SparkDataFrame-method -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -3427,7 +3350,6 @@ setMethod("drop", #' @name drop #' @rdname drop #' @aliases drop,ANY-method -#' @export setMethod("drop", signature(x = "ANY"), function(x) { @@ -3446,7 +3368,6 @@ setMethod("drop", #' @rdname histogram #' @aliases histogram,SparkDataFrame,characterOrColumn-method #' @family SparkDataFrame functions -#' @export #' @examples #' \dontrun{ #' @@ -3582,7 +3503,6 @@ setMethod("histogram", #' @rdname write.jdbc #' @name write.jdbc #' @aliases write.jdbc,SparkDataFrame,character,character-method -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -3611,7 +3531,6 @@ setMethod("write.jdbc", #' @aliases randomSplit,SparkDataFrame,numeric-method #' @rdname randomSplit #' @name randomSplit -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -3645,7 +3564,6 @@ setMethod("randomSplit", #' @aliases getNumPartitions,SparkDataFrame-method #' @rdname getNumPartitions #' @name getNumPartitions -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -3672,7 +3590,6 @@ setMethod("getNumPartitions", #' @rdname isStreaming #' @name isStreaming #' @seealso \link{read.stream} \link{write.stream} -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -3726,7 +3643,6 @@ setMethod("isStreaming", #' @aliases write.stream,SparkDataFrame-method #' @rdname write.stream #' @name write.stream -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -3819,7 +3735,6 @@ setMethod("write.stream", #' @rdname checkpoint #' @name checkpoint #' @seealso \link{setCheckpointDir} -#' @export #' @examples #'\dontrun{ #' setCheckpointDir("/checkpoint") @@ -3847,7 +3762,6 @@ setMethod("checkpoint", #' @aliases localCheckpoint,SparkDataFrame-method #' @rdname localCheckpoint #' @name localCheckpoint -#' @export #' @examples #'\dontrun{ #' df <- localCheckpoint(df) @@ -3874,7 +3788,6 @@ setMethod("localCheckpoint", #' @aliases cube,SparkDataFrame-method #' @rdname cube #' @name cube -#' @export #' @examples #' \dontrun{ #' df <- createDataFrame(mtcars) @@ -3909,7 +3822,6 @@ setMethod("cube", #' @aliases rollup,SparkDataFrame-method #' @rdname rollup #' @name rollup -#' @export #' @examples #'\dontrun{ #' df <- createDataFrame(mtcars) @@ -3942,7 +3854,6 @@ setMethod("rollup", #' @aliases hint,SparkDataFrame,character-method #' @rdname hint #' @name hint -#' @export #' @examples #' \dontrun{ #' df <- createDataFrame(mtcars) @@ -3966,7 +3877,6 @@ setMethod("hint", #' @family SparkDataFrame functions #' @rdname alias #' @name alias -#' @export #' @examples #' \dontrun{ #' df <- alias(createDataFrame(mtcars), "mtcars") @@ -3997,7 +3907,6 @@ setMethod("alias", #' @family SparkDataFrame functions #' @rdname broadcast #' @name broadcast -#' @export #' @examples #' \dontrun{ #' df <- createDataFrame(mtcars) @@ -4041,7 +3950,6 @@ setMethod("broadcast", #' @family SparkDataFrame functions #' @rdname withWatermark #' @name withWatermark -#' @export #' @examples #' \dontrun{ #' sparkR.session() diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R index 9d0a2d5e074e4..ebec0ce3d1920 100644 --- a/R/pkg/R/SQLContext.R +++ b/R/pkg/R/SQLContext.R @@ -123,7 +123,6 @@ infer_type <- function(x) { #' @return a list of config values with keys as their names #' @rdname sparkR.conf #' @name sparkR.conf -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -163,7 +162,6 @@ sparkR.conf <- function(key, defaultValue) { #' @return a character string of the Spark version #' @rdname sparkR.version #' @name sparkR.version -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -191,7 +189,6 @@ getDefaultSqlSource <- function() { #' limited by length of the list or number of rows of the data.frame #' @return A SparkDataFrame. #' @rdname createDataFrame -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -294,7 +291,6 @@ createDataFrame <- function(x, ...) { #' @rdname createDataFrame #' @aliases createDataFrame -#' @export #' @method as.DataFrame default #' @note as.DataFrame since 1.6.0 as.DataFrame.default <- function(data, schema = NULL, samplingRatio = 1.0, numPartitions = NULL) { @@ -304,7 +300,6 @@ as.DataFrame.default <- function(data, schema = NULL, samplingRatio = 1.0, numPa #' @param ... additional argument(s). #' @rdname createDataFrame #' @aliases as.DataFrame -#' @export as.DataFrame <- function(data, ...) { dispatchFunc("as.DataFrame(data, schema = NULL)", data, ...) } @@ -342,7 +337,6 @@ setMethod("toDF", signature(x = "RDD"), #' @param ... additional external data source specific named properties. #' @return SparkDataFrame #' @rdname read.json -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -371,7 +365,6 @@ read.json <- function(x, ...) { #' @rdname read.json #' @name jsonFile -#' @export #' @method jsonFile default #' @note jsonFile since 1.4.0 jsonFile.default <- function(path) { @@ -423,7 +416,6 @@ jsonRDD <- function(sqlContext, rdd, schema = NULL, samplingRatio = 1.0) { #' @param ... additional external data source specific named properties. #' @return SparkDataFrame #' @rdname read.orc -#' @export #' @name read.orc #' @note read.orc since 2.0.0 read.orc <- function(path, ...) { @@ -444,7 +436,6 @@ read.orc <- function(path, ...) { #' @param path path of file to read. A vector of multiple paths is allowed. #' @return SparkDataFrame #' @rdname read.parquet -#' @export #' @name read.parquet #' @method read.parquet default #' @note read.parquet since 1.6.0 @@ -466,7 +457,6 @@ read.parquet <- function(x, ...) { #' @param ... argument(s) passed to the method. #' @rdname read.parquet #' @name parquetFile -#' @export #' @method parquetFile default #' @note parquetFile since 1.4.0 parquetFile.default <- function(...) { @@ -490,7 +480,6 @@ parquetFile <- function(x, ...) { #' @param ... additional external data source specific named properties. #' @return SparkDataFrame #' @rdname read.text -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -522,7 +511,6 @@ read.text <- function(x, ...) { #' @param sqlQuery A character vector containing the SQL query #' @return SparkDataFrame #' @rdname sql -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -556,7 +544,6 @@ sql <- function(x, ...) { #' @return SparkDataFrame #' @rdname tableToDF #' @name tableToDF -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -591,7 +578,6 @@ tableToDF <- function(tableName) { #' @rdname read.df #' @name read.df #' @seealso \link{read.json} -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -681,7 +667,6 @@ loadDF <- function(x = NULL, ...) { #' @return SparkDataFrame #' @rdname read.jdbc #' @name read.jdbc -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -734,7 +719,6 @@ read.jdbc <- function(url, tableName, #' @rdname read.stream #' @name read.stream #' @seealso \link{write.stream} -#' @export #' @examples #'\dontrun{ #' sparkR.session() diff --git a/R/pkg/R/WindowSpec.R b/R/pkg/R/WindowSpec.R index debc7cbde55e7..ee7f4adf726e6 100644 --- a/R/pkg/R/WindowSpec.R +++ b/R/pkg/R/WindowSpec.R @@ -28,7 +28,6 @@ NULL #' @seealso \link{windowPartitionBy}, \link{windowOrderBy} #' #' @param sws A Java object reference to the backing Scala WindowSpec -#' @export #' @note WindowSpec since 2.0.0 setClass("WindowSpec", slots = list(sws = "jobj")) @@ -44,7 +43,6 @@ windowSpec <- function(sws) { } #' @rdname show -#' @export #' @note show(WindowSpec) since 2.0.0 setMethod("show", "WindowSpec", function(object) { @@ -63,7 +61,6 @@ setMethod("show", "WindowSpec", #' @name partitionBy #' @aliases partitionBy,WindowSpec-method #' @family windowspec_method -#' @export #' @examples #' \dontrun{ #' partitionBy(ws, "col1", "col2") @@ -97,7 +94,6 @@ setMethod("partitionBy", #' @aliases orderBy,WindowSpec,character-method #' @family windowspec_method #' @seealso See \link{arrange} for use in sorting a SparkDataFrame -#' @export #' @examples #' \dontrun{ #' orderBy(ws, "col1", "col2") @@ -113,7 +109,6 @@ setMethod("orderBy", #' @rdname orderBy #' @name orderBy #' @aliases orderBy,WindowSpec,Column-method -#' @export #' @note orderBy(WindowSpec, Column) since 2.0.0 setMethod("orderBy", signature(x = "WindowSpec", col = "Column"), @@ -142,7 +137,6 @@ setMethod("orderBy", #' @aliases rowsBetween,WindowSpec,numeric,numeric-method #' @name rowsBetween #' @family windowspec_method -#' @export #' @examples #' \dontrun{ #' rowsBetween(ws, 0, 3) @@ -174,7 +168,6 @@ setMethod("rowsBetween", #' @aliases rangeBetween,WindowSpec,numeric,numeric-method #' @name rangeBetween #' @family windowspec_method -#' @export #' @examples #' \dontrun{ #' rangeBetween(ws, 0, 3) @@ -202,7 +195,6 @@ setMethod("rangeBetween", #' @name over #' @aliases over,Column,WindowSpec-method #' @family colum_func -#' @export #' @examples #' \dontrun{ #' df <- createDataFrame(mtcars) diff --git a/R/pkg/R/broadcast.R b/R/pkg/R/broadcast.R index 398dffc4ab1b4..282f8a6857738 100644 --- a/R/pkg/R/broadcast.R +++ b/R/pkg/R/broadcast.R @@ -32,14 +32,12 @@ # @seealso broadcast # # @param id Id of the backing Spark broadcast variable -# @export setClass("Broadcast", slots = list(id = "character")) # @rdname broadcast-class # @param value Value of the broadcast variable # @param jBroadcastRef reference to the backing Java broadcast object # @param objName name of broadcasted object -# @export Broadcast <- function(id, value, jBroadcastRef, objName) { .broadcastValues[[id]] <- value .broadcastNames[[as.character(objName)]] <- jBroadcastRef @@ -73,7 +71,6 @@ setMethod("value", # @param bcastId The id of broadcast variable to set # @param value The value to be set -# @export setBroadcastValue <- function(bcastId, value) { bcastIdStr <- as.character(bcastId) .broadcastValues[[bcastIdStr]] <- value diff --git a/R/pkg/R/catalog.R b/R/pkg/R/catalog.R index e59a7024333ac..baf4d861fcf86 100644 --- a/R/pkg/R/catalog.R +++ b/R/pkg/R/catalog.R @@ -34,7 +34,6 @@ #' @return A SparkDataFrame. #' @rdname createExternalTable-deprecated #' @seealso \link{createTable} -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -71,7 +70,6 @@ createExternalTable <- function(x, ...) { #' @return A SparkDataFrame. #' @rdname createTable #' @seealso \link{createExternalTable} -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -110,7 +108,6 @@ createTable <- function(tableName, path = NULL, source = NULL, schema = NULL, .. #' identifier is provided, it refers to a table in the current database. #' @return SparkDataFrame #' @rdname cacheTable -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -140,7 +137,6 @@ cacheTable <- function(x, ...) { #' identifier is provided, it refers to a table in the current database. #' @return SparkDataFrame #' @rdname uncacheTable -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -167,7 +163,6 @@ uncacheTable <- function(x, ...) { #' Removes all cached tables from the in-memory cache. #' #' @rdname clearCache -#' @export #' @examples #' \dontrun{ #' clearCache() @@ -193,7 +188,6 @@ clearCache <- function() { #' @param tableName The name of the SparkSQL table to be dropped. #' @seealso \link{dropTempView} #' @rdname dropTempTable-deprecated -#' @export #' @examples #' \dontrun{ #' sparkR.session() @@ -225,7 +219,6 @@ dropTempTable <- function(x, ...) { #' @return TRUE if the view is dropped successfully, FALSE otherwise. #' @rdname dropTempView #' @name dropTempView -#' @export #' @examples #' \dontrun{ #' sparkR.session() @@ -251,7 +244,6 @@ dropTempView <- function(viewName) { #' @return a SparkDataFrame #' @rdname tables #' @seealso \link{listTables} -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -276,7 +268,6 @@ tables <- function(x, ...) { #' @param databaseName (optional) name of the database #' @return a list of table names #' @rdname tableNames -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -304,7 +295,6 @@ tableNames <- function(x, ...) { #' @return name of the current default database. #' @rdname currentDatabase #' @name currentDatabase -#' @export #' @examples #' \dontrun{ #' sparkR.session() @@ -324,7 +314,6 @@ currentDatabase <- function() { #' @param databaseName name of the database #' @rdname setCurrentDatabase #' @name setCurrentDatabase -#' @export #' @examples #' \dontrun{ #' sparkR.session() @@ -347,7 +336,6 @@ setCurrentDatabase <- function(databaseName) { #' @return a SparkDataFrame of the list of databases. #' @rdname listDatabases #' @name listDatabases -#' @export #' @examples #' \dontrun{ #' sparkR.session() @@ -370,7 +358,6 @@ listDatabases <- function() { #' @rdname listTables #' @name listTables #' @seealso \link{tables} -#' @export #' @examples #' \dontrun{ #' sparkR.session() @@ -403,7 +390,6 @@ listTables <- function(databaseName = NULL) { #' @return a SparkDataFrame of the list of column descriptions. #' @rdname listColumns #' @name listColumns -#' @export #' @examples #' \dontrun{ #' sparkR.session() @@ -433,7 +419,6 @@ listColumns <- function(tableName, databaseName = NULL) { #' @return a SparkDataFrame of the list of function descriptions. #' @rdname listFunctions #' @name listFunctions -#' @export #' @examples #' \dontrun{ #' sparkR.session() @@ -463,7 +448,6 @@ listFunctions <- function(databaseName = NULL) { #' identifier is provided, it refers to a table in the current database. #' @rdname recoverPartitions #' @name recoverPartitions -#' @export #' @examples #' \dontrun{ #' sparkR.session() @@ -490,7 +474,6 @@ recoverPartitions <- function(tableName) { #' identifier is provided, it refers to a table in the current database. #' @rdname refreshTable #' @name refreshTable -#' @export #' @examples #' \dontrun{ #' sparkR.session() @@ -512,7 +495,6 @@ refreshTable <- function(tableName) { #' @param path the path of the data source. #' @rdname refreshByPath #' @name refreshByPath -#' @export #' @examples #' \dontrun{ #' sparkR.session() diff --git a/R/pkg/R/column.R b/R/pkg/R/column.R index 3095adb918b67..9727efc354f10 100644 --- a/R/pkg/R/column.R +++ b/R/pkg/R/column.R @@ -29,7 +29,6 @@ setOldClass("jobj") #' @rdname column #' #' @slot jc reference to JVM SparkDataFrame column -#' @export #' @note Column since 1.4.0 setClass("Column", slots = list(jc = "jobj")) @@ -56,7 +55,6 @@ setMethod("column", #' @rdname show #' @name show #' @aliases show,Column-method -#' @export #' @note show(Column) since 1.4.0 setMethod("show", "Column", function(object) { @@ -134,7 +132,6 @@ createMethods() #' @name alias #' @aliases alias,Column-method #' @family colum_func -#' @export #' @examples #' \dontrun{ #' df <- createDataFrame(iris) @@ -270,7 +267,6 @@ setMethod("cast", #' @name %in% #' @aliases %in%,Column-method #' @return A matched values as a result of comparing with given values. -#' @export #' @examples #' \dontrun{ #' filter(df, "age in (10, 30)") @@ -296,7 +292,6 @@ setMethod("%in%", #' @name otherwise #' @family colum_func #' @aliases otherwise,Column-method -#' @export #' @note otherwise since 1.5.0 setMethod("otherwise", signature(x = "Column", value = "ANY"), @@ -318,7 +313,6 @@ setMethod("otherwise", #' @rdname eq_null_safe #' @name %<=>% #' @aliases %<=>%,Column-method -#' @export #' @examples #' \dontrun{ #' df1 <- createDataFrame(data.frame( @@ -348,7 +342,6 @@ setMethod("%<=>%", #' @rdname not #' @name not #' @aliases !,Column-method -#' @export #' @examples #' \dontrun{ #' df <- createDataFrame(data.frame(x = c(-1, 0, 1))) diff --git a/R/pkg/R/context.R b/R/pkg/R/context.R index 443c2ff8f9ace..8ec727dd042bc 100644 --- a/R/pkg/R/context.R +++ b/R/pkg/R/context.R @@ -308,7 +308,6 @@ setCheckpointDirSC <- function(sc, dirName) { #' @rdname spark.addFile #' @param path The path of the file to be added #' @param recursive Whether to add files recursively from the path. Default is FALSE. -#' @export #' @examples #'\dontrun{ #' spark.addFile("~/myfile") @@ -323,7 +322,6 @@ spark.addFile <- function(path, recursive = FALSE) { #' #' @rdname spark.getSparkFilesRootDirectory #' @return the root directory that contains files added through spark.addFile -#' @export #' @examples #'\dontrun{ #' spark.getSparkFilesRootDirectory() @@ -344,7 +342,6 @@ spark.getSparkFilesRootDirectory <- function() { # nolint #' @rdname spark.getSparkFiles #' @param fileName The name of the file added through spark.addFile #' @return the absolute path of a file added through spark.addFile. -#' @export #' @examples #'\dontrun{ #' spark.getSparkFiles("myfile") @@ -391,7 +388,6 @@ spark.getSparkFiles <- function(fileName) { #' @param list the list of elements #' @param func a function that takes one argument. #' @return a list of results (the exact type being determined by the function) -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -412,7 +408,6 @@ spark.lapply <- function(list, func) { #' #' @rdname setLogLevel #' @param level New log level -#' @export #' @examples #'\dontrun{ #' setLogLevel("ERROR") @@ -431,7 +426,6 @@ setLogLevel <- function(level) { #' @rdname setCheckpointDir #' @param directory Directory path to checkpoint to #' @seealso \link{checkpoint} -#' @export #' @examples #'\dontrun{ #' setCheckpointDir("/checkpoint") diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 29ee146ab14f9..a527426b19674 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -244,7 +244,6 @@ NULL #' If the parameter is a Column, it is returned unchanged. #' #' @rdname column_nonaggregate_functions -#' @export #' @aliases lit lit,ANY-method #' @examples #' @@ -267,7 +266,6 @@ setMethod("lit", signature("ANY"), #' \code{abs}: Computes the absolute value. #' #' @rdname column_math_functions -#' @export #' @aliases abs abs,Column-method #' @note abs since 1.5.0 setMethod("abs", @@ -282,7 +280,6 @@ setMethod("abs", #' as if computed by \code{java.lang.Math.acos()} #' #' @rdname column_math_functions -#' @export #' @aliases acos acos,Column-method #' @note acos since 1.5.0 setMethod("acos", @@ -296,7 +293,6 @@ setMethod("acos", #' \code{approxCountDistinct}: Returns the approximate number of distinct items in a group. #' #' @rdname column_aggregate_functions -#' @export #' @aliases approxCountDistinct approxCountDistinct,Column-method #' @examples #' @@ -319,7 +315,6 @@ setMethod("approxCountDistinct", #' and returns the result as an int column. #' #' @rdname column_string_functions -#' @export #' @aliases ascii ascii,Column-method #' @examples #' @@ -338,7 +333,6 @@ setMethod("ascii", #' as if computed by \code{java.lang.Math.asin()} #' #' @rdname column_math_functions -#' @export #' @aliases asin asin,Column-method #' @note asin since 1.5.0 setMethod("asin", @@ -353,7 +347,6 @@ setMethod("asin", #' as if computed by \code{java.lang.Math.atan()} #' #' @rdname column_math_functions -#' @export #' @aliases atan atan,Column-method #' @note atan since 1.5.0 setMethod("atan", @@ -370,7 +363,6 @@ setMethod("atan", #' @rdname avg #' @name avg #' @family aggregate functions -#' @export #' @aliases avg,Column-method #' @examples \dontrun{avg(df$c)} #' @note avg since 1.4.0 @@ -386,7 +378,6 @@ setMethod("avg", #' a string column. This is the reverse of unbase64. #' #' @rdname column_string_functions -#' @export #' @aliases base64 base64,Column-method #' @examples #' @@ -410,7 +401,6 @@ setMethod("base64", #' of the given long column. For example, bin("12") returns "1100". #' #' @rdname column_math_functions -#' @export #' @aliases bin bin,Column-method #' @note bin since 1.5.0 setMethod("bin", @@ -424,7 +414,6 @@ setMethod("bin", #' \code{bitwiseNOT}: Computes bitwise NOT. #' #' @rdname column_nonaggregate_functions -#' @export #' @aliases bitwiseNOT bitwiseNOT,Column-method #' @examples #' @@ -442,7 +431,6 @@ setMethod("bitwiseNOT", #' \code{cbrt}: Computes the cube-root of the given value. #' #' @rdname column_math_functions -#' @export #' @aliases cbrt cbrt,Column-method #' @note cbrt since 1.4.0 setMethod("cbrt", @@ -456,7 +444,6 @@ setMethod("cbrt", #' \code{ceil}: Computes the ceiling of the given value. #' #' @rdname column_math_functions -#' @export #' @aliases ceil ceil,Column-method #' @note ceil since 1.5.0 setMethod("ceil", @@ -471,7 +458,6 @@ setMethod("ceil", #' #' @rdname column_math_functions #' @aliases ceiling ceiling,Column-method -#' @export #' @note ceiling since 1.5.0 setMethod("ceiling", signature(x = "Column"), @@ -483,7 +469,6 @@ setMethod("ceiling", #' \code{coalesce}: Returns the first column that is not NA, or NA if all inputs are. #' #' @rdname column_nonaggregate_functions -#' @export #' @aliases coalesce,Column-method #' @note coalesce(Column) since 2.1.1 setMethod("coalesce", @@ -514,7 +499,6 @@ col <- function(x) { #' @rdname column #' @name column #' @family non-aggregate functions -#' @export #' @aliases column,character-method #' @examples \dontrun{column("name")} #' @note column since 1.6.0 @@ -533,7 +517,6 @@ setMethod("column", #' @rdname corr #' @name corr #' @family aggregate functions -#' @export #' @aliases corr,Column-method #' @examples #' \dontrun{ @@ -557,7 +540,6 @@ setMethod("corr", signature(x = "Column"), #' @rdname cov #' @name cov #' @family aggregate functions -#' @export #' @aliases cov,characterOrColumn-method #' @examples #' \dontrun{ @@ -598,7 +580,6 @@ setMethod("covar_samp", signature(col1 = "characterOrColumn", col2 = "characterO #' #' @rdname cov #' @name covar_pop -#' @export #' @aliases covar_pop,characterOrColumn,characterOrColumn-method #' @note covar_pop since 2.0.0 setMethod("covar_pop", signature(col1 = "characterOrColumn", col2 = "characterOrColumn"), @@ -618,7 +599,6 @@ setMethod("covar_pop", signature(col1 = "characterOrColumn", col2 = "characterOr #' #' @rdname column_math_functions #' @aliases cos cos,Column-method -#' @export #' @note cos since 1.5.0 setMethod("cos", signature(x = "Column"), @@ -633,7 +613,6 @@ setMethod("cos", #' #' @rdname column_math_functions #' @aliases cosh cosh,Column-method -#' @export #' @note cosh since 1.5.0 setMethod("cosh", signature(x = "Column"), @@ -651,7 +630,6 @@ setMethod("cosh", #' @name count #' @family aggregate functions #' @aliases count,Column-method -#' @export #' @examples \dontrun{count(df$c)} #' @note count since 1.4.0 setMethod("count", @@ -667,7 +645,6 @@ setMethod("count", #' #' @rdname column_misc_functions #' @aliases crc32 crc32,Column-method -#' @export #' @note crc32 since 1.5.0 setMethod("crc32", signature(x = "Column"), @@ -682,7 +659,6 @@ setMethod("crc32", #' #' @rdname column_misc_functions #' @aliases hash hash,Column-method -#' @export #' @note hash since 2.0.0 setMethod("hash", signature(x = "Column"), @@ -701,7 +677,6 @@ setMethod("hash", #' #' @rdname column_datetime_functions #' @aliases dayofmonth dayofmonth,Column-method -#' @export #' @examples #' #' \dontrun{ @@ -723,7 +698,6 @@ setMethod("dayofmonth", #' #' @rdname column_datetime_functions #' @aliases dayofweek dayofweek,Column-method -#' @export #' @note dayofweek since 2.3.0 setMethod("dayofweek", signature(x = "Column"), @@ -738,7 +712,6 @@ setMethod("dayofweek", #' #' @rdname column_datetime_functions #' @aliases dayofyear dayofyear,Column-method -#' @export #' @note dayofyear since 1.5.0 setMethod("dayofyear", signature(x = "Column"), @@ -756,7 +729,6 @@ setMethod("dayofyear", #' #' @rdname column_string_functions #' @aliases decode decode,Column,character-method -#' @export #' @note decode since 1.6.0 setMethod("decode", signature(x = "Column", charset = "character"), @@ -771,7 +743,6 @@ setMethod("decode", #' #' @rdname column_string_functions #' @aliases encode encode,Column,character-method -#' @export #' @note encode since 1.6.0 setMethod("encode", signature(x = "Column", charset = "character"), @@ -785,7 +756,6 @@ setMethod("encode", #' #' @rdname column_math_functions #' @aliases exp exp,Column-method -#' @export #' @note exp since 1.5.0 setMethod("exp", signature(x = "Column"), @@ -799,7 +769,6 @@ setMethod("exp", #' #' @rdname column_math_functions #' @aliases expm1 expm1,Column-method -#' @export #' @note expm1 since 1.5.0 setMethod("expm1", signature(x = "Column"), @@ -813,7 +782,6 @@ setMethod("expm1", #' #' @rdname column_math_functions #' @aliases factorial factorial,Column-method -#' @export #' @note factorial since 1.5.0 setMethod("factorial", signature(x = "Column"), @@ -836,7 +804,6 @@ setMethod("factorial", #' @name first #' @aliases first,characterOrColumn-method #' @family aggregate functions -#' @export #' @examples #' \dontrun{ #' first(df$c) @@ -860,7 +827,6 @@ setMethod("first", #' #' @rdname column_math_functions #' @aliases floor floor,Column-method -#' @export #' @note floor since 1.5.0 setMethod("floor", signature(x = "Column"), @@ -874,7 +840,6 @@ setMethod("floor", #' #' @rdname column_math_functions #' @aliases hex hex,Column-method -#' @export #' @note hex since 1.5.0 setMethod("hex", signature(x = "Column"), @@ -888,7 +853,6 @@ setMethod("hex", #' #' @rdname column_datetime_functions #' @aliases hour hour,Column-method -#' @export #' @examples #' #' \dontrun{ @@ -911,7 +875,6 @@ setMethod("hour", #' #' @rdname column_string_functions #' @aliases initcap initcap,Column-method -#' @export #' @examples #' #' \dontrun{ @@ -946,7 +909,6 @@ setMethod("isnan", #' #' @rdname column_nonaggregate_functions #' @aliases is.nan is.nan,Column-method -#' @export #' @note is.nan since 2.0.0 setMethod("is.nan", signature(x = "Column"), @@ -959,7 +921,6 @@ setMethod("is.nan", #' #' @rdname column_aggregate_functions #' @aliases kurtosis kurtosis,Column-method -#' @export #' @examples #' #' \dontrun{ @@ -988,7 +949,6 @@ setMethod("kurtosis", #' @name last #' @aliases last,characterOrColumn-method #' @family aggregate functions -#' @export #' @examples #' \dontrun{ #' last(df$c) @@ -1014,7 +974,6 @@ setMethod("last", #' #' @rdname column_datetime_functions #' @aliases last_day last_day,Column-method -#' @export #' @examples #' #' \dontrun{ @@ -1034,7 +993,6 @@ setMethod("last_day", #' #' @rdname column_string_functions #' @aliases length length,Column-method -#' @export #' @note length since 1.5.0 setMethod("length", signature(x = "Column"), @@ -1048,7 +1006,6 @@ setMethod("length", #' #' @rdname column_math_functions #' @aliases log log,Column-method -#' @export #' @note log since 1.5.0 setMethod("log", signature(x = "Column"), @@ -1062,7 +1019,6 @@ setMethod("log", #' #' @rdname column_math_functions #' @aliases log10 log10,Column-method -#' @export #' @note log10 since 1.5.0 setMethod("log10", signature(x = "Column"), @@ -1076,7 +1032,6 @@ setMethod("log10", #' #' @rdname column_math_functions #' @aliases log1p log1p,Column-method -#' @export #' @note log1p since 1.5.0 setMethod("log1p", signature(x = "Column"), @@ -1090,7 +1045,6 @@ setMethod("log1p", #' #' @rdname column_math_functions #' @aliases log2 log2,Column-method -#' @export #' @note log2 since 1.5.0 setMethod("log2", signature(x = "Column"), @@ -1104,7 +1058,6 @@ setMethod("log2", #' #' @rdname column_string_functions #' @aliases lower lower,Column-method -#' @export #' @note lower since 1.4.0 setMethod("lower", signature(x = "Column"), @@ -1119,7 +1072,6 @@ setMethod("lower", #' #' @rdname column_string_functions #' @aliases ltrim ltrim,Column,missing-method -#' @export #' @examples #' #' \dontrun{ @@ -1143,7 +1095,6 @@ setMethod("ltrim", #' @param trimString a character string to trim with #' @rdname column_string_functions #' @aliases ltrim,Column,character-method -#' @export #' @note ltrim(Column, character) since 2.3.0 setMethod("ltrim", signature(x = "Column", trimString = "character"), @@ -1171,7 +1122,6 @@ setMethod("max", #' #' @rdname column_misc_functions #' @aliases md5 md5,Column-method -#' @export #' @note md5 since 1.5.0 setMethod("md5", signature(x = "Column"), @@ -1185,7 +1135,6 @@ setMethod("md5", #' #' @rdname column_aggregate_functions #' @aliases mean mean,Column-method -#' @export #' @examples #' #' \dontrun{ @@ -1211,7 +1160,6 @@ setMethod("mean", #' #' @rdname column_aggregate_functions #' @aliases min min,Column-method -#' @export #' @note min since 1.5.0 setMethod("min", signature(x = "Column"), @@ -1225,7 +1173,6 @@ setMethod("min", #' #' @rdname column_datetime_functions #' @aliases minute minute,Column-method -#' @export #' @note minute since 1.5.0 setMethod("minute", signature(x = "Column"), @@ -1248,7 +1195,6 @@ setMethod("minute", #' #' @rdname column_nonaggregate_functions #' @aliases monotonically_increasing_id monotonically_increasing_id,missing-method -#' @export #' @examples #' #' \dontrun{head(select(df, monotonically_increasing_id()))} @@ -1264,7 +1210,6 @@ setMethod("monotonically_increasing_id", #' #' @rdname column_datetime_functions #' @aliases month month,Column-method -#' @export #' @note month since 1.5.0 setMethod("month", signature(x = "Column"), @@ -1278,7 +1223,6 @@ setMethod("month", #' #' @rdname column_nonaggregate_functions #' @aliases negate negate,Column-method -#' @export #' @note negate since 1.5.0 setMethod("negate", signature(x = "Column"), @@ -1292,7 +1236,6 @@ setMethod("negate", #' #' @rdname column_datetime_functions #' @aliases quarter quarter,Column-method -#' @export #' @note quarter since 1.5.0 setMethod("quarter", signature(x = "Column"), @@ -1306,7 +1249,6 @@ setMethod("quarter", #' #' @rdname column_string_functions #' @aliases reverse reverse,Column-method -#' @export #' @note reverse since 1.5.0 setMethod("reverse", signature(x = "Column"), @@ -1321,7 +1263,6 @@ setMethod("reverse", #' #' @rdname column_math_functions #' @aliases rint rint,Column-method -#' @export #' @note rint since 1.5.0 setMethod("rint", signature(x = "Column"), @@ -1336,7 +1277,6 @@ setMethod("rint", #' #' @rdname column_math_functions #' @aliases round round,Column-method -#' @export #' @note round since 1.5.0 setMethod("round", signature(x = "Column"), @@ -1356,7 +1296,6 @@ setMethod("round", #' to the left of the decimal point when \code{scale} < 0. #' @rdname column_math_functions #' @aliases bround bround,Column-method -#' @export #' @note bround since 2.0.0 setMethod("bround", signature(x = "Column"), @@ -1371,7 +1310,6 @@ setMethod("bround", #' #' @rdname column_string_functions #' @aliases rtrim rtrim,Column,missing-method -#' @export #' @note rtrim since 1.5.0 setMethod("rtrim", signature(x = "Column", trimString = "missing"), @@ -1382,7 +1320,6 @@ setMethod("rtrim", #' @rdname column_string_functions #' @aliases rtrim,Column,character-method -#' @export #' @note rtrim(Column, character) since 2.3.0 setMethod("rtrim", signature(x = "Column", trimString = "character"), @@ -1396,7 +1333,6 @@ setMethod("rtrim", #' #' @rdname column_aggregate_functions #' @aliases sd sd,Column-method -#' @export #' @examples #' #' \dontrun{ @@ -1414,7 +1350,6 @@ setMethod("sd", #' #' @rdname column_datetime_functions #' @aliases second second,Column-method -#' @export #' @note second since 1.5.0 setMethod("second", signature(x = "Column"), @@ -1429,7 +1364,6 @@ setMethod("second", #' #' @rdname column_misc_functions #' @aliases sha1 sha1,Column-method -#' @export #' @note sha1 since 1.5.0 setMethod("sha1", signature(x = "Column"), @@ -1443,7 +1377,6 @@ setMethod("sha1", #' #' @rdname column_math_functions #' @aliases signum signum,Column-method -#' @export #' @note signum since 1.5.0 setMethod("signum", signature(x = "Column"), @@ -1457,7 +1390,6 @@ setMethod("signum", #' #' @rdname column_math_functions #' @aliases sign sign,Column-method -#' @export #' @note sign since 1.5.0 setMethod("sign", signature(x = "Column"), function(x) { @@ -1470,7 +1402,6 @@ setMethod("sign", signature(x = "Column"), #' #' @rdname column_math_functions #' @aliases sin sin,Column-method -#' @export #' @note sin since 1.5.0 setMethod("sin", signature(x = "Column"), @@ -1485,7 +1416,6 @@ setMethod("sin", #' #' @rdname column_math_functions #' @aliases sinh sinh,Column-method -#' @export #' @note sinh since 1.5.0 setMethod("sinh", signature(x = "Column"), @@ -1499,7 +1429,6 @@ setMethod("sinh", #' #' @rdname column_aggregate_functions #' @aliases skewness skewness,Column-method -#' @export #' @note skewness since 1.6.0 setMethod("skewness", signature(x = "Column"), @@ -1513,7 +1442,6 @@ setMethod("skewness", #' #' @rdname column_string_functions #' @aliases soundex soundex,Column-method -#' @export #' @note soundex since 1.5.0 setMethod("soundex", signature(x = "Column"), @@ -1530,7 +1458,6 @@ setMethod("soundex", #' #' @rdname column_nonaggregate_functions #' @aliases spark_partition_id spark_partition_id,missing-method -#' @export #' @examples #' #' \dontrun{head(select(df, spark_partition_id()))} @@ -1560,7 +1487,6 @@ setMethod("stddev", #' #' @rdname column_aggregate_functions #' @aliases stddev_pop stddev_pop,Column-method -#' @export #' @note stddev_pop since 1.6.0 setMethod("stddev_pop", signature(x = "Column"), @@ -1574,7 +1500,6 @@ setMethod("stddev_pop", #' #' @rdname column_aggregate_functions #' @aliases stddev_samp stddev_samp,Column-method -#' @export #' @note stddev_samp since 1.6.0 setMethod("stddev_samp", signature(x = "Column"), @@ -1588,7 +1513,6 @@ setMethod("stddev_samp", #' #' @rdname column_nonaggregate_functions #' @aliases struct struct,characterOrColumn-method -#' @export #' @examples #' #' \dontrun{ @@ -1614,7 +1538,6 @@ setMethod("struct", #' #' @rdname column_math_functions #' @aliases sqrt sqrt,Column-method -#' @export #' @note sqrt since 1.5.0 setMethod("sqrt", signature(x = "Column"), @@ -1628,7 +1551,6 @@ setMethod("sqrt", #' #' @rdname column_aggregate_functions #' @aliases sum sum,Column-method -#' @export #' @note sum since 1.5.0 setMethod("sum", signature(x = "Column"), @@ -1642,7 +1564,6 @@ setMethod("sum", #' #' @rdname column_aggregate_functions #' @aliases sumDistinct sumDistinct,Column-method -#' @export #' @examples #' #' \dontrun{ @@ -1663,7 +1584,6 @@ setMethod("sumDistinct", #' #' @rdname column_math_functions #' @aliases tan tan,Column-method -#' @export #' @note tan since 1.5.0 setMethod("tan", signature(x = "Column"), @@ -1678,7 +1598,6 @@ setMethod("tan", #' #' @rdname column_math_functions #' @aliases tanh tanh,Column-method -#' @export #' @note tanh since 1.5.0 setMethod("tanh", signature(x = "Column"), @@ -1693,7 +1612,6 @@ setMethod("tanh", #' #' @rdname column_math_functions #' @aliases toDegrees toDegrees,Column-method -#' @export #' @note toDegrees since 1.4.0 setMethod("toDegrees", signature(x = "Column"), @@ -1708,7 +1626,6 @@ setMethod("toDegrees", #' #' @rdname column_math_functions #' @aliases toRadians toRadians,Column-method -#' @export #' @note toRadians since 1.4.0 setMethod("toRadians", signature(x = "Column"), @@ -1728,7 +1645,6 @@ setMethod("toRadians", #' #' @rdname column_datetime_functions #' @aliases to_date to_date,Column,missing-method -#' @export #' @examples #' #' \dontrun{ @@ -1749,7 +1665,6 @@ setMethod("to_date", #' @rdname column_datetime_functions #' @aliases to_date,Column,character-method -#' @export #' @note to_date(Column, character) since 2.2.0 setMethod("to_date", signature(x = "Column", format = "character"), @@ -1765,7 +1680,6 @@ setMethod("to_date", #' #' @rdname column_collection_functions #' @aliases to_json to_json,Column-method -#' @export #' @examples #' #' \dontrun{ @@ -1803,7 +1717,6 @@ setMethod("to_json", signature(x = "Column"), #' #' @rdname column_datetime_functions #' @aliases to_timestamp to_timestamp,Column,missing-method -#' @export #' @note to_timestamp(Column) since 2.2.0 setMethod("to_timestamp", signature(x = "Column", format = "missing"), @@ -1814,7 +1727,6 @@ setMethod("to_timestamp", #' @rdname column_datetime_functions #' @aliases to_timestamp,Column,character-method -#' @export #' @note to_timestamp(Column, character) since 2.2.0 setMethod("to_timestamp", signature(x = "Column", format = "character"), @@ -1829,7 +1741,6 @@ setMethod("to_timestamp", #' #' @rdname column_string_functions #' @aliases trim trim,Column,missing-method -#' @export #' @note trim since 1.5.0 setMethod("trim", signature(x = "Column", trimString = "missing"), @@ -1840,7 +1751,6 @@ setMethod("trim", #' @rdname column_string_functions #' @aliases trim,Column,character-method -#' @export #' @note trim(Column, character) since 2.3.0 setMethod("trim", signature(x = "Column", trimString = "character"), @@ -1855,7 +1765,6 @@ setMethod("trim", #' #' @rdname column_string_functions #' @aliases unbase64 unbase64,Column-method -#' @export #' @note unbase64 since 1.5.0 setMethod("unbase64", signature(x = "Column"), @@ -1870,7 +1779,6 @@ setMethod("unbase64", #' #' @rdname column_math_functions #' @aliases unhex unhex,Column-method -#' @export #' @note unhex since 1.5.0 setMethod("unhex", signature(x = "Column"), @@ -1884,7 +1792,6 @@ setMethod("unhex", #' #' @rdname column_string_functions #' @aliases upper upper,Column-method -#' @export #' @note upper since 1.4.0 setMethod("upper", signature(x = "Column"), @@ -1898,7 +1805,6 @@ setMethod("upper", #' #' @rdname column_aggregate_functions #' @aliases var var,Column-method -#' @export #' @examples #' #'\dontrun{ @@ -1913,7 +1819,6 @@ setMethod("var", #' @rdname column_aggregate_functions #' @aliases variance variance,Column-method -#' @export #' @note variance since 1.6.0 setMethod("variance", signature(x = "Column"), @@ -1927,7 +1832,6 @@ setMethod("variance", #' #' @rdname column_aggregate_functions #' @aliases var_pop var_pop,Column-method -#' @export #' @note var_pop since 1.5.0 setMethod("var_pop", signature(x = "Column"), @@ -1941,7 +1845,6 @@ setMethod("var_pop", #' #' @rdname column_aggregate_functions #' @aliases var_samp var_samp,Column-method -#' @export #' @note var_samp since 1.6.0 setMethod("var_samp", signature(x = "Column"), @@ -1955,7 +1858,6 @@ setMethod("var_samp", #' #' @rdname column_datetime_functions #' @aliases weekofyear weekofyear,Column-method -#' @export #' @note weekofyear since 1.5.0 setMethod("weekofyear", signature(x = "Column"), @@ -1969,7 +1871,6 @@ setMethod("weekofyear", #' #' @rdname column_datetime_functions #' @aliases year year,Column-method -#' @export #' @note year since 1.5.0 setMethod("year", signature(x = "Column"), @@ -1985,7 +1886,6 @@ setMethod("year", #' #' @rdname column_math_functions #' @aliases atan2 atan2,Column-method -#' @export #' @note atan2 since 1.5.0 setMethod("atan2", signature(y = "Column"), function(y, x) { @@ -2001,7 +1901,6 @@ setMethod("atan2", signature(y = "Column"), #' #' @rdname column_datetime_diff_functions #' @aliases datediff datediff,Column-method -#' @export #' @examples #' #' \dontrun{ @@ -2025,7 +1924,6 @@ setMethod("datediff", signature(y = "Column"), #' #' @rdname column_math_functions #' @aliases hypot hypot,Column-method -#' @export #' @note hypot since 1.4.0 setMethod("hypot", signature(y = "Column"), function(y, x) { @@ -2041,7 +1939,6 @@ setMethod("hypot", signature(y = "Column"), #' #' @rdname column_string_functions #' @aliases levenshtein levenshtein,Column-method -#' @export #' @examples #' #' \dontrun{ @@ -2064,7 +1961,6 @@ setMethod("levenshtein", signature(y = "Column"), #' #' @rdname column_datetime_diff_functions #' @aliases months_between months_between,Column-method -#' @export #' @note months_between since 1.5.0 setMethod("months_between", signature(y = "Column"), function(y, x) { @@ -2082,7 +1978,6 @@ setMethod("months_between", signature(y = "Column"), #' #' @rdname column_nonaggregate_functions #' @aliases nanvl nanvl,Column-method -#' @export #' @note nanvl since 1.5.0 setMethod("nanvl", signature(y = "Column"), function(y, x) { @@ -2099,7 +1994,6 @@ setMethod("nanvl", signature(y = "Column"), #' #' @rdname column_math_functions #' @aliases pmod pmod,Column-method -#' @export #' @note pmod since 1.5.0 setMethod("pmod", signature(y = "Column"), function(y, x) { @@ -2114,7 +2008,6 @@ setMethod("pmod", signature(y = "Column"), #' #' @rdname column_aggregate_functions #' @aliases approxCountDistinct,Column-method -#' @export #' @note approxCountDistinct(Column, numeric) since 1.4.0 setMethod("approxCountDistinct", signature(x = "Column"), @@ -2128,7 +2021,6 @@ setMethod("approxCountDistinct", #' #' @rdname column_aggregate_functions #' @aliases countDistinct countDistinct,Column-method -#' @export #' @note countDistinct since 1.4.0 setMethod("countDistinct", signature(x = "Column"), @@ -2148,7 +2040,6 @@ setMethod("countDistinct", #' #' @rdname column_string_functions #' @aliases concat concat,Column-method -#' @export #' @examples #' #' \dontrun{ @@ -2177,7 +2068,6 @@ setMethod("concat", #' #' @rdname column_nonaggregate_functions #' @aliases greatest greatest,Column-method -#' @export #' @note greatest since 1.5.0 setMethod("greatest", signature(x = "Column"), @@ -2197,7 +2087,6 @@ setMethod("greatest", #' #' @rdname column_nonaggregate_functions #' @aliases least least,Column-method -#' @export #' @note least since 1.5.0 setMethod("least", signature(x = "Column"), @@ -2216,7 +2105,6 @@ setMethod("least", #' #' @rdname column_aggregate_functions #' @aliases n_distinct n_distinct,Column-method -#' @export #' @note n_distinct since 1.4.0 setMethod("n_distinct", signature(x = "Column"), function(x, ...) { @@ -2226,7 +2114,6 @@ setMethod("n_distinct", signature(x = "Column"), #' @rdname count #' @name n #' @aliases n,Column-method -#' @export #' @examples \dontrun{n(df$c)} #' @note n since 1.4.0 setMethod("n", signature(x = "Column"), @@ -2245,7 +2132,6 @@ setMethod("n", signature(x = "Column"), #' @rdname column_datetime_diff_functions #' #' @aliases date_format date_format,Column,character-method -#' @export #' @note date_format since 1.5.0 setMethod("date_format", signature(y = "Column", x = "character"), function(y, x) { @@ -2263,7 +2149,6 @@ setMethod("date_format", signature(y = "Column", x = "character"), #' Since Spark 2.3, the DDL-formatted string is also supported for the schema. #' @param as.json.array indicating if input string is JSON array of objects or a single object. #' @aliases from_json from_json,Column,characterOrstructType-method -#' @export #' @examples #' #' \dontrun{ @@ -2306,7 +2191,6 @@ setMethod("from_json", signature(x = "Column", schema = "characterOrstructType") #' @rdname column_datetime_diff_functions #' #' @aliases from_utc_timestamp from_utc_timestamp,Column,character-method -#' @export #' @examples #' #' \dontrun{ @@ -2328,7 +2212,6 @@ setMethod("from_utc_timestamp", signature(y = "Column", x = "character"), #' #' @rdname column_string_functions #' @aliases instr instr,Column,character-method -#' @export #' @examples #' #' \dontrun{ @@ -2351,7 +2234,6 @@ setMethod("instr", signature(y = "Column", x = "character"), #' #' @rdname column_datetime_diff_functions #' @aliases next_day next_day,Column,character-method -#' @export #' @note next_day since 1.5.0 setMethod("next_day", signature(y = "Column", x = "character"), function(y, x) { @@ -2366,7 +2248,6 @@ setMethod("next_day", signature(y = "Column", x = "character"), #' #' @rdname column_datetime_diff_functions #' @aliases to_utc_timestamp to_utc_timestamp,Column,character-method -#' @export #' @note to_utc_timestamp since 1.5.0 setMethod("to_utc_timestamp", signature(y = "Column", x = "character"), function(y, x) { @@ -2379,7 +2260,6 @@ setMethod("to_utc_timestamp", signature(y = "Column", x = "character"), #' #' @rdname column_datetime_diff_functions #' @aliases add_months add_months,Column,numeric-method -#' @export #' @examples #' #' \dontrun{ @@ -2400,7 +2280,6 @@ setMethod("add_months", signature(y = "Column", x = "numeric"), #' #' @rdname column_datetime_diff_functions #' @aliases date_add date_add,Column,numeric-method -#' @export #' @note date_add since 1.5.0 setMethod("date_add", signature(y = "Column", x = "numeric"), function(y, x) { @@ -2414,7 +2293,6 @@ setMethod("date_add", signature(y = "Column", x = "numeric"), #' @rdname column_datetime_diff_functions #' #' @aliases date_sub date_sub,Column,numeric-method -#' @export #' @note date_sub since 1.5.0 setMethod("date_sub", signature(y = "Column", x = "numeric"), function(y, x) { @@ -2431,7 +2309,6 @@ setMethod("date_sub", signature(y = "Column", x = "numeric"), #' #' @rdname column_string_functions #' @aliases format_number format_number,Column,numeric-method -#' @export #' @examples #' #' \dontrun{ @@ -2454,7 +2331,6 @@ setMethod("format_number", signature(y = "Column", x = "numeric"), #' #' @rdname column_misc_functions #' @aliases sha2 sha2,Column,numeric-method -#' @export #' @note sha2 since 1.5.0 setMethod("sha2", signature(y = "Column", x = "numeric"), function(y, x) { @@ -2468,7 +2344,6 @@ setMethod("sha2", signature(y = "Column", x = "numeric"), #' #' @rdname column_math_functions #' @aliases shiftLeft shiftLeft,Column,numeric-method -#' @export #' @note shiftLeft since 1.5.0 setMethod("shiftLeft", signature(y = "Column", x = "numeric"), function(y, x) { @@ -2484,7 +2359,6 @@ setMethod("shiftLeft", signature(y = "Column", x = "numeric"), #' #' @rdname column_math_functions #' @aliases shiftRight shiftRight,Column,numeric-method -#' @export #' @note shiftRight since 1.5.0 setMethod("shiftRight", signature(y = "Column", x = "numeric"), function(y, x) { @@ -2500,7 +2374,6 @@ setMethod("shiftRight", signature(y = "Column", x = "numeric"), #' #' @rdname column_math_functions #' @aliases shiftRightUnsigned shiftRightUnsigned,Column,numeric-method -#' @export #' @note shiftRightUnsigned since 1.5.0 setMethod("shiftRightUnsigned", signature(y = "Column", x = "numeric"), function(y, x) { @@ -2517,7 +2390,6 @@ setMethod("shiftRightUnsigned", signature(y = "Column", x = "numeric"), #' @param sep separator to use. #' @rdname column_string_functions #' @aliases concat_ws concat_ws,character,Column-method -#' @export #' @note concat_ws since 1.5.0 setMethod("concat_ws", signature(sep = "character", x = "Column"), function(sep, x, ...) { @@ -2533,7 +2405,6 @@ setMethod("concat_ws", signature(sep = "character", x = "Column"), #' @param toBase base to convert to. #' @rdname column_math_functions #' @aliases conv conv,Column,numeric,numeric-method -#' @export #' @note conv since 1.5.0 setMethod("conv", signature(x = "Column", fromBase = "numeric", toBase = "numeric"), function(x, fromBase, toBase) { @@ -2551,7 +2422,6 @@ setMethod("conv", signature(x = "Column", fromBase = "numeric", toBase = "numeri #' #' @rdname column_nonaggregate_functions #' @aliases expr expr,character-method -#' @export #' @note expr since 1.5.0 setMethod("expr", signature(x = "character"), function(x) { @@ -2566,7 +2436,6 @@ setMethod("expr", signature(x = "character"), #' @param format a character object of format strings. #' @rdname column_string_functions #' @aliases format_string format_string,character,Column-method -#' @export #' @note format_string since 1.5.0 setMethod("format_string", signature(format = "character", x = "Column"), function(format, x, ...) { @@ -2587,7 +2456,6 @@ setMethod("format_string", signature(format = "character", x = "Column"), #' @rdname column_datetime_functions #' #' @aliases from_unixtime from_unixtime,Column-method -#' @export #' @examples #' #' \dontrun{ @@ -2629,7 +2497,6 @@ setMethod("from_unixtime", signature(x = "Column"), #' \code{startTime} as \code{"15 minutes"}. #' @rdname column_datetime_functions #' @aliases window window,Column-method -#' @export #' @examples #' #' \dontrun{ @@ -2680,7 +2547,6 @@ setMethod("window", signature(x = "Column"), #' @param pos start position of search. #' @rdname column_string_functions #' @aliases locate locate,character,Column-method -#' @export #' @note locate since 1.5.0 setMethod("locate", signature(substr = "character", str = "Column"), function(substr, str, pos = 1) { @@ -2697,7 +2563,6 @@ setMethod("locate", signature(substr = "character", str = "Column"), #' @param pad a character string to be padded with. #' @rdname column_string_functions #' @aliases lpad lpad,Column,numeric,character-method -#' @export #' @note lpad since 1.5.0 setMethod("lpad", signature(x = "Column", len = "numeric", pad = "character"), function(x, len, pad) { @@ -2714,7 +2579,6 @@ setMethod("lpad", signature(x = "Column", len = "numeric", pad = "character"), #' @rdname column_nonaggregate_functions #' @param seed a random seed. Can be missing. #' @aliases rand rand,missing-method -#' @export #' @examples #' #' \dontrun{ @@ -2729,7 +2593,6 @@ setMethod("rand", signature(seed = "missing"), #' @rdname column_nonaggregate_functions #' @aliases rand,numeric-method -#' @export #' @note rand(numeric) since 1.5.0 setMethod("rand", signature(seed = "numeric"), function(seed) { @@ -2743,7 +2606,6 @@ setMethod("rand", signature(seed = "numeric"), #' #' @rdname column_nonaggregate_functions #' @aliases randn randn,missing-method -#' @export #' @note randn since 1.5.0 setMethod("randn", signature(seed = "missing"), function(seed) { @@ -2753,7 +2615,6 @@ setMethod("randn", signature(seed = "missing"), #' @rdname column_nonaggregate_functions #' @aliases randn,numeric-method -#' @export #' @note randn(numeric) since 1.5.0 setMethod("randn", signature(seed = "numeric"), function(seed) { @@ -2770,7 +2631,6 @@ setMethod("randn", signature(seed = "numeric"), #' @param idx a group index. #' @rdname column_string_functions #' @aliases regexp_extract regexp_extract,Column,character,numeric-method -#' @export #' @examples #' #' \dontrun{ @@ -2799,7 +2659,6 @@ setMethod("regexp_extract", #' @param replacement a character string that a matched \code{pattern} is replaced with. #' @rdname column_string_functions #' @aliases regexp_replace regexp_replace,Column,character,character-method -#' @export #' @note regexp_replace since 1.5.0 setMethod("regexp_replace", signature(x = "Column", pattern = "character", replacement = "character"), @@ -2815,7 +2674,6 @@ setMethod("regexp_replace", #' #' @rdname column_string_functions #' @aliases rpad rpad,Column,numeric,character-method -#' @export #' @note rpad since 1.5.0 setMethod("rpad", signature(x = "Column", len = "numeric", pad = "character"), function(x, len, pad) { @@ -2838,7 +2696,6 @@ setMethod("rpad", signature(x = "Column", len = "numeric", pad = "character"), #' counting from the right. #' @rdname column_string_functions #' @aliases substring_index substring_index,Column,character,numeric-method -#' @export #' @note substring_index since 1.5.0 setMethod("substring_index", signature(x = "Column", delim = "character", count = "numeric"), @@ -2861,7 +2718,6 @@ setMethod("substring_index", #' at the same location, if any. #' @rdname column_string_functions #' @aliases translate translate,Column,character,character-method -#' @export #' @note translate since 1.5.0 setMethod("translate", signature(x = "Column", matchingString = "character", replaceString = "character"), @@ -2876,7 +2732,6 @@ setMethod("translate", #' #' @rdname column_datetime_functions #' @aliases unix_timestamp unix_timestamp,missing,missing-method -#' @export #' @note unix_timestamp since 1.5.0 setMethod("unix_timestamp", signature(x = "missing", format = "missing"), function(x, format) { @@ -2886,7 +2741,6 @@ setMethod("unix_timestamp", signature(x = "missing", format = "missing"), #' @rdname column_datetime_functions #' @aliases unix_timestamp,Column,missing-method -#' @export #' @note unix_timestamp(Column) since 1.5.0 setMethod("unix_timestamp", signature(x = "Column", format = "missing"), function(x, format) { @@ -2896,7 +2750,6 @@ setMethod("unix_timestamp", signature(x = "Column", format = "missing"), #' @rdname column_datetime_functions #' @aliases unix_timestamp,Column,character-method -#' @export #' @note unix_timestamp(Column, character) since 1.5.0 setMethod("unix_timestamp", signature(x = "Column", format = "character"), function(x, format = "yyyy-MM-dd HH:mm:ss") { @@ -2912,7 +2765,6 @@ setMethod("unix_timestamp", signature(x = "Column", format = "character"), #' @param condition the condition to test on. Must be a Column expression. #' @param value result expression. #' @aliases when when,Column-method -#' @export #' @examples #' #' \dontrun{ @@ -2941,7 +2793,6 @@ setMethod("when", signature(condition = "Column", value = "ANY"), #' @param yes return values for \code{TRUE} elements of test. #' @param no return values for \code{FALSE} elements of test. #' @aliases ifelse ifelse,Column-method -#' @export #' @note ifelse since 1.5.0 setMethod("ifelse", signature(test = "Column", yes = "ANY", no = "ANY"), @@ -2967,7 +2818,6 @@ setMethod("ifelse", #' #' @rdname column_window_functions #' @aliases cume_dist cume_dist,missing-method -#' @export #' @note cume_dist since 1.6.0 setMethod("cume_dist", signature("missing"), @@ -2988,7 +2838,6 @@ setMethod("cume_dist", #' #' @rdname column_window_functions #' @aliases dense_rank dense_rank,missing-method -#' @export #' @note dense_rank since 1.6.0 setMethod("dense_rank", signature("missing"), @@ -3005,7 +2854,6 @@ setMethod("dense_rank", #' #' @rdname column_window_functions #' @aliases lag lag,characterOrColumn-method -#' @export #' @note lag since 1.6.0 setMethod("lag", signature(x = "characterOrColumn"), @@ -3030,7 +2878,6 @@ setMethod("lag", #' #' @rdname column_window_functions #' @aliases lead lead,characterOrColumn,numeric-method -#' @export #' @note lead since 1.6.0 setMethod("lead", signature(x = "characterOrColumn", offset = "numeric", defaultValue = "ANY"), @@ -3054,7 +2901,6 @@ setMethod("lead", #' #' @rdname column_window_functions #' @aliases ntile ntile,numeric-method -#' @export #' @note ntile since 1.6.0 setMethod("ntile", signature(x = "numeric"), @@ -3072,7 +2918,6 @@ setMethod("ntile", #' #' @rdname column_window_functions #' @aliases percent_rank percent_rank,missing-method -#' @export #' @note percent_rank since 1.6.0 setMethod("percent_rank", signature("missing"), @@ -3093,7 +2938,6 @@ setMethod("percent_rank", #' #' @rdname column_window_functions #' @aliases rank rank,missing-method -#' @export #' @note rank since 1.6.0 setMethod("rank", signature(x = "missing"), @@ -3104,7 +2948,6 @@ setMethod("rank", #' @rdname column_window_functions #' @aliases rank,ANY-method -#' @export setMethod("rank", signature(x = "ANY"), function(x, ...) { @@ -3118,7 +2961,6 @@ setMethod("rank", #' #' @rdname column_window_functions #' @aliases row_number row_number,missing-method -#' @export #' @note row_number since 1.6.0 setMethod("row_number", signature("missing"), @@ -3136,7 +2978,6 @@ setMethod("row_number", #' @param value a value to be checked if contained in the column #' @rdname column_collection_functions #' @aliases array_contains array_contains,Column-method -#' @export #' @note array_contains since 1.6.0 setMethod("array_contains", signature(x = "Column", value = "ANY"), @@ -3150,7 +2991,6 @@ setMethod("array_contains", #' #' @rdname column_collection_functions #' @aliases map_keys map_keys,Column-method -#' @export #' @note map_keys since 2.3.0 setMethod("map_keys", signature(x = "Column"), @@ -3164,7 +3004,6 @@ setMethod("map_keys", #' #' @rdname column_collection_functions #' @aliases map_values map_values,Column-method -#' @export #' @note map_values since 2.3.0 setMethod("map_values", signature(x = "Column"), @@ -3178,7 +3017,6 @@ setMethod("map_values", #' #' @rdname column_collection_functions #' @aliases explode explode,Column-method -#' @export #' @note explode since 1.5.0 setMethod("explode", signature(x = "Column"), @@ -3192,7 +3030,6 @@ setMethod("explode", #' #' @rdname column_collection_functions #' @aliases size size,Column-method -#' @export #' @note size since 1.5.0 setMethod("size", signature(x = "Column"), @@ -3210,7 +3047,6 @@ setMethod("size", #' TRUE, sorting is in ascending order. #' FALSE, sorting is in descending order. #' @aliases sort_array sort_array,Column-method -#' @export #' @note sort_array since 1.6.0 setMethod("sort_array", signature(x = "Column"), @@ -3225,7 +3061,6 @@ setMethod("sort_array", #' #' @rdname column_collection_functions #' @aliases posexplode posexplode,Column-method -#' @export #' @note posexplode since 2.1.0 setMethod("posexplode", signature(x = "Column"), @@ -3240,7 +3075,6 @@ setMethod("posexplode", #' #' @rdname column_nonaggregate_functions #' @aliases create_array create_array,Column-method -#' @export #' @note create_array since 2.3.0 setMethod("create_array", signature(x = "Column"), @@ -3261,7 +3095,6 @@ setMethod("create_array", #' #' @rdname column_nonaggregate_functions #' @aliases create_map create_map,Column-method -#' @export #' @note create_map since 2.3.0 setMethod("create_map", signature(x = "Column"), @@ -3279,7 +3112,6 @@ setMethod("create_map", #' #' @rdname column_aggregate_functions #' @aliases collect_list collect_list,Column-method -#' @export #' @examples #' #' \dontrun{ @@ -3299,7 +3131,6 @@ setMethod("collect_list", #' #' @rdname column_aggregate_functions #' @aliases collect_set collect_set,Column-method -#' @export #' @note collect_set since 2.3.0 setMethod("collect_set", signature(x = "Column"), @@ -3314,7 +3145,6 @@ setMethod("collect_set", #' #' @rdname column_string_functions #' @aliases split_string split_string,Column-method -#' @export #' @examples #' #' \dontrun{ @@ -3337,7 +3167,6 @@ setMethod("split_string", #' @param n number of repetitions. #' @rdname column_string_functions #' @aliases repeat_string repeat_string,Column-method -#' @export #' @examples #' #' \dontrun{ @@ -3360,7 +3189,6 @@ setMethod("repeat_string", #' #' @rdname column_collection_functions #' @aliases explode_outer explode_outer,Column-method -#' @export #' @examples #' #' \dontrun{ @@ -3385,7 +3213,6 @@ setMethod("explode_outer", #' #' @rdname column_collection_functions #' @aliases posexplode_outer posexplode_outer,Column-method -#' @export #' @note posexplode_outer since 2.3.0 setMethod("posexplode_outer", signature(x = "Column"), @@ -3406,7 +3233,6 @@ setMethod("posexplode_outer", #' @name not #' @aliases not,Column-method #' @family non-aggregate functions -#' @export #' @examples #' \dontrun{ #' df <- createDataFrame(data.frame( @@ -3434,7 +3260,6 @@ setMethod("not", #' #' @rdname column_aggregate_functions #' @aliases grouping_bit grouping_bit,Column-method -#' @export #' @examples #' #' \dontrun{ @@ -3467,7 +3292,6 @@ setMethod("grouping_bit", #' #' @rdname column_aggregate_functions #' @aliases grouping_id grouping_id,Column-method -#' @export #' @examples #' #' \dontrun{ @@ -3502,7 +3326,6 @@ setMethod("grouping_id", #' #' @rdname column_nonaggregate_functions #' @aliases input_file_name input_file_name,missing-method -#' @export #' @examples #' #' \dontrun{ @@ -3520,7 +3343,6 @@ setMethod("input_file_name", signature("missing"), #' #' @rdname column_datetime_functions #' @aliases trunc trunc,Column-method -#' @export #' @examples #' #' \dontrun{ @@ -3540,7 +3362,6 @@ setMethod("trunc", #' #' @rdname column_datetime_functions #' @aliases date_trunc date_trunc,character,Column-method -#' @export #' @examples #' #' \dontrun{ @@ -3559,7 +3380,6 @@ setMethod("date_trunc", #' #' @rdname column_datetime_functions #' @aliases current_date current_date,missing-method -#' @export #' @examples #' \dontrun{ #' head(select(df, current_date(), current_timestamp()))} @@ -3576,7 +3396,6 @@ setMethod("current_date", #' #' @rdname column_datetime_functions #' @aliases current_timestamp current_timestamp,missing-method -#' @export #' @note current_timestamp since 2.3.0 setMethod("current_timestamp", signature("missing"), diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index e0dde3339fabc..6fba4b6c761dd 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -19,7 +19,6 @@ # @rdname aggregateRDD # @seealso reduce -# @export setGeneric("aggregateRDD", function(x, zeroValue, seqOp, combOp) { standardGeneric("aggregateRDD") }) @@ -27,21 +26,17 @@ setGeneric("cacheRDD", function(x) { standardGeneric("cacheRDD") }) # @rdname coalesce # @seealso repartition -# @export setGeneric("coalesceRDD", function(x, numPartitions, ...) { standardGeneric("coalesceRDD") }) # @rdname checkpoint-methods -# @export setGeneric("checkpointRDD", function(x) { standardGeneric("checkpointRDD") }) setGeneric("collectRDD", function(x, ...) { standardGeneric("collectRDD") }) # @rdname collect-methods -# @export setGeneric("collectAsMap", function(x) { standardGeneric("collectAsMap") }) # @rdname collect-methods -# @export setGeneric("collectPartition", function(x, partitionId) { standardGeneric("collectPartition") @@ -52,19 +47,15 @@ setGeneric("countRDD", function(x) { standardGeneric("countRDD") }) setGeneric("lengthRDD", function(x) { standardGeneric("lengthRDD") }) # @rdname countByValue -# @export setGeneric("countByValue", function(x) { standardGeneric("countByValue") }) # @rdname crosstab -# @export setGeneric("crosstab", function(x, col1, col2) { standardGeneric("crosstab") }) # @rdname freqItems -# @export setGeneric("freqItems", function(x, cols, support = 0.01) { standardGeneric("freqItems") }) # @rdname approxQuantile -# @export setGeneric("approxQuantile", function(x, cols, probabilities, relativeError) { standardGeneric("approxQuantile") @@ -73,18 +64,15 @@ setGeneric("approxQuantile", setGeneric("distinctRDD", function(x, numPartitions = 1) { standardGeneric("distinctRDD") }) # @rdname filterRDD -# @export setGeneric("filterRDD", function(x, f) { standardGeneric("filterRDD") }) setGeneric("firstRDD", function(x, ...) { standardGeneric("firstRDD") }) # @rdname flatMap -# @export setGeneric("flatMap", function(X, FUN) { standardGeneric("flatMap") }) # @rdname fold # @seealso reduce -# @export setGeneric("fold", function(x, zeroValue, op) { standardGeneric("fold") }) setGeneric("foreach", function(x, func) { standardGeneric("foreach") }) @@ -95,17 +83,14 @@ setGeneric("foreachPartition", function(x, func) { standardGeneric("foreachParti setGeneric("getJRDD", function(rdd, ...) { standardGeneric("getJRDD") }) # @rdname glom -# @export setGeneric("glom", function(x) { standardGeneric("glom") }) # @rdname histogram -# @export setGeneric("histogram", function(df, col, nbins=10) { standardGeneric("histogram") }) setGeneric("joinRDD", function(x, y, ...) { standardGeneric("joinRDD") }) # @rdname keyBy -# @export setGeneric("keyBy", function(x, func) { standardGeneric("keyBy") }) setGeneric("lapplyPartition", function(X, FUN) { standardGeneric("lapplyPartition") }) @@ -123,47 +108,37 @@ setGeneric("mapPartitionsWithIndex", function(X, FUN) { standardGeneric("mapPartitionsWithIndex") }) # @rdname maximum -# @export setGeneric("maximum", function(x) { standardGeneric("maximum") }) # @rdname minimum -# @export setGeneric("minimum", function(x) { standardGeneric("minimum") }) # @rdname sumRDD -# @export setGeneric("sumRDD", function(x) { standardGeneric("sumRDD") }) # @rdname name -# @export setGeneric("name", function(x) { standardGeneric("name") }) # @rdname getNumPartitionsRDD -# @export setGeneric("getNumPartitionsRDD", function(x) { standardGeneric("getNumPartitionsRDD") }) # @rdname getNumPartitions -# @export setGeneric("numPartitions", function(x) { standardGeneric("numPartitions") }) setGeneric("persistRDD", function(x, newLevel) { standardGeneric("persistRDD") }) # @rdname pipeRDD -# @export setGeneric("pipeRDD", function(x, command, env = list()) { standardGeneric("pipeRDD")}) # @rdname pivot -# @export setGeneric("pivot", function(x, colname, values = list()) { standardGeneric("pivot") }) # @rdname reduce -# @export setGeneric("reduce", function(x, func) { standardGeneric("reduce") }) setGeneric("repartitionRDD", function(x, ...) { standardGeneric("repartitionRDD") }) # @rdname sampleRDD -# @export setGeneric("sampleRDD", function(x, withReplacement, fraction, seed) { standardGeneric("sampleRDD") @@ -171,21 +146,17 @@ setGeneric("sampleRDD", # @rdname saveAsObjectFile # @seealso objectFile -# @export setGeneric("saveAsObjectFile", function(x, path) { standardGeneric("saveAsObjectFile") }) # @rdname saveAsTextFile -# @export setGeneric("saveAsTextFile", function(x, path) { standardGeneric("saveAsTextFile") }) # @rdname setName -# @export setGeneric("setName", function(x, name) { standardGeneric("setName") }) setGeneric("showRDD", function(object, ...) { standardGeneric("showRDD") }) # @rdname sortBy -# @export setGeneric("sortBy", function(x, func, ascending = TRUE, numPartitions = 1) { standardGeneric("sortBy") @@ -194,88 +165,71 @@ setGeneric("sortBy", setGeneric("takeRDD", function(x, num) { standardGeneric("takeRDD") }) # @rdname takeOrdered -# @export setGeneric("takeOrdered", function(x, num) { standardGeneric("takeOrdered") }) # @rdname takeSample -# @export setGeneric("takeSample", function(x, withReplacement, num, seed) { standardGeneric("takeSample") }) # @rdname top -# @export setGeneric("top", function(x, num) { standardGeneric("top") }) # @rdname unionRDD -# @export setGeneric("unionRDD", function(x, y) { standardGeneric("unionRDD") }) setGeneric("unpersistRDD", function(x, ...) { standardGeneric("unpersistRDD") }) # @rdname zipRDD -# @export setGeneric("zipRDD", function(x, other) { standardGeneric("zipRDD") }) # @rdname zipRDD -# @export setGeneric("zipPartitions", function(..., func) { standardGeneric("zipPartitions") }, signature = "...") # @rdname zipWithIndex # @seealso zipWithUniqueId -# @export setGeneric("zipWithIndex", function(x) { standardGeneric("zipWithIndex") }) # @rdname zipWithUniqueId # @seealso zipWithIndex -# @export setGeneric("zipWithUniqueId", function(x) { standardGeneric("zipWithUniqueId") }) ############ Binary Functions ############# # @rdname cartesian -# @export setGeneric("cartesian", function(x, other) { standardGeneric("cartesian") }) # @rdname countByKey -# @export setGeneric("countByKey", function(x) { standardGeneric("countByKey") }) # @rdname flatMapValues -# @export setGeneric("flatMapValues", function(X, FUN) { standardGeneric("flatMapValues") }) # @rdname intersection -# @export setGeneric("intersection", function(x, other, numPartitions = 1) { standardGeneric("intersection") }) # @rdname keys -# @export setGeneric("keys", function(x) { standardGeneric("keys") }) # @rdname lookup -# @export setGeneric("lookup", function(x, key) { standardGeneric("lookup") }) # @rdname mapValues -# @export setGeneric("mapValues", function(X, FUN) { standardGeneric("mapValues") }) # @rdname sampleByKey -# @export setGeneric("sampleByKey", function(x, withReplacement, fractions, seed) { standardGeneric("sampleByKey") }) # @rdname values -# @export setGeneric("values", function(x) { standardGeneric("values") }) @@ -283,14 +237,12 @@ setGeneric("values", function(x) { standardGeneric("values") }) # @rdname aggregateByKey # @seealso foldByKey, combineByKey -# @export setGeneric("aggregateByKey", function(x, zeroValue, seqOp, combOp, numPartitions) { standardGeneric("aggregateByKey") }) # @rdname cogroup -# @export setGeneric("cogroup", function(..., numPartitions) { standardGeneric("cogroup") @@ -299,7 +251,6 @@ setGeneric("cogroup", # @rdname combineByKey # @seealso groupByKey, reduceByKey -# @export setGeneric("combineByKey", function(x, createCombiner, mergeValue, mergeCombiners, numPartitions) { standardGeneric("combineByKey") @@ -307,64 +258,53 @@ setGeneric("combineByKey", # @rdname foldByKey # @seealso aggregateByKey, combineByKey -# @export setGeneric("foldByKey", function(x, zeroValue, func, numPartitions) { standardGeneric("foldByKey") }) # @rdname join-methods -# @export setGeneric("fullOuterJoin", function(x, y, numPartitions) { standardGeneric("fullOuterJoin") }) # @rdname groupByKey # @seealso reduceByKey -# @export setGeneric("groupByKey", function(x, numPartitions) { standardGeneric("groupByKey") }) # @rdname join-methods -# @export setGeneric("join", function(x, y, ...) { standardGeneric("join") }) # @rdname join-methods -# @export setGeneric("leftOuterJoin", function(x, y, numPartitions) { standardGeneric("leftOuterJoin") }) setGeneric("partitionByRDD", function(x, ...) { standardGeneric("partitionByRDD") }) # @rdname reduceByKey # @seealso groupByKey -# @export setGeneric("reduceByKey", function(x, combineFunc, numPartitions) { standardGeneric("reduceByKey")}) # @rdname reduceByKeyLocally # @seealso reduceByKey -# @export setGeneric("reduceByKeyLocally", function(x, combineFunc) { standardGeneric("reduceByKeyLocally") }) # @rdname join-methods -# @export setGeneric("rightOuterJoin", function(x, y, numPartitions) { standardGeneric("rightOuterJoin") }) # @rdname sortByKey -# @export setGeneric("sortByKey", function(x, ascending = TRUE, numPartitions = 1) { standardGeneric("sortByKey") }) # @rdname subtract -# @export setGeneric("subtract", function(x, other, numPartitions = 1) { standardGeneric("subtract") }) # @rdname subtractByKey -# @export setGeneric("subtractByKey", function(x, other, numPartitions = 1) { standardGeneric("subtractByKey") @@ -374,7 +314,6 @@ setGeneric("subtractByKey", ################### Broadcast Variable Methods ################# # @rdname broadcast -# @export setGeneric("value", function(bcast) { standardGeneric("value") }) @@ -384,7 +323,6 @@ setGeneric("value", function(bcast) { standardGeneric("value") }) #' @param ... further arguments to be passed to or from other methods. #' @return A SparkDataFrame. #' @rdname summarize -#' @export setGeneric("agg", function(x, ...) { standardGeneric("agg") }) #' alias @@ -399,11 +337,9 @@ setGeneric("agg", function(x, ...) { standardGeneric("agg") }) NULL #' @rdname arrange -#' @export setGeneric("arrange", function(x, col, ...) { standardGeneric("arrange") }) #' @rdname as.data.frame -#' @export setGeneric("as.data.frame", function(x, row.names = NULL, optional = FALSE, ...) { standardGeneric("as.data.frame") @@ -411,52 +347,41 @@ setGeneric("as.data.frame", # Do not document the generic because of signature changes across R versions #' @noRd -#' @export setGeneric("attach") #' @rdname cache -#' @export setGeneric("cache", function(x) { standardGeneric("cache") }) #' @rdname checkpoint -#' @export setGeneric("checkpoint", function(x, eager = TRUE) { standardGeneric("checkpoint") }) #' @rdname coalesce #' @param x a SparkDataFrame. #' @param ... additional argument(s). -#' @export setGeneric("coalesce", function(x, ...) { standardGeneric("coalesce") }) #' @rdname collect -#' @export setGeneric("collect", function(x, ...) { standardGeneric("collect") }) #' @param do.NULL currently not used. #' @param prefix currently not used. #' @rdname columns -#' @export setGeneric("colnames", function(x, do.NULL = TRUE, prefix = "col") { standardGeneric("colnames") }) #' @rdname columns -#' @export setGeneric("colnames<-", function(x, value) { standardGeneric("colnames<-") }) #' @rdname coltypes -#' @export setGeneric("coltypes", function(x) { standardGeneric("coltypes") }) #' @rdname coltypes -#' @export setGeneric("coltypes<-", function(x, value) { standardGeneric("coltypes<-") }) #' @rdname columns -#' @export setGeneric("columns", function(x) {standardGeneric("columns") }) #' @param x a GroupedData or Column. #' @rdname count -#' @export setGeneric("count", function(x) { standardGeneric("count") }) #' @rdname cov @@ -464,7 +389,6 @@ setGeneric("count", function(x) { standardGeneric("count") }) #' @param ... additional argument(s). If \code{x} is a Column, a Column #' should be provided. If \code{x} is a SparkDataFrame, two column names should #' be provided. -#' @export setGeneric("cov", function(x, ...) {standardGeneric("cov") }) #' @rdname corr @@ -472,294 +396,229 @@ setGeneric("cov", function(x, ...) {standardGeneric("cov") }) #' @param ... additional argument(s). If \code{x} is a Column, a Column #' should be provided. If \code{x} is a SparkDataFrame, two column names should #' be provided. -#' @export setGeneric("corr", function(x, ...) {standardGeneric("corr") }) #' @rdname cov -#' @export setGeneric("covar_samp", function(col1, col2) {standardGeneric("covar_samp") }) #' @rdname cov -#' @export setGeneric("covar_pop", function(col1, col2) {standardGeneric("covar_pop") }) #' @rdname createOrReplaceTempView -#' @export setGeneric("createOrReplaceTempView", function(x, viewName) { standardGeneric("createOrReplaceTempView") }) # @rdname crossJoin -# @export setGeneric("crossJoin", function(x, y) { standardGeneric("crossJoin") }) #' @rdname cube -#' @export setGeneric("cube", function(x, ...) { standardGeneric("cube") }) #' @rdname dapply -#' @export setGeneric("dapply", function(x, func, schema) { standardGeneric("dapply") }) #' @rdname dapplyCollect -#' @export setGeneric("dapplyCollect", function(x, func) { standardGeneric("dapplyCollect") }) #' @param x a SparkDataFrame or GroupedData. #' @param ... additional argument(s) passed to the method. #' @rdname gapply -#' @export setGeneric("gapply", function(x, ...) { standardGeneric("gapply") }) #' @param x a SparkDataFrame or GroupedData. #' @param ... additional argument(s) passed to the method. #' @rdname gapplyCollect -#' @export setGeneric("gapplyCollect", function(x, ...) { standardGeneric("gapplyCollect") }) # @rdname getNumPartitions -# @export setGeneric("getNumPartitions", function(x) { standardGeneric("getNumPartitions") }) #' @rdname describe -#' @export setGeneric("describe", function(x, col, ...) { standardGeneric("describe") }) #' @rdname distinct -#' @export setGeneric("distinct", function(x) { standardGeneric("distinct") }) #' @rdname drop -#' @export setGeneric("drop", function(x, ...) { standardGeneric("drop") }) #' @rdname dropDuplicates -#' @export setGeneric("dropDuplicates", function(x, ...) { standardGeneric("dropDuplicates") }) #' @rdname nafunctions -#' @export setGeneric("dropna", function(x, how = c("any", "all"), minNonNulls = NULL, cols = NULL) { standardGeneric("dropna") }) #' @rdname nafunctions -#' @export setGeneric("na.omit", function(object, ...) { standardGeneric("na.omit") }) #' @rdname dtypes -#' @export setGeneric("dtypes", function(x) { standardGeneric("dtypes") }) #' @rdname explain -#' @export #' @param x a SparkDataFrame or a StreamingQuery. #' @param extended Logical. If extended is FALSE, prints only the physical plan. #' @param ... further arguments to be passed to or from other methods. setGeneric("explain", function(x, ...) { standardGeneric("explain") }) #' @rdname except -#' @export setGeneric("except", function(x, y) { standardGeneric("except") }) #' @rdname nafunctions -#' @export setGeneric("fillna", function(x, value, cols = NULL) { standardGeneric("fillna") }) #' @rdname filter -#' @export setGeneric("filter", function(x, condition) { standardGeneric("filter") }) #' @rdname first -#' @export setGeneric("first", function(x, ...) { standardGeneric("first") }) #' @rdname groupBy -#' @export setGeneric("group_by", function(x, ...) { standardGeneric("group_by") }) #' @rdname groupBy -#' @export setGeneric("groupBy", function(x, ...) { standardGeneric("groupBy") }) #' @rdname hint -#' @export setGeneric("hint", function(x, name, ...) { standardGeneric("hint") }) #' @rdname insertInto -#' @export setGeneric("insertInto", function(x, tableName, ...) { standardGeneric("insertInto") }) #' @rdname intersect -#' @export setGeneric("intersect", function(x, y) { standardGeneric("intersect") }) #' @rdname isLocal -#' @export setGeneric("isLocal", function(x) { standardGeneric("isLocal") }) #' @rdname isStreaming -#' @export setGeneric("isStreaming", function(x) { standardGeneric("isStreaming") }) #' @rdname limit -#' @export setGeneric("limit", function(x, num) {standardGeneric("limit") }) #' @rdname localCheckpoint -#' @export setGeneric("localCheckpoint", function(x, eager = TRUE) { standardGeneric("localCheckpoint") }) #' @rdname merge -#' @export setGeneric("merge") #' @rdname mutate -#' @export setGeneric("mutate", function(.data, ...) {standardGeneric("mutate") }) #' @rdname orderBy -#' @export setGeneric("orderBy", function(x, col, ...) { standardGeneric("orderBy") }) #' @rdname persist -#' @export setGeneric("persist", function(x, newLevel) { standardGeneric("persist") }) #' @rdname printSchema -#' @export setGeneric("printSchema", function(x) { standardGeneric("printSchema") }) #' @rdname registerTempTable-deprecated -#' @export setGeneric("registerTempTable", function(x, tableName) { standardGeneric("registerTempTable") }) #' @rdname rename -#' @export setGeneric("rename", function(x, ...) { standardGeneric("rename") }) #' @rdname repartition -#' @export setGeneric("repartition", function(x, ...) { standardGeneric("repartition") }) #' @rdname sample -#' @export setGeneric("sample", function(x, withReplacement = FALSE, fraction, seed) { standardGeneric("sample") }) #' @rdname rollup -#' @export setGeneric("rollup", function(x, ...) { standardGeneric("rollup") }) #' @rdname sample -#' @export setGeneric("sample_frac", function(x, withReplacement = FALSE, fraction, seed) { standardGeneric("sample_frac") }) #' @rdname sampleBy -#' @export setGeneric("sampleBy", function(x, col, fractions, seed) { standardGeneric("sampleBy") }) #' @rdname saveAsTable -#' @export setGeneric("saveAsTable", function(df, tableName, source = NULL, mode = "error", ...) { standardGeneric("saveAsTable") }) -#' @export setGeneric("str") #' @rdname take -#' @export setGeneric("take", function(x, num) { standardGeneric("take") }) #' @rdname mutate -#' @export setGeneric("transform", function(`_data`, ...) {standardGeneric("transform") }) #' @rdname write.df -#' @export setGeneric("write.df", function(df, path = NULL, source = NULL, mode = "error", ...) { standardGeneric("write.df") }) #' @rdname write.df -#' @export setGeneric("saveDF", function(df, path, source = NULL, mode = "error", ...) { standardGeneric("saveDF") }) #' @rdname write.jdbc -#' @export setGeneric("write.jdbc", function(x, url, tableName, mode = "error", ...) { standardGeneric("write.jdbc") }) #' @rdname write.json -#' @export setGeneric("write.json", function(x, path, ...) { standardGeneric("write.json") }) #' @rdname write.orc -#' @export setGeneric("write.orc", function(x, path, ...) { standardGeneric("write.orc") }) #' @rdname write.parquet -#' @export setGeneric("write.parquet", function(x, path, ...) { standardGeneric("write.parquet") }) #' @rdname write.parquet -#' @export setGeneric("saveAsParquetFile", function(x, path) { standardGeneric("saveAsParquetFile") }) #' @rdname write.stream -#' @export setGeneric("write.stream", function(df, source = NULL, outputMode = NULL, ...) { standardGeneric("write.stream") }) #' @rdname write.text -#' @export setGeneric("write.text", function(x, path, ...) { standardGeneric("write.text") }) #' @rdname schema -#' @export setGeneric("schema", function(x) { standardGeneric("schema") }) #' @rdname select -#' @export setGeneric("select", function(x, col, ...) { standardGeneric("select") }) #' @rdname selectExpr -#' @export setGeneric("selectExpr", function(x, expr, ...) { standardGeneric("selectExpr") }) #' @rdname showDF -#' @export setGeneric("showDF", function(x, ...) { standardGeneric("showDF") }) # @rdname storageLevel -# @export setGeneric("storageLevel", function(x) { standardGeneric("storageLevel") }) #' @rdname subset -#' @export setGeneric("subset", function(x, ...) { standardGeneric("subset") }) #' @rdname summarize -#' @export setGeneric("summarize", function(x, ...) { standardGeneric("summarize") }) #' @rdname summary -#' @export setGeneric("summary", function(object, ...) { standardGeneric("summary") }) setGeneric("toJSON", function(x) { standardGeneric("toJSON") }) @@ -767,830 +626,660 @@ setGeneric("toJSON", function(x) { standardGeneric("toJSON") }) setGeneric("toRDD", function(x) { standardGeneric("toRDD") }) #' @rdname union -#' @export setGeneric("union", function(x, y) { standardGeneric("union") }) #' @rdname union -#' @export setGeneric("unionAll", function(x, y) { standardGeneric("unionAll") }) #' @rdname unionByName -#' @export setGeneric("unionByName", function(x, y) { standardGeneric("unionByName") }) #' @rdname unpersist -#' @export setGeneric("unpersist", function(x, ...) { standardGeneric("unpersist") }) #' @rdname filter -#' @export setGeneric("where", function(x, condition) { standardGeneric("where") }) #' @rdname with -#' @export setGeneric("with") #' @rdname withColumn -#' @export setGeneric("withColumn", function(x, colName, col) { standardGeneric("withColumn") }) #' @rdname rename -#' @export setGeneric("withColumnRenamed", function(x, existingCol, newCol) { standardGeneric("withColumnRenamed") }) #' @rdname withWatermark -#' @export setGeneric("withWatermark", function(x, eventTime, delayThreshold) { standardGeneric("withWatermark") }) #' @rdname write.df -#' @export setGeneric("write.df", function(df, path = NULL, ...) { standardGeneric("write.df") }) #' @rdname randomSplit -#' @export setGeneric("randomSplit", function(x, weights, seed) { standardGeneric("randomSplit") }) #' @rdname broadcast -#' @export setGeneric("broadcast", function(x) { standardGeneric("broadcast") }) ###################### Column Methods ########################## #' @rdname columnfunctions -#' @export setGeneric("asc", function(x) { standardGeneric("asc") }) #' @rdname between -#' @export setGeneric("between", function(x, bounds) { standardGeneric("between") }) #' @rdname cast -#' @export setGeneric("cast", function(x, dataType) { standardGeneric("cast") }) #' @rdname columnfunctions #' @param x a Column object. #' @param ... additional argument(s). -#' @export setGeneric("contains", function(x, ...) { standardGeneric("contains") }) #' @rdname columnfunctions -#' @export setGeneric("desc", function(x) { standardGeneric("desc") }) #' @rdname endsWith -#' @export setGeneric("endsWith", function(x, suffix) { standardGeneric("endsWith") }) #' @rdname columnfunctions -#' @export setGeneric("getField", function(x, ...) { standardGeneric("getField") }) #' @rdname columnfunctions -#' @export setGeneric("getItem", function(x, ...) { standardGeneric("getItem") }) #' @rdname columnfunctions -#' @export setGeneric("isNaN", function(x) { standardGeneric("isNaN") }) #' @rdname columnfunctions -#' @export setGeneric("isNull", function(x) { standardGeneric("isNull") }) #' @rdname columnfunctions -#' @export setGeneric("isNotNull", function(x) { standardGeneric("isNotNull") }) #' @rdname columnfunctions -#' @export setGeneric("like", function(x, ...) { standardGeneric("like") }) #' @rdname columnfunctions -#' @export setGeneric("rlike", function(x, ...) { standardGeneric("rlike") }) #' @rdname startsWith -#' @export setGeneric("startsWith", function(x, prefix) { standardGeneric("startsWith") }) #' @rdname column_nonaggregate_functions -#' @export #' @name NULL setGeneric("when", function(condition, value) { standardGeneric("when") }) #' @rdname otherwise -#' @export setGeneric("otherwise", function(x, value) { standardGeneric("otherwise") }) #' @rdname over -#' @export setGeneric("over", function(x, window) { standardGeneric("over") }) #' @rdname eq_null_safe -#' @export setGeneric("%<=>%", function(x, value) { standardGeneric("%<=>%") }) ###################### WindowSpec Methods ########################## #' @rdname partitionBy -#' @export setGeneric("partitionBy", function(x, ...) { standardGeneric("partitionBy") }) #' @rdname rowsBetween -#' @export setGeneric("rowsBetween", function(x, start, end) { standardGeneric("rowsBetween") }) #' @rdname rangeBetween -#' @export setGeneric("rangeBetween", function(x, start, end) { standardGeneric("rangeBetween") }) #' @rdname windowPartitionBy -#' @export setGeneric("windowPartitionBy", function(col, ...) { standardGeneric("windowPartitionBy") }) #' @rdname windowOrderBy -#' @export setGeneric("windowOrderBy", function(col, ...) { standardGeneric("windowOrderBy") }) ###################### Expression Function Methods ########################## #' @rdname column_datetime_diff_functions -#' @export #' @name NULL setGeneric("add_months", function(y, x) { standardGeneric("add_months") }) #' @rdname column_aggregate_functions -#' @export #' @name NULL setGeneric("approxCountDistinct", function(x, ...) { standardGeneric("approxCountDistinct") }) #' @rdname column_collection_functions -#' @export #' @name NULL setGeneric("array_contains", function(x, value) { standardGeneric("array_contains") }) #' @rdname column_string_functions -#' @export #' @name NULL setGeneric("ascii", function(x) { standardGeneric("ascii") }) #' @param x Column to compute on or a GroupedData object. #' @param ... additional argument(s) when \code{x} is a GroupedData object. #' @rdname avg -#' @export setGeneric("avg", function(x, ...) { standardGeneric("avg") }) #' @rdname column_string_functions -#' @export #' @name NULL setGeneric("base64", function(x) { standardGeneric("base64") }) #' @rdname column_math_functions -#' @export #' @name NULL setGeneric("bin", function(x) { standardGeneric("bin") }) #' @rdname column_nonaggregate_functions -#' @export #' @name NULL setGeneric("bitwiseNOT", function(x) { standardGeneric("bitwiseNOT") }) #' @rdname column_math_functions -#' @export #' @name NULL setGeneric("bround", function(x, ...) { standardGeneric("bround") }) #' @rdname column_math_functions -#' @export #' @name NULL setGeneric("cbrt", function(x) { standardGeneric("cbrt") }) #' @rdname column_math_functions -#' @export #' @name NULL setGeneric("ceil", function(x) { standardGeneric("ceil") }) #' @rdname column_aggregate_functions -#' @export #' @name NULL setGeneric("collect_list", function(x) { standardGeneric("collect_list") }) #' @rdname column_aggregate_functions -#' @export #' @name NULL setGeneric("collect_set", function(x) { standardGeneric("collect_set") }) #' @rdname column -#' @export setGeneric("column", function(x) { standardGeneric("column") }) #' @rdname column_string_functions -#' @export #' @name NULL setGeneric("concat", function(x, ...) { standardGeneric("concat") }) #' @rdname column_string_functions -#' @export #' @name NULL setGeneric("concat_ws", function(sep, x, ...) { standardGeneric("concat_ws") }) #' @rdname column_math_functions -#' @export #' @name NULL setGeneric("conv", function(x, fromBase, toBase) { standardGeneric("conv") }) #' @rdname column_aggregate_functions -#' @export #' @name NULL setGeneric("countDistinct", function(x, ...) { standardGeneric("countDistinct") }) #' @rdname column_misc_functions -#' @export #' @name NULL setGeneric("crc32", function(x) { standardGeneric("crc32") }) #' @rdname column_nonaggregate_functions -#' @export #' @name NULL setGeneric("create_array", function(x, ...) { standardGeneric("create_array") }) #' @rdname column_nonaggregate_functions -#' @export #' @name NULL setGeneric("create_map", function(x, ...) { standardGeneric("create_map") }) #' @rdname column_misc_functions -#' @export #' @name NULL setGeneric("hash", function(x, ...) { standardGeneric("hash") }) #' @rdname column_window_functions -#' @export #' @name NULL setGeneric("cume_dist", function(x = "missing") { standardGeneric("cume_dist") }) #' @rdname column_datetime_functions -#' @export #' @name NULL setGeneric("current_date", function(x = "missing") { standardGeneric("current_date") }) #' @rdname column_datetime_functions -#' @export #' @name NULL setGeneric("current_timestamp", function(x = "missing") { standardGeneric("current_timestamp") }) #' @rdname column_datetime_diff_functions -#' @export #' @name NULL setGeneric("datediff", function(y, x) { standardGeneric("datediff") }) #' @rdname column_datetime_diff_functions -#' @export #' @name NULL setGeneric("date_add", function(y, x) { standardGeneric("date_add") }) #' @rdname column_datetime_diff_functions -#' @export #' @name NULL setGeneric("date_format", function(y, x) { standardGeneric("date_format") }) #' @rdname column_datetime_diff_functions -#' @export #' @name NULL setGeneric("date_sub", function(y, x) { standardGeneric("date_sub") }) #' @rdname column_datetime_functions -#' @export #' @name NULL setGeneric("date_trunc", function(format, x) { standardGeneric("date_trunc") }) #' @rdname column_datetime_functions -#' @export #' @name NULL setGeneric("dayofmonth", function(x) { standardGeneric("dayofmonth") }) #' @rdname column_datetime_functions -#' @export #' @name NULL setGeneric("dayofweek", function(x) { standardGeneric("dayofweek") }) #' @rdname column_datetime_functions -#' @export #' @name NULL setGeneric("dayofyear", function(x) { standardGeneric("dayofyear") }) #' @rdname column_string_functions -#' @export #' @name NULL setGeneric("decode", function(x, charset) { standardGeneric("decode") }) #' @rdname column_window_functions -#' @export #' @name NULL setGeneric("dense_rank", function(x = "missing") { standardGeneric("dense_rank") }) #' @rdname column_string_functions -#' @export #' @name NULL setGeneric("encode", function(x, charset) { standardGeneric("encode") }) #' @rdname column_collection_functions -#' @export #' @name NULL setGeneric("explode", function(x) { standardGeneric("explode") }) #' @rdname column_collection_functions -#' @export #' @name NULL setGeneric("explode_outer", function(x) { standardGeneric("explode_outer") }) #' @rdname column_nonaggregate_functions -#' @export #' @name NULL setGeneric("expr", function(x) { standardGeneric("expr") }) #' @rdname column_datetime_diff_functions -#' @export #' @name NULL setGeneric("from_utc_timestamp", function(y, x) { standardGeneric("from_utc_timestamp") }) #' @rdname column_string_functions -#' @export #' @name NULL setGeneric("format_number", function(y, x) { standardGeneric("format_number") }) #' @rdname column_string_functions -#' @export #' @name NULL setGeneric("format_string", function(format, x, ...) { standardGeneric("format_string") }) #' @rdname column_collection_functions -#' @export #' @name NULL setGeneric("from_json", function(x, schema, ...) { standardGeneric("from_json") }) #' @rdname column_datetime_functions -#' @export #' @name NULL setGeneric("from_unixtime", function(x, ...) { standardGeneric("from_unixtime") }) #' @rdname column_nonaggregate_functions -#' @export #' @name NULL setGeneric("greatest", function(x, ...) { standardGeneric("greatest") }) #' @rdname column_aggregate_functions -#' @export #' @name NULL setGeneric("grouping_bit", function(x) { standardGeneric("grouping_bit") }) #' @rdname column_aggregate_functions -#' @export #' @name NULL setGeneric("grouping_id", function(x, ...) { standardGeneric("grouping_id") }) #' @rdname column_math_functions -#' @export #' @name NULL setGeneric("hex", function(x) { standardGeneric("hex") }) #' @rdname column_datetime_functions -#' @export #' @name NULL setGeneric("hour", function(x) { standardGeneric("hour") }) #' @rdname column_math_functions -#' @export #' @name NULL setGeneric("hypot", function(y, x) { standardGeneric("hypot") }) #' @rdname column_string_functions -#' @export #' @name NULL setGeneric("initcap", function(x) { standardGeneric("initcap") }) #' @rdname column_nonaggregate_functions -#' @export #' @name NULL setGeneric("input_file_name", function(x = "missing") { standardGeneric("input_file_name") }) #' @rdname column_string_functions -#' @export #' @name NULL setGeneric("instr", function(y, x) { standardGeneric("instr") }) #' @rdname column_nonaggregate_functions -#' @export #' @name NULL setGeneric("isnan", function(x) { standardGeneric("isnan") }) #' @rdname column_aggregate_functions -#' @export #' @name NULL setGeneric("kurtosis", function(x) { standardGeneric("kurtosis") }) #' @rdname column_window_functions -#' @export #' @name NULL setGeneric("lag", function(x, ...) { standardGeneric("lag") }) #' @rdname last -#' @export setGeneric("last", function(x, ...) { standardGeneric("last") }) #' @rdname column_datetime_functions -#' @export #' @name NULL setGeneric("last_day", function(x) { standardGeneric("last_day") }) #' @rdname column_window_functions -#' @export #' @name NULL setGeneric("lead", function(x, offset, defaultValue = NULL) { standardGeneric("lead") }) #' @rdname column_nonaggregate_functions -#' @export #' @name NULL setGeneric("least", function(x, ...) { standardGeneric("least") }) #' @rdname column_string_functions -#' @export #' @name NULL setGeneric("levenshtein", function(y, x) { standardGeneric("levenshtein") }) #' @rdname column_nonaggregate_functions -#' @export #' @name NULL setGeneric("lit", function(x) { standardGeneric("lit") }) #' @rdname column_string_functions -#' @export #' @name NULL setGeneric("locate", function(substr, str, ...) { standardGeneric("locate") }) #' @rdname column_string_functions -#' @export #' @name NULL setGeneric("lower", function(x) { standardGeneric("lower") }) #' @rdname column_string_functions -#' @export #' @name NULL setGeneric("lpad", function(x, len, pad) { standardGeneric("lpad") }) #' @rdname column_string_functions -#' @export #' @name NULL setGeneric("ltrim", function(x, trimString) { standardGeneric("ltrim") }) #' @rdname column_collection_functions -#' @export #' @name NULL setGeneric("map_keys", function(x) { standardGeneric("map_keys") }) #' @rdname column_collection_functions -#' @export #' @name NULL setGeneric("map_values", function(x) { standardGeneric("map_values") }) #' @rdname column_misc_functions -#' @export #' @name NULL setGeneric("md5", function(x) { standardGeneric("md5") }) #' @rdname column_datetime_functions -#' @export #' @name NULL setGeneric("minute", function(x) { standardGeneric("minute") }) #' @rdname column_nonaggregate_functions -#' @export #' @name NULL setGeneric("monotonically_increasing_id", function(x = "missing") { standardGeneric("monotonically_increasing_id") }) #' @rdname column_datetime_functions -#' @export #' @name NULL setGeneric("month", function(x) { standardGeneric("month") }) #' @rdname column_datetime_diff_functions -#' @export #' @name NULL setGeneric("months_between", function(y, x) { standardGeneric("months_between") }) #' @rdname count -#' @export setGeneric("n", function(x) { standardGeneric("n") }) #' @rdname column_nonaggregate_functions -#' @export #' @name NULL setGeneric("nanvl", function(y, x) { standardGeneric("nanvl") }) #' @rdname column_nonaggregate_functions -#' @export #' @name NULL setGeneric("negate", function(x) { standardGeneric("negate") }) #' @rdname not -#' @export setGeneric("not", function(x) { standardGeneric("not") }) #' @rdname column_datetime_diff_functions -#' @export #' @name NULL setGeneric("next_day", function(y, x) { standardGeneric("next_day") }) #' @rdname column_window_functions -#' @export #' @name NULL setGeneric("ntile", function(x) { standardGeneric("ntile") }) #' @rdname column_aggregate_functions -#' @export #' @name NULL setGeneric("n_distinct", function(x, ...) { standardGeneric("n_distinct") }) #' @rdname column_window_functions -#' @export #' @name NULL setGeneric("percent_rank", function(x = "missing") { standardGeneric("percent_rank") }) #' @rdname column_math_functions -#' @export #' @name NULL setGeneric("pmod", function(y, x) { standardGeneric("pmod") }) #' @rdname column_collection_functions -#' @export #' @name NULL setGeneric("posexplode", function(x) { standardGeneric("posexplode") }) #' @rdname column_collection_functions -#' @export #' @name NULL setGeneric("posexplode_outer", function(x) { standardGeneric("posexplode_outer") }) #' @rdname column_datetime_functions -#' @export #' @name NULL setGeneric("quarter", function(x) { standardGeneric("quarter") }) #' @rdname column_nonaggregate_functions -#' @export #' @name NULL setGeneric("rand", function(seed) { standardGeneric("rand") }) #' @rdname column_nonaggregate_functions -#' @export #' @name NULL setGeneric("randn", function(seed) { standardGeneric("randn") }) #' @rdname column_window_functions -#' @export #' @name NULL setGeneric("rank", function(x, ...) { standardGeneric("rank") }) #' @rdname column_string_functions -#' @export #' @name NULL setGeneric("regexp_extract", function(x, pattern, idx) { standardGeneric("regexp_extract") }) #' @rdname column_string_functions -#' @export #' @name NULL setGeneric("regexp_replace", function(x, pattern, replacement) { standardGeneric("regexp_replace") }) #' @rdname column_string_functions -#' @export #' @name NULL setGeneric("repeat_string", function(x, n) { standardGeneric("repeat_string") }) #' @rdname column_string_functions -#' @export #' @name NULL setGeneric("reverse", function(x) { standardGeneric("reverse") }) #' @rdname column_math_functions -#' @export #' @name NULL setGeneric("rint", function(x) { standardGeneric("rint") }) #' @rdname column_window_functions -#' @export #' @name NULL setGeneric("row_number", function(x = "missing") { standardGeneric("row_number") }) #' @rdname column_string_functions -#' @export #' @name NULL setGeneric("rpad", function(x, len, pad) { standardGeneric("rpad") }) #' @rdname column_string_functions -#' @export #' @name NULL setGeneric("rtrim", function(x, trimString) { standardGeneric("rtrim") }) #' @rdname column_aggregate_functions -#' @export #' @name NULL setGeneric("sd", function(x, na.rm = FALSE) { standardGeneric("sd") }) #' @rdname column_datetime_functions -#' @export #' @name NULL setGeneric("second", function(x) { standardGeneric("second") }) #' @rdname column_misc_functions -#' @export #' @name NULL setGeneric("sha1", function(x) { standardGeneric("sha1") }) #' @rdname column_misc_functions -#' @export #' @name NULL setGeneric("sha2", function(y, x) { standardGeneric("sha2") }) #' @rdname column_math_functions -#' @export #' @name NULL setGeneric("shiftLeft", function(y, x) { standardGeneric("shiftLeft") }) #' @rdname column_math_functions -#' @export #' @name NULL setGeneric("shiftRight", function(y, x) { standardGeneric("shiftRight") }) #' @rdname column_math_functions -#' @export #' @name NULL setGeneric("shiftRightUnsigned", function(y, x) { standardGeneric("shiftRightUnsigned") }) #' @rdname column_math_functions -#' @export #' @name NULL setGeneric("signum", function(x) { standardGeneric("signum") }) #' @rdname column_collection_functions -#' @export #' @name NULL setGeneric("size", function(x) { standardGeneric("size") }) #' @rdname column_aggregate_functions -#' @export #' @name NULL setGeneric("skewness", function(x) { standardGeneric("skewness") }) #' @rdname column_collection_functions -#' @export #' @name NULL setGeneric("sort_array", function(x, asc = TRUE) { standardGeneric("sort_array") }) #' @rdname column_string_functions -#' @export #' @name NULL setGeneric("split_string", function(x, pattern) { standardGeneric("split_string") }) #' @rdname column_string_functions -#' @export #' @name NULL setGeneric("soundex", function(x) { standardGeneric("soundex") }) #' @rdname column_nonaggregate_functions -#' @export #' @name NULL setGeneric("spark_partition_id", function(x = "missing") { standardGeneric("spark_partition_id") }) #' @rdname column_aggregate_functions -#' @export #' @name NULL setGeneric("stddev", function(x) { standardGeneric("stddev") }) #' @rdname column_aggregate_functions -#' @export #' @name NULL setGeneric("stddev_pop", function(x) { standardGeneric("stddev_pop") }) #' @rdname column_aggregate_functions -#' @export #' @name NULL setGeneric("stddev_samp", function(x) { standardGeneric("stddev_samp") }) #' @rdname column_nonaggregate_functions -#' @export #' @name NULL setGeneric("struct", function(x, ...) { standardGeneric("struct") }) #' @rdname column_string_functions -#' @export #' @name NULL setGeneric("substring_index", function(x, delim, count) { standardGeneric("substring_index") }) #' @rdname column_aggregate_functions -#' @export #' @name NULL setGeneric("sumDistinct", function(x) { standardGeneric("sumDistinct") }) #' @rdname column_math_functions -#' @export #' @name NULL setGeneric("toDegrees", function(x) { standardGeneric("toDegrees") }) #' @rdname column_math_functions -#' @export #' @name NULL setGeneric("toRadians", function(x) { standardGeneric("toRadians") }) #' @rdname column_datetime_functions -#' @export #' @name NULL setGeneric("to_date", function(x, format) { standardGeneric("to_date") }) #' @rdname column_collection_functions -#' @export #' @name NULL setGeneric("to_json", function(x, ...) { standardGeneric("to_json") }) #' @rdname column_datetime_functions -#' @export #' @name NULL setGeneric("to_timestamp", function(x, format) { standardGeneric("to_timestamp") }) #' @rdname column_datetime_diff_functions -#' @export #' @name NULL setGeneric("to_utc_timestamp", function(y, x) { standardGeneric("to_utc_timestamp") }) #' @rdname column_string_functions -#' @export #' @name NULL setGeneric("translate", function(x, matchingString, replaceString) { standardGeneric("translate") }) #' @rdname column_string_functions -#' @export #' @name NULL setGeneric("trim", function(x, trimString) { standardGeneric("trim") }) #' @rdname column_string_functions -#' @export #' @name NULL setGeneric("unbase64", function(x) { standardGeneric("unbase64") }) #' @rdname column_math_functions -#' @export #' @name NULL setGeneric("unhex", function(x) { standardGeneric("unhex") }) #' @rdname column_datetime_functions -#' @export #' @name NULL setGeneric("unix_timestamp", function(x, format) { standardGeneric("unix_timestamp") }) #' @rdname column_string_functions -#' @export #' @name NULL setGeneric("upper", function(x) { standardGeneric("upper") }) #' @rdname column_aggregate_functions -#' @export #' @name NULL setGeneric("var", function(x, y = NULL, na.rm = FALSE, use) { standardGeneric("var") }) #' @rdname column_aggregate_functions -#' @export #' @name NULL setGeneric("variance", function(x) { standardGeneric("variance") }) #' @rdname column_aggregate_functions -#' @export #' @name NULL setGeneric("var_pop", function(x) { standardGeneric("var_pop") }) #' @rdname column_aggregate_functions -#' @export #' @name NULL setGeneric("var_samp", function(x) { standardGeneric("var_samp") }) #' @rdname column_datetime_functions -#' @export #' @name NULL setGeneric("weekofyear", function(x) { standardGeneric("weekofyear") }) #' @rdname column_datetime_functions -#' @export #' @name NULL setGeneric("window", function(x, ...) { standardGeneric("window") }) #' @rdname column_datetime_functions -#' @export #' @name NULL setGeneric("year", function(x) { standardGeneric("year") }) @@ -1598,142 +1287,110 @@ setGeneric("year", function(x) { standardGeneric("year") }) ###################### Spark.ML Methods ########################## #' @rdname fitted -#' @export setGeneric("fitted") # Do not carry stats::glm usage and param here, and do not document the generic -#' @export #' @noRd setGeneric("glm") #' @param object a fitted ML model object. #' @param ... additional argument(s) passed to the method. #' @rdname predict -#' @export setGeneric("predict", function(object, ...) { standardGeneric("predict") }) #' @rdname rbind -#' @export setGeneric("rbind", signature = "...") #' @rdname spark.als -#' @export setGeneric("spark.als", function(data, ...) { standardGeneric("spark.als") }) #' @rdname spark.bisectingKmeans -#' @export setGeneric("spark.bisectingKmeans", function(data, formula, ...) { standardGeneric("spark.bisectingKmeans") }) #' @rdname spark.gaussianMixture -#' @export setGeneric("spark.gaussianMixture", function(data, formula, ...) { standardGeneric("spark.gaussianMixture") }) #' @rdname spark.gbt -#' @export setGeneric("spark.gbt", function(data, formula, ...) { standardGeneric("spark.gbt") }) #' @rdname spark.glm -#' @export setGeneric("spark.glm", function(data, formula, ...) { standardGeneric("spark.glm") }) #' @rdname spark.isoreg -#' @export setGeneric("spark.isoreg", function(data, formula, ...) { standardGeneric("spark.isoreg") }) #' @rdname spark.kmeans -#' @export setGeneric("spark.kmeans", function(data, formula, ...) { standardGeneric("spark.kmeans") }) #' @rdname spark.kstest -#' @export setGeneric("spark.kstest", function(data, ...) { standardGeneric("spark.kstest") }) #' @rdname spark.lda -#' @export setGeneric("spark.lda", function(data, ...) { standardGeneric("spark.lda") }) #' @rdname spark.logit -#' @export setGeneric("spark.logit", function(data, formula, ...) { standardGeneric("spark.logit") }) #' @rdname spark.mlp -#' @export setGeneric("spark.mlp", function(data, formula, ...) { standardGeneric("spark.mlp") }) #' @rdname spark.naiveBayes -#' @export setGeneric("spark.naiveBayes", function(data, formula, ...) { standardGeneric("spark.naiveBayes") }) #' @rdname spark.decisionTree -#' @export setGeneric("spark.decisionTree", function(data, formula, ...) { standardGeneric("spark.decisionTree") }) #' @rdname spark.randomForest -#' @export setGeneric("spark.randomForest", function(data, formula, ...) { standardGeneric("spark.randomForest") }) #' @rdname spark.survreg -#' @export setGeneric("spark.survreg", function(data, formula, ...) { standardGeneric("spark.survreg") }) #' @rdname spark.svmLinear -#' @export setGeneric("spark.svmLinear", function(data, formula, ...) { standardGeneric("spark.svmLinear") }) #' @rdname spark.lda -#' @export setGeneric("spark.posterior", function(object, newData) { standardGeneric("spark.posterior") }) #' @rdname spark.lda -#' @export setGeneric("spark.perplexity", function(object, data) { standardGeneric("spark.perplexity") }) #' @rdname spark.fpGrowth -#' @export setGeneric("spark.fpGrowth", function(data, ...) { standardGeneric("spark.fpGrowth") }) #' @rdname spark.fpGrowth -#' @export setGeneric("spark.freqItemsets", function(object) { standardGeneric("spark.freqItemsets") }) #' @rdname spark.fpGrowth -#' @export setGeneric("spark.associationRules", function(object) { standardGeneric("spark.associationRules") }) #' @param object a fitted ML model object. #' @param path the directory where the model is saved. #' @param ... additional argument(s) passed to the method. #' @rdname write.ml -#' @export setGeneric("write.ml", function(object, path, ...) { standardGeneric("write.ml") }) ###################### Streaming Methods ########################## #' @rdname awaitTermination -#' @export setGeneric("awaitTermination", function(x, timeout = NULL) { standardGeneric("awaitTermination") }) #' @rdname isActive -#' @export setGeneric("isActive", function(x) { standardGeneric("isActive") }) #' @rdname lastProgress -#' @export setGeneric("lastProgress", function(x) { standardGeneric("lastProgress") }) #' @rdname queryName -#' @export setGeneric("queryName", function(x) { standardGeneric("queryName") }) #' @rdname status -#' @export setGeneric("status", function(x) { standardGeneric("status") }) #' @rdname stopQuery -#' @export setGeneric("stopQuery", function(x) { standardGeneric("stopQuery") }) diff --git a/R/pkg/R/group.R b/R/pkg/R/group.R index 54ef9f07d6fae..f751b952f3915 100644 --- a/R/pkg/R/group.R +++ b/R/pkg/R/group.R @@ -30,7 +30,6 @@ setOldClass("jobj") #' @seealso groupBy #' #' @param sgd A Java object reference to the backing Scala GroupedData -#' @export #' @note GroupedData since 1.4.0 setClass("GroupedData", slots = list(sgd = "jobj")) @@ -48,7 +47,6 @@ groupedData <- function(sgd) { #' @rdname show #' @aliases show,GroupedData-method -#' @export #' @note show(GroupedData) since 1.4.0 setMethod("show", "GroupedData", function(object) { @@ -63,7 +61,6 @@ setMethod("show", "GroupedData", #' @return A SparkDataFrame. #' @rdname count #' @aliases count,GroupedData-method -#' @export #' @examples #' \dontrun{ #' count(groupBy(df, "name")) @@ -87,7 +84,6 @@ setMethod("count", #' @aliases agg,GroupedData-method #' @name agg #' @family agg_funcs -#' @export #' @examples #' \dontrun{ #' df2 <- agg(df, age = "sum") # new column name will be created as 'SUM(age#0)' @@ -150,7 +146,6 @@ methods <- c("avg", "max", "mean", "min", "sum") #' @rdname pivot #' @aliases pivot,GroupedData,character-method #' @name pivot -#' @export #' @examples #' \dontrun{ #' df <- createDataFrame(data.frame( @@ -202,7 +197,6 @@ createMethods() #' @rdname gapply #' @aliases gapply,GroupedData-method #' @name gapply -#' @export #' @note gapply(GroupedData) since 2.0.0 setMethod("gapply", signature(x = "GroupedData"), @@ -216,7 +210,6 @@ setMethod("gapply", #' @rdname gapplyCollect #' @aliases gapplyCollect,GroupedData-method #' @name gapplyCollect -#' @export #' @note gapplyCollect(GroupedData) since 2.0.0 setMethod("gapplyCollect", signature(x = "GroupedData"), diff --git a/R/pkg/R/install.R b/R/pkg/R/install.R index 04dc7562e5346..6d1edf6b6f3cf 100644 --- a/R/pkg/R/install.R +++ b/R/pkg/R/install.R @@ -58,7 +58,6 @@ #' @rdname install.spark #' @name install.spark #' @aliases install.spark -#' @export #' @examples #'\dontrun{ #' install.spark() diff --git a/R/pkg/R/jvm.R b/R/pkg/R/jvm.R index bb5c77544a3da..9a1b26b0fa3c5 100644 --- a/R/pkg/R/jvm.R +++ b/R/pkg/R/jvm.R @@ -35,7 +35,6 @@ #' @param ... parameters to pass to the Java method. #' @return the return value of the Java method. Either returned as a R object #' if it can be deserialized or returned as a "jobj". See details section for more. -#' @export #' @seealso \link{sparkR.callJStatic}, \link{sparkR.newJObject} #' @rdname sparkR.callJMethod #' @examples @@ -69,7 +68,6 @@ sparkR.callJMethod <- function(x, methodName, ...) { #' @param ... parameters to pass to the Java method. #' @return the return value of the Java method. Either returned as a R object #' if it can be deserialized or returned as a "jobj". See details section for more. -#' @export #' @seealso \link{sparkR.callJMethod}, \link{sparkR.newJObject} #' @rdname sparkR.callJStatic #' @examples @@ -100,7 +98,6 @@ sparkR.callJStatic <- function(x, methodName, ...) { #' @param ... arguments to be passed to the constructor. #' @return the object created. Either returned as a R object #' if it can be deserialized or returned as a "jobj". See details section for more. -#' @export #' @seealso \link{sparkR.callJMethod}, \link{sparkR.callJStatic} #' @rdname sparkR.newJObject #' @examples diff --git a/R/pkg/R/mllib_classification.R b/R/pkg/R/mllib_classification.R index f6e9b1357561b..2964fdeff0957 100644 --- a/R/pkg/R/mllib_classification.R +++ b/R/pkg/R/mllib_classification.R @@ -21,28 +21,24 @@ #' S4 class that represents an LinearSVCModel #' #' @param jobj a Java object reference to the backing Scala LinearSVCModel -#' @export #' @note LinearSVCModel since 2.2.0 setClass("LinearSVCModel", representation(jobj = "jobj")) #' S4 class that represents an LogisticRegressionModel #' #' @param jobj a Java object reference to the backing Scala LogisticRegressionModel -#' @export #' @note LogisticRegressionModel since 2.1.0 setClass("LogisticRegressionModel", representation(jobj = "jobj")) #' S4 class that represents a MultilayerPerceptronClassificationModel #' #' @param jobj a Java object reference to the backing Scala MultilayerPerceptronClassifierWrapper -#' @export #' @note MultilayerPerceptronClassificationModel since 2.1.0 setClass("MultilayerPerceptronClassificationModel", representation(jobj = "jobj")) #' S4 class that represents a NaiveBayesModel #' #' @param jobj a Java object reference to the backing Scala NaiveBayesWrapper -#' @export #' @note NaiveBayesModel since 2.0.0 setClass("NaiveBayesModel", representation(jobj = "jobj")) @@ -82,7 +78,6 @@ setClass("NaiveBayesModel", representation(jobj = "jobj")) #' @rdname spark.svmLinear #' @aliases spark.svmLinear,SparkDataFrame,formula-method #' @name spark.svmLinear -#' @export #' @examples #' \dontrun{ #' sparkR.session() @@ -131,7 +126,6 @@ setMethod("spark.svmLinear", signature(data = "SparkDataFrame", formula = "formu #' @return \code{predict} returns the predicted values based on a LinearSVCModel. #' @rdname spark.svmLinear #' @aliases predict,LinearSVCModel,SparkDataFrame-method -#' @export #' @note predict(LinearSVCModel) since 2.2.0 setMethod("predict", signature(object = "LinearSVCModel"), function(object, newData) { @@ -146,7 +140,6 @@ setMethod("predict", signature(object = "LinearSVCModel"), #' \code{numClasses} (number of classes), \code{numFeatures} (number of features). #' @rdname spark.svmLinear #' @aliases summary,LinearSVCModel-method -#' @export #' @note summary(LinearSVCModel) since 2.2.0 setMethod("summary", signature(object = "LinearSVCModel"), function(object) { @@ -169,7 +162,6 @@ setMethod("summary", signature(object = "LinearSVCModel"), #' #' @rdname spark.svmLinear #' @aliases write.ml,LinearSVCModel,character-method -#' @export #' @note write.ml(LogisticRegression, character) since 2.2.0 setMethod("write.ml", signature(object = "LinearSVCModel", path = "character"), function(object, path, overwrite = FALSE) { @@ -257,7 +249,6 @@ function(object, path, overwrite = FALSE) { #' @rdname spark.logit #' @aliases spark.logit,SparkDataFrame,formula-method #' @name spark.logit -#' @export #' @examples #' \dontrun{ #' sparkR.session() @@ -374,7 +365,6 @@ setMethod("spark.logit", signature(data = "SparkDataFrame", formula = "formula") #' The list includes \code{coefficients} (coefficients matrix of the fitted model). #' @rdname spark.logit #' @aliases summary,LogisticRegressionModel-method -#' @export #' @note summary(LogisticRegressionModel) since 2.1.0 setMethod("summary", signature(object = "LogisticRegressionModel"), function(object) { @@ -402,7 +392,6 @@ setMethod("summary", signature(object = "LogisticRegressionModel"), #' @return \code{predict} returns the predicted values based on an LogisticRegressionModel. #' @rdname spark.logit #' @aliases predict,LogisticRegressionModel,SparkDataFrame-method -#' @export #' @note predict(LogisticRegressionModel) since 2.1.0 setMethod("predict", signature(object = "LogisticRegressionModel"), function(object, newData) { @@ -417,7 +406,6 @@ setMethod("predict", signature(object = "LogisticRegressionModel"), #' #' @rdname spark.logit #' @aliases write.ml,LogisticRegressionModel,character-method -#' @export #' @note write.ml(LogisticRegression, character) since 2.1.0 setMethod("write.ml", signature(object = "LogisticRegressionModel", path = "character"), function(object, path, overwrite = FALSE) { @@ -458,7 +446,6 @@ setMethod("write.ml", signature(object = "LogisticRegressionModel", path = "char #' @aliases spark.mlp,SparkDataFrame,formula-method #' @name spark.mlp #' @seealso \link{read.ml} -#' @export #' @examples #' \dontrun{ #' df <- read.df("data/mllib/sample_multiclass_classification_data.txt", source = "libsvm") @@ -517,7 +504,6 @@ setMethod("spark.mlp", signature(data = "SparkDataFrame", formula = "formula"), #' For \code{weights}, it is a numeric vector with length equal to the expected #' given the architecture (i.e., for 8-10-2 network, 112 connection weights). #' @rdname spark.mlp -#' @export #' @aliases summary,MultilayerPerceptronClassificationModel-method #' @note summary(MultilayerPerceptronClassificationModel) since 2.1.0 setMethod("summary", signature(object = "MultilayerPerceptronClassificationModel"), @@ -538,7 +524,6 @@ setMethod("summary", signature(object = "MultilayerPerceptronClassificationModel #' "prediction". #' @rdname spark.mlp #' @aliases predict,MultilayerPerceptronClassificationModel-method -#' @export #' @note predict(MultilayerPerceptronClassificationModel) since 2.1.0 setMethod("predict", signature(object = "MultilayerPerceptronClassificationModel"), function(object, newData) { @@ -553,7 +538,6 @@ setMethod("predict", signature(object = "MultilayerPerceptronClassificationModel #' #' @rdname spark.mlp #' @aliases write.ml,MultilayerPerceptronClassificationModel,character-method -#' @export #' @seealso \link{write.ml} #' @note write.ml(MultilayerPerceptronClassificationModel, character) since 2.1.0 setMethod("write.ml", signature(object = "MultilayerPerceptronClassificationModel", @@ -585,7 +569,6 @@ setMethod("write.ml", signature(object = "MultilayerPerceptronClassificationMode #' @aliases spark.naiveBayes,SparkDataFrame,formula-method #' @name spark.naiveBayes #' @seealso e1071: \url{https://cran.r-project.org/package=e1071} -#' @export #' @examples #' \dontrun{ #' data <- as.data.frame(UCBAdmissions) @@ -624,7 +607,6 @@ setMethod("spark.naiveBayes", signature(data = "SparkDataFrame", formula = "form #' The list includes \code{apriori} (the label distribution) and #' \code{tables} (conditional probabilities given the target label). #' @rdname spark.naiveBayes -#' @export #' @note summary(NaiveBayesModel) since 2.0.0 setMethod("summary", signature(object = "NaiveBayesModel"), function(object) { @@ -648,7 +630,6 @@ setMethod("summary", signature(object = "NaiveBayesModel"), #' @return \code{predict} returns a SparkDataFrame containing predicted labeled in a column named #' "prediction". #' @rdname spark.naiveBayes -#' @export #' @note predict(NaiveBayesModel) since 2.0.0 setMethod("predict", signature(object = "NaiveBayesModel"), function(object, newData) { @@ -662,7 +643,6 @@ setMethod("predict", signature(object = "NaiveBayesModel"), #' which means throw exception if the output path exists. #' #' @rdname spark.naiveBayes -#' @export #' @seealso \link{write.ml} #' @note write.ml(NaiveBayesModel, character) since 2.0.0 setMethod("write.ml", signature(object = "NaiveBayesModel", path = "character"), diff --git a/R/pkg/R/mllib_clustering.R b/R/pkg/R/mllib_clustering.R index a25bf81c6d977..900be685824da 100644 --- a/R/pkg/R/mllib_clustering.R +++ b/R/pkg/R/mllib_clustering.R @@ -20,28 +20,24 @@ #' S4 class that represents a BisectingKMeansModel #' #' @param jobj a Java object reference to the backing Scala BisectingKMeansModel -#' @export #' @note BisectingKMeansModel since 2.2.0 setClass("BisectingKMeansModel", representation(jobj = "jobj")) #' S4 class that represents a GaussianMixtureModel #' #' @param jobj a Java object reference to the backing Scala GaussianMixtureModel -#' @export #' @note GaussianMixtureModel since 2.1.0 setClass("GaussianMixtureModel", representation(jobj = "jobj")) #' S4 class that represents a KMeansModel #' #' @param jobj a Java object reference to the backing Scala KMeansModel -#' @export #' @note KMeansModel since 2.0.0 setClass("KMeansModel", representation(jobj = "jobj")) #' S4 class that represents an LDAModel #' #' @param jobj a Java object reference to the backing Scala LDAWrapper -#' @export #' @note LDAModel since 2.1.0 setClass("LDAModel", representation(jobj = "jobj")) @@ -68,7 +64,6 @@ setClass("LDAModel", representation(jobj = "jobj")) #' @rdname spark.bisectingKmeans #' @aliases spark.bisectingKmeans,SparkDataFrame,formula-method #' @name spark.bisectingKmeans -#' @export #' @examples #' \dontrun{ #' sparkR.session() @@ -117,7 +112,6 @@ setMethod("spark.bisectingKmeans", signature(data = "SparkDataFrame", formula = #' (cluster centers of the transformed data; cluster is NULL if is.loaded is TRUE), #' and \code{is.loaded} (whether the model is loaded from a saved file). #' @rdname spark.bisectingKmeans -#' @export #' @note summary(BisectingKMeansModel) since 2.2.0 setMethod("summary", signature(object = "BisectingKMeansModel"), function(object) { @@ -144,7 +138,6 @@ setMethod("summary", signature(object = "BisectingKMeansModel"), #' @param newData a SparkDataFrame for testing. #' @return \code{predict} returns the predicted values based on a bisecting k-means model. #' @rdname spark.bisectingKmeans -#' @export #' @note predict(BisectingKMeansModel) since 2.2.0 setMethod("predict", signature(object = "BisectingKMeansModel"), function(object, newData) { @@ -160,7 +153,6 @@ setMethod("predict", signature(object = "BisectingKMeansModel"), #' or \code{"classes"} for assigned classes. #' @return \code{fitted} returns a SparkDataFrame containing fitted values. #' @rdname spark.bisectingKmeans -#' @export #' @note fitted since 2.2.0 setMethod("fitted", signature(object = "BisectingKMeansModel"), function(object, method = c("centers", "classes")) { @@ -181,7 +173,6 @@ setMethod("fitted", signature(object = "BisectingKMeansModel"), #' which means throw exception if the output path exists. #' #' @rdname spark.bisectingKmeans -#' @export #' @note write.ml(BisectingKMeansModel, character) since 2.2.0 setMethod("write.ml", signature(object = "BisectingKMeansModel", path = "character"), function(object, path, overwrite = FALSE) { @@ -208,7 +199,6 @@ setMethod("write.ml", signature(object = "BisectingKMeansModel", path = "charact #' @rdname spark.gaussianMixture #' @name spark.gaussianMixture #' @seealso mixtools: \url{https://cran.r-project.org/package=mixtools} -#' @export #' @examples #' \dontrun{ #' sparkR.session() @@ -251,7 +241,6 @@ setMethod("spark.gaussianMixture", signature(data = "SparkDataFrame", formula = #' \code{sigma} (sigma), \code{loglik} (loglik), and \code{posterior} (posterior). #' @aliases spark.gaussianMixture,SparkDataFrame,formula-method #' @rdname spark.gaussianMixture -#' @export #' @note summary(GaussianMixtureModel) since 2.1.0 setMethod("summary", signature(object = "GaussianMixtureModel"), function(object) { @@ -291,7 +280,6 @@ setMethod("summary", signature(object = "GaussianMixtureModel"), #' "prediction". #' @aliases predict,GaussianMixtureModel,SparkDataFrame-method #' @rdname spark.gaussianMixture -#' @export #' @note predict(GaussianMixtureModel) since 2.1.0 setMethod("predict", signature(object = "GaussianMixtureModel"), function(object, newData) { @@ -306,7 +294,6 @@ setMethod("predict", signature(object = "GaussianMixtureModel"), #' #' @aliases write.ml,GaussianMixtureModel,character-method #' @rdname spark.gaussianMixture -#' @export #' @note write.ml(GaussianMixtureModel, character) since 2.1.0 setMethod("write.ml", signature(object = "GaussianMixtureModel", path = "character"), function(object, path, overwrite = FALSE) { @@ -336,7 +323,6 @@ setMethod("write.ml", signature(object = "GaussianMixtureModel", path = "charact #' @rdname spark.kmeans #' @aliases spark.kmeans,SparkDataFrame,formula-method #' @name spark.kmeans -#' @export #' @examples #' \dontrun{ #' sparkR.session() @@ -385,7 +371,6 @@ setMethod("spark.kmeans", signature(data = "SparkDataFrame", formula = "formula" #' (the actual number of cluster centers. When using initMode = "random", #' \code{clusterSize} may not equal to \code{k}). #' @rdname spark.kmeans -#' @export #' @note summary(KMeansModel) since 2.0.0 setMethod("summary", signature(object = "KMeansModel"), function(object) { @@ -413,7 +398,6 @@ setMethod("summary", signature(object = "KMeansModel"), #' @param newData a SparkDataFrame for testing. #' @return \code{predict} returns the predicted values based on a k-means model. #' @rdname spark.kmeans -#' @export #' @note predict(KMeansModel) since 2.0.0 setMethod("predict", signature(object = "KMeansModel"), function(object, newData) { @@ -431,7 +415,6 @@ setMethod("predict", signature(object = "KMeansModel"), #' @param ... additional argument(s) passed to the method. #' @return \code{fitted} returns a SparkDataFrame containing fitted values. #' @rdname fitted -#' @export #' @examples #' \dontrun{ #' model <- spark.kmeans(trainingData, ~ ., 2) @@ -458,7 +441,6 @@ setMethod("fitted", signature(object = "KMeansModel"), #' which means throw exception if the output path exists. #' #' @rdname spark.kmeans -#' @export #' @note write.ml(KMeansModel, character) since 2.0.0 setMethod("write.ml", signature(object = "KMeansModel", path = "character"), function(object, path, overwrite = FALSE) { @@ -496,7 +478,6 @@ setMethod("write.ml", signature(object = "KMeansModel", path = "character"), #' @rdname spark.lda #' @aliases spark.lda,SparkDataFrame-method #' @seealso topicmodels: \url{https://cran.r-project.org/package=topicmodels} -#' @export #' @examples #' \dontrun{ #' text <- read.df("data/mllib/sample_lda_libsvm_data.txt", source = "libsvm") @@ -558,7 +539,6 @@ setMethod("spark.lda", signature(data = "SparkDataFrame"), #' It is only for distributed LDA model (i.e., optimizer = "em")} #' @rdname spark.lda #' @aliases summary,LDAModel-method -#' @export #' @note summary(LDAModel) since 2.1.0 setMethod("summary", signature(object = "LDAModel"), function(object, maxTermsPerTopic) { @@ -596,7 +576,6 @@ setMethod("summary", signature(object = "LDAModel"), #' perplexity of the training data if missing argument "data". #' @rdname spark.lda #' @aliases spark.perplexity,LDAModel-method -#' @export #' @note spark.perplexity(LDAModel) since 2.1.0 setMethod("spark.perplexity", signature(object = "LDAModel", data = "SparkDataFrame"), function(object, data) { @@ -611,7 +590,6 @@ setMethod("spark.perplexity", signature(object = "LDAModel", data = "SparkDataFr #' vectors named "topicDistribution". #' @rdname spark.lda #' @aliases spark.posterior,LDAModel,SparkDataFrame-method -#' @export #' @note spark.posterior(LDAModel) since 2.1.0 setMethod("spark.posterior", signature(object = "LDAModel", newData = "SparkDataFrame"), function(object, newData) { @@ -626,7 +604,6 @@ setMethod("spark.posterior", signature(object = "LDAModel", newData = "SparkData #' #' @rdname spark.lda #' @aliases write.ml,LDAModel,character-method -#' @export #' @seealso \link{read.ml} #' @note write.ml(LDAModel, character) since 2.1.0 setMethod("write.ml", signature(object = "LDAModel", path = "character"), diff --git a/R/pkg/R/mllib_fpm.R b/R/pkg/R/mllib_fpm.R index dfcb45a1b66c9..e2394906d8012 100644 --- a/R/pkg/R/mllib_fpm.R +++ b/R/pkg/R/mllib_fpm.R @@ -20,7 +20,6 @@ #' S4 class that represents a FPGrowthModel #' #' @param jobj a Java object reference to the backing Scala FPGrowthModel -#' @export #' @note FPGrowthModel since 2.2.0 setClass("FPGrowthModel", slots = list(jobj = "jobj")) @@ -45,7 +44,6 @@ setClass("FPGrowthModel", slots = list(jobj = "jobj")) #' @rdname spark.fpGrowth #' @name spark.fpGrowth #' @aliases spark.fpGrowth,SparkDataFrame-method -#' @export #' @examples #' \dontrun{ #' raw_data <- read.df( @@ -109,7 +107,6 @@ setMethod("spark.fpGrowth", signature(data = "SparkDataFrame"), #' and \code{freq} (frequency of the itemset). #' @rdname spark.fpGrowth #' @aliases freqItemsets,FPGrowthModel-method -#' @export #' @note spark.freqItemsets(FPGrowthModel) since 2.2.0 setMethod("spark.freqItemsets", signature(object = "FPGrowthModel"), function(object) { @@ -125,7 +122,6 @@ setMethod("spark.freqItemsets", signature(object = "FPGrowthModel"), #' and \code{condfidence} (confidence). #' @rdname spark.fpGrowth #' @aliases associationRules,FPGrowthModel-method -#' @export #' @note spark.associationRules(FPGrowthModel) since 2.2.0 setMethod("spark.associationRules", signature(object = "FPGrowthModel"), function(object) { @@ -138,7 +134,6 @@ setMethod("spark.associationRules", signature(object = "FPGrowthModel"), #' @return \code{predict} returns a SparkDataFrame containing predicted values. #' @rdname spark.fpGrowth #' @aliases predict,FPGrowthModel-method -#' @export #' @note predict(FPGrowthModel) since 2.2.0 setMethod("predict", signature(object = "FPGrowthModel"), function(object, newData) { @@ -153,7 +148,6 @@ setMethod("predict", signature(object = "FPGrowthModel"), #' if the output path exists. #' @rdname spark.fpGrowth #' @aliases write.ml,FPGrowthModel,character-method -#' @export #' @seealso \link{read.ml} #' @note write.ml(FPGrowthModel, character) since 2.2.0 setMethod("write.ml", signature(object = "FPGrowthModel", path = "character"), diff --git a/R/pkg/R/mllib_recommendation.R b/R/pkg/R/mllib_recommendation.R index 5441c4a4022a9..9a77b07462585 100644 --- a/R/pkg/R/mllib_recommendation.R +++ b/R/pkg/R/mllib_recommendation.R @@ -20,7 +20,6 @@ #' S4 class that represents an ALSModel #' #' @param jobj a Java object reference to the backing Scala ALSWrapper -#' @export #' @note ALSModel since 2.1.0 setClass("ALSModel", representation(jobj = "jobj")) @@ -55,7 +54,6 @@ setClass("ALSModel", representation(jobj = "jobj")) #' @rdname spark.als #' @aliases spark.als,SparkDataFrame-method #' @name spark.als -#' @export #' @examples #' \dontrun{ #' ratings <- list(list(0, 0, 4.0), list(0, 1, 2.0), list(1, 1, 3.0), list(1, 2, 4.0), @@ -118,7 +116,6 @@ setMethod("spark.als", signature(data = "SparkDataFrame"), #' and \code{rank} (rank of the matrix factorization model). #' @rdname spark.als #' @aliases summary,ALSModel-method -#' @export #' @note summary(ALSModel) since 2.1.0 setMethod("summary", signature(object = "ALSModel"), function(object) { @@ -139,7 +136,6 @@ setMethod("summary", signature(object = "ALSModel"), #' @return \code{predict} returns a SparkDataFrame containing predicted values. #' @rdname spark.als #' @aliases predict,ALSModel-method -#' @export #' @note predict(ALSModel) since 2.1.0 setMethod("predict", signature(object = "ALSModel"), function(object, newData) { @@ -155,7 +151,6 @@ setMethod("predict", signature(object = "ALSModel"), #' #' @rdname spark.als #' @aliases write.ml,ALSModel,character-method -#' @export #' @seealso \link{read.ml} #' @note write.ml(ALSModel, character) since 2.1.0 setMethod("write.ml", signature(object = "ALSModel", path = "character"), diff --git a/R/pkg/R/mllib_regression.R b/R/pkg/R/mllib_regression.R index 545be5e1d89f0..95c1a29905197 100644 --- a/R/pkg/R/mllib_regression.R +++ b/R/pkg/R/mllib_regression.R @@ -21,21 +21,18 @@ #' S4 class that represents a AFTSurvivalRegressionModel #' #' @param jobj a Java object reference to the backing Scala AFTSurvivalRegressionWrapper -#' @export #' @note AFTSurvivalRegressionModel since 2.0.0 setClass("AFTSurvivalRegressionModel", representation(jobj = "jobj")) #' S4 class that represents a generalized linear model #' #' @param jobj a Java object reference to the backing Scala GeneralizedLinearRegressionWrapper -#' @export #' @note GeneralizedLinearRegressionModel since 2.0.0 setClass("GeneralizedLinearRegressionModel", representation(jobj = "jobj")) #' S4 class that represents an IsotonicRegressionModel #' #' @param jobj a Java object reference to the backing Scala IsotonicRegressionModel -#' @export #' @note IsotonicRegressionModel since 2.1.0 setClass("IsotonicRegressionModel", representation(jobj = "jobj")) @@ -85,7 +82,6 @@ setClass("IsotonicRegressionModel", representation(jobj = "jobj")) #' @return \code{spark.glm} returns a fitted generalized linear model. #' @rdname spark.glm #' @name spark.glm -#' @export #' @examples #' \dontrun{ #' sparkR.session() @@ -211,7 +207,6 @@ setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"), #' @return \code{glm} returns a fitted generalized linear model. #' @rdname glm #' @aliases glm -#' @export #' @examples #' \dontrun{ #' sparkR.session() @@ -244,7 +239,6 @@ setMethod("glm", signature(formula = "formula", family = "ANY", data = "SparkDat #' and \code{iter} (number of iterations IRLS takes). If there are collinear columns in #' the data, the coefficients matrix only provides coefficients. #' @rdname spark.glm -#' @export #' @note summary(GeneralizedLinearRegressionModel) since 2.0.0 setMethod("summary", signature(object = "GeneralizedLinearRegressionModel"), function(object) { @@ -290,7 +284,6 @@ setMethod("summary", signature(object = "GeneralizedLinearRegressionModel"), #' @rdname spark.glm #' @param x summary object of fitted generalized linear model returned by \code{summary} function. -#' @export #' @note print.summary.GeneralizedLinearRegressionModel since 2.0.0 print.summary.GeneralizedLinearRegressionModel <- function(x, ...) { if (x$is.loaded) { @@ -324,7 +317,6 @@ print.summary.GeneralizedLinearRegressionModel <- function(x, ...) { #' @return \code{predict} returns a SparkDataFrame containing predicted labels in a column named #' "prediction". #' @rdname spark.glm -#' @export #' @note predict(GeneralizedLinearRegressionModel) since 1.5.0 setMethod("predict", signature(object = "GeneralizedLinearRegressionModel"), function(object, newData) { @@ -338,7 +330,6 @@ setMethod("predict", signature(object = "GeneralizedLinearRegressionModel"), #' which means throw exception if the output path exists. #' #' @rdname spark.glm -#' @export #' @note write.ml(GeneralizedLinearRegressionModel, character) since 2.0.0 setMethod("write.ml", signature(object = "GeneralizedLinearRegressionModel", path = "character"), function(object, path, overwrite = FALSE) { @@ -363,7 +354,6 @@ setMethod("write.ml", signature(object = "GeneralizedLinearRegressionModel", pat #' @rdname spark.isoreg #' @aliases spark.isoreg,SparkDataFrame,formula-method #' @name spark.isoreg -#' @export #' @examples #' \dontrun{ #' sparkR.session() @@ -412,7 +402,6 @@ setMethod("spark.isoreg", signature(data = "SparkDataFrame", formula = "formula" #' and \code{predictions} (predictions associated with the boundaries at the same index). #' @rdname spark.isoreg #' @aliases summary,IsotonicRegressionModel-method -#' @export #' @note summary(IsotonicRegressionModel) since 2.1.0 setMethod("summary", signature(object = "IsotonicRegressionModel"), function(object) { @@ -429,7 +418,6 @@ setMethod("summary", signature(object = "IsotonicRegressionModel"), #' @return \code{predict} returns a SparkDataFrame containing predicted values. #' @rdname spark.isoreg #' @aliases predict,IsotonicRegressionModel,SparkDataFrame-method -#' @export #' @note predict(IsotonicRegressionModel) since 2.1.0 setMethod("predict", signature(object = "IsotonicRegressionModel"), function(object, newData) { @@ -444,7 +432,6 @@ setMethod("predict", signature(object = "IsotonicRegressionModel"), #' #' @rdname spark.isoreg #' @aliases write.ml,IsotonicRegressionModel,character-method -#' @export #' @note write.ml(IsotonicRegression, character) since 2.1.0 setMethod("write.ml", signature(object = "IsotonicRegressionModel", path = "character"), function(object, path, overwrite = FALSE) { @@ -477,7 +464,6 @@ setMethod("write.ml", signature(object = "IsotonicRegressionModel", path = "char #' @return \code{spark.survreg} returns a fitted AFT survival regression model. #' @rdname spark.survreg #' @seealso survival: \url{https://cran.r-project.org/package=survival} -#' @export #' @examples #' \dontrun{ #' df <- createDataFrame(ovarian) @@ -517,7 +503,6 @@ setMethod("spark.survreg", signature(data = "SparkDataFrame", formula = "formula #' The list includes the model's \code{coefficients} (features, coefficients, #' intercept and log(scale)). #' @rdname spark.survreg -#' @export #' @note summary(AFTSurvivalRegressionModel) since 2.0.0 setMethod("summary", signature(object = "AFTSurvivalRegressionModel"), function(object) { @@ -537,7 +522,6 @@ setMethod("summary", signature(object = "AFTSurvivalRegressionModel"), #' @return \code{predict} returns a SparkDataFrame containing predicted values #' on the original scale of the data (mean predicted value at scale = 1.0). #' @rdname spark.survreg -#' @export #' @note predict(AFTSurvivalRegressionModel) since 2.0.0 setMethod("predict", signature(object = "AFTSurvivalRegressionModel"), function(object, newData) { @@ -550,7 +534,6 @@ setMethod("predict", signature(object = "AFTSurvivalRegressionModel"), #' @param overwrite overwrites or not if the output path already exists. Default is FALSE #' which means throw exception if the output path exists. #' @rdname spark.survreg -#' @export #' @note write.ml(AFTSurvivalRegressionModel, character) since 2.0.0 #' @seealso \link{write.ml} setMethod("write.ml", signature(object = "AFTSurvivalRegressionModel", path = "character"), diff --git a/R/pkg/R/mllib_stat.R b/R/pkg/R/mllib_stat.R index 3e013f1d45e38..f8c3329359961 100644 --- a/R/pkg/R/mllib_stat.R +++ b/R/pkg/R/mllib_stat.R @@ -20,7 +20,6 @@ #' S4 class that represents an KSTest #' #' @param jobj a Java object reference to the backing Scala KSTestWrapper -#' @export #' @note KSTest since 2.1.0 setClass("KSTest", representation(jobj = "jobj")) @@ -52,7 +51,6 @@ setClass("KSTest", representation(jobj = "jobj")) #' @name spark.kstest #' @seealso \href{http://spark.apache.org/docs/latest/mllib-statistics.html#hypothesis-testing}{ #' MLlib: Hypothesis Testing} -#' @export #' @examples #' \dontrun{ #' data <- data.frame(test = c(0.1, 0.15, 0.2, 0.3, 0.25)) @@ -94,7 +92,6 @@ setMethod("spark.kstest", signature(data = "SparkDataFrame"), #' parameters tested against) and \code{degreesOfFreedom} (degrees of freedom of the test). #' @rdname spark.kstest #' @aliases summary,KSTest-method -#' @export #' @note summary(KSTest) since 2.1.0 setMethod("summary", signature(object = "KSTest"), function(object) { @@ -117,7 +114,6 @@ setMethod("summary", signature(object = "KSTest"), #' @rdname spark.kstest #' @param x summary object of KSTest returned by \code{summary}. -#' @export #' @note print.summary.KSTest since 2.1.0 print.summary.KSTest <- function(x, ...) { jobj <- x$jobj diff --git a/R/pkg/R/mllib_tree.R b/R/pkg/R/mllib_tree.R index 4e5ddf22ee16d..6769be038efa9 100644 --- a/R/pkg/R/mllib_tree.R +++ b/R/pkg/R/mllib_tree.R @@ -20,42 +20,36 @@ #' S4 class that represents a GBTRegressionModel #' #' @param jobj a Java object reference to the backing Scala GBTRegressionModel -#' @export #' @note GBTRegressionModel since 2.1.0 setClass("GBTRegressionModel", representation(jobj = "jobj")) #' S4 class that represents a GBTClassificationModel #' #' @param jobj a Java object reference to the backing Scala GBTClassificationModel -#' @export #' @note GBTClassificationModel since 2.1.0 setClass("GBTClassificationModel", representation(jobj = "jobj")) #' S4 class that represents a RandomForestRegressionModel #' #' @param jobj a Java object reference to the backing Scala RandomForestRegressionModel -#' @export #' @note RandomForestRegressionModel since 2.1.0 setClass("RandomForestRegressionModel", representation(jobj = "jobj")) #' S4 class that represents a RandomForestClassificationModel #' #' @param jobj a Java object reference to the backing Scala RandomForestClassificationModel -#' @export #' @note RandomForestClassificationModel since 2.1.0 setClass("RandomForestClassificationModel", representation(jobj = "jobj")) #' S4 class that represents a DecisionTreeRegressionModel #' #' @param jobj a Java object reference to the backing Scala DecisionTreeRegressionModel -#' @export #' @note DecisionTreeRegressionModel since 2.3.0 setClass("DecisionTreeRegressionModel", representation(jobj = "jobj")) #' S4 class that represents a DecisionTreeClassificationModel #' #' @param jobj a Java object reference to the backing Scala DecisionTreeClassificationModel -#' @export #' @note DecisionTreeClassificationModel since 2.3.0 setClass("DecisionTreeClassificationModel", representation(jobj = "jobj")) @@ -179,7 +173,6 @@ print.summary.decisionTree <- function(x) { #' @return \code{spark.gbt} returns a fitted Gradient Boosted Tree model. #' @rdname spark.gbt #' @name spark.gbt -#' @export #' @examples #' \dontrun{ #' # fit a Gradient Boosted Tree Regression Model @@ -261,7 +254,6 @@ setMethod("spark.gbt", signature(data = "SparkDataFrame", formula = "formula"), #' \code{numTrees} (number of trees), and \code{treeWeights} (tree weights). #' @rdname spark.gbt #' @aliases summary,GBTRegressionModel-method -#' @export #' @note summary(GBTRegressionModel) since 2.1.0 setMethod("summary", signature(object = "GBTRegressionModel"), function(object) { @@ -275,7 +267,6 @@ setMethod("summary", signature(object = "GBTRegressionModel"), #' @param x summary object of Gradient Boosted Tree regression model or classification model #' returned by \code{summary}. #' @rdname spark.gbt -#' @export #' @note print.summary.GBTRegressionModel since 2.1.0 print.summary.GBTRegressionModel <- function(x, ...) { print.summary.treeEnsemble(x) @@ -285,7 +276,6 @@ print.summary.GBTRegressionModel <- function(x, ...) { #' @rdname spark.gbt #' @aliases summary,GBTClassificationModel-method -#' @export #' @note summary(GBTClassificationModel) since 2.1.0 setMethod("summary", signature(object = "GBTClassificationModel"), function(object) { @@ -297,7 +287,6 @@ setMethod("summary", signature(object = "GBTClassificationModel"), # Prints the summary of Gradient Boosted Tree Classification Model #' @rdname spark.gbt -#' @export #' @note print.summary.GBTClassificationModel since 2.1.0 print.summary.GBTClassificationModel <- function(x, ...) { print.summary.treeEnsemble(x) @@ -310,7 +299,6 @@ print.summary.GBTClassificationModel <- function(x, ...) { #' "prediction". #' @rdname spark.gbt #' @aliases predict,GBTRegressionModel-method -#' @export #' @note predict(GBTRegressionModel) since 2.1.0 setMethod("predict", signature(object = "GBTRegressionModel"), function(object, newData) { @@ -319,7 +307,6 @@ setMethod("predict", signature(object = "GBTRegressionModel"), #' @rdname spark.gbt #' @aliases predict,GBTClassificationModel-method -#' @export #' @note predict(GBTClassificationModel) since 2.1.0 setMethod("predict", signature(object = "GBTClassificationModel"), function(object, newData) { @@ -334,7 +321,6 @@ setMethod("predict", signature(object = "GBTClassificationModel"), #' which means throw exception if the output path exists. #' @aliases write.ml,GBTRegressionModel,character-method #' @rdname spark.gbt -#' @export #' @note write.ml(GBTRegressionModel, character) since 2.1.0 setMethod("write.ml", signature(object = "GBTRegressionModel", path = "character"), function(object, path, overwrite = FALSE) { @@ -343,7 +329,6 @@ setMethod("write.ml", signature(object = "GBTRegressionModel", path = "character #' @aliases write.ml,GBTClassificationModel,character-method #' @rdname spark.gbt -#' @export #' @note write.ml(GBTClassificationModel, character) since 2.1.0 setMethod("write.ml", signature(object = "GBTClassificationModel", path = "character"), function(object, path, overwrite = FALSE) { @@ -402,7 +387,6 @@ setMethod("write.ml", signature(object = "GBTClassificationModel", path = "chara #' @return \code{spark.randomForest} returns a fitted Random Forest model. #' @rdname spark.randomForest #' @name spark.randomForest -#' @export #' @examples #' \dontrun{ #' # fit a Random Forest Regression Model @@ -480,7 +464,6 @@ setMethod("spark.randomForest", signature(data = "SparkDataFrame", formula = "fo #' \code{numTrees} (number of trees), and \code{treeWeights} (tree weights). #' @rdname spark.randomForest #' @aliases summary,RandomForestRegressionModel-method -#' @export #' @note summary(RandomForestRegressionModel) since 2.1.0 setMethod("summary", signature(object = "RandomForestRegressionModel"), function(object) { @@ -494,7 +477,6 @@ setMethod("summary", signature(object = "RandomForestRegressionModel"), #' @param x summary object of Random Forest regression model or classification model #' returned by \code{summary}. #' @rdname spark.randomForest -#' @export #' @note print.summary.RandomForestRegressionModel since 2.1.0 print.summary.RandomForestRegressionModel <- function(x, ...) { print.summary.treeEnsemble(x) @@ -504,7 +486,6 @@ print.summary.RandomForestRegressionModel <- function(x, ...) { #' @rdname spark.randomForest #' @aliases summary,RandomForestClassificationModel-method -#' @export #' @note summary(RandomForestClassificationModel) since 2.1.0 setMethod("summary", signature(object = "RandomForestClassificationModel"), function(object) { @@ -516,7 +497,6 @@ setMethod("summary", signature(object = "RandomForestClassificationModel"), # Prints the summary of Random Forest Classification Model #' @rdname spark.randomForest -#' @export #' @note print.summary.RandomForestClassificationModel since 2.1.0 print.summary.RandomForestClassificationModel <- function(x, ...) { print.summary.treeEnsemble(x) @@ -529,7 +509,6 @@ print.summary.RandomForestClassificationModel <- function(x, ...) { #' "prediction". #' @rdname spark.randomForest #' @aliases predict,RandomForestRegressionModel-method -#' @export #' @note predict(RandomForestRegressionModel) since 2.1.0 setMethod("predict", signature(object = "RandomForestRegressionModel"), function(object, newData) { @@ -538,7 +517,6 @@ setMethod("predict", signature(object = "RandomForestRegressionModel"), #' @rdname spark.randomForest #' @aliases predict,RandomForestClassificationModel-method -#' @export #' @note predict(RandomForestClassificationModel) since 2.1.0 setMethod("predict", signature(object = "RandomForestClassificationModel"), function(object, newData) { @@ -554,7 +532,6 @@ setMethod("predict", signature(object = "RandomForestClassificationModel"), #' #' @aliases write.ml,RandomForestRegressionModel,character-method #' @rdname spark.randomForest -#' @export #' @note write.ml(RandomForestRegressionModel, character) since 2.1.0 setMethod("write.ml", signature(object = "RandomForestRegressionModel", path = "character"), function(object, path, overwrite = FALSE) { @@ -563,7 +540,6 @@ setMethod("write.ml", signature(object = "RandomForestRegressionModel", path = " #' @aliases write.ml,RandomForestClassificationModel,character-method #' @rdname spark.randomForest -#' @export #' @note write.ml(RandomForestClassificationModel, character) since 2.1.0 setMethod("write.ml", signature(object = "RandomForestClassificationModel", path = "character"), function(object, path, overwrite = FALSE) { @@ -617,7 +593,6 @@ setMethod("write.ml", signature(object = "RandomForestClassificationModel", path #' @return \code{spark.decisionTree} returns a fitted Decision Tree model. #' @rdname spark.decisionTree #' @name spark.decisionTree -#' @export #' @examples #' \dontrun{ #' # fit a Decision Tree Regression Model @@ -690,7 +665,6 @@ setMethod("spark.decisionTree", signature(data = "SparkDataFrame", formula = "fo #' trees). #' @rdname spark.decisionTree #' @aliases summary,DecisionTreeRegressionModel-method -#' @export #' @note summary(DecisionTreeRegressionModel) since 2.3.0 setMethod("summary", signature(object = "DecisionTreeRegressionModel"), function(object) { @@ -704,7 +678,6 @@ setMethod("summary", signature(object = "DecisionTreeRegressionModel"), #' @param x summary object of Decision Tree regression model or classification model #' returned by \code{summary}. #' @rdname spark.decisionTree -#' @export #' @note print.summary.DecisionTreeRegressionModel since 2.3.0 print.summary.DecisionTreeRegressionModel <- function(x, ...) { print.summary.decisionTree(x) @@ -714,7 +687,6 @@ print.summary.DecisionTreeRegressionModel <- function(x, ...) { #' @rdname spark.decisionTree #' @aliases summary,DecisionTreeClassificationModel-method -#' @export #' @note summary(DecisionTreeClassificationModel) since 2.3.0 setMethod("summary", signature(object = "DecisionTreeClassificationModel"), function(object) { @@ -726,7 +698,6 @@ setMethod("summary", signature(object = "DecisionTreeClassificationModel"), # Prints the summary of Decision Tree Classification Model #' @rdname spark.decisionTree -#' @export #' @note print.summary.DecisionTreeClassificationModel since 2.3.0 print.summary.DecisionTreeClassificationModel <- function(x, ...) { print.summary.decisionTree(x) @@ -739,7 +710,6 @@ print.summary.DecisionTreeClassificationModel <- function(x, ...) { #' "prediction". #' @rdname spark.decisionTree #' @aliases predict,DecisionTreeRegressionModel-method -#' @export #' @note predict(DecisionTreeRegressionModel) since 2.3.0 setMethod("predict", signature(object = "DecisionTreeRegressionModel"), function(object, newData) { @@ -748,7 +718,6 @@ setMethod("predict", signature(object = "DecisionTreeRegressionModel"), #' @rdname spark.decisionTree #' @aliases predict,DecisionTreeClassificationModel-method -#' @export #' @note predict(DecisionTreeClassificationModel) since 2.3.0 setMethod("predict", signature(object = "DecisionTreeClassificationModel"), function(object, newData) { @@ -764,7 +733,6 @@ setMethod("predict", signature(object = "DecisionTreeClassificationModel"), #' #' @aliases write.ml,DecisionTreeRegressionModel,character-method #' @rdname spark.decisionTree -#' @export #' @note write.ml(DecisionTreeRegressionModel, character) since 2.3.0 setMethod("write.ml", signature(object = "DecisionTreeRegressionModel", path = "character"), function(object, path, overwrite = FALSE) { @@ -773,7 +741,6 @@ setMethod("write.ml", signature(object = "DecisionTreeRegressionModel", path = " #' @aliases write.ml,DecisionTreeClassificationModel,character-method #' @rdname spark.decisionTree -#' @export #' @note write.ml(DecisionTreeClassificationModel, character) since 2.3.0 setMethod("write.ml", signature(object = "DecisionTreeClassificationModel", path = "character"), function(object, path, overwrite = FALSE) { diff --git a/R/pkg/R/mllib_utils.R b/R/pkg/R/mllib_utils.R index a53c92c2c4815..7d04bffcba3a4 100644 --- a/R/pkg/R/mllib_utils.R +++ b/R/pkg/R/mllib_utils.R @@ -31,7 +31,6 @@ #' MLlib model below. #' @rdname write.ml #' @name write.ml -#' @export #' @seealso \link{spark.als}, \link{spark.bisectingKmeans}, \link{spark.decisionTree}, #' @seealso \link{spark.gaussianMixture}, \link{spark.gbt}, #' @seealso \link{spark.glm}, \link{glm}, \link{spark.isoreg}, @@ -48,7 +47,6 @@ NULL #' MLlib model below. #' @rdname predict #' @name predict -#' @export #' @seealso \link{spark.als}, \link{spark.bisectingKmeans}, \link{spark.decisionTree}, #' @seealso \link{spark.gaussianMixture}, \link{spark.gbt}, #' @seealso \link{spark.glm}, \link{glm}, \link{spark.isoreg}, @@ -75,7 +73,6 @@ predict_internal <- function(object, newData) { #' @return A fitted MLlib model. #' @rdname read.ml #' @name read.ml -#' @export #' @seealso \link{write.ml} #' @examples #' \dontrun{ diff --git a/R/pkg/R/schema.R b/R/pkg/R/schema.R index 65f418740c643..9831fc3cc6d01 100644 --- a/R/pkg/R/schema.R +++ b/R/pkg/R/schema.R @@ -29,7 +29,6 @@ #' @param ... additional structField objects #' @return a structType object #' @rdname structType -#' @export #' @examples #'\dontrun{ #' schema <- structType(structField("a", "integer"), structField("c", "string"), @@ -49,7 +48,6 @@ structType <- function(x, ...) { #' @rdname structType #' @method structType jobj -#' @export structType.jobj <- function(x, ...) { obj <- structure(list(), class = "structType") obj$jobj <- x @@ -59,7 +57,6 @@ structType.jobj <- function(x, ...) { #' @rdname structType #' @method structType structField -#' @export structType.structField <- function(x, ...) { fields <- list(x, ...) if (!all(sapply(fields, inherits, "structField"))) { @@ -76,7 +73,6 @@ structType.structField <- function(x, ...) { #' @rdname structType #' @method structType character -#' @export structType.character <- function(x, ...) { if (!is.character(x)) { stop("schema must be a DDL-formatted string.") @@ -119,7 +115,6 @@ print.structType <- function(x, ...) { #' @param ... additional argument(s) passed to the method. #' @return A structField object. #' @rdname structField -#' @export #' @examples #'\dontrun{ #' field1 <- structField("a", "integer") @@ -137,7 +132,6 @@ structField <- function(x, ...) { #' @rdname structField #' @method structField jobj -#' @export structField.jobj <- function(x, ...) { obj <- structure(list(), class = "structField") obj$jobj <- x @@ -212,7 +206,6 @@ checkType <- function(type) { #' @param type The data type of the field #' @param nullable A logical vector indicating whether or not the field is nullable #' @rdname structField -#' @export structField.character <- function(x, type, nullable = TRUE, ...) { if (class(x) != "character") { stop("Field name must be a string.") diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R index 965471f3b07a0..a480ac606f10d 100644 --- a/R/pkg/R/sparkR.R +++ b/R/pkg/R/sparkR.R @@ -35,7 +35,6 @@ connExists <- function(env) { #' Also terminates the backend this R session is connected to. #' @rdname sparkR.session.stop #' @name sparkR.session.stop -#' @export #' @note sparkR.session.stop since 2.0.0 sparkR.session.stop <- function() { env <- .sparkREnv @@ -84,7 +83,6 @@ sparkR.session.stop <- function() { #' @rdname sparkR.session.stop #' @name sparkR.stop -#' @export #' @note sparkR.stop since 1.4.0 sparkR.stop <- function() { sparkR.session.stop() @@ -103,7 +101,6 @@ sparkR.stop <- function() { #' @param sparkPackages Character vector of package coordinates #' @seealso \link{sparkR.session} #' @rdname sparkR.init-deprecated -#' @export #' @examples #'\dontrun{ #' sc <- sparkR.init("local[2]", "SparkR", "/home/spark") @@ -270,7 +267,6 @@ sparkR.sparkContext <- function( #' @param jsc The existing JavaSparkContext created with SparkR.init() #' @seealso \link{sparkR.session} #' @rdname sparkRSQL.init-deprecated -#' @export #' @examples #'\dontrun{ #' sc <- sparkR.init() @@ -298,7 +294,6 @@ sparkRSQL.init <- function(jsc = NULL) { #' @param jsc The existing JavaSparkContext created with SparkR.init() #' @seealso \link{sparkR.session} #' @rdname sparkRHive.init-deprecated -#' @export #' @examples #'\dontrun{ #' sc <- sparkR.init() @@ -347,7 +342,6 @@ sparkRHive.init <- function(jsc = NULL) { #' @param enableHiveSupport enable support for Hive, fallback if not built with Hive support; once #' set, this cannot be turned off on an existing session #' @param ... named Spark properties passed to the method. -#' @export #' @examples #'\dontrun{ #' sparkR.session() @@ -442,7 +436,6 @@ sparkR.session <- function( #' @return the SparkUI URL, or NA if it is disabled, or not started. #' @rdname sparkR.uiWebUrl #' @name sparkR.uiWebUrl -#' @export #' @examples #'\dontrun{ #' sparkR.session() diff --git a/R/pkg/R/stats.R b/R/pkg/R/stats.R index c8af798830b30..497f18c763048 100644 --- a/R/pkg/R/stats.R +++ b/R/pkg/R/stats.R @@ -37,7 +37,6 @@ setOldClass("jobj") #' @name crosstab #' @aliases crosstab,SparkDataFrame,character,character-method #' @family stat functions -#' @export #' @examples #' \dontrun{ #' df <- read.json("/path/to/file.json") @@ -63,7 +62,6 @@ setMethod("crosstab", #' @rdname cov #' @aliases cov,SparkDataFrame-method #' @family stat functions -#' @export #' @examples #' #' \dontrun{ @@ -92,7 +90,6 @@ setMethod("cov", #' @name corr #' @aliases corr,SparkDataFrame-method #' @family stat functions -#' @export #' @examples #' #' \dontrun{ @@ -124,7 +121,6 @@ setMethod("corr", #' @name freqItems #' @aliases freqItems,SparkDataFrame,character-method #' @family stat functions -#' @export #' @examples #' \dontrun{ #' df <- read.json("/path/to/file.json") @@ -168,7 +164,6 @@ setMethod("freqItems", signature(x = "SparkDataFrame", cols = "character"), #' @name approxQuantile #' @aliases approxQuantile,SparkDataFrame,character,numeric,numeric-method #' @family stat functions -#' @export #' @examples #' \dontrun{ #' df <- read.json("/path/to/file.json") @@ -205,7 +200,6 @@ setMethod("approxQuantile", #' @aliases sampleBy,SparkDataFrame,character,list,numeric-method #' @name sampleBy #' @family stat functions -#' @export #' @examples #'\dontrun{ #' df <- read.json("/path/to/file.json") diff --git a/R/pkg/R/streaming.R b/R/pkg/R/streaming.R index 8390bd5e6de72..fc83463f72cd4 100644 --- a/R/pkg/R/streaming.R +++ b/R/pkg/R/streaming.R @@ -28,7 +28,6 @@ NULL #' @seealso \link{read.stream} #' #' @param ssq A Java object reference to the backing Scala StreamingQuery -#' @export #' @note StreamingQuery since 2.2.0 #' @note experimental setClass("StreamingQuery", @@ -45,7 +44,6 @@ streamingQuery <- function(ssq) { } #' @rdname show -#' @export #' @note show(StreamingQuery) since 2.2.0 setMethod("show", "StreamingQuery", function(object) { @@ -70,7 +68,6 @@ setMethod("show", "StreamingQuery", #' @aliases queryName,StreamingQuery-method #' @family StreamingQuery methods #' @seealso \link{write.stream} -#' @export #' @examples #' \dontrun{ queryName(sq) } #' @note queryName(StreamingQuery) since 2.2.0 @@ -85,7 +82,6 @@ setMethod("queryName", #' @name explain #' @aliases explain,StreamingQuery-method #' @family StreamingQuery methods -#' @export #' @examples #' \dontrun{ explain(sq) } #' @note explain(StreamingQuery) since 2.2.0 @@ -104,7 +100,6 @@ setMethod("explain", #' @name lastProgress #' @aliases lastProgress,StreamingQuery-method #' @family StreamingQuery methods -#' @export #' @examples #' \dontrun{ lastProgress(sq) } #' @note lastProgress(StreamingQuery) since 2.2.0 @@ -129,7 +124,6 @@ setMethod("lastProgress", #' @name status #' @aliases status,StreamingQuery-method #' @family StreamingQuery methods -#' @export #' @examples #' \dontrun{ status(sq) } #' @note status(StreamingQuery) since 2.2.0 @@ -150,7 +144,6 @@ setMethod("status", #' @name isActive #' @aliases isActive,StreamingQuery-method #' @family StreamingQuery methods -#' @export #' @examples #' \dontrun{ isActive(sq) } #' @note isActive(StreamingQuery) since 2.2.0 @@ -177,7 +170,6 @@ setMethod("isActive", #' @name awaitTermination #' @aliases awaitTermination,StreamingQuery-method #' @family StreamingQuery methods -#' @export #' @examples #' \dontrun{ awaitTermination(sq, 10000) } #' @note awaitTermination(StreamingQuery) since 2.2.0 @@ -202,7 +194,6 @@ setMethod("awaitTermination", #' @name stopQuery #' @aliases stopQuery,StreamingQuery-method #' @family StreamingQuery methods -#' @export #' @examples #' \dontrun{ stopQuery(sq) } #' @note stopQuery(StreamingQuery) since 2.2.0 diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R index 164cd6d01a347..f1b5ecaa017df 100644 --- a/R/pkg/R/utils.R +++ b/R/pkg/R/utils.R @@ -108,7 +108,6 @@ isRDD <- function(name, env) { #' #' @param key the object to be hashed #' @return the hash code as an integer -#' @export #' @examples #'\dontrun{ #' hashCode(1L) # 1 diff --git a/R/pkg/R/window.R b/R/pkg/R/window.R index 0799d841e5dc9..396b27bee80c6 100644 --- a/R/pkg/R/window.R +++ b/R/pkg/R/window.R @@ -29,7 +29,6 @@ #' @rdname windowPartitionBy #' @name windowPartitionBy #' @aliases windowPartitionBy,character-method -#' @export #' @examples #' \dontrun{ #' ws <- orderBy(windowPartitionBy("key1", "key2"), "key3") @@ -52,7 +51,6 @@ setMethod("windowPartitionBy", #' @rdname windowPartitionBy #' @name windowPartitionBy #' @aliases windowPartitionBy,Column-method -#' @export #' @note windowPartitionBy(Column) since 2.0.0 setMethod("windowPartitionBy", signature(col = "Column"), @@ -78,7 +76,6 @@ setMethod("windowPartitionBy", #' @rdname windowOrderBy #' @name windowOrderBy #' @aliases windowOrderBy,character-method -#' @export #' @examples #' \dontrun{ #' ws <- windowOrderBy("key1", "key2") @@ -101,7 +98,6 @@ setMethod("windowOrderBy", #' @rdname windowOrderBy #' @name windowOrderBy #' @aliases windowOrderBy,Column-method -#' @export #' @note windowOrderBy(Column) since 2.0.0 setMethod("windowOrderBy", signature(col = "Column"), From 98a5c0a35f0a24730f5074522939acf57ef95422 Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Mon, 5 Mar 2018 10:50:00 -0800 Subject: [PATCH 0429/1161] [SPARK-22882][ML][TESTS] ML test for structured streaming: ml.classification ## What changes were proposed in this pull request? adding Structured Streaming tests for all Models/Transformers in spark.ml.classification ## How was this patch tested? N/A Author: WeichenXu Closes #20121 from WeichenXu123/ml_stream_test_classification. --- .../DecisionTreeClassifierSuite.scala | 29 +-- .../classification/GBTClassifierSuite.scala | 77 ++---- .../ml/classification/LinearSVCSuite.scala | 15 +- .../LogisticRegressionSuite.scala | 229 +++++++----------- .../MultilayerPerceptronClassifierSuite.scala | 44 ++-- .../ml/classification/NaiveBayesSuite.scala | 47 ++-- .../ml/classification/OneVsRestSuite.scala | 21 +- .../ProbabilisticClassifierSuite.scala | 29 +-- .../RandomForestClassifierSuite.scala | 16 +- 9 files changed, 202 insertions(+), 305 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala index 38b265d62611b..eeb0324187c5b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala @@ -23,15 +23,14 @@ import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.tree.{CategoricalSplit, InternalNode, LeafNode} import org.apache.spark.ml.tree.impl.TreeTests -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint} -import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree, DecisionTreeSuite => OldDecisionTreeSuite} -import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree, + DecisionTreeSuite => OldDecisionTreeSuite} import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Row} -class DecisionTreeClassifierSuite - extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class DecisionTreeClassifierSuite extends MLTest with DefaultReadWriteTest { import DecisionTreeClassifierSuite.compareAPIs import testImplicits._ @@ -251,20 +250,18 @@ class DecisionTreeClassifierSuite MLTestingUtils.checkCopyAndUids(dt, newTree) - val predictions = newTree.transform(newData) - .select(newTree.getPredictionCol, newTree.getRawPredictionCol, newTree.getProbabilityCol) - .collect() - - predictions.foreach { case Row(pred: Double, rawPred: Vector, probPred: Vector) => - assert(pred === rawPred.argmax, - s"Expected prediction $pred but calculated ${rawPred.argmax} from rawPrediction.") - val sum = rawPred.toArray.sum - assert(Vectors.dense(rawPred.toArray.map(_ / sum)) === probPred, - "probability prediction mismatch") + testTransformer[(Vector, Double)](newData, newTree, + "prediction", "rawPrediction", "probability") { + case Row(pred: Double, rawPred: Vector, probPred: Vector) => + assert(pred === rawPred.argmax, + s"Expected prediction $pred but calculated ${rawPred.argmax} from rawPrediction.") + val sum = rawPred.toArray.sum + assert(Vectors.dense(rawPred.toArray.map(_ / sum)) === probPred, + "probability prediction mismatch") } ProbabilisticClassifierSuite.testPredictMethods[ - Vector, DecisionTreeClassificationModel](newTree, newData) + Vector, DecisionTreeClassificationModel](this, newTree, newData) } test("training with 1-category categorical feature") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala index 978f89c459f0a..092b4a01d5b0d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala @@ -26,13 +26,12 @@ import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.regression.DecisionTreeRegressionModel import org.apache.spark.ml.tree.LeafNode import org.apache.spark.ml.tree.impl.TreeTests -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint} import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees => OldGBT} import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.mllib.tree.loss.LogLoss -import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.util.Utils @@ -40,8 +39,7 @@ import org.apache.spark.util.Utils /** * Test suite for [[GBTClassifier]]. */ -class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext - with DefaultReadWriteTest { +class GBTClassifierSuite extends MLTest with DefaultReadWriteTest { import testImplicits._ import GBTClassifierSuite.compareAPIs @@ -126,14 +124,15 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext // should predict all zeros binaryModel.setThresholds(Array(0.0, 1.0)) - val binaryZeroPredictions = binaryModel.transform(df).select("prediction").collect() - assert(binaryZeroPredictions.forall(_.getDouble(0) === 0.0)) + testTransformer[(Double, Vector)](df, binaryModel, "prediction") { + case Row(prediction: Double) => prediction === 0.0 + } // should predict all ones binaryModel.setThresholds(Array(1.0, 0.0)) - val binaryOnePredictions = binaryModel.transform(df).select("prediction").collect() - assert(binaryOnePredictions.forall(_.getDouble(0) === 1.0)) - + testTransformer[(Double, Vector)](df, binaryModel, "prediction") { + case Row(prediction: Double) => prediction === 1.0 + } val gbtBase = new GBTClassifier val model = gbtBase.fit(df) @@ -141,15 +140,18 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext // constant threshold scaling is the same as no thresholds binaryModel.setThresholds(Array(1.0, 1.0)) - val scaledPredictions = binaryModel.transform(df).select("prediction").collect() - assert(scaledPredictions.zip(basePredictions).forall { case (scaled, base) => - scaled.getDouble(0) === base.getDouble(0) - }) + testTransformerByGlobalCheckFunc[(Double, Vector)](df, binaryModel, "prediction") { + scaledPredictions: Seq[Row] => + assert(scaledPredictions.zip(basePredictions).forall { case (scaled, base) => + scaled.getDouble(0) === base.getDouble(0) + }) + } // force it to use the predict method model.setRawPredictionCol("").setProbabilityCol("").setThresholds(Array(0, 1)) - val predictionsWithPredict = model.transform(df).select("prediction").collect() - assert(predictionsWithPredict.forall(_.getDouble(0) === 0.0)) + testTransformer[(Double, Vector)](df, model, "prediction") { + case Row(prediction: Double) => prediction === 0.0 + } } test("GBTClassifier: Predictor, Classifier methods") { @@ -169,61 +171,30 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext val blas = BLAS.getInstance() val validationDataset = validationData.toDF(labelCol, featuresCol) - val results = gbtModel.transform(validationDataset) - // check that raw prediction is tree predictions dot tree weights - results.select(rawPredictionCol, featuresCol).collect().foreach { - case Row(raw: Vector, features: Vector) => + testTransformer[(Double, Vector)](validationDataset, gbtModel, + "rawPrediction", "features", "probability", "prediction") { + case Row(raw: Vector, features: Vector, prob: Vector, pred: Double) => assert(raw.size === 2) + // check that raw prediction is tree predictions dot tree weights val treePredictions = gbtModel.trees.map(_.rootNode.predictImpl(features).prediction) val prediction = blas.ddot(gbtModel.numTrees, treePredictions, 1, gbtModel.treeWeights, 1) assert(raw ~== Vectors.dense(-prediction, prediction) relTol eps) - } - // Compare rawPrediction with probability - results.select(rawPredictionCol, probabilityCol).collect().foreach { - case Row(raw: Vector, prob: Vector) => - assert(raw.size === 2) + // Compare rawPrediction with probability assert(prob.size === 2) // Note: we should check other loss types for classification if they are added val predFromRaw = raw.toDense.values.map(value => LogLoss.computeProbability(value)) assert(prob(0) ~== predFromRaw(0) relTol eps) assert(prob(1) ~== predFromRaw(1) relTol eps) assert(prob(0) + prob(1) ~== 1.0 absTol absEps) - } - // Compare prediction with probability - results.select(predictionCol, probabilityCol).collect().foreach { - case Row(pred: Double, prob: Vector) => + // Compare prediction with probability val predFromProb = prob.toArray.zipWithIndex.maxBy(_._1)._2 assert(pred == predFromProb) } - // force it to use raw2prediction - gbtModel.setRawPredictionCol(rawPredictionCol).setProbabilityCol("") - val resultsUsingRaw2Predict = - gbtModel.transform(validationDataset).select(predictionCol).as[Double].collect() - resultsUsingRaw2Predict.zip(results.select(predictionCol).as[Double].collect()).foreach { - case (pred1, pred2) => assert(pred1 === pred2) - } - - // force it to use probability2prediction - gbtModel.setRawPredictionCol("").setProbabilityCol(probabilityCol) - val resultsUsingProb2Predict = - gbtModel.transform(validationDataset).select(predictionCol).as[Double].collect() - resultsUsingProb2Predict.zip(results.select(predictionCol).as[Double].collect()).foreach { - case (pred1, pred2) => assert(pred1 === pred2) - } - - // force it to use predict - gbtModel.setRawPredictionCol("").setProbabilityCol("") - val resultsUsingPredict = - gbtModel.transform(validationDataset).select(predictionCol).as[Double].collect() - resultsUsingPredict.zip(results.select(predictionCol).as[Double].collect()).foreach { - case (pred1, pred2) => assert(pred1 === pred2) - } - ProbabilisticClassifierSuite.testPredictMethods[ - Vector, GBTClassificationModel](gbtModel, validationDataset) + Vector, GBTClassificationModel](this, gbtModel, validationDataset) } test("GBT parameter stepSize should be in interval (0, 1]") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala index 41a5d22dd6283..a93825b8a812d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala @@ -21,20 +21,18 @@ import scala.util.Random import breeze.linalg.{DenseVector => BDV} -import org.apache.spark.SparkFunSuite import org.apache.spark.ml.classification.LinearSVCSuite._ import org.apache.spark.ml.feature.{Instance, LabeledPoint} import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.ml.optim.aggregator.HingeAggregator import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ -import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{Dataset, Row} import org.apache.spark.sql.functions.udf -class LinearSVCSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class LinearSVCSuite extends MLTest with DefaultReadWriteTest { import testImplicits._ @@ -141,10 +139,11 @@ class LinearSVCSuite extends SparkFunSuite with MLlibTestSparkContext with Defau threshold: Double, expected: Set[(Int, Double)]): Unit = { model.setThreshold(threshold) - val results = model.transform(df).select("id", "prediction").collect() - .map(r => (r.getInt(0), r.getDouble(1))) - .toSet - assert(results === expected, s"Failed for threshold = $threshold") + testTransformerByGlobalCheckFunc[(Int, Vector)](df, model, "id", "prediction") { + rows: Seq[Row] => + val results = rows.map(r => (r.getInt(0), r.getDouble(1))).toSet + assert(results === expected, s"Failed for threshold = $threshold") + } } def checkResults(threshold: Double, expected: Set[(Int, Double)]): Unit = { diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index a5f81a38face9..9987cbf6ba116 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -22,22 +22,20 @@ import scala.language.existentials import scala.util.Random import scala.util.control.Breaks._ -import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.SparkException import org.apache.spark.ml.attribute.NominalAttribute import org.apache.spark.ml.classification.LogisticRegressionSuite._ import org.apache.spark.ml.feature.{Instance, LabeledPoint} import org.apache.spark.ml.linalg.{DenseMatrix, Matrices, Matrix, SparseMatrix, Vector, Vectors} import org.apache.spark.ml.optim.aggregator.LogisticAggregator import org.apache.spark.ml.param.{ParamMap, ParamsSuite} -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ -import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{Dataset, Row} import org.apache.spark.sql.functions.{col, lit, rand} import org.apache.spark.sql.types.LongType -class LogisticRegressionSuite - extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class LogisticRegressionSuite extends MLTest with DefaultReadWriteTest { import testImplicits._ @@ -332,15 +330,14 @@ class LogisticRegressionSuite val binaryModel = blr.fit(smallBinaryDataset) binaryModel.setThreshold(1.0) - val binaryZeroPredictions = - binaryModel.transform(smallBinaryDataset).select("prediction").collect() - assert(binaryZeroPredictions.forall(_.getDouble(0) === 0.0)) + testTransformer[(Double, Vector)](smallBinaryDataset.toDF(), binaryModel, "prediction") { + row => assert(row.getDouble(0) === 0.0) + } binaryModel.setThreshold(0.0) - val binaryOnePredictions = - binaryModel.transform(smallBinaryDataset).select("prediction").collect() - assert(binaryOnePredictions.forall(_.getDouble(0) === 1.0)) - + testTransformer[(Double, Vector)](smallBinaryDataset.toDF(), binaryModel, "prediction") { + row => assert(row.getDouble(0) === 1.0) + } val mlr = new LogisticRegression().setFamily("multinomial") val model = mlr.fit(smallMultinomialDataset) @@ -348,31 +345,36 @@ class LogisticRegressionSuite // should predict all zeros model.setThresholds(Array(1, 1000, 1000)) - val zeroPredictions = model.transform(smallMultinomialDataset).select("prediction").collect() - assert(zeroPredictions.forall(_.getDouble(0) === 0.0)) + testTransformer[(Double, Vector)](smallMultinomialDataset.toDF(), model, "prediction") { + row => assert(row.getDouble(0) === 0.0) + } // should predict all ones model.setThresholds(Array(1000, 1, 1000)) - val onePredictions = model.transform(smallMultinomialDataset).select("prediction").collect() - assert(onePredictions.forall(_.getDouble(0) === 1.0)) + testTransformer[(Double, Vector)](smallMultinomialDataset.toDF(), model, "prediction") { + row => assert(row.getDouble(0) === 1.0) + } // should predict all twos model.setThresholds(Array(1000, 1000, 1)) - val twoPredictions = model.transform(smallMultinomialDataset).select("prediction").collect() - assert(twoPredictions.forall(_.getDouble(0) === 2.0)) + testTransformer[(Double, Vector)](smallMultinomialDataset.toDF(), model, "prediction") { + row => assert(row.getDouble(0) === 2.0) + } // constant threshold scaling is the same as no thresholds model.setThresholds(Array(1000, 1000, 1000)) - val scaledPredictions = model.transform(smallMultinomialDataset).select("prediction").collect() - assert(scaledPredictions.zip(basePredictions).forall { case (scaled, base) => - scaled.getDouble(0) === base.getDouble(0) - }) + testTransformerByGlobalCheckFunc[(Double, Vector)](smallMultinomialDataset.toDF(), model, + "prediction") { scaledPredictions: Seq[Row] => + assert(scaledPredictions.zip(basePredictions).forall { case (scaled, base) => + scaled.getDouble(0) === base.getDouble(0) + }) + } // force it to use the predict method model.setRawPredictionCol("").setProbabilityCol("").setThresholds(Array(0, 1, 1)) - val predictionsWithPredict = - model.transform(smallMultinomialDataset).select("prediction").collect() - assert(predictionsWithPredict.forall(_.getDouble(0) === 0.0)) + testTransformer[(Double, Vector)](smallMultinomialDataset.toDF(), model, "prediction") { + row => assert(row.getDouble(0) === 0.0) + } } test("logistic regression doesn't fit intercept when fitIntercept is off") { @@ -403,21 +405,19 @@ class LogisticRegressionSuite // Modify model params, and check that the params worked. model.setThreshold(1.0) - val predAllZero = model.transform(smallBinaryDataset) - .select("prediction", "myProbability") - .collect() - .map { case Row(pred: Double, prob: Vector) => pred } - assert(predAllZero.forall(_ === 0), - s"With threshold=1.0, expected predictions to be all 0, but only" + - s" ${predAllZero.count(_ === 0)} of ${smallBinaryDataset.count()} were 0.") + testTransformerByGlobalCheckFunc[(Double, Vector)](smallBinaryDataset.toDF(), + model, "prediction", "myProbability") { rows => + val predAllZero = rows.map(_.getDouble(0)) + assert(predAllZero.forall(_ === 0), + s"With threshold=1.0, expected predictions to be all 0, but only" + + s" ${predAllZero.count(_ === 0)} of ${smallBinaryDataset.count()} were 0.") + } // Call transform with params, and check that the params worked. - val predNotAllZero = - model.transform(smallBinaryDataset, model.threshold -> 0.0, - model.probabilityCol -> "myProb") - .select("prediction", "myProb") - .collect() - .map { case Row(pred: Double, prob: Vector) => pred } - assert(predNotAllZero.exists(_ !== 0.0)) + testTransformerByGlobalCheckFunc[(Double, Vector)](smallBinaryDataset.toDF(), + model.copy(ParamMap(model.threshold -> 0.0, + model.probabilityCol -> "myProb")), "prediction", "myProb") { + rows => assert(rows.map(_.getDouble(0)).exists(_ !== 0.0)) + } // Call fit() with new params, and check as many params as we can. lr.setThresholds(Array(0.6, 0.4)) @@ -441,10 +441,10 @@ class LogisticRegressionSuite val numFeatures = smallMultinomialDataset.select("features").first().getAs[Vector](0).size assert(model.numFeatures === numFeatures) - val results = model.transform(smallMultinomialDataset) - // check that raw prediction is coefficients dot features + intercept - results.select("rawPrediction", "features").collect().foreach { - case Row(raw: Vector, features: Vector) => + testTransformer[(Double, Vector)](smallMultinomialDataset.toDF(), + model, "rawPrediction", "features", "probability") { + case Row(raw: Vector, features: Vector, prob: Vector) => + // check that raw prediction is coefficients dot features + intercept assert(raw.size === 3) val margins = Array.tabulate(3) { k => var margin = 0.0 @@ -455,12 +455,7 @@ class LogisticRegressionSuite margin } assert(raw ~== Vectors.dense(margins) relTol eps) - } - - // Compare rawPrediction with probability - results.select("rawPrediction", "probability").collect().foreach { - case Row(raw: Vector, prob: Vector) => - assert(raw.size === 3) + // Compare rawPrediction with probability assert(prob.size === 3) val max = raw.toArray.max val subtract = if (max > 0) max else 0.0 @@ -472,39 +467,8 @@ class LogisticRegressionSuite assert(prob(2) ~== 1.0 - probFromRaw1 - probFromRaw0 relTol eps) } - // Compare prediction with probability - results.select("prediction", "probability").collect().foreach { - case Row(pred: Double, prob: Vector) => - val predFromProb = prob.toArray.zipWithIndex.maxBy(_._1)._2 - assert(pred == predFromProb) - } - - // force it to use raw2prediction - model.setRawPredictionCol("rawPrediction").setProbabilityCol("") - val resultsUsingRaw2Predict = - model.transform(smallMultinomialDataset).select("prediction").as[Double].collect() - resultsUsingRaw2Predict.zip(results.select("prediction").as[Double].collect()).foreach { - case (pred1, pred2) => assert(pred1 === pred2) - } - - // force it to use probability2prediction - model.setRawPredictionCol("").setProbabilityCol("probability") - val resultsUsingProb2Predict = - model.transform(smallMultinomialDataset).select("prediction").as[Double].collect() - resultsUsingProb2Predict.zip(results.select("prediction").as[Double].collect()).foreach { - case (pred1, pred2) => assert(pred1 === pred2) - } - - // force it to use predict - model.setRawPredictionCol("").setProbabilityCol("") - val resultsUsingPredict = - model.transform(smallMultinomialDataset).select("prediction").as[Double].collect() - resultsUsingPredict.zip(results.select("prediction").as[Double].collect()).foreach { - case (pred1, pred2) => assert(pred1 === pred2) - } - ProbabilisticClassifierSuite.testPredictMethods[ - Vector, LogisticRegressionModel](model, smallMultinomialDataset) + Vector, LogisticRegressionModel](this, model, smallMultinomialDataset) } test("binary logistic regression: Predictor, Classifier methods") { @@ -517,51 +481,22 @@ class LogisticRegressionSuite val numFeatures = smallBinaryDataset.select("features").first().getAs[Vector](0).size assert(model.numFeatures === numFeatures) - val results = model.transform(smallBinaryDataset) - - // Compare rawPrediction with probability - results.select("rawPrediction", "probability").collect().foreach { - case Row(raw: Vector, prob: Vector) => + testTransformer[(Double, Vector)](smallBinaryDataset.toDF(), + model, "rawPrediction", "probability", "prediction") { + case Row(raw: Vector, prob: Vector, pred: Double) => + // Compare rawPrediction with probability assert(raw.size === 2) assert(prob.size === 2) val probFromRaw1 = 1.0 / (1.0 + math.exp(-raw(1))) assert(prob(1) ~== probFromRaw1 relTol eps) assert(prob(0) ~== 1.0 - probFromRaw1 relTol eps) - } - - // Compare prediction with probability - results.select("prediction", "probability").collect().foreach { - case Row(pred: Double, prob: Vector) => + // Compare prediction with probability val predFromProb = prob.toArray.zipWithIndex.maxBy(_._1)._2 assert(pred == predFromProb) } - // force it to use raw2prediction - model.setRawPredictionCol("rawPrediction").setProbabilityCol("") - val resultsUsingRaw2Predict = - model.transform(smallBinaryDataset).select("prediction").as[Double].collect() - resultsUsingRaw2Predict.zip(results.select("prediction").as[Double].collect()).foreach { - case (pred1, pred2) => assert(pred1 === pred2) - } - - // force it to use probability2prediction - model.setRawPredictionCol("").setProbabilityCol("probability") - val resultsUsingProb2Predict = - model.transform(smallBinaryDataset).select("prediction").as[Double].collect() - resultsUsingProb2Predict.zip(results.select("prediction").as[Double].collect()).foreach { - case (pred1, pred2) => assert(pred1 === pred2) - } - - // force it to use predict - model.setRawPredictionCol("").setProbabilityCol("") - val resultsUsingPredict = - model.transform(smallBinaryDataset).select("prediction").as[Double].collect() - resultsUsingPredict.zip(results.select("prediction").as[Double].collect()).foreach { - case (pred1, pred2) => assert(pred1 === pred2) - } - ProbabilisticClassifierSuite.testPredictMethods[ - Vector, LogisticRegressionModel](model, smallBinaryDataset) + Vector, LogisticRegressionModel](this, model, smallBinaryDataset) } test("coefficients and intercept methods") { @@ -616,19 +551,21 @@ class LogisticRegressionSuite LabeledPoint(1.0, Vectors.dense(0.0, 1000.0)), LabeledPoint(1.0, Vectors.dense(0.0, -1.0)) ).toDF() - val results = model.transform(overFlowData).select("rawPrediction", "probability").collect() - - // probabilities are correct when margins have to be adjusted - val raw1 = results(0).getAs[Vector](0) - val prob1 = results(0).getAs[Vector](1) - assert(raw1 === Vectors.dense(1000.0, 2000.0, 3000.0)) - assert(prob1 ~== Vectors.dense(0.0, 0.0, 1.0) absTol eps) - - // probabilities are correct when margins don't have to be adjusted - val raw2 = results(1).getAs[Vector](0) - val prob2 = results(1).getAs[Vector](1) - assert(raw2 === Vectors.dense(-1.0, -2.0, -3.0)) - assert(prob2 ~== Vectors.dense(0.66524096, 0.24472847, 0.09003057) relTol eps) + + testTransformerByGlobalCheckFunc[(Double, Vector)](overFlowData.toDF(), + model, "rawPrediction", "probability") { results: Seq[Row] => + // probabilities are correct when margins have to be adjusted + val raw1 = results(0).getAs[Vector](0) + val prob1 = results(0).getAs[Vector](1) + assert(raw1 === Vectors.dense(1000.0, 2000.0, 3000.0)) + assert(prob1 ~== Vectors.dense(0.0, 0.0, 1.0) absTol eps) + + // probabilities are correct when margins don't have to be adjusted + val raw2 = results(1).getAs[Vector](0) + val prob2 = results(1).getAs[Vector](1) + assert(raw2 === Vectors.dense(-1.0, -2.0, -3.0)) + assert(prob2 ~== Vectors.dense(0.66524096, 0.24472847, 0.09003057) relTol eps) + } } test("MultiClassSummarizer") { @@ -2567,10 +2504,13 @@ class LogisticRegressionSuite val model1 = lr.fit(smallBinaryDataset) val lr2 = new LogisticRegression().setInitialModel(model1).setMaxIter(5).setFamily("binomial") val model2 = lr2.fit(smallBinaryDataset) - val predictions1 = model1.transform(smallBinaryDataset).select("prediction").collect() - val predictions2 = model2.transform(smallBinaryDataset).select("prediction").collect() - predictions1.zip(predictions2).foreach { case (Row(p1: Double), Row(p2: Double)) => - assert(p1 === p2) + val binaryExpected = model1.transform(smallBinaryDataset).select("prediction").collect() + .map(_.getDouble(0)) + for (model <- Seq(model1, model2)) { + testTransformerByGlobalCheckFunc[(Double, Vector)](smallBinaryDataset.toDF(), model, + "prediction") { rows: Seq[Row] => + rows.map(_.getDouble(0)).toArray === binaryExpected + } } assert(model2.summary.totalIterations === 1) @@ -2579,10 +2519,13 @@ class LogisticRegressionSuite val lr4 = new LogisticRegression() .setInitialModel(model3).setMaxIter(5).setFamily("multinomial") val model4 = lr4.fit(smallMultinomialDataset) - val predictions3 = model3.transform(smallMultinomialDataset).select("prediction").collect() - val predictions4 = model4.transform(smallMultinomialDataset).select("prediction").collect() - predictions3.zip(predictions4).foreach { case (Row(p1: Double), Row(p2: Double)) => - assert(p1 === p2) + val multinomialExpected = model3.transform(smallMultinomialDataset).select("prediction") + .collect().map(_.getDouble(0)) + for (model <- Seq(model3, model4)) { + testTransformerByGlobalCheckFunc[(Double, Vector)](smallMultinomialDataset.toDF(), model, + "prediction") { rows: Seq[Row] => + rows.map(_.getDouble(0)).toArray === multinomialExpected + } } assert(model4.summary.totalIterations === 1) } @@ -2638,8 +2581,8 @@ class LogisticRegressionSuite LabeledPoint(4.0, Vectors.dense(2.0))).toDF() val mlr = new LogisticRegression().setFamily("multinomial") val model = mlr.fit(constantData) - val results = model.transform(constantData) - results.select("rawPrediction", "probability", "prediction").collect().foreach { + testTransformer[(Double, Vector)](constantData, model, + "rawPrediction", "probability", "prediction") { case Row(raw: Vector, prob: Vector, pred: Double) => assert(raw === Vectors.dense(Array(0.0, 0.0, 0.0, 0.0, Double.PositiveInfinity))) assert(prob === Vectors.dense(Array(0.0, 0.0, 0.0, 0.0, 1.0))) @@ -2653,8 +2596,8 @@ class LogisticRegressionSuite LabeledPoint(0.0, Vectors.dense(1.0)), LabeledPoint(0.0, Vectors.dense(2.0))).toDF() val modelZeroLabel = mlr.setFitIntercept(false).fit(constantZeroData) - val resultsZero = modelZeroLabel.transform(constantZeroData) - resultsZero.select("rawPrediction", "probability", "prediction").collect().foreach { + testTransformer[(Double, Vector)](constantZeroData, modelZeroLabel, + "rawPrediction", "probability", "prediction") { case Row(raw: Vector, prob: Vector, pred: Double) => assert(prob === Vectors.dense(Array(1.0))) assert(pred === 0.0) @@ -2666,8 +2609,8 @@ class LogisticRegressionSuite val constantDataWithMetadata = constantData .select(constantData("label").as("label", labelMeta), constantData("features")) val modelWithMetadata = mlr.setFitIntercept(true).fit(constantDataWithMetadata) - val resultsWithMetadata = modelWithMetadata.transform(constantDataWithMetadata) - resultsWithMetadata.select("rawPrediction", "probability", "prediction").collect().foreach { + testTransformer[(Double, Vector)](constantDataWithMetadata, modelWithMetadata, + "rawPrediction", "probability", "prediction") { case Row(raw: Vector, prob: Vector, pred: Double) => assert(raw === Vectors.dense(Array(0.0, 0.0, 0.0, 0.0, Double.PositiveInfinity, 0.0))) assert(prob === Vectors.dense(Array(0.0, 0.0, 0.0, 0.0, 1.0, 0.0))) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala index d3141ec708560..daa58a56896d7 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala @@ -17,22 +17,17 @@ package org.apache.spark.ml.classification -import org.apache.spark.SparkFunSuite import org.apache.spark.ml.classification.LogisticRegressionSuite._ import org.apache.spark.ml.linalg.{Vector, Vectors} -import org.apache.spark.ml.util.DefaultReadWriteTest -import org.apache.spark.ml.util.MLTestingUtils +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS import org.apache.spark.mllib.evaluation.MulticlassMetrics import org.apache.spark.mllib.linalg.{Vectors => OldVectors} import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint} -import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{Dataset, Row} -import org.apache.spark.sql.functions._ -class MultilayerPerceptronClassifierSuite - extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class MultilayerPerceptronClassifierSuite extends MLTest with DefaultReadWriteTest { import testImplicits._ @@ -75,11 +70,9 @@ class MultilayerPerceptronClassifierSuite .setMaxIter(100) .setSolver("l-bfgs") val model = trainer.fit(dataset) - val result = model.transform(dataset) MLTestingUtils.checkCopyAndUids(trainer, model) - val predictionAndLabels = result.select("prediction", "label").collect() - predictionAndLabels.foreach { case Row(p: Double, l: Double) => - assert(p == l) + testTransformer[(Vector, Double)](dataset.toDF(), model, "prediction", "label") { + case Row(p: Double, l: Double) => assert(p == l) } } @@ -99,13 +92,12 @@ class MultilayerPerceptronClassifierSuite .setMaxIter(100) .setSolver("l-bfgs") val model = trainer.fit(strongDataset) - val result = model.transform(strongDataset) - result.select("probability", "expectedProbability").collect().foreach { - case Row(p: Vector, e: Vector) => - assert(p ~== e absTol 1e-3) + testTransformer[(Vector, Double, Vector)](strongDataset.toDF(), model, + "probability", "expectedProbability") { + case Row(p: Vector, e: Vector) => assert(p ~== e absTol 1e-3) } ProbabilisticClassifierSuite.testPredictMethods[ - Vector, MultilayerPerceptronClassificationModel](model, strongDataset) + Vector, MultilayerPerceptronClassificationModel](this, model, strongDataset) } test("test model probability") { @@ -118,11 +110,10 @@ class MultilayerPerceptronClassifierSuite .setSolver("l-bfgs") val model = trainer.fit(dataset) model.setProbabilityCol("probability") - val result = model.transform(dataset) - val features2prob = udf { features: Vector => model.mlpModel.predict(features) } - result.select(features2prob(col("features")), col("probability")).collect().foreach { - case Row(p1: Vector, p2: Vector) => - assert(p1 ~== p2 absTol 1e-3) + testTransformer[(Vector, Double)](dataset.toDF(), model, "features", "probability") { + case Row(features: Vector, prob: Vector) => + val prob2 = model.mlpModel.predict(features) + assert(prob ~== prob2 absTol 1e-3) } } @@ -175,9 +166,6 @@ class MultilayerPerceptronClassifierSuite val model = trainer.fit(dataFrame) val numFeatures = dataFrame.select("features").first().getAs[Vector](0).size assert(model.numFeatures === numFeatures) - val mlpPredictionAndLabels = model.transform(dataFrame).select("prediction", "label").rdd.map { - case Row(p: Double, l: Double) => (p, l) - } // train multinomial logistic regression val lr = new LogisticRegressionWithLBFGS() .setIntercept(true) @@ -189,8 +177,12 @@ class MultilayerPerceptronClassifierSuite lrModel.predict(data.rdd.map(p => OldVectors.fromML(p.features))).zip(data.rdd.map(_.label)) // MLP's predictions should not differ a lot from LR's. val lrMetrics = new MulticlassMetrics(lrPredictionAndLabels) - val mlpMetrics = new MulticlassMetrics(mlpPredictionAndLabels) - assert(mlpMetrics.confusionMatrix.asML ~== lrMetrics.confusionMatrix.asML absTol 100) + testTransformerByGlobalCheckFunc[(Double, Vector)](dataFrame, model, "prediction", "label") { + rows: Seq[Row] => + val mlpPredictionAndLabels = rows.map(x => (x.getDouble(0), x.getDouble(1))) + val mlpMetrics = new MulticlassMetrics(sc.makeRDD(mlpPredictionAndLabels)) + assert(mlpMetrics.confusionMatrix.asML ~== lrMetrics.confusionMatrix.asML absTol 100) + } } test("read/write: MultilayerPerceptronClassifier") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala index 0d3adf993383f..49115c8a4db30 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala @@ -28,12 +28,11 @@ import org.apache.spark.ml.classification.NaiveBayesSuite._ import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.linalg._ import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ -import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, Dataset, Row} -class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class NaiveBayesSuite extends MLTest with DefaultReadWriteTest { import testImplicits._ @@ -56,13 +55,13 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa bernoulliDataset = generateNaiveBayesInput(pi, theta, 100, seed, "bernoulli").toDF() } - def validatePrediction(predictionAndLabels: DataFrame): Unit = { - val numOfErrorPredictions = predictionAndLabels.collect().count { + def validatePrediction(predictionAndLabels: Seq[Row]): Unit = { + val numOfErrorPredictions = predictionAndLabels.filter { case Row(prediction: Double, label: Double) => prediction != label - } + }.length // At least 80% of the predictions should be on. - assert(numOfErrorPredictions < predictionAndLabels.count() / 5) + assert(numOfErrorPredictions < predictionAndLabels.length / 5) } def validateModelFit( @@ -92,10 +91,10 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa } def validateProbabilities( - featureAndProbabilities: DataFrame, + featureAndProbabilities: Seq[Row], model: NaiveBayesModel, modelType: String): Unit = { - featureAndProbabilities.collect().foreach { + featureAndProbabilities.foreach { case Row(features: Vector, probability: Vector) => assert(probability.toArray.sum ~== 1.0 relTol 1.0e-10) val expected = modelType match { @@ -154,15 +153,18 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa val validationDataset = generateNaiveBayesInput(piArray, thetaArray, nPoints, 17, "multinomial").toDF() - val predictionAndLabels = model.transform(validationDataset).select("prediction", "label") - validatePrediction(predictionAndLabels) + testTransformerByGlobalCheckFunc[(Double, Vector)](validationDataset, model, + "prediction", "label") { predictionAndLabels: Seq[Row] => + validatePrediction(predictionAndLabels) + } - val featureAndProbabilities = model.transform(validationDataset) - .select("features", "probability") - validateProbabilities(featureAndProbabilities, model, "multinomial") + testTransformerByGlobalCheckFunc[(Double, Vector)](validationDataset, model, + "features", "probability") { featureAndProbabilities: Seq[Row] => + validateProbabilities(featureAndProbabilities, model, "multinomial") + } ProbabilisticClassifierSuite.testPredictMethods[ - Vector, NaiveBayesModel](model, testDataset) + Vector, NaiveBayesModel](this, model, testDataset) } test("Naive Bayes with weighted samples") { @@ -210,15 +212,18 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa val validationDataset = generateNaiveBayesInput(piArray, thetaArray, nPoints, 20, "bernoulli").toDF() - val predictionAndLabels = model.transform(validationDataset).select("prediction", "label") - validatePrediction(predictionAndLabels) + testTransformerByGlobalCheckFunc[(Double, Vector)](validationDataset, model, + "prediction", "label") { predictionAndLabels: Seq[Row] => + validatePrediction(predictionAndLabels) + } - val featureAndProbabilities = model.transform(validationDataset) - .select("features", "probability") - validateProbabilities(featureAndProbabilities, model, "bernoulli") + testTransformerByGlobalCheckFunc[(Double, Vector)](validationDataset, model, + "features", "probability") { featureAndProbabilities: Seq[Row] => + validateProbabilities(featureAndProbabilities, model, "bernoulli") + } ProbabilisticClassifierSuite.testPredictMethods[ - Vector, NaiveBayesModel](model, testDataset) + Vector, NaiveBayesModel](this, model, testDataset) } test("detect negative values") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala index 25bad59b9c9cf..11e88367108b4 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala @@ -17,26 +17,24 @@ package org.apache.spark.ml.classification -import org.apache.spark.SparkFunSuite import org.apache.spark.ml.attribute.NominalAttribute import org.apache.spark.ml.classification.LogisticRegressionSuite._ import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.feature.StringIndexer -import org.apache.spark.ml.linalg.Vectors +import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.{ParamMap, ParamsSuite} -import org.apache.spark.ml.util.{DefaultReadWriteTest, MetadataUtils, MLTestingUtils} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MetadataUtils, MLTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS import org.apache.spark.mllib.evaluation.MulticlassMetrics import org.apache.spark.mllib.linalg.{Vectors => OldVectors} import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint} -import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD import org.apache.spark.sql.Dataset import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.Metadata -class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class OneVsRestSuite extends MLTest with DefaultReadWriteTest { import testImplicits._ @@ -85,10 +83,6 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext with Defau val predictionColSchema = transformedDataset.schema(ovaModel.getPredictionCol) assert(MetadataUtils.getNumClasses(predictionColSchema) === Some(3)) - val ovaResults = transformedDataset.select("prediction", "label").rdd.map { - row => (row.getDouble(0), row.getDouble(1)) - } - val lr = new LogisticRegressionWithLBFGS().setIntercept(true).setNumClasses(numClasses) lr.optimizer.setRegParam(0.1).setNumIterations(100) @@ -97,8 +91,13 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext with Defau // determine the #confusion matrix in each class. // bound how much error we allow compared to multinomial logistic regression. val expectedMetrics = new MulticlassMetrics(results) - val ovaMetrics = new MulticlassMetrics(ovaResults) - assert(expectedMetrics.confusionMatrix.asML ~== ovaMetrics.confusionMatrix.asML absTol 400) + + testTransformerByGlobalCheckFunc[(Double, Vector)](dataset.toDF(), ovaModel, + "prediction", "label") { rows => + val ovaResults = rows.map { row => (row.getDouble(0), row.getDouble(1)) } + val ovaMetrics = new MulticlassMetrics(sc.makeRDD(ovaResults)) + assert(expectedMetrics.confusionMatrix.asML ~== ovaMetrics.confusionMatrix.asML absTol 400) + } } test("one-vs-rest: tuning parallelism does not change output") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala index d649ceac949c4..1c8c9829f18d1 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.ml.classification import org.apache.spark.SparkFunSuite import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamMap +import org.apache.spark.ml.util.MLTest import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.sql.{Dataset, Row} @@ -122,13 +123,15 @@ object ProbabilisticClassifierSuite { def testPredictMethods[ FeaturesType, M <: ProbabilisticClassificationModel[FeaturesType, M]]( - model: M, testData: Dataset[_]): Unit = { + mlTest: MLTest, model: M, testData: Dataset[_]): Unit = { val allColModel = model.copy(ParamMap.empty) .setRawPredictionCol("rawPredictionAll") .setProbabilityCol("probabilityAll") .setPredictionCol("predictionAll") - val allColResult = allColModel.transform(testData) + + val allColResult = allColModel.transform(testData.select(allColModel.getFeaturesCol)) + .select(allColModel.getFeaturesCol, "rawPredictionAll", "probabilityAll", "predictionAll") for (rawPredictionCol <- Seq("", "rawPredictionSingle")) { for (probabilityCol <- Seq("", "probabilitySingle")) { @@ -138,22 +141,14 @@ object ProbabilisticClassifierSuite { .setProbabilityCol(probabilityCol) .setPredictionCol(predictionCol) - val result = newModel.transform(allColResult) - - import org.apache.spark.sql.functions._ - - val resultRawPredictionCol = - if (rawPredictionCol.isEmpty) col("rawPredictionAll") else col(rawPredictionCol) - val resultProbabilityCol = - if (probabilityCol.isEmpty) col("probabilityAll") else col(probabilityCol) - val resultPredictionCol = - if (predictionCol.isEmpty) col("predictionAll") else col(predictionCol) + import allColResult.sparkSession.implicits._ - result.select( - resultRawPredictionCol, col("rawPredictionAll"), - resultProbabilityCol, col("probabilityAll"), - resultPredictionCol, col("predictionAll") - ).collect().foreach { + mlTest.testTransformer[(Vector, Vector, Vector, Double)](allColResult, newModel, + if (rawPredictionCol.isEmpty) "rawPredictionAll" else rawPredictionCol, + "rawPredictionAll", + if (probabilityCol.isEmpty) "probabilityAll" else probabilityCol, "probabilityAll", + if (predictionCol.isEmpty) "predictionAll" else predictionCol, "predictionAll" + ) { case Row( rawPredictionSingle: Vector, rawPredictionAll: Vector, probabilitySingle: Vector, probabilityAll: Vector, diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala index 2cca2e6c04698..02a9d5c2a18c0 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala @@ -23,11 +23,10 @@ import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.tree.LeafNode import org.apache.spark.ml.tree.impl.TreeTests -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint} import org.apache.spark.mllib.tree.{EnsembleTestHelper, RandomForest => OldRandomForest} import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} -import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Row} @@ -35,8 +34,7 @@ import org.apache.spark.sql.{DataFrame, Row} /** * Test suite for [[RandomForestClassifier]]. */ -class RandomForestClassifierSuite - extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class RandomForestClassifierSuite extends MLTest with DefaultReadWriteTest { import RandomForestClassifierSuite.compareAPIs import testImplicits._ @@ -143,11 +141,8 @@ class RandomForestClassifierSuite MLTestingUtils.checkCopyAndUids(rf, model) - val predictions = model.transform(df) - .select(rf.getPredictionCol, rf.getRawPredictionCol, rf.getProbabilityCol) - .collect() - - predictions.foreach { case Row(pred: Double, rawPred: Vector, probPred: Vector) => + testTransformer[(Vector, Double)](df, model, "prediction", "rawPrediction", + "probability") { case Row(pred: Double, rawPred: Vector, probPred: Vector) => assert(pred === rawPred.argmax, s"Expected prediction $pred but calculated ${rawPred.argmax} from rawPrediction.") val sum = rawPred.toArray.sum @@ -155,8 +150,9 @@ class RandomForestClassifierSuite "probability prediction mismatch") assert(probPred.toArray.sum ~== 1.0 relTol 1E-5) } + ProbabilisticClassifierSuite.testPredictMethods[ - Vector, RandomForestClassificationModel](model, df) + Vector, RandomForestClassificationModel](this, model, df) } test("Fitting without numClasses in metadata") { From ba622f45caa808a9320c1f7ba4a4f344365dcf90 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Mon, 5 Mar 2018 20:43:03 +0100 Subject: [PATCH 0430/1161] [SPARK-23585][SQL] Add interpreted execution to UnwrapOption ## What changes were proposed in this pull request? The PR adds interpreted execution to UnwrapOption. ## How was this patch tested? added UT Author: Marco Gaido Closes #20736 from mgaido91/SPARK-23586. --- .../sql/catalyst/expressions/objects/objects.scala | 10 ++++++++-- .../catalyst/expressions/ObjectExpressionsSuite.scala | 11 ++++++++++- 2 files changed, 18 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 80618af1e859f..03cc8eaceb4e6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -382,8 +382,14 @@ case class UnwrapOption( override def inputTypes: Seq[AbstractDataType] = ObjectType :: Nil - override def eval(input: InternalRow): Any = - throw new UnsupportedOperationException("Only code-generated evaluation is supported") + override def eval(input: InternalRow): Any = { + val inputObject = child.eval(input) + if (inputObject == null) { + null + } else { + inputObject.asInstanceOf[Option[_]].orNull + } + } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val javaType = CodeGenerator.javaType(dataType) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala index 3edcc02f15264..d95db5867b19c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder -import org.apache.spark.sql.catalyst.expressions.objects.Invoke +import org.apache.spark.sql.catalyst.expressions.objects.{Invoke, UnwrapOption} import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData} import org.apache.spark.sql.types.{IntegerType, ObjectType} @@ -66,4 +66,13 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvalutionWithUnsafeProjection( mapEncoder.serializer.head, mapExpected, mapInputRow) } + + test("SPARK-23585: UnwrapOption should support interpreted execution") { + val cls = classOf[Option[Int]] + val inputObject = BoundReference(0, ObjectType(cls), nullable = true) + val unwrapObject = UnwrapOption(IntegerType, inputObject) + Seq((Some(1), 1), (None, null), (null, null)).foreach { case (input, expected) => + checkEvaluation(unwrapObject, expected, InternalRow.fromSeq(Seq(input))) + } + } } From b0f422c3861a5a3831e481b8ffac08f6fa085d00 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Mon, 5 Mar 2018 13:23:01 -0800 Subject: [PATCH 0431/1161] [SPARK-23559][SS] Add epoch ID to DataWriterFactory. ## What changes were proposed in this pull request? Add an epoch ID argument to DataWriterFactory for use in streaming. As a side effect of passing in this value, DataWriter will now have a consistent lifecycle; commit() or abort() ends the lifecycle of a DataWriter instance in any execution mode. I considered making a separate streaming interface and adding the epoch ID only to that one, but I think it requires a lot of extra work for no real gain. I think it makes sense to define epoch 0 as the one and only epoch of a non-streaming query. ## How was this patch tested? existing unit tests Author: Jose Torres Closes #20710 from jose-torres/api2. --- .../sql/kafka010/KafkaStreamWriter.scala | 5 +++- .../sql/sources/v2/writer/DataWriter.java | 12 ++++++--- .../sources/v2/writer/DataWriterFactory.java | 5 +++- .../v2/writer/streaming/StreamWriter.java | 19 +++++++------- .../datasources/v2/WriteToDataSourceV2.scala | 25 +++++++++++++------ .../streaming/MicroBatchExecution.scala | 7 ++++++ .../sources/PackedRowWriterFactory.scala | 5 +++- .../streaming/sources/memoryV2.scala | 5 +++- .../sources/v2/SimpleWritableDataSource.scala | 10 ++++++-- 9 files changed, 65 insertions(+), 28 deletions(-) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamWriter.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamWriter.scala index 9307bfc001c03..ae5b5c52d514e 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamWriter.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamWriter.scala @@ -65,7 +65,10 @@ case class KafkaStreamWriterFactory( topic: Option[String], producerParams: Map[String, String], schema: StructType) extends DataWriterFactory[InternalRow] { - override def createDataWriter(partitionId: Int, attemptNumber: Int): DataWriter[InternalRow] = { + override def createDataWriter( + partitionId: Int, + attemptNumber: Int, + epochId: Long): DataWriter[InternalRow] = { new KafkaStreamDataWriter(topic, producerParams, schema.toAttributes) } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java index 53941a89ba94e..39bf458298862 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java @@ -22,7 +22,7 @@ import org.apache.spark.annotation.InterfaceStability; /** - * A data writer returned by {@link DataWriterFactory#createDataWriter(int, int)} and is + * A data writer returned by {@link DataWriterFactory#createDataWriter(int, int, long)} and is * responsible for writing data for an input RDD partition. * * One Spark task has one exclusive data writer, so there is no thread-safe concern. @@ -31,13 +31,17 @@ * the {@link #write(Object)}, {@link #abort()} is called afterwards and the remaining records will * not be processed. If all records are successfully written, {@link #commit()} is called. * + * Once a data writer returns successfully from {@link #commit()} or {@link #abort()}, its lifecycle + * is over and Spark will not use it again. + * * If this data writer succeeds(all records are successfully written and {@link #commit()} * succeeds), a {@link WriterCommitMessage} will be sent to the driver side and pass to * {@link DataSourceWriter#commit(WriterCommitMessage[])} with commit messages from other data * writers. If this data writer fails(one record fails to write or {@link #commit()} fails), an - * exception will be sent to the driver side, and Spark will retry this writing task for some times, - * each time {@link DataWriterFactory#createDataWriter(int, int)} gets a different `attemptNumber`, - * and finally call {@link DataSourceWriter#abort(WriterCommitMessage[])} if all retry fail. + * exception will be sent to the driver side, and Spark may retry this writing task a few times. + * In each retry, {@link DataWriterFactory#createDataWriter(int, int, long)} will receive a + * different `attemptNumber`. Spark will call {@link DataSourceWriter#abort(WriterCommitMessage[])} + * when the configured number of retries is exhausted. * * Besides the retry mechanism, Spark may launch speculative tasks if the existing writing task * takes too long to finish. Different from retried tasks, which are launched one by one after the diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java index ea95442511ce5..c2c2ab73257e8 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java @@ -48,6 +48,9 @@ public interface DataWriterFactory extends Serializable { * same task id but different attempt number, which means there are multiple * tasks with the same task id running at the same time. Implementations can * use this attempt number to distinguish writers of different task attempts. + * @param epochId A monotonically increasing id for streaming queries that are split in to + * discrete periods of execution. For non-streaming queries, + * this ID will always be 0. */ - DataWriter createDataWriter(int partitionId, int attemptNumber); + DataWriter createDataWriter(int partitionId, int attemptNumber, long epochId); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamWriter.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamWriter.java index 4913341bd505d..a316b2a4c1d82 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamWriter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamWriter.java @@ -23,11 +23,10 @@ import org.apache.spark.sql.sources.v2.writer.WriterCommitMessage; /** - * A {@link DataSourceWriter} for use with structured streaming. This writer handles commits and - * aborts relative to an epoch ID determined by the execution engine. + * A {@link DataSourceWriter} for use with structured streaming. * - * {@link DataWriter} implementations generated by a StreamWriter may be reused for multiple epochs, - * and so must reset any internal state after a successful commit. + * Streaming queries are divided into intervals of data called epochs, with a monotonically + * increasing numeric ID. This writer handles commits and aborts for each successive epoch. */ @InterfaceStability.Evolving public interface StreamWriter extends DataSourceWriter { @@ -39,21 +38,21 @@ public interface StreamWriter extends DataSourceWriter { * If this method fails (by throwing an exception), this writing job is considered to have been * failed, and the execution engine will attempt to call {@link #abort(WriterCommitMessage[])}. * - * To support exactly-once processing, writer implementations should ensure that this method is - * idempotent. The execution engine may call commit() multiple times for the same epoch - * in some circumstances. + * The execution engine may call commit() multiple times for the same epoch in some circumstances. + * To support exactly-once data semantics, implementations must ensure that multiple commits for + * the same epoch are idempotent. */ void commit(long epochId, WriterCommitMessage[] messages); /** - * Aborts this writing job because some data writers are failed and keep failing when retry, or + * Aborts this writing job because some data writers are failed and keep failing when retried, or * the Spark job fails with some unknown reasons, or {@link #commit(WriterCommitMessage[])} fails. * * If this method fails (by throwing an exception), the underlying data source may require manual * cleanup. * - * Unless the abort is triggered by the failure of commit, the given messages should have some - * null slots as there maybe only a few data writers that are committed before the abort + * Unless the abort is triggered by the failure of commit, the given messages will have some + * null slots, as there may be only a few data writers that were committed before the abort * happens, or some data writers were committed but their commit messages haven't reached the * driver when the abort is triggered. So this is just a "best effort" for data sources to * clean up the data left by data writers. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala index 41cdfc80d8a19..e80b44c1cdc66 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder} import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.sql.execution.streaming.StreamExecution +import org.apache.spark.sql.execution.streaming.{MicroBatchExecution, StreamExecution} import org.apache.spark.sql.execution.streaming.continuous.{CommitPartitionEpoch, ContinuousExecution, EpochCoordinatorRef, SetWriterPartitions} import org.apache.spark.sql.sources.v2.writer._ import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter @@ -132,7 +132,8 @@ object DataWritingSparkTask extends Logging { val stageId = context.stageId() val partId = context.partitionId() val attemptId = context.attemptNumber() - val dataWriter = writeTask.createDataWriter(partId, attemptId) + val epochId = Option(context.getLocalProperty(MicroBatchExecution.BATCH_ID_KEY)).getOrElse("0") + val dataWriter = writeTask.createDataWriter(partId, attemptId, epochId.toLong) // write the data and commit this writer. Utils.tryWithSafeFinallyAndFailureCallbacks(block = { @@ -172,7 +173,6 @@ object DataWritingSparkTask extends Logging { writeTask: DataWriterFactory[InternalRow], context: TaskContext, iter: Iterator[InternalRow]): WriterCommitMessage = { - val dataWriter = writeTask.createDataWriter(context.partitionId(), context.attemptNumber()) val epochCoordinator = EpochCoordinatorRef.get( context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), SparkEnv.get) @@ -180,10 +180,15 @@ object DataWritingSparkTask extends Logging { var currentEpoch = context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong do { + var dataWriter: DataWriter[InternalRow] = null // write the data and commit this writer. Utils.tryWithSafeFinallyAndFailureCallbacks(block = { try { - iter.foreach(dataWriter.write) + dataWriter = writeTask.createDataWriter( + context.partitionId(), context.attemptNumber(), currentEpoch) + while (iter.hasNext) { + dataWriter.write(iter.next()) + } logInfo(s"Writer for partition ${context.partitionId()} is committing.") val msg = dataWriter.commit() logInfo(s"Writer for partition ${context.partitionId()} committed.") @@ -196,9 +201,10 @@ object DataWritingSparkTask extends Logging { // Continuous shutdown always involves an interrupt. Just finish the task. } })(catchBlock = { - // If there is an error, abort this writer + // If there is an error, abort this writer. We enter this callback in the middle of + // rethrowing an exception, so runContinuous will stop executing at this point. logError(s"Writer for partition ${context.partitionId()} is aborting.") - dataWriter.abort() + if (dataWriter != null) dataWriter.abort() logError(s"Writer for partition ${context.partitionId()} aborted.") }) } while (!context.isInterrupted()) @@ -211,9 +217,12 @@ class InternalRowDataWriterFactory( rowWriterFactory: DataWriterFactory[Row], schema: StructType) extends DataWriterFactory[InternalRow] { - override def createDataWriter(partitionId: Int, attemptNumber: Int): DataWriter[InternalRow] = { + override def createDataWriter( + partitionId: Int, + attemptNumber: Int, + epochId: Long): DataWriter[InternalRow] = { new InternalRowDataWriter( - rowWriterFactory.createDataWriter(partitionId, attemptNumber), + rowWriterFactory.createDataWriter(partitionId, attemptNumber, epochId), RowEncoder.apply(schema).resolveAndBind()) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index 6bd03972c301d..ff4be9c7ab874 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -469,6 +469,9 @@ class MicroBatchExecution( case _ => throw new IllegalArgumentException(s"unknown sink type for $sink") } + sparkSessionToRunBatch.sparkContext.setLocalProperty( + MicroBatchExecution.BATCH_ID_KEY, currentBatchId.toString) + reportTimeTaken("queryPlanning") { lastExecution = new IncrementalExecution( sparkSessionToRunBatch, @@ -518,3 +521,7 @@ class MicroBatchExecution( Optional.ofNullable(scalaOption.orNull) } } + +object MicroBatchExecution { + val BATCH_ID_KEY = "streaming.sql.batchId" +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PackedRowWriterFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PackedRowWriterFactory.scala index 248295e401a0d..e07355aa37dba 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PackedRowWriterFactory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PackedRowWriterFactory.scala @@ -31,7 +31,10 @@ import org.apache.spark.sql.sources.v2.writer.{DataSourceWriter, DataWriter, Dat * for production-quality sinks. It's intended for use in tests. */ case object PackedRowWriterFactory extends DataWriterFactory[Row] { - def createDataWriter(partitionId: Int, attemptNumber: Int): DataWriter[Row] = { + override def createDataWriter( + partitionId: Int, + attemptNumber: Int, + epochId: Long): DataWriter[Row] = { new PackedRowDataWriter() } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala index f960208155e3b..5f58246083bb2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala @@ -147,7 +147,10 @@ class MemoryStreamWriter(val sink: MemorySinkV2, outputMode: OutputMode) } case class MemoryWriterFactory(outputMode: OutputMode) extends DataWriterFactory[Row] { - def createDataWriter(partitionId: Int, attemptNumber: Int): DataWriter[Row] = { + override def createDataWriter( + partitionId: Int, + attemptNumber: Int, + epochId: Long): DataWriter[Row] = { new MemoryDataWriter(partitionId, outputMode) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala index 36dd2a350a055..a5007fa321359 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala @@ -207,7 +207,10 @@ private[v2] object SimpleCounter { class SimpleCSVDataWriterFactory(path: String, jobId: String, conf: SerializableConfiguration) extends DataWriterFactory[Row] { - override def createDataWriter(partitionId: Int, attemptNumber: Int): DataWriter[Row] = { + override def createDataWriter( + partitionId: Int, + attemptNumber: Int, + epochId: Long): DataWriter[Row] = { val jobPath = new Path(new Path(path, "_temporary"), jobId) val filePath = new Path(jobPath, s"$jobId-$partitionId-$attemptNumber") val fs = filePath.getFileSystem(conf.value) @@ -240,7 +243,10 @@ class SimpleCSVDataWriter(fs: FileSystem, file: Path) extends DataWriter[Row] { class InternalRowCSVDataWriterFactory(path: String, jobId: String, conf: SerializableConfiguration) extends DataWriterFactory[InternalRow] { - override def createDataWriter(partitionId: Int, attemptNumber: Int): DataWriter[InternalRow] = { + override def createDataWriter( + partitionId: Int, + attemptNumber: Int, + epochId: Long): DataWriter[InternalRow] = { val jobPath = new Path(new Path(path, "_temporary"), jobId) val filePath = new Path(jobPath, s"$jobId-$partitionId-$attemptNumber") val fs = filePath.getFileSystem(conf.value) From f2cab56ca22ed5db5ff604cd78cdb55aaa58f651 Mon Sep 17 00:00:00 2001 From: Xianjin YE Date: Mon, 5 Mar 2018 14:57:32 -0800 Subject: [PATCH 0432/1161] [SPARK-23040][CORE] Returns interruptible iterator for shuffle reader ## What changes were proposed in this pull request? Before this commit, a non-interruptible iterator is returned if aggregator or ordering is specified. This commit also ensures that sorter is closed even when task is cancelled(killed) in the middle of sorting. ## How was this patch tested? Add a unit test in JobCancellationSuite Author: Xianjin YE Closes #20449 from advancedxy/SPARK-23040. --- .../shuffle/BlockStoreShuffleReader.scala | 9 ++- .../apache/spark/JobCancellationSuite.scala | 65 ++++++++++++++++++- 2 files changed, 72 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala index edd69715c9602..85e7e56a04a7d 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala @@ -94,7 +94,7 @@ private[spark] class BlockStoreShuffleReader[K, C]( } // Sort the output if there is a sort ordering defined. - dep.keyOrdering match { + val resultIter = dep.keyOrdering match { case Some(keyOrd: Ordering[K]) => // Create an ExternalSorter to sort the data. val sorter = @@ -103,9 +103,16 @@ private[spark] class BlockStoreShuffleReader[K, C]( context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled) context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled) context.taskMetrics().incPeakExecutionMemory(sorter.peakMemoryUsedBytes) + // Use completion callback to stop sorter if task was finished/cancelled. + context.addTaskCompletionListener(_ => { + sorter.stop() + }) CompletionIterator[Product2[K, C], Iterator[Product2[K, C]]](sorter.iterator, sorter.stop()) case None => aggregatedIter } + // Use another interruptible iterator here to support task cancellation as aggregator or(and) + // sorter may have consumed previous interruptible iterator. + new InterruptibleIterator[Product2[K, C]](context, resultIter) } } diff --git a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala index 8a77aea75a992..3b793bb231cf3 100644 --- a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala +++ b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark import java.util.concurrent.Semaphore +import java.util.concurrent.atomic.AtomicInteger import scala.concurrent.ExecutionContext.Implicits.global import scala.concurrent.Future @@ -26,7 +27,7 @@ import scala.concurrent.duration._ import org.scalatest.BeforeAndAfter import org.scalatest.Matchers -import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskStart} +import org.apache.spark.scheduler.{SparkListener, SparkListenerStageCompleted, SparkListenerTaskEnd, SparkListenerTaskStart} import org.apache.spark.util.ThreadUtils /** @@ -40,6 +41,10 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft override def afterEach() { try { resetSparkContext() + JobCancellationSuite.taskStartedSemaphore.drainPermits() + JobCancellationSuite.taskCancelledSemaphore.drainPermits() + JobCancellationSuite.twoJobsSharingStageSemaphore.drainPermits() + JobCancellationSuite.executionOfInterruptibleCounter.set(0) } finally { super.afterEach() } @@ -320,6 +325,62 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft f2.get() } + test("interruptible iterator of shuffle reader") { + // In this test case, we create a Spark job of two stages. The second stage is cancelled during + // execution and a counter is used to make sure that the corresponding tasks are indeed + // cancelled. + import JobCancellationSuite._ + sc = new SparkContext("local[2]", "test interruptible iterator") + + val taskCompletedSem = new Semaphore(0) + + sc.addSparkListener(new SparkListener { + override def onStageCompleted(stageCompleted: SparkListenerStageCompleted): Unit = { + // release taskCancelledSemaphore when cancelTasks event has been posted + if (stageCompleted.stageInfo.stageId == 1) { + taskCancelledSemaphore.release(1000) + } + } + + override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = { + if (taskEnd.stageId == 1) { // make sure tasks are completed + taskCompletedSem.release() + } + } + }) + + val f = sc.parallelize(1 to 1000).map { i => (i, i) } + .repartitionAndSortWithinPartitions(new HashPartitioner(1)) + .mapPartitions { iter => + taskStartedSemaphore.release() + iter + }.foreachAsync { x => + if (x._1 >= 10) { + // This block of code is partially executed. It will be blocked when x._1 >= 10 and the + // next iteration will be cancelled if the source iterator is interruptible. Then in this + // case, the maximum num of increment would be 10(|1...10|) + taskCancelledSemaphore.acquire() + } + executionOfInterruptibleCounter.getAndIncrement() + } + + taskStartedSemaphore.acquire() + // Job is cancelled when: + // 1. task in reduce stage has been started, guaranteed by previous line. + // 2. task in reduce stage is blocked after processing at most 10 records as + // taskCancelledSemaphore is not released until cancelTasks event is posted + // After job being cancelled, task in reduce stage will be cancelled and no more iteration are + // executed. + f.cancel() + + val e = intercept[SparkException](f.get()).getCause + assert(e.getMessage.contains("cancelled") || e.getMessage.contains("killed")) + + // Make sure tasks are indeed completed. + taskCompletedSem.acquire() + assert(executionOfInterruptibleCounter.get() <= 10) + } + def testCount() { // Cancel before launching any tasks { @@ -381,7 +442,9 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft object JobCancellationSuite { + // To avoid any headaches, reset these global variables in the companion class's afterEach block val taskStartedSemaphore = new Semaphore(0) val taskCancelledSemaphore = new Semaphore(0) val twoJobsSharingStageSemaphore = new Semaphore(0) + val executionOfInterruptibleCounter = new AtomicInteger(0) } From 508573958dc9b6402e684cd6dd37202deaaa97f6 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Mon, 5 Mar 2018 15:03:27 -0800 Subject: [PATCH 0433/1161] [SPARK-23538][CORE] Remove custom configuration for SSL client. These options were used to configure the built-in JRE SSL libraries when downloading files from HTTPS servers. But because they were also used to set up the now (long) removed internal HTTPS file server, their default configuration chose convenience over security by having overly lenient settings. This change removes the configuration options that affect the JRE SSL libraries. The JRE trust store can still be configured via system properties (or globally in the JRE security config). The only lost functionality is not being able to disable the default hostname verifier when using spark-submit, which should be fine since Spark itself is not using https for any internal functionality anymore. I also removed the HTTP-related code from the REPL class loader, since we haven't had a HTTP server for REPL-generated classes for a while. Author: Marcelo Vanzin Closes #20723 from vanzin/SPARK-23538. --- .../org/apache/spark/SecurityManager.scala | 45 ------------ .../scala/org/apache/spark/util/Utils.scala | 15 ---- .../org/apache/spark/SSLSampleConfigs.scala | 68 ------------------- .../apache/spark/SecurityManagerSuite.scala | 45 ------------ docs/security.md | 4 -- .../spark/repl/ExecutorClassLoader.scala | 53 ++------------- 6 files changed, 7 insertions(+), 223 deletions(-) delete mode 100644 core/src/test/scala/org/apache/spark/SSLSampleConfigs.scala diff --git a/core/src/main/scala/org/apache/spark/SecurityManager.scala b/core/src/main/scala/org/apache/spark/SecurityManager.scala index 2519d266879aa..da1c89cd78901 100644 --- a/core/src/main/scala/org/apache/spark/SecurityManager.scala +++ b/core/src/main/scala/org/apache/spark/SecurityManager.scala @@ -256,51 +256,6 @@ private[spark] class SecurityManager( // the default SSL configuration - it will be used by all communication layers unless overwritten private val defaultSSLOptions = SSLOptions.parse(sparkConf, "spark.ssl", defaults = None) - // SSL configuration for the file server. This is used by Utils.setupSecureURLConnection(). - val fileServerSSLOptions = getSSLOptions("fs") - val (sslSocketFactory, hostnameVerifier) = if (fileServerSSLOptions.enabled) { - val trustStoreManagers = - for (trustStore <- fileServerSSLOptions.trustStore) yield { - val input = Files.asByteSource(fileServerSSLOptions.trustStore.get).openStream() - - try { - val ks = KeyStore.getInstance(KeyStore.getDefaultType) - ks.load(input, fileServerSSLOptions.trustStorePassword.get.toCharArray) - - val tmf = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm) - tmf.init(ks) - tmf.getTrustManagers - } finally { - input.close() - } - } - - lazy val credulousTrustStoreManagers = Array({ - logWarning("Using 'accept-all' trust manager for SSL connections.") - new X509TrustManager { - override def getAcceptedIssuers: Array[X509Certificate] = null - - override def checkClientTrusted(x509Certificates: Array[X509Certificate], s: String) {} - - override def checkServerTrusted(x509Certificates: Array[X509Certificate], s: String) {} - }: TrustManager - }) - - require(fileServerSSLOptions.protocol.isDefined, - "spark.ssl.protocol is required when enabling SSL connections.") - - val sslContext = SSLContext.getInstance(fileServerSSLOptions.protocol.get) - sslContext.init(null, trustStoreManagers.getOrElse(credulousTrustStoreManagers), null) - - val hostVerifier = new HostnameVerifier { - override def verify(s: String, sslSession: SSLSession): Boolean = true - } - - (Some(sslContext.getSocketFactory), Some(hostVerifier)) - } else { - (None, None) - } - def getSSLOptions(module: String): SSLOptions = { val opts = SSLOptions.parse(sparkConf, s"spark.ssl.$module", Some(defaultSSLOptions)) logDebug(s"Created SSL options for $module: $opts") diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index d493663f0b168..2e2a4a259e9af 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -673,7 +673,6 @@ private[spark] object Utils extends Logging { logDebug("fetchFile not using security") uc = new URL(url).openConnection() } - Utils.setupSecureURLConnection(uc, securityMgr) val timeoutMs = conf.getTimeAsSeconds("spark.files.fetchTimeout", "60s").toInt * 1000 @@ -2363,20 +2362,6 @@ private[spark] object Utils extends Logging { PropertyConfigurator.configure(pro) } - /** - * If the given URL connection is HttpsURLConnection, it sets the SSL socket factory and - * the host verifier from the given security manager. - */ - def setupSecureURLConnection(urlConnection: URLConnection, sm: SecurityManager): URLConnection = { - urlConnection match { - case https: HttpsURLConnection => - sm.sslSocketFactory.foreach(https.setSSLSocketFactory) - sm.hostnameVerifier.foreach(https.setHostnameVerifier) - https - case connection => connection - } - } - def invoke( clazz: Class[_], obj: AnyRef, diff --git a/core/src/test/scala/org/apache/spark/SSLSampleConfigs.scala b/core/src/test/scala/org/apache/spark/SSLSampleConfigs.scala deleted file mode 100644 index 33270bec6247c..0000000000000 --- a/core/src/test/scala/org/apache/spark/SSLSampleConfigs.scala +++ /dev/null @@ -1,68 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark - -import java.io.File - -object SSLSampleConfigs { - val keyStorePath = new File(this.getClass.getResource("/keystore").toURI).getAbsolutePath - val untrustedKeyStorePath = new File( - this.getClass.getResource("/untrusted-keystore").toURI).getAbsolutePath - val trustStorePath = new File(this.getClass.getResource("/truststore").toURI).getAbsolutePath - - val enabledAlgorithms = - // A reasonable set of TLSv1.2 Oracle security provider suites - "TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA384, " + - "TLS_RSA_WITH_AES_256_CBC_SHA256, " + - "TLS_DHE_RSA_WITH_AES_256_CBC_SHA256, " + - "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256, " + - "TLS_DHE_RSA_WITH_AES_128_CBC_SHA256, " + - // and their equivalent names in the IBM Security provider - "SSL_ECDHE_RSA_WITH_AES_256_CBC_SHA384, " + - "SSL_RSA_WITH_AES_256_CBC_SHA256, " + - "SSL_DHE_RSA_WITH_AES_256_CBC_SHA256, " + - "SSL_ECDHE_RSA_WITH_AES_128_CBC_SHA256, " + - "SSL_DHE_RSA_WITH_AES_128_CBC_SHA256" - - def sparkSSLConfig(): SparkConf = { - val conf = new SparkConf(loadDefaults = false) - conf.set("spark.ssl.enabled", "true") - conf.set("spark.ssl.keyStore", keyStorePath) - conf.set("spark.ssl.keyStorePassword", "password") - conf.set("spark.ssl.keyPassword", "password") - conf.set("spark.ssl.trustStore", trustStorePath) - conf.set("spark.ssl.trustStorePassword", "password") - conf.set("spark.ssl.enabledAlgorithms", enabledAlgorithms) - conf.set("spark.ssl.protocol", "TLSv1.2") - conf - } - - def sparkSSLConfigUntrusted(): SparkConf = { - val conf = new SparkConf(loadDefaults = false) - conf.set("spark.ssl.enabled", "true") - conf.set("spark.ssl.keyStore", untrustedKeyStorePath) - conf.set("spark.ssl.keyStorePassword", "password") - conf.set("spark.ssl.keyPassword", "password") - conf.set("spark.ssl.trustStore", trustStorePath) - conf.set("spark.ssl.trustStorePassword", "password") - conf.set("spark.ssl.enabledAlgorithms", enabledAlgorithms) - conf.set("spark.ssl.protocol", "TLSv1.2") - conf - } - -} diff --git a/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala b/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala index 106ece7aed0a4..e357299770a2e 100644 --- a/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala @@ -370,51 +370,6 @@ class SecurityManagerSuite extends SparkFunSuite with ResetSystemProperties { assert(securityManager.checkModifyPermissions("user1") === false) } - test("ssl on setup") { - val conf = SSLSampleConfigs.sparkSSLConfig() - val expectedAlgorithms = Set( - "TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA384", - "TLS_RSA_WITH_AES_256_CBC_SHA256", - "TLS_DHE_RSA_WITH_AES_256_CBC_SHA256", - "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256", - "TLS_DHE_RSA_WITH_AES_128_CBC_SHA256", - "SSL_ECDHE_RSA_WITH_AES_256_CBC_SHA384", - "SSL_RSA_WITH_AES_256_CBC_SHA256", - "SSL_DHE_RSA_WITH_AES_256_CBC_SHA256", - "SSL_ECDHE_RSA_WITH_AES_128_CBC_SHA256", - "SSL_DHE_RSA_WITH_AES_128_CBC_SHA256") - - val securityManager = new SecurityManager(conf) - - assert(securityManager.fileServerSSLOptions.enabled === true) - - assert(securityManager.sslSocketFactory.isDefined === true) - assert(securityManager.hostnameVerifier.isDefined === true) - - assert(securityManager.fileServerSSLOptions.trustStore.isDefined === true) - assert(securityManager.fileServerSSLOptions.trustStore.get.getName === "truststore") - assert(securityManager.fileServerSSLOptions.keyStore.isDefined === true) - assert(securityManager.fileServerSSLOptions.keyStore.get.getName === "keystore") - assert(securityManager.fileServerSSLOptions.trustStorePassword === Some("password")) - assert(securityManager.fileServerSSLOptions.keyStorePassword === Some("password")) - assert(securityManager.fileServerSSLOptions.keyPassword === Some("password")) - assert(securityManager.fileServerSSLOptions.protocol === Some("TLSv1.2")) - assert(securityManager.fileServerSSLOptions.enabledAlgorithms === expectedAlgorithms) - } - - test("ssl off setup") { - val file = File.createTempFile("SSLOptionsSuite", "conf", Utils.createTempDir()) - - System.setProperty("spark.ssl.configFile", file.getAbsolutePath) - val conf = new SparkConf() - - val securityManager = new SecurityManager(conf) - - assert(securityManager.fileServerSSLOptions.enabled === false) - assert(securityManager.sslSocketFactory.isDefined === false) - assert(securityManager.hostnameVerifier.isDefined === false) - } - test("missing secret authentication key") { val conf = new SparkConf().set("spark.authenticate", "true") val mgr = new SecurityManager(conf) diff --git a/docs/security.md b/docs/security.md index 0f384b411812a..913d9df50eb1c 100644 --- a/docs/security.md +++ b/docs/security.md @@ -44,10 +44,6 @@ component-specific configuration namespaces used to override the default setting Config Namespace Component - - spark.ssl.fs - File download client (used to download jars and files from HTTPS-enabled servers). - spark.ssl.ui Spark application Web UI diff --git a/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala b/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala index 127f67329f266..4dc399827ffed 100644 --- a/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala +++ b/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala @@ -17,12 +17,10 @@ package org.apache.spark.repl -import java.io.{ByteArrayOutputStream, FileNotFoundException, FilterInputStream, InputStream, IOException} -import java.net.{HttpURLConnection, URI, URL, URLEncoder} +import java.io.{ByteArrayOutputStream, FileNotFoundException, FilterInputStream, InputStream} +import java.net.{URI, URL, URLEncoder} import java.nio.channels.Channels -import scala.util.control.NonFatal - import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.xbean.asm5._ import org.apache.xbean.asm5.Opcodes._ @@ -30,13 +28,13 @@ import org.apache.xbean.asm5.Opcodes._ import org.apache.spark.{SparkConf, SparkEnv} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging -import org.apache.spark.util.{ParentClassLoader, Utils} +import org.apache.spark.util.ParentClassLoader /** - * A ClassLoader that reads classes from a Hadoop FileSystem or HTTP URI, used to load classes - * defined by the interpreter when the REPL is used. Allows the user to specify if user class path - * should be first. This class loader delegates getting/finding resources to parent loader, which - * makes sense until REPL never provide resource dynamically. + * A ClassLoader that reads classes from a Hadoop FileSystem or Spark RPC endpoint, used to load + * classes defined by the interpreter when the REPL is used. Allows the user to specify if user + * class path should be first. This class loader delegates getting/finding resources to parent + * loader, which makes sense until REPL never provide resource dynamically. * * Note: [[ClassLoader]] will preferentially load class from parent. Only when parent is null or * the load failed, that it will call the overridden `findClass` function. To avoid the potential @@ -60,7 +58,6 @@ class ExecutorClassLoader( private val fetchFn: (String) => InputStream = uri.getScheme() match { case "spark" => getClassFileInputStreamFromSparkRPC - case "http" | "https" | "ftp" => getClassFileInputStreamFromHttpServer case _ => val fileSystem = FileSystem.get(uri, SparkHadoopUtil.get.newConfiguration(conf)) getClassFileInputStreamFromFileSystem(fileSystem) @@ -113,42 +110,6 @@ class ExecutorClassLoader( } } - private def getClassFileInputStreamFromHttpServer(pathInDirectory: String): InputStream = { - val url = if (SparkEnv.get.securityManager.isAuthenticationEnabled()) { - val uri = new URI(classUri + "/" + urlEncode(pathInDirectory)) - val newuri = Utils.constructURIForAuthentication(uri, SparkEnv.get.securityManager) - newuri.toURL - } else { - new URL(classUri + "/" + urlEncode(pathInDirectory)) - } - val connection: HttpURLConnection = Utils.setupSecureURLConnection(url.openConnection(), - SparkEnv.get.securityManager).asInstanceOf[HttpURLConnection] - // Set the connection timeouts (for testing purposes) - if (httpUrlConnectionTimeoutMillis != -1) { - connection.setConnectTimeout(httpUrlConnectionTimeoutMillis) - connection.setReadTimeout(httpUrlConnectionTimeoutMillis) - } - connection.connect() - try { - if (connection.getResponseCode != 200) { - // Close the error stream so that the connection is eligible for re-use - try { - connection.getErrorStream.close() - } catch { - case ioe: IOException => - logError("Exception while closing error stream", ioe) - } - throw new ClassNotFoundException(s"Class file not found at URL $url") - } else { - connection.getInputStream - } - } catch { - case NonFatal(e) if !e.isInstanceOf[ClassNotFoundException] => - connection.disconnect() - throw e - } - } - private def getClassFileInputStreamFromFileSystem(fileSystem: FileSystem)( pathInDirectory: String): InputStream = { val path = new Path(directory, pathInDirectory) From 7706eea6a8bdcd73e9dde5212368f8825e2f1801 Mon Sep 17 00:00:00 2001 From: Yogesh Garg Date: Mon, 5 Mar 2018 15:53:10 -0800 Subject: [PATCH 0434/1161] [SPARK-18630][PYTHON][ML] Move del method from JavaParams to JavaWrapper; add tests The `__del__` method that explicitly detaches the object was moved from `JavaParams` to `JavaWrapper` class, this way model summaries could also be garbage collected in Java. A test case was added to make sure that relevant error messages are thrown after the objects are deleted. I ran pyspark tests agains `pyspark-ml` module `./python/run-tests --python-executables=$(which python) --modules=pyspark-ml` Author: Yogesh Garg Closes #20724 from yogeshg/java_wrapper_memory. --- python/pyspark/ml/tests.py | 39 ++++++++++++++++++++++++++++++++++++ python/pyspark/ml/wrapper.py | 8 ++++---- 2 files changed, 43 insertions(+), 4 deletions(-) diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 116885969345c..6dee6938d8916 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -173,6 +173,45 @@ class MockModel(MockTransformer, Model, HasFake): pass +class JavaWrapperMemoryTests(SparkSessionTestCase): + + def test_java_object_gets_detached(self): + df = self.spark.createDataFrame([(1.0, 2.0, Vectors.dense(1.0)), + (0.0, 2.0, Vectors.sparse(1, [], []))], + ["label", "weight", "features"]) + lr = LinearRegression(maxIter=1, regParam=0.0, solver="normal", weightCol="weight", + fitIntercept=False) + + model = lr.fit(df) + summary = model.summary + + self.assertIsInstance(model, JavaWrapper) + self.assertIsInstance(summary, JavaWrapper) + self.assertIsInstance(model, JavaParams) + self.assertNotIsInstance(summary, JavaParams) + + error_no_object = 'Target Object ID does not exist for this gateway' + + self.assertIn("LinearRegression_", model._java_obj.toString()) + self.assertIn("LinearRegressionTrainingSummary", summary._java_obj.toString()) + + model.__del__() + + with self.assertRaisesRegexp(py4j.protocol.Py4JError, error_no_object): + model._java_obj.toString() + self.assertIn("LinearRegressionTrainingSummary", summary._java_obj.toString()) + + try: + summary.__del__() + except: + pass + + with self.assertRaisesRegexp(py4j.protocol.Py4JError, error_no_object): + model._java_obj.toString() + with self.assertRaisesRegexp(py4j.protocol.Py4JError, error_no_object): + summary._java_obj.toString() + + class ParamTypeConversionTests(PySparkTestCase): """ Test that param type conversion happens. diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py index 0f846fbc5b5ef..5061f6434794a 100644 --- a/python/pyspark/ml/wrapper.py +++ b/python/pyspark/ml/wrapper.py @@ -36,6 +36,10 @@ def __init__(self, java_obj=None): super(JavaWrapper, self).__init__() self._java_obj = java_obj + def __del__(self): + if SparkContext._active_spark_context and self._java_obj is not None: + SparkContext._active_spark_context._gateway.detach(self._java_obj) + @classmethod def _create_from_java_class(cls, java_class, *args): """ @@ -100,10 +104,6 @@ class JavaParams(JavaWrapper, Params): __metaclass__ = ABCMeta - def __del__(self): - if SparkContext._active_spark_context: - SparkContext._active_spark_context._gateway.detach(self._java_obj) - def _make_java_param_pair(self, param, value): """ Makes a Java param pair. From f6b49f9d1b6f218408197f7272c1999fe3d94328 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Tue, 6 Mar 2018 01:37:51 +0100 Subject: [PATCH 0435/1161] [SPARK-23586][SQL] Add interpreted execution to WrapOption ## What changes were proposed in this pull request? The PR adds interpreted execution to WrapOption. ## How was this patch tested? added UT Author: Marco Gaido Closes #20741 from mgaido91/SPARK-23586_2. --- .../sql/catalyst/expressions/objects/objects.scala | 3 +-- .../catalyst/expressions/ObjectExpressionsSuite.scala | 11 ++++++++++- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 03cc8eaceb4e6..d832fe0a6857c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -422,8 +422,7 @@ case class WrapOption(child: Expression, optType: DataType) override def inputTypes: Seq[AbstractDataType] = optType :: Nil - override def eval(input: InternalRow): Any = - throw new UnsupportedOperationException("Only code-generated evaluation is supported") + override def eval(input: InternalRow): Any = Option(child.eval(input)) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val inputObject = child.genCode(ctx) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala index d95db5867b19c..d535578a7eb06 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder -import org.apache.spark.sql.catalyst.expressions.objects.{Invoke, UnwrapOption} +import org.apache.spark.sql.catalyst.expressions.objects._ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData} import org.apache.spark.sql.types.{IntegerType, ObjectType} @@ -75,4 +75,13 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(unwrapObject, expected, InternalRow.fromSeq(Seq(input))) } } + + test("SPARK-23586: WrapOption should support interpreted execution") { + val cls = ObjectType(classOf[java.lang.Integer]) + val inputObject = BoundReference(0, cls, nullable = true) + val wrapObject = WrapOption(inputObject, cls) + Seq((1, Some(1)), (null, None)).foreach { case (input, expected) => + checkEvaluation(wrapObject, expected, InternalRow.fromSeq(Seq(input))) + } + } } From 8c5b34c425bda2079a1ff969b12c067f2bb3f18f Mon Sep 17 00:00:00 2001 From: Henry Robinson Date: Mon, 5 Mar 2018 16:49:24 -0800 Subject: [PATCH 0436/1161] =?UTF-8?q?[SPARK-23604][SQL]=20Change=20Statist?= =?UTF-8?q?ics.isEmpty=20to=20!Statistics.hasNonNul=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …lValue ## What changes were proposed in this pull request? Parquet 1.9 will change the semantics of Statistics.isEmpty slightly to reflect if the null value count has been set. That breaks a timestamp interoperability test that cares only about whether there are column values present in the statistics of a written file for an INT96 column. Fix by using Statistics.hasNonNullValue instead. ## How was this patch tested? Unit tests continue to pass against Parquet 1.8, and also pass against a Parquet build including PARQUET-1217. Author: Henry Robinson Closes #20740 from henryr/spark-23604. --- .../datasources/parquet/ParquetInteroperabilitySuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetInteroperabilitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetInteroperabilitySuite.scala index fbd83a0fa425a..9c75965639d8a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetInteroperabilitySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetInteroperabilitySuite.scala @@ -184,7 +184,7 @@ class ParquetInteroperabilitySuite extends ParquetCompatibilityTest with SharedS // when the data is read back as mentioned above, b/c int96 is unsigned. This // assert makes sure this holds even if we change parquet versions (if eg. there // were ever statistics even on unsigned columns). - assert(columnStats.isEmpty) + assert(!columnStats.hasNonNullValue) } // These queries should return the entire dataset with the conversion applied, From ad640a5affceaaf3979e25848628fb1dfcdf932a Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 5 Mar 2018 20:35:14 -0800 Subject: [PATCH 0437/1161] [SPARK-23303][SQL] improve the explain result for data source v2 relations ## What changes were proposed in this pull request? The proposed explain format: **[streaming header] [RelationV2/ScanV2] [data source name] [output] [pushed filters] [options]** **streaming header**: if it's a streaming relation, put a "Streaming" at the beginning. **RelationV2/ScanV2**: if it's a logical plan, put a "RelationV2", else, put a "ScanV2" **data source name**: the simple class name of the data source implementation **output**: a string of the plan output attributes **pushed filters**: a string of all the filters that have been pushed to this data source **options**: all the options to create the data source reader. The current explain result for data source v2 relation is unreadable: ``` == Parsed Logical Plan == 'Filter ('i > 6) +- AnalysisBarrier +- Project [j#1] +- DataSourceV2Relation [i#0, j#1], org.apache.spark.sql.sources.v2.AdvancedDataSourceV2$Reader3b415940 == Analyzed Logical Plan == j: int Project [j#1] +- Filter (i#0 > 6) +- Project [j#1, i#0] +- DataSourceV2Relation [i#0, j#1], org.apache.spark.sql.sources.v2.AdvancedDataSourceV2$Reader3b415940 == Optimized Logical Plan == Project [j#1] +- Filter isnotnull(i#0) +- DataSourceV2Relation [i#0, j#1], org.apache.spark.sql.sources.v2.AdvancedDataSourceV2$Reader3b415940 == Physical Plan == *(1) Project [j#1] +- *(1) Filter isnotnull(i#0) +- *(1) DataSourceV2Scan [i#0, j#1], org.apache.spark.sql.sources.v2.AdvancedDataSourceV2$Reader3b415940 ``` after this PR ``` == Parsed Logical Plan == 'Project [unresolvedalias('j, None)] +- AnalysisBarrier +- RelationV2 AdvancedDataSourceV2[i#0, j#1] == Analyzed Logical Plan == j: int Project [j#1] +- RelationV2 AdvancedDataSourceV2[i#0, j#1] == Optimized Logical Plan == RelationV2 AdvancedDataSourceV2[j#1] == Physical Plan == *(1) ScanV2 AdvancedDataSourceV2[j#1] ``` ------- ``` == Analyzed Logical Plan == i: int, j: int Filter (i#88 > 3) +- RelationV2 JavaAdvancedDataSourceV2[i#88, j#89] == Optimized Logical Plan == Filter isnotnull(i#88) +- RelationV2 JavaAdvancedDataSourceV2[i#88, j#89] (Pushed Filters: [GreaterThan(i,3)]) == Physical Plan == *(1) Filter isnotnull(i#88) +- *(1) ScanV2 JavaAdvancedDataSourceV2[i#88, j#89] (Pushed Filters: [GreaterThan(i,3)]) ``` an example for streaming query ``` == Parsed Logical Plan == Aggregate [value#6], [value#6, count(1) AS count(1)#11L] +- SerializeFromObject [staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, input[0, java.lang.String, true], true, false) AS value#6] +- MapElements , class java.lang.String, [StructField(value,StringType,true)], obj#5: java.lang.String +- DeserializeToObject cast(value#25 as string).toString, obj#4: java.lang.String +- Streaming RelationV2 MemoryStreamDataSource[value#25] == Analyzed Logical Plan == value: string, count(1): bigint Aggregate [value#6], [value#6, count(1) AS count(1)#11L] +- SerializeFromObject [staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, input[0, java.lang.String, true], true, false) AS value#6] +- MapElements , class java.lang.String, [StructField(value,StringType,true)], obj#5: java.lang.String +- DeserializeToObject cast(value#25 as string).toString, obj#4: java.lang.String +- Streaming RelationV2 MemoryStreamDataSource[value#25] == Optimized Logical Plan == Aggregate [value#6], [value#6, count(1) AS count(1)#11L] +- SerializeFromObject [staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, input[0, java.lang.String, true], true, false) AS value#6] +- MapElements , class java.lang.String, [StructField(value,StringType,true)], obj#5: java.lang.String +- DeserializeToObject value#25.toString, obj#4: java.lang.String +- Streaming RelationV2 MemoryStreamDataSource[value#25] == Physical Plan == *(4) HashAggregate(keys=[value#6], functions=[count(1)], output=[value#6, count(1)#11L]) +- StateStoreSave [value#6], state info [ checkpoint = *********(redacted)/cloud/dev/spark/target/tmp/temporary-549f264b-2531-4fcb-a52f-433c77347c12/state, runId = f84d9da9-2f8c-45c1-9ea1-70791be684de, opId = 0, ver = 0, numPartitions = 5], Complete, 0 +- *(3) HashAggregate(keys=[value#6], functions=[merge_count(1)], output=[value#6, count#16L]) +- StateStoreRestore [value#6], state info [ checkpoint = *********(redacted)/cloud/dev/spark/target/tmp/temporary-549f264b-2531-4fcb-a52f-433c77347c12/state, runId = f84d9da9-2f8c-45c1-9ea1-70791be684de, opId = 0, ver = 0, numPartitions = 5] +- *(2) HashAggregate(keys=[value#6], functions=[merge_count(1)], output=[value#6, count#16L]) +- Exchange hashpartitioning(value#6, 5) +- *(1) HashAggregate(keys=[value#6], functions=[partial_count(1)], output=[value#6, count#16L]) +- *(1) SerializeFromObject [staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, input[0, java.lang.String, true], true, false) AS value#6] +- *(1) MapElements , obj#5: java.lang.String +- *(1) DeserializeToObject value#25.toString, obj#4: java.lang.String +- *(1) ScanV2 MemoryStreamDataSource[value#25] ``` ## How was this patch tested? N/A Author: Wenchen Fan Closes #20647 from cloud-fan/explain. --- .../kafka010/KafkaContinuousSourceSuite.scala | 2 +- .../sql/kafka010/KafkaContinuousTest.scala | 2 +- .../kafka010/KafkaMicroBatchSourceSuite.scala | 2 +- .../v2/DataSourceReaderHolder.scala | 64 ------------- .../datasources/v2/DataSourceV2Relation.scala | 34 +++++-- .../datasources/v2/DataSourceV2ScanExec.scala | 18 +++- .../datasources/v2/DataSourceV2Strategy.scala | 8 +- .../v2/DataSourceV2StringFormat.scala | 94 +++++++++++++++++++ .../streaming/MicroBatchExecution.scala | 29 +++++- .../continuous/ContinuousExecution.scala | 8 +- .../spark/sql/streaming/StreamSuite.scala | 12 ++- .../spark/sql/streaming/StreamTest.scala | 4 +- .../continuous/ContinuousSuite.scala | 11 +-- 13 files changed, 183 insertions(+), 105 deletions(-) delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceReaderHolder.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StringFormat.scala diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala index f679e9bfc0450..aab8ec42189fb 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala @@ -60,7 +60,7 @@ class KafkaContinuousSourceTopicDeletionSuite extends KafkaContinuousTest { eventually(timeout(streamingTimeout)) { assert( query.lastExecution.logical.collectFirst { - case StreamingDataSourceV2Relation(_, r: KafkaContinuousReader) => r + case StreamingDataSourceV2Relation(_, _, _, r: KafkaContinuousReader) => r }.exists { r => // Ensure the new topic is present and the old topic is gone. r.knownPartitions.exists(_.topic == topic2) diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala index 48ac3fc1e8f9d..fa1468a3943c8 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala @@ -47,7 +47,7 @@ trait KafkaContinuousTest extends KafkaSourceTest { eventually(timeout(streamingTimeout)) { assert( query.lastExecution.logical.collectFirst { - case StreamingDataSourceV2Relation(_, r: KafkaContinuousReader) => r + case StreamingDataSourceV2Relation(_, _, _, r: KafkaContinuousReader) => r }.exists(_.knownPartitions.size == newCount), s"query never reconfigured to $newCount partitions") } diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala index f2b3ff7615e74..e017fd9b84d21 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala @@ -124,7 +124,7 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext { } ++ (query.get.lastExecution match { case null => Seq() case e => e.logical.collect { - case StreamingDataSourceV2Relation(_, reader: KafkaContinuousReader) => reader + case StreamingDataSourceV2Relation(_, _, _, reader: KafkaContinuousReader) => reader } }) }.distinct diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceReaderHolder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceReaderHolder.scala deleted file mode 100644 index 81219e9771bd8..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceReaderHolder.scala +++ /dev/null @@ -1,64 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.datasources.v2 - -import java.util.Objects - -import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.sources.v2.reader._ - -/** - * A base class for data source reader holder with customized equals/hashCode methods. - */ -trait DataSourceReaderHolder { - - /** - * The output of the data source reader, w.r.t. column pruning. - */ - def output: Seq[Attribute] - - /** - * The held data source reader. - */ - def reader: DataSourceReader - - /** - * The metadata of this data source reader that can be used for equality test. - */ - private def metadata: Seq[Any] = { - val filters: Any = reader match { - case s: SupportsPushDownCatalystFilters => s.pushedCatalystFilters().toSet - case s: SupportsPushDownFilters => s.pushedFilters().toSet - case _ => Nil - } - Seq(output, reader.getClass, filters) - } - - def canEqual(other: Any): Boolean - - override def equals(other: Any): Boolean = other match { - case other: DataSourceReaderHolder => - canEqual(other) && metadata.length == other.metadata.length && - metadata.zip(other.metadata).forall { case (l, r) => l == r } - case _ => false - } - - override def hashCode(): Int = { - metadata.map(Objects.hashCode).foldLeft(0)((a, b) => 31 * a + b) - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala index cc6cb631e3f06..2b282ffae2390 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala @@ -35,15 +35,12 @@ case class DataSourceV2Relation( options: Map[String, String], projection: Seq[AttributeReference], filters: Option[Seq[Expression]] = None, - userSpecifiedSchema: Option[StructType] = None) extends LeafNode with MultiInstanceRelation { + userSpecifiedSchema: Option[StructType] = None) + extends LeafNode with MultiInstanceRelation with DataSourceV2StringFormat { import DataSourceV2Relation._ - override def simpleString: String = { - s"DataSourceV2Relation(source=${source.name}, " + - s"schema=[${output.map(a => s"$a ${a.dataType.simpleString}").mkString(", ")}], " + - s"filters=[${pushedFilters.mkString(", ")}], options=$options)" - } + override def simpleString: String = "RelationV2 " + metadataString override lazy val schema: StructType = reader.readSchema() @@ -107,19 +104,36 @@ case class DataSourceV2Relation( } /** - * A specialization of DataSourceV2Relation with the streaming bit set to true. Otherwise identical - * to the non-streaming relation. + * A specialization of [[DataSourceV2Relation]] with the streaming bit set to true. + * + * Note that, this plan has a mutable reader, so Spark won't apply operator push-down for this plan, + * to avoid making the plan mutable. We should consolidate this plan and [[DataSourceV2Relation]] + * after we figure out how to apply operator push-down for streaming data sources. */ case class StreamingDataSourceV2Relation( output: Seq[AttributeReference], + source: DataSourceV2, + options: Map[String, String], reader: DataSourceReader) - extends LeafNode with DataSourceReaderHolder with MultiInstanceRelation { + extends LeafNode with MultiInstanceRelation with DataSourceV2StringFormat { + override def isStreaming: Boolean = true - override def canEqual(other: Any): Boolean = other.isInstanceOf[StreamingDataSourceV2Relation] + override def simpleString: String = "Streaming RelationV2 " + metadataString override def newInstance(): LogicalPlan = copy(output = output.map(_.newInstance())) + // TODO: unify the equal/hashCode implementation for all data source v2 query plans. + override def equals(other: Any): Boolean = other match { + case other: StreamingDataSourceV2Relation => + output == other.output && reader.getClass == other.reader.getClass && options == other.options + case _ => false + } + + override def hashCode(): Int = { + Seq(output, source, options).hashCode() + } + override def computeStats(): Statistics = reader match { case r: SupportsReportStatistics => Statistics(sizeInBytes = r.getStatistics.sizeInBytes().orElse(conf.defaultSizeInBytes)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala index 7d9581be4db89..cb691ba297076 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical import org.apache.spark.sql.execution.{ColumnarBatchScan, LeafExecNode, WholeStageCodegenExec} import org.apache.spark.sql.execution.streaming.continuous._ +import org.apache.spark.sql.sources.v2.DataSourceV2 import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousReader import org.apache.spark.sql.types.StructType @@ -36,10 +37,23 @@ import org.apache.spark.sql.types.StructType */ case class DataSourceV2ScanExec( output: Seq[AttributeReference], + @transient source: DataSourceV2, + @transient options: Map[String, String], @transient reader: DataSourceReader) - extends LeafExecNode with DataSourceReaderHolder with ColumnarBatchScan { + extends LeafExecNode with DataSourceV2StringFormat with ColumnarBatchScan { - override def canEqual(other: Any): Boolean = other.isInstanceOf[DataSourceV2ScanExec] + override def simpleString: String = "ScanV2 " + metadataString + + // TODO: unify the equal/hashCode implementation for all data source v2 query plans. + override def equals(other: Any): Boolean = other match { + case other: DataSourceV2ScanExec => + output == other.output && reader.getClass == other.reader.getClass && options == other.options + case _ => false + } + + override def hashCode(): Int = { + Seq(output, source, options).hashCode() + } override def outputPartitioning: physical.Partitioning = reader match { case s: SupportsReportPartitioning => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index c4e7644683c36..1ac9572de6412 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -23,11 +23,11 @@ import org.apache.spark.sql.execution.SparkPlan object DataSourceV2Strategy extends Strategy { override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case relation: DataSourceV2Relation => - DataSourceV2ScanExec(relation.output, relation.reader) :: Nil + case r: DataSourceV2Relation => + DataSourceV2ScanExec(r.output, r.source, r.options, r.reader) :: Nil - case relation: StreamingDataSourceV2Relation => - DataSourceV2ScanExec(relation.output, relation.reader) :: Nil + case r: StreamingDataSourceV2Relation => + DataSourceV2ScanExec(r.output, r.source, r.options, r.reader) :: Nil case WriteToDataSourceV2(writer, query) => WriteToDataSourceV2Exec(writer, planLater(query)) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StringFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StringFormat.scala new file mode 100644 index 0000000000000..aed55a429bfd7 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StringFormat.scala @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import org.apache.commons.lang3.StringUtils + +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.sources.DataSourceRegister +import org.apache.spark.sql.sources.v2.DataSourceV2 +import org.apache.spark.sql.sources.v2.reader._ +import org.apache.spark.util.Utils + +/** + * A trait that can be used by data source v2 related query plans(both logical and physical), to + * provide a string format of the data source information for explain. + */ +trait DataSourceV2StringFormat { + + /** + * The instance of this data source implementation. Note that we only consider its class in + * equals/hashCode, not the instance itself. + */ + def source: DataSourceV2 + + /** + * The output of the data source reader, w.r.t. column pruning. + */ + def output: Seq[Attribute] + + /** + * The options for this data source reader. + */ + def options: Map[String, String] + + /** + * The created data source reader. Here we use it to get the filters that has been pushed down + * so far, itself doesn't take part in the equals/hashCode. + */ + def reader: DataSourceReader + + private lazy val filters = reader match { + case s: SupportsPushDownCatalystFilters => s.pushedCatalystFilters().toSet + case s: SupportsPushDownFilters => s.pushedFilters().toSet + case _ => Set.empty + } + + private def sourceName: String = source match { + case registered: DataSourceRegister => registered.shortName() + case _ => source.getClass.getSimpleName.stripSuffix("$") + } + + def metadataString: String = { + val entries = scala.collection.mutable.ArrayBuffer.empty[(String, String)] + + if (filters.nonEmpty) { + entries += "Filters" -> filters.mkString("[", ", ", "]") + } + + // TODO: we should only display some standard options like path, table, etc. + if (options.nonEmpty) { + entries += "Options" -> Utils.redact(options).map { + case (k, v) => s"$k=$v" + }.mkString("[", ",", "]") + } + + val outputStr = Utils.truncatedString(output, "[", ", ", "]") + + val entriesStr = if (entries.nonEmpty) { + Utils.truncatedString(entries.map { + case (key, value) => key + ": " + StringUtils.abbreviate(value, 100) + }, " (", ", ", ")") + } else { + "" + } + + s"$sourceName$outputStr$entriesStr" + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index ff4be9c7ab874..6e231970f4a22 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -20,16 +20,16 @@ package org.apache.spark.sql.execution.streaming import java.util.Optional import scala.collection.JavaConverters._ -import scala.collection.mutable.{ArrayBuffer, Map => MutableMap} +import scala.collection.mutable.{Map => MutableMap} import org.apache.spark.sql.{Dataset, SparkSession} import org.apache.spark.sql.catalyst.encoders.RowEncoder -import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeMap, CurrentBatchTimestamp, CurrentDate, CurrentTimestamp} +import org.apache.spark.sql.catalyst.expressions.{Alias, CurrentBatchTimestamp, CurrentDate, CurrentTimestamp} import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Project} import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.datasources.v2.{StreamingDataSourceV2Relation, WriteToDataSourceV2} import org.apache.spark.sql.execution.streaming.sources.{InternalRowMicroBatchWriter, MicroBatchWriter} -import org.apache.spark.sql.sources.v2.{DataSourceOptions, MicroBatchReadSupport, StreamWriteSupport} +import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, MicroBatchReadSupport, StreamWriteSupport} import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset => OffsetV2} import org.apache.spark.sql.sources.v2.writer.SupportsWriteInternalRow import org.apache.spark.sql.streaming.{OutputMode, ProcessingTime, Trigger} @@ -52,6 +52,9 @@ class MicroBatchExecution( @volatile protected var sources: Seq[BaseStreamingSource] = Seq.empty + private val readerToDataSourceMap = + MutableMap.empty[MicroBatchReader, (DataSourceV2, Map[String, String])] + private val triggerExecutor = trigger match { case t: ProcessingTime => ProcessingTimeExecutor(t, triggerClock) case OneTimeTrigger => OneTimeExecutor() @@ -97,6 +100,7 @@ class MicroBatchExecution( metadataPath, new DataSourceOptions(options.asJava)) nextSourceId += 1 + readerToDataSourceMap(reader) = dataSourceV2 -> options logInfo(s"Using MicroBatchReader [$reader] from " + s"DataSourceV2 named '$sourceName' [$dataSourceV2]") StreamingExecutionRelation(reader, output)(sparkSession) @@ -419,8 +423,19 @@ class MicroBatchExecution( toJava(current), Optional.of(availableV2)) logDebug(s"Retrieving data from $reader: $current -> $availableV2") - Some(reader -> - new StreamingDataSourceV2Relation(reader.readSchema().toAttributes, reader)) + + val (source, options) = reader match { + // `MemoryStream` is special. It's for test only and doesn't have a `DataSourceV2` + // implementation. We provide a fake one here for explain. + case _: MemoryStream[_] => MemoryStreamDataSource -> Map.empty[String, String] + // Provide a fake value here just in case something went wrong, e.g. the reader gives + // a wrong `equals` implementation. + case _ => readerToDataSourceMap.getOrElse(reader, { + FakeDataSourceV2 -> Map.empty[String, String] + }) + } + Some(reader -> StreamingDataSourceV2Relation( + reader.readSchema().toAttributes, source, options, reader)) case _ => None } } @@ -525,3 +540,7 @@ class MicroBatchExecution( object MicroBatchExecution { val BATCH_ID_KEY = "streaming.sql.batchId" } + +object MemoryStreamDataSource extends DataSourceV2 + +object FakeDataSourceV2 extends DataSourceV2 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index daebd1dd010ac..1758b3844bd62 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, CurrentDate, CurrentTimestamp} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.SQLExecution -import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, StreamingDataSourceV2Relation, WriteToDataSourceV2} +import org.apache.spark.sql.execution.datasources.v2.{StreamingDataSourceV2Relation, WriteToDataSourceV2} import org.apache.spark.sql.execution.streaming.{ContinuousExecutionRelation, StreamingRelationV2, _} import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions, StreamWriteSupport} import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReader, PartitionOffset} @@ -167,7 +167,7 @@ class ContinuousExecution( var insertedSourceId = 0 val withNewSources = logicalPlan transform { - case ContinuousExecutionRelation(_, _, output) => + case ContinuousExecutionRelation(source, options, output) => val reader = continuousSources(insertedSourceId) insertedSourceId += 1 val newOutput = reader.readSchema().toAttributes @@ -180,7 +180,7 @@ class ContinuousExecution( val loggedOffset = offsets.offsets(0) val realOffset = loggedOffset.map(off => reader.deserializeOffset(off.json)) reader.setStartOffset(java.util.Optional.ofNullable(realOffset.orNull)) - new StreamingDataSourceV2Relation(newOutput, reader) + StreamingDataSourceV2Relation(newOutput, source, options, reader) } // Rewire the plan to use the new attributes that were returned by the source. @@ -201,7 +201,7 @@ class ContinuousExecution( val withSink = WriteToDataSourceV2(writer, triggerLogicalPlan) val reader = withSink.collect { - case StreamingDataSourceV2Relation(_, r: ContinuousReader) => r + case StreamingDataSourceV2Relation(_, _, _, r: ContinuousReader) => r }.head reportTimeTaken("queryPlanning") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala index d1a04833390f5..c1ec1eba69fb2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala @@ -492,16 +492,20 @@ class StreamSuite extends StreamTest { val explainWithoutExtended = q.explainInternal(false) // `extended = false` only displays the physical plan. - assert("StreamingDataSourceV2Relation".r.findAllMatchIn(explainWithoutExtended).size === 0) - assert("DataSourceV2Scan".r.findAllMatchIn(explainWithoutExtended).size === 1) + assert("Streaming RelationV2 MemoryStreamDataSource".r + .findAllMatchIn(explainWithoutExtended).size === 0) + assert("ScanV2 MemoryStreamDataSource".r + .findAllMatchIn(explainWithoutExtended).size === 1) // Use "StateStoreRestore" to verify that it does output a streaming physical plan assert(explainWithoutExtended.contains("StateStoreRestore")) val explainWithExtended = q.explainInternal(true) // `extended = true` displays 3 logical plans (Parsed/Optimized/Optimized) and 1 physical // plan. - assert("StreamingDataSourceV2Relation".r.findAllMatchIn(explainWithExtended).size === 3) - assert("DataSourceV2Scan".r.findAllMatchIn(explainWithExtended).size === 1) + assert("Streaming RelationV2 MemoryStreamDataSource".r + .findAllMatchIn(explainWithExtended).size === 3) + assert("ScanV2 MemoryStreamDataSource".r + .findAllMatchIn(explainWithExtended).size === 1) // Use "StateStoreRestore" to verify that it does output a streaming physical plan assert(explainWithExtended.contains("StateStoreRestore")) } finally { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index 08f722ecb10e5..e44aef09f1f3c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -629,8 +629,8 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be def findSourceIndex(plan: LogicalPlan): Option[Int] = { plan .collect { - case StreamingExecutionRelation(s, _) => s - case StreamingDataSourceV2Relation(_, r) => r + case r: StreamingExecutionRelation => r.source + case r: StreamingDataSourceV2Relation => r.reader } .zipWithIndex .find(_._1 == source) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala index 4b4ed82dc6520..f5884b9c8de12 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala @@ -17,15 +17,12 @@ package org.apache.spark.sql.streaming.continuous -import java.util.UUID - -import org.apache.spark.{SparkContext, SparkEnv, SparkException} -import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart, SparkListenerTaskStart} +import org.apache.spark.{SparkContext, SparkException} +import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskStart} import org.apache.spark.sql._ -import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanExec, WriteToDataSourceV2Exec} +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanExec import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous._ -import org.apache.spark.sql.execution.streaming.sources.MemorySinkV2 import org.apache.spark.sql.functions._ import org.apache.spark.sql.streaming.{StreamTest, Trigger} import org.apache.spark.sql.test.TestSparkSession @@ -43,7 +40,7 @@ class ContinuousSuiteBase extends StreamTest { case s: ContinuousExecution => assert(numTriggers >= 2, "must wait for at least 2 triggers to ensure query is initialized") val reader = s.lastExecution.executedPlan.collectFirst { - case DataSourceV2ScanExec(_, r: RateStreamContinuousReader) => r + case DataSourceV2ScanExec(_, _, _, r: RateStreamContinuousReader) => r }.get val deltaMs = numTriggers * 1000 + 300 From e8a259d66dda0d4c76f3af8933676bade8a7451d Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Tue, 6 Mar 2018 13:55:13 +0100 Subject: [PATCH 0438/1161] [SPARK-23594][SQL] GetExternalRowField should support interpreted execution ## What changes were proposed in this pull request? This pr added interpreted execution for `GetExternalRowField`. ## How was this patch tested? Added tests in `ObjectExpressionsSuite`. Author: Takeshi Yamamuro Closes #20746 from maropu/SPARK-23594. --- .../expressions/objects/objects.scala | 14 ++++++++++--- .../expressions/ObjectExpressionsSuite.scala | 20 +++++++++++++++++++ 2 files changed, 31 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index d832fe0a6857c..97e3ff88858d0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -1358,11 +1358,19 @@ case class GetExternalRowField( override def dataType: DataType = ObjectType(classOf[Object]) - override def eval(input: InternalRow): Any = - throw new UnsupportedOperationException("Only code-generated evaluation is supported") - private val errMsg = s"The ${index}th field '$fieldName' of input row cannot be null." + override def eval(input: InternalRow): Any = { + val inputRow = child.eval(input).asInstanceOf[Row] + if (inputRow == null) { + throw new RuntimeException("The input external row cannot be null.") + } + if (inputRow.isNullAt(index)) { + throw new RuntimeException(errMsg) + } + inputRow.get(index) + } + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { // Use unnamed reference that doesn't create a local field here to reduce the number of fields // because errMsgField is used only when the field is null. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala index d535578a7eb06..0f376c4b63c15 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions.objects._ @@ -84,4 +85,23 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(wrapObject, expected, InternalRow.fromSeq(Seq(input))) } } + + test("SPARK-23594 GetExternalRowField should support interpreted execution") { + val inputObject = BoundReference(0, ObjectType(classOf[Row]), nullable = true) + val getRowField = GetExternalRowField(inputObject, index = 0, fieldName = "c0") + Seq((Row(1), 1), (Row(3), 3)).foreach { case (input, expected) => + checkEvaluation(getRowField, expected, InternalRow.fromSeq(Seq(input))) + } + + // If an input row or a field are null, a runtime exception will be thrown + val errMsg1 = intercept[RuntimeException] { + evaluate(getRowField, InternalRow.fromSeq(Seq(null))) + }.getMessage + assert(errMsg1 === "The input external row cannot be null.") + + val errMsg2 = intercept[RuntimeException] { + evaluate(getRowField, InternalRow.fromSeq(Seq(Row(null)))) + }.getMessage + assert(errMsg2 === "The 0th field 'c0' of input row cannot be null.") + } } From 8bceb899dc3220998a4ea4021f3b477f78faaca8 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Tue, 6 Mar 2018 08:52:28 -0600 Subject: [PATCH 0439/1161] [SPARK-23601][BUILD] Remove .md5 files from release ## What changes were proposed in this pull request? Remove .md5 files from release artifacts ## How was this patch tested? N/A Author: Sean Owen Closes #20737 from srowen/SPARK-23601. --- dev/create-release/release-build.sh | 20 +------------------- 1 file changed, 1 insertion(+), 19 deletions(-) diff --git a/dev/create-release/release-build.sh b/dev/create-release/release-build.sh index a3579f21fc539..c00b00b845401 100755 --- a/dev/create-release/release-build.sh +++ b/dev/create-release/release-build.sh @@ -164,8 +164,6 @@ if [[ "$1" == "package" ]]; then tar cvzf spark-$SPARK_VERSION.tgz spark-$SPARK_VERSION echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --armour --output spark-$SPARK_VERSION.tgz.asc \ --detach-sig spark-$SPARK_VERSION.tgz - echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --print-md MD5 spark-$SPARK_VERSION.tgz > \ - spark-$SPARK_VERSION.tgz.md5 echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --print-md \ SHA512 spark-$SPARK_VERSION.tgz > spark-$SPARK_VERSION.tgz.sha512 rm -rf spark-$SPARK_VERSION @@ -215,9 +213,6 @@ if [[ "$1" == "package" ]]; then echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --armour \ --output $R_DIST_NAME.asc \ --detach-sig $R_DIST_NAME - echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --print-md \ - MD5 $R_DIST_NAME > \ - $R_DIST_NAME.md5 echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --print-md \ SHA512 $R_DIST_NAME > \ $R_DIST_NAME.sha512 @@ -234,9 +229,6 @@ if [[ "$1" == "package" ]]; then echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --armour \ --output $PYTHON_DIST_NAME.asc \ --detach-sig $PYTHON_DIST_NAME - echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --print-md \ - MD5 $PYTHON_DIST_NAME > \ - $PYTHON_DIST_NAME.md5 echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --print-md \ SHA512 $PYTHON_DIST_NAME > \ $PYTHON_DIST_NAME.sha512 @@ -247,9 +239,6 @@ if [[ "$1" == "package" ]]; then echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --armour \ --output spark-$SPARK_VERSION-bin-$NAME.tgz.asc \ --detach-sig spark-$SPARK_VERSION-bin-$NAME.tgz - echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --print-md \ - MD5 spark-$SPARK_VERSION-bin-$NAME.tgz > \ - spark-$SPARK_VERSION-bin-$NAME.tgz.md5 echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --print-md \ SHA512 spark-$SPARK_VERSION-bin-$NAME.tgz > \ spark-$SPARK_VERSION-bin-$NAME.tgz.sha512 @@ -382,18 +371,11 @@ if [[ "$1" == "publish-release" ]]; then find . -type f |grep -v \.jar |grep -v \.pom | xargs rm echo "Creating hash and signature files" - # this must have .asc, .md5 and .sha1 - it really doesn't like anything else there + # this must have .asc and .sha1 - it really doesn't like anything else there for file in $(find . -type f) do echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --output $file.asc \ --detach-sig --armour $file; - if [ $(command -v md5) ]; then - # Available on OS X; -q to keep only hash - md5 -q $file > $file.md5 - else - # Available on Linux; cut to keep only hash - md5sum $file | cut -f1 -d' ' > $file.md5 - fi sha1sum $file | cut -f1 -d' ' > $file.sha1 done From 4c587eb4887623c839854c1505f495de42898229 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Tue, 6 Mar 2018 17:42:17 +0100 Subject: [PATCH 0440/1161] [SPARK-23590][SQL] Add interpreted execution to CreateExternalRow ## What changes were proposed in this pull request? The PR adds interpreted execution to CreateExternalRow ## How was this patch tested? added UT Author: Marco Gaido Closes #20749 from mgaido91/SPARK-23590. --- .../spark/sql/catalyst/expressions/objects/objects.scala | 6 ++++-- .../sql/catalyst/expressions/ExpressionEvalHelper.scala | 4 +++- .../sql/catalyst/expressions/ObjectExpressionsSuite.scala | 8 +++++++- 3 files changed, 14 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 97e3ff88858d0..721d589709131 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -1111,8 +1111,10 @@ case class CreateExternalRow(children: Seq[Expression], schema: StructType) override def nullable: Boolean = false - override def eval(input: InternalRow): Any = - throw new UnsupportedOperationException("Only code-generated evaluation is supported") + override def eval(input: InternalRow): Any = { + val values = children.map(_.eval(input)).toArray + new GenericRowWithSchema(values, schema) + } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val rowClass = classOf[GenericRowWithSchema].getName diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index b4c8eab19c5cc..29f0cc0d991aa 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -24,6 +24,7 @@ import org.scalatest.prop.GeneratorDrivenPropertyChecks import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.serializer.JavaSerializer +import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.analysis.{ResolveTimeZone, SimpleAnalyzer} import org.apache.spark.sql.catalyst.expressions.codegen._ @@ -60,7 +61,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { /** * Check the equality between result of expression and expected value, it will handle - * Array[Byte], Spread[Double], and MapData. + * Array[Byte], Spread[Double], MapData and Row. */ protected def checkResult(result: Any, expected: Any, dataType: DataType): Boolean = { (result, expected) match { @@ -88,6 +89,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { if (expected.isNaN) result.isNaN else expected == result case (result: Float, expected: Float) => if (expected.isNaN) result.isNaN else expected == result + case (result: Row, expected: InternalRow) => result.toSeq == expected.toSeq(result.schema) case _ => result == expected } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala index 0f376c4b63c15..50e57737a4612 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions.objects._ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData} -import org.apache.spark.sql.types.{IntegerType, ObjectType} +import org.apache.spark.sql.types._ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -86,6 +86,12 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } } + test("SPARK-23590: CreateExternalRow should support interpreted execution") { + val schema = new StructType().add("a", IntegerType).add("b", StringType) + val createExternalRow = CreateExternalRow(Seq(Literal(1), Literal("x")), schema) + checkEvaluation(createExternalRow, Row.fromSeq(Seq(1, "x")), InternalRow.fromSeq(Seq())) + } + test("SPARK-23594 GetExternalRowField should support interpreted execution") { val inputObject = BoundReference(0, ObjectType(classOf[Row]), nullable = true) val getRowField = GetExternalRowField(inputObject, index = 0, fieldName = "c0") From 04e71c31603af3a13bc13300df799f003fe185f7 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Wed, 7 Mar 2018 17:01:29 +0800 Subject: [PATCH 0441/1161] [MINOR][YARN] Add disable yarn.nodemanager.vmem-check-enabled option to memLimitExceededLogMessage My spark application sometimes will throw `Container killed by YARN for exceeding memory limits`. Even I increased `spark.yarn.executor.memoryOverhead` to 10G, this error still happen. The latest config: memory-config And error message: ``` ExecutorLostFailure (executor 121 exited caused by one of the running tasks) Reason: Container killed by YARN for exceeding memory limits. 30.7 GB of 30 GB physical memory used. Consider boosting spark.yarn.executor.memoryOverhead. ``` This is because of [Linux glibc >= 2.10 (RHEL 6) malloc may show excessive virtual memory usage](https://www.ibm.com/developerworks/community/blogs/kevgrig/entry/linux_glibc_2_10_rhel_6_malloc_may_show_excessive_virtual_memory_usage?lang=en). So disable `yarn.nodemanager.vmem-check-enabled` looks like a good option as [MapR mentioned ](https://mapr.com/blog/best-practices-yarn-resource-management). This PR add disable `yarn.nodemanager.vmem-check-enabled` option to memLimitExceededLogMessage. More details: https://issues.apache.org/jira/browse/YARN-4714 https://stackoverflow.com/a/31450291 https://stackoverflow.com/a/42091255 After this PR: yarn N/A Author: Yuming Wang Author: Yuming Wang Closes #20735 from wangyum/YARN-4714. Change-Id: Ie10836e2c07b6384d228c3f9e89f802823bd9f16 --- .../scala/org/apache/spark/deploy/yarn/YarnAllocator.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala index 506adb363aa90..a537243d641cb 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala @@ -736,7 +736,8 @@ private object YarnAllocator { def memLimitExceededLogMessage(diagnostics: String, pattern: Pattern): String = { val matcher = pattern.matcher(diagnostics) val diag = if (matcher.find()) " " + matcher.group() + "." else "" - ("Container killed by YARN for exceeding memory limits." + diag - + " Consider boosting spark.yarn.executor.memoryOverhead.") + s"Container killed by YARN for exceeding memory limits. $diag " + + "Consider boosting spark.yarn.executor.memoryOverhead or " + + "disabling yarn.nodemanager.vmem-check-enabled because of YARN-4714." } } From 33c2cb22b3b246a413717042a5f741da04ded69d Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Wed, 7 Mar 2018 13:10:51 +0100 Subject: [PATCH 0442/1161] [SPARK-23611][SQL] Add a helper function to check exception for expr evaluation ## What changes were proposed in this pull request? This pr added a helper function in `ExpressionEvalHelper` to check exceptions in all the path of expression evaluation. ## How was this patch tested? Modified the existing tests. Author: Takeshi Yamamuro Closes #20748 from maropu/SPARK-23611. --- .../expressions/ExpressionEvalHelper.scala | 83 ++++++++++++++----- .../expressions/MathExpressionsSuite.scala | 2 +- .../expressions/MiscExpressionsSuite.scala | 2 +- .../expressions/NullExpressionsSuite.scala | 2 +- .../expressions/ObjectExpressionsSuite.scala | 17 ++-- .../expressions/RegexpExpressionsSuite.scala | 8 +- .../expressions/StringExpressionsSuite.scala | 2 +- .../expressions/TimeWindowSuite.scala | 2 +- 8 files changed, 79 insertions(+), 39 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index 29f0cc0d991aa..58d0c07622eb9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.expressions +import scala.reflect.ClassTag + import org.scalacheck.Gen import org.scalactic.TripleEqualsSupport.Spread import org.scalatest.exceptions.TestFailedException @@ -45,11 +47,15 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { InternalRow.fromSeq(values.map(CatalystTypeConverters.convertToCatalyst)) } - protected def checkEvaluation( - expression: => Expression, expected: Any, inputRow: InternalRow = EmptyRow): Unit = { + private def prepareEvaluation(expression: Expression): Expression = { val serializer = new JavaSerializer(new SparkConf()).newInstance val resolver = ResolveTimeZone(new SQLConf) - val expr = resolver.resolveTimeZones(serializer.deserialize(serializer.serialize(expression))) + resolver.resolveTimeZones(serializer.deserialize(serializer.serialize(expression))) + } + + protected def checkEvaluation( + expression: => Expression, expected: Any, inputRow: InternalRow = EmptyRow): Unit = { + val expr = prepareEvaluation(expression) val catalystValue = CatalystTypeConverters.convertToCatalyst(expected) checkEvaluationWithoutCodegen(expr, catalystValue, inputRow) checkEvaluationWithGeneratedMutableProjection(expr, catalystValue, inputRow) @@ -95,7 +101,31 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { } } - protected def evaluate(expression: Expression, inputRow: InternalRow = EmptyRow): Any = { + protected def checkExceptionInExpression[T <: Throwable : ClassTag]( + expression: => Expression, + inputRow: InternalRow, + expectedErrMsg: String): Unit = { + + def checkException(eval: => Unit, testMode: String): Unit = { + withClue(s"($testMode)") { + val errMsg = intercept[T] { + eval + }.getMessage + if (errMsg != expectedErrMsg) { + fail(s"Expected error message is `$expectedErrMsg`, but `$errMsg` found") + } + } + } + val expr = prepareEvaluation(expression) + checkException(evaluateWithoutCodegen(expr, inputRow), "non-codegen mode") + checkException(evaluateWithGeneratedMutableProjection(expr, inputRow), "codegen mode") + if (GenerateUnsafeProjection.canSupport(expr.dataType)) { + checkException(evaluateWithUnsafeProjection(expr, inputRow), "unsafe mode") + } + } + + protected def evaluateWithoutCodegen( + expression: Expression, inputRow: InternalRow = EmptyRow): Any = { expression.foreach { case n: Nondeterministic => n.initialize(0) case _ => @@ -124,7 +154,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { expected: Any, inputRow: InternalRow = EmptyRow): Unit = { - val actual = try evaluate(expression, inputRow) catch { + val actual = try evaluateWithoutCodegen(expression, inputRow) catch { case e: Exception => fail(s"Exception evaluating $expression", e) } if (!checkResult(actual, expected, expression.dataType)) { @@ -139,33 +169,29 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { expression: Expression, expected: Any, inputRow: InternalRow = EmptyRow): Unit = { + val actual = evaluateWithGeneratedMutableProjection(expression, inputRow) + if (!checkResult(actual, expected, expression.dataType)) { + val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" + fail(s"Incorrect evaluation: $expression, actual: $actual, expected: $expected$input") + } + } + private def evaluateWithGeneratedMutableProjection( + expression: Expression, + inputRow: InternalRow = EmptyRow): Any = { val plan = generateProject( GenerateMutableProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil), expression) plan.initialize(0) - val actual = plan(inputRow).get(0, expression.dataType) - if (!checkResult(actual, expected, expression.dataType)) { - val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" - fail(s"Incorrect evaluation: $expression, actual: $actual, expected: $expected$input") - } + plan(inputRow).get(0, expression.dataType) } protected def checkEvalutionWithUnsafeProjection( expression: Expression, expected: Any, inputRow: InternalRow = EmptyRow): Unit = { - // SPARK-16489 Explicitly doing code generation twice so code gen will fail if - // some expression is reusing variable names across different instances. - // This behavior is tested in ExpressionEvalHelperSuite. - val plan = generateProject( - UnsafeProjection.create( - Alias(expression, s"Optimized($expression)1")() :: - Alias(expression, s"Optimized($expression)2")() :: Nil), - expression) - - val unsafeRow = plan(inputRow) + val unsafeRow = evaluateWithUnsafeProjection(expression, inputRow) val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" if (expected == null) { @@ -185,6 +211,21 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { } } + private def evaluateWithUnsafeProjection( + expression: Expression, + inputRow: InternalRow = EmptyRow): InternalRow = { + // SPARK-16489 Explicitly doing code generation twice so code gen will fail if + // some expression is reusing variable names across different instances. + // This behavior is tested in ExpressionEvalHelperSuite. + val plan = generateProject( + UnsafeProjection.create( + Alias(expression, s"Optimized($expression)1")() :: + Alias(expression, s"Optimized($expression)2")() :: Nil), + expression) + + plan(inputRow) + } + protected def checkEvaluationWithOptimization( expression: Expression, expected: Any, @@ -294,7 +335,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { private def cmpInterpretWithCodegen(inputRow: InternalRow, expr: Expression): Unit = { val interpret = try { - evaluate(expr, inputRow) + evaluateWithoutCodegen(expr, inputRow) } catch { case e: Exception => fail(s"Exception evaluating $expr", e) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala index 39e0060d41dd4..3a094079380fd 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala @@ -124,7 +124,7 @@ class MathExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { private def checkNaNWithoutCodegen( expression: Expression, inputRow: InternalRow = EmptyRow): Unit = { - val actual = try evaluate(expression, inputRow) catch { + val actual = try evaluateWithoutCodegen(expression, inputRow) catch { case e: Exception => fail(s"Exception evaluating $expression", e) } if (!actual.asInstanceOf[Double].isNaN) { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala index facc863081303..a21c139fe71d0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala @@ -41,6 +41,6 @@ class MiscExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { test("uuid") { checkEvaluation(Length(Uuid()), 36) - assert(evaluate(Uuid()) !== evaluate(Uuid())) + assert(evaluateWithoutCodegen(Uuid()) !== evaluateWithoutCodegen(Uuid())) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala index cc6c15cb2c909..424c3a4696077 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala @@ -51,7 +51,7 @@ class NullExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { test("AssertNotNUll") { val ex = intercept[RuntimeException] { - evaluate(AssertNotNull(Literal(null), Seq.empty[String])) + evaluateWithoutCodegen(AssertNotNull(Literal(null), Seq.empty[String])) }.getMessage assert(ex.contains("Null value appeared in non-nullable field")) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala index 50e57737a4612..cbfbb6573ae8e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala @@ -100,14 +100,13 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } // If an input row or a field are null, a runtime exception will be thrown - val errMsg1 = intercept[RuntimeException] { - evaluate(getRowField, InternalRow.fromSeq(Seq(null))) - }.getMessage - assert(errMsg1 === "The input external row cannot be null.") - - val errMsg2 = intercept[RuntimeException] { - evaluate(getRowField, InternalRow.fromSeq(Seq(Row(null)))) - }.getMessage - assert(errMsg2 === "The 0th field 'c0' of input row cannot be null.") + checkExceptionInExpression[RuntimeException]( + getRowField, + InternalRow.fromSeq(Seq(null)), + "The input external row cannot be null.") + checkExceptionInExpression[RuntimeException]( + getRowField, + InternalRow.fromSeq(Seq(Row(null))), + "The 0th field 'c0' of input row cannot be null.") } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala index 2a0a42c65b086..d532dc4f77198 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala @@ -100,12 +100,12 @@ class RegexpExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { // invalid escaping val invalidEscape = intercept[AnalysisException] { - evaluate("""a""" like """\a""") + evaluateWithoutCodegen("""a""" like """\a""") } assert(invalidEscape.getMessage.contains("pattern")) val endEscape = intercept[AnalysisException] { - evaluate("""a""" like """a\""") + evaluateWithoutCodegen("""a""" like """a\""") } assert(endEscape.getMessage.contains("pattern")) @@ -147,11 +147,11 @@ class RegexpExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkLiteralRow("abc" rlike _, "^bc", false) intercept[java.util.regex.PatternSyntaxException] { - evaluate("abbbbc" rlike "**") + evaluateWithoutCodegen("abbbbc" rlike "**") } intercept[java.util.regex.PatternSyntaxException] { val regex = 'a.string.at(0) - evaluate("abbbbc" rlike regex, create_row("**")) + evaluateWithoutCodegen("abbbbc" rlike regex, create_row("**")) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index 97ddbeba2c5ca..9a1a4da074ce3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -756,7 +756,7 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { // exceptional cases intercept[java.util.regex.PatternSyntaxException] { - evaluate(ParseUrl(Seq(Literal("http://spark.apache.org/path?"), + evaluateWithoutCodegen(ParseUrl(Seq(Literal("http://spark.apache.org/path?"), Literal("QUERY"), Literal("???")))) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TimeWindowSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TimeWindowSuite.scala index d6c8fcf291842..351d4d0c2eac9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TimeWindowSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TimeWindowSuite.scala @@ -27,7 +27,7 @@ class TimeWindowSuite extends SparkFunSuite with ExpressionEvalHelper with Priva test("time window is unevaluable") { intercept[UnsupportedOperationException] { - evaluate(TimeWindow(Literal(10L), "1 second", "1 second", "0 second")) + evaluateWithoutCodegen(TimeWindow(Literal(10L), "1 second", "1 second", "0 second")) } } From aff7d81cb73133483fc2256ca10e21b4b8101647 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Wed, 7 Mar 2018 18:31:59 +0100 Subject: [PATCH 0443/1161] [SPARK-23591][SQL] Add interpreted execution to EncodeUsingSerializer ## What changes were proposed in this pull request? The PR adds interpreted execution to EncodeUsingSerializer. ## How was this patch tested? added UT Author: Marco Gaido Closes #20751 from mgaido91/SPARK-23591. --- .../expressions/objects/objects.scala | 114 ++++++++++-------- .../expressions/ObjectExpressionsSuite.scala | 16 ++- 2 files changed, 77 insertions(+), 53 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 721d589709131..7bbc3c732e782 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -105,6 +105,61 @@ trait InvokeLike extends Expression with NonSQLExpression { } } +/** + * Common trait for [[DecodeUsingSerializer]] and [[EncodeUsingSerializer]] + */ +trait SerializerSupport { + /** + * If true, Kryo serialization is used, otherwise the Java one is used + */ + val kryo: Boolean + + /** + * The serializer instance to be used for serialization/deserialization in interpreted execution + */ + lazy val serializerInstance: SerializerInstance = SerializerSupport.newSerializer(kryo) + + /** + * Adds a immutable state to the generated class containing a reference to the serializer. + * @return a string containing the name of the variable referencing the serializer + */ + def addImmutableSerializerIfNeeded(ctx: CodegenContext): String = { + val (serializerInstance, serializerInstanceClass) = { + if (kryo) { + ("kryoSerializer", + classOf[KryoSerializerInstance].getName) + } else { + ("javaSerializer", + classOf[JavaSerializerInstance].getName) + } + } + val newSerializerMethod = s"${classOf[SerializerSupport].getName}$$.MODULE$$.newSerializer" + // Code to initialize the serializer + ctx.addImmutableStateIfNotExists(serializerInstanceClass, serializerInstance, v => + s""" + |$v = ($serializerInstanceClass) $newSerializerMethod($kryo); + """.stripMargin) + serializerInstance + } +} + +object SerializerSupport { + /** + * It creates a new `SerializerInstance` which is either a `KryoSerializerInstance` (is + * `useKryo` is set to `true`) or a `JavaSerializerInstance`. + */ + def newSerializer(useKryo: Boolean): SerializerInstance = { + // try conf from env, otherwise create a new one + val conf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf) + val s = if (useKryo) { + new KryoSerializer(conf) + } else { + new JavaSerializer(conf) + } + s.newInstance() + } +} + /** * Invokes a static function, returning the result. By default, any of the arguments being null * will result in returning null instead of calling the function. @@ -1154,36 +1209,14 @@ case class CreateExternalRow(children: Seq[Expression], schema: StructType) * @param kryo if true, use Kryo. Otherwise, use Java. */ case class EncodeUsingSerializer(child: Expression, kryo: Boolean) - extends UnaryExpression with NonSQLExpression { + extends UnaryExpression with NonSQLExpression with SerializerSupport { - override def eval(input: InternalRow): Any = - throw new UnsupportedOperationException("Only code-generated evaluation is supported") + override def nullSafeEval(input: Any): Any = { + serializerInstance.serialize(input).array() + } override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - // Code to initialize the serializer. - val (serializer, serializerClass, serializerInstanceClass) = { - if (kryo) { - ("kryoSerializer", - classOf[KryoSerializer].getName, - classOf[KryoSerializerInstance].getName) - } else { - ("javaSerializer", - classOf[JavaSerializer].getName, - classOf[JavaSerializerInstance].getName) - } - } - // try conf from env, otherwise create a new one - val env = s"${classOf[SparkEnv].getName}.get()" - val sparkConf = s"new ${classOf[SparkConf].getName}()" - ctx.addImmutableStateIfNotExists(serializerInstanceClass, serializer, v => - s""" - |if ($env == null) { - | $v = ($serializerInstanceClass) new $serializerClass($sparkConf).newInstance(); - |} else { - | $v = ($serializerInstanceClass) new $serializerClass($env.conf()).newInstance(); - |} - """.stripMargin) - + val serializer = addImmutableSerializerIfNeeded(ctx) // Code to serialize. val input = child.genCode(ctx) val javaType = CodeGenerator.javaType(dataType) @@ -1207,33 +1240,10 @@ case class EncodeUsingSerializer(child: Expression, kryo: Boolean) * @param kryo if true, use Kryo. Otherwise, use Java. */ case class DecodeUsingSerializer[T](child: Expression, tag: ClassTag[T], kryo: Boolean) - extends UnaryExpression with NonSQLExpression { + extends UnaryExpression with NonSQLExpression with SerializerSupport { override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - // Code to initialize the serializer. - val (serializer, serializerClass, serializerInstanceClass) = { - if (kryo) { - ("kryoSerializer", - classOf[KryoSerializer].getName, - classOf[KryoSerializerInstance].getName) - } else { - ("javaSerializer", - classOf[JavaSerializer].getName, - classOf[JavaSerializerInstance].getName) - } - } - // try conf from env, otherwise create a new one - val env = s"${classOf[SparkEnv].getName}.get()" - val sparkConf = s"new ${classOf[SparkConf].getName}()" - ctx.addImmutableStateIfNotExists(serializerInstanceClass, serializer, v => - s""" - |if ($env == null) { - | $v = ($serializerInstanceClass) new $serializerClass($sparkConf).newInstance(); - |} else { - | $v = ($serializerInstanceClass) new $serializerClass($env.conf()).newInstance(); - |} - """.stripMargin) - + val serializer = addImmutableSerializerIfNeeded(ctx) // Code to deserialize. val input = child.genCode(ctx) val javaType = CodeGenerator.javaType(dataType) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala index cbfbb6573ae8e..346b13277c709 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala @@ -17,7 +17,8 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.SparkFunSuite +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder @@ -109,4 +110,17 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { InternalRow.fromSeq(Seq(Row(null))), "The 0th field 'c0' of input row cannot be null.") } + + test("SPARK-23591: EncodeUsingSerializer should support interpreted execution") { + val cls = ObjectType(classOf[java.lang.Integer]) + val inputObject = BoundReference(0, cls, nullable = true) + val conf = new SparkConf() + Seq(true, false).foreach { useKryo => + val serializer = if (useKryo) new KryoSerializer(conf) else new JavaSerializer(conf) + val expected = serializer.newInstance().serialize(new Integer(1)).array() + val encodeUsingSerializer = EncodeUsingSerializer(inputObject, useKryo) + checkEvaluation(encodeUsingSerializer, expected, InternalRow.fromSeq(Seq(1))) + checkEvaluation(encodeUsingSerializer, null, InternalRow.fromSeq(Seq(null))) + } + } } From 53561d27c45db31893bcabd4aca2387fde869b72 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 7 Mar 2018 09:37:42 -0800 Subject: [PATCH 0444/1161] [SPARK-23291][SQL][R] R's substr should not reduce starting position by 1 when calling Scala API ## What changes were proposed in this pull request? Seems R's substr API treats Scala substr API as zero based and so subtracts the given starting position by 1. Because Scala's substr API also accepts zero-based starting position (treated as the first element), so the current R's substr test results are correct as they all use 1 as starting positions. ## How was this patch tested? Modified tests. Author: Liang-Chi Hsieh Closes #20464 from viirya/SPARK-23291. --- R/pkg/R/column.R | 10 ++++++++-- R/pkg/tests/fulltests/test_sparkSQL.R | 1 + docs/sparkr.md | 4 ++++ 3 files changed, 13 insertions(+), 2 deletions(-) diff --git a/R/pkg/R/column.R b/R/pkg/R/column.R index 9727efc354f10..7926a9a2467ee 100644 --- a/R/pkg/R/column.R +++ b/R/pkg/R/column.R @@ -161,12 +161,18 @@ setMethod("alias", #' @aliases substr,Column-method #' #' @param x a Column. -#' @param start starting position. +#' @param start starting position. It should be 1-base. #' @param stop ending position. +#' @examples +#' \dontrun{ +#' df <- createDataFrame(list(list(a="abcdef"))) +#' collect(select(df, substr(df$a, 1, 4))) # the result is `abcd`. +#' collect(select(df, substr(df$a, 2, 4))) # the result is `bcd`. +#' } #' @note substr since 1.4.0 setMethod("substr", signature(x = "Column"), function(x, start, stop) { - jc <- callJMethod(x@jc, "substr", as.integer(start - 1), as.integer(stop - start + 1)) + jc <- callJMethod(x@jc, "substr", as.integer(start), as.integer(stop - start + 1)) column(jc) }) diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index bd0a0dcd0674c..439191adb23ea 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -1651,6 +1651,7 @@ test_that("string operators", { expect_false(first(select(df, startsWith(df$name, "m")))[[1]]) expect_true(first(select(df, endsWith(df$name, "el")))[[1]]) expect_equal(first(select(df, substr(df$name, 1, 2)))[[1]], "Mi") + expect_equal(first(select(df, substr(df$name, 4, 6)))[[1]], "hae") if (as.numeric(R.version$major) >= 3 && as.numeric(R.version$minor) >= 3) { expect_true(startsWith("Hello World", "Hello")) expect_false(endsWith("Hello World", "a")) diff --git a/docs/sparkr.md b/docs/sparkr.md index 6685b585a393a..2909247e79e95 100644 --- a/docs/sparkr.md +++ b/docs/sparkr.md @@ -663,3 +663,7 @@ You can inspect the search path in R with [`search()`](https://stat.ethz.ch/R-ma - The `stringsAsFactors` parameter was previously ignored with `collect`, for example, in `collect(createDataFrame(iris), stringsAsFactors = TRUE))`. It has been corrected. - For `summary`, option for statistics to compute has been added. Its output is changed from that from `describe`. - A warning can be raised if versions of SparkR package and the Spark JVM do not match. + +## Upgrading to Spark 2.4.0 + + - The `start` parameter of `substr` method was wrongly subtracted by one, previously. In other words, the index specified by `start` parameter was considered as 0-base. This can lead to inconsistent substring results and also does not match with the behaviour with `substr` in R. It has been fixed so the `start` parameter of `substr` method is now 1-base, e.g., therefore to get the same result as `substr(df$a, 2, 5)`, it should be changed to `substr(df$a, 1, 4)`. From c99fc9ad9b600095baba003053dbf84304ca392b Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Wed, 7 Mar 2018 13:42:06 -0800 Subject: [PATCH 0445/1161] [SPARK-23550][CORE] Cleanup `Utils`. A few different things going on: - Remove unused methods. - Move JSON methods to the only class that uses them. - Move test-only methods to TestUtils. - Make getMaxResultSize() a config constant. - Reuse functionality from existing libraries (JRE or JavaUtils) where possible. The change also includes changes to a few tests to call `Utils.createTempFile` correctly, so that temp dirs are created under the designated top-level temp dir instead of potentially polluting git index. Author: Marcelo Vanzin Closes #20706 from vanzin/SPARK-23550. --- .../scala/org/apache/spark/TestUtils.scala | 26 +++- .../spark/deploy/SparkSubmitArguments.scala | 4 +- .../org/apache/spark/executor/Executor.scala | 4 +- .../spark/internal/config/ConfigBuilder.scala | 3 +- .../spark/internal/config/package.scala | 5 + .../spark/scheduler/TaskSetManager.scala | 3 +- .../org/apache/spark/util/JsonProtocol.scala | 124 +++++++++-------- .../scala/org/apache/spark/util/Utils.scala | 131 +----------------- .../sort/UnsafeShuffleWriterSuite.java | 2 +- .../scala/org/apache/spark/DriverSuite.scala | 2 +- .../spark/deploy/SparkSubmitSuite.scala | 18 +-- .../spark/scheduler/ReplayListenerSuite.scala | 12 +- .../org/apache/spark/util/UtilsSuite.scala | 1 + scalastyle-config.xml | 2 +- .../datasources/orc/OrcSourceSuite.scala | 4 +- .../metric/SQLMetricsTestUtils.scala | 6 +- .../sql/sources/PartitionedWriteSuite.scala | 15 +- .../HiveThriftServer2Suites.scala | 2 +- .../spark/sql/hive/HiveSparkSubmitSuite.scala | 20 +-- .../spark/streaming/CheckpointSuite.scala | 2 +- .../spark/streaming/MapWithStateSuite.scala | 2 +- 21 files changed, 152 insertions(+), 236 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/TestUtils.scala b/core/src/main/scala/org/apache/spark/TestUtils.scala index 93e7ee3d2a404..b5c4c705dcbc7 100644 --- a/core/src/main/scala/org/apache/spark/TestUtils.scala +++ b/core/src/main/scala/org/apache/spark/TestUtils.scala @@ -22,7 +22,7 @@ import java.net.{HttpURLConnection, URI, URL} import java.nio.charset.StandardCharsets import java.security.SecureRandom import java.security.cert.X509Certificate -import java.util.Arrays +import java.util.{Arrays, Properties} import java.util.concurrent.{CountDownLatch, TimeoutException, TimeUnit} import java.util.jar.{JarEntry, JarOutputStream} import javax.net.ssl._ @@ -35,6 +35,7 @@ import scala.sys.process.{Process, ProcessLogger} import scala.util.Try import com.google.common.io.{ByteStreams, Files} +import org.apache.log4j.PropertyConfigurator import org.apache.spark.executor.TaskMetrics import org.apache.spark.scheduler._ @@ -256,6 +257,29 @@ private[spark] object TestUtils { s"Can't find $numExecutors executors before $timeout milliseconds elapsed") } + /** + * config a log4j properties used for testsuite + */ + def configTestLog4j(level: String): Unit = { + val pro = new Properties() + pro.put("log4j.rootLogger", s"$level, console") + pro.put("log4j.appender.console", "org.apache.log4j.ConsoleAppender") + pro.put("log4j.appender.console.target", "System.err") + pro.put("log4j.appender.console.layout", "org.apache.log4j.PatternLayout") + pro.put("log4j.appender.console.layout.ConversionPattern", + "%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n") + PropertyConfigurator.configure(pro) + } + + /** + * Lists files recursively. + */ + def recursiveList(f: File): Array[File] = { + require(f.isDirectory) + val current = f.listFiles + current ++ current.filter(_.isDirectory).flatMap(recursiveList) + } + } diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index 9db7a1fe3106d..e7796d4ddbe34 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -17,7 +17,7 @@ package org.apache.spark.deploy -import java.io.{ByteArrayOutputStream, PrintStream} +import java.io.{ByteArrayOutputStream, File, PrintStream} import java.lang.reflect.InvocationTargetException import java.net.URI import java.nio.charset.StandardCharsets @@ -233,7 +233,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S // Set name from main class if not given name = Option(name).orElse(Option(mainClass)).orNull if (name == null && primaryResource != null) { - name = Utils.stripDirectory(primaryResource) + name = new File(primaryResource).getName() } // Action should be SUBMIT unless otherwise specified diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 2c3a8ef74800b..dcec3ec21b546 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -35,6 +35,7 @@ import com.google.common.util.concurrent.ThreadFactoryBuilder import org.apache.spark._ import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging +import org.apache.spark.internal.config._ import org.apache.spark.memory.{SparkOutOfMemoryError, TaskMemoryManager} import org.apache.spark.rpc.RpcTimeout import org.apache.spark.scheduler.{DirectTaskResult, IndirectTaskResult, Task, TaskDescription} @@ -141,8 +142,7 @@ private[spark] class Executor( conf.getSizeAsBytes("spark.task.maxDirectResultSize", 1L << 20), RpcUtils.maxMessageSizeBytes(conf)) - // Limit of bytes for total size of results (default is 1GB) - private val maxResultSize = Utils.getMaxResultSize(conf) + private val maxResultSize = conf.get(MAX_RESULT_SIZE) // Maintains the list of running tasks. private val runningTasks = new ConcurrentHashMap[Long, TaskRunner] diff --git a/core/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala b/core/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala index b0cd7110a3b47..f27aca03773a9 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala @@ -23,6 +23,7 @@ import java.util.regex.PatternSyntaxException import scala.util.matching.Regex import org.apache.spark.network.util.{ByteUnit, JavaUtils} +import org.apache.spark.util.Utils private object ConfigHelpers { @@ -45,7 +46,7 @@ private object ConfigHelpers { } def stringToSeq[T](str: String, converter: String => T): Seq[T] = { - str.split(",").map(_.trim()).filter(_.nonEmpty).map(converter) + Utils.stringToSeq(str).map(converter) } def seqToString[T](v: Seq[T], stringConverter: T => String): String = { diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index bbfcfbaa7363c..a313ad0554a3a 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -520,4 +520,9 @@ package object config { .checkValue(v => v > 0, "The threshold should be positive.") .createWithDefault(10000000) + private[spark] val MAX_RESULT_SIZE = ConfigBuilder("spark.driver.maxResultSize") + .doc("Size limit for results.") + .bytesConf(ByteUnit.BYTE) + .createWithDefaultString("1g") + } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 886c2c99f1ff3..d958658527f6d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -64,8 +64,7 @@ private[spark] class TaskSetManager( val SPECULATION_QUANTILE = conf.getDouble("spark.speculation.quantile", 0.75) val SPECULATION_MULTIPLIER = conf.getDouble("spark.speculation.multiplier", 1.5) - // Limit of bytes for total size of results (default is 1GB) - val maxResultSize = Utils.getMaxResultSize(conf) + val maxResultSize = conf.get(config.MAX_RESULT_SIZE) val speculationEnabled = conf.getBoolean("spark.speculation", false) diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index ff83301d631c4..40383fe05026b 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -48,7 +48,7 @@ import org.apache.spark.storage._ * To ensure that we provide these guarantees, follow these rules when modifying these methods: * * - Never delete any JSON fields. - * - Any new JSON fields should be optional; use `Utils.jsonOption` when reading these fields + * - Any new JSON fields should be optional; use `jsonOption` when reading these fields * in `*FromJson` methods. */ private[spark] object JsonProtocol { @@ -408,7 +408,7 @@ private[spark] object JsonProtocol { ("Loss Reason" -> reason.map(_.toString)) case taskKilled: TaskKilled => ("Kill Reason" -> taskKilled.reason) - case _ => Utils.emptyJson + case _ => emptyJson } ("Reason" -> reason) ~ json } @@ -422,7 +422,7 @@ private[spark] object JsonProtocol { def jobResultToJson(jobResult: JobResult): JValue = { val result = Utils.getFormattedClassName(jobResult) val json = jobResult match { - case JobSucceeded => Utils.emptyJson + case JobSucceeded => emptyJson case jobFailed: JobFailed => JObject("Exception" -> exceptionToJson(jobFailed.exception)) } @@ -573,7 +573,7 @@ private[spark] object JsonProtocol { def taskStartFromJson(json: JValue): SparkListenerTaskStart = { val stageId = (json \ "Stage ID").extract[Int] val stageAttemptId = - Utils.jsonOption(json \ "Stage Attempt ID").map(_.extract[Int]).getOrElse(0) + jsonOption(json \ "Stage Attempt ID").map(_.extract[Int]).getOrElse(0) val taskInfo = taskInfoFromJson(json \ "Task Info") SparkListenerTaskStart(stageId, stageAttemptId, taskInfo) } @@ -586,7 +586,7 @@ private[spark] object JsonProtocol { def taskEndFromJson(json: JValue): SparkListenerTaskEnd = { val stageId = (json \ "Stage ID").extract[Int] val stageAttemptId = - Utils.jsonOption(json \ "Stage Attempt ID").map(_.extract[Int]).getOrElse(0) + jsonOption(json \ "Stage Attempt ID").map(_.extract[Int]).getOrElse(0) val taskType = (json \ "Task Type").extract[String] val taskEndReason = taskEndReasonFromJson(json \ "Task End Reason") val taskInfo = taskInfoFromJson(json \ "Task Info") @@ -597,11 +597,11 @@ private[spark] object JsonProtocol { def jobStartFromJson(json: JValue): SparkListenerJobStart = { val jobId = (json \ "Job ID").extract[Int] val submissionTime = - Utils.jsonOption(json \ "Submission Time").map(_.extract[Long]).getOrElse(-1L) + jsonOption(json \ "Submission Time").map(_.extract[Long]).getOrElse(-1L) val stageIds = (json \ "Stage IDs").extract[List[JValue]].map(_.extract[Int]) val properties = propertiesFromJson(json \ "Properties") // The "Stage Infos" field was added in Spark 1.2.0 - val stageInfos = Utils.jsonOption(json \ "Stage Infos") + val stageInfos = jsonOption(json \ "Stage Infos") .map(_.extract[Seq[JValue]].map(stageInfoFromJson)).getOrElse { stageIds.map { id => new StageInfo(id, 0, "unknown", 0, Seq.empty, Seq.empty, "unknown") @@ -613,7 +613,7 @@ private[spark] object JsonProtocol { def jobEndFromJson(json: JValue): SparkListenerJobEnd = { val jobId = (json \ "Job ID").extract[Int] val completionTime = - Utils.jsonOption(json \ "Completion Time").map(_.extract[Long]).getOrElse(-1L) + jsonOption(json \ "Completion Time").map(_.extract[Long]).getOrElse(-1L) val jobResult = jobResultFromJson(json \ "Job Result") SparkListenerJobEnd(jobId, completionTime, jobResult) } @@ -630,15 +630,15 @@ private[spark] object JsonProtocol { def blockManagerAddedFromJson(json: JValue): SparkListenerBlockManagerAdded = { val blockManagerId = blockManagerIdFromJson(json \ "Block Manager ID") val maxMem = (json \ "Maximum Memory").extract[Long] - val time = Utils.jsonOption(json \ "Timestamp").map(_.extract[Long]).getOrElse(-1L) - val maxOnHeapMem = Utils.jsonOption(json \ "Maximum Onheap Memory").map(_.extract[Long]) - val maxOffHeapMem = Utils.jsonOption(json \ "Maximum Offheap Memory").map(_.extract[Long]) + val time = jsonOption(json \ "Timestamp").map(_.extract[Long]).getOrElse(-1L) + val maxOnHeapMem = jsonOption(json \ "Maximum Onheap Memory").map(_.extract[Long]) + val maxOffHeapMem = jsonOption(json \ "Maximum Offheap Memory").map(_.extract[Long]) SparkListenerBlockManagerAdded(time, blockManagerId, maxMem, maxOnHeapMem, maxOffHeapMem) } def blockManagerRemovedFromJson(json: JValue): SparkListenerBlockManagerRemoved = { val blockManagerId = blockManagerIdFromJson(json \ "Block Manager ID") - val time = Utils.jsonOption(json \ "Timestamp").map(_.extract[Long]).getOrElse(-1L) + val time = jsonOption(json \ "Timestamp").map(_.extract[Long]).getOrElse(-1L) SparkListenerBlockManagerRemoved(time, blockManagerId) } @@ -648,11 +648,11 @@ private[spark] object JsonProtocol { def applicationStartFromJson(json: JValue): SparkListenerApplicationStart = { val appName = (json \ "App Name").extract[String] - val appId = Utils.jsonOption(json \ "App ID").map(_.extract[String]) + val appId = jsonOption(json \ "App ID").map(_.extract[String]) val time = (json \ "Timestamp").extract[Long] val sparkUser = (json \ "User").extract[String] - val appAttemptId = Utils.jsonOption(json \ "App Attempt ID").map(_.extract[String]) - val driverLogs = Utils.jsonOption(json \ "Driver Logs").map(mapFromJson) + val appAttemptId = jsonOption(json \ "App Attempt ID").map(_.extract[String]) + val driverLogs = jsonOption(json \ "Driver Logs").map(mapFromJson) SparkListenerApplicationStart(appName, appId, time, sparkUser, appAttemptId, driverLogs) } @@ -703,19 +703,19 @@ private[spark] object JsonProtocol { def stageInfoFromJson(json: JValue): StageInfo = { val stageId = (json \ "Stage ID").extract[Int] - val attemptId = Utils.jsonOption(json \ "Stage Attempt ID").map(_.extract[Int]).getOrElse(0) + val attemptId = jsonOption(json \ "Stage Attempt ID").map(_.extract[Int]).getOrElse(0) val stageName = (json \ "Stage Name").extract[String] val numTasks = (json \ "Number of Tasks").extract[Int] val rddInfos = (json \ "RDD Info").extract[List[JValue]].map(rddInfoFromJson) - val parentIds = Utils.jsonOption(json \ "Parent IDs") + val parentIds = jsonOption(json \ "Parent IDs") .map { l => l.extract[List[JValue]].map(_.extract[Int]) } .getOrElse(Seq.empty) - val details = Utils.jsonOption(json \ "Details").map(_.extract[String]).getOrElse("") - val submissionTime = Utils.jsonOption(json \ "Submission Time").map(_.extract[Long]) - val completionTime = Utils.jsonOption(json \ "Completion Time").map(_.extract[Long]) - val failureReason = Utils.jsonOption(json \ "Failure Reason").map(_.extract[String]) + val details = jsonOption(json \ "Details").map(_.extract[String]).getOrElse("") + val submissionTime = jsonOption(json \ "Submission Time").map(_.extract[Long]) + val completionTime = jsonOption(json \ "Completion Time").map(_.extract[Long]) + val failureReason = jsonOption(json \ "Failure Reason").map(_.extract[String]) val accumulatedValues = { - Utils.jsonOption(json \ "Accumulables").map(_.extract[List[JValue]]) match { + jsonOption(json \ "Accumulables").map(_.extract[List[JValue]]) match { case Some(values) => values.map(accumulableInfoFromJson) case None => Seq.empty[AccumulableInfo] } @@ -735,17 +735,17 @@ private[spark] object JsonProtocol { def taskInfoFromJson(json: JValue): TaskInfo = { val taskId = (json \ "Task ID").extract[Long] val index = (json \ "Index").extract[Int] - val attempt = Utils.jsonOption(json \ "Attempt").map(_.extract[Int]).getOrElse(1) + val attempt = jsonOption(json \ "Attempt").map(_.extract[Int]).getOrElse(1) val launchTime = (json \ "Launch Time").extract[Long] val executorId = (json \ "Executor ID").extract[String].intern() val host = (json \ "Host").extract[String].intern() val taskLocality = TaskLocality.withName((json \ "Locality").extract[String]) - val speculative = Utils.jsonOption(json \ "Speculative").exists(_.extract[Boolean]) + val speculative = jsonOption(json \ "Speculative").exists(_.extract[Boolean]) val gettingResultTime = (json \ "Getting Result Time").extract[Long] val finishTime = (json \ "Finish Time").extract[Long] val failed = (json \ "Failed").extract[Boolean] - val killed = Utils.jsonOption(json \ "Killed").exists(_.extract[Boolean]) - val accumulables = Utils.jsonOption(json \ "Accumulables").map(_.extract[Seq[JValue]]) match { + val killed = jsonOption(json \ "Killed").exists(_.extract[Boolean]) + val accumulables = jsonOption(json \ "Accumulables").map(_.extract[Seq[JValue]]) match { case Some(values) => values.map(accumulableInfoFromJson) case None => Seq.empty[AccumulableInfo] } @@ -762,13 +762,13 @@ private[spark] object JsonProtocol { def accumulableInfoFromJson(json: JValue): AccumulableInfo = { val id = (json \ "ID").extract[Long] - val name = Utils.jsonOption(json \ "Name").map(_.extract[String]) - val update = Utils.jsonOption(json \ "Update").map { v => accumValueFromJson(name, v) } - val value = Utils.jsonOption(json \ "Value").map { v => accumValueFromJson(name, v) } - val internal = Utils.jsonOption(json \ "Internal").exists(_.extract[Boolean]) + val name = jsonOption(json \ "Name").map(_.extract[String]) + val update = jsonOption(json \ "Update").map { v => accumValueFromJson(name, v) } + val value = jsonOption(json \ "Value").map { v => accumValueFromJson(name, v) } + val internal = jsonOption(json \ "Internal").exists(_.extract[Boolean]) val countFailedValues = - Utils.jsonOption(json \ "Count Failed Values").exists(_.extract[Boolean]) - val metadata = Utils.jsonOption(json \ "Metadata").map(_.extract[String]) + jsonOption(json \ "Count Failed Values").exists(_.extract[Boolean]) + val metadata = jsonOption(json \ "Metadata").map(_.extract[String]) new AccumulableInfo(id, name, update, value, internal, countFailedValues, metadata) } @@ -821,49 +821,49 @@ private[spark] object JsonProtocol { metrics.incDiskBytesSpilled((json \ "Disk Bytes Spilled").extract[Long]) // Shuffle read metrics - Utils.jsonOption(json \ "Shuffle Read Metrics").foreach { readJson => + jsonOption(json \ "Shuffle Read Metrics").foreach { readJson => val readMetrics = metrics.createTempShuffleReadMetrics() readMetrics.incRemoteBlocksFetched((readJson \ "Remote Blocks Fetched").extract[Int]) readMetrics.incLocalBlocksFetched((readJson \ "Local Blocks Fetched").extract[Int]) readMetrics.incRemoteBytesRead((readJson \ "Remote Bytes Read").extract[Long]) - Utils.jsonOption(readJson \ "Remote Bytes Read To Disk") + jsonOption(readJson \ "Remote Bytes Read To Disk") .foreach { v => readMetrics.incRemoteBytesReadToDisk(v.extract[Long])} readMetrics.incLocalBytesRead( - Utils.jsonOption(readJson \ "Local Bytes Read").map(_.extract[Long]).getOrElse(0L)) + jsonOption(readJson \ "Local Bytes Read").map(_.extract[Long]).getOrElse(0L)) readMetrics.incFetchWaitTime((readJson \ "Fetch Wait Time").extract[Long]) readMetrics.incRecordsRead( - Utils.jsonOption(readJson \ "Total Records Read").map(_.extract[Long]).getOrElse(0L)) + jsonOption(readJson \ "Total Records Read").map(_.extract[Long]).getOrElse(0L)) metrics.mergeShuffleReadMetrics() } // Shuffle write metrics // TODO: Drop the redundant "Shuffle" since it's inconsistent with related classes. - Utils.jsonOption(json \ "Shuffle Write Metrics").foreach { writeJson => + jsonOption(json \ "Shuffle Write Metrics").foreach { writeJson => val writeMetrics = metrics.shuffleWriteMetrics writeMetrics.incBytesWritten((writeJson \ "Shuffle Bytes Written").extract[Long]) writeMetrics.incRecordsWritten( - Utils.jsonOption(writeJson \ "Shuffle Records Written").map(_.extract[Long]).getOrElse(0L)) + jsonOption(writeJson \ "Shuffle Records Written").map(_.extract[Long]).getOrElse(0L)) writeMetrics.incWriteTime((writeJson \ "Shuffle Write Time").extract[Long]) } // Output metrics - Utils.jsonOption(json \ "Output Metrics").foreach { outJson => + jsonOption(json \ "Output Metrics").foreach { outJson => val outputMetrics = metrics.outputMetrics outputMetrics.setBytesWritten((outJson \ "Bytes Written").extract[Long]) outputMetrics.setRecordsWritten( - Utils.jsonOption(outJson \ "Records Written").map(_.extract[Long]).getOrElse(0L)) + jsonOption(outJson \ "Records Written").map(_.extract[Long]).getOrElse(0L)) } // Input metrics - Utils.jsonOption(json \ "Input Metrics").foreach { inJson => + jsonOption(json \ "Input Metrics").foreach { inJson => val inputMetrics = metrics.inputMetrics inputMetrics.incBytesRead((inJson \ "Bytes Read").extract[Long]) inputMetrics.incRecordsRead( - Utils.jsonOption(inJson \ "Records Read").map(_.extract[Long]).getOrElse(0L)) + jsonOption(inJson \ "Records Read").map(_.extract[Long]).getOrElse(0L)) } // Updated blocks - Utils.jsonOption(json \ "Updated Blocks").foreach { blocksJson => + jsonOption(json \ "Updated Blocks").foreach { blocksJson => metrics.setUpdatedBlockStatuses(blocksJson.extract[List[JValue]].map { blockJson => val id = BlockId((blockJson \ "Block ID").extract[String]) val status = blockStatusFromJson(blockJson \ "Status") @@ -897,7 +897,7 @@ private[spark] object JsonProtocol { val shuffleId = (json \ "Shuffle ID").extract[Int] val mapId = (json \ "Map ID").extract[Int] val reduceId = (json \ "Reduce ID").extract[Int] - val message = Utils.jsonOption(json \ "Message").map(_.extract[String]) + val message = jsonOption(json \ "Message").map(_.extract[String]) new FetchFailed(blockManagerAddress, shuffleId, mapId, reduceId, message.getOrElse("Unknown reason")) case `exceptionFailure` => @@ -905,9 +905,9 @@ private[spark] object JsonProtocol { val description = (json \ "Description").extract[String] val stackTrace = stackTraceFromJson(json \ "Stack Trace") val fullStackTrace = - Utils.jsonOption(json \ "Full Stack Trace").map(_.extract[String]).orNull + jsonOption(json \ "Full Stack Trace").map(_.extract[String]).orNull // Fallback on getting accumulator updates from TaskMetrics, which was logged in Spark 1.x - val accumUpdates = Utils.jsonOption(json \ "Accumulator Updates") + val accumUpdates = jsonOption(json \ "Accumulator Updates") .map(_.extract[List[JValue]].map(accumulableInfoFromJson)) .getOrElse(taskMetricsFromJson(json \ "Metrics").accumulators().map(acc => { acc.toInfo(Some(acc.value), None) @@ -915,21 +915,21 @@ private[spark] object JsonProtocol { ExceptionFailure(className, description, stackTrace, fullStackTrace, None, accumUpdates) case `taskResultLost` => TaskResultLost case `taskKilled` => - val killReason = Utils.jsonOption(json \ "Kill Reason") + val killReason = jsonOption(json \ "Kill Reason") .map(_.extract[String]).getOrElse("unknown reason") TaskKilled(killReason) case `taskCommitDenied` => // Unfortunately, the `TaskCommitDenied` message was introduced in 1.3.0 but the JSON // de/serialization logic was not added until 1.5.1. To provide backward compatibility // for reading those logs, we need to provide default values for all the fields. - val jobId = Utils.jsonOption(json \ "Job ID").map(_.extract[Int]).getOrElse(-1) - val partitionId = Utils.jsonOption(json \ "Partition ID").map(_.extract[Int]).getOrElse(-1) - val attemptNo = Utils.jsonOption(json \ "Attempt Number").map(_.extract[Int]).getOrElse(-1) + val jobId = jsonOption(json \ "Job ID").map(_.extract[Int]).getOrElse(-1) + val partitionId = jsonOption(json \ "Partition ID").map(_.extract[Int]).getOrElse(-1) + val attemptNo = jsonOption(json \ "Attempt Number").map(_.extract[Int]).getOrElse(-1) TaskCommitDenied(jobId, partitionId, attemptNo) case `executorLostFailure` => - val exitCausedByApp = Utils.jsonOption(json \ "Exit Caused By App").map(_.extract[Boolean]) - val executorId = Utils.jsonOption(json \ "Executor ID").map(_.extract[String]) - val reason = Utils.jsonOption(json \ "Loss Reason").map(_.extract[String]) + val exitCausedByApp = jsonOption(json \ "Exit Caused By App").map(_.extract[Boolean]) + val executorId = jsonOption(json \ "Executor ID").map(_.extract[String]) + val reason = jsonOption(json \ "Loss Reason").map(_.extract[String]) ExecutorLostFailure( executorId.getOrElse("Unknown"), exitCausedByApp.getOrElse(true), @@ -968,11 +968,11 @@ private[spark] object JsonProtocol { def rddInfoFromJson(json: JValue): RDDInfo = { val rddId = (json \ "RDD ID").extract[Int] val name = (json \ "Name").extract[String] - val scope = Utils.jsonOption(json \ "Scope") + val scope = jsonOption(json \ "Scope") .map(_.extract[String]) .map(RDDOperationScope.fromJson) - val callsite = Utils.jsonOption(json \ "Callsite").map(_.extract[String]).getOrElse("") - val parentIds = Utils.jsonOption(json \ "Parent IDs") + val callsite = jsonOption(json \ "Callsite").map(_.extract[String]).getOrElse("") + val parentIds = jsonOption(json \ "Parent IDs") .map { l => l.extract[List[JValue]].map(_.extract[Int]) } .getOrElse(Seq.empty) val storageLevel = storageLevelFromJson(json \ "Storage Level") @@ -1029,7 +1029,7 @@ private[spark] object JsonProtocol { } def propertiesFromJson(json: JValue): Properties = { - Utils.jsonOption(json).map { value => + jsonOption(json).map { value => val properties = new Properties mapFromJson(json).foreach { case (k, v) => properties.setProperty(k, v) } properties @@ -1058,4 +1058,14 @@ private[spark] object JsonProtocol { e } + /** Return an option that translates JNothing to None */ + private def jsonOption(json: JValue): Option[JValue] = { + json match { + case JNothing => None + case value: JValue => Some(value) + } + } + + private def emptyJson: JObject = JObject(List[JField]()) + } diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 2e2a4a259e9af..29d26ea2c85df 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -25,7 +25,7 @@ import java.net._ import java.nio.ByteBuffer import java.nio.channels.{Channels, FileChannel} import java.nio.charset.StandardCharsets -import java.nio.file.{Files, Paths} +import java.nio.file.Files import java.util.{Locale, Properties, Random, UUID} import java.util.concurrent._ import java.util.concurrent.atomic.AtomicBoolean @@ -51,9 +51,7 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, FileUtil, Path} import org.apache.hadoop.security.UserGroupInformation import org.apache.hadoop.yarn.conf.YarnConfiguration -import org.apache.log4j.PropertyConfigurator import org.eclipse.jetty.util.MultiException -import org.json4s._ import org.slf4j.Logger import org.apache.spark._ @@ -1017,70 +1015,18 @@ private[spark] object Utils extends Logging { " " + (System.currentTimeMillis - startTimeMs) + " ms" } - private def listFilesSafely(file: File): Seq[File] = { - if (file.exists()) { - val files = file.listFiles() - if (files == null) { - throw new IOException("Failed to list files for dir: " + file) - } - files - } else { - List() - } - } - - /** - * Lists files recursively. - */ - def recursiveList(f: File): Array[File] = { - require(f.isDirectory) - val current = f.listFiles - current ++ current.filter(_.isDirectory).flatMap(recursiveList) - } - /** * Delete a file or directory and its contents recursively. * Don't follow directories if they are symlinks. * Throws an exception if deletion is unsuccessful. */ - def deleteRecursively(file: File) { + def deleteRecursively(file: File): Unit = { if (file != null) { - try { - if (file.isDirectory && !isSymlink(file)) { - var savedIOException: IOException = null - for (child <- listFilesSafely(file)) { - try { - deleteRecursively(child) - } catch { - // In case of multiple exceptions, only last one will be thrown - case ioe: IOException => savedIOException = ioe - } - } - if (savedIOException != null) { - throw savedIOException - } - ShutdownHookManager.removeShutdownDeleteDir(file) - } - } finally { - if (file.delete()) { - logTrace(s"${file.getAbsolutePath} has been deleted") - } else { - // Delete can also fail if the file simply did not exist - if (file.exists()) { - throw new IOException("Failed to delete: " + file.getAbsolutePath) - } - } - } + JavaUtils.deleteRecursively(file) + ShutdownHookManager.removeShutdownDeleteDir(file) } } - /** - * Check to see if file is a symbolic link. - */ - def isSymlink(file: File): Boolean = { - return Files.isSymbolicLink(Paths.get(file.toURI)) - } - /** * Determines if a directory contains any files newer than cutoff seconds. * @@ -1828,7 +1774,7 @@ private[spark] object Utils extends Logging { * [[scala.collection.Iterator#size]] because it uses a for loop, which is slightly slower * in the current version of Scala. */ - def getIteratorSize[T](iterator: Iterator[T]): Long = { + def getIteratorSize(iterator: Iterator[_]): Long = { var count = 0L while (iterator.hasNext) { count += 1L @@ -1875,17 +1821,6 @@ private[spark] object Utils extends Logging { obj.getClass.getSimpleName.replace("$", "") } - /** Return an option that translates JNothing to None */ - def jsonOption(json: JValue): Option[JValue] = { - json match { - case JNothing => None - case value: JValue => Some(value) - } - } - - /** Return an empty JSON object */ - def emptyJson: JsonAST.JObject = JObject(List[JField]()) - /** * Return a Hadoop FileSystem with the scheme encoded in the given path. */ @@ -1900,15 +1835,6 @@ private[spark] object Utils extends Logging { getHadoopFileSystem(new URI(path), conf) } - /** - * Return the absolute path of a file in the given directory. - */ - def getFilePath(dir: File, fileName: String): Path = { - assert(dir.isDirectory) - val path = new File(dir, fileName).getAbsolutePath - new Path(path) - } - /** * Whether the underlying operating system is Windows. */ @@ -1931,13 +1857,6 @@ private[spark] object Utils extends Logging { sys.env.contains("SPARK_TESTING") || sys.props.contains("spark.testing") } - /** - * Strip the directory from a path name - */ - def stripDirectory(path: String): String = { - new File(path).getName - } - /** * Terminates a process waiting for at most the specified duration. * @@ -2348,36 +2267,6 @@ private[spark] object Utils extends Logging { org.apache.log4j.Logger.getRootLogger().setLevel(l) } - /** - * config a log4j properties used for testsuite - */ - def configTestLog4j(level: String): Unit = { - val pro = new Properties() - pro.put("log4j.rootLogger", s"$level, console") - pro.put("log4j.appender.console", "org.apache.log4j.ConsoleAppender") - pro.put("log4j.appender.console.target", "System.err") - pro.put("log4j.appender.console.layout", "org.apache.log4j.PatternLayout") - pro.put("log4j.appender.console.layout.ConversionPattern", - "%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n") - PropertyConfigurator.configure(pro) - } - - def invoke( - clazz: Class[_], - obj: AnyRef, - methodName: String, - args: (Class[_], AnyRef)*): AnyRef = { - val (types, values) = args.unzip - val method = clazz.getDeclaredMethod(methodName, types: _*) - method.setAccessible(true) - method.invoke(obj, values.toSeq: _*) - } - - // Limit of bytes for total size of results (default is 1GB) - def getMaxResultSize(conf: SparkConf): Long = { - memoryStringToMb(conf.get("spark.driver.maxResultSize", "1g")).toLong << 20 - } - /** * Return the current system LD_LIBRARY_PATH name */ @@ -2610,16 +2499,6 @@ private[spark] object Utils extends Logging { SignalUtils.registerLogger(log) } - /** - * Unions two comma-separated lists of files and filters out empty strings. - */ - def unionFileLists(leftList: Option[String], rightList: Option[String]): Set[String] = { - var allFiles = Set.empty[String] - leftList.foreach { value => allFiles ++= value.split(",") } - rightList.foreach { value => allFiles ++= value.split(",") } - allFiles.filter { _.nonEmpty } - } - /** * Return the jar files pointed by the "spark.jars" property. Spark internally will distribute * these jars through file server. In the YARN mode, it will return an empty list, since YARN diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java index 24a55df84a240..0d5c5ea7903e9 100644 --- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java @@ -95,7 +95,7 @@ public void tearDown() { @SuppressWarnings("unchecked") public void setUp() throws IOException { MockitoAnnotations.initMocks(this); - tempDir = Utils.createTempDir("test", "test"); + tempDir = Utils.createTempDir(null, "test"); mergedOutputFile = File.createTempFile("mergedoutput", "", tempDir); partitionSizesInMergedFile = null; spillFilesCreated.clear(); diff --git a/core/src/test/scala/org/apache/spark/DriverSuite.scala b/core/src/test/scala/org/apache/spark/DriverSuite.scala index 962945e5b6bb1..896cd2e80aaef 100644 --- a/core/src/test/scala/org/apache/spark/DriverSuite.scala +++ b/core/src/test/scala/org/apache/spark/DriverSuite.scala @@ -51,7 +51,7 @@ class DriverSuite extends SparkFunSuite with TimeLimits { */ object DriverWithoutCleanup { def main(args: Array[String]) { - Utils.configTestLog4j("INFO") + TestUtils.configTestLog4j("INFO") val conf = new SparkConf val sc = new SparkContext(args(0), "DriverWithoutCleanup", conf) sc.parallelize(1 to 100, 4).count() diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index 803a38d77fb82..d265643a80b4e 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -35,6 +35,7 @@ import org.scalatest.concurrent.{Signaler, ThreadSignaler, TimeLimits} import org.scalatest.time.SpanSugar._ import org.apache.spark._ +import org.apache.spark.TestUtils import org.apache.spark.TestUtils.JavaSourceFromString import org.apache.spark.api.r.RUtils import org.apache.spark.deploy.SparkSubmit._ @@ -761,18 +762,6 @@ class SparkSubmitSuite } } - test("comma separated list of files are unioned correctly") { - val left = Option("/tmp/a.jar,/tmp/b.jar") - val right = Option("/tmp/c.jar,/tmp/a.jar") - val emptyString = Option("") - Utils.unionFileLists(left, right) should be (Set("/tmp/a.jar", "/tmp/b.jar", "/tmp/c.jar")) - Utils.unionFileLists(emptyString, emptyString) should be (Set.empty) - Utils.unionFileLists(Option("/tmp/a.jar"), emptyString) should be (Set("/tmp/a.jar")) - Utils.unionFileLists(emptyString, Option("/tmp/a.jar")) should be (Set("/tmp/a.jar")) - Utils.unionFileLists(None, Option("/tmp/a.jar")) should be (Set("/tmp/a.jar")) - Utils.unionFileLists(Option("/tmp/a.jar"), None) should be (Set("/tmp/a.jar")) - } - test("support glob path") { val tmpJarDir = Utils.createTempDir() val jar1 = TestUtils.createJarWithFiles(Map("test.resource" -> "1"), tmpJarDir) @@ -1042,6 +1031,7 @@ class SparkSubmitSuite assert(exception.getMessage() === "hello") } + } object SparkSubmitSuite extends SparkFunSuite with TimeLimits { @@ -1076,7 +1066,7 @@ object SparkSubmitSuite extends SparkFunSuite with TimeLimits { object JarCreationTest extends Logging { def main(args: Array[String]) { - Utils.configTestLog4j("INFO") + TestUtils.configTestLog4j("INFO") val conf = new SparkConf() val sc = new SparkContext(conf) val result = sc.makeRDD(1 to 100, 10).mapPartitions { x => @@ -1100,7 +1090,7 @@ object JarCreationTest extends Logging { object SimpleApplicationTest { def main(args: Array[String]) { - Utils.configTestLog4j("INFO") + TestUtils.configTestLog4j("INFO") val conf = new SparkConf() val sc = new SparkContext(conf) val configs = Seq("spark.master", "spark.app.name") diff --git a/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala index 73e7b3fe8c1de..e24d550a62665 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala @@ -47,7 +47,7 @@ class ReplayListenerSuite extends SparkFunSuite with BeforeAndAfter with LocalSp } test("Simple replay") { - val logFilePath = Utils.getFilePath(testDir, "events.txt") + val logFilePath = getFilePath(testDir, "events.txt") val fstream = fileSystem.create(logFilePath) val writer = new PrintWriter(fstream) val applicationStart = SparkListenerApplicationStart("Greatest App (N)ever", None, @@ -97,7 +97,7 @@ class ReplayListenerSuite extends SparkFunSuite with BeforeAndAfter with LocalSp // scalastyle:on println } - val logFilePath = Utils.getFilePath(testDir, "events.lz4.inprogress") + val logFilePath = getFilePath(testDir, "events.lz4.inprogress") val bytes = buffered.toByteArray Utils.tryWithResource(fileSystem.create(logFilePath)) { fstream => fstream.write(bytes, 0, buffered.size) @@ -129,7 +129,7 @@ class ReplayListenerSuite extends SparkFunSuite with BeforeAndAfter with LocalSp } test("Replay incompatible event log") { - val logFilePath = Utils.getFilePath(testDir, "incompatible.txt") + val logFilePath = getFilePath(testDir, "incompatible.txt") val fstream = fileSystem.create(logFilePath) val writer = new PrintWriter(fstream) val applicationStart = SparkListenerApplicationStart("Incompatible App", None, @@ -226,6 +226,12 @@ class ReplayListenerSuite extends SparkFunSuite with BeforeAndAfter with LocalSp } } + private def getFilePath(dir: File, fileName: String): Path = { + assert(dir.isDirectory) + val path = new File(dir, fileName).getAbsolutePath + new Path(path) + } + /** * A simple listener that buffers all the events it receives. * diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index eaea6b030c154..3b4273184f1e9 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -648,6 +648,7 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { test("fetch hcfs dir") { val tempDir = Utils.createTempDir() val sourceDir = new File(tempDir, "source-dir") + sourceDir.mkdir() val innerSourceDir = Utils.createTempDir(root = sourceDir.getPath) val sourceFile = File.createTempFile("someprefix", "somesuffix", innerSourceDir) val targetDir = new File(tempDir, "target-dir") diff --git a/scalastyle-config.xml b/scalastyle-config.xml index e2fa5754afaee..e65e3aafe5b5b 100644 --- a/scalastyle-config.xml +++ b/scalastyle-config.xml @@ -229,7 +229,7 @@ This file is divided into 3 sections: extractOpt - Use Utils.jsonOption(x).map(.extract[T]) instead of .extractOpt[T], as the latter + Use jsonOption(x).map(.extract[T]) instead of .extractOpt[T], as the latter is slower. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala index 523f7cf77e103..8a3bbd03a26dc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala @@ -39,8 +39,8 @@ abstract class OrcSuite extends OrcTest with BeforeAndAfterAll { protected override def beforeAll(): Unit = { super.beforeAll() - orcTableAsDir = Utils.createTempDir("orctests", "sparksql") - orcTableDir = Utils.createTempDir("orctests", "sparksql") + orcTableAsDir = Utils.createTempDir(namePrefix = "orctests") + orcTableDir = Utils.createTempDir(namePrefix = "orctests") sparkContext .makeRDD(1 to 10) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala index 122d28798136f..534d8bb629b8c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala @@ -21,13 +21,13 @@ import java.io.File import scala.collection.mutable.HashMap +import org.apache.spark.TestUtils import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd} import org.apache.spark.sql.DataFrame import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.execution.SparkPlanInfo import org.apache.spark.sql.execution.ui.{SparkPlanGraph, SQLAppStatusStore} import org.apache.spark.sql.test.SQLTestUtils -import org.apache.spark.util.Utils trait SQLMetricsTestUtils extends SQLTestUtils { @@ -91,7 +91,7 @@ trait SQLMetricsTestUtils extends SQLTestUtils { (0 until 100).map(i => (i, i + 1)).toDF("i", "j").repartition(2) .write.format(dataFormat).mode("overwrite").insertInto(tableName) } - assert(Utils.recursiveList(tableLocation).count(_.getName.startsWith("part-")) == 2) + assert(TestUtils.recursiveList(tableLocation).count(_.getName.startsWith("part-")) == 2) } } @@ -121,7 +121,7 @@ trait SQLMetricsTestUtils extends SQLTestUtils { .mode("overwrite") .insertInto(tableName) } - assert(Utils.recursiveList(dir).count(_.getName.startsWith("part-")) == 40) + assert(TestUtils.recursiveList(dir).count(_.getName.startsWith("part-")) == 40) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala index 0fe33e87318a5..27c983f270bf6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala @@ -22,6 +22,7 @@ import java.sql.Timestamp import org.apache.hadoop.mapreduce.TaskAttemptContext +import org.apache.spark.TestUtils import org.apache.spark.internal.Logging import org.apache.spark.sql.{QueryTest, Row} import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils @@ -86,15 +87,15 @@ class PartitionedWriteSuite extends QueryTest with SharedSQLContext { withTempDir { f => spark.range(start = 0, end = 4, step = 1, numPartitions = 1) .write.option("maxRecordsPerFile", 1).mode("overwrite").parquet(f.getAbsolutePath) - assert(Utils.recursiveList(f).count(_.getAbsolutePath.endsWith("parquet")) == 4) + assert(TestUtils.recursiveList(f).count(_.getAbsolutePath.endsWith("parquet")) == 4) spark.range(start = 0, end = 4, step = 1, numPartitions = 1) .write.option("maxRecordsPerFile", 2).mode("overwrite").parquet(f.getAbsolutePath) - assert(Utils.recursiveList(f).count(_.getAbsolutePath.endsWith("parquet")) == 2) + assert(TestUtils.recursiveList(f).count(_.getAbsolutePath.endsWith("parquet")) == 2) spark.range(start = 0, end = 4, step = 1, numPartitions = 1) .write.option("maxRecordsPerFile", -1).mode("overwrite").parquet(f.getAbsolutePath) - assert(Utils.recursiveList(f).count(_.getAbsolutePath.endsWith("parquet")) == 1) + assert(TestUtils.recursiveList(f).count(_.getAbsolutePath.endsWith("parquet")) == 1) } } @@ -106,7 +107,7 @@ class PartitionedWriteSuite extends QueryTest with SharedSQLContext { .option("maxRecordsPerFile", 1) .mode("overwrite") .parquet(f.getAbsolutePath) - assert(Utils.recursiveList(f).count(_.getAbsolutePath.endsWith("parquet")) == 4) + assert(TestUtils.recursiveList(f).count(_.getAbsolutePath.endsWith("parquet")) == 4) } } @@ -133,14 +134,14 @@ class PartitionedWriteSuite extends QueryTest with SharedSQLContext { val df = Seq((1, ts)).toDF("i", "ts") withTempPath { f => df.write.partitionBy("ts").parquet(f.getAbsolutePath) - val files = Utils.recursiveList(f).filter(_.getAbsolutePath.endsWith("parquet")) + val files = TestUtils.recursiveList(f).filter(_.getAbsolutePath.endsWith("parquet")) assert(files.length == 1) checkPartitionValues(files.head, "2016-12-01 00:00:00") } withTempPath { f => df.write.option(DateTimeUtils.TIMEZONE_OPTION, "GMT") .partitionBy("ts").parquet(f.getAbsolutePath) - val files = Utils.recursiveList(f).filter(_.getAbsolutePath.endsWith("parquet")) + val files = TestUtils.recursiveList(f).filter(_.getAbsolutePath.endsWith("parquet")) assert(files.length == 1) // use timeZone option "GMT" to format partition value. checkPartitionValues(files.head, "2016-12-01 08:00:00") @@ -148,7 +149,7 @@ class PartitionedWriteSuite extends QueryTest with SharedSQLContext { withTempPath { f => withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> "GMT") { df.write.partitionBy("ts").parquet(f.getAbsolutePath) - val files = Utils.recursiveList(f).filter(_.getAbsolutePath.endsWith("parquet")) + val files = TestUtils.recursiveList(f).filter(_.getAbsolutePath.endsWith("parquet")) assert(files.length == 1) // if there isn't timeZone option, then use session local timezone. checkPartitionValues(files.head, "2016-12-01 08:00:00") diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala index 496f8c82a6c61..b32c547cefefe 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala @@ -804,7 +804,7 @@ abstract class HiveThriftServer2Test extends SparkFunSuite with BeforeAndAfterAl protected var metastorePath: File = _ protected def metastoreJdbcUri = s"jdbc:derby:;databaseName=$metastorePath;create=true" - private val pidDir: File = Utils.createTempDir("thriftserver-pid") + private val pidDir: File = Utils.createTempDir(namePrefix = "thriftserver-pid") protected var logPath: File = _ protected var operationLogPath: File = _ private var logTailingProcess: Process = _ diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala index 2d31781132edc..079fe45860544 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala @@ -330,7 +330,7 @@ class HiveSparkSubmitSuite object SetMetastoreURLTest extends Logging { def main(args: Array[String]): Unit = { - Utils.configTestLog4j("INFO") + TestUtils.configTestLog4j("INFO") val sparkConf = new SparkConf(loadDefaults = true) val builder = SparkSession.builder() @@ -368,7 +368,7 @@ object SetMetastoreURLTest extends Logging { object SetWarehouseLocationTest extends Logging { def main(args: Array[String]): Unit = { - Utils.configTestLog4j("INFO") + TestUtils.configTestLog4j("INFO") val sparkConf = new SparkConf(loadDefaults = true).set("spark.ui.enabled", "false") val providedExpectedWarehouseLocation = @@ -447,7 +447,7 @@ object SetWarehouseLocationTest extends Logging { // can load the jar defined with the function. object TemporaryHiveUDFTest extends Logging { def main(args: Array[String]) { - Utils.configTestLog4j("INFO") + TestUtils.configTestLog4j("INFO") val conf = new SparkConf() conf.set("spark.ui.enabled", "false") val sc = new SparkContext(conf) @@ -485,7 +485,7 @@ object TemporaryHiveUDFTest extends Logging { // can load the jar defined with the function. object PermanentHiveUDFTest1 extends Logging { def main(args: Array[String]) { - Utils.configTestLog4j("INFO") + TestUtils.configTestLog4j("INFO") val conf = new SparkConf() conf.set("spark.ui.enabled", "false") val sc = new SparkContext(conf) @@ -523,7 +523,7 @@ object PermanentHiveUDFTest1 extends Logging { // can load the jar defined with the function. object PermanentHiveUDFTest2 extends Logging { def main(args: Array[String]) { - Utils.configTestLog4j("INFO") + TestUtils.configTestLog4j("INFO") val conf = new SparkConf() conf.set("spark.ui.enabled", "false") val sc = new SparkContext(conf) @@ -558,7 +558,7 @@ object PermanentHiveUDFTest2 extends Logging { // We test if we can load user jars in both driver and executors when HiveContext is used. object SparkSubmitClassLoaderTest extends Logging { def main(args: Array[String]) { - Utils.configTestLog4j("INFO") + TestUtils.configTestLog4j("INFO") val conf = new SparkConf() val hiveWarehouseLocation = Utils.createTempDir() conf.set("spark.ui.enabled", "false") @@ -628,7 +628,7 @@ object SparkSubmitClassLoaderTest extends Logging { // We test if we can correctly set spark sql configurations when HiveContext is used. object SparkSQLConfTest extends Logging { def main(args: Array[String]) { - Utils.configTestLog4j("INFO") + TestUtils.configTestLog4j("INFO") // We override the SparkConf to add spark.sql.hive.metastore.version and // spark.sql.hive.metastore.jars to the beginning of the conf entry array. // So, if metadataHive get initialized after we set spark.sql.hive.metastore.version but @@ -669,7 +669,7 @@ object SPARK_9757 extends QueryTest { protected var spark: SparkSession = _ def main(args: Array[String]): Unit = { - Utils.configTestLog4j("INFO") + TestUtils.configTestLog4j("INFO") val hiveWarehouseLocation = Utils.createTempDir() val sparkContext = new SparkContext( @@ -718,7 +718,7 @@ object SPARK_11009 extends QueryTest { protected var spark: SparkSession = _ def main(args: Array[String]): Unit = { - Utils.configTestLog4j("INFO") + TestUtils.configTestLog4j("INFO") val sparkContext = new SparkContext( new SparkConf() @@ -749,7 +749,7 @@ object SPARK_14244 extends QueryTest { protected var spark: SparkSession = _ def main(args: Array[String]): Unit = { - Utils.configTestLog4j("INFO") + TestUtils.configTestLog4j("INFO") val sparkContext = new SparkContext( new SparkConf() diff --git a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala index ee2fd45a7e851..19b621f11759d 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala @@ -97,7 +97,7 @@ trait DStreamCheckpointTester { self: SparkFunSuite => val batchDurationMillis = batchDuration.milliseconds // Setup the stream computation - val checkpointDir = Utils.createTempDir(this.getClass.getSimpleName()).toString + val checkpointDir = Utils.createTempDir(namePrefix = this.getClass.getSimpleName()).toString logDebug(s"Using checkpoint directory $checkpointDir") val ssc = createContextForCheckpointOperation(batchDuration) require(ssc.conf.get("spark.streaming.clock") === classOf[ManualClock].getName, diff --git a/streaming/src/test/scala/org/apache/spark/streaming/MapWithStateSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/MapWithStateSuite.scala index 3b662ec1833aa..06c0c2aa97ee1 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/MapWithStateSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/MapWithStateSuite.scala @@ -39,7 +39,7 @@ class MapWithStateSuite extends SparkFunSuite before { StreamingContext.getActive().foreach { _.stop(stopSparkContext = false) } - checkpointDir = Utils.createTempDir("checkpoint") + checkpointDir = Utils.createTempDir(namePrefix = "checkpoint") } after { From ac76eff6a88f6358a321b84cb5e60fb9d6403419 Mon Sep 17 00:00:00 2001 From: Xingbo Jiang Date: Wed, 7 Mar 2018 13:51:44 -0800 Subject: [PATCH 0446/1161] [SPARK-23525][SQL] Support ALTER TABLE CHANGE COLUMN COMMENT for external hive table ## What changes were proposed in this pull request? The following query doesn't work as expected: ``` CREATE EXTERNAL TABLE ext_table(a STRING, b INT, c STRING) PARTITIONED BY (d STRING) LOCATION 'sql/core/spark-warehouse/ext_table'; ALTER TABLE ext_table CHANGE a a STRING COMMENT "new comment"; DESC ext_table; ``` The comment of column `a` is not updated, that's because `HiveExternalCatalog.doAlterTable` ignores table schema changes. To fix the issue, we should call `doAlterTableDataSchema` instead of `doAlterTable`. ## How was this patch tested? Updated `DDLSuite.testChangeColumn`. Author: Xingbo Jiang Closes #20696 from jiangxb1987/alterColumnComment. --- .../spark/sql/execution/command/ddl.scala | 12 ++++++------ .../sql-tests/inputs/change-column.sql | 1 + .../sql-tests/results/change-column.sql.out | 19 ++++++++++++++----- .../sql/execution/command/DDLSuite.scala | 1 + 4 files changed, 22 insertions(+), 11 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala index 964cbca049b27..bf4d96fa18d0d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala @@ -314,8 +314,8 @@ case class AlterTableChangeColumnCommand( val resolver = sparkSession.sessionState.conf.resolver DDLUtils.verifyAlterTableType(catalog, table, isView = false) - // Find the origin column from schema by column name. - val originColumn = findColumnByName(table.schema, columnName, resolver) + // Find the origin column from dataSchema by column name. + val originColumn = findColumnByName(table.dataSchema, columnName, resolver) // Throw an AnalysisException if the column name/dataType is changed. if (!columnEqual(originColumn, newColumn, resolver)) { throw new AnalysisException( @@ -324,7 +324,7 @@ case class AlterTableChangeColumnCommand( s"'${newColumn.name}' with type '${newColumn.dataType}'") } - val newSchema = table.schema.fields.map { field => + val newDataSchema = table.dataSchema.fields.map { field => if (field.name == originColumn.name) { // Create a new column from the origin column with the new comment. addComment(field, newColumn.getComment) @@ -332,8 +332,7 @@ case class AlterTableChangeColumnCommand( field } } - val newTable = table.copy(schema = StructType(newSchema)) - catalog.alterTable(newTable) + catalog.alterTableDataSchema(tableName, StructType(newDataSchema)) Seq.empty[Row] } @@ -345,7 +344,8 @@ case class AlterTableChangeColumnCommand( schema.fields.collectFirst { case field if resolver(field.name, name) => field }.getOrElse(throw new AnalysisException( - s"Invalid column reference '$name', table schema is '${schema}'")) + s"Can't find column `$name` given table data columns " + + s"${schema.fieldNames.mkString("[`", "`, `", "`]")}")) } // Add the comment to a column, if comment is empty, return the original column. diff --git a/sql/core/src/test/resources/sql-tests/inputs/change-column.sql b/sql/core/src/test/resources/sql-tests/inputs/change-column.sql index ad0f885f63d3d..2909024e4c9f7 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/change-column.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/change-column.sql @@ -49,6 +49,7 @@ ALTER TABLE global_temp.global_temp_view CHANGE a a INT COMMENT 'this is column -- Change column in partition spec (not supported yet) CREATE TABLE partition_table(a INT, b STRING, c INT, d STRING) USING parquet PARTITIONED BY (c, d); ALTER TABLE partition_table PARTITION (c = 1) CHANGE COLUMN a new_a INT; +ALTER TABLE partition_table CHANGE COLUMN c c INT COMMENT 'this is column C'; -- DROP TEST TABLE DROP TABLE test_change; diff --git a/sql/core/src/test/resources/sql-tests/results/change-column.sql.out b/sql/core/src/test/resources/sql-tests/results/change-column.sql.out index ba8bc936f0c79..ff1ecbcc44c23 100644 --- a/sql/core/src/test/resources/sql-tests/results/change-column.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/change-column.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 32 +-- Number of queries: 33 -- !query 0 @@ -154,7 +154,7 @@ ALTER TABLE test_change CHANGE invalid_col invalid_col INT struct<> -- !query 15 output org.apache.spark.sql.AnalysisException -Invalid column reference 'invalid_col', table schema is 'StructType(StructField(a,IntegerType,true), StructField(b,StringType,true), StructField(c,IntegerType,true))'; +Can't find column `invalid_col` given table data columns [`a`, `b`, `c`]; -- !query 16 @@ -291,16 +291,25 @@ ALTER TABLE partition_table PARTITION (c = 1) CHANGE COLUMN a new_a INT -- !query 30 -DROP TABLE test_change +ALTER TABLE partition_table CHANGE COLUMN c c INT COMMENT 'this is column C' -- !query 30 schema struct<> -- !query 30 output - +org.apache.spark.sql.AnalysisException +Can't find column `c` given table data columns [`a`, `b`]; -- !query 31 -DROP TABLE partition_table +DROP TABLE test_change -- !query 31 schema struct<> -- !query 31 output + + +-- !query 32 +DROP TABLE partition_table +-- !query 32 schema +struct<> +-- !query 32 output + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index db9023b7ec8b6..4041176262426 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -1597,6 +1597,7 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { // Ensure that change column will preserve other metadata fields. sql("ALTER TABLE dbx.tab1 CHANGE COLUMN col1 col1 INT COMMENT 'this is col1'") assert(getMetadata("col1").getString("key") == "value") + assert(getMetadata("col1").getString("comment") == "this is col1") } test("drop build-in function") { From 77c91cc746f93e609c412f3a220495d9e931f696 Mon Sep 17 00:00:00 2001 From: jx158167 Date: Wed, 7 Mar 2018 20:08:32 -0800 Subject: [PATCH 0447/1161] [SPARK-23524] Big local shuffle blocks should not be checked for corruption. ## What changes were proposed in this pull request? In current code, all local blocks will be checked for corruption no matter it's big or not. The reasons are as below: Size in FetchResult for local block is set to be 0 (https://github.com/apache/spark/blob/master/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala#L327) SPARK-4105 meant to only check the small blocks(size Closes #20685 from jinxing64/SPARK-23524. --- .../storage/ShuffleBlockFetcherIterator.scala | 14 +++--- .../ShuffleBlockFetcherIteratorSuite.scala | 45 +++++++++++++++++++ 2 files changed, 54 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index 98b5a735a4529..dd9df74689a13 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -90,7 +90,7 @@ final class ShuffleBlockFetcherIterator( private[this] val startTime = System.currentTimeMillis /** Local blocks to fetch, excluding zero-sized blocks. */ - private[this] val localBlocks = new ArrayBuffer[BlockId]() + private[this] val localBlocks = scala.collection.mutable.LinkedHashSet[BlockId]() /** Remote blocks to fetch, excluding zero-sized blocks. */ private[this] val remoteBlocks = new HashSet[BlockId]() @@ -316,6 +316,7 @@ final class ShuffleBlockFetcherIterator( * track in-memory are the ManagedBuffer references themselves. */ private[this] def fetchLocalBlocks() { + logDebug(s"Start fetching local blocks: ${localBlocks.mkString(", ")}") val iter = localBlocks.iterator while (iter.hasNext) { val blockId = iter.next() @@ -324,7 +325,8 @@ final class ShuffleBlockFetcherIterator( shuffleMetrics.incLocalBlocksFetched(1) shuffleMetrics.incLocalBytesRead(buf.size) buf.retain() - results.put(new SuccessFetchResult(blockId, blockManager.blockManagerId, 0, buf, false)) + results.put(new SuccessFetchResult(blockId, blockManager.blockManagerId, + buf.size(), buf, false)) } catch { case e: Exception => // If we see an exception, stop immediately. @@ -397,7 +399,9 @@ final class ShuffleBlockFetcherIterator( } shuffleMetrics.incRemoteBlocksFetched(1) } - bytesInFlight -= size + if (!localBlocks.contains(blockId)) { + bytesInFlight -= size + } if (isNetworkReqDone) { reqsInFlight -= 1 logDebug("Number of requests in flight " + reqsInFlight) @@ -583,8 +587,8 @@ object ShuffleBlockFetcherIterator { * Result of a fetch from a remote block successfully. * @param blockId block id * @param address BlockManager that the block was fetched from. - * @param size estimated size of the block, used to calculate bytesInFlight. - * Note that this is NOT the exact bytes. + * @param size estimated size of the block. Note that this is NOT the exact bytes. + * Size of remote block is used to calculate bytesInFlight. * @param buf `ManagedBuffer` for the content. * @param isNetworkReqDone Is this the last network request for this host in this fetch request. */ diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index 5bfe9905ff17b..692ae3bf597e0 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -352,6 +352,51 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT intercept[FetchFailedException] { iterator.next() } } + test("big blocks are not checked for corruption") { + val corruptStream = mock(classOf[InputStream]) + when(corruptStream.read(any(), any(), any())).thenThrow(new IOException("corrupt")) + val corruptBuffer = mock(classOf[ManagedBuffer]) + when(corruptBuffer.createInputStream()).thenReturn(corruptStream) + doReturn(10000L).when(corruptBuffer).size() + + val blockManager = mock(classOf[BlockManager]) + val localBmId = BlockManagerId("test-client", "test-client", 1) + doReturn(localBmId).when(blockManager).blockManagerId + doReturn(corruptBuffer).when(blockManager).getBlockData(ShuffleBlockId(0, 0, 0)) + val localBlockLengths = Seq[Tuple2[BlockId, Long]]( + ShuffleBlockId(0, 0, 0) -> corruptBuffer.size() + ) + + val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2) + val remoteBlockLengths = Seq[Tuple2[BlockId, Long]]( + ShuffleBlockId(0, 1, 0) -> corruptBuffer.size() + ) + + val transfer = createMockTransfer( + Map(ShuffleBlockId(0, 0, 0) -> corruptBuffer, ShuffleBlockId(0, 1, 0) -> corruptBuffer)) + + val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( + (localBmId, localBlockLengths), + (remoteBmId, remoteBlockLengths) + ) + + val taskContext = TaskContext.empty() + val iterator = new ShuffleBlockFetcherIterator( + taskContext, + transfer, + blockManager, + blocksByAddress, + (_, in) => new LimitedInputStream(in, 10000), + 2048, + Int.MaxValue, + Int.MaxValue, + Int.MaxValue, + true) + // Blocks should be returned without exceptions. + assert(Set(iterator.next()._1, iterator.next()._1) === + Set(ShuffleBlockId(0, 0, 0), ShuffleBlockId(0, 1, 0))) + } + test("retry corrupt blocks (disabled)") { val blockManager = mock(classOf[BlockManager]) val localBmId = BlockManagerId("test-client", "test-client", 1) From fe22f32041572596a22e5f7441fa0bfbd9608648 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Thu, 8 Mar 2018 10:50:09 +0100 Subject: [PATCH 0448/1161] [SPARK-23620] Splitting thread dump lines by using the br tag ## What changes were proposed in this pull request? I propose to replace `'\n'` by the `
    ` tag in generated html of thread dump page. The `
    ` tag will split thread lines in more reliable way. For now it could look like on the screen shot if the html is proxied and `'\n'` is replaced by another whitespace. The changes allow to more easily read and copy stack traces. ## How was this patch tested? I tested it manually by checking the thread dump page and its source. Author: Maxim Gekk Closes #20762 from MaxGekk/br-thread-dump. --- .../org/apache/spark/status/api/v1/api.scala | 24 ++++++++++++++++++- .../ui/exec/ExecutorThreadDumpPage.scala | 2 +- .../scala/org/apache/spark/util/Utils.scala | 6 ++--- 3 files changed, 27 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala index 369e98b683b1a..971d7e90fa7b8 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala @@ -19,6 +19,8 @@ package org.apache.spark.status.api.v1 import java.lang.{Long => JLong} import java.util.Date +import scala.xml.{NodeSeq, Text} + import com.fasterxml.jackson.annotation.JsonIgnoreProperties import com.fasterxml.jackson.databind.annotation.JsonDeserialize @@ -317,11 +319,31 @@ class RuntimeInfo private[spark]( val javaHome: String, val scalaVersion: String) +case class StackTrace(elems: Seq[String]) { + override def toString: String = elems.mkString + + def html: NodeSeq = { + val withNewLine = elems.foldLeft(NodeSeq.Empty) { (acc, elem) => + if (acc.isEmpty) { + acc :+ Text(elem) + } else { + acc :+
    :+ Text(elem) + } + } + + withNewLine + } + + def mkString(start: String, sep: String, end: String): String = { + elems.mkString(start, sep, end) + } +} + case class ThreadStackTrace( val threadId: Long, val threadName: String, val threadState: Thread.State, - val stackTrace: String, + val stackTrace: StackTrace, val blockedByThreadId: Option[Long], val blockedByLock: String, val holdingLocks: Seq[String]) diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala index 7a9aaf29a8b05..9bb026c60565e 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala @@ -60,7 +60,7 @@ private[ui] class ExecutorThreadDumpPage( {thread.threadName} {thread.threadState} {blockedBy}{heldLocks} - {thread.stackTrace} + {thread.stackTrace.html} } diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 29d26ea2c85df..5caedeb526469 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -61,7 +61,7 @@ import org.apache.spark.internal.config._ import org.apache.spark.launcher.SparkLauncher import org.apache.spark.network.util.JavaUtils import org.apache.spark.serializer.{DeserializationStream, SerializationStream, SerializerInstance} -import org.apache.spark.status.api.v1.ThreadStackTrace +import org.apache.spark.status.api.v1.{StackTrace, ThreadStackTrace} /** CallSite represents a place in user code. It can have a short and a long form. */ private[spark] case class CallSite(shortForm: String, longForm: String) @@ -2118,14 +2118,14 @@ private[spark] object Utils extends Logging { private def threadInfoToThreadStackTrace(threadInfo: ThreadInfo): ThreadStackTrace = { val monitors = threadInfo.getLockedMonitors.map(m => m.getLockedStackFrame -> m).toMap - val stackTrace = threadInfo.getStackTrace.map { frame => + val stackTrace = StackTrace(threadInfo.getStackTrace.map { frame => monitors.get(frame) match { case Some(monitor) => monitor.getLockedStackFrame.toString + s" => holding ${monitor.lockString}" case None => frame.toString } - }.mkString("\n") + }) // use a set to dedup re-entrant locks that are held at multiple places val heldLocks = From 9bb239c8b174d31981dfff63baa38bb8cecfe913 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Thu, 8 Mar 2018 20:19:55 +0900 Subject: [PATCH 0449/1161] [SPARK-23159][PYTHON] Update cloudpickle to v0.4.3 ## What changes were proposed in this pull request? The version of cloudpickle in PySpark was close to version 0.4.0 with some additional backported fixes and some minor additions for Spark related things. This update removes Spark related changes and matches cloudpickle [v0.4.3](https://github.com/cloudpipe/cloudpickle/releases/tag/v0.4.3): Changes by updating to 0.4.3 include: * Fix pickling of named tuples https://github.com/cloudpipe/cloudpickle/pull/113 * Built in type constructors for PyPy compatibility [here](https://github.com/cloudpipe/cloudpickle/commit/d84980ccaafc7982a50d4e04064011f401f17d1b) * Fix memoryview support https://github.com/cloudpipe/cloudpickle/pull/122 * Improved compatibility with other cloudpickle versions https://github.com/cloudpipe/cloudpickle/pull/128 * Several cleanups https://github.com/cloudpipe/cloudpickle/pull/121 and [here](https://github.com/cloudpipe/cloudpickle/commit/c91aaf110441991307f5097f950764079d0f9652) * [MRG] Regression on pickling classes from the __main__ module https://github.com/cloudpipe/cloudpickle/pull/149 * BUG: Handle instance methods of builtin types https://github.com/cloudpipe/cloudpickle/pull/154 * Fix #129 : do not silence RuntimeError in dump() https://github.com/cloudpipe/cloudpickle/pull/153 ## How was this patch tested? Existing pyspark.tests using python 2.7.14, 3.5.2, 3.6.3 Author: Bryan Cutler Closes #20373 from BryanCutler/pyspark-update-cloudpickle-42-SPARK-23159. --- python/pyspark/accumulators.py | 1 - python/pyspark/cloudpickle.py | 320 ++++++++++++++------------------- python/pyspark/serializers.py | 14 +- 3 files changed, 151 insertions(+), 184 deletions(-) diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py index 6ef8cf53cc747..7def676b89a24 100644 --- a/python/pyspark/accumulators.py +++ b/python/pyspark/accumulators.py @@ -94,7 +94,6 @@ else: import socketserver as SocketServer import threading -from pyspark.cloudpickle import CloudPickler from pyspark.serializers import read_int, PickleSerializer diff --git a/python/pyspark/cloudpickle.py b/python/pyspark/cloudpickle.py index 40e91a2d0655d..ea845b98b3db2 100644 --- a/python/pyspark/cloudpickle.py +++ b/python/pyspark/cloudpickle.py @@ -57,7 +57,6 @@ import types import weakref -from pyspark.util import _exception_message if sys.version < '3': from pickle import Pickler @@ -181,6 +180,32 @@ def _builtin_type(name): return getattr(types, name) +def _make__new__factory(type_): + def _factory(): + return type_.__new__ + return _factory + + +# NOTE: These need to be module globals so that they're pickleable as globals. +_get_dict_new = _make__new__factory(dict) +_get_frozenset_new = _make__new__factory(frozenset) +_get_list_new = _make__new__factory(list) +_get_set_new = _make__new__factory(set) +_get_tuple_new = _make__new__factory(tuple) +_get_object_new = _make__new__factory(object) + +# Pre-defined set of builtin_function_or_method instances that can be +# serialized. +_BUILTIN_TYPE_CONSTRUCTORS = { + dict.__new__: _get_dict_new, + frozenset.__new__: _get_frozenset_new, + set.__new__: _get_set_new, + list.__new__: _get_list_new, + tuple.__new__: _get_tuple_new, + object.__new__: _get_object_new, +} + + if sys.version_info < (3, 4): def _walk_global_ops(code): """ @@ -237,28 +262,16 @@ def dump(self, obj): if 'recursion' in e.args[0]: msg = """Could not pickle object as excessively deep recursion required.""" raise pickle.PicklingError(msg) - except pickle.PickleError: - raise - except Exception as e: - emsg = _exception_message(e) - if "'i' format requires" in emsg: - msg = "Object too large to serialize: %s" % emsg else: - msg = "Could not serialize object: %s: %s" % (e.__class__.__name__, emsg) - print_exec(sys.stderr) - raise pickle.PicklingError(msg) - + raise def save_memoryview(self, obj): - """Fallback to save_string""" - Pickler.save_string(self, str(obj)) + self.save(obj.tobytes()) + dispatch[memoryview] = save_memoryview - def save_buffer(self, obj): - """Fallback to save_string""" - Pickler.save_string(self,str(obj)) - if PY3: - dispatch[memoryview] = save_memoryview - else: + if not PY3: + def save_buffer(self, obj): + self.save(str(obj)) dispatch[buffer] = save_buffer def save_unsupported(self, obj): @@ -318,6 +331,24 @@ def save_function(self, obj, name=None): Determines what kind of function obj is (e.g. lambda, defined at interactive prompt, etc) and handles the pickling appropriately. """ + try: + should_special_case = obj in _BUILTIN_TYPE_CONSTRUCTORS + except TypeError: + # Methods of builtin types aren't hashable in python 2. + should_special_case = False + + if should_special_case: + # We keep a special-cased cache of built-in type constructors at + # global scope, because these functions are structured very + # differently in different python versions and implementations (for + # example, they're instances of types.BuiltinFunctionType in + # CPython, but they're ordinary types.FunctionType instances in + # PyPy). + # + # If the function we've received is in that cache, we just + # serialize it as a lookup into the cache. + return self.save_reduce(_BUILTIN_TYPE_CONSTRUCTORS[obj], (), obj=obj) + write = self.write if name is None: @@ -344,7 +375,7 @@ def save_function(self, obj, name=None): return self.save_global(obj, name) # a builtin_function_or_method which comes in as an attribute of some - # object (e.g., object.__new__, itertools.chain.from_iterable) will end + # object (e.g., itertools.chain.from_iterable) will end # up with modname "__main__" and so end up here. But these functions # have no __code__ attribute in CPython, so the handling for # user-defined functions below will fail. @@ -352,16 +383,13 @@ def save_function(self, obj, name=None): # for different python versions. if not hasattr(obj, '__code__'): if PY3: - if sys.version_info < (3, 4): - raise pickle.PicklingError("Can't pickle %r" % obj) - else: - rv = obj.__reduce_ex__(self.proto) + rv = obj.__reduce_ex__(self.proto) else: if hasattr(obj, '__self__'): rv = (getattr, (obj.__self__, name)) else: raise pickle.PicklingError("Can't pickle %r" % obj) - return Pickler.save_reduce(self, obj=obj, *rv) + return self.save_reduce(obj=obj, *rv) # if func is lambda, def'ed at prompt, is in main, or is nested, then # we'll pickle the actual function object rather than simply saving a @@ -420,20 +448,18 @@ def save_dynamic_class(self, obj): from global modules. """ clsdict = dict(obj.__dict__) # copy dict proxy to a dict - if not isinstance(clsdict.get('__dict__', None), property): - # don't extract dict that are properties - clsdict.pop('__dict__', None) - clsdict.pop('__weakref__', None) - - # hack as __new__ is stored differently in the __dict__ - new_override = clsdict.get('__new__', None) - if new_override: - clsdict['__new__'] = obj.__new__ - - # namedtuple is a special case for Spark where we use the _load_namedtuple function - if getattr(obj, '_is_namedtuple_', False): - self.save_reduce(_load_namedtuple, (obj.__name__, obj._fields)) - return + clsdict.pop('__weakref__', None) + + # On PyPy, __doc__ is a readonly attribute, so we need to include it in + # the initial skeleton class. This is safe because we know that the + # doc can't participate in a cycle with the original class. + type_kwargs = {'__doc__': clsdict.pop('__doc__', None)} + + # If type overrides __dict__ as a property, include it in the type kwargs. + # In Python 2, we can't set this attribute after construction. + __dict__ = clsdict.pop('__dict__', None) + if isinstance(__dict__, property): + type_kwargs['__dict__'] = __dict__ save = self.save write = self.write @@ -453,23 +479,12 @@ def save_dynamic_class(self, obj): # Push the rehydration function. save(_rehydrate_skeleton_class) - # Mark the start of the args for the rehydration function. + # Mark the start of the args tuple for the rehydration function. write(pickle.MARK) - # On PyPy, __doc__ is a readonly attribute, so we need to include it in - # the initial skeleton class. This is safe because we know that the - # doc can't participate in a cycle with the original class. - doc_dict = {'__doc__': clsdict.pop('__doc__', None)} - - # Create and memoize an empty class with obj's name and bases. - save(type(obj)) - save(( - obj.__name__, - obj.__bases__, - doc_dict, - )) - write(pickle.REDUCE) - self.memoize(obj) + # Create and memoize an skeleton class with obj's name and bases. + tp = type(obj) + self.save_reduce(tp, (obj.__name__, obj.__bases__, type_kwargs), obj=obj) # Now save the rest of obj's __dict__. Any references to obj # encountered while saving will point to the skeleton class. @@ -522,17 +537,22 @@ def save_function_tuple(self, func): self.memoize(func) # save the rest of the func data needed by _fill_function - save(f_globals) - save(defaults) - save(dct) - save(func.__module__) - save(closure_values) + state = { + 'globals': f_globals, + 'defaults': defaults, + 'dict': dct, + 'module': func.__module__, + 'closure_values': closure_values, + } + if hasattr(func, '__qualname__'): + state['qualname'] = func.__qualname__ + save(state) write(pickle.TUPLE) write(pickle.REDUCE) # applies _fill_function on the tuple _extract_code_globals_cache = ( weakref.WeakKeyDictionary() - if sys.version_info >= (2, 7) and not hasattr(sys, "pypy_version_info") + if not hasattr(sys, "pypy_version_info") else {}) @classmethod @@ -608,37 +628,22 @@ def save_global(self, obj, name=None, pack=struct.pack): The name of this method is somewhat misleading: all types get dispatched here. """ - if obj.__module__ == "__builtin__" or obj.__module__ == "builtins": - if obj in _BUILTIN_TYPE_NAMES: - return self.save_reduce(_builtin_type, (_BUILTIN_TYPE_NAMES[obj],), obj=obj) - - if name is None: - name = obj.__name__ - - modname = getattr(obj, "__module__", None) - if modname is None: - try: - # whichmodule() could fail, see - # https://bitbucket.org/gutworth/six/issues/63/importing-six-breaks-pickling - modname = pickle.whichmodule(obj, name) - except Exception: - modname = '__main__' + if obj.__module__ == "__main__": + return self.save_dynamic_class(obj) - if modname == '__main__': - themodule = None - else: - __import__(modname) - themodule = sys.modules[modname] - self.modules.add(themodule) + try: + return Pickler.save_global(self, obj, name=name) + except Exception: + if obj.__module__ == "__builtin__" or obj.__module__ == "builtins": + if obj in _BUILTIN_TYPE_NAMES: + return self.save_reduce( + _builtin_type, (_BUILTIN_TYPE_NAMES[obj],), obj=obj) - if hasattr(themodule, name) and getattr(themodule, name) is obj: - return Pickler.save_global(self, obj, name) + typ = type(obj) + if typ is not obj and isinstance(obj, (type, types.ClassType)): + return self.save_dynamic_class(obj) - typ = type(obj) - if typ is not obj and isinstance(obj, (type, types.ClassType)): - self.save_dynamic_class(obj) - else: - raise pickle.PicklingError("Can't pickle %r" % obj) + raise dispatch[type] = save_global dispatch[types.ClassType] = save_global @@ -709,12 +714,7 @@ def save_property(self, obj): dispatch[property] = save_property def save_classmethod(self, obj): - try: - orig_func = obj.__func__ - except AttributeError: # Python 2.6 - orig_func = obj.__get__(None, object) - if isinstance(obj, classmethod): - orig_func = orig_func.__func__ # Unbind + orig_func = obj.__func__ self.save_reduce(type(obj), (orig_func,), obj=obj) dispatch[classmethod] = save_classmethod dispatch[staticmethod] = save_classmethod @@ -754,64 +754,6 @@ def __getattribute__(self, item): if type(operator.attrgetter) is type: dispatch[operator.attrgetter] = save_attrgetter - def save_reduce(self, func, args, state=None, - listitems=None, dictitems=None, obj=None): - # Assert that args is a tuple or None - if not isinstance(args, tuple): - raise pickle.PicklingError("args from reduce() should be a tuple") - - # Assert that func is callable - if not hasattr(func, '__call__'): - raise pickle.PicklingError("func from reduce should be callable") - - save = self.save - write = self.write - - # Protocol 2 special case: if func's name is __newobj__, use NEWOBJ - if self.proto >= 2 and getattr(func, "__name__", "") == "__newobj__": - cls = args[0] - if not hasattr(cls, "__new__"): - raise pickle.PicklingError( - "args[0] from __newobj__ args has no __new__") - if obj is not None and cls is not obj.__class__: - raise pickle.PicklingError( - "args[0] from __newobj__ args has the wrong class") - args = args[1:] - save(cls) - - save(args) - write(pickle.NEWOBJ) - else: - save(func) - save(args) - write(pickle.REDUCE) - - if obj is not None: - self.memoize(obj) - - # More new special cases (that work with older protocols as - # well): when __reduce__ returns a tuple with 4 or 5 items, - # the 4th and 5th item should be iterators that provide list - # items and dict items (as (key, value) tuples), or None. - - if listitems is not None: - self._batch_appends(listitems) - - if dictitems is not None: - self._batch_setitems(dictitems) - - if state is not None: - save(state) - write(pickle.BUILD) - - def save_partial(self, obj): - """Partial objects do not serialize correctly in python2.x -- this fixes the bugs""" - self.save_reduce(_genpartial, (obj.func, obj.args, obj.keywords)) - - if sys.version_info < (2,7): # 2.7 supports partial pickling - dispatch[partial] = save_partial - - def save_file(self, obj): """Save a file""" try: @@ -867,23 +809,21 @@ def save_not_implemented(self, obj): dispatch[type(Ellipsis)] = save_ellipsis dispatch[type(NotImplemented)] = save_not_implemented - # WeakSet was added in 2.7. - if hasattr(weakref, 'WeakSet'): - def save_weakset(self, obj): - self.save_reduce(weakref.WeakSet, (list(obj),)) - - dispatch[weakref.WeakSet] = save_weakset + def save_weakset(self, obj): + self.save_reduce(weakref.WeakSet, (list(obj),)) - """Special functions for Add-on libraries""" - def inject_addons(self): - """Plug in system. Register additional pickling functions if modules already loaded""" - pass + dispatch[weakref.WeakSet] = save_weakset def save_logger(self, obj): self.save_reduce(logging.getLogger, (obj.name,), obj=obj) dispatch[logging.Logger] = save_logger + """Special functions for Add-on libraries""" + def inject_addons(self): + """Plug in system. Register additional pickling functions if modules already loaded""" + pass + # Tornado support @@ -913,11 +853,12 @@ def dump(obj, file, protocol=2): def dumps(obj, protocol=2): file = StringIO() - - cp = CloudPickler(file,protocol) - cp.dump(obj) - - return file.getvalue() + try: + cp = CloudPickler(file,protocol) + cp.dump(obj) + return file.getvalue() + finally: + file.close() # including pickles unloading functions in this namespace load = pickle.load @@ -1019,18 +960,40 @@ def __reduce__(cls): return cls.__name__ -def _fill_function(func, globals, defaults, dict, module, closure_values): - """ Fills in the rest of function data into the skeleton function object - that were created via _make_skel_func(). +def _fill_function(*args): + """Fills in the rest of function data into the skeleton function object + + The skeleton itself is create by _make_skel_func(). """ - func.__globals__.update(globals) - func.__defaults__ = defaults - func.__dict__ = dict - func.__module__ = module + if len(args) == 2: + func = args[0] + state = args[1] + elif len(args) == 5: + # Backwards compat for cloudpickle v0.4.0, after which the `module` + # argument was introduced + func = args[0] + keys = ['globals', 'defaults', 'dict', 'closure_values'] + state = dict(zip(keys, args[1:])) + elif len(args) == 6: + # Backwards compat for cloudpickle v0.4.1, after which the function + # state was passed as a dict to the _fill_function it-self. + func = args[0] + keys = ['globals', 'defaults', 'dict', 'module', 'closure_values'] + state = dict(zip(keys, args[1:])) + else: + raise ValueError('Unexpected _fill_value arguments: %r' % (args,)) + + func.__globals__.update(state['globals']) + func.__defaults__ = state['defaults'] + func.__dict__ = state['dict'] + if 'module' in state: + func.__module__ = state['module'] + if 'qualname' in state: + func.__qualname__ = state['qualname'] cells = func.__closure__ if cells is not None: - for cell, value in zip(cells, closure_values): + for cell, value in zip(cells, state['closure_values']): if value is not _empty_cell_value: cell_set(cell, value) @@ -1087,13 +1050,6 @@ def _find_module(mod_name): file.close() return path, description -def _load_namedtuple(name, fields): - """ - Loads a class generated by namedtuple - """ - from collections import namedtuple - return namedtuple(name, fields) - """Constructors for 3rd party libraries Note: These can never be renamed due to client compatibility issues""" diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 91a7f093cec19..917e258d8a602 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -68,6 +68,7 @@ xrange = range from pyspark import cloudpickle +from pyspark.util import _exception_message __all__ = ["PickleSerializer", "MarshalSerializer", "UTF8Deserializer"] @@ -565,7 +566,18 @@ def loads(self, obj, encoding=None): class CloudPickleSerializer(PickleSerializer): def dumps(self, obj): - return cloudpickle.dumps(obj, 2) + try: + return cloudpickle.dumps(obj, 2) + except pickle.PickleError: + raise + except Exception as e: + emsg = _exception_message(e) + if "'i' format requires" in emsg: + msg = "Object too large to serialize: %s" % emsg + else: + msg = "Could not serialize object: %s: %s" % (e.__class__.__name__, emsg) + cloudpickle.print_exec(sys.stderr) + raise pickle.PicklingError(msg) class MarshalSerializer(FramedSerializer): From d6632d185e147fcbe6724545488ad80dce20277e Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Thu, 8 Mar 2018 20:22:07 +0900 Subject: [PATCH 0450/1161] [SPARK-23380][PYTHON] Adds a conf for Arrow fallback in toPandas/createDataFrame with Pandas DataFrame ## What changes were proposed in this pull request? This PR adds a configuration to control the fallback of Arrow optimization for `toPandas` and `createDataFrame` with Pandas DataFrame. ## How was this patch tested? Manually tested and unit tests added. You can test this by: **`createDataFrame`** ```python spark.conf.set("spark.sql.execution.arrow.enabled", False) pdf = spark.createDataFrame([[{'a': 1}]]).toPandas() spark.conf.set("spark.sql.execution.arrow.enabled", True) spark.conf.set("spark.sql.execution.arrow.fallback.enabled", True) spark.createDataFrame(pdf, "a: map") ``` ```python spark.conf.set("spark.sql.execution.arrow.enabled", False) pdf = spark.createDataFrame([[{'a': 1}]]).toPandas() spark.conf.set("spark.sql.execution.arrow.enabled", True) spark.conf.set("spark.sql.execution.arrow.fallback.enabled", False) spark.createDataFrame(pdf, "a: map") ``` **`toPandas`** ```python spark.conf.set("spark.sql.execution.arrow.enabled", True) spark.conf.set("spark.sql.execution.arrow.fallback.enabled", True) spark.createDataFrame([[{'a': 1}]]).toPandas() ``` ```python spark.conf.set("spark.sql.execution.arrow.enabled", True) spark.conf.set("spark.sql.execution.arrow.fallback.enabled", False) spark.createDataFrame([[{'a': 1}]]).toPandas() ``` Author: hyukjinkwon Closes #20678 from HyukjinKwon/SPARK-23380-conf. --- docs/sql-programming-guide.md | 5 + python/pyspark/sql/dataframe.py | 120 ++++++++++++------ python/pyspark/sql/session.py | 22 +++- python/pyspark/sql/tests.py | 84 ++++++++++-- .../apache/spark/sql/internal/SQLConf.scala | 13 +- 5 files changed, 186 insertions(+), 58 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 01e2076555ee6..451b814ab6c53 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1689,6 +1689,10 @@ using the call `toPandas()` and when creating a Spark DataFrame from a Pandas Da `createDataFrame(pandas_df)`. To use Arrow when executing these calls, users need to first set the Spark configuration 'spark.sql.execution.arrow.enabled' to 'true'. This is disabled by default. +In addition, optimizations enabled by 'spark.sql.execution.arrow.enabled' could fallback automatically +to non-Arrow optimization implementation if an error occurs before the actual computation within Spark. +This can be controlled by 'spark.sql.execution.arrow.fallback.enabled'. +
    {% include_example dataframe_with_arrow python/sql/arrow.py %} @@ -1800,6 +1804,7 @@ working with timestamps in `pandas_udf`s to get the best performance, see ## Upgrading From Spark SQL 2.3 to 2.4 - Since Spark 2.4, Spark maximizes the usage of a vectorized ORC reader for ORC files by default. To do that, `spark.sql.orc.impl` and `spark.sql.orc.filterPushdown` change their default values to `native` and `true` respectively. + - In PySpark, when Arrow optimization is enabled, previously `toPandas` just failed when Arrow optimization is unabled to be used whereas `createDataFrame` from Pandas DataFrame allowed the fallback to non-optimization. Now, both `toPandas` and `createDataFrame` from Pandas DataFrame allow the fallback by default, which can be switched off by `spark.sql.execution.arrow.fallback.enabled`. ## Upgrading From Spark SQL 2.2 to 2.3 diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 9d8e85cde914f..8f90a367e8bf8 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1992,55 +1992,91 @@ def toPandas(self): timezone = None if self.sql_ctx.getConf("spark.sql.execution.arrow.enabled", "false").lower() == "true": + use_arrow = True try: - from pyspark.sql.types import _check_dataframe_convert_date, \ - _check_dataframe_localize_timestamps, to_arrow_schema + from pyspark.sql.types import to_arrow_schema from pyspark.sql.utils import require_minimum_pyarrow_version + require_minimum_pyarrow_version() - import pyarrow to_arrow_schema(self.schema) - tables = self._collectAsArrow() - if tables: - table = pyarrow.concat_tables(tables) - pdf = table.to_pandas() - pdf = _check_dataframe_convert_date(pdf, self.schema) - return _check_dataframe_localize_timestamps(pdf, timezone) - else: - return pd.DataFrame.from_records([], columns=self.columns) except Exception as e: - msg = ( - "Note: toPandas attempted Arrow optimization because " - "'spark.sql.execution.arrow.enabled' is set to true. Please set it to false " - "to disable this.") - raise RuntimeError("%s\n%s" % (_exception_message(e), msg)) - else: - pdf = pd.DataFrame.from_records(self.collect(), columns=self.columns) - dtype = {} + if self.sql_ctx.getConf("spark.sql.execution.arrow.fallback.enabled", "true") \ + .lower() == "true": + msg = ( + "toPandas attempted Arrow optimization because " + "'spark.sql.execution.arrow.enabled' is set to true; however, " + "failed by the reason below:\n %s\n" + "Attempts non-optimization as " + "'spark.sql.execution.arrow.fallback.enabled' is set to " + "true." % _exception_message(e)) + warnings.warn(msg) + use_arrow = False + else: + msg = ( + "toPandas attempted Arrow optimization because " + "'spark.sql.execution.arrow.enabled' is set to true; however, " + "failed by the reason below:\n %s\n" + "For fallback to non-optimization automatically, please set true to " + "'spark.sql.execution.arrow.fallback.enabled'." % _exception_message(e)) + raise RuntimeError(msg) + + # Try to use Arrow optimization when the schema is supported and the required version + # of PyArrow is found, if 'spark.sql.execution.arrow.enabled' is enabled. + if use_arrow: + try: + from pyspark.sql.types import _check_dataframe_convert_date, \ + _check_dataframe_localize_timestamps + import pyarrow + + tables = self._collectAsArrow() + if tables: + table = pyarrow.concat_tables(tables) + pdf = table.to_pandas() + pdf = _check_dataframe_convert_date(pdf, self.schema) + return _check_dataframe_localize_timestamps(pdf, timezone) + else: + return pd.DataFrame.from_records([], columns=self.columns) + except Exception as e: + # We might have to allow fallback here as well but multiple Spark jobs can + # be executed. So, simply fail in this case for now. + msg = ( + "toPandas attempted Arrow optimization because " + "'spark.sql.execution.arrow.enabled' is set to true; however, " + "failed unexpectedly:\n %s\n" + "Note that 'spark.sql.execution.arrow.fallback.enabled' does " + "not have an effect in such failure in the middle of " + "computation." % _exception_message(e)) + raise RuntimeError(msg) + + # Below is toPandas without Arrow optimization. + pdf = pd.DataFrame.from_records(self.collect(), columns=self.columns) + + dtype = {} + for field in self.schema: + pandas_type = _to_corrected_pandas_type(field.dataType) + # SPARK-21766: if an integer field is nullable and has null values, it can be + # inferred by pandas as float column. Once we convert the column with NaN back + # to integer type e.g., np.int16, we will hit exception. So we use the inferred + # float type, not the corrected type from the schema in this case. + if pandas_type is not None and \ + not(isinstance(field.dataType, IntegralType) and field.nullable and + pdf[field.name].isnull().any()): + dtype[field.name] = pandas_type + + for f, t in dtype.items(): + pdf[f] = pdf[f].astype(t, copy=False) + + if timezone is None: + return pdf + else: + from pyspark.sql.types import _check_series_convert_timestamps_local_tz for field in self.schema: - pandas_type = _to_corrected_pandas_type(field.dataType) - # SPARK-21766: if an integer field is nullable and has null values, it can be - # inferred by pandas as float column. Once we convert the column with NaN back - # to integer type e.g., np.int16, we will hit exception. So we use the inferred - # float type, not the corrected type from the schema in this case. - if pandas_type is not None and \ - not(isinstance(field.dataType, IntegralType) and field.nullable and - pdf[field.name].isnull().any()): - dtype[field.name] = pandas_type - - for f, t in dtype.items(): - pdf[f] = pdf[f].astype(t, copy=False) - - if timezone is None: - return pdf - else: - from pyspark.sql.types import _check_series_convert_timestamps_local_tz - for field in self.schema: - # TODO: handle nested timestamps, such as ArrayType(TimestampType())? - if isinstance(field.dataType, TimestampType): - pdf[field.name] = \ - _check_series_convert_timestamps_local_tz(pdf[field.name], timezone) - return pdf + # TODO: handle nested timestamps, such as ArrayType(TimestampType())? + if isinstance(field.dataType, TimestampType): + pdf[field.name] = \ + _check_series_convert_timestamps_local_tz(pdf[field.name], timezone) + return pdf def _collectAsArrow(self): """ diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index b3af9b82953f3..215bb3e5c5173 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -666,8 +666,26 @@ def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=Tr try: return self._create_from_pandas_with_arrow(data, schema, timezone) except Exception as e: - warnings.warn("Arrow will not be used in createDataFrame: %s" % str(e)) - # Fallback to create DataFrame without arrow if raise some exception + from pyspark.util import _exception_message + + if self.conf.get("spark.sql.execution.arrow.fallback.enabled", "true") \ + .lower() == "true": + msg = ( + "createDataFrame attempted Arrow optimization because " + "'spark.sql.execution.arrow.enabled' is set to true; however, " + "failed by the reason below:\n %s\n" + "Attempts non-optimization as " + "'spark.sql.execution.arrow.fallback.enabled' is set to " + "true." % _exception_message(e)) + warnings.warn(msg) + else: + msg = ( + "createDataFrame attempted Arrow optimization because " + "'spark.sql.execution.arrow.enabled' is set to true; however, " + "failed by the reason below:\n %s\n" + "For fallback to non-optimization automatically, please set true to " + "'spark.sql.execution.arrow.fallback.enabled'." % _exception_message(e)) + raise RuntimeError(msg) data = self._convert_from_pandas(data, schema, timezone) if isinstance(schema, StructType): diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index fa3b7203e10ac..a9fe0b425ad3e 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -32,7 +32,9 @@ import datetime import array import ctypes +import warnings import py4j +from contextlib import contextmanager try: import xmlrunner @@ -48,12 +50,13 @@ else: import unittest +from pyspark.util import _exception_message + _pandas_requirement_message = None try: from pyspark.sql.utils import require_minimum_pandas_version require_minimum_pandas_version() except ImportError as e: - from pyspark.util import _exception_message # If Pandas version requirement is not satisfied, skip related tests. _pandas_requirement_message = _exception_message(e) @@ -62,7 +65,6 @@ from pyspark.sql.utils import require_minimum_pyarrow_version require_minimum_pyarrow_version() except ImportError as e: - from pyspark.util import _exception_message # If Arrow version requirement is not satisfied, skip related tests. _pyarrow_requirement_message = _exception_message(e) @@ -195,6 +197,28 @@ def tearDownClass(cls): ReusedPySparkTestCase.tearDownClass() cls.spark.stop() + @contextmanager + def sql_conf(self, pairs): + """ + A convenient context manager to test some configuration specific logic. This sets + `value` to the configuration `key` and then restores it back when it exits. + """ + assert isinstance(pairs, dict), "pairs should be a dictionary." + + keys = pairs.keys() + new_values = pairs.values() + old_values = [self.spark.conf.get(key, None) for key in keys] + for key, new_value in zip(keys, new_values): + self.spark.conf.set(key, new_value) + try: + yield + finally: + for key, old_value in zip(keys, old_values): + if old_value is None: + self.spark.conf.unset(key) + else: + self.spark.conf.set(key, old_value) + def assertPandasEqual(self, expected, result): msg = ("DataFrames are not equal: " + "\n\nExpected:\n%s\n%s" % (expected, expected.dtypes) + @@ -3458,6 +3482,8 @@ def setUpClass(cls): cls.spark.conf.set("spark.sql.session.timeZone", tz) cls.spark.conf.set("spark.sql.execution.arrow.enabled", "true") + # Disable fallback by default to easily detect the failures. + cls.spark.conf.set("spark.sql.execution.arrow.fallback.enabled", "false") cls.schema = StructType([ StructField("1_str_t", StringType(), True), StructField("2_int_t", IntegerType(), True), @@ -3493,20 +3519,30 @@ def create_pandas_data_frame(self): data_dict["4_float_t"] = np.float32(data_dict["4_float_t"]) return pd.DataFrame(data=data_dict) - def test_unsupported_datatype(self): + def test_toPandas_fallback_enabled(self): + import pandas as pd + + with self.sql_conf({"spark.sql.execution.arrow.fallback.enabled": True}): + schema = StructType([StructField("map", MapType(StringType(), IntegerType()), True)]) + df = self.spark.createDataFrame([({u'a': 1},)], schema=schema) + with QuietTest(self.sc): + with warnings.catch_warnings(record=True) as warns: + pdf = df.toPandas() + # Catch and check the last UserWarning. + user_warns = [ + warn.message for warn in warns if isinstance(warn.message, UserWarning)] + self.assertTrue(len(user_warns) > 0) + self.assertTrue( + "Attempts non-optimization" in _exception_message(user_warns[-1])) + self.assertPandasEqual(pdf, pd.DataFrame({u'map': [{u'a': 1}]})) + + def test_toPandas_fallback_disabled(self): schema = StructType([StructField("map", MapType(StringType(), IntegerType()), True)]) df = self.spark.createDataFrame([(None,)], schema=schema) with QuietTest(self.sc): with self.assertRaisesRegexp(Exception, 'Unsupported type'): df.toPandas() - df = self.spark.createDataFrame([(None,)], schema="a binary") - with QuietTest(self.sc): - with self.assertRaisesRegexp( - Exception, - 'Unsupported type.*\nNote: toPandas attempted Arrow optimization because'): - df.toPandas() - def test_null_conversion(self): df_null = self.spark.createDataFrame([tuple([None for _ in range(len(self.data[0]))])] + self.data) @@ -3625,7 +3661,7 @@ def test_createDataFrame_with_incorrect_schema(self): pdf = self.create_pandas_data_frame() wrong_schema = StructType(list(reversed(self.schema))) with QuietTest(self.sc): - with self.assertRaisesRegexp(TypeError, ".*field.*can.not.accept.*type"): + with self.assertRaisesRegexp(RuntimeError, ".*No cast.*string.*timestamp.*"): self.spark.createDataFrame(pdf, schema=wrong_schema) def test_createDataFrame_with_names(self): @@ -3650,7 +3686,7 @@ def test_createDataFrame_column_name_encoding(self): def test_createDataFrame_with_single_data_type(self): import pandas as pd with QuietTest(self.sc): - with self.assertRaisesRegexp(TypeError, ".*IntegerType.*tuple"): + with self.assertRaisesRegexp(RuntimeError, ".*IntegerType.*not supported.*"): self.spark.createDataFrame(pd.DataFrame({"a": [1]}), schema="int") def test_createDataFrame_does_not_modify_input(self): @@ -3705,6 +3741,30 @@ def test_createDataFrame_with_int_col_names(self): self.assertEqual(pdf_col_names, df.columns) self.assertEqual(pdf_col_names, df_arrow.columns) + def test_createDataFrame_fallback_enabled(self): + import pandas as pd + + with QuietTest(self.sc): + with self.sql_conf({"spark.sql.execution.arrow.fallback.enabled": True}): + with warnings.catch_warnings(record=True) as warns: + df = self.spark.createDataFrame( + pd.DataFrame([[{u'a': 1}]]), "a: map") + # Catch and check the last UserWarning. + user_warns = [ + warn.message for warn in warns if isinstance(warn.message, UserWarning)] + self.assertTrue(len(user_warns) > 0) + self.assertTrue( + "Attempts non-optimization" in _exception_message(user_warns[-1])) + self.assertEqual(df.collect(), [Row(a={u'a': 1})]) + + def test_createDataFrame_fallback_disabled(self): + import pandas as pd + + with QuietTest(self.sc): + with self.assertRaisesRegexp(Exception, 'Unsupported type'): + self.spark.createDataFrame( + pd.DataFrame([[{u'a': 1}]]), "a: map") + # Regression test for SPARK-23314 def test_timestamp_dst(self): import pandas as pd diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index ce3f94618edeb..3f96112659c11 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1058,7 +1058,7 @@ object SQLConf { .intConf .createWithDefault(100) - val ARROW_EXECUTION_ENABLE = + val ARROW_EXECUTION_ENABLED = buildConf("spark.sql.execution.arrow.enabled") .doc("When true, make use of Apache Arrow for columnar data transfers. Currently available " + "for use with pyspark.sql.DataFrame.toPandas, and " + @@ -1068,6 +1068,13 @@ object SQLConf { .booleanConf .createWithDefault(false) + val ARROW_FALLBACK_ENABLED = + buildConf("spark.sql.execution.arrow.fallback.enabled") + .doc("When true, optimizations enabled by 'spark.sql.execution.arrow.enabled' will " + + "fallback automatically to non-optimized implementations if an error occurs.") + .booleanConf + .createWithDefault(true) + val ARROW_EXECUTION_MAX_RECORDS_PER_BATCH = buildConf("spark.sql.execution.arrow.maxRecordsPerBatch") .doc("When using Apache Arrow, limit the maximum number of records that can be written " + @@ -1518,7 +1525,9 @@ class SQLConf extends Serializable with Logging { def rangeExchangeSampleSizePerPartition: Int = getConf(RANGE_EXCHANGE_SAMPLE_SIZE_PER_PARTITION) - def arrowEnable: Boolean = getConf(ARROW_EXECUTION_ENABLE) + def arrowEnabled: Boolean = getConf(ARROW_EXECUTION_ENABLED) + + def arrowFallbackEnabled: Boolean = getConf(ARROW_FALLBACK_ENABLED) def arrowMaxRecordsPerBatch: Int = getConf(ARROW_EXECUTION_MAX_RECORDS_PER_BATCH) From 2cb23a8f51a151970c121015fcbad9beeafa8295 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Thu, 8 Mar 2018 20:29:07 +0900 Subject: [PATCH 0451/1161] [SPARK-23011][SQL][PYTHON] Support alternative function form with group aggregate pandas UDF ## What changes were proposed in this pull request? This PR proposes to support an alternative function from with group aggregate pandas UDF. The current form: ``` def foo(pdf): return ... ``` Takes a single arg that is a pandas DataFrame. With this PR, an alternative form is supported: ``` def foo(key, pdf): return ... ``` The alternative form takes two argument - a tuple that presents the grouping key, and a pandas DataFrame represents the data. ## How was this patch tested? GroupbyApplyTests Author: Li Jin Closes #20295 from icexelloss/SPARK-23011-groupby-apply-key. --- python/pyspark/serializers.py | 18 ++- python/pyspark/sql/functions.py | 25 ++++ python/pyspark/sql/tests.py | 121 ++++++++++++++++-- python/pyspark/sql/types.py | 45 ++++++- python/pyspark/sql/udf.py | 19 +-- python/pyspark/util.py | 16 +++ python/pyspark/worker.py | 49 +++++-- .../python/FlatMapGroupsInPandasExec.scala | 56 +++++++- 8 files changed, 294 insertions(+), 55 deletions(-) diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 917e258d8a602..ebf549396f463 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -250,6 +250,15 @@ def __init__(self, timezone): super(ArrowStreamPandasSerializer, self).__init__() self._timezone = timezone + def arrow_to_pandas(self, arrow_column): + from pyspark.sql.types import from_arrow_type, \ + _check_series_convert_date, _check_series_localize_timestamps + + s = arrow_column.to_pandas() + s = _check_series_convert_date(s, from_arrow_type(arrow_column.type)) + s = _check_series_localize_timestamps(s, self._timezone) + return s + def dump_stream(self, iterator, stream): """ Make ArrowRecordBatches from Pandas Series and serialize. Input is a single series or @@ -272,16 +281,11 @@ def load_stream(self, stream): """ Deserialize ArrowRecordBatches to an Arrow table and return as a list of pandas.Series. """ - from pyspark.sql.types import from_arrow_schema, _check_dataframe_convert_date, \ - _check_dataframe_localize_timestamps import pyarrow as pa reader = pa.open_stream(stream) - schema = from_arrow_schema(reader.schema) + for batch in reader: - pdf = batch.to_pandas() - pdf = _check_dataframe_convert_date(pdf, schema) - pdf = _check_dataframe_localize_timestamps(pdf, self._timezone) - yield [c for _, c in pdf.iteritems()] + yield [self.arrow_to_pandas(c) for c in pa.Table.from_batches([batch]).itercolumns()] def __repr__(self): return "ArrowStreamPandasSerializer" diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index b9c0c57262c5d..dc1341ac74d3d 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2267,6 +2267,31 @@ def pandas_udf(f=None, returnType=None, functionType=None): | 2| 1.1094003924504583| +---+-------------------+ + Alternatively, the user can define a function that takes two arguments. + In this case, the grouping key will be passed as the first argument and the data will + be passed as the second argument. The grouping key will be passed as a tuple of numpy + data types, e.g., `numpy.int32` and `numpy.float64`. The data will still be passed in + as a `pandas.DataFrame` containing all columns from the original Spark DataFrame. + This is useful when the user does not want to hardcode grouping key in the function. + + >>> from pyspark.sql.functions import pandas_udf, PandasUDFType + >>> import pandas as pd # doctest: +SKIP + >>> df = spark.createDataFrame( + ... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], + ... ("id", "v")) # doctest: +SKIP + >>> @pandas_udf("id long, v double", PandasUDFType.GROUPED_MAP) # doctest: +SKIP + ... def mean_udf(key, pdf): + ... # key is a tuple of one numpy.int64, which is the value + ... # of 'id' for the current group + ... return pd.DataFrame([key + (pdf.v.mean(),)]) + >>> df.groupby('id').apply(mean_udf).show() # doctest: +SKIP + +---+---+ + | id| v| + +---+---+ + | 1|1.5| + | 2|6.0| + +---+---+ + .. seealso:: :meth:`pyspark.sql.GroupedData.apply` 3. GROUPED_AGG diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index a9fe0b425ad3e..480815d27333f 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3903,7 +3903,7 @@ def foo(df): return df with self.assertRaisesRegexp(ValueError, 'Invalid function'): @pandas_udf(returnType='k int, v double', functionType=PandasUDFType.GROUPED_MAP) - def foo(k, v): + def foo(k, v, w): return k @@ -4476,20 +4476,45 @@ def test_supported_types(self): from pyspark.sql.functions import pandas_udf, PandasUDFType, array, col df = self.data.withColumn("arr", array(col("id"))) - foo_udf = pandas_udf( + # Different forms of group map pandas UDF, results of these are the same + + output_schema = StructType( + [StructField('id', LongType()), + StructField('v', IntegerType()), + StructField('arr', ArrayType(LongType())), + StructField('v1', DoubleType()), + StructField('v2', LongType())]) + + udf1 = pandas_udf( lambda pdf: pdf.assign(v1=pdf.v * pdf.id * 1.0, v2=pdf.v + pdf.id), - StructType( - [StructField('id', LongType()), - StructField('v', IntegerType()), - StructField('arr', ArrayType(LongType())), - StructField('v1', DoubleType()), - StructField('v2', LongType())]), + output_schema, PandasUDFType.GROUPED_MAP ) - result = df.groupby('id').apply(foo_udf).sort('id').toPandas() - expected = df.toPandas().groupby('id').apply(foo_udf.func).reset_index(drop=True) - self.assertPandasEqual(expected, result) + udf2 = pandas_udf( + lambda _, pdf: pdf.assign(v1=pdf.v * pdf.id * 1.0, v2=pdf.v + pdf.id), + output_schema, + PandasUDFType.GROUPED_MAP + ) + + udf3 = pandas_udf( + lambda key, pdf: pdf.assign(id=key[0], v1=pdf.v * pdf.id * 1.0, v2=pdf.v + pdf.id), + output_schema, + PandasUDFType.GROUPED_MAP + ) + + result1 = df.groupby('id').apply(udf1).sort('id').toPandas() + expected1 = df.toPandas().groupby('id').apply(udf1.func).reset_index(drop=True) + + result2 = df.groupby('id').apply(udf2).sort('id').toPandas() + expected2 = expected1 + + result3 = df.groupby('id').apply(udf3).sort('id').toPandas() + expected3 = expected1 + + self.assertPandasEqual(expected1, result1) + self.assertPandasEqual(expected2, result2) + self.assertPandasEqual(expected3, result3) def test_register_grouped_map_udf(self): from pyspark.sql.functions import pandas_udf, PandasUDFType @@ -4648,6 +4673,80 @@ def test_timestamp_dst(self): result = df.groupby('time').apply(foo_udf).sort('time') self.assertPandasEqual(df.toPandas(), result.toPandas()) + def test_udf_with_key(self): + from pyspark.sql.functions import pandas_udf, col, PandasUDFType + df = self.data + pdf = df.toPandas() + + def foo1(key, pdf): + import numpy as np + assert type(key) == tuple + assert type(key[0]) == np.int64 + + return pdf.assign(v1=key[0], + v2=pdf.v * key[0], + v3=pdf.v * pdf.id, + v4=pdf.v * pdf.id.mean()) + + def foo2(key, pdf): + import numpy as np + assert type(key) == tuple + assert type(key[0]) == np.int64 + assert type(key[1]) == np.int32 + + return pdf.assign(v1=key[0], + v2=key[1], + v3=pdf.v * key[0], + v4=pdf.v + key[1]) + + def foo3(key, pdf): + assert type(key) == tuple + assert len(key) == 0 + return pdf.assign(v1=pdf.v * pdf.id) + + # v2 is int because numpy.int64 * pd.Series results in pd.Series + # v3 is long because pd.Series * pd.Series results in pd.Series + udf1 = pandas_udf( + foo1, + 'id long, v int, v1 long, v2 int, v3 long, v4 double', + PandasUDFType.GROUPED_MAP) + + udf2 = pandas_udf( + foo2, + 'id long, v int, v1 long, v2 int, v3 int, v4 int', + PandasUDFType.GROUPED_MAP) + + udf3 = pandas_udf( + foo3, + 'id long, v int, v1 long', + PandasUDFType.GROUPED_MAP) + + # Test groupby column + result1 = df.groupby('id').apply(udf1).sort('id', 'v').toPandas() + expected1 = pdf.groupby('id')\ + .apply(lambda x: udf1.func((x.id.iloc[0],), x))\ + .sort_values(['id', 'v']).reset_index(drop=True) + self.assertPandasEqual(expected1, result1) + + # Test groupby expression + result2 = df.groupby(df.id % 2).apply(udf1).sort('id', 'v').toPandas() + expected2 = pdf.groupby(pdf.id % 2)\ + .apply(lambda x: udf1.func((x.id.iloc[0] % 2,), x))\ + .sort_values(['id', 'v']).reset_index(drop=True) + self.assertPandasEqual(expected2, result2) + + # Test complex groupby + result3 = df.groupby(df.id, df.v % 2).apply(udf2).sort('id', 'v').toPandas() + expected3 = pdf.groupby([pdf.id, pdf.v % 2])\ + .apply(lambda x: udf2.func((x.id.iloc[0], (x.v % 2).iloc[0],), x))\ + .sort_values(['id', 'v']).reset_index(drop=True) + self.assertPandasEqual(expected3, result3) + + # Test empty groupby + result4 = df.groupby().apply(udf3).sort('id', 'v').toPandas() + expected4 = udf3.func((), pdf) + self.assertPandasEqual(expected4, result4) + @unittest.skipIf( not _have_pandas or not _have_pyarrow, diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index cd857402db8f7..1632862d3f1ba 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1695,6 +1695,19 @@ def from_arrow_schema(arrow_schema): for field in arrow_schema]) +def _check_series_convert_date(series, data_type): + """ + Cast the series to datetime.date if it's a date type, otherwise returns the original series. + + :param series: pandas.Series + :param data_type: a Spark data type for the series + """ + if type(data_type) == DateType: + return series.dt.date + else: + return series + + def _check_dataframe_convert_date(pdf, schema): """ Correct date type value to use datetime.date. @@ -1705,8 +1718,7 @@ def _check_dataframe_convert_date(pdf, schema): :param schema: a Spark schema of the pandas.DataFrame """ for field in schema: - if type(field.dataType) == DateType: - pdf[field.name] = pdf[field.name].dt.date + pdf[field.name] = _check_series_convert_date(pdf[field.name], field.dataType) return pdf @@ -1725,6 +1737,29 @@ def _get_local_timezone(): return os.environ.get('TZ', 'dateutil/:') +def _check_series_localize_timestamps(s, timezone): + """ + Convert timezone aware timestamps to timezone-naive in the specified timezone or local timezone. + + If the input series is not a timestamp series, then the same series is returned. If the input + series is a timestamp series, then a converted series is returned. + + :param s: pandas.Series + :param timezone: the timezone to convert. if None then use local timezone + :return pandas.Series that have been converted to tz-naive + """ + from pyspark.sql.utils import require_minimum_pandas_version + require_minimum_pandas_version() + + from pandas.api.types import is_datetime64tz_dtype + tz = timezone or _get_local_timezone() + # TODO: handle nested timestamps, such as ArrayType(TimestampType())? + if is_datetime64tz_dtype(s.dtype): + return s.dt.tz_convert(tz).dt.tz_localize(None) + else: + return s + + def _check_dataframe_localize_timestamps(pdf, timezone): """ Convert timezone aware timestamps to timezone-naive in the specified timezone or local timezone @@ -1736,12 +1771,8 @@ def _check_dataframe_localize_timestamps(pdf, timezone): from pyspark.sql.utils import require_minimum_pandas_version require_minimum_pandas_version() - from pandas.api.types import is_datetime64tz_dtype - tz = timezone or _get_local_timezone() for column, series in pdf.iteritems(): - # TODO: handle nested timestamps, such as ArrayType(TimestampType())? - if is_datetime64tz_dtype(series.dtype): - pdf[column] = series.dt.tz_convert(tz).dt.tz_localize(None) + pdf[column] = _check_series_localize_timestamps(series, timezone) return pdf diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index b9b490874f4fb..ce804c18e9b14 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -17,6 +17,8 @@ """ User-defined function related classes and functions """ +import sys +import inspect import functools from pyspark import SparkContext, since @@ -24,6 +26,7 @@ from pyspark.sql.column import Column, _to_java_column, _to_seq from pyspark.sql.types import StringType, DataType, ArrayType, StructType, MapType, \ _parse_datatype_string, to_arrow_type, to_arrow_schema +from pyspark.util import _get_argspec __all__ = ["UDFRegistration"] @@ -41,18 +44,10 @@ def _create_udf(f, returnType, evalType): PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF): - import inspect - import sys from pyspark.sql.utils import require_minimum_pyarrow_version - require_minimum_pyarrow_version() - if sys.version_info[0] < 3: - # `getargspec` is deprecated since python3.0 (incompatible with function annotations). - # See SPARK-23569. - argspec = inspect.getargspec(f) - else: - argspec = inspect.getfullargspec(f) + argspec = _get_argspec(f) if evalType == PythonEvalType.SQL_SCALAR_PANDAS_UDF and len(argspec.args) == 0 and \ argspec.varargs is None: @@ -61,11 +56,11 @@ def _create_udf(f, returnType, evalType): "Instead, create a 1-arg pandas_udf and ignore the arg in your function." ) - if evalType == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF and len(argspec.args) != 1: + if evalType == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF \ + and len(argspec.args) not in (1, 2): raise ValueError( "Invalid function: pandas_udfs with function type GROUPED_MAP " - "must take a single arg that is a pandas DataFrame." - ) + "must take either one argument (data) or two arguments (key, data).") # Set the name of the UserDefinedFunction object to be the name of function f udf_obj = UserDefinedFunction( diff --git a/python/pyspark/util.py b/python/pyspark/util.py index ad4a0bc68ef41..6837b18b7d7a5 100644 --- a/python/pyspark/util.py +++ b/python/pyspark/util.py @@ -15,6 +15,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # + +import sys +import inspect from py4j.protocol import Py4JJavaError __all__ = [] @@ -45,6 +48,19 @@ def _exception_message(excp): return str(excp) +def _get_argspec(f): + """ + Get argspec of a function. Supports both Python 2 and Python 3. + """ + # `getargspec` is deprecated since python3.0 (incompatible with function annotations). + # See SPARK-23569. + if sys.version_info[0] < 3: + argspec = inspect.getargspec(f) + else: + argspec = inspect.getfullargspec(f) + return argspec + + if __name__ == "__main__": import doctest (failure_count, test_count) = doctest.testmod() diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 89a3a92bc66d6..202cac350aafc 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -34,6 +34,7 @@ write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, \ BatchedSerializer, ArrowStreamPandasSerializer from pyspark.sql.types import to_arrow_type +from pyspark.util import _get_argspec from pyspark import shuffle pickleSer = PickleSerializer() @@ -91,10 +92,16 @@ def verify_result_length(*a): def wrap_grouped_map_pandas_udf(f, return_type): - def wrapped(*series): + def wrapped(key_series, value_series): import pandas as pd + argspec = _get_argspec(f) + + if len(argspec.args) == 1: + result = f(pd.concat(value_series, axis=1)) + elif len(argspec.args) == 2: + key = tuple(s[0] for s in key_series) + result = f(key, pd.concat(value_series, axis=1)) - result = f(pd.concat(series, axis=1)) if not isinstance(result, pd.DataFrame): raise TypeError("Return type of the user-defined function should be " "pandas.DataFrame, but is {}".format(type(result))) @@ -149,18 +156,36 @@ def read_udfs(pickleSer, infile, eval_type): num_udfs = read_int(infile) udfs = {} call_udf = [] - for i in range(num_udfs): + mapper_str = "" + if eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF: + # Create function like this: + # lambda a: f([a[0]], [a[0], a[1]]) + + # We assume there is only one UDF here because grouped map doesn't + # support combining multiple UDFs. + assert num_udfs == 1 + + # See FlatMapGroupsInPandasExec for how arg_offsets are used to + # distinguish between grouping attributes and data attributes arg_offsets, udf = read_single_udf(pickleSer, infile, eval_type) - udfs['f%d' % i] = udf - args = ["a[%d]" % o for o in arg_offsets] - call_udf.append("f%d(%s)" % (i, ", ".join(args))) - # Create function like this: - # lambda a: (f0(a0), f1(a1, a2), f2(a3)) - # In the special case of a single UDF this will return a single result rather - # than a tuple of results; this is the format that the JVM side expects. - mapper_str = "lambda a: (%s)" % (", ".join(call_udf)) - mapper = eval(mapper_str, udfs) + udfs['f'] = udf + split_offset = arg_offsets[0] + 1 + arg0 = ["a[%d]" % o for o in arg_offsets[1: split_offset]] + arg1 = ["a[%d]" % o for o in arg_offsets[split_offset:]] + mapper_str = "lambda a: f([%s], [%s])" % (", ".join(arg0), ", ".join(arg1)) + else: + # Create function like this: + # lambda a: (f0(a[0]), f1(a[1], a[2]), f2(a[3])) + # In the special case of a single UDF this will return a single result rather + # than a tuple of results; this is the format that the JVM side expects. + for i in range(num_udfs): + arg_offsets, udf = read_single_udf(pickleSer, infile, eval_type) + udfs['f%d' % i] = udf + args = ["a[%d]" % o for o in arg_offsets] + call_udf.append("f%d(%s)" % (i, ", ".join(args))) + mapper_str = "lambda a: (%s)" % (", ".join(call_udf)) + mapper = eval(mapper_str, udfs) func = lambda _, it: map(mapper, it) if eval_type in (PythonEvalType.SQL_SCALAR_PANDAS_UDF, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala index c798fe5a92c54..513e174c7733e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.python import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer import org.apache.spark.TaskContext import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType} @@ -75,20 +76,63 @@ case class FlatMapGroupsInPandasExec( val bufferSize = inputRDD.conf.getInt("spark.buffer.size", 65536) val reuseWorker = inputRDD.conf.getBoolean("spark.python.worker.reuse", defaultValue = true) val chainedFunc = Seq(ChainedPythonFunctions(Seq(pandasFunction))) - val argOffsets = Array((0 until (child.output.length - groupingAttributes.length)).toArray) - val schema = StructType(child.schema.drop(groupingAttributes.length)) val sessionLocalTimeZone = conf.sessionLocalTimeZone val pandasRespectSessionTimeZone = conf.pandasRespectSessionTimeZone + // Deduplicate the grouping attributes. + // If a grouping attribute also appears in data attributes, then we don't need to send the + // grouping attribute to Python worker. If a grouping attribute is not in data attributes, + // then we need to send this grouping attribute to python worker. + // + // We use argOffsets to distinguish grouping attributes and data attributes as following: + // + // argOffsets[0] is the length of grouping attributes + // argOffsets[1 .. argOffsets[0]+1] is the arg offsets for grouping attributes + // argOffsets[argOffsets[0]+1 .. ] is the arg offsets for data attributes + + val dataAttributes = child.output.drop(groupingAttributes.length) + val groupingIndicesInData = groupingAttributes.map { attribute => + dataAttributes.indexWhere(attribute.semanticEquals) + } + + val groupingArgOffsets = new ArrayBuffer[Int] + val nonDupGroupingAttributes = new ArrayBuffer[Attribute] + val nonDupGroupingSize = groupingIndicesInData.count(_ == -1) + + // Non duplicate grouping attributes are added to nonDupGroupingAttributes and + // their offsets are 0, 1, 2 ... + // Duplicate grouping attributes are NOT added to nonDupGroupingAttributes and + // their offsets are n + index, where n is the total number of non duplicate grouping + // attributes and index is the index in the data attributes that the grouping attribute + // is a duplicate of. + + groupingAttributes.zip(groupingIndicesInData).foreach { + case (attribute, index) => + if (index == -1) { + groupingArgOffsets += nonDupGroupingAttributes.length + nonDupGroupingAttributes += attribute + } else { + groupingArgOffsets += index + nonDupGroupingSize + } + } + + val dataArgOffsets = nonDupGroupingAttributes.length until + (nonDupGroupingAttributes.length + dataAttributes.length) + + val argOffsets = Array(Array(groupingAttributes.length) ++ groupingArgOffsets ++ dataArgOffsets) + + // Attributes after deduplication + val dedupAttributes = nonDupGroupingAttributes ++ dataAttributes + val dedupSchema = StructType.fromAttributes(dedupAttributes) + inputRDD.mapPartitionsInternal { iter => val grouped = if (groupingAttributes.isEmpty) { Iterator(iter) } else { val groupedIter = GroupedIterator(iter, groupingAttributes, child.output) - val dropGrouping = - UnsafeProjection.create(child.output.drop(groupingAttributes.length), child.output) + val dedupProj = UnsafeProjection.create(dedupAttributes, child.output) groupedIter.map { - case (_, groupedRowIter) => groupedRowIter.map(dropGrouping) + case (_, groupedRowIter) => groupedRowIter.map(dedupProj) } } @@ -96,7 +140,7 @@ case class FlatMapGroupsInPandasExec( val columnarBatchIter = new ArrowPythonRunner( chainedFunc, bufferSize, reuseWorker, - PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, argOffsets, schema, + PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, argOffsets, dedupSchema, sessionLocalTimeZone, pandasRespectSessionTimeZone) .compute(grouped, context.partitionId(), context) From 7013eea11cb32b1e0038dc751c485da5c94a484b Mon Sep 17 00:00:00 2001 From: Benjamin Peterson Date: Thu, 8 Mar 2018 20:38:34 +0900 Subject: [PATCH 0452/1161] [SPARK-23522][PYTHON] always use sys.exit over builtin exit The exit() builtin is only for interactive use. applications should use sys.exit(). ## What changes were proposed in this pull request? All usage of the builtin `exit()` function is replaced by `sys.exit()`. ## How was this patch tested? I ran `python/run-tests`. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Benjamin Peterson Closes #20682 from benjaminp/sys-exit. --- dev/merge_spark_pr.py | 2 +- dev/run-tests.py | 2 +- examples/src/main/python/avro_inputformat.py | 2 +- examples/src/main/python/kmeans.py | 2 +- examples/src/main/python/logistic_regression.py | 2 +- examples/src/main/python/ml/dataframe_example.py | 2 +- examples/src/main/python/mllib/correlations.py | 2 +- examples/src/main/python/mllib/kmeans.py | 2 +- examples/src/main/python/mllib/logistic_regression.py | 2 +- examples/src/main/python/mllib/random_rdd_generation.py | 2 +- examples/src/main/python/mllib/sampled_rdds.py | 4 ++-- .../python/mllib/streaming_linear_regression_example.py | 2 +- examples/src/main/python/pagerank.py | 2 +- examples/src/main/python/parquet_inputformat.py | 2 +- examples/src/main/python/sort.py | 2 +- .../main/python/sql/streaming/structured_kafka_wordcount.py | 2 +- .../python/sql/streaming/structured_network_wordcount.py | 2 +- .../sql/streaming/structured_network_wordcount_windowed.py | 2 +- .../src/main/python/streaming/direct_kafka_wordcount.py | 2 +- examples/src/main/python/streaming/flume_wordcount.py | 2 +- examples/src/main/python/streaming/hdfs_wordcount.py | 2 +- examples/src/main/python/streaming/kafka_wordcount.py | 2 +- examples/src/main/python/streaming/network_wordcount.py | 2 +- .../src/main/python/streaming/network_wordjoinsentiments.py | 2 +- .../main/python/streaming/recoverable_network_wordcount.py | 2 +- examples/src/main/python/streaming/sql_network_wordcount.py | 2 +- .../src/main/python/streaming/stateful_network_wordcount.py | 2 +- examples/src/main/python/wordcount.py | 2 +- python/pyspark/accumulators.py | 2 +- python/pyspark/broadcast.py | 2 +- python/pyspark/conf.py | 2 +- python/pyspark/context.py | 2 +- python/pyspark/daemon.py | 2 +- python/pyspark/find_spark_home.py | 2 +- python/pyspark/heapq3.py | 3 ++- python/pyspark/ml/classification.py | 3 ++- python/pyspark/ml/clustering.py | 4 +++- python/pyspark/ml/evaluation.py | 3 ++- python/pyspark/ml/feature.py | 2 +- python/pyspark/ml/image.py | 4 +++- python/pyspark/ml/linalg/__init__.py | 2 +- python/pyspark/ml/recommendation.py | 4 +++- python/pyspark/ml/regression.py | 3 ++- python/pyspark/ml/stat.py | 4 +++- python/pyspark/ml/tuning.py | 6 ++++-- python/pyspark/mllib/classification.py | 3 ++- python/pyspark/mllib/clustering.py | 2 +- python/pyspark/mllib/evaluation.py | 3 ++- python/pyspark/mllib/feature.py | 2 +- python/pyspark/mllib/fpm.py | 4 +++- python/pyspark/mllib/linalg/__init__.py | 2 +- python/pyspark/mllib/linalg/distributed.py | 2 +- python/pyspark/mllib/random.py | 3 ++- python/pyspark/mllib/recommendation.py | 3 ++- python/pyspark/mllib/regression.py | 6 ++++-- python/pyspark/mllib/stat/_statistics.py | 2 +- python/pyspark/mllib/tree.py | 3 ++- python/pyspark/mllib/util.py | 2 +- python/pyspark/profiler.py | 3 ++- python/pyspark/rdd.py | 2 +- python/pyspark/serializers.py | 2 +- python/pyspark/shuffle.py | 3 ++- python/pyspark/sql/catalog.py | 3 ++- python/pyspark/sql/column.py | 2 +- python/pyspark/sql/conf.py | 4 +++- python/pyspark/sql/context.py | 2 +- python/pyspark/sql/dataframe.py | 2 +- python/pyspark/sql/functions.py | 2 +- python/pyspark/sql/group.py | 4 +++- python/pyspark/sql/readwriter.py | 2 +- python/pyspark/sql/session.py | 2 +- python/pyspark/sql/streaming.py | 2 +- python/pyspark/sql/types.py | 2 +- python/pyspark/sql/udf.py | 3 ++- python/pyspark/sql/window.py | 2 +- python/pyspark/streaming/util.py | 3 ++- python/pyspark/util.py | 4 +++- python/pyspark/worker.py | 6 +++--- python/setup.py | 6 +++--- 79 files changed, 120 insertions(+), 86 deletions(-) diff --git a/dev/merge_spark_pr.py b/dev/merge_spark_pr.py index 6b244d8184b2c..5ea205fbed4aa 100755 --- a/dev/merge_spark_pr.py +++ b/dev/merge_spark_pr.py @@ -510,7 +510,7 @@ def main(): import doctest (failure_count, test_count) = doctest.testmod() if failure_count: - exit(-1) + sys.exit(-1) try: main() except: diff --git a/dev/run-tests.py b/dev/run-tests.py index fe75ef4411c8c..164c1e2200aa9 100755 --- a/dev/run-tests.py +++ b/dev/run-tests.py @@ -621,7 +621,7 @@ def _test(): import doctest failure_count = doctest.testmod()[0] if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": _test() diff --git a/examples/src/main/python/avro_inputformat.py b/examples/src/main/python/avro_inputformat.py index 6286ba6541fbd..a18722c687f8b 100644 --- a/examples/src/main/python/avro_inputformat.py +++ b/examples/src/main/python/avro_inputformat.py @@ -61,7 +61,7 @@ Assumes you have Avro data stored in . Reader schema can be optionally specified in [reader_schema_file]. """, file=sys.stderr) - exit(-1) + sys.exit(-1) path = sys.argv[1] diff --git a/examples/src/main/python/kmeans.py b/examples/src/main/python/kmeans.py index 92e0a3ae2ee60..a42d711fc505f 100755 --- a/examples/src/main/python/kmeans.py +++ b/examples/src/main/python/kmeans.py @@ -49,7 +49,7 @@ def closestPoint(p, centers): if len(sys.argv) != 4: print("Usage: kmeans ", file=sys.stderr) - exit(-1) + sys.exit(-1) print("""WARN: This is a naive implementation of KMeans Clustering and is given as an example! Please refer to examples/src/main/python/ml/kmeans_example.py for an diff --git a/examples/src/main/python/logistic_regression.py b/examples/src/main/python/logistic_regression.py index 01c938454b108..bcc4e0f4e8eae 100755 --- a/examples/src/main/python/logistic_regression.py +++ b/examples/src/main/python/logistic_regression.py @@ -48,7 +48,7 @@ def readPointBatch(iterator): if len(sys.argv) != 3: print("Usage: logistic_regression ", file=sys.stderr) - exit(-1) + sys.exit(-1) print("""WARN: This is a naive implementation of Logistic Regression and is given as an example! diff --git a/examples/src/main/python/ml/dataframe_example.py b/examples/src/main/python/ml/dataframe_example.py index 109f901012c9c..d62cf2338a1fe 100644 --- a/examples/src/main/python/ml/dataframe_example.py +++ b/examples/src/main/python/ml/dataframe_example.py @@ -33,7 +33,7 @@ if __name__ == "__main__": if len(sys.argv) > 2: print("Usage: dataframe_example.py ", file=sys.stderr) - exit(-1) + sys.exit(-1) elif len(sys.argv) == 2: input = sys.argv[1] else: diff --git a/examples/src/main/python/mllib/correlations.py b/examples/src/main/python/mllib/correlations.py index 0e13546b88e67..089504fa7064b 100755 --- a/examples/src/main/python/mllib/correlations.py +++ b/examples/src/main/python/mllib/correlations.py @@ -31,7 +31,7 @@ if __name__ == "__main__": if len(sys.argv) not in [1, 2]: print("Usage: correlations ()", file=sys.stderr) - exit(-1) + sys.exit(-1) sc = SparkContext(appName="PythonCorrelations") if len(sys.argv) == 2: filepath = sys.argv[1] diff --git a/examples/src/main/python/mllib/kmeans.py b/examples/src/main/python/mllib/kmeans.py index 002fc75799648..1bdb3e9b4a2af 100755 --- a/examples/src/main/python/mllib/kmeans.py +++ b/examples/src/main/python/mllib/kmeans.py @@ -36,7 +36,7 @@ def parseVector(line): if __name__ == "__main__": if len(sys.argv) != 3: print("Usage: kmeans ", file=sys.stderr) - exit(-1) + sys.exit(-1) sc = SparkContext(appName="KMeans") lines = sc.textFile(sys.argv[1]) data = lines.map(parseVector) diff --git a/examples/src/main/python/mllib/logistic_regression.py b/examples/src/main/python/mllib/logistic_regression.py index d4f1d34e2d8cf..87efe17375226 100755 --- a/examples/src/main/python/mllib/logistic_regression.py +++ b/examples/src/main/python/mllib/logistic_regression.py @@ -42,7 +42,7 @@ def parsePoint(line): if __name__ == "__main__": if len(sys.argv) != 3: print("Usage: logistic_regression ", file=sys.stderr) - exit(-1) + sys.exit(-1) sc = SparkContext(appName="PythonLR") points = sc.textFile(sys.argv[1]).map(parsePoint) iterations = int(sys.argv[2]) diff --git a/examples/src/main/python/mllib/random_rdd_generation.py b/examples/src/main/python/mllib/random_rdd_generation.py index 729bae30b152c..9a429b5f8abdf 100755 --- a/examples/src/main/python/mllib/random_rdd_generation.py +++ b/examples/src/main/python/mllib/random_rdd_generation.py @@ -29,7 +29,7 @@ if __name__ == "__main__": if len(sys.argv) not in [1, 2]: print("Usage: random_rdd_generation", file=sys.stderr) - exit(-1) + sys.exit(-1) sc = SparkContext(appName="PythonRandomRDDGeneration") diff --git a/examples/src/main/python/mllib/sampled_rdds.py b/examples/src/main/python/mllib/sampled_rdds.py index b7033ab7daeb3..00e7cf4bbcdbf 100755 --- a/examples/src/main/python/mllib/sampled_rdds.py +++ b/examples/src/main/python/mllib/sampled_rdds.py @@ -29,7 +29,7 @@ if __name__ == "__main__": if len(sys.argv) not in [1, 2]: print("Usage: sampled_rdds ", file=sys.stderr) - exit(-1) + sys.exit(-1) if len(sys.argv) == 2: datapath = sys.argv[1] else: @@ -43,7 +43,7 @@ numExamples = examples.count() if numExamples == 0: print("Error: Data file had no samples to load.", file=sys.stderr) - exit(1) + sys.exit(1) print('Loaded data with %d examples from file: %s' % (numExamples, datapath)) # Example: RDD.sample() and RDD.takeSample() diff --git a/examples/src/main/python/mllib/streaming_linear_regression_example.py b/examples/src/main/python/mllib/streaming_linear_regression_example.py index f600496867c11..714c9a0de7217 100644 --- a/examples/src/main/python/mllib/streaming_linear_regression_example.py +++ b/examples/src/main/python/mllib/streaming_linear_regression_example.py @@ -36,7 +36,7 @@ if len(sys.argv) != 3: print("Usage: streaming_linear_regression_example.py ", file=sys.stderr) - exit(-1) + sys.exit(-1) sc = SparkContext(appName="PythonLogisticRegressionWithLBFGSExample") ssc = StreamingContext(sc, 1) diff --git a/examples/src/main/python/pagerank.py b/examples/src/main/python/pagerank.py index 0d6c253d397a0..2c19e8700ab16 100755 --- a/examples/src/main/python/pagerank.py +++ b/examples/src/main/python/pagerank.py @@ -47,7 +47,7 @@ def parseNeighbors(urls): if __name__ == "__main__": if len(sys.argv) != 3: print("Usage: pagerank ", file=sys.stderr) - exit(-1) + sys.exit(-1) print("WARN: This is a naive implementation of PageRank and is given as an example!\n" + "Please refer to PageRank implementation provided by graphx", diff --git a/examples/src/main/python/parquet_inputformat.py b/examples/src/main/python/parquet_inputformat.py index a3f86cf8999cf..83041f0040a0c 100644 --- a/examples/src/main/python/parquet_inputformat.py +++ b/examples/src/main/python/parquet_inputformat.py @@ -45,7 +45,7 @@ /path/to/examples/parquet_inputformat.py Assumes you have Parquet data stored in . """, file=sys.stderr) - exit(-1) + sys.exit(-1) path = sys.argv[1] diff --git a/examples/src/main/python/sort.py b/examples/src/main/python/sort.py index 81898cf6d5ce6..d3cd985d197e3 100755 --- a/examples/src/main/python/sort.py +++ b/examples/src/main/python/sort.py @@ -25,7 +25,7 @@ if __name__ == "__main__": if len(sys.argv) != 2: print("Usage: sort ", file=sys.stderr) - exit(-1) + sys.exit(-1) spark = SparkSession\ .builder\ diff --git a/examples/src/main/python/sql/streaming/structured_kafka_wordcount.py b/examples/src/main/python/sql/streaming/structured_kafka_wordcount.py index 9e8a552b3b10b..921067891352a 100644 --- a/examples/src/main/python/sql/streaming/structured_kafka_wordcount.py +++ b/examples/src/main/python/sql/streaming/structured_kafka_wordcount.py @@ -49,7 +49,7 @@ print(""" Usage: structured_kafka_wordcount.py """, file=sys.stderr) - exit(-1) + sys.exit(-1) bootstrapServers = sys.argv[1] subscribeType = sys.argv[2] diff --git a/examples/src/main/python/sql/streaming/structured_network_wordcount.py b/examples/src/main/python/sql/streaming/structured_network_wordcount.py index c3284c1d01017..9ac392164735b 100644 --- a/examples/src/main/python/sql/streaming/structured_network_wordcount.py +++ b/examples/src/main/python/sql/streaming/structured_network_wordcount.py @@ -38,7 +38,7 @@ if __name__ == "__main__": if len(sys.argv) != 3: print("Usage: structured_network_wordcount.py ", file=sys.stderr) - exit(-1) + sys.exit(-1) host = sys.argv[1] port = int(sys.argv[2]) diff --git a/examples/src/main/python/sql/streaming/structured_network_wordcount_windowed.py b/examples/src/main/python/sql/streaming/structured_network_wordcount_windowed.py index db672551504b5..c4e3bbf44cd5a 100644 --- a/examples/src/main/python/sql/streaming/structured_network_wordcount_windowed.py +++ b/examples/src/main/python/sql/streaming/structured_network_wordcount_windowed.py @@ -53,7 +53,7 @@ msg = ("Usage: structured_network_wordcount_windowed.py " " []") print(msg, file=sys.stderr) - exit(-1) + sys.exit(-1) host = sys.argv[1] port = int(sys.argv[2]) diff --git a/examples/src/main/python/streaming/direct_kafka_wordcount.py b/examples/src/main/python/streaming/direct_kafka_wordcount.py index 425df309011a0..c5c186c11f79a 100644 --- a/examples/src/main/python/streaming/direct_kafka_wordcount.py +++ b/examples/src/main/python/streaming/direct_kafka_wordcount.py @@ -39,7 +39,7 @@ if __name__ == "__main__": if len(sys.argv) != 3: print("Usage: direct_kafka_wordcount.py ", file=sys.stderr) - exit(-1) + sys.exit(-1) sc = SparkContext(appName="PythonStreamingDirectKafkaWordCount") ssc = StreamingContext(sc, 2) diff --git a/examples/src/main/python/streaming/flume_wordcount.py b/examples/src/main/python/streaming/flume_wordcount.py index 5d6e6dc36d6f9..c8ea92b61ca6e 100644 --- a/examples/src/main/python/streaming/flume_wordcount.py +++ b/examples/src/main/python/streaming/flume_wordcount.py @@ -39,7 +39,7 @@ if __name__ == "__main__": if len(sys.argv) != 3: print("Usage: flume_wordcount.py ", file=sys.stderr) - exit(-1) + sys.exit(-1) sc = SparkContext(appName="PythonStreamingFlumeWordCount") ssc = StreamingContext(sc, 1) diff --git a/examples/src/main/python/streaming/hdfs_wordcount.py b/examples/src/main/python/streaming/hdfs_wordcount.py index f815dd26823d1..f9a5c43a8eaa9 100644 --- a/examples/src/main/python/streaming/hdfs_wordcount.py +++ b/examples/src/main/python/streaming/hdfs_wordcount.py @@ -35,7 +35,7 @@ if __name__ == "__main__": if len(sys.argv) != 2: print("Usage: hdfs_wordcount.py ", file=sys.stderr) - exit(-1) + sys.exit(-1) sc = SparkContext(appName="PythonStreamingHDFSWordCount") ssc = StreamingContext(sc, 1) diff --git a/examples/src/main/python/streaming/kafka_wordcount.py b/examples/src/main/python/streaming/kafka_wordcount.py index 704f6602e2297..e9ee08b9fd228 100644 --- a/examples/src/main/python/streaming/kafka_wordcount.py +++ b/examples/src/main/python/streaming/kafka_wordcount.py @@ -39,7 +39,7 @@ if __name__ == "__main__": if len(sys.argv) != 3: print("Usage: kafka_wordcount.py ", file=sys.stderr) - exit(-1) + sys.exit(-1) sc = SparkContext(appName="PythonStreamingKafkaWordCount") ssc = StreamingContext(sc, 1) diff --git a/examples/src/main/python/streaming/network_wordcount.py b/examples/src/main/python/streaming/network_wordcount.py index 9010fafb425e6..f3099d2517cd5 100644 --- a/examples/src/main/python/streaming/network_wordcount.py +++ b/examples/src/main/python/streaming/network_wordcount.py @@ -35,7 +35,7 @@ if __name__ == "__main__": if len(sys.argv) != 3: print("Usage: network_wordcount.py ", file=sys.stderr) - exit(-1) + sys.exit(-1) sc = SparkContext(appName="PythonStreamingNetworkWordCount") ssc = StreamingContext(sc, 1) diff --git a/examples/src/main/python/streaming/network_wordjoinsentiments.py b/examples/src/main/python/streaming/network_wordjoinsentiments.py index d51a380a5d5f9..2b5434c0c845a 100644 --- a/examples/src/main/python/streaming/network_wordjoinsentiments.py +++ b/examples/src/main/python/streaming/network_wordjoinsentiments.py @@ -47,7 +47,7 @@ def print_happiest_words(rdd): if __name__ == "__main__": if len(sys.argv) != 3: print("Usage: network_wordjoinsentiments.py ", file=sys.stderr) - exit(-1) + sys.exit(-1) sc = SparkContext(appName="PythonStreamingNetworkWordJoinSentiments") ssc = StreamingContext(sc, 5) diff --git a/examples/src/main/python/streaming/recoverable_network_wordcount.py b/examples/src/main/python/streaming/recoverable_network_wordcount.py index 52b2639cdf55c..60167dc772544 100644 --- a/examples/src/main/python/streaming/recoverable_network_wordcount.py +++ b/examples/src/main/python/streaming/recoverable_network_wordcount.py @@ -101,7 +101,7 @@ def filterFunc(wordCount): if len(sys.argv) != 5: print("Usage: recoverable_network_wordcount.py " " ", file=sys.stderr) - exit(-1) + sys.exit(-1) host, port, checkpoint, output = sys.argv[1:] ssc = StreamingContext.getOrCreate(checkpoint, lambda: createContext(host, int(port), output)) diff --git a/examples/src/main/python/streaming/sql_network_wordcount.py b/examples/src/main/python/streaming/sql_network_wordcount.py index 7f12281c0e3fe..ab3cfc067994d 100644 --- a/examples/src/main/python/streaming/sql_network_wordcount.py +++ b/examples/src/main/python/streaming/sql_network_wordcount.py @@ -48,7 +48,7 @@ def getSparkSessionInstance(sparkConf): if __name__ == "__main__": if len(sys.argv) != 3: print("Usage: sql_network_wordcount.py ", file=sys.stderr) - exit(-1) + sys.exit(-1) host, port = sys.argv[1:] sc = SparkContext(appName="PythonSqlNetworkWordCount") ssc = StreamingContext(sc, 1) diff --git a/examples/src/main/python/streaming/stateful_network_wordcount.py b/examples/src/main/python/streaming/stateful_network_wordcount.py index d7bb61e729f18..d5d1eba6c5969 100644 --- a/examples/src/main/python/streaming/stateful_network_wordcount.py +++ b/examples/src/main/python/streaming/stateful_network_wordcount.py @@ -39,7 +39,7 @@ if __name__ == "__main__": if len(sys.argv) != 3: print("Usage: stateful_network_wordcount.py ", file=sys.stderr) - exit(-1) + sys.exit(-1) sc = SparkContext(appName="PythonStreamingStatefulNetworkWordCount") ssc = StreamingContext(sc, 1) ssc.checkpoint("checkpoint") diff --git a/examples/src/main/python/wordcount.py b/examples/src/main/python/wordcount.py index 3d5e44d5b2df1..a05e24ff3ff95 100755 --- a/examples/src/main/python/wordcount.py +++ b/examples/src/main/python/wordcount.py @@ -26,7 +26,7 @@ if __name__ == "__main__": if len(sys.argv) != 2: print("Usage: wordcount ", file=sys.stderr) - exit(-1) + sys.exit(-1) spark = SparkSession\ .builder\ diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py index 7def676b89a24..f730d290273fe 100644 --- a/python/pyspark/accumulators.py +++ b/python/pyspark/accumulators.py @@ -265,4 +265,4 @@ def _start_update_server(): import doctest (failure_count, test_count) = doctest.testmod() if failure_count: - exit(-1) + sys.exit(-1) diff --git a/python/pyspark/broadcast.py b/python/pyspark/broadcast.py index 02fc515fb824a..b3dfc99962a35 100644 --- a/python/pyspark/broadcast.py +++ b/python/pyspark/broadcast.py @@ -162,4 +162,4 @@ def clear(self): import doctest (failure_count, test_count) = doctest.testmod() if failure_count: - exit(-1) + sys.exit(-1) diff --git a/python/pyspark/conf.py b/python/pyspark/conf.py index 491b3a81972bc..ab429d9ab10de 100644 --- a/python/pyspark/conf.py +++ b/python/pyspark/conf.py @@ -217,7 +217,7 @@ def _test(): import doctest (failure_count, test_count) = doctest.testmod(optionflags=doctest.ELLIPSIS) if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 24905f1c97b21..7c664966ed74e 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -1035,7 +1035,7 @@ def _test(): (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) globs['sc'].stop() if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": diff --git a/python/pyspark/daemon.py b/python/pyspark/daemon.py index 7f06d4288c872..7bed5216eabf3 100644 --- a/python/pyspark/daemon.py +++ b/python/pyspark/daemon.py @@ -89,7 +89,7 @@ def shutdown(code): signal.signal(SIGTERM, SIG_DFL) # Send SIGHUP to notify workers of shutdown os.kill(0, SIGHUP) - exit(code) + sys.exit(code) def handle_sigterm(*args): shutdown(1) diff --git a/python/pyspark/find_spark_home.py b/python/pyspark/find_spark_home.py index 212a618b767ab..9cf0e8c8d2fe9 100755 --- a/python/pyspark/find_spark_home.py +++ b/python/pyspark/find_spark_home.py @@ -68,7 +68,7 @@ def is_spark_home(path): return next(path for path in paths if is_spark_home(path)) except StopIteration: print("Could not find valid SPARK_HOME while searching {0}".format(paths), file=sys.stderr) - exit(-1) + sys.exit(-1) if __name__ == "__main__": print(_find_spark_home()) diff --git a/python/pyspark/heapq3.py b/python/pyspark/heapq3.py index b27e91a4cc251..6af084adcf373 100644 --- a/python/pyspark/heapq3.py +++ b/python/pyspark/heapq3.py @@ -884,6 +884,7 @@ def nlargest(n, iterable, key=None): if __name__ == "__main__": import doctest + import sys (failure_count, test_count) = doctest.testmod() if failure_count: - exit(-1) + sys.exit(-1) diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 27ad1e80aa0d3..fbbe3d0307c81 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -16,6 +16,7 @@ # import operator +import sys from multiprocessing.pool import ThreadPool from pyspark import since, keyword_only @@ -2043,4 +2044,4 @@ def _to_java(self): except OSError: pass if failure_count: - exit(-1) + sys.exit(-1) diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py index 6448b76a0da88..b3d5fb17f6b81 100644 --- a/python/pyspark/ml/clustering.py +++ b/python/pyspark/ml/clustering.py @@ -15,6 +15,8 @@ # limitations under the License. # +import sys + from pyspark import since, keyword_only from pyspark.ml.util import * from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaWrapper @@ -1181,4 +1183,4 @@ def getKeepLastCheckpoint(self): except OSError: pass if failure_count: - exit(-1) + sys.exit(-1) diff --git a/python/pyspark/ml/evaluation.py b/python/pyspark/ml/evaluation.py index 695d8ab27cc96..8eaf07645a37f 100644 --- a/python/pyspark/ml/evaluation.py +++ b/python/pyspark/ml/evaluation.py @@ -15,6 +15,7 @@ # limitations under the License. # +import sys from abc import abstractmethod, ABCMeta from pyspark import since, keyword_only @@ -446,4 +447,4 @@ def getDistanceMeasure(self): except OSError: pass if failure_count: - exit(-1) + sys.exit(-1) diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 04b07e6a05481..f2e357f0bede5 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -3717,4 +3717,4 @@ def setSize(self, value): except OSError: pass if failure_count: - exit(-1) + sys.exit(-1) diff --git a/python/pyspark/ml/image.py b/python/pyspark/ml/image.py index 45c936645f2a8..96d702f844839 100644 --- a/python/pyspark/ml/image.py +++ b/python/pyspark/ml/image.py @@ -24,6 +24,8 @@ :members: """ +import sys + import numpy as np from pyspark import SparkContext from pyspark.sql.types import Row, _create_row, _parse_datatype_json_string @@ -251,7 +253,7 @@ def _test(): optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE) spark.stop() if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": diff --git a/python/pyspark/ml/linalg/__init__.py b/python/pyspark/ml/linalg/__init__.py index ad1b487676fa7..6a611a2b5b59d 100644 --- a/python/pyspark/ml/linalg/__init__.py +++ b/python/pyspark/ml/linalg/__init__.py @@ -1158,7 +1158,7 @@ def _test(): import doctest (failure_count, test_count) = doctest.testmod(optionflags=doctest.ELLIPSIS) if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": _test() diff --git a/python/pyspark/ml/recommendation.py b/python/pyspark/ml/recommendation.py index e8bcbe4cd34cb..a8eae9bd268d3 100644 --- a/python/pyspark/ml/recommendation.py +++ b/python/pyspark/ml/recommendation.py @@ -15,6 +15,8 @@ # limitations under the License. # +import sys + from pyspark import since, keyword_only from pyspark.ml.util import * from pyspark.ml.wrapper import JavaEstimator, JavaModel @@ -480,4 +482,4 @@ def recommendForItemSubset(self, dataset, numUsers): except OSError: pass if failure_count: - exit(-1) + sys.exit(-1) diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index f0812bd1d4a39..de0a0fa9f3bf8 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -15,6 +15,7 @@ # limitations under the License. # +import sys import warnings from pyspark import since, keyword_only @@ -1812,4 +1813,4 @@ def __repr__(self): except OSError: pass if failure_count: - exit(-1) + sys.exit(-1) diff --git a/python/pyspark/ml/stat.py b/python/pyspark/ml/stat.py index 079b0833e1c6d..0eeb5e528434a 100644 --- a/python/pyspark/ml/stat.py +++ b/python/pyspark/ml/stat.py @@ -15,6 +15,8 @@ # limitations under the License. # +import sys + from pyspark import since, SparkContext from pyspark.ml.common import _java2py, _py2java from pyspark.ml.wrapper import _jvm @@ -151,4 +153,4 @@ def corr(dataset, column, method="pearson"): failure_count, test_count = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) spark.stop() if failure_count: - exit(-1) + sys.exit(-1) diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py index 6c0cad6cbaaa1..545e24ca05aa5 100644 --- a/python/pyspark/ml/tuning.py +++ b/python/pyspark/ml/tuning.py @@ -15,9 +15,11 @@ # limitations under the License. # import itertools -import numpy as np +import sys from multiprocessing.pool import ThreadPool +import numpy as np + from pyspark import since, keyword_only from pyspark.ml import Estimator, Model from pyspark.ml.common import _py2java @@ -727,4 +729,4 @@ def _to_java(self): (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) spark.stop() if failure_count: - exit(-1) + sys.exit(-1) diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py index cce703d432b5a..bb281981fd56b 100644 --- a/python/pyspark/mllib/classification.py +++ b/python/pyspark/mllib/classification.py @@ -16,6 +16,7 @@ # from math import exp +import sys import warnings import numpy @@ -761,7 +762,7 @@ def _test(): (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) spark.stop() if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": _test() diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py index bb687a7da6ffd..0cbabab13a896 100644 --- a/python/pyspark/mllib/clustering.py +++ b/python/pyspark/mllib/clustering.py @@ -1048,7 +1048,7 @@ def _test(): (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) globs['sc'].stop() if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": diff --git a/python/pyspark/mllib/evaluation.py b/python/pyspark/mllib/evaluation.py index 2cd1da3fbf9aa..36cb03369b8c0 100644 --- a/python/pyspark/mllib/evaluation.py +++ b/python/pyspark/mllib/evaluation.py @@ -15,6 +15,7 @@ # limitations under the License. # +import sys import warnings from pyspark import since @@ -542,7 +543,7 @@ def _test(): (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) spark.stop() if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": diff --git a/python/pyspark/mllib/feature.py b/python/pyspark/mllib/feature.py index e5231dc3a27a8..40ecd2e0ff4be 100644 --- a/python/pyspark/mllib/feature.py +++ b/python/pyspark/mllib/feature.py @@ -819,7 +819,7 @@ def _test(): (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) spark.stop() if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": sys.path.pop(0) diff --git a/python/pyspark/mllib/fpm.py b/python/pyspark/mllib/fpm.py index f58ea5dfb0874..de18dad1f675d 100644 --- a/python/pyspark/mllib/fpm.py +++ b/python/pyspark/mllib/fpm.py @@ -15,6 +15,8 @@ # limitations under the License. # +import sys + import numpy from numpy import array from collections import namedtuple @@ -197,7 +199,7 @@ def _test(): except OSError: pass if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": diff --git a/python/pyspark/mllib/linalg/__init__.py b/python/pyspark/mllib/linalg/__init__.py index 7b24b3c74a9fa..60d96d8d5ceb8 100644 --- a/python/pyspark/mllib/linalg/__init__.py +++ b/python/pyspark/mllib/linalg/__init__.py @@ -1370,7 +1370,7 @@ def _test(): import doctest (failure_count, test_count) = doctest.testmod(optionflags=doctest.ELLIPSIS) if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": _test() diff --git a/python/pyspark/mllib/linalg/distributed.py b/python/pyspark/mllib/linalg/distributed.py index 4cb802514be52..bba88542167ad 100644 --- a/python/pyspark/mllib/linalg/distributed.py +++ b/python/pyspark/mllib/linalg/distributed.py @@ -1377,7 +1377,7 @@ def _test(): (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) spark.stop() if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": _test() diff --git a/python/pyspark/mllib/random.py b/python/pyspark/mllib/random.py index 61213ddf62e8b..a8833cb446923 100644 --- a/python/pyspark/mllib/random.py +++ b/python/pyspark/mllib/random.py @@ -19,6 +19,7 @@ Python package for random data generation. """ +import sys from functools import wraps from pyspark import since @@ -421,7 +422,7 @@ def _test(): (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) spark.stop() if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": diff --git a/python/pyspark/mllib/recommendation.py b/python/pyspark/mllib/recommendation.py index 81182881352bb..3d4eae85132bb 100644 --- a/python/pyspark/mllib/recommendation.py +++ b/python/pyspark/mllib/recommendation.py @@ -16,6 +16,7 @@ # import array +import sys from collections import namedtuple from pyspark import SparkContext, since @@ -326,7 +327,7 @@ def _test(): (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) globs['sc'].stop() if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py index ea107d400621d..6be45f51862c9 100644 --- a/python/pyspark/mllib/regression.py +++ b/python/pyspark/mllib/regression.py @@ -15,9 +15,11 @@ # limitations under the License. # +import sys +import warnings + import numpy as np from numpy import array -import warnings from pyspark import RDD, since from pyspark.streaming.dstream import DStream @@ -837,7 +839,7 @@ def _test(): (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) spark.stop() if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": _test() diff --git a/python/pyspark/mllib/stat/_statistics.py b/python/pyspark/mllib/stat/_statistics.py index 49b26446dbc32..3c75b132ecad2 100644 --- a/python/pyspark/mllib/stat/_statistics.py +++ b/python/pyspark/mllib/stat/_statistics.py @@ -313,7 +313,7 @@ def _test(): (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) spark.stop() if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": diff --git a/python/pyspark/mllib/tree.py b/python/pyspark/mllib/tree.py index 619fa16d463f5..b05734ce489d9 100644 --- a/python/pyspark/mllib/tree.py +++ b/python/pyspark/mllib/tree.py @@ -17,6 +17,7 @@ from __future__ import absolute_import +import sys import random from pyspark import SparkContext, RDD, since @@ -654,7 +655,7 @@ def _test(): (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) spark.stop() if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": _test() diff --git a/python/pyspark/mllib/util.py b/python/pyspark/mllib/util.py index 97755807ef262..fc7809387b13a 100644 --- a/python/pyspark/mllib/util.py +++ b/python/pyspark/mllib/util.py @@ -521,7 +521,7 @@ def _test(): (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) spark.stop() if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": diff --git a/python/pyspark/profiler.py b/python/pyspark/profiler.py index 44d17bd629473..3c7656ab5758c 100644 --- a/python/pyspark/profiler.py +++ b/python/pyspark/profiler.py @@ -19,6 +19,7 @@ import pstats import os import atexit +import sys from pyspark.accumulators import AccumulatorParam @@ -173,4 +174,4 @@ def stats(self): import doctest (failure_count, test_count) = doctest.testmod() if failure_count: - exit(-1) + sys.exit(-1) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 93b8974a7e64a..4b44f76747264 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -2498,7 +2498,7 @@ def _test(): globs=globs, optionflags=doctest.ELLIPSIS) globs['sc'].stop() if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index ebf549396f463..15753f77bd903 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -715,4 +715,4 @@ def write_with_length(obj, stream): import doctest (failure_count, test_count) = doctest.testmod() if failure_count: - exit(-1) + sys.exit(-1) diff --git a/python/pyspark/shuffle.py b/python/pyspark/shuffle.py index e974cda9fc3e1..02c773302e9da 100644 --- a/python/pyspark/shuffle.py +++ b/python/pyspark/shuffle.py @@ -23,6 +23,7 @@ import itertools import operator import random +import sys import pyspark.heapq3 as heapq from pyspark.serializers import BatchedSerializer, PickleSerializer, FlattenedValuesSerializer, \ @@ -810,4 +811,4 @@ def load_partition(j): import doctest (failure_count, test_count) = doctest.testmod() if failure_count: - exit(-1) + sys.exit(-1) diff --git a/python/pyspark/sql/catalog.py b/python/pyspark/sql/catalog.py index 6aef0f22340be..b0d8357f4feec 100644 --- a/python/pyspark/sql/catalog.py +++ b/python/pyspark/sql/catalog.py @@ -15,6 +15,7 @@ # limitations under the License. # +import sys import warnings from collections import namedtuple @@ -306,7 +307,7 @@ def _test(): optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE) spark.stop() if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": _test() diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py index 43b38a2cd477c..e05a7b33c11a7 100644 --- a/python/pyspark/sql/column.py +++ b/python/pyspark/sql/column.py @@ -660,7 +660,7 @@ def _test(): optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF) spark.stop() if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": diff --git a/python/pyspark/sql/conf.py b/python/pyspark/sql/conf.py index 792c420ca6386..d929834aeeaa5 100644 --- a/python/pyspark/sql/conf.py +++ b/python/pyspark/sql/conf.py @@ -15,6 +15,8 @@ # limitations under the License. # +import sys + from pyspark import since from pyspark.rdd import ignore_unicode_prefix @@ -80,7 +82,7 @@ def _test(): (failure_count, test_count) = doctest.testmod(pyspark.sql.conf, globs=globs) spark.stop() if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": _test() diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index cc1cd1a5842d9..6cb90399dd616 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -543,7 +543,7 @@ def _test(): optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE) globs['sc'].stop() if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 8f90a367e8bf8..3fc194d8ec1d1 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -2231,7 +2231,7 @@ def _test(): optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF) globs['sc'].stop() if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index dc1341ac74d3d..dff590983b4d9 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2404,7 +2404,7 @@ def _test(): optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE) spark.stop() if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py index ab646535c864c..35cac406e0965 100644 --- a/python/pyspark/sql/group.py +++ b/python/pyspark/sql/group.py @@ -15,6 +15,8 @@ # limitations under the License. # +import sys + from pyspark import since from pyspark.rdd import ignore_unicode_prefix, PythonEvalType from pyspark.sql.column import Column, _to_seq, _to_java_column, _create_column_from_literal @@ -299,7 +301,7 @@ def _test(): optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF) spark.stop() if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 9d05ac7cb39be..803f561ece67b 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -970,7 +970,7 @@ def _test(): optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF) sc.stop() if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index 215bb3e5c5173..e82a9750a0014 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -830,7 +830,7 @@ def _test(): optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE) globs['sc'].stop() if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": _test() diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index cc622decfd682..e8966c20a8f42 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -930,7 +930,7 @@ def _test(): globs['spark'].stop() if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 1632862d3f1ba..826aab97e58db 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1890,7 +1890,7 @@ def _test(): (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) globs['sc'].stop() if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index ce804c18e9b14..24dd06c26089c 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -20,6 +20,7 @@ import sys import inspect import functools +import sys from pyspark import SparkContext, since from pyspark.rdd import _prepare_for_python_RDD, PythonEvalType, ignore_unicode_prefix @@ -397,7 +398,7 @@ def _test(): optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE) spark.stop() if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": diff --git a/python/pyspark/sql/window.py b/python/pyspark/sql/window.py index bb841a9b9ff7c..e667fba099fb9 100644 --- a/python/pyspark/sql/window.py +++ b/python/pyspark/sql/window.py @@ -264,7 +264,7 @@ def _test(): SparkContext('local[4]', 'PythonTest') (failure_count, test_count) = doctest.testmod(optionflags=doctest.NORMALIZE_WHITESPACE) if failure_count: - exit(-1) + sys.exit(-1) if __name__ == "__main__": diff --git a/python/pyspark/streaming/util.py b/python/pyspark/streaming/util.py index abbbf6eb9394f..df184471993ff 100644 --- a/python/pyspark/streaming/util.py +++ b/python/pyspark/streaming/util.py @@ -18,6 +18,7 @@ import time from datetime import datetime import traceback +import sys from pyspark import SparkContext, RDD @@ -147,4 +148,4 @@ def rddToFileName(prefix, suffix, timestamp): import doctest (failure_count, test_count) = doctest.testmod() if failure_count: - exit(-1) + sys.exit(-1) diff --git a/python/pyspark/util.py b/python/pyspark/util.py index 6837b18b7d7a5..ed1bdd0e4be83 100644 --- a/python/pyspark/util.py +++ b/python/pyspark/util.py @@ -22,6 +22,8 @@ __all__ = [] +import sys + def _exception_message(excp): """Return the message from an exception as either a str or unicode object. Supports both @@ -65,4 +67,4 @@ def _get_argspec(f): import doctest (failure_count, test_count) = doctest.testmod() if failure_count: - exit(-1) + sys.exit(-1) diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 202cac350aafc..a1a4336b1e8de 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -205,7 +205,7 @@ def main(infile, outfile): boot_time = time.time() split_index = read_int(infile) if split_index == -1: # for unit tests - exit(-1) + sys.exit(-1) version = utf8_deserializer.loads(infile) if version != "%d.%d" % sys.version_info[:2]: @@ -279,7 +279,7 @@ def process(): # Write the error to stderr if it happened while serializing print("PySpark worker failed with exception:", file=sys.stderr) print(traceback.format_exc(), file=sys.stderr) - exit(-1) + sys.exit(-1) finish_time = time.time() report_times(outfile, boot_time, init_time, finish_time) write_long(shuffle.MemoryBytesSpilled, outfile) @@ -297,7 +297,7 @@ def process(): else: # write a different value to tell JVM to not reuse this worker write_int(SpecialLengths.END_OF_DATA_SECTION, outfile) - exit(-1) + sys.exit(-1) if __name__ == '__main__': diff --git a/python/setup.py b/python/setup.py index 6a98401941d8d..794ceceae3008 100644 --- a/python/setup.py +++ b/python/setup.py @@ -26,7 +26,7 @@ if sys.version_info < (2, 7): print("Python versions prior to 2.7 are not supported for pip installed PySpark.", file=sys.stderr) - exit(-1) + sys.exit(-1) try: exec(open('pyspark/version.py').read()) @@ -98,7 +98,7 @@ def _supports_symlinks(): except: print("Temp path for symlink to parent already exists {0}".format(TEMP_PATH), file=sys.stderr) - exit(-1) + sys.exit(-1) # If you are changing the versions here, please also change ./python/pyspark/sql/utils.py and # ./python/run-tests.py. In case of Arrow, you should also check ./pom.xml. @@ -140,7 +140,7 @@ def _supports_symlinks(): if not os.path.isdir(SCRIPTS_TARGET): print(incorrect_invocation_message, file=sys.stderr) - exit(-1) + sys.exit(-1) # Scripts directive requires a list of each script path and does not take wild cards. script_names = os.listdir(SCRIPTS_TARGET) From 92e7ecbbbd6817378abdbd56541a9c13dcea8659 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Thu, 8 Mar 2018 14:18:14 +0100 Subject: [PATCH 0453/1161] [SPARK-23592][SQL] Add interpreted execution to DecodeUsingSerializer ## What changes were proposed in this pull request? The PR adds interpreted execution to DecodeUsingSerializer. ## How was this patch tested? added UT Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Marco Gaido Closes #20760 from mgaido91/SPARK-23592. --- .../catalyst/expressions/objects/objects.scala | 5 +++++ .../expressions/ObjectExpressionsSuite.scala | 15 +++++++++++++++ 2 files changed, 20 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 7bbc3c732e782..adf9ddf327c96 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -1242,6 +1242,11 @@ case class EncodeUsingSerializer(child: Expression, kryo: Boolean) case class DecodeUsingSerializer[T](child: Expression, tag: ClassTag[T], kryo: Boolean) extends UnaryExpression with NonSQLExpression with SerializerSupport { + override def nullSafeEval(input: Any): Any = { + val inputBytes = java.nio.ByteBuffer.wrap(input.asInstanceOf[Array[Byte]]) + serializerInstance.deserialize(inputBytes) + } + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val serializer = addImmutableSerializerIfNeeded(ctx) // Code to deserialize. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala index 346b13277c709..ffeec2a38c532 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.expressions +import scala.reflect.ClassTag + import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} import org.apache.spark.sql.Row @@ -123,4 +125,17 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(encodeUsingSerializer, null, InternalRow.fromSeq(Seq(null))) } } + + test("SPARK-23592: DecodeUsingSerializer should support interpreted execution") { + val cls = classOf[java.lang.Integer] + val inputObject = BoundReference(0, ObjectType(classOf[Array[Byte]]), nullable = true) + val conf = new SparkConf() + Seq(true, false).foreach { useKryo => + val serializer = if (useKryo) new KryoSerializer(conf) else new JavaSerializer(conf) + val input = serializer.newInstance().serialize(new Integer(1)).array() + val decodeUsingSerializer = DecodeUsingSerializer(inputObject, ClassTag(cls), useKryo) + checkEvaluation(decodeUsingSerializer, new Integer(1), InternalRow.fromSeq(Seq(input))) + checkEvaluation(decodeUsingSerializer, null, InternalRow.fromSeq(Seq(null))) + } + } } From 3be4adf6485ca19cdc5db23394c3f5a660d7dc6f Mon Sep 17 00:00:00 2001 From: lucio <576632108@qq.com> Date: Thu, 8 Mar 2018 08:03:24 -0600 Subject: [PATCH 0454/1161] [SPARK-22751][ML] Improve ML RandomForest shuffle performance ## What changes were proposed in this pull request? As I mentioned in [SPARK-22751](https://issues.apache.org/jira/browse/SPARK-22751?jql=project%20%3D%20SPARK%20AND%20component%20%3D%20ML%20AND%20text%20~%20randomforest), there is a shuffle performance problem in ML Randomforest when train a RF in high dimensional data. The reason is that, in _org.apache.spark.tree.impl.RandomForest_, the function _findSplitsBySorting_ will actually flatmap a sparse vector into a dense vector, then in groupByKey there will be a huge shuffle write size. To avoid this, we can add a filter in flatmap, to filter out zero value. And in function _findSplitsForContinuousFeature_, we can infer the number of zero value by _metadata_. In addition, if a feature only contains zero value, _continuousSplits_ will not has the key of feature id. So I add a check when using _continuousSplits_. ## How was this patch tested? Ran model locally using spark-submit. Author: lucio <576632108@qq.com> Closes #20472 from lucio-yz/master. --- .../spark/ml/tree/impl/RandomForest.scala | 52 ++++++++++++++----- .../ml/tree/impl/RandomForestSuite.scala | 23 ++++---- 2 files changed, 50 insertions(+), 25 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index 8e514f11e78ea..16f32d76b9984 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -892,13 +892,7 @@ private[spark] object RandomForest extends Logging { // Sample the input only if there are continuous features. val continuousFeatures = Range(0, numFeatures).filter(metadata.isContinuous) val sampledInput = if (continuousFeatures.nonEmpty) { - // Calculate the number of samples for approximate quantile calculation. - val requiredSamples = math.max(metadata.maxBins * metadata.maxBins, 10000) - val fraction = if (requiredSamples < metadata.numExamples) { - requiredSamples.toDouble / metadata.numExamples - } else { - 1.0 - } + val fraction = samplesFractionForFindSplits(metadata) logDebug("fraction of data used for calculating quantiles = " + fraction) input.sample(withReplacement = false, fraction, new XORShiftRandom(seed).nextInt()) } else { @@ -920,8 +914,9 @@ private[spark] object RandomForest extends Logging { val numPartitions = math.min(continuousFeatures.length, input.partitions.length) input - .flatMap(point => continuousFeatures.map(idx => (idx, point.features(idx)))) - .groupByKey(numPartitions) + .flatMap { point => + continuousFeatures.map(idx => (idx, point.features(idx))).filter(_._2 != 0.0) + }.groupByKey(numPartitions) .map { case (idx, samples) => val thresholds = findSplitsForContinuousFeature(samples, metadata, idx) val splits: Array[Split] = thresholds.map(thresh => new ContinuousSplit(idx, thresh)) @@ -933,7 +928,8 @@ private[spark] object RandomForest extends Logging { val numFeatures = metadata.numFeatures val splits: Array[Array[Split]] = Array.tabulate(numFeatures) { case i if metadata.isContinuous(i) => - val split = continuousSplits(i) + // some features may contain only zero, so continuousSplits will not have a record + val split = continuousSplits.getOrElse(i, Array.empty[Split]) metadata.setNumSplits(i, split.length) split @@ -1003,11 +999,22 @@ private[spark] object RandomForest extends Logging { } else { val numSplits = metadata.numSplits(featureIndex) - // get count for each distinct value - val (valueCountMap, numSamples) = featureSamples.foldLeft((Map.empty[Double, Int], 0)) { - case ((m, cnt), x) => - (m + ((x, m.getOrElse(x, 0) + 1)), cnt + 1) + // get count for each distinct value except zero value + val partNumSamples = featureSamples.size + val partValueCountMap = scala.collection.mutable.Map[Double, Int]() + featureSamples.foreach { x => + partValueCountMap(x) = partValueCountMap.getOrElse(x, 0) + 1 } + + // Calculate the expected number of samples for finding splits + val numSamples = (samplesFractionForFindSplits(metadata) * metadata.numExamples).toInt + // add expected zero value count and get complete statistics + val valueCountMap: Map[Double, Int] = if (numSamples - partNumSamples > 0) { + partValueCountMap.toMap + (0.0 -> (numSamples - partNumSamples)) + } else { + partValueCountMap.toMap + } + // sort distinct values val valueCounts = valueCountMap.toSeq.sortBy(_._1).toArray @@ -1149,4 +1156,21 @@ private[spark] object RandomForest extends Logging { 3 * totalBins } } + + /** + * Calculate the subsample fraction for finding splits + * + * @param metadata decision tree metadata + * @return subsample fraction + */ + private def samplesFractionForFindSplits( + metadata: DecisionTreeMetadata): Double = { + // Calculate the number of samples for approximate quantile calculation. + val requiredSamples = math.max(metadata.maxBins * metadata.maxBins, 10000) + if (requiredSamples < metadata.numExamples) { + requiredSamples.toDouble / metadata.numExamples + } else { + 1.0 + } + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala index 5f0d26eb5c058..743dacf146fe7 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala @@ -93,12 +93,12 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { test("find splits for a continuous feature") { // find splits for normal case { - val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0, + val fakeMetadata = new DecisionTreeMetadata(1, 200000, 0, 0, Map(), Set(), Array(6), Gini, QuantileStrategy.Sort, 0, 0, 0.0, 0, 0 ) - val featureSamples = Array.fill(200000)(math.random) + val featureSamples = Array.fill(10000)(math.random).filter(_ != 0.0) val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) assert(splits.length === 5) assert(fakeMetadata.numSplits(0) === 5) @@ -109,7 +109,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { // SPARK-16957: Use midpoints for split values. { - val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0, + val fakeMetadata = new DecisionTreeMetadata(1, 8, 0, 0, Map(), Set(), Array(3), Gini, QuantileStrategy.Sort, 0, 0, 0.0, 0, 0 @@ -117,7 +117,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { // possibleSplits <= numSplits { - val featureSamples = Array(0, 1, 0, 0, 1, 0, 1, 1).map(_.toDouble) + val featureSamples = Array(0, 1, 0, 0, 1, 0, 1, 1).map(_.toDouble).filter(_ != 0.0) val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) val expectedSplits = Array((0.0 + 1.0) / 2) assert(splits === expectedSplits) @@ -125,7 +125,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { // possibleSplits > numSplits { - val featureSamples = Array(0, 0, 1, 1, 2, 2, 3, 3).map(_.toDouble) + val featureSamples = Array(0, 0, 1, 1, 2, 2, 3, 3).map(_.toDouble).filter(_ != 0.0) val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) val expectedSplits = Array((0.0 + 1.0) / 2, (2.0 + 3.0) / 2) assert(splits === expectedSplits) @@ -135,7 +135,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { // find splits should not return identical splits // when there are not enough split candidates, reduce the number of splits in metadata { - val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0, + val fakeMetadata = new DecisionTreeMetadata(1, 12, 0, 0, Map(), Set(), Array(5), Gini, QuantileStrategy.Sort, 0, 0, 0.0, 0, 0 @@ -150,7 +150,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { // find splits when most samples close to the minimum { - val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0, + val fakeMetadata = new DecisionTreeMetadata(1, 18, 0, 0, Map(), Set(), Array(3), Gini, QuantileStrategy.Sort, 0, 0, 0.0, 0, 0 @@ -164,12 +164,13 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { // find splits when most samples close to the maximum { - val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0, + val fakeMetadata = new DecisionTreeMetadata(1, 17, 0, 0, Map(), Set(), Array(2), Gini, QuantileStrategy.Sort, 0, 0, 0.0, 0, 0 ) - val featureSamples = Array(0, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2).map(_.toDouble) + val featureSamples = Array(0, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2) + .map(_.toDouble).filter(_ != 0.0) val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) val expectedSplits = Array((1.0 + 2.0) / 2) assert(splits === expectedSplits) @@ -177,12 +178,12 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { // find splits for constant feature { - val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0, + val fakeMetadata = new DecisionTreeMetadata(1, 3, 0, 0, Map(), Set(), Array(3), Gini, QuantileStrategy.Sort, 0, 0, 0.0, 0, 0 ) - val featureSamples = Array(0, 0, 0).map(_.toDouble) + val featureSamples = Array(0, 0, 0).map(_.toDouble).filter(_ != 0.0) val featureSamplesEmpty = Array.empty[Double] val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) assert(splits === Array.empty[Double]) From ea480990e726aed59750f1cea8d40adba56d991a Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Thu, 8 Mar 2018 11:09:15 -0800 Subject: [PATCH 0455/1161] [SPARK-23628][SQL] calculateParamLength should not return 1 + num of epressions ## What changes were proposed in this pull request? There was a bug in `calculateParamLength` which caused it to return always 1 + the number of expressions. This could lead to Exceptions especially with expressions of type long. ## How was this patch tested? added UT + fixed previous UT Author: Marco Gaido Closes #20772 from mgaido91/SPARK-23628. --- .../expressions/codegen/CodeGenerator.scala | 51 ++++++++++--------- .../expressions/CodeGenerationSuite.scala | 6 +++ .../sql/execution/WholeStageCodegenExec.scala | 5 +- .../execution/WholeStageCodegenSuite.scala | 16 +++--- 4 files changed, 43 insertions(+), 35 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 793824b0b0a2f..fe5e63ec0a2bb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -1063,31 +1063,6 @@ class CodegenContext { "" } } - - /** - * Returns the length of parameters for a Java method descriptor. `this` contributes one unit - * and a parameter of type long or double contributes two units. Besides, for nullable parameter, - * we also need to pass a boolean parameter for the null status. - */ - def calculateParamLength(params: Seq[Expression]): Int = { - def paramLengthForExpr(input: Expression): Int = { - // For a nullable expression, we need to pass in an extra boolean parameter. - (if (input.nullable) 1 else 0) + javaType(input.dataType) match { - case JAVA_LONG | JAVA_DOUBLE => 2 - case _ => 1 - } - } - // Initial value is 1 for `this`. - 1 + params.map(paramLengthForExpr(_)).sum - } - - /** - * In Java, a method descriptor is valid only if it represents method parameters with a total - * length less than a pre-defined constant. - */ - def isValidParamLength(paramLength: Int): Boolean = { - paramLength <= MAX_JVM_METHOD_PARAMS_LENGTH - } } /** @@ -1538,4 +1513,30 @@ object CodeGenerator extends Logging { def defaultValue(dt: DataType, typedNull: Boolean = false): String = defaultValue(javaType(dt), typedNull) + + /** + * Returns the length of parameters for a Java method descriptor. `this` contributes one unit + * and a parameter of type long or double contributes two units. Besides, for nullable parameter, + * we also need to pass a boolean parameter for the null status. + */ + def calculateParamLength(params: Seq[Expression]): Int = { + def paramLengthForExpr(input: Expression): Int = { + val javaParamLength = javaType(input.dataType) match { + case JAVA_LONG | JAVA_DOUBLE => 2 + case _ => 1 + } + // For a nullable expression, we need to pass in an extra boolean parameter. + (if (input.nullable) 1 else 0) + javaParamLength + } + // Initial value is 1 for `this`. + 1 + params.map(paramLengthForExpr).sum + } + + /** + * In Java, a method descriptor is valid only if it represents method parameters with a total + * length less than a pre-defined constant. + */ + def isValidParamLength(paramLength: Int): Boolean = { + paramLength <= MAX_JVM_METHOD_PARAMS_LENGTH + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index 1e48c7b8df9da..64c13e8972036 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -436,4 +436,10 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { ctx.addImmutableStateIfNotExists("String", mutableState2) assert(ctx.inlinedMutableStates.length == 2) } + + test("SPARK-23628: calculateParamLength should compute properly the param length") { + assert(CodeGenerator.calculateParamLength(Seq.range(0, 100).map(Literal(_))) == 101) + assert(CodeGenerator.calculateParamLength( + Seq.range(0, 100).map(x => Literal(x.toLong))) == 201) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index f89e3fb0e536f..6ddaacfee1a40 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -174,8 +174,9 @@ trait CodegenSupport extends SparkPlan { // declaration. val confEnabled = SQLConf.get.wholeStageSplitConsumeFuncByOperator val requireAllOutput = output.forall(parent.usedInputs.contains(_)) - val paramLength = ctx.calculateParamLength(output) + (if (row != null) 1 else 0) - val consumeFunc = if (confEnabled && requireAllOutput && ctx.isValidParamLength(paramLength)) { + val paramLength = CodeGenerator.calculateParamLength(output) + (if (row != null) 1 else 0) + val consumeFunc = if (confEnabled && requireAllOutput + && CodeGenerator.isValidParamLength(paramLength)) { constructDoConsumeFunction(ctx, inputVars, row) } else { parent.doConsume(ctx, inputVars, rowVar) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index ef16292a8e75c..0fb9dd2017a09 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec import org.apache.spark.sql.execution.joins.SortMergeJoinExec import org.apache.spark.sql.expressions.scalalang.typed -import org.apache.spark.sql.functions.{avg, broadcast, col, max} +import org.apache.spark.sql.functions.{avg, broadcast, col, lit, max} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{IntegerType, StringType, StructType} @@ -249,12 +249,12 @@ class WholeStageCodegenSuite extends QueryTest with SharedSQLContext { } test("Skip splitting consume function when parameter number exceeds JVM limit") { - import testImplicits._ - - Seq((255, false), (254, true)).foreach { case (columnNum, hasSplit) => + // since every field is nullable we have 2 params for each input column (one for the value + // and one for the isNull variable) + Seq((128, false), (127, true)).foreach { case (columnNum, hasSplit) => withTempPath { dir => val path = dir.getCanonicalPath - spark.range(10).select(Seq.tabulate(columnNum) {i => ('id + i).as(s"c$i")} : _*) + spark.range(10).select(Seq.tabulate(columnNum) {i => lit(i).as(s"c$i")} : _*) .write.mode(SaveMode.Overwrite).parquet(path) withSQLConf(SQLConf.WHOLESTAGE_MAX_NUM_FIELDS.key -> "255", @@ -263,10 +263,10 @@ class WholeStageCodegenSuite extends QueryTest with SharedSQLContext { val df = spark.read.parquet(path).selectExpr(projection: _*) val plan = df.queryExecution.executedPlan - val wholeStageCodeGenExec = plan.find(p => p match { - case wp: WholeStageCodegenExec => true + val wholeStageCodeGenExec = plan.find { + case _: WholeStageCodegenExec => true case _ => false - }) + } assert(wholeStageCodeGenExec.isDefined) val code = wholeStageCodeGenExec.get.asInstanceOf[WholeStageCodegenExec].doCodeGen()._2 assert(code.body.contains("project_doConsume") == hasSplit) From e7bbca88964d95593fa15eb94643ba519801e352 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Thu, 8 Mar 2018 22:02:28 +0100 Subject: [PATCH 0456/1161] [SPARK-23602][SQL] PrintToStderr prints value also in interpreted mode ## What changes were proposed in this pull request? `PrintToStderr` was doing what is it supposed to only when code generation is enabled. The PR adds the same behavior in interpreted mode too. ## How was this patch tested? added UT Author: Marco Gaido Closes #20773 from mgaido91/SPARK-23602. --- .../spark/sql/catalyst/expressions/misc.scala | 7 +++++- .../expressions/MiscExpressionsSuite.scala | 25 +++++++++++++++++++ 2 files changed, 31 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index 4b9006ab5b423..38e4fe44b15ab 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -31,7 +31,12 @@ case class PrintToStderr(child: Expression) extends UnaryExpression { override def dataType: DataType = child.dataType - protected override def nullSafeEval(input: Any): Any = input + protected override def nullSafeEval(input: Any): Any = { + // scalastyle:off println + System.err.println(outputPrefix + input) + // scalastyle:on println + input + } private val outputPrefix = s"Result of ${child.simpleString} is " diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala index a21c139fe71d0..c3d08bf68c7bb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.expressions +import java.io.PrintStream + import org.apache.spark.SparkFunSuite import org.apache.spark.sql.types._ @@ -43,4 +45,27 @@ class MiscExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Length(Uuid()), 36) assert(evaluateWithoutCodegen(Uuid()) !== evaluateWithoutCodegen(Uuid())) } + + test("PrintToStderr") { + val inputExpr = Literal(1) + val systemErr = System.err + + val (outputEval, outputCodegen) = try { + val errorStream = new java.io.ByteArrayOutputStream() + System.setErr(new PrintStream(errorStream)) + // check without codegen + checkEvaluationWithoutCodegen(PrintToStderr(inputExpr), 1) + val outputEval = errorStream.toString + errorStream.reset() + // check with codegen + checkEvaluationWithGeneratedMutableProjection(PrintToStderr(inputExpr), 1) + val outputCodegen = errorStream.toString + (outputEval, outputCodegen) + } finally { + System.setErr(systemErr) + } + + assert(outputCodegen.contains(s"Result of $inputExpr is 1")) + assert(outputEval.contains(s"Result of $inputExpr is 1")) + } } From d90e77bd0ec19f8ba9198a24ec2ab3db7708eca8 Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Thu, 8 Mar 2018 14:58:40 -0800 Subject: [PATCH 0457/1161] [SPARK-23271][SQL] Parquet output contains only _SUCCESS file after writing an empty dataframe ## What changes were proposed in this pull request? Below are the two cases. ``` SQL case 1 scala> List.empty[String].toDF().rdd.partitions.length res18: Int = 1 ``` When we write the above data frame as parquet, we create a parquet file containing just the schema of the data frame. Case 2 ``` SQL scala> val anySchema = StructType(StructField("anyName", StringType, nullable = false) :: Nil) anySchema: org.apache.spark.sql.types.StructType = StructType(StructField(anyName,StringType,false)) scala> spark.read.schema(anySchema).csv("/tmp/empty_folder").rdd.partitions.length res22: Int = 0 ``` For the 2nd case, since number of partitions = 0, we don't call the write task (the task has logic to create the empty metadata only parquet file) The fix is to create a dummy single partition RDD and set up the write task based on it to ensure the metadata-only file. ## How was this patch tested? A new test is added to DataframeReaderWriterSuite. Author: Dilip Biswal Closes #20525 from dilipbiswal/spark-23271. --- docs/sql-programming-guide.md | 1 + .../datasources/FileFormatWriter.scala | 15 ++++++++++++--- .../spark/sql/FileBasedDataSourceSuite.scala | 18 ++++++++++++++++++ .../sql/test/DataFrameReaderWriterSuite.scala | 1 - 4 files changed, 31 insertions(+), 4 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 451b814ab6c53..d2132d2ae7441 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1805,6 +1805,7 @@ working with timestamps in `pandas_udf`s to get the best performance, see - Since Spark 2.4, Spark maximizes the usage of a vectorized ORC reader for ORC files by default. To do that, `spark.sql.orc.impl` and `spark.sql.orc.filterPushdown` change their default values to `native` and `true` respectively. - In PySpark, when Arrow optimization is enabled, previously `toPandas` just failed when Arrow optimization is unabled to be used whereas `createDataFrame` from Pandas DataFrame allowed the fallback to non-optimization. Now, both `toPandas` and `createDataFrame` from Pandas DataFrame allow the fallback by default, which can be switched off by `spark.sql.execution.arrow.fallback.enabled`. + - Since Spark 2.4, writing an empty dataframe to a directory launches at least one write task, even if physically the dataframe has no partition. This introduces a small behavior change that for self-describing file formats like Parquet and Orc, Spark creates a metadata-only file in the target directory when writing a 0-partition dataframe, so that schema inference can still work if users read that directory later. The new behavior is more reasonable and more consistent regarding writing empty dataframe. ## Upgrading From Spark SQL 2.2 to 2.3 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala index 1d80a69bc5a1d..401597f967218 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala @@ -190,9 +190,18 @@ object FileFormatWriter extends Logging { global = false, child = plan).execute() } - val ret = new Array[WriteTaskResult](rdd.partitions.length) + + // SPARK-23271 If we are attempting to write a zero partition rdd, create a dummy single + // partition rdd to make sure we at least set up one write task to write the metadata. + val rddWithNonEmptyPartitions = if (rdd.partitions.length == 0) { + sparkSession.sparkContext.parallelize(Array.empty[InternalRow], 1) + } else { + rdd + } + + val ret = new Array[WriteTaskResult](rddWithNonEmptyPartitions.partitions.length) sparkSession.sparkContext.runJob( - rdd, + rddWithNonEmptyPartitions, (taskContext: TaskContext, iter: Iterator[InternalRow]) => { executeTask( description = description, @@ -202,7 +211,7 @@ object FileFormatWriter extends Logging { committer, iterator = iter) }, - 0 until rdd.partitions.length, + rddWithNonEmptyPartitions.partitions.indices, (index, res: WriteTaskResult) => { committer.onTaskCommit(res.commitMsg) ret(index) = res diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala index 73e3df3b6202e..bd3071bcf9010 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala @@ -23,6 +23,7 @@ import org.apache.hadoop.fs.Path import org.scalatest.BeforeAndAfterAll import org.apache.spark.SparkException +import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext @@ -89,6 +90,23 @@ class FileBasedDataSourceSuite extends QueryTest with SharedSQLContext with Befo } } + Seq("orc", "parquet").foreach { format => + test(s"SPARK-23271 empty RDD when saved should write a metadata only file - $format") { + withTempPath { outputPath => + val df = spark.emptyDataFrame.select(lit(1).as("i")) + df.write.format(format).save(outputPath.toString) + val partFiles = outputPath.listFiles() + .filter(f => f.isFile && !f.getName.startsWith(".") && !f.getName.startsWith("_")) + assert(partFiles.length === 1) + + // Now read the file. + val df1 = spark.read.format(format).load(outputPath.toString) + checkAnswer(df1, Seq.empty[Row]) + assert(df1.schema.equals(df.schema.asNullable)) + } + } + } + allFileBasedDataSources.foreach { format => test(s"SPARK-22146 read files containing special characters using $format") { withTempDir { dir => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala index 8c9bb7d56a35f..a707a88dfa670 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala @@ -301,7 +301,6 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with Be intercept[AnalysisException] { spark.range(10).write.format("csv").mode("overwrite").partitionBy("id").save(path) } - spark.emptyDataFrame.write.format("parquet").mode("overwrite").save(path) } } From 2c3673680e16f88f1d1cd73a3f7445ded5b3daa8 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Fri, 9 Mar 2018 10:36:38 -0800 Subject: [PATCH 0458/1161] [SPARK-23630][YARN] Allow user's hadoop conf customizations to take effect. This change restores functionality that was inadvertently removed as part of the fix for SPARK-22372. Also modified an existing unit test to make sure the feature works as intended. Author: Marcelo Vanzin Closes #20776 from vanzin/SPARK-23630. --- .../apache/spark/deploy/SparkHadoopUtil.scala | 11 +++++- .../org/apache/spark/deploy/yarn/Client.scala | 14 ++++---- .../spark/deploy/yarn/YarnClusterSuite.scala | 34 ++++++++++++++----- 3 files changed, 44 insertions(+), 15 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala index e14f9845e6db6..177295fb7af0f 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala @@ -111,7 +111,9 @@ class SparkHadoopUtil extends Logging { * subsystems. */ def newConfiguration(conf: SparkConf): Configuration = { - SparkHadoopUtil.newConfiguration(conf) + val hadoopConf = SparkHadoopUtil.newConfiguration(conf) + hadoopConf.addResource(SparkHadoopUtil.SPARK_HADOOP_CONF_FILE) + hadoopConf } /** @@ -435,6 +437,13 @@ object SparkHadoopUtil { */ private[spark] val UPDATE_INPUT_METRICS_INTERVAL_RECORDS = 1000 + /** + * Name of the file containing the gateway's Hadoop configuration, to be overlayed on top of the + * cluster's Hadoop config. It is up to the Spark code launching the application to create + * this file if it's desired. If the file doesn't exist, it will just be ignored. + */ + private[spark] val SPARK_HADOOP_CONF_FILE = "__spark_hadoop_conf__.xml" + def get: SparkHadoopUtil = instance /** diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 8cd3cd9746a3a..28087dee831d1 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -696,7 +696,13 @@ private[spark] class Client( } } - Seq("HADOOP_CONF_DIR", "YARN_CONF_DIR").foreach { envKey => + // SPARK-23630: during testing, Spark scripts filter out hadoop conf dirs so that user's + // environments do not interfere with tests. This allows a special env variable during + // tests so that custom conf dirs can be used by unit tests. + val confDirs = Seq("HADOOP_CONF_DIR", "YARN_CONF_DIR") ++ + (if (Utils.isTesting) Seq("SPARK_TEST_HADOOP_CONF_DIR") else Nil) + + confDirs.foreach { envKey => sys.env.get(envKey).foreach { path => val dir = new File(path) if (dir.isDirectory()) { @@ -753,7 +759,7 @@ private[spark] class Client( // Save the YARN configuration into a separate file that will be overlayed on top of the // cluster's Hadoop conf. - confStream.putNextEntry(new ZipEntry(SPARK_HADOOP_CONF_FILE)) + confStream.putNextEntry(new ZipEntry(SparkHadoopUtil.SPARK_HADOOP_CONF_FILE)) hadoopConf.writeXml(confStream) confStream.closeEntry() @@ -1220,10 +1226,6 @@ private object Client extends Logging { // Name of the file in the conf archive containing Spark configuration. val SPARK_CONF_FILE = "__spark_conf__.properties" - // Name of the file containing the gateway's Hadoop configuration, to be overlayed on top of the - // cluster's Hadoop config. - val SPARK_HADOOP_CONF_FILE = "__spark_hadoop_conf__.xml" - // Subdirectory where the user's python files (not archives) will be placed. val LOCALIZED_PYTHON_DIR = "__pyfiles__" diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala index 5003326b440bf..33d400a5b1b2e 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala @@ -114,12 +114,25 @@ class YarnClusterSuite extends BaseYarnClusterSuite { )) } - test("yarn-cluster should respect conf overrides in SparkHadoopUtil (SPARK-16414)") { + test("yarn-cluster should respect conf overrides in SparkHadoopUtil (SPARK-16414, SPARK-23630)") { + // Create a custom hadoop config file, to make sure it's contents are propagated to the driver. + val customConf = Utils.createTempDir() + val coreSite = """ + | + | + | spark.test.key + | testvalue + | + | + |""".stripMargin + Files.write(coreSite, new File(customConf, "core-site.xml"), StandardCharsets.UTF_8) + val result = File.createTempFile("result", null, tempDir) val finalState = runSpark(false, mainClassName(YarnClusterDriverUseSparkHadoopUtilConf.getClass), - appArgs = Seq("key=value", result.getAbsolutePath()), - extraConf = Map("spark.hadoop.key" -> "value")) + appArgs = Seq("key=value", "spark.test.key=testvalue", result.getAbsolutePath()), + extraConf = Map("spark.hadoop.key" -> "value"), + extraEnv = Map("SPARK_TEST_HADOOP_CONF_DIR" -> customConf.getAbsolutePath())) checkResult(finalState, result) } @@ -319,13 +332,13 @@ private object YarnClusterDriverWithFailure extends Logging with Matchers { private object YarnClusterDriverUseSparkHadoopUtilConf extends Logging with Matchers { def main(args: Array[String]): Unit = { - if (args.length != 2) { + if (args.length < 2) { // scalastyle:off println System.err.println( s""" |Invalid command line: ${args.mkString(" ")} | - |Usage: YarnClusterDriverUseSparkHadoopUtilConf [hadoopConfKey=value] [result file] + |Usage: YarnClusterDriverUseSparkHadoopUtilConf [hadoopConfKey=value]+ [result file] """.stripMargin) // scalastyle:on println System.exit(1) @@ -335,11 +348,16 @@ private object YarnClusterDriverUseSparkHadoopUtilConf extends Logging with Matc .set("spark.extraListeners", classOf[SaveExecutorInfo].getName) .setAppName("yarn test using SparkHadoopUtil's conf")) - val kv = args(0).split("=") - val status = new File(args(1)) + val kvs = args.take(args.length - 1).map { kv => + val parsed = kv.split("=") + (parsed(0), parsed(1)) + } + val status = new File(args.last) var result = "failure" try { - SparkHadoopUtil.get.conf.get(kv(0)) should be (kv(1)) + kvs.foreach { case (k, v) => + SparkHadoopUtil.get.conf.get(k) should be (v) + } result = "success" } finally { Files.write(result, status, StandardCharsets.UTF_8) From 2ca9bb083c515511d2bfee271fc3e0269aceb9d5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20=C5=9Awitakowski?= Date: Fri, 9 Mar 2018 14:29:31 -0800 Subject: [PATCH 0459/1161] [SPARK-23173][SQL] Avoid creating corrupt parquet files when loading data from JSON MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? The from_json() function accepts an additional parameter, where the user might specify the schema. The issue is that the specified schema might not be compatible with data. In particular, the JSON data might be missing data for fields declared as non-nullable in the schema. The from_json() function does not verify the data against such errors. When data with missing fields is sent to the parquet encoder, there is no verification either. The end results is a corrupt parquet file. To avoid corruptions, make sure that all fields in the user-specified schema are set to be nullable. Since this changes the behavior of a public function, we need to include it in release notes. The behavior can be reverted by setting `spark.sql.fromJsonForceNullableSchema=false` ## How was this patch tested? Added two new tests. Author: Michał Świtakowski Closes #20694 from mswit-databricks/SPARK-23173. --- .../expressions/jsonExpressions.scala | 22 +++++++++----- .../apache/spark/sql/internal/SQLConf.scala | 8 +++++ .../expressions/JsonExpressionsSuite.scala | 30 ++++++++++++++++++- .../datasources/parquet/ParquetIOSuite.scala | 19 ++++++++++++ 4 files changed, 70 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index 18b4fed597447..fdd672c416a03 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -30,6 +30,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.json._ import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, BadRecordException, FailFastMode, GenericArrayData, MapData} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils @@ -515,10 +516,15 @@ case class JsonToStructs( child: Expression, timeZoneId: Option[String] = None) extends UnaryExpression with TimeZoneAwareExpression with CodegenFallback with ExpectsInputTypes { - override def nullable: Boolean = true - def this(schema: DataType, options: Map[String, String], child: Expression) = - this(schema, options, child, None) + val forceNullableSchema = SQLConf.get.getConf(SQLConf.FROM_JSON_FORCE_NULLABLE_SCHEMA) + + // The JSON input data might be missing certain fields. We force the nullability + // of the user-provided schema to avoid data corruptions. In particular, the parquet-mr encoder + // can generate incorrect files if values are missing in columns declared as non-nullable. + val nullableSchema = if (forceNullableSchema) schema.asNullable else schema + + override def nullable: Boolean = true // Used in `FunctionRegistry` def this(child: Expression, schema: Expression) = @@ -535,22 +541,22 @@ case class JsonToStructs( child = child, timeZoneId = None) - override def checkInputDataTypes(): TypeCheckResult = schema match { + override def checkInputDataTypes(): TypeCheckResult = nullableSchema match { case _: StructType | ArrayType(_: StructType, _) => super.checkInputDataTypes() case _ => TypeCheckResult.TypeCheckFailure( - s"Input schema ${schema.simpleString} must be a struct or an array of structs.") + s"Input schema ${nullableSchema.simpleString} must be a struct or an array of structs.") } @transient - lazy val rowSchema = schema match { + lazy val rowSchema = nullableSchema match { case st: StructType => st case ArrayType(st: StructType, _) => st } // This converts parsed rows to the desired output by the given schema. @transient - lazy val converter = schema match { + lazy val converter = nullableSchema match { case _: StructType => (rows: Seq[InternalRow]) => if (rows.length == 1) rows.head else null case ArrayType(_: StructType, _) => @@ -563,7 +569,7 @@ case class JsonToStructs( rowSchema, new JSONOptions(options + ("mode" -> FailFastMode.name), timeZoneId.get)) - override def dataType: DataType = schema + override def dataType: DataType = nullableSchema override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = copy(timeZoneId = Option(timeZoneId)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 3f96112659c11..11864bd1b1847 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -493,6 +493,14 @@ object SQLConf { .stringConf .createWithDefault("_corrupt_record") + val FROM_JSON_FORCE_NULLABLE_SCHEMA = buildConf("spark.sql.fromJsonForceNullableSchema") + .internal() + .doc("When true, force the output schema of the from_json() function to be nullable " + + "(including all the fields). Otherwise, the schema might not be compatible with" + + "actual data, which leads to curruptions.") + .booleanConf + .createWithDefault(true) + val BROADCAST_TIMEOUT = buildConf("spark.sql.broadcastTimeout") .doc("Timeout in seconds for the broadcast wait time in broadcast joins.") .timeConf(TimeUnit.SECONDS) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala index a0bbe02f92354..7812319756eae 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala @@ -22,11 +22,13 @@ import java.util.Calendar import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors.TreeNodeException +import org.apache.spark.sql.catalyst.plans.PlanTestBase import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeTestUtils, DateTimeUtils, GenericArrayData, PermissiveMode} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String -class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { +class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with PlanTestBase { val json = """ |{"store":{"fruit":[{"weight":8,"type":"apple"},{"weight":9,"type":"pear"}], @@ -680,4 +682,30 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { ) } } + + test("from_json missing fields") { + for (forceJsonNullableSchema <- Seq(false, true)) { + withSQLConf(SQLConf.FROM_JSON_FORCE_NULLABLE_SCHEMA.key -> forceJsonNullableSchema.toString) { + val input = + """{ + | "a": 1, + | "c": "foo" + |} + |""".stripMargin + val jsonSchema = new StructType() + .add("a", LongType, nullable = false) + .add("b", StringType, nullable = false) + .add("c", StringType, nullable = false) + val output = InternalRow(1L, null, UTF8String.fromString("foo")) + checkEvaluation( + JsonToStructs(jsonSchema, Map.empty, Literal.create(input, StringType), gmtId), + output + ) + val schema = JsonToStructs(jsonSchema, Map.empty, Literal.create(input, StringType), gmtId) + .dataType + val schemaToCompare = if (forceJsonNullableSchema) jsonSchema.asNullable else jsonSchema + assert(schemaToCompare == schema) + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala index 3af80930ec807..0b3e8ca060d87 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala @@ -43,6 +43,7 @@ import org.apache.spark.sql.catalyst.{InternalRow, ScalaReflection} import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeRow} import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.datasources.SQLHadoopMapReduceCommitProtocol +import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ @@ -780,6 +781,24 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { assert(option.compressionCodecClassName == "UNCOMPRESSED") } } + + test("SPARK-23173 Writing a file with data converted from JSON with and incorrect user schema") { + withTempPath { file => + val jsonData = + """{ + | "a": 1, + | "c": "foo" + |} + |""".stripMargin + val jsonSchema = new StructType() + .add("a", LongType, nullable = false) + .add("b", StringType, nullable = false) + .add("c", StringType, nullable = false) + spark.range(1).select(from_json(lit(jsonData), jsonSchema) as "input") + .write.parquet(file.getAbsolutePath) + checkAnswer(spark.read.parquet(file.getAbsolutePath), Seq(Row(Row(1, null, "foo")))) + } + } } class JobCommitFailureParquetOutputCommitter(outputPath: Path, context: TaskAttemptContext) From 10b0657b035641ce735055bba2c8459e71bc2400 Mon Sep 17 00:00:00 2001 From: Wang Gengliang Date: Fri, 9 Mar 2018 15:41:19 -0800 Subject: [PATCH 0460/1161] [SPARK-23624][SQL] Revise doc of method pushFilters in Datasource V2 ## What changes were proposed in this pull request? Revise doc of method pushFilters in SupportsPushDownFilters/SupportsPushDownCatalystFilters In `FileSourceStrategy`, except `partitionKeyFilters`(the references of which is subset of partition keys), all filters needs to be evaluated after scanning. Otherwise, Spark will get wrong result from data sources like Orc/Parquet. This PR is to improve the doc. Author: Wang Gengliang Closes #20769 from gengliangwang/revise_pushdown_doc. --- .../sql/sources/v2/reader/SupportsPushDownCatalystFilters.java | 2 +- .../spark/sql/sources/v2/reader/SupportsPushDownFilters.java | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownCatalystFilters.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownCatalystFilters.java index 98224102374aa..290d614805ac7 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownCatalystFilters.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownCatalystFilters.java @@ -34,7 +34,7 @@ public interface SupportsPushDownCatalystFilters extends DataSourceReader { /** - * Pushes down filters, and returns unsupported filters. + * Pushes down filters, and returns filters that need to be evaluated after scanning. */ Expression[] pushCatalystFilters(Expression[] filters); diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java index f35c711b0387a..1cff024232a44 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java @@ -32,7 +32,7 @@ public interface SupportsPushDownFilters extends DataSourceReader { /** - * Pushes down filters, and returns unsupported filters. + * Pushes down filters, and returns filters that need to be evaluated after scanning. */ Filter[] pushFilters(Filter[] filters); From 1a54f48b6744032b16543594651ee6d5e3ad4bda Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Fri, 9 Mar 2018 15:54:55 -0800 Subject: [PATCH 0461/1161] [SPARK-23510][SQL][FOLLOW-UP] Support Hive 2.2 and Hive 2.3 metastore ## What changes were proposed in this pull request? In the PR https://github.com/apache/spark/pull/20671, I forgot to update the doc about this new support. ## How was this patch tested? N/A Author: gatorsmile Closes #20789 from gatorsmile/docUpdate. --- docs/sql-programming-guide.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index d2132d2ae7441..0e092e0e37ccf 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -2229,7 +2229,7 @@ referencing a singleton. Spark SQL is designed to be compatible with the Hive Metastore, SerDes and UDFs. Currently Hive SerDes and UDFs are based on Hive 1.2.1, and Spark SQL can be connected to different versions of Hive Metastore -(from 0.12.0 to 2.1.1. Also see [Interacting with Different Versions of Hive Metastore](#interacting-with-different-versions-of-hive-metastore)). +(from 0.12.0 to 2.3.2. Also see [Interacting with Different Versions of Hive Metastore](#interacting-with-different-versions-of-hive-metastore)). #### Deploying in Existing Hive Warehouses From b6f837c9d3cb0f76f0a52df37e34aea8944f6867 Mon Sep 17 00:00:00 2001 From: DylanGuedes Date: Sat, 10 Mar 2018 19:48:29 +0900 Subject: [PATCH 0462/1161] [PYTHON] Changes input variable to not conflict with built-in function Signed-off-by: DylanGuedes ## What changes were proposed in this pull request? Changes variable name conflict: [input is a built-in python function](https://stackoverflow.com/questions/20670732/is-input-a-keyword-in-python). ## How was this patch tested? I runned the example and it works fine. Author: DylanGuedes Closes #20775 from DylanGuedes/input_variable. --- examples/src/main/python/ml/dataframe_example.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/examples/src/main/python/ml/dataframe_example.py b/examples/src/main/python/ml/dataframe_example.py index d62cf2338a1fe..cabc3de68f2f4 100644 --- a/examples/src/main/python/ml/dataframe_example.py +++ b/examples/src/main/python/ml/dataframe_example.py @@ -17,7 +17,7 @@ """ An example of how to use DataFrame for ML. Run with:: - bin/spark-submit examples/src/main/python/ml/dataframe_example.py + bin/spark-submit examples/src/main/python/ml/dataframe_example.py """ from __future__ import print_function @@ -35,18 +35,18 @@ print("Usage: dataframe_example.py ", file=sys.stderr) sys.exit(-1) elif len(sys.argv) == 2: - input = sys.argv[1] + input_path = sys.argv[1] else: - input = "data/mllib/sample_libsvm_data.txt" + input_path = "data/mllib/sample_libsvm_data.txt" spark = SparkSession \ .builder \ .appName("DataFrameExample") \ .getOrCreate() - # Load input data - print("Loading LIBSVM file with UDT from " + input + ".") - df = spark.read.format("libsvm").load(input).cache() + # Load an input file + print("Loading LIBSVM file with UDT from " + input_path + ".") + df = spark.read.format("libsvm").load(input_path).cache() print("Schema from LIBSVM:") df.printSchema() print("Loaded training data as a DataFrame with " + From b304e07e0671faf96530f9d8f49c55a83b07fa15 Mon Sep 17 00:00:00 2001 From: Xiayun Sun Date: Mon, 12 Mar 2018 22:13:28 +0900 Subject: [PATCH 0463/1161] [SPARK-23462][SQL] improve missing field error message in `StructType` ## What changes were proposed in this pull request? The error message ```s"""Field "$name" does not exist."""``` is thrown when looking up an unknown field in StructType. In the error message, we should also contain the information about which columns/fields exist in this struct. ## How was this patch tested? Added new unit tests. Note: I created a new `StructTypeSuite.scala` as I couldn't find an existing suite that's suitable to place these tests. I may be missing something so feel free to propose new locations. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Xiayun Sun Closes #20649 from xysun/SPARK-23462. --- .../apache/spark/sql/types/StructType.scala | 11 +++-- .../spark/sql/types/StructTypeSuite.scala | 40 +++++++++++++++++++ 2 files changed, 48 insertions(+), 3 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/types/StructTypeSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index d5011c3cb87e9..362676b252126 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -271,7 +271,9 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru */ def apply(name: String): StructField = { nameToField.getOrElse(name, - throw new IllegalArgumentException(s"""Field "$name" does not exist.""")) + throw new IllegalArgumentException( + s"""Field "$name" does not exist. + |Available fields: ${fieldNames.mkString(", ")}""".stripMargin)) } /** @@ -284,7 +286,8 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru val nonExistFields = names -- fieldNamesSet if (nonExistFields.nonEmpty) { throw new IllegalArgumentException( - s"Field ${nonExistFields.mkString(",")} does not exist.") + s"""Nonexistent field(s): ${nonExistFields.mkString(", ")}. + |Available fields: ${fieldNames.mkString(", ")}""".stripMargin) } // Preserve the original order of fields. StructType(fields.filter(f => names.contains(f.name))) @@ -297,7 +300,9 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru */ def fieldIndex(name: String): Int = { nameToIndex.getOrElse(name, - throw new IllegalArgumentException(s"""Field "$name" does not exist.""")) + throw new IllegalArgumentException( + s"""Field "$name" does not exist. + |Available fields: ${fieldNames.mkString(", ")}""".stripMargin)) } private[sql] def getFieldIndex(name: String): Option[Int] = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/StructTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/StructTypeSuite.scala new file mode 100644 index 0000000000000..c6ca8bb005429 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/StructTypeSuite.scala @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.types + +import org.apache.spark.SparkFunSuite + +class StructTypeSuite extends SparkFunSuite { + + val s = StructType.fromDDL("a INT, b STRING") + + test("lookup a single missing field should output existing fields") { + val e = intercept[IllegalArgumentException](s("c")).getMessage + assert(e.contains("Available fields: a, b")) + } + + test("lookup a set of missing fields should output existing fields") { + val e = intercept[IllegalArgumentException](s(Set("a", "c"))).getMessage + assert(e.contains("Available fields: a, b")) + } + + test("lookup fieldIndex for missing field should output existing fields") { + val e = intercept[IllegalArgumentException](s.fieldIndex("c")).getMessage + assert(e.contains("Available fields: a, b")) + } +} From d5b41aea62201cd5b1baad2f68f5fc7eb99c62c5 Mon Sep 17 00:00:00 2001 From: Jooseong Kim Date: Mon, 12 Mar 2018 11:31:34 -0700 Subject: [PATCH 0464/1161] [SPARK-23618][K8S][BUILD] Initialize BUILD_ARGS in docker-image-tool.sh ## What changes were proposed in this pull request? This change initializes BUILD_ARGS to an empty array when $SPARK_HOME/RELEASE exists. In function build, "local BUILD_ARGS" effectively creates an array of one element where the first and only element is an empty string, so "${BUILD_ARGS[]}" expands to "" and passes an extra argument to docker. Setting BUILD_ARGS to an empty array makes "${BUILD_ARGS[]}" expand to nothing. ## How was this patch tested? Manually tested. $ cat RELEASE Spark 2.3.0 (git revision a0d7949896) built for Hadoop 2.7.3 Build flags: -Phadoop-2.7 -Phive -Phive-thriftserver -Pkafka-0-8 -Pmesos -Pyarn -Pkubernetes -Pflume -Psparkr -DzincPort=3036 $ ./bin/docker-image-tool.sh -m t testing build Sending build context to Docker daemon 256.4MB ... vanzin Author: Jooseong Kim Closes #20791 from jooseong/SPARK-23618. --- bin/docker-image-tool.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/bin/docker-image-tool.sh b/bin/docker-image-tool.sh index 071406336d1b1..0d0f564bb8b9b 100755 --- a/bin/docker-image-tool.sh +++ b/bin/docker-image-tool.sh @@ -57,6 +57,7 @@ function build { else # Not passed as an argument to docker, but used to validate the Spark directory. IMG_PATH="kubernetes/dockerfiles" + BUILD_ARGS=() fi if [ ! -d "$IMG_PATH" ]; then From 567bd31e0ae8b632357baa93e1469b666fb06f3d Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Mon, 12 Mar 2018 14:53:15 -0500 Subject: [PATCH 0465/1161] [SPARK-23412][ML] Add cosine distance to BisectingKMeans ## What changes were proposed in this pull request? The PR adds the option to specify a distance measure in BisectingKMeans. Moreover, it introduces the ability to use the cosine distance measure in it. ## How was this patch tested? added UTs + existing UTs Author: Marco Gaido Closes #20600 from mgaido91/SPARK-23412. --- .../spark/ml/clustering/BisectingKMeans.scala | 16 +- .../apache/spark/ml/clustering/KMeans.scala | 11 +- .../ml/param/shared/SharedParamsCodeGen.scala | 6 +- .../spark/ml/param/shared/sharedParams.scala | 19 ++ .../mllib/clustering/BisectingKMeans.scala | 139 ++++---- .../clustering/BisectingKMeansModel.scala | 115 +++++-- .../mllib/clustering/DistanceMeasure.scala | 303 ++++++++++++++++++ .../spark/mllib/clustering/KMeans.scala | 196 +---------- .../ml/clustering/BisectingKMeansSuite.scala | 44 ++- project/MimaExcludes.scala | 6 + 10 files changed, 557 insertions(+), 298 deletions(-) create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/clustering/DistanceMeasure.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala index 4c20e6563bad1..f7c422dc0faea 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala @@ -26,7 +26,8 @@ import org.apache.spark.ml.linalg.{Vector, VectorUDT} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ -import org.apache.spark.mllib.clustering.{BisectingKMeans => MLlibBisectingKMeans, BisectingKMeansModel => MLlibBisectingKMeansModel} +import org.apache.spark.mllib.clustering.{BisectingKMeans => MLlibBisectingKMeans, + BisectingKMeansModel => MLlibBisectingKMeansModel} import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => OldVectors} import org.apache.spark.mllib.linalg.VectorImplicits._ import org.apache.spark.rdd.RDD @@ -38,8 +39,8 @@ import org.apache.spark.sql.types.{IntegerType, StructType} /** * Common params for BisectingKMeans and BisectingKMeansModel */ -private[clustering] trait BisectingKMeansParams extends Params - with HasMaxIter with HasFeaturesCol with HasSeed with HasPredictionCol { +private[clustering] trait BisectingKMeansParams extends Params with HasMaxIter + with HasFeaturesCol with HasSeed with HasPredictionCol with HasDistanceMeasure { /** * The desired number of leaf clusters. Must be > 1. Default: 4. @@ -104,6 +105,10 @@ class BisectingKMeansModel private[ml] ( @Since("2.1.0") def setPredictionCol(value: String): this.type = set(predictionCol, value) + /** @group expertSetParam */ + @Since("2.4.0") + def setDistanceMeasure(value: String): this.type = set(distanceMeasure, value) + @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) @@ -248,6 +253,10 @@ class BisectingKMeans @Since("2.0.0") ( @Since("2.0.0") def setMinDivisibleClusterSize(value: Double): this.type = set(minDivisibleClusterSize, value) + /** @group expertSetParam */ + @Since("2.4.0") + def setDistanceMeasure(value: String): this.type = set(distanceMeasure, value) + @Since("2.0.0") override def fit(dataset: Dataset[_]): BisectingKMeansModel = { transformSchema(dataset.schema, logging = true) @@ -263,6 +272,7 @@ class BisectingKMeans @Since("2.0.0") ( .setMaxIterations($(maxIter)) .setMinDivisibleClusterSize($(minDivisibleClusterSize)) .setSeed($(seed)) + .setDistanceMeasure($(distanceMeasure)) val parentModel = bkm.run(rdd) val model = copyValues(new BisectingKMeansModel(uid, parentModel).setParent(this)) val summary = new BisectingKMeansSummary( diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index c8145de564cbe..987a4285ebad4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -40,7 +40,7 @@ import org.apache.spark.util.VersionUtils.majorVersion * Common params for KMeans and KMeansModel */ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFeaturesCol - with HasSeed with HasPredictionCol with HasTol { + with HasSeed with HasPredictionCol with HasTol with HasDistanceMeasure { /** * The number of clusters to create (k). Must be > 1. Note that it is possible for fewer than @@ -71,15 +71,6 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe @Since("1.5.0") def getInitMode: String = $(initMode) - @Since("2.4.0") - final val distanceMeasure = new Param[String](this, "distanceMeasure", "The distance measure. " + - "Supported options: 'euclidean' and 'cosine'.", - (value: String) => MLlibKMeans.validateDistanceMeasure(value)) - - /** @group expertGetParam */ - @Since("2.4.0") - def getDistanceMeasure: String = $(distanceMeasure) - /** * Param for the number of steps for the k-means|| initialization mode. This is an advanced * setting -- the default of 2 is almost always enough. Must be > 0. Default: 2. diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala index 6ad44af9ef7eb..b9c3170cc3c28 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala @@ -91,7 +91,11 @@ private[shared] object SharedParamsCodeGen { "after fitting. If set to true, then all sub-models will be available. Warning: For " + "large models, collecting all sub-models can cause OOMs on the Spark driver", Some("false"), isExpertParam = true), - ParamDesc[String]("loss", "the loss function to be optimized", finalFields = false) + ParamDesc[String]("loss", "the loss function to be optimized", finalFields = false), + ParamDesc[String]("distanceMeasure", "The distance measure. Supported options: 'euclidean'" + + " and 'cosine'", Some("org.apache.spark.mllib.clustering.DistanceMeasure.EUCLIDEAN"), + isValid = "(value: String) => " + + "org.apache.spark.mllib.clustering.DistanceMeasure.validateDistanceMeasure(value)") ) val code = genSharedParams(params) diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala index be8b2f273164b..282ea6ebcbf7f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala @@ -504,4 +504,23 @@ trait HasLoss extends Params { /** @group getParam */ final def getLoss: String = $(loss) } + +/** + * Trait for shared param distanceMeasure (default: org.apache.spark.mllib.clustering.DistanceMeasure.EUCLIDEAN). This trait may be changed or + * removed between minor versions. + */ +@DeveloperApi +trait HasDistanceMeasure extends Params { + + /** + * Param for The distance measure. Supported options: 'euclidean' and 'cosine'. + * @group param + */ + final val distanceMeasure: Param[String] = new Param[String](this, "distanceMeasure", "The distance measure. Supported options: 'euclidean' and 'cosine'", (value: String) => org.apache.spark.mllib.clustering.DistanceMeasure.validateDistanceMeasure(value)) + + setDefault(distanceMeasure, org.apache.spark.mllib.clustering.DistanceMeasure.EUCLIDEAN) + + /** @group getParam */ + final def getDistanceMeasure: String = $(distanceMeasure) +} // scalastyle:on diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala index 2221f4c0edc17..98af487306dcc 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala @@ -25,7 +25,7 @@ import scala.collection.mutable import org.apache.spark.annotation.Since import org.apache.spark.api.java.JavaRDD import org.apache.spark.internal.Logging -import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors} +import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel @@ -57,7 +57,8 @@ class BisectingKMeans private ( private var k: Int, private var maxIterations: Int, private var minDivisibleClusterSize: Double, - private var seed: Long) extends Logging { + private var seed: Long, + private var distanceMeasure: String) extends Logging { import BisectingKMeans._ @@ -65,7 +66,7 @@ class BisectingKMeans private ( * Constructs with the default configuration */ @Since("1.6.0") - def this() = this(4, 20, 1.0, classOf[BisectingKMeans].getName.##) + def this() = this(4, 20, 1.0, classOf[BisectingKMeans].getName.##, DistanceMeasure.EUCLIDEAN) /** * Sets the desired number of leaf clusters (default: 4). @@ -134,6 +135,22 @@ class BisectingKMeans private ( @Since("1.6.0") def getSeed: Long = this.seed + /** + * The distance suite used by the algorithm. + */ + @Since("2.4.0") + def getDistanceMeasure: String = distanceMeasure + + /** + * Set the distance suite used by the algorithm. + */ + @Since("2.4.0") + def setDistanceMeasure(distanceMeasure: String): this.type = { + DistanceMeasure.validateDistanceMeasure(distanceMeasure) + this.distanceMeasure = distanceMeasure + this + } + /** * Runs the bisecting k-means algorithm. * @param input RDD of vectors @@ -147,11 +164,13 @@ class BisectingKMeans private ( } val d = input.map(_.size).first() logInfo(s"Feature dimension: $d.") + + val dMeasure: DistanceMeasure = DistanceMeasure.decodeFromString(this.distanceMeasure) // Compute and cache vector norms for fast distance computation. val norms = input.map(v => Vectors.norm(v, 2.0)).persist(StorageLevel.MEMORY_AND_DISK) val vectors = input.zip(norms).map { case (x, norm) => new VectorWithNorm(x, norm) } var assignments = vectors.map(v => (ROOT_INDEX, v)) - var activeClusters = summarize(d, assignments) + var activeClusters = summarize(d, assignments, dMeasure) val rootSummary = activeClusters(ROOT_INDEX) val n = rootSummary.size logInfo(s"Number of points: $n.") @@ -184,24 +203,25 @@ class BisectingKMeans private ( val divisibleIndices = divisibleClusters.keys.toSet logInfo(s"Dividing ${divisibleIndices.size} clusters on level $level.") var newClusterCenters = divisibleClusters.flatMap { case (index, summary) => - val (left, right) = splitCenter(summary.center, random) + val (left, right) = splitCenter(summary.center, random, dMeasure) Iterator((leftChildIndex(index), left), (rightChildIndex(index), right)) }.map(identity) // workaround for a Scala bug (SI-7005) that produces a not serializable map var newClusters: Map[Long, ClusterSummary] = null var newAssignments: RDD[(Long, VectorWithNorm)] = null for (iter <- 0 until maxIterations) { - newAssignments = updateAssignments(assignments, divisibleIndices, newClusterCenters) + newAssignments = updateAssignments(assignments, divisibleIndices, newClusterCenters, + dMeasure) .filter { case (index, _) => divisibleIndices.contains(parentIndex(index)) } - newClusters = summarize(d, newAssignments) + newClusters = summarize(d, newAssignments, dMeasure) newClusterCenters = newClusters.mapValues(_.center).map(identity) } if (preIndices != null) { preIndices.unpersist(false) } preIndices = indices - indices = updateAssignments(assignments, divisibleIndices, newClusterCenters).keys + indices = updateAssignments(assignments, divisibleIndices, newClusterCenters, dMeasure).keys .persist(StorageLevel.MEMORY_AND_DISK) assignments = indices.zip(vectors) inactiveClusters ++= activeClusters @@ -222,8 +242,8 @@ class BisectingKMeans private ( } norms.unpersist(false) val clusters = activeClusters ++ inactiveClusters - val root = buildTree(clusters) - new BisectingKMeansModel(root) + val root = buildTree(clusters, dMeasure) + new BisectingKMeansModel(root, this.distanceMeasure) } /** @@ -266,8 +286,9 @@ private object BisectingKMeans extends Serializable { */ private def summarize( d: Int, - assignments: RDD[(Long, VectorWithNorm)]): Map[Long, ClusterSummary] = { - assignments.aggregateByKey(new ClusterSummaryAggregator(d))( + assignments: RDD[(Long, VectorWithNorm)], + distanceMeasure: DistanceMeasure): Map[Long, ClusterSummary] = { + assignments.aggregateByKey(new ClusterSummaryAggregator(d, distanceMeasure))( seqOp = (agg, v) => agg.add(v), combOp = (agg1, agg2) => agg1.merge(agg2) ).mapValues(_.summary) @@ -278,7 +299,8 @@ private object BisectingKMeans extends Serializable { * Cluster summary aggregator. * @param d feature dimension */ - private class ClusterSummaryAggregator(val d: Int) extends Serializable { + private class ClusterSummaryAggregator(val d: Int, val distanceMeasure: DistanceMeasure) + extends Serializable { private var n: Long = 0L private val sum: Vector = Vectors.zeros(d) private var sumSq: Double = 0.0 @@ -288,7 +310,7 @@ private object BisectingKMeans extends Serializable { n += 1L // TODO: use a numerically stable approach to estimate cost sumSq += v.norm * v.norm - BLAS.axpy(1.0, v.vector, sum) + distanceMeasure.updateClusterSum(v, sum) this } @@ -296,19 +318,15 @@ private object BisectingKMeans extends Serializable { def merge(other: ClusterSummaryAggregator): this.type = { n += other.n sumSq += other.sumSq - BLAS.axpy(1.0, other.sum, sum) + distanceMeasure.updateClusterSum(new VectorWithNorm(other.sum), sum) this } /** Returns the summary. */ def summary: ClusterSummary = { - val mean = sum.copy - if (n > 0L) { - BLAS.scal(1.0 / n, mean) - } - val center = new VectorWithNorm(mean) - val cost = math.max(sumSq - n * center.norm * center.norm, 0.0) - new ClusterSummary(n, center, cost) + val center = distanceMeasure.centroid(sum.copy, n) + val cost = distanceMeasure.clusterCost(center, new VectorWithNorm(sum), n, sumSq) + ClusterSummary(n, center, cost) } } @@ -321,16 +339,13 @@ private object BisectingKMeans extends Serializable { */ private def splitCenter( center: VectorWithNorm, - random: Random): (VectorWithNorm, VectorWithNorm) = { + random: Random, + distanceMeasure: DistanceMeasure): (VectorWithNorm, VectorWithNorm) = { val d = center.vector.size val norm = center.norm val level = 1e-4 * norm val noise = Vectors.dense(Array.fill(d)(random.nextDouble())) - val left = center.vector.copy - BLAS.axpy(-level, noise, left) - val right = center.vector.copy - BLAS.axpy(level, noise, right) - (new VectorWithNorm(left), new VectorWithNorm(right)) + distanceMeasure.symmetricCentroids(level, noise, center.vector) } /** @@ -343,16 +358,20 @@ private object BisectingKMeans extends Serializable { private def updateAssignments( assignments: RDD[(Long, VectorWithNorm)], divisibleIndices: Set[Long], - newClusterCenters: Map[Long, VectorWithNorm]): RDD[(Long, VectorWithNorm)] = { + newClusterCenters: Map[Long, VectorWithNorm], + distanceMeasure: DistanceMeasure): RDD[(Long, VectorWithNorm)] = { assignments.map { case (index, v) => if (divisibleIndices.contains(index)) { val children = Seq(leftChildIndex(index), rightChildIndex(index)) - val newClusterChildren = children.filter(newClusterCenters.contains(_)) + val newClusterChildren = children.filter(newClusterCenters.contains) + val newClusterChildrenCenterToId = + newClusterChildren.map(id => newClusterCenters(id) -> id).toMap + val newClusterChildrenCenters = newClusterChildrenCenterToId.keys.toArray if (newClusterChildren.nonEmpty) { - val selected = newClusterChildren.minBy { child => - EuclideanDistanceMeasure.fastSquaredDistance(newClusterCenters(child), v) - } - (selected, v) + val selected = distanceMeasure.findClosest(newClusterChildrenCenters, v)._1 + val center = newClusterChildrenCenters(selected) + val id = newClusterChildrenCenterToId(center) + (id, v) } else { (index, v) } @@ -367,7 +386,9 @@ private object BisectingKMeans extends Serializable { * @param clusters a map from cluster indices to corresponding cluster summaries * @return the root node of the clustering tree */ - private def buildTree(clusters: Map[Long, ClusterSummary]): ClusteringTreeNode = { + private def buildTree( + clusters: Map[Long, ClusterSummary], + distanceMeasure: DistanceMeasure): ClusteringTreeNode = { var leafIndex = 0 var internalIndex = -1 @@ -385,11 +406,11 @@ private object BisectingKMeans extends Serializable { internalIndex -= 1 val leftIndex = leftChildIndex(rawIndex) val rightIndex = rightChildIndex(rawIndex) - val indexes = Seq(leftIndex, rightIndex).filter(clusters.contains(_)) - val height = math.sqrt(indexes.map { childIndex => - EuclideanDistanceMeasure.fastSquaredDistance(center, clusters(childIndex).center) - }.max) - val children = indexes.map(buildSubTree(_)).toArray + val indexes = Seq(leftIndex, rightIndex).filter(clusters.contains) + val height = indexes.map { childIndex => + distanceMeasure.distance(center, clusters(childIndex).center) + }.max + val children = indexes.map(buildSubTree).toArray new ClusteringTreeNode(index, size, center, cost, height, children) } else { val index = leafIndex @@ -441,42 +462,45 @@ private[clustering] class ClusteringTreeNode private[clustering] ( def center: Vector = centerWithNorm.vector /** Predicts the leaf cluster node index that the input point belongs to. */ - def predict(point: Vector): Int = { - val (index, _) = predict(new VectorWithNorm(point)) + def predict(point: Vector, distanceMeasure: DistanceMeasure): Int = { + val (index, _) = predict(new VectorWithNorm(point), distanceMeasure) index } /** Returns the full prediction path from root to leaf. */ - def predictPath(point: Vector): Array[ClusteringTreeNode] = { - predictPath(new VectorWithNorm(point)).toArray + def predictPath(point: Vector, distanceMeasure: DistanceMeasure): Array[ClusteringTreeNode] = { + predictPath(new VectorWithNorm(point), distanceMeasure).toArray } /** Returns the full prediction path from root to leaf. */ - private def predictPath(pointWithNorm: VectorWithNorm): List[ClusteringTreeNode] = { + private def predictPath( + pointWithNorm: VectorWithNorm, + distanceMeasure: DistanceMeasure): List[ClusteringTreeNode] = { if (isLeaf) { this :: Nil } else { val selected = children.minBy { child => - EuclideanDistanceMeasure.fastSquaredDistance(child.centerWithNorm, pointWithNorm) + distanceMeasure.distance(child.centerWithNorm, pointWithNorm) } - selected :: selected.predictPath(pointWithNorm) + selected :: selected.predictPath(pointWithNorm, distanceMeasure) } } /** - * Computes the cost (squared distance to the predicted leaf cluster center) of the input point. + * Computes the cost of the input point. */ - def computeCost(point: Vector): Double = { - val (_, cost) = predict(new VectorWithNorm(point)) + def computeCost(point: Vector, distanceMeasure: DistanceMeasure): Double = { + val (_, cost) = predict(new VectorWithNorm(point), distanceMeasure) cost } /** * Predicts the cluster index and the cost of the input point. */ - private def predict(pointWithNorm: VectorWithNorm): (Int, Double) = { - predict(pointWithNorm, - EuclideanDistanceMeasure.fastSquaredDistance(centerWithNorm, pointWithNorm)) + private def predict( + pointWithNorm: VectorWithNorm, + distanceMeasure: DistanceMeasure): (Int, Double) = { + predict(pointWithNorm, distanceMeasure.cost(centerWithNorm, pointWithNorm), distanceMeasure) } /** @@ -486,14 +510,17 @@ private[clustering] class ClusteringTreeNode private[clustering] ( * @return (predicted leaf cluster index, cost) */ @tailrec - private def predict(pointWithNorm: VectorWithNorm, cost: Double): (Int, Double) = { + private def predict( + pointWithNorm: VectorWithNorm, + cost: Double, + distanceMeasure: DistanceMeasure): (Int, Double) = { if (isLeaf) { (index, cost) } else { val (selectedChild, minCost) = children.map { child => - (child, EuclideanDistanceMeasure.fastSquaredDistance(child.centerWithNorm, pointWithNorm)) + (child, distanceMeasure.cost(child.centerWithNorm, pointWithNorm)) }.minBy(_._2) - selectedChild.predict(pointWithNorm, minCost) + selectedChild.predict(pointWithNorm, minCost, distanceMeasure) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeansModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeansModel.scala index 633bda6aac804..9d115afcea75d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeansModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeansModel.scala @@ -40,9 +40,16 @@ import org.apache.spark.sql.{Row, SparkSession} */ @Since("1.6.0") class BisectingKMeansModel private[clustering] ( - private[clustering] val root: ClusteringTreeNode + private[clustering] val root: ClusteringTreeNode, + @Since("2.4.0") val distanceMeasure: String ) extends Serializable with Saveable with Logging { + @Since("1.6.0") + def this(root: ClusteringTreeNode) = this(root, DistanceMeasure.EUCLIDEAN) + + private val distanceMeasureInstance: DistanceMeasure = + DistanceMeasure.decodeFromString(distanceMeasure) + /** * Leaf cluster centers. */ @@ -59,7 +66,7 @@ class BisectingKMeansModel private[clustering] ( */ @Since("1.6.0") def predict(point: Vector): Int = { - root.predict(point) + root.predict(point, distanceMeasureInstance) } /** @@ -67,7 +74,7 @@ class BisectingKMeansModel private[clustering] ( */ @Since("1.6.0") def predict(points: RDD[Vector]): RDD[Int] = { - points.map { p => root.predict(p) } + points.map { p => root.predict(p, distanceMeasureInstance) } } /** @@ -82,7 +89,7 @@ class BisectingKMeansModel private[clustering] ( */ @Since("1.6.0") def computeCost(point: Vector): Double = { - root.computeCost(point) + root.computeCost(point, distanceMeasureInstance) } /** @@ -91,7 +98,7 @@ class BisectingKMeansModel private[clustering] ( */ @Since("1.6.0") def computeCost(data: RDD[Vector]): Double = { - data.map(root.computeCost).sum() + data.map(root.computeCost(_, distanceMeasureInstance)).sum() } /** @@ -113,18 +120,19 @@ object BisectingKMeansModel extends Loader[BisectingKMeansModel] { @Since("2.0.0") override def load(sc: SparkContext, path: String): BisectingKMeansModel = { - val (loadedClassName, formatVersion, metadata) = Loader.loadMetadata(sc, path) - implicit val formats = DefaultFormats - val rootId = (metadata \ "rootId").extract[Int] - val classNameV1_0 = SaveLoadV1_0.thisClassName + val (loadedClassName, formatVersion, __) = Loader.loadMetadata(sc, path) (loadedClassName, formatVersion) match { - case (classNameV1_0, "1.0") => - val model = SaveLoadV1_0.load(sc, path, rootId) + case (SaveLoadV1_0.thisClassName, SaveLoadV1_0.thisFormatVersion) => + val model = SaveLoadV1_0.load(sc, path) + model + case (SaveLoadV2_0.thisClassName, SaveLoadV2_0.thisFormatVersion) => + val model = SaveLoadV1_0.load(sc, path) model case _ => throw new Exception( s"BisectingKMeansModel.load did not recognize model with (className, format version):" + s"($loadedClassName, $formatVersion). Supported:\n" + - s" ($classNameV1_0, 1.0)") + s" (${SaveLoadV1_0.thisClassName}, ${SaveLoadV1_0.thisClassName}\n" + + s" (${SaveLoadV2_0.thisClassName}, ${SaveLoadV2_0.thisClassName})") } } @@ -136,8 +144,28 @@ object BisectingKMeansModel extends Loader[BisectingKMeansModel] { r.getDouble(4), r.getDouble(5), r.getSeq[Int](6)) } + private def getNodes(node: ClusteringTreeNode): Array[ClusteringTreeNode] = { + if (node.children.isEmpty) { + Array(node) + } else { + node.children.flatMap(getNodes) ++ Array(node) + } + } + + private def buildTree(rootId: Int, nodes: Map[Int, Data]): ClusteringTreeNode = { + val root = nodes(rootId) + if (root.children.isEmpty) { + new ClusteringTreeNode(root.index, root.size, new VectorWithNorm(root.center, root.norm), + root.cost, root.height, new Array[ClusteringTreeNode](0)) + } else { + val children = root.children.map(c => buildTree(c, nodes)) + new ClusteringTreeNode(root.index, root.size, new VectorWithNorm(root.center, root.norm), + root.cost, root.height, children.toArray) + } + } + private[clustering] object SaveLoadV1_0 { - private val thisFormatVersion = "1.0" + private[clustering] val thisFormatVersion = "1.0" private[clustering] val thisClassName = "org.apache.spark.mllib.clustering.BisectingKMeansModel" @@ -155,34 +183,55 @@ object BisectingKMeansModel extends Loader[BisectingKMeansModel] { spark.createDataFrame(data).write.parquet(Loader.dataPath(path)) } - private def getNodes(node: ClusteringTreeNode): Array[ClusteringTreeNode] = { - if (node.children.isEmpty) { - Array(node) - } else { - node.children.flatMap(getNodes(_)) ++ Array(node) - } - } - - def load(sc: SparkContext, path: String, rootId: Int): BisectingKMeansModel = { + def load(sc: SparkContext, path: String): BisectingKMeansModel = { + implicit val formats: DefaultFormats = DefaultFormats + val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path) + assert(className == thisClassName) + assert(formatVersion == thisFormatVersion) + val rootId = (metadata \ "rootId").extract[Int] val spark = SparkSession.builder().sparkContext(sc).getOrCreate() val rows = spark.read.parquet(Loader.dataPath(path)) Loader.checkSchema[Data](rows.schema) val data = rows.select("index", "size", "center", "norm", "cost", "height", "children") val nodes = data.rdd.map(Data.apply).collect().map(d => (d.index, d)).toMap val rootNode = buildTree(rootId, nodes) - new BisectingKMeansModel(rootNode) + new BisectingKMeansModel(rootNode, DistanceMeasure.EUCLIDEAN) } + } + + private[clustering] object SaveLoadV2_0 { + private[clustering] val thisFormatVersion = "2.0" - private def buildTree(rootId: Int, nodes: Map[Int, Data]): ClusteringTreeNode = { - val root = nodes.get(rootId).get - if (root.children.isEmpty) { - new ClusteringTreeNode(root.index, root.size, new VectorWithNorm(root.center, root.norm), - root.cost, root.height, new Array[ClusteringTreeNode](0)) - } else { - val children = root.children.map(c => buildTree(c, nodes)) - new ClusteringTreeNode(root.index, root.size, new VectorWithNorm(root.center, root.norm), - root.cost, root.height, children.toArray) - } + private[clustering] + val thisClassName = "org.apache.spark.mllib.clustering.BisectingKMeansModel" + + def save(sc: SparkContext, model: BisectingKMeansModel, path: String): Unit = { + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() + val metadata = compact(render( + ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) + ~ ("rootId" -> model.root.index) ~ ("distanceMeasure" -> model.distanceMeasure))) + sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path)) + + val data = getNodes(model.root).map(node => Data(node.index, node.size, + node.centerWithNorm.vector, node.centerWithNorm.norm, node.cost, node.height, + node.children.map(_.index))) + spark.createDataFrame(data).write.parquet(Loader.dataPath(path)) + } + + def load(sc: SparkContext, path: String): BisectingKMeansModel = { + implicit val formats: DefaultFormats = DefaultFormats + val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path) + assert(className == thisClassName) + assert(formatVersion == thisFormatVersion) + val rootId = (metadata \ "rootId").extract[Int] + val distanceMeasure = (metadata \ "distanceMeasure").extract[String] + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() + val rows = spark.read.parquet(Loader.dataPath(path)) + Loader.checkSchema[Data](rows.schema) + val data = rows.select("index", "size", "center", "norm", "cost", "height", "children") + val nodes = data.rdd.map(Data.apply).collect().map(d => (d.index, d)).toMap + val rootNode = buildTree(rootId, nodes) + new BisectingKMeansModel(rootNode, distanceMeasure) } } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/DistanceMeasure.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/DistanceMeasure.scala new file mode 100644 index 0000000000000..683360efabc76 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/DistanceMeasure.scala @@ -0,0 +1,303 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.clustering + +import org.apache.spark.annotation.Since +import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.linalg.BLAS.{axpy, dot, scal} +import org.apache.spark.mllib.util.MLUtils + +private[spark] abstract class DistanceMeasure extends Serializable { + + /** + * @return the index of the closest center to the given point, as well as the cost. + */ + def findClosest( + centers: TraversableOnce[VectorWithNorm], + point: VectorWithNorm): (Int, Double) = { + var bestDistance = Double.PositiveInfinity + var bestIndex = 0 + var i = 0 + centers.foreach { center => + val currentDistance = distance(center, point) + if (currentDistance < bestDistance) { + bestDistance = currentDistance + bestIndex = i + } + i += 1 + } + (bestIndex, bestDistance) + } + + /** + * @return the K-means cost of a given point against the given cluster centers. + */ + def pointCost( + centers: TraversableOnce[VectorWithNorm], + point: VectorWithNorm): Double = { + findClosest(centers, point)._2 + } + + /** + * @return whether a center converged or not, given the epsilon parameter. + */ + def isCenterConverged( + oldCenter: VectorWithNorm, + newCenter: VectorWithNorm, + epsilon: Double): Boolean = { + distance(oldCenter, newCenter) <= epsilon + } + + /** + * @return the distance between two points. + */ + def distance( + v1: VectorWithNorm, + v2: VectorWithNorm): Double + + /** + * @return the total cost of the cluster from its aggregated properties + */ + def clusterCost( + centroid: VectorWithNorm, + pointsSum: VectorWithNorm, + numberOfPoints: Long, + pointsSquaredNorm: Double): Double + + /** + * Updates the value of `sum` adding the `point` vector. + * @param point a `VectorWithNorm` to be added to `sum` of a cluster + * @param sum the `sum` for a cluster to be updated + */ + def updateClusterSum(point: VectorWithNorm, sum: Vector): Unit = { + axpy(1.0, point.vector, sum) + } + + /** + * Returns a centroid for a cluster given its `sum` vector and its `count` of points. + * + * @param sum the `sum` for a cluster + * @param count the number of points in the cluster + * @return the centroid of the cluster + */ + def centroid(sum: Vector, count: Long): VectorWithNorm = { + scal(1.0 / count, sum) + new VectorWithNorm(sum) + } + + /** + * Returns two new centroids symmetric to the specified centroid applying `noise` with the + * with the specified `level`. + * + * @param level the level of `noise` to apply to the given centroid. + * @param noise a noise vector + * @param centroid the parent centroid + * @return a left and right centroid symmetric to `centroid` + */ + def symmetricCentroids( + level: Double, + noise: Vector, + centroid: Vector): (VectorWithNorm, VectorWithNorm) = { + val left = centroid.copy + axpy(-level, noise, left) + val right = centroid.copy + axpy(level, noise, right) + (new VectorWithNorm(left), new VectorWithNorm(right)) + } + + /** + * @return the cost of a point to be assigned to the cluster centroid + */ + def cost( + point: VectorWithNorm, + centroid: VectorWithNorm): Double = distance(point, centroid) +} + +@Since("2.4.0") +object DistanceMeasure { + + @Since("2.4.0") + val EUCLIDEAN = "euclidean" + @Since("2.4.0") + val COSINE = "cosine" + + private[spark] def decodeFromString(distanceMeasure: String): DistanceMeasure = + distanceMeasure match { + case EUCLIDEAN => new EuclideanDistanceMeasure + case COSINE => new CosineDistanceMeasure + case _ => throw new IllegalArgumentException(s"distanceMeasure must be one of: " + + s"$EUCLIDEAN, $COSINE. $distanceMeasure provided.") + } + + private[spark] def validateDistanceMeasure(distanceMeasure: String): Boolean = { + distanceMeasure match { + case DistanceMeasure.EUCLIDEAN => true + case DistanceMeasure.COSINE => true + case _ => false + } + } +} + +private[spark] class EuclideanDistanceMeasure extends DistanceMeasure { + /** + * @return the index of the closest center to the given point, as well as the squared distance. + */ + override def findClosest( + centers: TraversableOnce[VectorWithNorm], + point: VectorWithNorm): (Int, Double) = { + var bestDistance = Double.PositiveInfinity + var bestIndex = 0 + var i = 0 + centers.foreach { center => + // Since `\|a - b\| \geq |\|a\| - \|b\||`, we can use this lower bound to avoid unnecessary + // distance computation. + var lowerBoundOfSqDist = center.norm - point.norm + lowerBoundOfSqDist = lowerBoundOfSqDist * lowerBoundOfSqDist + if (lowerBoundOfSqDist < bestDistance) { + val distance: Double = EuclideanDistanceMeasure.fastSquaredDistance(center, point) + if (distance < bestDistance) { + bestDistance = distance + bestIndex = i + } + } + i += 1 + } + (bestIndex, bestDistance) + } + + /** + * @return whether a center converged or not, given the epsilon parameter. + */ + override def isCenterConverged( + oldCenter: VectorWithNorm, + newCenter: VectorWithNorm, + epsilon: Double): Boolean = { + EuclideanDistanceMeasure.fastSquaredDistance(newCenter, oldCenter) <= epsilon * epsilon + } + + /** + * @param v1: first vector + * @param v2: second vector + * @return the Euclidean distance between the two input vectors + */ + override def distance(v1: VectorWithNorm, v2: VectorWithNorm): Double = { + Math.sqrt(EuclideanDistanceMeasure.fastSquaredDistance(v1, v2)) + } + + /** + * @return the total cost of the cluster from its aggregated properties + */ + override def clusterCost( + centroid: VectorWithNorm, + pointsSum: VectorWithNorm, + numberOfPoints: Long, + pointsSquaredNorm: Double): Double = { + math.max(pointsSquaredNorm - numberOfPoints * centroid.norm * centroid.norm, 0.0) + } + + /** + * @return the cost of a point to be assigned to the cluster centroid + */ + override def cost( + point: VectorWithNorm, + centroid: VectorWithNorm): Double = { + EuclideanDistanceMeasure.fastSquaredDistance(point, centroid) + } +} + + +private[spark] object EuclideanDistanceMeasure { + /** + * @return the squared Euclidean distance between two vectors computed by + * [[org.apache.spark.mllib.util.MLUtils#fastSquaredDistance]]. + */ + private[clustering] def fastSquaredDistance( + v1: VectorWithNorm, + v2: VectorWithNorm): Double = { + MLUtils.fastSquaredDistance(v1.vector, v1.norm, v2.vector, v2.norm) + } +} + +private[spark] class CosineDistanceMeasure extends DistanceMeasure { + /** + * @param v1: first vector + * @param v2: second vector + * @return the cosine distance between the two input vectors + */ + override def distance(v1: VectorWithNorm, v2: VectorWithNorm): Double = { + assert(v1.norm > 0 && v2.norm > 0, "Cosine distance is not defined for zero-length vectors.") + 1 - dot(v1.vector, v2.vector) / v1.norm / v2.norm + } + + /** + * Updates the value of `sum` adding the `point` vector. + * @param point a `VectorWithNorm` to be added to `sum` of a cluster + * @param sum the `sum` for a cluster to be updated + */ + override def updateClusterSum(point: VectorWithNorm, sum: Vector): Unit = { + assert(point.norm > 0, "Cosine distance is not defined for zero-length vectors.") + axpy(1.0 / point.norm, point.vector, sum) + } + + /** + * Returns a centroid for a cluster given its `sum` vector and its `count` of points. + * + * @param sum the `sum` for a cluster + * @param count the number of points in the cluster + * @return the centroid of the cluster + */ + override def centroid(sum: Vector, count: Long): VectorWithNorm = { + scal(1.0 / count, sum) + val norm = Vectors.norm(sum, 2) + scal(1.0 / norm, sum) + new VectorWithNorm(sum, 1) + } + + /** + * @return the total cost of the cluster from its aggregated properties + */ + override def clusterCost( + centroid: VectorWithNorm, + pointsSum: VectorWithNorm, + numberOfPoints: Long, + pointsSquaredNorm: Double): Double = { + val costVector = pointsSum.vector.copy + math.max(numberOfPoints - dot(centroid.vector, costVector) / centroid.norm, 0.0) + } + + /** + * Returns two new centroids symmetric to the specified centroid applying `noise` with the + * with the specified `level`. + * + * @param level the level of `noise` to apply to the given centroid. + * @param noise a noise vector + * @param centroid the parent centroid + * @return a left and right centroid symmetric to `centroid` + */ + override def symmetricCentroids( + level: Double, + noise: Vector, + centroid: Vector): (VectorWithNorm, VectorWithNorm) = { + val (left, right) = super.symmetricCentroids(level, noise, centroid) + val leftVector = left.vector + val rightVector = right.vector + scal(1.0 / left.norm, leftVector) + scal(1.0 / right.norm, rightVector) + (new VectorWithNorm(leftVector, 1.0), new VectorWithNorm(rightVector, 1.0)) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala index 3c4ba0bc60c7f..b5b1be3490497 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala @@ -25,8 +25,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.ml.clustering.{KMeans => NewKMeans} import org.apache.spark.ml.util.Instrumentation import org.apache.spark.mllib.linalg.{Vector, Vectors} -import org.apache.spark.mllib.linalg.BLAS.{axpy, dot, scal} -import org.apache.spark.mllib.util.MLUtils +import org.apache.spark.mllib.linalg.BLAS.axpy import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils @@ -204,7 +203,7 @@ class KMeans private ( */ @Since("2.4.0") def setDistanceMeasure(distanceMeasure: String): this.type = { - KMeans.validateDistanceMeasure(distanceMeasure) + DistanceMeasure.validateDistanceMeasure(distanceMeasure) this.distanceMeasure = distanceMeasure this } @@ -582,14 +581,6 @@ object KMeans { case _ => false } } - - private[spark] def validateDistanceMeasure(distanceMeasure: String): Boolean = { - distanceMeasure match { - case DistanceMeasure.EUCLIDEAN => true - case DistanceMeasure.COSINE => true - case _ => false - } - } } /** @@ -605,186 +596,3 @@ private[clustering] class VectorWithNorm(val vector: Vector, val norm: Double) /** Converts the vector to a dense vector. */ def toDense: VectorWithNorm = new VectorWithNorm(Vectors.dense(vector.toArray), norm) } - - -private[spark] abstract class DistanceMeasure extends Serializable { - - /** - * @return the index of the closest center to the given point, as well as the cost. - */ - def findClosest( - centers: TraversableOnce[VectorWithNorm], - point: VectorWithNorm): (Int, Double) = { - var bestDistance = Double.PositiveInfinity - var bestIndex = 0 - var i = 0 - centers.foreach { center => - val currentDistance = distance(center, point) - if (currentDistance < bestDistance) { - bestDistance = currentDistance - bestIndex = i - } - i += 1 - } - (bestIndex, bestDistance) - } - - /** - * @return the K-means cost of a given point against the given cluster centers. - */ - def pointCost( - centers: TraversableOnce[VectorWithNorm], - point: VectorWithNorm): Double = { - findClosest(centers, point)._2 - } - - /** - * @return whether a center converged or not, given the epsilon parameter. - */ - def isCenterConverged( - oldCenter: VectorWithNorm, - newCenter: VectorWithNorm, - epsilon: Double): Boolean = { - distance(oldCenter, newCenter) <= epsilon - } - - /** - * @return the cosine distance between two points. - */ - def distance( - v1: VectorWithNorm, - v2: VectorWithNorm): Double - - /** - * Updates the value of `sum` adding the `point` vector. - * @param point a `VectorWithNorm` to be added to `sum` of a cluster - * @param sum the `sum` for a cluster to be updated - */ - def updateClusterSum(point: VectorWithNorm, sum: Vector): Unit = { - axpy(1.0, point.vector, sum) - } - - /** - * Returns a centroid for a cluster given its `sum` vector and its `count` of points. - * - * @param sum the `sum` for a cluster - * @param count the number of points in the cluster - * @return the centroid of the cluster - */ - def centroid(sum: Vector, count: Long): VectorWithNorm = { - scal(1.0 / count, sum) - new VectorWithNorm(sum) - } -} - -@Since("2.4.0") -object DistanceMeasure { - - @Since("2.4.0") - val EUCLIDEAN = "euclidean" - @Since("2.4.0") - val COSINE = "cosine" - - private[spark] def decodeFromString(distanceMeasure: String): DistanceMeasure = - distanceMeasure match { - case EUCLIDEAN => new EuclideanDistanceMeasure - case COSINE => new CosineDistanceMeasure - case _ => throw new IllegalArgumentException(s"distanceMeasure must be one of: " + - s"$EUCLIDEAN, $COSINE. $distanceMeasure provided.") - } -} - -private[spark] class EuclideanDistanceMeasure extends DistanceMeasure { - /** - * @return the index of the closest center to the given point, as well as the squared distance. - */ - override def findClosest( - centers: TraversableOnce[VectorWithNorm], - point: VectorWithNorm): (Int, Double) = { - var bestDistance = Double.PositiveInfinity - var bestIndex = 0 - var i = 0 - centers.foreach { center => - // Since `\|a - b\| \geq |\|a\| - \|b\||`, we can use this lower bound to avoid unnecessary - // distance computation. - var lowerBoundOfSqDist = center.norm - point.norm - lowerBoundOfSqDist = lowerBoundOfSqDist * lowerBoundOfSqDist - if (lowerBoundOfSqDist < bestDistance) { - val distance: Double = EuclideanDistanceMeasure.fastSquaredDistance(center, point) - if (distance < bestDistance) { - bestDistance = distance - bestIndex = i - } - } - i += 1 - } - (bestIndex, bestDistance) - } - - /** - * @return whether a center converged or not, given the epsilon parameter. - */ - override def isCenterConverged( - oldCenter: VectorWithNorm, - newCenter: VectorWithNorm, - epsilon: Double): Boolean = { - EuclideanDistanceMeasure.fastSquaredDistance(newCenter, oldCenter) <= epsilon * epsilon - } - - /** - * @param v1: first vector - * @param v2: second vector - * @return the Euclidean distance between the two input vectors - */ - override def distance(v1: VectorWithNorm, v2: VectorWithNorm): Double = { - Math.sqrt(EuclideanDistanceMeasure.fastSquaredDistance(v1, v2)) - } -} - - -private[spark] object EuclideanDistanceMeasure { - /** - * @return the squared Euclidean distance between two vectors computed by - * [[org.apache.spark.mllib.util.MLUtils#fastSquaredDistance]]. - */ - private[clustering] def fastSquaredDistance( - v1: VectorWithNorm, - v2: VectorWithNorm): Double = { - MLUtils.fastSquaredDistance(v1.vector, v1.norm, v2.vector, v2.norm) - } -} - -private[spark] class CosineDistanceMeasure extends DistanceMeasure { - /** - * @param v1: first vector - * @param v2: second vector - * @return the cosine distance between the two input vectors - */ - override def distance(v1: VectorWithNorm, v2: VectorWithNorm): Double = { - assert(v1.norm > 0 && v2.norm > 0, "Cosine distance is not defined for zero-length vectors.") - 1 - dot(v1.vector, v2.vector) / v1.norm / v2.norm - } - - /** - * Updates the value of `sum` adding the `point` vector. - * @param point a `VectorWithNorm` to be added to `sum` of a cluster - * @param sum the `sum` for a cluster to be updated - */ - override def updateClusterSum(point: VectorWithNorm, sum: Vector): Unit = { - axpy(1.0 / point.norm, point.vector, sum) - } - - /** - * Returns a centroid for a cluster given its `sum` vector and its `count` of points. - * - * @param sum the `sum` for a cluster - * @param count the number of points in the cluster - * @return the centroid of the cluster - */ - override def centroid(sum: Vector, count: Long): VectorWithNorm = { - scal(1.0 / count, sum) - val norm = Vectors.norm(sum, 2) - scal(1.0 / norm, sum) - new VectorWithNorm(sum, 1) - } -} diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala index fa7471fa2d658..02880f96ae6d9 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala @@ -17,9 +17,11 @@ package org.apache.spark.ml.clustering -import org.apache.spark.SparkFunSuite +import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.mllib.clustering.DistanceMeasure import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.Dataset @@ -140,6 +142,46 @@ class BisectingKMeansSuite testEstimatorAndModelReadWrite(bisectingKMeans, dataset, BisectingKMeansSuite.allParamSettings, BisectingKMeansSuite.allParamSettings, checkModelData) } + + test("BisectingKMeans with cosine distance is not supported for 0-length vectors") { + val model = new BisectingKMeans().setK(2).setDistanceMeasure(DistanceMeasure.COSINE).setSeed(1) + val df = spark.createDataFrame(spark.sparkContext.parallelize(Array( + Vectors.dense(0.0, 0.0), + Vectors.dense(10.0, 10.0), + Vectors.dense(1.0, 0.5) + )).map(v => TestRow(v))) + val e = intercept[SparkException](model.fit(df)) + assert(e.getCause.isInstanceOf[AssertionError]) + assert(e.getCause.getMessage.contains("Cosine distance is not defined")) + } + + test("BisectingKMeans with cosine distance") { + val df = spark.createDataFrame(spark.sparkContext.parallelize(Array( + Vectors.dense(1.0, 1.0), + Vectors.dense(10.0, 10.0), + Vectors.dense(1.0, 0.5), + Vectors.dense(10.0, 4.4), + Vectors.dense(-1.0, 1.0), + Vectors.dense(-100.0, 90.0) + )).map(v => TestRow(v))) + val model = new BisectingKMeans() + .setK(3) + .setDistanceMeasure(DistanceMeasure.COSINE) + .setSeed(1) + .fit(df) + val predictionDf = model.transform(df) + assert(predictionDf.select("prediction").distinct().count() == 3) + val predictionsMap = predictionDf.collect().map(row => + row.getAs[Vector]("features") -> row.getAs[Int]("prediction")).toMap + assert(predictionsMap(Vectors.dense(1.0, 1.0)) == + predictionsMap(Vectors.dense(10.0, 10.0))) + assert(predictionsMap(Vectors.dense(1.0, 0.5)) == + predictionsMap(Vectors.dense(10.0, 4.4))) + assert(predictionsMap(Vectors.dense(-1.0, 1.0)) == + predictionsMap(Vectors.dense(-100.0, 90.0))) + + model.clusterCenters.forall(Vectors.norm(_, 2) == 1.0) + } } object BisectingKMeansSuite { diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 381f7b5be1ddf..1b6d1dec69d49 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -36,6 +36,12 @@ object MimaExcludes { // Exclude rules for 2.4.x lazy val v24excludes = v23excludes ++ Seq( + // [SPARK-23412][ML] Add cosine distance measure to BisectingKmeans + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasDistanceMeasure.org$apache$spark$ml$param$shared$HasDistanceMeasure$_setter_$distanceMeasure_="), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasDistanceMeasure.getDistanceMeasure"), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasDistanceMeasure.distanceMeasure"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.clustering.BisectingKMeansModel#SaveLoadV1_0.load"), + // [SPARK-20659] Remove StorageStatus, or make it private ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.SparkExecutorInfo.totalOffHeapStorageMemory"), ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.SparkExecutorInfo.usedOffHeapStorageMemory"), From 23370554d0f88b82154d4232744b874cc58c7848 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Tue, 13 Mar 2018 15:20:09 +0100 Subject: [PATCH 0466/1161] [SPARK-23656][TEST] Perform assertions in XXH64Suite.testKnownByteArrayInputs() on big endian platform, too ## What changes were proposed in this pull request? This PR enables assertions in `XXH64Suite.testKnownByteArrayInputs()` on big endian platform, too. The current implementation performs them only on little endian platform. This PR increase test coverage of big endian platform. ## How was this patch tested? Updated `XXH64Suite` Tested on big endian platform using JIT compiler or interpreter `-Xint`. Author: Kazuaki Ishizaki Closes #20804 from kiszk/SPARK-23656. --- .../sql/catalyst/expressions/XXH64Suite.java | 20 ++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/XXH64Suite.java b/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/XXH64Suite.java index 711887f02832a..1baee91b3439c 100644 --- a/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/XXH64Suite.java +++ b/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/XXH64Suite.java @@ -74,9 +74,6 @@ public void testKnownByteArrayInputs() { Assert.assertEquals(0x739840CB819FA723L, XXH64.hashUnsafeBytes(BUFFER, Platform.BYTE_ARRAY_OFFSET, 1, PRIME)); - // These tests currently fail in a big endian environment because the test data and expected - // answers are generated with little endian the assumptions. We could revisit this when Platform - // becomes endian aware. if (ByteOrder.nativeOrder() == ByteOrder.LITTLE_ENDIAN) { Assert.assertEquals(0x9256E58AA397AEF1L, hasher.hashUnsafeBytes(BUFFER, Platform.BYTE_ARRAY_OFFSET, 4)); @@ -94,6 +91,23 @@ public void testKnownByteArrayInputs() { hasher.hashUnsafeBytes(BUFFER, Platform.BYTE_ARRAY_OFFSET, SIZE)); Assert.assertEquals(0xCAA65939306F1E21L, XXH64.hashUnsafeBytes(BUFFER, Platform.BYTE_ARRAY_OFFSET, SIZE, PRIME)); + } else { + Assert.assertEquals(0x7F875412350ADDDCL, + hasher.hashUnsafeBytes(BUFFER, Platform.BYTE_ARRAY_OFFSET, 4)); + Assert.assertEquals(0x564D279F524D8516L, + XXH64.hashUnsafeBytes(BUFFER, Platform.BYTE_ARRAY_OFFSET, 4, PRIME)); + Assert.assertEquals(0x7D9F07E27E0EB006L, + hasher.hashUnsafeBytes(BUFFER, Platform.BYTE_ARRAY_OFFSET, 8)); + Assert.assertEquals(0x893CEF564CB7858L, + XXH64.hashUnsafeBytes(BUFFER, Platform.BYTE_ARRAY_OFFSET, 8, PRIME)); + Assert.assertEquals(0xC6198C4C9CC49E17L, + hasher.hashUnsafeBytes(BUFFER, Platform.BYTE_ARRAY_OFFSET, 14)); + Assert.assertEquals(0x4E21BEF7164D4BBL, + XXH64.hashUnsafeBytes(BUFFER, Platform.BYTE_ARRAY_OFFSET, 14, PRIME)); + Assert.assertEquals(0xBCF5FAEDEE1F2B5AL, + hasher.hashUnsafeBytes(BUFFER, Platform.BYTE_ARRAY_OFFSET, SIZE)); + Assert.assertEquals(0x6F680C877A358FE5L, + XXH64.hashUnsafeBytes(BUFFER, Platform.BYTE_ARRAY_OFFSET, SIZE, PRIME)); } } From 9ddd1e2ceac8155b30beebb6bbfdcd32296fab2d Mon Sep 17 00:00:00 2001 From: Xingbo Jiang Date: Tue, 13 Mar 2018 23:31:08 +0900 Subject: [PATCH 0467/1161] [MINOR][SQL][TEST] Create table using `dataSourceName` in `HadoopFsRelationTest` ## What changes were proposed in this pull request? This PR fixes a minor issue in `HadoopFsRelationTest`, that you should create table using `dataSourceName` instead of `parquet`. The issue won't affect the correctness, but it will generate wrong error message in case the test fails. ## How was this patch tested? Exsiting tests. Author: Xingbo Jiang Closes #20780 from jiangxb1987/dataSourceName. --- .../apache/spark/sql/sources/HadoopFsRelationTest.scala | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/HadoopFsRelationTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/HadoopFsRelationTest.scala index 80aff446bc24b..53397991e59dc 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/HadoopFsRelationTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/HadoopFsRelationTest.scala @@ -335,16 +335,17 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes test("saveAsTable()/load() - non-partitioned table - ErrorIfExists") { withTable("t") { - sql("CREATE TABLE t(i INT) USING parquet") - intercept[AnalysisException] { + sql(s"CREATE TABLE t(i INT) USING $dataSourceName") + val msg = intercept[AnalysisException] { testDF.write.format(dataSourceName).mode(SaveMode.ErrorIfExists).saveAsTable("t") - } + }.getMessage + assert(msg.contains("Table `t` already exists")) } } test("saveAsTable()/load() - non-partitioned table - Ignore") { withTable("t") { - sql("CREATE TABLE t(i INT) USING parquet") + sql(s"CREATE TABLE t(i INT) USING $dataSourceName") testDF.write.format(dataSourceName).mode(SaveMode.Ignore).saveAsTable("t") assert(spark.table("t").collect().isEmpty) } From 918fb9beee6a2fd499b8f18dfe0d460f078f5290 Mon Sep 17 00:00:00 2001 From: zuotingbing Date: Tue, 13 Mar 2018 11:31:32 -0700 Subject: [PATCH 0468/1161] [SPARK-23547][SQL] Cleanup the .pipeout file when the Hive Session closed ## What changes were proposed in this pull request? ![2018-03-07_121010](https://user-images.githubusercontent.com/24823338/37073232-922e10d2-2200-11e8-8172-6e03aa984b39.png) when the hive session closed, we should also cleanup the .pipeout file. ## How was this patch tested? Added test cases. Author: zuotingbing Closes #20702 from zuotingbing/SPARK-23547. --- .../service/cli/session/HiveSessionImpl.java | 18 +++++++++++ .../HiveThriftServer2Suites.scala | 32 ++++++++++++++++++- 2 files changed, 49 insertions(+), 1 deletion(-) diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/session/HiveSessionImpl.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/session/HiveSessionImpl.java index fc818bc69c761..f59cdcd3188e6 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/session/HiveSessionImpl.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/session/HiveSessionImpl.java @@ -641,6 +641,8 @@ public void close() throws HiveSQLException { opHandleSet.clear(); // Cleanup session log directory. cleanupSessionLogDir(); + // Cleanup pipeout file. + cleanupPipeoutFile(); HiveHistory hiveHist = sessionState.getHiveHistory(); if (null != hiveHist) { hiveHist.closeStream(); @@ -665,6 +667,22 @@ public void close() throws HiveSQLException { } } + private void cleanupPipeoutFile() { + String lScratchDir = hiveConf.getVar(ConfVars.LOCALSCRATCHDIR); + String sessionID = hiveConf.getVar(ConfVars.HIVESESSIONID); + + File[] fileAry = new File(lScratchDir).listFiles( + (dir, name) -> name.startsWith(sessionID) && name.endsWith(".pipeout")); + + for (File file : fileAry) { + try { + FileUtils.forceDelete(file); + } catch (Exception e) { + LOG.error("Failed to cleanup pipeout file: " + file, e); + } + } + } + private void cleanupSessionLogDir() { if (isOperationLogEnabled) { try { diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala index b32c547cefefe..192f33a45e273 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala @@ -17,10 +17,11 @@ package org.apache.spark.sql.hive.thriftserver -import java.io.File +import java.io.{File, FilenameFilter} import java.net.URL import java.nio.charset.StandardCharsets import java.sql.{Date, DriverManager, SQLException, Statement} +import java.util.UUID import scala.collection.mutable import scala.collection.mutable.ArrayBuffer @@ -613,6 +614,28 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest { bufferSrc.close() } } + + test("SPARK-23547 Cleanup the .pipeout file when the Hive Session closed") { + def pipeoutFileList(sessionID: UUID): Array[File] = { + lScratchDir.listFiles(new FilenameFilter { + override def accept(dir: File, name: String): Boolean = { + name.startsWith(sessionID.toString) && name.endsWith(".pipeout") + } + }) + } + + withCLIServiceClient { client => + val user = System.getProperty("user.name") + val sessionHandle = client.openSession(user, "") + val sessionID = sessionHandle.getSessionId + + assert(pipeoutFileList(sessionID).length == 1) + + client.closeSession(sessionHandle) + + assert(pipeoutFileList(sessionID).length == 0) + } + } } class SingleSessionSuite extends HiveThriftJdbcTest { @@ -807,6 +830,7 @@ abstract class HiveThriftServer2Test extends SparkFunSuite with BeforeAndAfterAl private val pidDir: File = Utils.createTempDir(namePrefix = "thriftserver-pid") protected var logPath: File = _ protected var operationLogPath: File = _ + protected var lScratchDir: File = _ private var logTailingProcess: Process = _ private var diagnosisBuffer: ArrayBuffer[String] = ArrayBuffer.empty[String] @@ -844,6 +868,7 @@ abstract class HiveThriftServer2Test extends SparkFunSuite with BeforeAndAfterAl | --hiveconf ${ConfVars.HIVE_SERVER2_THRIFT_BIND_HOST}=localhost | --hiveconf ${ConfVars.HIVE_SERVER2_TRANSPORT_MODE}=$mode | --hiveconf ${ConfVars.HIVE_SERVER2_LOGGING_OPERATION_LOG_LOCATION}=$operationLogPath + | --hiveconf ${ConfVars.LOCALSCRATCHDIR}=$lScratchDir | --hiveconf $portConf=$port | --driver-class-path $driverClassPath | --driver-java-options -Dlog4j.debug @@ -873,6 +898,8 @@ abstract class HiveThriftServer2Test extends SparkFunSuite with BeforeAndAfterAl metastorePath.delete() operationLogPath = Utils.createTempDir() operationLogPath.delete() + lScratchDir = Utils.createTempDir() + lScratchDir.delete() logPath = null logTailingProcess = null @@ -956,6 +983,9 @@ abstract class HiveThriftServer2Test extends SparkFunSuite with BeforeAndAfterAl operationLogPath.delete() operationLogPath = null + lScratchDir.delete() + lScratchDir = null + Option(logPath).foreach(_.delete()) logPath = null From 1098933b0ac5cdb18101d3aebefa773c2ce05a50 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Tue, 13 Mar 2018 23:04:16 +0100 Subject: [PATCH 0469/1161] [SPARK-23598][SQL] Make methods in BufferedRowIterator public to avoid runtime error for a large query ## What changes were proposed in this pull request? This PR fixes runtime error regarding a large query when a generated code has split classes. The issue is `append()`, `stopEarly()`, and other methods are not accessible from split classes that are not subclasses of `BufferedRowIterator`. This PR fixes this issue by making them `public`. Before applying the PR, we see the following exception by running the attached program with `CodeGenerator.GENERATED_CLASS_SIZE_THRESHOLD=-1`. ``` test("SPARK-23598") { // When set -1 to CodeGenerator.GENERATED_CLASS_SIZE_THRESHOLD, an exception is thrown val df_pet_age = Seq((8, "bat"), (15, "mouse"), (5, "horse")).toDF("age", "name") df_pet_age.groupBy("name").avg("age").show() } ``` Exception: ``` 19:40:52.591 WARN org.apache.hadoop.util.NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable 19:41:32.319 ERROR org.apache.spark.executor.Executor: Exception in task 0.0 in stage 0.0 (TID 0) java.lang.IllegalAccessError: tried to access method org.apache.spark.sql.execution.BufferedRowIterator.shouldStop()Z from class org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage1$agg_NestedClass1 at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage1$agg_NestedClass1.agg_doAggregateWithKeys$(generated.java:203) at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage1.processNext(generated.java:160) at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43) at org.apache.spark.sql.execution.WholeStageCodegenExec$$anonfun$11$$anon$1.hasNext(WholeStageCodegenExec.scala:616) at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:408) at org.apache.spark.shuffle.sort.BypassMergeSortShuffleWriter.write(BypassMergeSortShuffleWriter.java:125) at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:96) at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:53) at org.apache.spark.scheduler.Task.run(Task.scala:109) at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:345) at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142) at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617) at java.lang.Thread.run(Thread.java:745) ... ``` Generated code (line 195 calles `stopEarly()`). ``` /* 001 */ public Object generate(Object[] references) { /* 002 */ return new GeneratedIteratorForCodegenStage1(references); /* 003 */ } /* 004 */ /* 005 */ // codegenStageId=1 /* 006 */ final class GeneratedIteratorForCodegenStage1 extends org.apache.spark.sql.execution.BufferedRowIterator { /* 007 */ private Object[] references; /* 008 */ private scala.collection.Iterator[] inputs; /* 009 */ private boolean agg_initAgg; /* 010 */ private boolean agg_bufIsNull; /* 011 */ private double agg_bufValue; /* 012 */ private boolean agg_bufIsNull1; /* 013 */ private long agg_bufValue1; /* 014 */ private agg_FastHashMap agg_fastHashMap; /* 015 */ private org.apache.spark.unsafe.KVIterator agg_fastHashMapIter; /* 016 */ private org.apache.spark.unsafe.KVIterator agg_mapIter; /* 017 */ private org.apache.spark.sql.execution.UnsafeFixedWidthAggregationMap agg_hashMap; /* 018 */ private org.apache.spark.sql.execution.UnsafeKVExternalSorter agg_sorter; /* 019 */ private scala.collection.Iterator inputadapter_input; /* 020 */ private boolean agg_agg_isNull11; /* 021 */ private boolean agg_agg_isNull25; /* 022 */ private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder[] agg_mutableStateArray1 = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder[2]; /* 023 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[] agg_mutableStateArray2 = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[2]; /* 024 */ private UnsafeRow[] agg_mutableStateArray = new UnsafeRow[2]; /* 025 */ /* 026 */ public GeneratedIteratorForCodegenStage1(Object[] references) { /* 027 */ this.references = references; /* 028 */ } /* 029 */ /* 030 */ public void init(int index, scala.collection.Iterator[] inputs) { /* 031 */ partitionIndex = index; /* 032 */ this.inputs = inputs; /* 033 */ /* 034 */ agg_fastHashMap = new agg_FastHashMap(((org.apache.spark.sql.execution.aggregate.HashAggregateExec) references[0] /* plan */).getTaskMemoryManager(), ((org.apache.spark.sql.execution.aggregate.HashAggregateExec) references[0] /* plan */).getEmptyAggregationBuffer()); /* 035 */ agg_hashMap = ((org.apache.spark.sql.execution.aggregate.HashAggregateExec) references[0] /* plan */).createHashMap(); /* 036 */ inputadapter_input = inputs[0]; /* 037 */ agg_mutableStateArray[0] = new UnsafeRow(1); /* 038 */ agg_mutableStateArray1[0] = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(agg_mutableStateArray[0], 32); /* 039 */ agg_mutableStateArray2[0] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(agg_mutableStateArray1[0], 1); /* 040 */ agg_mutableStateArray[1] = new UnsafeRow(3); /* 041 */ agg_mutableStateArray1[1] = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(agg_mutableStateArray[1], 32); /* 042 */ agg_mutableStateArray2[1] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(agg_mutableStateArray1[1], 3); /* 043 */ /* 044 */ } /* 045 */ /* 046 */ public class agg_FastHashMap { /* 047 */ private org.apache.spark.sql.catalyst.expressions.RowBasedKeyValueBatch batch; /* 048 */ private int[] buckets; /* 049 */ private int capacity = 1 << 16; /* 050 */ private double loadFactor = 0.5; /* 051 */ private int numBuckets = (int) (capacity / loadFactor); /* 052 */ private int maxSteps = 2; /* 053 */ private int numRows = 0; /* 054 */ private org.apache.spark.sql.types.StructType keySchema = new org.apache.spark.sql.types.StructType().add(((java.lang.String) references[1] /* keyName */), org.apache.spark.sql.types.DataTypes.StringType); /* 055 */ private org.apache.spark.sql.types.StructType valueSchema = new org.apache.spark.sql.types.StructType().add(((java.lang.String) references[2] /* keyName */), org.apache.spark.sql.types.DataTypes.DoubleType) /* 056 */ .add(((java.lang.String) references[3] /* keyName */), org.apache.spark.sql.types.DataTypes.LongType); /* 057 */ private Object emptyVBase; /* 058 */ private long emptyVOff; /* 059 */ private int emptyVLen; /* 060 */ private boolean isBatchFull = false; /* 061 */ /* 062 */ public agg_FastHashMap( /* 063 */ org.apache.spark.memory.TaskMemoryManager taskMemoryManager, /* 064 */ InternalRow emptyAggregationBuffer) { /* 065 */ batch = org.apache.spark.sql.catalyst.expressions.RowBasedKeyValueBatch /* 066 */ .allocate(keySchema, valueSchema, taskMemoryManager, capacity); /* 067 */ /* 068 */ final UnsafeProjection valueProjection = UnsafeProjection.create(valueSchema); /* 069 */ final byte[] emptyBuffer = valueProjection.apply(emptyAggregationBuffer).getBytes(); /* 070 */ /* 071 */ emptyVBase = emptyBuffer; /* 072 */ emptyVOff = Platform.BYTE_ARRAY_OFFSET; /* 073 */ emptyVLen = emptyBuffer.length; /* 074 */ /* 075 */ buckets = new int[numBuckets]; /* 076 */ java.util.Arrays.fill(buckets, -1); /* 077 */ } /* 078 */ /* 079 */ public org.apache.spark.sql.catalyst.expressions.UnsafeRow findOrInsert(UTF8String agg_key) { /* 080 */ long h = hash(agg_key); /* 081 */ int step = 0; /* 082 */ int idx = (int) h & (numBuckets - 1); /* 083 */ while (step < maxSteps) { /* 084 */ // Return bucket index if it's either an empty slot or already contains the key /* 085 */ if (buckets[idx] == -1) { /* 086 */ if (numRows < capacity && !isBatchFull) { /* 087 */ // creating the unsafe for new entry /* 088 */ UnsafeRow agg_result = new UnsafeRow(1); /* 089 */ org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder agg_holder /* 090 */ = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(agg_result, /* 091 */ 32); /* 092 */ org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter agg_rowWriter /* 093 */ = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter( /* 094 */ agg_holder, /* 095 */ 1); /* 096 */ agg_holder.reset(); //TODO: investigate if reset or zeroout are actually needed /* 097 */ agg_rowWriter.zeroOutNullBytes(); /* 098 */ agg_rowWriter.write(0, agg_key); /* 099 */ agg_result.setTotalSize(agg_holder.totalSize()); /* 100 */ Object kbase = agg_result.getBaseObject(); /* 101 */ long koff = agg_result.getBaseOffset(); /* 102 */ int klen = agg_result.getSizeInBytes(); /* 103 */ /* 104 */ UnsafeRow vRow /* 105 */ = batch.appendRow(kbase, koff, klen, emptyVBase, emptyVOff, emptyVLen); /* 106 */ if (vRow == null) { /* 107 */ isBatchFull = true; /* 108 */ } else { /* 109 */ buckets[idx] = numRows++; /* 110 */ } /* 111 */ return vRow; /* 112 */ } else { /* 113 */ // No more space /* 114 */ return null; /* 115 */ } /* 116 */ } else if (equals(idx, agg_key)) { /* 117 */ return batch.getValueRow(buckets[idx]); /* 118 */ } /* 119 */ idx = (idx + 1) & (numBuckets - 1); /* 120 */ step++; /* 121 */ } /* 122 */ // Didn't find it /* 123 */ return null; /* 124 */ } /* 125 */ /* 126 */ private boolean equals(int idx, UTF8String agg_key) { /* 127 */ UnsafeRow row = batch.getKeyRow(buckets[idx]); /* 128 */ return (row.getUTF8String(0).equals(agg_key)); /* 129 */ } /* 130 */ /* 131 */ private long hash(UTF8String agg_key) { /* 132 */ long agg_hash = 0; /* 133 */ /* 134 */ int agg_result = 0; /* 135 */ byte[] agg_bytes = agg_key.getBytes(); /* 136 */ for (int i = 0; i < agg_bytes.length; i++) { /* 137 */ int agg_hash1 = agg_bytes[i]; /* 138 */ agg_result = (agg_result ^ (0x9e3779b9)) + agg_hash1 + (agg_result << 6) + (agg_result >>> 2); /* 139 */ } /* 140 */ /* 141 */ agg_hash = (agg_hash ^ (0x9e3779b9)) + agg_result + (agg_hash << 6) + (agg_hash >>> 2); /* 142 */ /* 143 */ return agg_hash; /* 144 */ } /* 145 */ /* 146 */ public org.apache.spark.unsafe.KVIterator rowIterator() { /* 147 */ return batch.rowIterator(); /* 148 */ } /* 149 */ /* 150 */ public void close() { /* 151 */ batch.close(); /* 152 */ } /* 153 */ /* 154 */ } /* 155 */ /* 156 */ protected void processNext() throws java.io.IOException { /* 157 */ if (!agg_initAgg) { /* 158 */ agg_initAgg = true; /* 159 */ long wholestagecodegen_beforeAgg = System.nanoTime(); /* 160 */ agg_nestedClassInstance1.agg_doAggregateWithKeys(); /* 161 */ ((org.apache.spark.sql.execution.metric.SQLMetric) references[8] /* aggTime */).add((System.nanoTime() - wholestagecodegen_beforeAgg) / 1000000); /* 162 */ } /* 163 */ /* 164 */ // output the result /* 165 */ /* 166 */ while (agg_fastHashMapIter.next()) { /* 167 */ UnsafeRow agg_aggKey = (UnsafeRow) agg_fastHashMapIter.getKey(); /* 168 */ UnsafeRow agg_aggBuffer = (UnsafeRow) agg_fastHashMapIter.getValue(); /* 169 */ wholestagecodegen_nestedClassInstance.agg_doAggregateWithKeysOutput(agg_aggKey, agg_aggBuffer); /* 170 */ /* 171 */ if (shouldStop()) return; /* 172 */ } /* 173 */ agg_fastHashMap.close(); /* 174 */ /* 175 */ while (agg_mapIter.next()) { /* 176 */ UnsafeRow agg_aggKey = (UnsafeRow) agg_mapIter.getKey(); /* 177 */ UnsafeRow agg_aggBuffer = (UnsafeRow) agg_mapIter.getValue(); /* 178 */ wholestagecodegen_nestedClassInstance.agg_doAggregateWithKeysOutput(agg_aggKey, agg_aggBuffer); /* 179 */ /* 180 */ if (shouldStop()) return; /* 181 */ } /* 182 */ /* 183 */ agg_mapIter.close(); /* 184 */ if (agg_sorter == null) { /* 185 */ agg_hashMap.free(); /* 186 */ } /* 187 */ } /* 188 */ /* 189 */ private wholestagecodegen_NestedClass wholestagecodegen_nestedClassInstance = new wholestagecodegen_NestedClass(); /* 190 */ private agg_NestedClass1 agg_nestedClassInstance1 = new agg_NestedClass1(); /* 191 */ private agg_NestedClass agg_nestedClassInstance = new agg_NestedClass(); /* 192 */ /* 193 */ private class agg_NestedClass1 { /* 194 */ private void agg_doAggregateWithKeys() throws java.io.IOException { /* 195 */ while (inputadapter_input.hasNext() && !stopEarly()) { /* 196 */ InternalRow inputadapter_row = (InternalRow) inputadapter_input.next(); /* 197 */ int inputadapter_value = inputadapter_row.getInt(0); /* 198 */ boolean inputadapter_isNull1 = inputadapter_row.isNullAt(1); /* 199 */ UTF8String inputadapter_value1 = inputadapter_isNull1 ? /* 200 */ null : (inputadapter_row.getUTF8String(1)); /* 201 */ /* 202 */ agg_nestedClassInstance.agg_doConsume(inputadapter_row, inputadapter_value, inputadapter_value1, inputadapter_isNull1); /* 203 */ if (shouldStop()) return; /* 204 */ } /* 205 */ /* 206 */ agg_fastHashMapIter = agg_fastHashMap.rowIterator(); /* 207 */ agg_mapIter = ((org.apache.spark.sql.execution.aggregate.HashAggregateExec) references[0] /* plan */).finishAggregate(agg_hashMap, agg_sorter, ((org.apache.spark.sql.execution.metric.SQLMetric) references[4] /* peakMemory */), ((org.apache.spark.sql.execution.metric.SQLMetric) references[5] /* spillSize */), ((org.apache.spark.sql.execution.metric.SQLMetric) references[6] /* avgHashProbe */)); /* 208 */ /* 209 */ } /* 210 */ /* 211 */ } /* 212 */ /* 213 */ private class wholestagecodegen_NestedClass { /* 214 */ private void agg_doAggregateWithKeysOutput(UnsafeRow agg_keyTerm, UnsafeRow agg_bufferTerm) /* 215 */ throws java.io.IOException { /* 216 */ ((org.apache.spark.sql.execution.metric.SQLMetric) references[7] /* numOutputRows */).add(1); /* 217 */ /* 218 */ boolean agg_isNull35 = agg_keyTerm.isNullAt(0); /* 219 */ UTF8String agg_value37 = agg_isNull35 ? /* 220 */ null : (agg_keyTerm.getUTF8String(0)); /* 221 */ boolean agg_isNull36 = agg_bufferTerm.isNullAt(0); /* 222 */ double agg_value38 = agg_isNull36 ? /* 223 */ -1.0 : (agg_bufferTerm.getDouble(0)); /* 224 */ boolean agg_isNull37 = agg_bufferTerm.isNullAt(1); /* 225 */ long agg_value39 = agg_isNull37 ? /* 226 */ -1L : (agg_bufferTerm.getLong(1)); /* 227 */ /* 228 */ agg_mutableStateArray1[1].reset(); /* 229 */ /* 230 */ agg_mutableStateArray2[1].zeroOutNullBytes(); /* 231 */ /* 232 */ if (agg_isNull35) { /* 233 */ agg_mutableStateArray2[1].setNullAt(0); /* 234 */ } else { /* 235 */ agg_mutableStateArray2[1].write(0, agg_value37); /* 236 */ } /* 237 */ /* 238 */ if (agg_isNull36) { /* 239 */ agg_mutableStateArray2[1].setNullAt(1); /* 240 */ } else { /* 241 */ agg_mutableStateArray2[1].write(1, agg_value38); /* 242 */ } /* 243 */ /* 244 */ if (agg_isNull37) { /* 245 */ agg_mutableStateArray2[1].setNullAt(2); /* 246 */ } else { /* 247 */ agg_mutableStateArray2[1].write(2, agg_value39); /* 248 */ } /* 249 */ agg_mutableStateArray[1].setTotalSize(agg_mutableStateArray1[1].totalSize()); /* 250 */ append(agg_mutableStateArray[1]); /* 251 */ /* 252 */ } /* 253 */ /* 254 */ } /* 255 */ /* 256 */ private class agg_NestedClass { /* 257 */ private void agg_doConsume(InternalRow inputadapter_row, int agg_expr_0, UTF8String agg_expr_1, boolean agg_exprIsNull_1) throws java.io.IOException { /* 258 */ UnsafeRow agg_unsafeRowAggBuffer = null; /* 259 */ UnsafeRow agg_fastAggBuffer = null; /* 260 */ /* 261 */ if (true) { /* 262 */ if (!agg_exprIsNull_1) { /* 263 */ agg_fastAggBuffer = agg_fastHashMap.findOrInsert( /* 264 */ agg_expr_1); /* 265 */ } /* 266 */ } /* 267 */ // Cannot find the key in fast hash map, try regular hash map. /* 268 */ if (agg_fastAggBuffer == null) { /* 269 */ // generate grouping key /* 270 */ agg_mutableStateArray1[0].reset(); /* 271 */ /* 272 */ agg_mutableStateArray2[0].zeroOutNullBytes(); /* 273 */ /* 274 */ if (agg_exprIsNull_1) { /* 275 */ agg_mutableStateArray2[0].setNullAt(0); /* 276 */ } else { /* 277 */ agg_mutableStateArray2[0].write(0, agg_expr_1); /* 278 */ } /* 279 */ agg_mutableStateArray[0].setTotalSize(agg_mutableStateArray1[0].totalSize()); /* 280 */ int agg_value7 = 42; /* 281 */ /* 282 */ if (!agg_exprIsNull_1) { /* 283 */ agg_value7 = org.apache.spark.unsafe.hash.Murmur3_x86_32.hashUnsafeBytes(agg_expr_1.getBaseObject(), agg_expr_1.getBaseOffset(), agg_expr_1.numBytes(), agg_value7); /* 284 */ } /* 285 */ if (true) { /* 286 */ // try to get the buffer from hash map /* 287 */ agg_unsafeRowAggBuffer = /* 288 */ agg_hashMap.getAggregationBufferFromUnsafeRow(agg_mutableStateArray[0], agg_value7); /* 289 */ } /* 290 */ // Can't allocate buffer from the hash map. Spill the map and fallback to sort-based /* 291 */ // aggregation after processing all input rows. /* 292 */ if (agg_unsafeRowAggBuffer == null) { /* 293 */ if (agg_sorter == null) { /* 294 */ agg_sorter = agg_hashMap.destructAndCreateExternalSorter(); /* 295 */ } else { /* 296 */ agg_sorter.merge(agg_hashMap.destructAndCreateExternalSorter()); /* 297 */ } /* 298 */ /* 299 */ // the hash map had be spilled, it should have enough memory now, /* 300 */ // try to allocate buffer again. /* 301 */ agg_unsafeRowAggBuffer = agg_hashMap.getAggregationBufferFromUnsafeRow( /* 302 */ agg_mutableStateArray[0], agg_value7); /* 303 */ if (agg_unsafeRowAggBuffer == null) { /* 304 */ // failed to allocate the first page /* 305 */ throw new OutOfMemoryError("No enough memory for aggregation"); /* 306 */ } /* 307 */ } /* 308 */ /* 309 */ } /* 310 */ /* 311 */ if (agg_fastAggBuffer != null) { /* 312 */ // common sub-expressions /* 313 */ boolean agg_isNull21 = false; /* 314 */ long agg_value23 = -1L; /* 315 */ if (!false) { /* 316 */ agg_value23 = (long) agg_expr_0; /* 317 */ } /* 318 */ // evaluate aggregate function /* 319 */ boolean agg_isNull23 = true; /* 320 */ double agg_value25 = -1.0; /* 321 */ /* 322 */ boolean agg_isNull24 = agg_fastAggBuffer.isNullAt(0); /* 323 */ double agg_value26 = agg_isNull24 ? /* 324 */ -1.0 : (agg_fastAggBuffer.getDouble(0)); /* 325 */ if (!agg_isNull24) { /* 326 */ agg_agg_isNull25 = true; /* 327 */ double agg_value27 = -1.0; /* 328 */ do { /* 329 */ boolean agg_isNull26 = agg_isNull21; /* 330 */ double agg_value28 = -1.0; /* 331 */ if (!agg_isNull21) { /* 332 */ agg_value28 = (double) agg_value23; /* 333 */ } /* 334 */ if (!agg_isNull26) { /* 335 */ agg_agg_isNull25 = false; /* 336 */ agg_value27 = agg_value28; /* 337 */ continue; /* 338 */ } /* 339 */ /* 340 */ boolean agg_isNull27 = false; /* 341 */ double agg_value29 = -1.0; /* 342 */ if (!false) { /* 343 */ agg_value29 = (double) 0; /* 344 */ } /* 345 */ if (!agg_isNull27) { /* 346 */ agg_agg_isNull25 = false; /* 347 */ agg_value27 = agg_value29; /* 348 */ continue; /* 349 */ } /* 350 */ /* 351 */ } while (false); /* 352 */ /* 353 */ agg_isNull23 = false; // resultCode could change nullability. /* 354 */ agg_value25 = agg_value26 + agg_value27; /* 355 */ /* 356 */ } /* 357 */ boolean agg_isNull29 = false; /* 358 */ long agg_value31 = -1L; /* 359 */ if (!false && agg_isNull21) { /* 360 */ boolean agg_isNull31 = agg_fastAggBuffer.isNullAt(1); /* 361 */ long agg_value33 = agg_isNull31 ? /* 362 */ -1L : (agg_fastAggBuffer.getLong(1)); /* 363 */ agg_isNull29 = agg_isNull31; /* 364 */ agg_value31 = agg_value33; /* 365 */ } else { /* 366 */ boolean agg_isNull32 = true; /* 367 */ long agg_value34 = -1L; /* 368 */ /* 369 */ boolean agg_isNull33 = agg_fastAggBuffer.isNullAt(1); /* 370 */ long agg_value35 = agg_isNull33 ? /* 371 */ -1L : (agg_fastAggBuffer.getLong(1)); /* 372 */ if (!agg_isNull33) { /* 373 */ agg_isNull32 = false; // resultCode could change nullability. /* 374 */ agg_value34 = agg_value35 + 1L; /* 375 */ /* 376 */ } /* 377 */ agg_isNull29 = agg_isNull32; /* 378 */ agg_value31 = agg_value34; /* 379 */ } /* 380 */ // update fast row /* 381 */ if (!agg_isNull23) { /* 382 */ agg_fastAggBuffer.setDouble(0, agg_value25); /* 383 */ } else { /* 384 */ agg_fastAggBuffer.setNullAt(0); /* 385 */ } /* 386 */ /* 387 */ if (!agg_isNull29) { /* 388 */ agg_fastAggBuffer.setLong(1, agg_value31); /* 389 */ } else { /* 390 */ agg_fastAggBuffer.setNullAt(1); /* 391 */ } /* 392 */ } else { /* 393 */ // common sub-expressions /* 394 */ boolean agg_isNull7 = false; /* 395 */ long agg_value9 = -1L; /* 396 */ if (!false) { /* 397 */ agg_value9 = (long) agg_expr_0; /* 398 */ } /* 399 */ // evaluate aggregate function /* 400 */ boolean agg_isNull9 = true; /* 401 */ double agg_value11 = -1.0; /* 402 */ /* 403 */ boolean agg_isNull10 = agg_unsafeRowAggBuffer.isNullAt(0); /* 404 */ double agg_value12 = agg_isNull10 ? /* 405 */ -1.0 : (agg_unsafeRowAggBuffer.getDouble(0)); /* 406 */ if (!agg_isNull10) { /* 407 */ agg_agg_isNull11 = true; /* 408 */ double agg_value13 = -1.0; /* 409 */ do { /* 410 */ boolean agg_isNull12 = agg_isNull7; /* 411 */ double agg_value14 = -1.0; /* 412 */ if (!agg_isNull7) { /* 413 */ agg_value14 = (double) agg_value9; /* 414 */ } /* 415 */ if (!agg_isNull12) { /* 416 */ agg_agg_isNull11 = false; /* 417 */ agg_value13 = agg_value14; /* 418 */ continue; /* 419 */ } /* 420 */ /* 421 */ boolean agg_isNull13 = false; /* 422 */ double agg_value15 = -1.0; /* 423 */ if (!false) { /* 424 */ agg_value15 = (double) 0; /* 425 */ } /* 426 */ if (!agg_isNull13) { /* 427 */ agg_agg_isNull11 = false; /* 428 */ agg_value13 = agg_value15; /* 429 */ continue; /* 430 */ } /* 431 */ /* 432 */ } while (false); /* 433 */ /* 434 */ agg_isNull9 = false; // resultCode could change nullability. /* 435 */ agg_value11 = agg_value12 + agg_value13; /* 436 */ /* 437 */ } /* 438 */ boolean agg_isNull15 = false; /* 439 */ long agg_value17 = -1L; /* 440 */ if (!false && agg_isNull7) { /* 441 */ boolean agg_isNull17 = agg_unsafeRowAggBuffer.isNullAt(1); /* 442 */ long agg_value19 = agg_isNull17 ? /* 443 */ -1L : (agg_unsafeRowAggBuffer.getLong(1)); /* 444 */ agg_isNull15 = agg_isNull17; /* 445 */ agg_value17 = agg_value19; /* 446 */ } else { /* 447 */ boolean agg_isNull18 = true; /* 448 */ long agg_value20 = -1L; /* 449 */ /* 450 */ boolean agg_isNull19 = agg_unsafeRowAggBuffer.isNullAt(1); /* 451 */ long agg_value21 = agg_isNull19 ? /* 452 */ -1L : (agg_unsafeRowAggBuffer.getLong(1)); /* 453 */ if (!agg_isNull19) { /* 454 */ agg_isNull18 = false; // resultCode could change nullability. /* 455 */ agg_value20 = agg_value21 + 1L; /* 456 */ /* 457 */ } /* 458 */ agg_isNull15 = agg_isNull18; /* 459 */ agg_value17 = agg_value20; /* 460 */ } /* 461 */ // update unsafe row buffer /* 462 */ if (!agg_isNull9) { /* 463 */ agg_unsafeRowAggBuffer.setDouble(0, agg_value11); /* 464 */ } else { /* 465 */ agg_unsafeRowAggBuffer.setNullAt(0); /* 466 */ } /* 467 */ /* 468 */ if (!agg_isNull15) { /* 469 */ agg_unsafeRowAggBuffer.setLong(1, agg_value17); /* 470 */ } else { /* 471 */ agg_unsafeRowAggBuffer.setNullAt(1); /* 472 */ } /* 473 */ /* 474 */ } /* 475 */ /* 476 */ } /* 477 */ /* 478 */ } /* 479 */ /* 480 */ } ``` ## How was this patch tested? Added UT into `WholeStageCodegenSuite` Author: Kazuaki Ishizaki Closes #20779 from kiszk/SPARK-23598. --- .../spark/sql/execution/BufferedRowIterator.java | 12 ++++++++---- .../spark/sql/execution/WholeStageCodegenSuite.scala | 12 ++++++++++++ 2 files changed, 20 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/BufferedRowIterator.java b/sql/core/src/main/java/org/apache/spark/sql/execution/BufferedRowIterator.java index 730a4ae8d5605..74c9c05992719 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/BufferedRowIterator.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/BufferedRowIterator.java @@ -62,10 +62,14 @@ public long durationMs() { */ public abstract void init(int index, Iterator[] iters); + /* + * Attributes of the following four methods are public. Thus, they can be also accessed from + * methods in inner classes. See SPARK-23598 + */ /** * Append a row to currentRows. */ - protected void append(InternalRow row) { + public void append(InternalRow row) { currentRows.add(row); } @@ -75,7 +79,7 @@ protected void append(InternalRow row) { * If it returns true, the caller should exit the loop that [[InputAdapter]] generates. * This interface is mainly used to limit the number of input rows. */ - protected boolean stopEarly() { + public boolean stopEarly() { return false; } @@ -84,14 +88,14 @@ protected boolean stopEarly() { * * If it returns true, the caller should exit the loop (return from processNext()). */ - protected boolean shouldStop() { + public boolean shouldStop() { return !currentRows.isEmpty(); } /** * Increase the peak execution memory for current task. */ - protected void incPeakExecutionMemory(long size) { + public void incPeakExecutionMemory(long size) { TaskContext.get().taskMetrics().incPeakExecutionMemory(size); } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index 0fb9dd2017a09..4b40e4ef7571c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -32,6 +32,8 @@ import org.apache.spark.sql.types.{IntegerType, StringType, StructType} class WholeStageCodegenSuite extends QueryTest with SharedSQLContext { + import testImplicits._ + test("range/filter should be combined") { val df = spark.range(10).filter("id = 1").selectExpr("id + 1") val plan = df.queryExecution.executedPlan @@ -307,4 +309,14 @@ class WholeStageCodegenSuite extends QueryTest with SharedSQLContext { // a different query can result in codegen cache miss, that's by design } } + + test("SPARK-23598: Codegen working for lots of aggregation operations without runtime errors") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + var df = Seq((8, "bat"), (15, "mouse"), (5, "horse")).toDF("age", "name") + for (i <- 0 until 70) { + df = df.groupBy("name").agg(avg("age").alias("age")) + } + assert(df.limit(1).collect() === Array(Row("bat", 8.0))) + } + } } From 279b3db8970809104c30941254e57e3d62da5041 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Cattilapiros=E2=80=9D?= Date: Wed, 14 Mar 2018 18:36:31 -0700 Subject: [PATCH 0470/1161] [SPARK-22915][MLLIB] Streaming tests for spark.ml.feature, from N to Z MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # What changes were proposed in this pull request? Adds structured streaming tests using testTransformer for these suites: - NGramSuite - NormalizerSuite - OneHotEncoderEstimatorSuite - OneHotEncoderSuite - PCASuite - PolynomialExpansionSuite - QuantileDiscretizerSuite - RFormulaSuite - SQLTransformerSuite - StandardScalerSuite - StopWordsRemoverSuite - StringIndexerSuite - TokenizerSuite - RegexTokenizerSuite - VectorAssemblerSuite - VectorIndexerSuite - VectorSizeHintSuite - VectorSlicerSuite - Word2VecSuite # How was this patch tested? They are unit test. Author: “attilapiros” Closes #20686 from attilapiros/SPARK-22915. --- .../apache/spark/ml/feature/NGramSuite.scala | 23 +- .../spark/ml/feature/NormalizerSuite.scala | 57 ++--- .../feature/OneHotEncoderEstimatorSuite.scala | 193 ++++++++--------- .../spark/ml/feature/OneHotEncoderSuite.scala | 124 ++++++----- .../apache/spark/ml/feature/PCASuite.scala | 14 +- .../ml/feature/PolynomialExpansionSuite.scala | 62 +++--- .../ml/feature/QuantileDiscretizerSuite.scala | 198 +++++++++-------- .../spark/ml/feature/RFormulaSuite.scala | 158 +++++++------- .../ml/feature/SQLTransformerSuite.scala | 35 +-- .../ml/feature/StandardScalerSuite.scala | 33 +-- .../ml/feature/StopWordsRemoverSuite.scala | 37 ++-- .../spark/ml/feature/StringIndexerSuite.scala | 204 +++++++++--------- .../spark/ml/feature/TokenizerSuite.scala | 30 +-- .../spark/ml/feature/VectorIndexerSuite.scala | 183 +++++++++------- .../ml/feature/VectorSizeHintSuite.scala | 88 +++++--- .../spark/ml/feature/VectorSlicerSuite.scala | 27 +-- .../spark/ml/feature/Word2VecSuite.scala | 28 +-- .../org/apache/spark/ml/util/MLTest.scala | 33 ++- 18 files changed, 809 insertions(+), 718 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala index d4975c0b4e20e..e5956ee9942aa 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala @@ -19,17 +19,15 @@ package org.apache.spark.ml.feature import scala.beans.BeanInfo -import org.apache.spark.SparkFunSuite -import org.apache.spark.ml.util.DefaultReadWriteTest -import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.{Dataset, Row} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest} +import org.apache.spark.sql.{DataFrame, Row} + @BeanInfo case class NGramTestData(inputTokens: Array[String], wantedNGrams: Array[String]) -class NGramSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class NGramSuite extends MLTest with DefaultReadWriteTest { - import org.apache.spark.ml.feature.NGramSuite._ import testImplicits._ test("default behavior yields bigram features") { @@ -83,16 +81,11 @@ class NGramSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRe .setN(3) testDefaultReadWrite(t) } -} - -object NGramSuite extends SparkFunSuite { - def testNGram(t: NGram, dataset: Dataset[_]): Unit = { - t.transform(dataset) - .select("nGrams", "wantedNGrams") - .collect() - .foreach { case Row(actualNGrams, wantedNGrams) => + def testNGram(t: NGram, dataFrame: DataFrame): Unit = { + testTransformer[(Seq[String], Seq[String])](dataFrame, t, "nGrams", "wantedNGrams") { + case Row(actualNGrams : Seq[String], wantedNGrams: Seq[String]) => assert(actualNGrams === wantedNGrams) - } + } } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala index c75027fb4553d..eff57f1223af4 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala @@ -17,21 +17,17 @@ package org.apache.spark.ml.feature -import org.apache.spark.SparkFunSuite import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors} -import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest} import org.apache.spark.ml.util.TestingUtils._ -import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, Row} -class NormalizerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class NormalizerSuite extends MLTest with DefaultReadWriteTest { import testImplicits._ @transient var data: Array[Vector] = _ - @transient var dataFrame: DataFrame = _ - @transient var normalizer: Normalizer = _ @transient var l1Normalized: Array[Vector] = _ @transient var l2Normalized: Array[Vector] = _ @@ -62,49 +58,40 @@ class NormalizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa Vectors.dense(0.897906166, 0.113419726, 0.42532397), Vectors.sparse(3, Seq()) ) - - dataFrame = data.map(NormalizerSuite.FeatureData).toSeq.toDF() - normalizer = new Normalizer() - .setInputCol("features") - .setOutputCol("normalized_features") - } - - def collectResult(result: DataFrame): Array[Vector] = { - result.select("normalized_features").collect().map { - case Row(features: Vector) => features - } } - def assertTypeOfVector(lhs: Array[Vector], rhs: Array[Vector]): Unit = { - assert((lhs, rhs).zipped.forall { + def assertTypeOfVector(lhs: Vector, rhs: Vector): Unit = { + assert((lhs, rhs) match { case (v1: DenseVector, v2: DenseVector) => true case (v1: SparseVector, v2: SparseVector) => true case _ => false }, "The vector type should be preserved after normalization.") } - def assertValues(lhs: Array[Vector], rhs: Array[Vector]): Unit = { - assert((lhs, rhs).zipped.forall { (vector1, vector2) => - vector1 ~== vector2 absTol 1E-5 - }, "The vector value is not correct after normalization.") + def assertValues(lhs: Vector, rhs: Vector): Unit = { + assert(lhs ~== rhs absTol 1E-5, "The vector value is not correct after normalization.") } test("Normalization with default parameter") { - val result = collectResult(normalizer.transform(dataFrame)) - - assertTypeOfVector(data, result) + val normalizer = new Normalizer().setInputCol("features").setOutputCol("normalized") + val dataFrame: DataFrame = data.zip(l2Normalized).seq.toDF("features", "expected") - assertValues(result, l2Normalized) + testTransformer[(Vector, Vector)](dataFrame, normalizer, "features", "normalized", "expected") { + case Row(features: Vector, normalized: Vector, expected: Vector) => + assertTypeOfVector(normalized, features) + assertValues(normalized, expected) + } } test("Normalization with setter") { - normalizer.setP(1) + val dataFrame: DataFrame = data.zip(l1Normalized).seq.toDF("features", "expected") + val normalizer = new Normalizer().setInputCol("features").setOutputCol("normalized").setP(1) - val result = collectResult(normalizer.transform(dataFrame)) - - assertTypeOfVector(data, result) - - assertValues(result, l1Normalized) + testTransformer[(Vector, Vector)](dataFrame, normalizer, "features", "normalized", "expected") { + case Row(features: Vector, normalized: Vector, expected: Vector) => + assertTypeOfVector(normalized, features) + assertValues(normalized, expected) + } } test("read/write") { @@ -115,7 +102,3 @@ class NormalizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa testDefaultReadWrite(t) } } - -private object NormalizerSuite { - case class FeatureData(features: Vector) -} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderEstimatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderEstimatorSuite.scala index 1d3f845586426..d549e13262273 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderEstimatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderEstimatorSuite.scala @@ -17,18 +17,16 @@ package org.apache.spark.ml.feature -import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.attribute.{AttributeGroup, BinaryAttribute, NominalAttribute} import org.apache.spark.ml.linalg.{Vector, Vectors, VectorUDT} import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.DefaultReadWriteTest -import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest} +import org.apache.spark.sql.{Encoder, Row} +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.functions.col import org.apache.spark.sql.types._ -class OneHotEncoderEstimatorSuite - extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class OneHotEncoderEstimatorSuite extends MLTest with DefaultReadWriteTest { import testImplicits._ @@ -57,13 +55,10 @@ class OneHotEncoderEstimatorSuite assert(encoder.getDropLast === true) encoder.setDropLast(false) assert(encoder.getDropLast === false) - val model = encoder.fit(df) - val encoded = model.transform(df) - encoded.select("output", "expected").rdd.map { r => - (r.getAs[Vector](0), r.getAs[Vector](1)) - }.collect().foreach { case (vec1, vec2) => - assert(vec1 === vec2) + testTransformer[(Double, Vector)](df, model, "output", "expected") { + case Row(output: Vector, expected: Vector) => + assert(output === expected) } } @@ -87,11 +82,9 @@ class OneHotEncoderEstimatorSuite .setOutputCols(Array("output")) val model = encoder.fit(df) - val encoded = model.transform(df) - encoded.select("output", "expected").rdd.map { r => - (r.getAs[Vector](0), r.getAs[Vector](1)) - }.collect().foreach { case (vec1, vec2) => - assert(vec1 === vec2) + testTransformer[(Double, Vector)](df, model, "output", "expected") { + case Row(output: Vector, expected: Vector) => + assert(output === expected) } } @@ -103,11 +96,12 @@ class OneHotEncoderEstimatorSuite .setInputCols(Array("size")) .setOutputCols(Array("encoded")) val model = encoder.fit(df) - val output = model.transform(df) - val group = AttributeGroup.fromStructField(output.schema("encoded")) - assert(group.size === 2) - assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("small").withIndex(0)) - assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("medium").withIndex(1)) + testTransformerByGlobalCheckFunc[(Double)](df, model, "encoded") { rows => + val group = AttributeGroup.fromStructField(rows.head.schema("encoded")) + assert(group.size === 2) + assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("small").withIndex(0)) + assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("medium").withIndex(1)) + } } test("input column without ML attribute") { @@ -116,11 +110,12 @@ class OneHotEncoderEstimatorSuite .setInputCols(Array("index")) .setOutputCols(Array("encoded")) val model = encoder.fit(df) - val output = model.transform(df) - val group = AttributeGroup.fromStructField(output.schema("encoded")) - assert(group.size === 2) - assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("0").withIndex(0)) - assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("1").withIndex(1)) + testTransformerByGlobalCheckFunc[(Double)](df, model, "encoded") { rows => + val group = AttributeGroup.fromStructField(rows.head.schema("encoded")) + assert(group.size === 2) + assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("0").withIndex(0)) + assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("1").withIndex(1)) + } } test("read/write") { @@ -151,29 +146,30 @@ class OneHotEncoderEstimatorSuite val df = spark.createDataFrame(sc.parallelize(data), schema) - val dfWithTypes = df - .withColumn("shortInput", df("input").cast(ShortType)) - .withColumn("longInput", df("input").cast(LongType)) - .withColumn("intInput", df("input").cast(IntegerType)) - .withColumn("floatInput", df("input").cast(FloatType)) - .withColumn("decimalInput", df("input").cast(DecimalType(10, 0))) - - val cols = Array("input", "shortInput", "longInput", "intInput", - "floatInput", "decimalInput") - for (col <- cols) { - val encoder = new OneHotEncoderEstimator() - .setInputCols(Array(col)) + class NumericTypeWithEncoder[A](val numericType: NumericType) + (implicit val encoder: Encoder[(A, Vector)]) + + val types = Seq( + new NumericTypeWithEncoder[Short](ShortType), + new NumericTypeWithEncoder[Long](LongType), + new NumericTypeWithEncoder[Int](IntegerType), + new NumericTypeWithEncoder[Float](FloatType), + new NumericTypeWithEncoder[Byte](ByteType), + new NumericTypeWithEncoder[Double](DoubleType), + new NumericTypeWithEncoder[Decimal](DecimalType(10, 0))(ExpressionEncoder())) + + for (t <- types) { + val dfWithTypes = df.select(col("input").cast(t.numericType), col("expected")) + val estimator = new OneHotEncoderEstimator() + .setInputCols(Array("input")) .setOutputCols(Array("output")) .setDropLast(false) - val model = encoder.fit(dfWithTypes) - val encoded = model.transform(dfWithTypes) - - encoded.select("output", "expected").rdd.map { r => - (r.getAs[Vector](0), r.getAs[Vector](1)) - }.collect().foreach { case (vec1, vec2) => - assert(vec1 === vec2) - } + val model = estimator.fit(dfWithTypes) + testTransformer(dfWithTypes, model, "output", "expected") { + case Row(output: Vector, expected: Vector) => + assert(output === expected) + }(t.encoder) } } @@ -202,12 +198,16 @@ class OneHotEncoderEstimatorSuite assert(encoder.getDropLast === false) val model = encoder.fit(df) - val encoded = model.transform(df) - encoded.select("output1", "expected1", "output2", "expected2").rdd.map { r => - (r.getAs[Vector](0), r.getAs[Vector](1), r.getAs[Vector](2), r.getAs[Vector](3)) - }.collect().foreach { case (vec1, vec2, vec3, vec4) => - assert(vec1 === vec2) - assert(vec3 === vec4) + testTransformer[(Double, Vector, Double, Vector)]( + df, + model, + "output1", + "output2", + "expected1", + "expected2") { + case Row(output1: Vector, output2: Vector, expected1: Vector, expected2: Vector) => + assert(output1 === expected1) + assert(output2 === expected2) } } @@ -233,12 +233,16 @@ class OneHotEncoderEstimatorSuite .setOutputCols(Array("output1", "output2")) val model = encoder.fit(df) - val encoded = model.transform(df) - encoded.select("output1", "expected1", "output2", "expected2").rdd.map { r => - (r.getAs[Vector](0), r.getAs[Vector](1), r.getAs[Vector](2), r.getAs[Vector](3)) - }.collect().foreach { case (vec1, vec2, vec3, vec4) => - assert(vec1 === vec2) - assert(vec3 === vec4) + testTransformer[(Double, Vector, Double, Vector)]( + df, + model, + "output1", + "output2", + "expected1", + "expected2") { + case Row(output1: Vector, output2: Vector, expected1: Vector, expected2: Vector) => + assert(output1 === expected1) + assert(output2 === expected2) } } @@ -253,10 +257,12 @@ class OneHotEncoderEstimatorSuite .setOutputCols(Array("encoded")) val model = encoder.fit(trainingDF) - val err = intercept[SparkException] { - model.transform(testDF).show - } - err.getMessage.contains("Unseen value: 3.0. To handle unseen values") + testTransformerByInterceptingException[(Int, Int)]( + testDF, + model, + expectedMessagePart = "Unseen value: 3.0. To handle unseen values", + firstResultCol = "encoded") + } test("Can't transform on negative input") { @@ -268,10 +274,11 @@ class OneHotEncoderEstimatorSuite .setOutputCols(Array("encoded")) val model = encoder.fit(trainingDF) - val err = intercept[SparkException] { - model.transform(testDF).collect() - } - err.getMessage.contains("Negative value: -1.0. Input can't be negative") + testTransformerByInterceptingException[(Int, Int)]( + testDF, + model, + expectedMessagePart = "Negative value: -1.0. Input can't be negative", + firstResultCol = "encoded") } test("Keep on invalid values: dropLast = false") { @@ -295,11 +302,9 @@ class OneHotEncoderEstimatorSuite .setDropLast(false) val model = encoder.fit(trainingDF) - val encoded = model.transform(testDF) - encoded.select("output", "expected").rdd.map { r => - (r.getAs[Vector](0), r.getAs[Vector](1)) - }.collect().foreach { case (vec1, vec2) => - assert(vec1 === vec2) + testTransformer[(Double, Vector)](testDF, model, "output", "expected") { + case Row(output: Vector, expected: Vector) => + assert(output === expected) } } @@ -324,11 +329,9 @@ class OneHotEncoderEstimatorSuite .setDropLast(true) val model = encoder.fit(trainingDF) - val encoded = model.transform(testDF) - encoded.select("output", "expected").rdd.map { r => - (r.getAs[Vector](0), r.getAs[Vector](1)) - }.collect().foreach { case (vec1, vec2) => - assert(vec1 === vec2) + testTransformer[(Double, Vector)](testDF, model, "output", "expected") { + case Row(output: Vector, expected: Vector) => + assert(output === expected) } } @@ -355,19 +358,15 @@ class OneHotEncoderEstimatorSuite val model = encoder.fit(df) model.setDropLast(false) - val encoded1 = model.transform(df) - encoded1.select("output", "expected1").rdd.map { r => - (r.getAs[Vector](0), r.getAs[Vector](1)) - }.collect().foreach { case (vec1, vec2) => - assert(vec1 === vec2) + testTransformer[(Double, Vector, Vector)](df, model, "output", "expected1") { + case Row(output: Vector, expected1: Vector) => + assert(output === expected1) } model.setDropLast(true) - val encoded2 = model.transform(df) - encoded2.select("output", "expected2").rdd.map { r => - (r.getAs[Vector](0), r.getAs[Vector](1)) - }.collect().foreach { case (vec1, vec2) => - assert(vec1 === vec2) + testTransformer[(Double, Vector, Vector)](df, model, "output", "expected2") { + case Row(output: Vector, expected2: Vector) => + assert(output === expected2) } } @@ -392,13 +391,14 @@ class OneHotEncoderEstimatorSuite val model = encoder.fit(trainingDF) model.setHandleInvalid("error") - val err = intercept[SparkException] { - model.transform(testDF).collect() - } - err.getMessage.contains("Unseen value: 3.0. To handle unseen values") + testTransformerByInterceptingException[(Double, Vector)]( + testDF, + model, + expectedMessagePart = "Unseen value: 3.0. To handle unseen values", + firstResultCol = "output") model.setHandleInvalid("keep") - model.transform(testDF).collect() + testTransformerByGlobalCheckFunc[(Double, Vector)](testDF, model, "output") { _ => } } test("Transforming on mismatched attributes") { @@ -413,9 +413,10 @@ class OneHotEncoderEstimatorSuite val testAttr = NominalAttribute.defaultAttr.withValues("tiny", "small", "medium", "large") val testDF = Seq(0.0, 1.0, 2.0, 3.0).map(Tuple1.apply).toDF("size") .select(col("size").as("size", testAttr.toMetadata())) - val err = intercept[Exception] { - model.transform(testDF).collect() - } - err.getMessage.contains("OneHotEncoderModel expected 2 categorical values") + testTransformerByInterceptingException[(Double)]( + testDF, + model, + expectedMessagePart = "OneHotEncoderModel expected 2 categorical values", + firstResultCol = "encoded") } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala index c44c6813a94be..41b32b2ffa096 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala @@ -17,18 +17,18 @@ package org.apache.spark.ml.feature -import org.apache.spark.SparkFunSuite import org.apache.spark.ml.attribute.{AttributeGroup, BinaryAttribute, NominalAttribute} import org.apache.spark.ml.linalg.Vector +import org.apache.spark.ml.linalg.Vectors import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.DefaultReadWriteTest -import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.DataFrame +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest} +import org.apache.spark.sql.{DataFrame, Encoder, Row} +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.functions.col import org.apache.spark.sql.types._ class OneHotEncoderSuite - extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + extends MLTest with DefaultReadWriteTest { import testImplicits._ @@ -54,16 +54,19 @@ class OneHotEncoderSuite assert(encoder.getDropLast === true) encoder.setDropLast(false) assert(encoder.getDropLast === false) - val encoded = encoder.transform(transformed) - - val output = encoded.select("id", "labelVec").rdd.map { r => - val vec = r.getAs[Vector](1) - (r.getInt(0), vec(0), vec(1), vec(2)) - }.collect().toSet - // a -> 0, b -> 2, c -> 1 - val expected = Set((0, 1.0, 0.0, 0.0), (1, 0.0, 0.0, 1.0), (2, 0.0, 1.0, 0.0), - (3, 1.0, 0.0, 0.0), (4, 1.0, 0.0, 0.0), (5, 0.0, 1.0, 0.0)) - assert(output === expected) + val expected = Seq( + (0, Vectors.sparse(3, Seq((0, 1.0)))), + (1, Vectors.sparse(3, Seq((2, 1.0)))), + (2, Vectors.sparse(3, Seq((1, 1.0)))), + (3, Vectors.sparse(3, Seq((0, 1.0)))), + (4, Vectors.sparse(3, Seq((0, 1.0)))), + (5, Vectors.sparse(3, Seq((1, 1.0))))).toDF("id", "expected") + + val withExpected = transformed.join(expected, "id") + testTransformer[(Int, String, Double, Vector)](withExpected, encoder, "labelVec", "expected") { + case Row(output: Vector, expected: Vector) => + assert(output === expected) + } } test("OneHotEncoder dropLast = true") { @@ -71,16 +74,19 @@ class OneHotEncoderSuite val encoder = new OneHotEncoder() .setInputCol("labelIndex") .setOutputCol("labelVec") - val encoded = encoder.transform(transformed) - - val output = encoded.select("id", "labelVec").rdd.map { r => - val vec = r.getAs[Vector](1) - (r.getInt(0), vec(0), vec(1)) - }.collect().toSet - // a -> 0, b -> 2, c -> 1 - val expected = Set((0, 1.0, 0.0), (1, 0.0, 0.0), (2, 0.0, 1.0), - (3, 1.0, 0.0), (4, 1.0, 0.0), (5, 0.0, 1.0)) - assert(output === expected) + val expected = Seq( + (0, Vectors.sparse(2, Seq((0, 1.0)))), + (1, Vectors.sparse(2, Seq())), + (2, Vectors.sparse(2, Seq((1, 1.0)))), + (3, Vectors.sparse(2, Seq((0, 1.0)))), + (4, Vectors.sparse(2, Seq((0, 1.0)))), + (5, Vectors.sparse(2, Seq((1, 1.0))))).toDF("id", "expected") + + val withExpected = transformed.join(expected, "id") + testTransformer[(Int, String, Double, Vector)](withExpected, encoder, "labelVec", "expected") { + case Row(output: Vector, expected: Vector) => + assert(output === expected) + } } test("input column with ML attribute") { @@ -90,20 +96,22 @@ class OneHotEncoderSuite val encoder = new OneHotEncoder() .setInputCol("size") .setOutputCol("encoded") - val output = encoder.transform(df) - val group = AttributeGroup.fromStructField(output.schema("encoded")) - assert(group.size === 2) - assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("small").withIndex(0)) - assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("medium").withIndex(1)) + testTransformerByGlobalCheckFunc[(Double)](df, encoder, "encoded") { rows => + val group = AttributeGroup.fromStructField(rows.head.schema("encoded")) + assert(group.size === 2) + assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("small").withIndex(0)) + assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("medium").withIndex(1)) + } } + test("input column without ML attribute") { val df = Seq(0.0, 1.0, 2.0, 1.0).map(Tuple1.apply).toDF("index") val encoder = new OneHotEncoder() .setInputCol("index") .setOutputCol("encoded") - val output = encoder.transform(df) - val group = AttributeGroup.fromStructField(output.schema("encoded")) + val rows = encoder.transform(df).select("encoded").collect() + val group = AttributeGroup.fromStructField(rows.head.schema("encoded")) assert(group.size === 2) assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("0").withIndex(0)) assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("1").withIndex(1)) @@ -119,29 +127,41 @@ class OneHotEncoderSuite test("OneHotEncoder with varying types") { val df = stringIndexed() - val dfWithTypes = df - .withColumn("shortLabel", df("labelIndex").cast(ShortType)) - .withColumn("longLabel", df("labelIndex").cast(LongType)) - .withColumn("intLabel", df("labelIndex").cast(IntegerType)) - .withColumn("floatLabel", df("labelIndex").cast(FloatType)) - .withColumn("decimalLabel", df("labelIndex").cast(DecimalType(10, 0))) - val cols = Array("labelIndex", "shortLabel", "longLabel", "intLabel", - "floatLabel", "decimalLabel") - for (col <- cols) { + val attr = NominalAttribute.defaultAttr.withValues("small", "medium", "large") + val expected = Seq( + (0, Vectors.sparse(3, Seq((0, 1.0)))), + (1, Vectors.sparse(3, Seq((2, 1.0)))), + (2, Vectors.sparse(3, Seq((1, 1.0)))), + (3, Vectors.sparse(3, Seq((0, 1.0)))), + (4, Vectors.sparse(3, Seq((0, 1.0)))), + (5, Vectors.sparse(3, Seq((1, 1.0))))).toDF("id", "expected") + + val withExpected = df.join(expected, "id") + + class NumericTypeWithEncoder[A](val numericType: NumericType) + (implicit val encoder: Encoder[(A, Vector)]) + + val types = Seq( + new NumericTypeWithEncoder[Short](ShortType), + new NumericTypeWithEncoder[Long](LongType), + new NumericTypeWithEncoder[Int](IntegerType), + new NumericTypeWithEncoder[Float](FloatType), + new NumericTypeWithEncoder[Byte](ByteType), + new NumericTypeWithEncoder[Double](DoubleType), + new NumericTypeWithEncoder[Decimal](DecimalType(10, 0))(ExpressionEncoder())) + + for (t <- types) { + val dfWithTypes = withExpected.select(col("labelIndex") + .cast(t.numericType).as("labelIndex", attr.toMetadata()), col("expected")) val encoder = new OneHotEncoder() - .setInputCol(col) + .setInputCol("labelIndex") .setOutputCol("labelVec") .setDropLast(false) - val encoded = encoder.transform(dfWithTypes) - - val output = encoded.select("id", "labelVec").rdd.map { r => - val vec = r.getAs[Vector](1) - (r.getInt(0), vec(0), vec(1), vec(2)) - }.collect().toSet - // a -> 0, b -> 2, c -> 1 - val expected = Set((0, 1.0, 0.0, 0.0), (1, 0.0, 0.0, 1.0), (2, 0.0, 1.0, 0.0), - (3, 1.0, 0.0, 0.0), (4, 1.0, 0.0, 0.0), (5, 0.0, 1.0, 0.0)) - assert(output === expected) + + testTransformer(dfWithTypes, encoder, "labelVec", "expected") { + case Row(output: Vector, expected: Vector) => + assert(output === expected) + }(t.encoder) } } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala index 3067a52a4df76..531b1d7c4d9f7 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala @@ -17,17 +17,15 @@ package org.apache.spark.ml.feature -import org.apache.spark.SparkFunSuite import org.apache.spark.ml.linalg._ import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.linalg.{Vectors => OldVectors} import org.apache.spark.mllib.linalg.distributed.RowMatrix -import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.Row -class PCASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class PCASuite extends MLTest with DefaultReadWriteTest { import testImplicits._ @@ -62,10 +60,10 @@ class PCASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead val pcaModel = pca.fit(df) MLTestingUtils.checkCopyAndUids(pca, pcaModel) - - pcaModel.transform(df).select("pca_features", "expected").collect().foreach { - case Row(x: Vector, y: Vector) => - assert(x ~== y absTol 1e-5, "Transformed vector is different with expected vector.") + testTransformer[(Vector, Vector)](df, pcaModel, "pca_features", "expected") { + case Row(result: Vector, expected: Vector) => + assert(result ~== expected absTol 1e-5, + "Transformed vector is different with expected vector.") } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala index e4b0ddf98bfad..0be7aa6c83f29 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala @@ -17,18 +17,13 @@ package org.apache.spark.ml.feature -import org.scalatest.exceptions.TestFailedException - -import org.apache.spark.SparkFunSuite import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest} import org.apache.spark.ml.util.TestingUtils._ -import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.Row -class PolynomialExpansionSuite - extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class PolynomialExpansionSuite extends MLTest with DefaultReadWriteTest { import testImplicits._ @@ -60,6 +55,18 @@ class PolynomialExpansionSuite -1.08, 3.3, 1.98, -3.63, 9.0, 5.4, -9.9, -27.0), Vectors.sparse(19, Array.empty, Array.empty)) + def assertTypeOfVector(lhs: Vector, rhs: Vector): Unit = { + assert((lhs, rhs) match { + case (v1: DenseVector, v2: DenseVector) => true + case (v1: SparseVector, v2: SparseVector) => true + case _ => false + }, "The vector type should be preserved after polynomial expansion.") + } + + def assertValues(lhs: Vector, rhs: Vector): Unit = { + assert(lhs ~== rhs absTol 1e-1, "The vector value is not correct after polynomial expansion.") + } + test("Polynomial expansion with default parameter") { val df = data.zip(twoDegreeExpansion).toSeq.toDF("features", "expected") @@ -67,13 +74,10 @@ class PolynomialExpansionSuite .setInputCol("features") .setOutputCol("polyFeatures") - polynomialExpansion.transform(df).select("polyFeatures", "expected").collect().foreach { - case Row(expanded: DenseVector, expected: DenseVector) => - assert(expanded ~== expected absTol 1e-1) - case Row(expanded: SparseVector, expected: SparseVector) => - assert(expanded ~== expected absTol 1e-1) - case _ => - throw new TestFailedException("Unmatched data types after polynomial expansion", 0) + testTransformer[(Vector, Vector)](df, polynomialExpansion, "polyFeatures", "expected") { + case Row(expanded: Vector, expected: Vector) => + assertTypeOfVector(expanded, expected) + assertValues(expanded, expected) } } @@ -85,13 +89,10 @@ class PolynomialExpansionSuite .setOutputCol("polyFeatures") .setDegree(3) - polynomialExpansion.transform(df).select("polyFeatures", "expected").collect().foreach { - case Row(expanded: DenseVector, expected: DenseVector) => - assert(expanded ~== expected absTol 1e-1) - case Row(expanded: SparseVector, expected: SparseVector) => - assert(expanded ~== expected absTol 1e-1) - case _ => - throw new TestFailedException("Unmatched data types after polynomial expansion", 0) + testTransformer[(Vector, Vector)](df, polynomialExpansion, "polyFeatures", "expected") { + case Row(expanded: Vector, expected: Vector) => + assertTypeOfVector(expanded, expected) + assertValues(expanded, expected) } } @@ -103,11 +104,9 @@ class PolynomialExpansionSuite .setOutputCol("polyFeatures") .setDegree(1) - polynomialExpansion.transform(df).select("polyFeatures", "expected").collect().foreach { + testTransformer[(Vector, Vector)](df, polynomialExpansion, "polyFeatures", "expected") { case Row(expanded: Vector, expected: Vector) => - assert(expanded ~== expected absTol 1e-1) - case _ => - throw new TestFailedException("Unmatched data types after polynomial expansion", 0) + assertValues(expanded, expected) } } @@ -133,12 +132,13 @@ class PolynomialExpansionSuite .setOutputCol("polyFeatures") for (i <- Seq(10, 11)) { - val transformed = t.setDegree(i) - .transform(df) - .select(s"expectedPoly${i}size", "polyFeatures") - .rdd.map { case Row(expected: Int, v: Vector) => expected == v.size } - - assert(transformed.collect.forall(identity)) + testTransformer[(Vector, Int, Int)]( + df, + t.setDegree(i), + s"expectedPoly${i}size", + "polyFeatures") { case Row(size: Int, expected: Vector) => + assert(size === expected.size) + } } } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala index 6c363799dd300..b009038bbd833 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala @@ -17,15 +17,11 @@ package org.apache.spark.ml.feature -import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.Pipeline -import org.apache.spark.ml.util.DefaultReadWriteTest -import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest} import org.apache.spark.sql._ -import org.apache.spark.sql.functions.udf -class QuantileDiscretizerSuite - extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class QuantileDiscretizerSuite extends MLTest with DefaultReadWriteTest { import testImplicits._ @@ -40,19 +36,19 @@ class QuantileDiscretizerSuite .setInputCol("input") .setOutputCol("result") .setNumBuckets(numBuckets) - val result = discretizer.fit(df).transform(df) - - val observedNumBuckets = result.select("result").distinct.count - assert(observedNumBuckets === numBuckets, - "Observed number of buckets does not equal expected number of buckets.") + val model = discretizer.fit(df) - val relativeError = discretizer.getRelativeError - val isGoodBucket = udf { - (size: Int) => math.abs( size - (datasetSize / numBuckets)) <= (relativeError * datasetSize) + testTransformerByGlobalCheckFunc[(Double)](df, model, "result") { rows => + val result = rows.map { r => Tuple1(r.getDouble(0)) }.toDF("result") + val observedNumBuckets = result.select("result").distinct.count + assert(observedNumBuckets === numBuckets, + "Observed number of buckets does not equal expected number of buckets.") + val relativeError = discretizer.getRelativeError + val numGoodBuckets = result.groupBy("result").count + .filter(s"abs(count - ${datasetSize / numBuckets}) <= ${relativeError * datasetSize}").count + assert(numGoodBuckets === numBuckets, + "Bucket sizes are not within expected relative error tolerance.") } - val numGoodBuckets = result.groupBy("result").count.filter(isGoodBucket($"count")).count - assert(numGoodBuckets === numBuckets, - "Bucket sizes are not within expected relative error tolerance.") } test("Test on data with high proportion of duplicated values") { @@ -67,11 +63,14 @@ class QuantileDiscretizerSuite .setInputCol("input") .setOutputCol("result") .setNumBuckets(numBuckets) - val result = discretizer.fit(df).transform(df) - val observedNumBuckets = result.select("result").distinct.count - assert(observedNumBuckets == expectedNumBuckets, - s"Observed number of buckets are not correct." + - s" Expected $expectedNumBuckets but found $observedNumBuckets") + val model = discretizer.fit(df) + testTransformerByGlobalCheckFunc[(Double)](df, model, "result") { rows => + val result = rows.map { r => Tuple1(r.getDouble(0)) }.toDF("result") + val observedNumBuckets = result.select("result").distinct.count + assert(observedNumBuckets == expectedNumBuckets, + s"Observed number of buckets are not correct." + + s" Expected $expectedNumBuckets but found $observedNumBuckets") + } } test("Test transform on data with NaN value") { @@ -90,17 +89,20 @@ class QuantileDiscretizerSuite withClue("QuantileDiscretizer with handleInvalid=error should throw exception for NaN values") { val dataFrame: DataFrame = validData.toSeq.toDF("input") - intercept[SparkException] { - discretizer.fit(dataFrame).transform(dataFrame).collect() - } + val model = discretizer.fit(dataFrame) + testTransformerByInterceptingException[(Double)]( + dataFrame, + model, + expectedMessagePart = "Bucketizer encountered NaN value.", + firstResultCol = "result") } List(("keep", expectedKeep), ("skip", expectedSkip)).foreach{ case(u, v) => discretizer.setHandleInvalid(u) val dataFrame: DataFrame = validData.zip(v).toSeq.toDF("input", "expected") - val result = discretizer.fit(dataFrame).transform(dataFrame) - result.select("result", "expected").collect().foreach { + val model = discretizer.fit(dataFrame) + testTransformer[(Double, Double)](dataFrame, model, "result", "expected") { case Row(x: Double, y: Double) => assert(x === y, s"The feature value is not correct after bucketing. Expected $y but found $x") @@ -119,14 +121,17 @@ class QuantileDiscretizerSuite .setOutputCol("result") .setNumBuckets(5) - val result = discretizer.fit(trainDF).transform(testDF) - val firstBucketSize = result.filter(result("result") === 0.0).count - val lastBucketSize = result.filter(result("result") === 4.0).count + val model = discretizer.fit(trainDF) + testTransformerByGlobalCheckFunc[(Double)](testDF, model, "result") { rows => + val result = rows.map { r => Tuple1(r.getDouble(0)) }.toDF("result") + val firstBucketSize = result.filter(result("result") === 0.0).count + val lastBucketSize = result.filter(result("result") === 4.0).count - assert(firstBucketSize === 30L, - s"Size of first bucket ${firstBucketSize} did not equal expected value of 30.") - assert(lastBucketSize === 31L, - s"Size of last bucket ${lastBucketSize} did not equal expected value of 31.") + assert(firstBucketSize === 30L, + s"Size of first bucket ${firstBucketSize} did not equal expected value of 30.") + assert(lastBucketSize === 31L, + s"Size of last bucket ${lastBucketSize} did not equal expected value of 31.") + } } test("read/write") { @@ -167,21 +172,24 @@ class QuantileDiscretizerSuite .setInputCols(Array("input1", "input2")) .setOutputCols(Array("result1", "result2")) .setNumBuckets(numBuckets) - val result = discretizer.fit(df).transform(df) - - val relativeError = discretizer.getRelativeError - val isGoodBucket = udf { - (size: Int) => math.abs( size - (datasetSize / numBuckets)) <= (relativeError * datasetSize) - } - - for (i <- 1 to 2) { - val observedNumBuckets = result.select("result" + i).distinct.count - assert(observedNumBuckets === numBuckets, - "Observed number of buckets does not equal expected number of buckets.") - - val numGoodBuckets = result.groupBy("result" + i).count.filter(isGoodBucket($"count")).count - assert(numGoodBuckets === numBuckets, - "Bucket sizes are not within expected relative error tolerance.") + val model = discretizer.fit(df) + testTransformerByGlobalCheckFunc[(Double, Double)](df, model, "result1", "result2") { rows => + val result = + rows.map { r => Tuple2(r.getDouble(0), r.getDouble(1)) }.toDF("result1", "result2") + val relativeError = discretizer.getRelativeError + for (i <- 1 to 2) { + val observedNumBuckets = result.select("result" + i).distinct.count + assert(observedNumBuckets === numBuckets, + "Observed number of buckets does not equal expected number of buckets.") + + val numGoodBuckets = result + .groupBy("result" + i) + .count + .filter(s"abs(count - ${datasetSize / numBuckets}) <= ${relativeError * datasetSize}") + .count + assert(numGoodBuckets === numBuckets, + "Bucket sizes are not within expected relative error tolerance.") + } } } @@ -198,12 +206,16 @@ class QuantileDiscretizerSuite .setInputCols(Array("input1", "input2")) .setOutputCols(Array("result1", "result2")) .setNumBuckets(numBuckets) - val result = discretizer.fit(df).transform(df) - for (i <- 1 to 2) { - val observedNumBuckets = result.select("result" + i).distinct.count - assert(observedNumBuckets == expectedNumBucket, - s"Observed number of buckets are not correct." + - s" Expected $expectedNumBucket but found ($observedNumBuckets") + val model = discretizer.fit(df) + testTransformerByGlobalCheckFunc[(Double, Double)](df, model, "result1", "result2") { rows => + val result = + rows.map { r => Tuple2(r.getDouble(0), r.getDouble(1)) }.toDF("result1", "result2") + for (i <- 1 to 2) { + val observedNumBuckets = result.select("result" + i).distinct.count + assert(observedNumBuckets == expectedNumBucket, + s"Observed number of buckets are not correct." + + s" Expected $expectedNumBucket but found ($observedNumBuckets") + } } } @@ -226,9 +238,12 @@ class QuantileDiscretizerSuite withClue("QuantileDiscretizer with handleInvalid=error should throw exception for NaN values") { val dataFrame: DataFrame = validData1.zip(validData2).toSeq.toDF("input1", "input2") - intercept[SparkException] { - discretizer.fit(dataFrame).transform(dataFrame).collect() - } + val model = discretizer.fit(dataFrame) + testTransformerByInterceptingException[(Double, Double)]( + dataFrame, + model, + expectedMessagePart = "Bucketizer encountered NaN value.", + firstResultCol = "result1") } List(("keep", expectedKeep1, expectedKeep2), ("skip", expectedSkip1, expectedSkip2)).foreach { @@ -237,8 +252,14 @@ class QuantileDiscretizerSuite val dataFrame: DataFrame = validData1.zip(validData2).zip(v).zip(w).map { case (((a, b), c), d) => (a, b, c, d) }.toSeq.toDF("input1", "input2", "expected1", "expected2") - val result = discretizer.fit(dataFrame).transform(dataFrame) - result.select("result1", "expected1", "result2", "expected2").collect().foreach { + val model = discretizer.fit(dataFrame) + testTransformer[(Double, Double, Double, Double)]( + dataFrame, + model, + "result1", + "expected1", + "result2", + "expected2") { case Row(x: Double, y: Double, z: Double, w: Double) => assert(x === y && w === z) } @@ -270,9 +291,16 @@ class QuantileDiscretizerSuite .setOutputCols(Array("result1", "result2", "result3")) .setNumBucketsArray(numBucketsArray) - discretizer.fit(df).transform(df). - select("result1", "expected1", "result2", "expected2", "result3", "expected3") - .collect().foreach { + val model = discretizer.fit(df) + testTransformer[(Double, Double, Double, Double, Double, Double)]( + df, + model, + "result1", + "expected1", + "result2", + "expected2", + "result3", + "expected3") { case Row(r1: Double, e1: Double, r2: Double, e2: Double, r3: Double, e3: Double) => assert(r1 === e1, s"The result value is not correct after bucketing. Expected $e1 but found $r1") @@ -324,20 +352,16 @@ class QuantileDiscretizerSuite .setStages(Array(discretizerForCol1, discretizerForCol2, discretizerForCol3)) .fit(df) - val resultForMultiCols = plForMultiCols.transform(df) - .select("result1", "result2", "result3") - .collect() - - val resultForSingleCol = plForSingleCol.transform(df) - .select("result1", "result2", "result3") - .collect() + val expected = plForSingleCol.transform(df).select("result1", "result2", "result3").collect() - resultForSingleCol.zip(resultForMultiCols).foreach { - case (rowForSingle, rowForMultiCols) => - assert(rowForSingle.getDouble(0) == rowForMultiCols.getDouble(0) && - rowForSingle.getDouble(1) == rowForMultiCols.getDouble(1) && - rowForSingle.getDouble(2) == rowForMultiCols.getDouble(2)) - } + testTransformerByGlobalCheckFunc[(Double, Double, Double)]( + df, + plForMultiCols, + "result1", + "result2", + "result3") { rows => + assert(rows === expected) + } } test("Multiple Columns: Comparing setting numBuckets with setting numBucketsArray " + @@ -364,18 +388,16 @@ class QuantileDiscretizerSuite .setOutputCols(Array("result1", "result2", "result3")) .setNumBucketsArray(Array(10, 10, 10)) - val result1 = discretizerSingleNumBuckets.fit(df).transform(df) - .select("result1", "result2", "result3") - .collect() - val result2 = discretizerNumBucketsArray.fit(df).transform(df) - .select("result1", "result2", "result3") - .collect() - - result1.zip(result2).foreach { - case (row1, row2) => - assert(row1.getDouble(0) == row2.getDouble(0) && - row1.getDouble(1) == row2.getDouble(1) && - row1.getDouble(2) == row2.getDouble(2)) + val model = discretizerSingleNumBuckets.fit(df) + val expected = model.transform(df).select("result1", "result2", "result3").collect() + + testTransformerByGlobalCheckFunc[(Double, Double, Double)]( + df, + discretizerNumBucketsArray.fit(df), + "result1", + "result2", + "result3") { rows => + assert(rows === expected) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala index bfe38d32dd77d..27d570f0b68ad 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala @@ -17,7 +17,6 @@ package org.apache.spark.ml.feature -import org.apache.spark.SparkException import org.apache.spark.ml.attribute._ import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite @@ -32,10 +31,20 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { def testRFormulaTransform[A: Encoder]( dataframe: DataFrame, formulaModel: RFormulaModel, - expected: DataFrame): Unit = { + expected: DataFrame, + expectedAttributes: AttributeGroup*): Unit = { + val resultSchema = formulaModel.transformSchema(dataframe.schema) + assert(resultSchema.json === expected.schema.json) + assert(resultSchema === expected.schema) val (first +: rest) = expected.schema.fieldNames.toSeq val expectedRows = expected.collect() testTransformerByGlobalCheckFunc[A](dataframe, formulaModel, first, rest: _*) { rows => + assert(rows.head.schema.toString() == resultSchema.toString()) + for (expectedAttributeGroup <- expectedAttributes) { + val attributeGroup = + AttributeGroup.fromStructField(rows.head.schema(expectedAttributeGroup.name)) + assert(attributeGroup === expectedAttributeGroup) + } assert(rows === expectedRows) } } @@ -49,15 +58,10 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { val original = Seq((0, 1.0, 3.0), (2, 2.0, 5.0)).toDF("id", "v1", "v2") val model = formula.fit(original) MLTestingUtils.checkCopyAndUids(formula, model) - val result = model.transform(original) - val resultSchema = model.transformSchema(original.schema) val expected = Seq( (0, 1.0, 3.0, Vectors.dense(1.0, 3.0), 0.0), (2, 2.0, 5.0, Vectors.dense(2.0, 5.0), 2.0) ).toDF("id", "v1", "v2", "features", "label") - // TODO(ekl) make schema comparisons ignore metadata, to avoid .toString - assert(result.schema.toString == resultSchema.toString) - assert(resultSchema == expected.schema) testRFormulaTransform[(Int, Double, Double)](original, model, expected) } @@ -73,9 +77,13 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { val formula = new RFormula().setFormula("y ~ x").setLabelCol("y") val original = Seq((0, 1.0), (2, 2.0)).toDF("x", "y") val model = formula.fit(original) + val expected = Seq( + (0, 1.0, Vectors.dense(0.0)), + (2, 2.0, Vectors.dense(2.0)) + ).toDF("x", "y", "features") val resultSchema = model.transformSchema(original.schema) assert(resultSchema.length == 3) - assert(resultSchema.toString == model.transform(original).schema.toString) + testRFormulaTransform[(Int, Double)](original, model, expected) } test("label column already exists but forceIndexLabel was set with true") { @@ -93,9 +101,11 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { intercept[IllegalArgumentException] { model.transformSchema(original.schema) } - intercept[IllegalArgumentException] { - model.transform(original) - } + testTransformerByInterceptingException[(Int, Boolean)]( + original, + model, + "Label column already exists and is not of type NumericType.", + "x") } test("allow missing label column for test datasets") { @@ -105,21 +115,22 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { val resultSchema = model.transformSchema(original.schema) assert(resultSchema.length == 3) assert(!resultSchema.exists(_.name == "label")) - assert(resultSchema.toString == model.transform(original).schema.toString) + val expected = Seq( + (0, 1.0, Vectors.dense(0.0)), + (2, 2.0, Vectors.dense(2.0)) + ).toDF("x", "_not_y", "features") + testRFormulaTransform[(Int, Double)](original, model, expected) } test("allow empty label") { val original = Seq((1, 2.0, 3.0), (4, 5.0, 6.0), (7, 8.0, 9.0)).toDF("id", "a", "b") val formula = new RFormula().setFormula("~ a + b") val model = formula.fit(original) - val result = model.transform(original) - val resultSchema = model.transformSchema(original.schema) val expected = Seq( (1, 2.0, 3.0, Vectors.dense(2.0, 3.0)), (4, 5.0, 6.0, Vectors.dense(5.0, 6.0)), (7, 8.0, 9.0, Vectors.dense(8.0, 9.0)) ).toDF("id", "a", "b", "features") - assert(result.schema.toString == resultSchema.toString) testRFormulaTransform[(Int, Double, Double)](original, model, expected) } @@ -128,15 +139,12 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { val original = Seq((1, "foo", 4), (2, "bar", 4), (3, "bar", 5), (4, "baz", 5)) .toDF("id", "a", "b") val model = formula.fit(original) - val result = model.transform(original) - val resultSchema = model.transformSchema(original.schema) val expected = Seq( (1, "foo", 4, Vectors.dense(0.0, 1.0, 4.0), 1.0), (2, "bar", 4, Vectors.dense(1.0, 0.0, 4.0), 2.0), (3, "bar", 5, Vectors.dense(1.0, 0.0, 5.0), 3.0), (4, "baz", 5, Vectors.dense(0.0, 0.0, 5.0), 4.0) ).toDF("id", "a", "b", "features", "label") - assert(result.schema.toString == resultSchema.toString) testRFormulaTransform[(Int, String, Int)](original, model, expected) } @@ -175,9 +183,6 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { var idx = 0 for (orderType <- StringIndexer.supportedStringOrderType) { val model = formula.setStringIndexerOrderType(orderType).fit(original) - val result = model.transform(original) - val resultSchema = model.transformSchema(original.schema) - assert(result.schema.toString == resultSchema.toString) testRFormulaTransform[(Int, String, Int)](original, model, expected(idx)) idx += 1 } @@ -218,9 +223,6 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { ).toDF("id", "a", "b", "features", "label") val model = formula.fit(original) - val result = model.transform(original) - val resultSchema = model.transformSchema(original.schema) - assert(result.schema.toString == resultSchema.toString) testRFormulaTransform[(Int, String, Int)](original, model, expected) } @@ -254,19 +256,6 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { val formula1 = new RFormula().setFormula("id ~ a + b + c - 1") .setStringIndexerOrderType(StringIndexer.alphabetDesc) val model1 = formula1.fit(original) - val result1 = model1.transform(original) - val resultSchema1 = model1.transformSchema(original.schema) - // Note the column order is different between R and Spark. - val expected1 = Seq( - (1, "foo", "zq", 4, Vectors.sparse(5, Array(0, 4), Array(1.0, 4.0)), 1.0), - (2, "bar", "zz", 4, Vectors.dense(0.0, 0.0, 1.0, 1.0, 4.0), 2.0), - (3, "bar", "zz", 5, Vectors.dense(0.0, 0.0, 1.0, 1.0, 5.0), 3.0), - (4, "baz", "zz", 5, Vectors.dense(0.0, 1.0, 0.0, 1.0, 5.0), 4.0) - ).toDF("id", "a", "b", "c", "features", "label") - assert(result1.schema.toString == resultSchema1.toString) - testRFormulaTransform[(Int, String, String, Int)](original, model1, expected1) - - val attrs1 = AttributeGroup.fromStructField(result1.schema("features")) val expectedAttrs1 = new AttributeGroup( "features", Array[Attribute]( @@ -275,14 +264,20 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { new BinaryAttribute(Some("a_bar"), Some(3)), new BinaryAttribute(Some("b_zz"), Some(4)), new NumericAttribute(Some("c"), Some(5)))) - assert(attrs1 === expectedAttrs1) + // Note the column order is different between R and Spark. + val expected1 = Seq( + (1, "foo", "zq", 4, Vectors.sparse(5, Array(0, 4), Array(1.0, 4.0)), 1.0), + (2, "bar", "zz", 4, Vectors.dense(0.0, 0.0, 1.0, 1.0, 4.0), 2.0), + (3, "bar", "zz", 5, Vectors.dense(0.0, 0.0, 1.0, 1.0, 5.0), 3.0), + (4, "baz", "zz", 5, Vectors.dense(0.0, 1.0, 0.0, 1.0, 5.0), 4.0) + ).toDF("id", "a", "b", "c", "features", "label") + + testRFormulaTransform[(Int, String, String, Int)](original, model1, expected1, expectedAttrs1) // There is no impact for string terms interaction. val formula2 = new RFormula().setFormula("id ~ a:b + c - 1") .setStringIndexerOrderType(StringIndexer.alphabetDesc) val model2 = formula2.fit(original) - val result2 = model2.transform(original) - val resultSchema2 = model2.transformSchema(original.schema) // Note the column order is different between R and Spark. val expected2 = Seq( (1, "foo", "zq", 4, Vectors.sparse(7, Array(1, 6), Array(1.0, 4.0)), 1.0), @@ -290,10 +285,6 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { (3, "bar", "zz", 5, Vectors.sparse(7, Array(4, 6), Array(1.0, 5.0)), 3.0), (4, "baz", "zz", 5, Vectors.sparse(7, Array(2, 6), Array(1.0, 5.0)), 4.0) ).toDF("id", "a", "b", "c", "features", "label") - assert(result2.schema.toString == resultSchema2.toString) - testRFormulaTransform[(Int, String, String, Int)](original, model2, expected2) - - val attrs2 = AttributeGroup.fromStructField(result2.schema("features")) val expectedAttrs2 = new AttributeGroup( "features", Array[Attribute]( @@ -304,7 +295,8 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { new NumericAttribute(Some("a_bar:b_zz"), Some(5)), new NumericAttribute(Some("a_bar:b_zq"), Some(6)), new NumericAttribute(Some("c"), Some(7)))) - assert(attrs2 === expectedAttrs2) + + testRFormulaTransform[(Int, String, String, Int)](original, model2, expected2, expectedAttrs2) } test("index string label") { @@ -313,13 +305,14 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { Seq(("male", "foo", 4), ("female", "bar", 4), ("female", "bar", 5), ("male", "baz", 5)) .toDF("id", "a", "b") val model = formula.fit(original) + val attr = NominalAttribute.defaultAttr val expected = Seq( ("male", "foo", 4, Vectors.dense(0.0, 1.0, 4.0), 1.0), ("female", "bar", 4, Vectors.dense(1.0, 0.0, 4.0), 0.0), ("female", "bar", 5, Vectors.dense(1.0, 0.0, 5.0), 0.0), ("male", "baz", 5, Vectors.dense(0.0, 0.0, 5.0), 1.0) ).toDF("id", "a", "b", "features", "label") - // assert(result.schema.toString == resultSchema.toString) + .select($"id", $"a", $"b", $"features", $"label".as("label", attr.toMetadata())) testRFormulaTransform[(String, String, Int)](original, model, expected) } @@ -329,13 +322,14 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { Seq((1.0, "foo", 4), (1.0, "bar", 4), (0.0, "bar", 5), (1.0, "baz", 5)) ).toDF("id", "a", "b") val model = formula.fit(original) - val expected = spark.createDataFrame( - Seq( + val attr = NominalAttribute.defaultAttr + val expected = Seq( (1.0, "foo", 4, Vectors.dense(0.0, 1.0, 4.0), 0.0), (1.0, "bar", 4, Vectors.dense(1.0, 0.0, 4.0), 0.0), (0.0, "bar", 5, Vectors.dense(1.0, 0.0, 5.0), 1.0), (1.0, "baz", 5, Vectors.dense(0.0, 0.0, 5.0), 0.0)) - ).toDF("id", "a", "b", "features", "label") + .toDF("id", "a", "b", "features", "label") + .select($"id", $"a", $"b", $"features", $"label".as("label", attr.toMetadata())) testRFormulaTransform[(Double, String, Int)](original, model, expected) } @@ -344,15 +338,20 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { val original = Seq((1, "foo", 4), (2, "bar", 4), (3, "bar", 5), (4, "baz", 5)) .toDF("id", "a", "b") val model = formula.fit(original) - val result = model.transform(original) - val attrs = AttributeGroup.fromStructField(result.schema("features")) + val expected = Seq( + (1, "foo", 4, Vectors.dense(0.0, 1.0, 4.0), 1.0), + (2, "bar", 4, Vectors.dense(1.0, 0.0, 4.0), 2.0), + (3, "bar", 5, Vectors.dense(1.0, 0.0, 5.0), 3.0), + (4, "baz", 5, Vectors.dense(0.0, 0.0, 5.0), 4.0)) + .toDF("id", "a", "b", "features", "label") val expectedAttrs = new AttributeGroup( "features", Array( new BinaryAttribute(Some("a_bar"), Some(1)), new BinaryAttribute(Some("a_foo"), Some(2)), new NumericAttribute(Some("b"), Some(3)))) - assert(attrs === expectedAttrs) + testRFormulaTransform[(Int, String, Int)](original, model, expected, expectedAttrs) + } test("vector attribute generation") { @@ -360,14 +359,19 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { val original = Seq((1, Vectors.dense(0.0, 1.0)), (2, Vectors.dense(1.0, 2.0))) .toDF("id", "vec") val model = formula.fit(original) - val result = model.transform(original) - val attrs = AttributeGroup.fromStructField(result.schema("features")) + val attrs = new AttributeGroup("vec", 2) + val expected = Seq( + (1, Vectors.dense(0.0, 1.0), Vectors.dense(0.0, 1.0), 1.0), + (2, Vectors.dense(1.0, 2.0), Vectors.dense(1.0, 2.0), 2.0)) + .toDF("id", "vec", "features", "label") + .select($"id", $"vec".as("vec", attrs.toMetadata()), $"features", $"label") val expectedAttrs = new AttributeGroup( "features", Array[Attribute]( new NumericAttribute(Some("vec_0"), Some(1)), new NumericAttribute(Some("vec_1"), Some(2)))) - assert(attrs === expectedAttrs) + + testRFormulaTransform[(Int, Vector)](original, model, expected, expectedAttrs) } test("vector attribute generation with unnamed input attrs") { @@ -381,31 +385,31 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { NumericAttribute.defaultAttr)).toMetadata() val original = base.select(base.col("id"), base.col("vec").as("vec2", metadata)) val model = formula.fit(original) - val result = model.transform(original) - val attrs = AttributeGroup.fromStructField(result.schema("features")) + val expected = Seq( + (1, Vectors.dense(0.0, 1.0), Vectors.dense(0.0, 1.0), 1.0), + (2, Vectors.dense(1.0, 2.0), Vectors.dense(1.0, 2.0), 2.0) + ).toDF("id", "vec2", "features", "label") + .select($"id", $"vec2".as("vec2", metadata), $"features", $"label") val expectedAttrs = new AttributeGroup( "features", Array[Attribute]( new NumericAttribute(Some("vec2_0"), Some(1)), new NumericAttribute(Some("vec2_1"), Some(2)))) - assert(attrs === expectedAttrs) + testRFormulaTransform[(Int, Vector)](original, model, expected, expectedAttrs) } test("numeric interaction") { val formula = new RFormula().setFormula("a ~ b:c:d") val original = Seq((1, 2, 4, 2), (2, 3, 4, 1)).toDF("a", "b", "c", "d") val model = formula.fit(original) - val result = model.transform(original) val expected = Seq( (1, 2, 4, 2, Vectors.dense(16.0), 1.0), (2, 3, 4, 1, Vectors.dense(12.0), 2.0) ).toDF("a", "b", "c", "d", "features", "label") - testRFormulaTransform[(Int, Int, Int, Int)](original, model, expected) - val attrs = AttributeGroup.fromStructField(result.schema("features")) val expectedAttrs = new AttributeGroup( "features", Array[Attribute](new NumericAttribute(Some("b:c:d"), Some(1)))) - assert(attrs === expectedAttrs) + testRFormulaTransform[(Int, Int, Int, Int)](original, model, expected, expectedAttrs) } test("factor numeric interaction") { @@ -414,7 +418,6 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { Seq((1, "foo", 4), (2, "bar", 4), (3, "bar", 5), (4, "baz", 5), (4, "baz", 5), (4, "baz", 5)) .toDF("id", "a", "b") val model = formula.fit(original) - val result = model.transform(original) val expected = Seq( (1, "foo", 4, Vectors.dense(0.0, 0.0, 4.0), 1.0), (2, "bar", 4, Vectors.dense(0.0, 4.0, 0.0), 2.0), @@ -423,15 +426,13 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { (4, "baz", 5, Vectors.dense(5.0, 0.0, 0.0), 4.0), (4, "baz", 5, Vectors.dense(5.0, 0.0, 0.0), 4.0) ).toDF("id", "a", "b", "features", "label") - testRFormulaTransform[(Int, String, Int)](original, model, expected) - val attrs = AttributeGroup.fromStructField(result.schema("features")) val expectedAttrs = new AttributeGroup( "features", Array[Attribute]( new NumericAttribute(Some("a_baz:b"), Some(1)), new NumericAttribute(Some("a_bar:b"), Some(2)), new NumericAttribute(Some("a_foo:b"), Some(3)))) - assert(attrs === expectedAttrs) + testRFormulaTransform[(Int, String, Int)](original, model, expected, expectedAttrs) } test("factor factor interaction") { @@ -439,14 +440,12 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { val original = Seq((1, "foo", "zq"), (2, "bar", "zq"), (3, "bar", "zz")).toDF("id", "a", "b") val model = formula.fit(original) - val result = model.transform(original) val expected = Seq( (1, "foo", "zq", Vectors.dense(0.0, 0.0, 1.0, 0.0), 1.0), (2, "bar", "zq", Vectors.dense(1.0, 0.0, 0.0, 0.0), 2.0), (3, "bar", "zz", Vectors.dense(0.0, 1.0, 0.0, 0.0), 3.0) ).toDF("id", "a", "b", "features", "label") testRFormulaTransform[(Int, String, String)](original, model, expected) - val attrs = AttributeGroup.fromStructField(result.schema("features")) val expectedAttrs = new AttributeGroup( "features", Array[Attribute]( @@ -454,7 +453,7 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { new NumericAttribute(Some("a_bar:b_zz"), Some(2)), new NumericAttribute(Some("a_foo:b_zq"), Some(3)), new NumericAttribute(Some("a_foo:b_zz"), Some(4)))) - assert(attrs === expectedAttrs) + testRFormulaTransform[(Int, String, String)](original, model, expected, expectedAttrs) } test("read/write: RFormula") { @@ -517,9 +516,11 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { // Handle unseen features. val formula1 = new RFormula().setFormula("id ~ a + b") - intercept[SparkException] { - formula1.fit(df1).transform(df2).collect() - } + testTransformerByInterceptingException[(Int, String, String)]( + df2, + formula1.fit(df1), + "Unseen label:", + "features") val model1 = formula1.setHandleInvalid("skip").fit(df1) val model2 = formula1.setHandleInvalid("keep").fit(df1) @@ -538,21 +539,28 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { // Handle unseen labels. val formula2 = new RFormula().setFormula("b ~ a + id") - intercept[SparkException] { - formula2.fit(df1).transform(df2).collect() - } + testTransformerByInterceptingException[(Int, String, String)]( + df2, + formula2.fit(df1), + "Unseen label:", + "label") + val model3 = formula2.setHandleInvalid("skip").fit(df1) val model4 = formula2.setHandleInvalid("keep").fit(df1) + val attr = NominalAttribute.defaultAttr val expected3 = Seq( (1, "foo", "zq", Vectors.dense(0.0, 1.0), 0.0), (2, "bar", "zq", Vectors.dense(1.0, 2.0), 0.0) ).toDF("id", "a", "b", "features", "label") + .select($"id", $"a", $"b", $"features", $"label".as("label", attr.toMetadata())) + val expected4 = Seq( (1, "foo", "zq", Vectors.dense(0.0, 1.0, 1.0), 0.0), (2, "bar", "zq", Vectors.dense(1.0, 0.0, 2.0), 0.0), (3, "bar", "zy", Vectors.dense(1.0, 0.0, 3.0), 2.0) ).toDF("id", "a", "b", "features", "label") + .select($"id", $"a", $"b", $"features", $"label".as("label", attr.toMetadata())) testRFormulaTransform[(Int, String, String)](df2, model3, expected3) testRFormulaTransform[(Int, String, String)](df2, model4, expected4) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala index 673a146e619f2..cf09418d8e0a2 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala @@ -17,15 +17,12 @@ package org.apache.spark.ml.feature -import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.DefaultReadWriteTest -import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest} import org.apache.spark.sql.types.{LongType, StructField, StructType} import org.apache.spark.storage.StorageLevel -class SQLTransformerSuite - extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class SQLTransformerSuite extends MLTest with DefaultReadWriteTest { import testImplicits._ @@ -37,14 +34,22 @@ class SQLTransformerSuite val original = Seq((0, 1.0, 3.0), (2, 2.0, 5.0)).toDF("id", "v1", "v2") val sqlTrans = new SQLTransformer().setStatement( "SELECT *, (v1 + v2) AS v3, (v1 * v2) AS v4 FROM __THIS__") - val result = sqlTrans.transform(original) - val resultSchema = sqlTrans.transformSchema(original.schema) - val expected = Seq((0, 1.0, 3.0, 4.0, 3.0), (2, 2.0, 5.0, 7.0, 10.0)) + val expected = Seq((0, 1.0, 3.0, 4.0, 3.0), (2, 2.0, 5.0, 7.0, 10.0)) .toDF("id", "v1", "v2", "v3", "v4") - assert(result.schema.toString == resultSchema.toString) - assert(resultSchema == expected.schema) - assert(result.collect().toSeq == expected.collect().toSeq) - assert(original.sparkSession.catalog.listTables().count() == 0) + val resultSchema = sqlTrans.transformSchema(original.schema) + testTransformerByGlobalCheckFunc[(Int, Double, Double)]( + original, + sqlTrans, + "id", + "v1", + "v2", + "v3", + "v4") { rows => + assert(rows.head.schema.toString == resultSchema.toString) + assert(resultSchema == expected.schema) + assert(rows == expected.collect().toSeq) + assert(original.sparkSession.catalog.listTables().count() == 0) + } } test("read/write") { @@ -63,13 +68,13 @@ class SQLTransformerSuite } test("SPARK-22538: SQLTransformer should not unpersist given dataset") { - val df = spark.range(10) + val df = spark.range(10).toDF() df.cache() df.count() assert(df.storageLevel != StorageLevel.NONE) - new SQLTransformer() + val sqlTrans = new SQLTransformer() .setStatement("SELECT id + 1 AS id1 FROM __THIS__") - .transform(df) + testTransformerByGlobalCheckFunc[Long](df, sqlTrans, "id1") { _ => } assert(df.storageLevel != StorageLevel.NONE) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala index 350ba44baa1eb..c5c49d67194e4 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala @@ -17,16 +17,13 @@ package org.apache.spark.ml.feature -import org.apache.spark.SparkFunSuite import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ -import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, Row} -class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext - with DefaultReadWriteTest { +class StandardScalerSuite extends MLTest with DefaultReadWriteTest { import testImplicits._ @@ -60,12 +57,10 @@ class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext ) } - def assertResult(df: DataFrame): Unit = { - df.select("standardized_features", "expected").collect().foreach { - case Row(vector1: Vector, vector2: Vector) => - assert(vector1 ~== vector2 absTol 1E-5, - "The vector value is not correct after standardization.") - } + def assertResult: Row => Unit = { + case Row(vector1: Vector, vector2: Vector) => + assert(vector1 ~== vector2 absTol 1E-5, + "The vector value is not correct after standardization.") } test("params") { @@ -83,7 +78,8 @@ class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext val standardScaler0 = standardScalerEst0.fit(df0) MLTestingUtils.checkCopyAndUids(standardScalerEst0, standardScaler0) - assertResult(standardScaler0.transform(df0)) + testTransformer[(Vector, Vector)](df0, standardScaler0, "standardized_features", "expected")( + assertResult) } test("Standardization with setter") { @@ -112,9 +108,12 @@ class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext .setWithStd(false) .fit(df3) - assertResult(standardScaler1.transform(df1)) - assertResult(standardScaler2.transform(df2)) - assertResult(standardScaler3.transform(df3)) + testTransformer[(Vector, Vector)](df1, standardScaler1, "standardized_features", "expected")( + assertResult) + testTransformer[(Vector, Vector)](df2, standardScaler2, "standardized_features", "expected")( + assertResult) + testTransformer[(Vector, Vector)](df3, standardScaler3, "standardized_features", "expected")( + assertResult) } test("sparse data and withMean") { @@ -130,7 +129,8 @@ class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext .setWithMean(true) .setWithStd(false) .fit(df) - assertResult(standardScaler.transform(df)) + testTransformer[(Vector, Vector)](df, standardScaler, "standardized_features", "expected")( + assertResult) } test("StandardScaler read/write") { @@ -149,4 +149,5 @@ class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext assert(newInstance.std === instance.std) assert(newInstance.mean === instance.mean) } + } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala index 5262b146b184e..21259a50916d2 100755 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala @@ -17,28 +17,20 @@ package org.apache.spark.ml.feature -import org.apache.spark.SparkFunSuite -import org.apache.spark.ml.util.DefaultReadWriteTest -import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.{Dataset, Row} - -object StopWordsRemoverSuite extends SparkFunSuite { - def testStopWordsRemover(t: StopWordsRemover, dataset: Dataset[_]): Unit = { - t.transform(dataset) - .select("filtered", "expected") - .collect() - .foreach { case Row(tokens, wantedTokens) => - assert(tokens === wantedTokens) - } - } -} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest} +import org.apache.spark.sql.{DataFrame, Row} -class StopWordsRemoverSuite - extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class StopWordsRemoverSuite extends MLTest with DefaultReadWriteTest { - import StopWordsRemoverSuite._ import testImplicits._ + def testStopWordsRemover(t: StopWordsRemover, dataFrame: DataFrame): Unit = { + testTransformer[(Array[String], Array[String])](dataFrame, t, "filtered", "expected") { + case Row(tokens: Seq[_], wantedTokens: Seq[_]) => + assert(tokens === wantedTokens) + } + } + test("StopWordsRemover default") { val remover = new StopWordsRemover() .setInputCol("raw") @@ -151,9 +143,10 @@ class StopWordsRemoverSuite .setOutputCol(outputCol) val dataSet = Seq((Seq("The", "the", "swift"), Seq("swift"))).toDF("raw", outputCol) - val thrown = intercept[IllegalArgumentException] { - testStopWordsRemover(remover, dataSet) - } - assert(thrown.getMessage == s"requirement failed: Column $outputCol already exists.") + testTransformerByInterceptingException[(Array[String], Array[String])]( + dataSet, + remover, + s"requirement failed: Column $outputCol already exists.", + "expected") } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala index 775a04d3df050..df24367177011 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala @@ -17,17 +17,14 @@ package org.apache.spark.ml.feature -import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.attribute.{Attribute, NominalAttribute} import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} -import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.sql.Row import org.apache.spark.sql.functions.col import org.apache.spark.sql.types.{DoubleType, StringType, StructField, StructType} -class StringIndexerSuite - extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class StringIndexerSuite extends MLTest with DefaultReadWriteTest { import testImplicits._ @@ -46,19 +43,23 @@ class StringIndexerSuite .setInputCol("label") .setOutputCol("labelIndex") val indexerModel = indexer.fit(df) - MLTestingUtils.checkCopyAndUids(indexer, indexerModel) - - val transformed = indexerModel.transform(df) - val attr = Attribute.fromStructField(transformed.schema("labelIndex")) - .asInstanceOf[NominalAttribute] - assert(attr.values.get === Array("a", "c", "b")) - val output = transformed.select("id", "labelIndex").rdd.map { r => - (r.getInt(0), r.getDouble(1)) - }.collect().toSet // a -> 0, b -> 2, c -> 1 - val expected = Set((0, 0.0), (1, 2.0), (2, 1.0), (3, 0.0), (4, 0.0), (5, 1.0)) - assert(output === expected) + val expected = Seq( + (0, 0.0), + (1, 2.0), + (2, 1.0), + (3, 0.0), + (4, 0.0), + (5, 1.0) + ).toDF("id", "labelIndex") + + testTransformerByGlobalCheckFunc[(Int, String)](df, indexerModel, "id", "labelIndex") { rows => + val attr = Attribute.fromStructField(rows.head.schema("labelIndex")) + .asInstanceOf[NominalAttribute] + assert(attr.values.get === Array("a", "c", "b")) + assert(rows.seq === expected.collect().toSeq) + } } test("StringIndexerUnseen") { @@ -70,36 +71,38 @@ class StringIndexerSuite .setInputCol("label") .setOutputCol("labelIndex") .fit(df) + // Verify we throw by default with unseen values - intercept[SparkException] { - indexer.transform(df2).collect() - } + testTransformerByInterceptingException[(Int, String)]( + df2, + indexer, + "Unseen label:", + "labelIndex") - indexer.setHandleInvalid("skip") // Verify that we skip the c record - val transformedSkip = indexer.transform(df2) - val attrSkip = Attribute.fromStructField(transformedSkip.schema("labelIndex")) - .asInstanceOf[NominalAttribute] - assert(attrSkip.values.get === Array("b", "a")) - val outputSkip = transformedSkip.select("id", "labelIndex").rdd.map { r => - (r.getInt(0), r.getDouble(1)) - }.collect().toSet // a -> 1, b -> 0 - val expectedSkip = Set((0, 1.0), (1, 0.0)) - assert(outputSkip === expectedSkip) + indexer.setHandleInvalid("skip") + + val expectedSkip = Seq((0, 1.0), (1, 0.0)).toDF() + testTransformerByGlobalCheckFunc[(Int, String)](df2, indexer, "id", "labelIndex") { rows => + val attrSkip = Attribute.fromStructField(rows.head.schema("labelIndex")) + .asInstanceOf[NominalAttribute] + assert(attrSkip.values.get === Array("b", "a")) + assert(rows.seq === expectedSkip.collect().toSeq) + } indexer.setHandleInvalid("keep") - // Verify that we keep the unseen records - val transformedKeep = indexer.transform(df2) - val attrKeep = Attribute.fromStructField(transformedKeep.schema("labelIndex")) - .asInstanceOf[NominalAttribute] - assert(attrKeep.values.get === Array("b", "a", "__unknown")) - val outputKeep = transformedKeep.select("id", "labelIndex").rdd.map { r => - (r.getInt(0), r.getDouble(1)) - }.collect().toSet + // a -> 1, b -> 0, c -> 2, d -> 3 - val expectedKeep = Set((0, 1.0), (1, 0.0), (2, 2.0), (3, 2.0)) - assert(outputKeep === expectedKeep) + val expectedKeep = Seq((0, 1.0), (1, 0.0), (2, 2.0), (3, 2.0)).toDF() + + // Verify that we keep the unseen records + testTransformerByGlobalCheckFunc[(Int, String)](df2, indexer, "id", "labelIndex") { rows => + val attrKeep = Attribute.fromStructField(rows.head.schema("labelIndex")) + .asInstanceOf[NominalAttribute] + assert(attrKeep.values.get === Array("b", "a", "__unknown")) + assert(rows === expectedKeep.collect().toSeq) + } } test("StringIndexer with a numeric input column") { @@ -109,16 +112,14 @@ class StringIndexerSuite .setInputCol("label") .setOutputCol("labelIndex") .fit(df) - val transformed = indexer.transform(df) - val attr = Attribute.fromStructField(transformed.schema("labelIndex")) - .asInstanceOf[NominalAttribute] - assert(attr.values.get === Array("100", "300", "200")) - val output = transformed.select("id", "labelIndex").rdd.map { r => - (r.getInt(0), r.getDouble(1)) - }.collect().toSet // 100 -> 0, 200 -> 2, 300 -> 1 - val expected = Set((0, 0.0), (1, 2.0), (2, 1.0), (3, 0.0), (4, 0.0), (5, 1.0)) - assert(output === expected) + val expected = Seq((0, 0.0), (1, 2.0), (2, 1.0), (3, 0.0), (4, 0.0), (5, 1.0)).toDF() + testTransformerByGlobalCheckFunc[(Int, String)](df, indexer, "id", "labelIndex") { rows => + val attr = Attribute.fromStructField(rows.head.schema("labelIndex")) + .asInstanceOf[NominalAttribute] + assert(attr.values.get === Array("100", "300", "200")) + assert(rows === expected.collect().toSeq) + } } test("StringIndexer with NULLs") { @@ -133,37 +134,36 @@ class StringIndexerSuite withClue("StringIndexer should throw error when setHandleInvalid=error " + "when given NULL values") { - intercept[SparkException] { - indexer.setHandleInvalid("error") - indexer.fit(df).transform(df2).collect() - } + indexer.setHandleInvalid("error") + testTransformerByInterceptingException[(Int, String)]( + df2, + indexer.fit(df), + "StringIndexer encountered NULL value.", + "labelIndex") } indexer.setHandleInvalid("skip") - val transformedSkip = indexer.fit(df).transform(df2) - val attrSkip = Attribute - .fromStructField(transformedSkip.schema("labelIndex")) - .asInstanceOf[NominalAttribute] - assert(attrSkip.values.get === Array("b", "a")) - val outputSkip = transformedSkip.select("id", "labelIndex").rdd.map { r => - (r.getInt(0), r.getDouble(1)) - }.collect().toSet + val modelSkip = indexer.fit(df) // a -> 1, b -> 0 - val expectedSkip = Set((0, 1.0), (1, 0.0)) - assert(outputSkip === expectedSkip) + val expectedSkip = Seq((0, 1.0), (1, 0.0)).toDF() + testTransformerByGlobalCheckFunc[(Int, String)](df2, modelSkip, "id", "labelIndex") { rows => + val attrSkip = + Attribute.fromStructField(rows.head.schema("labelIndex")).asInstanceOf[NominalAttribute] + assert(attrSkip.values.get === Array("b", "a")) + assert(rows === expectedSkip.collect().toSeq) + } indexer.setHandleInvalid("keep") - val transformedKeep = indexer.fit(df).transform(df2) - val attrKeep = Attribute - .fromStructField(transformedKeep.schema("labelIndex")) - .asInstanceOf[NominalAttribute] - assert(attrKeep.values.get === Array("b", "a", "__unknown")) - val outputKeep = transformedKeep.select("id", "labelIndex").rdd.map { r => - (r.getInt(0), r.getDouble(1)) - }.collect().toSet // a -> 1, b -> 0, null -> 2 - val expectedKeep = Set((0, 1.0), (1, 0.0), (3, 2.0)) - assert(outputKeep === expectedKeep) + val expectedKeep = Seq((0, 1.0), (1, 0.0), (3, 2.0)).toDF() + val modelKeep = indexer.fit(df) + testTransformerByGlobalCheckFunc[(Int, String)](df2, modelKeep, "id", "labelIndex") { rows => + val attrKeep = Attribute + .fromStructField(rows.head.schema("labelIndex")) + .asInstanceOf[NominalAttribute] + assert(attrKeep.values.get === Array("b", "a", "__unknown")) + assert(rows === expectedKeep.collect().toSeq) + } } test("StringIndexerModel should keep silent if the input column does not exist.") { @@ -171,7 +171,9 @@ class StringIndexerSuite .setInputCol("label") .setOutputCol("labelIndex") val df = spark.range(0L, 10L).toDF() - assert(indexerModel.transform(df).collect().toSet === df.collect().toSet) + testTransformerByGlobalCheckFunc[Long](df, indexerModel, "id") { rows => + assert(rows.toSet === df.collect().toSet) + } } test("StringIndexerModel can't overwrite output column") { @@ -188,9 +190,12 @@ class StringIndexerSuite .setOutputCol("indexedInput") .fit(df) - intercept[IllegalArgumentException] { - indexer.setOutputCol("output").transform(df) - } + testTransformerByInterceptingException[(Int, String)]( + df, + indexer.setOutputCol("output"), + "Output column output already exists.", + "labelIndex") + } test("StringIndexer read/write") { @@ -223,7 +228,8 @@ class StringIndexerSuite .setInputCol("index") .setOutputCol("actual") .setLabels(labels) - idxToStr0.transform(df0).select("actual", "expected").collect().foreach { + + testTransformer[(Int, String)](df0, idxToStr0, "actual", "expected") { case Row(actual, expected) => assert(actual === expected) } @@ -234,7 +240,8 @@ class StringIndexerSuite val idxToStr1 = new IndexToString() .setInputCol("indexWithAttr") .setOutputCol("actual") - idxToStr1.transform(df1).select("actual", "expected").collect().foreach { + + testTransformer[(Int, String)](df1, idxToStr1, "actual", "expected") { case Row(actual, expected) => assert(actual === expected) } @@ -252,9 +259,10 @@ class StringIndexerSuite .setInputCol("labelIndex") .setOutputCol("sameLabel") .setLabels(indexer.labels) - idx2str.transform(transformed).select("label", "sameLabel").collect().foreach { - case Row(a: String, b: String) => - assert(a === b) + + testTransformer[(Int, String, Double)](transformed, idx2str, "sameLabel", "label") { + case Row(sameLabel, label) => + assert(sameLabel === label) } } @@ -286,10 +294,11 @@ class StringIndexerSuite .setInputCol("label") .setOutputCol("labelIndex") .fit(df) - val transformed = indexer.transform(df) - val attrs = - NominalAttribute.decodeStructField(transformed.schema("labelIndex"), preserveName = true) - assert(attrs.name.nonEmpty && attrs.name.get === "labelIndex") + testTransformerByGlobalCheckFunc[(Int, String)](df, indexer, "labelIndex") { rows => + val attrs = + NominalAttribute.decodeStructField(rows.head.schema("labelIndex"), preserveName = true) + assert(attrs.name.nonEmpty && attrs.name.get === "labelIndex") + } } test("StringIndexer order types") { @@ -299,18 +308,17 @@ class StringIndexerSuite .setInputCol("label") .setOutputCol("labelIndex") - val expected = Seq(Set((0, 0.0), (1, 0.0), (2, 2.0), (3, 1.0), (4, 1.0), (5, 0.0)), - Set((0, 2.0), (1, 2.0), (2, 0.0), (3, 1.0), (4, 1.0), (5, 2.0)), - Set((0, 1.0), (1, 1.0), (2, 0.0), (3, 2.0), (4, 2.0), (5, 1.0)), - Set((0, 1.0), (1, 1.0), (2, 2.0), (3, 0.0), (4, 0.0), (5, 1.0))) + val expected = Seq(Seq((0, 0.0), (1, 0.0), (2, 2.0), (3, 1.0), (4, 1.0), (5, 0.0)), + Seq((0, 2.0), (1, 2.0), (2, 0.0), (3, 1.0), (4, 1.0), (5, 2.0)), + Seq((0, 1.0), (1, 1.0), (2, 0.0), (3, 2.0), (4, 2.0), (5, 1.0)), + Seq((0, 1.0), (1, 1.0), (2, 2.0), (3, 0.0), (4, 0.0), (5, 1.0))) var idx = 0 for (orderType <- StringIndexer.supportedStringOrderType) { - val transformed = indexer.setStringOrderType(orderType).fit(df).transform(df) - val output = transformed.select("id", "labelIndex").rdd.map { r => - (r.getInt(0), r.getDouble(1)) - }.collect().toSet - assert(output === expected(idx)) + val model = indexer.setStringOrderType(orderType).fit(df) + testTransformerByGlobalCheckFunc[(Int, String)](df, model, "id", "labelIndex") { rows => + assert(rows === expected(idx).toDF().collect().toSeq) + } idx += 1 } } @@ -328,7 +336,11 @@ class StringIndexerSuite .setOutputCol("CITYIndexed") .fit(dfNoBristol) - val dfWithIndex = model.transform(dfNoBristol) - assert(dfWithIndex.filter($"CITYIndexed" === 1.0).count == 1) + testTransformerByGlobalCheckFunc[(String, String, String)]( + dfNoBristol, + model, + "CITYIndexed") { rows => + assert(rows.toList.count(_.getDouble(0) == 1.0) === 1) + } } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala index c895659a2d8be..be59b0af2c78e 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala @@ -19,16 +19,14 @@ package org.apache.spark.ml.feature import scala.beans.BeanInfo -import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.DefaultReadWriteTest -import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.{Dataset, Row} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest} +import org.apache.spark.sql.{DataFrame, Row} @BeanInfo case class TokenizerTestData(rawText: String, wantedTokens: Array[String]) -class TokenizerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class TokenizerSuite extends MLTest with DefaultReadWriteTest { test("params") { ParamsSuite.checkParams(new Tokenizer) @@ -42,12 +40,17 @@ class TokenizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defau } } -class RegexTokenizerSuite - extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class RegexTokenizerSuite extends MLTest with DefaultReadWriteTest { - import org.apache.spark.ml.feature.RegexTokenizerSuite._ import testImplicits._ + def testRegexTokenizer(t: RegexTokenizer, dataframe: DataFrame): Unit = { + testTransformer[(String, Seq[String])](dataframe, t, "tokens", "wantedTokens") { + case Row(tokens, wantedTokens) => + assert(tokens === wantedTokens) + } + } + test("params") { ParamsSuite.checkParams(new RegexTokenizer) } @@ -105,14 +108,3 @@ class RegexTokenizerSuite } } -object RegexTokenizerSuite extends SparkFunSuite { - - def testRegexTokenizer(t: RegexTokenizer, dataset: Dataset[_]): Unit = { - t.transform(dataset) - .select("tokens", "wantedTokens") - .collect() - .foreach { case Row(tokens, wantedTokens) => - assert(tokens === wantedTokens) - } - } -} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala index 69a7b75e32eb7..e5675e31bbecf 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala @@ -19,18 +19,16 @@ package org.apache.spark.ml.feature import scala.beans.{BeanInfo, BeanProperty} -import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.SparkException import org.apache.spark.internal.Logging import org.apache.spark.ml.attribute._ import org.apache.spark.ml.linalg.{SparseVector, Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} -import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Row} -class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext - with DefaultReadWriteTest with Logging { +class VectorIndexerSuite extends MLTest with DefaultReadWriteTest with Logging { import testImplicits._ import VectorIndexerSuite.FeatureData @@ -128,18 +126,27 @@ class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext MLTestingUtils.checkCopyAndUids(vectorIndexer, model) - model.transform(densePoints1) // should work - model.transform(sparsePoints1) // should work + testTransformer[FeatureData](densePoints1, model, "indexed") { _ => } + testTransformer[FeatureData](sparsePoints1, model, "indexed") { _ => } + // If the data is local Dataset, it throws AssertionError directly. - intercept[AssertionError] { - model.transform(densePoints2).collect() - logInfo("Did not throw error when fit, transform were called on vectors of different lengths") + withClue("Did not throw error when fit, transform were called on " + + "vectors of different lengths") { + testTransformerByInterceptingException[FeatureData]( + densePoints2, + model, + "VectorIndexerModel expected vector of length 3 but found length 4", + "indexed") } // If the data is distributed Dataset, it throws SparkException // which is the wrapper of AssertionError. - intercept[SparkException] { - model.transform(densePoints2.repartition(2)).collect() - logInfo("Did not throw error when fit, transform were called on vectors of different lengths") + withClue("Did not throw error when fit, transform were called " + + "on vectors of different lengths") { + testTransformerByInterceptingException[FeatureData]( + densePoints2.repartition(2), + model, + "VectorIndexerModel expected vector of length 3 but found length 4", + "indexed") } intercept[SparkException] { vectorIndexer.fit(badPoints) @@ -178,46 +185,48 @@ class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext val categoryMaps = model.categoryMaps // Chose correct categorical features assert(categoryMaps.keys.toSet === categoricalFeatures) - val transformed = model.transform(data).select("indexed") - val indexedRDD: RDD[Vector] = transformed.rdd.map(_.getAs[Vector](0)) - val featureAttrs = AttributeGroup.fromStructField(transformed.schema("indexed")) - assert(featureAttrs.name === "indexed") - assert(featureAttrs.attributes.get.length === model.numFeatures) - categoricalFeatures.foreach { feature: Int => - val origValueSet = collectedData.map(_(feature)).toSet - val targetValueIndexSet = Range(0, origValueSet.size).toSet - val catMap = categoryMaps(feature) - assert(catMap.keys.toSet === origValueSet) // Correct categories - assert(catMap.values.toSet === targetValueIndexSet) // Correct category indices - if (origValueSet.contains(0.0)) { - assert(catMap(0.0) === 0) // value 0 gets index 0 - } - // Check transformed data - assert(indexedRDD.map(_(feature)).collect().toSet === targetValueIndexSet) - // Check metadata - val featureAttr = featureAttrs(feature) - assert(featureAttr.index.get === feature) - featureAttr match { - case attr: BinaryAttribute => - assert(attr.values.get === origValueSet.toArray.sorted.map(_.toString)) - case attr: NominalAttribute => - assert(attr.values.get === origValueSet.toArray.sorted.map(_.toString)) - assert(attr.isOrdinal.get === false) - case _ => - throw new RuntimeException(errMsg + s". Categorical feature $feature failed" + - s" metadata check. Found feature attribute: $featureAttr.") + testTransformerByGlobalCheckFunc[FeatureData](data, model, "indexed") { rows => + val transformed = rows.map { r => Tuple1(r.getAs[Vector](0)) }.toDF("indexed") + val indexedRDD: RDD[Vector] = transformed.rdd.map(_.getAs[Vector](0)) + val featureAttrs = AttributeGroup.fromStructField(rows.head.schema("indexed")) + assert(featureAttrs.name === "indexed") + assert(featureAttrs.attributes.get.length === model.numFeatures) + categoricalFeatures.foreach { feature: Int => + val origValueSet = collectedData.map(_(feature)).toSet + val targetValueIndexSet = Range(0, origValueSet.size).toSet + val catMap = categoryMaps(feature) + assert(catMap.keys.toSet === origValueSet) // Correct categories + assert(catMap.values.toSet === targetValueIndexSet) // Correct category indices + if (origValueSet.contains(0.0)) { + assert(catMap(0.0) === 0) // value 0 gets index 0 + } + // Check transformed data + assert(indexedRDD.map(_(feature)).collect().toSet === targetValueIndexSet) + // Check metadata + val featureAttr = featureAttrs(feature) + assert(featureAttr.index.get === feature) + featureAttr match { + case attr: BinaryAttribute => + assert(attr.values.get === origValueSet.toArray.sorted.map(_.toString)) + case attr: NominalAttribute => + assert(attr.values.get === origValueSet.toArray.sorted.map(_.toString)) + assert(attr.isOrdinal.get === false) + case _ => + throw new RuntimeException(errMsg + s". Categorical feature $feature failed" + + s" metadata check. Found feature attribute: $featureAttr.") + } } - } - // Check numerical feature metadata. - Range(0, model.numFeatures).filter(feature => !categoricalFeatures.contains(feature)) - .foreach { feature: Int => - val featureAttr = featureAttrs(feature) - featureAttr match { - case attr: NumericAttribute => - assert(featureAttr.index.get === feature) - case _ => - throw new RuntimeException(errMsg + s". Numerical feature $feature failed" + - s" metadata check. Found feature attribute: $featureAttr.") + // Check numerical feature metadata. + Range(0, model.numFeatures).filter(feature => !categoricalFeatures.contains(feature)) + .foreach { feature: Int => + val featureAttr = featureAttrs(feature) + featureAttr match { + case attr: NumericAttribute => + assert(featureAttr.index.get === feature) + case _ => + throw new RuntimeException(errMsg + s". Numerical feature $feature failed" + + s" metadata check. Found feature attribute: $featureAttr.") + } } } } catch { @@ -236,25 +245,32 @@ class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext (sparsePoints1, sparsePoints1TestInvalid))) { val vectorIndexer = getIndexer.setMaxCategories(4).setHandleInvalid("error") val model = vectorIndexer.fit(points) - intercept[SparkException] { - model.transform(pointsTestInvalid).collect() - } + testTransformerByInterceptingException[FeatureData]( + pointsTestInvalid, + model, + "VectorIndexer encountered invalid value", + "indexed") val vectorIndexer1 = getIndexer.setMaxCategories(4).setHandleInvalid("skip") val model1 = vectorIndexer1.fit(points) - val invalidTransformed1 = model1.transform(pointsTestInvalid).select("indexed") - .collect().map(_(0)) - val transformed1 = model1.transform(points).select("indexed").collect().map(_(0)) - assert(transformed1 === invalidTransformed1) - + val expected = Seq( + Vectors.dense(1.0, 2.0, 0.0), + Vectors.dense(0.0, 1.0, 2.0), + Vectors.dense(0.0, 0.0, 1.0), + Vectors.dense(1.0, 3.0, 2.0)) + testTransformerByGlobalCheckFunc[FeatureData](pointsTestInvalid, model1, "indexed") { rows => + assert(rows.map(_(0)) == expected) + } + testTransformerByGlobalCheckFunc[FeatureData](points, model1, "indexed") { rows => + assert(rows.map(_(0)) == expected) + } val vectorIndexer2 = getIndexer.setMaxCategories(4).setHandleInvalid("keep") val model2 = vectorIndexer2.fit(points) - val invalidTransformed2 = model2.transform(pointsTestInvalid).select("indexed") - .collect().map(_(0)) - assert(invalidTransformed2 === transformed1 ++ Array( - Vectors.dense(2.0, 2.0, 0.0), - Vectors.dense(0.0, 4.0, 2.0), - Vectors.dense(1.0, 3.0, 3.0)) - ) + testTransformerByGlobalCheckFunc[FeatureData](pointsTestInvalid, model2, "indexed") { rows => + assert(rows.map(_(0)) == expected ++ Array( + Vectors.dense(2.0, 2.0, 0.0), + Vectors dense(0.0, 4.0, 2.0), + Vectors.dense(1.0, 3.0, 3.0))) + } } } @@ -263,12 +279,12 @@ class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext val points = data.collect().map(_.getAs[Vector](0)) val vectorIndexer = getIndexer.setMaxCategories(maxCategories) val model = vectorIndexer.fit(data) - val indexedPoints = - model.transform(data).select("indexed").rdd.map(_.getAs[Vector](0)).collect() - points.zip(indexedPoints).foreach { - case (orig: SparseVector, indexed: SparseVector) => - assert(orig.indices.length == indexed.indices.length) - case _ => throw new UnknownError("Unit test has a bug in it.") // should never happen + testTransformerByGlobalCheckFunc[FeatureData](data, model, "indexed") { rows => + points.zip(rows.map(_(0))).foreach { + case (orig: SparseVector, indexed: SparseVector) => + assert(orig.indices.length == indexed.indices.length) + case _ => throw new UnknownError("Unit test has a bug in it.") // should never happen + } } } checkSparsity(sparsePoints1, maxCategories = 2) @@ -286,17 +302,18 @@ class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext val vectorIndexer = getIndexer.setMaxCategories(2) val model = vectorIndexer.fit(densePoints1WithMeta) // Check that ML metadata are preserved. - val indexedPoints = model.transform(densePoints1WithMeta) - val transAttributes: Array[Attribute] = - AttributeGroup.fromStructField(indexedPoints.schema("indexed")).attributes.get - featureAttributes.zip(transAttributes).foreach { case (orig, trans) => - assert(orig.name === trans.name) - (orig, trans) match { - case (orig: NumericAttribute, trans: NumericAttribute) => - assert(orig.max.nonEmpty && orig.max === trans.max) - case _ => + testTransformerByGlobalCheckFunc[FeatureData](densePoints1WithMeta, model, "indexed") { rows => + val transAttributes: Array[Attribute] = + AttributeGroup.fromStructField(rows.head.schema("indexed")).attributes.get + featureAttributes.zip(transAttributes).foreach { case (orig, trans) => + assert(orig.name === trans.name) + (orig, trans) match { + case (orig: NumericAttribute, trans: NumericAttribute) => + assert(orig.max.nonEmpty && orig.max === trans.max) + case _ => // do nothing // TODO: Once input features marked as categorical are handled correctly, check that here. + } } } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSizeHintSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSizeHintSuite.scala index f6c9a76599fae..d89d10b320d84 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSizeHintSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSizeHintSuite.scala @@ -17,17 +17,15 @@ package org.apache.spark.ml.feature -import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.Pipeline import org.apache.spark.ml.attribute.AttributeGroup import org.apache.spark.ml.linalg.{Vector, Vectors} -import org.apache.spark.ml.util.DefaultReadWriteTest -import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest} import org.apache.spark.sql.execution.streaming.MemoryStream import org.apache.spark.sql.streaming.StreamTest class VectorSizeHintSuite - extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + extends MLTest with DefaultReadWriteTest { import testImplicits._ @@ -40,16 +38,23 @@ class VectorSizeHintSuite val data = Seq((Vectors.dense(1, 2), 0)).toDF("vector", "intValue") val noSizeTransformer = new VectorSizeHint().setInputCol("vector") - intercept[NoSuchElementException] (noSizeTransformer.transform(data)) + testTransformerByInterceptingException[(Vector, Int)]( + data, + noSizeTransformer, + "Failed to find a default value for size", + "vector") intercept[NoSuchElementException] (noSizeTransformer.transformSchema(data.schema)) val noInputColTransformer = new VectorSizeHint().setSize(2) - intercept[NoSuchElementException] (noInputColTransformer.transform(data)) + testTransformerByInterceptingException[(Vector, Int)]( + data, + noInputColTransformer, + "Failed to find a default value for inputCol", + "vector") intercept[NoSuchElementException] (noInputColTransformer.transformSchema(data.schema)) } test("Adding size to column of vectors.") { - val size = 3 val vectorColName = "vector" val denseVector = Vectors.dense(1, 2, 3) @@ -66,12 +71,15 @@ class VectorSizeHintSuite .setInputCol(vectorColName) .setSize(size) .setHandleInvalid(handleInvalid) - val withSize = transformer.transform(dataFrame) - assert( - AttributeGroup.fromStructField(withSize.schema(vectorColName)).size == size, - "Transformer did not add expected size data.") - val numRows = withSize.collect().length - assert(numRows === data.length, s"Expecting ${data.length} rows, got $numRows.") + testTransformerByGlobalCheckFunc[Tuple1[Vector]](dataFrame, transformer, vectorColName) { + rows => { + assert( + AttributeGroup.fromStructField(rows.head.schema(vectorColName)).size == size, + "Transformer did not add expected size data.") + val numRows = rows.length + assert(numRows === data.length, s"Expecting ${data.length} rows, got $numRows.") + } + } } } @@ -93,14 +101,16 @@ class VectorSizeHintSuite .setInputCol(vectorColName) .setSize(size) .setHandleInvalid(handleInvalid) - val withSize = transformer.transform(dataFrameWithMetadata) - - val newGroup = AttributeGroup.fromStructField(withSize.schema(vectorColName)) - assert(newGroup.size === size, "Column has incorrect size metadata.") - assert( - newGroup.attributes.get === group.attributes.get, - "VectorSizeHint did not preserve attributes.") - withSize.collect + testTransformerByGlobalCheckFunc[(Int, Int, Int, Vector)]( + dataFrameWithMetadata, + transformer, + vectorColName) { rows => + val newGroup = AttributeGroup.fromStructField(rows.head.schema(vectorColName)) + assert(newGroup.size === size, "Column has incorrect size metadata.") + assert( + newGroup.attributes.get === group.attributes.get, + "VectorSizeHint did not preserve attributes.") + } } } @@ -120,7 +130,11 @@ class VectorSizeHintSuite .setInputCol(vectorColName) .setSize(size) .setHandleInvalid(handleInvalid) - intercept[IllegalArgumentException](transformer.transform(dataFrameWithMetadata)) + testTransformerByInterceptingException[(Int, Int, Int, Vector)]( + dataFrameWithMetadata, + transformer, + "Trying to set size of vectors in `vector` to 4 but size already set to 3.", + vectorColName) } } @@ -136,18 +150,36 @@ class VectorSizeHintSuite .setHandleInvalid("error") .setSize(3) - intercept[SparkException](sizeHint.transform(dataWithNull).collect()) - intercept[SparkException](sizeHint.transform(dataWithShort).collect()) + testTransformerByInterceptingException[Tuple1[Vector]]( + dataWithNull, + sizeHint, + "Got null vector in VectorSizeHint", + "vector") + + testTransformerByInterceptingException[Tuple1[Vector]]( + dataWithShort, + sizeHint, + "VectorSizeHint Expecting a vector of size 3 but got 1", + "vector") sizeHint.setHandleInvalid("skip") - assert(sizeHint.transform(dataWithNull).count() === 1) - assert(sizeHint.transform(dataWithShort).count() === 1) + testTransformerByGlobalCheckFunc[Tuple1[Vector]](dataWithNull, sizeHint, "vector") { rows => + assert(rows.length === 1) + } + testTransformerByGlobalCheckFunc[Tuple1[Vector]](dataWithShort, sizeHint, "vector") { rows => + assert(rows.length === 1) + } sizeHint.setHandleInvalid("optimistic") - assert(sizeHint.transform(dataWithNull).count() === 2) - assert(sizeHint.transform(dataWithShort).count() === 2) + testTransformerByGlobalCheckFunc[Tuple1[Vector]](dataWithNull, sizeHint, "vector") { rows => + assert(rows.length === 2) + } + testTransformerByGlobalCheckFunc[Tuple1[Vector]](dataWithShort, sizeHint, "vector") { rows => + assert(rows.length === 2) + } } + test("read/write") { val sizeHint = new VectorSizeHint() .setInputCol("myInputCol") diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala index 1746ce53107c4..3d90f9d9ac764 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala @@ -17,16 +17,16 @@ package org.apache.spark.ml.feature -import org.apache.spark.SparkFunSuite import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NumericAttribute} import org.apache.spark.ml.linalg.{Vector, Vectors, VectorUDT} import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.DefaultReadWriteTest -import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest} +import org.apache.spark.sql.Row import org.apache.spark.sql.types.{StructField, StructType} -class VectorSlicerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class VectorSlicerSuite extends MLTest with DefaultReadWriteTest { + + import testImplicits._ test("params") { val slicer = new VectorSlicer().setInputCol("feature") @@ -84,12 +84,12 @@ class VectorSlicerSuite extends SparkFunSuite with MLlibTestSparkContext with De val vectorSlicer = new VectorSlicer().setInputCol("features").setOutputCol("result") - def validateResults(df: DataFrame): Unit = { - df.select("result", "expected").collect().foreach { case Row(vec1: Vector, vec2: Vector) => + def validateResults(rows: Seq[Row]): Unit = { + rows.foreach { case Row(vec1: Vector, vec2: Vector) => assert(vec1 === vec2) } - val resultMetadata = AttributeGroup.fromStructField(df.schema("result")) - val expectedMetadata = AttributeGroup.fromStructField(df.schema("expected")) + val resultMetadata = AttributeGroup.fromStructField(rows.head.schema("result")) + val expectedMetadata = AttributeGroup.fromStructField(rows.head.schema("expected")) assert(resultMetadata.numAttributes === expectedMetadata.numAttributes) resultMetadata.attributes.get.zip(expectedMetadata.attributes.get).foreach { case (a, b) => assert(a === b) @@ -97,13 +97,16 @@ class VectorSlicerSuite extends SparkFunSuite with MLlibTestSparkContext with De } vectorSlicer.setIndices(Array(1, 4)).setNames(Array.empty) - validateResults(vectorSlicer.transform(df)) + testTransformerByGlobalCheckFunc[(Vector, Vector)](df, vectorSlicer, "result", "expected")( + validateResults) vectorSlicer.setIndices(Array(1)).setNames(Array("f4")) - validateResults(vectorSlicer.transform(df)) + testTransformerByGlobalCheckFunc[(Vector, Vector)](df, vectorSlicer, "result", "expected")( + validateResults) vectorSlicer.setIndices(Array.empty).setNames(Array("f1", "f4")) - validateResults(vectorSlicer.transform(df)) + testTransformerByGlobalCheckFunc[(Vector, Vector)](df, vectorSlicer, "result", "expected")( + validateResults) } test("read/write") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala index 10682ba176aca..b59c4e7967338 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala @@ -17,17 +17,17 @@ package org.apache.spark.ml.feature -import org.apache.spark.SparkFunSuite import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.feature.{Word2VecModel => OldWord2VecModel} -import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.Row import org.apache.spark.util.Utils -class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class Word2VecSuite extends MLTest with DefaultReadWriteTest { + + import testImplicits._ test("params") { ParamsSuite.checkParams(new Word2Vec) @@ -36,10 +36,6 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul } test("Word2Vec") { - - val spark = this.spark - import spark.implicits._ - val sentence = "a b " * 100 + "a c " * 10 val numOfWords = sentence.split(" ").size val doc = sc.parallelize(Seq(sentence, sentence)).map(line => line.split(" ")) @@ -70,17 +66,13 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul // These expectations are just magic values, characterizing the current // behavior. The test needs to be updated to be more general, see SPARK-11502 val magicExp = Vectors.dense(0.30153007534417237, -0.6833061711354689, 0.5116530778733167) - model.transform(docDF).select("result", "expected").collect().foreach { + testTransformer[(Seq[String], Vector)](docDF, model, "result", "expected") { case Row(vector1: Vector, vector2: Vector) => assert(vector1 ~== magicExp absTol 1E-5, "Transformed vector is different with expected.") } } test("getVectors") { - - val spark = this.spark - import spark.implicits._ - val sentence = "a b " * 100 + "a c " * 10 val doc = sc.parallelize(Seq(sentence, sentence)).map(line => line.split(" ")) @@ -119,9 +111,6 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul test("findSynonyms") { - val spark = this.spark - import spark.implicits._ - val sentence = "a b " * 100 + "a c " * 10 val doc = sc.parallelize(Seq(sentence, sentence)).map(line => line.split(" ")) val docDF = doc.zip(doc).toDF("text", "alsotext") @@ -154,9 +143,6 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul test("window size") { - val spark = this.spark - import spark.implicits._ - val sentence = "a q s t q s t b b b s t m s t m q " * 100 + "a c " * 10 val doc = sc.parallelize(Seq(sentence, sentence)).map(line => line.split(" ")) val docDF = doc.zip(doc).toDF("text", "alsotext") @@ -227,8 +213,6 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul } test("Word2Vec works with input that is non-nullable (NGram)") { - val spark = this.spark - import spark.implicits._ val sentence = "a q s t q s t b b b s t m s t m q " val docDF = sc.parallelize(Seq(sentence, sentence)).map(_.split(" ")).toDF("text") @@ -243,7 +227,7 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul .fit(ngramDF) // Just test that this transformation succeeds - model.transform(ngramDF).collect() + testTransformerByGlobalCheckFunc[(Seq[String], Seq[String])](ngramDF, model, "result") { _ => } } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala b/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala index 17678aa611a48..795fd0e2ac0e4 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala @@ -22,9 +22,10 @@ import java.io.File import org.scalatest.Suite import org.apache.spark.SparkContext -import org.apache.spark.ml.{PipelineModel, Transformer} +import org.apache.spark.ml.Transformer import org.apache.spark.sql.{DataFrame, Encoder, Row} import org.apache.spark.sql.execution.streaming.MemoryStream +import org.apache.spark.sql.functions.col import org.apache.spark.sql.streaming.StreamTest import org.apache.spark.sql.test.TestSparkSession import org.apache.spark.util.Utils @@ -62,8 +63,10 @@ trait MLTest extends StreamTest with TempDirectory { self: Suite => val columnNames = dataframe.schema.fieldNames val stream = MemoryStream[A] - val streamDF = stream.toDS().toDF(columnNames: _*) - + val columnsWithMetadata = dataframe.schema.map { structField => + col(structField.name).as(structField.name, structField.metadata) + } + val streamDF = stream.toDS().toDF(columnNames: _*).select(columnsWithMetadata: _*) val data = dataframe.as[A].collect() val streamOutput = transformer.transform(streamDF) @@ -108,5 +111,29 @@ trait MLTest extends StreamTest with TempDirectory { self: Suite => otherResultCols: _*)(globalCheckFunction) testTransformerOnDF(dataframe, transformer, firstResultCol, otherResultCols: _*)(globalCheckFunction) + } + + def testTransformerByInterceptingException[A : Encoder]( + dataframe: DataFrame, + transformer: Transformer, + expectedMessagePart : String, + firstResultCol: String) { + + def hasExpectedMessage(exception: Throwable): Boolean = + exception.getMessage.contains(expectedMessagePart) || + (exception.getCause != null && exception.getCause.getMessage.contains(expectedMessagePart)) + + withClue(s"""Expected message part "${expectedMessagePart}" is not found in DF test.""") { + val exceptionOnDf = intercept[Throwable] { + testTransformerOnDF(dataframe, transformer, firstResultCol)(_ => Unit) + } + assert(hasExpectedMessage(exceptionOnDf)) + } + withClue(s"""Expected message part "${expectedMessagePart}" is not found in stream test.""") { + val exceptionOnStreamData = intercept[Throwable] { + testTransformerOnStreamData(dataframe, transformer, firstResultCol)(_ => Unit) + } + assert(hasExpectedMessage(exceptionOnStreamData)) + } } } From 4f5bad615b47d743b8932aea1071652293981604 Mon Sep 17 00:00:00 2001 From: smallory Date: Thu, 15 Mar 2018 11:58:54 +0900 Subject: [PATCH 0471/1161] [SPARK-23642][DOCS] AccumulatorV2 subclass isZero scaladoc fix Added/corrected scaladoc for isZero on the DoubleAccumulator, CollectionAccumulator, and LongAccumulator subclasses of AccumulatorV2, particularly noting where there are requirements in addition to having a value of zero in order to return true. ## What changes were proposed in this pull request? Three scaladoc comments are updated in AccumulatorV2.scala No changes outside of comment blocks were made. ## How was this patch tested? Running "sbt unidoc", fixing style errors found, and reviewing the resulting local scaladoc in firefox. Author: smallory Closes #20790 from smallory/patch-1. --- .../main/scala/org/apache/spark/util/AccumulatorV2.scala | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala b/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala index f4a736d6d439a..0f84ea9752cf5 100644 --- a/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala +++ b/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala @@ -290,7 +290,8 @@ class LongAccumulator extends AccumulatorV2[jl.Long, jl.Long] { private var _count = 0L /** - * Adds v to the accumulator, i.e. increment sum by v and count by 1. + * Returns false if this accumulator has had any values added to it or the sum is non-zero. + * * @since 2.0.0 */ override def isZero: Boolean = _sum == 0L && _count == 0 @@ -368,6 +369,9 @@ class DoubleAccumulator extends AccumulatorV2[jl.Double, jl.Double] { private var _sum = 0.0 private var _count = 0L + /** + * Returns false if this accumulator has had any values added to it or the sum is non-zero. + */ override def isZero: Boolean = _sum == 0.0 && _count == 0 override def copy(): DoubleAccumulator = { @@ -441,6 +445,9 @@ class DoubleAccumulator extends AccumulatorV2[jl.Double, jl.Double] { class CollectionAccumulator[T] extends AccumulatorV2[T, java.util.List[T]] { private val _list: java.util.List[T] = Collections.synchronizedList(new ArrayList[T]()) + /** + * Returns false if this accumulator instance has any values in it. + */ override def isZero: Boolean = _list.isEmpty override def copyAndReset(): CollectionAccumulator[T] = new CollectionAccumulator From 7c3e8995f18a1fb57c1f2c1b98a1d47590e28f38 Mon Sep 17 00:00:00 2001 From: Yuanjian Li Date: Thu, 15 Mar 2018 00:04:28 -0700 Subject: [PATCH 0472/1161] [SPARK-23533][SS] Add support for changing ContinuousDataReader's startOffset ## What changes were proposed in this pull request? As discussion in #20675, we need add a new interface `ContinuousDataReaderFactory` to support the requirements of setting start offset in Continuous Processing. ## How was this patch tested? Existing UT. Author: Yuanjian Li Closes #20689 from xuanyuanking/SPARK-23533. --- .../sql/kafka010/KafkaContinuousReader.scala | 11 +++++- .../reader/ContinuousDataReaderFactory.java | 35 +++++++++++++++++++ .../ContinuousRateStreamSource.scala | 15 +++++++- 3 files changed, 59 insertions(+), 2 deletions(-) create mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousDataReaderFactory.java diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala index ecd1170321f3f..6e56b0a72d671 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala @@ -164,7 +164,16 @@ case class KafkaContinuousDataReaderFactory( startOffset: Long, kafkaParams: ju.Map[String, Object], pollTimeoutMs: Long, - failOnDataLoss: Boolean) extends DataReaderFactory[UnsafeRow] { + failOnDataLoss: Boolean) extends ContinuousDataReaderFactory[UnsafeRow] { + + override def createDataReaderWithOffset(offset: PartitionOffset): DataReader[UnsafeRow] = { + val kafkaOffset = offset.asInstanceOf[KafkaSourcePartitionOffset] + require(kafkaOffset.topicPartition == topicPartition, + s"Expected topicPartition: $topicPartition, but got: ${kafkaOffset.topicPartition}") + new KafkaContinuousDataReader( + topicPartition, kafkaOffset.partitionOffset, kafkaParams, pollTimeoutMs, failOnDataLoss) + } + override def createDataReader(): KafkaContinuousDataReader = { new KafkaContinuousDataReader( topicPartition, startOffset, kafkaParams, pollTimeoutMs, failOnDataLoss) diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousDataReaderFactory.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousDataReaderFactory.java new file mode 100644 index 0000000000000..a61697649c43e --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousDataReaderFactory.java @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.reader; + +import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.sources.v2.reader.streaming.PartitionOffset; + +/** + * A mix-in interface for {@link DataReaderFactory}. Continuous data reader factories can + * implement this interface to provide creating {@link DataReader} with particular offset. + */ +@InterfaceStability.Evolving +public interface ContinuousDataReaderFactory extends DataReaderFactory { + /** + * Create a DataReader with particular offset as its startOffset. + * + * @param offset offset want to set as the DataReader's startOffset. + */ + DataReader createDataReaderWithOffset(PartitionOffset offset); +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala index b63d8d3e20650..20d90069163a6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala @@ -106,7 +106,20 @@ case class RateStreamContinuousDataReaderFactory( partitionIndex: Int, increment: Long, rowsPerSecond: Double) - extends DataReaderFactory[Row] { + extends ContinuousDataReaderFactory[Row] { + + override def createDataReaderWithOffset(offset: PartitionOffset): DataReader[Row] = { + val rateStreamOffset = offset.asInstanceOf[RateStreamPartitionOffset] + require(rateStreamOffset.partition == partitionIndex, + s"Expected partitionIndex: $partitionIndex, but got: ${rateStreamOffset.partition}") + new RateStreamContinuousDataReader( + rateStreamOffset.currentValue, + rateStreamOffset.currentTimeMs, + partitionIndex, + increment, + rowsPerSecond) + } + override def createDataReader(): DataReader[Row] = new RateStreamContinuousDataReader( startValue, startTimeMs, partitionIndex, increment, rowsPerSecond) From 56e8f48a43eb51e8582db2461a585b13a771a00a Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Thu, 15 Mar 2018 10:55:33 -0700 Subject: [PATCH 0473/1161] [SPARK-23695][PYTHON] Fix the error message for Kinesis streaming tests ## What changes were proposed in this pull request? This PR proposes to fix the error message for Kinesis in PySpark when its jar is missing but explicitly enabled. ```bash ENABLE_KINESIS_TESTS=1 SPARK_TESTING=1 bin/pyspark pyspark.streaming.tests ``` Before: ``` Skipped test_flume_stream (enable by setting environment variable ENABLE_FLUME_TESTS=1Skipped test_kafka_stream (enable by setting environment variable ENABLE_KAFKA_0_8_TESTS=1Traceback (most recent call last): File "/usr/local/Cellar/python/2.7.14_3/Frameworks/Python.framework/Versions/2.7/lib/python2.7/runpy.py", line 174, in _run_module_as_main "__main__", fname, loader, pkg_name) File "/usr/local/Cellar/python/2.7.14_3/Frameworks/Python.framework/Versions/2.7/lib/python2.7/runpy.py", line 72, in _run_code exec code in run_globals File "/.../spark/python/pyspark/streaming/tests.py", line 1572, in % kinesis_asl_assembly_dir) + NameError: name 'kinesis_asl_assembly_dir' is not defined ``` After: ``` Skipped test_flume_stream (enable by setting environment variable ENABLE_FLUME_TESTS=1Skipped test_kafka_stream (enable by setting environment variable ENABLE_KAFKA_0_8_TESTS=1Traceback (most recent call last): File "/usr/local/Cellar/python/2.7.14_3/Frameworks/Python.framework/Versions/2.7/lib/python2.7/runpy.py", line 174, in _run_module_as_main "__main__", fname, loader, pkg_name) File "/usr/local/Cellar/python/2.7.14_3/Frameworks/Python.framework/Versions/2.7/lib/python2.7/runpy.py", line 72, in _run_code exec code in run_globals File "/.../spark/python/pyspark/streaming/tests.py", line 1576, in "You need to build Spark with 'build/sbt -Pkinesis-asl " Exception: Failed to find Spark Streaming Kinesis assembly jar in /.../spark/external/kinesis-asl-assembly. You need to build Spark with 'build/sbt -Pkinesis-asl assembly/package streaming-kinesis-asl-assembly/assembly'or 'build/mvn -Pkinesis-asl package' before running this test. ``` ## How was this patch tested? Manually tested. Author: hyukjinkwon Closes #20834 from HyukjinKwon/minor-variable. --- python/pyspark/streaming/tests.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 71f8101e34c50..7dde7c0928c08 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -1503,10 +1503,13 @@ def search_flume_assembly_jar(): return jars[0] -def search_kinesis_asl_assembly_jar(): +def _kinesis_asl_assembly_dir(): SPARK_HOME = os.environ["SPARK_HOME"] - kinesis_asl_assembly_dir = os.path.join(SPARK_HOME, "external/kinesis-asl-assembly") - jars = search_jar(kinesis_asl_assembly_dir, "spark-streaming-kinesis-asl-assembly") + return os.path.join(SPARK_HOME, "external/kinesis-asl-assembly") + + +def search_kinesis_asl_assembly_jar(): + jars = search_jar(_kinesis_asl_assembly_dir(), "spark-streaming-kinesis-asl-assembly") if not jars: return None elif len(jars) > 1: @@ -1569,7 +1572,7 @@ def search_kinesis_asl_assembly_jar(): else: raise Exception( ("Failed to find Spark Streaming Kinesis assembly jar in %s. " - % kinesis_asl_assembly_dir) + + % _kinesis_asl_assembly_dir()) + "You need to build Spark with 'build/sbt -Pkinesis-asl " "assembly/package streaming-kinesis-asl-assembly/assembly'" "or 'build/mvn -Pkinesis-asl package' before running this test.") From 15c3c983008557165cc91713ddaf2dbd6d5a506c Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Thu, 15 Mar 2018 19:54:58 +0100 Subject: [PATCH 0474/1161] [HOT-FIX] Fix SparkOutOfMemoryError: Unable to acquire 262144 bytes of memory, got 224631 ## What changes were proposed in this pull request? https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder/88263/testReport https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder/88260/testReport https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder/88257/testReport https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder/88224/testReport These tests all failed: ``` org.apache.spark.memory.SparkOutOfMemoryError: Unable to acquire 262144 bytes of memory, got 224631 at org.apache.spark.memory.MemoryConsumer.throwOom(MemoryConsumer.java:157) at org.apache.spark.memory.MemoryConsumer.allocateArray(MemoryConsumer.java:98) at org.apache.spark.unsafe.map.BytesToBytesMap.allocate(BytesToBytesMap.java:787) at org.apache.spark.unsafe.map.BytesToBytesMap.(BytesToBytesMap.java:204) at org.apache.spark.unsafe.map.BytesToBytesMap.(BytesToBytesMap.java:219) ... ``` This PR ignore this test. ## How was this patch tested? N/A Author: Yuming Wang Closes #20835 from wangyum/SPARK-23598. --- .../org/apache/spark/sql/execution/WholeStageCodegenSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index 4b40e4ef7571c..9180a22c260f1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -310,7 +310,7 @@ class WholeStageCodegenSuite extends QueryTest with SharedSQLContext { } } - test("SPARK-23598: Codegen working for lots of aggregation operations without runtime errors") { + ignore("SPARK-23598: Codegen working for lots of aggregation operations without runtime errors") { withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { var df = Seq((8, "bat"), (15, "mouse"), (5, "horse")).toDF("age", "name") for (i <- 0 until 70) { From 7618896e855579f111dd92cd76794a5672a087e5 Mon Sep 17 00:00:00 2001 From: Sahil Takiar Date: Thu, 15 Mar 2018 17:04:39 -0700 Subject: [PATCH 0475/1161] [SPARK-23658][LAUNCHER] InProcessAppHandle uses the wrong class in getLogger ## What changes were proposed in this pull request? Changed `Logger` in `InProcessAppHandle` to use `InProcessAppHandle` instead of `ChildProcAppHandle` Author: Sahil Takiar Closes #20815 from sahilTakiar/master. --- .../main/java/org/apache/spark/launcher/InProcessAppHandle.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/launcher/src/main/java/org/apache/spark/launcher/InProcessAppHandle.java b/launcher/src/main/java/org/apache/spark/launcher/InProcessAppHandle.java index 4b740d3fad20e..15fbca0facef2 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/InProcessAppHandle.java +++ b/launcher/src/main/java/org/apache/spark/launcher/InProcessAppHandle.java @@ -25,7 +25,7 @@ class InProcessAppHandle extends AbstractAppHandle { private static final String THREAD_NAME_FMT = "spark-app-%d: '%s'"; - private static final Logger LOG = Logger.getLogger(ChildProcAppHandle.class.getName()); + private static final Logger LOG = Logger.getLogger(InProcessAppHandle.class.getName()); private static final AtomicLong THREAD_IDS = new AtomicLong(); // Avoid really long thread names. From 18f8575e0166c6997569358d45bdae2cf45bf624 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Thu, 15 Mar 2018 17:12:01 -0700 Subject: [PATCH 0476/1161] [SPARK-23671][CORE] Fix condition to enable the SHS thread pool. Author: Marcelo Vanzin Closes #20814 from vanzin/SPARK-23671. --- .../org/apache/spark/deploy/history/FsHistoryProvider.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index f9d0b5ee4e23e..ace6d9e00c838 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -173,7 +173,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) * Fixed size thread pool to fetch and parse log files. */ private val replayExecutor: ExecutorService = { - if (Utils.isTesting) { + if (!Utils.isTesting) { ThreadUtils.newDaemonFixedThreadPool(NUM_PROCESSING_THREADS, "log-replay-executor") } else { MoreExecutors.sameThreadExecutor() From 3675af7247e841e9a689666dc20891ba55c612b3 Mon Sep 17 00:00:00 2001 From: Ye Zhou Date: Thu, 15 Mar 2018 17:15:53 -0700 Subject: [PATCH 0477/1161] [SPARK-23608][CORE][WEBUI] Add synchronization in SHS between attachSparkUI and detachSparkUI functions to avoid concurrent modification issue to Jetty Handlers Jetty handlers are dynamically attached/detached while SHS is running. But the attach and detach operations might be taking place at the same time due to the async in load/clear in Guava Cache. ## What changes were proposed in this pull request? Add synchronization between attachSparkUI and detachSparkUI in SHS. ## How was this patch tested? With this patch, the jetty handlers missing issue never happens again in our production cluster SHS. Author: Ye Zhou Closes #20744 from zhouyejoe/SPARK-23608. --- .../apache/spark/deploy/history/HistoryServer.scala | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala index 0ec4afad0308c..611fa563a7cd9 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala @@ -150,14 +150,18 @@ class HistoryServer( ui: SparkUI, completed: Boolean) { assert(serverInfo.isDefined, "HistoryServer must be bound before attaching SparkUIs") - ui.getHandlers.foreach(attachHandler) - addFilters(ui.getHandlers, conf) + handlers.synchronized { + ui.getHandlers.foreach(attachHandler) + addFilters(ui.getHandlers, conf) + } } /** Detach a reconstructed UI from this server. Only valid after bind(). */ override def detachSparkUI(appId: String, attemptId: Option[String], ui: SparkUI): Unit = { assert(serverInfo.isDefined, "HistoryServer must be bound before detaching SparkUIs") - ui.getHandlers.foreach(detachHandler) + handlers.synchronized { + ui.getHandlers.foreach(detachHandler) + } provider.onUIDetached(appId, attemptId, ui) } From c2632edebd978716dbfa7874a2fc0a8f5a4a9951 Mon Sep 17 00:00:00 2001 From: myroslavlisniak Date: Thu, 15 Mar 2018 17:20:17 -0700 Subject: [PATCH 0478/1161] [SPARK-23670][SQL] Fix memory leak on SparkPlanGraphWrapper Clean up SparkPlanGraphWrapper objects from InMemoryStore together with cleaning up SQLExecutionUIData existing unit test was extended to check also SparkPlanGraphWrapper object count vanzin Author: myroslavlisniak Closes #20813 from myroslavlisniak/master. --- .../apache/spark/sql/execution/ui/SQLAppStatusListener.scala | 5 ++++- .../apache/spark/sql/execution/ui/SQLAppStatusStore.scala | 4 ++++ .../spark/sql/execution/ui/SQLAppStatusListenerSuite.scala | 1 + 3 files changed, 9 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala index 53fb9a0cc21cf..71e9f93c4566e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala @@ -334,7 +334,10 @@ class SQLAppStatusListener( val view = kvstore.view(classOf[SQLExecutionUIData]).index("completionTime").first(0L) val toDelete = KVUtils.viewToSeq(view, countToDelete.toInt)(_.completionTime.isDefined) - toDelete.foreach { e => kvstore.delete(e.getClass(), e.executionId) } + toDelete.foreach { e => + kvstore.delete(e.getClass(), e.executionId) + kvstore.delete(classOf[SparkPlanGraphWrapper], e.executionId) + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusStore.scala index 9a76584717f42..241001a857c8f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusStore.scala @@ -54,6 +54,10 @@ class SQLAppStatusStore( store.count(classOf[SQLExecutionUIData]) } + def planGraphCount(): Long = { + store.count(classOf[SparkPlanGraphWrapper]) + } + def executionMetrics(executionId: Long): Map[Long, String] = { def metricsFromStore(): Option[Map[Long, String]] = { val exec = store.read(classOf[SQLExecutionUIData], executionId) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala index 85face3994fd4..f3f08839c1d3a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala @@ -611,6 +611,7 @@ class SQLAppStatusListenerMemoryLeakSuite extends SparkFunSuite { sc.listenerBus.waitUntilEmpty(10000) val statusStore = spark.sharedState.statusStore assert(statusStore.executionsCount() <= 50) + assert(statusStore.planGraphCount() <= 50) // No live data should be left behind after all executions end. assert(statusStore.listener.get.noLiveData()) } From ca83526de55f0f8784df58cc8b7c0a7cb0c96e23 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Fri, 16 Mar 2018 15:12:26 +0800 Subject: [PATCH 0479/1161] [SPARK-23644][CORE][UI] Use absolute path for REST call in SHS ## What changes were proposed in this pull request? SHS is using a relative path for the REST API call to get the list of the application is a relative path call. In case of the SHS being consumed through a proxy, it can be an issue if the path doesn't end with a "/". Therefore, we should use an absolute path for the REST call as it is done for all the other resources. ## How was this patch tested? manual tests Before the change: ![screen shot 2018-03-10 at 4 22 02 pm](https://user-images.githubusercontent.com/8821783/37244190-8ccf9d40-2485-11e8-8fa9-345bc81472fc.png) After the change: ![screen shot 2018-03-10 at 4 36 34 pm 1](https://user-images.githubusercontent.com/8821783/37244201-a1922810-2485-11e8-8856-eeab2bf5e180.png) Author: Marco Gaido Closes #20794 from mgaido91/SPARK-23644. --- .../main/resources/org/apache/spark/ui/static/historypage.js | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/src/main/resources/org/apache/spark/ui/static/historypage.js b/core/src/main/resources/org/apache/spark/ui/static/historypage.js index f0b2a5a833a99..abc2ec0fa6531 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/historypage.js +++ b/core/src/main/resources/org/apache/spark/ui/static/historypage.js @@ -113,7 +113,7 @@ $(document).ready(function() { status: (requestedIncomplete ? "running" : "completed") }; - $.getJSON("api/v1/applications", appParams, function(response,status,jqXHR) { + $.getJSON(uiRoot + "/api/v1/applications", appParams, function(response,status,jqXHR) { var array = []; var hasMultipleAttempts = false; for (i in response) { @@ -151,7 +151,7 @@ $(document).ready(function() { "showCompletedColumns": !requestedIncomplete, } - $.get("static/historypage-template.html", function(template) { + $.get(uiRoot + "/static/historypage-template.html", function(template) { var sibling = historySummary.prev(); historySummary.detach(); var apps = $(Mustache.render($(template).filter("#history-summary-template").html(),data)); From c952000487ee003200221b3c4e25dcb06e359f0a Mon Sep 17 00:00:00 2001 From: jerryshao Date: Fri, 16 Mar 2018 16:22:03 +0800 Subject: [PATCH 0480/1161] [SPARK-23635][YARN] AM env variable should not overwrite same name env variable set through spark.executorEnv. ## What changes were proposed in this pull request? In the current Spark on YARN code, AM always will copy and overwrite its env variables to executors, so we cannot set different values for executors. To reproduce issue, user could start spark-shell like: ``` ./bin/spark-shell --master yarn-client --conf spark.executorEnv.SPARK_ABC=executor_val --conf spark.yarn.appMasterEnv.SPARK_ABC=am_val ``` Then check executor env variables by ``` sc.parallelize(1 to 1).flatMap \{ i => sys.env.toSeq }.collect.foreach(println) ``` We will always get `am_val` instead of `executor_val`. So we should not let AM to overwrite specifically set executor env variables. ## How was this patch tested? Added UT and tested in local cluster. Author: jerryshao Closes #20799 from jerryshao/SPARK-23635. --- .../spark/deploy/yarn/ExecutorRunnable.scala | 22 +++++++----- .../spark/deploy/yarn/YarnClusterSuite.scala | 36 +++++++++++++++++++ 2 files changed, 50 insertions(+), 8 deletions(-) diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala index 3f4d236571ffd..ab08698035c98 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala @@ -220,12 +220,6 @@ private[yarn] class ExecutorRunnable( val env = new HashMap[String, String]() Client.populateClasspath(null, conf, sparkConf, env, sparkConf.get(EXECUTOR_CLASS_PATH)) - sparkConf.getExecutorEnv.foreach { case (key, value) => - // This assumes each executor environment variable set here is a path - // This is kept for backward compatibility and consistency with hadoop - YarnSparkHadoopUtil.addPathToEnvironment(env, key, value) - } - // lookup appropriate http scheme for container log urls val yarnHttpPolicy = conf.get( YarnConfiguration.YARN_HTTP_POLICY_KEY, @@ -233,6 +227,20 @@ private[yarn] class ExecutorRunnable( ) val httpScheme = if (yarnHttpPolicy == "HTTPS_ONLY") "https://" else "http://" + System.getenv().asScala.filterKeys(_.startsWith("SPARK")) + .foreach { case (k, v) => env(k) = v } + + sparkConf.getExecutorEnv.foreach { case (key, value) => + if (key == Environment.CLASSPATH.name()) { + // If the key of env variable is CLASSPATH, we assume it is a path and append it. + // This is kept for backward compatibility and consistency with hadoop + YarnSparkHadoopUtil.addPathToEnvironment(env, key, value) + } else { + // For other env variables, simply overwrite the value. + env(key) = value + } + } + // Add log urls container.foreach { c => sys.env.get("SPARK_USER").foreach { user => @@ -245,8 +253,6 @@ private[yarn] class ExecutorRunnable( } } - System.getenv().asScala.filterKeys(_.startsWith("SPARK")) - .foreach { case (k, v) => env(k) = v } env } } diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala index 33d400a5b1b2e..a129be7c06b53 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala @@ -225,6 +225,14 @@ class YarnClusterSuite extends BaseYarnClusterSuite { finalState should be (SparkAppHandle.State.FAILED) } + test("executor env overwrite AM env in client mode") { + testExecutorEnv(true) + } + + test("executor env overwrite AM env in cluster mode") { + testExecutorEnv(false) + } + private def testBasicYarnApp(clientMode: Boolean, conf: Map[String, String] = Map()): Unit = { val result = File.createTempFile("result", null, tempDir) val finalState = runSpark(clientMode, mainClassName(YarnClusterDriver.getClass), @@ -305,6 +313,17 @@ class YarnClusterSuite extends BaseYarnClusterSuite { checkResult(finalState, executorResult, "OVERRIDDEN") } + private def testExecutorEnv(clientMode: Boolean): Unit = { + val result = File.createTempFile("result", null, tempDir) + val finalState = runSpark(clientMode, mainClassName(ExecutorEnvTestApp.getClass), + appArgs = Seq(result.getAbsolutePath), + extraConf = Map( + "spark.yarn.appMasterEnv.TEST_ENV" -> "am_val", + "spark.executorEnv.TEST_ENV" -> "executor_val" + ) + ) + checkResult(finalState, result, "true") + } } private[spark] class SaveExecutorInfo extends SparkListener { @@ -526,3 +545,20 @@ private object SparkContextTimeoutApp { } } + +private object ExecutorEnvTestApp { + + def main(args: Array[String]): Unit = { + val status = args(0) + val sparkConf = new SparkConf() + val sc = new SparkContext(sparkConf) + val executorEnvs = sc.parallelize(Seq(1)).flatMap { _ => sys.env }.collect().toMap + val result = sparkConf.getExecutorEnv.forall { case (k, v) => + executorEnvs.get(k).contains(v) + } + + Files.write(result.toString, new File(status), StandardCharsets.UTF_8) + sc.stop() + } + +} From 5414abca4fec6a68174c34d22d071c20027e959d Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Fri, 16 Mar 2018 09:36:30 -0700 Subject: [PATCH 0481/1161] [SPARK-23553][TESTS] Tests should not assume the default value of `spark.sql.sources.default` ## What changes were proposed in this pull request? Currently, some tests have an assumption that `spark.sql.sources.default=parquet`. In fact, that is a correct assumption, but that assumption makes it difficult to test new data source format. This PR aims to - Improve test suites more robust and makes it easy to test new data sources in the future. - Test new native ORC data source with the full existing Apache Spark test coverage. As an example, the PR uses `spark.sql.sources.default=orc` during reviews. The value should be `parquet` when this PR is accepted. ## How was this patch tested? Pass the Jenkins with updated tests. Author: Dongjoon Hyun Closes #20705 from dongjoon-hyun/SPARK-23553. --- python/pyspark/sql/readwriter.py | 4 +- .../org/apache/spark/sql/SQLQuerySuite.scala | 9 +-- .../columnar/InMemoryColumnarQuerySuite.scala | 5 +- .../sql/execution/command/DDLSuite.scala | 11 ++- .../ParquetPartitionDiscoverySuite.scala | 10 +++ .../sql/test/DataFrameReaderWriterSuite.scala | 3 +- .../sql/hive/MetastoreDataSourcesSuite.scala | 72 +++++++++---------- .../PartitionProviderCompatibilitySuite.scala | 6 +- .../hive/PartitionedTablePerfStatsSuite.scala | 2 +- .../sql/hive/execution/HiveDDLSuite.scala | 11 +-- .../sql/hive/execution/SQLQuerySuite.scala | 19 ++--- 11 files changed, 81 insertions(+), 71 deletions(-) diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 803f561ece67b..facc16bc53108 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -147,8 +147,8 @@ def load(self, path=None, format=None, schema=None, **options): or a DDL-formatted string (For example ``col0 INT, col1 DOUBLE``). :param options: all other string options - >>> df = spark.read.load('python/test_support/sql/parquet_partitioned', opt1=True, - ... opt2=1, opt3='str') + >>> df = spark.read.format("parquet").load('python/test_support/sql/parquet_partitioned', + ... opt1=True, opt2=1, opt3='str') >>> df.dtypes [('name', 'string'), ('year', 'int'), ('month', 'int'), ('day', 'int')] diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 8f14575c3325f..640affc10ee58 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -2150,7 +2150,8 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("data source table created in InMemoryCatalog should be able to read/write") { withTable("tbl") { - sql("CREATE TABLE tbl(i INT, j STRING) USING parquet") + val provider = spark.sessionState.conf.defaultDataSourceName + sql(s"CREATE TABLE tbl(i INT, j STRING) USING $provider") checkAnswer(sql("SELECT i, j FROM tbl"), Nil) Seq(1 -> "a", 2 -> "b").toDF("i", "j").write.mode("overwrite").insertInto("tbl") @@ -2474,9 +2475,9 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("SPARK-16975: Column-partition path starting '_' should be handled correctly") { withTempDir { dir => - val parquetDir = new File(dir, "parquet").getCanonicalPath - spark.range(10).withColumn("_col", $"id").write.partitionBy("_col").save(parquetDir) - spark.read.parquet(parquetDir) + val dataDir = new File(dir, "data").getCanonicalPath + spark.range(10).withColumn("_col", $"id").write.partitionBy("_col").save(dataDir) + spark.read.load(dataDir) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala index dc1766fb9a785..26b63e8e8490f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala @@ -487,7 +487,10 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { } test("SPARK-22673: InMemoryRelation should utilize existing stats of the plan to be cached") { - withSQLConf("spark.sql.cbo.enabled" -> "true") { + // This test case depends on the size of parquet in statistics. + withSQLConf( + SQLConf.CBO_ENABLED.key -> "true", + SQLConf.DEFAULT_DATA_SOURCE_NAME.key -> "parquet") { withTempPath { workDir => withTable("table1") { val workDirPath = workDir.getAbsolutePath diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index 4041176262426..4df8fbfe1c0db 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -154,10 +154,15 @@ class InMemoryCatalogedDDLSuite extends DDLSuite with SharedSQLContext with Befo Seq(4 -> "d").toDF("i", "j").write.saveAsTable("t1") val e = intercept[AnalysisException] { - Seq(5 -> "e").toDF("i", "j").write.mode("append").format("json").saveAsTable("t1") + val format = if (spark.sessionState.conf.defaultDataSourceName.equalsIgnoreCase("json")) { + "orc" + } else { + "json" + } + Seq(5 -> "e").toDF("i", "j").write.mode("append").format(format).saveAsTable("t1") } - assert(e.message.contains("The format of the existing table default.t1 is " + - "`ParquetFileFormat`. It doesn't match the specified format `JsonFileFormat`.")) + assert(e.message.contains("The format of the existing table default.t1 is ")) + assert(e.message.contains("It doesn't match the specified format")) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala index edb3da904d10d..e887c9734a8b8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala @@ -57,6 +57,16 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha val timeZone = TimeZone.getDefault() val timeZoneId = timeZone.getID + protected override def beforeAll(): Unit = { + super.beforeAll() + spark.conf.set(SQLConf.DEFAULT_DATA_SOURCE_NAME.key, "parquet") + } + + protected override def afterAll(): Unit = { + spark.conf.unset(SQLConf.DEFAULT_DATA_SOURCE_NAME.key) + super.afterAll() + } + test("column type inference") { def check(raw: String, literal: Literal, timeZone: TimeZone = timeZone): Unit = { assert(inferPartitionColumnValue(raw, true, timeZone) === literal) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala index a707a88dfa670..14b1feb2adc20 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala @@ -562,7 +562,8 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with Be "and a same-name temp view exist") { withTable("same_name") { withTempView("same_name") { - sql("CREATE TABLE same_name(id LONG) USING parquet") + val format = spark.sessionState.conf.defaultDataSourceName + sql(s"CREATE TABLE same_name(id LONG) USING $format") spark.range(10).createTempView("same_name") spark.range(20).write.mode(SaveMode.Append).saveAsTable("same_name") checkAnswer(spark.table("same_name"), spark.range(10).toDF()) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala index 859099a321bf7..d93215fefb810 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -591,7 +591,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv } test("Pre insert nullability check (ArrayType)") { - withTable("arrayInParquet") { + withTable("array") { { val df = (Tuple1(Seq(Int.box(1), null: Integer)) :: Nil).toDF("a") val expectedSchema = @@ -604,9 +604,8 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv assert(df.schema === expectedSchema) df.write - .format("parquet") .mode(SaveMode.Overwrite) - .saveAsTable("arrayInParquet") + .saveAsTable("array") } { @@ -621,25 +620,24 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv assert(df.schema === expectedSchema) df.write - .format("parquet") .mode(SaveMode.Append) - .insertInto("arrayInParquet") + .insertInto("array") } (Tuple1(Seq(4, 5)) :: Nil).toDF("a") .write .mode(SaveMode.Append) - .saveAsTable("arrayInParquet") // This one internally calls df2.insertInto. + .saveAsTable("array") // This one internally calls df2.insertInto. (Tuple1(Seq(Int.box(6), null: Integer)) :: Nil).toDF("a") .write .mode(SaveMode.Append) - .saveAsTable("arrayInParquet") + .saveAsTable("array") - sparkSession.catalog.refreshTable("arrayInParquet") + sparkSession.catalog.refreshTable("array") checkAnswer( - sql("SELECT a FROM arrayInParquet"), + sql("SELECT a FROM array"), Row(ArrayBuffer(1, null)) :: Row(ArrayBuffer(2, 3)) :: Row(ArrayBuffer(4, 5)) :: @@ -648,7 +646,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv } test("Pre insert nullability check (MapType)") { - withTable("mapInParquet") { + withTable("map") { { val df = (Tuple1(Map(1 -> (null: Integer))) :: Nil).toDF("a") val expectedSchema = @@ -661,9 +659,8 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv assert(df.schema === expectedSchema) df.write - .format("parquet") .mode(SaveMode.Overwrite) - .saveAsTable("mapInParquet") + .saveAsTable("map") } { @@ -678,27 +675,24 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv assert(df.schema === expectedSchema) df.write - .format("parquet") .mode(SaveMode.Append) - .insertInto("mapInParquet") + .insertInto("map") } (Tuple1(Map(4 -> 5)) :: Nil).toDF("a") .write - .format("parquet") .mode(SaveMode.Append) - .saveAsTable("mapInParquet") // This one internally calls df2.insertInto. + .saveAsTable("map") // This one internally calls df2.insertInto. (Tuple1(Map(6 -> null.asInstanceOf[Integer])) :: Nil).toDF("a") .write - .format("parquet") .mode(SaveMode.Append) - .saveAsTable("mapInParquet") + .saveAsTable("map") - sparkSession.catalog.refreshTable("mapInParquet") + sparkSession.catalog.refreshTable("map") checkAnswer( - sql("SELECT a FROM mapInParquet"), + sql("SELECT a FROM map"), Row(Map(1 -> null)) :: Row(Map(2 -> 3)) :: Row(Map(4 -> 5)) :: @@ -852,52 +846,52 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv (from to to).map(i => i -> s"str$i").toDF("c1", "c2") } - withTable("insertParquet") { - createDF(0, 9).write.format("parquet").saveAsTable("insertParquet") + withTable("t") { + createDF(0, 9).write.saveAsTable("t") checkAnswer( - sql("SELECT p.c1, p.c2 FROM insertParquet p WHERE p.c1 > 5"), + sql("SELECT p.c1, p.c2 FROM t p WHERE p.c1 > 5"), (6 to 9).map(i => Row(i, s"str$i"))) intercept[AnalysisException] { - createDF(10, 19).write.format("parquet").saveAsTable("insertParquet") + createDF(10, 19).write.saveAsTable("t") } - createDF(10, 19).write.mode(SaveMode.Append).format("parquet").saveAsTable("insertParquet") + createDF(10, 19).write.mode(SaveMode.Append).saveAsTable("t") checkAnswer( - sql("SELECT p.c1, p.c2 FROM insertParquet p WHERE p.c1 > 5"), + sql("SELECT p.c1, p.c2 FROM t p WHERE p.c1 > 5"), (6 to 19).map(i => Row(i, s"str$i"))) - createDF(20, 29).write.mode(SaveMode.Append).format("parquet").saveAsTable("insertParquet") + createDF(20, 29).write.mode(SaveMode.Append).saveAsTable("t") checkAnswer( - sql("SELECT p.c1, c2 FROM insertParquet p WHERE p.c1 > 5 AND p.c1 < 25"), + sql("SELECT p.c1, c2 FROM t p WHERE p.c1 > 5 AND p.c1 < 25"), (6 to 24).map(i => Row(i, s"str$i"))) intercept[AnalysisException] { - createDF(30, 39).write.saveAsTable("insertParquet") + createDF(30, 39).write.saveAsTable("t") } - createDF(30, 39).write.mode(SaveMode.Append).saveAsTable("insertParquet") + createDF(30, 39).write.mode(SaveMode.Append).saveAsTable("t") checkAnswer( - sql("SELECT p.c1, c2 FROM insertParquet p WHERE p.c1 > 5 AND p.c1 < 35"), + sql("SELECT p.c1, c2 FROM t p WHERE p.c1 > 5 AND p.c1 < 35"), (6 to 34).map(i => Row(i, s"str$i"))) - createDF(40, 49).write.mode(SaveMode.Append).insertInto("insertParquet") + createDF(40, 49).write.mode(SaveMode.Append).insertInto("t") checkAnswer( - sql("SELECT p.c1, c2 FROM insertParquet p WHERE p.c1 > 5 AND p.c1 < 45"), + sql("SELECT p.c1, c2 FROM t p WHERE p.c1 > 5 AND p.c1 < 45"), (6 to 44).map(i => Row(i, s"str$i"))) - createDF(50, 59).write.mode(SaveMode.Overwrite).saveAsTable("insertParquet") + createDF(50, 59).write.mode(SaveMode.Overwrite).saveAsTable("t") checkAnswer( - sql("SELECT p.c1, c2 FROM insertParquet p WHERE p.c1 > 51 AND p.c1 < 55"), + sql("SELECT p.c1, c2 FROM t p WHERE p.c1 > 51 AND p.c1 < 55"), (52 to 54).map(i => Row(i, s"str$i"))) - createDF(60, 69).write.mode(SaveMode.Ignore).saveAsTable("insertParquet") + createDF(60, 69).write.mode(SaveMode.Ignore).saveAsTable("t") checkAnswer( - sql("SELECT p.c1, c2 FROM insertParquet p"), + sql("SELECT p.c1, c2 FROM t p"), (50 to 59).map(i => Row(i, s"str$i"))) - createDF(70, 79).write.mode(SaveMode.Overwrite).insertInto("insertParquet") + createDF(70, 79).write.mode(SaveMode.Overwrite).insertInto("t") checkAnswer( - sql("SELECT p.c1, c2 FROM insertParquet p"), + sql("SELECT p.c1, c2 FROM t p"), (70 to 79).map(i => Row(i, s"str$i"))) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionProviderCompatibilitySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionProviderCompatibilitySuite.scala index 9440a17677ebf..80afc9d8f44bc 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionProviderCompatibilitySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionProviderCompatibilitySuite.scala @@ -37,11 +37,11 @@ class PartitionProviderCompatibilitySuite spark.range(5).selectExpr("id as fieldOne", "id as partCol").write .partitionBy("partCol") .mode("overwrite") - .parquet(dir.getAbsolutePath) + .save(dir.getAbsolutePath) spark.sql(s""" |create table $tableName (fieldOne long, partCol int) - |using parquet + |using ${spark.sessionState.conf.defaultDataSourceName} |options (path "${dir.toURI}") |partitioned by (partCol)""".stripMargin) } @@ -358,7 +358,7 @@ class PartitionProviderCompatibilitySuite try { spark.sql(s""" |create table test (id long, P1 int, P2 int) - |using parquet + |using ${spark.sessionState.conf.defaultDataSourceName} |options (path "${base.toURI}") |partitioned by (P1, P2)""".stripMargin) spark.sql(s"alter table test add partition (P1=0, P2=0) location '${a.toURI}'") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionedTablePerfStatsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionedTablePerfStatsSuite.scala index 54d3962a46b4d..1a86c604d5da3 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionedTablePerfStatsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionedTablePerfStatsSuite.scala @@ -417,7 +417,7 @@ class PartitionedTablePerfStatsSuite import spark.implicits._ Seq(1).toDF("a").write.mode("overwrite").save(dir.getAbsolutePath) HiveCatalogMetrics.reset() - spark.read.parquet(dir.getAbsolutePath) + spark.read.load(dir.getAbsolutePath) assert(HiveCatalogMetrics.METRIC_FILES_DISCOVERED.getCount() == 1) assert(HiveCatalogMetrics.METRIC_FILE_CACHE_HITS.getCount() == 1) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala index 65be244418670..db76ec9d084cb 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -1658,8 +1658,8 @@ class HiveDDLSuite Seq(5 -> "e").toDF("i", "j") .write.format("hive").mode("append").saveAsTable("t1") } - assert(e.message.contains("The format of the existing table default.t1 is " + - "`ParquetFileFormat`. It doesn't match the specified format `HiveFileFormat`.")) + assert(e.message.contains("The format of the existing table default.t1 is ")) + assert(e.message.contains("It doesn't match the specified format `HiveFileFormat`.")) } } @@ -1709,11 +1709,12 @@ class HiveDDLSuite spark.sessionState.catalog.getTableMetadata(TableIdentifier(tblName)).schema.map(_.name) } + val provider = spark.sessionState.conf.defaultDataSourceName withTable("t", "t1", "t2", "t3", "t4", "t5", "t6") { - sql("CREATE TABLE t(a int, b int, c int, d int) USING parquet PARTITIONED BY (d, b)") + sql(s"CREATE TABLE t(a int, b int, c int, d int) USING $provider PARTITIONED BY (d, b)") assert(getTableColumns("t") == Seq("a", "c", "d", "b")) - sql("CREATE TABLE t1 USING parquet PARTITIONED BY (d, b) AS SELECT 1 a, 1 b, 1 c, 1 d") + sql(s"CREATE TABLE t1 USING $provider PARTITIONED BY (d, b) AS SELECT 1 a, 1 b, 1 c, 1 d") assert(getTableColumns("t1") == Seq("a", "c", "d", "b")) Seq((1, 1, 1, 1)).toDF("a", "b", "c", "d").write.partitionBy("d", "b").saveAsTable("t2") @@ -1723,7 +1724,7 @@ class HiveDDLSuite val dataPath = new File(new File(path, "d=1"), "b=1").getCanonicalPath Seq(1 -> 1).toDF("a", "c").write.save(dataPath) - sql(s"CREATE TABLE t3 USING parquet LOCATION '${path.toURI}'") + sql(s"CREATE TABLE t3 USING $provider LOCATION '${path.toURI}'") assert(getTableColumns("t3") == Seq("a", "c", "d", "b")) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index baabc4a3bca2c..73f83d593bbfb 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -516,24 +516,19 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { test("CTAS with default fileformat") { val table = "ctas1" val ctas = s"CREATE TABLE IF NOT EXISTS $table SELECT key k, value FROM src" - withSQLConf(SQLConf.CONVERT_CTAS.key -> "true") { - withSQLConf("hive.default.fileformat" -> "textfile") { + Seq("orc", "parquet").foreach { dataSourceFormat => + withSQLConf( + SQLConf.CONVERT_CTAS.key -> "true", + SQLConf.DEFAULT_DATA_SOURCE_NAME.key -> dataSourceFormat, + "hive.default.fileformat" -> "textfile") { withTable(table) { sql(ctas) - // We should use parquet here as that is the default datasource fileformat. The default - // datasource file format is controlled by `spark.sql.sources.default` configuration. + // The default datasource file format is controlled by `spark.sql.sources.default`. // This testcase verifies that setting `hive.default.fileformat` has no impact on // the target table's fileformat in case of CTAS. - assert(sessionState.conf.defaultDataSourceName === "parquet") - checkRelation(tableName = table, isDataSourceTable = true, format = "parquet") + checkRelation(tableName = table, isDataSourceTable = true, format = dataSourceFormat) } } - withSQLConf("spark.sql.sources.default" -> "orc") { - withTable(table) { - sql(ctas) - checkRelation(tableName = table, isDataSourceTable = true, format = "orc") - } - } } } From dffeac3691daa620446ae949c5b147518d128e08 Mon Sep 17 00:00:00 2001 From: Sebastian Arzt Date: Fri, 16 Mar 2018 12:25:58 -0500 Subject: [PATCH 0482/1161] [SPARK-18371][STREAMING] Spark Streaming backpressure generates batch with large number of records MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? Omit rounding of backpressure rate. Effects: - no batch with large number of records is created when rate from PID estimator is one - the number of records per batch and partition is more fine-grained improving backpressure accuracy ## How was this patch tested? This was tested by running: - `mvn test -pl external/kafka-0-8` - `mvn test -pl external/kafka-0-10` - a streaming application which was suffering from the issue JasonMWhite The contribution is my original work and I license the work to the project under the project’s open source license Author: Sebastian Arzt Closes #17774 from arzt/kafka-back-pressure. --- .../kafka010/DirectKafkaInputDStream.scala | 6 +-- .../kafka010/DirectKafkaStreamSuite.scala | 48 +++++++++++++++++ .../kafka/DirectKafkaInputDStream.scala | 6 +-- .../kafka/DirectKafkaStreamSuite.scala | 51 +++++++++++++++++++ 4 files changed, 105 insertions(+), 6 deletions(-) diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala index 0fa3287f36db8..9cb2448fea0f4 100644 --- a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala +++ b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala @@ -138,17 +138,17 @@ private[spark] class DirectKafkaInputDStream[K, V]( lagPerPartition.map { case (tp, lag) => val maxRateLimitPerPartition = ppc.maxRatePerPartition(tp) - val backpressureRate = Math.round(lag / totalLag.toFloat * rate) + val backpressureRate = lag / totalLag.toDouble * rate tp -> (if (maxRateLimitPerPartition > 0) { Math.min(backpressureRate, maxRateLimitPerPartition)} else backpressureRate) } - case None => offsets.map { case (tp, offset) => tp -> ppc.maxRatePerPartition(tp) } + case None => offsets.map { case (tp, offset) => tp -> ppc.maxRatePerPartition(tp).toDouble } } if (effectiveRateLimitPerPartition.values.sum > 0) { val secsPerBatch = context.graph.batchDuration.milliseconds.toDouble / 1000 Some(effectiveRateLimitPerPartition.map { - case (tp, limit) => tp -> (secsPerBatch * limit).toLong + case (tp, limit) => tp -> Math.max((secsPerBatch * limit).toLong, 1L) }) } else { None diff --git a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/DirectKafkaStreamSuite.scala b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/DirectKafkaStreamSuite.scala index 453b5e5ab20d3..8524743ee2846 100644 --- a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/DirectKafkaStreamSuite.scala +++ b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/DirectKafkaStreamSuite.scala @@ -617,6 +617,54 @@ class DirectKafkaStreamSuite ssc.stop() } + test("maxMessagesPerPartition with zero offset and rate equal to one") { + val topic = "backpressure" + val kafkaParams = getKafkaParams() + val batchIntervalMilliseconds = 60000 + val sparkConf = new SparkConf() + // Safe, even with streaming, because we're using the direct API. + // Using 1 core is useful to make the test more predictable. + .setMaster("local[1]") + .setAppName(this.getClass.getSimpleName) + .set("spark.streaming.kafka.maxRatePerPartition", "100") + + // Setup the streaming context + ssc = new StreamingContext(sparkConf, Milliseconds(batchIntervalMilliseconds)) + val estimateRate = 1L + val fromOffsets = Map( + new TopicPartition(topic, 0) -> 0L, + new TopicPartition(topic, 1) -> 0L, + new TopicPartition(topic, 2) -> 0L, + new TopicPartition(topic, 3) -> 0L + ) + val kafkaStream = withClue("Error creating direct stream") { + new DirectKafkaInputDStream[String, String]( + ssc, + preferredHosts, + ConsumerStrategies.Subscribe[String, String](List(topic), kafkaParams.asScala), + new DefaultPerPartitionConfig(sparkConf) + ) { + currentOffsets = fromOffsets + override val rateController = Some(new ConstantRateController(id, null, estimateRate)) + } + } + + val offsets = Map[TopicPartition, Long]( + new TopicPartition(topic, 0) -> 0, + new TopicPartition(topic, 1) -> 100L, + new TopicPartition(topic, 2) -> 200L, + new TopicPartition(topic, 3) -> 300L + ) + val result = kafkaStream.maxMessagesPerPartition(offsets) + val expected = Map( + new TopicPartition(topic, 0) -> 1L, + new TopicPartition(topic, 1) -> 10L, + new TopicPartition(topic, 2) -> 20L, + new TopicPartition(topic, 3) -> 30L + ) + assert(result.contains(expected), s"Number of messages per partition must be at least 1") + } + /** Get the generated offset ranges from the DirectKafkaStream */ private def getOffsetRanges[K, V]( kafkaStream: DStream[ConsumerRecord[K, V]]): Seq[(Time, Array[OffsetRange])] = { diff --git a/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala b/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala index d52c230eb7849..d6dd0744441e4 100644 --- a/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala +++ b/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala @@ -104,17 +104,17 @@ class DirectKafkaInputDStream[ val totalLag = lagPerPartition.values.sum lagPerPartition.map { case (tp, lag) => - val backpressureRate = Math.round(lag / totalLag.toFloat * rate) + val backpressureRate = lag / totalLag.toDouble * rate tp -> (if (maxRateLimitPerPartition > 0) { Math.min(backpressureRate, maxRateLimitPerPartition)} else backpressureRate) } - case None => offsets.map { case (tp, offset) => tp -> maxRateLimitPerPartition } + case None => offsets.map { case (tp, offset) => tp -> maxRateLimitPerPartition.toDouble } } if (effectiveRateLimitPerPartition.values.sum > 0) { val secsPerBatch = context.graph.batchDuration.milliseconds.toDouble / 1000 Some(effectiveRateLimitPerPartition.map { - case (tp, limit) => tp -> (secsPerBatch * limit).toLong + case (tp, limit) => tp -> Math.max((secsPerBatch * limit).toLong, 1L) }) } else { None diff --git a/external/kafka-0-8/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala b/external/kafka-0-8/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala index 06ef5bc3f8bd0..3fea6cfd910bf 100644 --- a/external/kafka-0-8/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala +++ b/external/kafka-0-8/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala @@ -456,6 +456,57 @@ class DirectKafkaStreamSuite ssc.stop() } + test("maxMessagesPerPartition with zero offset and rate equal to one") { + val topic = "backpressure" + val kafkaParams = Map( + "metadata.broker.list" -> kafkaTestUtils.brokerAddress, + "auto.offset.reset" -> "smallest" + ) + + val batchIntervalMilliseconds = 60000 + val sparkConf = new SparkConf() + // Safe, even with streaming, because we're using the direct API. + // Using 1 core is useful to make the test more predictable. + .setMaster("local[1]") + .setAppName(this.getClass.getSimpleName) + .set("spark.streaming.kafka.maxRatePerPartition", "100") + + // Setup the streaming context + ssc = new StreamingContext(sparkConf, Milliseconds(batchIntervalMilliseconds)) + val estimatedRate = 1L + val kafkaStream = withClue("Error creating direct stream") { + val messageHandler = (mmd: MessageAndMetadata[String, String]) => (mmd.key, mmd.message) + val fromOffsets = Map( + TopicAndPartition(topic, 0) -> 0L, + TopicAndPartition(topic, 1) -> 0L, + TopicAndPartition(topic, 2) -> 0L, + TopicAndPartition(topic, 3) -> 0L + ) + new DirectKafkaInputDStream[String, String, StringDecoder, StringDecoder, (String, String)]( + ssc, kafkaParams, fromOffsets, messageHandler) { + override protected[streaming] val rateController = + Some(new DirectKafkaRateController(id, null) { + override def getLatestRate() = estimatedRate + }) + } + } + + val offsets = Map( + TopicAndPartition(topic, 0) -> 0L, + TopicAndPartition(topic, 1) -> 100L, + TopicAndPartition(topic, 2) -> 200L, + TopicAndPartition(topic, 3) -> 300L + ) + val result = kafkaStream.maxMessagesPerPartition(offsets) + val expected = Map( + TopicAndPartition(topic, 0) -> 1L, + TopicAndPartition(topic, 1) -> 10L, + TopicAndPartition(topic, 2) -> 20L, + TopicAndPartition(topic, 3) -> 30L + ) + assert(result.contains(expected), s"Number of messages per partition must be at least 1") + } + /** Get the generated offset ranges from the DirectKafkaStream */ private def getOffsetRanges[K, V]( kafkaStream: DStream[(K, V)]): Seq[(Time, Array[OffsetRange])] = { From 88d8de9260edf6e9d5449ff7ef6e35d16051fc9f Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Fri, 16 Mar 2018 18:28:16 +0100 Subject: [PATCH 0483/1161] [SPARK-23581][SQL] Add interpreted unsafe projection ## What changes were proposed in this pull request? We currently can only create unsafe rows using code generation. This is a problem for situations in which code generation fails. There is no fallback, and as a result we cannot execute the query. This PR adds an interpreted version of `UnsafeProjection`. The implementation is modeled after `InterpretedMutableProjection`. It stores the expression results in a `GenericInternalRow`, and it then uses a conversion function to convert the `GenericInternalRow` into an `UnsafeRow`. This PR does not implement the actual code generated to interpreted fallback logic. This will be done in a follow-up. ## How was this patch tested? I am piggybacking on exiting `UnsafeProjection` tests, and I have added an interpreted version for each of these. Author: Herman van Hovell Closes #20750 from hvanhovell/SPARK-23581. --- .../codegen/UnsafeArrayWriter.java | 32 +- .../expressions/codegen/UnsafeRowWriter.java | 30 +- .../expressions/codegen/UnsafeWriter.java | 43 ++ .../sql/catalyst/expressions/Expression.scala | 26 ++ .../InterpretedUnsafeProjection.scala | 366 ++++++++++++++++++ .../MonotonicallyIncreasingID.scala | 4 +- .../sql/catalyst/expressions/Projection.scala | 19 +- .../codegen/GenerateUnsafeProjection.scala | 2 +- .../expressions/randomExpressions.scala | 6 +- .../expressions/ComplexTypeSuite.scala | 2 +- .../expressions/ExpressionEvalHelper.scala | 20 +- .../expressions/ObjectExpressionsSuite.scala | 21 +- .../catalyst/expressions/ScalaUDFSuite.scala | 2 +- .../expressions/UnsafeRowConverterSuite.scala | 56 +-- 14 files changed, 555 insertions(+), 74 deletions(-) create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java index 791e8d80e6cba..82cd1b24607e1 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java @@ -30,7 +30,7 @@ * A helper class to write data into global row buffer using `UnsafeArrayData` format, * used by {@link org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection}. */ -public class UnsafeArrayWriter { +public final class UnsafeArrayWriter extends UnsafeWriter { private BufferHolder holder; @@ -83,7 +83,7 @@ private long getElementOffset(int ordinal, int elementSize) { return startingOffset + headerInBytes + ordinal * elementSize; } - public void setOffsetAndSize(int ordinal, long currentCursor, int size) { + public void setOffsetAndSize(int ordinal, int currentCursor, int size) { assertIndexIsValid(ordinal); final long relativeOffset = currentCursor - startingOffset; final long offsetAndSize = (relativeOffset << 32) | (long)size; @@ -96,49 +96,31 @@ private void setNullBit(int ordinal) { BitSetMethods.set(holder.buffer, startingOffset + 8, ordinal); } - public void setNullBoolean(int ordinal) { - setNullBit(ordinal); - // put zero into the corresponding field when set null - Platform.putBoolean(holder.buffer, getElementOffset(ordinal, 1), false); - } - - public void setNullByte(int ordinal) { + public void setNull1Bytes(int ordinal) { setNullBit(ordinal); // put zero into the corresponding field when set null Platform.putByte(holder.buffer, getElementOffset(ordinal, 1), (byte)0); } - public void setNullShort(int ordinal) { + public void setNull2Bytes(int ordinal) { setNullBit(ordinal); // put zero into the corresponding field when set null Platform.putShort(holder.buffer, getElementOffset(ordinal, 2), (short)0); } - public void setNullInt(int ordinal) { + public void setNull4Bytes(int ordinal) { setNullBit(ordinal); // put zero into the corresponding field when set null Platform.putInt(holder.buffer, getElementOffset(ordinal, 4), 0); } - public void setNullLong(int ordinal) { + public void setNull8Bytes(int ordinal) { setNullBit(ordinal); // put zero into the corresponding field when set null Platform.putLong(holder.buffer, getElementOffset(ordinal, 8), (long)0); } - public void setNullFloat(int ordinal) { - setNullBit(ordinal); - // put zero into the corresponding field when set null - Platform.putFloat(holder.buffer, getElementOffset(ordinal, 4), (float)0); - } - - public void setNullDouble(int ordinal) { - setNullBit(ordinal); - // put zero into the corresponding field when set null - Platform.putDouble(holder.buffer, getElementOffset(ordinal, 8), (double)0); - } - - public void setNull(int ordinal) { setNullLong(ordinal); } + public void setNull(int ordinal) { setNull8Bytes(ordinal); } public void write(int ordinal, boolean value) { assertIndexIsValid(ordinal); diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java index 5d9515c0725da..2620bbcfb87a2 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java @@ -38,7 +38,7 @@ * beginning of the global row buffer, we don't need to update `startingOffset` and can just call * `zeroOutNullBytes` before writing new data. */ -public class UnsafeRowWriter { +public final class UnsafeRowWriter extends UnsafeWriter { private final BufferHolder holder; // The offset of the global buffer where we start to write this row. @@ -93,18 +93,38 @@ public void setNullAt(int ordinal) { Platform.putLong(holder.buffer, getFieldOffset(ordinal), 0L); } + @Override + public void setNull1Bytes(int ordinal) { + setNullAt(ordinal); + } + + @Override + public void setNull2Bytes(int ordinal) { + setNullAt(ordinal); + } + + @Override + public void setNull4Bytes(int ordinal) { + setNullAt(ordinal); + } + + @Override + public void setNull8Bytes(int ordinal) { + setNullAt(ordinal); + } + public long getFieldOffset(int ordinal) { return startingOffset + nullBitsSize + 8 * ordinal; } - public void setOffsetAndSize(int ordinal, long size) { + public void setOffsetAndSize(int ordinal, int size) { setOffsetAndSize(ordinal, holder.cursor, size); } - public void setOffsetAndSize(int ordinal, long currentCursor, long size) { + public void setOffsetAndSize(int ordinal, int currentCursor, int size) { final long relativeOffset = currentCursor - startingOffset; final long fieldOffset = getFieldOffset(ordinal); - final long offsetAndSize = (relativeOffset << 32) | size; + final long offsetAndSize = (relativeOffset << 32) | (long) size; Platform.putLong(holder.buffer, fieldOffset, offsetAndSize); } @@ -174,7 +194,7 @@ public void write(int ordinal, Decimal input, int precision, int scale) { if (input == null || !input.changePrecision(precision, scale)) { BitSetMethods.set(holder.buffer, startingOffset, ordinal); // keep the offset for future update - setOffsetAndSize(ordinal, 0L); + setOffsetAndSize(ordinal, 0); } else { final byte[] bytes = input.toJavaBigDecimal().unscaledValue().toByteArray(); assert bytes.length <= 16; diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java new file mode 100644 index 0000000000000..c94b5c7a367ef --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.expressions.codegen; + +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.unsafe.types.CalendarInterval; +import org.apache.spark.unsafe.types.UTF8String; + +/** + * Base class for writing Unsafe* structures. + */ +public abstract class UnsafeWriter { + public abstract void setNull1Bytes(int ordinal); + public abstract void setNull2Bytes(int ordinal); + public abstract void setNull4Bytes(int ordinal); + public abstract void setNull8Bytes(int ordinal); + public abstract void write(int ordinal, boolean value); + public abstract void write(int ordinal, byte value); + public abstract void write(int ordinal, short value); + public abstract void write(int ordinal, int value); + public abstract void write(int ordinal, long value); + public abstract void write(int ordinal, float value); + public abstract void write(int ordinal, double value); + public abstract void write(int ordinal, Decimal input, int precision, int scale); + public abstract void write(int ordinal, UTF8String input); + public abstract void write(int ordinal, byte[] input); + public abstract void write(int ordinal, CalendarInterval input); + public abstract void setOffsetAndSize(int ordinal, int currentCursor, int size); +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index ed90b185865a0..d7f9e38915dd5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -328,6 +328,32 @@ trait Nondeterministic extends Expression { protected def evalInternal(input: InternalRow): Any } +/** + * An expression that contains mutable state. A stateful expression is always non-deterministic + * because the results it produces during evaluation are not only dependent on the given input + * but also on its internal state. + * + * The state of the expressions is generally not exposed in the parameter list and this makes + * comparing stateful expressions problematic because similar stateful expressions (with the same + * parameter list) but with different internal state will be considered equal. This is especially + * problematic during tree transformations. In order to counter this the `fastEquals` method for + * stateful expressions only returns `true` for the same reference. + * + * A stateful expression should never be evaluated multiple times for a single row. This should + * only be a problem for interpreted execution. This can be prevented by creating fresh copies + * of the stateful expression before execution, these can be made using the `freshCopy` function. + */ +trait Stateful extends Nondeterministic { + /** + * Return a fresh uninitialized copy of the stateful expression. + */ + def freshCopy(): Stateful + + /** + * Only the same reference is considered equal. + */ + override def fastEquals(other: TreeNode[_]): Boolean = this eq other +} /** * A leaf expression, i.e. one without any child expressions. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala new file mode 100644 index 0000000000000..0da5ece7e47fe --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala @@ -0,0 +1,366 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.SparkException +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.codegen.{BufferHolder, UnsafeArrayWriter, UnsafeRowWriter, UnsafeWriter} +import org.apache.spark.sql.catalyst.util.ArrayData +import org.apache.spark.sql.types.{UserDefinedType, _} +import org.apache.spark.unsafe.Platform + +/** + * An interpreted unsafe projection. This class reuses the [[UnsafeRow]] it produces, a consumer + * should copy the row if it is being buffered. This class is not thread safe. + * + * @param expressions that produces the resulting fields. These expressions must be bound + * to a schema. + */ +class InterpretedUnsafeProjection(expressions: Array[Expression]) extends UnsafeProjection { + import InterpretedUnsafeProjection._ + + /** Number of (top level) fields in the resulting row. */ + private[this] val numFields = expressions.length + + /** Array that expression results. */ + private[this] val values = new Array[Any](numFields) + + /** The row representing the expression results. */ + private[this] val intermediate = new GenericInternalRow(values) + + /** The row returned by the projection. */ + private[this] val result = new UnsafeRow(numFields) + + /** The buffer which holds the resulting row's backing data. */ + private[this] val holder = new BufferHolder(result, numFields * 32) + + /** The writer that writes the intermediate result to the result row. */ + private[this] val writer: InternalRow => Unit = { + val rowWriter = new UnsafeRowWriter(holder, numFields) + val baseWriter = generateStructWriter( + holder, + rowWriter, + expressions.map(e => StructField("", e.dataType, e.nullable))) + if (!expressions.exists(_.nullable)) { + // No nullable fields. The top-level null bit mask will always be zeroed out. + baseWriter + } else { + // Zero out the null bit mask before we write the row. + row => { + rowWriter.zeroOutNullBytes() + baseWriter(row) + } + } + } + + override def initialize(partitionIndex: Int): Unit = { + expressions.foreach(_.foreach { + case n: Nondeterministic => n.initialize(partitionIndex) + case _ => + }) + } + + override def apply(row: InternalRow): UnsafeRow = { + // Put the expression results in the intermediate row. + var i = 0 + while (i < numFields) { + values(i) = expressions(i).eval(row) + i += 1 + } + + // Write the intermediate row to an unsafe row. + holder.reset() + writer(intermediate) + result.setTotalSize(holder.totalSize()) + result + } +} + +/** + * Helper functions for creating an [[InterpretedUnsafeProjection]]. + */ +object InterpretedUnsafeProjection extends UnsafeProjectionCreator { + + /** + * Returns an [[UnsafeProjection]] for given sequence of bound Expressions. + */ + override protected def createProjection(exprs: Seq[Expression]): UnsafeProjection = { + // We need to make sure that we do not reuse stateful expressions. + val cleanedExpressions = exprs.map(_.transform { + case s: Stateful => s.freshCopy() + }) + new InterpretedUnsafeProjection(cleanedExpressions.toArray) + } + + /** + * Generate a struct writer function. The generated function writes an [[InternalRow]] to the + * given buffer using the given [[UnsafeRowWriter]]. + */ + private def generateStructWriter( + bufferHolder: BufferHolder, + rowWriter: UnsafeRowWriter, + fields: Array[StructField]): InternalRow => Unit = { + val numFields = fields.length + + // Create field writers. + val fieldWriters = fields.map { field => + generateFieldWriter(bufferHolder, rowWriter, field.dataType, field.nullable) + } + // Create basic writer. + row => { + var i = 0 + while (i < numFields) { + fieldWriters(i).apply(row, i) + i += 1 + } + } + } + + /** + * Generate a writer function for a struct field, array element, map key or map value. The + * generated function writes the element at an index in a [[SpecializedGetters]] object (row + * or array) to the given buffer using the given [[UnsafeWriter]]. + */ + private def generateFieldWriter( + bufferHolder: BufferHolder, + writer: UnsafeWriter, + dt: DataType, + nullable: Boolean): (SpecializedGetters, Int) => Unit = { + + // Create the the basic writer. + val unsafeWriter: (SpecializedGetters, Int) => Unit = dt match { + case BooleanType => + (v, i) => writer.write(i, v.getBoolean(i)) + + case ByteType => + (v, i) => writer.write(i, v.getByte(i)) + + case ShortType => + (v, i) => writer.write(i, v.getShort(i)) + + case IntegerType | DateType => + (v, i) => writer.write(i, v.getInt(i)) + + case LongType | TimestampType => + (v, i) => writer.write(i, v.getLong(i)) + + case FloatType => + (v, i) => writer.write(i, v.getFloat(i)) + + case DoubleType => + (v, i) => writer.write(i, v.getDouble(i)) + + case DecimalType.Fixed(precision, scale) => + (v, i) => writer.write(i, v.getDecimal(i, precision, scale), precision, scale) + + case CalendarIntervalType => + (v, i) => writer.write(i, v.getInterval(i)) + + case BinaryType => + (v, i) => writer.write(i, v.getBinary(i)) + + case StringType => + (v, i) => writer.write(i, v.getUTF8String(i)) + + case StructType(fields) => + val numFields = fields.length + val rowWriter = new UnsafeRowWriter(bufferHolder, numFields) + val structWriter = generateStructWriter(bufferHolder, rowWriter, fields) + (v, i) => { + val tmpCursor = bufferHolder.cursor + v.getStruct(i, fields.length) match { + case row: UnsafeRow => + writeUnsafeData( + bufferHolder, + row.getBaseObject, + row.getBaseOffset, + row.getSizeInBytes) + case row => + // Nested struct. We don't know where this will start because a row can be + // variable length, so we need to update the offsets and zero out the bit mask. + rowWriter.reset() + structWriter.apply(row) + } + writer.setOffsetAndSize(i, tmpCursor, bufferHolder.cursor - tmpCursor) + } + + case ArrayType(elementType, containsNull) => + val arrayWriter = new UnsafeArrayWriter + val elementSize = getElementSize(elementType) + val elementWriter = generateFieldWriter( + bufferHolder, + arrayWriter, + elementType, + containsNull) + (v, i) => { + val tmpCursor = bufferHolder.cursor + writeArray(bufferHolder, arrayWriter, elementWriter, v.getArray(i), elementSize) + writer.setOffsetAndSize(i, tmpCursor, bufferHolder.cursor - tmpCursor) + } + + case MapType(keyType, valueType, valueContainsNull) => + val keyArrayWriter = new UnsafeArrayWriter + val keySize = getElementSize(keyType) + val keyWriter = generateFieldWriter( + bufferHolder, + keyArrayWriter, + keyType, + nullable = false) + val valueArrayWriter = new UnsafeArrayWriter + val valueSize = getElementSize(valueType) + val valueWriter = generateFieldWriter( + bufferHolder, + valueArrayWriter, + valueType, + valueContainsNull) + (v, i) => { + val tmpCursor = bufferHolder.cursor + v.getMap(i) match { + case map: UnsafeMapData => + writeUnsafeData( + bufferHolder, + map.getBaseObject, + map.getBaseOffset, + map.getSizeInBytes) + case map => + // preserve 8 bytes to write the key array numBytes later. + bufferHolder.grow(8) + bufferHolder.cursor += 8 + + // Write the keys and write the numBytes of key array into the first 8 bytes. + writeArray(bufferHolder, keyArrayWriter, keyWriter, map.keyArray(), keySize) + Platform.putLong(bufferHolder.buffer, tmpCursor, bufferHolder.cursor - tmpCursor - 8) + + // Write the values. + writeArray(bufferHolder, valueArrayWriter, valueWriter, map.valueArray(), valueSize) + } + writer.setOffsetAndSize(i, tmpCursor, bufferHolder.cursor - tmpCursor) + } + + case udt: UserDefinedType[_] => + generateFieldWriter(bufferHolder, writer, udt.sqlType, nullable) + + case NullType => + (_, _) => {} + + case _ => + throw new SparkException(s"Unsupported data type $dt") + } + + // Always wrap the writer with a null safe version. + dt match { + case _: UserDefinedType[_] => + // The null wrapper depends on the sql type and not on the UDT. + unsafeWriter + case DecimalType.Fixed(precision, _) if precision > Decimal.MAX_LONG_DIGITS => + // We can't call setNullAt() for DecimalType with precision larger than 18, we call write + // directly. We can use the unwrapped writer directly. + unsafeWriter + case BooleanType | ByteType => + (v, i) => { + if (!v.isNullAt(i)) { + unsafeWriter(v, i) + } else { + writer.setNull1Bytes(i) + } + } + case ShortType => + (v, i) => { + if (!v.isNullAt(i)) { + unsafeWriter(v, i) + } else { + writer.setNull2Bytes(i) + } + } + case IntegerType | DateType | FloatType => + (v, i) => { + if (!v.isNullAt(i)) { + unsafeWriter(v, i) + } else { + writer.setNull4Bytes(i) + } + } + case _ => + (v, i) => { + if (!v.isNullAt(i)) { + unsafeWriter(v, i) + } else { + writer.setNull8Bytes(i) + } + } + } + } + + /** + * Get the number of bytes elements of a data type will occupy in the fixed part of an + * [[UnsafeArrayData]] object. Reference types are stored as an 8 byte combination of an + * offset (upper 4 bytes) and a length (lower 4 bytes), these point to the variable length + * portion of the array object. Primitives take up to 8 bytes, depending on the size of the + * underlying data type. + */ + private def getElementSize(dataType: DataType): Int = dataType match { + case NullType | StringType | BinaryType | CalendarIntervalType | + _: DecimalType | _: StructType | _: ArrayType | _: MapType => 8 + case _ => dataType.defaultSize + } + + /** + * Write an array to the buffer. If the array is already in serialized form (an instance of + * [[UnsafeArrayData]]) then we copy the bytes directly, otherwise we do an element-by-element + * copy. + */ + private def writeArray( + bufferHolder: BufferHolder, + arrayWriter: UnsafeArrayWriter, + elementWriter: (SpecializedGetters, Int) => Unit, + array: ArrayData, + elementSize: Int): Unit = array match { + case unsafe: UnsafeArrayData => + writeUnsafeData( + bufferHolder, + unsafe.getBaseObject, + unsafe.getBaseOffset, + unsafe.getSizeInBytes) + case _ => + val numElements = array.numElements() + arrayWriter.initialize(bufferHolder, numElements, elementSize) + var i = 0 + while (i < numElements) { + elementWriter.apply(array, i) + i += 1 + } + } + + /** + * Write an opaque block of data to the buffer. This is used to copy + * [[UnsafeRow]], [[UnsafeArrayData]] and [[UnsafeMapData]] objects. + */ + private def writeUnsafeData( + bufferHolder: BufferHolder, + baseObject: AnyRef, + baseOffset: Long, + sizeInBytes: Int) : Unit = { + bufferHolder.grow(sizeInBytes) + Platform.copyMemory( + baseObject, + baseOffset, + bufferHolder.buffer, + bufferHolder.cursor, + sizeInBytes) + bufferHolder.cursor += sizeInBytes + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala index 4523079060896..dd523d312e3b4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala @@ -39,7 +39,7 @@ import org.apache.spark.sql.types.{DataType, LongType} within each partition. The assumption is that the data frame has less than 1 billion partitions, and each partition has less than 8 billion records. """) -case class MonotonicallyIncreasingID() extends LeafExpression with Nondeterministic { +case class MonotonicallyIncreasingID() extends LeafExpression with Stateful { /** * Record ID within each partition. By being transient, count's value is reset to 0 every time @@ -79,4 +79,6 @@ case class MonotonicallyIncreasingID() extends LeafExpression with Nondeterminis override def prettyName: String = "monotonically_increasing_id" override def sql: String = s"$prettyName()" + + override def freshCopy(): MonotonicallyIncreasingID = MonotonicallyIncreasingID() } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index 64b94f0a2c103..3cd73682188bc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -108,8 +108,7 @@ abstract class UnsafeProjection extends Projection { override def apply(row: InternalRow): UnsafeRow } -object UnsafeProjection { - +trait UnsafeProjectionCreator { /** * Returns an UnsafeProjection for given StructType. * @@ -127,13 +126,13 @@ object UnsafeProjection { } /** - * Returns an UnsafeProjection for given sequence of Expressions (bounded). + * Returns an UnsafeProjection for given sequence of bound Expressions. */ def create(exprs: Seq[Expression]): UnsafeProjection = { val unsafeExprs = exprs.map(_ transform { case CreateNamedStruct(children) => CreateNamedStructUnsafe(children) }) - GenerateUnsafeProjection.generate(unsafeExprs) + createProjection(unsafeExprs) } def create(expr: Expression): UnsafeProjection = create(Seq(expr)) @@ -146,6 +145,18 @@ object UnsafeProjection { create(exprs.map(BindReferences.bindReference(_, inputSchema))) } + /** + * Returns an [[UnsafeProjection]] for given sequence of bound Expressions. + */ + protected def createProjection(exprs: Seq[Expression]): UnsafeProjection +} + +object UnsafeProjection extends UnsafeProjectionCreator { + + override protected def createProjection(exprs: Seq[Expression]): UnsafeProjection = { + GenerateUnsafeProjection.generate(exprs) + } + /** * Same as other create()'s but allowing enabling/disabling subexpression elimination. * TODO: refactor the plumbing and clean this up. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 22717f5954a45..6682ba55b18b1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -247,7 +247,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro for (int $index = 0; $index < $numElements; $index++) { if ($tmpInput.isNullAt($index)) { - $arrayWriter.setNull$primitiveTypeName($index); + $arrayWriter.setNull${elementOrOffsetSize}Bytes($index); } else { $writeElement } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala index 6c9937dacc70b..f36633867316e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala @@ -31,7 +31,7 @@ import org.apache.spark.util.random.XORShiftRandom * * Since this expression is stateful, it cannot be a case object. */ -abstract class RDG extends UnaryExpression with ExpectsInputTypes with Nondeterministic { +abstract class RDG extends UnaryExpression with ExpectsInputTypes with Stateful { /** * Record ID within each partition. By being transient, the Random Number Generator is * reset every time we serialize and deserialize and initialize it. @@ -85,6 +85,8 @@ case class Rand(child: Expression) extends RDG { final ${CodeGenerator.javaType(dataType)} ${ev.value} = $rngTerm.nextDouble();""", isNull = "false") } + + override def freshCopy(): Rand = Rand(child) } object Rand { @@ -120,6 +122,8 @@ case class Randn(child: Expression) extends RDG { final ${CodeGenerator.javaType(dataType)} ${ev.value} = $rngTerm.nextGaussian();""", isNull = "false") } + + override def freshCopy(): Randn = Randn(child) } object Randn { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala index 84190f0bd5f7d..b4138ce366b3a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala @@ -180,7 +180,7 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { null, null) } intercept[RuntimeException] { - checkEvalutionWithUnsafeProjection( + checkEvaluationWithUnsafeProjection( CreateMap(interlace(strWithNull, intSeq.map(Literal(_)))), null, null) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index 58d0c07622eb9..c6343b1cbf600 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -60,7 +60,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { checkEvaluationWithoutCodegen(expr, catalystValue, inputRow) checkEvaluationWithGeneratedMutableProjection(expr, catalystValue, inputRow) if (GenerateUnsafeProjection.canSupport(expr.dataType)) { - checkEvalutionWithUnsafeProjection(expr, catalystValue, inputRow) + checkEvaluationWithUnsafeProjection(expr, catalystValue, inputRow) } checkEvaluationWithOptimization(expr, catalystValue, inputRow) } @@ -187,11 +187,20 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { plan(inputRow).get(0, expression.dataType) } - protected def checkEvalutionWithUnsafeProjection( + protected def checkEvaluationWithUnsafeProjection( expression: Expression, expected: Any, inputRow: InternalRow = EmptyRow): Unit = { - val unsafeRow = evaluateWithUnsafeProjection(expression, inputRow) + checkEvaluationWithUnsafeProjection(expression, expected, inputRow, UnsafeProjection) + checkEvaluationWithUnsafeProjection(expression, expected, inputRow, InterpretedUnsafeProjection) + } + + protected def checkEvaluationWithUnsafeProjection( + expression: Expression, + expected: Any, + inputRow: InternalRow, + factory: UnsafeProjectionCreator): Unit = { + val unsafeRow = evaluateWithUnsafeProjection(expression, inputRow, factory) val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" if (expected == null) { @@ -203,7 +212,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { } else { val lit = InternalRow(expected, expected) val expectedRow = - UnsafeProjection.create(Array(expression.dataType, expression.dataType)).apply(lit) + factory.create(Array(expression.dataType, expression.dataType)).apply(lit) if (unsafeRow != expectedRow) { fail("Incorrect evaluation in unsafe mode: " + s"$expression, actual: $unsafeRow, expected: $expectedRow$input") @@ -213,7 +222,8 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { private def evaluateWithUnsafeProjection( expression: Expression, - inputRow: InternalRow = EmptyRow): InternalRow = { + inputRow: InternalRow = EmptyRow, + factory: UnsafeProjectionCreator = UnsafeProjection): InternalRow = { // SPARK-16489 Explicitly doing code generation twice so code gen will fail if // some expression is reusing variable names across different instances. // This behavior is tested in ExpressionEvalHelperSuite. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala index ffeec2a38c532..1f6964dfef598 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala @@ -45,16 +45,22 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val structInputRow = InternalRow.fromSeq(Seq(Array((1, 2), (3, 4)))) val structExpected = new GenericArrayData( Array(InternalRow.fromSeq(Seq(1, 2)), InternalRow.fromSeq(Seq(3, 4)))) - checkEvalutionWithUnsafeProjection( - structEncoder.serializer.head, structExpected, structInputRow) + checkEvaluationWithUnsafeProjection( + structEncoder.serializer.head, + structExpected, + structInputRow, + UnsafeProjection) // TODO(hvanhovell) revert this when SPARK-23587 is fixed // test UnsafeArray-backed data val arrayEncoder = ExpressionEncoder[Array[Array[Int]]] val arrayInputRow = InternalRow.fromSeq(Seq(Array(Array(1, 2), Array(3, 4)))) val arrayExpected = new GenericArrayData( Array(new GenericArrayData(Array(1, 2)), new GenericArrayData(Array(3, 4)))) - checkEvalutionWithUnsafeProjection( - arrayEncoder.serializer.head, arrayExpected, arrayInputRow) + checkEvaluationWithUnsafeProjection( + arrayEncoder.serializer.head, + arrayExpected, + arrayInputRow, + UnsafeProjection) // TODO(hvanhovell) revert this when SPARK-23587 is fixed // test UnsafeMap-backed data val mapEncoder = ExpressionEncoder[Array[Map[Int, Int]]] @@ -67,8 +73,11 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { new ArrayBasedMapData( new GenericArrayData(Array(3, 4)), new GenericArrayData(Array(300, 400))))) - checkEvalutionWithUnsafeProjection( - mapEncoder.serializer.head, mapExpected, mapInputRow) + checkEvaluationWithUnsafeProjection( + mapEncoder.serializer.head, + mapExpected, + mapInputRow, + UnsafeProjection) // TODO(hvanhovell) revert this when SPARK-23587 is fixed } test("SPARK-23585: UnwrapOption should support interpreted execution") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala index 10e3ffd0dff97..e083ae0089244 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala @@ -43,7 +43,7 @@ class ScalaUDFSuite extends SparkFunSuite with ExpressionEvalHelper { assert(e1.getMessage.contains("Failed to execute user defined function")) val e2 = intercept[SparkException] { - checkEvalutionWithUnsafeProjection(udf, null) + checkEvaluationWithUnsafeProjection(udf, null) } assert(e2.getMessage.contains("Failed to execute user defined function")) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala index cf3cbe270753e..c07da122cd7b8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala @@ -25,7 +25,7 @@ import org.scalatest.Matchers import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util._ -import org.apache.spark.sql.types._ +import org.apache.spark.sql.types.{IntegerType, LongType, _} import org.apache.spark.unsafe.array.ByteArrayMethods import org.apache.spark.unsafe.types.UTF8String @@ -33,10 +33,18 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { private def roundedSize(size: Int) = ByteArrayMethods.roundNumberOfBytesToNearestWord(size) - test("basic conversion with only primitive types") { - val fieldTypes: Array[DataType] = Array(LongType, LongType, IntegerType) - val converter = UnsafeProjection.create(fieldTypes) + private def testWithFactory( + name: String)( + f: UnsafeProjectionCreator => Unit): Unit = { + test(name) { + f(UnsafeProjection) + f(InterpretedUnsafeProjection) + } + } + testWithFactory("basic conversion with only primitive types") { factory => + val fieldTypes: Array[DataType] = Array(LongType, LongType, IntegerType) + val converter = factory.create(fieldTypes) val row = new SpecificInternalRow(fieldTypes) row.setLong(0, 0) row.setLong(1, 1) @@ -71,9 +79,9 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { assert(unsafeRow2.getInt(2) === 2) } - test("basic conversion with primitive, string and binary types") { + testWithFactory("basic conversion with primitive, string and binary types") { factory => val fieldTypes: Array[DataType] = Array(LongType, StringType, BinaryType) - val converter = UnsafeProjection.create(fieldTypes) + val converter = factory.create(fieldTypes) val row = new SpecificInternalRow(fieldTypes) row.setLong(0, 0) @@ -90,9 +98,9 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { assert(unsafeRow.getBinary(2) === "World".getBytes(StandardCharsets.UTF_8)) } - test("basic conversion with primitive, string, date and timestamp types") { + testWithFactory("basic conversion with primitive, string, date and timestamp types") { factory => val fieldTypes: Array[DataType] = Array(LongType, StringType, DateType, TimestampType) - val converter = UnsafeProjection.create(fieldTypes) + val converter = factory.create(fieldTypes) val row = new SpecificInternalRow(fieldTypes) row.setLong(0, 0) @@ -119,7 +127,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { (Timestamp.valueOf("2015-06-22 08:10:25")) } - test("null handling") { + testWithFactory("null handling") { factory => val fieldTypes: Array[DataType] = Array( NullType, BooleanType, @@ -135,7 +143,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { DecimalType.SYSTEM_DEFAULT // ArrayType(IntegerType) ) - val converter = UnsafeProjection.create(fieldTypes) + val converter = factory.create(fieldTypes) val rowWithAllNullColumns: InternalRow = { val r = new SpecificInternalRow(fieldTypes) @@ -240,7 +248,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { // assert(setToNullAfterCreation.get(11) === rowWithNoNullColumns.get(11)) } - test("NaN canonicalization") { + testWithFactory("NaN canonicalization") { factory => val fieldTypes: Array[DataType] = Array(FloatType, DoubleType) val row1 = new SpecificInternalRow(fieldTypes) @@ -251,17 +259,17 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { row2.setFloat(0, java.lang.Float.intBitsToFloat(0x7fffffff)) row2.setDouble(1, java.lang.Double.longBitsToDouble(0x7fffffffffffffffL)) - val converter = UnsafeProjection.create(fieldTypes) + val converter = factory.create(fieldTypes) assert(converter.apply(row1).getBytes === converter.apply(row2).getBytes) } - test("basic conversion with struct type") { + testWithFactory("basic conversion with struct type") { factory => val fieldTypes: Array[DataType] = Array( new StructType().add("i", IntegerType), new StructType().add("nest", new StructType().add("l", LongType)) ) - val converter = UnsafeProjection.create(fieldTypes) + val converter = factory.create(fieldTypes) val row = new GenericInternalRow(fieldTypes.length) row.update(0, InternalRow(1)) @@ -317,12 +325,12 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { assert(map.getSizeInBytes == 8 + map.keyArray.getSizeInBytes + map.valueArray.getSizeInBytes) } - test("basic conversion with array type") { + testWithFactory("basic conversion with array type") { factory => val fieldTypes: Array[DataType] = Array( ArrayType(IntegerType), ArrayType(ArrayType(IntegerType)) ) - val converter = UnsafeProjection.create(fieldTypes) + val converter = factory.create(fieldTypes) val row = new GenericInternalRow(fieldTypes.length) row.update(0, createArray(1, 2)) @@ -347,12 +355,12 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { assert(unsafeRow.getSizeInBytes == 8 + 8 * 2 + array1Size + array2Size) } - test("basic conversion with map type") { + testWithFactory("basic conversion with map type") { factory => val fieldTypes: Array[DataType] = Array( MapType(IntegerType, IntegerType), MapType(IntegerType, MapType(IntegerType, IntegerType)) ) - val converter = UnsafeProjection.create(fieldTypes) + val converter = factory.create(fieldTypes) val map1 = createMap(1, 2)(3, 4) @@ -393,12 +401,12 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { assert(unsafeRow.getSizeInBytes == 8 + 8 * 2 + map1Size + map2Size) } - test("basic conversion with struct and array") { + testWithFactory("basic conversion with struct and array") { factory => val fieldTypes: Array[DataType] = Array( new StructType().add("arr", ArrayType(IntegerType)), ArrayType(new StructType().add("l", LongType)) ) - val converter = UnsafeProjection.create(fieldTypes) + val converter = factory.create(fieldTypes) val row = new GenericInternalRow(fieldTypes.length) row.update(0, InternalRow(createArray(1))) @@ -432,12 +440,12 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { 8 + 8 * 2 + field1.getSizeInBytes + roundedSize(field2.getSizeInBytes)) } - test("basic conversion with struct and map") { + testWithFactory("basic conversion with struct and map") { factory => val fieldTypes: Array[DataType] = Array( new StructType().add("map", MapType(IntegerType, IntegerType)), MapType(IntegerType, new StructType().add("l", LongType)) ) - val converter = UnsafeProjection.create(fieldTypes) + val converter = factory.create(fieldTypes) val row = new GenericInternalRow(fieldTypes.length) row.update(0, InternalRow(createMap(1)(2))) @@ -478,12 +486,12 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { 8 + 8 * 2 + field1.getSizeInBytes + roundedSize(field2.getSizeInBytes)) } - test("basic conversion with array and map") { + testWithFactory("basic conversion with array and map") { factory => val fieldTypes: Array[DataType] = Array( ArrayType(MapType(IntegerType, IntegerType)), MapType(IntegerType, ArrayType(IntegerType)) ) - val converter = UnsafeProjection.create(fieldTypes) + val converter = factory.create(fieldTypes) val row = new GenericInternalRow(fieldTypes.length) row.update(0, createArray(createMap(1)(2))) From 9945b0227efcd952c8e835453b2831a8c6d5d607 Mon Sep 17 00:00:00 2001 From: Ricardo Martinelli de Oliveira Date: Fri, 16 Mar 2018 10:37:11 -0700 Subject: [PATCH 0484/1161] [SPARK-23680] Fix entrypoint.sh to properly support Arbitrary UIDs ## What changes were proposed in this pull request? As described in SPARK-23680, entrypoint.sh returns an error code because of a command pipeline execution where it is expected in case of Openshift environments, where arbitrary UIDs are used to run containers ## How was this patch tested? This patch was manually tested by using docker-image-toll.sh script to generate a Spark driver image and running an example against an OpenShift cluster. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Ricardo Martinelli de Oliveira Closes #20822 from rimolive/rmartine-spark-23680. --- .../kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh | 3 +++ 1 file changed, 3 insertions(+) diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh index 3d67b0a702dd4..d0cf284f035ea 100755 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh @@ -22,7 +22,10 @@ set -ex # Check whether there is a passwd entry for the container UID myuid=$(id -u) mygid=$(id -g) +# turn off -e for getent because it will return error code in anonymous uid case +set +e uidentry=$(getent passwd $myuid) +set -e # If there is no passwd entry for the container UID, attempt to create one if [ -z "$uidentry" ] ; then From bd201bf61e8e1713deb91b962f670c76c9e3492b Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Fri, 16 Mar 2018 11:11:07 -0700 Subject: [PATCH 0485/1161] [SPARK-23623][SS] Avoid concurrent use of cached consumers in CachedKafkaConsumer ## What changes were proposed in this pull request? CacheKafkaConsumer in the project `kafka-0-10-sql` is designed to maintain a pool of KafkaConsumers that can be reused. However, it was built with the assumption there will be only one task using trying to read the same Kafka TopicPartition at the same time. Hence, the cache was keyed by the TopicPartition a consumer is supposed to read. And any cases where this assumption may not be true, we have SparkPlan flag to disable the use of a cache. So it was up to the planner to correctly identify when it was not safe to use the cache and set the flag accordingly. Fundamentally, this is the wrong way to approach the problem. It is HARD for a high-level planner to reason about the low-level execution model, whether there will be multiple tasks in the same query trying to read the same partition. Case in point, 2.3.0 introduced stream-stream joins, and you can build a streaming self-join query on Kafka. It's pretty non-trivial to figure out how this leads to two tasks reading the same partition twice, possibly concurrently. And due to the non-triviality, it is hard to figure this out in the planner and set the flag to avoid the cache / consumer pool. And this can inadvertently lead to ConcurrentModificationException ,or worse, silent reading of incorrect data. Here is a better way to design this. The planner shouldnt have to understand these low-level optimizations. Rather the consumer pool should be smart enough avoid concurrent use of a cached consumer. Currently, it tries to do so but incorrectly (the flag inuse is not checked when returning a cached consumer, see [this](https://github.com/apache/spark/blob/master/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumer.scala#L403)). If there is another request for the same partition as a currently in-use consumer, the pool should automatically return a fresh consumer that should be closed when the task is done. Then the planner does not have to have a flag to avoid reuses. This PR is a step towards that goal. It does the following. - There are effectively two kinds of consumer that may be generated - Cached consumer - this should be returned to the pool at task end - Non-cached consumer - this should be closed at task end - A trait called KafkaConsumer is introduced to hide this difference from the users of the consumer so that the client code does not have to reason about whether to stop and release. They simply called `val consumer = KafkaConsumer.acquire` and then `consumer.release()`. - If there is request for a consumer that is in-use, then a new consumer is generated. - If there is a concurrent attempt of the same task, then a new consumer is generated, and the existing cached consumer is marked for close upon release. - In addition, I renamed the classes because CachedKafkaConsumer is a misnomer given that what it returns may or may not be cached. This PR does not remove the planner flag to avoid reuse to make this patch safe enough for merging in branch-2.3. This can be done later in master-only. ## How was this patch tested? A new stress test that verifies it is safe to concurrently get consumers for the same partition from the consumer pool. Author: Tathagata Das Closes #20767 from tdas/SPARK-23623. --- .../sql/kafka010/KafkaContinuousReader.scala | 5 +- ...Consumer.scala => KafkaDataConsumer.scala} | 242 ++++++++++++------ .../sql/kafka010/KafkaMicroBatchReader.scala | 22 +- .../spark/sql/kafka010/KafkaSourceRDD.scala | 23 +- .../kafka010/CachedKafkaConsumerSuite.scala | 34 --- .../sql/kafka010/KafkaDataConsumerSuite.scala | 124 +++++++++ 6 files changed, 295 insertions(+), 155 deletions(-) rename external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/{CachedKafkaConsumer.scala => KafkaDataConsumer.scala} (66%) delete mode 100644 external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumerSuite.scala create mode 100644 external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaDataConsumerSuite.scala diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala index 6e56b0a72d671..e7e27876088f3 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala @@ -196,8 +196,7 @@ class KafkaContinuousDataReader( kafkaParams: ju.Map[String, Object], pollTimeoutMs: Long, failOnDataLoss: Boolean) extends ContinuousDataReader[UnsafeRow] { - private val consumer = - CachedKafkaConsumer.createUncached(topicPartition.topic, topicPartition.partition, kafkaParams) + private val consumer = KafkaDataConsumer.acquire(topicPartition, kafkaParams, useCache = false) private val converter = new KafkaRecordToUnsafeRowConverter private var nextKafkaOffset = startOffset @@ -245,6 +244,6 @@ class KafkaContinuousDataReader( } override def close(): Unit = { - consumer.close() + consumer.release() } } diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumer.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataConsumer.scala similarity index 66% rename from external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumer.scala rename to external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataConsumer.scala index e97881cb0a163..48508d057a540 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumer.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataConsumer.scala @@ -27,30 +27,73 @@ import org.apache.kafka.common.TopicPartition import org.apache.spark.{SparkEnv, SparkException, TaskContext} import org.apache.spark.internal.Logging +import org.apache.spark.sql.kafka010.KafkaDataConsumer.AvailableOffsetRange import org.apache.spark.sql.kafka010.KafkaSourceProvider._ import org.apache.spark.util.UninterruptibleThread +private[kafka010] sealed trait KafkaDataConsumer { + /** + * Get the record for the given offset if available. Otherwise it will either throw error + * (if failOnDataLoss = true), or return the next available offset within [offset, untilOffset), + * or null. + * + * @param offset the offset to fetch. + * @param untilOffset the max offset to fetch. Exclusive. + * @param pollTimeoutMs timeout in milliseconds to poll data from Kafka. + * @param failOnDataLoss When `failOnDataLoss` is `true`, this method will either return record at + * offset if available, or throw exception.when `failOnDataLoss` is `false`, + * this method will either return record at offset if available, or return + * the next earliest available record less than untilOffset, or null. It + * will not throw any exception. + */ + def get( + offset: Long, + untilOffset: Long, + pollTimeoutMs: Long, + failOnDataLoss: Boolean): ConsumerRecord[Array[Byte], Array[Byte]] = { + internalConsumer.get(offset, untilOffset, pollTimeoutMs, failOnDataLoss) + } + + /** + * Return the available offset range of the current partition. It's a pair of the earliest offset + * and the latest offset. + */ + def getAvailableOffsetRange(): AvailableOffsetRange = internalConsumer.getAvailableOffsetRange() + + /** + * Release this consumer from being further used. Depending on its implementation, + * this consumer will be either finalized, or reset for reuse later. + */ + def release(): Unit + + /** Reference to the internal implementation that this wrapper delegates to */ + protected def internalConsumer: InternalKafkaConsumer +} + /** - * Consumer of single topicpartition, intended for cached reuse. - * Underlying consumer is not threadsafe, so neither is this, - * but processing the same topicpartition and group id in multiple threads is usually bad anyway. + * A wrapper around Kafka's KafkaConsumer that throws error when data loss is detected. + * This is not for direct use outside this file. */ -private[kafka010] case class CachedKafkaConsumer private( +private[kafka010] case class InternalKafkaConsumer( topicPartition: TopicPartition, kafkaParams: ju.Map[String, Object]) extends Logging { - import CachedKafkaConsumer._ + import InternalKafkaConsumer._ private val groupId = kafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG).asInstanceOf[String] - private var consumer = createConsumer + @volatile private var consumer = createConsumer /** indicates whether this consumer is in use or not */ - private var inuse = true + @volatile var inUse = true + + /** indicate whether this consumer is going to be stopped in the next release */ + @volatile var markedForClose = false /** Iterator to the already fetch data */ - private var fetchedData = ju.Collections.emptyIterator[ConsumerRecord[Array[Byte], Array[Byte]]] - private var nextOffsetInFetchedData = UNKNOWN_OFFSET + @volatile private var fetchedData = + ju.Collections.emptyIterator[ConsumerRecord[Array[Byte], Array[Byte]]] + @volatile private var nextOffsetInFetchedData = UNKNOWN_OFFSET /** Create a KafkaConsumer to fetch records for `topicPartition` */ private def createConsumer: KafkaConsumer[Array[Byte], Array[Byte]] = { @@ -61,8 +104,6 @@ private[kafka010] case class CachedKafkaConsumer private( c } - case class AvailableOffsetRange(earliest: Long, latest: Long) - private def runUninterruptiblyIfPossible[T](body: => T): T = Thread.currentThread match { case ut: UninterruptibleThread => ut.runUninterruptibly(body) @@ -313,21 +354,51 @@ private[kafka010] case class CachedKafkaConsumer private( } } -private[kafka010] object CachedKafkaConsumer extends Logging { - private val UNKNOWN_OFFSET = -2L +private[kafka010] object KafkaDataConsumer extends Logging { + + case class AvailableOffsetRange(earliest: Long, latest: Long) + + private case class CachedKafkaDataConsumer(internalConsumer: InternalKafkaConsumer) + extends KafkaDataConsumer { + assert(internalConsumer.inUse) // make sure this has been set to true + override def release(): Unit = { KafkaDataConsumer.release(internalConsumer) } + } + + private case class NonCachedKafkaDataConsumer(internalConsumer: InternalKafkaConsumer) + extends KafkaDataConsumer { + override def release(): Unit = { internalConsumer.close() } + } - private case class CacheKey(groupId: String, topicPartition: TopicPartition) + private case class CacheKey(groupId: String, topicPartition: TopicPartition) { + def this(topicPartition: TopicPartition, kafkaParams: ju.Map[String, Object]) = + this(kafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG).asInstanceOf[String], topicPartition) + } + // This cache has the following important properties. + // - We make a best-effort attempt to maintain the max size of the cache as configured capacity. + // The capacity is not guaranteed to be maintained, especially when there are more active + // tasks simultaneously using consumers than the capacity. private lazy val cache = { val conf = SparkEnv.get.conf val capacity = conf.getInt("spark.sql.kafkaConsumerCache.capacity", 64) - new ju.LinkedHashMap[CacheKey, CachedKafkaConsumer](capacity, 0.75f, true) { + new ju.LinkedHashMap[CacheKey, InternalKafkaConsumer](capacity, 0.75f, true) { override def removeEldestEntry( - entry: ju.Map.Entry[CacheKey, CachedKafkaConsumer]): Boolean = { - if (entry.getValue.inuse == false && this.size > capacity) { - logWarning(s"KafkaConsumer cache hitting max capacity of $capacity, " + - s"removing consumer for ${entry.getKey}") + entry: ju.Map.Entry[CacheKey, InternalKafkaConsumer]): Boolean = { + + // Try to remove the least-used entry if its currently not in use. + // + // If you cannot remove it, then the cache will keep growing. In the worst case, + // the cache will grow to the max number of concurrent tasks that can run in the executor, + // (that is, number of tasks slots) after which it will never reduce. This is unlikely to + // be a serious problem because an executor with more than 64 (default) tasks slots is + // likely running on a beefy machine that can handle a large number of simultaneously + // active consumers. + + if (entry.getValue.inUse == false && this.size > capacity) { + logWarning( + s"KafkaConsumer cache hitting max capacity of $capacity, " + + s"removing consumer for ${entry.getKey}") try { entry.getValue.close() } catch { @@ -342,80 +413,87 @@ private[kafka010] object CachedKafkaConsumer extends Logging { } } - def releaseKafkaConsumer( - topic: String, - partition: Int, - kafkaParams: ju.Map[String, Object]): Unit = { - val groupId = kafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG).asInstanceOf[String] - val topicPartition = new TopicPartition(topic, partition) - val key = CacheKey(groupId, topicPartition) - - synchronized { - val consumer = cache.get(key) - if (consumer != null) { - consumer.inuse = false - } else { - logWarning(s"Attempting to release consumer that does not exist") - } - } - } - /** - * Removes (and closes) the Kafka Consumer for the given topic, partition and group id. + * Get a cached consumer for groupId, assigned to topic and partition. + * If matching consumer doesn't already exist, will be created using kafkaParams. + * The returned consumer must be released explicitly using [[KafkaDataConsumer.release()]]. + * + * Note: This method guarantees that the consumer returned is not currently in use by any one + * else. Within this guarantee, this method will make a best effort attempt to re-use consumers by + * caching them and tracking when they are in use. */ - def removeKafkaConsumer( - topic: String, - partition: Int, - kafkaParams: ju.Map[String, Object]): Unit = { - val groupId = kafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG).asInstanceOf[String] - val topicPartition = new TopicPartition(topic, partition) - val key = CacheKey(groupId, topicPartition) + def acquire( + topicPartition: TopicPartition, + kafkaParams: ju.Map[String, Object], + useCache: Boolean): KafkaDataConsumer = synchronized { + val key = new CacheKey(topicPartition, kafkaParams) + val existingInternalConsumer = cache.get(key) - synchronized { - val removedConsumer = cache.remove(key) - if (removedConsumer != null) { - removedConsumer.close() + lazy val newInternalConsumer = new InternalKafkaConsumer(topicPartition, kafkaParams) + + if (TaskContext.get != null && TaskContext.get.attemptNumber >= 1) { + // If this is reattempt at running the task, then invalidate cached consumer if any and + // start with a new one. + if (existingInternalConsumer != null) { + // Consumer exists in cache. If its in use, mark it for closing later, or close it now. + if (existingInternalConsumer.inUse) { + existingInternalConsumer.markedForClose = true + } else { + existingInternalConsumer.close() + } } + cache.remove(key) // Invalidate the cache in any case + NonCachedKafkaDataConsumer(newInternalConsumer) + + } else if (!useCache) { + // If planner asks to not reuse consumers, then do not use it, return a new consumer + NonCachedKafkaDataConsumer(newInternalConsumer) + + } else if (existingInternalConsumer == null) { + // If consumer is not already cached, then put a new in the cache and return it + cache.put(key, newInternalConsumer) + newInternalConsumer.inUse = true + CachedKafkaDataConsumer(newInternalConsumer) + + } else if (existingInternalConsumer.inUse) { + // If consumer is already cached but is currently in use, then return a new consumer + NonCachedKafkaDataConsumer(newInternalConsumer) + + } else { + // If consumer is already cached and is currently not in use, then return that consumer + existingInternalConsumer.inUse = true + CachedKafkaDataConsumer(existingInternalConsumer) } } - /** - * Get a cached consumer for groupId, assigned to topic and partition. - * If matching consumer doesn't already exist, will be created using kafkaParams. - */ - def getOrCreate( - topic: String, - partition: Int, - kafkaParams: ju.Map[String, Object]): CachedKafkaConsumer = synchronized { - val groupId = kafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG).asInstanceOf[String] - val topicPartition = new TopicPartition(topic, partition) - val key = CacheKey(groupId, topicPartition) - - // If this is reattempt at running the task, then invalidate cache and start with - // a new consumer - if (TaskContext.get != null && TaskContext.get.attemptNumber >= 1) { - removeKafkaConsumer(topic, partition, kafkaParams) - val consumer = new CachedKafkaConsumer(topicPartition, kafkaParams) - consumer.inuse = true - cache.put(key, consumer) - consumer - } else { - if (!cache.containsKey(key)) { - cache.put(key, new CachedKafkaConsumer(topicPartition, kafkaParams)) + private def release(intConsumer: InternalKafkaConsumer): Unit = { + synchronized { + + // Clear the consumer from the cache if this is indeed the consumer present in the cache + val key = new CacheKey(intConsumer.topicPartition, intConsumer.kafkaParams) + val cachedIntConsumer = cache.get(key) + if (intConsumer.eq(cachedIntConsumer)) { + // The released consumer is the same object as the cached one. + if (intConsumer.markedForClose) { + intConsumer.close() + cache.remove(key) + } else { + intConsumer.inUse = false + } + } else { + // The released consumer is either not the same one as in the cache, or not in the cache + // at all. This may happen if the cache was invalidate while this consumer was being used. + // Just close this consumer. + intConsumer.close() + logInfo(s"Released a supposedly cached consumer that was not found in the cache") } - val consumer = cache.get(key) - consumer.inuse = true - consumer } } +} - /** Create an [[CachedKafkaConsumer]] but don't put it into cache. */ - def createUncached( - topic: String, - partition: Int, - kafkaParams: ju.Map[String, Object]): CachedKafkaConsumer = { - new CachedKafkaConsumer(new TopicPartition(topic, partition), kafkaParams) - } +private[kafka010] object InternalKafkaConsumer extends Logging { + + private val UNKNOWN_OFFSET = -2L private def reportDataLoss0( failOnDataLoss: Boolean, diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala index 8a5f3a249b11c..2ed49ba3f5495 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala @@ -321,17 +321,8 @@ private[kafka010] case class KafkaMicroBatchDataReader( failOnDataLoss: Boolean, reuseKafkaConsumer: Boolean) extends DataReader[UnsafeRow] with Logging { - private val consumer = { - if (!reuseKafkaConsumer) { - // If we can't reuse CachedKafkaConsumers, creating a new CachedKafkaConsumer. We - // uses `assign` here, hence we don't need to worry about the "group.id" conflicts. - CachedKafkaConsumer.createUncached( - offsetRange.topicPartition.topic, offsetRange.topicPartition.partition, executorKafkaParams) - } else { - CachedKafkaConsumer.getOrCreate( - offsetRange.topicPartition.topic, offsetRange.topicPartition.partition, executorKafkaParams) - } - } + private val consumer = KafkaDataConsumer.acquire( + offsetRange.topicPartition, executorKafkaParams, reuseKafkaConsumer) private val rangeToRead = resolveRange(offsetRange) private val converter = new KafkaRecordToUnsafeRowConverter @@ -360,14 +351,7 @@ private[kafka010] case class KafkaMicroBatchDataReader( } override def close(): Unit = { - if (!reuseKafkaConsumer) { - // Don't forget to close non-reuse KafkaConsumers. You may take down your cluster! - consumer.close() - } else { - // Indicate that we're no longer using this consumer - CachedKafkaConsumer.releaseKafkaConsumer( - offsetRange.topicPartition.topic, offsetRange.topicPartition.partition, executorKafkaParams) - } + consumer.release() } private def resolveRange(range: KafkaOffsetRange): KafkaOffsetRange = { diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceRDD.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceRDD.scala index 66b3409c0cd04..498e344ea39f4 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceRDD.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceRDD.scala @@ -52,7 +52,7 @@ private[kafka010] case class KafkaSourceRDDPartition( * An RDD that reads data from Kafka based on offset ranges across multiple partitions. * Additionally, it allows preferred locations to be set for each topic + partition, so that * the [[KafkaSource]] can ensure the same executor always reads the same topic + partition - * and cached KafkaConsumers (see [[CachedKafkaConsumer]] can be used read data efficiently. + * and cached KafkaConsumers (see [[KafkaDataConsumer]] can be used read data efficiently. * * @param sc the [[SparkContext]] * @param executorKafkaParams Kafka configuration for creating KafkaConsumer on the executors @@ -126,14 +126,9 @@ private[kafka010] class KafkaSourceRDD( val sourcePartition = thePart.asInstanceOf[KafkaSourceRDDPartition] val topic = sourcePartition.offsetRange.topic val kafkaPartition = sourcePartition.offsetRange.partition - val consumer = - if (!reuseKafkaConsumer) { - // If we can't reuse CachedKafkaConsumers, creating a new CachedKafkaConsumer. As here we - // uses `assign`, we don't need to worry about the "group.id" conflicts. - CachedKafkaConsumer.createUncached(topic, kafkaPartition, executorKafkaParams) - } else { - CachedKafkaConsumer.getOrCreate(topic, kafkaPartition, executorKafkaParams) - } + val consumer = KafkaDataConsumer.acquire( + sourcePartition.offsetRange.topicPartition, executorKafkaParams, reuseKafkaConsumer) + val range = resolveRange(consumer, sourcePartition.offsetRange) assert( range.fromOffset <= range.untilOffset, @@ -167,13 +162,7 @@ private[kafka010] class KafkaSourceRDD( } override protected def close(): Unit = { - if (!reuseKafkaConsumer) { - // Don't forget to close non-reuse KafkaConsumers. You may take down your cluster! - consumer.close() - } else { - // Indicate that we're no longer using this consumer - CachedKafkaConsumer.releaseKafkaConsumer(topic, kafkaPartition, executorKafkaParams) - } + consumer.release() } } // Release consumer, either by removing it or indicating we're no longer using it @@ -184,7 +173,7 @@ private[kafka010] class KafkaSourceRDD( } } - private def resolveRange(consumer: CachedKafkaConsumer, range: KafkaSourceRDDOffsetRange) = { + private def resolveRange(consumer: KafkaDataConsumer, range: KafkaSourceRDDOffsetRange) = { if (range.fromOffset < 0 || range.untilOffset < 0) { // Late bind the offset range val availableOffsetRange = consumer.getAvailableOffsetRange() diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumerSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumerSuite.scala deleted file mode 100644 index 7aa7dd096c07b..0000000000000 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumerSuite.scala +++ /dev/null @@ -1,34 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.kafka010 - -import org.scalatest.PrivateMethodTester - -import org.apache.spark.sql.test.SharedSQLContext - -class CachedKafkaConsumerSuite extends SharedSQLContext with PrivateMethodTester { - - test("SPARK-19886: Report error cause correctly in reportDataLoss") { - val cause = new Exception("D'oh!") - val reportDataLoss = PrivateMethod[Unit]('reportDataLoss0) - val e = intercept[IllegalStateException] { - CachedKafkaConsumer.invokePrivate(reportDataLoss(true, "message", cause)) - } - assert(e.getCause === cause) - } -} diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaDataConsumerSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaDataConsumerSuite.scala new file mode 100644 index 0000000000000..0d0fb9c3ab5af --- /dev/null +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaDataConsumerSuite.scala @@ -0,0 +1,124 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import java.util.concurrent.{Executors, TimeUnit} + +import scala.collection.JavaConverters._ +import scala.concurrent.{ExecutionContext, Future} +import scala.concurrent.duration.Duration +import scala.util.Random + +import org.apache.kafka.clients.consumer.ConsumerConfig +import org.apache.kafka.common.TopicPartition +import org.apache.kafka.common.serialization.ByteArrayDeserializer +import org.scalatest.PrivateMethodTester + +import org.apache.spark.{TaskContext, TaskContextImpl} +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.util.ThreadUtils + +class KafkaDataConsumerSuite extends SharedSQLContext with PrivateMethodTester { + + protected var testUtils: KafkaTestUtils = _ + + override def beforeAll(): Unit = { + super.beforeAll() + testUtils = new KafkaTestUtils(Map[String, Object]()) + testUtils.setup() + } + + override def afterAll(): Unit = { + if (testUtils != null) { + testUtils.teardown() + testUtils = null + } + super.afterAll() + } + + test("SPARK-19886: Report error cause correctly in reportDataLoss") { + val cause = new Exception("D'oh!") + val reportDataLoss = PrivateMethod[Unit]('reportDataLoss0) + val e = intercept[IllegalStateException] { + InternalKafkaConsumer.invokePrivate(reportDataLoss(true, "message", cause)) + } + assert(e.getCause === cause) + } + + test("SPARK-23623: concurrent use of KafkaDataConsumer") { + val topic = "topic" + Random.nextInt() + val data = (1 to 1000).map(_.toString) + testUtils.createTopic(topic, 1) + testUtils.sendMessages(topic, data.toArray) + val topicPartition = new TopicPartition(topic, 0) + + import ConsumerConfig._ + val kafkaParams = Map[String, Object]( + GROUP_ID_CONFIG -> "groupId", + BOOTSTRAP_SERVERS_CONFIG -> testUtils.brokerAddress, + KEY_DESERIALIZER_CLASS_CONFIG -> classOf[ByteArrayDeserializer].getName, + VALUE_DESERIALIZER_CLASS_CONFIG -> classOf[ByteArrayDeserializer].getName, + AUTO_OFFSET_RESET_CONFIG -> "earliest", + ENABLE_AUTO_COMMIT_CONFIG -> "false" + ) + + val numThreads = 100 + val numConsumerUsages = 500 + + @volatile var error: Throwable = null + + def consume(i: Int): Unit = { + val useCache = Random.nextBoolean + val taskContext = if (Random.nextBoolean) { + new TaskContextImpl(0, 0, 0, 0, attemptNumber = Random.nextInt(2), null, null, null) + } else { + null + } + TaskContext.setTaskContext(taskContext) + val consumer = KafkaDataConsumer.acquire( + topicPartition, kafkaParams.asJava, useCache) + try { + val range = consumer.getAvailableOffsetRange() + val rcvd = range.earliest until range.latest map { offset => + val bytes = consumer.get(offset, Long.MaxValue, 10000, failOnDataLoss = false).value() + new String(bytes) + } + assert(rcvd == data) + } catch { + case e: Throwable => + error = e + throw e + } finally { + consumer.release() + } + } + + val threadpool = Executors.newFixedThreadPool(numThreads) + try { + val futures = (1 to numConsumerUsages).map { i => + threadpool.submit(new Runnable { + override def run(): Unit = { consume(i) } + }) + } + futures.foreach(_.get(1, TimeUnit.MINUTES)) + assert(error == null) + } finally { + threadpool.shutdown() + } + } +} From 8a72734f33f6a0abbd3207b0d661633c8b25d9ad Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Fri, 16 Mar 2018 11:42:57 -0700 Subject: [PATCH 0486/1161] [SPARK-15009][PYTHON][ML] Construct a CountVectorizerModel from a vocabulary list ## What changes were proposed in this pull request? Added a class method to construct CountVectorizerModel from a list of vocabulary strings, equivalent to the Scala version. Introduced a common param base class `_CountVectorizerParams` to allow the Python model to also own the parameters. This now matches the Scala class hierarchy. ## How was this patch tested? Added to CountVectorizer doctests to do a transform on a model constructed from vocab, and unit test to verify params and vocab are constructed correctly. Author: Bryan Cutler Closes #16770 from BryanCutler/pyspark-CountVectorizerModel-vocab_ctor-SPARK-15009. --- python/pyspark/ml/feature.py | 168 +++++++++++++++++++++++------------ python/pyspark/ml/tests.py | 32 ++++++- 2 files changed, 142 insertions(+), 58 deletions(-) diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index f2e357f0bede5..a1ceb7f02da8b 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -19,12 +19,12 @@ if sys.version > '3': basestring = str -from pyspark import since, keyword_only +from pyspark import since, keyword_only, SparkContext from pyspark.rdd import ignore_unicode_prefix from pyspark.ml.linalg import _convert_to_vector from pyspark.ml.param.shared import * from pyspark.ml.util import JavaMLReadable, JavaMLWritable -from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaTransformer, _jvm +from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaParams, JavaTransformer, _jvm from pyspark.ml.common import inherit_doc __all__ = ['Binarizer', @@ -403,8 +403,69 @@ def getSplits(self): return self.getOrDefault(self.splits) +class _CountVectorizerParams(JavaParams, HasInputCol, HasOutputCol): + """ + Params for :py:attr:`CountVectorizer` and :py:attr:`CountVectorizerModel`. + """ + + minTF = Param( + Params._dummy(), "minTF", "Filter to ignore rare words in" + + " a document. For each document, terms with frequency/count less than the given" + + " threshold are ignored. If this is an integer >= 1, then this specifies a count (of" + + " times the term must appear in the document); if this is a double in [0,1), then this " + + "specifies a fraction (out of the document's token count). Note that the parameter is " + + "only used in transform of CountVectorizerModel and does not affect fitting. Default 1.0", + typeConverter=TypeConverters.toFloat) + minDF = Param( + Params._dummy(), "minDF", "Specifies the minimum number of" + + " different documents a term must appear in to be included in the vocabulary." + + " If this is an integer >= 1, this specifies the number of documents the term must" + + " appear in; if this is a double in [0,1), then this specifies the fraction of documents." + + " Default 1.0", typeConverter=TypeConverters.toFloat) + vocabSize = Param( + Params._dummy(), "vocabSize", "max size of the vocabulary. Default 1 << 18.", + typeConverter=TypeConverters.toInt) + binary = Param( + Params._dummy(), "binary", "Binary toggle to control the output vector values." + + " If True, all nonzero counts (after minTF filter applied) are set to 1. This is useful" + + " for discrete probabilistic models that model binary events rather than integer counts." + + " Default False", typeConverter=TypeConverters.toBoolean) + + def __init__(self, *args): + super(_CountVectorizerParams, self).__init__(*args) + self._setDefault(minTF=1.0, minDF=1.0, vocabSize=1 << 18, binary=False) + + @since("1.6.0") + def getMinTF(self): + """ + Gets the value of minTF or its default value. + """ + return self.getOrDefault(self.minTF) + + @since("1.6.0") + def getMinDF(self): + """ + Gets the value of minDF or its default value. + """ + return self.getOrDefault(self.minDF) + + @since("1.6.0") + def getVocabSize(self): + """ + Gets the value of vocabSize or its default value. + """ + return self.getOrDefault(self.vocabSize) + + @since("2.0.0") + def getBinary(self): + """ + Gets the value of binary or its default value. + """ + return self.getOrDefault(self.binary) + + @inherit_doc -class CountVectorizer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): +class CountVectorizer(JavaEstimator, _CountVectorizerParams, JavaMLReadable, JavaMLWritable): """ Extracts a vocabulary from document collections and generates a :py:attr:`CountVectorizerModel`. @@ -437,33 +498,20 @@ class CountVectorizer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, >>> loadedModel = CountVectorizerModel.load(modelPath) >>> loadedModel.vocabulary == model.vocabulary True + >>> fromVocabModel = CountVectorizerModel.from_vocabulary(["a", "b", "c"], + ... inputCol="raw", outputCol="vectors") + >>> fromVocabModel.transform(df).show(truncate=False) + +-----+---------------+-------------------------+ + |label|raw |vectors | + +-----+---------------+-------------------------+ + |0 |[a, b, c] |(3,[0,1,2],[1.0,1.0,1.0])| + |1 |[a, b, b, c, a]|(3,[0,1,2],[2.0,2.0,1.0])| + +-----+---------------+-------------------------+ + ... .. versionadded:: 1.6.0 """ - minTF = Param( - Params._dummy(), "minTF", "Filter to ignore rare words in" + - " a document. For each document, terms with frequency/count less than the given" + - " threshold are ignored. If this is an integer >= 1, then this specifies a count (of" + - " times the term must appear in the document); if this is a double in [0,1), then this " + - "specifies a fraction (out of the document's token count). Note that the parameter is " + - "only used in transform of CountVectorizerModel and does not affect fitting. Default 1.0", - typeConverter=TypeConverters.toFloat) - minDF = Param( - Params._dummy(), "minDF", "Specifies the minimum number of" + - " different documents a term must appear in to be included in the vocabulary." + - " If this is an integer >= 1, this specifies the number of documents the term must" + - " appear in; if this is a double in [0,1), then this specifies the fraction of documents." + - " Default 1.0", typeConverter=TypeConverters.toFloat) - vocabSize = Param( - Params._dummy(), "vocabSize", "max size of the vocabulary. Default 1 << 18.", - typeConverter=TypeConverters.toInt) - binary = Param( - Params._dummy(), "binary", "Binary toggle to control the output vector values." + - " If True, all nonzero counts (after minTF filter applied) are set to 1. This is useful" + - " for discrete probabilistic models that model binary events rather than integer counts." + - " Default False", typeConverter=TypeConverters.toBoolean) - @keyword_only def __init__(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, binary=False, inputCol=None, outputCol=None): @@ -474,7 +522,6 @@ def __init__(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, binary=False, inputC super(CountVectorizer, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.CountVectorizer", self.uid) - self._setDefault(minTF=1.0, minDF=1.0, vocabSize=1 << 18, binary=False) kwargs = self._input_kwargs self.setParams(**kwargs) @@ -497,13 +544,6 @@ def setMinTF(self, value): """ return self._set(minTF=value) - @since("1.6.0") - def getMinTF(self): - """ - Gets the value of minTF or its default value. - """ - return self.getOrDefault(self.minTF) - @since("1.6.0") def setMinDF(self, value): """ @@ -511,13 +551,6 @@ def setMinDF(self, value): """ return self._set(minDF=value) - @since("1.6.0") - def getMinDF(self): - """ - Gets the value of minDF or its default value. - """ - return self.getOrDefault(self.minDF) - @since("1.6.0") def setVocabSize(self, value): """ @@ -525,13 +558,6 @@ def setVocabSize(self, value): """ return self._set(vocabSize=value) - @since("1.6.0") - def getVocabSize(self): - """ - Gets the value of vocabSize or its default value. - """ - return self.getOrDefault(self.vocabSize) - @since("2.0.0") def setBinary(self, value): """ @@ -539,24 +565,40 @@ def setBinary(self, value): """ return self._set(binary=value) - @since("2.0.0") - def getBinary(self): - """ - Gets the value of binary or its default value. - """ - return self.getOrDefault(self.binary) - def _create_model(self, java_model): return CountVectorizerModel(java_model) -class CountVectorizerModel(JavaModel, JavaMLReadable, JavaMLWritable): +@inherit_doc +class CountVectorizerModel(JavaModel, _CountVectorizerParams, JavaMLReadable, JavaMLWritable): """ Model fitted by :py:class:`CountVectorizer`. .. versionadded:: 1.6.0 """ + @classmethod + @since("2.4.0") + def from_vocabulary(cls, vocabulary, inputCol, outputCol=None, minTF=None, binary=None): + """ + Construct the model directly from a vocabulary list of strings, + requires an active SparkContext. + """ + sc = SparkContext._active_spark_context + java_class = sc._gateway.jvm.java.lang.String + jvocab = CountVectorizerModel._new_java_array(vocabulary, java_class) + model = CountVectorizerModel._create_from_java_class( + "org.apache.spark.ml.feature.CountVectorizerModel", jvocab) + model.setInputCol(inputCol) + if outputCol is not None: + model.setOutputCol(outputCol) + if minTF is not None: + model.setMinTF(minTF) + if binary is not None: + model.setBinary(binary) + model._set(vocabSize=len(vocabulary)) + return model + @property @since("1.6.0") def vocabulary(self): @@ -565,6 +607,20 @@ def vocabulary(self): """ return self._call_java("vocabulary") + @since("2.4.0") + def setMinTF(self, value): + """ + Sets the value of :py:attr:`minTF`. + """ + return self._set(minTF=value) + + @since("2.4.0") + def setBinary(self, value): + """ + Sets the value of :py:attr:`binary`. + """ + return self._set(binary=value) + @inherit_doc class DCT(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 6dee6938d8916..fd45fd00b270b 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -679,6 +679,34 @@ def test_count_vectorizer_with_binary(self): feature, expected = r self.assertEqual(feature, expected) + def test_count_vectorizer_from_vocab(self): + model = CountVectorizerModel.from_vocabulary(["a", "b", "c"], inputCol="words", + outputCol="features", minTF=2) + self.assertEqual(model.vocabulary, ["a", "b", "c"]) + self.assertEqual(model.getMinTF(), 2) + + dataset = self.spark.createDataFrame([ + (0, "a a a b b c".split(' '), SparseVector(3, {0: 3.0, 1: 2.0}),), + (1, "a a".split(' '), SparseVector(3, {0: 2.0}),), + (2, "a b".split(' '), SparseVector(3, {}),)], ["id", "words", "expected"]) + + transformed_list = model.transform(dataset).select("features", "expected").collect() + + for r in transformed_list: + feature, expected = r + self.assertEqual(feature, expected) + + # Test an empty vocabulary + with QuietTest(self.sc): + with self.assertRaisesRegexp(Exception, "vocabSize.*invalid.*0"): + CountVectorizerModel.from_vocabulary([], inputCol="words") + + # Test model with default settings can transform + model_default = CountVectorizerModel.from_vocabulary(["a", "b", "c"], inputCol="words") + transformed_list = model_default.transform(dataset)\ + .select(model_default.getOrDefault(model_default.outputCol)).collect() + self.assertEqual(len(transformed_list), 3) + def test_rformula_force_index_label(self): df = self.spark.createDataFrame([ (1.0, 1.0, "a"), @@ -2019,8 +2047,8 @@ def test_java_params(self): pyspark.ml.regression] for module in modules: for name, cls in inspect.getmembers(module, inspect.isclass): - if not name.endswith('Model') and issubclass(cls, JavaParams)\ - and not inspect.isabstract(cls): + if not name.endswith('Model') and not name.endswith('Params')\ + and issubclass(cls, JavaParams) and not inspect.isabstract(cls): # NOTE: disable check_params_exist until there is parity with Scala API ParamTests.check_params(self, cls(), check_params_exist=False) From 8a1efe3076f29259151f1fba2ff894487efb6c4e Mon Sep 17 00:00:00 2001 From: Steve Loughran Date: Fri, 16 Mar 2018 15:40:21 -0700 Subject: [PATCH 0487/1161] [SPARK-23683][SQL] FileCommitProtocol.instantiate() hardening ## What changes were proposed in this pull request? With SPARK-20236, `FileCommitProtocol.instantiate()` looks for a three argument constructor, passing in the `dynamicPartitionOverwrite` parameter. If there is no such constructor, it falls back to the classic two-arg one. When `InsertIntoHadoopFsRelationCommand` passes down that `dynamicPartitionOverwrite` flag `to FileCommitProtocol.instantiate(`), it assumes that the instantiated protocol supports the specific requirements of dynamic partition overwrite. It does not notice when this does not hold, and so the output generated may be incorrect. This patch changes `FileCommitProtocol.instantiate()` so when `dynamicPartitionOverwrite == true`, it requires the protocol implementation to have a 3-arg constructor. Classic two arg constructors are supported when it is false. Also it adds some debug level logging for anyone trying to understand what's going on. ## How was this patch tested? Unit tests verify that * classes with only 2-arg constructor cannot be used with dynamic overwrite * classes with only 2-arg constructor can be used without dynamic overwrite * classes with 3 arg constructors can be used with both. * the fallback to any two arg ctor takes place after the attempt to load the 3-arg ctor, * passing in invalid class types fail as expected (regression tests on expected behavior) Author: Steve Loughran Closes #20824 from steveloughran/stevel/SPARK-23683-protocol-instantiate. --- .../internal/io/FileCommitProtocol.scala | 11 +- ...FileCommitProtocolInstantiationSuite.scala | 148 ++++++++++++++++++ 2 files changed, 158 insertions(+), 1 deletion(-) create mode 100644 core/src/test/scala/org/apache/spark/internal/io/FileCommitProtocolInstantiationSuite.scala diff --git a/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala b/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala index 6d0059b6a0272..e6e9c9e328853 100644 --- a/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala +++ b/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala @@ -20,6 +20,7 @@ package org.apache.spark.internal.io import org.apache.hadoop.fs._ import org.apache.hadoop.mapreduce._ +import org.apache.spark.internal.Logging import org.apache.spark.util.Utils @@ -132,7 +133,7 @@ abstract class FileCommitProtocol { } -object FileCommitProtocol { +object FileCommitProtocol extends Logging { class TaskCommitMessage(val obj: Any) extends Serializable object EmptyTaskCommitMessage extends TaskCommitMessage(null) @@ -145,15 +146,23 @@ object FileCommitProtocol { jobId: String, outputPath: String, dynamicPartitionOverwrite: Boolean = false): FileCommitProtocol = { + + logDebug(s"Creating committer $className; job $jobId; output=$outputPath;" + + s" dynamic=$dynamicPartitionOverwrite") val clazz = Utils.classForName(className).asInstanceOf[Class[FileCommitProtocol]] // First try the constructor with arguments (jobId: String, outputPath: String, // dynamicPartitionOverwrite: Boolean). // If that doesn't exist, try the one with (jobId: string, outputPath: String). try { val ctor = clazz.getDeclaredConstructor(classOf[String], classOf[String], classOf[Boolean]) + logDebug("Using (String, String, Boolean) constructor") ctor.newInstance(jobId, outputPath, dynamicPartitionOverwrite.asInstanceOf[java.lang.Boolean]) } catch { case _: NoSuchMethodException => + logDebug("Falling back to (String, String) constructor") + require(!dynamicPartitionOverwrite, + "Dynamic Partition Overwrite is enabled but" + + s" the committer ${className} does not have the appropriate constructor") val ctor = clazz.getDeclaredConstructor(classOf[String], classOf[String]) ctor.newInstance(jobId, outputPath) } diff --git a/core/src/test/scala/org/apache/spark/internal/io/FileCommitProtocolInstantiationSuite.scala b/core/src/test/scala/org/apache/spark/internal/io/FileCommitProtocolInstantiationSuite.scala new file mode 100644 index 0000000000000..2bd32fc927e21 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/internal/io/FileCommitProtocolInstantiationSuite.scala @@ -0,0 +1,148 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.internal.io + +import org.apache.spark.SparkFunSuite + +/** + * Unit tests for instantiation of FileCommitProtocol implementations. + */ +class FileCommitProtocolInstantiationSuite extends SparkFunSuite { + + test("Dynamic partitions require appropriate constructor") { + + // you cannot instantiate a two-arg client with dynamic partitions + // enabled. + val ex = intercept[IllegalArgumentException] { + instantiateClassic(true) + } + // check the contents of the message and rethrow if unexpected. + // this preserves the stack trace of the unexpected + // exception. + if (!ex.toString.contains("Dynamic Partition Overwrite")) { + fail(s"Wrong text in caught exception $ex", ex) + } + } + + test("Standard partitions work with classic constructor") { + instantiateClassic(false) + } + + test("Three arg constructors have priority") { + assert(3 == instantiateNew(false).argCount, + "Wrong constructor argument count") + } + + test("Three arg constructors have priority when dynamic") { + assert(3 == instantiateNew(true).argCount, + "Wrong constructor argument count") + } + + test("The protocol must be of the correct class") { + intercept[ClassCastException] { + FileCommitProtocol.instantiate( + classOf[Other].getCanonicalName, + "job", + "path", + false) + } + } + + test("If there is no matching constructor, class hierarchy is irrelevant") { + intercept[NoSuchMethodException] { + FileCommitProtocol.instantiate( + classOf[NoMatchingArgs].getCanonicalName, + "job", + "path", + false) + } + } + + /** + * Create a classic two-arg protocol instance. + * @param dynamic dyanmic partitioning mode + * @return the instance + */ + private def instantiateClassic(dynamic: Boolean): ClassicConstructorCommitProtocol = { + FileCommitProtocol.instantiate( + classOf[ClassicConstructorCommitProtocol].getCanonicalName, + "job", + "path", + dynamic).asInstanceOf[ClassicConstructorCommitProtocol] + } + + /** + * Create a three-arg protocol instance. + * @param dynamic dyanmic partitioning mode + * @return the instance + */ + private def instantiateNew( + dynamic: Boolean): FullConstructorCommitProtocol = { + FileCommitProtocol.instantiate( + classOf[FullConstructorCommitProtocol].getCanonicalName, + "job", + "path", + dynamic).asInstanceOf[FullConstructorCommitProtocol] + } + +} + +/** + * This protocol implementation does not have the new three-arg + * constructor. + */ +private class ClassicConstructorCommitProtocol(arg1: String, arg2: String) + extends HadoopMapReduceCommitProtocol(arg1, arg2) { +} + +/** + * This protocol implementation does have the new three-arg constructor + * alongside the original, and a 4 arg one for completeness. + * The final value of the real constructor is the number of arguments + * used in the 2- and 3- constructor, for test assertions. + */ +private class FullConstructorCommitProtocol( + arg1: String, + arg2: String, + b: Boolean, + val argCount: Int) + extends HadoopMapReduceCommitProtocol(arg1, arg2, b) { + + def this(arg1: String, arg2: String) = { + this(arg1, arg2, false, 2) + } + + def this(arg1: String, arg2: String, b: Boolean) = { + this(arg1, arg2, false, 3) + } +} + +/** + * This has the 2-arity constructor, but isn't the right class. + */ +private class Other(arg1: String, arg2: String) { + +} + +/** + * This has no matching arguments as well as being the wrong class. + */ +private class NoMatchingArgs() { + +} + From 61487b308b0169e3108c2ad31674a0c80b8ac5f3 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Sun, 18 Mar 2018 20:24:14 +0900 Subject: [PATCH 0488/1161] [SPARK-23706][PYTHON] spark.conf.get(value, default=None) should produce None in PySpark ## What changes were proposed in this pull request? Scala: ``` scala> spark.conf.get("hey", null) res1: String = null ``` ``` scala> spark.conf.get("spark.sql.sources.partitionOverwriteMode", null) res2: String = null ``` Python: **Before** ``` >>> spark.conf.get("hey", None) ... py4j.protocol.Py4JJavaError: An error occurred while calling o30.get. : java.util.NoSuchElementException: hey ... ``` ``` >>> spark.conf.get("spark.sql.sources.partitionOverwriteMode", None) u'STATIC' ``` **After** ``` >>> spark.conf.get("hey", None) is None True ``` ``` >>> spark.conf.get("spark.sql.sources.partitionOverwriteMode", None) is None True ``` *Note that this PR preserves the case below: ``` >>> spark.conf.get("spark.sql.sources.partitionOverwriteMode") u'STATIC' ``` ## How was this patch tested? Manually tested and unit tests were added. Author: hyukjinkwon Closes #20841 from HyukjinKwon/spark-conf-get. --- python/pyspark/sql/conf.py | 9 +++++---- python/pyspark/sql/context.py | 8 ++++---- python/pyspark/sql/tests.py | 11 +++++++++++ 3 files changed, 20 insertions(+), 8 deletions(-) diff --git a/python/pyspark/sql/conf.py b/python/pyspark/sql/conf.py index d929834aeeaa5..b82224b6194ed 100644 --- a/python/pyspark/sql/conf.py +++ b/python/pyspark/sql/conf.py @@ -17,7 +17,7 @@ import sys -from pyspark import since +from pyspark import since, _NoValue from pyspark.rdd import ignore_unicode_prefix @@ -39,15 +39,16 @@ def set(self, key, value): @ignore_unicode_prefix @since(2.0) - def get(self, key, default=None): + def get(self, key, default=_NoValue): """Returns the value of Spark runtime configuration property for the given key, assuming it is set. """ self._checkType(key, "key") - if default is None: + if default is _NoValue: return self._jconf.get(key) else: - self._checkType(default, "default") + if default is not None: + self._checkType(default, "default") return self._jconf.get(key, default) @ignore_unicode_prefix diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index 6cb90399dd616..e9ec7ba866761 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -22,7 +22,7 @@ if sys.version >= '3': basestring = unicode = str -from pyspark import since +from pyspark import since, _NoValue from pyspark.rdd import ignore_unicode_prefix from pyspark.sql.session import _monkey_patch_RDD, SparkSession from pyspark.sql.dataframe import DataFrame @@ -124,11 +124,11 @@ def setConf(self, key, value): @ignore_unicode_prefix @since(1.3) - def getConf(self, key, defaultValue=None): + def getConf(self, key, defaultValue=_NoValue): """Returns the value of Spark SQL configuration property for the given key. - If the key is not set and defaultValue is not None, return - defaultValue. If the key is not set and defaultValue is None, return + If the key is not set and defaultValue is set, return + defaultValue. If the key is not set and defaultValue is not set, return the system default value. >>> sqlContext.getConf("spark.sql.shuffle.partitions") diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 480815d27333f..a0d547ad620e5 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -2504,6 +2504,17 @@ def test_conf(self): spark.conf.unset("bogo") self.assertEqual(spark.conf.get("bogo", "colombia"), "colombia") + self.assertEqual(spark.conf.get("hyukjin", None), None) + + # This returns 'STATIC' because it's the default value of + # 'spark.sql.sources.partitionOverwriteMode', and `defaultValue` in + # `spark.conf.get` is unset. + self.assertEqual(spark.conf.get("spark.sql.sources.partitionOverwriteMode"), "STATIC") + + # This returns None because 'spark.sql.sources.partitionOverwriteMode' is unset, but + # `defaultValue` in `spark.conf.get` is set to None. + self.assertEqual(spark.conf.get("spark.sql.sources.partitionOverwriteMode", None), None) + def test_current_database(self): spark = self.spark spark.catalog._reset() From 745c8c0901ac522ba92c1356ca74bd0dd7701496 Mon Sep 17 00:00:00 2001 From: zhoukang Date: Mon, 19 Mar 2018 13:31:21 +0800 Subject: [PATCH 0489/1161] [SPARK-23708][CORE] Correct comment for function addShutDownHook in ShutdownHookManager ## What changes were proposed in this pull request? Minor modification.Comment below is not right. ``` /** * Adds a shutdown hook with the given priority. Hooks with lower priority values run * first. * * param hook The code to run during shutdown. * return A handle that can be used to unregister the shutdown hook. */ def addShutdownHook(priority: Int)(hook: () => Unit): AnyRef = { shutdownHooks.add(priority, hook) } ``` ## How was this patch tested? UT Author: zhoukang Closes #20845 from caneGuy/zhoukang/fix-shutdowncomment. --- .../main/scala/org/apache/spark/util/ShutdownHookManager.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/util/ShutdownHookManager.scala b/core/src/main/scala/org/apache/spark/util/ShutdownHookManager.scala index 4001fac3c3d5a..b702838fa257f 100644 --- a/core/src/main/scala/org/apache/spark/util/ShutdownHookManager.scala +++ b/core/src/main/scala/org/apache/spark/util/ShutdownHookManager.scala @@ -143,7 +143,7 @@ private[spark] object ShutdownHookManager extends Logging { } /** - * Adds a shutdown hook with the given priority. Hooks with lower priority values run + * Adds a shutdown hook with the given priority. Hooks with higher priority values run * first. * * @param hook The code to run during shutdown. From 4de638c1976dea74761bbe5c30da808178ee885d Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 19 Mar 2018 09:41:43 +0100 Subject: [PATCH 0490/1161] [SPARK-23599][SQL] Add a UUID generator from Pseudo-Random Numbers ## What changes were proposed in this pull request? This patch adds a UUID generator from Pseudo-Random Numbers. We can use it later to have deterministic `UUID()` expression. ## How was this patch tested? Added unit tests. Author: Liang-Chi Hsieh Closes #20817 from viirya/SPARK-23599. --- .../catalyst/util/RandomUUIDGenerator.scala | 43 ++++++++++++++ .../util/RandomUUIDGeneratorSuite.scala | 57 +++++++++++++++++++ 2 files changed, 100 insertions(+) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/RandomUUIDGenerator.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/RandomUUIDGeneratorSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/RandomUUIDGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/RandomUUIDGenerator.scala new file mode 100644 index 0000000000000..4fe07a071c1ca --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/RandomUUIDGenerator.scala @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util + +import java.util.UUID + +import org.apache.commons.math3.random.MersenneTwister + +import org.apache.spark.unsafe.types.UTF8String + +/** + * This class is used to generate a UUID from Pseudo-Random Numbers. + * + * For the algorithm, see RFC 4122: A Universally Unique IDentifier (UUID) URN Namespace, + * section 4.4 "Algorithms for Creating a UUID from Truly Random or Pseudo-Random Numbers". + */ +case class RandomUUIDGenerator(randomSeed: Long) { + private val random = new MersenneTwister(randomSeed) + + def getNextUUID(): UUID = { + val mostSigBits = (random.nextLong() & 0xFFFFFFFFFFFF0FFFL) | 0x0000000000004000L + val leastSigBits = (random.nextLong() | 0x8000000000000000L) & 0xBFFFFFFFFFFFFFFFL + + new UUID(mostSigBits, leastSigBits) + } + + def getNextUUIDUTF8String(): UTF8String = UTF8String.fromString(getNextUUID().toString()) +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/RandomUUIDGeneratorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/RandomUUIDGeneratorSuite.scala new file mode 100644 index 0000000000000..b75739e5a3a65 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/RandomUUIDGeneratorSuite.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util + +import scala.util.Random + +import org.apache.spark.SparkFunSuite + +class RandomUUIDGeneratorSuite extends SparkFunSuite { + test("RandomUUIDGenerator should generate version 4, variant 2 UUIDs") { + val generator = RandomUUIDGenerator(new Random().nextLong()) + for (_ <- 0 to 100) { + val uuid = generator.getNextUUID() + assert(uuid.version() == 4) + assert(uuid.variant() == 2) + } + } + + test("UUID from RandomUUIDGenerator should be deterministic") { + val r1 = new Random(100) + val generator1 = RandomUUIDGenerator(r1.nextLong()) + val r2 = new Random(100) + val generator2 = RandomUUIDGenerator(r2.nextLong()) + val r3 = new Random(101) + val generator3 = RandomUUIDGenerator(r3.nextLong()) + + for (_ <- 0 to 100) { + val uuid1 = generator1.getNextUUID() + val uuid2 = generator2.getNextUUID() + val uuid3 = generator3.getNextUUID() + assert(uuid1 == uuid2) + assert(uuid1 != uuid3) + } + } + + test("Get UTF8String UUID") { + val generator = RandomUUIDGenerator(new Random().nextLong()) + val utf8StringUUID = generator.getNextUUIDUTF8String() + val uuid = java.util.UUID.fromString(utf8StringUUID.toString) + assert(uuid.version() == 4 && uuid.variant() == 2 && utf8StringUUID.toString == uuid.toString) + } +} From f15906da153f139b698e192ec6f82f078f896f1e Mon Sep 17 00:00:00 2001 From: Ilan Filonenko Date: Mon, 19 Mar 2018 11:29:56 -0700 Subject: [PATCH 0491/1161] [SPARK-22839][K8S] Remove the use of init-container for downloading remote dependencies ## What changes were proposed in this pull request? Removal of the init-container for downloading remote dependencies. Built off of the work done by vanzin in an attempt to refactor driver/executor configuration elaborated in [this](https://issues.apache.org/jira/browse/SPARK-22839) ticket. ## How was this patch tested? This patch was tested with unit and integration tests. Author: Ilan Filonenko Closes #20669 from ifilonenko/remove-init-container. --- bin/docker-image-tool.sh | 9 +- .../org/apache/spark/deploy/SparkSubmit.scala | 2 - docs/running-on-kubernetes.md | 71 +------- .../spark/examples/SparkRemoteFileTest.scala | 48 ++++++ .../org/apache/spark/deploy/k8s/Config.scala | 73 +------- .../apache/spark/deploy/k8s/Constants.scala | 21 +-- .../deploy/k8s/InitContainerBootstrap.scala | 120 ------------- .../spark/deploy/k8s/KubernetesUtils.scala | 63 +------ .../k8s/PodWithDetachedInitContainer.scala | 31 ---- .../deploy/k8s/SparkPodInitContainer.scala | 116 ------------- .../k8s/submit/DriverConfigOrchestrator.scala | 45 +---- .../submit/KubernetesClientApplication.scala | 84 +++++---- .../steps/BasicDriverConfigurationStep.scala | 32 ++-- .../steps/DependencyResolutionStep.scala | 18 +- .../DriverInitContainerBootstrapStep.scala | 95 ----------- .../DriverKubernetesCredentialsStep.scala | 2 +- .../BasicInitContainerConfigurationStep.scala | 67 -------- .../InitContainerConfigOrchestrator.scala | 79 --------- .../InitContainerConfigurationStep.scala | 25 --- .../InitContainerMountSecretsStep.scala | 36 ---- .../initcontainer/InitContainerSpec.scala | 37 ---- .../cluster/k8s/ExecutorPodFactory.scala | 43 +---- .../k8s/KubernetesClusterManager.scala | 65 +------ .../k8s/SparkPodInitContainerSuite.scala | 86 ---------- .../spark/deploy/k8s/submit/ClientSuite.scala | 82 ++++----- .../DriverConfigOrchestratorSuite.scala | 41 +---- .../BasicDriverConfigurationStepSuite.scala | 8 +- .../steps/DependencyResolutionStepSuite.scala | 32 ++-- ...riverInitContainerBootstrapStepSuite.scala | 160 ------------------ ...cInitContainerConfigurationStepSuite.scala | 95 ----------- ...InitContainerConfigOrchestratorSuite.scala | 80 --------- .../InitContainerMountSecretsStepSuite.scala | 52 ------ .../cluster/k8s/ExecutorPodFactorySuite.scala | 67 +------- .../src/main/dockerfiles/spark/Dockerfile | 1 - .../src/main/dockerfiles/spark/entrypoint.sh | 20 +-- 35 files changed, 241 insertions(+), 1665 deletions(-) create mode 100644 examples/src/main/scala/org/apache/spark/examples/SparkRemoteFileTest.scala delete mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/InitContainerBootstrap.scala delete mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/PodWithDetachedInitContainer.scala delete mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/SparkPodInitContainer.scala delete mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverInitContainerBootstrapStep.scala delete mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/BasicInitContainerConfigurationStep.scala delete mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerConfigOrchestrator.scala delete mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerConfigurationStep.scala delete mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerMountSecretsStep.scala delete mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerSpec.scala delete mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/SparkPodInitContainerSuite.scala delete mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverInitContainerBootstrapStepSuite.scala delete mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/BasicInitContainerConfigurationStepSuite.scala delete mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerConfigOrchestratorSuite.scala delete mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerMountSecretsStepSuite.scala diff --git a/bin/docker-image-tool.sh b/bin/docker-image-tool.sh index 0d0f564bb8b9b..f090240065bf1 100755 --- a/bin/docker-image-tool.sh +++ b/bin/docker-image-tool.sh @@ -64,9 +64,11 @@ function build { error "Cannot find docker image. This script must be run from a runnable distribution of Apache Spark." fi + local DOCKERFILE=${DOCKERFILE:-"$IMG_PATH/spark/Dockerfile"} + docker build "${BUILD_ARGS[@]}" \ -t $(image_ref spark) \ - -f "$IMG_PATH/spark/Dockerfile" . + -f "$DOCKERFILE" . } function push { @@ -84,6 +86,7 @@ Commands: push Push a pre-built image to a registry. Requires a repository address to be provided. Options: + -f file Dockerfile to build. By default builds the Dockerfile shipped with Spark. -r repo Repository address. -t tag Tag to apply to the built image, or to identify the image to be pushed. -m Use minikube's Docker daemon. @@ -113,10 +116,12 @@ fi REPO= TAG= -while getopts mr:t: option +DOCKERFILE= +while getopts f:mr:t: option do case "${option}" in + f) DOCKERFILE=${OPTARG};; r) REPO=${OPTARG};; t) TAG=${OPTARG};; m) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 1e381965c52ba..329bde08718fe 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -320,8 +320,6 @@ object SparkSubmit extends CommandLineUtils with Logging { printErrorAndExit("Python applications are currently not supported for Kubernetes.") case (KUBERNETES, _) if args.isR => printErrorAndExit("R applications are currently not supported for Kubernetes.") - case (KUBERNETES, CLIENT) => - printErrorAndExit("Client mode is currently not supported for Kubernetes.") case (LOCAL, CLUSTER) => printErrorAndExit("Cluster deploy mode is not compatible with master \"local\"") case (_, CLUSTER) if isShell(args.primaryResource) => diff --git a/docs/running-on-kubernetes.md b/docs/running-on-kubernetes.md index 3c7586e8544ba..975b28de47e20 100644 --- a/docs/running-on-kubernetes.md +++ b/docs/running-on-kubernetes.md @@ -126,29 +126,6 @@ Those dependencies can be added to the classpath by referencing them with `local dependencies in custom-built Docker images in `spark-submit`. Note that using application dependencies from the submission client's local file system is currently not yet supported. - -### Using Remote Dependencies -When there are application dependencies hosted in remote locations like HDFS or HTTP servers, the driver and executor pods -need a Kubernetes [init-container](https://kubernetes.io/docs/concepts/workloads/pods/init-containers/) for downloading -the dependencies so the driver and executor containers can use them locally. - -The init-container handles remote dependencies specified in `spark.jars` (or the `--jars` option of `spark-submit`) and -`spark.files` (or the `--files` option of `spark-submit`). It also handles remotely hosted main application resources, e.g., -the main application jar. The following shows an example of using remote dependencies with the `spark-submit` command: - -```bash -$ bin/spark-submit \ - --master k8s://https://: \ - --deploy-mode cluster \ - --name spark-pi \ - --class org.apache.spark.examples.SparkPi \ - --jars https://path/to/dependency1.jar,https://path/to/dependency2.jar - --files hdfs://host:port/path/to/file1,hdfs://host:port/path/to/file2 - --conf spark.executor.instances=5 \ - --conf spark.kubernetes.container.image= \ - https://path/to/examples.jar -``` - ## Secret Management Kubernetes [Secrets](https://kubernetes.io/docs/concepts/configuration/secret/) can be used to provide credentials for a Spark application to access secured services. To mount a user-specified secret into the driver container, users can use @@ -163,10 +140,6 @@ namespace as that of the driver and executor pods. For example, to mount a secre --conf spark.kubernetes.executor.secrets.spark-secret=/etc/secrets ``` -Note that if an init-container is used, any secret mounted into the driver container will also be mounted into the -init-container of the driver. Similarly, any secret mounted into an executor container will also be mounted into the -init-container of the executor. - ## Introspection and Debugging These are the different ways in which you can investigate a running/completed Spark application, monitor progress, and @@ -604,51 +577,12 @@ specific to Spark on Kubernetes. the Driver process. The user can specify multiple of these to set multiple environment variables. - - spark.kubernetes.mountDependencies.jarsDownloadDir - /var/spark-data/spark-jars - - Location to download jars to in the driver and executors. - This directory must be empty and will be mounted as an empty directory volume on the driver and executor pods. - - - - spark.kubernetes.mountDependencies.filesDownloadDir - /var/spark-data/spark-files - - Location to download jars to in the driver and executors. - This directory must be empty and will be mounted as an empty directory volume on the driver and executor pods. - - - - spark.kubernetes.mountDependencies.timeout - 300s - - Timeout in seconds before aborting the attempt to download and unpack dependencies from remote locations into - the driver and executor pods. - - - - spark.kubernetes.mountDependencies.maxSimultaneousDownloads - 5 - - Maximum number of remote dependencies to download simultaneously in a driver or executor pod. - - - - spark.kubernetes.initContainer.image - (value of spark.kubernetes.container.image) - - Custom container image for the init container of both driver and executors. - - spark.kubernetes.driver.secrets.[SecretName] (none) Add the Kubernetes Secret named SecretName to the driver pod on the path specified in the value. For example, - spark.kubernetes.driver.secrets.spark-secret=/etc/secrets. Note that if an init-container is used, - the secret will also be added to the init-container in the driver pod. + spark.kubernetes.driver.secrets.spark-secret=/etc/secrets. @@ -656,8 +590,7 @@ specific to Spark on Kubernetes. (none) Add the Kubernetes Secret named SecretName to the executor pod on the path specified in the value. For example, - spark.kubernetes.executor.secrets.spark-secret=/etc/secrets. Note that if an init-container is used, - the secret will also be added to the init-container in the executor pod. + spark.kubernetes.executor.secrets.spark-secret=/etc/secrets. \ No newline at end of file diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkRemoteFileTest.scala b/examples/src/main/scala/org/apache/spark/examples/SparkRemoteFileTest.scala new file mode 100644 index 0000000000000..64076f2deb706 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/SparkRemoteFileTest.scala @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples + +import java.io.File + +import org.apache.spark.SparkFiles +import org.apache.spark.sql.SparkSession + +/** Usage: SparkRemoteFileTest [file] */ +object SparkRemoteFileTest { + def main(args: Array[String]) { + if (args.length < 1) { + System.err.println("Usage: SparkRemoteFileTest ") + System.exit(1) + } + val spark = SparkSession + .builder() + .appName("SparkRemoteFileTest") + .getOrCreate() + val sc = spark.sparkContext + val rdd = sc.parallelize(Seq(1)).map(_ => { + val localLocation = SparkFiles.get(args(0)) + println(s"${args(0)} is stored at: $localLocation") + new File(localLocation).isFile + }) + val truthCheck = rdd.collect().head + println(s"Mounting of ${args(0)} was $truthCheck") + spark.stop() + } +} +// scalastyle:on println diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala index 471196ac0e3f6..da34a7e06238a 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala @@ -79,6 +79,12 @@ private[spark] object Config extends Logging { .stringConf .createOptional + val KUBERNETES_DRIVER_SUBMIT_CHECK = + ConfigBuilder("spark.kubernetes.submitInDriver") + .internal() + .booleanConf + .createOptional + val KUBERNETES_EXECUTOR_LIMIT_CORES = ConfigBuilder("spark.kubernetes.executor.limit.cores") .doc("Specify the hard cpu limit for each executor pod") @@ -135,73 +141,6 @@ private[spark] object Config extends Logging { .checkValue(interval => interval > 0, s"Logging interval must be a positive time value.") .createWithDefaultString("1s") - val JARS_DOWNLOAD_LOCATION = - ConfigBuilder("spark.kubernetes.mountDependencies.jarsDownloadDir") - .doc("Location to download jars to in the driver and executors. When using " + - "spark-submit, this directory must be empty and will be mounted as an empty directory " + - "volume on the driver and executor pod.") - .stringConf - .createWithDefault("/var/spark-data/spark-jars") - - val FILES_DOWNLOAD_LOCATION = - ConfigBuilder("spark.kubernetes.mountDependencies.filesDownloadDir") - .doc("Location to download files to in the driver and executors. When using " + - "spark-submit, this directory must be empty and will be mounted as an empty directory " + - "volume on the driver and executor pods.") - .stringConf - .createWithDefault("/var/spark-data/spark-files") - - val INIT_CONTAINER_IMAGE = - ConfigBuilder("spark.kubernetes.initContainer.image") - .doc("Image for the driver and executor's init-container for downloading dependencies.") - .fallbackConf(CONTAINER_IMAGE) - - val INIT_CONTAINER_MOUNT_TIMEOUT = - ConfigBuilder("spark.kubernetes.mountDependencies.timeout") - .doc("Timeout before aborting the attempt to download and unpack dependencies from remote " + - "locations into the driver and executor pods.") - .timeConf(TimeUnit.SECONDS) - .createWithDefault(300) - - val INIT_CONTAINER_MAX_THREAD_POOL_SIZE = - ConfigBuilder("spark.kubernetes.mountDependencies.maxSimultaneousDownloads") - .doc("Maximum number of remote dependencies to download simultaneously in a driver or " + - "executor pod.") - .intConf - .createWithDefault(5) - - val INIT_CONTAINER_REMOTE_JARS = - ConfigBuilder("spark.kubernetes.initContainer.remoteJars") - .doc("Comma-separated list of jar URIs to download in the init-container. This is " + - "calculated from spark.jars.") - .internal() - .stringConf - .createOptional - - val INIT_CONTAINER_REMOTE_FILES = - ConfigBuilder("spark.kubernetes.initContainer.remoteFiles") - .doc("Comma-separated list of file URIs to download in the init-container. This is " + - "calculated from spark.files.") - .internal() - .stringConf - .createOptional - - val INIT_CONTAINER_CONFIG_MAP_NAME = - ConfigBuilder("spark.kubernetes.initContainer.configMapName") - .doc("Name of the config map to use in the init-container that retrieves submitted files " + - "for the executor.") - .internal() - .stringConf - .createOptional - - val INIT_CONTAINER_CONFIG_MAP_KEY_CONF = - ConfigBuilder("spark.kubernetes.initContainer.configMapKey") - .doc("Key for the entry in the init container config map for submitted files that " + - "corresponds to the properties for this init-container.") - .internal() - .stringConf - .createOptional - val KUBERNETES_AUTH_SUBMISSION_CONF_PREFIX = "spark.kubernetes.authenticate.submission" diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Constants.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Constants.scala index 9411956996843..8da5f24044aad 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Constants.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Constants.scala @@ -63,22 +63,13 @@ private[spark] object Constants { val ENV_MOUNTED_CLASSPATH = "SPARK_MOUNTED_CLASSPATH" val ENV_JAVA_OPT_PREFIX = "SPARK_JAVA_OPT_" val ENV_CLASSPATH = "SPARK_CLASSPATH" - val ENV_DRIVER_MAIN_CLASS = "SPARK_DRIVER_CLASS" - val ENV_DRIVER_ARGS = "SPARK_DRIVER_ARGS" - val ENV_DRIVER_JAVA_OPTS = "SPARK_DRIVER_JAVA_OPTS" val ENV_DRIVER_BIND_ADDRESS = "SPARK_DRIVER_BIND_ADDRESS" - val ENV_DRIVER_MEMORY = "SPARK_DRIVER_MEMORY" - val ENV_MOUNTED_FILES_DIR = "SPARK_MOUNTED_FILES_DIR" - - // Bootstrapping dependencies with the init-container - val INIT_CONTAINER_DOWNLOAD_JARS_VOLUME_NAME = "download-jars-volume" - val INIT_CONTAINER_DOWNLOAD_FILES_VOLUME_NAME = "download-files-volume" - val INIT_CONTAINER_PROPERTIES_FILE_VOLUME = "spark-init-properties" - val INIT_CONTAINER_PROPERTIES_FILE_DIR = "/etc/spark-init" - val INIT_CONTAINER_PROPERTIES_FILE_NAME = "spark-init.properties" - val INIT_CONTAINER_PROPERTIES_FILE_PATH = - s"$INIT_CONTAINER_PROPERTIES_FILE_DIR/$INIT_CONTAINER_PROPERTIES_FILE_NAME" - val INIT_CONTAINER_SECRET_VOLUME_NAME = "spark-init-secret" + val ENV_SPARK_CONF_DIR = "SPARK_CONF_DIR" + // Spark app configs for containers + val SPARK_CONF_VOLUME = "spark-conf-volume" + val SPARK_CONF_DIR_INTERNAL = "/opt/spark/conf" + val SPARK_CONF_FILE_NAME = "spark.properties" + val SPARK_CONF_PATH = s"$SPARK_CONF_DIR_INTERNAL/$SPARK_CONF_FILE_NAME" // Miscellaneous val KUBERNETES_MASTER_INTERNAL_URL = "https://kubernetes.default.svc" diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/InitContainerBootstrap.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/InitContainerBootstrap.scala deleted file mode 100644 index f6a57dfe00171..0000000000000 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/InitContainerBootstrap.scala +++ /dev/null @@ -1,120 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s - -import scala.collection.JavaConverters._ - -import io.fabric8.kubernetes.api.model.{ContainerBuilder, EmptyDirVolumeSource, EnvVarBuilder, PodBuilder, VolumeMount, VolumeMountBuilder} - -import org.apache.spark.{SparkConf, SparkException} -import org.apache.spark.deploy.k8s.Config._ -import org.apache.spark.deploy.k8s.Constants._ - -/** - * Bootstraps an init-container for downloading remote dependencies. This is separated out from - * the init-container steps API because this component can be used to bootstrap init-containers - * for both the driver and executors. - */ -private[spark] class InitContainerBootstrap( - initContainerImage: String, - imagePullPolicy: String, - jarsDownloadPath: String, - filesDownloadPath: String, - configMapName: String, - configMapKey: String, - sparkRole: String, - sparkConf: SparkConf) { - - /** - * Bootstraps an init-container that downloads dependencies to be used by a main container. - */ - def bootstrapInitContainer( - original: PodWithDetachedInitContainer): PodWithDetachedInitContainer = { - val sharedVolumeMounts = Seq[VolumeMount]( - new VolumeMountBuilder() - .withName(INIT_CONTAINER_DOWNLOAD_JARS_VOLUME_NAME) - .withMountPath(jarsDownloadPath) - .build(), - new VolumeMountBuilder() - .withName(INIT_CONTAINER_DOWNLOAD_FILES_VOLUME_NAME) - .withMountPath(filesDownloadPath) - .build()) - - val customEnvVarKeyPrefix = sparkRole match { - case SPARK_POD_DRIVER_ROLE => KUBERNETES_DRIVER_ENV_KEY - case SPARK_POD_EXECUTOR_ROLE => "spark.executorEnv." - case _ => throw new SparkException(s"$sparkRole is not a valid Spark pod role") - } - val customEnvVars = sparkConf.getAllWithPrefix(customEnvVarKeyPrefix).toSeq.map { - case (key, value) => - new EnvVarBuilder() - .withName(key) - .withValue(value) - .build() - } - - val initContainer = new ContainerBuilder(original.initContainer) - .withName("spark-init") - .withImage(initContainerImage) - .withImagePullPolicy(imagePullPolicy) - .addAllToEnv(customEnvVars.asJava) - .addNewVolumeMount() - .withName(INIT_CONTAINER_PROPERTIES_FILE_VOLUME) - .withMountPath(INIT_CONTAINER_PROPERTIES_FILE_DIR) - .endVolumeMount() - .addToVolumeMounts(sharedVolumeMounts: _*) - .addToArgs("init") - .addToArgs(INIT_CONTAINER_PROPERTIES_FILE_PATH) - .build() - - val podWithBasicVolumes = new PodBuilder(original.pod) - .editSpec() - .addNewVolume() - .withName(INIT_CONTAINER_PROPERTIES_FILE_VOLUME) - .withNewConfigMap() - .withName(configMapName) - .addNewItem() - .withKey(configMapKey) - .withPath(INIT_CONTAINER_PROPERTIES_FILE_NAME) - .endItem() - .endConfigMap() - .endVolume() - .addNewVolume() - .withName(INIT_CONTAINER_DOWNLOAD_JARS_VOLUME_NAME) - .withEmptyDir(new EmptyDirVolumeSource()) - .endVolume() - .addNewVolume() - .withName(INIT_CONTAINER_DOWNLOAD_FILES_VOLUME_NAME) - .withEmptyDir(new EmptyDirVolumeSource()) - .endVolume() - .endSpec() - .build() - - val mainContainer = new ContainerBuilder(original.mainContainer) - .addToVolumeMounts(sharedVolumeMounts: _*) - .addNewEnv() - .withName(ENV_MOUNTED_FILES_DIR) - .withValue(filesDownloadPath) - .endEnv() - .build() - - PodWithDetachedInitContainer( - podWithBasicVolumes, - initContainer, - mainContainer) - } -} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesUtils.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesUtils.scala index 37331d8bbf9b7..5bc070147d3a8 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesUtils.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesUtils.scala @@ -16,10 +16,6 @@ */ package org.apache.spark.deploy.k8s -import java.io.File - -import io.fabric8.kubernetes.api.model.{Container, Pod, PodBuilder} - import org.apache.spark.SparkConf import org.apache.spark.util.Utils @@ -43,72 +39,23 @@ private[spark] object KubernetesUtils { opt1.foreach { _ => require(opt2.isEmpty, errMessage) } } - /** - * Append the given init-container to a pod's list of init-containers. - * - * @param originalPodSpec original specification of the pod - * @param initContainer the init-container to add to the pod - * @return the pod with the init-container added to the list of InitContainers - */ - def appendInitContainer(originalPodSpec: Pod, initContainer: Container): Pod = { - new PodBuilder(originalPodSpec) - .editOrNewSpec() - .addToInitContainers(initContainer) - .endSpec() - .build() - } - /** * For the given collection of file URIs, resolves them as follows: - * - File URIs with scheme file:// are resolved to the given download path. * - File URIs with scheme local:// resolve to just the path of the URI. * - Otherwise, the URIs are returned as-is. */ - def resolveFileUris( - fileUris: Iterable[String], - fileDownloadPath: String): Iterable[String] = { - fileUris.map { uri => - resolveFileUri(uri, fileDownloadPath, false) - } - } - - /** - * If any file uri has any scheme other than local:// it is mapped as if the file - * was downloaded to the file download path. Otherwise, it is mapped to the path - * part of the URI. - */ - def resolveFilePaths(fileUris: Iterable[String], fileDownloadPath: String): Iterable[String] = { + def resolveFileUrisAndPath(fileUris: Iterable[String]): Iterable[String] = { fileUris.map { uri => - resolveFileUri(uri, fileDownloadPath, true) - } - } - - /** - * Get from a given collection of file URIs the ones that represent remote files. - */ - def getOnlyRemoteFiles(uris: Iterable[String]): Iterable[String] = { - uris.filter { uri => - val scheme = Utils.resolveURI(uri).getScheme - scheme != "file" && scheme != "local" + resolveFileUri(uri) } } - private def resolveFileUri( - uri: String, - fileDownloadPath: String, - assumesDownloaded: Boolean): String = { + private def resolveFileUri(uri: String): String = { val fileUri = Utils.resolveURI(uri) val fileScheme = Option(fileUri.getScheme).getOrElse("file") fileScheme match { - case "local" => - fileUri.getPath - case _ => - if (assumesDownloaded || fileScheme == "file") { - val fileName = new File(fileUri.getPath).getName - s"$fileDownloadPath/$fileName" - } else { - uri - } + case "local" => fileUri.getPath + case _ => uri } } } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/PodWithDetachedInitContainer.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/PodWithDetachedInitContainer.scala deleted file mode 100644 index 0b79f8b12e806..0000000000000 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/PodWithDetachedInitContainer.scala +++ /dev/null @@ -1,31 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s - -import io.fabric8.kubernetes.api.model.{Container, Pod} - -/** - * Represents a pod with a detached init-container (not yet added to the pod). - * - * @param pod the pod - * @param initContainer the init-container in the pod - * @param mainContainer the main container in the pod - */ -private[spark] case class PodWithDetachedInitContainer( - pod: Pod, - initContainer: Container, - mainContainer: Container) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/SparkPodInitContainer.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/SparkPodInitContainer.scala deleted file mode 100644 index c0f08786b76a1..0000000000000 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/SparkPodInitContainer.scala +++ /dev/null @@ -1,116 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s - -import java.io.File -import java.util.concurrent.TimeUnit - -import scala.concurrent.{ExecutionContext, Future} - -import org.apache.spark.{SecurityManager => SparkSecurityManager, SparkConf} -import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.deploy.k8s.Config._ -import org.apache.spark.internal.Logging -import org.apache.spark.util.{ThreadUtils, Utils} - -/** - * Process that fetches files from a resource staging server and/or arbitrary remote locations. - * - * The init-container can handle fetching files from any of those sources, but not all of the - * sources need to be specified. This allows for composing multiple instances of this container - * with different configurations for different download sources, or using the same container to - * download everything at once. - */ -private[spark] class SparkPodInitContainer( - sparkConf: SparkConf, - fileFetcher: FileFetcher) extends Logging { - - private val maxThreadPoolSize = sparkConf.get(INIT_CONTAINER_MAX_THREAD_POOL_SIZE) - private implicit val downloadExecutor = ExecutionContext.fromExecutorService( - ThreadUtils.newDaemonCachedThreadPool("download-executor", maxThreadPoolSize)) - - private val jarsDownloadDir = new File(sparkConf.get(JARS_DOWNLOAD_LOCATION)) - private val filesDownloadDir = new File(sparkConf.get(FILES_DOWNLOAD_LOCATION)) - - private val remoteJars = sparkConf.get(INIT_CONTAINER_REMOTE_JARS) - private val remoteFiles = sparkConf.get(INIT_CONTAINER_REMOTE_FILES) - - private val downloadTimeoutMinutes = sparkConf.get(INIT_CONTAINER_MOUNT_TIMEOUT) - - def run(): Unit = { - logInfo(s"Downloading remote jars: $remoteJars") - downloadFiles( - remoteJars, - jarsDownloadDir, - s"Remote jars download directory specified at $jarsDownloadDir does not exist " + - "or is not a directory.") - - logInfo(s"Downloading remote files: $remoteFiles") - downloadFiles( - remoteFiles, - filesDownloadDir, - s"Remote files download directory specified at $filesDownloadDir does not exist " + - "or is not a directory.") - - downloadExecutor.shutdown() - downloadExecutor.awaitTermination(downloadTimeoutMinutes, TimeUnit.MINUTES) - } - - private def downloadFiles( - filesCommaSeparated: Option[String], - downloadDir: File, - errMessage: String): Unit = { - filesCommaSeparated.foreach { files => - require(downloadDir.isDirectory, errMessage) - Utils.stringToSeq(files).foreach { file => - Future[Unit] { - fileFetcher.fetchFile(file, downloadDir) - } - } - } - } -} - -private class FileFetcher(sparkConf: SparkConf, securityManager: SparkSecurityManager) { - - def fetchFile(uri: String, targetDir: File): Unit = { - Utils.fetchFile( - url = uri, - targetDir = targetDir, - conf = sparkConf, - securityMgr = securityManager, - hadoopConf = SparkHadoopUtil.get.newConfiguration(sparkConf), - timestamp = System.currentTimeMillis(), - useCache = false) - } -} - -object SparkPodInitContainer extends Logging { - - def main(args: Array[String]): Unit = { - logInfo("Starting init-container to download Spark application dependencies.") - val sparkConf = new SparkConf(true) - if (args.nonEmpty) { - Utils.loadDefaultSparkProperties(sparkConf, args(0)) - } - - val securityManager = new SparkSecurityManager(sparkConf) - val fileFetcher = new FileFetcher(sparkConf, securityManager) - new SparkPodInitContainer(sparkConf, fileFetcher).run() - logInfo("Finished downloading application dependencies.") - } -} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestrator.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestrator.scala index ae70904621184..b4d3f04a1bc32 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestrator.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestrator.scala @@ -16,16 +16,11 @@ */ package org.apache.spark.deploy.k8s.submit -import java.util.UUID - -import com.google.common.primitives.Longs - import org.apache.spark.{SparkConf, SparkException} import org.apache.spark.deploy.k8s.{KubernetesUtils, MountSecretsBootstrap} import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ import org.apache.spark.deploy.k8s.submit.steps._ -import org.apache.spark.deploy.k8s.submit.steps.initcontainer.InitContainerConfigOrchestrator import org.apache.spark.launcher.SparkLauncher import org.apache.spark.util.SystemClock import org.apache.spark.util.Utils @@ -34,13 +29,11 @@ import org.apache.spark.util.Utils * Figures out and returns the complete ordered list of needed DriverConfigurationSteps to * configure the Spark driver pod. The returned steps will be applied one by one in the given * order to produce a final KubernetesDriverSpec that is used in KubernetesClientApplication - * to construct and create the driver pod. It uses the InitContainerConfigOrchestrator to - * configure the driver init-container if one is needed, i.e., when there are remote dependencies - * to localize. + * to construct and create the driver pod. */ private[spark] class DriverConfigOrchestrator( kubernetesAppId: String, - launchTime: Long, + kubernetesResourceNamePrefix: String, mainAppResource: Option[MainAppResource], appName: String, mainClass: String, @@ -50,15 +43,8 @@ private[spark] class DriverConfigOrchestrator( // The resource name prefix is derived from the Spark application name, making it easy to connect // the names of the Kubernetes resources from e.g. kubectl or the Kubernetes dashboard to the // application the user submitted. - private val kubernetesResourceNamePrefix = { - val uuid = UUID.nameUUIDFromBytes(Longs.toByteArray(launchTime)).toString.replaceAll("-", "") - s"$appName-$uuid".toLowerCase.replaceAll("\\.", "-") - } private val imagePullPolicy = sparkConf.get(CONTAINER_IMAGE_PULL_POLICY) - private val initContainerConfigMapName = s"$kubernetesResourceNamePrefix-init-config" - private val jarsDownloadPath = sparkConf.get(JARS_DOWNLOAD_LOCATION) - private val filesDownloadPath = sparkConf.get(FILES_DOWNLOAD_LOCATION) def getAllConfigurationSteps: Seq[DriverConfigurationStep] = { val driverCustomLabels = KubernetesUtils.parsePrefixedKeyValuePairs( @@ -126,9 +112,7 @@ private[spark] class DriverConfigOrchestrator( val dependencyResolutionStep = if (sparkJars.nonEmpty || sparkFiles.nonEmpty) { Seq(new DependencyResolutionStep( sparkJars, - sparkFiles, - jarsDownloadPath, - filesDownloadPath)) + sparkFiles)) } else { Nil } @@ -139,33 +123,12 @@ private[spark] class DriverConfigOrchestrator( Nil } - val initContainerBootstrapStep = if (existNonContainerLocalFiles(sparkJars ++ sparkFiles)) { - val orchestrator = new InitContainerConfigOrchestrator( - sparkJars, - sparkFiles, - jarsDownloadPath, - filesDownloadPath, - imagePullPolicy, - initContainerConfigMapName, - INIT_CONTAINER_PROPERTIES_FILE_NAME, - sparkConf) - val bootstrapStep = new DriverInitContainerBootstrapStep( - orchestrator.getAllConfigurationSteps, - initContainerConfigMapName, - INIT_CONTAINER_PROPERTIES_FILE_NAME) - - Seq(bootstrapStep) - } else { - Nil - } - Seq( initialSubmissionStep, serviceBootstrapStep, kubernetesCredentialsStep) ++ dependencyResolutionStep ++ - mountSecretsStep ++ - initContainerBootstrapStep + mountSecretsStep } private def existSubmissionLocalFiles(files: Seq[String]): Boolean = { diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesClientApplication.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesClientApplication.scala index 5884348cb3e41..e16d1add600b2 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesClientApplication.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesClientApplication.scala @@ -16,14 +16,14 @@ */ package org.apache.spark.deploy.k8s.submit +import java.io.StringWriter import java.util.{Collections, UUID} - -import scala.collection.JavaConverters._ -import scala.collection.mutable -import scala.util.control.NonFatal +import java.util.Properties import io.fabric8.kubernetes.api.model._ import io.fabric8.kubernetes.client.KubernetesClient +import scala.collection.mutable +import scala.util.control.NonFatal import org.apache.spark.SparkConf import org.apache.spark.deploy.SparkApplication @@ -32,6 +32,7 @@ import org.apache.spark.deploy.k8s.Constants._ import org.apache.spark.deploy.k8s.SparkKubernetesClientFactory import org.apache.spark.deploy.k8s.submit.steps.DriverConfigurationStep import org.apache.spark.internal.Logging +import org.apache.spark.internal.config.ConfigBuilder import org.apache.spark.util.Utils /** @@ -93,10 +94,8 @@ private[spark] class Client( kubernetesClient: KubernetesClient, waitForAppCompletion: Boolean, appName: String, - watcher: LoggingPodStatusWatcher) extends Logging { - - private val driverJavaOptions = sparkConf.get( - org.apache.spark.internal.config.DRIVER_JAVA_OPTIONS) + watcher: LoggingPodStatusWatcher, + kubernetesResourceNamePrefix: String) extends Logging { /** * Run command that initializes a DriverSpec that will be updated after each @@ -110,33 +109,31 @@ private[spark] class Client( for (nextStep <- submissionSteps) { currentDriverSpec = nextStep.configureDriver(currentDriverSpec) } - - val resolvedDriverJavaOpts = currentDriverSpec - .driverSparkConf - // Remove this as the options are instead extracted and set individually below using - // environment variables with prefix SPARK_JAVA_OPT_. - .remove(org.apache.spark.internal.config.DRIVER_JAVA_OPTIONS) - .getAll - .map { - case (confKey, confValue) => s"-D$confKey=$confValue" - } ++ driverJavaOptions.map(Utils.splitCommandString).getOrElse(Seq.empty) - val driverJavaOptsEnvs: Seq[EnvVar] = resolvedDriverJavaOpts.zipWithIndex.map { - case (option, index) => - new EnvVarBuilder() - .withName(s"$ENV_JAVA_OPT_PREFIX$index") - .withValue(option) - .build() - } - + val configMapName = s"$kubernetesResourceNamePrefix-driver-conf-map" + val configMap = buildConfigMap(configMapName, currentDriverSpec.driverSparkConf) + // The include of the ENV_VAR for "SPARK_CONF_DIR" is to allow for the + // Spark command builder to pickup on the Java Options present in the ConfigMap val resolvedDriverContainer = new ContainerBuilder(currentDriverSpec.driverContainer) - .addAllToEnv(driverJavaOptsEnvs.asJava) + .addNewEnv() + .withName(ENV_SPARK_CONF_DIR) + .withValue(SPARK_CONF_DIR_INTERNAL) + .endEnv() + .addNewVolumeMount() + .withName(SPARK_CONF_VOLUME) + .withMountPath(SPARK_CONF_DIR_INTERNAL) + .endVolumeMount() .build() val resolvedDriverPod = new PodBuilder(currentDriverSpec.driverPod) .editSpec() .addToContainers(resolvedDriverContainer) + .addNewVolume() + .withName(SPARK_CONF_VOLUME) + .withNewConfigMap() + .withName(configMapName) + .endConfigMap() + .endVolume() .endSpec() .build() - Utils.tryWithResource( kubernetesClient .pods() @@ -145,7 +142,8 @@ private[spark] class Client( val createdDriverPod = kubernetesClient.pods().create(resolvedDriverPod) try { if (currentDriverSpec.otherKubernetesResources.nonEmpty) { - val otherKubernetesResources = currentDriverSpec.otherKubernetesResources + val otherKubernetesResources = + currentDriverSpec.otherKubernetesResources ++ Seq(configMap) addDriverOwnerReference(createdDriverPod, otherKubernetesResources) kubernetesClient.resourceList(otherKubernetesResources: _*).createOrReplace() } @@ -180,6 +178,26 @@ private[spark] class Client( originalMetadata.setOwnerReferences(Collections.singletonList(driverPodOwnerReference)) } } + + // Build a Config Map that will house spark conf properties in a single file for spark-submit + private def buildConfigMap(configMapName: String, conf: SparkConf): ConfigMap = { + val properties = new Properties() + conf.getAll.foreach { case (k, v) => + properties.setProperty(k, v) + } + val propertiesWriter = new StringWriter() + properties.store(propertiesWriter, + s"Java properties built from Kubernetes config map with name: $configMapName") + + val namespace = conf.get(KUBERNETES_NAMESPACE) + new ConfigMapBuilder() + .withNewMetadata() + .withName(configMapName) + .withNamespace(namespace) + .endMetadata() + .addToData(SPARK_CONF_FILE_NAME, propertiesWriter.toString) + .build() + } } /** @@ -202,6 +220,9 @@ private[spark] class KubernetesClientApplication extends SparkApplication { val launchTime = System.currentTimeMillis() val waitForAppCompletion = sparkConf.get(WAIT_FOR_APP_COMPLETION) val appName = sparkConf.getOption("spark.app.name").getOrElse("spark") + val kubernetesResourceNamePrefix = { + s"$appName-$launchTime".toLowerCase.replaceAll("\\.", "-") + } // The master URL has been checked for validity already in SparkSubmit. // We just need to get rid of the "k8s://" prefix here. val master = sparkConf.get("spark.master").substring("k8s://".length) @@ -211,7 +232,7 @@ private[spark] class KubernetesClientApplication extends SparkApplication { val orchestrator = new DriverConfigOrchestrator( kubernetesAppId, - launchTime, + kubernetesResourceNamePrefix, clientArguments.mainAppResource, appName, clientArguments.mainClass, @@ -231,7 +252,8 @@ private[spark] class KubernetesClientApplication extends SparkApplication { kubernetesClient, waitForAppCompletion, appName, - watcher) + watcher, + kubernetesResourceNamePrefix) client.run() } } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStep.scala index 164e2e5594778..347c4d2d66826 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStep.scala @@ -26,6 +26,7 @@ import org.apache.spark.deploy.k8s.Constants._ import org.apache.spark.deploy.k8s.KubernetesUtils import org.apache.spark.deploy.k8s.submit.KubernetesDriverSpec import org.apache.spark.internal.config.{DRIVER_CLASS_PATH, DRIVER_MEMORY, DRIVER_MEMORY_OVERHEAD} +import org.apache.spark.launcher.SparkLauncher /** * Performs basic configuration for the driver pod. @@ -56,8 +57,6 @@ private[spark] class BasicDriverConfigurationStep( // Memory settings private val driverMemoryMiB = sparkConf.get(DRIVER_MEMORY) - private val driverMemoryString = sparkConf.get( - DRIVER_MEMORY.key, DRIVER_MEMORY.defaultValueString) private val memoryOverheadMiB = sparkConf .get(DRIVER_MEMORY_OVERHEAD) .getOrElse(math.max((MEMORY_OVERHEAD_FACTOR * driverMemoryMiB).toInt, MEMORY_OVERHEAD_MIN_MIB)) @@ -103,24 +102,12 @@ private[spark] class BasicDriverConfigurationStep( ("cpu", new QuantityBuilder(false).withAmount(limitCores).build()) } - val driverContainer = new ContainerBuilder(driverSpec.driverContainer) + val driverContainerWithoutArgs = new ContainerBuilder(driverSpec.driverContainer) .withName(DRIVER_CONTAINER_NAME) .withImage(driverContainerImage) .withImagePullPolicy(imagePullPolicy) .addAllToEnv(driverCustomEnvs.asJava) .addToEnv(driverExtraClasspathEnv.toSeq: _*) - .addNewEnv() - .withName(ENV_DRIVER_MEMORY) - .withValue(driverMemoryString) - .endEnv() - .addNewEnv() - .withName(ENV_DRIVER_MAIN_CLASS) - .withValue(mainClass) - .endEnv() - .addNewEnv() - .withName(ENV_DRIVER_ARGS) - .withValue(appArgs.mkString(" ")) - .endEnv() .addNewEnv() .withName(ENV_DRIVER_BIND_ADDRESS) .withValueFrom(new EnvVarSourceBuilder() @@ -134,7 +121,16 @@ private[spark] class BasicDriverConfigurationStep( .addToLimits(maybeCpuLimitQuantity.toMap.asJava) .endResources() .addToArgs("driver") - .build() + .addToArgs("--properties-file", SPARK_CONF_PATH) + .addToArgs("--class", mainClass) + // The user application jar is merged into the spark.jars list and managed through that + // property, so there is no need to reference it explicitly here. + .addToArgs(SparkLauncher.NO_RESOURCE) + + val driverContainer = appArgs.toList match { + case "" :: Nil | Nil => driverContainerWithoutArgs.build() + case _ => driverContainerWithoutArgs.addToArgs(appArgs: _*).build() + } val baseDriverPod = new PodBuilder(driverSpec.driverPod) .editOrNewMetadata() @@ -152,10 +148,14 @@ private[spark] class BasicDriverConfigurationStep( .setIfMissing(KUBERNETES_DRIVER_POD_NAME, driverPodName) .set("spark.app.id", kubernetesAppId) .set(KUBERNETES_EXECUTOR_POD_NAME_PREFIX, resourceNamePrefix) + // to set the config variables to allow client-mode spark-submit from driver + .set(KUBERNETES_DRIVER_SUBMIT_CHECK, true) driverSpec.copy( driverPod = baseDriverPod, driverSparkConf = resolvedSparkConf, driverContainer = driverContainer) } + } + diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DependencyResolutionStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DependencyResolutionStep.scala index d4b83235b4e3b..43de329f239ad 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DependencyResolutionStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DependencyResolutionStep.scala @@ -30,13 +30,11 @@ import org.apache.spark.deploy.k8s.submit.KubernetesDriverSpec */ private[spark] class DependencyResolutionStep( sparkJars: Seq[String], - sparkFiles: Seq[String], - jarsDownloadPath: String, - filesDownloadPath: String) extends DriverConfigurationStep { + sparkFiles: Seq[String]) extends DriverConfigurationStep { override def configureDriver(driverSpec: KubernetesDriverSpec): KubernetesDriverSpec = { - val resolvedSparkJars = KubernetesUtils.resolveFileUris(sparkJars, jarsDownloadPath) - val resolvedSparkFiles = KubernetesUtils.resolveFileUris(sparkFiles, filesDownloadPath) + val resolvedSparkJars = KubernetesUtils.resolveFileUrisAndPath(sparkJars) + val resolvedSparkFiles = KubernetesUtils.resolveFileUrisAndPath(sparkFiles) val sparkConf = driverSpec.driverSparkConf.clone() if (resolvedSparkJars.nonEmpty) { @@ -45,14 +43,12 @@ private[spark] class DependencyResolutionStep( if (resolvedSparkFiles.nonEmpty) { sparkConf.set("spark.files", resolvedSparkFiles.mkString(",")) } - - val resolvedClasspath = KubernetesUtils.resolveFilePaths(sparkJars, jarsDownloadPath) - val resolvedDriverContainer = if (resolvedClasspath.nonEmpty) { + val resolvedDriverContainer = if (resolvedSparkJars.nonEmpty) { new ContainerBuilder(driverSpec.driverContainer) .addNewEnv() - .withName(ENV_MOUNTED_CLASSPATH) - .withValue(resolvedClasspath.mkString(File.pathSeparator)) - .endEnv() + .withName(ENV_MOUNTED_CLASSPATH) + .withValue(resolvedSparkJars.mkString(File.pathSeparator)) + .endEnv() .build() } else { driverSpec.driverContainer diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverInitContainerBootstrapStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverInitContainerBootstrapStep.scala deleted file mode 100644 index 9fb3dafdda540..0000000000000 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverInitContainerBootstrapStep.scala +++ /dev/null @@ -1,95 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s.submit.steps - -import java.io.StringWriter -import java.util.Properties - -import io.fabric8.kubernetes.api.model.{ConfigMap, ConfigMapBuilder, ContainerBuilder, HasMetadata} - -import org.apache.spark.deploy.k8s.Config._ -import org.apache.spark.deploy.k8s.KubernetesUtils -import org.apache.spark.deploy.k8s.submit.KubernetesDriverSpec -import org.apache.spark.deploy.k8s.submit.steps.initcontainer.{InitContainerConfigurationStep, InitContainerSpec} - -/** - * Configures the driver init-container that localizes remote dependencies into the driver pod. - * It applies the given InitContainerConfigurationSteps in the given order to produce a final - * InitContainerSpec that is then used to configure the driver pod with the init-container attached. - * It also builds a ConfigMap that will be mounted into the init-container. The ConfigMap carries - * configuration properties for the init-container. - */ -private[spark] class DriverInitContainerBootstrapStep( - steps: Seq[InitContainerConfigurationStep], - configMapName: String, - configMapKey: String) - extends DriverConfigurationStep { - - override def configureDriver(driverSpec: KubernetesDriverSpec): KubernetesDriverSpec = { - var initContainerSpec = InitContainerSpec( - properties = Map.empty[String, String], - driverSparkConf = Map.empty[String, String], - initContainer = new ContainerBuilder().build(), - driverContainer = driverSpec.driverContainer, - driverPod = driverSpec.driverPod, - dependentResources = Seq.empty[HasMetadata]) - for (nextStep <- steps) { - initContainerSpec = nextStep.configureInitContainer(initContainerSpec) - } - - val configMap = buildConfigMap( - configMapName, - configMapKey, - initContainerSpec.properties) - val resolvedDriverSparkConf = driverSpec.driverSparkConf - .clone() - .set(INIT_CONTAINER_CONFIG_MAP_NAME, configMapName) - .set(INIT_CONTAINER_CONFIG_MAP_KEY_CONF, configMapKey) - .setAll(initContainerSpec.driverSparkConf) - val resolvedDriverPod = KubernetesUtils.appendInitContainer( - initContainerSpec.driverPod, initContainerSpec.initContainer) - - driverSpec.copy( - driverPod = resolvedDriverPod, - driverContainer = initContainerSpec.driverContainer, - driverSparkConf = resolvedDriverSparkConf, - otherKubernetesResources = - driverSpec.otherKubernetesResources ++ - initContainerSpec.dependentResources ++ - Seq(configMap)) - } - - private def buildConfigMap( - configMapName: String, - configMapKey: String, - config: Map[String, String]): ConfigMap = { - val properties = new Properties() - config.foreach { entry => - properties.setProperty(entry._1, entry._2) - } - val propertiesWriter = new StringWriter() - properties.store(propertiesWriter, - s"Java properties built from Kubernetes config map with name: $configMapName " + - s"and config map key: $configMapKey") - new ConfigMapBuilder() - .withNewMetadata() - .withName(configMapName) - .endMetadata() - .addToData(configMapKey, propertiesWriter.toString) - .build() - } -} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverKubernetesCredentialsStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverKubernetesCredentialsStep.scala index ccc18908658f1..2424e63999a82 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverKubernetesCredentialsStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverKubernetesCredentialsStep.scala @@ -99,7 +99,7 @@ private[spark] class DriverKubernetesCredentialsStep( }.getOrElse(driverSpec.driverPod) ) - val driverContainerWithMountedSecretVolume = kubernetesCredentialsSecret.map { secret => + val driverContainerWithMountedSecretVolume = kubernetesCredentialsSecret.map { _ => new ContainerBuilder(driverSpec.driverContainer) .addNewVolumeMount() .withName(DRIVER_CREDENTIALS_SECRET_VOLUME_NAME) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/BasicInitContainerConfigurationStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/BasicInitContainerConfigurationStep.scala deleted file mode 100644 index 01469853dacc2..0000000000000 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/BasicInitContainerConfigurationStep.scala +++ /dev/null @@ -1,67 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s.submit.steps.initcontainer - -import org.apache.spark.deploy.k8s.{InitContainerBootstrap, PodWithDetachedInitContainer} -import org.apache.spark.deploy.k8s.Config._ -import org.apache.spark.deploy.k8s.KubernetesUtils - -/** - * Performs basic configuration for the driver init-container with most of the work delegated to - * the given InitContainerBootstrap. - */ -private[spark] class BasicInitContainerConfigurationStep( - sparkJars: Seq[String], - sparkFiles: Seq[String], - jarsDownloadPath: String, - filesDownloadPath: String, - bootstrap: InitContainerBootstrap) - extends InitContainerConfigurationStep { - - override def configureInitContainer(spec: InitContainerSpec): InitContainerSpec = { - val remoteJarsToDownload = KubernetesUtils.getOnlyRemoteFiles(sparkJars) - val remoteFilesToDownload = KubernetesUtils.getOnlyRemoteFiles(sparkFiles) - val remoteJarsConf = if (remoteJarsToDownload.nonEmpty) { - Map(INIT_CONTAINER_REMOTE_JARS.key -> remoteJarsToDownload.mkString(",")) - } else { - Map() - } - val remoteFilesConf = if (remoteFilesToDownload.nonEmpty) { - Map(INIT_CONTAINER_REMOTE_FILES.key -> remoteFilesToDownload.mkString(",")) - } else { - Map() - } - - val baseInitContainerConfig = Map( - JARS_DOWNLOAD_LOCATION.key -> jarsDownloadPath, - FILES_DOWNLOAD_LOCATION.key -> filesDownloadPath) ++ - remoteJarsConf ++ - remoteFilesConf - - val bootstrapped = bootstrap.bootstrapInitContainer( - PodWithDetachedInitContainer( - spec.driverPod, - spec.initContainer, - spec.driverContainer)) - - spec.copy( - initContainer = bootstrapped.initContainer, - driverContainer = bootstrapped.mainContainer, - driverPod = bootstrapped.pod, - properties = spec.properties ++ baseInitContainerConfig) - } -} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerConfigOrchestrator.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerConfigOrchestrator.scala deleted file mode 100644 index f2c29c7ce1076..0000000000000 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerConfigOrchestrator.scala +++ /dev/null @@ -1,79 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s.submit.steps.initcontainer - -import org.apache.spark.{SparkConf, SparkException} -import org.apache.spark.deploy.k8s.{InitContainerBootstrap, KubernetesUtils, MountSecretsBootstrap} -import org.apache.spark.deploy.k8s.Config._ -import org.apache.spark.deploy.k8s.Constants._ - -/** - * Figures out and returns the complete ordered list of InitContainerConfigurationSteps required to - * configure the driver init-container. The returned steps will be applied in the given order to - * produce a final InitContainerSpec that is used to construct the driver init-container in - * DriverInitContainerBootstrapStep. This class is only used when an init-container is needed, i.e., - * when there are remote application dependencies to localize. - */ -private[spark] class InitContainerConfigOrchestrator( - sparkJars: Seq[String], - sparkFiles: Seq[String], - jarsDownloadPath: String, - filesDownloadPath: String, - imagePullPolicy: String, - configMapName: String, - configMapKey: String, - sparkConf: SparkConf) { - - private val initContainerImage = sparkConf - .get(INIT_CONTAINER_IMAGE) - .getOrElse(throw new SparkException( - "Must specify the init-container image when there are remote dependencies")) - - def getAllConfigurationSteps: Seq[InitContainerConfigurationStep] = { - val initContainerBootstrap = new InitContainerBootstrap( - initContainerImage, - imagePullPolicy, - jarsDownloadPath, - filesDownloadPath, - configMapName, - configMapKey, - SPARK_POD_DRIVER_ROLE, - sparkConf) - val baseStep = new BasicInitContainerConfigurationStep( - sparkJars, - sparkFiles, - jarsDownloadPath, - filesDownloadPath, - initContainerBootstrap) - - val secretNamesToMountPaths = KubernetesUtils.parsePrefixedKeyValuePairs( - sparkConf, - KUBERNETES_DRIVER_SECRETS_PREFIX) - // Mount user-specified driver secrets also into the driver's init-container. The - // init-container may need credentials in the secrets to be able to download remote - // dependencies. The driver's main container and its init-container share the secrets - // because the init-container is sort of an implementation details and this sharing - // avoids introducing a dedicated configuration property just for the init-container. - val mountSecretsStep = if (secretNamesToMountPaths.nonEmpty) { - Seq(new InitContainerMountSecretsStep(new MountSecretsBootstrap(secretNamesToMountPaths))) - } else { - Nil - } - - Seq(baseStep) ++ mountSecretsStep - } -} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerConfigurationStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerConfigurationStep.scala deleted file mode 100644 index 0372ad5270951..0000000000000 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerConfigurationStep.scala +++ /dev/null @@ -1,25 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s.submit.steps.initcontainer - -/** - * Represents a step in configuring the driver init-container. - */ -private[spark] trait InitContainerConfigurationStep { - - def configureInitContainer(spec: InitContainerSpec): InitContainerSpec -} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerMountSecretsStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerMountSecretsStep.scala deleted file mode 100644 index 0daa7b95e8aae..0000000000000 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerMountSecretsStep.scala +++ /dev/null @@ -1,36 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s.submit.steps.initcontainer - -import org.apache.spark.deploy.k8s.MountSecretsBootstrap - -/** - * An init-container configuration step for mounting user-specified secrets onto user-specified - * paths. - * - * @param bootstrap a utility actually handling mounting of the secrets - */ -private[spark] class InitContainerMountSecretsStep( - bootstrap: MountSecretsBootstrap) extends InitContainerConfigurationStep { - - override def configureInitContainer(spec: InitContainerSpec) : InitContainerSpec = { - // Mount the secret volumes given that the volumes have already been added to the driver pod - // when mounting the secrets into the main driver container. - val initContainer = bootstrap.mountSecrets(spec.initContainer) - spec.copy(initContainer = initContainer) - } -} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerSpec.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerSpec.scala deleted file mode 100644 index b52c343f0c0ed..0000000000000 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerSpec.scala +++ /dev/null @@ -1,37 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s.submit.steps.initcontainer - -import io.fabric8.kubernetes.api.model.{Container, HasMetadata, Pod} - -/** - * Represents a specification of the init-container for the driver pod. - * - * @param properties properties that should be set on the init-container - * @param driverSparkConf Spark configuration properties that will be carried back to the driver - * @param initContainer the init-container object - * @param driverContainer the driver container object - * @param driverPod the driver pod object - * @param dependentResources resources the init-container depends on to work - */ -private[spark] case class InitContainerSpec( - properties: Map[String, String], - driverSparkConf: Map[String, String], - initContainer: Container, - driverContainer: Container, - driverPod: Pod, - dependentResources: Seq[HasMetadata]) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala index 141bd2827e7c5..98cbd5607da00 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala @@ -21,7 +21,7 @@ import scala.collection.JavaConverters._ import io.fabric8.kubernetes.api.model._ import org.apache.spark.{SparkConf, SparkException} -import org.apache.spark.deploy.k8s.{InitContainerBootstrap, KubernetesUtils, MountSecretsBootstrap, PodWithDetachedInitContainer} +import org.apache.spark.deploy.k8s.{KubernetesUtils, MountSecretsBootstrap} import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ import org.apache.spark.internal.config.{EXECUTOR_CLASS_PATH, EXECUTOR_JAVA_OPTIONS, EXECUTOR_MEMORY, EXECUTOR_MEMORY_OVERHEAD} @@ -34,18 +34,10 @@ import org.apache.spark.util.Utils * @param sparkConf Spark configuration * @param mountSecretsBootstrap an optional component for mounting user-specified secrets onto * user-specified paths into the executor container - * @param initContainerBootstrap an optional component for bootstrapping the executor init-container - * if one is needed, i.e., when there are remote dependencies to - * localize - * @param initContainerMountSecretsBootstrap an optional component for mounting user-specified - * secrets onto user-specified paths into the executor - * init-container */ private[spark] class ExecutorPodFactory( sparkConf: SparkConf, - mountSecretsBootstrap: Option[MountSecretsBootstrap], - initContainerBootstrap: Option[InitContainerBootstrap], - initContainerMountSecretsBootstrap: Option[MountSecretsBootstrap]) { + mountSecretsBootstrap: Option[MountSecretsBootstrap]) { private val executorExtraClasspath = sparkConf.get(EXECUTOR_CLASS_PATH) @@ -94,8 +86,6 @@ private[spark] class ExecutorPodFactory( private val executorCores = sparkConf.getDouble("spark.executor.cores", 1) private val executorLimitCores = sparkConf.get(KUBERNETES_EXECUTOR_LIMIT_CORES) - private val executorJarsDownloadDir = sparkConf.get(JARS_DOWNLOAD_LOCATION) - /** * Configure and construct an executor pod with the given parameters. */ @@ -147,8 +137,9 @@ private[spark] class ExecutorPodFactory( (ENV_EXECUTOR_CORES, math.ceil(executorCores).toInt.toString), (ENV_EXECUTOR_MEMORY, executorMemoryString), (ENV_APPLICATION_ID, applicationId), - (ENV_EXECUTOR_ID, executorId), - (ENV_MOUNTED_CLASSPATH, s"$executorJarsDownloadDir/*")) ++ executorEnvs) + // This is to set the SPARK_CONF_DIR to be /opt/spark/conf + (ENV_SPARK_CONF_DIR, SPARK_CONF_DIR_INTERNAL), + (ENV_EXECUTOR_ID, executorId)) ++ executorEnvs) .map(env => new EnvVarBuilder() .withName(env._1) .withValue(env._2) @@ -221,30 +212,10 @@ private[spark] class ExecutorPodFactory( (bootstrap.addSecretVolumes(executorPod), bootstrap.mountSecrets(containerWithLimitCores)) }.getOrElse((executorPod, containerWithLimitCores)) - val (bootstrappedPod, bootstrappedContainer) = - initContainerBootstrap.map { bootstrap => - val podWithInitContainer = bootstrap.bootstrapInitContainer( - PodWithDetachedInitContainer( - maybeSecretsMountedPod, - new ContainerBuilder().build(), - maybeSecretsMountedContainer)) - - val (pod, mayBeSecretsMountedInitContainer) = - initContainerMountSecretsBootstrap.map { bootstrap => - // Mount the secret volumes given that the volumes have already been added to the - // executor pod when mounting the secrets into the main executor container. - (podWithInitContainer.pod, bootstrap.mountSecrets(podWithInitContainer.initContainer)) - }.getOrElse((podWithInitContainer.pod, podWithInitContainer.initContainer)) - - val bootstrappedPod = KubernetesUtils.appendInitContainer( - pod, mayBeSecretsMountedInitContainer) - - (bootstrappedPod, podWithInitContainer.mainContainer) - }.getOrElse((maybeSecretsMountedPod, maybeSecretsMountedContainer)) - new PodBuilder(bootstrappedPod) + new PodBuilder(maybeSecretsMountedPod) .editSpec() - .addToContainers(bootstrappedContainer) + .addToContainers(maybeSecretsMountedContainer) .endSpec() .build() } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala index a942db6ae02db..ff5f6801da2a3 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala @@ -21,7 +21,7 @@ import java.io.File import io.fabric8.kubernetes.client.Config import org.apache.spark.{SparkContext, SparkException} -import org.apache.spark.deploy.k8s.{InitContainerBootstrap, KubernetesUtils, MountSecretsBootstrap, SparkKubernetesClientFactory} +import org.apache.spark.deploy.k8s.{KubernetesUtils, MountSecretsBootstrap, SparkKubernetesClientFactory} import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ import org.apache.spark.internal.Logging @@ -33,7 +33,9 @@ private[spark] class KubernetesClusterManager extends ExternalClusterManager wit override def canCreate(masterURL: String): Boolean = masterURL.startsWith("k8s") override def createTaskScheduler(sc: SparkContext, masterURL: String): TaskScheduler = { - if (masterURL.startsWith("k8s") && sc.deployMode == "client") { + if (masterURL.startsWith("k8s") && + sc.deployMode == "client" && + !sc.conf.get(KUBERNETES_DRIVER_SUBMIT_CHECK).getOrElse(false)) { throw new SparkException("Client mode is currently not supported for Kubernetes.") } @@ -44,74 +46,23 @@ private[spark] class KubernetesClusterManager extends ExternalClusterManager wit sc: SparkContext, masterURL: String, scheduler: TaskScheduler): SchedulerBackend = { - val sparkConf = sc.getConf - val initContainerConfigMap = sparkConf.get(INIT_CONTAINER_CONFIG_MAP_NAME) - val initContainerConfigMapKey = sparkConf.get(INIT_CONTAINER_CONFIG_MAP_KEY_CONF) - - if (initContainerConfigMap.isEmpty) { - logWarning("The executor's init-container config map is not specified. Executors will " + - "therefore not attempt to fetch remote or submitted dependencies.") - } - - if (initContainerConfigMapKey.isEmpty) { - logWarning("The executor's init-container config map key is not specified. Executors will " + - "therefore not attempt to fetch remote or submitted dependencies.") - } - - // Only set up the bootstrap if they've provided both the config map key and the config map - // name. The config map might not be provided if init-containers aren't being used to - // bootstrap dependencies. - val initContainerBootstrap = for { - configMap <- initContainerConfigMap - configMapKey <- initContainerConfigMapKey - } yield { - val initContainerImage = sparkConf - .get(INIT_CONTAINER_IMAGE) - .getOrElse(throw new SparkException( - "Must specify the init-container image when there are remote dependencies")) - new InitContainerBootstrap( - initContainerImage, - sparkConf.get(CONTAINER_IMAGE_PULL_POLICY), - sparkConf.get(JARS_DOWNLOAD_LOCATION), - sparkConf.get(FILES_DOWNLOAD_LOCATION), - configMap, - configMapKey, - SPARK_POD_EXECUTOR_ROLE, - sparkConf) - } - val executorSecretNamesToMountPaths = KubernetesUtils.parsePrefixedKeyValuePairs( - sparkConf, KUBERNETES_EXECUTOR_SECRETS_PREFIX) + sc.conf, KUBERNETES_EXECUTOR_SECRETS_PREFIX) val mountSecretBootstrap = if (executorSecretNamesToMountPaths.nonEmpty) { Some(new MountSecretsBootstrap(executorSecretNamesToMountPaths)) } else { None } - // Mount user-specified executor secrets also into the executor's init-container. The - // init-container may need credentials in the secrets to be able to download remote - // dependencies. The executor's main container and its init-container share the secrets - // because the init-container is sort of an implementation details and this sharing - // avoids introducing a dedicated configuration property just for the init-container. - val initContainerMountSecretsBootstrap = if (initContainerBootstrap.nonEmpty && - executorSecretNamesToMountPaths.nonEmpty) { - Some(new MountSecretsBootstrap(executorSecretNamesToMountPaths)) - } else { - None - } val kubernetesClient = SparkKubernetesClientFactory.createKubernetesClient( KUBERNETES_MASTER_INTERNAL_URL, - Some(sparkConf.get(KUBERNETES_NAMESPACE)), + Some(sc.conf.get(KUBERNETES_NAMESPACE)), KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX, - sparkConf, + sc.conf, Some(new File(Config.KUBERNETES_SERVICE_ACCOUNT_TOKEN_PATH)), Some(new File(Config.KUBERNETES_SERVICE_ACCOUNT_CA_CRT_PATH))) - val executorPodFactory = new ExecutorPodFactory( - sparkConf, - mountSecretBootstrap, - initContainerBootstrap, - initContainerMountSecretsBootstrap) + val executorPodFactory = new ExecutorPodFactory(sc.conf, mountSecretBootstrap) val allocatorExecutor = ThreadUtils .newDaemonSingleThreadScheduledExecutor("kubernetes-pod-allocator") diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/SparkPodInitContainerSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/SparkPodInitContainerSuite.scala deleted file mode 100644 index e0f29ecd0fb53..0000000000000 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/SparkPodInitContainerSuite.scala +++ /dev/null @@ -1,86 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s - -import java.io.File -import java.util.UUID - -import com.google.common.base.Charsets -import com.google.common.io.Files -import org.mockito.Mockito -import org.scalatest.BeforeAndAfter -import org.scalatest.mockito.MockitoSugar._ - -import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.deploy.k8s.Config._ -import org.apache.spark.util.Utils - -class SparkPodInitContainerSuite extends SparkFunSuite with BeforeAndAfter { - - private val DOWNLOAD_JARS_SECRET_LOCATION = createTempFile("txt") - private val DOWNLOAD_FILES_SECRET_LOCATION = createTempFile("txt") - - private var downloadJarsDir: File = _ - private var downloadFilesDir: File = _ - private var downloadJarsSecretValue: String = _ - private var downloadFilesSecretValue: String = _ - private var fileFetcher: FileFetcher = _ - - override def beforeAll(): Unit = { - downloadJarsSecretValue = Files.toString( - new File(DOWNLOAD_JARS_SECRET_LOCATION), Charsets.UTF_8) - downloadFilesSecretValue = Files.toString( - new File(DOWNLOAD_FILES_SECRET_LOCATION), Charsets.UTF_8) - } - - before { - downloadJarsDir = Utils.createTempDir() - downloadFilesDir = Utils.createTempDir() - fileFetcher = mock[FileFetcher] - } - - after { - downloadJarsDir.delete() - downloadFilesDir.delete() - } - - test("Downloads from remote server should invoke the file fetcher") { - val sparkConf = getSparkConfForRemoteFileDownloads - val initContainerUnderTest = new SparkPodInitContainer(sparkConf, fileFetcher) - initContainerUnderTest.run() - Mockito.verify(fileFetcher).fetchFile("http://localhost:9000/jar1.jar", downloadJarsDir) - Mockito.verify(fileFetcher).fetchFile("hdfs://localhost:9000/jar2.jar", downloadJarsDir) - Mockito.verify(fileFetcher).fetchFile("http://localhost:9000/file.txt", downloadFilesDir) - } - - private def getSparkConfForRemoteFileDownloads: SparkConf = { - new SparkConf(true) - .set(INIT_CONTAINER_REMOTE_JARS, - "http://localhost:9000/jar1.jar,hdfs://localhost:9000/jar2.jar") - .set(INIT_CONTAINER_REMOTE_FILES, - "http://localhost:9000/file.txt") - .set(JARS_DOWNLOAD_LOCATION, downloadJarsDir.getAbsolutePath) - .set(FILES_DOWNLOAD_LOCATION, downloadFilesDir.getAbsolutePath) - } - - private def createTempFile(extension: String): String = { - val dir = Utils.createTempDir() - val file = new File(dir, s"${UUID.randomUUID().toString}.$extension") - Files.write(UUID.randomUUID().toString, file, Charsets.UTF_8) - file.getAbsolutePath - } -} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala index bf4ec04893204..6a501592f42a3 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala @@ -38,6 +38,7 @@ class ClientSuite extends SparkFunSuite with BeforeAndAfter { private val DRIVER_POD_UID = "pod-id" private val DRIVER_POD_API_VERSION = "v1" private val DRIVER_POD_KIND = "pod" + private val KUBERNETES_RESOURCE_PREFIX = "resource-example" private type ResourceList = NamespaceListVisitFromServerGetDeleteRecreateWaitApplicable[ HasMetadata, Boolean] @@ -61,6 +62,7 @@ class ClientSuite extends SparkFunSuite with BeforeAndAfter { private val submissionSteps = Seq(FirstTestConfigurationStep, SecondTestConfigurationStep) private var createdPodArgumentCaptor: ArgumentCaptor[Pod] = _ private var createdResourcesArgumentCaptor: ArgumentCaptor[HasMetadata] = _ + private var createdContainerArgumentCaptor: ArgumentCaptor[Container] = _ before { MockitoAnnotations.initMocks(this) @@ -94,7 +96,8 @@ class ClientSuite extends SparkFunSuite with BeforeAndAfter { kubernetesClient, false, "spark", - loggingPodStatusWatcher) + loggingPodStatusWatcher, + KUBERNETES_RESOURCE_PREFIX) submissionClient.run() val createdPod = createdPodArgumentCaptor.getValue assert(createdPod.getMetadata.getName === FirstTestConfigurationStep.podName) @@ -108,62 +111,52 @@ class ClientSuite extends SparkFunSuite with BeforeAndAfter { SecondTestConfigurationStep.containerName) } - test("The client should create the secondary Kubernetes resources.") { + test("The client should create Kubernetes resources") { + val EXAMPLE_JAVA_OPTS = "-XX:+HeapDumpOnOutOfMemoryError -XX:+PrintGCDetails" + val EXPECTED_JAVA_OPTS = "-XX\\:+HeapDumpOnOutOfMemoryError -XX\\:+PrintGCDetails" val submissionClient = new Client( submissionSteps, - new SparkConf(false), + new SparkConf(false) + .set(org.apache.spark.internal.config.DRIVER_JAVA_OPTIONS, EXAMPLE_JAVA_OPTS), kubernetesClient, false, "spark", - loggingPodStatusWatcher) + loggingPodStatusWatcher, + KUBERNETES_RESOURCE_PREFIX) submissionClient.run() val createdPod = createdPodArgumentCaptor.getValue val otherCreatedResources = createdResourcesArgumentCaptor.getAllValues - assert(otherCreatedResources.size === 1) - val createdResource = Iterables.getOnlyElement(otherCreatedResources).asInstanceOf[Secret] - assert(createdResource.getMetadata.getName === FirstTestConfigurationStep.secretName) - assert(createdResource.getData.asScala === + assert(otherCreatedResources.size === 2) + val secrets = otherCreatedResources.toArray + .filter(_.isInstanceOf[Secret]).map(_.asInstanceOf[Secret]) + val configMaps = otherCreatedResources.toArray + .filter(_.isInstanceOf[ConfigMap]).map(_.asInstanceOf[ConfigMap]) + assert(secrets.nonEmpty) + val secret = secrets.head + assert(secret.getMetadata.getName === FirstTestConfigurationStep.secretName) + assert(secret.getData.asScala === Map(FirstTestConfigurationStep.secretKey -> FirstTestConfigurationStep.secretData)) - val ownerReference = Iterables.getOnlyElement(createdResource.getMetadata.getOwnerReferences) + val ownerReference = Iterables.getOnlyElement(secret.getMetadata.getOwnerReferences) assert(ownerReference.getName === createdPod.getMetadata.getName) assert(ownerReference.getKind === DRIVER_POD_KIND) assert(ownerReference.getUid === DRIVER_POD_UID) assert(ownerReference.getApiVersion === DRIVER_POD_API_VERSION) - } - - test("The client should attach the driver container with the appropriate JVM options.") { - val sparkConf = new SparkConf(false) - .set("spark.logConf", "true") - .set( - org.apache.spark.internal.config.DRIVER_JAVA_OPTIONS, - "-XX:+HeapDumpOnOutOfMemoryError -XX:+PrintGCDetails") - val submissionClient = new Client( - submissionSteps, - sparkConf, - kubernetesClient, - false, - "spark", - loggingPodStatusWatcher) - submissionClient.run() - val createdPod = createdPodArgumentCaptor.getValue + assert(configMaps.nonEmpty) + val configMap = configMaps.head + assert(configMap.getMetadata.getName === + s"$KUBERNETES_RESOURCE_PREFIX-driver-conf-map") + assert(configMap.getData.containsKey(SPARK_CONF_FILE_NAME)) + assert(configMap.getData.get(SPARK_CONF_FILE_NAME).contains(EXPECTED_JAVA_OPTS)) + assert(configMap.getData.get(SPARK_CONF_FILE_NAME).contains( + "spark.custom-conf=custom-conf-value")) val driverContainer = Iterables.getOnlyElement(createdPod.getSpec.getContainers) assert(driverContainer.getName === SecondTestConfigurationStep.containerName) - val driverJvmOptsEnvs = driverContainer.getEnv.asScala.filter { env => - env.getName.startsWith(ENV_JAVA_OPT_PREFIX) - }.sortBy(_.getName) - assert(driverJvmOptsEnvs.size === 4) - - val expectedJvmOptsValues = Seq( - "-Dspark.logConf=true", - s"-D${SecondTestConfigurationStep.sparkConfKey}=" + - s"${SecondTestConfigurationStep.sparkConfValue}", - "-XX:+HeapDumpOnOutOfMemoryError", - "-XX:+PrintGCDetails") - driverJvmOptsEnvs.zip(expectedJvmOptsValues).zipWithIndex.foreach { - case ((resolvedEnv, expectedJvmOpt), index) => - assert(resolvedEnv.getName === s"$ENV_JAVA_OPT_PREFIX$index") - assert(resolvedEnv.getValue === expectedJvmOpt) - } + val driverEnv = driverContainer.getEnv.asScala.head + assert(driverEnv.getName === ENV_SPARK_CONF_DIR) + assert(driverEnv.getValue === SPARK_CONF_DIR_INTERNAL) + val driverMount = driverContainer.getVolumeMounts.asScala.head + assert(driverMount.getName === SPARK_CONF_VOLUME) + assert(driverMount.getMountPath === SPARK_CONF_DIR_INTERNAL) } test("Waiting for app completion should stall on the watcher") { @@ -173,7 +166,8 @@ class ClientSuite extends SparkFunSuite with BeforeAndAfter { kubernetesClient, true, "spark", - loggingPodStatusWatcher) + loggingPodStatusWatcher, + KUBERNETES_RESOURCE_PREFIX) submissionClient.run() verify(loggingPodStatusWatcher).awaitCompletion() } @@ -209,13 +203,11 @@ private object FirstTestConfigurationStep extends DriverConfigurationStep { } private object SecondTestConfigurationStep extends DriverConfigurationStep { - val annotationKey = "second-submit" val annotationValue = "submitted" val sparkConfKey = "spark.custom-conf" val sparkConfValue = "custom-conf-value" val containerName = "driverContainer" - override def configureDriver(driverSpec: KubernetesDriverSpec): KubernetesDriverSpec = { val modifiedPod = new PodBuilder(driverSpec.driverPod) .editMetadata() diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestratorSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestratorSuite.scala index 033d303e946fd..df34d2dbcb5be 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestratorSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestratorSuite.scala @@ -25,7 +25,7 @@ class DriverConfigOrchestratorSuite extends SparkFunSuite { private val DRIVER_IMAGE = "driver-image" private val IC_IMAGE = "init-container-image" private val APP_ID = "spark-app-id" - private val LAUNCH_TIME = 975256L + private val KUBERNETES_RESOURCE_PREFIX = "example-prefix" private val APP_NAME = "spark" private val MAIN_CLASS = "org.apache.spark.examples.SparkPi" private val APP_ARGS = Array("arg1", "arg2") @@ -38,7 +38,7 @@ class DriverConfigOrchestratorSuite extends SparkFunSuite { val mainAppResource = JavaMainAppResource("local:///var/apps/jars/main.jar") val orchestrator = new DriverConfigOrchestrator( APP_ID, - LAUNCH_TIME, + KUBERNETES_RESOURCE_PREFIX, Some(mainAppResource), APP_NAME, MAIN_CLASS, @@ -49,15 +49,14 @@ class DriverConfigOrchestratorSuite extends SparkFunSuite { classOf[BasicDriverConfigurationStep], classOf[DriverServiceBootstrapStep], classOf[DriverKubernetesCredentialsStep], - classOf[DependencyResolutionStep] - ) + classOf[DependencyResolutionStep]) } test("Base submission steps without a main app resource.") { val sparkConf = new SparkConf(false).set(CONTAINER_IMAGE, DRIVER_IMAGE) val orchestrator = new DriverConfigOrchestrator( APP_ID, - LAUNCH_TIME, + KUBERNETES_RESOURCE_PREFIX, Option.empty, APP_NAME, MAIN_CLASS, @@ -67,31 +66,7 @@ class DriverConfigOrchestratorSuite extends SparkFunSuite { orchestrator, classOf[BasicDriverConfigurationStep], classOf[DriverServiceBootstrapStep], - classOf[DriverKubernetesCredentialsStep] - ) - } - - test("Submission steps with an init-container.") { - val sparkConf = new SparkConf(false) - .set(CONTAINER_IMAGE, DRIVER_IMAGE) - .set(INIT_CONTAINER_IMAGE.key, IC_IMAGE) - .set("spark.jars", "hdfs://localhost:9000/var/apps/jars/jar1.jar") - val mainAppResource = JavaMainAppResource("local:///var/apps/jars/main.jar") - val orchestrator = new DriverConfigOrchestrator( - APP_ID, - LAUNCH_TIME, - Some(mainAppResource), - APP_NAME, - MAIN_CLASS, - APP_ARGS, - sparkConf) - validateStepTypes( - orchestrator, - classOf[BasicDriverConfigurationStep], - classOf[DriverServiceBootstrapStep], - classOf[DriverKubernetesCredentialsStep], - classOf[DependencyResolutionStep], - classOf[DriverInitContainerBootstrapStep]) + classOf[DriverKubernetesCredentialsStep]) } test("Submission steps with driver secrets to mount") { @@ -102,7 +77,7 @@ class DriverConfigOrchestratorSuite extends SparkFunSuite { val mainAppResource = JavaMainAppResource("local:///var/apps/jars/main.jar") val orchestrator = new DriverConfigOrchestrator( APP_ID, - LAUNCH_TIME, + KUBERNETES_RESOURCE_PREFIX, Some(mainAppResource), APP_NAME, MAIN_CLASS, @@ -122,7 +97,7 @@ class DriverConfigOrchestratorSuite extends SparkFunSuite { .set(CONTAINER_IMAGE, DRIVER_IMAGE) var orchestrator = new DriverConfigOrchestrator( APP_ID, - LAUNCH_TIME, + KUBERNETES_RESOURCE_PREFIX, Some(JavaMainAppResource("file:///var/apps/jars/main.jar")), APP_NAME, MAIN_CLASS, @@ -135,7 +110,7 @@ class DriverConfigOrchestratorSuite extends SparkFunSuite { sparkConf.set("spark.files", "/path/to/file1,/path/to/file2") orchestrator = new DriverConfigOrchestrator( APP_ID, - LAUNCH_TIME, + KUBERNETES_RESOURCE_PREFIX, Some(JavaMainAppResource("local:///var/apps/jars/main.jar")), APP_NAME, MAIN_CLASS, diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStepSuite.scala index b136f2c02ffba..ce068531c7673 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStepSuite.scala @@ -73,16 +73,13 @@ class BasicDriverConfigurationStepSuite extends SparkFunSuite { assert(preparedDriverSpec.driverContainer.getImage === "spark-driver:latest") assert(preparedDriverSpec.driverContainer.getImagePullPolicy === CONTAINER_IMAGE_PULL_POLICY) - assert(preparedDriverSpec.driverContainer.getEnv.size === 7) + assert(preparedDriverSpec.driverContainer.getEnv.size === 4) val envs = preparedDriverSpec.driverContainer .getEnv .asScala .map(env => (env.getName, env.getValue)) .toMap assert(envs(ENV_CLASSPATH) === "/opt/spark/spark-examples.jar") - assert(envs(ENV_DRIVER_MEMORY) === "256M") - assert(envs(ENV_DRIVER_MAIN_CLASS) === MAIN_CLASS) - assert(envs(ENV_DRIVER_ARGS) === "arg1 arg2 \"arg 3\"") assert(envs(DRIVER_CUSTOM_ENV_KEY1) === "customDriverEnv1") assert(envs(DRIVER_CUSTOM_ENV_KEY2) === "customDriverEnv2") @@ -112,7 +109,8 @@ class BasicDriverConfigurationStepSuite extends SparkFunSuite { val expectedSparkConf = Map( KUBERNETES_DRIVER_POD_NAME.key -> "spark-driver-pod", "spark.app.id" -> APP_ID, - KUBERNETES_EXECUTOR_POD_NAME_PREFIX.key -> RESOURCE_NAME_PREFIX) + KUBERNETES_EXECUTOR_POD_NAME_PREFIX.key -> RESOURCE_NAME_PREFIX, + "spark.kubernetes.submitInDriver" -> "true") assert(resolvedSparkConf === expectedSparkConf) } } diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DependencyResolutionStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DependencyResolutionStepSuite.scala index 991b03cafb76c..ca43fc97dc991 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DependencyResolutionStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DependencyResolutionStepSuite.scala @@ -29,24 +29,17 @@ import org.apache.spark.deploy.k8s.submit.KubernetesDriverSpec class DependencyResolutionStepSuite extends SparkFunSuite { private val SPARK_JARS = Seq( - "hdfs://localhost:9000/apps/jars/jar1.jar", - "file:///home/user/apps/jars/jar2.jar", - "local:///var/apps/jars/jar3.jar") + "apps/jars/jar1.jar", + "local:///var/apps/jars/jar2.jar") private val SPARK_FILES = Seq( - "file:///home/user/apps/files/file1.txt", - "hdfs://localhost:9000/apps/files/file2.txt", - "local:///var/apps/files/file3.txt") - - private val JARS_DOWNLOAD_PATH = "/mnt/spark-data/jars" - private val FILES_DOWNLOAD_PATH = "/mnt/spark-data/files" + "apps/files/file1.txt", + "local:///var/apps/files/file2.txt") test("Added dependencies should be resolved in Spark configuration and environment") { val dependencyResolutionStep = new DependencyResolutionStep( SPARK_JARS, - SPARK_FILES, - JARS_DOWNLOAD_PATH, - FILES_DOWNLOAD_PATH) + SPARK_FILES) val driverPod = new PodBuilder().build() val baseDriverSpec = KubernetesDriverSpec( driverPod = driverPod, @@ -58,24 +51,19 @@ class DependencyResolutionStepSuite extends SparkFunSuite { assert(preparedDriverSpec.otherKubernetesResources.isEmpty) val resolvedSparkJars = preparedDriverSpec.driverSparkConf.get("spark.jars").split(",").toSet val expectedResolvedSparkJars = Set( - "hdfs://localhost:9000/apps/jars/jar1.jar", - s"$JARS_DOWNLOAD_PATH/jar2.jar", - "/var/apps/jars/jar3.jar") + "apps/jars/jar1.jar", + "/var/apps/jars/jar2.jar") assert(resolvedSparkJars === expectedResolvedSparkJars) val resolvedSparkFiles = preparedDriverSpec.driverSparkConf.get("spark.files").split(",").toSet val expectedResolvedSparkFiles = Set( - s"$FILES_DOWNLOAD_PATH/file1.txt", - s"hdfs://localhost:9000/apps/files/file2.txt", - s"/var/apps/files/file3.txt") + "apps/files/file1.txt", + "/var/apps/files/file2.txt") assert(resolvedSparkFiles === expectedResolvedSparkFiles) val driverEnv = preparedDriverSpec.driverContainer.getEnv.asScala assert(driverEnv.size === 1) assert(driverEnv.head.getName === ENV_MOUNTED_CLASSPATH) val resolvedDriverClasspath = driverEnv.head.getValue.split(File.pathSeparator).toSet - val expectedResolvedDriverClasspath = Set( - s"$JARS_DOWNLOAD_PATH/jar1.jar", - s"$JARS_DOWNLOAD_PATH/jar2.jar", - "/var/apps/jars/jar3.jar") + val expectedResolvedDriverClasspath = expectedResolvedSparkJars assert(resolvedDriverClasspath === expectedResolvedDriverClasspath) } } diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverInitContainerBootstrapStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverInitContainerBootstrapStepSuite.scala deleted file mode 100644 index 758871e2ba356..0000000000000 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverInitContainerBootstrapStepSuite.scala +++ /dev/null @@ -1,160 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s.submit.steps - -import java.io.StringReader -import java.util.Properties - -import scala.collection.JavaConverters._ - -import com.google.common.collect.Maps -import io.fabric8.kubernetes.api.model.{ConfigMap, ContainerBuilder, HasMetadata, PodBuilder, SecretBuilder} - -import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.deploy.k8s.Config._ -import org.apache.spark.deploy.k8s.submit.KubernetesDriverSpec -import org.apache.spark.deploy.k8s.submit.steps.initcontainer.{InitContainerConfigurationStep, InitContainerSpec} -import org.apache.spark.util.Utils - -class DriverInitContainerBootstrapStepSuite extends SparkFunSuite { - - private val CONFIG_MAP_NAME = "spark-init-config-map" - private val CONFIG_MAP_KEY = "spark-init-config-map-key" - - test("The init container bootstrap step should use all of the init container steps") { - val baseDriverSpec = KubernetesDriverSpec( - driverPod = new PodBuilder().build(), - driverContainer = new ContainerBuilder().build(), - driverSparkConf = new SparkConf(false), - otherKubernetesResources = Seq.empty[HasMetadata]) - val initContainerSteps = Seq( - FirstTestInitContainerConfigurationStep, - SecondTestInitContainerConfigurationStep) - val bootstrapStep = new DriverInitContainerBootstrapStep( - initContainerSteps, - CONFIG_MAP_NAME, - CONFIG_MAP_KEY) - - val preparedDriverSpec = bootstrapStep.configureDriver(baseDriverSpec) - - assert(preparedDriverSpec.driverPod.getMetadata.getLabels.asScala === - FirstTestInitContainerConfigurationStep.additionalLabels) - val additionalDriverEnv = preparedDriverSpec.driverContainer.getEnv.asScala - assert(additionalDriverEnv.size === 1) - assert(additionalDriverEnv.head.getName === - FirstTestInitContainerConfigurationStep.additionalMainContainerEnvKey) - assert(additionalDriverEnv.head.getValue === - FirstTestInitContainerConfigurationStep.additionalMainContainerEnvValue) - - assert(preparedDriverSpec.otherKubernetesResources.size === 2) - assert(preparedDriverSpec.otherKubernetesResources.contains( - FirstTestInitContainerConfigurationStep.additionalKubernetesResource)) - assert(preparedDriverSpec.otherKubernetesResources.exists { - case configMap: ConfigMap => - val hasMatchingName = configMap.getMetadata.getName == CONFIG_MAP_NAME - val configMapData = configMap.getData.asScala - val hasCorrectNumberOfEntries = configMapData.size == 1 - val initContainerPropertiesRaw = configMapData(CONFIG_MAP_KEY) - val initContainerProperties = new Properties() - Utils.tryWithResource(new StringReader(initContainerPropertiesRaw)) { - initContainerProperties.load(_) - } - val initContainerPropertiesMap = Maps.fromProperties(initContainerProperties).asScala - val expectedInitContainerProperties = Map( - SecondTestInitContainerConfigurationStep.additionalInitContainerPropertyKey -> - SecondTestInitContainerConfigurationStep.additionalInitContainerPropertyValue) - val hasMatchingProperties = initContainerPropertiesMap == expectedInitContainerProperties - hasMatchingName && hasCorrectNumberOfEntries && hasMatchingProperties - - case _ => false - }) - - val initContainers = preparedDriverSpec.driverPod.getSpec.getInitContainers - assert(initContainers.size() === 1) - val initContainerEnv = initContainers.get(0).getEnv.asScala - assert(initContainerEnv.size === 1) - assert(initContainerEnv.head.getName === - SecondTestInitContainerConfigurationStep.additionalInitContainerEnvKey) - assert(initContainerEnv.head.getValue === - SecondTestInitContainerConfigurationStep.additionalInitContainerEnvValue) - - val expectedSparkConf = Map( - INIT_CONTAINER_CONFIG_MAP_NAME.key -> CONFIG_MAP_NAME, - INIT_CONTAINER_CONFIG_MAP_KEY_CONF.key -> CONFIG_MAP_KEY, - SecondTestInitContainerConfigurationStep.additionalDriverSparkConfKey -> - SecondTestInitContainerConfigurationStep.additionalDriverSparkConfValue) - assert(preparedDriverSpec.driverSparkConf.getAll.toMap === expectedSparkConf) - } -} - -private object FirstTestInitContainerConfigurationStep extends InitContainerConfigurationStep { - - val additionalLabels = Map("additionalLabelkey" -> "additionalLabelValue") - val additionalMainContainerEnvKey = "TEST_ENV_MAIN_KEY" - val additionalMainContainerEnvValue = "TEST_ENV_MAIN_VALUE" - val additionalKubernetesResource = new SecretBuilder() - .withNewMetadata() - .withName("test-secret") - .endMetadata() - .addToData("secret-key", "secret-value") - .build() - - override def configureInitContainer(initContainerSpec: InitContainerSpec): InitContainerSpec = { - val driverPod = new PodBuilder(initContainerSpec.driverPod) - .editOrNewMetadata() - .addToLabels(additionalLabels.asJava) - .endMetadata() - .build() - val mainContainer = new ContainerBuilder(initContainerSpec.driverContainer) - .addNewEnv() - .withName(additionalMainContainerEnvKey) - .withValue(additionalMainContainerEnvValue) - .endEnv() - .build() - initContainerSpec.copy( - driverPod = driverPod, - driverContainer = mainContainer, - dependentResources = initContainerSpec.dependentResources ++ - Seq(additionalKubernetesResource)) - } -} - -private object SecondTestInitContainerConfigurationStep extends InitContainerConfigurationStep { - val additionalInitContainerEnvKey = "TEST_ENV_INIT_KEY" - val additionalInitContainerEnvValue = "TEST_ENV_INIT_VALUE" - val additionalInitContainerPropertyKey = "spark.initcontainer.testkey" - val additionalInitContainerPropertyValue = "testvalue" - val additionalDriverSparkConfKey = "spark.driver.testkey" - val additionalDriverSparkConfValue = "spark.driver.testvalue" - - override def configureInitContainer(initContainerSpec: InitContainerSpec): InitContainerSpec = { - val initContainer = new ContainerBuilder(initContainerSpec.initContainer) - .addNewEnv() - .withName(additionalInitContainerEnvKey) - .withValue(additionalInitContainerEnvValue) - .endEnv() - .build() - val initContainerProperties = initContainerSpec.properties ++ - Map(additionalInitContainerPropertyKey -> additionalInitContainerPropertyValue) - val driverSparkConf = initContainerSpec.driverSparkConf ++ - Map(additionalDriverSparkConfKey -> additionalDriverSparkConfValue) - initContainerSpec.copy( - initContainer = initContainer, - properties = initContainerProperties, - driverSparkConf = driverSparkConf) - } -} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/BasicInitContainerConfigurationStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/BasicInitContainerConfigurationStepSuite.scala deleted file mode 100644 index 4553f9f6b1d45..0000000000000 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/BasicInitContainerConfigurationStepSuite.scala +++ /dev/null @@ -1,95 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s.submit.steps.initcontainer - -import scala.collection.JavaConverters._ - -import io.fabric8.kubernetes.api.model._ -import org.mockito.{Mock, MockitoAnnotations} -import org.mockito.Matchers.any -import org.mockito.Mockito.when -import org.mockito.invocation.InvocationOnMock -import org.mockito.stubbing.Answer -import org.scalatest.BeforeAndAfter - -import org.apache.spark.SparkFunSuite -import org.apache.spark.deploy.k8s.{InitContainerBootstrap, PodWithDetachedInitContainer} -import org.apache.spark.deploy.k8s.Config._ - -class BasicInitContainerConfigurationStepSuite extends SparkFunSuite with BeforeAndAfter { - - private val SPARK_JARS = Seq( - "hdfs://localhost:9000/app/jars/jar1.jar", "file:///app/jars/jar2.jar") - private val SPARK_FILES = Seq( - "hdfs://localhost:9000/app/files/file1.txt", "file:///app/files/file2.txt") - private val JARS_DOWNLOAD_PATH = "/var/data/jars" - private val FILES_DOWNLOAD_PATH = "/var/data/files" - private val POD_LABEL = Map("bootstrap" -> "true") - private val INIT_CONTAINER_NAME = "init-container" - private val DRIVER_CONTAINER_NAME = "driver-container" - - @Mock - private var podAndInitContainerBootstrap : InitContainerBootstrap = _ - - before { - MockitoAnnotations.initMocks(this) - when(podAndInitContainerBootstrap.bootstrapInitContainer( - any[PodWithDetachedInitContainer])).thenAnswer(new Answer[PodWithDetachedInitContainer] { - override def answer(invocation: InvocationOnMock) : PodWithDetachedInitContainer = { - val pod = invocation.getArgumentAt(0, classOf[PodWithDetachedInitContainer]) - pod.copy( - pod = new PodBuilder(pod.pod) - .withNewMetadata() - .addToLabels("bootstrap", "true") - .endMetadata() - .withNewSpec().endSpec() - .build(), - initContainer = new ContainerBuilder() - .withName(INIT_CONTAINER_NAME) - .build(), - mainContainer = new ContainerBuilder() - .withName(DRIVER_CONTAINER_NAME) - .build() - )}}) - } - - test("additionalDriverSparkConf with mix of remote files and jars") { - val baseInitStep = new BasicInitContainerConfigurationStep( - SPARK_JARS, - SPARK_FILES, - JARS_DOWNLOAD_PATH, - FILES_DOWNLOAD_PATH, - podAndInitContainerBootstrap) - val expectedDriverSparkConf = Map( - JARS_DOWNLOAD_LOCATION.key -> JARS_DOWNLOAD_PATH, - FILES_DOWNLOAD_LOCATION.key -> FILES_DOWNLOAD_PATH, - INIT_CONTAINER_REMOTE_JARS.key -> "hdfs://localhost:9000/app/jars/jar1.jar", - INIT_CONTAINER_REMOTE_FILES.key -> "hdfs://localhost:9000/app/files/file1.txt") - val initContainerSpec = InitContainerSpec( - Map.empty[String, String], - Map.empty[String, String], - new Container(), - new Container(), - new Pod, - Seq.empty[HasMetadata]) - val returnContainerSpec = baseInitStep.configureInitContainer(initContainerSpec) - assert(expectedDriverSparkConf === returnContainerSpec.properties) - assert(returnContainerSpec.initContainer.getName === INIT_CONTAINER_NAME) - assert(returnContainerSpec.driverContainer.getName === DRIVER_CONTAINER_NAME) - assert(returnContainerSpec.driverPod.getMetadata.getLabels.asScala === POD_LABEL) - } -} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerConfigOrchestratorSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerConfigOrchestratorSuite.scala deleted file mode 100644 index 09b42e4484d86..0000000000000 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerConfigOrchestratorSuite.scala +++ /dev/null @@ -1,80 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s.submit.steps.initcontainer - -import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.deploy.k8s.Config._ -import org.apache.spark.deploy.k8s.Constants._ - -class InitContainerConfigOrchestratorSuite extends SparkFunSuite { - - private val DOCKER_IMAGE = "init-container" - private val SPARK_JARS = Seq( - "hdfs://localhost:9000/app/jars/jar1.jar", "file:///app/jars/jar2.jar") - private val SPARK_FILES = Seq( - "hdfs://localhost:9000/app/files/file1.txt", "file:///app/files/file2.txt") - private val JARS_DOWNLOAD_PATH = "/var/data/jars" - private val FILES_DOWNLOAD_PATH = "/var/data/files" - private val DOCKER_IMAGE_PULL_POLICY: String = "IfNotPresent" - private val CUSTOM_LABEL_KEY = "customLabel" - private val CUSTOM_LABEL_VALUE = "customLabelValue" - private val INIT_CONTAINER_CONFIG_MAP_NAME = "spark-init-config-map" - private val INIT_CONTAINER_CONFIG_MAP_KEY = "spark-init-config-map-key" - private val SECRET_FOO = "foo" - private val SECRET_BAR = "bar" - private val SECRET_MOUNT_PATH = "/etc/secrets/init-container" - - test("including basic configuration step") { - val sparkConf = new SparkConf(true) - .set(CONTAINER_IMAGE, DOCKER_IMAGE) - .set(s"$KUBERNETES_DRIVER_LABEL_PREFIX$CUSTOM_LABEL_KEY", CUSTOM_LABEL_VALUE) - - val orchestrator = new InitContainerConfigOrchestrator( - SPARK_JARS.take(1), - SPARK_FILES, - JARS_DOWNLOAD_PATH, - FILES_DOWNLOAD_PATH, - DOCKER_IMAGE_PULL_POLICY, - INIT_CONTAINER_CONFIG_MAP_NAME, - INIT_CONTAINER_CONFIG_MAP_KEY, - sparkConf) - val initSteps = orchestrator.getAllConfigurationSteps - assert(initSteps.lengthCompare(1) == 0) - assert(initSteps.head.isInstanceOf[BasicInitContainerConfigurationStep]) - } - - test("including step to mount user-specified secrets") { - val sparkConf = new SparkConf(false) - .set(CONTAINER_IMAGE, DOCKER_IMAGE) - .set(s"$KUBERNETES_DRIVER_SECRETS_PREFIX$SECRET_FOO", SECRET_MOUNT_PATH) - .set(s"$KUBERNETES_DRIVER_SECRETS_PREFIX$SECRET_BAR", SECRET_MOUNT_PATH) - - val orchestrator = new InitContainerConfigOrchestrator( - SPARK_JARS.take(1), - SPARK_FILES, - JARS_DOWNLOAD_PATH, - FILES_DOWNLOAD_PATH, - DOCKER_IMAGE_PULL_POLICY, - INIT_CONTAINER_CONFIG_MAP_NAME, - INIT_CONTAINER_CONFIG_MAP_KEY, - sparkConf) - val initSteps = orchestrator.getAllConfigurationSteps - assert(initSteps.length === 2) - assert(initSteps.head.isInstanceOf[BasicInitContainerConfigurationStep]) - assert(initSteps(1).isInstanceOf[InitContainerMountSecretsStep]) - } -} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerMountSecretsStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerMountSecretsStepSuite.scala deleted file mode 100644 index 7ac0bde80dfe6..0000000000000 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerMountSecretsStepSuite.scala +++ /dev/null @@ -1,52 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s.submit.steps.initcontainer - -import io.fabric8.kubernetes.api.model.{ContainerBuilder, PodBuilder} - -import org.apache.spark.SparkFunSuite -import org.apache.spark.deploy.k8s.{MountSecretsBootstrap, SecretVolumeUtils} - -class InitContainerMountSecretsStepSuite extends SparkFunSuite { - - private val SECRET_FOO = "foo" - private val SECRET_BAR = "bar" - private val SECRET_MOUNT_PATH = "/etc/secrets/init-container" - - test("mounts all given secrets") { - val baseInitContainerSpec = InitContainerSpec( - Map.empty, - Map.empty, - new ContainerBuilder().build(), - new ContainerBuilder().build(), - new PodBuilder().withNewMetadata().endMetadata().withNewSpec().endSpec().build(), - Seq.empty) - val secretNamesToMountPaths = Map( - SECRET_FOO -> SECRET_MOUNT_PATH, - SECRET_BAR -> SECRET_MOUNT_PATH) - - val mountSecretsBootstrap = new MountSecretsBootstrap(secretNamesToMountPaths) - val initContainerMountSecretsStep = new InitContainerMountSecretsStep(mountSecretsBootstrap) - val configuredInitContainerSpec = initContainerMountSecretsStep.configureInitContainer( - baseInitContainerSpec) - val initContainerWithSecretsMounted = configuredInitContainerSpec.initContainer - - Seq(s"$SECRET_FOO-volume", s"$SECRET_BAR-volume").foreach(volumeName => - assert(SecretVolumeUtils.containerHasVolume( - initContainerWithSecretsMounted, volumeName, SECRET_MOUNT_PATH))) - } -} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala index a3c615be031d2..7755b93835047 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala @@ -19,15 +19,13 @@ package org.apache.spark.scheduler.cluster.k8s import scala.collection.JavaConverters._ import io.fabric8.kubernetes.api.model._ -import org.mockito.{AdditionalAnswers, MockitoAnnotations} -import org.mockito.Matchers.any -import org.mockito.Mockito._ +import org.mockito.MockitoAnnotations import org.scalatest.{BeforeAndAfter, BeforeAndAfterEach} import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.deploy.k8s.{InitContainerBootstrap, MountSecretsBootstrap, PodWithDetachedInitContainer, SecretVolumeUtils} import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.deploy.k8s.MountSecretsBootstrap class ExecutorPodFactorySuite extends SparkFunSuite with BeforeAndAfter with BeforeAndAfterEach { @@ -55,10 +53,11 @@ class ExecutorPodFactorySuite extends SparkFunSuite with BeforeAndAfter with Bef .set(KUBERNETES_DRIVER_POD_NAME, driverPodName) .set(KUBERNETES_EXECUTOR_POD_NAME_PREFIX, executorPrefix) .set(CONTAINER_IMAGE, executorImage) + .set(KUBERNETES_DRIVER_SUBMIT_CHECK, true) } test("basic executor pod has reasonable defaults") { - val factory = new ExecutorPodFactory(baseConf, None, None, None) + val factory = new ExecutorPodFactory(baseConf, None) val executor = factory.createExecutorPod( "1", "dummy", "dummy", Seq[(String, String)](), driverPod, Map[String, Int]()) @@ -89,7 +88,7 @@ class ExecutorPodFactorySuite extends SparkFunSuite with BeforeAndAfter with Bef conf.set(KUBERNETES_EXECUTOR_POD_NAME_PREFIX, "loremipsumdolorsitametvimatelitrefficiendisuscipianturvixlegeresple") - val factory = new ExecutorPodFactory(conf, None, None, None) + val factory = new ExecutorPodFactory(conf, None) val executor = factory.createExecutorPod( "1", "dummy", "dummy", Seq[(String, String)](), driverPod, Map[String, Int]()) @@ -101,7 +100,7 @@ class ExecutorPodFactorySuite extends SparkFunSuite with BeforeAndAfter with Bef conf.set(org.apache.spark.internal.config.EXECUTOR_JAVA_OPTIONS, "foo=bar") conf.set(org.apache.spark.internal.config.EXECUTOR_CLASS_PATH, "bar=baz") - val factory = new ExecutorPodFactory(conf, None, None, None) + val factory = new ExecutorPodFactory(conf, None) val executor = factory.createExecutorPod( "1", "dummy", "dummy", Seq[(String, String)]("qux" -> "quux"), driverPod, Map[String, Int]()) @@ -116,11 +115,7 @@ class ExecutorPodFactorySuite extends SparkFunSuite with BeforeAndAfter with Bef val conf = baseConf.clone() val secretsBootstrap = new MountSecretsBootstrap(Map("secret1" -> "/var/secret1")) - val factory = new ExecutorPodFactory( - conf, - Some(secretsBootstrap), - None, - None) + val factory = new ExecutorPodFactory(conf, Some(secretsBootstrap)) val executor = factory.createExecutorPod( "1", "dummy", "dummy", Seq[(String, String)](), driverPod, Map[String, Int]()) @@ -138,50 +133,6 @@ class ExecutorPodFactorySuite extends SparkFunSuite with BeforeAndAfter with Bef checkOwnerReferences(executor, driverPodUid) } - test("init-container bootstrap step adds an init container") { - val conf = baseConf.clone() - val initContainerBootstrap = mock(classOf[InitContainerBootstrap]) - when(initContainerBootstrap.bootstrapInitContainer( - any(classOf[PodWithDetachedInitContainer]))).thenAnswer(AdditionalAnswers.returnsFirstArg()) - - val factory = new ExecutorPodFactory( - conf, - None, - Some(initContainerBootstrap), - None) - val executor = factory.createExecutorPod( - "1", "dummy", "dummy", Seq[(String, String)](), driverPod, Map[String, Int]()) - - assert(executor.getSpec.getInitContainers.size() === 1) - checkOwnerReferences(executor, driverPodUid) - } - - test("init-container with secrets mount bootstrap") { - val conf = baseConf.clone() - val initContainerBootstrap = mock(classOf[InitContainerBootstrap]) - when(initContainerBootstrap.bootstrapInitContainer( - any(classOf[PodWithDetachedInitContainer]))).thenAnswer(AdditionalAnswers.returnsFirstArg()) - val secretsBootstrap = new MountSecretsBootstrap(Map("secret1" -> "/var/secret1")) - - val factory = new ExecutorPodFactory( - conf, - Some(secretsBootstrap), - Some(initContainerBootstrap), - Some(secretsBootstrap)) - val executor = factory.createExecutorPod( - "1", "dummy", "dummy", Seq[(String, String)](), driverPod, Map[String, Int]()) - - assert(executor.getSpec.getVolumes.size() === 1) - assert(SecretVolumeUtils.podHasVolume(executor, "secret1-volume")) - assert(SecretVolumeUtils.containerHasVolume( - executor.getSpec.getContainers.get(0), "secret1-volume", "/var/secret1")) - assert(executor.getSpec.getInitContainers.size() === 1) - assert(SecretVolumeUtils.containerHasVolume( - executor.getSpec.getInitContainers.get(0), "secret1-volume", "/var/secret1")) - - checkOwnerReferences(executor, driverPodUid) - } - // There is always exactly one controller reference, and it points to the driver pod. private def checkOwnerReferences(executor: Pod, driverPodUid: String): Unit = { assert(executor.getMetadata.getOwnerReferences.size() === 1) @@ -197,8 +148,8 @@ class ExecutorPodFactorySuite extends SparkFunSuite with BeforeAndAfter with Bef ENV_EXECUTOR_CORES -> "1", ENV_EXECUTOR_MEMORY -> "1g", ENV_APPLICATION_ID -> "dummy", - ENV_EXECUTOR_POD_IP -> null, - ENV_MOUNTED_CLASSPATH -> "/var/spark-data/spark-jars/*") ++ additionalEnvVars + ENV_SPARK_CONF_DIR -> SPARK_CONF_DIR_INTERNAL, + ENV_EXECUTOR_POD_IP -> null) ++ additionalEnvVars assert(executor.getSpec.getContainers.size() === 1) assert(executor.getSpec.getContainers.get(0).getEnv.size() === defaultEnvs.size) diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/Dockerfile b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/Dockerfile index 491b7cf692478..9badf8556afc3 100644 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/Dockerfile +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/Dockerfile @@ -40,7 +40,6 @@ RUN set -ex && \ COPY ${spark_jars} /opt/spark/jars COPY bin /opt/spark/bin COPY sbin /opt/spark/sbin -COPY conf /opt/spark/conf COPY ${img_path}/spark/entrypoint.sh /opt/ COPY examples /opt/spark/examples COPY data /opt/spark/data diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh index d0cf284f035ea..3e166116aa3fd 100755 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh @@ -56,14 +56,10 @@ fi case "$SPARK_K8S_CMD" in driver) CMD=( - ${JAVA_HOME}/bin/java - "${SPARK_JAVA_OPTS[@]}" - -cp "$SPARK_CLASSPATH" - -Xms$SPARK_DRIVER_MEMORY - -Xmx$SPARK_DRIVER_MEMORY - -Dspark.driver.bindAddress=$SPARK_DRIVER_BIND_ADDRESS - $SPARK_DRIVER_CLASS - $SPARK_DRIVER_ARGS + "$SPARK_HOME/bin/spark-submit" + --conf "spark.driver.bindAddress=$SPARK_DRIVER_BIND_ADDRESS" + --deploy-mode client + "$@" ) ;; @@ -83,14 +79,6 @@ case "$SPARK_K8S_CMD" in ) ;; - init) - CMD=( - "$SPARK_HOME/bin/spark-class" - "org.apache.spark.deploy.k8s.SparkPodInitContainer" - "$@" - ) - ;; - *) echo "Unknown command: $SPARK_K8S_CMD" 1>&2 exit 1 From 5f4deff19511b6870f056eba5489104b9cac05a9 Mon Sep 17 00:00:00 2001 From: Gabor Somogyi Date: Mon, 19 Mar 2018 18:02:04 -0700 Subject: [PATCH 0492/1161] [SPARK-23660] Fix exception in yarn cluster mode when application ended fast ## What changes were proposed in this pull request? Yarn throws the following exception in cluster mode when the application is really small: ``` 18/03/07 23:34:22 WARN netty.NettyRpcEnv: Ignored failure: java.util.concurrent.RejectedExecutionException: Task java.util.concurrent.ScheduledThreadPoolExecutor$ScheduledFutureTask7c974942 rejected from java.util.concurrent.ScheduledThreadPoolExecutor1eea9d2d[Terminated, pool size = 0, active threads = 0, queued tasks = 0, completed tasks = 0] 18/03/07 23:34:22 ERROR yarn.ApplicationMaster: Uncaught exception: org.apache.spark.SparkException: Exception thrown in awaitResult: at org.apache.spark.util.ThreadUtils$.awaitResult(ThreadUtils.scala:205) at org.apache.spark.rpc.RpcTimeout.awaitResult(RpcTimeout.scala:75) at org.apache.spark.rpc.RpcEndpointRef.askSync(RpcEndpointRef.scala:92) at org.apache.spark.rpc.RpcEndpointRef.askSync(RpcEndpointRef.scala:76) at org.apache.spark.deploy.yarn.YarnAllocator.(YarnAllocator.scala:102) at org.apache.spark.deploy.yarn.YarnRMClient.register(YarnRMClient.scala:77) at org.apache.spark.deploy.yarn.ApplicationMaster.registerAM(ApplicationMaster.scala:450) at org.apache.spark.deploy.yarn.ApplicationMaster.runDriver(ApplicationMaster.scala:493) at org.apache.spark.deploy.yarn.ApplicationMaster.org$apache$spark$deploy$yarn$ApplicationMaster$$runImpl(ApplicationMaster.scala:345) at org.apache.spark.deploy.yarn.ApplicationMaster$$anonfun$run$2.apply$mcV$sp(ApplicationMaster.scala:260) at org.apache.spark.deploy.yarn.ApplicationMaster$$anonfun$run$2.apply(ApplicationMaster.scala:260) at org.apache.spark.deploy.yarn.ApplicationMaster$$anonfun$run$2.apply(ApplicationMaster.scala:260) at org.apache.spark.deploy.yarn.ApplicationMaster$$anon$5.run(ApplicationMaster.scala:810) at java.security.AccessController.doPrivileged(Native Method) at javax.security.auth.Subject.doAs(Subject.java:422) at org.apache.hadoop.security.UserGroupInformation.doAs(UserGroupInformation.java:1920) at org.apache.spark.deploy.yarn.ApplicationMaster.doAsUser(ApplicationMaster.scala:809) at org.apache.spark.deploy.yarn.ApplicationMaster.run(ApplicationMaster.scala:259) at org.apache.spark.deploy.yarn.ApplicationMaster$.main(ApplicationMaster.scala:834) at org.apache.spark.deploy.yarn.ApplicationMaster.main(ApplicationMaster.scala) Caused by: org.apache.spark.rpc.RpcEnvStoppedException: RpcEnv already stopped. at org.apache.spark.rpc.netty.Dispatcher.postMessage(Dispatcher.scala:158) at org.apache.spark.rpc.netty.Dispatcher.postLocalMessage(Dispatcher.scala:135) at org.apache.spark.rpc.netty.NettyRpcEnv.ask(NettyRpcEnv.scala:229) at org.apache.spark.rpc.netty.NettyRpcEndpointRef.ask(NettyRpcEnv.scala:523) at org.apache.spark.rpc.RpcEndpointRef.askSync(RpcEndpointRef.scala:91) ... 17 more 18/03/07 23:34:22 INFO yarn.ApplicationMaster: Final app status: FAILED, exitCode: 13, (reason: Uncaught exception: org.apache.spark.SparkException: Exception thrown in awaitResult: ) ``` Example application: ``` object ExampleApp { def main(args: Array[String]): Unit = { val conf = new SparkConf().setAppName("ExampleApp") val sc = new SparkContext(conf) try { // Do nothing } finally { sc.stop() } } ``` This PR pauses user class thread after `SparkContext` created and keeps it so until application master initialises properly. ## How was this patch tested? Automated: Existing unit tests Manual: Application submitted into small cluster Author: Gabor Somogyi Closes #20807 from gaborgsomogyi/SPARK-23660. --- .../spark/deploy/yarn/ApplicationMaster.scala | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 2f88feb0f1fdf..6e35d23def6f0 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -418,7 +418,19 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends } private def sparkContextInitialized(sc: SparkContext) = { - sparkContextPromise.success(sc) + sparkContextPromise.synchronized { + // Notify runDriver function that SparkContext is available + sparkContextPromise.success(sc) + // Pause the user class thread in order to make proper initialization in runDriver function. + sparkContextPromise.wait() + } + } + + private def resumeDriver(): Unit = { + // When initialization in runDriver happened the user class thread has to be resumed. + sparkContextPromise.synchronized { + sparkContextPromise.notify() + } } private def registerAM( @@ -497,6 +509,7 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends // if the user app did not create a SparkContext. throw new IllegalStateException("User did not initialize spark context!") } + resumeDriver() userClassThread.join() } catch { case e: SparkException if e.getCause().isInstanceOf[TimeoutException] => @@ -506,6 +519,8 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends finish(FinalApplicationStatus.FAILED, ApplicationMaster.EXIT_SC_NOT_INITED, "Timed out waiting for SparkContext.") + } finally { + resumeDriver() } } From 566321852b2d60641fe86acbc8914b4a7063b58e Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Mon, 19 Mar 2018 21:25:37 -0700 Subject: [PATCH 0493/1161] [SPARK-23691][PYTHON] Use sql_conf util in PySpark tests where possible ## What changes were proposed in this pull request? https://github.com/apache/spark/commit/d6632d185e147fcbe6724545488ad80dce20277e added an useful util ```python contextmanager def sql_conf(self, pairs): ... ``` to allow configuration set/unset within a block: ```python with self.sql_conf({"spark.blah.blah.blah", "blah"}) # test codes ``` This PR proposes to use this util where possible in PySpark tests. Note that there look already few places affecting tests without restoring the original value back in unittest classes. ## How was this patch tested? Manually tested via: ``` ./run-tests --modules=pyspark-sql --python-executables=python2 ./run-tests --modules=pyspark-sql --python-executables=python3 ``` Author: hyukjinkwon Closes #20830 from HyukjinKwon/cleanup-sql-conf. --- python/pyspark/sql/tests.py | 130 ++++++++++++++---------------------- 1 file changed, 50 insertions(+), 80 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index a0d547ad620e5..39d6c5226f138 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -2461,17 +2461,13 @@ def test_join_without_on(self): df1 = self.spark.range(1).toDF("a") df2 = self.spark.range(1).toDF("b") - try: - self.spark.conf.set("spark.sql.crossJoin.enabled", "false") + with self.sql_conf({"spark.sql.crossJoin.enabled": False}): self.assertRaises(AnalysisException, lambda: df1.join(df2, how="inner").collect()) - self.spark.conf.set("spark.sql.crossJoin.enabled", "true") + with self.sql_conf({"spark.sql.crossJoin.enabled": True}): actual = df1.join(df2, how="inner").collect() expected = [Row(a=0, b=0)] self.assertEqual(actual, expected) - finally: - # We should unset this. Otherwise, other tests are affected. - self.spark.conf.unset("spark.sql.crossJoin.enabled") # Regression test for invalid join methods when on is None, Spark-14761 def test_invalid_join_method(self): @@ -2943,21 +2939,18 @@ def test_create_dateframe_from_pandas_with_dst(self): self.assertPandasEqual(pdf, df.toPandas()) orig_env_tz = os.environ.get('TZ', None) - orig_session_tz = self.spark.conf.get('spark.sql.session.timeZone') try: tz = 'America/Los_Angeles' os.environ['TZ'] = tz time.tzset() - self.spark.conf.set('spark.sql.session.timeZone', tz) - - df = self.spark.createDataFrame(pdf) - self.assertPandasEqual(pdf, df.toPandas()) + with self.sql_conf({'spark.sql.session.timeZone': tz}): + df = self.spark.createDataFrame(pdf) + self.assertPandasEqual(pdf, df.toPandas()) finally: del os.environ['TZ'] if orig_env_tz is not None: os.environ['TZ'] = orig_env_tz time.tzset() - self.spark.conf.set('spark.sql.session.timeZone', orig_session_tz) class HiveSparkSubmitTests(SparkSubmitTests): @@ -3562,12 +3555,11 @@ def test_null_conversion(self): self.assertTrue(all([c == 1 for c in null_counts])) def _toPandas_arrow_toggle(self, df): - self.spark.conf.set("spark.sql.execution.arrow.enabled", "false") - try: + with self.sql_conf({"spark.sql.execution.arrow.enabled": False}): pdf = df.toPandas() - finally: - self.spark.conf.set("spark.sql.execution.arrow.enabled", "true") + pdf_arrow = df.toPandas() + return pdf, pdf_arrow def test_toPandas_arrow_toggle(self): @@ -3579,16 +3571,17 @@ def test_toPandas_arrow_toggle(self): def test_toPandas_respect_session_timezone(self): df = self.spark.createDataFrame(self.data, schema=self.schema) - orig_tz = self.spark.conf.get("spark.sql.session.timeZone") - try: - timezone = "America/New_York" - self.spark.conf.set("spark.sql.session.timeZone", timezone) - self.spark.conf.set("spark.sql.execution.pandas.respectSessionTimeZone", "false") - try: - pdf_la, pdf_arrow_la = self._toPandas_arrow_toggle(df) - self.assertPandasEqual(pdf_arrow_la, pdf_la) - finally: - self.spark.conf.set("spark.sql.execution.pandas.respectSessionTimeZone", "true") + + timezone = "America/New_York" + with self.sql_conf({ + "spark.sql.execution.pandas.respectSessionTimeZone": False, + "spark.sql.session.timeZone": timezone}): + pdf_la, pdf_arrow_la = self._toPandas_arrow_toggle(df) + self.assertPandasEqual(pdf_arrow_la, pdf_la) + + with self.sql_conf({ + "spark.sql.execution.pandas.respectSessionTimeZone": True, + "spark.sql.session.timeZone": timezone}): pdf_ny, pdf_arrow_ny = self._toPandas_arrow_toggle(df) self.assertPandasEqual(pdf_arrow_ny, pdf_ny) @@ -3601,8 +3594,6 @@ def test_toPandas_respect_session_timezone(self): pdf_la_corrected[field.name] = _check_series_convert_timestamps_local_tz( pdf_la_corrected[field.name], timezone) self.assertPandasEqual(pdf_ny, pdf_la_corrected) - finally: - self.spark.conf.set("spark.sql.session.timeZone", orig_tz) def test_pandas_round_trip(self): pdf = self.create_pandas_data_frame() @@ -3618,12 +3609,11 @@ def test_filtered_frame(self): self.assertTrue(pdf.empty) def _createDataFrame_toggle(self, pdf, schema=None): - self.spark.conf.set("spark.sql.execution.arrow.enabled", "false") - try: + with self.sql_conf({"spark.sql.execution.arrow.enabled": False}): df_no_arrow = self.spark.createDataFrame(pdf, schema=schema) - finally: - self.spark.conf.set("spark.sql.execution.arrow.enabled", "true") + df_arrow = self.spark.createDataFrame(pdf, schema=schema) + return df_no_arrow, df_arrow def test_createDataFrame_toggle(self): @@ -3634,18 +3624,18 @@ def test_createDataFrame_toggle(self): def test_createDataFrame_respect_session_timezone(self): from datetime import timedelta pdf = self.create_pandas_data_frame() - orig_tz = self.spark.conf.get("spark.sql.session.timeZone") - try: - timezone = "America/New_York" - self.spark.conf.set("spark.sql.session.timeZone", timezone) - self.spark.conf.set("spark.sql.execution.pandas.respectSessionTimeZone", "false") - try: - df_no_arrow_la, df_arrow_la = self._createDataFrame_toggle(pdf, schema=self.schema) - result_la = df_no_arrow_la.collect() - result_arrow_la = df_arrow_la.collect() - self.assertEqual(result_la, result_arrow_la) - finally: - self.spark.conf.set("spark.sql.execution.pandas.respectSessionTimeZone", "true") + timezone = "America/New_York" + with self.sql_conf({ + "spark.sql.execution.pandas.respectSessionTimeZone": False, + "spark.sql.session.timeZone": timezone}): + df_no_arrow_la, df_arrow_la = self._createDataFrame_toggle(pdf, schema=self.schema) + result_la = df_no_arrow_la.collect() + result_arrow_la = df_arrow_la.collect() + self.assertEqual(result_la, result_arrow_la) + + with self.sql_conf({ + "spark.sql.execution.pandas.respectSessionTimeZone": True, + "spark.sql.session.timeZone": timezone}): df_no_arrow_ny, df_arrow_ny = self._createDataFrame_toggle(pdf, schema=self.schema) result_ny = df_no_arrow_ny.collect() result_arrow_ny = df_arrow_ny.collect() @@ -3658,8 +3648,6 @@ def test_createDataFrame_respect_session_timezone(self): for k, v in row.asDict().items()}) for row in result_la] self.assertEqual(result_ny, result_la_corrected) - finally: - self.spark.conf.set("spark.sql.session.timeZone", orig_tz) def test_createDataFrame_with_schema(self): pdf = self.create_pandas_data_frame() @@ -4336,9 +4324,7 @@ def gen_timestamps(id): def test_vectorized_udf_check_config(self): from pyspark.sql.functions import pandas_udf, col import pandas as pd - orig_value = self.spark.conf.get("spark.sql.execution.arrow.maxRecordsPerBatch", None) - self.spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", 3) - try: + with self.sql_conf({"spark.sql.execution.arrow.maxRecordsPerBatch": 3}): df = self.spark.range(10, numPartitions=1) @pandas_udf(returnType=LongType()) @@ -4348,11 +4334,6 @@ def check_records_per_batch(x): result = df.select(check_records_per_batch(col("id"))).collect() for (r,) in result: self.assertTrue(r <= 3) - finally: - if orig_value is None: - self.spark.conf.unset("spark.sql.execution.arrow.maxRecordsPerBatch") - else: - self.spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", orig_value) def test_vectorized_udf_timestamps_respect_session_timezone(self): from pyspark.sql.functions import pandas_udf, col @@ -4371,30 +4352,27 @@ def test_vectorized_udf_timestamps_respect_session_timezone(self): internal_value = pandas_udf( lambda ts: ts.apply(lambda ts: ts.value if ts is not pd.NaT else None), LongType()) - orig_tz = self.spark.conf.get("spark.sql.session.timeZone") - try: - timezone = "America/New_York" - self.spark.conf.set("spark.sql.session.timeZone", timezone) - self.spark.conf.set("spark.sql.execution.pandas.respectSessionTimeZone", "false") - try: - df_la = df.withColumn("tscopy", f_timestamp_copy(col("timestamp"))) \ - .withColumn("internal_value", internal_value(col("timestamp"))) - result_la = df_la.select(col("idx"), col("internal_value")).collect() - # Correct result_la by adjusting 3 hours difference between Los Angeles and New York - diff = 3 * 60 * 60 * 1000 * 1000 * 1000 - result_la_corrected = \ - df_la.select(col("idx"), col("tscopy"), col("internal_value") + diff).collect() - finally: - self.spark.conf.set("spark.sql.execution.pandas.respectSessionTimeZone", "true") + timezone = "America/New_York" + with self.sql_conf({ + "spark.sql.execution.pandas.respectSessionTimeZone": False, + "spark.sql.session.timeZone": timezone}): + df_la = df.withColumn("tscopy", f_timestamp_copy(col("timestamp"))) \ + .withColumn("internal_value", internal_value(col("timestamp"))) + result_la = df_la.select(col("idx"), col("internal_value")).collect() + # Correct result_la by adjusting 3 hours difference between Los Angeles and New York + diff = 3 * 60 * 60 * 1000 * 1000 * 1000 + result_la_corrected = \ + df_la.select(col("idx"), col("tscopy"), col("internal_value") + diff).collect() + with self.sql_conf({ + "spark.sql.execution.pandas.respectSessionTimeZone": True, + "spark.sql.session.timeZone": timezone}): df_ny = df.withColumn("tscopy", f_timestamp_copy(col("timestamp"))) \ .withColumn("internal_value", internal_value(col("timestamp"))) result_ny = df_ny.select(col("idx"), col("tscopy"), col("internal_value")).collect() self.assertNotEqual(result_ny, result_la) self.assertEqual(result_ny, result_la_corrected) - finally: - self.spark.conf.set("spark.sql.session.timeZone", orig_tz) def test_nondeterministic_vectorized_udf(self): # Test that nondeterministic UDFs are evaluated only once in chained UDF evaluations @@ -5170,9 +5148,7 @@ def test_complex_expressions(self): def test_retain_group_columns(self): from pyspark.sql.functions import sum, lit, col - orig_value = self.spark.conf.get("spark.sql.retainGroupColumns", None) - self.spark.conf.set("spark.sql.retainGroupColumns", False) - try: + with self.sql_conf({"spark.sql.retainGroupColumns": False}): df = self.data sum_udf = self.pandas_agg_sum_udf @@ -5180,12 +5156,6 @@ def test_retain_group_columns(self): expected1 = df.groupby(df.id).agg(sum(df.v)) self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) - finally: - if orig_value is None: - self.spark.conf.unset("spark.sql.retainGroupColumns") - else: - self.spark.conf.set("spark.sql.retainGroupColumns", orig_value) - def test_invalid_args(self): from pyspark.sql.functions import mean From 5e7bc2acef4a1e11d0d8056ef5c12cd5c8f220da Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Tue, 20 Mar 2018 10:34:56 -0700 Subject: [PATCH 0494/1161] [SPARK-23649][SQL] Skipping chars disallowed in UTF-8 ## What changes were proposed in this pull request? The mapping of UTF-8 char's first byte to char's size doesn't cover whole range 0-255. It is defined only for 0-253: https://github.com/apache/spark/blob/master/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java#L60-L65 https://github.com/apache/spark/blob/master/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java#L190 If the first byte of a char is 253-255, IndexOutOfBoundsException is thrown. Besides of that values for 244-252 are not correct according to recent unicode standard for UTF-8: http://www.unicode.org/versions/Unicode10.0.0/UnicodeStandard-10.0.pdf As a consequence of the exception above, the length of input string in UTF-8 encoding cannot be calculated if the string contains chars started from 253 code. It is visible on user's side as for example crashing of schema inferring of csv file which contains such chars but the file can be read if the schema is specified explicitly or if the mode set to multiline. The proposed changes build correct mapping of first byte of UTF-8 char to its size (now it covers all cases) and skip disallowed chars (counts it as one octet). ## How was this patch tested? Added a test and a file with a char which is disallowed in UTF-8 - 0xFF. Author: Maxim Gekk Closes #20796 from MaxGekk/skip-wrong-utf8-chars. --- .../apache/spark/unsafe/types/UTF8String.java | 48 +++++++++++++++---- .../spark/unsafe/types/UTF8StringSuite.java | 23 ++++++++- 2 files changed, 62 insertions(+), 9 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index b0d0c44823e68..5d468aed42337 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -57,12 +57,43 @@ public final class UTF8String implements Comparable, Externalizable, public Object getBaseObject() { return base; } public long getBaseOffset() { return offset; } - private static int[] bytesOfCodePointInUTF8 = {2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, - 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, - 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, - 4, 4, 4, 4, 4, 4, 4, 4, - 5, 5, 5, 5, - 6, 6}; + /** + * A char in UTF-8 encoding can take 1-4 bytes depending on the first byte which + * indicates the size of the char. See Unicode standard in page 126, Table 3-6: + * http://www.unicode.org/versions/Unicode10.0.0/UnicodeStandard-10.0.pdf + * + * Binary Hex Comments + * 0xxxxxxx 0x00..0x7F Only byte of a 1-byte character encoding + * 10xxxxxx 0x80..0xBF Continuation bytes (1-3 continuation bytes) + * 110xxxxx 0xC0..0xDF First byte of a 2-byte character encoding + * 1110xxxx 0xE0..0xEF First byte of a 3-byte character encoding + * 11110xxx 0xF0..0xF7 First byte of a 4-byte character encoding + * + * As a consequence of the well-formedness conditions specified in + * Table 3-7 (page 126), the following byte values are disallowed in UTF-8: + * C0–C1, F5–FF. + */ + private static byte[] bytesOfCodePointInUTF8 = { + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, // 0x00..0x0F + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, // 0x10..0x1F + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, // 0x20..0x2F + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, // 0x30..0x3F + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, // 0x40..0x4F + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, // 0x50..0x5F + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, // 0x60..0x6F + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, // 0x70..0x7F + // Continuation bytes cannot appear as the first byte + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // 0x80..0x8F + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // 0x90..0x9F + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // 0xA0..0xAF + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // 0xB0..0xBF + 0, 0, // 0xC0..0xC1 - disallowed in UTF-8 + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, // 0xC2..0xCF + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, // 0xD0..0xDF + 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, // 0xE0..0xEF + 4, 4, 4, 4, 4, // 0xF0..0xF4 + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 // 0xF5..0xFF - disallowed in UTF-8 + }; private static final boolean IS_LITTLE_ENDIAN = ByteOrder.nativeOrder() == ByteOrder.LITTLE_ENDIAN; @@ -187,8 +218,9 @@ public void writeTo(OutputStream out) throws IOException { * @param b The first byte of a code point */ private static int numBytesForFirstByte(final byte b) { - final int offset = (b & 0xFF) - 192; - return (offset >= 0) ? bytesOfCodePointInUTF8[offset] : 1; + final int offset = b & 0xFF; + byte numBytes = bytesOfCodePointInUTF8[offset]; + return (numBytes == 0) ? 1: numBytes; // Skip the first byte disallowed in UTF-8 } /** diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java index 9b303fa5bc6c5..7c34d419574ef 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java @@ -58,8 +58,12 @@ private static void checkBasic(String str, int len) { @Test public void basicTest() { checkBasic("", 0); - checkBasic("hello", 5); + checkBasic("¡", 1); // 2 bytes char + checkBasic("ку", 2); // 2 * 2 bytes chars + checkBasic("hello", 5); // 5 * 1 byte chars checkBasic("大 千 世 界", 7); + checkBasic("︽﹋%", 3); // 3 * 3 bytes chars + checkBasic("\uD83E\uDD19", 1); // 4 bytes char } @Test @@ -791,4 +795,21 @@ public void trimRightWithTrimString() { assertEquals(fromString("头"), fromString("头a???/").trimRight(fromString("数?/*&^%a"))); assertEquals(fromString("头"), fromString("头数b数数 [").trimRight(fromString(" []数b"))); } + + @Test + public void skipWrongFirstByte() { + int[] wrongFirstBytes = { + 0x80, 0x9F, 0xBF, // Skip Continuation bytes + 0xC0, 0xC2, // 0xC0..0xC1 - disallowed in UTF-8 + // 0xF5..0xFF - disallowed in UTF-8 + 0xF5, 0xF6, 0xF7, 0xF8, 0xF9, + 0xFA, 0xFB, 0xFC, 0xFD, 0xFE, 0xFF + }; + byte[] c = new byte[1]; + + for (int i = 0; i < wrongFirstBytes.length; ++i) { + c[0] = (byte)wrongFirstBytes[i]; + assertEquals(fromBytes(c).numChars(), 1); + } + } } From 7f5e8aa2606b0ee0297ceb6f4603bd368e3b0291 Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Tue, 20 Mar 2018 11:14:34 -0700 Subject: [PATCH 0495/1161] [SPARK-21898][ML] Feature parity for KolmogorovSmirnovTest in MLlib ## What changes were proposed in this pull request? Feature parity for KolmogorovSmirnovTest in MLlib. Implement `DataFrame` interface for `KolmogorovSmirnovTest` in `mllib.stat`. ## How was this patch tested? Test suite added. Author: WeichenXu Author: jkbradley Closes #19108 from WeichenXu123/ml-ks-test. --- .../spark/ml/stat/KolmogorovSmirnovTest.scala | 113 ++++++++++++++ .../stat/JavaKolmogorovSmirnovTestSuite.java | 84 +++++++++++ .../ml/stat/KolmogorovSmirnovTestSuite.scala | 140 ++++++++++++++++++ 3 files changed, 337 insertions(+) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/stat/KolmogorovSmirnovTest.scala create mode 100644 mllib/src/test/java/org/apache/spark/ml/stat/JavaKolmogorovSmirnovTestSuite.java create mode 100644 mllib/src/test/scala/org/apache/spark/ml/stat/KolmogorovSmirnovTestSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/stat/KolmogorovSmirnovTest.scala b/mllib/src/main/scala/org/apache/spark/ml/stat/KolmogorovSmirnovTest.scala new file mode 100644 index 0000000000000..8d80e7768cb6e --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/stat/KolmogorovSmirnovTest.scala @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.stat + +import scala.annotation.varargs + +import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.api.java.function.Function +import org.apache.spark.ml.util.SchemaUtils +import org.apache.spark.mllib.stat.{Statistics => OldStatistics} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.functions.col + +/** + * :: Experimental :: + * + * Conduct the two-sided Kolmogorov Smirnov (KS) test for data sampled from a + * continuous distribution. By comparing the largest difference between the empirical cumulative + * distribution of the sample data and the theoretical distribution we can provide a test for the + * the null hypothesis that the sample data comes from that theoretical distribution. + * For more information on KS Test: + * @see + * Kolmogorov-Smirnov test (Wikipedia) + */ +@Experimental +@Since("2.4.0") +object KolmogorovSmirnovTest { + + /** Used to construct output schema of test */ + private case class KolmogorovSmirnovTestResult( + pValue: Double, + statistic: Double) + + private def getSampleRDD(dataset: DataFrame, sampleCol: String): RDD[Double] = { + SchemaUtils.checkNumericType(dataset.schema, sampleCol) + import dataset.sparkSession.implicits._ + dataset.select(col(sampleCol).cast("double")).as[Double].rdd + } + + /** + * Conduct the two-sided Kolmogorov-Smirnov (KS) test for data sampled from a + * continuous distribution. By comparing the largest difference between the empirical cumulative + * distribution of the sample data and the theoretical distribution we can provide a test for the + * the null hypothesis that the sample data comes from that theoretical distribution. + * + * @param dataset a `DataFrame` containing the sample of data to test + * @param sampleCol Name of sample column in dataset, of any numerical type + * @param cdf a `Double => Double` function to calculate the theoretical CDF at a given value + * @return DataFrame containing the test result for the input sampled data. + * This DataFrame will contain a single Row with the following fields: + * - `pValue: Double` + * - `statistic: Double` + */ + @Since("2.4.0") + def test(dataset: DataFrame, sampleCol: String, cdf: Double => Double): DataFrame = { + val spark = dataset.sparkSession + + val rdd = getSampleRDD(dataset, sampleCol) + val testResult = OldStatistics.kolmogorovSmirnovTest(rdd, cdf) + spark.createDataFrame(Seq(KolmogorovSmirnovTestResult( + testResult.pValue, testResult.statistic))) + } + + /** + * Java-friendly version of `test(dataset: DataFrame, sampleCol: String, cdf: Double => Double)` + */ + @Since("2.4.0") + def test(dataset: DataFrame, sampleCol: String, + cdf: Function[java.lang.Double, java.lang.Double]): DataFrame = { + test(dataset, sampleCol, (x: Double) => cdf.call(x)) + } + + /** + * Convenience function to conduct a one-sample, two-sided Kolmogorov-Smirnov test for probability + * distribution equality. Currently supports the normal distribution, taking as parameters + * the mean and standard deviation. + * + * @param dataset a `DataFrame` containing the sample of data to test + * @param sampleCol Name of sample column in dataset, of any numerical type + * @param distName a `String` name for a theoretical distribution, currently only support "norm". + * @param params `Double*` specifying the parameters to be used for the theoretical distribution + * @return DataFrame containing the test result for the input sampled data. + * This DataFrame will contain a single Row with the following fields: + * - `pValue: Double` + * - `statistic: Double` + */ + @Since("2.4.0") + @varargs + def test(dataset: DataFrame, sampleCol: String, distName: String, params: Double*): DataFrame = { + val spark = dataset.sparkSession + + val rdd = getSampleRDD(dataset, sampleCol) + val testResult = OldStatistics.kolmogorovSmirnovTest(rdd, distName, params: _*) + spark.createDataFrame(Seq(KolmogorovSmirnovTestResult( + testResult.pValue, testResult.statistic))) + } +} diff --git a/mllib/src/test/java/org/apache/spark/ml/stat/JavaKolmogorovSmirnovTestSuite.java b/mllib/src/test/java/org/apache/spark/ml/stat/JavaKolmogorovSmirnovTestSuite.java new file mode 100644 index 0000000000000..021272dd5a40c --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/stat/JavaKolmogorovSmirnovTestSuite.java @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.stat; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import org.apache.commons.math3.distribution.NormalDistribution; +import org.apache.spark.ml.linalg.VectorUDT; +import org.apache.spark.sql.Encoder; +import org.apache.spark.sql.Encoders; +import org.apache.spark.sql.types.DoubleType; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import org.junit.Test; + +import org.apache.spark.SharedSparkSession; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; + + +public class JavaKolmogorovSmirnovTestSuite extends SharedSparkSession { + + private transient Dataset dataset; + + @Override + public void setUp() throws IOException { + super.setUp(); + List points = Arrays.asList(0.1, 1.1, 10.1, -1.1); + + dataset = spark.createDataset(points, Encoders.DOUBLE()).toDF("sample"); + } + + @Test + public void testKSTestCDF() { + // Create theoretical distributions + NormalDistribution stdNormalDist = new NormalDistribution(0, 1); + + // set seeds + Long seed = 10L; + stdNormalDist.reseedRandomGenerator(seed); + Function stdNormalCDF = (x) -> stdNormalDist.cumulativeProbability(x); + + double pThreshold = 0.05; + + // Comparing a standard normal sample to a standard normal distribution + Row results = KolmogorovSmirnovTest + .test(dataset, "sample", stdNormalCDF).head(); + double pValue1 = results.getDouble(0); + // Cannot reject null hypothesis + assert(pValue1 > pThreshold); + } + + @Test + public void testKSTestNamedDistribution() { + double pThreshold = 0.05; + + // Comparing a standard normal sample to a standard normal distribution + Row results = KolmogorovSmirnovTest + .test(dataset, "sample", "norm", 0.0, 1.0).head(); + double pValue1 = results.getDouble(0); + // Cannot reject null hypothesis + assert(pValue1 > pThreshold); + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/stat/KolmogorovSmirnovTestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/stat/KolmogorovSmirnovTestSuite.scala new file mode 100644 index 0000000000000..1312de3a1b522 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/stat/KolmogorovSmirnovTestSuite.scala @@ -0,0 +1,140 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.stat + +import org.apache.commons.math3.distribution.{ExponentialDistribution, NormalDistribution, + RealDistribution, UniformRealDistribution} +import org.apache.commons.math3.stat.inference.{KolmogorovSmirnovTest => Math3KSTest} + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.ml.util.TestingUtils._ +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.Row + +class KolmogorovSmirnovTestSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + + import testImplicits._ + + def apacheCommonMath3EquivalenceTest( + sampleDist: RealDistribution, + theoreticalDist: RealDistribution, + theoreticalDistByName: (String, Array[Double]), + rejectNullHypothesis: Boolean): Unit = { + + // set seeds + val seed = 10L + sampleDist.reseedRandomGenerator(seed) + if (theoreticalDist != null) { + theoreticalDist.reseedRandomGenerator(seed) + } + + // Sample data from the distributions and parallelize it + val n = 100000 + val sampledArray = sampleDist.sample(n) + val sampledDF = sc.parallelize(sampledArray, 10).toDF("sample") + + // Use a apache math commons local KS test to verify calculations + val ksTest = new Math3KSTest() + val pThreshold = 0.05 + + // Comparing a standard normal sample to a standard normal distribution + val Row(pValue1: Double, statistic1: Double) = + if (theoreticalDist != null) { + val cdf = (x: Double) => theoreticalDist.cumulativeProbability(x) + KolmogorovSmirnovTest.test(sampledDF, "sample", cdf).head() + } else { + KolmogorovSmirnovTest.test(sampledDF, "sample", + theoreticalDistByName._1, + theoreticalDistByName._2: _* + ).head() + } + val theoreticalDistMath3 = if (theoreticalDist == null) { + assert(theoreticalDistByName._1 == "norm") + val params = theoreticalDistByName._2 + new NormalDistribution(params(0), params(1)) + } else { + theoreticalDist + } + val referenceStat1 = ksTest.kolmogorovSmirnovStatistic(theoreticalDistMath3, sampledArray) + val referencePVal1 = 1 - ksTest.cdf(referenceStat1, n) + // Verify vs apache math commons ks test + assert(statistic1 ~== referenceStat1 relTol 1e-4) + assert(pValue1 ~== referencePVal1 relTol 1e-4) + + if (rejectNullHypothesis) { + assert(pValue1 < pThreshold) + } else { + assert(pValue1 > pThreshold) + } + } + + test("1 sample Kolmogorov-Smirnov test: apache commons math3 implementation equivalence") { + // Create theoretical distributions + val stdNormalDist = new NormalDistribution(0.0, 1.0) + val expDist = new ExponentialDistribution(0.6) + val uniformDist = new UniformRealDistribution(0.0, 1.0) + val expDist2 = new ExponentialDistribution(0.2) + val stdNormByName = Tuple2("norm", Array(0.0, 1.0)) + + apacheCommonMath3EquivalenceTest(stdNormalDist, null, stdNormByName, false) + apacheCommonMath3EquivalenceTest(expDist, null, stdNormByName, true) + apacheCommonMath3EquivalenceTest(uniformDist, null, stdNormByName, true) + apacheCommonMath3EquivalenceTest(expDist, expDist2, null, true) + } + + test("1 sample Kolmogorov-Smirnov test: R implementation equivalence") { + /* + Comparing results with R's implementation of Kolmogorov-Smirnov for 1 sample + > sessionInfo() + R version 3.2.0 (2015-04-16) + Platform: x86_64-apple-darwin13.4.0 (64-bit) + > set.seed(20) + > v <- rnorm(20) + > v + [1] 1.16268529 -0.58592447 1.78546500 -1.33259371 -0.44656677 0.56960612 + [7] -2.88971761 -0.86901834 -0.46170268 -0.55554091 -0.02013537 -0.15038222 + [13] -0.62812676 1.32322085 -1.52135057 -0.43742787 0.97057758 0.02822264 + [19] -0.08578219 0.38921440 + > ks.test(v, pnorm, alternative = "two.sided") + + One-sample Kolmogorov-Smirnov test + + data: v + D = 0.18874, p-value = 0.4223 + alternative hypothesis: two-sided + */ + + val rKSStat = 0.18874 + val rKSPVal = 0.4223 + val rData = sc.parallelize( + Array( + 1.1626852897838, -0.585924465893051, 1.78546500331661, -1.33259371048501, + -0.446566766553219, 0.569606122374976, -2.88971761441412, -0.869018343326555, + -0.461702683149641, -0.555540910137444, -0.0201353678515895, -0.150382224136063, + -0.628126755843964, 1.32322085193283, -1.52135057001199, -0.437427868856691, + 0.970577579543399, 0.0282226444247749, -0.0857821886527593, 0.389214404984942 + ) + ).toDF("sample") + val Row(pValue: Double, statistic: Double) = KolmogorovSmirnovTest + .test(rData, "sample", "norm", 0, 1).head() + assert(statistic ~== rKSStat relTol 1e-4) + assert(pValue ~== rKSPVal relTol 1e-4) + } +} From 2c4b9962fdf8c1beb66126ca41628c72eb6c2383 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Tue, 20 Mar 2018 11:46:51 -0700 Subject: [PATCH 0496/1161] [SPARK-23574][SQL] Report SinglePartition in DataSourceV2ScanExec when there's exactly 1 data reader factory. ## What changes were proposed in this pull request? Report SinglePartition in DataSourceV2ScanExec when there's exactly 1 data reader factory. Note that this means reader factories end up being constructed as partitioning is checked; let me know if you think that could be a problem. ## How was this patch tested? existing unit tests Author: Jose Torres Author: Jose Torres Closes #20726 from jose-torres/SPARK-23574. --- .../v2/reader/SupportsReportPartitioning.java | 3 ++ .../datasources/v2/DataSourceRDD.scala | 4 +-- .../datasources/v2/DataSourceV2ScanExec.scala | 29 ++++++++++++++----- .../ContinuousDataSourceRDDIter.scala | 4 +-- .../sql/sources/v2/DataSourceV2Suite.scala | 20 ++++++++++++- .../sql/streaming/StreamingQuerySuite.scala | 4 +-- 6 files changed, 50 insertions(+), 14 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java index 5405a916951b8..607628746e873 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java @@ -23,6 +23,9 @@ /** * A mix in interface for {@link DataSourceReader}. Data source readers can implement this * interface to report data partitioning and try to avoid shuffle at Spark side. + * + * Note that, when the reader creates exactly one {@link DataReaderFactory}, Spark may avoid + * adding a shuffle even if the reader does not implement this interface. */ @InterfaceStability.Evolving public interface SupportsReportPartitioning extends DataSourceReader { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala index 5ed0ba71e94c7..f85971be394b1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala @@ -29,11 +29,11 @@ class DataSourceRDDPartition[T : ClassTag](val index: Int, val readerFactory: Da class DataSourceRDD[T: ClassTag]( sc: SparkContext, - @transient private val readerFactories: java.util.List[DataReaderFactory[T]]) + @transient private val readerFactories: Seq[DataReaderFactory[T]]) extends RDD[T](sc, Nil) { override protected def getPartitions: Array[Partition] = { - readerFactories.asScala.zipWithIndex.map { + readerFactories.zipWithIndex.map { case (readerFactory, index) => new DataSourceRDDPartition(index, readerFactory) }.toArray } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala index cb691ba297076..3a5e7bf89e142 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala @@ -25,12 +25,14 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical +import org.apache.spark.sql.catalyst.plans.physical.SinglePartition import org.apache.spark.sql.execution.{ColumnarBatchScan, LeafExecNode, WholeStageCodegenExec} import org.apache.spark.sql.execution.streaming.continuous._ import org.apache.spark.sql.sources.v2.DataSourceV2 import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousReader import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.vectorized.ColumnarBatch /** * Physical plan node for scanning data from a data source. @@ -56,6 +58,15 @@ case class DataSourceV2ScanExec( } override def outputPartitioning: physical.Partitioning = reader match { + case r: SupportsScanColumnarBatch if r.enableBatchRead() && batchReaderFactories.size == 1 => + SinglePartition + + case r: SupportsScanColumnarBatch if !r.enableBatchRead() && readerFactories.size == 1 => + SinglePartition + + case r if !r.isInstanceOf[SupportsScanColumnarBatch] && readerFactories.size == 1 => + SinglePartition + case s: SupportsReportPartitioning => new DataSourcePartitioning( s.outputPartitioning(), AttributeMap(output.map(a => a -> a.name))) @@ -63,29 +74,33 @@ case class DataSourceV2ScanExec( case _ => super.outputPartitioning } - private lazy val readerFactories: java.util.List[DataReaderFactory[UnsafeRow]] = reader match { - case r: SupportsScanUnsafeRow => r.createUnsafeRowReaderFactories() + private lazy val readerFactories: Seq[DataReaderFactory[UnsafeRow]] = reader match { + case r: SupportsScanUnsafeRow => r.createUnsafeRowReaderFactories().asScala case _ => reader.createDataReaderFactories().asScala.map { new RowToUnsafeRowDataReaderFactory(_, reader.readSchema()): DataReaderFactory[UnsafeRow] - }.asJava + } } - private lazy val inputRDD: RDD[InternalRow] = reader match { + private lazy val batchReaderFactories: Seq[DataReaderFactory[ColumnarBatch]] = reader match { case r: SupportsScanColumnarBatch if r.enableBatchRead() => assert(!reader.isInstanceOf[ContinuousReader], "continuous stream reader does not support columnar read yet.") - new DataSourceRDD(sparkContext, r.createBatchDataReaderFactories()) - .asInstanceOf[RDD[InternalRow]] + r.createBatchDataReaderFactories().asScala + } + private lazy val inputRDD: RDD[InternalRow] = reader match { case _: ContinuousReader => EpochCoordinatorRef.get( sparkContext.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), sparkContext.env) - .askSync[Unit](SetReaderPartitions(readerFactories.size())) + .askSync[Unit](SetReaderPartitions(readerFactories.size)) new ContinuousDataSourceRDD(sparkContext, sqlContext, readerFactories) .asInstanceOf[RDD[InternalRow]] + case r: SupportsScanColumnarBatch if r.enableBatchRead() => + new DataSourceRDD(sparkContext, batchReaderFactories).asInstanceOf[RDD[InternalRow]] + case _ => new DataSourceRDD(sparkContext, readerFactories).asInstanceOf[RDD[InternalRow]] } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala index cf02c0dda25d7..06754f01657d3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala @@ -35,14 +35,14 @@ import org.apache.spark.util.ThreadUtils class ContinuousDataSourceRDD( sc: SparkContext, sqlContext: SQLContext, - @transient private val readerFactories: java.util.List[DataReaderFactory[UnsafeRow]]) + @transient private val readerFactories: Seq[DataReaderFactory[UnsafeRow]]) extends RDD[UnsafeRow](sc, Nil) { private val dataQueueSize = sqlContext.conf.continuousStreamingExecutorQueueSize private val epochPollIntervalMs = sqlContext.conf.continuousStreamingExecutorPollIntervalMs override protected def getPartitions: Array[Partition] = { - readerFactories.asScala.zipWithIndex.map { + readerFactories.zipWithIndex.map { case (readerFactory, index) => new DataSourceRDDPartition(index, readerFactory) }.toArray } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala index 1157a350461d8..e0a53272cd222 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala @@ -25,7 +25,7 @@ import org.apache.spark.SparkException import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row} import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, DataSourceV2ScanExec} -import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec +import org.apache.spark.sql.execution.exchange.{Exchange, ShuffleExchangeExec} import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector import org.apache.spark.sql.functions._ import org.apache.spark.sql.sources.{Filter, GreaterThan} @@ -191,6 +191,11 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { } } + test("SPARK-23574: no shuffle exchange with single partition") { + val df = spark.read.format(classOf[SimpleSinglePartitionSource].getName).load().agg(count("*")) + assert(df.queryExecution.executedPlan.collect { case e: Exchange => e }.isEmpty) + } + test("simple writable data source") { // TODO: java implementation. Seq(classOf[SimpleWritableDataSource]).foreach { cls => @@ -336,6 +341,19 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { } } +class SimpleSinglePartitionSource extends DataSourceV2 with ReadSupport { + + class Reader extends DataSourceReader { + override def readSchema(): StructType = new StructType().add("i", "int").add("j", "int") + + override def createDataReaderFactories(): JList[DataReaderFactory[Row]] = { + java.util.Arrays.asList(new SimpleDataReaderFactory(0, 5)) + } + } + + override def createReader(options: DataSourceOptions): DataSourceReader = new Reader +} + class SimpleDataSourceV2 extends DataSourceV2 with ReadSupport { class Reader extends DataSourceReader { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala index 3f9aa0d1fa5be..ebc9a87b23f84 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala @@ -326,9 +326,9 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi assert(progress.durationMs.get("setOffsetRange") === 50) assert(progress.durationMs.get("getEndOffset") === 100) - assert(progress.durationMs.get("queryPlanning") === 0) + assert(progress.durationMs.get("queryPlanning") === 200) assert(progress.durationMs.get("walCommit") === 0) - assert(progress.durationMs.get("addBatch") === 350) + assert(progress.durationMs.get("addBatch") === 150) assert(progress.durationMs.get("triggerExecution") === 500) assert(progress.sources.length === 1) From 477d6bd7265e255fd43e53edda02019b32f29bb2 Mon Sep 17 00:00:00 2001 From: Henry Robinson Date: Tue, 20 Mar 2018 13:27:50 -0700 Subject: [PATCH 0497/1161] [SPARK-23500][SQL] Fix complex type simplification rules to apply to entire plan ## What changes were proposed in this pull request? Complex type simplification optimizer rules were not applied to the entire plan, just the expressions reachable from the root node. This patch fixes the rules to transform the entire plan. ## How was this patch tested? New unit test + ran sql / core tests. Author: Henry Robinson Author: Henry Robinson Closes #20687 from henryr/spark-25000. --- .../sql/catalyst/optimizer/ComplexTypes.scala | 61 ++++++++----------- .../sql/catalyst/optimizer/Optimizer.scala | 4 +- .../optimizer/complexTypesSuite.scala | 55 +++++++++++++++-- 3 files changed, 76 insertions(+), 44 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala index be0009ec8c760..db7d6d3254bd2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala @@ -18,39 +18,39 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan} import org.apache.spark.sql.catalyst.rules.Rule /** -* push down operations into [[CreateNamedStructLike]]. -*/ -object SimplifyCreateStructOps extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = { - plan.transformExpressionsUp { - // push down field extraction + * Simplify redundant [[CreateNamedStructLike]], [[CreateArray]] and [[CreateMap]] expressions. + */ +object SimplifyExtractValueOps extends Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = plan transform { + // One place where this optimization is invalid is an aggregation where the select + // list expression is a function of a grouping expression: + // + // SELECT struct(a,b).a FROM tbl GROUP BY struct(a,b) + // + // cannot be simplified to SELECT a FROM tbl GROUP BY struct(a,b). So just skip this + // optimization for Aggregates (although this misses some cases where the optimization + // can be made). + case a: Aggregate => a + case p => p.transformExpressionsUp { + // Remove redundant field extraction. case GetStructField(createNamedStructLike: CreateNamedStructLike, ordinal, _) => createNamedStructLike.valExprs(ordinal) - } - } -} -/** -* push down operations into [[CreateArray]]. -*/ -object SimplifyCreateArrayOps extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = { - plan.transformExpressionsUp { - // push down field selection (array of structs) - case GetArrayStructFields(CreateArray(elems), field, ordinal, numFields, containsNull) => - // instead f selecting the field on the entire array, - // select it from each member of the array. - // pushing down the operation this way open other optimizations opportunities - // (i.e. struct(...,x,...).x) + // Remove redundant array indexing. + case GetArrayStructFields(CreateArray(elems), field, ordinal, _, _) => + // Instead of selecting the field on the entire array, select it from each member + // of the array. Pushing down the operation this way may open other optimizations + // opportunities (i.e. struct(...,x,...).x) CreateArray(elems.map(GetStructField(_, ordinal, Some(field.name)))) - // push down item selection. + + // Remove redundant map lookup. case ga @ GetArrayItem(CreateArray(elems), IntegerLiteral(idx)) => - // instead of creating the array and then selecting one row, - // remove array creation altgether. + // Instead of creating the array and then selecting one row, remove array creation + // altogether. if (idx >= 0 && idx < elems.size) { // valid index elems(idx) @@ -58,18 +58,7 @@ object SimplifyCreateArrayOps extends Rule[LogicalPlan] { // out of bounds, mimic the runtime behavior and return null Literal(null, ga.dataType) } - } - } -} - -/** -* push down operations into [[CreateMap]]. -*/ -object SimplifyCreateMapOps extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = { - plan.transformExpressionsUp { case GetMapValue(CreateMap(elems), key) => CaseKeyWhen(key, elems) } } } - diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 91208479be03b..2829d1d81eb1a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -85,9 +85,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog) EliminateSerialization, RemoveRedundantAliases, RemoveRedundantProject, - SimplifyCreateStructOps, - SimplifyCreateArrayOps, - SimplifyCreateMapOps, + SimplifyExtractValueOps, CombineConcats) ++ extendedOperatorOptimizationRules diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala index de544ac314789..e44a6692ad8e2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala @@ -44,14 +44,13 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { BooleanSimplification, SimplifyConditionals, SimplifyBinaryComparison, - SimplifyCreateStructOps, - SimplifyCreateArrayOps, - SimplifyCreateMapOps) :: Nil + SimplifyExtractValueOps) :: Nil } val idAtt = ('id).long.notNull + val nullableIdAtt = ('nullable_id).long - lazy val relation = LocalRelation(idAtt ) + lazy val relation = LocalRelation(idAtt, nullableIdAtt) test("explicit get from namedStruct") { val query = relation @@ -321,7 +320,7 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { .select( CaseWhen(Seq( (EqualTo(2L, 'id), ('id + 1L)), - // these two are possible matches, we can't tell untill runtime + // these two are possible matches, we can't tell until runtime (EqualTo(2L, ('id + 1L)), ('id + 2L)), (EqualTo(2L, 'id + 2L), Literal.create(null, LongType)), // this is a definite match (two constants), @@ -331,4 +330,50 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { .analyze comparePlans(Optimizer execute rel, expected) } + + test("SPARK-23500: Simplify complex ops that aren't at the plan root") { + val structRel = relation + .select(GetStructField(CreateNamedStruct(Seq("att1", 'nullable_id)), 0, None) as "foo") + .groupBy($"foo")("1").analyze + val structExpected = relation + .select('nullable_id as "foo") + .groupBy($"foo")("1").analyze + comparePlans(Optimizer execute structRel, structExpected) + + // These tests must use nullable attributes from the base relation for the following reason: + // in the 'original' plans below, the Aggregate node produced by groupBy() has a + // nullable AttributeReference to a1, because both array indexing and map lookup are + // nullable expressions. After optimization, the same attribute is now non-nullable, + // but the AttributeReference is not updated to reflect this. In the 'expected' plans, + // the grouping expressions have the same nullability as the original attribute in the + // relation. If that attribute is non-nullable, the tests will fail as the plans will + // compare differently, so for these tests we must use a nullable attribute. See + // SPARK-23634. + val arrayRel = relation + .select(GetArrayItem(CreateArray(Seq('nullable_id, 'nullable_id + 1L)), 0) as "a1") + .groupBy($"a1")("1").analyze + val arrayExpected = relation.select('nullable_id as "a1").groupBy($"a1")("1").analyze + comparePlans(Optimizer execute arrayRel, arrayExpected) + + val mapRel = relation + .select(GetMapValue(CreateMap(Seq("id", 'nullable_id)), "id") as "m1") + .groupBy($"m1")("1").analyze + val mapExpected = relation + .select('nullable_id as "m1") + .groupBy($"m1")("1").analyze + comparePlans(Optimizer execute mapRel, mapExpected) + } + + test("SPARK-23500: Ensure that aggregation expressions are not simplified") { + // Make sure that aggregation exprs are correctly ignored. Maps can't be used in + // grouping exprs so aren't tested here. + val structAggRel = relation.groupBy( + CreateNamedStruct(Seq("att1", 'nullable_id)))( + GetStructField(CreateNamedStruct(Seq("att1", 'nullable_id)), 0, None)).analyze + comparePlans(Optimizer execute structAggRel, structAggRel) + + val arrayAggRel = relation.groupBy( + CreateArray(Seq('nullable_id)))(GetArrayItem(CreateArray(Seq('nullable_id)), 0)).analyze + comparePlans(Optimizer execute arrayAggRel, arrayAggRel) + } } From 983e8d9d64b6b1304c43ea6e5dffdc1078138ef9 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Tue, 20 Mar 2018 23:17:49 -0700 Subject: [PATCH 0498/1161] [SPARK-23666][SQL] Do not display exprIds of Alias in user-facing info. ## What changes were proposed in this pull request? To drop `exprId`s for `Alias` in user-facing info., this pr added an entry for `Alias` in `NonSQLExpression.sql` ## How was this patch tested? Added tests in `UDFSuite`. Author: Takeshi Yamamuro Closes #20827 from maropu/SPARK-23666. --- docs/sql-programming-guide.md | 1 + .../sql/catalyst/expressions/Expression.scala | 1 + .../scala/org/apache/spark/sql/UDFSuite.scala | 132 ++++++++++-------- 3 files changed, 78 insertions(+), 56 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 0e092e0e37ccf..5b47fd77f2cbc 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1806,6 +1806,7 @@ working with timestamps in `pandas_udf`s to get the best performance, see - Since Spark 2.4, Spark maximizes the usage of a vectorized ORC reader for ORC files by default. To do that, `spark.sql.orc.impl` and `spark.sql.orc.filterPushdown` change their default values to `native` and `true` respectively. - In PySpark, when Arrow optimization is enabled, previously `toPandas` just failed when Arrow optimization is unabled to be used whereas `createDataFrame` from Pandas DataFrame allowed the fallback to non-optimization. Now, both `toPandas` and `createDataFrame` from Pandas DataFrame allow the fallback by default, which can be switched off by `spark.sql.execution.arrow.fallback.enabled`. - Since Spark 2.4, writing an empty dataframe to a directory launches at least one write task, even if physically the dataframe has no partition. This introduces a small behavior change that for self-describing file formats like Parquet and Orc, Spark creates a metadata-only file in the target directory when writing a 0-partition dataframe, so that schema inference can still work if users read that directory later. The new behavior is more reasonable and more consistent regarding writing empty dataframe. + - Since Spark 2.4, expression IDs in UDF arguments do not appear in column names. For example, an column name in Spark 2.4 is not `UDF:f(col0 AS colA#28)` but ``UDF:f(col0 AS `colA`)``. ## Upgrading From Spark SQL 2.2 to 2.3 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index d7f9e38915dd5..38caf67d465d8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -288,6 +288,7 @@ trait NonSQLExpression extends Expression { final override def sql: String = { transform { case a: Attribute => new PrettyAttribute(a) + case a: Alias => PrettyAttribute(a.sql, a.dataType) }.toString } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index af6a10b425b9f..21afdc7e2a33f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -144,73 +144,81 @@ class UDFSuite extends QueryTest with SharedSQLContext { } test("UDF in a WHERE") { - spark.udf.register("oneArgFilter", (n: Int) => { n > 80 }) + withTempView("integerData") { + spark.udf.register("oneArgFilter", (n: Int) => { n > 80 }) - val df = sparkContext.parallelize( - (1 to 100).map(i => TestData(i, i.toString))).toDF() - df.createOrReplaceTempView("integerData") + val df = sparkContext.parallelize( + (1 to 100).map(i => TestData(i, i.toString))).toDF() + df.createOrReplaceTempView("integerData") - val result = - sql("SELECT * FROM integerData WHERE oneArgFilter(key)") - assert(result.count() === 20) + val result = + sql("SELECT * FROM integerData WHERE oneArgFilter(key)") + assert(result.count() === 20) + } } test("UDF in a HAVING") { - spark.udf.register("havingFilter", (n: Long) => { n > 5 }) - - val df = Seq(("red", 1), ("red", 2), ("blue", 10), - ("green", 100), ("green", 200)).toDF("g", "v") - df.createOrReplaceTempView("groupData") - - val result = - sql( - """ - | SELECT g, SUM(v) as s - | FROM groupData - | GROUP BY g - | HAVING havingFilter(s) - """.stripMargin) - - assert(result.count() === 2) + withTempView("groupData") { + spark.udf.register("havingFilter", (n: Long) => { n > 5 }) + + val df = Seq(("red", 1), ("red", 2), ("blue", 10), + ("green", 100), ("green", 200)).toDF("g", "v") + df.createOrReplaceTempView("groupData") + + val result = + sql( + """ + | SELECT g, SUM(v) as s + | FROM groupData + | GROUP BY g + | HAVING havingFilter(s) + """.stripMargin) + + assert(result.count() === 2) + } } test("UDF in a GROUP BY") { - spark.udf.register("groupFunction", (n: Int) => { n > 10 }) - - val df = Seq(("red", 1), ("red", 2), ("blue", 10), - ("green", 100), ("green", 200)).toDF("g", "v") - df.createOrReplaceTempView("groupData") - - val result = - sql( - """ - | SELECT SUM(v) - | FROM groupData - | GROUP BY groupFunction(v) - """.stripMargin) - assert(result.count() === 2) + withTempView("groupData") { + spark.udf.register("groupFunction", (n: Int) => { n > 10 }) + + val df = Seq(("red", 1), ("red", 2), ("blue", 10), + ("green", 100), ("green", 200)).toDF("g", "v") + df.createOrReplaceTempView("groupData") + + val result = + sql( + """ + | SELECT SUM(v) + | FROM groupData + | GROUP BY groupFunction(v) + """.stripMargin) + assert(result.count() === 2) + } } test("UDFs everywhere") { - spark.udf.register("groupFunction", (n: Int) => { n > 10 }) - spark.udf.register("havingFilter", (n: Long) => { n > 2000 }) - spark.udf.register("whereFilter", (n: Int) => { n < 150 }) - spark.udf.register("timesHundred", (n: Long) => { n * 100 }) - - val df = Seq(("red", 1), ("red", 2), ("blue", 10), - ("green", 100), ("green", 200)).toDF("g", "v") - df.createOrReplaceTempView("groupData") - - val result = - sql( - """ - | SELECT timesHundred(SUM(v)) as v100 - | FROM groupData - | WHERE whereFilter(v) - | GROUP BY groupFunction(v) - | HAVING havingFilter(v100) - """.stripMargin) - assert(result.count() === 1) + withTempView("groupData") { + spark.udf.register("groupFunction", (n: Int) => { n > 10 }) + spark.udf.register("havingFilter", (n: Long) => { n > 2000 }) + spark.udf.register("whereFilter", (n: Int) => { n < 150 }) + spark.udf.register("timesHundred", (n: Long) => { n * 100 }) + + val df = Seq(("red", 1), ("red", 2), ("blue", 10), + ("green", 100), ("green", 200)).toDF("g", "v") + df.createOrReplaceTempView("groupData") + + val result = + sql( + """ + | SELECT timesHundred(SUM(v)) as v100 + | FROM groupData + | WHERE whereFilter(v) + | GROUP BY groupFunction(v) + | HAVING havingFilter(v100) + """.stripMargin) + assert(result.count() === 1) + } } test("struct UDF") { @@ -304,4 +312,16 @@ class UDFSuite extends QueryTest with SharedSQLContext { assert(explainStr(spark.range(1).select(udf1(udf2(functions.lit(1))))) .contains(s"UDF:$udf1Name(UDF:$udf2Name(1))")) } + + test("SPARK-23666 Do not display exprId in argument names") { + withTempView("x") { + Seq(((1, 2), 3)).toDF("a", "b").createOrReplaceTempView("x") + spark.udf.register("f", (a: Int) => a) + val outputStream = new java.io.ByteArrayOutputStream() + Console.withOut(outputStream) { + spark.sql("SELECT f(a._1) FROM x").show + } + assert(outputStream.toString.contains("UDF:f(a._1 AS `_1`)")) + } + } } From 500b21c3d6247015e550be7e144e9b4b26fe28be Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Wed, 21 Mar 2018 10:19:02 -0500 Subject: [PATCH 0499/1161] [SPARK-23568][ML] Use metadata numAttributes if available in Silhouette ## What changes were proposed in this pull request? Silhouette need to know the number of features. This was taken using `first` and checking the size of the vector. Despite this works fine, if the number of attributes is present in metadata, we can avoid to trigger a job for this and use the metadata value. This can help improving performances of course. ## How was this patch tested? existing UTs + added UT Author: Marco Gaido Closes #20719 from mgaido91/SPARK-23568. --- .../ml/evaluation/ClusteringEvaluator.scala | 22 ++++++++++++++---- .../evaluation/ClusteringEvaluatorSuite.scala | 23 ++++++++++++++++++- 2 files changed, 40 insertions(+), 5 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringEvaluator.scala index 8d4ae562b3d2b..4353c46781e9d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringEvaluator.scala @@ -20,6 +20,7 @@ package org.apache.spark.ml.evaluation import org.apache.spark.SparkContext import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.broadcast.Broadcast +import org.apache.spark.ml.attribute.AttributeGroup import org.apache.spark.ml.linalg.{BLAS, DenseVector, SparseVector, Vector, Vectors, VectorUDT} import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators} import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasPredictionCol} @@ -170,6 +171,15 @@ private[evaluation] abstract class Silhouette { def overallScore(df: DataFrame, scoreColumn: Column): Double = { df.select(avg(scoreColumn)).collect()(0).getDouble(0) } + + protected def getNumberOfFeatures(dataFrame: DataFrame, columnName: String): Int = { + val group = AttributeGroup.fromStructField(dataFrame.schema(columnName)) + if (group.size < 0) { + dataFrame.select(col(columnName)).first().getAs[Vector](0).size + } else { + group.size + } + } } /** @@ -360,7 +370,7 @@ private[evaluation] object SquaredEuclideanSilhouette extends Silhouette { df: DataFrame, predictionCol: String, featuresCol: String): Map[Double, ClusterStats] = { - val numFeatures = df.select(col(featuresCol)).first().getAs[Vector](0).size + val numFeatures = getNumberOfFeatures(df, featuresCol) val clustersStatsRDD = df.select( col(predictionCol).cast(DoubleType), col(featuresCol), col("squaredNorm")) .rdd @@ -552,8 +562,11 @@ private[evaluation] object CosineSilhouette extends Silhouette { * @return A [[scala.collection.immutable.Map]] which associates each cluster id to a * its statistics (ie. the precomputed values `N` and `$\Omega_{\Gamma}$`). */ - def computeClusterStats(df: DataFrame, predictionCol: String): Map[Double, (Vector, Long)] = { - val numFeatures = df.select(col(normalizedFeaturesColName)).first().getAs[Vector](0).size + def computeClusterStats( + df: DataFrame, + featuresCol: String, + predictionCol: String): Map[Double, (Vector, Long)] = { + val numFeatures = getNumberOfFeatures(df, featuresCol) val clustersStatsRDD = df.select( col(predictionCol).cast(DoubleType), col(normalizedFeaturesColName)) .rdd @@ -626,7 +639,8 @@ private[evaluation] object CosineSilhouette extends Silhouette { normalizeFeatureUDF(col(featuresCol))) // compute aggregate values for clusters needed by the algorithm - val clustersStatsMap = computeClusterStats(dfWithNormalizedFeatures, predictionCol) + val clustersStatsMap = computeClusterStats(dfWithNormalizedFeatures, featuresCol, + predictionCol) // Silhouette is reasonable only when the number of clusters is greater then 1 assert(clustersStatsMap.size > 1, "Number of clusters must be greater than one.") diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/ClusteringEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/ClusteringEvaluatorSuite.scala index 3bf34770f5687..2c175ff68e0b8 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/evaluation/ClusteringEvaluatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/ClusteringEvaluatorSuite.scala @@ -17,7 +17,9 @@ package org.apache.spark.ml.evaluation -import org.apache.spark.SparkFunSuite +import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.ml.attribute.AttributeGroup +import org.apache.spark.ml.linalg.Vector import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.ml.util.TestingUtils._ @@ -100,4 +102,23 @@ class ClusteringEvaluatorSuite } } + test("SPARK-23568: we should use metadata to determine features number") { + val attributesNum = irisDataset.select("features").rdd.first().getAs[Vector](0).size + val attrGroup = new AttributeGroup("features", attributesNum) + val df = irisDataset.select($"features".as("features", attrGroup.toMetadata()), $"label") + require(AttributeGroup.fromStructField(df.schema("features")) + .numAttributes.isDefined, "numAttributes metadata should be defined") + val evaluator = new ClusteringEvaluator() + .setFeaturesCol("features") + .setPredictionCol("label") + + // with the proper metadata we compute correctly the result + assert(evaluator.evaluate(df) ~== 0.6564679231 relTol 1e-5) + + val wrongAttrGroup = new AttributeGroup("features", attributesNum + 1) + val dfWrong = irisDataset.select($"features".as("features", wrongAttrGroup.toMetadata()), + $"label") + // with wrong metadata the evaluator throws an Exception + intercept[SparkException](evaluator.evaluate(dfWrong)) + } } From bf09f2f71276d3b3a84a8f89109bd785a066c3e6 Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Wed, 21 Mar 2018 09:39:14 -0700 Subject: [PATCH 0500/1161] [SPARK-10884][ML] Support prediction on single instance for regression and classification related models ## What changes were proposed in this pull request? Support prediction on single instance for regression and classification related models (i.e., PredictionModel, ClassificationModel and their sub classes). Add corresponding test cases. ## How was this patch tested? Test cases added. Author: WeichenXu Closes #19381 from WeichenXu123/single_prediction. --- .../scala/org/apache/spark/ml/Predictor.scala | 5 ++-- .../spark/ml/classification/Classifier.scala | 6 ++--- .../DecisionTreeClassifier.scala | 2 +- .../ml/classification/GBTClassifier.scala | 2 +- .../spark/ml/classification/LinearSVC.scala | 2 +- .../classification/LogisticRegression.scala | 2 +- .../MultilayerPerceptronClassifier.scala | 2 +- .../ml/regression/DecisionTreeRegressor.scala | 2 +- .../spark/ml/regression/GBTRegressor.scala | 2 +- .../GeneralizedLinearRegression.scala | 2 +- .../ml/regression/LinearRegression.scala | 2 +- .../ml/regression/RandomForestRegressor.scala | 2 +- .../DecisionTreeClassifierSuite.scala | 17 ++++++++++++- .../classification/GBTClassifierSuite.scala | 9 +++++++ .../ml/classification/LinearSVCSuite.scala | 6 +++++ .../LogisticRegressionSuite.scala | 9 +++++++ .../MultilayerPerceptronClassifierSuite.scala | 12 ++++++++++ .../ml/classification/NaiveBayesSuite.scala | 22 +++++++++++++++++ .../RandomForestClassifierSuite.scala | 16 +++++++++++++ .../DecisionTreeRegressorSuite.scala | 15 ++++++++++++ .../ml/regression/GBTRegressorSuite.scala | 8 +++++++ .../GeneralizedLinearRegressionSuite.scala | 8 +++++++ .../ml/regression/LinearRegressionSuite.scala | 7 ++++++ .../RandomForestRegressorSuite.scala | 24 +++++++++++++++---- .../org/apache/spark/ml/util/MLTest.scala | 15 ++++++++++-- 25 files changed, 176 insertions(+), 23 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala index 08b0cb9b8f6a5..d8f3dfa874439 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala @@ -219,7 +219,8 @@ abstract class PredictionModel[FeaturesType, M <: PredictionModel[FeaturesType, /** * Predict label for the given features. - * This internal method is used to implement `transform()` and output [[predictionCol]]. + * This method is used to implement `transform()` and output [[predictionCol]]. */ - protected def predict(features: FeaturesType): Double + @Since("2.4.0") + def predict(features: FeaturesType): Double } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala index 9d1d5aa1e0cff..7e5790ab70ee9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala @@ -18,7 +18,7 @@ package org.apache.spark.ml.classification import org.apache.spark.SparkException -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.ml.{PredictionModel, Predictor, PredictorParams} import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.linalg.{Vector, VectorUDT} @@ -192,12 +192,12 @@ abstract class ClassificationModel[FeaturesType, M <: ClassificationModel[Featur /** * Predict label for the given features. - * This internal method is used to implement `transform()` and output [[predictionCol]]. + * This method is used to implement `transform()` and output [[predictionCol]]. * * This default implementation for classification predicts the index of the maximum value * from `predictRaw()`. */ - override protected def predict(features: FeaturesType): Double = { + override def predict(features: FeaturesType): Double = { raw2prediction(predictRaw(features)) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala index 9f60f0896ec52..65cce697d8202 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala @@ -181,7 +181,7 @@ class DecisionTreeClassificationModel private[ml] ( private[ml] def this(rootNode: Node, numFeatures: Int, numClasses: Int) = this(Identifiable.randomUID("dtc"), rootNode, numFeatures, numClasses) - override protected def predict(features: Vector): Double = { + override def predict(features: Vector): Double = { rootNode.predictImpl(features).prediction } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index f11bc1d8fe415..cd44489f618b2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -267,7 +267,7 @@ class GBTClassificationModel private[ml]( dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) } - override protected def predict(features: Vector): Double = { + override def predict(features: Vector): Double = { // If thresholds defined, use predictRaw to get probabilities, otherwise use optimization if (isDefined(thresholds)) { super.predict(features) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala index ce400f4f1faf7..8f950cd28c3aa 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala @@ -316,7 +316,7 @@ class LinearSVCModel private[classification] ( BLAS.dot(features, coefficients) + intercept } - override protected def predict(features: Vector): Double = { + override def predict(features: Vector): Double = { if (margin(features) > $(threshold)) 1.0 else 0.0 } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index fa191604218db..3ae4db3f3f965 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -1090,7 +1090,7 @@ class LogisticRegressionModel private[spark] ( * Predict label for the given feature vector. * The behavior of this can be adjusted using `thresholds`. */ - override protected def predict(features: Vector): Double = if (isMultinomial) { + override def predict(features: Vector): Double = if (isMultinomial) { super.predict(features) } else { // Note: We should use getThreshold instead of $(threshold) since getThreshold is overridden. diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala index fd4c98f22132f..af2e4699924e5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala @@ -322,7 +322,7 @@ class MultilayerPerceptronClassificationModel private[ml] ( * Predict label for the given features. * This internal method is used to implement `transform()` and output [[predictionCol]]. */ - override protected def predict(features: Vector): Double = { + override def predict(features: Vector): Double = { LabelConverter.decodeLabel(mlpModel.predict(features)) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala index 0291a57487c47..ad154fcd010cc 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala @@ -178,7 +178,7 @@ class DecisionTreeRegressionModel private[ml] ( private[ml] def this(rootNode: Node, numFeatures: Int) = this(Identifiable.randomUID("dtr"), rootNode, numFeatures) - override protected def predict(features: Vector): Double = { + override def predict(features: Vector): Double = { rootNode.predictImpl(features).prediction } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala index f41d15b62dddd..6569ff2a5bfc1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala @@ -230,7 +230,7 @@ class GBTRegressionModel private[ml]( dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) } - override protected def predict(features: Vector): Double = { + override def predict(features: Vector): Double = { // TODO: When we add a generic Boosting class, handle transform there? SPARK-7129 // Classifies by thresholding sum of weighted tree predictions val treePredictions = _trees.map(_.rootNode.predictImpl(features).prediction) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala index 917a4d238d467..9f1f2405c428e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala @@ -1010,7 +1010,7 @@ class GeneralizedLinearRegressionModel private[ml] ( private lazy val familyAndLink = FamilyAndLink(this) - override protected def predict(features: Vector): Double = { + override def predict(features: Vector): Double = { predict(features, 0.0) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index 6d3fe7a6c748c..92510154d500e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -699,7 +699,7 @@ class LinearRegressionModel private[ml] ( } - override protected def predict(features: Vector): Double = { + override def predict(features: Vector): Double = { dot(features, coefficients) + intercept } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala index 200b234b79978..2d594460c2475 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala @@ -199,7 +199,7 @@ class RandomForestRegressionModel private[ml] ( dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) } - override protected def predict(features: Vector): Double = { + override def predict(features: Vector): Double = { // TODO: When we add a generic Bagging class, handle transform there. SPARK-7128 // Predict average of tree predictions. // Ignore the weights since all are 1.0 for now. diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala index eeb0324187c5b..2930f4900d50e 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala @@ -21,7 +21,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.tree.{CategoricalSplit, InternalNode, LeafNode} +import org.apache.spark.ml.tree.LeafNode import org.apache.spark.ml.tree.impl.TreeTests import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint} @@ -264,6 +264,21 @@ class DecisionTreeClassifierSuite extends MLTest with DefaultReadWriteTest { Vector, DecisionTreeClassificationModel](this, newTree, newData) } + test("prediction on single instance") { + val rdd = continuousDataPointsForMulticlassRDD + val dt = new DecisionTreeClassifier() + .setImpurity("Gini") + .setMaxDepth(4) + .setMaxBins(100) + val categoricalFeatures = Map(0 -> 3) + val numClasses = 3 + + val newData: DataFrame = TreeTests.setMetadata(rdd, categoricalFeatures, numClasses) + val newTree = dt.fit(newData) + + testPredictionModelSinglePrediction(newTree, newData) + } + test("training with 1-category categorical feature") { val data = sc.parallelize(Seq( LabeledPoint(0, Vectors.dense(0, 2, 3)), diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala index 092b4a01d5b0d..57796069f6052 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala @@ -197,6 +197,15 @@ class GBTClassifierSuite extends MLTest with DefaultReadWriteTest { Vector, GBTClassificationModel](this, gbtModel, validationDataset) } + test("prediction on single instance") { + + val gbt = new GBTClassifier().setSeed(123) + val trainingDataset = trainData.toDF("label", "features") + val gbtModel = gbt.fit(trainingDataset) + + testPredictionModelSinglePrediction(gbtModel, trainingDataset) + } + test("GBT parameter stepSize should be in interval (0, 1]") { withClue("GBT parameter stepSize should be in interval (0, 1]") { intercept[IllegalArgumentException] { diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala index a93825b8a812d..c05c896df5cb1 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala @@ -201,6 +201,12 @@ class LinearSVCSuite extends MLTest with DefaultReadWriteTest { dataset.as[LabeledPoint], estimator, modelEquals, 42L) } + test("prediction on single instance") { + val trainer = new LinearSVC() + val model = trainer.fit(smallBinaryDataset) + testPredictionModelSinglePrediction(model, smallBinaryDataset) + } + test("linearSVC comparison with R e1071 and scikit-learn") { val trainer1 = new LinearSVC() .setRegParam(0.00002) // set regParam = 2.0 / datasize / c diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index 9987cbf6ba116..36b7e51f93d01 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -499,6 +499,15 @@ class LogisticRegressionSuite extends MLTest with DefaultReadWriteTest { Vector, LogisticRegressionModel](this, model, smallBinaryDataset) } + test("prediction on single instance") { + val blor = new LogisticRegression().setFamily("binomial") + val blorModel = blor.fit(smallBinaryDataset) + testPredictionModelSinglePrediction(blorModel, smallBinaryDataset) + val mlor = new LogisticRegression().setFamily("multinomial") + val mlorModel = mlor.fit(smallMultinomialDataset) + testPredictionModelSinglePrediction(mlorModel, smallMultinomialDataset) + } + test("coefficients and intercept methods") { val mlr = new LogisticRegression().setMaxIter(1).setFamily("multinomial") val mlrModel = mlr.fit(smallMultinomialDataset) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala index daa58a56896d7..6b5fe6e49ffea 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala @@ -76,6 +76,18 @@ class MultilayerPerceptronClassifierSuite extends MLTest with DefaultReadWriteTe } } + test("prediction on single instance") { + val layers = Array[Int](2, 5, 2) + val trainer = new MultilayerPerceptronClassifier() + .setLayers(layers) + .setBlockSize(1) + .setSeed(123L) + .setMaxIter(100) + .setSolver("l-bfgs") + val model = trainer.fit(dataset) + testPredictionModelSinglePrediction(model, dataset) + } + test("Predicted class probabilities: calibration on toy dataset") { val layers = Array[Int](4, 5, 2) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala index 49115c8a4db30..5f9ab98a2c3ce 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala @@ -167,6 +167,28 @@ class NaiveBayesSuite extends MLTest with DefaultReadWriteTest { Vector, NaiveBayesModel](this, model, testDataset) } + test("prediction on single instance") { + val nPoints = 1000 + val piArray = Array(0.5, 0.1, 0.4).map(math.log) + val thetaArray = Array( + Array(0.70, 0.10, 0.10, 0.10), // label 0 + Array(0.10, 0.70, 0.10, 0.10), // label 1 + Array(0.10, 0.10, 0.70, 0.10) // label 2 + ).map(_.map(math.log)) + val pi = Vectors.dense(piArray) + val theta = new DenseMatrix(3, 4, thetaArray.flatten, true) + + val trainDataset = + generateNaiveBayesInput(piArray, thetaArray, nPoints, seed, "multinomial").toDF() + val nb = new NaiveBayes().setSmoothing(1.0).setModelType("multinomial") + val model = nb.fit(trainDataset) + + val validationDataset = + generateNaiveBayesInput(piArray, thetaArray, nPoints, 17, "multinomial").toDF() + + testPredictionModelSinglePrediction(model, validationDataset) + } + test("Naive Bayes with weighted samples") { val numClasses = 3 def modelEquals(m1: NaiveBayesModel, m2: NaiveBayesModel): Unit = { diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala index 02a9d5c2a18c0..ba4a9cf082785 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala @@ -155,6 +155,22 @@ class RandomForestClassifierSuite extends MLTest with DefaultReadWriteTest { Vector, RandomForestClassificationModel](this, model, df) } + test("prediction on single instance") { + val rdd = orderedLabeledPoints5_20 + val rf = new RandomForestClassifier() + .setImpurity("Gini") + .setMaxDepth(3) + .setNumTrees(3) + .setSeed(123) + val categoricalFeatures = Map.empty[Int, Int] + val numClasses = 2 + + val df: DataFrame = TreeTests.setMetadata(rdd, categoricalFeatures, numClasses) + val model = rf.fit(df) + + testPredictionModelSinglePrediction(model, df) + } + test("Fitting without numClasses in metadata") { val df: DataFrame = TreeTests.featureImportanceData(sc).toDF() val rf = new RandomForestClassifier().setMaxDepth(1).setNumTrees(1) diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala index 68a1218c23ece..29a438396516b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala @@ -136,6 +136,21 @@ class DecisionTreeRegressorSuite extends MLTest with DefaultReadWriteTest { assert(importances.toArray.forall(_ >= 0.0)) } + test("prediction on single instance") { + val dt = new DecisionTreeRegressor() + .setImpurity("variance") + .setMaxDepth(3) + .setSeed(123) + + // In this data, feature 1 is very important. + val data: RDD[LabeledPoint] = TreeTests.featureImportanceData(sc) + val categoricalFeatures = Map.empty[Int, Int] + val df: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0) + + val model = dt.fit(df) + testPredictionModelSinglePrediction(model, df) + } + test("should support all NumericType labels and not support other types") { val dt = new DecisionTreeRegressor().setMaxDepth(1) MLTestingUtils.checkNumericTypes[DecisionTreeRegressionModel, DecisionTreeRegressor]( diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala index 11c593b521e65..fad11d078250f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala @@ -99,6 +99,14 @@ class GBTRegressorSuite extends MLTest with DefaultReadWriteTest { } } + test("prediction on single instance") { + val gbt = new GBTRegressor() + .setMaxDepth(2) + .setMaxIter(2) + val model = gbt.fit(trainData.toDF()) + testPredictionModelSinglePrediction(model, validationData.toDF) + } + test("Checkpointing") { val tempDir = Utils.createTempDir() val path = tempDir.toURI.toString diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala index ef2ff94a5e213..d5bcbb221783e 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala @@ -211,6 +211,14 @@ class GeneralizedLinearRegressionSuite extends MLTest with DefaultReadWriteTest assert(model.getLink === "identity") } + test("prediction on single instance") { + val glr = new GeneralizedLinearRegression + val model = glr.setFamily("gaussian").setLink("identity") + .fit(datasetGaussianIdentity) + + testPredictionModelSinglePrediction(model, datasetGaussianIdentity) + } + test("generalized linear regression: gaussian family against glm") { /* R code: diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala index d42cb1714478f..9b19f63eba1bd 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -636,6 +636,13 @@ class LinearRegressionSuite extends MLTest with DefaultReadWriteTest { } } + test("prediction on single instance") { + val trainer = new LinearRegression + val model = trainer.fit(datasetWithDenseFeature) + + testPredictionModelSinglePrediction(model, datasetWithDenseFeature) + } + test("linear regression model with constant label") { /* R code: diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala index 8b8e8a655f47b..e83c49f932973 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala @@ -19,22 +19,22 @@ package org.apache.spark.ml.regression import org.apache.spark.SparkFunSuite import org.apache.spark.ml.feature.LabeledPoint +import org.apache.spark.ml.linalg.Vector import org.apache.spark.ml.tree.impl.TreeTests -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint} import org.apache.spark.mllib.tree.{EnsembleTestHelper, RandomForest => OldRandomForest} import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} -import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Row} /** * Test suite for [[RandomForestRegressor]]. */ -class RandomForestRegressorSuite extends SparkFunSuite with MLlibTestSparkContext - with DefaultReadWriteTest{ +class RandomForestRegressorSuite extends MLTest with DefaultReadWriteTest{ import RandomForestRegressorSuite.compareAPIs + import testImplicits._ private var orderedLabeledPoints50_1000: RDD[LabeledPoint] = _ @@ -74,6 +74,20 @@ class RandomForestRegressorSuite extends SparkFunSuite with MLlibTestSparkContex regressionTestWithContinuousFeatures(rf) } + test("prediction on single instance") { + val rf = new RandomForestRegressor() + .setImpurity("variance") + .setMaxDepth(2) + .setMaxBins(10) + .setNumTrees(1) + .setFeatureSubsetStrategy("auto") + .setSeed(123) + + val df = orderedLabeledPoints50_1000.toDF() + val model = rf.fit(df) + testPredictionModelSinglePrediction(model, df) + } + test("Feature importance with toy data") { val rf = new RandomForestRegressor() .setImpurity("variance") diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala b/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala index 795fd0e2ac0e4..76d41f9b23715 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala @@ -22,8 +22,9 @@ import java.io.File import org.scalatest.Suite import org.apache.spark.SparkContext -import org.apache.spark.ml.Transformer -import org.apache.spark.sql.{DataFrame, Encoder, Row} +import org.apache.spark.ml.{PredictionModel, Transformer} +import org.apache.spark.ml.linalg.Vector +import org.apache.spark.sql.{DataFrame, Dataset, Encoder, Row} import org.apache.spark.sql.execution.streaming.MemoryStream import org.apache.spark.sql.functions.col import org.apache.spark.sql.streaming.StreamTest @@ -136,4 +137,14 @@ trait MLTest extends StreamTest with TempDirectory { self: Suite => assert(hasExpectedMessage(exceptionOnStreamData)) } } + + def testPredictionModelSinglePrediction(model: PredictionModel[Vector, _], + dataset: Dataset[_]): Unit = { + + model.transform(dataset).select(model.getFeaturesCol, model.getPredictionCol) + .collect().foreach { + case Row(features: Vector, prediction: Double) => + assert(prediction === model.predict(features)) + } + } } From 8d79113b812a91073d2c24a3a9ad94cc3b90b24a Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Wed, 21 Mar 2018 09:46:47 -0700 Subject: [PATCH 0501/1161] [SPARK-23577][SQL] Supports custom line separator for text datasource ## What changes were proposed in this pull request? This PR proposes to add `lineSep` option for a configurable line separator in text datasource. It supports this option by using `LineRecordReader`'s functionality with passing it to the constructor. ## How was this patch tested? Manual tests and unit tests were added. Author: hyukjinkwon Closes #20727 from HyukjinKwon/linesep-text. --- python/pyspark/sql/readwriter.py | 14 ++++--- python/pyspark/sql/streaming.py | 8 +++- python/pyspark/sql/tests.py | 24 ++++++++++- .../apache/spark/sql/DataFrameReader.scala | 30 ++++++++------ .../apache/spark/sql/DataFrameWriter.scala | 2 + .../datasources/HadoopFileLinesReader.scala | 23 ++++++++++- .../datasources/text/TextFileFormat.scala | 16 ++++---- .../datasources/text/TextOptions.scala | 12 ++++++ .../sql/streaming/DataStreamReader.scala | 12 +++++- .../datasources/text/TextSuite.scala | 40 +++++++++++++++++++ 10 files changed, 147 insertions(+), 34 deletions(-) diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index facc16bc53108..e5288636c596e 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -304,16 +304,18 @@ def parquet(self, *paths): @ignore_unicode_prefix @since(1.6) - def text(self, paths, wholetext=False): + def text(self, paths, wholetext=False, lineSep=None): """ Loads text files and returns a :class:`DataFrame` whose schema starts with a string column named "value", and followed by partitioned columns if there are any. - Each line in the text file is a new row in the resulting DataFrame. + By default, each line in the text file is a new row in the resulting DataFrame. :param paths: string, or list of strings, for input path(s). :param wholetext: if true, read each file from input path(s) as a single row. + :param lineSep: defines the line separator that should be used for parsing. If None is + set, it covers all ``\\r``, ``\\r\\n`` and ``\\n``. >>> df = spark.read.text('python/test_support/sql/text-test.txt') >>> df.collect() @@ -322,7 +324,7 @@ def text(self, paths, wholetext=False): >>> df.collect() [Row(value=u'hello\\nthis')] """ - self._set_opts(wholetext=wholetext) + self._set_opts(wholetext=wholetext, lineSep=lineSep) if isinstance(paths, basestring): paths = [paths] return self._df(self._jreader.text(self._spark._sc._jvm.PythonUtils.toSeq(paths))) @@ -804,18 +806,20 @@ def parquet(self, path, mode=None, partitionBy=None, compression=None): self._jwrite.parquet(path) @since(1.6) - def text(self, path, compression=None): + def text(self, path, compression=None, lineSep=None): """Saves the content of the DataFrame in a text file at the specified path. :param path: the path in any Hadoop supported file system :param compression: compression codec to use when saving to file. This can be one of the known case-insensitive shorten names (none, bzip2, gzip, lz4, snappy and deflate). + :param lineSep: defines the line separator that should be used for writing. If None is + set, it uses the default value, ``\\n``. The DataFrame must have only one column that is of string type. Each row becomes a new line in the output file. """ - self._set_opts(compression=compression) + self._set_opts(compression=compression, lineSep=lineSep) self._jwrite.text(path) @since(2.0) diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index e8966c20a8f42..07f9ac1b5aa9e 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -531,17 +531,20 @@ def parquet(self, path): @ignore_unicode_prefix @since(2.0) - def text(self, path): + def text(self, path, wholetext=False, lineSep=None): """ Loads a text file stream and returns a :class:`DataFrame` whose schema starts with a string column named "value", and followed by partitioned columns if there are any. - Each line in the text file is a new row in the resulting DataFrame. + By default, each line in the text file is a new row in the resulting DataFrame. .. note:: Evolving. :param paths: string, or list of strings, for input path(s). + :param wholetext: if true, read each file from input path(s) as a single row. + :param lineSep: defines the line separator that should be used for parsing. If None is + set, it covers all ``\\r``, ``\\r\\n`` and ``\\n``. >>> text_sdf = spark.readStream.text(tempfile.mkdtemp()) >>> text_sdf.isStreaming @@ -549,6 +552,7 @@ def text(self, path): >>> "value" in str(text_sdf.schema) True """ + self._set_opts(wholetext=wholetext, lineSep=lineSep) if isinstance(path, basestring): return self._df(self._jreader.text(path)) else: diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 39d6c5226f138..967cc83166f3f 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -648,7 +648,29 @@ def test_non_existed_udaf(self): self.assertRaisesRegexp(AnalysisException, "Can not load class non_existed_udaf", lambda: spark.udf.registerJavaUDAF("udaf1", "non_existed_udaf")) - def test_multiLine_json(self): + def test_linesep_text(self): + df = self.spark.read.text("python/test_support/sql/ages_newlines.csv", lineSep=",") + expected = [Row(value=u'Joe'), Row(value=u'20'), Row(value=u'"Hi'), + Row(value=u'\nI am Jeo"\nTom'), Row(value=u'30'), + Row(value=u'"My name is Tom"\nHyukjin'), Row(value=u'25'), + Row(value=u'"I am Hyukjin\n\nI love Spark!"\n')] + self.assertEqual(df.collect(), expected) + + tpath = tempfile.mkdtemp() + shutil.rmtree(tpath) + try: + df.write.text(tpath, lineSep="!") + expected = [Row(value=u'Joe!20!"Hi!'), Row(value=u'I am Jeo"'), + Row(value=u'Tom!30!"My name is Tom"'), + Row(value=u'Hyukjin!25!"I am Hyukjin'), + Row(value=u''), Row(value=u'I love Spark!"'), + Row(value=u'!')] + readback = self.spark.read.text(tpath) + self.assertEqual(readback.collect(), expected) + finally: + shutil.rmtree(tpath) + + def test_multiline_json(self): people1 = self.spark.read.json("python/test_support/sql/people.json") people_array = self.spark.read.json("python/test_support/sql/people_array.json", multiLine=True) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 0139913aaa4e2..1a5e47508c070 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -647,14 +647,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * Loads text files and returns a `DataFrame` whose schema starts with a string column named * "value", and followed by partitioned columns if there are any. * - * You can set the following text-specific option(s) for reading text files: - *
      - *
    • `wholetext` ( default `false`): If true, read a file as a single row and not split by "\n". - *
    • - *
    - * By default, each line in the text files is a new row in the resulting DataFrame. - * - * Usage example: + * By default, each line in the text files is a new row in the resulting DataFrame. For example: * {{{ * // Scala: * spark.read.text("/path/to/spark/README.md") @@ -663,6 +656,14 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * spark.read().text("/path/to/spark/README.md") * }}} * + * You can set the following text-specific option(s) for reading text files: + *
      + *
    • `wholetext` (default `false`): If true, read a file as a single row and not split by "\n". + *
    • + *
    • `lineSep` (default covers all `\r`, `\r\n` and `\n`): defines the line separator + * that should be used for parsing.
    • + *
    + * * @param paths input paths * @since 1.6.0 */ @@ -686,11 +687,6 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * If the directory structure of the text files contains partitioning information, those are * ignored in the resulting Dataset. To include partitioning information as columns, use `text`. * - * You can set the following textFile-specific option(s) for reading text files: - *
      - *
    • `wholetext` ( default `false`): If true, read a file as a single row and not split by "\n". - *
    • - *
    * By default, each line in the text files is a new row in the resulting DataFrame. For example: * {{{ * // Scala: @@ -700,6 +696,14 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * spark.read().textFile("/path/to/spark/README.md") * }}} * + * You can set the following textFile-specific option(s) for reading text files: + *
      + *
    • `wholetext` (default `false`): If true, read a file as a single row and not split by "\n". + *
    • + *
    • `lineSep` (default covers all `\r`, `\r\n` and `\n`): defines the line separator + * that should be used for parsing.
    • + *
    + * * @param paths input path * @since 2.0.0 */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index ed7a9100cc7f1..bb93889dc55e9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -587,6 +587,8 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { *
  • `compression` (default `null`): compression codec to use when saving to file. This can be * one of the known case-insensitive shorten names (`none`, `bzip2`, `gzip`, `lz4`, * `snappy` and `deflate`).
  • + *
  • `lineSep` (default `\n`): defines the line separator that should + * be used for writing.
  • * * * @since 1.6.0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFileLinesReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFileLinesReader.scala index 83cf26c63a175..00a78f7343c59 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFileLinesReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFileLinesReader.scala @@ -30,9 +30,22 @@ import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl /** * An adaptor from a [[PartitionedFile]] to an [[Iterator]] of [[Text]], which are all of the lines * in that file. + * + * @param file A part (i.e. "block") of a single file that should be read line by line. + * @param lineSeparator A line separator that should be used for each line. If the value is `None`, + * it covers `\r`, `\r\n` and `\n`. + * @param conf Hadoop configuration + * + * @note The behavior when `lineSeparator` is `None` (covering `\r`, `\r\n` and `\n`) is defined + * by [[LineRecordReader]], not within Spark. */ class HadoopFileLinesReader( - file: PartitionedFile, conf: Configuration) extends Iterator[Text] with Closeable { + file: PartitionedFile, + lineSeparator: Option[Array[Byte]], + conf: Configuration) extends Iterator[Text] with Closeable { + + def this(file: PartitionedFile, conf: Configuration) = this(file, None, conf) + private val iterator = { val fileSplit = new FileSplit( new Path(new URI(file.filePath)), @@ -42,7 +55,13 @@ class HadoopFileLinesReader( Array.empty) val attemptId = new TaskAttemptID(new TaskID(new JobID(), TaskType.MAP, 0), 0) val hadoopAttemptContext = new TaskAttemptContextImpl(conf, attemptId) - val reader = new LineRecordReader() + + val reader = lineSeparator match { + case Some(sep) => new LineRecordReader(sep) + // If the line separator is `None`, it covers `\r`, `\r\n` and `\n`. + case _ => new LineRecordReader() + } + reader.initialize(fileSplit, hadoopAttemptContext) new RecordReaderIterator(reader) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala index c661e9bd3b94c..9647f09867643 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala @@ -17,11 +17,8 @@ package org.apache.spark.sql.execution.datasources.text -import java.io.Closeable - import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path} -import org.apache.hadoop.io.Text import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} import org.apache.spark.TaskContext @@ -89,7 +86,7 @@ class TextFileFormat extends TextBasedFileFormat with DataSourceRegister { path: String, dataSchema: StructType, context: TaskAttemptContext): OutputWriter = { - new TextOutputWriter(path, dataSchema, context) + new TextOutputWriter(path, dataSchema, textOptions.lineSeparatorInWrite, context) } override def getFileExtension(context: TaskAttemptContext): String = { @@ -113,18 +110,18 @@ class TextFileFormat extends TextBasedFileFormat with DataSourceRegister { val broadcastedHadoopConf = sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) - readToUnsafeMem(broadcastedHadoopConf, requiredSchema, textOptions.wholeText) + readToUnsafeMem(broadcastedHadoopConf, requiredSchema, textOptions) } private def readToUnsafeMem( conf: Broadcast[SerializableConfiguration], requiredSchema: StructType, - wholeTextMode: Boolean): (PartitionedFile) => Iterator[UnsafeRow] = { + textOptions: TextOptions): (PartitionedFile) => Iterator[UnsafeRow] = { (file: PartitionedFile) => { val confValue = conf.value.value - val reader = if (!wholeTextMode) { - new HadoopFileLinesReader(file, confValue) + val reader = if (!textOptions.wholeText) { + new HadoopFileLinesReader(file, textOptions.lineSeparatorInRead, confValue) } else { new HadoopFileWholeTextReader(file, confValue) } @@ -152,6 +149,7 @@ class TextFileFormat extends TextBasedFileFormat with DataSourceRegister { class TextOutputWriter( path: String, dataSchema: StructType, + lineSeparator: Array[Byte], context: TaskAttemptContext) extends OutputWriter { @@ -162,7 +160,7 @@ class TextOutputWriter( val utf8string = row.getUTF8String(0) utf8string.writeTo(writer) } - writer.write('\n') + writer.write(lineSeparator) } override def close(): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextOptions.scala index 2a661561ab51e..18698df9fd8e5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextOptions.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.datasources.text +import java.nio.charset.StandardCharsets + import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, CompressionCodecs} /** @@ -39,9 +41,19 @@ private[text] class TextOptions(@transient private val parameters: CaseInsensiti */ val wholeText = parameters.getOrElse(WHOLETEXT, "false").toBoolean + private val lineSeparator: Option[String] = parameters.get(LINE_SEPARATOR).map { sep => + require(sep.nonEmpty, s"'$LINE_SEPARATOR' cannot be an empty string.") + sep + } + // Note that the option 'lineSep' uses a different default value in read and write. + val lineSeparatorInRead: Option[Array[Byte]] = + lineSeparator.map(_.getBytes(StandardCharsets.UTF_8)) + val lineSeparatorInWrite: Array[Byte] = + lineSeparatorInRead.getOrElse("\n".getBytes(StandardCharsets.UTF_8)) } private[text] object TextOptions { val COMPRESSION = "compression" val WHOLETEXT = "wholetext" + val LINE_SEPARATOR = "lineSep" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index c393dcdfdd7e5..9b17406a816b5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -387,7 +387,7 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo * Loads text files and returns a `DataFrame` whose schema starts with a string column named * "value", and followed by partitioned columns if there are any. * - * Each line in the text files is a new row in the resulting DataFrame. For example: + * By default, each line in the text files is a new row in the resulting DataFrame. For example: * {{{ * // Scala: * spark.readStream.text("/path/to/directory/") @@ -400,6 +400,10 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo *
      *
    • `maxFilesPerTrigger` (default: no max limit): sets the maximum number of new files to be * considered in every trigger.
    • + *
    • `wholetext` (default `false`): If true, read a file as a single row and not split by "\n". + *
    • + *
    • `lineSep` (default covers all `\r`, `\r\n` and `\n`): defines the line separator + * that should be used for parsing.
    • *
    * * @since 2.0.0 @@ -413,7 +417,7 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo * If the directory structure of the text files contains partitioning information, those are * ignored in the resulting Dataset. To include partitioning information as columns, use `text`. * - * Each line in the text file is a new element in the resulting Dataset. For example: + * By default, each line in the text file is a new element in the resulting Dataset. For example: * {{{ * // Scala: * spark.readStream.textFile("/path/to/spark/README.md") @@ -426,6 +430,10 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo *
      *
    • `maxFilesPerTrigger` (default: no max limit): sets the maximum number of new files to be * considered in every trigger.
    • + *
    • `wholetext` (default `false`): If true, read a file as a single row and not split by "\n". + *
    • + *
    • `lineSep` (default covers all `\r`, `\r\n` and `\n`): defines the line separator + * that should be used for parsing.
    • *
    * * @param path input path diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala index 33287044f279e..e8a5299d6ba9d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala @@ -18,10 +18,13 @@ package org.apache.spark.sql.execution.datasources.text import java.io.File +import java.nio.charset.StandardCharsets +import java.nio.file.Files import org.apache.hadoop.io.SequenceFile.CompressionType import org.apache.hadoop.io.compress.GzipCodec +import org.apache.spark.TestUtils import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row, SaveMode} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext @@ -172,6 +175,43 @@ class TextSuite extends QueryTest with SharedSQLContext { } } + def testLineSeparator(lineSep: String): Unit = { + test(s"SPARK-23577: Support line separator - lineSep: '$lineSep'") { + // Read + val values = Seq("a", "b", "\nc") + val data = values.mkString(lineSep) + val dataWithTrailingLineSep = s"$data$lineSep" + Seq(data, dataWithTrailingLineSep).foreach { lines => + withTempPath { path => + Files.write(path.toPath, lines.getBytes(StandardCharsets.UTF_8)) + val df = spark.read.option("lineSep", lineSep).text(path.getAbsolutePath) + checkAnswer(df, Seq("a", "b", "\nc").toDF()) + } + } + + // Write + withTempPath { path => + values.toDF().coalesce(1) + .write.option("lineSep", lineSep).text(path.getAbsolutePath) + val partFile = TestUtils.recursiveList(path).filter(f => f.getName.startsWith("part-")).head + val readBack = new String(Files.readAllBytes(partFile.toPath), StandardCharsets.UTF_8) + assert(readBack === s"a${lineSep}b${lineSep}\nc${lineSep}") + } + + // Roundtrip + withTempPath { path => + val df = values.toDF() + df.write.option("lineSep", lineSep).text(path.getAbsolutePath) + val readBack = spark.read.option("lineSep", lineSep).text(path.getAbsolutePath) + checkAnswer(df, readBack) + } + } + } + + Seq("|", "^", "::", "!!!@3", 0x1E.toChar.toString).foreach { lineSep => + testLineSeparator(lineSep) + } + private def testFile: String = { Thread.currentThread().getContextClassLoader.getResource("test-data/text-suite.txt").toString } From 98d0ea3f6091730285293321a50148f69e94c9cd Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Wed, 21 Mar 2018 09:52:28 -0700 Subject: [PATCH 0502/1161] [SPARK-23264][SQL] Fix scala.MatchError in literals.sql.out ## What changes were proposed in this pull request? To fix `scala.MatchError` in `literals.sql.out`, this pr added an entry for `CalendarIntervalType` in `QueryExecution.toHiveStructString`. ## How was this patch tested? Existing tests and added tests in `literals.sql` Author: Takeshi Yamamuro Closes #20872 from maropu/FixIntervalTests. --- .../spark/sql/execution/QueryExecution.scala | 2 ++ .../resources/sql-tests/inputs/literals.sql | 3 +++ .../sql-tests/results/literals.sql.out | 20 ++++++++++++------- 3 files changed, 18 insertions(+), 7 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index 7cae24bf5976c..15379a0663f7d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -155,6 +155,7 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) { case (null, _) => "null" case (s: String, StringType) => "\"" + s + "\"" case (decimal, DecimalType()) => decimal.toString + case (interval, CalendarIntervalType) => interval.toString case (other, tpe) if primitiveTypes contains tpe => other.toString } @@ -178,6 +179,7 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) { DateTimeUtils.getTimeZone(sparkSession.sessionState.conf.sessionLocalTimeZone)) case (bin: Array[Byte], BinaryType) => new String(bin, StandardCharsets.UTF_8) case (decimal: java.math.BigDecimal, DecimalType()) => formatDecimal(decimal) + case (interval, CalendarIntervalType) => interval.toString case (other, tpe) if primitiveTypes.contains(tpe) => other.toString } } diff --git a/sql/core/src/test/resources/sql-tests/inputs/literals.sql b/sql/core/src/test/resources/sql-tests/inputs/literals.sql index 37b4b7606d12b..a743cf1ec2cde 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/literals.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/literals.sql @@ -105,3 +105,6 @@ select X'XuZ'; -- Hive literal_double test. SELECT 3.14, -3.14, 3.14e8, 3.14e-8, -3.14e8, -3.14e-8, 3.14e+8, 3.14E8, 3.14E-8; + +-- map + interval test +select map(1, interval 1 day, 2, interval 3 week); diff --git a/sql/core/src/test/resources/sql-tests/results/literals.sql.out b/sql/core/src/test/resources/sql-tests/results/literals.sql.out index 95d4413148f64..b8c91dc8b59a4 100644 --- a/sql/core/src/test/resources/sql-tests/results/literals.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/literals.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 43 +-- Number of queries: 44 -- !query 0 @@ -323,19 +323,17 @@ select timestamp '2016-33-11 20:54:00.000' -- !query 34 select interval 13.123456789 seconds, interval -13.123456789 second -- !query 34 schema -struct<> +struct -- !query 34 output -scala.MatchError -(interval 13 seconds 123 milliseconds 456 microseconds,CalendarIntervalType) (of class scala.Tuple2) +interval 13 seconds 123 milliseconds 456 microseconds interval -12 seconds -876 milliseconds -544 microseconds -- !query 35 select interval 1 year 2 month 3 week 4 day 5 hour 6 minute 7 seconds 8 millisecond, 9 microsecond -- !query 35 schema -struct<> +struct -- !query 35 output -scala.MatchError -(interval 1 years 2 months 3 weeks 4 days 5 hours 6 minutes 7 seconds 8 milliseconds,CalendarIntervalType) (of class scala.Tuple2) +interval 1 years 2 months 3 weeks 4 days 5 hours 6 minutes 7 seconds 8 milliseconds 9 -- !query 36 @@ -416,3 +414,11 @@ SELECT 3.14, -3.14, 3.14e8, 3.14e-8, -3.14e8, -3.14e-8, 3.14e+8, 3.14E8, 3.14E-8 struct<3.14:decimal(3,2),-3.14:decimal(3,2),3.14E+8:decimal(3,-6),3.14E-8:decimal(10,10),-3.14E+8:decimal(3,-6),-3.14E-8:decimal(10,10),3.14E+8:decimal(3,-6),3.14E+8:decimal(3,-6),3.14E-8:decimal(10,10)> -- !query 42 output 3.14 -3.14 314000000 0.0000000314 -314000000 -0.0000000314 314000000 314000000 0.0000000314 + + +-- !query 43 +select map(1, interval 1 day, 2, interval 3 week) +-- !query 43 schema +struct> +-- !query 43 output +{1:interval 1 days,2:interval 3 weeks} From 918c7e99afdcea05c36626e230636c4f8aabf82c Mon Sep 17 00:00:00 2001 From: Gabor Somogyi Date: Wed, 21 Mar 2018 10:06:26 -0700 Subject: [PATCH 0503/1161] [SPARK-23288][SS] Fix output metrics with parquet sink ## What changes were proposed in this pull request? Output metrics were not filled when parquet sink used. This PR fixes this problem by passing a `BasicWriteJobStatsTracker` in `FileStreamSink`. ## How was this patch tested? Additional unit test added. Author: Gabor Somogyi Closes #20745 from gaborgsomogyi/SPARK-23288. --- .../command/DataWritingCommand.scala | 11 +--- .../datasources/BasicWriteStatsTracker.scala | 25 +++++++-- .../execution/streaming/FileStreamSink.scala | 10 +++- .../sql/streaming/FileStreamSinkSuite.scala | 52 +++++++++++++++++++ 4 files changed, 82 insertions(+), 16 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DataWritingCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DataWritingCommand.scala index e56f8105fc9a7..e11dbd201004d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DataWritingCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DataWritingCommand.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.execution.command import org.apache.hadoop.conf.Configuration -import org.apache.spark.SparkContext import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.{Command, LogicalPlan} @@ -45,15 +44,7 @@ trait DataWritingCommand extends Command { // Output columns of the analyzed input query plan def outputColumns: Seq[Attribute] - lazy val metrics: Map[String, SQLMetric] = { - val sparkContext = SparkContext.getActive.get - Map( - "numFiles" -> SQLMetrics.createMetric(sparkContext, "number of written files"), - "numOutputBytes" -> SQLMetrics.createMetric(sparkContext, "bytes of written output"), - "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), - "numParts" -> SQLMetrics.createMetric(sparkContext, "number of dynamic part") - ) - } + lazy val metrics: Map[String, SQLMetric] = BasicWriteJobStatsTracker.metrics def basicWriteJobStatsTracker(hadoopConf: Configuration): BasicWriteJobStatsTracker = { val serializableHadoopConf = new SerializableConfiguration(hadoopConf) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BasicWriteStatsTracker.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BasicWriteStatsTracker.scala index 9dbbe9946ee99..69c03d862391e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BasicWriteStatsTracker.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BasicWriteStatsTracker.scala @@ -153,12 +153,29 @@ class BasicWriteJobStatsTracker( totalNumOutput += summary.numRows } - metrics("numFiles").add(numFiles) - metrics("numOutputBytes").add(totalNumBytes) - metrics("numOutputRows").add(totalNumOutput) - metrics("numParts").add(numPartitions) + metrics(BasicWriteJobStatsTracker.NUM_FILES_KEY).add(numFiles) + metrics(BasicWriteJobStatsTracker.NUM_OUTPUT_BYTES_KEY).add(totalNumBytes) + metrics(BasicWriteJobStatsTracker.NUM_OUTPUT_ROWS_KEY).add(totalNumOutput) + metrics(BasicWriteJobStatsTracker.NUM_PARTS_KEY).add(numPartitions) val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, metrics.values.toList) } } + +object BasicWriteJobStatsTracker { + private val NUM_FILES_KEY = "numFiles" + private val NUM_OUTPUT_BYTES_KEY = "numOutputBytes" + private val NUM_OUTPUT_ROWS_KEY = "numOutputRows" + private val NUM_PARTS_KEY = "numParts" + + def metrics: Map[String, SQLMetric] = { + val sparkContext = SparkContext.getActive.get + Map( + NUM_FILES_KEY -> SQLMetrics.createMetric(sparkContext, "number of written files"), + NUM_OUTPUT_BYTES_KEY -> SQLMetrics.createMetric(sparkContext, "bytes of written output"), + NUM_OUTPUT_ROWS_KEY -> SQLMetrics.createMetric(sparkContext, "number of output rows"), + NUM_PARTS_KEY -> SQLMetrics.createMetric(sparkContext, "number of dynamic part") + ) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala index 87a17cebdc10c..b3d12f67b5d63 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala @@ -26,7 +26,8 @@ import org.apache.spark.internal.Logging import org.apache.spark.internal.io.FileCommitProtocol import org.apache.spark.sql.{DataFrame, SparkSession} import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.execution.datasources.{FileFormat, FileFormatWriter} +import org.apache.spark.sql.execution.datasources.{BasicWriteJobStatsTracker, FileFormat, FileFormatWriter} +import org.apache.spark.util.SerializableConfiguration object FileStreamSink extends Logging { // The name of the subdirectory that is used to store metadata about which files are valid. @@ -97,6 +98,11 @@ class FileStreamSink( new FileStreamSinkLog(FileStreamSinkLog.VERSION, sparkSession, logPath.toUri.toString) private val hadoopConf = sparkSession.sessionState.newHadoopConf() + private def basicWriteJobStatsTracker: BasicWriteJobStatsTracker = { + val serializableHadoopConf = new SerializableConfiguration(hadoopConf) + new BasicWriteJobStatsTracker(serializableHadoopConf, BasicWriteJobStatsTracker.metrics) + } + override def addBatch(batchId: Long, data: DataFrame): Unit = { if (batchId <= fileLog.getLatest().map(_._1).getOrElse(-1L)) { logInfo(s"Skipping already committed batch $batchId") @@ -131,7 +137,7 @@ class FileStreamSink( hadoopConf = hadoopConf, partitionColumns = partitionColumns, bucketSpec = None, - statsTrackers = Nil, + statsTrackers = Seq(basicWriteJobStatsTracker), options = options) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala index 31e5527d7366a..cf41d7e0e4fe1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala @@ -21,6 +21,7 @@ import java.util.Locale import org.apache.hadoop.fs.Path +import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd} import org.apache.spark.sql.{AnalysisException, DataFrame} import org.apache.spark.sql.execution.DataSourceScanExec import org.apache.spark.sql.execution.datasources._ @@ -405,4 +406,55 @@ class FileStreamSinkSuite extends StreamTest { } } } + + test("SPARK-23288 writing and checking output metrics") { + Seq("parquet", "orc", "text", "json").foreach { format => + val inputData = MemoryStream[String] + val df = inputData.toDF() + + withTempDir { outputDir => + withTempDir { checkpointDir => + + var query: StreamingQuery = null + + var numTasks = 0 + var recordsWritten: Long = 0L + var bytesWritten: Long = 0L + try { + spark.sparkContext.addSparkListener(new SparkListener() { + override def onTaskEnd(taskEnd: SparkListenerTaskEnd) { + val outputMetrics = taskEnd.taskMetrics.outputMetrics + recordsWritten += outputMetrics.recordsWritten + bytesWritten += outputMetrics.bytesWritten + numTasks += 1 + } + }) + + query = + df.writeStream + .option("checkpointLocation", checkpointDir.getCanonicalPath) + .format(format) + .start(outputDir.getCanonicalPath) + + inputData.addData("1", "2", "3") + inputData.addData("4", "5") + + failAfter(streamingTimeout) { + query.processAllAvailable() + } + spark.sparkContext.listenerBus.waitUntilEmpty(streamingTimeout.toMillis) + + assert(numTasks > 0) + assert(recordsWritten === 5) + // This is heavily file type/version specific but should be filled + assert(bytesWritten > 0) + } finally { + if (query != null) { + query.stop() + } + } + } + } + } + } } From 2b89e4aa2e8bd8b88f6e5eb60d95c1a58e5c4ace Mon Sep 17 00:00:00 2001 From: akonopko Date: Wed, 21 Mar 2018 14:40:21 -0500 Subject: [PATCH 0504/1161] [SPARK-18580][DSTREAM][KAFKA] Add spark.streaming.backpressure.initialRate to direct Kafka streams ## What changes were proposed in this pull request? Add `spark.streaming.backpressure.initialRate` to direct Kafka Streams for Kafka 0.8 and 0.10 This is required in order to be able to use backpressure with huge lags, which cannot be processed at once. Without this parameter `DirectKafkaInputDStream` with backpressure enabled would try to get all the possible data from Kafka before adjusting consumption rate ## How was this patch tested? - Tests added to `org/apache/spark/streaming/kafka010/DirectKafkaStreamSuite.scala` and `org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala` - Manual tests on YARN cluster Author: akonopko Author: Alexander Konopko Closes #19431 from akonopko/SPARK-18580-initialrate. --- .../kafka010/DirectKafkaInputDStream.scala | 8 ++- .../kafka010/DirectKafkaStreamSuite.scala | 51 +++++++++++++++- .../kafka/DirectKafkaInputDStream.scala | 9 ++- .../kafka/DirectKafkaStreamSuite.scala | 59 ++++++++++++++++++- 4 files changed, 120 insertions(+), 7 deletions(-) diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala index 9cb2448fea0f4..215b7cab703fb 100644 --- a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala +++ b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala @@ -56,6 +56,9 @@ private[spark] class DirectKafkaInputDStream[K, V]( ppc: PerPartitionConfig ) extends InputDStream[ConsumerRecord[K, V]](_ssc) with Logging with CanCommitOffsets { + private val initialRate = context.sparkContext.getConf.getLong( + "spark.streaming.backpressure.initialRate", 0) + val executorKafkaParams = { val ekp = new ju.HashMap[String, Object](consumerStrategy.executorKafkaParams) KafkaUtils.fixKafkaParams(ekp) @@ -126,7 +129,10 @@ private[spark] class DirectKafkaInputDStream[K, V]( protected[streaming] def maxMessagesPerPartition( offsets: Map[TopicPartition, Long]): Option[Map[TopicPartition, Long]] = { - val estimatedRateLimit = rateController.map(_.getLatestRate()) + val estimatedRateLimit = rateController.map { x => { + val lr = x.getLatestRate() + if (lr > 0) lr else initialRate + }} // calculate a per-partition rate limit based on current lag val effectiveRateLimitPerPartition = estimatedRateLimit.filter(_ > 0) match { diff --git a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/DirectKafkaStreamSuite.scala b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/DirectKafkaStreamSuite.scala index 8524743ee2846..35e4678f2e3c8 100644 --- a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/DirectKafkaStreamSuite.scala +++ b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/DirectKafkaStreamSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.streaming.kafka010 import java.io.File import java.lang.{ Long => JLong } -import java.util.{ Arrays, HashMap => JHashMap, Map => JMap } +import java.util.{ Arrays, HashMap => JHashMap, Map => JMap, UUID } import java.util.concurrent.ConcurrentLinkedQueue import java.util.concurrent.atomic.AtomicLong @@ -34,7 +34,7 @@ import org.apache.kafka.common.serialization.StringDeserializer import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} import org.scalatest.concurrent.Eventually -import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.streaming.{Milliseconds, StreamingContext, Time} @@ -617,6 +617,53 @@ class DirectKafkaStreamSuite ssc.stop() } + test("backpressure.initialRate should honor maxRatePerPartition") { + backpressureTest(maxRatePerPartition = 1000, initialRate = 500, maxMessagesPerPartition = 250) + } + + test("use backpressure.initialRate with backpressure") { + backpressureTest(maxRatePerPartition = 300, initialRate = 1000, maxMessagesPerPartition = 150) + } + + private def backpressureTest( + maxRatePerPartition: Int, + initialRate: Int, + maxMessagesPerPartition: Int) = { + + val topic = UUID.randomUUID().toString + val kafkaParams = getKafkaParams("auto.offset.reset" -> "earliest") + val sparkConf = new SparkConf() + // Safe, even with streaming, because we're using the direct API. + // Using 1 core is useful to make the test more predictable. + .setMaster("local[1]") + .setAppName(this.getClass.getSimpleName) + .set("spark.streaming.backpressure.enabled", "true") + .set("spark.streaming.backpressure.initialRate", initialRate.toString) + .set("spark.streaming.kafka.maxRatePerPartition", maxRatePerPartition.toString) + + val messages = Map("foo" -> 5000) + kafkaTestUtils.sendMessages(topic, messages) + + ssc = new StreamingContext(sparkConf, Milliseconds(500)) + + val kafkaStream = withClue("Error creating direct stream") { + new DirectKafkaInputDStream[String, String]( + ssc, + preferredHosts, + ConsumerStrategies.Subscribe[String, String](List(topic), kafkaParams.asScala), + new DefaultPerPartitionConfig(sparkConf) + ) + } + kafkaStream.start() + + val input = Map(new TopicPartition(topic, 0) -> 1000L) + + assert(kafkaStream.maxMessagesPerPartition(input).get == + Map(new TopicPartition(topic, 0) -> maxMessagesPerPartition)) // we run for half a second + + kafkaStream.stop() + } + test("maxMessagesPerPartition with zero offset and rate equal to one") { val topic = "backpressure" val kafkaParams = getKafkaParams() diff --git a/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala b/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala index d6dd0744441e4..9297c39d170c4 100644 --- a/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala +++ b/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala @@ -91,9 +91,16 @@ class DirectKafkaInputDStream[ private val maxRateLimitPerPartition: Long = context.sparkContext.getConf.getLong( "spark.streaming.kafka.maxRatePerPartition", 0) + private val initialRate = context.sparkContext.getConf.getLong( + "spark.streaming.backpressure.initialRate", 0) + protected[streaming] def maxMessagesPerPartition( offsets: Map[TopicAndPartition, Long]): Option[Map[TopicAndPartition, Long]] = { - val estimatedRateLimit = rateController.map(_.getLatestRate()) + + val estimatedRateLimit = rateController.map { x => { + val lr = x.getLatestRate() + if (lr > 0) lr else initialRate + }} // calculate a per-partition rate limit based on current lag val effectiveRateLimitPerPartition = estimatedRateLimit.filter(_ > 0) match { diff --git a/external/kafka-0-8/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala b/external/kafka-0-8/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala index 3fea6cfd910bf..ecca38784e777 100644 --- a/external/kafka-0-8/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala +++ b/external/kafka-0-8/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.streaming.kafka import java.io.File -import java.util.Arrays +import java.util.{ Arrays, UUID } import java.util.concurrent.ConcurrentLinkedQueue import java.util.concurrent.atomic.AtomicLong @@ -32,12 +32,11 @@ import kafka.serializer.StringDecoder import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} import org.scalatest.concurrent.Eventually -import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.streaming.{Milliseconds, StreamingContext, Time} import org.apache.spark.streaming.dstream.DStream -import org.apache.spark.streaming.kafka.KafkaCluster.LeaderOffset import org.apache.spark.streaming.scheduler._ import org.apache.spark.streaming.scheduler.rate.RateEstimator import org.apache.spark.util.Utils @@ -456,6 +455,60 @@ class DirectKafkaStreamSuite ssc.stop() } + test("use backpressure.initialRate with backpressure") { + backpressureTest(maxRatePerPartition = 1000, initialRate = 500, maxMessagesPerPartition = 250) + } + + test("backpressure.initialRate should honor maxRatePerPartition") { + backpressureTest(maxRatePerPartition = 300, initialRate = 1000, maxMessagesPerPartition = 150) + } + + private def backpressureTest( + maxRatePerPartition: Int, + initialRate: Int, + maxMessagesPerPartition: Int) = { + + val topic = UUID.randomUUID().toString + val topicPartitions = Set(TopicAndPartition(topic, 0)) + kafkaTestUtils.createTopic(topic, 1) + val kafkaParams = Map( + "metadata.broker.list" -> kafkaTestUtils.brokerAddress, + "auto.offset.reset" -> "smallest" + ) + + val sparkConf = new SparkConf() + // Safe, even with streaming, because we're using the direct API. + // Using 1 core is useful to make the test more predictable. + .setMaster("local[1]") + .setAppName(this.getClass.getSimpleName) + .set("spark.streaming.backpressure.enabled", "true") + .set("spark.streaming.backpressure.initialRate", initialRate.toString) + .set("spark.streaming.kafka.maxRatePerPartition", maxRatePerPartition.toString) + + val messages = Map("foo" -> 5000) + kafkaTestUtils.sendMessages(topic, messages) + + ssc = new StreamingContext(sparkConf, Milliseconds(500)) + + val kafkaStream = withClue("Error creating direct stream") { + val kc = new KafkaCluster(kafkaParams) + val messageHandler = (mmd: MessageAndMetadata[String, String]) => (mmd.key, mmd.message) + val m = kc.getEarliestLeaderOffsets(topicPartitions) + .fold(e => Map.empty[TopicAndPartition, Long], m => m.mapValues(lo => lo.offset)) + + new DirectKafkaInputDStream[String, String, StringDecoder, StringDecoder, (String, String)]( + ssc, kafkaParams, m, messageHandler) + } + kafkaStream.start() + + val input = Map(new TopicAndPartition(topic, 0) -> 1000L) + + assert(kafkaStream.maxMessagesPerPartition(input).get == + Map(new TopicAndPartition(topic, 0) -> maxMessagesPerPartition)) + + kafkaStream.stop() + } + test("maxMessagesPerPartition with zero offset and rate equal to one") { val topic = "backpressure" val kafkaParams = Map( From a091ee676b8707819e94d92693956237310a6145 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Wed, 21 Mar 2018 13:52:03 -0700 Subject: [PATCH 0505/1161] [MINOR] Fix Java lint from new JavaKolmogorovSmirnovTestSuite ## What changes were proposed in this pull request? Fix lint-java from https://github.com/apache/spark/pull/19108 addition of JavaKolmogorovSmirnovTestSuite Author: Joseph K. Bradley Closes #20875 from jkbradley/kstest-lint-fix. --- .../spark/ml/stat/JavaKolmogorovSmirnovTestSuite.java | 7 ------- 1 file changed, 7 deletions(-) diff --git a/mllib/src/test/java/org/apache/spark/ml/stat/JavaKolmogorovSmirnovTestSuite.java b/mllib/src/test/java/org/apache/spark/ml/stat/JavaKolmogorovSmirnovTestSuite.java index 021272dd5a40c..830f668fe07b8 100644 --- a/mllib/src/test/java/org/apache/spark/ml/stat/JavaKolmogorovSmirnovTestSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/stat/JavaKolmogorovSmirnovTestSuite.java @@ -18,18 +18,11 @@ package org.apache.spark.ml.stat; import java.io.IOException; -import java.util.ArrayList; import java.util.Arrays; import java.util.List; import org.apache.commons.math3.distribution.NormalDistribution; -import org.apache.spark.ml.linalg.VectorUDT; -import org.apache.spark.sql.Encoder; import org.apache.spark.sql.Encoders; -import org.apache.spark.sql.types.DoubleType; -import org.apache.spark.sql.types.Metadata; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; import org.junit.Test; import org.apache.spark.SharedSparkSession; From 0604beaff2baa2d0fed86c0c87fd2a16a1838b5f Mon Sep 17 00:00:00 2001 From: Mihaly Toth Date: Wed, 21 Mar 2018 17:05:39 -0700 Subject: [PATCH 0506/1161] [SPARK-23729][CORE] Respect URI fragment when resolving globs Firstly, glob resolution will not result in swallowing the remote name part (that is preceded by the `#` sign) in case of `--files` or `--archives` options Moreover in the special case of multiple resolutions when the remote naming does not make sense and error is returned. Enhanced current test and wrote additional test for the error case Author: Mihaly Toth Closes #20853 from misutoth/glob-with-remote-name. --- .../apache/spark/deploy/DependencyUtils.scala | 34 +++++++++++---- .../org/apache/spark/deploy/SparkSubmit.scala | 13 ++++++ .../spark/deploy/SparkSubmitSuite.scala | 41 +++++++++++++++---- 3 files changed, 72 insertions(+), 16 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/DependencyUtils.scala b/core/src/main/scala/org/apache/spark/deploy/DependencyUtils.scala index ecc82d7ac8001..ab319c860ee69 100644 --- a/core/src/main/scala/org/apache/spark/deploy/DependencyUtils.scala +++ b/core/src/main/scala/org/apache/spark/deploy/DependencyUtils.scala @@ -18,12 +18,13 @@ package org.apache.spark.deploy import java.io.File +import java.net.URI import org.apache.commons.lang3.StringUtils import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, Path} -import org.apache.spark.{SecurityManager, SparkConf} +import org.apache.spark.{SecurityManager, SparkConf, SparkException} import org.apache.spark.util.{MutableURLClassLoader, Utils} private[deploy] object DependencyUtils { @@ -137,16 +138,31 @@ private[deploy] object DependencyUtils { def resolveGlobPaths(paths: String, hadoopConf: Configuration): String = { require(paths != null, "paths cannot be null.") Utils.stringToSeq(paths).flatMap { path => - val uri = Utils.resolveURI(path) - uri.getScheme match { - case "local" | "http" | "https" | "ftp" => Array(path) - case _ => - val fs = FileSystem.get(uri, hadoopConf) - Option(fs.globStatus(new Path(uri))).map { status => - status.filter(_.isFile).map(_.getPath.toUri.toString) - }.getOrElse(Array(path)) + val (base, fragment) = splitOnFragment(path) + (resolveGlobPath(base, hadoopConf), fragment) match { + case (resolved, Some(_)) if resolved.length > 1 => throw new SparkException( + s"${base.toString} resolves ambiguously to multiple files: ${resolved.mkString(",")}") + case (resolved, Some(namedAs)) => resolved.map(_ + "#" + namedAs) + case (resolved, _) => resolved } }.mkString(",") } + private def splitOnFragment(path: String): (URI, Option[String]) = { + val uri = Utils.resolveURI(path) + val withoutFragment = new URI(uri.getScheme, uri.getSchemeSpecificPart, null) + (withoutFragment, Option(uri.getFragment)) + } + + private def resolveGlobPath(uri: URI, hadoopConf: Configuration): Array[String] = { + uri.getScheme match { + case "local" | "http" | "https" | "ftp" => Array(uri.toString) + case _ => + val fs = FileSystem.get(uri, hadoopConf) + Option(fs.globStatus(new Path(uri))).map { status => + status.filter(_.isFile).map(_.getPath.toUri.toString) + }.getOrElse(Array(uri.toString)) + } + } + } diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 329bde08718fe..3965f17f4b56e 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -245,6 +245,19 @@ object SparkSubmit extends CommandLineUtils with Logging { args: SparkSubmitArguments, conf: Option[HadoopConfiguration] = None) : (Seq[String], Seq[String], SparkConf, String) = { + try { + doPrepareSubmitEnvironment(args, conf) + } catch { + case e: SparkException => + printErrorAndExit(e.getMessage) + throw e + } + } + + private def doPrepareSubmitEnvironment( + args: SparkSubmitArguments, + conf: Option[HadoopConfiguration] = None) + : (Seq[String], Seq[String], SparkConf, String) = { // Return values val childArgs = new ArrayBuffer[String]() val childClasspath = new ArrayBuffer[String]() diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index d265643a80b4e..2d0c192db4915 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.deploy import java.io._ import java.net.URI import java.nio.charset.StandardCharsets -import java.nio.file.Files +import java.nio.file.{Files, Paths} import scala.collection.mutable import scala.collection.mutable.ArrayBuffer @@ -606,10 +606,13 @@ class SparkSubmitSuite } test("resolves command line argument paths correctly") { - val jars = "/jar1,/jar2" // --jars - val files = "local:/file1,file2" // --files - val archives = "file:/archive1,archive2" // --archives - val pyFiles = "py-file1,py-file2" // --py-files + val dir = Utils.createTempDir() + val archive = Paths.get(dir.toPath.toString, "single.zip") + Files.createFile(archive) + val jars = "/jar1,/jar2" + val files = "local:/file1,file2" + val archives = s"file:/archive1,${dir.toPath.toAbsolutePath.toString}/*.zip#archive3" + val pyFiles = "py-file1,py-file2" // Test jars and files val clArgs = Seq( @@ -636,9 +639,10 @@ class SparkSubmitSuite val appArgs2 = new SparkSubmitArguments(clArgs2) val (_, _, conf2, _) = SparkSubmit.prepareSubmitEnvironment(appArgs2) appArgs2.files should be (Utils.resolveURIs(files)) - appArgs2.archives should be (Utils.resolveURIs(archives)) + appArgs2.archives should fullyMatch regex ("file:/archive1,file:.*#archive3") conf2.get("spark.yarn.dist.files") should be (Utils.resolveURIs(files)) - conf2.get("spark.yarn.dist.archives") should be (Utils.resolveURIs(archives)) + conf2.get("spark.yarn.dist.archives") should fullyMatch regex + ("file:/archive1,file:.*#archive3") // Test python files val clArgs3 = Seq( @@ -657,6 +661,29 @@ class SparkSubmitSuite conf3.get(PYSPARK_PYTHON.key) should be ("python3.5") } + test("ambiguous archive mapping results in error message") { + val dir = Utils.createTempDir() + val archive1 = Paths.get(dir.toPath.toString, "first.zip") + val archive2 = Paths.get(dir.toPath.toString, "second.zip") + Files.createFile(archive1) + Files.createFile(archive2) + val jars = "/jar1,/jar2" + val files = "local:/file1,file2" + val archives = s"file:/archive1,${dir.toPath.toAbsolutePath.toString}/*.zip#archive3" + val pyFiles = "py-file1,py-file2" + + // Test files and archives (Yarn) + val clArgs2 = Seq( + "--master", "yarn", + "--class", "org.SomeClass", + "--files", files, + "--archives", archives, + "thejar.jar" + ) + + testPrematureExit(clArgs2.toArray, "resolves ambiguously to multiple files") + } + test("resolves config paths correctly") { val jars = "/jar1,/jar2" // spark.jars val files = "local:/file1,file2" // spark.files / spark.yarn.dist.files From 95e51ff849a4c46cae463636b1ee393042469e7b Mon Sep 17 00:00:00 2001 From: Kris Mok Date: Wed, 21 Mar 2018 21:21:36 -0700 Subject: [PATCH 0507/1161] [SPARK-23760][SQL] CodegenContext.withSubExprEliminationExprs should save/restore CSE state correctly ## What changes were proposed in this pull request? Fixed `CodegenContext.withSubExprEliminationExprs()` so that it saves/restores CSE state correctly. ## How was this patch tested? Added new unit test to verify that the old CSE state is indeed saved and restored around the `withSubExprEliminationExprs()` call. Manually verified that this test fails without this patch. Author: Kris Mok Closes #20870 from rednaxelafx/codegen-subexpr-fix. --- .../expressions/codegen/CodeGenerator.scala | 16 +++---- .../expressions/CodeGenerationSuite.scala | 44 +++++++++++++++++++ 2 files changed, 51 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index fe5e63ec0a2bb..84b1e3fbda876 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -402,7 +402,7 @@ class CodegenContext { val equivalentExpressions: EquivalentExpressions = new EquivalentExpressions // Foreach expression that is participating in subexpression elimination, the state to use. - val subExprEliminationExprs = mutable.HashMap.empty[Expression, SubExprEliminationState] + var subExprEliminationExprs = Map.empty[Expression, SubExprEliminationState] // The collection of sub-expression result resetting methods that need to be called on each row. val subexprFunctions = mutable.ArrayBuffer.empty[String] @@ -921,14 +921,12 @@ class CodegenContext { newSubExprEliminationExprs: Map[Expression, SubExprEliminationState])( f: => Seq[ExprCode]): Seq[ExprCode] = { val oldsubExprEliminationExprs = subExprEliminationExprs - subExprEliminationExprs.clear - newSubExprEliminationExprs.foreach(subExprEliminationExprs += _) + subExprEliminationExprs = newSubExprEliminationExprs val genCodes = f // Restore previous subExprEliminationExprs - subExprEliminationExprs.clear - oldsubExprEliminationExprs.foreach(subExprEliminationExprs += _) + subExprEliminationExprs = oldsubExprEliminationExprs genCodes } @@ -942,7 +940,7 @@ class CodegenContext { def subexpressionEliminationForWholeStageCodegen(expressions: Seq[Expression]): SubExprCodes = { // Create a clear EquivalentExpressions and SubExprEliminationState mapping val equivalentExpressions: EquivalentExpressions = new EquivalentExpressions - val subExprEliminationExprs = mutable.HashMap.empty[Expression, SubExprEliminationState] + val localSubExprEliminationExprs = mutable.HashMap.empty[Expression, SubExprEliminationState] // Add each expression tree and compute the common subexpressions. expressions.foreach(equivalentExpressions.addExprTree) @@ -955,10 +953,10 @@ class CodegenContext { // Generate the code for this expression tree. val eval = expr.genCode(this) val state = SubExprEliminationState(eval.isNull, eval.value) - e.foreach(subExprEliminationExprs.put(_, state)) + e.foreach(localSubExprEliminationExprs.put(_, state)) eval.code.trim } - SubExprCodes(codes, subExprEliminationExprs.toMap) + SubExprCodes(codes, localSubExprEliminationExprs.toMap) } /** @@ -1006,7 +1004,7 @@ class CodegenContext { subexprFunctions += s"${addNewFunction(fnName, fn)}($INPUT_ROW);" val state = SubExprEliminationState(isNull, value) - e.foreach(subExprEliminationExprs.put(_, state)) + subExprEliminationExprs ++= e.map(_ -> state).toMap } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index 64c13e8972036..398b6767654fa 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -442,4 +442,48 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { assert(CodeGenerator.calculateParamLength( Seq.range(0, 100).map(x => Literal(x.toLong))) == 201) } + + test("SPARK-23760: CodegenContext.withSubExprEliminationExprs should save/restore correctly") { + + val ref = BoundReference(0, IntegerType, true) + val add1 = Add(ref, ref) + val add2 = Add(add1, add1) + + // raw testing of basic functionality + { + val ctx = new CodegenContext + val e = ref.genCode(ctx) + // before + ctx.subExprEliminationExprs += ref -> SubExprEliminationState(e.isNull, e.value) + assert(ctx.subExprEliminationExprs.contains(ref)) + // call withSubExprEliminationExprs + ctx.withSubExprEliminationExprs(Map(add1 -> SubExprEliminationState("dummy", "dummy"))) { + assert(ctx.subExprEliminationExprs.contains(add1)) + assert(!ctx.subExprEliminationExprs.contains(ref)) + Seq.empty + } + // after + assert(ctx.subExprEliminationExprs.nonEmpty) + assert(ctx.subExprEliminationExprs.contains(ref)) + assert(!ctx.subExprEliminationExprs.contains(add1)) + } + + // emulate an actual codegen workload + { + val ctx = new CodegenContext + // before + ctx.generateExpressions(Seq(add2, add1), doSubexpressionElimination = true) // trigger CSE + assert(ctx.subExprEliminationExprs.contains(add1)) + // call withSubExprEliminationExprs + ctx.withSubExprEliminationExprs(Map(ref -> SubExprEliminationState("dummy", "dummy"))) { + assert(ctx.subExprEliminationExprs.contains(ref)) + assert(!ctx.subExprEliminationExprs.contains(add1)) + Seq.empty + } + // after + assert(ctx.subExprEliminationExprs.nonEmpty) + assert(ctx.subExprEliminationExprs.contains(add1)) + assert(!ctx.subExprEliminationExprs.contains(ref)) + } + } } From 5c9eaa6b585e9febd782da8eb6490b24d0d39ff3 Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Wed, 21 Mar 2018 21:49:02 -0700 Subject: [PATCH 0508/1161] [SPARK-23372][SQL] Writing empty struct in parquet fails during execution. It should fail earlier in the processing. ## What changes were proposed in this pull request? Currently we allow writing data frames with empty schema into a file based datasource for certain file formats such as JSON, ORC etc. For formats such as Parquet and Text, we raise error at different times of execution. For text format, we return error from the driver early on in processing where as for format such as parquet, the error is raised from executor. **Example** spark.emptyDataFrame.write.format("parquet").mode("overwrite").save(path) **Results in** ``` SQL org.apache.parquet.schema.InvalidSchemaException: Cannot write a schema with an empty group: message spark_schema { } at org.apache.parquet.schema.TypeUtil$1.visit(TypeUtil.java:27) at org.apache.parquet.schema.TypeUtil$1.visit(TypeUtil.java:37) at org.apache.parquet.schema.MessageType.accept(MessageType.java:58) at org.apache.parquet.schema.TypeUtil.checkValidWriteSchema(TypeUtil.java:23) at org.apache.parquet.hadoop.ParquetFileWriter.(ParquetFileWriter.java:225) at org.apache.parquet.hadoop.ParquetOutputFormat.getRecordWriter(ParquetOutputFormat.java:342) at org.apache.parquet.hadoop.ParquetOutputFormat.getRecordWriter(ParquetOutputFormat.java:302) at org.apache.spark.sql.execution.datasources.parquet.ParquetOutputWriter.(ParquetOutputWriter.scala:37) at org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat$$anon$1.newInstance(ParquetFileFormat.scala:151) at org.apache.spark.sql.execution.datasources.FileFormatWriter$SingleDirectoryWriteTask.newOutputWriter(FileFormatWriter.scala:376) at org.apache.spark.sql.execution.datasources.FileFormatWriter$SingleDirectoryWriteTask.execute(FileFormatWriter.scala:387) at org.apache.spark.sql.execution.datasources.FileFormatWriter$$anonfun$org$apache$spark$sql$execution$datasources$FileFormatWriter$$executeTask$3.apply(FileFormatWriter.scala:278) at org.apache.spark.sql.execution.datasources.FileFormatWriter$$anonfun$org$apache$spark$sql$execution$datasources$FileFormatWriter$$executeTask$3.apply(FileFormatWriter.scala:276) at org.apache.spark.util.Utils$.tryWithSafeFinallyAndFailureCallbacks(Utils.scala:1411) at org.apache.spark.sql.execution.datasources.FileFormatWriter$.org$apache$spark$sql$execution$datasources$FileFormatWriter$$executeTask(FileFormatWriter.scala:281) at org.apache.spark.sql.execution.datasources.FileFormatWriter$$anonfun$write$1.apply(FileFormatWriter.scala:206) at org.apache.spark.sql.execution.datasources.FileFormatWriter$$anonfun$write$1.apply(FileFormatWriter.scala:205) at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:87) at org.apache.spark.scheduler.Task.run(Task.scala:109) at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:345) at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142) at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617) at java.lang.Thread.run(Thread. ``` In this PR, we unify the error processing and raise error on attempt to write empty schema based dataframes into file based datasource (orc, parquet, text , csv, json etc) early on in the processing. ## How was this patch tested? Unit tests added in FileBasedDatasourceSuite. Author: Dilip Biswal Closes #20579 from dilipbiswal/spark-23372. --- docs/sql-programming-guide.md | 1 + .../execution/datasources/DataSource.scala | 26 ++++++++++++++++- .../spark/sql/FileBasedDataSourceSuite.scala | 28 +++++++++++++++++++ 3 files changed, 54 insertions(+), 1 deletion(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 5b47fd77f2cbc..421e2eaf62bfb 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1807,6 +1807,7 @@ working with timestamps in `pandas_udf`s to get the best performance, see - In PySpark, when Arrow optimization is enabled, previously `toPandas` just failed when Arrow optimization is unabled to be used whereas `createDataFrame` from Pandas DataFrame allowed the fallback to non-optimization. Now, both `toPandas` and `createDataFrame` from Pandas DataFrame allow the fallback by default, which can be switched off by `spark.sql.execution.arrow.fallback.enabled`. - Since Spark 2.4, writing an empty dataframe to a directory launches at least one write task, even if physically the dataframe has no partition. This introduces a small behavior change that for self-describing file formats like Parquet and Orc, Spark creates a metadata-only file in the target directory when writing a 0-partition dataframe, so that schema inference can still work if users read that directory later. The new behavior is more reasonable and more consistent regarding writing empty dataframe. - Since Spark 2.4, expression IDs in UDF arguments do not appear in column names. For example, an column name in Spark 2.4 is not `UDF:f(col0 AS colA#28)` but ``UDF:f(col0 AS `colA`)``. + - Since Spark 2.4, writing a dataframe with an empty or nested empty schema using any file formats (parquet, orc, json, text, csv etc.) is not allowed. An exception is thrown when attempting to write dataframes with empty schema. ## Upgrading From Spark SQL 2.2 to 2.3 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index 35fcff69b14d8..31fa89b4570a6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -45,7 +45,7 @@ import org.apache.spark.sql.execution.streaming.sources.TextSocketSourceProvider import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources._ import org.apache.spark.sql.streaming.OutputMode -import org.apache.spark.sql.types.{CalendarIntervalType, StructType} +import org.apache.spark.sql.types.{CalendarIntervalType, StructField, StructType} import org.apache.spark.sql.util.SchemaUtils import org.apache.spark.util.Utils @@ -546,6 +546,7 @@ case class DataSource( case dataSource: CreatableRelationProvider => SaveIntoDataSourceCommand(data, dataSource, caseInsensitiveOptions, mode) case format: FileFormat => + DataSource.validateSchema(data.schema) planForWritingFileFormat(format, mode, data) case _ => sys.error(s"${providingClass.getCanonicalName} does not allow create table as select.") @@ -719,4 +720,27 @@ object DataSource extends Logging { } globPath } + + /** + * Called before writing into a FileFormat based data source to make sure the + * supplied schema is not empty. + * @param schema + */ + private def validateSchema(schema: StructType): Unit = { + def hasEmptySchema(schema: StructType): Boolean = { + schema.size == 0 || schema.find { + case StructField(_, b: StructType, _, _) => hasEmptySchema(b) + case _ => false + }.isDefined + } + + + if (hasEmptySchema(schema)) { + throw new AnalysisException( + s""" + |Datasource does not support writing empty or nested empty schemas. + |Please make sure the data schema has at least one or more column(s). + """.stripMargin) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala index bd3071bcf9010..06303099f5310 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala @@ -26,6 +26,7 @@ import org.apache.spark.SparkException import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types._ class FileBasedDataSourceSuite extends QueryTest with SharedSQLContext with BeforeAndAfterAll { @@ -107,6 +108,33 @@ class FileBasedDataSourceSuite extends QueryTest with SharedSQLContext with Befo } } + allFileBasedDataSources.foreach { format => + test(s"SPARK-23372 error while writing empty schema files using $format") { + withTempPath { outputPath => + val errMsg = intercept[AnalysisException] { + spark.emptyDataFrame.write.format(format).save(outputPath.toString) + } + assert(errMsg.getMessage.contains( + "Datasource does not support writing empty or nested empty schemas")) + } + + // Nested empty schema + withTempPath { outputPath => + val schema = StructType(Seq( + StructField("a", IntegerType), + StructField("b", StructType(Nil)), + StructField("c", IntegerType) + )) + val df = spark.createDataFrame(sparkContext.emptyRDD[Row], schema) + val errMsg = intercept[AnalysisException] { + df.write.format(format).save(outputPath.toString) + } + assert(errMsg.getMessage.contains( + "Datasource does not support writing empty or nested empty schemas")) + } + } + } + allFileBasedDataSources.foreach { format => test(s"SPARK-22146 read files containing special characters using $format") { withTempDir { dir => From 4d37008c78d7d6b8f8a649b375ecc090700eca4f Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 22 Mar 2018 19:57:32 +0100 Subject: [PATCH 0509/1161] [SPARK-23599][SQL] Use RandomUUIDGenerator in Uuid expression ## What changes were proposed in this pull request? As stated in Jira, there are problems with current `Uuid` expression which uses `java.util.UUID.randomUUID` for UUID generation. This patch uses the newly added `RandomUUIDGenerator` for UUID generation. So we can make `Uuid` deterministic between retries. ## How was this patch tested? Added unit tests. Author: Liang-Chi Hsieh Closes #20861 from viirya/SPARK-23599-2. --- .../sql/catalyst/analysis/Analyzer.scala | 16 ++++ .../spark/sql/catalyst/expressions/misc.scala | 26 +++++-- .../ResolvedUuidExpressionsSuite.scala | 73 +++++++++++++++++++ .../expressions/ExpressionEvalHelper.scala | 5 +- .../expressions/MiscExpressionsSuite.scala | 19 ++++- .../org/apache/spark/sql/DataFrameSuite.scala | 6 ++ 6 files changed, 136 insertions(+), 9 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolvedUuidExpressionsSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 7848f88bda1c9..e821e96522f7c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.analysis import scala.collection.mutable.ArrayBuffer +import scala.util.Random import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst._ @@ -177,6 +178,7 @@ class Analyzer( TimeWindowing :: ResolveInlineTables(conf) :: ResolveTimeZone(conf) :: + ResolvedUuidExpressions :: TypeCoercion.typeCoercionRules(conf) ++ extendedResolutionRules : _*), Batch("Post-Hoc Resolution", Once, postHocResolutionRules: _*), @@ -1994,6 +1996,20 @@ class Analyzer( } } + /** + * Set the seed for random number generation in Uuid expressions. + */ + object ResolvedUuidExpressions extends Rule[LogicalPlan] { + private lazy val random = new Random() + + override def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { + case p if p.resolved => p + case p => p transformExpressionsUp { + case Uuid(None) => Uuid(Some(random.nextLong())) + } + } + } + /** * Correctly handle null primitive inputs for UDF by adding extra [[If]] expression to do the * null check. When user defines a UDF with primitive parameters, there is no way to tell if the diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index 38e4fe44b15ab..ec93620038cff 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -21,6 +21,7 @@ import java.util.UUID import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.util.RandomUUIDGenerator import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -122,18 +123,33 @@ case class CurrentDatabase() extends LeafExpression with Unevaluable { 46707d92-02f4-4817-8116-a4c3b23e6266 """) // scalastyle:on line.size.limit -case class Uuid() extends LeafExpression { +case class Uuid(randomSeed: Option[Long] = None) extends LeafExpression with Nondeterministic { - override lazy val deterministic: Boolean = false + def this() = this(None) + + override lazy val resolved: Boolean = randomSeed.isDefined override def nullable: Boolean = false override def dataType: DataType = StringType - override def eval(input: InternalRow): Any = UTF8String.fromString(UUID.randomUUID().toString) + @transient private[this] var randomGenerator: RandomUUIDGenerator = _ + + override protected def initializeInternal(partitionIndex: Int): Unit = + randomGenerator = RandomUUIDGenerator(randomSeed.get + partitionIndex) + + override protected def evalInternal(input: InternalRow): Any = + randomGenerator.getNextUUIDUTF8String() override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - ev.copy(code = s"final UTF8String ${ev.value} = " + - s"UTF8String.fromString(java.util.UUID.randomUUID().toString());", isNull = "false") + val randomGen = ctx.freshName("randomGen") + ctx.addMutableState("org.apache.spark.sql.catalyst.util.RandomUUIDGenerator", randomGen, + forceInline = true, + useFreshName = false) + ctx.addPartitionInitializationStatement(s"$randomGen = " + + "new org.apache.spark.sql.catalyst.util.RandomUUIDGenerator(" + + s"${randomSeed.get}L + partitionIndex);") + ev.copy(code = s"final UTF8String ${ev.value} = $randomGen.getNextUUIDUTF8String();", + isNull = "false") } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolvedUuidExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolvedUuidExpressionsSuite.scala new file mode 100644 index 0000000000000..fe57c199b8744 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolvedUuidExpressionsSuite.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} + +/** + * Test suite for resolving Uuid expressions. + */ +class ResolvedUuidExpressionsSuite extends AnalysisTest { + + private lazy val a = 'a.int + private lazy val r = LocalRelation(a) + private lazy val uuid1 = Uuid().as('_uuid1) + private lazy val uuid2 = Uuid().as('_uuid2) + private lazy val uuid3 = Uuid().as('_uuid3) + private lazy val uuid1Ref = uuid1.toAttribute + + private val analyzer = getAnalyzer(caseSensitive = true) + + private def getUuidExpressions(plan: LogicalPlan): Seq[Uuid] = { + plan.flatMap { + case p => + p.expressions.flatMap(_.collect { + case u: Uuid => u + }) + } + } + + test("analyzed plan sets random seed for Uuid expression") { + val plan = r.select(a, uuid1) + val resolvedPlan = analyzer.executeAndCheck(plan) + getUuidExpressions(resolvedPlan).foreach { u => + assert(u.resolved) + assert(u.randomSeed.isDefined) + } + } + + test("Uuid expressions should have different random seeds") { + val plan = r.select(a, uuid1).groupBy(uuid1Ref)(uuid2, uuid3) + val resolvedPlan = analyzer.executeAndCheck(plan) + assert(getUuidExpressions(resolvedPlan).map(_.randomSeed.get).distinct.length == 3) + } + + test("Different analyzed plans should have different random seeds in Uuids") { + val plan = r.select(a, uuid1).groupBy(uuid1Ref)(uuid2, uuid3) + val resolvedPlan1 = analyzer.executeAndCheck(plan) + val resolvedPlan2 = analyzer.executeAndCheck(plan) + val uuids1 = getUuidExpressions(resolvedPlan1) + val uuids2 = getUuidExpressions(resolvedPlan2) + assert(uuids1.distinct.length == 3) + assert(uuids2.distinct.length == 3) + assert(uuids1.intersect(uuids2).length == 0) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index c6343b1cbf600..3828f172a15cf 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -176,7 +176,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { } } - private def evaluateWithGeneratedMutableProjection( + protected def evaluateWithGeneratedMutableProjection( expression: Expression, inputRow: InternalRow = EmptyRow): Any = { val plan = generateProject( @@ -220,7 +220,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { } } - private def evaluateWithUnsafeProjection( + protected def evaluateWithUnsafeProjection( expression: Expression, inputRow: InternalRow = EmptyRow, factory: UnsafeProjectionCreator = UnsafeProjection): InternalRow = { @@ -233,6 +233,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { Alias(expression, s"Optimized($expression)2")() :: Nil), expression) + plan.initialize(0) plan(inputRow) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala index c3d08bf68c7bb..3383d421f5616 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.catalyst.expressions import java.io.PrintStream +import scala.util.Random + import org.apache.spark.SparkFunSuite import org.apache.spark.sql.types._ @@ -42,8 +44,21 @@ class MiscExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("uuid") { - checkEvaluation(Length(Uuid()), 36) - assert(evaluateWithoutCodegen(Uuid()) !== evaluateWithoutCodegen(Uuid())) + checkEvaluation(Length(Uuid(Some(0))), 36) + val r = new Random() + val seed1 = Some(r.nextLong()) + assert(evaluateWithoutCodegen(Uuid(seed1)) === evaluateWithoutCodegen(Uuid(seed1))) + assert(evaluateWithGeneratedMutableProjection(Uuid(seed1)) === + evaluateWithGeneratedMutableProjection(Uuid(seed1))) + assert(evaluateWithUnsafeProjection(Uuid(seed1)) === + evaluateWithUnsafeProjection(Uuid(seed1))) + + val seed2 = Some(r.nextLong()) + assert(evaluateWithoutCodegen(Uuid(seed1)) !== evaluateWithoutCodegen(Uuid(seed2))) + assert(evaluateWithGeneratedMutableProjection(Uuid(seed1)) !== + evaluateWithGeneratedMutableProjection(Uuid(seed2))) + assert(evaluateWithUnsafeProjection(Uuid(seed1)) !== + evaluateWithUnsafeProjection(Uuid(seed2))) } test("PrintToStderr") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 8b66f77b2f923..f7b3393f65cb1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -28,6 +28,7 @@ import org.scalatest.Matchers._ import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.expressions.Uuid import org.apache.spark.sql.catalyst.plans.logical.{Filter, OneRowRelation, Union} import org.apache.spark.sql.execution.{FilterExec, QueryExecution, WholeStageCodegenExec} import org.apache.spark.sql.execution.aggregate.HashAggregateExec @@ -2264,4 +2265,9 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { checkAnswer(df, Row(0, 10) :: Nil) assert(df.queryExecution.executedPlan.isInstanceOf[WholeStageCodegenExec]) } + + test("Uuid expressions should produce same results at retries in the same DataFrame") { + val df = spark.range(1).select($"id", new Column(Uuid())) + checkAnswer(df, df.collect()) + } } From a649fcf32a7e610da2a2b4e3d94f5d1372c825d6 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Thu, 22 Mar 2018 21:20:41 -0700 Subject: [PATCH 0510/1161] [MINOR][PYTHON] Remove unused codes in schema parsing logics of PySpark ## What changes were proposed in this pull request? This PR proposes to remove out unused codes, `_ignore_brackets_split` and `_BRACKETS`. `_ignore_brackets_split` was introduced in https://github.com/apache/spark/commit/d57daf1f7732a7ac54a91fe112deeda0a254f9ef to refactor and support `toDF("...")`; however, https://github.com/apache/spark/commit/ebc124d4c44d4c84f7868f390f778c0ff5cd66cb replaced the logics here. Seems `_ignore_brackets_split` is not referred anymore. `_BRACKETS` was introduced in https://github.com/apache/spark/commit/880eabec37c69ce4e9594d7babfac291b0f93f50; however, all other usages were removed out in https://github.com/apache/spark/commit/648a8626b82d27d84db3e48bccfd73d020828586. This is rather a followup for https://github.com/apache/spark/commit/ebc124d4c44d4c84f7868f390f778c0ff5cd66cb which I missed in that PR. ## How was this patch tested? Manually tested. Existing tests should cover this. I also double checked by `grep` in the whole repo. Author: hyukjinkwon Closes #20878 from HyukjinKwon/minor-remove-unused. --- python/pyspark/sql/types.py | 35 ----------------------------------- 1 file changed, 35 deletions(-) diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 826aab97e58db..5d5919e451b46 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -752,41 +752,6 @@ def __eq__(self, other): _FIXED_DECIMAL = re.compile("decimal\\(\\s*(\\d+)\\s*,\\s*(\\d+)\\s*\\)") -_BRACKETS = {'(': ')', '[': ']', '{': '}'} - - -def _ignore_brackets_split(s, separator): - """ - Splits the given string by given separator, but ignore separators inside brackets pairs, e.g. - given "a,b" and separator ",", it will return ["a", "b"], but given "a, d", it will return - ["a", "d"]. - """ - parts = [] - buf = "" - level = 0 - for c in s: - if c in _BRACKETS.keys(): - level += 1 - buf += c - elif c in _BRACKETS.values(): - if level == 0: - raise ValueError("Brackets are not correctly paired: %s" % s) - level -= 1 - buf += c - elif c == separator and level > 0: - buf += c - elif c == separator: - parts.append(buf) - buf = "" - else: - buf += c - - if len(buf) == 0: - raise ValueError("The %s cannot be the last char: %s" % (separator, s)) - parts.append(buf) - return parts - - def _parse_datatype_string(s): """ Parses the given data type string to a :class:`DataType`. The data type string format equals From b2edc30db1dcc6102687d20c158a2700965fdf51 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 22 Mar 2018 21:23:25 -0700 Subject: [PATCH 0511/1161] [SPARK-23614][SQL] Fix incorrect reuse exchange when caching is used ## What changes were proposed in this pull request? We should provide customized canonicalize plan for `InMemoryRelation` and `InMemoryTableScanExec`. Otherwise, we can wrongly treat two different cached plans as same result. It causes wrongly reused exchange then. For a test query like this: ```scala val cached = spark.createDataset(Seq(TestDataUnion(1, 2, 3), TestDataUnion(4, 5, 6))).cache() val group1 = cached.groupBy("x").agg(min(col("y")) as "value") val group2 = cached.groupBy("x").agg(min(col("z")) as "value") group1.union(group2) ``` Canonicalized plans before: First exchange: ``` Exchange hashpartitioning(none#0, 5) +- *(1) HashAggregate(keys=[none#0], functions=[partial_min(none#1)], output=[none#0, none#4]) +- *(1) InMemoryTableScan [none#0, none#1] +- InMemoryRelation [x#4253, y#4254, z#4255], true, 10000, StorageLevel(disk, memory, deserialized, 1 replicas) +- LocalTableScan [x#4253, y#4254, z#4255] ``` Second exchange: ``` Exchange hashpartitioning(none#0, 5) +- *(3) HashAggregate(keys=[none#0], functions=[partial_min(none#1)], output=[none#0, none#4]) +- *(3) InMemoryTableScan [none#0, none#1] +- InMemoryRelation [x#4253, y#4254, z#4255], true, 10000, StorageLevel(disk, memory, deserialized, 1 replicas) +- LocalTableScan [x#4253, y#4254, z#4255] ``` You can find that they have the canonicalized plans are the same, although we use different columns in two `InMemoryTableScan`s. Canonicalized plan after: First exchange: ``` Exchange hashpartitioning(none#0, 5) +- *(1) HashAggregate(keys=[none#0], functions=[partial_min(none#1)], output=[none#0, none#4]) +- *(1) InMemoryTableScan [none#0, none#1] +- InMemoryRelation [none#0, none#1, none#2], true, 10000, StorageLevel(memory, 1 replicas) +- LocalTableScan [none#0, none#1, none#2] ``` Second exchange: ``` Exchange hashpartitioning(none#0, 5) +- *(3) HashAggregate(keys=[none#0], functions=[partial_min(none#1)], output=[none#0, none#4]) +- *(3) InMemoryTableScan [none#0, none#2] +- InMemoryRelation [none#0, none#1, none#2], true, 10000, StorageLevel(memory, 1 replicas) +- LocalTableScan [none#0, none#1, none#2] ``` ## How was this patch tested? Added unit test. Author: Liang-Chi Hsieh Closes #20831 from viirya/SPARK-23614. --- .../execution/columnar/InMemoryRelation.scala | 10 ++++++++++ .../columnar/InMemoryTableScanExec.scala | 19 +++++++++++++------ .../org/apache/spark/sql/DatasetSuite.scala | 9 +++++++++ .../spark/sql/execution/ExchangeSuite.scala | 7 +++++++ 4 files changed, 39 insertions(+), 6 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala index 22e16913d4da9..2579046e30708 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala @@ -24,6 +24,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical.{HintInfo, Statistics} import org.apache.spark.sql.execution.SparkPlan @@ -68,6 +69,15 @@ case class InMemoryRelation( override protected def innerChildren: Seq[SparkPlan] = Seq(child) + override def doCanonicalize(): logical.LogicalPlan = + copy(output = output.map(QueryPlan.normalizeExprId(_, child.output)), + storageLevel = StorageLevel.NONE, + child = child.canonicalized, + tableName = None)( + _cachedColumnBuffers, + sizeInBytesStats, + statsOfPlanToCache) + override def producedAttributes: AttributeSet = outputSet @transient val partitionStatistics = new PartitionStatistics(output) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala index a93e8a1ad954d..e73e1378d52e3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning} -import org.apache.spark.sql.execution.{ColumnarBatchScan, LeafExecNode, WholeStageCodegenExec} +import org.apache.spark.sql.execution.{ColumnarBatchScan, LeafExecNode, SparkPlan, WholeStageCodegenExec} import org.apache.spark.sql.execution.vectorized._ import org.apache.spark.sql.types._ import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector} @@ -38,6 +38,11 @@ case class InMemoryTableScanExec( override protected def innerChildren: Seq[QueryPlan[_]] = Seq(relation) ++ super.innerChildren + override def doCanonicalize(): SparkPlan = + copy(attributes = attributes.map(QueryPlan.normalizeExprId(_, relation.output)), + predicates = predicates.map(QueryPlan.normalizeExprId(_, relation.output)), + relation = relation.canonicalized.asInstanceOf[InMemoryRelation]) + override def vectorTypes: Option[Seq[String]] = Option(Seq.fill(attributes.length)( if (!conf.offHeapColumnVectorEnabled) { @@ -169,11 +174,13 @@ case class InMemoryTableScanExec( override def outputOrdering: Seq[SortOrder] = relation.child.outputOrdering.map(updateAttribute(_).asInstanceOf[SortOrder]) - private def statsFor(a: Attribute) = relation.partitionStatistics.forAttribute(a) + // Keeps relation's partition statistics because we don't serialize relation. + private val stats = relation.partitionStatistics + private def statsFor(a: Attribute) = stats.forAttribute(a) // Returned filter predicate should return false iff it is impossible for the input expression // to evaluate to `true' based on statistics collected about this partition batch. - @transient val buildFilter: PartialFunction[Expression, Expression] = { + @transient lazy val buildFilter: PartialFunction[Expression, Expression] = { case And(lhs: Expression, rhs: Expression) if buildFilter.isDefinedAt(lhs) || buildFilter.isDefinedAt(rhs) => (buildFilter.lift(lhs) ++ buildFilter.lift(rhs)).reduce(_ && _) @@ -213,14 +220,14 @@ case class InMemoryTableScanExec( l.asInstanceOf[Literal] <= statsFor(a).upperBound).reduce(_ || _) } - val partitionFilters: Seq[Expression] = { + lazy val partitionFilters: Seq[Expression] = { predicates.flatMap { p => val filter = buildFilter.lift(p) val boundFilter = filter.map( BindReferences.bindReference( _, - relation.partitionStatistics.schema, + stats.schema, allowFailures = true)) boundFilter.foreach(_ => @@ -243,7 +250,7 @@ case class InMemoryTableScanExec( private def filteredCachedBatches(): RDD[CachedBatch] = { // Using these variables here to avoid serialization of entire objects (if referenced directly) // within the map Partitions closure. - val schema = relation.partitionStatistics.schema + val schema = stats.schema val schemaIndex = schema.zipWithIndex val buffers = relation.cachedColumnBuffers diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 49c59cf695dc1..9b745befcb611 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -1446,8 +1446,17 @@ class DatasetSuite extends QueryTest with SharedSQLContext { val data = Seq(("a", null)) checkDataset(data.toDS(), data: _*) } + + test("SPARK-23614: Union produces incorrect results when caching is used") { + val cached = spark.createDataset(Seq(TestDataUnion(1, 2, 3), TestDataUnion(4, 5, 6))).cache() + val group1 = cached.groupBy("x").agg(min(col("y")) as "value") + val group2 = cached.groupBy("x").agg(min(col("z")) as "value") + checkAnswer(group1.union(group2), Row(4, 5) :: Row(1, 2) :: Row(4, 6) :: Row(1, 3) :: Nil) + } } +case class TestDataUnion(x: Int, y: Int, z: Int) + case class SingleData(id: Int) case class DoubleData(id: Int, val1: String) case class TripleData(id: Int, val1: String, val2: Long) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala index 697d7e6520713..bde2de5b39fd7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala @@ -125,4 +125,11 @@ class ExchangeSuite extends SparkPlanTest with SharedSQLContext { assertConsistency(spark.range(10000).map(i => Random.nextInt(1000).toLong)) } } + + test("SPARK-23614: Fix incorrect reuse exchange when caching is used") { + val cached = spark.createDataset(Seq((1, 2, 3), (4, 5, 6))).cache() + val projection1 = cached.select("_1", "_2").queryExecution.executedPlan + val projection2 = cached.select("_1", "_3").queryExecution.executedPlan + assert(!projection1.sameResult(projection2)) + } } From 5fa438471110afbf4e2174df449ac79e292501f8 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Fri, 23 Mar 2018 13:59:21 +0800 Subject: [PATCH 0512/1161] [SPARK-23361][YARN] Allow AM to restart after initial tokens expire. Currently, the Spark AM relies on the initial set of tokens created by the submission client to be able to talk to HDFS and other services that require delegation tokens. This means that after those tokens expire, a new AM will fail to start (e.g. when there is an application failure and re-attempts are enabled). This PR makes it so that the first thing the AM does when the user provides a principal and keytab is to create new delegation tokens for use. This makes sure that the AM can be started irrespective of how old the original token set is. It also allows all of the token management to be done by the AM - there is no need for the submission client to set configuration values to tell the AM when to renew tokens. Note that even though in this case the AM will not be using the delegation tokens created by the submission client, those tokens still need to be provided to YARN, since they are used to do log aggregation. To be able to re-use the code in the AMCredentialRenewal for the above purposes, I refactored that class a bit so that it can fetch tokens into a pre-defined UGI, insted of always logging in. Another issue with re-attempts is that, after the fix that allows the AM to restart correctly, new executors would get confused about when to update credentials, because the credential updater used the update time initially set up by the submission code. This could make the executor fail to update credentials in time, since that value would be very out of date in the situation described in the bug. To fix that, I changed the YARN code to use the new RPC-based mechanism for distributing tokens to executors. This allowed the old credential updater code to be removed, and a lot of code in the renewer to be simplified. I also made two currently hardcoded values (the renewal time ratio, and the retry wait) configurable; while this probably never needs to be set by anyone in a production environment, it helps with testing; that's also why they're not documented. Tested on real cluster with a specially crafted application to test this functionality: checked proper access to HDFS, Hive and HBase in cluster mode with token renewal on and AM restarts. Tested things still work in client mode too. Author: Marcelo Vanzin Closes #20657 from vanzin/SPARK-23361. --- .../scala/org/apache/spark/SparkConf.scala | 12 +- .../apache/spark/deploy/SparkHadoopUtil.scala | 32 +- .../CoarseGrainedExecutorBackend.scala | 12 - .../spark/internal/config/package.scala | 12 + .../MesosHadoopDelegationTokenManager.scala | 11 +- .../spark/deploy/yarn/ApplicationMaster.scala | 117 +++---- .../org/apache/spark/deploy/yarn/Client.scala | 102 ++---- .../deploy/yarn/YarnSparkHadoopUtil.scala | 20 -- .../org/apache/spark/deploy/yarn/config.scala | 25 -- .../yarn/security/AMCredentialRenewer.scala | 291 +++++++----------- .../yarn/security/CredentialUpdater.scala | 131 -------- .../YARNHadoopDelegationTokenManager.scala | 9 +- .../cluster/YarnClientSchedulerBackend.scala | 9 +- .../cluster/YarnSchedulerBackend.scala | 10 +- ...ARNHadoopDelegationTokenManagerSuite.scala | 7 +- .../apache/spark/streaming/Checkpoint.scala | 3 - 16 files changed, 238 insertions(+), 565 deletions(-) delete mode 100644 resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/CredentialUpdater.scala diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index f53b2bed74c6e..129956e9f9ffa 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -603,13 +603,15 @@ private[spark] object SparkConf extends Logging { "Please use spark.kryoserializer.buffer instead. The default value for " + "spark.kryoserializer.buffer.mb was previously specified as '0.064'. Fractional values " + "are no longer accepted. To specify the equivalent now, one may use '64k'."), - DeprecatedConfig("spark.rpc", "2.0", "Not used any more."), + DeprecatedConfig("spark.rpc", "2.0", "Not used anymore."), DeprecatedConfig("spark.scheduler.executorTaskBlacklistTime", "2.1.0", "Please use the new blacklisting options, spark.blacklist.*"), - DeprecatedConfig("spark.yarn.am.port", "2.0.0", "Not used any more"), - DeprecatedConfig("spark.executor.port", "2.0.0", "Not used any more"), + DeprecatedConfig("spark.yarn.am.port", "2.0.0", "Not used anymore"), + DeprecatedConfig("spark.executor.port", "2.0.0", "Not used anymore"), DeprecatedConfig("spark.shuffle.service.index.cache.entries", "2.3.0", - "Not used any more. Please use spark.shuffle.service.index.cache.size") + "Not used anymore. Please use spark.shuffle.service.index.cache.size"), + DeprecatedConfig("spark.yarn.credentials.file.retention.count", "2.4.0", "Not used anymore."), + DeprecatedConfig("spark.yarn.credentials.file.retention.days", "2.4.0", "Not used anymore.") ) Map(configs.map { cfg => (cfg.key -> cfg) } : _*) @@ -748,7 +750,7 @@ private[spark] object SparkConf extends Logging { } if (key.startsWith("spark.akka") || key.startsWith("spark.ssl.akka")) { logWarning( - s"The configuration key $key is not supported any more " + + s"The configuration key $key is not supported anymore " + s"because Spark doesn't use Akka since 2.0") } } diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala index 177295fb7af0f..8353e64a619cf 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala @@ -40,6 +40,7 @@ import org.apache.hadoop.security.token.delegation.AbstractDelegationTokenIdenti import org.apache.spark.{SparkConf, SparkException} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.internal.Logging +import org.apache.spark.internal.config._ import org.apache.spark.util.Utils /** @@ -146,7 +147,8 @@ class SparkHadoopUtil extends Logging { private[spark] def addDelegationTokens(tokens: Array[Byte], sparkConf: SparkConf) { UserGroupInformation.setConfiguration(newConfiguration(sparkConf)) val creds = deserialize(tokens) - logInfo(s"Adding/updating delegation tokens ${dumpTokens(creds)}") + logInfo("Updating delegation tokens for current user.") + logDebug(s"Adding/updating delegation tokens ${dumpTokens(creds)}") addCurrentUserCredentials(creds) } @@ -321,19 +323,6 @@ class SparkHadoopUtil extends Logging { } } - /** - * Return a fresh Hadoop configuration, bypassing the HDFS cache mechanism. - * This is to prevent the DFSClient from using an old cached token to connect to the NameNode. - */ - private[spark] def getConfBypassingFSCache( - hadoopConf: Configuration, - scheme: String): Configuration = { - val newConf = new Configuration(hadoopConf) - val confKey = s"fs.${scheme}.impl.disable.cache" - newConf.setBoolean(confKey, true) - newConf - } - /** * Dump the credentials' tokens to string values. * @@ -447,16 +436,17 @@ object SparkHadoopUtil { def get: SparkHadoopUtil = instance /** - * Given an expiration date (e.g. for Hadoop Delegation Tokens) return a the date - * when a given fraction of the duration until the expiration date has passed. - * Formula: current time + (fraction * (time until expiration)) + * Given an expiration date for the current set of credentials, calculate the time when new + * credentials should be created. + * * @param expirationDate Drop-dead expiration date - * @param fraction fraction of the time until expiration return - * @return Date when the fraction of the time until expiration has passed + * @param conf Spark configuration + * @return Timestamp when new credentials should be created. */ - private[spark] def getDateOfNextUpdate(expirationDate: Long, fraction: Double): Long = { + private[spark] def nextCredentialRenewalTime(expirationDate: Long, conf: SparkConf): Long = { val ct = System.currentTimeMillis - (ct + (fraction * (expirationDate - ct))).toLong + val ratio = conf.get(CREDENTIALS_RENEWAL_INTERVAL_RATIO) + (ct + (ratio * (expirationDate - ct))).toLong } /** diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index 9b62e4b1b7150..48d3630abd1f9 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -213,13 +213,6 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { driverConf.set(key, value) } } - if (driverConf.contains("spark.yarn.credentials.file")) { - logInfo("Will periodically update credentials from: " + - driverConf.get("spark.yarn.credentials.file")) - Utils.classForName("org.apache.spark.deploy.yarn.YarnSparkHadoopUtil") - .getMethod("startCredentialUpdater", classOf[SparkConf]) - .invoke(null, driverConf) - } cfg.hadoopDelegationCreds.foreach { tokens => SparkHadoopUtil.get.addDelegationTokens(tokens, driverConf) @@ -234,11 +227,6 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { env.rpcEnv.setupEndpoint("WorkerWatcher", new WorkerWatcher(env.rpcEnv, url)) } env.rpcEnv.awaitTermination() - if (driverConf.contains("spark.yarn.credentials.file")) { - Utils.classForName("org.apache.spark.deploy.yarn.YarnSparkHadoopUtil") - .getMethod("stopCredentialUpdater") - .invoke(null) - } } } diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index a313ad0554a3a..407545aa4a47a 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -525,4 +525,16 @@ package object config { .bytesConf(ByteUnit.BYTE) .createWithDefaultString("1g") + private[spark] val CREDENTIALS_RENEWAL_INTERVAL_RATIO = + ConfigBuilder("spark.security.credentials.renewalRatio") + .doc("Ratio of the credential's expiration time when Spark should fetch new credentials.") + .doubleConf + .createWithDefault(0.75d) + + private[spark] val CREDENTIALS_RENEWAL_RETRY_WAIT = + ConfigBuilder("spark.security.credentials.retryWait") + .doc("How long to wait before retrying to fetch new credentials after a failure.") + .timeConf(TimeUnit.SECONDS) + .createWithDefaultString("1h") + } diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosHadoopDelegationTokenManager.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosHadoopDelegationTokenManager.scala index 7165bfae18a5e..a1bf4f0c048fe 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosHadoopDelegationTokenManager.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosHadoopDelegationTokenManager.scala @@ -29,6 +29,7 @@ import org.apache.spark.deploy.security.HadoopDelegationTokenManager import org.apache.spark.internal.{config, Logging} import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.UpdateDelegationTokens +import org.apache.spark.ui.UIUtils import org.apache.spark.util.ThreadUtils @@ -63,7 +64,7 @@ private[spark] class MesosHadoopDelegationTokenManager( val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf) val rt = tokenManager.obtainDelegationTokens(hadoopConf, creds) logInfo(s"Initialized tokens: ${SparkHadoopUtil.get.dumpTokens(creds)}") - (SparkHadoopUtil.get.serialize(creds), SparkHadoopUtil.getDateOfNextUpdate(rt, 0.75)) + (SparkHadoopUtil.get.serialize(creds), SparkHadoopUtil.nextCredentialRenewalTime(rt, conf)) } catch { case e: Exception => logError(s"Failed to fetch Hadoop delegation tokens $e") @@ -104,8 +105,10 @@ private[spark] class MesosHadoopDelegationTokenManager( } catch { case e: Exception => // Log the error and try to write new tokens back in an hour - logWarning("Couldn't broadcast tokens, trying again in an hour", e) - credentialRenewerThread.schedule(this, 1, TimeUnit.HOURS) + val delay = TimeUnit.SECONDS.toMillis(conf.get(config.CREDENTIALS_RENEWAL_RETRY_WAIT)) + logWarning( + s"Couldn't broadcast tokens, trying again in ${UIUtils.formatDuration(delay)}", e) + credentialRenewerThread.schedule(this, delay, TimeUnit.MILLISECONDS) return } scheduleRenewal(this) @@ -135,7 +138,7 @@ private[spark] class MesosHadoopDelegationTokenManager( "related configurations in the target services.") currTime } else { - SparkHadoopUtil.getDateOfNextUpdate(nextRenewalTime, 0.75) + SparkHadoopUtil.nextCredentialRenewalTime(nextRenewalTime, conf) } logInfo(s"Time of next renewal is in ${timeOfNextRenewal - System.currentTimeMillis()} ms") diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 6e35d23def6f0..d04989e138f83 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -29,7 +29,6 @@ import scala.concurrent.duration.Duration import scala.util.control.NonFatal import org.apache.hadoop.fs.{FileSystem, Path} -import org.apache.hadoop.security.UserGroupInformation import org.apache.hadoop.util.StringUtils import org.apache.hadoop.yarn.api._ import org.apache.hadoop.yarn.api.records._ @@ -41,7 +40,7 @@ import org.apache.spark._ import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.deploy.history.HistoryServer import org.apache.spark.deploy.yarn.config._ -import org.apache.spark.deploy.yarn.security.{AMCredentialRenewer, YARNHadoopDelegationTokenManager} +import org.apache.spark.deploy.yarn.security.AMCredentialRenewer import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ import org.apache.spark.rpc._ @@ -79,42 +78,43 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends private val yarnConf = new YarnConfiguration(SparkHadoopUtil.newConfiguration(sparkConf)) - private val ugi = { - val original = UserGroupInformation.getCurrentUser() - - // If a principal and keytab were provided, log in to kerberos, and set up a thread to - // renew the kerberos ticket when needed. Because the UGI API does not expose the TTL - // of the TGT, use a configuration to define how often to check that a relogin is necessary. - // checkTGTAndReloginFromKeytab() is a no-op if the relogin is not yet needed. - val principal = sparkConf.get(PRINCIPAL).orNull - val keytab = sparkConf.get(KEYTAB).orNull - if (principal != null && keytab != null) { - UserGroupInformation.loginUserFromKeytab(principal, keytab) - - val renewer = new Thread() { - override def run(): Unit = Utils.tryLogNonFatalError { - while (true) { - TimeUnit.SECONDS.sleep(sparkConf.get(KERBEROS_RELOGIN_PERIOD)) - UserGroupInformation.getCurrentUser().checkTGTAndReloginFromKeytab() - } - } + private val userClassLoader = { + val classpath = Client.getUserClasspath(sparkConf) + val urls = classpath.map { entry => + new URL("file:" + new File(entry.getPath()).getAbsolutePath()) + } + + if (isClusterMode) { + if (Client.isUserClassPathFirst(sparkConf, isDriver = true)) { + new ChildFirstURLClassLoader(urls, Utils.getContextOrSparkClassLoader) + } else { + new MutableURLClassLoader(urls, Utils.getContextOrSparkClassLoader) } - renewer.setName("am-kerberos-renewer") - renewer.setDaemon(true) - renewer.start() - - // Transfer the original user's tokens to the new user, since that's needed to connect to - // YARN. It also copies over any delegation tokens that might have been created by the - // client, which will then be transferred over when starting executors (until new ones - // are created by the periodic task). - val newUser = UserGroupInformation.getCurrentUser() - SparkHadoopUtil.get.transferCredentials(original, newUser) - newUser } else { - SparkHadoopUtil.get.createSparkUser() + new MutableURLClassLoader(urls, Utils.getContextOrSparkClassLoader) } } + private val credentialRenewer: Option[AMCredentialRenewer] = sparkConf.get(KEYTAB).map { _ => + new AMCredentialRenewer(sparkConf, yarnConf) + } + + private val ugi = credentialRenewer match { + case Some(cr) => + // Set the context class loader so that the token renewer has access to jars distributed + // by the user. + val currentLoader = Thread.currentThread().getContextClassLoader() + Thread.currentThread().setContextClassLoader(userClassLoader) + try { + cr.start() + } finally { + Thread.currentThread().setContextClassLoader(currentLoader) + } + + case _ => + SparkHadoopUtil.get.createSparkUser() + } + private val client = doAsUser { new YarnRMClient() } // Default to twice the number of executors (twice the maximum number of executors if dynamic @@ -148,23 +148,6 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends // A flag to check whether user has initialized spark context @volatile private var registered = false - private val userClassLoader = { - val classpath = Client.getUserClasspath(sparkConf) - val urls = classpath.map { entry => - new URL("file:" + new File(entry.getPath()).getAbsolutePath()) - } - - if (isClusterMode) { - if (Client.isUserClassPathFirst(sparkConf, isDriver = true)) { - new ChildFirstURLClassLoader(urls, Utils.getContextOrSparkClassLoader) - } else { - new MutableURLClassLoader(urls, Utils.getContextOrSparkClassLoader) - } - } else { - new MutableURLClassLoader(urls, Utils.getContextOrSparkClassLoader) - } - } - // Lock for controlling the allocator (heartbeat) thread. private val allocatorLock = new Object() @@ -189,8 +172,6 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends // In cluster mode, used to tell the AM when the user's SparkContext has been initialized. private val sparkContextPromise = Promise[SparkContext]() - private var credentialRenewer: AMCredentialRenewer = _ - // Load the list of localized files set by the client. This is used when launching executors, // and is loaded here so that these configs don't pollute the Web UI's environment page in // cluster mode. @@ -316,31 +297,6 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends } } - // If the credentials file config is present, we must periodically renew tokens. So create - // a new AMDelegationTokenRenewer - if (sparkConf.contains(CREDENTIALS_FILE_PATH)) { - // Start a short-lived thread for AMCredentialRenewer, the only purpose is to set the - // classloader so that main jar and secondary jars could be used by AMCredentialRenewer. - val credentialRenewerThread = new Thread { - setName("AMCredentialRenewerStarter") - setContextClassLoader(userClassLoader) - - override def run(): Unit = { - val credentialManager = new YARNHadoopDelegationTokenManager( - sparkConf, - yarnConf, - conf => YarnSparkHadoopUtil.hadoopFSsToAccess(sparkConf, conf)) - - val credentialRenewer = - new AMCredentialRenewer(sparkConf, yarnConf, credentialManager) - credentialRenewer.scheduleLoginFromKeytab() - } - } - - credentialRenewerThread.start() - credentialRenewerThread.join() - } - if (isClusterMode) { runDriver() } else { @@ -409,9 +365,8 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends logDebug("shutting down user thread") userClassThread.interrupt() } - if (!inShutdown && credentialRenewer != null) { - credentialRenewer.stop() - credentialRenewer = null + if (!inShutdown) { + credentialRenewer.foreach(_.stop()) } } } @@ -468,6 +423,8 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends securityMgr, localResources) + credentialRenewer.foreach(_.setDriverRef(driverRef)) + // Initialize the AM endpoint *after* the allocator has been initialized. This ensures // that when the driver sends an initial executor request (e.g. after an AM restart), // the allocator is ready to service requests. diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 28087dee831d1..5763c3dbc5a8a 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -93,11 +93,21 @@ private[spark] class Client( private val distCacheMgr = new ClientDistributedCacheManager() - private var loginFromKeytab = false - private var principal: String = null - private var keytab: String = null - private var credentials: Credentials = null - private var amKeytabFileName: String = null + private val principal = sparkConf.get(PRINCIPAL).orNull + private val keytab = sparkConf.get(KEYTAB).orNull + private val loginFromKeytab = principal != null + private val amKeytabFileName: String = { + require((principal == null) == (keytab == null), + "Both principal and keytab must be defined, or neither.") + if (loginFromKeytab) { + logInfo(s"Kerberos credentials: principal = $principal, keytab = $keytab") + // Generate a file name that can be used for the keytab file, that does not conflict + // with any user file. + new File(keytab).getName() + "-" + UUID.randomUUID().toString + } else { + null + } + } private val launcherBackend = new LauncherBackend() { override protected def conf: SparkConf = sparkConf @@ -120,11 +130,6 @@ private[spark] class Client( private val appStagingBaseDir = sparkConf.get(STAGING_DIR).map { new Path(_) } .getOrElse(FileSystem.get(hadoopConf).getHomeDirectory()) - private val credentialManager = new YARNHadoopDelegationTokenManager( - sparkConf, - hadoopConf, - conf => YarnSparkHadoopUtil.hadoopFSsToAccess(sparkConf, conf)) - def reportLauncherState(state: SparkAppHandle.State): Unit = { launcherBackend.setState(state) } @@ -145,9 +150,6 @@ private[spark] class Client( var appId: ApplicationId = null try { launcherBackend.connect() - // Setup the credentials before doing anything else, - // so we have don't have issues at any point. - setupCredentials() yarnClient.init(hadoopConf) yarnClient.start() @@ -288,8 +290,26 @@ private[spark] class Client( appContext } - /** Set up security tokens for launching our ApplicationMaster container. */ + /** + * Set up security tokens for launching our ApplicationMaster container. + * + * This method will obtain delegation tokens from all the registered providers, and set them in + * the AM's launch context. + */ private def setupSecurityToken(amContainer: ContainerLaunchContext): Unit = { + val credentials = UserGroupInformation.getCurrentUser().getCredentials() + val credentialManager = new YARNHadoopDelegationTokenManager(sparkConf, hadoopConf) + credentialManager.obtainDelegationTokens(hadoopConf, credentials) + + // When using a proxy user, copy the delegation tokens to the user's credentials. Avoid + // that for regular users, since in those case the user already has access to the TGT, + // and adding delegation tokens could lead to expired or cancelled tokens being used + // later, as reported in SPARK-15754. + val currentUser = UserGroupInformation.getCurrentUser() + if (SparkHadoopUtil.get.isProxyUser(currentUser)) { + currentUser.addCredentials(credentials) + } + val dob = new DataOutputBuffer credentials.writeTokenStorageToStream(dob) amContainer.setTokens(ByteBuffer.wrap(dob.getData)) @@ -384,36 +404,6 @@ private[spark] class Client( // and add them as local resources to the application master. val fs = destDir.getFileSystem(hadoopConf) - // Merge credentials obtained from registered providers - val nearestTimeOfNextRenewal = credentialManager.obtainDelegationTokens(hadoopConf, credentials) - - if (credentials != null) { - // Add credentials to current user's UGI, so that following operations don't need to use the - // Kerberos tgt to get delegations again in the client side. - val currentUser = UserGroupInformation.getCurrentUser() - if (SparkHadoopUtil.get.isProxyUser(currentUser)) { - currentUser.addCredentials(credentials) - } - logDebug(SparkHadoopUtil.get.dumpTokens(credentials).mkString("\n")) - } - - // If we use principal and keytab to login, also credentials can be renewed some time - // after current time, we should pass the next renewal and updating time to credential - // renewer and updater. - if (loginFromKeytab && nearestTimeOfNextRenewal > System.currentTimeMillis() && - nearestTimeOfNextRenewal != Long.MaxValue) { - - // Valid renewal time is 75% of next renewal time, and the valid update time will be - // slightly later then renewal time (80% of next renewal time). This is to make sure - // credentials are renewed and updated before expired. - val currTime = System.currentTimeMillis() - val renewalTime = (nearestTimeOfNextRenewal - currTime) * 0.75 + currTime - val updateTime = (nearestTimeOfNextRenewal - currTime) * 0.8 + currTime - - sparkConf.set(CREDENTIALS_RENEWAL_TIME, renewalTime.toLong) - sparkConf.set(CREDENTIALS_UPDATE_TIME, updateTime.toLong) - } - // Used to keep track of URIs added to the distributed cache. If the same URI is added // multiple times, YARN will fail to launch containers for the app with an internal // error. @@ -793,11 +783,6 @@ private[spark] class Client( populateClasspath(args, hadoopConf, sparkConf, env, sparkConf.get(DRIVER_CLASS_PATH)) env("SPARK_YARN_STAGING_DIR") = stagingDirPath.toString env("SPARK_USER") = UserGroupInformation.getCurrentUser().getShortUserName() - if (loginFromKeytab) { - val credentialsFile = "credentials-" + UUID.randomUUID().toString - sparkConf.set(CREDENTIALS_FILE_PATH, new Path(stagingDirPath, credentialsFile).toString) - logInfo(s"Credentials file set to: $credentialsFile") - } // Pick up any environment variables for the AM provided through spark.yarn.appMasterEnv.* val amEnvPrefix = "spark.yarn.appMasterEnv." @@ -1014,25 +999,6 @@ private[spark] class Client( amContainer } - def setupCredentials(): Unit = { - loginFromKeytab = sparkConf.contains(PRINCIPAL.key) - if (loginFromKeytab) { - principal = sparkConf.get(PRINCIPAL).get - keytab = sparkConf.get(KEYTAB).orNull - - require(keytab != null, "Keytab must be specified when principal is specified.") - logInfo("Attempting to login to the Kerberos" + - s" using principal: $principal and keytab: $keytab") - val f = new File(keytab) - // Generate a file name that can be used for the keytab file, that does not conflict - // with any user file. - amKeytabFileName = f.getName + "-" + UUID.randomUUID().toString - sparkConf.set(PRINCIPAL.key, principal) - } - // Defensive copy of the credentials - credentials = new Credentials(UserGroupInformation.getCurrentUser.getCredentials) - } - /** * Report the state of an application until it has exited, either successfully or * due to some failure, then return a pair of the yarn application state (FINISHED, FAILED, diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala index f406fabd61860..8eda6cb1277c5 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala @@ -30,7 +30,6 @@ import org.apache.hadoop.yarn.util.ConverterUtils import org.apache.spark.{SecurityManager, SparkConf, SparkException} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.deploy.yarn.config._ -import org.apache.spark.deploy.yarn.security.CredentialUpdater import org.apache.spark.deploy.yarn.security.YARNHadoopDelegationTokenManager import org.apache.spark.internal.config._ import org.apache.spark.launcher.YarnCommandBuilderUtils @@ -38,8 +37,6 @@ import org.apache.spark.util.Utils object YarnSparkHadoopUtil { - private var credentialUpdater: CredentialUpdater = _ - // Additional memory overhead // 10% was arrived at experimentally. In the interest of minimizing memory waste while covering // the common cases. Memory overhead tends to grow with container size. @@ -206,21 +203,4 @@ object YarnSparkHadoopUtil { filesystemsToAccess + stagingFS } - def startCredentialUpdater(sparkConf: SparkConf): Unit = { - val hadoopConf = SparkHadoopUtil.get.newConfiguration(sparkConf) - val credentialManager = new YARNHadoopDelegationTokenManager( - sparkConf, - hadoopConf, - conf => YarnSparkHadoopUtil.hadoopFSsToAccess(sparkConf, conf)) - credentialUpdater = new CredentialUpdater(sparkConf, hadoopConf, credentialManager) - credentialUpdater.start() - } - - def stopCredentialUpdater(): Unit = { - if (credentialUpdater != null) { - credentialUpdater.stop() - credentialUpdater = null - } - } - } diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala index 3ba3ae5ab4401..1a99b3bd57672 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala @@ -231,16 +231,6 @@ package object config { /* Security configuration. */ - private[spark] val CREDENTIAL_FILE_MAX_COUNT = - ConfigBuilder("spark.yarn.credentials.file.retention.count") - .intConf - .createWithDefault(5) - - private[spark] val CREDENTIALS_FILE_MAX_RETENTION = - ConfigBuilder("spark.yarn.credentials.file.retention.days") - .intConf - .createWithDefault(5) - private[spark] val NAMENODES_TO_ACCESS = ConfigBuilder("spark.yarn.access.namenodes") .doc("Extra NameNode URLs for which to request delegation tokens. The NameNode that hosts " + "fs.defaultFS does not need to be listed here.") @@ -271,11 +261,6 @@ package object config { /* Private configs. */ - private[spark] val CREDENTIALS_FILE_PATH = ConfigBuilder("spark.yarn.credentials.file") - .internal() - .stringConf - .createWithDefault(null) - // Internal config to propagate the location of the user's jar to the driver/executors private[spark] val APP_JAR = ConfigBuilder("spark.yarn.user.jar") .internal() @@ -329,16 +314,6 @@ package object config { .stringConf .createOptional - private[spark] val CREDENTIALS_RENEWAL_TIME = ConfigBuilder("spark.yarn.credentials.renewalTime") - .internal() - .timeConf(TimeUnit.MILLISECONDS) - .createWithDefault(Long.MaxValue) - - private[spark] val CREDENTIALS_UPDATE_TIME = ConfigBuilder("spark.yarn.credentials.updateTime") - .internal() - .timeConf(TimeUnit.MILLISECONDS) - .createWithDefault(Long.MaxValue) - private[spark] val KERBEROS_RELOGIN_PERIOD = ConfigBuilder("spark.yarn.kerberos.relogin.period") .timeConf(TimeUnit.SECONDS) .createWithDefaultString("1m") diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/AMCredentialRenewer.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/AMCredentialRenewer.scala index eaf2cff111a49..bc8d47dbd54c6 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/AMCredentialRenewer.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/AMCredentialRenewer.scala @@ -18,221 +18,160 @@ package org.apache.spark.deploy.yarn.security import java.security.PrivilegedExceptionAction import java.util.concurrent.{ScheduledExecutorService, TimeUnit} +import java.util.concurrent.atomic.AtomicReference import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{FileSystem, Path} -import org.apache.hadoop.security.UserGroupInformation +import org.apache.hadoop.security.{Credentials, UserGroupInformation} import org.apache.spark.SparkConf import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.deploy.security.HadoopDelegationTokenManager -import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil import org.apache.spark.deploy.yarn.config._ import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ +import org.apache.spark.rpc.RpcEndpointRef +import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.UpdateDelegationTokens +import org.apache.spark.ui.UIUtils import org.apache.spark.util.ThreadUtils /** - * The following methods are primarily meant to make sure long-running apps like Spark - * Streaming apps can run without interruption while accessing secured services. The - * scheduleLoginFromKeytab method is called on the AM to get the new credentials. - * This method wakes up a thread that logs into the KDC - * once 75% of the renewal interval of the original credentials used for the container - * has elapsed. It then obtains new credentials and writes them to HDFS in a - * pre-specified location - the prefix of which is specified in the sparkConf by - * spark.yarn.credentials.file (so the file(s) would be named c-timestamp1-1, c-timestamp2-2 etc. - * - each update goes to a new file, with a monotonically increasing suffix), also the - * timestamp1, timestamp2 here indicates the time of next update for CredentialUpdater. - * After this, the credentials are renewed once 75% of the new tokens renewal interval has elapsed. + * A manager tasked with periodically updating delegation tokens needed by the application. * - * On the executor and driver (yarn client mode) side, the updateCredentialsIfRequired method is - * called once 80% of the validity of the original credentials has elapsed. At that time the - * executor finds the credentials file with the latest timestamp and checks if it has read those - * credentials before (by keeping track of the suffix of the last file it read). If a new file has - * appeared, it will read the credentials and update the currently running UGI with it. This - * process happens again once 80% of the validity of this has expired. + * This manager is meant to make sure long-running apps (such as Spark Streaming apps) can run + * without interruption while accessing secured services. It periodically logs in to the KDC with + * user-provided credentials, and contacts all the configured secure services to obtain delegation + * tokens to be distributed to the rest of the application. + * + * This class will manage the kerberos login, by renewing the TGT when needed. Because the UGI API + * does not expose the TTL of the TGT, a configuration controls how often to check that a relogin is + * necessary. This is done reasonably often since the check is a no-op when the relogin is not yet + * needed. The check period can be overridden in the configuration. + * + * New delegation tokens are created once 75% of the renewal interval of the original tokens has + * elapsed. The new tokens are sent to the Spark driver endpoint once it's registered with the AM. + * The driver is tasked with distributing the tokens to other processes that might need them. */ private[yarn] class AMCredentialRenewer( sparkConf: SparkConf, - hadoopConf: Configuration, - credentialManager: YARNHadoopDelegationTokenManager) extends Logging { + hadoopConf: Configuration) extends Logging { - private var lastCredentialsFileSuffix = 0 + private val principal = sparkConf.get(PRINCIPAL).get + private val keytab = sparkConf.get(KEYTAB).get + private val credentialManager = new YARNHadoopDelegationTokenManager(sparkConf, hadoopConf) - private val credentialRenewerThread: ScheduledExecutorService = + private val renewalExecutor: ScheduledExecutorService = ThreadUtils.newDaemonSingleThreadScheduledExecutor("Credential Refresh Thread") - private val hadoopUtil = SparkHadoopUtil.get + private val driverRef = new AtomicReference[RpcEndpointRef]() - private val credentialsFile = sparkConf.get(CREDENTIALS_FILE_PATH) - private val daysToKeepFiles = sparkConf.get(CREDENTIALS_FILE_MAX_RETENTION) - private val numFilesToKeep = sparkConf.get(CREDENTIAL_FILE_MAX_COUNT) - private val freshHadoopConf = - hadoopUtil.getConfBypassingFSCache(hadoopConf, new Path(credentialsFile).toUri.getScheme) + private val renewalTask = new Runnable() { + override def run(): Unit = { + updateTokensTask() + } + } - @volatile private var timeOfNextRenewal: Long = sparkConf.get(CREDENTIALS_RENEWAL_TIME) + def setDriverRef(ref: RpcEndpointRef): Unit = { + driverRef.set(ref) + } /** - * Schedule a login from the keytab and principal set using the --principal and --keytab - * arguments to spark-submit. This login happens only when the credentials of the current user - * are about to expire. This method reads spark.yarn.principal and spark.yarn.keytab from - * SparkConf to do the login. This method is a no-op in non-YARN mode. + * Start the token renewer. Upon start, the renewer will: * + * - log in the configured user, and set up a task to keep that user's ticket renewed + * - obtain delegation tokens from all available providers + * - schedule a periodic task to update the tokens when needed. + * + * @return The newly logged in user. */ - private[spark] def scheduleLoginFromKeytab(): Unit = { - val principal = sparkConf.get(PRINCIPAL).get - val keytab = sparkConf.get(KEYTAB).get - - /** - * Schedule re-login and creation of new credentials. If credentials have already expired, this - * method will synchronously create new ones. - */ - def scheduleRenewal(runnable: Runnable): Unit = { - // Run now! - val remainingTime = timeOfNextRenewal - System.currentTimeMillis() - if (remainingTime <= 0) { - logInfo("Credentials have expired, creating new ones now.") - runnable.run() - } else { - logInfo(s"Scheduling login from keytab in $remainingTime millis.") - credentialRenewerThread.schedule(runnable, remainingTime, TimeUnit.MILLISECONDS) + def start(): UserGroupInformation = { + val originalCreds = UserGroupInformation.getCurrentUser().getCredentials() + val ugi = doLogin() + + val tgtRenewalTask = new Runnable() { + override def run(): Unit = { + ugi.checkTGTAndReloginFromKeytab() } } + val tgtRenewalPeriod = sparkConf.get(KERBEROS_RELOGIN_PERIOD) + renewalExecutor.scheduleAtFixedRate(tgtRenewalTask, tgtRenewalPeriod, tgtRenewalPeriod, + TimeUnit.SECONDS) - // This thread periodically runs on the AM to update the credentials on HDFS. - val credentialRenewerRunnable = - new Runnable { - override def run(): Unit = { - try { - writeNewCredentialsToHDFS(principal, keytab) - cleanupOldFiles() - } catch { - case e: Exception => - // Log the error and try to write new tokens back in an hour - logWarning("Failed to write out new credentials to HDFS, will try again in an " + - "hour! If this happens too often tasks will fail.", e) - credentialRenewerThread.schedule(this, 1, TimeUnit.HOURS) - return - } - scheduleRenewal(this) - } - } - // Schedule update of credentials. This handles the case of updating the credentials right now - // as well, since the renewal interval will be 0, and the thread will get scheduled - // immediately. - scheduleRenewal(credentialRenewerRunnable) + val creds = obtainTokensAndScheduleRenewal(ugi) + ugi.addCredentials(creds) + + // Transfer the original user's tokens to the new user, since that's needed to connect to + // YARN. Explicitly avoid overwriting tokens that already exist in the current user's + // credentials, since those were freshly obtained above (see SPARK-23361). + val existing = ugi.getCredentials() + existing.mergeAll(originalCreds) + ugi.addCredentials(existing) + + ugi + } + + def stop(): Unit = { + renewalExecutor.shutdown() + } + + private def scheduleRenewal(delay: Long): Unit = { + val _delay = math.max(0, delay) + logInfo(s"Scheduling login from keytab in ${UIUtils.formatDuration(delay)}.") + renewalExecutor.schedule(renewalTask, _delay, TimeUnit.MILLISECONDS) } - // Keeps only files that are newer than daysToKeepFiles days, and deletes everything else. At - // least numFilesToKeep files are kept for safety - private def cleanupOldFiles(): Unit = { - import scala.concurrent.duration._ + /** + * Periodic task to login to the KDC and create new delegation tokens. Re-schedules itself + * to fetch the next set of tokens when needed. + */ + private def updateTokensTask(): Unit = { try { - val remoteFs = FileSystem.get(freshHadoopConf) - val credentialsPath = new Path(credentialsFile) - val thresholdTime = System.currentTimeMillis() - (daysToKeepFiles.days).toMillis - hadoopUtil.listFilesSorted( - remoteFs, credentialsPath.getParent, - credentialsPath.getName, SparkHadoopUtil.SPARK_YARN_CREDS_TEMP_EXTENSION) - .dropRight(numFilesToKeep) - .takeWhile(_.getModificationTime < thresholdTime) - .foreach(x => remoteFs.delete(x.getPath, true)) + val freshUGI = doLogin() + val creds = obtainTokensAndScheduleRenewal(freshUGI) + val tokens = SparkHadoopUtil.get.serialize(creds) + + val driver = driverRef.get() + if (driver != null) { + logInfo("Updating delegation tokens.") + driver.send(UpdateDelegationTokens(tokens)) + } else { + // This shouldn't really happen, since the driver should register way before tokens expire + // (or the AM should time out the application). + logWarning("Delegation tokens close to expiration but no driver has registered yet.") + SparkHadoopUtil.get.addDelegationTokens(tokens, sparkConf) + } } catch { - // Such errors are not fatal, so don't throw. Make sure they are logged though case e: Exception => - logWarning("Error while attempting to cleanup old credentials. If you are seeing many " + - "such warnings there may be an issue with your HDFS cluster.", e) + val delay = TimeUnit.SECONDS.toMillis(sparkConf.get(CREDENTIALS_RENEWAL_RETRY_WAIT)) + logWarning(s"Failed to update tokens, will try again in ${UIUtils.formatDuration(delay)}!" + + " If this happens too often tasks will fail.", e) + scheduleRenewal(delay) } } - private def writeNewCredentialsToHDFS(principal: String, keytab: String): Unit = { - // Keytab is copied by YARN to the working directory of the AM, so full path is - // not needed. - - // HACK: - // HDFS will not issue new delegation tokens, if the Credentials object - // passed in already has tokens for that FS even if the tokens are expired (it really only - // checks if there are tokens for the service, and not if they are valid). So the only real - // way to get new tokens is to make sure a different Credentials object is used each time to - // get new tokens and then the new tokens are copied over the current user's Credentials. - // So: - // - we login as a different user and get the UGI - // - use that UGI to get the tokens (see doAs block below) - // - copy the tokens over to the current user's credentials (this will overwrite the tokens - // in the current user's Credentials object for this FS). - // The login to KDC happens each time new tokens are required, but this is rare enough to not - // have to worry about (like once every day or so). This makes this code clearer than having - // to login and then relogin every time (the HDFS API may not relogin since we don't use this - // UGI directly for HDFS communication. - logInfo(s"Attempting to login to KDC using principal: $principal") - val keytabLoggedInUGI = UserGroupInformation.loginUserFromKeytabAndReturnUGI(principal, keytab) - logInfo("Successfully logged into KDC.") - val tempCreds = keytabLoggedInUGI.getCredentials - val credentialsPath = new Path(credentialsFile) - val dst = credentialsPath.getParent - var nearestNextRenewalTime = Long.MaxValue - keytabLoggedInUGI.doAs(new PrivilegedExceptionAction[Void] { - // Get a copy of the credentials - override def run(): Void = { - nearestNextRenewalTime = credentialManager.obtainDelegationTokens( - freshHadoopConf, - tempCreds) - null + /** + * Obtain new delegation tokens from the available providers. Schedules a new task to fetch + * new tokens before the new set expires. + * + * @return Credentials containing the new tokens. + */ + private def obtainTokensAndScheduleRenewal(ugi: UserGroupInformation): Credentials = { + ugi.doAs(new PrivilegedExceptionAction[Credentials]() { + override def run(): Credentials = { + val creds = new Credentials() + val nextRenewal = credentialManager.obtainDelegationTokens(hadoopConf, creds) + + val timeToWait = SparkHadoopUtil.nextCredentialRenewalTime(nextRenewal, sparkConf) - + System.currentTimeMillis() + scheduleRenewal(timeToWait) + creds } }) - - val currTime = System.currentTimeMillis() - val timeOfNextUpdate = if (nearestNextRenewalTime <= currTime) { - // If next renewal time is earlier than current time, we set next renewal time to current - // time, this will trigger next renewal immediately. Also set next update time to current - // time. There still has a gap between token renewal and update will potentially introduce - // issue. - logWarning(s"Next credential renewal time ($nearestNextRenewalTime) is earlier than " + - s"current time ($currTime), which is unexpected, please check your credential renewal " + - "related configurations in the target services.") - timeOfNextRenewal = currTime - currTime - } else { - // Next valid renewal time is about 75% of credential renewal time, and update time is - // slightly later than valid renewal time (80% of renewal time). - timeOfNextRenewal = - SparkHadoopUtil.getDateOfNextUpdate(nearestNextRenewalTime, 0.75) - SparkHadoopUtil.getDateOfNextUpdate(nearestNextRenewalTime, 0.8) - } - - // Add the temp credentials back to the original ones. - UserGroupInformation.getCurrentUser.addCredentials(tempCreds) - val remoteFs = FileSystem.get(freshHadoopConf) - // If lastCredentialsFileSuffix is 0, then the AM is either started or restarted. If the AM - // was restarted, then the lastCredentialsFileSuffix might be > 0, so find the newest file - // and update the lastCredentialsFileSuffix. - if (lastCredentialsFileSuffix == 0) { - hadoopUtil.listFilesSorted( - remoteFs, credentialsPath.getParent, - credentialsPath.getName, SparkHadoopUtil.SPARK_YARN_CREDS_TEMP_EXTENSION) - .lastOption.foreach { status => - lastCredentialsFileSuffix = hadoopUtil.getSuffixForCredentialsPath(status.getPath) - } - } - val nextSuffix = lastCredentialsFileSuffix + 1 - - val tokenPathStr = - credentialsFile + SparkHadoopUtil.SPARK_YARN_CREDS_COUNTER_DELIM + - timeOfNextUpdate.toLong.toString + SparkHadoopUtil.SPARK_YARN_CREDS_COUNTER_DELIM + - nextSuffix - val tokenPath = new Path(tokenPathStr) - val tempTokenPath = new Path(tokenPathStr + SparkHadoopUtil.SPARK_YARN_CREDS_TEMP_EXTENSION) - - logInfo("Writing out delegation tokens to " + tempTokenPath.toString) - val credentials = UserGroupInformation.getCurrentUser.getCredentials - credentials.writeTokenStorageFile(tempTokenPath, freshHadoopConf) - logInfo(s"Delegation Tokens written out successfully. Renaming file to $tokenPathStr") - remoteFs.rename(tempTokenPath, tokenPath) - logInfo("Delegation token file rename complete.") - lastCredentialsFileSuffix = nextSuffix } - def stop(): Unit = { - credentialRenewerThread.shutdown() + private def doLogin(): UserGroupInformation = { + logInfo(s"Attempting to login to KDC using principal: $principal") + val ugi = UserGroupInformation.loginUserFromKeytabAndReturnUGI(principal, keytab) + logInfo("Successfully logged into KDC.") + ugi } + } diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/CredentialUpdater.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/CredentialUpdater.scala deleted file mode 100644 index fe173dffc22a8..0000000000000 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/CredentialUpdater.scala +++ /dev/null @@ -1,131 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.deploy.yarn.security - -import java.util.concurrent.{Executors, TimeUnit} - -import scala.util.control.NonFatal - -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{FileSystem, Path} -import org.apache.hadoop.security.{Credentials, UserGroupInformation} - -import org.apache.spark.SparkConf -import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.deploy.yarn.config._ -import org.apache.spark.internal.Logging -import org.apache.spark.util.{ThreadUtils, Utils} - -private[spark] class CredentialUpdater( - sparkConf: SparkConf, - hadoopConf: Configuration, - credentialManager: YARNHadoopDelegationTokenManager) extends Logging { - - @volatile private var lastCredentialsFileSuffix = 0 - - private val credentialsFile = sparkConf.get(CREDENTIALS_FILE_PATH) - private val freshHadoopConf = - SparkHadoopUtil.get.getConfBypassingFSCache( - hadoopConf, new Path(credentialsFile).toUri.getScheme) - - private val credentialUpdater = - Executors.newSingleThreadScheduledExecutor( - ThreadUtils.namedThreadFactory("Credential Refresh Thread")) - - // This thread wakes up and picks up new credentials from HDFS, if any. - private val credentialUpdaterRunnable = - new Runnable { - override def run(): Unit = Utils.logUncaughtExceptions(updateCredentialsIfRequired()) - } - - /** Start the credential updater task */ - def start(): Unit = { - val startTime = sparkConf.get(CREDENTIALS_UPDATE_TIME) - val remainingTime = startTime - System.currentTimeMillis() - if (remainingTime <= 0) { - credentialUpdater.schedule(credentialUpdaterRunnable, 1, TimeUnit.MINUTES) - } else { - logInfo(s"Scheduling credentials refresh from HDFS in $remainingTime ms.") - credentialUpdater.schedule(credentialUpdaterRunnable, remainingTime, TimeUnit.MILLISECONDS) - } - } - - private def updateCredentialsIfRequired(): Unit = { - val timeToNextUpdate = try { - val credentialsFilePath = new Path(credentialsFile) - val remoteFs = FileSystem.get(freshHadoopConf) - SparkHadoopUtil.get.listFilesSorted( - remoteFs, credentialsFilePath.getParent, - credentialsFilePath.getName, SparkHadoopUtil.SPARK_YARN_CREDS_TEMP_EXTENSION) - .lastOption.map { credentialsStatus => - val suffix = SparkHadoopUtil.get.getSuffixForCredentialsPath(credentialsStatus.getPath) - if (suffix > lastCredentialsFileSuffix) { - logInfo("Reading new credentials from " + credentialsStatus.getPath) - val newCredentials = getCredentialsFromHDFSFile(remoteFs, credentialsStatus.getPath) - lastCredentialsFileSuffix = suffix - UserGroupInformation.getCurrentUser.addCredentials(newCredentials) - logInfo("Credentials updated from credentials file.") - - val remainingTime = (getTimeOfNextUpdateFromFileName(credentialsStatus.getPath) - - System.currentTimeMillis()) - if (remainingTime <= 0) TimeUnit.MINUTES.toMillis(1) else remainingTime - } else { - // If current credential file is older than expected, sleep 1 hour and check again. - TimeUnit.HOURS.toMillis(1) - } - }.getOrElse { - // Wait for 1 minute to check again if there's no credential file currently - TimeUnit.MINUTES.toMillis(1) - } - } catch { - // Since the file may get deleted while we are reading it, catch the Exception and come - // back in an hour to try again - case NonFatal(e) => - logWarning("Error while trying to update credentials, will try again in 1 hour", e) - TimeUnit.HOURS.toMillis(1) - } - - logInfo(s"Scheduling credentials refresh from HDFS in $timeToNextUpdate ms.") - credentialUpdater.schedule( - credentialUpdaterRunnable, timeToNextUpdate, TimeUnit.MILLISECONDS) - } - - private def getCredentialsFromHDFSFile(remoteFs: FileSystem, tokenPath: Path): Credentials = { - val stream = remoteFs.open(tokenPath) - try { - val newCredentials = new Credentials() - newCredentials.readTokenStorageStream(stream) - newCredentials - } finally { - stream.close() - } - } - - private def getTimeOfNextUpdateFromFileName(credentialsPath: Path): Long = { - val name = credentialsPath.getName - val index = name.lastIndexOf(SparkHadoopUtil.SPARK_YARN_CREDS_COUNTER_DELIM) - val slice = name.substring(0, index) - val last2index = slice.lastIndexOf(SparkHadoopUtil.SPARK_YARN_CREDS_COUNTER_DELIM) - name.substring(last2index + 1, index).toLong - } - - def stop(): Unit = { - credentialUpdater.shutdown() - } - -} diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/YARNHadoopDelegationTokenManager.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/YARNHadoopDelegationTokenManager.scala index 163cfb4eb8624..d4eeb6bbcf886 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/YARNHadoopDelegationTokenManager.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/YARNHadoopDelegationTokenManager.scala @@ -22,11 +22,11 @@ import java.util.ServiceLoader import scala.collection.JavaConverters._ import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.FileSystem import org.apache.hadoop.security.Credentials import org.apache.spark.SparkConf import org.apache.spark.deploy.security.HadoopDelegationTokenManager +import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil import org.apache.spark.internal.Logging import org.apache.spark.util.Utils @@ -37,11 +37,10 @@ import org.apache.spark.util.Utils */ private[yarn] class YARNHadoopDelegationTokenManager( sparkConf: SparkConf, - hadoopConf: Configuration, - fileSystems: Configuration => Set[FileSystem]) extends Logging { + hadoopConf: Configuration) extends Logging { - private val delegationTokenManager = - new HadoopDelegationTokenManager(sparkConf, hadoopConf, fileSystems) + private val delegationTokenManager = new HadoopDelegationTokenManager(sparkConf, hadoopConf, + conf => YarnSparkHadoopUtil.hadoopFSsToAccess(sparkConf, conf)) // public for testing val credentialProviders = getCredentialProviders diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala index 0c6206eebe41d..06e54a2eaf95a 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala @@ -22,7 +22,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.hadoop.yarn.api.records.YarnApplicationState import org.apache.spark.{SparkContext, SparkException} -import org.apache.spark.deploy.yarn.{Client, ClientArguments, YarnSparkHadoopUtil} +import org.apache.spark.deploy.yarn.{Client, ClientArguments} import org.apache.spark.deploy.yarn.config._ import org.apache.spark.internal.Logging import org.apache.spark.launcher.SparkAppHandle @@ -62,12 +62,6 @@ private[spark] class YarnClientSchedulerBackend( super.start() waitForApplication() - // SPARK-8851: In yarn-client mode, the AM still does the credentials refresh. The driver - // reads the credentials from HDFS, just like the executors and updates its own credentials - // cache. - if (conf.contains("spark.yarn.credentials.file")) { - YarnSparkHadoopUtil.startCredentialUpdater(conf) - } monitorThread = asyncMonitorApplication() monitorThread.start() } @@ -153,7 +147,6 @@ private[spark] class YarnClientSchedulerBackend( client.reportLauncherState(SparkAppHandle.State.FINISHED) super.stop() - YarnSparkHadoopUtil.stopCredentialUpdater() client.stop() logInfo("Stopped") } diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala index bb615c36cd97f..63bea3e7a5003 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala @@ -24,9 +24,11 @@ import scala.concurrent.ExecutionContext.Implicits.global import scala.util.{Failure, Success} import scala.util.control.NonFatal +import org.apache.hadoop.security.UserGroupInformation import org.apache.hadoop.yarn.api.records.{ApplicationAttemptId, ApplicationId} import org.apache.spark.SparkContext +import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging import org.apache.spark.rpc._ import org.apache.spark.scheduler._ @@ -70,6 +72,7 @@ private[spark] abstract class YarnSchedulerBackend( /** Scheduler extension services. */ private val services: SchedulerExtensionServices = new SchedulerExtensionServices() + /** * Bind to YARN. This *must* be done before calling [[start()]]. * @@ -263,8 +266,13 @@ private[spark] abstract class YarnSchedulerBackend( logWarning(s"Requesting driver to remove executor $executorId for reason $reason") driverEndpoint.send(r) } - } + case u @ UpdateDelegationTokens(tokens) => + // Add the tokens to the current user and send a message to the scheduler so that it + // notifies all registered executors of the new tokens. + SparkHadoopUtil.get.addDelegationTokens(tokens, sc.conf) + driverEndpoint.send(u) + } override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case r: RequestExecutors => diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/security/YARNHadoopDelegationTokenManagerSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/security/YARNHadoopDelegationTokenManagerSuite.scala index 3c7cdc0f1dab8..9fa749b14c98c 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/security/YARNHadoopDelegationTokenManagerSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/security/YARNHadoopDelegationTokenManagerSuite.scala @@ -22,7 +22,6 @@ import org.apache.hadoop.security.Credentials import org.scalatest.Matchers import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil class YARNHadoopDelegationTokenManagerSuite extends SparkFunSuite with Matchers { private var credentialManager: YARNHadoopDelegationTokenManager = null @@ -36,11 +35,7 @@ class YARNHadoopDelegationTokenManagerSuite extends SparkFunSuite with Matchers } test("Correctly loads credential providers") { - credentialManager = new YARNHadoopDelegationTokenManager( - sparkConf, - hadoopConf, - conf => YarnSparkHadoopUtil.hadoopFSsToAccess(sparkConf, conf)) - + credentialManager = new YARNHadoopDelegationTokenManager(sparkConf, hadoopConf) credentialManager.credentialProviders.get("yarn-test") should not be (None) } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala index aed67a5027433..3703a87cdb9ab 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala @@ -57,9 +57,6 @@ class Checkpoint(ssc: StreamingContext, val checkpointTime: Time) "spark.yarn.jars", "spark.yarn.keytab", "spark.yarn.principal", - "spark.yarn.credentials.file", - "spark.yarn.credentials.renewalTime", - "spark.yarn.credentials.updateTime", "spark.ui.filters", "spark.mesos.driver.frameworkId") From 92e952557dbd8a170d66d615e25c6c6a8399dd43 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Fri, 23 Mar 2018 21:01:07 +0900 Subject: [PATCH 0513/1161] [MINOR][R] Fix R lint failure ## What changes were proposed in this pull request? The lint failure bugged me: ```R R/SQLContext.R:715:97: style: Trailing whitespace is superfluous. #' file-based streaming data source. \code{timeZone} to indicate a timezone to be used to ^ tests/fulltests/test_streaming.R:239:45: style: Commas should always have a space after. expect_equal(times[order(times$eventTime),][1, 2], 2) ^ lintr checks failed. ``` and I actually saw https://amplab.cs.berkeley.edu/jenkins/job/spark-master-test-sbt-hadoop-2.6-ubuntu-test/500/console too. If I understood correctly, there is a try about moving to Unbuntu one. ## How was this patch tested? Manually tested by `./dev/lint-r`: ``` ... lintr checks passed. ``` Author: hyukjinkwon Closes #20879 from HyukjinKwon/minor-r-lint. --- R/pkg/R/SQLContext.R | 2 +- R/pkg/tests/fulltests/test_streaming.R | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R index ebec0ce3d1920..429dd5d565492 100644 --- a/R/pkg/R/SQLContext.R +++ b/R/pkg/R/SQLContext.R @@ -712,7 +712,7 @@ read.jdbc <- function(url, tableName, #' @param schema The data schema defined in structType or a DDL-formatted string, this is #' required for file-based streaming data source #' @param ... additional external data source specific named options, for instance \code{path} for -#' file-based streaming data source. \code{timeZone} to indicate a timezone to be used to +#' file-based streaming data source. \code{timeZone} to indicate a timezone to be used to #' parse timestamps in the JSON/CSV data sources or partition values; If it isn't set, it #' uses the default value, session local timezone. #' @return SparkDataFrame diff --git a/R/pkg/tests/fulltests/test_streaming.R b/R/pkg/tests/fulltests/test_streaming.R index a354d50c6b54e..bfb1a046490ec 100644 --- a/R/pkg/tests/fulltests/test_streaming.R +++ b/R/pkg/tests/fulltests/test_streaming.R @@ -236,7 +236,7 @@ test_that("Watermark", { times <- collect(sql("SELECT * FROM times")) # looks like write timing can affect the first bucket; but it should be t - expect_equal(times[order(times$eventTime),][1, 2], 2) + expect_equal(times[order(times$eventTime), ][1, 2], 2) stopQuery(q) unlink(parquetPath) From 6ac4fba69290e1c7de2c0a5863f224981dedb919 Mon Sep 17 00:00:00 2001 From: arucard21 Date: Fri, 23 Mar 2018 21:02:34 +0900 Subject: [PATCH 0514/1161] [SPARK-23769][CORE] Remove comments that unnecessarily disable Scalastyle check ## What changes were proposed in this pull request? We re-enabled the Scalastyle checker on a line of code. It was previously disabled, but it does not violate any of the rules. So there's no reason to disable the Scalastyle checker here. ## How was this patch tested? We tested this by running `build/mvn scalastyle:check` after removing the comments that disable the checker. This check passed with no errors or warnings for Spark Core ``` [INFO] [INFO] ------------------------------------------------------------------------ [INFO] Building Spark Project Core 2.4.0-SNAPSHOT [INFO] ------------------------------------------------------------------------ [INFO] [INFO] --- scalastyle-maven-plugin:1.0.0:check (default-cli) spark-core_2.11 --- Saving to outputFile=/spark/core/target/scalastyle-output.xml Processed 485 file(s) Found 0 errors Found 0 warnings Found 0 infos ``` We did not run all tests (with `dev/run-tests`) since this Scalastyle check seemed sufficient. ## Co-contributors: chialun-yeh Hrayo712 vpourquie Author: arucard21 Closes #20880 from arucard21/scalastyle_util. --- .../org/apache/spark/storage/BlockReplicationPolicy.scala | 4 +--- .../main/scala/org/apache/spark/util/CompletionIterator.scala | 2 -- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockReplicationPolicy.scala b/core/src/main/scala/org/apache/spark/storage/BlockReplicationPolicy.scala index 353eac60df171..0bacc34cdfd90 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockReplicationPolicy.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockReplicationPolicy.scala @@ -54,10 +54,9 @@ trait BlockReplicationPolicy { } object BlockReplicationUtils { - // scalastyle:off line.size.limit /** * Uses sampling algorithm by Robert Floyd. Finds a random sample in O(n) while - * minimizing space usage. Please see + * minimizing space usage. Please see * here. * * @param n total number of indices @@ -65,7 +64,6 @@ object BlockReplicationUtils { * @param r random number generator * @return list of m random unique indices */ - // scalastyle:on line.size.limit private def getSampleIds(n: Int, m: Int, r: Random): List[Int] = { val indices = (n - m + 1 to n).foldLeft(mutable.LinkedHashSet.empty[Int]) {case (set, i) => val t = r.nextInt(i) + 1 diff --git a/core/src/main/scala/org/apache/spark/util/CompletionIterator.scala b/core/src/main/scala/org/apache/spark/util/CompletionIterator.scala index 31d230d0fec8e..21acaa95c5645 100644 --- a/core/src/main/scala/org/apache/spark/util/CompletionIterator.scala +++ b/core/src/main/scala/org/apache/spark/util/CompletionIterator.scala @@ -22,9 +22,7 @@ package org.apache.spark.util * through all the elements. */ private[spark] -// scalastyle:off abstract class CompletionIterator[ +A, +I <: Iterator[A]](sub: I) extends Iterator[A] { -// scalastyle:on private[this] var completed = false def next(): A = sub.next() From 8b56f16640fc4156aa7bd529c54469d27635b951 Mon Sep 17 00:00:00 2001 From: bag_of_tricks Date: Fri, 23 Mar 2018 10:36:23 -0700 Subject: [PATCH 0515/1161] [SPARK-23759][UI] Unable to bind Spark UI to specific host name / IP ## What changes were proposed in this pull request? Fixes SPARK-23759 by moving connector.start() after connector.setHost() Problem was created due connector.setHost(hostName) call was after connector.start() ## How was this patch tested? Patch was tested after build and deployment. This patch requires SPARK_LOCAL_IP environment variable to be set on spark-env.sh Author: bag_of_tricks Closes #20883 from felixalbani/SPARK-23759. --- core/src/main/scala/org/apache/spark/ui/JettyUtils.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala index 0adeb4058b6e4..0e8a6307de6a8 100644 --- a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala @@ -343,12 +343,13 @@ private[spark] object JettyUtils extends Logging { -1, connectionFactories: _*) connector.setPort(port) - connector.start() + connector.setHost(hostName) // Currently we only use "SelectChannelConnector" // Limit the max acceptor number to 8 so that we don't waste a lot of threads connector.setAcceptQueueSize(math.min(connector.getAcceptors, 8)) - connector.setHost(hostName) + + connector.start() // The number of selectors always equals to the number of acceptors minThreads += connector.getAcceptors * 2 From cb43bbe13606673349511829fd71d1f34fc39c45 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Fri, 23 Mar 2018 11:42:40 -0700 Subject: [PATCH 0516/1161] [SPARK-21685][PYTHON][ML] PySpark Params isSet state should not change after transform ## What changes were proposed in this pull request? Currently when a PySpark Model is transformed, default params that have not been explicitly set are then set on the Java side on the call to `wrapper._transfer_values_to_java`. This incorrectly changes the state of the Param as it should still be marked as a default value only. ## How was this patch tested? Added a new test to verify that when transferring Params to Java, default params have their state preserved. Author: Bryan Cutler Closes #18982 from BryanCutler/pyspark-ml-param-to-java-defaults-SPARK-21685. --- python/pyspark/ml/tests.py | 20 +++++++++++++++++++- python/pyspark/ml/wrapper.py | 13 ++++++++++--- 2 files changed, 29 insertions(+), 4 deletions(-) diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index fd45fd00b270b..080119959a4e8 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -369,7 +369,7 @@ def test_property(self): raise RuntimeError("Test property to raise error when invoked") -class ParamTests(PySparkTestCase): +class ParamTests(SparkSessionTestCase): def test_copy_new_parent(self): testParams = TestParams() @@ -514,6 +514,24 @@ def test_logistic_regression_check_thresholds(self): LogisticRegression, threshold=0.42, thresholds=[0.5, 0.5] ) + def test_preserve_set_state(self): + dataset = self.spark.createDataFrame([(0.5,)], ["data"]) + binarizer = Binarizer(inputCol="data") + self.assertFalse(binarizer.isSet("threshold")) + binarizer.transform(dataset) + binarizer._transfer_params_from_java() + self.assertFalse(binarizer.isSet("threshold"), + "Params not explicitly set should remain unset after transform") + + def test_default_params_transferred(self): + dataset = self.spark.createDataFrame([(0.5,)], ["data"]) + binarizer = Binarizer(inputCol="data") + # intentionally change the pyspark default, but don't set it + binarizer._defaultParamMap[binarizer.outputCol] = "my_default" + result = binarizer.transform(dataset).select("my_default").collect() + self.assertFalse(binarizer.isSet(binarizer.outputCol)) + self.assertEqual(result[0][0], 1.0) + @staticmethod def check_params(test_self, py_stage, check_params_exist=True): """ diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py index 5061f6434794a..d325633195ddb 100644 --- a/python/pyspark/ml/wrapper.py +++ b/python/pyspark/ml/wrapper.py @@ -118,11 +118,18 @@ def _transfer_params_to_java(self): """ Transforms the embedded params to the companion Java object. """ - paramMap = self.extractParamMap() + pair_defaults = [] for param in self.params: - if param in paramMap: - pair = self._make_java_param_pair(param, paramMap[param]) + if self.isSet(param): + pair = self._make_java_param_pair(param, self._paramMap[param]) self._java_obj.set(pair) + if self.hasDefault(param): + pair = self._make_java_param_pair(param, self._defaultParamMap[param]) + pair_defaults.append(pair) + if len(pair_defaults) > 0: + sc = SparkContext._active_spark_context + pair_defaults_seq = sc._jvm.PythonUtils.toSeq(pair_defaults) + self._java_obj.setDefault(pair_defaults_seq) def _transfer_param_map_to_java(self, pyParamMap): """ From 95c03cbd27cea2255d9d748f9a84a0a38e54594d Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Fri, 23 Mar 2018 11:56:17 -0700 Subject: [PATCH 0517/1161] [SPARK-23783][SPARK-11239][ML] Add PMML export to Spark ML pipelines ## What changes were proposed in this pull request? Adds PMML export support to Spark ML pipelines in the style of Spark's DataSource API to allow library authors to add their own model export formats. Includes a specific implementation for Spark ML linear regression PMML export. In addition to adding PMML to reach parity with our current MLlib implementation, this approach will allow other libraries & formats (like PFA) to implement and export models with a unified API. ## How was this patch tested? Basic unit test. Author: Holden Karau Author: Holden Karau Closes #19876 from holdenk/SPARK-11171-SPARK-11237-Add-PMML-export-for-ML-KMeans-r2. --- .../org.apache.spark.ml.util.MLFormatRegister | 2 + .../ml/regression/LinearRegression.scala | 70 ++++--- .../org/apache/spark/ml/util/ReadWrite.scala | 173 +++++++++++++++++- .../org.apache.spark.ml.util.MLFormatRegister | 3 + .../ml/regression/LinearRegressionSuite.scala | 27 ++- .../spark/ml/util/PMMLReadWriteTest.scala | 55 ++++++ .../org/apache/spark/ml/util/PMMLUtils.scala | 43 +++++ .../apache/spark/ml/util/ReadWriteSuite.scala | 132 +++++++++++++ 8 files changed, 474 insertions(+), 31 deletions(-) create mode 100644 mllib/src/main/resources/META-INF/services/org.apache.spark.ml.util.MLFormatRegister create mode 100644 mllib/src/test/resources/META-INF/services/org.apache.spark.ml.util.MLFormatRegister create mode 100644 mllib/src/test/scala/org/apache/spark/ml/util/PMMLReadWriteTest.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/util/PMMLUtils.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/util/ReadWriteSuite.scala diff --git a/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.util.MLFormatRegister b/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.util.MLFormatRegister new file mode 100644 index 0000000000000..5e5484fd8784d --- /dev/null +++ b/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.util.MLFormatRegister @@ -0,0 +1,2 @@ +org.apache.spark.ml.regression.InternalLinearRegressionModelWriter +org.apache.spark.ml.regression.PMMLLinearRegressionModelWriter \ No newline at end of file diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index 92510154d500e..f67d9d831f327 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -27,7 +27,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.SparkException import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.internal.Logging -import org.apache.spark.ml.PredictorParams +import org.apache.spark.ml.{PipelineStage, PredictorParams} import org.apache.spark.ml.feature.Instance import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.linalg.BLAS._ @@ -39,10 +39,11 @@ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ import org.apache.spark.mllib.evaluation.RegressionMetrics import org.apache.spark.mllib.linalg.VectorImplicits._ +import org.apache.spark.mllib.regression.{LinearRegressionModel => OldLinearRegressionModel} import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, Dataset, Row} +import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{DataType, DoubleType, StructType} import org.apache.spark.storage.StorageLevel @@ -643,7 +644,7 @@ class LinearRegressionModel private[ml] ( @Since("1.3.0") val intercept: Double, @Since("2.3.0") val scale: Double) extends RegressionModel[Vector, LinearRegressionModel] - with LinearRegressionParams with MLWritable { + with LinearRegressionParams with GeneralMLWritable { private[ml] def this(uid: String, coefficients: Vector, intercept: Double) = this(uid, coefficients, intercept, 1.0) @@ -710,7 +711,7 @@ class LinearRegressionModel private[ml] ( } /** - * Returns a [[org.apache.spark.ml.util.MLWriter]] instance for this ML instance. + * Returns a [[org.apache.spark.ml.util.GeneralMLWriter]] instance for this ML instance. * * For [[LinearRegressionModel]], this does NOT currently save the training [[summary]]. * An option to save [[summary]] may be added in the future. @@ -718,7 +719,50 @@ class LinearRegressionModel private[ml] ( * This also does not save the [[parent]] currently. */ @Since("1.6.0") - override def write: MLWriter = new LinearRegressionModel.LinearRegressionModelWriter(this) + override def write: GeneralMLWriter = new GeneralMLWriter(this) +} + +/** A writer for LinearRegression that handles the "internal" (or default) format */ +private class InternalLinearRegressionModelWriter + extends MLWriterFormat with MLFormatRegister { + + override def format(): String = "internal" + override def stageName(): String = "org.apache.spark.ml.regression.LinearRegressionModel" + + private case class Data(intercept: Double, coefficients: Vector, scale: Double) + + override def write(path: String, sparkSession: SparkSession, + optionMap: mutable.Map[String, String], stage: PipelineStage): Unit = { + val instance = stage.asInstanceOf[LinearRegressionModel] + val sc = sparkSession.sparkContext + // Save metadata and Params + DefaultParamsWriter.saveMetadata(instance, path, sc) + // Save model data: intercept, coefficients, scale + val data = Data(instance.intercept, instance.coefficients, instance.scale) + val dataPath = new Path(path, "data").toString + sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + } +} + +/** A writer for LinearRegression that handles the "pmml" format */ +private class PMMLLinearRegressionModelWriter + extends MLWriterFormat with MLFormatRegister { + + override def format(): String = "pmml" + + override def stageName(): String = "org.apache.spark.ml.regression.LinearRegressionModel" + + private case class Data(intercept: Double, coefficients: Vector) + + override def write(path: String, sparkSession: SparkSession, + optionMap: mutable.Map[String, String], stage: PipelineStage): Unit = { + val sc = sparkSession.sparkContext + // Construct the MLLib model which knows how to write to PMML. + val instance = stage.asInstanceOf[LinearRegressionModel] + val oldModel = new OldLinearRegressionModel(instance.coefficients, instance.intercept) + // Save PMML + oldModel.toPMML(sc, path) + } } @Since("1.6.0") @@ -730,22 +774,6 @@ object LinearRegressionModel extends MLReadable[LinearRegressionModel] { @Since("1.6.0") override def load(path: String): LinearRegressionModel = super.load(path) - /** [[MLWriter]] instance for [[LinearRegressionModel]] */ - private[LinearRegressionModel] class LinearRegressionModelWriter(instance: LinearRegressionModel) - extends MLWriter with Logging { - - private case class Data(intercept: Double, coefficients: Vector, scale: Double) - - override protected def saveImpl(path: String): Unit = { - // Save metadata and Params - DefaultParamsWriter.saveMetadata(instance, path, sc) - // Save model data: intercept, coefficients, scale - val data = Data(instance.intercept, instance.coefficients, instance.scale) - val dataPath = new Path(path, "data").toString - sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) - } - } - private class LinearRegressionModelReader extends MLReader[LinearRegressionModel] { /** Checked against metadata when loading model */ diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala index a616907800969..7edcd498678cc 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -18,9 +18,11 @@ package org.apache.spark.ml.util import java.io.IOException -import java.util.Locale +import java.util.{Locale, ServiceLoader} +import scala.collection.JavaConverters._ import scala.collection.mutable +import scala.util.{Failure, Success, Try} import org.apache.hadoop.fs.Path import org.json4s._ @@ -28,8 +30,8 @@ import org.json4s.{DefaultFormats, JObject} import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ -import org.apache.spark.SparkContext -import org.apache.spark.annotation.{DeveloperApi, Since} +import org.apache.spark.{SparkContext, SparkException} +import org.apache.spark.annotation.{DeveloperApi, InterfaceStability, Since} import org.apache.spark.internal.Logging import org.apache.spark.ml._ import org.apache.spark.ml.classification.{OneVsRest, OneVsRestModel} @@ -86,7 +88,82 @@ private[util] sealed trait BaseReadWrite { } /** - * Abstract class for utility classes that can save ML instances. + * Abstract class to be implemented by objects that provide ML exportability. + * + * A new instance of this class will be instantiated each time a save call is made. + * + * Must have a valid zero argument constructor which will be called to instantiate. + * + * @since 2.4.0 + */ +@InterfaceStability.Unstable +@Since("2.4.0") +trait MLWriterFormat { + /** + * Function to write the provided pipeline stage out. + * + * @param path The path to write the result out to. + * @param session SparkSession associated with the write request. + * @param optionMap User provided options stored as strings. + * @param stage The pipeline stage to be saved. + */ + @Since("2.4.0") + def write(path: String, session: SparkSession, optionMap: mutable.Map[String, String], + stage: PipelineStage): Unit +} + +/** + * ML export formats for should implement this trait so that users can specify a shortname rather + * than the fully qualified class name of the exporter. + * + * A new instance of this class will be instantiated each time a save call is made. + * + * @since 2.4.0 + */ +@InterfaceStability.Unstable +@Since("2.4.0") +trait MLFormatRegister extends MLWriterFormat { + /** + * The string that represents the format that this format provider uses. This is, along with + * stageName, is overridden by children to provide a nice alias for the writer. For example: + * + * {{{ + * override def format(): String = + * "pmml" + * }}} + * Indicates that this format is capable of saving a pmml model. + * + * Must have a valid zero argument constructor which will be called to instantiate. + * + * Format discovery is done using a ServiceLoader so make sure to list your format in + * META-INF/services. + * @since 2.4.0 + */ + @Since("2.4.0") + def format(): String + + /** + * The string that represents the stage type that this writer supports. This is, along with + * format, is overridden by children to provide a nice alias for the writer. For example: + * + * {{{ + * override def stageName(): String = + * "org.apache.spark.ml.regression.LinearRegressionModel" + * }}} + * Indicates that this format is capable of saving Spark's own PMML model. + * + * Format discovery is done using a ServiceLoader so make sure to list your format in + * META-INF/services. + * @since 2.4.0 + */ + @Since("2.4.0") + def stageName(): String + + private[ml] def shortName(): String = s"${format()}+${stageName()}" +} + +/** + * Abstract class for utility classes that can save ML instances in Spark's internal format. */ @Since("1.6.0") abstract class MLWriter extends BaseReadWrite with Logging { @@ -110,6 +187,15 @@ abstract class MLWriter extends BaseReadWrite with Logging { @Since("1.6.0") protected def saveImpl(path: String): Unit + /** + * Overwrites if the output path already exists. + */ + @Since("1.6.0") + def overwrite(): this.type = { + shouldOverwrite = true + this + } + /** * Map to store extra options for this writer. */ @@ -126,15 +212,73 @@ abstract class MLWriter extends BaseReadWrite with Logging { this } + // override for Java compatibility + @Since("1.6.0") + override def session(sparkSession: SparkSession): this.type = super.session(sparkSession) + + // override for Java compatibility + @Since("1.6.0") + override def context(sqlContext: SQLContext): this.type = super.session(sqlContext.sparkSession) +} + +/** + * A ML Writer which delegates based on the requested format. + */ +@InterfaceStability.Unstable +@Since("2.4.0") +class GeneralMLWriter(stage: PipelineStage) extends MLWriter with Logging { + private var source: String = "internal" + /** - * Overwrites if the output path already exists. + * Specifies the format of ML export (e.g. "pmml", "internal", or + * the fully qualified class name for export). */ - @Since("1.6.0") - def overwrite(): this.type = { - shouldOverwrite = true + @Since("2.4.0") + def format(source: String): this.type = { + this.source = source this } + /** + * Dispatches the save to the correct MLFormat. + */ + @Since("2.4.0") + @throws[IOException]("If the input path already exists but overwrite is not enabled.") + @throws[SparkException]("If multiple sources for a given short name format are found.") + override protected def saveImpl(path: String): Unit = { + val loader = Utils.getContextOrSparkClassLoader + val serviceLoader = ServiceLoader.load(classOf[MLFormatRegister], loader) + val stageName = stage.getClass.getName + val targetName = s"$source+$stageName" + val formats = serviceLoader.asScala.toList + val shortNames = formats.map(_.shortName()) + val writerCls = formats.filter(_.shortName().equalsIgnoreCase(targetName)) match { + // requested name did not match any given registered alias + case Nil => + Try(loader.loadClass(source)) match { + case Success(writer) => + // Found the ML writer using the fully qualified path + writer + case Failure(error) => + throw new SparkException( + s"Could not load requested format $source for $stageName ($targetName) had $formats" + + s"supporting $shortNames", error) + } + case head :: Nil => + head.getClass + case _ => + // Multiple sources + throw new SparkException( + s"Multiple writers found for $source+$stageName, try using the class name of the writer") + } + if (classOf[MLWriterFormat].isAssignableFrom(writerCls)) { + val writer = writerCls.newInstance().asInstanceOf[MLWriterFormat] + writer.write(path, sparkSession, optionMap, stage) + } else { + throw new SparkException(s"ML source $source is not a valid MLWriterFormat") + } + } + // override for Java compatibility override def session(sparkSession: SparkSession): this.type = super.session(sparkSession) @@ -162,6 +306,19 @@ trait MLWritable { def save(path: String): Unit = write.save(path) } +/** + * Trait for classes that provide `GeneralMLWriter`. + */ +@Since("2.4.0") +@InterfaceStability.Unstable +trait GeneralMLWritable extends MLWritable { + /** + * Returns an `MLWriter` instance for this ML instance. + */ + @Since("2.4.0") + override def write: GeneralMLWriter +} + /** * :: DeveloperApi :: * diff --git a/mllib/src/test/resources/META-INF/services/org.apache.spark.ml.util.MLFormatRegister b/mllib/src/test/resources/META-INF/services/org.apache.spark.ml.util.MLFormatRegister new file mode 100644 index 0000000000000..100ef2545418f --- /dev/null +++ b/mllib/src/test/resources/META-INF/services/org.apache.spark.ml.util.MLFormatRegister @@ -0,0 +1,3 @@ +org.apache.spark.ml.util.DuplicateLinearRegressionWriter1 +org.apache.spark.ml.util.DuplicateLinearRegressionWriter2 +org.apache.spark.ml.util.FakeLinearRegressionWriterWithName diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala index 9b19f63eba1bd..90ceb7dee38f7 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -17,18 +17,23 @@ package org.apache.spark.ml.regression +import scala.collection.JavaConverters._ +import scala.collection.mutable import scala.util.Random +import org.dmg.pmml.{OpType, PMML, RegressionModel => PMMLRegressionModel} + import org.apache.spark.ml.feature.Instance import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.linalg.{DenseVector, Vector, Vectors} import org.apache.spark.ml.param.{ParamMap, ParamsSuite} -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} +import org.apache.spark.ml.util._ import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.util.LinearDataGenerator import org.apache.spark.sql.{DataFrame, Row} -class LinearRegressionSuite extends MLTest with DefaultReadWriteTest { + +class LinearRegressionSuite extends MLTest with DefaultReadWriteTest with PMMLReadWriteTest { import testImplicits._ @@ -1052,6 +1057,24 @@ class LinearRegressionSuite extends MLTest with DefaultReadWriteTest { LinearRegressionSuite.allParamSettings, checkModelData) } + test("pmml export") { + val lr = new LinearRegression() + val model = lr.fit(datasetWithWeight) + def checkModel(pmml: PMML): Unit = { + val dd = pmml.getDataDictionary + assert(dd.getNumberOfFields === 3) + val fields = dd.getDataFields.asScala + assert(fields(0).getName().toString === "field_0") + assert(fields(0).getOpType() == OpType.CONTINUOUS) + val pmmlRegressionModel = pmml.getModels().get(0).asInstanceOf[PMMLRegressionModel] + val pmmlPredictors = pmmlRegressionModel.getRegressionTables.get(0).getNumericPredictors + val pmmlWeights = pmmlPredictors.asScala.map(_.getCoefficient()).toList + assert(pmmlWeights(0) ~== model.coefficients(0) relTol 1E-3) + assert(pmmlWeights(1) ~== model.coefficients(1) relTol 1E-3) + } + testPMMLWrite(sc, model, checkModel) + } + test("should support all NumericType labels and weights, and not support other types") { for (solver <- Seq("auto", "l-bfgs", "normal")) { val lr = new LinearRegression().setMaxIter(1).setSolver(solver) diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/PMMLReadWriteTest.scala b/mllib/src/test/scala/org/apache/spark/ml/util/PMMLReadWriteTest.scala new file mode 100644 index 0000000000000..d2c4832b12bac --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/util/PMMLReadWriteTest.scala @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.util + +import java.io.{File, IOException} + +import org.dmg.pmml.PMML +import org.scalatest.Suite + +import org.apache.spark.SparkContext +import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.ml.param._ +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.Dataset + +trait PMMLReadWriteTest extends TempDirectory { self: Suite => + /** + * Test PMML export. Requires exported model is small enough to be loaded locally. + * Checks that the model can be exported and the result is valid PMML, but does not check + * the specific contents of the model. + */ + def testPMMLWrite[T <: Params with GeneralMLWritable](sc: SparkContext, instance: T, + checkModelData: PMML => Unit): Unit = { + val uid = instance.uid + val subdirName = Identifiable.randomUID("pmml-") + + val subdir = new File(tempDir, subdirName) + val path = new File(subdir, uid).getPath + + instance.write.format("pmml").save(path) + intercept[IOException] { + instance.write.format("pmml").save(path) + } + instance.write.format("pmml").overwrite().save(path) + val pmmlStr = sc.textFile(path).collect.mkString("\n") + val pmmlModel = PMMLUtils.loadFromString(pmmlStr) + assert(pmmlModel.getHeader().getApplication().getName().startsWith("Apache Spark")) + checkModelData(pmmlModel) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/PMMLUtils.scala b/mllib/src/test/scala/org/apache/spark/ml/util/PMMLUtils.scala new file mode 100644 index 0000000000000..dbdc69f95d841 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/util/PMMLUtils.scala @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.ml.util + +import java.io.StringReader +import javax.xml.bind.Unmarshaller +import javax.xml.transform.Source + +import org.dmg.pmml._ +import org.jpmml.model.{ImportFilter, JAXBUtil} +import org.xml.sax.InputSource + +/** + * Testing utils for working with PMML. + * Predictive Model Markup Language (PMML) is an XML-based file format + * developed by the Data Mining Group (www.dmg.org). + */ +private[spark] object PMMLUtils { + /** + * :: Experimental :: + * Load a PMML model from a string. Note: for testing only, PMML model evaluation is supported + * through external spark-packages. + */ + def loadFromString(input: String): PMML = { + val is = new StringReader(input) + val transformed = ImportFilter.apply(new InputSource(is)) + JAXBUtil.unmarshalPMML(transformed) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/ReadWriteSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/util/ReadWriteSuite.scala new file mode 100644 index 0000000000000..f4c1f0bdb32cd --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/util/ReadWriteSuite.scala @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.util + +import scala.collection.mutable + +import org.apache.spark.SparkException +import org.apache.spark.ml.PipelineStage +import org.apache.spark.ml.regression.LinearRegression +import org.apache.spark.mllib.util.LinearDataGenerator +import org.apache.spark.sql.{DataFrame, SparkSession} + +class FakeLinearRegressionWriter extends MLWriterFormat { + override def write(path: String, sparkSession: SparkSession, + optionMap: mutable.Map[String, String], stage: PipelineStage): Unit = { + throw new Exception(s"Fake writer doesn't writestart") + } +} + +class FakeLinearRegressionWriterWithName extends MLFormatRegister { + override def format(): String = "fakeWithName" + override def stageName(): String = "org.apache.spark.ml.regression.LinearRegressionModel" + override def write(path: String, sparkSession: SparkSession, + optionMap: mutable.Map[String, String], stage: PipelineStage): Unit = { + throw new Exception(s"Fake writer doesn't writestart") + } +} + + +class DuplicateLinearRegressionWriter1 extends MLFormatRegister { + override def format(): String = "dupe" + override def stageName(): String = "org.apache.spark.ml.regression.LinearRegressionModel" + override def write(path: String, sparkSession: SparkSession, + optionMap: mutable.Map[String, String], stage: PipelineStage): Unit = { + throw new Exception(s"Duplicate writer shouldn't have been called") + } +} + +class DuplicateLinearRegressionWriter2 extends MLFormatRegister { + override def format(): String = "dupe" + override def stageName(): String = "org.apache.spark.ml.regression.LinearRegressionModel" + override def write(path: String, sparkSession: SparkSession, + optionMap: mutable.Map[String, String], stage: PipelineStage): Unit = { + throw new Exception(s"Duplicate writer shouldn't have been called") + } +} + +class ReadWriteSuite extends MLTest { + + import testImplicits._ + + private val seed: Int = 42 + @transient var dataset: DataFrame = _ + + override def beforeAll(): Unit = { + super.beforeAll() + dataset = sc.parallelize(LinearDataGenerator.generateLinearInput( + intercept = 0.0, weights = Array(1.0, 2.0), xMean = Array(0.0, 1.0), + xVariance = Array(2.0, 1.0), nPoints = 10, seed, eps = 0.2)).map(_.asML).toDF() + } + + test("unsupported/non existent export formats") { + val lr = new LinearRegression() + val model = lr.fit(dataset) + // Does not exist with a long class name + val thrownDNE = intercept[SparkException] { + model.write.format("com.holdenkarau.boop").save("boop") + } + assert(thrownDNE.getMessage(). + contains("Could not load requested format")) + + // Does not exist with a short name + val thrownDNEShort = intercept[SparkException] { + model.write.format("boop").save("boop") + } + assert(thrownDNEShort.getMessage(). + contains("Could not load requested format")) + + // Check with a valid class that is not a writer format. + val thrownInvalid = intercept[SparkException] { + model.write.format("org.apache.spark.SparkContext").save("boop2") + } + assert(thrownInvalid.getMessage() + .contains("ML source org.apache.spark.SparkContext is not a valid MLWriterFormat")) + } + + test("invalid paths fail") { + val lr = new LinearRegression() + val model = lr.fit(dataset) + val thrown = intercept[Exception] { + model.write.format("pmml").save("") + } + assert(thrown.getMessage().contains("Can not create a Path from an empty string")) + } + + test("dummy export format is called") { + val lr = new LinearRegression() + val model = lr.fit(dataset) + val thrown = intercept[Exception] { + model.write.format("org.apache.spark.ml.util.FakeLinearRegressionWriter").save("name") + } + assert(thrown.getMessage().contains("Fake writer doesn't write")) + val thrownWithName = intercept[Exception] { + model.write.format("fakeWithName").save("name") + } + assert(thrownWithName.getMessage().contains("Fake writer doesn't write")) + } + + test("duplicate format raises error") { + val lr = new LinearRegression() + val model = lr.fit(dataset) + val thrown = intercept[Exception] { + model.write.format("dupe").save("dupepanda") + } + assert(thrown.getMessage().contains("Multiple writers found for")) + } +} From a33655348c4066d9c1d8ad2055aadfbc892ba7fd Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Fri, 23 Mar 2018 15:58:48 -0700 Subject: [PATCH 0518/1161] [SPARK-23615][ML][PYSPARK] Add maxDF Parameter to Python CountVectorizer ## What changes were proposed in this pull request? The maxDF parameter is for filtering out frequently occurring terms. This param was recently added to the Scala CountVectorizer and needs to be added to Python also. ## How was this patch tested? add test Author: Huaxin Gao Closes #20777 from huaxingao/spark-23615. --- .../spark/ml/feature/CountVectorizer.scala | 20 +++++----- python/pyspark/ml/feature.py | 40 ++++++++++++++----- python/pyspark/ml/tests.py | 25 ++++++++++++ 3 files changed, 67 insertions(+), 18 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala index 60a4f918790a3..9e0ed437e7bfc 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala @@ -70,19 +70,21 @@ private[feature] trait CountVectorizerParams extends Params with HasInputCol wit def getMinDF: Double = $(minDF) /** - * Specifies the maximum number of different documents a term must appear in to be included - * in the vocabulary. - * If this is an integer greater than or equal to 1, this specifies the number of documents - * the term must appear in; if this is a double in [0,1), then this specifies the fraction of - * documents. + * Specifies the maximum number of different documents a term could appear in to be included + * in the vocabulary. A term that appears more than the threshold will be ignored. If this is an + * integer greater than or equal to 1, this specifies the maximum number of documents the term + * could appear in; if this is a double in [0,1), then this specifies the maximum fraction of + * documents the term could appear in. * - * Default: (2^64^) - 1 + * Default: (2^63^) - 1 * @group param */ val maxDF: DoubleParam = new DoubleParam(this, "maxDF", "Specifies the maximum number of" + - " different documents a term must appear in to be included in the vocabulary." + - " If this is an integer >= 1, this specifies the number of documents the term must" + - " appear in; if this is a double in [0,1), then this specifies the fraction of documents.", + " different documents a term could appear in to be included in the vocabulary." + + " A term that appears more than the threshold will be ignored. If this is an integer >= 1," + + " this specifies the maximum number of documents the term could appear in;" + + " if this is a double in [0,1), then this specifies the maximum fraction of" + + " documents the term could appear in.", ParamValidators.gtEq(0.0)) /** @group getParam */ diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index a1ceb7f02da8b..fcb0dfc563720 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -422,6 +422,14 @@ class _CountVectorizerParams(JavaParams, HasInputCol, HasOutputCol): " If this is an integer >= 1, this specifies the number of documents the term must" + " appear in; if this is a double in [0,1), then this specifies the fraction of documents." + " Default 1.0", typeConverter=TypeConverters.toFloat) + maxDF = Param( + Params._dummy(), "maxDF", "Specifies the maximum number of" + + " different documents a term could appear in to be included in the vocabulary." + + " A term that appears more than the threshold will be ignored. If this is an" + + " integer >= 1, this specifies the maximum number of documents the term could appear in;" + + " if this is a double in [0,1), then this specifies the maximum" + + " fraction of documents the term could appear in." + + " Default (2^63) - 1", typeConverter=TypeConverters.toFloat) vocabSize = Param( Params._dummy(), "vocabSize", "max size of the vocabulary. Default 1 << 18.", typeConverter=TypeConverters.toInt) @@ -433,7 +441,7 @@ class _CountVectorizerParams(JavaParams, HasInputCol, HasOutputCol): def __init__(self, *args): super(_CountVectorizerParams, self).__init__(*args) - self._setDefault(minTF=1.0, minDF=1.0, vocabSize=1 << 18, binary=False) + self._setDefault(minTF=1.0, minDF=1.0, maxDF=2 ** 63 - 1, vocabSize=1 << 18, binary=False) @since("1.6.0") def getMinTF(self): @@ -449,6 +457,13 @@ def getMinDF(self): """ return self.getOrDefault(self.minDF) + @since("2.4.0") + def getMaxDF(self): + """ + Gets the value of maxDF or its default value. + """ + return self.getOrDefault(self.maxDF) + @since("1.6.0") def getVocabSize(self): """ @@ -513,11 +528,11 @@ class CountVectorizer(JavaEstimator, _CountVectorizerParams, JavaMLReadable, Jav """ @keyword_only - def __init__(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, binary=False, inputCol=None, - outputCol=None): + def __init__(self, minTF=1.0, minDF=1.0, maxDF=2 ** 63 - 1, vocabSize=1 << 18, binary=False, + inputCol=None, outputCol=None): """ - __init__(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, binary=False, inputCol=None,\ - outputCol=None) + __init__(self, minTF=1.0, minDF=1.0, maxDF=2 ** 63 - 1, vocabSize=1 << 18, binary=False,\ + inputCol=None,outputCol=None) """ super(CountVectorizer, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.CountVectorizer", @@ -527,11 +542,11 @@ def __init__(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, binary=False, inputC @keyword_only @since("1.6.0") - def setParams(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, binary=False, inputCol=None, - outputCol=None): + def setParams(self, minTF=1.0, minDF=1.0, maxDF=2 ** 63 - 1, vocabSize=1 << 18, binary=False, + inputCol=None, outputCol=None): """ - setParams(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, binary=False, inputCol=None,\ - outputCol=None) + setParams(self, minTF=1.0, minDF=1.0, maxDF=2 ** 63 - 1, vocabSize=1 << 18, binary=False,\ + inputCol=None, outputCol=None) Set the params for the CountVectorizer """ kwargs = self._input_kwargs @@ -551,6 +566,13 @@ def setMinDF(self, value): """ return self._set(minDF=value) + @since("2.4.0") + def setMaxDF(self, value): + """ + Sets the value of :py:attr:`maxDF`. + """ + return self._set(maxDF=value) + @since("1.6.0") def setVocabSize(self, value): """ diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 080119959a4e8..cf1ffa181ecec 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -697,6 +697,31 @@ def test_count_vectorizer_with_binary(self): feature, expected = r self.assertEqual(feature, expected) + def test_count_vectorizer_with_maxDF(self): + dataset = self.spark.createDataFrame([ + (0, "a b c d".split(' '), SparseVector(3, {0: 1.0, 1: 1.0, 2: 1.0}),), + (1, "a b c".split(' '), SparseVector(3, {0: 1.0, 1: 1.0}),), + (2, "a b".split(' '), SparseVector(3, {0: 1.0}),), + (3, "a".split(' '), SparseVector(3, {}),)], ["id", "words", "expected"]) + cv = CountVectorizer(inputCol="words", outputCol="features") + model1 = cv.setMaxDF(3).fit(dataset) + self.assertEqual(model1.vocabulary, ['b', 'c', 'd']) + + transformedList1 = model1.transform(dataset).select("features", "expected").collect() + + for r in transformedList1: + feature, expected = r + self.assertEqual(feature, expected) + + model2 = cv.setMaxDF(0.75).fit(dataset) + self.assertEqual(model2.vocabulary, ['b', 'c', 'd']) + + transformedList2 = model2.transform(dataset).select("features", "expected").collect() + + for r in transformedList2: + feature, expected = r + self.assertEqual(feature, expected) + def test_count_vectorizer_from_vocab(self): model = CountVectorizerModel.from_vocabulary(["a", "b", "c"], inputCol="words", outputCol="features", minTF=2) From 816a5496ba4caac438f70400f72bb10bfcc02418 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Sat, 24 Mar 2018 18:21:01 -0700 Subject: [PATCH 0519/1161] [SPARK-23788][SS] Fix race in StreamingQuerySuite ## What changes were proposed in this pull request? The serializability test uses the same MemoryStream instance for 3 different queries. If any of those queries ask it to commit before the others have run, the rest will see empty dataframes. This can fail the test if q3 is affected. We should use one instance per query instead. ## How was this patch tested? Existing unit test. If I move q2.processAllAvailable() before starting q3, the test always fails without the fix. Author: Jose Torres Closes #20896 from jose-torres/fixrace. --- .../spark/sql/streaming/StreamingQuerySuite.scala | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala index ebc9a87b23f84..08749b49997e0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala @@ -550,22 +550,22 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi .start() } - val input = MemoryStream[Int] - val q1 = startQuery(input.toDS, "stream_serializable_test_1") - val q2 = startQuery(input.toDS.map { i => + val input = MemoryStream[Int] :: MemoryStream[Int] :: MemoryStream[Int] :: Nil + val q1 = startQuery(input(0).toDS, "stream_serializable_test_1") + val q2 = startQuery(input(1).toDS.map { i => // Emulate that `StreamingQuery` get captured with normal usage unintentionally. // It should not fail the query. q1 i }, "stream_serializable_test_2") - val q3 = startQuery(input.toDS.map { i => + val q3 = startQuery(input(2).toDS.map { i => // Emulate that `StreamingQuery` is used in executors. We should fail the query with a clear // error message. q1.explain() i }, "stream_serializable_test_3") try { - input.addData(1) + input.foreach(_.addData(1)) // q2 should not fail since it doesn't use `q1` in the closure q2.processAllAvailable() From 5f653d4f7c84e6147cd323cd650da65e0381ebe8 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Sun, 25 Mar 2018 09:18:26 -0700 Subject: [PATCH 0520/1161] [SPARK-23167][SQL] Add TPCDS queries v2.7 in TPCDSQuerySuite ## What changes were proposed in this pull request? This pr added TPCDS v2.7 (latest) queries in `TPCDSQuerySuite` because the current `TPCDSQuerySuite` tests older one (v1.4) and some queries are different from v1.4 and v2.7. Since the original v2.7 queries have the syntaxes that Spark cannot parse, I changed these queries in a following way: - [date] + 14 days -> date + `INTERVAL` 14 days - [column name] as "30 days" -> [column name] as \`30 days\` - Fix some syntax errors, e.g., missing brackets ## How was this patch tested? Added tests in `TPCDSQuerySuite`. Author: Takeshi Yamamuro Closes #20343 from maropu/TPCDSV2_7. --- .../src/test/resources/tpcds-v2.7.0/q10a.sql | 69 ++++++ .../src/test/resources/tpcds-v2.7.0/q11.sql | 84 +++++++ .../src/test/resources/tpcds-v2.7.0/q12.sql | 23 ++ .../src/test/resources/tpcds-v2.7.0/q14.sql | 135 +++++++++++ .../src/test/resources/tpcds-v2.7.0/q14a.sql | 215 ++++++++++++++++++ .../src/test/resources/tpcds-v2.7.0/q18a.sql | 133 +++++++++++ .../src/test/resources/tpcds-v2.7.0/q20.sql | 19 ++ .../src/test/resources/tpcds-v2.7.0/q22.sql | 15 ++ .../src/test/resources/tpcds-v2.7.0/q22a.sql | 94 ++++++++ .../src/test/resources/tpcds-v2.7.0/q24.sql | 40 ++++ .../src/test/resources/tpcds-v2.7.0/q27a.sql | 70 ++++++ .../src/test/resources/tpcds-v2.7.0/q34.sql | 37 +++ .../src/test/resources/tpcds-v2.7.0/q35.sql | 65 ++++++ .../src/test/resources/tpcds-v2.7.0/q35a.sql | 62 +++++ .../src/test/resources/tpcds-v2.7.0/q36a.sql | 70 ++++++ .../src/test/resources/tpcds-v2.7.0/q47.sql | 64 ++++++ .../src/test/resources/tpcds-v2.7.0/q49.sql | 133 +++++++++++ .../src/test/resources/tpcds-v2.7.0/q51a.sql | 103 +++++++++ .../src/test/resources/tpcds-v2.7.0/q57.sql | 57 +++++ .../src/test/resources/tpcds-v2.7.0/q5a.sql | 158 +++++++++++++ .../src/test/resources/tpcds-v2.7.0/q6.sql | 23 ++ .../src/test/resources/tpcds-v2.7.0/q64.sql | 111 +++++++++ .../src/test/resources/tpcds-v2.7.0/q67a.sql | 208 +++++++++++++++++ .../src/test/resources/tpcds-v2.7.0/q70a.sql | 70 ++++++ .../src/test/resources/tpcds-v2.7.0/q72.sql | 40 ++++ .../src/test/resources/tpcds-v2.7.0/q74.sql | 60 +++++ .../src/test/resources/tpcds-v2.7.0/q75.sql | 78 +++++++ .../src/test/resources/tpcds-v2.7.0/q77a.sql | 121 ++++++++++ .../src/test/resources/tpcds-v2.7.0/q78.sql | 75 ++++++ .../src/test/resources/tpcds-v2.7.0/q80a.sql | 147 ++++++++++++ .../src/test/resources/tpcds-v2.7.0/q86a.sql | 61 +++++ .../src/test/resources/tpcds-v2.7.0/q98.sql | 22 ++ .../apache/spark/sql/TPCDSQuerySuite.scala | 38 +++- 33 files changed, 2691 insertions(+), 9 deletions(-) create mode 100644 sql/core/src/test/resources/tpcds-v2.7.0/q10a.sql create mode 100755 sql/core/src/test/resources/tpcds-v2.7.0/q11.sql create mode 100755 sql/core/src/test/resources/tpcds-v2.7.0/q12.sql create mode 100644 sql/core/src/test/resources/tpcds-v2.7.0/q14.sql create mode 100644 sql/core/src/test/resources/tpcds-v2.7.0/q14a.sql create mode 100644 sql/core/src/test/resources/tpcds-v2.7.0/q18a.sql create mode 100755 sql/core/src/test/resources/tpcds-v2.7.0/q20.sql create mode 100755 sql/core/src/test/resources/tpcds-v2.7.0/q22.sql create mode 100644 sql/core/src/test/resources/tpcds-v2.7.0/q22a.sql create mode 100755 sql/core/src/test/resources/tpcds-v2.7.0/q24.sql create mode 100644 sql/core/src/test/resources/tpcds-v2.7.0/q27a.sql create mode 100755 sql/core/src/test/resources/tpcds-v2.7.0/q34.sql create mode 100755 sql/core/src/test/resources/tpcds-v2.7.0/q35.sql create mode 100644 sql/core/src/test/resources/tpcds-v2.7.0/q35a.sql create mode 100644 sql/core/src/test/resources/tpcds-v2.7.0/q36a.sql create mode 100755 sql/core/src/test/resources/tpcds-v2.7.0/q47.sql create mode 100755 sql/core/src/test/resources/tpcds-v2.7.0/q49.sql create mode 100644 sql/core/src/test/resources/tpcds-v2.7.0/q51a.sql create mode 100755 sql/core/src/test/resources/tpcds-v2.7.0/q57.sql create mode 100644 sql/core/src/test/resources/tpcds-v2.7.0/q5a.sql create mode 100755 sql/core/src/test/resources/tpcds-v2.7.0/q6.sql create mode 100755 sql/core/src/test/resources/tpcds-v2.7.0/q64.sql create mode 100644 sql/core/src/test/resources/tpcds-v2.7.0/q67a.sql create mode 100644 sql/core/src/test/resources/tpcds-v2.7.0/q70a.sql create mode 100755 sql/core/src/test/resources/tpcds-v2.7.0/q72.sql create mode 100755 sql/core/src/test/resources/tpcds-v2.7.0/q74.sql create mode 100755 sql/core/src/test/resources/tpcds-v2.7.0/q75.sql create mode 100644 sql/core/src/test/resources/tpcds-v2.7.0/q77a.sql create mode 100755 sql/core/src/test/resources/tpcds-v2.7.0/q78.sql create mode 100644 sql/core/src/test/resources/tpcds-v2.7.0/q80a.sql create mode 100644 sql/core/src/test/resources/tpcds-v2.7.0/q86a.sql create mode 100755 sql/core/src/test/resources/tpcds-v2.7.0/q98.sql diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q10a.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q10a.sql new file mode 100644 index 0000000000000..50e521567eb3a --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q10a.sql @@ -0,0 +1,69 @@ +-- This is a new query in TPCDS v2.7 +select + cd_gender, + cd_marital_status, + cd_education_status, + count(*) cnt1, + cd_purchase_estimate, + count(*) cnt2, + cd_credit_rating, + count(*) cnt3, + cd_dep_count, + count(*) cnt4, + cd_dep_employed_count, + count(*) cnt5, + cd_dep_college_count, + count(*) cnt6 +from + customer c,customer_address ca,customer_demographics +where + c.c_current_addr_sk = ca.ca_address_sk + and ca_county in ('Walker County', 'Richland County', 'Gaines County', 'Douglas County', 'Dona Ana County') + and cd_demo_sk = c.c_current_cdemo_sk + and exists ( + select * + from store_sales,date_dim + where c.c_customer_sk = ss_customer_sk + and ss_sold_date_sk = d_date_sk + and d_year = 2002 + and d_moy between 4 and 4 + 3) + and exists ( + select * + from ( + select + ws_bill_customer_sk as customer_sk, + d_year, + d_moy + from web_sales, date_dim + where ws_sold_date_sk = d_date_sk + and d_year = 2002 + and d_moy between 4 and 4 + 3 + union all + select + cs_ship_customer_sk as customer_sk, + d_year, + d_moy + from catalog_sales, date_dim + where cs_sold_date_sk = d_date_sk + and d_year = 2002 + and d_moy between 4 and 4 + 3) x + where c.c_customer_sk = customer_sk) +group by + cd_gender, + cd_marital_status, + cd_education_status, + cd_purchase_estimate, + cd_credit_rating, + cd_dep_count, + cd_dep_employed_count, + cd_dep_college_count +order by + cd_gender, + cd_marital_status, + cd_education_status, + cd_purchase_estimate, + cd_credit_rating, + cd_dep_count, + cd_dep_employed_count, + cd_dep_college_count +limit 100 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q11.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q11.sql new file mode 100755 index 0000000000000..97bed33721742 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q11.sql @@ -0,0 +1,84 @@ +WITH year_total AS ( + SELECT + c_customer_id customer_id, + c_first_name customer_first_name, + c_last_name customer_last_name, + c_preferred_cust_flag customer_preferred_cust_flag, + c_birth_country customer_birth_country, + c_login customer_login, + c_email_address customer_email_address, + d_year dyear, + sum(ss_ext_list_price - ss_ext_discount_amt) year_total, + 's' sale_type + FROM customer, store_sales, date_dim + WHERE c_customer_sk = ss_customer_sk + AND ss_sold_date_sk = d_date_sk + GROUP BY c_customer_id + , c_first_name + , c_last_name + , d_year + , c_preferred_cust_flag + , c_birth_country + , c_login + , c_email_address + , d_year + UNION ALL + SELECT + c_customer_id customer_id, + c_first_name customer_first_name, + c_last_name customer_last_name, + c_preferred_cust_flag customer_preferred_cust_flag, + c_birth_country customer_birth_country, + c_login customer_login, + c_email_address customer_email_address, + d_year dyear, + sum(ws_ext_list_price - ws_ext_discount_amt) year_total, + 'w' sale_type + FROM customer, web_sales, date_dim + WHERE c_customer_sk = ws_bill_customer_sk + AND ws_sold_date_sk = d_date_sk + GROUP BY + c_customer_id, c_first_name, c_last_name, c_preferred_cust_flag, c_birth_country, + c_login, c_email_address, d_year) +SELECT + -- select list of q11 in TPCDS v1.4 is below: + -- t_s_secyear.customer_preferred_cust_flag + t_s_secyear.customer_id, + t_s_secyear.customer_first_name, + t_s_secyear.customer_last_name, + t_s_secyear.customer_email_address +FROM year_total t_s_firstyear + , year_total t_s_secyear + , year_total t_w_firstyear + , year_total t_w_secyear +WHERE t_s_secyear.customer_id = t_s_firstyear.customer_id + AND t_s_firstyear.customer_id = t_w_secyear.customer_id + AND t_s_firstyear.customer_id = t_w_firstyear.customer_id + AND t_s_firstyear.sale_type = 's' + AND t_w_firstyear.sale_type = 'w' + AND t_s_secyear.sale_type = 's' + AND t_w_secyear.sale_type = 'w' + AND t_s_firstyear.dyear = 2001 + AND t_s_secyear.dyear = 2001 + 1 + AND t_w_firstyear.dyear = 2001 + AND t_w_secyear.dyear = 2001 + 1 + AND t_s_firstyear.year_total > 0 + AND t_w_firstyear.year_total > 0 + AND CASE WHEN t_w_firstyear.year_total > 0 + THEN t_w_secyear.year_total / t_w_firstyear.year_total + -- q11 in TPCDS v1.4 used NULL + -- ELSE NULL END + ELSE 0.0 END + > CASE WHEN t_s_firstyear.year_total > 0 + THEN t_s_secyear.year_total / t_s_firstyear.year_total + -- q11 in TPCDS v1.4 used NULL + -- ELSE NULL END + ELSE 0.0 END +ORDER BY + -- order-by list of q11 in TPCDS v1.4 is below: + -- t_s_secyear.customer_preferred_cust_flag + t_s_secyear.customer_id, + t_s_secyear.customer_first_name, + t_s_secyear.customer_last_name, + t_s_secyear.customer_email_address +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q12.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q12.sql new file mode 100755 index 0000000000000..7a6fafd22428a --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q12.sql @@ -0,0 +1,23 @@ +SELECT + i_item_id, -- This column did not exist in TPCDS v1.4 + i_item_desc, + i_category, + i_class, + i_current_price, + sum(ws_ext_sales_price) AS itemrevenue, + sum(ws_ext_sales_price) * 100 / sum(sum(ws_ext_sales_price)) + OVER + (PARTITION BY i_class) AS revenueratio +FROM + web_sales, item, date_dim +WHERE + ws_item_sk = i_item_sk + AND i_category IN ('Sports', 'Books', 'Home') + AND ws_sold_date_sk = d_date_sk + AND d_date BETWEEN cast('1999-02-22' AS DATE) + AND (cast('1999-02-22' AS DATE) + INTERVAL 30 days) +GROUP BY + i_item_id, i_item_desc, i_category, i_class, i_current_price +ORDER BY + i_category, i_class, i_item_id, i_item_desc, revenueratio +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q14.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q14.sql new file mode 100644 index 0000000000000..b2ca3ddaf2baf --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q14.sql @@ -0,0 +1,135 @@ +-- This query is the alternative form of sql/core/src/test/resources/tpcds/q14a.sql +with cross_items as ( + select + i_item_sk ss_item_sk + from item, ( + select + iss.i_brand_id brand_id, + iss.i_class_id class_id, + iss.i_category_id category_id + from + store_sales, item iss, date_dim d1 + where + ss_item_sk = iss.i_item_sk + and ss_sold_date_sk = d1.d_date_sk + and d1.d_year between 1998 AND 1998 + 2 + intersect + select + ics.i_brand_id, + ics.i_class_id, + ics.i_category_id + from + catalog_sales, item ics, date_dim d2 + where + cs_item_sk = ics.i_item_sk + and cs_sold_date_sk = d2.d_date_sk + and d2.d_year between 1998 AND 1998 + 2 + intersect + select + iws.i_brand_id, + iws.i_class_id, + iws.i_category_id + from + web_sales, item iws, date_dim d3 + where + ws_item_sk = iws.i_item_sk + and ws_sold_date_sk = d3.d_date_sk + and d3.d_year between 1998 AND 1998 + 2) x + where + i_brand_id = brand_id + and i_class_id = class_id + and i_category_id = category_id), +avg_sales as ( + select + avg(quantity*list_price) average_sales + from ( + select + ss_quantity quantity, + ss_list_price list_price + from + store_sales, date_dim + where + ss_sold_date_sk = d_date_sk + and d_year between 1998 and 1998 + 2 + union all + select + cs_quantity quantity, + cs_list_price list_price + from + catalog_sales, date_dim + where + cs_sold_date_sk = d_date_sk + and d_year between 1998 and 1998 + 2 + union all + select + ws_quantity quantity, + ws_list_price list_price + from + web_sales, date_dim + where + ws_sold_date_sk = d_date_sk + and d_year between 1998 and 1998 + 2) x) +select + * +from ( + select + 'store' channel, + i_brand_id, + i_class_id, + i_category_id, + sum(ss_quantity * ss_list_price) sales, + count(*) number_sales + from + store_sales, item, date_dim + where + ss_item_sk in (select ss_item_sk from cross_items) + and ss_item_sk = i_item_sk + and ss_sold_date_sk = d_date_sk + and d_week_seq = ( + select d_week_seq + from date_dim + where d_year = 1998 + 1 + and d_moy = 12 + and d_dom = 16) + group by + i_brand_id, + i_class_id, + i_category_id + having + sum(ss_quantity*ss_list_price) > (select average_sales from avg_sales)) this_year, + ( + select + 'store' channel, + i_brand_id, + i_class_id, + i_category_id, + sum(ss_quantity * ss_list_price) sales, + count(*) number_sales + from + store_sales, item, date_dim + where + ss_item_sk in (select ss_item_sk from cross_items) + and ss_item_sk = i_item_sk + and ss_sold_date_sk = d_date_sk + and d_week_seq = ( + select d_week_seq + from date_dim + where d_year = 1998 + and d_moy = 12 + and d_dom = 16) + group by + i_brand_id, + i_class_id, + i_category_id + having + sum(ss_quantity * ss_list_price) > (select average_sales from avg_sales)) last_year +where + this_year.i_brand_id = last_year.i_brand_id + and this_year.i_class_id = last_year.i_class_id + and this_year.i_category_id = last_year.i_category_id +order by + this_year.channel, + this_year.i_brand_id, + this_year.i_class_id, + this_year.i_category_id +limit 100 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q14a.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q14a.sql new file mode 100644 index 0000000000000..bfa70fe62d8d5 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q14a.sql @@ -0,0 +1,215 @@ +-- This query is the alternative form of sql/core/src/test/resources/tpcds/q14b.sql +with cross_items as ( + select + i_item_sk ss_item_sk + from item, ( + select + iss.i_brand_id brand_id, + iss.i_class_id class_id, + iss.i_category_id category_id + from + store_sales, item iss, date_dim d1 + where + ss_item_sk = iss.i_item_sk + and ss_sold_date_sk = d1.d_date_sk + and d1.d_year between 1999 AND 1999 + 2 + intersect + select + ics.i_brand_id, + ics.i_class_id, + ics.i_category_id + from + catalog_sales, item ics, date_dim d2 + where + cs_item_sk = ics.i_item_sk + and cs_sold_date_sk = d2.d_date_sk + and d2.d_year between 1999 AND 1999 + 2 + intersect + select + iws.i_brand_id, + iws.i_class_id, + iws.i_category_id + from + web_sales, item iws, date_dim d3 + where + ws_item_sk = iws.i_item_sk + and ws_sold_date_sk = d3.d_date_sk + and d3.d_year between 1999 AND 1999 + 2) x + where + i_brand_id = brand_id + and i_class_id = class_id + and i_category_id = category_id), +avg_sales as ( + select + avg(quantity*list_price) average_sales + from ( + select + ss_quantity quantity, + ss_list_price list_price + from + store_sales, date_dim + where + ss_sold_date_sk = d_date_sk + and d_year between 1999 and 2001 + union all + select + cs_quantity quantity, + cs_list_price list_price + from + catalog_sales, date_dim + where + cs_sold_date_sk = d_date_sk + and d_year between 1998 and 1998 + 2 + union all + select + ws_quantity quantity, + ws_list_price list_price + from + web_sales, date_dim + where + ws_sold_date_sk = d_date_sk + and d_year between 1998 and 1998 + 2) x), +results AS ( + select + channel, + i_brand_id, + i_class_id, + i_category_id, + sum(sales) sum_sales, + sum(number_sales) number_sales + from ( + select + 'store' channel, + i_brand_id,i_class_id, + i_category_id, + sum(ss_quantity*ss_list_price) sales, + count(*) number_sales + from + store_sales, item, date_dim + where + ss_item_sk in (select ss_item_sk from cross_items) + and ss_item_sk = i_item_sk + and ss_sold_date_sk = d_date_sk + and d_year = 1998 + 2 + and d_moy = 11 + group by + i_brand_id, + i_class_id, + i_category_id + having + sum(ss_quantity * ss_list_price) > (select average_sales from avg_sales) + union all + select + 'catalog' channel, + i_brand_id, + i_class_id, + i_category_id, + sum(cs_quantity*cs_list_price) sales, + count(*) number_sales + from + catalog_sales, item, date_dim + where + cs_item_sk in (select ss_item_sk from cross_items) + and cs_item_sk = i_item_sk + and cs_sold_date_sk = d_date_sk + and d_year = 1998+2 + and d_moy = 11 + group by + i_brand_id,i_class_id,i_category_id + having + sum(cs_quantity*cs_list_price) > (select average_sales from avg_sales) + union all + select + 'web' channel, + i_brand_id, + i_class_id, + i_category_id, + sum(ws_quantity*ws_list_price) sales, + count(*) number_sales + from + web_sales, item, date_dim + where + ws_item_sk in (select ss_item_sk from cross_items) + and ws_item_sk = i_item_sk + and ws_sold_date_sk = d_date_sk + and d_year = 1998 + 2 + and d_moy = 11 + group by + i_brand_id, + i_class_id, + i_category_id + having + sum(ws_quantity*ws_list_price) > (select average_sales from avg_sales)) y + group by + channel, + i_brand_id, + i_class_id, + i_category_id) +select + channel, + i_brand_id, + i_class_id, + i_category_id, + sum_sales, + number_sales +from ( + select + channel, + i_brand_id, + i_class_id, + i_category_id, + sum_sales, + number_sales + from + results + union + select + channel, + i_brand_id, + i_class_id, + null as i_category_id, + sum(sum_sales), + sum(number_sales) + from results + group by + channel, + i_brand_id, + i_class_id + union + select + channel, + i_brand_id, + null as i_class_id, + null as i_category_id, + sum(sum_sales), + sum(number_sales) + from results + group by + channel, + i_brand_id + union + select + channel, + null as i_brand_id, + null as i_class_id, + null as i_category_id, + sum(sum_sales), + sum(number_sales) + from results + group by + channel + union + select + null as channel, + null as i_brand_id, + null as i_class_id, + null as i_category_id, + sum(sum_sales), + sum(number_sales) + from results) z +order by + channel, + i_brand_id, + i_class_id, + i_category_id +limit 100 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q18a.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q18a.sql new file mode 100644 index 0000000000000..2201a302ab352 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q18a.sql @@ -0,0 +1,133 @@ +-- This is a new query in TPCDS v2.7 +with results as ( + select + i_item_id, + ca_country, + ca_state, + ca_county, + cast(cs_quantity as decimal(12,2)) agg1, + cast(cs_list_price as decimal(12,2)) agg2, + cast(cs_coupon_amt as decimal(12,2)) agg3, + cast(cs_sales_price as decimal(12,2)) agg4, + cast(cs_net_profit as decimal(12,2)) agg5, + cast(c_birth_year as decimal(12,2)) agg6, + cast(cd1.cd_dep_count as decimal(12,2)) agg7 + from + catalog_sales, customer_demographics cd1, customer_demographics cd2, customer, + customer_address, date_dim, item + where + cs_sold_date_sk = d_date_sk + and cs_item_sk = i_item_sk + and cs_bill_cdemo_sk = cd1.cd_demo_sk + and cs_bill_customer_sk = c_customer_sk + and cd1.cd_gender = 'M' + and cd1.cd_education_status = 'College' + and c_current_cdemo_sk = cd2.cd_demo_sk + and c_current_addr_sk = ca_address_sk + and c_birth_month in (9,5,12,4,1,10) + and d_year = 2001 + and ca_state in ('ND','WI','AL','NC','OK','MS','TN')) +select + i_item_id, + ca_country, + ca_state, + ca_county, + agg1, + agg2, + agg3, + agg4, + agg5, + agg6, + agg7 +from ( + select + i_item_id, + ca_country, + ca_state, + ca_county, + avg(agg1) agg1, + avg(agg2) agg2, + avg(agg3) agg3, + avg(agg4) agg4, + avg(agg5) agg5, + avg(agg6) agg6, + avg(agg7) agg7 + from + results + group by + i_item_id, + ca_country, + ca_state, + ca_county + union all + select + i_item_id, + ca_country, + ca_state, + NULL as county, + avg(agg1) agg1, + avg(agg2) agg2, + avg(agg3) agg3, + avg(agg4) agg4, + avg(agg5) agg5, + avg(agg6) agg6, + avg(agg7) agg7 + from + results + group by + i_item_id, + ca_country, + ca_state + union all + select + i_item_id, + ca_country, + NULL as ca_state, + NULL as county, + avg(agg1) agg1, + avg(agg2) agg2, + avg(agg3) agg3, + avg(agg4) agg4, + avg(agg5) agg5, + avg(agg6) agg6, + avg(agg7) agg7 + from results + group by + i_item_id, + ca_country + union all + select + i_item_id, + NULL as ca_country, + NULL as ca_state, + NULL as county, + avg(agg1) agg1, + avg(agg2) agg2, + avg(agg3) agg3, + avg(agg4) agg4, + avg(agg5) agg5, + avg(agg6) agg6, + avg(agg7) agg7 + from results + group by + i_item_id + union all + select + NULL AS i_item_id, + NULL as ca_country, + NULL as ca_state, + NULL as county, + avg(agg1) agg1, + avg(agg2) agg2, + avg(agg3) agg3, + avg(agg4) agg4, + avg(agg5) agg5, + avg(agg6) agg6, + avg(agg7) agg7 + from results) foo +order by + ca_country, + ca_state, + ca_county, + i_item_id +limit 100 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q20.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q20.sql new file mode 100755 index 0000000000000..34d46b1394d8f --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q20.sql @@ -0,0 +1,19 @@ +SELECT + i_item_id, -- This column did not exist in TPCDS v1.4 + i_item_desc, + i_category, + i_class, + i_current_price, + sum(cs_ext_sales_price) AS itemrevenue, + sum(cs_ext_sales_price) * 100 / sum(sum(cs_ext_sales_price)) + OVER + (PARTITION BY i_class) AS revenueratio +FROM catalog_sales, item, date_dim +WHERE cs_item_sk = i_item_sk + AND i_category IN ('Sports', 'Books', 'Home') + AND cs_sold_date_sk = d_date_sk + AND d_date BETWEEN cast('1999-02-22' AS DATE) +AND (cast('1999-02-22' AS DATE) + INTERVAL 30 days) +GROUP BY i_item_id, i_item_desc, i_category, i_class, i_current_price +ORDER BY i_category, i_class, i_item_id, i_item_desc, revenueratio +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q22.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q22.sql new file mode 100755 index 0000000000000..e7bea0804f162 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q22.sql @@ -0,0 +1,15 @@ +SELECT + i_product_name, + i_brand, + i_class, + i_category, + avg(inv_quantity_on_hand) qoh +FROM inventory, date_dim, item, warehouse +WHERE inv_date_sk = d_date_sk + AND inv_item_sk = i_item_sk + -- q22 in TPCDS v1.4 had a condition below: + -- AND inv_warehouse_sk = w_warehouse_sk + AND d_month_seq BETWEEN 1200 AND 1200 + 11 +GROUP BY ROLLUP (i_product_name, i_brand, i_class, i_category) +ORDER BY qoh, i_product_name, i_brand, i_class, i_category +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q22a.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q22a.sql new file mode 100644 index 0000000000000..c886e6271511b --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q22a.sql @@ -0,0 +1,94 @@ +-- This is a new query in TPCDS v2.7 +with results as ( + select + i_product_name, + i_brand, + i_class, + i_category, + avg(inv_quantity_on_hand) qoh + from + inventory, date_dim, item, warehouse + where + inv_date_sk = d_date_sk + and inv_item_sk = i_item_sk + and inv_warehouse_sk = w_warehouse_sk + and d_month_seq between 1212 and 1212 + 11 + group by + i_product_name, + i_brand, + i_class, + i_category), +results_rollup as ( + select + i_product_name, + i_brand, + i_class, + i_category, + avg(qoh) qoh + from + results + group by + i_product_name, + i_brand, + i_class, + i_category + union all + select + i_product_name, + i_brand, + i_class, + null i_category, + avg(qoh) qoh + from + results + group by + i_product_name, + i_brand, + i_class + union all + select + i_product_name, + i_brand, + null i_class, + null i_category, + avg(qoh) qoh + from + results + group by + i_product_name, + i_brand + union all + select + i_product_name, + null i_brand, + null i_class, + null i_category, + avg(qoh) qoh + from + results + group by + i_product_name + union all + select + null i_product_name, + null i_brand, + null i_class, + null i_category, + avg(qoh) qoh + from + results) +select + i_product_name, + i_brand, + i_class, + i_category, + qoh +from + results_rollup +order by + qoh, + i_product_name, + i_brand, + i_class, + i_category +limit 100 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q24.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q24.sql new file mode 100755 index 0000000000000..92d64bc7eba78 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q24.sql @@ -0,0 +1,40 @@ +WITH ssales AS +(SELECT + c_last_name, + c_first_name, + s_store_name, + ca_state, + s_state, + i_color, + i_current_price, + i_manager_id, + i_units, + i_size, + sum(ss_net_paid) netpaid + FROM store_sales, store_returns, store, item, customer, customer_address + WHERE ss_ticket_number = sr_ticket_number + AND ss_item_sk = sr_item_sk + AND ss_customer_sk = c_customer_sk + AND ss_item_sk = i_item_sk + AND ss_store_sk = s_store_sk + AND c_current_addr_sk = ca_address_sk -- This condition did not exist in TPCDS v1.4 + AND c_birth_country = upper(ca_country) + AND s_zip = ca_zip + AND s_market_id = 8 + GROUP BY c_last_name, c_first_name, s_store_name, ca_state, s_state, i_color, + i_current_price, i_manager_id, i_units, i_size) +SELECT + c_last_name, + c_first_name, + s_store_name, + sum(netpaid) paid +FROM ssales +WHERE i_color = 'pale' +GROUP BY c_last_name, c_first_name, s_store_name +HAVING sum(netpaid) > (SELECT 0.05 * avg(netpaid) +FROM ssales) +-- no order-by exists in q24a of TPCDS v1.4 +ORDER BY + c_last_name, + c_first_name, + s_store_name diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q27a.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q27a.sql new file mode 100644 index 0000000000000..c70a2420e8387 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q27a.sql @@ -0,0 +1,70 @@ +-- This is a new query in TPCDS v2.7 +with results as ( + select + i_item_id, + s_state, 0 as g_state, + ss_quantity agg1, + ss_list_price agg2, + ss_coupon_amt agg3, + ss_sales_price agg4 + from + store_sales, customer_demographics, date_dim, store, item + where + ss_sold_date_sk = d_date_sk + and ss_item_sk = i_item_sk + and ss_store_sk = s_store_sk + and ss_cdemo_sk = cd_demo_sk + and cd_gender = 'F' + and cd_marital_status = 'W' + and cd_education_status = 'Primary' + and d_year = 1998 + and s_state in ('TN','TN', 'TN', 'TN', 'TN', 'TN')) +select + i_item_id, + s_state, + g_state, + agg1, + agg2, + agg3, + agg4 +from ( + select + i_item_id, + s_state, + 0 as g_state, + avg(agg1) agg1, + avg(agg2) agg2, + avg(agg3) agg3, + avg(agg4) agg4 + from + results + group by + i_item_id, + s_state + union all + select + i_item_id, + NULL AS s_state, + 1 AS g_state, + avg(agg1) agg1, + avg(agg2) agg2, + avg(agg3) agg3, + avg(agg4) agg4 + from results + group by + i_item_id + union all + select + NULL AS i_item_id, + NULL as s_state, + 1 as g_state, + avg(agg1) agg1, + avg(agg2) agg2, + avg(agg3) agg3, + avg(agg4) agg4 + from + results) foo +order by + i_item_id, + s_state +limit 100 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q34.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q34.sql new file mode 100755 index 0000000000000..bbede62acc9a7 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q34.sql @@ -0,0 +1,37 @@ +SELECT + c_last_name, + c_first_name, + c_salutation, + c_preferred_cust_flag, + ss_ticket_number, + cnt +FROM + (SELECT + ss_ticket_number, + ss_customer_sk, + count(*) cnt + FROM store_sales, date_dim, store, household_demographics + WHERE store_sales.ss_sold_date_sk = date_dim.d_date_sk + AND store_sales.ss_store_sk = store.s_store_sk + AND store_sales.ss_hdemo_sk = household_demographics.hd_demo_sk + AND (date_dim.d_dom BETWEEN 1 AND 3 OR date_dim.d_dom BETWEEN 25 AND 28) + AND (household_demographics.hd_buy_potential = '>10000' OR + household_demographics.hd_buy_potential = 'unknown') + AND household_demographics.hd_vehicle_count > 0 + AND (CASE WHEN household_demographics.hd_vehicle_count > 0 + THEN household_demographics.hd_dep_count / household_demographics.hd_vehicle_count + ELSE NULL + END) > 1.2 + AND date_dim.d_year IN (1999, 1999 + 1, 1999 + 2) + AND store.s_county IN + ('Williamson County', 'Williamson County', 'Williamson County', 'Williamson County', + 'Williamson County', 'Williamson County', 'Williamson County', 'Williamson County') + GROUP BY ss_ticket_number, ss_customer_sk) dn, customer +WHERE ss_customer_sk = c_customer_sk + AND cnt BETWEEN 15 AND 20 +ORDER BY + c_last_name, + c_first_name, + c_salutation, + c_preferred_cust_flag DESC, + ss_ticket_number -- This order-by condition did not exist in TPCDS v1.4 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q35.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q35.sql new file mode 100755 index 0000000000000..27116a563d5c6 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q35.sql @@ -0,0 +1,65 @@ +SELECT + -- select list of q35 in TPCDS v1.4 is below: + -- ca_state, + -- cd_gender, + -- cd_marital_status, + -- count(*) cnt1, + -- min(cd_dep_count), + -- max(cd_dep_count), + -- avg(cd_dep_count), + -- cd_dep_employed_count, + -- count(*) cnt2, + -- min(cd_dep_employed_count), + -- max(cd_dep_employed_count), + -- avg(cd_dep_employed_count), + -- cd_dep_college_count, + -- count(*) cnt3, + -- min(cd_dep_college_count), + -- max(cd_dep_college_count), + -- avg(cd_dep_college_count) + ca_state, + cd_gender, + cd_marital_status, + cd_dep_count, + count(*) cnt1, + avg(cd_dep_count), + max(cd_dep_count), + sum(cd_dep_count), + cd_dep_employed_count, + count(*) cnt2, + avg(cd_dep_employed_count), + max(cd_dep_employed_count), + sum(cd_dep_employed_count), + cd_dep_college_count, + count(*) cnt3, + avg(cd_dep_college_count), + max(cd_dep_college_count), + sum(cd_dep_college_count) +FROM + customer c, customer_address ca, customer_demographics +WHERE + c.c_current_addr_sk = ca.ca_address_sk AND + cd_demo_sk = c.c_current_cdemo_sk AND + exists(SELECT * + FROM store_sales, date_dim + WHERE c.c_customer_sk = ss_customer_sk AND + ss_sold_date_sk = d_date_sk AND + d_year = 2002 AND + d_qoy < 4) AND + (exists(SELECT * + FROM web_sales, date_dim + WHERE c.c_customer_sk = ws_bill_customer_sk AND + ws_sold_date_sk = d_date_sk AND + d_year = 2002 AND + d_qoy < 4) OR + exists(SELECT * + FROM catalog_sales, date_dim + WHERE c.c_customer_sk = cs_ship_customer_sk AND + cs_sold_date_sk = d_date_sk AND + d_year = 2002 AND + d_qoy < 4)) +GROUP BY ca_state, cd_gender, cd_marital_status, cd_dep_count, + cd_dep_employed_count, cd_dep_college_count +ORDER BY ca_state, cd_gender, cd_marital_status, cd_dep_count, + cd_dep_employed_count, cd_dep_college_count +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q35a.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q35a.sql new file mode 100644 index 0000000000000..1c1463e44777f --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q35a.sql @@ -0,0 +1,62 @@ +-- This is a new query in TPCDS v2.7 +select + ca_state, + cd_gender, + cd_marital_status, + cd_dep_count, + count(*) cnt1, + avg(cd_dep_count), + max(cd_dep_count), + sum(cd_dep_count), + cd_dep_employed_count, + count(*) cnt2, + avg(cd_dep_employed_count), + max(cd_dep_employed_count), + sum(cd_dep_employed_count), + cd_dep_college_count, + count(*) cnt3, + avg(cd_dep_college_count), + max(cd_dep_college_count), + sum(cd_dep_college_count) +from + customer c, customer_address ca, customer_demographics +where + c.c_current_addr_sk = ca.ca_address_sk + and cd_demo_sk = c.c_current_cdemo_sk + and exists ( + select * + from store_sales, date_dim + where c.c_customer_sk = ss_customer_sk + and ss_sold_date_sk = d_date_sk + and d_year = 1999 + and d_qoy < 4) + and exists ( + select * + from ( + select ws_bill_customer_sk customsk + from web_sales, date_dim + where ws_sold_date_sk = d_date_sk + and d_year = 1999 + and d_qoy < 4 + union all + select cs_ship_customer_sk customsk + from catalog_sales, date_dim + where cs_sold_date_sk = d_date_sk + and d_year = 1999 + and d_qoy < 4) x + where x.customsk = c.c_customer_sk) +group by + ca_state, + cd_gender, + cd_marital_status, + cd_dep_count, + cd_dep_employed_count, + cd_dep_college_count +order by + ca_state, + cd_gender, + cd_marital_status, + cd_dep_count, + cd_dep_employed_count, + cd_dep_college_count +limit 100 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q36a.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q36a.sql new file mode 100644 index 0000000000000..9d98f32add508 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q36a.sql @@ -0,0 +1,70 @@ +-- This is a new query in TPCDS v2.7 +with results as ( + select + sum(ss_net_profit) as ss_net_profit, + sum(ss_ext_sales_price) as ss_ext_sales_price, + sum(ss_net_profit)/sum(ss_ext_sales_price) as gross_margin, + i_category, + i_class, + 0 as g_category, + 0 as g_class + from + store_sales, + date_dim d1, + item, + store + where + d1.d_year = 2001 + and d1.d_date_sk = ss_sold_date_sk + and i_item_sk = ss_item_sk + and s_store_sk = ss_store_sk + and s_state in ('TN', 'TN', 'TN', 'TN', 'TN', 'TN', 'TN', 'TN') + group by + i_category, + i_class), + results_rollup as ( + select + gross_margin, + i_category, + i_class, + 0 as t_category, + 0 as t_class, + 0 as lochierarchy + from + results + union + select + sum(ss_net_profit) / sum(ss_ext_sales_price) as gross_margin, + i_category, NULL AS i_class, + 0 as t_category, + 1 as t_class, + 1 as lochierarchy + from + results + group by + i_category + union + select + sum(ss_net_profit) / sum(ss_ext_sales_price) as gross_margin, + NULL AS i_category, + NULL AS i_class, + 1 as t_category, + 1 as t_class, + 2 as lochierarchy + from + results) +select + gross_margin, + i_category, + i_class, + lochierarchy, + rank() over ( + partition by lochierarchy, case when t_class = 0 then i_category end + order by gross_margin asc) as rank_within_parent +from + results_rollup +order by + lochierarchy desc, + case when lochierarchy = 0 then i_category end, + rank_within_parent +limit 100 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q47.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q47.sql new file mode 100755 index 0000000000000..9f7ee457ea45f --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q47.sql @@ -0,0 +1,64 @@ +WITH v1 AS ( + SELECT + i_category, + i_brand, + s_store_name, + s_company_name, + d_year, + d_moy, + sum(ss_sales_price) sum_sales, + avg(sum(ss_sales_price)) + OVER + (PARTITION BY i_category, i_brand, + s_store_name, s_company_name, d_year) + avg_monthly_sales, + rank() + OVER + (PARTITION BY i_category, i_brand, + s_store_name, s_company_name + ORDER BY d_year, d_moy) rn + FROM item, store_sales, date_dim, store + WHERE ss_item_sk = i_item_sk AND + ss_sold_date_sk = d_date_sk AND + ss_store_sk = s_store_sk AND + ( + d_year = 1999 OR + (d_year = 1999 - 1 AND d_moy = 12) OR + (d_year = 1999 + 1 AND d_moy = 1) + ) + GROUP BY i_category, i_brand, + s_store_name, s_company_name, + d_year, d_moy), + v2 AS ( + SELECT + v1.i_category, + -- q47 in TPCDS v1.4 had more columns below: + -- v1.i_brand, + -- v1.s_store_name, + -- v1.s_company_name, + v1.d_year, + v1.d_moy, + v1.avg_monthly_sales, + v1.sum_sales, + v1_lag.sum_sales psum, + v1_lead.sum_sales nsum + FROM v1, v1 v1_lag, v1 v1_lead + WHERE v1.i_category = v1_lag.i_category AND + v1.i_category = v1_lead.i_category AND + v1.i_brand = v1_lag.i_brand AND + v1.i_brand = v1_lead.i_brand AND + v1.s_store_name = v1_lag.s_store_name AND + v1.s_store_name = v1_lead.s_store_name AND + v1.s_company_name = v1_lag.s_company_name AND + v1.s_company_name = v1_lead.s_company_name AND + v1.rn = v1_lag.rn + 1 AND + v1.rn = v1_lead.rn - 1) +SELECT * +FROM v2 +WHERE d_year = 1999 AND + avg_monthly_sales > 0 AND + CASE WHEN avg_monthly_sales > 0 + THEN abs(sum_sales - avg_monthly_sales) / avg_monthly_sales + ELSE NULL END > 0.1 +ORDER BY sum_sales - avg_monthly_sales, 3 +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q49.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q49.sql new file mode 100755 index 0000000000000..e8061bde4159e --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q49.sql @@ -0,0 +1,133 @@ +-- The first SELECT query below is different from q49 of TPCDS v1.4 +SELECT + channel, + item, + return_ratio, + return_rank, + currency_rank +FROM ( + SELECT + 'web' as channel, + in_web.item, + in_web.return_ratio, + in_web.return_rank, + in_web.currency_rank + FROM + (SELECT + item, + return_ratio, + currency_ratio, + rank() over (ORDER BY return_ratio) AS return_rank, + rank() over (ORDER BY currency_ratio) AS currency_rank + FROM ( + SELECT + ws.ws_item_sk AS item, + CAST(SUM(COALESCE(wr.wr_return_quantity, 0)) AS DECIMAL(15, 4)) / + CAST(SUM(COALESCE(ws.ws_quantity, 0)) AS DECIMAL(15, 4)) AS return_ratio, + CAST(SUM(COALESCE(wr.wr_return_amt, 0)) AS DECIMAL(15, 4)) / + CAST(SUM(COALESCE(ws.ws_net_paid, 0)) AS DECIMAL(15, 4)) AS currency_ratio + FROM + web_sales ws LEFT OUTER JOIN web_returns wr + ON (ws.ws_order_number = wr.wr_order_number AND ws.ws_item_sk = wr.wr_item_sk), + date_dim + WHERE + wr.wr_return_amt > 10000 + AND ws.ws_net_profit > 1 + AND ws.ws_net_paid > 0 + AND ws.ws_quantity > 0 + AND ws_sold_date_sk = d_date_sk + AND d_year = 2001 + AND d_moy = 12 + GROUP BY + ws.ws_item_sk) + ) in_web + ) web +WHERE (web.return_rank <= 10 OR web.currency_rank <= 10) +UNION +SELECT + 'catalog' AS channel, + catalog.item, + catalog.return_ratio, + catalog.return_rank, + catalog.currency_rank +FROM ( + SELECT + item, + return_ratio, + currency_ratio, + rank() + OVER ( + ORDER BY return_ratio) AS return_rank, + rank() + OVER ( + ORDER BY currency_ratio) AS currency_rank + FROM + (SELECT + cs.cs_item_sk AS item, + (cast(sum(coalesce(cr.cr_return_quantity, 0)) AS DECIMAL(15, 4)) / + cast(sum(coalesce(cs.cs_quantity, 0)) AS DECIMAL(15, 4))) AS return_ratio, + (cast(sum(coalesce(cr.cr_return_amount, 0)) AS DECIMAL(15, 4)) / + cast(sum(coalesce(cs.cs_net_paid, 0)) AS DECIMAL(15, 4))) AS currency_ratio + FROM + catalog_sales cs LEFT OUTER JOIN catalog_returns cr + ON (cs.cs_order_number = cr.cr_order_number AND + cs.cs_item_sk = cr.cr_item_sk) + , date_dim + WHERE + cr.cr_return_amount > 10000 + AND cs.cs_net_profit > 1 + AND cs.cs_net_paid > 0 + AND cs.cs_quantity > 0 + AND cs_sold_date_sk = d_date_sk + AND d_year = 2001 + AND d_moy = 12 + GROUP BY cs.cs_item_sk + ) in_cat + ) catalog +WHERE (catalog.return_rank <= 10 OR catalog.currency_rank <= 10) +UNION +SELECT + 'store' AS channel, + store.item, + store.return_ratio, + store.return_rank, + store.currency_rank +FROM ( + SELECT + item, + return_ratio, + currency_ratio, + rank() + OVER ( + ORDER BY return_ratio) AS return_rank, + rank() + OVER ( + ORDER BY currency_ratio) AS currency_rank + FROM + (SELECT + sts.ss_item_sk AS item, + (cast(sum(coalesce(sr.sr_return_quantity, 0)) AS DECIMAL(15, 4)) / + cast(sum(coalesce(sts.ss_quantity, 0)) AS DECIMAL(15, 4))) AS return_ratio, + (cast(sum(coalesce(sr.sr_return_amt, 0)) AS DECIMAL(15, 4)) / + cast(sum(coalesce(sts.ss_net_paid, 0)) AS DECIMAL(15, 4))) AS currency_ratio + FROM + store_sales sts LEFT OUTER JOIN store_returns sr + ON (sts.ss_ticket_number = sr.sr_ticket_number AND sts.ss_item_sk = sr.sr_item_sk) + , date_dim + WHERE + sr.sr_return_amt > 10000 + AND sts.ss_net_profit > 1 + AND sts.ss_net_paid > 0 + AND sts.ss_quantity > 0 + AND ss_sold_date_sk = d_date_sk + AND d_year = 2001 + AND d_moy = 12 + GROUP BY sts.ss_item_sk + ) in_store + ) store +WHERE (store.return_rank <= 10 OR store.currency_rank <= 10) +ORDER BY + -- order-by list of q49 in TPCDS v1.4 is below: + -- 1, 4, 5 + 1, 4, 5, 2 +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q51a.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q51a.sql new file mode 100644 index 0000000000000..b8cbbbc8ef7d5 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q51a.sql @@ -0,0 +1,103 @@ +-- This is a new query in TPCDS v2.7 +WITH web_tv as ( + select + ws_item_sk item_sk, + d_date, + sum(ws_sales_price) sumws, + row_number() over (partition by ws_item_sk order by d_date) rk + from + web_sales, date_dim + where + ws_sold_date_sk=d_date_sk + and d_month_seq between 1212 and 1212 + 11 + and ws_item_sk is not NULL + group by + ws_item_sk, d_date), +web_v1 as ( + select + v1.item_sk, + v1.d_date, + v1.sumws, + sum(v2.sumws) cume_sales + from + web_tv v1, web_tv v2 + where + v1.item_sk = v2.item_sk + and v1.rk >= v2.rk + group by + v1.item_sk, + v1.d_date, + v1.sumws), +store_tv as ( + select + ss_item_sk item_sk, + d_date, + sum(ss_sales_price) sumss, + row_number() over (partition by ss_item_sk order by d_date) rk + from + store_sales, date_dim + where + ss_sold_date_sk = d_date_sk + and d_month_seq between 1212 and 1212 + 11 + and ss_item_sk is not NULL + group by ss_item_sk, d_date), +store_v1 as ( + select + v1.item_sk, + v1.d_date, + v1.sumss, + sum(v2.sumss) cume_sales + from + store_tv v1, store_tv v2 + where + v1.item_sk = v2.item_sk + and v1.rk >= v2.rk + group by + v1.item_sk, + v1.d_date, + v1.sumss), +v as ( + select + item_sk, + d_date, + web_sales, + store_sales, + row_number() over (partition by item_sk order by d_date) rk + from ( + select + case when web.item_sk is not null + then web.item_sk + else store.item_sk end item_sk, + case when web.d_date is not null + then web.d_date + else store.d_date end d_date, + web.cume_sales web_sales, + store.cume_sales store_sales + from + web_v1 web full outer join store_v1 store + on (web.item_sk = store.item_sk and web.d_date = store.d_date))) +select * +from ( + select + v1.item_sk, + v1.d_date, + v1.web_sales, + v1.store_sales, + max(v2.web_sales) web_cumulative, + max(v2.store_sales) store_cumulative + from + v v1, v v2 + where + v1.item_sk = v2.item_sk + and v1.rk >= v2.rk + group by + v1.item_sk, + v1.d_date, + v1.web_sales, + v1.store_sales) x +where + web_cumulative > store_cumulative +order by + item_sk, + d_date +limit 100 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q57.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q57.sql new file mode 100755 index 0000000000000..ccefaac3c12ca --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q57.sql @@ -0,0 +1,57 @@ +WITH v1 AS ( + SELECT + i_category, + i_brand, + cc_name, + d_year, + d_moy, + sum(cs_sales_price) sum_sales, + avg(sum(cs_sales_price)) + OVER + (PARTITION BY i_category, i_brand, cc_name, d_year) + avg_monthly_sales, + rank() + OVER + (PARTITION BY i_category, i_brand, cc_name + ORDER BY d_year, d_moy) rn + FROM item, catalog_sales, date_dim, call_center + WHERE cs_item_sk = i_item_sk AND + cs_sold_date_sk = d_date_sk AND + cc_call_center_sk = cs_call_center_sk AND + ( + d_year = 1999 OR + (d_year = 1999 - 1 AND d_moy = 12) OR + (d_year = 1999 + 1 AND d_moy = 1) + ) + GROUP BY i_category, i_brand, + cc_name, d_year, d_moy), + v2 AS ( + SELECT + v1.i_category, + v1.i_brand, + -- q57 in TPCDS v1.4 had a column below: + -- v1.cc_name, + v1.d_year, + v1.d_moy, + v1.avg_monthly_sales, + v1.sum_sales, + v1_lag.sum_sales psum, + v1_lead.sum_sales nsum + FROM v1, v1 v1_lag, v1 v1_lead + WHERE v1.i_category = v1_lag.i_category AND + v1.i_category = v1_lead.i_category AND + v1.i_brand = v1_lag.i_brand AND + v1.i_brand = v1_lead.i_brand AND + v1.cc_name = v1_lag.cc_name AND + v1.cc_name = v1_lead.cc_name AND + v1.rn = v1_lag.rn + 1 AND + v1.rn = v1_lead.rn - 1) +SELECT * +FROM v2 +WHERE d_year = 1999 AND + avg_monthly_sales > 0 AND + CASE WHEN avg_monthly_sales > 0 + THEN abs(sum_sales - avg_monthly_sales) / avg_monthly_sales + ELSE NULL END > 0.1 +ORDER BY sum_sales - avg_monthly_sales, 3 +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q5a.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q5a.sql new file mode 100644 index 0000000000000..42bcf59c2aeb1 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q5a.sql @@ -0,0 +1,158 @@ +-- This is a new query in TPCDS v2.7 +with ssr as( + select + s_store_id, + sum(sales_price) as sales, + sum(profit) as profit, + sum(return_amt) as returns, + sum(net_loss) as profit_loss + from ( + select + ss_store_sk as store_sk, + ss_sold_date_sk as date_sk, + ss_ext_sales_price as sales_price, + ss_net_profit as profit, + cast(0 as decimal(7,2)) as return_amt, + cast(0 as decimal(7,2)) as net_loss + from + store_sales + union all + select + sr_store_sk as store_sk, + sr_returned_date_sk as date_sk, + cast(0 as decimal(7,2)) as sales_price, + cast(0 as decimal(7,2)) as profit, + sr_return_amt as return_amt, + sr_net_loss as net_loss + from + store_returns) salesreturns, + date_dim, + store + where + date_sk = d_date_sk and d_date between cast('1998-08-04' as date) + and (cast('1998-08-04' as date) + INTERVAL 14 days) + and store_sk = s_store_sk + group by + s_store_id), +csr as ( + select + cp_catalog_page_id, + sum(sales_price) as sales, + sum(profit) as profit, + sum(return_amt) as returns, + sum(net_loss) as profit_loss + from ( + select + cs_catalog_page_sk as page_sk, + cs_sold_date_sk as date_sk, + cs_ext_sales_price as sales_price, + cs_net_profit as profit, + cast(0 as decimal(7,2)) as return_amt, + cast(0 as decimal(7,2)) as net_loss + from catalog_sales + union all + select + cr_catalog_page_sk as page_sk, + cr_returned_date_sk as date_sk, + cast(0 as decimal(7,2)) as sales_price, + cast(0 as decimal(7,2)) as profit, + cr_return_amount as return_amt, + cr_net_loss as net_loss + from catalog_returns) salesreturns, + date_dim, + catalog_page + where + date_sk = d_date_sk + and d_date between cast('1998-08-04' as date) + and (cast('1998-08-04' as date) + INTERVAL 14 days) + and page_sk = cp_catalog_page_sk + group by + cp_catalog_page_id), +wsr as ( + select + web_site_id, + sum(sales_price) as sales, + sum(profit) as profit, + sum(return_amt) as returns, + sum(net_loss) as profit_loss + from ( + select + ws_web_site_sk as wsr_web_site_sk, + ws_sold_date_sk as date_sk, + ws_ext_sales_price as sales_price, + ws_net_profit as profit, + cast(0 as decimal(7,2)) as return_amt, + cast(0 as decimal(7,2)) as net_loss + from + web_sales + union all + select + ws_web_site_sk as wsr_web_site_sk, + wr_returned_date_sk as date_sk, + cast(0 as decimal(7,2)) as sales_price, + cast(0 as decimal(7,2)) as profit, + wr_return_amt as return_amt, + wr_net_loss as net_loss + from + web_returns + left outer join web_sales on ( + wr_item_sk = ws_item_sk and wr_order_number = ws_order_number) + ) salesreturns, + date_dim, + web_site + where + date_sk = d_date_sk and d_date between cast('1998-08-04' as date) + and (cast('1998-08-04' as date) + INTERVAL 14 days) + and wsr_web_site_sk = web_site_sk + group by + web_site_id), +results as ( + select + channel, + id, + sum(sales) as sales, + sum(returns) as returns, + sum(profit) as profit + from ( + select + 'store channel' as channel, + 'store' || s_store_id as id, + sales, + returns, + (profit - profit_loss) as profit + from + ssr + union all + select + 'catalog channel' as channel, + 'catalog_page' || cp_catalog_page_id as id, + sales, + returns, + (profit - profit_loss) as profit + from + csr + union all + select + 'web channel' as channel, + 'web_site' || web_site_id as id, + sales, + returns, + (profit - profit_loss) as profit + from + wsr) x + group by + channel, id) +select + channel, id, sales, returns, profit +from ( + select channel, id, sales, returns, profit + from results + union + select channel, null as id, sum(sales), sum(returns), sum(profit) + from results + group by channel + union + select null as channel, null as id, sum(sales), sum(returns), sum(profit) + from results) foo + order by channel, id +limit 100 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q6.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q6.sql new file mode 100755 index 0000000000000..c0bfa40ad44a8 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q6.sql @@ -0,0 +1,23 @@ +SELECT + a.ca_state state, + count(*) cnt +FROM + customer_address a, customer c, store_sales s, date_dim d, item i +WHERE a.ca_address_sk = c.c_current_addr_sk + AND c.c_customer_sk = s.ss_customer_sk + AND s.ss_sold_date_sk = d.d_date_sk + AND s.ss_item_sk = i.i_item_sk + AND d.d_month_seq = + (SELECT DISTINCT (d_month_seq) + FROM date_dim + WHERE d_year = 2000 AND d_moy = 1) + AND i.i_current_price > 1.2 * + (SELECT avg(j.i_current_price) + FROM item j + WHERE j.i_category = i.i_category) +GROUP BY a.ca_state +HAVING count(*) >= 10 +-- order-by list of q6 in TPCDS v1.4 is below: +-- order by cnt +order by cnt, a.ca_state +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q64.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q64.sql new file mode 100755 index 0000000000000..cdcd8486b363d --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q64.sql @@ -0,0 +1,111 @@ +WITH cs_ui AS +(SELECT + cs_item_sk, + sum(cs_ext_list_price) AS sale, + sum(cr_refunded_cash + cr_reversed_charge + cr_store_credit) AS refund + FROM catalog_sales + , catalog_returns + WHERE cs_item_sk = cr_item_sk + AND cs_order_number = cr_order_number + GROUP BY cs_item_sk + HAVING sum(cs_ext_list_price) > 2 * sum(cr_refunded_cash + cr_reversed_charge + cr_store_credit)), + cross_sales AS + (SELECT + i_product_name product_name, + i_item_sk item_sk, + s_store_name store_name, + s_zip store_zip, + ad1.ca_street_number b_street_number, + ad1.ca_street_name b_streen_name, + ad1.ca_city b_city, + ad1.ca_zip b_zip, + ad2.ca_street_number c_street_number, + ad2.ca_street_name c_street_name, + ad2.ca_city c_city, + ad2.ca_zip c_zip, + d1.d_year AS syear, + d2.d_year AS fsyear, + d3.d_year s2year, + count(*) cnt, + sum(ss_wholesale_cost) s1, + sum(ss_list_price) s2, + sum(ss_coupon_amt) s3 + FROM store_sales, store_returns, cs_ui, date_dim d1, date_dim d2, date_dim d3, + store, customer, customer_demographics cd1, customer_demographics cd2, + promotion, household_demographics hd1, household_demographics hd2, + customer_address ad1, customer_address ad2, income_band ib1, income_band ib2, item + WHERE ss_store_sk = s_store_sk AND + ss_sold_date_sk = d1.d_date_sk AND + ss_customer_sk = c_customer_sk AND + ss_cdemo_sk = cd1.cd_demo_sk AND + ss_hdemo_sk = hd1.hd_demo_sk AND + ss_addr_sk = ad1.ca_address_sk AND + ss_item_sk = i_item_sk AND + ss_item_sk = sr_item_sk AND + ss_ticket_number = sr_ticket_number AND + ss_item_sk = cs_ui.cs_item_sk AND + c_current_cdemo_sk = cd2.cd_demo_sk AND + c_current_hdemo_sk = hd2.hd_demo_sk AND + c_current_addr_sk = ad2.ca_address_sk AND + c_first_sales_date_sk = d2.d_date_sk AND + c_first_shipto_date_sk = d3.d_date_sk AND + ss_promo_sk = p_promo_sk AND + hd1.hd_income_band_sk = ib1.ib_income_band_sk AND + hd2.hd_income_band_sk = ib2.ib_income_band_sk AND + cd1.cd_marital_status <> cd2.cd_marital_status AND + i_color IN ('purple', 'burlywood', 'indian', 'spring', 'floral', 'medium') AND + i_current_price BETWEEN 64 AND 64 + 10 AND + i_current_price BETWEEN 64 + 1 AND 64 + 15 + GROUP BY + i_product_name, + i_item_sk, + s_store_name, + s_zip, + ad1.ca_street_number, + ad1.ca_street_name, + ad1.ca_city, + ad1.ca_zip, + ad2.ca_street_number, + ad2.ca_street_name, + ad2.ca_city, + ad2.ca_zip, + d1.d_year, + d2.d_year, + d3.d_year + ) +SELECT + cs1.product_name, + cs1.store_name, + cs1.store_zip, + cs1.b_street_number, + cs1.b_streen_name, + cs1.b_city, + cs1.b_zip, + cs1.c_street_number, + cs1.c_street_name, + cs1.c_city, + cs1.c_zip, + cs1.syear, + cs1.cnt, + cs1.s1, + cs1.s2, + cs1.s3, + cs2.s1, + cs2.s2, + cs2.s3, + cs2.syear, + cs2.cnt +FROM cross_sales cs1, cross_sales cs2 +WHERE cs1.item_sk = cs2.item_sk AND + cs1.syear = 1999 AND + cs2.syear = 1999 + 1 AND + cs2.cnt <= cs1.cnt AND + cs1.store_name = cs2.store_name AND + cs1.store_zip = cs2.store_zip +ORDER BY + cs1.product_name, + cs1.store_name, + cs2.cnt, + -- The two columns below are newly added in TPCDS v2.7 + cs1.s1, + cs2.s1 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q67a.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q67a.sql new file mode 100644 index 0000000000000..70a14043bbb3d --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q67a.sql @@ -0,0 +1,208 @@ +-- This is a new query in TPCDS v2.7 +with results as ( + select + i_category, + i_class, + i_brand, + i_product_name, + d_year, + d_qoy, + d_moy, + s_store_id, + sum(coalesce(ss_sales_price * ss_quantity, 0)) sumsales + from + store_sales, date_dim, store, item + where + ss_sold_date_sk=d_date_sk + and ss_item_sk=i_item_sk + and ss_store_sk = s_store_sk + and d_month_seq between 1212 and 1212 + 11 + group by + i_category, + i_class, + i_brand, + i_product_name, + d_year, + d_qoy, + d_moy, + s_store_id), +results_rollup as ( + select + i_category, + i_class, + i_brand, + i_product_name, + d_year, + d_qoy, + d_moy, + s_store_id, + sumsales + from + results + union all + select + i_category, + i_class, + i_brand, + i_product_name, + d_year, + d_qoy, + d_moy, + null s_store_id, + sum(sumsales) sumsales + from + results + group by + i_category, + i_class, + i_brand, + i_product_name, + d_year, + d_qoy, + d_moy + union all + select + i_category, + i_class, + i_brand, + i_product_name, + d_year, + d_qoy, + null d_moy, + null s_store_id, + sum(sumsales) sumsales + from + results + group by + i_category, + i_class, + i_brand, + i_product_name, + d_year, + d_qoy + union all + select + i_category, + i_class, + i_brand, + i_product_name, + d_year, + null d_qoy, + null d_moy, + null s_store_id, + sum(sumsales) sumsales + from + results + group by + i_category, + i_class, + i_brand, + i_product_name, + d_year + union all + select + i_category, + i_class, + i_brand, + i_product_name, + null d_year, + null d_qoy, + null d_moy, + null s_store_id, + sum(sumsales) sumsales + from + results + group by + i_category, + i_class, + i_brand, + i_product_name + union all + select + i_category, + i_class, + i_brand, + null i_product_name, + null d_year, + null d_qoy, + null d_moy, + null s_store_id, + sum(sumsales) sumsales + from + results + group by + i_category, + i_class, + i_brand + union all + select + i_category, + i_class, + null i_brand, + null i_product_name, + null d_year, + null d_qoy, + null d_moy, + null s_store_id, + sum(sumsales) sumsales + from + results + group by + i_category, + i_class + union all + select + i_category, + null i_class, + null i_brand, + null i_product_name, + null d_year, + null d_qoy, + null d_moy, + null s_store_id, + sum(sumsales) sumsales + from results + group by + i_category + union all + select + null i_category, + null i_class, + null i_brand, + null i_product_name, + null d_year, + null d_qoy, + null d_moy, + null s_store_id, + sum(sumsales) sumsales + from + results) +select + * +from ( + select + i_category, + i_class, + i_brand, + i_product_name, + d_year, + d_qoy, + d_moy, + s_store_id, + sumsales, + rank() over (partition by i_category order by sumsales desc) rk + from results_rollup) dw2 +where + rk <= 100 +order by + i_category, + i_class, + i_brand, + i_product_name, + d_year, + d_qoy, + d_moy, + s_store_id, + sumsales, + rk +limit 100 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q70a.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q70a.sql new file mode 100644 index 0000000000000..4aec9c7fd1fd6 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q70a.sql @@ -0,0 +1,70 @@ +-- This is a new query in TPCDS v2.7 +with results as ( + select + sum(ss_net_profit) as total_sum, + s_state ,s_county, + 0 as gstate, + 0 as g_county + from + store_sales, date_dim d1, store + where + d1.d_month_seq between 1212 and 1212 + 11 + and d1.d_date_sk = ss_sold_date_sk + and s_store_sk = ss_store_sk + and s_state in ( + select s_state + from ( + select + s_state as s_state, + rank() over (partition by s_state order by sum(ss_net_profit) desc) as ranking + from store_sales, store, date_dim + where d_month_seq between 1212 and 1212 + 11 + and d_date_sk = ss_sold_date_sk + and s_store_sk = ss_store_sk + group by s_state) tmp1 + where ranking <= 5) + group by + s_state, s_county), +results_rollup as ( + select + total_sum, + s_state, + s_county, + 0 as g_state, + 0 as g_county, + 0 as lochierarchy + from results + union + select + sum(total_sum) as total_sum,s_state, + NULL as s_county, + 0 as g_state, + 1 as g_county, + 1 as lochierarchy + from results + group by s_state + union + select + sum(total_sum) as total_sum, + NULL as s_state, + NULL as s_county, + 1 as g_state, + 1 as g_county, + 2 as lochierarchy + from results) +select + total_sum, + s_state, + s_county, + lochierarchy, + rank() over ( + partition by lochierarchy, + case when g_county = 0 then s_state end + order by total_sum desc) as rank_within_parent +from + results_rollup +order by + lochierarchy desc, + case when lochierarchy = 0 then s_state end, + rank_within_parent +limit 100 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q72.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q72.sql new file mode 100755 index 0000000000000..066d6a587e917 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q72.sql @@ -0,0 +1,40 @@ +SELECT + i_item_desc, + w_warehouse_name, + d1.d_week_seq, + count(CASE WHEN p_promo_sk IS NULL + THEN 1 + ELSE 0 END) no_promo, + count(CASE WHEN p_promo_sk IS NOT NULL + THEN 1 + ELSE 0 END) promo, + count(*) total_cnt +FROM catalog_sales + JOIN inventory ON (cs_item_sk = inv_item_sk) + JOIN warehouse ON (w_warehouse_sk = inv_warehouse_sk) + JOIN item ON (i_item_sk = cs_item_sk) + JOIN customer_demographics ON (cs_bill_cdemo_sk = cd_demo_sk) + JOIN household_demographics ON (cs_bill_hdemo_sk = hd_demo_sk) + JOIN date_dim d1 ON (cs_sold_date_sk = d1.d_date_sk) + JOIN date_dim d2 ON (inv_date_sk = d2.d_date_sk) + JOIN date_dim d3 ON (cs_ship_date_sk = d3.d_date_sk) + LEFT OUTER JOIN promotion ON (cs_promo_sk = p_promo_sk) + LEFT OUTER JOIN catalog_returns ON (cr_item_sk = cs_item_sk AND cr_order_number = cs_order_number) +-- q72 in TPCDS v1.4 had conditions below: +-- WHERE d1.d_week_seq = d2.d_week_seq +-- AND inv_quantity_on_hand < cs_quantity +-- AND d3.d_date > (cast(d1.d_date AS DATE) + interval 5 days) +-- AND hd_buy_potential = '>10000' +-- AND d1.d_year = 1999 +-- AND hd_buy_potential = '>10000' +-- AND cd_marital_status = 'D' +-- AND d1.d_year = 1999 +WHERE d1.d_week_seq = d2.d_week_seq + AND inv_quantity_on_hand < cs_quantity + AND d3.d_date > d1.d_date + INTERVAL 5 days + AND hd_buy_potential = '1001-5000' + AND d1.d_year = 2001 + AND cd_marital_status = 'M' +GROUP BY i_item_desc, w_warehouse_name, d1.d_week_seq +ORDER BY total_cnt DESC, i_item_desc, w_warehouse_name, d_week_seq +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q74.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q74.sql new file mode 100755 index 0000000000000..94a0063b36c0c --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q74.sql @@ -0,0 +1,60 @@ +WITH year_total AS ( + SELECT + c_customer_id customer_id, + c_first_name customer_first_name, + c_last_name customer_last_name, + d_year AS year, + sum(ss_net_paid) year_total, + 's' sale_type + FROM + customer, store_sales, date_dim + WHERE c_customer_sk = ss_customer_sk + AND ss_sold_date_sk = d_date_sk + AND d_year IN (2001, 2001 + 1) + GROUP BY + c_customer_id, c_first_name, c_last_name, d_year + UNION ALL + SELECT + c_customer_id customer_id, + c_first_name customer_first_name, + c_last_name customer_last_name, + d_year AS year, + sum(ws_net_paid) year_total, + 'w' sale_type + FROM + customer, web_sales, date_dim + WHERE c_customer_sk = ws_bill_customer_sk + AND ws_sold_date_sk = d_date_sk + AND d_year IN (2001, 2001 + 1) + GROUP BY + c_customer_id, c_first_name, c_last_name, d_year) +SELECT + t_s_secyear.customer_id, + t_s_secyear.customer_first_name, + t_s_secyear.customer_last_name +FROM + year_total t_s_firstyear, year_total t_s_secyear, + year_total t_w_firstyear, year_total t_w_secyear +WHERE t_s_secyear.customer_id = t_s_firstyear.customer_id + AND t_s_firstyear.customer_id = t_w_secyear.customer_id + AND t_s_firstyear.customer_id = t_w_firstyear.customer_id + AND t_s_firstyear.sale_type = 's' + AND t_w_firstyear.sale_type = 'w' + AND t_s_secyear.sale_type = 's' + AND t_w_secyear.sale_type = 'w' + AND t_s_firstyear.year = 2001 + AND t_s_secyear.year = 2001 + 1 + AND t_w_firstyear.year = 2001 + AND t_w_secyear.year = 2001 + 1 + AND t_s_firstyear.year_total > 0 + AND t_w_firstyear.year_total > 0 + AND CASE WHEN t_w_firstyear.year_total > 0 + THEN t_w_secyear.year_total / t_w_firstyear.year_total + ELSE NULL END + > CASE WHEN t_s_firstyear.year_total > 0 + THEN t_s_secyear.year_total / t_s_firstyear.year_total + ELSE NULL END +-- order-by list of q74 in TPCDS v1.4 is below: +-- ORDER BY 1, 1, 1 +ORDER BY 2, 1, 3 +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q75.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q75.sql new file mode 100755 index 0000000000000..ae5dc97ef2317 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q75.sql @@ -0,0 +1,78 @@ +WITH all_sales AS ( + SELECT + d_year, + i_brand_id, + i_class_id, + i_category_id, + i_manufact_id, + SUM(sales_cnt) AS sales_cnt, + SUM(sales_amt) AS sales_amt + FROM ( + SELECT + d_year, + i_brand_id, + i_class_id, + i_category_id, + i_manufact_id, + cs_quantity - COALESCE(cr_return_quantity, 0) AS sales_cnt, + cs_ext_sales_price - COALESCE(cr_return_amount, 0.0) AS sales_amt + FROM catalog_sales + JOIN item ON i_item_sk = cs_item_sk + JOIN date_dim ON d_date_sk = cs_sold_date_sk + LEFT JOIN catalog_returns ON (cs_order_number = cr_order_number + AND cs_item_sk = cr_item_sk) + WHERE i_category = 'Books' + UNION + SELECT + d_year, + i_brand_id, + i_class_id, + i_category_id, + i_manufact_id, + ss_quantity - COALESCE(sr_return_quantity, 0) AS sales_cnt, + ss_ext_sales_price - COALESCE(sr_return_amt, 0.0) AS sales_amt + FROM store_sales + JOIN item ON i_item_sk = ss_item_sk + JOIN date_dim ON d_date_sk = ss_sold_date_sk + LEFT JOIN store_returns ON (ss_ticket_number = sr_ticket_number + AND ss_item_sk = sr_item_sk) + WHERE i_category = 'Books' + UNION + SELECT + d_year, + i_brand_id, + i_class_id, + i_category_id, + i_manufact_id, + ws_quantity - COALESCE(wr_return_quantity, 0) AS sales_cnt, + ws_ext_sales_price - COALESCE(wr_return_amt, 0.0) AS sales_amt + FROM web_sales + JOIN item ON i_item_sk = ws_item_sk + JOIN date_dim ON d_date_sk = ws_sold_date_sk + LEFT JOIN web_returns ON (ws_order_number = wr_order_number + AND ws_item_sk = wr_item_sk) + WHERE i_category = 'Books') sales_detail + GROUP BY d_year, i_brand_id, i_class_id, i_category_id, i_manufact_id) +SELECT + prev_yr.d_year AS prev_year, + curr_yr.d_year AS year, + curr_yr.i_brand_id, + curr_yr.i_class_id, + curr_yr.i_category_id, + curr_yr.i_manufact_id, + prev_yr.sales_cnt AS prev_yr_cnt, + curr_yr.sales_cnt AS curr_yr_cnt, + curr_yr.sales_cnt - prev_yr.sales_cnt AS sales_cnt_diff, + curr_yr.sales_amt - prev_yr.sales_amt AS sales_amt_diff +FROM all_sales curr_yr, all_sales prev_yr +WHERE curr_yr.i_brand_id = prev_yr.i_brand_id + AND curr_yr.i_class_id = prev_yr.i_class_id + AND curr_yr.i_category_id = prev_yr.i_category_id + AND curr_yr.i_manufact_id = prev_yr.i_manufact_id + AND curr_yr.d_year = 2002 + AND prev_yr.d_year = 2002 - 1 + AND CAST(curr_yr.sales_cnt AS DECIMAL(17, 2)) / CAST(prev_yr.sales_cnt AS DECIMAL(17, 2)) < 0.9 +ORDER BY + sales_cnt_diff, + sales_amt_diff -- This order-by condition did not exist in TPCDS v1.4 +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q77a.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q77a.sql new file mode 100644 index 0000000000000..fc69c43470f1e --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q77a.sql @@ -0,0 +1,121 @@ +-- This is a new query in TPCDS v2.7 +with ss as ( + select + s_store_sk, + sum(ss_ext_sales_price) as sales, + sum(ss_net_profit) as profit + from + store_sales, date_dim, store + where + ss_sold_date_sk = d_date_sk + and d_date between cast('1998-08-04' as date) + and (cast('1998-08-04' as date) + interval 30 days) + and ss_store_sk = s_store_sk + group by + s_store_sk), +sr as ( + select + s_store_sk, + sum(sr_return_amt) as returns, + sum(sr_net_loss) as profit_loss + from + store_returns, date_dim, store + where + sr_returned_date_sk = d_date_sk + and d_date between cast('1998-08-04' as date) + and (cast('1998-08-04' as date) + interval 30 days) + and sr_store_sk = s_store_sk + group by + s_store_sk), +cs as ( + select + cs_call_center_sk, + sum(cs_ext_sales_price) as sales, + sum(cs_net_profit) as profit + from + catalog_sales, + date_dim + where + cs_sold_date_sk = d_date_sk + and d_date between cast('1998-08-04' as date) + and (cast('1998-08-04' as date) + interval 30 days) + group by + cs_call_center_sk), + cr as ( + select + sum(cr_return_amount) as returns, + sum(cr_net_loss) as profit_loss + from catalog_returns, + date_dim + where + cr_returned_date_sk = d_date_sk + and d_date between cast('1998-08-04' as date) + and (cast('1998-08-04' as date) + interval 30 days)), +ws as ( select wp_web_page_sk, + sum(ws_ext_sales_price) as sales, + sum(ws_net_profit) as profit + from web_sales, + date_dim, + web_page + where ws_sold_date_sk = d_date_sk + and d_date between cast('1998-08-04' as date) + and (cast('1998-08-04' as date) + interval 30 days) + and ws_web_page_sk = wp_web_page_sk + group by wp_web_page_sk), + wr as + (select wp_web_page_sk, + sum(wr_return_amt) as returns, + sum(wr_net_loss) as profit_loss + from web_returns, + date_dim, + web_page + where wr_returned_date_sk = d_date_sk + and d_date between cast('1998-08-04' as date) + and (cast('1998-08-04' as date) + interval 30 days) + and wr_web_page_sk = wp_web_page_sk + group by wp_web_page_sk) + , + results as + (select channel + , id + , sum(sales) as sales + , sum(returns) as returns + , sum(profit) as profit + from + (select 'store channel' as channel + , ss.s_store_sk as id + , sales + , coalesce(returns, 0) as returns + , (profit - coalesce(profit_loss,0)) as profit + from ss left join sr + on ss.s_store_sk = sr.s_store_sk + union all + select 'catalog channel' as channel + , cs_call_center_sk as id + , sales + , returns + , (profit - profit_loss) as profit + from cs + , cr + union all + select 'web channel' as channel + , ws.wp_web_page_sk as id + , sales + , coalesce(returns, 0) returns + , (profit - coalesce(profit_loss,0)) as profit + from ws left join wr + on ws.wp_web_page_sk = wr.wp_web_page_sk + ) x + group by channel, id ) + + select * + from ( + select channel, id, sales, returns, profit from results + union + select channel, NULL AS id, sum(sales) as sales, sum(returns) as returns, sum(profit) as profit from results group by channel + union + select NULL AS channel, NULL AS id, sum(sales) as sales, sum(returns) as returns, sum(profit) as profit from results +) foo +order by + channel, id +limit 100 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q78.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q78.sql new file mode 100755 index 0000000000000..d03d8af77174c --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q78.sql @@ -0,0 +1,75 @@ +WITH ws AS +(SELECT + d_year AS ws_sold_year, + ws_item_sk, + ws_bill_customer_sk ws_customer_sk, + sum(ws_quantity) ws_qty, + sum(ws_wholesale_cost) ws_wc, + sum(ws_sales_price) ws_sp + FROM web_sales + LEFT JOIN web_returns ON wr_order_number = ws_order_number AND ws_item_sk = wr_item_sk + JOIN date_dim ON ws_sold_date_sk = d_date_sk + WHERE wr_order_number IS NULL + GROUP BY d_year, ws_item_sk, ws_bill_customer_sk +), + cs AS + (SELECT + d_year AS cs_sold_year, + cs_item_sk, + cs_bill_customer_sk cs_customer_sk, + sum(cs_quantity) cs_qty, + sum(cs_wholesale_cost) cs_wc, + sum(cs_sales_price) cs_sp + FROM catalog_sales + LEFT JOIN catalog_returns ON cr_order_number = cs_order_number AND cs_item_sk = cr_item_sk + JOIN date_dim ON cs_sold_date_sk = d_date_sk + WHERE cr_order_number IS NULL + GROUP BY d_year, cs_item_sk, cs_bill_customer_sk + ), + ss AS + (SELECT + d_year AS ss_sold_year, + ss_item_sk, + ss_customer_sk, + sum(ss_quantity) ss_qty, + sum(ss_wholesale_cost) ss_wc, + sum(ss_sales_price) ss_sp + FROM store_sales + LEFT JOIN store_returns ON sr_ticket_number = ss_ticket_number AND ss_item_sk = sr_item_sk + JOIN date_dim ON ss_sold_date_sk = d_date_sk + WHERE sr_ticket_number IS NULL + GROUP BY d_year, ss_item_sk, ss_customer_sk + ) +SELECT + round(ss_qty / (coalesce(ws_qty + cs_qty, 1)), 2) ratio, + ss_qty store_qty, + ss_wc store_wholesale_cost, + ss_sp store_sales_price, + coalesce(ws_qty, 0) + coalesce(cs_qty, 0) other_chan_qty, + coalesce(ws_wc, 0) + coalesce(cs_wc, 0) other_chan_wholesale_cost, + coalesce(ws_sp, 0) + coalesce(cs_sp, 0) other_chan_sales_price +FROM ss + LEFT JOIN ws + ON (ws_sold_year = ss_sold_year AND ws_item_sk = ss_item_sk AND ws_customer_sk = ss_customer_sk) + LEFT JOIN cs + ON (cs_sold_year = ss_sold_year AND cs_item_sk = ss_item_sk AND cs_customer_sk = ss_customer_sk) +WHERE coalesce(ws_qty, 0) > 0 AND coalesce(cs_qty, 0) > 0 AND ss_sold_year = 2000 +ORDER BY + -- order-by list of q78 in TPCDS v1.4 is below: + -- ratio, + -- ss_qty DESC, ss_wc DESC, ss_sp DESC, + -- other_chan_qty, + -- other_chan_wholesale_cost, + -- other_chan_sales_price, + -- round(ss_qty / (coalesce(ws_qty + cs_qty, 1)), 2) + ss_sold_year, + ss_item_sk, + ss_customer_sk, + ss_qty desc, + ss_wc desc, + ss_sp desc, + other_chan_qty, + other_chan_wholesale_cost, + other_chan_sales_price, + ratio +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q80a.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q80a.sql new file mode 100644 index 0000000000000..686e03ba2a6d0 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q80a.sql @@ -0,0 +1,147 @@ +-- This is a new query in TPCDS v2.7 +with ssr as ( + select + s_store_id as store_id, + sum(ss_ext_sales_price) as sales, + sum(coalesce(sr_return_amt, 0)) as returns, + sum(ss_net_profit - coalesce(sr_net_loss, 0)) as profit + from + store_sales left outer join store_returns on ( + ss_item_sk = sr_item_sk and ss_ticket_number = sr_ticket_number), + date_dim, + store, + item, + promotion + where + ss_sold_date_sk = d_date_sk + and d_date between cast('1998-08-04' as date) + and (cast('1998-08-04' as date) + interval 30 days) + and ss_store_sk = s_store_sk + and ss_item_sk = i_item_sk + and i_current_price > 50 + and ss_promo_sk = p_promo_sk + and p_channel_tv = 'N' + group by + s_store_id), +csr as ( + select + cp_catalog_page_id as catalog_page_id, + sum(cs_ext_sales_price) as sales, + sum(coalesce(cr_return_amount, 0)) as returns, + sum(cs_net_profit - coalesce(cr_net_loss, 0)) as profit + from + catalog_sales left outer join catalog_returns on + (cs_item_sk = cr_item_sk and cs_order_number = cr_order_number), + date_dim, + catalog_page, + item, + promotion + where + cs_sold_date_sk = d_date_sk + and d_date between cast('1998-08-04' as date) + and (cast('1998-08-04' as date) + interval 30 days) + and cs_catalog_page_sk = cp_catalog_page_sk + and cs_item_sk = i_item_sk + and i_current_price > 50 + and cs_promo_sk = p_promo_sk + and p_channel_tv = 'N' + group by + cp_catalog_page_id), +wsr as ( + select + web_site_id, + sum(ws_ext_sales_price) as sales, + sum(coalesce(wr_return_amt, 0)) as returns, + sum(ws_net_profit - coalesce(wr_net_loss, 0)) as profit + from + web_sales left outer join web_returns on ( + ws_item_sk = wr_item_sk and ws_order_number = wr_order_number), + date_dim, + web_site, + item, + promotion + where + ws_sold_date_sk = d_date_sk + and d_date between cast('1998-08-04' as date) + and (cast('1998-08-04' as date) + interval 30 days) + and ws_web_site_sk = web_site_sk + and ws_item_sk = i_item_sk + and i_current_price > 50 + and ws_promo_sk = p_promo_sk + and p_channel_tv = 'N' + group by + web_site_id), +results as ( + select + channel, + id, + sum(sales) as sales, + sum(returns) as returns, + sum(profit) as profit + from ( + select + 'store channel' as channel, + 'store' || store_id as id, + sales, + returns, + profit + from + ssr + union all + select + 'catalog channel' as channel, + 'catalog_page' || catalog_page_id as id, + sales, + returns, + profit + from + csr + union all + select + 'web channel' as channel, + 'web_site' || web_site_id as id, + sales, + returns, + profit + from + wsr) x + group by + channel, id) +select + channel, + id, + sales, + returns, + profit +from ( + select + channel, + id, + sales, + returns, + profit + from + results + union + select + channel, + NULL AS id, + sum(sales) as sales, + sum(returns) as returns, + sum(profit) as profit + from + results + group by + channel + union + select + NULL AS channel, + NULL AS id, + sum(sales) as sales, + sum(returns) as returns, + sum(profit) as profit + from + results) foo +order by + channel, id +limit 100 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q86a.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q86a.sql new file mode 100644 index 0000000000000..fff76b08d4ba0 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q86a.sql @@ -0,0 +1,61 @@ +-- This is a new query in TPCDS v2.7 +with results as ( + select + sum(ws_net_paid) as total_sum, + i_category, i_class, + 0 as g_category, + 0 as g_class + from + web_sales, date_dim d1, item + where + d1.d_month_seq between 1212 and 1212 + 11 + and d1.d_date_sk = ws_sold_date_sk + and i_item_sk = ws_item_sk + group by + i_category, i_class), +results_rollup as( + select + total_sum, + i_category, + i_class, + g_category, + g_class, + 0 as lochierarchy + from + results + union + select + sum(total_sum) as total_sum, + i_category, + NULL as i_class, + 0 as g_category, + 1 as g_class, + 1 as lochierarchy + from + results + group by + i_category + union + select + sum(total_sum) as total_sum, + NULL as i_category, + NULL as i_class, + 1 as g_category, + 1 as g_class, + 2 as lochierarchy + from + results) +select + total_sum, + i_category ,i_class, lochierarchy, + rank() over ( + partition by lochierarchy, + case when g_class = 0 then i_category end + order by total_sum desc) as rank_within_parent +from + results_rollup +order by + lochierarchy desc, + case when lochierarchy = 0 then i_category end, + rank_within_parent +limit 100 diff --git a/sql/core/src/test/resources/tpcds-v2.7.0/q98.sql b/sql/core/src/test/resources/tpcds-v2.7.0/q98.sql new file mode 100755 index 0000000000000..771117add2ed2 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-v2.7.0/q98.sql @@ -0,0 +1,22 @@ +SELECT + i_item_id, -- This column did not exist in TPCDS v1.4 + i_item_desc, + i_category, + i_class, + i_current_price, + sum(ss_ext_sales_price) AS itemrevenue, + sum(ss_ext_sales_price) * 100 / sum(sum(ss_ext_sales_price)) + OVER + (PARTITION BY i_class) AS revenueratio +FROM + store_sales, item, date_dim +WHERE + ss_item_sk = i_item_sk + AND i_category IN ('Sports', 'Books', 'Home') + AND ss_sold_date_sk = d_date_sk + AND d_date BETWEEN cast('1999-02-22' AS DATE) + AND (cast('1999-02-22' AS DATE) + INTERVAL 30 days) +GROUP BY + i_item_id, i_item_desc, i_category, i_class, i_current_price +ORDER BY + i_category, i_class, i_item_id, i_item_desc, revenueratio diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TPCDSQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/TPCDSQuerySuite.scala index 1a584187a06e5..bc95b4696190d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TPCDSQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TPCDSQuerySuite.scala @@ -62,7 +62,7 @@ class TPCDSQuerySuite extends BenchmarkQueryTest { |`c_first_sales_date_sk` INT, `c_salutation` STRING, `c_first_name` STRING, |`c_last_name` STRING, `c_preferred_cust_flag` STRING, `c_birth_day` INT, |`c_birth_month` INT, `c_birth_year` INT, `c_birth_country` STRING, `c_login` STRING, - |`c_email_address` STRING, `c_last_review_date` STRING) + |`c_email_address` STRING, `c_last_review_date` INT) |USING parquet """.stripMargin) @@ -88,7 +88,7 @@ class TPCDSQuerySuite extends BenchmarkQueryTest { sql( """ |CREATE TABLE `date_dim` ( - |`d_date_sk` INT, `d_date_id` STRING, `d_date` STRING, + |`d_date_sk` INT, `d_date_id` STRING, `d_date` DATE, |`d_month_seq` INT, `d_week_seq` INT, `d_quarter_seq` INT, `d_year` INT, `d_dow` INT, |`d_moy` INT, `d_dom` INT, `d_qoy` INT, `d_fy_year` INT, `d_fy_quarter_seq` INT, |`d_fy_week_seq` INT, `d_day_name` STRING, `d_quarter_name` STRING, `d_holiday` STRING, @@ -115,8 +115,8 @@ class TPCDSQuerySuite extends BenchmarkQueryTest { sql( """ - |CREATE TABLE `item` (`i_item_sk` INT, `i_item_id` STRING, `i_rec_start_date` STRING, - |`i_rec_end_date` STRING, `i_item_desc` STRING, `i_current_price` DECIMAL(7,2), + |CREATE TABLE `item` (`i_item_sk` INT, `i_item_id` STRING, `i_rec_start_date` DATE, + |`i_rec_end_date` DATE, `i_item_desc` STRING, `i_current_price` DECIMAL(7,2), |`i_wholesale_cost` DECIMAL(7,2), `i_brand_id` INT, `i_brand` STRING, `i_class_id` INT, |`i_class` STRING, `i_category_id` INT, `i_category` STRING, `i_manufact_id` INT, |`i_manufact` STRING, `i_size` STRING, `i_formulation` STRING, `i_color` STRING, @@ -139,8 +139,8 @@ class TPCDSQuerySuite extends BenchmarkQueryTest { sql( """ |CREATE TABLE `store` ( - |`s_store_sk` INT, `s_store_id` STRING, `s_rec_start_date` STRING, - |`s_rec_end_date` STRING, `s_closed_date_sk` INT, `s_store_name` STRING, + |`s_store_sk` INT, `s_store_id` STRING, `s_rec_start_date` DATE, + |`s_rec_end_date` DATE, `s_closed_date_sk` INT, `s_store_name` STRING, |`s_number_employees` INT, `s_floor_space` INT, `s_hours` STRING, `s_manager` STRING, |`s_market_id` INT, `s_geography_class` STRING, `s_market_desc` STRING, |`s_market_manager` STRING, `s_division_id` INT, `s_division_name` STRING, @@ -157,7 +157,7 @@ class TPCDSQuerySuite extends BenchmarkQueryTest { |`sr_returned_date_sk` BIGINT, `sr_return_time_sk` BIGINT, `sr_item_sk` BIGINT, |`sr_customer_sk` BIGINT, `sr_cdemo_sk` BIGINT, `sr_hdemo_sk` BIGINT, `sr_addr_sk` BIGINT, |`sr_store_sk` BIGINT, `sr_reason_sk` BIGINT, `sr_ticket_number` BIGINT, - |`sr_return_quantity` BIGINT, `sr_return_amt` DECIMAL(7,2), `sr_return_tax` DECIMAL(7,2), + |`sr_return_quantity` INT, `sr_return_amt` DECIMAL(7,2), `sr_return_tax` DECIMAL(7,2), |`sr_return_amt_inc_tax` DECIMAL(7,2), `sr_fee` DECIMAL(7,2), |`sr_return_ship_cost` DECIMAL(7,2), `sr_refunded_cash` DECIMAL(7,2), |`sr_reversed_charge` DECIMAL(7,2), `sr_store_credit` DECIMAL(7,2), @@ -225,7 +225,7 @@ class TPCDSQuerySuite extends BenchmarkQueryTest { |`wr_refunded_hdemo_sk` BIGINT, `wr_refunded_addr_sk` BIGINT, |`wr_returning_customer_sk` BIGINT, `wr_returning_cdemo_sk` BIGINT, |`wr_returning_hdemo_sk` BIGINT, `wr_returning_addr_sk` BIGINT, `wr_web_page_sk` BIGINT, - |`wr_reason_sk` BIGINT, `wr_order_number` BIGINT, `wr_return_quantity` BIGINT, + |`wr_reason_sk` BIGINT, `wr_order_number` BIGINT, `wr_return_quantity` INT, |`wr_return_amt` DECIMAL(7,2), `wr_return_tax` DECIMAL(7,2), |`wr_return_amt_inc_tax` DECIMAL(7,2), `wr_fee` DECIMAL(7,2), |`wr_return_ship_cost` DECIMAL(7,2), `wr_refunded_cash` DECIMAL(7,2), @@ -244,7 +244,7 @@ class TPCDSQuerySuite extends BenchmarkQueryTest { |`web_company_id` INT, `web_company_name` STRING, `web_street_number` STRING, |`web_street_name` STRING, `web_street_type` STRING, `web_suite_number` STRING, |`web_city` STRING, `web_county` STRING, `web_state` STRING, `web_zip` STRING, - |`web_country` STRING, `web_gmt_offset` STRING, `web_tax_percentage` DECIMAL(5,2)) + |`web_country` STRING, `web_gmt_offset` DECIMAL(5,2), `web_tax_percentage` DECIMAL(5,2)) |USING parquet """.stripMargin) @@ -315,6 +315,7 @@ class TPCDSQuerySuite extends BenchmarkQueryTest { """.stripMargin) } + // The TPCDS queries below are based on v1.4 val tpcdsQueries = Seq( "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14a", "q14b", "q15", "q16", "q17", "q18", "q19", "q20", @@ -339,6 +340,25 @@ class TPCDSQuerySuite extends BenchmarkQueryTest { } } + // This list only includes TPCDS v2.7 queries that are different from v1.4 ones + val tpcdsQueriesV2_7_0 = Seq( + "q5a", "q6", "q10a", "q11", "q12", "q14", "q14a", "q18a", + "q20", "q22", "q22a", "q24", "q27a", "q34", "q35", "q35a", "q36a", "q47", "q49", + "q51a", "q57", "q64", "q67a", "q70a", "q72", "q74", "q75", "q77a", "q78", + "q80a", "q86a", "q98") + + tpcdsQueriesV2_7_0.foreach { name => + val queryString = resourceToString(s"tpcds-v2.7.0/$name.sql", + classLoader = Thread.currentThread().getContextClassLoader) + test(s"$name-v2.7") { + withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "true") { + // check the plans can be properly generated + val plan = sql(queryString).queryExecution.executedPlan + checkGeneratedCode(plan) + } + } + } + // These queries are from https://github.com/cloudera/impala-tpcds-kit/tree/master/queries val modifiedTPCDSQueries = Seq( "q3", "q7", "q10", "q19", "q27", "q34", "q42", "q43", "q46", "q52", "q53", "q55", "q59", From e4bec7cb88b9ee63f8497e3f9e0ab0bfa5d5a77c Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Sun, 25 Mar 2018 16:38:49 -0700 Subject: [PATCH 0521/1161] [SPARK-23549][SQL] Cast to timestamp when comparing timestamp with date ## What changes were proposed in this pull request? This PR fixes an incorrect comparison in SQL between timestamp and date. This is because both of them are casted to `string` and then are compared lexicographically. This implementation shows `false` regarding this query `spark.sql("select cast('2017-03-01 00:00:00' as timestamp) between cast('2017-02-28' as date) and cast('2017-03-01' as date)").show`. This PR shows `true` for this query by casting `date("2017-03-01")` to `timestamp("2017-03-01 00:00:00")`. (Please fill in changes proposed in this fix) ## How was this patch tested? Added new UTs to `TypeCoercionSuite`. Author: Kazuaki Ishizaki Closes #20774 from kiszk/SPARK-23549. --- docs/sql-programming-guide.md | 1 + .../sql/catalyst/analysis/TypeCoercion.scala | 29 ++++++++----- .../apache/spark/sql/internal/SQLConf.scala | 13 ++++++ .../catalyst/analysis/TypeCoercionSuite.scala | 34 ++++++++++++--- .../sql-tests/inputs/predicate-functions.sql | 7 ++++ .../results/predicate-functions.sql.out | 42 ++++++++++++++++++- 6 files changed, 108 insertions(+), 18 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 421e2eaf62bfb..2b393f30d1435 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1808,6 +1808,7 @@ working with timestamps in `pandas_udf`s to get the best performance, see - Since Spark 2.4, writing an empty dataframe to a directory launches at least one write task, even if physically the dataframe has no partition. This introduces a small behavior change that for self-describing file formats like Parquet and Orc, Spark creates a metadata-only file in the target directory when writing a 0-partition dataframe, so that schema inference can still work if users read that directory later. The new behavior is more reasonable and more consistent regarding writing empty dataframe. - Since Spark 2.4, expression IDs in UDF arguments do not appear in column names. For example, an column name in Spark 2.4 is not `UDF:f(col0 AS colA#28)` but ``UDF:f(col0 AS `colA`)``. - Since Spark 2.4, writing a dataframe with an empty or nested empty schema using any file formats (parquet, orc, json, text, csv etc.) is not allowed. An exception is thrown when attempting to write dataframes with empty schema. + - Since Spark 2.4, Spark compares a DATE type with a TIMESTAMP type after promotes both sides to TIMESTAMP. To set `false` to `spark.sql.hive.compareDateTimestampInTimestamp` restores the previous behavior. This option will be removed in Spark 3.0. ## Upgrading From Spark SQL 2.2 to 2.3 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index e8669c4637d06..ec7e7761dc4c2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -47,9 +47,9 @@ import org.apache.spark.sql.types._ object TypeCoercion { def typeCoercionRules(conf: SQLConf): List[Rule[LogicalPlan]] = - InConversion :: + InConversion(conf) :: WidenSetOperationTypes :: - PromoteStrings :: + PromoteStrings(conf) :: DecimalPrecision :: BooleanEquality :: FunctionArgumentConversion :: @@ -127,7 +127,8 @@ object TypeCoercion { * is a String and the other is not. It also handles when one op is a Date and the * other is a Timestamp by making the target type to be String. */ - val findCommonTypeForBinaryComparison: (DataType, DataType) => Option[DataType] = { + private def findCommonTypeForBinaryComparison( + dt1: DataType, dt2: DataType, conf: SQLConf): Option[DataType] = (dt1, dt2) match { // We should cast all relative timestamp/date/string comparison into string comparisons // This behaves as a user would expect because timestamp strings sort lexicographically. // i.e. TimeStamp(2013-01-01 00:00 ...) < "2014" = true @@ -135,11 +136,17 @@ object TypeCoercion { case (DateType, StringType) => Some(StringType) case (StringType, TimestampType) => Some(StringType) case (TimestampType, StringType) => Some(StringType) - case (TimestampType, DateType) => Some(StringType) - case (DateType, TimestampType) => Some(StringType) case (StringType, NullType) => Some(StringType) case (NullType, StringType) => Some(StringType) + // Cast to TimestampType when we compare DateType with TimestampType + // if conf.compareDateTimestampInTimestamp is true + // i.e. TimeStamp('2017-03-01 00:00:00') eq Date('2017-03-01') = true + case (TimestampType, DateType) + => if (conf.compareDateTimestampInTimestamp) Some(TimestampType) else Some(StringType) + case (DateType, TimestampType) + => if (conf.compareDateTimestampInTimestamp) Some(TimestampType) else Some(StringType) + // There is no proper decimal type we can pick, // using double type is the best we can do. // See SPARK-22469 for details. @@ -147,7 +154,7 @@ object TypeCoercion { case (s: StringType, n: DecimalType) => Some(DoubleType) case (l: StringType, r: AtomicType) if r != StringType => Some(r) - case (l: AtomicType, r: StringType) if (l != StringType) => Some(l) + case (l: AtomicType, r: StringType) if l != StringType => Some(l) case (l, r) => None } @@ -313,7 +320,7 @@ object TypeCoercion { /** * Promotes strings that appear in arithmetic expressions. */ - object PromoteStrings extends TypeCoercionRule { + case class PromoteStrings(conf: SQLConf) extends TypeCoercionRule { private def castExpr(expr: Expression, targetType: DataType): Expression = { (expr.dataType, targetType) match { case (NullType, dt) => Literal.create(null, targetType) @@ -342,8 +349,8 @@ object TypeCoercion { p.makeCopy(Array(left, Cast(right, TimestampType))) case p @ BinaryComparison(left, right) - if findCommonTypeForBinaryComparison(left.dataType, right.dataType).isDefined => - val commonType = findCommonTypeForBinaryComparison(left.dataType, right.dataType).get + if findCommonTypeForBinaryComparison(left.dataType, right.dataType, conf).isDefined => + val commonType = findCommonTypeForBinaryComparison(left.dataType, right.dataType, conf).get p.makeCopy(Array(castExpr(left, commonType), castExpr(right, commonType))) case Abs(e @ StringType()) => Abs(Cast(e, DoubleType)) @@ -374,7 +381,7 @@ object TypeCoercion { * operator type is found the original expression will be returned and an * Analysis Exception will be raised at the type checking phase. */ - object InConversion extends TypeCoercionRule { + case class InConversion(conf: SQLConf) extends TypeCoercionRule { private def flattenExpr(expr: Expression): Seq[Expression] = { expr match { // Multi columns in IN clause is represented as a CreateNamedStruct. @@ -400,7 +407,7 @@ object TypeCoercion { val rhs = sub.output val commonTypes = lhs.zip(rhs).flatMap { case (l, r) => - findCommonTypeForBinaryComparison(l.dataType, r.dataType) + findCommonTypeForBinaryComparison(l.dataType, r.dataType, conf) .orElse(findTightestCommonType(l.dataType, r.dataType)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 11864bd1b1847..9cb03b5bb6152 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -479,6 +479,16 @@ object SQLConf { .checkValues(HiveCaseSensitiveInferenceMode.values.map(_.toString)) .createWithDefault(HiveCaseSensitiveInferenceMode.INFER_AND_SAVE.toString) + val TYPECOERCION_COMPARE_DATE_TIMESTAMP_IN_TIMESTAMP = + buildConf("spark.sql.typeCoercion.compareDateTimestampInTimestamp") + .internal() + .doc("When true (default), compare Date with Timestamp after converting both sides to " + + "Timestamp. This behavior is compatible with Hive 2.2 or later. See HIVE-15236. " + + "When false, restore the behavior prior to Spark 2.4. Compare Date with Timestamp after " + + "converting both sides to string. This config will be removed in spark 3.0") + .booleanConf + .createWithDefault(true) + val OPTIMIZER_METADATA_ONLY = buildConf("spark.sql.optimizer.metadataOnly") .doc("When true, enable the metadata-only query optimization that use the table's metadata " + "to produce the partition columns instead of table scans. It applies when all the columns " + @@ -1332,6 +1342,9 @@ class SQLConf extends Serializable with Logging { def caseSensitiveInferenceMode: HiveCaseSensitiveInferenceMode.Value = HiveCaseSensitiveInferenceMode.withName(getConf(HIVE_CASE_SENSITIVE_INFERENCE)) + def compareDateTimestampInTimestamp : Boolean = + getConf(TYPECOERCION_COMPARE_DATE_TIMESTAMP_IN_TIMESTAMP) + def gatherFastStats: Boolean = getConf(GATHER_FASTSTAT) def optimizerMetadataOnly: Boolean = getConf(OPTIMIZER_METADATA_ONLY) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala index 52a7ebdafd7c7..8ac49dc05e3cf 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala @@ -1207,7 +1207,7 @@ class TypeCoercionSuite extends AnalysisTest { */ test("make sure rules do not fire early") { // InConversion - val inConversion = TypeCoercion.InConversion + val inConversion = TypeCoercion.InConversion(conf) ruleTest(inConversion, In(UnresolvedAttribute("a"), Seq(Literal(1))), In(UnresolvedAttribute("a"), Seq(Literal(1))) @@ -1251,18 +1251,40 @@ class TypeCoercionSuite extends AnalysisTest { } test("binary comparison with string promotion") { - ruleTest(PromoteStrings, + val rule = TypeCoercion.PromoteStrings(conf) + ruleTest(rule, GreaterThan(Literal("123"), Literal(1)), GreaterThan(Cast(Literal("123"), IntegerType), Literal(1))) - ruleTest(PromoteStrings, + ruleTest(rule, LessThan(Literal(true), Literal("123")), LessThan(Literal(true), Cast(Literal("123"), BooleanType))) - ruleTest(PromoteStrings, + ruleTest(rule, EqualTo(Literal(Array(1, 2)), Literal("123")), EqualTo(Literal(Array(1, 2)), Literal("123"))) - ruleTest(PromoteStrings, + ruleTest(rule, GreaterThan(Literal("1.5"), Literal(BigDecimal("0.5"))), - GreaterThan(Cast(Literal("1.5"), DoubleType), Cast(Literal(BigDecimal("0.5")), DoubleType))) + GreaterThan(Cast(Literal("1.5"), DoubleType), Cast(Literal(BigDecimal("0.5")), + DoubleType))) + Seq(true, false).foreach { convertToTS => + withSQLConf( + "spark.sql.typeCoercion.compareDateTimestampInTimestamp" -> convertToTS.toString) { + val date0301 = Literal(java.sql.Date.valueOf("2017-03-01")) + val timestamp0301000000 = Literal(Timestamp.valueOf("2017-03-01 00:00:00")) + val timestamp0301000001 = Literal(Timestamp.valueOf("2017-03-01 00:00:01")) + if (convertToTS) { + // `Date` should be treated as timestamp at 00:00:00 See SPARK-23549 + ruleTest(rule, EqualTo(date0301, timestamp0301000000), + EqualTo(Cast(date0301, TimestampType), timestamp0301000000)) + ruleTest(rule, LessThan(date0301, timestamp0301000001), + LessThan(Cast(date0301, TimestampType), timestamp0301000001)) + } else { + ruleTest(rule, LessThan(date0301, timestamp0301000000), + LessThan(Cast(date0301, StringType), Cast(timestamp0301000000, StringType))) + ruleTest(rule, LessThan(date0301, timestamp0301000001), + LessThan(Cast(date0301, StringType), Cast(timestamp0301000001, StringType))) + } + } + } } test("cast WindowFrame boundaries to the type they operate upon") { diff --git a/sql/core/src/test/resources/sql-tests/inputs/predicate-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/predicate-functions.sql index e99d5cef81f64..fadb4bb27fa13 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/predicate-functions.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/predicate-functions.sql @@ -39,3 +39,10 @@ select 2.0 <= '2.2'; select 0.5 <= '1.5'; select to_date('2009-07-30 04:17:52') <= to_date('2009-07-30 04:17:52'); select to_date('2009-07-30 04:17:52') <= '2009-07-30 04:17:52'; + +-- SPARK-23549: Cast to timestamp when comparing timestamp with date +select to_date('2017-03-01') = to_timestamp('2017-03-01 00:00:00'); +select to_timestamp('2017-03-01 00:00:01') > to_date('2017-03-01'); +select to_timestamp('2017-03-01 00:00:01') >= to_date('2017-03-01'); +select to_date('2017-03-01') < to_timestamp('2017-03-01 00:00:01'); +select to_date('2017-03-01') <= to_timestamp('2017-03-01 00:00:01'); diff --git a/sql/core/src/test/resources/sql-tests/results/predicate-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/predicate-functions.sql.out index d51f6d37e4b41..cf828c69af62a 100644 --- a/sql/core/src/test/resources/sql-tests/results/predicate-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/predicate-functions.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 32 +-- Number of queries: 37 -- !query 0 @@ -256,3 +256,43 @@ select to_date('2009-07-30 04:17:52') <= '2009-07-30 04:17:52' struct<(CAST(to_date('2009-07-30 04:17:52') AS STRING) <= 2009-07-30 04:17:52):boolean> -- !query 31 output true + + +-- !query 32 +select to_date('2017-03-01') = to_timestamp('2017-03-01 00:00:00') +-- !query 32 schema +struct<(CAST(to_date('2017-03-01') AS TIMESTAMP) = to_timestamp('2017-03-01 00:00:00')):boolean> +-- !query 32 output +true + + +-- !query 33 +select to_timestamp('2017-03-01 00:00:01') > to_date('2017-03-01') +-- !query 33 schema +struct<(to_timestamp('2017-03-01 00:00:01') > CAST(to_date('2017-03-01') AS TIMESTAMP)):boolean> +-- !query 33 output +true + + +-- !query 34 +select to_timestamp('2017-03-01 00:00:01') >= to_date('2017-03-01') +-- !query 34 schema +struct<(to_timestamp('2017-03-01 00:00:01') >= CAST(to_date('2017-03-01') AS TIMESTAMP)):boolean> +-- !query 34 output +true + + +-- !query 35 +select to_date('2017-03-01') < to_timestamp('2017-03-01 00:00:01') +-- !query 35 schema +struct<(CAST(to_date('2017-03-01') AS TIMESTAMP) < to_timestamp('2017-03-01 00:00:01')):boolean> +-- !query 35 output +true + + +-- !query 36 +select to_date('2017-03-01') <= to_timestamp('2017-03-01 00:00:01') +-- !query 36 schema +struct<(CAST(to_date('2017-03-01') AS TIMESTAMP) <= to_timestamp('2017-03-01 00:00:01')):boolean> +-- !query 36 output +true From a9350d7095b79c8374fb4a06fd3f1a1a67615f6f Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Mon, 26 Mar 2018 12:42:32 +0900 Subject: [PATCH 0522/1161] [SPARK-23700][PYTHON] Cleanup imports in pyspark.sql ## What changes were proposed in this pull request? This cleans up unused imports, mainly from pyspark.sql module. Added a note in function.py that imports `UserDefinedFunction` only to maintain backwards compatibility for using `from pyspark.sql.function import UserDefinedFunction`. ## How was this patch tested? Existing tests and built docs. Author: Bryan Cutler Closes #20892 from BryanCutler/pyspark-cleanup-imports-SPARK-23700. --- python/pyspark/sql/column.py | 1 - python/pyspark/sql/conf.py | 1 - python/pyspark/sql/functions.py | 3 +-- python/pyspark/sql/group.py | 3 +-- python/pyspark/sql/readwriter.py | 2 +- python/pyspark/sql/streaming.py | 2 -- python/pyspark/sql/types.py | 1 - python/pyspark/sql/udf.py | 6 ++---- python/pyspark/util.py | 2 -- 9 files changed, 5 insertions(+), 16 deletions(-) diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py index e05a7b33c11a7..922c7cf288f8f 100644 --- a/python/pyspark/sql/column.py +++ b/python/pyspark/sql/column.py @@ -16,7 +16,6 @@ # import sys -import warnings import json if sys.version >= '3': diff --git a/python/pyspark/sql/conf.py b/python/pyspark/sql/conf.py index b82224b6194ed..db49040e17b63 100644 --- a/python/pyspark/sql/conf.py +++ b/python/pyspark/sql/conf.py @@ -67,7 +67,6 @@ def _checkType(self, obj, identifier): def _test(): import os import doctest - from pyspark.context import SparkContext from pyspark.sql.session import SparkSession import pyspark.sql.conf diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index dff590983b4d9..a4edb1e27b599 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -18,7 +18,6 @@ """ A collections of builtin functions """ -import math import sys import functools import warnings @@ -28,10 +27,10 @@ from pyspark import since, SparkContext from pyspark.rdd import ignore_unicode_prefix, PythonEvalType -from pyspark.serializers import PickleSerializer, AutoBatchedSerializer from pyspark.sql.column import Column, _to_java_column, _to_seq from pyspark.sql.dataframe import DataFrame from pyspark.sql.types import StringType, DataType +# Keep UserDefinedFunction import for backwards compatible import; moved in SPARK-22409 from pyspark.sql.udf import UserDefinedFunction, _create_udf diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py index 35cac406e0965..3505065b648f2 100644 --- a/python/pyspark/sql/group.py +++ b/python/pyspark/sql/group.py @@ -19,9 +19,8 @@ from pyspark import since from pyspark.rdd import ignore_unicode_prefix, PythonEvalType -from pyspark.sql.column import Column, _to_seq, _to_java_column, _create_column_from_literal +from pyspark.sql.column import Column, _to_seq from pyspark.sql.dataframe import DataFrame -from pyspark.sql.udf import UserDefinedFunction from pyspark.sql.types import * __all__ = ["GroupedData"] diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index e5288636c596e..4f9b9383a5ef4 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -22,7 +22,7 @@ from py4j.java_gateway import JavaClass -from pyspark import RDD, since, keyword_only +from pyspark import RDD, since from pyspark.rdd import ignore_unicode_prefix from pyspark.sql.column import _to_seq from pyspark.sql.types import * diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index 07f9ac1b5aa9e..c7907aaaf1f7b 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -24,8 +24,6 @@ else: intlike = (int, long) -from abc import ABCMeta, abstractmethod - from pyspark import since, keyword_only from pyspark.rdd import ignore_unicode_prefix from pyspark.sql.column import _to_seq diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 5d5919e451b46..1f6534836d64a 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -35,7 +35,6 @@ from pyspark import SparkContext from pyspark.serializers import CloudPickleSerializer -from pyspark.util import _exception_message __all__ = [ "DataType", "NullType", "StringType", "BinaryType", "BooleanType", "DateType", diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index 24dd06c26089c..9dbe49b831cef 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -17,16 +17,14 @@ """ User-defined function related classes and functions """ -import sys -import inspect import functools import sys from pyspark import SparkContext, since from pyspark.rdd import _prepare_for_python_RDD, PythonEvalType, ignore_unicode_prefix from pyspark.sql.column import Column, _to_java_column, _to_seq -from pyspark.sql.types import StringType, DataType, ArrayType, StructType, MapType, \ - _parse_datatype_string, to_arrow_type, to_arrow_schema +from pyspark.sql.types import StringType, DataType, StructType, _parse_datatype_string,\ + to_arrow_type, to_arrow_schema from pyspark.util import _get_argspec __all__ = ["UDFRegistration"] diff --git a/python/pyspark/util.py b/python/pyspark/util.py index ed1bdd0e4be83..49afc13640332 100644 --- a/python/pyspark/util.py +++ b/python/pyspark/util.py @@ -22,8 +22,6 @@ __all__ = [] -import sys - def _exception_message(excp): """Return the message from an exception as either a str or unicode object. Supports both From 087fb3142028d679524e22596b0ad4f74ff47e8d Mon Sep 17 00:00:00 2001 From: "Michael (Stu) Stewart" Date: Mon, 26 Mar 2018 12:45:45 +0900 Subject: [PATCH 0523/1161] [SPARK-23645][MINOR][DOCS][PYTHON] Add docs RE `pandas_udf` with keyword args ## What changes were proposed in this pull request? Add documentation about the limitations of `pandas_udf` with keyword arguments and related concepts, like `functools.partial` fn objects. NOTE: intermediate commits on this PR show some of the steps that can be taken to fix some (but not all) of these pain points. ### Survey of problems we face today: (Initialize) Note: python 3.6 and spark 2.4snapshot. ``` from pyspark.sql import SparkSession import inspect, functools from pyspark.sql.functions import pandas_udf, PandasUDFType, col, lit, udf spark = SparkSession.builder.getOrCreate() print(spark.version) df = spark.range(1,6).withColumn('b', col('id') * 2) def ok(a,b): return a+b ``` Using a keyword argument at the call site `b=...` (and yes, *full* stack trace below, haha): ``` ---> 14 df.withColumn('ok', pandas_udf(f=ok, returnType='bigint')('id', b='id')).show() # no kwargs TypeError: wrapper() got an unexpected keyword argument 'b' ``` Using partial with a keyword argument where the kw-arg is the first argument of the fn: *(Aside: kind of interesting that lines 15,16 work great and then 17 explodes)* ``` --------------------------------------------------------------------------- ValueError Traceback (most recent call last) in () 15 df.withColumn('ok', pandas_udf(f=functools.partial(ok, 7), returnType='bigint')('id')).show() 16 df.withColumn('ok', pandas_udf(f=functools.partial(ok, b=7), returnType='bigint')('id')).show() ---> 17 df.withColumn('ok', pandas_udf(f=functools.partial(ok, a=7), returnType='bigint')('id')).show() /Users/stu/ZZ/spark/python/pyspark/sql/functions.py in pandas_udf(f, returnType, functionType) 2378 return functools.partial(_create_udf, returnType=return_type, evalType=eval_type) 2379 else: -> 2380 return _create_udf(f=f, returnType=return_type, evalType=eval_type) 2381 2382 /Users/stu/ZZ/spark/python/pyspark/sql/udf.py in _create_udf(f, returnType, evalType) 54 argspec.varargs is None: 55 raise ValueError( ---> 56 "Invalid function: 0-arg pandas_udfs are not supported. " 57 "Instead, create a 1-arg pandas_udf and ignore the arg in your function." 58 ) ValueError: Invalid function: 0-arg pandas_udfs are not supported. Instead, create a 1-arg pandas_udf and ignore the arg in your function. ``` Author: Michael (Stu) Stewart Closes #20900 from mstewart141/udfkw2. --- python/pyspark/sql/functions.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index a4edb1e27b599..ad3e37c872628 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2154,6 +2154,8 @@ def udf(f=None, returnType=StringType()): in boolean expressions and it ends up with being executed all internally. If the functions can fail on special rows, the workaround is to incorporate the condition into the functions. + .. note:: The user-defined functions do not take keyword arguments on the calling side. + :param f: python function if used as a standalone function :param returnType: the return type of the user-defined function. The value can be either a :class:`pyspark.sql.types.DataType` object or a DDL-formatted type string. @@ -2337,6 +2339,8 @@ def pandas_udf(f=None, returnType=None, functionType=None): .. note:: The user-defined functions do not support conditional expressions or short circuiting in boolean expressions and it ends up with being executed all internally. If the functions can fail on special rows, the workaround is to incorporate the condition into the functions. + + .. note:: The user-defined functions do not take keyword arguments on the calling side. """ # decorator @pandas_udf(returnType, functionType) is_decorator = f is None or isinstance(f, (str, DataType)) From eb48edf9ca4f4b42c63f145718696472cb6a31ba Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Mon, 26 Mar 2018 14:01:04 +0800 Subject: [PATCH 0524/1161] [SPARK-23787][TESTS] Fix file download test in SparkSubmitSuite for Hadoop 2.9. This particular test assumed that Hadoop libraries did not support http as a file system. Hadoop 2.9 does, so the test failed. The test now forces a non-existent implementation for the http fs, which forces the expected error. There were also a couple of other issues in the same test: SparkSubmit arguments in the wrong order, and the wrong check later when asserting, which was being masked by the previous issues. Author: Marcelo Vanzin Closes #20895 from vanzin/SPARK-23787. --- .../spark/deploy/SparkSubmitSuite.scala | 36 ++++++++++--------- 1 file changed, 19 insertions(+), 17 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index 2d0c192db4915..d86ef907b4492 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -959,25 +959,28 @@ class SparkSubmitSuite } test("download remote resource if it is not supported by yarn service") { - testRemoteResources(isHttpSchemeBlacklisted = false, supportMockHttpFs = false) + testRemoteResources(enableHttpFs = false, blacklistHttpFs = false) } test("avoid downloading remote resource if it is supported by yarn service") { - testRemoteResources(isHttpSchemeBlacklisted = false, supportMockHttpFs = true) + testRemoteResources(enableHttpFs = true, blacklistHttpFs = false) } test("force download from blacklisted schemes") { - testRemoteResources(isHttpSchemeBlacklisted = true, supportMockHttpFs = true) + testRemoteResources(enableHttpFs = true, blacklistHttpFs = true) } - private def testRemoteResources(isHttpSchemeBlacklisted: Boolean, - supportMockHttpFs: Boolean): Unit = { + private def testRemoteResources( + enableHttpFs: Boolean, + blacklistHttpFs: Boolean): Unit = { val hadoopConf = new Configuration() updateConfWithFakeS3Fs(hadoopConf) - if (supportMockHttpFs) { + if (enableHttpFs) { hadoopConf.set("fs.http.impl", classOf[TestFileSystem].getCanonicalName) - hadoopConf.set("fs.http.impl.disable.cache", "true") + } else { + hadoopConf.set("fs.http.impl", getClass().getName() + ".DoesNotExist") } + hadoopConf.set("fs.http.impl.disable.cache", "true") val tmpDir = Utils.createTempDir() val mainResource = File.createTempFile("tmpPy", ".py", tmpDir) @@ -986,20 +989,19 @@ class SparkSubmitSuite val tmpHttpJar = TestUtils.createJarWithFiles(Map("test.resource" -> "USER"), tmpDir) val tmpHttpJarPath = s"http://${new File(tmpHttpJar.toURI).getAbsolutePath}" + val forceDownloadArgs = if (blacklistHttpFs) { + Seq("--conf", "spark.yarn.dist.forceDownloadSchemes=http") + } else { + Nil + } + val args = Seq( "--class", UserClasspathFirstTest.getClass.getName.stripPrefix("$"), "--name", "testApp", "--master", "yarn", "--deploy-mode", "client", - "--jars", s"$tmpS3JarPath,$tmpHttpJarPath", - s"s3a://$mainResource" - ) ++ ( - if (isHttpSchemeBlacklisted) { - Seq("--conf", "spark.yarn.dist.forceDownloadSchemes=http,https") - } else { - Nil - } - ) + "--jars", s"$tmpS3JarPath,$tmpHttpJarPath" + ) ++ forceDownloadArgs ++ Seq(s"s3a://$mainResource") val appArgs = new SparkSubmitArguments(args) val (_, _, conf, _) = SparkSubmit.prepareSubmitEnvironment(appArgs, Some(hadoopConf)) @@ -1009,7 +1011,7 @@ class SparkSubmitSuite // The URI of remote S3 resource should still be remote. assert(jars.contains(tmpS3JarPath)) - if (supportMockHttpFs) { + if (enableHttpFs && !blacklistHttpFs) { // If Http FS is supported by yarn service, the URI of remote http resource should // still be remote. assert(jars.contains(tmpHttpJarPath)) From b30a7d28b399950953d4b112c57d4c9b9ab223e9 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Mon, 26 Mar 2018 12:45:45 -0700 Subject: [PATCH 0525/1161] [SPARK-23572][DOCS] Bring "security.md" up to date. This change basically rewrites the security documentation so that it's up to date with new features, more correct, and more complete. Because security is such an important feature, I chose to move all the relevant configuration documentation to the security page, instead of having them peppered all over the place in the configuration page. This allows an almost one-stop shop for security configuration in Spark. The only exceptions are some YARN-specific minor features which I left in the YARN page. I also re-organized the page's topics, since they didn't make a lot of sense. You had kerberos features described inside paragraphs talking about UI access control, and other oddities. It should be easier now to find information about specific Spark security features. I also enabled TOCs for both the Security and YARN pages, since that makes it easier to see what is covered. I removed most of the comments from the SecurityManager javadoc since they just replicated information in the security doc, with different levels of out-of-dateness. Author: Marcelo Vanzin Closes #20742 from vanzin/SPARK-23572. --- .gitignore | 1 + .../org/apache/spark/SecurityManager.scala | 144 +--- docs/configuration.md | 359 +--------- docs/monitoring.md | 40 +- docs/running-on-yarn.md | 203 +++--- docs/security.md | 629 +++++++++++++++--- 6 files changed, 673 insertions(+), 703 deletions(-) diff --git a/.gitignore b/.gitignore index 39085904e324c..e4c44d0590d59 100644 --- a/.gitignore +++ b/.gitignore @@ -76,6 +76,7 @@ streaming-tests.log target/ unit-tests.log work/ +docs/.jekyll-metadata # For Hive TempStatsStore/ diff --git a/core/src/main/scala/org/apache/spark/SecurityManager.scala b/core/src/main/scala/org/apache/spark/SecurityManager.scala index da1c89cd78901..09ec8932353a0 100644 --- a/core/src/main/scala/org/apache/spark/SecurityManager.scala +++ b/core/src/main/scala/org/apache/spark/SecurityManager.scala @@ -42,148 +42,10 @@ import org.apache.spark.util.Utils * should access it from that. There are some cases where the SparkEnv hasn't been * initialized yet and this class must be instantiated directly. * - * Spark currently supports authentication via a shared secret. - * Authentication can be configured to be on via the 'spark.authenticate' configuration - * parameter. This parameter controls whether the Spark communication protocols do - * authentication using the shared secret. This authentication is a basic handshake to - * make sure both sides have the same shared secret and are allowed to communicate. - * If the shared secret is not identical they will not be allowed to communicate. - * - * The Spark UI can also be secured by using javax servlet filters. A user may want to - * secure the UI if it has data that other users should not be allowed to see. The javax - * servlet filter specified by the user can authenticate the user and then once the user - * is logged in, Spark can compare that user versus the view acls to make sure they are - * authorized to view the UI. The configs 'spark.acls.enable', 'spark.ui.view.acls' and - * 'spark.ui.view.acls.groups' control the behavior of the acls. Note that the person who - * started the application always has view access to the UI. - * - * Spark has a set of individual and group modify acls (`spark.modify.acls`) and - * (`spark.modify.acls.groups`) that controls which users and groups have permission to - * modify a single application. This would include things like killing the application. - * By default the person who started the application has modify access. For modify access - * through the UI, you must have a filter that does authentication in place for the modify - * acls to work properly. - * - * Spark also has a set of individual and group admin acls (`spark.admin.acls`) and - * (`spark.admin.acls.groups`) which is a set of users/administrators and admin groups - * who always have permission to view or modify the Spark application. - * - * Starting from version 1.3, Spark has partial support for encrypted connections with SSL. - * - * At this point spark has multiple communication protocols that need to be secured and - * different underlying mechanisms are used depending on the protocol: - * - * - HTTP for broadcast and file server (via HttpServer) -> Spark currently uses Jetty - * for the HttpServer. Jetty supports multiple authentication mechanisms - - * Basic, Digest, Form, Spnego, etc. It also supports multiple different login - * services - Hash, JAAS, Spnego, JDBC, etc. Spark currently uses the HashLoginService - * to authenticate using DIGEST-MD5 via a single user and the shared secret. - * Since we are using DIGEST-MD5, the shared secret is not passed on the wire - * in plaintext. - * - * We currently support SSL (https) for this communication protocol (see the details - * below). - * - * The Spark HttpServer installs the HashLoginServer and configures it to DIGEST-MD5. - * Any clients must specify the user and password. There is a default - * Authenticator installed in the SecurityManager to how it does the authentication - * and in this case gets the user name and password from the request. - * - * - BlockTransferService -> The Spark BlockTransferServices uses java nio to asynchronously - * exchange messages. For this we use the Java SASL - * (Simple Authentication and Security Layer) API and again use DIGEST-MD5 - * as the authentication mechanism. This means the shared secret is not passed - * over the wire in plaintext. - * Note that SASL is pluggable as to what mechanism it uses. We currently use - * DIGEST-MD5 but this could be changed to use Kerberos or other in the future. - * Spark currently supports "auth" for the quality of protection, which means - * the connection does not support integrity or privacy protection (encryption) - * after authentication. SASL also supports "auth-int" and "auth-conf" which - * SPARK could support in the future to allow the user to specify the quality - * of protection they want. If we support those, the messages will also have to - * be wrapped and unwrapped via the SaslServer/SaslClient.wrap/unwrap API's. - * - * Since the NioBlockTransferService does asynchronous messages passing, the SASL - * authentication is a bit more complex. A ConnectionManager can be both a client - * and a Server, so for a particular connection it has to determine what to do. - * A ConnectionId was added to be able to track connections and is used to - * match up incoming messages with connections waiting for authentication. - * The ConnectionManager tracks all the sendingConnections using the ConnectionId, - * waits for the response from the server, and does the handshake before sending - * the real message. - * - * The NettyBlockTransferService ensures that SASL authentication is performed - * synchronously prior to any other communication on a connection. This is done in - * SaslClientBootstrap on the client side and SaslRpcHandler on the server side. - * - * - HTTP for the Spark UI -> the UI was changed to use servlets so that javax servlet filters - * can be used. Yarn requires a specific AmIpFilter be installed for security to work - * properly. For non-Yarn deployments, users can write a filter to go through their - * organization's normal login service. If an authentication filter is in place then the - * SparkUI can be configured to check the logged in user against the list of users who - * have view acls to see if that user is authorized. - * The filters can also be used for many different purposes. For instance filters - * could be used for logging, encryption, or compression. - * - * The exact mechanisms used to generate/distribute the shared secret are deployment-specific. - * - * For YARN deployments, the secret is automatically generated. The secret is placed in the Hadoop - * UGI which gets passed around via the Hadoop RPC mechanism. Hadoop RPC can be configured to - * support different levels of protection. See the Hadoop documentation for more details. Each - * Spark application on YARN gets a different shared secret. - * - * On YARN, the Spark UI gets configured to use the Hadoop YARN AmIpFilter which requires the user - * to go through the ResourceManager Proxy. That proxy is there to reduce the possibility of web - * based attacks through YARN. Hadoop can be configured to use filters to do authentication. That - * authentication then happens via the ResourceManager Proxy and Spark will use that to do - * authorization against the view acls. - * - * For other Spark deployments, the shared secret must be specified via the - * spark.authenticate.secret config. - * All the nodes (Master and Workers) and the applications need to have the same shared secret. - * This again is not ideal as one user could potentially affect another users application. - * This should be enhanced in the future to provide better protection. - * If the UI needs to be secure, the user needs to install a javax servlet filter to do the - * authentication. Spark will then use that user to compare against the view acls to do - * authorization. If not filter is in place the user is generally null and no authorization - * can take place. - * - * When authentication is being used, encryption can also be enabled by setting the option - * spark.authenticate.enableSaslEncryption to true. This is only supported by communication - * channels that use the network-common library, and can be used as an alternative to SSL in those - * cases. - * - * SSL can be used for encryption for certain communication channels. The user can configure the - * default SSL settings which will be used for all the supported communication protocols unless - * they are overwritten by protocol specific settings. This way the user can easily provide the - * common settings for all the protocols without disabling the ability to configure each one - * individually. - * - * All the SSL settings like `spark.ssl.xxx` where `xxx` is a particular configuration property, - * denote the global configuration for all the supported protocols. In order to override the global - * configuration for the particular protocol, the properties must be overwritten in the - * protocol-specific namespace. Use `spark.ssl.yyy.xxx` settings to overwrite the global - * configuration for particular protocol denoted by `yyy`. Currently `yyy` can be only`fs` for - * broadcast and file server. - * - * Refer to [[org.apache.spark.SSLOptions]] documentation for the list of - * options that can be specified. - * - * SecurityManager initializes SSLOptions objects for different protocols separately. SSLOptions - * object parses Spark configuration at a given namespace and builds the common representation - * of SSL settings. SSLOptions is then used to provide protocol-specific SSLContextFactory for - * Jetty. - * - * SSL must be configured on each node and configured for each component involved in - * communication using the particular protocol. In YARN clusters, the key-store can be prepared on - * the client side then distributed and used by the executors as the part of the application - * (YARN allows the user to deploy files before the application is started). - * In standalone deployment, the user needs to provide key-stores and configuration - * options for master and workers. In this mode, the user may allow the executors to use the SSL - * settings inherited from the worker which spawned that executor. It can be accomplished by - * setting `spark.ssl.useNodeLocalConf` to `true`. + * This class implements all of the configuration related to security features described + * in the "Security" document. Please refer to that document for specific features implemented + * here. */ - private[spark] class SecurityManager( sparkConf: SparkConf, val ioEncryptionKey: Option[Array[Byte]] = None) diff --git a/docs/configuration.md b/docs/configuration.md index e7f2419cc2fa4..2eb6a77434ea6 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -712,30 +712,6 @@ Apart from these, the following properties are also available, and may be useful When we fail to register to the external shuffle service, we will retry for maxAttempts times. - - spark.io.encryption.enabled - false - - Enable IO encryption. Currently supported by all modes except Mesos. It's recommended that RPC encryption - be enabled when using this feature. - - - - spark.io.encryption.keySizeBits - 128 - - IO encryption key size in bits. Supported values are 128, 192 and 256. - - - - spark.io.encryption.keygen.algorithm - HmacSHA1 - - The algorithm to use when generating the IO encryption key. The supported algorithms are - described in the KeyGenerator section of the Java Cryptography Architecture Standard Algorithm - Name Documentation. - - ### Spark UI @@ -893,6 +869,23 @@ Apart from these, the following properties are also available, and may be useful How many dead executors the Spark UI and status APIs remember before garbage collecting. + + spark.ui.filters + None + + Comma separated list of filter class names to apply to the Spark Web UI. The filter should be a + standard + javax servlet Filter. + +
    Filter parameters can also be specified in the configuration, by setting config entries + of the form spark.<class name of filter>.param.<param name>=<value> + +
    For example: +
    spark.ui.filters=com.test.filter1 +
    spark.com.test.filter1.param.name1=foo +
    spark.com.test.filter1.param.name2=bar + + ### Compression and Serialization @@ -1446,6 +1439,15 @@ Apart from these, the following properties are also available, and may be useful Duration for an RPC remote endpoint lookup operation to wait before timing out. + + spark.core.connection.ack.wait.timeout + spark.network.timeout + + How long for the connection to wait for ack to occur before timing + out and giving up. To avoid unwilling timeout caused by long pause like GC, + you can set larger value. + + ### Scheduling @@ -1817,313 +1819,8 @@ Apart from these, the following properties are also available, and may be useful ### Security - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
    Property NameDefaultMeaning
    spark.acls.enablefalse - Whether Spark acls should be enabled. If enabled, this checks to see if the user has - access permissions to view or modify the job. Note this requires the user to be known, - so if the user comes across as null no checks are done. Filters can be used with the UI - to authenticate and set the user. -
    spark.admin.aclsEmpty - Comma separated list of users/administrators that have view and modify access to all Spark jobs. - This can be used if you run on a shared cluster and have a set of administrators or devs who - help debug when things do not work. Putting a "*" in the list means any user can have the - privilege of admin. -
    spark.admin.acls.groupsEmpty - Comma separated list of groups that have view and modify access to all Spark jobs. - This can be used if you have a set of administrators or developers who help maintain and debug - the underlying infrastructure. Putting a "*" in the list means any user in any group can have - the privilege of admin. The user groups are obtained from the instance of the groups mapping - provider specified by spark.user.groups.mapping. Check the entry - spark.user.groups.mapping for more details. -
    spark.user.groups.mappingorg.apache.spark.security.ShellBasedGroupsMappingProvider - The list of groups for a user is determined by a group mapping service defined by the trait - org.apache.spark.security.GroupMappingServiceProvider which can be configured by this property. - A default unix shell based implementation is provided org.apache.spark.security.ShellBasedGroupsMappingProvider - which can be specified to resolve a list of groups for a user. - Note: This implementation supports only a Unix/Linux based environment. Windows environment is - currently not supported. However, a new platform/protocol can be supported by implementing - the trait org.apache.spark.security.GroupMappingServiceProvider. -
    spark.authenticatefalse - Whether Spark authenticates its internal connections. See - spark.authenticate.secret if not running on YARN. -
    spark.authenticate.secretNone - Set the secret key used for Spark to authenticate between components. This needs to be set if - not running on YARN and authentication is enabled. -
    spark.network.crypto.enabledfalse - Enable encryption using the commons-crypto library for RPC and block transfer service. - Requires spark.authenticate to be enabled. -
    spark.network.crypto.keyLength128 - The length in bits of the encryption key to generate. Valid values are 128, 192 and 256. -
    spark.network.crypto.keyFactoryAlgorithmPBKDF2WithHmacSHA1 - The key factory algorithm to use when generating encryption keys. Should be one of the - algorithms supported by the javax.crypto.SecretKeyFactory class in the JRE being used. -
    spark.network.crypto.saslFallbacktrue - Whether to fall back to SASL authentication if authentication fails using Spark's internal - mechanism. This is useful when the application is connecting to old shuffle services that - do not support the internal Spark authentication protocol. On the server side, this can be - used to block older clients from authenticating against a new shuffle service. -
    spark.network.crypto.config.*None - Configuration values for the commons-crypto library, such as which cipher implementations to - use. The config name should be the name of commons-crypto configuration without the - "commons.crypto" prefix. -
    spark.authenticate.enableSaslEncryptionfalse - Enable encrypted communication when authentication is - enabled. This is supported by the block transfer service and the - RPC endpoints. -
    spark.network.sasl.serverAlwaysEncryptfalse - Disable unencrypted connections for services that support SASL authentication. -
    spark.core.connection.ack.wait.timeoutspark.network.timeout - How long for the connection to wait for ack to occur before timing - out and giving up. To avoid unwilling timeout caused by long pause like GC, - you can set larger value. -
    spark.modify.aclsEmpty - Comma separated list of users that have modify access to the Spark job. By default only the - user that started the Spark job has access to modify it (kill it for example). Putting a "*" in - the list means any user can have access to modify it. -
    spark.modify.acls.groupsEmpty - Comma separated list of groups that have modify access to the Spark job. This can be used if you - have a set of administrators or developers from the same team to have access to control the job. - Putting a "*" in the list means any user in any group has the access to modify the Spark job. - The user groups are obtained from the instance of the groups mapping provider specified by - spark.user.groups.mapping. Check the entry spark.user.groups.mapping - for more details. -
    spark.ui.filtersNone - Comma separated list of filter class names to apply to the Spark web UI. The filter should be a - standard - javax servlet Filter. Parameters to each filter can also be specified by setting a - java system property of:
    - spark.<class name of filter>.params='param1=value1,param2=value2'
    - For example:
    - -Dspark.ui.filters=com.test.filter1
    - -Dspark.com.test.filter1.params='param1=foo,param2=testing' -
    spark.ui.view.aclsEmpty - Comma separated list of users that have view access to the Spark web ui. By default only the - user that started the Spark job has view access. Putting a "*" in the list means any user can - have view access to this Spark job. -
    spark.ui.view.acls.groupsEmpty - Comma separated list of groups that have view access to the Spark web ui to view the Spark Job - details. This can be used if you have a set of administrators or developers or users who can - monitor the Spark job submitted. Putting a "*" in the list means any user in any group can view - the Spark job details on the Spark web ui. The user groups are obtained from the instance of the - groups mapping provider specified by spark.user.groups.mapping. Check the entry - spark.user.groups.mapping for more details. -
    - -### TLS / SSL - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
    Property NameDefaultMeaning
    spark.ssl.enabledfalse - Whether to enable SSL connections on all supported protocols. - -
    When spark.ssl.enabled is configured, spark.ssl.protocol - is required. - -
    All the SSL settings like spark.ssl.xxx where xxx is a - particular configuration property, denote the global configuration for all the supported - protocols. In order to override the global configuration for the particular protocol, - the properties must be overwritten in the protocol-specific namespace. - -
    Use spark.ssl.YYY.XXX settings to overwrite the global configuration for - particular protocol denoted by YYY. Example values for YYY - include fs, ui, standalone, and - historyServer. See SSL - Configuration for details on hierarchical SSL configuration for services. -
    spark.ssl.[namespace].portNone - The port where the SSL service will listen on. - -
    The port must be defined within a namespace configuration; see - SSL Configuration for the available - namespaces. - -
    When not set, the SSL port will be derived from the non-SSL port for the - same service. A value of "0" will make the service bind to an ephemeral port. -
    spark.ssl.enabledAlgorithmsEmpty - A comma separated list of ciphers. The specified ciphers must be supported by JVM. - The reference list of protocols one can find on - this - page. - Note: If not set, it will use the default cipher suites of JVM. -
    spark.ssl.keyPasswordNone - A password to the private key in key-store. -
    spark.ssl.keyStoreNone - A path to a key-store file. The path can be absolute or relative to the directory where - the component is started in. -
    spark.ssl.keyStorePasswordNone - A password to the key-store. -
    spark.ssl.keyStoreTypeJKS - The type of the key-store. -
    spark.ssl.protocolNone - A protocol name. The protocol must be supported by JVM. The reference list of protocols - one can find on this - page. -
    spark.ssl.needClientAuthfalse - Set true if SSL needs client authentication. -
    spark.ssl.trustStoreNone - A path to a trust-store file. The path can be absolute or relative to the directory - where the component is started in. -
    spark.ssl.trustStorePasswordNone - A password to the trust-store. -
    spark.ssl.trustStoreTypeJKS - The type of the trust-store. -
    - +Please refer to the [Security](security.html) page for available options on how to secure different +Spark subsystems. ### Spark SQL diff --git a/docs/monitoring.md b/docs/monitoring.md index d5f7ffcc260a1..01736c77b0979 100644 --- a/docs/monitoring.md +++ b/docs/monitoring.md @@ -80,7 +80,10 @@ The history server can be configured as follows: -### Spark configuration options +### Spark History Server Configuration Options + +Security options for the Spark History Server are covered more detail in the +[Security](security.html#web-ui) page. @@ -160,41 +163,6 @@ The history server can be configured as follows: Location of the kerberos keytab file for the History Server. - - - - - - - - - - - - - - - diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index c010af35f8d2e..e07759a4dba87 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -2,6 +2,8 @@ layout: global title: Running Spark on YARN --- +* This will become a table of contents (this text will be scraped). +{:toc} Support for running on [YARN (Hadoop NextGen)](http://hadoop.apache.org/docs/stable/hadoop-yarn/hadoop-yarn-site/YARN.html) @@ -217,8 +219,8 @@ To use a custom metrics.properties for the application master and executors, upd @@ -265,19 +267,6 @@ To use a custom metrics.properties for the application master and executors, upd distribution. - - - - - @@ -373,31 +362,6 @@ To use a custom metrics.properties for the application master and executors, upd in YARN ApplicationReports, which can be used for filtering when querying YARN apps. - - - - - - - - - - - - - - - @@ -424,17 +388,6 @@ To use a custom metrics.properties for the application master and executors, upd See spark.yarn.config.gatewayPath. - - - - - @@ -468,48 +421,104 @@ To use a custom metrics.properties for the application master and executors, upd - The `--files` and `--archives` options support specifying file names with the # similar to Hadoop. For example you can specify: `--files localtest.txt#appSees.txt` and this will upload the file you have locally named `localtest.txt` into HDFS but this will be linked to by the name `appSees.txt`, and your application should use the name as `appSees.txt` to reference it when running on YARN. - The `--jars` option allows the `SparkContext.addJar` function to work if you are using it with local files and running in `cluster` mode. It does not need to be used if you are using it with HDFS, HTTP, HTTPS, or FTP files. -# Running in a Secure Cluster +# Kerberos + +Standard Kerberos support in Spark is covered in the [Security](security.html#kerberos) page. + +In YARN mode, when accessing Hadoop file systems, aside from the service hosting the user's home +directory, Spark will also automatically obtain delegation tokens for the service hosting the +staging directory of the Spark application. + +If an application needs to interact with other secure Hadoop filesystems, their URIs need to be +explicitly provided to Spark at launch time. This is done by listing them in the +`spark.yarn.access.hadoopFileSystems` property, described in the configuration section below. -As covered in [security](security.html), Kerberos is used in a secure Hadoop cluster to -authenticate principals associated with services and clients. This allows clients to -make requests of these authenticated services; the services to grant rights -to the authenticated principals. +The YARN integration also supports custom delegation token providers using the Java Services +mechanism (see `java.util.ServiceLoader`). Implementations of +`org.apache.spark.deploy.yarn.security.ServiceCredentialProvider` can be made available to Spark +by listing their names in the corresponding file in the jar's `META-INF/services` directory. These +providers can be disabled individually by setting `spark.security.credentials.{service}.enabled` to +`false`, where `{service}` is the name of the credential provider. + +## YARN-specific Kerberos Configuration + +
    Property NameDefaultMeaning
    spark.history.ui.acls.enablefalse - Specifies whether acls should be checked to authorize users viewing the applications. - If enabled, access control checks are made regardless of what the individual application had - set for spark.ui.acls.enable when the application was run. The application owner - will always have authorization to view their own application and any users specified via - spark.ui.view.acls and groups specified via spark.ui.view.acls.groups - when the application was run will also have authorization to view that application. - If disabled, no access control checks are made. -
    spark.history.ui.admin.aclsempty - Comma separated list of users/administrators that have view access to all the Spark applications in - history server. By default only the users permitted to view the application at run-time could - access the related application history, with this, configured users/administrators could also - have the permission to access it. - Putting a "*" in the list means any user can have the privilege of admin. -
    spark.history.ui.admin.acls.groupsempty - Comma separated list of groups that have view access to all the Spark applications in - history server. By default only the groups permitted to view the application at run-time could - access the related application history, with this, configured groups could also - have the permission to access it. - Putting a "*" in the list means any group can have the privilege of admin. -
    spark.history.fs.cleaner.enabled falsespark.yarn.dist.forceDownloadSchemes (none) - Comma-separated list of schemes for which files will be downloaded to the local disk prior to - being added to YARN's distributed cache. For use in cases where the YARN service does not + Comma-separated list of schemes for which files will be downloaded to the local disk prior to + being added to YARN's distributed cache. For use in cases where the YARN service does not support schemes that are supported by Spark, like http, https and ftp.
    spark.yarn.access.hadoopFileSystems(none) - A comma-separated list of secure Hadoop filesystems your Spark application is going to access. For - example, spark.yarn.access.hadoopFileSystems=hdfs://nn1.com:8032,hdfs://nn2.com:8032, - webhdfs://nn3.com:50070. The Spark application must have access to the filesystems listed - and Kerberos must be properly configured to be able to access them (either in the same realm - or in a trusted realm). Spark acquires security tokens for each of the filesystems so that - the Spark application can access those remote Hadoop filesystems. spark.yarn.access.namenodes - is deprecated, please use this instead. -
    spark.yarn.appMasterEnv.[EnvironmentVariableName] (none)
    spark.yarn.keytab(none) - The full path to the file that contains the keytab for the principal specified above. - This keytab will be copied to the node running the YARN Application Master via the Secure Distributed Cache, - for renewing the login tickets and the delegation tokens periodically. (Works also with the "local" master) -
    spark.yarn.principal(none) - Principal to be used to login to KDC, while running on secure HDFS. (Works also with the "local" master) -
    spark.yarn.kerberos.relogin.period1m - How often to check whether the kerberos TGT should be renewed. This should be set to a value - that is shorter than the TGT renewal period (or the TGT lifetime if TGT renewal is not enabled). - The default value should be enough for most deployments. -
    spark.yarn.config.gatewayPath (none)
    spark.security.credentials.${service}.enabledtrue - Controls whether to obtain credentials for services when security is enabled. - By default, credentials for all supported services are retrieved when those services are - configured, but it's possible to disable that behavior if it somehow conflicts with the - application being run. For further details please see - [Running in a Secure Cluster](running-on-yarn.html#running-in-a-secure-cluster) -
    spark.yarn.rolledLog.includePattern (none)
    + + + + + + + + + + + + + + + + + + + + + +
    Property NameDefaultMeaning
    spark.yarn.keytab(none) + The full path to the file that contains the keytab for the principal specified above. This keytab + will be copied to the node running the YARN Application Master via the YARN Distributed Cache, and + will be used for renewing the login tickets and the delegation tokens periodically. Equivalent to + the --keytab command line argument. + +
    (Works also with the "local" master.) +
    spark.yarn.principal(none) + Principal to be used to login to KDC, while running on secure clusters. Equivalent to the + --principal command line argument. + +
    (Works also with the "local" master.) +
    spark.yarn.access.hadoopFileSystems(none) + A comma-separated list of secure Hadoop filesystems your Spark application is going to access. For + example, spark.yarn.access.hadoopFileSystems=hdfs://nn1.com:8032,hdfs://nn2.com:8032, + webhdfs://nn3.com:50070. The Spark application must have access to the filesystems listed + and Kerberos must be properly configured to be able to access them (either in the same realm + or in a trusted realm). Spark acquires security tokens for each of the filesystems so that + the Spark application can access those remote Hadoop filesystems. +
    spark.yarn.kerberos.relogin.period1m + How often to check whether the kerberos TGT should be renewed. This should be set to a value + that is shorter than the TGT renewal period (or the TGT lifetime if TGT renewal is not enabled). + The default value should be enough for most deployments. +
    -Hadoop services issue *hadoop tokens* to grant access to the services and data. -Clients must first acquire tokens for the services they will access and pass them along with their -application as it is launched in the YARN cluster. +## Troubleshooting Kerberos -For a Spark application to interact with any of the Hadoop filesystem (for example hdfs, webhdfs, etc), HBase and Hive, it must acquire the relevant tokens -using the Kerberos credentials of the user launching the application -—that is, the principal whose identity will become that of the launched Spark application. +Debugging Hadoop/Kerberos problems can be "difficult". One useful technique is to +enable extra logging of Kerberos operations in Hadoop by setting the `HADOOP_JAAS_DEBUG` +environment variable. -This is normally done at launch time: in a secure cluster Spark will automatically obtain a -token for the cluster's default Hadoop filesystem, and potentially for HBase and Hive. +```bash +export HADOOP_JAAS_DEBUG=true +``` -An HBase token will be obtained if HBase is in on classpath, the HBase configuration declares -the application is secure (i.e. `hbase-site.xml` sets `hbase.security.authentication` to `kerberos`), -and `spark.security.credentials.hbase.enabled` is not set to `false`. +The JDK classes can be configured to enable extra logging of their Kerberos and +SPNEGO/REST authentication via the system properties `sun.security.krb5.debug` +and `sun.security.spnego.debug=true` -Similarly, a Hive token will be obtained if Hive is on the classpath, its configuration -includes a URI of the metadata store in `"hive.metastore.uris`, and -`spark.security.credentials.hive.enabled` is not set to `false`. +``` +-Dsun.security.krb5.debug=true -Dsun.security.spnego.debug=true +``` -If an application needs to interact with other secure Hadoop filesystems, then -the tokens needed to access these clusters must be explicitly requested at -launch time. This is done by listing them in the `spark.yarn.access.hadoopFileSystems` property. +All these options can be enabled in the Application Master: ``` -spark.yarn.access.hadoopFileSystems hdfs://ireland.example.org:8020/,webhdfs://frankfurt.example.org:50070/ +spark.yarn.appMasterEnv.HADOOP_JAAS_DEBUG true +spark.yarn.am.extraJavaOptions -Dsun.security.krb5.debug=true -Dsun.security.spnego.debug=true ``` -Spark supports integrating with other security-aware services through Java Services mechanism (see -`java.util.ServiceLoader`). To do that, implementations of `org.apache.spark.deploy.yarn.security.ServiceCredentialProvider` -should be available to Spark by listing their names in the corresponding file in the jar's -`META-INF/services` directory. These plug-ins can be disabled by setting -`spark.security.credentials.{service}.enabled` to `false`, where `{service}` is the name of -credential provider. +Finally, if the log level for `org.apache.spark.deploy.yarn.Client` is set to `DEBUG`, the log +will include a list of all tokens obtained, and their expiry details -## Configuring the External Shuffle Service + +# Configuring the External Shuffle Service To start the Spark Shuffle Service on each `NodeManager` in your YARN cluster, follow these instructions: @@ -542,7 +551,7 @@ The following extra configuration options are available when the shuffle service -## Launching your application with Apache Oozie +# Launching your application with Apache Oozie Apache Oozie can launch Spark applications as part of a workflow. In a secure cluster, the launched application will need the relevant tokens to access the cluster's @@ -576,35 +585,7 @@ spark.security.credentials.hbase.enabled false The configuration option `spark.yarn.access.hadoopFileSystems` must be unset. -## Troubleshooting Kerberos - -Debugging Hadoop/Kerberos problems can be "difficult". One useful technique is to -enable extra logging of Kerberos operations in Hadoop by setting the `HADOOP_JAAS_DEBUG` -environment variable. - -```bash -export HADOOP_JAAS_DEBUG=true -``` - -The JDK classes can be configured to enable extra logging of their Kerberos and -SPNEGO/REST authentication via the system properties `sun.security.krb5.debug` -and `sun.security.spnego.debug=true` - -``` --Dsun.security.krb5.debug=true -Dsun.security.spnego.debug=true -``` - -All these options can be enabled in the Application Master: - -``` -spark.yarn.appMasterEnv.HADOOP_JAAS_DEBUG true -spark.yarn.am.extraJavaOptions -Dsun.security.krb5.debug=true -Dsun.security.spnego.debug=true -``` - -Finally, if the log level for `org.apache.spark.deploy.yarn.Client` is set to `DEBUG`, the log -will include a list of all tokens obtained, and their expiry details - -## Using the Spark History Server to replace the Spark Web UI +# Using the Spark History Server to replace the Spark Web UI It is possible to use the Spark History Server application page as the tracking URL for running applications when the application UI is disabled. This may be desirable on secure clusters, or to diff --git a/docs/security.md b/docs/security.md index 913d9df50eb1c..3e5607a9a0d67 100644 --- a/docs/security.md +++ b/docs/security.md @@ -3,47 +3,336 @@ layout: global displayTitle: Spark Security title: Security --- +* This will become a table of contents (this text will be scraped). +{:toc} -Spark currently supports authentication via a shared secret. Authentication can be configured to be on via the `spark.authenticate` configuration parameter. This parameter controls whether the Spark communication protocols do authentication using the shared secret. This authentication is a basic handshake to make sure both sides have the same shared secret and are allowed to communicate. If the shared secret is not identical they will not be allowed to communicate. The shared secret is created as follows: +# Spark RPC -* For Spark on [YARN](running-on-yarn.html) and local deployments, configuring `spark.authenticate` to `true` will automatically handle generating and distributing the shared secret. Each application will use a unique shared secret. -* For other types of Spark deployments, the Spark parameter `spark.authenticate.secret` should be configured on each of the nodes. This secret will be used by all the Master/Workers and applications. +## Authentication -## Web UI +Spark currently supports authentication for RPC channels using a shared secret. Authentication can +be turned on by setting the `spark.authenticate` configuration parameter. -The Spark UI can be secured by using [javax servlet filters](http://docs.oracle.com/javaee/6/api/javax/servlet/Filter.html) via the `spark.ui.filters` setting -and by using [https/SSL](http://en.wikipedia.org/wiki/HTTPS) via [SSL settings](security.html#ssl-configuration). +The exact mechanism used to generate and distribute the shared secret is deployment-specific. -### Authentication +For Spark on [YARN](running-on-yarn.html) and local deployments, Spark will automatically handle +generating and distributing the shared secret. Each application will use a unique shared secret. In +the case of YARN, this feature relies on YARN RPC encryption being enabled for the distribution of +secrets to be secure. -A user may want to secure the UI if it has data that other users should not be allowed to see. The javax servlet filter specified by the user can authenticate the user and then once the user is logged in, Spark can compare that user versus the view ACLs to make sure they are authorized to view the UI. The configs `spark.acls.enable`, `spark.ui.view.acls` and `spark.ui.view.acls.groups` control the behavior of the ACLs. Note that the user who started the application always has view access to the UI. On YARN, the Spark UI uses the standard YARN web application proxy mechanism and will authenticate via any installed Hadoop filters. +For other resource managers, `spark.authenticate.secret` must be configured on each of the nodes. +This secret will be shared by all the daemons and applications, so this deployment configuration is +not as secure as the above, especially when considering multi-tenant clusters. -Spark also supports modify ACLs to control who has access to modify a running Spark application. This includes things like killing the application or a task. This is controlled by the configs `spark.acls.enable`, `spark.modify.acls` and `spark.modify.acls.groups`. Note that if you are authenticating the web UI, in order to use the kill button on the web UI it might be necessary to add the users in the modify acls to the view acls also. On YARN, the modify acls are passed in and control who has modify access via YARN interfaces. -Spark allows for a set of administrators to be specified in the acls who always have view and modify permissions to all the applications. is controlled by the configs `spark.admin.acls` and `spark.admin.acls.groups`. This is useful on a shared cluster where you might have administrators or support staff who help users debug applications. + + + + + + + + + + + + +
    Property NameDefaultMeaning
    spark.authenticatefalseWhether Spark authenticates its internal connections.
    spark.authenticate.secretNone + The secret key used authentication. See above for when this configuration should be set. +
    + +## Encryption -## Event Logging +Spark supports AES-based encryption for RPC connections. For encryption to be enabled, RPC +authentication must also be enabled and properly configured. AES encryption uses the +[Apache Commons Crypto](http://commons.apache.org/proper/commons-crypto/) library, and Spark's +configuration system allows access to that library's configuration for advanced users. -If your applications are using event logging, the directory where the event logs go (`spark.eventLog.dir`) should be manually created and have the proper permissions set on it. If you want those log files secured, the permissions should be set to `drwxrwxrwxt` for that directory. The owner of the directory should be the super user who is running the history server and the group permissions should be restricted to super user group. This will allow all users to write to the directory but will prevent unprivileged users from removing or renaming a file unless they own the file or directory. The event log files will be created by Spark with permissions such that only the user and group have read and write access. +There is also support for SASL-based encryption, although it should be considered deprecated. It +is still required when talking to shuffle services from Spark versions older than 2.2.0. -## Encryption +The following table describes the different options available for configuring this feature. + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
    Property NameDefaultMeaning
    spark.network.crypto.enabledfalse + Enable AES-based RPC encryption, including the new authentication protocol added in 2.2.0. +
    spark.network.crypto.keyLength128 + The length in bits of the encryption key to generate. Valid values are 128, 192 and 256. +
    spark.network.crypto.keyFactoryAlgorithmPBKDF2WithHmacSHA1 + The key factory algorithm to use when generating encryption keys. Should be one of the + algorithms supported by the javax.crypto.SecretKeyFactory class in the JRE being used. +
    spark.network.crypto.config.*None + Configuration values for the commons-crypto library, such as which cipher implementations to + use. The config name should be the name of commons-crypto configuration without the + commons.crypto prefix. +
    spark.network.crypto.saslFallbacktrue + Whether to fall back to SASL authentication if authentication fails using Spark's internal + mechanism. This is useful when the application is connecting to old shuffle services that + do not support the internal Spark authentication protocol. On the shuffle service side, + disabling this feature will block older clients from authenticating. +
    spark.authenticate.enableSaslEncryptionfalse + Enable SASL-based encrypted communication. +
    spark.network.sasl.serverAlwaysEncryptfalse + Disable unencrypted connections for ports using SASL authentication. This will deny connections + from clients that have authentication enabled, but do not request SASL-based encryption. +
    + + +# Local Storage Encryption + +Spark supports encrypting temporary data written to local disks. This covers shuffle files, shuffle +spills and data blocks stored on disk (for both caching and broadcast variables). It does not cover +encrypting output data generated by applications with APIs such as `saveAsHadoopFile` or +`saveAsTable`. + +The following settings cover enabling encryption for data written to disk: + + + + + + + + + + + + + + + + + + + + + + + +
    Property NameDefaultMeaning
    spark.io.encryption.enabledfalse + Enable local disk I/O encryption. Currently supported by all modes except Mesos. It's strongly + recommended that RPC encryption be enabled when using this feature. +
    spark.io.encryption.keySizeBits128 + IO encryption key size in bits. Supported values are 128, 192 and 256. +
    spark.io.encryption.keygen.algorithmHmacSHA1 + The algorithm to use when generating the IO encryption key. The supported algorithms are + described in the KeyGenerator section of the Java Cryptography Architecture Standard Algorithm + Name Documentation. +
    spark.io.encryption.commons.config.*None + Configuration values for the commons-crypto library, such as which cipher implementations to + use. The config name should be the name of commons-crypto configuration without the + commons.crypto prefix. +
    + + +# Web UI + +## Authentication and Authorization + +Enabling authentication for the Web UIs is done using [javax servlet filters](http://docs.oracle.com/javaee/6/api/javax/servlet/Filter.html). +You will need a filter that implements the authentication method you want to deploy. Spark does not +provide any built-in authentication filters. + +Spark also supports access control to the UI when an authentication filter is present. Each +application can be configured with its own separate access control lists (ACLs). Spark +differentiates between "view" permissions (who is allowed to see the application's UI), and "modify" +permissions (who can do things like kill jobs in a running application). + +ACLs can be configured for either users or groups. Configuration entries accept comma-separated +lists as input, meaning multiple users or groups can be given the desired privileges. This can be +used if you run on a shared cluster and have a set of administrators or developers who need to +monitor applications they may not have started themselves. A wildcard (`*`) added to specific ACL +means that all users will have the respective pivilege. By default, only the user submitting the +application is added to the ACLs. + +Group membership is established by using a configurable group mapping provider. The mapper is +configured using the spark.user.groups.mapping config option, described in the table +below. + +The following options control the authentication of Web UIs: + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
    Property NameDefaultMeaning
    spark.ui.filtersNone + See the Spark UI configuration for how to configure + filters. +
    spark.acls.enablefalse + Whether UI ACLs should be enabled. If enabled, this checks to see if the user has access + permissions to view or modify the application. Note this requires the user to be authenticated, + so if no authentication filter is installed, this option does not do anything. +
    spark.admin.aclsNone + Comma-separated list of users that have view and modify access to the Spark application. +
    spark.admin.acls.groupsNone + Comma-separated list of groups that have view and modify access to the Spark application. +
    spark.modify.aclsNone + Comma-separated list of users that have modify access to the Spark application. +
    spark.modify.acls.groupsNone + Comma-separated list of groups that have modify access to the Spark application. +
    spark.ui.view.aclsNone + Comma-separated list of users that have view access to the Spark application. +
    spark.ui.view.acls.groupsNone + Comma-separated list of groups that have view access to the Spark application. +
    spark.user.groups.mappingorg.apache.spark.security.ShellBasedGroupsMappingProvider + The list of groups for a user is determined by a group mapping service defined by the trait + org.apache.spark.security.GroupMappingServiceProvider, which can be configured by + this property. + +
    By default, a Unix shell-based implementation is used, which collects this information + from the host OS. + +
    Note: This implementation supports only Unix/Linux-based environments. + Windows environment is currently not supported. However, a new platform/protocol can + be supported by implementing the trait mentioned above. +
    + +On YARN, the view and modify ACLs are provided to the YARN service when submitting applications, and +control who has the respective privileges via YARN interfaces. + +## Spark History Server ACLs -Spark supports SSL for HTTP protocols. SASL encryption is supported for the block transfer service -and the RPC endpoints. Shuffle files can also be encrypted if desired. +Authentication for the SHS Web UI is enabled the same way as for regular applications, using +servlet filters. -### SSL Configuration +To enable authorization in the SHS, a few extra options are used: + + + + + + + + + + + + + + + + + + +
    Property NameDefaultMeaning
    spark.history.ui.acls.enablefalse + Specifies whether ACLs should be checked to authorize users viewing the applications in + the history server. If enabled, access control checks are performed regardless of what the + individual applications had set for spark.ui.acls.enable. The application owner + will always have authorization to view their own application and any users specified via + spark.ui.view.acls and groups specified via spark.ui.view.acls.groups + when the application was run will also have authorization to view that application. + If disabled, no access control checks are made for any application UIs available through + the history server. +
    spark.history.ui.admin.aclsNone + Comma separated list of users that have view access to all the Spark applications in history + server. +
    spark.history.ui.admin.acls.groupsNone + Comma separated list of groups that have view access to all the Spark applications in history + server. +
    + +The SHS uses the same options to configure the group mapping provider as regular applications. +In this case, the group mapping provider will apply to all UIs server by the SHS, and individual +application configurations will be ignored. + +## SSL Configuration Configuration for SSL is organized hierarchically. The user can configure the default SSL settings which will be used for all the supported communication protocols unless they are overwritten by protocol-specific settings. This way the user can easily provide the common settings for all the -protocols without disabling the ability to configure each one individually. The common SSL settings -are at `spark.ssl` namespace in Spark configuration. The following table describes the -component-specific configuration namespaces used to override the default settings: +protocols without disabling the ability to configure each one individually. The following table +describes the the SSL configuration namespaces: + + + + @@ -58,49 +347,205 @@ component-specific configuration namespaces used to override the default setting
    Config Namespace Component
    spark.ssl + The default SSL configuration. These values will apply to all namespaces below, unless + explicitly overridden at the namespace level. +
    spark.ssl.ui Spark application Web UI
    -The full breakdown of available SSL options can be found on the [configuration page](configuration.html). -SSL must be configured on each node and configured for each component involved in communication using the particular protocol. +The full breakdown of available SSL options can be found below. The `${ns}` placeholder should be +replaced with one of the above namespaces. + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
    Property NameDefaultMeaning
    ${ns}.enabledfalseEnables SSL. When enabled, ${ns}.ssl.protocol is required.
    ${ns}.portNone + The port where the SSL service will listen on. + +
    The port must be defined within a specific namespace configuration. The default + namespace is ignored when reading this configuration. + +
    When not set, the SSL port will be derived from the non-SSL port for the + same service. A value of "0" will make the service bind to an ephemeral port. +
    ${ns}.enabledAlgorithmsNone + A comma separated list of ciphers. The specified ciphers must be supported by JVM. + +
    The reference list of protocols can be found in the "JSSE Cipher Suite Names" section + of the Java security guide. The list for Java 8 can be found at + this + page. + +
    Note: If not set, the default cipher suite for the JRE will be used. +
    ${ns}.keyPasswordNone + The password to the private key in the key store. +
    ${ns}.keyStoreNone + Path to the key store file. The path can be absolute or relative to the directory in which the + process is started. +
    ${ns}.keyStorePasswordNonePassword to the key store.
    ${ns}.keyStoreTypeJKSThe type of the key store.
    ${ns}.protocolNone + TLS protocol to use. The protocol must be supported by JVM. + +
    The reference list of protocols can be found in the "Additional JSSE Standard Names" + section of the Java security guide. For Java 8, the list can be found at + this + page. +
    ${ns}.needClientAuthfalseWhether to require client authentication.
    ${ns}.trustStoreNone + Path to the trust store file. The path can be absolute or relative to the directory in which + the process is started. +
    ${ns}.trustStorePasswordNonePassword for the trust store.
    ${ns}.trustStoreTypeJKSThe type of the trust store.
    + +## Preparing the key stores + +Key stores can be generated by `keytool` program. The reference documentation for this tool for +Java 8 is [here](https://docs.oracle.com/javase/8/docs/technotes/tools/unix/keytool.html). +The most basic steps to configure the key stores and the trust store for a Spark Standalone +deployment mode is as follows: + +* Generate a key pair for each node +* Export the public key of the key pair to a file on each node +* Import all exported public keys into a single trust store +* Distribute the trust store to the cluster nodes ### YARN mode -The key-store can be prepared on the client side and then distributed and used by the executors as the part of the application. It is possible because the user is able to deploy files before the application is started in YARN by using `spark.yarn.dist.files` or `spark.yarn.dist.archives` configuration settings. The responsibility for encryption of transferring these files is on YARN side and has nothing to do with Spark. -For long-running apps like Spark Streaming apps to be able to write to HDFS, it is possible to pass a principal and keytab to `spark-submit` via the `--principal` and `--keytab` parameters respectively. The keytab passed in will be copied over to the machine running the Application Master via the Hadoop Distributed Cache (securely - if YARN is configured with SSL and HDFS encryption is enabled). The Kerberos login will be periodically renewed using this principal and keytab and the delegation tokens required for HDFS will be generated periodically so the application can continue writing to HDFS. +To provide a local trust store or key store file to drivers running in cluster mode, they can be +distributed with the application using the `--files` command line argument (or the equivalent +`spark.files` configuration). The files will be placed on the driver's working directory, so the TLS +configuration should just reference the file name with no absolute path. + +Distributing local key stores this way may require the files to be staged in HDFS (or other similar +distributed file system used by the cluster), so it's recommended that the undelying file system be +configured with security in mind (e.g. by enabling authentication and wire encryption). ### Standalone mode -The user needs to provide key-stores and configuration options for master and workers. They have to be set by attaching appropriate Java system properties in `SPARK_MASTER_OPTS` and in `SPARK_WORKER_OPTS` environment variables, or just in `SPARK_DAEMON_JAVA_OPTS`. In this mode, the user may allow the executors to use the SSL settings inherited from the worker which spawned that executor. It can be accomplished by setting `spark.ssl.useNodeLocalConf` to `true`. If that parameter is set, the settings provided by user on the client side, are not used by the executors. + +The user needs to provide key stores and configuration options for master and workers. They have to +be set by attaching appropriate Java system properties in `SPARK_MASTER_OPTS` and in +`SPARK_WORKER_OPTS` environment variables, or just in `SPARK_DAEMON_JAVA_OPTS`. + +The user may allow the executors to use the SSL settings inherited from the worker process. That +can be accomplished by setting `spark.ssl.useNodeLocalConf` to `true`. In that case, the settings +provided by the user on the client side are not used. ### Mesos mode -Mesos 1.3.0 and newer supports `Secrets` primitives as both file-based and environment based secrets. Spark allows the specification of file-based and environment variable based secrets with the `spark.mesos.driver.secret.filenames` and `spark.mesos.driver.secret.envkeys`, respectively. Depending on the secret store backend secrets can be passed by reference or by value with the `spark.mesos.driver.secret.names` and `spark.mesos.driver.secret.values` configuration properties, respectively. Reference type secrets are served by the secret store and referred to by name, for example `/mysecret`. Value type secrets are passed on the command line and translated into their appropriate files or environment variables. +Mesos 1.3.0 and newer supports `Secrets` primitives as both file-based and environment based +secrets. Spark allows the specification of file-based and environment variable based secrets with +`spark.mesos.driver.secret.filenames` and `spark.mesos.driver.secret.envkeys`, respectively. -### Preparing the key-stores -Key-stores can be generated by `keytool` program. The reference documentation for this tool is -[here](https://docs.oracle.com/javase/7/docs/technotes/tools/solaris/keytool.html). The most basic -steps to configure the key-stores and the trust-store for the standalone deployment mode is as -follows: +Depending on the secret store backend secrets can be passed by reference or by value with the +`spark.mesos.driver.secret.names` and `spark.mesos.driver.secret.values` configuration properties, +respectively. -* Generate a keys pair for each node -* Export the public key of the key pair to a file on each node -* Import all exported public keys into a single trust-store -* Distribute the trust-store over the nodes +Reference type secrets are served by the secret store and referred to by name, for example +`/mysecret`. Value type secrets are passed on the command line and translated into their +appropriate files or environment variables. -### Configuring SASL Encryption +## HTTP Security Headers -SASL encryption is currently supported for the block transfer service when authentication -(`spark.authenticate`) is enabled. To enable SASL encryption for an application, set -`spark.authenticate.enableSaslEncryption` to `true` in the application's configuration. +Apache Spark can be configured to include HTTP headers to aid in preventing Cross Site Scripting +(XSS), Cross-Frame Scripting (XFS), MIME-Sniffing, and also to enforce HTTP Strict Transport +Security. -When using an external shuffle service, it's possible to disable unencrypted connections by setting -`spark.network.sasl.serverAlwaysEncrypt` to `true` in the shuffle service's configuration. If that -option is enabled, applications that are not set up to use SASL encryption will fail to connect to -the shuffle service. + + + + + + + + + + + + + + + + + +
    Property NameDefaultMeaning
    spark.ui.xXssProtection1; mode=block + Value for HTTP X-XSS-Protection response header. You can choose appropriate value + from below: +
      +
    • 0 (Disables XSS filtering)
    • +
    • 1 (Enables XSS filtering. If a cross-site scripting attack is detected, + the browser will sanitize the page.)
    • +
    • 1; mode=block (Enables XSS filtering. The browser will prevent rendering + of the page if an attack is detected.)
    • +
    +
    spark.ui.xContentTypeOptions.enabledtrue + When enabled, X-Content-Type-Options HTTP response header will be set to "nosniff". +
    spark.ui.strictTransportSecurityNone + Value for HTTP Strict Transport Security (HSTS) Response Header. You can choose appropriate + value from below and set expire-time accordingly. This option is only used when + SSL/TLS is enabled. +
      +
    • max-age=<expire-time>
    • +
    • max-age=<expire-time>; includeSubDomains
    • +
    • max-age=<expire-time>; preload
    • +
    +
    -## Configuring Ports for Network Security + +# Configuring Ports for Network Security Spark makes heavy use of the network, and some environments have strict requirements for using tight firewall settings. Below are the primary ports that Spark uses for its communication and how to configure those ports. -### Standalone mode only +## Standalone mode only @@ -141,7 +586,7 @@ configure those ports.
    -### All cluster managers +## All cluster managers @@ -182,54 +627,70 @@ configure those ports.
    -### HTTP Security Headers -Apache Spark can be configured to include HTTP Headers which aids in preventing Cross -Site Scripting (XSS), Cross-Frame Scripting (XFS), MIME-Sniffing and also enforces HTTP -Strict Transport Security. +# Kerberos + +Spark supports submitting applications in environments that use Kerberos for authentication. +In most cases, Spark relies on the credentials of the current logged in user when authenticating +to Kerberos-aware services. Such credentials can be obtained by logging in to the configured KDC +with tools like `kinit`. + +When talking to Hadoop-based services, Spark needs to obtain delegation tokens so that non-local +processes can authenticate. Spark ships with support for HDFS and other Hadoop file systems, Hive +and HBase. + +When using a Hadoop filesystem (such HDFS or WebHDFS), Spark will acquire the relevant tokens +for the service hosting the user's home directory. + +An HBase token will be obtained if HBase is in the application's classpath, and the HBase +configuration has Kerberos authentication turned (`hbase.security.authentication=kerberos`). + +Similarly, a Hive token will be obtained if Hive is in the classpath, and the configuration includes +URIs for remote metastore services (`hive.metastore.uris` is not empty). + +Delegation token support is currently only supported in YARN and Mesos modes. Consult the +deployment-specific page for more information. + +The following options provides finer-grained control for this feature: - - - - - - + - - - - -
    Property NameDefaultMeaning
    spark.ui.xXssProtection1; mode=block - Value for HTTP X-XSS-Protection response header. You can choose appropriate value - from below: -
      -
    • 0 (Disables XSS filtering)
    • -
    • 1 (Enables XSS filtering. If a cross-site scripting attack is detected, - the browser will sanitize the page.)
    • -
    • 1; mode=block (Enables XSS filtering. The browser will prevent rendering - of the page if an attack is detected.)
    • -
    -
    spark.ui.xContentTypeOptions.enabledspark.security.credentials.${service}.enabled true - When value is set to "true", X-Content-Type-Options HTTP response header will be set - to "nosniff". Set "false" to disable. -
    spark.ui.strictTransportSecurityNone - Value for HTTP Strict Transport Security (HSTS) Response Header. You can choose appropriate - value from below and set expire-time accordingly, when Spark is SSL/TLS enabled. -
      -
    • max-age=<expire-time>
    • -
    • max-age=<expire-time>; includeSubDomains
    • -
    • max-age=<expire-time>; preload
    • -
    + Controls whether to obtain credentials for services when security is enabled. + By default, credentials for all supported services are retrieved when those services are + configured, but it's possible to disable that behavior if it somehow conflicts with the + application being run.
    - -See the [configuration page](configuration.html) for more details on the security configuration -parameters, and -org.apache.spark.SecurityManager for implementation details about security. +## Long-Running Applications + +Long-running applications may run into issues if their run time exceeds the maximum delegation +token lifetime configured in services it needs to access. + +Spark supports automatically creating new tokens for these applications when running in YARN mode. +Kerberos credentials need to be provided to the Spark application via the `spark-submit` command, +using the `--principal` and `--keytab` parameters. + +The provided keytab will be copied over to the machine running the Application Master via the Hadoop +Distributed Cache. For this reason, it's strongly recommended that both YARN and HDFS be secured +with encryption, at least. + +The Kerberos login will be periodically renewed using the provided credentials, and new delegation +tokens for supported will be created. + + +# Event Logging + +If your applications are using event logging, the directory where the event logs go +(`spark.eventLog.dir`) should be manually created with proper permissions. To secure the log files, +the directory permissions should be set to `drwxrwxrwxt`. The owner and group of the directory +should correspond to the super user who is running the Spark History Server. +This will allow all users to write to the directory but will prevent unprivileged users from +reading, removing or renaming a file unless they own it. The event log files will be created by +Spark with permissions such that only the user and group have read and write access. From 3e778f5a91b0553b09fe0e0ee84d771a71504960 Mon Sep 17 00:00:00 2001 From: Kevin Yu Date: Mon, 26 Mar 2018 15:45:27 -0700 Subject: [PATCH 0526/1161] [SPARK-23162][PYSPARK][ML] Add r2adj into Python API in LinearRegressionSummary ## What changes were proposed in this pull request? Adding r2adj in LinearRegressionSummary for Python API. ## How was this patch tested? Added unit tests to exercise the api calls for the summary classes in tests.py. Author: Kevin Yu Closes #20842 from kevinyu98/spark-23162. --- python/pyspark/ml/regression.py | 18 ++++++++++++++++-- python/pyspark/ml/tests.py | 1 + 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index de0a0fa9f3bf8..9a66d87d7f211 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -336,10 +336,10 @@ def rootMeanSquaredError(self): @since("2.0.0") def r2(self): """ - Returns R^2^, the coefficient of determination. + Returns R^2, the coefficient of determination. .. seealso:: `Wikipedia coefficient of determination \ - ` + `_ .. note:: This ignores instance weights (setting all to 1.0) from `LinearRegression.weightCol`. This will change in later Spark @@ -347,6 +347,20 @@ def r2(self): """ return self._call_java("r2") + @property + @since("2.4.0") + def r2adj(self): + """ + Returns Adjusted R^2, the adjusted coefficient of determination. + + .. seealso:: `Wikipedia coefficient of determination, Adjusted R^2 \ + `_ + + .. note:: This ignores instance weights (setting all to 1.0) from + `LinearRegression.weightCol`. This will change in later Spark versions. + """ + return self._call_java("r2adj") + @property @since("2.0.0") def residuals(self): diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index cf1ffa181ecec..6b4376cbf14e8 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -1559,6 +1559,7 @@ def test_linear_regression_summary(self): self.assertAlmostEqual(s.meanSquaredError, 0.0) self.assertAlmostEqual(s.rootMeanSquaredError, 0.0) self.assertAlmostEqual(s.r2, 1.0, 2) + self.assertAlmostEqual(s.r2adj, 1.0, 2) self.assertTrue(isinstance(s.residuals, DataFrame)) self.assertEqual(s.numInstances, 2) self.assertEqual(s.degreesOfFreedom, 1) From 35997b59f3116830af06b3d40a7675ef0dbf7091 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 27 Mar 2018 14:49:50 +0200 Subject: [PATCH 0527/1161] [SPARK-23794][SQL] Make UUID as stateful expression ## What changes were proposed in this pull request? The UUID() expression is stateful and should implement the `Stateful` trait instead of the `Nondeterministic` trait. ## How was this patch tested? Added test. Author: Liang-Chi Hsieh Closes #20912 from viirya/SPARK-23794. --- .../org/apache/spark/sql/catalyst/expressions/misc.scala | 4 +++- .../sql/catalyst/expressions/MiscExpressionsSuite.scala | 6 ++++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index ec93620038cff..a390f8ef7fd9a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -123,7 +123,7 @@ case class CurrentDatabase() extends LeafExpression with Unevaluable { 46707d92-02f4-4817-8116-a4c3b23e6266 """) // scalastyle:on line.size.limit -case class Uuid(randomSeed: Option[Long] = None) extends LeafExpression with Nondeterministic { +case class Uuid(randomSeed: Option[Long] = None) extends LeafExpression with Stateful { def this() = this(None) @@ -152,4 +152,6 @@ case class Uuid(randomSeed: Option[Long] = None) extends LeafExpression with Non ev.copy(code = s"final UTF8String ${ev.value} = $randomGen.getNextUUIDUTF8String();", isNull = "false") } + + override def freshCopy(): Uuid = Uuid(randomSeed) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala index 3383d421f5616..b6c269348b002 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala @@ -59,6 +59,12 @@ class MiscExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { evaluateWithGeneratedMutableProjection(Uuid(seed2))) assert(evaluateWithUnsafeProjection(Uuid(seed1)) !== evaluateWithUnsafeProjection(Uuid(seed2))) + + val uuid = Uuid(seed1) + assert(uuid.fastEquals(uuid)) + assert(!uuid.fastEquals(Uuid(seed1))) + assert(!uuid.fastEquals(uuid.freshCopy())) + assert(!uuid.fastEquals(Uuid(seed2))) } test("PrintToStderr") { From c68ec4e6a1ed9ea13345c7705ea60ff4df7aec7b Mon Sep 17 00:00:00 2001 From: jerryshao Date: Tue, 27 Mar 2018 14:39:05 -0700 Subject: [PATCH 0528/1161] [SPARK-23096][SS] Migrate rate source to V2 ## What changes were proposed in this pull request? This PR migrate micro batch rate source to V2 API and rewrite UTs to suite V2 test. ## How was this patch tested? UTs. Author: jerryshao Closes #20688 from jerryshao/SPARK-23096. --- ...pache.spark.sql.sources.DataSourceRegister | 3 +- .../execution/datasources/DataSource.scala | 6 +- .../streaming/RateSourceProvider.scala | 262 ------------- .../ContinuousRateStreamSource.scala | 25 +- .../sources/RateStreamMicroBatchReader.scala | 222 +++++++++++ .../sources/RateStreamProvider.scala | 125 +++++++ .../sources/RateStreamSourceV2.scala | 187 ---------- .../execution/streaming/RateSourceSuite.scala | 194 ---------- .../streaming/RateSourceV2Suite.scala | 191 ---------- .../sources/RateStreamProviderSuite.scala | 344 ++++++++++++++++++ 10 files changed, 715 insertions(+), 844 deletions(-) delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceSuite.scala delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala diff --git a/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister index 1fe9c093af99f..1b37905543b4e 100644 --- a/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister +++ b/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -5,6 +5,5 @@ org.apache.spark.sql.execution.datasources.orc.OrcFileFormat org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat org.apache.spark.sql.execution.datasources.text.TextFileFormat org.apache.spark.sql.execution.streaming.ConsoleSinkProvider -org.apache.spark.sql.execution.streaming.RateSourceProvider +org.apache.spark.sql.execution.streaming.sources.RateStreamProvider org.apache.spark.sql.execution.streaming.sources.TextSocketSourceProvider -org.apache.spark.sql.execution.streaming.sources.RateSourceProviderV2 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index 31fa89b4570a6..b84ea769808f9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -41,7 +41,7 @@ import org.apache.spark.sql.execution.datasources.json.JsonFileFormat import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat import org.apache.spark.sql.execution.streaming._ -import org.apache.spark.sql.execution.streaming.sources.TextSocketSourceProvider +import org.apache.spark.sql.execution.streaming.sources.{RateStreamProvider, TextSocketSourceProvider} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources._ import org.apache.spark.sql.streaming.OutputMode @@ -566,6 +566,7 @@ object DataSource extends Logging { val orc = "org.apache.spark.sql.hive.orc.OrcFileFormat" val nativeOrc = classOf[OrcFileFormat].getCanonicalName val socket = classOf[TextSocketSourceProvider].getCanonicalName + val rate = classOf[RateStreamProvider].getCanonicalName Map( "org.apache.spark.sql.jdbc" -> jdbc, @@ -587,7 +588,8 @@ object DataSource extends Logging { "org.apache.spark.ml.source.libsvm.DefaultSource" -> libsvm, "org.apache.spark.ml.source.libsvm" -> libsvm, "com.databricks.spark.csv" -> csv, - "org.apache.spark.sql.execution.streaming.TextSocketSourceProvider" -> socket + "org.apache.spark.sql.execution.streaming.TextSocketSourceProvider" -> socket, + "org.apache.spark.sql.execution.streaming.RateSourceProvider" -> rate ) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala deleted file mode 100644 index 649fbbfa184ec..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala +++ /dev/null @@ -1,262 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.streaming - -import java.io._ -import java.nio.charset.StandardCharsets -import java.util.Optional -import java.util.concurrent.TimeUnit - -import org.apache.commons.io.IOUtils - -import org.apache.spark.internal.Logging -import org.apache.spark.network.util.JavaUtils -import org.apache.spark.sql.{AnalysisException, DataFrame, SQLContext} -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} -import org.apache.spark.sql.execution.streaming.continuous.RateStreamContinuousReader -import org.apache.spark.sql.sources.{DataSourceRegister, StreamSourceProvider} -import org.apache.spark.sql.sources.v2._ -import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousReader -import org.apache.spark.sql.types._ -import org.apache.spark.util.{ManualClock, SystemClock} - -/** - * A source that generates increment long values with timestamps. Each generated row has two - * columns: a timestamp column for the generated time and an auto increment long column starting - * with 0L. - * - * This source supports the following options: - * - `rowsPerSecond` (e.g. 100, default: 1): How many rows should be generated per second. - * - `rampUpTime` (e.g. 5s, default: 0s): How long to ramp up before the generating speed - * becomes `rowsPerSecond`. Using finer granularities than seconds will be truncated to integer - * seconds. - * - `numPartitions` (e.g. 10, default: Spark's default parallelism): The partition number for the - * generated rows. The source will try its best to reach `rowsPerSecond`, but the query may - * be resource constrained, and `numPartitions` can be tweaked to help reach the desired speed. - */ -class RateSourceProvider extends StreamSourceProvider with DataSourceRegister - with DataSourceV2 with ContinuousReadSupport { - - override def sourceSchema( - sqlContext: SQLContext, - schema: Option[StructType], - providerName: String, - parameters: Map[String, String]): (String, StructType) = { - if (schema.nonEmpty) { - throw new AnalysisException("The rate source does not support a user-specified schema.") - } - - (shortName(), RateSourceProvider.SCHEMA) - } - - override def createSource( - sqlContext: SQLContext, - metadataPath: String, - schema: Option[StructType], - providerName: String, - parameters: Map[String, String]): Source = { - val params = CaseInsensitiveMap(parameters) - - val rowsPerSecond = params.get("rowsPerSecond").map(_.toLong).getOrElse(1L) - if (rowsPerSecond <= 0) { - throw new IllegalArgumentException( - s"Invalid value '${params("rowsPerSecond")}'. The option 'rowsPerSecond' " + - "must be positive") - } - - val rampUpTimeSeconds = - params.get("rampUpTime").map(JavaUtils.timeStringAsSec(_)).getOrElse(0L) - if (rampUpTimeSeconds < 0) { - throw new IllegalArgumentException( - s"Invalid value '${params("rampUpTime")}'. The option 'rampUpTime' " + - "must not be negative") - } - - val numPartitions = params.get("numPartitions").map(_.toInt).getOrElse( - sqlContext.sparkContext.defaultParallelism) - if (numPartitions <= 0) { - throw new IllegalArgumentException( - s"Invalid value '${params("numPartitions")}'. The option 'numPartitions' " + - "must be positive") - } - - new RateStreamSource( - sqlContext, - metadataPath, - rowsPerSecond, - rampUpTimeSeconds, - numPartitions, - params.get("useManualClock").map(_.toBoolean).getOrElse(false) // Only for testing - ) - } - - override def createContinuousReader( - schema: Optional[StructType], - checkpointLocation: String, - options: DataSourceOptions): ContinuousReader = { - new RateStreamContinuousReader(options) - } - - override def shortName(): String = "rate" -} - -object RateSourceProvider { - val SCHEMA = - StructType(StructField("timestamp", TimestampType) :: StructField("value", LongType) :: Nil) - - val VERSION = 1 -} - -class RateStreamSource( - sqlContext: SQLContext, - metadataPath: String, - rowsPerSecond: Long, - rampUpTimeSeconds: Long, - numPartitions: Int, - useManualClock: Boolean) extends Source with Logging { - - import RateSourceProvider._ - import RateStreamSource._ - - val clock = if (useManualClock) new ManualClock else new SystemClock - - private val maxSeconds = Long.MaxValue / rowsPerSecond - - if (rampUpTimeSeconds > maxSeconds) { - throw new ArithmeticException( - s"Integer overflow. Max offset with $rowsPerSecond rowsPerSecond" + - s" is $maxSeconds, but 'rampUpTimeSeconds' is $rampUpTimeSeconds.") - } - - private val startTimeMs = { - val metadataLog = - new HDFSMetadataLog[LongOffset](sqlContext.sparkSession, metadataPath) { - override def serialize(metadata: LongOffset, out: OutputStream): Unit = { - val writer = new BufferedWriter(new OutputStreamWriter(out, StandardCharsets.UTF_8)) - writer.write("v" + VERSION + "\n") - writer.write(metadata.json) - writer.flush - } - - override def deserialize(in: InputStream): LongOffset = { - val content = IOUtils.toString(new InputStreamReader(in, StandardCharsets.UTF_8)) - // HDFSMetadataLog guarantees that it never creates a partial file. - assert(content.length != 0) - if (content(0) == 'v') { - val indexOfNewLine = content.indexOf("\n") - if (indexOfNewLine > 0) { - val version = parseVersion(content.substring(0, indexOfNewLine), VERSION) - LongOffset(SerializedOffset(content.substring(indexOfNewLine + 1))) - } else { - throw new IllegalStateException( - s"Log file was malformed: failed to detect the log file version line.") - } - } else { - throw new IllegalStateException( - s"Log file was malformed: failed to detect the log file version line.") - } - } - } - - metadataLog.get(0).getOrElse { - val offset = LongOffset(clock.getTimeMillis()) - metadataLog.add(0, offset) - logInfo(s"Start time: $offset") - offset - }.offset - } - - /** When the system time runs backward, "lastTimeMs" will make sure we are still monotonic. */ - @volatile private var lastTimeMs = startTimeMs - - override def schema: StructType = RateSourceProvider.SCHEMA - - override def getOffset: Option[Offset] = { - val now = clock.getTimeMillis() - if (lastTimeMs < now) { - lastTimeMs = now - } - Some(LongOffset(TimeUnit.MILLISECONDS.toSeconds(lastTimeMs - startTimeMs))) - } - - override def getBatch(start: Option[Offset], end: Offset): DataFrame = { - val startSeconds = start.flatMap(LongOffset.convert(_).map(_.offset)).getOrElse(0L) - val endSeconds = LongOffset.convert(end).map(_.offset).getOrElse(0L) - assert(startSeconds <= endSeconds, s"startSeconds($startSeconds) > endSeconds($endSeconds)") - if (endSeconds > maxSeconds) { - throw new ArithmeticException("Integer overflow. Max offset with " + - s"$rowsPerSecond rowsPerSecond is $maxSeconds, but it's $endSeconds now.") - } - // Fix "lastTimeMs" for recovery - if (lastTimeMs < TimeUnit.SECONDS.toMillis(endSeconds) + startTimeMs) { - lastTimeMs = TimeUnit.SECONDS.toMillis(endSeconds) + startTimeMs - } - val rangeStart = valueAtSecond(startSeconds, rowsPerSecond, rampUpTimeSeconds) - val rangeEnd = valueAtSecond(endSeconds, rowsPerSecond, rampUpTimeSeconds) - logDebug(s"startSeconds: $startSeconds, endSeconds: $endSeconds, " + - s"rangeStart: $rangeStart, rangeEnd: $rangeEnd") - - if (rangeStart == rangeEnd) { - return sqlContext.internalCreateDataFrame( - sqlContext.sparkContext.emptyRDD, schema, isStreaming = true) - } - - val localStartTimeMs = startTimeMs + TimeUnit.SECONDS.toMillis(startSeconds) - val relativeMsPerValue = - TimeUnit.SECONDS.toMillis(endSeconds - startSeconds).toDouble / (rangeEnd - rangeStart) - - val rdd = sqlContext.sparkContext.range(rangeStart, rangeEnd, 1, numPartitions).map { v => - val relative = math.round((v - rangeStart) * relativeMsPerValue) - InternalRow(DateTimeUtils.fromMillis(relative + localStartTimeMs), v) - } - sqlContext.internalCreateDataFrame(rdd, schema, isStreaming = true) - } - - override def stop(): Unit = {} - - override def toString: String = s"RateSource[rowsPerSecond=$rowsPerSecond, " + - s"rampUpTimeSeconds=$rampUpTimeSeconds, numPartitions=$numPartitions]" -} - -object RateStreamSource { - - /** Calculate the end value we will emit at the time `seconds`. */ - def valueAtSecond(seconds: Long, rowsPerSecond: Long, rampUpTimeSeconds: Long): Long = { - // E.g., rampUpTimeSeconds = 4, rowsPerSecond = 10 - // Then speedDeltaPerSecond = 2 - // - // seconds = 0 1 2 3 4 5 6 - // speed = 0 2 4 6 8 10 10 (speedDeltaPerSecond * seconds) - // end value = 0 2 6 12 20 30 40 (0 + speedDeltaPerSecond * seconds) * (seconds + 1) / 2 - val speedDeltaPerSecond = rowsPerSecond / (rampUpTimeSeconds + 1) - if (seconds <= rampUpTimeSeconds) { - // Calculate "(0 + speedDeltaPerSecond * seconds) * (seconds + 1) / 2" in a special way to - // avoid overflow - if (seconds % 2 == 1) { - (seconds + 1) / 2 * speedDeltaPerSecond * seconds - } else { - seconds / 2 * speedDeltaPerSecond * (seconds + 1) - } - } else { - // rampUpPart is just a special case of the above formula: rampUpTimeSeconds == seconds - val rampUpPart = valueAtSecond(rampUpTimeSeconds, rowsPerSecond, rampUpTimeSeconds) - rampUpPart + (seconds - rampUpTimeSeconds) * rowsPerSecond - } - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala index 20d90069163a6..2f0de2612c150 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala @@ -24,8 +24,8 @@ import org.json4s.jackson.Serialization import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.execution.streaming.{RateSourceProvider, RateStreamOffset, ValueRunTimeMsPair} -import org.apache.spark.sql.execution.streaming.sources.RateStreamSourceV2 +import org.apache.spark.sql.execution.streaming.{RateStreamOffset, ValueRunTimeMsPair} +import org.apache.spark.sql.execution.streaming.sources.RateStreamProvider import org.apache.spark.sql.sources.v2.DataSourceOptions import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousDataReader, ContinuousReader, Offset, PartitionOffset} @@ -40,8 +40,8 @@ class RateStreamContinuousReader(options: DataSourceOptions) val creationTime = System.currentTimeMillis() - val numPartitions = options.get(RateStreamSourceV2.NUM_PARTITIONS).orElse("5").toInt - val rowsPerSecond = options.get(RateStreamSourceV2.ROWS_PER_SECOND).orElse("6").toLong + val numPartitions = options.get(RateStreamProvider.NUM_PARTITIONS).orElse("5").toInt + val rowsPerSecond = options.get(RateStreamProvider.ROWS_PER_SECOND).orElse("6").toLong val perPartitionRate = rowsPerSecond.toDouble / numPartitions.toDouble override def mergeOffsets(offsets: Array[PartitionOffset]): Offset = { @@ -57,12 +57,12 @@ class RateStreamContinuousReader(options: DataSourceOptions) RateStreamOffset(Serialization.read[Map[Int, ValueRunTimeMsPair]](json)) } - override def readSchema(): StructType = RateSourceProvider.SCHEMA + override def readSchema(): StructType = RateStreamProvider.SCHEMA private var offset: Offset = _ override def setStartOffset(offset: java.util.Optional[Offset]): Unit = { - this.offset = offset.orElse(RateStreamSourceV2.createInitialOffset(numPartitions, creationTime)) + this.offset = offset.orElse(createInitialOffset(numPartitions, creationTime)) } override def getStartOffset(): Offset = offset @@ -98,6 +98,19 @@ class RateStreamContinuousReader(options: DataSourceOptions) override def commit(end: Offset): Unit = {} override def stop(): Unit = {} + private def createInitialOffset(numPartitions: Int, creationTimeMs: Long) = { + RateStreamOffset( + Range(0, numPartitions).map { i => + // Note that the starting offset is exclusive, so we have to decrement the starting value + // by the increment that will later be applied. The first row output in each + // partition will have a value equal to the partition index. + (i, + ValueRunTimeMsPair( + (i - numPartitions).toLong, + creationTimeMs)) + }.toMap) + } + } case class RateStreamContinuousDataReaderFactory( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala new file mode 100644 index 0000000000000..6cf8520fc544f --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala @@ -0,0 +1,222 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.sources + +import java.io._ +import java.nio.charset.StandardCharsets +import java.util.Optional +import java.util.concurrent.TimeUnit + +import scala.collection.JavaConverters._ + +import org.apache.commons.io.IOUtils + +import org.apache.spark.internal.Logging +import org.apache.spark.network.util.JavaUtils +import org.apache.spark.sql.{Row, SparkSession} +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.sources.v2.DataSourceOptions +import org.apache.spark.sql.sources.v2.reader._ +import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset} +import org.apache.spark.sql.types.StructType +import org.apache.spark.util.{ManualClock, SystemClock} + +class RateStreamMicroBatchReader(options: DataSourceOptions, checkpointLocation: String) + extends MicroBatchReader with Logging { + import RateStreamProvider._ + + private[sources] val clock = { + // The option to use a manual clock is provided only for unit testing purposes. + if (options.getBoolean("useManualClock", false)) new ManualClock else new SystemClock + } + + private val rowsPerSecond = + options.get(ROWS_PER_SECOND).orElse("1").toLong + + private val rampUpTimeSeconds = + Option(options.get(RAMP_UP_TIME).orElse(null.asInstanceOf[String])) + .map(JavaUtils.timeStringAsSec(_)) + .getOrElse(0L) + + private val maxSeconds = Long.MaxValue / rowsPerSecond + + if (rampUpTimeSeconds > maxSeconds) { + throw new ArithmeticException( + s"Integer overflow. Max offset with $rowsPerSecond rowsPerSecond" + + s" is $maxSeconds, but 'rampUpTimeSeconds' is $rampUpTimeSeconds.") + } + + private[sources] val creationTimeMs = { + val session = SparkSession.getActiveSession.orElse(SparkSession.getDefaultSession) + require(session.isDefined) + + val metadataLog = + new HDFSMetadataLog[LongOffset](session.get, checkpointLocation) { + override def serialize(metadata: LongOffset, out: OutputStream): Unit = { + val writer = new BufferedWriter(new OutputStreamWriter(out, StandardCharsets.UTF_8)) + writer.write("v" + VERSION + "\n") + writer.write(metadata.json) + writer.flush + } + + override def deserialize(in: InputStream): LongOffset = { + val content = IOUtils.toString(new InputStreamReader(in, StandardCharsets.UTF_8)) + // HDFSMetadataLog guarantees that it never creates a partial file. + assert(content.length != 0) + if (content(0) == 'v') { + val indexOfNewLine = content.indexOf("\n") + if (indexOfNewLine > 0) { + parseVersion(content.substring(0, indexOfNewLine), VERSION) + LongOffset(SerializedOffset(content.substring(indexOfNewLine + 1))) + } else { + throw new IllegalStateException( + s"Log file was malformed: failed to detect the log file version line.") + } + } else { + throw new IllegalStateException( + s"Log file was malformed: failed to detect the log file version line.") + } + } + } + + metadataLog.get(0).getOrElse { + val offset = LongOffset(clock.getTimeMillis()) + metadataLog.add(0, offset) + logInfo(s"Start time: $offset") + offset + }.offset + } + + @volatile private var lastTimeMs: Long = creationTimeMs + + private var start: LongOffset = _ + private var end: LongOffset = _ + + override def readSchema(): StructType = SCHEMA + + override def setOffsetRange(start: Optional[Offset], end: Optional[Offset]): Unit = { + this.start = start.orElse(LongOffset(0L)).asInstanceOf[LongOffset] + this.end = end.orElse { + val now = clock.getTimeMillis() + if (lastTimeMs < now) { + lastTimeMs = now + } + LongOffset(TimeUnit.MILLISECONDS.toSeconds(lastTimeMs - creationTimeMs)) + }.asInstanceOf[LongOffset] + } + + override def getStartOffset(): Offset = { + if (start == null) throw new IllegalStateException("start offset not set") + start + } + override def getEndOffset(): Offset = { + if (end == null) throw new IllegalStateException("end offset not set") + end + } + + override def deserializeOffset(json: String): Offset = { + LongOffset(json.toLong) + } + + override def createDataReaderFactories(): java.util.List[DataReaderFactory[Row]] = { + val startSeconds = LongOffset.convert(start).map(_.offset).getOrElse(0L) + val endSeconds = LongOffset.convert(end).map(_.offset).getOrElse(0L) + assert(startSeconds <= endSeconds, s"startSeconds($startSeconds) > endSeconds($endSeconds)") + if (endSeconds > maxSeconds) { + throw new ArithmeticException("Integer overflow. Max offset with " + + s"$rowsPerSecond rowsPerSecond is $maxSeconds, but it's $endSeconds now.") + } + // Fix "lastTimeMs" for recovery + if (lastTimeMs < TimeUnit.SECONDS.toMillis(endSeconds) + creationTimeMs) { + lastTimeMs = TimeUnit.SECONDS.toMillis(endSeconds) + creationTimeMs + } + val rangeStart = valueAtSecond(startSeconds, rowsPerSecond, rampUpTimeSeconds) + val rangeEnd = valueAtSecond(endSeconds, rowsPerSecond, rampUpTimeSeconds) + logDebug(s"startSeconds: $startSeconds, endSeconds: $endSeconds, " + + s"rangeStart: $rangeStart, rangeEnd: $rangeEnd") + + if (rangeStart == rangeEnd) { + return List.empty.asJava + } + + val localStartTimeMs = creationTimeMs + TimeUnit.SECONDS.toMillis(startSeconds) + val relativeMsPerValue = + TimeUnit.SECONDS.toMillis(endSeconds - startSeconds).toDouble / (rangeEnd - rangeStart) + val numPartitions = { + val activeSession = SparkSession.getActiveSession + require(activeSession.isDefined) + Option(options.get(NUM_PARTITIONS).orElse(null.asInstanceOf[String])) + .map(_.toInt) + .getOrElse(activeSession.get.sparkContext.defaultParallelism) + } + + (0 until numPartitions).map { p => + new RateStreamMicroBatchDataReaderFactory( + p, numPartitions, rangeStart, rangeEnd, localStartTimeMs, relativeMsPerValue) + : DataReaderFactory[Row] + }.toList.asJava + } + + override def commit(end: Offset): Unit = {} + + override def stop(): Unit = {} + + override def toString: String = s"MicroBatchRateSource[rowsPerSecond=$rowsPerSecond, " + + s"rampUpTimeSeconds=$rampUpTimeSeconds, " + + s"numPartitions=${options.get(NUM_PARTITIONS).orElse("default")}" +} + +class RateStreamMicroBatchDataReaderFactory( + partitionId: Int, + numPartitions: Int, + rangeStart: Long, + rangeEnd: Long, + localStartTimeMs: Long, + relativeMsPerValue: Double) extends DataReaderFactory[Row] { + + override def createDataReader(): DataReader[Row] = new RateStreamMicroBatchDataReader( + partitionId, numPartitions, rangeStart, rangeEnd, localStartTimeMs, relativeMsPerValue) +} + +class RateStreamMicroBatchDataReader( + partitionId: Int, + numPartitions: Int, + rangeStart: Long, + rangeEnd: Long, + localStartTimeMs: Long, + relativeMsPerValue: Double) extends DataReader[Row] { + private var count = 0 + + override def next(): Boolean = { + rangeStart + partitionId + numPartitions * count < rangeEnd + } + + override def get(): Row = { + val currValue = rangeStart + partitionId + numPartitions * count + count += 1 + val relative = math.round((currValue - rangeStart) * relativeMsPerValue) + Row( + DateTimeUtils.toJavaTimestamp( + DateTimeUtils.fromMillis(relative + localStartTimeMs)), + currValue + ) + } + + override def close(): Unit = {} +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala new file mode 100644 index 0000000000000..6bdd492f0cb35 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.sources + +import java.util.Optional + +import org.apache.spark.network.util.JavaUtils +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.execution.streaming.continuous.RateStreamContinuousReader +import org.apache.spark.sql.sources.DataSourceRegister +import org.apache.spark.sql.sources.v2._ +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReader, MicroBatchReader} +import org.apache.spark.sql.types._ + +/** + * A source that generates increment long values with timestamps. Each generated row has two + * columns: a timestamp column for the generated time and an auto increment long column starting + * with 0L. + * + * This source supports the following options: + * - `rowsPerSecond` (e.g. 100, default: 1): How many rows should be generated per second. + * - `rampUpTime` (e.g. 5s, default: 0s): How long to ramp up before the generating speed + * becomes `rowsPerSecond`. Using finer granularities than seconds will be truncated to integer + * seconds. + * - `numPartitions` (e.g. 10, default: Spark's default parallelism): The partition number for the + * generated rows. The source will try its best to reach `rowsPerSecond`, but the query may + * be resource constrained, and `numPartitions` can be tweaked to help reach the desired speed. + */ +class RateStreamProvider extends DataSourceV2 + with MicroBatchReadSupport with ContinuousReadSupport with DataSourceRegister { + import RateStreamProvider._ + + override def createMicroBatchReader( + schema: Optional[StructType], + checkpointLocation: String, + options: DataSourceOptions): MicroBatchReader = { + if (options.get(ROWS_PER_SECOND).isPresent) { + val rowsPerSecond = options.get(ROWS_PER_SECOND).get().toLong + if (rowsPerSecond <= 0) { + throw new IllegalArgumentException( + s"Invalid value '$rowsPerSecond'. The option 'rowsPerSecond' must be positive") + } + } + + if (options.get(RAMP_UP_TIME).isPresent) { + val rampUpTimeSeconds = + JavaUtils.timeStringAsSec(options.get(RAMP_UP_TIME).get()) + if (rampUpTimeSeconds < 0) { + throw new IllegalArgumentException( + s"Invalid value '$rampUpTimeSeconds'. The option 'rampUpTime' must not be negative") + } + } + + if (options.get(NUM_PARTITIONS).isPresent) { + val numPartitions = options.get(NUM_PARTITIONS).get().toInt + if (numPartitions <= 0) { + throw new IllegalArgumentException( + s"Invalid value '$numPartitions'. The option 'numPartitions' must be positive") + } + } + + if (schema.isPresent) { + throw new AnalysisException("The rate source does not support a user-specified schema.") + } + + new RateStreamMicroBatchReader(options, checkpointLocation) + } + + override def createContinuousReader( + schema: Optional[StructType], + checkpointLocation: String, + options: DataSourceOptions): ContinuousReader = new RateStreamContinuousReader(options) + + override def shortName(): String = "rate" +} + +object RateStreamProvider { + val SCHEMA = + StructType(StructField("timestamp", TimestampType) :: StructField("value", LongType) :: Nil) + + val VERSION = 1 + + val NUM_PARTITIONS = "numPartitions" + val ROWS_PER_SECOND = "rowsPerSecond" + val RAMP_UP_TIME = "rampUpTime" + + /** Calculate the end value we will emit at the time `seconds`. */ + def valueAtSecond(seconds: Long, rowsPerSecond: Long, rampUpTimeSeconds: Long): Long = { + // E.g., rampUpTimeSeconds = 4, rowsPerSecond = 10 + // Then speedDeltaPerSecond = 2 + // + // seconds = 0 1 2 3 4 5 6 + // speed = 0 2 4 6 8 10 10 (speedDeltaPerSecond * seconds) + // end value = 0 2 6 12 20 30 40 (0 + speedDeltaPerSecond * seconds) * (seconds + 1) / 2 + val speedDeltaPerSecond = rowsPerSecond / (rampUpTimeSeconds + 1) + if (seconds <= rampUpTimeSeconds) { + // Calculate "(0 + speedDeltaPerSecond * seconds) * (seconds + 1) / 2" in a special way to + // avoid overflow + if (seconds % 2 == 1) { + (seconds + 1) / 2 * speedDeltaPerSecond * seconds + } else { + seconds / 2 * speedDeltaPerSecond * (seconds + 1) + } + } else { + // rampUpPart is just a special case of the above formula: rampUpTimeSeconds == seconds + val rampUpPart = valueAtSecond(rampUpTimeSeconds, rowsPerSecond, rampUpTimeSeconds) + rampUpPart + (seconds - rampUpTimeSeconds) * rowsPerSecond + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala deleted file mode 100644 index 4e2459bb05bd6..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala +++ /dev/null @@ -1,187 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.streaming.sources - -import java.util.Optional - -import scala.collection.JavaConverters._ -import scala.collection.mutable - -import org.json4s.DefaultFormats -import org.json4s.jackson.Serialization - -import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.execution.streaming.{RateStreamOffset, ValueRunTimeMsPair} -import org.apache.spark.sql.sources.DataSourceRegister -import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, MicroBatchReadSupport} -import org.apache.spark.sql.sources.v2.reader._ -import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset} -import org.apache.spark.sql.types.{LongType, StructField, StructType, TimestampType} -import org.apache.spark.util.{ManualClock, SystemClock} - -/** - * This is a temporary register as we build out v2 migration. Microbatch read support should - * be implemented in the same register as v1. - */ -class RateSourceProviderV2 extends DataSourceV2 with MicroBatchReadSupport with DataSourceRegister { - override def createMicroBatchReader( - schema: Optional[StructType], - checkpointLocation: String, - options: DataSourceOptions): MicroBatchReader = { - new RateStreamMicroBatchReader(options) - } - - override def shortName(): String = "ratev2" -} - -class RateStreamMicroBatchReader(options: DataSourceOptions) - extends MicroBatchReader { - implicit val defaultFormats: DefaultFormats = DefaultFormats - - val clock = { - // The option to use a manual clock is provided only for unit testing purposes. - if (options.get("useManualClock").orElse("false").toBoolean) new ManualClock - else new SystemClock - } - - private val numPartitions = - options.get(RateStreamSourceV2.NUM_PARTITIONS).orElse("5").toInt - private val rowsPerSecond = - options.get(RateStreamSourceV2.ROWS_PER_SECOND).orElse("6").toLong - - // The interval (in milliseconds) between rows in each partition. - // e.g. if there are 4 global rows per second, and 2 partitions, each partition - // should output rows every (1000 * 2 / 4) = 500 ms. - private val msPerPartitionBetweenRows = (1000 * numPartitions) / rowsPerSecond - - override def readSchema(): StructType = { - StructType( - StructField("timestamp", TimestampType, false) :: - StructField("value", LongType, false) :: Nil) - } - - val creationTimeMs = clock.getTimeMillis() - - private var start: RateStreamOffset = _ - private var end: RateStreamOffset = _ - - override def setOffsetRange( - start: Optional[Offset], - end: Optional[Offset]): Unit = { - this.start = start.orElse( - RateStreamSourceV2.createInitialOffset(numPartitions, creationTimeMs)) - .asInstanceOf[RateStreamOffset] - - this.end = end.orElse { - val currentTime = clock.getTimeMillis() - RateStreamOffset( - this.start.partitionToValueAndRunTimeMs.map { - case startOffset @ (part, ValueRunTimeMsPair(currentVal, currentReadTime)) => - // Calculate the number of rows we should advance in this partition (based on the - // current time), and output a corresponding offset. - val readInterval = currentTime - currentReadTime - val numNewRows = readInterval / msPerPartitionBetweenRows - if (numNewRows <= 0) { - startOffset - } else { - (part, ValueRunTimeMsPair( - currentVal + (numNewRows * numPartitions), - currentReadTime + (numNewRows * msPerPartitionBetweenRows))) - } - } - ) - }.asInstanceOf[RateStreamOffset] - } - - override def getStartOffset(): Offset = { - if (start == null) throw new IllegalStateException("start offset not set") - start - } - override def getEndOffset(): Offset = { - if (end == null) throw new IllegalStateException("end offset not set") - end - } - - override def deserializeOffset(json: String): Offset = { - RateStreamOffset(Serialization.read[Map[Int, ValueRunTimeMsPair]](json)) - } - - override def createDataReaderFactories(): java.util.List[DataReaderFactory[Row]] = { - val startMap = start.partitionToValueAndRunTimeMs - val endMap = end.partitionToValueAndRunTimeMs - endMap.keys.toSeq.map { part => - val ValueRunTimeMsPair(endVal, _) = endMap(part) - val ValueRunTimeMsPair(startVal, startTimeMs) = startMap(part) - - val packedRows = mutable.ListBuffer[(Long, Long)]() - var outVal = startVal + numPartitions - var outTimeMs = startTimeMs - while (outVal <= endVal) { - packedRows.append((outTimeMs, outVal)) - outVal += numPartitions - outTimeMs += msPerPartitionBetweenRows - } - - RateStreamBatchTask(packedRows).asInstanceOf[DataReaderFactory[Row]] - }.toList.asJava - } - - override def commit(end: Offset): Unit = {} - override def stop(): Unit = {} -} - -case class RateStreamBatchTask(vals: Seq[(Long, Long)]) extends DataReaderFactory[Row] { - override def createDataReader(): DataReader[Row] = new RateStreamBatchReader(vals) -} - -class RateStreamBatchReader(vals: Seq[(Long, Long)]) extends DataReader[Row] { - private var currentIndex = -1 - - override def next(): Boolean = { - // Return true as long as the new index is in the seq. - currentIndex += 1 - currentIndex < vals.size - } - - override def get(): Row = { - Row( - DateTimeUtils.toJavaTimestamp(DateTimeUtils.fromMillis(vals(currentIndex)._1)), - vals(currentIndex)._2) - } - - override def close(): Unit = {} -} - -object RateStreamSourceV2 { - val NUM_PARTITIONS = "numPartitions" - val ROWS_PER_SECOND = "rowsPerSecond" - - private[sql] def createInitialOffset(numPartitions: Int, creationTimeMs: Long) = { - RateStreamOffset( - Range(0, numPartitions).map { i => - // Note that the starting offset is exclusive, so we have to decrement the starting value - // by the increment that will later be applied. The first row output in each - // partition will have a value equal to the partition index. - (i, - ValueRunTimeMsPair( - (i - numPartitions).toLong, - creationTimeMs)) - }.toMap) - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceSuite.scala deleted file mode 100644 index 03d0f63fa4d7f..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceSuite.scala +++ /dev/null @@ -1,194 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.streaming - -import java.util.concurrent.TimeUnit - -import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.functions._ -import org.apache.spark.sql.streaming.{StreamingQueryException, StreamTest} -import org.apache.spark.util.ManualClock - -class RateSourceSuite extends StreamTest { - - import testImplicits._ - - case class AdvanceRateManualClock(seconds: Long) extends AddData { - override def addData(query: Option[StreamExecution]): (Source, Offset) = { - assert(query.nonEmpty) - val rateSource = query.get.logicalPlan.collect { - case StreamingExecutionRelation(source, _) if source.isInstanceOf[RateStreamSource] => - source.asInstanceOf[RateStreamSource] - }.head - rateSource.clock.asInstanceOf[ManualClock].advance(TimeUnit.SECONDS.toMillis(seconds)) - (rateSource, rateSource.getOffset.get) - } - } - - test("basic") { - val input = spark.readStream - .format("rate") - .option("rowsPerSecond", "10") - .option("useManualClock", "true") - .load() - testStream(input)( - AdvanceRateManualClock(seconds = 1), - CheckLastBatch((0 until 10).map(v => new java.sql.Timestamp(v * 100L) -> v): _*), - StopStream, - StartStream(), - // Advance 2 seconds because creating a new RateSource will also create a new ManualClock - AdvanceRateManualClock(seconds = 2), - CheckLastBatch((10 until 20).map(v => new java.sql.Timestamp(v * 100L) -> v): _*) - ) - } - - test("uniform distribution of event timestamps") { - val input = spark.readStream - .format("rate") - .option("rowsPerSecond", "1500") - .option("useManualClock", "true") - .load() - .as[(java.sql.Timestamp, Long)] - .map(v => (v._1.getTime, v._2)) - val expectedAnswer = (0 until 1500).map { v => - (math.round(v * (1000.0 / 1500)), v) - } - testStream(input)( - AdvanceRateManualClock(seconds = 1), - CheckLastBatch(expectedAnswer: _*) - ) - } - - test("valueAtSecond") { - import RateStreamSource._ - - assert(valueAtSecond(seconds = 0, rowsPerSecond = 5, rampUpTimeSeconds = 0) === 0) - assert(valueAtSecond(seconds = 1, rowsPerSecond = 5, rampUpTimeSeconds = 0) === 5) - - assert(valueAtSecond(seconds = 0, rowsPerSecond = 5, rampUpTimeSeconds = 2) === 0) - assert(valueAtSecond(seconds = 1, rowsPerSecond = 5, rampUpTimeSeconds = 2) === 1) - assert(valueAtSecond(seconds = 2, rowsPerSecond = 5, rampUpTimeSeconds = 2) === 3) - assert(valueAtSecond(seconds = 3, rowsPerSecond = 5, rampUpTimeSeconds = 2) === 8) - - assert(valueAtSecond(seconds = 0, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 0) - assert(valueAtSecond(seconds = 1, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 2) - assert(valueAtSecond(seconds = 2, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 6) - assert(valueAtSecond(seconds = 3, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 12) - assert(valueAtSecond(seconds = 4, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 20) - assert(valueAtSecond(seconds = 5, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 30) - } - - test("rampUpTime") { - val input = spark.readStream - .format("rate") - .option("rowsPerSecond", "10") - .option("rampUpTime", "4s") - .option("useManualClock", "true") - .load() - .as[(java.sql.Timestamp, Long)] - .map(v => (v._1.getTime, v._2)) - testStream(input)( - AdvanceRateManualClock(seconds = 1), - CheckLastBatch((0 until 2).map(v => v * 500 -> v): _*), // speed = 2 - AdvanceRateManualClock(seconds = 1), - CheckLastBatch((2 until 6).map(v => 1000 + (v - 2) * 250 -> v): _*), // speed = 4 - AdvanceRateManualClock(seconds = 1), - CheckLastBatch({ - Seq(2000 -> 6, 2167 -> 7, 2333 -> 8, 2500 -> 9, 2667 -> 10, 2833 -> 11) - }: _*), // speed = 6 - AdvanceRateManualClock(seconds = 1), - CheckLastBatch((12 until 20).map(v => 3000 + (v - 12) * 125 -> v): _*), // speed = 8 - AdvanceRateManualClock(seconds = 1), - // Now we should reach full speed - CheckLastBatch((20 until 30).map(v => 4000 + (v - 20) * 100 -> v): _*), // speed = 10 - AdvanceRateManualClock(seconds = 1), - CheckLastBatch((30 until 40).map(v => 5000 + (v - 30) * 100 -> v): _*), // speed = 10 - AdvanceRateManualClock(seconds = 1), - CheckLastBatch((40 until 50).map(v => 6000 + (v - 40) * 100 -> v): _*) // speed = 10 - ) - } - - test("numPartitions") { - val input = spark.readStream - .format("rate") - .option("rowsPerSecond", "10") - .option("numPartitions", "6") - .option("useManualClock", "true") - .load() - .select(spark_partition_id()) - .distinct() - testStream(input)( - AdvanceRateManualClock(1), - CheckLastBatch((0 until 6): _*) - ) - } - - testQuietly("overflow") { - val input = spark.readStream - .format("rate") - .option("rowsPerSecond", Long.MaxValue.toString) - .option("useManualClock", "true") - .load() - .select(spark_partition_id()) - .distinct() - testStream(input)( - AdvanceRateManualClock(2), - ExpectFailure[ArithmeticException](t => { - Seq("overflow", "rowsPerSecond").foreach { msg => - assert(t.getMessage.contains(msg)) - } - }) - ) - } - - testQuietly("illegal option values") { - def testIllegalOptionValue( - option: String, - value: String, - expectedMessages: Seq[String]): Unit = { - val e = intercept[StreamingQueryException] { - spark.readStream - .format("rate") - .option(option, value) - .load() - .writeStream - .format("console") - .start() - .awaitTermination() - } - assert(e.getCause.isInstanceOf[IllegalArgumentException]) - for (msg <- expectedMessages) { - assert(e.getCause.getMessage.contains(msg)) - } - } - - testIllegalOptionValue("rowsPerSecond", "-1", Seq("-1", "rowsPerSecond", "positive")) - testIllegalOptionValue("numPartitions", "-1", Seq("-1", "numPartitions", "positive")) - } - - test("user-specified schema given") { - val exception = intercept[AnalysisException] { - spark.readStream - .format("rate") - .schema(spark.range(1).schema) - .load() - } - assert(exception.getMessage.contains( - "rate source does not support a user-specified schema")) - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala deleted file mode 100644 index 983ba1668f58f..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala +++ /dev/null @@ -1,191 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.streaming - -import java.util.Optional -import java.util.concurrent.TimeUnit - -import scala.collection.JavaConverters._ - -import org.apache.spark.sql.Row -import org.apache.spark.sql.execution.datasources.DataSource -import org.apache.spark.sql.execution.streaming.continuous._ -import org.apache.spark.sql.execution.streaming.sources.{RateStreamBatchTask, RateStreamMicroBatchReader, RateStreamSourceV2} -import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, MicroBatchReadSupport} -import org.apache.spark.sql.sources.v2.DataSourceOptions -import org.apache.spark.sql.streaming.StreamTest -import org.apache.spark.util.ManualClock - -class RateSourceV2Suite extends StreamTest { - import testImplicits._ - - case class AdvanceRateManualClock(seconds: Long) extends AddData { - override def addData(query: Option[StreamExecution]): (BaseStreamingSource, Offset) = { - assert(query.nonEmpty) - val rateSource = query.get.logicalPlan.collect { - case StreamingExecutionRelation(source: RateStreamMicroBatchReader, _) => source - }.head - rateSource.clock.asInstanceOf[ManualClock].advance(TimeUnit.SECONDS.toMillis(seconds)) - rateSource.setOffsetRange(Optional.empty(), Optional.empty()) - (rateSource, rateSource.getEndOffset()) - } - } - - test("microbatch in registry") { - DataSource.lookupDataSource("ratev2", spark.sqlContext.conf).newInstance() match { - case ds: MicroBatchReadSupport => - val reader = ds.createMicroBatchReader(Optional.empty(), "", DataSourceOptions.empty()) - assert(reader.isInstanceOf[RateStreamMicroBatchReader]) - case _ => - throw new IllegalStateException("Could not find v2 read support for rate") - } - } - - test("basic microbatch execution") { - val input = spark.readStream - .format("rateV2") - .option("numPartitions", "1") - .option("rowsPerSecond", "10") - .option("useManualClock", "true") - .load() - testStream(input, useV2Sink = true)( - AdvanceRateManualClock(seconds = 1), - CheckLastBatch((0 until 10).map(v => new java.sql.Timestamp(v * 100L) -> v): _*), - StopStream, - StartStream(), - // Advance 2 seconds because creating a new RateSource will also create a new ManualClock - AdvanceRateManualClock(seconds = 2), - CheckLastBatch((10 until 20).map(v => new java.sql.Timestamp(v * 100L) -> v): _*) - ) - } - - test("microbatch - numPartitions propagated") { - val reader = new RateStreamMicroBatchReader( - new DataSourceOptions(Map("numPartitions" -> "11", "rowsPerSecond" -> "33").asJava)) - reader.setOffsetRange(Optional.empty(), Optional.empty()) - val tasks = reader.createDataReaderFactories() - assert(tasks.size == 11) - } - - test("microbatch - set offset") { - val reader = new RateStreamMicroBatchReader(DataSourceOptions.empty()) - val startOffset = RateStreamOffset(Map((0, ValueRunTimeMsPair(0, 1000)))) - val endOffset = RateStreamOffset(Map((0, ValueRunTimeMsPair(0, 2000)))) - reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) - assert(reader.getStartOffset() == startOffset) - assert(reader.getEndOffset() == endOffset) - } - - test("microbatch - infer offsets") { - val reader = new RateStreamMicroBatchReader( - new DataSourceOptions(Map("numPartitions" -> "1", "rowsPerSecond" -> "100").asJava)) - reader.clock.waitTillTime(reader.clock.getTimeMillis() + 100) - reader.setOffsetRange(Optional.empty(), Optional.empty()) - reader.getStartOffset() match { - case r: RateStreamOffset => - assert(r.partitionToValueAndRunTimeMs(0).runTimeMs == reader.creationTimeMs) - case _ => throw new IllegalStateException("unexpected offset type") - } - reader.getEndOffset() match { - case r: RateStreamOffset => - // End offset may be a bit beyond 100 ms/9 rows after creation if the wait lasted - // longer than 100ms. It should never be early. - assert(r.partitionToValueAndRunTimeMs(0).value >= 9) - assert(r.partitionToValueAndRunTimeMs(0).runTimeMs >= reader.creationTimeMs + 100) - - case _ => throw new IllegalStateException("unexpected offset type") - } - } - - test("microbatch - predetermined batch size") { - val reader = new RateStreamMicroBatchReader( - new DataSourceOptions(Map("numPartitions" -> "1", "rowsPerSecond" -> "20").asJava)) - val startOffset = RateStreamOffset(Map((0, ValueRunTimeMsPair(0, 1000)))) - val endOffset = RateStreamOffset(Map((0, ValueRunTimeMsPair(20, 2000)))) - reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) - val tasks = reader.createDataReaderFactories() - assert(tasks.size == 1) - assert(tasks.get(0).asInstanceOf[RateStreamBatchTask].vals.size == 20) - } - - test("microbatch - data read") { - val reader = new RateStreamMicroBatchReader( - new DataSourceOptions(Map("numPartitions" -> "11", "rowsPerSecond" -> "33").asJava)) - val startOffset = RateStreamSourceV2.createInitialOffset(11, reader.creationTimeMs) - val endOffset = RateStreamOffset(startOffset.partitionToValueAndRunTimeMs.toSeq.map { - case (part, ValueRunTimeMsPair(currentVal, currentReadTime)) => - (part, ValueRunTimeMsPair(currentVal + 33, currentReadTime + 1000)) - }.toMap) - - reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) - val tasks = reader.createDataReaderFactories() - assert(tasks.size == 11) - - val readData = tasks.asScala - .map(_.createDataReader()) - .flatMap { reader => - val buf = scala.collection.mutable.ListBuffer[Row]() - while (reader.next()) buf.append(reader.get()) - buf - } - - assert(readData.map(_.getLong(1)).sorted == Range(0, 33)) - } - - test("continuous in registry") { - DataSource.lookupDataSource("rate", spark.sqlContext.conf).newInstance() match { - case ds: ContinuousReadSupport => - val reader = ds.createContinuousReader(Optional.empty(), "", DataSourceOptions.empty()) - assert(reader.isInstanceOf[RateStreamContinuousReader]) - case _ => - throw new IllegalStateException("Could not find v2 read support for rate") - } - } - - test("continuous data") { - val reader = new RateStreamContinuousReader( - new DataSourceOptions(Map("numPartitions" -> "2", "rowsPerSecond" -> "20").asJava)) - reader.setStartOffset(Optional.empty()) - val tasks = reader.createDataReaderFactories() - assert(tasks.size == 2) - - val data = scala.collection.mutable.ListBuffer[Row]() - tasks.asScala.foreach { - case t: RateStreamContinuousDataReaderFactory => - val startTimeMs = reader.getStartOffset() - .asInstanceOf[RateStreamOffset] - .partitionToValueAndRunTimeMs(t.partitionIndex) - .runTimeMs - val r = t.createDataReader().asInstanceOf[RateStreamContinuousDataReader] - for (rowIndex <- 0 to 9) { - r.next() - data.append(r.get()) - assert(r.getOffset() == - RateStreamPartitionOffset( - t.partitionIndex, - t.partitionIndex + rowIndex * 2, - startTimeMs + (rowIndex + 1) * 100)) - } - assert(System.currentTimeMillis() >= startTimeMs + 1000) - - case _ => throw new IllegalStateException("Unexpected task type") - } - - assert(data.map(_.getLong(1)).toSeq.sorted == Range(0, 20)) - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala new file mode 100644 index 0000000000000..9149e50962255 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala @@ -0,0 +1,344 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.sources + +import java.nio.file.Files +import java.util.Optional +import java.util.concurrent.TimeUnit + +import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.sql.{AnalysisException, Row, SparkSession} +import org.apache.spark.sql.catalyst.errors.TreeNodeException +import org.apache.spark.sql.execution.datasources.DataSource +import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.execution.streaming.continuous._ +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions, MicroBatchReadSupport} +import org.apache.spark.sql.sources.v2.reader.streaming.Offset +import org.apache.spark.sql.streaming.StreamTest +import org.apache.spark.util.ManualClock + +class RateSourceSuite extends StreamTest { + + import testImplicits._ + + protected override def beforeAll(): Unit = { + super.beforeAll() + SparkSession.setActiveSession(spark) + } + + override def afterAll(): Unit = { + SparkSession.clearActiveSession() + super.afterAll() + } + + case class AdvanceRateManualClock(seconds: Long) extends AddData { + override def addData(query: Option[StreamExecution]): (BaseStreamingSource, Offset) = { + assert(query.nonEmpty) + val rateSource = query.get.logicalPlan.collect { + case StreamingExecutionRelation(source: RateStreamMicroBatchReader, _) => source + }.head + + rateSource.clock.asInstanceOf[ManualClock].advance(TimeUnit.SECONDS.toMillis(seconds)) + val offset = LongOffset(TimeUnit.MILLISECONDS.toSeconds( + rateSource.clock.getTimeMillis() - rateSource.creationTimeMs)) + (rateSource, offset) + } + } + + test("microbatch in registry") { + DataSource.lookupDataSource("rate", spark.sqlContext.conf).newInstance() match { + case ds: MicroBatchReadSupport => + val reader = ds.createMicroBatchReader(Optional.empty(), "dummy", DataSourceOptions.empty()) + assert(reader.isInstanceOf[RateStreamMicroBatchReader]) + case _ => + throw new IllegalStateException("Could not find read support for rate") + } + } + + test("compatible with old path in registry") { + DataSource.lookupDataSource("org.apache.spark.sql.execution.streaming.RateSourceProvider", + spark.sqlContext.conf).newInstance() match { + case ds: MicroBatchReadSupport => + assert(ds.isInstanceOf[RateStreamProvider]) + case _ => + throw new IllegalStateException("Could not find read support for rate") + } + } + + test("microbatch - basic") { + val input = spark.readStream + .format("rate") + .option("rowsPerSecond", "10") + .option("useManualClock", "true") + .load() + testStream(input)( + AdvanceRateManualClock(seconds = 1), + CheckLastBatch((0 until 10).map(v => new java.sql.Timestamp(v * 100L) -> v): _*), + StopStream, + StartStream(), + // Advance 2 seconds because creating a new RateSource will also create a new ManualClock + AdvanceRateManualClock(seconds = 2), + CheckLastBatch((10 until 20).map(v => new java.sql.Timestamp(v * 100L) -> v): _*) + ) + } + + test("microbatch - uniform distribution of event timestamps") { + val input = spark.readStream + .format("rate") + .option("rowsPerSecond", "1500") + .option("useManualClock", "true") + .load() + .as[(java.sql.Timestamp, Long)] + .map(v => (v._1.getTime, v._2)) + val expectedAnswer = (0 until 1500).map { v => + (math.round(v * (1000.0 / 1500)), v) + } + testStream(input)( + AdvanceRateManualClock(seconds = 1), + CheckLastBatch(expectedAnswer: _*) + ) + } + + test("microbatch - set offset") { + val temp = Files.createTempDirectory("dummy").toString + val reader = new RateStreamMicroBatchReader(DataSourceOptions.empty(), temp) + val startOffset = LongOffset(0L) + val endOffset = LongOffset(1L) + reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) + assert(reader.getStartOffset() == startOffset) + assert(reader.getEndOffset() == endOffset) + } + + test("microbatch - infer offsets") { + val tempFolder = Files.createTempDirectory("dummy").toString + val reader = new RateStreamMicroBatchReader( + new DataSourceOptions( + Map("numPartitions" -> "1", "rowsPerSecond" -> "100", "useManualClock" -> "true").asJava), + tempFolder) + reader.clock.asInstanceOf[ManualClock].advance(100000) + reader.setOffsetRange(Optional.empty(), Optional.empty()) + reader.getStartOffset() match { + case r: LongOffset => assert(r.offset === 0L) + case _ => throw new IllegalStateException("unexpected offset type") + } + reader.getEndOffset() match { + case r: LongOffset => assert(r.offset >= 100) + case _ => throw new IllegalStateException("unexpected offset type") + } + } + + test("microbatch - predetermined batch size") { + val temp = Files.createTempDirectory("dummy").toString + val reader = new RateStreamMicroBatchReader( + new DataSourceOptions(Map("numPartitions" -> "1", "rowsPerSecond" -> "20").asJava), temp) + val startOffset = LongOffset(0L) + val endOffset = LongOffset(1L) + reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) + val tasks = reader.createDataReaderFactories() + assert(tasks.size == 1) + val dataReader = tasks.get(0).createDataReader() + val data = ArrayBuffer[Row]() + while (dataReader.next()) { + data.append(dataReader.get()) + } + assert(data.size === 20) + } + + test("microbatch - data read") { + val temp = Files.createTempDirectory("dummy").toString + val reader = new RateStreamMicroBatchReader( + new DataSourceOptions(Map("numPartitions" -> "11", "rowsPerSecond" -> "33").asJava), temp) + val startOffset = LongOffset(0L) + val endOffset = LongOffset(1L) + reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) + val tasks = reader.createDataReaderFactories() + assert(tasks.size == 11) + + val readData = tasks.asScala + .map(_.createDataReader()) + .flatMap { reader => + val buf = scala.collection.mutable.ListBuffer[Row]() + while (reader.next()) buf.append(reader.get()) + buf + } + + assert(readData.map(_.getLong(1)).sorted == Range(0, 33)) + } + + test("valueAtSecond") { + import RateStreamProvider._ + + assert(valueAtSecond(seconds = 0, rowsPerSecond = 5, rampUpTimeSeconds = 0) === 0) + assert(valueAtSecond(seconds = 1, rowsPerSecond = 5, rampUpTimeSeconds = 0) === 5) + + assert(valueAtSecond(seconds = 0, rowsPerSecond = 5, rampUpTimeSeconds = 2) === 0) + assert(valueAtSecond(seconds = 1, rowsPerSecond = 5, rampUpTimeSeconds = 2) === 1) + assert(valueAtSecond(seconds = 2, rowsPerSecond = 5, rampUpTimeSeconds = 2) === 3) + assert(valueAtSecond(seconds = 3, rowsPerSecond = 5, rampUpTimeSeconds = 2) === 8) + + assert(valueAtSecond(seconds = 0, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 0) + assert(valueAtSecond(seconds = 1, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 2) + assert(valueAtSecond(seconds = 2, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 6) + assert(valueAtSecond(seconds = 3, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 12) + assert(valueAtSecond(seconds = 4, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 20) + assert(valueAtSecond(seconds = 5, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 30) + } + + test("rampUpTime") { + val input = spark.readStream + .format("rate") + .option("rowsPerSecond", "10") + .option("rampUpTime", "4s") + .option("useManualClock", "true") + .load() + .as[(java.sql.Timestamp, Long)] + .map(v => (v._1.getTime, v._2)) + testStream(input)( + AdvanceRateManualClock(seconds = 1), + CheckLastBatch((0 until 2).map(v => v * 500 -> v): _*), // speed = 2 + AdvanceRateManualClock(seconds = 1), + CheckLastBatch((2 until 6).map(v => 1000 + (v - 2) * 250 -> v): _*), // speed = 4 + AdvanceRateManualClock(seconds = 1), + CheckLastBatch({ + Seq(2000 -> 6, 2167 -> 7, 2333 -> 8, 2500 -> 9, 2667 -> 10, 2833 -> 11) + }: _*), // speed = 6 + AdvanceRateManualClock(seconds = 1), + CheckLastBatch((12 until 20).map(v => 3000 + (v - 12) * 125 -> v): _*), // speed = 8 + AdvanceRateManualClock(seconds = 1), + // Now we should reach full speed + CheckLastBatch((20 until 30).map(v => 4000 + (v - 20) * 100 -> v): _*), // speed = 10 + AdvanceRateManualClock(seconds = 1), + CheckLastBatch((30 until 40).map(v => 5000 + (v - 30) * 100 -> v): _*), // speed = 10 + AdvanceRateManualClock(seconds = 1), + CheckLastBatch((40 until 50).map(v => 6000 + (v - 40) * 100 -> v): _*) // speed = 10 + ) + } + + test("numPartitions") { + val input = spark.readStream + .format("rate") + .option("rowsPerSecond", "10") + .option("numPartitions", "6") + .option("useManualClock", "true") + .load() + .select(spark_partition_id()) + .distinct() + testStream(input)( + AdvanceRateManualClock(1), + CheckLastBatch((0 until 6): _*) + ) + } + + testQuietly("overflow") { + val input = spark.readStream + .format("rate") + .option("rowsPerSecond", Long.MaxValue.toString) + .option("useManualClock", "true") + .load() + .select(spark_partition_id()) + .distinct() + testStream(input)( + AdvanceRateManualClock(2), + ExpectFailure[TreeNodeException[_]](t => { + Seq("overflow", "rowsPerSecond").foreach { msg => + assert(t.getCause.getMessage.contains(msg)) + } + }) + ) + } + + testQuietly("illegal option values") { + def testIllegalOptionValue( + option: String, + value: String, + expectedMessages: Seq[String]): Unit = { + val e = intercept[IllegalArgumentException] { + spark.readStream + .format("rate") + .option(option, value) + .load() + .writeStream + .format("console") + .start() + .awaitTermination() + } + for (msg <- expectedMessages) { + assert(e.getMessage.contains(msg)) + } + } + + testIllegalOptionValue("rowsPerSecond", "-1", Seq("-1", "rowsPerSecond", "positive")) + testIllegalOptionValue("numPartitions", "-1", Seq("-1", "numPartitions", "positive")) + } + + test("user-specified schema given") { + val exception = intercept[AnalysisException] { + spark.readStream + .format("rate") + .schema(spark.range(1).schema) + .load() + } + assert(exception.getMessage.contains( + "rate source does not support a user-specified schema")) + } + + test("continuous in registry") { + DataSource.lookupDataSource("rate", spark.sqlContext.conf).newInstance() match { + case ds: ContinuousReadSupport => + val reader = ds.createContinuousReader(Optional.empty(), "", DataSourceOptions.empty()) + assert(reader.isInstanceOf[RateStreamContinuousReader]) + case _ => + throw new IllegalStateException("Could not find read support for continuous rate") + } + } + + test("continuous data") { + val reader = new RateStreamContinuousReader( + new DataSourceOptions(Map("numPartitions" -> "2", "rowsPerSecond" -> "20").asJava)) + reader.setStartOffset(Optional.empty()) + val tasks = reader.createDataReaderFactories() + assert(tasks.size == 2) + + val data = scala.collection.mutable.ListBuffer[Row]() + tasks.asScala.foreach { + case t: RateStreamContinuousDataReaderFactory => + val startTimeMs = reader.getStartOffset() + .asInstanceOf[RateStreamOffset] + .partitionToValueAndRunTimeMs(t.partitionIndex) + .runTimeMs + val r = t.createDataReader().asInstanceOf[RateStreamContinuousDataReader] + for (rowIndex <- 0 to 9) { + r.next() + data.append(r.get()) + assert(r.getOffset() == + RateStreamPartitionOffset( + t.partitionIndex, + t.partitionIndex + rowIndex * 2, + startTimeMs + (rowIndex + 1) * 100)) + } + assert(System.currentTimeMillis() >= startTimeMs + 1000) + + case _ => throw new IllegalStateException("Unexpected task type") + } + + assert(data.map(_.getLong(1)).toSeq.sorted == Range(0, 20)) + } +} From ed72badb04a56d8046bbd185245abf5ae265ccfd Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Tue, 27 Mar 2018 20:06:12 -0700 Subject: [PATCH 0529/1161] [SPARK-23699][PYTHON][SQL] Raise same type of error caught with Arrow enabled ## What changes were proposed in this pull request? When using Arrow for createDataFrame or toPandas and an error is encountered with fallback disabled, this will raise the same type of error instead of a RuntimeError. This change also allows for the traceback of the error to be retained and prevents the accidental chaining of exceptions with Python 3. ## How was this patch tested? Updated existing tests to verify error type. Author: Bryan Cutler Closes #20839 from BryanCutler/arrow-raise-same-error-SPARK-23699. --- python/pyspark/sql/dataframe.py | 25 +++++++++++++------------ python/pyspark/sql/session.py | 13 +++++++------ python/pyspark/sql/tests.py | 10 +++++----- python/pyspark/sql/utils.py | 6 ++++++ 4 files changed, 31 insertions(+), 23 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 3fc194d8ec1d1..16f8e52dead7b 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -2007,7 +2007,7 @@ def toPandas(self): "toPandas attempted Arrow optimization because " "'spark.sql.execution.arrow.enabled' is set to true; however, " "failed by the reason below:\n %s\n" - "Attempts non-optimization as " + "Attempting non-optimization as " "'spark.sql.execution.arrow.fallback.enabled' is set to " "true." % _exception_message(e)) warnings.warn(msg) @@ -2015,11 +2015,12 @@ def toPandas(self): else: msg = ( "toPandas attempted Arrow optimization because " - "'spark.sql.execution.arrow.enabled' is set to true; however, " - "failed by the reason below:\n %s\n" - "For fallback to non-optimization automatically, please set true to " - "'spark.sql.execution.arrow.fallback.enabled'." % _exception_message(e)) - raise RuntimeError(msg) + "'spark.sql.execution.arrow.enabled' is set to true, but has reached " + "the error below and will not continue because automatic fallback " + "with 'spark.sql.execution.arrow.fallback.enabled' has been set to " + "false.\n %s" % _exception_message(e)) + warnings.warn(msg) + raise # Try to use Arrow optimization when the schema is supported and the required version # of PyArrow is found, if 'spark.sql.execution.arrow.enabled' is enabled. @@ -2042,12 +2043,12 @@ def toPandas(self): # be executed. So, simply fail in this case for now. msg = ( "toPandas attempted Arrow optimization because " - "'spark.sql.execution.arrow.enabled' is set to true; however, " - "failed unexpectedly:\n %s\n" - "Note that 'spark.sql.execution.arrow.fallback.enabled' does " - "not have an effect in such failure in the middle of " - "computation." % _exception_message(e)) - raise RuntimeError(msg) + "'spark.sql.execution.arrow.enabled' is set to true, but has reached " + "the error below and can not continue. Note that " + "'spark.sql.execution.arrow.fallback.enabled' does not have an effect " + "on failures in the middle of computation.\n %s" % _exception_message(e)) + warnings.warn(msg) + raise # Below is toPandas without Arrow optimization. pdf = pd.DataFrame.from_records(self.collect(), columns=self.columns) diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index e82a9750a0014..13d6e2e53dbd0 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -674,18 +674,19 @@ def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=Tr "createDataFrame attempted Arrow optimization because " "'spark.sql.execution.arrow.enabled' is set to true; however, " "failed by the reason below:\n %s\n" - "Attempts non-optimization as " + "Attempting non-optimization as " "'spark.sql.execution.arrow.fallback.enabled' is set to " "true." % _exception_message(e)) warnings.warn(msg) else: msg = ( "createDataFrame attempted Arrow optimization because " - "'spark.sql.execution.arrow.enabled' is set to true; however, " - "failed by the reason below:\n %s\n" - "For fallback to non-optimization automatically, please set true to " - "'spark.sql.execution.arrow.fallback.enabled'." % _exception_message(e)) - raise RuntimeError(msg) + "'spark.sql.execution.arrow.enabled' is set to true, but has reached " + "the error below and will not continue because automatic fallback " + "with 'spark.sql.execution.arrow.fallback.enabled' has been set to " + "false.\n %s" % _exception_message(e)) + warnings.warn(msg) + raise data = self._convert_from_pandas(data, schema, timezone) if isinstance(schema, StructType): diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 967cc83166f3f..01c5dd6ff8c3f 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3559,7 +3559,7 @@ def test_toPandas_fallback_enabled(self): warn.message for warn in warns if isinstance(warn.message, UserWarning)] self.assertTrue(len(user_warns) > 0) self.assertTrue( - "Attempts non-optimization" in _exception_message(user_warns[-1])) + "Attempting non-optimization" in _exception_message(user_warns[-1])) self.assertPandasEqual(pdf, pd.DataFrame({u'map': [{u'a': 1}]})) def test_toPandas_fallback_disabled(self): @@ -3682,7 +3682,7 @@ def test_createDataFrame_with_incorrect_schema(self): pdf = self.create_pandas_data_frame() wrong_schema = StructType(list(reversed(self.schema))) with QuietTest(self.sc): - with self.assertRaisesRegexp(RuntimeError, ".*No cast.*string.*timestamp.*"): + with self.assertRaisesRegexp(Exception, ".*No cast.*string.*timestamp.*"): self.spark.createDataFrame(pdf, schema=wrong_schema) def test_createDataFrame_with_names(self): @@ -3707,7 +3707,7 @@ def test_createDataFrame_column_name_encoding(self): def test_createDataFrame_with_single_data_type(self): import pandas as pd with QuietTest(self.sc): - with self.assertRaisesRegexp(RuntimeError, ".*IntegerType.*not supported.*"): + with self.assertRaisesRegexp(ValueError, ".*IntegerType.*not supported.*"): self.spark.createDataFrame(pd.DataFrame({"a": [1]}), schema="int") def test_createDataFrame_does_not_modify_input(self): @@ -3775,14 +3775,14 @@ def test_createDataFrame_fallback_enabled(self): warn.message for warn in warns if isinstance(warn.message, UserWarning)] self.assertTrue(len(user_warns) > 0) self.assertTrue( - "Attempts non-optimization" in _exception_message(user_warns[-1])) + "Attempting non-optimization" in _exception_message(user_warns[-1])) self.assertEqual(df.collect(), [Row(a={u'a': 1})]) def test_createDataFrame_fallback_disabled(self): import pandas as pd with QuietTest(self.sc): - with self.assertRaisesRegexp(Exception, 'Unsupported type'): + with self.assertRaisesRegexp(TypeError, 'Unsupported type'): self.spark.createDataFrame( pd.DataFrame([[{u'a': 1}]]), "a: map") diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py index 578298632dd4c..45363f089a73d 100644 --- a/python/pyspark/sql/utils.py +++ b/python/pyspark/sql/utils.py @@ -121,7 +121,10 @@ def require_minimum_pandas_version(): from distutils.version import LooseVersion try: import pandas + have_pandas = True except ImportError: + have_pandas = False + if not have_pandas: raise ImportError("Pandas >= %s must be installed; however, " "it was not found." % minimum_pandas_version) if LooseVersion(pandas.__version__) < LooseVersion(minimum_pandas_version): @@ -138,7 +141,10 @@ def require_minimum_pyarrow_version(): from distutils.version import LooseVersion try: import pyarrow + have_arrow = True except ImportError: + have_arrow = False + if not have_arrow: raise ImportError("PyArrow >= %s must be installed; however, " "it was not found." % minimum_pyarrow_version) if LooseVersion(pyarrow.__version__) < LooseVersion(minimum_pyarrow_version): From 34c4b9c57e114cdb390e4dbc7383284d82fea317 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Wed, 28 Mar 2018 19:49:27 +0800 Subject: [PATCH 0530/1161] [SPARK-23765][SQL] Supports custom line separator for json datasource ## What changes were proposed in this pull request? This PR proposes to add lineSep option for a configurable line separator in text datasource. It supports this option by using `LineRecordReader`'s functionality with passing it to the constructor. The approach is similar with https://github.com/apache/spark/pull/20727; however, one main difference is, it uses text datasource's `lineSep` option to parse line by line in JSON's schema inference. ## How was this patch tested? Manually tested and unit tests were added. Author: hyukjinkwon Author: hyukjinkwon Closes #20877 from HyukjinKwon/linesep-json. --- python/pyspark/sql/readwriter.py | 14 ++-- python/pyspark/sql/streaming.py | 6 +- python/pyspark/sql/tests.py | 17 +++++ .../spark/sql/catalyst/json/JSONOptions.scala | 11 ++++ .../sql/catalyst/json/JacksonGenerator.scala | 8 ++- .../apache/spark/sql/DataFrameReader.scala | 2 + .../apache/spark/sql/DataFrameWriter.scala | 2 + .../datasources/json/JsonDataSource.scala | 17 +++-- .../datasources/text/TextOptions.scala | 2 +- .../sql/streaming/DataStreamReader.scala | 2 + .../datasources/json/JsonSuite.scala | 66 ++++++++++++++++++- .../datasources/text/TextSuite.scala | 4 +- 12 files changed, 136 insertions(+), 15 deletions(-) diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 4f9b9383a5ef4..6bd79bc2f43e5 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -176,7 +176,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, allowComments=None, allowUnquotedFieldNames=None, allowSingleQuotes=None, allowNumericLeadingZero=None, allowBackslashEscapingAnyCharacter=None, mode=None, columnNameOfCorruptRecord=None, dateFormat=None, timestampFormat=None, - multiLine=None, allowUnquotedControlChars=None): + multiLine=None, allowUnquotedControlChars=None, lineSep=None): """ Loads JSON files and returns the results as a :class:`DataFrame`. @@ -237,6 +237,8 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, :param allowUnquotedControlChars: allows JSON Strings to contain unquoted control characters (ASCII characters with value less than 32, including tab and line feed characters) or not. + :param lineSep: defines the line separator that should be used for parsing. If None is + set, it covers all ``\\r``, ``\\r\\n`` and ``\\n``. >>> df1 = spark.read.json('python/test_support/sql/people.json') >>> df1.dtypes @@ -254,7 +256,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, allowBackslashEscapingAnyCharacter=allowBackslashEscapingAnyCharacter, mode=mode, columnNameOfCorruptRecord=columnNameOfCorruptRecord, dateFormat=dateFormat, timestampFormat=timestampFormat, multiLine=multiLine, - allowUnquotedControlChars=allowUnquotedControlChars) + allowUnquotedControlChars=allowUnquotedControlChars, lineSep=lineSep) if isinstance(path, basestring): path = [path] if type(path) == list: @@ -746,7 +748,8 @@ def saveAsTable(self, name, format=None, mode=None, partitionBy=None, **options) self._jwrite.saveAsTable(name) @since(1.4) - def json(self, path, mode=None, compression=None, dateFormat=None, timestampFormat=None): + def json(self, path, mode=None, compression=None, dateFormat=None, timestampFormat=None, + lineSep=None): """Saves the content of the :class:`DataFrame` in JSON format (`JSON Lines text format or newline-delimited JSON `_) at the specified path. @@ -770,12 +773,15 @@ def json(self, path, mode=None, compression=None, dateFormat=None, timestampForm formats follow the formats at ``java.text.SimpleDateFormat``. This applies to timestamp type. If None is set, it uses the default value, ``yyyy-MM-dd'T'HH:mm:ss.SSSXXX``. + :param lineSep: defines the line separator that should be used for writing. If None is + set, it uses the default value, ``\\n``. >>> df.write.json(os.path.join(tempfile.mkdtemp(), 'data')) """ self.mode(mode) self._set_opts( - compression=compression, dateFormat=dateFormat, timestampFormat=timestampFormat) + compression=compression, dateFormat=dateFormat, timestampFormat=timestampFormat, + lineSep=lineSep) self._jwrite.json(path) @since(1.4) diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index c7907aaaf1f7b..15f9407389864 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -405,7 +405,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, allowComments=None, allowUnquotedFieldNames=None, allowSingleQuotes=None, allowNumericLeadingZero=None, allowBackslashEscapingAnyCharacter=None, mode=None, columnNameOfCorruptRecord=None, dateFormat=None, timestampFormat=None, - multiLine=None, allowUnquotedControlChars=None): + multiLine=None, allowUnquotedControlChars=None, lineSep=None): """ Loads a JSON file stream and returns the results as a :class:`DataFrame`. @@ -468,6 +468,8 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, :param allowUnquotedControlChars: allows JSON Strings to contain unquoted control characters (ASCII characters with value less than 32, including tab and line feed characters) or not. + :param lineSep: defines the line separator that should be used for parsing. If None is + set, it covers all ``\\r``, ``\\r\\n`` and ``\\n``. >>> json_sdf = spark.readStream.json(tempfile.mkdtemp(), schema = sdf_schema) >>> json_sdf.isStreaming @@ -482,7 +484,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, allowBackslashEscapingAnyCharacter=allowBackslashEscapingAnyCharacter, mode=mode, columnNameOfCorruptRecord=columnNameOfCorruptRecord, dateFormat=dateFormat, timestampFormat=timestampFormat, multiLine=multiLine, - allowUnquotedControlChars=allowUnquotedControlChars) + allowUnquotedControlChars=allowUnquotedControlChars, lineSep=lineSep) if isinstance(path, basestring): return self._df(self._jreader.json(path)) else: diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 01c5dd6ff8c3f..5181053a0d318 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -676,6 +676,23 @@ def test_multiline_json(self): multiLine=True) self.assertEqual(people1.collect(), people_array.collect()) + def test_linesep_json(self): + df = self.spark.read.json("python/test_support/sql/people.json", lineSep=",") + expected = [Row(_corrupt_record=None, name=u'Michael'), + Row(_corrupt_record=u' "age":30}\n{"name":"Justin"', name=None), + Row(_corrupt_record=u' "age":19}\n', name=None)] + self.assertEqual(df.collect(), expected) + + tpath = tempfile.mkdtemp() + shutil.rmtree(tpath) + try: + df = self.spark.read.json("python/test_support/sql/people.json") + df.write.json(tpath, lineSep="!!") + readback = self.spark.read.json(tpath, lineSep="!!") + self.assertEqual(readback.collect(), df.collect()) + finally: + shutil.rmtree(tpath) + def test_multiline_csv(self): ages_newlines = self.spark.read.csv( "python/test_support/sql/ages_newlines.csv", multiLine=True) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala index 652412b34478a..5c9adc3332bc0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.json +import java.nio.charset.StandardCharsets import java.util.{Locale, TimeZone} import com.fasterxml.jackson.core.{JsonFactory, JsonParser} @@ -85,6 +86,16 @@ private[sql] class JSONOptions( val multiLine = parameters.get("multiLine").map(_.toBoolean).getOrElse(false) + val lineSeparator: Option[String] = parameters.get("lineSep").map { sep => + require(sep.nonEmpty, "'lineSep' cannot be an empty string.") + sep + } + // Note that the option 'lineSep' uses a different default value in read and write. + val lineSeparatorInRead: Option[Array[Byte]] = + lineSeparator.map(_.getBytes(StandardCharsets.UTF_8)) + // Note that JSON uses writer with UTF-8 charset. This string will be written out as UTF-8. + val lineSeparatorInWrite: String = lineSeparator.getOrElse("\n") + /** Sets config options on a Jackson [[JsonFactory]]. */ def setJacksonOptions(factory: JsonFactory): Unit = { factory.configure(JsonParser.Feature.ALLOW_COMMENTS, allowComments) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala index eb06e4f304f0a..9c413de752a8c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.json import java.io.Writer +import java.nio.charset.StandardCharsets import com.fasterxml.jackson.core._ @@ -74,6 +75,8 @@ private[sql] class JacksonGenerator( private val gen = new JsonFactory().createGenerator(writer).setRootValueSeparator(null) + private val lineSeparator: String = options.lineSeparatorInWrite + private def makeWriter(dataType: DataType): ValueWriter = dataType match { case NullType => (row: SpecializedGetters, ordinal: Int) => @@ -251,5 +254,8 @@ private[sql] class JacksonGenerator( mapType = dataType.asInstanceOf[MapType])) } - def writeLineEnding(): Unit = gen.writeRaw('\n') + def writeLineEnding(): Unit = { + // Note that JSON uses writer with UTF-8 charset. This string will be written out as UTF-8. + gen.writeRaw(lineSeparator) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 1a5e47508c070..ae3ba1690f696 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -366,6 +366,8 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * `java.text.SimpleDateFormat`. This applies to timestamp type. *
  • `multiLine` (default `false`): parse one record, which may span multiple lines, * per file
  • + *
  • `lineSep` (default covers all `\r`, `\r\n` and `\n`): defines the line separator + * that should be used for parsing.
  • * * * @since 2.0.0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index bb93889dc55e9..bbc063148a72c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -518,6 +518,8 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { *
  • `timestampFormat` (default `yyyy-MM-dd'T'HH:mm:ss.SSSXXX`): sets the string that * indicates a timestamp format. Custom date formats follow the formats at * `java.text.SimpleDateFormat`. This applies to timestamp type.
  • + *
  • `lineSep` (default `\n`): defines the line separator that should + * be used for writing.
  • * * * @since 1.4.0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala index 77e7edc8e7a20..5769c09c9a1d9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala @@ -35,7 +35,7 @@ import org.apache.spark.sql.{AnalysisException, Dataset, Encoders, SparkSession} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JSONOptions} import org.apache.spark.sql.execution.datasources._ -import org.apache.spark.sql.execution.datasources.text.TextFileFormat +import org.apache.spark.sql.execution.datasources.text.{TextFileFormat, TextOptions} import org.apache.spark.sql.types.StructType import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils @@ -92,7 +92,8 @@ object TextInputJsonDataSource extends JsonDataSource { sparkSession: SparkSession, inputPaths: Seq[FileStatus], parsedOptions: JSONOptions): StructType = { - val json: Dataset[String] = createBaseDataset(sparkSession, inputPaths) + val json: Dataset[String] = createBaseDataset( + sparkSession, inputPaths, parsedOptions.lineSeparator) inferFromDataset(json, parsedOptions) } @@ -104,13 +105,19 @@ object TextInputJsonDataSource extends JsonDataSource { private def createBaseDataset( sparkSession: SparkSession, - inputPaths: Seq[FileStatus]): Dataset[String] = { + inputPaths: Seq[FileStatus], + lineSeparator: Option[String]): Dataset[String] = { + val textOptions = lineSeparator.map { lineSep => + Map(TextOptions.LINE_SEPARATOR -> lineSep) + }.getOrElse(Map.empty[String, String]) + val paths = inputPaths.map(_.getPath.toString) sparkSession.baseRelationToDataFrame( DataSource.apply( sparkSession, paths = paths, - className = classOf[TextFileFormat].getName + className = classOf[TextFileFormat].getName, + options = textOptions ).resolveRelation(checkFilesExist = false)) .select("value").as(Encoders.STRING) } @@ -120,7 +127,7 @@ object TextInputJsonDataSource extends JsonDataSource { file: PartitionedFile, parser: JacksonParser, schema: StructType): Iterator[InternalRow] = { - val linesReader = new HadoopFileLinesReader(file, conf) + val linesReader = new HadoopFileLinesReader(file, parser.options.lineSeparatorInRead, conf) Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => linesReader.close())) val safeParser = new FailureSafeParser[Text]( input => parser.parse(input, CreateJacksonParser.text, textToUTF8String), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextOptions.scala index 18698df9fd8e5..5c1a35434f7b5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextOptions.scala @@ -52,7 +52,7 @@ private[text] class TextOptions(@transient private val parameters: CaseInsensiti lineSeparatorInRead.getOrElse("\n".getBytes(StandardCharsets.UTF_8)) } -private[text] object TextOptions { +private[datasources] object TextOptions { val COMPRESSION = "compression" val WHOLETEXT = "wholetext" val LINE_SEPARATOR = "lineSep" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index 9b17406a816b5..ae93965bc50ed 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -268,6 +268,8 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo * `java.text.SimpleDateFormat`. This applies to timestamp type. *
  • `multiLine` (default `false`): parse one record, which may span multiple lines, * per file
  • + *
  • `lineSep` (default covers all `\r`, `\r\n` and `\n`): defines the line separator + * that should be used for parsing.
  • * * * @since 2.0.0 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 8c8d41ebf115a..10bac0554484a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.datasources.json import java.io.{File, StringWriter} import java.nio.charset.StandardCharsets +import java.nio.file.Files import java.sql.{Date, Timestamp} import java.util.Locale @@ -27,7 +28,7 @@ import org.apache.hadoop.fs.{Path, PathFilter} import org.apache.hadoop.io.SequenceFile.CompressionType import org.apache.hadoop.io.compress.GzipCodec -import org.apache.spark.SparkException +import org.apache.spark.{SparkException, TestUtils} import org.apache.spark.rdd.RDD import org.apache.spark.sql.{functions => F, _} import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JSONOptions} @@ -2063,4 +2064,67 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { ) } } + + def testLineSeparator(lineSep: String): Unit = { + test(s"SPARK-21289: Support line separator - lineSep: '$lineSep'") { + // Read + val data = + s""" + | {"f": + |"a", "f0": 1}$lineSep{"f": + | + |"c", "f0": 2}$lineSep{"f": "d", "f0": 3} + """.stripMargin + val dataWithTrailingLineSep = s"$data$lineSep" + + Seq(data, dataWithTrailingLineSep).foreach { lines => + withTempPath { path => + Files.write(path.toPath, lines.getBytes(StandardCharsets.UTF_8)) + val df = spark.read.option("lineSep", lineSep).json(path.getAbsolutePath) + val expectedSchema = + StructType(StructField("f", StringType) :: StructField("f0", LongType) :: Nil) + checkAnswer(df, Seq(("a", 1), ("c", 2), ("d", 3)).toDF()) + assert(df.schema === expectedSchema) + } + } + + // Write + withTempPath { path => + Seq("a", "b", "c").toDF("value").coalesce(1) + .write.option("lineSep", lineSep).json(path.getAbsolutePath) + val partFile = TestUtils.recursiveList(path).filter(f => f.getName.startsWith("part-")).head + val readBack = new String(Files.readAllBytes(partFile.toPath), StandardCharsets.UTF_8) + assert( + readBack === s"""{"value":"a"}$lineSep{"value":"b"}$lineSep{"value":"c"}$lineSep""") + } + + // Roundtrip + withTempPath { path => + val df = Seq("a", "b", "c").toDF() + df.write.option("lineSep", lineSep).json(path.getAbsolutePath) + val readBack = spark.read.option("lineSep", lineSep).json(path.getAbsolutePath) + checkAnswer(df, readBack) + } + } + } + + // scalastyle:off nonascii + Seq("|", "^", "::", "!!!@3", 0x1E.toChar.toString, "아").foreach { lineSep => + testLineSeparator(lineSep) + } + // scalastyle:on nonascii + + test("""SPARK-21289: Support line separator - default value \r, \r\n and \n""") { + val data = + "{\"f\": \"a\", \"f0\": 1}\r{\"f\": \"c\", \"f0\": 2}\r\n{\"f\": \"d\", \"f0\": 3}\n" + + withTempPath { path => + Files.write(path.toPath, data.getBytes(StandardCharsets.UTF_8)) + val df = spark.read.json(path.getAbsolutePath) + val expectedSchema = + StructType(StructField("f", StringType) :: StructField("f0", LongType) :: Nil) + checkAnswer(df, Seq(("a", 1), ("c", 2), ("d", 3)).toDF()) + assert(df.schema === expectedSchema) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala index e8a5299d6ba9d..0e7f3afa9c3ab 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala @@ -208,9 +208,11 @@ class TextSuite extends QueryTest with SharedSQLContext { } } - Seq("|", "^", "::", "!!!@3", 0x1E.toChar.toString).foreach { lineSep => + // scalastyle:off nonascii + Seq("|", "^", "::", "!!!@3", 0x1E.toChar.toString, "아").foreach { lineSep => testLineSeparator(lineSep) } + // scalastyle:on nonascii private def testFile: String = { Thread.currentThread().getContextClassLoader.getResource("test-data/text-suite.txt").toString From 761565a3ccbf7f083e587fee14a27b61867a3886 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Wed, 28 Mar 2018 09:11:52 -0700 Subject: [PATCH 0531/1161] Revert "[SPARK-23096][SS] Migrate rate source to V2" This reverts commit c68ec4e6a1ed9ea13345c7705ea60ff4df7aec7b. --- ...pache.spark.sql.sources.DataSourceRegister | 3 +- .../execution/datasources/DataSource.scala | 6 +- .../streaming/RateSourceProvider.scala | 262 +++++++++++++ .../ContinuousRateStreamSource.scala | 25 +- .../sources/RateStreamMicroBatchReader.scala | 222 ----------- .../sources/RateStreamProvider.scala | 125 ------- .../sources/RateStreamSourceV2.scala | 187 ++++++++++ .../execution/streaming/RateSourceSuite.scala | 194 ++++++++++ .../streaming/RateSourceV2Suite.scala | 191 ++++++++++ .../sources/RateStreamProviderSuite.scala | 344 ------------------ 10 files changed, 844 insertions(+), 715 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala diff --git a/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister index 1b37905543b4e..1fe9c093af99f 100644 --- a/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister +++ b/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -5,5 +5,6 @@ org.apache.spark.sql.execution.datasources.orc.OrcFileFormat org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat org.apache.spark.sql.execution.datasources.text.TextFileFormat org.apache.spark.sql.execution.streaming.ConsoleSinkProvider -org.apache.spark.sql.execution.streaming.sources.RateStreamProvider +org.apache.spark.sql.execution.streaming.RateSourceProvider org.apache.spark.sql.execution.streaming.sources.TextSocketSourceProvider +org.apache.spark.sql.execution.streaming.sources.RateSourceProviderV2 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index b84ea769808f9..31fa89b4570a6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -41,7 +41,7 @@ import org.apache.spark.sql.execution.datasources.json.JsonFileFormat import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat import org.apache.spark.sql.execution.streaming._ -import org.apache.spark.sql.execution.streaming.sources.{RateStreamProvider, TextSocketSourceProvider} +import org.apache.spark.sql.execution.streaming.sources.TextSocketSourceProvider import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources._ import org.apache.spark.sql.streaming.OutputMode @@ -566,7 +566,6 @@ object DataSource extends Logging { val orc = "org.apache.spark.sql.hive.orc.OrcFileFormat" val nativeOrc = classOf[OrcFileFormat].getCanonicalName val socket = classOf[TextSocketSourceProvider].getCanonicalName - val rate = classOf[RateStreamProvider].getCanonicalName Map( "org.apache.spark.sql.jdbc" -> jdbc, @@ -588,8 +587,7 @@ object DataSource extends Logging { "org.apache.spark.ml.source.libsvm.DefaultSource" -> libsvm, "org.apache.spark.ml.source.libsvm" -> libsvm, "com.databricks.spark.csv" -> csv, - "org.apache.spark.sql.execution.streaming.TextSocketSourceProvider" -> socket, - "org.apache.spark.sql.execution.streaming.RateSourceProvider" -> rate + "org.apache.spark.sql.execution.streaming.TextSocketSourceProvider" -> socket ) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala new file mode 100644 index 0000000000000..649fbbfa184ec --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala @@ -0,0 +1,262 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import java.io._ +import java.nio.charset.StandardCharsets +import java.util.Optional +import java.util.concurrent.TimeUnit + +import org.apache.commons.io.IOUtils + +import org.apache.spark.internal.Logging +import org.apache.spark.network.util.JavaUtils +import org.apache.spark.sql.{AnalysisException, DataFrame, SQLContext} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} +import org.apache.spark.sql.execution.streaming.continuous.RateStreamContinuousReader +import org.apache.spark.sql.sources.{DataSourceRegister, StreamSourceProvider} +import org.apache.spark.sql.sources.v2._ +import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousReader +import org.apache.spark.sql.types._ +import org.apache.spark.util.{ManualClock, SystemClock} + +/** + * A source that generates increment long values with timestamps. Each generated row has two + * columns: a timestamp column for the generated time and an auto increment long column starting + * with 0L. + * + * This source supports the following options: + * - `rowsPerSecond` (e.g. 100, default: 1): How many rows should be generated per second. + * - `rampUpTime` (e.g. 5s, default: 0s): How long to ramp up before the generating speed + * becomes `rowsPerSecond`. Using finer granularities than seconds will be truncated to integer + * seconds. + * - `numPartitions` (e.g. 10, default: Spark's default parallelism): The partition number for the + * generated rows. The source will try its best to reach `rowsPerSecond`, but the query may + * be resource constrained, and `numPartitions` can be tweaked to help reach the desired speed. + */ +class RateSourceProvider extends StreamSourceProvider with DataSourceRegister + with DataSourceV2 with ContinuousReadSupport { + + override def sourceSchema( + sqlContext: SQLContext, + schema: Option[StructType], + providerName: String, + parameters: Map[String, String]): (String, StructType) = { + if (schema.nonEmpty) { + throw new AnalysisException("The rate source does not support a user-specified schema.") + } + + (shortName(), RateSourceProvider.SCHEMA) + } + + override def createSource( + sqlContext: SQLContext, + metadataPath: String, + schema: Option[StructType], + providerName: String, + parameters: Map[String, String]): Source = { + val params = CaseInsensitiveMap(parameters) + + val rowsPerSecond = params.get("rowsPerSecond").map(_.toLong).getOrElse(1L) + if (rowsPerSecond <= 0) { + throw new IllegalArgumentException( + s"Invalid value '${params("rowsPerSecond")}'. The option 'rowsPerSecond' " + + "must be positive") + } + + val rampUpTimeSeconds = + params.get("rampUpTime").map(JavaUtils.timeStringAsSec(_)).getOrElse(0L) + if (rampUpTimeSeconds < 0) { + throw new IllegalArgumentException( + s"Invalid value '${params("rampUpTime")}'. The option 'rampUpTime' " + + "must not be negative") + } + + val numPartitions = params.get("numPartitions").map(_.toInt).getOrElse( + sqlContext.sparkContext.defaultParallelism) + if (numPartitions <= 0) { + throw new IllegalArgumentException( + s"Invalid value '${params("numPartitions")}'. The option 'numPartitions' " + + "must be positive") + } + + new RateStreamSource( + sqlContext, + metadataPath, + rowsPerSecond, + rampUpTimeSeconds, + numPartitions, + params.get("useManualClock").map(_.toBoolean).getOrElse(false) // Only for testing + ) + } + + override def createContinuousReader( + schema: Optional[StructType], + checkpointLocation: String, + options: DataSourceOptions): ContinuousReader = { + new RateStreamContinuousReader(options) + } + + override def shortName(): String = "rate" +} + +object RateSourceProvider { + val SCHEMA = + StructType(StructField("timestamp", TimestampType) :: StructField("value", LongType) :: Nil) + + val VERSION = 1 +} + +class RateStreamSource( + sqlContext: SQLContext, + metadataPath: String, + rowsPerSecond: Long, + rampUpTimeSeconds: Long, + numPartitions: Int, + useManualClock: Boolean) extends Source with Logging { + + import RateSourceProvider._ + import RateStreamSource._ + + val clock = if (useManualClock) new ManualClock else new SystemClock + + private val maxSeconds = Long.MaxValue / rowsPerSecond + + if (rampUpTimeSeconds > maxSeconds) { + throw new ArithmeticException( + s"Integer overflow. Max offset with $rowsPerSecond rowsPerSecond" + + s" is $maxSeconds, but 'rampUpTimeSeconds' is $rampUpTimeSeconds.") + } + + private val startTimeMs = { + val metadataLog = + new HDFSMetadataLog[LongOffset](sqlContext.sparkSession, metadataPath) { + override def serialize(metadata: LongOffset, out: OutputStream): Unit = { + val writer = new BufferedWriter(new OutputStreamWriter(out, StandardCharsets.UTF_8)) + writer.write("v" + VERSION + "\n") + writer.write(metadata.json) + writer.flush + } + + override def deserialize(in: InputStream): LongOffset = { + val content = IOUtils.toString(new InputStreamReader(in, StandardCharsets.UTF_8)) + // HDFSMetadataLog guarantees that it never creates a partial file. + assert(content.length != 0) + if (content(0) == 'v') { + val indexOfNewLine = content.indexOf("\n") + if (indexOfNewLine > 0) { + val version = parseVersion(content.substring(0, indexOfNewLine), VERSION) + LongOffset(SerializedOffset(content.substring(indexOfNewLine + 1))) + } else { + throw new IllegalStateException( + s"Log file was malformed: failed to detect the log file version line.") + } + } else { + throw new IllegalStateException( + s"Log file was malformed: failed to detect the log file version line.") + } + } + } + + metadataLog.get(0).getOrElse { + val offset = LongOffset(clock.getTimeMillis()) + metadataLog.add(0, offset) + logInfo(s"Start time: $offset") + offset + }.offset + } + + /** When the system time runs backward, "lastTimeMs" will make sure we are still monotonic. */ + @volatile private var lastTimeMs = startTimeMs + + override def schema: StructType = RateSourceProvider.SCHEMA + + override def getOffset: Option[Offset] = { + val now = clock.getTimeMillis() + if (lastTimeMs < now) { + lastTimeMs = now + } + Some(LongOffset(TimeUnit.MILLISECONDS.toSeconds(lastTimeMs - startTimeMs))) + } + + override def getBatch(start: Option[Offset], end: Offset): DataFrame = { + val startSeconds = start.flatMap(LongOffset.convert(_).map(_.offset)).getOrElse(0L) + val endSeconds = LongOffset.convert(end).map(_.offset).getOrElse(0L) + assert(startSeconds <= endSeconds, s"startSeconds($startSeconds) > endSeconds($endSeconds)") + if (endSeconds > maxSeconds) { + throw new ArithmeticException("Integer overflow. Max offset with " + + s"$rowsPerSecond rowsPerSecond is $maxSeconds, but it's $endSeconds now.") + } + // Fix "lastTimeMs" for recovery + if (lastTimeMs < TimeUnit.SECONDS.toMillis(endSeconds) + startTimeMs) { + lastTimeMs = TimeUnit.SECONDS.toMillis(endSeconds) + startTimeMs + } + val rangeStart = valueAtSecond(startSeconds, rowsPerSecond, rampUpTimeSeconds) + val rangeEnd = valueAtSecond(endSeconds, rowsPerSecond, rampUpTimeSeconds) + logDebug(s"startSeconds: $startSeconds, endSeconds: $endSeconds, " + + s"rangeStart: $rangeStart, rangeEnd: $rangeEnd") + + if (rangeStart == rangeEnd) { + return sqlContext.internalCreateDataFrame( + sqlContext.sparkContext.emptyRDD, schema, isStreaming = true) + } + + val localStartTimeMs = startTimeMs + TimeUnit.SECONDS.toMillis(startSeconds) + val relativeMsPerValue = + TimeUnit.SECONDS.toMillis(endSeconds - startSeconds).toDouble / (rangeEnd - rangeStart) + + val rdd = sqlContext.sparkContext.range(rangeStart, rangeEnd, 1, numPartitions).map { v => + val relative = math.round((v - rangeStart) * relativeMsPerValue) + InternalRow(DateTimeUtils.fromMillis(relative + localStartTimeMs), v) + } + sqlContext.internalCreateDataFrame(rdd, schema, isStreaming = true) + } + + override def stop(): Unit = {} + + override def toString: String = s"RateSource[rowsPerSecond=$rowsPerSecond, " + + s"rampUpTimeSeconds=$rampUpTimeSeconds, numPartitions=$numPartitions]" +} + +object RateStreamSource { + + /** Calculate the end value we will emit at the time `seconds`. */ + def valueAtSecond(seconds: Long, rowsPerSecond: Long, rampUpTimeSeconds: Long): Long = { + // E.g., rampUpTimeSeconds = 4, rowsPerSecond = 10 + // Then speedDeltaPerSecond = 2 + // + // seconds = 0 1 2 3 4 5 6 + // speed = 0 2 4 6 8 10 10 (speedDeltaPerSecond * seconds) + // end value = 0 2 6 12 20 30 40 (0 + speedDeltaPerSecond * seconds) * (seconds + 1) / 2 + val speedDeltaPerSecond = rowsPerSecond / (rampUpTimeSeconds + 1) + if (seconds <= rampUpTimeSeconds) { + // Calculate "(0 + speedDeltaPerSecond * seconds) * (seconds + 1) / 2" in a special way to + // avoid overflow + if (seconds % 2 == 1) { + (seconds + 1) / 2 * speedDeltaPerSecond * seconds + } else { + seconds / 2 * speedDeltaPerSecond * (seconds + 1) + } + } else { + // rampUpPart is just a special case of the above formula: rampUpTimeSeconds == seconds + val rampUpPart = valueAtSecond(rampUpTimeSeconds, rowsPerSecond, rampUpTimeSeconds) + rampUpPart + (seconds - rampUpTimeSeconds) * rowsPerSecond + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala index 2f0de2612c150..20d90069163a6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala @@ -24,8 +24,8 @@ import org.json4s.jackson.Serialization import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.execution.streaming.{RateStreamOffset, ValueRunTimeMsPair} -import org.apache.spark.sql.execution.streaming.sources.RateStreamProvider +import org.apache.spark.sql.execution.streaming.{RateSourceProvider, RateStreamOffset, ValueRunTimeMsPair} +import org.apache.spark.sql.execution.streaming.sources.RateStreamSourceV2 import org.apache.spark.sql.sources.v2.DataSourceOptions import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousDataReader, ContinuousReader, Offset, PartitionOffset} @@ -40,8 +40,8 @@ class RateStreamContinuousReader(options: DataSourceOptions) val creationTime = System.currentTimeMillis() - val numPartitions = options.get(RateStreamProvider.NUM_PARTITIONS).orElse("5").toInt - val rowsPerSecond = options.get(RateStreamProvider.ROWS_PER_SECOND).orElse("6").toLong + val numPartitions = options.get(RateStreamSourceV2.NUM_PARTITIONS).orElse("5").toInt + val rowsPerSecond = options.get(RateStreamSourceV2.ROWS_PER_SECOND).orElse("6").toLong val perPartitionRate = rowsPerSecond.toDouble / numPartitions.toDouble override def mergeOffsets(offsets: Array[PartitionOffset]): Offset = { @@ -57,12 +57,12 @@ class RateStreamContinuousReader(options: DataSourceOptions) RateStreamOffset(Serialization.read[Map[Int, ValueRunTimeMsPair]](json)) } - override def readSchema(): StructType = RateStreamProvider.SCHEMA + override def readSchema(): StructType = RateSourceProvider.SCHEMA private var offset: Offset = _ override def setStartOffset(offset: java.util.Optional[Offset]): Unit = { - this.offset = offset.orElse(createInitialOffset(numPartitions, creationTime)) + this.offset = offset.orElse(RateStreamSourceV2.createInitialOffset(numPartitions, creationTime)) } override def getStartOffset(): Offset = offset @@ -98,19 +98,6 @@ class RateStreamContinuousReader(options: DataSourceOptions) override def commit(end: Offset): Unit = {} override def stop(): Unit = {} - private def createInitialOffset(numPartitions: Int, creationTimeMs: Long) = { - RateStreamOffset( - Range(0, numPartitions).map { i => - // Note that the starting offset is exclusive, so we have to decrement the starting value - // by the increment that will later be applied. The first row output in each - // partition will have a value equal to the partition index. - (i, - ValueRunTimeMsPair( - (i - numPartitions).toLong, - creationTimeMs)) - }.toMap) - } - } case class RateStreamContinuousDataReaderFactory( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala deleted file mode 100644 index 6cf8520fc544f..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala +++ /dev/null @@ -1,222 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.streaming.sources - -import java.io._ -import java.nio.charset.StandardCharsets -import java.util.Optional -import java.util.concurrent.TimeUnit - -import scala.collection.JavaConverters._ - -import org.apache.commons.io.IOUtils - -import org.apache.spark.internal.Logging -import org.apache.spark.network.util.JavaUtils -import org.apache.spark.sql.{Row, SparkSession} -import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.execution.streaming._ -import org.apache.spark.sql.sources.v2.DataSourceOptions -import org.apache.spark.sql.sources.v2.reader._ -import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset} -import org.apache.spark.sql.types.StructType -import org.apache.spark.util.{ManualClock, SystemClock} - -class RateStreamMicroBatchReader(options: DataSourceOptions, checkpointLocation: String) - extends MicroBatchReader with Logging { - import RateStreamProvider._ - - private[sources] val clock = { - // The option to use a manual clock is provided only for unit testing purposes. - if (options.getBoolean("useManualClock", false)) new ManualClock else new SystemClock - } - - private val rowsPerSecond = - options.get(ROWS_PER_SECOND).orElse("1").toLong - - private val rampUpTimeSeconds = - Option(options.get(RAMP_UP_TIME).orElse(null.asInstanceOf[String])) - .map(JavaUtils.timeStringAsSec(_)) - .getOrElse(0L) - - private val maxSeconds = Long.MaxValue / rowsPerSecond - - if (rampUpTimeSeconds > maxSeconds) { - throw new ArithmeticException( - s"Integer overflow. Max offset with $rowsPerSecond rowsPerSecond" + - s" is $maxSeconds, but 'rampUpTimeSeconds' is $rampUpTimeSeconds.") - } - - private[sources] val creationTimeMs = { - val session = SparkSession.getActiveSession.orElse(SparkSession.getDefaultSession) - require(session.isDefined) - - val metadataLog = - new HDFSMetadataLog[LongOffset](session.get, checkpointLocation) { - override def serialize(metadata: LongOffset, out: OutputStream): Unit = { - val writer = new BufferedWriter(new OutputStreamWriter(out, StandardCharsets.UTF_8)) - writer.write("v" + VERSION + "\n") - writer.write(metadata.json) - writer.flush - } - - override def deserialize(in: InputStream): LongOffset = { - val content = IOUtils.toString(new InputStreamReader(in, StandardCharsets.UTF_8)) - // HDFSMetadataLog guarantees that it never creates a partial file. - assert(content.length != 0) - if (content(0) == 'v') { - val indexOfNewLine = content.indexOf("\n") - if (indexOfNewLine > 0) { - parseVersion(content.substring(0, indexOfNewLine), VERSION) - LongOffset(SerializedOffset(content.substring(indexOfNewLine + 1))) - } else { - throw new IllegalStateException( - s"Log file was malformed: failed to detect the log file version line.") - } - } else { - throw new IllegalStateException( - s"Log file was malformed: failed to detect the log file version line.") - } - } - } - - metadataLog.get(0).getOrElse { - val offset = LongOffset(clock.getTimeMillis()) - metadataLog.add(0, offset) - logInfo(s"Start time: $offset") - offset - }.offset - } - - @volatile private var lastTimeMs: Long = creationTimeMs - - private var start: LongOffset = _ - private var end: LongOffset = _ - - override def readSchema(): StructType = SCHEMA - - override def setOffsetRange(start: Optional[Offset], end: Optional[Offset]): Unit = { - this.start = start.orElse(LongOffset(0L)).asInstanceOf[LongOffset] - this.end = end.orElse { - val now = clock.getTimeMillis() - if (lastTimeMs < now) { - lastTimeMs = now - } - LongOffset(TimeUnit.MILLISECONDS.toSeconds(lastTimeMs - creationTimeMs)) - }.asInstanceOf[LongOffset] - } - - override def getStartOffset(): Offset = { - if (start == null) throw new IllegalStateException("start offset not set") - start - } - override def getEndOffset(): Offset = { - if (end == null) throw new IllegalStateException("end offset not set") - end - } - - override def deserializeOffset(json: String): Offset = { - LongOffset(json.toLong) - } - - override def createDataReaderFactories(): java.util.List[DataReaderFactory[Row]] = { - val startSeconds = LongOffset.convert(start).map(_.offset).getOrElse(0L) - val endSeconds = LongOffset.convert(end).map(_.offset).getOrElse(0L) - assert(startSeconds <= endSeconds, s"startSeconds($startSeconds) > endSeconds($endSeconds)") - if (endSeconds > maxSeconds) { - throw new ArithmeticException("Integer overflow. Max offset with " + - s"$rowsPerSecond rowsPerSecond is $maxSeconds, but it's $endSeconds now.") - } - // Fix "lastTimeMs" for recovery - if (lastTimeMs < TimeUnit.SECONDS.toMillis(endSeconds) + creationTimeMs) { - lastTimeMs = TimeUnit.SECONDS.toMillis(endSeconds) + creationTimeMs - } - val rangeStart = valueAtSecond(startSeconds, rowsPerSecond, rampUpTimeSeconds) - val rangeEnd = valueAtSecond(endSeconds, rowsPerSecond, rampUpTimeSeconds) - logDebug(s"startSeconds: $startSeconds, endSeconds: $endSeconds, " + - s"rangeStart: $rangeStart, rangeEnd: $rangeEnd") - - if (rangeStart == rangeEnd) { - return List.empty.asJava - } - - val localStartTimeMs = creationTimeMs + TimeUnit.SECONDS.toMillis(startSeconds) - val relativeMsPerValue = - TimeUnit.SECONDS.toMillis(endSeconds - startSeconds).toDouble / (rangeEnd - rangeStart) - val numPartitions = { - val activeSession = SparkSession.getActiveSession - require(activeSession.isDefined) - Option(options.get(NUM_PARTITIONS).orElse(null.asInstanceOf[String])) - .map(_.toInt) - .getOrElse(activeSession.get.sparkContext.defaultParallelism) - } - - (0 until numPartitions).map { p => - new RateStreamMicroBatchDataReaderFactory( - p, numPartitions, rangeStart, rangeEnd, localStartTimeMs, relativeMsPerValue) - : DataReaderFactory[Row] - }.toList.asJava - } - - override def commit(end: Offset): Unit = {} - - override def stop(): Unit = {} - - override def toString: String = s"MicroBatchRateSource[rowsPerSecond=$rowsPerSecond, " + - s"rampUpTimeSeconds=$rampUpTimeSeconds, " + - s"numPartitions=${options.get(NUM_PARTITIONS).orElse("default")}" -} - -class RateStreamMicroBatchDataReaderFactory( - partitionId: Int, - numPartitions: Int, - rangeStart: Long, - rangeEnd: Long, - localStartTimeMs: Long, - relativeMsPerValue: Double) extends DataReaderFactory[Row] { - - override def createDataReader(): DataReader[Row] = new RateStreamMicroBatchDataReader( - partitionId, numPartitions, rangeStart, rangeEnd, localStartTimeMs, relativeMsPerValue) -} - -class RateStreamMicroBatchDataReader( - partitionId: Int, - numPartitions: Int, - rangeStart: Long, - rangeEnd: Long, - localStartTimeMs: Long, - relativeMsPerValue: Double) extends DataReader[Row] { - private var count = 0 - - override def next(): Boolean = { - rangeStart + partitionId + numPartitions * count < rangeEnd - } - - override def get(): Row = { - val currValue = rangeStart + partitionId + numPartitions * count - count += 1 - val relative = math.round((currValue - rangeStart) * relativeMsPerValue) - Row( - DateTimeUtils.toJavaTimestamp( - DateTimeUtils.fromMillis(relative + localStartTimeMs)), - currValue - ) - } - - override def close(): Unit = {} -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala deleted file mode 100644 index 6bdd492f0cb35..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala +++ /dev/null @@ -1,125 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.streaming.sources - -import java.util.Optional - -import org.apache.spark.network.util.JavaUtils -import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.execution.streaming.continuous.RateStreamContinuousReader -import org.apache.spark.sql.sources.DataSourceRegister -import org.apache.spark.sql.sources.v2._ -import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReader, MicroBatchReader} -import org.apache.spark.sql.types._ - -/** - * A source that generates increment long values with timestamps. Each generated row has two - * columns: a timestamp column for the generated time and an auto increment long column starting - * with 0L. - * - * This source supports the following options: - * - `rowsPerSecond` (e.g. 100, default: 1): How many rows should be generated per second. - * - `rampUpTime` (e.g. 5s, default: 0s): How long to ramp up before the generating speed - * becomes `rowsPerSecond`. Using finer granularities than seconds will be truncated to integer - * seconds. - * - `numPartitions` (e.g. 10, default: Spark's default parallelism): The partition number for the - * generated rows. The source will try its best to reach `rowsPerSecond`, but the query may - * be resource constrained, and `numPartitions` can be tweaked to help reach the desired speed. - */ -class RateStreamProvider extends DataSourceV2 - with MicroBatchReadSupport with ContinuousReadSupport with DataSourceRegister { - import RateStreamProvider._ - - override def createMicroBatchReader( - schema: Optional[StructType], - checkpointLocation: String, - options: DataSourceOptions): MicroBatchReader = { - if (options.get(ROWS_PER_SECOND).isPresent) { - val rowsPerSecond = options.get(ROWS_PER_SECOND).get().toLong - if (rowsPerSecond <= 0) { - throw new IllegalArgumentException( - s"Invalid value '$rowsPerSecond'. The option 'rowsPerSecond' must be positive") - } - } - - if (options.get(RAMP_UP_TIME).isPresent) { - val rampUpTimeSeconds = - JavaUtils.timeStringAsSec(options.get(RAMP_UP_TIME).get()) - if (rampUpTimeSeconds < 0) { - throw new IllegalArgumentException( - s"Invalid value '$rampUpTimeSeconds'. The option 'rampUpTime' must not be negative") - } - } - - if (options.get(NUM_PARTITIONS).isPresent) { - val numPartitions = options.get(NUM_PARTITIONS).get().toInt - if (numPartitions <= 0) { - throw new IllegalArgumentException( - s"Invalid value '$numPartitions'. The option 'numPartitions' must be positive") - } - } - - if (schema.isPresent) { - throw new AnalysisException("The rate source does not support a user-specified schema.") - } - - new RateStreamMicroBatchReader(options, checkpointLocation) - } - - override def createContinuousReader( - schema: Optional[StructType], - checkpointLocation: String, - options: DataSourceOptions): ContinuousReader = new RateStreamContinuousReader(options) - - override def shortName(): String = "rate" -} - -object RateStreamProvider { - val SCHEMA = - StructType(StructField("timestamp", TimestampType) :: StructField("value", LongType) :: Nil) - - val VERSION = 1 - - val NUM_PARTITIONS = "numPartitions" - val ROWS_PER_SECOND = "rowsPerSecond" - val RAMP_UP_TIME = "rampUpTime" - - /** Calculate the end value we will emit at the time `seconds`. */ - def valueAtSecond(seconds: Long, rowsPerSecond: Long, rampUpTimeSeconds: Long): Long = { - // E.g., rampUpTimeSeconds = 4, rowsPerSecond = 10 - // Then speedDeltaPerSecond = 2 - // - // seconds = 0 1 2 3 4 5 6 - // speed = 0 2 4 6 8 10 10 (speedDeltaPerSecond * seconds) - // end value = 0 2 6 12 20 30 40 (0 + speedDeltaPerSecond * seconds) * (seconds + 1) / 2 - val speedDeltaPerSecond = rowsPerSecond / (rampUpTimeSeconds + 1) - if (seconds <= rampUpTimeSeconds) { - // Calculate "(0 + speedDeltaPerSecond * seconds) * (seconds + 1) / 2" in a special way to - // avoid overflow - if (seconds % 2 == 1) { - (seconds + 1) / 2 * speedDeltaPerSecond * seconds - } else { - seconds / 2 * speedDeltaPerSecond * (seconds + 1) - } - } else { - // rampUpPart is just a special case of the above formula: rampUpTimeSeconds == seconds - val rampUpPart = valueAtSecond(rampUpTimeSeconds, rowsPerSecond, rampUpTimeSeconds) - rampUpPart + (seconds - rampUpTimeSeconds) * rowsPerSecond - } - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala new file mode 100644 index 0000000000000..4e2459bb05bd6 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala @@ -0,0 +1,187 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.sources + +import java.util.Optional + +import scala.collection.JavaConverters._ +import scala.collection.mutable + +import org.json4s.DefaultFormats +import org.json4s.jackson.Serialization + +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.execution.streaming.{RateStreamOffset, ValueRunTimeMsPair} +import org.apache.spark.sql.sources.DataSourceRegister +import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, MicroBatchReadSupport} +import org.apache.spark.sql.sources.v2.reader._ +import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset} +import org.apache.spark.sql.types.{LongType, StructField, StructType, TimestampType} +import org.apache.spark.util.{ManualClock, SystemClock} + +/** + * This is a temporary register as we build out v2 migration. Microbatch read support should + * be implemented in the same register as v1. + */ +class RateSourceProviderV2 extends DataSourceV2 with MicroBatchReadSupport with DataSourceRegister { + override def createMicroBatchReader( + schema: Optional[StructType], + checkpointLocation: String, + options: DataSourceOptions): MicroBatchReader = { + new RateStreamMicroBatchReader(options) + } + + override def shortName(): String = "ratev2" +} + +class RateStreamMicroBatchReader(options: DataSourceOptions) + extends MicroBatchReader { + implicit val defaultFormats: DefaultFormats = DefaultFormats + + val clock = { + // The option to use a manual clock is provided only for unit testing purposes. + if (options.get("useManualClock").orElse("false").toBoolean) new ManualClock + else new SystemClock + } + + private val numPartitions = + options.get(RateStreamSourceV2.NUM_PARTITIONS).orElse("5").toInt + private val rowsPerSecond = + options.get(RateStreamSourceV2.ROWS_PER_SECOND).orElse("6").toLong + + // The interval (in milliseconds) between rows in each partition. + // e.g. if there are 4 global rows per second, and 2 partitions, each partition + // should output rows every (1000 * 2 / 4) = 500 ms. + private val msPerPartitionBetweenRows = (1000 * numPartitions) / rowsPerSecond + + override def readSchema(): StructType = { + StructType( + StructField("timestamp", TimestampType, false) :: + StructField("value", LongType, false) :: Nil) + } + + val creationTimeMs = clock.getTimeMillis() + + private var start: RateStreamOffset = _ + private var end: RateStreamOffset = _ + + override def setOffsetRange( + start: Optional[Offset], + end: Optional[Offset]): Unit = { + this.start = start.orElse( + RateStreamSourceV2.createInitialOffset(numPartitions, creationTimeMs)) + .asInstanceOf[RateStreamOffset] + + this.end = end.orElse { + val currentTime = clock.getTimeMillis() + RateStreamOffset( + this.start.partitionToValueAndRunTimeMs.map { + case startOffset @ (part, ValueRunTimeMsPair(currentVal, currentReadTime)) => + // Calculate the number of rows we should advance in this partition (based on the + // current time), and output a corresponding offset. + val readInterval = currentTime - currentReadTime + val numNewRows = readInterval / msPerPartitionBetweenRows + if (numNewRows <= 0) { + startOffset + } else { + (part, ValueRunTimeMsPair( + currentVal + (numNewRows * numPartitions), + currentReadTime + (numNewRows * msPerPartitionBetweenRows))) + } + } + ) + }.asInstanceOf[RateStreamOffset] + } + + override def getStartOffset(): Offset = { + if (start == null) throw new IllegalStateException("start offset not set") + start + } + override def getEndOffset(): Offset = { + if (end == null) throw new IllegalStateException("end offset not set") + end + } + + override def deserializeOffset(json: String): Offset = { + RateStreamOffset(Serialization.read[Map[Int, ValueRunTimeMsPair]](json)) + } + + override def createDataReaderFactories(): java.util.List[DataReaderFactory[Row]] = { + val startMap = start.partitionToValueAndRunTimeMs + val endMap = end.partitionToValueAndRunTimeMs + endMap.keys.toSeq.map { part => + val ValueRunTimeMsPair(endVal, _) = endMap(part) + val ValueRunTimeMsPair(startVal, startTimeMs) = startMap(part) + + val packedRows = mutable.ListBuffer[(Long, Long)]() + var outVal = startVal + numPartitions + var outTimeMs = startTimeMs + while (outVal <= endVal) { + packedRows.append((outTimeMs, outVal)) + outVal += numPartitions + outTimeMs += msPerPartitionBetweenRows + } + + RateStreamBatchTask(packedRows).asInstanceOf[DataReaderFactory[Row]] + }.toList.asJava + } + + override def commit(end: Offset): Unit = {} + override def stop(): Unit = {} +} + +case class RateStreamBatchTask(vals: Seq[(Long, Long)]) extends DataReaderFactory[Row] { + override def createDataReader(): DataReader[Row] = new RateStreamBatchReader(vals) +} + +class RateStreamBatchReader(vals: Seq[(Long, Long)]) extends DataReader[Row] { + private var currentIndex = -1 + + override def next(): Boolean = { + // Return true as long as the new index is in the seq. + currentIndex += 1 + currentIndex < vals.size + } + + override def get(): Row = { + Row( + DateTimeUtils.toJavaTimestamp(DateTimeUtils.fromMillis(vals(currentIndex)._1)), + vals(currentIndex)._2) + } + + override def close(): Unit = {} +} + +object RateStreamSourceV2 { + val NUM_PARTITIONS = "numPartitions" + val ROWS_PER_SECOND = "rowsPerSecond" + + private[sql] def createInitialOffset(numPartitions: Int, creationTimeMs: Long) = { + RateStreamOffset( + Range(0, numPartitions).map { i => + // Note that the starting offset is exclusive, so we have to decrement the starting value + // by the increment that will later be applied. The first row output in each + // partition will have a value equal to the partition index. + (i, + ValueRunTimeMsPair( + (i - numPartitions).toLong, + creationTimeMs)) + }.toMap) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceSuite.scala new file mode 100644 index 0000000000000..03d0f63fa4d7f --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceSuite.scala @@ -0,0 +1,194 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import java.util.concurrent.TimeUnit + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.streaming.{StreamingQueryException, StreamTest} +import org.apache.spark.util.ManualClock + +class RateSourceSuite extends StreamTest { + + import testImplicits._ + + case class AdvanceRateManualClock(seconds: Long) extends AddData { + override def addData(query: Option[StreamExecution]): (Source, Offset) = { + assert(query.nonEmpty) + val rateSource = query.get.logicalPlan.collect { + case StreamingExecutionRelation(source, _) if source.isInstanceOf[RateStreamSource] => + source.asInstanceOf[RateStreamSource] + }.head + rateSource.clock.asInstanceOf[ManualClock].advance(TimeUnit.SECONDS.toMillis(seconds)) + (rateSource, rateSource.getOffset.get) + } + } + + test("basic") { + val input = spark.readStream + .format("rate") + .option("rowsPerSecond", "10") + .option("useManualClock", "true") + .load() + testStream(input)( + AdvanceRateManualClock(seconds = 1), + CheckLastBatch((0 until 10).map(v => new java.sql.Timestamp(v * 100L) -> v): _*), + StopStream, + StartStream(), + // Advance 2 seconds because creating a new RateSource will also create a new ManualClock + AdvanceRateManualClock(seconds = 2), + CheckLastBatch((10 until 20).map(v => new java.sql.Timestamp(v * 100L) -> v): _*) + ) + } + + test("uniform distribution of event timestamps") { + val input = spark.readStream + .format("rate") + .option("rowsPerSecond", "1500") + .option("useManualClock", "true") + .load() + .as[(java.sql.Timestamp, Long)] + .map(v => (v._1.getTime, v._2)) + val expectedAnswer = (0 until 1500).map { v => + (math.round(v * (1000.0 / 1500)), v) + } + testStream(input)( + AdvanceRateManualClock(seconds = 1), + CheckLastBatch(expectedAnswer: _*) + ) + } + + test("valueAtSecond") { + import RateStreamSource._ + + assert(valueAtSecond(seconds = 0, rowsPerSecond = 5, rampUpTimeSeconds = 0) === 0) + assert(valueAtSecond(seconds = 1, rowsPerSecond = 5, rampUpTimeSeconds = 0) === 5) + + assert(valueAtSecond(seconds = 0, rowsPerSecond = 5, rampUpTimeSeconds = 2) === 0) + assert(valueAtSecond(seconds = 1, rowsPerSecond = 5, rampUpTimeSeconds = 2) === 1) + assert(valueAtSecond(seconds = 2, rowsPerSecond = 5, rampUpTimeSeconds = 2) === 3) + assert(valueAtSecond(seconds = 3, rowsPerSecond = 5, rampUpTimeSeconds = 2) === 8) + + assert(valueAtSecond(seconds = 0, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 0) + assert(valueAtSecond(seconds = 1, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 2) + assert(valueAtSecond(seconds = 2, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 6) + assert(valueAtSecond(seconds = 3, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 12) + assert(valueAtSecond(seconds = 4, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 20) + assert(valueAtSecond(seconds = 5, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 30) + } + + test("rampUpTime") { + val input = spark.readStream + .format("rate") + .option("rowsPerSecond", "10") + .option("rampUpTime", "4s") + .option("useManualClock", "true") + .load() + .as[(java.sql.Timestamp, Long)] + .map(v => (v._1.getTime, v._2)) + testStream(input)( + AdvanceRateManualClock(seconds = 1), + CheckLastBatch((0 until 2).map(v => v * 500 -> v): _*), // speed = 2 + AdvanceRateManualClock(seconds = 1), + CheckLastBatch((2 until 6).map(v => 1000 + (v - 2) * 250 -> v): _*), // speed = 4 + AdvanceRateManualClock(seconds = 1), + CheckLastBatch({ + Seq(2000 -> 6, 2167 -> 7, 2333 -> 8, 2500 -> 9, 2667 -> 10, 2833 -> 11) + }: _*), // speed = 6 + AdvanceRateManualClock(seconds = 1), + CheckLastBatch((12 until 20).map(v => 3000 + (v - 12) * 125 -> v): _*), // speed = 8 + AdvanceRateManualClock(seconds = 1), + // Now we should reach full speed + CheckLastBatch((20 until 30).map(v => 4000 + (v - 20) * 100 -> v): _*), // speed = 10 + AdvanceRateManualClock(seconds = 1), + CheckLastBatch((30 until 40).map(v => 5000 + (v - 30) * 100 -> v): _*), // speed = 10 + AdvanceRateManualClock(seconds = 1), + CheckLastBatch((40 until 50).map(v => 6000 + (v - 40) * 100 -> v): _*) // speed = 10 + ) + } + + test("numPartitions") { + val input = spark.readStream + .format("rate") + .option("rowsPerSecond", "10") + .option("numPartitions", "6") + .option("useManualClock", "true") + .load() + .select(spark_partition_id()) + .distinct() + testStream(input)( + AdvanceRateManualClock(1), + CheckLastBatch((0 until 6): _*) + ) + } + + testQuietly("overflow") { + val input = spark.readStream + .format("rate") + .option("rowsPerSecond", Long.MaxValue.toString) + .option("useManualClock", "true") + .load() + .select(spark_partition_id()) + .distinct() + testStream(input)( + AdvanceRateManualClock(2), + ExpectFailure[ArithmeticException](t => { + Seq("overflow", "rowsPerSecond").foreach { msg => + assert(t.getMessage.contains(msg)) + } + }) + ) + } + + testQuietly("illegal option values") { + def testIllegalOptionValue( + option: String, + value: String, + expectedMessages: Seq[String]): Unit = { + val e = intercept[StreamingQueryException] { + spark.readStream + .format("rate") + .option(option, value) + .load() + .writeStream + .format("console") + .start() + .awaitTermination() + } + assert(e.getCause.isInstanceOf[IllegalArgumentException]) + for (msg <- expectedMessages) { + assert(e.getCause.getMessage.contains(msg)) + } + } + + testIllegalOptionValue("rowsPerSecond", "-1", Seq("-1", "rowsPerSecond", "positive")) + testIllegalOptionValue("numPartitions", "-1", Seq("-1", "numPartitions", "positive")) + } + + test("user-specified schema given") { + val exception = intercept[AnalysisException] { + spark.readStream + .format("rate") + .schema(spark.range(1).schema) + .load() + } + assert(exception.getMessage.contains( + "rate source does not support a user-specified schema")) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala new file mode 100644 index 0000000000000..983ba1668f58f --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala @@ -0,0 +1,191 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import java.util.Optional +import java.util.concurrent.TimeUnit + +import scala.collection.JavaConverters._ + +import org.apache.spark.sql.Row +import org.apache.spark.sql.execution.datasources.DataSource +import org.apache.spark.sql.execution.streaming.continuous._ +import org.apache.spark.sql.execution.streaming.sources.{RateStreamBatchTask, RateStreamMicroBatchReader, RateStreamSourceV2} +import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, MicroBatchReadSupport} +import org.apache.spark.sql.sources.v2.DataSourceOptions +import org.apache.spark.sql.streaming.StreamTest +import org.apache.spark.util.ManualClock + +class RateSourceV2Suite extends StreamTest { + import testImplicits._ + + case class AdvanceRateManualClock(seconds: Long) extends AddData { + override def addData(query: Option[StreamExecution]): (BaseStreamingSource, Offset) = { + assert(query.nonEmpty) + val rateSource = query.get.logicalPlan.collect { + case StreamingExecutionRelation(source: RateStreamMicroBatchReader, _) => source + }.head + rateSource.clock.asInstanceOf[ManualClock].advance(TimeUnit.SECONDS.toMillis(seconds)) + rateSource.setOffsetRange(Optional.empty(), Optional.empty()) + (rateSource, rateSource.getEndOffset()) + } + } + + test("microbatch in registry") { + DataSource.lookupDataSource("ratev2", spark.sqlContext.conf).newInstance() match { + case ds: MicroBatchReadSupport => + val reader = ds.createMicroBatchReader(Optional.empty(), "", DataSourceOptions.empty()) + assert(reader.isInstanceOf[RateStreamMicroBatchReader]) + case _ => + throw new IllegalStateException("Could not find v2 read support for rate") + } + } + + test("basic microbatch execution") { + val input = spark.readStream + .format("rateV2") + .option("numPartitions", "1") + .option("rowsPerSecond", "10") + .option("useManualClock", "true") + .load() + testStream(input, useV2Sink = true)( + AdvanceRateManualClock(seconds = 1), + CheckLastBatch((0 until 10).map(v => new java.sql.Timestamp(v * 100L) -> v): _*), + StopStream, + StartStream(), + // Advance 2 seconds because creating a new RateSource will also create a new ManualClock + AdvanceRateManualClock(seconds = 2), + CheckLastBatch((10 until 20).map(v => new java.sql.Timestamp(v * 100L) -> v): _*) + ) + } + + test("microbatch - numPartitions propagated") { + val reader = new RateStreamMicroBatchReader( + new DataSourceOptions(Map("numPartitions" -> "11", "rowsPerSecond" -> "33").asJava)) + reader.setOffsetRange(Optional.empty(), Optional.empty()) + val tasks = reader.createDataReaderFactories() + assert(tasks.size == 11) + } + + test("microbatch - set offset") { + val reader = new RateStreamMicroBatchReader(DataSourceOptions.empty()) + val startOffset = RateStreamOffset(Map((0, ValueRunTimeMsPair(0, 1000)))) + val endOffset = RateStreamOffset(Map((0, ValueRunTimeMsPair(0, 2000)))) + reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) + assert(reader.getStartOffset() == startOffset) + assert(reader.getEndOffset() == endOffset) + } + + test("microbatch - infer offsets") { + val reader = new RateStreamMicroBatchReader( + new DataSourceOptions(Map("numPartitions" -> "1", "rowsPerSecond" -> "100").asJava)) + reader.clock.waitTillTime(reader.clock.getTimeMillis() + 100) + reader.setOffsetRange(Optional.empty(), Optional.empty()) + reader.getStartOffset() match { + case r: RateStreamOffset => + assert(r.partitionToValueAndRunTimeMs(0).runTimeMs == reader.creationTimeMs) + case _ => throw new IllegalStateException("unexpected offset type") + } + reader.getEndOffset() match { + case r: RateStreamOffset => + // End offset may be a bit beyond 100 ms/9 rows after creation if the wait lasted + // longer than 100ms. It should never be early. + assert(r.partitionToValueAndRunTimeMs(0).value >= 9) + assert(r.partitionToValueAndRunTimeMs(0).runTimeMs >= reader.creationTimeMs + 100) + + case _ => throw new IllegalStateException("unexpected offset type") + } + } + + test("microbatch - predetermined batch size") { + val reader = new RateStreamMicroBatchReader( + new DataSourceOptions(Map("numPartitions" -> "1", "rowsPerSecond" -> "20").asJava)) + val startOffset = RateStreamOffset(Map((0, ValueRunTimeMsPair(0, 1000)))) + val endOffset = RateStreamOffset(Map((0, ValueRunTimeMsPair(20, 2000)))) + reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) + val tasks = reader.createDataReaderFactories() + assert(tasks.size == 1) + assert(tasks.get(0).asInstanceOf[RateStreamBatchTask].vals.size == 20) + } + + test("microbatch - data read") { + val reader = new RateStreamMicroBatchReader( + new DataSourceOptions(Map("numPartitions" -> "11", "rowsPerSecond" -> "33").asJava)) + val startOffset = RateStreamSourceV2.createInitialOffset(11, reader.creationTimeMs) + val endOffset = RateStreamOffset(startOffset.partitionToValueAndRunTimeMs.toSeq.map { + case (part, ValueRunTimeMsPair(currentVal, currentReadTime)) => + (part, ValueRunTimeMsPair(currentVal + 33, currentReadTime + 1000)) + }.toMap) + + reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) + val tasks = reader.createDataReaderFactories() + assert(tasks.size == 11) + + val readData = tasks.asScala + .map(_.createDataReader()) + .flatMap { reader => + val buf = scala.collection.mutable.ListBuffer[Row]() + while (reader.next()) buf.append(reader.get()) + buf + } + + assert(readData.map(_.getLong(1)).sorted == Range(0, 33)) + } + + test("continuous in registry") { + DataSource.lookupDataSource("rate", spark.sqlContext.conf).newInstance() match { + case ds: ContinuousReadSupport => + val reader = ds.createContinuousReader(Optional.empty(), "", DataSourceOptions.empty()) + assert(reader.isInstanceOf[RateStreamContinuousReader]) + case _ => + throw new IllegalStateException("Could not find v2 read support for rate") + } + } + + test("continuous data") { + val reader = new RateStreamContinuousReader( + new DataSourceOptions(Map("numPartitions" -> "2", "rowsPerSecond" -> "20").asJava)) + reader.setStartOffset(Optional.empty()) + val tasks = reader.createDataReaderFactories() + assert(tasks.size == 2) + + val data = scala.collection.mutable.ListBuffer[Row]() + tasks.asScala.foreach { + case t: RateStreamContinuousDataReaderFactory => + val startTimeMs = reader.getStartOffset() + .asInstanceOf[RateStreamOffset] + .partitionToValueAndRunTimeMs(t.partitionIndex) + .runTimeMs + val r = t.createDataReader().asInstanceOf[RateStreamContinuousDataReader] + for (rowIndex <- 0 to 9) { + r.next() + data.append(r.get()) + assert(r.getOffset() == + RateStreamPartitionOffset( + t.partitionIndex, + t.partitionIndex + rowIndex * 2, + startTimeMs + (rowIndex + 1) * 100)) + } + assert(System.currentTimeMillis() >= startTimeMs + 1000) + + case _ => throw new IllegalStateException("Unexpected task type") + } + + assert(data.map(_.getLong(1)).toSeq.sorted == Range(0, 20)) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala deleted file mode 100644 index 9149e50962255..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala +++ /dev/null @@ -1,344 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.streaming.sources - -import java.nio.file.Files -import java.util.Optional -import java.util.concurrent.TimeUnit - -import scala.collection.JavaConverters._ -import scala.collection.mutable.ArrayBuffer - -import org.apache.spark.sql.{AnalysisException, Row, SparkSession} -import org.apache.spark.sql.catalyst.errors.TreeNodeException -import org.apache.spark.sql.execution.datasources.DataSource -import org.apache.spark.sql.execution.streaming._ -import org.apache.spark.sql.execution.streaming.continuous._ -import org.apache.spark.sql.functions._ -import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions, MicroBatchReadSupport} -import org.apache.spark.sql.sources.v2.reader.streaming.Offset -import org.apache.spark.sql.streaming.StreamTest -import org.apache.spark.util.ManualClock - -class RateSourceSuite extends StreamTest { - - import testImplicits._ - - protected override def beforeAll(): Unit = { - super.beforeAll() - SparkSession.setActiveSession(spark) - } - - override def afterAll(): Unit = { - SparkSession.clearActiveSession() - super.afterAll() - } - - case class AdvanceRateManualClock(seconds: Long) extends AddData { - override def addData(query: Option[StreamExecution]): (BaseStreamingSource, Offset) = { - assert(query.nonEmpty) - val rateSource = query.get.logicalPlan.collect { - case StreamingExecutionRelation(source: RateStreamMicroBatchReader, _) => source - }.head - - rateSource.clock.asInstanceOf[ManualClock].advance(TimeUnit.SECONDS.toMillis(seconds)) - val offset = LongOffset(TimeUnit.MILLISECONDS.toSeconds( - rateSource.clock.getTimeMillis() - rateSource.creationTimeMs)) - (rateSource, offset) - } - } - - test("microbatch in registry") { - DataSource.lookupDataSource("rate", spark.sqlContext.conf).newInstance() match { - case ds: MicroBatchReadSupport => - val reader = ds.createMicroBatchReader(Optional.empty(), "dummy", DataSourceOptions.empty()) - assert(reader.isInstanceOf[RateStreamMicroBatchReader]) - case _ => - throw new IllegalStateException("Could not find read support for rate") - } - } - - test("compatible with old path in registry") { - DataSource.lookupDataSource("org.apache.spark.sql.execution.streaming.RateSourceProvider", - spark.sqlContext.conf).newInstance() match { - case ds: MicroBatchReadSupport => - assert(ds.isInstanceOf[RateStreamProvider]) - case _ => - throw new IllegalStateException("Could not find read support for rate") - } - } - - test("microbatch - basic") { - val input = spark.readStream - .format("rate") - .option("rowsPerSecond", "10") - .option("useManualClock", "true") - .load() - testStream(input)( - AdvanceRateManualClock(seconds = 1), - CheckLastBatch((0 until 10).map(v => new java.sql.Timestamp(v * 100L) -> v): _*), - StopStream, - StartStream(), - // Advance 2 seconds because creating a new RateSource will also create a new ManualClock - AdvanceRateManualClock(seconds = 2), - CheckLastBatch((10 until 20).map(v => new java.sql.Timestamp(v * 100L) -> v): _*) - ) - } - - test("microbatch - uniform distribution of event timestamps") { - val input = spark.readStream - .format("rate") - .option("rowsPerSecond", "1500") - .option("useManualClock", "true") - .load() - .as[(java.sql.Timestamp, Long)] - .map(v => (v._1.getTime, v._2)) - val expectedAnswer = (0 until 1500).map { v => - (math.round(v * (1000.0 / 1500)), v) - } - testStream(input)( - AdvanceRateManualClock(seconds = 1), - CheckLastBatch(expectedAnswer: _*) - ) - } - - test("microbatch - set offset") { - val temp = Files.createTempDirectory("dummy").toString - val reader = new RateStreamMicroBatchReader(DataSourceOptions.empty(), temp) - val startOffset = LongOffset(0L) - val endOffset = LongOffset(1L) - reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) - assert(reader.getStartOffset() == startOffset) - assert(reader.getEndOffset() == endOffset) - } - - test("microbatch - infer offsets") { - val tempFolder = Files.createTempDirectory("dummy").toString - val reader = new RateStreamMicroBatchReader( - new DataSourceOptions( - Map("numPartitions" -> "1", "rowsPerSecond" -> "100", "useManualClock" -> "true").asJava), - tempFolder) - reader.clock.asInstanceOf[ManualClock].advance(100000) - reader.setOffsetRange(Optional.empty(), Optional.empty()) - reader.getStartOffset() match { - case r: LongOffset => assert(r.offset === 0L) - case _ => throw new IllegalStateException("unexpected offset type") - } - reader.getEndOffset() match { - case r: LongOffset => assert(r.offset >= 100) - case _ => throw new IllegalStateException("unexpected offset type") - } - } - - test("microbatch - predetermined batch size") { - val temp = Files.createTempDirectory("dummy").toString - val reader = new RateStreamMicroBatchReader( - new DataSourceOptions(Map("numPartitions" -> "1", "rowsPerSecond" -> "20").asJava), temp) - val startOffset = LongOffset(0L) - val endOffset = LongOffset(1L) - reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) - val tasks = reader.createDataReaderFactories() - assert(tasks.size == 1) - val dataReader = tasks.get(0).createDataReader() - val data = ArrayBuffer[Row]() - while (dataReader.next()) { - data.append(dataReader.get()) - } - assert(data.size === 20) - } - - test("microbatch - data read") { - val temp = Files.createTempDirectory("dummy").toString - val reader = new RateStreamMicroBatchReader( - new DataSourceOptions(Map("numPartitions" -> "11", "rowsPerSecond" -> "33").asJava), temp) - val startOffset = LongOffset(0L) - val endOffset = LongOffset(1L) - reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) - val tasks = reader.createDataReaderFactories() - assert(tasks.size == 11) - - val readData = tasks.asScala - .map(_.createDataReader()) - .flatMap { reader => - val buf = scala.collection.mutable.ListBuffer[Row]() - while (reader.next()) buf.append(reader.get()) - buf - } - - assert(readData.map(_.getLong(1)).sorted == Range(0, 33)) - } - - test("valueAtSecond") { - import RateStreamProvider._ - - assert(valueAtSecond(seconds = 0, rowsPerSecond = 5, rampUpTimeSeconds = 0) === 0) - assert(valueAtSecond(seconds = 1, rowsPerSecond = 5, rampUpTimeSeconds = 0) === 5) - - assert(valueAtSecond(seconds = 0, rowsPerSecond = 5, rampUpTimeSeconds = 2) === 0) - assert(valueAtSecond(seconds = 1, rowsPerSecond = 5, rampUpTimeSeconds = 2) === 1) - assert(valueAtSecond(seconds = 2, rowsPerSecond = 5, rampUpTimeSeconds = 2) === 3) - assert(valueAtSecond(seconds = 3, rowsPerSecond = 5, rampUpTimeSeconds = 2) === 8) - - assert(valueAtSecond(seconds = 0, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 0) - assert(valueAtSecond(seconds = 1, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 2) - assert(valueAtSecond(seconds = 2, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 6) - assert(valueAtSecond(seconds = 3, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 12) - assert(valueAtSecond(seconds = 4, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 20) - assert(valueAtSecond(seconds = 5, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 30) - } - - test("rampUpTime") { - val input = spark.readStream - .format("rate") - .option("rowsPerSecond", "10") - .option("rampUpTime", "4s") - .option("useManualClock", "true") - .load() - .as[(java.sql.Timestamp, Long)] - .map(v => (v._1.getTime, v._2)) - testStream(input)( - AdvanceRateManualClock(seconds = 1), - CheckLastBatch((0 until 2).map(v => v * 500 -> v): _*), // speed = 2 - AdvanceRateManualClock(seconds = 1), - CheckLastBatch((2 until 6).map(v => 1000 + (v - 2) * 250 -> v): _*), // speed = 4 - AdvanceRateManualClock(seconds = 1), - CheckLastBatch({ - Seq(2000 -> 6, 2167 -> 7, 2333 -> 8, 2500 -> 9, 2667 -> 10, 2833 -> 11) - }: _*), // speed = 6 - AdvanceRateManualClock(seconds = 1), - CheckLastBatch((12 until 20).map(v => 3000 + (v - 12) * 125 -> v): _*), // speed = 8 - AdvanceRateManualClock(seconds = 1), - // Now we should reach full speed - CheckLastBatch((20 until 30).map(v => 4000 + (v - 20) * 100 -> v): _*), // speed = 10 - AdvanceRateManualClock(seconds = 1), - CheckLastBatch((30 until 40).map(v => 5000 + (v - 30) * 100 -> v): _*), // speed = 10 - AdvanceRateManualClock(seconds = 1), - CheckLastBatch((40 until 50).map(v => 6000 + (v - 40) * 100 -> v): _*) // speed = 10 - ) - } - - test("numPartitions") { - val input = spark.readStream - .format("rate") - .option("rowsPerSecond", "10") - .option("numPartitions", "6") - .option("useManualClock", "true") - .load() - .select(spark_partition_id()) - .distinct() - testStream(input)( - AdvanceRateManualClock(1), - CheckLastBatch((0 until 6): _*) - ) - } - - testQuietly("overflow") { - val input = spark.readStream - .format("rate") - .option("rowsPerSecond", Long.MaxValue.toString) - .option("useManualClock", "true") - .load() - .select(spark_partition_id()) - .distinct() - testStream(input)( - AdvanceRateManualClock(2), - ExpectFailure[TreeNodeException[_]](t => { - Seq("overflow", "rowsPerSecond").foreach { msg => - assert(t.getCause.getMessage.contains(msg)) - } - }) - ) - } - - testQuietly("illegal option values") { - def testIllegalOptionValue( - option: String, - value: String, - expectedMessages: Seq[String]): Unit = { - val e = intercept[IllegalArgumentException] { - spark.readStream - .format("rate") - .option(option, value) - .load() - .writeStream - .format("console") - .start() - .awaitTermination() - } - for (msg <- expectedMessages) { - assert(e.getMessage.contains(msg)) - } - } - - testIllegalOptionValue("rowsPerSecond", "-1", Seq("-1", "rowsPerSecond", "positive")) - testIllegalOptionValue("numPartitions", "-1", Seq("-1", "numPartitions", "positive")) - } - - test("user-specified schema given") { - val exception = intercept[AnalysisException] { - spark.readStream - .format("rate") - .schema(spark.range(1).schema) - .load() - } - assert(exception.getMessage.contains( - "rate source does not support a user-specified schema")) - } - - test("continuous in registry") { - DataSource.lookupDataSource("rate", spark.sqlContext.conf).newInstance() match { - case ds: ContinuousReadSupport => - val reader = ds.createContinuousReader(Optional.empty(), "", DataSourceOptions.empty()) - assert(reader.isInstanceOf[RateStreamContinuousReader]) - case _ => - throw new IllegalStateException("Could not find read support for continuous rate") - } - } - - test("continuous data") { - val reader = new RateStreamContinuousReader( - new DataSourceOptions(Map("numPartitions" -> "2", "rowsPerSecond" -> "20").asJava)) - reader.setStartOffset(Optional.empty()) - val tasks = reader.createDataReaderFactories() - assert(tasks.size == 2) - - val data = scala.collection.mutable.ListBuffer[Row]() - tasks.asScala.foreach { - case t: RateStreamContinuousDataReaderFactory => - val startTimeMs = reader.getStartOffset() - .asInstanceOf[RateStreamOffset] - .partitionToValueAndRunTimeMs(t.partitionIndex) - .runTimeMs - val r = t.createDataReader().asInstanceOf[RateStreamContinuousDataReader] - for (rowIndex <- 0 to 9) { - r.next() - data.append(r.get()) - assert(r.getOffset() == - RateStreamPartitionOffset( - t.partitionIndex, - t.partitionIndex + rowIndex * 2, - startTimeMs + (rowIndex + 1) * 100)) - } - assert(System.currentTimeMillis() >= startTimeMs + 1000) - - case _ => throw new IllegalStateException("Unexpected task type") - } - - assert(data.map(_.getLong(1)).toSeq.sorted == Range(0, 20)) - } -} From ea2fdc0d286e449884de44f22a908a26ab1248a5 Mon Sep 17 00:00:00 2001 From: guoxiaolong Date: Wed, 28 Mar 2018 19:49:32 -0500 Subject: [PATCH 0532/1161] [SPARK-23675][WEB-UI] Title add spark logo, use spark logo image ## What changes were proposed in this pull request? Title add spark logo, use spark logo image. reference other big data system ui, so i think spark should add it. spark fix before: ![spark_fix_before](https://user-images.githubusercontent.com/26266482/37387866-2d5add0e-2799-11e8-9165-250f2b59df3f.png) spark fix after: ![spark_fix_after](https://user-images.githubusercontent.com/26266482/37387874-329e1876-2799-11e8-8bc5-c619fc1e680e.png) reference kafka ui: ![kafka](https://user-images.githubusercontent.com/26266482/37387878-35ca89d0-2799-11e8-834e-1598ae7158e1.png) reference storm ui: ![storm](https://user-images.githubusercontent.com/26266482/37387880-3854f12c-2799-11e8-8968-b428ba361995.png) reference yarn ui: ![yarn](https://user-images.githubusercontent.com/26266482/37387881-3a72e130-2799-11e8-97bb-dea85f573e95.png) reference nifi ui: ![nifi](https://user-images.githubusercontent.com/26266482/37387887-3cecfea0-2799-11e8-9a71-6c454d25840b.png) reference flink ui: ![flink](https://user-images.githubusercontent.com/26266482/37387888-3f16b1ee-2799-11e8-9d37-8355f0100548.png) ## How was this patch tested? manual tests Please review http://spark.apache.org/contributing.html before opening a pull request. Author: guoxiaolong Closes #20818 from guoxiaolongzte/SPARK-23675. --- core/src/main/scala/org/apache/spark/ui/UIUtils.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala index ba798df13c95d..02cf19e00ecde 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala @@ -224,6 +224,7 @@ private[spark] object UIUtils extends Logging { {commonHeaderNodes} {if (showVisualization) vizHeaderNodes else Seq.empty} {if (useDataTables) dataTablesHeaderNodes else Seq.empty} + {appName} - {title} @@ -265,6 +266,7 @@ private[spark] object UIUtils extends Logging { {commonHeaderNodes} {if (useDataTables) dataTablesHeaderNodes else Seq.empty} + {title} From 641aec68e8167546dbb922874c086c9b90198f08 Mon Sep 17 00:00:00 2001 From: Thomas Graves Date: Thu, 29 Mar 2018 16:37:46 +0800 Subject: [PATCH 0533/1161] =?UTF-8?q?[SPARK-23806]=20Broadcast.unpersist?= =?UTF-8?q?=20can=20cause=20fatal=20exception=20when=20used=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit … with dynamic allocation ## What changes were proposed in this pull request? ignore errors when you are waiting for a broadcast.unpersist. This is handling it the same way as doing rdd.unpersist in https://issues.apache.org/jira/browse/SPARK-22618 ## How was this patch tested? Patch was tested manually against a couple jobs that exhibit this behavior, with the change the application no longer dies due to this and just prints the warning. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Thomas Graves Closes #20924 from tgravescs/SPARK-23806. --- .../spark/storage/BlockManagerMasterEndpoint.scala | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala index 89a6a71a589a1..56b95c31eb4c3 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala @@ -192,11 +192,15 @@ class BlockManagerMasterEndpoint( val requiredBlockManagers = blockManagerInfo.values.filter { info => removeFromDriver || !info.blockManagerId.isDriver } - Future.sequence( - requiredBlockManagers.map { bm => - bm.slaveEndpoint.ask[Int](removeMsg) - }.toSeq - ) + val futures = requiredBlockManagers.map { bm => + bm.slaveEndpoint.ask[Int](removeMsg).recover { + case e: IOException => + logWarning(s"Error trying to remove broadcast $broadcastId", e) + 0 // zero blocks were removed + } + }.toSeq + + Future.sequence(futures) } private def removeBlockManager(blockManagerId: BlockManagerId) { From 505480cb578af9f23acc77bc82348afc9d8468e8 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Thu, 29 Mar 2018 19:38:28 +0900 Subject: [PATCH 0534/1161] [SPARK-23770][R] Exposes repartitionByRange in SparkR ## What changes were proposed in this pull request? This PR proposes to expose `repartitionByRange`. ```R > df <- createDataFrame(iris) ... > getNumPartitions(repartitionByRange(df, 3, col = df$Species)) [1] 3 ``` ## How was this patch tested? Manually tested and the unit tests were added. The diff with `repartition` can be checked as below: ```R > df <- createDataFrame(mtcars) > take(repartition(df, 10, df$wt), 3) mpg cyl disp hp drat wt qsec vs am gear carb 1 14.3 8 360.0 245 3.21 3.570 15.84 0 0 3 4 2 10.4 8 460.0 215 3.00 5.424 17.82 0 0 3 4 3 32.4 4 78.7 66 4.08 2.200 19.47 1 1 4 1 > take(repartitionByRange(df, 10, df$wt), 3) mpg cyl disp hp drat wt qsec vs am gear carb 1 30.4 4 75.7 52 4.93 1.615 18.52 1 1 4 2 2 33.9 4 71.1 65 4.22 1.835 19.90 1 1 4 1 3 27.3 4 79.0 66 4.08 1.935 18.90 1 1 4 1 ``` Author: hyukjinkwon Closes #20902 from HyukjinKwon/r-repartitionByRange. --- R/pkg/NAMESPACE | 1 + R/pkg/R/DataFrame.R | 65 ++++++++++++++++++++++++++- R/pkg/R/generics.R | 3 ++ R/pkg/tests/fulltests/test_sparkSQL.R | 45 +++++++++++++++++++ 4 files changed, 112 insertions(+), 2 deletions(-) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index c51eb0f39c4b1..190c50ea10482 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -151,6 +151,7 @@ exportMethods("arrange", "registerTempTable", "rename", "repartition", + "repartitionByRange", "rollup", "sample", "sample_frac", diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index c4852024c0f49..a1c9495b0795e 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -687,7 +687,7 @@ setMethod("storageLevel", #' @rdname coalesce #' @name coalesce #' @aliases coalesce,SparkDataFrame-method -#' @seealso \link{repartition} +#' @seealso \link{repartition}, \link{repartitionByRange} #' @examples #'\dontrun{ #' sparkR.session() @@ -723,7 +723,7 @@ setMethod("coalesce", #' @rdname repartition #' @name repartition #' @aliases repartition,SparkDataFrame-method -#' @seealso \link{coalesce} +#' @seealso \link{coalesce}, \link{repartitionByRange} #' @examples #'\dontrun{ #' sparkR.session() @@ -759,6 +759,67 @@ setMethod("repartition", dataFrame(sdf) }) + +#' Repartition by range +#' +#' The following options for repartition by range are possible: +#' \itemize{ +#' \item{1.} {Return a new SparkDataFrame range partitioned by +#' the given columns into \code{numPartitions}.} +#' \item{2.} {Return a new SparkDataFrame range partitioned by the given column(s), +#' using \code{spark.sql.shuffle.partitions} as number of partitions.} +#'} +#' +#' @param x a SparkDataFrame. +#' @param numPartitions the number of partitions to use. +#' @param col the column by which the range partitioning will be performed. +#' @param ... additional column(s) to be used in the range partitioning. +#' +#' @family SparkDataFrame functions +#' @rdname repartitionByRange +#' @name repartitionByRange +#' @aliases repartitionByRange,SparkDataFrame-method +#' @seealso \link{repartition}, \link{coalesce} +#' @examples +#'\dontrun{ +#' sparkR.session() +#' path <- "path/to/file.json" +#' df <- read.json(path) +#' newDF <- repartitionByRange(df, col = df$col1, df$col2) +#' newDF <- repartitionByRange(df, 3L, col = df$col1, df$col2) +#'} +#' @note repartitionByRange since 2.4.0 +setMethod("repartitionByRange", + signature(x = "SparkDataFrame"), + function(x, numPartitions = NULL, col = NULL, ...) { + if (!is.null(numPartitions) && !is.null(col)) { + # number of partitions and columns both are specified + if (is.numeric(numPartitions) && class(col) == "Column") { + cols <- list(col, ...) + jcol <- lapply(cols, function(c) { c@jc }) + sdf <- callJMethod(x@sdf, "repartitionByRange", numToInt(numPartitions), jcol) + } else { + stop(paste("numPartitions and col must be numeric and Column; however, got", + class(numPartitions), "and", class(col))) + } + } else if (!is.null(col)) { + # only columns are specified + if (class(col) == "Column") { + cols <- list(col, ...) + jcol <- lapply(cols, function(c) { c@jc }) + sdf <- callJMethod(x@sdf, "repartitionByRange", jcol) + } else { + stop(paste("col must be Column; however, got", class(col))) + } + } else if (!is.null(numPartitions)) { + # only numPartitions is specified + stop("At least one partition-by column must be specified.") + } else { + stop("Please, specify a column(s) or the number of partitions with a column(s)") + } + dataFrame(sdf) + }) + #' toJSON #' #' Converts a SparkDataFrame into a SparkDataFrame of JSON string. diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 6fba4b6c761dd..974beff1a3d76 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -531,6 +531,9 @@ setGeneric("rename", function(x, ...) { standardGeneric("rename") }) #' @rdname repartition setGeneric("repartition", function(x, ...) { standardGeneric("repartition") }) +#' @rdname repartitionByRange +setGeneric("repartitionByRange", function(x, ...) { standardGeneric("repartitionByRange") }) + #' @rdname sample setGeneric("sample", function(x, withReplacement = FALSE, fraction, seed) { diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index 439191adb23ea..7105469ffc242 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -3104,6 +3104,51 @@ test_that("repartition by columns on DataFrame", { }) }) +test_that("repartitionByRange on a DataFrame", { + # The tasks here launch R workers with shuffles. So, we decrease the number of shuffle + # partitions to reduce the number of the tasks to speed up the test. This is particularly + # slow on Windows because the R workers are unable to be forked. See also SPARK-21693. + conf <- callJMethod(sparkSession, "conf") + shufflepartitionsvalue <- callJMethod(conf, "get", "spark.sql.shuffle.partitions") + callJMethod(conf, "set", "spark.sql.shuffle.partitions", "5") + tryCatch({ + df <- createDataFrame(mtcars) + expect_error(repartitionByRange(df, "haha", df$mpg), + "numPartitions and col must be numeric and Column.*") + expect_error(repartitionByRange(df), + ".*specify a column.*or the number of partitions with a column.*") + expect_error(repartitionByRange(df, col = "haha"), + "col must be Column; however, got.*") + expect_error(repartitionByRange(df, 3), + "At least one partition-by column must be specified.") + + # The order of rows should be different with a normal repartition. + actual <- repartitionByRange(df, 3, df$mpg) + expect_equal(getNumPartitions(actual), 3) + expect_false(identical(collect(actual), collect(repartition(df, 3, df$mpg)))) + + actual <- repartitionByRange(df, col = df$mpg) + expect_false(identical(collect(actual), collect(repartition(df, col = df$mpg)))) + + # They should have same data. + actual <- collect(repartitionByRange(df, 3, df$mpg)) + actual <- actual[order(actual$mpg), ] + expected <- collect(repartition(df, 3, df$mpg)) + expected <- expected[order(expected$mpg), ] + expect_true(all(actual == expected)) + + actual <- collect(repartitionByRange(df, col = df$mpg)) + actual <- actual[order(actual$mpg), ] + expected <- collect(repartition(df, col = df$mpg)) + expected <- expected[order(expected$mpg), ] + expect_true(all(actual == expected)) + }, + finally = { + # Resetting the conf back to default value + callJMethod(conf, "set", "spark.sql.shuffle.partitions", shufflepartitionsvalue) + }) +}) + test_that("coalesce, repartition, numPartitions", { df <- as.DataFrame(cars, numPartitions = 5) expect_equal(getNumPartitions(df), 5) From 491ec114fd3886ebd9fa29a482e3d112fb5a088c Mon Sep 17 00:00:00 2001 From: Sahil Takiar Date: Thu, 29 Mar 2018 10:23:23 -0700 Subject: [PATCH 0535/1161] [SPARK-23785][LAUNCHER] LauncherBackend doesn't check state of connection before setting state ## What changes were proposed in this pull request? Changed `LauncherBackend` `set` method so that it checks if the connection is open or not before writing to it (uses `isConnected`). ## How was this patch tested? None Author: Sahil Takiar Closes #20893 from sahilTakiar/master. --- .../spark/launcher/LauncherBackend.scala | 6 +++--- .../spark/launcher/LauncherServerSuite.java | 20 +++++++++++++++++++ 2 files changed, 23 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/launcher/LauncherBackend.scala b/core/src/main/scala/org/apache/spark/launcher/LauncherBackend.scala index aaae33ca4e6f3..1b049b786023a 100644 --- a/core/src/main/scala/org/apache/spark/launcher/LauncherBackend.scala +++ b/core/src/main/scala/org/apache/spark/launcher/LauncherBackend.scala @@ -67,13 +67,13 @@ private[spark] abstract class LauncherBackend { } def setAppId(appId: String): Unit = { - if (connection != null) { + if (connection != null && isConnected) { connection.send(new SetAppId(appId)) } } def setState(state: SparkAppHandle.State): Unit = { - if (connection != null && lastState != state) { + if (connection != null && isConnected && lastState != state) { connection.send(new SetState(state)) lastState = state } @@ -114,10 +114,10 @@ private[spark] abstract class LauncherBackend { override def close(): Unit = { try { + _isConnected = false super.close() } finally { onDisconnected() - _isConnected = false } } diff --git a/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java b/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java index d16337a319be3..5413d3a416545 100644 --- a/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java +++ b/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java @@ -185,6 +185,26 @@ public void testStreamFiltering() throws Exception { } } + @Test + public void testAppHandleDisconnect() throws Exception { + LauncherServer server = LauncherServer.getOrCreateServer(); + ChildProcAppHandle handle = new ChildProcAppHandle(server); + String secret = server.registerHandle(handle); + + TestClient client = null; + try { + Socket s = new Socket(InetAddress.getLoopbackAddress(), server.getPort()); + client = new TestClient(s); + client.send(new Hello(secret, "1.4.0")); + handle.disconnect(); + waitForError(client, secret); + } finally { + handle.kill(); + close(client); + client.clientThread.join(); + } + } + private void close(Closeable c) { if (c != null) { try { From a7755fd8ce2f022118b9827aaac7d5d59f0f297a Mon Sep 17 00:00:00 2001 From: Kent Yao Date: Thu, 29 Mar 2018 10:46:28 -0700 Subject: [PATCH 0536/1161] [SPARK-23639][SQL] Obtain token before init metastore client in SparkSQL CLI ## What changes were proposed in this pull request? In SparkSQLCLI, SessionState generates before SparkContext instantiating. When we use --proxy-user to impersonate, it's unable to initializing a metastore client to talk to the secured metastore for no kerberos ticket. This PR use real user ugi to obtain token for owner before talking to kerberized metastore. ## How was this patch tested? Manually verified with kerberized hive metasotre / hdfs. Author: Kent Yao Closes #20784 from yaooqinn/SPARK-23639. --- .../deploy/security/HiveDelegationTokenProvider.scala | 8 ++++---- .../spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala | 9 +++++++++ 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/security/HiveDelegationTokenProvider.scala b/core/src/main/scala/org/apache/spark/deploy/security/HiveDelegationTokenProvider.scala index ece5ce79c650d..7249eb85ac7c7 100644 --- a/core/src/main/scala/org/apache/spark/deploy/security/HiveDelegationTokenProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/security/HiveDelegationTokenProvider.scala @@ -36,7 +36,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.internal.config.KEYTAB import org.apache.spark.util.Utils -private[security] class HiveDelegationTokenProvider +private[spark] class HiveDelegationTokenProvider extends HadoopDelegationTokenProvider with Logging { override def serviceName: String = "hive" @@ -124,9 +124,9 @@ private[security] class HiveDelegationTokenProvider val currentUser = UserGroupInformation.getCurrentUser() val realUser = Option(currentUser.getRealUser()).getOrElse(currentUser) - // For some reason the Scala-generated anonymous class ends up causing an - // UndeclaredThrowableException, even if you annotate the method with @throws. - try { + // For some reason the Scala-generated anonymous class ends up causing an + // UndeclaredThrowableException, even if you annotate the method with @throws. + try { realUser.doAs(new PrivilegedExceptionAction[T]() { override def run(): T = fn }) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala index 832a15d09599f..084f8200102ba 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala @@ -34,11 +34,13 @@ import org.apache.hadoop.hive.ql.Driver import org.apache.hadoop.hive.ql.exec.Utilities import org.apache.hadoop.hive.ql.processors._ import org.apache.hadoop.hive.ql.session.SessionState +import org.apache.hadoop.security.{Credentials, UserGroupInformation} import org.apache.log4j.{Level, Logger} import org.apache.thrift.transport.TSocket import org.apache.spark.SparkConf import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.deploy.security.HiveDelegationTokenProvider import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.hive.HiveUtils @@ -121,6 +123,13 @@ private[hive] object SparkSQLCLIDriver extends Logging { } } + val tokenProvider = new HiveDelegationTokenProvider() + if (tokenProvider.delegationTokensRequired(sparkConf, hadoopConf)) { + val credentials = new Credentials() + tokenProvider.obtainDelegationTokens(hadoopConf, sparkConf, credentials) + UserGroupInformation.getCurrentUser.addCredentials(credentials) + } + SessionState.start(sessionState) // Clean up after we exit From b348901192b231153b58fe5720253168c87963d4 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Thu, 29 Mar 2018 21:36:56 -0700 Subject: [PATCH 0537/1161] [SPARK-23808][SQL] Set default Spark session in test-only spark sessions. ## What changes were proposed in this pull request? Set default Spark session in the TestSparkSession and TestHiveSparkSession constructors. ## How was this patch tested? new unit tests Author: Jose Torres Closes #20926 from jose-torres/test3. --- .../spark/sql/test/TestSQLContext.scala | 2 ++ .../sql/test/TestSparkSessionSuite.scala | 29 +++++++++++++++++++ .../apache/spark/sql/hive/test/TestHive.scala | 4 +++ 3 files changed, 35 insertions(+) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/test/TestSparkSessionSuite.scala diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala index 4286e8a6ca2c8..3038b822beb4a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala @@ -34,6 +34,8 @@ private[spark] class TestSparkSession(sc: SparkContext) extends SparkSession(sc) this(new SparkConf) } + SparkSession.setDefaultSession(this) + @transient override lazy val sessionState: SessionState = { new TestSQLSessionStateBuilder(this, None).build() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSparkSessionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSparkSessionSuite.scala new file mode 100644 index 0000000000000..4019c6888da98 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSparkSessionSuite.scala @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.test + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.SparkSession + +class TestSparkSessionSuite extends SparkFunSuite { + test("default session is set in constructor") { + val session = new TestSparkSession() + assert(SparkSession.getDefaultSession.contains(session)) + session.stop() + } +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index fcf2025d34432..814038d4ef7af 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -159,6 +159,10 @@ private[hive] class TestHiveSparkSession( private val loadTestTables: Boolean) extends SparkSession(sc) with Logging { self => + // TODO(SPARK-23826): TestHiveSparkSession should set default session the same way as + // TestSparkSession, but doing this the same way breaks many tests in the package. We need + // to investigate and find a different strategy. + def this(sc: SparkContext, loadTestTables: Boolean) { this( sc, From df05fb63abe6018ccbe572c34cf65fc3ecbf1166 Mon Sep 17 00:00:00 2001 From: Jongyoul Lee Date: Fri, 30 Mar 2018 14:07:35 +0800 Subject: [PATCH 0538/1161] [SPARK-23743][SQL] Changed a comparison logic from containing 'slf4j' to starting with 'org.slf4j' ## What changes were proposed in this pull request? isSharedClass returns if some classes can/should be shared or not. It checks if the classes names have some keywords or start with some names. Following the logic, it can occur unintended behaviors when a custom package has `slf4j` inside the package or class name. As I guess, the first intention seems to figure out the class containing `org.slf4j`. It would be better to change the comparison logic to `name.startsWith("org.slf4j")` ## How was this patch tested? This patch should pass all of the current tests and keep all of the current behaviors. In my case, I'm using ProtobufDeserializer to get a table schema from hive tables. Thus some Protobuf packages and names have `slf4j` inside. Without this patch, it cannot be resolved because of ClassCastException from different classloaders. Author: Jongyoul Lee Closes #20860 from jongyoul/SPARK-23743. --- .../apache/spark/sql/hive/client/IsolatedClientLoader.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala index 12975bc85b971..c2690ec32b9e7 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala @@ -179,8 +179,9 @@ private[hive] class IsolatedClientLoader( val isHadoopClass = name.startsWith("org.apache.hadoop.") && !name.startsWith("org.apache.hadoop.hive.") - name.contains("slf4j") || - name.contains("log4j") || + name.startsWith("org.slf4j") || + name.startsWith("org.apache.log4j") || // log4j1.x + name.startsWith("org.apache.logging.log4j") || // log4j2 name.startsWith("org.apache.spark.") || (sharesHadoopClasses && isHadoopClass) || name.startsWith("scala.") || From b02e76cbffe9e589b7a4e60f91250ca12a4420b2 Mon Sep 17 00:00:00 2001 From: yucai Date: Fri, 30 Mar 2018 15:07:38 +0800 Subject: [PATCH 0539/1161] [SPARK-23727][SQL] Support for pushing down filters for DateType in parquet ## What changes were proposed in this pull request? This PR supports for pushing down filters for DateType in parquet ## How was this patch tested? Added UT and tested in local. Author: yucai Closes #20851 from yucai/SPARK-23727. --- .../apache/spark/sql/internal/SQLConf.scala | 9 ++++ .../datasources/parquet/ParquetFilters.scala | 33 ++++++++++++ .../parquet/ParquetFilterSuite.scala | 50 +++++++++++++++++-- 3 files changed, 89 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 9cb03b5bb6152..13f31a6b2eb93 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -353,6 +353,13 @@ object SQLConf { .booleanConf .createWithDefault(true) + val PARQUET_FILTER_PUSHDOWN_DATE_ENABLED = buildConf("spark.sql.parquet.filterPushdown.date") + .doc("If true, enables Parquet filter push-down optimization for Date. " + + "This configuration only has an effect when 'spark.sql.parquet.filterPushdown' is enabled.") + .internal() + .booleanConf + .createWithDefault(true) + val PARQUET_WRITE_LEGACY_FORMAT = buildConf("spark.sql.parquet.writeLegacyFormat") .doc("Whether to be compatible with the legacy Parquet format adopted by Spark 1.4 and prior " + "versions, when converting Parquet schema to Spark SQL schema and vice versa.") @@ -1329,6 +1336,8 @@ class SQLConf extends Serializable with Logging { def parquetFilterPushDown: Boolean = getConf(PARQUET_FILTER_PUSHDOWN_ENABLED) + def parquetFilterPushDownDate: Boolean = getConf(PARQUET_FILTER_PUSHDOWN_DATE_ENABLED) + def orcFilterPushDown: Boolean = getConf(ORC_FILTER_PUSHDOWN_ENABLED) def verifyPartitionPath: Boolean = getConf(HIVE_VERIFY_PARTITION_PATH) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala index 763841efbd9f3..ccc8306866d68 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala @@ -17,10 +17,15 @@ package org.apache.spark.sql.execution.datasources.parquet +import java.sql.Date + import org.apache.parquet.filter2.predicate._ import org.apache.parquet.filter2.predicate.FilterApi._ import org.apache.parquet.io.api.Binary +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.catalyst.util.DateTimeUtils.SQLDate +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources import org.apache.spark.sql.types._ @@ -29,6 +34,10 @@ import org.apache.spark.sql.types._ */ private[parquet] object ParquetFilters { + private def dateToDays(date: Date): SQLDate = { + DateTimeUtils.fromJavaDate(date) + } + private val makeEq: PartialFunction[DataType, (String, Any) => FilterPredicate] = { case BooleanType => (n: String, v: Any) => FilterApi.eq(booleanColumn(n), v.asInstanceOf[java.lang.Boolean]) @@ -50,6 +59,10 @@ private[parquet] object ParquetFilters { (n: String, v: Any) => FilterApi.eq( binaryColumn(n), Option(v).map(b => Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])).orNull) + case DateType if SQLConf.get.parquetFilterPushDownDate => + (n: String, v: Any) => FilterApi.eq( + intColumn(n), + Option(v).map(date => dateToDays(date.asInstanceOf[Date]).asInstanceOf[Integer]).orNull) } private val makeNotEq: PartialFunction[DataType, (String, Any) => FilterPredicate] = { @@ -72,6 +85,10 @@ private[parquet] object ParquetFilters { (n: String, v: Any) => FilterApi.notEq( binaryColumn(n), Option(v).map(b => Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])).orNull) + case DateType if SQLConf.get.parquetFilterPushDownDate => + (n: String, v: Any) => FilterApi.notEq( + intColumn(n), + Option(v).map(date => dateToDays(date.asInstanceOf[Date]).asInstanceOf[Integer]).orNull) } private val makeLt: PartialFunction[DataType, (String, Any) => FilterPredicate] = { @@ -91,6 +108,10 @@ private[parquet] object ParquetFilters { case BinaryType => (n: String, v: Any) => FilterApi.lt(binaryColumn(n), Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])) + case DateType if SQLConf.get.parquetFilterPushDownDate => + (n: String, v: Any) => FilterApi.lt( + intColumn(n), + Option(v).map(date => dateToDays(date.asInstanceOf[Date]).asInstanceOf[Integer]).orNull) } private val makeLtEq: PartialFunction[DataType, (String, Any) => FilterPredicate] = { @@ -110,6 +131,10 @@ private[parquet] object ParquetFilters { case BinaryType => (n: String, v: Any) => FilterApi.ltEq(binaryColumn(n), Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])) + case DateType if SQLConf.get.parquetFilterPushDownDate => + (n: String, v: Any) => FilterApi.ltEq( + intColumn(n), + Option(v).map(date => dateToDays(date.asInstanceOf[Date]).asInstanceOf[Integer]).orNull) } private val makeGt: PartialFunction[DataType, (String, Any) => FilterPredicate] = { @@ -129,6 +154,10 @@ private[parquet] object ParquetFilters { case BinaryType => (n: String, v: Any) => FilterApi.gt(binaryColumn(n), Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])) + case DateType if SQLConf.get.parquetFilterPushDownDate => + (n: String, v: Any) => FilterApi.gt( + intColumn(n), + Option(v).map(date => dateToDays(date.asInstanceOf[Date]).asInstanceOf[Integer]).orNull) } private val makeGtEq: PartialFunction[DataType, (String, Any) => FilterPredicate] = { @@ -148,6 +177,10 @@ private[parquet] object ParquetFilters { case BinaryType => (n: String, v: Any) => FilterApi.gtEq(binaryColumn(n), Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])) + case DateType if SQLConf.get.parquetFilterPushDownDate => + (n: String, v: Any) => FilterApi.gtEq( + intColumn(n), + Option(v).map(date => dateToDays(date.asInstanceOf[Date]).asInstanceOf[Integer]).orNull) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index 33801954ebd51..1d3476e747046 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.datasources.parquet import java.nio.charset.StandardCharsets +import java.sql.Date import org.apache.parquet.filter2.predicate.{FilterPredicate, Operators} import org.apache.parquet.filter2.predicate.FilterApi._ @@ -76,8 +77,10 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex expected: Seq[Row]): Unit = { val output = predicate.collect { case a: Attribute => a }.distinct - withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> "true") { - withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { + withSQLConf( + SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> "true", + SQLConf.PARQUET_FILTER_PUSHDOWN_DATE_ENABLED.key -> "true", + SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { val query = df .select(output.map(e => Column(e)): _*) .where(Column(predicate)) @@ -102,7 +105,6 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex maybeFilter.exists(_.getClass === filterClass) } checker(stripSparkFilter(query), expected) - } } } @@ -313,6 +315,48 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex } } + test("filter pushdown - date") { + implicit class StringToDate(s: String) { + def date: Date = Date.valueOf(s) + } + + val data = Seq("2018-03-18", "2018-03-19", "2018-03-20", "2018-03-21") + + withParquetDataFrame(data.map(i => Tuple1(i.date))) { implicit df => + checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) + checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], data.map(i => Row.apply(i.date))) + + checkFilterPredicate('_1 === "2018-03-18".date, classOf[Eq[_]], "2018-03-18".date) + checkFilterPredicate('_1 <=> "2018-03-18".date, classOf[Eq[_]], "2018-03-18".date) + checkFilterPredicate('_1 =!= "2018-03-18".date, classOf[NotEq[_]], + Seq("2018-03-19", "2018-03-20", "2018-03-21").map(i => Row.apply(i.date))) + + checkFilterPredicate('_1 < "2018-03-19".date, classOf[Lt[_]], "2018-03-18".date) + checkFilterPredicate('_1 > "2018-03-20".date, classOf[Gt[_]], "2018-03-21".date) + checkFilterPredicate('_1 <= "2018-03-18".date, classOf[LtEq[_]], "2018-03-18".date) + checkFilterPredicate('_1 >= "2018-03-21".date, classOf[GtEq[_]], "2018-03-21".date) + + checkFilterPredicate( + Literal("2018-03-18".date) === '_1, classOf[Eq[_]], "2018-03-18".date) + checkFilterPredicate( + Literal("2018-03-18".date) <=> '_1, classOf[Eq[_]], "2018-03-18".date) + checkFilterPredicate( + Literal("2018-03-19".date) > '_1, classOf[Lt[_]], "2018-03-18".date) + checkFilterPredicate( + Literal("2018-03-20".date) < '_1, classOf[Gt[_]], "2018-03-21".date) + checkFilterPredicate( + Literal("2018-03-18".date) >= '_1, classOf[LtEq[_]], "2018-03-18".date) + checkFilterPredicate( + Literal("2018-03-21".date) <= '_1, classOf[GtEq[_]], "2018-03-21".date) + + checkFilterPredicate(!('_1 < "2018-03-21".date), classOf[GtEq[_]], "2018-03-21".date) + checkFilterPredicate( + '_1 < "2018-03-19".date || '_1 > "2018-03-20".date, + classOf[Operators.Or], + Seq(Row("2018-03-18".date), Row("2018-03-21".date))) + } + } + test("SPARK-6554: don't push down predicates which reference partition columns") { import testImplicits._ From 5b5a36ed6d2bb0971edfeccddf0f280936d2275f Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Fri, 30 Mar 2018 21:54:26 +0800 Subject: [PATCH 0540/1161] Roll forward "[SPARK-23096][SS] Migrate rate source to V2" ## What changes were proposed in this pull request? Roll forward c68ec4e (#20688). There are two minor test changes required: * An error which used to be TreeNodeException[ArithmeticException] is no longer wrapped and is now just ArithmeticException. * The test framework simply does not set the active Spark session. (Or rather, it doesn't do so early enough - I think it only happens when a query is analyzed.) I've added the required logic to SQLTestUtils. ## How was this patch tested? existing tests Author: Jose Torres Author: jerryshao Closes #20922 from jose-torres/ratefix. --- ...pache.spark.sql.sources.DataSourceRegister | 3 +- .../execution/datasources/DataSource.scala | 6 +- .../streaming/RateSourceProvider.scala | 262 ------------------ .../ContinuousRateStreamSource.scala | 25 +- .../sources/RateStreamMicroBatchReader.scala | 222 +++++++++++++++ .../sources/RateStreamProvider.scala | 125 +++++++++ .../sources/RateStreamSourceV2.scala | 187 ------------- .../streaming/RateSourceV2Suite.scala | 191 ------------- .../RateStreamProviderSuite.scala} | 166 ++++++++++- 9 files changed, 524 insertions(+), 663 deletions(-) delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala rename sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/{RateSourceSuite.scala => sources/RateStreamProviderSuite.scala} (50%) diff --git a/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister index 1fe9c093af99f..1b37905543b4e 100644 --- a/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister +++ b/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -5,6 +5,5 @@ org.apache.spark.sql.execution.datasources.orc.OrcFileFormat org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat org.apache.spark.sql.execution.datasources.text.TextFileFormat org.apache.spark.sql.execution.streaming.ConsoleSinkProvider -org.apache.spark.sql.execution.streaming.RateSourceProvider +org.apache.spark.sql.execution.streaming.sources.RateStreamProvider org.apache.spark.sql.execution.streaming.sources.TextSocketSourceProvider -org.apache.spark.sql.execution.streaming.sources.RateSourceProviderV2 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index 31fa89b4570a6..b84ea769808f9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -41,7 +41,7 @@ import org.apache.spark.sql.execution.datasources.json.JsonFileFormat import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat import org.apache.spark.sql.execution.streaming._ -import org.apache.spark.sql.execution.streaming.sources.TextSocketSourceProvider +import org.apache.spark.sql.execution.streaming.sources.{RateStreamProvider, TextSocketSourceProvider} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources._ import org.apache.spark.sql.streaming.OutputMode @@ -566,6 +566,7 @@ object DataSource extends Logging { val orc = "org.apache.spark.sql.hive.orc.OrcFileFormat" val nativeOrc = classOf[OrcFileFormat].getCanonicalName val socket = classOf[TextSocketSourceProvider].getCanonicalName + val rate = classOf[RateStreamProvider].getCanonicalName Map( "org.apache.spark.sql.jdbc" -> jdbc, @@ -587,7 +588,8 @@ object DataSource extends Logging { "org.apache.spark.ml.source.libsvm.DefaultSource" -> libsvm, "org.apache.spark.ml.source.libsvm" -> libsvm, "com.databricks.spark.csv" -> csv, - "org.apache.spark.sql.execution.streaming.TextSocketSourceProvider" -> socket + "org.apache.spark.sql.execution.streaming.TextSocketSourceProvider" -> socket, + "org.apache.spark.sql.execution.streaming.RateSourceProvider" -> rate ) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala deleted file mode 100644 index 649fbbfa184ec..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala +++ /dev/null @@ -1,262 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.streaming - -import java.io._ -import java.nio.charset.StandardCharsets -import java.util.Optional -import java.util.concurrent.TimeUnit - -import org.apache.commons.io.IOUtils - -import org.apache.spark.internal.Logging -import org.apache.spark.network.util.JavaUtils -import org.apache.spark.sql.{AnalysisException, DataFrame, SQLContext} -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} -import org.apache.spark.sql.execution.streaming.continuous.RateStreamContinuousReader -import org.apache.spark.sql.sources.{DataSourceRegister, StreamSourceProvider} -import org.apache.spark.sql.sources.v2._ -import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousReader -import org.apache.spark.sql.types._ -import org.apache.spark.util.{ManualClock, SystemClock} - -/** - * A source that generates increment long values with timestamps. Each generated row has two - * columns: a timestamp column for the generated time and an auto increment long column starting - * with 0L. - * - * This source supports the following options: - * - `rowsPerSecond` (e.g. 100, default: 1): How many rows should be generated per second. - * - `rampUpTime` (e.g. 5s, default: 0s): How long to ramp up before the generating speed - * becomes `rowsPerSecond`. Using finer granularities than seconds will be truncated to integer - * seconds. - * - `numPartitions` (e.g. 10, default: Spark's default parallelism): The partition number for the - * generated rows. The source will try its best to reach `rowsPerSecond`, but the query may - * be resource constrained, and `numPartitions` can be tweaked to help reach the desired speed. - */ -class RateSourceProvider extends StreamSourceProvider with DataSourceRegister - with DataSourceV2 with ContinuousReadSupport { - - override def sourceSchema( - sqlContext: SQLContext, - schema: Option[StructType], - providerName: String, - parameters: Map[String, String]): (String, StructType) = { - if (schema.nonEmpty) { - throw new AnalysisException("The rate source does not support a user-specified schema.") - } - - (shortName(), RateSourceProvider.SCHEMA) - } - - override def createSource( - sqlContext: SQLContext, - metadataPath: String, - schema: Option[StructType], - providerName: String, - parameters: Map[String, String]): Source = { - val params = CaseInsensitiveMap(parameters) - - val rowsPerSecond = params.get("rowsPerSecond").map(_.toLong).getOrElse(1L) - if (rowsPerSecond <= 0) { - throw new IllegalArgumentException( - s"Invalid value '${params("rowsPerSecond")}'. The option 'rowsPerSecond' " + - "must be positive") - } - - val rampUpTimeSeconds = - params.get("rampUpTime").map(JavaUtils.timeStringAsSec(_)).getOrElse(0L) - if (rampUpTimeSeconds < 0) { - throw new IllegalArgumentException( - s"Invalid value '${params("rampUpTime")}'. The option 'rampUpTime' " + - "must not be negative") - } - - val numPartitions = params.get("numPartitions").map(_.toInt).getOrElse( - sqlContext.sparkContext.defaultParallelism) - if (numPartitions <= 0) { - throw new IllegalArgumentException( - s"Invalid value '${params("numPartitions")}'. The option 'numPartitions' " + - "must be positive") - } - - new RateStreamSource( - sqlContext, - metadataPath, - rowsPerSecond, - rampUpTimeSeconds, - numPartitions, - params.get("useManualClock").map(_.toBoolean).getOrElse(false) // Only for testing - ) - } - - override def createContinuousReader( - schema: Optional[StructType], - checkpointLocation: String, - options: DataSourceOptions): ContinuousReader = { - new RateStreamContinuousReader(options) - } - - override def shortName(): String = "rate" -} - -object RateSourceProvider { - val SCHEMA = - StructType(StructField("timestamp", TimestampType) :: StructField("value", LongType) :: Nil) - - val VERSION = 1 -} - -class RateStreamSource( - sqlContext: SQLContext, - metadataPath: String, - rowsPerSecond: Long, - rampUpTimeSeconds: Long, - numPartitions: Int, - useManualClock: Boolean) extends Source with Logging { - - import RateSourceProvider._ - import RateStreamSource._ - - val clock = if (useManualClock) new ManualClock else new SystemClock - - private val maxSeconds = Long.MaxValue / rowsPerSecond - - if (rampUpTimeSeconds > maxSeconds) { - throw new ArithmeticException( - s"Integer overflow. Max offset with $rowsPerSecond rowsPerSecond" + - s" is $maxSeconds, but 'rampUpTimeSeconds' is $rampUpTimeSeconds.") - } - - private val startTimeMs = { - val metadataLog = - new HDFSMetadataLog[LongOffset](sqlContext.sparkSession, metadataPath) { - override def serialize(metadata: LongOffset, out: OutputStream): Unit = { - val writer = new BufferedWriter(new OutputStreamWriter(out, StandardCharsets.UTF_8)) - writer.write("v" + VERSION + "\n") - writer.write(metadata.json) - writer.flush - } - - override def deserialize(in: InputStream): LongOffset = { - val content = IOUtils.toString(new InputStreamReader(in, StandardCharsets.UTF_8)) - // HDFSMetadataLog guarantees that it never creates a partial file. - assert(content.length != 0) - if (content(0) == 'v') { - val indexOfNewLine = content.indexOf("\n") - if (indexOfNewLine > 0) { - val version = parseVersion(content.substring(0, indexOfNewLine), VERSION) - LongOffset(SerializedOffset(content.substring(indexOfNewLine + 1))) - } else { - throw new IllegalStateException( - s"Log file was malformed: failed to detect the log file version line.") - } - } else { - throw new IllegalStateException( - s"Log file was malformed: failed to detect the log file version line.") - } - } - } - - metadataLog.get(0).getOrElse { - val offset = LongOffset(clock.getTimeMillis()) - metadataLog.add(0, offset) - logInfo(s"Start time: $offset") - offset - }.offset - } - - /** When the system time runs backward, "lastTimeMs" will make sure we are still monotonic. */ - @volatile private var lastTimeMs = startTimeMs - - override def schema: StructType = RateSourceProvider.SCHEMA - - override def getOffset: Option[Offset] = { - val now = clock.getTimeMillis() - if (lastTimeMs < now) { - lastTimeMs = now - } - Some(LongOffset(TimeUnit.MILLISECONDS.toSeconds(lastTimeMs - startTimeMs))) - } - - override def getBatch(start: Option[Offset], end: Offset): DataFrame = { - val startSeconds = start.flatMap(LongOffset.convert(_).map(_.offset)).getOrElse(0L) - val endSeconds = LongOffset.convert(end).map(_.offset).getOrElse(0L) - assert(startSeconds <= endSeconds, s"startSeconds($startSeconds) > endSeconds($endSeconds)") - if (endSeconds > maxSeconds) { - throw new ArithmeticException("Integer overflow. Max offset with " + - s"$rowsPerSecond rowsPerSecond is $maxSeconds, but it's $endSeconds now.") - } - // Fix "lastTimeMs" for recovery - if (lastTimeMs < TimeUnit.SECONDS.toMillis(endSeconds) + startTimeMs) { - lastTimeMs = TimeUnit.SECONDS.toMillis(endSeconds) + startTimeMs - } - val rangeStart = valueAtSecond(startSeconds, rowsPerSecond, rampUpTimeSeconds) - val rangeEnd = valueAtSecond(endSeconds, rowsPerSecond, rampUpTimeSeconds) - logDebug(s"startSeconds: $startSeconds, endSeconds: $endSeconds, " + - s"rangeStart: $rangeStart, rangeEnd: $rangeEnd") - - if (rangeStart == rangeEnd) { - return sqlContext.internalCreateDataFrame( - sqlContext.sparkContext.emptyRDD, schema, isStreaming = true) - } - - val localStartTimeMs = startTimeMs + TimeUnit.SECONDS.toMillis(startSeconds) - val relativeMsPerValue = - TimeUnit.SECONDS.toMillis(endSeconds - startSeconds).toDouble / (rangeEnd - rangeStart) - - val rdd = sqlContext.sparkContext.range(rangeStart, rangeEnd, 1, numPartitions).map { v => - val relative = math.round((v - rangeStart) * relativeMsPerValue) - InternalRow(DateTimeUtils.fromMillis(relative + localStartTimeMs), v) - } - sqlContext.internalCreateDataFrame(rdd, schema, isStreaming = true) - } - - override def stop(): Unit = {} - - override def toString: String = s"RateSource[rowsPerSecond=$rowsPerSecond, " + - s"rampUpTimeSeconds=$rampUpTimeSeconds, numPartitions=$numPartitions]" -} - -object RateStreamSource { - - /** Calculate the end value we will emit at the time `seconds`. */ - def valueAtSecond(seconds: Long, rowsPerSecond: Long, rampUpTimeSeconds: Long): Long = { - // E.g., rampUpTimeSeconds = 4, rowsPerSecond = 10 - // Then speedDeltaPerSecond = 2 - // - // seconds = 0 1 2 3 4 5 6 - // speed = 0 2 4 6 8 10 10 (speedDeltaPerSecond * seconds) - // end value = 0 2 6 12 20 30 40 (0 + speedDeltaPerSecond * seconds) * (seconds + 1) / 2 - val speedDeltaPerSecond = rowsPerSecond / (rampUpTimeSeconds + 1) - if (seconds <= rampUpTimeSeconds) { - // Calculate "(0 + speedDeltaPerSecond * seconds) * (seconds + 1) / 2" in a special way to - // avoid overflow - if (seconds % 2 == 1) { - (seconds + 1) / 2 * speedDeltaPerSecond * seconds - } else { - seconds / 2 * speedDeltaPerSecond * (seconds + 1) - } - } else { - // rampUpPart is just a special case of the above formula: rampUpTimeSeconds == seconds - val rampUpPart = valueAtSecond(rampUpTimeSeconds, rowsPerSecond, rampUpTimeSeconds) - rampUpPart + (seconds - rampUpTimeSeconds) * rowsPerSecond - } - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala index 20d90069163a6..2f0de2612c150 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala @@ -24,8 +24,8 @@ import org.json4s.jackson.Serialization import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.execution.streaming.{RateSourceProvider, RateStreamOffset, ValueRunTimeMsPair} -import org.apache.spark.sql.execution.streaming.sources.RateStreamSourceV2 +import org.apache.spark.sql.execution.streaming.{RateStreamOffset, ValueRunTimeMsPair} +import org.apache.spark.sql.execution.streaming.sources.RateStreamProvider import org.apache.spark.sql.sources.v2.DataSourceOptions import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousDataReader, ContinuousReader, Offset, PartitionOffset} @@ -40,8 +40,8 @@ class RateStreamContinuousReader(options: DataSourceOptions) val creationTime = System.currentTimeMillis() - val numPartitions = options.get(RateStreamSourceV2.NUM_PARTITIONS).orElse("5").toInt - val rowsPerSecond = options.get(RateStreamSourceV2.ROWS_PER_SECOND).orElse("6").toLong + val numPartitions = options.get(RateStreamProvider.NUM_PARTITIONS).orElse("5").toInt + val rowsPerSecond = options.get(RateStreamProvider.ROWS_PER_SECOND).orElse("6").toLong val perPartitionRate = rowsPerSecond.toDouble / numPartitions.toDouble override def mergeOffsets(offsets: Array[PartitionOffset]): Offset = { @@ -57,12 +57,12 @@ class RateStreamContinuousReader(options: DataSourceOptions) RateStreamOffset(Serialization.read[Map[Int, ValueRunTimeMsPair]](json)) } - override def readSchema(): StructType = RateSourceProvider.SCHEMA + override def readSchema(): StructType = RateStreamProvider.SCHEMA private var offset: Offset = _ override def setStartOffset(offset: java.util.Optional[Offset]): Unit = { - this.offset = offset.orElse(RateStreamSourceV2.createInitialOffset(numPartitions, creationTime)) + this.offset = offset.orElse(createInitialOffset(numPartitions, creationTime)) } override def getStartOffset(): Offset = offset @@ -98,6 +98,19 @@ class RateStreamContinuousReader(options: DataSourceOptions) override def commit(end: Offset): Unit = {} override def stop(): Unit = {} + private def createInitialOffset(numPartitions: Int, creationTimeMs: Long) = { + RateStreamOffset( + Range(0, numPartitions).map { i => + // Note that the starting offset is exclusive, so we have to decrement the starting value + // by the increment that will later be applied. The first row output in each + // partition will have a value equal to the partition index. + (i, + ValueRunTimeMsPair( + (i - numPartitions).toLong, + creationTimeMs)) + }.toMap) + } + } case class RateStreamContinuousDataReaderFactory( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala new file mode 100644 index 0000000000000..6cf8520fc544f --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala @@ -0,0 +1,222 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.sources + +import java.io._ +import java.nio.charset.StandardCharsets +import java.util.Optional +import java.util.concurrent.TimeUnit + +import scala.collection.JavaConverters._ + +import org.apache.commons.io.IOUtils + +import org.apache.spark.internal.Logging +import org.apache.spark.network.util.JavaUtils +import org.apache.spark.sql.{Row, SparkSession} +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.sources.v2.DataSourceOptions +import org.apache.spark.sql.sources.v2.reader._ +import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset} +import org.apache.spark.sql.types.StructType +import org.apache.spark.util.{ManualClock, SystemClock} + +class RateStreamMicroBatchReader(options: DataSourceOptions, checkpointLocation: String) + extends MicroBatchReader with Logging { + import RateStreamProvider._ + + private[sources] val clock = { + // The option to use a manual clock is provided only for unit testing purposes. + if (options.getBoolean("useManualClock", false)) new ManualClock else new SystemClock + } + + private val rowsPerSecond = + options.get(ROWS_PER_SECOND).orElse("1").toLong + + private val rampUpTimeSeconds = + Option(options.get(RAMP_UP_TIME).orElse(null.asInstanceOf[String])) + .map(JavaUtils.timeStringAsSec(_)) + .getOrElse(0L) + + private val maxSeconds = Long.MaxValue / rowsPerSecond + + if (rampUpTimeSeconds > maxSeconds) { + throw new ArithmeticException( + s"Integer overflow. Max offset with $rowsPerSecond rowsPerSecond" + + s" is $maxSeconds, but 'rampUpTimeSeconds' is $rampUpTimeSeconds.") + } + + private[sources] val creationTimeMs = { + val session = SparkSession.getActiveSession.orElse(SparkSession.getDefaultSession) + require(session.isDefined) + + val metadataLog = + new HDFSMetadataLog[LongOffset](session.get, checkpointLocation) { + override def serialize(metadata: LongOffset, out: OutputStream): Unit = { + val writer = new BufferedWriter(new OutputStreamWriter(out, StandardCharsets.UTF_8)) + writer.write("v" + VERSION + "\n") + writer.write(metadata.json) + writer.flush + } + + override def deserialize(in: InputStream): LongOffset = { + val content = IOUtils.toString(new InputStreamReader(in, StandardCharsets.UTF_8)) + // HDFSMetadataLog guarantees that it never creates a partial file. + assert(content.length != 0) + if (content(0) == 'v') { + val indexOfNewLine = content.indexOf("\n") + if (indexOfNewLine > 0) { + parseVersion(content.substring(0, indexOfNewLine), VERSION) + LongOffset(SerializedOffset(content.substring(indexOfNewLine + 1))) + } else { + throw new IllegalStateException( + s"Log file was malformed: failed to detect the log file version line.") + } + } else { + throw new IllegalStateException( + s"Log file was malformed: failed to detect the log file version line.") + } + } + } + + metadataLog.get(0).getOrElse { + val offset = LongOffset(clock.getTimeMillis()) + metadataLog.add(0, offset) + logInfo(s"Start time: $offset") + offset + }.offset + } + + @volatile private var lastTimeMs: Long = creationTimeMs + + private var start: LongOffset = _ + private var end: LongOffset = _ + + override def readSchema(): StructType = SCHEMA + + override def setOffsetRange(start: Optional[Offset], end: Optional[Offset]): Unit = { + this.start = start.orElse(LongOffset(0L)).asInstanceOf[LongOffset] + this.end = end.orElse { + val now = clock.getTimeMillis() + if (lastTimeMs < now) { + lastTimeMs = now + } + LongOffset(TimeUnit.MILLISECONDS.toSeconds(lastTimeMs - creationTimeMs)) + }.asInstanceOf[LongOffset] + } + + override def getStartOffset(): Offset = { + if (start == null) throw new IllegalStateException("start offset not set") + start + } + override def getEndOffset(): Offset = { + if (end == null) throw new IllegalStateException("end offset not set") + end + } + + override def deserializeOffset(json: String): Offset = { + LongOffset(json.toLong) + } + + override def createDataReaderFactories(): java.util.List[DataReaderFactory[Row]] = { + val startSeconds = LongOffset.convert(start).map(_.offset).getOrElse(0L) + val endSeconds = LongOffset.convert(end).map(_.offset).getOrElse(0L) + assert(startSeconds <= endSeconds, s"startSeconds($startSeconds) > endSeconds($endSeconds)") + if (endSeconds > maxSeconds) { + throw new ArithmeticException("Integer overflow. Max offset with " + + s"$rowsPerSecond rowsPerSecond is $maxSeconds, but it's $endSeconds now.") + } + // Fix "lastTimeMs" for recovery + if (lastTimeMs < TimeUnit.SECONDS.toMillis(endSeconds) + creationTimeMs) { + lastTimeMs = TimeUnit.SECONDS.toMillis(endSeconds) + creationTimeMs + } + val rangeStart = valueAtSecond(startSeconds, rowsPerSecond, rampUpTimeSeconds) + val rangeEnd = valueAtSecond(endSeconds, rowsPerSecond, rampUpTimeSeconds) + logDebug(s"startSeconds: $startSeconds, endSeconds: $endSeconds, " + + s"rangeStart: $rangeStart, rangeEnd: $rangeEnd") + + if (rangeStart == rangeEnd) { + return List.empty.asJava + } + + val localStartTimeMs = creationTimeMs + TimeUnit.SECONDS.toMillis(startSeconds) + val relativeMsPerValue = + TimeUnit.SECONDS.toMillis(endSeconds - startSeconds).toDouble / (rangeEnd - rangeStart) + val numPartitions = { + val activeSession = SparkSession.getActiveSession + require(activeSession.isDefined) + Option(options.get(NUM_PARTITIONS).orElse(null.asInstanceOf[String])) + .map(_.toInt) + .getOrElse(activeSession.get.sparkContext.defaultParallelism) + } + + (0 until numPartitions).map { p => + new RateStreamMicroBatchDataReaderFactory( + p, numPartitions, rangeStart, rangeEnd, localStartTimeMs, relativeMsPerValue) + : DataReaderFactory[Row] + }.toList.asJava + } + + override def commit(end: Offset): Unit = {} + + override def stop(): Unit = {} + + override def toString: String = s"MicroBatchRateSource[rowsPerSecond=$rowsPerSecond, " + + s"rampUpTimeSeconds=$rampUpTimeSeconds, " + + s"numPartitions=${options.get(NUM_PARTITIONS).orElse("default")}" +} + +class RateStreamMicroBatchDataReaderFactory( + partitionId: Int, + numPartitions: Int, + rangeStart: Long, + rangeEnd: Long, + localStartTimeMs: Long, + relativeMsPerValue: Double) extends DataReaderFactory[Row] { + + override def createDataReader(): DataReader[Row] = new RateStreamMicroBatchDataReader( + partitionId, numPartitions, rangeStart, rangeEnd, localStartTimeMs, relativeMsPerValue) +} + +class RateStreamMicroBatchDataReader( + partitionId: Int, + numPartitions: Int, + rangeStart: Long, + rangeEnd: Long, + localStartTimeMs: Long, + relativeMsPerValue: Double) extends DataReader[Row] { + private var count = 0 + + override def next(): Boolean = { + rangeStart + partitionId + numPartitions * count < rangeEnd + } + + override def get(): Row = { + val currValue = rangeStart + partitionId + numPartitions * count + count += 1 + val relative = math.round((currValue - rangeStart) * relativeMsPerValue) + Row( + DateTimeUtils.toJavaTimestamp( + DateTimeUtils.fromMillis(relative + localStartTimeMs)), + currValue + ) + } + + override def close(): Unit = {} +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala new file mode 100644 index 0000000000000..6bdd492f0cb35 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.sources + +import java.util.Optional + +import org.apache.spark.network.util.JavaUtils +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.execution.streaming.continuous.RateStreamContinuousReader +import org.apache.spark.sql.sources.DataSourceRegister +import org.apache.spark.sql.sources.v2._ +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReader, MicroBatchReader} +import org.apache.spark.sql.types._ + +/** + * A source that generates increment long values with timestamps. Each generated row has two + * columns: a timestamp column for the generated time and an auto increment long column starting + * with 0L. + * + * This source supports the following options: + * - `rowsPerSecond` (e.g. 100, default: 1): How many rows should be generated per second. + * - `rampUpTime` (e.g. 5s, default: 0s): How long to ramp up before the generating speed + * becomes `rowsPerSecond`. Using finer granularities than seconds will be truncated to integer + * seconds. + * - `numPartitions` (e.g. 10, default: Spark's default parallelism): The partition number for the + * generated rows. The source will try its best to reach `rowsPerSecond`, but the query may + * be resource constrained, and `numPartitions` can be tweaked to help reach the desired speed. + */ +class RateStreamProvider extends DataSourceV2 + with MicroBatchReadSupport with ContinuousReadSupport with DataSourceRegister { + import RateStreamProvider._ + + override def createMicroBatchReader( + schema: Optional[StructType], + checkpointLocation: String, + options: DataSourceOptions): MicroBatchReader = { + if (options.get(ROWS_PER_SECOND).isPresent) { + val rowsPerSecond = options.get(ROWS_PER_SECOND).get().toLong + if (rowsPerSecond <= 0) { + throw new IllegalArgumentException( + s"Invalid value '$rowsPerSecond'. The option 'rowsPerSecond' must be positive") + } + } + + if (options.get(RAMP_UP_TIME).isPresent) { + val rampUpTimeSeconds = + JavaUtils.timeStringAsSec(options.get(RAMP_UP_TIME).get()) + if (rampUpTimeSeconds < 0) { + throw new IllegalArgumentException( + s"Invalid value '$rampUpTimeSeconds'. The option 'rampUpTime' must not be negative") + } + } + + if (options.get(NUM_PARTITIONS).isPresent) { + val numPartitions = options.get(NUM_PARTITIONS).get().toInt + if (numPartitions <= 0) { + throw new IllegalArgumentException( + s"Invalid value '$numPartitions'. The option 'numPartitions' must be positive") + } + } + + if (schema.isPresent) { + throw new AnalysisException("The rate source does not support a user-specified schema.") + } + + new RateStreamMicroBatchReader(options, checkpointLocation) + } + + override def createContinuousReader( + schema: Optional[StructType], + checkpointLocation: String, + options: DataSourceOptions): ContinuousReader = new RateStreamContinuousReader(options) + + override def shortName(): String = "rate" +} + +object RateStreamProvider { + val SCHEMA = + StructType(StructField("timestamp", TimestampType) :: StructField("value", LongType) :: Nil) + + val VERSION = 1 + + val NUM_PARTITIONS = "numPartitions" + val ROWS_PER_SECOND = "rowsPerSecond" + val RAMP_UP_TIME = "rampUpTime" + + /** Calculate the end value we will emit at the time `seconds`. */ + def valueAtSecond(seconds: Long, rowsPerSecond: Long, rampUpTimeSeconds: Long): Long = { + // E.g., rampUpTimeSeconds = 4, rowsPerSecond = 10 + // Then speedDeltaPerSecond = 2 + // + // seconds = 0 1 2 3 4 5 6 + // speed = 0 2 4 6 8 10 10 (speedDeltaPerSecond * seconds) + // end value = 0 2 6 12 20 30 40 (0 + speedDeltaPerSecond * seconds) * (seconds + 1) / 2 + val speedDeltaPerSecond = rowsPerSecond / (rampUpTimeSeconds + 1) + if (seconds <= rampUpTimeSeconds) { + // Calculate "(0 + speedDeltaPerSecond * seconds) * (seconds + 1) / 2" in a special way to + // avoid overflow + if (seconds % 2 == 1) { + (seconds + 1) / 2 * speedDeltaPerSecond * seconds + } else { + seconds / 2 * speedDeltaPerSecond * (seconds + 1) + } + } else { + // rampUpPart is just a special case of the above formula: rampUpTimeSeconds == seconds + val rampUpPart = valueAtSecond(rampUpTimeSeconds, rowsPerSecond, rampUpTimeSeconds) + rampUpPart + (seconds - rampUpTimeSeconds) * rowsPerSecond + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala deleted file mode 100644 index 4e2459bb05bd6..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala +++ /dev/null @@ -1,187 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.streaming.sources - -import java.util.Optional - -import scala.collection.JavaConverters._ -import scala.collection.mutable - -import org.json4s.DefaultFormats -import org.json4s.jackson.Serialization - -import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.execution.streaming.{RateStreamOffset, ValueRunTimeMsPair} -import org.apache.spark.sql.sources.DataSourceRegister -import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, MicroBatchReadSupport} -import org.apache.spark.sql.sources.v2.reader._ -import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset} -import org.apache.spark.sql.types.{LongType, StructField, StructType, TimestampType} -import org.apache.spark.util.{ManualClock, SystemClock} - -/** - * This is a temporary register as we build out v2 migration. Microbatch read support should - * be implemented in the same register as v1. - */ -class RateSourceProviderV2 extends DataSourceV2 with MicroBatchReadSupport with DataSourceRegister { - override def createMicroBatchReader( - schema: Optional[StructType], - checkpointLocation: String, - options: DataSourceOptions): MicroBatchReader = { - new RateStreamMicroBatchReader(options) - } - - override def shortName(): String = "ratev2" -} - -class RateStreamMicroBatchReader(options: DataSourceOptions) - extends MicroBatchReader { - implicit val defaultFormats: DefaultFormats = DefaultFormats - - val clock = { - // The option to use a manual clock is provided only for unit testing purposes. - if (options.get("useManualClock").orElse("false").toBoolean) new ManualClock - else new SystemClock - } - - private val numPartitions = - options.get(RateStreamSourceV2.NUM_PARTITIONS).orElse("5").toInt - private val rowsPerSecond = - options.get(RateStreamSourceV2.ROWS_PER_SECOND).orElse("6").toLong - - // The interval (in milliseconds) between rows in each partition. - // e.g. if there are 4 global rows per second, and 2 partitions, each partition - // should output rows every (1000 * 2 / 4) = 500 ms. - private val msPerPartitionBetweenRows = (1000 * numPartitions) / rowsPerSecond - - override def readSchema(): StructType = { - StructType( - StructField("timestamp", TimestampType, false) :: - StructField("value", LongType, false) :: Nil) - } - - val creationTimeMs = clock.getTimeMillis() - - private var start: RateStreamOffset = _ - private var end: RateStreamOffset = _ - - override def setOffsetRange( - start: Optional[Offset], - end: Optional[Offset]): Unit = { - this.start = start.orElse( - RateStreamSourceV2.createInitialOffset(numPartitions, creationTimeMs)) - .asInstanceOf[RateStreamOffset] - - this.end = end.orElse { - val currentTime = clock.getTimeMillis() - RateStreamOffset( - this.start.partitionToValueAndRunTimeMs.map { - case startOffset @ (part, ValueRunTimeMsPair(currentVal, currentReadTime)) => - // Calculate the number of rows we should advance in this partition (based on the - // current time), and output a corresponding offset. - val readInterval = currentTime - currentReadTime - val numNewRows = readInterval / msPerPartitionBetweenRows - if (numNewRows <= 0) { - startOffset - } else { - (part, ValueRunTimeMsPair( - currentVal + (numNewRows * numPartitions), - currentReadTime + (numNewRows * msPerPartitionBetweenRows))) - } - } - ) - }.asInstanceOf[RateStreamOffset] - } - - override def getStartOffset(): Offset = { - if (start == null) throw new IllegalStateException("start offset not set") - start - } - override def getEndOffset(): Offset = { - if (end == null) throw new IllegalStateException("end offset not set") - end - } - - override def deserializeOffset(json: String): Offset = { - RateStreamOffset(Serialization.read[Map[Int, ValueRunTimeMsPair]](json)) - } - - override def createDataReaderFactories(): java.util.List[DataReaderFactory[Row]] = { - val startMap = start.partitionToValueAndRunTimeMs - val endMap = end.partitionToValueAndRunTimeMs - endMap.keys.toSeq.map { part => - val ValueRunTimeMsPair(endVal, _) = endMap(part) - val ValueRunTimeMsPair(startVal, startTimeMs) = startMap(part) - - val packedRows = mutable.ListBuffer[(Long, Long)]() - var outVal = startVal + numPartitions - var outTimeMs = startTimeMs - while (outVal <= endVal) { - packedRows.append((outTimeMs, outVal)) - outVal += numPartitions - outTimeMs += msPerPartitionBetweenRows - } - - RateStreamBatchTask(packedRows).asInstanceOf[DataReaderFactory[Row]] - }.toList.asJava - } - - override def commit(end: Offset): Unit = {} - override def stop(): Unit = {} -} - -case class RateStreamBatchTask(vals: Seq[(Long, Long)]) extends DataReaderFactory[Row] { - override def createDataReader(): DataReader[Row] = new RateStreamBatchReader(vals) -} - -class RateStreamBatchReader(vals: Seq[(Long, Long)]) extends DataReader[Row] { - private var currentIndex = -1 - - override def next(): Boolean = { - // Return true as long as the new index is in the seq. - currentIndex += 1 - currentIndex < vals.size - } - - override def get(): Row = { - Row( - DateTimeUtils.toJavaTimestamp(DateTimeUtils.fromMillis(vals(currentIndex)._1)), - vals(currentIndex)._2) - } - - override def close(): Unit = {} -} - -object RateStreamSourceV2 { - val NUM_PARTITIONS = "numPartitions" - val ROWS_PER_SECOND = "rowsPerSecond" - - private[sql] def createInitialOffset(numPartitions: Int, creationTimeMs: Long) = { - RateStreamOffset( - Range(0, numPartitions).map { i => - // Note that the starting offset is exclusive, so we have to decrement the starting value - // by the increment that will later be applied. The first row output in each - // partition will have a value equal to the partition index. - (i, - ValueRunTimeMsPair( - (i - numPartitions).toLong, - creationTimeMs)) - }.toMap) - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala deleted file mode 100644 index 983ba1668f58f..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala +++ /dev/null @@ -1,191 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.streaming - -import java.util.Optional -import java.util.concurrent.TimeUnit - -import scala.collection.JavaConverters._ - -import org.apache.spark.sql.Row -import org.apache.spark.sql.execution.datasources.DataSource -import org.apache.spark.sql.execution.streaming.continuous._ -import org.apache.spark.sql.execution.streaming.sources.{RateStreamBatchTask, RateStreamMicroBatchReader, RateStreamSourceV2} -import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, MicroBatchReadSupport} -import org.apache.spark.sql.sources.v2.DataSourceOptions -import org.apache.spark.sql.streaming.StreamTest -import org.apache.spark.util.ManualClock - -class RateSourceV2Suite extends StreamTest { - import testImplicits._ - - case class AdvanceRateManualClock(seconds: Long) extends AddData { - override def addData(query: Option[StreamExecution]): (BaseStreamingSource, Offset) = { - assert(query.nonEmpty) - val rateSource = query.get.logicalPlan.collect { - case StreamingExecutionRelation(source: RateStreamMicroBatchReader, _) => source - }.head - rateSource.clock.asInstanceOf[ManualClock].advance(TimeUnit.SECONDS.toMillis(seconds)) - rateSource.setOffsetRange(Optional.empty(), Optional.empty()) - (rateSource, rateSource.getEndOffset()) - } - } - - test("microbatch in registry") { - DataSource.lookupDataSource("ratev2", spark.sqlContext.conf).newInstance() match { - case ds: MicroBatchReadSupport => - val reader = ds.createMicroBatchReader(Optional.empty(), "", DataSourceOptions.empty()) - assert(reader.isInstanceOf[RateStreamMicroBatchReader]) - case _ => - throw new IllegalStateException("Could not find v2 read support for rate") - } - } - - test("basic microbatch execution") { - val input = spark.readStream - .format("rateV2") - .option("numPartitions", "1") - .option("rowsPerSecond", "10") - .option("useManualClock", "true") - .load() - testStream(input, useV2Sink = true)( - AdvanceRateManualClock(seconds = 1), - CheckLastBatch((0 until 10).map(v => new java.sql.Timestamp(v * 100L) -> v): _*), - StopStream, - StartStream(), - // Advance 2 seconds because creating a new RateSource will also create a new ManualClock - AdvanceRateManualClock(seconds = 2), - CheckLastBatch((10 until 20).map(v => new java.sql.Timestamp(v * 100L) -> v): _*) - ) - } - - test("microbatch - numPartitions propagated") { - val reader = new RateStreamMicroBatchReader( - new DataSourceOptions(Map("numPartitions" -> "11", "rowsPerSecond" -> "33").asJava)) - reader.setOffsetRange(Optional.empty(), Optional.empty()) - val tasks = reader.createDataReaderFactories() - assert(tasks.size == 11) - } - - test("microbatch - set offset") { - val reader = new RateStreamMicroBatchReader(DataSourceOptions.empty()) - val startOffset = RateStreamOffset(Map((0, ValueRunTimeMsPair(0, 1000)))) - val endOffset = RateStreamOffset(Map((0, ValueRunTimeMsPair(0, 2000)))) - reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) - assert(reader.getStartOffset() == startOffset) - assert(reader.getEndOffset() == endOffset) - } - - test("microbatch - infer offsets") { - val reader = new RateStreamMicroBatchReader( - new DataSourceOptions(Map("numPartitions" -> "1", "rowsPerSecond" -> "100").asJava)) - reader.clock.waitTillTime(reader.clock.getTimeMillis() + 100) - reader.setOffsetRange(Optional.empty(), Optional.empty()) - reader.getStartOffset() match { - case r: RateStreamOffset => - assert(r.partitionToValueAndRunTimeMs(0).runTimeMs == reader.creationTimeMs) - case _ => throw new IllegalStateException("unexpected offset type") - } - reader.getEndOffset() match { - case r: RateStreamOffset => - // End offset may be a bit beyond 100 ms/9 rows after creation if the wait lasted - // longer than 100ms. It should never be early. - assert(r.partitionToValueAndRunTimeMs(0).value >= 9) - assert(r.partitionToValueAndRunTimeMs(0).runTimeMs >= reader.creationTimeMs + 100) - - case _ => throw new IllegalStateException("unexpected offset type") - } - } - - test("microbatch - predetermined batch size") { - val reader = new RateStreamMicroBatchReader( - new DataSourceOptions(Map("numPartitions" -> "1", "rowsPerSecond" -> "20").asJava)) - val startOffset = RateStreamOffset(Map((0, ValueRunTimeMsPair(0, 1000)))) - val endOffset = RateStreamOffset(Map((0, ValueRunTimeMsPair(20, 2000)))) - reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) - val tasks = reader.createDataReaderFactories() - assert(tasks.size == 1) - assert(tasks.get(0).asInstanceOf[RateStreamBatchTask].vals.size == 20) - } - - test("microbatch - data read") { - val reader = new RateStreamMicroBatchReader( - new DataSourceOptions(Map("numPartitions" -> "11", "rowsPerSecond" -> "33").asJava)) - val startOffset = RateStreamSourceV2.createInitialOffset(11, reader.creationTimeMs) - val endOffset = RateStreamOffset(startOffset.partitionToValueAndRunTimeMs.toSeq.map { - case (part, ValueRunTimeMsPair(currentVal, currentReadTime)) => - (part, ValueRunTimeMsPair(currentVal + 33, currentReadTime + 1000)) - }.toMap) - - reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) - val tasks = reader.createDataReaderFactories() - assert(tasks.size == 11) - - val readData = tasks.asScala - .map(_.createDataReader()) - .flatMap { reader => - val buf = scala.collection.mutable.ListBuffer[Row]() - while (reader.next()) buf.append(reader.get()) - buf - } - - assert(readData.map(_.getLong(1)).sorted == Range(0, 33)) - } - - test("continuous in registry") { - DataSource.lookupDataSource("rate", spark.sqlContext.conf).newInstance() match { - case ds: ContinuousReadSupport => - val reader = ds.createContinuousReader(Optional.empty(), "", DataSourceOptions.empty()) - assert(reader.isInstanceOf[RateStreamContinuousReader]) - case _ => - throw new IllegalStateException("Could not find v2 read support for rate") - } - } - - test("continuous data") { - val reader = new RateStreamContinuousReader( - new DataSourceOptions(Map("numPartitions" -> "2", "rowsPerSecond" -> "20").asJava)) - reader.setStartOffset(Optional.empty()) - val tasks = reader.createDataReaderFactories() - assert(tasks.size == 2) - - val data = scala.collection.mutable.ListBuffer[Row]() - tasks.asScala.foreach { - case t: RateStreamContinuousDataReaderFactory => - val startTimeMs = reader.getStartOffset() - .asInstanceOf[RateStreamOffset] - .partitionToValueAndRunTimeMs(t.partitionIndex) - .runTimeMs - val r = t.createDataReader().asInstanceOf[RateStreamContinuousDataReader] - for (rowIndex <- 0 to 9) { - r.next() - data.append(r.get()) - assert(r.getOffset() == - RateStreamPartitionOffset( - t.partitionIndex, - t.partitionIndex + rowIndex * 2, - startTimeMs + (rowIndex + 1) * 100)) - } - assert(System.currentTimeMillis() >= startTimeMs + 1000) - - case _ => throw new IllegalStateException("Unexpected task type") - } - - assert(data.map(_.getLong(1)).toSeq.sorted == Range(0, 20)) - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala similarity index 50% rename from sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala index 03d0f63fa4d7f..ff14ec38e66a8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala @@ -15,13 +15,24 @@ * limitations under the License. */ -package org.apache.spark.sql.execution.streaming +package org.apache.spark.sql.execution.streaming.sources +import java.nio.file.Files +import java.util.Optional import java.util.concurrent.TimeUnit -import org.apache.spark.sql.AnalysisException +import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.sql.{AnalysisException, Row, SparkSession} +import org.apache.spark.sql.catalyst.errors.TreeNodeException +import org.apache.spark.sql.execution.datasources.DataSource +import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.execution.streaming.continuous._ import org.apache.spark.sql.functions._ -import org.apache.spark.sql.streaming.{StreamingQueryException, StreamTest} +import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions, MicroBatchReadSupport} +import org.apache.spark.sql.sources.v2.reader.streaming.Offset +import org.apache.spark.sql.streaming.StreamTest import org.apache.spark.util.ManualClock class RateSourceSuite extends StreamTest { @@ -29,18 +40,40 @@ class RateSourceSuite extends StreamTest { import testImplicits._ case class AdvanceRateManualClock(seconds: Long) extends AddData { - override def addData(query: Option[StreamExecution]): (Source, Offset) = { + override def addData(query: Option[StreamExecution]): (BaseStreamingSource, Offset) = { assert(query.nonEmpty) val rateSource = query.get.logicalPlan.collect { - case StreamingExecutionRelation(source, _) if source.isInstanceOf[RateStreamSource] => - source.asInstanceOf[RateStreamSource] + case StreamingExecutionRelation(source: RateStreamMicroBatchReader, _) => source }.head + rateSource.clock.asInstanceOf[ManualClock].advance(TimeUnit.SECONDS.toMillis(seconds)) - (rateSource, rateSource.getOffset.get) + val offset = LongOffset(TimeUnit.MILLISECONDS.toSeconds( + rateSource.clock.getTimeMillis() - rateSource.creationTimeMs)) + (rateSource, offset) + } + } + + test("microbatch in registry") { + DataSource.lookupDataSource("rate", spark.sqlContext.conf).newInstance() match { + case ds: MicroBatchReadSupport => + val reader = ds.createMicroBatchReader(Optional.empty(), "dummy", DataSourceOptions.empty()) + assert(reader.isInstanceOf[RateStreamMicroBatchReader]) + case _ => + throw new IllegalStateException("Could not find read support for rate") + } + } + + test("compatible with old path in registry") { + DataSource.lookupDataSource("org.apache.spark.sql.execution.streaming.RateSourceProvider", + spark.sqlContext.conf).newInstance() match { + case ds: MicroBatchReadSupport => + assert(ds.isInstanceOf[RateStreamProvider]) + case _ => + throw new IllegalStateException("Could not find read support for rate") } } - test("basic") { + test("microbatch - basic") { val input = spark.readStream .format("rate") .option("rowsPerSecond", "10") @@ -57,7 +90,7 @@ class RateSourceSuite extends StreamTest { ) } - test("uniform distribution of event timestamps") { + test("microbatch - uniform distribution of event timestamps") { val input = spark.readStream .format("rate") .option("rowsPerSecond", "1500") @@ -74,8 +107,74 @@ class RateSourceSuite extends StreamTest { ) } + test("microbatch - set offset") { + val temp = Files.createTempDirectory("dummy").toString + val reader = new RateStreamMicroBatchReader(DataSourceOptions.empty(), temp) + val startOffset = LongOffset(0L) + val endOffset = LongOffset(1L) + reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) + assert(reader.getStartOffset() == startOffset) + assert(reader.getEndOffset() == endOffset) + } + + test("microbatch - infer offsets") { + val tempFolder = Files.createTempDirectory("dummy").toString + val reader = new RateStreamMicroBatchReader( + new DataSourceOptions( + Map("numPartitions" -> "1", "rowsPerSecond" -> "100", "useManualClock" -> "true").asJava), + tempFolder) + reader.clock.asInstanceOf[ManualClock].advance(100000) + reader.setOffsetRange(Optional.empty(), Optional.empty()) + reader.getStartOffset() match { + case r: LongOffset => assert(r.offset === 0L) + case _ => throw new IllegalStateException("unexpected offset type") + } + reader.getEndOffset() match { + case r: LongOffset => assert(r.offset >= 100) + case _ => throw new IllegalStateException("unexpected offset type") + } + } + + test("microbatch - predetermined batch size") { + val temp = Files.createTempDirectory("dummy").toString + val reader = new RateStreamMicroBatchReader( + new DataSourceOptions(Map("numPartitions" -> "1", "rowsPerSecond" -> "20").asJava), temp) + val startOffset = LongOffset(0L) + val endOffset = LongOffset(1L) + reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) + val tasks = reader.createDataReaderFactories() + assert(tasks.size == 1) + val dataReader = tasks.get(0).createDataReader() + val data = ArrayBuffer[Row]() + while (dataReader.next()) { + data.append(dataReader.get()) + } + assert(data.size === 20) + } + + test("microbatch - data read") { + val temp = Files.createTempDirectory("dummy").toString + val reader = new RateStreamMicroBatchReader( + new DataSourceOptions(Map("numPartitions" -> "11", "rowsPerSecond" -> "33").asJava), temp) + val startOffset = LongOffset(0L) + val endOffset = LongOffset(1L) + reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) + val tasks = reader.createDataReaderFactories() + assert(tasks.size == 11) + + val readData = tasks.asScala + .map(_.createDataReader()) + .flatMap { reader => + val buf = scala.collection.mutable.ListBuffer[Row]() + while (reader.next()) buf.append(reader.get()) + buf + } + + assert(readData.map(_.getLong(1)).sorted == Range(0, 33)) + } + test("valueAtSecond") { - import RateStreamSource._ + import RateStreamProvider._ assert(valueAtSecond(seconds = 0, rowsPerSecond = 5, rampUpTimeSeconds = 0) === 0) assert(valueAtSecond(seconds = 1, rowsPerSecond = 5, rampUpTimeSeconds = 0) === 5) @@ -161,7 +260,7 @@ class RateSourceSuite extends StreamTest { option: String, value: String, expectedMessages: Seq[String]): Unit = { - val e = intercept[StreamingQueryException] { + val e = intercept[IllegalArgumentException] { spark.readStream .format("rate") .option(option, value) @@ -171,9 +270,8 @@ class RateSourceSuite extends StreamTest { .start() .awaitTermination() } - assert(e.getCause.isInstanceOf[IllegalArgumentException]) for (msg <- expectedMessages) { - assert(e.getCause.getMessage.contains(msg)) + assert(e.getMessage.contains(msg)) } } @@ -191,4 +289,46 @@ class RateSourceSuite extends StreamTest { assert(exception.getMessage.contains( "rate source does not support a user-specified schema")) } + + test("continuous in registry") { + DataSource.lookupDataSource("rate", spark.sqlContext.conf).newInstance() match { + case ds: ContinuousReadSupport => + val reader = ds.createContinuousReader(Optional.empty(), "", DataSourceOptions.empty()) + assert(reader.isInstanceOf[RateStreamContinuousReader]) + case _ => + throw new IllegalStateException("Could not find read support for continuous rate") + } + } + + test("continuous data") { + val reader = new RateStreamContinuousReader( + new DataSourceOptions(Map("numPartitions" -> "2", "rowsPerSecond" -> "20").asJava)) + reader.setStartOffset(Optional.empty()) + val tasks = reader.createDataReaderFactories() + assert(tasks.size == 2) + + val data = scala.collection.mutable.ListBuffer[Row]() + tasks.asScala.foreach { + case t: RateStreamContinuousDataReaderFactory => + val startTimeMs = reader.getStartOffset() + .asInstanceOf[RateStreamOffset] + .partitionToValueAndRunTimeMs(t.partitionIndex) + .runTimeMs + val r = t.createDataReader().asInstanceOf[RateStreamContinuousDataReader] + for (rowIndex <- 0 to 9) { + r.next() + data.append(r.get()) + assert(r.getOffset() == + RateStreamPartitionOffset( + t.partitionIndex, + t.partitionIndex + rowIndex * 2, + startTimeMs + (rowIndex + 1) * 100)) + } + assert(System.currentTimeMillis() >= startTimeMs + 1000) + + case _ => throw new IllegalStateException("Unexpected task type") + } + + assert(data.map(_.getLong(1)).toSeq.sorted == Range(0, 20)) + } } From bc8d0931170cfa20a4fb64b3b11a2027ddb0d6e9 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Fri, 30 Mar 2018 23:21:07 +0800 Subject: [PATCH 0541/1161] [SPARK-23500][SQL][FOLLOWUP] Fix complex type simplification rules to apply to entire plan ## What changes were proposed in this pull request? This PR is to improve the test coverage of the original PR https://github.com/apache/spark/pull/20687 ## How was this patch tested? N/A Author: gatorsmile Closes #20911 from gatorsmile/addTests. --- .../optimizer/complexTypesSuite.scala | 176 ++++++++++++------ .../apache/spark/sql/ComplexTypesSuite.scala | 109 +++++++++++ 2 files changed, 233 insertions(+), 52 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/ComplexTypesSuite.scala diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala index e44a6692ad8e2..21ed987627b3b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala @@ -47,10 +47,17 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { SimplifyExtractValueOps) :: Nil } - val idAtt = ('id).long.notNull - val nullableIdAtt = ('nullable_id).long + private val idAtt = ('id).long.notNull + private val nullableIdAtt = ('nullable_id).long - lazy val relation = LocalRelation(idAtt, nullableIdAtt) + private val relation = LocalRelation(idAtt, nullableIdAtt) + private val testRelation = LocalRelation('a.int, 'b.int, 'c.int, 'd.double, 'e.int) + + private def checkRule(originalQuery: LogicalPlan, correctAnswer: LogicalPlan) = { + val optimized = Optimizer.execute(originalQuery.analyze) + assert(optimized.resolved, "optimized plans must be still resolvable") + comparePlans(optimized, correctAnswer.analyze) + } test("explicit get from namedStruct") { val query = relation @@ -58,31 +65,28 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { GetStructField( CreateNamedStruct(Seq("att", 'id )), 0, - None) as "outerAtt").analyze - val expected = relation.select('id as "outerAtt").analyze + None) as "outerAtt") + val expected = relation.select('id as "outerAtt") - comparePlans(Optimizer execute query, expected) + checkRule(query, expected) } test("explicit get from named_struct- expression maintains original deduced alias") { val query = relation .select(GetStructField(CreateNamedStruct(Seq("att", 'id)), 0, None)) - .analyze val expected = relation .select('id as "named_struct(att, id).att") - .analyze - comparePlans(Optimizer execute query, expected) + checkRule(query, expected) } test("collapsed getStructField ontop of namedStruct") { val query = relation .select(CreateNamedStruct(Seq("att", 'id)) as "struct1") .select(GetStructField('struct1, 0, None) as "struct1Att") - .analyze - val expected = relation.select('id as "struct1Att").analyze - comparePlans(Optimizer execute query, expected) + val expected = relation.select('id as "struct1Att") + checkRule(query, expected) } test("collapse multiple CreateNamedStruct/GetStructField pairs") { @@ -94,16 +98,14 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { .select( GetStructField('struct1, 0, None) as "struct1Att1", GetStructField('struct1, 1, None) as "struct1Att2") - .analyze val expected = relation. select( 'id as "struct1Att1", ('id * 'id) as "struct1Att2") - .analyze - comparePlans(Optimizer execute query, expected) + checkRule(query, expected) } test("collapsed2 - deduced names") { @@ -115,16 +117,14 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { .select( GetStructField('struct1, 0, None), GetStructField('struct1, 1, None)) - .analyze val expected = relation. select( 'id as "struct1.att1", ('id * 'id) as "struct1.att2") - .analyze - comparePlans(Optimizer execute query, expected) + checkRule(query, expected) } test("simplified array ops") { @@ -151,7 +151,6 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { 1, false), 1) as "a4") - .analyze val expected = relation .select( @@ -161,8 +160,7 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { "att2", (('id + 1L) * ('id + 1L)))) as "a2", ('id + 1L) as "a3", ('id + 1L) as "a4") - .analyze - comparePlans(Optimizer execute query, expected) + checkRule(query, expected) } test("SPARK-22570: CreateArray should not create a lot of global variables") { @@ -188,7 +186,6 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { GetStructField(GetMapValue('m, "r1"), 0, None) as "a2", GetMapValue('m, "r32") as "a3", GetStructField(GetMapValue('m, "r32"), 0, None) as "a4") - .analyze val expected = relation.select( @@ -201,8 +198,7 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { ) ) as "a3", Literal.create(null, LongType) as "a4") - .analyze - comparePlans(Optimizer execute query, expected) + checkRule(query, expected) } test("simplify map ops, constant lookup, dynamic keys") { @@ -216,7 +212,6 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { ('id + 3L), ('id + 4L), ('id + 4L), ('id + 5L))), 13L) as "a") - .analyze val expected = relation .select( @@ -225,8 +220,7 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { (EqualTo(13L, ('id + 1L)), ('id + 2L)), (EqualTo(13L, ('id + 2L)), ('id + 3L)), (Literal(true), 'id))) as "a") - .analyze - comparePlans(Optimizer execute query, expected) + checkRule(query, expected) } test("simplify map ops, dynamic lookup, dynamic keys, lookup is equivalent to one of the keys") { @@ -240,7 +234,6 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { ('id + 3L), ('id + 4L), ('id + 4L), ('id + 5L))), ('id + 3L)) as "a") - .analyze val expected = relation .select( CaseWhen(Seq( @@ -248,8 +241,7 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { (EqualTo('id + 3L, ('id + 1L)), ('id + 2L)), (EqualTo('id + 3L, ('id + 2L)), ('id + 3L)), (Literal(true), ('id + 4L)))) as "a") - .analyze - comparePlans(Optimizer execute query, expected) + checkRule(query, expected) } test("simplify map ops, no positive match") { @@ -263,7 +255,6 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { ('id + 3L), ('id + 4L), ('id + 4L), ('id + 5L))), 'id + 30L) as "a") - .analyze val expected = relation.select( CaseWhen(Seq( (EqualTo('id + 30L, 'id), ('id + 1L)), @@ -271,8 +262,7 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { (EqualTo('id + 30L, ('id + 2L)), ('id + 3L)), (EqualTo('id + 30L, ('id + 3L)), ('id + 4L)), (EqualTo('id + 30L, ('id + 4L)), ('id + 5L)))) as "a") - .analyze - comparePlans(Optimizer execute rel, expected) + checkRule(rel, expected) } test("simplify map ops, constant lookup, mixed keys, eliminated constants") { @@ -287,7 +277,6 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { ('id + 3L), ('id + 4L), ('id + 4L), ('id + 5L))), 13L) as "a") - .analyze val expected = relation .select( @@ -297,9 +286,8 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { ('id + 2L), ('id + 3L), ('id + 3L), ('id + 4L), ('id + 4L), ('id + 5L))) as "a") - .analyze - comparePlans(Optimizer execute rel, expected) + checkRule(rel, expected) } test("simplify map ops, potential dynamic match with null value + an absolute constant match") { @@ -314,7 +302,6 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { ('id + 3L), ('id + 4L), ('id + 4L), ('id + 5L))), 2L ) as "a") - .analyze val expected = relation .select( @@ -327,18 +314,69 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { // but it cannot override a potential match with ('id + 2L), // which is exactly what [[Coalesce]] would do in this case. (Literal.TrueLiteral, 'id))) as "a") - .analyze - comparePlans(Optimizer execute rel, expected) + checkRule(rel, expected) + } + + test("SPARK-23500: Simplify array ops that are not at the top node") { + val query = LocalRelation('id.long) + .select( + CreateArray(Seq( + CreateNamedStruct(Seq( + "att1", 'id, + "att2", 'id * 'id)), + CreateNamedStruct(Seq( + "att1", 'id + 1, + "att2", ('id + 1) * ('id + 1)) + )) + ) as "arr") + .select( + GetStructField(GetArrayItem('arr, 1), 0, None) as "a1", + GetArrayItem( + GetArrayStructFields('arr, + StructField("att1", LongType, nullable = false), + ordinal = 0, + numFields = 1, + containsNull = false), + ordinal = 1) as "a2") + .orderBy('id.asc) + + val expected = LocalRelation('id.long) + .select( + ('id + 1L) as "a1", + ('id + 1L) as "a2") + .orderBy('id.asc) + checkRule(query, expected) + } + + test("SPARK-23500: Simplify map ops that are not top nodes") { + val query = + LocalRelation('id.long) + .select( + CreateMap(Seq( + "r1", 'id, + "r2", 'id + 1L)) as "m") + .select( + GetMapValue('m, "r1") as "a1", + GetMapValue('m, "r32") as "a2") + .orderBy('id.asc) + .select('a1, 'a2) + + val expected = + LocalRelation('id.long).select( + 'id as "a1", + Literal.create(null, LongType) as "a2") + .orderBy('id.asc) + checkRule(query, expected) } test("SPARK-23500: Simplify complex ops that aren't at the plan root") { val structRel = relation .select(GetStructField(CreateNamedStruct(Seq("att1", 'nullable_id)), 0, None) as "foo") - .groupBy($"foo")("1").analyze + .groupBy($"foo")("1") val structExpected = relation .select('nullable_id as "foo") - .groupBy($"foo")("1").analyze - comparePlans(Optimizer execute structRel, structExpected) + .groupBy($"foo")("1") + checkRule(structRel, structExpected) // These tests must use nullable attributes from the base relation for the following reason: // in the 'original' plans below, the Aggregate node produced by groupBy() has a @@ -351,17 +389,17 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { // SPARK-23634. val arrayRel = relation .select(GetArrayItem(CreateArray(Seq('nullable_id, 'nullable_id + 1L)), 0) as "a1") - .groupBy($"a1")("1").analyze - val arrayExpected = relation.select('nullable_id as "a1").groupBy($"a1")("1").analyze - comparePlans(Optimizer execute arrayRel, arrayExpected) + .groupBy($"a1")("1") + val arrayExpected = relation.select('nullable_id as "a1").groupBy($"a1")("1") + checkRule(arrayRel, arrayExpected) val mapRel = relation .select(GetMapValue(CreateMap(Seq("id", 'nullable_id)), "id") as "m1") - .groupBy($"m1")("1").analyze + .groupBy($"m1")("1") val mapExpected = relation .select('nullable_id as "m1") - .groupBy($"m1")("1").analyze - comparePlans(Optimizer execute mapRel, mapExpected) + .groupBy($"m1")("1") + checkRule(mapRel, mapExpected) } test("SPARK-23500: Ensure that aggregation expressions are not simplified") { @@ -369,11 +407,45 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { // grouping exprs so aren't tested here. val structAggRel = relation.groupBy( CreateNamedStruct(Seq("att1", 'nullable_id)))( - GetStructField(CreateNamedStruct(Seq("att1", 'nullable_id)), 0, None)).analyze - comparePlans(Optimizer execute structAggRel, structAggRel) + GetStructField(CreateNamedStruct(Seq("att1", 'nullable_id)), 0, None)) + checkRule(structAggRel, structAggRel) val arrayAggRel = relation.groupBy( - CreateArray(Seq('nullable_id)))(GetArrayItem(CreateArray(Seq('nullable_id)), 0)).analyze - comparePlans(Optimizer execute arrayAggRel, arrayAggRel) + CreateArray(Seq('nullable_id)))(GetArrayItem(CreateArray(Seq('nullable_id)), 0)) + checkRule(arrayAggRel, arrayAggRel) + + // This could be done if we had a more complex rule that checks that + // the CreateMap does not come from key. + val originalQuery = relation + .groupBy('id)( + GetMapValue(CreateMap(Seq('id, 'id + 1L)), 0L) as "a" + ) + checkRule(originalQuery, originalQuery) + } + + test("SPARK-23500: namedStruct and getField in the same Project #1") { + val originalQuery = + testRelation + .select( + namedStruct("col1", 'b, "col2", 'c).as("s1"), 'a, 'b) + .select('s1 getField "col2" as 's1Col2, + namedStruct("col1", 'a, "col2", 'b).as("s2")) + .select('s1Col2, 's2 getField "col2" as 's2Col2) + val correctAnswer = + testRelation + .select('c as 's1Col2, 'b as 's2Col2) + checkRule(originalQuery, correctAnswer) + } + + test("SPARK-23500: namedStruct and getField in the same Project #2") { + val originalQuery = + testRelation + .select( + namedStruct("col1", 'b, "col2", 'c) getField "col2" as 'sCol2, + namedStruct("col1", 'a, "col2", 'c) getField "col1" as 'sCol1) + val correctAnswer = + testRelation + .select('c as 'sCol2, 'a as 'sCol1) + checkRule(originalQuery, correctAnswer) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ComplexTypesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ComplexTypesSuite.scala new file mode 100644 index 0000000000000..b74fe2f90df23 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/ComplexTypesSuite.scala @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.catalyst.expressions.CreateNamedStruct +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.test.SharedSQLContext + +class ComplexTypesSuite extends QueryTest with SharedSQLContext { + + override def beforeAll() { + super.beforeAll() + spark.range(10).selectExpr( + "id + 1 as i1", "id + 2 as i2", "id + 3 as i3", "id + 4 as i4", "id + 5 as i5") + .write.saveAsTable("tab") + } + + override def afterAll() { + try { + spark.sql("DROP TABLE IF EXISTS tab") + } finally { + super.afterAll() + } + } + + def checkNamedStruct(plan: LogicalPlan, expectedCount: Int): Unit = { + var count = 0 + plan.foreach { operator => + operator.transformExpressions { + case c: CreateNamedStruct => + count += 1 + c + } + } + + if (expectedCount != count) { + fail(s"expect $expectedCount CreateNamedStruct but got $count.") + } + } + + test("simple case") { + val df = spark.table("tab").selectExpr( + "i5", "named_struct('a', i1, 'b', i2) as col1", "named_struct('a', i3, 'c', i4) as col2") + .filter("col2.c > 11").selectExpr("col1.a") + checkAnswer(df, Row(9) :: Row(10) :: Nil) + checkNamedStruct(df.queryExecution.optimizedPlan, expectedCount = 0) + } + + test("named_struct is used in the top Project") { + val df = spark.table("tab").selectExpr( + "i5", "named_struct('a', i1, 'b', i2) as col1", "named_struct('a', i3, 'c', i4)") + .selectExpr("col1.a", "col1") + .filter("col1.a > 8") + checkAnswer(df, Row(9, Row(9, 10)) :: Row(10, Row(10, 11)) :: Nil) + checkNamedStruct(df.queryExecution.optimizedPlan, expectedCount = 1) + + val df1 = spark.table("tab").selectExpr( + "i5", "named_struct('a', i1, 'b', i2) as col1", "named_struct('a', i3, 'c', i4)") + .sort("col1") + .selectExpr("col1.a") + .filter("col1.a > 8") + checkAnswer(df1, Row(9) :: Row(10) :: Nil) + checkNamedStruct(df1.queryExecution.optimizedPlan, expectedCount = 1) + } + + test("expression in named_struct") { + val df = spark.table("tab") + .selectExpr("i5", "struct(i1 as exp, i2, i3) as cola") + .selectExpr("cola.exp", "cola.i3").filter("cola.i3 > 10") + checkAnswer(df, Row(9, 11) :: Row(10, 12) :: Nil) + checkNamedStruct(df.queryExecution.optimizedPlan, expectedCount = 0) + + val df1 = spark.table("tab") + .selectExpr("i5", "struct(i1 + 1 as exp, i2, i3) as cola") + .selectExpr("cola.i3").filter("cola.exp > 10") + checkAnswer(df1, Row(12) :: Nil) + checkNamedStruct(df1.queryExecution.optimizedPlan, expectedCount = 0) + } + + test("nested case") { + val df = spark.table("tab") + .selectExpr("struct(struct(i2, i3) as exp, i4) as cola") + .selectExpr("cola.exp.i2", "cola.i4").filter("cola.exp.i2 > 10") + checkAnswer(df, Row(11, 13) :: Nil) + checkNamedStruct(df.queryExecution.optimizedPlan, expectedCount = 0) + + val df1 = spark.table("tab") + .selectExpr("struct(i2, i3) as exp", "i4") + .selectExpr("struct(exp, i4) as cola") + .selectExpr("cola.exp.i2", "cola.i4").filter("cola.i4 > 11") + checkAnswer(df1, Row(10, 12) :: Row(11, 13) :: Nil) + checkNamedStruct(df.queryExecution.optimizedPlan, expectedCount = 0) + } +} From ae9172017c361e5c1039bc2ca94048117021974a Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Fri, 30 Mar 2018 14:09:14 -0700 Subject: [PATCH 0542/1161] [SPARK-23640][CORE] Fix hadoop config may override spark config ## What changes were proposed in this pull request? It may be get `spark.shuffle.service.port` from https://github.com/apache/spark/blob/9745ec3a61c99be59ef6a9d5eebd445e8af65b7a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala#L459 Therefore, the client configuration `spark.shuffle.service.port` does not working unless the configuration is `spark.hadoop.spark.shuffle.service.port`. - This configuration is not working: ``` bin/spark-sql --master yarn --conf spark.shuffle.service.port=7338 ``` - This configuration works: ``` bin/spark-sql --master yarn --conf spark.hadoop.spark.shuffle.service.port=7338 ``` This PR fix this issue. ## How was this patch tested? It's difficult to carry out unit testing. But I've tested it manually. Author: Yuming Wang Closes #20785 from wangyum/SPARK-23640. --- .../scala/org/apache/spark/util/Utils.scala | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 5caedeb526469..d2be93226e2a2 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -2302,16 +2302,20 @@ private[spark] object Utils extends Logging { } /** - * Return the value of a config either through the SparkConf or the Hadoop configuration - * if this is Yarn mode. In the latter case, this defaults to the value set through SparkConf - * if the key is not set in the Hadoop configuration. + * Return the value of a config either through the SparkConf or the Hadoop configuration. + * We Check whether the key is set in the SparkConf before look at any Hadoop configuration. + * If the key is set in SparkConf, no matter whether it is running on YARN or not, + * gets the value from SparkConf. + * Only when the key is not set in SparkConf and running on YARN, + * gets the value from Hadoop configuration. */ def getSparkOrYarnConfig(conf: SparkConf, key: String, default: String): String = { - val sparkValue = conf.get(key, default) - if (conf.get(SparkLauncher.SPARK_MASTER, null) == "yarn") { - new YarnConfiguration(SparkHadoopUtil.get.newConfiguration(conf)).get(key, sparkValue) + if (conf.contains(key)) { + conf.get(key, default) + } else if (conf.get(SparkLauncher.SPARK_MASTER, null) == "yarn") { + new YarnConfiguration(SparkHadoopUtil.get.newConfiguration(conf)).get(key, default) } else { - sparkValue + default } } From 15298b99ac8944e781328423289586176cf824d7 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Fri, 30 Mar 2018 16:48:26 -0700 Subject: [PATCH 0543/1161] [SPARK-23827][SS] StreamingJoinExec should ensure that input data is partitioned into specific number of partitions ## What changes were proposed in this pull request? Currently, the requiredChildDistribution does not specify the partitions. This can cause the weird corner cases where the child's distribution is `SinglePartition` which satisfies the required distribution of `ClusterDistribution(no-num-partition-requirement)`, thus eliminating the shuffle needed to repartition input data into the required number of partitions (i.e. same as state stores). That can lead to "file not found" errors on the state store delta files as the micro-batch-with-no-shuffle will not run certain tasks and therefore not generate the expected state store delta files. This PR adds the required constraint on the number of partitions. ## How was this patch tested? Modified test harness to always check that ANY stateful operator should have a constraint on the number of partitions. As part of that, the existing opt-in checks on child output partitioning were removed, as they are redundant. Author: Tathagata Das Closes #20941 from tdas/SPARK-23827. --- .../streaming/IncrementalExecution.scala | 2 +- .../StreamingSymmetricHashJoinExec.scala | 3 +- .../sql/streaming/DeduplicateSuite.scala | 8 +-- .../FlatMapGroupsWithStateSuite.scala | 5 +- .../sql/streaming/StatefulOperatorTest.scala | 49 ------------------- .../spark/sql/streaming/StreamTest.scala | 19 +++++++ .../streaming/StreamingAggregationSuite.scala | 4 +- 7 files changed, 25 insertions(+), 65 deletions(-) delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/streaming/StatefulOperatorTest.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index a10ed5f2df1b5..1a83c884d55bd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -62,7 +62,7 @@ class IncrementalExecution( StreamingDeduplicationStrategy :: Nil } - private val numStateStores = offsetSeqMetadata.conf.get(SQLConf.SHUFFLE_PARTITIONS.key) + private[sql] val numStateStores = offsetSeqMetadata.conf.get(SQLConf.SHUFFLE_PARTITIONS.key) .map(SQLConf.SHUFFLE_PARTITIONS.valueConverter) .getOrElse(sparkSession.sessionState.conf.numShufflePartitions) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala index c351f658cb955..fa7c8ee906ecd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala @@ -167,7 +167,8 @@ case class StreamingSymmetricHashJoinExec( val nullRight = new GenericInternalRow(right.output.map(_.withNullability(true)).length) override def requiredChildDistribution: Seq[Distribution] = - ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil + ClusteredDistribution(leftKeys, stateInfo.map(_.numPartitions)) :: + ClusteredDistribution(rightKeys, stateInfo.map(_.numPartitions)) :: Nil override def output: Seq[Attribute] = joinType match { case _: InnerLike => left.output ++ right.output diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/DeduplicateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/DeduplicateSuite.scala index caf2bab8a5859..0088b64d6195e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/DeduplicateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/DeduplicateSuite.scala @@ -25,9 +25,7 @@ import org.apache.spark.sql.execution.streaming.{MemoryStream, StreamingDeduplic import org.apache.spark.sql.execution.streaming.state.StateStore import org.apache.spark.sql.functions._ -class DeduplicateSuite extends StateStoreMetricsTest - with BeforeAndAfterAll - with StatefulOperatorTest { +class DeduplicateSuite extends StateStoreMetricsTest with BeforeAndAfterAll { import testImplicits._ @@ -44,8 +42,6 @@ class DeduplicateSuite extends StateStoreMetricsTest AddData(inputData, "a"), CheckLastBatch("a"), assertNumStateRows(total = 1, updated = 1), - AssertOnQuery(sq => - checkChildOutputHashPartitioning[StreamingDeduplicateExec](sq, Seq("value"))), AddData(inputData, "a"), CheckLastBatch(), assertNumStateRows(total = 1, updated = 0), @@ -63,8 +59,6 @@ class DeduplicateSuite extends StateStoreMetricsTest AddData(inputData, "a" -> 1), CheckLastBatch("a" -> 1), assertNumStateRows(total = 1, updated = 1), - AssertOnQuery(sq => - checkChildOutputHashPartitioning[StreamingDeduplicateExec](sq, Seq("_1"))), AddData(inputData, "a" -> 2), // Dropped CheckLastBatch(), assertNumStateRows(total = 1, updated = 0), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala index de2b51678cea6..b1416bff87ee7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala @@ -42,8 +42,7 @@ case class RunningCount(count: Long) case class Result(key: Long, count: Int) class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest - with BeforeAndAfterAll - with StatefulOperatorTest { + with BeforeAndAfterAll { import testImplicits._ import GroupStateImpl._ @@ -618,8 +617,6 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest AddData(inputData, "a"), CheckLastBatch(("a", "1")), assertNumStateRows(total = 1, updated = 1), - AssertOnQuery(sq => checkChildOutputHashPartitioning[FlatMapGroupsWithStateExec]( - sq, Seq("value"))), AddData(inputData, "a", "b"), CheckLastBatch(("a", "2"), ("b", "1")), assertNumStateRows(total = 2, updated = 2), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StatefulOperatorTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StatefulOperatorTest.scala deleted file mode 100644 index 45142278993bb..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StatefulOperatorTest.scala +++ /dev/null @@ -1,49 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.streaming - -import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.execution.streaming._ - -trait StatefulOperatorTest { - /** - * Check that the output partitioning of a child operator of a Stateful operator satisfies the - * distribution that we expect for our Stateful operator. - */ - protected def checkChildOutputHashPartitioning[T <: StatefulOperator]( - sq: StreamingQuery, - colNames: Seq[String]): Boolean = { - val attr = sq.asInstanceOf[StreamExecution].lastExecution.analyzed.output - val partitions = sq.sparkSession.sessionState.conf.numShufflePartitions - val groupingAttr = attr.filter(a => colNames.contains(a.name)) - checkChildOutputPartitioning(sq, HashPartitioning(groupingAttr, partitions)) - } - - /** - * Check that the output partitioning of a child operator of a Stateful operator satisfies the - * distribution that we expect for our Stateful operator. - */ - protected def checkChildOutputPartitioning[T <: StatefulOperator]( - sq: StreamingQuery, - expectedPartitioning: Partitioning): Boolean = { - val operator = sq.asInstanceOf[StreamExecution].lastExecution - .executedPlan.collect { case p: T => p } - operator.head.children.forall( - _.outputPartitioning.numPartitions == expectedPartitioning.numPartitions) - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index e44aef09f1f3c..00741d660dd2d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -37,6 +37,7 @@ import org.apache.spark.SparkEnv import org.apache.spark.sql.{Dataset, Encoder, QueryTest, Row} import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder, RowEncoder} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.physical.AllTuples import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.execution.datasources.v2.StreamingDataSourceV2Relation import org.apache.spark.sql.execution.streaming._ @@ -444,6 +445,24 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be } } + val lastExecution = currentStream.lastExecution + if (currentStream.isInstanceOf[MicroBatchExecution] && lastExecution != null) { + // Verify if stateful operators have correct metadata and distribution + // This can often catch hard to debug errors when developing stateful operators + lastExecution.executedPlan.collect { case s: StatefulOperator => s }.foreach { s => + assert(s.stateInfo.map(_.numPartitions).contains(lastExecution.numStateStores)) + s.requiredChildDistribution.foreach { d => + withClue(s"$s specifies incorrect # partitions in requiredChildDistribution $d") { + assert(d.requiredNumPartitions.isDefined) + assert(d.requiredNumPartitions.get >= 1) + if (d != AllTuples) { + assert(d.requiredNumPartitions.get == s.stateInfo.get.numPartitions) + } + } + } + } + } + val (latestBatchData, allData) = sink match { case s: MemorySink => (s.latestBatchData, s.allData) case s: MemorySinkV2 => (s.latestBatchData, s.allData) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala index 97e065193fd05..1cae8cb8d47f1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala @@ -44,7 +44,7 @@ object FailureSingleton { } class StreamingAggregationSuite extends StateStoreMetricsTest - with BeforeAndAfterAll with Assertions with StatefulOperatorTest { + with BeforeAndAfterAll with Assertions { override def afterAll(): Unit = { super.afterAll() @@ -281,8 +281,6 @@ class StreamingAggregationSuite extends StateStoreMetricsTest AddData(inputData, 0L, 5L, 5L, 10L), AdvanceManualClock(10 * 1000), CheckLastBatch((0L, 1), (5L, 2), (10L, 1)), - AssertOnQuery(sq => - checkChildOutputHashPartitioning[StateStoreRestoreExec](sq, Seq("value"))), // advance clock to 20 seconds, should retain keys >= 10 AddData(inputData, 15L, 15L, 20L), From 529f847105fa8d98a5dc4d20955e4870df6bc1c5 Mon Sep 17 00:00:00 2001 From: Xingbo Jiang Date: Sat, 31 Mar 2018 10:34:01 +0800 Subject: [PATCH 0544/1161] [SPARK-23040][CORE][FOLLOW-UP] Avoid double wrap result Iterator. ## What changes were proposed in this pull request? Address https://github.com/apache/spark/pull/20449#discussion_r172414393, If `resultIter` is already a `InterruptibleIterator`, don't double wrap it. ## How was this patch tested? Existing tests. Author: Xingbo Jiang Closes #20920 from jiangxb1987/SPARK-23040. --- .../spark/shuffle/BlockStoreShuffleReader.scala | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala index 85e7e56a04a7d..4103dfb10175e 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala @@ -111,8 +111,13 @@ private[spark] class BlockStoreShuffleReader[K, C]( case None => aggregatedIter } - // Use another interruptible iterator here to support task cancellation as aggregator or(and) - // sorter may have consumed previous interruptible iterator. - new InterruptibleIterator[Product2[K, C]](context, resultIter) + + resultIter match { + case _: InterruptibleIterator[Product2[K, C]] => resultIter + case _ => + // Use another interruptible iterator here to support task cancellation as aggregator + // or(and) sorter may have consumed previous interruptible iterator. + new InterruptibleIterator[Product2[K, C]](context, resultIter) + } } } From 44a9f8e6e82c300dc61ca18515aee16f17f27501 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Mon, 2 Apr 2018 09:53:37 -0700 Subject: [PATCH 0545/1161] [SPARK-15009][PYTHON][FOLLOWUP] Add default param checks for CountVectorizerModel ## What changes were proposed in this pull request? Adding test for default params for `CountVectorizerModel` constructed from vocabulary. This required that the param `maxDF` be added, which was done in SPARK-23615. ## How was this patch tested? Added an explicit test for CountVectorizerModel in DefaultValuesTests. Author: Bryan Cutler Closes #20942 from BryanCutler/pyspark-CountVectorizerModel-default-param-test-SPARK-15009. --- python/pyspark/ml/tests.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 6b4376cbf14e8..c2c4861e2aff4 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -2096,6 +2096,11 @@ def test_java_params(self): # NOTE: disable check_params_exist until there is parity with Scala API ParamTests.check_params(self, cls(), check_params_exist=False) + # Additional classes that need explicit construction + from pyspark.ml.feature import CountVectorizerModel + ParamTests.check_params(self, CountVectorizerModel.from_vocabulary(['a'], 'input'), + check_params_exist=False) + def _squared_distance(a, b): if isinstance(a, Vector): From 6151f29f9f589301159482044fc32717f430db6e Mon Sep 17 00:00:00 2001 From: David Vogelbacher Date: Mon, 2 Apr 2018 12:00:37 -0700 Subject: [PATCH 0546/1161] [SPARK-23825][K8S] Requesting memory + memory overhead for pod memory ## What changes were proposed in this pull request? Kubernetes driver and executor pods should request `memory + memoryOverhead` as their resources instead of just `memory`, see https://issues.apache.org/jira/browse/SPARK-23825 ## How was this patch tested? Existing unit tests were adapted. Author: David Vogelbacher Closes #20943 from dvogelbacher/spark-23825. --- .../k8s/submit/steps/BasicDriverConfigurationStep.scala | 5 +---- .../spark/scheduler/cluster/k8s/ExecutorPodFactory.scala | 5 +---- .../submit/steps/BasicDriverConfigurationStepSuite.scala | 2 +- .../scheduler/cluster/k8s/ExecutorPodFactorySuite.scala | 6 ++++-- 4 files changed, 7 insertions(+), 11 deletions(-) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStep.scala index 347c4d2d66826..b811db324108c 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStep.scala @@ -93,9 +93,6 @@ private[spark] class BasicDriverConfigurationStep( .withAmount(driverCpuCores) .build() val driverMemoryQuantity = new QuantityBuilder(false) - .withAmount(s"${driverMemoryMiB}Mi") - .build() - val driverMemoryLimitQuantity = new QuantityBuilder(false) .withAmount(s"${driverMemoryWithOverheadMiB}Mi") .build() val maybeCpuLimitQuantity = driverLimitCores.map { limitCores => @@ -117,7 +114,7 @@ private[spark] class BasicDriverConfigurationStep( .withNewResources() .addToRequests("cpu", driverCpuQuantity) .addToRequests("memory", driverMemoryQuantity) - .addToLimits("memory", driverMemoryLimitQuantity) + .addToLimits("memory", driverMemoryQuantity) .addToLimits(maybeCpuLimitQuantity.toMap.asJava) .endResources() .addToArgs("driver") diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala index 98cbd5607da00..ac42385459dda 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala @@ -108,9 +108,6 @@ private[spark] class ExecutorPodFactory( SPARK_ROLE_LABEL -> SPARK_POD_EXECUTOR_ROLE) ++ executorLabels val executorMemoryQuantity = new QuantityBuilder(false) - .withAmount(s"${executorMemoryMiB}Mi") - .build() - val executorMemoryLimitQuantity = new QuantityBuilder(false) .withAmount(s"${executorMemoryWithOverhead}Mi") .build() val executorCpuQuantity = new QuantityBuilder(false) @@ -167,7 +164,7 @@ private[spark] class ExecutorPodFactory( .withImagePullPolicy(imagePullPolicy) .withNewResources() .addToRequests("memory", executorMemoryQuantity) - .addToLimits("memory", executorMemoryLimitQuantity) + .addToLimits("memory", executorMemoryQuantity) .addToRequests("cpu", executorCpuQuantity) .endResources() .addAllToEnv(executorEnv.asJava) diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStepSuite.scala index ce068531c7673..e59c6d28a8cc2 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStepSuite.scala @@ -91,7 +91,7 @@ class BasicDriverConfigurationStepSuite extends SparkFunSuite { val resourceRequirements = preparedDriverSpec.driverContainer.getResources val requests = resourceRequirements.getRequests.asScala assert(requests("cpu").getAmount === "2") - assert(requests("memory").getAmount === "256Mi") + assert(requests("memory").getAmount === "456Mi") val limits = resourceRequirements.getLimits.asScala assert(limits("memory").getAmount === "456Mi") assert(limits("cpu").getAmount === "4") diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala index 7755b93835047..cee8fe27039c9 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala @@ -66,12 +66,14 @@ class ExecutorPodFactorySuite extends SparkFunSuite with BeforeAndAfter with Bef assert(executor.getMetadata.getLabels.size() === 3) assert(executor.getMetadata.getLabels.get(SPARK_EXECUTOR_ID_LABEL) === "1") - // There is exactly 1 container with no volume mounts and default memory limits. - // Default memory limit is 1024M + 384M (minimum overhead constant). + // There is exactly 1 container with no volume mounts and default memory limits and requests. + // Default memory limit/request is 1024M + 384M (minimum overhead constant). assert(executor.getSpec.getContainers.size() === 1) assert(executor.getSpec.getContainers.get(0).getImage === executorImage) assert(executor.getSpec.getContainers.get(0).getVolumeMounts.isEmpty) assert(executor.getSpec.getContainers.get(0).getResources.getLimits.size() === 1) + assert(executor.getSpec.getContainers.get(0).getResources + .getRequests.get("memory").getAmount === "1408Mi") assert(executor.getSpec.getContainers.get(0).getResources .getLimits.get("memory").getAmount === "1408Mi") From fe2b7a4568d65a62da6e6eb00fff05f248b4332c Mon Sep 17 00:00:00 2001 From: Yinan Li Date: Mon, 2 Apr 2018 12:20:55 -0700 Subject: [PATCH 0547/1161] [SPARK-23285][K8S] Add a config property for specifying physical executor cores ## What changes were proposed in this pull request? As mentioned in SPARK-23285, this PR introduces a new configuration property `spark.kubernetes.executor.cores` for specifying the physical CPU cores requested for each executor pod. This is to avoid changing the semantics of `spark.executor.cores` and `spark.task.cpus` and their role in task scheduling, task parallelism, dynamic resource allocation, etc. The new configuration property only determines the physical CPU cores available to an executor. An executor can still run multiple tasks simultaneously by using appropriate values for `spark.executor.cores` and `spark.task.cpus`. ## How was this patch tested? Unit tests. felixcheung srowen jiangxb1987 jerryshao mccheah foxish Author: Yinan Li Author: Yinan Li Closes #20553 from liyinan926/master. --- docs/running-on-kubernetes.md | 15 ++++++++--- .../org/apache/spark/deploy/k8s/Config.scala | 6 +++++ .../cluster/k8s/ExecutorPodFactory.scala | 12 ++++++--- .../cluster/k8s/ExecutorPodFactorySuite.scala | 27 +++++++++++++++++++ 4 files changed, 53 insertions(+), 7 deletions(-) diff --git a/docs/running-on-kubernetes.md b/docs/running-on-kubernetes.md index 975b28de47e20..9c4644947c911 100644 --- a/docs/running-on-kubernetes.md +++ b/docs/running-on-kubernetes.md @@ -549,14 +549,23 @@ specific to Spark on Kubernetes. spark.kubernetes.driver.limit.cores (none) - Specify the hard CPU [limit](https://kubernetes.io/docs/concepts/configuration/manage-compute-resources-container/#resource-requests-and-limits-of-pod-and-container) for the driver pod. + Specify a hard cpu [limit](https://kubernetes.io/docs/concepts/configuration/manage-compute-resources-container/#resource-requests-and-limits-of-pod-and-container) for the driver pod. + + spark.kubernetes.executor.request.cores + (none) + + Specify the cpu request for each executor pod. Values conform to the Kubernetes [convention](https://kubernetes.io/docs/concepts/configuration/manage-compute-resources-container/#meaning-of-cpu). + Example values include 0.1, 500m, 1.5, 5, etc., with the definition of cpu units documented in [CPU units](https://kubernetes.io/docs/tasks/configure-pod-container/assign-cpu-resource/#cpu-units). + This is distinct from spark.executor.cores: it is only used and takes precedence over spark.executor.cores for specifying the executor pod cpu request if set. Task + parallelism, e.g., number of tasks an executor can run concurrently is not affected by this. + spark.kubernetes.executor.limit.cores (none) - Specify the hard CPU [limit](https://kubernetes.io/docs/concepts/configuration/manage-compute-resources-container/#resource-requests-and-limits-of-pod-and-container) for each executor pod launched for the Spark Application. + Specify a hard cpu [limit](https://kubernetes.io/docs/concepts/configuration/manage-compute-resources-container/#resource-requests-and-limits-of-pod-and-container) for each executor pod launched for the Spark Application. @@ -593,4 +602,4 @@ specific to Spark on Kubernetes. spark.kubernetes.executor.secrets.spark-secret=/etc/secrets. - \ No newline at end of file + diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala index da34a7e06238a..405ea476351bb 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala @@ -91,6 +91,12 @@ private[spark] object Config extends Logging { .stringConf .createOptional + val KUBERNETES_EXECUTOR_REQUEST_CORES = + ConfigBuilder("spark.kubernetes.executor.request.cores") + .doc("Specify the cpu request for each executor pod") + .stringConf + .createOptional + val KUBERNETES_DRIVER_POD_NAME = ConfigBuilder("spark.kubernetes.driver.pod.name") .doc("Name of the driver pod.") diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala index ac42385459dda..7143f7a6f0b71 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala @@ -83,7 +83,12 @@ private[spark] class ExecutorPodFactory( MEMORY_OVERHEAD_MIN_MIB)) private val executorMemoryWithOverhead = executorMemoryMiB + memoryOverheadMiB - private val executorCores = sparkConf.getDouble("spark.executor.cores", 1) + private val executorCores = sparkConf.getInt("spark.executor.cores", 1) + private val executorCoresRequest = if (sparkConf.contains(KUBERNETES_EXECUTOR_REQUEST_CORES)) { + sparkConf.get(KUBERNETES_EXECUTOR_REQUEST_CORES).get + } else { + executorCores.toString + } private val executorLimitCores = sparkConf.get(KUBERNETES_EXECUTOR_LIMIT_CORES) /** @@ -111,7 +116,7 @@ private[spark] class ExecutorPodFactory( .withAmount(s"${executorMemoryWithOverhead}Mi") .build() val executorCpuQuantity = new QuantityBuilder(false) - .withAmount(executorCores.toString) + .withAmount(executorCoresRequest) .build() val executorExtraClasspathEnv = executorExtraClasspath.map { cp => new EnvVarBuilder() @@ -130,8 +135,7 @@ private[spark] class ExecutorPodFactory( }.getOrElse(Seq.empty[EnvVar]) val executorEnv = (Seq( (ENV_DRIVER_URL, driverUrl), - // Executor backend expects integral value for executor cores, so round it up to an int. - (ENV_EXECUTOR_CORES, math.ceil(executorCores).toInt.toString), + (ENV_EXECUTOR_CORES, executorCores.toString), (ENV_EXECUTOR_MEMORY, executorMemoryString), (ENV_APPLICATION_ID, applicationId), // This is to set the SPARK_CONF_DIR to be /opt/spark/conf diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala index cee8fe27039c9..a71a2a1b888bc 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala @@ -85,6 +85,33 @@ class ExecutorPodFactorySuite extends SparkFunSuite with BeforeAndAfter with Bef checkOwnerReferences(executor, driverPodUid) } + test("executor core request specification") { + var factory = new ExecutorPodFactory(baseConf, None) + var executor = factory.createExecutorPod( + "1", "dummy", "dummy", Seq[(String, String)](), driverPod, Map[String, Int]()) + assert(executor.getSpec.getContainers.size() === 1) + assert(executor.getSpec.getContainers.get(0).getResources.getRequests.get("cpu").getAmount + === "1") + + val conf = baseConf.clone() + + conf.set(KUBERNETES_EXECUTOR_REQUEST_CORES, "0.1") + factory = new ExecutorPodFactory(conf, None) + executor = factory.createExecutorPod( + "1", "dummy", "dummy", Seq[(String, String)](), driverPod, Map[String, Int]()) + assert(executor.getSpec.getContainers.size() === 1) + assert(executor.getSpec.getContainers.get(0).getResources.getRequests.get("cpu").getAmount + === "0.1") + + conf.set(KUBERNETES_EXECUTOR_REQUEST_CORES, "100m") + factory = new ExecutorPodFactory(conf, None) + conf.set(KUBERNETES_EXECUTOR_REQUEST_CORES, "100m") + executor = factory.createExecutorPod( + "1", "dummy", "dummy", Seq[(String, String)](), driverPod, Map[String, Int]()) + assert(executor.getSpec.getContainers.get(0).getResources.getRequests.get("cpu").getAmount + === "100m") + } + test("executor pod hostnames get truncated to 63 characters") { val conf = baseConf.clone() conf.set(KUBERNETES_EXECUTOR_POD_NAME_PREFIX, From a7c19d9c21d59fd0109a7078c80b33d3da03fafd Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Mon, 2 Apr 2018 21:48:44 +0200 Subject: [PATCH 0548/1161] [SPARK-23713][SQL] Cleanup UnsafeWriter and BufferHolder classes ## What changes were proposed in this pull request? This PR implemented the following cleanups related to `UnsafeWriter` class: - Remove code duplication between `UnsafeRowWriter` and `UnsafeArrayWriter` - Make `BufferHolder` class internal by delegating its accessor methods to `UnsafeWriter` - Replace `UnsafeRow.setTotalSize(...)` with `UnsafeRowWriter.setTotalSize()` ## How was this patch tested? Tested by existing UTs Author: Kazuaki Ishizaki Closes #20850 from kiszk/SPARK-23713. --- .../sql/kafka010/KafkaContinuousReader.scala | 3 - .../KafkaRecordToUnsafeRowConverter.scala | 11 +- .../expressions/codegen/BufferHolder.java | 32 +-- .../codegen/UnsafeArrayWriter.java | 133 +++--------- .../expressions/codegen/UnsafeRowWriter.java | 189 +++++++----------- .../expressions/codegen/UnsafeWriter.java | 157 ++++++++++++++- .../InterpretedUnsafeProjection.scala | 90 ++++----- .../codegen/GenerateUnsafeProjection.scala | 124 +++++------- .../RowBasedKeyValueBatchSuite.java | 28 +-- .../aggregate/RowBasedHashMapGenerator.scala | 12 +- .../columnar/GenerateColumnAccessor.scala | 9 +- .../datasources/text/TextFileFormat.scala | 11 +- 12 files changed, 391 insertions(+), 408 deletions(-) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala index e7e27876088f3..f26c134c2f6e9 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala @@ -27,13 +27,10 @@ import org.apache.spark.TaskContext import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.expressions.UnsafeRow -import org.apache.spark.sql.catalyst.expressions.codegen.{BufferHolder, UnsafeRowWriter} -import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.kafka010.KafkaSourceProvider.{INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE, INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE} import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousDataReader, ContinuousReader, Offset, PartitionOffset} import org.apache.spark.sql.types.StructType -import org.apache.spark.unsafe.types.UTF8String /** * A [[ContinuousReader]] for data from kafka. diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRecordToUnsafeRowConverter.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRecordToUnsafeRowConverter.scala index 1acdd56125741..f35a143e00374 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRecordToUnsafeRowConverter.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRecordToUnsafeRowConverter.scala @@ -20,18 +20,16 @@ package org.apache.spark.sql.kafka010 import org.apache.kafka.clients.consumer.ConsumerRecord import org.apache.spark.sql.catalyst.expressions.UnsafeRow -import org.apache.spark.sql.catalyst.expressions.codegen.{BufferHolder, UnsafeRowWriter} +import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.unsafe.types.UTF8String /** A simple class for converting Kafka ConsumerRecord to UnsafeRow */ private[kafka010] class KafkaRecordToUnsafeRowConverter { - private val sharedRow = new UnsafeRow(7) - private val bufferHolder = new BufferHolder(sharedRow) - private val rowWriter = new UnsafeRowWriter(bufferHolder, 7) + private val rowWriter = new UnsafeRowWriter(7) def toUnsafeRow(record: ConsumerRecord[Array[Byte], Array[Byte]]): UnsafeRow = { - bufferHolder.reset() + rowWriter.reset() if (record.key == null) { rowWriter.setNullAt(0) @@ -46,7 +44,6 @@ private[kafka010] class KafkaRecordToUnsafeRowConverter { 5, DateTimeUtils.fromJavaTimestamp(new java.sql.Timestamp(record.timestamp))) rowWriter.write(6, record.timestampType.id) - sharedRow.setTotalSize(bufferHolder.totalSize) - sharedRow + rowWriter.getRow() } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java index 259976118c12f..537ef244b7e81 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java @@ -30,25 +30,21 @@ * this class per writing program, so that the memory segment/data buffer can be reused. Note that * for each incoming record, we should call `reset` of BufferHolder instance before write the record * and reuse the data buffer. - * - * Generally we should call `UnsafeRow.setTotalSize` and pass in `BufferHolder.totalSize` to update - * the size of the result row, after writing a record to the buffer. However, we can skip this step - * if the fields of row are all fixed-length, as the size of result row is also fixed. */ -public class BufferHolder { +final class BufferHolder { private static final int ARRAY_MAX = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH; - public byte[] buffer; - public int cursor = Platform.BYTE_ARRAY_OFFSET; + private byte[] buffer; + private int cursor = Platform.BYTE_ARRAY_OFFSET; private final UnsafeRow row; private final int fixedSize; - public BufferHolder(UnsafeRow row) { + BufferHolder(UnsafeRow row) { this(row, 64); } - public BufferHolder(UnsafeRow row, int initialSize) { + BufferHolder(UnsafeRow row, int initialSize) { int bitsetWidthInBytes = UnsafeRow.calculateBitSetWidthInBytes(row.numFields()); if (row.numFields() > (ARRAY_MAX - initialSize - bitsetWidthInBytes) / 8) { throw new UnsupportedOperationException( @@ -64,7 +60,7 @@ public BufferHolder(UnsafeRow row, int initialSize) { /** * Grows the buffer by at least neededSize and points the row to the buffer. */ - public void grow(int neededSize) { + void grow(int neededSize) { if (neededSize > ARRAY_MAX - totalSize()) { throw new UnsupportedOperationException( "Cannot grow BufferHolder by size " + neededSize + " because the size after growing " + @@ -86,11 +82,23 @@ public void grow(int neededSize) { } } - public void reset() { + byte[] getBuffer() { + return buffer; + } + + int getCursor() { + return cursor; + } + + void increaseCursor(int val) { + cursor += val; + } + + void reset() { cursor = Platform.BYTE_ARRAY_OFFSET + fixedSize; } - public int totalSize() { + int totalSize() { return cursor - Platform.BYTE_ARRAY_OFFSET; } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java index 82cd1b24607e1..a78dd970d23e4 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java @@ -21,8 +21,6 @@ import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.bitset.BitSetMethods; -import org.apache.spark.unsafe.types.CalendarInterval; -import org.apache.spark.unsafe.types.UTF8String; import static org.apache.spark.sql.catalyst.expressions.UnsafeArrayData.calculateHeaderPortionInBytes; @@ -32,14 +30,12 @@ */ public final class UnsafeArrayWriter extends UnsafeWriter { - private BufferHolder holder; - - // The offset of the global buffer where we start to write this array. - private int startingOffset; - // The number of elements in this array private int numElements; + // The element size in this array + private int elementSize; + private int headerInBytes; private void assertIndexIsValid(int index) { @@ -47,13 +43,17 @@ private void assertIndexIsValid(int index) { assert index < numElements : "index (" + index + ") should < " + numElements; } - public void initialize(BufferHolder holder, int numElements, int elementSize) { + public UnsafeArrayWriter(UnsafeWriter writer, int elementSize) { + super(writer.getBufferHolder()); + this.elementSize = elementSize; + } + + public void initialize(int numElements) { // We need 8 bytes to store numElements in header this.numElements = numElements; this.headerInBytes = calculateHeaderPortionInBytes(numElements); - this.holder = holder; - this.startingOffset = holder.cursor; + this.startingOffset = cursor(); // Grows the global buffer ahead for header and fixed size data. int fixedPartInBytes = @@ -61,112 +61,92 @@ public void initialize(BufferHolder holder, int numElements, int elementSize) { holder.grow(headerInBytes + fixedPartInBytes); // Write numElements and clear out null bits to header - Platform.putLong(holder.buffer, startingOffset, numElements); + Platform.putLong(getBuffer(), startingOffset, numElements); for (int i = 8; i < headerInBytes; i += 8) { - Platform.putLong(holder.buffer, startingOffset + i, 0L); + Platform.putLong(getBuffer(), startingOffset + i, 0L); } // fill 0 into reminder part of 8-bytes alignment in unsafe array for (int i = elementSize * numElements; i < fixedPartInBytes; i++) { - Platform.putByte(holder.buffer, startingOffset + headerInBytes + i, (byte) 0); + Platform.putByte(getBuffer(), startingOffset + headerInBytes + i, (byte) 0); } - holder.cursor += (headerInBytes + fixedPartInBytes); + increaseCursor(headerInBytes + fixedPartInBytes); } - private void zeroOutPaddingBytes(int numBytes) { - if ((numBytes & 0x07) > 0) { - Platform.putLong(holder.buffer, holder.cursor + ((numBytes >> 3) << 3), 0L); - } - } - - private long getElementOffset(int ordinal, int elementSize) { + private long getElementOffset(int ordinal) { return startingOffset + headerInBytes + ordinal * elementSize; } - public void setOffsetAndSize(int ordinal, int currentCursor, int size) { - assertIndexIsValid(ordinal); - final long relativeOffset = currentCursor - startingOffset; - final long offsetAndSize = (relativeOffset << 32) | (long)size; - - write(ordinal, offsetAndSize); - } - private void setNullBit(int ordinal) { assertIndexIsValid(ordinal); - BitSetMethods.set(holder.buffer, startingOffset + 8, ordinal); + BitSetMethods.set(getBuffer(), startingOffset + 8, ordinal); } public void setNull1Bytes(int ordinal) { setNullBit(ordinal); // put zero into the corresponding field when set null - Platform.putByte(holder.buffer, getElementOffset(ordinal, 1), (byte)0); + writeByte(getElementOffset(ordinal), (byte)0); } public void setNull2Bytes(int ordinal) { setNullBit(ordinal); // put zero into the corresponding field when set null - Platform.putShort(holder.buffer, getElementOffset(ordinal, 2), (short)0); + writeShort(getElementOffset(ordinal), (short)0); } public void setNull4Bytes(int ordinal) { setNullBit(ordinal); // put zero into the corresponding field when set null - Platform.putInt(holder.buffer, getElementOffset(ordinal, 4), 0); + writeInt(getElementOffset(ordinal), 0); } public void setNull8Bytes(int ordinal) { setNullBit(ordinal); // put zero into the corresponding field when set null - Platform.putLong(holder.buffer, getElementOffset(ordinal, 8), (long)0); + writeLong(getElementOffset(ordinal), 0); } public void setNull(int ordinal) { setNull8Bytes(ordinal); } public void write(int ordinal, boolean value) { assertIndexIsValid(ordinal); - Platform.putBoolean(holder.buffer, getElementOffset(ordinal, 1), value); + writeBoolean(getElementOffset(ordinal), value); } public void write(int ordinal, byte value) { assertIndexIsValid(ordinal); - Platform.putByte(holder.buffer, getElementOffset(ordinal, 1), value); + writeByte(getElementOffset(ordinal), value); } public void write(int ordinal, short value) { assertIndexIsValid(ordinal); - Platform.putShort(holder.buffer, getElementOffset(ordinal, 2), value); + writeShort(getElementOffset(ordinal), value); } public void write(int ordinal, int value) { assertIndexIsValid(ordinal); - Platform.putInt(holder.buffer, getElementOffset(ordinal, 4), value); + writeInt(getElementOffset(ordinal), value); } public void write(int ordinal, long value) { assertIndexIsValid(ordinal); - Platform.putLong(holder.buffer, getElementOffset(ordinal, 8), value); + writeLong(getElementOffset(ordinal), value); } public void write(int ordinal, float value) { - if (Float.isNaN(value)) { - value = Float.NaN; - } assertIndexIsValid(ordinal); - Platform.putFloat(holder.buffer, getElementOffset(ordinal, 4), value); + writeFloat(getElementOffset(ordinal), value); } public void write(int ordinal, double value) { - if (Double.isNaN(value)) { - value = Double.NaN; - } assertIndexIsValid(ordinal); - Platform.putDouble(holder.buffer, getElementOffset(ordinal, 8), value); + writeDouble(getElementOffset(ordinal), value); } public void write(int ordinal, Decimal input, int precision, int scale) { // make sure Decimal object has the same scale as DecimalType assertIndexIsValid(ordinal); - if (input.changePrecision(precision, scale)) { + if (input != null && input.changePrecision(precision, scale)) { if (precision <= Decimal.MAX_LONG_DIGITS()) { write(ordinal, input.toUnscaledLong()); } else { @@ -180,65 +160,14 @@ public void write(int ordinal, Decimal input, int precision, int scale) { // Write the bytes to the variable length portion. Platform.copyMemory( - bytes, Platform.BYTE_ARRAY_OFFSET, holder.buffer, holder.cursor, numBytes); - setOffsetAndSize(ordinal, holder.cursor, numBytes); + bytes, Platform.BYTE_ARRAY_OFFSET, getBuffer(), cursor(), numBytes); + setOffsetAndSize(ordinal, numBytes); // move the cursor forward with 8-bytes boundary - holder.cursor += roundedSize; + increaseCursor(roundedSize); } } else { setNull(ordinal); } } - - public void write(int ordinal, UTF8String input) { - final int numBytes = input.numBytes(); - final int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes); - - // grow the global buffer before writing data. - holder.grow(roundedSize); - - zeroOutPaddingBytes(numBytes); - - // Write the bytes to the variable length portion. - input.writeToMemory(holder.buffer, holder.cursor); - - setOffsetAndSize(ordinal, holder.cursor, numBytes); - - // move the cursor forward. - holder.cursor += roundedSize; - } - - public void write(int ordinal, byte[] input) { - final int numBytes = input.length; - final int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(input.length); - - // grow the global buffer before writing data. - holder.grow(roundedSize); - - zeroOutPaddingBytes(numBytes); - - // Write the bytes to the variable length portion. - Platform.copyMemory( - input, Platform.BYTE_ARRAY_OFFSET, holder.buffer, holder.cursor, numBytes); - - setOffsetAndSize(ordinal, holder.cursor, numBytes); - - // move the cursor forward. - holder.cursor += roundedSize; - } - - public void write(int ordinal, CalendarInterval input) { - // grow the global buffer before writing data. - holder.grow(16); - - // Write the months and microseconds fields of Interval to the variable length portion. - Platform.putLong(holder.buffer, holder.cursor, input.months); - Platform.putLong(holder.buffer, holder.cursor + 8, input.microseconds); - - setOffsetAndSize(ordinal, holder.cursor, 16); - - // move the cursor forward. - holder.cursor += 16; - } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java index 2620bbcfb87a2..71c49d8ed0177 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java @@ -20,10 +20,7 @@ import org.apache.spark.sql.catalyst.expressions.UnsafeRow; import org.apache.spark.sql.types.Decimal; import org.apache.spark.unsafe.Platform; -import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.bitset.BitSetMethods; -import org.apache.spark.unsafe.types.CalendarInterval; -import org.apache.spark.unsafe.types.UTF8String; /** * A helper class to write data into global row buffer using `UnsafeRow` format. @@ -31,7 +28,7 @@ * It will remember the offset of row buffer which it starts to write, and move the cursor of row * buffer while writing. If new data(can be the input record if this is the outermost writer, or * nested struct if this is an inner writer) comes, the starting cursor of row buffer may be - * changed, so we need to call `UnsafeRowWriter.reset` before writing, to update the + * changed, so we need to call `UnsafeRowWriter.resetRowWriter` before writing, to update the * `startingOffset` and clear out null bits. * * Note that if this is the outermost writer, which means we will always write from the very @@ -40,29 +37,58 @@ */ public final class UnsafeRowWriter extends UnsafeWriter { - private final BufferHolder holder; - // The offset of the global buffer where we start to write this row. - private int startingOffset; + private final UnsafeRow row; + private final int nullBitsSize; private final int fixedSize; - public UnsafeRowWriter(BufferHolder holder, int numFields) { - this.holder = holder; + public UnsafeRowWriter(int numFields) { + this(new UnsafeRow(numFields)); + } + + public UnsafeRowWriter(int numFields, int initialBufferSize) { + this(new UnsafeRow(numFields), initialBufferSize); + } + + public UnsafeRowWriter(UnsafeWriter writer, int numFields) { + this(null, writer.getBufferHolder(), numFields); + } + + private UnsafeRowWriter(UnsafeRow row) { + this(row, new BufferHolder(row), row.numFields()); + } + + private UnsafeRowWriter(UnsafeRow row, int initialBufferSize) { + this(row, new BufferHolder(row, initialBufferSize), row.numFields()); + } + + private UnsafeRowWriter(UnsafeRow row, BufferHolder holder, int numFields) { + super(holder); + this.row = row; this.nullBitsSize = UnsafeRow.calculateBitSetWidthInBytes(numFields); this.fixedSize = nullBitsSize + 8 * numFields; - this.startingOffset = holder.cursor; + this.startingOffset = cursor(); + } + + /** + * Updates total size of the UnsafeRow using the size collected by BufferHolder, and returns + * the UnsafeRow created at a constructor + */ + public UnsafeRow getRow() { + row.setTotalSize(totalSize()); + return row; } /** * Resets the `startingOffset` according to the current cursor of row buffer, and clear out null * bits. This should be called before we write a new nested struct to the row buffer. */ - public void reset() { - this.startingOffset = holder.cursor; + public void resetRowWriter() { + this.startingOffset = cursor(); // grow the global buffer to make sure it has enough space to write fixed-length data. - holder.grow(fixedSize); - holder.cursor += fixedSize; + grow(fixedSize); + increaseCursor(fixedSize); zeroOutNullBytes(); } @@ -72,25 +98,17 @@ public void reset() { */ public void zeroOutNullBytes() { for (int i = 0; i < nullBitsSize; i += 8) { - Platform.putLong(holder.buffer, startingOffset + i, 0L); - } - } - - private void zeroOutPaddingBytes(int numBytes) { - if ((numBytes & 0x07) > 0) { - Platform.putLong(holder.buffer, holder.cursor + ((numBytes >> 3) << 3), 0L); + Platform.putLong(getBuffer(), startingOffset + i, 0L); } } - public BufferHolder holder() { return holder; } - public boolean isNullAt(int ordinal) { - return BitSetMethods.isSet(holder.buffer, startingOffset, ordinal); + return BitSetMethods.isSet(getBuffer(), startingOffset, ordinal); } public void setNullAt(int ordinal) { - BitSetMethods.set(holder.buffer, startingOffset, ordinal); - Platform.putLong(holder.buffer, getFieldOffset(ordinal), 0L); + BitSetMethods.set(getBuffer(), startingOffset, ordinal); + write(ordinal, 0L); } @Override @@ -117,67 +135,49 @@ public long getFieldOffset(int ordinal) { return startingOffset + nullBitsSize + 8 * ordinal; } - public void setOffsetAndSize(int ordinal, int size) { - setOffsetAndSize(ordinal, holder.cursor, size); - } - - public void setOffsetAndSize(int ordinal, int currentCursor, int size) { - final long relativeOffset = currentCursor - startingOffset; - final long fieldOffset = getFieldOffset(ordinal); - final long offsetAndSize = (relativeOffset << 32) | (long) size; - - Platform.putLong(holder.buffer, fieldOffset, offsetAndSize); - } - public void write(int ordinal, boolean value) { final long offset = getFieldOffset(ordinal); - Platform.putLong(holder.buffer, offset, 0L); - Platform.putBoolean(holder.buffer, offset, value); + writeLong(offset, 0L); + writeBoolean(offset, value); } public void write(int ordinal, byte value) { final long offset = getFieldOffset(ordinal); - Platform.putLong(holder.buffer, offset, 0L); - Platform.putByte(holder.buffer, offset, value); + writeLong(offset, 0L); + writeByte(offset, value); } public void write(int ordinal, short value) { final long offset = getFieldOffset(ordinal); - Platform.putLong(holder.buffer, offset, 0L); - Platform.putShort(holder.buffer, offset, value); + writeLong(offset, 0L); + writeShort(offset, value); } public void write(int ordinal, int value) { final long offset = getFieldOffset(ordinal); - Platform.putLong(holder.buffer, offset, 0L); - Platform.putInt(holder.buffer, offset, value); + writeLong(offset, 0L); + writeInt(offset, value); } public void write(int ordinal, long value) { - Platform.putLong(holder.buffer, getFieldOffset(ordinal), value); + writeLong(getFieldOffset(ordinal), value); } public void write(int ordinal, float value) { - if (Float.isNaN(value)) { - value = Float.NaN; - } final long offset = getFieldOffset(ordinal); - Platform.putLong(holder.buffer, offset, 0L); - Platform.putFloat(holder.buffer, offset, value); + writeLong(offset, 0); + writeFloat(offset, value); } public void write(int ordinal, double value) { - if (Double.isNaN(value)) { - value = Double.NaN; - } - Platform.putDouble(holder.buffer, getFieldOffset(ordinal), value); + writeDouble(getFieldOffset(ordinal), value); } public void write(int ordinal, Decimal input, int precision, int scale) { if (precision <= Decimal.MAX_LONG_DIGITS()) { // make sure Decimal object has the same scale as DecimalType - if (input.changePrecision(precision, scale)) { - Platform.putLong(holder.buffer, getFieldOffset(ordinal), input.toUnscaledLong()); + if (input != null && input.changePrecision(precision, scale)) { + write(ordinal, input.toUnscaledLong()); } else { setNullAt(ordinal); } @@ -185,82 +185,31 @@ public void write(int ordinal, Decimal input, int precision, int scale) { // grow the global buffer before writing data. holder.grow(16); - // zero-out the bytes - Platform.putLong(holder.buffer, holder.cursor, 0L); - Platform.putLong(holder.buffer, holder.cursor + 8, 0L); - // Make sure Decimal object has the same scale as DecimalType. // Note that we may pass in null Decimal object to set null for it. if (input == null || !input.changePrecision(precision, scale)) { - BitSetMethods.set(holder.buffer, startingOffset, ordinal); + // zero-out the bytes + Platform.putLong(getBuffer(), cursor(), 0L); + Platform.putLong(getBuffer(), cursor() + 8, 0L); + + BitSetMethods.set(getBuffer(), startingOffset, ordinal); // keep the offset for future update setOffsetAndSize(ordinal, 0); } else { final byte[] bytes = input.toJavaBigDecimal().unscaledValue().toByteArray(); - assert bytes.length <= 16; + final int numBytes = bytes.length; + assert numBytes <= 16; + + zeroOutPaddingBytes(numBytes); // Write the bytes to the variable length portion. Platform.copyMemory( - bytes, Platform.BYTE_ARRAY_OFFSET, holder.buffer, holder.cursor, bytes.length); + bytes, Platform.BYTE_ARRAY_OFFSET, getBuffer(), cursor(), numBytes); setOffsetAndSize(ordinal, bytes.length); } // move the cursor forward. - holder.cursor += 16; + increaseCursor(16); } } - - public void write(int ordinal, UTF8String input) { - final int numBytes = input.numBytes(); - final int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes); - - // grow the global buffer before writing data. - holder.grow(roundedSize); - - zeroOutPaddingBytes(numBytes); - - // Write the bytes to the variable length portion. - input.writeToMemory(holder.buffer, holder.cursor); - - setOffsetAndSize(ordinal, numBytes); - - // move the cursor forward. - holder.cursor += roundedSize; - } - - public void write(int ordinal, byte[] input) { - write(ordinal, input, 0, input.length); - } - - public void write(int ordinal, byte[] input, int offset, int numBytes) { - final int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes); - - // grow the global buffer before writing data. - holder.grow(roundedSize); - - zeroOutPaddingBytes(numBytes); - - // Write the bytes to the variable length portion. - Platform.copyMemory(input, Platform.BYTE_ARRAY_OFFSET + offset, - holder.buffer, holder.cursor, numBytes); - - setOffsetAndSize(ordinal, numBytes); - - // move the cursor forward. - holder.cursor += roundedSize; - } - - public void write(int ordinal, CalendarInterval input) { - // grow the global buffer before writing data. - holder.grow(16); - - // Write the months and microseconds fields of Interval to the variable length portion. - Platform.putLong(holder.buffer, holder.cursor, input.months); - Platform.putLong(holder.buffer, holder.cursor + 8, input.microseconds); - - setOffsetAndSize(ordinal, 16); - - // move the cursor forward. - holder.cursor += 16; - } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java index c94b5c7a367ef..de0eb6dbb76be 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.expressions.codegen; import org.apache.spark.sql.types.Decimal; +import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; @@ -24,10 +26,73 @@ * Base class for writing Unsafe* structures. */ public abstract class UnsafeWriter { + // Keep internal buffer holder + protected final BufferHolder holder; + + // The offset of the global buffer where we start to write this structure. + protected int startingOffset; + + protected UnsafeWriter(BufferHolder holder) { + this.holder = holder; + } + + /** + * Accessor methods are delegated from BufferHolder class + */ + public final BufferHolder getBufferHolder() { + return holder; + } + + public final byte[] getBuffer() { + return holder.getBuffer(); + } + + public final void reset() { + holder.reset(); + } + + public final int totalSize() { + return holder.totalSize(); + } + + public final void grow(int neededSize) { + holder.grow(neededSize); + } + + public final int cursor() { + return holder.getCursor(); + } + + public final void increaseCursor(int val) { + holder.increaseCursor(val); + } + + public final void setOffsetAndSizeFromPreviousCursor(int ordinal, int previousCursor) { + setOffsetAndSize(ordinal, previousCursor, cursor() - previousCursor); + } + + protected void setOffsetAndSize(int ordinal, int size) { + setOffsetAndSize(ordinal, cursor(), size); + } + + protected void setOffsetAndSize(int ordinal, int currentCursor, int size) { + final long relativeOffset = currentCursor - startingOffset; + final long offsetAndSize = (relativeOffset << 32) | (long)size; + + write(ordinal, offsetAndSize); + } + + protected final void zeroOutPaddingBytes(int numBytes) { + if ((numBytes & 0x07) > 0) { + Platform.putLong(getBuffer(), cursor() + ((numBytes >> 3) << 3), 0L); + } + } + public abstract void setNull1Bytes(int ordinal); public abstract void setNull2Bytes(int ordinal); public abstract void setNull4Bytes(int ordinal); public abstract void setNull8Bytes(int ordinal); + public abstract void write(int ordinal, boolean value); public abstract void write(int ordinal, byte value); public abstract void write(int ordinal, short value); @@ -36,8 +101,92 @@ public abstract class UnsafeWriter { public abstract void write(int ordinal, float value); public abstract void write(int ordinal, double value); public abstract void write(int ordinal, Decimal input, int precision, int scale); - public abstract void write(int ordinal, UTF8String input); - public abstract void write(int ordinal, byte[] input); - public abstract void write(int ordinal, CalendarInterval input); - public abstract void setOffsetAndSize(int ordinal, int currentCursor, int size); + + public final void write(int ordinal, UTF8String input) { + final int numBytes = input.numBytes(); + final int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes); + + // grow the global buffer before writing data. + grow(roundedSize); + + zeroOutPaddingBytes(numBytes); + + // Write the bytes to the variable length portion. + input.writeToMemory(getBuffer(), cursor()); + + setOffsetAndSize(ordinal, numBytes); + + // move the cursor forward. + increaseCursor(roundedSize); + } + + public final void write(int ordinal, byte[] input) { + write(ordinal, input, 0, input.length); + } + + public final void write(int ordinal, byte[] input, int offset, int numBytes) { + final int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(input.length); + + // grow the global buffer before writing data. + grow(roundedSize); + + zeroOutPaddingBytes(numBytes); + + // Write the bytes to the variable length portion. + Platform.copyMemory( + input, Platform.BYTE_ARRAY_OFFSET + offset, getBuffer(), cursor(), numBytes); + + setOffsetAndSize(ordinal, numBytes); + + // move the cursor forward. + increaseCursor(roundedSize); + } + + public final void write(int ordinal, CalendarInterval input) { + // grow the global buffer before writing data. + grow(16); + + // Write the months and microseconds fields of Interval to the variable length portion. + Platform.putLong(getBuffer(), cursor(), input.months); + Platform.putLong(getBuffer(), cursor() + 8, input.microseconds); + + setOffsetAndSize(ordinal, 16); + + // move the cursor forward. + increaseCursor(16); + } + + protected final void writeBoolean(long offset, boolean value) { + Platform.putBoolean(getBuffer(), offset, value); + } + + protected final void writeByte(long offset, byte value) { + Platform.putByte(getBuffer(), offset, value); + } + + protected final void writeShort(long offset, short value) { + Platform.putShort(getBuffer(), offset, value); + } + + protected final void writeInt(long offset, int value) { + Platform.putInt(getBuffer(), offset, value); + } + + protected final void writeLong(long offset, long value) { + Platform.putLong(getBuffer(), offset, value); + } + + protected final void writeFloat(long offset, float value) { + if (Float.isNaN(value)) { + value = Float.NaN; + } + Platform.putFloat(getBuffer(), offset, value); + } + + protected final void writeDouble(long offset, double value) { + if (Double.isNaN(value)) { + value = Double.NaN; + } + Platform.putDouble(getBuffer(), offset, value); + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala index 0da5ece7e47fe..b31466f5c92d1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.{BufferHolder, UnsafeArrayWriter, UnsafeRowWriter, UnsafeWriter} +import org.apache.spark.sql.catalyst.expressions.codegen.{UnsafeArrayWriter, UnsafeRowWriter, UnsafeWriter} import org.apache.spark.sql.catalyst.util.ArrayData import org.apache.spark.sql.types.{UserDefinedType, _} import org.apache.spark.unsafe.Platform @@ -42,17 +42,12 @@ class InterpretedUnsafeProjection(expressions: Array[Expression]) extends Unsafe /** The row representing the expression results. */ private[this] val intermediate = new GenericInternalRow(values) - /** The row returned by the projection. */ - private[this] val result = new UnsafeRow(numFields) - - /** The buffer which holds the resulting row's backing data. */ - private[this] val holder = new BufferHolder(result, numFields * 32) + /* The row writer for UnsafeRow result */ + private[this] val rowWriter = new UnsafeRowWriter(numFields, numFields * 32) /** The writer that writes the intermediate result to the result row. */ private[this] val writer: InternalRow => Unit = { - val rowWriter = new UnsafeRowWriter(holder, numFields) val baseWriter = generateStructWriter( - holder, rowWriter, expressions.map(e => StructField("", e.dataType, e.nullable))) if (!expressions.exists(_.nullable)) { @@ -83,10 +78,9 @@ class InterpretedUnsafeProjection(expressions: Array[Expression]) extends Unsafe } // Write the intermediate row to an unsafe row. - holder.reset() + rowWriter.reset() writer(intermediate) - result.setTotalSize(holder.totalSize()) - result + rowWriter.getRow() } } @@ -111,14 +105,13 @@ object InterpretedUnsafeProjection extends UnsafeProjectionCreator { * given buffer using the given [[UnsafeRowWriter]]. */ private def generateStructWriter( - bufferHolder: BufferHolder, rowWriter: UnsafeRowWriter, fields: Array[StructField]): InternalRow => Unit = { val numFields = fields.length // Create field writers. val fieldWriters = fields.map { field => - generateFieldWriter(bufferHolder, rowWriter, field.dataType, field.nullable) + generateFieldWriter(rowWriter, field.dataType, field.nullable) } // Create basic writer. row => { @@ -136,7 +129,6 @@ object InterpretedUnsafeProjection extends UnsafeProjectionCreator { * or array) to the given buffer using the given [[UnsafeWriter]]. */ private def generateFieldWriter( - bufferHolder: BufferHolder, writer: UnsafeWriter, dt: DataType, nullable: Boolean): (SpecializedGetters, Int) => Unit = { @@ -178,81 +170,79 @@ object InterpretedUnsafeProjection extends UnsafeProjectionCreator { case StructType(fields) => val numFields = fields.length - val rowWriter = new UnsafeRowWriter(bufferHolder, numFields) - val structWriter = generateStructWriter(bufferHolder, rowWriter, fields) + val rowWriter = new UnsafeRowWriter(writer, numFields) + val structWriter = generateStructWriter(rowWriter, fields) (v, i) => { - val tmpCursor = bufferHolder.cursor + val previousCursor = writer.cursor() v.getStruct(i, fields.length) match { case row: UnsafeRow => writeUnsafeData( - bufferHolder, + rowWriter, row.getBaseObject, row.getBaseOffset, row.getSizeInBytes) case row => // Nested struct. We don't know where this will start because a row can be // variable length, so we need to update the offsets and zero out the bit mask. - rowWriter.reset() + rowWriter.resetRowWriter() structWriter.apply(row) } - writer.setOffsetAndSize(i, tmpCursor, bufferHolder.cursor - tmpCursor) + writer.setOffsetAndSizeFromPreviousCursor(i, previousCursor) } case ArrayType(elementType, containsNull) => - val arrayWriter = new UnsafeArrayWriter - val elementSize = getElementSize(elementType) + val arrayWriter = new UnsafeArrayWriter(writer, getElementSize(elementType)) val elementWriter = generateFieldWriter( - bufferHolder, arrayWriter, elementType, containsNull) (v, i) => { - val tmpCursor = bufferHolder.cursor - writeArray(bufferHolder, arrayWriter, elementWriter, v.getArray(i), elementSize) - writer.setOffsetAndSize(i, tmpCursor, bufferHolder.cursor - tmpCursor) + val previousCursor = writer.cursor() + writeArray(arrayWriter, elementWriter, v.getArray(i)) + writer.setOffsetAndSizeFromPreviousCursor(i, previousCursor) } case MapType(keyType, valueType, valueContainsNull) => - val keyArrayWriter = new UnsafeArrayWriter - val keySize = getElementSize(keyType) + val keyArrayWriter = new UnsafeArrayWriter(writer, getElementSize(keyType)) val keyWriter = generateFieldWriter( - bufferHolder, keyArrayWriter, keyType, nullable = false) - val valueArrayWriter = new UnsafeArrayWriter - val valueSize = getElementSize(valueType) + val valueArrayWriter = new UnsafeArrayWriter(writer, getElementSize(valueType)) val valueWriter = generateFieldWriter( - bufferHolder, valueArrayWriter, valueType, valueContainsNull) (v, i) => { - val tmpCursor = bufferHolder.cursor + val previousCursor = writer.cursor() v.getMap(i) match { case map: UnsafeMapData => writeUnsafeData( - bufferHolder, + valueArrayWriter, map.getBaseObject, map.getBaseOffset, map.getSizeInBytes) case map => // preserve 8 bytes to write the key array numBytes later. - bufferHolder.grow(8) - bufferHolder.cursor += 8 + valueArrayWriter.grow(8) + valueArrayWriter.increaseCursor(8) // Write the keys and write the numBytes of key array into the first 8 bytes. - writeArray(bufferHolder, keyArrayWriter, keyWriter, map.keyArray(), keySize) - Platform.putLong(bufferHolder.buffer, tmpCursor, bufferHolder.cursor - tmpCursor - 8) + writeArray(keyArrayWriter, keyWriter, map.keyArray()) + Platform.putLong( + valueArrayWriter.getBuffer, + previousCursor, + valueArrayWriter.cursor - previousCursor - 8 + ) // Write the values. - writeArray(bufferHolder, valueArrayWriter, valueWriter, map.valueArray(), valueSize) + writeArray(valueArrayWriter, valueWriter, map.valueArray()) } - writer.setOffsetAndSize(i, tmpCursor, bufferHolder.cursor - tmpCursor) + writer.setOffsetAndSizeFromPreviousCursor(i, previousCursor) } case udt: UserDefinedType[_] => - generateFieldWriter(bufferHolder, writer, udt.sqlType, nullable) + generateFieldWriter(writer, udt.sqlType, nullable) case NullType => (_, _) => {} @@ -324,20 +314,18 @@ object InterpretedUnsafeProjection extends UnsafeProjectionCreator { * copy. */ private def writeArray( - bufferHolder: BufferHolder, arrayWriter: UnsafeArrayWriter, elementWriter: (SpecializedGetters, Int) => Unit, - array: ArrayData, - elementSize: Int): Unit = array match { + array: ArrayData): Unit = array match { case unsafe: UnsafeArrayData => writeUnsafeData( - bufferHolder, + arrayWriter, unsafe.getBaseObject, unsafe.getBaseOffset, unsafe.getSizeInBytes) case _ => val numElements = array.numElements() - arrayWriter.initialize(bufferHolder, numElements, elementSize) + arrayWriter.initialize(numElements) var i = 0 while (i < numElements) { elementWriter.apply(array, i) @@ -350,17 +338,17 @@ object InterpretedUnsafeProjection extends UnsafeProjectionCreator { * [[UnsafeRow]], [[UnsafeArrayData]] and [[UnsafeMapData]] objects. */ private def writeUnsafeData( - bufferHolder: BufferHolder, + writer: UnsafeWriter, baseObject: AnyRef, baseOffset: Long, sizeInBytes: Int) : Unit = { - bufferHolder.grow(sizeInBytes) + writer.grow(sizeInBytes) Platform.copyMemory( baseObject, baseOffset, - bufferHolder.buffer, - bufferHolder.cursor, + writer.getBuffer, + writer.cursor, sizeInBytes) - bufferHolder.cursor += sizeInBytes + writer.increaseCursor(sizeInBytes) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 6682ba55b18b1..ab2254cd9f70a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -48,19 +48,23 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro ctx: CodegenContext, input: String, fieldTypes: Seq[DataType], - bufferHolder: String): String = { + rowWriter: String): String = { // Puts `input` in a local variable to avoid to re-evaluate it if it's a statement. val tmpInput = ctx.freshName("tmpInput") val fieldEvals = fieldTypes.zipWithIndex.map { case (dt, i) => ExprCode("", s"$tmpInput.isNullAt($i)", CodeGenerator.getValue(tmpInput, dt, i.toString)) } + val rowWriterClass = classOf[UnsafeRowWriter].getName + val structRowWriter = ctx.addMutableState(rowWriterClass, "rowWriter", + v => s"$v = new $rowWriterClass($rowWriter, ${fieldEvals.length});") + s""" final InternalRow $tmpInput = $input; if ($tmpInput instanceof UnsafeRow) { - ${writeUnsafeData(ctx, s"((UnsafeRow) $tmpInput)", bufferHolder)} + ${writeUnsafeData(ctx, s"((UnsafeRow) $tmpInput)", structRowWriter)} } else { - ${writeExpressionsToBuffer(ctx, tmpInput, fieldEvals, fieldTypes, bufferHolder)} + ${writeExpressionsToBuffer(ctx, tmpInput, fieldEvals, fieldTypes, structRowWriter)} } """ } @@ -70,12 +74,8 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro row: String, inputs: Seq[ExprCode], inputTypes: Seq[DataType], - bufferHolder: String, + rowWriter: String, isTopLevel: Boolean = false): String = { - val rowWriterClass = classOf[UnsafeRowWriter].getName - val rowWriter = ctx.addMutableState(rowWriterClass, "rowWriter", - v => s"$v = new $rowWriterClass($bufferHolder, ${inputs.length});") - val resetWriter = if (isTopLevel) { // For top level row writer, it always writes to the beginning of the global buffer holder, // which means its fixed-size region always in the same position, so we don't need to call @@ -88,7 +88,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro s"$rowWriter.zeroOutNullBytes();" } } else { - s"$rowWriter.reset();" + s"$rowWriter.resetRowWriter();" } val writeFields = inputs.zip(inputTypes).zipWithIndex.map { @@ -97,7 +97,6 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro case udt: UserDefinedType[_] => udt.sqlType case other => other } - val tmpCursor = ctx.freshName("tmpCursor") val setNull = dt match { case t: DecimalType if t.precision > Decimal.MAX_LONG_DIGITS => @@ -105,33 +104,34 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro s"$rowWriter.write($index, (Decimal) null, ${t.precision}, ${t.scale});" case _ => s"$rowWriter.setNullAt($index);" } + val previousCursor = ctx.freshName("previousCursor") val writeField = dt match { case t: StructType => s""" // Remember the current cursor so that we can calculate how many bytes are // written later. - final int $tmpCursor = $bufferHolder.cursor; - ${writeStructToBuffer(ctx, input.value, t.map(_.dataType), bufferHolder)} - $rowWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor); + final int $previousCursor = $rowWriter.cursor(); + ${writeStructToBuffer(ctx, input.value, t.map(_.dataType), rowWriter)} + $rowWriter.setOffsetAndSizeFromPreviousCursor($index, $previousCursor); """ case a @ ArrayType(et, _) => s""" // Remember the current cursor so that we can calculate how many bytes are // written later. - final int $tmpCursor = $bufferHolder.cursor; - ${writeArrayToBuffer(ctx, input.value, et, bufferHolder)} - $rowWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor); + final int $previousCursor = $rowWriter.cursor(); + ${writeArrayToBuffer(ctx, input.value, et, rowWriter)} + $rowWriter.setOffsetAndSizeFromPreviousCursor($index, $previousCursor); """ case m @ MapType(kt, vt, _) => s""" // Remember the current cursor so that we can calculate how many bytes are // written later. - final int $tmpCursor = $bufferHolder.cursor; - ${writeMapToBuffer(ctx, input.value, kt, vt, bufferHolder)} - $rowWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor); + final int $previousCursor = $rowWriter.cursor(); + ${writeMapToBuffer(ctx, input.value, kt, vt, rowWriter)} + $rowWriter.setOffsetAndSizeFromPreviousCursor($index, $previousCursor); """ case t: DecimalType => @@ -181,12 +181,9 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro ctx: CodegenContext, input: String, elementType: DataType, - bufferHolder: String): String = { + rowWriter: String): String = { // Puts `input` in a local variable to avoid to re-evaluate it if it's a statement. val tmpInput = ctx.freshName("tmpInput") - val arrayWriterClass = classOf[UnsafeArrayWriter].getName - val arrayWriter = ctx.addMutableState(arrayWriterClass, "arrayWriter", - v => s"$v = new $arrayWriterClass();") val numElements = ctx.freshName("numElements") val index = ctx.freshName("index") @@ -203,28 +200,32 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro case _ => 8 // we need 8 bytes to store offset and length } - val tmpCursor = ctx.freshName("tmpCursor") + val arrayWriterClass = classOf[UnsafeArrayWriter].getName + val arrayWriter = ctx.addMutableState(arrayWriterClass, "arrayWriter", + v => s"$v = new $arrayWriterClass($rowWriter, $elementOrOffsetSize);") + val previousCursor = ctx.freshName("previousCursor") + val element = CodeGenerator.getValue(tmpInput, et, index) val writeElement = et match { case t: StructType => s""" - final int $tmpCursor = $bufferHolder.cursor; - ${writeStructToBuffer(ctx, element, t.map(_.dataType), bufferHolder)} - $arrayWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor); + final int $previousCursor = $arrayWriter.cursor(); + ${writeStructToBuffer(ctx, element, t.map(_.dataType), arrayWriter)} + $arrayWriter.setOffsetAndSizeFromPreviousCursor($index, $previousCursor); """ case a @ ArrayType(et, _) => s""" - final int $tmpCursor = $bufferHolder.cursor; - ${writeArrayToBuffer(ctx, element, et, bufferHolder)} - $arrayWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor); + final int $previousCursor = $arrayWriter.cursor(); + ${writeArrayToBuffer(ctx, element, et, arrayWriter)} + $arrayWriter.setOffsetAndSizeFromPreviousCursor($index, $previousCursor); """ case m @ MapType(kt, vt, _) => s""" - final int $tmpCursor = $bufferHolder.cursor; - ${writeMapToBuffer(ctx, element, kt, vt, bufferHolder)} - $arrayWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor); + final int $previousCursor = $arrayWriter.cursor(); + ${writeMapToBuffer(ctx, element, kt, vt, arrayWriter)} + $arrayWriter.setOffsetAndSizeFromPreviousCursor($index, $previousCursor); """ case t: DecimalType => @@ -240,10 +241,10 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro s""" final ArrayData $tmpInput = $input; if ($tmpInput instanceof UnsafeArrayData) { - ${writeUnsafeData(ctx, s"((UnsafeArrayData) $tmpInput)", bufferHolder)} + ${writeUnsafeData(ctx, s"((UnsafeArrayData) $tmpInput)", arrayWriter)} } else { final int $numElements = $tmpInput.numElements(); - $arrayWriter.initialize($bufferHolder, $numElements, $elementOrOffsetSize); + $arrayWriter.initialize($numElements); for (int $index = 0; $index < $numElements; $index++) { if ($tmpInput.isNullAt($index)) { @@ -262,7 +263,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro input: String, keyType: DataType, valueType: DataType, - bufferHolder: String): String = { + rowWriter: String): String = { // Puts `input` in a local variable to avoid to re-evaluate it if it's a statement. val tmpInput = ctx.freshName("tmpInput") val tmpCursor = ctx.freshName("tmpCursor") @@ -271,20 +272,20 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro s""" final MapData $tmpInput = $input; if ($tmpInput instanceof UnsafeMapData) { - ${writeUnsafeData(ctx, s"((UnsafeMapData) $tmpInput)", bufferHolder)} + ${writeUnsafeData(ctx, s"((UnsafeMapData) $tmpInput)", rowWriter)} } else { // preserve 8 bytes to write the key array numBytes later. - $bufferHolder.grow(8); - $bufferHolder.cursor += 8; + $rowWriter.grow(8); + $rowWriter.increaseCursor(8); // Remember the current cursor so that we can write numBytes of key array later. - final int $tmpCursor = $bufferHolder.cursor; + final int $tmpCursor = $rowWriter.cursor(); - ${writeArrayToBuffer(ctx, s"$tmpInput.keyArray()", keyType, bufferHolder)} + ${writeArrayToBuffer(ctx, s"$tmpInput.keyArray()", keyType, rowWriter)} // Write the numBytes of key array into the first 8 bytes. - Platform.putLong($bufferHolder.buffer, $tmpCursor - 8, $bufferHolder.cursor - $tmpCursor); + Platform.putLong($rowWriter.getBuffer(), $tmpCursor - 8, $rowWriter.cursor() - $tmpCursor); - ${writeArrayToBuffer(ctx, s"$tmpInput.valueArray()", valueType, bufferHolder)} + ${writeArrayToBuffer(ctx, s"$tmpInput.valueArray()", valueType, rowWriter)} } """ } @@ -293,14 +294,14 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro * If the input is already in unsafe format, we don't need to go through all elements/fields, * we can directly write it. */ - private def writeUnsafeData(ctx: CodegenContext, input: String, bufferHolder: String) = { + private def writeUnsafeData(ctx: CodegenContext, input: String, rowWriter: String) = { val sizeInBytes = ctx.freshName("sizeInBytes") s""" final int $sizeInBytes = $input.getSizeInBytes(); // grow the global buffer before writing data. - $bufferHolder.grow($sizeInBytes); - $input.writeToMemory($bufferHolder.buffer, $bufferHolder.cursor); - $bufferHolder.cursor += $sizeInBytes; + $rowWriter.grow($sizeInBytes); + $input.writeToMemory($rowWriter.getBuffer(), $rowWriter.cursor()); + $rowWriter.increaseCursor($sizeInBytes); """ } @@ -317,38 +318,23 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro case _ => true } - val result = ctx.addMutableState("UnsafeRow", "result", - v => s"$v = new UnsafeRow(${expressions.length});") - - val holderClass = classOf[BufferHolder].getName - val holder = ctx.addMutableState(holderClass, "holder", - v => s"$v = new $holderClass($result, ${numVarLenFields * 32});") - - val resetBufferHolder = if (numVarLenFields == 0) { - "" - } else { - s"$holder.reset();" - } - val updateRowSize = if (numVarLenFields == 0) { - "" - } else { - s"$result.setTotalSize($holder.totalSize());" - } + val rowWriterClass = classOf[UnsafeRowWriter].getName + val rowWriter = ctx.addMutableState(rowWriterClass, "rowWriter", + v => s"$v = new $rowWriterClass(${expressions.length}, ${numVarLenFields * 32});") // Evaluate all the subexpression. val evalSubexpr = ctx.subexprFunctions.mkString("\n") - val writeExpressions = - writeExpressionsToBuffer(ctx, ctx.INPUT_ROW, exprEvals, exprTypes, holder, isTopLevel = true) + val writeExpressions = writeExpressionsToBuffer( + ctx, ctx.INPUT_ROW, exprEvals, exprTypes, rowWriter, isTopLevel = true) val code = s""" - $resetBufferHolder + $rowWriter.reset(); $evalSubexpr $writeExpressions - $updateRowSize """ - ExprCode(code, "false", result) + ExprCode(code, "false", s"$rowWriter.getRow()") } protected def canonicalize(in: Seq[Expression]): Seq[Expression] = diff --git a/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/RowBasedKeyValueBatchSuite.java b/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/RowBasedKeyValueBatchSuite.java index fb3dbe8ed1996..2da87113c6229 100644 --- a/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/RowBasedKeyValueBatchSuite.java +++ b/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/RowBasedKeyValueBatchSuite.java @@ -27,7 +27,6 @@ import org.apache.spark.memory.TestMemoryManager; import org.apache.spark.sql.types.StructType; import org.apache.spark.sql.types.DataTypes; -import org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder; import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter; import org.apache.spark.unsafe.types.UTF8String; @@ -55,36 +54,27 @@ private String getRandomString(int length) { } private UnsafeRow makeKeyRow(long k1, String k2) { - UnsafeRow row = new UnsafeRow(2); - BufferHolder holder = new BufferHolder(row, 32); - UnsafeRowWriter writer = new UnsafeRowWriter(holder, 2); - holder.reset(); + UnsafeRowWriter writer = new UnsafeRowWriter(2); + writer.reset(); writer.write(0, k1); writer.write(1, UTF8String.fromString(k2)); - row.setTotalSize(holder.totalSize()); - return row; + return writer.getRow(); } private UnsafeRow makeKeyRow(long k1, long k2) { - UnsafeRow row = new UnsafeRow(2); - BufferHolder holder = new BufferHolder(row, 0); - UnsafeRowWriter writer = new UnsafeRowWriter(holder, 2); - holder.reset(); + UnsafeRowWriter writer = new UnsafeRowWriter(2); + writer.reset(); writer.write(0, k1); writer.write(1, k2); - row.setTotalSize(holder.totalSize()); - return row; + return writer.getRow(); } private UnsafeRow makeValueRow(long v1, long v2) { - UnsafeRow row = new UnsafeRow(2); - BufferHolder holder = new BufferHolder(row, 0); - UnsafeRowWriter writer = new UnsafeRowWriter(holder, 2); - holder.reset(); + UnsafeRowWriter writer = new UnsafeRowWriter(2); + writer.reset(); writer.write(0, v1); writer.write(1, v2); - row.setTotalSize(holder.totalSize()); - return row; + return writer.getRow(); } private UnsafeRow appendRow(RowBasedKeyValueBatch batch, UnsafeRow key, UnsafeRow value) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala index 8617be88f3570..d5508275c48c5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala @@ -165,18 +165,14 @@ class RowBasedHashMapGenerator( | if (buckets[idx] == -1) { | if (numRows < capacity && !isBatchFull) { | // creating the unsafe for new entry - | UnsafeRow agg_result = new UnsafeRow(${groupingKeySchema.length}); - | org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder agg_holder - | = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(agg_result, - | ${numVarLenFields * 32}); | org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter agg_rowWriter | = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter( - | agg_holder, - | ${groupingKeySchema.length}); - | agg_holder.reset(); //TODO: investigate if reset or zeroout are actually needed + | ${groupingKeySchema.length}, ${numVarLenFields * 32}); + | agg_rowWriter.reset(); //TODO: investigate if reset or zeroout are actually needed | agg_rowWriter.zeroOutNullBytes(); | ${createUnsafeRowForKey}; - | agg_result.setTotalSize(agg_holder.totalSize()); + | org.apache.spark.sql.catalyst.expressions.UnsafeRow agg_result + | = agg_rowWriter.getRow(); | Object kbase = agg_result.getBaseObject(); | long koff = agg_result.getBaseOffset(); | int klen = agg_result.getSizeInBytes(); diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala index 3b5655ba0582e..2d699e8a9d088 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala @@ -165,9 +165,7 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera private ByteOrder nativeOrder = null; private byte[][] buffers = null; - private UnsafeRow unsafeRow = new UnsafeRow($numFields); - private BufferHolder bufferHolder = new BufferHolder(unsafeRow); - private UnsafeRowWriter rowWriter = new UnsafeRowWriter(bufferHolder, $numFields); + private UnsafeRowWriter rowWriter = new UnsafeRowWriter($numFields); private MutableUnsafeRow mutableRow = null; private int currentRow = 0; @@ -212,11 +210,10 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera public InternalRow next() { currentRow += 1; - bufferHolder.reset(); + rowWriter.reset(); rowWriter.zeroOutNullBytes(); ${extractorCalls} - unsafeRow.setTotalSize(bufferHolder.totalSize()); - return unsafeRow; + return rowWriter.getRow(); } ${ctx.declareAddedFunctions()} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala index 9647f09867643..e93908da43535 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala @@ -26,7 +26,7 @@ import org.apache.spark.broadcast.Broadcast import org.apache.spark.sql.{AnalysisException, SparkSession} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.UnsafeRow -import org.apache.spark.sql.catalyst.expressions.codegen.{BufferHolder, UnsafeRowWriter} +import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter import org.apache.spark.sql.catalyst.util.CompressionCodecs import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.sources._ @@ -130,16 +130,13 @@ class TextFileFormat extends TextBasedFileFormat with DataSourceRegister { val emptyUnsafeRow = new UnsafeRow(0) reader.map(_ => emptyUnsafeRow) } else { - val unsafeRow = new UnsafeRow(1) - val bufferHolder = new BufferHolder(unsafeRow) - val unsafeRowWriter = new UnsafeRowWriter(bufferHolder, 1) + val unsafeRowWriter = new UnsafeRowWriter(1) reader.map { line => // Writes to an UnsafeRow directly - bufferHolder.reset() + unsafeRowWriter.reset() unsafeRowWriter.write(0, line.getBytes, 0, line.getLength) - unsafeRow.setTotalSize(bufferHolder.totalSize()) - unsafeRow + unsafeRowWriter.getRow() } } } From 28ea4e3142b88eb396aa8dd5daf7b02b556204ba Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Mon, 2 Apr 2018 14:35:07 -0700 Subject: [PATCH 0549/1161] [SPARK-23834][TEST] Wait for connection before disconnect in LauncherServer test. It was possible that the disconnect() was called on the handle before the server had received the handshake messages, so no connection was yet attached to the handle. The fix waits until we're sure the handle has been mapped to a client connection. Author: Marcelo Vanzin Closes #20950 from vanzin/SPARK-23834. --- .../org/apache/spark/launcher/LauncherServerSuite.java | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java b/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java index 5413d3a416545..f8dc0ec7a0bf6 100644 --- a/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java +++ b/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java @@ -196,6 +196,14 @@ public void testAppHandleDisconnect() throws Exception { Socket s = new Socket(InetAddress.getLoopbackAddress(), server.getPort()); client = new TestClient(s); client.send(new Hello(secret, "1.4.0")); + client.send(new SetAppId("someId")); + + // Wait until we know the server has received the messages and matched the handle to the + // connection before disconnecting. + eventually(Duration.ofSeconds(1), Duration.ofMillis(10), () -> { + assertEquals("someId", handle.getAppId()); + }); + handle.disconnect(); waitForError(client, secret); } finally { From a1351828d376a01e5ee0959cf608f767d756dd86 Mon Sep 17 00:00:00 2001 From: Yogesh Garg Date: Mon, 2 Apr 2018 16:41:26 -0700 Subject: [PATCH 0550/1161] [SPARK-23690][ML] Add handleinvalid to VectorAssembler ## What changes were proposed in this pull request? Introduce `handleInvalid` parameter in `VectorAssembler` that can take in `"keep", "skip", "error"` options. "error" throws an error on seeing a row containing a `null`, "skip" filters out all such rows, and "keep" adds relevant number of NaN. "keep" figures out an example to find out what this number of NaN s should be added and throws an error when no such number could be found. ## How was this patch tested? Unit tests are added to check the behavior of `assemble` on specific rows and the transformer is called on `DataFrame`s of different configurations to test different corner cases. Author: Yogesh Garg Author: Bago Amirbekian Author: Yogesh Garg <1059168+yogeshg@users.noreply.github.com> Closes #20829 from yogeshg/rformula_handleinvalid. --- .../spark/ml/feature/StringIndexer.scala | 2 +- .../spark/ml/feature/VectorAssembler.scala | 198 ++++++++++++++---- .../ml/feature/VectorAssemblerSuite.scala | 131 +++++++++++- 3 files changed, 284 insertions(+), 47 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index 1cdcdfcaeab78..67cdb097217a2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -234,7 +234,7 @@ class StringIndexerModel ( val metadata = NominalAttribute.defaultAttr .withName($(outputCol)).withValues(filteredLabels).toMetadata() // If we are skipping invalid records, filter them out. - val (filteredDataset, keepInvalid) = getHandleInvalid match { + val (filteredDataset, keepInvalid) = $(handleInvalid) match { case StringIndexer.SKIP_INVALID => val filterer = udf { label: String => labelToIndex.contains(label) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala index b373ae921ed38..6bf4aa38b1fcb 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala @@ -17,14 +17,17 @@ package org.apache.spark.ml.feature -import scala.collection.mutable.ArrayBuilder +import java.util.NoSuchElementException + +import scala.collection.mutable +import scala.language.existentials import org.apache.spark.SparkException import org.apache.spark.annotation.Since import org.apache.spark.ml.Transformer import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NumericAttribute, UnresolvedAttribute} import org.apache.spark.ml.linalg.{Vector, Vectors, VectorUDT} -import org.apache.spark.ml.param.ParamMap +import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators} import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ import org.apache.spark.sql.{DataFrame, Dataset, Row} @@ -33,10 +36,14 @@ import org.apache.spark.sql.types._ /** * A feature transformer that merges multiple columns into a vector column. + * + * This requires one pass over the entire dataset. In case we need to infer column lengths from the + * data we require an additional call to the 'first' Dataset method, see 'handleInvalid' parameter. */ @Since("1.4.0") class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String) - extends Transformer with HasInputCols with HasOutputCol with DefaultParamsWritable { + extends Transformer with HasInputCols with HasOutputCol with HasHandleInvalid + with DefaultParamsWritable { @Since("1.4.0") def this() = this(Identifiable.randomUID("vecAssembler")) @@ -49,32 +56,63 @@ class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String) @Since("1.4.0") def setOutputCol(value: String): this.type = set(outputCol, value) + /** @group setParam */ + @Since("2.4.0") + def setHandleInvalid(value: String): this.type = set(handleInvalid, value) + + /** + * Param for how to handle invalid data (NULL values). Options are 'skip' (filter out rows with + * invalid data), 'error' (throw an error), or 'keep' (return relevant number of NaN in the + * output). Column lengths are taken from the size of ML Attribute Group, which can be set using + * `VectorSizeHint` in a pipeline before `VectorAssembler`. Column lengths can also be inferred + * from first rows of the data since it is safe to do so but only in case of 'error' or 'skip'. + * Default: "error" + * @group param + */ + @Since("2.4.0") + override val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", + """Param for how to handle invalid data (NULL values). Options are 'skip' (filter out rows with + |invalid data), 'error' (throw an error), or 'keep' (return relevant number of NaN in the + |output). Column lengths are taken from the size of ML Attribute Group, which can be set using + |`VectorSizeHint` in a pipeline before `VectorAssembler`. Column lengths can also be inferred + |from first rows of the data since it is safe to do so but only in case of 'error' or 'skip'. + |""".stripMargin.replaceAll("\n", " "), + ParamValidators.inArray(VectorAssembler.supportedHandleInvalids)) + + setDefault(handleInvalid, VectorAssembler.ERROR_INVALID) + @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) // Schema transformation. val schema = dataset.schema - lazy val first = dataset.toDF.first() - val attrs = $(inputCols).flatMap { c => + + val vectorCols = $(inputCols).filter { c => + schema(c).dataType match { + case _: VectorUDT => true + case _ => false + } + } + val vectorColsLengths = VectorAssembler.getLengths(dataset, vectorCols, $(handleInvalid)) + + val featureAttributesMap = $(inputCols).map { c => val field = schema(c) - val index = schema.fieldIndex(c) field.dataType match { case DoubleType => - val attr = Attribute.fromStructField(field) - // If the input column doesn't have ML attribute, assume numeric. - if (attr == UnresolvedAttribute) { - Some(NumericAttribute.defaultAttr.withName(c)) - } else { - Some(attr.withName(c)) + val attribute = Attribute.fromStructField(field) + attribute match { + case UnresolvedAttribute => + Seq(NumericAttribute.defaultAttr.withName(c)) + case _ => + Seq(attribute.withName(c)) } case _: NumericType | BooleanType => // If the input column type is a compatible scalar type, assume numeric. - Some(NumericAttribute.defaultAttr.withName(c)) + Seq(NumericAttribute.defaultAttr.withName(c)) case _: VectorUDT => - val group = AttributeGroup.fromStructField(field) - if (group.attributes.isDefined) { - // If attributes are defined, copy them with updated names. - group.attributes.get.zipWithIndex.map { case (attr, i) => + val attributeGroup = AttributeGroup.fromStructField(field) + if (attributeGroup.attributes.isDefined) { + attributeGroup.attributes.get.zipWithIndex.toSeq.map { case (attr, i) => if (attr.name.isDefined) { // TODO: Define a rigorous naming scheme. attr.withName(c + "_" + attr.name.get) @@ -85,18 +123,25 @@ class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String) } else { // Otherwise, treat all attributes as numeric. If we cannot get the number of attributes // from metadata, check the first row. - val numAttrs = group.numAttributes.getOrElse(first.getAs[Vector](index).size) - Array.tabulate(numAttrs)(i => NumericAttribute.defaultAttr.withName(c + "_" + i)) + (0 until vectorColsLengths(c)).map { i => + NumericAttribute.defaultAttr.withName(c + "_" + i) + } } case otherType => throw new SparkException(s"VectorAssembler does not support the $otherType type") } } - val metadata = new AttributeGroup($(outputCol), attrs).toMetadata() - + val featureAttributes = featureAttributesMap.flatten[Attribute].toArray + val lengths = featureAttributesMap.map(a => a.length).toArray + val metadata = new AttributeGroup($(outputCol), featureAttributes).toMetadata() + val (filteredDataset, keepInvalid) = $(handleInvalid) match { + case VectorAssembler.SKIP_INVALID => (dataset.na.drop($(inputCols)), false) + case VectorAssembler.KEEP_INVALID => (dataset, true) + case VectorAssembler.ERROR_INVALID => (dataset, false) + } // Data transformation. val assembleFunc = udf { r: Row => - VectorAssembler.assemble(r.toSeq: _*) + VectorAssembler.assemble(lengths, keepInvalid)(r.toSeq: _*) }.asNondeterministic() val args = $(inputCols).map { c => schema(c).dataType match { @@ -106,7 +151,7 @@ class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String) } } - dataset.select(col("*"), assembleFunc(struct(args: _*)).as($(outputCol), metadata)) + filteredDataset.select(col("*"), assembleFunc(struct(args: _*)).as($(outputCol), metadata)) } @Since("1.4.0") @@ -136,34 +181,117 @@ class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String) @Since("1.6.0") object VectorAssembler extends DefaultParamsReadable[VectorAssembler] { + private[feature] val SKIP_INVALID: String = "skip" + private[feature] val ERROR_INVALID: String = "error" + private[feature] val KEEP_INVALID: String = "keep" + private[feature] val supportedHandleInvalids: Array[String] = + Array(SKIP_INVALID, ERROR_INVALID, KEEP_INVALID) + + /** + * Infers lengths of vector columns from the first row of the dataset + * @param dataset the dataset + * @param columns name of vector columns whose lengths need to be inferred + * @return map of column names to lengths + */ + private[feature] def getVectorLengthsFromFirstRow( + dataset: Dataset[_], + columns: Seq[String]): Map[String, Int] = { + try { + val first_row = dataset.toDF().select(columns.map(col): _*).first() + columns.zip(first_row.toSeq).map { + case (c, x) => c -> x.asInstanceOf[Vector].size + }.toMap + } catch { + case e: NullPointerException => throw new NullPointerException( + s"""Encountered null value while inferring lengths from the first row. Consider using + |VectorSizeHint to add metadata for columns: ${columns.mkString("[", ", ", "]")}. """ + .stripMargin.replaceAll("\n", " ") + e.toString) + case e: NoSuchElementException => throw new NoSuchElementException( + s"""Encountered empty dataframe while inferring lengths from the first row. Consider using + |VectorSizeHint to add metadata for columns: ${columns.mkString("[", ", ", "]")}. """ + .stripMargin.replaceAll("\n", " ") + e.toString) + } + } + + private[feature] def getLengths( + dataset: Dataset[_], + columns: Seq[String], + handleInvalid: String): Map[String, Int] = { + val groupSizes = columns.map { c => + c -> AttributeGroup.fromStructField(dataset.schema(c)).size + }.toMap + val missingColumns = groupSizes.filter(_._2 == -1).keys.toSeq + val firstSizes = (missingColumns.nonEmpty, handleInvalid) match { + case (true, VectorAssembler.ERROR_INVALID) => + getVectorLengthsFromFirstRow(dataset, missingColumns) + case (true, VectorAssembler.SKIP_INVALID) => + getVectorLengthsFromFirstRow(dataset.na.drop(missingColumns), missingColumns) + case (true, VectorAssembler.KEEP_INVALID) => throw new RuntimeException( + s"""Can not infer column lengths with handleInvalid = "keep". Consider using VectorSizeHint + |to add metadata for columns: ${columns.mkString("[", ", ", "]")}.""" + .stripMargin.replaceAll("\n", " ")) + case (_, _) => Map.empty + } + groupSizes ++ firstSizes + } + + @Since("1.6.0") override def load(path: String): VectorAssembler = super.load(path) - private[feature] def assemble(vv: Any*): Vector = { - val indices = ArrayBuilder.make[Int] - val values = ArrayBuilder.make[Double] - var cur = 0 + /** + * Returns a function that has the required information to assemble each row. + * @param lengths an array of lengths of input columns, whose size should be equal to the number + * of cells in the row (vv) + * @param keepInvalid indicate whether to throw an error or not on seeing a null in the rows + * @return a udf that can be applied on each row + */ + private[feature] def assemble(lengths: Array[Int], keepInvalid: Boolean)(vv: Any*): Vector = { + val indices = mutable.ArrayBuilder.make[Int] + val values = mutable.ArrayBuilder.make[Double] + var featureIndex = 0 + + var inputColumnIndex = 0 vv.foreach { case v: Double => - if (v != 0.0) { - indices += cur + if (v.isNaN && !keepInvalid) { + throw new SparkException( + s"""Encountered NaN while assembling a row with handleInvalid = "error". Consider + |removing NaNs from dataset or using handleInvalid = "keep" or "skip".""" + .stripMargin) + } else if (v != 0.0) { + indices += featureIndex values += v } - cur += 1 + inputColumnIndex += 1 + featureIndex += 1 case vec: Vector => vec.foreachActive { case (i, v) => if (v != 0.0) { - indices += cur + i + indices += featureIndex + i values += v } } - cur += vec.size + inputColumnIndex += 1 + featureIndex += vec.size case null => - // TODO: output Double.NaN? - throw new SparkException("Values to assemble cannot be null.") + if (keepInvalid) { + val length: Int = lengths(inputColumnIndex) + Array.range(0, length).foreach { i => + indices += featureIndex + i + values += Double.NaN + } + inputColumnIndex += 1 + featureIndex += length + } else { + throw new SparkException( + s"""Encountered null while assembling a row with handleInvalid = "keep". Consider + |removing nulls from dataset or using handleInvalid = "keep" or "skip".""" + .stripMargin) + } case o => throw new SparkException(s"$o of type ${o.getClass.getName} is not supported.") } - Vectors.sparse(cur, indices.result(), values.result()).compressed + Vectors.sparse(featureIndex, indices.result(), values.result()).compressed } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala index eca065f7e775d..91fb24a268b8c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala @@ -18,12 +18,12 @@ package org.apache.spark.ml.feature import org.apache.spark.{SparkException, SparkFunSuite} -import org.apache.spark.ml.attribute.{AttributeGroup, NominalAttribute, NumericAttribute} +import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NominalAttribute, NumericAttribute} import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.Row +import org.apache.spark.sql.{Dataset, Row} import org.apache.spark.sql.functions.{col, udf} class VectorAssemblerSuite @@ -31,30 +31,49 @@ class VectorAssemblerSuite import testImplicits._ + @transient var dfWithNullsAndNaNs: Dataset[_] = _ + + override def beforeAll(): Unit = { + super.beforeAll() + val sv = Vectors.sparse(2, Array(1), Array(3.0)) + dfWithNullsAndNaNs = Seq[(Long, Long, java.lang.Double, Vector, String, Vector, Long, String)]( + (1, 2, 0.0, Vectors.dense(1.0, 2.0), "a", sv, 7L, null), + (2, 1, 0.0, null, "a", sv, 6L, null), + (3, 3, null, Vectors.dense(1.0, 2.0), "a", sv, 8L, null), + (4, 4, null, null, "a", sv, 9L, null), + (5, 5, java.lang.Double.NaN, Vectors.dense(1.0, 2.0), "a", sv, 7L, null), + (6, 6, java.lang.Double.NaN, null, "a", sv, 8L, null)) + .toDF("id1", "id2", "x", "y", "name", "z", "n", "nulls") + } + test("params") { ParamsSuite.checkParams(new VectorAssembler) } test("assemble") { import org.apache.spark.ml.feature.VectorAssembler.assemble - assert(assemble(0.0) === Vectors.sparse(1, Array.empty, Array.empty)) - assert(assemble(0.0, 1.0) === Vectors.sparse(2, Array(1), Array(1.0))) + assert(assemble(Array(1), keepInvalid = true)(0.0) + === Vectors.sparse(1, Array.empty, Array.empty)) + assert(assemble(Array(1, 1), keepInvalid = true)(0.0, 1.0) + === Vectors.sparse(2, Array(1), Array(1.0))) val dv = Vectors.dense(2.0, 0.0) - assert(assemble(0.0, dv, 1.0) === Vectors.sparse(4, Array(1, 3), Array(2.0, 1.0))) + assert(assemble(Array(1, 2, 1), keepInvalid = true)(0.0, dv, 1.0) === + Vectors.sparse(4, Array(1, 3), Array(2.0, 1.0))) val sv = Vectors.sparse(2, Array(0, 1), Array(3.0, 4.0)) - assert(assemble(0.0, dv, 1.0, sv) === + assert(assemble(Array(1, 2, 1, 2), keepInvalid = true)(0.0, dv, 1.0, sv) === Vectors.sparse(6, Array(1, 3, 4, 5), Array(2.0, 1.0, 3.0, 4.0))) - for (v <- Seq(1, "a", null)) { - intercept[SparkException](assemble(v)) - intercept[SparkException](assemble(1.0, v)) + for (v <- Seq(1, "a")) { + intercept[SparkException](assemble(Array(1), keepInvalid = true)(v)) + intercept[SparkException](assemble(Array(1, 1), keepInvalid = true)(1.0, v)) } } test("assemble should compress vectors") { import org.apache.spark.ml.feature.VectorAssembler.assemble - val v1 = assemble(0.0, 0.0, 0.0, Vectors.dense(4.0)) + val v1 = assemble(Array(1, 1, 1, 1), keepInvalid = true)(0.0, 0.0, 0.0, Vectors.dense(4.0)) assert(v1.isInstanceOf[SparseVector]) - val v2 = assemble(1.0, 2.0, 3.0, Vectors.sparse(1, Array(0), Array(4.0))) + val sv = Vectors.sparse(1, Array(0), Array(4.0)) + val v2 = assemble(Array(1, 1, 1, 1), keepInvalid = true)(1.0, 2.0, 3.0, sv) assert(v2.isInstanceOf[DenseVector]) } @@ -147,4 +166,94 @@ class VectorAssemblerSuite .filter(vectorUDF($"features") > 1) .count() == 1) } + + test("assemble should keep nulls when keepInvalid is true") { + import org.apache.spark.ml.feature.VectorAssembler.assemble + assert(assemble(Array(1, 1), keepInvalid = true)(1.0, null) === Vectors.dense(1.0, Double.NaN)) + assert(assemble(Array(1, 2), keepInvalid = true)(1.0, null) + === Vectors.dense(1.0, Double.NaN, Double.NaN)) + assert(assemble(Array(1), keepInvalid = true)(null) === Vectors.dense(Double.NaN)) + assert(assemble(Array(2), keepInvalid = true)(null) === Vectors.dense(Double.NaN, Double.NaN)) + } + + test("assemble should throw errors when keepInvalid is false") { + import org.apache.spark.ml.feature.VectorAssembler.assemble + intercept[SparkException](assemble(Array(1, 1), keepInvalid = false)(1.0, null)) + intercept[SparkException](assemble(Array(1, 2), keepInvalid = false)(1.0, null)) + intercept[SparkException](assemble(Array(1), keepInvalid = false)(null)) + intercept[SparkException](assemble(Array(2), keepInvalid = false)(null)) + } + + test("get lengths functions") { + import org.apache.spark.ml.feature.VectorAssembler._ + val df = dfWithNullsAndNaNs + assert(getVectorLengthsFromFirstRow(df, Seq("y")) === Map("y" -> 2)) + assert(intercept[NullPointerException](getVectorLengthsFromFirstRow(df.sort("id2"), Seq("y"))) + .getMessage.contains("VectorSizeHint")) + assert(intercept[NoSuchElementException](getVectorLengthsFromFirstRow(df.filter("id1 > 6"), + Seq("y"))).getMessage.contains("VectorSizeHint")) + + assert(getLengths(df.sort("id2"), Seq("y"), SKIP_INVALID).exists(_ == "y" -> 2)) + assert(intercept[NullPointerException](getLengths(df.sort("id2"), Seq("y"), ERROR_INVALID)) + .getMessage.contains("VectorSizeHint")) + assert(intercept[RuntimeException](getLengths(df.sort("id2"), Seq("y"), KEEP_INVALID)) + .getMessage.contains("VectorSizeHint")) + } + + test("Handle Invalid should behave properly") { + val assembler = new VectorAssembler() + .setInputCols(Array("x", "y", "z", "n")) + .setOutputCol("features") + + def runWithMetadata(mode: String, additional_filter: String = "true"): Dataset[_] = { + val attributeY = new AttributeGroup("y", 2) + val attributeZ = new AttributeGroup( + "z", + Array[Attribute]( + NumericAttribute.defaultAttr.withName("foo"), + NumericAttribute.defaultAttr.withName("bar"))) + val dfWithMetadata = dfWithNullsAndNaNs.withColumn("y", col("y"), attributeY.toMetadata()) + .withColumn("z", col("z"), attributeZ.toMetadata()).filter(additional_filter) + val output = assembler.setHandleInvalid(mode).transform(dfWithMetadata) + output.collect() + output + } + + def runWithFirstRow(mode: String): Dataset[_] = { + val output = assembler.setHandleInvalid(mode).transform(dfWithNullsAndNaNs) + output.collect() + output + } + + def runWithAllNullVectors(mode: String): Dataset[_] = { + val output = assembler.setHandleInvalid(mode) + .transform(dfWithNullsAndNaNs.filter("0 == id1 % 2")) + output.collect() + output + } + + // behavior when vector size hint is given + assert(runWithMetadata("keep").count() == 6, "should keep all rows") + assert(runWithMetadata("skip").count() == 1, "should skip rows with nulls") + // should throw error with nulls + intercept[SparkException](runWithMetadata("error")) + // should throw error with NaNs + intercept[SparkException](runWithMetadata("error", additional_filter = "id1 > 4")) + + // behavior when first row has information + assert(intercept[RuntimeException](runWithFirstRow("keep").count()) + .getMessage.contains("VectorSizeHint"), "should suggest to use metadata") + assert(runWithFirstRow("skip").count() == 1, "should infer size and skip rows with nulls") + intercept[SparkException](runWithFirstRow("error")) + + // behavior when vector column is all null + assert(intercept[RuntimeException](runWithAllNullVectors("skip")) + .getMessage.contains("VectorSizeHint"), "should suggest to use metadata") + assert(intercept[NullPointerException](runWithAllNullVectors("error")) + .getMessage.contains("VectorSizeHint"), "should suggest to use metadata") + + // behavior when scalar column is all null + assert(runWithMetadata("keep", additional_filter = "id1 > 2").count() == 4) + } + } From 441d0d0766e9a6ac4c6ff79680394999ff7191fd Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Tue, 3 Apr 2018 09:31:47 +0800 Subject: [PATCH 0551/1161] [SPARK-19964][CORE] Avoid reading from remote repos in SparkSubmitSuite. These tests can fail with a timeout if the remote repos are not responding, or slow. The tests don't need anything from those repos, so use an empty ivy config file to avoid setting up the defaults. The tests are passing reliably for me locally now, and failing more often than not today without this change since http://dl.bintray.com/spark-packages/maven doesn't seem to be loading from my machine. Author: Marcelo Vanzin Closes #20916 from vanzin/SPARK-19964. --- .../org/apache/spark/deploy/DependencyUtils.scala | 13 ++++++++----- .../scala/org/apache/spark/deploy/SparkSubmit.scala | 3 ++- .../apache/spark/deploy/SparkSubmitArguments.scala | 2 ++ .../apache/spark/deploy/worker/DriverWrapper.scala | 13 +++++++++---- .../org/apache/spark/deploy/SparkSubmitSuite.scala | 9 ++++++--- 5 files changed, 27 insertions(+), 13 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/DependencyUtils.scala b/core/src/main/scala/org/apache/spark/deploy/DependencyUtils.scala index ab319c860ee69..fac834a70b893 100644 --- a/core/src/main/scala/org/apache/spark/deploy/DependencyUtils.scala +++ b/core/src/main/scala/org/apache/spark/deploy/DependencyUtils.scala @@ -33,7 +33,8 @@ private[deploy] object DependencyUtils { packagesExclusions: String, packages: String, repositories: String, - ivyRepoPath: String): String = { + ivyRepoPath: String, + ivySettingsPath: Option[String]): String = { val exclusions: Seq[String] = if (!StringUtils.isBlank(packagesExclusions)) { packagesExclusions.split(",") @@ -41,10 +42,12 @@ private[deploy] object DependencyUtils { Nil } // Create the IvySettings, either load from file or build defaults - val ivySettings = sys.props.get("spark.jars.ivySettings").map { ivySettingsFile => - SparkSubmitUtils.loadIvySettings(ivySettingsFile, Option(repositories), Option(ivyRepoPath)) - }.getOrElse { - SparkSubmitUtils.buildIvySettings(Option(repositories), Option(ivyRepoPath)) + val ivySettings = ivySettingsPath match { + case Some(path) => + SparkSubmitUtils.loadIvySettings(path, Option(repositories), Option(ivyRepoPath)) + + case None => + SparkSubmitUtils.buildIvySettings(Option(repositories), Option(ivyRepoPath)) } SparkSubmitUtils.resolveMavenCoordinates(packages, ivySettings, exclusions = exclusions) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 3965f17f4b56e..eddbedeb1024d 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -359,7 +359,8 @@ object SparkSubmit extends CommandLineUtils with Logging { // Resolve maven dependencies if there are any and add classpath to jars. Add them to py-files // too for packages that include Python code val resolvedMavenCoordinates = DependencyUtils.resolveMavenDependencies( - args.packagesExclusions, args.packages, args.repositories, args.ivyRepoPath) + args.packagesExclusions, args.packages, args.repositories, args.ivyRepoPath, + args.ivySettingsPath) if (!StringUtils.isBlank(resolvedMavenCoordinates)) { args.jars = mergeFileLists(args.jars, resolvedMavenCoordinates) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index e7796d4ddbe34..8e7070593687b 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -63,6 +63,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S var packages: String = null var repositories: String = null var ivyRepoPath: String = null + var ivySettingsPath: Option[String] = None var packagesExclusions: String = null var verbose: Boolean = false var isPython: Boolean = false @@ -184,6 +185,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S jars = Option(jars).orElse(sparkProperties.get("spark.jars")).orNull files = Option(files).orElse(sparkProperties.get("spark.files")).orNull ivyRepoPath = sparkProperties.get("spark.jars.ivy").orNull + ivySettingsPath = sparkProperties.get("spark.jars.ivySettings") packages = Option(packages).orElse(sparkProperties.get("spark.jars.packages")).orNull packagesExclusions = Option(packagesExclusions) .orElse(sparkProperties.get("spark.jars.excludes")).orNull diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala b/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala index b19c9904d5982..3f71237164a15 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala @@ -79,12 +79,17 @@ object DriverWrapper extends Logging { val secMgr = new SecurityManager(sparkConf) val hadoopConf = SparkHadoopUtil.newConfiguration(sparkConf) - val Seq(packagesExclusions, packages, repositories, ivyRepoPath) = - Seq("spark.jars.excludes", "spark.jars.packages", "spark.jars.repositories", "spark.jars.ivy") - .map(sys.props.get(_).orNull) + val Seq(packagesExclusions, packages, repositories, ivyRepoPath, ivySettingsPath) = + Seq( + "spark.jars.excludes", + "spark.jars.packages", + "spark.jars.repositories", + "spark.jars.ivy", + "spark.jars.ivySettings" + ).map(sys.props.get(_).orNull) val resolvedMavenCoordinates = DependencyUtils.resolveMavenDependencies(packagesExclusions, - packages, repositories, ivyRepoPath) + packages, repositories, ivyRepoPath, Option(ivySettingsPath)) val jars = { val jarsProp = sys.props.get("spark.jars").orNull if (!StringUtils.isBlank(resolvedMavenCoordinates)) { diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index d86ef907b4492..0d7c342a5eacd 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -106,6 +106,9 @@ class SparkSubmitSuite // Necessary to make ScalaTest 3.x interrupt a thread on the JVM like ScalaTest 2.2.x implicit val defaultSignaler: Signaler = ThreadSignaler + private val emptyIvySettings = File.createTempFile("ivy", ".xml") + FileUtils.write(emptyIvySettings, "", StandardCharsets.UTF_8) + override def beforeEach() { super.beforeEach() } @@ -520,6 +523,7 @@ class SparkSubmitSuite "--repositories", repo, "--conf", "spark.ui.enabled=false", "--conf", "spark.master.rest.enabled=false", + "--conf", s"spark.jars.ivySettings=${emptyIvySettings.getAbsolutePath()}", unusedJar.toString, "my.great.lib.MyLib", "my.great.dep.MyLib") runSparkSubmit(args) @@ -530,7 +534,6 @@ class SparkSubmitSuite val unusedJar = TestUtils.createJarWithClasses(Seq.empty) val main = MavenCoordinate("my.great.lib", "mylib", "0.1") val dep = MavenCoordinate("my.great.dep", "mylib", "0.1") - // Test using "spark.jars.packages" and "spark.jars.repositories" configurations. IvyTestUtils.withRepository(main, Some(dep.toString), None) { repo => val args = Seq( "--class", JarCreationTest.getClass.getName.stripSuffix("$"), @@ -540,6 +543,7 @@ class SparkSubmitSuite "--conf", s"spark.jars.repositories=$repo", "--conf", "spark.ui.enabled=false", "--conf", "spark.master.rest.enabled=false", + "--conf", s"spark.jars.ivySettings=${emptyIvySettings.getAbsolutePath()}", unusedJar.toString, "my.great.lib.MyLib", "my.great.dep.MyLib") runSparkSubmit(args) @@ -550,7 +554,6 @@ class SparkSubmitSuite // See https://gist.github.com/shivaram/3a2fecce60768a603dac for a error log ignore("correctly builds R packages included in a jar with --packages") { assume(RUtils.isRInstalled, "R isn't installed on this machine.") - // Check if the SparkR package is installed assume(RUtils.isSparkRInstalled, "SparkR is not installed in this build.") val main = MavenCoordinate("my.great.lib", "mylib", "0.1") val sparkHome = sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!")) @@ -563,6 +566,7 @@ class SparkSubmitSuite "--master", "local-cluster[2,1,1024]", "--packages", main.toString, "--repositories", repo, + "--conf", s"spark.jars.ivySettings=${emptyIvySettings.getAbsolutePath()}", "--verbose", "--conf", "spark.ui.enabled=false", rScriptDir) @@ -573,7 +577,6 @@ class SparkSubmitSuite test("include an external JAR in SparkR") { assume(RUtils.isRInstalled, "R isn't installed on this machine.") val sparkHome = sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!")) - // Check if the SparkR package is installed assume(RUtils.isSparkRInstalled, "SparkR is not installed in this build.") val rScriptDir = Seq(sparkHome, "R", "pkg", "tests", "fulltests", "jarTest.R").mkString(File.separator) From 8020f66fc47140a1b5f843fb18c34ec80541d5ca Mon Sep 17 00:00:00 2001 From: lemonjing <932191671@qq.com> Date: Tue, 3 Apr 2018 09:36:44 +0800 Subject: [PATCH 0552/1161] [MINOR][DOC] Fix a few markdown typos ## What changes were proposed in this pull request? Easy fix in the markdown. ## How was this patch tested? jekyII build test manually. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: lemonjing <932191671@qq.com> Closes #20897 from Lemonjing/master. --- docs/ml-guide.md | 2 +- docs/mllib-feature-extraction.md | 4 ++-- docs/mllib-pmml-model-export.md | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/docs/ml-guide.md b/docs/ml-guide.md index 702bcf748fc74..aea07be34cb86 100644 --- a/docs/ml-guide.md +++ b/docs/ml-guide.md @@ -111,7 +111,7 @@ and the migration guide below will explain all changes between releases. * The class and trait hierarchy for logistic regression model summaries was changed to be cleaner and better accommodate the addition of the multi-class summary. This is a breaking change for user code that casts a `LogisticRegressionTrainingSummary` to a -` BinaryLogisticRegressionTrainingSummary`. Users should instead use the `model.binarySummary` +`BinaryLogisticRegressionTrainingSummary`. Users should instead use the `model.binarySummary` method. See [SPARK-17139](https://issues.apache.org/jira/browse/SPARK-17139) for more detail (_note_ this is an `Experimental` API). This _does not_ affect the Python `summary` method, which will still work correctly for both multinomial and binary cases. diff --git a/docs/mllib-feature-extraction.md b/docs/mllib-feature-extraction.md index 75aea70601875..8b89296b14cdd 100644 --- a/docs/mllib-feature-extraction.md +++ b/docs/mllib-feature-extraction.md @@ -278,8 +278,8 @@ for details on the API. multiplication. In other words, it scales each column of the dataset by a scalar multiplier. This represents the [Hadamard product](https://en.wikipedia.org/wiki/Hadamard_product_%28matrices%29) between the input vector, `v` and transforming vector, `scalingVec`, to yield a result vector. -Qu8T948*1# -Denoting the `scalingVec` as "`w`," this transformation may be written as: + +Denoting the `scalingVec` as "`w`", this transformation may be written as: `\[ \begin{pmatrix} v_1 \\ diff --git a/docs/mllib-pmml-model-export.md b/docs/mllib-pmml-model-export.md index d3530908706d0..f567565437927 100644 --- a/docs/mllib-pmml-model-export.md +++ b/docs/mllib-pmml-model-export.md @@ -7,7 +7,7 @@ displayTitle: PMML model export - RDD-based API * Table of contents {:toc} -## `spark.mllib` supported models +## spark.mllib supported models `spark.mllib` supports model export to Predictive Model Markup Language ([PMML](http://en.wikipedia.org/wiki/Predictive_Model_Markup_Language)). @@ -15,7 +15,7 @@ The table below outlines the `spark.mllib` models that can be exported to PMML a - + From 7cf9fab33457ccc9b2d548f15dd5700d5e8d08ef Mon Sep 17 00:00:00 2001 From: Xingbo Jiang Date: Tue, 3 Apr 2018 21:26:49 +0800 Subject: [PATCH 0553/1161] [MINOR][CORE] Show block manager id when remove RDD/Broadcast fails. ## What changes were proposed in this pull request? Address https://github.com/apache/spark/pull/20924#discussion_r177987175, show block manager id when remove RDD/Broadcast fails. ## How was this patch tested? N/A Author: Xingbo Jiang Closes #20960 from jiangxb1987/bmid. --- .../apache/spark/storage/BlockManagerMasterEndpoint.scala | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala index 56b95c31eb4c3..8e8f7d197c9ef 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala @@ -164,7 +164,8 @@ class BlockManagerMasterEndpoint( val futures = blockManagerInfo.values.map { bm => bm.slaveEndpoint.ask[Int](removeMsg).recover { case e: IOException => - logWarning(s"Error trying to remove RDD $rddId", e) + logWarning(s"Error trying to remove RDD $rddId from block manager ${bm.blockManagerId}", + e) 0 // zero blocks were removed } }.toSeq @@ -195,7 +196,8 @@ class BlockManagerMasterEndpoint( val futures = requiredBlockManagers.map { bm => bm.slaveEndpoint.ask[Int](removeMsg).recover { case e: IOException => - logWarning(s"Error trying to remove broadcast $broadcastId", e) + logWarning(s"Error trying to remove broadcast $broadcastId from block manager " + + s"${bm.blockManagerId}", e) 0 // zero blocks were removed } }.toSeq From 66a3a5a2dc83e03dedcee9839415c1ddc1fb8125 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Tue, 3 Apr 2018 11:05:29 -0700 Subject: [PATCH 0554/1161] [SPARK-23099][SS] Migrate foreach sink to DataSourceV2 ## What changes were proposed in this pull request? Migrate foreach sink to DataSourceV2. Since the previous attempt at this PR #20552, we've changed and strictly defined the lifecycle of writer components. This means we no longer need the complicated lifecycle shim from that PR; it just naturally works. ## How was this patch tested? existing tests Author: Jose Torres Closes #20951 from jose-torres/foreach. --- .../sql/execution/streaming/ForeachSink.scala | 68 ----------- .../sources/ForeachWriterProvider.scala | 111 ++++++++++++++++++ .../sql/streaming/DataStreamWriter.scala | 4 +- .../ForeachWriterSuite.scala} | 83 ++++++------- .../sql/streaming/StreamingQuerySuite.scala | 1 + 5 files changed, 156 insertions(+), 111 deletions(-) delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ForeachSink.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterProvider.scala rename sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/{ForeachSinkSuite.scala => sources/ForeachWriterSuite.scala} (77%) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ForeachSink.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ForeachSink.scala deleted file mode 100644 index 2cc54107f8b83..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ForeachSink.scala +++ /dev/null @@ -1,68 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.streaming - -import org.apache.spark.TaskContext -import org.apache.spark.sql.{DataFrame, Encoder, ForeachWriter} -import org.apache.spark.sql.catalyst.encoders.encoderFor - -/** - * A [[Sink]] that forwards all data into [[ForeachWriter]] according to the contract defined by - * [[ForeachWriter]]. - * - * @param writer The [[ForeachWriter]] to process all data. - * @tparam T The expected type of the sink. - */ -class ForeachSink[T : Encoder](writer: ForeachWriter[T]) extends Sink with Serializable { - - override def addBatch(batchId: Long, data: DataFrame): Unit = { - // This logic should've been as simple as: - // ``` - // data.as[T].foreachPartition { iter => ... } - // ``` - // - // Unfortunately, doing that would just break the incremental planing. The reason is, - // `Dataset.foreachPartition()` would further call `Dataset.rdd()`, but `Dataset.rdd()` will - // create a new plan. Because StreamExecution uses the existing plan to collect metrics and - // update watermark, we should never create a new plan. Otherwise, metrics and watermark are - // updated in the new plan, and StreamExecution cannot retrieval them. - // - // Hence, we need to manually convert internal rows to objects using encoder. - val encoder = encoderFor[T].resolveAndBind( - data.logicalPlan.output, - data.sparkSession.sessionState.analyzer) - data.queryExecution.toRdd.foreachPartition { iter => - if (writer.open(TaskContext.getPartitionId(), batchId)) { - try { - while (iter.hasNext) { - writer.process(encoder.fromRow(iter.next())) - } - } catch { - case e: Throwable => - writer.close(e) - throw e - } - writer.close(null) - } else { - writer.close(null) - } - } - } - - override def toString(): String = "ForeachSink" -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterProvider.scala new file mode 100644 index 0000000000000..df5d69d57e36f --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterProvider.scala @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.sources + +import org.apache.spark.sql.{Encoder, ForeachWriter, SparkSession} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder} +import org.apache.spark.sql.sources.v2.{DataSourceOptions, StreamWriteSupport} +import org.apache.spark.sql.sources.v2.writer.{DataWriter, DataWriterFactory, SupportsWriteInternalRow, WriterCommitMessage} +import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter +import org.apache.spark.sql.streaming.OutputMode +import org.apache.spark.sql.types.StructType + +/** + * A [[org.apache.spark.sql.sources.v2.DataSourceV2]] for forwarding data into the specified + * [[ForeachWriter]]. + * + * @param writer The [[ForeachWriter]] to process all data. + * @tparam T The expected type of the sink. + */ +case class ForeachWriterProvider[T: Encoder](writer: ForeachWriter[T]) extends StreamWriteSupport { + override def createStreamWriter( + queryId: String, + schema: StructType, + mode: OutputMode, + options: DataSourceOptions): StreamWriter = { + new StreamWriter with SupportsWriteInternalRow { + override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {} + override def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {} + + override def createInternalRowWriterFactory(): DataWriterFactory[InternalRow] = { + val encoder = encoderFor[T].resolveAndBind( + schema.toAttributes, + SparkSession.getActiveSession.get.sessionState.analyzer) + ForeachWriterFactory(writer, encoder) + } + + override def toString: String = "ForeachSink" + } + } +} + +case class ForeachWriterFactory[T: Encoder]( + writer: ForeachWriter[T], + encoder: ExpressionEncoder[T]) + extends DataWriterFactory[InternalRow] { + override def createDataWriter( + partitionId: Int, + attemptNumber: Int, + epochId: Long): ForeachDataWriter[T] = { + new ForeachDataWriter(writer, encoder, partitionId, epochId) + } +} + +/** + * A [[DataWriter]] which writes data in this partition to a [[ForeachWriter]]. + * @param writer The [[ForeachWriter]] to process all data. + * @param encoder An encoder which can convert [[InternalRow]] to the required type [[T]] + * @param partitionId + * @param epochId + * @tparam T The type expected by the writer. + */ +class ForeachDataWriter[T : Encoder]( + writer: ForeachWriter[T], + encoder: ExpressionEncoder[T], + partitionId: Int, + epochId: Long) + extends DataWriter[InternalRow] { + + // If open returns false, we should skip writing rows. + private val opened = writer.open(partitionId, epochId) + + override def write(record: InternalRow): Unit = { + if (!opened) return + + try { + writer.process(encoder.fromRow(record)) + } catch { + case t: Throwable => + writer.close(t) + throw t + } + } + + override def commit(): WriterCommitMessage = { + writer.close(null) + ForeachWriterCommitMessage + } + + override def abort(): Unit = {} +} + +/** + * An empty [[WriterCommitMessage]]. [[ForeachWriter]] implementations have no global coordination. + */ +case object ForeachWriterCommitMessage extends WriterCommitMessage diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala index 2fc903168cfa0..effc1471e8e12 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger -import org.apache.spark.sql.execution.streaming.sources.{MemoryPlanV2, MemorySinkV2} +import org.apache.spark.sql.execution.streaming.sources.{ForeachWriterProvider, MemoryPlanV2, MemorySinkV2} import org.apache.spark.sql.sources.v2.StreamWriteSupport /** @@ -269,7 +269,7 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { query } else if (source == "foreach") { assertNotPartitioned("foreach") - val sink = new ForeachSink[T](foreachWriter)(ds.exprEnc) + val sink = new ForeachWriterProvider[T](foreachWriter)(ds.exprEnc) df.sparkSession.sessionState.streamingQueryManager.startQuery( extraOptions.get("queryName"), extraOptions.get("checkpointLocation"), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterSuite.scala similarity index 77% rename from sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterSuite.scala index b249dd41a84a6..03bf71b3f4b78 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterSuite.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.execution.streaming +package org.apache.spark.sql.execution.streaming.sources import java.util.concurrent.ConcurrentLinkedQueue @@ -25,11 +25,12 @@ import org.scalatest.BeforeAndAfter import org.apache.spark.SparkException import org.apache.spark.sql.ForeachWriter +import org.apache.spark.sql.execution.streaming.MemoryStream import org.apache.spark.sql.functions.{count, window} import org.apache.spark.sql.streaming.{OutputMode, StreamingQueryException, StreamTest} import org.apache.spark.sql.test.SharedSQLContext -class ForeachSinkSuite extends StreamTest with SharedSQLContext with BeforeAndAfter { +class ForeachWriterSuite extends StreamTest with SharedSQLContext with BeforeAndAfter { import testImplicits._ @@ -47,9 +48,9 @@ class ForeachSinkSuite extends StreamTest with SharedSQLContext with BeforeAndAf .start() def verifyOutput(expectedVersion: Int, expectedData: Seq[Int]): Unit = { - import ForeachSinkSuite._ + import ForeachWriterSuite._ - val events = ForeachSinkSuite.allEvents() + val events = ForeachWriterSuite.allEvents() assert(events.size === 2) // one seq of events for each of the 2 partitions // Verify both seq of events have an Open event as the first event @@ -64,13 +65,13 @@ class ForeachSinkSuite extends StreamTest with SharedSQLContext with BeforeAndAf } // -- batch 0 --------------------------------------- - ForeachSinkSuite.clear() + ForeachWriterSuite.clear() input.addData(1, 2, 3, 4) query.processAllAvailable() verifyOutput(expectedVersion = 0, expectedData = 1 to 4) // -- batch 1 --------------------------------------- - ForeachSinkSuite.clear() + ForeachWriterSuite.clear() input.addData(5, 6, 7, 8) query.processAllAvailable() verifyOutput(expectedVersion = 1, expectedData = 5 to 8) @@ -95,27 +96,27 @@ class ForeachSinkSuite extends StreamTest with SharedSQLContext with BeforeAndAf input.addData(1, 2, 3, 4) query.processAllAvailable() - var allEvents = ForeachSinkSuite.allEvents() + var allEvents = ForeachWriterSuite.allEvents() assert(allEvents.size === 1) var expectedEvents = Seq( - ForeachSinkSuite.Open(partition = 0, version = 0), - ForeachSinkSuite.Process(value = 4), - ForeachSinkSuite.Close(None) + ForeachWriterSuite.Open(partition = 0, version = 0), + ForeachWriterSuite.Process(value = 4), + ForeachWriterSuite.Close(None) ) assert(allEvents === Seq(expectedEvents)) - ForeachSinkSuite.clear() + ForeachWriterSuite.clear() // -- batch 1 --------------------------------------- input.addData(5, 6, 7, 8) query.processAllAvailable() - allEvents = ForeachSinkSuite.allEvents() + allEvents = ForeachWriterSuite.allEvents() assert(allEvents.size === 1) expectedEvents = Seq( - ForeachSinkSuite.Open(partition = 0, version = 1), - ForeachSinkSuite.Process(value = 8), - ForeachSinkSuite.Close(None) + ForeachWriterSuite.Open(partition = 0, version = 1), + ForeachWriterSuite.Process(value = 8), + ForeachWriterSuite.Close(None) ) assert(allEvents === Seq(expectedEvents)) @@ -131,7 +132,7 @@ class ForeachSinkSuite extends StreamTest with SharedSQLContext with BeforeAndAf .foreach(new TestForeachWriter() { override def process(value: Int): Unit = { super.process(value) - throw new RuntimeException("error") + throw new RuntimeException("ForeachSinkSuite error") } }).start() input.addData(1, 2, 3, 4) @@ -141,18 +142,18 @@ class ForeachSinkSuite extends StreamTest with SharedSQLContext with BeforeAndAf query.processAllAvailable() } assert(e.getCause.isInstanceOf[SparkException]) - assert(e.getCause.getCause.getMessage === "error") + assert(e.getCause.getCause.getCause.getMessage === "ForeachSinkSuite error") assert(query.isActive === false) - val allEvents = ForeachSinkSuite.allEvents() + val allEvents = ForeachWriterSuite.allEvents() assert(allEvents.size === 1) - assert(allEvents(0)(0) === ForeachSinkSuite.Open(partition = 0, version = 0)) - assert(allEvents(0)(1) === ForeachSinkSuite.Process(value = 1)) + assert(allEvents(0)(0) === ForeachWriterSuite.Open(partition = 0, version = 0)) + assert(allEvents(0)(1) === ForeachWriterSuite.Process(value = 1)) // `close` should be called with the error - val errorEvent = allEvents(0)(2).asInstanceOf[ForeachSinkSuite.Close] + val errorEvent = allEvents(0)(2).asInstanceOf[ForeachWriterSuite.Close] assert(errorEvent.error.get.isInstanceOf[RuntimeException]) - assert(errorEvent.error.get.getMessage === "error") + assert(errorEvent.error.get.getMessage === "ForeachSinkSuite error") } } @@ -177,12 +178,12 @@ class ForeachSinkSuite extends StreamTest with SharedSQLContext with BeforeAndAf inputData.addData(10, 11, 12) query.processAllAvailable() - val allEvents = ForeachSinkSuite.allEvents() + val allEvents = ForeachWriterSuite.allEvents() assert(allEvents.size === 1) val expectedEvents = Seq( - ForeachSinkSuite.Open(partition = 0, version = 0), - ForeachSinkSuite.Process(value = 3), - ForeachSinkSuite.Close(None) + ForeachWriterSuite.Open(partition = 0, version = 0), + ForeachWriterSuite.Process(value = 3), + ForeachWriterSuite.Close(None) ) assert(allEvents === Seq(expectedEvents)) } finally { @@ -216,21 +217,21 @@ class ForeachSinkSuite extends StreamTest with SharedSQLContext with BeforeAndAf query.processAllAvailable() // There should be 3 batches and only does the last batch contain a value. - val allEvents = ForeachSinkSuite.allEvents() + val allEvents = ForeachWriterSuite.allEvents() assert(allEvents.size === 3) val expectedEvents = Seq( Seq( - ForeachSinkSuite.Open(partition = 0, version = 0), - ForeachSinkSuite.Close(None) + ForeachWriterSuite.Open(partition = 0, version = 0), + ForeachWriterSuite.Close(None) ), Seq( - ForeachSinkSuite.Open(partition = 0, version = 1), - ForeachSinkSuite.Close(None) + ForeachWriterSuite.Open(partition = 0, version = 1), + ForeachWriterSuite.Close(None) ), Seq( - ForeachSinkSuite.Open(partition = 0, version = 2), - ForeachSinkSuite.Process(value = 3), - ForeachSinkSuite.Close(None) + ForeachWriterSuite.Open(partition = 0, version = 2), + ForeachWriterSuite.Process(value = 3), + ForeachWriterSuite.Close(None) ) ) assert(allEvents === expectedEvents) @@ -258,7 +259,7 @@ class ForeachSinkSuite extends StreamTest with SharedSQLContext with BeforeAndAf } /** A global object to collect events in the executor */ -object ForeachSinkSuite { +object ForeachWriterSuite { trait Event @@ -285,21 +286,21 @@ object ForeachSinkSuite { /** A [[ForeachWriter]] that writes collected events to ForeachSinkSuite */ class TestForeachWriter extends ForeachWriter[Int] { - ForeachSinkSuite.clear() + ForeachWriterSuite.clear() - private val events = mutable.ArrayBuffer[ForeachSinkSuite.Event]() + private val events = mutable.ArrayBuffer[ForeachWriterSuite.Event]() override def open(partitionId: Long, version: Long): Boolean = { - events += ForeachSinkSuite.Open(partition = partitionId, version = version) + events += ForeachWriterSuite.Open(partition = partitionId, version = version) true } override def process(value: Int): Unit = { - events += ForeachSinkSuite.Process(value) + events += ForeachWriterSuite.Process(value) } override def close(errorOrNull: Throwable): Unit = { - events += ForeachSinkSuite.Close(error = Option(errorOrNull)) - ForeachSinkSuite.addEvents(events) + events += ForeachWriterSuite.Close(error = Option(errorOrNull)) + ForeachWriterSuite.addEvents(events) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala index 08749b49997e0..20942ed93897c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala @@ -32,6 +32,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.execution.streaming.sources.TestForeachWriter import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.v2.reader.DataReaderFactory From 1035aaa61704b2790192d3186fe37e678553d36d Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 4 Apr 2018 01:36:58 +0200 Subject: [PATCH 0555/1161] [SPARK-23587][SQL] Add interpreted execution for MapObjects expression ## What changes were proposed in this pull request? Add interpreted execution for `MapObjects` expression. ## How was this patch tested? Added unit test. Author: Liang-Chi Hsieh Closes #20771 from viirya/SPARK-23587. --- .../expressions/objects/objects.scala | 110 ++++++++++++++++-- .../expressions/ObjectExpressionsSuite.scala | 67 ++++++++++- 2 files changed, 165 insertions(+), 12 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index adf9ddf327c96..0e9d357c19c63 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions.objects import java.lang.reflect.Modifier +import scala.collection.JavaConverters._ import scala.collection.mutable.Builder import scala.language.existentials import scala.reflect.ClassTag @@ -501,12 +502,22 @@ case class LambdaVariable( value: String, isNull: String, dataType: DataType, - nullable: Boolean = true) extends LeafExpression - with Unevaluable with NonSQLExpression { + nullable: Boolean = true) extends LeafExpression with NonSQLExpression { + + // Interpreted execution of `LambdaVariable` always get the 0-index element from input row. + override def eval(input: InternalRow): Any = { + assert(input.numFields == 1, + "The input row of interpreted LambdaVariable should have only 1 field.") + input.get(0, dataType) + } override def genCode(ctx: CodegenContext): ExprCode = { ExprCode(code = "", value = value, isNull = if (nullable) isNull else "false") } + + // This won't be called as `genCode` is overrided, just overriding it to make + // `LambdaVariable` non-abstract. + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = ev } /** @@ -599,8 +610,92 @@ case class MapObjects private( override def children: Seq[Expression] = lambdaFunction :: inputData :: Nil - override def eval(input: InternalRow): Any = - throw new UnsupportedOperationException("Only code-generated evaluation is supported") + // The data with UserDefinedType are actually stored with the data type of its sqlType. + // When we want to apply MapObjects on it, we have to use it. + lazy private val inputDataType = inputData.dataType match { + case u: UserDefinedType[_] => u.sqlType + case _ => inputData.dataType + } + + private def executeFuncOnCollection(inputCollection: Seq[_]): Iterator[_] = { + val row = new GenericInternalRow(1) + inputCollection.toIterator.map { element => + row.update(0, element) + lambdaFunction.eval(row) + } + } + + private lazy val convertToSeq: Any => Seq[_] = inputDataType match { + case ObjectType(cls) if classOf[Seq[_]].isAssignableFrom(cls) => + _.asInstanceOf[Seq[_]] + case ObjectType(cls) if cls.isArray => + _.asInstanceOf[Array[_]].toSeq + case ObjectType(cls) if classOf[java.util.List[_]].isAssignableFrom(cls) => + _.asInstanceOf[java.util.List[_]].asScala + case ObjectType(cls) if cls == classOf[Object] => + (inputCollection) => { + if (inputCollection.getClass.isArray) { + inputCollection.asInstanceOf[Array[_]].toSeq + } else { + inputCollection.asInstanceOf[Seq[_]] + } + } + case ArrayType(et, _) => + _.asInstanceOf[ArrayData].array + } + + private lazy val mapElements: Seq[_] => Any = customCollectionCls match { + case Some(cls) if classOf[Seq[_]].isAssignableFrom(cls) => + // Scala sequence + executeFuncOnCollection(_).toSeq + case Some(cls) if classOf[scala.collection.Set[_]].isAssignableFrom(cls) => + // Scala set + executeFuncOnCollection(_).toSet + case Some(cls) if classOf[java.util.List[_]].isAssignableFrom(cls) => + // Java list + if (cls == classOf[java.util.List[_]] || cls == classOf[java.util.AbstractList[_]] || + cls == classOf[java.util.AbstractSequentialList[_]]) { + // Specifying non concrete implementations of `java.util.List` + executeFuncOnCollection(_).toSeq.asJava + } else { + val constructors = cls.getConstructors() + val intParamConstructor = constructors.find { constructor => + constructor.getParameterCount == 1 && constructor.getParameterTypes()(0) == classOf[Int] + } + val noParamConstructor = constructors.find { constructor => + constructor.getParameterCount == 0 + } + + val constructor = intParamConstructor.map { intConstructor => + (len: Int) => intConstructor.newInstance(len.asInstanceOf[Object]) + }.getOrElse { + (_: Int) => noParamConstructor.get.newInstance() + } + + // Specifying concrete implementations of `java.util.List` + (inputs) => { + val results = executeFuncOnCollection(inputs) + val builder = constructor(inputs.length).asInstanceOf[java.util.List[Any]] + results.foreach(builder.add(_)) + builder + } + } + case None => + // array + x => new GenericArrayData(executeFuncOnCollection(x).toArray) + case Some(cls) => + throw new RuntimeException(s"class `${cls.getName}` is not supported by `MapObjects` as " + + "resulting collection.") + } + + override def eval(input: InternalRow): Any = { + val inputCollection = inputData.eval(input) + + if (inputCollection == null) { + return null + } + mapElements(convertToSeq(inputCollection)) + } override def dataType: DataType = customCollectionCls.map(ObjectType.apply).getOrElse( @@ -647,13 +742,6 @@ case class MapObjects private( case _ => "" } - // The data with PythonUserDefinedType are actually stored with the data type of its sqlType. - // When we want to apply MapObjects on it, we have to use it. - val inputDataType = inputData.dataType match { - case p: PythonUserDefinedType => p.sqlType - case _ => inputData.dataType - } - // `MapObjects` generates a while loop to traverse the elements of the input collection. We // need to take care of Seq and List because they may have O(n) complexity for indexed accessing // like `list.get(1)`. Here we use Iterator to traverse Seq and List. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala index 1f6964dfef598..0edd27c8241e8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.expressions +import scala.collection.JavaConverters._ import scala.reflect.ClassTag import org.apache.spark.{SparkConf, SparkFunSuite} @@ -25,7 +26,7 @@ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions.objects._ -import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData} import org.apache.spark.sql.types._ @@ -135,6 +136,70 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } } + test("SPARK-23587: MapObjects should support interpreted execution") { + def testMapObjects(collection: Any, collectionCls: Class[_], inputType: DataType): Unit = { + val function = (lambda: Expression) => Add(lambda, Literal(1)) + val elementType = IntegerType + val expected = Seq(2, 3, 4) + + val inputObject = BoundReference(0, inputType, nullable = true) + val optClass = Option(collectionCls) + val mapObj = MapObjects(function, inputObject, elementType, true, optClass) + val row = InternalRow.fromSeq(Seq(collection)) + val result = mapObj.eval(row) + + collectionCls match { + case null => + assert(result.asInstanceOf[ArrayData].array.toSeq == expected) + case l if classOf[java.util.List[_]].isAssignableFrom(l) => + assert(result.asInstanceOf[java.util.List[_]].asScala.toSeq == expected) + case s if classOf[Seq[_]].isAssignableFrom(s) => + assert(result.asInstanceOf[Seq[_]].toSeq == expected) + case s if classOf[scala.collection.Set[_]].isAssignableFrom(s) => + assert(result.asInstanceOf[scala.collection.Set[_]] == expected.toSet) + } + } + + val customCollectionClasses = Seq(classOf[Seq[Int]], classOf[scala.collection.Set[Int]], + classOf[java.util.List[Int]], classOf[java.util.AbstractList[Int]], + classOf[java.util.AbstractSequentialList[Int]], classOf[java.util.Vector[Int]], + classOf[java.util.Stack[Int]], null) + + val list = new java.util.ArrayList[Int]() + list.add(1) + list.add(2) + list.add(3) + val arrayData = new GenericArrayData(Array(1, 2, 3)) + val vector = new java.util.Vector[Int]() + vector.add(1) + vector.add(2) + vector.add(3) + val stack = new java.util.Stack[Int]() + stack.add(1) + stack.add(2) + stack.add(3) + + Seq( + (Seq(1, 2, 3), ObjectType(classOf[Seq[Int]])), + (Array(1, 2, 3), ObjectType(classOf[Array[Int]])), + (Seq(1, 2, 3), ObjectType(classOf[Object])), + (Array(1, 2, 3), ObjectType(classOf[Object])), + (list, ObjectType(classOf[java.util.List[Int]])), + (vector, ObjectType(classOf[java.util.Vector[Int]])), + (stack, ObjectType(classOf[java.util.Stack[Int]])), + (arrayData, ArrayType(IntegerType)) + ).foreach { case (collection, inputType) => + customCollectionClasses.foreach(testMapObjects(collection, _, inputType)) + + // Unsupported custom collection class + val errMsg = intercept[RuntimeException] { + testMapObjects(collection, classOf[scala.collection.Map[Int, Int]], inputType) + }.getMessage() + assert(errMsg.contains("`scala.collection.Map` is not supported by `MapObjects` " + + "as resulting collection.")) + } + } + test("SPARK-23592: DecodeUsingSerializer should support interpreted execution") { val cls = classOf[java.lang.Integer] val inputObject = BoundReference(0, ObjectType(classOf[Array[Byte]]), nullable = true) From 359375eff74630c9f0ea5a90ab7d45bf1b281ed0 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Tue, 3 Apr 2018 17:09:12 -0700 Subject: [PATCH 0556/1161] [SPARK-23809][SQL] Active SparkSession should be set by getOrCreate ## What changes were proposed in this pull request? Currently, the active spark session is set inconsistently (e.g., in createDataFrame, prior to query execution). Many places in spark also incorrectly query active session when they should be calling activeSession.getOrElse(defaultSession) and so might get None even if a Spark session exists. The semantics here can be cleaned up if we also set the active session when the default session is set. Related: https://github.com/apache/spark/pull/20926/files ## How was this patch tested? Unit test, existing test. Note that if https://github.com/apache/spark/pull/20926 merges first we should also update the tests there. Author: Eric Liang Closes #20927 from ericl/active-session-cleanup. --- .../org/apache/spark/sql/SparkSession.scala | 14 +++++++++++++- .../spark/sql/SparkSessionBuilderSuite.scala | 18 ++++++++++++++++++ .../apache/spark/sql/test/TestSQLContext.scala | 1 + .../apache/spark/sql/hive/test/TestHive.scala | 3 +++ 4 files changed, 35 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index 734573ba31f71..b107492fbb330 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -951,7 +951,8 @@ object SparkSession { session = new SparkSession(sparkContext, None, None, extensions) options.foreach { case (k, v) => session.initialSessionOptions.put(k, v) } - defaultSession.set(session) + setDefaultSession(session) + setActiveSession(session) // Register a successfully instantiated context to the singleton. This should be at the // end of the class definition so that the singleton is updated only if there is no @@ -1027,6 +1028,17 @@ object SparkSession { */ def getDefaultSession: Option[SparkSession] = Option(defaultSession.get) + /** + * Returns the currently active SparkSession, otherwise the default one. If there is no default + * SparkSession, throws an exception. + * + * @since 2.4.0 + */ + def active: SparkSession = { + getActiveSession.getOrElse(getDefaultSession.getOrElse( + throw new IllegalStateException("No active or default Spark session found"))) + } + //////////////////////////////////////////////////////////////////////////////////////// // Private methods from now on //////////////////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderSuite.scala index c0301f2ce2d66..44bf8624a6bcd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderSuite.scala @@ -50,6 +50,24 @@ class SparkSessionBuilderSuite extends SparkFunSuite with BeforeAndAfterEach { assert(SparkSession.builder().getOrCreate() == session) } + test("sets default and active session") { + assert(SparkSession.getDefaultSession == None) + assert(SparkSession.getActiveSession == None) + val session = SparkSession.builder().master("local").getOrCreate() + assert(SparkSession.getDefaultSession == Some(session)) + assert(SparkSession.getActiveSession == Some(session)) + } + + test("get active or default session") { + val session = SparkSession.builder().master("local").getOrCreate() + assert(SparkSession.active == session) + SparkSession.clearActiveSession() + assert(SparkSession.active == session) + SparkSession.clearDefaultSession() + intercept[IllegalStateException](SparkSession.active) + session.stop() + } + test("config options are propagated to existing SparkSession") { val session1 = SparkSession.builder().master("local").config("spark-config1", "a").getOrCreate() assert(session1.conf.get("spark-config1") == "a") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala index 3038b822beb4a..17603deacdcdd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala @@ -35,6 +35,7 @@ private[spark] class TestSparkSession(sc: SparkContext) extends SparkSession(sc) } SparkSession.setDefaultSession(this) + SparkSession.setActiveSession(this) @transient override lazy val sessionState: SessionState = { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index 814038d4ef7af..a7006a16d7b73 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -179,6 +179,9 @@ private[hive] class TestHiveSparkSession( loadTestTables) } + SparkSession.setDefaultSession(this) + SparkSession.setActiveSession(this) + { // set the metastore temporary configuration val metastoreTempConf = HiveUtils.newTemporaryConfiguration(useInMemoryDerby = false) ++ Map( ConfVars.METASTORE_INTEGER_JDO_PUSHDOWN.varname -> "true", From 5cfd5fabcdbd77a806b98a6dd59b02772d2f6dee Mon Sep 17 00:00:00 2001 From: Robert Kruszewski Date: Tue, 3 Apr 2018 17:25:54 -0700 Subject: [PATCH 0557/1161] [SPARK-23802][SQL] PropagateEmptyRelation can leave query plan in unresolved state ## What changes were proposed in this pull request? Add cast to nulls introduced by PropagateEmptyRelation so in cases they're part of coalesce they will not break its type checking rules ## How was this patch tested? Added unit test Author: Robert Kruszewski Closes #20914 from robert3005/rk/propagate-empty-fix. --- .../optimizer/PropagateEmptyRelation.scala | 8 ++++-- .../PropagateEmptyRelationSuite.scala | 26 ++++++++++++++----- 2 files changed, 25 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala index a6e5aa6daca65..c3fdb924243df 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala @@ -17,10 +17,12 @@ package org.apache.spark.sql.catalyst.optimizer +import org.apache.spark.sql.catalyst.analysis.CastSupport import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ +import org.apache.spark.sql.internal.SQLConf /** * Collapse plans consisting empty local relations generated by [[PruneFilters]]. @@ -32,7 +34,7 @@ import org.apache.spark.sql.catalyst.rules._ * - Aggregate with all empty children and at least one grouping expression. * - Generate(Explode) with all empty children. Others like Hive UDTF may return results. */ -object PropagateEmptyRelation extends Rule[LogicalPlan] with PredicateHelper { +object PropagateEmptyRelation extends Rule[LogicalPlan] with PredicateHelper with CastSupport { private def isEmptyLocalRelation(plan: LogicalPlan): Boolean = plan match { case p: LocalRelation => p.data.isEmpty case _ => false @@ -43,7 +45,9 @@ object PropagateEmptyRelation extends Rule[LogicalPlan] with PredicateHelper { // Construct a project list from plan's output, while the value is always NULL. private def nullValueProjectList(plan: LogicalPlan): Seq[NamedExpression] = - plan.output.map{ a => Alias(Literal(null), a.name)(a.exprId) } + plan.output.map{ a => Alias(cast(Literal(null), a.dataType), a.name)(a.exprId) } + + override def conf: SQLConf = SQLConf.get def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case p: Union if p.children.forall(isEmptyLocalRelation) => diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala index 3964508e3a55e..f1ce7543ffdc1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.RuleExecutor -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{IntegerType, StructType} class PropagateEmptyRelationSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { @@ -37,7 +37,8 @@ class PropagateEmptyRelationSuite extends PlanTest { ReplaceIntersectWithSemiJoin, PushDownPredicate, PruneFilters, - PropagateEmptyRelation) :: Nil + PropagateEmptyRelation, + CollapseProject) :: Nil } object OptimizeWithoutPropagateEmptyRelation extends RuleExecutor[LogicalPlan] { @@ -48,7 +49,8 @@ class PropagateEmptyRelationSuite extends PlanTest { ReplaceExceptWithAntiJoin, ReplaceIntersectWithSemiJoin, PushDownPredicate, - PruneFilters) :: Nil + PruneFilters, + CollapseProject) :: Nil } val testRelation1 = LocalRelation.fromExternalRows(Seq('a.int), data = Seq(Row(1))) @@ -79,9 +81,11 @@ class PropagateEmptyRelationSuite extends PlanTest { (true, false, Inner, Some(LocalRelation('a.int, 'b.int))), (true, false, Cross, Some(LocalRelation('a.int, 'b.int))), - (true, false, LeftOuter, Some(Project(Seq('a, Literal(null).as('b)), testRelation1).analyze)), + (true, false, LeftOuter, + Some(Project(Seq('a, Literal(null).cast(IntegerType).as('b)), testRelation1).analyze)), (true, false, RightOuter, Some(LocalRelation('a.int, 'b.int))), - (true, false, FullOuter, Some(Project(Seq('a, Literal(null).as('b)), testRelation1).analyze)), + (true, false, FullOuter, + Some(Project(Seq('a, Literal(null).cast(IntegerType).as('b)), testRelation1).analyze)), (true, false, LeftAnti, Some(testRelation1)), (true, false, LeftSemi, Some(LocalRelation('a.int))), @@ -89,8 +93,9 @@ class PropagateEmptyRelationSuite extends PlanTest { (false, true, Cross, Some(LocalRelation('a.int, 'b.int))), (false, true, LeftOuter, Some(LocalRelation('a.int, 'b.int))), (false, true, RightOuter, - Some(Project(Seq(Literal(null).as('a), 'b), testRelation2).analyze)), - (false, true, FullOuter, Some(Project(Seq(Literal(null).as('a), 'b), testRelation2).analyze)), + Some(Project(Seq(Literal(null).cast(IntegerType).as('a), 'b), testRelation2).analyze)), + (false, true, FullOuter, + Some(Project(Seq(Literal(null).cast(IntegerType).as('a), 'b), testRelation2).analyze)), (false, true, LeftAnti, Some(LocalRelation('a.int))), (false, true, LeftSemi, Some(LocalRelation('a.int))), @@ -209,4 +214,11 @@ class PropagateEmptyRelationSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + + test("propagate empty relation keeps the plan resolved") { + val query = testRelation1.join( + LocalRelation('a.int, 'b.int), UsingJoin(FullOuter, "a" :: Nil), None) + val optimized = Optimize.execute(query.analyze) + assert(optimized.resolved) + } } From 16ef6baa36ac11c72cfeafaa2363e6b69f0ba573 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Wed, 4 Apr 2018 14:31:03 +0800 Subject: [PATCH 0558/1161] [SPARK-23826][TEST] TestHiveSparkSession should set default session ## What changes were proposed in this pull request? In TestHive, the base spark session does this in getOrCreate(), we emulate that behavior for tests. ## How was this patch tested? N/A Author: gatorsmile Closes #20969 from gatorsmile/setDefault. --- .../main/scala/org/apache/spark/sql/hive/test/TestHive.scala | 4 ---- 1 file changed, 4 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index a7006a16d7b73..965aea2b61456 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -159,10 +159,6 @@ private[hive] class TestHiveSparkSession( private val loadTestTables: Boolean) extends SparkSession(sc) with Logging { self => - // TODO(SPARK-23826): TestHiveSparkSession should set default session the same way as - // TestSparkSession, but doing this the same way breaks many tests in the package. We need - // to investigate and find a different strategy. - def this(sc: SparkContext, loadTestTables: Boolean) { this( sc, From 5197562afe8534b29f5a0d72683c2859f796275d Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Wed, 4 Apr 2018 14:39:19 +0800 Subject: [PATCH 0559/1161] [SPARK-21351][SQL] Update nullability based on children's output ## What changes were proposed in this pull request? This pr added a new optimizer rule `UpdateNullabilityInAttributeReferences ` to update the nullability that `Filter` changes when having `IsNotNull`. In the master, optimized plans do not respect the nullability when `Filter` has `IsNotNull`. This wrongly generates unnecessary code. For example: ``` scala> val df = Seq((Some(1), Some(2))).toDF("a", "b") scala> val bIsNotNull = df.where($"b" =!= 2).select($"b") scala> val targetQuery = bIsNotNull.distinct scala> val targetQuery.queryExecution.optimizedPlan.output(0).nullable res5: Boolean = true scala> targetQuery.debugCodegen Found 2 WholeStageCodegen subtrees. == Subtree 1 / 2 == *HashAggregate(keys=[b#19], functions=[], output=[b#19]) +- Exchange hashpartitioning(b#19, 200) +- *HashAggregate(keys=[b#19], functions=[], output=[b#19]) +- *Project [_2#16 AS b#19] +- *Filter isnotnull(_2#16) +- LocalTableScan [_1#15, _2#16] Generated code: ... /* 124 */ protected void processNext() throws java.io.IOException { ... /* 132 */ // output the result /* 133 */ /* 134 */ while (agg_mapIter.next()) { /* 135 */ wholestagecodegen_numOutputRows.add(1); /* 136 */ UnsafeRow agg_aggKey = (UnsafeRow) agg_mapIter.getKey(); /* 137 */ UnsafeRow agg_aggBuffer = (UnsafeRow) agg_mapIter.getValue(); /* 138 */ /* 139 */ boolean agg_isNull4 = agg_aggKey.isNullAt(0); /* 140 */ int agg_value4 = agg_isNull4 ? -1 : (agg_aggKey.getInt(0)); /* 141 */ agg_rowWriter1.zeroOutNullBytes(); /* 142 */ // We don't need this NULL check because NULL is filtered out in `$"b" =!=2` /* 143 */ if (agg_isNull4) { /* 144 */ agg_rowWriter1.setNullAt(0); /* 145 */ } else { /* 146 */ agg_rowWriter1.write(0, agg_value4); /* 147 */ } /* 148 */ append(agg_result1); /* 149 */ /* 150 */ if (shouldStop()) return; /* 151 */ } /* 152 */ /* 153 */ agg_mapIter.close(); /* 154 */ if (agg_sorter == null) { /* 155 */ agg_hashMap.free(); /* 156 */ } /* 157 */ } /* 158 */ /* 159 */ } ``` In the line 143, we don't need this NULL check because NULL is filtered out in `$"b" =!=2`. This pr could remove this NULL check; ``` scala> val targetQuery.queryExecution.optimizedPlan.output(0).nullable res5: Boolean = false scala> targetQuery.debugCodegen ... Generated code: ... /* 144 */ protected void processNext() throws java.io.IOException { ... /* 152 */ // output the result /* 153 */ /* 154 */ while (agg_mapIter.next()) { /* 155 */ wholestagecodegen_numOutputRows.add(1); /* 156 */ UnsafeRow agg_aggKey = (UnsafeRow) agg_mapIter.getKey(); /* 157 */ UnsafeRow agg_aggBuffer = (UnsafeRow) agg_mapIter.getValue(); /* 158 */ /* 159 */ int agg_value4 = agg_aggKey.getInt(0); /* 160 */ agg_rowWriter1.write(0, agg_value4); /* 161 */ append(agg_result1); /* 162 */ /* 163 */ if (shouldStop()) return; /* 164 */ } /* 165 */ /* 166 */ agg_mapIter.close(); /* 167 */ if (agg_sorter == null) { /* 168 */ agg_hashMap.free(); /* 169 */ } /* 170 */ } ``` ## How was this patch tested? Added `UpdateNullabilityInAttributeReferencesSuite` for unit tests. Author: Takeshi Yamamuro Closes #18576 from maropu/SPARK-21351. --- .../sql/catalyst/optimizer/Optimizer.scala | 19 ++++++- ...ullabilityInAttributeReferencesSuite.scala | 57 +++++++++++++++++++ .../optimizer/complexTypesSuite.scala | 9 --- .../org/apache/spark/sql/DataFrameSuite.scala | 5 -- 4 files changed, 75 insertions(+), 15 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UpdateNullabilityInAttributeReferencesSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 2829d1d81eb1a..9a1bbc675e397 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -153,7 +153,9 @@ abstract class Optimizer(sessionCatalog: SessionCatalog) RewritePredicateSubquery, ColumnPruning, CollapseProject, - RemoveRedundantProject) + RemoveRedundantProject) :+ + Batch("UpdateAttributeReferences", Once, + UpdateNullabilityInAttributeReferences) } /** @@ -1309,3 +1311,18 @@ object RemoveRepetitionFromGroupExpressions extends Rule[LogicalPlan] { } } } + +/** + * Updates nullability in [[AttributeReference]]s if nullability is different between + * non-leaf plan's expressions and the children output. + */ +object UpdateNullabilityInAttributeReferences extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + case p if !p.isInstanceOf[LeafNode] => + val nullabilityMap = AttributeMap(p.children.flatMap(_.output).map { x => x -> x.nullable }) + p transformExpressions { + case ar: AttributeReference if nullabilityMap.contains(ar) => + ar.withNullability(nullabilityMap(ar)) + } + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UpdateNullabilityInAttributeReferencesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UpdateNullabilityInAttributeReferencesSuite.scala new file mode 100644 index 0000000000000..09b11f5aba2a0 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UpdateNullabilityInAttributeReferencesSuite.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions.{CreateArray, GetArrayItem} +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.RuleExecutor + + +class UpdateNullabilityInAttributeReferencesSuite extends PlanTest { + + object Optimizer extends RuleExecutor[LogicalPlan] { + val batches = + Batch("Constant Folding", FixedPoint(10), + NullPropagation, + ConstantFolding, + BooleanSimplification, + SimplifyConditionals, + SimplifyBinaryComparison, + SimplifyExtractValueOps) :: + Batch("UpdateAttributeReferences", Once, + UpdateNullabilityInAttributeReferences) :: Nil + } + + test("update nullability in AttributeReference") { + val rel = LocalRelation('a.long.notNull) + // In the 'original' plans below, the Aggregate node produced by groupBy() has a + // nullable AttributeReference to `b`, because both array indexing and map lookup are + // nullable expressions. After optimization, the same attribute is now non-nullable, + // but the AttributeReference is not updated to reflect this. So, we need to update nullability + // by the `UpdateNullabilityInAttributeReferences` rule. + val original = rel + .select(GetArrayItem(CreateArray(Seq('a, 'a + 1L)), 0) as "b") + .groupBy($"b")("1") + val expected = rel.select('a as "b").groupBy($"b")("1").analyze + val optimized = Optimizer.execute(original.analyze) + comparePlans(optimized, expected) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala index 21ed987627b3b..633d86d495581 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala @@ -378,15 +378,6 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { .groupBy($"foo")("1") checkRule(structRel, structExpected) - // These tests must use nullable attributes from the base relation for the following reason: - // in the 'original' plans below, the Aggregate node produced by groupBy() has a - // nullable AttributeReference to a1, because both array indexing and map lookup are - // nullable expressions. After optimization, the same attribute is now non-nullable, - // but the AttributeReference is not updated to reflect this. In the 'expected' plans, - // the grouping expressions have the same nullability as the original attribute in the - // relation. If that attribute is non-nullable, the tests will fail as the plans will - // compare differently, so for these tests we must use a nullable attribute. See - // SPARK-23634. val arrayRel = relation .select(GetArrayItem(CreateArray(Seq('nullable_id, 'nullable_id + 1L)), 0) as "a1") .groupBy($"a1")("1") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index f7b3393f65cb1..60e84e6ee7504 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -2055,11 +2055,6 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { expr: String, expectedNonNullableColumns: Seq[String]): Unit = { val dfWithFilter = df.where(s"isnotnull($expr)").selectExpr(expr) - // In the logical plan, all the output columns of input dataframe are nullable - dfWithFilter.queryExecution.optimizedPlan.collect { - case e: Filter => assert(e.output.forall(_.nullable)) - } - dfWithFilter.queryExecution.executedPlan.collect { // When the child expression in isnotnull is null-intolerant (i.e. any null input will // result in null output), the involved columns are converted to not nullable; From a35523653cdac039ee2ddff316bc2c25d6514a91 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Wed, 4 Apr 2018 18:36:15 +0200 Subject: [PATCH 0560/1161] [SPARK-23583][SQL] Invoke should support interpreted execution ## What changes were proposed in this pull request? This pr added interpreted execution for `Invoke`. ## How was this patch tested? Added tests in `ObjectExpressionsSuite`. Author: Kazuaki Ishizaki Closes #20797 from kiszk/SPARK-28583. --- .../spark/sql/catalyst/ScalaReflection.scala | 48 +++++++++++++- .../expressions/objects/objects.scala | 56 ++++++++++++++-- .../expressions/ObjectExpressionsSuite.scala | 65 +++++++++++++++++++ 3 files changed, 163 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 9a4bf0075a178..1aae3aea3a31a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst import org.apache.spark.sql.catalyst.analysis.{GetColumnByOrdinal, UnresolvedAttribute, UnresolvedExtractValue} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.objects._ -import org.apache.spark.sql.catalyst.util.{DateTimeUtils, GenericArrayData} +import org.apache.spark.sql.catalyst.util.{ArrayData, DateTimeUtils, GenericArrayData, MapData} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} @@ -794,6 +794,52 @@ object ScalaReflection extends ScalaReflection { "interface", "long", "native", "new", "null", "package", "private", "protected", "public", "return", "short", "static", "strictfp", "super", "switch", "synchronized", "this", "throw", "throws", "transient", "true", "try", "void", "volatile", "while") + + val typeJavaMapping = Map[DataType, Class[_]]( + BooleanType -> classOf[Boolean], + ByteType -> classOf[Byte], + ShortType -> classOf[Short], + IntegerType -> classOf[Int], + LongType -> classOf[Long], + FloatType -> classOf[Float], + DoubleType -> classOf[Double], + StringType -> classOf[UTF8String], + DateType -> classOf[DateType.InternalType], + TimestampType -> classOf[TimestampType.InternalType], + BinaryType -> classOf[BinaryType.InternalType], + CalendarIntervalType -> classOf[CalendarInterval] + ) + + val typeBoxedJavaMapping = Map[DataType, Class[_]]( + BooleanType -> classOf[java.lang.Boolean], + ByteType -> classOf[java.lang.Byte], + ShortType -> classOf[java.lang.Short], + IntegerType -> classOf[java.lang.Integer], + LongType -> classOf[java.lang.Long], + FloatType -> classOf[java.lang.Float], + DoubleType -> classOf[java.lang.Double], + DateType -> classOf[java.lang.Integer], + TimestampType -> classOf[java.lang.Long] + ) + + def dataTypeJavaClass(dt: DataType): Class[_] = { + dt match { + case _: DecimalType => classOf[Decimal] + case _: StructType => classOf[InternalRow] + case _: ArrayType => classOf[ArrayData] + case _: MapType => classOf[MapData] + case ObjectType(cls) => cls + case _ => typeJavaMapping.getOrElse(dt, classOf[java.lang.Object]) + } + } + + def expressionJavaClasses(arguments: Seq[Expression]): Seq[Class[_]] = { + if (arguments != Nil) { + arguments.map(e => dataTypeJavaClass(e.dataType)) + } else { + Seq.empty + } + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 0e9d357c19c63..a455c1c821a26 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.expressions.objects -import java.lang.reflect.Modifier +import java.lang.reflect.{Method, Modifier} import scala.collection.JavaConverters._ import scala.collection.mutable.Builder @@ -28,7 +28,7 @@ import scala.util.Try import org.apache.spark.{SparkConf, SparkEnv} import org.apache.spark.serializer._ import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.{InternalRow, ScalaReflection} import org.apache.spark.sql.catalyst.ScalaReflection.universe.TermName import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions._ @@ -104,6 +104,38 @@ trait InvokeLike extends Expression with NonSQLExpression { (argCode, argValues.mkString(", "), resultIsNull) } + + /** + * Evaluate each argument with a given row, invoke a method with a given object and arguments, + * and cast a return value if the return type can be mapped to a Java Boxed type + * + * @param obj the object for the method to be called. If null, perform s static method call + * @param method the method object to be called + * @param arguments the arguments used for the method call + * @param input the row used for evaluating arguments + * @param dataType the data type of the return object + * @return the return object of a method call + */ + def invoke( + obj: Any, + method: Method, + arguments: Seq[Expression], + input: InternalRow, + dataType: DataType): Any = { + val args = arguments.map(e => e.eval(input).asInstanceOf[Object]) + if (needNullCheck && args.exists(_ == null)) { + // return null if one of arguments is null + null + } else { + val ret = method.invoke(obj, args: _*) + val boxedClass = ScalaReflection.typeBoxedJavaMapping.get(dataType) + if (boxedClass.isDefined) { + boxedClass.get.cast(ret) + } else { + ret + } + } + } } /** @@ -264,12 +296,11 @@ case class Invoke( propagateNull: Boolean = true, returnNullable : Boolean = true) extends InvokeLike { + lazy val argClasses = ScalaReflection.expressionJavaClasses(arguments) + override def nullable: Boolean = targetObject.nullable || needNullCheck || returnNullable override def children: Seq[Expression] = targetObject +: arguments - override def eval(input: InternalRow): Any = - throw new UnsupportedOperationException("Only code-generated evaluation is supported.") - private lazy val encodedFunctionName = TermName(functionName).encodedName.toString @transient lazy val method = targetObject.dataType match { @@ -283,6 +314,21 @@ case class Invoke( case _ => None } + override def eval(input: InternalRow): Any = { + val obj = targetObject.eval(input) + if (obj == null) { + // return null if obj is null + null + } else { + val invokeMethod = if (method.isDefined) { + method.get + } else { + obj.getClass.getDeclaredMethod(functionName, argClasses: _*) + } + invoke(obj, invokeMethod, arguments, input, dataType) + } + } + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val javaType = CodeGenerator.javaType(dataType) val obj = targetObject.genCode(ctx) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala index 0edd27c8241e8..9bfe2916b0820 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala @@ -24,11 +24,23 @@ import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.ResolveTimeZone import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection import org.apache.spark.sql.catalyst.expressions.objects._ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ +class InvokeTargetClass extends Serializable { + def filterInt(e: Any): Any = e.asInstanceOf[Int] > 0 + def filterPrimitiveInt(e: Int): Boolean = e > 0 + def binOp(e1: Int, e2: Double): Double = e1 + e2 +} + +class InvokeTargetSubClass extends InvokeTargetClass { + override def binOp(e1: Int, e2: Double): Double = e1 - e2 +} class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -81,6 +93,41 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { UnsafeProjection) // TODO(hvanhovell) revert this when SPARK-23587 is fixed } + test("SPARK-23583: Invoke should support interpreted execution") { + val targetObject = new InvokeTargetClass + val funcClass = classOf[InvokeTargetClass] + val funcObj = Literal.create(targetObject, ObjectType(funcClass)) + val targetSubObject = new InvokeTargetSubClass + val funcSubObj = Literal.create(targetSubObject, ObjectType(classOf[InvokeTargetSubClass])) + val funcNullObj = Literal.create(null, ObjectType(funcClass)) + + val inputInt = Seq(BoundReference(0, ObjectType(classOf[Any]), true)) + val inputPrimitiveInt = Seq(BoundReference(0, IntegerType, false)) + val inputSum = Seq(BoundReference(0, IntegerType, false), BoundReference(1, DoubleType, false)) + + checkObjectExprEvaluation( + Invoke(funcObj, "filterInt", ObjectType(classOf[Any]), inputInt), + java.lang.Boolean.valueOf(true), InternalRow.fromSeq(Seq(Integer.valueOf(1)))) + + checkObjectExprEvaluation( + Invoke(funcObj, "filterPrimitiveInt", BooleanType, inputPrimitiveInt), + false, InternalRow.fromSeq(Seq(-1))) + + checkObjectExprEvaluation( + Invoke(funcObj, "filterInt", ObjectType(classOf[Any]), inputInt), + null, InternalRow.fromSeq(Seq(null))) + + checkObjectExprEvaluation( + Invoke(funcNullObj, "filterInt", ObjectType(classOf[Any]), inputInt), + null, InternalRow.fromSeq(Seq(Integer.valueOf(1)))) + + checkObjectExprEvaluation( + Invoke(funcObj, "binOp", DoubleType, inputSum), 1.25, InternalRow.apply(1, 0.25)) + + checkObjectExprEvaluation( + Invoke(funcSubObj, "binOp", DoubleType, inputSum), 0.75, InternalRow.apply(1, 0.25)) + } + test("SPARK-23585: UnwrapOption should support interpreted execution") { val cls = classOf[Option[Int]] val inputObject = BoundReference(0, ObjectType(cls), nullable = true) @@ -105,6 +152,24 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(createExternalRow, Row.fromSeq(Seq(1, "x")), InternalRow.fromSeq(Seq())) } + // by scala values instead of catalyst values. + private def checkObjectExprEvaluation( + expression: => Expression, expected: Any, inputRow: InternalRow = EmptyRow): Unit = { + val serializer = new JavaSerializer(new SparkConf()).newInstance + val resolver = ResolveTimeZone(new SQLConf) + val expr = resolver.resolveTimeZones(serializer.deserialize(serializer.serialize(expression))) + checkEvaluationWithoutCodegen(expr, expected, inputRow) + checkEvaluationWithGeneratedMutableProjection(expr, expected, inputRow) + if (GenerateUnsafeProjection.canSupport(expr.dataType)) { + checkEvaluationWithUnsafeProjection( + expr, + expected, + inputRow, + UnsafeProjection) // TODO(hvanhovell) revert this when SPARK-23587 is fixed + } + checkEvaluationWithOptimization(expr, expected, inputRow) + } + test("SPARK-23594 GetExternalRowField should support interpreted execution") { val inputObject = BoundReference(0, ObjectType(classOf[Row]), nullable = true) val getRowField = GetExternalRowField(inputObject, index = 0, fieldName = "c0") From cccaaa14ad775fb981e501452ba2cc06ff5c0f0a Mon Sep 17 00:00:00 2001 From: Andrew Korzhuev Date: Wed, 4 Apr 2018 12:30:52 -0700 Subject: [PATCH 0561/1161] [SPARK-23668][K8S] Add config option for passing through k8s Pod.spec.imagePullSecrets ## What changes were proposed in this pull request? Pass through the `imagePullSecrets` option to the k8s pod in order to allow user to access private image registries. See https://kubernetes.io/docs/tasks/configure-pod-container/pull-image-private-registry/ ## How was this patch tested? Unit tests + manual testing. Manual testing procedure: 1. Have private image registry. 2. Spark-submit application with no `spark.kubernetes.imagePullSecret` set. Do `kubectl describe pod ...`. See the error message: ``` Error syncing pod, skipping: failed to "StartContainer" for "spark-kubernetes-driver" with ErrImagePull: "rpc error: code = 2 desc = Error: Status 400 trying to pull repository ...: \"{\\n \\\"errors\\\" : [ {\\n \\\"status\\\" : 400,\\n \\\"message\\\" : \\\"Unsupported docker v1 repository request for '...'\\\"\\n } ]\\n}\"" ``` 3. Create secret `kubectl create secret docker-registry ...` 4. Spark-submit with `spark.kubernetes.imagePullSecret` set to the new secret. See that deployment was successful. Author: Andrew Korzhuev Author: Andrew Korzhuev Closes #20811 from andrusha/spark-23668-image-pull-secrets. --- .../org/apache/spark/deploy/k8s/Config.scala | 7 ++++ .../spark/deploy/k8s/KubernetesUtils.scala | 13 +++++++ .../steps/BasicDriverConfigurationStep.scala | 7 +++- .../cluster/k8s/ExecutorPodFactory.scala | 4 +++ .../deploy/k8s/KubernetesUtilsTest.scala | 36 +++++++++++++++++++ .../BasicDriverConfigurationStepSuite.scala | 8 ++++- .../cluster/k8s/ExecutorPodFactorySuite.scala | 5 +++ 7 files changed, 78 insertions(+), 2 deletions(-) create mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesUtilsTest.scala diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala index 405ea476351bb..82f6c714f3555 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala @@ -54,6 +54,13 @@ private[spark] object Config extends Logging { .checkValues(Set("Always", "Never", "IfNotPresent")) .createWithDefault("IfNotPresent") + val IMAGE_PULL_SECRETS = + ConfigBuilder("spark.kubernetes.container.image.pullSecrets") + .doc("Comma separated list of the Kubernetes secrets used " + + "to access private image registries.") + .stringConf + .createOptional + val KUBERNETES_AUTH_DRIVER_CONF_PREFIX = "spark.kubernetes.authenticate.driver" val KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX = diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesUtils.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesUtils.scala index 5bc070147d3a8..5b2bb819cdb14 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesUtils.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesUtils.scala @@ -16,6 +16,8 @@ */ package org.apache.spark.deploy.k8s +import io.fabric8.kubernetes.api.model.LocalObjectReference + import org.apache.spark.SparkConf import org.apache.spark.util.Utils @@ -35,6 +37,17 @@ private[spark] object KubernetesUtils { sparkConf.getAllWithPrefix(prefix).toMap } + /** + * Parses comma-separated list of imagePullSecrets into K8s-understandable format + */ + def parseImagePullSecrets(imagePullSecrets: Option[String]): List[LocalObjectReference] = { + imagePullSecrets match { + case Some(secretsCommaSeparated) => + secretsCommaSeparated.split(',').map(_.trim).map(new LocalObjectReference(_)).toList + case None => Nil + } + } + def requireNandDefined(opt1: Option[_], opt2: Option[_], errMessage: String): Unit = { opt1.foreach { _ => require(opt2.isEmpty, errMessage) } } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStep.scala index b811db324108c..fcb1db8008053 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStep.scala @@ -18,7 +18,7 @@ package org.apache.spark.deploy.k8s.submit.steps import scala.collection.JavaConverters._ -import io.fabric8.kubernetes.api.model.{ContainerBuilder, EnvVarBuilder, EnvVarSourceBuilder, PodBuilder, QuantityBuilder} +import io.fabric8.kubernetes.api.model._ import org.apache.spark.{SparkConf, SparkException} import org.apache.spark.deploy.k8s.Config._ @@ -51,6 +51,8 @@ private[spark] class BasicDriverConfigurationStep( .get(DRIVER_CONTAINER_IMAGE) .getOrElse(throw new SparkException("Must specify the driver container image")) + private val imagePullSecrets = sparkConf.get(IMAGE_PULL_SECRETS) + // CPU settings private val driverCpuCores = sparkConf.getOption("spark.driver.cores").getOrElse("1") private val driverLimitCores = sparkConf.get(KUBERNETES_DRIVER_LIMIT_CORES) @@ -129,6 +131,8 @@ private[spark] class BasicDriverConfigurationStep( case _ => driverContainerWithoutArgs.addToArgs(appArgs: _*).build() } + val parsedImagePullSecrets = KubernetesUtils.parseImagePullSecrets(imagePullSecrets) + val baseDriverPod = new PodBuilder(driverSpec.driverPod) .editOrNewMetadata() .withName(driverPodName) @@ -138,6 +142,7 @@ private[spark] class BasicDriverConfigurationStep( .withNewSpec() .withRestartPolicy("Never") .withNodeSelector(nodeSelector.asJava) + .withImagePullSecrets(parsedImagePullSecrets.asJava) .endSpec() .build() diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala index 7143f7a6f0b71..8607d6fba3234 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala @@ -68,6 +68,7 @@ private[spark] class ExecutorPodFactory( .get(EXECUTOR_CONTAINER_IMAGE) .getOrElse(throw new SparkException("Must specify the executor container image")) private val imagePullPolicy = sparkConf.get(CONTAINER_IMAGE_PULL_POLICY) + private val imagePullSecrets = sparkConf.get(IMAGE_PULL_SECRETS) private val blockManagerPort = sparkConf .getInt("spark.blockmanager.port", DEFAULT_BLOCKMANAGER_PORT) @@ -103,6 +104,8 @@ private[spark] class ExecutorPodFactory( nodeToLocalTaskCount: Map[String, Int]): Pod = { val name = s"$executorPodNamePrefix-exec-$executorId" + val parsedImagePullSecrets = KubernetesUtils.parseImagePullSecrets(imagePullSecrets) + // hostname must be no longer than 63 characters, so take the last 63 characters of the pod // name as the hostname. This preserves uniqueness since the end of name contains // executorId @@ -194,6 +197,7 @@ private[spark] class ExecutorPodFactory( .withHostname(hostname) .withRestartPolicy("Never") .withNodeSelector(nodeSelector.asJava) + .withImagePullSecrets(parsedImagePullSecrets.asJava) .endSpec() .build() diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesUtilsTest.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesUtilsTest.scala new file mode 100644 index 0000000000000..cf41b22e241af --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesUtilsTest.scala @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s + +import io.fabric8.kubernetes.api.model.LocalObjectReference + +import org.apache.spark.SparkFunSuite + +class KubernetesUtilsTest extends SparkFunSuite { + + test("testParseImagePullSecrets") { + val noSecrets = KubernetesUtils.parseImagePullSecrets(None) + assert(noSecrets === Nil) + + val oneSecret = KubernetesUtils.parseImagePullSecrets(Some("imagePullSecret")) + assert(oneSecret === new LocalObjectReference("imagePullSecret") :: Nil) + + val commaSeparatedSecrets = KubernetesUtils.parseImagePullSecrets(Some("s1, s2 , s3,s4")) + assert(commaSeparatedSecrets.map(_.getName) === "s1" :: "s2" :: "s3" :: "s4" :: Nil) + } + +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStepSuite.scala index e59c6d28a8cc2..ee450fff8d376 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStepSuite.scala @@ -51,6 +51,7 @@ class BasicDriverConfigurationStepSuite extends SparkFunSuite { .set(s"$KUBERNETES_DRIVER_ANNOTATION_PREFIX$CUSTOM_ANNOTATION_KEY", CUSTOM_ANNOTATION_VALUE) .set(s"$KUBERNETES_DRIVER_ENV_KEY$DRIVER_CUSTOM_ENV_KEY1", "customDriverEnv1") .set(s"$KUBERNETES_DRIVER_ENV_KEY$DRIVER_CUSTOM_ENV_KEY2", "customDriverEnv2") + .set(IMAGE_PULL_SECRETS, "imagePullSecret1, imagePullSecret2") val submissionStep = new BasicDriverConfigurationStep( APP_ID, @@ -103,7 +104,12 @@ class BasicDriverConfigurationStepSuite extends SparkFunSuite { CUSTOM_ANNOTATION_KEY -> CUSTOM_ANNOTATION_VALUE, SPARK_APP_NAME_ANNOTATION -> APP_NAME) assert(driverPodMetadata.getAnnotations.asScala === expectedAnnotations) - assert(preparedDriverSpec.driverPod.getSpec.getRestartPolicy === "Never") + + val driverPodSpec = preparedDriverSpec.driverPod.getSpec + assert(driverPodSpec.getRestartPolicy === "Never") + assert(driverPodSpec.getImagePullSecrets.size() === 2) + assert(driverPodSpec.getImagePullSecrets.get(0).getName === "imagePullSecret1") + assert(driverPodSpec.getImagePullSecrets.get(1).getName === "imagePullSecret2") val resolvedSparkConf = preparedDriverSpec.driverSparkConf.getAll.toMap val expectedSparkConf = Map( diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala index a71a2a1b888bc..d73df20f0f956 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala @@ -33,6 +33,7 @@ class ExecutorPodFactorySuite extends SparkFunSuite with BeforeAndAfter with Bef private val driverPodUid: String = "driver-uid" private val executorPrefix: String = "base" private val executorImage: String = "executor-image" + private val imagePullSecrets: String = "imagePullSecret1, imagePullSecret2" private val driverPod = new PodBuilder() .withNewMetadata() .withName(driverPodName) @@ -54,6 +55,7 @@ class ExecutorPodFactorySuite extends SparkFunSuite with BeforeAndAfter with Bef .set(KUBERNETES_EXECUTOR_POD_NAME_PREFIX, executorPrefix) .set(CONTAINER_IMAGE, executorImage) .set(KUBERNETES_DRIVER_SUBMIT_CHECK, true) + .set(IMAGE_PULL_SECRETS, imagePullSecrets) } test("basic executor pod has reasonable defaults") { @@ -76,6 +78,9 @@ class ExecutorPodFactorySuite extends SparkFunSuite with BeforeAndAfter with Bef .getRequests.get("memory").getAmount === "1408Mi") assert(executor.getSpec.getContainers.get(0).getResources .getLimits.get("memory").getAmount === "1408Mi") + assert(executor.getSpec.getImagePullSecrets.size() === 2) + assert(executor.getSpec.getImagePullSecrets.get(0).getName === "imagePullSecret1") + assert(executor.getSpec.getImagePullSecrets.get(1).getName === "imagePullSecret2") // The pod has no node selector, volumes. assert(executor.getSpec.getNodeSelector.isEmpty) From d8379e5bc3629f4e8233ad42831bdaf68c24cfeb Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Wed, 4 Apr 2018 15:43:58 -0700 Subject: [PATCH 0562/1161] [SPARK-23838][WEBUI] Running SQL query is displayed as "completed" in SQL tab ## What changes were proposed in this pull request? A running SQL query would appear as completed in the Spark UI: ![image1](https://user-images.githubusercontent.com/1097932/38170733-3d7cb00c-35bf-11e8-994c-43f2d4fa285d.png) We can see the query in "Completed queries", while in in the job page we see it's still running Job 132. ![image2](https://user-images.githubusercontent.com/1097932/38170735-48f2c714-35bf-11e8-8a41-6fae23543c46.png) After some time in the query still appears in "Completed queries" (while it's still running), but the "Duration" gets increased. ![image3](https://user-images.githubusercontent.com/1097932/38170737-50f87ea4-35bf-11e8-8b60-000f6f918964.png) To reproduce, we can run a query with multiple jobs. E.g. Run TPCDS q6. The reason is that updates from executions are written into kvstore periodically, and the job start event may be missed. ## How was this patch tested? Manually run the job again and check the SQL Tab. The fix is pretty simple. Author: Gengliang Wang Closes #20955 from gengliangwang/jobCompleted. --- .../apache/spark/sql/execution/ui/AllExecutionsPage.scala | 3 ++- .../spark/sql/execution/ui/SQLAppStatusListener.scala | 6 ++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala index e751ce39cd5d7..582528777f90e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala @@ -39,7 +39,8 @@ private[ui] class AllExecutionsPage(parent: SQLTab) extends WebUIPage("") with L val failed = new mutable.ArrayBuffer[SQLExecutionUIData]() sqlStore.executionsList().foreach { e => - val isRunning = e.jobs.exists { case (_, status) => status == JobExecutionStatus.RUNNING } + val isRunning = e.completionTime.isEmpty || + e.jobs.exists { case (_, status) => status == JobExecutionStatus.RUNNING } val isFailed = e.jobs.exists { case (_, status) => status == JobExecutionStatus.FAILED } if (isRunning) { running += e diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala index 71e9f93c4566e..2b6bb48467eb3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala @@ -88,7 +88,7 @@ class SQLAppStatusListener( exec.jobs = exec.jobs + (jobId -> JobExecutionStatus.RUNNING) exec.stages ++= event.stageIds.toSet - update(exec) + update(exec, force = true) } override def onStageSubmitted(event: SparkListenerStageSubmitted): Unit = { @@ -308,11 +308,13 @@ class SQLAppStatusListener( }) } - private def update(exec: LiveExecutionData): Unit = { + private def update(exec: LiveExecutionData, force: Boolean = false): Unit = { val now = System.nanoTime() if (exec.endEvents >= exec.jobs.size + 1) { exec.write(kvstore, now) liveExecutions.remove(exec.executionId) + } else if (force) { + exec.write(kvstore, now) } else if (liveUpdatePeriodNs >= 0) { if (now - exec.lastWriteTime > liveUpdatePeriodNs) { exec.write(kvstore, now) From d3bd0435ee4ff3d414f32cce3f58b6b9f67e68bc Mon Sep 17 00:00:00 2001 From: jinxing Date: Wed, 4 Apr 2018 15:51:27 -0700 Subject: [PATCH 0563/1161] [SPARK-23637][YARN] Yarn might allocate more resource if a same executor is killed multiple times. ## What changes were proposed in this pull request? `YarnAllocator` uses `numExecutorsRunning` to track the number of running executor. `numExecutorsRunning` is used to check if there're executors missing and need to allocate more. In current code, `numExecutorsRunning` can be negative when driver asks to kill a same idle executor multiple times. ## How was this patch tested? UT added Author: jinxing Closes #20781 from jinxing64/SPARK-23637. --- .../spark/deploy/yarn/YarnAllocator.scala | 36 +++++++------- .../deploy/yarn/YarnAllocatorSuite.scala | 48 ++++++++++++++++++- 2 files changed, 66 insertions(+), 18 deletions(-) diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala index a537243d641cb..ebee3d431744d 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala @@ -81,7 +81,8 @@ private[yarn] class YarnAllocator( private val releasedContainers = Collections.newSetFromMap[ContainerId]( new ConcurrentHashMap[ContainerId, java.lang.Boolean]) - private val numExecutorsRunning = new AtomicInteger(0) + private val runningExecutors = Collections.newSetFromMap[String]( + new ConcurrentHashMap[String, java.lang.Boolean]()) private val numExecutorsStarting = new AtomicInteger(0) @@ -166,7 +167,7 @@ private[yarn] class YarnAllocator( clock = newClock } - def getNumExecutorsRunning: Int = numExecutorsRunning.get() + def getNumExecutorsRunning: Int = runningExecutors.size() def getNumExecutorsFailed: Int = synchronized { val endTime = clock.getTimeMillis() @@ -242,12 +243,11 @@ private[yarn] class YarnAllocator( * Request that the ResourceManager release the container running the specified executor. */ def killExecutor(executorId: String): Unit = synchronized { - if (executorIdToContainer.contains(executorId)) { - val container = executorIdToContainer.get(executorId).get - internalReleaseContainer(container) - numExecutorsRunning.decrementAndGet() - } else { - logWarning(s"Attempted to kill unknown executor $executorId!") + executorIdToContainer.get(executorId) match { + case Some(container) if !releasedContainers.contains(container.getId) => + internalReleaseContainer(container) + runningExecutors.remove(executorId) + case _ => logWarning(s"Attempted to kill unknown executor $executorId!") } } @@ -274,7 +274,7 @@ private[yarn] class YarnAllocator( "Launching executor count: %d. Cluster resources: %s.") .format( allocatedContainers.size, - numExecutorsRunning.get, + runningExecutors.size, numExecutorsStarting.get, allocateResponse.getAvailableResources)) @@ -286,7 +286,7 @@ private[yarn] class YarnAllocator( logDebug("Completed %d containers".format(completedContainers.size)) processCompletedContainers(completedContainers.asScala) logDebug("Finished processing %d completed containers. Current running executor count: %d." - .format(completedContainers.size, numExecutorsRunning.get)) + .format(completedContainers.size, runningExecutors.size)) } } @@ -300,9 +300,9 @@ private[yarn] class YarnAllocator( val pendingAllocate = getPendingAllocate val numPendingAllocate = pendingAllocate.size val missing = targetNumExecutors - numPendingAllocate - - numExecutorsStarting.get - numExecutorsRunning.get + numExecutorsStarting.get - runningExecutors.size logDebug(s"Updating resource requests, target: $targetNumExecutors, " + - s"pending: $numPendingAllocate, running: ${numExecutorsRunning.get}, " + + s"pending: $numPendingAllocate, running: ${runningExecutors.size}, " + s"executorsStarting: ${numExecutorsStarting.get}") if (missing > 0) { @@ -502,7 +502,7 @@ private[yarn] class YarnAllocator( s"for executor with ID $executorId") def updateInternalState(): Unit = synchronized { - numExecutorsRunning.incrementAndGet() + runningExecutors.add(executorId) numExecutorsStarting.decrementAndGet() executorIdToContainer(executorId) = container containerIdToExecutorId(container.getId) = executorId @@ -513,7 +513,7 @@ private[yarn] class YarnAllocator( allocatedContainerToHostMap.put(containerId, executorHostname) } - if (numExecutorsRunning.get < targetNumExecutors) { + if (runningExecutors.size() < targetNumExecutors) { numExecutorsStarting.incrementAndGet() if (launchContainers) { launcherPool.execute(new Runnable { @@ -554,7 +554,7 @@ private[yarn] class YarnAllocator( } else { logInfo(("Skip launching executorRunnable as running executors count: %d " + "reached target executors count: %d.").format( - numExecutorsRunning.get, targetNumExecutors)) + runningExecutors.size, targetNumExecutors)) } } } @@ -569,7 +569,11 @@ private[yarn] class YarnAllocator( val exitReason = if (!alreadyReleased) { // Decrement the number of executors running. The next iteration of // the ApplicationMaster's reporting thread will take care of allocating. - numExecutorsRunning.decrementAndGet() + containerIdToExecutorId.get(containerId) match { + case Some(executorId) => runningExecutors.remove(executorId) + case None => logWarning(s"Cannot find executorId for container: ${containerId.toString}") + } + logInfo("Completed container %s%s (state: %s, exit status: %s)".format( containerId, onHostStr, diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala index cb1e3c5268510..525abb6f2b350 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala @@ -251,11 +251,55 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter ContainerStatus.newInstance(c.getId(), ContainerState.COMPLETE, "Finished", 0) } handler.updateResourceRequests() - handler.processCompletedContainers(statuses.toSeq) + handler.processCompletedContainers(statuses) handler.getNumExecutorsRunning should be (0) handler.getPendingAllocate.size should be (1) } + test("kill same executor multiple times") { + val handler = createAllocator(2) + handler.updateResourceRequests() + handler.getNumExecutorsRunning should be (0) + handler.getPendingAllocate.size should be (2) + + val container1 = createContainer("host1") + val container2 = createContainer("host2") + handler.handleAllocatedContainers(Array(container1, container2)) + handler.getNumExecutorsRunning should be (2) + handler.getPendingAllocate.size should be (0) + + val executorToKill = handler.executorIdToContainer.keys.head + handler.killExecutor(executorToKill) + handler.getNumExecutorsRunning should be (1) + handler.killExecutor(executorToKill) + handler.killExecutor(executorToKill) + handler.killExecutor(executorToKill) + handler.getNumExecutorsRunning should be (1) + handler.requestTotalExecutorsWithPreferredLocalities(2, 0, Map.empty, Set.empty) + handler.updateResourceRequests() + handler.getPendingAllocate.size should be (1) + } + + test("process same completed container multiple times") { + val handler = createAllocator(2) + handler.updateResourceRequests() + handler.getNumExecutorsRunning should be (0) + handler.getPendingAllocate.size should be (2) + + val container1 = createContainer("host1") + val container2 = createContainer("host2") + handler.handleAllocatedContainers(Array(container1, container2)) + handler.getNumExecutorsRunning should be (2) + handler.getPendingAllocate.size should be (0) + + val statuses = Seq(container1, container1, container2).map { c => + ContainerStatus.newInstance(c.getId(), ContainerState.COMPLETE, "Finished", 0) + } + handler.processCompletedContainers(statuses) + handler.getNumExecutorsRunning should be (0) + + } + test("lost executor removed from backend") { val handler = createAllocator(4) handler.updateResourceRequests() @@ -272,7 +316,7 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter ContainerStatus.newInstance(c.getId(), ContainerState.COMPLETE, "Failed", -1) } handler.updateResourceRequests() - handler.processCompletedContainers(statuses.toSeq) + handler.processCompletedContainers(statuses) handler.updateResourceRequests() handler.getNumExecutorsRunning should be (0) handler.getPendingAllocate.size should be (2) From c5c8b544047a83cb6128a20d31f1d943a15f9260 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 5 Apr 2018 13:39:45 +0200 Subject: [PATCH 0564/1161] [SPARK-23593][SQL] Add interpreted execution for InitializeJavaBean expression ## What changes were proposed in this pull request? Add interpreted execution for `InitializeJavaBean` expression. ## How was this patch tested? Added unit test. Author: Liang-Chi Hsieh Closes #20756 from viirya/SPARK-23593. --- .../expressions/objects/objects.scala | 47 ++++++++++++++++- .../expressions/ExpressionEvalHelper.scala | 9 ++-- .../expressions/ObjectExpressionsSuite.scala | 52 +++++++++++++++++++ 3 files changed, 103 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index a455c1c821a26..20c4f4c7324fd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -1410,8 +1410,47 @@ case class InitializeJavaBean(beanInstance: Expression, setters: Map[String, Exp override def children: Seq[Expression] = beanInstance +: setters.values.toSeq override def dataType: DataType = beanInstance.dataType - override def eval(input: InternalRow): Any = - throw new UnsupportedOperationException("Only code-generated evaluation is supported.") + private lazy val resolvedSetters = { + assert(beanInstance.dataType.isInstanceOf[ObjectType]) + + val ObjectType(beanClass) = beanInstance.dataType + setters.map { + case (name, expr) => + // Looking for known type mapping. + // But also looking for general `Object`-type parameter for generic methods. + val paramTypes = ScalaReflection.expressionJavaClasses(Seq(expr)) ++ Seq(classOf[Object]) + val methods = paramTypes.flatMap { fieldClass => + try { + Some(beanClass.getDeclaredMethod(name, fieldClass)) + } catch { + case e: NoSuchMethodException => None + } + } + if (methods.isEmpty) { + throw new NoSuchMethodException(s"""A method named "$name" is not declared """ + + "in any enclosing class nor any supertype") + } + methods.head -> expr + } + } + + override def eval(input: InternalRow): Any = { + val instance = beanInstance.eval(input) + if (instance != null) { + val bean = instance.asInstanceOf[Object] + resolvedSetters.foreach { + case (setter, expr) => + val paramVal = expr.eval(input) + if (paramVal == null) { + throw new NullPointerException("The parameter value for setters in " + + "`InitializeJavaBean` can not be null") + } else { + setter.invoke(bean, paramVal.asInstanceOf[AnyRef]) + } + } + } + instance + } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val instanceGen = beanInstance.genCode(ctx) @@ -1424,6 +1463,10 @@ case class InitializeJavaBean(beanInstance: Expression, setters: Map[String, Exp val fieldGen = fieldValue.genCode(ctx) s""" |${fieldGen.code} + |if (${fieldGen.isNull}) { + | throw new NullPointerException("The parameter value for setters in " + + | "`InitializeJavaBean` can not be null"); + |} |$javaBeanInstance.$setterMethod(${fieldGen.value}); """.stripMargin } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index 3828f172a15cf..a5ecd1b68fac4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -55,7 +55,8 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { protected def checkEvaluation( expression: => Expression, expected: Any, inputRow: InternalRow = EmptyRow): Unit = { - val expr = prepareEvaluation(expression) + // Make it as method to obtain fresh expression everytime. + def expr = prepareEvaluation(expression) val catalystValue = CatalystTypeConverters.convertToCatalyst(expected) checkEvaluationWithoutCodegen(expr, catalystValue, inputRow) checkEvaluationWithGeneratedMutableProjection(expr, catalystValue, inputRow) @@ -111,12 +112,14 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { val errMsg = intercept[T] { eval }.getMessage - if (errMsg != expectedErrMsg) { + if (!errMsg.contains(expectedErrMsg)) { fail(s"Expected error message is `$expectedErrMsg`, but `$errMsg` found") } } } - val expr = prepareEvaluation(expression) + + // Make it as method to obtain fresh expression everytime. + def expr = prepareEvaluation(expression) checkException(evaluateWithoutCodegen(expr, inputRow), "non-codegen mode") checkException(evaluateWithGeneratedMutableProjection(expr, inputRow), "codegen mode") if (GenerateUnsafeProjection.canSupport(expr.dataType)) { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala index 9bfe2916b0820..44fecd602e854 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala @@ -128,6 +128,50 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { Invoke(funcSubObj, "binOp", DoubleType, inputSum), 0.75, InternalRow.apply(1, 0.25)) } + test("SPARK-23593: InitializeJavaBean should support interpreted execution") { + val list = new java.util.LinkedList[Int]() + list.add(1) + + val initializeBean = InitializeJavaBean(Literal.fromObject(new java.util.LinkedList[Int]), + Map("add" -> Literal(1))) + checkEvaluation(initializeBean, list, InternalRow.fromSeq(Seq())) + + val initializeWithNonexistingMethod = InitializeJavaBean( + Literal.fromObject(new java.util.LinkedList[Int]), + Map("nonexisting" -> Literal(1))) + checkExceptionInExpression[Exception](initializeWithNonexistingMethod, + InternalRow.fromSeq(Seq()), + """A method named "nonexisting" is not declared in any enclosing class """ + + "nor any supertype") + + val initializeWithWrongParamType = InitializeJavaBean( + Literal.fromObject(new TestBean), + Map("setX" -> Literal("1"))) + intercept[Exception] { + evaluateWithoutCodegen(initializeWithWrongParamType, InternalRow.fromSeq(Seq())) + }.getMessage.contains( + """A method named "setX" is not declared in any enclosing class """ + + "nor any supertype") + } + + test("Can not pass in null into setters in InitializeJavaBean") { + val initializeBean = InitializeJavaBean( + Literal.fromObject(new TestBean), + Map("setNonPrimitive" -> Literal(null))) + intercept[NullPointerException] { + evaluateWithoutCodegen(initializeBean, InternalRow.fromSeq(Seq())) + }.getMessage.contains("The parameter value for setters in `InitializeJavaBean` can not be null") + intercept[NullPointerException] { + evaluateWithGeneratedMutableProjection(initializeBean, InternalRow.fromSeq(Seq())) + }.getMessage.contains("The parameter value for setters in `InitializeJavaBean` can not be null") + + val initializeBean2 = InitializeJavaBean( + Literal.fromObject(new TestBean), + Map("setNonPrimitive" -> Literal("string"))) + evaluateWithoutCodegen(initializeBean2, InternalRow.fromSeq(Seq())) + evaluateWithGeneratedMutableProjection(initializeBean2, InternalRow.fromSeq(Seq())) + } + test("SPARK-23585: UnwrapOption should support interpreted execution") { val cls = classOf[Option[Int]] val inputObject = BoundReference(0, ObjectType(cls), nullable = true) @@ -278,3 +322,11 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } } } + +class TestBean extends Serializable { + private var x: Int = 0 + + def setX(i: Int): Unit = x = i + def setNonPrimitive(i: AnyRef): Unit = + assert(i != null, "this setter should not be called with null.") +} From 1822ecda51cc9e14bb18050e0b8c270fee47ced7 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Thu, 5 Apr 2018 13:47:06 +0200 Subject: [PATCH 0565/1161] [SPARK-23582][SQL] StaticInvoke should support interpreted execution ## What changes were proposed in this pull request? This pr added interpreted execution for `StaticInvoke`. ## How was this patch tested? Added tests in `ObjectExpressionsSuite`. Author: Kazuaki Ishizaki Closes #20753 from kiszk/SPARK-23582. --- .../expressions/objects/objects.scala | 14 +++- .../expressions/ObjectExpressionsSuite.scala | 66 ++++++++++++++++++- 2 files changed, 77 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 20c4f4c7324fd..9ca0b6137679e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -35,6 +35,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData} import org.apache.spark.sql.types._ +import org.apache.spark.util.Utils /** * Common base class for [[StaticInvoke]], [[Invoke]], and [[NewInstance]]. @@ -217,12 +218,21 @@ case class StaticInvoke( returnNullable: Boolean = true) extends InvokeLike { val objectName = staticObject.getName.stripSuffix("$") + val cls = if (staticObject.getName == objectName) { + staticObject + } else { + Utils.classForName(objectName) + } override def nullable: Boolean = needNullCheck || returnNullable override def children: Seq[Expression] = arguments - override def eval(input: InternalRow): Any = - throw new UnsupportedOperationException("Only code-generated evaluation is supported.") + lazy val argClasses = ScalaReflection.expressionJavaClasses(arguments) + @transient lazy val method = cls.getDeclaredMethod(functionName, argClasses : _*) + + override def eval(input: InternalRow): Any = { + invoke(null, method, arguments, input, dataType) + } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val javaType = CodeGenerator.javaType(dataType) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala index 44fecd602e854..eb89e01b5ff9d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.expressions +import java.sql.{Date, Timestamp} + import scala.collection.JavaConverters._ import scala.reflect.ClassTag @@ -28,9 +30,11 @@ import org.apache.spark.sql.catalyst.analysis.ResolveTimeZone import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection import org.apache.spark.sql.catalyst.expressions.objects._ -import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData} +import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.catalyst.util.DateTimeUtils.{SQLDate, SQLTimestamp} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String class InvokeTargetClass extends Serializable { def filterInt(e: Any): Any = e.asInstanceOf[Int] > 0 @@ -93,6 +97,66 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { UnsafeProjection) // TODO(hvanhovell) revert this when SPARK-23587 is fixed } + test("SPARK-23582: StaticInvoke should support interpreted execution") { + Seq((classOf[java.lang.Boolean], "true", true), + (classOf[java.lang.Byte], "1", 1.toByte), + (classOf[java.lang.Short], "257", 257.toShort), + (classOf[java.lang.Integer], "12345", 12345), + (classOf[java.lang.Long], "12345678", 12345678.toLong), + (classOf[java.lang.Float], "12.34", 12.34.toFloat), + (classOf[java.lang.Double], "1.2345678", 1.2345678) + ).foreach { case (cls, arg, expected) => + checkObjectExprEvaluation(StaticInvoke(cls, ObjectType(cls), "valueOf", + Seq(BoundReference(0, ObjectType(classOf[java.lang.String]), true))), + expected, InternalRow.fromSeq(Seq(arg))) + } + + // Return null when null argument is passed with propagateNull = true + val stringCls = classOf[java.lang.String] + checkObjectExprEvaluation(StaticInvoke(stringCls, ObjectType(stringCls), "valueOf", + Seq(BoundReference(0, ObjectType(classOf[Object]), true)), propagateNull = true), + null, InternalRow.fromSeq(Seq(null))) + checkObjectExprEvaluation(StaticInvoke(stringCls, ObjectType(stringCls), "valueOf", + Seq(BoundReference(0, ObjectType(classOf[Object]), true)), propagateNull = false), + "null", InternalRow.fromSeq(Seq(null))) + + // test no argument + val clCls = classOf[java.lang.ClassLoader] + checkObjectExprEvaluation(StaticInvoke(clCls, ObjectType(clCls), "getSystemClassLoader", Nil), + ClassLoader.getSystemClassLoader, InternalRow.empty) + // test more than one argument + val intCls = classOf[java.lang.Integer] + checkObjectExprEvaluation(StaticInvoke(intCls, ObjectType(intCls), "compare", + Seq(BoundReference(0, IntegerType, false), BoundReference(1, IntegerType, false))), + 0, InternalRow.fromSeq(Seq(7, 7))) + + Seq((DateTimeUtils.getClass, TimestampType, "fromJavaTimestamp", ObjectType(classOf[Timestamp]), + new Timestamp(77777), DateTimeUtils.fromJavaTimestamp(new Timestamp(77777))), + (DateTimeUtils.getClass, DateType, "fromJavaDate", ObjectType(classOf[Date]), + new Date(88888888), DateTimeUtils.fromJavaDate(new Date(88888888))), + (classOf[UTF8String], StringType, "fromString", ObjectType(classOf[String]), + "abc", UTF8String.fromString("abc")), + (Decimal.getClass, DecimalType(38, 0), "fromDecimal", ObjectType(classOf[Any]), + BigInt(88888888), Decimal.fromDecimal(BigInt(88888888))), + (Decimal.getClass, DecimalType.SYSTEM_DEFAULT, + "apply", ObjectType(classOf[java.math.BigInteger]), + new java.math.BigInteger("88888888"), Decimal.apply(new java.math.BigInteger("88888888"))), + (classOf[ArrayData], ArrayType(IntegerType), "toArrayData", ObjectType(classOf[Any]), + Array[Int](1, 2, 3), ArrayData.toArrayData(Array[Int](1, 2, 3))), + (classOf[UnsafeArrayData], ArrayType(IntegerType, false), + "fromPrimitiveArray", ObjectType(classOf[Array[Int]]), + Array[Int](1, 2, 3), UnsafeArrayData.fromPrimitiveArray(Array[Int](1, 2, 3))), + (DateTimeUtils.getClass, ObjectType(classOf[Date]), + "toJavaDate", ObjectType(classOf[SQLDate]), 77777, DateTimeUtils.toJavaDate(77777)), + (DateTimeUtils.getClass, ObjectType(classOf[Timestamp]), + "toJavaTimestamp", ObjectType(classOf[SQLTimestamp]), + 88888888.toLong, DateTimeUtils.toJavaTimestamp(88888888)) + ).foreach { case (cls, dataType, methodName, argType, arg, expected) => + checkObjectExprEvaluation(StaticInvoke(cls, dataType, methodName, + Seq(BoundReference(0, argType, true))), expected, InternalRow.fromSeq(Seq(arg))) + } + } + test("SPARK-23583: Invoke should support interpreted execution") { val targetObject = new InvokeTargetClass val funcClass = classOf[InvokeTargetClass] From b2329fb1fcdc0e93c4bdc39d574cde7328ef6094 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Thu, 5 Apr 2018 13:57:41 +0200 Subject: [PATCH 0566/1161] Revert "[SPARK-23593][SQL] Add interpreted execution for InitializeJavaBean expression" This reverts commit c5c8b544047a83cb6128a20d31f1d943a15f9260. --- .../expressions/objects/objects.scala | 47 +---------------- .../expressions/ExpressionEvalHelper.scala | 9 ++-- .../expressions/ObjectExpressionsSuite.scala | 52 ------------------- 3 files changed, 5 insertions(+), 103 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 9ca0b6137679e..3fa91bd36bb60 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -1420,47 +1420,8 @@ case class InitializeJavaBean(beanInstance: Expression, setters: Map[String, Exp override def children: Seq[Expression] = beanInstance +: setters.values.toSeq override def dataType: DataType = beanInstance.dataType - private lazy val resolvedSetters = { - assert(beanInstance.dataType.isInstanceOf[ObjectType]) - - val ObjectType(beanClass) = beanInstance.dataType - setters.map { - case (name, expr) => - // Looking for known type mapping. - // But also looking for general `Object`-type parameter for generic methods. - val paramTypes = ScalaReflection.expressionJavaClasses(Seq(expr)) ++ Seq(classOf[Object]) - val methods = paramTypes.flatMap { fieldClass => - try { - Some(beanClass.getDeclaredMethod(name, fieldClass)) - } catch { - case e: NoSuchMethodException => None - } - } - if (methods.isEmpty) { - throw new NoSuchMethodException(s"""A method named "$name" is not declared """ + - "in any enclosing class nor any supertype") - } - methods.head -> expr - } - } - - override def eval(input: InternalRow): Any = { - val instance = beanInstance.eval(input) - if (instance != null) { - val bean = instance.asInstanceOf[Object] - resolvedSetters.foreach { - case (setter, expr) => - val paramVal = expr.eval(input) - if (paramVal == null) { - throw new NullPointerException("The parameter value for setters in " + - "`InitializeJavaBean` can not be null") - } else { - setter.invoke(bean, paramVal.asInstanceOf[AnyRef]) - } - } - } - instance - } + override def eval(input: InternalRow): Any = + throw new UnsupportedOperationException("Only code-generated evaluation is supported.") override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val instanceGen = beanInstance.genCode(ctx) @@ -1473,10 +1434,6 @@ case class InitializeJavaBean(beanInstance: Expression, setters: Map[String, Exp val fieldGen = fieldValue.genCode(ctx) s""" |${fieldGen.code} - |if (${fieldGen.isNull}) { - | throw new NullPointerException("The parameter value for setters in " + - | "`InitializeJavaBean` can not be null"); - |} |$javaBeanInstance.$setterMethod(${fieldGen.value}); """.stripMargin } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index a5ecd1b68fac4..3828f172a15cf 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -55,8 +55,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { protected def checkEvaluation( expression: => Expression, expected: Any, inputRow: InternalRow = EmptyRow): Unit = { - // Make it as method to obtain fresh expression everytime. - def expr = prepareEvaluation(expression) + val expr = prepareEvaluation(expression) val catalystValue = CatalystTypeConverters.convertToCatalyst(expected) checkEvaluationWithoutCodegen(expr, catalystValue, inputRow) checkEvaluationWithGeneratedMutableProjection(expr, catalystValue, inputRow) @@ -112,14 +111,12 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { val errMsg = intercept[T] { eval }.getMessage - if (!errMsg.contains(expectedErrMsg)) { + if (errMsg != expectedErrMsg) { fail(s"Expected error message is `$expectedErrMsg`, but `$errMsg` found") } } } - - // Make it as method to obtain fresh expression everytime. - def expr = prepareEvaluation(expression) + val expr = prepareEvaluation(expression) checkException(evaluateWithoutCodegen(expr, inputRow), "non-codegen mode") checkException(evaluateWithGeneratedMutableProjection(expr, inputRow), "codegen mode") if (GenerateUnsafeProjection.canSupport(expr.dataType)) { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala index eb89e01b5ff9d..1d59b20077fa9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala @@ -192,50 +192,6 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { Invoke(funcSubObj, "binOp", DoubleType, inputSum), 0.75, InternalRow.apply(1, 0.25)) } - test("SPARK-23593: InitializeJavaBean should support interpreted execution") { - val list = new java.util.LinkedList[Int]() - list.add(1) - - val initializeBean = InitializeJavaBean(Literal.fromObject(new java.util.LinkedList[Int]), - Map("add" -> Literal(1))) - checkEvaluation(initializeBean, list, InternalRow.fromSeq(Seq())) - - val initializeWithNonexistingMethod = InitializeJavaBean( - Literal.fromObject(new java.util.LinkedList[Int]), - Map("nonexisting" -> Literal(1))) - checkExceptionInExpression[Exception](initializeWithNonexistingMethod, - InternalRow.fromSeq(Seq()), - """A method named "nonexisting" is not declared in any enclosing class """ + - "nor any supertype") - - val initializeWithWrongParamType = InitializeJavaBean( - Literal.fromObject(new TestBean), - Map("setX" -> Literal("1"))) - intercept[Exception] { - evaluateWithoutCodegen(initializeWithWrongParamType, InternalRow.fromSeq(Seq())) - }.getMessage.contains( - """A method named "setX" is not declared in any enclosing class """ + - "nor any supertype") - } - - test("Can not pass in null into setters in InitializeJavaBean") { - val initializeBean = InitializeJavaBean( - Literal.fromObject(new TestBean), - Map("setNonPrimitive" -> Literal(null))) - intercept[NullPointerException] { - evaluateWithoutCodegen(initializeBean, InternalRow.fromSeq(Seq())) - }.getMessage.contains("The parameter value for setters in `InitializeJavaBean` can not be null") - intercept[NullPointerException] { - evaluateWithGeneratedMutableProjection(initializeBean, InternalRow.fromSeq(Seq())) - }.getMessage.contains("The parameter value for setters in `InitializeJavaBean` can not be null") - - val initializeBean2 = InitializeJavaBean( - Literal.fromObject(new TestBean), - Map("setNonPrimitive" -> Literal("string"))) - evaluateWithoutCodegen(initializeBean2, InternalRow.fromSeq(Seq())) - evaluateWithGeneratedMutableProjection(initializeBean2, InternalRow.fromSeq(Seq())) - } - test("SPARK-23585: UnwrapOption should support interpreted execution") { val cls = classOf[Option[Int]] val inputObject = BoundReference(0, ObjectType(cls), nullable = true) @@ -386,11 +342,3 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } } } - -class TestBean extends Serializable { - private var x: Int = 0 - - def setX(i: Int): Unit = x = i - def setNonPrimitive(i: AnyRef): Unit = - assert(i != null, "this setter should not be called with null.") -} From d9ca1c906bd0571802f2297c36b407e660fcdb64 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 5 Apr 2018 20:43:05 +0200 Subject: [PATCH 0567/1161] [SPARK-23593][SQL] Add interpreted execution for InitializeJavaBean expression ## What changes were proposed in this pull request? Add interpreted execution for `InitializeJavaBean` expression. ## How was this patch tested? Added unit test. Author: Liang-Chi Hsieh Closes #20985 from viirya/SPARK-23593-2. --- .../expressions/objects/objects.scala | 45 +++++++++++++++-- .../expressions/ExpressionEvalHelper.scala | 9 ++-- .../expressions/ObjectExpressionsSuite.scala | 48 +++++++++++++++++++ 3 files changed, 96 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 3fa91bd36bb60..9252425f86473 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -1420,8 +1420,45 @@ case class InitializeJavaBean(beanInstance: Expression, setters: Map[String, Exp override def children: Seq[Expression] = beanInstance +: setters.values.toSeq override def dataType: DataType = beanInstance.dataType - override def eval(input: InternalRow): Any = - throw new UnsupportedOperationException("Only code-generated evaluation is supported.") + private lazy val resolvedSetters = { + assert(beanInstance.dataType.isInstanceOf[ObjectType]) + + val ObjectType(beanClass) = beanInstance.dataType + setters.map { + case (name, expr) => + // Looking for known type mapping. + // But also looking for general `Object`-type parameter for generic methods. + val paramTypes = ScalaReflection.expressionJavaClasses(Seq(expr)) ++ Seq(classOf[Object]) + val methods = paramTypes.flatMap { fieldClass => + try { + Some(beanClass.getDeclaredMethod(name, fieldClass)) + } catch { + case e: NoSuchMethodException => None + } + } + if (methods.isEmpty) { + throw new NoSuchMethodException(s"""A method named "$name" is not declared """ + + "in any enclosing class nor any supertype") + } + methods.head -> expr + } + } + + override def eval(input: InternalRow): Any = { + val instance = beanInstance.eval(input) + if (instance != null) { + val bean = instance.asInstanceOf[Object] + resolvedSetters.foreach { + case (setter, expr) => + val paramVal = expr.eval(input) + // We don't call setter if input value is null. + if (paramVal != null) { + setter.invoke(bean, paramVal.asInstanceOf[AnyRef]) + } + } + } + instance + } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val instanceGen = beanInstance.genCode(ctx) @@ -1434,7 +1471,9 @@ case class InitializeJavaBean(beanInstance: Expression, setters: Map[String, Exp val fieldGen = fieldValue.genCode(ctx) s""" |${fieldGen.code} - |$javaBeanInstance.$setterMethod(${fieldGen.value}); + |if (!${fieldGen.isNull}) { + | $javaBeanInstance.$setterMethod(${fieldGen.value}); + |} """.stripMargin } val initializeCode = ctx.splitExpressionsWithCurrentInputs( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index 3828f172a15cf..a5ecd1b68fac4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -55,7 +55,8 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { protected def checkEvaluation( expression: => Expression, expected: Any, inputRow: InternalRow = EmptyRow): Unit = { - val expr = prepareEvaluation(expression) + // Make it as method to obtain fresh expression everytime. + def expr = prepareEvaluation(expression) val catalystValue = CatalystTypeConverters.convertToCatalyst(expected) checkEvaluationWithoutCodegen(expr, catalystValue, inputRow) checkEvaluationWithGeneratedMutableProjection(expr, catalystValue, inputRow) @@ -111,12 +112,14 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { val errMsg = intercept[T] { eval }.getMessage - if (errMsg != expectedErrMsg) { + if (!errMsg.contains(expectedErrMsg)) { fail(s"Expected error message is `$expectedErrMsg`, but `$errMsg` found") } } } - val expr = prepareEvaluation(expression) + + // Make it as method to obtain fresh expression everytime. + def expr = prepareEvaluation(expression) checkException(evaluateWithoutCodegen(expr, inputRow), "non-codegen mode") checkException(evaluateWithGeneratedMutableProjection(expr, inputRow), "codegen mode") if (GenerateUnsafeProjection.canSupport(expr.dataType)) { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala index 1d59b20077fa9..b1bc67dfac1b5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala @@ -192,6 +192,46 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { Invoke(funcSubObj, "binOp", DoubleType, inputSum), 0.75, InternalRow.apply(1, 0.25)) } + test("SPARK-23593: InitializeJavaBean should support interpreted execution") { + val list = new java.util.LinkedList[Int]() + list.add(1) + + val initializeBean = InitializeJavaBean(Literal.fromObject(new java.util.LinkedList[Int]), + Map("add" -> Literal(1))) + checkEvaluation(initializeBean, list, InternalRow.fromSeq(Seq())) + + val initializeWithNonexistingMethod = InitializeJavaBean( + Literal.fromObject(new java.util.LinkedList[Int]), + Map("nonexisting" -> Literal(1))) + checkExceptionInExpression[Exception](initializeWithNonexistingMethod, + InternalRow.fromSeq(Seq()), + """A method named "nonexisting" is not declared in any enclosing class """ + + "nor any supertype") + + val initializeWithWrongParamType = InitializeJavaBean( + Literal.fromObject(new TestBean), + Map("setX" -> Literal("1"))) + intercept[Exception] { + evaluateWithoutCodegen(initializeWithWrongParamType, InternalRow.fromSeq(Seq())) + }.getMessage.contains( + """A method named "setX" is not declared in any enclosing class """ + + "nor any supertype") + } + + test("InitializeJavaBean doesn't call setters if input in null") { + val initializeBean = InitializeJavaBean( + Literal.fromObject(new TestBean), + Map("setNonPrimitive" -> Literal(null))) + evaluateWithoutCodegen(initializeBean, InternalRow.fromSeq(Seq())) + evaluateWithGeneratedMutableProjection(initializeBean, InternalRow.fromSeq(Seq())) + + val initializeBean2 = InitializeJavaBean( + Literal.fromObject(new TestBean), + Map("setNonPrimitive" -> Literal("string"))) + evaluateWithoutCodegen(initializeBean2, InternalRow.fromSeq(Seq())) + evaluateWithGeneratedMutableProjection(initializeBean2, InternalRow.fromSeq(Seq())) + } + test("SPARK-23585: UnwrapOption should support interpreted execution") { val cls = classOf[Option[Int]] val inputObject = BoundReference(0, ObjectType(cls), nullable = true) @@ -342,3 +382,11 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } } } + +class TestBean extends Serializable { + private var x: Int = 0 + + def setX(i: Int): Unit = x = i + def setNonPrimitive(i: AnyRef): Unit = + assert(i != null, "this setter should not be called with null.") +} From 4807d381bb113a5c61e6dad88202f23a8b6dd141 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Fri, 6 Apr 2018 10:13:59 +0800 Subject: [PATCH 0568/1161] [SPARK-10399][CORE][SQL] Introduce multiple MemoryBlocks to choose several types of memory block ## What changes were proposed in this pull request? This PR allows us to use one of several types of `MemoryBlock`, such as byte array, int array, long array, or `java.nio.DirectByteBuffer`. To use `java.nio.DirectByteBuffer` allows to have off heap memory which is automatically deallocated by JVM. `MemoryBlock` class has primitive accessors like `Platform.getInt()`, `Platform.putint()`, or `Platform.copyMemory()`. This PR uses `MemoryBlock` for `OffHeapColumnVector`, `UTF8String`, and other places. This PR can improve performance of operations involving memory accesses (e.g. `UTF8String.trim`) by 1.8x. For now, this PR does not use `MemoryBlock` for `BufferHolder` based on cloud-fan's [suggestion](https://github.com/apache/spark/pull/11494#issuecomment-309694290). Since this PR is a successor of #11494, close #11494. Many codes were ported from #11494. Many efforts were put here. **I think this PR should credit to yzotov.** This PR can achieve **1.1-1.4x performance improvements** for operations in `UTF8String` or `Murmur3_x86_32`. Other operations are almost comparable performances. Without this PR ``` OpenJDK 64-Bit Server VM 1.8.0_121-8u121-b13-0ubuntu1.16.04.2-b13 on Linux 4.4.0-22-generic Intel(R) Xeon(R) CPU E5-2667 v3 3.20GHz OpenJDK 64-Bit Server VM 1.8.0_121-8u121-b13-0ubuntu1.16.04.2-b13 on Linux 4.4.0-22-generic Intel(R) Xeon(R) CPU E5-2667 v3 3.20GHz Hash byte arrays with length 268435487: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ Murmur3_x86_32 526 / 536 0.0 131399881.5 1.0X UTF8String benchmark: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ hashCode 525 / 552 1022.6 1.0 1.0X substring 414 / 423 1298.0 0.8 1.3X ``` With this PR ``` OpenJDK 64-Bit Server VM 1.8.0_121-8u121-b13-0ubuntu1.16.04.2-b13 on Linux 4.4.0-22-generic Intel(R) Xeon(R) CPU E5-2667 v3 3.20GHz Hash byte arrays with length 268435487: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ Murmur3_x86_32 474 / 488 0.0 118552232.0 1.0X UTF8String benchmark: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ hashCode 476 / 480 1127.3 0.9 1.0X substring 287 / 291 1869.9 0.5 1.7X ``` Benchmark program ``` test("benchmark Murmur3_x86_32") { val length = 8192 * 32768 + 31 val seed = 42L val iters = 1 << 2 val random = new Random(seed) val arrays = Array.fill[MemoryBlock](numArrays) { val bytes = new Array[Byte](length) random.nextBytes(bytes) new ByteArrayMemoryBlock(bytes, Platform.BYTE_ARRAY_OFFSET, length) } val benchmark = new Benchmark("Hash byte arrays with length " + length, iters * numArrays, minNumIters = 20) benchmark.addCase("HiveHasher") { _: Int => var sum = 0L for (_ <- 0L until iters) { sum += HiveHasher.hashUnsafeBytesBlock( arrays(i), Platform.BYTE_ARRAY_OFFSET, length) } } benchmark.run() } test("benchmark UTF8String") { val N = 512 * 1024 * 1024 val iters = 2 val benchmark = new Benchmark("UTF8String benchmark", N, minNumIters = 20) val str0 = new java.io.StringWriter() { { for (i <- 0 until N) { write(" ") } } }.toString val s0 = UTF8String.fromString(str0) benchmark.addCase("hashCode") { _: Int => var h: Int = 0 for (_ <- 0L until iters) { h += s0.hashCode } } benchmark.addCase("substring") { _: Int => var s: UTF8String = null for (_ <- 0L until iters) { s = s0.substring(N / 2 - 5, N / 2 + 5) } } benchmark.run() } ``` I run [this benchmark program](https://gist.github.com/kiszk/94f75b506c93a663bbbc372ffe8f05de) using [the commit](https://github.com/apache/spark/pull/19222/commits/ee5a79861c18725fb1cd9b518cdfd2489c05b81d6). I got the following results: ``` OpenJDK 64-Bit Server VM 1.8.0_151-8u151-b12-0ubuntu0.16.04.2-b12 on Linux 4.4.0-66-generic Intel(R) Xeon(R) CPU E5-2667 v3 3.20GHz Memory access benchmarks: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ ByteArrayMemoryBlock get/putInt() 220 / 221 609.3 1.6 1.0X Platform get/putInt(byte[]) 220 / 236 610.9 1.6 1.0X Platform get/putInt(Object) 492 / 494 272.8 3.7 0.4X OnHeapMemoryBlock get/putLong() 322 / 323 416.5 2.4 0.7X long[] 221 / 221 608.0 1.6 1.0X Platform get/putLong(long[]) 321 / 321 418.7 2.4 0.7X Platform get/putLong(Object) 561 / 563 239.2 4.2 0.4X ``` I also run [this benchmark program](https://gist.github.com/kiszk/5fdb4e03733a5d110421177e289d1fb5) for comparing performance of `Platform.copyMemory()`. ``` OpenJDK 64-Bit Server VM 1.8.0_151-8u151-b12-0ubuntu0.16.04.2-b12 on Linux 4.4.0-66-generic Intel(R) Xeon(R) CPU E5-2667 v3 3.20GHz Platform copyMemory: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ Object to Object 1961 / 1967 8.6 116.9 1.0X System.arraycopy Object to Object 1917 / 1921 8.8 114.3 1.0X byte array to byte array 1961 / 1968 8.6 116.9 1.0X System.arraycopy byte array to byte array 1909 / 1937 8.8 113.8 1.0X int array to int array 1921 / 1990 8.7 114.5 1.0X double array to double array 1918 / 1923 8.7 114.3 1.0X Object to byte array 1961 / 1967 8.6 116.9 1.0X Object to short array 1965 / 1972 8.5 117.1 1.0X Object to int array 1910 / 1915 8.8 113.9 1.0X Object to float array 1971 / 1978 8.5 117.5 1.0X Object to double array 1919 / 1944 8.7 114.4 1.0X byte array to Object 1959 / 1967 8.6 116.8 1.0X int array to Object 1961 / 1970 8.6 116.9 1.0X double array to Object 1917 / 1924 8.8 114.3 1.0X ``` These results show three facts: 1. According to the second/third or sixth/seventh results in the first experiment, if we use `Platform.get/putInt(Object)`, we achieve more than 2x worse performance than `Platform.get/putInt(byte[])` with concrete type (i.e. `byte[]`). 2. According to the second/third or fourth/fifth/sixth results in the first experiment, the fastest way to access an array element on Java heap is `array[]`. **Cons of `array[]` is that it is not possible to support unaligned-8byte access.** 3. According to the first/second/third or fourth/sixth/seventh results in the first experiment, `getInt()/putInt() or getLong()/putLong()` in subclasses of `MemoryBlock` can achieve comparable performance to `Platform.get/putInt()` or `Platform.get/putLong()` with concrete type (second or sixth result). There is no overhead regarding virtual call. 4. According to results in the second experiment, for `Platform.copy()`, to pass `Object` can achieve the same performance as to pass any type of primitive array as source or destination. 5. According to second/fourth results in the second experiment, `Platform.copy()` can achieve the same performance as `System.arrayCopy`. **It would be good to use `Platform.copy()` since `Platform.copy()` can take any types for src and dst.** We are incrementally replace `Platform.get/putXXX` with `MemoryBlock.get/putXXX`. This is because we have two advantages. 1) Achieve better performance due to having a concrete type for an array. 2) Use simple OO design instead of passing `Object` It is easy to use `MemoryBlock` in `InternalRow`, `BufferHolder`, `TaskMemoryManager`, and others that are already abstracted. It is not easy to use `MemoryBlock` in utility classes related to hashing or others. Other candidates are - UnsafeRow, UnsafeArrayData, UnsafeMapData, SpecificUnsafeRowJoiner - UTF8StringBuffer - BufferHolder - TaskMemoryManager - OnHeapColumnVector - BytesToBytesMap - CachedBatch - classes for hash - others. ## How was this patch tested? Added `UnsafeMemoryAllocator` Author: Kazuaki Ishizaki Closes #19222 from kiszk/SPARK-10399. --- .../sql/catalyst/expressions/HiveHasher.java | 12 +- .../org/apache/spark/unsafe/Platform.java | 2 +- .../spark/unsafe/array/ByteArrayMethods.java | 13 +- .../apache/spark/unsafe/array/LongArray.java | 17 +- .../spark/unsafe/hash/Murmur3_x86_32.java | 45 +++-- .../unsafe/memory/ByteArrayMemoryBlock.java | 128 +++++++++++++ .../unsafe/memory/HeapMemoryAllocator.java | 19 +- .../spark/unsafe/memory/MemoryAllocator.java | 4 +- .../spark/unsafe/memory/MemoryBlock.java | 157 ++++++++++++++-- .../spark/unsafe/memory/MemoryLocation.java | 54 ------ .../unsafe/memory/OffHeapMemoryBlock.java | 105 +++++++++++ .../unsafe/memory/OnHeapMemoryBlock.java | 132 +++++++++++++ .../unsafe/memory/UnsafeMemoryAllocator.java | 21 ++- .../apache/spark/unsafe/types/UTF8String.java | 148 +++++++-------- .../spark/unsafe/PlatformUtilSuite.java | 4 +- .../spark/unsafe/array/LongArraySuite.java | 5 +- .../unsafe/hash/Murmur3_x86_32Suite.java | 18 ++ .../spark/unsafe/memory/MemoryBlockSuite.java | 175 ++++++++++++++++++ .../spark/unsafe/types/UTF8StringSuite.java | 29 +-- .../spark/memory/TaskMemoryManager.java | 22 +-- .../shuffle/sort/ShuffleInMemorySorter.java | 14 +- .../shuffle/sort/ShuffleSortDataFormat.java | 11 +- .../unsafe/sort/UnsafeExternalSorter.java | 2 +- .../unsafe/sort/UnsafeInMemorySorter.java | 13 +- .../spark/memory/TaskMemoryManagerSuite.java | 2 +- .../util/collection/ExternalSorterSuite.scala | 7 +- .../unsafe/sort/RadixSortSuite.scala | 10 +- .../spark/ml/feature/FeatureHasher.scala | 5 +- .../spark/mllib/feature/HashingTF.scala | 2 +- .../catalyst/expressions/UnsafeArrayData.java | 4 +- .../sql/catalyst/expressions/UnsafeRow.java | 4 +- .../spark/sql/catalyst/expressions/XXH64.java | 46 +++-- .../spark/sql/catalyst/expressions/hash.scala | 39 ++-- .../catalyst/expressions/HiveHasherSuite.java | 20 +- .../sql/catalyst/expressions/XXH64Suite.java | 18 +- .../vectorized/OffHeapColumnVector.java | 3 +- .../sql/vectorized/ArrowColumnVector.java | 6 +- .../execution/benchmark/SortBenchmark.scala | 16 +- .../sql/execution/python/RowQueueSuite.scala | 4 +- 39 files changed, 1002 insertions(+), 334 deletions(-) create mode 100644 common/unsafe/src/main/java/org/apache/spark/unsafe/memory/ByteArrayMemoryBlock.java delete mode 100644 common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryLocation.java create mode 100644 common/unsafe/src/main/java/org/apache/spark/unsafe/memory/OffHeapMemoryBlock.java create mode 100644 common/unsafe/src/main/java/org/apache/spark/unsafe/memory/OnHeapMemoryBlock.java create mode 100644 common/unsafe/src/test/java/org/apache/spark/unsafe/memory/MemoryBlockSuite.java diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/expressions/HiveHasher.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/expressions/HiveHasher.java index 73577437ac506..5d905943a3aa7 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/expressions/HiveHasher.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/expressions/HiveHasher.java @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions; import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.memory.MemoryBlock; /** * Simulates Hive's hashing function from Hive v1.2.1 @@ -38,12 +39,17 @@ public static int hashLong(long input) { return (int) ((input >>> 32) ^ input); } - public static int hashUnsafeBytes(Object base, long offset, int lengthInBytes) { + public static int hashUnsafeBytesBlock(MemoryBlock mb) { + long lengthInBytes = mb.size(); assert (lengthInBytes >= 0): "lengthInBytes cannot be negative"; int result = 0; - for (int i = 0; i < lengthInBytes; i++) { - result = (result * 31) + (int) Platform.getByte(base, offset + i); + for (long i = 0; i < lengthInBytes; i++) { + result = (result * 31) + (int) mb.getByte(i); } return result; } + + public static int hashUnsafeBytes(Object base, long offset, int lengthInBytes) { + return hashUnsafeBytesBlock(MemoryBlock.allocateFromObject(base, offset, lengthInBytes)); + } } diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java index aca6fca00c48b..54dcadf3a7754 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java @@ -187,7 +187,7 @@ public static void setMemory(long address, byte value, long size) { } public static void copyMemory( - Object src, long srcOffset, Object dst, long dstOffset, long length) { + Object src, long srcOffset, Object dst, long dstOffset, long length) { // Check if dstOffset is before or after srcOffset to determine if we should copy // forward or backwards. This is necessary in case src and dst overlap. if (dstOffset < srcOffset) { diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java index a6b1f7a16d605..c334c9651cf6b 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java @@ -18,6 +18,7 @@ package org.apache.spark.unsafe.array; import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.memory.MemoryBlock; public class ByteArrayMethods { @@ -48,6 +49,16 @@ public static int roundNumberOfBytesToNearestWord(int numBytes) { public static int MAX_ROUNDED_ARRAY_LENGTH = Integer.MAX_VALUE - 15; private static final boolean unaligned = Platform.unaligned(); + /** + * MemoryBlock equality check for MemoryBlocks. + * @return true if the arrays are equal, false otherwise + */ + public static boolean arrayEqualsBlock( + MemoryBlock leftBase, long leftOffset, MemoryBlock rightBase, long rightOffset, final long length) { + return arrayEquals(leftBase.getBaseObject(), leftBase.getBaseOffset() + leftOffset, + rightBase.getBaseObject(), rightBase.getBaseOffset() + rightOffset, length); + } + /** * Optimized byte array equality check for byte arrays. * @return true if the arrays are equal, false otherwise @@ -56,7 +67,7 @@ public static boolean arrayEquals( Object leftBase, long leftOffset, Object rightBase, long rightOffset, final long length) { int i = 0; - // check if stars align and we can get both offsets to be aligned + // check if starts align and we can get both offsets to be aligned if ((leftOffset % 8) == (rightOffset % 8)) { while ((leftOffset + i) % 8 != 0 && i < length) { if (Platform.getByte(leftBase, leftOffset + i) != diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/array/LongArray.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/array/LongArray.java index 2cd39bd60c2ac..b74d2de0691d5 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/array/LongArray.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/array/LongArray.java @@ -17,7 +17,6 @@ package org.apache.spark.unsafe.array; -import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.memory.MemoryBlock; /** @@ -33,16 +32,12 @@ public final class LongArray { private static final long WIDTH = 8; private final MemoryBlock memory; - private final Object baseObj; - private final long baseOffset; private final long length; public LongArray(MemoryBlock memory) { assert memory.size() < (long) Integer.MAX_VALUE * 8: "Array size >= Integer.MAX_VALUE elements"; this.memory = memory; - this.baseObj = memory.getBaseObject(); - this.baseOffset = memory.getBaseOffset(); this.length = memory.size() / WIDTH; } @@ -51,11 +46,11 @@ public MemoryBlock memoryBlock() { } public Object getBaseObject() { - return baseObj; + return memory.getBaseObject(); } public long getBaseOffset() { - return baseOffset; + return memory.getBaseOffset(); } /** @@ -69,8 +64,8 @@ public long size() { * Fill this all with 0L. */ public void zeroOut() { - for (long off = baseOffset; off < baseOffset + length * WIDTH; off += WIDTH) { - Platform.putLong(baseObj, off, 0); + for (long off = 0; off < length * WIDTH; off += WIDTH) { + memory.putLong(off, 0); } } @@ -80,7 +75,7 @@ public void zeroOut() { public void set(int index, long value) { assert index >= 0 : "index (" + index + ") should >= 0"; assert index < length : "index (" + index + ") should < length (" + length + ")"; - Platform.putLong(baseObj, baseOffset + index * WIDTH, value); + memory.putLong(index * WIDTH, value); } /** @@ -89,6 +84,6 @@ public void set(int index, long value) { public long get(int index) { assert index >= 0 : "index (" + index + ") should >= 0"; assert index < length : "index (" + index + ") should < length (" + length + ")"; - return Platform.getLong(baseObj, baseOffset + index * WIDTH); + return memory.getLong(index * WIDTH); } } diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java index d239de6083ad0..f372b19fac119 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java @@ -17,7 +17,9 @@ package org.apache.spark.unsafe.hash; -import org.apache.spark.unsafe.Platform; +import com.google.common.primitives.Ints; + +import org.apache.spark.unsafe.memory.MemoryBlock; /** * 32-bit Murmur3 hasher. This is based on Guava's Murmur3_32HashFunction. @@ -49,49 +51,66 @@ public static int hashInt(int input, int seed) { } public int hashUnsafeWords(Object base, long offset, int lengthInBytes) { - return hashUnsafeWords(base, offset, lengthInBytes, seed); + return hashUnsafeWordsBlock(MemoryBlock.allocateFromObject(base, offset, lengthInBytes), seed); } - public static int hashUnsafeWords(Object base, long offset, int lengthInBytes, int seed) { + public static int hashUnsafeWordsBlock(MemoryBlock base, int seed) { // This is based on Guava's `Murmur32_Hasher.processRemaining(ByteBuffer)` method. + int lengthInBytes = Ints.checkedCast(base.size()); assert (lengthInBytes % 8 == 0): "lengthInBytes must be a multiple of 8 (word-aligned)"; - int h1 = hashBytesByInt(base, offset, lengthInBytes, seed); + int h1 = hashBytesByIntBlock(base, seed); return fmix(h1, lengthInBytes); } - public static int hashUnsafeBytes(Object base, long offset, int lengthInBytes, int seed) { + public static int hashUnsafeWords(Object base, long offset, int lengthInBytes, int seed) { + // This is based on Guava's `Murmur32_Hasher.processRemaining(ByteBuffer)` method. + return hashUnsafeWordsBlock(MemoryBlock.allocateFromObject(base, offset, lengthInBytes), seed); + } + + public static int hashUnsafeBytesBlock(MemoryBlock base, int seed) { // This is not compatible with original and another implementations. // But remain it for backward compatibility for the components existing before 2.3. + int lengthInBytes = Ints.checkedCast(base.size()); assert (lengthInBytes >= 0): "lengthInBytes cannot be negative"; int lengthAligned = lengthInBytes - lengthInBytes % 4; - int h1 = hashBytesByInt(base, offset, lengthAligned, seed); + int h1 = hashBytesByIntBlock(base.subBlock(0, lengthAligned), seed); for (int i = lengthAligned; i < lengthInBytes; i++) { - int halfWord = Platform.getByte(base, offset + i); + int halfWord = base.getByte(i); int k1 = mixK1(halfWord); h1 = mixH1(h1, k1); } return fmix(h1, lengthInBytes); } + public static int hashUnsafeBytes(Object base, long offset, int lengthInBytes, int seed) { + return hashUnsafeBytesBlock(MemoryBlock.allocateFromObject(base, offset, lengthInBytes), seed); + } + public static int hashUnsafeBytes2(Object base, long offset, int lengthInBytes, int seed) { + return hashUnsafeBytes2Block(MemoryBlock.allocateFromObject(base, offset, lengthInBytes), seed); + } + + public static int hashUnsafeBytes2Block(MemoryBlock base, int seed) { // This is compatible with original and another implementations. // Use this method for new components after Spark 2.3. - assert (lengthInBytes >= 0): "lengthInBytes cannot be negative"; + int lengthInBytes = Ints.checkedCast(base.size()); + assert (lengthInBytes >= 0) : "lengthInBytes cannot be negative"; int lengthAligned = lengthInBytes - lengthInBytes % 4; - int h1 = hashBytesByInt(base, offset, lengthAligned, seed); + int h1 = hashBytesByIntBlock(base.subBlock(0, lengthAligned), seed); int k1 = 0; for (int i = lengthAligned, shift = 0; i < lengthInBytes; i++, shift += 8) { - k1 ^= (Platform.getByte(base, offset + i) & 0xFF) << shift; + k1 ^= (base.getByte(i) & 0xFF) << shift; } h1 ^= mixK1(k1); return fmix(h1, lengthInBytes); } - private static int hashBytesByInt(Object base, long offset, int lengthInBytes, int seed) { + private static int hashBytesByIntBlock(MemoryBlock base, int seed) { + long lengthInBytes = base.size(); assert (lengthInBytes % 4 == 0); int h1 = seed; - for (int i = 0; i < lengthInBytes; i += 4) { - int halfWord = Platform.getInt(base, offset + i); + for (long i = 0; i < lengthInBytes; i += 4) { + int halfWord = base.getInt(i); int k1 = mixK1(halfWord); h1 = mixH1(h1, k1); } diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/ByteArrayMemoryBlock.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/ByteArrayMemoryBlock.java new file mode 100644 index 0000000000000..99a9868a49a79 --- /dev/null +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/ByteArrayMemoryBlock.java @@ -0,0 +1,128 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.memory; + +import com.google.common.primitives.Ints; + +import org.apache.spark.unsafe.Platform; + +/** + * A consecutive block of memory with a byte array on Java heap. + */ +public final class ByteArrayMemoryBlock extends MemoryBlock { + + private final byte[] array; + + public ByteArrayMemoryBlock(byte[] obj, long offset, long size) { + super(obj, offset, size); + this.array = obj; + assert(offset + size <= Platform.BYTE_ARRAY_OFFSET + obj.length) : + "The sum of size " + size + " and offset " + offset + " should not be larger than " + + "the size of the given memory space " + (obj.length + Platform.BYTE_ARRAY_OFFSET); + } + + public ByteArrayMemoryBlock(long length) { + this(new byte[Ints.checkedCast(length)], Platform.BYTE_ARRAY_OFFSET, length); + } + + @Override + public MemoryBlock subBlock(long offset, long size) { + checkSubBlockRange(offset, size); + if (offset == 0 && size == this.size()) return this; + return new ByteArrayMemoryBlock(array, this.offset + offset, size); + } + + public byte[] getByteArray() { return array; } + + /** + * Creates a memory block pointing to the memory used by the byte array. + */ + public static ByteArrayMemoryBlock fromArray(final byte[] array) { + return new ByteArrayMemoryBlock(array, Platform.BYTE_ARRAY_OFFSET, array.length); + } + + @Override + public final int getInt(long offset) { + return Platform.getInt(array, this.offset + offset); + } + + @Override + public final void putInt(long offset, int value) { + Platform.putInt(array, this.offset + offset, value); + } + + @Override + public final boolean getBoolean(long offset) { + return Platform.getBoolean(array, this.offset + offset); + } + + @Override + public final void putBoolean(long offset, boolean value) { + Platform.putBoolean(array, this.offset + offset, value); + } + + @Override + public final byte getByte(long offset) { + return array[(int)(this.offset + offset - Platform.BYTE_ARRAY_OFFSET)]; + } + + @Override + public final void putByte(long offset, byte value) { + array[(int)(this.offset + offset - Platform.BYTE_ARRAY_OFFSET)] = value; + } + + @Override + public final short getShort(long offset) { + return Platform.getShort(array, this.offset + offset); + } + + @Override + public final void putShort(long offset, short value) { + Platform.putShort(array, this.offset + offset, value); + } + + @Override + public final long getLong(long offset) { + return Platform.getLong(array, this.offset + offset); + } + + @Override + public final void putLong(long offset, long value) { + Platform.putLong(array, this.offset + offset, value); + } + + @Override + public final float getFloat(long offset) { + return Platform.getFloat(array, this.offset + offset); + } + + @Override + public final void putFloat(long offset, float value) { + Platform.putFloat(array, this.offset + offset, value); + } + + @Override + public final double getDouble(long offset) { + return Platform.getDouble(array, this.offset + offset); + } + + @Override + public final void putDouble(long offset, double value) { + Platform.putDouble(array, this.offset + offset, value); + } +} diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java index 2733760dd19ef..acf28fd7ee59b 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java @@ -58,7 +58,7 @@ public MemoryBlock allocate(long size) throws OutOfMemoryError { final long[] array = arrayReference.get(); if (array != null) { assert (array.length * 8L >= size); - MemoryBlock memory = new MemoryBlock(array, Platform.LONG_ARRAY_OFFSET, size); + MemoryBlock memory = OnHeapMemoryBlock.fromArray(array, size); if (MemoryAllocator.MEMORY_DEBUG_FILL_ENABLED) { memory.fill(MemoryAllocator.MEMORY_DEBUG_FILL_CLEAN_VALUE); } @@ -70,7 +70,7 @@ public MemoryBlock allocate(long size) throws OutOfMemoryError { } } long[] array = new long[numWords]; - MemoryBlock memory = new MemoryBlock(array, Platform.LONG_ARRAY_OFFSET, size); + MemoryBlock memory = OnHeapMemoryBlock.fromArray(array, size); if (MemoryAllocator.MEMORY_DEBUG_FILL_ENABLED) { memory.fill(MemoryAllocator.MEMORY_DEBUG_FILL_CLEAN_VALUE); } @@ -79,12 +79,13 @@ public MemoryBlock allocate(long size) throws OutOfMemoryError { @Override public void free(MemoryBlock memory) { - assert (memory.obj != null) : + assert(memory instanceof OnHeapMemoryBlock); + assert (memory.getBaseObject() != null) : "baseObject was null; are you trying to use the on-heap allocator to free off-heap memory?"; - assert (memory.pageNumber != MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER) : + assert (memory.getPageNumber() != MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER) : "page has already been freed"; - assert ((memory.pageNumber == MemoryBlock.NO_PAGE_NUMBER) - || (memory.pageNumber == MemoryBlock.FREED_IN_TMM_PAGE_NUMBER)) : + assert ((memory.getPageNumber() == MemoryBlock.NO_PAGE_NUMBER) + || (memory.getPageNumber() == MemoryBlock.FREED_IN_TMM_PAGE_NUMBER)) : "TMM-allocated pages must first be freed via TMM.freePage(), not directly in allocator " + "free()"; @@ -94,12 +95,12 @@ public void free(MemoryBlock memory) { } // Mark the page as freed (so we can detect double-frees). - memory.pageNumber = MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER; + memory.setPageNumber(MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER); // As an additional layer of defense against use-after-free bugs, we mutate the // MemoryBlock to null out its reference to the long[] array. - long[] array = (long[]) memory.obj; - memory.setObjAndOffset(null, 0); + long[] array = ((OnHeapMemoryBlock)memory).getLongArray(); + memory.resetObjAndOffset(); long alignedSize = ((size + 7) / 8) * 8; if (shouldPool(alignedSize)) { diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryAllocator.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryAllocator.java index 7b588681d9790..38315fb97b46a 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryAllocator.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryAllocator.java @@ -38,7 +38,7 @@ public interface MemoryAllocator { void free(MemoryBlock memory); - MemoryAllocator UNSAFE = new UnsafeMemoryAllocator(); + UnsafeMemoryAllocator UNSAFE = new UnsafeMemoryAllocator(); - MemoryAllocator HEAP = new HeapMemoryAllocator(); + HeapMemoryAllocator HEAP = new HeapMemoryAllocator(); } diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java index c333857358d30..b086941108522 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java @@ -22,10 +22,10 @@ import org.apache.spark.unsafe.Platform; /** - * A consecutive block of memory, starting at a {@link MemoryLocation} with a fixed size. + * A representation of a consecutive memory block in Spark. It defines the common interfaces + * for memory accessing and mutating. */ -public class MemoryBlock extends MemoryLocation { - +public abstract class MemoryBlock { /** Special `pageNumber` value for pages which were not allocated by TaskMemoryManagers */ public static final int NO_PAGE_NUMBER = -1; @@ -45,38 +45,163 @@ public class MemoryBlock extends MemoryLocation { */ public static final int FREED_IN_ALLOCATOR_PAGE_NUMBER = -3; - private final long length; + @Nullable + protected Object obj; + + protected long offset; + + protected long length; /** * Optional page number; used when this MemoryBlock represents a page allocated by a - * TaskMemoryManager. This field is public so that it can be modified by the TaskMemoryManager, - * which lives in a different package. + * TaskMemoryManager. This field can be updated using setPageNumber method so that + * this can be modified by the TaskMemoryManager, which lives in a different package. */ - public int pageNumber = NO_PAGE_NUMBER; + private int pageNumber = NO_PAGE_NUMBER; - public MemoryBlock(@Nullable Object obj, long offset, long length) { - super(obj, offset); + protected MemoryBlock(@Nullable Object obj, long offset, long length) { + if (offset < 0 || length < 0) { + throw new IllegalArgumentException( + "Length " + length + " and offset " + offset + "must be non-negative"); + } + this.obj = obj; + this.offset = offset; this.length = length; } + protected MemoryBlock() { + this(null, 0, 0); + } + + public final Object getBaseObject() { + return obj; + } + + public final long getBaseOffset() { + return offset; + } + + public void resetObjAndOffset() { + this.obj = null; + this.offset = 0; + } + /** * Returns the size of the memory block. */ - public long size() { + public final long size() { return length; } - /** - * Creates a memory block pointing to the memory used by the long array. - */ - public static MemoryBlock fromLongArray(final long[] array) { - return new MemoryBlock(array, Platform.LONG_ARRAY_OFFSET, array.length * 8L); + public final void setPageNumber(int pageNum) { + pageNumber = pageNum; + } + + public final int getPageNumber() { + return pageNumber; } /** * Fills the memory block with the specified byte value. */ - public void fill(byte value) { + public final void fill(byte value) { Platform.setMemory(obj, offset, length, value); } + + /** + * Instantiate MemoryBlock for given object type with new offset + */ + public final static MemoryBlock allocateFromObject(Object obj, long offset, long length) { + MemoryBlock mb = null; + if (obj instanceof byte[]) { + byte[] array = (byte[])obj; + mb = new ByteArrayMemoryBlock(array, offset, length); + } else if (obj instanceof long[]) { + long[] array = (long[])obj; + mb = new OnHeapMemoryBlock(array, offset, length); + } else if (obj == null) { + // we assume that to pass null pointer means off-heap + mb = new OffHeapMemoryBlock(offset, length); + } else { + throw new UnsupportedOperationException( + "Instantiate MemoryBlock for type " + obj.getClass() + " is not supported now"); + } + return mb; + } + + /** + * Just instantiate the sub-block with the same type of MemoryBlock with the new size and relative + * offset from the original offset. The data is not copied. + * If parameters are invalid, an exception is thrown. + */ + public abstract MemoryBlock subBlock(long offset, long size); + + protected void checkSubBlockRange(long offset, long size) { + if (offset < 0 || size < 0) { + throw new ArrayIndexOutOfBoundsException( + "Size " + size + " and offset " + offset + " must be non-negative"); + } + if (offset + size > length) { + throw new ArrayIndexOutOfBoundsException("The sum of size " + size + " and offset " + + offset + " should not be larger than the length " + length + " in the MemoryBlock"); + } + } + + /** + * getXXX/putXXX does not ensure guarantee behavior if the offset is invalid. e.g cause illegal + * memory access, throw an exception, or etc. + * getXXX/putXXX uses an index based on this.offset that includes the size of metadata such as + * JVM object header. The offset is 0-based and is expected as an logical offset in the memory + * block. + */ + public abstract int getInt(long offset); + + public abstract void putInt(long offset, int value); + + public abstract boolean getBoolean(long offset); + + public abstract void putBoolean(long offset, boolean value); + + public abstract byte getByte(long offset); + + public abstract void putByte(long offset, byte value); + + public abstract short getShort(long offset); + + public abstract void putShort(long offset, short value); + + public abstract long getLong(long offset); + + public abstract void putLong(long offset, long value); + + public abstract float getFloat(long offset); + + public abstract void putFloat(long offset, float value); + + public abstract double getDouble(long offset); + + public abstract void putDouble(long offset, double value); + + public static final void copyMemory( + MemoryBlock src, long srcOffset, MemoryBlock dst, long dstOffset, long length) { + assert(srcOffset + length <= src.length && dstOffset + length <= dst.length); + Platform.copyMemory(src.getBaseObject(), src.getBaseOffset() + srcOffset, + dst.getBaseObject(), dst.getBaseOffset() + dstOffset, length); + } + + public static final void copyMemory(MemoryBlock src, MemoryBlock dst, long length) { + assert(length <= src.length && length <= dst.length); + Platform.copyMemory(src.getBaseObject(), src.getBaseOffset(), + dst.getBaseObject(), dst.getBaseOffset(), length); + } + + public final void copyFrom(Object src, long srcOffset, long dstOffset, long length) { + assert(length <= this.length - srcOffset); + Platform.copyMemory(src, srcOffset, obj, offset + dstOffset, length); + } + + public final void writeTo(long srcOffset, Object dst, long dstOffset, long length) { + assert(length <= this.length - srcOffset); + Platform.copyMemory(obj, offset + srcOffset, dst, dstOffset, length); + } } diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryLocation.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryLocation.java deleted file mode 100644 index 74ebc87dc978c..0000000000000 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryLocation.java +++ /dev/null @@ -1,54 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.unsafe.memory; - -import javax.annotation.Nullable; - -/** - * A memory location. Tracked either by a memory address (with off-heap allocation), - * or by an offset from a JVM object (in-heap allocation). - */ -public class MemoryLocation { - - @Nullable - Object obj; - - long offset; - - public MemoryLocation(@Nullable Object obj, long offset) { - this.obj = obj; - this.offset = offset; - } - - public MemoryLocation() { - this(null, 0); - } - - public void setObjAndOffset(Object newObj, long newOffset) { - this.obj = newObj; - this.offset = newOffset; - } - - public final Object getBaseObject() { - return obj; - } - - public final long getBaseOffset() { - return offset; - } -} diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/OffHeapMemoryBlock.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/OffHeapMemoryBlock.java new file mode 100644 index 0000000000000..f90f62bf21dcb --- /dev/null +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/OffHeapMemoryBlock.java @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.memory; + +import org.apache.spark.unsafe.Platform; + +public class OffHeapMemoryBlock extends MemoryBlock { + static public final OffHeapMemoryBlock NULL = new OffHeapMemoryBlock(0, 0); + + public OffHeapMemoryBlock(long address, long size) { + super(null, address, size); + } + + @Override + public MemoryBlock subBlock(long offset, long size) { + checkSubBlockRange(offset, size); + if (offset == 0 && size == this.size()) return this; + return new OffHeapMemoryBlock(this.offset + offset, size); + } + + @Override + public final int getInt(long offset) { + return Platform.getInt(null, this.offset + offset); + } + + @Override + public final void putInt(long offset, int value) { + Platform.putInt(null, this.offset + offset, value); + } + + @Override + public final boolean getBoolean(long offset) { + return Platform.getBoolean(null, this.offset + offset); + } + + @Override + public final void putBoolean(long offset, boolean value) { + Platform.putBoolean(null, this.offset + offset, value); + } + + @Override + public final byte getByte(long offset) { + return Platform.getByte(null, this.offset + offset); + } + + @Override + public final void putByte(long offset, byte value) { + Platform.putByte(null, this.offset + offset, value); + } + + @Override + public final short getShort(long offset) { + return Platform.getShort(null, this.offset + offset); + } + + @Override + public final void putShort(long offset, short value) { + Platform.putShort(null, this.offset + offset, value); + } + + @Override + public final long getLong(long offset) { + return Platform.getLong(null, this.offset + offset); + } + + @Override + public final void putLong(long offset, long value) { + Platform.putLong(null, this.offset + offset, value); + } + + @Override + public final float getFloat(long offset) { + return Platform.getFloat(null, this.offset + offset); + } + + @Override + public final void putFloat(long offset, float value) { + Platform.putFloat(null, this.offset + offset, value); + } + + @Override + public final double getDouble(long offset) { + return Platform.getDouble(null, this.offset + offset); + } + + @Override + public final void putDouble(long offset, double value) { + Platform.putDouble(null, this.offset + offset, value); + } +} diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/OnHeapMemoryBlock.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/OnHeapMemoryBlock.java new file mode 100644 index 0000000000000..12f67c7bd593e --- /dev/null +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/OnHeapMemoryBlock.java @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.memory; + +import com.google.common.primitives.Ints; + +import org.apache.spark.unsafe.Platform; + +/** + * A consecutive block of memory with a long array on Java heap. + */ +public final class OnHeapMemoryBlock extends MemoryBlock { + + private final long[] array; + + public OnHeapMemoryBlock(long[] obj, long offset, long size) { + super(obj, offset, size); + this.array = obj; + assert(offset + size <= obj.length * 8L + Platform.LONG_ARRAY_OFFSET) : + "The sum of size " + size + " and offset " + offset + " should not be larger than " + + "the size of the given memory space " + (obj.length * 8L + Platform.LONG_ARRAY_OFFSET); + } + + public OnHeapMemoryBlock(long size) { + this(new long[Ints.checkedCast((size + 7) / 8)], Platform.LONG_ARRAY_OFFSET, size); + } + + @Override + public MemoryBlock subBlock(long offset, long size) { + checkSubBlockRange(offset, size); + if (offset == 0 && size == this.size()) return this; + return new OnHeapMemoryBlock(array, this.offset + offset, size); + } + + public long[] getLongArray() { return array; } + + /** + * Creates a memory block pointing to the memory used by the long array. + */ + public static OnHeapMemoryBlock fromArray(final long[] array) { + return new OnHeapMemoryBlock(array, Platform.LONG_ARRAY_OFFSET, array.length * 8L); + } + + public static OnHeapMemoryBlock fromArray(final long[] array, long size) { + return new OnHeapMemoryBlock(array, Platform.LONG_ARRAY_OFFSET, size); + } + + @Override + public final int getInt(long offset) { + return Platform.getInt(array, this.offset + offset); + } + + @Override + public final void putInt(long offset, int value) { + Platform.putInt(array, this.offset + offset, value); + } + + @Override + public final boolean getBoolean(long offset) { + return Platform.getBoolean(array, this.offset + offset); + } + + @Override + public final void putBoolean(long offset, boolean value) { + Platform.putBoolean(array, this.offset + offset, value); + } + + @Override + public final byte getByte(long offset) { + return Platform.getByte(array, this.offset + offset); + } + + @Override + public final void putByte(long offset, byte value) { + Platform.putByte(array, this.offset + offset, value); + } + + @Override + public final short getShort(long offset) { + return Platform.getShort(array, this.offset + offset); + } + + @Override + public final void putShort(long offset, short value) { + Platform.putShort(array, this.offset + offset, value); + } + + @Override + public final long getLong(long offset) { + return Platform.getLong(array, this.offset + offset); + } + + @Override + public final void putLong(long offset, long value) { + Platform.putLong(array, this.offset + offset, value); + } + + @Override + public final float getFloat(long offset) { + return Platform.getFloat(array, this.offset + offset); + } + + @Override + public final void putFloat(long offset, float value) { + Platform.putFloat(array, this.offset + offset, value); + } + + @Override + public final double getDouble(long offset) { + return Platform.getDouble(array, this.offset + offset); + } + + @Override + public final void putDouble(long offset, double value) { + Platform.putDouble(array, this.offset + offset, value); + } +} diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java index 4368fb615ba1e..5310bdf2779a9 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java @@ -25,9 +25,9 @@ public class UnsafeMemoryAllocator implements MemoryAllocator { @Override - public MemoryBlock allocate(long size) throws OutOfMemoryError { + public OffHeapMemoryBlock allocate(long size) throws OutOfMemoryError { long address = Platform.allocateMemory(size); - MemoryBlock memory = new MemoryBlock(null, address, size); + OffHeapMemoryBlock memory = new OffHeapMemoryBlock(address, size); if (MemoryAllocator.MEMORY_DEBUG_FILL_ENABLED) { memory.fill(MemoryAllocator.MEMORY_DEBUG_FILL_CLEAN_VALUE); } @@ -36,22 +36,25 @@ public MemoryBlock allocate(long size) throws OutOfMemoryError { @Override public void free(MemoryBlock memory) { - assert (memory.obj == null) : - "baseObject not null; are you trying to use the off-heap allocator to free on-heap memory?"; - assert (memory.pageNumber != MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER) : + assert(memory instanceof OffHeapMemoryBlock) : + "UnsafeMemoryAllocator can only free OffHeapMemoryBlock."; + if (memory == OffHeapMemoryBlock.NULL) return; + assert (memory.getPageNumber() != MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER) : "page has already been freed"; - assert ((memory.pageNumber == MemoryBlock.NO_PAGE_NUMBER) - || (memory.pageNumber == MemoryBlock.FREED_IN_TMM_PAGE_NUMBER)) : + assert ((memory.getPageNumber() == MemoryBlock.NO_PAGE_NUMBER) + || (memory.getPageNumber() == MemoryBlock.FREED_IN_TMM_PAGE_NUMBER)) : "TMM-allocated pages must be freed via TMM.freePage(), not directly in allocator free()"; if (MemoryAllocator.MEMORY_DEBUG_FILL_ENABLED) { memory.fill(MemoryAllocator.MEMORY_DEBUG_FILL_FREED_VALUE); } + Platform.freeMemory(memory.offset); + // As an additional layer of defense against use-after-free bugs, we mutate the // MemoryBlock to reset its pointer. - memory.offset = 0; + memory.resetObjAndOffset(); // Mark the page as freed (so we can detect double-frees). - memory.pageNumber = MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER; + memory.setPageNumber(MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER); } } diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index 5d468aed42337..e9b3d9b045af5 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -30,9 +30,12 @@ import com.esotericsoftware.kryo.io.Input; import com.esotericsoftware.kryo.io.Output; +import com.google.common.primitives.Ints; import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.hash.Murmur3_x86_32; +import org.apache.spark.unsafe.memory.ByteArrayMemoryBlock; +import org.apache.spark.unsafe.memory.MemoryBlock; import static org.apache.spark.unsafe.Platform.*; @@ -50,12 +53,13 @@ public final class UTF8String implements Comparable, Externalizable, // These are only updated by readExternal() or read() @Nonnull - private Object base; - private long offset; + private MemoryBlock base; + // While numBytes has the same value as base.size(), to keep as int avoids cast from long to int private int numBytes; - public Object getBaseObject() { return base; } - public long getBaseOffset() { return offset; } + public MemoryBlock getMemoryBlock() { return base; } + public Object getBaseObject() { return base.getBaseObject(); } + public long getBaseOffset() { return base.getBaseOffset(); } /** * A char in UTF-8 encoding can take 1-4 bytes depending on the first byte which @@ -108,7 +112,8 @@ public final class UTF8String implements Comparable, Externalizable, */ public static UTF8String fromBytes(byte[] bytes) { if (bytes != null) { - return new UTF8String(bytes, BYTE_ARRAY_OFFSET, bytes.length); + return new UTF8String( + new ByteArrayMemoryBlock(bytes, BYTE_ARRAY_OFFSET, bytes.length)); } else { return null; } @@ -121,19 +126,13 @@ public static UTF8String fromBytes(byte[] bytes) { */ public static UTF8String fromBytes(byte[] bytes, int offset, int numBytes) { if (bytes != null) { - return new UTF8String(bytes, BYTE_ARRAY_OFFSET + offset, numBytes); + return new UTF8String( + new ByteArrayMemoryBlock(bytes, BYTE_ARRAY_OFFSET + offset, numBytes)); } else { return null; } } - /** - * Creates an UTF8String from given address (base and offset) and length. - */ - public static UTF8String fromAddress(Object base, long offset, int numBytes) { - return new UTF8String(base, offset, numBytes); - } - /** * Creates an UTF8String from String. */ @@ -150,16 +149,13 @@ public static UTF8String blankString(int length) { return fromBytes(spaces); } - protected UTF8String(Object base, long offset, int numBytes) { + public UTF8String(MemoryBlock base) { this.base = base; - this.offset = offset; - this.numBytes = numBytes; + this.numBytes = Ints.checkedCast(base.size()); } // for serialization - public UTF8String() { - this(null, 0, 0); - } + public UTF8String() {} /** * Writes the content of this string into a memory address, identified by an object and an offset. @@ -167,7 +163,7 @@ public UTF8String() { * bytes in this string. */ public void writeToMemory(Object target, long targetOffset) { - Platform.copyMemory(base, offset, target, targetOffset, numBytes); + base.writeTo(0, target, targetOffset, numBytes); } public void writeTo(ByteBuffer buffer) { @@ -187,8 +183,9 @@ public void writeTo(ByteBuffer buffer) { */ @Nonnull public ByteBuffer getByteBuffer() { - if (base instanceof byte[] && offset >= BYTE_ARRAY_OFFSET) { - final byte[] bytes = (byte[]) base; + long offset = base.getBaseOffset(); + if (base instanceof ByteArrayMemoryBlock && offset >= BYTE_ARRAY_OFFSET) { + final byte[] bytes = ((ByteArrayMemoryBlock) base).getByteArray(); // the offset includes an object header... this is only needed for unsafe copies final long arrayOffset = offset - BYTE_ARRAY_OFFSET; @@ -255,12 +252,12 @@ public long getPrefix() { long mask = 0; if (IS_LITTLE_ENDIAN) { if (numBytes >= 8) { - p = Platform.getLong(base, offset); + p = base.getLong(0); } else if (numBytes > 4) { - p = Platform.getLong(base, offset); + p = base.getLong(0); mask = (1L << (8 - numBytes) * 8) - 1; } else if (numBytes > 0) { - p = (long) Platform.getInt(base, offset); + p = (long) base.getInt(0); mask = (1L << (8 - numBytes) * 8) - 1; } else { p = 0; @@ -269,12 +266,12 @@ public long getPrefix() { } else { // byteOrder == ByteOrder.BIG_ENDIAN if (numBytes >= 8) { - p = Platform.getLong(base, offset); + p = base.getLong(0); } else if (numBytes > 4) { - p = Platform.getLong(base, offset); + p = base.getLong(0); mask = (1L << (8 - numBytes) * 8) - 1; } else if (numBytes > 0) { - p = ((long) Platform.getInt(base, offset)) << 32; + p = ((long) base.getInt(0)) << 32; mask = (1L << (8 - numBytes) * 8) - 1; } else { p = 0; @@ -289,12 +286,13 @@ public long getPrefix() { */ public byte[] getBytes() { // avoid copy if `base` is `byte[]` - if (offset == BYTE_ARRAY_OFFSET && base instanceof byte[] - && ((byte[]) base).length == numBytes) { - return (byte[]) base; + long offset = base.getBaseOffset(); + if (offset == BYTE_ARRAY_OFFSET && base instanceof ByteArrayMemoryBlock + && (((ByteArrayMemoryBlock) base).getByteArray()).length == numBytes) { + return ((ByteArrayMemoryBlock) base).getByteArray(); } else { byte[] bytes = new byte[numBytes]; - copyMemory(base, offset, bytes, BYTE_ARRAY_OFFSET, numBytes); + base.writeTo(0, bytes, BYTE_ARRAY_OFFSET, numBytes); return bytes; } } @@ -324,7 +322,7 @@ public UTF8String substring(final int start, final int until) { if (i > j) { byte[] bytes = new byte[i - j]; - copyMemory(base, offset + j, bytes, BYTE_ARRAY_OFFSET, i - j); + base.writeTo(j, bytes, BYTE_ARRAY_OFFSET, i - j); return fromBytes(bytes); } else { return EMPTY_UTF8; @@ -365,14 +363,14 @@ public boolean contains(final UTF8String substring) { * Returns the byte at position `i`. */ private byte getByte(int i) { - return Platform.getByte(base, offset + i); + return base.getByte(i); } private boolean matchAt(final UTF8String s, int pos) { if (s.numBytes + pos > numBytes || pos < 0) { return false; } - return ByteArrayMethods.arrayEquals(base, offset + pos, s.base, s.offset, s.numBytes); + return ByteArrayMethods.arrayEqualsBlock(base, pos, s.base, 0, s.numBytes); } public boolean startsWith(final UTF8String prefix) { @@ -499,8 +497,7 @@ public int findInSet(UTF8String match) { for (int i = 0; i < numBytes; i++) { if (getByte(i) == (byte) ',') { if (i - (lastComma + 1) == match.numBytes && - ByteArrayMethods.arrayEquals(base, offset + (lastComma + 1), match.base, match.offset, - match.numBytes)) { + ByteArrayMethods.arrayEqualsBlock(base, lastComma + 1, match.base, 0, match.numBytes)) { return n; } lastComma = i; @@ -508,8 +505,7 @@ public int findInSet(UTF8String match) { } } if (numBytes - (lastComma + 1) == match.numBytes && - ByteArrayMethods.arrayEquals(base, offset + (lastComma + 1), match.base, match.offset, - match.numBytes)) { + ByteArrayMethods.arrayEqualsBlock(base, lastComma + 1, match.base, 0, match.numBytes)) { return n; } return 0; @@ -524,7 +520,7 @@ public int findInSet(UTF8String match) { private UTF8String copyUTF8String(int start, int end) { int len = end - start + 1; byte[] newBytes = new byte[len]; - copyMemory(base, offset + start, newBytes, BYTE_ARRAY_OFFSET, len); + base.writeTo(start, newBytes, BYTE_ARRAY_OFFSET, len); return UTF8String.fromBytes(newBytes); } @@ -671,8 +667,7 @@ public UTF8String reverse() { int i = 0; // position in byte while (i < numBytes) { int len = numBytesForFirstByte(getByte(i)); - copyMemory(this.base, this.offset + i, result, - BYTE_ARRAY_OFFSET + result.length - i - len, len); + base.writeTo(i, result, BYTE_ARRAY_OFFSET + result.length - i - len, len); i += len; } @@ -686,7 +681,7 @@ public UTF8String repeat(int times) { } byte[] newBytes = new byte[numBytes * times]; - copyMemory(this.base, this.offset, newBytes, BYTE_ARRAY_OFFSET, numBytes); + base.writeTo(0, newBytes, BYTE_ARRAY_OFFSET, numBytes); int copied = 1; while (copied < times) { @@ -723,7 +718,7 @@ public int indexOf(UTF8String v, int start) { if (i + v.numBytes > numBytes) { return -1; } - if (ByteArrayMethods.arrayEquals(base, offset + i, v.base, v.offset, v.numBytes)) { + if (ByteArrayMethods.arrayEqualsBlock(base, i, v.base, 0, v.numBytes)) { return c; } i += numBytesForFirstByte(getByte(i)); @@ -739,7 +734,7 @@ public int indexOf(UTF8String v, int start) { private int find(UTF8String str, int start) { assert (str.numBytes > 0); while (start <= numBytes - str.numBytes) { - if (ByteArrayMethods.arrayEquals(base, offset + start, str.base, str.offset, str.numBytes)) { + if (ByteArrayMethods.arrayEqualsBlock(base, start, str.base, 0, str.numBytes)) { return start; } start += 1; @@ -753,7 +748,7 @@ private int find(UTF8String str, int start) { private int rfind(UTF8String str, int start) { assert (str.numBytes > 0); while (start >= 0) { - if (ByteArrayMethods.arrayEquals(base, offset + start, str.base, str.offset, str.numBytes)) { + if (ByteArrayMethods.arrayEqualsBlock(base, start, str.base, 0, str.numBytes)) { return start; } start -= 1; @@ -786,7 +781,7 @@ public UTF8String subStringIndex(UTF8String delim, int count) { return EMPTY_UTF8; } byte[] bytes = new byte[idx]; - copyMemory(base, offset, bytes, BYTE_ARRAY_OFFSET, idx); + base.writeTo(0, bytes, BYTE_ARRAY_OFFSET, idx); return fromBytes(bytes); } else { @@ -806,7 +801,7 @@ public UTF8String subStringIndex(UTF8String delim, int count) { } int size = numBytes - delim.numBytes - idx; byte[] bytes = new byte[size]; - copyMemory(base, offset + idx + delim.numBytes, bytes, BYTE_ARRAY_OFFSET, size); + base.writeTo(idx + delim.numBytes, bytes, BYTE_ARRAY_OFFSET, size); return fromBytes(bytes); } } @@ -829,15 +824,15 @@ public UTF8String rpad(int len, UTF8String pad) { UTF8String remain = pad.substring(0, spaces - padChars * count); byte[] data = new byte[this.numBytes + pad.numBytes * count + remain.numBytes]; - copyMemory(this.base, this.offset, data, BYTE_ARRAY_OFFSET, this.numBytes); + base.writeTo(0, data, BYTE_ARRAY_OFFSET, this.numBytes); int offset = this.numBytes; int idx = 0; while (idx < count) { - copyMemory(pad.base, pad.offset, data, BYTE_ARRAY_OFFSET + offset, pad.numBytes); + pad.base.writeTo(0, data, BYTE_ARRAY_OFFSET + offset, pad.numBytes); ++ idx; offset += pad.numBytes; } - copyMemory(remain.base, remain.offset, data, BYTE_ARRAY_OFFSET + offset, remain.numBytes); + remain.base.writeTo(0, data, BYTE_ARRAY_OFFSET + offset, remain.numBytes); return UTF8String.fromBytes(data); } @@ -865,13 +860,13 @@ public UTF8String lpad(int len, UTF8String pad) { int offset = 0; int idx = 0; while (idx < count) { - copyMemory(pad.base, pad.offset, data, BYTE_ARRAY_OFFSET + offset, pad.numBytes); + pad.base.writeTo(0, data, BYTE_ARRAY_OFFSET + offset, pad.numBytes); ++ idx; offset += pad.numBytes; } - copyMemory(remain.base, remain.offset, data, BYTE_ARRAY_OFFSET + offset, remain.numBytes); + remain.base.writeTo(0, data, BYTE_ARRAY_OFFSET + offset, remain.numBytes); offset += remain.numBytes; - copyMemory(this.base, this.offset, data, BYTE_ARRAY_OFFSET + offset, numBytes()); + base.writeTo(0, data, BYTE_ARRAY_OFFSET + offset, numBytes()); return UTF8String.fromBytes(data); } @@ -896,8 +891,8 @@ public static UTF8String concat(UTF8String... inputs) { int offset = 0; for (int i = 0; i < inputs.length; i++) { int len = inputs[i].numBytes; - copyMemory( - inputs[i].base, inputs[i].offset, + inputs[i].base.writeTo( + 0, result, BYTE_ARRAY_OFFSET + offset, len); offset += len; @@ -936,8 +931,8 @@ public static UTF8String concatWs(UTF8String separator, UTF8String... inputs) { for (int i = 0, j = 0; i < inputs.length; i++) { if (inputs[i] != null) { int len = inputs[i].numBytes; - copyMemory( - inputs[i].base, inputs[i].offset, + inputs[i].base.writeTo( + 0, result, BYTE_ARRAY_OFFSET + offset, len); offset += len; @@ -945,8 +940,8 @@ public static UTF8String concatWs(UTF8String separator, UTF8String... inputs) { j++; // Add separator if this is not the last input. if (j < numInputs) { - copyMemory( - separator.base, separator.offset, + separator.base.writeTo( + 0, result, BYTE_ARRAY_OFFSET + offset, separator.numBytes); offset += separator.numBytes; @@ -1220,7 +1215,7 @@ public UTF8String clone() { public UTF8String copy() { byte[] bytes = new byte[numBytes]; - copyMemory(base, offset, bytes, BYTE_ARRAY_OFFSET, numBytes); + base.writeTo(0, bytes, BYTE_ARRAY_OFFSET, numBytes); return fromBytes(bytes); } @@ -1228,11 +1223,10 @@ public UTF8String copy() { public int compareTo(@Nonnull final UTF8String other) { int len = Math.min(numBytes, other.numBytes); int wordMax = (len / 8) * 8; - long roffset = other.offset; - Object rbase = other.base; + MemoryBlock rbase = other.base; for (int i = 0; i < wordMax; i += 8) { - long left = getLong(base, offset + i); - long right = getLong(rbase, roffset + i); + long left = base.getLong(i); + long right = rbase.getLong(i); if (left != right) { if (IS_LITTLE_ENDIAN) { return Long.compareUnsigned(Long.reverseBytes(left), Long.reverseBytes(right)); @@ -1243,7 +1237,7 @@ public int compareTo(@Nonnull final UTF8String other) { } for (int i = wordMax; i < len; i++) { // In UTF-8, the byte should be unsigned, so we should compare them as unsigned int. - int res = (getByte(i) & 0xFF) - (Platform.getByte(rbase, roffset + i) & 0xFF); + int res = (getByte(i) & 0xFF) - (rbase.getByte(i) & 0xFF); if (res != 0) { return res; } @@ -1262,7 +1256,7 @@ public boolean equals(final Object other) { if (numBytes != o.numBytes) { return false; } - return ByteArrayMethods.arrayEquals(base, offset, o.base, o.offset, numBytes); + return ByteArrayMethods.arrayEqualsBlock(base, 0, o.base, 0, numBytes); } else { return false; } @@ -1318,8 +1312,8 @@ public int levenshteinDistance(UTF8String other) { num_bytes_j != numBytesForFirstByte(s.getByte(i_bytes))) { cost = 1; } else { - cost = (ByteArrayMethods.arrayEquals(t.base, t.offset + j_bytes, s.base, - s.offset + i_bytes, num_bytes_j)) ? 0 : 1; + cost = (ByteArrayMethods.arrayEqualsBlock(t.base, j_bytes, s.base, + i_bytes, num_bytes_j)) ? 0 : 1; } d[i + 1] = Math.min(Math.min(d[i] + 1, p[i + 1] + 1), p[i] + cost); } @@ -1334,7 +1328,7 @@ public int levenshteinDistance(UTF8String other) { @Override public int hashCode() { - return Murmur3_x86_32.hashUnsafeBytes(base, offset, numBytes, 42); + return Murmur3_x86_32.hashUnsafeBytesBlock(base,42); } /** @@ -1397,10 +1391,10 @@ public void writeExternal(ObjectOutput out) throws IOException { } public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { - offset = BYTE_ARRAY_OFFSET; numBytes = in.readInt(); - base = new byte[numBytes]; - in.readFully((byte[]) base); + byte[] bytes = new byte[numBytes]; + in.readFully(bytes); + base = ByteArrayMemoryBlock.fromArray(bytes); } @Override @@ -1412,10 +1406,10 @@ public void write(Kryo kryo, Output out) { @Override public void read(Kryo kryo, Input in) { - this.offset = BYTE_ARRAY_OFFSET; - this.numBytes = in.readInt(); - this.base = new byte[numBytes]; - in.read((byte[]) base); + numBytes = in.readInt(); + byte[] bytes = new byte[numBytes]; + in.read(bytes); + base = ByteArrayMemoryBlock.fromArray(bytes); } } diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java index 3ad9ac7b4de9c..583a148b3845d 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java @@ -81,7 +81,7 @@ public void freeingOnHeapMemoryBlockResetsBaseObjectAndOffset() { MemoryAllocator.HEAP.free(block); Assert.assertNull(block.getBaseObject()); Assert.assertEquals(0, block.getBaseOffset()); - Assert.assertEquals(MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER, block.pageNumber); + Assert.assertEquals(MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER, block.getPageNumber()); } @Test @@ -92,7 +92,7 @@ public void freeingOffHeapMemoryBlockResetsOffset() { MemoryAllocator.UNSAFE.free(block); Assert.assertNull(block.getBaseObject()); Assert.assertEquals(0, block.getBaseOffset()); - Assert.assertEquals(MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER, block.pageNumber); + Assert.assertEquals(MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER, block.getPageNumber()); } @Test(expected = AssertionError.class) diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/array/LongArraySuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/array/LongArraySuite.java index fb8e53b3348f3..8c2e98c2bfc54 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/array/LongArraySuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/array/LongArraySuite.java @@ -20,14 +20,13 @@ import org.junit.Assert; import org.junit.Test; -import org.apache.spark.unsafe.memory.MemoryBlock; +import org.apache.spark.unsafe.memory.OnHeapMemoryBlock; public class LongArraySuite { @Test public void basicTest() { - long[] bytes = new long[2]; - LongArray arr = new LongArray(MemoryBlock.fromLongArray(bytes)); + LongArray arr = new LongArray(new OnHeapMemoryBlock(16)); arr.set(0, 1L); arr.set(1, 2L); arr.set(1, 3L); diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java index 6348a73bf3895..d7ed005db1891 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java @@ -70,6 +70,24 @@ public void testKnownBytesInputs() { Murmur3_x86_32.hashUnsafeBytes2(tes, Platform.BYTE_ARRAY_OFFSET, tes.length, 0)); } + @Test + public void testKnownWordsInputs() { + byte[] bytes = new byte[16]; + long offset = Platform.BYTE_ARRAY_OFFSET; + for (int i = 0; i < 16; i++) { + bytes[i] = 0; + } + Assert.assertEquals(-300363099, hasher.hashUnsafeWords(bytes, offset, 16, 42)); + for (int i = 0; i < 16; i++) { + bytes[i] = -1; + } + Assert.assertEquals(-1210324667, hasher.hashUnsafeWords(bytes, offset, 16, 42)); + for (int i = 0; i < 16; i++) { + bytes[i] = (byte)i; + } + Assert.assertEquals(-634919701, hasher.hashUnsafeWords(bytes, offset, 16, 42)); + } + @Test public void randomizedStressTest() { int size = 65536; diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/memory/MemoryBlockSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/memory/MemoryBlockSuite.java new file mode 100644 index 0000000000000..47f05c928f2e5 --- /dev/null +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/memory/MemoryBlockSuite.java @@ -0,0 +1,175 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.memory; + +import org.apache.spark.unsafe.Platform; +import org.junit.Assert; +import org.junit.Test; + +import java.nio.ByteOrder; + +import static org.hamcrest.core.StringContains.containsString; + +public class MemoryBlockSuite { + private static final boolean bigEndianPlatform = + ByteOrder.nativeOrder().equals(ByteOrder.BIG_ENDIAN); + + private void check(MemoryBlock memory, Object obj, long offset, int length) { + memory.setPageNumber(1); + memory.fill((byte)-1); + memory.putBoolean(0, true); + memory.putByte(1, (byte)127); + memory.putShort(2, (short)257); + memory.putInt(4, 0x20000002); + memory.putLong(8, 0x1234567089ABCDEFL); + memory.putFloat(16, 1.0F); + memory.putLong(20, 0x1234567089ABCDEFL); + memory.putDouble(28, 2.0); + MemoryBlock.copyMemory(memory, 0L, memory, 36, 4); + int[] a = new int[2]; + a[0] = 0x12345678; + a[1] = 0x13579BDF; + memory.copyFrom(a, Platform.INT_ARRAY_OFFSET, 40, 8); + byte[] b = new byte[8]; + memory.writeTo(40, b, Platform.BYTE_ARRAY_OFFSET, 8); + + Assert.assertEquals(obj, memory.getBaseObject()); + Assert.assertEquals(offset, memory.getBaseOffset()); + Assert.assertEquals(length, memory.size()); + Assert.assertEquals(1, memory.getPageNumber()); + Assert.assertEquals(true, memory.getBoolean(0)); + Assert.assertEquals((byte)127, memory.getByte(1 )); + Assert.assertEquals((short)257, memory.getShort(2)); + Assert.assertEquals(0x20000002, memory.getInt(4)); + Assert.assertEquals(0x1234567089ABCDEFL, memory.getLong(8)); + Assert.assertEquals(1.0F, memory.getFloat(16), 0); + Assert.assertEquals(0x1234567089ABCDEFL, memory.getLong(20)); + Assert.assertEquals(2.0, memory.getDouble(28), 0); + Assert.assertEquals(true, memory.getBoolean(36)); + Assert.assertEquals((byte)127, memory.getByte(37 )); + Assert.assertEquals((short)257, memory.getShort(38)); + Assert.assertEquals(a[0], memory.getInt(40)); + Assert.assertEquals(a[1], memory.getInt(44)); + if (bigEndianPlatform) { + Assert.assertEquals(a[0], + ((int)b[0] & 0xff) << 24 | ((int)b[1] & 0xff) << 16 | + ((int)b[2] & 0xff) << 8 | ((int)b[3] & 0xff)); + Assert.assertEquals(a[1], + ((int)b[4] & 0xff) << 24 | ((int)b[5] & 0xff) << 16 | + ((int)b[6] & 0xff) << 8 | ((int)b[7] & 0xff)); + } else { + Assert.assertEquals(a[0], + ((int)b[3] & 0xff) << 24 | ((int)b[2] & 0xff) << 16 | + ((int)b[1] & 0xff) << 8 | ((int)b[0] & 0xff)); + Assert.assertEquals(a[1], + ((int)b[7] & 0xff) << 24 | ((int)b[6] & 0xff) << 16 | + ((int)b[5] & 0xff) << 8 | ((int)b[4] & 0xff)); + } + for (int i = 48; i < memory.size(); i++) { + Assert.assertEquals((byte) -1, memory.getByte(i)); + } + + assert(memory.subBlock(0, memory.size()) == memory); + + try { + memory.subBlock(-8, 8); + Assert.fail(); + } catch (Exception expected) { + Assert.assertThat(expected.getMessage(), containsString("non-negative")); + } + + try { + memory.subBlock(0, -8); + Assert.fail(); + } catch (Exception expected) { + Assert.assertThat(expected.getMessage(), containsString("non-negative")); + } + + try { + memory.subBlock(0, length + 8); + Assert.fail(); + } catch (Exception expected) { + Assert.assertThat(expected.getMessage(), containsString("should not be larger than")); + } + + try { + memory.subBlock(8, length - 4); + Assert.fail(); + } catch (Exception expected) { + Assert.assertThat(expected.getMessage(), containsString("should not be larger than")); + } + + try { + memory.subBlock(length + 8, 4); + Assert.fail(); + } catch (Exception expected) { + Assert.assertThat(expected.getMessage(), containsString("should not be larger than")); + } + } + + @Test + public void ByteArrayMemoryBlockTest() { + byte[] obj = new byte[56]; + long offset = Platform.BYTE_ARRAY_OFFSET; + int length = obj.length; + + MemoryBlock memory = new ByteArrayMemoryBlock(obj, offset, length); + check(memory, obj, offset, length); + + memory = ByteArrayMemoryBlock.fromArray(obj); + check(memory, obj, offset, length); + + obj = new byte[112]; + memory = new ByteArrayMemoryBlock(obj, offset, length); + check(memory, obj, offset, length); + } + + @Test + public void OnHeapMemoryBlockTest() { + long[] obj = new long[7]; + long offset = Platform.LONG_ARRAY_OFFSET; + int length = obj.length * 8; + + MemoryBlock memory = new OnHeapMemoryBlock(obj, offset, length); + check(memory, obj, offset, length); + + memory = OnHeapMemoryBlock.fromArray(obj); + check(memory, obj, offset, length); + + obj = new long[14]; + memory = new OnHeapMemoryBlock(obj, offset, length); + check(memory, obj, offset, length); + } + + @Test + public void OffHeapArrayMemoryBlockTest() { + MemoryAllocator memoryAllocator = new UnsafeMemoryAllocator(); + MemoryBlock memory = memoryAllocator.allocate(56); + Object obj = memory.getBaseObject(); + long offset = memory.getBaseOffset(); + int length = 56; + + check(memory, obj, offset, length); + + long address = Platform.allocateMemory(112); + memory = new OffHeapMemoryBlock(address, length); + obj = memory.getBaseObject(); + offset = memory.getBaseOffset(); + check(memory, obj, offset, length); + } +} diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java index 7c34d419574ef..bad908fcaf136 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java @@ -26,6 +26,9 @@ import com.google.common.collect.ImmutableMap; import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.memory.ByteArrayMemoryBlock; +import org.apache.spark.unsafe.memory.MemoryBlock; +import org.apache.spark.unsafe.memory.OnHeapMemoryBlock; import org.junit.Test; import static org.junit.Assert.*; @@ -519,7 +522,8 @@ public void writeToOutputStreamUnderflow() throws IOException { final byte[] test = "01234567".getBytes(StandardCharsets.UTF_8); for (int i = 1; i <= Platform.BYTE_ARRAY_OFFSET; ++i) { - UTF8String.fromAddress(test, Platform.BYTE_ARRAY_OFFSET - i, test.length + i) + new UTF8String( + new ByteArrayMemoryBlock(test, Platform.BYTE_ARRAY_OFFSET - i, test.length + i)) .writeTo(outputStream); final ByteBuffer buffer = ByteBuffer.wrap(outputStream.toByteArray(), i, test.length); assertEquals("01234567", StandardCharsets.UTF_8.decode(buffer).toString()); @@ -534,7 +538,7 @@ public void writeToOutputStreamSlice() throws IOException { for (int i = 0; i < test.length; ++i) { for (int j = 0; j < test.length - i; ++j) { - UTF8String.fromAddress(test, Platform.BYTE_ARRAY_OFFSET + i, j) + new UTF8String(ByteArrayMemoryBlock.fromArray(test).subBlock(i, j)) .writeTo(outputStream); assertArrayEquals(Arrays.copyOfRange(test, i, i + j), outputStream.toByteArray()); @@ -565,7 +569,7 @@ public void writeToOutputStreamOverflow() throws IOException { for (final long offset : offsets) { try { - fromAddress(test, BYTE_ARRAY_OFFSET + offset, test.length) + new UTF8String(ByteArrayMemoryBlock.fromArray(test).subBlock(offset, test.length)) .writeTo(outputStream); throw new IllegalStateException(Long.toString(offset)); @@ -592,26 +596,25 @@ public void writeToOutputStream() throws IOException { } @Test - public void writeToOutputStreamIntArray() throws IOException { + public void writeToOutputStreamLongArray() throws IOException { // verify that writes work on objects that are not byte arrays - final ByteBuffer buffer = StandardCharsets.UTF_8.encode("大千世界"); + final ByteBuffer buffer = StandardCharsets.UTF_8.encode("3千大千世界"); buffer.position(0); buffer.order(ByteOrder.nativeOrder()); final int length = buffer.limit(); - assertEquals(12, length); + assertEquals(16, length); - final int ints = length / 4; - final int[] array = new int[ints]; + final int longs = length / 8; + final long[] array = new long[longs]; - for (int i = 0; i < ints; ++i) { - array[i] = buffer.getInt(); + for (int i = 0; i < longs; ++i) { + array[i] = buffer.getLong(); } final ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); - fromAddress(array, Platform.INT_ARRAY_OFFSET, length) - .writeTo(outputStream); - assertEquals("大千世界", outputStream.toString("UTF-8")); + new UTF8String(OnHeapMemoryBlock.fromArray(array)).writeTo(outputStream); + assertEquals("3千大千世界", outputStream.toString("UTF-8")); } @Test diff --git a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java index d07faf1da1248..8651a639c07f7 100644 --- a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java +++ b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java @@ -311,7 +311,7 @@ public MemoryBlock allocatePage(long size, MemoryConsumer consumer) { // this could trigger spilling to free some pages. return allocatePage(size, consumer); } - page.pageNumber = pageNumber; + page.setPageNumber(pageNumber); pageTable[pageNumber] = page; if (logger.isTraceEnabled()) { logger.trace("Allocate page number {} ({} bytes)", pageNumber, acquired); @@ -323,25 +323,25 @@ public MemoryBlock allocatePage(long size, MemoryConsumer consumer) { * Free a block of memory allocated via {@link TaskMemoryManager#allocatePage}. */ public void freePage(MemoryBlock page, MemoryConsumer consumer) { - assert (page.pageNumber != MemoryBlock.NO_PAGE_NUMBER) : + assert (page.getPageNumber() != MemoryBlock.NO_PAGE_NUMBER) : "Called freePage() on memory that wasn't allocated with allocatePage()"; - assert (page.pageNumber != MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER) : + assert (page.getPageNumber() != MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER) : "Called freePage() on a memory block that has already been freed"; - assert (page.pageNumber != MemoryBlock.FREED_IN_TMM_PAGE_NUMBER) : + assert (page.getPageNumber() != MemoryBlock.FREED_IN_TMM_PAGE_NUMBER) : "Called freePage() on a memory block that has already been freed"; - assert(allocatedPages.get(page.pageNumber)); - pageTable[page.pageNumber] = null; + assert(allocatedPages.get(page.getPageNumber())); + pageTable[page.getPageNumber()] = null; synchronized (this) { - allocatedPages.clear(page.pageNumber); + allocatedPages.clear(page.getPageNumber()); } if (logger.isTraceEnabled()) { - logger.trace("Freed page number {} ({} bytes)", page.pageNumber, page.size()); + logger.trace("Freed page number {} ({} bytes)", page.getPageNumber(), page.size()); } long pageSize = page.size(); // Clear the page number before passing the block to the MemoryAllocator's free(). // Doing this allows the MemoryAllocator to detect when a TaskMemoryManager-managed // page has been inappropriately directly freed without calling TMM.freePage(). - page.pageNumber = MemoryBlock.FREED_IN_TMM_PAGE_NUMBER; + page.setPageNumber(MemoryBlock.FREED_IN_TMM_PAGE_NUMBER); memoryManager.tungstenMemoryAllocator().free(page); releaseExecutionMemory(pageSize, consumer); } @@ -363,7 +363,7 @@ public long encodePageNumberAndOffset(MemoryBlock page, long offsetInPage) { // relative to the page's base offset; this relative offset will fit in 51 bits. offsetInPage -= page.getBaseOffset(); } - return encodePageNumberAndOffset(page.pageNumber, offsetInPage); + return encodePageNumberAndOffset(page.getPageNumber(), offsetInPage); } @VisibleForTesting @@ -434,7 +434,7 @@ public long cleanUpAllAllocatedMemory() { for (MemoryBlock page : pageTable) { if (page != null) { logger.debug("unreleased page: " + page + " in task " + taskAttemptId); - page.pageNumber = MemoryBlock.FREED_IN_TMM_PAGE_NUMBER; + page.setPageNumber(MemoryBlock.FREED_IN_TMM_PAGE_NUMBER); memoryManager.tungstenMemoryAllocator().free(page); } } diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java index dc36809d8911f..8f49859746b89 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java @@ -20,7 +20,6 @@ import java.util.Comparator; import org.apache.spark.memory.MemoryConsumer; -import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.LongArray; import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.util.collection.Sorter; @@ -105,13 +104,7 @@ public void reset() { public void expandPointerArray(LongArray newArray) { assert(newArray.size() > array.size()); - Platform.copyMemory( - array.getBaseObject(), - array.getBaseOffset(), - newArray.getBaseObject(), - newArray.getBaseOffset(), - pos * 8L - ); + MemoryBlock.copyMemory(array.memoryBlock(), newArray.memoryBlock(), pos * 8L); consumer.freeArray(array); array = newArray; usableCapacity = getUsableCapacity(); @@ -180,10 +173,7 @@ public ShuffleSorterIterator getSortedIterator() { PackedRecordPointer.PARTITION_ID_START_BYTE_INDEX, PackedRecordPointer.PARTITION_ID_END_BYTE_INDEX, false, false); } else { - MemoryBlock unused = new MemoryBlock( - array.getBaseObject(), - array.getBaseOffset() + pos * 8L, - (array.size() - pos) * 8L); + MemoryBlock unused = array.memoryBlock().subBlock(pos * 8L, (array.size() - pos) * 8L); LongArray buffer = new LongArray(unused); Sorter sorter = new Sorter<>(new ShuffleSortDataFormat(buffer)); diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleSortDataFormat.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleSortDataFormat.java index 717bdd79d47ef..254449e95443e 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleSortDataFormat.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleSortDataFormat.java @@ -17,8 +17,8 @@ package org.apache.spark.shuffle.sort; -import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.LongArray; +import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.util.collection.SortDataFormat; final class ShuffleSortDataFormat extends SortDataFormat { @@ -60,13 +60,8 @@ public void copyElement(LongArray src, int srcPos, LongArray dst, int dstPos) { @Override public void copyRange(LongArray src, int srcPos, LongArray dst, int dstPos, int length) { - Platform.copyMemory( - src.getBaseObject(), - src.getBaseOffset() + srcPos * 8L, - dst.getBaseObject(), - dst.getBaseOffset() + dstPos * 8L, - length * 8L - ); + MemoryBlock.copyMemory(src.memoryBlock(), srcPos * 8L, + dst.memoryBlock(),dstPos * 8L,length * 8L); } @Override diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java index 66118f454159b..4fc19b1721518 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -544,7 +544,7 @@ public long spill() throws IOException { // is accessing the current record. We free this page in that caller's next loadNext() // call. for (MemoryBlock page : allocatedPages) { - if (!loaded || page.pageNumber != + if (!loaded || page.getPageNumber() != ((UnsafeInMemorySorter.SortedIterator)upstream).getCurrentPageNumber()) { released += page.size(); freePage(page); diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java index b3c27d83da172..20a7a8b267438 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java @@ -26,7 +26,6 @@ import org.apache.spark.memory.MemoryConsumer; import org.apache.spark.memory.SparkOutOfMemoryError; import org.apache.spark.memory.TaskMemoryManager; -import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.UnsafeAlignedOffset; import org.apache.spark.unsafe.array.LongArray; import org.apache.spark.unsafe.memory.MemoryBlock; @@ -216,12 +215,7 @@ public void expandPointerArray(LongArray newArray) { if (newArray.size() < array.size()) { throw new SparkOutOfMemoryError("Not enough memory to grow pointer array"); } - Platform.copyMemory( - array.getBaseObject(), - array.getBaseOffset(), - newArray.getBaseObject(), - newArray.getBaseOffset(), - pos * 8L); + MemoryBlock.copyMemory(array.memoryBlock(), newArray.memoryBlock(), pos * 8L); consumer.freeArray(array); array = newArray; usableCapacity = getUsableCapacity(); @@ -348,10 +342,7 @@ public UnsafeSorterIterator getSortedIterator() { array, nullBoundaryPos, (pos - nullBoundaryPos) / 2L, 0, 7, radixSortSupport.sortDescending(), radixSortSupport.sortSigned()); } else { - MemoryBlock unused = new MemoryBlock( - array.getBaseObject(), - array.getBaseOffset() + pos * 8L, - (array.size() - pos) * 8L); + MemoryBlock unused = array.memoryBlock().subBlock(pos * 8L, (array.size() - pos) * 8L); LongArray buffer = new LongArray(unused); Sorter sorter = new Sorter<>(new UnsafeSortDataFormat(buffer)); diff --git a/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java b/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java index a0664b30d6cc2..d7d2d0b012bd3 100644 --- a/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java +++ b/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java @@ -76,7 +76,7 @@ public void freeingPageSetsPageNumberToSpecialConstant() { final MemoryConsumer c = new TestMemoryConsumer(manager, MemoryMode.ON_HEAP); final MemoryBlock dataPage = manager.allocatePage(256, c); c.freePage(dataPage); - Assert.assertEquals(MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER, dataPage.pageNumber); + Assert.assertEquals(MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER, dataPage.getPageNumber()); } @Test(expected = AssertionError.class) diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala index 47173b89e91e2..3e56db5ea116a 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala @@ -26,7 +26,7 @@ import org.apache.spark._ import org.apache.spark.memory.MemoryTestingUtils import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} import org.apache.spark.unsafe.array.LongArray -import org.apache.spark.unsafe.memory.MemoryBlock +import org.apache.spark.unsafe.memory.OnHeapMemoryBlock import org.apache.spark.util.collection.unsafe.sort.{PrefixComparators, RecordPointerAndKeyPrefix, UnsafeSortDataFormat} class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { @@ -105,9 +105,8 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { // the form [150000000, 150000001, 150000002, ...., 300000000, 0, 1, 2, ..., 149999999] // that can trigger copyRange() in TimSort.mergeLo() or TimSort.mergeHi() val ref = Array.tabulate[Long](size) { i => if (i < size / 2) size / 2 + i else i } - val buf = new LongArray(MemoryBlock.fromLongArray(ref)) - val tmp = new Array[Long](size/2) - val tmpBuf = new LongArray(MemoryBlock.fromLongArray(tmp)) + val buf = new LongArray(OnHeapMemoryBlock.fromArray(ref)) + val tmpBuf = new LongArray(new OnHeapMemoryBlock((size/2) * 8L)) new Sorter(new UnsafeSortDataFormat(tmpBuf)).sort( buf, 0, size, new Comparator[RecordPointerAndKeyPrefix] { diff --git a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/RadixSortSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/RadixSortSuite.scala index d5956ea32096a..ddf3740e76a7a 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/RadixSortSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/RadixSortSuite.scala @@ -27,7 +27,7 @@ import com.google.common.primitives.Ints import org.apache.spark.SparkFunSuite import org.apache.spark.internal.Logging import org.apache.spark.unsafe.array.LongArray -import org.apache.spark.unsafe.memory.MemoryBlock +import org.apache.spark.unsafe.memory.OnHeapMemoryBlock import org.apache.spark.util.collection.Sorter import org.apache.spark.util.random.XORShiftRandom @@ -78,14 +78,14 @@ class RadixSortSuite extends SparkFunSuite with Logging { private def generateTestData(size: Long, rand: => Long): (Array[JLong], LongArray) = { val ref = Array.tabulate[Long](Ints.checkedCast(size)) { i => rand } val extended = ref ++ Array.fill[Long](Ints.checkedCast(size))(0) - (ref.map(i => new JLong(i)), new LongArray(MemoryBlock.fromLongArray(extended))) + (ref.map(i => new JLong(i)), new LongArray(OnHeapMemoryBlock.fromArray(extended))) } private def generateKeyPrefixTestData(size: Long, rand: => Long): (LongArray, LongArray) = { val ref = Array.tabulate[Long](Ints.checkedCast(size * 2)) { i => rand } val extended = ref ++ Array.fill[Long](Ints.checkedCast(size * 2))(0) - (new LongArray(MemoryBlock.fromLongArray(ref)), - new LongArray(MemoryBlock.fromLongArray(extended))) + (new LongArray(OnHeapMemoryBlock.fromArray(ref)), + new LongArray(OnHeapMemoryBlock.fromArray(extended))) } private def collectToArray(array: LongArray, offset: Int, length: Long): Array[Long] = { @@ -110,7 +110,7 @@ class RadixSortSuite extends SparkFunSuite with Logging { } private def referenceKeyPrefixSort(buf: LongArray, lo: Long, hi: Long, refCmp: PrefixComparator) { - val sortBuffer = new LongArray(MemoryBlock.fromLongArray(new Array[Long](buf.size().toInt))) + val sortBuffer = new LongArray(new OnHeapMemoryBlock(buf.size() * 8L)) new Sorter(new UnsafeSortDataFormat(sortBuffer)).sort( buf, Ints.checkedCast(lo), Ints.checkedCast(hi), new Comparator[RecordPointerAndKeyPrefix] { override def compare( diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/FeatureHasher.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/FeatureHasher.scala index c78f61ac3ef71..d67e4819b161a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/FeatureHasher.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/FeatureHasher.scala @@ -29,7 +29,7 @@ import org.apache.spark.mllib.feature.{HashingTF => OldHashingTF} import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.hash.Murmur3_x86_32.{hashInt, hashLong, hashUnsafeBytes2} +import org.apache.spark.unsafe.hash.Murmur3_x86_32.{hashInt, hashLong, hashUnsafeBytes2Block} import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils import org.apache.spark.util.collection.OpenHashMap @@ -243,8 +243,7 @@ object FeatureHasher extends DefaultParamsReadable[FeatureHasher] { case f: Float => hashInt(java.lang.Float.floatToIntBits(f), seed) case d: Double => hashLong(java.lang.Double.doubleToLongBits(d), seed) case s: String => - val utf8 = UTF8String.fromString(s) - hashUnsafeBytes2(utf8.getBaseObject, utf8.getBaseOffset, utf8.numBytes(), seed) + hashUnsafeBytes2Block(UTF8String.fromString(s).getMemoryBlock, seed) case _ => throw new SparkException("FeatureHasher with murmur3 algorithm does not " + s"support type ${term.getClass.getCanonicalName} of input data.") } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala index 8935c8496cdbb..7b73b286fb91c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala @@ -160,7 +160,7 @@ object HashingTF { case d: Double => hashLong(java.lang.Double.doubleToLongBits(d), seed) case s: String => val utf8 = UTF8String.fromString(s) - hashUnsafeBytes(utf8.getBaseObject, utf8.getBaseOffset, utf8.numBytes(), seed) + hashUnsafeBytesBlock(utf8.getMemoryBlock(), seed) case _ => throw new SparkException("HashingTF with murmur3 algorithm does not " + s"support type ${term.getClass.getCanonicalName} of input data.") } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java index d18542b188f71..8546c28335536 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java @@ -27,6 +27,7 @@ import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.bitset.BitSetMethods; import org.apache.spark.unsafe.hash.Murmur3_x86_32; +import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; @@ -230,7 +231,8 @@ public UTF8String getUTF8String(int ordinal) { final long offsetAndSize = getLong(ordinal); final int offset = (int) (offsetAndSize >> 32); final int size = (int) offsetAndSize; - return UTF8String.fromAddress(baseObject, baseOffset + offset, size); + MemoryBlock mb = MemoryBlock.allocateFromObject(baseObject, baseOffset + offset, size); + return new UTF8String(mb); } @Override diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index 71c086029cc5b..29a1411241cf6 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -37,6 +37,7 @@ import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.bitset.BitSetMethods; import org.apache.spark.unsafe.hash.Murmur3_x86_32; +import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; @@ -414,7 +415,8 @@ public UTF8String getUTF8String(int ordinal) { final long offsetAndSize = getLong(ordinal); final int offset = (int) (offsetAndSize >> 32); final int size = (int) offsetAndSize; - return UTF8String.fromAddress(baseObject, baseOffset + offset, size); + MemoryBlock mb = MemoryBlock.allocateFromObject(baseObject, baseOffset + offset, size); + return new UTF8String(mb); } @Override diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/XXH64.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/XXH64.java index f37ef83ad92b4..883748932ad33 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/XXH64.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/XXH64.java @@ -16,7 +16,10 @@ */ package org.apache.spark.sql.catalyst.expressions; +import com.google.common.primitives.Ints; + import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.memory.MemoryBlock; // scalastyle: off /** @@ -71,13 +74,13 @@ public static long hashLong(long input, long seed) { return fmix(hash); } - public long hashUnsafeWords(Object base, long offset, int length) { - return hashUnsafeWords(base, offset, length, seed); + public long hashUnsafeWordsBlock(MemoryBlock mb) { + return hashUnsafeWordsBlock(mb, seed); } - public static long hashUnsafeWords(Object base, long offset, int length, long seed) { - assert (length % 8 == 0) : "lengthInBytes must be a multiple of 8 (word-aligned)"; - long hash = hashBytesByWords(base, offset, length, seed); + public static long hashUnsafeWordsBlock(MemoryBlock mb, long seed) { + assert (mb.size() % 8 == 0) : "lengthInBytes must be a multiple of 8 (word-aligned)"; + long hash = hashBytesByWordsBlock(mb, seed); return fmix(hash); } @@ -85,26 +88,32 @@ public long hashUnsafeBytes(Object base, long offset, int length) { return hashUnsafeBytes(base, offset, length, seed); } - public static long hashUnsafeBytes(Object base, long offset, int length, long seed) { + public static long hashUnsafeBytesBlock(MemoryBlock mb, long seed) { + long offset = 0; + long length = mb.size(); assert (length >= 0) : "lengthInBytes cannot be negative"; - long hash = hashBytesByWords(base, offset, length, seed); + long hash = hashBytesByWordsBlock(mb, seed); long end = offset + length; offset += length & -8; if (offset + 4L <= end) { - hash ^= (Platform.getInt(base, offset) & 0xFFFFFFFFL) * PRIME64_1; + hash ^= (mb.getInt(offset) & 0xFFFFFFFFL) * PRIME64_1; hash = Long.rotateLeft(hash, 23) * PRIME64_2 + PRIME64_3; offset += 4L; } while (offset < end) { - hash ^= (Platform.getByte(base, offset) & 0xFFL) * PRIME64_5; + hash ^= (mb.getByte(offset) & 0xFFL) * PRIME64_5; hash = Long.rotateLeft(hash, 11) * PRIME64_1; offset++; } return fmix(hash); } + public static long hashUnsafeBytes(Object base, long offset, int length, long seed) { + return hashUnsafeBytesBlock(MemoryBlock.allocateFromObject(base, offset, length), seed); + } + private static long fmix(long hash) { hash ^= hash >>> 33; hash *= PRIME64_2; @@ -114,30 +123,31 @@ private static long fmix(long hash) { return hash; } - private static long hashBytesByWords(Object base, long offset, int length, long seed) { - long end = offset + length; + private static long hashBytesByWordsBlock(MemoryBlock mb, long seed) { + long offset = 0; + long length = mb.size(); long hash; if (length >= 32) { - long limit = end - 32; + long limit = length - 32; long v1 = seed + PRIME64_1 + PRIME64_2; long v2 = seed + PRIME64_2; long v3 = seed; long v4 = seed - PRIME64_1; do { - v1 += Platform.getLong(base, offset) * PRIME64_2; + v1 += mb.getLong(offset) * PRIME64_2; v1 = Long.rotateLeft(v1, 31); v1 *= PRIME64_1; - v2 += Platform.getLong(base, offset + 8) * PRIME64_2; + v2 += mb.getLong(offset + 8) * PRIME64_2; v2 = Long.rotateLeft(v2, 31); v2 *= PRIME64_1; - v3 += Platform.getLong(base, offset + 16) * PRIME64_2; + v3 += mb.getLong(offset + 16) * PRIME64_2; v3 = Long.rotateLeft(v3, 31); v3 *= PRIME64_1; - v4 += Platform.getLong(base, offset + 24) * PRIME64_2; + v4 += mb.getLong(offset + 24) * PRIME64_2; v4 = Long.rotateLeft(v4, 31); v4 *= PRIME64_1; @@ -178,9 +188,9 @@ private static long hashBytesByWords(Object base, long offset, int length, long hash += length; - long limit = end - 8; + long limit = length - 8; while (offset <= limit) { - long k1 = Platform.getLong(base, offset); + long k1 = mb.getLong(offset); hash ^= Long.rotateLeft(k1 * PRIME64_2, 31) * PRIME64_1; hash = Long.rotateLeft(hash, 27) * PRIME64_1 + PRIME64_4; offset += 8L; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala index b702422ed7a1d..b76b64ab5096f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.hash.Murmur3_x86_32 +import org.apache.spark.unsafe.memory.MemoryBlock import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -360,10 +361,8 @@ abstract class HashExpression[E] extends Expression { } protected def genHashString(input: String, result: String): String = { - val baseObject = s"$input.getBaseObject()" - val baseOffset = s"$input.getBaseOffset()" - val numBytes = s"$input.numBytes()" - s"$result = $hasherClassName.hashUnsafeBytes($baseObject, $baseOffset, $numBytes, $result);" + val mb = s"$input.getMemoryBlock()" + s"$result = $hasherClassName.hashUnsafeBytesBlock($mb, $result);" } protected def genHashForMap( @@ -465,6 +464,8 @@ abstract class InterpretedHashFunction { protected def hashUnsafeBytes(base: AnyRef, offset: Long, length: Int, seed: Long): Long + protected def hashUnsafeBytesBlock(base: MemoryBlock, seed: Long): Long + /** * Computes hash of a given `value` of type `dataType`. The caller needs to check the validity * of input `value`. @@ -490,8 +491,7 @@ abstract class InterpretedHashFunction { case c: CalendarInterval => hashInt(c.months, hashLong(c.microseconds, seed)) case a: Array[Byte] => hashUnsafeBytes(a, Platform.BYTE_ARRAY_OFFSET, a.length, seed) - case s: UTF8String => - hashUnsafeBytes(s.getBaseObject, s.getBaseOffset, s.numBytes(), seed) + case s: UTF8String => hashUnsafeBytesBlock(s.getMemoryBlock(), seed) case array: ArrayData => val elementType = dataType match { @@ -578,9 +578,15 @@ object Murmur3HashFunction extends InterpretedHashFunction { Murmur3_x86_32.hashLong(l, seed.toInt) } - override protected def hashUnsafeBytes(base: AnyRef, offset: Long, len: Int, seed: Long): Long = { + override protected def hashUnsafeBytes( + base: AnyRef, offset: Long, len: Int, seed: Long): Long = { Murmur3_x86_32.hashUnsafeBytes(base, offset, len, seed.toInt) } + + override protected def hashUnsafeBytesBlock( + base: MemoryBlock, seed: Long): Long = { + Murmur3_x86_32.hashUnsafeBytesBlock(base, seed.toInt) + } } /** @@ -605,9 +611,14 @@ object XxHash64Function extends InterpretedHashFunction { override protected def hashLong(l: Long, seed: Long): Long = XXH64.hashLong(l, seed) - override protected def hashUnsafeBytes(base: AnyRef, offset: Long, len: Int, seed: Long): Long = { + override protected def hashUnsafeBytes( + base: AnyRef, offset: Long, len: Int, seed: Long): Long = { XXH64.hashUnsafeBytes(base, offset, len, seed) } + + override protected def hashUnsafeBytesBlock(base: MemoryBlock, seed: Long): Long = { + XXH64.hashUnsafeBytesBlock(base, seed) + } } /** @@ -714,10 +725,8 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] { """ override protected def genHashString(input: String, result: String): String = { - val baseObject = s"$input.getBaseObject()" - val baseOffset = s"$input.getBaseOffset()" - val numBytes = s"$input.numBytes()" - s"$result = $hasherClassName.hashUnsafeBytes($baseObject, $baseOffset, $numBytes);" + val mb = s"$input.getMemoryBlock()" + s"$result = $hasherClassName.hashUnsafeBytesBlock($mb);" } override protected def genHashForArray( @@ -805,10 +814,14 @@ object HiveHashFunction extends InterpretedHashFunction { HiveHasher.hashLong(l) } - override protected def hashUnsafeBytes(base: AnyRef, offset: Long, len: Int, seed: Long): Long = { + override protected def hashUnsafeBytes( + base: AnyRef, offset: Long, len: Int, seed: Long): Long = { HiveHasher.hashUnsafeBytes(base, offset, len) } + override protected def hashUnsafeBytesBlock( + base: MemoryBlock, seed: Long): Long = HiveHasher.hashUnsafeBytesBlock(base) + private val HIVE_DECIMAL_MAX_PRECISION = 38 private val HIVE_DECIMAL_MAX_SCALE = 38 diff --git a/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/HiveHasherSuite.java b/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/HiveHasherSuite.java index b67c6f3e6e85e..8ffc1d7c24d61 100644 --- a/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/HiveHasherSuite.java +++ b/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/HiveHasherSuite.java @@ -18,6 +18,8 @@ package org.apache.spark.sql.catalyst.expressions; import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.memory.ByteArrayMemoryBlock; +import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.unsafe.types.UTF8String; import org.junit.Assert; import org.junit.Test; @@ -53,7 +55,7 @@ public void testKnownStringAndIntInputs() { for (int i = 0; i < inputs.length; i++) { UTF8String s = UTF8String.fromString("val_" + inputs[i]); - int hash = HiveHasher.hashUnsafeBytes(s.getBaseObject(), s.getBaseOffset(), s.numBytes()); + int hash = HiveHasher.hashUnsafeBytesBlock(s.getMemoryBlock()); Assert.assertEquals(expected[i], ((31 * inputs[i]) + hash)); } } @@ -89,13 +91,13 @@ public void randomizedStressTestBytes() { int byteArrSize = rand.nextInt(100) * 8; byte[] bytes = new byte[byteArrSize]; rand.nextBytes(bytes); + MemoryBlock mb = ByteArrayMemoryBlock.fromArray(bytes); Assert.assertEquals( - HiveHasher.hashUnsafeBytes(bytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize), - HiveHasher.hashUnsafeBytes(bytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize)); + HiveHasher.hashUnsafeBytesBlock(mb), + HiveHasher.hashUnsafeBytesBlock(mb)); - hashcodes.add(HiveHasher.hashUnsafeBytes( - bytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize)); + hashcodes.add(HiveHasher.hashUnsafeBytesBlock(mb)); } // A very loose bound. @@ -112,13 +114,13 @@ public void randomizedStressTestPaddedStrings() { byte[] strBytes = String.valueOf(i).getBytes(StandardCharsets.UTF_8); byte[] paddedBytes = new byte[byteArrSize]; System.arraycopy(strBytes, 0, paddedBytes, 0, strBytes.length); + MemoryBlock mb = ByteArrayMemoryBlock.fromArray(paddedBytes); Assert.assertEquals( - HiveHasher.hashUnsafeBytes(paddedBytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize), - HiveHasher.hashUnsafeBytes(paddedBytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize)); + HiveHasher.hashUnsafeBytesBlock(mb), + HiveHasher.hashUnsafeBytesBlock(mb)); - hashcodes.add(HiveHasher.hashUnsafeBytes( - paddedBytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize)); + hashcodes.add(HiveHasher.hashUnsafeBytesBlock(mb)); } // A very loose bound. diff --git a/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/XXH64Suite.java b/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/XXH64Suite.java index 1baee91b3439c..cd8bce623c5df 100644 --- a/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/XXH64Suite.java +++ b/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/XXH64Suite.java @@ -24,6 +24,8 @@ import java.util.Set; import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.memory.ByteArrayMemoryBlock; +import org.apache.spark.unsafe.memory.MemoryBlock; import org.junit.Assert; import org.junit.Test; @@ -142,13 +144,13 @@ public void randomizedStressTestBytes() { int byteArrSize = rand.nextInt(100) * 8; byte[] bytes = new byte[byteArrSize]; rand.nextBytes(bytes); + MemoryBlock mb = ByteArrayMemoryBlock.fromArray(bytes); Assert.assertEquals( - hasher.hashUnsafeWords(bytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize), - hasher.hashUnsafeWords(bytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize)); + hasher.hashUnsafeWordsBlock(mb), + hasher.hashUnsafeWordsBlock(mb)); - hashcodes.add(hasher.hashUnsafeWords( - bytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize)); + hashcodes.add(hasher.hashUnsafeWordsBlock(mb)); } // A very loose bound. @@ -165,13 +167,13 @@ public void randomizedStressTestPaddedStrings() { byte[] strBytes = String.valueOf(i).getBytes(StandardCharsets.UTF_8); byte[] paddedBytes = new byte[byteArrSize]; System.arraycopy(strBytes, 0, paddedBytes, 0, strBytes.length); + MemoryBlock mb = ByteArrayMemoryBlock.fromArray(paddedBytes); Assert.assertEquals( - hasher.hashUnsafeWords(paddedBytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize), - hasher.hashUnsafeWords(paddedBytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize)); + hasher.hashUnsafeWordsBlock(mb), + hasher.hashUnsafeWordsBlock(mb)); - hashcodes.add(hasher.hashUnsafeWords( - paddedBytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize)); + hashcodes.add(hasher.hashUnsafeWordsBlock(mb)); } // A very loose bound. diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java index 754c26579ff08..4733f36174f42 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java @@ -23,6 +23,7 @@ import org.apache.spark.sql.types.*; import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.memory.OffHeapMemoryBlock; import org.apache.spark.unsafe.types.UTF8String; /** @@ -206,7 +207,7 @@ public byte[] getBytes(int rowId, int count) { @Override protected UTF8String getBytesAsUTF8String(int rowId, int count) { - return UTF8String.fromAddress(null, data + rowId, count); + return new UTF8String(new OffHeapMemoryBlock(data + rowId, count)); } // diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java index f8e37e995a17f..227a16f7e69e9 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java @@ -25,6 +25,7 @@ import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.execution.arrow.ArrowUtils; import org.apache.spark.sql.types.*; +import org.apache.spark.unsafe.memory.OffHeapMemoryBlock; import org.apache.spark.unsafe.types.UTF8String; /** @@ -377,9 +378,10 @@ final UTF8String getUTF8String(int rowId) { if (stringResult.isSet == 0) { return null; } else { - return UTF8String.fromAddress(null, + return new UTF8String(new OffHeapMemoryBlock( stringResult.buffer.memoryAddress() + stringResult.start, - stringResult.end - stringResult.start); + stringResult.end - stringResult.start + )); } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/SortBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/SortBenchmark.scala index 50ae26a3ff9d9..470b93efd1974 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/SortBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/SortBenchmark.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.benchmark import java.util.{Arrays, Comparator} import org.apache.spark.unsafe.array.LongArray -import org.apache.spark.unsafe.memory.MemoryBlock +import org.apache.spark.unsafe.memory.OnHeapMemoryBlock import org.apache.spark.util.Benchmark import org.apache.spark.util.collection.Sorter import org.apache.spark.util.collection.unsafe.sort._ @@ -36,7 +36,7 @@ import org.apache.spark.util.random.XORShiftRandom class SortBenchmark extends BenchmarkBase { private def referenceKeyPrefixSort(buf: LongArray, lo: Int, hi: Int, refCmp: PrefixComparator) { - val sortBuffer = new LongArray(MemoryBlock.fromLongArray(new Array[Long](buf.size().toInt))) + val sortBuffer = new LongArray(new OnHeapMemoryBlock(buf.size() * 8L)) new Sorter(new UnsafeSortDataFormat(sortBuffer)).sort( buf, lo, hi, new Comparator[RecordPointerAndKeyPrefix] { override def compare( @@ -50,8 +50,8 @@ class SortBenchmark extends BenchmarkBase { private def generateKeyPrefixTestData(size: Int, rand: => Long): (LongArray, LongArray) = { val ref = Array.tabulate[Long](size * 2) { i => rand } val extended = ref ++ Array.fill[Long](size * 2)(0) - (new LongArray(MemoryBlock.fromLongArray(ref)), - new LongArray(MemoryBlock.fromLongArray(extended))) + (new LongArray(OnHeapMemoryBlock.fromArray(ref)), + new LongArray(OnHeapMemoryBlock.fromArray(extended))) } ignore("sort") { @@ -60,7 +60,7 @@ class SortBenchmark extends BenchmarkBase { val benchmark = new Benchmark("radix sort " + size, size) benchmark.addTimerCase("reference TimSort key prefix array") { timer => val array = Array.tabulate[Long](size * 2) { i => rand.nextLong } - val buf = new LongArray(MemoryBlock.fromLongArray(array)) + val buf = new LongArray(OnHeapMemoryBlock.fromArray(array)) timer.startTiming() referenceKeyPrefixSort(buf, 0, size, PrefixComparators.BINARY) timer.stopTiming() @@ -78,7 +78,7 @@ class SortBenchmark extends BenchmarkBase { array(i) = rand.nextLong & 0xff i += 1 } - val buf = new LongArray(MemoryBlock.fromLongArray(array)) + val buf = new LongArray(OnHeapMemoryBlock.fromArray(array)) timer.startTiming() RadixSort.sort(buf, size, 0, 7, false, false) timer.stopTiming() @@ -90,7 +90,7 @@ class SortBenchmark extends BenchmarkBase { array(i) = rand.nextLong & 0xffff i += 1 } - val buf = new LongArray(MemoryBlock.fromLongArray(array)) + val buf = new LongArray(OnHeapMemoryBlock.fromArray(array)) timer.startTiming() RadixSort.sort(buf, size, 0, 7, false, false) timer.stopTiming() @@ -102,7 +102,7 @@ class SortBenchmark extends BenchmarkBase { array(i) = rand.nextLong i += 1 } - val buf = new LongArray(MemoryBlock.fromLongArray(array)) + val buf = new LongArray(OnHeapMemoryBlock.fromArray(array)) timer.startTiming() RadixSort.sort(buf, size, 0, 7, false, false) timer.stopTiming() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/RowQueueSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/RowQueueSuite.scala index ffda33cf906c5..25ee95daa034c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/RowQueueSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/RowQueueSuite.scala @@ -22,13 +22,13 @@ import java.io.File import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.memory.{MemoryManager, TaskMemoryManager, TestMemoryManager} import org.apache.spark.sql.catalyst.expressions.UnsafeRow -import org.apache.spark.unsafe.memory.MemoryBlock +import org.apache.spark.unsafe.memory.OnHeapMemoryBlock import org.apache.spark.util.Utils class RowQueueSuite extends SparkFunSuite { test("in-memory queue") { - val page = MemoryBlock.fromLongArray(new Array[Long](1<<10)) + val page = new OnHeapMemoryBlock((1<<10) * 8L) val queue = new InMemoryRowQueue(page, 1) { override def close() {} } From f2ac0879561cde63ed4eb759f5efa0a5ce393a22 Mon Sep 17 00:00:00 2001 From: Yogesh Garg Date: Thu, 5 Apr 2018 19:55:42 -0700 Subject: [PATCH 0569/1161] [SPARK-23870][ML] Forward RFormula handleInvalid Param to VectorAssembler to handle invalid values in non-string columns ## What changes were proposed in this pull request? `handleInvalid` Param was forwarded to the VectorAssembler used by RFormula. ## How was this patch tested? added a test and ran all tests for RFormula and VectorAssembler Author: Yogesh Garg Closes #20970 from yogeshg/spark_23562. --- .../apache/spark/ml/feature/RFormula.scala | 1 + .../spark/ml/feature/RFormulaSuite.scala | 23 +++++++++++++++++++ 2 files changed, 24 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala index 22e7b8bbf1ff5..e214765e3307f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala @@ -278,6 +278,7 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override val uid: String) encoderStages += new VectorAssembler(uid) .setInputCols(encodedTerms.toArray) .setOutputCol($(featuresCol)) + .setHandleInvalid($(handleInvalid)) encoderStages += new VectorAttributeRewriter($(featuresCol), prefixesToRewrite.toMap) encoderStages += new ColumnPruner(tempColumns.toSet) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala index 27d570f0b68ad..a250331efeb1d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.ml.feature +import org.apache.spark.SparkException import org.apache.spark.ml.attribute._ import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite @@ -592,4 +593,26 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { assert(features.toArray === a +: b.toArray) } } + + test("SPARK-23562 RFormula handleInvalid should handle invalid values in non-string columns.") { + val d1 = Seq( + (1001L, "a"), + (1002L, "b")).toDF("id1", "c1") + val d2 = Seq[(java.lang.Long, String)]( + (20001L, "x"), + (20002L, "y"), + (null, null)).toDF("id2", "c2") + val dataset = d1.crossJoin(d2) + + def get_output(mode: String): DataFrame = { + val formula = new RFormula().setFormula("c1 ~ id2").setHandleInvalid(mode) + formula.fit(dataset).transform(dataset).select("features", "label") + } + + assert(intercept[SparkException](get_output("error").collect()) + .getMessage.contains("Encountered null while assembling a row")) + assert(get_output("skip").count() == 4) + assert(get_output("keep").count() == 6) + } + } From d65e531b44a388fed25d3cbf28fdce5a2d0598e6 Mon Sep 17 00:00:00 2001 From: JiahuiJiang Date: Thu, 5 Apr 2018 20:06:08 -0700 Subject: [PATCH 0570/1161] [SPARK-23823][SQL] Keep origin in transformExpression Fixes https://issues.apache.org/jira/browse/SPARK-23823 Keep origin for all the methods using transformExpression ## What changes were proposed in this pull request? Keep origin in transformExpression ## How was this patch tested? Manually tested that this fixes https://issues.apache.org/jira/browse/SPARK-23823 and columns have correct origins after Analyzer.analyze Author: JiahuiJiang Author: Jiahui Jiang Closes #20961 from JiahuiJiang/jj/keep-origin. --- .../spark/sql/catalyst/plans/QueryPlan.scala | 6 ++- .../sql/catalyst/plans/QueryPlanSuite.scala | 42 +++++++++++++++++++ 2 files changed, 46 insertions(+), 2 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/QueryPlanSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index ddf2cbf2ab911..64cb8c726772f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.plans import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.trees.TreeNode +import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, TreeNode} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DataType, StructType} @@ -103,7 +103,9 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT var changed = false @inline def transformExpression(e: Expression): Expression = { - val newE = f(e) + val newE = CurrentOrigin.withOrigin(e.origin) { + f(e) + } if (newE.fastEquals(e)) { e } else { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/QueryPlanSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/QueryPlanSuite.scala new file mode 100644 index 0000000000000..27914ef5565c0 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/QueryPlanSuite.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.plans + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.dsl.plans +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression, Literal, NamedExpression} +import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin} +import org.apache.spark.sql.types.IntegerType + +class QueryPlanSuite extends SparkFunSuite { + + test("origin remains the same after mapExpressions (SPARK-23823)") { + CurrentOrigin.setPosition(0, 0) + val column = AttributeReference("column", IntegerType)(NamedExpression.newExprId) + val query = plans.DslLogicalPlan(plans.table("table")).select(column) + CurrentOrigin.reset() + + val mappedQuery = query mapExpressions { + case _: Expression => Literal(1) + } + + val mappedOrigin = mappedQuery.expressions.apply(0).origin + assert(mappedOrigin == Origin.apply(Some(0), Some(0))) + } + +} From 249007e37f51f00d14e596692aeac87fbc10b520 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Thu, 5 Apr 2018 20:19:25 -0700 Subject: [PATCH 0571/1161] [SPARK-19724][SQL] create a managed table with an existed default table should throw an exception ## What changes were proposed in this pull request? This PR is to finish https://github.com/apache/spark/pull/17272 This JIRA is a follow up work after SPARK-19583 As we discussed in that PR The following DDL for a managed table with an existed default location should throw an exception: CREATE TABLE ... (PARTITIONED BY ...) AS SELECT ... CREATE TABLE ... (PARTITIONED BY ...) Currently there are some situations which are not consist with above logic: CREATE TABLE ... (PARTITIONED BY ...) succeed with an existed default location situation: for both hive/datasource(with HiveExternalCatalog/InMemoryCatalog) CREATE TABLE ... (PARTITIONED BY ...) AS SELECT ... situation: hive table succeed with an existed default location This PR is going to make above two situations consist with the logic that it should throw an exception with an existed default location. ## How was this patch tested? unit test added Author: Gengliang Wang Closes #20886 from gengliangwang/pr-17272. --- docs/sql-programming-guide.md | 1 + .../sql/catalyst/catalog/SessionCatalog.scala | 23 ++++++- .../apache/spark/sql/internal/SQLConf.scala | 11 +++ .../command/createDataSourceTables.scala | 5 +- .../spark/sql/StatisticsCollectionSuite.scala | 7 ++ .../sql/execution/command/DDLSuite.scala | 67 +++++++++++++++++++ 6 files changed, 110 insertions(+), 4 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 2b393f30d1435..9822d669050d5 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1809,6 +1809,7 @@ working with timestamps in `pandas_udf`s to get the best performance, see - Since Spark 2.4, expression IDs in UDF arguments do not appear in column names. For example, an column name in Spark 2.4 is not `UDF:f(col0 AS colA#28)` but ``UDF:f(col0 AS `colA`)``. - Since Spark 2.4, writing a dataframe with an empty or nested empty schema using any file formats (parquet, orc, json, text, csv etc.) is not allowed. An exception is thrown when attempting to write dataframes with empty schema. - Since Spark 2.4, Spark compares a DATE type with a TIMESTAMP type after promotes both sides to TIMESTAMP. To set `false` to `spark.sql.hive.compareDateTimestampInTimestamp` restores the previous behavior. This option will be removed in Spark 3.0. + - Since Spark 2.4, creating a managed table with nonempty location is not allowed. An exception is thrown when attempting to create a managed table with nonempty location. To set `true` to `spark.sql.allowCreatingManagedTableUsingNonemptyLocation` restores the previous behavior. This option will be removed in Spark 3.0. ## Upgrading From Spark SQL 2.2 to 2.3 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index 64e7ca11270b4..52ed89ef8d781 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -289,6 +289,7 @@ class SessionCatalog( def createTable(tableDefinition: CatalogTable, ignoreIfExists: Boolean): Unit = { val db = formatDatabaseName(tableDefinition.identifier.database.getOrElse(getCurrentDatabase)) val table = formatTableName(tableDefinition.identifier.table) + val tableIdentifier = TableIdentifier(table, Some(db)) validateName(table) val newTableDefinition = if (tableDefinition.storage.locationUri.isDefined @@ -298,15 +299,33 @@ class SessionCatalog( makeQualifiedPath(tableDefinition.storage.locationUri.get) tableDefinition.copy( storage = tableDefinition.storage.copy(locationUri = Some(qualifiedTableLocation)), - identifier = TableIdentifier(table, Some(db))) + identifier = tableIdentifier) } else { - tableDefinition.copy(identifier = TableIdentifier(table, Some(db))) + tableDefinition.copy(identifier = tableIdentifier) } requireDbExists(db) + if (!ignoreIfExists) { + validateTableLocation(newTableDefinition) + } externalCatalog.createTable(newTableDefinition, ignoreIfExists) } + def validateTableLocation(table: CatalogTable): Unit = { + // SPARK-19724: the default location of a managed table should be non-existent or empty. + if (table.tableType == CatalogTableType.MANAGED && + !conf.allowCreatingManagedTableUsingNonemptyLocation) { + val tableLocation = + new Path(table.storage.locationUri.getOrElse(defaultTablePath(table.identifier))) + val fs = tableLocation.getFileSystem(hadoopConf) + + if (fs.exists(tableLocation) && fs.listStatus(tableLocation).nonEmpty) { + throw new AnalysisException(s"Can not create the managed table('${table.identifier}')" + + s". The associated location('${tableLocation.toString}') already exists.") + } + } + } + /** * Alter the metadata of an existing metastore table identified by `tableDefinition`. * diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 13f31a6b2eb93..1c8ab9c62623e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1159,6 +1159,14 @@ object SQLConf { .booleanConf .createWithDefault(false) + val ALLOW_CREATING_MANAGED_TABLE_USING_NONEMPTY_LOCATION = + buildConf("spark.sql.allowCreatingManagedTableUsingNonemptyLocation") + .internal() + .doc("When this option is set to true, creating managed tables with nonempty location " + + "is allowed. Otherwise, an analysis exception is thrown. ") + .booleanConf + .createWithDefault(false) + val CONTINUOUS_STREAMING_EXECUTOR_QUEUE_SIZE = buildConf("spark.sql.streaming.continuous.executorQueueSize") .internal() @@ -1581,6 +1589,9 @@ class SQLConf extends Serializable with Logging { def eltOutputAsString: Boolean = getConf(ELT_OUTPUT_AS_STRING) + def allowCreatingManagedTableUsingNonemptyLocation: Boolean = + getConf(ALLOW_CREATING_MANAGED_TABLE_USING_NONEMPTY_LOCATION) + def partitionOverwriteMode: PartitionOverwriteMode.Value = PartitionOverwriteMode.withName(getConf(PARTITION_OVERWRITE_MODE)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala index e9747769dfcfc..f7c3e9b019258 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala @@ -167,7 +167,7 @@ case class CreateDataSourceTableAsSelectCommand( sparkSession, table, table.storage.locationUri, child, SaveMode.Append, tableExists = true) } else { assert(table.schema.isEmpty) - + sparkSession.sessionState.catalog.validateTableLocation(table) val tableLocation = if (table.tableType == CatalogTableType.MANAGED) { Some(sessionState.catalog.defaultTablePath(table.identifier)) } else { @@ -181,7 +181,8 @@ case class CreateDataSourceTableAsSelectCommand( // the schema of df). It is important since the nullability may be changed by the relation // provider (for example, see org.apache.spark.sql.parquet.DefaultSource). schema = result.schema) - sessionState.catalog.createTable(newTable, ignoreIfExists = false) + // Table location is already validated. No need to check it again during table creation. + sessionState.catalog.createTable(newTable, ignoreIfExists = true) result match { case fs: HadoopFsRelation if table.partitionColumnNames.nonEmpty && diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala index ed4ea0231f1a7..14a565863d66c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql +import java.io.File + import scala.collection.mutable import org.apache.spark.sql.catalyst.TableIdentifier @@ -26,6 +28,7 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.test.SQLTestData.ArrayData import org.apache.spark.sql.types._ +import org.apache.spark.util.Utils /** @@ -242,6 +245,7 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared test("change stats after set location command") { val table = "change_stats_set_location_table" + val tableLoc = new File(spark.sessionState.catalog.defaultTablePath(TableIdentifier(table))) Seq(false, true).foreach { autoUpdate => withSQLConf(SQLConf.AUTO_SIZE_UPDATE_ENABLED.key -> autoUpdate.toString) { withTable(table) { @@ -269,6 +273,9 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared assert(fetched3.get.sizeInBytes == fetched1.get.sizeInBytes) } else { checkTableStats(table, hasSizeInBytes = false, expectedRowCounts = None) + // SPARK-19724: clean up the previous table location. + waitForTasksToFinish() + Utils.deleteRecursively(tableLoc) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index 4df8fbfe1c0db..4304d0b6f6b16 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -180,6 +180,13 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { private val escapedIdentifier = "`(.+)`".r + private def dataSource: String = { + if (isUsingHiveMetastore) { + "HIVE" + } else { + "PARQUET" + } + } protected def normalizeCatalogTable(table: CatalogTable): CatalogTable = table private def normalizeSerdeProp(props: Map[String, String]): Map[String, String] = { @@ -365,6 +372,66 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } } + test("CTAS a managed table with the existing empty directory") { + val tableLoc = new File(spark.sessionState.catalog.defaultTablePath(TableIdentifier("tab1"))) + try { + tableLoc.mkdir() + withTable("tab1") { + sql(s"CREATE TABLE tab1 USING ${dataSource} AS SELECT 1, 'a'") + checkAnswer(spark.table("tab1"), Row(1, "a")) + } + } finally { + waitForTasksToFinish() + Utils.deleteRecursively(tableLoc) + } + } + + test("create a managed table with the existing empty directory") { + val tableLoc = new File(spark.sessionState.catalog.defaultTablePath(TableIdentifier("tab1"))) + try { + tableLoc.mkdir() + withTable("tab1") { + sql(s"CREATE TABLE tab1 (col1 int, col2 string) USING ${dataSource}") + sql("INSERT INTO tab1 VALUES (1, 'a')") + checkAnswer(spark.table("tab1"), Row(1, "a")) + } + } finally { + waitForTasksToFinish() + Utils.deleteRecursively(tableLoc) + } + } + + test("create a managed table with the existing non-empty directory") { + withTable("tab1") { + val tableLoc = new File(spark.sessionState.catalog.defaultTablePath(TableIdentifier("tab1"))) + try { + // create an empty hidden file + tableLoc.mkdir() + val hiddenGarbageFile = new File(tableLoc.getCanonicalPath, ".garbage") + hiddenGarbageFile.createNewFile() + val exMsg = "Can not create the managed table('`tab1`'). The associated location" + val exMsgWithDefaultDB = + "Can not create the managed table('`default`.`tab1`'). The associated location" + var ex = intercept[AnalysisException] { + sql(s"CREATE TABLE tab1 USING ${dataSource} AS SELECT 1, 'a'") + }.getMessage + if (isUsingHiveMetastore) { + assert(ex.contains(exMsgWithDefaultDB)) + } else { + assert(ex.contains(exMsg)) + } + + ex = intercept[AnalysisException] { + sql(s"CREATE TABLE tab1 (col1 int, col2 string) USING ${dataSource}") + }.getMessage + assert(ex.contains(exMsgWithDefaultDB)) + } finally { + waitForTasksToFinish() + Utils.deleteRecursively(tableLoc) + } + } + } + private def checkSchemaInCreatedDataSourceTable( path: File, userSpecifiedSchema: Option[String], From 6ade5cbb498f6c6ea38779b97f2325d5cf5013f2 Mon Sep 17 00:00:00 2001 From: Daniel Sakuma Date: Fri, 6 Apr 2018 13:37:08 +0800 Subject: [PATCH 0572/1161] [MINOR][DOC] Fix some typos and grammar issues ## What changes were proposed in this pull request? Easy fix in the documentation. ## How was this patch tested? N/A Closes #20948 Author: Daniel Sakuma Closes #20928 from dsakuma/fix_typo_configuration_docs. --- docs/README.md | 2 +- docs/_plugins/include_example.rb | 2 +- docs/building-spark.md | 2 +- docs/cloud-integration.md | 4 +-- docs/configuration.md | 20 ++++++------ docs/css/pygments-default.css | 2 +- docs/graphx-programming-guide.md | 4 +-- docs/job-scheduling.md | 4 +-- docs/ml-advanced.md | 2 +- docs/ml-classification-regression.md | 6 ++-- docs/ml-collaborative-filtering.md | 2 +- docs/ml-features.md | 2 +- docs/ml-migration-guides.md | 2 +- docs/ml-tuning.md | 2 +- docs/mllib-clustering.md | 2 +- docs/mllib-collaborative-filtering.md | 4 +-- docs/mllib-data-types.md | 2 +- docs/mllib-dimensionality-reduction.md | 2 +- docs/mllib-evaluation-metrics.md | 2 +- docs/mllib-feature-extraction.md | 2 +- docs/mllib-isotonic-regression.md | 4 +-- docs/mllib-linear-methods.md | 2 +- docs/mllib-optimization.md | 4 +-- docs/monitoring.md | 4 +-- docs/quick-start.md | 6 ++-- docs/rdd-programming-guide.md | 2 +- docs/running-on-kubernetes.md | 4 +-- docs/running-on-mesos.md | 12 +++---- docs/running-on-yarn.md | 2 +- docs/security.md | 2 +- docs/spark-standalone.md | 2 +- docs/sparkr.md | 6 ++-- docs/sql-programming-guide.md | 32 +++++++++---------- docs/storage-openstack-swift.md | 2 +- docs/streaming-flume-integration.md | 6 ++-- docs/streaming-kafka-0-8-integration.md | 10 +++--- docs/streaming-programming-guide.md | 26 +++++++-------- .../structured-streaming-kafka-integration.md | 2 +- .../structured-streaming-programming-guide.md | 8 ++--- docs/submitting-applications.md | 2 +- docs/tuning.md | 2 +- python/README.md | 2 +- sql/README.md | 2 +- 43 files changed, 107 insertions(+), 107 deletions(-) diff --git a/docs/README.md b/docs/README.md index 225bb1b2040de..9eac4ba35c458 100644 --- a/docs/README.md +++ b/docs/README.md @@ -5,7 +5,7 @@ here with the Spark source code. You can also find documentation specific to rel Spark at http://spark.apache.org/documentation.html. Read on to learn more about viewing documentation in plain text (i.e., markdown) or building the -documentation yourself. Why build it yourself? So that you have the docs that corresponds to +documentation yourself. Why build it yourself? So that you have the docs that correspond to whichever version of Spark you currently have checked out of revision control. ## Prerequisites diff --git a/docs/_plugins/include_example.rb b/docs/_plugins/include_example.rb index 6ea1d438f529e..1e91f12518e0b 100644 --- a/docs/_plugins/include_example.rb +++ b/docs/_plugins/include_example.rb @@ -48,7 +48,7 @@ def render(context) begin code = File.open(@file).read.encode("UTF-8") rescue => e - # We need to explicitly exit on execptions here because Jekyll will silently swallow + # We need to explicitly exit on exceptions here because Jekyll will silently swallow # them, leading to silent build failures (see https://github.com/jekyll/jekyll/issues/5104) puts(e) puts(e.backtrace) diff --git a/docs/building-spark.md b/docs/building-spark.md index c391255a91596..0236bb05849ad 100644 --- a/docs/building-spark.md +++ b/docs/building-spark.md @@ -113,7 +113,7 @@ Note: Flume support is deprecated as of Spark 2.3.0. ## Building submodules individually -It's possible to build Spark sub-modules using the `mvn -pl` option. +It's possible to build Spark submodules using the `mvn -pl` option. For instance, you can build the Spark Streaming module using: diff --git a/docs/cloud-integration.md b/docs/cloud-integration.md index c150d9efc06ff..ac1c336988930 100644 --- a/docs/cloud-integration.md +++ b/docs/cloud-integration.md @@ -27,13 +27,13 @@ description: Introduction to cloud storage support in Apache Spark SPARK_VERSION All major cloud providers offer persistent data storage in *object stores*. These are not classic "POSIX" file systems. In order to store hundreds of petabytes of data without any single points of failure, -object stores replace the classic filesystem directory tree +object stores replace the classic file system directory tree with a simpler model of `object-name => data`. To enable remote access, operations on objects are usually offered as (slow) HTTP REST operations. Spark can read and write data in object stores through filesystem connectors implemented in Hadoop or provided by the infrastructure suppliers themselves. -These connectors make the object stores look *almost* like filesystems, with directories and files +These connectors make the object stores look *almost* like file systems, with directories and files and the classic operations on them such as list, delete and rename. diff --git a/docs/configuration.md b/docs/configuration.md index 2eb6a77434ea6..4d4d0c58dd07d 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -558,7 +558,7 @@ Apart from these, the following properties are also available, and may be useful @@ -1288,7 +1288,7 @@ Apart from these, the following properties are also available, and may be useful @@ -1513,7 +1513,7 @@ Apart from these, the following properties are also available, and may be useful @@ -1722,7 +1722,7 @@ Apart from these, the following properties are also available, and may be useful When spark.task.reaper.enabled = true, this setting specifies a timeout after which the executor JVM will kill itself if a killed task has not stopped running. The default value, -1, disables this mechanism and prevents the executor from self-destructing. The purpose - of this setting is to act as a safety-net to prevent runaway uncancellable tasks from rendering + of this setting is to act as a safety-net to prevent runaway noncancellable tasks from rendering an executor unusable. @@ -1915,8 +1915,8 @@ showDF(properties, numRows = 200, truncate = FALSE) @@ -1971,7 +1971,7 @@ showDF(properties, numRows = 200, truncate = FALSE) @@ -1980,7 +1980,7 @@ showDF(properties, numRows = 200, truncate = FALSE) @@ -2178,7 +2178,7 @@ Spark's classpath for each application. In a Spark cluster running on YARN, thes files are set cluster-wide, and cannot safely be changed by the application. The better choice is to use spark hadoop properties in the form of `spark.hadoop.*`. -They can be considered as same as normal spark properties which can be set in `$SPARK_HOME/conf/spark-defalut.conf` +They can be considered as same as normal spark properties which can be set in `$SPARK_HOME/conf/spark-default.conf` In some cases, you may want to avoid hard-coding certain configurations in a `SparkConf`. For instance, Spark allows you to simply create an empty conf and set spark/spark hadoop properties. diff --git a/docs/css/pygments-default.css b/docs/css/pygments-default.css index 6247cd8396cf1..a4d583b366603 100644 --- a/docs/css/pygments-default.css +++ b/docs/css/pygments-default.css @@ -5,7 +5,7 @@ To generate this, I had to run But first I had to install pygments via easy_install pygments I had to override the conflicting bootstrap style rules by linking to -this stylesheet lower in the html than the bootstap css. +this stylesheet lower in the html than the bootstrap css. Also, I was thrown off for a while at first when I was using markdown code block inside my {% highlight scala %} ... {% endhighlight %} tags diff --git a/docs/graphx-programming-guide.md b/docs/graphx-programming-guide.md index 5c97a248df4bc..35293348e3f3d 100644 --- a/docs/graphx-programming-guide.md +++ b/docs/graphx-programming-guide.md @@ -491,7 +491,7 @@ val joinedGraph = graph.joinVertices(uniqueCosts)( The more general [`outerJoinVertices`][Graph.outerJoinVertices] behaves similarly to `joinVertices` except that the user defined `map` function is applied to all vertices and can change the vertex property type. Because not all vertices may have a matching value in the input RDD the `map` -function takes an `Option` type. For example, we can setup a graph for PageRank by initializing +function takes an `Option` type. For example, we can set up a graph for PageRank by initializing vertex properties with their `outDegree`. @@ -969,7 +969,7 @@ A vertex is part of a triangle when it has two adjacent vertices with an edge be # Examples Suppose I want to build a graph from some text files, restrict the graph -to important relationships and users, run page-rank on the sub-graph, and +to important relationships and users, run page-rank on the subgraph, and then finally return attributes associated with the top users. I can do all of this in just a few lines with GraphX: diff --git a/docs/job-scheduling.md b/docs/job-scheduling.md index e6d881639a13b..da90342406c84 100644 --- a/docs/job-scheduling.md +++ b/docs/job-scheduling.md @@ -23,7 +23,7 @@ run tasks and store data for that application. If multiple users need to share y different options to manage allocation, depending on the cluster manager. The simplest option, available on all cluster managers, is _static partitioning_ of resources. With -this approach, each application is given a maximum amount of resources it can use, and holds onto them +this approach, each application is given a maximum amount of resources it can use and holds onto them for its whole duration. This is the approach used in Spark's [standalone](spark-standalone.html) and [YARN](running-on-yarn.html) modes, as well as the [coarse-grained Mesos mode](running-on-mesos.html#mesos-run-modes). @@ -230,7 +230,7 @@ properties: * `minShare`: Apart from an overall weight, each pool can be given a _minimum shares_ (as a number of CPU cores) that the administrator would like it to have. The fair scheduler always attempts to meet all active pools' minimum shares before redistributing extra resources according to the weights. - The `minShare` property can therefore be another way to ensure that a pool can always get up to a + The `minShare` property can, therefore, be another way to ensure that a pool can always get up to a certain number of resources (e.g. 10 cores) quickly without giving it a high priority for the rest of the cluster. By default, each pool's `minShare` is 0. diff --git a/docs/ml-advanced.md b/docs/ml-advanced.md index 2747f2df7cb10..375957e92cc4c 100644 --- a/docs/ml-advanced.md +++ b/docs/ml-advanced.md @@ -77,7 +77,7 @@ Quasi-Newton methods in this case. This fallback is currently always enabled for L1 regularization is applied (i.e. $\alpha = 0$), there exists an analytical solution and either Cholesky or Quasi-Newton solver may be used. When $\alpha > 0$ no analytical solution exists and we instead use the Quasi-Newton solver to find the coefficients iteratively. -In order to make the normal equation approach efficient, `WeightedLeastSquares` requires that the number of features be no more than 4096. For larger problems, use L-BFGS instead. +In order to make the normal equation approach efficient, `WeightedLeastSquares` requires that the number of features is no more than 4096. For larger problems, use L-BFGS instead. ## Iteratively reweighted least squares (IRLS) diff --git a/docs/ml-classification-regression.md b/docs/ml-classification-regression.md index ddd2f4b49ca07..d660655e193eb 100644 --- a/docs/ml-classification-regression.md +++ b/docs/ml-classification-regression.md @@ -420,7 +420,7 @@ Refer to the [R API docs](api/R/spark.svmLinear.html) for more details. [OneVsRest](http://en.wikipedia.org/wiki/Multiclass_classification#One-vs.-rest) is an example of a machine learning reduction for performing multiclass classification given a base classifier that can perform binary classification efficiently. It is also known as "One-vs-All." -`OneVsRest` is implemented as an `Estimator`. For the base classifier it takes instances of `Classifier` and creates a binary classification problem for each of the k classes. The classifier for class i is trained to predict whether the label is i or not, distinguishing class i from all other classes. +`OneVsRest` is implemented as an `Estimator`. For the base classifier, it takes instances of `Classifier` and creates a binary classification problem for each of the k classes. The classifier for class i is trained to predict whether the label is i or not, distinguishing class i from all other classes. Predictions are done by evaluating each binary classifier and the index of the most confident classifier is output as label. @@ -908,7 +908,7 @@ Refer to the [R API docs](api/R/spark.survreg.html) for more details. belongs to the family of regression algorithms. Formally isotonic regression is a problem where given a finite set of real numbers `$Y = {y_1, y_2, ..., y_n}$` representing observed responses and `$X = {x_1, x_2, ..., x_n}$` the unknown response values to be fitted -finding a function that minimises +finding a function that minimizes `\begin{equation} f(x) = \sum_{i=1}^n w_i (y_i - x_i)^2 @@ -927,7 +927,7 @@ We implement a which uses an approach to [parallelizing isotonic regression](http://doi.org/10.1007/978-3-642-99789-1_10). The training input is a DataFrame which contains three columns -label, features and weight. Additionally IsotonicRegression algorithm has one +label, features and weight. Additionally, IsotonicRegression algorithm has one optional parameter called $isotonic$ defaulting to true. This argument specifies if the isotonic regression is isotonic (monotonically increasing) or antitonic (monotonically decreasing). diff --git a/docs/ml-collaborative-filtering.md b/docs/ml-collaborative-filtering.md index 58f2d4b531e70..8b0f287dc39ad 100644 --- a/docs/ml-collaborative-filtering.md +++ b/docs/ml-collaborative-filtering.md @@ -35,7 +35,7 @@ but the ids must be within the integer value range. ### Explicit vs. implicit feedback -The standard approach to matrix factorization based collaborative filtering treats +The standard approach to matrix factorization-based collaborative filtering treats the entries in the user-item matrix as *explicit* preferences given by the user to the item, for example, users giving ratings to movies. diff --git a/docs/ml-features.md b/docs/ml-features.md index 3370eb3893272..7aed2341584fc 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -1174,7 +1174,7 @@ for more details on the API. ## SQLTransformer `SQLTransformer` implements the transformations which are defined by SQL statement. -Currently we only support SQL syntax like `"SELECT ... FROM __THIS__ ..."` +Currently, we only support SQL syntax like `"SELECT ... FROM __THIS__ ..."` where `"__THIS__"` represents the underlying table of the input dataset. The select clause specifies the fields, constants, and expressions to display in the output, and can be any select clause that Spark SQL supports. Users can also diff --git a/docs/ml-migration-guides.md b/docs/ml-migration-guides.md index f4b0df58cf63b..e4736411fb5fe 100644 --- a/docs/ml-migration-guides.md +++ b/docs/ml-migration-guides.md @@ -347,7 +347,7 @@ rather than using the old parameter class `Strategy`. These new training method separate classification and regression, and they replace specialized parameter types with simple `String` types. -Examples of the new, recommended `trainClassifier` and `trainRegressor` are given in the +Examples of the new recommended `trainClassifier` and `trainRegressor` are given in the [Decision Trees Guide](mllib-decision-tree.html#examples). ## From 0.9 to 1.0 diff --git a/docs/ml-tuning.md b/docs/ml-tuning.md index 54d9cd21909df..028bfec465bab 100644 --- a/docs/ml-tuning.md +++ b/docs/ml-tuning.md @@ -103,7 +103,7 @@ Refer to the [`CrossValidator` Python docs](api/python/pyspark.ml.html#pyspark.m In addition to `CrossValidator` Spark also offers `TrainValidationSplit` for hyper-parameter tuning. `TrainValidationSplit` only evaluates each combination of parameters once, as opposed to k times in - the case of `CrossValidator`. It is therefore less expensive, + the case of `CrossValidator`. It is, therefore, less expensive, but will not produce as reliable results when the training dataset is not sufficiently large. Unlike `CrossValidator`, `TrainValidationSplit` creates a single (training, test) dataset pair. diff --git a/docs/mllib-clustering.md b/docs/mllib-clustering.md index df2be92d860e4..dc6b095f5d59b 100644 --- a/docs/mllib-clustering.md +++ b/docs/mllib-clustering.md @@ -42,7 +42,7 @@ The following code snippets can be executed in `spark-shell`. In the following example after loading and parsing data, we use the [`KMeans`](api/scala/index.html#org.apache.spark.mllib.clustering.KMeans) object to cluster the data into two clusters. The number of desired clusters is passed to the algorithm. We then compute Within -Set Sum of Squared Error (WSSSE). You can reduce this error measure by increasing *k*. In fact the +Set Sum of Squared Error (WSSSE). You can reduce this error measure by increasing *k*. In fact, the optimal *k* is usually one where there is an "elbow" in the WSSSE graph. Refer to the [`KMeans` Scala docs](api/scala/index.html#org.apache.spark.mllib.clustering.KMeans) and [`KMeansModel` Scala docs](api/scala/index.html#org.apache.spark.mllib.clustering.KMeansModel) for details on the API. diff --git a/docs/mllib-collaborative-filtering.md b/docs/mllib-collaborative-filtering.md index 76a00f18b3b90..b2300028e151b 100644 --- a/docs/mllib-collaborative-filtering.md +++ b/docs/mllib-collaborative-filtering.md @@ -31,7 +31,7 @@ following parameters: ### Explicit vs. implicit feedback -The standard approach to matrix factorization based collaborative filtering treats +The standard approach to matrix factorization-based collaborative filtering treats the entries in the user-item matrix as *explicit* preferences given by the user to the item, for example, users giving ratings to movies. @@ -60,7 +60,7 @@ best parameter learned from a sampled subset to the full dataset and expect simi
    -In the following example we load rating data. Each row consists of a user, a product and a rating. +In the following example, we load rating data. Each row consists of a user, a product and a rating. We use the default [ALS.train()](api/scala/index.html#org.apache.spark.mllib.recommendation.ALS$) method which assumes ratings are explicit. We evaluate the recommendation model by measuring the Mean Squared Error of rating prediction. diff --git a/docs/mllib-data-types.md b/docs/mllib-data-types.md index 35cee3275e3b5..5066bb29387dc 100644 --- a/docs/mllib-data-types.md +++ b/docs/mllib-data-types.md @@ -350,7 +350,7 @@ which is a tuple of `(Int, Int, Matrix)`. ***Note*** The underlying RDDs of a distributed matrix must be deterministic, because we cache the matrix size. -In general the use of non-deterministic RDDs can lead to errors. +In general, the use of non-deterministic RDDs can lead to errors. ### RowMatrix diff --git a/docs/mllib-dimensionality-reduction.md b/docs/mllib-dimensionality-reduction.md index a72680d52a26c..4e6b4530942f1 100644 --- a/docs/mllib-dimensionality-reduction.md +++ b/docs/mllib-dimensionality-reduction.md @@ -91,7 +91,7 @@ The same code applies to `IndexedRowMatrix` if `U` is defined as an [Principal component analysis (PCA)](http://en.wikipedia.org/wiki/Principal_component_analysis) is a statistical method to find a rotation such that the first coordinate has the largest variance -possible, and each succeeding coordinate in turn has the largest variance possible. The columns of +possible, and each succeeding coordinate, in turn, has the largest variance possible. The columns of the rotation matrix are called principal components. PCA is used widely in dimensionality reduction. `spark.mllib` supports PCA for tall-and-skinny matrices stored in row-oriented format and any Vectors. diff --git a/docs/mllib-evaluation-metrics.md b/docs/mllib-evaluation-metrics.md index 7f277543d2e9a..d9dbbab4840a3 100644 --- a/docs/mllib-evaluation-metrics.md +++ b/docs/mllib-evaluation-metrics.md @@ -13,7 +13,7 @@ of the model on some criteria, which depends on the application and its requirem suite of metrics for the purpose of evaluating the performance of machine learning models. Specific machine learning algorithms fall under broader types of machine learning applications like classification, -regression, clustering, etc. Each of these types have well established metrics for performance evaluation and those +regression, clustering, etc. Each of these types have well-established metrics for performance evaluation and those metrics that are currently available in `spark.mllib` are detailed in this section. ## Classification model evaluation diff --git a/docs/mllib-feature-extraction.md b/docs/mllib-feature-extraction.md index 8b89296b14cdd..bb29f65c0322f 100644 --- a/docs/mllib-feature-extraction.md +++ b/docs/mllib-feature-extraction.md @@ -105,7 +105,7 @@ p(w_i | w_j ) = \frac{\exp(u_{w_i}^{\top}v_{w_j})}{\sum_{l=1}^{V} \exp(u_l^{\top \]` where $V$ is the vocabulary size. -The skip-gram model with softmax is expensive because the cost of computing $\log p(w_i | w_j)$ +The skip-gram model with softmax is expensive because the cost of computing $\log p(w_i | w_j)$ is proportional to $V$, which can be easily in order of millions. To speed up training of Word2Vec, we used hierarchical softmax, which reduced the complexity of computing of $\log p(w_i | w_j)$ to $O(\log(V))$ diff --git a/docs/mllib-isotonic-regression.md b/docs/mllib-isotonic-regression.md index ca84551506b2b..99cab98c690c6 100644 --- a/docs/mllib-isotonic-regression.md +++ b/docs/mllib-isotonic-regression.md @@ -9,7 +9,7 @@ displayTitle: Regression - RDD-based API belongs to the family of regression algorithms. Formally isotonic regression is a problem where given a finite set of real numbers `$Y = {y_1, y_2, ..., y_n}$` representing observed responses and `$X = {x_1, x_2, ..., x_n}$` the unknown response values to be fitted -finding a function that minimises +finding a function that minimizes `\begin{equation} f(x) = \sum_{i=1}^n w_i (y_i - x_i)^2 @@ -28,7 +28,7 @@ best fitting the original data points. which uses an approach to [parallelizing isotonic regression](http://doi.org/10.1007/978-3-642-99789-1_10). The training input is an RDD of tuples of three double values that represent -label, feature and weight in this order. Additionally IsotonicRegression algorithm has one +label, feature and weight in this order. Additionally, IsotonicRegression algorithm has one optional parameter called $isotonic$ defaulting to true. This argument specifies if the isotonic regression is isotonic (monotonically increasing) or antitonic (monotonically decreasing). diff --git a/docs/mllib-linear-methods.md b/docs/mllib-linear-methods.md index 034e89e25000e..73f6e206ca543 100644 --- a/docs/mllib-linear-methods.md +++ b/docs/mllib-linear-methods.md @@ -425,7 +425,7 @@ We create our model by initializing the weights to zero and register the streams testing then start the job. Printing predictions alongside true labels lets us easily see the result. -Finally we can save text files with data to the training or testing folders. +Finally, we can save text files with data to the training or testing folders. Each line should be a data point formatted as `(y,[x1,x2,x3])` where `y` is the label and `x1,x2,x3` are the features. Anytime a text file is placed in `args(0)` the model will update. Anytime a text file is placed in `args(1)` you will see predictions. diff --git a/docs/mllib-optimization.md b/docs/mllib-optimization.md index 14d76a6e41e23..04758903da89c 100644 --- a/docs/mllib-optimization.md +++ b/docs/mllib-optimization.md @@ -121,7 +121,7 @@ computation of the sum of the partial results from each worker machine is perfor standard spark routines. If the fraction of points `miniBatchFraction` is set to 1 (default), then the resulting step in -each iteration is exact (sub)gradient descent. In this case there is no randomness and no +each iteration is exact (sub)gradient descent. In this case, there is no randomness and no variance in the used step directions. On the other extreme, if `miniBatchFraction` is chosen very small, such that only a single point is sampled, i.e. `$|S|=$ miniBatchFraction $\cdot n = 1$`, then the algorithm is equivalent to @@ -135,7 +135,7 @@ algorithm in the family of quasi-Newton methods to solve the optimization proble quadratic without evaluating the second partial derivatives of the objective function to construct the Hessian matrix. The Hessian matrix is approximated by previous gradient evaluations, so there is no vertical scalability issue (the number of training features) when computing the Hessian matrix -explicitly in Newton's method. As a result, L-BFGS often achieves rapider convergence compared with +explicitly in Newton's method. As a result, L-BFGS often achieves more rapid convergence compared with other first-order optimization. ### Choosing an Optimization Method diff --git a/docs/monitoring.md b/docs/monitoring.md index 01736c77b0979..6eaf33135744d 100644 --- a/docs/monitoring.md +++ b/docs/monitoring.md @@ -214,7 +214,7 @@ incomplete attempt or the final successful attempt. 2. Incomplete applications are only updated intermittently. The time between updates is defined by the interval between checks for changed files (`spark.history.fs.update.interval`). -On larger clusters the update interval may be set to large values. +On larger clusters, the update interval may be set to large values. The way to view a running application is actually to view its own web UI. 3. Applications which exited without registering themselves as completed will be listed @@ -422,7 +422,7 @@ configuration property. If, say, users wanted to set the metrics namespace to the name of the application, they can set the `spark.metrics.namespace` property to a value like `${spark.app.name}`. This value is then expanded appropriately by Spark and is used as the root namespace of the metrics system. -Non driver and executor metrics are never prefixed with `spark.app.id`, nor does the +Non-driver and executor metrics are never prefixed with `spark.app.id`, nor does the `spark.metrics.namespace` property have any such affect on such metrics. Spark's metrics are decoupled into different diff --git a/docs/quick-start.md b/docs/quick-start.md index 07c520cbee6be..f1a2096cd4dbd 100644 --- a/docs/quick-start.md +++ b/docs/quick-start.md @@ -11,11 +11,11 @@ This tutorial provides a quick introduction to using Spark. We will first introd interactive shell (in Python or Scala), then show how to write applications in Java, Scala, and Python. -To follow along with this guide, first download a packaged release of Spark from the +To follow along with this guide, first, download a packaged release of Spark from the [Spark website](http://spark.apache.org/downloads.html). Since we won't be using HDFS, you can download a package for any version of Hadoop. -Note that, before Spark 2.0, the main programming interface of Spark was the Resilient Distributed Dataset (RDD). After Spark 2.0, RDDs are replaced by Dataset, which is strongly-typed like an RDD, but with richer optimizations under the hood. The RDD interface is still supported, and you can get a more complete reference at the [RDD programming guide](rdd-programming-guide.html). However, we highly recommend you to switch to use Dataset, which has better performance than RDD. See the [SQL programming guide](sql-programming-guide.html) to get more information about Dataset. +Note that, before Spark 2.0, the main programming interface of Spark was the Resilient Distributed Dataset (RDD). After Spark 2.0, RDDs are replaced by Dataset, which is strongly-typed like an RDD, but with richer optimizations under the hood. The RDD interface is still supported, and you can get a more detailed reference at the [RDD programming guide](rdd-programming-guide.html). However, we highly recommend you to switch to use Dataset, which has better performance than RDD. See the [SQL programming guide](sql-programming-guide.html) to get more information about Dataset. # Interactive Analysis with the Spark Shell @@ -47,7 +47,7 @@ scala> textFile.first() // First item in this Dataset res1: String = # Apache Spark {% endhighlight %} -Now let's transform this Dataset to a new one. We call `filter` to return a new Dataset with a subset of the items in the file. +Now let's transform this Dataset into a new one. We call `filter` to return a new Dataset with a subset of the items in the file. {% highlight scala %} scala> val linesWithSpark = textFile.filter(line => line.contains("Spark")) diff --git a/docs/rdd-programming-guide.md b/docs/rdd-programming-guide.md index 2e29aef7f21a2..b6424090d2fea 100644 --- a/docs/rdd-programming-guide.md +++ b/docs/rdd-programming-guide.md @@ -818,7 +818,7 @@ The behavior of the above code is undefined, and may not work as intended. To ex The variables within the closure sent to each executor are now copies and thus, when **counter** is referenced within the `foreach` function, it's no longer the **counter** on the driver node. There is still a **counter** in the memory of the driver node but this is no longer visible to the executors! The executors only see the copy from the serialized closure. Thus, the final value of **counter** will still be zero since all operations on **counter** were referencing the value within the serialized closure. -In local mode, in some circumstances the `foreach` function will actually execute within the same JVM as the driver and will reference the same original **counter**, and may actually update it. +In local mode, in some circumstances, the `foreach` function will actually execute within the same JVM as the driver and will reference the same original **counter**, and may actually update it. To ensure well-defined behavior in these sorts of scenarios one should use an [`Accumulator`](#accumulators). Accumulators in Spark are used specifically to provide a mechanism for safely updating a variable when execution is split up across worker nodes in a cluster. The Accumulators section of this guide discusses these in more detail. diff --git a/docs/running-on-kubernetes.md b/docs/running-on-kubernetes.md index 9c4644947c911..e9e1f3e280609 100644 --- a/docs/running-on-kubernetes.md +++ b/docs/running-on-kubernetes.md @@ -17,7 +17,7 @@ container images and entrypoints.** * A runnable distribution of Spark 2.3 or above. * A running Kubernetes cluster at version >= 1.6 with access configured to it using [kubectl](https://kubernetes.io/docs/user-guide/prereqs/). If you do not already have a working Kubernetes cluster, -you may setup a test cluster on your local machine using +you may set up a test cluster on your local machine using [minikube](https://kubernetes.io/docs/getting-started-guides/minikube/). * We recommend using the latest release of minikube with the DNS addon enabled. * Be aware that the default minikube configuration is not enough for running Spark applications. @@ -221,7 +221,7 @@ that allows driver pods to create pods and services under the default Kubernetes [RBAC](https://kubernetes.io/docs/admin/authorization/rbac/) policies. Sometimes users may need to specify a custom service account that has the right role granted. Spark on Kubernetes supports specifying a custom service account to be used by the driver pod through the configuration property -`spark.kubernetes.authenticate.driver.serviceAccountName=`. For example to make the driver pod +`spark.kubernetes.authenticate.driver.serviceAccountName=`. For example, to make the driver pod use the `spark` service account, a user simply adds the following option to the `spark-submit` command: ``` diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md index 8e58892e2689f..3c2a1501ca692 100644 --- a/docs/running-on-mesos.md +++ b/docs/running-on-mesos.md @@ -90,7 +90,7 @@ Depending on your deployment environment you may wish to create a single set of Framework credentials may be specified in a variety of ways depending on your deployment environment and security requirements. The most simple way is to specify the `spark.mesos.principal` and `spark.mesos.secret` values directly in your Spark configuration. Alternatively you may specify these values indirectly by instead specifying `spark.mesos.principal.file` and `spark.mesos.secret.file`, these settings point to files containing the principal and secret. These files must be plaintext files in UTF-8 encoding. Combined with appropriate file ownership and mode/ACLs this provides a more secure way to specify these credentials. -Additionally if you prefer to use environment variables you can specify all of the above via environment variables instead, the environment variable names are simply the configuration settings uppercased with `.` replaced with `_` e.g. `SPARK_MESOS_PRINCIPAL`. +Additionally, if you prefer to use environment variables you can specify all of the above via environment variables instead, the environment variable names are simply the configuration settings uppercased with `.` replaced with `_` e.g. `SPARK_MESOS_PRINCIPAL`. ### Credential Specification Preference Order @@ -225,7 +225,7 @@ details and default values. Executors are brought up eagerly when the application starts, until `spark.cores.max` is reached. If you don't set `spark.cores.max`, the Spark application will consume all resources offered to it by Mesos, -so we of course urge you to set this variable in any sort of +so we, of course, urge you to set this variable in any sort of multi-tenant cluster, including one which runs multiple concurrent Spark applications. @@ -233,14 +233,14 @@ The scheduler will start executors round-robin on the offers Mesos gives it, but there are no spread guarantees, as Mesos does not provide such guarantees on the offer stream. -In this mode spark executors will honor port allocation if such is -provided from the user. Specifically if the user defines +In this mode Spark executors will honor port allocation if such is +provided from the user. Specifically, if the user defines `spark.blockManager.port` in Spark configuration, the mesos scheduler will check the available offers for a valid port range containing the port numbers. If no such range is available it will not launch any task. If no restriction is imposed on port numbers by the user, ephemeral ports are used as usual. This port honouring implementation -implies one task per host if the user defines a port. In the future network +implies one task per host if the user defines a port. In the future network, isolation shall be supported. The benefit of coarse-grained mode is much lower startup overhead, but @@ -486,7 +486,7 @@ See the [configuration page](configuration.html) for information on Spark config
    - + @@ -1797,6 +1798,23 @@ Apart from these, the following properties are also available, and may be useful Lower bound for the number of executors if dynamic allocation is enabled. + + + + + From 2a24c481da3f30b510deb62e5cf21c9463cf250c Mon Sep 17 00:00:00 2001 From: Lu WANG Date: Tue, 24 Apr 2018 09:25:41 -0700 Subject: [PATCH 0680/1161] [SPARK-23975][ML] Allow Clustering to take Arrays of Double as input features ## What changes were proposed in this pull request? - Multiple possible input types is added in validateAndTransformSchema() and computeCost() while checking column type - Add if statement in transform() to support array type as featuresCol - Add the case statement in fit() while selecting columns from dataset These changes will be applied to KMeans first, then to other clustering method ## How was this patch tested? unit test is added Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Lu WANG Closes #21081 from ludatabricks/SPARK-23975. --- .../apache/spark/ml/clustering/KMeans.scala | 32 +++++++--- .../apache/spark/ml/util/DatasetUtils.scala | 63 +++++++++++++++++++ .../spark/ml/clustering/KMeansSuite.scala | 38 +++++++++++ 3 files changed, 126 insertions(+), 7 deletions(-) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/util/DatasetUtils.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index 1ad157a695a7d..d475c726e6f08 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -33,8 +33,8 @@ import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => OldVectors import org.apache.spark.mllib.linalg.VectorImplicits._ import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession} -import org.apache.spark.sql.functions.{col, udf} -import org.apache.spark.sql.types.{IntegerType, StructType} +import org.apache.spark.sql.functions.udf +import org.apache.spark.sql.types.{ArrayType, DoubleType, FloatType, IntegerType, StructType} import org.apache.spark.storage.StorageLevel import org.apache.spark.util.VersionUtils.majorVersion @@ -86,13 +86,24 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe @Since("1.5.0") def getInitSteps: Int = $(initSteps) + /** + * Validates the input schema. + * @param schema input schema + */ + private[clustering] def validateSchema(schema: StructType): Unit = { + val typeCandidates = List( new VectorUDT, + new ArrayType(DoubleType, false), + new ArrayType(FloatType, false)) + + SchemaUtils.checkColumnTypes(schema, $(featuresCol), typeCandidates) + } /** * Validates and transforms the input schema. * @param schema input schema * @return output schema */ protected def validateAndTransformSchema(schema: StructType): StructType = { - SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT) + validateSchema(schema) SchemaUtils.appendColumn(schema, $(predictionCol), IntegerType) } } @@ -125,8 +136,11 @@ class KMeansModel private[ml] ( @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) + val predictUDF = udf((vector: Vector) => predict(vector)) - dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) + + dataset.withColumn($(predictionCol), + predictUDF(DatasetUtils.columnToVector(dataset, getFeaturesCol))) } @Since("1.5.0") @@ -146,8 +160,10 @@ class KMeansModel private[ml] ( // TODO: Replace the temp fix when we have proper evaluators defined for clustering. @Since("2.0.0") def computeCost(dataset: Dataset[_]): Double = { - SchemaUtils.checkColumnType(dataset.schema, $(featuresCol), new VectorUDT) - val data: RDD[OldVector] = dataset.select(col($(featuresCol))).rdd.map { + validateSchema(dataset.schema) + + val data: RDD[OldVector] = dataset.select(DatasetUtils.columnToVector(dataset, getFeaturesCol)) + .rdd.map { case Row(point: Vector) => OldVectors.fromML(point) } parentModel.computeCost(data) @@ -335,7 +351,9 @@ class KMeans @Since("1.5.0") ( transformSchema(dataset.schema, logging = true) val handlePersistence = dataset.storageLevel == StorageLevel.NONE - val instances: RDD[OldVector] = dataset.select(col($(featuresCol))).rdd.map { + val instances: RDD[OldVector] = dataset.select( + DatasetUtils.columnToVector(dataset, getFeaturesCol)) + .rdd.map { case Row(point: Vector) => OldVectors.fromML(point) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/DatasetUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/util/DatasetUtils.scala new file mode 100644 index 0000000000000..52619cb65489a --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/util/DatasetUtils.scala @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.util + +import org.apache.spark.ml.linalg.{Vectors, VectorUDT} +import org.apache.spark.sql.{Column, Dataset} +import org.apache.spark.sql.functions.{col, udf} +import org.apache.spark.sql.types.{ArrayType, DoubleType, FloatType} + + +private[spark] object DatasetUtils { + + /** + * Cast a column in a Dataset to Vector type. + * + * The supported data types of the input column are + * - Vector + * - float/double type Array. + * + * Note: The returned column does not have Metadata. + * + * @param dataset input DataFrame + * @param colName column name. + * @return Vector column + */ + def columnToVector(dataset: Dataset[_], colName: String): Column = { + val columnDataType = dataset.schema(colName).dataType + columnDataType match { + case _: VectorUDT => col(colName) + case fdt: ArrayType => + val transferUDF = fdt.elementType match { + case _: FloatType => udf(f = (vector: Seq[Float]) => { + val inputArray = Array.fill[Double](vector.size)(0.0) + vector.indices.foreach(idx => inputArray(idx) = vector(idx).toDouble) + Vectors.dense(inputArray) + }) + case _: DoubleType => udf((vector: Seq[Double]) => { + Vectors.dense(vector.toArray) + }) + case other => + throw new IllegalArgumentException(s"Array[$other] column cannot be cast to Vector") + } + transferUDF(col(colName)) + case other => + throw new IllegalArgumentException(s"$other column cannot be cast to Vector") + } + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala index 77c9d482d95b6..5445ebe5c95eb 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala @@ -30,6 +30,8 @@ import org.apache.spark.mllib.clustering.{DistanceMeasure, KMeans => MLlibKMeans import org.apache.spark.mllib.linalg.{Vectors => MLlibVectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, Dataset, SparkSession} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types.{ArrayType, DoubleType, FloatType, IntegerType, StructType} private[clustering] case class TestRow(features: Vector) @@ -199,6 +201,42 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR assert(e.getCause.getMessage.contains("Cosine distance is not defined")) } + test("KMean with Array input") { + val featuresColNameD = "array_double_features" + val featuresColNameF = "array_float_features" + + val doubleUDF = udf { (features: Vector) => + val featureArray = Array.fill[Double](features.size)(0.0) + features.foreachActive((idx, value) => featureArray(idx) = value.toFloat) + featureArray + } + val floatUDF = udf { (features: Vector) => + val featureArray = Array.fill[Float](features.size)(0.0f) + features.foreachActive((idx, value) => featureArray(idx) = value.toFloat) + featureArray + } + + val newdatasetD = dataset.withColumn(featuresColNameD, doubleUDF(col("features"))) + .drop("features") + val newdatasetF = dataset.withColumn(featuresColNameF, floatUDF(col("features"))) + .drop("features") + assert(newdatasetD.schema(featuresColNameD).dataType.equals(new ArrayType(DoubleType, false))) + assert(newdatasetF.schema(featuresColNameF).dataType.equals(new ArrayType(FloatType, false))) + + val kmeansD = new KMeans().setK(k).setMaxIter(1).setFeaturesCol(featuresColNameD).setSeed(1) + val kmeansF = new KMeans().setK(k).setMaxIter(1).setFeaturesCol(featuresColNameF).setSeed(1) + val modelD = kmeansD.fit(newdatasetD) + val modelF = kmeansF.fit(newdatasetF) + val transformedD = modelD.transform(newdatasetD) + val transformedF = modelF.transform(newdatasetF) + + val predictDifference = transformedD.select("prediction") + .except(transformedF.select("prediction")) + assert(predictDifference.count() == 0) + assert(modelD.computeCost(newdatasetD) == modelF.computeCost(newdatasetF) ) + } + + test("read/write") { def checkModelData(model: KMeansModel, model2: KMeansModel): Unit = { assert(model.clusterCenters === model2.clusterCenters) From ce7ba2e98e0a3b038e881c271b5905058c43155b Mon Sep 17 00:00:00 2001 From: Steve Loughran Date: Tue, 24 Apr 2018 09:57:09 -0700 Subject: [PATCH 0681/1161] [SPARK-23807][BUILD] Add Hadoop 3.1 profile with relevant POM fix ups MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? 1. Adds a `hadoop-3.1` profile build depending on the hadoop-3.1 artifacts. 1. In the hadoop-cloud module, adds an explicit hadoop-3.1 profile which switches from explicitly pulling in cloud connectors (hadoop-openstack, hadoop-aws, hadoop-azure) to depending on the hadoop-cloudstorage POM artifact, which pulls these in, has pre-excluded things like hadoop-common, and stays up to date with new connectors (hadoop-azuredatalake, hadoop-allyun). Goal: it becomes the Hadoop projects homework of keeping this clean, and the spark project doesn't need to handle new hadoop releases adding more dependencies. 1. the hadoop-cloud/hadoop-3.1 profile also declares support for jetty-ajax and jetty-util to ensure that these jars get into the distribution jar directory when needed by unshaded libraries. 1. Increases the curator and zookeeper versions to match those in hadoop-3, fixing spark core to build in sbt with the hadoop-3 dependencies. ## How was this patch tested? * Everything this has been built and tested against both ASF Hadoop branch-3.1 and hadoop trunk. * spark-shell was used to create connectors to all the stores and verify that file IO could take place. The spark hive-1.2.1 JAR has problems here, as it's version check logic fails for Hadoop versions > 2. This can be avoided with either of * The hadoop JARs built to declare their version as Hadoop 2.11 `mvn install -DskipTests -DskipShade -Ddeclared.hadoop.version=2.11` . This is safe for local test runs, not for deployment (HDFS is very strict about cross-version deployment). * A modified version of spark hive whose version check switch statement is happy with hadoop 3. I've done both, with maven and SBT. Three issues surfaced 1. A spark-core test failure —fixed in SPARK-23787. 1. SBT only: Zookeeper not being found in spark-core. Somehow curator 2.12.0 triggers some slightly different dependency resolution logic from previous versions, and Ivy was missing zookeeper.jar entirely. This patch adds the explicit declaration for all spark profiles, setting the ZK version = 3.4.9 for hadoop-3.1 1. Marking jetty-utils as provided in spark was stopping hadoop-azure from being able to instantiate the azure wasb:// client; it was using jetty-util-ajax, which could then not find a class in jetty-util. Author: Steve Loughran Closes #20923 from steveloughran/cloud/SPARK-23807-hadoop-31. --- assembly/pom.xml | 8 ++ core/pom.xml | 6 + dev/deps/spark-deps-hadoop-3.1 | 221 +++++++++++++++++++++++++++++++++ dev/test-dependencies.sh | 1 + hadoop-cloud/pom.xml | 83 ++++++++++++- pom.xml | 9 ++ 6 files changed, 327 insertions(+), 1 deletion(-) create mode 100644 dev/deps/spark-deps-hadoop-3.1 diff --git a/assembly/pom.xml b/assembly/pom.xml index a207dae5a74ff..9608c96fd5369 100644 --- a/assembly/pom.xml +++ b/assembly/pom.xml @@ -254,6 +254,14 @@ spark-hadoop-cloud_${scala.binary.version} ${project.version} + + + org.eclipse.jetty + jetty-util + ${hadoop.deps.scope} + diff --git a/core/pom.xml b/core/pom.xml index 9258a856028a0..093a9869b6dd7 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -95,6 +95,12 @@ org.apache.curator curator-recipes + + + org.apache.zookeeper + zookeeper + diff --git a/dev/deps/spark-deps-hadoop-3.1 b/dev/deps/spark-deps-hadoop-3.1 new file mode 100644 index 0000000000000..97ad65a4096cb --- /dev/null +++ b/dev/deps/spark-deps-hadoop-3.1 @@ -0,0 +1,221 @@ +HikariCP-java7-2.4.12.jar +JavaEWAH-0.3.2.jar +RoaringBitmap-0.5.11.jar +ST4-4.0.4.jar +accessors-smart-1.2.jar +activation-1.1.1.jar +aircompressor-0.8.jar +antlr-2.7.7.jar +antlr-runtime-3.4.jar +antlr4-runtime-4.7.jar +aopalliance-1.0.jar +aopalliance-repackaged-2.4.0-b34.jar +apache-log4j-extras-1.2.17.jar +arpack_combined_all-0.1.jar +arrow-format-0.8.0.jar +arrow-memory-0.8.0.jar +arrow-vector-0.8.0.jar +automaton-1.11-8.jar +avro-1.7.7.jar +avro-ipc-1.7.7.jar +avro-mapred-1.7.7-hadoop2.jar +base64-2.3.8.jar +bcprov-jdk15on-1.58.jar +bonecp-0.8.0.RELEASE.jar +breeze-macros_2.11-0.13.2.jar +breeze_2.11-0.13.2.jar +calcite-avatica-1.2.0-incubating.jar +calcite-core-1.2.0-incubating.jar +calcite-linq4j-1.2.0-incubating.jar +chill-java-0.8.4.jar +chill_2.11-0.8.4.jar +commons-beanutils-1.9.3.jar +commons-cli-1.2.jar +commons-codec-1.10.jar +commons-collections-3.2.2.jar +commons-compiler-3.0.8.jar +commons-compress-1.4.1.jar +commons-configuration2-2.1.1.jar +commons-crypto-1.0.0.jar +commons-daemon-1.0.13.jar +commons-dbcp-1.4.jar +commons-httpclient-3.1.jar +commons-io-2.4.jar +commons-lang-2.6.jar +commons-lang3-3.5.jar +commons-logging-1.1.3.jar +commons-math3-3.4.1.jar +commons-net-3.1.jar +commons-pool-1.5.4.jar +compress-lzf-1.0.3.jar +core-1.1.2.jar +curator-client-2.12.0.jar +curator-framework-2.12.0.jar +curator-recipes-2.12.0.jar +datanucleus-api-jdo-3.2.6.jar +datanucleus-core-3.2.10.jar +datanucleus-rdbms-3.2.9.jar +derby-10.12.1.1.jar +dnsjava-2.1.7.jar +ehcache-3.3.1.jar +eigenbase-properties-1.1.5.jar +flatbuffers-1.2.0-3f79e055.jar +generex-1.0.1.jar +geronimo-jcache_1.0_spec-1.0-alpha-1.jar +gson-2.2.4.jar +guava-14.0.1.jar +guice-4.0.jar +guice-servlet-4.0.jar +hadoop-annotations-3.1.0.jar +hadoop-auth-3.1.0.jar +hadoop-client-3.1.0.jar +hadoop-common-3.1.0.jar +hadoop-hdfs-client-3.1.0.jar +hadoop-mapreduce-client-common-3.1.0.jar +hadoop-mapreduce-client-core-3.1.0.jar +hadoop-mapreduce-client-jobclient-3.1.0.jar +hadoop-yarn-api-3.1.0.jar +hadoop-yarn-client-3.1.0.jar +hadoop-yarn-common-3.1.0.jar +hadoop-yarn-registry-3.1.0.jar +hadoop-yarn-server-common-3.1.0.jar +hadoop-yarn-server-web-proxy-3.1.0.jar +hk2-api-2.4.0-b34.jar +hk2-locator-2.4.0-b34.jar +hk2-utils-2.4.0-b34.jar +hppc-0.7.2.jar +htrace-core4-4.1.0-incubating.jar +httpclient-4.5.4.jar +httpcore-4.4.8.jar +ivy-2.4.0.jar +jackson-annotations-2.6.7.jar +jackson-core-2.6.7.jar +jackson-core-asl-1.9.13.jar +jackson-databind-2.6.7.1.jar +jackson-dataformat-yaml-2.6.7.jar +jackson-jaxrs-base-2.7.8.jar +jackson-jaxrs-json-provider-2.7.8.jar +jackson-mapper-asl-1.9.13.jar +jackson-module-jaxb-annotations-2.6.7.jar +jackson-module-paranamer-2.7.9.jar +jackson-module-scala_2.11-2.6.7.1.jar +janino-3.0.8.jar +java-xmlbuilder-1.1.jar +javassist-3.18.1-GA.jar +javax.annotation-api-1.2.jar +javax.inject-1.jar +javax.inject-2.4.0-b34.jar +javax.servlet-api-3.1.0.jar +javax.ws.rs-api-2.0.1.jar +javolution-5.5.1.jar +jaxb-api-2.2.11.jar +jcip-annotations-1.0-1.jar +jcl-over-slf4j-1.7.16.jar +jdo-api-3.0.1.jar +jersey-client-2.22.2.jar +jersey-common-2.22.2.jar +jersey-container-servlet-2.22.2.jar +jersey-container-servlet-core-2.22.2.jar +jersey-guava-2.22.2.jar +jersey-media-jaxb-2.22.2.jar +jersey-server-2.22.2.jar +jets3t-0.9.4.jar +jetty-webapp-9.3.20.v20170531.jar +jetty-xml-9.3.20.v20170531.jar +jline-2.12.1.jar +joda-time-2.9.3.jar +jodd-core-3.5.2.jar +jpam-1.1.jar +json-smart-2.3.jar +json4s-ast_2.11-3.5.3.jar +json4s-core_2.11-3.5.3.jar +json4s-jackson_2.11-3.5.3.jar +json4s-scalap_2.11-3.5.3.jar +jsp-api-2.1.jar +jsr305-1.3.9.jar +jta-1.1.jar +jtransforms-2.4.0.jar +jul-to-slf4j-1.7.16.jar +kerb-admin-1.0.1.jar +kerb-client-1.0.1.jar +kerb-common-1.0.1.jar +kerb-core-1.0.1.jar +kerb-crypto-1.0.1.jar +kerb-identity-1.0.1.jar +kerb-server-1.0.1.jar +kerb-simplekdc-1.0.1.jar +kerb-util-1.0.1.jar +kerby-asn1-1.0.1.jar +kerby-config-1.0.1.jar +kerby-pkix-1.0.1.jar +kerby-util-1.0.1.jar +kerby-xdr-1.0.1.jar +kryo-shaded-3.0.3.jar +kubernetes-client-3.0.0.jar +kubernetes-model-2.0.0.jar +leveldbjni-all-1.8.jar +libfb303-0.9.3.jar +libthrift-0.9.3.jar +log4j-1.2.17.jar +logging-interceptor-3.8.1.jar +lz4-java-1.4.0.jar +machinist_2.11-0.6.1.jar +macro-compat_2.11-1.1.1.jar +mesos-1.4.0-shaded-protobuf.jar +metrics-core-3.1.5.jar +metrics-graphite-3.1.5.jar +metrics-json-3.1.5.jar +metrics-jvm-3.1.5.jar +minlog-1.3.0.jar +mssql-jdbc-6.2.1.jre7.jar +netty-3.9.9.Final.jar +netty-all-4.1.17.Final.jar +nimbus-jose-jwt-4.41.1.jar +objenesis-2.1.jar +okhttp-2.7.5.jar +okhttp-3.8.1.jar +okio-1.13.0.jar +opencsv-2.3.jar +orc-core-1.4.3-nohive.jar +orc-mapreduce-1.4.3-nohive.jar +oro-2.0.8.jar +osgi-resource-locator-1.0.1.jar +paranamer-2.8.jar +parquet-column-1.8.2.jar +parquet-common-1.8.2.jar +parquet-encoding-1.8.2.jar +parquet-format-2.3.1.jar +parquet-hadoop-1.8.2.jar +parquet-hadoop-bundle-1.6.0.jar +parquet-jackson-1.8.2.jar +protobuf-java-2.5.0.jar +py4j-0.10.6.jar +pyrolite-4.13.jar +re2j-1.1.jar +scala-compiler-2.11.8.jar +scala-library-2.11.8.jar +scala-parser-combinators_2.11-1.0.4.jar +scala-reflect-2.11.8.jar +scala-xml_2.11-1.0.5.jar +shapeless_2.11-2.3.2.jar +slf4j-api-1.7.16.jar +slf4j-log4j12-1.7.16.jar +snakeyaml-1.15.jar +snappy-0.2.jar +snappy-java-1.1.7.1.jar +spire-macros_2.11-0.13.0.jar +spire_2.11-0.13.0.jar +stax-api-1.0.1.jar +stax2-api-3.1.4.jar +stream-2.7.0.jar +stringtemplate-3.2.1.jar +super-csv-2.2.0.jar +token-provider-1.0.1.jar +univocity-parsers-2.5.9.jar +validation-api-1.1.0.Final.jar +woodstox-core-5.0.3.jar +xbean-asm5-shaded-4.4.jar +xz-1.0.jar +zjsonpatch-0.3.0.jar +zookeeper-3.4.9.jar +zstd-jni-1.3.2-2.jar diff --git a/dev/test-dependencies.sh b/dev/test-dependencies.sh index 3bf7618e1ea96..2fbd6b5e98f7f 100755 --- a/dev/test-dependencies.sh +++ b/dev/test-dependencies.sh @@ -34,6 +34,7 @@ MVN="build/mvn" HADOOP_PROFILES=( hadoop-2.6 hadoop-2.7 + hadoop-3.1 ) # We'll switch the version to a temp. one, publish POMs using that new version, then switch back to diff --git a/hadoop-cloud/pom.xml b/hadoop-cloud/pom.xml index 8e424b1c50236..2c39a7df0146e 100644 --- a/hadoop-cloud/pom.xml +++ b/hadoop-cloud/pom.xml @@ -38,7 +38,32 @@ hadoop-cloud + + target/scala-${scala.binary.version}/classes + target/scala-${scala.binary.version}/test-classes + + + + + org.apache.spark + spark-sql_${scala.binary.version} + ${project.version} + provided + + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + + + org.apache.hadoop + hadoop-client + ${hadoop.version} + provided + + + + hadoop-3.1 + + + + org.apache.hadoop + hadoop-cloud-storage + ${hadoop.version} + ${hadoop.deps.scope} + + + org.apache.hadoop + hadoop-common + + + org.codehaus.jackson + jackson-mapper-asl + + + com.fasterxml.jackson.core + jackson-core + + + com.google.guava + guava + + + + + + org.eclipse.jetty + jetty-util + ${hadoop.deps.scope} + + + org.eclipse.jetty + jetty-util-ajax + ${jetty.version} + ${hadoop.deps.scope} + + + + diff --git a/pom.xml b/pom.xml index 0a711f287a53f..88e77ff874748 100644 --- a/pom.xml +++ b/pom.xml @@ -2671,6 +2671,15 @@ + + hadoop-3.1 + + 3.1.0 + 2.12.0 + 3.4.9 + + + yarn From 83013752e3cfcbc3edeef249439ac20b143eeabc Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 24 Apr 2018 10:40:25 -0700 Subject: [PATCH 0682/1161] [SPARK-23455][ML] Default Params in ML should be saved separately in metadata ## What changes were proposed in this pull request? We save ML's user-supplied params and default params as one entity in metadata. During loading the saved models, we set all the loaded params into created ML model instances as user-supplied params. It causes some problems, e.g., if we strictly disallow some params to be set at the same time, a default param can fail the param check because it is treated as user-supplied param after loading. The loaded default params should not be set as user-supplied params. We should save ML default params separately in metadata. For backward compatibility, when loading metadata, if it is a metadata file from previous Spark, we shouldn't raise error if we can't find the default param field. ## How was this patch tested? Pass existing tests and added tests. Author: Liang-Chi Hsieh Closes #20633 from viirya/save-ml-default-params. --- .../DecisionTreeClassifier.scala | 2 +- .../ml/classification/GBTClassifier.scala | 4 +- .../spark/ml/classification/LinearSVC.scala | 2 +- .../classification/LogisticRegression.scala | 2 +- .../MultilayerPerceptronClassifier.scala | 2 +- .../spark/ml/classification/NaiveBayes.scala | 2 +- .../spark/ml/classification/OneVsRest.scala | 4 +- .../RandomForestClassifier.scala | 4 +- .../spark/ml/clustering/BisectingKMeans.scala | 2 +- .../spark/ml/clustering/GaussianMixture.scala | 2 +- .../apache/spark/ml/clustering/KMeans.scala | 2 +- .../org/apache/spark/ml/clustering/LDA.scala | 4 +- .../feature/BucketedRandomProjectionLSH.scala | 2 +- .../apache/spark/ml/feature/Bucketizer.scala | 24 ---- .../spark/ml/feature/ChiSqSelector.scala | 2 +- .../spark/ml/feature/CountVectorizer.scala | 2 +- .../org/apache/spark/ml/feature/IDF.scala | 2 +- .../org/apache/spark/ml/feature/Imputer.scala | 2 +- .../spark/ml/feature/MaxAbsScaler.scala | 2 +- .../apache/spark/ml/feature/MinHashLSH.scala | 2 +- .../spark/ml/feature/MinMaxScaler.scala | 2 +- .../ml/feature/OneHotEncoderEstimator.scala | 2 +- .../org/apache/spark/ml/feature/PCA.scala | 2 +- .../ml/feature/QuantileDiscretizer.scala | 24 ---- .../apache/spark/ml/feature/RFormula.scala | 6 +- .../spark/ml/feature/StandardScaler.scala | 2 +- .../spark/ml/feature/StringIndexer.scala | 2 +- .../spark/ml/feature/VectorIndexer.scala | 2 +- .../apache/spark/ml/feature/Word2Vec.scala | 2 +- .../org/apache/spark/ml/fpm/FPGrowth.scala | 2 +- .../org/apache/spark/ml/param/params.scala | 13 +- .../apache/spark/ml/recommendation/ALS.scala | 2 +- .../ml/regression/AFTSurvivalRegression.scala | 2 +- .../ml/regression/DecisionTreeRegressor.scala | 2 +- .../spark/ml/regression/GBTRegressor.scala | 4 +- .../GeneralizedLinearRegression.scala | 2 +- .../ml/regression/IsotonicRegression.scala | 2 +- .../ml/regression/LinearRegression.scala | 2 +- .../ml/regression/RandomForestRegressor.scala | 4 +- .../spark/ml/tuning/CrossValidator.scala | 6 +- .../ml/tuning/TrainValidationSplit.scala | 6 +- .../org/apache/spark/ml/util/ReadWrite.scala | 130 ++++++++++++------ .../spark/ml/util/DefaultReadWriteTest.scala | 73 +++++++++- project/MimaExcludes.scala | 6 + 44 files changed, 223 insertions(+), 147 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala index 771cd4fe91dcf..57797d1cc4978 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala @@ -279,7 +279,7 @@ object DecisionTreeClassificationModel extends MLReadable[DecisionTreeClassifica val root = loadTreeNodes(path, metadata, sparkSession, isClassification = true) val model = new DecisionTreeClassificationModel(metadata.uid, root.asInstanceOf[ClassificationNode], numFeatures, numClasses) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index c0255103bc313..0aa24f0a3cfcc 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -379,14 +379,14 @@ object GBTClassificationModel extends MLReadable[GBTClassificationModel] { case (treeMetadata, root) => val tree = new DecisionTreeRegressionModel(treeMetadata.uid, root.asInstanceOf[RegressionNode], numFeatures) - DefaultParamsReader.getAndSetParams(tree, treeMetadata) + treeMetadata.getAndSetParams(tree) tree } require(numTrees == trees.length, s"GBTClassificationModel.load expected $numTrees" + s" trees based on metadata but found ${trees.length} trees.") val model = new GBTClassificationModel(metadata.uid, trees, treeWeights, numFeatures) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala index 8f950cd28c3aa..80c537e1e0eb2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala @@ -377,7 +377,7 @@ object LinearSVCModel extends MLReadable[LinearSVCModel] { val Row(coefficients: Vector, intercept: Double) = data.select("coefficients", "intercept").head() val model = new LinearSVCModel(metadata.uid, coefficients, intercept) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index ee4b01058c75c..e426263910f26 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -1270,7 +1270,7 @@ object LogisticRegressionModel extends MLReadable[LogisticRegressionModel] { numClasses, isMultinomial) } - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala index af2e4699924e5..57ba47e596a97 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala @@ -388,7 +388,7 @@ object MultilayerPerceptronClassificationModel val weights = data.getAs[Vector](1) val model = new MultilayerPerceptronClassificationModel(metadata.uid, layers, weights) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala index 0293e03d47435..45fb585ed2262 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala @@ -407,7 +407,7 @@ object NaiveBayesModel extends MLReadable[NaiveBayesModel] { .head() val model = new NaiveBayesModel(metadata.uid, pi, theta) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala index 5348d882cfd67..7df53a6b8ad10 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala @@ -289,7 +289,7 @@ object OneVsRestModel extends MLReadable[OneVsRestModel] { DefaultParamsReader.loadParamsInstance[ClassificationModel[_, _]](modelPath, sc) } val ovrModel = new OneVsRestModel(metadata.uid, labelMetadata, models) - DefaultParamsReader.getAndSetParams(ovrModel, metadata) + metadata.getAndSetParams(ovrModel) ovrModel.set("classifier", classifier) ovrModel } @@ -484,7 +484,7 @@ object OneVsRest extends MLReadable[OneVsRest] { override def load(path: String): OneVsRest = { val (metadata, classifier) = OneVsRestParams.loadImpl(path, sc, className) val ovr = new OneVsRest(metadata.uid) - DefaultParamsReader.getAndSetParams(ovr, metadata) + metadata.getAndSetParams(ovr) ovr.setClassifier(classifier) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index bb972e9706fc1..f1ef26a07d3f8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -319,14 +319,14 @@ object RandomForestClassificationModel extends MLReadable[RandomForestClassifica case (treeMetadata, root) => val tree = new DecisionTreeClassificationModel(treeMetadata.uid, root.asInstanceOf[ClassificationNode], numFeatures, numClasses) - DefaultParamsReader.getAndSetParams(tree, treeMetadata) + treeMetadata.getAndSetParams(tree) tree } require(numTrees == trees.length, s"RandomForestClassificationModel.load expected $numTrees" + s" trees based on metadata but found ${trees.length} trees.") val model = new RandomForestClassificationModel(metadata.uid, trees, numFeatures, numClasses) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala index f7c422dc0faea..addc12ac52ec1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala @@ -193,7 +193,7 @@ object BisectingKMeansModel extends MLReadable[BisectingKMeansModel] { val dataPath = new Path(path, "data").toString val mllibModel = MLlibBisectingKMeansModel.load(sc, dataPath) val model = new BisectingKMeansModel(metadata.uid, mllibModel) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala index f19ad7a5a6938..b5804900c0358 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala @@ -233,7 +233,7 @@ object GaussianMixtureModel extends MLReadable[GaussianMixtureModel] { } val model = new GaussianMixtureModel(metadata.uid, weights, gaussians) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index d475c726e6f08..de61c9c089a36 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -280,7 +280,7 @@ object KMeansModel extends MLReadable[KMeansModel] { sparkSession.read.parquet(dataPath).as[OldData].head().clusterCenters } val model = new KMeansModel(metadata.uid, new MLlibKMeansModel(clusterCenters)) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala index 4bab670cc159f..47077230fac0a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala @@ -366,7 +366,7 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM private object LDAParams { /** - * Equivalent to [[DefaultParamsReader.getAndSetParams()]], but handles [[LDA]] and [[LDAModel]] + * Equivalent to [[Metadata.getAndSetParams()]], but handles [[LDA]] and [[LDAModel]] * formats saved with Spark 1.6, which differ from the formats in Spark 2.0+. * * @param model [[LDA]] or [[LDAModel]] instance. This instance will be modified with @@ -391,7 +391,7 @@ private object LDAParams { s"Cannot recognize JSON metadata: ${metadata.metadataJson}.") } case _ => // 2.0+ - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) } } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala index 41eaaf9679914..a906e954fecd5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala @@ -238,7 +238,7 @@ object BucketedRandomProjectionLSHModel extends MLReadable[BucketedRandomProject val model = new BucketedRandomProjectionLSHModel(metadata.uid, randUnitVectors.rowIter.toArray) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala index f49c410cbcfe2..f99649f7fa164 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala @@ -217,8 +217,6 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String override def copy(extra: ParamMap): Bucketizer = { defaultCopy[Bucketizer](extra).setParent(parent) } - - override def write: MLWriter = new Bucketizer.BucketizerWriter(this) } @Since("1.6.0") @@ -296,28 +294,6 @@ object Bucketizer extends DefaultParamsReadable[Bucketizer] { } } - - private[Bucketizer] class BucketizerWriter(instance: Bucketizer) extends MLWriter { - - override protected def saveImpl(path: String): Unit = { - // SPARK-23377: The default params will be saved and loaded as user-supplied params. - // Once `inputCols` is set, the default value of `outputCol` param causes the error - // when checking exclusive params. As a temporary to fix it, we skip the default value - // of `outputCol` if `inputCols` is set when saving the metadata. - // TODO: If we modify the persistence mechanism later to better handle default params, - // we can get rid of this. - var paramWithoutOutputCol: Option[JValue] = None - if (instance.isSet(instance.inputCols)) { - val params = instance.extractParamMap().toSeq - val jsonParams = params.filter(_.param != instance.outputCol).map { case ParamPair(p, v) => - p.name -> parse(p.jsonEncode(v)) - }.toList - paramWithoutOutputCol = Some(render(jsonParams)) - } - DefaultParamsWriter.saveMetadata(instance, path, sc, paramMap = paramWithoutOutputCol) - } - } - @Since("1.6.0") override def load(path: String): Bucketizer = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala index 16abc4949dea3..dbfb199ccd58f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala @@ -334,7 +334,7 @@ object ChiSqSelectorModel extends MLReadable[ChiSqSelectorModel] { val selectedFeatures = data.getAs[Seq[Int]](0).toArray val oldModel = new feature.ChiSqSelectorModel(selectedFeatures) val model = new ChiSqSelectorModel(metadata.uid, oldModel) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala index 9e0ed437e7bfc..10c48c3f52085 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala @@ -363,7 +363,7 @@ object CountVectorizerModel extends MLReadable[CountVectorizerModel] { .head() val vocabulary = data.getAs[Seq[String]](0).toArray val model = new CountVectorizerModel(metadata.uid, vocabulary) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala index 46a0730f5ddb8..58897cca4e5c6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala @@ -182,7 +182,7 @@ object IDFModel extends MLReadable[IDFModel] { .select("idf") .head() val model = new IDFModel(metadata.uid, new feature.IDFModel(OldVectors.fromML(idf))) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala index 730ee9fc08db8..1c074e204ad99 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala @@ -262,7 +262,7 @@ object ImputerModel extends MLReadable[ImputerModel] { val dataPath = new Path(path, "data").toString val surrogateDF = sqlContext.read.parquet(dataPath) val model = new ImputerModel(metadata.uid, surrogateDF) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala index 85f9732f79f67..90eceb0d61b40 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala @@ -172,7 +172,7 @@ object MaxAbsScalerModel extends MLReadable[MaxAbsScalerModel] { .select("maxAbs") .head() val model = new MaxAbsScalerModel(metadata.uid, maxAbs) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala index 556848e45532d..a67a3b0abbc1f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala @@ -205,7 +205,7 @@ object MinHashLSHModel extends MLReadable[MinHashLSHModel] { .map(tuple => (tuple(0), tuple(1))).toArray val model = new MinHashLSHModel(metadata.uid, randCoefficients) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala index f648deced54cd..2e0ae4af66f06 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala @@ -243,7 +243,7 @@ object MinMaxScalerModel extends MLReadable[MinMaxScalerModel] { .select("originalMin", "originalMax") .head() val model = new MinMaxScalerModel(metadata.uid, originalMin, originalMax) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala index bd1e3426c8780..4a44f3186538d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala @@ -386,7 +386,7 @@ object OneHotEncoderModel extends MLReadable[OneHotEncoderModel] { .head() val categorySizes = data.getAs[Seq[Int]](0).toArray val model = new OneHotEncoderModel(metadata.uid, categorySizes) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala index 4143d864d7930..8172491a517d1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala @@ -220,7 +220,7 @@ object PCAModel extends MLReadable[PCAModel] { new PCAModel(metadata.uid, pc.asML, Vectors.dense(Array.empty[Double]).asInstanceOf[DenseVector]) } - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala index 3b4c25478fb1d..56e2c543d100a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala @@ -253,35 +253,11 @@ final class QuantileDiscretizer @Since("1.6.0") (@Since("1.6.0") override val ui @Since("1.6.0") override def copy(extra: ParamMap): QuantileDiscretizer = defaultCopy(extra) - - override def write: MLWriter = new QuantileDiscretizer.QuantileDiscretizerWriter(this) } @Since("1.6.0") object QuantileDiscretizer extends DefaultParamsReadable[QuantileDiscretizer] with Logging { - private[QuantileDiscretizer] - class QuantileDiscretizerWriter(instance: QuantileDiscretizer) extends MLWriter { - - override protected def saveImpl(path: String): Unit = { - // SPARK-23377: The default params will be saved and loaded as user-supplied params. - // Once `inputCols` is set, the default value of `outputCol` param causes the error - // when checking exclusive params. As a temporary to fix it, we skip the default value - // of `outputCol` if `inputCols` is set when saving the metadata. - // TODO: If we modify the persistence mechanism later to better handle default params, - // we can get rid of this. - var paramWithoutOutputCol: Option[JValue] = None - if (instance.isSet(instance.inputCols)) { - val params = instance.extractParamMap().toSeq - val jsonParams = params.filter(_.param != instance.outputCol).map { case ParamPair(p, v) => - p.name -> parse(p.jsonEncode(v)) - }.toList - paramWithoutOutputCol = Some(render(jsonParams)) - } - DefaultParamsWriter.saveMetadata(instance, path, sc, paramMap = paramWithoutOutputCol) - } - } - @Since("1.6.0") override def load(path: String): QuantileDiscretizer = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala index e214765e3307f..55e595eee6ffb 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala @@ -446,7 +446,7 @@ object RFormulaModel extends MLReadable[RFormulaModel] { val model = new RFormulaModel(metadata.uid, resolvedRFormula, pipelineModel) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } @@ -510,7 +510,7 @@ private object ColumnPruner extends MLReadable[ColumnPruner] { val columnsToPrune = data.getAs[Seq[String]](0).toSet val pruner = new ColumnPruner(metadata.uid, columnsToPrune) - DefaultParamsReader.getAndSetParams(pruner, metadata) + metadata.getAndSetParams(pruner) pruner } } @@ -602,7 +602,7 @@ private object VectorAttributeRewriter extends MLReadable[VectorAttributeRewrite val prefixesToRewrite = data.getAs[Map[String, String]](1) val rewriter = new VectorAttributeRewriter(metadata.uid, vectorCol, prefixesToRewrite) - DefaultParamsReader.getAndSetParams(rewriter, metadata) + metadata.getAndSetParams(rewriter) rewriter } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala index 8f125d8fd51d2..91b0707dec3f3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala @@ -212,7 +212,7 @@ object StandardScalerModel extends MLReadable[StandardScalerModel] { .select("std", "mean") .head() val model = new StandardScalerModel(metadata.uid, std, mean) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index 67cdb097217a2..a833d8b270cf1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -315,7 +315,7 @@ object StringIndexerModel extends MLReadable[StringIndexerModel] { .head() val labels = data.getAs[Seq[String]](0).toArray val model = new StringIndexerModel(metadata.uid, labels) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala index e6ec4e2e36ff0..0e7396a621dbd 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala @@ -537,7 +537,7 @@ object VectorIndexerModel extends MLReadable[VectorIndexerModel] { val numFeatures = data.getAs[Int](0) val categoryMaps = data.getAs[Map[Int, Map[Double, Int]]](1) val model = new VectorIndexerModel(metadata.uid, numFeatures, categoryMaps) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala index fe3306e1e50d6..fc9996d69ba72 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala @@ -410,7 +410,7 @@ object Word2VecModel extends MLReadable[Word2VecModel] { } val model = new Word2VecModel(metadata.uid, oldModel) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala index 3d041fc80eb7f..0bf405d9abf9d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala @@ -335,7 +335,7 @@ object FPGrowthModel extends MLReadable[FPGrowthModel] { val dataPath = new Path(path, "data").toString val frequentItems = sparkSession.read.parquet(dataPath) val model = new FPGrowthModel(metadata.uid, frequentItems) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index 9a83a5882ce29..e6c347ed17c15 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -865,10 +865,10 @@ trait Params extends Identifiable with Serializable { } /** Internal param map for user-supplied values. */ - private val paramMap: ParamMap = ParamMap.empty + private[ml] val paramMap: ParamMap = ParamMap.empty /** Internal param map for default values. */ - private val defaultParamMap: ParamMap = ParamMap.empty + private[ml] val defaultParamMap: ParamMap = ParamMap.empty /** Validates that the input param belongs to this instance. */ private def shouldOwn(param: Param[_]): Unit = { @@ -905,6 +905,15 @@ trait Params extends Identifiable with Serializable { } } +private[ml] object Params { + /** + * Sets a default param value for a `Params`. + */ + private[ml] final def setDefault[T](params: Params, param: Param[T], value: T): Unit = { + params.defaultParamMap.put(param -> value) + } +} + /** * :: DeveloperApi :: * Java-friendly wrapper for [[Params]]. diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index 81a8f50761e0e..a23f9552b9e5f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -529,7 +529,7 @@ object ALSModel extends MLReadable[ALSModel] { val model = new ALSModel(metadata.uid, rank, userFactors, itemFactors) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala index 4b46c3831d75f..7c6ec2a8419fd 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala @@ -423,7 +423,7 @@ object AFTSurvivalRegressionModel extends MLReadable[AFTSurvivalRegressionModel] .head() val model = new AFTSurvivalRegressionModel(metadata.uid, coefficients, intercept, scale) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala index 5cef5c9f21f1e..8bcf0793a64c1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala @@ -282,7 +282,7 @@ object DecisionTreeRegressionModel extends MLReadable[DecisionTreeRegressionMode val root = loadTreeNodes(path, metadata, sparkSession, isClassification = false) val model = new DecisionTreeRegressionModel(metadata.uid, root.asInstanceOf[RegressionNode], numFeatures) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala index 834aaa0e362d1..8598e808c4946 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala @@ -311,7 +311,7 @@ object GBTRegressionModel extends MLReadable[GBTRegressionModel] { case (treeMetadata, root) => val tree = new DecisionTreeRegressionModel(treeMetadata.uid, root.asInstanceOf[RegressionNode], numFeatures) - DefaultParamsReader.getAndSetParams(tree, treeMetadata) + treeMetadata.getAndSetParams(tree) tree } @@ -319,7 +319,7 @@ object GBTRegressionModel extends MLReadable[GBTRegressionModel] { s" trees based on metadata but found ${trees.length} trees.") val model = new GBTRegressionModel(metadata.uid, trees, treeWeights, numFeatures) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala index 4c3f1431d5077..e030a40cb19be 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala @@ -1146,7 +1146,7 @@ object GeneralizedLinearRegressionModel extends MLReadable[GeneralizedLinearRegr val model = new GeneralizedLinearRegressionModel(metadata.uid, coefficients, intercept) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala index 8faab52ea474b..b046897ab2b7e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala @@ -308,7 +308,7 @@ object IsotonicRegressionModel extends MLReadable[IsotonicRegressionModel] { val model = new IsotonicRegressionModel( metadata.uid, new MLlibIsotonicRegressionModel(boundaries, predictions, isotonic)) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index 9cdd3a051e719..f1d9a4453deaa 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -799,7 +799,7 @@ object LinearRegressionModel extends MLReadable[LinearRegressionModel] { new LinearRegressionModel(metadata.uid, coefficients, intercept, scale) } - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala index 7f77398ba2a22..4509f85aafd12 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala @@ -276,14 +276,14 @@ object RandomForestRegressionModel extends MLReadable[RandomForestRegressionMode val trees: Array[DecisionTreeRegressionModel] = treesData.map { case (treeMetadata, root) => val tree = new DecisionTreeRegressionModel(treeMetadata.uid, root.asInstanceOf[RegressionNode], numFeatures) - DefaultParamsReader.getAndSetParams(tree, treeMetadata) + treeMetadata.getAndSetParams(tree) tree } require(numTrees == trees.length, s"RandomForestRegressionModel.load expected $numTrees" + s" trees based on metadata but found ${trees.length} trees.") val model = new RandomForestRegressionModel(metadata.uid, trees, numFeatures) - DefaultParamsReader.getAndSetParams(model, metadata) + metadata.getAndSetParams(model) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index c2826dcc08634..5e916cc4a9fdd 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -234,8 +234,7 @@ object CrossValidator extends MLReadable[CrossValidator] { .setEstimator(estimator) .setEvaluator(evaluator) .setEstimatorParamMaps(estimatorParamMaps) - DefaultParamsReader.getAndSetParams(cv, metadata, - skipParams = Option(List("estimatorParamMaps"))) + metadata.getAndSetParams(cv, skipParams = Option(List("estimatorParamMaps"))) cv } } @@ -424,8 +423,7 @@ object CrossValidatorModel extends MLReadable[CrossValidatorModel] { model.set(model.estimator, estimator) .set(model.evaluator, evaluator) .set(model.estimatorParamMaps, estimatorParamMaps) - DefaultParamsReader.getAndSetParams(model, metadata, - skipParams = Option(List("estimatorParamMaps"))) + metadata.getAndSetParams(model, skipParams = Option(List("estimatorParamMaps"))) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala index 8d1b9a8ddab59..13369c4df7180 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala @@ -228,8 +228,7 @@ object TrainValidationSplit extends MLReadable[TrainValidationSplit] { .setEstimator(estimator) .setEvaluator(evaluator) .setEstimatorParamMaps(estimatorParamMaps) - DefaultParamsReader.getAndSetParams(tvs, metadata, - skipParams = Option(List("estimatorParamMaps"))) + metadata.getAndSetParams(tvs, skipParams = Option(List("estimatorParamMaps"))) tvs } } @@ -407,8 +406,7 @@ object TrainValidationSplitModel extends MLReadable[TrainValidationSplitModel] { model.set(model.estimator, estimator) .set(model.evaluator, evaluator) .set(model.estimatorParamMaps, estimatorParamMaps) - DefaultParamsReader.getAndSetParams(model, metadata, - skipParams = Option(List("estimatorParamMaps"))) + metadata.getAndSetParams(model, skipParams = Option(List("estimatorParamMaps"))) model } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala index 7edcd498678cc..72a60e04360d6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -39,7 +39,7 @@ import org.apache.spark.ml.feature.RFormulaModel import org.apache.spark.ml.param.{ParamPair, Params} import org.apache.spark.ml.tuning.ValidatorParams import org.apache.spark.sql.{SparkSession, SQLContext} -import org.apache.spark.util.Utils +import org.apache.spark.util.{Utils, VersionUtils} /** * Trait for `MLWriter` and `MLReader`. @@ -421,6 +421,7 @@ private[ml] object DefaultParamsWriter { * - timestamp * - sparkVersion * - uid + * - defaultParamMap * - paramMap * - (optionally, extra metadata) * @@ -453,15 +454,20 @@ private[ml] object DefaultParamsWriter { paramMap: Option[JValue] = None): String = { val uid = instance.uid val cls = instance.getClass.getName - val params = instance.extractParamMap().toSeq.asInstanceOf[Seq[ParamPair[Any]]] + val params = instance.paramMap.toSeq + val defaultParams = instance.defaultParamMap.toSeq val jsonParams = paramMap.getOrElse(render(params.map { case ParamPair(p, v) => p.name -> parse(p.jsonEncode(v)) }.toList)) + val jsonDefaultParams = render(defaultParams.map { case ParamPair(p, v) => + p.name -> parse(p.jsonEncode(v)) + }.toList) val basicMetadata = ("class" -> cls) ~ ("timestamp" -> System.currentTimeMillis()) ~ ("sparkVersion" -> sc.version) ~ ("uid" -> uid) ~ - ("paramMap" -> jsonParams) + ("paramMap" -> jsonParams) ~ + ("defaultParamMap" -> jsonDefaultParams) val metadata = extraMetadata match { case Some(jObject) => basicMetadata ~ jObject @@ -488,7 +494,7 @@ private[ml] class DefaultParamsReader[T] extends MLReader[T] { val cls = Utils.classForName(metadata.className) val instance = cls.getConstructor(classOf[String]).newInstance(metadata.uid).asInstanceOf[Params] - DefaultParamsReader.getAndSetParams(instance, metadata) + metadata.getAndSetParams(instance) instance.asInstanceOf[T] } } @@ -499,6 +505,8 @@ private[ml] object DefaultParamsReader { * All info from metadata file. * * @param params paramMap, as a `JValue` + * @param defaultParams defaultParamMap, as a `JValue`. For metadata file prior to Spark 2.4, + * this is `JNothing`. * @param metadata All metadata, including the other fields * @param metadataJson Full metadata file String (for debugging) */ @@ -508,27 +516,90 @@ private[ml] object DefaultParamsReader { timestamp: Long, sparkVersion: String, params: JValue, + defaultParams: JValue, metadata: JValue, metadataJson: String) { + + private def getValueFromParams(params: JValue): Seq[(String, JValue)] = { + params match { + case JObject(pairs) => pairs + case _ => + throw new IllegalArgumentException( + s"Cannot recognize JSON metadata: $metadataJson.") + } + } + /** * Get the JSON value of the [[org.apache.spark.ml.param.Param]] of the given name. * This can be useful for getting a Param value before an instance of `Params` - * is available. + * is available. This will look up `params` first, if not existing then looking up + * `defaultParams`. */ def getParamValue(paramName: String): JValue = { implicit val format = DefaultFormats - params match { + + // Looking up for `params` first. + var pairs = getValueFromParams(params) + var foundPairs = pairs.filter { case (pName, jsonValue) => + pName == paramName + } + if (foundPairs.length == 0) { + // Looking up for `defaultParams` then. + pairs = getValueFromParams(defaultParams) + foundPairs = pairs.filter { case (pName, jsonValue) => + pName == paramName + } + } + assert(foundPairs.length == 1, s"Expected one instance of Param '$paramName' but found" + + s" ${foundPairs.length} in JSON Params: " + pairs.map(_.toString).mkString(", ")) + + foundPairs.map(_._2).head + } + + /** + * Extract Params from metadata, and set them in the instance. + * This works if all Params (except params included by `skipParams` list) implement + * [[org.apache.spark.ml.param.Param.jsonDecode()]]. + * + * @param skipParams The params included in `skipParams` won't be set. This is useful if some + * params don't implement [[org.apache.spark.ml.param.Param.jsonDecode()]] + * and need special handling. + */ + def getAndSetParams( + instance: Params, + skipParams: Option[List[String]] = None): Unit = { + setParams(instance, skipParams, isDefault = false) + + // For metadata file prior to Spark 2.4, there is no default section. + val (major, minor) = VersionUtils.majorMinorVersion(sparkVersion) + if (major > 2 || (major == 2 && minor >= 4)) { + setParams(instance, skipParams, isDefault = true) + } + } + + private def setParams( + instance: Params, + skipParams: Option[List[String]], + isDefault: Boolean): Unit = { + implicit val format = DefaultFormats + val paramsToSet = if (isDefault) defaultParams else params + paramsToSet match { case JObject(pairs) => - val values = pairs.filter { case (pName, jsonValue) => - pName == paramName - }.map(_._2) - assert(values.length == 1, s"Expected one instance of Param '$paramName' but found" + - s" ${values.length} in JSON Params: " + pairs.map(_.toString).mkString(", ")) - values.head + pairs.foreach { case (paramName, jsonValue) => + if (skipParams == None || !skipParams.get.contains(paramName)) { + val param = instance.getParam(paramName) + val value = param.jsonDecode(compact(render(jsonValue))) + if (isDefault) { + Params.setDefault(instance, param, value) + } else { + instance.set(param, value) + } + } + } case _ => throw new IllegalArgumentException( - s"Cannot recognize JSON metadata: $metadataJson.") + s"Cannot recognize JSON metadata: ${metadataJson}.") } } } @@ -561,43 +632,14 @@ private[ml] object DefaultParamsReader { val uid = (metadata \ "uid").extract[String] val timestamp = (metadata \ "timestamp").extract[Long] val sparkVersion = (metadata \ "sparkVersion").extract[String] + val defaultParams = metadata \ "defaultParamMap" val params = metadata \ "paramMap" if (expectedClassName.nonEmpty) { require(className == expectedClassName, s"Error loading metadata: Expected class name" + s" $expectedClassName but found class name $className") } - Metadata(className, uid, timestamp, sparkVersion, params, metadata, metadataStr) - } - - /** - * Extract Params from metadata, and set them in the instance. - * This works if all Params (except params included by `skipParams` list) implement - * [[org.apache.spark.ml.param.Param.jsonDecode()]]. - * - * @param skipParams The params included in `skipParams` won't be set. This is useful if some - * params don't implement [[org.apache.spark.ml.param.Param.jsonDecode()]] - * and need special handling. - * TODO: Move to [[Metadata]] method - */ - def getAndSetParams( - instance: Params, - metadata: Metadata, - skipParams: Option[List[String]] = None): Unit = { - implicit val format = DefaultFormats - metadata.params match { - case JObject(pairs) => - pairs.foreach { case (paramName, jsonValue) => - if (skipParams == None || !skipParams.get.contains(paramName)) { - val param = instance.getParam(paramName) - val value = param.jsonDecode(compact(render(jsonValue))) - instance.set(param, value) - } - } - case _ => - throw new IllegalArgumentException( - s"Cannot recognize JSON metadata: ${metadata.metadataJson}.") - } + Metadata(className, uid, timestamp, sparkVersion, params, defaultParams, metadata, metadataStr) } /** diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala index 4da95e74434ee..4d9e664850c12 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala @@ -19,9 +19,10 @@ package org.apache.spark.ml.util import java.io.{File, IOException} +import org.json4s.JNothing import org.scalatest.Suite -import org.apache.spark.SparkFunSuite +import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.param._ import org.apache.spark.mllib.util.MLlibTestSparkContext @@ -129,6 +130,8 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite => class MyParams(override val uid: String) extends Params with MLWritable { final val intParamWithDefault: IntParam = new IntParam(this, "intParamWithDefault", "doc") + final val shouldNotSetIfSetintParamWithDefault: IntParam = + new IntParam(this, "shouldNotSetIfSetintParamWithDefault", "doc") final val intParam: IntParam = new IntParam(this, "intParam", "doc") final val floatParam: FloatParam = new FloatParam(this, "floatParam", "doc") final val doubleParam: DoubleParam = new DoubleParam(this, "doubleParam", "doc") @@ -150,6 +153,13 @@ class MyParams(override val uid: String) extends Params with MLWritable { set(doubleArrayParam -> Array(8.0, 9.0)) set(stringArrayParam -> Array("10", "11")) + def checkExclusiveParams(): Unit = { + if (isSet(shouldNotSetIfSetintParamWithDefault) && isSet(intParamWithDefault)) { + throw new SparkException("intParamWithDefault and shouldNotSetIfSetintParamWithDefault " + + "shouldn't be set at the same time") + } + } + override def copy(extra: ParamMap): Params = defaultCopy(extra) override def write: MLWriter = new DefaultParamsWriter(this) @@ -169,4 +179,65 @@ class DefaultReadWriteSuite extends SparkFunSuite with MLlibTestSparkContext val myParams = new MyParams("my_params") testDefaultReadWrite(myParams) } + + test("default param shouldn't become user-supplied param after persistence") { + val myParams = new MyParams("my_params") + myParams.set(myParams.shouldNotSetIfSetintParamWithDefault, 1) + myParams.checkExclusiveParams() + val loadedMyParams = testDefaultReadWrite(myParams) + loadedMyParams.checkExclusiveParams() + assert(loadedMyParams.getDefault(loadedMyParams.intParamWithDefault) == + myParams.getDefault(myParams.intParamWithDefault)) + + loadedMyParams.set(myParams.intParamWithDefault, 1) + intercept[SparkException] { + loadedMyParams.checkExclusiveParams() + } + } + + test("User-supplied value for default param should be kept after persistence") { + val myParams = new MyParams("my_params") + myParams.set(myParams.intParamWithDefault, 100) + val loadedMyParams = testDefaultReadWrite(myParams) + assert(loadedMyParams.get(myParams.intParamWithDefault).get == 100) + } + + test("Read metadata without default field prior to 2.4") { + // default params are saved in `paramMap` field in metadata file prior to Spark 2.4. + val metadata = """{"class":"org.apache.spark.ml.util.MyParams", + |"timestamp":1518852502761,"sparkVersion":"2.3.0", + |"uid":"my_params", + |"paramMap":{"intParamWithDefault":0}}""".stripMargin + val parsedMetadata = DefaultParamsReader.parseMetadata(metadata) + val myParams = new MyParams("my_params") + assert(!myParams.isSet(myParams.intParamWithDefault)) + parsedMetadata.getAndSetParams(myParams) + + // The behavior prior to Spark 2.4, default params are set in loaded ML instance. + assert(myParams.isSet(myParams.intParamWithDefault)) + } + + test("Should raise error when read metadata without default field after Spark 2.4") { + val myParams = new MyParams("my_params") + + val metadata1 = """{"class":"org.apache.spark.ml.util.MyParams", + |"timestamp":1518852502761,"sparkVersion":"2.4.0", + |"uid":"my_params", + |"paramMap":{"intParamWithDefault":0}}""".stripMargin + val parsedMetadata1 = DefaultParamsReader.parseMetadata(metadata1) + val err1 = intercept[IllegalArgumentException] { + parsedMetadata1.getAndSetParams(myParams) + } + assert(err1.getMessage().contains("Cannot recognize JSON metadata")) + + val metadata2 = """{"class":"org.apache.spark.ml.util.MyParams", + |"timestamp":1518852502761,"sparkVersion":"3.0.0", + |"uid":"my_params", + |"paramMap":{"intParamWithDefault":0}}""".stripMargin + val parsedMetadata2 = DefaultParamsReader.parseMetadata(metadata2) + val err2 = intercept[IllegalArgumentException] { + parsedMetadata2.getAndSetParams(myParams) + } + assert(err2.getMessage().contains("Cannot recognize JSON metadata")) + } } diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index a87fa68422c34..7d0e88ee20c3f 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -62,6 +62,12 @@ object MimaExcludes { ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.StorageStatus.cacheSize"), ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.StorageStatus.rddStorageLevel"), + // [SPARK-23455][ML] Default Params in ML should be saved separately in metadata + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.param.Params.paramMap"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.param.Params.org$apache$spark$ml$param$Params$_setter_$paramMap_="), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.param.Params.defaultParamMap"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.param.Params.org$apache$spark$ml$param$Params$_setter_$defaultParamMap_="), + // [SPARK-14681][ML] Provide label/impurity stats for spark.ml decision tree nodes ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.ml.tree.LeafNode"), ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.ml.tree.InternalNode"), From 379bffa0525a4343f8c10e51ed192031922f9874 Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Tue, 24 Apr 2018 11:02:22 -0700 Subject: [PATCH 0683/1161] [SPARK-23990][ML] Instruments logging improvements - ML regression package ## What changes were proposed in this pull request? Instruments logging improvements - ML regression package I add an `OptionalInstrument` class which used in `WeightLeastSquares` and `IterativelyReweightedLeastSquares`. ## How was this patch tested? N/A Author: WeichenXu Closes #21078 from WeichenXu123/inst_reg. --- .../classification/LogisticRegression.scala | 4 +- .../IterativelyReweightedLeastSquares.scala | 18 +++-- .../spark/ml/optim/WeightedLeastSquares.scala | 32 +++++---- .../ml/regression/AFTSurvivalRegression.scala | 2 +- .../GeneralizedLinearRegression.scala | 14 ++-- .../ml/regression/LinearRegression.scala | 22 +++--- .../spark/ml/tree/impl/RandomForest.scala | 2 + .../spark/ml/util/Instrumentation.scala | 68 ++++++++++++++++++- 8 files changed, 125 insertions(+), 37 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index e426263910f26..06ca37bc75146 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -500,7 +500,7 @@ class LogisticRegression @Since("1.2.0") ( if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK) - val instr = Instrumentation.create(this, instances) + val instr = Instrumentation.create(this, dataset) instr.logParams(regParam, elasticNetParam, standardization, threshold, maxIter, tol, fitIntercept) @@ -816,7 +816,7 @@ class LogisticRegression @Since("1.2.0") ( if (state == null) { val msg = s"${optimizer.getClass.getName} failed." - logError(msg) + instr.logError(msg) throw new SparkException(msg) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquares.scala b/mllib/src/main/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquares.scala index 6961b45f55e4d..572b8cf0051b3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquares.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquares.scala @@ -17,9 +17,9 @@ package org.apache.spark.ml.optim -import org.apache.spark.internal.Logging import org.apache.spark.ml.feature.{Instance, OffsetInstance} import org.apache.spark.ml.linalg._ +import org.apache.spark.ml.util.OptionalInstrumentation import org.apache.spark.rdd.RDD /** @@ -61,9 +61,12 @@ private[ml] class IterativelyReweightedLeastSquares( val fitIntercept: Boolean, val regParam: Double, val maxIter: Int, - val tol: Double) extends Logging with Serializable { + val tol: Double) extends Serializable { - def fit(instances: RDD[OffsetInstance]): IterativelyReweightedLeastSquaresModel = { + def fit( + instances: RDD[OffsetInstance], + instr: OptionalInstrumentation = OptionalInstrumentation.create( + classOf[IterativelyReweightedLeastSquares])): IterativelyReweightedLeastSquaresModel = { var converged = false var iter = 0 @@ -83,7 +86,8 @@ private[ml] class IterativelyReweightedLeastSquares( // Estimate new model model = new WeightedLeastSquares(fitIntercept, regParam, elasticNetParam = 0.0, - standardizeFeatures = false, standardizeLabel = false).fit(newInstances) + standardizeFeatures = false, standardizeLabel = false) + .fit(newInstances, instr = instr) // Check convergence val oldCoefficients = oldModel.coefficients @@ -96,14 +100,14 @@ private[ml] class IterativelyReweightedLeastSquares( if (maxTol < tol) { converged = true - logInfo(s"IRLS converged in $iter iterations.") + instr.logInfo(s"IRLS converged in $iter iterations.") } - logInfo(s"Iteration $iter : relative tolerance = $maxTol") + instr.logInfo(s"Iteration $iter : relative tolerance = $maxTol") iter = iter + 1 if (iter == maxIter) { - logInfo(s"IRLS reached the max number of iterations: $maxIter.") + instr.logInfo(s"IRLS reached the max number of iterations: $maxIter.") } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala index c5c9c8eb2bd29..1b7c15f1f0a8c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala @@ -17,9 +17,9 @@ package org.apache.spark.ml.optim -import org.apache.spark.internal.Logging import org.apache.spark.ml.feature.Instance import org.apache.spark.ml.linalg._ +import org.apache.spark.ml.util.OptionalInstrumentation import org.apache.spark.rdd.RDD /** @@ -81,13 +81,11 @@ private[ml] class WeightedLeastSquares( val standardizeLabel: Boolean, val solverType: WeightedLeastSquares.Solver = WeightedLeastSquares.Auto, val maxIter: Int = 100, - val tol: Double = 1e-6) extends Logging with Serializable { + val tol: Double = 1e-6 + ) extends Serializable { import WeightedLeastSquares._ require(regParam >= 0.0, s"regParam cannot be negative: $regParam") - if (regParam == 0.0) { - logWarning("regParam is zero, which might cause numerical instability and overfitting.") - } require(elasticNetParam >= 0.0 && elasticNetParam <= 1.0, s"elasticNetParam must be in [0, 1]: $elasticNetParam") require(maxIter >= 0, s"maxIter must be a positive integer: $maxIter") @@ -96,10 +94,17 @@ private[ml] class WeightedLeastSquares( /** * Creates a [[WeightedLeastSquaresModel]] from an RDD of [[Instance]]s. */ - def fit(instances: RDD[Instance]): WeightedLeastSquaresModel = { + def fit( + instances: RDD[Instance], + instr: OptionalInstrumentation = OptionalInstrumentation.create(classOf[WeightedLeastSquares]) + ): WeightedLeastSquaresModel = { + if (regParam == 0.0) { + instr.logWarning("regParam is zero, which might cause numerical instability and overfitting.") + } + val summary = instances.treeAggregate(new Aggregator)(_.add(_), _.merge(_)) summary.validate() - logInfo(s"Number of instances: ${summary.count}.") + instr.logInfo(s"Number of instances: ${summary.count}.") val k = if (fitIntercept) summary.k + 1 else summary.k val numFeatures = summary.k val triK = summary.triK @@ -114,11 +119,12 @@ private[ml] class WeightedLeastSquares( if (rawBStd == 0) { if (fitIntercept || rawBBar == 0.0) { if (rawBBar == 0.0) { - logWarning(s"Mean and standard deviation of the label are zero, so the coefficients " + - s"and the intercept will all be zero; as a result, training is not needed.") + instr.logWarning(s"Mean and standard deviation of the label are zero, so the " + + s"coefficients and the intercept will all be zero; as a result, training is not " + + s"needed.") } else { - logWarning(s"The standard deviation of the label is zero, so the coefficients will be " + - s"zeros and the intercept will be the mean of the label; as a result, " + + instr.logWarning(s"The standard deviation of the label is zero, so the coefficients " + + s"will be zeros and the intercept will be the mean of the label; as a result, " + s"training is not needed.") } val coefficients = new DenseVector(Array.ofDim(numFeatures)) @@ -128,7 +134,7 @@ private[ml] class WeightedLeastSquares( } else { require(!(regParam > 0.0 && standardizeLabel), "The standard deviation of the label is " + "zero. Model cannot be regularized with standardization=true") - logWarning(s"The standard deviation of the label is zero. Consider setting " + + instr.logWarning(s"The standard deviation of the label is zero. Consider setting " + s"fitIntercept=true.") } } @@ -256,7 +262,7 @@ private[ml] class WeightedLeastSquares( // if Auto solver is used and Cholesky fails due to singular AtA, then fall back to // Quasi-Newton solver. case _: SingularMatrixException if solverType == WeightedLeastSquares.Auto => - logWarning("Cholesky solver failed due to singular covariance matrix. " + + instr.logWarning("Cholesky solver failed due to singular covariance matrix. " + "Retrying with Quasi-Newton solver.") // ab and aa were modified in place, so reconstruct them val _aa = getAtA(aaBarValues, aBarValues) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala index 7c6ec2a8419fd..e27a96e1f5dfc 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala @@ -237,7 +237,7 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S if (!$(fitIntercept) && (0 until numFeatures).exists { i => featuresStd(i) == 0.0 && featuresSummarizer.mean(i) != 0.0 }) { - logWarning("Fitting AFTSurvivalRegressionModel without intercept on dataset with " + + instr.logWarning("Fitting AFTSurvivalRegressionModel without intercept on dataset with " + "constant nonzero column, Spark MLlib outputs zero coefficients for constant nonzero " + "columns. This behavior is different from R survival::survreg.") } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala index e030a40cb19be..143c8a3548b1f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala @@ -404,7 +404,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val } val optimizer = new WeightedLeastSquares($(fitIntercept), $(regParam), elasticNetParam = 0.0, standardizeFeatures = true, standardizeLabel = true) - val wlsModel = optimizer.fit(instances) + val wlsModel = optimizer.fit(instances, instr = OptionalInstrumentation.create(instr)) val model = copyValues( new GeneralizedLinearRegressionModel(uid, wlsModel.coefficients, wlsModel.intercept) .setParent(this)) @@ -418,10 +418,11 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val OffsetInstance(label, weight, offset, features) } // Fit Generalized Linear Model by iteratively reweighted least squares (IRLS). - val initialModel = familyAndLink.initialize(instances, $(fitIntercept), $(regParam)) + val initialModel = familyAndLink.initialize(instances, $(fitIntercept), $(regParam), + instr = OptionalInstrumentation.create(instr)) val optimizer = new IterativelyReweightedLeastSquares(initialModel, familyAndLink.reweightFunc, $(fitIntercept), $(regParam), $(maxIter), $(tol)) - val irlsModel = optimizer.fit(instances) + val irlsModel = optimizer.fit(instances, instr = OptionalInstrumentation.create(instr)) val model = copyValues( new GeneralizedLinearRegressionModel(uid, irlsModel.coefficients, irlsModel.intercept) .setParent(this)) @@ -492,7 +493,10 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine def initialize( instances: RDD[OffsetInstance], fitIntercept: Boolean, - regParam: Double): WeightedLeastSquaresModel = { + regParam: Double, + instr: OptionalInstrumentation = OptionalInstrumentation.create( + classOf[GeneralizedLinearRegression]) + ): WeightedLeastSquaresModel = { val newInstances = instances.map { instance => val mu = family.initialize(instance.label, instance.weight) val eta = predict(mu) - instance.offset @@ -501,7 +505,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine // TODO: Make standardizeFeatures and standardizeLabel configurable. val initialModel = new WeightedLeastSquares(fitIntercept, regParam, elasticNetParam = 0.0, standardizeFeatures = true, standardizeLabel = true) - .fit(newInstances) + .fit(newInstances, instr) initialModel } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index f1d9a4453deaa..c45ade94a4e33 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -339,7 +339,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String val optimizer = new WeightedLeastSquares($(fitIntercept), $(regParam), elasticNetParam = $(elasticNetParam), $(standardization), true, solverType = WeightedLeastSquares.Auto, maxIter = $(maxIter), tol = $(tol)) - val model = optimizer.fit(instances) + val model = optimizer.fit(instances, instr = OptionalInstrumentation.create(instr)) // When it is trained by WeightedLeastSquares, training summary does not // attach returned model. val lrModel = copyValues(new LinearRegressionModel(uid, model.coefficients, model.intercept)) @@ -378,6 +378,11 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String val yMean = ySummarizer.mean(0) val rawYStd = math.sqrt(ySummarizer.variance(0)) + + instr.logNumExamples(ySummarizer.count) + instr.logNamedValue(Instrumentation.loggerTags.meanOfLabels, yMean) + instr.logNamedValue(Instrumentation.loggerTags.varianceOfLabels, rawYStd) + if (rawYStd == 0.0) { if ($(fitIntercept) || yMean == 0.0) { // If the rawYStd==0 and fitIntercept==true, then the intercept is yMean with @@ -385,11 +390,12 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String // Also, if yMean==0 and rawYStd==0, all the coefficients are zero regardless of // the fitIntercept. if (yMean == 0.0) { - logWarning(s"Mean and standard deviation of the label are zero, so the coefficients " + - s"and the intercept will all be zero; as a result, training is not needed.") + instr.logWarning(s"Mean and standard deviation of the label are zero, so the " + + s"coefficients and the intercept will all be zero; as a result, training is not " + + s"needed.") } else { - logWarning(s"The standard deviation of the label is zero, so the coefficients will be " + - s"zeros and the intercept will be the mean of the label; as a result, " + + instr.logWarning(s"The standard deviation of the label is zero, so the coefficients " + + s"will be zeros and the intercept will be the mean of the label; as a result, " + s"training is not needed.") } if (handlePersistence) instances.unpersist() @@ -415,7 +421,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String } else { require($(regParam) == 0.0, "The standard deviation of the label is zero. " + "Model cannot be regularized.") - logWarning(s"The standard deviation of the label is zero. " + + instr.logWarning(s"The standard deviation of the label is zero. " + "Consider setting fitIntercept=true.") } } @@ -430,7 +436,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String if (!$(fitIntercept) && (0 until numFeatures).exists { i => featuresStd(i) == 0.0 && featuresMean(i) != 0.0 }) { - logWarning("Fitting LinearRegressionModel without intercept on dataset with " + + instr.logWarning("Fitting LinearRegressionModel without intercept on dataset with " + "constant nonzero column, Spark MLlib outputs zero coefficients for constant nonzero " + "columns. This behavior is the same as R glmnet but different from LIBSVM.") } @@ -522,7 +528,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String } if (state == null) { val msg = s"${optimizer.getClass.getName} failed." - logError(msg) + instr.logError(msg) throw new SparkException(msg) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index 056a94b351f79..905870178e549 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -108,9 +108,11 @@ private[spark] object RandomForest extends Logging { case Some(instrumentation) => instrumentation.logNumFeatures(metadata.numFeatures) instrumentation.logNumClasses(metadata.numClasses) + instrumentation.logNumExamples(metadata.numExamples) case None => logInfo("numFeatures: " + metadata.numFeatures) logInfo("numClasses: " + metadata.numClasses) + logInfo("numExamples: " + metadata.numExamples) } // Find the splits and the corresponding bins (interval between the splits) using a sample diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala b/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala index e694bc27b2f1e..3247c394dfa64 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala @@ -19,6 +19,8 @@ package org.apache.spark.ml.util import java.util.UUID +import scala.reflect.ClassTag + import org.json4s._ import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ @@ -40,7 +42,8 @@ import org.apache.spark.sql.Dataset * @tparam E the type of the estimator */ private[spark] class Instrumentation[E <: Estimator[_]] private ( - estimator: E, dataset: RDD[_]) extends Logging { + val estimator: E, + val dataset: RDD[_]) extends Logging { private val id = UUID.randomUUID() private val prefix = { @@ -103,6 +106,10 @@ private[spark] class Instrumentation[E <: Estimator[_]] private ( logNamedValue(Instrumentation.loggerTags.numClasses, num) } + def logNumExamples(num: Long): Unit = { + logNamedValue(Instrumentation.loggerTags.numExamples, num) + } + /** * Logs the value with customized name field. */ @@ -114,6 +121,10 @@ private[spark] class Instrumentation[E <: Estimator[_]] private ( log(compact(render(name -> value))) } + def logNamedValue(name: String, value: Double): Unit = { + log(compact(render(name -> value))) + } + /** * Logs the successful completion of the training session. */ @@ -131,6 +142,8 @@ private[spark] object Instrumentation { val numFeatures = "numFeatures" val numClasses = "numClasses" val numExamples = "numExamples" + val meanOfLabels = "meanOfLabels" + val varianceOfLabels = "varianceOfLabels" } /** @@ -150,3 +163,56 @@ private[spark] object Instrumentation { } } + +/** + * A small wrapper that contains an optional `Instrumentation` object. + * Provide some log methods, if the containing `Instrumentation` object is defined, + * will log via it, otherwise will log via common logger. + */ +private[spark] class OptionalInstrumentation private( + val instrumentation: Option[Instrumentation[_ <: Estimator[_]]], + val className: String) extends Logging { + + protected override def logName: String = className + + override def logInfo(msg: => String) { + instrumentation match { + case Some(instr) => instr.logInfo(msg) + case None => super.logInfo(msg) + } + } + + override def logWarning(msg: => String) { + instrumentation match { + case Some(instr) => instr.logWarning(msg) + case None => super.logWarning(msg) + } + } + + override def logError(msg: => String) { + instrumentation match { + case Some(instr) => instr.logError(msg) + case None => super.logError(msg) + } + } +} + +private[spark] object OptionalInstrumentation { + + /** + * Creates an `OptionalInstrumentation` object from an existing `Instrumentation` object. + */ + def create[E <: Estimator[_]](instr: Instrumentation[E]): OptionalInstrumentation = { + new OptionalInstrumentation(Some(instr), + instr.estimator.getClass.getName.stripSuffix("$")) + } + + /** + * Creates an `OptionalInstrumentation` object from a `Class` object. + * The created `OptionalInstrumentation` object will log messages via common logger and use the + * specified class name as logger name. + */ + def create(clazz: Class[_]): OptionalInstrumentation = { + new OptionalInstrumentation(None, clazz.getName.stripSuffix("$")) + } +} From 7b1e6523af3c96043aa8d2763e5f18b6e2781c3d Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 24 Apr 2018 14:33:33 -0700 Subject: [PATCH 0684/1161] [SPARK-24056][SS] Make consumer creation lazy in Kafka source for Structured streaming ## What changes were proposed in this pull request? Currently, the driver side of the Kafka source (i.e. KafkaMicroBatchReader) eagerly creates a consumer as soon as the Kafk aMicroBatchReader is created. However, we create dummy KafkaMicroBatchReader to get the schema and immediately stop it. Its better to make the consumer creation lazy, it will be created on the first attempt to fetch offsets using the KafkaOffsetReader. ## How was this patch tested? Existing unit tests Author: Tathagata Das Closes #21134 from tdas/SPARK-24056. --- .../sql/kafka010/KafkaOffsetReader.scala | 31 ++++++++++--------- 1 file changed, 17 insertions(+), 14 deletions(-) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala index 551641cfdbca8..82066697cb95a 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala @@ -75,7 +75,17 @@ private[kafka010] class KafkaOffsetReader( * A KafkaConsumer used in the driver to query the latest Kafka offsets. This only queries the * offsets and never commits them. */ - protected var consumer = createConsumer() + @volatile protected var _consumer: Consumer[Array[Byte], Array[Byte]] = null + + protected def consumer: Consumer[Array[Byte], Array[Byte]] = synchronized { + assert(Thread.currentThread().isInstanceOf[UninterruptibleThread]) + if (_consumer == null) { + val newKafkaParams = new ju.HashMap[String, Object](driverKafkaParams) + newKafkaParams.put(ConsumerConfig.GROUP_ID_CONFIG, nextGroupId()) + _consumer = consumerStrategy.createConsumer(newKafkaParams) + } + _consumer + } private val maxOffsetFetchAttempts = readerOptions.getOrElse("fetchOffset.numRetries", "3").toInt @@ -95,9 +105,7 @@ private[kafka010] class KafkaOffsetReader( * Closes the connection to Kafka, and cleans up state. */ def close(): Unit = { - runUninterruptibly { - consumer.close() - } + if (_consumer != null) runUninterruptibly { stopConsumer() } kafkaReaderThread.shutdown() } @@ -304,19 +312,14 @@ private[kafka010] class KafkaOffsetReader( } } - /** - * Create a consumer using the new generated group id. We always use a new consumer to avoid - * just using a broken consumer to retry on Kafka errors, which likely will fail again. - */ - private def createConsumer(): Consumer[Array[Byte], Array[Byte]] = synchronized { - val newKafkaParams = new ju.HashMap[String, Object](driverKafkaParams) - newKafkaParams.put(ConsumerConfig.GROUP_ID_CONFIG, nextGroupId()) - consumerStrategy.createConsumer(newKafkaParams) + private def stopConsumer(): Unit = synchronized { + assert(Thread.currentThread().isInstanceOf[UninterruptibleThread]) + if (_consumer != null) _consumer.close() } private def resetConsumer(): Unit = synchronized { - consumer.close() - consumer = createConsumer() + stopConsumer() + _consumer = null // will automatically get reinitialized again } } From d6c26d1c9a8f747a3e0d281a27ea9eb4d92102e5 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Tue, 24 Apr 2018 17:06:03 -0700 Subject: [PATCH 0685/1161] [SPARK-24038][SS] Refactor continuous writing to its own class ## What changes were proposed in this pull request? Refactor continuous writing to its own class. See WIP https://github.com/jose-torres/spark/pull/13 for the overall direction this is going, but I think this PR is very isolated and necessary anyway. ## How was this patch tested? existing unit tests - refactoring only Author: Jose Torres Closes #21116 from jose-torres/SPARK-24038. --- .../datasources/v2/DataSourceV2Strategy.scala | 4 + .../datasources/v2/WriteToDataSourceV2.scala | 74 +---------- .../continuous/ContinuousExecution.scala | 2 +- .../WriteToContinuousDataSource.scala | 31 +++++ .../WriteToContinuousDataSourceExec.scala | 124 ++++++++++++++++++ 5 files changed, 165 insertions(+), 70 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSource.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index 1ac9572de6412..c2a31442d2be5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.datasources.v2 import org.apache.spark.sql.Strategy import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.streaming.continuous.{WriteToContinuousDataSource, WriteToContinuousDataSourceExec} object DataSourceV2Strategy extends Strategy { override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { @@ -32,6 +33,9 @@ object DataSourceV2Strategy extends Strategy { case WriteToDataSourceV2(writer, query) => WriteToDataSourceV2Exec(writer, planLater(query)) :: Nil + case WriteToContinuousDataSource(writer, query) => + WriteToContinuousDataSourceExec(writer, planLater(query)) :: Nil + case _ => Nil } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala index e80b44c1cdc66..ea283ed77efda 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala @@ -65,25 +65,10 @@ case class WriteToDataSourceV2Exec(writer: DataSourceWriter, query: SparkPlan) e s"The input RDD has ${messages.length} partitions.") try { - val runTask = writer match { - // This case means that we're doing continuous processing. In microbatch streaming, the - // StreamWriter is wrapped in a MicroBatchWriter, which is executed as a normal batch. - case w: StreamWriter => - EpochCoordinatorRef.get( - sparkContext.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), - sparkContext.env) - .askSync[Unit](SetWriterPartitions(rdd.getNumPartitions)) - - (context: TaskContext, iter: Iterator[InternalRow]) => - DataWritingSparkTask.runContinuous(writeTask, context, iter) - case _ => - (context: TaskContext, iter: Iterator[InternalRow]) => - DataWritingSparkTask.run(writeTask, context, iter, useCommitCoordinator) - } - sparkContext.runJob( rdd, - runTask, + (context: TaskContext, iter: Iterator[InternalRow]) => + DataWritingSparkTask.run(writeTask, context, iter, useCommitCoordinator), rdd.partitions.indices, (index, message: WriterCommitMessage) => { messages(index) = message @@ -91,14 +76,10 @@ case class WriteToDataSourceV2Exec(writer: DataSourceWriter, query: SparkPlan) e } ) - if (!writer.isInstanceOf[StreamWriter]) { - logInfo(s"Data source writer $writer is committing.") - writer.commit(messages) - logInfo(s"Data source writer $writer committed.") - } + logInfo(s"Data source writer $writer is committing.") + writer.commit(messages) + logInfo(s"Data source writer $writer committed.") } catch { - case _: InterruptedException if writer.isInstanceOf[StreamWriter] => - // Interruption is how continuous queries are ended, so accept and ignore the exception. case cause: Throwable => logError(s"Data source writer $writer is aborting.") try { @@ -111,8 +92,6 @@ case class WriteToDataSourceV2Exec(writer: DataSourceWriter, query: SparkPlan) e } logError(s"Data source writer $writer aborted.") cause match { - // Do not wrap interruption exceptions that will be handled by streaming specially. - case _ if StreamExecution.isInterruptionException(cause) => throw cause // Only wrap non fatal exceptions. case NonFatal(e) => throw new SparkException("Writing job aborted.", e) case _ => throw cause @@ -168,49 +147,6 @@ object DataWritingSparkTask extends Logging { logError(s"Writer for stage $stageId, task $partId.$attemptId aborted.") }) } - - def runContinuous( - writeTask: DataWriterFactory[InternalRow], - context: TaskContext, - iter: Iterator[InternalRow]): WriterCommitMessage = { - val epochCoordinator = EpochCoordinatorRef.get( - context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), - SparkEnv.get) - val currentMsg: WriterCommitMessage = null - var currentEpoch = context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong - - do { - var dataWriter: DataWriter[InternalRow] = null - // write the data and commit this writer. - Utils.tryWithSafeFinallyAndFailureCallbacks(block = { - try { - dataWriter = writeTask.createDataWriter( - context.partitionId(), context.attemptNumber(), currentEpoch) - while (iter.hasNext) { - dataWriter.write(iter.next()) - } - logInfo(s"Writer for partition ${context.partitionId()} is committing.") - val msg = dataWriter.commit() - logInfo(s"Writer for partition ${context.partitionId()} committed.") - epochCoordinator.send( - CommitPartitionEpoch(context.partitionId(), currentEpoch, msg) - ) - currentEpoch += 1 - } catch { - case _: InterruptedException => - // Continuous shutdown always involves an interrupt. Just finish the task. - } - })(catchBlock = { - // If there is an error, abort this writer. We enter this callback in the middle of - // rethrowing an exception, so runContinuous will stop executing at this point. - logError(s"Writer for partition ${context.partitionId()} is aborting.") - if (dataWriter != null) dataWriter.abort() - logError(s"Writer for partition ${context.partitionId()} aborted.") - }) - } while (!context.isInterrupted()) - - currentMsg - } } class InternalRowDataWriterFactory( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index 951d694355ec5..f58146ac42398 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -199,7 +199,7 @@ class ContinuousExecution( triggerLogicalPlan.schema, outputMode, new DataSourceOptions(extraOptions.asJava)) - val withSink = WriteToDataSourceV2(writer, triggerLogicalPlan) + val withSink = WriteToContinuousDataSource(writer, triggerLogicalPlan) val reader = withSink.collect { case StreamingDataSourceV2Relation(_, _, _, r: ContinuousReader) => r diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSource.scala new file mode 100644 index 0000000000000..943c731a70529 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSource.scala @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.continuous + +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter + +/** + * The logical plan for writing data in a continuous stream. + */ +case class WriteToContinuousDataSource( + writer: StreamWriter, query: LogicalPlan) extends LogicalPlan { + override def children: Seq[LogicalPlan] = Seq(query) + override def output: Seq[Attribute] = Nil +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala new file mode 100644 index 0000000000000..ba88ae1af469a --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala @@ -0,0 +1,124 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.continuous + +import scala.util.control.NonFatal + +import org.apache.spark.{SparkEnv, SparkException, TaskContext} +import org.apache.spark.internal.Logging +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.datasources.v2.{DataWritingSparkTask, InternalRowDataWriterFactory} +import org.apache.spark.sql.execution.datasources.v2.DataWritingSparkTask.{logError, logInfo} +import org.apache.spark.sql.execution.streaming.StreamExecution +import org.apache.spark.sql.sources.v2.writer._ +import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter +import org.apache.spark.util.Utils + +/** + * The physical plan for writing data into a continuous processing [[StreamWriter]]. + */ +case class WriteToContinuousDataSourceExec(writer: StreamWriter, query: SparkPlan) + extends SparkPlan with Logging { + override def children: Seq[SparkPlan] = Seq(query) + override def output: Seq[Attribute] = Nil + + override protected def doExecute(): RDD[InternalRow] = { + val writerFactory = writer match { + case w: SupportsWriteInternalRow => w.createInternalRowWriterFactory() + case _ => new InternalRowDataWriterFactory(writer.createWriterFactory(), query.schema) + } + + val rdd = query.execute() + + logInfo(s"Start processing data source writer: $writer. " + + s"The input RDD has ${rdd.getNumPartitions} partitions.") + // Let the epoch coordinator know how many partitions the write RDD has. + EpochCoordinatorRef.get( + sparkContext.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), + sparkContext.env) + .askSync[Unit](SetWriterPartitions(rdd.getNumPartitions)) + + try { + // Force the RDD to run so continuous processing starts; no data is actually being collected + // to the driver, as ContinuousWriteRDD outputs nothing. + sparkContext.runJob( + rdd, + (context: TaskContext, iter: Iterator[InternalRow]) => + WriteToContinuousDataSourceExec.run(writerFactory, context, iter), + rdd.partitions.indices) + } catch { + case _: InterruptedException => + // Interruption is how continuous queries are ended, so accept and ignore the exception. + case cause: Throwable => + cause match { + // Do not wrap interruption exceptions that will be handled by streaming specially. + case _ if StreamExecution.isInterruptionException(cause) => throw cause + // Only wrap non fatal exceptions. + case NonFatal(e) => throw new SparkException("Writing job aborted.", e) + case _ => throw cause + } + } + + sparkContext.emptyRDD + } +} + +object WriteToContinuousDataSourceExec extends Logging { + def run( + writeTask: DataWriterFactory[InternalRow], + context: TaskContext, + iter: Iterator[InternalRow]): Unit = { + val epochCoordinator = EpochCoordinatorRef.get( + context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), + SparkEnv.get) + var currentEpoch = context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong + + do { + var dataWriter: DataWriter[InternalRow] = null + // write the data and commit this writer. + Utils.tryWithSafeFinallyAndFailureCallbacks(block = { + try { + dataWriter = writeTask.createDataWriter( + context.partitionId(), context.attemptNumber(), currentEpoch) + while (iter.hasNext) { + dataWriter.write(iter.next()) + } + logInfo(s"Writer for partition ${context.partitionId()} is committing.") + val msg = dataWriter.commit() + logInfo(s"Writer for partition ${context.partitionId()} committed.") + epochCoordinator.send( + CommitPartitionEpoch(context.partitionId(), currentEpoch, msg) + ) + currentEpoch += 1 + } catch { + case _: InterruptedException => + // Continuous shutdown always involves an interrupt. Just finish the task. + } + })(catchBlock = { + // If there is an error, abort this writer. We enter this callback in the middle of + // rethrowing an exception, so runContinuous will stop executing at this point. + logError(s"Writer for partition ${context.partitionId()} is aborting.") + if (dataWriter != null) dataWriter.abort() + logError(s"Writer for partition ${context.partitionId()} aborted.") + }) + } while (!context.isInterrupted()) + } +} From 5fea17b3befc50aef59b799711d03b9552f21b19 Mon Sep 17 00:00:00 2001 From: mn-mikke Date: Wed, 25 Apr 2018 11:19:08 +0900 Subject: [PATCH 0686/1161] [SPARK-23821][SQL] Collection function: flatten ## What changes were proposed in this pull request? This PR adds a new collection function that transforms an array of arrays into a single array. The PR comprises: - An expression for flattening array structure - Flatten function - A wrapper for PySpark ## How was this patch tested? New tests added into: - CollectionExpressionsSuite - DataFrameFunctionsSuite ## Codegen examples ### Primitive type ``` val df = Seq( Seq(Seq(1, 2), Seq(4, 5)), Seq(null, Seq(1)) ).toDF("i") df.filter($"i".isNotNull || $"i".isNull).select(flatten($"i")).debugCodegen ``` Result: ``` /* 033 */ boolean inputadapter_isNull = inputadapter_row.isNullAt(0); /* 034 */ ArrayData inputadapter_value = inputadapter_isNull ? /* 035 */ null : (inputadapter_row.getArray(0)); /* 036 */ /* 037 */ boolean filter_value = true; /* 038 */ /* 039 */ if (!(!inputadapter_isNull)) { /* 040 */ filter_value = inputadapter_isNull; /* 041 */ } /* 042 */ if (!filter_value) continue; /* 043 */ /* 044 */ ((org.apache.spark.sql.execution.metric.SQLMetric) references[0] /* numOutputRows */).add(1); /* 045 */ /* 046 */ boolean project_isNull = inputadapter_isNull; /* 047 */ ArrayData project_value = null; /* 048 */ /* 049 */ if (!inputadapter_isNull) { /* 050 */ for (int z = 0; !project_isNull && z < inputadapter_value.numElements(); z++) { /* 051 */ project_isNull |= inputadapter_value.isNullAt(z); /* 052 */ } /* 053 */ if (!project_isNull) { /* 054 */ long project_numElements = 0; /* 055 */ for (int z = 0; z < inputadapter_value.numElements(); z++) { /* 056 */ project_numElements += inputadapter_value.getArray(z).numElements(); /* 057 */ } /* 058 */ if (project_numElements > 2147483632) { /* 059 */ throw new RuntimeException("Unsuccessful try to flatten an array of arrays with " + /* 060 */ project_numElements + " elements due to exceeding the array size limit 2147483632."); /* 061 */ } /* 062 */ /* 063 */ long project_size = UnsafeArrayData.calculateSizeOfUnderlyingByteArray( /* 064 */ project_numElements, /* 065 */ 4); /* 066 */ if (project_size > 2147483632) { /* 067 */ throw new RuntimeException("Unsuccessful try to flatten an array of arrays with " + /* 068 */ project_size + " bytes of data due to exceeding the limit 2147483632" + /* 069 */ " bytes for UnsafeArrayData."); /* 070 */ } /* 071 */ /* 072 */ byte[] project_array = new byte[(int)project_size]; /* 073 */ UnsafeArrayData project_tempArrayData = new UnsafeArrayData(); /* 074 */ Platform.putLong(project_array, 16, project_numElements); /* 075 */ project_tempArrayData.pointTo(project_array, 16, (int)project_size); /* 076 */ int project_counter = 0; /* 077 */ for (int k = 0; k < inputadapter_value.numElements(); k++) { /* 078 */ ArrayData arr = inputadapter_value.getArray(k); /* 079 */ for (int l = 0; l < arr.numElements(); l++) { /* 080 */ if (arr.isNullAt(l)) { /* 081 */ project_tempArrayData.setNullAt(project_counter); /* 082 */ } else { /* 083 */ project_tempArrayData.setInt( /* 084 */ project_counter, /* 085 */ arr.getInt(l) /* 086 */ ); /* 087 */ } /* 088 */ project_counter++; /* 089 */ } /* 090 */ } /* 091 */ project_value = project_tempArrayData; /* 092 */ /* 093 */ } /* 094 */ /* 095 */ } ``` ### Non-primitive type ``` val df = Seq( Seq(Seq("a", "b"), Seq(null, "d")), Seq(null, Seq("a")) ).toDF("s") df.filter($"s".isNotNull || $"s".isNull).select(flatten($"s")).debugCodegen ``` Result: ``` /* 033 */ boolean inputadapter_isNull = inputadapter_row.isNullAt(0); /* 034 */ ArrayData inputadapter_value = inputadapter_isNull ? /* 035 */ null : (inputadapter_row.getArray(0)); /* 036 */ /* 037 */ boolean filter_value = true; /* 038 */ /* 039 */ if (!(!inputadapter_isNull)) { /* 040 */ filter_value = inputadapter_isNull; /* 041 */ } /* 042 */ if (!filter_value) continue; /* 043 */ /* 044 */ ((org.apache.spark.sql.execution.metric.SQLMetric) references[0] /* numOutputRows */).add(1); /* 045 */ /* 046 */ boolean project_isNull = inputadapter_isNull; /* 047 */ ArrayData project_value = null; /* 048 */ /* 049 */ if (!inputadapter_isNull) { /* 050 */ for (int z = 0; !project_isNull && z < inputadapter_value.numElements(); z++) { /* 051 */ project_isNull |= inputadapter_value.isNullAt(z); /* 052 */ } /* 053 */ if (!project_isNull) { /* 054 */ long project_numElements = 0; /* 055 */ for (int z = 0; z < inputadapter_value.numElements(); z++) { /* 056 */ project_numElements += inputadapter_value.getArray(z).numElements(); /* 057 */ } /* 058 */ if (project_numElements > 2147483632) { /* 059 */ throw new RuntimeException("Unsuccessful try to flatten an array of arrays with " + /* 060 */ project_numElements + " elements due to exceeding the array size limit 2147483632."); /* 061 */ } /* 062 */ /* 063 */ Object[] project_arrayObject = new Object[(int)project_numElements]; /* 064 */ int project_counter = 0; /* 065 */ for (int k = 0; k < inputadapter_value.numElements(); k++) { /* 066 */ ArrayData arr = inputadapter_value.getArray(k); /* 067 */ for (int l = 0; l < arr.numElements(); l++) { /* 068 */ project_arrayObject[project_counter] = arr.getUTF8String(l); /* 069 */ project_counter++; /* 070 */ } /* 071 */ } /* 072 */ project_value = new org.apache.spark.sql.catalyst.util.GenericArrayData(project_arrayObject); /* 073 */ /* 074 */ } /* 075 */ /* 076 */ } ``` Author: mn-mikke Closes #20938 from mn-mikke/feature/array-api-flatten-to-master. --- python/pyspark/sql/functions.py | 17 ++ .../catalyst/analysis/FunctionRegistry.scala | 1 + .../expressions/collectionOperations.scala | 176 ++++++++++++++++++ .../CollectionExpressionsSuite.scala | 95 ++++++++++ .../org/apache/spark/sql/functions.scala | 8 + .../spark/sql/DataFrameFunctionsSuite.scala | 79 ++++++++ 6 files changed, 376 insertions(+) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index da32ab25cad0c..de53b48b6f3b4 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2191,6 +2191,23 @@ def reverse(col): return Column(sc._jvm.functions.reverse(_to_java_column(col))) +@since(2.4) +def flatten(col): + """ + Collection function: creates a single array from an array of arrays. + If a structure of nested arrays is deeper than two levels, + only one level of nesting is removed. + + :param col: name of column or expression + + >>> df = spark.createDataFrame([([[1, 2, 3], [4, 5], [6]],), ([None, [4, 5]],)], ['data']) + >>> df.select(flatten(df.data).alias('r')).collect() + [Row(r=[1, 2, 3, 4, 5, 6]), Row(r=None)] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.flatten(_to_java_column(col))) + + @since(2.3) def map_keys(col): """ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index c41f16c61d7a2..6afcf309bd690 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -413,6 +413,7 @@ object FunctionRegistry { expression[ArrayMax]("array_max"), expression[Reverse]("reverse"), expression[Concat]("concat"), + expression[Flatten]("flatten"), CreateStruct.registryEntry, // misc functions diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index c16793bda028e..bc71b5f34ce4a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -883,3 +883,179 @@ case class Concat(children: Seq[Expression]) extends Expression { override def sql: String = s"concat(${children.map(_.sql).mkString(", ")})" } + +/** + * Transforms an array of arrays into a single array. + */ +@ExpressionDescription( + usage = "_FUNC_(arrayOfArrays) - Transforms an array of arrays into a single array.", + examples = """ + Examples: + > SELECT _FUNC_(array(array(1, 2), array(3, 4)); + [1,2,3,4] + """, + since = "2.4.0") +case class Flatten(child: Expression) extends UnaryExpression { + + private val MAX_ARRAY_LENGTH = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH + + private lazy val childDataType: ArrayType = child.dataType.asInstanceOf[ArrayType] + + override def nullable: Boolean = child.nullable || childDataType.containsNull + + override def dataType: DataType = childDataType.elementType + + lazy val elementType: DataType = dataType.asInstanceOf[ArrayType].elementType + + override def checkInputDataTypes(): TypeCheckResult = child.dataType match { + case ArrayType(_: ArrayType, _) => + TypeCheckResult.TypeCheckSuccess + case _ => + TypeCheckResult.TypeCheckFailure( + s"The argument should be an array of arrays, " + + s"but '${child.sql}' is of ${child.dataType.simpleString} type." + ) + } + + override def nullSafeEval(child: Any): Any = { + val elements = child.asInstanceOf[ArrayData].toObjectArray(dataType) + + if (elements.contains(null)) { + null + } else { + val arrayData = elements.map(_.asInstanceOf[ArrayData]) + val numberOfElements = arrayData.foldLeft(0L)((sum, e) => sum + e.numElements()) + if (numberOfElements > MAX_ARRAY_LENGTH) { + throw new RuntimeException("Unsuccessful try to flatten an array of arrays with " + + s"$numberOfElements elements due to exceeding the array size limit $MAX_ARRAY_LENGTH.") + } + val flattenedData = new Array(numberOfElements.toInt) + var position = 0 + for (ad <- arrayData) { + val arr = ad.toObjectArray(elementType) + Array.copy(arr, 0, flattenedData, position, arr.length) + position += arr.length + } + new GenericArrayData(flattenedData) + } + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, c => { + val code = if (CodeGenerator.isPrimitiveType(elementType)) { + genCodeForFlattenOfPrimitiveElements(ctx, c, ev.value) + } else { + genCodeForFlattenOfNonPrimitiveElements(ctx, c, ev.value) + } + if (childDataType.containsNull) nullElementsProtection(ev, c, code) else code + }) + } + + private def nullElementsProtection( + ev: ExprCode, + childVariableName: String, + coreLogic: String): String = { + s""" + |for (int z = 0; !${ev.isNull} && z < $childVariableName.numElements(); z++) { + | ${ev.isNull} |= $childVariableName.isNullAt(z); + |} + |if (!${ev.isNull}) { + | $coreLogic + |} + """.stripMargin + } + + private def genCodeForNumberOfElements( + ctx: CodegenContext, + childVariableName: String) : (String, String) = { + val variableName = ctx.freshName("numElements") + val code = s""" + |long $variableName = 0; + |for (int z = 0; z < $childVariableName.numElements(); z++) { + | $variableName += $childVariableName.getArray(z).numElements(); + |} + |if ($variableName > $MAX_ARRAY_LENGTH) { + | throw new RuntimeException("Unsuccessful try to flatten an array of arrays with " + + | $variableName + " elements due to exceeding the array size limit $MAX_ARRAY_LENGTH."); + |} + """.stripMargin + (code, variableName) + } + + private def genCodeForFlattenOfPrimitiveElements( + ctx: CodegenContext, + childVariableName: String, + arrayDataName: String): String = { + val arrayName = ctx.freshName("array") + val arraySizeName = ctx.freshName("size") + val counter = ctx.freshName("counter") + val tempArrayDataName = ctx.freshName("tempArrayData") + + val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx, childVariableName) + + val unsafeArraySizeInBytes = s""" + |long $arraySizeName = UnsafeArrayData.calculateSizeOfUnderlyingByteArray( + | $numElemName, + | ${elementType.defaultSize}); + |if ($arraySizeName > $MAX_ARRAY_LENGTH) { + | throw new RuntimeException("Unsuccessful try to flatten an array of arrays with " + + | $arraySizeName + " bytes of data due to exceeding the limit $MAX_ARRAY_LENGTH" + + | " bytes for UnsafeArrayData."); + |} + """.stripMargin + val baseOffset = Platform.BYTE_ARRAY_OFFSET + + val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType) + + s""" + |$numElemCode + |$unsafeArraySizeInBytes + |byte[] $arrayName = new byte[(int)$arraySizeName]; + |UnsafeArrayData $tempArrayDataName = new UnsafeArrayData(); + |Platform.putLong($arrayName, $baseOffset, $numElemName); + |$tempArrayDataName.pointTo($arrayName, $baseOffset, (int)$arraySizeName); + |int $counter = 0; + |for (int k = 0; k < $childVariableName.numElements(); k++) { + | ArrayData arr = $childVariableName.getArray(k); + | for (int l = 0; l < arr.numElements(); l++) { + | if (arr.isNullAt(l)) { + | $tempArrayDataName.setNullAt($counter); + | } else { + | $tempArrayDataName.set$primitiveValueTypeName( + | $counter, + | ${CodeGenerator.getValue("arr", elementType, "l")} + | ); + | } + | $counter++; + | } + |} + |$arrayDataName = $tempArrayDataName; + """.stripMargin + } + + private def genCodeForFlattenOfNonPrimitiveElements( + ctx: CodegenContext, + childVariableName: String, + arrayDataName: String): String = { + val genericArrayClass = classOf[GenericArrayData].getName + val arrayName = ctx.freshName("arrayObject") + val counter = ctx.freshName("counter") + val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx, childVariableName) + + s""" + |$numElemCode + |Object[] $arrayName = new Object[(int)$numElemName]; + |int $counter = 0; + |for (int k = 0; k < $childVariableName.numElements(); k++) { + | ArrayData arr = $childVariableName.getArray(k); + | for (int l = 0; l < arr.numElements(); l++) { + | $arrayName[$counter] = ${CodeGenerator.getValue("arr", elementType, "l")}; + | $counter++; + | } + |} + |$arrayDataName = new $genericArrayClass($arrayName); + """.stripMargin + } + + override def prettyName: String = "flatten" +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 43c5dda2e4a48..b49fa76b2a781 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -280,4 +280,99 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(Concat(Seq(aa0, aa1)), Seq(Seq("a", "b"), Seq("c"), Seq("d"), Seq("e", "f"))) } + + test("Flatten") { + // Primitive-type test cases + val intArrayType = ArrayType(ArrayType(IntegerType)) + + // Main test cases (primitive type) + val aim1 = Literal.create(Seq(Seq(1, 2, 3), Seq(4, 5), Seq(6)), intArrayType) + val aim2 = Literal.create(Seq(Seq(1, 2, 3)), intArrayType) + + checkEvaluation(Flatten(aim1), Seq(1, 2, 3, 4, 5, 6)) + checkEvaluation(Flatten(aim2), Seq(1, 2, 3)) + + // Test cases with an empty array (primitive type) + val aie1 = Literal.create(Seq(Seq.empty, Seq(1, 2), Seq(3, 4)), intArrayType) + val aie2 = Literal.create(Seq(Seq(1, 2), Seq.empty, Seq(3, 4)), intArrayType) + val aie3 = Literal.create(Seq(Seq(1, 2), Seq(3, 4), Seq.empty), intArrayType) + val aie4 = Literal.create(Seq(Seq.empty, Seq.empty, Seq.empty), intArrayType) + val aie5 = Literal.create(Seq(Seq.empty), intArrayType) + val aie6 = Literal.create(Seq.empty, intArrayType) + + checkEvaluation(Flatten(aie1), Seq(1, 2, 3, 4)) + checkEvaluation(Flatten(aie2), Seq(1, 2, 3, 4)) + checkEvaluation(Flatten(aie3), Seq(1, 2, 3, 4)) + checkEvaluation(Flatten(aie4), Seq.empty) + checkEvaluation(Flatten(aie5), Seq.empty) + checkEvaluation(Flatten(aie6), Seq.empty) + + // Test cases with null elements (primitive type) + val ain1 = Literal.create(Seq(Seq(null, null, null), Seq(4, null)), intArrayType) + val ain2 = Literal.create(Seq(Seq(null, 2, null), Seq(null, null)), intArrayType) + val ain3 = Literal.create(Seq(Seq(null, null), Seq(null, null)), intArrayType) + + checkEvaluation(Flatten(ain1), Seq(null, null, null, 4, null)) + checkEvaluation(Flatten(ain2), Seq(null, 2, null, null, null)) + checkEvaluation(Flatten(ain3), Seq(null, null, null, null)) + + // Test cases with a null array (primitive type) + val aia1 = Literal.create(Seq(null, Seq(1, 2)), intArrayType) + val aia2 = Literal.create(Seq(Seq(1, 2), null), intArrayType) + val aia3 = Literal.create(Seq(null), intArrayType) + val aia4 = Literal.create(null, intArrayType) + + checkEvaluation(Flatten(aia1), null) + checkEvaluation(Flatten(aia2), null) + checkEvaluation(Flatten(aia3), null) + checkEvaluation(Flatten(aia4), null) + + // Non-primitive-type test cases + val strArrayType = ArrayType(ArrayType(StringType)) + val arrArrayType = ArrayType(ArrayType(ArrayType(StringType))) + + // Main test cases (non-primitive type) + val asm1 = Literal.create(Seq(Seq("a"), Seq("b", "c"), Seq("d", "e", "f")), strArrayType) + val asm2 = Literal.create(Seq(Seq("a", "b")), strArrayType) + val asm3 = Literal.create(Seq(Seq(Seq("a", "b"), Seq("c")), Seq(Seq("d", "e"))), arrArrayType) + + checkEvaluation(Flatten(asm1), Seq("a", "b", "c", "d", "e", "f")) + checkEvaluation(Flatten(asm2), Seq("a", "b")) + checkEvaluation(Flatten(asm3), Seq(Seq("a", "b"), Seq("c"), Seq("d", "e"))) + + // Test cases with an empty array (non-primitive type) + val ase1 = Literal.create(Seq(Seq.empty, Seq("a", "b"), Seq("c", "d")), strArrayType) + val ase2 = Literal.create(Seq(Seq("a", "b"), Seq.empty, Seq("c", "d")), strArrayType) + val ase3 = Literal.create(Seq(Seq("a", "b"), Seq("c", "d"), Seq.empty), strArrayType) + val ase4 = Literal.create(Seq(Seq.empty, Seq.empty, Seq.empty), strArrayType) + val ase5 = Literal.create(Seq(Seq.empty), strArrayType) + val ase6 = Literal.create(Seq.empty, strArrayType) + + checkEvaluation(Flatten(ase1), Seq("a", "b", "c", "d")) + checkEvaluation(Flatten(ase2), Seq("a", "b", "c", "d")) + checkEvaluation(Flatten(ase3), Seq("a", "b", "c", "d")) + checkEvaluation(Flatten(ase4), Seq.empty) + checkEvaluation(Flatten(ase5), Seq.empty) + checkEvaluation(Flatten(ase6), Seq.empty) + + // Test cases with null elements (non-primitive type) + val asn1 = Literal.create(Seq(Seq(null, null, "c"), Seq(null, null)), strArrayType) + val asn2 = Literal.create(Seq(Seq(null, null, null), Seq("d", null)), strArrayType) + val asn3 = Literal.create(Seq(Seq(null, null), Seq(null, null)), strArrayType) + + checkEvaluation(Flatten(asn1), Seq(null, null, "c", null, null)) + checkEvaluation(Flatten(asn2), Seq(null, null, null, "d", null)) + checkEvaluation(Flatten(asn3), Seq(null, null, null, null)) + + // Test cases with a null array (non-primitive type) + val asa1 = Literal.create(Seq(null, Seq("a", "b")), strArrayType) + val asa2 = Literal.create(Seq(Seq("a", "b"), null), strArrayType) + val asa3 = Literal.create(Seq(null), strArrayType) + val asa4 = Literal.create(null, strArrayType) + + checkEvaluation(Flatten(asa1), null) + checkEvaluation(Flatten(asa2), null) + checkEvaluation(Flatten(asa3), null) + checkEvaluation(Flatten(asa4), null) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index bea8c0e445002..d2f057310f89b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3340,6 +3340,14 @@ object functions { */ def reverse(e: Column): Column = withExpr { Reverse(e.expr) } + /** + * Creates a single array from an array of arrays. If a structure of nested arrays is deeper than + * two levels, only one level of nesting is removed. + * @group collection_funcs + * @since 2.4.0 + */ + def flatten(e: Column): Column = withExpr { Flatten(e.expr) } + /** * Returns an unordered array containing the keys of the map. * @group collection_funcs diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 25e5cd60dd236..03605c30036a3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -691,6 +691,85 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { } } + test("flatten function") { + val dummyFilter = (c: Column) => c.isNull || c.isNotNull // to switch codeGen on + val oneRowDF = Seq((1, "a", Seq(1, 2, 3))).toDF("i", "s", "arr") + + // Test cases with a primitive type + val intDF = Seq( + (Seq(Seq(1, 2, 3), Seq(4, 5), Seq(6))), + (Seq(Seq(1, 2))), + (Seq(Seq(1), Seq.empty)), + (Seq(Seq.empty, Seq(1))), + (Seq(Seq.empty, Seq.empty)), + (Seq(Seq(1), null)), + (Seq(null, Seq(1))), + (Seq(null, null)) + ).toDF("i") + + val intDFResult = Seq( + Row(Seq(1, 2, 3, 4, 5, 6)), + Row(Seq(1, 2)), + Row(Seq(1)), + Row(Seq(1)), + Row(Seq.empty), + Row(null), + Row(null), + Row(null)) + + checkAnswer(intDF.select(flatten($"i")), intDFResult) + checkAnswer(intDF.filter(dummyFilter($"i"))select(flatten($"i")), intDFResult) + checkAnswer(intDF.selectExpr("flatten(i)"), intDFResult) + checkAnswer( + oneRowDF.selectExpr("flatten(array(arr, array(null, 5), array(6, null)))"), + Seq(Row(Seq(1, 2, 3, null, 5, 6, null)))) + + // Test cases with non-primitive types + val strDF = Seq( + (Seq(Seq("a", "b"), Seq("c"), Seq("d", "e", "f"))), + (Seq(Seq("a", "b"))), + (Seq(Seq("a", null), Seq(null, "b"), Seq(null, null))), + (Seq(Seq("a"), Seq.empty)), + (Seq(Seq.empty, Seq("a"))), + (Seq(Seq.empty, Seq.empty)), + (Seq(Seq("a"), null)), + (Seq(null, Seq("a"))), + (Seq(null, null)) + ).toDF("s") + + val strDFResult = Seq( + Row(Seq("a", "b", "c", "d", "e", "f")), + Row(Seq("a", "b")), + Row(Seq("a", null, null, "b", null, null)), + Row(Seq("a")), + Row(Seq("a")), + Row(Seq.empty), + Row(null), + Row(null), + Row(null)) + + checkAnswer(strDF.select(flatten($"s")), strDFResult) + checkAnswer(strDF.filter(dummyFilter($"s")).select(flatten($"s")), strDFResult) + checkAnswer(strDF.selectExpr("flatten(s)"), strDFResult) + checkAnswer( + oneRowDF.selectExpr("flatten(array(array(arr, arr), array(arr)))"), + Seq(Row(Seq(Seq(1, 2, 3), Seq(1, 2, 3), Seq(1, 2, 3))))) + + // Error test cases + intercept[AnalysisException] { + oneRowDF.select(flatten($"arr")) + } + intercept[AnalysisException] { + oneRowDF.select(flatten($"i")) + } + intercept[AnalysisException] { + oneRowDF.select(flatten($"s")) + } + intercept[AnalysisException] { + oneRowDF.selectExpr("flatten(null)") + } + } + private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = { import DataFrameFunctionsSuite.CodegenFallbackExpr for ((codegenFallback, wholeStage) <- Seq((true, false), (false, false), (false, true))) { From 64e8408e6fa2d74929601b01a29771738f6d8c65 Mon Sep 17 00:00:00 2001 From: liutang123 Date: Wed, 25 Apr 2018 18:10:51 +0800 Subject: [PATCH 0687/1161] [SPARK-24012][SQL] Union of map and other compatible column ## What changes were proposed in this pull request? Union of map and other compatible column result in unresolved operator 'Union; exception Reproduction `spark-sql>select map(1,2), 'str' union all select map(1,2,3,null), 1` Output: ``` Error in query: unresolved operator 'Union;; 'Union :- Project [map(1, 2) AS map(1, 2)#106, str AS str#107] : +- OneRowRelation$ +- Project [map(1, cast(2 as int), 3, cast(null as int)) AS map(1, CAST(2 AS INT), 3, CAST(NULL AS INT))#109, 1 AS 1#108] +- OneRowRelation$ ``` So, we should cast part of columns to be compatible when appropriate. ## How was this patch tested? Added a test (query union of map and other columns) to SQLQueryTestSuite's union.sql. Author: liutang123 Closes #21100 from liutang123/SPARK-24012. --- .../sql/catalyst/analysis/TypeCoercion.scala | 8 ++++ .../test/resources/sql-tests/inputs/union.sql | 11 +++++ .../resources/sql-tests/results/union.sql.out | 42 ++++++++++++++----- 3 files changed, 51 insertions(+), 10 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index cfcbd8db559a3..25bad28a2a209 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -112,6 +112,14 @@ object TypeCoercion { StructField(f1.name, dataType, nullable = f1.nullable || f2.nullable) })) + case (a1 @ ArrayType(et1, hasNull1), a2 @ ArrayType(et2, hasNull2)) if a1.sameType(a2) => + findTightestCommonType(et1, et2).map(ArrayType(_, hasNull1 || hasNull2)) + + case (m1 @ MapType(kt1, vt1, hasNull1), m2 @ MapType(kt2, vt2, hasNull2)) if m1.sameType(m2) => + val keyType = findTightestCommonType(kt1, kt2) + val valueType = findTightestCommonType(vt1, vt2) + Some(MapType(keyType.get, valueType.get, hasNull1 || hasNull2)) + case _ => None } diff --git a/sql/core/src/test/resources/sql-tests/inputs/union.sql b/sql/core/src/test/resources/sql-tests/inputs/union.sql index e57d69eaad033..6da1b9b49b226 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/union.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/union.sql @@ -35,6 +35,17 @@ FROM (SELECT col AS col SELECT col FROM p3) T1) T2; +-- SPARK-24012 Union of map and other compatible columns. +SELECT map(1, 2), 'str' +UNION ALL +SELECT map(1, 2, 3, NULL), 1; + +-- SPARK-24012 Union of array and other compatible columns. +SELECT array(1, 2), 'str' +UNION ALL +SELECT array(1, 2, 3, NULL), 1; + + -- Clean-up DROP VIEW IF EXISTS t1; DROP VIEW IF EXISTS t2; diff --git a/sql/core/src/test/resources/sql-tests/results/union.sql.out b/sql/core/src/test/resources/sql-tests/results/union.sql.out index d123b7fdbe0cf..b023df825d814 100644 --- a/sql/core/src/test/resources/sql-tests/results/union.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/union.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 14 +-- Number of queries: 16 -- !query 0 @@ -105,23 +105,29 @@ struct -- !query 9 -DROP VIEW IF EXISTS t1 +SELECT map(1, 2), 'str' +UNION ALL +SELECT map(1, 2, 3, NULL), 1 -- !query 9 schema -struct<> +struct,str:string> -- !query 9 output - +{1:2,3:null} 1 +{1:2} str -- !query 10 -DROP VIEW IF EXISTS t2 +SELECT array(1, 2), 'str' +UNION ALL +SELECT array(1, 2, 3, NULL), 1 -- !query 10 schema -struct<> +struct,str:string> -- !query 10 output - +[1,2,3,null] 1 +[1,2] str -- !query 11 -DROP VIEW IF EXISTS p1 +DROP VIEW IF EXISTS t1 -- !query 11 schema struct<> -- !query 11 output @@ -129,7 +135,7 @@ struct<> -- !query 12 -DROP VIEW IF EXISTS p2 +DROP VIEW IF EXISTS t2 -- !query 12 schema struct<> -- !query 12 output @@ -137,8 +143,24 @@ struct<> -- !query 13 -DROP VIEW IF EXISTS p3 +DROP VIEW IF EXISTS p1 -- !query 13 schema struct<> -- !query 13 output + + +-- !query 14 +DROP VIEW IF EXISTS p2 +-- !query 14 schema +struct<> +-- !query 14 output + + + +-- !query 15 +DROP VIEW IF EXISTS p3 +-- !query 15 schema +struct<> +-- !query 15 output + From 20ca208bcda6f22fe7d9fb54144de435b4237536 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Wed, 25 Apr 2018 19:06:18 +0800 Subject: [PATCH 0688/1161] [SPARK-23880][SQL] Do not trigger any jobs for caching data ## What changes were proposed in this pull request? This pr fixed code so that `cache` could prevent any jobs from being triggered. For example, in the current master, an operation below triggers a actual job; ``` val df = spark.range(10000000000L) .filter('id > 1000) .orderBy('id.desc) .cache() ``` This triggers a job while the cache should be lazy. The problem is that, when creating `InMemoryRelation`, we build the RDD, which calls `SparkPlan.execute` and may trigger jobs, like sampling job for range partitioner, or broadcast job. This pr removed the code to build a cached `RDD` in the constructor of `InMemoryRelation` and added `CachedRDDBuilder` to lazily build the `RDD` in `InMemoryRelation`. Then, the first call of `CachedRDDBuilder.cachedColumnBuffers` triggers a job to materialize the cache in `InMemoryTableScanExec` . ## How was this patch tested? Added tests in `CachedTableSuite`. Author: Takeshi Yamamuro Closes #21018 from maropu/SPARK-23880. --- .../scala/org/apache/spark/sql/Dataset.scala | 2 +- .../spark/sql/execution/CacheManager.scala | 14 +- .../execution/columnar/InMemoryRelation.scala | 155 ++++++++++-------- .../columnar/InMemoryTableScanExec.scala | 10 +- .../apache/spark/sql/CachedTableSuite.scala | 36 +++- .../spark/sql/execution/PlannerSuite.scala | 2 +- .../columnar/InMemoryColumnarQuerySuite.scala | 6 +- .../spark/sql/hive/CachedTableSuite.scala | 2 +- 8 files changed, 133 insertions(+), 94 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 917168162b236..cd4def71e6f3b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -2933,7 +2933,7 @@ class Dataset[T] private[sql]( */ def storageLevel: StorageLevel = { sparkSession.sharedState.cacheManager.lookupCachedData(this).map { cachedData => - cachedData.cachedRepresentation.storageLevel + cachedData.cachedRepresentation.cacheBuilder.storageLevel }.getOrElse(StorageLevel.NONE) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala index a8794be7280c7..93bf91e56f1bd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala @@ -71,7 +71,7 @@ class CacheManager extends Logging { /** Clears all cached tables. */ def clearCache(): Unit = writeLock { - cachedData.asScala.foreach(_.cachedRepresentation.cachedColumnBuffers.unpersist()) + cachedData.asScala.foreach(_.cachedRepresentation.cacheBuilder.clearCache()) cachedData.clear() } @@ -119,7 +119,7 @@ class CacheManager extends Logging { while (it.hasNext) { val cd = it.next() if (cd.plan.find(_.sameResult(plan)).isDefined) { - cd.cachedRepresentation.cachedColumnBuffers.unpersist(blocking) + cd.cachedRepresentation.cacheBuilder.clearCache(blocking) it.remove() } } @@ -138,16 +138,14 @@ class CacheManager extends Logging { while (it.hasNext) { val cd = it.next() if (condition(cd.plan)) { - cd.cachedRepresentation.cachedColumnBuffers.unpersist() + cd.cachedRepresentation.cacheBuilder.clearCache() // Remove the cache entry before we create a new one, so that we can have a different // physical plan. it.remove() + val plan = spark.sessionState.executePlan(cd.plan).executedPlan val newCache = InMemoryRelation( - useCompression = cd.cachedRepresentation.useCompression, - batchSize = cd.cachedRepresentation.batchSize, - storageLevel = cd.cachedRepresentation.storageLevel, - child = spark.sessionState.executePlan(cd.plan).executedPlan, - tableName = cd.cachedRepresentation.tableName, + cacheBuilder = cd.cachedRepresentation + .cacheBuilder.copy(cachedPlan = plan)(_cachedColumnBuffers = null), logicalPlan = cd.plan) needToRecache += cd.copy(cachedRepresentation = newCache) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala index a7ba9b86a176f..da35a4734e65a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala @@ -32,19 +32,6 @@ import org.apache.spark.storage.StorageLevel import org.apache.spark.util.LongAccumulator -object InMemoryRelation { - def apply( - useCompression: Boolean, - batchSize: Int, - storageLevel: StorageLevel, - child: SparkPlan, - tableName: Option[String], - logicalPlan: LogicalPlan): InMemoryRelation = - new InMemoryRelation(child.output, useCompression, batchSize, storageLevel, child, tableName)( - statsOfPlanToCache = logicalPlan.stats, outputOrdering = logicalPlan.outputOrdering) -} - - /** * CachedBatch is a cached batch of rows. * @@ -55,58 +42,41 @@ object InMemoryRelation { private[columnar] case class CachedBatch(numRows: Int, buffers: Array[Array[Byte]], stats: InternalRow) -case class InMemoryRelation( - output: Seq[Attribute], +case class CachedRDDBuilder( useCompression: Boolean, batchSize: Int, storageLevel: StorageLevel, - @transient child: SparkPlan, + @transient cachedPlan: SparkPlan, tableName: Option[String])( - @transient var _cachedColumnBuffers: RDD[CachedBatch] = null, - val sizeInBytesStats: LongAccumulator = child.sqlContext.sparkContext.longAccumulator, - statsOfPlanToCache: Statistics, - override val outputOrdering: Seq[SortOrder]) - extends logical.LeafNode with MultiInstanceRelation { - - override protected def innerChildren: Seq[SparkPlan] = Seq(child) - - override def doCanonicalize(): logical.LogicalPlan = - copy(output = output.map(QueryPlan.normalizeExprId(_, child.output)), - storageLevel = StorageLevel.NONE, - child = child.canonicalized, - tableName = None)( - _cachedColumnBuffers, - sizeInBytesStats, - statsOfPlanToCache, - outputOrdering) + @transient private var _cachedColumnBuffers: RDD[CachedBatch] = null) { - override def producedAttributes: AttributeSet = outputSet - - @transient val partitionStatistics = new PartitionStatistics(output) + val sizeInBytesStats: LongAccumulator = cachedPlan.sqlContext.sparkContext.longAccumulator - override def computeStats(): Statistics = { - if (sizeInBytesStats.value == 0L) { - // Underlying columnar RDD hasn't been materialized, use the stats from the plan to cache. - // Note that we should drop the hint info here. We may cache a plan whose root node is a hint - // node. When we lookup the cache with a semantically same plan without hint info, the plan - // returned by cache lookup should not have hint info. If we lookup the cache with a - // semantically same plan with a different hint info, `CacheManager.useCachedData` will take - // care of it and retain the hint info in the lookup input plan. - statsOfPlanToCache.copy(hints = HintInfo()) - } else { - Statistics(sizeInBytes = sizeInBytesStats.value.longValue) + def cachedColumnBuffers: RDD[CachedBatch] = { + if (_cachedColumnBuffers == null) { + synchronized { + if (_cachedColumnBuffers == null) { + _cachedColumnBuffers = buildBuffers() + } + } } + _cachedColumnBuffers } - // If the cached column buffers were not passed in, we calculate them in the constructor. - // As in Spark, the actual work of caching is lazy. - if (_cachedColumnBuffers == null) { - buildBuffers() + def clearCache(blocking: Boolean = true): Unit = { + if (_cachedColumnBuffers != null) { + synchronized { + if (_cachedColumnBuffers != null) { + _cachedColumnBuffers.unpersist(blocking) + _cachedColumnBuffers = null + } + } + } } - private def buildBuffers(): Unit = { - val output = child.output - val cached = child.execute().mapPartitionsInternal { rowIterator => + private def buildBuffers(): RDD[CachedBatch] = { + val output = cachedPlan.output + val cached = cachedPlan.execute().mapPartitionsInternal { rowIterator => new Iterator[CachedBatch] { def next(): CachedBatch = { val columnBuilders = output.map { attribute => @@ -154,32 +124,77 @@ case class InMemoryRelation( cached.setName( tableName.map(n => s"In-memory table $n") - .getOrElse(StringUtils.abbreviate(child.toString, 1024))) - _cachedColumnBuffers = cached + .getOrElse(StringUtils.abbreviate(cachedPlan.toString, 1024))) + cached + } +} + +object InMemoryRelation { + + def apply( + useCompression: Boolean, + batchSize: Int, + storageLevel: StorageLevel, + child: SparkPlan, + tableName: Option[String], + logicalPlan: LogicalPlan): InMemoryRelation = { + val cacheBuilder = CachedRDDBuilder(useCompression, batchSize, storageLevel, child, tableName)() + new InMemoryRelation(child.output, cacheBuilder)( + statsOfPlanToCache = logicalPlan.stats, outputOrdering = logicalPlan.outputOrdering) + } + + def apply(cacheBuilder: CachedRDDBuilder, logicalPlan: LogicalPlan): InMemoryRelation = { + new InMemoryRelation(cacheBuilder.cachedPlan.output, cacheBuilder)( + statsOfPlanToCache = logicalPlan.stats, outputOrdering = logicalPlan.outputOrdering) + } +} + +case class InMemoryRelation( + output: Seq[Attribute], + @transient cacheBuilder: CachedRDDBuilder)( + statsOfPlanToCache: Statistics, + override val outputOrdering: Seq[SortOrder]) + extends logical.LeafNode with MultiInstanceRelation { + + override protected def innerChildren: Seq[SparkPlan] = Seq(cachedPlan) + + override def doCanonicalize(): logical.LogicalPlan = + copy(output = output.map(QueryPlan.normalizeExprId(_, cachedPlan.output)), + cacheBuilder)( + statsOfPlanToCache, + outputOrdering) + + override def producedAttributes: AttributeSet = outputSet + + @transient val partitionStatistics = new PartitionStatistics(output) + + def cachedPlan: SparkPlan = cacheBuilder.cachedPlan + + override def computeStats(): Statistics = { + if (cacheBuilder.sizeInBytesStats.value == 0L) { + // Underlying columnar RDD hasn't been materialized, use the stats from the plan to cache. + // Note that we should drop the hint info here. We may cache a plan whose root node is a hint + // node. When we lookup the cache with a semantically same plan without hint info, the plan + // returned by cache lookup should not have hint info. If we lookup the cache with a + // semantically same plan with a different hint info, `CacheManager.useCachedData` will take + // care of it and retain the hint info in the lookup input plan. + statsOfPlanToCache.copy(hints = HintInfo()) + } else { + Statistics(sizeInBytes = cacheBuilder.sizeInBytesStats.value.longValue) + } } def withOutput(newOutput: Seq[Attribute]): InMemoryRelation = { - InMemoryRelation( - newOutput, useCompression, batchSize, storageLevel, child, tableName)( - _cachedColumnBuffers, sizeInBytesStats, statsOfPlanToCache, outputOrdering) + InMemoryRelation(newOutput, cacheBuilder)(statsOfPlanToCache, outputOrdering) } override def newInstance(): this.type = { new InMemoryRelation( output.map(_.newInstance()), - useCompression, - batchSize, - storageLevel, - child, - tableName)( - _cachedColumnBuffers, - sizeInBytesStats, + cacheBuilder)( statsOfPlanToCache, outputOrdering).asInstanceOf[this.type] } - def cachedColumnBuffers: RDD[CachedBatch] = _cachedColumnBuffers - - override protected def otherCopyArgs: Seq[AnyRef] = - Seq(_cachedColumnBuffers, sizeInBytesStats, statsOfPlanToCache) + override protected def otherCopyArgs: Seq[AnyRef] = Seq(statsOfPlanToCache) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala index e73e1378d52e3..ea315fb71617c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala @@ -154,7 +154,7 @@ case class InMemoryTableScanExec( private def updateAttribute(expr: Expression): Expression = { // attributes can be pruned so using relation's output. // E.g., relation.output is [id, item] but this scan's output can be [item] only. - val attrMap = AttributeMap(relation.child.output.zip(relation.output)) + val attrMap = AttributeMap(relation.cachedPlan.output.zip(relation.output)) expr.transform { case attr: Attribute => attrMap.getOrElse(attr, attr) } @@ -163,16 +163,16 @@ case class InMemoryTableScanExec( // The cached version does not change the outputPartitioning of the original SparkPlan. // But the cached version could alias output, so we need to replace output. override def outputPartitioning: Partitioning = { - relation.child.outputPartitioning match { + relation.cachedPlan.outputPartitioning match { case h: HashPartitioning => updateAttribute(h).asInstanceOf[HashPartitioning] - case _ => relation.child.outputPartitioning + case _ => relation.cachedPlan.outputPartitioning } } // The cached version does not change the outputOrdering of the original SparkPlan. // But the cached version could alias output, so we need to replace output. override def outputOrdering: Seq[SortOrder] = - relation.child.outputOrdering.map(updateAttribute(_).asInstanceOf[SortOrder]) + relation.cachedPlan.outputOrdering.map(updateAttribute(_).asInstanceOf[SortOrder]) // Keeps relation's partition statistics because we don't serialize relation. private val stats = relation.partitionStatistics @@ -252,7 +252,7 @@ case class InMemoryTableScanExec( // within the map Partitions closure. val schema = stats.schema val schemaIndex = schema.zipWithIndex - val buffers = relation.cachedColumnBuffers + val buffers = relation.cacheBuilder.cachedColumnBuffers buffers.mapPartitionsWithIndexInternal { (index, cachedBatchIterator) => val partitionFilter = newPredicate( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index 669e5f2bf4e65..81b7e18773f81 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -22,6 +22,7 @@ import scala.concurrent.duration._ import scala.language.postfixOps import org.apache.spark.CleanerListener +import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.expressions.SubqueryExpression import org.apache.spark.sql.execution.{RDDScanExec, SparkPlan} @@ -52,7 +53,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext val plan = spark.table(tableName).queryExecution.sparkPlan plan.collect { case InMemoryTableScanExec(_, _, relation) => - relation.cachedColumnBuffers.id + relation.cacheBuilder.cachedColumnBuffers.id case _ => fail(s"Table $tableName is not cached\n" + plan) }.head @@ -78,7 +79,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext private def getNumInMemoryTablesRecursively(plan: SparkPlan): Int = { plan.collect { case InMemoryTableScanExec(_, _, relation) => - getNumInMemoryTablesRecursively(relation.child) + 1 + getNumInMemoryTablesRecursively(relation.cachedPlan) + 1 }.sum } @@ -200,7 +201,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext spark.catalog.cacheTable("testData") assertResult(0, "Double InMemoryRelations found, cacheTable() is not idempotent") { spark.table("testData").queryExecution.withCachedData.collect { - case r @ InMemoryRelation(_, _, _, _, _: InMemoryTableScanExec, _) => r + case r: InMemoryRelation if r.cachedPlan.isInstanceOf[InMemoryTableScanExec] => r }.size } @@ -367,12 +368,12 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext val toBeCleanedAccIds = new HashSet[Long] val accId1 = spark.table("t1").queryExecution.withCachedData.collect { - case i: InMemoryRelation => i.sizeInBytesStats.id + case i: InMemoryRelation => i.cacheBuilder.sizeInBytesStats.id }.head toBeCleanedAccIds += accId1 val accId2 = spark.table("t1").queryExecution.withCachedData.collect { - case i: InMemoryRelation => i.sizeInBytesStats.id + case i: InMemoryRelation => i.cacheBuilder.sizeInBytesStats.id }.head toBeCleanedAccIds += accId2 @@ -794,4 +795,29 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext } } } + + private def checkIfNoJobTriggered[T](f: => T): T = { + var numJobTrigered = 0 + val jobListener = new SparkListener { + override def onJobStart(jobStart: SparkListenerJobStart): Unit = { + numJobTrigered += 1 + } + } + sparkContext.addSparkListener(jobListener) + try { + val result = f + sparkContext.listenerBus.waitUntilEmpty(10000L) + assert(numJobTrigered === 0) + result + } finally { + sparkContext.removeSparkListener(jobListener) + } + } + + test("SPARK-23880 table cache should be lazy and don't trigger any jobs") { + val cachedData = checkIfNoJobTriggered { + spark.range(1002).filter('id > 1000).orderBy('id.desc).cache() + } + assert(cachedData.collect === Seq(1001)) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 40915a102bab0..f0dfe6b76f7ae 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -194,7 +194,7 @@ class PlannerSuite extends SharedSQLContext { test("CollectLimit can appear in the middle of a plan when caching is used") { val query = testData.select('key, 'value).limit(2).cache() val planned = query.queryExecution.optimizedPlan.asInstanceOf[InMemoryRelation] - assert(planned.child.isInstanceOf[CollectLimitExec]) + assert(planned.cachedPlan.isInstanceOf[CollectLimitExec]) } test("SPARK-23375: Cached sorted data doesn't need to be re-sorted") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala index 9b7b316211d30..863703b15f4f1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala @@ -45,8 +45,8 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { val inMemoryRelation = InMemoryRelation(useCompression = true, 5, storageLevel, plan, None, data.logicalPlan) - assert(inMemoryRelation.cachedColumnBuffers.getStorageLevel == storageLevel) - inMemoryRelation.cachedColumnBuffers.collect().head match { + assert(inMemoryRelation.cacheBuilder.cachedColumnBuffers.getStorageLevel == storageLevel) + inMemoryRelation.cacheBuilder.cachedColumnBuffers.collect().head match { case _: CachedBatch => case other => fail(s"Unexpected cached batch type: ${other.getClass.getName}") } @@ -337,7 +337,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { checkAnswer(cached, expectedAnswer) // Check that the right size was calculated. - assert(cached.sizeInBytesStats.value === expectedAnswer.size * INT.defaultSize) + assert(cached.cacheBuilder.sizeInBytesStats.value === expectedAnswer.size * INT.defaultSize) } test("access primitive-type columns in CachedBatch without whole stage codegen") { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala index 48ab4eb9a6178..569f00c053e5f 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala @@ -38,7 +38,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with TestHiveSingleto val plan = table(tableName).queryExecution.sparkPlan plan.collect { case InMemoryTableScanExec(_, _, relation) => - relation.cachedColumnBuffers.id + relation.cacheBuilder.cachedColumnBuffers.id case _ => fail(s"Table $tableName is not cached\n" + plan) }.head From 396938ef02c70468e1695872f96b1e9aff28b7ea Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Wed, 25 Apr 2018 12:21:55 -0700 Subject: [PATCH 0689/1161] [SPARK-24050][SS] Calculate input / processing rates correctly for DataSourceV2 streaming sources ## What changes were proposed in this pull request? In some streaming queries, the input and processing rates are not calculated at all (shows up as zero) because MicroBatchExecution fails to associated metrics from the executed plan of a trigger with the sources in the logical plan of the trigger. The way this executed-plan-leaf-to-logical-source attribution works is as follows. With V1 sources, there was no way to identify which execution plan leaves were generated by a streaming source. So did a best-effort attempt to match logical and execution plan leaves when the number of leaves were same. In cases where the number of leaves is different, we just give up and report zero rates. An example where this may happen is as follows. ``` val cachedStaticDF = someStaticDF.union(anotherStaticDF).cache() val streamingInputDF = ... val query = streamingInputDF.join(cachedStaticDF).writeStream.... ``` In this case, the `cachedStaticDF` has multiple logical leaves, but in the trigger's execution plan it only has leaf because a cached subplan is represented as a single InMemoryTableScanExec leaf. This leads to a mismatch in the number of leaves causing the input rates to be computed as zero. With DataSourceV2, all inputs are represented in the executed plan using `DataSourceV2ScanExec`, each of which has a reference to the associated logical `DataSource` and `DataSourceReader`. So its easy to associate the metrics to the original streaming sources. In this PR, the solution is as follows. If all the streaming sources in a streaming query as v2 sources, then use a new code path where the execution-metrics-to-source mapping is done directly. Otherwise we fall back to existing mapping logic. ## How was this patch tested? - New unit tests using V2 memory source - Existing unit tests using V1 source Author: Tathagata Das Closes #21126 from tdas/SPARK-24050. --- .../kafka010/KafkaMicroBatchSourceSuite.scala | 9 +- .../streaming/ProgressReporter.scala | 146 +++++++++++++----- .../sql/streaming/StreamingQuerySuite.scala | 134 +++++++++++++++- 3 files changed, 245 insertions(+), 44 deletions(-) diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala index e017fd9b84d21..d2d04b68de6ab 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala @@ -563,7 +563,7 @@ abstract class KafkaMicroBatchSourceSuiteBase extends KafkaSourceSuiteBase { ) } - test("ensure stream-stream self-join generates only one offset in offset log") { + test("ensure stream-stream self-join generates only one offset in log and correct metrics") { val topic = newTopic() testUtils.createTopic(topic, partitions = 2) require(testUtils.getLatestOffsets(Set(topic)).size === 2) @@ -587,7 +587,12 @@ abstract class KafkaMicroBatchSourceSuiteBase extends KafkaSourceSuiteBase { AddKafkaData(Set(topic), 1, 2), CheckAnswer((1, 1, 1), (2, 2, 2)), AddKafkaData(Set(topic), 6, 3), - CheckAnswer((1, 1, 1), (2, 2, 2), (3, 3, 3), (1, 6, 1), (1, 1, 6), (1, 6, 6)) + CheckAnswer((1, 1, 1), (2, 2, 2), (3, 3, 3), (1, 6, 1), (1, 1, 6), (1, 6, 6)), + AssertOnQuery { q => + assert(q.availableOffsets.iterator.size == 1) + assert(q.recentProgress.map(_.numInputRows).sum == 4) + true + } ) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala index d1e5be9c12762..16ad3ef9a3d4a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala @@ -28,6 +28,8 @@ import org.apache.spark.sql.{DataFrame, SparkSession} import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, LogicalPlan} import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.QueryExecution +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanExec +import org.apache.spark.sql.sources.v2.reader.streaming.MicroBatchReader import org.apache.spark.sql.streaming._ import org.apache.spark.sql.streaming.StreamingQueryListener.QueryProgressEvent import org.apache.spark.util.Clock @@ -141,7 +143,7 @@ trait ProgressReporter extends Logging { } logDebug(s"Execution stats: $executionStats") - val sourceProgress = sources.map { source => + val sourceProgress = sources.distinct.map { source => val numRecords = executionStats.inputRows.getOrElse(source, 0L) new SourceProgress( description = source.toString, @@ -207,62 +209,126 @@ trait ProgressReporter extends Logging { return ExecutionStats(Map.empty, stateOperators, watermarkTimestamp) } - // We want to associate execution plan leaves to sources that generate them, so that we match - // the their metrics (e.g. numOutputRows) to the sources. To do this we do the following. - // Consider the translation from the streaming logical plan to the final executed plan. - // - // streaming logical plan (with sources) <==> trigger's logical plan <==> executed plan - // - // 1. We keep track of streaming sources associated with each leaf in the trigger's logical plan - // - Each logical plan leaf will be associated with a single streaming source. - // - There can be multiple logical plan leaves associated with a streaming source. - // - There can be leaves not associated with any streaming source, because they were - // generated from a batch source (e.g. stream-batch joins) - // - // 2. Assuming that the executed plan has same number of leaves in the same order as that of - // the trigger logical plan, we associate executed plan leaves with corresponding - // streaming sources. - // - // 3. For each source, we sum the metrics of the associated execution plan leaves. - // - val logicalPlanLeafToSource = newData.flatMap { case (source, logicalPlan) => - logicalPlan.collectLeaves().map { leaf => leaf -> source } + val numInputRows = extractSourceToNumInputRows() + + val eventTimeStats = lastExecution.executedPlan.collect { + case e: EventTimeWatermarkExec if e.eventTimeStats.value.count > 0 => + val stats = e.eventTimeStats.value + Map( + "max" -> stats.max, + "min" -> stats.min, + "avg" -> stats.avg.toLong).mapValues(formatTimestamp) + }.headOption.getOrElse(Map.empty) ++ watermarkTimestamp + + ExecutionStats(numInputRows, stateOperators, eventTimeStats) + } + + /** Extract number of input sources for each streaming source in plan */ + private def extractSourceToNumInputRows(): Map[BaseStreamingSource, Long] = { + + import java.util.IdentityHashMap + import scala.collection.JavaConverters._ + + def sumRows(tuples: Seq[(BaseStreamingSource, Long)]): Map[BaseStreamingSource, Long] = { + tuples.groupBy(_._1).mapValues(_.map(_._2).sum) // sum up rows for each source } - val allLogicalPlanLeaves = lastExecution.logical.collectLeaves() // includes non-streaming - val allExecPlanLeaves = lastExecution.executedPlan.collectLeaves() - val numInputRows: Map[BaseStreamingSource, Long] = + + val onlyDataSourceV2Sources = { + // Check whether the streaming query's logical plan has only V2 data sources + val allStreamingLeaves = + logicalPlan.collect { case s: StreamingExecutionRelation => s } + allStreamingLeaves.forall { _.source.isInstanceOf[MicroBatchReader] } + } + + if (onlyDataSourceV2Sources) { + // DataSourceV2ScanExec is the execution plan leaf that is responsible for reading data + // from a V2 source and has a direct reference to the V2 source that generated it. Each + // DataSourceV2ScanExec records the number of rows it has read using SQLMetrics. However, + // just collecting all DataSourceV2ScanExec nodes and getting the metric is not correct as + // a DataSourceV2ScanExec instance may be referred to in the execution plan from two (or + // even multiple times) points and considering it twice will leads to double counting. We + // can't dedup them using their hashcode either because two different instances of + // DataSourceV2ScanExec can have the same hashcode but account for separate sets of + // records read, and deduping them to consider only one of them would be undercounting the + // records read. Therefore the right way to do this is to consider the unique instances of + // DataSourceV2ScanExec (using their identity hash codes) and get metrics from them. + // Hence we calculate in the following way. + // + // 1. Collect all the unique DataSourceV2ScanExec instances using IdentityHashMap. + // + // 2. Extract the source and the number of rows read from the DataSourceV2ScanExec instanes. + // + // 3. Multiple DataSourceV2ScanExec instance may refer to the same source (can happen with + // self-unions or self-joins). Add up the number of rows for each unique source. + val uniqueStreamingExecLeavesMap = + new IdentityHashMap[DataSourceV2ScanExec, DataSourceV2ScanExec]() + + lastExecution.executedPlan.collectLeaves().foreach { + case s: DataSourceV2ScanExec if s.reader.isInstanceOf[BaseStreamingSource] => + uniqueStreamingExecLeavesMap.put(s, s) + case _ => + } + + val sourceToInputRowsTuples = + uniqueStreamingExecLeavesMap.values.asScala.map { execLeaf => + val numRows = execLeaf.metrics.get("numOutputRows").map(_.value).getOrElse(0L) + val source = execLeaf.reader.asInstanceOf[BaseStreamingSource] + source -> numRows + }.toSeq + logDebug("Source -> # input rows\n\t" + sourceToInputRowsTuples.mkString("\n\t")) + sumRows(sourceToInputRowsTuples) + } else { + + // Since V1 source do not generate execution plan leaves that directly link with source that + // generated it, we can only do a best-effort association between execution plan leaves to the + // sources. This is known to fail in a few cases, see SPARK-24050. + // + // We want to associate execution plan leaves to sources that generate them, so that we match + // the their metrics (e.g. numOutputRows) to the sources. To do this we do the following. + // Consider the translation from the streaming logical plan to the final executed plan. + // + // streaming logical plan (with sources) <==> trigger's logical plan <==> executed plan + // + // 1. We keep track of streaming sources associated with each leaf in trigger's logical plan + // - Each logical plan leaf will be associated with a single streaming source. + // - There can be multiple logical plan leaves associated with a streaming source. + // - There can be leaves not associated with any streaming source, because they were + // generated from a batch source (e.g. stream-batch joins) + // + // 2. Assuming that the executed plan has same number of leaves in the same order as that of + // the trigger logical plan, we associate executed plan leaves with corresponding + // streaming sources. + // + // 3. For each source, we sum the metrics of the associated execution plan leaves. + // + val logicalPlanLeafToSource = newData.flatMap { case (source, logicalPlan) => + logicalPlan.collectLeaves().map { leaf => leaf -> source } + } + val allLogicalPlanLeaves = lastExecution.logical.collectLeaves() // includes non-streaming + val allExecPlanLeaves = lastExecution.executedPlan.collectLeaves() if (allLogicalPlanLeaves.size == allExecPlanLeaves.size) { val execLeafToSource = allLogicalPlanLeaves.zip(allExecPlanLeaves).flatMap { case (lp, ep) => logicalPlanLeafToSource.get(lp).map { source => ep -> source } } - val sourceToNumInputRows = execLeafToSource.map { case (execLeaf, source) => + val sourceToInputRowsTuples = execLeafToSource.map { case (execLeaf, source) => val numRows = execLeaf.metrics.get("numOutputRows").map(_.value).getOrElse(0L) source -> numRows } - sourceToNumInputRows.groupBy(_._1).mapValues(_.map(_._2).sum) // sum up rows for each source + sumRows(sourceToInputRowsTuples) } else { if (!metricWarningLogged) { def toString[T](seq: Seq[T]): String = s"(size = ${seq.size}), ${seq.mkString(", ")}" + logWarning( "Could not report metrics as number leaves in trigger logical plan did not match that" + - s" of the execution plan:\n" + - s"logical plan leaves: ${toString(allLogicalPlanLeaves)}\n" + - s"execution plan leaves: ${toString(allExecPlanLeaves)}\n") + s" of the execution plan:\n" + + s"logical plan leaves: ${toString(allLogicalPlanLeaves)}\n" + + s"execution plan leaves: ${toString(allExecPlanLeaves)}\n") metricWarningLogged = true } Map.empty } - - val eventTimeStats = lastExecution.executedPlan.collect { - case e: EventTimeWatermarkExec if e.eventTimeStats.value.count > 0 => - val stats = e.eventTimeStats.value - Map( - "max" -> stats.max, - "min" -> stats.min, - "avg" -> stats.avg.toLong).mapValues(formatTimestamp) - }.headOption.getOrElse(Map.empty) ++ watermarkTimestamp - - ExecutionStats(numInputRows, stateOperators, eventTimeStats) + } } /** Records the duration of running `body` for the next query progress update. */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala index 20942ed93897c..390d67d1feb27 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala @@ -466,7 +466,17 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi } } - test("input row calculation with mixed batch and streaming sources") { + test("input row calculation with same V1 source used twice in self-join") { + val streamingTriggerDF = spark.createDataset(1 to 10).toDF + val streamingInputDF = createSingleTriggerStreamingDF(streamingTriggerDF).toDF("value") + + val progress = getFirstProgress(streamingInputDF.join(streamingInputDF, "value")) + assert(progress.numInputRows === 20) // data is read multiple times in self-joins + assert(progress.sources.size === 1) + assert(progress.sources(0).numInputRows === 20) + } + + test("input row calculation with mixed batch and streaming V1 sources") { val streamingTriggerDF = spark.createDataset(1 to 10).toDF val streamingInputDF = createSingleTriggerStreamingDF(streamingTriggerDF).toDF("value") val staticInputDF = spark.createDataFrame(Seq(1 -> "1", 2 -> "2")).toDF("value", "anotherValue") @@ -479,7 +489,7 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi assert(progress.sources(0).numInputRows === 10) } - test("input row calculation with trigger input DF having multiple leaves") { + test("input row calculation with trigger input DF having multiple leaves in V1 source") { val streamingTriggerDF = spark.createDataset(1 to 5).toDF.union(spark.createDataset(6 to 10).toDF) require(streamingTriggerDF.logicalPlan.collectLeaves().size > 1) @@ -492,6 +502,121 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi assert(progress.sources(0).numInputRows === 10) } + test("input row calculation with same V2 source used twice in self-union") { + val streamInput = MemoryStream[Int] + + testStream(streamInput.toDF().union(streamInput.toDF()), useV2Sink = true)( + AddData(streamInput, 1, 2, 3), + CheckAnswer(1, 1, 2, 2, 3, 3), + AssertOnQuery { q => + val lastProgress = getLastProgressWithData(q) + assert(lastProgress.nonEmpty) + assert(lastProgress.get.numInputRows == 6) + assert(lastProgress.get.sources.length == 1) + assert(lastProgress.get.sources(0).numInputRows == 6) + true + } + ) + } + + test("input row calculation with same V2 source used twice in self-join") { + val streamInput = MemoryStream[Int] + val df = streamInput.toDF() + testStream(df.join(df, "value"), useV2Sink = true)( + AddData(streamInput, 1, 2, 3), + CheckAnswer(1, 2, 3), + AssertOnQuery { q => + val lastProgress = getLastProgressWithData(q) + assert(lastProgress.nonEmpty) + assert(lastProgress.get.numInputRows == 6) + assert(lastProgress.get.sources.length == 1) + assert(lastProgress.get.sources(0).numInputRows == 6) + true + } + ) + } + + test("input row calculation with trigger having data for only one of two V2 sources") { + val streamInput1 = MemoryStream[Int] + val streamInput2 = MemoryStream[Int] + + testStream(streamInput1.toDF().union(streamInput2.toDF()), useV2Sink = true)( + AddData(streamInput1, 1, 2, 3), + CheckLastBatch(1, 2, 3), + AssertOnQuery { q => + val lastProgress = getLastProgressWithData(q) + assert(lastProgress.nonEmpty) + assert(lastProgress.get.numInputRows == 3) + assert(lastProgress.get.sources.length == 2) + assert(lastProgress.get.sources(0).numInputRows == 3) + assert(lastProgress.get.sources(1).numInputRows == 0) + true + }, + AddData(streamInput2, 4, 5), + CheckLastBatch(4, 5), + AssertOnQuery { q => + val lastProgress = getLastProgressWithData(q) + assert(lastProgress.nonEmpty) + assert(lastProgress.get.numInputRows == 2) + assert(lastProgress.get.sources.length == 2) + assert(lastProgress.get.sources(0).numInputRows == 0) + assert(lastProgress.get.sources(1).numInputRows == 2) + true + } + ) + } + + test("input row calculation with mixed batch and streaming V2 sources") { + + val streamInput = MemoryStream[Int] + val staticInputDF = spark.createDataFrame(Seq(1 -> "1", 2 -> "2")).toDF("value", "anotherValue") + + testStream(streamInput.toDF().join(staticInputDF, "value"), useV2Sink = true)( + AddData(streamInput, 1, 2, 3), + AssertOnQuery { q => + q.processAllAvailable() + + // The number of leaves in the trigger's logical plan should be same as the executed plan. + require( + q.lastExecution.logical.collectLeaves().length == + q.lastExecution.executedPlan.collectLeaves().length) + + val lastProgress = getLastProgressWithData(q) + assert(lastProgress.nonEmpty) + assert(lastProgress.get.numInputRows == 3) + assert(lastProgress.get.sources.length == 1) + assert(lastProgress.get.sources(0).numInputRows == 3) + true + } + ) + + val streamInput2 = MemoryStream[Int] + val staticInputDF2 = staticInputDF.union(staticInputDF).cache() + + testStream(streamInput2.toDF().join(staticInputDF2, "value"), useV2Sink = true)( + AddData(streamInput2, 1, 2, 3), + AssertOnQuery { q => + q.processAllAvailable() + // The number of leaves in the trigger's logical plan should be different from + // the executed plan. The static input will have two leaves in the logical plan + // (due to the union), but will be converted to a single leaf in the executed plan + // (due to the caching, the cached subplan is replaced by a single InMemoryTableScanExec). + require( + q.lastExecution.logical.collectLeaves().length != + q.lastExecution.executedPlan.collectLeaves().length) + + // Despite the mismatch in total number of leaves in the logical and executed plans, + // we should be able to attribute streaming input metrics to the streaming sources. + val lastProgress = getLastProgressWithData(q) + assert(lastProgress.nonEmpty) + assert(lastProgress.get.numInputRows == 3) + assert(lastProgress.get.sources.length == 1) + assert(lastProgress.get.sources(0).numInputRows == 3) + true + } + ) + } + testQuietly("StreamExecution metadata garbage collection") { val inputData = MemoryStream[Int] val mapped = inputData.toDS().map(6 / _) @@ -733,6 +858,11 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi } } + /** Returns the last query progress from query.recentProgress where numInputRows is positive */ + def getLastProgressWithData(q: StreamingQuery): Option[StreamingQueryProgress] = { + q.recentProgress.filter(_.numInputRows > 0).lastOption + } + /** * A [[StreamAction]] to test the behavior of `StreamingQuery.awaitTermination()`. * From ac4ca7c4dd3ff666ec70aeb26ac84cffa557ee12 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 25 Apr 2018 13:42:44 -0700 Subject: [PATCH 0690/1161] [SPARK-24012][SQL][TEST][FOLLOWUP] add unit test ## What changes were proposed in this pull request? a followup of https://github.com/apache/spark/pull/21100 ## How was this patch tested? N/A Author: Wenchen Fan Closes #21154 from cloud-fan/test. --- .../catalyst/analysis/TypeCoercionSuite.scala | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala index fd6a3121663ed..1cc431aaf0a60 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala @@ -429,6 +429,24 @@ class TypeCoercionSuite extends AnalysisTest { Some(StructType(Seq(StructField("a", IntegerType), StructField("B", IntegerType)))), isSymmetric = false) } + + widenTest( + ArrayType(IntegerType, containsNull = true), + ArrayType(IntegerType, containsNull = false), + Some(ArrayType(IntegerType, containsNull = true))) + + widenTest( + MapType(IntegerType, StringType, valueContainsNull = true), + MapType(IntegerType, StringType, valueContainsNull = false), + Some(MapType(IntegerType, StringType, valueContainsNull = true))) + + widenTest( + new StructType() + .add("arr", ArrayType(IntegerType, containsNull = true), nullable = false), + new StructType() + .add("arr", ArrayType(IntegerType, containsNull = false), nullable = true), + Some(new StructType() + .add("arr", ArrayType(IntegerType, containsNull = true), nullable = true))) } test("wider common type for decimal and array") { From 95a651339ec39d5753e849e578ad715be0d7c83e Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Thu, 26 Apr 2018 09:12:38 +0800 Subject: [PATCH 0691/1161] [SPARK-24069][R] Add array_min / array_max functions ## What changes were proposed in this pull request? This PR proposes to add array_max and array_min in R side too. array_max: ```r df <- createDataFrame(cbind(model = rownames(mtcars), mtcars)) mutated <- mutate(df, v1 = create_array(df$gear, df$am, df$carb)) head(select(mutated, array_max(mutated$v1))) ``` ``` array_max(v1) 1 4 2 4 3 4 4 3 5 3 6 3 ``` array_min: ```r df <- createDataFrame(cbind(model = rownames(mtcars), mtcars)) mutated <- mutate(df, v1 = create_array(df$mpg, df$cyl, df$hp)) head(select(mutated, array_min(mutated$v1))) ``` ``` array_min(v1) 1 6 2 6 3 4 4 6 5 8 6 6 ``` ## How was this patch tested? Unit tests were added in `R/pkg/tests/fulltests/test_sparkSQL.R` and manually tested. Documentation was manually built and verified. Author: hyukjinkwon Closes #21142 from HyukjinKwon/sparkr_array_min_array_max. --- R/pkg/NAMESPACE | 2 ++ R/pkg/R/functions.R | 27 +++++++++++++++++++++++++++ R/pkg/R/generics.R | 8 ++++++++ R/pkg/tests/fulltests/test_sparkSQL.R | 9 ++++++++- 4 files changed, 45 insertions(+), 1 deletion(-) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 55dec177ea853..f36d462a83cb0 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -201,6 +201,8 @@ exportMethods("%<=>%", "approxCountDistinct", "approxQuantile", "array_contains", + "array_max", + "array_min", "array_position", "asc", "ascii", diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 7b3aa05074563..ec4bd4e73c7e5 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -206,6 +206,7 @@ NULL #' df <- createDataFrame(cbind(model = rownames(mtcars), mtcars)) #' tmp <- mutate(df, v1 = create_array(df$mpg, df$cyl, df$hp)) #' head(select(tmp, array_contains(tmp$v1, 21), size(tmp$v1))) +#' head(select(tmp, array_max(tmp$v1), array_min(tmp$v1))) #' head(select(tmp, array_position(tmp$v1, 21))) #' tmp2 <- mutate(tmp, v2 = explode(tmp$v1)) #' head(tmp2) @@ -2992,6 +2993,32 @@ setMethod("array_contains", column(jc) }) +#' @details +#' \code{array_max}: Returns the maximum value of the array. +#' +#' @rdname column_collection_functions +#' @aliases array_max array_max,Column-method +#' @note array_max since 2.4.0 +setMethod("array_max", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "array_max", x@jc) + column(jc) + }) + +#' @details +#' \code{array_min}: Returns the minimum value of the array. +#' +#' @rdname column_collection_functions +#' @aliases array_min array_min,Column-method +#' @note array_min since 2.4.0 +setMethod("array_min", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "array_min", x@jc) + column(jc) + }) + #' @details #' \code{array_position}: Locates the position of the first occurrence of the given value #' in the given array. Returns NA if either of the arguments are NA. diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index f30ac9e4295e4..562d3399ee9c8 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -757,6 +757,14 @@ setGeneric("approxCountDistinct", function(x, ...) { standardGeneric("approxCoun #' @name NULL setGeneric("array_contains", function(x, value) { standardGeneric("array_contains") }) +#' @rdname column_collection_functions +#' @name NULL +setGeneric("array_max", function(x) { standardGeneric("array_max") }) + +#' @rdname column_collection_functions +#' @name NULL +setGeneric("array_min", function(x) { standardGeneric("array_min") }) + #' @rdname column_collection_functions #' @name NULL setGeneric("array_position", function(x, value) { standardGeneric("array_position") }) diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index a384997830276..8cc2db7a140f9 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -1479,11 +1479,18 @@ test_that("column functions", { df5 <- createDataFrame(list(list(a = "010101"))) expect_equal(collect(select(df5, conv(df5$a, 2, 16)))[1, 1], "15") - # Test array_contains(), array_position(), element_at() and sort_array() + # Test array_contains(), array_max(), array_min(), array_position(), element_at() + # and sort_array() df <- createDataFrame(list(list(list(1L, 2L, 3L)), list(list(6L, 5L, 4L)))) result <- collect(select(df, array_contains(df[[1]], 1L)))[[1]] expect_equal(result, c(TRUE, FALSE)) + result <- collect(select(df, array_max(df[[1]])))[[1]] + expect_equal(result, c(3, 6)) + + result <- collect(select(df, array_min(df[[1]])))[[1]] + expect_equal(result, c(1, 4)) + result <- collect(select(df, array_position(df[[1]], 1L)))[[1]] expect_equal(result, c(1, 0)) From 3f1e999d3d215bb3b867bcd83ec5c799448ec730 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Thu, 26 Apr 2018 09:14:24 +0800 Subject: [PATCH 0692/1161] [SPARK-23849][SQL] Tests for samplingRatio of json datasource ## What changes were proposed in this pull request? Added the `samplingRatio` option to the `json()` method of PySpark DataFrame Reader. Improving existing tests for Scala API according to review of the PR: https://github.com/apache/spark/pull/20959 ## How was this patch tested? Added new test for PySpark, updated 2 existing tests according to reviews of https://github.com/apache/spark/pull/20959 and added new negative test Author: Maxim Gekk Closes #21056 from MaxGekk/json-sampling. --- python/pyspark/sql/readwriter.py | 7 ++- python/pyspark/sql/tests.py | 8 +++ .../apache/spark/sql/DataFrameReader.scala | 2 + .../datasources/json/JsonSuite.scala | 63 ++++++++++--------- .../datasources/json/TestJsonData.scala | 12 ++++ 5 files changed, 61 insertions(+), 31 deletions(-) diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 6bd79bc2f43e5..df176c579fc8b 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -176,7 +176,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, allowComments=None, allowUnquotedFieldNames=None, allowSingleQuotes=None, allowNumericLeadingZero=None, allowBackslashEscapingAnyCharacter=None, mode=None, columnNameOfCorruptRecord=None, dateFormat=None, timestampFormat=None, - multiLine=None, allowUnquotedControlChars=None, lineSep=None): + multiLine=None, allowUnquotedControlChars=None, lineSep=None, samplingRatio=None): """ Loads JSON files and returns the results as a :class:`DataFrame`. @@ -239,6 +239,8 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, including tab and line feed characters) or not. :param lineSep: defines the line separator that should be used for parsing. If None is set, it covers all ``\\r``, ``\\r\\n`` and ``\\n``. + :param samplingRatio: defines fraction of input JSON objects used for schema inferring. + If None is set, it uses the default value, ``1.0``. >>> df1 = spark.read.json('python/test_support/sql/people.json') >>> df1.dtypes @@ -256,7 +258,8 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, allowBackslashEscapingAnyCharacter=allowBackslashEscapingAnyCharacter, mode=mode, columnNameOfCorruptRecord=columnNameOfCorruptRecord, dateFormat=dateFormat, timestampFormat=timestampFormat, multiLine=multiLine, - allowUnquotedControlChars=allowUnquotedControlChars, lineSep=lineSep) + allowUnquotedControlChars=allowUnquotedControlChars, lineSep=lineSep, + samplingRatio=samplingRatio) if isinstance(path, basestring): path = [path] if type(path) == list: diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 4e99c8e3c6b10..98fa1b54b0a17 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3018,6 +3018,14 @@ def test_sort_with_nulls_order(self): df.select(df.name).orderBy(functions.desc_nulls_last('name')).collect(), [Row(name=u'Tom'), Row(name=u'Alice'), Row(name=None)]) + def test_json_sampling_ratio(self): + rdd = self.spark.sparkContext.range(0, 100, 1, 1) \ + .map(lambda x: '{"a":0.1}' if x == 1 else '{"a":%s}' % str(x)) + schema = self.spark.read.option('inferSchema', True) \ + .option('samplingRatio', 0.5) \ + .json(rdd).schema + self.assertEquals(schema, StructType([StructField("a", LongType(), True)])) + class HiveSparkSubmitTests(SparkSubmitTests): diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index d640fdc530ce2..b44552f0eb17b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -374,6 +374,8 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * per file *
  • `lineSep` (default covers all `\r`, `\r\n` and `\n`): defines the line separator * that should be used for parsing.
  • + *
  • `samplingRatio` (default is 1.0): defines fraction of input JSON objects used + * for schema inferring.
  • * * * @since 2.0.0 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 70aee561ff0f6..a58dff827b92d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -2128,38 +2128,43 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } } - test("SPARK-23849: schema inferring touches less data if samplingRation < 1.0") { - val predefinedSample = Set[Int](2, 8, 15, 27, 30, 34, 35, 37, 44, 46, - 57, 62, 68, 72) - withTempPath { path => - val writer = Files.newBufferedWriter(Paths.get(path.getAbsolutePath), - StandardCharsets.UTF_8, StandardOpenOption.CREATE_NEW) - for (i <- 0 until 100) { - if (predefinedSample.contains(i)) { - writer.write(s"""{"f1":${i.toString}}""" + "\n") - } else { - writer.write(s"""{"f1":${(i.toDouble + 0.1).toString}}""" + "\n") - } - } - writer.close() + test("SPARK-23849: schema inferring touches less data if samplingRatio < 1.0") { + // Set default values for the DataSource parameters to make sure + // that whole test file is mapped to only one partition. This will guarantee + // reliable sampling of the input file. + withSQLConf( + "spark.sql.files.maxPartitionBytes" -> (128 * 1024 * 1024).toString, + "spark.sql.files.openCostInBytes" -> (4 * 1024 * 1024).toString + )(withTempPath { path => + val ds = sampledTestData.coalesce(1) + ds.write.text(path.getAbsolutePath) + val readback = spark.read.option("samplingRatio", 0.1).json(path.getCanonicalPath) + + assert(readback.schema == new StructType().add("f1", LongType)) + }) + } - val ds = spark.read.option("samplingRatio", 0.1).json(path.getCanonicalPath) - assert(ds.schema == new StructType().add("f1", LongType)) - } + test("SPARK-23849: usage of samplingRatio while parsing a dataset of strings") { + val ds = sampledTestData.coalesce(1) + val readback = spark.read.option("samplingRatio", 0.1).json(ds) + + assert(readback.schema == new StructType().add("f1", LongType)) } - test("SPARK-23849: usage of samplingRation while parsing of dataset of strings") { - val dstr = spark.sparkContext.parallelize(0 until 100, 1).map { i => - val predefinedSample = Set[Int](2, 8, 15, 27, 30, 34, 35, 37, 44, 46, - 57, 62, 68, 72) - if (predefinedSample.contains(i)) { - s"""{"f1":${i.toString}}""" + "\n" - } else { - s"""{"f1":${(i.toDouble + 0.1).toString}}""" + "\n" - } - }.toDS() - val ds = spark.read.option("samplingRatio", 0.1).json(dstr) + test("SPARK-23849: samplingRatio is out of the range (0, 1.0]") { + val ds = spark.range(0, 100, 1, 1).map(_.toString) + + val errorMsg0 = intercept[IllegalArgumentException] { + spark.read.option("samplingRatio", -1).json(ds) + }.getMessage + assert(errorMsg0.contains("samplingRatio (-1.0) should be greater than 0")) + + val errorMsg1 = intercept[IllegalArgumentException] { + spark.read.option("samplingRatio", 0).json(ds) + }.getMessage + assert(errorMsg1.contains("samplingRatio (0.0) should be greater than 0")) - assert(ds.schema == new StructType().add("f1", LongType)) + val sampled = spark.read.option("samplingRatio", 1.0).json(ds) + assert(sampled.count() == ds.count()) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala index 13084ba4a7f04..6e9559edf8ec2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala @@ -233,4 +233,16 @@ private[json] trait TestJsonData { spark.createDataset(spark.sparkContext.parallelize("""{"a":123}""" :: Nil))(Encoders.STRING) def empty: Dataset[String] = spark.emptyDataset(Encoders.STRING) + + def sampledTestData: Dataset[String] = { + spark.range(0, 100, 1).map { index => + val predefinedSample = Set[Long](2, 8, 15, 27, 30, 34, 35, 37, 44, 46, + 57, 62, 68, 72) + if (predefinedSample.contains(index)) { + s"""{"f1":${index.toString}}""" + } else { + s"""{"f1":${(index.toDouble + 0.1).toString}}""" + } + }(Encoders.STRING) + } } From 58c55cb4a6d72d72df908e37aa63f617b3cc5587 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Thu, 26 Apr 2018 12:19:20 +0900 Subject: [PATCH 0693/1161] [SPARK-23902][SQL] Add roundOff flag to months_between ## What changes were proposed in this pull request? HIVE-15511 introduced the `roundOff` flag in order to disable the rounding to 8 digits which is performed in `months_between`. Since this can be a computational intensive operation, skipping it may improve performances when the rounding is not needed. ## How was this patch tested? modified existing UT Author: Marco Gaido Closes #21008 from mgaido91/SPARK-23902. --- python/pyspark/sql/functions.py | 10 +++- .../expressions/datetimeExpressions.scala | 33 +++++++---- .../sql/catalyst/util/DateTimeUtils.scala | 32 ++++------ .../expressions/DateExpressionsSuite.scala | 59 +++++++++++-------- .../catalyst/util/DateTimeUtilsSuite.scala | 30 +++++++--- .../org/apache/spark/sql/functions.scala | 13 +++- .../apache/spark/sql/DateFunctionsSuite.scala | 7 +++ 7 files changed, 118 insertions(+), 66 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index de53b48b6f3b4..38ae41a5dafe6 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1088,16 +1088,20 @@ def add_months(start, months): @since(1.5) -def months_between(date1, date2): +def months_between(date1, date2, roundOff=True): """ Returns the number of months between date1 and date2. + Unless `roundOff` is set to `False`, the result is rounded off to 8 digits. >>> df = spark.createDataFrame([('1997-02-28 10:30:00', '1996-10-30')], ['date1', 'date2']) >>> df.select(months_between(df.date1, df.date2).alias('months')).collect() - [Row(months=3.9495967...)] + [Row(months=3.94959677)] + >>> df.select(months_between(df.date1, df.date2, False).alias('months')).collect() + [Row(months=3.9495967741935485)] """ sc = SparkContext._active_spark_context - return Column(sc._jvm.functions.months_between(_to_java_column(date1), _to_java_column(date2))) + return Column(sc._jvm.functions.months_between( + _to_java_column(date1), _to_java_column(date2), roundOff)) @since(2.2) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index b9b2cd5bdb9f0..d882d06cfd625 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -1156,38 +1156,49 @@ case class AddMonths(startDate: Expression, numMonths: Expression) */ // scalastyle:off line.size.limit @ExpressionDescription( - usage = "_FUNC_(timestamp1, timestamp2) - Returns number of months between `timestamp1` and `timestamp2`.", + usage = """ + _FUNC_(timestamp1, timestamp2[, roundOff]) - Returns number of months between `timestamp1` and `timestamp2`. + The result is rounded to 8 decimal places by default. Set roundOff=false otherwise."""", examples = """ Examples: > SELECT _FUNC_('1997-02-28 10:30:00', '1996-10-30'); 3.94959677 + > SELECT _FUNC_('1997-02-28 10:30:00', '1996-10-30', false); + 3.9495967741935485 """, since = "1.5.0") // scalastyle:on line.size.limit -case class MonthsBetween(date1: Expression, date2: Expression, timeZoneId: Option[String] = None) - extends BinaryExpression with TimeZoneAwareExpression with ImplicitCastInputTypes { +case class MonthsBetween( + date1: Expression, + date2: Expression, + roundOff: Expression, + timeZoneId: Option[String] = None) + extends TernaryExpression with TimeZoneAwareExpression with ImplicitCastInputTypes { + + def this(date1: Expression, date2: Expression) = this(date1, date2, Literal.TrueLiteral, None) - def this(date1: Expression, date2: Expression) = this(date1, date2, None) + def this(date1: Expression, date2: Expression, roundOff: Expression) = + this(date1, date2, roundOff, None) - override def left: Expression = date1 - override def right: Expression = date2 + override def children: Seq[Expression] = Seq(date1, date2, roundOff) - override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType, TimestampType) + override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType, TimestampType, BooleanType) override def dataType: DataType = DoubleType override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = copy(timeZoneId = Option(timeZoneId)) - override def nullSafeEval(t1: Any, t2: Any): Any = { - DateTimeUtils.monthsBetween(t1.asInstanceOf[Long], t2.asInstanceOf[Long], timeZone) + override def nullSafeEval(t1: Any, t2: Any, roundOff: Any): Any = { + DateTimeUtils.monthsBetween( + t1.asInstanceOf[Long], t2.asInstanceOf[Long], roundOff.asInstanceOf[Boolean], timeZone) } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val tz = ctx.addReferenceObj("timeZone", timeZone) val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") - defineCodeGen(ctx, ev, (l, r) => { - s"""$dtu.monthsBetween($l, $r, $tz)""" + defineCodeGen(ctx, ev, (d1, d2, roundOff) => { + s"""$dtu.monthsBetween($d1, $d2, $roundOff, $tz)""" }) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index fa69b8af62c85..4b00a61c6cf91 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -870,24 +870,14 @@ object DateTimeUtils { * If time1 and time2 having the same day of month, or both are the last day of month, * it returns an integer (time under a day will be ignored). * - * Otherwise, the difference is calculated based on 31 days per month, and rounding to - * 8 digits. + * Otherwise, the difference is calculated based on 31 days per month. + * If `roundOff` is set to true, the result is rounded to 8 decimal places. */ - def monthsBetween(time1: SQLTimestamp, time2: SQLTimestamp): Double = { - monthsBetween(time1, time2, defaultTimeZone()) - } - - /** - * Returns number of months between time1 and time2. time1 and time2 are expressed in - * microseconds since 1.1.1970. - * - * If time1 and time2 having the same day of month, or both are the last day of month, - * it returns an integer (time under a day will be ignored). - * - * Otherwise, the difference is calculated based on 31 days per month, and rounding to - * 8 digits. - */ - def monthsBetween(time1: SQLTimestamp, time2: SQLTimestamp, timeZone: TimeZone): Double = { + def monthsBetween( + time1: SQLTimestamp, + time2: SQLTimestamp, + roundOff: Boolean, + timeZone: TimeZone): Double = { val millis1 = time1 / 1000L val millis2 = time2 / 1000L val date1 = millisToDays(millis1, timeZone) @@ -906,8 +896,12 @@ object DateTimeUtils { val timeInDay2 = millis2 - daysToMillis(date2, timeZone) val timesBetween = (timeInDay1 - timeInDay2).toDouble / MILLIS_PER_DAY val diff = (months1 - months2).toDouble + (dayInMonth1 - dayInMonth2 + timesBetween) / 31.0 - // rounding to 8 digits - math.round(diff * 1e8) / 1e8 + if (roundOff) { + // rounding to 8 digits + math.round(diff * 1e8) / 1e8 + } else { + diff + } } // Thursday = 0 since 1970/Jan/01 => Thursday diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala index 080ec487cfa6a..63b24fb9eb13a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala @@ -464,34 +464,47 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { MonthsBetween( Literal(new Timestamp(sdf.parse("1997-02-28 10:30:00").getTime)), Literal(new Timestamp(sdf.parse("1996-10-30 00:00:00").getTime)), - timeZoneId), - 3.94959677) + Literal.TrueLiteral, + timeZoneId = timeZoneId), 3.94959677) checkEvaluation( MonthsBetween( - Literal(new Timestamp(sdf.parse("2015-01-30 11:52:00").getTime)), - Literal(new Timestamp(sdf.parse("2015-01-30 11:50:00").getTime)), - timeZoneId), - 0.0) - checkEvaluation( - MonthsBetween( - Literal(new Timestamp(sdf.parse("2015-01-31 00:00:00").getTime)), - Literal(new Timestamp(sdf.parse("2015-03-31 22:00:00").getTime)), - timeZoneId), - -2.0) - checkEvaluation( - MonthsBetween( - Literal(new Timestamp(sdf.parse("2015-03-31 22:00:00").getTime)), - Literal(new Timestamp(sdf.parse("2015-02-28 00:00:00").getTime)), - timeZoneId), - 1.0) + Literal(new Timestamp(sdf.parse("1997-02-28 10:30:00").getTime)), + Literal(new Timestamp(sdf.parse("1996-10-30 00:00:00").getTime)), + Literal.FalseLiteral, + timeZoneId = timeZoneId), 3.9495967741935485) + + Seq(Literal.FalseLiteral, Literal.TrueLiteral). foreach { roundOff => + checkEvaluation( + MonthsBetween( + Literal(new Timestamp(sdf.parse("2015-01-30 11:52:00").getTime)), + Literal(new Timestamp(sdf.parse("2015-01-30 11:50:00").getTime)), + roundOff, + timeZoneId = timeZoneId), 0.0) + checkEvaluation( + MonthsBetween( + Literal(new Timestamp(sdf.parse("2015-01-31 00:00:00").getTime)), + Literal(new Timestamp(sdf.parse("2015-03-31 22:00:00").getTime)), + roundOff, + timeZoneId = timeZoneId), -2.0) + checkEvaluation( + MonthsBetween( + Literal(new Timestamp(sdf.parse("2015-03-31 22:00:00").getTime)), + Literal(new Timestamp(sdf.parse("2015-02-28 00:00:00").getTime)), + roundOff, + timeZoneId = timeZoneId), 1.0) + } val t = Literal(Timestamp.valueOf("2015-03-31 22:00:00")) val tnull = Literal.create(null, TimestampType) - checkEvaluation(MonthsBetween(t, tnull, timeZoneId), null) - checkEvaluation(MonthsBetween(tnull, t, timeZoneId), null) - checkEvaluation(MonthsBetween(tnull, tnull, timeZoneId), null) + checkEvaluation(MonthsBetween(t, tnull, Literal.TrueLiteral, timeZoneId = timeZoneId), null) + checkEvaluation(MonthsBetween(tnull, t, Literal.TrueLiteral, timeZoneId = timeZoneId), null) + checkEvaluation( + MonthsBetween(tnull, tnull, Literal.TrueLiteral, timeZoneId = timeZoneId), null) + checkEvaluation( + MonthsBetween(t, t, Literal.create(null, BooleanType), timeZoneId = timeZoneId), null) checkConsistencyBetweenInterpretedAndCodegen( - (time1: Expression, time2: Expression) => MonthsBetween(time1, time2, timeZoneId), - TimestampType, TimestampType) + (time1: Expression, time2: Expression, roundOff: Expression) => + MonthsBetween(time1, time2, roundOff, timeZoneId = timeZoneId), + TimestampType, TimestampType, BooleanType) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala index 625ff38943fa3..cbf6106697f30 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala @@ -490,24 +490,36 @@ class DateTimeUtilsSuite extends SparkFunSuite { c1.set(1997, 1, 28, 10, 30, 0) val c2 = Calendar.getInstance() c2.set(1996, 9, 30, 0, 0, 0) - assert(monthsBetween(c1.getTimeInMillis * 1000L, c2.getTimeInMillis * 1000L) === 3.94959677) - c2.set(2000, 1, 28, 0, 0, 0) - assert(monthsBetween(c1.getTimeInMillis * 1000L, c2.getTimeInMillis * 1000L) === -36) - c2.set(2000, 1, 29, 0, 0, 0) - assert(monthsBetween(c1.getTimeInMillis * 1000L, c2.getTimeInMillis * 1000L) === -36) - c2.set(1996, 2, 31, 0, 0, 0) - assert(monthsBetween(c1.getTimeInMillis * 1000L, c2.getTimeInMillis * 1000L) === 11) + assert(monthsBetween( + c1.getTimeInMillis * 1000L, c2.getTimeInMillis * 1000L, true, c1.getTimeZone) === 3.94959677) + assert(monthsBetween( + c1.getTimeInMillis * 1000L, c2.getTimeInMillis * 1000L, false, c1.getTimeZone) + === 3.9495967741935485) + Seq(true, false).foreach { roundOff => + c2.set(2000, 1, 28, 0, 0, 0) + assert(monthsBetween( + c1.getTimeInMillis * 1000L, c2.getTimeInMillis * 1000L, roundOff, c1.getTimeZone) === -36) + c2.set(2000, 1, 29, 0, 0, 0) + assert(monthsBetween( + c1.getTimeInMillis * 1000L, c2.getTimeInMillis * 1000L, roundOff, c1.getTimeZone) === -36) + c2.set(1996, 2, 31, 0, 0, 0) + assert(monthsBetween( + c1.getTimeInMillis * 1000L, c2.getTimeInMillis * 1000L, roundOff, c1.getTimeZone) === 11) + } val c3 = Calendar.getInstance(TimeZonePST) c3.set(2000, 1, 28, 16, 0, 0) val c4 = Calendar.getInstance(TimeZonePST) c4.set(1997, 1, 28, 16, 0, 0) assert( - monthsBetween(c3.getTimeInMillis * 1000L, c4.getTimeInMillis * 1000L, TimeZonePST) + monthsBetween(c3.getTimeInMillis * 1000L, c4.getTimeInMillis * 1000L, true, TimeZonePST) === 36.0) assert( - monthsBetween(c3.getTimeInMillis * 1000L, c4.getTimeInMillis * 1000L, TimeZoneGMT) + monthsBetween(c3.getTimeInMillis * 1000L, c4.getTimeInMillis * 1000L, true, TimeZoneGMT) === 35.90322581) + assert( + monthsBetween(c3.getTimeInMillis * 1000L, c4.getTimeInMillis * 1000L, false, TimeZoneGMT) + === 35.903225806451616) } test("from UTC timestamp") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index d2f057310f89b..f1587cd032adc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -2691,11 +2691,22 @@ object functions { /** * Returns number of months between dates `date1` and `date2`. + * The result is rounded off to 8 digits. * @group datetime_funcs * @since 1.5.0 */ def months_between(date1: Column, date2: Column): Column = withExpr { - MonthsBetween(date1.expr, date2.expr) + new MonthsBetween(date1.expr, date2.expr) + } + + /** + * Returns number of months between dates `date1` and `date2`. If `roundOff` is set to true, the + * result is rounded off to 8 digits; it is not rounded otherwise. + * @group datetime_funcs + * @since 2.4.0 + */ + def months_between(date1: Column, date2: Column, roundOff: Boolean): Column = withExpr { + MonthsBetween(date1.expr, date2.expr, lit(roundOff).expr) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala index 6bbf38516cdf6..f712baa7a9134 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala @@ -327,6 +327,13 @@ class DateFunctionsSuite extends QueryTest with SharedSQLContext { val df = Seq((t1, d1, s1), (t2, d2, s2)).toDF("t", "d", "s") checkAnswer(df.select(months_between(col("t"), col("d"))), Seq(Row(-10.0), Row(7.0))) checkAnswer(df.selectExpr("months_between(t, s)"), Seq(Row(0.5), Row(-0.5))) + checkAnswer(df.selectExpr("months_between(t, s, true)"), Seq(Row(0.5), Row(-0.5))) + Seq(true, false).foreach { roundOff => + checkAnswer(df.select(months_between(col("t"), col("d"), roundOff)), + Seq(Row(-10.0), Row(7.0))) + checkAnswer(df.withColumn("r", lit(false)).selectExpr("months_between(t, s, r)"), + Seq(Row(0.5), Row(-0.5))) + } } test("function last_day") { From cd10f9df8284ee8a5d287b2cd204c70b8ba87f5e Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Thu, 26 Apr 2018 13:37:13 +0900 Subject: [PATCH 0694/1161] [SPARK-23916][SQL] Add array_join function ## What changes were proposed in this pull request? The PR adds the SQL function `array_join`. The behavior of the function is based on Presto's one. The function accepts an `array` of `string` which is to be joined, a `string` which is the delimiter to use between the items of the first argument and optionally a `string` which is used to replace `null` values. ## How was this patch tested? added UTs Author: Marco Gaido Closes #21011 from mgaido91/SPARK-23916. --- python/pyspark/sql/functions.py | 21 +++ .../catalyst/analysis/FunctionRegistry.scala | 1 + .../expressions/collectionOperations.scala | 169 ++++++++++++++++++ .../CollectionExpressionsSuite.scala | 35 ++++ .../org/apache/spark/sql/functions.scala | 19 ++ .../spark/sql/DataFrameFunctionsSuite.scala | 23 +++ 6 files changed, 268 insertions(+) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 38ae41a5dafe6..ad4bd6f5089e9 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1834,6 +1834,27 @@ def array_contains(col, value): return Column(sc._jvm.functions.array_contains(_to_java_column(col), value)) +@ignore_unicode_prefix +@since(2.4) +def array_join(col, delimiter, null_replacement=None): + """ + Concatenates the elements of `column` using the `delimiter`. Null values are replaced with + `null_replacement` if set, otherwise they are ignored. + + >>> df = spark.createDataFrame([(["a", "b", "c"],), (["a", None],)], ['data']) + >>> df.select(array_join(df.data, ",").alias("joined")).collect() + [Row(joined=u'a,b,c'), Row(joined=u'a')] + >>> df.select(array_join(df.data, ",", "NULL").alias("joined")).collect() + [Row(joined=u'a,b,c'), Row(joined=u'a,NULL')] + """ + sc = SparkContext._active_spark_context + if null_replacement is None: + return Column(sc._jvm.functions.array_join(_to_java_column(col), delimiter)) + else: + return Column(sc._jvm.functions.array_join( + _to_java_column(col), delimiter, null_replacement)) + + @since(1.5) @ignore_unicode_prefix def concat(*cols): diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 6afcf309bd690..6bc7b4e4f7cb3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -401,6 +401,7 @@ object FunctionRegistry { // collection functions expression[CreateArray]("array"), expression[ArrayContains]("array_contains"), + expression[ArrayJoin]("array_join"), expression[ArrayPosition]("array_position"), expression[CreateMap]("map"), expression[CreateNamedStruct]("named_struct"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index bc71b5f34ce4a..90223b9126555 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -378,6 +378,175 @@ case class ArrayContains(left: Expression, right: Expression) override def prettyName: String = "array_contains" } +/** + * Creates a String containing all the elements of the input array separated by the delimiter. + */ +@ExpressionDescription( + usage = """ + _FUNC_(array, delimiter[, nullReplacement]) - Concatenates the elements of the given array + using the delimiter and an optional string to replace nulls. If no value is set for + nullReplacement, any null value is filtered.""", + examples = """ + Examples: + > SELECT _FUNC_(array('hello', 'world'), ' '); + hello world + > SELECT _FUNC_(array('hello', null ,'world'), ' '); + hello world + > SELECT _FUNC_(array('hello', null ,'world'), ' ', ','); + hello , world + """, since = "2.4.0") +case class ArrayJoin( + array: Expression, + delimiter: Expression, + nullReplacement: Option[Expression]) extends Expression with ExpectsInputTypes { + + def this(array: Expression, delimiter: Expression) = this(array, delimiter, None) + + def this(array: Expression, delimiter: Expression, nullReplacement: Expression) = + this(array, delimiter, Some(nullReplacement)) + + override def inputTypes: Seq[AbstractDataType] = if (nullReplacement.isDefined) { + Seq(ArrayType(StringType), StringType, StringType) + } else { + Seq(ArrayType(StringType), StringType) + } + + override def children: Seq[Expression] = if (nullReplacement.isDefined) { + Seq(array, delimiter, nullReplacement.get) + } else { + Seq(array, delimiter) + } + + override def nullable: Boolean = children.exists(_.nullable) + + override def foldable: Boolean = children.forall(_.foldable) + + override def eval(input: InternalRow): Any = { + val arrayEval = array.eval(input) + if (arrayEval == null) return null + val delimiterEval = delimiter.eval(input) + if (delimiterEval == null) return null + val nullReplacementEval = nullReplacement.map(_.eval(input)) + if (nullReplacementEval.contains(null)) return null + + val buffer = new UTF8StringBuilder() + var firstItem = true + val nullHandling = nullReplacementEval match { + case Some(rep) => (prependDelimiter: Boolean) => { + if (!prependDelimiter) { + buffer.append(delimiterEval.asInstanceOf[UTF8String]) + } + buffer.append(rep.asInstanceOf[UTF8String]) + true + } + case None => (_: Boolean) => false + } + arrayEval.asInstanceOf[ArrayData].foreach(StringType, (_, item) => { + if (item == null) { + if (nullHandling(firstItem)) { + firstItem = false + } + } else { + if (!firstItem) { + buffer.append(delimiterEval.asInstanceOf[UTF8String]) + } + buffer.append(item.asInstanceOf[UTF8String]) + firstItem = false + } + }) + buffer.build() + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val code = nullReplacement match { + case Some(replacement) => + val replacementGen = replacement.genCode(ctx) + val nullHandling = (buffer: String, delimiter: String, firstItem: String) => { + s""" + |if (!$firstItem) { + | $buffer.append($delimiter); + |} + |$buffer.append(${replacementGen.value}); + |$firstItem = false; + """.stripMargin + } + val execCode = if (replacement.nullable) { + ctx.nullSafeExec(replacement.nullable, replacementGen.isNull) { + genCodeForArrayAndDelimiter(ctx, ev, nullHandling) + } + } else { + genCodeForArrayAndDelimiter(ctx, ev, nullHandling) + } + s""" + |${replacementGen.code} + |$execCode + """.stripMargin + case None => genCodeForArrayAndDelimiter(ctx, ev, + (_: String, _: String, _: String) => "// nulls are ignored") + } + if (nullable) { + ev.copy( + s""" + |boolean ${ev.isNull} = true; + |UTF8String ${ev.value} = null; + |$code + """.stripMargin) + } else { + ev.copy( + s""" + |UTF8String ${ev.value} = null; + |$code + """.stripMargin, FalseLiteral) + } + } + + private def genCodeForArrayAndDelimiter( + ctx: CodegenContext, + ev: ExprCode, + nullEval: (String, String, String) => String): String = { + val arrayGen = array.genCode(ctx) + val delimiterGen = delimiter.genCode(ctx) + val buffer = ctx.freshName("buffer") + val bufferClass = classOf[UTF8StringBuilder].getName + val i = ctx.freshName("i") + val firstItem = ctx.freshName("firstItem") + val resultCode = + s""" + |$bufferClass $buffer = new $bufferClass(); + |boolean $firstItem = true; + |for (int $i = 0; $i < ${arrayGen.value}.numElements(); $i ++) { + | if (${arrayGen.value}.isNullAt($i)) { + | ${nullEval(buffer, delimiterGen.value, firstItem)} + | } else { + | if (!$firstItem) { + | $buffer.append(${delimiterGen.value}); + | } + | $buffer.append(${CodeGenerator.getValue(arrayGen.value, StringType, i)}); + | $firstItem = false; + | } + |} + |${ev.value} = $buffer.build();""".stripMargin + + if (array.nullable || delimiter.nullable) { + arrayGen.code + ctx.nullSafeExec(array.nullable, arrayGen.isNull) { + delimiterGen.code + ctx.nullSafeExec(delimiter.nullable, delimiterGen.isNull) { + s""" + |${ev.isNull} = false; + |$resultCode""".stripMargin + } + } + } else { + s""" + |${arrayGen.code} + |${delimiterGen.code} + |$resultCode""".stripMargin + } + } + + override def dataType: DataType = StringType + +} + /** * Returns the minimum value in the array. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index b49fa76b2a781..7048d93fd5649 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -106,6 +106,41 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(ArrayContains(a3, Literal.create(null, StringType)), null) } + test("ArrayJoin") { + def testArrays( + arrays: Seq[Expression], + nullReplacement: Option[Expression], + expected: Seq[String]): Unit = { + assert(arrays.length == expected.length) + arrays.zip(expected).foreach { case (arr, exp) => + checkEvaluation(ArrayJoin(arr, Literal(","), nullReplacement), exp) + } + } + + val arrays = Seq(Literal.create(Seq[String]("a", "b"), ArrayType(StringType)), + Literal.create(Seq[String]("a", null, "b"), ArrayType(StringType)), + Literal.create(Seq[String](null), ArrayType(StringType)), + Literal.create(Seq[String]("a", "b", null), ArrayType(StringType)), + Literal.create(Seq[String](null, "a", "b"), ArrayType(StringType)), + Literal.create(Seq[String]("a"), ArrayType(StringType))) + + val withoutNullReplacement = Seq("a,b", "a,b", "", "a,b", "a,b", "a") + val withNullReplacement = Seq("a,b", "a,NULL,b", "NULL", "a,b,NULL", "NULL,a,b", "a") + testArrays(arrays, None, withoutNullReplacement) + testArrays(arrays, Some(Literal("NULL")), withNullReplacement) + + checkEvaluation(ArrayJoin( + Literal.create(null, ArrayType(StringType)), Literal(","), None), null) + checkEvaluation(ArrayJoin( + Literal.create(Seq[String](null), ArrayType(StringType)), + Literal.create(null, StringType), + None), null) + checkEvaluation(ArrayJoin( + Literal.create(Seq[String](null), ArrayType(StringType)), + Literal(","), + Some(Literal.create(null, StringType))), null) + } + test("Array Min") { checkEvaluation(ArrayMin(Literal.create(Seq(-11, 10, 2), ArrayType(IntegerType))), -11) checkEvaluation( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index f1587cd032adc..25afaacc38d6f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3039,6 +3039,25 @@ object functions { ArrayContains(column.expr, Literal(value)) } + /** + * Concatenates the elements of `column` using the `delimiter`. Null values are replaced with + * `nullReplacement`. + * @group collection_funcs + * @since 2.4.0 + */ + def array_join(column: Column, delimiter: String, nullReplacement: String): Column = withExpr { + ArrayJoin(column.expr, Literal(delimiter), Some(Literal(nullReplacement))) + } + + /** + * Concatenates the elements of `column` using the `delimiter`. + * @group collection_funcs + * @since 2.4.0 + */ + def array_join(column: Column, delimiter: String): Column = withExpr { + ArrayJoin(column.expr, Literal(delimiter), None) + } + /** * Concatenates multiple input columns together into a single column. * The function works with strings, binary and compatible array columns. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 03605c30036a3..c216d1322a06c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -413,6 +413,29 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { ) } + test("array_join function") { + val df = Seq( + (Seq[String]("a", "b"), ","), + (Seq[String]("a", null, "b"), ","), + (Seq.empty[String], ",") + ).toDF("x", "delimiter") + + checkAnswer( + df.select(array_join(df("x"), ";")), + Seq(Row("a;b"), Row("a;b"), Row("")) + ) + checkAnswer( + df.select(array_join(df("x"), ";", "NULL")), + Seq(Row("a;b"), Row("a;NULL;b"), Row("")) + ) + checkAnswer( + df.selectExpr("array_join(x, delimiter)"), + Seq(Row("a,b"), Row("a,b"), Row(""))) + checkAnswer( + df.selectExpr("array_join(x, delimiter, 'NULL')"), + Seq(Row("a,b"), Row("a,NULL,b"), Row(""))) + } + test("array_min function") { val df = Seq( Seq[Option[Int]](Some(1), Some(3), Some(2)), From ffaf0f9fd407aeba7006f3d785ea8a0e51187357 Mon Sep 17 00:00:00 2001 From: jerryshao Date: Thu, 26 Apr 2018 13:27:33 +0800 Subject: [PATCH 0695/1161] [SPARK-24062][THRIFT SERVER] Fix SASL encryption cannot enabled issue in thrift server ## What changes were proposed in this pull request? For the details of the exception please see [SPARK-24062](https://issues.apache.org/jira/browse/SPARK-24062). The issue is: Spark on Yarn stores SASL secret in current UGI's credentials, this credentials will be distributed to AM and executors, so that executors and drive share the same secret to communicate. But STS/Hive library code will refresh the current UGI by UGI's loginFromKeytab() after Spark application is started, this will create a new UGI in the current driver's context with empty tokens and secret keys, so secret key is lost in the current context's UGI, that's why Spark driver throws secret key not found exception. In Spark 2.2 code, Spark also stores this secret key in SecurityManager's class variable, so even UGI is refreshed, the secret is still existed in the object, so STS with SASL can still be worked in Spark 2.2. But in Spark 2.3, we always search key from current UGI, which makes it fail to work in Spark 2.3. To fix this issue, there're two possible solutions: 1. Fix in STS/Hive library, when a new UGI is refreshed, copy the secret key from original UGI to the new one. The difficulty is that some codes to refresh the UGI is existed in Hive library, which makes us hard to change the code. 2. Roll back the logics in SecurityManager to match Spark 2.2, so that this issue can be fixed. 2nd solution seems a simple one. So I will propose a PR with 2nd solution. ## How was this patch tested? Verified in local cluster. CC vanzin tgravescs please help to review. Thanks! Author: jerryshao Closes #21138 from jerryshao/SPARK-24062. --- .../main/scala/org/apache/spark/SecurityManager.scala | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SecurityManager.scala b/core/src/main/scala/org/apache/spark/SecurityManager.scala index 09ec8932353a0..dbfd5a514c189 100644 --- a/core/src/main/scala/org/apache/spark/SecurityManager.scala +++ b/core/src/main/scala/org/apache/spark/SecurityManager.scala @@ -89,6 +89,7 @@ private[spark] class SecurityManager( setViewAclsGroups(sparkConf.get("spark.ui.view.acls.groups", "")); setModifyAclsGroups(sparkConf.get("spark.modify.acls.groups", "")); + private var secretKey: String = _ logInfo("SecurityManager: authentication " + (if (authOn) "enabled" else "disabled") + "; ui acls " + (if (aclsOn) "enabled" else "disabled") + "; users with view permissions: " + viewAcls.toString() + @@ -321,6 +322,12 @@ private[spark] class SecurityManager( val creds = UserGroupInformation.getCurrentUser().getCredentials() Option(creds.getSecretKey(SECRET_LOOKUP_KEY)) .map { bytes => new String(bytes, UTF_8) } + // Secret key may not be found in current UGI's credentials. + // This happens when UGI is refreshed in the driver side by UGI's loginFromKeytab but not + // copy secret key from original UGI to the new one. This exists in ThriftServer's Hive + // logic. So as a workaround, storing secret key in a local variable to make it visible + // in different context. + .orElse(Option(secretKey)) .orElse(Option(sparkConf.getenv(ENV_AUTH_SECRET))) .orElse(sparkConf.getOption(SPARK_AUTH_SECRET_CONF)) .getOrElse { @@ -364,8 +371,8 @@ private[spark] class SecurityManager( rnd.nextBytes(secretBytes) val creds = new Credentials() - val secretStr = HashCodes.fromBytes(secretBytes).toString() - creds.addSecretKey(SECRET_LOOKUP_KEY, secretStr.getBytes(UTF_8)) + secretKey = HashCodes.fromBytes(secretBytes).toString() + creds.addSecretKey(SECRET_LOOKUP_KEY, secretKey.getBytes(UTF_8)) UserGroupInformation.getCurrentUser().addCredentials(creds) } From d1eb8d3ddc877958512194cc8f5dd8119b41bed0 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Wed, 25 Apr 2018 23:24:05 -0700 Subject: [PATCH 0696/1161] [SPARK-24094][SS][MINOR] Change description strings of v2 streaming sources to reflect the change ## What changes were proposed in this pull request? This makes it easy to understand at runtime which version is running. Great for debugging production issues. ## How was this patch tested? Not necessary. Author: Tathagata Das Closes #21160 from tdas/SPARK-24094. --- .../org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala | 2 +- .../streaming/sources/RateStreamMicroBatchReader.scala | 2 +- .../apache/spark/sql/execution/streaming/sources/socket.scala | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala index 2ed49ba3f5495..cbe655f9bff1f 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala @@ -169,7 +169,7 @@ private[kafka010] class KafkaMicroBatchReader( kafkaOffsetReader.close() } - override def toString(): String = s"Kafka[$kafkaOffsetReader]" + override def toString(): String = s"KafkaV2[$kafkaOffsetReader]" /** * Read initial partition offsets from the checkpoint, or decide the offsets and write them to diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala index 6cf8520fc544f..f54291bea6678 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala @@ -177,7 +177,7 @@ class RateStreamMicroBatchReader(options: DataSourceOptions, checkpointLocation: override def stop(): Unit = {} - override def toString: String = s"MicroBatchRateSource[rowsPerSecond=$rowsPerSecond, " + + override def toString: String = s"RateStreamV2[rowsPerSecond=$rowsPerSecond, " + s"rampUpTimeSeconds=$rampUpTimeSeconds, " + s"numPartitions=${options.get(NUM_PARTITIONS).orElse("default")}" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala index 5aae46b463398..90f4a5ba4234d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala @@ -214,7 +214,7 @@ class TextSocketMicroBatchReader(options: DataSourceOptions) extends MicroBatchR } } - override def toString: String = s"TextSocket[host: $host, port: $port]" + override def toString: String = s"TextSocketV2[host: $host, port: $port]" } class TextSocketSourceProvider extends DataSourceV2 From ce2f919f8df1b794ceaa23e1a59d5d541ed47bf5 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Thu, 26 Apr 2018 19:07:13 +0800 Subject: [PATCH 0697/1161] [SPARK-23799][SQL][FOLLOW-UP] FilterEstimation.evaluateInSet produces wrong stats for STRING ## What changes were proposed in this pull request? `colStat.min` AND `colStat.max` are empty for string type. Thus, `evaluateInSet` should not return zero when either `colStat.min` or `colStat.max`. ## How was this patch tested? Added a test case. Author: gatorsmile Closes #21147 from gatorsmile/cached. --- .../logical/statsEstimation/FilterEstimation.scala | 12 ++++++++---- .../statsEstimation/FilterEstimationSuite.scala | 12 ++++++++++++ 2 files changed, 20 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala index 263c9ba60d145..5a3eeefaedb18 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala @@ -392,13 +392,13 @@ case class FilterEstimation(plan: Filter) extends Logging { val dataType = attr.dataType var newNdv = ndv - if (ndv.toDouble == 0 || colStat.min.isEmpty || colStat.max.isEmpty) { - return Some(0.0) - } - // use [min, max] to filter the original hSet dataType match { case _: NumericType | BooleanType | DateType | TimestampType => + if (ndv.toDouble == 0 || colStat.min.isEmpty || colStat.max.isEmpty) { + return Some(0.0) + } + val statsInterval = ValueInterval(colStat.min, colStat.max, dataType).asInstanceOf[NumericValueInterval] val validQuerySet = hSet.filter { v => @@ -422,6 +422,10 @@ case class FilterEstimation(plan: Filter) extends Logging { // We assume the whole set since there is no min/max information for String/Binary type case StringType | BinaryType => + if (ndv.toDouble == 0) { + return Some(0.0) + } + newNdv = ndv.min(BigInt(hSet.size)) if (update) { val newStats = colStat.copy(distinctCount = Some(newNdv), nullCount = Some(0)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala index 16cb5d032cf57..47bfa62569583 100755 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala @@ -368,6 +368,18 @@ class FilterEstimationSuite extends StatsEstimationTestBase { expectedRowCount = 0) } + test("evaluateInSet with string") { + validateEstimatedStats( + Filter(InSet(attrString, Set("A0")), + StatsTestPlan(Seq(attrString), 10, + AttributeMap(Seq(attrString -> + ColumnStat(distinctCount = Some(10), min = None, max = None, + nullCount = Some(0), avgLen = Some(2), maxLen = Some(2)))))), + Seq(attrString -> ColumnStat(distinctCount = Some(1), min = None, max = None, + nullCount = Some(0), avgLen = Some(2), maxLen = Some(2))), + expectedRowCount = 1) + } + test("cint NOT IN (3, 4, 5)") { validateEstimatedStats( Filter(Not(InSet(attrInt, Set(3, 4, 5))), childStatsTestPlan(Seq(attrInt), 10L)), From 4f1e38649ebc7710850b7c40e6fb355775e7bb7f Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Thu, 26 Apr 2018 14:21:22 -0700 Subject: [PATCH 0698/1161] [SPARK-24057][PYTHON] put the real data type in the AssertionError message ## What changes were proposed in this pull request? Print out the data type in the AssertionError message to make it more meaningful. ## How was this patch tested? I manually tested the changed code on my local, but didn't add any test. Author: Huaxin Gao Closes #21159 from huaxingao/spark-24057. --- python/pyspark/sql/types.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 1f6534836d64a..3cd7a2ef115af 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -289,7 +289,8 @@ def __init__(self, elementType, containsNull=True): >>> ArrayType(StringType(), False) == ArrayType(StringType()) False """ - assert isinstance(elementType, DataType), "elementType should be DataType" + assert isinstance(elementType, DataType),\ + "elementType %s should be an instance of %s" % (elementType, DataType) self.elementType = elementType self.containsNull = containsNull @@ -343,8 +344,10 @@ def __init__(self, keyType, valueType, valueContainsNull=True): ... == MapType(StringType(), FloatType())) False """ - assert isinstance(keyType, DataType), "keyType should be DataType" - assert isinstance(valueType, DataType), "valueType should be DataType" + assert isinstance(keyType, DataType),\ + "keyType %s should be an instance of %s" % (keyType, DataType) + assert isinstance(valueType, DataType),\ + "valueType %s should be an instance of %s" % (valueType, DataType) self.keyType = keyType self.valueType = valueType self.valueContainsNull = valueContainsNull @@ -402,8 +405,9 @@ def __init__(self, name, dataType, nullable=True, metadata=None): ... == StructField("f2", StringType(), True)) False """ - assert isinstance(dataType, DataType), "dataType should be DataType" - assert isinstance(name, basestring), "field name should be string" + assert isinstance(dataType, DataType),\ + "dataType %s should be an instance of %s" % (dataType, DataType) + assert isinstance(name, basestring), "field name %s should be string" % (name) if not isinstance(name, str): name = name.encode('utf-8') self.name = name From f7435bec6a9348cfbbe26b13c230c08545d16067 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Thu, 26 Apr 2018 15:11:42 -0700 Subject: [PATCH 0699/1161] [SPARK-24044][PYTHON] Explicitly print out skipped tests from unittest module ## What changes were proposed in this pull request? This PR proposes to remove duplicated dependency checking logics and also print out skipped tests from unittests. For example, as below: ``` Skipped tests in pyspark.sql.tests with pypy: test_createDataFrame_column_name_encoding (pyspark.sql.tests.ArrowTests) ... skipped 'Pandas >= 0.19.2 must be installed; however, it was not found.' test_createDataFrame_does_not_modify_input (pyspark.sql.tests.ArrowTests) ... skipped 'Pandas >= 0.19.2 must be installed; however, it was not found.' ... Skipped tests in pyspark.sql.tests with python3: test_createDataFrame_column_name_encoding (pyspark.sql.tests.ArrowTests) ... skipped 'PyArrow >= 0.8.0 must be installed; however, it was not found.' test_createDataFrame_does_not_modify_input (pyspark.sql.tests.ArrowTests) ... skipped 'PyArrow >= 0.8.0 must be installed; however, it was not found.' ... ``` Currently, it's not printed out in the console. I think we should better print out skipped tests in the console. ## How was this patch tested? Manually tested. Also, fortunately, Jenkins has good environment to test the skipped output. Author: hyukjinkwon Closes #21107 from HyukjinKwon/skipped-tests-print. --- python/pyspark/ml/tests.py | 16 +++-- python/pyspark/mllib/tests.py | 4 +- python/pyspark/sql/tests.py | 51 +++++++------ python/pyspark/streaming/tests.py | 4 +- python/pyspark/tests.py | 12 +--- python/run-tests.py | 115 +++++++++++++----------------- 6 files changed, 98 insertions(+), 104 deletions(-) diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 2ec0be60e9fa9..093593132e56d 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -2136,17 +2136,23 @@ class ImageReaderTest2(PySparkTestCase): @classmethod def setUpClass(cls): super(ImageReaderTest2, cls).setUpClass() + cls.hive_available = True # Note that here we enable Hive's support. cls.spark = None try: cls.sc._jvm.org.apache.hadoop.hive.conf.HiveConf() except py4j.protocol.Py4JError: cls.tearDownClass() - raise unittest.SkipTest("Hive is not available") + cls.hive_available = False except TypeError: cls.tearDownClass() - raise unittest.SkipTest("Hive is not available") - cls.spark = HiveContext._createForTesting(cls.sc) + cls.hive_available = False + if cls.hive_available: + cls.spark = HiveContext._createForTesting(cls.sc) + + def setUp(self): + if not self.hive_available: + self.skipTest("Hive is not available.") @classmethod def tearDownClass(cls): @@ -2662,6 +2668,6 @@ def testDefaultFitMultiple(self): if __name__ == "__main__": from pyspark.ml.tests import * if xmlrunner: - unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports')) + unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'), verbosity=2) else: - unittest.main() + unittest.main(verbosity=2) diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index 1037bab7f1088..14d788b0bef60 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -1767,9 +1767,9 @@ def test_pca(self): if not _have_scipy: print("NOTE: Skipping SciPy tests as it does not seem to be installed") if xmlrunner: - unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports')) + unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'), verbosity=2) else: - unittest.main() + unittest.main(verbosity=2) if not _have_scipy: print("NOTE: SciPy tests were skipped as it does not seem to be installed") sc.stop() diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 98fa1b54b0a17..6b28c557a803e 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3096,23 +3096,28 @@ def setUpClass(cls): filename_pattern = ( "sql/core/target/scala-*/test-classes/org/apache/spark/sql/" "TestQueryExecutionListener.class") - if not glob.glob(os.path.join(SPARK_HOME, filename_pattern)): - raise unittest.SkipTest( + cls.has_listener = bool(glob.glob(os.path.join(SPARK_HOME, filename_pattern))) + + if cls.has_listener: + # Note that 'spark.sql.queryExecutionListeners' is a static immutable configuration. + cls.spark = SparkSession.builder \ + .master("local[4]") \ + .appName(cls.__name__) \ + .config( + "spark.sql.queryExecutionListeners", + "org.apache.spark.sql.TestQueryExecutionListener") \ + .getOrCreate() + + def setUp(self): + if not self.has_listener: + raise self.skipTest( "'org.apache.spark.sql.TestQueryExecutionListener' is not " "available. Will skip the related tests.") - # Note that 'spark.sql.queryExecutionListeners' is a static immutable configuration. - cls.spark = SparkSession.builder \ - .master("local[4]") \ - .appName(cls.__name__) \ - .config( - "spark.sql.queryExecutionListeners", - "org.apache.spark.sql.TestQueryExecutionListener") \ - .getOrCreate() - @classmethod def tearDownClass(cls): - cls.spark.stop() + if hasattr(cls, "spark"): + cls.spark.stop() def tearDown(self): self.spark._jvm.OnSuccessCall.clear() @@ -3196,18 +3201,22 @@ class HiveContextSQLTests(ReusedPySparkTestCase): def setUpClass(cls): ReusedPySparkTestCase.setUpClass() cls.tempdir = tempfile.NamedTemporaryFile(delete=False) + cls.hive_available = True try: cls.sc._jvm.org.apache.hadoop.hive.conf.HiveConf() except py4j.protocol.Py4JError: - cls.tearDownClass() - raise unittest.SkipTest("Hive is not available") + cls.hive_available = False except TypeError: - cls.tearDownClass() - raise unittest.SkipTest("Hive is not available") + cls.hive_available = False os.unlink(cls.tempdir.name) - cls.spark = HiveContext._createForTesting(cls.sc) - cls.testData = [Row(key=i, value=str(i)) for i in range(100)] - cls.df = cls.sc.parallelize(cls.testData).toDF() + if cls.hive_available: + cls.spark = HiveContext._createForTesting(cls.sc) + cls.testData = [Row(key=i, value=str(i)) for i in range(100)] + cls.df = cls.sc.parallelize(cls.testData).toDF() + + def setUp(self): + if not self.hive_available: + self.skipTest("Hive is not available.") @classmethod def tearDownClass(cls): @@ -5316,6 +5325,6 @@ def test_invalid_args(self): if __name__ == "__main__": from pyspark.sql.tests import * if xmlrunner: - unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports')) + unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'), verbosity=2) else: - unittest.main() + unittest.main(verbosity=2) diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 103940923dd4d..d77f1baa1f344 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -1590,11 +1590,11 @@ def search_kinesis_asl_assembly_jar(): sys.stderr.write("[Running %s]\n" % (testcase)) tests = unittest.TestLoader().loadTestsFromTestCase(testcase) if xmlrunner: - result = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=3).run(tests) + result = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2).run(tests) if not result.wasSuccessful(): failed = True else: - result = unittest.TextTestRunner(verbosity=3).run(tests) + result = unittest.TextTestRunner(verbosity=2).run(tests) if not result.wasSuccessful(): failed = True sys.exit(failed) diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 9111dbbed5929..8392d7f29af53 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -2353,15 +2353,7 @@ def test_statcounter_array(self): if __name__ == "__main__": from pyspark.tests import * - if not _have_scipy: - print("NOTE: Skipping SciPy tests as it does not seem to be installed") - if not _have_numpy: - print("NOTE: Skipping NumPy tests as it does not seem to be installed") if xmlrunner: - unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports')) + unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'), verbosity=2) else: - unittest.main() - if not _have_scipy: - print("NOTE: SciPy tests were skipped as it does not seem to be installed") - if not _have_numpy: - print("NOTE: NumPy tests were skipped as it does not seem to be installed") + unittest.main(verbosity=2) diff --git a/python/run-tests.py b/python/run-tests.py index 6b41b5ee22814..f408fc5082b3d 100755 --- a/python/run-tests.py +++ b/python/run-tests.py @@ -32,6 +32,7 @@ else: import queue as Queue from distutils.version import LooseVersion +from multiprocessing import Manager # Append `SPARK_HOME/dev` to the Python path so that we can import the sparktestsupport module @@ -50,6 +51,7 @@ def print_red(text): print('\033[31m' + text + '\033[0m') +SKIPPED_TESTS = Manager().dict() LOG_FILE = os.path.join(SPARK_HOME, "python/unit-tests.log") FAILURE_REPORTING_LOCK = Lock() LOGGER = logging.getLogger() @@ -109,8 +111,34 @@ def run_individual_python_test(test_name, pyspark_python): # this code is invoked from a thread other than the main thread. os._exit(-1) else: - per_test_output.close() - LOGGER.info("Finished test(%s): %s (%is)", pyspark_python, test_name, duration) + skipped_counts = 0 + try: + per_test_output.seek(0) + # Here expects skipped test output from unittest when verbosity level is + # 2 (or --verbose option is enabled). + decoded_lines = map(lambda line: line.decode(), iter(per_test_output)) + skipped_tests = list(filter( + lambda line: re.search('test_.* \(pyspark\..*\) ... skipped ', line), + decoded_lines)) + skipped_counts = len(skipped_tests) + if skipped_counts > 0: + key = (pyspark_python, test_name) + SKIPPED_TESTS[key] = skipped_tests + per_test_output.close() + except: + import traceback + print_red("\nGot an exception while trying to store " + "skipped test output:\n%s" % traceback.format_exc()) + # Here, we use os._exit() instead of sys.exit() in order to force Python to exit even if + # this code is invoked from a thread other than the main thread. + os._exit(-1) + if skipped_counts != 0: + LOGGER.info( + "Finished test(%s): %s (%is) ... %s tests were skipped", pyspark_python, test_name, + duration, skipped_counts) + else: + LOGGER.info( + "Finished test(%s): %s (%is)", pyspark_python, test_name, duration) def get_default_python_executables(): @@ -152,65 +180,17 @@ def parse_opts(): return opts -def _check_dependencies(python_exec, modules_to_test): - if "COVERAGE_PROCESS_START" in os.environ: - # Make sure if coverage is installed. - try: - subprocess_check_output( - [python_exec, "-c", "import coverage"], - stderr=open(os.devnull, 'w')) - except: - print_red("Coverage is not installed in Python executable '%s' " - "but 'COVERAGE_PROCESS_START' environment variable is set, " - "exiting." % python_exec) - sys.exit(-1) - - # If we should test 'pyspark-sql', it checks if PyArrow and Pandas are installed and - # explicitly prints out. See SPARK-23300. - if pyspark_sql in modules_to_test: - # TODO(HyukjinKwon): Relocate and deduplicate these version specifications. - minimum_pyarrow_version = '0.8.0' - minimum_pandas_version = '0.19.2' - - try: - pyarrow_version = subprocess_check_output( - [python_exec, "-c", "import pyarrow; print(pyarrow.__version__)"], - universal_newlines=True, - stderr=open(os.devnull, 'w')).strip() - if LooseVersion(pyarrow_version) >= LooseVersion(minimum_pyarrow_version): - LOGGER.info("Will test PyArrow related features against Python executable " - "'%s' in '%s' module." % (python_exec, pyspark_sql.name)) - else: - LOGGER.warning( - "Will skip PyArrow related features against Python executable " - "'%s' in '%s' module. PyArrow >= %s is required; however, PyArrow " - "%s was found." % ( - python_exec, pyspark_sql.name, minimum_pyarrow_version, pyarrow_version)) - except: - LOGGER.warning( - "Will skip PyArrow related features against Python executable " - "'%s' in '%s' module. PyArrow >= %s is required; however, PyArrow " - "was not found." % (python_exec, pyspark_sql.name, minimum_pyarrow_version)) - - try: - pandas_version = subprocess_check_output( - [python_exec, "-c", "import pandas; print(pandas.__version__)"], - universal_newlines=True, - stderr=open(os.devnull, 'w')).strip() - if LooseVersion(pandas_version) >= LooseVersion(minimum_pandas_version): - LOGGER.info("Will test Pandas related features against Python executable " - "'%s' in '%s' module." % (python_exec, pyspark_sql.name)) - else: - LOGGER.warning( - "Will skip Pandas related features against Python executable " - "'%s' in '%s' module. Pandas >= %s is required; however, Pandas " - "%s was found." % ( - python_exec, pyspark_sql.name, minimum_pandas_version, pandas_version)) - except: - LOGGER.warning( - "Will skip Pandas related features against Python executable " - "'%s' in '%s' module. Pandas >= %s is required; however, Pandas " - "was not found." % (python_exec, pyspark_sql.name, minimum_pandas_version)) +def _check_coverage(python_exec): + # Make sure if coverage is installed. + try: + subprocess_check_output( + [python_exec, "-c", "import coverage"], + stderr=open(os.devnull, 'w')) + except: + print_red("Coverage is not installed in Python executable '%s' " + "but 'COVERAGE_PROCESS_START' environment variable is set, " + "exiting." % python_exec) + sys.exit(-1) def main(): @@ -237,9 +217,10 @@ def main(): task_queue = Queue.PriorityQueue() for python_exec in python_execs: - # Check if the python executable has proper dependencies installed to run tests - # for given modules properly. - _check_dependencies(python_exec, modules_to_test) + # Check if the python executable has coverage installed when 'COVERAGE_PROCESS_START' + # environmental variable is set. + if "COVERAGE_PROCESS_START" in os.environ: + _check_coverage(python_exec) python_implementation = subprocess_check_output( [python_exec, "-c", "import platform; print(platform.python_implementation())"], @@ -281,6 +262,12 @@ def process_queue(task_queue): total_duration = time.time() - start_time LOGGER.info("Tests passed in %i seconds", total_duration) + for key, lines in sorted(SKIPPED_TESTS.items()): + pyspark_python, test_name = key + LOGGER.info("\nSkipped tests in %s with %s:" % (test_name, pyspark_python)) + for line in lines: + LOGGER.info(" %s" % line.rstrip()) + if __name__ == "__main__": main() From 9ee9fcf5223efdf7543161b7bc99131111876b92 Mon Sep 17 00:00:00 2001 From: zhoukang Date: Thu, 26 Apr 2018 15:38:11 -0700 Subject: [PATCH 0700/1161] [SPARK-24083][YARN] Log stacktrace for uncaught exception ## What changes were proposed in this pull request? Log stacktrace for uncaught exception ## How was this patch tested? UT and manually test Author: zhoukang Closes #21151 from caneGuy/zhoukang/log-stacktrace. --- .../scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index d04989e138f83..650840045361c 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -308,7 +308,7 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends logError("Uncaught exception: ", e) finish(FinalApplicationStatus.FAILED, ApplicationMaster.EXIT_UNCAUGHT_EXCEPTION, - "Uncaught exception: " + e) + "Uncaught exception: " + StringUtils.stringifyException(e)) } } From 8aa1d7b0ede5115297541d29eab4ce5f4fe905cb Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Fri, 27 Apr 2018 11:00:41 +0800 Subject: [PATCH 0701/1161] [SPARK-23355][SQL] convertMetastore should not ignore table properties ## What changes were proposed in this pull request? Previously, SPARK-22158 fixed for `USING hive` syntax. This PR aims to fix for `STORED AS` syntax. Although the test case covers ORC part, the patch considers both `convertMetastoreOrc` and `convertMetastoreParquet`. ## How was this patch tested? Pass newly added test cases. Author: Dongjoon Hyun Closes #20522 from dongjoon-hyun/SPARK-22158-2. --- .../spark/sql/hive/HiveStrategies.scala | 17 +++- .../sql/hive/CompressionCodecSuite.scala | 7 +- .../sql/hive/execution/HiveDDLSuite.scala | 81 +++++++++++++++++++ 3 files changed, 97 insertions(+), 8 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index 8df05cbb20361..a0c197b06ddab 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -186,15 +186,28 @@ case class RelationConversions( serde.contains("orc") && conf.getConf(HiveUtils.CONVERT_METASTORE_ORC) } + // Return true for Apache ORC and Hive ORC-related configuration names. + // Note that Spark doesn't support configurations like `hive.merge.orcfile.stripe.level`. + private def isOrcProperty(key: String) = + key.startsWith("orc.") || key.contains(".orc.") + + private def isParquetProperty(key: String) = + key.startsWith("parquet.") || key.contains(".parquet.") + private def convert(relation: HiveTableRelation): LogicalRelation = { val serde = relation.tableMeta.storage.serde.getOrElse("").toLowerCase(Locale.ROOT) + + // Consider table and storage properties. For properties existing in both sides, storage + // properties will supersede table properties. if (serde.contains("parquet")) { - val options = relation.tableMeta.storage.properties + (ParquetOptions.MERGE_SCHEMA -> + val options = relation.tableMeta.properties.filterKeys(isParquetProperty) ++ + relation.tableMeta.storage.properties + (ParquetOptions.MERGE_SCHEMA -> conf.getConf(HiveUtils.CONVERT_METASTORE_PARQUET_WITH_SCHEMA_MERGING).toString) sessionCatalog.metastoreCatalog .convertToLogicalRelation(relation, options, classOf[ParquetFileFormat], "parquet") } else { - val options = relation.tableMeta.storage.properties + val options = relation.tableMeta.properties.filterKeys(isOrcProperty) ++ + relation.tableMeta.storage.properties if (conf.getConf(SQLConf.ORC_IMPLEMENTATION) == "native") { sessionCatalog.metastoreCatalog.convertToLogicalRelation( relation, diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CompressionCodecSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CompressionCodecSuite.scala index d10a6f25c64fc..4550d350f6db2 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CompressionCodecSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CompressionCodecSuite.scala @@ -268,12 +268,7 @@ class CompressionCodecSuite extends TestHiveSingleton with ParquetTest with Befo compressionCodecs = compressCodecs, tableCompressionCodecs = compressCodecs) { case (tableCodec, sessionCodec, realCodec, tableSize) => - // For non-partitioned table and when convertMetastore is true, Expect session-level - // take effect, and in other cases expect table-level take effect - // TODO: It should always be table-level taking effect when the bug(SPARK-22926) - // is fixed - val expectCodec = - if (convertMetastore && !isPartitioned) sessionCodec else tableCodec.get + val expectCodec = tableCodec.get assert(expectCodec == realCodec) assert(checkTableSize( format, expectCodec, isPartitioned, convertMetastore, usingCTAS, tableSize)) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala index c85db78c732de..daac6af9b557f 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -38,6 +38,7 @@ import org.apache.spark.sql.hive.HiveUtils.{CONVERT_METASTORE_ORC, CONVERT_METAS import org.apache.spark.sql.hive.orc.OrcFileOperator import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.internal.{HiveSerDe, SQLConf} +import org.apache.spark.sql.internal.SQLConf.ORC_IMPLEMENTATION import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ @@ -2144,6 +2145,86 @@ class HiveDDLSuite } } + private def getReader(path: String): org.apache.orc.Reader = { + val conf = spark.sessionState.newHadoopConf() + val files = org.apache.spark.sql.execution.datasources.orc.OrcUtils.listOrcFiles(path, conf) + assert(files.length == 1) + val file = files.head + val fs = file.getFileSystem(conf) + val readerOptions = org.apache.orc.OrcFile.readerOptions(conf).filesystem(fs) + org.apache.orc.OrcFile.createReader(file, readerOptions) + } + + test("SPARK-23355 convertMetastoreOrc should not ignore table properties - STORED AS") { + Seq("native", "hive").foreach { orcImpl => + withSQLConf(ORC_IMPLEMENTATION.key -> orcImpl, CONVERT_METASTORE_ORC.key -> "true") { + withTable("t") { + withTempPath { path => + sql( + s""" + |CREATE TABLE t(id int) STORED AS ORC + |TBLPROPERTIES ( + | orc.compress 'ZLIB', + | orc.compress.size '1001', + | orc.row.index.stride '2002', + | hive.exec.orc.default.block.size '3003', + | hive.exec.orc.compression.strategy 'COMPRESSION') + |LOCATION '${path.toURI}' + """.stripMargin) + val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) + assert(DDLUtils.isHiveTable(table)) + assert(table.storage.serde.get.contains("orc")) + val properties = table.properties + assert(properties.get("orc.compress") == Some("ZLIB")) + assert(properties.get("orc.compress.size") == Some("1001")) + assert(properties.get("orc.row.index.stride") == Some("2002")) + assert(properties.get("hive.exec.orc.default.block.size") == Some("3003")) + assert(properties.get("hive.exec.orc.compression.strategy") == Some("COMPRESSION")) + assert(spark.table("t").collect().isEmpty) + + sql("INSERT INTO t SELECT 1") + checkAnswer(spark.table("t"), Row(1)) + val maybeFile = path.listFiles().find(_.getName.startsWith("part")) + + val reader = getReader(maybeFile.head.getCanonicalPath) + assert(reader.getCompressionKind.name === "ZLIB") + assert(reader.getCompressionSize == 1001) + assert(reader.getRowIndexStride == 2002) + } + } + } + } + } + + test("SPARK-23355 convertMetastoreParquet should not ignore table properties - STORED AS") { + withSQLConf(CONVERT_METASTORE_PARQUET.key -> "true") { + withTable("t") { + withTempPath { path => + sql( + s""" + |CREATE TABLE t(id int) STORED AS PARQUET + |TBLPROPERTIES ( + | parquet.compression 'GZIP' + |) + |LOCATION '${path.toURI}' + """.stripMargin) + val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) + assert(DDLUtils.isHiveTable(table)) + assert(table.storage.serde.get.contains("parquet")) + val properties = table.properties + assert(properties.get("parquet.compression") == Some("GZIP")) + assert(spark.table("t").collect().isEmpty) + + sql("INSERT INTO t SELECT 1") + checkAnswer(spark.table("t"), Row(1)) + val maybeFile = path.listFiles().find(_.getName.startsWith("part")) + + assertCompression(maybeFile, "parquet", "GZIP") + } + } + } + } + test("load command for non local invalid path validation") { withTable("tbl") { sql("CREATE TABLE tbl(i INT, j STRING)") From 109935fc5d8b3d381bb1b09a4a570040a0a1846f Mon Sep 17 00:00:00 2001 From: eric-maynard Date: Fri, 27 Apr 2018 15:25:07 +0800 Subject: [PATCH 0702/1161] [SPARK-23830][YARN] added check to ensure main method is found ## What changes were proposed in this pull request? When a user specifies the wrong class -- or, in fact, a class instead of an object -- Spark throws an NPE which is not useful for debugging. This was reported in [SPARK-23830](https://issues.apache.org/jira/browse/SPARK-23830). This PR adds a check to ensure the main method was found and logs a useful error in the event that it's null. ## How was this patch tested? * Unit tests + Manual testing * The scope of the changes is very limited Author: eric-maynard Author: Eric Maynard Closes #21168 from eric-maynard/feature/SPARK-23830. --- .../spark/deploy/yarn/ApplicationMaster.scala | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 650840045361c..595077e7e809f 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -18,7 +18,7 @@ package org.apache.spark.deploy.yarn import java.io.{File, IOException} -import java.lang.reflect.InvocationTargetException +import java.lang.reflect.{InvocationTargetException, Modifier} import java.net.{Socket, URI, URL} import java.security.PrivilegedExceptionAction import java.util.concurrent.{TimeoutException, TimeUnit} @@ -675,9 +675,14 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends val userThread = new Thread { override def run() { try { - mainMethod.invoke(null, userArgs.toArray) - finish(FinalApplicationStatus.SUCCEEDED, ApplicationMaster.EXIT_SUCCESS) - logDebug("Done running users class") + if (!Modifier.isStatic(mainMethod.getModifiers)) { + logError(s"Could not find static main method in object ${args.userClass}") + finish(FinalApplicationStatus.FAILED, ApplicationMaster.EXIT_EXCEPTION_USER_CLASS) + } else { + mainMethod.invoke(null, userArgs.toArray) + finish(FinalApplicationStatus.SUCCEEDED, ApplicationMaster.EXIT_SUCCESS) + logDebug("Done running user class") + } } catch { case e: InvocationTargetException => e.getCause match { From 2824f12b8bac5d86a82339d4dfb4d2625e978a15 Mon Sep 17 00:00:00 2001 From: Patrick McGloin Date: Fri, 27 Apr 2018 23:04:14 +0800 Subject: [PATCH 0703/1161] [SPARK-23565][SS] New error message for structured streaming sources assertion ## What changes were proposed in this pull request? A more informative message to tell you why a structured streaming query cannot continue if you have added more sources, than there are in the existing checkpoint offsets. ## How was this patch tested? I added a Unit Test. Author: Patrick McGloin Closes #20946 from patrickmcgloin/master. --- .../org/apache/spark/sql/execution/streaming/OffsetSeq.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala index 73945b39b8967..787174481ff08 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala @@ -39,7 +39,9 @@ case class OffsetSeq(offsets: Seq[Option[Offset]], metadata: Option[OffsetSeqMet * cannot be serialized). */ def toStreamProgress(sources: Seq[BaseStreamingSource]): StreamProgress = { - assert(sources.size == offsets.size) + assert(sources.size == offsets.size, s"There are [${offsets.size}] sources in the " + + s"checkpoint offsets and now there are [${sources.size}] sources requested by the query. " + + s"Cannot continue.") new StreamProgress ++ sources.zip(offsets).collect { case (s, Some(o)) => (s, o) } } From 3fd297af6dc568357c97abf86760c570309d6597 Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Fri, 27 Apr 2018 11:43:29 -0700 Subject: [PATCH 0704/1161] [SPARK-24085][SQL] Query returns UnsupportedOperationException when scalar subquery is present in partitioning expression ## What changes were proposed in this pull request? In this case, the partition pruning happens before the planning phase of scalar subquery expressions. For scalar subquery expressions, the planning occurs late in the cycle (after the physical planning) in "PlanSubqueries" just before execution. Currently we try to execute the scalar subquery expression as part of partition pruning and fail as it implements Unevaluable. The fix attempts to ignore the Subquery expressions from partition pruning computation. Another option can be to somehow plan the subqueries before the partition pruning. Since this may not be a commonly occuring expression, i am opting for a simpler fix. Repro ``` SQL CREATE TABLE test_prc_bug ( id_value string ) partitioned by (id_type string) location '/tmp/test_prc_bug' stored as parquet; insert into test_prc_bug values ('1','a'); insert into test_prc_bug values ('2','a'); insert into test_prc_bug values ('3','b'); insert into test_prc_bug values ('4','b'); select * from test_prc_bug where id_type = (select 'b'); ``` ## How was this patch tested? Added test in SubquerySuite and hive/SQLQuerySuite Author: Dilip Biswal Closes #21174 from dilipbiswal/spark-24085. --- .../datasources/FileSourceStrategy.scala | 5 ++- .../PruneFileSourcePartitions.scala | 4 ++- .../org/apache/spark/sql/SubquerySuite.scala | 15 +++++++++ .../sql/hive/execution/SQLQuerySuite.scala | 31 +++++++++++++++++++ 4 files changed, 53 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala index 16b22717b8d92..0a568d6b8adce 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala @@ -76,7 +76,10 @@ object FileSourceStrategy extends Strategy with Logging { fsRelation.partitionSchema, fsRelation.sparkSession.sessionState.analyzer.resolver) val partitionSet = AttributeSet(partitionColumns) val partitionKeyFilters = - ExpressionSet(normalizedFilters.filter(_.references.subsetOf(partitionSet))) + ExpressionSet(normalizedFilters + .filterNot(SubqueryExpression.hasSubquery(_)) + .filter(_.references.subsetOf(partitionSet))) + logInfo(s"Pruning directories with: ${partitionKeyFilters.mkString(",")}") val dataColumns = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala index 3b830accb83f0..16b2367bfdd5c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala @@ -55,7 +55,9 @@ private[sql] object PruneFileSourcePartitions extends Rule[LogicalPlan] { partitionSchema, sparkSession.sessionState.analyzer.resolver) val partitionSet = AttributeSet(partitionColumns) val partitionKeyFilters = - ExpressionSet(normalizedFilters.filter(_.references.subsetOf(partitionSet))) + ExpressionSet(normalizedFilters + .filterNot(SubqueryExpression.hasSubquery(_)) + .filter(_.references.subsetOf(partitionSet))) if (partitionKeyFilters.nonEmpty) { val prunedFileIndex = catalogFileIndex.filterPartitions(partitionKeyFilters.toSeq) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala index 31e8b0e8dede0..acef62d81ee12 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala @@ -955,4 +955,19 @@ class SubquerySuite extends QueryTest with SharedSQLContext { // before the fix this would throw AnalysisException spark.range(10).where("(id,id) in (select id, null from range(3))").count } + + test("SPARK-24085 scalar subquery in partitioning expression") { + withTable("parquet_part") { + Seq("1" -> "a", "2" -> "a", "3" -> "b", "4" -> "b") + .toDF("id_value", "id_type") + .write + .mode(SaveMode.Overwrite) + .partitionBy("id_type") + .format("parquet") + .saveAsTable("parquet_part") + checkAnswer( + sql("SELECT * FROM parquet_part WHERE id_type = (SELECT 'b')"), + Row("3", "b") :: Row("4", "b") :: Nil) + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 73f83d593bbfb..704a410b6a37b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -2156,4 +2156,35 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } } } + + test("SPARK-24085 scalar subquery in partitioning expression") { + Seq("orc", "parquet").foreach { format => + Seq(true, false).foreach { isConverted => + withSQLConf( + HiveUtils.CONVERT_METASTORE_ORC.key -> s"$isConverted", + HiveUtils.CONVERT_METASTORE_PARQUET.key -> s"$isConverted", + "hive.exec.dynamic.partition.mode" -> "nonstrict") { + withTable(format) { + withTempPath { tempDir => + sql( + s""" + |CREATE TABLE ${format} (id_value string) + |PARTITIONED BY (id_type string) + |LOCATION '${tempDir.toURI}' + |STORED AS ${format} + """.stripMargin) + sql(s"insert into $format values ('1','a')") + sql(s"insert into $format values ('2','a')") + sql(s"insert into $format values ('3','b')") + sql(s"insert into $format values ('4','b')") + checkAnswer( + sql(s"SELECT * FROM $format WHERE id_type = (SELECT 'b')"), + Row("3", "b") :: Row("4", "b") :: Nil) + } + } + } + } + } + } + } From 8614edd445264007144caa6743a8c2ca2b5082e0 Mon Sep 17 00:00:00 2001 From: Juliusz Sompolski Date: Fri, 27 Apr 2018 14:14:28 -0700 Subject: [PATCH 0705/1161] [SPARK-24104] SQLAppStatusListener overwrites metrics onDriverAccumUpdates instead of updating them ## What changes were proposed in this pull request? Event `SparkListenerDriverAccumUpdates` may happen multiple times in a query - e.g. every `FileSourceScanExec` and `BroadcastExchangeExec` call `postDriverMetricUpdates`. In Spark 2.2 `SQLListener` updated the map with new values. `SQLAppStatusListener` overwrites it. Unless `update` preserved it in the KV store (dependant on `exec.lastWriteTime`), only the metrics from the last operator that does `postDriverMetricUpdates` are preserved. ## How was this patch tested? Unit test added. Author: Juliusz Sompolski Closes #21171 from juliuszsompolski/SPARK-24104. --- .../execution/ui/SQLAppStatusListener.scala | 2 +- .../ui/SQLAppStatusListenerSuite.scala | 24 +++++++++++++++---- 2 files changed, 21 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala index 2b6bb48467eb3..d254af400a7cf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala @@ -289,7 +289,7 @@ class SQLAppStatusListener( private def onDriverAccumUpdates(event: SparkListenerDriverAccumUpdates): Unit = { val SparkListenerDriverAccumUpdates(executionId, accumUpdates) = event Option(liveExecutions.get(executionId)).foreach { exec => - exec.driverAccumUpdates = accumUpdates.toMap + exec.driverAccumUpdates = exec.driverAccumUpdates ++ accumUpdates update(exec) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala index f3f08839c1d3a..02df45d1b7989 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala @@ -443,7 +443,8 @@ class SQLAppStatusListenerSuite extends SparkFunSuite with SharedSQLContext with val oldCount = statusStore.executionsList().size val expectedAccumValue = 12345 - val physicalPlan = MyPlan(sqlContext.sparkContext, expectedAccumValue) + val expectedAccumValue2 = 54321 + val physicalPlan = MyPlan(sqlContext.sparkContext, expectedAccumValue, expectedAccumValue2) val dummyQueryExecution = new QueryExecution(spark, LocalRelation()) { override lazy val sparkPlan = physicalPlan override lazy val executedPlan = physicalPlan @@ -466,10 +467,14 @@ class SQLAppStatusListenerSuite extends SparkFunSuite with SharedSQLContext with val execId = statusStore.executionsList().last.executionId val metrics = statusStore.executionMetrics(execId) val driverMetric = physicalPlan.metrics("dummy") + val driverMetric2 = physicalPlan.metrics("dummy2") val expectedValue = SQLMetrics.stringValue(driverMetric.metricType, Seq(expectedAccumValue)) + val expectedValue2 = SQLMetrics.stringValue(driverMetric2.metricType, Seq(expectedAccumValue2)) assert(metrics.contains(driverMetric.id)) assert(metrics(driverMetric.id) === expectedValue) + assert(metrics.contains(driverMetric2.id)) + assert(metrics(driverMetric2.id) === expectedValue2) } test("roundtripping SparkListenerDriverAccumUpdates through JsonProtocol (SPARK-18462)") { @@ -562,20 +567,31 @@ class SQLAppStatusListenerSuite extends SparkFunSuite with SharedSQLContext with * A dummy [[org.apache.spark.sql.execution.SparkPlan]] that updates a [[SQLMetrics]] * on the driver. */ -private case class MyPlan(sc: SparkContext, expectedValue: Long) extends LeafExecNode { +private case class MyPlan(sc: SparkContext, expectedValue: Long, expectedValue2: Long) + extends LeafExecNode { + override def sparkContext: SparkContext = sc override def output: Seq[Attribute] = Seq() override val metrics: Map[String, SQLMetric] = Map( - "dummy" -> SQLMetrics.createMetric(sc, "dummy")) + "dummy" -> SQLMetrics.createMetric(sc, "dummy"), + "dummy2" -> SQLMetrics.createMetric(sc, "dummy2")) override def doExecute(): RDD[InternalRow] = { longMetric("dummy") += expectedValue + longMetric("dummy2") += expectedValue2 + + // postDriverMetricUpdates may happen multiple time in a query. + // (normally from different operators, but for the sake of testing, from one operator) + SQLMetrics.postDriverMetricUpdates( + sc, + sc.getLocalProperty(SQLExecution.EXECUTION_ID_KEY), + Seq(metrics("dummy"))) SQLMetrics.postDriverMetricUpdates( sc, sc.getLocalProperty(SQLExecution.EXECUTION_ID_KEY), - metrics.values.toSeq) + Seq(metrics("dummy2"))) sc.emptyRDD } } From 1fb46f30f83e4751169ff288ad406f26b7c11f7e Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Sat, 28 Apr 2018 09:55:56 +0800 Subject: [PATCH 0706/1161] [SPARK-23688][SS] Refactor tests away from rate source ## What changes were proposed in this pull request? Replace rate source with memory source in continuous mode test suite. Keep using "rate" source if the tests intend to put data periodically in background, or need to put short source name to load, since "memory" doesn't have provider for source. ## How was this patch tested? Ran relevant test suite from IDE. Author: Jungtaek Lim Closes #21152 from HeartSaVioR/SPARK-23688. --- .../continuous/ContinuousSuite.scala | 163 +++++++----------- 1 file changed, 61 insertions(+), 102 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala index c318b951ff992..5f222e7885994 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala @@ -75,73 +75,50 @@ class ContinuousSuite extends ContinuousSuiteBase { } test("map") { - val df = spark.readStream - .format("rate") - .option("numPartitions", "5") - .option("rowsPerSecond", "5") - .load() - .select('value) - .map(r => r.getLong(0) * 2) + val input = ContinuousMemoryStream[Int] + val df = input.toDF().map(_.getInt(0) * 2) - testStream(df, useV2Sink = true)( - StartStream(longContinuousTrigger), - AwaitEpoch(0), - Execute(waitForRateSourceTriggers(_, 2)), - IncrementEpoch(), - Execute(waitForRateSourceTriggers(_, 4)), - IncrementEpoch(), - CheckAnswerRowsContains(scala.Range(0, 40, 2).map(Row(_)))) + testStream(df)( + AddData(input, 0, 1), + CheckAnswer(0, 2), + StopStream, + AddData(input, 2, 3, 4), + StartStream(), + CheckAnswer(0, 2, 4, 6, 8)) } test("flatMap") { - val df = spark.readStream - .format("rate") - .option("numPartitions", "5") - .option("rowsPerSecond", "5") - .load() - .select('value) - .flatMap(r => Seq(0, r.getLong(0), r.getLong(0) * 2)) + val input = ContinuousMemoryStream[Int] + val df = input.toDF().flatMap(r => Seq(0, r.getInt(0), r.getInt(0) * 2)) - testStream(df, useV2Sink = true)( - StartStream(longContinuousTrigger), - AwaitEpoch(0), - Execute(waitForRateSourceTriggers(_, 2)), - IncrementEpoch(), - Execute(waitForRateSourceTriggers(_, 4)), - IncrementEpoch(), - CheckAnswerRowsContains(scala.Range(0, 20).flatMap(n => Seq(0, n, n * 2)).map(Row(_)))) + testStream(df)( + AddData(input, 0, 1), + CheckAnswer((0 to 1).flatMap(n => Seq(0, n, n * 2)): _*), + StopStream, + AddData(input, 2, 3, 4), + StartStream(), + CheckAnswer((0 to 4).flatMap(n => Seq(0, n, n * 2)): _*)) } test("filter") { - val df = spark.readStream - .format("rate") - .option("numPartitions", "5") - .option("rowsPerSecond", "5") - .load() - .select('value) - .where('value > 5) + val input = ContinuousMemoryStream[Int] + val df = input.toDF().where('value > 2) - testStream(df, useV2Sink = true)( - StartStream(longContinuousTrigger), - AwaitEpoch(0), - Execute(waitForRateSourceTriggers(_, 2)), - IncrementEpoch(), - Execute(waitForRateSourceTriggers(_, 4)), - IncrementEpoch(), - CheckAnswerRowsContains(scala.Range(6, 20).map(Row(_)))) + testStream(df)( + AddData(input, 0, 1), + CheckAnswer(), + StopStream, + AddData(input, 2, 3, 4), + StartStream(), + CheckAnswer(3, 4)) } test("deduplicate") { - val df = spark.readStream - .format("rate") - .option("numPartitions", "5") - .option("rowsPerSecond", "5") - .load() - .select('value) - .dropDuplicates() + val input = ContinuousMemoryStream[Int] + val df = input.toDF().dropDuplicates() val except = intercept[AnalysisException] { - testStream(df, useV2Sink = true)(StartStream(longContinuousTrigger)) + testStream(df)(StartStream()) } assert(except.message.contains( @@ -149,15 +126,11 @@ class ContinuousSuite extends ContinuousSuiteBase { } test("timestamp") { - val df = spark.readStream - .format("rate") - .option("numPartitions", "5") - .option("rowsPerSecond", "5") - .load() - .select(current_timestamp()) + val input = ContinuousMemoryStream[Int] + val df = input.toDF().select(current_timestamp()) val except = intercept[AnalysisException] { - testStream(df, useV2Sink = true)(StartStream(longContinuousTrigger)) + testStream(df)(StartStream()) } assert(except.message.contains( @@ -165,58 +138,43 @@ class ContinuousSuite extends ContinuousSuiteBase { } test("subquery alias") { - val df = spark.readStream - .format("rate") - .option("numPartitions", "5") - .option("rowsPerSecond", "5") - .load() - .createOrReplaceTempView("rate") - val test = spark.sql("select value from rate where value > 5") + val input = ContinuousMemoryStream[Int] + input.toDF().createOrReplaceTempView("memory") + val test = spark.sql("select value from memory where value > 2") - testStream(test, useV2Sink = true)( - StartStream(longContinuousTrigger), - AwaitEpoch(0), - Execute(waitForRateSourceTriggers(_, 2)), - IncrementEpoch(), - Execute(waitForRateSourceTriggers(_, 4)), - IncrementEpoch(), - CheckAnswerRowsContains(scala.Range(6, 20).map(Row(_)))) + testStream(test)( + AddData(input, 0, 1), + CheckAnswer(), + StopStream, + AddData(input, 2, 3, 4), + StartStream(), + CheckAnswer(3, 4)) } test("repeatedly restart") { - val df = spark.readStream - .format("rate") - .option("numPartitions", "5") - .option("rowsPerSecond", "5") - .load() - .select('value) + val input = ContinuousMemoryStream[Int] + val df = input.toDF() - testStream(df, useV2Sink = true)( - StartStream(longContinuousTrigger), - AwaitEpoch(0), - Execute(waitForRateSourceTriggers(_, 2)), - IncrementEpoch(), - CheckAnswerRowsContains(scala.Range(0, 10).map(Row(_))), + testStream(df)( + StartStream(), + AddData(input, 0, 1), + CheckAnswer(0, 1), StopStream, - StartStream(longContinuousTrigger), + StartStream(), StopStream, - StartStream(longContinuousTrigger), + StartStream(), StopStream, - StartStream(longContinuousTrigger), - AwaitEpoch(2), - Execute(waitForRateSourceTriggers(_, 2)), - IncrementEpoch(), - CheckAnswerRowsContains(scala.Range(0, 20).map(Row(_))), + StartStream(), + StopStream, + AddData(input, 2, 3), + StartStream(), + CheckAnswer(0, 1, 2, 3), StopStream) } test("task failure kills the query") { - val df = spark.readStream - .format("rate") - .option("numPartitions", "5") - .option("rowsPerSecond", "5") - .load() - .select('value) + val input = ContinuousMemoryStream[Int] + val df = input.toDF() // Get an arbitrary task from this query to kill. It doesn't matter which one. var taskId: Long = -1 @@ -227,9 +185,9 @@ class ContinuousSuite extends ContinuousSuiteBase { } spark.sparkContext.addSparkListener(listener) try { - testStream(df, useV2Sink = true)( + testStream(df)( StartStream(Trigger.Continuous(100)), - Execute(waitForRateSourceTriggers(_, 2)), + AddData(input, 0, 1, 2, 3), Execute { _ => // Wait until a task is started, then kill its first attempt. eventually(timeout(streamingTimeout)) { @@ -252,6 +210,7 @@ class ContinuousSuite extends ContinuousSuiteBase { .option("rowsPerSecond", "2") .load() .select('value) + val query = df.writeStream .format("memory") .queryName("noharness") From ad94e8592b2e8f4c1bdbd958e110797c6658af84 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Sat, 28 Apr 2018 10:47:43 +0800 Subject: [PATCH 0707/1161] [SPARK-23736][SQL][FOLLOWUP] Error message should contains SQL types ## What changes were proposed in this pull request? In the error messages we should return the SQL types (like `string` rather than the internal types like `StringType`). ## How was this patch tested? added UT Author: Marco Gaido Closes #21181 from mgaido91/SPARK-23736_followup. --- .../sql/catalyst/expressions/collectionOperations.scala | 5 +++-- .../scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala | 5 +++++ 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 90223b9126555..6d63a531e3b74 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -863,8 +863,9 @@ case class Concat(children: Seq[Expression]) extends Expression { val childTypes = children.map(_.dataType) if (childTypes.exists(tpe => !allowedTypes.exists(_.acceptsType(tpe)))) { return TypeCheckResult.TypeCheckFailure( - s"input to function $prettyName should have been StringType, BinaryType or ArrayType," + - s" but it's " + childTypes.map(_.simpleString).mkString("[", ", ", "]")) + s"input to function $prettyName should have been ${StringType.simpleString}," + + s" ${BinaryType.simpleString} or ${ArrayType.simpleString}, but it's " + + childTypes.map(_.simpleString).mkString("[", ", ", "]")) } TypeUtils.checkForSameTypeInputExpr(childTypes, s"function $prettyName") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index c216d1322a06c..470a1c8e331ba 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -712,6 +712,11 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { intercept[AnalysisException] { df.selectExpr("concat(i1, array(i1, i2))") } + + val e = intercept[AnalysisException] { + df.selectExpr("concat(map(1, 2), map(3, 4))") + } + assert(e.getMessage.contains("string, binary or array")) } test("flatten function") { From 4df51361a5ff1fba20524f1b580f4049b328ed32 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Sat, 28 Apr 2018 16:57:41 +0800 Subject: [PATCH 0708/1161] [SPARK-22732][SS][FOLLOW-UP] Fix MemorySinkV2 toString error ## What changes were proposed in this pull request? Fix `MemorySinkV2` toString() error ## How was this patch tested? N/A Author: Yuming Wang Closes #21170 from wangyum/SPARK-22732. --- .../spark/sql/execution/streaming/sources/memoryV2.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala index 5f58246083bb2..d871d37ad37c1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala @@ -96,7 +96,7 @@ class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with Logging { case _ => throw new IllegalArgumentException( - s"Output mode $outputMode is not supported by MemorySink") + s"Output mode $outputMode is not supported by MemorySinkV2") } } else { logDebug(s"Skipping already committed batch: $batchId") @@ -107,7 +107,7 @@ class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with Logging { batches.clear() } - override def toString(): String = "MemorySink" + override def toString(): String = "MemorySinkV2" } case class MemoryWriterCommitMessage(partition: Int, data: Seq[Row]) extends WriterCommitMessage {} @@ -175,7 +175,7 @@ class MemoryDataWriter(partition: Int, outputMode: OutputMode) /** - * Used to query the data that has been written into a [[MemorySink]]. + * Used to query the data that has been written into a [[MemorySinkV2]]. */ case class MemoryPlanV2(sink: MemorySinkV2, override val output: Seq[Attribute]) extends LeafNode { private val sizePerRow = output.map(_.dataType.defaultSize).sum From bd14da6fd5a77cc03efff193a84ffccbe892cc13 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Sun, 29 Apr 2018 11:25:31 +0800 Subject: [PATCH 0709/1161] [SPARK-23094][SPARK-23723][SPARK-23724][SQL] Support custom encoding for json files ## What changes were proposed in this pull request? I propose new option for JSON datasource which allows to specify encoding (charset) of input and output files. Here is an example of using of the option: ``` spark.read.schema(schema) .option("multiline", "true") .option("encoding", "UTF-16LE") .json(fileName) ``` If the option is not specified, charset auto-detection mechanism is used by default. The option can be used for saving datasets to jsons. Currently Spark is able to save datasets into json files in `UTF-8` charset only. The changes allow to save data in any supported charset. Here is the approximate list of supported charsets by Oracle Java SE: https://docs.oracle.com/javase/8/docs/technotes/guides/intl/encoding.doc.html . An user can specify the charset of output jsons via the charset option like `.option("charset", "UTF-16BE")`. By default the output charset is still `UTF-8` to keep backward compatibility. The solution has the following restrictions for per-line mode (`multiline = false`): - If charset is different from UTF-8, the lineSep option must be specified. The option required because Hadoop LineReader cannot detect the line separator correctly. Here is the ticket for solving the issue: https://issues.apache.org/jira/browse/SPARK-23725 - Encoding with [BOM](https://en.wikipedia.org/wiki/Byte_order_mark) are not supported. For example, the `UTF-16` and `UTF-32` encodings are blacklisted. The problem can be solved by https://github.com/MaxGekk/spark-1/pull/2 ## How was this patch tested? I added the following tests: - reads an json file in `UTF-16LE` encoding with BOM in `multiline` mode - read json file by using charset auto detection (`UTF-32BE` with BOM) - read json file using of user's charset (`UTF-16LE`) - saving in `UTF-32BE` and read the result by standard library (not by Spark) - checking that default charset is `UTF-8` - handling wrong (unsupported) charset Author: Maxim Gekk Author: Maxim Gekk Closes #20937 from MaxGekk/json-encoding-line-sep. --- python/pyspark/sql/readwriter.py | 15 +- python/pyspark/sql/tests.py | 7 + .../sql/people_array_utf16le.json | Bin 0 -> 182 bytes .../catalyst/json/CreateJacksonParser.scala | 49 +++- .../spark/sql/catalyst/json/JSONOptions.scala | 39 ++- .../sql/catalyst/json/JacksonParser.scala | 10 +- .../apache/spark/sql/DataFrameReader.scala | 3 + .../apache/spark/sql/DataFrameWriter.scala | 8 +- .../datasources/json/JsonDataSource.scala | 60 +++-- .../datasources/json/JsonFileFormat.scala | 10 +- .../datasources/text/TextOptions.scala | 18 +- .../src/test/resources/test-data/utf16LE.json | Bin 0 -> 98 bytes .../resources/test-data/utf16WithBOM.json | Bin 0 -> 200 bytes .../resources/test-data/utf32BEWithBOM.json | Bin 0 -> 172 bytes .../datasources/json/JsonBenchmarks.scala | 179 +++++++++++++ .../datasources/json/JsonSuite.scala | 245 +++++++++++++++++- 16 files changed, 599 insertions(+), 44 deletions(-) create mode 100644 python/test_support/sql/people_array_utf16le.json create mode 100644 sql/core/src/test/resources/test-data/utf16LE.json create mode 100644 sql/core/src/test/resources/test-data/utf16WithBOM.json create mode 100644 sql/core/src/test/resources/test-data/utf32BEWithBOM.json create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonBenchmarks.scala diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index df176c579fc8b..6811fa6b3b156 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -176,7 +176,8 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, allowComments=None, allowUnquotedFieldNames=None, allowSingleQuotes=None, allowNumericLeadingZero=None, allowBackslashEscapingAnyCharacter=None, mode=None, columnNameOfCorruptRecord=None, dateFormat=None, timestampFormat=None, - multiLine=None, allowUnquotedControlChars=None, lineSep=None, samplingRatio=None): + multiLine=None, allowUnquotedControlChars=None, lineSep=None, samplingRatio=None, + encoding=None): """ Loads JSON files and returns the results as a :class:`DataFrame`. @@ -237,6 +238,10 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, :param allowUnquotedControlChars: allows JSON Strings to contain unquoted control characters (ASCII characters with value less than 32, including tab and line feed characters) or not. + :param encoding: allows to forcibly set one of standard basic or extended encoding for + the JSON files. For example UTF-16BE, UTF-32LE. If None is set, + the encoding of input JSON will be detected automatically + when the multiLine option is set to ``true``. :param lineSep: defines the line separator that should be used for parsing. If None is set, it covers all ``\\r``, ``\\r\\n`` and ``\\n``. :param samplingRatio: defines fraction of input JSON objects used for schema inferring. @@ -259,7 +264,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, mode=mode, columnNameOfCorruptRecord=columnNameOfCorruptRecord, dateFormat=dateFormat, timestampFormat=timestampFormat, multiLine=multiLine, allowUnquotedControlChars=allowUnquotedControlChars, lineSep=lineSep, - samplingRatio=samplingRatio) + samplingRatio=samplingRatio, encoding=encoding) if isinstance(path, basestring): path = [path] if type(path) == list: @@ -752,7 +757,7 @@ def saveAsTable(self, name, format=None, mode=None, partitionBy=None, **options) @since(1.4) def json(self, path, mode=None, compression=None, dateFormat=None, timestampFormat=None, - lineSep=None): + lineSep=None, encoding=None): """Saves the content of the :class:`DataFrame` in JSON format (`JSON Lines text format or newline-delimited JSON `_) at the specified path. @@ -776,6 +781,8 @@ def json(self, path, mode=None, compression=None, dateFormat=None, timestampForm formats follow the formats at ``java.text.SimpleDateFormat``. This applies to timestamp type. If None is set, it uses the default value, ``yyyy-MM-dd'T'HH:mm:ss.SSSXXX``. + :param encoding: specifies encoding (charset) of saved json files. If None is set, + the default UTF-8 charset will be used. :param lineSep: defines the line separator that should be used for writing. If None is set, it uses the default value, ``\\n``. @@ -784,7 +791,7 @@ def json(self, path, mode=None, compression=None, dateFormat=None, timestampForm self.mode(mode) self._set_opts( compression=compression, dateFormat=dateFormat, timestampFormat=timestampFormat, - lineSep=lineSep) + lineSep=lineSep, encoding=encoding) self._jwrite.json(path) @since(1.4) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 6b28c557a803e..e0cd2aa41a2d0 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -685,6 +685,13 @@ def test_multiline_json(self): multiLine=True) self.assertEqual(people1.collect(), people_array.collect()) + def test_encoding_json(self): + people_array = self.spark.read\ + .json("python/test_support/sql/people_array_utf16le.json", + multiLine=True, encoding="UTF-16LE") + expected = [Row(age=30, name=u'Andy'), Row(age=19, name=u'Justin')] + self.assertEqual(people_array.collect(), expected) + def test_linesep_json(self): df = self.spark.read.json("python/test_support/sql/people.json", lineSep=",") expected = [Row(_corrupt_record=None, name=u'Michael'), diff --git a/python/test_support/sql/people_array_utf16le.json b/python/test_support/sql/people_array_utf16le.json new file mode 100644 index 0000000000000000000000000000000000000000..9c657fa30ac9c651076ff8aa3676baa400b121fb GIT binary patch literal 182 zcma!M;9^h!!fGfDVk require(sep.nonEmpty, "'lineSep' cannot be an empty string.") sep } - // Note that the option 'lineSep' uses a different default value in read and write. - val lineSeparatorInRead: Option[Array[Byte]] = - lineSeparator.map(_.getBytes(StandardCharsets.UTF_8)) - // Note that JSON uses writer with UTF-8 charset. This string will be written out as UTF-8. + + /** + * Standard encoding (charset) name. For example UTF-8, UTF-16LE and UTF-32BE. + * If the encoding is not specified (None), it will be detected automatically + * when the multiLine option is set to `true`. + */ + val encoding: Option[String] = parameters.get("encoding") + .orElse(parameters.get("charset")).map { enc => + // The following encodings are not supported in per-line mode (multiline is false) + // because they cause some problems in reading files with BOM which is supposed to + // present in the files with such encodings. After splitting input files by lines, + // only the first lines will have the BOM which leads to impossibility for reading + // the rest lines. Besides of that, the lineSep option must have the BOM in such + // encodings which can never present between lines. + val blacklist = Seq(Charset.forName("UTF-16"), Charset.forName("UTF-32")) + val isBlacklisted = blacklist.contains(Charset.forName(enc)) + require(multiLine || !isBlacklisted, + s"""The ${enc} encoding must not be included in the blacklist when multiLine is disabled: + | ${blacklist.mkString(", ")}""".stripMargin) + + val isLineSepRequired = !(multiLine == false && + Charset.forName(enc) != StandardCharsets.UTF_8 && lineSeparator.isEmpty) + require(isLineSepRequired, s"The lineSep option must be specified for the $enc encoding") + + enc + } + + val lineSeparatorInRead: Option[Array[Byte]] = lineSeparator.map { lineSep => + lineSep.getBytes(encoding.getOrElse("UTF-8")) + } val lineSeparatorInWrite: String = lineSeparator.getOrElse("\n") /** Sets config options on a Jackson [[JsonFactory]]. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala index 7f6956994f31f..a5a4a13eb608b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.json -import java.io.ByteArrayOutputStream +import java.io.{ByteArrayOutputStream, CharConversionException} import scala.collection.mutable.ArrayBuffer import scala.util.Try @@ -361,6 +361,14 @@ class JacksonParser( // For such records, all fields other than the field configured by // `columnNameOfCorruptRecord` are set to `null`. throw BadRecordException(() => recordLiteral(record), () => None, e) + case e: CharConversionException if options.encoding.isEmpty => + val msg = + """JSON parser cannot handle a character in its input. + |Specifying encoding as an input option explicitly might help to resolve the issue. + |""".stripMargin + e.getMessage + val wrappedCharException = new CharConversionException(msg) + wrappedCharException.initCause(e) + throw BadRecordException(() => recordLiteral(record), () => None, wrappedCharException) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index b44552f0eb17b..6b2ea6c06d3ae 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -372,6 +372,9 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * `java.text.SimpleDateFormat`. This applies to timestamp type. *
  • `multiLine` (default `false`): parse one record, which may span multiple lines, * per file
  • + *
  • `encoding` (by default it is not set): allows to forcibly set one of standard basic + * or extended encoding for the JSON files. For example UTF-16BE, UTF-32LE. If the encoding + * is not specified and `multiLine` is set to `true`, it will be detected automatically.
  • *
  • `lineSep` (default covers all `\r`, `\r\n` and `\n`): defines the line separator * that should be used for parsing.
  • *
  • `samplingRatio` (default is 1.0): defines fraction of input JSON objects used diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index bbc063148a72c..e183fa6f9542b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -518,8 +518,9 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { *
  • `timestampFormat` (default `yyyy-MM-dd'T'HH:mm:ss.SSSXXX`): sets the string that * indicates a timestamp format. Custom date formats follow the formats at * `java.text.SimpleDateFormat`. This applies to timestamp type.
  • - *
  • `lineSep` (default `\n`): defines the line separator that should - * be used for writing.
  • + *
  • `encoding` (by default it is not set): specifies encoding (charset) of saved json + * files. If it is not set, the UTF-8 charset will be used.
  • + *
  • `lineSep` (default `\n`): defines the line separator that should be used for writing.
  • * * * @since 1.4.0 @@ -589,8 +590,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { *
  • `compression` (default `null`): compression codec to use when saving to file. This can be * one of the known case-insensitive shorten names (`none`, `bzip2`, `gzip`, `lz4`, * `snappy` and `deflate`).
  • - *
  • `lineSep` (default `\n`): defines the line separator that should - * be used for writing.
  • + *
  • `lineSep` (default `\n`): defines the line separator that should be used for writing.
  • * * * @since 1.6.0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala index 5769c09c9a1d9..983a5f0dcade2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala @@ -31,11 +31,11 @@ import org.apache.hadoop.mapreduce.lib.input.FileInputFormat import org.apache.spark.TaskContext import org.apache.spark.input.{PortableDataStream, StreamInputFormat} import org.apache.spark.rdd.{BinaryFileRDD, RDD} -import org.apache.spark.sql.{AnalysisException, Dataset, Encoders, SparkSession} +import org.apache.spark.sql.{Dataset, Encoders, SparkSession} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JSONOptions} import org.apache.spark.sql.execution.datasources._ -import org.apache.spark.sql.execution.datasources.text.{TextFileFormat, TextOptions} +import org.apache.spark.sql.execution.datasources.text.TextFileFormat import org.apache.spark.sql.types.StructType import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils @@ -92,26 +92,30 @@ object TextInputJsonDataSource extends JsonDataSource { sparkSession: SparkSession, inputPaths: Seq[FileStatus], parsedOptions: JSONOptions): StructType = { - val json: Dataset[String] = createBaseDataset( - sparkSession, inputPaths, parsedOptions.lineSeparator) + val json: Dataset[String] = createBaseDataset(sparkSession, inputPaths, parsedOptions) + inferFromDataset(json, parsedOptions) } def inferFromDataset(json: Dataset[String], parsedOptions: JSONOptions): StructType = { val sampled: Dataset[String] = JsonUtils.sample(json, parsedOptions) - val rdd: RDD[UTF8String] = sampled.queryExecution.toRdd.map(_.getUTF8String(0)) - JsonInferSchema.infer(rdd, parsedOptions, CreateJacksonParser.utf8String) + val rdd: RDD[InternalRow] = sampled.queryExecution.toRdd + val rowParser = parsedOptions.encoding.map { enc => + CreateJacksonParser.internalRow(enc, _: JsonFactory, _: InternalRow) + }.getOrElse(CreateJacksonParser.internalRow(_: JsonFactory, _: InternalRow)) + + JsonInferSchema.infer(rdd, parsedOptions, rowParser) } private def createBaseDataset( sparkSession: SparkSession, inputPaths: Seq[FileStatus], - lineSeparator: Option[String]): Dataset[String] = { - val textOptions = lineSeparator.map { lineSep => - Map(TextOptions.LINE_SEPARATOR -> lineSep) - }.getOrElse(Map.empty[String, String]) - + parsedOptions: JSONOptions): Dataset[String] = { val paths = inputPaths.map(_.getPath.toString) + val textOptions = Map.empty[String, String] ++ + parsedOptions.encoding.map("encoding" -> _) ++ + parsedOptions.lineSeparator.map("lineSep" -> _) + sparkSession.baseRelationToDataFrame( DataSource.apply( sparkSession, @@ -129,8 +133,12 @@ object TextInputJsonDataSource extends JsonDataSource { schema: StructType): Iterator[InternalRow] = { val linesReader = new HadoopFileLinesReader(file, parser.options.lineSeparatorInRead, conf) Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => linesReader.close())) + val textParser = parser.options.encoding + .map(enc => CreateJacksonParser.text(enc, _: JsonFactory, _: Text)) + .getOrElse(CreateJacksonParser.text(_: JsonFactory, _: Text)) + val safeParser = new FailureSafeParser[Text]( - input => parser.parse(input, CreateJacksonParser.text, textToUTF8String), + input => parser.parse(input, textParser, textToUTF8String), parser.options.parseMode, schema, parser.options.columnNameOfCorruptRecord) @@ -153,7 +161,11 @@ object MultiLineJsonDataSource extends JsonDataSource { parsedOptions: JSONOptions): StructType = { val json: RDD[PortableDataStream] = createBaseRdd(sparkSession, inputPaths) val sampled: RDD[PortableDataStream] = JsonUtils.sample(json, parsedOptions) - JsonInferSchema.infer(sampled, parsedOptions, createParser) + val parser = parsedOptions.encoding + .map(enc => createParser(enc, _: JsonFactory, _: PortableDataStream)) + .getOrElse(createParser(_: JsonFactory, _: PortableDataStream)) + + JsonInferSchema.infer[PortableDataStream](sampled, parsedOptions, parser) } private def createBaseRdd( @@ -175,11 +187,18 @@ object MultiLineJsonDataSource extends JsonDataSource { .values } - private def createParser(jsonFactory: JsonFactory, record: PortableDataStream): JsonParser = { - val path = new Path(record.getPath()) - CreateJacksonParser.inputStream( - jsonFactory, - CodecStreams.createInputStreamWithCloseResource(record.getConfiguration, path)) + private def dataToInputStream(dataStream: PortableDataStream): InputStream = { + val path = new Path(dataStream.getPath()) + CodecStreams.createInputStreamWithCloseResource(dataStream.getConfiguration, path) + } + + private def createParser(jsonFactory: JsonFactory, stream: PortableDataStream): JsonParser = { + CreateJacksonParser.inputStream(jsonFactory, dataToInputStream(stream)) + } + + private def createParser(enc: String, jsonFactory: JsonFactory, + stream: PortableDataStream): JsonParser = { + CreateJacksonParser.inputStream(enc, jsonFactory, dataToInputStream(stream)) } override def readFile( @@ -194,9 +213,12 @@ object MultiLineJsonDataSource extends JsonDataSource { UTF8String.fromBytes(ByteStreams.toByteArray(inputStream)) } } + val streamParser = parser.options.encoding + .map(enc => CreateJacksonParser.inputStream(enc, _: JsonFactory, _: InputStream)) + .getOrElse(CreateJacksonParser.inputStream(_: JsonFactory, _: InputStream)) val safeParser = new FailureSafeParser[InputStream]( - input => parser.parse(input, CreateJacksonParser.inputStream, partitionedFileString), + input => parser.parse[InputStream](input, streamParser, partitionedFileString), parser.options.parseMode, schema, parser.options.columnNameOfCorruptRecord) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala index 0862c746fffad..3b04510d29695 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.datasources.json +import java.nio.charset.{Charset, StandardCharsets} + import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} @@ -151,7 +153,13 @@ private[json] class JsonOutputWriter( context: TaskAttemptContext) extends OutputWriter with Logging { - private val writer = CodecStreams.createOutputStreamWriter(context, new Path(path)) + private val encoding = options.encoding match { + case Some(charsetName) => Charset.forName(charsetName) + case None => StandardCharsets.UTF_8 + } + + private val writer = CodecStreams.createOutputStreamWriter( + context, new Path(path), encoding) // create the Generator without separator inserted between 2 records private[this] val gen = new JacksonGenerator(dataSchema, writer, options) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextOptions.scala index 5c1a35434f7b5..e4e201995faa2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextOptions.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution.datasources.text -import java.nio.charset.StandardCharsets +import java.nio.charset.{Charset, StandardCharsets} import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, CompressionCodecs} @@ -41,13 +41,18 @@ private[text] class TextOptions(@transient private val parameters: CaseInsensiti */ val wholeText = parameters.getOrElse(WHOLETEXT, "false").toBoolean - private val lineSeparator: Option[String] = parameters.get(LINE_SEPARATOR).map { sep => - require(sep.nonEmpty, s"'$LINE_SEPARATOR' cannot be an empty string.") - sep + val encoding: Option[String] = parameters.get(ENCODING) + + val lineSeparator: Option[String] = parameters.get(LINE_SEPARATOR).map { lineSep => + require(lineSep.nonEmpty, s"'$LINE_SEPARATOR' cannot be an empty string.") + + lineSep } + // Note that the option 'lineSep' uses a different default value in read and write. - val lineSeparatorInRead: Option[Array[Byte]] = - lineSeparator.map(_.getBytes(StandardCharsets.UTF_8)) + val lineSeparatorInRead: Option[Array[Byte]] = lineSeparator.map { lineSep => + lineSep.getBytes(encoding.map(Charset.forName(_)).getOrElse(StandardCharsets.UTF_8)) + } val lineSeparatorInWrite: Array[Byte] = lineSeparatorInRead.getOrElse("\n".getBytes(StandardCharsets.UTF_8)) } @@ -55,5 +60,6 @@ private[text] class TextOptions(@transient private val parameters: CaseInsensiti private[datasources] object TextOptions { val COMPRESSION = "compression" val WHOLETEXT = "wholetext" + val ENCODING = "encoding" val LINE_SEPARATOR = "lineSep" } diff --git a/sql/core/src/test/resources/test-data/utf16LE.json b/sql/core/src/test/resources/test-data/utf16LE.json new file mode 100644 index 0000000000000000000000000000000000000000..ce4117fd299dfcbc7089e7c0530098bfcaf5a27e GIT binary patch literal 98 zcmbi20w;GhFpeJpqLd9J2PYe#WR62N(?$cl`!==KvkHk Roq(bsb5ek+xfp7J7y!-s4k`cu literal 0 HcmV?d00001 diff --git a/sql/core/src/test/resources/test-data/utf16WithBOM.json b/sql/core/src/test/resources/test-data/utf16WithBOM.json new file mode 100644 index 0000000000000000000000000000000000000000..cf4d29328b860ffe8288edea437222c6d432a100 GIT binary patch literal 200 zcmezWFPedufr~)_3ac5E7}6Lr8HyN+8A=%Z7!nzB8B&2_RzU2`kO36W1j;Be=m6C# tG2{T{G1WN%ML{N{09DiiRT68y3qw9bDMLB|(}RGj@}XvfOpXPc4*2#AY;xCDs(fH)C|bAdP&h(YSCptLiP&H!SNdXPSl f9+12a5Gz30IY1hupBVF;plV@mNP(JB3#7RK --jars + */ +object JSONBenchmarks { + val conf = new SparkConf() + + val spark = SparkSession.builder + .master("local[1]") + .appName("benchmark-json-datasource") + .config(conf) + .getOrCreate() + import spark.implicits._ + + def withTempPath(f: File => Unit): Unit = { + val path = Utils.createTempDir() + path.delete() + try f(path) finally Utils.deleteRecursively(path) + } + + + def schemaInferring(rowsNum: Int): Unit = { + val benchmark = new Benchmark("JSON schema inferring", rowsNum) + + withTempPath { path => + // scalastyle:off println + benchmark.out.println("Preparing data for benchmarking ...") + // scalastyle:on println + + spark.sparkContext.range(0, rowsNum, 1) + .map(_ => "a") + .toDF("fieldA") + .write + .option("encoding", "UTF-8") + .json(path.getAbsolutePath) + + benchmark.addCase("No encoding", 3) { _ => + spark.read.json(path.getAbsolutePath) + } + + benchmark.addCase("UTF-8 is set", 3) { _ => + spark.read + .option("encoding", "UTF-8") + .json(path.getAbsolutePath) + } + + /* + Intel(R) Core(TM) i7-7920HQ CPU @ 3.10GHz + + JSON schema inferring: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + -------------------------------------------------------------------------------------------- + No encoding 38902 / 39282 2.6 389.0 1.0X + UTF-8 is set 56959 / 57261 1.8 569.6 0.7X + */ + benchmark.run() + } + } + + def perlineParsing(rowsNum: Int): Unit = { + val benchmark = new Benchmark("JSON per-line parsing", rowsNum) + + withTempPath { path => + // scalastyle:off println + benchmark.out.println("Preparing data for benchmarking ...") + // scalastyle:on println + + spark.sparkContext.range(0, rowsNum, 1) + .map(_ => "a") + .toDF("fieldA") + .write.json(path.getAbsolutePath) + val schema = new StructType().add("fieldA", StringType) + + benchmark.addCase("No encoding", 3) { _ => + spark.read + .schema(schema) + .json(path.getAbsolutePath) + .count() + } + + benchmark.addCase("UTF-8 is set", 3) { _ => + spark.read + .option("encoding", "UTF-8") + .schema(schema) + .json(path.getAbsolutePath) + .count() + } + + /* + Intel(R) Core(TM) i7-7920HQ CPU @ 3.10GHz + + JSON per-line parsing: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + -------------------------------------------------------------------------------------------- + No encoding 25947 / 26188 3.9 259.5 1.0X + UTF-8 is set 46319 / 46417 2.2 463.2 0.6X + */ + benchmark.run() + } + } + + def perlineParsingOfWideColumn(rowsNum: Int): Unit = { + val benchmark = new Benchmark("JSON parsing of wide lines", rowsNum) + + withTempPath { path => + // scalastyle:off println + benchmark.out.println("Preparing data for benchmarking ...") + // scalastyle:on println + + spark.sparkContext.range(0, rowsNum, 1) + .map { i => + val s = "abcdef0123456789ABCDEF" * 20 + s"""{"a":"$s","b": $i,"c":"$s","d":$i,"e":"$s","f":$i,"x":"$s","y":$i,"z":"$s"}""" + } + .toDF().write.text(path.getAbsolutePath) + val schema = new StructType() + .add("a", StringType).add("b", LongType) + .add("c", StringType).add("d", LongType) + .add("e", StringType).add("f", LongType) + .add("x", StringType).add("y", LongType) + .add("z", StringType) + + benchmark.addCase("No encoding", 3) { _ => + spark.read + .schema(schema) + .json(path.getAbsolutePath) + .count() + } + + benchmark.addCase("UTF-8 is set", 3) { _ => + spark.read + .option("encoding", "UTF-8") + .schema(schema) + .json(path.getAbsolutePath) + .count() + } + + /* + Intel(R) Core(TM) i7-7920HQ CPU @ 3.10GHz + + JSON parsing of wide lines: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + -------------------------------------------------------------------------------------------- + No encoding 45543 / 45660 0.2 4554.3 1.0X + UTF-8 is set 65737 / 65957 0.2 6573.7 0.7X + */ + benchmark.run() + } + } + + def main(args: Array[String]): Unit = { + schemaInferring(100 * 1000 * 1000) + perlineParsing(100 * 1000 * 1000) + perlineParsingOfWideColumn(10 * 1000 * 1000) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index a58dff827b92d..0db688fec9a67 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -17,8 +17,8 @@ package org.apache.spark.sql.execution.datasources.json -import java.io.{File, StringWriter} -import java.nio.charset.StandardCharsets +import java.io.{File, FileOutputStream, StringWriter} +import java.nio.charset.{StandardCharsets, UnsupportedCharsetException} import java.nio.file.{Files, Paths, StandardOpenOption} import java.sql.{Date, Timestamp} import java.util.Locale @@ -48,6 +48,10 @@ class TestFileFilter extends PathFilter { class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { import testImplicits._ + def testFile(fileName: String): String = { + Thread.currentThread().getContextClassLoader.getResource(fileName).toString + } + test("Type promotion") { def checkTypePromotion(expected: Any, actual: Any) { assert(expected.getClass == actual.getClass, @@ -2167,4 +2171,241 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { val sampled = spark.read.option("samplingRatio", 1.0).json(ds) assert(sampled.count() == ds.count()) } + + test("SPARK-23723: json in UTF-16 with BOM") { + val fileName = "test-data/utf16WithBOM.json" + val schema = new StructType().add("firstName", StringType).add("lastName", StringType) + val jsonDF = spark.read.schema(schema) + .option("multiline", "true") + .option("encoding", "UTF-16") + .json(testFile(fileName)) + + checkAnswer(jsonDF, Seq(Row("Chris", "Baird"), Row("Doug", "Rood"))) + } + + test("SPARK-23723: multi-line json in UTF-32BE with BOM") { + val fileName = "test-data/utf32BEWithBOM.json" + val schema = new StructType().add("firstName", StringType).add("lastName", StringType) + val jsonDF = spark.read.schema(schema) + .option("multiline", "true") + .json(testFile(fileName)) + + checkAnswer(jsonDF, Seq(Row("Chris", "Baird"))) + } + + test("SPARK-23723: Use user's encoding in reading of multi-line json in UTF-16LE") { + val fileName = "test-data/utf16LE.json" + val schema = new StructType().add("firstName", StringType).add("lastName", StringType) + val jsonDF = spark.read.schema(schema) + .option("multiline", "true") + .options(Map("encoding" -> "UTF-16LE")) + .json(testFile(fileName)) + + checkAnswer(jsonDF, Seq(Row("Chris", "Baird"))) + } + + test("SPARK-23723: Unsupported encoding name") { + val invalidCharset = "UTF-128" + val exception = intercept[UnsupportedCharsetException] { + spark.read + .options(Map("encoding" -> invalidCharset, "lineSep" -> "\n")) + .json(testFile("test-data/utf16LE.json")) + .count() + } + + assert(exception.getMessage.contains(invalidCharset)) + } + + test("SPARK-23723: checking that the encoding option is case agnostic") { + val fileName = "test-data/utf16LE.json" + val schema = new StructType().add("firstName", StringType).add("lastName", StringType) + val jsonDF = spark.read.schema(schema) + .option("multiline", "true") + .options(Map("encoding" -> "uTf-16lE")) + .json(testFile(fileName)) + + checkAnswer(jsonDF, Seq(Row("Chris", "Baird"))) + } + + + test("SPARK-23723: specified encoding is not matched to actual encoding") { + val fileName = "test-data/utf16LE.json" + val schema = new StructType().add("firstName", StringType).add("lastName", StringType) + val exception = intercept[SparkException] { + spark.read.schema(schema) + .option("mode", "FAILFAST") + .option("multiline", "true") + .options(Map("encoding" -> "UTF-16BE")) + .json(testFile(fileName)) + .count() + } + val errMsg = exception.getMessage + + assert(errMsg.contains("Malformed records are detected in record parsing")) + } + + def checkEncoding(expectedEncoding: String, pathToJsonFiles: String, + expectedContent: String): Unit = { + val jsonFiles = new File(pathToJsonFiles) + .listFiles() + .filter(_.isFile) + .filter(_.getName.endsWith("json")) + val actualContent = jsonFiles.map { file => + new String(Files.readAllBytes(file.toPath), expectedEncoding) + }.mkString.trim + + assert(actualContent == expectedContent) + } + + test("SPARK-23723: save json in UTF-32BE") { + val encoding = "UTF-32BE" + withTempPath { path => + val df = spark.createDataset(Seq(("Dog", 42))) + df.write + .options(Map("encoding" -> encoding, "lineSep" -> "\n")) + .json(path.getCanonicalPath) + + checkEncoding( + expectedEncoding = encoding, + pathToJsonFiles = path.getCanonicalPath, + expectedContent = """{"_1":"Dog","_2":42}""") + } + } + + test("SPARK-23723: save json in default encoding - UTF-8") { + withTempPath { path => + val df = spark.createDataset(Seq(("Dog", 42))) + df.write.json(path.getCanonicalPath) + + checkEncoding( + expectedEncoding = "UTF-8", + pathToJsonFiles = path.getCanonicalPath, + expectedContent = """{"_1":"Dog","_2":42}""") + } + } + + test("SPARK-23723: wrong output encoding") { + val encoding = "UTF-128" + val exception = intercept[UnsupportedCharsetException] { + withTempPath { path => + val df = spark.createDataset(Seq((0))) + df.write + .options(Map("encoding" -> encoding, "lineSep" -> "\n")) + .json(path.getCanonicalPath) + } + } + + assert(exception.getMessage == encoding) + } + + test("SPARK-23723: read back json in UTF-16LE") { + val options = Map("encoding" -> "UTF-16LE", "lineSep" -> "\n") + withTempPath { path => + val ds = spark.createDataset(Seq(("a", 1), ("b", 2), ("c", 3))).repartition(2) + ds.write.options(options).json(path.getCanonicalPath) + + val readBack = spark + .read + .options(options) + .json(path.getCanonicalPath) + + checkAnswer(readBack.toDF(), ds.toDF()) + } + } + + def checkReadJson(lineSep: String, encoding: String, inferSchema: Boolean, id: Int): Unit = { + test(s"SPARK-23724: checks reading json in ${encoding} #${id}") { + val schema = new StructType().add("f1", StringType).add("f2", IntegerType) + withTempPath { path => + val records = List(("a", 1), ("b", 2)) + val data = records + .map(rec => s"""{"f1":"${rec._1}", "f2":${rec._2}}""".getBytes(encoding)) + .reduce((a1, a2) => a1 ++ lineSep.getBytes(encoding) ++ a2) + val os = new FileOutputStream(path) + os.write(data) + os.close() + val reader = if (inferSchema) { + spark.read + } else { + spark.read.schema(schema) + } + val readBack = reader + .option("encoding", encoding) + .option("lineSep", lineSep) + .json(path.getCanonicalPath) + checkAnswer(readBack, records.map(rec => Row(rec._1, rec._2))) + } + } + } + + // scalastyle:off nonascii + List( + (0, "|", "UTF-8", false), + (1, "^", "UTF-16BE", true), + (2, "::", "ISO-8859-1", true), + (3, "!!!@3", "UTF-32LE", false), + (4, 0x1E.toChar.toString, "UTF-8", true), + (5, "아", "UTF-32BE", false), + (6, "куку", "CP1251", true), + (7, "sep", "utf-8", false), + (8, "\r\n", "UTF-16LE", false), + (9, "\r\n", "utf-16be", true), + (10, "\u000d\u000a", "UTF-32BE", false), + (11, "\u000a\u000d", "UTF-8", true), + (12, "===", "US-ASCII", false), + (13, "$^+", "utf-32le", true) + ).foreach { + case (testNum, sep, encoding, inferSchema) => checkReadJson(sep, encoding, inferSchema, testNum) + } + // scalastyle:on nonascii + + test("SPARK-23724: lineSep should be set if encoding if different from UTF-8") { + val encoding = "UTF-16LE" + val exception = intercept[IllegalArgumentException] { + spark.read + .options(Map("encoding" -> encoding)) + .json(testFile("test-data/utf16LE.json")) + .count() + } + + assert(exception.getMessage.contains( + s"""The lineSep option must be specified for the $encoding encoding""")) + } + + private val badJson = "\u0000\u0000\u0000A\u0001AAA" + + test("SPARK-23094: permissively read JSON file with leading nulls when multiLine is enabled") { + withTempPath { tempDir => + val path = tempDir.getAbsolutePath + Seq(badJson + """{"a":1}""").toDS().write.text(path) + val expected = s"""${badJson}{"a":1}\n""" + val schema = new StructType().add("a", IntegerType).add("_corrupt_record", StringType) + val df = spark.read.format("json") + .option("mode", "PERMISSIVE") + .option("multiLine", true) + .option("encoding", "UTF-8") + .schema(schema).load(path) + checkAnswer(df, Row(null, expected)) + } + } + + test("SPARK-23094: permissively read JSON file with leading nulls when multiLine is disabled") { + withTempPath { tempDir => + val path = tempDir.getAbsolutePath + Seq(badJson, """{"a":1}""").toDS().write.text(path) + val schema = new StructType().add("a", IntegerType).add("_corrupt_record", StringType) + val df = spark.read.format("json") + .option("mode", "PERMISSIVE") + .option("multiLine", false) + .option("encoding", "UTF-8") + .schema(schema).load(path) + checkAnswer(df, Seq(Row(1, null), Row(null, badJson))) + } + } + + test("SPARK-23094: permissively parse a dataset contains JSON with leading nulls") { + checkAnswer( + spark.read.option("mode", "PERMISSIVE").option("encoding", "UTF-8").json(Seq(badJson).toDS()), + Row(badJson)) + } } From 56f501e1c0cec3be7d13008bd2c0182ec83ed2a2 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Mon, 30 Apr 2018 09:40:46 +0800 Subject: [PATCH 0710/1161] [MINOR][DOCS] Fix a broken link for Arrow's supported types in the programming guide ## What changes were proposed in this pull request? This PR fixes a broken link for Arrow's supported types in the programming guide. ## How was this patch tested? Manually tested via `SKIP_API=1 jekyll watch`. "Supported SQL Types" here in https://spark.apache.org/docs/latest/sql-programming-guide.html#enabling-for-conversion-tofrom-pandas is broken. It should be https://spark.apache.org/docs/latest/sql-programming-guide.html#supported-sql-types Author: hyukjinkwon Closes #21191 from HyukjinKwon/minor-arrow-link. --- docs/sql-programming-guide.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index e8ff1470970f7..836ce990205a9 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1703,7 +1703,7 @@ Using the above optimizations with Arrow will produce the same results as when A enabled. Note that even with Arrow, `toPandas()` results in the collection of all records in the DataFrame to the driver program and should be done on a small subset of the data. Not all Spark data types are currently supported and an error can be raised if a column has an unsupported type, -see [Supported SQL Types](#supported-sql-arrow-types). If an error occurs during `createDataFrame()`, +see [Supported SQL Types](#supported-sql-types). If an error occurs during `createDataFrame()`, Spark will fall back to create the DataFrame without Arrow. ## Pandas UDFs (a.k.a. Vectorized UDFs) From 3121b411f748859ed3ed1c97cbc21e6ae980a35c Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Mon, 30 Apr 2018 09:45:22 +0800 Subject: [PATCH 0711/1161] [SPARK-23846][SQL] The samplingRatio option for CSV datasource ## What changes were proposed in this pull request? I propose to support the `samplingRatio` option for schema inferring of CSV datasource similar to the same option of JSON datasource: https://github.com/apache/spark/blob/b14993e1fcb68e1c946a671c6048605ab4afdf58/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala#L49-L50 ## How was this patch tested? Added 2 tests for json and 2 tests for csv datasources. The tests checks that only subset of input dataset is used for schema inferring. Author: Maxim Gekk Author: Maxim Gekk Closes #20959 from MaxGekk/csv-sampling. --- python/pyspark/sql/readwriter.py | 7 ++- python/pyspark/sql/tests.py | 7 +++ .../apache/spark/sql/DataFrameReader.scala | 1 + .../datasources/csv/CSVDataSource.scala | 6 ++- .../datasources/csv/CSVOptions.scala | 3 ++ .../execution/datasources/csv/CSVUtils.scala | 28 +++++++++++ .../execution/datasources/csv/CSVSuite.scala | 47 ++++++++++++++++++- .../datasources/csv/TestCsvData.scala | 36 ++++++++++++++ 8 files changed, 129 insertions(+), 6 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/TestCsvData.scala diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 6811fa6b3b156..9899eb5058b82 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -345,7 +345,8 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non ignoreTrailingWhiteSpace=None, nullValue=None, nanValue=None, positiveInf=None, negativeInf=None, dateFormat=None, timestampFormat=None, maxColumns=None, maxCharsPerColumn=None, maxMalformedLogPerPartition=None, mode=None, - columnNameOfCorruptRecord=None, multiLine=None, charToEscapeQuoteEscaping=None): + columnNameOfCorruptRecord=None, multiLine=None, charToEscapeQuoteEscaping=None, + samplingRatio=None): """Loads a CSV file and returns the result as a :class:`DataFrame`. This function will go through the input once to determine the input schema if @@ -428,6 +429,8 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non the quote character. If None is set, the default value is escape character when escape and quote characters are different, ``\0`` otherwise. + :param samplingRatio: defines fraction of rows used for schema inferring. + If None is set, it uses the default value, ``1.0``. >>> df = spark.read.csv('python/test_support/sql/ages.csv') >>> df.dtypes @@ -446,7 +449,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non maxCharsPerColumn=maxCharsPerColumn, maxMalformedLogPerPartition=maxMalformedLogPerPartition, mode=mode, columnNameOfCorruptRecord=columnNameOfCorruptRecord, multiLine=multiLine, - charToEscapeQuoteEscaping=charToEscapeQuoteEscaping) + charToEscapeQuoteEscaping=charToEscapeQuoteEscaping, samplingRatio=samplingRatio) if isinstance(path, basestring): path = [path] if type(path) == list: diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index e0cd2aa41a2d0..bc3eaf16b4de7 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3033,6 +3033,13 @@ def test_json_sampling_ratio(self): .json(rdd).schema self.assertEquals(schema, StructType([StructField("a", LongType(), True)])) + def test_csv_sampling_ratio(self): + rdd = self.spark.sparkContext.range(0, 100, 1, 1) \ + .map(lambda x: '0.1' if x == 1 else str(x)) + schema = self.spark.read.option('inferSchema', True)\ + .csv(rdd, samplingRatio=0.5).schema + self.assertEquals(schema, StructType([StructField("_c0", IntegerType(), True)])) + class HiveSparkSubmitTests(SparkSubmitTests): diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 6b2ea6c06d3ae..53f44888ebaff 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -539,6 +539,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { *
  • `header` (default `false`): uses the first line as names of columns.
  • *
  • `inferSchema` (default `false`): infers the input schema automatically from data. It * requires one extra pass over the data.
  • + *
  • `samplingRatio` (default is 1.0): defines fraction of rows used for schema inferring.
  • *
  • `ignoreLeadingWhiteSpace` (default `false`): a flag indicating whether or not leading * whitespaces from values being read should be skipped.
  • *
  • `ignoreTrailingWhiteSpace` (default `false`): a flag indicating whether or not trailing diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala index 4870d75fc5f08..bc1f4ab3bb053 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala @@ -161,7 +161,8 @@ object TextInputCSVDataSource extends CSVDataSource { val firstRow = new CsvParser(parsedOptions.asParserSettings).parseLine(firstLine) val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis val header = makeSafeHeader(firstRow, caseSensitive, parsedOptions) - val tokenRDD = csv.rdd.mapPartitions { iter => + val sampled: Dataset[String] = CSVUtils.sample(csv, parsedOptions) + val tokenRDD = sampled.rdd.mapPartitions { iter => val filteredLines = CSVUtils.filterCommentAndEmpty(iter, parsedOptions) val linesWithoutHeader = CSVUtils.filterHeaderLine(filteredLines, firstLine, parsedOptions) @@ -235,7 +236,8 @@ object MultiLineCSVDataSource extends CSVDataSource { parsedOptions.headerFlag, new CsvParser(parsedOptions.asParserSettings)) } - CSVInferSchema.infer(tokenRDD, header, parsedOptions) + val sampled = CSVUtils.sample(tokenRDD, parsedOptions) + CSVInferSchema.infer(sampled, header, parsedOptions) case None => // If the first row could not be read, just return the empty schema. StructType(Nil) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala index c16790630ce17..2ec0fc605a84b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala @@ -150,6 +150,9 @@ class CSVOptions( val isCommentSet = this.comment != '\u0000' + val samplingRatio = + parameters.get("samplingRatio").map(_.toDouble).getOrElse(1.0) + def asWriterSettings: CsvWriterSettings = { val writerSettings = new CsvWriterSettings() val format = writerSettings.getFormat diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala index 72b053d2092ca..31464f1bcc68e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala @@ -17,7 +17,10 @@ package org.apache.spark.sql.execution.datasources.csv +import org.apache.spark.input.PortableDataStream +import org.apache.spark.rdd.RDD import org.apache.spark.sql.Dataset +import org.apache.spark.sql.catalyst.json.JSONOptions import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ @@ -131,4 +134,29 @@ object CSVUtils { schema.foreach(field => verifyType(field.dataType)) } + /** + * Sample CSV dataset as configured by `samplingRatio`. + */ + def sample(csv: Dataset[String], options: CSVOptions): Dataset[String] = { + require(options.samplingRatio > 0, + s"samplingRatio (${options.samplingRatio}) should be greater than 0") + if (options.samplingRatio > 0.99) { + csv + } else { + csv.sample(withReplacement = false, options.samplingRatio, 1) + } + } + + /** + * Sample CSV RDD as configured by `samplingRatio`. + */ + def sample(csv: RDD[Array[String]], options: CSVOptions): RDD[Array[String]] = { + require(options.samplingRatio > 0, + s"samplingRatio (${options.samplingRatio}) should be greater than 0") + if (options.samplingRatio > 0.99) { + csv + } else { + csv.sample(withReplacement = false, options.samplingRatio, 1) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index 4398e547d9217..461abdd96d3f3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -30,12 +30,11 @@ import org.apache.hadoop.io.compress.GzipCodec import org.apache.spark.SparkException import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row, UDT} import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.functions.{col, regexp_replace} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} import org.apache.spark.sql.types._ -class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { +class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils with TestCsvData { import testImplicits._ private val carsFile = "test-data/cars.csv" @@ -1279,4 +1278,48 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { Row("0,2013-111-11 12:13:14") :: Row(null) :: Nil ) } + + test("SPARK-23846: schema inferring touches less data if samplingRatio < 1.0") { + // Set default values for the DataSource parameters to make sure + // that whole test file is mapped to only one partition. This will guarantee + // reliable sampling of the input file. + withSQLConf( + "spark.sql.files.maxPartitionBytes" -> (128 * 1024 * 1024).toString, + "spark.sql.files.openCostInBytes" -> (4 * 1024 * 1024).toString + )(withTempPath { path => + val ds = sampledTestData.coalesce(1) + ds.write.text(path.getAbsolutePath) + + val readback = spark.read + .option("inferSchema", true).option("samplingRatio", 0.1) + .csv(path.getCanonicalPath) + assert(readback.schema == new StructType().add("_c0", IntegerType)) + }) + } + + test("SPARK-23846: usage of samplingRatio while parsing a dataset of strings") { + val ds = sampledTestData.coalesce(1) + val readback = spark.read + .option("inferSchema", true).option("samplingRatio", 0.1) + .csv(ds) + + assert(readback.schema == new StructType().add("_c0", IntegerType)) + } + + test("SPARK-23846: samplingRatio is out of the range (0, 1.0]") { + val ds = spark.range(0, 100, 1, 1).map(_.toString) + + val errorMsg0 = intercept[IllegalArgumentException] { + spark.read.option("inferSchema", true).option("samplingRatio", -1).csv(ds) + }.getMessage + assert(errorMsg0.contains("samplingRatio (-1.0) should be greater than 0")) + + val errorMsg1 = intercept[IllegalArgumentException] { + spark.read.option("inferSchema", true).option("samplingRatio", 0).csv(ds) + }.getMessage + assert(errorMsg1.contains("samplingRatio (0.0) should be greater than 0")) + + val sampled = spark.read.option("inferSchema", true).option("samplingRatio", 1.0).csv(ds) + assert(sampled.count() == ds.count()) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/TestCsvData.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/TestCsvData.scala new file mode 100644 index 0000000000000..3e20cc47dca2c --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/TestCsvData.scala @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.csv + +import org.apache.spark.sql.{Dataset, Encoders, SparkSession} + +private[csv] trait TestCsvData { + protected def spark: SparkSession + + def sampledTestData: Dataset[String] = { + spark.range(0, 100, 1).map { index => + val predefinedSample = Set[Long](2, 8, 15, 27, 30, 34, 35, 37, 44, 46, + 57, 62, 68, 72) + if (predefinedSample.contains(index)) { + index.toString + } else { + (index.toDouble + 0.1).toString + } + }(Encoders.STRING) + } +} From b42ad165bb93c96cc5be9ed05b5026f9baafdfa2 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 30 Apr 2018 09:13:32 -0700 Subject: [PATCH 0712/1161] [SPARK-24072][SQL] clearly define pushed filters ## What changes were proposed in this pull request? filters like parquet row group filter, which is actually pushed to the data source but still to be evaluated by Spark, should also count as `pushedFilters`. ## How was this patch tested? existing tests Author: Wenchen Fan Closes #21143 from cloud-fan/step1. --- .../SupportsPushDownCatalystFilters.java | 11 +++- .../v2/reader/SupportsPushDownFilters.java | 10 ++- .../datasources/v2/DataSourceV2Relation.scala | 63 +++++++++++-------- .../datasources/v2/DataSourceV2ScanExec.scala | 1 + .../datasources/v2/DataSourceV2Strategy.scala | 4 +- .../v2/DataSourceV2StringFormat.scala | 19 ++---- .../v2/PushDownOperatorsToDataSource.scala | 6 +- .../continuous/ContinuousSuite.scala | 2 +- 8 files changed, 68 insertions(+), 48 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownCatalystFilters.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownCatalystFilters.java index 290d614805ac7..4543c143a9aca 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownCatalystFilters.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownCatalystFilters.java @@ -39,7 +39,16 @@ public interface SupportsPushDownCatalystFilters extends DataSourceReader { Expression[] pushCatalystFilters(Expression[] filters); /** - * Returns the catalyst filters that are pushed in {@link #pushCatalystFilters(Expression[])}. + * Returns the catalyst filters that are pushed to the data source via + * {@link #pushCatalystFilters(Expression[])}. + * + * There are 3 kinds of filters: + * 1. pushable filters which don't need to be evaluated again after scanning. + * 2. pushable filters which still need to be evaluated after scanning, e.g. parquet + * row group filter. + * 3. non-pushable filters. + * Both case 1 and 2 should be considered as pushed filters and should be returned by this method. + * * It's possible that there is no filters in the query and * {@link #pushCatalystFilters(Expression[])} is never called, empty array should be returned for * this case. diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java index 1cff024232a44..b6a90a3d0b681 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java @@ -37,7 +37,15 @@ public interface SupportsPushDownFilters extends DataSourceReader { Filter[] pushFilters(Filter[] filters); /** - * Returns the filters that are pushed in {@link #pushFilters(Filter[])}. + * Returns the filters that are pushed to the data source via {@link #pushFilters(Filter[])}. + * + * There are 3 kinds of filters: + * 1. pushable filters which don't need to be evaluated again after scanning. + * 2. pushable filters which still need to be evaluated after scanning, e.g. parquet + * row group filter. + * 3. non-pushable filters. + * Both case 1 and 2 should be considered as pushed filters and should be returned by this method. + * * It's possible that there is no filters in the query and {@link #pushFilters(Filter[])} * is never called, empty array should be returned for this case. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala index 2b282ffae2390..90fb5a14c9fc9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala @@ -54,9 +54,12 @@ case class DataSourceV2Relation( private lazy val v2Options: DataSourceOptions = makeV2Options(options) + // postScanFilters: filters that need to be evaluated after the scan. + // pushedFilters: filters that will be pushed down and evaluated in the underlying data sources. + // Note: postScanFilters and pushedFilters can overlap, e.g. the parquet row group filter. lazy val ( reader: DataSourceReader, - unsupportedFilters: Seq[Expression], + postScanFilters: Seq[Expression], pushedFilters: Seq[Expression]) = { val newReader = userSpecifiedSchema match { case Some(s) => @@ -67,14 +70,16 @@ case class DataSourceV2Relation( DataSourceV2Relation.pushRequiredColumns(newReader, projection.toStructType) - val (remainingFilters, pushedFilters) = filters match { + val (postScanFilters, pushedFilters) = filters match { case Some(filterSeq) => DataSourceV2Relation.pushFilters(newReader, filterSeq) case _ => (Nil, Nil) } + logInfo(s"Post-Scan Filters: ${postScanFilters.mkString(",")}") + logInfo(s"Pushed Filters: ${pushedFilters.mkString(", ")}") - (newReader, remainingFilters, pushedFilters) + (newReader, postScanFilters, pushedFilters) } override def doCanonicalize(): LogicalPlan = { @@ -121,6 +126,8 @@ case class StreamingDataSourceV2Relation( override def simpleString: String = "Streaming RelationV2 " + metadataString + override def pushedFilters: Seq[Expression] = Nil + override def newInstance(): LogicalPlan = copy(output = output.map(_.newInstance())) // TODO: unify the equal/hashCode implementation for all data source v2 query plans. @@ -217,31 +224,35 @@ object DataSourceV2Relation { reader: DataSourceReader, filters: Seq[Expression]): (Seq[Expression], Seq[Expression]) = { reader match { - case catalystFilterSupport: SupportsPushDownCatalystFilters => - ( - catalystFilterSupport.pushCatalystFilters(filters.toArray), - catalystFilterSupport.pushedCatalystFilters() - ) - - case filterSupport: SupportsPushDownFilters => - // A map from original Catalyst expressions to corresponding translated data source - // filters. If a predicate is not in this map, it means it cannot be pushed down. - val translatedMap: Map[Expression, Filter] = filters.flatMap { p => - DataSourceStrategy.translateFilter(p).map(f => p -> f) - }.toMap - - // Catalyst predicate expressions that cannot be converted to data source filters. - val nonConvertiblePredicates = filters.filterNot(translatedMap.contains) - - // Data source filters that cannot be pushed down. An unhandled filter means - // the data source cannot guarantee the rows returned can pass the filter. - // As a result we must return it so Spark can plan an extra filter operator. - val unhandledFilters = filterSupport.pushFilters(translatedMap.values.toArray).toSet - val (unhandledPredicates, pushedPredicates) = translatedMap.partition { case (_, f) => - unhandledFilters.contains(f) + case r: SupportsPushDownCatalystFilters => + val postScanFilters = r.pushCatalystFilters(filters.toArray) + val pushedFilters = r.pushedCatalystFilters() + (postScanFilters, pushedFilters) + + case r: SupportsPushDownFilters => + // A map from translated data source filters to original catalyst filter expressions. + val translatedFilterToExpr = scala.collection.mutable.HashMap.empty[Filter, Expression] + // Catalyst filter expression that can't be translated to data source filters. + val untranslatableExprs = scala.collection.mutable.ArrayBuffer.empty[Expression] + + for (filterExpr <- filters) { + val translated = DataSourceStrategy.translateFilter(filterExpr) + if (translated.isDefined) { + translatedFilterToExpr(translated.get) = filterExpr + } else { + untranslatableExprs += filterExpr + } } - (nonConvertiblePredicates ++ unhandledPredicates.keys, pushedPredicates.keys.toSeq) + // Data source filters that need to be evaluated again after scanning. which means + // the data source cannot guarantee the rows returned can pass these filters. + // As a result we must return it so Spark can plan an extra filter operator. + val postScanFilters = + r.pushFilters(translatedFilterToExpr.keys.toArray).map(translatedFilterToExpr) + // The filters which are marked as pushed to this data source + val pushedFilters = r.pushedFilters().map(translatedFilterToExpr) + + (untranslatableExprs ++ postScanFilters, pushedFilters) case _ => (filters, Nil) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala index 3a5e7bf89e142..41bdda47c8c3e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala @@ -41,6 +41,7 @@ case class DataSourceV2ScanExec( output: Seq[AttributeReference], @transient source: DataSourceV2, @transient options: Map[String, String], + @transient pushedFilters: Seq[Expression], @transient reader: DataSourceReader) extends LeafExecNode with DataSourceV2StringFormat with ColumnarBatchScan { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index c2a31442d2be5..1b7c639f10f98 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -25,10 +25,10 @@ import org.apache.spark.sql.execution.streaming.continuous.{WriteToContinuousDat object DataSourceV2Strategy extends Strategy { override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case r: DataSourceV2Relation => - DataSourceV2ScanExec(r.output, r.source, r.options, r.reader) :: Nil + DataSourceV2ScanExec(r.output, r.source, r.options, r.pushedFilters, r.reader) :: Nil case r: StreamingDataSourceV2Relation => - DataSourceV2ScanExec(r.output, r.source, r.options, r.reader) :: Nil + DataSourceV2ScanExec(r.output, r.source, r.options, r.pushedFilters, r.reader) :: Nil case WriteToDataSourceV2(writer, query) => WriteToDataSourceV2Exec(writer, planLater(query)) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StringFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StringFormat.scala index aed55a429bfd7..693e67dcd108e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StringFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StringFormat.scala @@ -19,11 +19,9 @@ package org.apache.spark.sql.execution.datasources.v2 import org.apache.commons.lang3.StringUtils -import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} import org.apache.spark.sql.sources.DataSourceRegister import org.apache.spark.sql.sources.v2.DataSourceV2 -import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.util.Utils /** @@ -49,16 +47,9 @@ trait DataSourceV2StringFormat { def options: Map[String, String] /** - * The created data source reader. Here we use it to get the filters that has been pushed down - * so far, itself doesn't take part in the equals/hashCode. + * The filters which have been pushed to the data source. */ - def reader: DataSourceReader - - private lazy val filters = reader match { - case s: SupportsPushDownCatalystFilters => s.pushedCatalystFilters().toSet - case s: SupportsPushDownFilters => s.pushedFilters().toSet - case _ => Set.empty - } + def pushedFilters: Seq[Expression] private def sourceName: String = source match { case registered: DataSourceRegister => registered.shortName() @@ -68,8 +59,8 @@ trait DataSourceV2StringFormat { def metadataString: String = { val entries = scala.collection.mutable.ArrayBuffer.empty[(String, String)] - if (filters.nonEmpty) { - entries += "Filters" -> filters.mkString("[", ", ", "]") + if (pushedFilters.nonEmpty) { + entries += "Filters" -> pushedFilters.mkString("[", ", ", "]") } // TODO: we should only display some standard options like path, table, etc. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala index f23d228567241..9293d4f831bff 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala @@ -57,9 +57,9 @@ object PushDownOperatorsToDataSource extends Rule[LogicalPlan] { projection = projection.asInstanceOf[Seq[AttributeReference]], filters = Some(filters)) - // Add a Filter for any filters that could not be pushed - val unpushedFilter = newRelation.unsupportedFilters.reduceLeftOption(And) - val filtered = unpushedFilter.map(Filter(_, newRelation)).getOrElse(newRelation) + // Add a Filter for any filters that need to be evaluated after scan. + val postScanFilterCond = newRelation.postScanFilters.reduceLeftOption(And) + val filtered = postScanFilterCond.map(Filter(_, newRelation)).getOrElse(newRelation) // Add a Project to ensure the output matches the required projection if (newRelation.output != projectAttrs) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala index 5f222e7885994..cd1704ac2fdad 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala @@ -41,7 +41,7 @@ class ContinuousSuiteBase extends StreamTest { case s: ContinuousExecution => assert(numTriggers >= 2, "must wait for at least 2 triggers to ensure query is initialized") val reader = s.lastExecution.executedPlan.collectFirst { - case DataSourceV2ScanExec(_, _, _, r: RateStreamContinuousReader) => r + case DataSourceV2ScanExec(_, _, _, _, r: RateStreamContinuousReader) => r }.get val deltaMs = numTriggers * 1000 + 300 From 007ae6878f4b4defe1f08114212fa7289fc9ee4a Mon Sep 17 00:00:00 2001 From: Devaraj K Date: Mon, 30 Apr 2018 13:40:03 -0700 Subject: [PATCH 0713/1161] [SPARK-24003][CORE] Add support to provide spark.executor.extraJavaOptions in terms of App Id and/or Executor Id's ## What changes were proposed in this pull request? Added support to specify the 'spark.executor.extraJavaOptions' value in terms of the `{{APP_ID}}` and/or `{{EXECUTOR_ID}}`, `{{APP_ID}}` will be replaced by Application Id and `{{EXECUTOR_ID}}` will be replaced by Executor Id while starting the executor. ## How was this patch tested? I have verified this by checking the executor process command and gc logs. I verified the same in different deployment modes(Standalone, YARN, Mesos) client and cluster modes. Author: Devaraj K Closes #21088 from devaraj-kavali/SPARK-24003. --- .../spark/deploy/worker/ExecutorRunner.scala | 8 ++++++-- .../main/scala/org/apache/spark/util/Utils.scala | 15 +++++++++++++++ docs/configuration.md | 5 +++++ .../k8s/features/BasicExecutorFeatureStep.scala | 4 +++- .../MesosCoarseGrainedSchedulerBackend.scala | 4 +++- .../mesos/MesosFineGrainedSchedulerBackend.scala | 4 +++- .../org/apache/spark/deploy/yarn/Client.scala | 8 ++++++-- .../spark/deploy/yarn/ExecutorRunnable.scala | 3 ++- 8 files changed, 43 insertions(+), 8 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala index d4d8521cc8204..dc6a3076a5113 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala @@ -25,7 +25,7 @@ import scala.collection.JavaConverters._ import com.google.common.io.Files import org.apache.spark.{SecurityManager, SparkConf} -import org.apache.spark.deploy.{ApplicationDescription, ExecutorState} +import org.apache.spark.deploy.{ApplicationDescription, Command, ExecutorState} import org.apache.spark.deploy.DeployMessages.ExecutorStateChanged import org.apache.spark.internal.Logging import org.apache.spark.rpc.RpcEndpointRef @@ -142,7 +142,11 @@ private[deploy] class ExecutorRunner( private def fetchAndRunExecutor() { try { // Launch the process - val builder = CommandUtils.buildProcessBuilder(appDesc.command, new SecurityManager(conf), + val subsOpts = appDesc.command.javaOpts.map { + Utils.substituteAppNExecIds(_, appId, execId.toString) + } + val subsCommand = appDesc.command.copy(javaOpts = subsOpts) + val builder = CommandUtils.buildProcessBuilder(subsCommand, new SecurityManager(conf), memory, sparkHome.getAbsolutePath, substituteVariables) val command = builder.command() val formattedCommand = command.asScala.mkString("\"", "\" \"", "\"") diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index d2be93226e2a2..dcad1b914038f 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -2689,6 +2689,21 @@ private[spark] object Utils extends Logging { s"k8s://$resolvedURL" } + + /** + * Replaces all the {{EXECUTOR_ID}} occurrences with the Executor Id + * and {{APP_ID}} occurrences with the App Id. + */ + def substituteAppNExecIds(opt: String, appId: String, execId: String): String = { + opt.replace("{{APP_ID}}", appId).replace("{{EXECUTOR_ID}}", execId) + } + + /** + * Replaces all the {{APP_ID}} occurrences with the App Id. + */ + def substituteAppId(opt: String, appId: String): String = { + opt.replace("{{APP_ID}}", appId) + } } private[util] object CallerContext extends Logging { diff --git a/docs/configuration.md b/docs/configuration.md index fb02d7ea1d4ea..8a1aacef85760 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -328,6 +328,11 @@ Apart from these, the following properties are also available, and may be useful Note that it is illegal to set Spark properties or maximum heap size (-Xmx) settings with this option. Spark properties should be set using a SparkConf object or the spark-defaults.conf file used with the spark-submit script. Maximum heap size settings can be set with spark.executor.memory. + + The following symbols, if present will be interpolated: {{APP_ID}} will be replaced by + application ID and {{EXECUTOR_ID}} will be replaced by executor ID. For example, to enable + verbose gc logging to a file named for the executor ID of the app in /tmp, pass a 'value' of: + -verbose:gc -Xloggc:/tmp/{{APP_ID}}-{{EXECUTOR_ID}}.gc
  • diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala index d22097587aafe..529069d3b8a0c 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala @@ -89,7 +89,9 @@ private[spark] class BasicExecutorFeatureStep( val executorExtraJavaOptionsEnv = kubernetesConf .get(EXECUTOR_JAVA_OPTIONS) .map { opts => - val delimitedOpts = Utils.splitCommandString(opts) + val subsOpts = Utils.substituteAppNExecIds(opts, kubernetesConf.appId, + kubernetesConf.roleSpecificConf.executorId) + val delimitedOpts = Utils.splitCommandString(subsOpts) delimitedOpts.zipWithIndex.map { case (opt, index) => new EnvVarBuilder().withName(s"$ENV_JAVA_OPT_PREFIX$index").withValue(opt).build() diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala index 53f5f61cca486..9b75e4c98344a 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala @@ -227,7 +227,9 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( environment.addVariables( Environment.Variable.newBuilder().setName("SPARK_EXECUTOR_CLASSPATH").setValue(cp).build()) } - val extraJavaOpts = conf.get("spark.executor.extraJavaOptions", "") + val extraJavaOpts = conf.getOption("spark.executor.extraJavaOptions").map { + Utils.substituteAppNExecIds(_, appId, taskId) + }.getOrElse("") // Set the environment variable through a command prefix // to append to the existing value of the variable diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala index d6d939d246109..71a70ff048ccc 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala @@ -111,7 +111,9 @@ private[spark] class MesosFineGrainedSchedulerBackend( environment.addVariables( Environment.Variable.newBuilder().setName("SPARK_EXECUTOR_CLASSPATH").setValue(cp).build()) } - val extraJavaOpts = sc.conf.getOption("spark.executor.extraJavaOptions").getOrElse("") + val extraJavaOpts = sc.conf.getOption("spark.executor.extraJavaOptions").map { + Utils.substituteAppNExecIds(_, appId, execId) + }.getOrElse("") val prefixEnv = sc.conf.getOption("spark.executor.extraLibraryPath").map { p => Utils.libraryPathEnvPrefix(Seq(p)) diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 5763c3dbc5a8a..bafb129032b49 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -892,7 +892,9 @@ private[spark] class Client( // Include driver-specific java options if we are launching a driver if (isClusterMode) { sparkConf.get(DRIVER_JAVA_OPTIONS).foreach { opts => - javaOpts ++= Utils.splitCommandString(opts).map(YarnSparkHadoopUtil.escapeForShell) + javaOpts ++= Utils.splitCommandString(opts) + .map(Utils.substituteAppId(_, appId.toString)) + .map(YarnSparkHadoopUtil.escapeForShell) } val libraryPaths = Seq(sparkConf.get(DRIVER_LIBRARY_PATH), sys.props.get("spark.driver.libraryPath")).flatten @@ -914,7 +916,9 @@ private[spark] class Client( s"(was '$opts'). Use spark.yarn.am.memory instead." throw new SparkException(msg) } - javaOpts ++= Utils.splitCommandString(opts).map(YarnSparkHadoopUtil.escapeForShell) + javaOpts ++= Utils.splitCommandString(opts) + .map(Utils.substituteAppId(_, appId.toString)) + .map(YarnSparkHadoopUtil.escapeForShell) } sparkConf.get(AM_LIBRARY_PATH).foreach { paths => prefixEnv = Some(getClusterPath(sparkConf, Utils.libraryPathEnvPrefix(Seq(paths)))) diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala index ab08698035c98..a2a18cdff65af 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala @@ -141,7 +141,8 @@ private[yarn] class ExecutorRunnable( // Set extra Java options for the executor, if defined sparkConf.get(EXECUTOR_JAVA_OPTIONS).foreach { opts => - javaOpts ++= Utils.splitCommandString(opts).map(YarnSparkHadoopUtil.escapeForShell) + val subsOpt = Utils.substituteAppNExecIds(opts, appId, executorId) + javaOpts ++= Utils.splitCommandString(subsOpt).map(YarnSparkHadoopUtil.escapeForShell) } sparkConf.get(EXECUTOR_LIBRARY_PATH).foreach { p => prefixEnv = Some(Client.getClusterPath(sparkConf, Utils.libraryPathEnvPrefix(Seq(p)))) From b857fb549f3bf4e6f289ba11f3903db0a3696dec Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Tue, 1 May 2018 09:06:23 +0800 Subject: [PATCH 0714/1161] [SPARK-23853][PYSPARK][TEST] Run Hive-related PySpark tests only for `-Phive` ## What changes were proposed in this pull request? When `PyArrow` or `Pandas` are not available, the corresponding PySpark tests are skipped automatically. Currently, PySpark tests fail when we are not using `-Phive`. This PR aims to skip Hive related PySpark tests when `-Phive` is not given. **BEFORE** ```bash $ build/mvn -DskipTests clean package $ python/run-tests.py --python-executables python2.7 --modules pyspark-sql File "/Users/dongjoon/spark/python/pyspark/sql/readwriter.py", line 295, in pyspark.sql.readwriter.DataFrameReader.table ... IllegalArgumentException: u"Error while instantiating 'org.apache.spark.sql.hive.HiveExternalCatalog':" ********************************************************************** 1 of 3 in pyspark.sql.readwriter.DataFrameReader.table ***Test Failed*** 1 failures. ``` **AFTER** ```bash $ build/mvn -DskipTests clean package $ python/run-tests.py --python-executables python2.7 --modules pyspark-sql ... Tests passed in 138 seconds Skipped tests in pyspark.sql.tests with python2.7: ... test_hivecontext (pyspark.sql.tests.HiveSparkSubmitTests) ... skipped 'Hive is not available.' ``` ## How was this patch tested? This is a test-only change. First, this should pass the Jenkins. Then, manually do the following. ```bash build/mvn -DskipTests clean package python/run-tests.py --python-executables python2.7 --modules pyspark-sql ``` Author: Dongjoon Hyun Closes #21141 from dongjoon-hyun/SPARK-23853. --- python/pyspark/sql/readwriter.py | 2 +- python/pyspark/sql/tests.py | 20 ++++++++++++++++++++ 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 9899eb5058b82..448a4732001b5 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -979,7 +979,7 @@ def _test(): globs = pyspark.sql.readwriter.__dict__.copy() sc = SparkContext('local[4]', 'PythonTest') try: - spark = SparkSession.builder.enableHiveSupport().getOrCreate() + spark = SparkSession.builder.getOrCreate() except py4j.protocol.Py4JError: spark = SparkSession(sc) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index bc3eaf16b4de7..cc6acfdb07d99 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3043,6 +3043,26 @@ def test_csv_sampling_ratio(self): class HiveSparkSubmitTests(SparkSubmitTests): + @classmethod + def setUpClass(cls): + # get a SparkContext to check for availability of Hive + sc = SparkContext('local[4]', cls.__name__) + cls.hive_available = True + try: + sc._jvm.org.apache.hadoop.hive.conf.HiveConf() + except py4j.protocol.Py4JError: + cls.hive_available = False + except TypeError: + cls.hive_available = False + finally: + # we don't need this SparkContext for the test + sc.stop() + + def setUp(self): + super(HiveSparkSubmitTests, self).setUp() + if not self.hive_available: + self.skipTest("Hive is not available.") + def test_hivecontext(self): # This test checks that HiveContext is using Hive metastore (SPARK-16224). # It sets a metastore url and checks if there is a derby dir created by From 7bbec0dced35aeed79c1a24b6f7a1e0a3508b0fb Mon Sep 17 00:00:00 2001 From: wangyanlin01 Date: Tue, 1 May 2018 16:22:52 +0800 Subject: [PATCH 0715/1161] [SPARK-24061][SS] Add TypedFilter support for continuous processing ## What changes were proposed in this pull request? Add TypedFilter support for continuous processing application. ## How was this patch tested? unit tests Author: wangyanlin01 Closes #21136 from yanlin-Lynn/SPARK-24061. --- .../UnsupportedOperationChecker.scala | 3 ++- .../analysis/UnsupportedOperationsSuite.scala | 23 +++++++++++++++++++ 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala index ff9d6d7a7dded..d3d6c636c4ba8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala @@ -345,7 +345,8 @@ object UnsupportedOperationChecker { plan.foreachUp { implicit subPlan => subPlan match { case (_: Project | _: Filter | _: MapElements | _: MapPartitions | - _: DeserializeToObject | _: SerializeFromObject | _: SubqueryAlias) => + _: DeserializeToObject | _: SerializeFromObject | _: SubqueryAlias | + _: TypedFilter) => case node if node.nodeName == "StreamingRelationV2" => case node => throwError(s"Continuous processing does not support ${node.nodeName} operations.") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala index 60d1351fda264..cb487c8893541 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala @@ -621,6 +621,13 @@ class UnsupportedOperationsSuite extends SparkFunSuite { outputMode = Append, expectedMsgs = Seq("monotonically_increasing_id")) + assertSupportedForContinuousProcessing( + "TypedFilter", TypedFilter( + null, + null, + null, + null, + new TestStreamingRelationV2(attribute)), OutputMode.Append()) /* ======================================================================================= @@ -771,6 +778,16 @@ class UnsupportedOperationsSuite extends SparkFunSuite { } } + /** Assert that the logical plan is supported for continuous procsssing mode */ + def assertSupportedForContinuousProcessing( + name: String, + plan: LogicalPlan, + outputMode: OutputMode): Unit = { + test(s"continuous processing - $name: supported") { + UnsupportedOperationChecker.checkForContinuous(plan, outputMode) + } + } + /** * Assert that the logical plan is not supported inside a streaming plan. * @@ -840,4 +857,10 @@ class UnsupportedOperationsSuite extends SparkFunSuite { def this(attribute: Attribute) = this(Seq(attribute)) override def isStreaming: Boolean = true } + + case class TestStreamingRelationV2(output: Seq[Attribute]) extends LeafNode { + def this(attribute: Attribute) = this(Seq(attribute)) + override def isStreaming: Boolean = true + override def nodeName: String = "StreamingRelationV2" + } } From 6782359a04356e4cde32940861bf2410ef37f445 Mon Sep 17 00:00:00 2001 From: Bounkong Khamphousone Date: Tue, 1 May 2018 08:28:21 -0700 Subject: [PATCH 0716/1161] [SPARK-23941][MESOS] Mesos task failed on specific spark app name MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? Shell escaped the name passed to spark-submit and change how conf attributes are shell escaped. ## How was this patch tested? This test has been tested manually with Hive-on-spark with mesos or with the use case described in the issue with the sparkPi application with a custom name which contains illegal shell characters. With this PR, hive-on-spark on mesos works like a charm with hive 3.0.0-SNAPSHOT. I state that this contribution is my original work and that I license the work to the project under the project’s open source license Author: Bounkong Khamphousone Closes #21014 from tiboun/fix/SPARK-23941. --- .../spark/scheduler/cluster/mesos/MesosClusterScheduler.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala index d224a7325820a..b36f46456f9a5 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala @@ -530,9 +530,9 @@ private[spark] class MesosClusterScheduler( .filter { case (key, _) => !replicatedOptionsBlacklist.contains(key) } .toMap (defaultConf ++ driverConf).foreach { case (key, value) => - options ++= Seq("--conf", s""""$key=${shellEscape(value)}"""".stripMargin) } + options ++= Seq("--conf", s"${key}=${value}") } - options + options.map(shellEscape) } /** From e15850be6e0210614a734a307f5b83bdf44e2456 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 2 May 2018 10:55:01 +0800 Subject: [PATCH 0717/1161] [SPARK-24131][PYSPARK] Add majorMinorVersion API to PySpark for determining Spark versions ## What changes were proposed in this pull request? We need to determine Spark major and minor versions in PySpark. We can add a `majorMinorVersion` API to PySpark which is similar to the Scala API in `VersionUtils.majorMinorVersion`. ## How was this patch tested? Added tests. Author: Liang-Chi Hsieh Closes #21203 from viirya/SPARK-24131. --- python/pyspark/util.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/python/pyspark/util.py b/python/pyspark/util.py index 49afc13640332..04df835bf6717 100644 --- a/python/pyspark/util.py +++ b/python/pyspark/util.py @@ -16,6 +16,7 @@ # limitations under the License. # +import re import sys import inspect from py4j.protocol import Py4JJavaError @@ -61,6 +62,26 @@ def _get_argspec(f): return argspec +def majorMinorVersion(version): + """ + Get major and minor version numbers for given Spark version string. + + >>> version = "2.4.0" + >>> majorMinorVersion(version) + (2, 4) + + >>> version = "abc" + >>> majorMinorVersion(version) is None + True + + """ + m = re.search('^(\d+)\.(\d+)(\..*)?$', version) + if m is None: + return None + else: + return (int(m.group(1)), int(m.group(2))) + + if __name__ == "__main__": import doctest (failure_count, test_count) = doctest.testmod() From 9215ee7a16b57c56ae927d65e024cf7afe542cbb Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Wed, 2 May 2018 10:41:34 +0200 Subject: [PATCH 0718/1161] [SPARK-23976][CORE] Detect length overflow in UTF8String.concat()/ByteArray.concat() ## What changes were proposed in this pull request? This PR detects length overflow if total elements in inputs are not acceptable. For example, when the three inputs has `0x7FFF_FF00`, `0x7FFF_FF00`, and `0xE00`, we should detect length overflow since we cannot allocate such a large structure on `byte[]`. On the other hand, the current algorithm can allocate the result structure with `0x1000`-byte length due to integer sum overflow. ## How was this patch tested? Existing UTs. If we would create UTs, we need large heap (6-8GB). It may make test environment unstable. If it is necessary to create UTs, I will create them. Author: Kazuaki Ishizaki Closes #21064 from kiszk/SPARK-23976. --- .../org/apache/spark/unsafe/types/ByteArray.java | 12 +++++++----- .../org/apache/spark/unsafe/types/UTF8String.java | 8 ++++---- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java index c03caf0076f61..ecd7c19f2c634 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java @@ -17,10 +17,12 @@ package org.apache.spark.unsafe.types; -import org.apache.spark.unsafe.Platform; - import java.util.Arrays; +import com.google.common.primitives.Ints; + +import org.apache.spark.unsafe.Platform; + public final class ByteArray { public static final byte[] EMPTY_BYTE = new byte[0]; @@ -77,17 +79,17 @@ public static byte[] subStringSQL(byte[] bytes, int pos, int len) { public static byte[] concat(byte[]... inputs) { // Compute the total length of the result - int totalLength = 0; + long totalLength = 0; for (int i = 0; i < inputs.length; i++) { if (inputs[i] != null) { - totalLength += inputs[i].length; + totalLength += (long)inputs[i].length; } else { return null; } } // Allocate a new byte array, and copy the inputs one by one into it - final byte[] result = new byte[totalLength]; + final byte[] result = new byte[Ints.checkedCast(totalLength)]; int offset = 0; for (int i = 0; i < inputs.length; i++) { int len = inputs[i].length; diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index e9b3d9b045af5..e91fc4391425c 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -29,8 +29,8 @@ import com.esotericsoftware.kryo.KryoSerializable; import com.esotericsoftware.kryo.io.Input; import com.esotericsoftware.kryo.io.Output; - import com.google.common.primitives.Ints; + import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.hash.Murmur3_x86_32; @@ -877,17 +877,17 @@ public UTF8String lpad(int len, UTF8String pad) { */ public static UTF8String concat(UTF8String... inputs) { // Compute the total length of the result. - int totalLength = 0; + long totalLength = 0; for (int i = 0; i < inputs.length; i++) { if (inputs[i] != null) { - totalLength += inputs[i].numBytes; + totalLength += (long)inputs[i].numBytes; } else { return null; } } // Allocate a new byte array, and copy the inputs one by one into it. - final byte[] result = new byte[totalLength]; + final byte[] result = new byte[Ints.checkedCast(totalLength)]; int offset = 0; for (int i = 0; i < inputs.length; i++) { int len = inputs[i].numBytes; From 152eaf6ae698cd0df7f5a5be3f17ee46e0be929d Mon Sep 17 00:00:00 2001 From: WangJinhai02 Date: Wed, 2 May 2018 22:40:14 +0800 Subject: [PATCH 0719/1161] [SPARK-24107][CORE] ChunkedByteBuffer.writeFully method has not reset the limit value MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit JIRA Issue: https://issues.apache.org/jira/browse/SPARK-24107?jql=text%20~%20%22ChunkedByteBuffer%22 ChunkedByteBuffer.writeFully method has not reset the limit value. When chunks larger than bufferWriteChunkSize, such as 80 * 1024 * 1024 larger than config.BUFFER_WRITE_CHUNK_SIZE(64 * 1024 * 1024),only while once, will lost 16 * 1024 * 1024 byte Author: WangJinhai02 Closes #21175 from manbuyun/bugfix-ChunkedByteBuffer. --- .../spark/util/io/ChunkedByteBuffer.scala | 13 +++++++++---- .../spark/io/ChunkedByteBufferSuite.scala | 17 +++++++++++++++-- 2 files changed, 24 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala index 7367af7888bd8..3ae8dfcc1cb66 100644 --- a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala +++ b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala @@ -63,10 +63,15 @@ private[spark] class ChunkedByteBuffer(var chunks: Array[ByteBuffer]) { */ def writeFully(channel: WritableByteChannel): Unit = { for (bytes <- getChunks()) { - while (bytes.remaining() > 0) { - val ioSize = Math.min(bytes.remaining(), bufferWriteChunkSize) - bytes.limit(bytes.position() + ioSize) - channel.write(bytes) + val curChunkLimit = bytes.limit() + while (bytes.hasRemaining) { + try { + val ioSize = Math.min(bytes.remaining(), bufferWriteChunkSize) + bytes.limit(bytes.position() + ioSize) + channel.write(bytes) + } finally { + bytes.limit(curChunkLimit) + } } } } diff --git a/core/src/test/scala/org/apache/spark/io/ChunkedByteBufferSuite.scala b/core/src/test/scala/org/apache/spark/io/ChunkedByteBufferSuite.scala index 3b798e36b0499..2107559572d78 100644 --- a/core/src/test/scala/org/apache/spark/io/ChunkedByteBufferSuite.scala +++ b/core/src/test/scala/org/apache/spark/io/ChunkedByteBufferSuite.scala @@ -21,11 +21,12 @@ import java.nio.ByteBuffer import com.google.common.io.ByteStreams -import org.apache.spark.SparkFunSuite +import org.apache.spark.{SharedSparkContext, SparkFunSuite} +import org.apache.spark.internal.config import org.apache.spark.network.util.ByteArrayWritableChannel import org.apache.spark.util.io.ChunkedByteBuffer -class ChunkedByteBufferSuite extends SparkFunSuite { +class ChunkedByteBufferSuite extends SparkFunSuite with SharedSparkContext { test("no chunks") { val emptyChunkedByteBuffer = new ChunkedByteBuffer(Array.empty[ByteBuffer]) @@ -56,6 +57,18 @@ class ChunkedByteBufferSuite extends SparkFunSuite { assert(chunkedByteBuffer.getChunks().head.position() === 0) } + test("SPARK-24107: writeFully() write buffer which is larger than bufferWriteChunkSize") { + try { + sc.conf.set(config.BUFFER_WRITE_CHUNK_SIZE, 32L * 1024L * 1024L) + val chunkedByteBuffer = new ChunkedByteBuffer(Array(ByteBuffer.allocate(40 * 1024 * 1024))) + val byteArrayWritableChannel = new ByteArrayWritableChannel(chunkedByteBuffer.size.toInt) + chunkedByteBuffer.writeFully(byteArrayWritableChannel) + assert(byteArrayWritableChannel.length() === chunkedByteBuffer.size) + } finally { + sc.conf.remove(config.BUFFER_WRITE_CHUNK_SIZE) + } + } + test("toArray()") { val empty = ByteBuffer.wrap(Array.empty[Byte]) val bytes = ByteBuffer.wrap(Array.tabulate(8)(_.toByte)) From 8dbf56c055218ff0f3fabae84b63c022f43afbfd Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Wed, 2 May 2018 11:58:55 -0700 Subject: [PATCH 0720/1161] [SPARK-24013][SQL] Remove unneeded compress in ApproximatePercentile ## What changes were proposed in this pull request? `ApproximatePercentile` contains a workaround logic to compress the samples since at the beginning `QuantileSummaries` was ignoring the compression threshold. This problem was fixed in SPARK-17439, but the workaround logic was not removed. So we are compressing the samples many more times than needed: this could lead to critical performance degradation. This can create serious performance issues in queries like: ``` select approx_percentile(id, array(0.1)) from range(10000000) ``` ## How was this patch tested? added UT Author: Marco Gaido Closes #21133 from mgaido91/SPARK-24013. --- .../aggregate/ApproximatePercentile.scala | 33 ++++--------------- .../sql/catalyst/util/QuantileSummaries.scala | 11 ++++--- .../sql/ApproximatePercentileQuerySuite.scala | 13 ++++++++ 3 files changed, 26 insertions(+), 31 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala index a45854a3b5146..f1bbbdabb41f3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala @@ -206,27 +206,15 @@ object ApproximatePercentile { * with limited memory. PercentileDigest is backed by [[QuantileSummaries]]. * * @param summaries underlying probabilistic data structure [[QuantileSummaries]]. - * @param isCompressed An internal flag from class [[QuantileSummaries]] to indicate whether the - * underlying quantileSummaries is compressed. */ - class PercentileDigest( - private var summaries: QuantileSummaries, - private var isCompressed: Boolean) { - - // Trigger compression if the QuantileSummaries's buffer length exceeds - // compressThresHoldBufferLength. The buffer length can be get by - // quantileSummaries.sampled.length - private[this] final val compressThresHoldBufferLength: Int = { - // Max buffer length after compression. - val maxBufferLengthAfterCompression: Int = (1 / summaries.relativeError).toInt * 2 - // A safe upper bound for buffer length before compression - maxBufferLengthAfterCompression * 2 - } + class PercentileDigest(private var summaries: QuantileSummaries) { def this(relativeError: Double) = { - this(new QuantileSummaries(defaultCompressThreshold, relativeError), isCompressed = true) + this(new QuantileSummaries(defaultCompressThreshold, relativeError, compressed = true)) } + private[sql] def isCompressed: Boolean = summaries.compressed + /** Returns compressed object of [[QuantileSummaries]] */ def quantileSummaries: QuantileSummaries = { if (!isCompressed) compress() @@ -236,14 +224,6 @@ object ApproximatePercentile { /** Insert an observation value into the PercentileDigest data structure. */ def add(value: Double): Unit = { summaries = summaries.insert(value) - // The result of QuantileSummaries.insert is un-compressed - isCompressed = false - - // Currently, QuantileSummaries ignores the construction parameter compressThresHold, - // which may cause QuantileSummaries to occupy unbounded memory. We have to hack around here - // to make sure QuantileSummaries doesn't occupy infinite memory. - // TODO: Figure out why QuantileSummaries ignores construction parameter compressThresHold - if (summaries.sampled.length >= compressThresHoldBufferLength) compress() } /** In-place merges in another PercentileDigest. */ @@ -280,7 +260,6 @@ object ApproximatePercentile { private final def compress(): Unit = { summaries = summaries.compress() - isCompressed = true } } @@ -335,8 +314,8 @@ object ApproximatePercentile { sampled(i) = Stats(value, g, delta) i += 1 } - val summary = new QuantileSummaries(compressThreshold, relativeError, sampled, count) - new PercentileDigest(summary, isCompressed = true) + val summary = new QuantileSummaries(compressThreshold, relativeError, sampled, count, true) + new PercentileDigest(summary) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala index b013add9c9778..3190e511e2cb5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala @@ -40,12 +40,14 @@ import org.apache.spark.sql.catalyst.util.QuantileSummaries.Stats * See the G-K article for more details. * @param count the count of all the elements *inserted in the sampled buffer* * (excluding the head buffer) + * @param compressed whether the statistics have been compressed */ class QuantileSummaries( val compressThreshold: Int, val relativeError: Double, val sampled: Array[Stats] = Array.empty, - val count: Long = 0L) extends Serializable { + val count: Long = 0L, + var compressed: Boolean = false) extends Serializable { // a buffer of latest samples seen so far private val headSampled: ArrayBuffer[Double] = ArrayBuffer.empty @@ -60,6 +62,7 @@ class QuantileSummaries( */ def insert(x: Double): QuantileSummaries = { headSampled += x + compressed = false if (headSampled.size >= defaultHeadSize) { val result = this.withHeadBufferInserted if (result.sampled.length >= compressThreshold) { @@ -135,11 +138,11 @@ class QuantileSummaries( assert(inserted.count == count + headSampled.size) val compressed = compressImmut(inserted.sampled, mergeThreshold = 2 * relativeError * inserted.count) - new QuantileSummaries(compressThreshold, relativeError, compressed, inserted.count) + new QuantileSummaries(compressThreshold, relativeError, compressed, inserted.count, true) } private def shallowCopy: QuantileSummaries = { - new QuantileSummaries(compressThreshold, relativeError, sampled, count) + new QuantileSummaries(compressThreshold, relativeError, sampled, count, compressed) } /** @@ -163,7 +166,7 @@ class QuantileSummaries( val res = (sampled ++ other.sampled).sortBy(_.value) val comp = compressImmut(res, mergeThreshold = 2 * relativeError * count) new QuantileSummaries( - other.compressThreshold, other.relativeError, comp, other.count + count) + other.compressThreshold, other.relativeError, comp, other.count + count, true) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ApproximatePercentileQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ApproximatePercentileQuerySuite.scala index 137c5bea2abb9..d635912cf7205 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ApproximatePercentileQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ApproximatePercentileQuerySuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql import java.sql.{Date, Timestamp} +import org.apache.spark.sql.catalyst.expressions.aggregate.ApproximatePercentile import org.apache.spark.sql.catalyst.expressions.aggregate.ApproximatePercentile.DEFAULT_PERCENTILE_ACCURACY import org.apache.spark.sql.catalyst.expressions.aggregate.ApproximatePercentile.PercentileDigest import org.apache.spark.sql.catalyst.util.DateTimeUtils @@ -279,4 +280,16 @@ class ApproximatePercentileQuerySuite extends QueryTest with SharedSQLContext { checkAnswer(query, expected) } } + + test("SPARK-24013: unneeded compress can cause performance issues with sorted input") { + val buffer = new PercentileDigest(1.0D / ApproximatePercentile.DEFAULT_PERCENTILE_ACCURACY) + var compressCounts = 0 + (1 to 10000000).foreach { i => + buffer.add(i) + if (buffer.isCompressed) compressCounts += 1 + } + assert(compressCounts > 0) + buffer.quantileSummaries + assert(buffer.isCompressed) + } } From 8bd27025b7cf0b44726b6f4020d294ef14dbbb7e Mon Sep 17 00:00:00 2001 From: Ala Luszczak Date: Wed, 2 May 2018 12:43:19 -0700 Subject: [PATCH 0721/1161] [SPARK-24133][SQL] Check for integer overflows when resizing WritableColumnVectors ## What changes were proposed in this pull request? `ColumnVector`s store string data in one big byte array. Since the array size is capped at just under Integer.MAX_VALUE, a single `ColumnVector` cannot store more than 2GB of string data. But since the Parquet files commonly contain large blobs stored as strings, and `ColumnVector`s by default carry 4096 values, it's entirely possible to go past that limit. In such cases a negative capacity is requested from `WritableColumnVector.reserve()`. The call succeeds (requested capacity is smaller than already allocated capacity), and consequently `java.lang.ArrayIndexOutOfBoundsException` is thrown when the reader actually attempts to put the data into the array. This change introduces a simple check for integer overflow to `WritableColumnVector.reserve()` which should help catch the error earlier and provide more informative exception. Additionally, the error message in `WritableColumnVector.throwUnsupportedException()` was corrected, as it previously encouraged users to increase rather than reduce the batch size. ## How was this patch tested? New units tests were added. Author: Ala Luszczak Closes #21206 from ala/overflow-reserve. --- .../vectorized/WritableColumnVector.java | 21 ++++++++++++------- .../vectorized/ColumnarBatchSuite.scala | 7 +++++++ 2 files changed, 20 insertions(+), 8 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java index 5275e4a91eac0..b0e119d658cb4 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java @@ -81,7 +81,9 @@ public void close() { } public void reserve(int requiredCapacity) { - if (requiredCapacity > capacity) { + if (requiredCapacity < 0) { + throwUnsupportedException(requiredCapacity, null); + } else if (requiredCapacity > capacity) { int newCapacity = (int) Math.min(MAX_CAPACITY, requiredCapacity * 2L); if (requiredCapacity <= newCapacity) { try { @@ -96,13 +98,16 @@ public void reserve(int requiredCapacity) { } private void throwUnsupportedException(int requiredCapacity, Throwable cause) { - String message = "Cannot reserve additional contiguous bytes in the vectorized reader " + - "(requested = " + requiredCapacity + " bytes). As a workaround, you can disable the " + - "vectorized reader, or increase the vectorized reader batch size. For parquet file " + - "format, refer to " + SQLConf.PARQUET_VECTORIZED_READER_ENABLED().key() + " and " + - SQLConf.PARQUET_VECTORIZED_READER_BATCH_SIZE().key() + "; for orc file format, refer to " + - SQLConf.ORC_VECTORIZED_READER_ENABLED().key() + " and " + - SQLConf.ORC_VECTORIZED_READER_BATCH_SIZE().key() + "."; + String message = "Cannot reserve additional contiguous bytes in the vectorized reader (" + + (requiredCapacity >= 0 ? "requested " + requiredCapacity + " bytes" : "integer overflow") + + "). As a workaround, you can reduce the vectorized reader batch size, or disable the " + + "vectorized reader. For parquet file format, refer to " + + SQLConf.PARQUET_VECTORIZED_READER_BATCH_SIZE().key() + + " (default " + SQLConf.PARQUET_VECTORIZED_READER_BATCH_SIZE().defaultValueString() + + ") and " + SQLConf.PARQUET_VECTORIZED_READER_ENABLED().key() + "; for orc file format, " + + "refer to " + SQLConf.ORC_VECTORIZED_READER_BATCH_SIZE().key() + + " (default " + SQLConf.ORC_VECTORIZED_READER_BATCH_SIZE().defaultValueString() + + ") and " + SQLConf.ORC_VECTORIZED_READER_ENABLED().key() + "."; throw new RuntimeException(message, cause); } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala index 772f687526008..f57f07b498261 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala @@ -1333,4 +1333,11 @@ class ColumnarBatchSuite extends SparkFunSuite { column.close() } + + testVector("WritableColumnVector.reserve(): requested capacity is negative", 1024, ByteType) { + column => + val ex = intercept[RuntimeException] { column.reserve(-1) } + assert(ex.getMessage.contains( + "Cannot reserve additional contiguous bytes in the vectorized reader (integer overflow)")) + } } From 504c9cfd21ef45a13d9428fef3b197dcbf6786cd Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Wed, 2 May 2018 13:49:15 -0700 Subject: [PATCH 0722/1161] [SPARK-24123][SQL] Fix precision issues in monthsBetween with more than 8 digits ## What changes were proposed in this pull request? SPARK-23902 introduced the ability to retrieve more than 8 digits in `monthsBetween`. Unfortunately, current implementation can cause precision loss in such a case. This was causing also a flaky UT. This PR mirrors Hive's implementation in order to avoid precision loss also when more than 8 digits are returned. ## How was this patch tested? running 10000000 times the flaky UT Author: Marco Gaido Closes #21196 from mgaido91/SPARK-24123. --- .../spark/sql/catalyst/util/DateTimeUtils.scala | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index 4b00a61c6cf91..d2fe15c48c6dd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -888,14 +888,19 @@ object DateTimeUtils { val months1 = year1 * 12 + monthInYear1 val months2 = year2 * 12 + monthInYear2 + val monthDiff = (months1 - months2).toDouble + if (dayInMonth1 == dayInMonth2 || ((daysToMonthEnd1 == 0) && (daysToMonthEnd2 == 0))) { - return (months1 - months2).toDouble + return monthDiff } - // milliseconds is enough for 8 digits precision on the right side - val timeInDay1 = millis1 - daysToMillis(date1, timeZone) - val timeInDay2 = millis2 - daysToMillis(date2, timeZone) - val timesBetween = (timeInDay1 - timeInDay2).toDouble / MILLIS_PER_DAY - val diff = (months1 - months2).toDouble + (dayInMonth1 - dayInMonth2 + timesBetween) / 31.0 + // using milliseconds can cause precision loss with more than 8 digits + // we follow Hive's implementation which uses seconds + val secondsInDay1 = (millis1 - daysToMillis(date1, timeZone)) / 1000L + val secondsInDay2 = (millis2 - daysToMillis(date2, timeZone)) / 1000L + val secondsDiff = (dayInMonth1 - dayInMonth2) * SECONDS_PER_DAY + secondsInDay1 - secondsInDay2 + // 2678400D is the number of seconds in 31 days + // every month is considered to be 31 days long in this function + val diff = monthDiff + secondsDiff / 2678400D if (roundOff) { // rounding to 8 digits math.round(diff * 1e8) / 1e8 From 5be8aab14468e55b1049a0c83f02dcec0651162f Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Wed, 2 May 2018 13:53:10 -0700 Subject: [PATCH 0723/1161] [SPARK-23923][SQL] Add cardinality function ## What changes were proposed in this pull request? The PR adds the SQL function `cardinality`. The behavior of the function is based on Presto's one. The function returns the length of the array or map stored in the column as `int` while the Presto version returns the value as `BigInt` (`long` in Spark). The discussions regarding the difference of return type are [here](https://github.com/apache/spark/pull/21031#issuecomment-381284638) and [there](https://github.com/apache/spark/pull/21031#discussion_r181622107). ## How was this patch tested? Added UTs Author: Kazuaki Ishizaki Closes #21031 from kiszk/SPARK-23923. --- .../spark/sql/catalyst/analysis/FunctionRegistry.scala | 1 + .../scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala | 5 +++++ 2 files changed, 6 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 6bc7b4e4f7cb3..3ffbc9c8069fd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -409,6 +409,7 @@ object FunctionRegistry { expression[MapKeys]("map_keys"), expression[MapValues]("map_values"), expression[Size]("size"), + expression[Size]("cardinality"), expression[SortArray]("sort_array"), expression[ArrayMin]("array_min"), expression[ArrayMax]("array_max"), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 470a1c8e331ba..a5163accb1bb3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -341,6 +341,11 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { df.selectExpr("size(a)"), Seq(Row(2), Row(0), Row(3), Row(-1)) ) + + checkAnswer( + df.selectExpr("cardinality(a)"), + Seq(Row(2L), Row(0L), Row(3L), Row(-1L)) + ) } test("map size function") { From e4c91c089a701117af82f585d14d8afc5245fc64 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Wed, 2 May 2018 16:12:21 -0700 Subject: [PATCH 0724/1161] [SPARK-24111][SQL] Add the TPCDS v2.7 (latest) queries in TPCDSQueryBenchmark ## What changes were proposed in this pull request? This pr added the TPCDS v2.7 (latest) queries in `TPCDSQueryBenchmark`. These query files have been added in `SPARK-23167`. ## How was this patch tested? Manually checked. Author: Takeshi Yamamuro Closes #21177 from maropu/AddTpcdsV2_7InBenchmark. --- .../benchmark/TPCDSQueryBenchmark.scala | 52 +++++++++++++------ 1 file changed, 35 insertions(+), 17 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/TPCDSQueryBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/TPCDSQueryBenchmark.scala index 69247d7f4e9aa..abe61a2c2b9c4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/TPCDSQueryBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/TPCDSQueryBenchmark.scala @@ -58,10 +58,13 @@ object TPCDSQueryBenchmark extends Logging { }.toMap } - def tpcdsAll(dataLocation: String, queries: Seq[String]): Unit = { - val tableSizes = setupTables(dataLocation) + def runTpcdsQueries( + queryLocation: String, + queries: Seq[String], + tableSizes: Map[String, Long], + nameSuffix: String = ""): Unit = { queries.foreach { name => - val queryString = resourceToString(s"tpcds/$name.sql", + val queryString = resourceToString(s"$queryLocation/$name.sql", classLoader = Thread.currentThread().getContextClassLoader) // This is an indirect hack to estimate the size of each query's input by traversing the @@ -78,7 +81,7 @@ object TPCDSQueryBenchmark extends Logging { } val numRows = queryRelations.map(tableSizes.getOrElse(_, 0L)).sum val benchmark = new Benchmark(s"TPCDS Snappy", numRows, 5) - benchmark.addCase(name) { i => + benchmark.addCase(s"$name$nameSuffix") { _ => spark.sql(queryString).collect() } logInfo(s"\n\n===== TPCDS QUERY BENCHMARK OUTPUT FOR $name =====\n") @@ -87,10 +90,20 @@ object TPCDSQueryBenchmark extends Logging { } } + def filterQueries( + origQueries: Seq[String], + args: TPCDSQueryBenchmarkArguments): Seq[String] = { + if (args.queryFilter.nonEmpty) { + origQueries.filter(args.queryFilter.contains) + } else { + origQueries + } + } + def main(args: Array[String]): Unit = { val benchmarkArgs = new TPCDSQueryBenchmarkArguments(args) - // List of all TPC-DS queries + // List of all TPC-DS v1.4 queries val tpcdsQueries = Seq( "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14a", "q14b", "q15", "q16", "q17", "q18", "q19", "q20", @@ -103,20 +116,25 @@ object TPCDSQueryBenchmark extends Logging { "q81", "q82", "q83", "q84", "q85", "q86", "q87", "q88", "q89", "q90", "q91", "q92", "q93", "q94", "q95", "q96", "q97", "q98", "q99") + // This list only includes TPC-DS v2.7 queries that are different from v1.4 ones + val tpcdsQueriesV2_7 = Seq( + "q5a", "q6", "q10a", "q11", "q12", "q14", "q14a", "q18a", + "q20", "q22", "q22a", "q24", "q27a", "q34", "q35", "q35a", "q36a", "q47", "q49", + "q51a", "q57", "q64", "q67a", "q70a", "q72", "q74", "q75", "q77a", "q78", + "q80a", "q86a", "q98") + // If `--query-filter` defined, filters the queries that this option selects - val queriesToRun = if (benchmarkArgs.queryFilter.nonEmpty) { - val queries = tpcdsQueries.filter { case queryName => - benchmarkArgs.queryFilter.contains(queryName) - } - if (queries.isEmpty) { - throw new RuntimeException( - s"Empty queries to run. Bad query name filter: ${benchmarkArgs.queryFilter}") - } - queries - } else { - tpcdsQueries + val queriesV1_4ToRun = filterQueries(tpcdsQueries, benchmarkArgs) + val queriesV2_7ToRun = filterQueries(tpcdsQueriesV2_7, benchmarkArgs) + + if ((queriesV1_4ToRun ++ queriesV2_7ToRun).isEmpty) { + throw new RuntimeException( + s"Empty queries to run. Bad query name filter: ${benchmarkArgs.queryFilter}") } - tpcdsAll(benchmarkArgs.dataLocation, queries = queriesToRun) + val tableSizes = setupTables(benchmarkArgs.dataLocation) + runTpcdsQueries(queryLocation = "tpcds", queries = queriesV1_4ToRun, tableSizes) + runTpcdsQueries(queryLocation = "tpcds-v2.7.0", queries = queriesV2_7ToRun, tableSizes, + nameSuffix = "-v2.7") } } From bf4352ca6c96dfab16b286c54720685e32b216f1 Mon Sep 17 00:00:00 2001 From: jerryshao Date: Thu, 3 May 2018 09:28:14 +0800 Subject: [PATCH 0725/1161] [SPARK-24110][THRIFT-SERVER] Avoid UGI.loginUserFromKeytab in STS ## What changes were proposed in this pull request? Spark ThriftServer will call UGI.loginUserFromKeytab twice in initialization. This is unnecessary and will cause various potential problems, like Hadoop IPC failure after 7 days, or RM failover issue and so on. So here we need to remove all the unnecessary login logics and make sure UGI in the context never be created again. Note this is actually a HS2 issue, If later on we upgrade supported Hive version, the issue may already be fixed in Hive side. ## How was this patch tested? Local verification in secure cluster. Author: jerryshao Closes #21178 from jerryshao/SPARK-24110. --- .../hive/service/auth/HiveAuthFactory.java | 62 +++++++++++++++++-- .../thriftserver/SparkSQLCLIService.scala | 20 +++++- 2 files changed, 75 insertions(+), 7 deletions(-) diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/HiveAuthFactory.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/HiveAuthFactory.java index c5ade65283045..10000f12ab329 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/HiveAuthFactory.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/HiveAuthFactory.java @@ -18,6 +18,8 @@ package org.apache.hive.service.auth; import java.io.IOException; +import java.lang.reflect.Field; +import java.lang.reflect.Method; import java.net.InetSocketAddress; import java.net.UnknownHostException; import java.util.ArrayList; @@ -26,6 +28,7 @@ import java.util.List; import java.util.Locale; import java.util.Map; +import java.util.Objects; import javax.net.ssl.SSLServerSocket; import javax.security.auth.login.LoginException; @@ -92,7 +95,30 @@ public String getAuthName() { public static final String HS2_PROXY_USER = "hive.server2.proxy.user"; public static final String HS2_CLIENT_TOKEN = "hiveserver2ClientToken"; - public HiveAuthFactory(HiveConf conf) throws TTransportException { + private static Field keytabFile = null; + private static Method getKeytab = null; + static { + Class clz = UserGroupInformation.class; + try { + keytabFile = clz.getDeclaredField("keytabFile"); + keytabFile.setAccessible(true); + } catch (NoSuchFieldException nfe) { + LOG.debug("Cannot find private field \"keytabFile\" in class: " + + UserGroupInformation.class.getCanonicalName(), nfe); + keytabFile = null; + } + + try { + getKeytab = clz.getDeclaredMethod("getKeytab"); + getKeytab.setAccessible(true); + } catch(NoSuchMethodException nme) { + LOG.debug("Cannot find private method \"getKeytab\" in class:" + + UserGroupInformation.class.getCanonicalName(), nme); + getKeytab = null; + } + } + + public HiveAuthFactory(HiveConf conf) throws TTransportException, IOException { this.conf = conf; transportMode = conf.getVar(HiveConf.ConfVars.HIVE_SERVER2_TRANSPORT_MODE); authTypeStr = conf.getVar(HiveConf.ConfVars.HIVE_SERVER2_AUTHENTICATION); @@ -107,9 +133,16 @@ public HiveAuthFactory(HiveConf conf) throws TTransportException { authTypeStr = AuthTypes.NONE.getAuthName(); } if (authTypeStr.equalsIgnoreCase(AuthTypes.KERBEROS.getAuthName())) { - saslServer = ShimLoader.getHadoopThriftAuthBridge() - .createServer(conf.getVar(ConfVars.HIVE_SERVER2_KERBEROS_KEYTAB), - conf.getVar(ConfVars.HIVE_SERVER2_KERBEROS_PRINCIPAL)); + String principal = conf.getVar(ConfVars.HIVE_SERVER2_KERBEROS_PRINCIPAL); + String keytab = conf.getVar(ConfVars.HIVE_SERVER2_KERBEROS_KEYTAB); + if (needUgiLogin(UserGroupInformation.getCurrentUser(), + SecurityUtil.getServerPrincipal(principal, "0.0.0.0"), keytab)) { + saslServer = ShimLoader.getHadoopThriftAuthBridge().createServer(principal, keytab); + } else { + // Using the default constructor to avoid unnecessary UGI login. + saslServer = new HadoopThriftAuthBridge.Server(); + } + // start delegation token manager try { // rawStore is only necessary for DBTokenStore @@ -362,4 +395,25 @@ public static void verifyProxyAccess(String realUser, String proxyUser, String i } } + public static boolean needUgiLogin(UserGroupInformation ugi, String principal, String keytab) { + return null == ugi || !ugi.hasKerberosCredentials() || !ugi.getUserName().equals(principal) || + !Objects.equals(keytab, getKeytabFromUgi()); + } + + private static String getKeytabFromUgi() { + synchronized (UserGroupInformation.class) { + try { + if (keytabFile != null) { + return (String) keytabFile.get(null); + } else if (getKeytab != null) { + return (String) getKeytab.invoke(UserGroupInformation.getCurrentUser()); + } else { + return null; + } + } catch (Exception e) { + LOG.debug("Fail to get keytabFile path via reflection", e); + return null; + } + } + } } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala index ad1f5eb9ca3a7..1335e16e35882 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala @@ -27,7 +27,7 @@ import org.apache.commons.logging.Log import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.hadoop.hive.shims.Utils -import org.apache.hadoop.security.UserGroupInformation +import org.apache.hadoop.security.{SecurityUtil, UserGroupInformation} import org.apache.hive.service.{AbstractService, Service, ServiceException} import org.apache.hive.service.Service.STATE import org.apache.hive.service.auth.HiveAuthFactory @@ -52,8 +52,22 @@ private[hive] class SparkSQLCLIService(hiveServer: HiveServer2, sqlContext: SQLC if (UserGroupInformation.isSecurityEnabled) { try { - HiveAuthFactory.loginFromKeytab(hiveConf) - sparkServiceUGI = Utils.getUGI() + val principal = hiveConf.getVar(ConfVars.HIVE_SERVER2_KERBEROS_PRINCIPAL) + val keyTabFile = hiveConf.getVar(ConfVars.HIVE_SERVER2_KERBEROS_KEYTAB) + if (principal.isEmpty || keyTabFile.isEmpty) { + throw new IOException( + "HiveServer2 Kerberos principal or keytab is not correctly configured") + } + + val originalUgi = UserGroupInformation.getCurrentUser + sparkServiceUGI = if (HiveAuthFactory.needUgiLogin(originalUgi, + SecurityUtil.getServerPrincipal(principal, "0.0.0.0"), keyTabFile)) { + HiveAuthFactory.loginFromKeytab(hiveConf) + Utils.getUGI() + } else { + originalUgi + } + setSuperField(this, "serviceUGI", sparkServiceUGI) } catch { case e @ (_: IOException | _: LoginException) => From c9bfd1c6f8d16890ea1e5bc2bcb654a3afb32591 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Thu, 3 May 2018 15:15:05 +0800 Subject: [PATCH 0726/1161] [SPARK-23489][SQL][TEST] HiveExternalCatalogVersionsSuite should verify the downloaded file ## What changes were proposed in this pull request? Although [SPARK-22654](https://issues.apache.org/jira/browse/SPARK-22654) made `HiveExternalCatalogVersionsSuite` download from Apache mirrors three times, it has been flaky because it didn't verify the downloaded file. Some Apache mirrors terminate the downloading abnormally, the *corrupted* file shows the following errors. ``` gzip: stdin: not in gzip format tar: Child returned status 1 tar: Error is not recoverable: exiting now 22:46:32.700 WARN org.apache.spark.sql.hive.HiveExternalCatalogVersionsSuite: ===== POSSIBLE THREAD LEAK IN SUITE o.a.s.sql.hive.HiveExternalCatalogVersionsSuite, thread names: Keep-Alive-Timer ===== *** RUN ABORTED *** java.io.IOException: Cannot run program "./bin/spark-submit" (in directory "/tmp/test-spark/spark-2.2.0"): error=2, No such file or directory ``` This has been reported weirdly in two ways. For example, the above case is reported as Case 2 `no failures`. - Case 1. [Test Result (1 failure / +1)](https://amplab.cs.berkeley.edu/jenkins/view/Spark%20QA%20Test%20(Dashboard)/job/spark-master-test-sbt-hadoop-2.7/4389/) - Case 2. [Test Result (no failures)](https://amplab.cs.berkeley.edu/jenkins/view/Spark%20QA%20Test%20(Dashboard)/job/spark-master-test-maven-hadoop-2.6/4811/) This PR aims to make `HiveExternalCatalogVersionsSuite` more robust by verifying the downloaded `tgz` file by extracting and checking the existence of `bin/spark-submit`. If it turns out that the file is empty or corrupted, `HiveExternalCatalogVersionsSuite` will do retry logic like the download failure. ## How was this patch tested? Pass the Jenkins. Author: Dongjoon Hyun Closes #21210 from dongjoon-hyun/SPARK-23489. --- .../HiveExternalCatalogVersionsSuite.scala | 32 +++++++++---------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala index 6ca58e68d31eb..ea86ab9772bc7 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala @@ -67,7 +67,21 @@ class HiveExternalCatalogVersionsSuite extends SparkSubmitTestUtils { logInfo(s"Downloading Spark $version from $url") try { getFileFromUrl(url, path, filename) - return + val downloaded = new File(sparkTestingDir, filename).getCanonicalPath + val targetDir = new File(sparkTestingDir, s"spark-$version").getCanonicalPath + + Seq("mkdir", targetDir).! + val exitCode = Seq("tar", "-xzf", downloaded, "-C", targetDir, "--strip-components=1").! + Seq("rm", downloaded).! + + // For a corrupted file, `tar` returns non-zero values. However, we also need to check + // the extracted file because `tar` returns 0 for empty file. + val sparkSubmit = new File(sparkTestingDir, s"spark-$version/bin/spark-submit") + if (exitCode == 0 && sparkSubmit.exists()) { + return + } else { + Seq("rm", "-rf", targetDir).! + } } catch { case ex: Exception => logWarning(s"Failed to download Spark $version from $url", ex) } @@ -75,20 +89,6 @@ class HiveExternalCatalogVersionsSuite extends SparkSubmitTestUtils { fail(s"Unable to download Spark $version") } - - private def downloadSpark(version: String): Unit = { - tryDownloadSpark(version, sparkTestingDir.getCanonicalPath) - - val downloaded = new File(sparkTestingDir, s"spark-$version-bin-hadoop2.7.tgz").getCanonicalPath - val targetDir = new File(sparkTestingDir, s"spark-$version").getCanonicalPath - - Seq("mkdir", targetDir).! - - Seq("tar", "-xzf", downloaded, "-C", targetDir, "--strip-components=1").! - - Seq("rm", downloaded).! - } - private def genDataDir(name: String): String = { new File(tmpDataDir, name).getCanonicalPath } @@ -161,7 +161,7 @@ class HiveExternalCatalogVersionsSuite extends SparkSubmitTestUtils { PROCESS_TABLES.testingVersions.zipWithIndex.foreach { case (version, index) => val sparkHome = new File(sparkTestingDir, s"spark-$version") if (!sparkHome.exists()) { - downloadSpark(version) + tryDownloadSpark(version, sparkTestingDir.getCanonicalPath) } val args = Seq( From 417ad92502e714da71552f64d0e1257d2fd5d3d0 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 3 May 2018 19:27:01 +0800 Subject: [PATCH 0727/1161] [SPARK-23715][SQL] the input of to/from_utc_timestamp can not have timezone ## What changes were proposed in this pull request? `from_utc_timestamp` assumes its input is in UTC timezone and shifts it to the specified timezone. When the timestamp contains timezone(e.g. `2018-03-13T06:18:23+00:00`), Spark breaks the semantic and respect the timezone in the string. This is not what user expects and the result is different from Hive/Impala. `to_utc_timestamp` has the same problem. More details please refer to the JIRA ticket. This PR fixes this by returning null if the input timestamp contains timezone. ## How was this patch tested? new tests Author: Wenchen Fan Closes #21169 from cloud-fan/from_utc_timezone. --- docs/sql-programming-guide.md | 13 +- .../sql/catalyst/analysis/TypeCoercion.scala | 30 +++- .../expressions/datetimeExpressions.scala | 42 ++++++ .../sql/catalyst/util/DateTimeUtils.scala | 22 ++- .../apache/spark/sql/internal/SQLConf.scala | 7 + .../catalyst/analysis/TypeCoercionSuite.scala | 12 +- .../resources/sql-tests/inputs/datetime.sql | 33 +++++ .../sql-tests/results/datetime.sql.out | 135 +++++++++++++++++- .../apache/spark/sql/DateFunctionsSuite.scala | 8 ++ 9 files changed, 283 insertions(+), 19 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 836ce990205a9..075b953a0898e 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1805,12 +1805,13 @@ working with timestamps in `pandas_udf`s to get the best performance, see - Since Spark 2.4, Spark maximizes the usage of a vectorized ORC reader for ORC files by default. To do that, `spark.sql.orc.impl` and `spark.sql.orc.filterPushdown` change their default values to `native` and `true` respectively. - In PySpark, when Arrow optimization is enabled, previously `toPandas` just failed when Arrow optimization is unable to be used whereas `createDataFrame` from Pandas DataFrame allowed the fallback to non-optimization. Now, both `toPandas` and `createDataFrame` from Pandas DataFrame allow the fallback by default, which can be switched off by `spark.sql.execution.arrow.fallback.enabled`. - - Since Spark 2.4, writing an empty dataframe to a directory launches at least one write task, even if physically the dataframe has no partition. This introduces a small behavior change that for self-describing file formats like Parquet and Orc, Spark creates a metadata-only file in the target directory when writing a 0-partition dataframe, so that schema inference can still work if users read that directory later. The new behavior is more reasonable and more consistent regarding writing empty dataframe. - - Since Spark 2.4, expression IDs in UDF arguments do not appear in column names. For example, an column name in Spark 2.4 is not `UDF:f(col0 AS colA#28)` but ``UDF:f(col0 AS `colA`)``. - - Since Spark 2.4, writing a dataframe with an empty or nested empty schema using any file formats (parquet, orc, json, text, csv etc.) is not allowed. An exception is thrown when attempting to write dataframes with empty schema. - - Since Spark 2.4, Spark compares a DATE type with a TIMESTAMP type after promotes both sides to TIMESTAMP. To set `false` to `spark.sql.hive.compareDateTimestampInTimestamp` restores the previous behavior. This option will be removed in Spark 3.0. - - Since Spark 2.4, creating a managed table with nonempty location is not allowed. An exception is thrown when attempting to create a managed table with nonempty location. To set `true` to `spark.sql.allowCreatingManagedTableUsingNonemptyLocation` restores the previous behavior. This option will be removed in Spark 3.0. - - Since Spark 2.4, the type coercion rules can automatically promote the argument types of the variadic SQL functions (e.g., IN/COALESCE) to the widest common type, no matter how the input arguments order. In prior Spark versions, the promotion could fail in some specific orders (e.g., TimestampType, IntegerType and StringType) and throw an exception. + - Since Spark 2.4, writing an empty dataframe to a directory launches at least one write task, even if physically the dataframe has no partition. This introduces a small behavior change that for self-describing file formats like Parquet and Orc, Spark creates a metadata-only file in the target directory when writing a 0-partition dataframe, so that schema inference can still work if users read that directory later. The new behavior is more reasonable and more consistent regarding writing empty dataframe. + - Since Spark 2.4, expression IDs in UDF arguments do not appear in column names. For example, an column name in Spark 2.4 is not `UDF:f(col0 AS colA#28)` but ``UDF:f(col0 AS `colA`)``. + - Since Spark 2.4, writing a dataframe with an empty or nested empty schema using any file formats (parquet, orc, json, text, csv etc.) is not allowed. An exception is thrown when attempting to write dataframes with empty schema. + - Since Spark 2.4, Spark compares a DATE type with a TIMESTAMP type after promotes both sides to TIMESTAMP. To set `false` to `spark.sql.hive.compareDateTimestampInTimestamp` restores the previous behavior. This option will be removed in Spark 3.0. + - Since Spark 2.4, creating a managed table with nonempty location is not allowed. An exception is thrown when attempting to create a managed table with nonempty location. To set `true` to `spark.sql.allowCreatingManagedTableUsingNonemptyLocation` restores the previous behavior. This option will be removed in Spark 3.0. + - Since Spark 2.4, the type coercion rules can automatically promote the argument types of the variadic SQL functions (e.g., IN/COALESCE) to the widest common type, no matter how the input arguments order. In prior Spark versions, the promotion could fail in some specific orders (e.g., TimestampType, IntegerType and StringType) and throw an exception. + - In version 2.3 and earlier, `to_utc_timestamp` and `from_utc_timestamp` respect the timezone in the input timestamp string, which breaks the assumption that the input timestamp is in a specific timezone. Therefore, these 2 functions can return unexpected results. In version 2.4 and later, this problem has been fixed. `to_utc_timestamp` and `from_utc_timestamp` will return null if the input timestamp string contains timezone. As an example, `from_utc_timestamp('2000-10-10 00:00:00', 'GMT+1')` will return `2000-10-10 01:00:00` in both Spark 2.3 and 2.4. However, `from_utc_timestamp('2000-10-10 00:00:00+00:00', 'GMT+1')`, assuming a local timezone of GMT+8, will return `2000-10-10 09:00:00` in Spark 2.3 but `null` in 2.4. For people who don't care about this problem and want to retain the previous behaivor to keep their query unchanged, you can set `spark.sql.function.rejectTimezoneInString` to false. This option will be removed in Spark 3.0 and should only be used as a temporary workaround. ## Upgrading From Spark SQL 2.2 to 2.3 - Since Spark 2.3, the queries from raw JSON/CSV files are disallowed when the referenced columns only include the internal corrupt record column (named `_corrupt_record` by default). For example, `spark.read.schema(schema).json(file).filter($"_corrupt_record".isNotNull).count()` and `spark.read.schema(schema).json(file).select("_corrupt_record").show()`. Instead, you can cache or save the parsed results and then send the same query. For example, `val df = spark.read.schema(schema).json(file).cache()` and then `df.filter($"_corrupt_record".isNotNull).count()`. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 25bad28a2a209..b2817b0538a7f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -59,7 +59,7 @@ object TypeCoercion { IfCoercion :: StackCoercion :: Division :: - ImplicitTypeCasts :: + new ImplicitTypeCasts(conf) :: DateTimeOperations :: WindowFrameCoercion :: Nil @@ -776,12 +776,33 @@ object TypeCoercion { /** * Casts types according to the expected input types for [[Expression]]s. */ - object ImplicitTypeCasts extends TypeCoercionRule { + class ImplicitTypeCasts(conf: SQLConf) extends TypeCoercionRule { + + private def rejectTzInString = conf.getConf(SQLConf.REJECT_TIMEZONE_IN_STRING) + override protected def coerceTypes( plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e + // Special rules for `from/to_utc_timestamp`. These 2 functions assume the input timestamp + // string is in a specific timezone, so the string itself should not contain timezone. + // TODO: We should move the type coercion logic to expressions instead of a central + // place to put all the rules. + case e: FromUTCTimestamp if e.left.dataType == StringType => + if (rejectTzInString) { + e.copy(left = StringToTimestampWithoutTimezone(e.left)) + } else { + e.copy(left = Cast(e.left, TimestampType)) + } + + case e: ToUTCTimestamp if e.left.dataType == StringType => + if (rejectTzInString) { + e.copy(left = StringToTimestampWithoutTimezone(e.left)) + } else { + e.copy(left = Cast(e.left, TimestampType)) + } + case b @ BinaryOperator(left, right) if left.dataType != right.dataType => findTightestCommonType(left.dataType, right.dataType).map { commonType => if (b.inputType.acceptsType(commonType)) { @@ -798,7 +819,7 @@ object TypeCoercion { case e: ImplicitCastInputTypes if e.inputTypes.nonEmpty => val children: Seq[Expression] = e.children.zip(e.inputTypes).map { case (in, expected) => // If we cannot do the implicit cast, just use the original input. - implicitCast(in, expected).getOrElse(in) + ImplicitTypeCasts.implicitCast(in, expected).getOrElse(in) } e.withNewChildren(children) @@ -814,6 +835,9 @@ object TypeCoercion { } e.withNewChildren(children) } + } + + object ImplicitTypeCasts { /** * Given an expected data type, try to cast the expression and return the cast expression. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index d882d06cfd625..76aa61415a11f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -1016,6 +1016,48 @@ case class TimeAdd(start: Expression, interval: Expression, timeZoneId: Option[S } } +/** + * A special expression used to convert the string input of `to/from_utc_timestamp` to timestamp, + * which requires the timestamp string to not have timezone information, otherwise null is returned. + */ +case class StringToTimestampWithoutTimezone(child: Expression, timeZoneId: Option[String] = None) + extends UnaryExpression with TimeZoneAwareExpression with ExpectsInputTypes { + + override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = + copy(timeZoneId = Option(timeZoneId)) + + override def inputTypes: Seq[AbstractDataType] = Seq(StringType) + override def dataType: DataType = TimestampType + override def nullable: Boolean = true + override def toString: String = child.toString + override def sql: String = child.sql + + override def nullSafeEval(input: Any): Any = { + DateTimeUtils.stringToTimestamp( + input.asInstanceOf[UTF8String], timeZone, rejectTzInString = true).orNull + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") + val tz = ctx.addReferenceObj("timeZone", timeZone) + val longOpt = ctx.freshName("longOpt") + val eval = child.genCode(ctx) + val code = s""" + |${eval.code} + |${CodeGenerator.JAVA_BOOLEAN} ${ev.isNull} = true; + |${CodeGenerator.JAVA_LONG} ${ev.value} = ${CodeGenerator.defaultValue(TimestampType)}; + |if (!${eval.isNull}) { + | scala.Option $longOpt = $dtu.stringToTimestamp(${eval.value}, $tz, true); + | if ($longOpt.isDefined()) { + | ${ev.value} = ((Long) $longOpt.get()).longValue(); + | ${ev.isNull} = false; + | } + |} + """.stripMargin + ev.copy(code = code) + } +} + /** * Given a timestamp like '2017-07-14 02:40:00.0', interprets it as a time in UTC, and renders * that time as a timestamp in the given time zone. For example, 'GMT+1' would yield diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index d2fe15c48c6dd..e646da0659e85 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -296,10 +296,28 @@ object DateTimeUtils { * `T[h]h:[m]m:[s]s.[ms][ms][ms][us][us][us]+[h]h:[m]m` */ def stringToTimestamp(s: UTF8String): Option[SQLTimestamp] = { - stringToTimestamp(s, defaultTimeZone()) + stringToTimestamp(s, defaultTimeZone(), rejectTzInString = false) } def stringToTimestamp(s: UTF8String, timeZone: TimeZone): Option[SQLTimestamp] = { + stringToTimestamp(s, timeZone, rejectTzInString = false) + } + + /** + * Converts a timestamp string to microseconds from the unix epoch, w.r.t. the given timezone. + * Returns None if the input string is not a valid timestamp format. + * + * @param s the input timestamp string. + * @param timeZone the timezone of the timestamp string, will be ignored if the timestamp string + * already contains timezone information and `forceTimezone` is false. + * @param rejectTzInString if true, rejects timezone in the input string, i.e., if the + * timestamp string contains timezone, like `2000-10-10 00:00:00+00:00`, + * return None. + */ + def stringToTimestamp( + s: UTF8String, + timeZone: TimeZone, + rejectTzInString: Boolean): Option[SQLTimestamp] = { if (s == null) { return None } @@ -417,6 +435,8 @@ object DateTimeUtils { return None } + if (tz.isDefined && rejectTzInString) return None + val c = if (tz.isEmpty) { Calendar.getInstance(timeZone) } else { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 3729bd5293eca..3942240c442b2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1208,6 +1208,13 @@ object SQLConf { .stringConf .createWithDefault("") + val REJECT_TIMEZONE_IN_STRING = buildConf("spark.sql.function.rejectTimezoneInString") + .internal() + .doc("If true, `to_utc_timestamp` and `from_utc_timestamp` return null if the input string " + + "contains a timezone part, e.g. `2000-10-10 00:00:00+00:00`.") + .booleanConf + .createWithDefault(true) + object PartitionOverwriteMode extends Enumeration { val STATIC, DYNAMIC = Value } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala index 1cc431aaf0a60..0acd3b490447d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala @@ -524,11 +524,11 @@ class TypeCoercionSuite extends AnalysisTest { test("cast NullType for expressions that implement ExpectsInputTypes") { import TypeCoercionSuite._ - ruleTest(TypeCoercion.ImplicitTypeCasts, + ruleTest(new TypeCoercion.ImplicitTypeCasts(conf), AnyTypeUnaryExpression(Literal.create(null, NullType)), AnyTypeUnaryExpression(Literal.create(null, NullType))) - ruleTest(TypeCoercion.ImplicitTypeCasts, + ruleTest(new TypeCoercion.ImplicitTypeCasts(conf), NumericTypeUnaryExpression(Literal.create(null, NullType)), NumericTypeUnaryExpression(Literal.create(null, DoubleType))) } @@ -536,11 +536,11 @@ class TypeCoercionSuite extends AnalysisTest { test("cast NullType for binary operators") { import TypeCoercionSuite._ - ruleTest(TypeCoercion.ImplicitTypeCasts, + ruleTest(new TypeCoercion.ImplicitTypeCasts(conf), AnyTypeBinaryOperator(Literal.create(null, NullType), Literal.create(null, NullType)), AnyTypeBinaryOperator(Literal.create(null, NullType), Literal.create(null, NullType))) - ruleTest(TypeCoercion.ImplicitTypeCasts, + ruleTest(new TypeCoercion.ImplicitTypeCasts(conf), NumericTypeBinaryOperator(Literal.create(null, NullType), Literal.create(null, NullType)), NumericTypeBinaryOperator(Literal.create(null, DoubleType), Literal.create(null, DoubleType))) } @@ -823,7 +823,7 @@ class TypeCoercionSuite extends AnalysisTest { } test("type coercion for CaseKeyWhen") { - ruleTest(TypeCoercion.ImplicitTypeCasts, + ruleTest(new TypeCoercion.ImplicitTypeCasts(conf), CaseKeyWhen(Literal(1.toShort), Seq(Literal(1), Literal("a"))), CaseKeyWhen(Cast(Literal(1.toShort), IntegerType), Seq(Literal(1), Literal("a"))) ) @@ -1275,7 +1275,7 @@ class TypeCoercionSuite extends AnalysisTest { } test("SPARK-17117 null type coercion in divide") { - val rules = Seq(FunctionArgumentConversion, Division, ImplicitTypeCasts) + val rules = Seq(FunctionArgumentConversion, Division, new ImplicitTypeCasts(conf)) val nullLit = Literal.create(null, NullType) ruleTest(rules, Divide(1L, nullLit), Divide(Cast(1L, DoubleType), Cast(nullLit, DoubleType))) ruleTest(rules, Divide(nullLit, 1L), Divide(Cast(nullLit, DoubleType), Cast(1L, DoubleType))) diff --git a/sql/core/src/test/resources/sql-tests/inputs/datetime.sql b/sql/core/src/test/resources/sql-tests/inputs/datetime.sql index 547c2bef02b24..4950a4b7a4e5a 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/datetime.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/datetime.sql @@ -27,3 +27,36 @@ select current_date = current_date(), current_timestamp = current_timestamp(), a select a, b from ttf2 order by a, current_date; select weekday('2007-02-03'), weekday('2009-07-30'), weekday('2017-05-27'), weekday(null), weekday('1582-10-15 13:10:15'); + +select from_utc_timestamp('2015-07-24 00:00:00', 'PST'); + +select from_utc_timestamp('2015-01-24 00:00:00', 'PST'); + +select from_utc_timestamp(null, 'PST'); + +select from_utc_timestamp('2015-07-24 00:00:00', null); + +select from_utc_timestamp(null, null); + +select from_utc_timestamp(cast(0 as timestamp), 'PST'); + +select from_utc_timestamp(cast('2015-01-24' as date), 'PST'); + +select to_utc_timestamp('2015-07-24 00:00:00', 'PST'); + +select to_utc_timestamp('2015-01-24 00:00:00', 'PST'); + +select to_utc_timestamp(null, 'PST'); + +select to_utc_timestamp('2015-07-24 00:00:00', null); + +select to_utc_timestamp(null, null); + +select to_utc_timestamp(cast(0 as timestamp), 'PST'); + +select to_utc_timestamp(cast('2015-01-24' as date), 'PST'); + +-- SPARK-23715: the input of to/from_utc_timestamp can not have timezone +select from_utc_timestamp('2000-10-10 00:00:00+00:00', 'PST'); + +select to_utc_timestamp('2000-10-10 00:00:00+00:00', 'PST'); diff --git a/sql/core/src/test/resources/sql-tests/results/datetime.sql.out b/sql/core/src/test/resources/sql-tests/results/datetime.sql.out index 4e1cfa6e48c1c..9eede305dbdcc 100644 --- a/sql/core/src/test/resources/sql-tests/results/datetime.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/datetime.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 10 +-- Number of queries: 26 -- !query 0 @@ -82,9 +82,138 @@ struct 1 2 2 3 + -- !query 9 select weekday('2007-02-03'), weekday('2009-07-30'), weekday('2017-05-27'), weekday(null), weekday('1582-10-15 13:10:15') --- !query 3 schema +-- !query 9 schema struct --- !query 3 output +-- !query 9 output 5 3 5 NULL 4 + + +-- !query 10 +select from_utc_timestamp('2015-07-24 00:00:00', 'PST') +-- !query 10 schema +struct +-- !query 10 output +2015-07-23 17:00:00 + + +-- !query 11 +select from_utc_timestamp('2015-01-24 00:00:00', 'PST') +-- !query 11 schema +struct +-- !query 11 output +2015-01-23 16:00:00 + + +-- !query 12 +select from_utc_timestamp(null, 'PST') +-- !query 12 schema +struct +-- !query 12 output +NULL + + +-- !query 13 +select from_utc_timestamp('2015-07-24 00:00:00', null) +-- !query 13 schema +struct +-- !query 13 output +NULL + + +-- !query 14 +select from_utc_timestamp(null, null) +-- !query 14 schema +struct +-- !query 14 output +NULL + + +-- !query 15 +select from_utc_timestamp(cast(0 as timestamp), 'PST') +-- !query 15 schema +struct +-- !query 15 output +1969-12-31 08:00:00 + + +-- !query 16 +select from_utc_timestamp(cast('2015-01-24' as date), 'PST') +-- !query 16 schema +struct +-- !query 16 output +2015-01-23 16:00:00 + + +-- !query 17 +select to_utc_timestamp('2015-07-24 00:00:00', 'PST') +-- !query 17 schema +struct +-- !query 17 output +2015-07-24 07:00:00 + + +-- !query 18 +select to_utc_timestamp('2015-01-24 00:00:00', 'PST') +-- !query 18 schema +struct +-- !query 18 output +2015-01-24 08:00:00 + + +-- !query 19 +select to_utc_timestamp(null, 'PST') +-- !query 19 schema +struct +-- !query 19 output +NULL + + +-- !query 20 +select to_utc_timestamp('2015-07-24 00:00:00', null) +-- !query 20 schema +struct +-- !query 20 output +NULL + + +-- !query 21 +select to_utc_timestamp(null, null) +-- !query 21 schema +struct +-- !query 21 output +NULL + + +-- !query 22 +select to_utc_timestamp(cast(0 as timestamp), 'PST') +-- !query 22 schema +struct +-- !query 22 output +1970-01-01 00:00:00 + + +-- !query 23 +select to_utc_timestamp(cast('2015-01-24' as date), 'PST') +-- !query 23 schema +struct +-- !query 23 output +2015-01-24 08:00:00 + + +-- !query 24 +select from_utc_timestamp('2000-10-10 00:00:00+00:00', 'PST') +-- !query 24 schema +struct +-- !query 24 output +NULL + + +-- !query 25 +select to_utc_timestamp('2000-10-10 00:00:00+00:00', 'PST') +-- !query 25 schema +struct +-- !query 25 output +NULL diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala index f712baa7a9134..237412aa692e5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala @@ -23,6 +23,7 @@ import java.util.Locale import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.unsafe.types.CalendarInterval @@ -696,4 +697,11 @@ class DateFunctionsSuite extends QueryTest with SharedSQLContext { Row(Timestamp.valueOf("2015-07-25 07:00:00")))) } + test("SPARK-23715: to/from_utc_timestamp can retain the previous behavior") { + withSQLConf(SQLConf.REJECT_TIMEZONE_IN_STRING.key -> "false") { + checkAnswer( + sql("SELECT from_utc_timestamp('2000-10-10 00:00:00+00:00', 'GMT+1')"), + Row(Timestamp.valueOf("2000-10-09 18:00:00"))) + } + } } From 991b526992bcf1dc1268578b650916569b12f583 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 3 May 2018 19:56:30 +0800 Subject: [PATCH 0728/1161] [SPARK-24166][SQL] InMemoryTableScanExec should not access SQLConf at executor side ## What changes were proposed in this pull request? This PR is extracted from https://github.com/apache/spark/pull/21190 , to make it easier to backport. `InMemoryTableScanExec#createAndDecompressColumn` is executed inside `rdd.map`, we can't access `conf.offHeapColumnVectorEnabled` there. ## How was this patch tested? it's tested in #21190 Author: Wenchen Fan Closes #21223 from cloud-fan/minor1. --- .../execution/columnar/InMemoryTableScanExec.scala | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala index ea315fb71617c..0b4dd76c7d860 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala @@ -78,10 +78,12 @@ case class InMemoryTableScanExec( private lazy val columnarBatchSchema = new StructType(columnIndices.map(i => relationSchema(i))) - private def createAndDecompressColumn(cachedColumnarBatch: CachedBatch): ColumnarBatch = { + private def createAndDecompressColumn( + cachedColumnarBatch: CachedBatch, + offHeapColumnVectorEnabled: Boolean): ColumnarBatch = { val rowCount = cachedColumnarBatch.numRows val taskContext = Option(TaskContext.get()) - val columnVectors = if (!conf.offHeapColumnVectorEnabled || taskContext.isEmpty) { + val columnVectors = if (!offHeapColumnVectorEnabled || taskContext.isEmpty) { OnHeapColumnVector.allocateColumns(rowCount, columnarBatchSchema) } else { OffHeapColumnVector.allocateColumns(rowCount, columnarBatchSchema) @@ -101,10 +103,13 @@ case class InMemoryTableScanExec( private lazy val inputRDD: RDD[InternalRow] = { val buffers = filteredCachedBatches() + val offHeapColumnVectorEnabled = conf.offHeapColumnVectorEnabled if (supportsBatch) { // HACK ALERT: This is actually an RDD[ColumnarBatch]. // We're taking advantage of Scala's type erasure here to pass these batches along. - buffers.map(createAndDecompressColumn).asInstanceOf[RDD[InternalRow]] + buffers + .map(createAndDecompressColumn(_, offHeapColumnVectorEnabled)) + .asInstanceOf[RDD[InternalRow]] } else { val numOutputRows = longMetric("numOutputRows") From 96a50016bb0fb1cc57823a6706bff2467d671efd Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 3 May 2018 23:36:09 +0800 Subject: [PATCH 0729/1161] [SPARK-24169][SQL] JsonToStructs should not access SQLConf at executor side ## What changes were proposed in this pull request? This PR is extracted from #21190 , to make it easier to backport. `JsonToStructs` can be serialized to executors and evaluate, we should not call `SQLConf.get.getConf(SQLConf.FROM_JSON_FORCE_NULLABLE_SCHEMA)` in the body. ## How was this patch tested? tested in #21190 Author: Wenchen Fan Closes #21226 from cloud-fan/minor4. --- .../catalyst/analysis/FunctionRegistry.scala | 4 +- .../expressions/jsonExpressions.scala | 16 ++-- .../expressions/JsonExpressionsSuite.scala | 76 +++++++++---------- .../org/apache/spark/sql/functions.scala | 2 +- .../sql-tests/results/json-functions.sql.out | 4 +- 5 files changed, 54 insertions(+), 48 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 3ffbc9c8069fd..51bb6b0abe408 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -534,7 +534,9 @@ object FunctionRegistry { // Otherwise, find a constructor method that matches the number of arguments, and use that. val params = Seq.fill(expressions.size)(classOf[Expression]) val f = constructors.find(_.getParameterTypes.toSeq == params).getOrElse { - val validParametersCount = constructors.map(_.getParameterCount).distinct.sorted + val validParametersCount = constructors + .filter(_.getParameterTypes.forall(_ == classOf[Expression])) + .map(_.getParameterCount).distinct.sorted val expectedNumberOfParameters = if (validParametersCount.length == 1) { validParametersCount.head.toString } else { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index fdd672c416a03..34161f0f03f4a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -514,11 +514,10 @@ case class JsonToStructs( schema: DataType, options: Map[String, String], child: Expression, - timeZoneId: Option[String] = None) + timeZoneId: Option[String], + forceNullableSchema: Boolean) extends UnaryExpression with TimeZoneAwareExpression with CodegenFallback with ExpectsInputTypes { - val forceNullableSchema = SQLConf.get.getConf(SQLConf.FROM_JSON_FORCE_NULLABLE_SCHEMA) - // The JSON input data might be missing certain fields. We force the nullability // of the user-provided schema to avoid data corruptions. In particular, the parquet-mr encoder // can generate incorrect files if values are missing in columns declared as non-nullable. @@ -532,14 +531,21 @@ case class JsonToStructs( schema = JsonExprUtils.validateSchemaLiteral(schema), options = Map.empty[String, String], child = child, - timeZoneId = None) + timeZoneId = None, + forceNullableSchema = SQLConf.get.getConf(SQLConf.FROM_JSON_FORCE_NULLABLE_SCHEMA)) def this(child: Expression, schema: Expression, options: Expression) = this( schema = JsonExprUtils.validateSchemaLiteral(schema), options = JsonExprUtils.convertToMapData(options), child = child, - timeZoneId = None) + timeZoneId = None, + forceNullableSchema = SQLConf.get.getConf(SQLConf.FROM_JSON_FORCE_NULLABLE_SCHEMA)) + + // Used in `org.apache.spark.sql.functions` + def this(schema: DataType, options: Map[String, String], child: Expression) = + this(schema, options, child, timeZoneId = None, + forceNullableSchema = SQLConf.get.getConf(SQLConf.FROM_JSON_FORCE_NULLABLE_SCHEMA)) override def checkInputDataTypes(): TypeCheckResult = nullableSchema match { case _: StructType | ArrayType(_: StructType, _) => diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala index 7812319756eae..00e97637eee7e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala @@ -392,7 +392,7 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with val jsonData = """{"a": 1}""" val schema = StructType(StructField("a", IntegerType) :: Nil) checkEvaluation( - JsonToStructs(schema, Map.empty, Literal(jsonData), gmtId), + JsonToStructs(schema, Map.empty, Literal(jsonData), gmtId, true), InternalRow(1) ) } @@ -401,13 +401,13 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with val jsonData = """{"a" 1}""" val schema = StructType(StructField("a", IntegerType) :: Nil) checkEvaluation( - JsonToStructs(schema, Map.empty, Literal(jsonData), gmtId), + JsonToStructs(schema, Map.empty, Literal(jsonData), gmtId, true), null ) // Other modes should still return `null`. checkEvaluation( - JsonToStructs(schema, Map("mode" -> PermissiveMode.name), Literal(jsonData), gmtId), + JsonToStructs(schema, Map("mode" -> PermissiveMode.name), Literal(jsonData), gmtId, true), null ) } @@ -416,62 +416,62 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with val input = """[{"a": 1}, {"a": 2}]""" val schema = ArrayType(StructType(StructField("a", IntegerType) :: Nil)) val output = InternalRow(1) :: InternalRow(2) :: Nil - checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId), output) + checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId, true), output) } test("from_json - input=object, schema=array, output=array of single row") { val input = """{"a": 1}""" val schema = ArrayType(StructType(StructField("a", IntegerType) :: Nil)) val output = InternalRow(1) :: Nil - checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId), output) + checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId, true), output) } test("from_json - input=empty array, schema=array, output=empty array") { val input = "[ ]" val schema = ArrayType(StructType(StructField("a", IntegerType) :: Nil)) val output = Nil - checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId), output) + checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId, true), output) } test("from_json - input=empty object, schema=array, output=array of single row with null") { val input = "{ }" val schema = ArrayType(StructType(StructField("a", IntegerType) :: Nil)) val output = InternalRow(null) :: Nil - checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId), output) + checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId, true), output) } test("from_json - input=array of single object, schema=struct, output=single row") { val input = """[{"a": 1}]""" val schema = StructType(StructField("a", IntegerType) :: Nil) val output = InternalRow(1) - checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId), output) + checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId, true), output) } test("from_json - input=array, schema=struct, output=null") { val input = """[{"a": 1}, {"a": 2}]""" val schema = StructType(StructField("a", IntegerType) :: Nil) val output = null - checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId), output) + checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId, true), output) } test("from_json - input=empty array, schema=struct, output=null") { val input = """[]""" val schema = StructType(StructField("a", IntegerType) :: Nil) val output = null - checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId), output) + checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId, true), output) } test("from_json - input=empty object, schema=struct, output=single row with null") { val input = """{ }""" val schema = StructType(StructField("a", IntegerType) :: Nil) val output = InternalRow(null) - checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId), output) + checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId, true), output) } test("from_json null input column") { val schema = StructType(StructField("a", IntegerType) :: Nil) checkEvaluation( - JsonToStructs(schema, Map.empty, Literal.create(null, StringType), gmtId), + JsonToStructs(schema, Map.empty, Literal.create(null, StringType), gmtId, true), null ) } @@ -479,7 +479,7 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with test("SPARK-20549: from_json bad UTF-8") { val schema = StructType(StructField("a", IntegerType) :: Nil) checkEvaluation( - JsonToStructs(schema, Map.empty, Literal(badJson), gmtId), + JsonToStructs(schema, Map.empty, Literal(badJson), gmtId, true), null) } @@ -491,14 +491,14 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with c.set(2016, 0, 1, 0, 0, 0) c.set(Calendar.MILLISECOND, 123) checkEvaluation( - JsonToStructs(schema, Map.empty, Literal(jsonData1), gmtId), + JsonToStructs(schema, Map.empty, Literal(jsonData1), gmtId, true), InternalRow(c.getTimeInMillis * 1000L) ) // The result doesn't change because the json string includes timezone string ("Z" here), // which means the string represents the timestamp string in the timezone regardless of // the timeZoneId parameter. checkEvaluation( - JsonToStructs(schema, Map.empty, Literal(jsonData1), Option("PST")), + JsonToStructs(schema, Map.empty, Literal(jsonData1), Option("PST"), true), InternalRow(c.getTimeInMillis * 1000L) ) @@ -512,7 +512,8 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with schema, Map("timestampFormat" -> "yyyy-MM-dd'T'HH:mm:ss"), Literal(jsonData2), - Option(tz.getID)), + Option(tz.getID), + true), InternalRow(c.getTimeInMillis * 1000L) ) checkEvaluation( @@ -521,7 +522,8 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with Map("timestampFormat" -> "yyyy-MM-dd'T'HH:mm:ss", DateTimeUtils.TIMEZONE_OPTION -> tz.getID), Literal(jsonData2), - gmtId), + gmtId, + true), InternalRow(c.getTimeInMillis * 1000L) ) } @@ -530,7 +532,7 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with test("SPARK-19543: from_json empty input column") { val schema = StructType(StructField("a", IntegerType) :: Nil) checkEvaluation( - JsonToStructs(schema, Map.empty, Literal.create(" ", StringType), gmtId), + JsonToStructs(schema, Map.empty, Literal.create(" ", StringType), gmtId, true), null ) } @@ -685,27 +687,23 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with test("from_json missing fields") { for (forceJsonNullableSchema <- Seq(false, true)) { - withSQLConf(SQLConf.FROM_JSON_FORCE_NULLABLE_SCHEMA.key -> forceJsonNullableSchema.toString) { - val input = - """{ - | "a": 1, - | "c": "foo" - |} - |""".stripMargin - val jsonSchema = new StructType() - .add("a", LongType, nullable = false) - .add("b", StringType, nullable = false) - .add("c", StringType, nullable = false) - val output = InternalRow(1L, null, UTF8String.fromString("foo")) - checkEvaluation( - JsonToStructs(jsonSchema, Map.empty, Literal.create(input, StringType), gmtId), - output - ) - val schema = JsonToStructs(jsonSchema, Map.empty, Literal.create(input, StringType), gmtId) - .dataType - val schemaToCompare = if (forceJsonNullableSchema) jsonSchema.asNullable else jsonSchema - assert(schemaToCompare == schema) - } + val input = + """{ + | "a": 1, + | "c": "foo" + |} + |""".stripMargin + val jsonSchema = new StructType() + .add("a", LongType, nullable = false) + .add("b", StringType, nullable = false) + .add("c", StringType, nullable = false) + val output = InternalRow(1L, null, UTF8String.fromString("foo")) + val expr = JsonToStructs( + jsonSchema, Map.empty, Literal.create(input, StringType), gmtId, forceJsonNullableSchema) + checkEvaluation(expr, output) + val schema = expr.dataType + val schemaToCompare = if (forceJsonNullableSchema) jsonSchema.asNullable else jsonSchema + assert(schemaToCompare == schema) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 25afaacc38d6f..d2e22fa355514 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3179,7 +3179,7 @@ object functions { * @since 2.2.0 */ def from_json(e: Column, schema: DataType, options: Map[String, String]): Column = withExpr { - JsonToStructs(schema, options, e.expr) + new JsonToStructs(schema, options, e.expr) } /** diff --git a/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out index 581dddc89d0bb..14a69128ffb41 100644 --- a/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out @@ -129,7 +129,7 @@ select to_json() struct<> -- !query 12 output org.apache.spark.sql.AnalysisException -Invalid number of arguments for function to_json. Expected: one of 1, 2 and 3; Found: 0; line 1 pos 7 +Invalid number of arguments for function to_json. Expected: one of 1 and 2; Found: 0; line 1 pos 7 -- !query 13 @@ -225,7 +225,7 @@ select from_json() struct<> -- !query 21 output org.apache.spark.sql.AnalysisException -Invalid number of arguments for function from_json. Expected: one of 2, 3 and 4; Found: 0; line 1 pos 7 +Invalid number of arguments for function from_json. Expected: one of 2 and 3; Found: 0; line 1 pos 7 -- !query 22 From 94641fe6cc68e5977dd8663b8f232a287a783acb Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Thu, 3 May 2018 10:59:18 -0500 Subject: [PATCH 0730/1161] [SPARK-23433][CORE] Late zombie task completions update all tasksets Fetch failure lead to multiple tasksets which are active for a given stage. While there is only one "active" version of the taskset, the earlier attempts can still have running tasks, which can complete successfully. So a task completion needs to update every taskset so that it knows the partition is completed. That way the final active taskset does not try to submit another task for the same partition, and so that it knows when it is completed and when it should be marked as a "zombie". Added a regression test. Author: Imran Rashid Closes #21131 from squito/SPARK-23433. --- .../spark/scheduler/TaskSchedulerImpl.scala | 14 +++ .../spark/scheduler/TaskSetManager.scala | 20 +++- .../scheduler/TaskSchedulerImplSuite.scala | 104 ++++++++++++++++++ 3 files changed, 137 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index 0c11806b3981b..8e97b3da33820 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -689,6 +689,20 @@ private[spark] class TaskSchedulerImpl( } } + /** + * Marks the task has completed in all TaskSetManagers for the given stage. + * + * After stage failure and retry, there may be multiple TaskSetManagers for the stage. + * If an earlier attempt of a stage completes a task, we should ensure that the later attempts + * do not also submit those same tasks. That also means that a task completion from an earlier + * attempt can lead to the entire stage getting marked as successful. + */ + private[scheduler] def markPartitionCompletedInAllTaskSets(stageId: Int, partitionId: Int) = { + taskSetsByStageIdAndAttempt.getOrElse(stageId, Map()).values.foreach { tsm => + tsm.markPartitionCompleted(partitionId) + } + } + } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 8a96a7692f614..195fc8025e4b5 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -73,6 +73,8 @@ private[spark] class TaskSetManager( val ser = env.closureSerializer.newInstance() val tasks = taskSet.tasks + private[scheduler] val partitionToIndex = tasks.zipWithIndex + .map { case (t, idx) => t.partitionId -> idx }.toMap val numTasks = tasks.length val copiesRunning = new Array[Int](numTasks) @@ -153,7 +155,7 @@ private[spark] class TaskSetManager( private[scheduler] val speculatableTasks = new HashSet[Int] // Task index, start and finish time for each task attempt (indexed by task ID) - private val taskInfos = new HashMap[Long, TaskInfo] + private[scheduler] val taskInfos = new HashMap[Long, TaskInfo] // Use a MedianHeap to record durations of successful tasks so we know when to launch // speculative tasks. This is only used when speculation is enabled, to avoid the overhead @@ -754,6 +756,9 @@ private[spark] class TaskSetManager( logInfo("Ignoring task-finished event for " + info.id + " in stage " + taskSet.id + " because task " + index + " has already completed successfully") } + // There may be multiple tasksets for this stage -- we let all of them know that the partition + // was completed. This may result in some of the tasksets getting completed. + sched.markPartitionCompletedInAllTaskSets(stageId, tasks(index).partitionId) // This method is called by "TaskSchedulerImpl.handleSuccessfulTask" which holds the // "TaskSchedulerImpl" lock until exiting. To avoid the SPARK-7655 issue, we should not // "deserialize" the value when holding a lock to avoid blocking other threads. So we call @@ -764,6 +769,19 @@ private[spark] class TaskSetManager( maybeFinishTaskSet() } + private[scheduler] def markPartitionCompleted(partitionId: Int): Unit = { + partitionToIndex.get(partitionId).foreach { index => + if (!successful(index)) { + tasksSuccessful += 1 + successful(index) = true + if (tasksSuccessful == numTasks) { + isZombie = true + } + maybeFinishTaskSet() + } + } + } + /** * Marks the task as failed, re-adds it to the list of pending tasks, and notifies the * DAG Scheduler. diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala index 6003899bb7bef..33f2ea1c94e75 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala @@ -917,4 +917,108 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B taskScheduler.initialize(new FakeSchedulerBackend) } } + + test("Completions in zombie tasksets update status of non-zombie taskset") { + val taskScheduler = setupSchedulerWithMockTaskSetBlacklist() + val valueSer = SparkEnv.get.serializer.newInstance() + + def completeTaskSuccessfully(tsm: TaskSetManager, partition: Int): Unit = { + val indexInTsm = tsm.partitionToIndex(partition) + val matchingTaskInfo = tsm.taskAttempts.flatten.filter(_.index == indexInTsm).head + val result = new DirectTaskResult[Int](valueSer.serialize(1), Seq()) + tsm.handleSuccessfulTask(matchingTaskInfo.taskId, result) + } + + // Submit a task set, have it fail with a fetch failed, and then re-submit the task attempt, + // two times, so we have three active task sets for one stage. (For this to really happen, + // you'd need the previous stage to also get restarted, and then succeed, in between each + // attempt, but that happens outside what we're mocking here.) + val zombieAttempts = (0 until 2).map { stageAttempt => + val attempt = FakeTask.createTaskSet(10, stageAttemptId = stageAttempt) + taskScheduler.submitTasks(attempt) + val tsm = taskScheduler.taskSetManagerForAttempt(0, stageAttempt).get + val offers = (0 until 10).map{ idx => WorkerOffer(s"exec-$idx", s"host-$idx", 1) } + taskScheduler.resourceOffers(offers) + assert(tsm.runningTasks === 10) + // fail attempt + tsm.handleFailedTask(tsm.taskAttempts.head.head.taskId, TaskState.FAILED, + FetchFailed(null, 0, 0, 0, "fetch failed")) + // the attempt is a zombie, but the tasks are still running (this could be true even if + // we actively killed those tasks, as killing is best-effort) + assert(tsm.isZombie) + assert(tsm.runningTasks === 9) + tsm + } + + // we've now got 2 zombie attempts, each with 9 tasks still active. Submit the 3rd attempt for + // the stage, but this time with insufficient resources so not all tasks are active. + + val finalAttempt = FakeTask.createTaskSet(10, stageAttemptId = 2) + taskScheduler.submitTasks(finalAttempt) + val finalTsm = taskScheduler.taskSetManagerForAttempt(0, 2).get + val offers = (0 until 5).map{ idx => WorkerOffer(s"exec-$idx", s"host-$idx", 1) } + val finalAttemptLaunchedPartitions = taskScheduler.resourceOffers(offers).flatten.map { task => + finalAttempt.tasks(task.index).partitionId + }.toSet + assert(finalTsm.runningTasks === 5) + assert(!finalTsm.isZombie) + + // We simulate late completions from our zombie tasksets, corresponding to all the pending + // partitions in our final attempt. This means we're only waiting on the tasks we've already + // launched. + val finalAttemptPendingPartitions = (0 until 10).toSet.diff(finalAttemptLaunchedPartitions) + finalAttemptPendingPartitions.foreach { partition => + completeTaskSuccessfully(zombieAttempts(0), partition) + } + + // If there is another resource offer, we shouldn't run anything. Though our final attempt + // used to have pending tasks, now those tasks have been completed by zombie attempts. The + // remaining tasks to compute are already active in the non-zombie attempt. + assert( + taskScheduler.resourceOffers(IndexedSeq(WorkerOffer("exec-1", "host-1", 1))).flatten.isEmpty) + + val remainingTasks = finalAttemptLaunchedPartitions.toIndexedSeq.sorted + + // finally, if we finish the remaining partitions from a mix of tasksets, all attempts should be + // marked as zombie. + // for each of the remaining tasks, find the tasksets with an active copy of the task, and + // finish the task. + remainingTasks.foreach { partition => + val tsm = if (partition == 0) { + // we failed this task on both zombie attempts, this one is only present in the latest + // taskset + finalTsm + } else { + // should be active in every taskset. We choose a zombie taskset just to make sure that + // we transition the active taskset correctly even if the final completion comes + // from a zombie. + zombieAttempts(partition % 2) + } + completeTaskSuccessfully(tsm, partition) + } + + assert(finalTsm.isZombie) + + // no taskset has completed all of its tasks, so no updates to the blacklist tracker yet + verify(blacklist, never).updateBlacklistForSuccessfulTaskSet(anyInt(), anyInt(), anyObject()) + + // finally, lets complete all the tasks. We simulate failures in attempt 1, but everything + // else succeeds, to make sure we get the right updates to the blacklist in all cases. + (zombieAttempts ++ Seq(finalTsm)).foreach { tsm => + val stageAttempt = tsm.taskSet.stageAttemptId + tsm.runningTasksSet.foreach { index => + if (stageAttempt == 1) { + tsm.handleFailedTask(tsm.taskInfos(index).taskId, TaskState.FAILED, TaskResultLost) + } else { + val result = new DirectTaskResult[Int](valueSer.serialize(1), Seq()) + tsm.handleSuccessfulTask(tsm.taskInfos(index).taskId, result) + } + } + + // we update the blacklist for the stage attempts with all successful tasks. Even though + // some tasksets had failures, we still consider them all successful from a blacklisting + // perspective, as the failures weren't from a problem w/ the tasks themselves. + verify(blacklist).updateBlacklistForSuccessfulTaskSet(meq(0), meq(stageAttempt), anyObject()) + } + } } From e3201e165e41f076ec72175af246d12c0da529cf Mon Sep 17 00:00:00 2001 From: maryannxue Date: Thu, 3 May 2018 17:05:02 -0700 Subject: [PATCH 0731/1161] [SPARK-24035][SQL] SQL syntax for Pivot ## What changes were proposed in this pull request? Add SQL support for Pivot according to Pivot grammar defined by Oracle (https://docs.oracle.com/database/121/SQLRF/img_text/pivot_clause.htm) with some simplifications, based on our existing functionality and limitations for Pivot at the backend: 1. For pivot_for_clause (https://docs.oracle.com/database/121/SQLRF/img_text/pivot_for_clause.htm), the column list form is not supported, which means the pivot column can only be one single column. 2. For pivot_in_clause (https://docs.oracle.com/database/121/SQLRF/img_text/pivot_in_clause.htm), the sub-query form and "ANY" is not supported (this is only supported by Oracle for XML anyway). 3. For pivot_in_clause, aliases for the constant values are not supported. The code changes are: 1. Add parser support for Pivot. Note that according to https://docs.oracle.com/database/121/SQLRF/statements_10002.htm#i2076542, Pivot cannot be used together with lateral views in the from clause. This restriction has been implemented in the Parser rule. 2. Infer group-by expressions: group-by expressions are not explicitly specified in SQL Pivot clause and need to be deduced based on this rule: https://docs.oracle.com/database/121/SQLRF/statements_10002.htm#CHDFAFIE, so we have to post-fix it at query analysis stage. 3. Override Pivot.resolved as "false": for the reason mentioned in [2] and the fact that output attributes change after Pivot being replaced by Project or Aggregate, we avoid resolving parent references until after Pivot has been resolved and replaced. 4. Verify aggregate expressions: only aggregate expressions with or without aliases can appear in the first part of the Pivot clause, and this check is performed as analysis stage. ## How was this patch tested? A new test suite PivotSuite is added. Author: maryannxue Closes #21187 from maryannxue/spark-24035. --- .../spark/sql/catalyst/parser/SqlBase.g4 | 12 +- .../sql/catalyst/analysis/Analyzer.scala | 35 +++- .../sql/catalyst/parser/AstBuilder.scala | 20 +- .../plans/logical/basicLogicalOperators.scala | 27 ++- .../parser/TableIdentifierParserSuite.scala | 6 +- .../spark/sql/RelationalGroupedDataset.scala | 2 +- .../test/resources/sql-tests/inputs/pivot.sql | 113 ++++++++++ .../resources/sql-tests/results/pivot.sql.out | 194 ++++++++++++++++++ 8 files changed, 386 insertions(+), 23 deletions(-) create mode 100644 sql/core/src/test/resources/sql-tests/inputs/pivot.sql create mode 100644 sql/core/src/test/resources/sql-tests/results/pivot.sql.out diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 5fa75fe348e68..f7f921ec22c35 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -398,7 +398,7 @@ hintStatement ; fromClause - : FROM relation (',' relation)* lateralView* + : FROM relation (',' relation)* (pivotClause | lateralView*)? ; aggregation @@ -413,6 +413,10 @@ groupingSet | expression ; +pivotClause + : PIVOT '(' aggregates=namedExpressionSeq FOR pivotColumn=identifier IN '(' pivotValues+=constant (',' pivotValues+=constant)* ')' ')' + ; + lateralView : LATERAL VIEW (OUTER)? qualifiedName '(' (expression (',' expression)*)? ')' tblName=identifier (AS? colName+=identifier (',' colName+=identifier)*)? ; @@ -725,7 +729,7 @@ nonReserved | ADD | OVER | PARTITION | RANGE | ROWS | PRECEDING | FOLLOWING | CURRENT | ROW | LAST | FIRST | AFTER | MAP | ARRAY | STRUCT - | LATERAL | WINDOW | REDUCE | TRANSFORM | SERDE | SERDEPROPERTIES | RECORDREADER + | PIVOT | LATERAL | WINDOW | REDUCE | TRANSFORM | SERDE | SERDEPROPERTIES | RECORDREADER | DELIMITED | FIELDS | TERMINATED | COLLECTION | ITEMS | KEYS | ESCAPED | LINES | SEPARATED | EXTENDED | REFRESH | CLEAR | CACHE | UNCACHE | LAZY | GLOBAL | TEMPORARY | OPTIONS | GROUPING | CUBE | ROLLUP @@ -745,7 +749,7 @@ nonReserved | REVOKE | GRANT | LOCK | UNLOCK | MSCK | REPAIR | RECOVER | EXPORT | IMPORT | LOAD | VALUES | COMMENT | ROLE | ROLES | COMPACTIONS | PRINCIPALS | TRANSACTIONS | INDEX | INDEXES | LOCKS | OPTION | LOCAL | INPATH | ASC | DESC | LIMIT | RENAME | SETS - | AT | NULLS | OVERWRITE | ALL | ALTER | AS | BETWEEN | BY | CREATE | DELETE + | AT | NULLS | OVERWRITE | ALL | ANY | ALTER | AS | BETWEEN | BY | CREATE | DELETE | DESCRIBE | DROP | EXISTS | FALSE | FOR | GROUP | IN | INSERT | INTO | IS |LIKE | NULL | ORDER | OUTER | TABLE | TRUE | WITH | RLIKE | AND | CASE | CAST | DISTINCT | DIV | ELSE | END | FUNCTION | INTERVAL | MACRO | OR | STRATIFY | THEN @@ -760,6 +764,7 @@ FROM: 'FROM'; ADD: 'ADD'; AS: 'AS'; ALL: 'ALL'; +ANY: 'ANY'; DISTINCT: 'DISTINCT'; WHERE: 'WHERE'; GROUP: 'GROUP'; @@ -805,6 +810,7 @@ RIGHT: 'RIGHT'; FULL: 'FULL'; NATURAL: 'NATURAL'; ON: 'ON'; +PIVOT: 'PIVOT'; LATERAL: 'LATERAL'; WINDOW: 'WINDOW'; OVER: 'OVER'; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index e821e96522f7c..dfdcdbc1eb2c7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -275,9 +275,9 @@ class Analyzer( case g: GroupingSets if g.child.resolved && hasUnresolvedAlias(g.aggregations) => g.copy(aggregations = assignAliases(g.aggregations)) - case Pivot(groupByExprs, pivotColumn, pivotValues, aggregates, child) - if child.resolved && hasUnresolvedAlias(groupByExprs) => - Pivot(assignAliases(groupByExprs), pivotColumn, pivotValues, aggregates, child) + case Pivot(groupByOpt, pivotColumn, pivotValues, aggregates, child) + if child.resolved && groupByOpt.isDefined && hasUnresolvedAlias(groupByOpt.get) => + Pivot(Some(assignAliases(groupByOpt.get)), pivotColumn, pivotValues, aggregates, child) case Project(projectList, child) if child.resolved && hasUnresolvedAlias(projectList) => Project(assignAliases(projectList), child) @@ -504,9 +504,20 @@ class Analyzer( object ResolvePivot extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case p: Pivot if !p.childrenResolved | !p.aggregates.forall(_.resolved) - | !p.groupByExprs.forall(_.resolved) | !p.pivotColumn.resolved => p - case Pivot(groupByExprs, pivotColumn, pivotValues, aggregates, child) => + case p: Pivot if !p.childrenResolved || !p.aggregates.forall(_.resolved) + || (p.groupByExprsOpt.isDefined && !p.groupByExprsOpt.get.forall(_.resolved)) + || !p.pivotColumn.resolved => p + case Pivot(groupByExprsOpt, pivotColumn, pivotValues, aggregates, child) => + // Check all aggregate expressions. + aggregates.foreach { e => + if (!isAggregateExpression(e)) { + throw new AnalysisException( + s"Aggregate expression required for pivot, found '$e'") + } + } + // Group-by expressions coming from SQL are implicit and need to be deduced. + val groupByExprs = groupByExprsOpt.getOrElse( + (child.outputSet -- aggregates.flatMap(_.references) -- pivotColumn.references).toSeq) val singleAgg = aggregates.size == 1 def outputName(value: Literal, aggregate: Expression): String = { val utf8Value = Cast(value, StringType, Some(conf.sessionLocalTimeZone)).eval(EmptyRow) @@ -568,16 +579,20 @@ class Analyzer( // TODO: Don't construct the physical container until after analysis. case ae: AggregateExpression => ae.copy(resultId = NamedExpression.newExprId) } - if (filteredAggregate.fastEquals(aggregate)) { - throw new AnalysisException( - s"Aggregate expression required for pivot, found '$aggregate'") - } Alias(filteredAggregate, outputName(value, aggregate))() } } Aggregate(groupByExprs, groupByExprs ++ pivotAggregates, child) } } + + private def isAggregateExpression(expr: Expression): Boolean = { + expr match { + case Alias(e, _) => isAggregateExpression(e) + case AggregateExpression(_, _, _, _) => true + case _ => false + } + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index bdc357d54a878..64eed23884584 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -503,7 +503,11 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging val join = right.optionalMap(left)(Join(_, _, Inner, None)) withJoinRelations(join, relation) } - ctx.lateralView.asScala.foldLeft(from)(withGenerate) + if (ctx.pivotClause() != null) { + withPivot(ctx.pivotClause, from) + } else { + ctx.lateralView.asScala.foldLeft(from)(withGenerate) + } } /** @@ -614,6 +618,20 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging plan } + /** + * Add a [[Pivot]] to a logical plan. + */ + private def withPivot( + ctx: PivotClauseContext, + query: LogicalPlan): LogicalPlan = withOrigin(ctx) { + val aggregates = Option(ctx.aggregates).toSeq + .flatMap(_.namedExpression.asScala) + .map(typedVisit[Expression]) + val pivotColumn = UnresolvedAttribute.quoted(ctx.pivotColumn.getText) + val pivotValues = ctx.pivotValues.asScala.map(typedVisit[Expression]).map(Literal.apply) + Pivot(None, pivotColumn, pivotValues, aggregates, query) + } + /** * Add a [[Generate]] (Lateral View) to a logical plan. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 10df504795430..3bf32ef7884e5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -686,17 +686,34 @@ case class GroupingSets( override lazy val resolved: Boolean = false } +/** + * A constructor for creating a pivot, which will later be converted to a [[Project]] + * or an [[Aggregate]] during the query analysis. + * + * @param groupByExprsOpt A sequence of group by expressions. This field should be None if coming + * from SQL, in which group by expressions are not explicitly specified. + * @param pivotColumn The pivot column. + * @param pivotValues A sequence of values for the pivot column. + * @param aggregates The aggregation expressions, each with or without an alias. + * @param child Child operator + */ case class Pivot( - groupByExprs: Seq[NamedExpression], + groupByExprsOpt: Option[Seq[NamedExpression]], pivotColumn: Expression, pivotValues: Seq[Literal], aggregates: Seq[Expression], child: LogicalPlan) extends UnaryNode { - override def output: Seq[Attribute] = groupByExprs.map(_.toAttribute) ++ aggregates match { - case agg :: Nil => pivotValues.map(value => AttributeReference(value.toString, agg.dataType)()) - case _ => pivotValues.flatMap{ value => - aggregates.map(agg => AttributeReference(value + "_" + agg.sql, agg.dataType)()) + override lazy val resolved = false // Pivot will be replaced after being resolved. + override def output: Seq[Attribute] = { + val pivotAgg = aggregates match { + case agg :: Nil => + pivotValues.map(value => AttributeReference(value.toString, agg.dataType)()) + case _ => + pivotValues.flatMap { value => + aggregates.map(agg => AttributeReference(value + "_" + agg.sql, agg.dataType)()) + } } + groupByExprsOpt.getOrElse(Seq.empty).map(_.toAttribute) ++ pivotAgg } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala index cc80a41df998d..89903c2825125 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala @@ -41,12 +41,12 @@ class TableIdentifierParserSuite extends SparkFunSuite { "sort", "sorted", "ssl", "statistics", "stored", "streamtable", "string", "struct", "tables", "tblproperties", "temporary", "terminated", "tinyint", "touch", "transactions", "unarchive", "undo", "uniontype", "unlock", "unset", "unsigned", "uri", "use", "utc", "utctimestamp", - "view", "while", "year", "work", "transaction", "write", "isolation", "level", - "snapshot", "autocommit", "all", "alter", "array", "as", "authorization", "between", "bigint", + "view", "while", "year", "work", "transaction", "write", "isolation", "level", "snapshot", + "autocommit", "all", "any", "alter", "array", "as", "authorization", "between", "bigint", "binary", "boolean", "both", "by", "create", "cube", "current_date", "current_timestamp", "cursor", "date", "decimal", "delete", "describe", "double", "drop", "exists", "external", "false", "fetch", "float", "for", "grant", "group", "grouping", "import", "in", - "insert", "int", "into", "is", "lateral", "like", "local", "none", "null", + "insert", "int", "into", "is", "pivot", "lateral", "like", "local", "none", "null", "of", "order", "out", "outer", "partition", "percent", "procedure", "range", "reads", "revoke", "rollup", "row", "rows", "set", "smallint", "table", "timestamp", "to", "trigger", "true", "truncate", "update", "user", "values", "with", "regexp", "rlike", diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index 7147798d99533..6c2be3610ae30 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -73,7 +73,7 @@ class RelationalGroupedDataset protected[sql]( case RelationalGroupedDataset.PivotType(pivotCol, values) => val aliasedGrps = groupingExprs.map(alias) Dataset.ofRows( - df.sparkSession, Pivot(aliasedGrps, pivotCol, values, aggExprs, df.logicalPlan)) + df.sparkSession, Pivot(Some(aliasedGrps), pivotCol, values, aggExprs, df.logicalPlan)) } } diff --git a/sql/core/src/test/resources/sql-tests/inputs/pivot.sql b/sql/core/src/test/resources/sql-tests/inputs/pivot.sql new file mode 100644 index 0000000000000..01dea6c81c11b --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/pivot.sql @@ -0,0 +1,113 @@ +create temporary view courseSales as select * from values + ("dotNET", 2012, 10000), + ("Java", 2012, 20000), + ("dotNET", 2012, 5000), + ("dotNET", 2013, 48000), + ("Java", 2013, 30000) + as courseSales(course, year, earnings); + +create temporary view years as select * from values + (2012, 1), + (2013, 2) + as years(y, s); + +-- pivot courses +SELECT * FROM ( + SELECT year, course, earnings FROM courseSales +) +PIVOT ( + sum(earnings) + FOR course IN ('dotNET', 'Java') +); + +-- pivot years with no subquery +SELECT * FROM courseSales +PIVOT ( + sum(earnings) + FOR year IN (2012, 2013) +); + +-- pivot courses with multiple aggregations +SELECT * FROM ( + SELECT year, course, earnings FROM courseSales +) +PIVOT ( + sum(earnings), avg(earnings) + FOR course IN ('dotNET', 'Java') +); + +-- pivot with no group by column +SELECT * FROM ( + SELECT course, earnings FROM courseSales +) +PIVOT ( + sum(earnings) + FOR course IN ('dotNET', 'Java') +); + +-- pivot with no group by column and with multiple aggregations on different columns +SELECT * FROM ( + SELECT year, course, earnings FROM courseSales +) +PIVOT ( + sum(earnings), min(year) + FOR course IN ('dotNET', 'Java') +); + +-- pivot on join query with multiple group by columns +SELECT * FROM ( + SELECT course, year, earnings, s + FROM courseSales + JOIN years ON year = y +) +PIVOT ( + sum(earnings) + FOR s IN (1, 2) +); + +-- pivot on join query with multiple aggregations on different columns +SELECT * FROM ( + SELECT course, year, earnings, s + FROM courseSales + JOIN years ON year = y +) +PIVOT ( + sum(earnings), min(s) + FOR course IN ('dotNET', 'Java') +); + +-- pivot on join query with multiple columns in one aggregation +SELECT * FROM ( + SELECT course, year, earnings, s + FROM courseSales + JOIN years ON year = y +) +PIVOT ( + sum(earnings * s) + FOR course IN ('dotNET', 'Java') +); + +-- pivot with aliases and projection +SELECT 2012_s, 2013_s, 2012_a, 2013_a, c FROM ( + SELECT year y, course c, earnings e FROM courseSales +) +PIVOT ( + sum(e) s, avg(e) a + FOR y IN (2012, 2013) +); + +-- pivot years with non-aggregate function +SELECT * FROM courseSales +PIVOT ( + abs(earnings) + FOR year IN (2012, 2013) +); + +-- pivot with unresolvable columns +SELECT * FROM ( + SELECT course, earnings FROM courseSales +) +PIVOT ( + sum(earnings) + FOR year IN (2012, 2013) +); diff --git a/sql/core/src/test/resources/sql-tests/results/pivot.sql.out b/sql/core/src/test/resources/sql-tests/results/pivot.sql.out new file mode 100644 index 0000000000000..85e3488990e20 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/pivot.sql.out @@ -0,0 +1,194 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 13 + + +-- !query 0 +create temporary view courseSales as select * from values + ("dotNET", 2012, 10000), + ("Java", 2012, 20000), + ("dotNET", 2012, 5000), + ("dotNET", 2013, 48000), + ("Java", 2013, 30000) + as courseSales(course, year, earnings) +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +create temporary view years as select * from values + (2012, 1), + (2013, 2) + as years(y, s) +-- !query 1 schema +struct<> +-- !query 1 output + + + +-- !query 2 +SELECT * FROM ( + SELECT year, course, earnings FROM courseSales +) +PIVOT ( + sum(earnings) + FOR course IN ('dotNET', 'Java') +) +-- !query 2 schema +struct +-- !query 2 output +2012 15000 20000 +2013 48000 30000 + + +-- !query 3 +SELECT * FROM courseSales +PIVOT ( + sum(earnings) + FOR year IN (2012, 2013) +) +-- !query 3 schema +struct +-- !query 3 output +Java 20000 30000 +dotNET 15000 48000 + + +-- !query 4 +SELECT * FROM ( + SELECT year, course, earnings FROM courseSales +) +PIVOT ( + sum(earnings), avg(earnings) + FOR course IN ('dotNET', 'Java') +) +-- !query 4 schema +struct +-- !query 4 output +2012 15000 7500.0 20000 20000.0 +2013 48000 48000.0 30000 30000.0 + + +-- !query 5 +SELECT * FROM ( + SELECT course, earnings FROM courseSales +) +PIVOT ( + sum(earnings) + FOR course IN ('dotNET', 'Java') +) +-- !query 5 schema +struct +-- !query 5 output +63000 50000 + + +-- !query 6 +SELECT * FROM ( + SELECT year, course, earnings FROM courseSales +) +PIVOT ( + sum(earnings), min(year) + FOR course IN ('dotNET', 'Java') +) +-- !query 6 schema +struct +-- !query 6 output +63000 2012 50000 2012 + + +-- !query 7 +SELECT * FROM ( + SELECT course, year, earnings, s + FROM courseSales + JOIN years ON year = y +) +PIVOT ( + sum(earnings) + FOR s IN (1, 2) +) +-- !query 7 schema +struct +-- !query 7 output +Java 2012 20000 NULL +Java 2013 NULL 30000 +dotNET 2012 15000 NULL +dotNET 2013 NULL 48000 + + +-- !query 8 +SELECT * FROM ( + SELECT course, year, earnings, s + FROM courseSales + JOIN years ON year = y +) +PIVOT ( + sum(earnings), min(s) + FOR course IN ('dotNET', 'Java') +) +-- !query 8 schema +struct +-- !query 8 output +2012 15000 1 20000 1 +2013 48000 2 30000 2 + + +-- !query 9 +SELECT * FROM ( + SELECT course, year, earnings, s + FROM courseSales + JOIN years ON year = y +) +PIVOT ( + sum(earnings * s) + FOR course IN ('dotNET', 'Java') +) +-- !query 9 schema +struct +-- !query 9 output +2012 15000 20000 +2013 96000 60000 + + +-- !query 10 +SELECT 2012_s, 2013_s, 2012_a, 2013_a, c FROM ( + SELECT year y, course c, earnings e FROM courseSales +) +PIVOT ( + sum(e) s, avg(e) a + FOR y IN (2012, 2013) +) +-- !query 10 schema +struct<2012_s:bigint,2013_s:bigint,2012_a:double,2013_a:double,c:string> +-- !query 10 output +15000 48000 7500.0 48000.0 dotNET +20000 30000 20000.0 30000.0 Java + + +-- !query 11 +SELECT * FROM courseSales +PIVOT ( + abs(earnings) + FOR year IN (2012, 2013) +) +-- !query 11 schema +struct<> +-- !query 11 output +org.apache.spark.sql.AnalysisException +Aggregate expression required for pivot, found 'abs(earnings#x)'; + + +-- !query 12 +SELECT * FROM ( + SELECT course, earnings FROM courseSales +) +PIVOT ( + sum(earnings) + FOR year IN (2012, 2013) +) +-- !query 12 schema +struct<> +-- !query 12 output +org.apache.spark.sql.AnalysisException +cannot resolve '`year`' given input columns: [__auto_generated_subquery_name.course, __auto_generated_subquery_name.earnings]; line 4 pos 0 From e646ae67f2e793204bc819ab2b90815214c2bbf3 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 3 May 2018 17:27:13 -0700 Subject: [PATCH 0732/1161] [SPARK-24168][SQL] WindowExec should not access SQLConf at executor side ## What changes were proposed in this pull request? This PR is extracted from #21190 , to make it easier to backport. `WindowExec#createBoundOrdering` is called on executor side, so we can't use `conf.sessionLocalTimezone` there. ## How was this patch tested? tested in #21190 Author: Wenchen Fan Closes #21225 from cloud-fan/minor3. --- .../spark/sql/execution/window/WindowExec.scala | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala index 800a2ea3f3996..626f39d9e95cc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala @@ -112,9 +112,11 @@ case class WindowExec( * * @param frame to evaluate. This can either be a Row or Range frame. * @param bound with respect to the row. + * @param timeZone the session local timezone for time related calculations. * @return a bound ordering object. */ - private[this] def createBoundOrdering(frame: FrameType, bound: Expression): BoundOrdering = { + private[this] def createBoundOrdering( + frame: FrameType, bound: Expression, timeZone: String): BoundOrdering = { (frame, bound) match { case (RowFrame, CurrentRow) => RowBoundOrdering(0) @@ -144,7 +146,7 @@ case class WindowExec( val boundExpr = (expr.dataType, boundOffset.dataType) match { case (DateType, IntegerType) => DateAdd(expr, boundOffset) case (TimestampType, CalendarIntervalType) => - TimeAdd(expr, boundOffset, Some(conf.sessionLocalTimeZone)) + TimeAdd(expr, boundOffset, Some(timeZone)) case (a, b) if a== b => Add(expr, boundOffset) } val bound = newMutableProjection(boundExpr :: Nil, child.output) @@ -197,6 +199,7 @@ case class WindowExec( // Map the groups to a (unbound) expression and frame factory pair. var numExpressions = 0 + val timeZone = conf.sessionLocalTimeZone framedFunctions.toSeq.map { case (key, (expressions, functionSeq)) => val ordinal = numExpressions @@ -237,7 +240,7 @@ case class WindowExec( new UnboundedPrecedingWindowFunctionFrame( target, processor, - createBoundOrdering(frameType, upper)) + createBoundOrdering(frameType, upper, timeZone)) } // Shrinking Frame. @@ -246,7 +249,7 @@ case class WindowExec( new UnboundedFollowingWindowFunctionFrame( target, processor, - createBoundOrdering(frameType, lower)) + createBoundOrdering(frameType, lower, timeZone)) } // Moving Frame. @@ -255,8 +258,8 @@ case class WindowExec( new SlidingWindowFunctionFrame( target, processor, - createBoundOrdering(frameType, lower), - createBoundOrdering(frameType, upper)) + createBoundOrdering(frameType, lower, timeZone), + createBoundOrdering(frameType, upper, timeZone)) } } From 0c23e254c38d4a9210939e1e1b0074278568abed Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 4 May 2018 09:27:14 +0800 Subject: [PATCH 0733/1161] [SPARK-24167][SQL] ParquetFilters should not access SQLConf at executor side ## What changes were proposed in this pull request? This PR is extracted from #21190 , to make it easier to backport. `ParquetFilters` is used in the file scan function, which is executed in executor side, so we can't call `conf.parquetFilterPushDownDate` there. ## How was this patch tested? it's tested in #21190 Author: Wenchen Fan Closes #21224 from cloud-fan/minor2. --- .../datasources/parquet/ParquetFileFormat.scala | 3 ++- .../datasources/parquet/ParquetFilters.scala | 15 +++++++-------- .../datasources/parquet/ParquetFilterSuite.scala | 10 ++++++---- 3 files changed, 15 insertions(+), 13 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala index d8f47eec952de..d1f9e11ed4225 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala @@ -342,6 +342,7 @@ class ParquetFileFormat sparkSession.sessionState.conf.parquetFilterPushDown // Whole stage codegen (PhysicalRDD) is able to deal with batches directly val returningBatch = supportBatch(sparkSession, resultSchema) + val pushDownDate = sqlConf.parquetFilterPushDownDate (file: PartitionedFile) => { assert(file.partitionValues.numFields == partitionSchema.size) @@ -352,7 +353,7 @@ class ParquetFileFormat // Collects all converted Parquet filter predicates. Notice that not all predicates can be // converted (`ParquetFilters.createFilter` returns an `Option`). That's why a `flatMap` // is used here. - .flatMap(ParquetFilters.createFilter(requiredSchema, _)) + .flatMap(new ParquetFilters(pushDownDate).createFilter(requiredSchema, _)) .reduceOption(FilterApi.and) } else { None diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala index ccc8306866d68..310626197a763 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala @@ -25,14 +25,13 @@ import org.apache.parquet.io.api.Binary import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.catalyst.util.DateTimeUtils.SQLDate -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources import org.apache.spark.sql.types._ /** * Some utility function to convert Spark data source filters to Parquet filters. */ -private[parquet] object ParquetFilters { +private[parquet] class ParquetFilters(pushDownDate: Boolean) { private def dateToDays(date: Date): SQLDate = { DateTimeUtils.fromJavaDate(date) @@ -59,7 +58,7 @@ private[parquet] object ParquetFilters { (n: String, v: Any) => FilterApi.eq( binaryColumn(n), Option(v).map(b => Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])).orNull) - case DateType if SQLConf.get.parquetFilterPushDownDate => + case DateType if pushDownDate => (n: String, v: Any) => FilterApi.eq( intColumn(n), Option(v).map(date => dateToDays(date.asInstanceOf[Date]).asInstanceOf[Integer]).orNull) @@ -85,7 +84,7 @@ private[parquet] object ParquetFilters { (n: String, v: Any) => FilterApi.notEq( binaryColumn(n), Option(v).map(b => Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])).orNull) - case DateType if SQLConf.get.parquetFilterPushDownDate => + case DateType if pushDownDate => (n: String, v: Any) => FilterApi.notEq( intColumn(n), Option(v).map(date => dateToDays(date.asInstanceOf[Date]).asInstanceOf[Integer]).orNull) @@ -108,7 +107,7 @@ private[parquet] object ParquetFilters { case BinaryType => (n: String, v: Any) => FilterApi.lt(binaryColumn(n), Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])) - case DateType if SQLConf.get.parquetFilterPushDownDate => + case DateType if pushDownDate => (n: String, v: Any) => FilterApi.lt( intColumn(n), Option(v).map(date => dateToDays(date.asInstanceOf[Date]).asInstanceOf[Integer]).orNull) @@ -131,7 +130,7 @@ private[parquet] object ParquetFilters { case BinaryType => (n: String, v: Any) => FilterApi.ltEq(binaryColumn(n), Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])) - case DateType if SQLConf.get.parquetFilterPushDownDate => + case DateType if pushDownDate => (n: String, v: Any) => FilterApi.ltEq( intColumn(n), Option(v).map(date => dateToDays(date.asInstanceOf[Date]).asInstanceOf[Integer]).orNull) @@ -154,7 +153,7 @@ private[parquet] object ParquetFilters { case BinaryType => (n: String, v: Any) => FilterApi.gt(binaryColumn(n), Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])) - case DateType if SQLConf.get.parquetFilterPushDownDate => + case DateType if pushDownDate => (n: String, v: Any) => FilterApi.gt( intColumn(n), Option(v).map(date => dateToDays(date.asInstanceOf[Date]).asInstanceOf[Integer]).orNull) @@ -177,7 +176,7 @@ private[parquet] object ParquetFilters { case BinaryType => (n: String, v: Any) => FilterApi.gtEq(binaryColumn(n), Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])) - case DateType if SQLConf.get.parquetFilterPushDownDate => + case DateType if pushDownDate => (n: String, v: Any) => FilterApi.gtEq( intColumn(n), Option(v).map(date => dateToDays(date.asInstanceOf[Date]).asInstanceOf[Integer]).orNull) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index 1d3476e747046..667e0b1760e3d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -55,6 +55,8 @@ import org.apache.spark.util.{AccumulatorContext, AccumulatorV2} */ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContext { + private lazy val parquetFilters = new ParquetFilters(conf.parquetFilterPushDownDate) + override def beforeEach(): Unit = { super.beforeEach() // Note that there are many tests here that require record-level filtering set to be true. @@ -99,7 +101,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex assert(selectedFilters.nonEmpty, "No filter is pushed down") selectedFilters.foreach { pred => - val maybeFilter = ParquetFilters.createFilter(df.schema, pred) + val maybeFilter = parquetFilters.createFilter(df.schema, pred) assert(maybeFilter.isDefined, s"Couldn't generate filter predicate for $pred") // Doesn't bother checking type parameters here (e.g. `Eq[Integer]`) maybeFilter.exists(_.getClass === filterClass) @@ -517,7 +519,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex lt(intColumn("a"), 10: Integer), gt(doubleColumn("c"), 1.5: java.lang.Double))) ) { - ParquetFilters.createFilter( + parquetFilters.createFilter( schema, sources.And( sources.LessThan("a", 10), @@ -525,7 +527,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex } assertResult(None) { - ParquetFilters.createFilter( + parquetFilters.createFilter( schema, sources.And( sources.LessThan("a", 10), @@ -533,7 +535,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex } assertResult(None) { - ParquetFilters.createFilter( + parquetFilters.createFilter( schema, sources.Not( sources.And( From 7f1b6b182e3cf3cbf29399e7bfbe03fa869e0bc8 Mon Sep 17 00:00:00 2001 From: Arun Mahadevan Date: Fri, 4 May 2018 16:02:21 +0800 Subject: [PATCH 0734/1161] [SPARK-24136][SS] Fix MemoryStreamDataReader.next to skip sleeping if record is available ## What changes were proposed in this pull request? Avoid unnecessary sleep (10 ms) in each invocation of MemoryStreamDataReader.next. ## How was this patch tested? Ran ContinuousSuite from IDE. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Arun Mahadevan Closes #21207 from arunmahadevan/memorystream. --- .../streaming/sources/ContinuousMemoryStream.scala | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala index c28919b8b729b..a8fca3c19a2d2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala @@ -183,11 +183,10 @@ class ContinuousMemoryStreamDataReader( private var current: Option[Row] = None override def next(): Boolean = { - current = None + current = getRecord while (current.isEmpty) { Thread.sleep(10) - current = endpoint.askSync[Option[Row]]( - GetRecord(ContinuousMemoryStreamPartitionOffset(partition, currentOffset))) + current = getRecord } currentOffset += 1 true @@ -199,6 +198,10 @@ class ContinuousMemoryStreamDataReader( override def getOffset: ContinuousMemoryStreamPartitionOffset = ContinuousMemoryStreamPartitionOffset(partition, currentOffset) + + private def getRecord: Option[Row] = + endpoint.askSync[Option[Row]]( + GetRecord(ContinuousMemoryStreamPartitionOffset(partition, currentOffset))) } case class ContinuousMemoryStreamOffset(partitionNums: Map[Int, Int]) From 4d5de4d303a773b1c18c350072344bd7efca9fc4 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 4 May 2018 19:20:15 +0800 Subject: [PATCH 0735/1161] [SPARK-23697][CORE] LegacyAccumulatorWrapper should define isZero correctly ## What changes were proposed in this pull request? It's possible that Accumulators of Spark 1.x may no longer work with Spark 2.x. This is because `LegacyAccumulatorWrapper.isZero` may return wrong answer if `AccumulableParam` doesn't define equals/hashCode. This PR fixes this by using reference equality check in `LegacyAccumulatorWrapper.isZero`. ## How was this patch tested? a new test Author: Wenchen Fan Closes #21229 from cloud-fan/accumulator. --- .../org/apache/spark/util/AccumulatorV2.scala | 6 ++++-- .../spark/util/AccumulatorV2Suite.scala | 19 +++++++++++++++++++ 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala b/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala index 0f84ea9752cf5..2bc84953a56eb 100644 --- a/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala +++ b/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala @@ -486,7 +486,9 @@ class LegacyAccumulatorWrapper[R, T]( param: org.apache.spark.AccumulableParam[R, T]) extends AccumulatorV2[T, R] { private[spark] var _value = initialValue // Current value on driver - override def isZero: Boolean = _value == param.zero(initialValue) + @transient private lazy val _zero = param.zero(initialValue) + + override def isZero: Boolean = _value.asInstanceOf[AnyRef].eq(_zero.asInstanceOf[AnyRef]) override def copy(): LegacyAccumulatorWrapper[R, T] = { val acc = new LegacyAccumulatorWrapper(initialValue, param) @@ -495,7 +497,7 @@ class LegacyAccumulatorWrapper[R, T]( } override def reset(): Unit = { - _value = param.zero(initialValue) + _value = _zero } override def add(v: T): Unit = _value = param.addAccumulator(_value, v) diff --git a/core/src/test/scala/org/apache/spark/util/AccumulatorV2Suite.scala b/core/src/test/scala/org/apache/spark/util/AccumulatorV2Suite.scala index a04644d57ed88..fe0a9a471a651 100644 --- a/core/src/test/scala/org/apache/spark/util/AccumulatorV2Suite.scala +++ b/core/src/test/scala/org/apache/spark/util/AccumulatorV2Suite.scala @@ -18,6 +18,7 @@ package org.apache.spark.util import org.apache.spark._ +import org.apache.spark.serializer.JavaSerializer class AccumulatorV2Suite extends SparkFunSuite { @@ -162,4 +163,22 @@ class AccumulatorV2Suite extends SparkFunSuite { assert(acc3.isZero) assert(acc3.value === "") } + + test("LegacyAccumulatorWrapper with AccumulatorParam that has no equals/hashCode") { + class MyData(val i: Int) extends Serializable + val param = new AccumulatorParam[MyData] { + override def zero(initialValue: MyData): MyData = new MyData(0) + override def addInPlace(r1: MyData, r2: MyData): MyData = new MyData(r1.i + r2.i) + } + + val acc = new LegacyAccumulatorWrapper(new MyData(0), param) + acc.metadata = AccumulatorMetadata( + AccumulatorContext.newId(), + Some("test"), + countFailedValues = false) + AccumulatorContext.register(acc) + + val ser = new JavaSerializer(new SparkConf).newInstance() + ser.serialize(acc) + } } From d04806a23c1843a7f0dcc4fa236ed1b40ae113a5 Mon Sep 17 00:00:00 2001 From: Thomas Graves Date: Fri, 4 May 2018 13:29:47 -0700 Subject: [PATCH 0736/1161] =?UTF-8?q?[SPARK-24124]=20Spark=20history=20ser?= =?UTF-8?q?ver=20should=20create=20spark.history.store.=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …path and set permissions properly ## What changes were proposed in this pull request? Spark history server should create spark.history.store.path and set permissions properly. Note createdDirectories doesn't do anything if the directories are already created. This does not stomp on the permissions if the user had manually created the directory before the history server could. ## How was this patch tested? Manually tested in a 100 node cluster. Ensured directories created with proper permissions. Ensured restarted worked apps/temp directories worked as apps were read. Author: Thomas Graves Closes #21234 from tgravescs/SPARK-24124. --- .../apache/spark/deploy/history/FsHistoryProvider.scala | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index 56db9359e033f..bf1eeb0c1bf59 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -18,6 +18,8 @@ package org.apache.spark.deploy.history import java.io.{File, FileNotFoundException, IOException} +import java.nio.file.Files +import java.nio.file.attribute.PosixFilePermissions import java.util.{Date, ServiceLoader} import java.util.concurrent.{ExecutorService, TimeUnit} import java.util.zip.{ZipEntry, ZipOutputStream} @@ -130,8 +132,10 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) // Visible for testing. private[history] val listing: KVStore = storePath.map { path => - require(path.isDirectory(), s"Configured store directory ($path) does not exist.") - val dbPath = new File(path, "listing.ldb") + val perms = PosixFilePermissions.fromString("rwx------") + val dbPath = Files.createDirectories(new File(path, "listing.ldb").toPath(), + PosixFilePermissions.asFileAttribute(perms)).toFile() + val metadata = new FsHistoryProviderMetadata(CURRENT_LISTING_VERSION, AppStatusStore.CURRENT_VERSION, logDir.toString()) From af4dc50280ffcdeda208ef2dc5f8b843389732e5 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Fri, 4 May 2018 14:14:40 -0700 Subject: [PATCH 0737/1161] [SPARK-24039][SS] Do continuous processing writes with multiple compute() calls ## What changes were proposed in this pull request? Do continuous processing writes with multiple compute() calls. The current strategy (before this PR) is hacky; we just call next() on an iterator which has already returned hasNext = false, knowing that all the nodes we whitelist handle this properly. This will have to be changed before we can support more complex query plans. (In particular, I have a WIP https://github.com/jose-torres/spark/pull/13 which should be able to support aggregates in a single partition with minimal additional work.) Most of the changes here are just refactoring to accommodate the new model. The behavioral changes are: * The writer now calls prev.compute(split, context) once per epoch within the epoch loop. * ContinuousDataSourceRDD now spawns a ContinuousQueuedDataReader which is shared across multiple calls to compute() for the same partition. ## How was this patch tested? existing unit tests Author: Jose Torres Closes #21200 from jose-torres/noAggr. --- .../datasources/v2/DataSourceV2ScanExec.scala | 6 +- .../continuous/ContinuousDataSourceRDD.scala | 114 +++++++++ .../ContinuousDataSourceRDDIter.scala | 222 ------------------ .../ContinuousQueuedDataReader.scala | 211 +++++++++++++++++ .../continuous/ContinuousWriteRDD.scala | 90 +++++++ .../WriteToContinuousDataSourceExec.scala | 57 +---- .../ContinuousQueuedDataReaderSuite.scala | 167 +++++++++++++ 7 files changed, 592 insertions(+), 275 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala index 41bdda47c8c3e..77cb707340b0f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala @@ -96,7 +96,11 @@ case class DataSourceV2ScanExec( sparkContext.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), sparkContext.env) .askSync[Unit](SetReaderPartitions(readerFactories.size)) - new ContinuousDataSourceRDD(sparkContext, sqlContext, readerFactories) + new ContinuousDataSourceRDD( + sparkContext, + sqlContext.conf.continuousStreamingExecutorQueueSize, + sqlContext.conf.continuousStreamingExecutorPollIntervalMs, + readerFactories) .asInstanceOf[RDD[InternalRow]] case r: SupportsScanColumnarBatch if r.enableBatchRead() => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala new file mode 100644 index 0000000000000..0a3b9dcccb6c5 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.continuous + +import org.apache.spark._ +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{Row, SQLContext} +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.execution.datasources.v2.{DataSourceRDDPartition, RowToUnsafeDataReader} +import org.apache.spark.sql.sources.v2.reader._ +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousDataReader, PartitionOffset} +import org.apache.spark.util.{NextIterator, ThreadUtils} + +class ContinuousDataSourceRDDPartition( + val index: Int, + val readerFactory: DataReaderFactory[UnsafeRow]) + extends Partition with Serializable { + + // This is semantically a lazy val - it's initialized once the first time a call to + // ContinuousDataSourceRDD.compute() needs to access it, so it can be shared across + // all compute() calls for a partition. This ensures that one compute() picks up where the + // previous one ended. + // We don't make it actually a lazy val because it needs input which isn't available here. + // This will only be initialized on the executors. + private[continuous] var queueReader: ContinuousQueuedDataReader = _ +} + +/** + * The bottom-most RDD of a continuous processing read task. Wraps a [[ContinuousQueuedDataReader]] + * to read from the remote source, and polls that queue for incoming rows. + * + * Note that continuous processing calls compute() multiple times, and the same + * [[ContinuousQueuedDataReader]] instance will/must be shared between each call for the same split. + */ +class ContinuousDataSourceRDD( + sc: SparkContext, + dataQueueSize: Int, + epochPollIntervalMs: Long, + @transient private val readerFactories: Seq[DataReaderFactory[UnsafeRow]]) + extends RDD[UnsafeRow](sc, Nil) { + + override protected def getPartitions: Array[Partition] = { + readerFactories.zipWithIndex.map { + case (readerFactory, index) => new ContinuousDataSourceRDDPartition(index, readerFactory) + }.toArray + } + + /** + * Initialize the shared reader for this partition if needed, then read rows from it until + * it returns null to signal the end of the epoch. + */ + override def compute(split: Partition, context: TaskContext): Iterator[UnsafeRow] = { + // If attempt number isn't 0, this is a task retry, which we don't support. + if (context.attemptNumber() != 0) { + throw new ContinuousTaskRetryException() + } + + val readerForPartition = { + val partition = split.asInstanceOf[ContinuousDataSourceRDDPartition] + if (partition.queueReader == null) { + partition.queueReader = + new ContinuousQueuedDataReader( + partition.readerFactory, context, dataQueueSize, epochPollIntervalMs) + } + + partition.queueReader + } + + new NextIterator[UnsafeRow] { + override def getNext(): UnsafeRow = { + readerForPartition.next() match { + case null => + finished = true + null + case row => row + } + } + + override def close(): Unit = {} + } + } + + override def getPreferredLocations(split: Partition): Seq[String] = { + split.asInstanceOf[ContinuousDataSourceRDDPartition].readerFactory.preferredLocations() + } +} + +object ContinuousDataSourceRDD { + private[continuous] def getContinuousReader( + reader: DataReader[UnsafeRow]): ContinuousDataReader[_] = { + reader match { + case r: ContinuousDataReader[UnsafeRow] => r + case wrapped: RowToUnsafeDataReader => + wrapped.rowReader.asInstanceOf[ContinuousDataReader[Row]] + case _ => + throw new IllegalStateException(s"Unknown continuous reader type ${reader.getClass}") + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala deleted file mode 100644 index 06754f01657d3..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala +++ /dev/null @@ -1,222 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.streaming.continuous - -import java.util.concurrent.{ArrayBlockingQueue, BlockingQueue, TimeUnit} -import java.util.concurrent.atomic.AtomicBoolean - -import scala.collection.JavaConverters._ - -import org.apache.spark._ -import org.apache.spark.internal.Logging -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{Row, SQLContext} -import org.apache.spark.sql.catalyst.expressions.UnsafeRow -import org.apache.spark.sql.execution.datasources.v2.{DataSourceRDDPartition, RowToUnsafeDataReader} -import org.apache.spark.sql.sources.v2.reader._ -import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousDataReader, PartitionOffset} -import org.apache.spark.util.ThreadUtils - -class ContinuousDataSourceRDD( - sc: SparkContext, - sqlContext: SQLContext, - @transient private val readerFactories: Seq[DataReaderFactory[UnsafeRow]]) - extends RDD[UnsafeRow](sc, Nil) { - - private val dataQueueSize = sqlContext.conf.continuousStreamingExecutorQueueSize - private val epochPollIntervalMs = sqlContext.conf.continuousStreamingExecutorPollIntervalMs - - override protected def getPartitions: Array[Partition] = { - readerFactories.zipWithIndex.map { - case (readerFactory, index) => new DataSourceRDDPartition(index, readerFactory) - }.toArray - } - - override def compute(split: Partition, context: TaskContext): Iterator[UnsafeRow] = { - // If attempt number isn't 0, this is a task retry, which we don't support. - if (context.attemptNumber() != 0) { - throw new ContinuousTaskRetryException() - } - - val reader = split.asInstanceOf[DataSourceRDDPartition[UnsafeRow]] - .readerFactory.createDataReader() - - val coordinatorId = context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY) - - // This queue contains two types of messages: - // * (null, null) representing an epoch boundary. - // * (row, off) containing a data row and its corresponding PartitionOffset. - val queue = new ArrayBlockingQueue[(UnsafeRow, PartitionOffset)](dataQueueSize) - - val epochPollFailed = new AtomicBoolean(false) - val epochPollExecutor = ThreadUtils.newDaemonSingleThreadScheduledExecutor( - s"epoch-poll--$coordinatorId--${context.partitionId()}") - val epochPollRunnable = new EpochPollRunnable(queue, context, epochPollFailed) - epochPollExecutor.scheduleWithFixedDelay( - epochPollRunnable, 0, epochPollIntervalMs, TimeUnit.MILLISECONDS) - - // Important sequencing - we must get start offset before the data reader thread begins - val startOffset = ContinuousDataSourceRDD.getBaseReader(reader).getOffset - - val dataReaderFailed = new AtomicBoolean(false) - val dataReaderThread = new DataReaderThread(reader, queue, context, dataReaderFailed) - dataReaderThread.setDaemon(true) - dataReaderThread.start() - - context.addTaskCompletionListener(_ => { - dataReaderThread.interrupt() - epochPollExecutor.shutdown() - }) - - val epochEndpoint = EpochCoordinatorRef.get(coordinatorId, SparkEnv.get) - new Iterator[UnsafeRow] { - private val POLL_TIMEOUT_MS = 1000 - - private var currentEntry: (UnsafeRow, PartitionOffset) = _ - private var currentOffset: PartitionOffset = startOffset - private var currentEpoch = - context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong - - override def hasNext(): Boolean = { - while (currentEntry == null) { - if (context.isInterrupted() || context.isCompleted()) { - currentEntry = (null, null) - } - if (dataReaderFailed.get()) { - throw new SparkException("data read failed", dataReaderThread.failureReason) - } - if (epochPollFailed.get()) { - throw new SparkException("epoch poll failed", epochPollRunnable.failureReason) - } - currentEntry = queue.poll(POLL_TIMEOUT_MS, TimeUnit.MILLISECONDS) - } - - currentEntry match { - // epoch boundary marker - case (null, null) => - epochEndpoint.send(ReportPartitionOffset( - context.partitionId(), - currentEpoch, - currentOffset)) - currentEpoch += 1 - currentEntry = null - false - // real row - case (_, offset) => - currentOffset = offset - true - } - } - - override def next(): UnsafeRow = { - if (currentEntry == null) throw new NoSuchElementException("No current row was set") - val r = currentEntry._1 - currentEntry = null - r - } - } - } - - override def getPreferredLocations(split: Partition): Seq[String] = { - split.asInstanceOf[DataSourceRDDPartition[UnsafeRow]].readerFactory.preferredLocations() - } -} - -case class EpochPackedPartitionOffset(epoch: Long) extends PartitionOffset - -class EpochPollRunnable( - queue: BlockingQueue[(UnsafeRow, PartitionOffset)], - context: TaskContext, - failedFlag: AtomicBoolean) - extends Thread with Logging { - private[continuous] var failureReason: Throwable = _ - - private val epochEndpoint = EpochCoordinatorRef.get( - context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), SparkEnv.get) - private var currentEpoch = context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong - - override def run(): Unit = { - try { - val newEpoch = epochEndpoint.askSync[Long](GetCurrentEpoch) - for (i <- currentEpoch to newEpoch - 1) { - queue.put((null, null)) - logDebug(s"Sent marker to start epoch ${i + 1}") - } - currentEpoch = newEpoch - } catch { - case t: Throwable => - failureReason = t - failedFlag.set(true) - throw t - } - } -} - -class DataReaderThread( - reader: DataReader[UnsafeRow], - queue: BlockingQueue[(UnsafeRow, PartitionOffset)], - context: TaskContext, - failedFlag: AtomicBoolean) - extends Thread( - s"continuous-reader--${context.partitionId()}--" + - s"${context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY)}") { - private[continuous] var failureReason: Throwable = _ - - override def run(): Unit = { - TaskContext.setTaskContext(context) - val baseReader = ContinuousDataSourceRDD.getBaseReader(reader) - try { - while (!context.isInterrupted && !context.isCompleted()) { - if (!reader.next()) { - // Check again, since reader.next() might have blocked through an incoming interrupt. - if (!context.isInterrupted && !context.isCompleted()) { - throw new IllegalStateException( - "Continuous reader reported no elements! Reader should have blocked waiting.") - } else { - return - } - } - - queue.put((reader.get().copy(), baseReader.getOffset)) - } - } catch { - case _: InterruptedException if context.isInterrupted() => - // Continuous shutdown always involves an interrupt; do nothing and shut down quietly. - - case t: Throwable => - failureReason = t - failedFlag.set(true) - // Don't rethrow the exception in this thread. It's not needed, and the default Spark - // exception handler will kill the executor. - } finally { - reader.close() - } - } -} - -object ContinuousDataSourceRDD { - private[continuous] def getBaseReader(reader: DataReader[UnsafeRow]): ContinuousDataReader[_] = { - reader match { - case r: ContinuousDataReader[UnsafeRow] => r - case wrapped: RowToUnsafeDataReader => - wrapped.rowReader.asInstanceOf[ContinuousDataReader[Row]] - case _ => - throw new IllegalStateException(s"Unknown continuous reader type ${reader.getClass}") - } - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala new file mode 100644 index 0000000000000..01a999f6505fc --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala @@ -0,0 +1,211 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.continuous + +import java.io.Closeable +import java.util.concurrent.{ArrayBlockingQueue, BlockingQueue, TimeUnit} +import java.util.concurrent.atomic.AtomicBoolean + +import scala.util.control.NonFatal + +import org.apache.spark.{Partition, SparkEnv, SparkException, TaskContext} +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.sources.v2.reader.{DataReader, DataReaderFactory} +import org.apache.spark.sql.sources.v2.reader.streaming.PartitionOffset +import org.apache.spark.util.ThreadUtils + +/** + * A wrapper for a continuous processing data reader, including a reading queue and epoch markers. + * + * This will be instantiated once per partition - successive calls to compute() in the + * [[ContinuousDataSourceRDD]] will reuse the same reader. This is required to get continuity of + * offsets across epochs. Each compute() should call the next() method here until null is returned. + */ +class ContinuousQueuedDataReader( + factory: DataReaderFactory[UnsafeRow], + context: TaskContext, + dataQueueSize: Int, + epochPollIntervalMs: Long) extends Closeable { + private val reader = factory.createDataReader() + + // Important sequencing - we must get our starting point before the provider threads start running + private var currentOffset: PartitionOffset = + ContinuousDataSourceRDD.getContinuousReader(reader).getOffset + private var currentEpoch: Long = + context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong + + /** + * The record types in the read buffer. + */ + sealed trait ContinuousRecord + case object EpochMarker extends ContinuousRecord + case class ContinuousRow(row: UnsafeRow, offset: PartitionOffset) extends ContinuousRecord + + private val queue = new ArrayBlockingQueue[ContinuousRecord](dataQueueSize) + + private val coordinatorId = context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY) + private val epochCoordEndpoint = EpochCoordinatorRef.get( + context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), SparkEnv.get) + + private val epochMarkerExecutor = ThreadUtils.newDaemonSingleThreadScheduledExecutor( + s"epoch-poll--$coordinatorId--${context.partitionId()}") + private val epochMarkerGenerator = new EpochMarkerGenerator + epochMarkerExecutor.scheduleWithFixedDelay( + epochMarkerGenerator, 0, epochPollIntervalMs, TimeUnit.MILLISECONDS) + + private val dataReaderThread = new DataReaderThread + dataReaderThread.setDaemon(true) + dataReaderThread.start() + + context.addTaskCompletionListener(_ => { + this.close() + }) + + private def shouldStop() = { + context.isInterrupted() || context.isCompleted() + } + + /** + * Return the next UnsafeRow to be read in the current epoch, or null if the epoch is done. + * + * After returning null, the [[ContinuousDataSourceRDD]] compute() for the following epoch + * will call next() again to start getting rows. + */ + def next(): UnsafeRow = { + val POLL_TIMEOUT_MS = 1000 + var currentEntry: ContinuousRecord = null + + while (currentEntry == null) { + if (shouldStop()) { + // Force the epoch to end here. The writer will notice the context is interrupted + // or completed and not start a new one. This makes it possible to achieve clean + // shutdown of the streaming query. + // TODO: The obvious generalization of this logic to multiple stages won't work. It's + // invalid to send an epoch marker from the bottom of a task if all its child tasks + // haven't sent one. + currentEntry = EpochMarker + } else { + if (dataReaderThread.failureReason != null) { + throw new SparkException("Data read failed", dataReaderThread.failureReason) + } + if (epochMarkerGenerator.failureReason != null) { + throw new SparkException( + "Epoch marker generation failed", + epochMarkerGenerator.failureReason) + } + currentEntry = queue.poll(POLL_TIMEOUT_MS, TimeUnit.MILLISECONDS) + } + } + + currentEntry match { + case EpochMarker => + epochCoordEndpoint.send(ReportPartitionOffset( + context.partitionId(), currentEpoch, currentOffset)) + currentEpoch += 1 + null + case ContinuousRow(row, offset) => + currentOffset = offset + row + } + } + + override def close(): Unit = { + dataReaderThread.interrupt() + epochMarkerExecutor.shutdown() + } + + /** + * The data component of [[ContinuousQueuedDataReader]]. Pushes (row, offset) to the queue when + * a new row arrives to the [[DataReader]]. + */ + class DataReaderThread extends Thread( + s"continuous-reader--${context.partitionId()}--" + + s"${context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY)}") with Logging { + @volatile private[continuous] var failureReason: Throwable = _ + + override def run(): Unit = { + TaskContext.setTaskContext(context) + val baseReader = ContinuousDataSourceRDD.getContinuousReader(reader) + try { + while (!shouldStop()) { + if (!reader.next()) { + // Check again, since reader.next() might have blocked through an incoming interrupt. + if (!shouldStop()) { + throw new IllegalStateException( + "Continuous reader reported no elements! Reader should have blocked waiting.") + } else { + return + } + } + + queue.put(ContinuousRow(reader.get().copy(), baseReader.getOffset)) + } + } catch { + case _: InterruptedException => + // Continuous shutdown always involves an interrupt; do nothing and shut down quietly. + logInfo(s"shutting down interrupted data reader thread $getName") + + case NonFatal(t) => + failureReason = t + logWarning("data reader thread failed", t) + // If we throw from this thread, we may kill the executor. Let the parent thread handle + // it. + + case t: Throwable => + failureReason = t + throw t + } finally { + reader.close() + } + } + } + + /** + * The epoch marker component of [[ContinuousQueuedDataReader]]. Populates the queue with + * EpochMarker when a new epoch marker arrives. + */ + class EpochMarkerGenerator extends Runnable with Logging { + @volatile private[continuous] var failureReason: Throwable = _ + + private val epochCoordEndpoint = EpochCoordinatorRef.get( + context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), SparkEnv.get) + // Note that this is *not* the same as the currentEpoch in [[ContinuousDataQueuedReader]]! That + // field represents the epoch wrt the data being processed. The currentEpoch here is just a + // counter to ensure we send the appropriate number of markers if we fall behind the driver. + private var currentEpoch = context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong + + override def run(): Unit = { + try { + val newEpoch = epochCoordEndpoint.askSync[Long](GetCurrentEpoch) + // It's possible to fall more than 1 epoch behind if a GetCurrentEpoch RPC ends up taking + // a while. We catch up by injecting enough epoch markers immediately to catch up. This will + // result in some epochs being empty for this partition, but that's fine. + for (i <- currentEpoch to newEpoch - 1) { + queue.put(EpochMarker) + logDebug(s"Sent marker to start epoch ${i + 1}") + } + currentEpoch = newEpoch + } catch { + case t: Throwable => + failureReason = t + throw t + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala new file mode 100644 index 0000000000000..91f1576581511 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.continuous + +import java.util.concurrent.atomic.AtomicLong + +import org.apache.spark.{Partition, SparkEnv, TaskContext} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.execution.datasources.v2.DataWritingSparkTask.{logError, logInfo} +import org.apache.spark.sql.sources.v2.writer.{DataWriter, DataWriterFactory, WriterCommitMessage} +import org.apache.spark.util.Utils + +/** + * The RDD writing to a sink in continuous processing. + * + * Within each task, we repeatedly call prev.compute(). Each resulting iterator contains the data + * to be written for one epoch, which we commit and forward to the driver. + * + * We keep repeating prev.compute() and writing new epochs until the query is shut down. + */ +class ContinuousWriteRDD(var prev: RDD[InternalRow], writeTask: DataWriterFactory[InternalRow]) + extends RDD[Unit](prev) { + + override val partitioner = prev.partitioner + + override def getPartitions: Array[Partition] = prev.partitions + + override def compute(split: Partition, context: TaskContext): Iterator[Unit] = { + val epochCoordinator = EpochCoordinatorRef.get( + context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), + SparkEnv.get) + var currentEpoch = context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong + + while (!context.isInterrupted() && !context.isCompleted()) { + var dataWriter: DataWriter[InternalRow] = null + // write the data and commit this writer. + Utils.tryWithSafeFinallyAndFailureCallbacks(block = { + try { + val dataIterator = prev.compute(split, context) + dataWriter = writeTask.createDataWriter( + context.partitionId(), context.attemptNumber(), currentEpoch) + while (dataIterator.hasNext) { + dataWriter.write(dataIterator.next()) + } + logInfo(s"Writer for partition ${context.partitionId()} " + + s"in epoch $currentEpoch is committing.") + val msg = dataWriter.commit() + epochCoordinator.send( + CommitPartitionEpoch(context.partitionId(), currentEpoch, msg) + ) + logInfo(s"Writer for partition ${context.partitionId()} " + + s"in epoch $currentEpoch committed.") + currentEpoch += 1 + } catch { + case _: InterruptedException => + // Continuous shutdown always involves an interrupt. Just finish the task. + } + })(catchBlock = { + // If there is an error, abort this writer. We enter this callback in the middle of + // rethrowing an exception, so compute() will stop executing at this point. + logError(s"Writer for partition ${context.partitionId()} is aborting.") + if (dataWriter != null) dataWriter.abort() + logError(s"Writer for partition ${context.partitionId()} aborted.") + }) + } + + Iterator() + } + + override def clearDependencies() { + super.clearDependencies() + prev = null + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala index ba88ae1af469a..e0af3a2f1b85d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala @@ -46,24 +46,19 @@ case class WriteToContinuousDataSourceExec(writer: StreamWriter, query: SparkPla case _ => new InternalRowDataWriterFactory(writer.createWriterFactory(), query.schema) } - val rdd = query.execute() + val rdd = new ContinuousWriteRDD(query.execute(), writerFactory) logInfo(s"Start processing data source writer: $writer. " + - s"The input RDD has ${rdd.getNumPartitions} partitions.") - // Let the epoch coordinator know how many partitions the write RDD has. + s"The input RDD has ${rdd.partitions.length} partitions.") EpochCoordinatorRef.get( - sparkContext.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), - sparkContext.env) + sparkContext.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), + sparkContext.env) .askSync[Unit](SetWriterPartitions(rdd.getNumPartitions)) try { // Force the RDD to run so continuous processing starts; no data is actually being collected // to the driver, as ContinuousWriteRDD outputs nothing. - sparkContext.runJob( - rdd, - (context: TaskContext, iter: Iterator[InternalRow]) => - WriteToContinuousDataSourceExec.run(writerFactory, context, iter), - rdd.partitions.indices) + rdd.collect() } catch { case _: InterruptedException => // Interruption is how continuous queries are ended, so accept and ignore the exception. @@ -80,45 +75,3 @@ case class WriteToContinuousDataSourceExec(writer: StreamWriter, query: SparkPla sparkContext.emptyRDD } } - -object WriteToContinuousDataSourceExec extends Logging { - def run( - writeTask: DataWriterFactory[InternalRow], - context: TaskContext, - iter: Iterator[InternalRow]): Unit = { - val epochCoordinator = EpochCoordinatorRef.get( - context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), - SparkEnv.get) - var currentEpoch = context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong - - do { - var dataWriter: DataWriter[InternalRow] = null - // write the data and commit this writer. - Utils.tryWithSafeFinallyAndFailureCallbacks(block = { - try { - dataWriter = writeTask.createDataWriter( - context.partitionId(), context.attemptNumber(), currentEpoch) - while (iter.hasNext) { - dataWriter.write(iter.next()) - } - logInfo(s"Writer for partition ${context.partitionId()} is committing.") - val msg = dataWriter.commit() - logInfo(s"Writer for partition ${context.partitionId()} committed.") - epochCoordinator.send( - CommitPartitionEpoch(context.partitionId(), currentEpoch, msg) - ) - currentEpoch += 1 - } catch { - case _: InterruptedException => - // Continuous shutdown always involves an interrupt. Just finish the task. - } - })(catchBlock = { - // If there is an error, abort this writer. We enter this callback in the middle of - // rethrowing an exception, so runContinuous will stop executing at this point. - logError(s"Writer for partition ${context.partitionId()} is aborting.") - if (dataWriter != null) dataWriter.abort() - logError(s"Writer for partition ${context.partitionId()} aborted.") - }) - } while (!context.isInterrupted()) - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala new file mode 100644 index 0000000000000..e755625d09e0f --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala @@ -0,0 +1,167 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming.continuous + +import java.util.concurrent.{ArrayBlockingQueue, BlockingQueue} + +import org.mockito.{ArgumentCaptor, Matchers} +import org.mockito.Mockito._ +import org.scalatest.mockito.MockitoSugar + +import org.apache.spark.{SparkEnv, SparkFunSuite, TaskContext} +import org.apache.spark.rpc.{RpcEndpointRef, RpcEnv} +import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.execution.streaming.continuous._ +import org.apache.spark.sql.sources.v2.reader.DataReaderFactory +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousDataReader, ContinuousReader, PartitionOffset} +import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter +import org.apache.spark.sql.streaming.StreamTest +import org.apache.spark.sql.types.{DataType, IntegerType} + +class ContinuousQueuedDataReaderSuite extends StreamTest with MockitoSugar { + case class LongPartitionOffset(offset: Long) extends PartitionOffset + + val coordinatorId = s"${getClass.getSimpleName}-epochCoordinatorIdForUnitTest" + val startEpoch = 0 + + var epochEndpoint: RpcEndpointRef = _ + + override def beforeEach(): Unit = { + super.beforeEach() + epochEndpoint = EpochCoordinatorRef.create( + mock[StreamWriter], + mock[ContinuousReader], + mock[ContinuousExecution], + coordinatorId, + startEpoch, + spark, + SparkEnv.get) + } + + override def afterEach(): Unit = { + SparkEnv.get.rpcEnv.stop(epochEndpoint) + epochEndpoint = null + super.afterEach() + } + + + private val mockContext = mock[TaskContext] + when(mockContext.getLocalProperty(ContinuousExecution.START_EPOCH_KEY)) + .thenReturn(startEpoch.toString) + when(mockContext.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY)) + .thenReturn(coordinatorId) + + /** + * Set up a ContinuousQueuedDataReader for testing. The blocking queue can be used to send + * rows to the wrapped data reader. + */ + private def setup(): (BlockingQueue[UnsafeRow], ContinuousQueuedDataReader) = { + val queue = new ArrayBlockingQueue[UnsafeRow](1024) + val factory = new DataReaderFactory[UnsafeRow] { + override def createDataReader() = new ContinuousDataReader[UnsafeRow] { + var index = -1 + var curr: UnsafeRow = _ + + override def next() = { + curr = queue.take() + index += 1 + true + } + + override def get = curr + + override def getOffset = LongPartitionOffset(index) + + override def close() = {} + } + } + val reader = new ContinuousQueuedDataReader( + factory, + mockContext, + dataQueueSize = sqlContext.conf.continuousStreamingExecutorQueueSize, + epochPollIntervalMs = sqlContext.conf.continuousStreamingExecutorPollIntervalMs) + + (queue, reader) + } + + private def unsafeRow(value: Int) = { + UnsafeProjection.create(Array(IntegerType : DataType))( + new GenericInternalRow(Array(value: Any))) + } + + test("basic data read") { + val (input, reader) = setup() + + input.add(unsafeRow(12345)) + assert(reader.next().getInt(0) == 12345) + } + + test("basic epoch marker") { + val (input, reader) = setup() + + epochEndpoint.askSync[Long](IncrementAndGetEpoch) + assert(reader.next() == null) + } + + test("new rows after markers") { + val (input, reader) = setup() + + epochEndpoint.askSync[Long](IncrementAndGetEpoch) + epochEndpoint.askSync[Long](IncrementAndGetEpoch) + epochEndpoint.askSync[Long](IncrementAndGetEpoch) + assert(reader.next() == null) + assert(reader.next() == null) + assert(reader.next() == null) + input.add(unsafeRow(11111)) + input.add(unsafeRow(22222)) + assert(reader.next().getInt(0) == 11111) + assert(reader.next().getInt(0) == 22222) + } + + test("new markers after rows") { + val (input, reader) = setup() + + input.add(unsafeRow(11111)) + input.add(unsafeRow(22222)) + assert(reader.next().getInt(0) == 11111) + assert(reader.next().getInt(0) == 22222) + epochEndpoint.askSync[Long](IncrementAndGetEpoch) + epochEndpoint.askSync[Long](IncrementAndGetEpoch) + epochEndpoint.askSync[Long](IncrementAndGetEpoch) + assert(reader.next() == null) + assert(reader.next() == null) + assert(reader.next() == null) + } + + test("alternating markers and rows") { + val (input, reader) = setup() + + input.add(unsafeRow(11111)) + assert(reader.next().getInt(0) == 11111) + input.add(unsafeRow(22222)) + assert(reader.next().getInt(0) == 22222) + epochEndpoint.askSync[Long](IncrementAndGetEpoch) + assert(reader.next() == null) + input.add(unsafeRow(33333)) + assert(reader.next().getInt(0) == 33333) + input.add(unsafeRow(44444)) + assert(reader.next().getInt(0) == 44444) + epochEndpoint.askSync[Long](IncrementAndGetEpoch) + assert(reader.next() == null) + } +} From 47b5b68528c154d32b3f40f388918836d29462b8 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Fri, 4 May 2018 16:35:24 -0700 Subject: [PATCH 0738/1161] [SPARK-24157][SS] Enabled no-data batches in MicroBatchExecution for streaming aggregation and deduplication. ## What changes were proposed in this pull request? This PR enables the MicroBatchExecution to run no-data batches if some SparkPlan requires running another batch to output results based on updated watermark / processing time. In this PR, I have enabled streaming aggregations and streaming deduplicates to automatically run addition batch even if new data is available. See https://issues.apache.org/jira/browse/SPARK-24156 for more context. Major changes/refactoring done in this PR. - Refactoring MicroBatchExecution - A major point of confusion in MicroBatchExecution control flow was always (at least to me) was that `populateStartOffsets` internally called `constructNextBatch` which was not obvious from just the name "populateStartOffsets" and made the control flow from the main trigger execution loop very confusing (main loop in `runActivatedStream` called `constructNextBatch` but only if `populateStartOffsets` hadn't already called it). Instead, the refactoring makes it cleaner. - `populateStartOffsets` only the updates `availableOffsets` and `committedOffsets`. Does not call `constructNextBatch`. - Main loop in `runActivatedStream` calls `constructNextBatch` which returns true or false reflecting whether the next batch is ready for executing. This method is now idempotent; if a batch has already been constructed, then it will always return true until the batch has been executed. - If next batch is ready then we call `runBatch` or sleep. - That's it. - Refactoring watermark management logic - This has been refactored out from `MicroBatchExecution` in a separate class to simplify `MicroBatchExecution`. - New method `shouldRunAnotherBatch` in `IncrementalExecution` - This returns true if there is any stateful operation in the last execution plan that requires another batch for state cleanup, etc. This is used to decide whether to construct a batch or not in `constructNextBatch`. - Changes to stream testing framework - Many tests used CheckLastBatch to validate answers. This assumed that there will be no more batches after the last set of input has been processed, so the last batch is the one that has output corresponding to the last input. This is not true anymore. To account for that, I made two changes. - `CheckNewAnswer` is a new test action that verifies the new rows generated since the last time the answer was checked by `CheckAnswer`, `CheckNewAnswer` or `CheckLastBatch`. This is agnostic to how many batches occurred between the last check and now. To do make this easier, I added a common trait between MemorySink and MemorySinkV2 to abstract out some common methods. - `assertNumStateRows` has been updated in the same way to be agnostic to batches while checking what the total rows and how many state rows were updated (sums up updates since the last check). ## How was this patch tested? - Changes made to existing tests - Tests have been changed in one of the following patterns. - Tests where the last input was given again to force another batch to be executed and state cleaned up / output generated, they were simplified by removing the extra input. - Tests using aggregation+watermark where CheckLastBatch were replaced with CheckNewAnswer to make them batch agnostic. - New tests added to check whether the flag works for streaming aggregation and deduplication Author: Tathagata Das Closes #21220 from tdas/SPARK-24157. --- .../apache/spark/sql/internal/SQLConf.scala | 11 + .../streaming/IncrementalExecution.scala | 10 + .../streaming/MicroBatchExecution.scala | 231 ++++++++---------- .../streaming/WatermarkTracker.scala | 73 ++++++ .../sql/execution/streaming/memory.scala | 17 +- .../streaming/sources/memoryV2.scala | 8 +- .../streaming/statefulOperators.scala | 16 ++ .../sources/ForeachWriterSuite.scala | 8 +- .../streaming/EventTimeWatermarkSuite.scala | 112 ++++----- .../sql/streaming/FileStreamSinkSuite.scala | 7 +- .../sql/streaming/StateStoreMetricsTest.scala | 52 +++- .../spark/sql/streaming/StreamTest.scala | 56 ++++- ...cala => StreamingDeduplicationSuite.scala} | 94 +++---- .../sql/streaming/StreamingJoinSuite.scala | 18 +- .../sql/streaming/StreamingQuerySuite.scala | 2 +- 15 files changed, 450 insertions(+), 265 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/WatermarkTracker.scala rename sql/core/src/test/scala/org/apache/spark/sql/streaming/{DeduplicateSuite.scala => StreamingDeduplicationSuite.scala} (80%) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 3942240c442b2..895e150756567 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -919,6 +919,14 @@ object SQLConf { .timeConf(TimeUnit.MILLISECONDS) .createWithDefault(10000L) + val STREAMING_NO_DATA_MICRO_BATCHES_ENABLED = + buildConf("spark.sql.streaming.noDataMicroBatchesEnabled") + .doc( + "Whether streaming micro-batch engine will execute batches without data " + + "for eager state management for stateful streaming queries.") + .booleanConf + .createWithDefault(true) + val STREAMING_METRICS_ENABLED = buildConf("spark.sql.streaming.metricsEnabled") .doc("Whether Dropwizard/Codahale metrics will be reported for active streaming queries.") @@ -1313,6 +1321,9 @@ class SQLConf extends Serializable with Logging { def streamingNoDataProgressEventInterval: Long = getConf(STREAMING_NO_DATA_PROGRESS_EVENT_INTERVAL) + def streamingNoDataMicroBatchesEnabled: Boolean = + getConf(STREAMING_NO_DATA_MICRO_BATCHES_ENABLED) + def streamingMetricsEnabled: Boolean = getConf(STREAMING_METRICS_ENABLED) def streamingProgressRetention: Int = getConf(STREAMING_PROGRESS_RETENTION) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index 1a83c884d55bd..c480b96626f84 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -143,4 +143,14 @@ class IncrementalExecution( /** No need assert supported, as this check has already been done */ override def assertSupported(): Unit = { } + + /** + * Should the MicroBatchExecution run another batch based on this execution and the current + * updated metadata. + */ + def shouldRunAnotherBatch(newMetadata: OffsetSeqMetadata): Boolean = { + executedPlan.collect { + case p: StateStoreWriter => p.shouldRunAnotherBatch(newMetadata) + }.exists(_ == true) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index 6e231970f4a22..6709e7052f005 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -61,6 +61,8 @@ class MicroBatchExecution( case _ => throw new IllegalStateException(s"Unknown type of trigger: $trigger") } + private val watermarkTracker = new WatermarkTracker() + override lazy val logicalPlan: LogicalPlan = { assert(queryExecutionThread eq Thread.currentThread, "logicalPlan must be initialized in QueryExecutionThread " + @@ -128,40 +130,55 @@ class MicroBatchExecution( * Repeatedly attempts to run batches as data arrives. */ protected def runActivatedStream(sparkSessionForStream: SparkSession): Unit = { - triggerExecutor.execute(() => { - startTrigger() + val noDataBatchesEnabled = + sparkSessionForStream.sessionState.conf.streamingNoDataMicroBatchesEnabled + + triggerExecutor.execute(() => { if (isActive) { + var currentBatchIsRunnable = false // Whether the current batch is runnable / has been run + var currentBatchHasNewData = false // Whether the current batch had new data + + startTrigger() + reportTimeTaken("triggerExecution") { + // We'll do this initialization only once every start / restart if (currentBatchId < 0) { - // We'll do this initialization only once populateStartOffsets(sparkSessionForStream) - sparkSession.sparkContext.setJobDescription(getBatchDescriptionString) - logDebug(s"Stream running from $committedOffsets to $availableOffsets") - } else { - constructNextBatch() + logInfo(s"Stream started from $committedOffsets") } - if (dataAvailable) { - currentStatus = currentStatus.copy(isDataAvailable = true) - updateStatusMessage("Processing new data") + + // Set this before calling constructNextBatch() so any Spark jobs executed by sources + // while getting new data have the correct description + sparkSession.sparkContext.setJobDescription(getBatchDescriptionString) + + // Try to construct the next batch. This will return true only if the next batch is + // ready and runnable. Note that the current batch may be runnable even without + // new data to process as `constructNextBatch` may decide to run a batch for + // state cleanup, etc. `isNewDataAvailable` will be updated to reflect whether new data + // is available or not. + currentBatchIsRunnable = constructNextBatch(noDataBatchesEnabled) + + // Remember whether the current batch has data or not. This will be required later + // for bookkeeping after running the batch, when `isNewDataAvailable` will have changed + // to false as the batch would have already processed the available data. + currentBatchHasNewData = isNewDataAvailable + + currentStatus = currentStatus.copy(isDataAvailable = isNewDataAvailable) + if (currentBatchIsRunnable) { + if (currentBatchHasNewData) updateStatusMessage("Processing new data") + else updateStatusMessage("No new data but cleaning up state") runBatch(sparkSessionForStream) + } else { + updateStatusMessage("Waiting for data to arrive") } } - // Report trigger as finished and construct progress object. - finishTrigger(dataAvailable) - if (dataAvailable) { - // Update committed offsets. - commitLog.add(currentBatchId) - committedOffsets ++= availableOffsets - logDebug(s"batch ${currentBatchId} committed") - // We'll increase currentBatchId after we complete processing current batch's data - currentBatchId += 1 - sparkSession.sparkContext.setJobDescription(getBatchDescriptionString) - } else { - currentStatus = currentStatus.copy(isDataAvailable = false) - updateStatusMessage("Waiting for data to arrive") - Thread.sleep(pollingDelayMs) - } + + finishTrigger(currentBatchHasNewData) // Must be outside reportTimeTaken so it is recorded + + // If the current batch has been executed, then increment the batch id, else there was + // no data to execute the batch + if (currentBatchIsRunnable) currentBatchId += 1 else Thread.sleep(pollingDelayMs) } updateStatusMessage("Waiting for next trigger") isActive @@ -211,6 +228,7 @@ class MicroBatchExecution( OffsetSeqMetadata.setSessionConf(metadata, sparkSessionToRunBatches.conf) offsetSeqMetadata = OffsetSeqMetadata( metadata.batchWatermarkMs, metadata.batchTimestampMs, sparkSessionToRunBatches.conf) + watermarkTracker.setWatermark(metadata.batchWatermarkMs) } /* identify the current batch id: if commit log indicates we successfully processed the @@ -235,7 +253,6 @@ class MicroBatchExecution( currentBatchId = latestCommittedBatchId + 1 committedOffsets ++= availableOffsets // Construct a new batch be recomputing availableOffsets - constructNextBatch() } else if (latestCommittedBatchId < latestBatchId - 1) { logWarning(s"Batch completion log latest batch id is " + s"${latestCommittedBatchId}, which is not trailing " + @@ -243,19 +260,18 @@ class MicroBatchExecution( } case None => logInfo("no commit log present") } - logDebug(s"Resuming at batch $currentBatchId with committed offsets " + + logInfo(s"Resuming at batch $currentBatchId with committed offsets " + s"$committedOffsets and available offsets $availableOffsets") case None => // We are starting this stream for the first time. logInfo(s"Starting new streaming query.") currentBatchId = 0 - constructNextBatch() } } /** * Returns true if there is any new data available to be processed. */ - private def dataAvailable: Boolean = { + private def isNewDataAvailable: Boolean = { availableOffsets.exists { case (source, available) => committedOffsets @@ -266,93 +282,63 @@ class MicroBatchExecution( } /** - * Queries all of the sources to see if any new data is available. When there is new data the - * batchId counter is incremented and a new log entry is written with the newest offsets. + * Attempts to construct a batch according to: + * - Availability of new data + * - Need for timeouts and state cleanups in stateful operators + * + * Returns true only if the next batch should be executed. + * + * Here is the high-level logic on how this constructs the next batch. + * - Check each source whether new data is available + * - Updated the query's metadata and check using the last execution whether there is any need + * to run another batch (for state clean up, etc.) + * - If either of the above is true, then construct the next batch by committing to the offset + * log that range of offsets that the next batch will process. */ - private def constructNextBatch(): Unit = { - // Check to see what new data is available. - val hasNewData = { - awaitProgressLock.lock() - try { - // Generate a map from each unique source to the next available offset. - val latestOffsets: Map[BaseStreamingSource, Option[Offset]] = uniqueSources.map { - case s: Source => - updateStatusMessage(s"Getting offsets from $s") - reportTimeTaken("getOffset") { - (s, s.getOffset) - } - case s: MicroBatchReader => - updateStatusMessage(s"Getting offsets from $s") - reportTimeTaken("setOffsetRange") { - // Once v1 streaming source execution is gone, we can refactor this away. - // For now, we set the range here to get the source to infer the available end offset, - // get that offset, and then set the range again when we later execute. - s.setOffsetRange( - toJava(availableOffsets.get(s).map(off => s.deserializeOffset(off.json))), - Optional.empty()) - } - - val currentOffset = reportTimeTaken("getEndOffset") { s.getEndOffset() } - (s, Option(currentOffset)) - }.toMap - availableOffsets ++= latestOffsets.filter { case (_, o) => o.nonEmpty }.mapValues(_.get) - - if (dataAvailable) { - true - } else { - noNewData = true - false + private def constructNextBatch(noDataBatchesEnables: Boolean): Boolean = withProgressLocked { + // If new data is already available that means this method has already been called before + // and it must have already committed the offset range of next batch to the offset log. + // Hence do nothing, just return true. + if (isNewDataAvailable) return true + + // Generate a map from each unique source to the next available offset. + val latestOffsets: Map[BaseStreamingSource, Option[Offset]] = uniqueSources.map { + case s: Source => + updateStatusMessage(s"Getting offsets from $s") + reportTimeTaken("getOffset") { + (s, s.getOffset) } - } finally { - awaitProgressLock.unlock() - } - } - if (hasNewData) { - var batchWatermarkMs = offsetSeqMetadata.batchWatermarkMs - // Update the eventTime watermarks if we find any in the plan. - if (lastExecution != null) { - lastExecution.executedPlan.collect { - case e: EventTimeWatermarkExec => e - }.zipWithIndex.foreach { - case (e, index) if e.eventTimeStats.value.count > 0 => - logDebug(s"Observed event time stats $index: ${e.eventTimeStats.value}") - val newWatermarkMs = e.eventTimeStats.value.max - e.delayMs - val prevWatermarkMs = watermarkMsMap.get(index) - if (prevWatermarkMs.isEmpty || newWatermarkMs > prevWatermarkMs.get) { - watermarkMsMap.put(index, newWatermarkMs) - } - - // Populate 0 if we haven't seen any data yet for this watermark node. - case (_, index) => - if (!watermarkMsMap.isDefinedAt(index)) { - watermarkMsMap.put(index, 0) - } + case s: MicroBatchReader => + updateStatusMessage(s"Getting offsets from $s") + reportTimeTaken("setOffsetRange") { + // Once v1 streaming source execution is gone, we can refactor this away. + // For now, we set the range here to get the source to infer the available end offset, + // get that offset, and then set the range again when we later execute. + s.setOffsetRange( + toJava(availableOffsets.get(s).map(off => s.deserializeOffset(off.json))), + Optional.empty()) } - // Update the global watermark to the minimum of all watermark nodes. - // This is the safest option, because only the global watermark is fault-tolerant. Making - // it the minimum of all individual watermarks guarantees it will never advance past where - // any individual watermark operator would be if it were in a plan by itself. - if(!watermarkMsMap.isEmpty) { - val newWatermarkMs = watermarkMsMap.minBy(_._2)._2 - if (newWatermarkMs > batchWatermarkMs) { - logInfo(s"Updating eventTime watermark to: $newWatermarkMs ms") - batchWatermarkMs = newWatermarkMs - } else { - logDebug( - s"Event time didn't move: $newWatermarkMs < " + - s"$batchWatermarkMs") - } - } - } - offsetSeqMetadata = offsetSeqMetadata.copy( - batchWatermarkMs = batchWatermarkMs, - batchTimestampMs = triggerClock.getTimeMillis()) // Current batch timestamp in milliseconds + val currentOffset = reportTimeTaken("getEndOffset") { s.getEndOffset() } + (s, Option(currentOffset)) + }.toMap + availableOffsets ++= latestOffsets.filter { case (_, o) => o.nonEmpty }.mapValues(_.get) + + // Update the query metadata + offsetSeqMetadata = offsetSeqMetadata.copy( + batchWatermarkMs = watermarkTracker.currentWatermark, + batchTimestampMs = triggerClock.getTimeMillis()) + + // Check whether next batch should be constructed + val lastExecutionRequiresAnotherBatch = noDataBatchesEnables && + Option(lastExecution).exists(_.shouldRunAnotherBatch(offsetSeqMetadata)) + val shouldConstructNextBatch = isNewDataAvailable || lastExecutionRequiresAnotherBatch + if (shouldConstructNextBatch) { + // Commit the next batch offset range to the offset log updateStatusMessage("Writing offsets to log") reportTimeTaken("walCommit") { - assert(offsetLog.add( - currentBatchId, + assert(offsetLog.add(currentBatchId, availableOffsets.toOffsetSeq(sources, offsetSeqMetadata)), s"Concurrent update to the log. Multiple streaming jobs detected for $currentBatchId") logInfo(s"Committed offsets for batch $currentBatchId. " + @@ -373,7 +359,7 @@ class MicroBatchExecution( reader.commit(reader.deserializeOffset(off.json)) } } else { - throw new IllegalStateException(s"batch $currentBatchId doesn't exist") + throw new IllegalStateException(s"batch ${currentBatchId - 1} doesn't exist") } } @@ -384,15 +370,12 @@ class MicroBatchExecution( commitLog.purge(currentBatchId - minLogEntriesToMaintain) } } + noNewData = false } else { - awaitProgressLock.lock() - try { - // Wake up any threads that are waiting for the stream to progress. - awaitProgressLockCondition.signalAll() - } finally { - awaitProgressLock.unlock() - } + noNewData = true + awaitProgressLockCondition.signalAll() } + shouldConstructNextBatch } /** @@ -400,6 +383,8 @@ class MicroBatchExecution( * @param sparkSessionToRunBatch Isolated [[SparkSession]] to run this batch with. */ private def runBatch(sparkSessionToRunBatch: SparkSession): Unit = { + logDebug(s"Running batch $currentBatchId") + // Request unprocessed data from all sources. newData = reportTimeTaken("getBatch") { availableOffsets.flatMap { @@ -513,17 +498,17 @@ class MicroBatchExecution( } } - awaitProgressLock.lock() - try { - // Wake up any threads that are waiting for the stream to progress. + withProgressLocked { + commitLog.add(currentBatchId) + committedOffsets ++= availableOffsets awaitProgressLockCondition.signalAll() - } finally { - awaitProgressLock.unlock() } + watermarkTracker.updateWatermark(lastExecution.executedPlan) + logDebug(s"Completed batch ${currentBatchId}") } /** Execute a function while locking the stream from making an progress */ - private[sql] def withProgressLocked(f: => Unit): Unit = { + private[sql] def withProgressLocked[T](f: => T): T = { awaitProgressLock.lock() try { f diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/WatermarkTracker.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/WatermarkTracker.scala new file mode 100644 index 0000000000000..80865669558dd --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/WatermarkTracker.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import scala.collection.mutable + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.execution.SparkPlan + +class WatermarkTracker extends Logging { + private val operatorToWatermarkMap = mutable.HashMap[Int, Long]() + private var watermarkMs: Long = 0 + private var updated = false + + def setWatermark(newWatermarkMs: Long): Unit = synchronized { + watermarkMs = newWatermarkMs + } + + def updateWatermark(executedPlan: SparkPlan): Unit = synchronized { + val watermarkOperators = executedPlan.collect { + case e: EventTimeWatermarkExec => e + } + if (watermarkOperators.isEmpty) return + + + watermarkOperators.zipWithIndex.foreach { + case (e, index) if e.eventTimeStats.value.count > 0 => + logDebug(s"Observed event time stats $index: ${e.eventTimeStats.value}") + val newWatermarkMs = e.eventTimeStats.value.max - e.delayMs + val prevWatermarkMs = operatorToWatermarkMap.get(index) + if (prevWatermarkMs.isEmpty || newWatermarkMs > prevWatermarkMs.get) { + operatorToWatermarkMap.put(index, newWatermarkMs) + } + + // Populate 0 if we haven't seen any data yet for this watermark node. + case (_, index) => + if (!operatorToWatermarkMap.isDefinedAt(index)) { + operatorToWatermarkMap.put(index, 0) + } + } + + // Update the global watermark to the minimum of all watermark nodes. + // This is the safest option, because only the global watermark is fault-tolerant. Making + // it the minimum of all individual watermarks guarantees it will never advance past where + // any individual watermark operator would be if it were in a plan by itself. + val newWatermarkMs = operatorToWatermarkMap.minBy(_._2)._2 + if (newWatermarkMs > watermarkMs) { + logInfo(s"Updating eventTime watermark to: $newWatermarkMs ms") + watermarkMs = newWatermarkMs + updated = true + } else { + logDebug(s"Event time didn't move: $newWatermarkMs < $watermarkMs") + updated = false + } + } + + def currentWatermark: Long = synchronized { watermarkMs } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala index 628923d367ce7..22258274c70c1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -222,11 +222,20 @@ class MemoryStreamDataReaderFactory(records: Array[UnsafeRow]) } } +/** A common trait for MemorySinks with methods used for testing */ +trait MemorySinkBase extends BaseStreamingSink { + def allData: Seq[Row] + def latestBatchData: Seq[Row] + def dataSinceBatch(sinceBatchId: Long): Seq[Row] + def latestBatchId: Option[Long] +} + /** * A sink that stores the results in memory. This [[Sink]] is primarily intended for use in unit * tests and does not provide durability. */ -class MemorySink(val schema: StructType, outputMode: OutputMode) extends Sink with Logging { +class MemorySink(val schema: StructType, outputMode: OutputMode) extends Sink + with MemorySinkBase with Logging { private case class AddedData(batchId: Long, data: Array[Row]) @@ -236,7 +245,7 @@ class MemorySink(val schema: StructType, outputMode: OutputMode) extends Sink wi /** Returns all rows that are stored in this [[Sink]]. */ def allData: Seq[Row] = synchronized { - batches.map(_.data).flatten + batches.flatMap(_.data) } def latestBatchId: Option[Long] = synchronized { @@ -245,6 +254,10 @@ class MemorySink(val schema: StructType, outputMode: OutputMode) extends Sink wi def latestBatchData: Seq[Row] = synchronized { batches.lastOption.toSeq.flatten(_.data) } + def dataSinceBatch(sinceBatchId: Long): Seq[Row] = synchronized { + batches.filter(_.batchId > sinceBatchId).flatMap(_.data) + } + def toDebugString: String = synchronized { batches.map { case AddedData(batchId, data) => val dataStr = try data.mkString(" ") catch { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala index d871d37ad37c1..0d6c239274dd8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics} import org.apache.spark.sql.catalyst.streaming.InternalOutputModes.{Append, Complete, Update} -import org.apache.spark.sql.execution.streaming.Sink +import org.apache.spark.sql.execution.streaming.{MemorySinkBase, Sink} import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, StreamWriteSupport} import org.apache.spark.sql.sources.v2.writer._ import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter @@ -39,7 +39,7 @@ import org.apache.spark.sql.types.StructType * A sink that stores the results in memory. This [[Sink]] is primarily intended for use in unit * tests and does not provide durability. */ -class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with Logging { +class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with MemorySinkBase with Logging { override def createStreamWriter( queryId: String, schema: StructType, @@ -67,6 +67,10 @@ class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with Logging { batches.lastOption.toSeq.flatten(_.data) } + def dataSinceBatch(sinceBatchId: Long): Seq[Row] = synchronized { + batches.filter(_.batchId > sinceBatchId).flatMap(_.data) + } + def toDebugString: String = synchronized { batches.map { case AddedData(batchId, data) => val dataStr = try data.mkString(" ") catch { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index c9354ac0ec78a..1691a6320a526 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -126,6 +126,12 @@ trait StateStoreWriter extends StatefulOperator { self: SparkPlan => name -> SQLMetrics.createTimingMetric(sparkContext, desc) }.toMap } + + /** + * Should the MicroBatchExecution run another batch based on this stateful operator and the + * current updated metadata. + */ + def shouldRunAnotherBatch(newMetadata: OffsetSeqMetadata): Boolean = false } /** An operator that supports watermark. */ @@ -388,6 +394,12 @@ case class StateStoreSaveExec( ClusteredDistribution(keyExpressions, stateInfo.map(_.numPartitions)) :: Nil } } + + override def shouldRunAnotherBatch(newMetadata: OffsetSeqMetadata): Boolean = { + (outputMode.contains(Append) || outputMode.contains(Update)) && + eventTimeWatermark.isDefined && + newMetadata.batchWatermarkMs > eventTimeWatermark.get + } } /** Physical operator for executing streaming Deduplicate. */ @@ -454,6 +466,10 @@ case class StreamingDeduplicateExec( override def output: Seq[Attribute] = child.output override def outputPartitioning: Partitioning = child.outputPartitioning + + override def shouldRunAnotherBatch(newMetadata: OffsetSeqMetadata): Boolean = { + eventTimeWatermark.isDefined && newMetadata.batchWatermarkMs > eventTimeWatermark.get + } } object StreamingDeduplicateExec { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterSuite.scala index 03bf71b3f4b78..e60c339bc9cc1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterSuite.scala @@ -211,14 +211,12 @@ class ForeachWriterSuite extends StreamTest with SharedSQLContext with BeforeAnd try { inputData.addData(10, 11, 12) query.processAllAvailable() - inputData.addData(25) // Advance watermark to 15 seconds - query.processAllAvailable() inputData.addData(25) // Evict items less than previous watermark query.processAllAvailable() // There should be 3 batches and only does the last batch contain a value. val allEvents = ForeachWriterSuite.allEvents() - assert(allEvents.size === 3) + assert(allEvents.size === 4) val expectedEvents = Seq( Seq( ForeachWriterSuite.Open(partition = 0, version = 0), @@ -230,6 +228,10 @@ class ForeachWriterSuite extends StreamTest with SharedSQLContext with BeforeAnd ), Seq( ForeachWriterSuite.Open(partition = 0, version = 2), + ForeachWriterSuite.Close(None) + ), + Seq( + ForeachWriterSuite.Open(partition = 0, version = 3), ForeachWriterSuite.Process(value = 3), ForeachWriterSuite.Close(None) ) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala index d6bef9ce07379..7e8fde1ff8e56 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.functions.{count, window} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.OutputMode._ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matchers with Logging { @@ -137,20 +138,12 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche assert(e.get("watermark") === formatTimestamp(5)) }, AddData(inputData2, 25), - CheckAnswer(), - assertEventStats { e => - assert(e.get("max") === formatTimestamp(25)) - assert(e.get("min") === formatTimestamp(25)) - assert(e.get("avg") === formatTimestamp(25)) - assert(e.get("watermark") === formatTimestamp(5)) - }, - AddData(inputData2, 25), CheckAnswer((10, 3)), assertEventStats { e => assert(e.get("max") === formatTimestamp(25)) assert(e.get("min") === formatTimestamp(25)) assert(e.get("avg") === formatTimestamp(25)) - assert(e.get("watermark") === formatTimestamp(15)) + assert(e.get("watermark") === formatTimestamp(5)) } ) } @@ -167,15 +160,12 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche testStream(windowedAggregation)( AddData(inputData, 10, 11, 12, 13, 14, 15), - CheckLastBatch(), + CheckNewAnswer(), AddData(inputData, 25), // Advance watermark to 15 seconds - CheckLastBatch(), - assertNumStateRows(3), - AddData(inputData, 25), // Emit items less than watermark and drop their state - CheckLastBatch((10, 5)), + CheckNewAnswer((10, 5)), assertNumStateRows(2), AddData(inputData, 10), // Should not emit anything as data less than watermark - CheckLastBatch(), + CheckNewAnswer(), assertNumStateRows(2) ) } @@ -193,15 +183,15 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche testStream(windowedAggregation, OutputMode.Update)( AddData(inputData, 10, 11, 12, 13, 14, 15), - CheckLastBatch((10, 5), (15, 1)), + CheckNewAnswer((10, 5), (15, 1)), AddData(inputData, 25), // Advance watermark to 15 seconds - CheckLastBatch((25, 1)), - assertNumStateRows(3), + CheckNewAnswer((25, 1)), + assertNumStateRows(2), AddData(inputData, 10, 25), // Ignore 10 as its less than watermark - CheckLastBatch((25, 2)), + CheckNewAnswer((25, 2)), assertNumStateRows(2), AddData(inputData, 10), // Should not emit anything as data less than watermark - CheckLastBatch(), + CheckNewAnswer(), assertNumStateRows(2) ) } @@ -251,56 +241,25 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche testStream(df)( AddData(inputData, 10, 11, 12, 13, 14, 15), - CheckLastBatch(), + CheckAnswer(), AddData(inputData, 25), // Advance watermark to 15 seconds - StopStream, - StartStream(), - CheckLastBatch(), - AddData(inputData, 25), // Evict items less than previous watermark. - CheckLastBatch((10, 5)), + CheckAnswer((10, 5)), StopStream, AssertOnQuery { q => // purge commit and clear the sink - val commit = q.commitLog.getLatest().map(_._1).getOrElse(-1L) + 1L + val commit = q.commitLog.getLatest().map(_._1).getOrElse(-1L) q.commitLog.purge(commit) q.sink.asInstanceOf[MemorySink].clear() true }, StartStream(), - CheckLastBatch((10, 5)), // Recompute last batch and re-evict timestamp 10 - AddData(inputData, 30), // Advance watermark to 20 seconds - CheckLastBatch(), + AddData(inputData, 10, 27, 30), // Advance watermark to 20 seconds, 10 should be ignored + CheckAnswer((15, 1)), StopStream, - StartStream(), // Watermark should still be 15 seconds - AddData(inputData, 17), - CheckLastBatch(), // We still do not see next batch - AddData(inputData, 30), // Advance watermark to 20 seconds - CheckLastBatch(), - AddData(inputData, 30), // Evict items less than previous watermark. - CheckLastBatch((15, 2)) // Ensure we see next window - ) - } - - test("dropping old data") { - val inputData = MemoryStream[Int] - - val windowedAggregation = inputData.toDF() - .withColumn("eventTime", $"value".cast("timestamp")) - .withWatermark("eventTime", "10 seconds") - .groupBy(window($"eventTime", "5 seconds") as 'window) - .agg(count("*") as 'count) - .select($"window".getField("start").cast("long").as[Long], $"count".as[Long]) - - testStream(windowedAggregation)( - AddData(inputData, 10, 11, 12), - CheckAnswer(), - AddData(inputData, 25), // Advance watermark to 15 seconds - CheckAnswer(), - AddData(inputData, 25), // Evict items less than previous watermark. - CheckAnswer((10, 3)), - AddData(inputData, 10), // 10 is later than 15 second watermark - CheckAnswer((10, 3)), - AddData(inputData, 25), - CheckAnswer((10, 3)) // Should not emit an incorrect partial result. + StartStream(), + AddData(inputData, 17), // Watermark should still be 20 seconds, 17 should be ignored + CheckAnswer((15, 1)), + AddData(inputData, 40), // Advance watermark to 30 seconds, emit first data 25 + CheckNewAnswer((25, 2)) ) } @@ -421,8 +380,6 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche AddData(inputData, 10), CheckAnswer(), AddData(inputData, 25), // Advance watermark to 15 seconds - CheckAnswer(), - AddData(inputData, 25), // Evict items less than previous watermark. CheckAnswer((10, 1)) ) } @@ -501,8 +458,35 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche } } + test("test no-data flag") { + val flagKey = SQLConf.STREAMING_NO_DATA_MICRO_BATCHES_ENABLED.key + + def testWithFlag(flag: Boolean): Unit = withClue(s"with $flagKey = $flag") { + val inputData = MemoryStream[Int] + val windowedAggregation = inputData.toDF() + .withColumn("eventTime", $"value".cast("timestamp")) + .withWatermark("eventTime", "10 seconds") + .groupBy(window($"eventTime", "5 seconds") as 'window) + .agg(count("*") as 'count) + .select($"window".getField("start").cast("long").as[Long], $"count".as[Long]) + + testStream(windowedAggregation)( + StartStream(additionalConfs = Map(flagKey -> flag.toString)), + AddData(inputData, 10, 11, 12, 13, 14, 15), + CheckNewAnswer(), + AddData(inputData, 25), // Advance watermark to 15 seconds + // Check if there is new answer if flag is set, no new answer otherwise + if (flag) CheckNewAnswer((10, 5)) else CheckNewAnswer() + ) + } + + testWithFlag(true) + testWithFlag(false) + } + private def assertNumStateRows(numTotalRows: Long): AssertOnQuery = AssertOnQuery { q => - val progressWithData = q.recentProgress.filter(_.numInputRows > 0).lastOption.get + q.processAllAvailable() + val progressWithData = q.recentProgress.lastOption.get assert(progressWithData.stateOperators(0).numRowsTotal === numTotalRows) true } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala index cf41d7e0e4fe1..ed53def556cb8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala @@ -279,13 +279,10 @@ class FileStreamSinkSuite extends StreamTest { check() // nothing emitted yet addTimestamp(104, 123) // watermark = 90 before this, watermark = 123 - 10 = 113 after this - check() // nothing emitted yet + check((100L, 105L) -> 2L) // no-data-batch emits results on 100-105, addTimestamp(140) // wm = 113 before this, emit results on 100-105, wm = 130 after this - check((100L, 105L) -> 2L) - - addTimestamp(150) // wm = 130s before this, emit results on 120-125, wm = 150 after this - check((100L, 105L) -> 2L, (120L, 125L) -> 1L) + check((100L, 105L) -> 2L, (120L, 125L) -> 1L) // no-data-batch emits results on 120-125 } finally { if (query != null) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StateStoreMetricsTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StateStoreMetricsTest.scala index 368c4604dfca8..e45f9d3e2e97b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StateStoreMetricsTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StateStoreMetricsTest.scala @@ -17,20 +17,58 @@ package org.apache.spark.sql.streaming +import org.apache.spark.sql.execution.streaming.StreamExecution + trait StateStoreMetricsTest extends StreamTest { + private var lastCheckedRecentProgressIndex = -1 + private var lastQuery: StreamExecution = null + + override def beforeEach(): Unit = { + super.beforeEach() + lastCheckedRecentProgressIndex = -1 + } + def assertNumStateRows(total: Seq[Long], updated: Seq[Long]): AssertOnQuery = AssertOnQuery(s"Check total state rows = $total, updated state rows = $updated") { q => - val progressWithData = q.recentProgress.filter(_.numInputRows > 0).lastOption.get - assert( - progressWithData.stateOperators.map(_.numRowsTotal) === total, - "incorrect total rows") - assert( - progressWithData.stateOperators.map(_.numRowsUpdated) === updated, - "incorrect updates rows") + val recentProgress = q.recentProgress + require(recentProgress.nonEmpty, "No progress made, cannot check num state rows") + require(recentProgress.length < spark.sessionState.conf.streamingProgressRetention, + "This test assumes that all progresses are present in q.recentProgress but " + + "some may have been dropped due to retention limits") + + if (q.ne(lastQuery)) lastCheckedRecentProgressIndex = -1 + lastQuery = q + + val numStateOperators = recentProgress.last.stateOperators.length + val progressesSinceLastCheck = recentProgress + .slice(lastCheckedRecentProgressIndex + 1, recentProgress.length) + .filter(_.stateOperators.length == numStateOperators) + + val allNumUpdatedRowsSinceLastCheck = + progressesSinceLastCheck.map(_.stateOperators.map(_.numRowsUpdated)) + + lazy val debugString = "recent progresses:\n" + + progressesSinceLastCheck.map(_.prettyJson).mkString("\n\n") + + val numTotalRows = recentProgress.last.stateOperators.map(_.numRowsTotal) + assert(numTotalRows === total, s"incorrect total rows, $debugString") + + val numUpdatedRows = arraySum(allNumUpdatedRowsSinceLastCheck, numStateOperators) + assert(numUpdatedRows === updated, s"incorrect updates rows, $debugString") + + lastCheckedRecentProgressIndex = recentProgress.length - 1 true } def assertNumStateRows(total: Long, updated: Long): AssertOnQuery = assertNumStateRows(Seq(total), Seq(updated)) + + def arraySum(arraySeq: Seq[Array[Long]], arrayLength: Int): Seq[Long] = { + if (arraySeq.isEmpty) return Seq.fill(arrayLength)(0L) + + assert(arraySeq.forall(_.length == arrayLength), + "Arrays are of different lengths:\n" + arraySeq.map(_.toSeq).mkString("\n")) + (0 until arrayLength).map { index => arraySeq.map(_.apply(index)).sum } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index af0268fa47871..9d139a927bea5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -44,6 +44,7 @@ import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous.{ContinuousExecution, EpochCoordinatorRef, IncrementAndGetEpoch} import org.apache.spark.sql.execution.streaming.sources.MemorySinkV2 import org.apache.spark.sql.execution.streaming.state.StateStore +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.StreamingQueryListener._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.util.{Clock, SystemClock, Utils} @@ -192,7 +193,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be case class CheckAnswerRowsContains(expectedAnswer: Seq[Row], lastOnly: Boolean = false) extends StreamAction with StreamMustBeRunning { override def toString: String = s"$operatorName: ${expectedAnswer.mkString(",")}" - private def operatorName = if (lastOnly) "CheckLastBatch" else "CheckAnswer" + private def operatorName = if (lastOnly) "CheckLastBatchContains" else "CheckAnswerContains" } case class CheckAnswerRowsByFunc( @@ -202,6 +203,23 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be private def operatorName = if (lastOnly) "CheckLastBatchByFunc" else "CheckAnswerByFunc" } + case class CheckNewAnswerRows(expectedAnswer: Seq[Row]) + extends StreamAction with StreamMustBeRunning { + override def toString: String = s"$operatorName: ${expectedAnswer.mkString(",")}" + + private def operatorName = "CheckNewAnswer" + } + + object CheckNewAnswer { + def apply(): CheckNewAnswerRows = CheckNewAnswerRows(Seq.empty) + + def apply[A: Encoder](data: A, moreData: A*): CheckNewAnswerRows = { + val encoder = encoderFor[A] + val toExternalRow = RowEncoder(encoder.schema).resolveAndBind() + CheckNewAnswerRows((data +: moreData).map(d => toExternalRow.fromRow(encoder.toRow(d)))) + } + } + /** Stops the stream. It must currently be running. */ case object StopStream extends StreamAction with StreamMustBeRunning @@ -435,13 +453,24 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be """.stripMargin) } - def fetchStreamAnswer(currentStream: StreamExecution, lastOnly: Boolean) = { + var lastFetchedMemorySinkLastBatchId: Long = -1 + + def fetchStreamAnswer( + currentStream: StreamExecution, + lastOnly: Boolean = false, + sinceLastFetchOnly: Boolean = false) = { + verify( + !(lastOnly && sinceLastFetchOnly), "both lastOnly and sinceLastFetchOnly cannot be true") verify(currentStream != null, "stream not running") // Block until all data added has been processed for all the source awaiting.foreach { case (sourceIndex, offset) => failAfter(streamingTimeout) { currentStream.awaitOffset(sourceIndex, offset) + // Make sure all processing including no-data-batches have been executed + if (!currentStream.triggerClock.isInstanceOf[StreamManualClock]) { + currentStream.processAllAvailable() + } } } @@ -463,14 +492,21 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be } } - val (latestBatchData, allData) = sink match { - case s: MemorySink => (s.latestBatchData, s.allData) - case s: MemorySinkV2 => (s.latestBatchData, s.allData) - } - try if (lastOnly) latestBatchData else allData catch { + val rows = try { + if (sinceLastFetchOnly) { + if (sink.latestBatchId.getOrElse(-1L) < lastFetchedMemorySinkLastBatchId) { + failTest("MemorySink was probably cleared since last fetch. Use CheckAnswer instead.") + } + sink.dataSinceBatch(lastFetchedMemorySinkLastBatchId) + } else { + if (lastOnly) sink.latestBatchData else sink.allData + } + } catch { case e: Exception => failTest("Exception while getting data from sink", e) } + lastFetchedMemorySinkLastBatchId = sink.latestBatchId.getOrElse(-1L) + rows } def executeAction(action: StreamAction): Unit = { @@ -704,6 +740,12 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be } catch { case e: Throwable => failTest(e.toString) } + + case CheckNewAnswerRows(expectedAnswer) => + val sparkAnswer = fetchStreamAnswer(currentStream, sinceLastFetchOnly = true) + QueryTest.sameRows(expectedAnswer, sparkAnswer).foreach { + error => failTest(error) + } } pos += 1 } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/DeduplicateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingDeduplicationSuite.scala similarity index 80% rename from sql/core/src/test/scala/org/apache/spark/sql/streaming/DeduplicateSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingDeduplicationSuite.scala index 0088b64d6195e..42ffd472eb843 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/DeduplicateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingDeduplicationSuite.scala @@ -24,8 +24,9 @@ import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ import org.apache.spark.sql.execution.streaming.{MemoryStream, StreamingDeduplicateExec} import org.apache.spark.sql.execution.streaming.state.StateStore import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf -class DeduplicateSuite extends StateStoreMetricsTest with BeforeAndAfterAll { +class StreamingDeduplicationSuite extends StateStoreMetricsTest with BeforeAndAfterAll { import testImplicits._ @@ -97,28 +98,20 @@ class DeduplicateSuite extends StateStoreMetricsTest with BeforeAndAfterAll { testStream(result, Append)( AddData(inputData, (1 to 5).flatMap(_ => (10 to 15)): _*), - CheckLastBatch(10 to 15: _*), + CheckAnswer(10 to 15: _*), assertNumStateRows(total = 6, updated = 6), - AddData(inputData, 25), // Advance watermark to 15 seconds - CheckLastBatch(25), - assertNumStateRows(total = 7, updated = 1), - - AddData(inputData, 25), // Drop states less than watermark - CheckLastBatch(), - assertNumStateRows(total = 1, updated = 0), + AddData(inputData, 25), // Advance watermark to 15 secs, no-data-batch drops rows <= 15 + CheckNewAnswer(25), + assertNumStateRows(total = 1, updated = 1), AddData(inputData, 10), // Should not emit anything as data less than watermark - CheckLastBatch(), + CheckNewAnswer(), assertNumStateRows(total = 1, updated = 0), - AddData(inputData, 45), // Advance watermark to 35 seconds - CheckLastBatch(45), - assertNumStateRows(total = 2, updated = 1), - - AddData(inputData, 45), // Drop states less than watermark - CheckLastBatch(), - assertNumStateRows(total = 1, updated = 0) + AddData(inputData, 45), // Advance watermark to 35 seconds, no-data-batch drops row 25 + CheckNewAnswer(45), + assertNumStateRows(total = 1, updated = 1) ) } @@ -141,33 +134,20 @@ class DeduplicateSuite extends StateStoreMetricsTest with BeforeAndAfterAll { assertNumStateRows(total = Seq(2L, 6L), updated = Seq(2L, 6L)), AddData(inputData, 25), // Advance watermark to 15 seconds - CheckLastBatch(), - // states in aggregate in [10, 14), [15, 20) and [25, 30) (3 windows) - // states in deduplicate is 10 to 15 and 25 - assertNumStateRows(total = Seq(3L, 7L), updated = Seq(1L, 1L)), - - AddData(inputData, 25), // Emit items less than watermark and drop their state - CheckLastBatch((10 -> 5)), // 5 items (10 to 14) after deduplicate - // states in aggregate in [15, 20) and [25, 30) (2 windows, note aggregate uses the end of - // window to evict items, so [15, 20) is still in the state store) - // states in deduplicate is 25 - assertNumStateRows(total = Seq(2L, 1L), updated = Seq(0L, 0L)), + CheckLastBatch((10 -> 5)), // 5 items (10 to 14) after deduplicate, emitted with no-data-batch + // states in aggregate in [15, 20) and [25, 30); no-data-batch removed [10, 14) + // states in deduplicate is 25, no-data-batch removed 10 to 14 + assertNumStateRows(total = Seq(2L, 1L), updated = Seq(1L, 1L)), AddData(inputData, 10), // Should not emit anything as data less than watermark CheckLastBatch(), assertNumStateRows(total = Seq(2L, 1L), updated = Seq(0L, 0L)), AddData(inputData, 40), // Advance watermark to 30 seconds - CheckLastBatch(), - // states in aggregate in [15, 20), [25, 30) and [40, 45) - // states in deduplicate is 25 and 40, - assertNumStateRows(total = Seq(3L, 2L), updated = Seq(1L, 1L)), - - AddData(inputData, 40), // Emit items less than watermark and drop their state CheckLastBatch((15 -> 1), (25 -> 1)), - // states in aggregate in [40, 45) - // states in deduplicate is 40, - assertNumStateRows(total = Seq(1L, 1L), updated = Seq(0L, 0L)) + // states in aggregate is [40, 45); no-data-batch removed [15, 20) and [25, 30) + // states in deduplicate is 40; no-data-batch removed 25 + assertNumStateRows(total = Seq(1L, 1L), updated = Seq(1L, 1L)) ) } @@ -260,13 +240,13 @@ class DeduplicateSuite extends StateStoreMetricsTest with BeforeAndAfterAll { .select($"id") testStream(df)( AddData(input, 1 -> 1, 1 -> 1, 1 -> 2), - CheckLastBatch(1, 2), + CheckAnswer(1, 2), AddData(input, 1 -> 1, 2 -> 3, 2 -> 4), - CheckLastBatch(3, 4), + CheckNewAnswer(3, 4), AddData(input, 1 -> 0, 1 -> 1, 3 -> 5, 3 -> 6), // Drop (1 -> 0, 1 -> 1) due to watermark - CheckLastBatch(5, 6), + CheckNewAnswer(5, 6), AddData(input, 1 -> 0, 4 -> 7), // Drop (1 -> 0) due to watermark - CheckLastBatch(7) + CheckNewAnswer(7) ) } @@ -279,7 +259,37 @@ class DeduplicateSuite extends StateStoreMetricsTest with BeforeAndAfterAll { .select($"id", $"time".cast("long")) testStream(df)( AddData(input, 1 -> 1, 1 -> 2, 2 -> 2), - CheckLastBatch(1 -> 1, 2 -> 2) + CheckAnswer(1 -> 1, 2 -> 2) ) } + + test("test no-data flag") { + val flagKey = SQLConf.STREAMING_NO_DATA_MICRO_BATCHES_ENABLED.key + + def testWithFlag(flag: Boolean): Unit = withClue(s"with $flagKey = $flag") { + val inputData = MemoryStream[Int] + val result = inputData.toDS() + .withColumn("eventTime", $"value".cast("timestamp")) + .withWatermark("eventTime", "10 seconds") + .dropDuplicates() + .select($"eventTime".cast("long").as[Long]) + + testStream(result, Append)( + StartStream(additionalConfs = Map(flagKey -> flag.toString)), + AddData(inputData, 10, 11, 12, 13, 14, 15), + CheckAnswer(10, 11, 12, 13, 14, 15), + assertNumStateRows(total = 6, updated = 6), + + AddData(inputData, 25), // Advance watermark to 15 seconds + CheckNewAnswer(25), + { // State should have been cleaned if flag is set, otherwise should not have been cleaned + if (flag) assertNumStateRows(total = 1, updated = 1) + else assertNumStateRows(total = 7, updated = 1) + } + ) + } + + testWithFlag(true) + testWithFlag(false) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala index 11bdd13942dcb..da8f9608c1e9c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala @@ -192,7 +192,7 @@ class StreamingInnerJoinSuite extends StreamTest with StateStoreMetricsTest with CheckLastBatch((1, 5, 11)), AddData(rightInput, (1, 10)), CheckLastBatch(), // no match as neither 5, nor 10 from leftTime is less than rightTime 10 - 5 - assertNumStateRows(total = 3, updated = 1), + assertNumStateRows(total = 3, updated = 3), // Increase event time watermark to 20s by adding data with time = 30s on both inputs AddData(leftInput, (1, 3), (1, 30)), @@ -276,14 +276,14 @@ class StreamingInnerJoinSuite extends StreamTest with StateStoreMetricsTest with CheckAnswer(), AddData(rightInput, (1, 14), (1, 15), (1, 25), (1, 26), (1, 30), (1, 31)), CheckLastBatch((1, 20, 15), (1, 20, 25), (1, 20, 26), (1, 20, 30)), - assertNumStateRows(total = 7, updated = 6), + assertNumStateRows(total = 7, updated = 7), // If rightTime = 60, then it matches only leftTime = [50, 65] AddData(rightInput, (1, 60)), CheckLastBatch(), // matches with nothing on the left AddData(leftInput, (1, 49), (1, 50), (1, 65), (1, 66)), CheckLastBatch((1, 50, 60), (1, 65, 60)), - assertNumStateRows(total = 12, updated = 4), + assertNumStateRows(total = 12, updated = 5), // Event time watermark = min(left: 66 - delay 20 = 46, right: 60 - delay 30 = 30) = 30 // Left state value watermark = 30 - 10 = slightly less than 20 (since condition has <=) @@ -573,7 +573,7 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with // nulls won't show up until the next batch after the watermark advances. MultiAddData(leftInput, 21)(rightInput, 22), CheckLastBatch(), - assertNumStateRows(total = 12, updated = 2), + assertNumStateRows(total = 12, updated = 12), AddData(leftInput, 22), CheckLastBatch(Row(22, 30, 44, 66), Row(1, 10, 2, null), Row(2, 10, 4, null)), assertNumStateRows(total = 3, updated = 1) @@ -591,7 +591,7 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with // nulls won't show up until the next batch after the watermark advances. MultiAddData(leftInput, 21)(rightInput, 22), CheckLastBatch(), - assertNumStateRows(total = 12, updated = 2), + assertNumStateRows(total = 12, updated = 12), AddData(leftInput, 22), CheckLastBatch(Row(22, 30, 44, 66), Row(6, 10, null, 18), Row(7, 10, null, 21)), assertNumStateRows(total = 3, updated = 1) @@ -630,7 +630,7 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with CheckLastBatch((1, 1, 5, 10)), AddData(rightInput, (1, 11)), CheckLastBatch(), // no match as left time is too low - assertNumStateRows(total = 5, updated = 1), + assertNumStateRows(total = 5, updated = 5), // Increase event time watermark to 20s by adding data with time = 30s on both inputs AddData(leftInput, (1, 7), (1, 30)), @@ -668,7 +668,7 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with CheckLastBatch(Row(1, 10, 2, null), Row(2, 10, 4, null), Row(3, 10, 6, null)), MultiAddData(leftInput, 20)(rightInput, 21), CheckLastBatch(), - assertNumStateRows(total = 5, updated = 2), + assertNumStateRows(total = 5, updated = 5), // 1...3 added, but 20 and 21 not added AddData(rightInput, 20), CheckLastBatch( Row(20, 30, 40, 60)), @@ -678,7 +678,7 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with CheckLastBatch((40, 50, 80, 120), (41, 50, 82, 123)), MultiAddData(leftInput, 70)(rightInput, 71), CheckLastBatch(), - assertNumStateRows(total = 6, updated = 2), + assertNumStateRows(total = 6, updated = 6), // all inputs added since last check AddData(rightInput, 70), CheckLastBatch((70, 80, 140, 210)), assertNumStateRows(total = 3, updated = 1), @@ -687,7 +687,7 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with CheckLastBatch(), MultiAddData(leftInput, 1000)(rightInput, 1001), CheckLastBatch(), - assertNumStateRows(total = 8, updated = 2), + assertNumStateRows(total = 8, updated = 5), // 101...103 added, but 1000 and 1001 not added AddData(rightInput, 1000), CheckLastBatch( Row(1000, 1010, 2000, 3000), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala index 390d67d1feb27..0cb2375e0a49a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala @@ -334,7 +334,7 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi assert(progress.sources.length === 1) assert(progress.sources(0).description contains "MemoryStream") - assert(progress.sources(0).startOffset === null) + assert(progress.sources(0).startOffset === "0") assert(progress.sources(0).endOffset !== null) assert(progress.sources(0).processedRowsPerSecond === 4.0) // 2 rows processed in 500 ms From dd4b1b9c7ccad3363a6a21524aed047fcd282f68 Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Sun, 6 May 2018 10:25:01 +0800 Subject: [PATCH 0739/1161] [SPARK-24185][SPARKR][SQL] add flatten function to SparkR ## What changes were proposed in this pull request? add array flatten function to SparkR ## How was this patch tested? Unit tests were added in R/pkg/tests/fulltests/test_sparkSQL.R Author: Huaxin Gao Closes #21244 from huaxingao/spark-24185. --- R/pkg/NAMESPACE | 1 + R/pkg/R/functions.R | 14 ++++++++++++++ R/pkg/R/generics.R | 4 ++++ R/pkg/tests/fulltests/test_sparkSQL.R | 6 ++++++ 4 files changed, 25 insertions(+) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index f36d462a83cb0..8cd00352d1956 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -258,6 +258,7 @@ exportMethods("%<=>%", "expr", "factorial", "first", + "flatten", "floor", "format_number", "format_string", diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index ec4bd4e73c7e5..0ec99d19e21e4 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -208,6 +208,7 @@ NULL #' head(select(tmp, array_contains(tmp$v1, 21), size(tmp$v1))) #' head(select(tmp, array_max(tmp$v1), array_min(tmp$v1))) #' head(select(tmp, array_position(tmp$v1, 21))) +#' head(select(tmp, flatten(tmp$v1))) #' tmp2 <- mutate(tmp, v2 = explode(tmp$v1)) #' head(tmp2) #' head(select(tmp, posexplode(tmp$v1))) @@ -3035,6 +3036,19 @@ setMethod("array_position", column(jc) }) +#' @details +#' \code{flatten}: Transforms an array of arrays into a single array. +#' +#' @rdname column_collection_functions +#' @aliases flatten flatten,Column-method +#' @note flatten since 2.4.0 +setMethod("flatten", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "flatten", x@jc) + column(jc) + }) + #' @details #' \code{map_keys}: Returns an unordered array containing the keys of the map. #' diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 562d3399ee9c8..4ef12d19b3575 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -918,6 +918,10 @@ setGeneric("explode_outer", function(x) { standardGeneric("explode_outer") }) #' @name NULL setGeneric("expr", function(x) { standardGeneric("expr") }) +#' @rdname column_collection_functions +#' @name NULL +setGeneric("flatten", function(x) { standardGeneric("flatten") }) + #' @rdname column_datetime_diff_functions #' @name NULL setGeneric("from_utc_timestamp", function(y, x) { standardGeneric("from_utc_timestamp") }) diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index 8cc2db7a140f9..3a8866bf2a88a 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -1502,6 +1502,12 @@ test_that("column functions", { result <- collect(select(df, sort_array(df[[1]])))[[1]] expect_equal(result, list(list(1L, 2L, 3L), list(4L, 5L, 6L))) + # Test flattern + df <- createDataFrame(list(list(list(list(1L, 2L), list(3L, 4L))), + list(list(list(5L, 6L), list(7L, 8L))))) + result <- collect(select(df, flatten(df[[1]])))[[1]] + expect_equal(result, list(list(1L, 2L, 3L, 4L), list(5L, 6L, 7L, 8L))) + # Test map_keys(), map_values() and element_at() df <- createDataFrame(list(list(map = as.environment(list(x = 1, y = 2))))) result <- collect(select(df, map_keys(df$map)))[[1]] From f38ea00e83099a5ae8d3afdec2e896e43c2db612 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Sun, 6 May 2018 20:41:32 -0700 Subject: [PATCH 0740/1161] [SPARK-24017][SQL] Refactor ExternalCatalog to be an interface ## What changes were proposed in this pull request? This refactors the external catalog to be an interface. It can be easier for the future work in the catalog federation. After the refactoring, `ExternalCatalog` is much cleaner without mixing the listener event generation logic. ## How was this patch tested? The existing tests Author: gatorsmile Closes #21122 from gatorsmile/refactorExternalCatalog. --- .../catalyst/catalog/ExternalCatalog.scala | 134 ++------ .../catalog/ExternalCatalogWithListener.scala | 298 ++++++++++++++++++ .../catalyst/catalog/InMemoryCatalog.scala | 26 +- .../catalog/ExternalCatalogEventSuite.scala | 2 +- .../spark/sql/internal/SharedState.scala | 9 +- .../sql/hive/thriftserver/SparkSQLEnv.scala | 2 +- .../spark/sql/hive/HiveExternalCatalog.scala | 26 +- .../spark/sql/hive/HiveSessionCatalog.scala | 4 +- .../sql/hive/HiveSessionStateBuilder.scala | 7 +- .../sql/hive/execution/SaveAsHiveFile.scala | 2 +- .../apache/spark/sql/hive/test/TestHive.scala | 11 +- .../HiveExternalSessionCatalogSuite.scala | 2 +- .../sql/hive/HiveSchemaInferenceSuite.scala | 3 +- .../sql/hive/HiveSessionStateSuite.scala | 3 +- .../spark/sql/hive/HiveSparkSubmitSuite.scala | 5 +- .../spark/sql/hive/ShowCreateTableSuite.scala | 3 +- .../spark/sql/hive/client/VersionsSuite.scala | 4 +- .../sql/hive/execution/HiveDDLSuite.scala | 6 +- .../sql/hive/execution/SQLQuerySuite.scala | 3 +- .../sql/hive/test/TestHiveSingleton.scala | 2 +- 20 files changed, 384 insertions(+), 168 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogWithListener.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala index 45b4f013620c1..1a145c24d78cc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala @@ -17,10 +17,9 @@ package org.apache.spark.sql.catalyst.catalog -import org.apache.spark.sql.catalyst.analysis.{FunctionAlreadyExistsException, NoSuchDatabaseException, NoSuchFunctionException, NoSuchTableException} +import org.apache.spark.sql.catalyst.analysis.{FunctionAlreadyExistsException, NoSuchDatabaseException, NoSuchFunctionException, NoSuchPartitionException, NoSuchTableException} import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.types.StructType -import org.apache.spark.util.ListenerBus /** * Interface for the system catalog (of functions, partitions, tables, and databases). @@ -31,10 +30,13 @@ import org.apache.spark.util.ListenerBus * * Implementations should throw [[NoSuchDatabaseException]] when databases don't exist. */ -abstract class ExternalCatalog - extends ListenerBus[ExternalCatalogEventListener, ExternalCatalogEvent] { +trait ExternalCatalog { import CatalogTypes.TablePartitionSpec + // -------------------------------------------------------------------------- + // Utils + // -------------------------------------------------------------------------- + protected def requireDbExists(db: String): Unit = { if (!databaseExists(db)) { throw new NoSuchDatabaseException(db) @@ -63,22 +65,9 @@ abstract class ExternalCatalog // Databases // -------------------------------------------------------------------------- - final def createDatabase(dbDefinition: CatalogDatabase, ignoreIfExists: Boolean): Unit = { - val db = dbDefinition.name - postToAll(CreateDatabasePreEvent(db)) - doCreateDatabase(dbDefinition, ignoreIfExists) - postToAll(CreateDatabaseEvent(db)) - } + def createDatabase(dbDefinition: CatalogDatabase, ignoreIfExists: Boolean): Unit - protected def doCreateDatabase(dbDefinition: CatalogDatabase, ignoreIfExists: Boolean): Unit - - final def dropDatabase(db: String, ignoreIfNotExists: Boolean, cascade: Boolean): Unit = { - postToAll(DropDatabasePreEvent(db)) - doDropDatabase(db, ignoreIfNotExists, cascade) - postToAll(DropDatabaseEvent(db)) - } - - protected def doDropDatabase(db: String, ignoreIfNotExists: Boolean, cascade: Boolean): Unit + def dropDatabase(db: String, ignoreIfNotExists: Boolean, cascade: Boolean): Unit /** * Alter a database whose name matches the one specified in `dbDefinition`, @@ -87,14 +76,7 @@ abstract class ExternalCatalog * Note: If the underlying implementation does not support altering a certain field, * this becomes a no-op. */ - final def alterDatabase(dbDefinition: CatalogDatabase): Unit = { - val db = dbDefinition.name - postToAll(AlterDatabasePreEvent(db)) - doAlterDatabase(dbDefinition) - postToAll(AlterDatabaseEvent(db)) - } - - protected def doAlterDatabase(dbDefinition: CatalogDatabase): Unit + def alterDatabase(dbDefinition: CatalogDatabase): Unit def getDatabase(db: String): CatalogDatabase @@ -110,41 +92,15 @@ abstract class ExternalCatalog // Tables // -------------------------------------------------------------------------- - final def createTable(tableDefinition: CatalogTable, ignoreIfExists: Boolean): Unit = { - val db = tableDefinition.database - val name = tableDefinition.identifier.table - val tableDefinitionWithVersion = - tableDefinition.copy(createVersion = org.apache.spark.SPARK_VERSION) - postToAll(CreateTablePreEvent(db, name)) - doCreateTable(tableDefinitionWithVersion, ignoreIfExists) - postToAll(CreateTableEvent(db, name)) - } - - protected def doCreateTable(tableDefinition: CatalogTable, ignoreIfExists: Boolean): Unit - - final def dropTable( - db: String, - table: String, - ignoreIfNotExists: Boolean, - purge: Boolean): Unit = { - postToAll(DropTablePreEvent(db, table)) - doDropTable(db, table, ignoreIfNotExists, purge) - postToAll(DropTableEvent(db, table)) - } + def createTable(tableDefinition: CatalogTable, ignoreIfExists: Boolean): Unit - protected def doDropTable( + def dropTable( db: String, table: String, ignoreIfNotExists: Boolean, purge: Boolean): Unit - final def renameTable(db: String, oldName: String, newName: String): Unit = { - postToAll(RenameTablePreEvent(db, oldName, newName)) - doRenameTable(db, oldName, newName) - postToAll(RenameTableEvent(db, oldName, newName)) - } - - protected def doRenameTable(db: String, oldName: String, newName: String): Unit + def renameTable(db: String, oldName: String, newName: String): Unit /** * Alter a table whose database and name match the ones specified in `tableDefinition`, assuming @@ -154,15 +110,7 @@ abstract class ExternalCatalog * Note: If the underlying implementation does not support altering a certain field, * this becomes a no-op. */ - final def alterTable(tableDefinition: CatalogTable): Unit = { - val db = tableDefinition.database - val name = tableDefinition.identifier.table - postToAll(AlterTablePreEvent(db, name, AlterTableKind.TABLE)) - doAlterTable(tableDefinition) - postToAll(AlterTableEvent(db, name, AlterTableKind.TABLE)) - } - - protected def doAlterTable(tableDefinition: CatalogTable): Unit + def alterTable(tableDefinition: CatalogTable): Unit /** * Alter the data schema of a table identified by the provided database and table name. The new @@ -173,22 +121,10 @@ abstract class ExternalCatalog * @param table Name of table to alter schema for * @param newDataSchema Updated data schema to be used for the table. */ - final def alterTableDataSchema(db: String, table: String, newDataSchema: StructType): Unit = { - postToAll(AlterTablePreEvent(db, table, AlterTableKind.DATASCHEMA)) - doAlterTableDataSchema(db, table, newDataSchema) - postToAll(AlterTableEvent(db, table, AlterTableKind.DATASCHEMA)) - } - - protected def doAlterTableDataSchema(db: String, table: String, newDataSchema: StructType): Unit + def alterTableDataSchema(db: String, table: String, newDataSchema: StructType): Unit /** Alter the statistics of a table. If `stats` is None, then remove all existing statistics. */ - final def alterTableStats(db: String, table: String, stats: Option[CatalogStatistics]): Unit = { - postToAll(AlterTablePreEvent(db, table, AlterTableKind.STATS)) - doAlterTableStats(db, table, stats) - postToAll(AlterTableEvent(db, table, AlterTableKind.STATS)) - } - - protected def doAlterTableStats(db: String, table: String, stats: Option[CatalogStatistics]): Unit + def alterTableStats(db: String, table: String, stats: Option[CatalogStatistics]): Unit def getTable(db: String, table: String): CatalogTable @@ -340,49 +276,17 @@ abstract class ExternalCatalog // Functions // -------------------------------------------------------------------------- - final def createFunction(db: String, funcDefinition: CatalogFunction): Unit = { - val name = funcDefinition.identifier.funcName - postToAll(CreateFunctionPreEvent(db, name)) - doCreateFunction(db, funcDefinition) - postToAll(CreateFunctionEvent(db, name)) - } + def createFunction(db: String, funcDefinition: CatalogFunction): Unit - protected def doCreateFunction(db: String, funcDefinition: CatalogFunction): Unit + def dropFunction(db: String, funcName: String): Unit - final def dropFunction(db: String, funcName: String): Unit = { - postToAll(DropFunctionPreEvent(db, funcName)) - doDropFunction(db, funcName) - postToAll(DropFunctionEvent(db, funcName)) - } + def alterFunction(db: String, funcDefinition: CatalogFunction): Unit - protected def doDropFunction(db: String, funcName: String): Unit - - final def alterFunction(db: String, funcDefinition: CatalogFunction): Unit = { - val name = funcDefinition.identifier.funcName - postToAll(AlterFunctionPreEvent(db, name)) - doAlterFunction(db, funcDefinition) - postToAll(AlterFunctionEvent(db, name)) - } - - protected def doAlterFunction(db: String, funcDefinition: CatalogFunction): Unit - - final def renameFunction(db: String, oldName: String, newName: String): Unit = { - postToAll(RenameFunctionPreEvent(db, oldName, newName)) - doRenameFunction(db, oldName, newName) - postToAll(RenameFunctionEvent(db, oldName, newName)) - } - - protected def doRenameFunction(db: String, oldName: String, newName: String): Unit + def renameFunction(db: String, oldName: String, newName: String): Unit def getFunction(db: String, funcName: String): CatalogFunction def functionExists(db: String, funcName: String): Boolean def listFunctions(db: String, pattern: String): Seq[String] - - override protected def doPostEvent( - listener: ExternalCatalogEventListener, - event: ExternalCatalogEvent): Unit = { - listener.onEvent(event) - } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogWithListener.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogWithListener.scala new file mode 100644 index 0000000000000..2f009be5816fa --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogWithListener.scala @@ -0,0 +1,298 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.catalog + +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.types.StructType +import org.apache.spark.util.ListenerBus + +/** + * Wraps an ExternalCatalog to provide listener events. + */ +class ExternalCatalogWithListener(delegate: ExternalCatalog) + extends ExternalCatalog + with ListenerBus[ExternalCatalogEventListener, ExternalCatalogEvent] { + import CatalogTypes.TablePartitionSpec + + def unwrapped: ExternalCatalog = delegate + + override protected def doPostEvent( + listener: ExternalCatalogEventListener, + event: ExternalCatalogEvent): Unit = { + listener.onEvent(event) + } + + // -------------------------------------------------------------------------- + // Databases + // -------------------------------------------------------------------------- + + override def createDatabase(dbDefinition: CatalogDatabase, ignoreIfExists: Boolean): Unit = { + val db = dbDefinition.name + postToAll(CreateDatabasePreEvent(db)) + delegate.createDatabase(dbDefinition, ignoreIfExists) + postToAll(CreateDatabaseEvent(db)) + } + + override def dropDatabase(db: String, ignoreIfNotExists: Boolean, cascade: Boolean): Unit = { + postToAll(DropDatabasePreEvent(db)) + delegate.dropDatabase(db, ignoreIfNotExists, cascade) + postToAll(DropDatabaseEvent(db)) + } + + override def alterDatabase(dbDefinition: CatalogDatabase): Unit = { + val db = dbDefinition.name + postToAll(AlterDatabasePreEvent(db)) + delegate.alterDatabase(dbDefinition) + postToAll(AlterDatabaseEvent(db)) + } + + override def getDatabase(db: String): CatalogDatabase = { + delegate.getDatabase(db) + } + + override def databaseExists(db: String): Boolean = { + delegate.databaseExists(db) + } + + override def listDatabases(): Seq[String] = { + delegate.listDatabases() + } + + override def listDatabases(pattern: String): Seq[String] = { + delegate.listDatabases(pattern) + } + + override def setCurrentDatabase(db: String): Unit = { + delegate.setCurrentDatabase(db) + } + + // -------------------------------------------------------------------------- + // Tables + // -------------------------------------------------------------------------- + + override def createTable(tableDefinition: CatalogTable, ignoreIfExists: Boolean): Unit = { + val db = tableDefinition.database + val name = tableDefinition.identifier.table + val tableDefinitionWithVersion = + tableDefinition.copy(createVersion = org.apache.spark.SPARK_VERSION) + postToAll(CreateTablePreEvent(db, name)) + delegate.createTable(tableDefinitionWithVersion, ignoreIfExists) + postToAll(CreateTableEvent(db, name)) + } + + override def dropTable( + db: String, + table: String, + ignoreIfNotExists: Boolean, + purge: Boolean): Unit = { + postToAll(DropTablePreEvent(db, table)) + delegate.dropTable(db, table, ignoreIfNotExists, purge) + postToAll(DropTableEvent(db, table)) + } + + override def renameTable(db: String, oldName: String, newName: String): Unit = { + postToAll(RenameTablePreEvent(db, oldName, newName)) + delegate.renameTable(db, oldName, newName) + postToAll(RenameTableEvent(db, oldName, newName)) + } + + override def alterTable(tableDefinition: CatalogTable): Unit = { + val db = tableDefinition.database + val name = tableDefinition.identifier.table + postToAll(AlterTablePreEvent(db, name, AlterTableKind.TABLE)) + delegate.alterTable(tableDefinition) + postToAll(AlterTableEvent(db, name, AlterTableKind.TABLE)) + } + + override def alterTableDataSchema(db: String, table: String, newDataSchema: StructType): Unit = { + postToAll(AlterTablePreEvent(db, table, AlterTableKind.DATASCHEMA)) + delegate.alterTableDataSchema(db, table, newDataSchema) + postToAll(AlterTableEvent(db, table, AlterTableKind.DATASCHEMA)) + } + + override def alterTableStats( + db: String, + table: String, + stats: Option[CatalogStatistics]): Unit = { + postToAll(AlterTablePreEvent(db, table, AlterTableKind.STATS)) + delegate.alterTableStats(db, table, stats) + postToAll(AlterTableEvent(db, table, AlterTableKind.STATS)) + } + + override def getTable(db: String, table: String): CatalogTable = { + delegate.getTable(db, table) + } + + override def tableExists(db: String, table: String): Boolean = { + delegate.tableExists(db, table) + } + + override def listTables(db: String): Seq[String] = { + delegate.listTables(db) + } + + override def listTables(db: String, pattern: String): Seq[String] = { + delegate.listTables(db, pattern) + } + + override def loadTable( + db: String, + table: String, + loadPath: String, + isOverwrite: Boolean, + isSrcLocal: Boolean): Unit = { + delegate.loadTable(db, table, loadPath, isOverwrite, isSrcLocal) + } + + override def loadPartition( + db: String, + table: String, + loadPath: String, + partition: TablePartitionSpec, + isOverwrite: Boolean, + inheritTableSpecs: Boolean, + isSrcLocal: Boolean): Unit = { + delegate.loadPartition( + db, table, loadPath, partition, isOverwrite, inheritTableSpecs, isSrcLocal) + } + + override def loadDynamicPartitions( + db: String, + table: String, + loadPath: String, + partition: TablePartitionSpec, + replace: Boolean, + numDP: Int): Unit = { + delegate.loadDynamicPartitions(db, table, loadPath, partition, replace, numDP) + } + + // -------------------------------------------------------------------------- + // Partitions + // -------------------------------------------------------------------------- + + override def createPartitions( + db: String, + table: String, + parts: Seq[CatalogTablePartition], + ignoreIfExists: Boolean): Unit = { + delegate.createPartitions(db, table, parts, ignoreIfExists) + } + + override def dropPartitions( + db: String, + table: String, + partSpecs: Seq[TablePartitionSpec], + ignoreIfNotExists: Boolean, + purge: Boolean, + retainData: Boolean): Unit = { + delegate.dropPartitions(db, table, partSpecs, ignoreIfNotExists, purge, retainData) + } + + override def renamePartitions( + db: String, + table: String, + specs: Seq[TablePartitionSpec], + newSpecs: Seq[TablePartitionSpec]): Unit = { + delegate.renamePartitions(db, table, specs, newSpecs) + } + + override def alterPartitions( + db: String, + table: String, + parts: Seq[CatalogTablePartition]): Unit = { + delegate.alterPartitions(db, table, parts) + } + + override def getPartition( + db: String, + table: String, + spec: TablePartitionSpec): CatalogTablePartition = { + delegate.getPartition(db, table, spec) + } + + override def getPartitionOption( + db: String, + table: String, + spec: TablePartitionSpec): Option[CatalogTablePartition] = { + delegate.getPartitionOption(db, table, spec) + } + + override def listPartitionNames( + db: String, + table: String, + partialSpec: Option[TablePartitionSpec] = None): Seq[String] = { + delegate.listPartitionNames(db, table, partialSpec) + } + + override def listPartitions( + db: String, + table: String, + partialSpec: Option[TablePartitionSpec] = None): Seq[CatalogTablePartition] = { + delegate.listPartitions(db, table, partialSpec) + } + + override def listPartitionsByFilter( + db: String, + table: String, + predicates: Seq[Expression], + defaultTimeZoneId: String): Seq[CatalogTablePartition] = { + delegate.listPartitionsByFilter(db, table, predicates, defaultTimeZoneId) + } + + // -------------------------------------------------------------------------- + // Functions + // -------------------------------------------------------------------------- + + override def createFunction(db: String, funcDefinition: CatalogFunction): Unit = { + val name = funcDefinition.identifier.funcName + postToAll(CreateFunctionPreEvent(db, name)) + delegate.createFunction(db, funcDefinition) + postToAll(CreateFunctionEvent(db, name)) + } + + override def dropFunction(db: String, funcName: String): Unit = { + postToAll(DropFunctionPreEvent(db, funcName)) + delegate.dropFunction(db, funcName) + postToAll(DropFunctionEvent(db, funcName)) + } + + override def alterFunction(db: String, funcDefinition: CatalogFunction): Unit = { + val name = funcDefinition.identifier.funcName + postToAll(AlterFunctionPreEvent(db, name)) + delegate.alterFunction(db, funcDefinition) + postToAll(AlterFunctionEvent(db, name)) + } + + override def renameFunction(db: String, oldName: String, newName: String): Unit = { + postToAll(RenameFunctionPreEvent(db, oldName, newName)) + delegate.renameFunction(db, oldName, newName) + postToAll(RenameFunctionEvent(db, oldName, newName)) + } + + override def getFunction(db: String, funcName: String): CatalogFunction = { + delegate.getFunction(db, funcName) + } + + override def functionExists(db: String, funcName: String): Boolean = { + delegate.functionExists(db, funcName) + } + + override def listFunctions(db: String, pattern: String): Seq[String] = { + delegate.listFunctions(db, pattern) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala index 8eacfa058bd52..741dc46b07382 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala @@ -98,7 +98,7 @@ class InMemoryCatalog( // Databases // -------------------------------------------------------------------------- - override protected def doCreateDatabase( + override def createDatabase( dbDefinition: CatalogDatabase, ignoreIfExists: Boolean): Unit = synchronized { if (catalog.contains(dbDefinition.name)) { @@ -119,7 +119,7 @@ class InMemoryCatalog( } } - override protected def doDropDatabase( + override def dropDatabase( db: String, ignoreIfNotExists: Boolean, cascade: Boolean): Unit = synchronized { @@ -152,7 +152,7 @@ class InMemoryCatalog( } } - override def doAlterDatabase(dbDefinition: CatalogDatabase): Unit = synchronized { + override def alterDatabase(dbDefinition: CatalogDatabase): Unit = synchronized { requireDbExists(dbDefinition.name) catalog(dbDefinition.name).db = dbDefinition } @@ -180,7 +180,7 @@ class InMemoryCatalog( // Tables // -------------------------------------------------------------------------- - override protected def doCreateTable( + override def createTable( tableDefinition: CatalogTable, ignoreIfExists: Boolean): Unit = synchronized { assert(tableDefinition.identifier.database.isDefined) @@ -221,7 +221,7 @@ class InMemoryCatalog( } } - override protected def doDropTable( + override def dropTable( db: String, table: String, ignoreIfNotExists: Boolean, @@ -264,7 +264,7 @@ class InMemoryCatalog( } } - override protected def doRenameTable( + override def renameTable( db: String, oldName: String, newName: String): Unit = synchronized { @@ -294,7 +294,7 @@ class InMemoryCatalog( catalog(db).tables.remove(oldName) } - override def doAlterTable(tableDefinition: CatalogTable): Unit = synchronized { + override def alterTable(tableDefinition: CatalogTable): Unit = synchronized { assert(tableDefinition.identifier.database.isDefined) val db = tableDefinition.identifier.database.get requireTableExists(db, tableDefinition.identifier.table) @@ -303,7 +303,7 @@ class InMemoryCatalog( catalog(db).tables(tableDefinition.identifier.table).table = newTableDefinition } - override def doAlterTableDataSchema( + override def alterTableDataSchema( db: String, table: String, newDataSchema: StructType): Unit = synchronized { @@ -313,7 +313,7 @@ class InMemoryCatalog( catalog(db).tables(table).table = origTable.copy(schema = newSchema) } - override def doAlterTableStats( + override def alterTableStats( db: String, table: String, stats: Option[CatalogStatistics]): Unit = synchronized { @@ -564,24 +564,24 @@ class InMemoryCatalog( // Functions // -------------------------------------------------------------------------- - override protected def doCreateFunction(db: String, func: CatalogFunction): Unit = synchronized { + override def createFunction(db: String, func: CatalogFunction): Unit = synchronized { requireDbExists(db) requireFunctionNotExists(db, func.identifier.funcName) catalog(db).functions.put(func.identifier.funcName, func) } - override protected def doDropFunction(db: String, funcName: String): Unit = synchronized { + override def dropFunction(db: String, funcName: String): Unit = synchronized { requireFunctionExists(db, funcName) catalog(db).functions.remove(funcName) } - override protected def doAlterFunction(db: String, func: CatalogFunction): Unit = synchronized { + override def alterFunction(db: String, func: CatalogFunction): Unit = synchronized { requireDbExists(db) requireFunctionExists(db, func.identifier.funcName) catalog(db).functions.put(func.identifier.funcName, func) } - override protected def doRenameFunction( + override def renameFunction( db: String, oldName: String, newName: String): Unit = synchronized { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogEventSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogEventSuite.scala index 1acbe34d9a075..2fcaeca34db3f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogEventSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogEventSuite.scala @@ -36,7 +36,7 @@ class ExternalCatalogEventSuite extends SparkFunSuite { private def testWithCatalog( name: String)( f: (ExternalCatalog, Seq[ExternalCatalogEvent] => Unit) => Unit): Unit = test(name) { - val catalog = newCatalog + val catalog = new ExternalCatalogWithListener(newCatalog) val recorder = mutable.Buffer.empty[ExternalCatalogEvent] catalog.addListener(new ExternalCatalogEventListener { override def onEvent(event: ExternalCatalogEvent): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala index baea4ceebf8e3..5b6160e2b408f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala @@ -99,7 +99,7 @@ private[sql] class SharedState(val sparkContext: SparkContext) extends Logging { /** * A catalog that interacts with external systems. */ - lazy val externalCatalog: ExternalCatalog = { + lazy val externalCatalog: ExternalCatalogWithListener = { val externalCatalog = SharedState.reflect[ExternalCatalog, SparkConf, Configuration]( SharedState.externalCatalogClassName(sparkContext.conf), sparkContext.conf, @@ -117,14 +117,17 @@ private[sql] class SharedState(val sparkContext: SparkContext) extends Logging { externalCatalog.createDatabase(defaultDbDefinition, ignoreIfExists = true) } + // Wrap to provide catalog events + val wrapped = new ExternalCatalogWithListener(externalCatalog) + // Make sure we propagate external catalog events to the spark listener bus - externalCatalog.addListener(new ExternalCatalogEventListener { + wrapped.addListener(new ExternalCatalogEventListener { override def onEvent(event: ExternalCatalogEvent): Unit = { sparkContext.listenerBus.post(event) } }) - externalCatalog + wrapped } /** diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala index cbd75ad12d430..8980bcf885589 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala @@ -50,7 +50,7 @@ private[hive] object SparkSQLEnv extends Logging { sqlContext = sparkSession.sqlContext val metadataHive = sparkSession - .sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog].client + .sharedState.externalCatalog.unwrapped.asInstanceOf[HiveExternalCatalog].client metadataHive.setOut(new PrintStream(System.out, true, "UTF-8")) metadataHive.setInfo(new PrintStream(System.err, true, "UTF-8")) metadataHive.setError(new PrintStream(System.err, true, "UTF-8")) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index 28c340a176d91..011a3ba553cb2 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -158,13 +158,13 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat // Databases // -------------------------------------------------------------------------- - override protected def doCreateDatabase( + override def createDatabase( dbDefinition: CatalogDatabase, ignoreIfExists: Boolean): Unit = withClient { client.createDatabase(dbDefinition, ignoreIfExists) } - override protected def doDropDatabase( + override def dropDatabase( db: String, ignoreIfNotExists: Boolean, cascade: Boolean): Unit = withClient { @@ -177,7 +177,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat * * Note: As of now, this only supports altering database properties! */ - override def doAlterDatabase(dbDefinition: CatalogDatabase): Unit = withClient { + override def alterDatabase(dbDefinition: CatalogDatabase): Unit = withClient { val existingDb = getDatabase(dbDefinition.name) if (existingDb.properties == dbDefinition.properties) { logWarning(s"Request to alter database ${dbDefinition.name} is a no-op because " + @@ -211,7 +211,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat // Tables // -------------------------------------------------------------------------- - override protected def doCreateTable( + override def createTable( tableDefinition: CatalogTable, ignoreIfExists: Boolean): Unit = withClient { assert(tableDefinition.identifier.database.isDefined) @@ -480,7 +480,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat } } - override protected def doDropTable( + override def dropTable( db: String, table: String, ignoreIfNotExists: Boolean, @@ -489,7 +489,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat client.dropTable(db, table, ignoreIfNotExists, purge) } - override protected def doRenameTable( + override def renameTable( db: String, oldName: String, newName: String): Unit = withClient { @@ -540,7 +540,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat * Note: As of now, this doesn't support altering table schema, partition column names and bucket * specification. We will ignore them even if users do specify different values for these fields. */ - override def doAlterTable(tableDefinition: CatalogTable): Unit = withClient { + override def alterTable(tableDefinition: CatalogTable): Unit = withClient { assert(tableDefinition.identifier.database.isDefined) val db = tableDefinition.identifier.database.get requireTableExists(db, tableDefinition.identifier.table) @@ -624,7 +624,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat * data schema should not have conflict column names with the existing partition columns, and * should still contain all the existing data columns. */ - override def doAlterTableDataSchema( + override def alterTableDataSchema( db: String, table: String, newDataSchema: StructType): Unit = withClient { @@ -656,7 +656,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat } /** Alter the statistics of a table. If `stats` is None, then remove all existing statistics. */ - override def doAlterTableStats( + override def alterTableStats( db: String, table: String, stats: Option[CatalogStatistics]): Unit = withClient { @@ -1208,7 +1208,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat // Functions // -------------------------------------------------------------------------- - override protected def doCreateFunction( + override def createFunction( db: String, funcDefinition: CatalogFunction): Unit = withClient { requireDbExists(db) @@ -1221,12 +1221,12 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat client.createFunction(db, funcDefinition.copy(identifier = functionIdentifier)) } - override protected def doDropFunction(db: String, name: String): Unit = withClient { + override def dropFunction(db: String, name: String): Unit = withClient { requireFunctionExists(db, name) client.dropFunction(db, name) } - override protected def doAlterFunction( + override def alterFunction( db: String, funcDefinition: CatalogFunction): Unit = withClient { requireDbExists(db) val functionName = funcDefinition.identifier.funcName.toLowerCase(Locale.ROOT) @@ -1235,7 +1235,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat client.alterFunction(db, funcDefinition.copy(identifier = functionIdentifier)) } - override protected def doRenameFunction( + override def renameFunction( db: String, oldName: String, newName: String): Unit = withClient { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala index e5aff3b99d0b9..94ddeae1bf547 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala @@ -30,7 +30,7 @@ import org.apache.hadoop.hive.ql.udf.generic.{AbstractGenericUDAFResolver, Gener import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.FunctionIdentifier import org.apache.spark.sql.catalyst.analysis.FunctionRegistry -import org.apache.spark.sql.catalyst.catalog.{CatalogFunction, FunctionResourceLoader, GlobalTempViewManager, SessionCatalog} +import org.apache.spark.sql.catalyst.catalog.{CatalogFunction, ExternalCatalog, FunctionResourceLoader, GlobalTempViewManager, SessionCatalog} import org.apache.spark.sql.catalyst.expressions.{Cast, Expression} import org.apache.spark.sql.catalyst.parser.ParserInterface import org.apache.spark.sql.hive.HiveShim.HiveFunctionWrapper @@ -39,7 +39,7 @@ import org.apache.spark.sql.types.{DecimalType, DoubleType} private[sql] class HiveSessionCatalog( - externalCatalogBuilder: () => HiveExternalCatalog, + externalCatalogBuilder: () => ExternalCatalog, globalTempViewManagerBuilder: () => GlobalTempViewManager, val metastoreCatalog: HiveMetastoreCatalog, functionRegistry: FunctionRegistry, diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala index 40b9bb51ca9a0..2882672f327c4 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.hive import org.apache.spark.annotation.{Experimental, InterfaceStability} import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.analysis.Analyzer +import org.apache.spark.sql.catalyst.catalog.ExternalCatalogWithListener import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.SparkPlanner @@ -35,14 +36,14 @@ import org.apache.spark.sql.internal.{BaseSessionStateBuilder, SessionResourceLo class HiveSessionStateBuilder(session: SparkSession, parentState: Option[SessionState] = None) extends BaseSessionStateBuilder(session, parentState) { - private def externalCatalog: HiveExternalCatalog = - session.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog] + private def externalCatalog: ExternalCatalogWithListener = session.sharedState.externalCatalog /** * Create a Hive aware resource loader. */ override protected lazy val resourceLoader: HiveSessionResourceLoader = { - new HiveSessionResourceLoader(session, () => externalCatalog.client) + new HiveSessionResourceLoader( + session, () => externalCatalog.unwrapped.asInstanceOf[HiveExternalCatalog].client) } /** diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/SaveAsHiveFile.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/SaveAsHiveFile.scala index 6a7b25b36d9a5..e0f7375387d24 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/SaveAsHiveFile.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/SaveAsHiveFile.scala @@ -122,7 +122,7 @@ private[hive] trait SaveAsHiveFile extends DataWritingCommand { allSupportedHiveVersions) val externalCatalog = sparkSession.sharedState.externalCatalog - val hiveVersion = externalCatalog.asInstanceOf[HiveExternalCatalog].client.version + val hiveVersion = externalCatalog.unwrapped.asInstanceOf[HiveExternalCatalog].client.version val stagingDir = hadoopConf.get("hive.exec.stagingdir", ".hive-staging") val scratchDir = hadoopConf.get("hive.exec.scratchdir", "/tmp/hive") diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index 965aea2b61456..ee3f99ab7e9bb 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -35,6 +35,7 @@ import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.internal.Logging import org.apache.spark.sql.{SparkSession, SQLContext} import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation +import org.apache.spark.sql.catalyst.catalog.{ExternalCatalog, ExternalCatalogWithListener} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, OneRowRelation} import org.apache.spark.sql.execution.{QueryExecution, SQLExecution} import org.apache.spark.sql.execution.command.CacheTableCommand @@ -83,11 +84,11 @@ private[hive] class TestHiveSharedState( hiveClient: Option[HiveClient] = None) extends SharedState(sc) { - override lazy val externalCatalog: TestHiveExternalCatalog = { - new TestHiveExternalCatalog( + override lazy val externalCatalog: ExternalCatalogWithListener = { + new ExternalCatalogWithListener(new TestHiveExternalCatalog( sc.conf, sc.hadoopConfiguration, - hiveClient) + hiveClient)) } } @@ -208,7 +209,9 @@ private[hive] class TestHiveSparkSession( new TestHiveSessionStateBuilder(this, parentSessionState).build() } - lazy val metadataHive: HiveClient = sharedState.externalCatalog.client.newSession() + lazy val metadataHive: HiveClient = { + sharedState.externalCatalog.unwrapped.asInstanceOf[HiveExternalCatalog].client.newSession() + } override def newSession(): TestHiveSparkSession = { new TestHiveSparkSession(sc, Some(sharedState), None, loadTestTables) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalSessionCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalSessionCatalogSuite.scala index 285f35b0b0eac..fd5f47e428239 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalSessionCatalogSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalSessionCatalogSuite.scala @@ -26,7 +26,7 @@ class HiveExternalSessionCatalogSuite extends SessionCatalogSuite with TestHiveS private val externalCatalog = { val catalog = spark.sharedState.externalCatalog - catalog.asInstanceOf[HiveExternalCatalog].client.reset() + catalog.unwrapped.asInstanceOf[HiveExternalCatalog].client.reset() catalog } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSchemaInferenceSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSchemaInferenceSuite.scala index f2d27671094d7..51a48a20daaa2 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSchemaInferenceSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSchemaInferenceSuite.scala @@ -50,7 +50,8 @@ class HiveSchemaInferenceSuite FileStatusCache.resetForTesting() } - private val externalCatalog = spark.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog] + private val externalCatalog = + spark.sharedState.externalCatalog.unwrapped.asInstanceOf[HiveExternalCatalog] private val client = externalCatalog.client // Return a copy of the given schema with all field names converted to lower case. diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSessionStateSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSessionStateSuite.scala index ecc09cdcdbeaf..a3579862c9e59 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSessionStateSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSessionStateSuite.scala @@ -44,7 +44,8 @@ class HiveSessionStateSuite extends SessionStateSuite with TestHiveSingleton { val conf = sparkSession.sparkContext.hadoopConfiguration val oldValue = conf.get(ConfVars.METASTORECONNECTURLKEY.varname) sparkSession.cloneSession() - sparkSession.sharedState.externalCatalog.client.newSession() + sparkSession.sharedState.externalCatalog.unwrapped.asInstanceOf[HiveExternalCatalog] + .client.newSession() val newValue = conf.get(ConfVars.METASTORECONNECTURLKEY.varname) assert(oldValue == newValue, "cloneSession and then newSession should not affect the Derby directory") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala index 079fe45860544..aa5b531992613 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala @@ -354,7 +354,7 @@ object SetMetastoreURLTest extends Logging { // HiveExternalCatalog is used when Hive support is enabled. val actualMetastoreURL = - spark.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog].client + spark.sharedState.externalCatalog.unwrapped.asInstanceOf[HiveExternalCatalog].client .getConf("javax.jdo.option.ConnectionURL", "this_is_a_wrong_URL") logInfo(s"javax.jdo.option.ConnectionURL is $actualMetastoreURL") @@ -780,7 +780,8 @@ object SPARK_18360 { val defaultDbLocation = spark.catalog.getDatabase("default").locationUri assert(new Path(defaultDbLocation) == new Path(spark.sharedState.warehousePath)) - val hiveClient = spark.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog].client + val hiveClient = + spark.sharedState.externalCatalog.unwrapped.asInstanceOf[HiveExternalCatalog].client try { val tableMeta = CatalogTable( diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ShowCreateTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ShowCreateTableSuite.scala index fad81c7e9474e..473bbced41b31 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ShowCreateTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ShowCreateTableSuite.scala @@ -289,7 +289,8 @@ class ShowCreateTableSuite extends QueryTest with SQLTestUtils with TestHiveSing } private def createRawHiveTable(ddl: String): Unit = { - hiveContext.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog].client.runSqlHive(ddl) + hiveContext.sharedState.externalCatalog.unwrapped.asInstanceOf[HiveExternalCatalog] + .client.runSqlHive(ddl) } private def checkCreateTable(table: String): Unit = { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala index 6176273c88db1..dc96ec416afd8 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala @@ -134,8 +134,8 @@ class VersionsSuite extends SparkFunSuite with Logging { client = buildClient(version, hadoopConf, HiveUtils.formatTimeVarsForHiveClient(hadoopConf)) if (versionSpark != null) versionSpark.reset() versionSpark = TestHiveVersion(client) - assert(versionSpark.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog].client - .version.fullVersion.startsWith(version)) + assert(versionSpark.sharedState.externalCatalog.unwrapped.asInstanceOf[HiveExternalCatalog] + .client.version.fullVersion.startsWith(version)) } def table(database: String, tableName: String): CatalogTable = { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala index daac6af9b557f..0341c3b378918 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -1355,7 +1355,8 @@ class HiveDDLSuite val indexName = tabName + "_index" withTable(tabName) { // Spark SQL does not support creating index. Thus, we have to use Hive client. - val client = spark.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog].client + val client = + spark.sharedState.externalCatalog.unwrapped.asInstanceOf[HiveExternalCatalog].client sql(s"CREATE TABLE $tabName(a int)") try { @@ -1393,7 +1394,8 @@ class HiveDDLSuite val tabName = "tab1" withTable(tabName) { // Spark SQL does not support creating skewed table. Thus, we have to use Hive client. - val client = spark.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog].client + val client = + spark.sharedState.externalCatalog.unwrapped.asInstanceOf[HiveExternalCatalog].client client.runSqlHive( s""" |CREATE Table $tabName(col1 int, col2 int) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 704a410b6a37b..828c18a770c80 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -2099,7 +2099,8 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { Seq("orc", "parquet").foreach { format => test(s"SPARK-18355 Read data from a hive table with a new column - $format") { - val client = spark.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog].client + val client = + spark.sharedState.externalCatalog.unwrapped.asInstanceOf[HiveExternalCatalog].client Seq("true", "false").foreach { value => withSQLConf( diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHiveSingleton.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHiveSingleton.scala index d3fff37c3424d..d50bf0b8fd603 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHiveSingleton.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHiveSingleton.scala @@ -30,7 +30,7 @@ trait TestHiveSingleton extends SparkFunSuite with BeforeAndAfterAll { protected val spark: SparkSession = TestHive.sparkSession protected val hiveContext: TestHiveContext = TestHive protected val hiveClient: HiveClient = - spark.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog].client + spark.sharedState.externalCatalog.unwrapped.asInstanceOf[HiveExternalCatalog].client protected override def afterAll(): Unit = { try { From a634d66ce767bd5e1d8553d1a2c32e2b1a80f642 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Mon, 7 May 2018 13:00:18 +0800 Subject: [PATCH 0741/1161] [SPARK-24126][PYSPARK] Use build-specific temp directory for pyspark tests. This avoids polluting and leaving garbage behind in /tmp, and allows the usual build tools to clean up any leftover files. Author: Marcelo Vanzin Closes #21198 from vanzin/SPARK-24126. --- python/pyspark/sql/tests.py | 4 ++-- python/pyspark/streaming/tests.py | 6 ++++-- python/pyspark/tests.py | 33 ++++++++++++++++++++----------- python/run-tests.py | 29 +++++++++++++++++++++++++-- 4 files changed, 54 insertions(+), 18 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index cc6acfdb07d99..16aa9378ad8ee 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3092,8 +3092,8 @@ def test_hivecontext(self): |print(hive_context.sql("show databases").collect()) """) proc = subprocess.Popen( - [self.sparkSubmit, "--master", "local-cluster[1,1,1024]", - "--driver-class-path", hive_site_dir, script], + self.sparkSubmit + ["--master", "local-cluster[1,1,1024]", + "--driver-class-path", hive_site_dir, script], stdout=subprocess.PIPE) out, err = proc.communicate() self.assertEqual(0, proc.returncode) diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index d77f1baa1f344..e4a428a0b27e7 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -63,7 +63,7 @@ def setUpClass(cls): class_name = cls.__name__ conf = SparkConf().set("spark.default.parallelism", 1) cls.sc = SparkContext(appName=class_name, conf=conf) - cls.sc.setCheckpointDir("/tmp") + cls.sc.setCheckpointDir(tempfile.mkdtemp()) @classmethod def tearDownClass(cls): @@ -1549,7 +1549,9 @@ def search_kinesis_asl_assembly_jar(): kinesis_jar_present = True jars = "%s,%s,%s" % (kafka_assembly_jar, flume_assembly_jar, kinesis_asl_assembly_jar) - os.environ["PYSPARK_SUBMIT_ARGS"] = "--jars %s pyspark-shell" % jars + existing_args = os.environ.get("PYSPARK_SUBMIT_ARGS", "pyspark-shell") + jars_args = "--jars %s" % jars + os.environ["PYSPARK_SUBMIT_ARGS"] = " ".join([jars_args, existing_args]) testcases = [BasicOperationTests, WindowFunctionTests, StreamingContextTests, CheckpointTests, StreamingListenerTests] diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 8392d7f29af53..7b8ce2c6b799f 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -1951,7 +1951,12 @@ class SparkSubmitTests(unittest.TestCase): def setUp(self): self.programDir = tempfile.mkdtemp() - self.sparkSubmit = os.path.join(os.environ.get("SPARK_HOME"), "bin", "spark-submit") + tmp_dir = tempfile.gettempdir() + self.sparkSubmit = [ + os.path.join(os.environ.get("SPARK_HOME"), "bin", "spark-submit"), + "--conf", "spark.driver.extraJavaOptions=-Djava.io.tmpdir={0}".format(tmp_dir), + "--conf", "spark.executor.extraJavaOptions=-Djava.io.tmpdir={0}".format(tmp_dir), + ] def tearDown(self): shutil.rmtree(self.programDir) @@ -2017,7 +2022,7 @@ def test_single_script(self): |sc = SparkContext() |print(sc.parallelize([1, 2, 3]).map(lambda x: x * 2).collect()) """) - proc = subprocess.Popen([self.sparkSubmit, script], stdout=subprocess.PIPE) + proc = subprocess.Popen(self.sparkSubmit + [script], stdout=subprocess.PIPE) out, err = proc.communicate() self.assertEqual(0, proc.returncode) self.assertIn("[2, 4, 6]", out.decode('utf-8')) @@ -2033,7 +2038,7 @@ def test_script_with_local_functions(self): |sc = SparkContext() |print(sc.parallelize([1, 2, 3]).map(foo).collect()) """) - proc = subprocess.Popen([self.sparkSubmit, script], stdout=subprocess.PIPE) + proc = subprocess.Popen(self.sparkSubmit + [script], stdout=subprocess.PIPE) out, err = proc.communicate() self.assertEqual(0, proc.returncode) self.assertIn("[3, 6, 9]", out.decode('utf-8')) @@ -2051,7 +2056,7 @@ def test_module_dependency(self): |def myfunc(x): | return x + 1 """) - proc = subprocess.Popen([self.sparkSubmit, "--py-files", zip, script], + proc = subprocess.Popen(self.sparkSubmit + ["--py-files", zip, script], stdout=subprocess.PIPE) out, err = proc.communicate() self.assertEqual(0, proc.returncode) @@ -2070,7 +2075,7 @@ def test_module_dependency_on_cluster(self): |def myfunc(x): | return x + 1 """) - proc = subprocess.Popen([self.sparkSubmit, "--py-files", zip, "--master", + proc = subprocess.Popen(self.sparkSubmit + ["--py-files", zip, "--master", "local-cluster[1,1,1024]", script], stdout=subprocess.PIPE) out, err = proc.communicate() @@ -2087,8 +2092,10 @@ def test_package_dependency(self): |print(sc.parallelize([1, 2, 3]).map(myfunc).collect()) """) self.create_spark_package("a:mylib:0.1") - proc = subprocess.Popen([self.sparkSubmit, "--packages", "a:mylib:0.1", "--repositories", - "file:" + self.programDir, script], stdout=subprocess.PIPE) + proc = subprocess.Popen( + self.sparkSubmit + ["--packages", "a:mylib:0.1", "--repositories", + "file:" + self.programDir, script], + stdout=subprocess.PIPE) out, err = proc.communicate() self.assertEqual(0, proc.returncode) self.assertIn("[2, 3, 4]", out.decode('utf-8')) @@ -2103,9 +2110,11 @@ def test_package_dependency_on_cluster(self): |print(sc.parallelize([1, 2, 3]).map(myfunc).collect()) """) self.create_spark_package("a:mylib:0.1") - proc = subprocess.Popen([self.sparkSubmit, "--packages", "a:mylib:0.1", "--repositories", - "file:" + self.programDir, "--master", - "local-cluster[1,1,1024]", script], stdout=subprocess.PIPE) + proc = subprocess.Popen( + self.sparkSubmit + ["--packages", "a:mylib:0.1", "--repositories", + "file:" + self.programDir, "--master", "local-cluster[1,1,1024]", + script], + stdout=subprocess.PIPE) out, err = proc.communicate() self.assertEqual(0, proc.returncode) self.assertIn("[2, 3, 4]", out.decode('utf-8')) @@ -2124,7 +2133,7 @@ def test_single_script_on_cluster(self): # this will fail if you have different spark.executor.memory # in conf/spark-defaults.conf proc = subprocess.Popen( - [self.sparkSubmit, "--master", "local-cluster[1,1,1024]", script], + self.sparkSubmit + ["--master", "local-cluster[1,1,1024]", script], stdout=subprocess.PIPE) out, err = proc.communicate() self.assertEqual(0, proc.returncode) @@ -2144,7 +2153,7 @@ def test_user_configuration(self): | sc.stop() """) proc = subprocess.Popen( - [self.sparkSubmit, "--master", "local", script], + self.sparkSubmit + ["--master", "local", script], stdout=subprocess.PIPE, stderr=subprocess.STDOUT) out, err = proc.communicate() diff --git a/python/run-tests.py b/python/run-tests.py index f408fc5082b3d..4c90926cfa350 100755 --- a/python/run-tests.py +++ b/python/run-tests.py @@ -22,11 +22,13 @@ from optparse import OptionParser import os import re +import shutil import subprocess import sys import tempfile from threading import Thread, Lock import time +import uuid if sys.version < '3': import Queue else: @@ -68,7 +70,7 @@ def print_red(text): raise Exception("Cannot find assembly build directory, please build Spark first.") -def run_individual_python_test(test_name, pyspark_python): +def run_individual_python_test(target_dir, test_name, pyspark_python): env = dict(os.environ) env.update({ 'SPARK_DIST_CLASSPATH': SPARK_DIST_CLASSPATH, @@ -77,6 +79,23 @@ def run_individual_python_test(test_name, pyspark_python): 'PYSPARK_PYTHON': which(pyspark_python), 'PYSPARK_DRIVER_PYTHON': which(pyspark_python) }) + + # Create a unique temp directory under 'target/' for each run. The TMPDIR variable is + # recognized by the tempfile module to override the default system temp directory. + tmp_dir = os.path.join(target_dir, str(uuid.uuid4())) + while os.path.isdir(tmp_dir): + tmp_dir = os.path.join(target_dir, str(uuid.uuid4())) + os.mkdir(tmp_dir) + env["TMPDIR"] = tmp_dir + + # Also override the JVM's temp directory by setting driver and executor options. + spark_args = [ + "--conf", "spark.driver.extraJavaOptions=-Djava.io.tmpdir={0}".format(tmp_dir), + "--conf", "spark.executor.extraJavaOptions=-Djava.io.tmpdir={0}".format(tmp_dir), + "pyspark-shell" + ] + env["PYSPARK_SUBMIT_ARGS"] = " ".join(spark_args) + LOGGER.info("Starting test(%s): %s", pyspark_python, test_name) start_time = time.time() try: @@ -84,6 +103,7 @@ def run_individual_python_test(test_name, pyspark_python): retcode = subprocess.Popen( [os.path.join(SPARK_HOME, "bin/pyspark"), test_name], stderr=per_test_output, stdout=per_test_output, env=env).wait() + shutil.rmtree(tmp_dir, ignore_errors=True) except: LOGGER.exception("Got exception while running %s with %s", test_name, pyspark_python) # Here, we use os._exit() instead of sys.exit() in order to force Python to exit even if @@ -238,6 +258,11 @@ def main(): priority = 100 task_queue.put((priority, (python_exec, test_goal))) + # Create the target directory before starting tasks to avoid races. + target_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), 'target')) + if not os.path.isdir(target_dir): + os.mkdir(target_dir) + def process_queue(task_queue): while True: try: @@ -245,7 +270,7 @@ def process_queue(task_queue): except Queue.Empty: break try: - run_individual_python_test(test_goal, python_exec) + run_individual_python_test(target_dir, test_goal, python_exec) finally: task_queue.task_done() From 889f6cc10cbd7781df04f468674a61f0ac5a870b Mon Sep 17 00:00:00 2001 From: jinxing Date: Mon, 7 May 2018 14:16:27 +0800 Subject: [PATCH 0742/1161] [SPARK-24143] filter empty blocks when convert mapstatus to (blockId, size) pair MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? In current code(`MapOutputTracker.convertMapStatuses`), mapstatus are converted to (blockId, size) pair for all blocks – no matter the block is empty or not, which result in OOM when there are lots of consecutive empty blocks, especially when adaptive execution is enabled. (blockId, size) pair is only used in `ShuffleBlockFetcherIterator` to control shuffle-read and only non-empty block request is sent. Can we just filter out the empty blocks in MapOutputTracker.convertMapStatuses and save memory? ## How was this patch tested? not added yet. Author: jinxing Closes #21212 from jinxing64/SPARK-24143. --- .../org/apache/spark/MapOutputTracker.scala | 31 +++++++++------- .../storage/ShuffleBlockFetcherIterator.scala | 35 +++++++++++-------- .../apache/spark/MapOutputTrackerSuite.scala | 31 +++++++++++++++- .../BlockStoreShuffleReaderSuite.scala | 2 +- .../ShuffleBlockFetcherIteratorSuite.scala | 19 +++++----- 5 files changed, 80 insertions(+), 38 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 195fd4f818b36..73646051f264c 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -22,7 +22,7 @@ import java.util.concurrent.{ConcurrentHashMap, LinkedBlockingQueue, ThreadPoolE import java.util.zip.{GZIPInputStream, GZIPOutputStream} import scala.collection.JavaConverters._ -import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map} +import scala.collection.mutable.{HashMap, HashSet, ListBuffer, Map} import scala.concurrent.{ExecutionContext, Future} import scala.concurrent.duration.Duration import scala.reflect.ClassTag @@ -282,7 +282,7 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging // For testing def getMapSizesByExecutorId(shuffleId: Int, reduceId: Int) - : Seq[(BlockManagerId, Seq[(BlockId, Long)])] = { + : Iterator[(BlockManagerId, Seq[(BlockId, Long)])] = { getMapSizesByExecutorId(shuffleId, reduceId, reduceId + 1) } @@ -296,7 +296,7 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging * describing the shuffle blocks that are stored at that block manager. */ def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int) - : Seq[(BlockManagerId, Seq[(BlockId, Long)])] + : Iterator[(BlockManagerId, Seq[(BlockId, Long)])] /** * Deletes map output status information for the specified shuffle stage. @@ -632,9 +632,10 @@ private[spark] class MapOutputTrackerMaster( } } + // Get blocks sizes by executor Id. Note that zero-sized blocks are excluded in the result. // This method is only called in local-mode. def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int) - : Seq[(BlockManagerId, Seq[(BlockId, Long)])] = { + : Iterator[(BlockManagerId, Seq[(BlockId, Long)])] = { logDebug(s"Fetching outputs for shuffle $shuffleId, partitions $startPartition-$endPartition") shuffleStatuses.get(shuffleId) match { case Some (shuffleStatus) => @@ -642,7 +643,7 @@ private[spark] class MapOutputTrackerMaster( MapOutputTracker.convertMapStatuses(shuffleId, startPartition, endPartition, statuses) } case None => - Seq.empty + Iterator.empty } } @@ -669,8 +670,9 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr /** Remembers which map output locations are currently being fetched on an executor. */ private val fetching = new HashSet[Int] + // Get blocks sizes by executor Id. Note that zero-sized blocks are excluded in the result. override def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int) - : Seq[(BlockManagerId, Seq[(BlockId, Long)])] = { + : Iterator[(BlockManagerId, Seq[(BlockId, Long)])] = { logDebug(s"Fetching outputs for shuffle $shuffleId, partitions $startPartition-$endPartition") val statuses = getStatuses(shuffleId) try { @@ -841,6 +843,7 @@ private[spark] object MapOutputTracker extends Logging { * Given an array of map statuses and a range of map output partitions, returns a sequence that, * for each block manager ID, lists the shuffle block IDs and corresponding shuffle block sizes * stored at that block manager. + * Note that empty blocks are filtered in the result. * * If any of the statuses is null (indicating a missing location due to a failed mapper), * throws a FetchFailedException. @@ -857,22 +860,24 @@ private[spark] object MapOutputTracker extends Logging { shuffleId: Int, startPartition: Int, endPartition: Int, - statuses: Array[MapStatus]): Seq[(BlockManagerId, Seq[(BlockId, Long)])] = { + statuses: Array[MapStatus]): Iterator[(BlockManagerId, Seq[(BlockId, Long)])] = { assert (statuses != null) - val splitsByAddress = new HashMap[BlockManagerId, ArrayBuffer[(BlockId, Long)]] - for ((status, mapId) <- statuses.zipWithIndex) { + val splitsByAddress = new HashMap[BlockManagerId, ListBuffer[(BlockId, Long)]] + for ((status, mapId) <- statuses.iterator.zipWithIndex) { if (status == null) { val errorMessage = s"Missing an output location for shuffle $shuffleId" logError(errorMessage) throw new MetadataFetchFailedException(shuffleId, startPartition, errorMessage) } else { for (part <- startPartition until endPartition) { - splitsByAddress.getOrElseUpdate(status.location, ArrayBuffer()) += - ((ShuffleBlockId(shuffleId, mapId, part), status.getSizeForBlock(part))) + val size = status.getSizeForBlock(part) + if (size != 0) { + splitsByAddress.getOrElseUpdate(status.location, ListBuffer()) += + ((ShuffleBlockId(shuffleId, mapId, part), size)) + } } } } - - splitsByAddress.toSeq + splitsByAddress.iterator } } diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index dd9df74689a13..6971efd2504c2 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -48,7 +48,9 @@ import org.apache.spark.util.io.ChunkedByteBufferOutputStream * @param blockManager [[BlockManager]] for reading local blocks * @param blocksByAddress list of blocks to fetch grouped by the [[BlockManagerId]]. * For each block we also require the size (in bytes as a long field) in - * order to throttle the memory usage. + * order to throttle the memory usage. Note that zero-sized blocks are + * already excluded, which happened in + * [[MapOutputTracker.convertMapStatuses]]. * @param streamWrapper A function to wrap the returned input stream. * @param maxBytesInFlight max size (in bytes) of remote blocks to fetch at any given point. * @param maxReqsInFlight max number of remote requests to fetch blocks at any given point. @@ -62,7 +64,7 @@ final class ShuffleBlockFetcherIterator( context: TaskContext, shuffleClient: ShuffleClient, blockManager: BlockManager, - blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])], + blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long)])], streamWrapper: (BlockId, InputStream) => InputStream, maxBytesInFlight: Long, maxReqsInFlight: Int, @@ -74,8 +76,8 @@ final class ShuffleBlockFetcherIterator( import ShuffleBlockFetcherIterator._ /** - * Total number of blocks to fetch. This can be smaller than the total number of blocks - * in [[blocksByAddress]] because we filter out zero-sized blocks in [[initialize]]. + * Total number of blocks to fetch. This should be equal to the total number of blocks + * in [[blocksByAddress]] because we already filter out zero-sized blocks in [[blocksByAddress]]. * * This should equal localBlocks.size + remoteBlocks.size. */ @@ -267,13 +269,16 @@ final class ShuffleBlockFetcherIterator( // at most maxBytesInFlight in order to limit the amount of data in flight. val remoteRequests = new ArrayBuffer[FetchRequest] - // Tracks total number of blocks (including zero sized blocks) - var totalBlocks = 0 for ((address, blockInfos) <- blocksByAddress) { - totalBlocks += blockInfos.size if (address.executorId == blockManager.blockManagerId.executorId) { - // Filter out zero-sized blocks - localBlocks ++= blockInfos.filter(_._2 != 0).map(_._1) + blockInfos.find(_._2 <= 0) match { + case Some((blockId, size)) if size < 0 => + throw new BlockException(blockId, "Negative block size " + size) + case Some((blockId, size)) if size == 0 => + throw new BlockException(blockId, "Zero-sized blocks should be excluded.") + case None => // do nothing. + } + localBlocks ++= blockInfos.map(_._1) numBlocksToFetch += localBlocks.size } else { val iterator = blockInfos.iterator @@ -281,14 +286,15 @@ final class ShuffleBlockFetcherIterator( var curBlocks = new ArrayBuffer[(BlockId, Long)] while (iterator.hasNext) { val (blockId, size) = iterator.next() - // Skip empty blocks - if (size > 0) { + if (size < 0) { + throw new BlockException(blockId, "Negative block size " + size) + } else if (size == 0) { + throw new BlockException(blockId, "Zero-sized blocks should be excluded.") + } else { curBlocks += ((blockId, size)) remoteBlocks += blockId numBlocksToFetch += 1 curRequestSize += size - } else if (size < 0) { - throw new BlockException(blockId, "Negative block size " + size) } if (curRequestSize >= targetRequestSize || curBlocks.size >= maxBlocksInFlightPerAddress) { @@ -306,7 +312,8 @@ final class ShuffleBlockFetcherIterator( } } } - logInfo(s"Getting $numBlocksToFetch non-empty blocks out of $totalBlocks blocks") + logInfo(s"Getting $numBlocksToFetch non-empty blocks including ${localBlocks.size}" + + s" local blocks and ${remoteBlocks.size} remote blocks") remoteRequests } diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index 50b8ea754d8d9..21f481d477242 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -147,7 +147,7 @@ class MapOutputTrackerSuite extends SparkFunSuite { masterTracker.registerMapOutput(10, 0, MapStatus( BlockManagerId("a", "hostA", 1000), Array(1000L))) slaveTracker.updateEpoch(masterTracker.getEpoch) - assert(slaveTracker.getMapSizesByExecutorId(10, 0) === + assert(slaveTracker.getMapSizesByExecutorId(10, 0).toSeq === Seq((BlockManagerId("a", "hostA", 1000), ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000))))) assert(0 == masterTracker.getNumCachedSerializedBroadcast) @@ -298,4 +298,33 @@ class MapOutputTrackerSuite extends SparkFunSuite { } } + test("zero-sized blocks should be excluded when getMapSizesByExecutorId") { + val rpcEnv = createRpcEnv("test") + val tracker = newTrackerMaster() + tracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, + new MapOutputTrackerMasterEndpoint(rpcEnv, tracker, conf)) + tracker.registerShuffle(10, 2) + + val size0 = MapStatus.decompressSize(MapStatus.compressSize(0L)) + val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L)) + val size10000 = MapStatus.decompressSize(MapStatus.compressSize(10000L)) + tracker.registerMapOutput(10, 0, MapStatus(BlockManagerId("a", "hostA", 1000), + Array(size0, size1000, size0, size10000))) + tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("b", "hostB", 1000), + Array(size10000, size0, size1000, size0))) + assert(tracker.containsShuffle(10)) + assert(tracker.getMapSizesByExecutorId(10, 0, 4).toSeq === + Seq( + (BlockManagerId("a", "hostA", 1000), + Seq((ShuffleBlockId(10, 0, 1), size1000), (ShuffleBlockId(10, 0, 3), size10000))), + (BlockManagerId("b", "hostB", 1000), + Seq((ShuffleBlockId(10, 1, 0), size10000), (ShuffleBlockId(10, 1, 2), size1000))) + ) + ) + + tracker.unregisterShuffle(10) + tracker.stop() + rpcEnv.shutdown() + } + } diff --git a/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala index dba1172d5fdbd..2d8a83c6fabed 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala @@ -108,7 +108,7 @@ class BlockStoreShuffleReaderSuite extends SparkFunSuite with LocalSparkContext val shuffleBlockId = ShuffleBlockId(shuffleId, mapId, reduceId) (shuffleBlockId, byteOutputStream.size().toLong) } - Seq((localBlockManagerId, shuffleBlockIdsAndSizes)) + Seq((localBlockManagerId, shuffleBlockIdsAndSizes)).toIterator } // Create a mocked shuffle handle to pass into HashShuffleReader. diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index 692ae3bf597e0..cefebfa51b8b9 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -99,7 +99,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( (localBmId, localBlocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq), (remoteBmId, remoteBlocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq) - ) + ).toIterator val iterator = new ShuffleBlockFetcherIterator( TaskContext.empty(), @@ -176,7 +176,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT }) val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( - (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)) + (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)).toIterator val taskContext = TaskContext.empty() val iterator = new ShuffleBlockFetcherIterator( @@ -244,7 +244,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT }) val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( - (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)) + (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)).toIterator val taskContext = TaskContext.empty() val iterator = new ShuffleBlockFetcherIterator( @@ -310,7 +310,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT }) val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( - (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)) + (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)).toIterator val taskContext = TaskContext.empty() val iterator = new ShuffleBlockFetcherIterator( @@ -378,7 +378,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( (localBmId, localBlockLengths), (remoteBmId, remoteBlockLengths) - ) + ).toIterator val taskContext = TaskContext.empty() val iterator = new ShuffleBlockFetcherIterator( @@ -437,7 +437,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT }) val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( - (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)) + (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)).toIterator val taskContext = TaskContext.empty() val iterator = new ShuffleBlockFetcherIterator( @@ -495,7 +495,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT } }) - def fetchShuffleBlock(blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])]): Unit = { + def fetchShuffleBlock( + blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long)])]): Unit = { // Set `maxBytesInFlight` and `maxReqsInFlight` to `Int.MaxValue`, so that during the // construction of `ShuffleBlockFetcherIterator`, all requests to fetch remote shuffle blocks // are issued. The `maxReqSizeShuffleToMem` is hard-coded as 200 here. @@ -513,14 +514,14 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT } val blocksByAddress1 = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( - (remoteBmId, remoteBlocks.keys.map(blockId => (blockId, 100L)).toSeq)) + (remoteBmId, remoteBlocks.keys.map(blockId => (blockId, 100L)).toSeq)).toIterator fetchShuffleBlock(blocksByAddress1) // `maxReqSizeShuffleToMem` is 200, which is greater than the block size 100, so don't fetch // shuffle block to disk. assert(tempFileManager == null) val blocksByAddress2 = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( - (remoteBmId, remoteBlocks.keys.map(blockId => (blockId, 300L)).toSeq)) + (remoteBmId, remoteBlocks.keys.map(blockId => (blockId, 300L)).toSeq)).toIterator fetchShuffleBlock(blocksByAddress2) // `maxReqSizeShuffleToMem` is 200, which is smaller than the block size 300, so fetch // shuffle block to disk. From 7564a9a70695dac2f0b5f51493d37cbc93691663 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Mon, 7 May 2018 15:22:23 +0900 Subject: [PATCH 0743/1161] [SPARK-23921][SQL] Add array_sort function ## What changes were proposed in this pull request? The PR adds the SQL function `array_sort`. The behavior of the function is based on Presto's one. The function sorts the input array in ascending order. The elements of the input array must be orderable. Null elements will be placed at the end of the returned array. ## How was this patch tested? Added UTs Author: Kazuaki Ishizaki Closes #21021 from kiszk/SPARK-23921. --- python/pyspark/sql/functions.py | 26 +- .../catalyst/analysis/FunctionRegistry.scala | 1 + .../expressions/collectionOperations.scala | 240 ++++++++++++++---- .../CollectionExpressionsSuite.scala | 34 ++- .../org/apache/spark/sql/functions.scala | 12 + .../spark/sql/DataFrameFunctionsSuite.scala | 34 ++- 6 files changed, 292 insertions(+), 55 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index ad4bd6f5089e9..bd55b5f73b4d0 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2183,20 +2183,38 @@ def array_max(col): def sort_array(col, asc=True): """ Collection function: sorts the input array in ascending or descending order according - to the natural ordering of the array elements. + to the natural ordering of the array elements. Null elements will be placed at the beginning + of the returned array in ascending order or at the end of the returned array in descending + order. :param col: name of column or expression - >>> df = spark.createDataFrame([([2, 1, 3],),([1],),([],)], ['data']) + >>> df = spark.createDataFrame([([2, 1, None, 3],),([1],),([],)], ['data']) >>> df.select(sort_array(df.data).alias('r')).collect() - [Row(r=[1, 2, 3]), Row(r=[1]), Row(r=[])] + [Row(r=[None, 1, 2, 3]), Row(r=[1]), Row(r=[])] >>> df.select(sort_array(df.data, asc=False).alias('r')).collect() - [Row(r=[3, 2, 1]), Row(r=[1]), Row(r=[])] + [Row(r=[3, 2, 1, None]), Row(r=[1]), Row(r=[])] """ sc = SparkContext._active_spark_context return Column(sc._jvm.functions.sort_array(_to_java_column(col), asc)) +@since(2.4) +def array_sort(col): + """ + Collection function: sorts the input array in ascending order. The elements of the input array + must be orderable. Null elements will be placed at the end of the returned array. + + :param col: name of column or expression + + >>> df = spark.createDataFrame([([2, 1, None, 3],),([1],),([],)], ['data']) + >>> df.select(array_sort(df.data).alias('r')).collect() + [Row(r=[1, 2, 3, None]), Row(r=[1]), Row(r=[])] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.array_sort(_to_java_column(col))) + + @since(1.5) @ignore_unicode_prefix def reverse(col): diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 51bb6b0abe408..01776b85e6f53 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -403,6 +403,7 @@ object FunctionRegistry { expression[ArrayContains]("array_contains"), expression[ArrayJoin]("array_join"), expression[ArrayPosition]("array_position"), + expression[ArraySort]("array_sort"), expression[CreateMap]("map"), expression[CreateNamedStruct]("named_struct"), expression[ElementAt]("element_at"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 6d63a531e3b74..23c09bc3b49d7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -20,6 +20,7 @@ import java.util.Comparator import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.expressions.ArraySortLike.NullOrder import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData, TypeUtils} import org.apache.spark.sql.types._ @@ -119,47 +120,16 @@ case class MapValues(child: Expression) } /** - * Sorts the input array in ascending / descending order according to the natural ordering of - * the array elements and returns it. + * Common base class for [[SortArray]] and [[ArraySort]]. */ -// scalastyle:off line.size.limit -@ExpressionDescription( - usage = "_FUNC_(array[, ascendingOrder]) - Sorts the input array in ascending or descending order according to the natural ordering of the array elements.", - examples = """ - Examples: - > SELECT _FUNC_(array('b', 'd', 'c', 'a'), true); - ["a","b","c","d"] - """) -// scalastyle:on line.size.limit -case class SortArray(base: Expression, ascendingOrder: Expression) - extends BinaryExpression with ExpectsInputTypes with CodegenFallback { +trait ArraySortLike extends ExpectsInputTypes { + protected def arrayExpression: Expression - def this(e: Expression) = this(e, Literal(true)) - - override def left: Expression = base - override def right: Expression = ascendingOrder - override def dataType: DataType = base.dataType - override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, BooleanType) - - override def checkInputDataTypes(): TypeCheckResult = base.dataType match { - case ArrayType(dt, _) if RowOrdering.isOrderable(dt) => - ascendingOrder match { - case Literal(_: Boolean, BooleanType) => - TypeCheckResult.TypeCheckSuccess - case _ => - TypeCheckResult.TypeCheckFailure( - "Sort order in second argument requires a boolean literal.") - } - case ArrayType(dt, _) => - TypeCheckResult.TypeCheckFailure( - s"$prettyName does not support sorting array of type ${dt.simpleString}") - case _ => - TypeCheckResult.TypeCheckFailure(s"$prettyName only supports array input.") - } + protected def nullOrder: NullOrder @transient private lazy val lt: Comparator[Any] = { - val ordering = base.dataType match { + val ordering = arrayExpression.dataType match { case _ @ ArrayType(n: AtomicType, _) => n.ordering.asInstanceOf[Ordering[Any]] case _ @ ArrayType(a: ArrayType, _) => a.interpretedOrdering.asInstanceOf[Ordering[Any]] case _ @ ArrayType(s: StructType, _) => s.interpretedOrdering.asInstanceOf[Ordering[Any]] @@ -170,9 +140,9 @@ case class SortArray(base: Expression, ascendingOrder: Expression) if (o1 == null && o2 == null) { 0 } else if (o1 == null) { - -1 + nullOrder } else if (o2 == null) { - 1 + -nullOrder } else { ordering.compare(o1, o2) } @@ -182,7 +152,7 @@ case class SortArray(base: Expression, ascendingOrder: Expression) @transient private lazy val gt: Comparator[Any] = { - val ordering = base.dataType match { + val ordering = arrayExpression.dataType match { case _ @ ArrayType(n: AtomicType, _) => n.ordering.asInstanceOf[Ordering[Any]] case _ @ ArrayType(a: ArrayType, _) => a.interpretedOrdering.asInstanceOf[Ordering[Any]] case _ @ ArrayType(s: StructType, _) => s.interpretedOrdering.asInstanceOf[Ordering[Any]] @@ -193,9 +163,9 @@ case class SortArray(base: Expression, ascendingOrder: Expression) if (o1 == null && o2 == null) { 0 } else if (o1 == null) { - 1 + -nullOrder } else if (o2 == null) { - -1 + nullOrder } else { -ordering.compare(o1, o2) } @@ -203,18 +173,200 @@ case class SortArray(base: Expression, ascendingOrder: Expression) } } - override def nullSafeEval(array: Any, ascending: Any): Any = { - val elementType = base.dataType.asInstanceOf[ArrayType].elementType + def elementType: DataType = arrayExpression.dataType.asInstanceOf[ArrayType].elementType + def containsNull: Boolean = arrayExpression.dataType.asInstanceOf[ArrayType].containsNull + + def sortEval(array: Any, ascending: Boolean): Any = { val data = array.asInstanceOf[ArrayData].toArray[AnyRef](elementType) if (elementType != NullType) { - java.util.Arrays.sort(data, if (ascending.asInstanceOf[Boolean]) lt else gt) + java.util.Arrays.sort(data, if (ascending) lt else gt) } new GenericArrayData(data.asInstanceOf[Array[Any]]) } + def sortCodegen(ctx: CodegenContext, ev: ExprCode, base: String, order: String): String = { + val arrayData = classOf[ArrayData].getName + val genericArrayData = classOf[GenericArrayData].getName + val unsafeArrayData = classOf[UnsafeArrayData].getName + val array = ctx.freshName("array") + val c = ctx.freshName("c") + if (elementType == NullType) { + s"${ev.value} = $base.copy();" + } else { + val elementTypeTerm = ctx.addReferenceObj("elementTypeTerm", elementType) + val sortOrder = ctx.freshName("sortOrder") + val o1 = ctx.freshName("o1") + val o2 = ctx.freshName("o2") + val jt = CodeGenerator.javaType(elementType) + val comp = if (CodeGenerator.isPrimitiveType(elementType)) { + val bt = CodeGenerator.boxedType(elementType) + val v1 = ctx.freshName("v1") + val v2 = ctx.freshName("v2") + s""" + |$jt $v1 = (($bt) $o1).${jt}Value(); + |$jt $v2 = (($bt) $o2).${jt}Value(); + |int $c = ${ctx.genComp(elementType, v1, v2)}; + """.stripMargin + } else { + s"int $c = ${ctx.genComp(elementType, s"(($jt) $o1)", s"(($jt) $o2)")};" + } + val nonNullPrimitiveAscendingSort = + if (CodeGenerator.isPrimitiveType(elementType) && !containsNull) { + val javaType = CodeGenerator.javaType(elementType) + val primitiveTypeName = CodeGenerator.primitiveTypeName(elementType) + s""" + |if ($order) { + | $javaType[] $array = $base.to${primitiveTypeName}Array(); + | java.util.Arrays.sort($array); + | ${ev.value} = $unsafeArrayData.fromPrimitiveArray($array); + |} else + """.stripMargin + } else { + "" + } + s""" + |$nonNullPrimitiveAscendingSort + |{ + | Object[] $array = $base.toObjectArray($elementTypeTerm); + | final int $sortOrder = $order ? 1 : -1; + | java.util.Arrays.sort($array, new java.util.Comparator() { + | @Override public int compare(Object $o1, Object $o2) { + | if ($o1 == null && $o2 == null) { + | return 0; + | } else if ($o1 == null) { + | return $sortOrder * $nullOrder; + | } else if ($o2 == null) { + | return -$sortOrder * $nullOrder; + | } + | $comp + | return $sortOrder * $c; + | } + | }); + | ${ev.value} = new $genericArrayData($array); + |} + """.stripMargin + } + } + +} + +object ArraySortLike { + type NullOrder = Int + // Least: place null element at the first of the array for ascending order + // Greatest: place null element at the end of the array for ascending order + object NullOrder { + val Least: NullOrder = -1 + val Greatest: NullOrder = 1 + } +} + +/** + * Sorts the input array in ascending / descending order according to the natural ordering of + * the array elements and returns it. + */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = """ + _FUNC_(array[, ascendingOrder]) - Sorts the input array in ascending or descending order + according to the natural ordering of the array elements. Null elements will be placed + at the beginning of the returned array in ascending order or at the end of the returned + array in descending order. + """, + examples = """ + Examples: + > SELECT _FUNC_(array('b', 'd', null, 'c', 'a'), true); + [null,"a","b","c","d"] + """) +// scalastyle:on line.size.limit +case class SortArray(base: Expression, ascendingOrder: Expression) + extends BinaryExpression with ArraySortLike { + + def this(e: Expression) = this(e, Literal(true)) + + override def left: Expression = base + override def right: Expression = ascendingOrder + override def dataType: DataType = base.dataType + override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, BooleanType) + + override def arrayExpression: Expression = base + override def nullOrder: NullOrder = NullOrder.Least + + override def checkInputDataTypes(): TypeCheckResult = base.dataType match { + case ArrayType(dt, _) if RowOrdering.isOrderable(dt) => + ascendingOrder match { + case Literal(_: Boolean, BooleanType) => + TypeCheckResult.TypeCheckSuccess + case _ => + TypeCheckResult.TypeCheckFailure( + "Sort order in second argument requires a boolean literal.") + } + case ArrayType(dt, _) => + val dtSimple = dt.simpleString + TypeCheckResult.TypeCheckFailure( + s"$prettyName does not support sorting array of type $dtSimple which is not orderable") + case _ => + TypeCheckResult.TypeCheckFailure(s"$prettyName only supports array input.") + } + + override def nullSafeEval(array: Any, ascending: Any): Any = { + sortEval(array, ascending.asInstanceOf[Boolean]) + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, (b, order) => sortCodegen(ctx, ev, b, order)) + } + override def prettyName: String = "sort_array" } + +/** + * Sorts the input array in ascending order according to the natural ordering of + * the array elements and returns it. + */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = """ + _FUNC_(array) - Sorts the input array in ascending order. The elements of the input array must + be orderable. Null elements will be placed at the end of the returned array. + """, + examples = """ + Examples: + > SELECT _FUNC_(array('b', 'd', null, 'c', 'a')); + ["a","b","c","d",null] + """, + since = "2.4.0") +// scalastyle:on line.size.limit +case class ArraySort(child: Expression) extends UnaryExpression with ArraySortLike { + + override def dataType: DataType = child.dataType + override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType) + + override def arrayExpression: Expression = child + override def nullOrder: NullOrder = NullOrder.Greatest + + override def checkInputDataTypes(): TypeCheckResult = child.dataType match { + case ArrayType(dt, _) if RowOrdering.isOrderable(dt) => + TypeCheckResult.TypeCheckSuccess + case ArrayType(dt, _) => + val dtSimple = dt.simpleString + TypeCheckResult.TypeCheckFailure( + s"$prettyName does not support sorting array of type $dtSimple which is not orderable") + case _ => + TypeCheckResult.TypeCheckFailure(s"$prettyName only supports array input.") + } + + override def nullSafeEval(array: Any): Any = { + sortEval(array, true) + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, c => sortCodegen(ctx, ev, c, "true")) + } + + override def prettyName: String = "array_sort" +} + /** * Returns a reversed string or an array with reverse order of elements. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 7048d93fd5649..749374f1a14a1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -61,28 +61,58 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper val a1 = Literal.create(Seq[Integer](), ArrayType(IntegerType)) val a2 = Literal.create(Seq("b", "a"), ArrayType(StringType)) val a3 = Literal.create(Seq("b", null, "a"), ArrayType(StringType)) - val a4 = Literal.create(Seq(null, null), ArrayType(NullType)) + val d1 = new Decimal().set(10) + val d2 = new Decimal().set(100) + val a4 = Literal.create(Seq(d2, d1), ArrayType(DecimalType(10, 0))) + val a5 = Literal.create(Seq(null, null), ArrayType(NullType)) checkEvaluation(new SortArray(a0), Seq(1, 2, 3)) checkEvaluation(new SortArray(a1), Seq[Integer]()) checkEvaluation(new SortArray(a2), Seq("a", "b")) checkEvaluation(new SortArray(a3), Seq(null, "a", "b")) + checkEvaluation(new SortArray(a4), Seq(d1, d2)) checkEvaluation(SortArray(a0, Literal(true)), Seq(1, 2, 3)) checkEvaluation(SortArray(a1, Literal(true)), Seq[Integer]()) checkEvaluation(SortArray(a2, Literal(true)), Seq("a", "b")) checkEvaluation(new SortArray(a3, Literal(true)), Seq(null, "a", "b")) + checkEvaluation(SortArray(a4, Literal(true)), Seq(d1, d2)) checkEvaluation(SortArray(a0, Literal(false)), Seq(3, 2, 1)) checkEvaluation(SortArray(a1, Literal(false)), Seq[Integer]()) checkEvaluation(SortArray(a2, Literal(false)), Seq("b", "a")) checkEvaluation(new SortArray(a3, Literal(false)), Seq("b", "a", null)) + checkEvaluation(SortArray(a4, Literal(false)), Seq(d2, d1)) checkEvaluation(Literal.create(null, ArrayType(StringType)), null) - checkEvaluation(new SortArray(a4), Seq(null, null)) + checkEvaluation(new SortArray(a5), Seq(null, null)) val typeAS = ArrayType(StructType(StructField("a", IntegerType) :: Nil)) val arrayStruct = Literal.create(Seq(create_row(2), create_row(1)), typeAS) checkEvaluation(new SortArray(arrayStruct), Seq(create_row(1), create_row(2))) + + val typeAA = ArrayType(ArrayType(IntegerType)) + val aa1 = Array[java.lang.Integer](1, 2) + val aa2 = Array[java.lang.Integer](3, null, 4) + val arrayArray = Literal.create(Seq(aa2, aa1), typeAA) + + checkEvaluation(new SortArray(arrayArray), Seq(aa1, aa2)) + + val typeAAS = ArrayType(ArrayType(StructType(StructField("a", IntegerType) :: Nil))) + val aas1 = Array(create_row(1)) + val aas2 = Array(create_row(2)) + val arrayArrayStruct = Literal.create(Seq(aas2, aas1), typeAAS) + + checkEvaluation(new SortArray(arrayArrayStruct), Seq(aas1, aas2)) + + checkEvaluation(ArraySort(a0), Seq(1, 2, 3)) + checkEvaluation(ArraySort(a1), Seq[Integer]()) + checkEvaluation(ArraySort(a2), Seq("a", "b")) + checkEvaluation(ArraySort(a3), Seq("a", "b", null)) + checkEvaluation(ArraySort(a4), Seq(d1, d2)) + checkEvaluation(ArraySort(a5), Seq(null, null)) + checkEvaluation(ArraySort(arrayStruct), Seq(create_row(1), create_row(2))) + checkEvaluation(ArraySort(arrayArray), Seq(aa1, aa2)) + checkEvaluation(ArraySort(arrayArrayStruct), Seq(aas1, aas2)) } test("Array contains") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index d2e22fa355514..10b6dcc0608c2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3093,6 +3093,15 @@ object functions { ElementAt(column.expr, Literal(value)) } + /** + * Sorts the input array in ascending order. The elements of the input array must be orderable. + * Null elements will be placed at the end of the returned array. + * + * @group collection_funcs + * @since 2.4.0 + */ + def array_sort(e: Column): Column = withExpr { ArraySort(e.expr) } + /** * Creates a new row for each element in the given array or map column. * @@ -3332,6 +3341,7 @@ object functions { /** * Sorts the input array for the given column in ascending order, * according to the natural ordering of the array elements. + * Null elements will be placed at the beginning of the returned array. * * @group collection_funcs * @since 1.5.0 @@ -3341,6 +3351,8 @@ object functions { /** * Sorts the input array for the given column in ascending or descending order, * according to the natural ordering of the array elements. + * Null elements will be placed at the beginning of the returned array in ascending order or + * at the end of the returned array in descending order. * * @group collection_funcs * @since 1.5.0 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index a5163accb1bb3..ae21cbc802d0a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -276,7 +276,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { ) } - test("sort_array function") { + test("sort_array/array_sort functions") { val df = Seq( (Array[Int](2, 1, 3), Array("b", "c", "a")), (Array.empty[Int], Array.empty[String]), @@ -286,28 +286,28 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { df.select(sort_array($"a"), sort_array($"b")), Seq( Row(Seq(1, 2, 3), Seq("a", "b", "c")), - Row(Seq[Int](), Seq[String]()), + Row(Seq.empty[Int], Seq.empty[String]), Row(null, null)) ) checkAnswer( df.select(sort_array($"a", false), sort_array($"b", false)), Seq( Row(Seq(3, 2, 1), Seq("c", "b", "a")), - Row(Seq[Int](), Seq[String]()), + Row(Seq.empty[Int], Seq.empty[String]), Row(null, null)) ) checkAnswer( df.selectExpr("sort_array(a)", "sort_array(b)"), Seq( Row(Seq(1, 2, 3), Seq("a", "b", "c")), - Row(Seq[Int](), Seq[String]()), + Row(Seq.empty[Int], Seq.empty[String]), Row(null, null)) ) checkAnswer( df.selectExpr("sort_array(a, true)", "sort_array(b, false)"), Seq( Row(Seq(1, 2, 3), Seq("c", "b", "a")), - Row(Seq[Int](), Seq[String]()), + Row(Seq.empty[Int], Seq.empty[String]), Row(null, null)) ) @@ -324,6 +324,30 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { assert(intercept[AnalysisException] { df3.selectExpr("sort_array(a)").collect() }.getMessage().contains("only supports array input")) + + checkAnswer( + df.select(array_sort($"a"), array_sort($"b")), + Seq( + Row(Seq(1, 2, 3), Seq("a", "b", "c")), + Row(Seq.empty[Int], Seq.empty[String]), + Row(null, null)) + ) + checkAnswer( + df.selectExpr("array_sort(a)", "array_sort(b)"), + Seq( + Row(Seq(1, 2, 3), Seq("a", "b", "c")), + Row(Seq.empty[Int], Seq.empty[String]), + Row(null, null)) + ) + + checkAnswer( + df2.selectExpr("array_sort(a)"), + Seq(Row(Seq[Seq[Int]](Seq(1), Seq(2), Seq(2, 4), null))) + ) + + assert(intercept[AnalysisException] { + df3.selectExpr("array_sort(a)").collect() + }.getMessage().contains("only supports array input")) } test("array size function") { From d2aa859b4faeda03e32a7574dd0c5b4ed367fae4 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 7 May 2018 14:34:03 +0800 Subject: [PATCH 0744/1161] [SPARK-24160] ShuffleBlockFetcherIterator should fail if it receives zero-size blocks ## What changes were proposed in this pull request? This patch modifies `ShuffleBlockFetcherIterator` so that the receipt of zero-size blocks is treated as an error. This is done as a preventative measure to guard against a potential source of data loss bugs. In the shuffle layer, we guarantee that zero-size blocks will never be requested (a block containing zero records is always 0 bytes in size and is marked as empty such that it will never be legitimately requested by executors). However, the existing code does not fully take advantage of this invariant in the shuffle-read path: the existing code did not explicitly check whether blocks are non-zero-size. Additionally, our decompression and deserialization streams treat zero-size inputs as empty streams rather than errors (EOF might actually be treated as "end-of-stream" in certain layers (longstanding behavior dating to earliest versions of Spark) and decompressors like Snappy may be tolerant to zero-size inputs). As a result, if some other bug causes legitimate buffers to be replaced with zero-sized buffers (due to corruption on either the send or receive sides) then this would translate into silent data loss rather than an explicit fail-fast error. This patch addresses this problem by adding a `buf.size != 0` check. See code comments for pointers to tests which guarantee the invariants relied on here. ## How was this patch tested? Existing tests (which required modifications, since some were creating empty buffers in mocks). I also added a test to make sure we fail on zero-size blocks. To test that the zero-size blocks are indeed a potential corruption source, I manually ran a workload in `spark-shell` with a modified build which replaces all buffers with zero-size buffers in the receive path. Author: Josh Rosen Closes #21219 from JoshRosen/SPARK-24160. --- .../storage/ShuffleBlockFetcherIterator.scala | 19 +++++ .../ShuffleBlockFetcherIteratorSuite.scala | 71 +++++++++++++------ 2 files changed, 70 insertions(+), 20 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index 6971efd2504c2..b31862323a895 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -414,6 +414,25 @@ final class ShuffleBlockFetcherIterator( logDebug("Number of requests in flight " + reqsInFlight) } + if (buf.size == 0) { + // We will never legitimately receive a zero-size block. All blocks with zero records + // have zero size and all zero-size blocks have no records (and hence should never + // have been requested in the first place). This statement relies on behaviors of the + // shuffle writers, which are guaranteed by the following test cases: + // + // - BypassMergeSortShuffleWriterSuite: "write with some empty partitions" + // - UnsafeShuffleWriterSuite: "writeEmptyIterator" + // - DiskBlockObjectWriterSuite: "commit() and close() without ever opening or writing" + // + // There is not an explicit test for SortShuffleWriter but the underlying APIs that + // uses are shared by the UnsafeShuffleWriter (both writers use DiskBlockObjectWriter + // which returns a zero-size from commitAndGet() in case no records were written + // since the last call. + val msg = s"Received a zero-size buffer for block $blockId from $address " + + s"(expectedApproxSize = $size, isNetworkReqDone=$isNetworkReqDone)" + throwFetchFailedException(blockId, address, new IOException(msg)) + } + val in = try { buf.createInputStream() } catch { diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index cefebfa51b8b9..8e9374b768adc 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -65,12 +65,13 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT } // Create a mock managed buffer for testing - def createMockManagedBuffer(): ManagedBuffer = { + def createMockManagedBuffer(size: Int = 1): ManagedBuffer = { val mockManagedBuffer = mock(classOf[ManagedBuffer]) val in = mock(classOf[InputStream]) when(in.read(any())).thenReturn(1) when(in.read(any(), any(), any())).thenReturn(1) when(mockManagedBuffer.createInputStream()).thenReturn(in) + when(mockManagedBuffer.size()).thenReturn(size) mockManagedBuffer } @@ -269,6 +270,15 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT intercept[FetchFailedException] { iterator.next() } } + private def mockCorruptBuffer(size: Long = 1L): ManagedBuffer = { + val corruptStream = mock(classOf[InputStream]) + when(corruptStream.read(any(), any(), any())).thenThrow(new IOException("corrupt")) + val corruptBuffer = mock(classOf[ManagedBuffer]) + when(corruptBuffer.size()).thenReturn(size) + when(corruptBuffer.createInputStream()).thenReturn(corruptStream) + corruptBuffer + } + test("retry corrupt blocks") { val blockManager = mock(classOf[BlockManager]) val localBmId = BlockManagerId("test-client", "test-client", 1) @@ -284,11 +294,6 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT // Semaphore to coordinate event sequence in two different threads. val sem = new Semaphore(0) - - val corruptStream = mock(classOf[InputStream]) - when(corruptStream.read(any(), any(), any())).thenThrow(new IOException("corrupt")) - val corruptBuffer = mock(classOf[ManagedBuffer]) - when(corruptBuffer.createInputStream()).thenReturn(corruptStream) val corruptLocalBuffer = new FileSegmentManagedBuffer(null, new File("a"), 0, 100) val transfer = mock(classOf[BlockTransferService]) @@ -301,7 +306,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT listener.onBlockFetchSuccess( ShuffleBlockId(0, 0, 0).toString, blocks(ShuffleBlockId(0, 0, 0))) listener.onBlockFetchSuccess( - ShuffleBlockId(0, 1, 0).toString, corruptBuffer) + ShuffleBlockId(0, 1, 0).toString, mockCorruptBuffer()) listener.onBlockFetchSuccess( ShuffleBlockId(0, 2, 0).toString, corruptLocalBuffer) sem.release() @@ -339,7 +344,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT Future { // Return the first block, and then fail. listener.onBlockFetchSuccess( - ShuffleBlockId(0, 1, 0).toString, corruptBuffer) + ShuffleBlockId(0, 1, 0).toString, mockCorruptBuffer()) sem.release() } } @@ -353,11 +358,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT } test("big blocks are not checked for corruption") { - val corruptStream = mock(classOf[InputStream]) - when(corruptStream.read(any(), any(), any())).thenThrow(new IOException("corrupt")) - val corruptBuffer = mock(classOf[ManagedBuffer]) - when(corruptBuffer.createInputStream()).thenReturn(corruptStream) - doReturn(10000L).when(corruptBuffer).size() + val corruptBuffer = mockCorruptBuffer(10000L) val blockManager = mock(classOf[BlockManager]) val localBmId = BlockManagerId("test-client", "test-client", 1) @@ -413,11 +414,6 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT // Semaphore to coordinate event sequence in two different threads. val sem = new Semaphore(0) - val corruptStream = mock(classOf[InputStream]) - when(corruptStream.read(any(), any(), any())).thenThrow(new IOException("corrupt")) - val corruptBuffer = mock(classOf[ManagedBuffer]) - when(corruptBuffer.createInputStream()).thenReturn(corruptStream) - val transfer = mock(classOf[BlockTransferService]) when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())) .thenAnswer(new Answer[Unit] { @@ -428,9 +424,9 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT listener.onBlockFetchSuccess( ShuffleBlockId(0, 0, 0).toString, blocks(ShuffleBlockId(0, 0, 0))) listener.onBlockFetchSuccess( - ShuffleBlockId(0, 1, 0).toString, corruptBuffer) + ShuffleBlockId(0, 1, 0).toString, mockCorruptBuffer()) listener.onBlockFetchSuccess( - ShuffleBlockId(0, 2, 0).toString, corruptBuffer) + ShuffleBlockId(0, 2, 0).toString, mockCorruptBuffer()) sem.release() } } @@ -527,4 +523,39 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT // shuffle block to disk. assert(tempFileManager != null) } + + test("fail zero-size blocks") { + val blockManager = mock(classOf[BlockManager]) + val localBmId = BlockManagerId("test-client", "test-client", 1) + doReturn(localBmId).when(blockManager).blockManagerId + + // Make sure remote blocks would return + val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2) + val blocks = Map[BlockId, ManagedBuffer]( + ShuffleBlockId(0, 0, 0) -> createMockManagedBuffer(), + ShuffleBlockId(0, 1, 0) -> createMockManagedBuffer() + ) + + val transfer = createMockTransfer(blocks.mapValues(_ => createMockManagedBuffer(0))) + + val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( + (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)) + + val taskContext = TaskContext.empty() + val iterator = new ShuffleBlockFetcherIterator( + taskContext, + transfer, + blockManager, + blocksByAddress, + (_, in) => in, + 48 * 1024 * 1024, + Int.MaxValue, + Int.MaxValue, + Int.MaxValue, + true) + + // All blocks fetched return zero length and should trigger a receive-side error: + val e = intercept[FetchFailedException] { iterator.next() } + assert(e.getMessage.contains("Received a zero-size buffer")) + } } From c5981976f1d514a3ad8a684b9a21cebe38b786fa Mon Sep 17 00:00:00 2001 From: Gabor Somogyi Date: Mon, 7 May 2018 14:45:14 +0800 Subject: [PATCH 0745/1161] [SPARK-23775][TEST] Make DataFrameRangeSuite not flaky ## What changes were proposed in this pull request? DataFrameRangeSuite.test("Cancelling stage in a query with Range.") stays sometimes in an infinite loop and times out the build. There were multiple issues with the test: 1. The first valid stageId is zero when the test started alone and not in a suite and the following code waits until timeout: ``` eventually(timeout(10.seconds), interval(1.millis)) { assert(DataFrameRangeSuite.stageToKill > 0) } ``` 2. The `DataFrameRangeSuite.stageToKill` was overwritten by the task's thread after the reset which ended up in canceling the same stage 2 times. This caused the infinite wait. This PR solves this mentioned flakyness by removing the shared `DataFrameRangeSuite.stageToKill` and using `onTaskStart` where stage ID is provided. In order to make sure cancelStage called for all stages `waitUntilEmpty` is called on `ListenerBus`. In [PR20888](https://github.com/apache/spark/pull/20888) this tried to get solved by: * Stopping the executor thread with `wait` * Wait for all `cancelStage` called * Kill the executor thread by setting `SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL` but the thread killing left the shared `SparkContext` sometimes in a state where further jobs can't be submitted. As a result DataFrameRangeSuite.test("Cancelling stage in a query with Range.") test passed properly but the next test inside the suite was hanging. ## How was this patch tested? Existing unit test executed 10k times. Author: Gabor Somogyi Closes #21214 from gaborgsomogyi/SPARK-23775_1. --- .../spark/sql/DataFrameRangeSuite.scala | 24 +++++++------------ 1 file changed, 8 insertions(+), 16 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala index 57a930dfaf320..b0b46640ff317 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala @@ -23,8 +23,8 @@ import scala.util.Random import org.scalatest.concurrent.Eventually -import org.apache.spark.{SparkException, TaskContext} -import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} +import org.apache.spark.SparkException +import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskStart} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext @@ -153,23 +153,17 @@ class DataFrameRangeSuite extends QueryTest with SharedSQLContext with Eventuall test("Cancelling stage in a query with Range.") { val listener = new SparkListener { - override def onJobStart(jobStart: SparkListenerJobStart): Unit = { - eventually(timeout(10.seconds), interval(1.millis)) { - assert(DataFrameRangeSuite.stageToKill > 0) - } - sparkContext.cancelStage(DataFrameRangeSuite.stageToKill) + override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = { + sparkContext.cancelStage(taskStart.stageId) } } sparkContext.addSparkListener(listener) for (codegen <- Seq(true, false)) { withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> codegen.toString()) { - DataFrameRangeSuite.stageToKill = -1 val ex = intercept[SparkException] { - spark.range(0, 100000000000L, 1, 1).map { x => - DataFrameRangeSuite.stageToKill = TaskContext.get().stageId() - x - }.toDF("id").agg(sum("id")).collect() + spark.range(0, 100000000000L, 1, 1) + .toDF("id").agg(sum("id")).collect() } ex.getCause() match { case null => @@ -180,6 +174,8 @@ class DataFrameRangeSuite extends QueryTest with SharedSQLContext with Eventuall fail("Expected the cause to be SparkException, got " + cause.toString() + " instead.") } } + // Wait until all ListenerBus events consumed to make sure cancelStage called for all stages + sparkContext.listenerBus.waitUntilEmpty(20.seconds.toMillis) eventually(timeout(20.seconds)) { assert(sparkContext.statusTracker.getExecutorInfos.map(_.numRunningTasks()).sum == 0) } @@ -204,7 +200,3 @@ class DataFrameRangeSuite extends QueryTest with SharedSQLContext with Eventuall } } } - -object DataFrameRangeSuite { - @volatile var stageToKill = -1 -} From f06528015d5856d6dc5cce00309bc2ae985e080f Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Mon, 7 May 2018 15:42:10 +0800 Subject: [PATCH 0746/1161] [SPARK-24160][FOLLOWUP] Fix compilation failure ## What changes were proposed in this pull request? SPARK-24160 is causing a compilation failure (after SPARK-24143 was merged). This fixes the issue. ## How was this patch tested? building successfully Author: Marco Gaido Closes #21256 from mgaido91/SPARK-24160_FOLLOWUP. --- .../apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index 8e9374b768adc..a2997dbd1b1ac 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -546,7 +546,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT taskContext, transfer, blockManager, - blocksByAddress, + blocksByAddress.toIterator, (_, in) => in, 48 * 1024 * 1024, Int.MaxValue, From e35ad3caddeaa4b0d4c8524dcfb9e9f56dc7fe3d Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Mon, 7 May 2018 16:57:37 +0900 Subject: [PATCH 0747/1161] [SPARK-23930][SQL] Add slice function ## What changes were proposed in this pull request? The PR add the `slice` function. The behavior of the function is based on Presto's one. The function slices an array according to the requested start index and length. ## How was this patch tested? added UTs Author: Marco Gaido Closes #21040 from mgaido91/SPARK-23930. --- python/pyspark/sql/functions.py | 13 ++ .../catalyst/analysis/FunctionRegistry.scala | 1 + .../expressions/codegen/CodeGenerator.scala | 34 ++++ .../expressions/collectionOperations.scala | 163 ++++++++++++++---- .../CollectionExpressionsSuite.scala | 28 +++ .../expressions/ExpressionEvalHelper.scala | 6 + .../expressions/ObjectExpressionsSuite.scala | 1 - .../org/apache/spark/sql/functions.scala | 10 ++ .../spark/sql/DataFrameFunctionsSuite.scala | 16 ++ 9 files changed, 233 insertions(+), 39 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index bd55b5f73b4d0..ac3c79766702c 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1834,6 +1834,19 @@ def array_contains(col, value): return Column(sc._jvm.functions.array_contains(_to_java_column(col), value)) +@since(2.4) +def slice(x, start, length): + """ + Collection function: returns an array containing all the elements in `x` from index `start` + (or starting from the end if `start` is negative) with the specified `length`. + >>> df = spark.createDataFrame([([1, 2, 3],), ([4, 5],)], ['x']) + >>> df.select(slice(df.x, 2, 2).alias("sliced")).collect() + [Row(sliced=[2, 3]), Row(sliced=[5])] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.slice(_to_java_column(x), start, length)) + + @ignore_unicode_prefix @since(2.4) def array_join(col, delimiter, null_replacement=None): diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 01776b85e6f53..87b0911e150c5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -410,6 +410,7 @@ object FunctionRegistry { expression[MapKeys]("map_keys"), expression[MapValues]("map_values"), expression[Size]("size"), + expression[Slice]("slice"), expression[Size]("cardinality"), expression[SortArray]("sort_array"), expression[ArrayMin]("array_min"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index cf0a91ff00626..4dda525294259 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -42,6 +42,7 @@ import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.Platform +import org.apache.spark.unsafe.array.ByteArrayMethods import org.apache.spark.unsafe.types._ import org.apache.spark.util.{ParentClassLoader, Utils} @@ -730,6 +731,39 @@ class CodegenContext { """.stripMargin } + /** + * Generates code creating a [[UnsafeArrayData]]. + * + * @param arrayName name of the array to create + * @param numElements code representing the number of elements the array should contain + * @param elementType data type of the elements in the array + * @param additionalErrorMessage string to include in the error message + */ + def createUnsafeArray( + arrayName: String, + numElements: String, + elementType: DataType, + additionalErrorMessage: String): String = { + val arraySize = freshName("size") + val arrayBytes = freshName("arrayBytes") + + s""" + |long $arraySize = UnsafeArrayData.calculateSizeOfUnderlyingByteArray( + | $numElements, + | ${elementType.defaultSize}); + |if ($arraySize > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { + | throw new RuntimeException("Unsuccessful try create array with " + $arraySize + + | " bytes of data due to exceeding the limit " + + | "${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH} bytes for UnsafeArrayData." + + | "$additionalErrorMessage"); + |} + |byte[] $arrayBytes = new byte[(int)$arraySize]; + |UnsafeArrayData $arrayName = new UnsafeArrayData(); + |Platform.putLong($arrayBytes, ${Platform.BYTE_ARRAY_OFFSET}, $numElements); + |$arrayName.pointTo($arrayBytes, ${Platform.BYTE_ARRAY_OFFSET}, (int)$arraySize); + """.stripMargin + } + /** * Generates code to do null safe execution, i.e. only execute the code when the input is not * null by adding null check if necessary. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 23c09bc3b49d7..12b9ab2b272ab 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -24,7 +24,6 @@ import org.apache.spark.sql.catalyst.expressions.ArraySortLike.NullOrder import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData, TypeUtils} import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.array.ByteArrayMethods import org.apache.spark.unsafe.types.{ByteArray, UTF8String} @@ -530,6 +529,129 @@ case class ArrayContains(left: Expression, right: Expression) override def prettyName: String = "array_contains" } +/** + * Slices an array according to the requested start index and length + */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(x, start, length) - Subsets array x starting from index start (or starting from the end if start is negative) with the specified length.", + examples = """ + Examples: + > SELECT _FUNC_(array(1, 2, 3, 4), 2, 2); + [2,3] + > SELECT _FUNC_(array(1, 2, 3, 4), -2, 2); + [3,4] + """, since = "2.4.0") +// scalastyle:on line.size.limit +case class Slice(x: Expression, start: Expression, length: Expression) + extends TernaryExpression with ImplicitCastInputTypes { + + override def dataType: DataType = x.dataType + + override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, IntegerType, IntegerType) + + override def children: Seq[Expression] = Seq(x, start, length) + + lazy val elementType: DataType = x.dataType.asInstanceOf[ArrayType].elementType + + override def nullSafeEval(xVal: Any, startVal: Any, lengthVal: Any): Any = { + val startInt = startVal.asInstanceOf[Int] + val lengthInt = lengthVal.asInstanceOf[Int] + val arr = xVal.asInstanceOf[ArrayData] + val startIndex = if (startInt == 0) { + throw new RuntimeException( + s"Unexpected value for start in function $prettyName: SQL array indices start at 1.") + } else if (startInt < 0) { + startInt + arr.numElements() + } else { + startInt - 1 + } + if (lengthInt < 0) { + throw new RuntimeException(s"Unexpected value for length in function $prettyName: " + + "length must be greater than or equal to 0.") + } + // startIndex can be negative if start is negative and its absolute value is greater than the + // number of elements in the array + if (startIndex < 0 || startIndex >= arr.numElements()) { + return new GenericArrayData(Array.empty[AnyRef]) + } + val data = arr.toSeq[AnyRef](elementType) + new GenericArrayData(data.slice(startIndex, startIndex + lengthInt)) + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, (x, start, length) => { + val startIdx = ctx.freshName("startIdx") + val resLength = ctx.freshName("resLength") + val defaultIntValue = CodeGenerator.defaultValue(CodeGenerator.JAVA_INT, false) + s""" + |${CodeGenerator.JAVA_INT} $startIdx = $defaultIntValue; + |${CodeGenerator.JAVA_INT} $resLength = $defaultIntValue; + |if ($start == 0) { + | throw new RuntimeException("Unexpected value for start in function $prettyName: " + | + "SQL array indices start at 1."); + |} else if ($start < 0) { + | $startIdx = $start + $x.numElements(); + |} else { + | // arrays in SQL are 1-based instead of 0-based + | $startIdx = $start - 1; + |} + |if ($length < 0) { + | throw new RuntimeException("Unexpected value for length in function $prettyName: " + | + "length must be greater than or equal to 0."); + |} else if ($length > $x.numElements() - $startIdx) { + | $resLength = $x.numElements() - $startIdx; + |} else { + | $resLength = $length; + |} + |${genCodeForResult(ctx, ev, x, startIdx, resLength)} + """.stripMargin + }) + } + + def genCodeForResult( + ctx: CodegenContext, + ev: ExprCode, + inputArray: String, + startIdx: String, + resLength: String): String = { + val values = ctx.freshName("values") + val i = ctx.freshName("i") + val getValue = CodeGenerator.getValue(inputArray, elementType, s"$i + $startIdx") + if (!CodeGenerator.isPrimitiveType(elementType)) { + val arrayClass = classOf[GenericArrayData].getName + s""" + |Object[] $values; + |if ($startIdx < 0 || $startIdx >= $inputArray.numElements()) { + | $values = new Object[0]; + |} else { + | $values = new Object[$resLength]; + | for (int $i = 0; $i < $resLength; $i ++) { + | $values[$i] = $getValue; + | } + |} + |${ev.value} = new $arrayClass($values); + """.stripMargin + } else { + val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType) + s""" + |if ($startIdx < 0 || $startIdx >= $inputArray.numElements()) { + | $resLength = 0; + |} + |${ctx.createUnsafeArray(values, resLength, elementType, s" $prettyName failed.")} + |for (int $i = 0; $i < $resLength; $i ++) { + | if ($inputArray.isNullAt($i + $startIdx)) { + | $values.setNullAt($i); + | } else { + | $values.set$primitiveValueTypeName($i, $getValue); + | } + |} + |${ev.value} = $values; + """.stripMargin + } + } +} + /** * Creates a String containing all the elements of the input array separated by the delimiter. */ @@ -1127,24 +1249,11 @@ case class Concat(children: Seq[Expression]) extends Expression { } private def genCodeForPrimitiveArrays(ctx: CodegenContext, elementType: DataType): String = { - val arrayName = ctx.freshName("array") - val arraySizeName = ctx.freshName("size") val counter = ctx.freshName("counter") val arrayData = ctx.freshName("arrayData") val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx) - val unsafeArraySizeInBytes = s""" - |long $arraySizeName = UnsafeArrayData.calculateSizeOfUnderlyingByteArray( - | $numElemName, - | ${elementType.defaultSize}); - |if ($arraySizeName > $MAX_ARRAY_LENGTH) { - | throw new RuntimeException("Unsuccessful try to concat arrays with " + $arraySizeName + - | " bytes of data due to exceeding the limit $MAX_ARRAY_LENGTH bytes" + - | " for UnsafeArrayData."); - |} - """.stripMargin - val baseOffset = Platform.BYTE_ARRAY_OFFSET val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType) s""" @@ -1152,11 +1261,7 @@ case class Concat(children: Seq[Expression]) extends Expression { | public ArrayData concat($javaType[] args) { | ${nullArgumentProtection()} | $numElemCode - | $unsafeArraySizeInBytes - | byte[] $arrayName = new byte[(int)$arraySizeName]; - | UnsafeArrayData $arrayData = new UnsafeArrayData(); - | Platform.putLong($arrayName, $baseOffset, $numElemName); - | $arrayData.pointTo($arrayName, $baseOffset, (int)$arraySizeName); + | ${ctx.createUnsafeArray(arrayData, numElemName, elementType, s" $prettyName failed.")} | int $counter = 0; | for (int y = 0; y < ${children.length}; y++) { | for (int z = 0; z < args[y].numElements(); z++) { @@ -1308,34 +1413,16 @@ case class Flatten(child: Expression) extends UnaryExpression { ctx: CodegenContext, childVariableName: String, arrayDataName: String): String = { - val arrayName = ctx.freshName("array") - val arraySizeName = ctx.freshName("size") val counter = ctx.freshName("counter") val tempArrayDataName = ctx.freshName("tempArrayData") val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx, childVariableName) - val unsafeArraySizeInBytes = s""" - |long $arraySizeName = UnsafeArrayData.calculateSizeOfUnderlyingByteArray( - | $numElemName, - | ${elementType.defaultSize}); - |if ($arraySizeName > $MAX_ARRAY_LENGTH) { - | throw new RuntimeException("Unsuccessful try to flatten an array of arrays with " + - | $arraySizeName + " bytes of data due to exceeding the limit $MAX_ARRAY_LENGTH" + - | " bytes for UnsafeArrayData."); - |} - """.stripMargin - val baseOffset = Platform.BYTE_ARRAY_OFFSET - val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType) s""" |$numElemCode - |$unsafeArraySizeInBytes - |byte[] $arrayName = new byte[(int)$arraySizeName]; - |UnsafeArrayData $tempArrayDataName = new UnsafeArrayData(); - |Platform.putLong($arrayName, $baseOffset, $numElemName); - |$tempArrayDataName.pointTo($arrayName, $baseOffset, (int)$arraySizeName); + |${ctx.createUnsafeArray(tempArrayDataName, numElemName, elementType, s" $prettyName failed.")} |int $counter = 0; |for (int k = 0; k < $childVariableName.numElements(); k++) { | ArrayData arr = $childVariableName.getArray(k); diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 749374f1a14a1..a2851d071c7c6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -136,6 +136,34 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(ArrayContains(a3, Literal.create(null, StringType)), null) } + test("Slice") { + val a0 = Literal.create(Seq(1, 2, 3, 4, 5, 6), ArrayType(IntegerType)) + val a1 = Literal.create(Seq[String]("a", "b", "c", "d"), ArrayType(StringType)) + val a2 = Literal.create(Seq[String]("", null, "a", "b"), ArrayType(StringType)) + val a3 = Literal.create(Seq(1, 2, null, 4), ArrayType(IntegerType)) + + checkEvaluation(Slice(a0, Literal(1), Literal(2)), Seq(1, 2)) + checkEvaluation(Slice(a0, Literal(-3), Literal(2)), Seq(4, 5)) + checkEvaluation(Slice(a0, Literal(4), Literal(10)), Seq(4, 5, 6)) + checkEvaluation(Slice(a0, Literal(-1), Literal(2)), Seq(6)) + checkExceptionInExpression[RuntimeException](Slice(a0, Literal(1), Literal(-1)), + "Unexpected value for length") + checkExceptionInExpression[RuntimeException](Slice(a0, Literal(0), Literal(1)), + "Unexpected value for start") + checkEvaluation(Slice(a0, Literal(-20), Literal(1)), Seq.empty[Int]) + checkEvaluation(Slice(a1, Literal(-20), Literal(1)), Seq.empty[String]) + checkEvaluation(Slice(a0, Literal.create(null, IntegerType), Literal(2)), null) + checkEvaluation(Slice(a0, Literal(2), Literal.create(null, IntegerType)), null) + checkEvaluation(Slice(Literal.create(null, ArrayType(IntegerType)), Literal(1), Literal(2)), + null) + + checkEvaluation(Slice(a1, Literal(1), Literal(2)), Seq("a", "b")) + checkEvaluation(Slice(a2, Literal(1), Literal(2)), Seq("", null)) + checkEvaluation(Slice(a0, Literal(10), Literal(1)), Seq.empty[Int]) + checkEvaluation(Slice(a1, Literal(10), Literal(1)), Seq.empty[String]) + checkEvaluation(Slice(a3, Literal(2), Literal(3)), Seq(2, null, 4)) + } + test("ArrayJoin") { def testArrays( arrays: Seq[Expression], diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index b4bf6d7107d7e..a22e9d4655e8c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -104,6 +104,12 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { } } + protected def checkExceptionInExpression[T <: Throwable : ClassTag]( + expression: => Expression, + expectedErrMsg: String): Unit = { + checkExceptionInExpression[T](expression, InternalRow.empty, expectedErrMsg) + } + protected def checkExceptionInExpression[T <: Throwable : ClassTag]( expression: => Expression, inputRow: InternalRow, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala index 730b36c32333c..77ca640f2e0bd 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala @@ -223,7 +223,6 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { Literal.fromObject(new java.util.LinkedList[Int]), Map("nonexisting" -> Literal(1))) checkExceptionInExpression[Exception](initializeWithNonexistingMethod, - InternalRow.fromSeq(Seq()), """A method named "nonexisting" is not declared in any enclosing class """ + "nor any supertype") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 10b6dcc0608c2..8f9e4ae18b3f1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3039,6 +3039,16 @@ object functions { ArrayContains(column.expr, Literal(value)) } + /** + * Returns an array containing all the elements in `x` from index `start` (or starting from the + * end if `start` is negative) with the specified `length`. + * @group collection_funcs + * @since 2.4.0 + */ + def slice(x: Column, start: Int, length: Int): Column = withExpr { + Slice(x.expr, Literal(start), Literal(length)) + } + /** * Concatenates the elements of `column` using the `delimiter`. Null values are replaced with * `nullReplacement`. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index ae21cbc802d0a..ecce06f4c0755 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -442,6 +442,22 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { ) } + test("slice function") { + val df = Seq( + Seq(1, 2, 3), + Seq(4, 5) + ).toDF("x") + + val answer = Seq(Row(Seq(2, 3)), Row(Seq(5))) + + checkAnswer(df.select(slice(df("x"), 2, 2)), answer) + checkAnswer(df.selectExpr("slice(x, 2, 2)"), answer) + + val answerNegative = Seq(Row(Seq(3)), Row(Seq(5))) + checkAnswer(df.select(slice(df("x"), -1, 1)), answerNegative) + checkAnswer(df.selectExpr("slice(x, -1, 1)"), answerNegative) + } + test("array_join function") { val df = Seq( (Seq[String]("a", "b"), ","), From 4e861db5f149e10fd8dfe6b3c1484821a590b1e8 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Mon, 7 May 2018 11:21:22 +0200 Subject: [PATCH 0748/1161] [SPARK-16406][SQL] Improve performance of LogicalPlan.resolve ## What changes were proposed in this pull request? `LogicalPlan.resolve(...)` uses linear searches to find an attribute matching a name. This is fine in normal cases, but gets problematic when you try to resolve a large number of columns on a plan with a large number of attributes. This PR adds an indexing structure to `resolve(...)` in order to find potential matches quicker. This PR improves the reference resolution time for the following code by 4x (11.8s -> 2.4s): ``` scala val n = 4000 val values = (1 to n).map(_.toString).mkString(", ") val columns = (1 to n).map("column" + _).mkString(", ") val query = s""" |SELECT $columns |FROM VALUES ($values) T($columns) |WHERE 1=2 AND 1 IN ($columns) |GROUP BY $columns |ORDER BY $columns |""".stripMargin spark.time(sql(query)) ``` ## How was this patch tested? Existing tests. Author: Herman van Hovell Closes #14083 from hvanhovell/SPARK-16406. --- .../sql/catalyst/expressions/package.scala | 86 ++++++++++++++ .../catalyst/plans/logical/LogicalPlan.scala | 108 ++---------------- 2 files changed, 93 insertions(+), 101 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala index 1a48995358af7..8a06daa37132d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala @@ -17,8 +17,12 @@ package org.apache.spark.sql.catalyst +import java.util.Locale + import com.google.common.collect.Maps +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.analysis.{Resolver, UnresolvedAttribute} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types.{StructField, StructType} @@ -138,6 +142,88 @@ package object expressions { def indexOf(exprId: ExprId): Int = { Option(exprIdToOrdinal.get(exprId)).getOrElse(-1) } + + private def unique[T](m: Map[T, Seq[Attribute]]): Map[T, Seq[Attribute]] = { + m.mapValues(_.distinct).map(identity) + } + + /** Map to use for direct case insensitive attribute lookups. */ + @transient private lazy val direct: Map[String, Seq[Attribute]] = { + unique(attrs.groupBy(_.name.toLowerCase(Locale.ROOT))) + } + + /** Map to use for qualified case insensitive attribute lookups. */ + @transient private val qualified: Map[(String, String), Seq[Attribute]] = { + val grouped = attrs.filter(_.qualifier.isDefined).groupBy { a => + (a.qualifier.get.toLowerCase(Locale.ROOT), a.name.toLowerCase(Locale.ROOT)) + } + unique(grouped) + } + + /** Perform attribute resolution given a name and a resolver. */ + def resolve(nameParts: Seq[String], resolver: Resolver): Option[NamedExpression] = { + // Collect matching attributes given a name and a lookup. + def collectMatches(name: String, candidates: Option[Seq[Attribute]]): Seq[Attribute] = { + candidates.toSeq.flatMap(_.collect { + case a if resolver(a.name, name) => a.withName(name) + }) + } + + // Find matches for the given name assuming that the 1st part is a qualifier (i.e. table name, + // alias, or subquery alias) and the 2nd part is the actual name. This returns a tuple of + // matched attributes and a list of parts that are to be resolved. + // + // For example, consider an example where "a" is the table name, "b" is the column name, + // and "c" is the struct field name, i.e. "a.b.c". In this case, Attribute will be "a.b", + // and the second element will be List("c"). + val matches = nameParts match { + case qualifier +: name +: nestedFields => + val key = (qualifier.toLowerCase(Locale.ROOT), name.toLowerCase(Locale.ROOT)) + val attributes = collectMatches(name, qualified.get(key)).filter { a => + resolver(qualifier, a.qualifier.get) + } + (attributes, nestedFields) + case all => + (Nil, all) + } + + // If none of attributes match `table.column` pattern, we try to resolve it as a column. + val (candidates, nestedFields) = matches match { + case (Seq(), _) => + val name = nameParts.head + val attributes = collectMatches(name, direct.get(name.toLowerCase(Locale.ROOT))) + (attributes, nameParts.tail) + case _ => matches + } + + def name = UnresolvedAttribute(nameParts).name + candidates match { + case Seq(a) if nestedFields.nonEmpty => + // One match, but we also need to extract the requested nested field. + // The foldLeft adds ExtractValues for every remaining parts of the identifier, + // and aliased it with the last part of the name. + // For example, consider "a.b.c", where "a" is resolved to an existing attribute. + // Then this will add ExtractValue("c", ExtractValue("b", a)), and alias the final + // expression as "c". + val fieldExprs = nestedFields.foldLeft(a: Expression) { (e, name) => + ExtractValue(e, Literal(name), resolver) + } + Some(Alias(fieldExprs, nestedFields.last)()) + + case Seq(a) => + // One match, no nested fields, use it. + Some(a) + + case Seq() => + // No matches. + None + + case ambiguousReferences => + // More than one match. + val referenceNames = ambiguousReferences.map(_.qualifiedName).mkString(", ") + throw new AnalysisException(s"Reference '$name' is ambiguous, could be: $referenceNames.") + } + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 42034403d6d03..e487693927ab6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -86,6 +86,10 @@ abstract class LogicalPlan } } + private[this] lazy val childAttributes = AttributeSeq(children.flatMap(_.output)) + + private[this] lazy val outputAttributes = AttributeSeq(output) + /** * Optionally resolves the given strings to a [[NamedExpression]] using the input from all child * nodes of this LogicalPlan. The attribute is expressed as @@ -94,7 +98,7 @@ abstract class LogicalPlan def resolveChildren( nameParts: Seq[String], resolver: Resolver): Option[NamedExpression] = - resolve(nameParts, children.flatMap(_.output), resolver) + childAttributes.resolve(nameParts, resolver) /** * Optionally resolves the given strings to a [[NamedExpression]] based on the output of this @@ -104,7 +108,7 @@ abstract class LogicalPlan def resolve( nameParts: Seq[String], resolver: Resolver): Option[NamedExpression] = - resolve(nameParts, output, resolver) + outputAttributes.resolve(nameParts, resolver) /** * Given an attribute name, split it to name parts by dot, but @@ -114,105 +118,7 @@ abstract class LogicalPlan def resolveQuoted( name: String, resolver: Resolver): Option[NamedExpression] = { - resolve(UnresolvedAttribute.parseAttributeName(name), output, resolver) - } - - /** - * Resolve the given `name` string against the given attribute, returning either 0 or 1 match. - * - * This assumes `name` has multiple parts, where the 1st part is a qualifier - * (i.e. table name, alias, or subquery alias). - * See the comment above `candidates` variable in resolve() for semantics the returned data. - */ - private def resolveAsTableColumn( - nameParts: Seq[String], - resolver: Resolver, - attribute: Attribute): Option[(Attribute, List[String])] = { - assert(nameParts.length > 1) - if (attribute.qualifier.exists(resolver(_, nameParts.head))) { - // At least one qualifier matches. See if remaining parts match. - val remainingParts = nameParts.tail - resolveAsColumn(remainingParts, resolver, attribute) - } else { - None - } - } - - /** - * Resolve the given `name` string against the given attribute, returning either 0 or 1 match. - * - * Different from resolveAsTableColumn, this assumes `name` does NOT start with a qualifier. - * See the comment above `candidates` variable in resolve() for semantics the returned data. - */ - private def resolveAsColumn( - nameParts: Seq[String], - resolver: Resolver, - attribute: Attribute): Option[(Attribute, List[String])] = { - if (resolver(attribute.name, nameParts.head)) { - Option((attribute.withName(nameParts.head), nameParts.tail.toList)) - } else { - None - } - } - - /** Performs attribute resolution given a name and a sequence of possible attributes. */ - protected def resolve( - nameParts: Seq[String], - input: Seq[Attribute], - resolver: Resolver): Option[NamedExpression] = { - - // A sequence of possible candidate matches. - // Each candidate is a tuple. The first element is a resolved attribute, followed by a list - // of parts that are to be resolved. - // For example, consider an example where "a" is the table name, "b" is the column name, - // and "c" is the struct field name, i.e. "a.b.c". In this case, Attribute will be "a.b", - // and the second element will be List("c"). - var candidates: Seq[(Attribute, List[String])] = { - // If the name has 2 or more parts, try to resolve it as `table.column` first. - if (nameParts.length > 1) { - input.flatMap { option => - resolveAsTableColumn(nameParts, resolver, option) - } - } else { - Seq.empty - } - } - - // If none of attributes match `table.column` pattern, we try to resolve it as a column. - if (candidates.isEmpty) { - candidates = input.flatMap { candidate => - resolveAsColumn(nameParts, resolver, candidate) - } - } - - def name = UnresolvedAttribute(nameParts).name - - candidates.distinct match { - // One match, no nested fields, use it. - case Seq((a, Nil)) => Some(a) - - // One match, but we also need to extract the requested nested field. - case Seq((a, nestedFields)) => - // The foldLeft adds ExtractValues for every remaining parts of the identifier, - // and aliased it with the last part of the name. - // For example, consider "a.b.c", where "a" is resolved to an existing attribute. - // Then this will add ExtractValue("c", ExtractValue("b", a)), and alias the final - // expression as "c". - val fieldExprs = nestedFields.foldLeft(a: Expression)((expr, fieldName) => - ExtractValue(expr, Literal(fieldName), resolver)) - Some(Alias(fieldExprs, nestedFields.last)()) - - // No matches. - case Seq() => - logTrace(s"Could not find $name in ${input.mkString(", ")}") - None - - // More than one match. - case ambiguousReferences => - val referenceNames = ambiguousReferences.map(_._1.qualifiedName).mkString(", ") - throw new AnalysisException( - s"Reference '$name' is ambiguous, could be: $referenceNames.") - } + outputAttributes.resolve(UnresolvedAttribute.parseAttributeName(name), resolver) } /** From d83e9637246b05eea202add07a168688f6c0481b Mon Sep 17 00:00:00 2001 From: Bruce Robbins Date: Mon, 7 May 2018 17:54:39 +0200 Subject: [PATCH 0749/1161] [SPARK-24043][SQL] Interpreted Predicate should initialize nondeterministic expressions ## What changes were proposed in this pull request? When creating an InterpretedPredicate instance, initialize any Nondeterministic expressions in the expression tree to avoid java.lang.IllegalArgumentException on later call to eval(). ## How was this patch tested? - sbt SQL tests - python SQL tests - new unit test Author: Bruce Robbins Closes #21144 from bersprockets/interpretedpredicate. --- .../spark/sql/catalyst/expressions/predicates.scala | 8 ++++++++ .../spark/sql/catalyst/expressions/PredicateSuite.scala | 6 ++++++ 2 files changed, 14 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index e195ec17f3bcf..f8c6dc4e6adc9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -36,6 +36,14 @@ object InterpretedPredicate { case class InterpretedPredicate(expression: Expression) extends BasePredicate { override def eval(r: InternalRow): Boolean = expression.eval(r).asInstanceOf[Boolean] + + override def initialize(partitionIndex: Int): Unit = { + super.initialize(partitionIndex) + expression.foreach { + case n: Nondeterministic => n.initialize(partitionIndex) + case _ => + } + } } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala index 1bfd180ae4393..ac76b17ef4761 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala @@ -449,4 +449,10 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(EqualNullSafe(Literal(null, DoubleType), Literal(-1.0d)), false) checkEvaluation(EqualNullSafe(Literal(-1.0d), Literal(null, DoubleType)), false) } + + test("Interpreted Predicate should initialize nondeterministic expressions") { + val interpreted = InterpretedPredicate.create(LessThan(Rand(7), Literal(1.0))) + interpreted.initialize(0) + assert(interpreted.eval(new UnsafeRow())) + } } From 56a52e0a58fc82ea69e47d0d8c4f905565be7c8b Mon Sep 17 00:00:00 2001 From: Jeff Zhang Date: Mon, 7 May 2018 14:47:58 -0700 Subject: [PATCH 0750/1161] [SPARK-15750][MLLIB][PYSPARK] Constructing FPGrowth fails when no numPartitions specified in pyspark ## What changes were proposed in this pull request? Change FPGrowth from private to private[spark]. If no numPartitions is specified, then default value -1 is used. But -1 is only valid in the construction function of FPGrowth, but not in setNumPartitions. So I make this change and use the constructor directly rather than using set method. ## How was this patch tested? Unit test is added Author: Jeff Zhang Closes #13493 from zjffdu/SPARK-15750. --- .../spark/mllib/api/python/PythonMLLibAPI.scala | 5 +---- .../scala/org/apache/spark/mllib/fpm/FPGrowth.scala | 2 +- python/pyspark/mllib/tests.py | 12 ++++++++++++ 3 files changed, 14 insertions(+), 5 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index b32d3f252ae59..db3f074ecfbac 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -572,10 +572,7 @@ private[python] class PythonMLLibAPI extends Serializable { data: JavaRDD[java.lang.Iterable[Any]], minSupport: Double, numPartitions: Int): FPGrowthModel[Any] = { - val fpg = new FPGrowth() - .setMinSupport(minSupport) - .setNumPartitions(numPartitions) - + val fpg = new FPGrowth(minSupport, numPartitions) val model = fpg.run(data.rdd.map(_.asScala.toArray)) new FPGrowthModelWrapper(model) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala index f6b1143272d16..4f2b7e6f0764e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala @@ -162,7 +162,7 @@ object FPGrowthModel extends Loader[FPGrowthModel[_]] { * */ @Since("1.3.0") -class FPGrowth private ( +class FPGrowth private[spark] ( private var minSupport: Double, private var numPartitions: Int) extends Logging with Serializable { diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index 14d788b0bef60..4c2ce137e331c 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -57,6 +57,7 @@ DenseMatrix, SparseMatrix, Vectors, Matrices, MatrixUDT from pyspark.mllib.linalg.distributed import RowMatrix from pyspark.mllib.classification import StreamingLogisticRegressionWithSGD +from pyspark.mllib.fpm import FPGrowth from pyspark.mllib.recommendation import Rating from pyspark.mllib.regression import LabeledPoint, StreamingLinearRegressionWithSGD from pyspark.mllib.random import RandomRDDs @@ -1762,6 +1763,17 @@ def test_pca(self): self.assertEqualUpToSign(pcs.toArray()[:, k - 1], expected_pcs[:, k - 1]) +class FPGrowthTest(MLlibTestCase): + + def test_fpgrowth(self): + data = [["a", "b", "c"], ["a", "b", "d", "e"], ["a", "c", "e"], ["a", "c", "f"]] + rdd = self.sc.parallelize(data, 2) + model1 = FPGrowth.train(rdd, 0.6, 2) + # use default data partition number when numPartitions is not specified + model2 = FPGrowth.train(rdd, 0.6) + self.assertEqual(sorted(model1.freqItemsets().collect()), + sorted(model2.freqItemsets().collect())) + if __name__ == "__main__": from pyspark.mllib.tests import * if not _have_scipy: From 1c9c5de951ed86290bcd7d8edaab952b8cacd290 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Mon, 7 May 2018 14:52:14 -0700 Subject: [PATCH 0751/1161] [SPARK-23291][SPARK-23291][R][FOLLOWUP] Update SparkR migration note for ## What changes were proposed in this pull request? This PR fixes the migration note for SPARK-23291 since it's going to backport to 2.3.1. See the discussion in https://issues.apache.org/jira/browse/SPARK-23291 ## How was this patch tested? N/A Author: hyukjinkwon Closes #21249 from HyukjinKwon/SPARK-23291. --- docs/sparkr.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/sparkr.md b/docs/sparkr.md index 7fabab5d38f16..4faad2c4c1824 100644 --- a/docs/sparkr.md +++ b/docs/sparkr.md @@ -664,6 +664,6 @@ You can inspect the search path in R with [`search()`](https://stat.ethz.ch/R-ma - For `summary`, option for statistics to compute has been added. Its output is changed from that from `describe`. - A warning can be raised if versions of SparkR package and the Spark JVM do not match. -## Upgrading to Spark 2.4.0 +## Upgrading to SparkR 2.3.1 and above - - The `start` parameter of `substr` method was wrongly subtracted by one, previously. In other words, the index specified by `start` parameter was considered as 0-base. This can lead to inconsistent substring results and also does not match with the behaviour with `substr` in R. It has been fixed so the `start` parameter of `substr` method is now 1-base, e.g., therefore to get the same result as `substr(df$a, 2, 5)`, it should be changed to `substr(df$a, 1, 4)`. + - In SparkR 2.3.0 and earlier, the `start` parameter of `substr` method was wrongly subtracted by one and considered as 0-based. This can lead to inconsistent substring results and also does not match with the behaviour with `substr` in R. In version 2.3.1 and later, it has been fixed so the `start` parameter of `substr` method is now 1-base. As an example, `substr(lit('abcdef'), 2, 4))` would result to `abc` in SparkR 2.3.0, and the result would be `bcd` in SparkR 2.3.1. From f48bd6bdc5aefd9ec43e2d0ee648d17add7ef554 Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Mon, 7 May 2018 14:55:41 -0700 Subject: [PATCH 0752/1161] [SPARK-22885][ML][TEST] ML test for StructuredStreaming: spark.ml.tuning ## What changes were proposed in this pull request? ML test for StructuredStreaming: spark.ml.tuning ## How was this patch tested? N/A Author: WeichenXu Closes #20261 from WeichenXu123/ml_stream_tuning_test. --- .../spark/ml/tuning/CrossValidatorSuite.scala | 15 +++++++++++---- .../ml/tuning/TrainValidationSplitSuite.scala | 15 +++++++++++---- 2 files changed, 22 insertions(+), 8 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala index 15dade2627090..e6ee7220d2279 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala @@ -25,17 +25,17 @@ import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressio import org.apache.spark.ml.classification.LogisticRegressionSuite.generateLogisticInput import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, Evaluator, MulticlassClassificationEvaluator, RegressionEvaluator} import org.apache.spark.ml.feature.HashingTF -import org.apache.spark.ml.linalg.Vectors +import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.param.shared.HasInputCol import org.apache.spark.ml.regression.LinearRegression -import org.apache.spark.ml.util.{DefaultReadWriteTest, Identifiable, MLTestingUtils} -import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} +import org.apache.spark.mllib.util.LinearDataGenerator import org.apache.spark.sql.Dataset import org.apache.spark.sql.types.StructType class CrossValidatorSuite - extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + extends MLTest with DefaultReadWriteTest { import testImplicits._ @@ -66,6 +66,13 @@ class CrossValidatorSuite assert(parent.getRegParam === 0.001) assert(parent.getMaxIter === 10) assert(cvModel.avgMetrics.length === lrParamMaps.length) + + val result = cvModel.transform(dataset).select("prediction").as[Double].collect() + testTransformerByGlobalCheckFunc[(Double, Vector)](dataset.toDF(), cvModel, "prediction") { + rows => + val result2 = rows.map(_.getDouble(0)) + assert(result === result2) + } } test("cross validation with linear regression") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala index 9024342d9c831..cd76acf9c67bc 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala @@ -24,17 +24,17 @@ import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel, OneVsRest} import org.apache.spark.ml.classification.LogisticRegressionSuite.generateLogisticInput import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, Evaluator, RegressionEvaluator} -import org.apache.spark.ml.linalg.Vectors +import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.param.shared.HasInputCol import org.apache.spark.ml.regression.LinearRegression -import org.apache.spark.ml.util.{DefaultReadWriteTest, Identifiable, MLTestingUtils} -import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} +import org.apache.spark.mllib.util.LinearDataGenerator import org.apache.spark.sql.Dataset import org.apache.spark.sql.types.StructType class TrainValidationSplitSuite - extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + extends MLTest with DefaultReadWriteTest { import testImplicits._ @@ -64,6 +64,13 @@ class TrainValidationSplitSuite assert(parent.getRegParam === 0.001) assert(parent.getMaxIter === 10) assert(tvsModel.validationMetrics.length === lrParamMaps.length) + + val result = tvsModel.transform(dataset).select("prediction").as[Double].collect() + testTransformerByGlobalCheckFunc[(Double, Vector)](dataset.toDF(), tvsModel, "prediction") { + rows => + val result2 = rows.map(_.getDouble(0)) + assert(result === result2) + } } test("train validation with linear regression") { From 76ecd095024a658bf68e5db658e4416565b30c17 Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Mon, 7 May 2018 14:57:14 -0700 Subject: [PATCH 0753/1161] [SPARK-20114][ML] spark.ml parity for sequential pattern mining - PrefixSpan ## What changes were proposed in this pull request? PrefixSpan API for spark.ml. New implementation instead of #20810 ## How was this patch tested? TestSuite added. Author: WeichenXu Closes #20973 from WeichenXu123/prefixSpan2. --- .../org/apache/spark/ml/fpm/PrefixSpan.scala | 96 +++++++++++++ .../apache/spark/mllib/fpm/PrefixSpan.scala | 3 +- .../apache/spark/ml/fpm/PrefixSpanSuite.scala | 136 ++++++++++++++++++ 3 files changed, 233 insertions(+), 2 deletions(-) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/fpm/PrefixSpanSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala new file mode 100644 index 0000000000000..02168fee16dbf --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.fpm + +import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.mllib.fpm.{PrefixSpan => mllibPrefixSpan} +import org.apache.spark.sql.{DataFrame, Dataset, Row} +import org.apache.spark.sql.functions.col +import org.apache.spark.sql.types.{ArrayType, LongType, StructField, StructType} + +/** + * :: Experimental :: + * A parallel PrefixSpan algorithm to mine frequent sequential patterns. + * The PrefixSpan algorithm is described in J. Pei, et al., PrefixSpan: Mining Sequential Patterns + * Efficiently by Prefix-Projected Pattern Growth + * (see here). + * + * @see Sequential Pattern Mining + * (Wikipedia) + */ +@Since("2.4.0") +@Experimental +object PrefixSpan { + + /** + * :: Experimental :: + * Finds the complete set of frequent sequential patterns in the input sequences of itemsets. + * + * @param dataset A dataset or a dataframe containing a sequence column which is + * {{{Seq[Seq[_]]}}} type + * @param sequenceCol the name of the sequence column in dataset, rows with nulls in this column + * are ignored + * @param minSupport the minimal support level of the sequential pattern, any pattern that + * appears more than (minSupport * size-of-the-dataset) times will be output + * (recommended value: `0.1`). + * @param maxPatternLength the maximal length of the sequential pattern + * (recommended value: `10`). + * @param maxLocalProjDBSize The maximum number of items (including delimiters used in the + * internal storage format) allowed in a projected database before + * local processing. If a projected database exceeds this size, another + * iteration of distributed prefix growth is run + * (recommended value: `32000000`). + * @return A `DataFrame` that contains columns of sequence and corresponding frequency. + * The schema of it will be: + * - `sequence: Seq[Seq[T]]` (T is the item type) + * - `freq: Long` + */ + @Since("2.4.0") + def findFrequentSequentialPatterns( + dataset: Dataset[_], + sequenceCol: String, + minSupport: Double, + maxPatternLength: Int, + maxLocalProjDBSize: Long): DataFrame = { + + val inputType = dataset.schema(sequenceCol).dataType + require(inputType.isInstanceOf[ArrayType] && + inputType.asInstanceOf[ArrayType].elementType.isInstanceOf[ArrayType], + s"The input column must be ArrayType and the array element type must also be ArrayType, " + + s"but got $inputType.") + + + val data = dataset.select(sequenceCol) + val sequences = data.where(col(sequenceCol).isNotNull).rdd + .map(r => r.getAs[Seq[Seq[Any]]](0).map(_.toArray).toArray) + + val mllibPrefixSpan = new mllibPrefixSpan() + .setMinSupport(minSupport) + .setMaxPatternLength(maxPatternLength) + .setMaxLocalProjDBSize(maxLocalProjDBSize) + + val rows = mllibPrefixSpan.run(sequences).freqSequences.map(f => Row(f.sequence, f.freq)) + val schema = StructType(Seq( + StructField("sequence", dataset.schema(sequenceCol).dataType, nullable = false), + StructField("freq", LongType, nullable = false))) + val freqSequences = dataset.sparkSession.createDataFrame(rows, schema) + + freqSequences + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala index 3f8d65a378e2c..7aed2f3bd8a61 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala @@ -49,8 +49,7 @@ import org.apache.spark.storage.StorageLevel * * @param minSupport the minimal support level of the sequential pattern, any pattern that appears * more than (minSupport * size-of-the-dataset) times will be output - * @param maxPatternLength the maximal length of the sequential pattern, any pattern that appears - * less than maxPatternLength will be output + * @param maxPatternLength the maximal length of the sequential pattern * @param maxLocalProjDBSize The maximum number of items (including delimiters used in the internal * storage format) allowed in a projected database before local * processing. If a projected database exceeds this size, another diff --git a/mllib/src/test/scala/org/apache/spark/ml/fpm/PrefixSpanSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/fpm/PrefixSpanSuite.scala new file mode 100644 index 0000000000000..9e538696cbcf7 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/fpm/PrefixSpanSuite.scala @@ -0,0 +1,136 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.ml.fpm + +import org.apache.spark.ml.util.MLTest +import org.apache.spark.sql.DataFrame + +class PrefixSpanSuite extends MLTest { + + import testImplicits._ + + override def beforeAll(): Unit = { + super.beforeAll() + } + + test("PrefixSpan projections with multiple partial starts") { + val smallDataset = Seq(Seq(Seq(1, 2), Seq(1, 2, 3))).toDF("sequence") + val result = PrefixSpan.findFrequentSequentialPatterns(smallDataset, "sequence", + minSupport = 1.0, maxPatternLength = 2, maxLocalProjDBSize = 32000000) + .as[(Seq[Seq[Int]], Long)].collect() + val expected = Array( + (Seq(Seq(1)), 1L), + (Seq(Seq(1, 2)), 1L), + (Seq(Seq(1), Seq(1)), 1L), + (Seq(Seq(1), Seq(2)), 1L), + (Seq(Seq(1), Seq(3)), 1L), + (Seq(Seq(1, 3)), 1L), + (Seq(Seq(2)), 1L), + (Seq(Seq(2, 3)), 1L), + (Seq(Seq(2), Seq(1)), 1L), + (Seq(Seq(2), Seq(2)), 1L), + (Seq(Seq(2), Seq(3)), 1L), + (Seq(Seq(3)), 1L)) + compareResults[Int](expected, result) + } + + /* + To verify expected results for `smallTestData`, create file "prefixSpanSeqs2" with content + (format = (transactionID, idxInTransaction, numItemsinItemset, itemset)): + 1 1 2 1 2 + 1 2 1 3 + 2 1 1 1 + 2 2 2 3 2 + 2 3 2 1 2 + 3 1 2 1 2 + 3 2 1 5 + 4 1 1 6 + In R, run: + library("arulesSequences") + prefixSpanSeqs = read_baskets("prefixSpanSeqs", info = c("sequenceID","eventID","SIZE")) + freqItemSeq = cspade(prefixSpanSeqs, + parameter = 0.5, maxlen = 5 )) + resSeq = as(freqItemSeq, "data.frame") + resSeq + + sequence support + 1 <{1}> 0.75 + 2 <{2}> 0.75 + 3 <{3}> 0.50 + 4 <{1},{3}> 0.50 + 5 <{1,2}> 0.75 + */ + val smallTestData = Seq( + Seq(Seq(1, 2), Seq(3)), + Seq(Seq(1), Seq(3, 2), Seq(1, 2)), + Seq(Seq(1, 2), Seq(5)), + Seq(Seq(6))) + + val smallTestDataExpectedResult = Array( + (Seq(Seq(1)), 3L), + (Seq(Seq(2)), 3L), + (Seq(Seq(3)), 2L), + (Seq(Seq(1), Seq(3)), 2L), + (Seq(Seq(1, 2)), 3L) + ) + + test("PrefixSpan Integer type, variable-size itemsets") { + val df = smallTestData.toDF("sequence") + val result = PrefixSpan.findFrequentSequentialPatterns(df, "sequence", + minSupport = 0.5, maxPatternLength = 5, maxLocalProjDBSize = 32000000) + .as[(Seq[Seq[Int]], Long)].collect() + + compareResults[Int](smallTestDataExpectedResult, result) + } + + test("PrefixSpan input row with nulls") { + val df = (smallTestData :+ null).toDF("sequence") + val result = PrefixSpan.findFrequentSequentialPatterns(df, "sequence", + minSupport = 0.5, maxPatternLength = 5, maxLocalProjDBSize = 32000000) + .as[(Seq[Seq[Int]], Long)].collect() + + compareResults[Int](smallTestDataExpectedResult, result) + } + + test("PrefixSpan String type, variable-size itemsets") { + val intToString = (1 to 6).zip(Seq("a", "b", "c", "d", "e", "f")).toMap + val df = smallTestData + .map(seq => seq.map(itemSet => itemSet.map(intToString))) + .toDF("sequence") + val result = PrefixSpan.findFrequentSequentialPatterns(df, "sequence", + minSupport = 0.5, maxPatternLength = 5, maxLocalProjDBSize = 32000000) + .as[(Seq[Seq[String]], Long)].collect() + + val expected = smallTestDataExpectedResult.map { case (seq, freq) => + (seq.map(itemSet => itemSet.map(intToString)), freq) + } + compareResults[String](expected, result) + } + + private def compareResults[Item]( + expectedValue: Array[(Seq[Seq[Item]], Long)], + actualValue: Array[(Seq[Seq[Item]], Long)]): Unit = { + val expectedSet = expectedValue.map { x => + (x._1.map(_.toSet), x._2) + }.toSet + val actualSet = actualValue.map { x => + (x._1.map(_.toSet), x._2) + }.toSet + assert(expectedSet === actualSet) + } +} + From 0d63eb8888d17df747fb41d7ba254718bb7af3ae Mon Sep 17 00:00:00 2001 From: Lu WANG Date: Mon, 7 May 2018 20:08:41 -0700 Subject: [PATCH 0754/1161] [SPARK-23975][ML] Add support of array input for all clustering methods ## What changes were proposed in this pull request? Add support for all of the clustering methods ## How was this patch tested? unit tests added Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Lu WANG Closes #21195 from ludatabricks/SPARK-23975-1. --- .../spark/ml/clustering/BisectingKMeans.scala | 21 ++++---- .../spark/ml/clustering/GaussianMixture.scala | 12 +++-- .../apache/spark/ml/clustering/KMeans.scala | 31 +++--------- .../org/apache/spark/ml/clustering/LDA.scala | 9 ++-- .../apache/spark/ml/util/DatasetUtils.scala | 13 ++++- .../apache/spark/ml/util/SchemaUtils.scala | 16 ++++++- .../ml/clustering/BisectingKMeansSuite.scala | 21 +++++++- .../ml/clustering/GaussianMixtureSuite.scala | 21 +++++++- .../spark/ml/clustering/KMeansSuite.scala | 48 ++++++------------- .../apache/spark/ml/clustering/LDASuite.scala | 20 +++++++- .../apache/spark/ml/util/MLTestingUtils.scala | 23 ++++++++- 11 files changed, 147 insertions(+), 88 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala index addc12ac52ec1..438e53ba6197c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala @@ -22,17 +22,15 @@ import org.apache.hadoop.fs.Path import org.apache.spark.SparkException import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.{Estimator, Model} -import org.apache.spark.ml.linalg.{Vector, VectorUDT} +import org.apache.spark.ml.linalg.Vector import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ import org.apache.spark.mllib.clustering.{BisectingKMeans => MLlibBisectingKMeans, BisectingKMeansModel => MLlibBisectingKMeansModel} -import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => OldVectors} import org.apache.spark.mllib.linalg.VectorImplicits._ -import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Dataset, Row} -import org.apache.spark.sql.functions.{col, udf} +import org.apache.spark.sql.functions.udf import org.apache.spark.sql.types.{IntegerType, StructType} @@ -75,7 +73,7 @@ private[clustering] trait BisectingKMeansParams extends Params with HasMaxIter * @return output schema */ protected def validateAndTransformSchema(schema: StructType): StructType = { - SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT) + SchemaUtils.validateVectorCompatibleColumn(schema, getFeaturesCol) SchemaUtils.appendColumn(schema, $(predictionCol), IntegerType) } } @@ -113,7 +111,8 @@ class BisectingKMeansModel private[ml] ( override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) val predictUDF = udf((vector: Vector) => predict(vector)) - dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) + dataset.withColumn($(predictionCol), + predictUDF(DatasetUtils.columnToVector(dataset, getFeaturesCol))) } @Since("2.0.0") @@ -132,9 +131,9 @@ class BisectingKMeansModel private[ml] ( */ @Since("2.0.0") def computeCost(dataset: Dataset[_]): Double = { - SchemaUtils.checkColumnType(dataset.schema, $(featuresCol), new VectorUDT) - val data = dataset.select(col($(featuresCol))).rdd.map { case Row(point: Vector) => point } - parentModel.computeCost(data.map(OldVectors.fromML)) + SchemaUtils.validateVectorCompatibleColumn(dataset.schema, getFeaturesCol) + val data = DatasetUtils.columnToOldVector(dataset, getFeaturesCol) + parentModel.computeCost(data) } @Since("2.0.0") @@ -260,9 +259,7 @@ class BisectingKMeans @Since("2.0.0") ( @Since("2.0.0") override def fit(dataset: Dataset[_]): BisectingKMeansModel = { transformSchema(dataset.schema, logging = true) - val rdd: RDD[OldVector] = dataset.select(col($(featuresCol))).rdd.map { - case Row(point: Vector) => OldVectors.fromML(point) - } + val rdd = DatasetUtils.columnToOldVector(dataset, getFeaturesCol) val instr = Instrumentation.create(this, rdd) instr.logParams(featuresCol, predictionCol, k, maxIter, seed, minDivisibleClusterSize) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala index b5804900c0358..88d618c3a03a8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala @@ -33,7 +33,7 @@ import org.apache.spark.mllib.linalg.{Matrices => OldMatrices, Matrix => OldMatr Vector => OldVector, Vectors => OldVectors} import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession} -import org.apache.spark.sql.functions.{col, udf} +import org.apache.spark.sql.functions.udf import org.apache.spark.sql.types.{IntegerType, StructType} @@ -63,7 +63,7 @@ private[clustering] trait GaussianMixtureParams extends Params with HasMaxIter w * @return output schema */ protected def validateAndTransformSchema(schema: StructType): StructType = { - SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT) + SchemaUtils.validateVectorCompatibleColumn(schema, getFeaturesCol) val schemaWithPredictionCol = SchemaUtils.appendColumn(schema, $(predictionCol), IntegerType) SchemaUtils.appendColumn(schemaWithPredictionCol, $(probabilityCol), new VectorUDT) } @@ -109,8 +109,9 @@ class GaussianMixtureModel private[ml] ( transformSchema(dataset.schema, logging = true) val predUDF = udf((vector: Vector) => predict(vector)) val probUDF = udf((vector: Vector) => predictProbability(vector)) - dataset.withColumn($(predictionCol), predUDF(col($(featuresCol)))) - .withColumn($(probabilityCol), probUDF(col($(featuresCol)))) + dataset + .withColumn($(predictionCol), predUDF(DatasetUtils.columnToVector(dataset, getFeaturesCol))) + .withColumn($(probabilityCol), probUDF(DatasetUtils.columnToVector(dataset, getFeaturesCol))) } @Since("2.0.0") @@ -340,7 +341,8 @@ class GaussianMixture @Since("2.0.0") ( val sc = dataset.sparkSession.sparkContext val numClusters = $(k) - val instances: RDD[Vector] = dataset.select(col($(featuresCol))).rdd.map { + val instances: RDD[Vector] = dataset + .select(DatasetUtils.columnToVector(dataset, getFeaturesCol)).rdd.map { case Row(features: Vector) => features }.cache() diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index de61c9c089a36..97f246fbfd859 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -24,7 +24,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.SparkException import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.{Estimator, Model, PipelineStage} -import org.apache.spark.ml.linalg.{Vector, VectorUDT} +import org.apache.spark.ml.linalg.Vector import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ @@ -34,7 +34,7 @@ import org.apache.spark.mllib.linalg.VectorImplicits._ import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession} import org.apache.spark.sql.functions.udf -import org.apache.spark.sql.types.{ArrayType, DoubleType, FloatType, IntegerType, StructType} +import org.apache.spark.sql.types.{IntegerType, StructType} import org.apache.spark.storage.StorageLevel import org.apache.spark.util.VersionUtils.majorVersion @@ -86,24 +86,13 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe @Since("1.5.0") def getInitSteps: Int = $(initSteps) - /** - * Validates the input schema. - * @param schema input schema - */ - private[clustering] def validateSchema(schema: StructType): Unit = { - val typeCandidates = List( new VectorUDT, - new ArrayType(DoubleType, false), - new ArrayType(FloatType, false)) - - SchemaUtils.checkColumnTypes(schema, $(featuresCol), typeCandidates) - } /** * Validates and transforms the input schema. * @param schema input schema * @return output schema */ protected def validateAndTransformSchema(schema: StructType): StructType = { - validateSchema(schema) + SchemaUtils.validateVectorCompatibleColumn(schema, getFeaturesCol) SchemaUtils.appendColumn(schema, $(predictionCol), IntegerType) } } @@ -160,12 +149,8 @@ class KMeansModel private[ml] ( // TODO: Replace the temp fix when we have proper evaluators defined for clustering. @Since("2.0.0") def computeCost(dataset: Dataset[_]): Double = { - validateSchema(dataset.schema) - - val data: RDD[OldVector] = dataset.select(DatasetUtils.columnToVector(dataset, getFeaturesCol)) - .rdd.map { - case Row(point: Vector) => OldVectors.fromML(point) - } + SchemaUtils.validateVectorCompatibleColumn(dataset.schema, getFeaturesCol) + val data = DatasetUtils.columnToOldVector(dataset, getFeaturesCol) parentModel.computeCost(data) } @@ -351,11 +336,7 @@ class KMeans @Since("1.5.0") ( transformSchema(dataset.schema, logging = true) val handlePersistence = dataset.storageLevel == StorageLevel.NONE - val instances: RDD[OldVector] = dataset.select( - DatasetUtils.columnToVector(dataset, getFeaturesCol)) - .rdd.map { - case Row(point: Vector) => OldVectors.fromML(point) - } + val instances = DatasetUtils.columnToOldVector(dataset, getFeaturesCol) if (handlePersistence) { instances.persist(StorageLevel.MEMORY_AND_DISK) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala index 47077230fac0a..afe599cd167cb 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala @@ -43,7 +43,7 @@ import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession} import org.apache.spark.sql.functions.{col, monotonically_increasing_id, udf} -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{ArrayType, DoubleType, FloatType, StructType} import org.apache.spark.util.PeriodicCheckpointer import org.apache.spark.util.VersionUtils @@ -345,7 +345,7 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM s" must be >= 1. Found value: $getTopicConcentration") } } - SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT) + SchemaUtils.validateVectorCompatibleColumn(schema, getFeaturesCol) SchemaUtils.appendColumn(schema, $(topicDistributionCol), new VectorUDT) } @@ -461,7 +461,8 @@ abstract class LDAModel private[ml] ( val transformer = oldLocalModel.getTopicDistributionMethod val t = udf { (v: Vector) => transformer(OldVectors.fromML(v)).asML } - dataset.withColumn($(topicDistributionCol), t(col($(featuresCol)))).toDF() + dataset.withColumn($(topicDistributionCol), + t(DatasetUtils.columnToVector(dataset, getFeaturesCol))).toDF() } else { logWarning("LDAModel.transform was called without any output columns. Set an output column" + " such as topicDistributionCol to produce results.") @@ -938,7 +939,7 @@ object LDA extends MLReadable[LDA] { featuresCol: String): RDD[(Long, OldVector)] = { dataset .withColumn("docId", monotonically_increasing_id()) - .select("docId", featuresCol) + .select(col("docId"), DatasetUtils.columnToVector(dataset, featuresCol)) .rdd .map { case Row(docId: Long, features: Vector) => (docId, OldVectors.fromML(features)) diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/DatasetUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/util/DatasetUtils.scala index 52619cb65489a..6af4b3ebc2cc2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/DatasetUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/DatasetUtils.scala @@ -17,8 +17,10 @@ package org.apache.spark.ml.util -import org.apache.spark.ml.linalg.{Vectors, VectorUDT} -import org.apache.spark.sql.{Column, Dataset} +import org.apache.spark.ml.linalg.{Vector, Vectors, VectorUDT} +import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => OldVectors} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{Column, Dataset, Row} import org.apache.spark.sql.functions.{col, udf} import org.apache.spark.sql.types.{ArrayType, DoubleType, FloatType} @@ -60,4 +62,11 @@ private[spark] object DatasetUtils { throw new IllegalArgumentException(s"$other column cannot be cast to Vector") } } + + def columnToOldVector(dataset: Dataset[_], colName: String): RDD[OldVector] = { + dataset.select(columnToVector(dataset, colName)) + .rdd.map { + case Row(point: Vector) => OldVectors.fromML(point) + } + } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala index 334410c9620de..d9a3f85ef9a24 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala @@ -17,7 +17,8 @@ package org.apache.spark.ml.util -import org.apache.spark.sql.types.{DataType, NumericType, StructField, StructType} +import org.apache.spark.ml.linalg.VectorUDT +import org.apache.spark.sql.types._ /** @@ -101,4 +102,17 @@ private[spark] object SchemaUtils { require(!schema.fieldNames.contains(col.name), s"Column ${col.name} already exists.") StructType(schema.fields :+ col) } + + /** + * Check whether the given column in the schema is one of the supporting vector type: Vector, + * Array[Float]. Array[Double] + * @param schema input schema + * @param colName column name + */ + def validateVectorCompatibleColumn(schema: StructType, colName: String): Unit = { + val typeCandidates = List( new VectorUDT, + new ArrayType(DoubleType, false), + new ArrayType(FloatType, false)) + checkColumnTypes(schema, colName, typeCandidates) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala index 02880f96ae6d9..f3ff2afcad2cd 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala @@ -17,13 +17,16 @@ package org.apache.spark.ml.clustering +import scala.language.existentials + import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.clustering.DistanceMeasure import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.Dataset +import org.apache.spark.sql.{DataFrame, Dataset} class BisectingKMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { @@ -182,6 +185,22 @@ class BisectingKMeansSuite model.clusterCenters.forall(Vectors.norm(_, 2) == 1.0) } + + test("BisectingKMeans with Array input") { + def trainAndComputeCost(dataset: Dataset[_]): Double = { + val model = new BisectingKMeans().setK(k).setMaxIter(1).setSeed(1).fit(dataset) + model.computeCost(dataset) + } + + val (newDataset, newDatasetD, newDatasetF) = MLTestingUtils.generateArrayFeatureDataset(dataset) + val trueCost = trainAndComputeCost(newDataset) + val doubleArrayCost = trainAndComputeCost(newDatasetD) + val floatArrayCost = trainAndComputeCost(newDatasetF) + + // checking the cost is fine enough as a sanity check + assert(trueCost ~== doubleArrayCost absTol 1e-6) + assert(trueCost ~== floatArrayCost absTol 1e-6) + } } object BisectingKMeansSuite { diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala index 08b800b7e4183..d0d461a42711a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.ml.clustering +import scala.language.existentials + import org.apache.spark.SparkFunSuite import org.apache.spark.ml.linalg.{DenseMatrix, Matrices, Vector, Vectors} import org.apache.spark.ml.param.ParamMap @@ -24,8 +26,7 @@ import org.apache.spark.ml.stat.distribution.MultivariateGaussian import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.{Dataset, Row} - +import org.apache.spark.sql.{DataFrame, Dataset, Row} class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { @@ -256,6 +257,22 @@ class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext val expectedMatrix = GaussianMixture.unpackUpperTriangularMatrix(4, triangularValues) assert(symmetricMatrix === expectedMatrix) } + + test("GaussianMixture with Array input") { + def trainAndComputlogLikelihood(dataset: Dataset[_]): Double = { + val model = new GaussianMixture().setK(k).setMaxIter(1).setSeed(1).fit(dataset) + model.summary.logLikelihood + } + + val (newDataset, newDatasetD, newDatasetF) = MLTestingUtils.generateArrayFeatureDataset(dataset) + val trueLikelihood = trainAndComputlogLikelihood(newDataset) + val doubleLikelihood = trainAndComputlogLikelihood(newDatasetD) + val floatLikelihood = trainAndComputlogLikelihood(newDatasetF) + + // checking the cost is fine enough as a sanity check + assert(trueLikelihood ~== doubleLikelihood absTol 1e-6) + assert(trueLikelihood ~== floatLikelihood absTol 1e-6) + } } object GaussianMixtureSuite extends SparkFunSuite { diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala index 5445ebe5c95eb..680a7c2034083 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.ml.clustering +import scala.language.existentials import scala.util.Random import org.dmg.pmml.{ClusteringModel, PMML} @@ -25,13 +26,11 @@ import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.util._ -import org.apache.spark.mllib.clustering.{DistanceMeasure, KMeans => MLlibKMeans, - KMeansModel => MLlibKMeansModel} +import org.apache.spark.ml.util.TestingUtils._ +import org.apache.spark.mllib.clustering.{DistanceMeasure, KMeans => MLlibKMeans, KMeansModel => MLlibKMeansModel} import org.apache.spark.mllib.linalg.{Vectors => MLlibVectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, Dataset, SparkSession} -import org.apache.spark.sql.functions._ -import org.apache.spark.sql.types.{ArrayType, DoubleType, FloatType, IntegerType, StructType} private[clustering] case class TestRow(features: Vector) @@ -202,38 +201,19 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR } test("KMean with Array input") { - val featuresColNameD = "array_double_features" - val featuresColNameF = "array_float_features" - - val doubleUDF = udf { (features: Vector) => - val featureArray = Array.fill[Double](features.size)(0.0) - features.foreachActive((idx, value) => featureArray(idx) = value.toFloat) - featureArray - } - val floatUDF = udf { (features: Vector) => - val featureArray = Array.fill[Float](features.size)(0.0f) - features.foreachActive((idx, value) => featureArray(idx) = value.toFloat) - featureArray + def trainAndComputeCost(dataset: Dataset[_]): Double = { + val model = new KMeans().setK(k).setMaxIter(1).setSeed(1).fit(dataset) + model.computeCost(dataset) } - val newdatasetD = dataset.withColumn(featuresColNameD, doubleUDF(col("features"))) - .drop("features") - val newdatasetF = dataset.withColumn(featuresColNameF, floatUDF(col("features"))) - .drop("features") - assert(newdatasetD.schema(featuresColNameD).dataType.equals(new ArrayType(DoubleType, false))) - assert(newdatasetF.schema(featuresColNameF).dataType.equals(new ArrayType(FloatType, false))) - - val kmeansD = new KMeans().setK(k).setMaxIter(1).setFeaturesCol(featuresColNameD).setSeed(1) - val kmeansF = new KMeans().setK(k).setMaxIter(1).setFeaturesCol(featuresColNameF).setSeed(1) - val modelD = kmeansD.fit(newdatasetD) - val modelF = kmeansF.fit(newdatasetF) - val transformedD = modelD.transform(newdatasetD) - val transformedF = modelF.transform(newdatasetF) - - val predictDifference = transformedD.select("prediction") - .except(transformedF.select("prediction")) - assert(predictDifference.count() == 0) - assert(modelD.computeCost(newdatasetD) == modelF.computeCost(newdatasetF) ) + val (newDataset, newDatasetD, newDatasetF) = MLTestingUtils.generateArrayFeatureDataset(dataset) + val trueCost = trainAndComputeCost(newDataset) + val doubleArrayCost = trainAndComputeCost(newDatasetD) + val floatArrayCost = trainAndComputeCost(newDatasetF) + + // checking the cost is fine enough as a sanity check + assert(trueCost ~== doubleArrayCost absTol 1e-6) + assert(trueCost ~== floatArrayCost absTol 1e-6) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala index e73bbc18d76bd..8d728f063dd8c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.ml.clustering +import scala.language.existentials + import org.apache.hadoop.fs.Path import org.apache.spark.SparkFunSuite @@ -26,7 +28,6 @@ import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql._ - object LDASuite { def generateLDAData( spark: SparkSession, @@ -323,4 +324,21 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead assert(model.getOptimizer === optimizer) } } + + test("LDA with Array input") { + def trainAndLogLikelihoodAndPerplexity(dataset: Dataset[_]): (Double, Double) = { + val model = new LDA().setK(k).setOptimizer("online").setMaxIter(1).setSeed(1).fit(dataset) + (model.logLikelihood(dataset), model.logPerplexity(dataset)) + } + + val (newDataset, newDatasetD, newDatasetF) = MLTestingUtils.generateArrayFeatureDataset(dataset) + val (ll, lp) = trainAndLogLikelihoodAndPerplexity(newDataset) + val (llD, lpD) = trainAndLogLikelihoodAndPerplexity(newDatasetD) + val (llF, lpF) = trainAndLogLikelihoodAndPerplexity(newDatasetF) + // TODO: need to compare the results once we fix the seed issue for LDA (SPARK-22210) + assert(llD <= 0.0 && llD != Double.NegativeInfinity) + assert(llF <= 0.0 && llF != Double.NegativeInfinity) + assert(lpD >= 0.0 && lpD != Double.NegativeInfinity) + assert(lpF >= 0.0 && lpF != Double.NegativeInfinity) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala index c328d81b4bc3a..5e72b4d864c1d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala @@ -21,7 +21,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.ml._ import org.apache.spark.ml.evaluation.Evaluator import org.apache.spark.ml.feature.{Instance, LabeledPoint} -import org.apache.spark.ml.linalg.{Vector, Vectors} +import org.apache.spark.ml.linalg.{Vector, Vectors, VectorUDT} import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasLabelCol, HasWeightCol} import org.apache.spark.ml.recommendation.{ALS, ALSModel} @@ -247,4 +247,25 @@ object MLTestingUtils extends SparkFunSuite { } models.sliding(2).foreach { case Seq(m1, m2) => modelEquals(m1, m2)} } + + /** + * Helper function for testing different input types for "features" column. Given a DataFrame, + * generate three output DataFrames: one having vector "features" column with float precision, + * one having double array "features" column with float precision, and one having float array + * "features" column. + */ + def generateArrayFeatureDataset(dataset: Dataset[_], + featuresColName: String = "features"): (Dataset[_], Dataset[_], Dataset[_]) = { + val toFloatVectorUDF = udf { (features: Vector) => + Vectors.dense(features.toArray.map(_.toFloat.toDouble))} + val toDoubleArrayUDF = udf { (features: Vector) => features.toArray} + val toFloatArrayUDF = udf { (features: Vector) => features.toArray.map(_.toFloat)} + val newDataset = dataset.withColumn(featuresColName, toFloatVectorUDF(col(featuresColName))) + val newDatasetD = newDataset.withColumn(featuresColName, toDoubleArrayUDF(col(featuresColName))) + val newDatasetF = newDataset.withColumn(featuresColName, toFloatArrayUDF(col(featuresColName))) + assert(newDataset.schema(featuresColName).dataType.equals(new VectorUDT)) + assert(newDatasetD.schema(featuresColName).dataType.equals(new ArrayType(DoubleType, false))) + assert(newDatasetF.schema(featuresColName).dataType.equals(new ArrayType(FloatType, false))) + (newDataset, newDatasetD, newDatasetF) + } } From cd12c5c3ecf28f7b04f566c2057f9b65eb456b7d Mon Sep 17 00:00:00 2001 From: Henry Robinson Date: Tue, 8 May 2018 12:21:33 +0800 Subject: [PATCH 0755/1161] [SPARK-24128][SQL] Mention configuration option in implicit CROSS JOIN error ## What changes were proposed in this pull request? Mention `spark.sql.crossJoin.enabled` in error message when an implicit `CROSS JOIN` is detected. ## How was this patch tested? `CartesianProductSuite` and `JoinSuite`. Author: Henry Robinson Closes #21201 from henryr/spark-24128. --- R/pkg/tests/fulltests/test_sparkSQL.R | 4 ++-- .../org/apache/spark/sql/catalyst/optimizer/Optimizer.scala | 6 ++++-- .../catalyst/optimizer/CheckCartesianProductsSuite.scala | 2 +- .../src/test/scala/org/apache/spark/sql/JoinSuite.scala | 4 ++-- 4 files changed, 9 insertions(+), 7 deletions(-) diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index 3a8866bf2a88a..43725e0ebd3bf 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -2210,8 +2210,8 @@ test_that("join(), crossJoin() and merge() on a DataFrame", { expect_equal(count(where(join(df, df2), df$name == df2$name)), 3) # cartesian join expect_error(tryCatch(count(join(df, df2)), error = function(e) { stop(e) }), - paste0(".*(org.apache.spark.sql.AnalysisException: Detected cartesian product for", - " INNER join between logical plans).*")) + paste0(".*(org.apache.spark.sql.AnalysisException: Detected implicit cartesian", + " product for INNER join between logical plans).*")) joined <- crossJoin(df, df2) expect_equal(names(joined), c("age", "name", "name", "test")) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 45f13956a0a85..bfa61116a6658 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -1182,12 +1182,14 @@ object CheckCartesianProducts extends Rule[LogicalPlan] with PredicateHelper { case j @ Join(left, right, Inner | LeftOuter | RightOuter | FullOuter, _) if isCartesianProduct(j) => throw new AnalysisException( - s"""Detected cartesian product for ${j.joinType.sql} join between logical plans + s"""Detected implicit cartesian product for ${j.joinType.sql} join between logical plans |${left.treeString(false).trim} |and |${right.treeString(false).trim} |Join condition is missing or trivial. - |Use the CROSS JOIN syntax to allow cartesian products between these relations.""" + |Either: use the CROSS JOIN syntax to allow cartesian products between these + |relations, or: enable implicit cartesian products by setting the configuration + |variable spark.sql.crossJoin.enabled=true""" .stripMargin) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CheckCartesianProductsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CheckCartesianProductsSuite.scala index 21220b38968e8..788fedb3c8e8e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CheckCartesianProductsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CheckCartesianProductsSuite.scala @@ -56,7 +56,7 @@ class CheckCartesianProductsSuite extends PlanTest { val thrownException = the [AnalysisException] thrownBy { performCartesianProductCheck(joinType) } - assert(thrownException.message.contains("Detected cartesian product")) + assert(thrownException.message.contains("Detected implicit cartesian product")) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 771e1186e63ab..8fa747465cb1a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -239,7 +239,7 @@ class JoinSuite extends QueryTest with SharedSQLContext { Row(2, 2, 1, null) :: Row(2, 2, 2, 2) :: Nil) } - assert(e.getMessage.contains("Detected cartesian product for INNER join " + + assert(e.getMessage.contains("Detected implicit cartesian product for INNER join " + "between logical plans")) } } @@ -611,7 +611,7 @@ class JoinSuite extends QueryTest with SharedSQLContext { val e = intercept[Exception] { checkAnswer(sql(query), Nil); } - assert(e.getMessage.contains("Detected cartesian product")) + assert(e.getMessage.contains("Detected implicit cartesian product")) } cartesianQueries.foreach(checkCartesianDetection) From 05eb19b6e09065265358eec2db2ff3b42806dfc9 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Tue, 8 May 2018 14:32:04 +0800 Subject: [PATCH 0756/1161] [SPARK-24188][CORE] Restore "/version" API endpoint. It was missing the jax-rs annotation. Author: Marcelo Vanzin Closes #21245 from vanzin/SPARK-24188. Change-Id: Ib338e34b363d7c729cc92202df020dc51033b719 --- .../org/apache/spark/status/api/v1/ApiRootResource.scala | 1 + .../org/apache/spark/deploy/history/HistoryServerSuite.scala | 5 +++++ 2 files changed, 6 insertions(+) diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala index 7127397f6205c..d121068718b8a 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala @@ -49,6 +49,7 @@ private[v1] class ApiRootResource extends ApiRequestContext { @Path("applications/{appId}") def application(): Class[OneApplicationResource] = classOf[OneApplicationResource] + @GET @Path("version") def version(): VersionInfo = new VersionInfo(org.apache.spark.SPARK_VERSION) diff --git a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala index 87f12f303cd5e..a871b1c717837 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala @@ -296,6 +296,11 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers all (siteRelativeLinks) should startWith (uiRoot) } + test("/version api endpoint") { + val response = getUrl("version") + assert(response.contains(SPARK_VERSION)) + } + test("ajax rendered relative links are prefixed with uiRoot (spark.ui.proxyBase)") { val uiRoot = "/testwebproxybase" System.setProperty("spark.ui.proxyBase", uiRoot) From e17567ca78dbb416039c17da212957c8955bfa65 Mon Sep 17 00:00:00 2001 From: yucai Date: Tue, 8 May 2018 11:34:27 +0200 Subject: [PATCH 0757/1161] [SPARK-24076][SQL] Use different seed in HashAggregate to avoid hash conflict ## What changes were proposed in this pull request? HashAggregate uses the same hash algorithm and seed as ShuffleExchange, it may lead to bad hash conflict when shuffle.partitions=8192*n. Considering below example: ``` SET spark.sql.shuffle.partitions=8192; INSERT OVERWRITE TABLE target_xxx SELECT item_id, auct_end_dt FROM from source_xxx GROUP BY item_id, auct_end_dt; ``` In the shuffle stage, if user sets the shuffle.partition = 8192, all tuples in the same partition will meet the following relationship: ``` hash(tuple x) = hash(tuple y) + n * 8192 ``` Then in the next HashAggregate stage, all tuples from the same partition need be put into a 16K BytesToBytesMap (unsafeRowAggBuffer). Here, the HashAggregate uses the same hash algorithm on the same expression as shuffle, and uses the same seed, and 16K = 8192 * 2, so actually, all tuples in the same parititon will only be hashed to 2 different places in the BytesToBytesMap. It is bad hash conflict. With BytesToBytesMap growing, the conflict will always exist. Before change: hash_conflict After change: no_hash_conflict ## How was this patch tested? Unit tests and production cases. Author: yucai Closes #21149 from yucai/SPARK-24076. --- .../spark/sql/execution/aggregate/HashAggregateExec.scala | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index a5dc6ebf2b0f2..6a8ec4f722aea 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -755,7 +755,10 @@ case class HashAggregateExec( } // generate hash code for key - val hashExpr = Murmur3Hash(groupingExpressions, 42) + // SPARK-24076: HashAggregate uses the same hash algorithm on the same expressions + // as ShuffleExchange, it may lead to bad hash conflict when shuffle.partitions=8192*n, + // pick a different seed to avoid this conflict + val hashExpr = Murmur3Hash(groupingExpressions, 48) val hashEval = BindReferences.bindReference(hashExpr, child.output).genCode(ctx) val (checkFallbackForGeneratedHashMap, checkFallbackForBytesToBytesMap, resetCounter, From b54bbe57b33b00063596cd9588fa2461745ed571 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 8 May 2018 21:22:54 +0800 Subject: [PATCH 0758/1161] [SPARK-24131][PYSPARK][FOLLOWUP] Add majorMinorVersion API to PySpark for determining Spark versions ## What changes were proposed in this pull request? More close to Scala API behavior when can't parse input by throwing exception. Add tests. ## How was this patch tested? Added tests. Author: Liang-Chi Hsieh Closes #21211 from viirya/SPARK-24131-followup. --- python/pyspark/tests.py | 4 ++++ python/pyspark/util.py | 37 ++++++++++++++++++++++--------------- 2 files changed, 26 insertions(+), 15 deletions(-) diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 7b8ce2c6b799f..498d6b57e4353 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -2312,6 +2312,10 @@ def test_py4j_exception_message(self): self.assertTrue('NullPointerException' in _exception_message(context.exception)) + def test_parsing_version_string(self): + from pyspark.util import VersionUtils + self.assertRaises(ValueError, lambda: VersionUtils.majorMinorVersion("abced")) + @unittest.skipIf(not _have_scipy, "SciPy not installed") class SciPyTests(PySparkTestCase): diff --git a/python/pyspark/util.py b/python/pyspark/util.py index 04df835bf6717..59cc2a6329350 100644 --- a/python/pyspark/util.py +++ b/python/pyspark/util.py @@ -62,24 +62,31 @@ def _get_argspec(f): return argspec -def majorMinorVersion(version): +class VersionUtils(object): """ - Get major and minor version numbers for given Spark version string. - - >>> version = "2.4.0" - >>> majorMinorVersion(version) - (2, 4) + Provides utility method to determine Spark versions with given input string. + """ + @staticmethod + def majorMinorVersion(sparkVersion): + """ + Given a Spark version string, return the (major version number, minor version number). + E.g., for 2.0.1-SNAPSHOT, return (2, 0). - >>> version = "abc" - >>> majorMinorVersion(version) is None - True + >>> sparkVersion = "2.4.0" + >>> VersionUtils.majorMinorVersion(sparkVersion) + (2, 4) + >>> sparkVersion = "2.3.0-SNAPSHOT" + >>> VersionUtils.majorMinorVersion(sparkVersion) + (2, 3) - """ - m = re.search('^(\d+)\.(\d+)(\..*)?$', version) - if m is None: - return None - else: - return (int(m.group(1)), int(m.group(2))) + """ + m = re.search('^(\d+)\.(\d+)(\..*)?$', sparkVersion) + if m is not None: + return (int(m.group(1)), int(m.group(2))) + else: + raise ValueError("Spark tried to parse '%s' as a Spark" % sparkVersion + + " version string, but it could not find the major and minor" + + " version numbers.") if __name__ == "__main__": From 2f6fe7d679a878ffd103cac6f06081c5b3888744 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 8 May 2018 21:24:35 +0800 Subject: [PATCH 0759/1161] [SPARK-23094][SPARK-23723][SPARK-23724][SQL][FOLLOW-UP] Support custom encoding for json files ## What changes were proposed in this pull request? This is to add a test case to check the behaviors when users write json in the specified UTF-16/UTF-32 encoding with multiline off. ## How was this patch tested? N/A Author: gatorsmile Closes #21254 from gatorsmile/followupSPARK-23094. --- .../spark/sql/catalyst/json/JSONOptions.scala | 9 +++++---- .../datasources/json/JsonSuite.scala | 19 +++++++++++++++++++ 2 files changed, 24 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala index 5f130af606e19..2579374e3f4e1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala @@ -110,11 +110,12 @@ private[sql] class JSONOptions( val blacklist = Seq(Charset.forName("UTF-16"), Charset.forName("UTF-32")) val isBlacklisted = blacklist.contains(Charset.forName(enc)) require(multiLine || !isBlacklisted, - s"""The ${enc} encoding must not be included in the blacklist when multiLine is disabled: - | ${blacklist.mkString(", ")}""".stripMargin) + s"""The $enc encoding in the blacklist is not allowed when multiLine is disabled. + |Blacklist: ${blacklist.mkString(", ")}""".stripMargin) + + val isLineSepRequired = + multiLine || Charset.forName(enc) == StandardCharsets.UTF_8 || lineSeparator.nonEmpty - val isLineSepRequired = !(multiLine == false && - Charset.forName(enc) != StandardCharsets.UTF_8 && lineSeparator.isEmpty) require(isLineSepRequired, s"The lineSep option must be specified for the $enc encoding") enc diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 0db688fec9a67..4b3921c61a000 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -2313,6 +2313,25 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } } + test("SPARK-23723: write json in UTF-16/32 with multiline off") { + Seq("UTF-16", "UTF-32").foreach { encoding => + withTempPath { path => + val ds = spark.createDataset(Seq( + ("a", 1), ("b", 2), ("c", 3)) + ).repartition(2) + val e = intercept[IllegalArgumentException] { + ds.write + .option("encoding", encoding) + .option("multiline", "false") + .format("json").mode("overwrite") + .save(path.getCanonicalPath) + }.getMessage + assert(e.contains( + s"$encoding encoding in the blacklist is not allowed when multiLine is disabled")) + } + } + } + def checkReadJson(lineSep: String, encoding: String, inferSchema: Boolean, id: Int): Unit = { test(s"SPARK-23724: checks reading json in ${encoding} #${id}") { val schema = new StructType().add("f1", StringType).add("f2", IntegerType) From 487faf17ab96c8edb729501dfb1ff82f7b2c6031 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Tue, 8 May 2018 23:43:02 +0800 Subject: [PATCH 0760/1161] [SPARK-24117][SQL] Unified the getSizePerRow ## What changes were proposed in this pull request? This pr unified the `getSizePerRow` because `getSizePerRow` is used in many places. For example: 1. [LocalRelation.scala#L80](https://github.com/wangyum/spark/blob/f70f46d1e5bc503e9071707d837df618b7696d32/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala#L80) 2. [SizeInBytesOnlyStatsPlanVisitor.scala#L36](https://github.com/apache/spark/blob/76b8b840ddc951ee6203f9cccd2c2b9671c1b5e8/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/SizeInBytesOnlyStatsPlanVisitor.scala#L36) ## How was this patch tested? Exist tests Author: Yuming Wang Closes #21189 from wangyum/SPARK-24117. --- .../sql/catalyst/plans/logical/LocalRelation.scala | 3 ++- .../logical/statsEstimation/EstimationUtils.scala | 14 ++++++++------ .../SizeInBytesOnlyStatsPlanVisitor.scala | 4 ++-- .../spark/sql/execution/streaming/memory.scala | 10 ++++------ .../sql/execution/streaming/sources/memoryV2.scala | 3 ++- .../spark/sql/StatisticsCollectionSuite.scala | 2 +- .../sql/execution/streaming/MemorySinkSuite.scala | 4 ++-- 7 files changed, 21 insertions(+), 19 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala index 720d42ab409a0..8c4828a4cef23 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala @@ -21,6 +21,7 @@ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.expressions.{Attribute, Literal} +import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils import org.apache.spark.sql.types.{StructField, StructType} object LocalRelation { @@ -77,7 +78,7 @@ case class LocalRelation( } override def computeStats(): Statistics = - Statistics(sizeInBytes = output.map(n => BigInt(n.dataType.defaultSize)).sum * data.length) + Statistics(sizeInBytes = EstimationUtils.getSizePerRow(output) * data.length) def toSQL(inlineTableName: String): String = { require(data.nonEmpty) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala index 0f147f0ffb135..211a2a0717371 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst.plans.logical.statsEstimation -import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import scala.math.BigDecimal.RoundingMode @@ -25,7 +24,6 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.types.{DecimalType, _} - object EstimationUtils { /** Check if each plan has rowCount in its statistics. */ @@ -73,13 +71,12 @@ object EstimationUtils { AttributeMap(output.flatMap(a => inputMap.get(a).map(a -> _))) } - def getOutputSize( + def getSizePerRow( attributes: Seq[Attribute], - outputRowCount: BigInt, attrStats: AttributeMap[ColumnStat] = AttributeMap(Nil)): BigInt = { // We assign a generic overhead for a Row object, the actual overhead is different for different // Row format. - val sizePerRow = 8 + attributes.map { attr => + 8 + attributes.map { attr => if (attrStats.get(attr).map(_.avgLen.isDefined).getOrElse(false)) { attr.dataType match { case StringType => @@ -92,10 +89,15 @@ object EstimationUtils { attr.dataType.defaultSize } }.sum + } + def getOutputSize( + attributes: Seq[Attribute], + outputRowCount: BigInt, + attrStats: AttributeMap[ColumnStat] = AttributeMap(Nil)): BigInt = { // Output size can't be zero, or sizeInBytes of BinaryNode will also be zero // (simple computation of statistics returns product of children). - if (outputRowCount > 0) outputRowCount * sizePerRow else 1 + if (outputRowCount > 0) outputRowCount * getSizePerRow(attributes, attrStats) else 1 } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/SizeInBytesOnlyStatsPlanVisitor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/SizeInBytesOnlyStatsPlanVisitor.scala index 85f67c7d66075..ee43f9126386b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/SizeInBytesOnlyStatsPlanVisitor.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/SizeInBytesOnlyStatsPlanVisitor.scala @@ -33,8 +33,8 @@ object SizeInBytesOnlyStatsPlanVisitor extends LogicalPlanVisitor[Statistics] { private def visitUnaryNode(p: UnaryNode): Statistics = { // There should be some overhead in Row object, the size should not be zero when there is // no columns, this help to prevent divide-by-zero error. - val childRowSize = p.child.output.map(_.dataType.defaultSize).sum + 8 - val outputRowSize = p.output.map(_.dataType.defaultSize).sum + 8 + val childRowSize = EstimationUtils.getSizePerRow(p.child.output) + val outputRowSize = EstimationUtils.getSizePerRow(p.output) // Assume there will be the same number of rows as child has. var sizeInBytes = (p.child.stats.sizeInBytes * outputRowSize) / childRowSize if (sizeInBytes == 0) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala index 22258274c70c1..6720cdd24b1b2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -24,23 +24,21 @@ import javax.annotation.concurrent.GuardedBy import scala.collection.JavaConverters._ import scala.collection.mutable.{ArrayBuffer, ListBuffer} -import scala.reflect.ClassTag import scala.util.control.NonFatal import org.apache.spark.internal.Logging import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder} +import org.apache.spark.sql.catalyst.encoders.encoderFor import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeRow} import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} +import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ -import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger import org.apache.spark.sql.sources.v2.reader.{DataReader, DataReaderFactory, SupportsScanUnsafeRow} import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset => OffsetV2} -import org.apache.spark.sql.streaming.{OutputMode, Trigger} +import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils - object MemoryStream { protected val currentBlockId = new AtomicInteger(0) protected val memoryStreamId = new AtomicInteger(0) @@ -307,7 +305,7 @@ class MemorySink(val schema: StructType, outputMode: OutputMode) extends Sink case class MemoryPlan(sink: MemorySink, output: Seq[Attribute]) extends LeafNode { def this(sink: MemorySink) = this(sink, sink.schema.toAttributes) - private val sizePerRow = sink.schema.toAttributes.map(_.dataType.defaultSize).sum + private val sizePerRow = EstimationUtils.getSizePerRow(sink.schema.toAttributes) override def computeStats(): Statistics = Statistics(sizePerRow * sink.allData.size) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala index 0d6c239274dd8..468313bfe8c3c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala @@ -27,6 +27,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics} +import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils import org.apache.spark.sql.catalyst.streaming.InternalOutputModes.{Append, Complete, Update} import org.apache.spark.sql.execution.streaming.{MemorySinkBase, Sink} import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, StreamWriteSupport} @@ -182,7 +183,7 @@ class MemoryDataWriter(partition: Int, outputMode: OutputMode) * Used to query the data that has been written into a [[MemorySinkV2]]. */ case class MemoryPlanV2(sink: MemorySinkV2, override val output: Seq[Attribute]) extends LeafNode { - private val sizePerRow = output.map(_.dataType.defaultSize).sum + private val sizePerRow = EstimationUtils.getSizePerRow(output) override def computeStats(): Statistics = Statistics(sizePerRow * sink.allData.size) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala index b91712f4cc25d..60fa951e23178 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala @@ -50,7 +50,7 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared } assert(sizes.size === 1, s"number of Join nodes is wrong:\n ${df.queryExecution}") - assert(sizes.head === BigInt(96), + assert(sizes.head === BigInt(128), s"expected exact size 96 for table 'test', got: ${sizes.head}") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala index e8420eee7fe9d..3bc36ce55d902 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala @@ -220,11 +220,11 @@ class MemorySinkSuite extends StreamTest with BeforeAndAfter { sink.addBatch(0, 1 to 3) plan.invalidateStatsCache() - assert(plan.stats.sizeInBytes === 12) + assert(plan.stats.sizeInBytes === 36) sink.addBatch(1, 4 to 6) plan.invalidateStatsCache() - assert(plan.stats.sizeInBytes === 24) + assert(plan.stats.sizeInBytes === 72) } ignore("stress test") { From e3de6ab30d52890eb08578e55eb4a5d2b4e7aa35 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Wed, 9 May 2018 08:32:20 +0800 Subject: [PATCH 0761/1161] [SPARK-24068] Propagating DataFrameReader's options to Text datasource on schema inferring ## What changes were proposed in this pull request? While reading CSV or JSON files, DataFrameReader's options are converted to Hadoop's parameters, for example there: https://github.com/apache/spark/blob/branch-2.3/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala#L302 but the options are not propagated to Text datasource on schema inferring, for instance: https://github.com/apache/spark/blob/branch-2.3/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala#L184-L188 The PR proposes propagation of user's options to Text datasource on scheme inferring in similar way as user's options are converted to Hadoop parameters if schema is specified. ## How was this patch tested? The changes were tested manually by using https://github.com/twitter/hadoop-lzo: ``` hadoop-lzo> mvn clean package hadoop-lzo> ln -s ./target/hadoop-lzo-0.4.21-SNAPSHOT.jar ./hadoop-lzo.jar ``` Create 2 test files in JSON and CSV format and compress them: ```shell $ cat test.csv col1|col2 a|1 $ lzop test.csv $ cat test.json {"col1":"a","col2":1} $ lzop test.json ``` Run `spark-shell` with hadoop-lzo: ``` bin/spark-shell --jars ~/hadoop-lzo/hadoop-lzo.jar ``` reading compressed CSV and JSON without schema: ```scala spark.read.option("io.compression.codecs", "com.hadoop.compression.lzo.LzopCodec").option("inferSchema",true).option("header",true).option("sep","|").csv("test.csv.lzo").show() +----+----+ |col1|col2| +----+----+ | a| 1| +----+----+ ``` ```scala spark.read.option("io.compression.codecs", "com.hadoop.compression.lzo.LzopCodec").option("multiLine", true).json("test.json.lzo").printSchema root |-- col1: string (nullable = true) |-- col2: long (nullable = true) ``` Author: Maxim Gekk Author: Maxim Gekk Closes #21182 from MaxGekk/text-options. --- .../apache/spark/sql/catalyst/json/JSONOptions.scala | 2 +- .../sql/execution/datasources/csv/CSVDataSource.scala | 6 ++++-- .../sql/execution/datasources/csv/CSVOptions.scala | 2 +- .../spark/sql/execution/datasources/csv/CSVUtils.scala | 2 -- .../execution/datasources/csv/UnivocityParser.scala | 2 -- .../execution/datasources/json/JsonDataSource.scala | 10 ++++++---- 6 files changed, 12 insertions(+), 12 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala index 2579374e3f4e1..2ff12acb2946f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.catalyst.util._ * Most of these map directly to Jackson's internal options, specified in [[JsonParser.Feature]]. */ private[sql] class JSONOptions( - @transient private val parameters: CaseInsensitiveMap[String], + @transient val parameters: CaseInsensitiveMap[String], defaultTimeZoneId: String, defaultColumnNameOfCorruptRecord: String) extends Logging with Serializable { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala index bc1f4ab3bb053..dc54d182651b1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala @@ -185,7 +185,8 @@ object TextInputCSVDataSource extends CSVDataSource { DataSource.apply( sparkSession, paths = paths, - className = classOf[TextFileFormat].getName + className = classOf[TextFileFormat].getName, + options = options.parameters ).resolveRelation(checkFilesExist = false)) .select("value").as[String](Encoders.STRING) } else { @@ -250,7 +251,8 @@ object MultiLineCSVDataSource extends CSVDataSource { options: CSVOptions): RDD[PortableDataStream] = { val paths = inputPaths.map(_.getPath) val name = paths.mkString(",") - val job = Job.getInstance(sparkSession.sessionState.newHadoopConf()) + val job = Job.getInstance(sparkSession.sessionState.newHadoopConfWithOptions( + options.parameters)) FileInputFormat.setInputPaths(job, paths: _*) val conf = job.getConfiguration diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala index 2ec0fc605a84b..ed2dc65a47914 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala @@ -27,7 +27,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.util._ class CSVOptions( - @transient private val parameters: CaseInsensitiveMap[String], + @transient val parameters: CaseInsensitiveMap[String], defaultTimeZoneId: String, defaultColumnNameOfCorruptRecord: String) extends Logging with Serializable { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala index 31464f1bcc68e..9dae41b63e810 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala @@ -17,10 +17,8 @@ package org.apache.spark.sql.execution.datasources.csv -import org.apache.spark.input.PortableDataStream import org.apache.spark.rdd.RDD import org.apache.spark.sql.Dataset -import org.apache.spark.sql.catalyst.json.JSONOptions import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala index 3d6cc30f2ba83..99557a1ceb0c8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala @@ -19,8 +19,6 @@ package org.apache.spark.sql.execution.datasources.csv import java.io.InputStream import java.math.BigDecimal -import java.text.NumberFormat -import java.util.Locale import scala.util.Try import scala.util.control.NonFatal diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala index 983a5f0dcade2..ba83df0efebd0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala @@ -121,7 +121,7 @@ object TextInputJsonDataSource extends JsonDataSource { sparkSession, paths = paths, className = classOf[TextFileFormat].getName, - options = textOptions + options = parsedOptions.parameters ).resolveRelation(checkFilesExist = false)) .select("value").as(Encoders.STRING) } @@ -159,7 +159,7 @@ object MultiLineJsonDataSource extends JsonDataSource { sparkSession: SparkSession, inputPaths: Seq[FileStatus], parsedOptions: JSONOptions): StructType = { - val json: RDD[PortableDataStream] = createBaseRdd(sparkSession, inputPaths) + val json: RDD[PortableDataStream] = createBaseRdd(sparkSession, inputPaths, parsedOptions) val sampled: RDD[PortableDataStream] = JsonUtils.sample(json, parsedOptions) val parser = parsedOptions.encoding .map(enc => createParser(enc, _: JsonFactory, _: PortableDataStream)) @@ -170,9 +170,11 @@ object MultiLineJsonDataSource extends JsonDataSource { private def createBaseRdd( sparkSession: SparkSession, - inputPaths: Seq[FileStatus]): RDD[PortableDataStream] = { + inputPaths: Seq[FileStatus], + parsedOptions: JSONOptions): RDD[PortableDataStream] = { val paths = inputPaths.map(_.getPath) - val job = Job.getInstance(sparkSession.sessionState.newHadoopConf()) + val job = Job.getInstance(sparkSession.sessionState.newHadoopConfWithOptions( + parsedOptions.parameters)) val conf = job.getConfiguration val name = paths.mkString(",") FileInputFormat.setInputPaths(job, paths: _*) From 9498e528d21e286e496da6ea9bf9c7ad73a7b5bd Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Wed, 9 May 2018 08:39:46 +0800 Subject: [PATCH 0762/1161] [SPARK-23355][SQL][DOC][FOLLOWUP] Add migration doc for TBLPROPERTIES ## What changes were proposed in this pull request? In Apache Spark 2.4, [SPARK-23355](https://issues.apache.org/jira/browse/SPARK-23355) fixes a bug which ignores table properties during convertMetastore for tables created by STORED AS ORC/PARQUET. For some Parquet tables having table properties like TBLPROPERTIES (parquet.compression 'NONE'), it was ignored by default before Apache Spark 2.4. After upgrading cluster, Spark will write uncompressed file which is different from Apache Spark 2.3 and old. This PR adds a migration note for that. ## How was this patch tested? N/A Author: Dongjoon Hyun Closes #21269 from dongjoon-hyun/SPARK-23355-DOC. --- docs/sql-programming-guide.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 075b953a0898e..c521f3cb51e58 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1812,6 +1812,8 @@ working with timestamps in `pandas_udf`s to get the best performance, see - Since Spark 2.4, creating a managed table with nonempty location is not allowed. An exception is thrown when attempting to create a managed table with nonempty location. To set `true` to `spark.sql.allowCreatingManagedTableUsingNonemptyLocation` restores the previous behavior. This option will be removed in Spark 3.0. - Since Spark 2.4, the type coercion rules can automatically promote the argument types of the variadic SQL functions (e.g., IN/COALESCE) to the widest common type, no matter how the input arguments order. In prior Spark versions, the promotion could fail in some specific orders (e.g., TimestampType, IntegerType and StringType) and throw an exception. - In version 2.3 and earlier, `to_utc_timestamp` and `from_utc_timestamp` respect the timezone in the input timestamp string, which breaks the assumption that the input timestamp is in a specific timezone. Therefore, these 2 functions can return unexpected results. In version 2.4 and later, this problem has been fixed. `to_utc_timestamp` and `from_utc_timestamp` will return null if the input timestamp string contains timezone. As an example, `from_utc_timestamp('2000-10-10 00:00:00', 'GMT+1')` will return `2000-10-10 01:00:00` in both Spark 2.3 and 2.4. However, `from_utc_timestamp('2000-10-10 00:00:00+00:00', 'GMT+1')`, assuming a local timezone of GMT+8, will return `2000-10-10 09:00:00` in Spark 2.3 but `null` in 2.4. For people who don't care about this problem and want to retain the previous behaivor to keep their query unchanged, you can set `spark.sql.function.rejectTimezoneInString` to false. This option will be removed in Spark 3.0 and should only be used as a temporary workaround. + - In version 2.3 and earlier, Spark converts Parquet Hive tables by default but ignores table properties like `TBLPROPERTIES (parquet.compression 'NONE')`. This happens for ORC Hive table properties like `TBLPROPERTIES (orc.compress 'NONE')` in case of `spark.sql.hive.convertMetastoreOrc=true`, too. Since Spark 2.4, Spark respects Parquet/ORC specific table properties while converting Parquet/ORC Hive tables. As an example, `CREATE TABLE t(id int) STORED AS PARQUET TBLPROPERTIES (parquet.compression 'NONE')` would generate Snappy parquet files during insertion in Spark 2.3, and in Spark 2.4, the result would be uncompressed parquet files. + ## Upgrading From Spark SQL 2.2 to 2.3 - Since Spark 2.3, the queries from raw JSON/CSV files are disallowed when the referenced columns only include the internal corrupt record column (named `_corrupt_record` by default). For example, `spark.read.schema(schema).json(file).filter($"_corrupt_record".isNotNull).count()` and `spark.read.schema(schema).json(file).select("_corrupt_record").show()`. Instead, you can cache or save the parsed results and then send the same query. For example, `val df = spark.read.schema(schema).json(file).cache()` and then `df.filter($"_corrupt_record".isNotNull).count()`. From 7e7350285dc22764f599671d874617c0eea093e5 Mon Sep 17 00:00:00 2001 From: Lu WANG Date: Tue, 8 May 2018 21:20:58 -0700 Subject: [PATCH 0763/1161] [SPARK-24132][ML] Instrumentation improvement for classification ## What changes were proposed in this pull request? - Add OptionalInstrumentation as argument for getNumClasses in ml.classification.Classifier - Change the function call for getNumClasses in train() in ml.classification.DecisionTreeClassifier, ml.classification.RandomForestClassifier, and ml.classification.NaiveBayes - Modify the instrumentation creation in ml.classification.LinearSVC - Change the log call in ml.classification.OneVsRest and ml.classification.LinearSVC ## How was this patch tested? Manual. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Lu WANG Closes #21204 from ludatabricks/SPARK-23686. --- .../spark/ml/classification/DecisionTreeClassifier.scala | 9 ++++++--- .../org/apache/spark/ml/classification/LinearSVC.scala | 9 ++++++--- .../org/apache/spark/ml/classification/NaiveBayes.scala | 3 ++- .../org/apache/spark/ml/classification/OneVsRest.scala | 4 ++-- .../spark/ml/classification/RandomForestClassifier.scala | 4 +++- 5 files changed, 19 insertions(+), 10 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala index 57797d1cc4978..c9786f1f7ceb1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala @@ -97,9 +97,11 @@ class DecisionTreeClassifier @Since("1.4.0") ( override def setSeed(value: Long): this.type = set(seed, value) override protected def train(dataset: Dataset[_]): DecisionTreeClassificationModel = { + val instr = Instrumentation.create(this, dataset) val categoricalFeatures: Map[Int, Int] = MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) val numClasses: Int = getNumClasses(dataset) + instr.logNumClasses(numClasses) if (isDefined(thresholds)) { require($(thresholds).length == numClasses, this.getClass.getSimpleName + @@ -110,8 +112,8 @@ class DecisionTreeClassifier @Since("1.4.0") ( val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, numClasses) val strategy = getOldStrategy(categoricalFeatures, numClasses) - val instr = Instrumentation.create(this, oldDataset) - instr.logParams(params: _*) + instr.logParams(maxDepth, maxBins, minInstancesPerNode, minInfoGain, maxMemoryInMB, + cacheNodeIds, checkpointInterval, impurity, seed) val trees = RandomForest.run(oldDataset, strategy, numTrees = 1, featureSubsetStrategy = "all", seed = $(seed), instr = Some(instr), parentUID = Some(uid)) @@ -125,7 +127,8 @@ class DecisionTreeClassifier @Since("1.4.0") ( private[ml] def train(data: RDD[LabeledPoint], oldStrategy: OldStrategy): DecisionTreeClassificationModel = { val instr = Instrumentation.create(this, data) - instr.logParams(params: _*) + instr.logParams(maxDepth, maxBins, minInstancesPerNode, minInfoGain, maxMemoryInMB, + cacheNodeIds, checkpointInterval, impurity, seed) val trees = RandomForest.run(data, oldStrategy, numTrees = 1, featureSubsetStrategy = "all", seed = 0L, instr = Some(instr), parentUID = Some(uid)) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala index 80c537e1e0eb2..38eb04556b775 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala @@ -170,7 +170,7 @@ class LinearSVC @Since("2.2.0") ( Instance(label, weight, features) } - val instr = Instrumentation.create(this, instances) + val instr = Instrumentation.create(this, dataset) instr.logParams(regParam, maxIter, fitIntercept, tol, standardization, threshold, aggregationDepth) @@ -187,6 +187,9 @@ class LinearSVC @Since("2.2.0") ( (new MultivariateOnlineSummarizer, new MultiClassSummarizer) )(seqOp, combOp, $(aggregationDepth)) } + instr.logNamedValue(Instrumentation.loggerTags.numExamples, summarizer.count) + instr.logNamedValue("lowestLabelWeight", labelSummarizer.histogram.min.toString) + instr.logNamedValue("highestLabelWeight", labelSummarizer.histogram.max.toString) val histogram = labelSummarizer.histogram val numInvalid = labelSummarizer.countInvalid @@ -209,7 +212,7 @@ class LinearSVC @Since("2.2.0") ( if (numInvalid != 0) { val msg = s"Classification labels should be in [0 to ${numClasses - 1}]. " + s"Found $numInvalid invalid labels." - logError(msg) + instr.logError(msg) throw new SparkException(msg) } @@ -246,7 +249,7 @@ class LinearSVC @Since("2.2.0") ( bcFeaturesStd.destroy(blocking = false) if (state == null) { val msg = s"${optimizer.getClass.getName} failed." - logError(msg) + instr.logError(msg) throw new SparkException(msg) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala index 45fb585ed2262..1dde18d2d1a31 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala @@ -126,8 +126,10 @@ class NaiveBayes @Since("1.5.0") ( private[spark] def trainWithLabelCheck( dataset: Dataset[_], positiveLabel: Boolean): NaiveBayesModel = { + val instr = Instrumentation.create(this, dataset) if (positiveLabel && isDefined(thresholds)) { val numClasses = getNumClasses(dataset) + instr.logNumClasses(numClasses) require($(thresholds).length == numClasses, this.getClass.getSimpleName + ".train() called with non-matching numClasses and thresholds.length." + s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}") @@ -146,7 +148,6 @@ class NaiveBayes @Since("1.5.0") ( } } - val instr = Instrumentation.create(this, dataset) instr.logParams(labelCol, featuresCol, weightCol, predictionCol, rawPredictionCol, probabilityCol, modelType, smoothing, thresholds) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala index 7df53a6b8ad10..3474b61e40136 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala @@ -366,7 +366,7 @@ final class OneVsRest @Since("1.4.0") ( transformSchema(dataset.schema) val instr = Instrumentation.create(this, dataset) - instr.logParams(labelCol, featuresCol, predictionCol, parallelism) + instr.logParams(labelCol, featuresCol, predictionCol, parallelism, rawPredictionCol) instr.logNamedValue("classifier", $(classifier).getClass.getCanonicalName) // determine number of classes either from metadata if provided, or via computation. @@ -383,7 +383,7 @@ final class OneVsRest @Since("1.4.0") ( getClassifier match { case _: HasWeightCol => true case c => - logWarning(s"weightCol is ignored, as it is not supported by $c now.") + instr.logWarning(s"weightCol is ignored, as it is not supported by $c now.") false } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index f1ef26a07d3f8..040db3b94b041 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -116,6 +116,7 @@ class RandomForestClassifier @Since("1.4.0") ( set(featureSubsetStrategy, value) override protected def train(dataset: Dataset[_]): RandomForestClassificationModel = { + val instr = Instrumentation.create(this, dataset) val categoricalFeatures: Map[Int, Int] = MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) val numClasses: Int = getNumClasses(dataset) @@ -130,7 +131,6 @@ class RandomForestClassifier @Since("1.4.0") ( val strategy = super.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification, getOldImpurity) - val instr = Instrumentation.create(this, oldDataset) instr.logParams(labelCol, featuresCol, predictionCol, probabilityCol, rawPredictionCol, impurity, numTrees, featureSubsetStrategy, maxDepth, maxBins, maxMemoryInMB, minInfoGain, minInstancesPerNode, seed, subsamplingRate, thresholds, cacheNodeIds, checkpointInterval) @@ -141,6 +141,8 @@ class RandomForestClassifier @Since("1.4.0") ( val numFeatures = oldDataset.first().features.size val m = new RandomForestClassificationModel(uid, trees, numFeatures, numClasses) + instr.logNumClasses(numClasses) + instr.logNumFeatures(numFeatures) instr.logSuccess(m) m } From cac9b1dea1bb44fa42abf77829c05bf93f70cf20 Mon Sep 17 00:00:00 2001 From: Ryan Blue Date: Wed, 9 May 2018 12:27:32 +0800 Subject: [PATCH 0764/1161] [SPARK-23972][BUILD][SQL] Update Parquet to 1.10.0. ## What changes were proposed in this pull request? This updates Parquet to 1.10.0 and updates the vectorized path for buffer management changes. Parquet 1.10.0 uses ByteBufferInputStream instead of byte arrays in encoders. This allows Parquet to break allocations into smaller chunks that are better for garbage collection. ## How was this patch tested? Existing Parquet tests. Running in production at Netflix for about 3 months. Author: Ryan Blue Closes #21070 from rdblue/SPARK-23972-update-parquet-to-1.10.0. --- dev/deps/spark-deps-hadoop-2.6 | 12 +- dev/deps/spark-deps-hadoop-2.7 | 12 +- dev/deps/spark-deps-hadoop-3.1 | 12 +- docs/sql-programming-guide.md | 2 +- pom.xml | 8 +- .../apache/spark/sql/internal/SQLConf.scala | 2 +- .../SpecificParquetRecordReaderBase.java | 2 +- .../parquet/VectorizedColumnReader.java | 39 ++-- .../parquet/VectorizedPlainValuesReader.java | 166 +++++++++++------- .../parquet/VectorizedRleValuesReader.java | 163 ++++++++--------- .../datasources/parquet/ParquetOptions.scala | 5 +- .../describe-part-after-analyze.sql.out | 12 +- .../columnar/InMemoryColumnarQuerySuite.scala | 4 +- 13 files changed, 241 insertions(+), 198 deletions(-) diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6 index c3d1dd444b506..f479c13f00be6 100644 --- a/dev/deps/spark-deps-hadoop-2.6 +++ b/dev/deps/spark-deps-hadoop-2.6 @@ -162,13 +162,13 @@ orc-mapreduce-1.4.3-nohive.jar oro-2.0.8.jar osgi-resource-locator-1.0.1.jar paranamer-2.8.jar -parquet-column-1.8.2.jar -parquet-common-1.8.2.jar -parquet-encoding-1.8.2.jar -parquet-format-2.3.1.jar -parquet-hadoop-1.8.2.jar +parquet-column-1.10.0.jar +parquet-common-1.10.0.jar +parquet-encoding-1.10.0.jar +parquet-format-2.4.0.jar +parquet-hadoop-1.10.0.jar parquet-hadoop-bundle-1.6.0.jar -parquet-jackson-1.8.2.jar +parquet-jackson-1.10.0.jar protobuf-java-2.5.0.jar py4j-0.10.6.jar pyrolite-4.13.jar diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index 290867035f91d..e7c4599cb5003 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -163,13 +163,13 @@ orc-mapreduce-1.4.3-nohive.jar oro-2.0.8.jar osgi-resource-locator-1.0.1.jar paranamer-2.8.jar -parquet-column-1.8.2.jar -parquet-common-1.8.2.jar -parquet-encoding-1.8.2.jar -parquet-format-2.3.1.jar -parquet-hadoop-1.8.2.jar +parquet-column-1.10.0.jar +parquet-common-1.10.0.jar +parquet-encoding-1.10.0.jar +parquet-format-2.4.0.jar +parquet-hadoop-1.10.0.jar parquet-hadoop-bundle-1.6.0.jar -parquet-jackson-1.8.2.jar +parquet-jackson-1.10.0.jar protobuf-java-2.5.0.jar py4j-0.10.6.jar pyrolite-4.13.jar diff --git a/dev/deps/spark-deps-hadoop-3.1 b/dev/deps/spark-deps-hadoop-3.1 index 97ad65a4096cb..3447cd7395d95 100644 --- a/dev/deps/spark-deps-hadoop-3.1 +++ b/dev/deps/spark-deps-hadoop-3.1 @@ -181,13 +181,13 @@ orc-mapreduce-1.4.3-nohive.jar oro-2.0.8.jar osgi-resource-locator-1.0.1.jar paranamer-2.8.jar -parquet-column-1.8.2.jar -parquet-common-1.8.2.jar -parquet-encoding-1.8.2.jar -parquet-format-2.3.1.jar -parquet-hadoop-1.8.2.jar +parquet-column-1.10.0.jar +parquet-common-1.10.0.jar +parquet-encoding-1.10.0.jar +parquet-format-2.4.0.jar +parquet-hadoop-1.10.0.jar parquet-hadoop-bundle-1.6.0.jar -parquet-jackson-1.8.2.jar +parquet-jackson-1.10.0.jar protobuf-java-2.5.0.jar py4j-0.10.6.jar pyrolite-4.13.jar diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index c521f3cb51e58..3e8946e424237 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -964,7 +964,7 @@ Configuration of Parquet can be done using the `setConf` method on `SparkSession Sets the compression codec used when writing Parquet files. If either `compression` or `parquet.compression` is specified in the table-specific options/properties, the precedence would be `compression`, `parquet.compression`, `spark.sql.parquet.compression.codec`. Acceptable values include: - none, uncompressed, snappy, gzip, lzo. + none, uncompressed, snappy, gzip, lzo, brotli, lz4, zstd. diff --git a/pom.xml b/pom.xml index 88e77ff874748..6e37e518d86e4 100644 --- a/pom.xml +++ b/pom.xml @@ -129,7 +129,7 @@ 1.2.110.12.1.1 - 1.8.2 + 1.10.01.4.3nohive1.6.0 @@ -1778,6 +1778,12 @@ parquet-hadoop${parquet.version}${parquet.deps.scope} + + + commons-pool + commons-pool + + org.apache.parquet diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 895e150756567..b00edca97cd44 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -345,7 +345,7 @@ object SQLConf { "snappy, gzip, lzo.") .stringConf .transform(_.toLowerCase(Locale.ROOT)) - .checkValues(Set("none", "uncompressed", "snappy", "gzip", "lzo")) + .checkValues(Set("none", "uncompressed", "snappy", "gzip", "lzo", "lz4", "brotli", "zstd")) .createWithDefault("snappy") val PARQUET_FILTER_PUSHDOWN_ENABLED = buildConf("spark.sql.parquet.filterPushdown") diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java index e65cd252c3ddf..10d6ed85a4080 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java @@ -293,7 +293,7 @@ protected static IntIterator createRLEIterator( return new RLEIntIterator( new RunLengthBitPackingHybridDecoder( BytesUtils.getWidthFromMaxInt(maxLevel), - new ByteArrayInputStream(bytes.toByteArray()))); + bytes.toInputStream())); } catch (IOException e) { throw new IOException("could not read levels in page for col " + descriptor, e); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java index 72f1d024b08ce..d5969b55eef96 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java @@ -21,6 +21,8 @@ import java.util.Arrays; import java.util.TimeZone; +import org.apache.parquet.bytes.ByteBufferInputStream; +import org.apache.parquet.bytes.BytesInput; import org.apache.parquet.bytes.BytesUtils; import org.apache.parquet.column.ColumnDescriptor; import org.apache.parquet.column.Dictionary; @@ -388,7 +390,8 @@ private void decodeDictionaryIds( * is guaranteed that num is smaller than the number of values left in the current page. */ - private void readBooleanBatch(int rowId, int num, WritableColumnVector column) { + private void readBooleanBatch(int rowId, int num, WritableColumnVector column) + throws IOException { if (column.dataType() != DataTypes.BooleanType) { throw constructConvertNotSupportedException(descriptor, column); } @@ -396,7 +399,7 @@ private void readBooleanBatch(int rowId, int num, WritableColumnVector column) { num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn); } - private void readIntBatch(int rowId, int num, WritableColumnVector column) { + private void readIntBatch(int rowId, int num, WritableColumnVector column) throws IOException { // This is where we implement support for the valid type conversions. // TODO: implement remaining type conversions if (column.dataType() == DataTypes.IntegerType || column.dataType() == DataTypes.DateType || @@ -414,7 +417,7 @@ private void readIntBatch(int rowId, int num, WritableColumnVector column) { } } - private void readLongBatch(int rowId, int num, WritableColumnVector column) { + private void readLongBatch(int rowId, int num, WritableColumnVector column) throws IOException { // This is where we implement support for the valid type conversions. if (column.dataType() == DataTypes.LongType || DecimalType.is64BitDecimalType(column.dataType()) || @@ -434,7 +437,7 @@ private void readLongBatch(int rowId, int num, WritableColumnVector column) { } } - private void readFloatBatch(int rowId, int num, WritableColumnVector column) { + private void readFloatBatch(int rowId, int num, WritableColumnVector column) throws IOException { // This is where we implement support for the valid type conversions. // TODO: support implicit cast to double? if (column.dataType() == DataTypes.FloatType) { @@ -445,7 +448,7 @@ private void readFloatBatch(int rowId, int num, WritableColumnVector column) { } } - private void readDoubleBatch(int rowId, int num, WritableColumnVector column) { + private void readDoubleBatch(int rowId, int num, WritableColumnVector column) throws IOException { // This is where we implement support for the valid type conversions. // TODO: implement remaining type conversions if (column.dataType() == DataTypes.DoubleType) { @@ -456,7 +459,7 @@ private void readDoubleBatch(int rowId, int num, WritableColumnVector column) { } } - private void readBinaryBatch(int rowId, int num, WritableColumnVector column) { + private void readBinaryBatch(int rowId, int num, WritableColumnVector column) throws IOException { // This is where we implement support for the valid type conversions. // TODO: implement remaining type conversions VectorizedValuesReader data = (VectorizedValuesReader) dataColumn; @@ -556,7 +559,7 @@ public Void visit(DataPageV2 dataPageV2) { }); } - private void initDataReader(Encoding dataEncoding, byte[] bytes, int offset) throws IOException { + private void initDataReader(Encoding dataEncoding, ByteBufferInputStream in) throws IOException { this.endOfPageValueCount = valuesRead + pageValueCount; if (dataEncoding.usesDictionary()) { this.dataColumn = null; @@ -581,7 +584,7 @@ private void initDataReader(Encoding dataEncoding, byte[] bytes, int offset) thr } try { - dataColumn.initFromPage(pageValueCount, bytes, offset); + dataColumn.initFromPage(pageValueCount, in); } catch (IOException e) { throw new IOException("could not read page in col " + descriptor, e); } @@ -602,12 +605,11 @@ private void readPageV1(DataPageV1 page) throws IOException { this.repetitionLevelColumn = new ValuesReaderIntIterator(rlReader); this.definitionLevelColumn = new ValuesReaderIntIterator(dlReader); try { - byte[] bytes = page.getBytes().toByteArray(); - rlReader.initFromPage(pageValueCount, bytes, 0); - int next = rlReader.getNextOffset(); - dlReader.initFromPage(pageValueCount, bytes, next); - next = dlReader.getNextOffset(); - initDataReader(page.getValueEncoding(), bytes, next); + BytesInput bytes = page.getBytes(); + ByteBufferInputStream in = bytes.toInputStream(); + rlReader.initFromPage(pageValueCount, in); + dlReader.initFromPage(pageValueCount, in); + initDataReader(page.getValueEncoding(), in); } catch (IOException e) { throw new IOException("could not read page " + page + " in col " + descriptor, e); } @@ -619,12 +621,13 @@ private void readPageV2(DataPageV2 page) throws IOException { page.getRepetitionLevels(), descriptor); int bitWidth = BytesUtils.getWidthFromMaxInt(descriptor.getMaxDefinitionLevel()); - this.defColumn = new VectorizedRleValuesReader(bitWidth); + // do not read the length from the stream. v2 pages handle dividing the page bytes. + this.defColumn = new VectorizedRleValuesReader(bitWidth, false); this.definitionLevelColumn = new ValuesReaderIntIterator(this.defColumn); - this.defColumn.initFromBuffer( - this.pageValueCount, page.getDefinitionLevels().toByteArray()); + this.defColumn.initFromPage( + this.pageValueCount, page.getDefinitionLevels().toInputStream()); try { - initDataReader(page.getDataEncoding(), page.getData().toByteArray(), 0); + initDataReader(page.getDataEncoding(), page.getData().toInputStream()); } catch (IOException e) { throw new IOException("could not read page " + page + " in col " + descriptor, e); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java index 5b75f719339fb..aacefacfc1c1a 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java @@ -20,34 +20,30 @@ import java.nio.ByteBuffer; import java.nio.ByteOrder; +import org.apache.parquet.bytes.ByteBufferInputStream; +import org.apache.parquet.io.ParquetDecodingException; import org.apache.spark.sql.execution.vectorized.WritableColumnVector; -import org.apache.spark.unsafe.Platform; import org.apache.parquet.column.values.ValuesReader; import org.apache.parquet.io.api.Binary; +import org.apache.spark.unsafe.Platform; /** * An implementation of the Parquet PLAIN decoder that supports the vectorized interface. */ public class VectorizedPlainValuesReader extends ValuesReader implements VectorizedValuesReader { - private byte[] buffer; - private int offset; - private int bitOffset; // Only used for booleans. - private ByteBuffer byteBuffer; // used to wrap the byte array buffer + private ByteBufferInputStream in = null; - private static final boolean bigEndianPlatform = - ByteOrder.nativeOrder().equals(ByteOrder.BIG_ENDIAN); + // Only used for booleans. + private int bitOffset; + private byte currentByte = 0; public VectorizedPlainValuesReader() { } @Override - public void initFromPage(int valueCount, byte[] bytes, int offset) throws IOException { - this.buffer = bytes; - this.offset = offset + Platform.BYTE_ARRAY_OFFSET; - if (bigEndianPlatform) { - byteBuffer = ByteBuffer.wrap(bytes).order(ByteOrder.LITTLE_ENDIAN); - } + public void initFromPage(int valueCount, ByteBufferInputStream in) throws IOException { + this.in = in; } @Override @@ -63,115 +59,157 @@ public final void readBooleans(int total, WritableColumnVector c, int rowId) { } } + private ByteBuffer getBuffer(int length) { + try { + return in.slice(length).order(ByteOrder.LITTLE_ENDIAN); + } catch (IOException e) { + throw new ParquetDecodingException("Failed to read " + length + " bytes", e); + } + } + @Override public final void readIntegers(int total, WritableColumnVector c, int rowId) { - c.putIntsLittleEndian(rowId, total, buffer, offset - Platform.BYTE_ARRAY_OFFSET); - offset += 4 * total; + int requiredBytes = total * 4; + ByteBuffer buffer = getBuffer(requiredBytes); + + if (buffer.hasArray()) { + int offset = buffer.arrayOffset() + buffer.position(); + c.putIntsLittleEndian(rowId, total, buffer.array(), offset); + } else { + for (int i = 0; i < total; i += 1) { + c.putInt(rowId + i, buffer.getInt()); + } + } } @Override public final void readLongs(int total, WritableColumnVector c, int rowId) { - c.putLongsLittleEndian(rowId, total, buffer, offset - Platform.BYTE_ARRAY_OFFSET); - offset += 8 * total; + int requiredBytes = total * 8; + ByteBuffer buffer = getBuffer(requiredBytes); + + if (buffer.hasArray()) { + int offset = buffer.arrayOffset() + buffer.position(); + c.putLongsLittleEndian(rowId, total, buffer.array(), offset); + } else { + for (int i = 0; i < total; i += 1) { + c.putLong(rowId + i, buffer.getLong()); + } + } } @Override public final void readFloats(int total, WritableColumnVector c, int rowId) { - c.putFloats(rowId, total, buffer, offset - Platform.BYTE_ARRAY_OFFSET); - offset += 4 * total; + int requiredBytes = total * 4; + ByteBuffer buffer = getBuffer(requiredBytes); + + if (buffer.hasArray()) { + int offset = buffer.arrayOffset() + buffer.position(); + c.putFloats(rowId, total, buffer.array(), offset); + } else { + for (int i = 0; i < total; i += 1) { + c.putFloat(rowId + i, buffer.getFloat()); + } + } } @Override public final void readDoubles(int total, WritableColumnVector c, int rowId) { - c.putDoubles(rowId, total, buffer, offset - Platform.BYTE_ARRAY_OFFSET); - offset += 8 * total; + int requiredBytes = total * 8; + ByteBuffer buffer = getBuffer(requiredBytes); + + if (buffer.hasArray()) { + int offset = buffer.arrayOffset() + buffer.position(); + c.putDoubles(rowId, total, buffer.array(), offset); + } else { + for (int i = 0; i < total; i += 1) { + c.putDouble(rowId + i, buffer.getDouble()); + } + } } @Override public final void readBytes(int total, WritableColumnVector c, int rowId) { - for (int i = 0; i < total; i++) { - // Bytes are stored as a 4-byte little endian int. Just read the first byte. - // TODO: consider pushing this in ColumnVector by adding a readBytes with a stride. - c.putByte(rowId + i, Platform.getByte(buffer, offset)); - offset += 4; + // Bytes are stored as a 4-byte little endian int. Just read the first byte. + // TODO: consider pushing this in ColumnVector by adding a readBytes with a stride. + int requiredBytes = total * 4; + ByteBuffer buffer = getBuffer(requiredBytes); + + for (int i = 0; i < total; i += 1) { + c.putByte(rowId + i, buffer.get()); + // skip the next 3 bytes + buffer.position(buffer.position() + 3); } } @Override public final boolean readBoolean() { - byte b = Platform.getByte(buffer, offset); - boolean v = (b & (1 << bitOffset)) != 0; + // TODO: vectorize decoding and keep boolean[] instead of currentByte + if (bitOffset == 0) { + try { + currentByte = (byte) in.read(); + } catch (IOException e) { + throw new ParquetDecodingException("Failed to read a byte", e); + } + } + + boolean v = (currentByte & (1 << bitOffset)) != 0; bitOffset += 1; if (bitOffset == 8) { bitOffset = 0; - offset++; } return v; } @Override public final int readInteger() { - int v = Platform.getInt(buffer, offset); - if (bigEndianPlatform) { - v = java.lang.Integer.reverseBytes(v); - } - offset += 4; - return v; + return getBuffer(4).getInt(); } @Override public final long readLong() { - long v = Platform.getLong(buffer, offset); - if (bigEndianPlatform) { - v = java.lang.Long.reverseBytes(v); - } - offset += 8; - return v; + return getBuffer(8).getLong(); } @Override public final byte readByte() { - return (byte)readInteger(); + return (byte) readInteger(); } @Override public final float readFloat() { - float v; - if (!bigEndianPlatform) { - v = Platform.getFloat(buffer, offset); - } else { - v = byteBuffer.getFloat(offset - Platform.BYTE_ARRAY_OFFSET); - } - offset += 4; - return v; + return getBuffer(4).getFloat(); } @Override public final double readDouble() { - double v; - if (!bigEndianPlatform) { - v = Platform.getDouble(buffer, offset); - } else { - v = byteBuffer.getDouble(offset - Platform.BYTE_ARRAY_OFFSET); - } - offset += 8; - return v; + return getBuffer(8).getDouble(); } @Override public final void readBinary(int total, WritableColumnVector v, int rowId) { for (int i = 0; i < total; i++) { int len = readInteger(); - int start = offset; - offset += len; - v.putByteArray(rowId + i, buffer, start - Platform.BYTE_ARRAY_OFFSET, len); + ByteBuffer buffer = getBuffer(len); + if (buffer.hasArray()) { + v.putByteArray(rowId + i, buffer.array(), buffer.arrayOffset() + buffer.position(), len); + } else { + byte[] bytes = new byte[len]; + buffer.get(bytes); + v.putByteArray(rowId + i, bytes); + } } } @Override public final Binary readBinary(int len) { - Binary result = Binary.fromConstantByteArray(buffer, offset - Platform.BYTE_ARRAY_OFFSET, len); - offset += len; - return result; + ByteBuffer buffer = getBuffer(len); + if (buffer.hasArray()) { + return Binary.fromConstantByteArray( + buffer.array(), buffer.arrayOffset() + buffer.position(), len); + } else { + byte[] bytes = new byte[len]; + buffer.get(bytes); + return Binary.fromConstantByteArray(bytes); + } } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java index fc7fa70c39419..fe3d31ae8e746 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.datasources.parquet; import org.apache.parquet.Preconditions; +import org.apache.parquet.bytes.ByteBufferInputStream; import org.apache.parquet.bytes.BytesUtils; import org.apache.parquet.column.values.ValuesReader; import org.apache.parquet.column.values.bitpacking.BytePacker; @@ -27,6 +28,9 @@ import org.apache.spark.sql.execution.vectorized.WritableColumnVector; +import java.io.IOException; +import java.nio.ByteBuffer; + /** * A values reader for Parquet's run-length encoded data. This is based off of the version in * parquet-mr with these changes: @@ -49,9 +53,7 @@ private enum MODE { } // Encoded data. - private byte[] in; - private int end; - private int offset; + private ByteBufferInputStream in; // bit/byte width of decoded data and utility to batch unpack them. private int bitWidth; @@ -70,45 +72,40 @@ private enum MODE { // If true, the bit width is fixed. This decoder is used in different places and this also // controls if we need to read the bitwidth from the beginning of the data stream. private final boolean fixedWidth; + private final boolean readLength; public VectorizedRleValuesReader() { - fixedWidth = false; + this.fixedWidth = false; + this.readLength = false; } public VectorizedRleValuesReader(int bitWidth) { - fixedWidth = true; + this.fixedWidth = true; + this.readLength = bitWidth != 0; + init(bitWidth); + } + + public VectorizedRleValuesReader(int bitWidth, boolean readLength) { + this.fixedWidth = true; + this.readLength = readLength; init(bitWidth); } @Override - public void initFromPage(int valueCount, byte[] page, int start) { - this.offset = start; - this.in = page; + public void initFromPage(int valueCount, ByteBufferInputStream in) throws IOException { + this.in = in; if (fixedWidth) { - if (bitWidth != 0) { + // initialize for repetition and definition levels + if (readLength) { int length = readIntLittleEndian(); - this.end = this.offset + length; + this.in = in.sliceStream(length); } } else { - this.end = page.length; - if (this.end != this.offset) init(page[this.offset++] & 255); - } - if (bitWidth == 0) { - // 0 bit width, treat this as an RLE run of valueCount number of 0's. - this.mode = MODE.RLE; - this.currentCount = valueCount; - this.currentValue = 0; - } else { - this.currentCount = 0; + // initialize for values + if (in.available() > 0) { + init(in.read()); + } } - } - - // Initialize the reader from a buffer. This is used for the V2 page encoding where the - // definition are in its own buffer. - public void initFromBuffer(int valueCount, byte[] data) { - this.offset = 0; - this.in = data; - this.end = data.length; if (bitWidth == 0) { // 0 bit width, treat this as an RLE run of valueCount number of 0's. this.mode = MODE.RLE; @@ -129,11 +126,6 @@ private void init(int bitWidth) { this.packer = Packer.LITTLE_ENDIAN.newBytePacker(bitWidth); } - @Override - public int getNextOffset() { - return this.end; - } - @Override public boolean readBoolean() { return this.readInteger() != 0; @@ -182,7 +174,7 @@ public void readIntegers( WritableColumnVector c, int rowId, int level, - VectorizedValuesReader data) { + VectorizedValuesReader data) throws IOException { int left = total; while (left > 0) { if (this.currentCount == 0) this.readNextGroup(); @@ -217,7 +209,7 @@ public void readBooleans( WritableColumnVector c, int rowId, int level, - VectorizedValuesReader data) { + VectorizedValuesReader data) throws IOException { int left = total; while (left > 0) { if (this.currentCount == 0) this.readNextGroup(); @@ -251,7 +243,7 @@ public void readBytes( WritableColumnVector c, int rowId, int level, - VectorizedValuesReader data) { + VectorizedValuesReader data) throws IOException { int left = total; while (left > 0) { if (this.currentCount == 0) this.readNextGroup(); @@ -285,7 +277,7 @@ public void readShorts( WritableColumnVector c, int rowId, int level, - VectorizedValuesReader data) { + VectorizedValuesReader data) throws IOException { int left = total; while (left > 0) { if (this.currentCount == 0) this.readNextGroup(); @@ -321,7 +313,7 @@ public void readLongs( WritableColumnVector c, int rowId, int level, - VectorizedValuesReader data) { + VectorizedValuesReader data) throws IOException { int left = total; while (left > 0) { if (this.currentCount == 0) this.readNextGroup(); @@ -355,7 +347,7 @@ public void readFloats( WritableColumnVector c, int rowId, int level, - VectorizedValuesReader data) { + VectorizedValuesReader data) throws IOException { int left = total; while (left > 0) { if (this.currentCount == 0) this.readNextGroup(); @@ -389,7 +381,7 @@ public void readDoubles( WritableColumnVector c, int rowId, int level, - VectorizedValuesReader data) { + VectorizedValuesReader data) throws IOException { int left = total; while (left > 0) { if (this.currentCount == 0) this.readNextGroup(); @@ -423,7 +415,7 @@ public void readBinarys( WritableColumnVector c, int rowId, int level, - VectorizedValuesReader data) { + VectorizedValuesReader data) throws IOException { int left = total; while (left > 0) { if (this.currentCount == 0) this.readNextGroup(); @@ -462,7 +454,7 @@ public void readIntegers( WritableColumnVector nulls, int rowId, int level, - VectorizedValuesReader data) { + VectorizedValuesReader data) throws IOException { int left = total; while (left > 0) { if (this.currentCount == 0) this.readNextGroup(); @@ -559,12 +551,12 @@ public Binary readBinary(int len) { /** * Reads the next varint encoded int. */ - private int readUnsignedVarInt() { + private int readUnsignedVarInt() throws IOException { int value = 0; int shift = 0; int b; do { - b = in[offset++] & 255; + b = in.read(); value |= (b & 0x7F) << shift; shift += 7; } while ((b & 0x80) != 0); @@ -574,35 +566,32 @@ private int readUnsignedVarInt() { /** * Reads the next 4 byte little endian int. */ - private int readIntLittleEndian() { - int ch4 = in[offset] & 255; - int ch3 = in[offset + 1] & 255; - int ch2 = in[offset + 2] & 255; - int ch1 = in[offset + 3] & 255; - offset += 4; + private int readIntLittleEndian() throws IOException { + int ch4 = in.read(); + int ch3 = in.read(); + int ch2 = in.read(); + int ch1 = in.read(); return ((ch1 << 24) + (ch2 << 16) + (ch3 << 8) + (ch4 << 0)); } /** * Reads the next byteWidth little endian int. */ - private int readIntLittleEndianPaddedOnBitWidth() { + private int readIntLittleEndianPaddedOnBitWidth() throws IOException { switch (bytesWidth) { case 0: return 0; case 1: - return in[offset++] & 255; + return in.read(); case 2: { - int ch2 = in[offset] & 255; - int ch1 = in[offset + 1] & 255; - offset += 2; + int ch2 = in.read(); + int ch1 = in.read(); return (ch1 << 8) + ch2; } case 3: { - int ch3 = in[offset] & 255; - int ch2 = in[offset + 1] & 255; - int ch1 = in[offset + 2] & 255; - offset += 3; + int ch3 = in.read(); + int ch2 = in.read(); + int ch1 = in.read(); return (ch1 << 16) + (ch2 << 8) + (ch3 << 0); } case 4: { @@ -619,32 +608,36 @@ private int ceil8(int value) { /** * Reads the next group. */ - private void readNextGroup() { - int header = readUnsignedVarInt(); - this.mode = (header & 1) == 0 ? MODE.RLE : MODE.PACKED; - switch (mode) { - case RLE: - this.currentCount = header >>> 1; - this.currentValue = readIntLittleEndianPaddedOnBitWidth(); - return; - case PACKED: - int numGroups = header >>> 1; - this.currentCount = numGroups * 8; - int bytesToRead = ceil8(this.currentCount * this.bitWidth); - - if (this.currentBuffer.length < this.currentCount) { - this.currentBuffer = new int[this.currentCount]; - } - currentBufferIdx = 0; - int valueIndex = 0; - for (int byteIndex = offset; valueIndex < this.currentCount; byteIndex += this.bitWidth) { - this.packer.unpack8Values(in, byteIndex, this.currentBuffer, valueIndex); - valueIndex += 8; - } - offset += bytesToRead; - return; - default: - throw new ParquetDecodingException("not a valid mode " + this.mode); + private void readNextGroup() { + try { + int header = readUnsignedVarInt(); + this.mode = (header & 1) == 0 ? MODE.RLE : MODE.PACKED; + switch (mode) { + case RLE: + this.currentCount = header >>> 1; + this.currentValue = readIntLittleEndianPaddedOnBitWidth(); + return; + case PACKED: + int numGroups = header >>> 1; + this.currentCount = numGroups * 8; + + if (this.currentBuffer.length < this.currentCount) { + this.currentBuffer = new int[this.currentCount]; + } + currentBufferIdx = 0; + int valueIndex = 0; + while (valueIndex < this.currentCount) { + // values are bit packed 8 at a time, so reading bitWidth will always work + ByteBuffer buffer = in.slice(bitWidth); + this.packer.unpack8Values(buffer, buffer.position(), this.currentBuffer, valueIndex); + valueIndex += 8; + } + return; + default: + throw new ParquetDecodingException("not a valid mode " + this.mode); + } + } catch (IOException e) { + throw new ParquetDecodingException("Failed to read from input stream", e); } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala index f36a89a4c3c5f..9cfc30725f03a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala @@ -81,7 +81,10 @@ object ParquetOptions { "uncompressed" -> CompressionCodecName.UNCOMPRESSED, "snappy" -> CompressionCodecName.SNAPPY, "gzip" -> CompressionCodecName.GZIP, - "lzo" -> CompressionCodecName.LZO) + "lzo" -> CompressionCodecName.LZO, + "lz4" -> CompressionCodecName.LZ4, + "brotli" -> CompressionCodecName.BROTLI, + "zstd" -> CompressionCodecName.ZSTD) def getParquetCompressionCodecName(name: String): String = { shortParquetCompressionCodecNames(name).name() diff --git a/sql/core/src/test/resources/sql-tests/results/describe-part-after-analyze.sql.out b/sql/core/src/test/resources/sql-tests/results/describe-part-after-analyze.sql.out index 51dac111029e8..58ed201e2a60f 100644 --- a/sql/core/src/test/resources/sql-tests/results/describe-part-after-analyze.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/describe-part-after-analyze.sql.out @@ -89,7 +89,7 @@ Database default Table t Partition Values [ds=2017-08-01, hr=10] Location [not included in comparison]sql/core/spark-warehouse/t/ds=2017-08-01/hr=10 -Partition Statistics 1067 bytes, 3 rows +Partition Statistics 1121 bytes, 3 rows # Storage Information Location [not included in comparison]sql/core/spark-warehouse/t @@ -122,7 +122,7 @@ Database default Table t Partition Values [ds=2017-08-01, hr=10] Location [not included in comparison]sql/core/spark-warehouse/t/ds=2017-08-01/hr=10 -Partition Statistics 1067 bytes, 3 rows +Partition Statistics 1121 bytes, 3 rows # Storage Information Location [not included in comparison]sql/core/spark-warehouse/t @@ -147,7 +147,7 @@ Database default Table t Partition Values [ds=2017-08-01, hr=11] Location [not included in comparison]sql/core/spark-warehouse/t/ds=2017-08-01/hr=11 -Partition Statistics 1080 bytes, 4 rows +Partition Statistics 1098 bytes, 4 rows # Storage Information Location [not included in comparison]sql/core/spark-warehouse/t @@ -180,7 +180,7 @@ Database default Table t Partition Values [ds=2017-08-01, hr=10] Location [not included in comparison]sql/core/spark-warehouse/t/ds=2017-08-01/hr=10 -Partition Statistics 1067 bytes, 3 rows +Partition Statistics 1121 bytes, 3 rows # Storage Information Location [not included in comparison]sql/core/spark-warehouse/t @@ -205,7 +205,7 @@ Database default Table t Partition Values [ds=2017-08-01, hr=11] Location [not included in comparison]sql/core/spark-warehouse/t/ds=2017-08-01/hr=11 -Partition Statistics 1080 bytes, 4 rows +Partition Statistics 1098 bytes, 4 rows # Storage Information Location [not included in comparison]sql/core/spark-warehouse/t @@ -230,7 +230,7 @@ Database default Table t Partition Values [ds=2017-09-01, hr=5] Location [not included in comparison]sql/core/spark-warehouse/t/ds=2017-09-01/hr=5 -Partition Statistics 1054 bytes, 2 rows +Partition Statistics 1144 bytes, 2 rows # Storage Information Location [not included in comparison]sql/core/spark-warehouse/t diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala index 863703b15f4f1..efc2f20a907f1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala @@ -503,7 +503,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { case plan: InMemoryRelation => plan }.head // InMemoryRelation's stats is file size before the underlying RDD is materialized - assert(inMemoryRelation.computeStats().sizeInBytes === 740) + assert(inMemoryRelation.computeStats().sizeInBytes === 800) // InMemoryRelation's stats is updated after materializing RDD dfFromFile.collect() @@ -516,7 +516,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { // Even CBO enabled, InMemoryRelation's stats keeps as the file size before table's stats // is calculated - assert(inMemoryRelation2.computeStats().sizeInBytes === 740) + assert(inMemoryRelation2.computeStats().sizeInBytes === 800) // InMemoryRelation's stats should be updated after calculating stats of the table // clear cache to simulate a fresh environment From 6ea582e36ab0a2e4e01340f6fc8cfb8d493d567d Mon Sep 17 00:00:00 2001 From: DB Tsai Date: Wed, 9 May 2018 09:15:16 -0700 Subject: [PATCH 0765/1161] [SPARK-24181][SQL] Better error message for writing sorted data ## What changes were proposed in this pull request? The exception message should clearly distinguish sorting and bucketing in `save` and `jdbc` write. When a user tries to write a sorted data using save or insertInto, it will throw an exception with message that `s"'$operation' does not support bucketing right now""`. We should throw `s"'$operation' does not support sortBy right now""` instead. ## How was this patch tested? More tests in `DataFrameReaderWriterSuite.scala` Author: DB Tsai Closes #21235 from dbtsai/fixException. --- .../apache/spark/sql/DataFrameWriter.scala | 12 ++++++--- .../sql/sources/BucketedWriteSuite.scala | 27 ++++++++++++++++--- .../sql/test/DataFrameReaderWriterSuite.scala | 16 +++++++++-- 3 files changed, 46 insertions(+), 9 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index e183fa6f9542b..90bea2d676e22 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -330,8 +330,8 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { } private def getBucketSpec: Option[BucketSpec] = { - if (sortColumnNames.isDefined) { - require(numBuckets.isDefined, "sortBy must be used together with bucketBy") + if (sortColumnNames.isDefined && numBuckets.isEmpty) { + throw new AnalysisException("sortBy must be used together with bucketBy") } numBuckets.map { n => @@ -340,8 +340,12 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { } private def assertNotBucketed(operation: String): Unit = { - if (numBuckets.isDefined || sortColumnNames.isDefined) { - throw new AnalysisException(s"'$operation' does not support bucketing right now") + if (getBucketSpec.isDefined) { + if (sortColumnNames.isEmpty) { + throw new AnalysisException(s"'$operation' does not support bucketBy right now") + } else { + throw new AnalysisException(s"'$operation' does not support bucketBy and sortBy right now") + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala index 93f3efe2ccc4a..5ff1ea84d9a7b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala @@ -60,7 +60,10 @@ abstract class BucketedWriteSuite extends QueryTest with SQLTestUtils { test("specify sorting columns without bucketing columns") { val df = Seq(1 -> "a", 2 -> "b").toDF("i", "j") - intercept[IllegalArgumentException](df.write.sortBy("j").saveAsTable("tt")) + val e = intercept[AnalysisException] { + df.write.sortBy("j").saveAsTable("tt") + } + assert(e.getMessage == "sortBy must be used together with bucketBy;") } test("sorting by non-orderable column") { @@ -74,7 +77,16 @@ abstract class BucketedWriteSuite extends QueryTest with SQLTestUtils { val e = intercept[AnalysisException] { df.write.bucketBy(2, "i").parquet("/tmp/path") } - assert(e.getMessage == "'save' does not support bucketing right now;") + assert(e.getMessage == "'save' does not support bucketBy right now;") + } + + test("write bucketed and sorted data using save()") { + val df = Seq(1 -> "a", 2 -> "b").toDF("i", "j") + + val e = intercept[AnalysisException] { + df.write.bucketBy(2, "i").sortBy("i").parquet("/tmp/path") + } + assert(e.getMessage == "'save' does not support bucketBy and sortBy right now;") } test("write bucketed data using insertInto()") { @@ -83,7 +95,16 @@ abstract class BucketedWriteSuite extends QueryTest with SQLTestUtils { val e = intercept[AnalysisException] { df.write.bucketBy(2, "i").insertInto("tt") } - assert(e.getMessage == "'insertInto' does not support bucketing right now;") + assert(e.getMessage == "'insertInto' does not support bucketBy right now;") + } + + test("write bucketed and sorted data using insertInto()") { + val df = Seq(1 -> "a", 2 -> "b").toDF("i", "j") + + val e = intercept[AnalysisException] { + df.write.bucketBy(2, "i").sortBy("i").insertInto("tt") + } + assert(e.getMessage == "'insertInto' does not support bucketBy and sortBy right now;") } private lazy val df = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala index 14b1feb2adc20..b65058fffd339 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala @@ -276,7 +276,7 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with Be assert(LastOptions.parameters("doubleOpt") == "6.7") } - test("check jdbc() does not support partitioning or bucketing") { + test("check jdbc() does not support partitioning, bucketBy or sortBy") { val df = spark.read.text(Utils.createTempDir(namePrefix = "text").getCanonicalPath) var w = df.write.partitionBy("value") @@ -287,7 +287,19 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with Be w = df.write.bucketBy(2, "value") e = intercept[AnalysisException](w.jdbc(null, null, null)) - Seq("jdbc", "bucketing").foreach { s => + Seq("jdbc", "does not support bucketBy right now").foreach { s => + assert(e.getMessage.toLowerCase(Locale.ROOT).contains(s.toLowerCase(Locale.ROOT))) + } + + w = df.write.sortBy("value") + e = intercept[AnalysisException](w.jdbc(null, null, null)) + Seq("sortBy must be used together with bucketBy").foreach { s => + assert(e.getMessage.toLowerCase(Locale.ROOT).contains(s.toLowerCase(Locale.ROOT))) + } + + w = df.write.bucketBy(2, "value").sortBy("value") + e = intercept[AnalysisException](w.jdbc(null, null, null)) + Seq("jdbc", "does not support bucketBy and sortBy right now").foreach { s => assert(e.getMessage.toLowerCase(Locale.ROOT).contains(s.toLowerCase(Locale.ROOT))) } } From 94155d0395324a012db2fc8a57edb3cd90b61e96 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Wed, 9 May 2018 10:34:57 -0700 Subject: [PATCH 0766/1161] [MINOR][ML][DOC] Improved Naive Bayes user guide explanation ## What changes were proposed in this pull request? This copies the material from the spark.mllib user guide page for Naive Bayes to the spark.ml user guide page. I also improved the wording and organization slightly. ## How was this patch tested? Built docs locally. Author: Joseph K. Bradley Closes #21272 from jkbradley/nb-doc-update. --- docs/ml-classification-regression.md | 26 ++++++++++++++++++++++---- 1 file changed, 22 insertions(+), 4 deletions(-) diff --git a/docs/ml-classification-regression.md b/docs/ml-classification-regression.md index d660655e193eb..b3d109039da4d 100644 --- a/docs/ml-classification-regression.md +++ b/docs/ml-classification-regression.md @@ -455,11 +455,29 @@ Refer to the [Python API docs](api/python/pyspark.ml.html#pyspark.ml.classificat ## Naive Bayes [Naive Bayes classifiers](http://en.wikipedia.org/wiki/Naive_Bayes_classifier) are a family of simple -probabilistic classifiers based on applying Bayes' theorem with strong (naive) independence -assumptions between the features. The `spark.ml` implementation currently supports both [multinomial -naive Bayes](http://nlp.stanford.edu/IR-book/html/htmledition/naive-bayes-text-classification-1.html) +probabilistic, multiclass classifiers based on applying Bayes' theorem with strong (naive) independence +assumptions between every pair of features. + +Naive Bayes can be trained very efficiently. With a single pass over the training data, +it computes the conditional probability distribution of each feature given each label. +For prediction, it applies Bayes' theorem to compute the conditional probability distribution +of each label given an observation. + +MLlib supports both [multinomial naive Bayes](http://en.wikipedia.org/wiki/Naive_Bayes_classifier#Multinomial_naive_Bayes) and [Bernoulli naive Bayes](http://nlp.stanford.edu/IR-book/html/htmledition/the-bernoulli-model-1.html). -More information can be found in the section on [Naive Bayes in MLlib](mllib-naive-bayes.html#naive-bayes-sparkmllib). + +*Input data*: +These models are typically used for [document classification](http://nlp.stanford.edu/IR-book/html/htmledition/naive-bayes-text-classification-1.html). +Within that context, each observation is a document and each feature represents a term. +A feature's value is the frequency of the term (in multinomial Naive Bayes) or +a zero or one indicating whether the term was found in the document (in Bernoulli Naive Bayes). +Feature values must be *non-negative*. The model type is selected with an optional parameter +"multinomial" or "bernoulli" with "multinomial" as the default. +For document classification, the input feature vectors should usually be sparse vectors. +Since the training data is only used once, it is not necessary to cache it. + +[Additive smoothing](http://en.wikipedia.org/wiki/Lidstone_smoothing) can be used by +setting the parameter $\lambda$ (default to $1.0$). **Examples** From cc613b552e753d03cb62661591de59e1c8d82c74 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Fri, 13 Apr 2018 14:28:24 -0700 Subject: [PATCH 0767/1161] [PYSPARK] Update py4j to version 0.10.7. --- LICENSE | 2 +- bin/pyspark | 6 +- bin/pyspark2.cmd | 2 +- core/pom.xml | 2 +- .../org/apache/spark/SecurityManager.scala | 12 +-- .../api/python/PythonGatewayServer.scala | 50 ++++++--- .../apache/spark/api/python/PythonRDD.scala | 29 +++-- .../apache/spark/api/python/PythonUtils.scala | 2 +- .../api/python/PythonWorkerFactory.scala | 20 ++-- .../apache/spark/deploy/PythonRunner.scala | 12 ++- .../spark/internal/config/package.scala | 5 + .../spark/security/SocketAuthHelper.scala | 101 ++++++++++++++++++ .../scala/org/apache/spark/util/Utils.scala | 13 ++- .../security/SocketAuthHelperSuite.scala | 97 +++++++++++++++++ dev/deps/spark-deps-hadoop-2.6 | 2 +- dev/deps/spark-deps-hadoop-2.7 | 2 +- dev/deps/spark-deps-hadoop-3.1 | 2 +- dev/run-pip-tests | 2 +- python/README.md | 2 +- python/docs/Makefile | 2 +- python/lib/py4j-0.10.6-src.zip | Bin 80352 -> 0 bytes python/lib/py4j-0.10.7-src.zip | Bin 0 -> 42437 bytes python/pyspark/context.py | 4 +- python/pyspark/daemon.py | 21 +++- python/pyspark/java_gateway.py | 93 +++++++++------- python/pyspark/rdd.py | 21 ++-- python/pyspark/sql/dataframe.py | 12 +-- python/pyspark/worker.py | 7 +- python/setup.py | 2 +- .../org/apache/spark/deploy/yarn/Client.scala | 2 +- .../spark/deploy/yarn/YarnClusterSuite.scala | 2 +- sbin/spark-config.sh | 2 +- .../scala/org/apache/spark/sql/Dataset.scala | 6 +- 33 files changed, 417 insertions(+), 120 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/security/SocketAuthHelper.scala create mode 100644 core/src/test/scala/org/apache/spark/security/SocketAuthHelperSuite.scala delete mode 100644 python/lib/py4j-0.10.6-src.zip create mode 100644 python/lib/py4j-0.10.7-src.zip diff --git a/LICENSE b/LICENSE index c2b0d72663b55..820f14dbdeed0 100644 --- a/LICENSE +++ b/LICENSE @@ -263,7 +263,7 @@ The text of each license is also included at licenses/LICENSE-[project].txt. (New BSD license) Protocol Buffer Java API (org.spark-project.protobuf:protobuf-java:2.4.1-shaded - http://code.google.com/p/protobuf) (The BSD License) Fortran to Java ARPACK (net.sourceforge.f2j:arpack_combined_all:0.1 - http://f2j.sourceforge.net) (The BSD License) xmlenc Library (xmlenc:xmlenc:0.52 - http://xmlenc.sourceforge.net) - (The New BSD License) Py4J (net.sf.py4j:py4j:0.10.6 - http://py4j.sourceforge.net/) + (The New BSD License) Py4J (net.sf.py4j:py4j:0.10.7 - http://py4j.sourceforge.net/) (Two-clause BSD-style license) JUnit-Interface (com.novocode:junit-interface:0.10 - http://github.com/szeiger/junit-interface/) (BSD licence) sbt and sbt-launch-lib.bash (BSD 3 Clause) d3.min.js (https://github.com/mbostock/d3/blob/master/LICENSE) diff --git a/bin/pyspark b/bin/pyspark index dd286277c1fc1..5d5affb1f97c3 100755 --- a/bin/pyspark +++ b/bin/pyspark @@ -25,14 +25,14 @@ source "${SPARK_HOME}"/bin/load-spark-env.sh export _SPARK_CMD_USAGE="Usage: ./bin/pyspark [options]" # In Spark 2.0, IPYTHON and IPYTHON_OPTS are removed and pyspark fails to launch if either option -# is set in the user's environment. Instead, users should set PYSPARK_DRIVER_PYTHON=ipython +# is set in the user's environment. Instead, users should set PYSPARK_DRIVER_PYTHON=ipython # to use IPython and set PYSPARK_DRIVER_PYTHON_OPTS to pass options when starting the Python driver # (e.g. PYSPARK_DRIVER_PYTHON_OPTS='notebook'). This supports full customization of the IPython # and executor Python executables. # Fail noisily if removed options are set if [[ -n "$IPYTHON" || -n "$IPYTHON_OPTS" ]]; then - echo "Error in pyspark startup:" + echo "Error in pyspark startup:" echo "IPYTHON and IPYTHON_OPTS are removed in Spark 2.0+. Remove these from the environment and set PYSPARK_DRIVER_PYTHON and PYSPARK_DRIVER_PYTHON_OPTS instead." exit 1 fi @@ -57,7 +57,7 @@ export PYSPARK_PYTHON # Add the PySpark classes to the Python path: export PYTHONPATH="${SPARK_HOME}/python/:$PYTHONPATH" -export PYTHONPATH="${SPARK_HOME}/python/lib/py4j-0.10.6-src.zip:$PYTHONPATH" +export PYTHONPATH="${SPARK_HOME}/python/lib/py4j-0.10.7-src.zip:$PYTHONPATH" # Load the PySpark shell.py script when ./pyspark is used interactively: export OLD_PYTHONSTARTUP="$PYTHONSTARTUP" diff --git a/bin/pyspark2.cmd b/bin/pyspark2.cmd index 663670f2fddaf..15fa910c277b3 100644 --- a/bin/pyspark2.cmd +++ b/bin/pyspark2.cmd @@ -30,7 +30,7 @@ if "x%PYSPARK_DRIVER_PYTHON%"=="x" ( ) set PYTHONPATH=%SPARK_HOME%\python;%PYTHONPATH% -set PYTHONPATH=%SPARK_HOME%\python\lib\py4j-0.10.6-src.zip;%PYTHONPATH% +set PYTHONPATH=%SPARK_HOME%\python\lib\py4j-0.10.7-src.zip;%PYTHONPATH% set OLD_PYTHONSTARTUP=%PYTHONSTARTUP% set PYTHONSTARTUP=%SPARK_HOME%\python\pyspark\shell.py diff --git a/core/pom.xml b/core/pom.xml index 093a9869b6dd7..220522d3a8296 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -350,7 +350,7 @@ net.sf.py4j py4j - 0.10.6 + 0.10.7 org.apache.spark diff --git a/core/src/main/scala/org/apache/spark/SecurityManager.scala b/core/src/main/scala/org/apache/spark/SecurityManager.scala index dbfd5a514c189..b87476322573d 100644 --- a/core/src/main/scala/org/apache/spark/SecurityManager.scala +++ b/core/src/main/scala/org/apache/spark/SecurityManager.scala @@ -17,15 +17,10 @@ package org.apache.spark -import java.lang.{Byte => JByte} import java.net.{Authenticator, PasswordAuthentication} import java.nio.charset.StandardCharsets.UTF_8 -import java.security.{KeyStore, SecureRandom} -import java.security.cert.X509Certificate import javax.net.ssl._ -import com.google.common.hash.HashCodes -import com.google.common.io.Files import org.apache.hadoop.io.Text import org.apache.hadoop.security.{Credentials, UserGroupInformation} @@ -365,13 +360,8 @@ private[spark] class SecurityManager( return } - val rnd = new SecureRandom() - val length = sparkConf.getInt("spark.authenticate.secretBitLength", 256) / JByte.SIZE - val secretBytes = new Array[Byte](length) - rnd.nextBytes(secretBytes) - + secretKey = Utils.createSecret(sparkConf) val creds = new Credentials() - secretKey = HashCodes.fromBytes(secretBytes).toString() creds.addSecretKey(SECRET_LOOKUP_KEY, secretKey.getBytes(UTF_8)) UserGroupInformation.getCurrentUser().addCredentials(creds) } diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonGatewayServer.scala b/core/src/main/scala/org/apache/spark/api/python/PythonGatewayServer.scala index 11f2432575d84..9ddc4a4910180 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonGatewayServer.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonGatewayServer.scala @@ -17,26 +17,39 @@ package org.apache.spark.api.python -import java.io.DataOutputStream -import java.net.Socket +import java.io.{DataOutputStream, File, FileOutputStream} +import java.net.InetAddress +import java.nio.charset.StandardCharsets.UTF_8 +import java.nio.file.Files import py4j.GatewayServer +import org.apache.spark.SparkConf import org.apache.spark.internal.Logging import org.apache.spark.util.Utils /** - * Process that starts a Py4J GatewayServer on an ephemeral port and communicates the bound port - * back to its caller via a callback port specified by the caller. + * Process that starts a Py4J GatewayServer on an ephemeral port. * * This process is launched (via SparkSubmit) by the PySpark driver (see java_gateway.py). */ private[spark] object PythonGatewayServer extends Logging { initializeLogIfNecessary(true) - def main(args: Array[String]): Unit = Utils.tryOrExit { - // Start a GatewayServer on an ephemeral port - val gatewayServer: GatewayServer = new GatewayServer(null, 0) + def main(args: Array[String]): Unit = { + val secret = Utils.createSecret(new SparkConf()) + + // Start a GatewayServer on an ephemeral port. Make sure the callback client is configured + // with the same secret, in case the app needs callbacks from the JVM to the underlying + // python processes. + val localhost = InetAddress.getLoopbackAddress() + val gatewayServer: GatewayServer = new GatewayServer.GatewayServerBuilder() + .authToken(secret) + .javaPort(0) + .javaAddress(localhost) + .callbackClient(GatewayServer.DEFAULT_PYTHON_PORT, localhost, secret) + .build() + gatewayServer.start() val boundPort: Int = gatewayServer.getListeningPort if (boundPort == -1) { @@ -46,15 +59,24 @@ private[spark] object PythonGatewayServer extends Logging { logDebug(s"Started PythonGatewayServer on port $boundPort") } - // Communicate the bound port back to the caller via the caller-specified callback port - val callbackHost = sys.env("_PYSPARK_DRIVER_CALLBACK_HOST") - val callbackPort = sys.env("_PYSPARK_DRIVER_CALLBACK_PORT").toInt - logDebug(s"Communicating GatewayServer port to Python driver at $callbackHost:$callbackPort") - val callbackSocket = new Socket(callbackHost, callbackPort) - val dos = new DataOutputStream(callbackSocket.getOutputStream) + // Communicate the connection information back to the python process by writing the + // information in the requested file. This needs to match the read side in java_gateway.py. + val connectionInfoPath = new File(sys.env("_PYSPARK_DRIVER_CONN_INFO_PATH")) + val tmpPath = Files.createTempFile(connectionInfoPath.getParentFile().toPath(), + "connection", ".info").toFile() + + val dos = new DataOutputStream(new FileOutputStream(tmpPath)) dos.writeInt(boundPort) + + val secretBytes = secret.getBytes(UTF_8) + dos.writeInt(secretBytes.length) + dos.write(secretBytes, 0, secretBytes.length) dos.close() - callbackSocket.close() + + if (!tmpPath.renameTo(connectionInfoPath)) { + logError(s"Unable to write connection information to $connectionInfoPath.") + System.exit(1) + } // Exit on EOF or broken pipe to ensure that this process dies when the Python driver dies: while (System.in.read() != -1) { diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index f6293c0dc5091..a1ee2f7d1b119 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -38,6 +38,7 @@ import org.apache.spark.broadcast.Broadcast import org.apache.spark.input.PortableDataStream import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD +import org.apache.spark.security.SocketAuthHelper import org.apache.spark.util._ @@ -107,6 +108,12 @@ private[spark] object PythonRDD extends Logging { // remember the broadcasts sent to each worker private val workerBroadcasts = new mutable.WeakHashMap[Socket, mutable.Set[Long]]() + // Authentication helper used when serving iterator data. + private lazy val authHelper = { + val conf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf()) + new SocketAuthHelper(conf) + } + def getWorkerBroadcasts(worker: Socket): mutable.Set[Long] = { synchronized { workerBroadcasts.getOrElseUpdate(worker, new mutable.HashSet[Long]()) @@ -129,12 +136,13 @@ private[spark] object PythonRDD extends Logging { * (effectively a collect()), but allows you to run on a certain subset of partitions, * or to enable local execution. * - * @return the port number of a local socket which serves the data collected from this job. + * @return 2-tuple (as a Java array) with the port number of a local socket which serves the + * data collected from this job, and the secret for authentication. */ def runJob( sc: SparkContext, rdd: JavaRDD[Array[Byte]], - partitions: JArrayList[Int]): Int = { + partitions: JArrayList[Int]): Array[Any] = { type ByteArray = Array[Byte] type UnrolledPartition = Array[ByteArray] val allPartitions: Array[UnrolledPartition] = @@ -147,13 +155,14 @@ private[spark] object PythonRDD extends Logging { /** * A helper function to collect an RDD as an iterator, then serve it via socket. * - * @return the port number of a local socket which serves the data collected from this job. + * @return 2-tuple (as a Java array) with the port number of a local socket which serves the + * data collected from this job, and the secret for authentication. */ - def collectAndServe[T](rdd: RDD[T]): Int = { + def collectAndServe[T](rdd: RDD[T]): Array[Any] = { serveIterator(rdd.collect().iterator, s"serve RDD ${rdd.id}") } - def toLocalIteratorAndServe[T](rdd: RDD[T]): Int = { + def toLocalIteratorAndServe[T](rdd: RDD[T]): Array[Any] = { serveIterator(rdd.toLocalIterator, s"serve toLocalIterator") } @@ -384,8 +393,11 @@ private[spark] object PythonRDD extends Logging { * and send them into this connection. * * The thread will terminate after all the data are sent or any exceptions happen. + * + * @return 2-tuple (as a Java array) with the port number of a local socket which serves the + * data collected from this job, and the secret for authentication. */ - def serveIterator[T](items: Iterator[T], threadName: String): Int = { + def serveIterator(items: Iterator[_], threadName: String): Array[Any] = { val serverSocket = new ServerSocket(0, 1, InetAddress.getByName("localhost")) // Close the socket if no connection in 15 seconds serverSocket.setSoTimeout(15000) @@ -395,11 +407,14 @@ private[spark] object PythonRDD extends Logging { override def run() { try { val sock = serverSocket.accept() + authHelper.authClient(sock) + val out = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream)) Utils.tryWithSafeFinally { writeIteratorToStream(items, out) } { out.close() + sock.close() } } catch { case NonFatal(e) => @@ -410,7 +425,7 @@ private[spark] object PythonRDD extends Logging { } }.start() - serverSocket.getLocalPort + Array(serverSocket.getLocalPort, authHelper.secret) } private def getMergedConf(confAsMap: java.util.HashMap[String, String], diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala b/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala index 92e228a9dd10c..27a5e19f96a14 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala @@ -32,7 +32,7 @@ private[spark] object PythonUtils { val pythonPath = new ArrayBuffer[String] for (sparkHome <- sys.env.get("SPARK_HOME")) { pythonPath += Seq(sparkHome, "python", "lib", "pyspark.zip").mkString(File.separator) - pythonPath += Seq(sparkHome, "python", "lib", "py4j-0.10.6-src.zip").mkString(File.separator) + pythonPath += Seq(sparkHome, "python", "lib", "py4j-0.10.7-src.zip").mkString(File.separator) } pythonPath ++= SparkContext.jarOfObject(this) pythonPath.mkString(File.pathSeparator) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala index 2340580b54f67..6afa37aa36fd3 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala @@ -27,6 +27,7 @@ import scala.collection.mutable import org.apache.spark._ import org.apache.spark.internal.Logging +import org.apache.spark.security.SocketAuthHelper import org.apache.spark.util.{RedirectThread, Utils} private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String, String]) @@ -67,6 +68,8 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String value }.getOrElse("pyspark.worker") + private val authHelper = new SocketAuthHelper(SparkEnv.get.conf) + var daemon: Process = null val daemonHost = InetAddress.getByAddress(Array(127, 0, 0, 1)) var daemonPort: Int = 0 @@ -108,6 +111,8 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String if (pid < 0) { throw new IllegalStateException("Python daemon failed to launch worker with code " + pid) } + + authHelper.authToServer(socket) daemonWorkers.put(socket, pid) socket } @@ -145,25 +150,24 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String workerEnv.put("PYTHONPATH", pythonPath) // This is equivalent to setting the -u flag; we use it because ipython doesn't support -u: workerEnv.put("PYTHONUNBUFFERED", "YES") + workerEnv.put("PYTHON_WORKER_FACTORY_PORT", serverSocket.getLocalPort.toString) + workerEnv.put("PYTHON_WORKER_FACTORY_SECRET", authHelper.secret) val worker = pb.start() // Redirect worker stdout and stderr redirectStreamsToStderr(worker.getInputStream, worker.getErrorStream) - // Tell the worker our port - val out = new OutputStreamWriter(worker.getOutputStream, StandardCharsets.UTF_8) - out.write(serverSocket.getLocalPort + "\n") - out.flush() - - // Wait for it to connect to our socket + // Wait for it to connect to our socket, and validate the auth secret. serverSocket.setSoTimeout(10000) + try { val socket = serverSocket.accept() + authHelper.authClient(socket) simpleWorkers.put(socket, worker) return socket } catch { case e: Exception => - throw new SparkException("Python worker did not connect back in time", e) + throw new SparkException("Python worker failed to connect back.", e) } } finally { if (serverSocket != null) { @@ -187,6 +191,7 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String val workerEnv = pb.environment() workerEnv.putAll(envVars.asJava) workerEnv.put("PYTHONPATH", pythonPath) + workerEnv.put("PYTHON_WORKER_FACTORY_SECRET", authHelper.secret) // This is equivalent to setting the -u flag; we use it because ipython doesn't support -u: workerEnv.put("PYTHONUNBUFFERED", "YES") daemon = pb.start() @@ -218,7 +223,6 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String // Redirect daemon stdout and stderr redirectStreamsToStderr(in, daemon.getErrorStream) - } catch { case e: Exception => diff --git a/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala b/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala index 7aca305783a7f..1b7e031ee0678 100644 --- a/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala @@ -18,7 +18,7 @@ package org.apache.spark.deploy import java.io.File -import java.net.URI +import java.net.{InetAddress, URI} import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer @@ -39,6 +39,7 @@ object PythonRunner { val pyFiles = args(1) val otherArgs = args.slice(2, args.length) val sparkConf = new SparkConf() + val secret = Utils.createSecret(sparkConf) val pythonExec = sparkConf.get(PYSPARK_DRIVER_PYTHON) .orElse(sparkConf.get(PYSPARK_PYTHON)) .orElse(sys.env.get("PYSPARK_DRIVER_PYTHON")) @@ -51,7 +52,13 @@ object PythonRunner { // Launch a Py4J gateway server for the process to connect to; this will let it see our // Java system properties and such - val gatewayServer = new py4j.GatewayServer(null, 0) + val localhost = InetAddress.getLoopbackAddress() + val gatewayServer = new py4j.GatewayServer.GatewayServerBuilder() + .authToken(secret) + .javaPort(0) + .javaAddress(localhost) + .callbackClient(py4j.GatewayServer.DEFAULT_PYTHON_PORT, localhost, secret) + .build() val thread = new Thread(new Runnable() { override def run(): Unit = Utils.logUncaughtExceptions { gatewayServer.start() @@ -82,6 +89,7 @@ object PythonRunner { // This is equivalent to setting the -u flag; we use it because ipython doesn't support -u: env.put("PYTHONUNBUFFERED", "YES") // value is needed to be set to a non-empty string env.put("PYSPARK_GATEWAY_PORT", "" + gatewayServer.getListeningPort) + env.put("PYSPARK_GATEWAY_SECRET", secret) // pass conf spark.pyspark.python to python process, the only way to pass info to // python process is through environment variable. sparkConf.get(PYSPARK_PYTHON).foreach(env.put("PYSPARK_PYTHON", _)) diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 6bb98c37b4479..82f0a04e94b1c 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -352,6 +352,11 @@ package object config { .regexConf .createOptional + private[spark] val AUTH_SECRET_BIT_LENGTH = + ConfigBuilder("spark.authenticate.secretBitLength") + .intConf + .createWithDefault(256) + private[spark] val NETWORK_AUTH_ENABLED = ConfigBuilder("spark.authenticate") .booleanConf diff --git a/core/src/main/scala/org/apache/spark/security/SocketAuthHelper.scala b/core/src/main/scala/org/apache/spark/security/SocketAuthHelper.scala new file mode 100644 index 0000000000000..d15e7937b0523 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/security/SocketAuthHelper.scala @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.security + +import java.io.{DataInputStream, DataOutputStream, InputStream} +import java.net.Socket +import java.nio.charset.StandardCharsets.UTF_8 + +import org.apache.spark.SparkConf +import org.apache.spark.network.util.JavaUtils +import org.apache.spark.util.Utils + +/** + * A class that can be used to add a simple authentication protocol to socket-based communication. + * + * The protocol is simple: an auth secret is written to the socket, and the other side checks the + * secret and writes either "ok" or "err" to the output. If authentication fails, the socket is + * not expected to be valid anymore. + * + * There's no secrecy, so this relies on the sockets being either local or somehow encrypted. + */ +private[spark] class SocketAuthHelper(conf: SparkConf) { + + val secret = Utils.createSecret(conf) + + /** + * Read the auth secret from the socket and compare to the expected value. Write the reply back + * to the socket. + * + * If authentication fails, this method will close the socket. + * + * @param s The client socket. + * @throws IllegalArgumentException If authentication fails. + */ + def authClient(s: Socket): Unit = { + // Set the socket timeout while checking the auth secret. Reset it before returning. + val currentTimeout = s.getSoTimeout() + try { + s.setSoTimeout(10000) + val clientSecret = readUtf8(s) + if (secret == clientSecret) { + writeUtf8("ok", s) + } else { + writeUtf8("err", s) + JavaUtils.closeQuietly(s) + } + } finally { + s.setSoTimeout(currentTimeout) + } + } + + /** + * Authenticate with a server by writing the auth secret and checking the server's reply. + * + * If authentication fails, this method will close the socket. + * + * @param s The socket connected to the server. + * @throws IllegalArgumentException If authentication fails. + */ + def authToServer(s: Socket): Unit = { + writeUtf8(secret, s) + + val reply = readUtf8(s) + if (reply != "ok") { + JavaUtils.closeQuietly(s) + throw new IllegalArgumentException("Authentication failed.") + } + } + + protected def readUtf8(s: Socket): String = { + val din = new DataInputStream(s.getInputStream()) + val len = din.readInt() + val bytes = new Array[Byte](len) + din.readFully(bytes) + new String(bytes, UTF_8) + } + + protected def writeUtf8(str: String, s: Socket): Unit = { + val bytes = str.getBytes(UTF_8) + val dout = new DataOutputStream(s.getOutputStream()) + dout.writeInt(bytes.length) + dout.write(bytes, 0, bytes.length) + dout.flush() + } + +} diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index dcad1b914038f..13adaa921dc23 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -18,6 +18,7 @@ package org.apache.spark.util import java.io._ +import java.lang.{Byte => JByte} import java.lang.management.{LockInfo, ManagementFactory, MonitorInfo, ThreadInfo} import java.lang.reflect.InvocationTargetException import java.math.{MathContext, RoundingMode} @@ -26,11 +27,11 @@ import java.nio.ByteBuffer import java.nio.channels.{Channels, FileChannel} import java.nio.charset.StandardCharsets import java.nio.file.Files +import java.security.SecureRandom import java.util.{Locale, Properties, Random, UUID} import java.util.concurrent._ import java.util.concurrent.atomic.AtomicBoolean import java.util.zip.GZIPInputStream -import javax.net.ssl.HttpsURLConnection import scala.annotation.tailrec import scala.collection.JavaConverters._ @@ -44,6 +45,7 @@ import scala.util.matching.Regex import _root_.io.netty.channel.unix.Errors.NativeIoException import com.google.common.cache.{CacheBuilder, CacheLoader, LoadingCache} +import com.google.common.hash.HashCodes import com.google.common.io.{ByteStreams, Files => GFiles} import com.google.common.net.InetAddresses import org.apache.commons.lang3.SystemUtils @@ -2704,6 +2706,15 @@ private[spark] object Utils extends Logging { def substituteAppId(opt: String, appId: String): String = { opt.replace("{{APP_ID}}", appId) } + + def createSecret(conf: SparkConf): String = { + val bits = conf.get(AUTH_SECRET_BIT_LENGTH) + val rnd = new SecureRandom() + val secretBytes = new Array[Byte](bits / JByte.SIZE) + rnd.nextBytes(secretBytes) + HashCodes.fromBytes(secretBytes).toString() + } + } private[util] object CallerContext extends Logging { diff --git a/core/src/test/scala/org/apache/spark/security/SocketAuthHelperSuite.scala b/core/src/test/scala/org/apache/spark/security/SocketAuthHelperSuite.scala new file mode 100644 index 0000000000000..e57cb701b6284 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/security/SocketAuthHelperSuite.scala @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.security + +import java.io.Closeable +import java.net._ + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.internal.config._ +import org.apache.spark.util.Utils + +class SocketAuthHelperSuite extends SparkFunSuite { + + private val conf = new SparkConf() + private val authHelper = new SocketAuthHelper(conf) + + test("successful auth") { + Utils.tryWithResource(new ServerThread()) { server => + Utils.tryWithResource(server.createClient()) { client => + authHelper.authToServer(client) + server.close() + server.join() + assert(server.error == null) + assert(server.authenticated) + } + } + } + + test("failed auth") { + Utils.tryWithResource(new ServerThread()) { server => + Utils.tryWithResource(server.createClient()) { client => + val badHelper = new SocketAuthHelper(new SparkConf().set(AUTH_SECRET_BIT_LENGTH, 128)) + intercept[IllegalArgumentException] { + badHelper.authToServer(client) + } + server.close() + server.join() + assert(server.error != null) + assert(!server.authenticated) + } + } + } + + private class ServerThread extends Thread with Closeable { + + private val ss = new ServerSocket() + ss.bind(new InetSocketAddress(InetAddress.getLoopbackAddress(), 0)) + + @volatile var error: Exception = _ + @volatile var authenticated = false + + setDaemon(true) + start() + + def createClient(): Socket = { + new Socket(InetAddress.getLoopbackAddress(), ss.getLocalPort()) + } + + override def run(): Unit = { + var clientConn: Socket = null + try { + clientConn = ss.accept() + authHelper.authClient(clientConn) + authenticated = true + } catch { + case e: Exception => + error = e + } finally { + Option(clientConn).foreach(_.close()) + } + } + + override def close(): Unit = { + try { + ss.close() + } finally { + interrupt() + } + } + + } + +} diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6 index f479c13f00be6..f552b81fde9f4 100644 --- a/dev/deps/spark-deps-hadoop-2.6 +++ b/dev/deps/spark-deps-hadoop-2.6 @@ -170,7 +170,7 @@ parquet-hadoop-1.10.0.jar parquet-hadoop-bundle-1.6.0.jar parquet-jackson-1.10.0.jar protobuf-java-2.5.0.jar -py4j-0.10.6.jar +py4j-0.10.7.jar pyrolite-4.13.jar scala-compiler-2.11.8.jar scala-library-2.11.8.jar diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index e7c4599cb5003..024b1fca717df 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -171,7 +171,7 @@ parquet-hadoop-1.10.0.jar parquet-hadoop-bundle-1.6.0.jar parquet-jackson-1.10.0.jar protobuf-java-2.5.0.jar -py4j-0.10.6.jar +py4j-0.10.7.jar pyrolite-4.13.jar scala-compiler-2.11.8.jar scala-library-2.11.8.jar diff --git a/dev/deps/spark-deps-hadoop-3.1 b/dev/deps/spark-deps-hadoop-3.1 index 3447cd7395d95..938de7bc06663 100644 --- a/dev/deps/spark-deps-hadoop-3.1 +++ b/dev/deps/spark-deps-hadoop-3.1 @@ -189,7 +189,7 @@ parquet-hadoop-1.10.0.jar parquet-hadoop-bundle-1.6.0.jar parquet-jackson-1.10.0.jar protobuf-java-2.5.0.jar -py4j-0.10.6.jar +py4j-0.10.7.jar pyrolite-4.13.jar re2j-1.1.jar scala-compiler-2.11.8.jar diff --git a/dev/run-pip-tests b/dev/run-pip-tests index 1321c2be4c192..7271d1014e4ae 100755 --- a/dev/run-pip-tests +++ b/dev/run-pip-tests @@ -89,7 +89,7 @@ for python in "${PYTHON_EXECS[@]}"; do source "$VIRTUALENV_PATH"/bin/activate fi # Upgrade pip & friends if using virutal env - if [ ! -n "USE_CONDA" ]; then + if [ ! -n "$USE_CONDA" ]; then pip install --upgrade pip pypandoc wheel numpy fi diff --git a/python/README.md b/python/README.md index 2e0112da58b94..c020d84b01ffd 100644 --- a/python/README.md +++ b/python/README.md @@ -29,4 +29,4 @@ The Python packaging for Spark is not intended to replace all of the other use c ## Python Requirements -At its core PySpark depends on Py4J (currently version 0.10.6), but some additional sub-packages have their own extra requirements for some features (including numpy, pandas, and pyarrow). +At its core PySpark depends on Py4J (currently version 0.10.7), but some additional sub-packages have their own extra requirements for some features (including numpy, pandas, and pyarrow). diff --git a/python/docs/Makefile b/python/docs/Makefile index 09898f29950ed..b8e079483c90c 100644 --- a/python/docs/Makefile +++ b/python/docs/Makefile @@ -7,7 +7,7 @@ SPHINXBUILD ?= sphinx-build PAPER ?= BUILDDIR ?= _build -export PYTHONPATH=$(realpath ..):$(realpath ../lib/py4j-0.10.6-src.zip) +export PYTHONPATH=$(realpath ..):$(realpath ../lib/py4j-0.10.7-src.zip) # User-friendly check for sphinx-build ifeq ($(shell which $(SPHINXBUILD) >/dev/null 2>&1; echo $$?), 1) diff --git a/python/lib/py4j-0.10.6-src.zip b/python/lib/py4j-0.10.6-src.zip deleted file mode 100644 index 2f8edcc0c0b886669460642aa650a1b642367439..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 80352 zcmafZ1B@`;wq@J4ZM*wx+qP}nwr$(CZQC}!wmtuSnLBspB`;ISE+lnMcBQhk0v&QGOsvq7S?-^UssrKY{S?;d*)&wieEMdUW<4|230^y}AYwy4HW6Nl6u= zoaJ5-LR@-QR$A^vl6rDZB|J`!nwFAQGA2%Ke42Kgo=QPnnre2Al2#%n?2d5qNx#vg zLS!W4-9uaZ{$8F;zKm8{9bO_C`oE>MI*Aq63km=LhxLD@WoTezWpC%`{QpU7MO`Xx ziw&XoQw?Flz@#5B5@*v;0sWj>WS)p39@G;8D1fR(qEen@k{AUt3$zi z;{dihzth9IJM2krlN>HW3%ZU@f)zh#x}qz1I%iKBe-(PUFCT?dhT0aD|MeuiGxy4R>IH0`W_ZMf4)j+w+$g z&DpJ|DXf0EK#%-2Xw7m&$!gLmp+z{0+|j)9=N-Q-kysDt#|9;^#5;m^kYh*zh&4#| z4df7CH$H;4FfyEB7=l^*M|dV=3w8Kc42583#*=jtV%zL*ueu&U2Oo z&)ds{ur9liY+7?~5wuIlOP?FV65|eTeC@3ZEeknNdgMEe>Yr-~rnGgV#w^%xwuzI{{u3=pNbvv`QwzoNZN_kuvWq3t?@y$ z8MEXqc4>Y2dgo}^sFQb_jN``W;U}$r=9)Y3S(74`h~W{;EmYZKnY^q{u~4q{4VQf~ z79KOih$5^^(v5a|aY3h~Q(t8JFvuJMhg%~`X$_cd@_nb7q3Pt3mWyq?5X=1Cj-+d>H*(Nq>;;*OGB2~ z6}$*}Tnw0Z$!e_-f+JI8IeHLiZWRdIV|J<#VaX-l9)ZP_)z(gWnlBX7X-&o)PRc&< zKOTvBXrbPmD>-L= z`cZbtSy-iQOV||H$!Y_4TiFHsY&F^43qL=mKdxUEPMtU`d(E{95DdDFx%`kp<6>h6 z#S4|1euLb|Jb~|)-B;LQ%_k-gCC-eYUZluQgEXN2=B+z*FU4ld<*LphdGKbhOaGWD zJ*gL?wLTtS-$re89X)Dv*)w3Y=s!~ySv#&?=eaJ+n$M`ks-Q7b-^}BK#+ELLM*q4? zzks;u+q55ct(g4>p8kby<~YP%=coVxzPbPaDF22hBWnv2TW2Q|M^_Wae`CSF!0uo1 z-@xu#`y%P6HRZRq4`{)gA;nzTHCwqvc|~jGVl{f}!l+g2NTZH@q<}!kx?S`GNk;m;nclK2jw~ zLSc`rMZB#`+SWu@jR>*IxhIv>J77;TLtm|dR+R}ET4PTt10Lz#=z$QCuf$}BB%@tC zwatRA;3~_2F&aaf+Te*<7ur#?_x}v1l{r2iA|ziF@?!8en1 z0+yUUz;*&d6itxsh^nKNH%F=JwVFeo={PWkG99)Ej#hIr`S{v^z7Coe;@DNGtt95y z_)cN9NK%NtQ!>4}8q1?h5+aHZ`(SP9qpoAG zp|pO?h_NlcHeZ^L?@WKo@Xbs#0MhQ~q!?NYyOrcBY0sUy6RXp(QPdlD0?Ed>-U7(u zFB51^`qfQV+$imO-#`mYU4RnMoxcI>XkFPQ%(Bm#W;-?yH)njvkwS zb9KgUYiobmy1eW{%z_!8F=R9wi)_mqCoE8czP@cTC(LyM2ypy#DnqSW@i#M#x|%?EQ7L~}E!QUlDq)X4OqO7?x5 z9sCQHB{GA$z}DKWKFje3nu5Bv1+45-_3_2QJ*$^sIYjTKZZ~4;-}~~q*MQ$GjcOuX zC^TyH{mNsx$Nb4$j6J{;SX>b6lrm~!&7X~;@V+H~ZXsNRS;Ha*t1VXq#P^*s%|wt$ z6jU%>%TLBW&Kt)Yd~#`+FzrrSDP9@Xsh-AgS8)uRU={?oBCedNqePLlNcqXYzi$i7 z%2}`ZKc@EncP;Tu#(PgSNT<9E#W>*dMSoZQ^s-Rwp5Kd-h z-XWRydEzVVYG!ZK+8jO5_s~>wf+FX$_m29Zu+IP3U}LytElYKIfk-2acnT}62p3uw zv+t9RQ|gpS$?JU7w*C>bke}#Rn1cG=88$+ygmnt{ei-`8s9^}}LbWE5gb4F~l-KwB zCTSRa<462Kdc<>S(4K39~X>J1BqRJETL>by{2Hkqof4uyPm) z24-Q)&~AJc56Z?}XSFQG=bP#Vi7e;ErG}p2mu;LG@Dww*L zCwP;<%9KE8qz+f3AyB#6w(TMnJ|Vz?d9)(SU(op?8ead#DMy{;eYsSc;1F#x*lJ!C z2CR~rVu}{{HT-vnv;YQ6^+2Mb)9I^oz?f7b z17Iqfu6OUnRtXUWj{WcV6cbBNHwjTYWK4sh3atk=sFHQg>8KWmWO+&(HL#RK3nFJ6 z3(tj#_Yo-2Fa3HCuSEm&)CgMyfi%WHX0c?aqR$A)zb(S!ERUswrmB=4hb5?~Hkl0b;BN8)z0 zHg^Qs;|E>fZn)76JSB~DVuNgTmUfB#IqMeh z6vxngU2+c`kU4|#I`@W-iepJl3|@sN7P8zX-Pnj;akhHD2ZWwKWKU96v#dWF^cItU z%aS4PI>Gk_?AwvoAcbdU;83eb{q;5{LZ+Yez-lM#BO77aSw>%gSVP*LiSa=?`OhR6 z>Yx32=Jyj2=2ch8f%usJ%=l9_%I0$VyUwJ_TtFHV5AYsaL}fn%T zb7qzq5c42X49xd&DK~U-tqf0iQ@q`hQLK|iRzYzP!`ts*_GvTSFZ zH{BpjK;lBcWc4L{l?r4WXN~BNrG=KdD5fR;-mGYN8(J3Gr zt5B8|oMeIkWsEorBX$M%VcnsBgpR%0G_M0cSuf;tj*0v3B?XpXt zG}z7a+yG4#ehaCMCx%&G9LInwL!GSXQk3jY&)B|SKZEyM4ed6!dF|xK!SDgXrIBo9 zT$|HJa>vLydtmJHo^b?$>XSxd-4@FB}`L(=h^|tk}qSaat#V(=3aGd`WDp zy0f^*mHCrN0Ws>ML5{1X-EdZE7(sRG!86)6_!Cbtp+4iIeshH22(O&+#jA5)w4=V? zyBT#9VR+shhEWFwY1d8Ak#Cqfwy9yiYP(7%>IY9t;gjw){ukOIK1A&rr=9yx zBsm7GZiWRfdF^V7YaB!)d?i{Oem8?C2WirE z-_9scgUE^|FZ6M_n;rS}e#^bD;d4~B3jXFGye&AQSKut`<)c85S2i8nXB11FW%WU! zt|b#aI^_B|26Bohn`1G&kfxc{->1sIgOeH^uiU+YoJ(k>C+#*C->J)Afj^JXH1e#~ zTz~$kh4P+q44JH6tVhl<%&#yU@7t2mi=3FL2%5R@+;yr$s%pz{lWkhmD1D^J&+ww8 z-I6|{nD->^BDl%8SZM`#NI~_}IGeHAy!@LwVmS^6>nGe z3@cH_wsy`JZEuF&Tr7PzHLt*P{jASW1vXE8wMy5yfaF2L3{|XTd8<`hOQ7sfNA<{5k8_!YBj_7H`zuvV;uW7hTrtL_|DYLr2KQc#5WKD=BoA-g zFUPrd_}TR|9w&o8?OW(yB~dn~3p+=Hh~_l$graf_`UqbS;y>9%st6lM|RTj6uJGEp(Lhp}dmhT`Jz+OIkLg+f=*N8J$NAG`k?&t0&Rz^nd zPb|DWycx3q)@fAre`I=m-_s1n?e3(zt=;1Y`w+@PW+T39t&}UUZVfxE>mB{zw;Px4 zy;WQr6QfL#VYgzpi;5??@05ncs5S{rsao)2gM8KG&c0K~g^=9Z@yHYTqS)=T9Xs zvQi3j^>TD0jxz0zmngwhHUj{aoV@J-bSKE2Z)4UL8$>)8-%+zil)oOgez9fGUzBh! zB%(jAiHaXO9qZMibWBFJ4rd7?Gf6KY&T%)JR9RI9)RU>;2BEiy8mbh?YGxVJU&nl= zU}A!j$0kk9*4EV4?pI8_nY*}oj%m_?WQW6OC3JU?C0Xq1vgktW6;e<2clxCDd|zJ^ z#0FFM8IhF*3Wx>pY>t=Q34QyCe4QVh?R3Xit_kWO#N?1PSJk6+jBf}dA} zRCKBqRnN8Kg?>HHebxtOauS0t>aYxYoeUMsbpa19Tc0Z)Dr*FelymGA%mF$n#|wwN z=Y)g$@MQJ%^hVR0va^rXnHA+m*ds+*d_C+u-Rf4|UAumr?|g~7z_U@I^^IO!I%E!3 zGC?dBJW8|fCMcdH9X&d^U=MD)4>nz^V%q_(8n&o-)kD6-5^j$93@_nE?d~jm1#gG) z3?J~+8#$P})Wy~PLInnuDs4+!t)<2`43iKPp*L3-%q0nl&dpQD4 zVr)getD`+z5c($5Cs+&HJ$7viXN9o?zdj0fC+%UVAvxq78gZ=aIG1qh@c5g3@9%@6 zMk$G;^VXa-e&3)Hrl=>6~IE<&={1?0W&q7cV3WYMmHlZ?7wYpO$4TK0Us<66 zD!x|x=1%QiBAA$W_}eit8GG$AF-~5(S{+V-abENe)~KXw1K=B<%wv&uz+(qMwDjPz zfWOB5r8Xu}Gx|QiF0hk2SmCY^(#&GUOkV36XK9|xrqcMr1SSV$v*Y{y*0Wx6AAP+| zYQRKvJ{`IS1)7P6bL8uygFE_r=C;2Wlw-rkr-}^s<|WVDf>FdlmDCXkfHi^`{mIt- z{FdgAav&!}m4{g34Hx|2u`q!gKY=Vik;FDBCq(=$e(Y2bUq*<4Z1xoI1@_?%CX^S# zxr{KN-0kY2YxQ$c~tyfThY__XGrb&DzM|)YA(yI%w+a+YdZ#( z@9T%a?XiH8r~6tSsfJ5y=+MY_7{ptNMXx`c!s@fSz99Rd1+-zb+gYoQ-13DCSr0(swB40%H{K6nKqsta;wwM zB0BVw5CYXJE6ltzcbHf~h9aXS@;qy2AJgf!(k8&Uo{f?zQm3=AILmZECaOgR_A&T% zLUzUH&$s64jjze4Ekn_{L{5SIZY+|niknyt-37dQ&~;kVCA+YOoBg)h*yUCMcf#%R z=cv7gw^s(W_ciZ ziELhk>M-5fV}wPqe>P^Ddf8nqo-VGmwtJh(##=-2wlrO#6Yg&Q zXQFv(Dt~n=+|(;lEmvhdpk1FN^ZZ{dvov^Z*ez-0uNQx?jM^68nb(v)MP^@Mz<=TY zU)Go^*?s~87yuv<@*f59|FFh(HueVpM?3r<_~I7Tnb;);gx)hs(H*HyD#S~oU@A}b zf)HRNpnze(8Y1!x>z=I5ifij7^dnGr#!*+_EL-&W^MZ!H_tUYXIR`p|6ijqC>T>Cy zLm7swnLsy3Ev6x4U3(}0ZlsC>Mf35JK*_soBQQwYbz*V?fhIAOyLWvy#eSyBfC!k{ zLQbea)G7!S>UO&NQ25I=xrT(1Rg;O7Z9$x~lnRtl#ycCzlm^IhqG(EXD%Eq)df}-; zCJZl81QuyJ=ZFf6vX6Wp=_ZBoznPoydpHOnapf$uB&S4a^{sFV2!;KBU;^-~PH3yw z5a`hFJbR)1ed&Ze++FZVX?*Tn-e~p`c)Ue`?EwXE)cr$71Lkt4p&rZDleX#Y4rsOC z0WzMM+gA)$lMVc!Dx82C$Foic;;UvKSBAy(Z3aCh>xj)P(jz0dig#Ch=T+=%LW)!C zzKx8<(?xAi+)f9Kg)LulFs{vUj`ZUxN8qHei*lZ7u~*y|^$*40|N3H|5E9Gz*V=V5 z277Ko)JoWgMxY2GktT$~u9YPHu3-MYl9+POo~Yh?x5jdgTQCxLFc>hoq)`#SrCU>T z3CMclF?zer-eH0IRn!ZDZ(B=w%2o3qbz}_fbPw_P+0)EccJ6Nm=QRIHZuQjGI8yfc zM4w#c60a!Y%Z0ZniPg=*Wl|{OS0wzzk#gj3v^WcLmjIYGX7L$uo_MqUlIL(9>QOy+ zK>RATPvBdw4nXrC)YUyVOY|hqvA#wC#EiP-jB3%|f@$Je`mxXVdS%$9n}$YO+fr_h zadYV20oZ|_p}*tlpoeT<-5la8h;Q~UqDTQ_4$*Glzeq2{NCP2m3)>|6 zb{(!`9dazIA49j3u-VL=lssbJ=p);~<`wtDv0lzAiL&4mP?=!Pa*JneH`cqd`VJ4p zC_WW1K`AT>a`LN@R@$#Kc7l{Idn~K?YCFyUNy78KbO5@-cM16rxWoQWef`Tcp!87` zy!20fEg}E_ApX0)Of75;tSvnMBgy;^u7PW{ov1BQ1iv*sMs0jcbgg32c_NN;ohbnV z3&I}WltV2kVk@UAorbQ+qu{UGobPzjjfb6%wJ^a#axa#=zV{iJ53tldT0%9@v@DpP z?jDVwW0HxH5s?Q`)EfSJ)FyfT2~sS5`yo=~FJjAaP-ZZeOi_uUwPH;wnP2v8_2pxP zbQ*E!jAA~C4NS5LdaR5QpyXh7>7h92w}X*a`a}|=nwsR<{Y(q3f__i%zD8Kd(HOak z<;u{NN^6@chZZ1Rc1bOB(lR7_QVvklK^p&>7^IpJm3i^Nk{XJ7>K112PtB9pdy^)w zce)vs9r=!X){ZKfM3RC6$$Q+cOQIC;3n{d-eiX@NI!MM9UfB;Tl#D ze-*E8{iwgYT?^`kq8Z4!Il#coxrUN19xVZrsaui7e~=Al9zH!X-SOK{MPLHj<5XIK zT2Unjv!m0SO;NI+8~9U}IUrIz%y>m-U3Q^LKiA1z9PWZo)-AQu12XNXppo{HRz}K0*);O zph^{MDEO(Q4C(>VF7bq9CbzLrOkms^mjX9JQmva{58$@p@&nb-`rk%mnSu^^UPFuA}Jk2SHU+sCt1nZO;JHqxEK^rXX)% zVSfIzB1Ur{n3GP>ppnTyD?JO|kl6C*$ zLXb&c`d?T%#wEQV{ES5PLAKLv0*_l%uXu5T-6AlRv4TZ=S7d+IoIW<%0?Y}{(azb7 zJWfJ~)gT0#8Y}AwbR`J#Ag47e%uF*Lv1>xXBTio1?-UsV(FjN;d3gwNq08wDPq^8T zSc473#czC0?W;&Em1uQqxq}B;=wrRE4BZAN?>YuWDWph7?qCROlG=_S7c|)Zy@|I3 zE+ZCmxQK0mWr>u!stw+sc|IdN*G(4lD-R0!V*laS3FKasDe7kB)7E*|@N6%9XU z^mJ90I{Cy?h6E#%-fuD=8?<5GQlMfWi$6*KL*=_rc6;UH^|yBf_s@gdi%-_@&l9TG zaUiZ&=b*HHLT)IJ{fR~9?|<{|+%}|ns=xvO0Eq$sVElV^urzQr&@-~Lwl**@PwBM)bb!17sU4vUPClTMJd_zr(4rZClsb zCQB+r>U76UQYPG?w@ozQhfkWvpbYhEUztOh?J4Bhcqo|p{(Z&Z`n0aa4#H3Q>5z!W1)wA)Csb}b5HAFY2Fl2B(#IE zXk=@Oan0JlPYrw%Xi~aw#08;uc7U9vhiGqS+7E@}17Hm}bx)n4cdt-e#5RB z$2@GONPYgX1DM!=A#1pTZ}R&I59H@JL{mYItp|XR{`!DslA|*FWCU^1aZecY+9F+R za03z3MIHM1+VA-K1tb!#z|?@9&BD~x%_^KGe}bv1F=PXIVWO-ng88R?^ntF2Y)^?u z1ylz?R0T2z3?=G>YqLb~f$>WECj-*smfFJ$`6G>2jX>uXiWxZ+K6x@p8pLOw?v*x+ z`&ipdLl`Hsfxun5;2~ukpd|pOX}k)?#$T#fI-B*0cP=~zi_A| z7(B&mB4T5bQGKu|sxwL2lK@O70}0`{e&bH?Iegcw>|9tGtasyb<8qH>_GR4XT!h)9xO)w-enyzW0DaR9|N5DI!O z&sz~X5y`SPkLV=S`QYNBw^d;!10@$ovQ&x-O}&?b>JmjHcSDrM=y7KQQb20+Ket7( zjla{D)`Xr}f!p~Dk1Z)A_Ul%(03^>8ZzcsSMWzQrj5-C#R|N`xeq#dz%Ln94_6R?^ z6_j9iP6`0uxVi?*8SZrZN|^fZ&~;Tey<%Jw#G)96=@8j_>JUbVvkQvpjs7ABj4`23 zfVD(WqDTGJTxRJ-F8}BU{3C?$0^9#4G^DC_$C09yJ{|ufR(+L9eIrjC2^Fb+x&bPUHE(%vVSoybxsasJ|7Ai*`YCY7))oTGFazTEg3%fB7!Lr)D*RVzhKWW zsDo@C`#tSzOav%TG+AF|0^r61Bv2++6o3khwjr6= zu+t_A$(zGg2-56U)f~@v*PW2*Y58RN&dO9yhaRTSzq1jJTW&5t5@SI_&b>EYMwmZY zk(Grx;IVvcJ166XQf;x>w_O99z11AhRxaRLRck!9K2S^G!Md3Md!Vwtg_Mg}54w*Pr3p3}y#nCS12~LxFY= z4I7qvXcO7^@OpI^d}5! zw4!L#`vEczEn6u>JyACk(aGev`CJl|A)4oax*|C+Kwm_@3xq&+HCm2hO6&BVi0;se zqCLNS5>)M!shC~kCjw?pseaUxk|1Sm_2XmyoYsFNah7bcqTPiVoA5^%Og2iBza+o(ceM%ENcj z)Kn2);Ssx;nb=vM$eEZK@XME+oK@P5oVft!Zv>}2q8w(r`cvr#r;>k zYar73#j=a9Z8Bc}yl6;N(NsdoaYZ*`2Wv2Xnh&Ku1TD?{?^=I_r#CW1hRsNJ-me%_ z1i2{}$kh=11&1a35xksSEO?xeF9;N>YkD3P&$s%{F~Q*%SErd!4saA#YP|poj=!dg z+VRvRScT09M9{~FntlS6wJ>N4#nfFsE8ds>A@LJqX{G4j0|%X2aQNL$3!!{OYj`aG zWiVsY76AL99E4C5wRp!@^R_^ctyHoW8T=Ul&ixl$kvrHqib9HNQ$olJ^n-!QHqmhu z9+w}Gu2yD+2f7m-ed{m!dSPD-BZwpnze=h{WpEq;N*Zf*8g)L9x`GmbAmel5HX1C4(7tciiCf} z>&UdHzv82O;fEjjt?j7VU$NOx^p+@Z7Y{%%9m-}filb^48B{`|LLJI&8OZsd9r_W0 zO(V*fYfK$uwB`D7%QgPixX0p;mKMt}{0XD@ZVS`}@BoHU$Tu6>U74&jY!mxw!H1)Z zZtVP=xY>N11Ux+Z_EgJCcS?_t+EO`}x>N6M4zC>H#5b1`I9{M!`Rym`#W%zZtqou* zY4pR-w!iyhuY_F@2FZ{*V$9v|y0tO}QiaSea<3n+%qkpkE%_L?F{+7!@MIW4E&A5k zVcIbhCK}ju|0)Rc4)Axy}V+TtX)i;s(m4nwy$rqAxm5FA|4DXKQqtHMVXm;H$wgtbtM`w}w4es~`EAXwq_1mGEWPABXYRpB33mwlG#kd#t+< zdG-FnBF7>lx_s@`yOb-X82q^8`jIS^D8$?Miyq`e2SgQEAf{pvfiTk-c6t5zJP zu-Y$E1u46oUs^b)YeaN;RAQ3?a8oUfzr8N$sn6zWhdRJL#Q`r`C+yeQSt0BymRf0D zZySZjE5_K9y4TG!CG@_n)PDDmcSz~&_ysVda!mmiWP4KwqmFMsVlQRnLknmnj0qdS zpSNevdvjO#XWc(mG?6p#!9I{f?8``R8+2V`a5)D_s;-Nnk9=|+>tj|#)Q zM&b2!8rk?9$|5&I%Ltwc2frt$KFA#0PPUASx;2yA8T}_!=B{9NlGb%uQaK7R4)S3| zc_kD32`NNqnoKM%h8}5Z)-h=@q(X7AXR2Z=HlXd8>oU&^%9^w0NVb!nW#U|rcCs|S z3s$UC&}hH9Y-;W1@O^lwL4SGUshhfAOqY4wY7erebY|Y%-Ob2D?UGIux%F*p0GVS#1Zt;^8KweJAqgtvpq*K6X5<)C(?Jql zE7Lr*aqqXCIeKqh>jVB|vRh-}qt;E?D|zK_b!RlCh%|cKwT|zw8$()rSngJ}VxvtP z#+<+RG{xtm=bo>P>Nv{u5HG3m0o3Ry=%YF;{o+r?l36*?-qjf~=BFLwNqrFgh|NT(*6uKlx_c%~+G3 z3%&*~wd;2M(xyBIyo|l$<`0W4D;K;s6P7k!%&zvXm-FWt9b3hn-+xG)e_3mJYA4Ce z! zD$#2jTmYU@p|5=n$e{|JMjUczXuAtDjoDLq*mf5arEniAPA66Gb^NCNmA$Wn5_ zsnS=iEafNB_>NP|rq`6%b#i6F#)#{w8iWY4ue=5JEVX}+fa=zO&FDcPkR7=m(M5b8OY zfK^A&nI1HqbkrG>0q0@o%ppwfP4e1ibRfJ;dChGTA+!io%&W zf!Sq=$4REv393fQ01br3gIQdmkDAeoFGD$~JiFXyB?oZ!|5?!4M9&peF-;>D{%F}|=u)yiOsV&tPrVmH_*vI@v1 zg3x!56aSFhigdY&HmJU96J5=)VddU|9GtOc4a&JN^bG}^l`$8s5QefN6#jt*_Pdpv z;GdjKQ(n*;!DQDbSiZo%@YDMI+vCn1ixyq&KJSgbNHFr>l`!48*WZ`5Rfpe`h35+f zPEOw5OL(xm_V1?+zu?ciy`@$kFAoRSV6{2lx1Z13ouRz2I{Mk7sj8?v-xn=zH({@y zpA!bG-0Wxi+Jd1Ag}UF5!>4w?kfU=|+vMGai;JkD-MEE(zh|_!YPT>tzGgZ<-ybrp zFLy`GXF6HA&z^&ckw1(*&mvh{@O8BcyNW6UpD&sz*R|$#_j!>&l&Y2AfXz!(Du79G zp{b4Zk&{$YPWEA8o3{G7{TugY`Dufbx)d-Fn-}0b86-x3IxfNi5s_mt*30Lg`WJem zkn`&&AB|;V{a74n#fS-`-l4QTRzz-b%l2oMB9Jzu(*ZqOSrM>-0QXHJt$IluOEXRf z`+lum)TL1&%bYxlfQ=4!YdXfS_SpxiH4-FZ4t>UUC581j{0G8#pQ6dT=;Qs zk^6ov3Lol{KCaG|Wscb}-P6OS36nmd1bSUjJ4;^AAQFMemGgqDkQ zz>h-~x2;V708BJCezX=6qOxr9DU_p)ItkK%gh7s)17OTVS*M7xyD_+9o+H5Ld{{vC zs|k7G0LuF@p`vM}8Wt9?k|6-Bp)*u_NU;d07PD#(hxeldddY{1(w|s1@|&}9c(^$7 z9xDTT0w1RwJNF|we;{DRVxZ3eP={53r2nFaCt9qy<_DYjPkgk%9#zO$BF!wm0J{3i zXUs)?F}J8lRJ6VjMf8izApM~)C<>Y|xUKYK`73tz?G_I2=}%LrriQ+jmZu>;>YbUo z>L#IXuTlvC0AJn+6Rm;h?p6ubtp#OC=`59Sg5z$Q3^7pF#R2+B*&<8-H;Asn!X8oQ zxenY1gs1dYabnrnw5mAr5yU36ItbGrGnq00n;LK`?vaK>t$Sly`sMhi`61Q z*9TB<5~w*$0aKrj6itE+Ebo!QSb_%+M(_&9(#O;BfZG+61TIB_v=gWX%LnGzAj}4V zD&Qms2W#DT+t1!jA24~fkIf-2zbVD!j~K<*)sNW|do1Z4yHY*> zU6FoAh^A@V-%#nx$E&5eYro}Cc)(#AlNOQGS+uk})T7WsM#oRA)aSZ6UT9jKH~m>l zO;ff%0#FBP{8+a{Rk%fBRr=(RJV;}jKm;GMXq3_l{NL`Clof~8jBoB;(Y8h|YbzJT z{3(O9$;h6gfaKs>DNfyoBl!vUn$g;?h?WqZ2#ERg@xA;GQuzhVVqD`3 zhas51{QBX{d+|W}is2aQBZu9p5A|{^_}x_S+x5}`8qmrd@LeS4?Ujcl+#$lO_l#}?if zH1xi@yfAaJ4d}37YxP(*IfNX6N0rXEe#ACI^$mIS2(2lt3V#Ze!8mJHP_Ju}l*DDP6Hax(>OAHlcp5mMRjKRJnrQ(0%@)9j7c0 z4pI#ei7h9}19J5sRQ#8qM)VrQL2w-`$T4fKI;F;5!ohd_^pipuP2=GU#Qu>5PzBvr z*`U{`>lx;xAmzy|UAI0B`LNiN0WJlyO1b(xL6qce*a8Wn_J-m#vN!3M+WSXN4l6tW z_sIxNM1(OzulZS^hRDD87$aJLKy1ch?Jx)`diw`Q<3rOv)$4Ow8-OOT*5R0x*DTGV zCVLauCPQdY8tP>bTM%)_l2jD@0dG;M6dIMEbM^FFw1?URXAb51`@f6$kUX-P+rtQi zQ@h)V4lo*RE7NtL&?X$g3LN1Html6@gNLs{xr5-(qs${!$-2 zkKM>6y2Dr-q&_YsQMS+VT(e1;uGUS5yl-L)E*MtDEEZow)=YVipgFw05a6NpDO~#T zK=#E{;yzBUJGmGZLO6(ZfkIRF@enlQRS`_0Q?5tUs>5jlS-) zLCYLd<@M43+p6H(HNS351NF`T>lUAB0R#>}2Moz%iRVjh6`(X29A(tq%wKyMY!06;3J0DWe!gm5~4?Wt*V;Huf0Ne_Um26YPJRTkMY92ywM zKjp^F44rbke`KnU)u|dO$I3C<;@FP?hR6`VKQf$x#DZxf(eFbMzrk`;>TvO32faj1 zrBfn!0^Kq^BfW(72p1CpLvyy`v45AGYqCu-)K;KsRyl8M(jaP9@STtQA)7jyVMG^# z!8^0aQ>`px!f3)0>VYE7?YzdydI;Vo35y9S<;$joVK*`kI)}K*NwEZQkBZJ4-m_Vs z3Uy7e7TjzbHqxvof%|ISRY3Wo3SO?Efx>~w3@EB9u9Hq^ANgtamURx&a+I$t9;VztpKu; zWTZ(nlXzN#!p=z5S(*|m#LHmM?L;?&i`I|UygZK`#jmE!DNEl?chq>`Lck`oLVH6I zdI&U?;FkXR5E+_c^QoksE9j^^>=A<9r1)j1Q_?(}IN9v}2GBtOE{GOof>ahM?Im>W z>;?t9dsG!7YWGDBWrKs4)lXNzZbY7eH6o|0S`VbUZ~G(=w79BdE!QS$(*atenHC>0 zj>O|N1X2+k?HZ0sQOl_Wxop;15vmAMy3YH$F zfqIE-Wu+IPB_ib1&&%oQg~HRvVbzZ*Rw=NNL+++Q$VV9)R4F~MMjbkeS}ZC-ZYS*q zKwvnI8LZi-joj=$PA}>11zEjpaxcAp3lc9F5bh2;YYdE1n3{v!-|3d%oTEQb3-*w- z8thqu?orsqg77V%ObvBH_cmg;v3+XsmWo$NrTSrVfg|IBQ)q^JAjYv_H}O#bj~e+2 zKs0fRHl{3SU~?=-H+VLMfj>@kBrfFXR%{P3X>eK5phZS9;(cn;%UUjSW;6g z$JAmKj}vN>2kO;0OTDwvg^J-(@0D+>GofIaQ96fHT`_ZquO7qhF*2ZXgjrRjWKu&a zv^>r0A~L8-wYSX@5cg)o)wlO0;6Y*(@2)}j9PN+?|8Svn3>)_vUetn@)4?|(tPM1+ zhsV4XJl{dvn4M~8>PWlADm$~(L`B|lqe5l2%(6QZoby$X!h`@`z>T=N&Q@u)`EoU3 zJvc=D6XW5N)7-yxoP2B0o2N8?N+|EIO=N~``yCMalE9K%ZUx&YKVPLAy?%TWk($}Vc;wp+6B(R>0rYxcCvMl*dZJc<{z0^|M`q5=bmAqY*z&q78miz z2j#f&x{8m`P*#p=tM(23v`W~49KdN%lsSpeDsQNnAaI+8D}g)G%A~h>Ra5r4Hz)Bs%u)l=T$qjMf4 z2Irx???)idJ|`>m=q=@qBKT$yEdCHn5e=a6d8c;FBI1MA_EZa`#>tdJJ?A7{L60J= zle&uAm8yyul892L0t&+|w3v#KYd#YC7_{O%T2mbPr&&ySM%PMatAx)1jpVAzn47~H z1^`{sGeIU}Z%Dw8$u@RB-FS@L>=__k%`(Y94dM56{BuV-d!ex&3rcAx%U{EK&F#pS zC;ijo?VFMO>Q$fiB6n^@IvK)3=lvaTS88fR+mU1ngNwFQP2IwC4N*BaOM!1v%{!Z4+}pSg9$@_8x*lJK69Jmp8i21%c%z5bNmt*@>|v*R=C0}M^8J>PJGDu!*9ROY{*Aoa z^)6)i&|`UPox!JL`Mj4*7|ni3A4H4Axt|zV9Ss?5#Y_!XThm|LQ4z~kS^1Mu)B1*j z+x1k$9RK)uy9}wi1#V1V{0xWdTwUXF!?--rVqVOGxG{2BF5wjcE;+C}89@W=X-8N0hc}pX0n^dx& zWJRIMJWHpv11g*BZ*xw*0nuk{t__n_=BJGKFWyUb+WI@bH92cZXGgh-@fleDqYEgD zBD+J4;J&zbW69Hp>3Se8TwgrKF(T@tkqIBo1!%|IC{8^R$9&^iR+U=jw3ETMjG@F#K4ij>ouB>-80i5@aIfWgNd-A zzHspRdKg_Dzq|UnsiWISsLz_Zcg-Xwe&`_p70B)PRt6W*>B;%4{uRMC+S%LZ1P4)< zxEc)UfHer9VOv&A;;~zcW=G7sZYlM_T|8l5DPs`&IK&LMU{D%7OT>qB7IxzVE10ZV z1UpiL9=am#E;1Pd1E{ArYrXS8m_h<|bNvJ9FXej7dLxyYL^G~j*^;3ip5`&)411Tp zNGU>ji12mOOQI~zW-|zCPxW`!H9Xkol;ysAGWXt*;W=@DjQ>PHh*J@azl-DVx@PpQ z9kyd;eX>xK8}E{-o%=3nb@K*I4?E~BS-XPitb&}bc@icnzEyZN!2EduDzn7~((6rA zo-`m|3M%7PI&?`f&j73mpgbO(tBNYa7R@x7mF6?n5=KfbGof-PfFp5uNW7 z=w8OWP=>~t52v03eP+)+;eKV6>7(eDTHP%CSvPl=_QU=C{ii0i!|+)uYbG^6;=Q$m zK>fGDvzp`3Zw#*+u8!||8RYOOrxPcC@xY6Y)b!UG5^wx%+}Ete*5e$fhhp4rH=&#L zTtR!t2S?0m+BRuz3!Zegfk)|rWu|AJpyh!zFPjDYaMLEu{Q=Zef%oDk0$7pT- zn4h9;;!yP0I`lPGeK&=mW4g7{Tjz@{1)gE5p|6V6t3$CUK;)c2t5#Xaj70{4&ehbZWuXAI{eB zkp5tX_xlA-1RjE+G8`q4@KlG%A)hiN7VsF4tiU5dG?`0`j>HE3r5HTAY7%pZE*UqZ z3L(?+Ghvg}clpQ)e>ul<)Q{T@IAIOKJFP-xM_RKnN?c2@?eRN3Dxlz!B8S!iKd|SV z*lzHQ5Kz}AmMdLp(2UpfKr_xRIyLU~E+HV zQ6z_JcJbm(zjtpn5`Tb$us(jU?Wz?1Y*bsd!FImb(An2U+k5su$hH3FdsviZHx&#i z__tDP@APOec#u|M(EoRu)O9&kv^%uk8?TWuU(rCo^J<%$%=oiWAT~YRX<#}bO1aG~ z-(r(q2=Gpen4X6;0j7)&?#{Nr{`}TT(#TWl(X>7r_k`A3RCOQgvDJiC{rk7qaJ2X- z>O&VAgdZVAEkp{dAfkc-MgQI2w>GzNBnf`^ujqj`!hjcor3vrF26yOPnxbWnMN&&r z_Bb>Opb0d|wm<+z1EP68|M#nWRb{;zASKzJ-9gM)BvF<1%F4=jBC&dp@I-T>U75bt zRDrx?2WTJ>BE=~MQn^y_V2_DU?M%%fP$5WHkDQoZLiW#1!E03 z4S=~|TX#XXtDwD`vHf#oc}gxl==e~vm!aFdY@|@2G2&=^`XnXd=!leTUpl% zM}{K-JH@1$8qNXG7zVMTuM?rbsmH$YqRE5U0(9|4ebNFSZ_ry%c*@(^{FVeAj4hG; zeMbedE?^d==Q-RH7QI=yW3xomH#gO!E|J-0x`pY0#^$2aV{RSK`)vFur?GRC439T57B?| zyrG6*)x*?g=8r@0?i~JP1H{@Ohwh5oK&FI@3bK8_9o!D;nL77`d^u5P=vKkUh{0<+ z>AfD<5V#Nm&UPCJVVGP9=yI|R6cwxI04&!t?5lcV9N-YMVXiO{8T{B)=wUqq`yk{y z6a!fmBTS4a;RD&U*In=N!6~{+7&_8zzJ(G{AkK(U+#kjagaIBKFDrTQiQ0WQt((>6 z`x?yto)?x~O5q=k*4`Za&3#mSFHX5*SS4JqmDrmb3K42-2&oxa5kizi186^jEzQW} ztl_#D#7F7P#?03H-8PTC@{^rkziCBIyDI}oX4lQobK5A4@X5A9k3@~g&bBj4Ow?aU z(o|*fmtQ!S0H{0HmQaIn-!I{YOWGBANa(;YEr1ISrZ36X3>rJs=FlrUpIz2h;L%VP zBW39xwcvGh9DPX=wHsvJ?4zgfuFkgEFO**f=?gq~_$_VlJgpHr^V>(?8QDUXN2;e*Lg%r`s;bKvbxI= z#q^SAJXCMcfpiU<&rZ9Bxm(qj6xO^V?9eN(r?P+|^lu zScP%F^ZDELg4~80jMlxgS~vn(qbV6wgctr63FB(LlzVv85_!uYEikhNb5!*h4Hhu0 z3WkKwdI2{J$Ga=8YqX<~CbJPGp>kzH+tpjwCYo7OOfh)((i}9(VCh}OTNyBkP0GHn)~IOvj)d3|jW<-%!$rAFH|3-XG6&$4U^g9QjW;v%Wau#F z3a*Zgns4e!#d$!zk(A?0*K=;PiU+hbYcM@P5WWk+Q^EH_y0B0Wo#M^%t&nEkUlva= zZn)uCZVZtEPGyX4IOXL^Y5w5j=(JehS95bU{FK1ferF>l>_P^~w-Y%dd6;(Pi3+$i z3M;bVN3NeG@X*!%l4wM*ofR3j!3yB=Zj8@h{Snz>&BA=qIi2!Zam09FRC0CMK5 zP%=}7I0Vm!%o3VfN=qVk!1Xa6@g9L zgJ<^a3TB9}V;bBNoZn}4bDd(aG#hak@Or_LK)%o$>R%oXDYb6zo+Gto7Vq=M0~GgW zf7*EB2|xIssEt5@gtIKk)YUO0m`m?UZ~bNl;f#VuGNz5Ojo(f;8yLi}Ka^DfyZUdkq#fdoP%`rTr94QWPl7Z{rD2|pP-J%4 z0`Ster|{-?<5XZCn5mXc-NLX${MyH2UhHKw-`bxuO*CQA&4(#M@*(*Ypo-o-0LQG& zuUn{56eE?3BiAXEZP*T+$zU)DJD)r|DNfN^=?bb8U~Gc)0c6Ichav_UfWiS<9`0iI zKIya_3WPv{{MlaPZJ4%a>zkW9tghsRBg=5&U<=j(Dg*{}iqLJZ*UDumizmxVddLh!fF zi*Pq$i&|IBCVTj+-pEp{DcALEwfp#Kx`!TE+eNz2tGDH5F3WL+#U(tK646Q|oJsVC z42D9xr^m|R<4oM<5uN)R3@4j9QCrz9?-)&s~g3OYceHN6v!fnymw)3+F~ zOBWIh+xWOCFG<&;VF9}%nyLBh@ub51Fbr=QbOm0< z4eVfDav&$=I<&mh;NzGj+KmulICMf9jI9o~lNF-35cA?rJ1qGitFHudamsb#cgV?j zCml0g;i7gd$hTFsP+1^FY1P1P~m6A(%1rq=49F_WJZCE|X;9(iL$yjRTGBY05!d#NeqDELU z9IR?*ypdybWS?87DSzTt1lND7dpU^qyB8ykj!^f$2v`A^V^d*LNJ^3jB~;Q?I}PHZ zaOplO^<&MRSi+Wa?Uv!0t0Qmx(!$iX@#Hvg*D3pmjc;R6m zWt=I2)+!w4O@7 zcjF@NUYchMecHL=fE4gfDU<>GLN^GHvv>M-mpnO46uAn#s6sshBoJd+ew2JwY!wgO zVwJbrA`M4taS-`L`5Sz9RCHn4+=B2#O(w#a9&OXrnsR?96dn}6(9CV{qK6eDI(%$J zhyM}ir3&HJq%#|vrsl`E8NHOYkQzbm=Qx`o?o9Hjh7V`$(Y5Kj0<#y78tDp&!UZ5h zJtKU0%Ti283mXvgj!LsU|A5UodT6h_`0ahD*nL09-wbmKZvJK?1hHy39;GD z!q(9V)6S4BaERoti~sF=cY85Qore_Yg!e3w8iexQ+f3T+j~}l`5H%LE>`TY58~wA* z;mb6wKUi8|+q{6*8ME+udU|r|=udD@+L(F1DlI%Q1IyStqK;}X?|3EnCw0i7W>Ly- z%}5;%2#U(oU2na^lVc-qA3t*=IwyCZ|KP|q;8mR#n_|*37In%Hj43x+M~@j&$Xmne z=_QMBX`U+c*j1O8pb|0!-|W?uV>?`)orqCZL+(+%Cr|MYhf^AENO&RUr)gA_O+9xH zC$=`?(+d`AE zoq#JJFVG-I(A$M!__W*YA~TP;0a^D@#W_esUujOF7xNpol$LVZwAy8qXuhK&dc{!* z>HDieTL4Zy(#L%7dxt-Mj_{7G*y-6mD9ZI{pu+>@6>c!1@7V(^G2r zC4h%-l^(k9z(R6W^S?k%OGswn&_7*|!MS+i_ftXYOHFngOQLr>2r+9F!GBwI{4GQq-d5#Xb0r|{V!{yU z%l_rf#iT55m;Ep{cvtO)zu_&>-i0TowzrL9hP93Wm2wHS6a2jig|7=oP-j;*Hw&ND z$i|P4;C76XS3j+&V;pAK-Ry|Tlw>yxyQQVOYb0s(3$iJkY@Ju-)vEPrsa0V;6ScuD%`<*@nVk_c2YoDhp_ zra*PXZSEwGiZk>UupWaK)g{D@a*jngKu)vB2!!RYXki*w7{{NQK@`nNAln)Q5b_l8 zx6&u4;_x1*Pj*UV4eTNMWTiybpj!3G_q7fMW+9n@I|)%Jx?R1Yq;LW&qI(x9zQ4MiOmQ_kNO(&h72ji~S0rO7ykBw=K28r?OBz?oC;`@W^O)U%19T%5cN{QM z?4zfi_wIG=o45J?i+U-0rMF;Uc#~KWr|nRzSZG(iW!QhyScbZ~aI_xvEC z^t4l_A6`G*c@rSTH)x=4IEd4?$xtnalBz)%?u3XX_=nCY+iU=fM`K<_#U-dPS2&YF zc74lw(d~i4PNziOaPYFFc$=MD5I9xB90NrtD+y`D4}!IHlphIiE$}4fo*wKU9DaA; zypUBTi`WHOk`p?Y;`XoKy9oD%>jVR&iY<_DWG$ck6eUJXX%V>fPtnY78%` zB=c2JFw?6`fNe+?5+W{tB5otlO7r6JB3c3p{I%vYL}n17&raHs|1q!$dQlTO)c`_Y=0qw%p0y2%#r{lK0{B2#jc~-(tkDMG&$4 ziRK|C6 zE-Kx-IZY|itY`t#tUwaYjh$X;u*@(%$Q9BEM)AbF;Jp7SH0O5ueWYeM4@htADUO1G zEmQ~@re8RcKm+**(ktVZAV)^23IRGiD!`keYa911nt8z$!P1rz1I|`ZGjp@C`T(Ss zdU0soY>L(GyjZN43ov|#b9&m7gD$#oc~E!T{AUkS!9dm=tRGmU@U zwxR;?FQGNvGiO&{tXDMRC(*f)jmv*MB7H_Te4-NSgf2LE!*_j;dGX~>DhE$oQwRXY z5f~YDMf8vX!ODQcEl4m^{stOKVD2m@r0R!Gr!Gms?R?q5A(20qZnz@aCk+!=z4RJ)u6rNYH^baR?z~x$XiZ;*Bvd&yH!_EB;@6-dcC!G zcpMA^L-B3wknsA+n=nd`j0T)g)>Q?APppgo*w-cIon(UK?CG&`UdFMXtVDpEvo8{m zBSccLVfcX$8bkAn*gyugu@rOVlxLr~nUTlBzyoU?+t7CbLu8@eHmTrH0Hi|OP@HVJw295k z=M}lm^*wOx0HJy+t4)nYR4Q6Mr%JH`%+?GLagm%v{GhGTkMC|qUKZg}WbS@Z4Xy@O ztuqpLThn9kqnX}dTF$NpbhtrPv4jHk@~7D*FAFSvxeG1dORK0o8#XGzV%LUYJF(?J^pWdTeIeFRA2bU#?yuiq;G4S=K{uz$` zx?F})XGk_JHHca3ZM;y2|5M#ihgzZY^N)MgRLgISu)Mck&7p$xv`Xfh!GFOBZ|(^M@L$iKjMl7o1kg~q=3P{ z(a*Mr`de2`+VIx~tp|KoBo97iApyGC@&>zlsoYL*jOR#YArKYD7L!PaNg1f@`Y$Tc zpDz`@f7$Vr#=H|bwK0F~DUN}flG>Ca|9KRpLWaM~UFMAmOK0RXz;?htv<~<|%3pK6 z-rWrANtjsb!)SsHsuB1$)g!mCIy0_)V2H@@@YyGsw)DeB^&!*j*J=7ZFi8340kXA` z;_?8K%7}FH*eB}4NbOsA#zmPb&Rq%`Ow1uyNGBx(j)YLc6WRtMw0&T;!K>qMk0C$B zCcS}7E3F#v5a&139;l-Tt+%WtFk|e0yZ8Jci7%Q8^77h+R9c^AYFk3kWd|^%!qUPX$sF+mU=W=&KPVotfWC%0w^AQjn_8 zt5@ljGjgv4>UH@I5#N4Szap?%cDZ;u*6k)=jGoTPB{xsKAJ&96{Rz`_`V(zQ zn#kYJtS1974fd`*B&(Qc*)2uxk$kI&1;kW0$E~_sLUC|{D%?b&it0Vp&W+5AI;&|YO3cC3Y2&#tRwO|dZn9H&|# zo7Q1*i(RnnyG@-4=n8W79P)`=PODTQ!AnedeYKt%whR-X2p^cCi@fo>YX5>VJNURt z>G>gukYr5TExLqbuix}GzQy-7j=5H_hldB>Z|BC7rSg&R-L|yXPu}!?2WOuTC=0h6 z(=Z>68Y_G241*0K-Io4qe@wULJoYBE$FN*;`hEHF>CPL22yGl>#@d)ctfz0R{qc8U zI<-pU9;ljO)RJ7?!rr_+k03tQo=zO)|&-ReTqlB>V)+FIjAjm@jq} zbNXPv*mU2CxIsD0jnb_1Rqtcb1|uYALg^@gJK%e~sK;DXBRet0?n$|!Oyk^Fo{EFF zEvMoTUXx0O8tMjRvay0?Ng?fvx%tiTovYglb3ie-^4*bUDN7Zp`dJZy7Nrz0u{tkd zX>$vxMid&$kX7kv296mC%p#ZI=vzY_4cc1NVr(}q!(OS}1_Olnk^lx%$xTA5@`i+D z_nnN)WbRGojA3m%576e=as5#VMbHrpEhj9FFrU_E(SfJ2Nq|U+4cPUFBXTLIja63| z{6;s?pA$JrnjfK1E2VyN$KoArpidlGb<|{jZiediMu@kbmkU9$EfqM`%X-Q{QnO__ zsUFWSFD;Uz`lWVhXoKh!3L7>Y)al~m~*;A9UN&*r%)vmPl;GBSy=WX4nvuvPXQ-?D7i3K3u zhwNeQh@N5z9BH#63g8#zto_ZwFCd%Xl8V%^dS6N&$W0T}@ZZN6Gum_YZ&{Hx4?9A$ z1Ow%6wdNhXi>H5yoo3X1bU%hQY4f-Ut_s;N!>wpR=|D<4p6jX(x7AW1Tr)Q;9m0c^ z5$*}(By48mu1tKTe$0wiX!_wQTR z5c6vYZ;ydu16U%erfyfzXcVllCP=ot~sGcO69X`8H zJct*mth6q=qsk?SFx8_L5aV<)CU8SKyE8)N&48ny&g+}%sGd&iCN5;3DfdAypH#$2 zq+xdxL`z=cbtdgKMDpB&quR`hSKDr7eXt6`J)fj}j2GKDblDH@mUmxm4 zpkvRIE@uU}fO9inVdO6-Q{2Njm$%70US@!Y3gWFo^9JTEw719ns>f-E z7`x>>+6rbQ$0bfQzQXNA>x~RQt<1?8_I$ox8o-bcjbB!{$)ZfMDRL7UhzucU<>tK; zbreG#aFBN-vzwdlzxHW}X5!1UZLF#0`7(CKEzfng?Z|^I$kIaar-H``m@4&@>~_ZM zW;MTohpQ8mb;IogWxeT(NmG#_Lr?YEP(vuIM9#&HVYpm_YSyd`JLhmwvR1R>^-VcJ zPe{1=*ufelK^Z%($17K5RrzPb2P`}@N&$jy;SK|yV^}@!5fhtD4!OzutMT zXZ25z2!-l>KtjPu-K^@FSv~+D&ogmKA}0vX8w);Nv>DE5-ROmkPXqQU9lUyys%gI$ zc}5v#8-^h>X6Z5g$O+kck=sQti(k)XFCgH1_D>HCt2Pu*pw?cIuMe^~zojP+vZB_r zDstp$^WBMKPLlj(V7c1BZlX7X*I)Sss>$|zih!9>m>gg$Jo|p}*~D;oGqATcIii4< z1p6cIL)6R2ff%82$UzxJm&moHg!j70O_1WnMFQ~K;w==1=Ia+^dXcf*JOJ$>J0Ns$ zfAza-`=>Y3==Ib7pWlqwbI+jHWIe_NGb8K>zCY&cWmKy>GZf7modsmS{?lX~Mgr4%>fVCH8GMhHm@B%ly~fkLf_33uoSffav<3Qlb@3 zcxV}~RI%1G)1#r9h%0sLJmsJ*+^hFbTKXK2hsONPz;E*9D87)vM0Y^U83Fq?;zdL? z7Nk&_ccfVIu}Kzu?3$79(7SeeNBjg8K5H_MSB7H_4>PQ%x~=#N4w=E4BOM?ywfZmV z@Lyf_Ecz1H<%*`Iwk8m*GDeR~*))%Rk9g|@YG(NFd!y?AB3#T#wdpV`un=PvRa7SM zt|#Q*dkcrLjP^eq$1Uz+12wm|$s3tuz=#bKs)CcAaTdJP4yeF+EmQ3pRzUEcJf(Tvu z#oMY{WZe6FMl(e{hU%)c& z5rlk4M=n@Dq2)*|cLzWTK}q2O1)BJqXATz;2OQ>BRJGdR<8<9XZFs*x&NjO@2q(I+ zhc{~~UT+#FwQmW+)>f@2w8}9733bfLt@Os<+ARl%7^DE^eOP?Ip3GfKdE{(yFEJuw z^9jVT({HT0hpHS*U+Y`^dvjF2rKgq@61+;&*+pui*@6QHwGFw_KFp{A@QSo^4i2uu ztn=csDDFRM5%+OcL{hyeRR+EfHk7OJfIDbPN6N}(#h1LEy_>(4iI8&w zTC3}u8pm1EH>>%)FtX(xuhhNAIx#`UryV|8EkhtkFOyg}JiB^>yO+zV>J>qxVla?8 zvy>JsOvwTmfcZ)3{NI&E7+#bsP^Nid7Cgm=XZzvG2)S}hfHR5_3?sZ8|vS`X!Ii_mUK1UpH%`O-Rf3a~pq z=HIN#@?|}R1YA^I5@QGQ8mrD#wfg&KU(e?qgBy2B+i-Me{@90~)W2Vql-D96;TM?v z*|fm;SD7DR;Z3<3Uz>TAS0!Kv8GdG~8tPBiSF_TLSNZmz^Eq`Q$BdN}3c=tFxf)>G z?9eaA=F9HiKkLN18)b`?MK^9rpnm zroHzdL85RdJ2&x0&4EiODe)_{80hZs6@0a&H)eHHAw%!{W5@Pp2Wp5BmN*xHG_q=~ zT0)@1b@5mUb{dWut`MavPT%Oj1u4!joYrIO>#c9Lz0^z<{tGXwswV?vVp^_e<7=yu z3Al--Pman@;1OAxF|Kzzoj%=q`x_9V`C`@VzJB9BpOd;8&S!?juWi(ZgwG5c^ER$m ziI@xB!U$JPuO_9T-Lv}tRzvxo=@wE!nhN|FJm74_1a19dIS0mWc(t6b7bzz41@1%; z;UxwZ zN0x%_w=e?$Wd*kZ2YW^#i@*>QEyo8O1|X|*L{<$NP~k?jLz)@WATjpM7VtZry6a>Fy91+IbiJYAjUk!?6~4~KsRMrz_&Jm+QsK!8v4!!bfLA zx|@DLw=19V!VNb2VvuNRxb{8^iy)j`R1UF_+kq0DK7rv(DzHKG_0_duaohp1t3ml7 z_RW7Hb+!NEC!?^?4k??KJ!NzL!$ilT?nD;O21PV4uHofgI#N>zra}NySCAw~Eq`FGzbRz}*~0)%QEPCt`yNE5-`;`^#JyQYl-kBS#_gRxRw z3`*QJb6_&f`%_~&HpmGGln)aMDoTjYkmrRD`4Pd*cRs1m$ygj2tsT1|i^38$j41sY z>2m6&Jf>7|cUG@CF4w?O1BOkia!M}T1AmQESix>_3M&aMyLB@ZW0Ps(*KpJgKsp$D ztt=_0V1|U{+6C_fmqOcPlsbS7Ka@bjLShmCL<-W~L59ZltlJxCno(X3(HnJ_-5mxH zgk{hP?fM!G`wp9-lBfA!fNq3R1jT`)f4V$OU;cb{vj6Sy?ELg#??`VexD*;}8S1LM^W(xe`@oK$bKA~Q``#=5S&;DP&@H)U2$<6YIHG2p@fjT%D zE&}en6=aVXx;qZg^rog-GbzEQPhHc{=>v4$fyc&817%4%d+a(x?lMc8_j_e?E!)s8 z_vUm`^{qSNMe)m%Uw3|a`fJZG7L(X*4+wjY*!$Zx6HG+dMR7gQu3z~1IgZM^i%N1j zLg+$6u?q(OjWoVq6zj!*`nbw%cg<~rAG%a2lpNO074e`y;}Qbq+9mEdb7{neZ$V}) zzn%mm!A+JxfRn9%{0V9y(QsBxOddC^pG>$rScqny9Km~LRjQ~(9Ra8jkyX(UKWVj~ zq_8!&6n(+-tC8FBumv5FggOb<#Xv`>-?YDA4B!kzT=PTMOKy*v>FM>nS!G3rYe1d{Rc$DyxOY*c62<-%h4&%^J^*Nk2Ot4IDu^~36nh#}C_d>_bZm}? z$%&E#!pOmzftgh!4e-OndOfIy#^Z142~GZc@`A_CYa1VMI(FybW?9UF_XbbRY;Dp9 z#d{3sV7vd8B!Fr6xy~8Aq7Em_=^H}Z8AMHm=-z+_*`xvvDe-FhFH)};PMVt1v7-Bb zI^9E`n}3Cy?c$UqFss`UxK~EZlDTgZl&h zlXgV3GlQ(q`2S?HjZ8LP$ss7LK_UA|1VlVSTUP@gs=1Xq$^z@t$codwX+MW@Px{3$ z&4kc=1O6kNYs-3Z+jzD4qG!QuAY~5sNgsyocaPEAm!Kqs%=I1IF3Sb)-G-@$ z2O#@iPpTmFz?<)Ev388H+p(MSZ3SxfZfAMXNq1FmKp1yzzsb)znlH{IJ<=w^X%}<3 z9GQbz76CpHP^lxba$J1=F8^E=cjdTj>5(%=tTwM zO};m|Y_@dWLLS(PR(TH~vfoU&Z%ed!Q8F8SG6DXk_aiM8FS-(V8Yy$oqV%Lb^8cg_ZPtr4=v* z-H1CHmi#w;4`!fmX4_FD0^#=@|@ z)K}q-E#}Q?>w6m-Zq*}4iZ`#yK-_A?*4nC3NvDAz2B8b>hFjmCKUuyZdM(A0ikx!d z$$eN>oXCi(D?qFl)v_D!0YAG6+ zPj|fFp#A-|&YCW@>Lt@}`o3mAAX^%GA>tJ^j6?(3EVa!QGh{}DCsx+sy9%#FOv_y6 zj#$Hmn2m}8a3}2cdarjLKYbH{m|y%zHQF8wVRNYKDQ&{U&KD}zSWgD1LHEnkl=coj ztij`+eV%VGe$NbM^IAB>7#u7ihYJ0506`85vxr0<75SvAy<@$5-e+6Kd+)^yb`sG^ zPtssH4+x`x(@_|_q1??p#l9j;uU(<@;Y`@H`KM zp%-maN{i8nVU$9CJN;Z1l=jR3k5r+X865glB|{8K*tIW5t6o0ZVq7?G7k531h3CW< z@Ld#jbT5rVSfS7W480D(t~tAsfc;1r7u`R|6pzWCSwiAOl#(8S^gLu}Ls(wtq}&t? z83{@O!Hor}001XLr6v=rEGM>+AD;|M8Gj{=sP_$XY5qXr#>N_4rJYNQJpmKsX>cPI zo_n4MUF4CX?Fv)(s^=|FK=cCTv7R|jeG8ohD`Wz8WlK&4c|Pkgh9zg|PuojN_VCw+IJfo=$l(<9DHI`@j3eCviZLX zQkGa%4n9UPt%^P!c45U#{y{+f83>`BNYCo;o5D|ZdW4@LUfR=s?Jxnf5BEeKb^(Hp ze^Qy_{LLeG*hJpQ&K*?0gx*O_bT5X|3`%p2%@xMq{00x0K5rv|wmOHfphBXoVbEk$ zQ2m4kw&3TwQf2C_BXTA~z#ewFqvw-Po2G#vFs>WQGT0=P7@u)jj+Otiz_=z%lr*Uf z2MqTUs*g7!J!r|H^z@ZL8M(1W>@$Otk_3_%0v?2YV`vh@NB08PV+#udVtB^N`vj{B zM!V>(wpm{QQ^bMglX|>zkSv0YkkS|`GwK&3;G*UC9{j;ppuqice`1~3sPEQ}#1_7U zMj#m>i5P91al`uU$Djh=Q06nN+S%qJitF7ie)-k^^Q`)LMIz}Fcimxy_8G!B4#O3U zt07lS_pJnb6DYgL7$bp3uqlxPyUvSYo0pS`-^+Hg**_`QuaR6KHG)Bey#tVJTa>O_ zwr$(CZQHhX*|zPfUADc;wr$(SF5NnPUq|25FFHCSXGTUw##)&fbB#H2{A2!KimBnO z17FP?LId~Rcuw0#l_SGMWwXGm7)2e_3v(SQ%qI=EqFT)-q*jk_6lw6NFQ$? zOkBpCzx}pobBEjUI8+4Wm?_f8(rPo1G9;-N89^G^RHoE~-L<#M8*Mg7uGB(=x?nRyOr9345_k0gBiQh|0ZW2R3k*Y9MG; zGNBwOLs11J*?_Xb5+?Y7*Z**Ff4sH38DaP^D)pzQQ|{T@ znC%E{9t{%0#>Mw<`Tuk*wAj|tBtqBz?=bLYZ0jFW`TGuEQ2F*68{5=>WAx=$l%C)-mVixa)jnWPhW3|D1+Xf;&SYd-D!U-R_YYk5V zf{b-{{6lS|dFKI%&bU;XDGBQ_Z>T^`g0q0xWJ{(d!IlYv5&==k1*7TVv+$E6-z~Bi z7Ghr{^S;Pp0#^iR)>)bpAq~OoMwSvn&4$$-F_zEhiqV!aUhT2Qqil&$rq6gmo4#QR z;;~GVK6SE1Op!^#kZv`vc(%Qd<;59LR;hlHRtIH}b^^^@pq}X8(LQ&rEZYD(sVeF+ z#LI-f5&_Mx@3AF-x}lT(BUQp0*W6`czT*~f+)3%t0ZVHhVLZIN+|5;VTfpT_(sY(k ziNQ2n@zhYq`GRI^yNnks5#cA*FK>LNx#EEqX+@lCBLTmw($T>jh`@kkxqIKB^GC7XnyA%asy zcgbbiz5%rJJLCwdXz(1x7LxIDT{)=ufR(_b72lGK_Xe&3)+2tS8Z^_&_cC(b%XcGw z0j33>w3uqvde(Yv%idJ8V@!>KdU(2ioxJJKd~tj$3xN?sdK3uwjp+u>JsGg1_Wk!)6X@A~XfItV`jb_P=AiHt-neV^0fD-+&_q^%M;P$za zldXentX{)8yaA?t@w`Pt>~{UbQ>(Y5p?>XhRyTG;xRqT=TH1iEy9=k?$~5Hi zJt%8ifmU}9v*d^KX{POf<#rzX_h@>t#dCoYr&9yK(EM75*G@3IugV+F_M(JcF>bL% ztaTvK+D1D`@B8<5RN??!H1p-uxNs1u-R9fGU5ZwO|M+UhFMB%OpHg~@UK}{-2iziy zTWQx@+u{NkM;9>F=C-}`+ltN_GNx#sd0Rlb?Vp9Sbmea&?UgRhGO23s5ayld_iLL_ zdGN)B4_7Xr2s365D7r6b-A8t&Ab{lecOGgs_xm<~<#-Olb|@4g^}qnN3Q?=mzx>4T zIQHZ8M3L7|Bkb(Fue^v+?wQDK*!3lKKqCazc@)n45kwF=8A~?-RD%H!&FJk<)DPH+YukiXt6=Ke&{YWug?cH&?$6csDEs#a4lHx5 zg6q<*9h8D>5Q#e5p$Z@ZvHyDRc9#S46VCc}q5Ra{&H zj8+4q8m>*JTQhe6GF*j8oD^)Ykz7>5Ji+aG1ee8D(`3DwHRJV}CjJC`sNdjVj&PEu zV|EehBZ&zC^}=!7>F@wMo?clz^APT^*yvu-lOaKG6&n=J?Bu2VtNr+b3zXDA)j3;* zO7u!twyWsNJns4O$QxNBN;x1Lot9Dy3N=wnnD!MqW8}MIPT6g zdFzLiuub~I`XP30rRvlIX?48V1S+XM*WoN_A&Hr-qSuA0c7&_=0PjKB2R(W4HH~fa z7-*{AS?-!+YXx9di_yXCew=MYjk=#Kw~4rs-MsWniL8QC7WVjlGo-GwNUoI>6;djy z^}?}B`YaTCwYTkheVj+fOlSI*uswRCg_<6x~dNaRB9mro#xsh=8JeMO6~Zs=zm1ONWm0saUBEy3%jK)`@G#eFTdLWsBTj zXmvaSFVU^sL80!_%aR!PbUC#nvmj2k=niC=Ru;O*V;Tm-kWGD-lH%7F*Z%S%Pzmjw zw&=KIYl{4{$Wg&<4M&#yx?z1UD$6xw z4ex--$tBQsHT;b`XjQdaD^aQiTSI0`te~O(wl7Y(BqBYBh;<-jTclj35mG;qKhc;b zEsqLdb0v}lFO-vR%y$4#?59I;y|b*4!rKymyf#He#G0&w{4 zNfaMZZ3HS3B49!Ryvf3PyAfj7M9bjE0KRuQkv}wm6a=9qwnOb5 z{fhJDJK+wcq+!%XM?Vfi=`69!P#V;DLMysU?I|8I7wOTQ4)u`089XHFKgVhGE0pLP3B|ZZIqI%%x7E8|-mJHb@Bs&rhVG%OV#gjTHUZyJkfsv*pnz@dTF04T;~qj@pWyaZh6JlrH~59riQ3`Xy~ zQ~^6DXB4T^;VxQ4wBYGWCbI&BA%Ka@&Azw8``R`%j#KSV{iU_ELf&R(qM$;Z`3qP| zH=9s9a}BBz-}9aUXQqR2Dezz#!U1v}NgMHK{c093)U_!m`{ zO2|&S0E;n=;P+LbN%Y)R1ZM9<F^r#I2CPPdY>`-T=2z;RW;X!^G{*;{ z;&1e_{g^&Q&6t1}`iV$dg)DHW=$Oww{%n!MQKS?VqNYTqrwlgioZHY5JnBjqEj zVr>FTmy0tCtobDlgDFRV@woxk5~(EoM=_!Oe*9R(_|h;D;cfw1t-3`u&UZ8+jtPN? zIVLvG6dMq=*(7RsjTg&a*vVBMJ{Nz#(g@=9RSNshUGlabAF~Jyn0V7ykASL> zB13n=kepo3Ujk_M*_nxh$4Hh4@3@j{1gI;_TdhL^DL4_ zLzX@KpJhuv)_UAh2q=-}X3@P$;CAi)J^b$$0B=4Z7A|OScy2z)B0Kz7zEyf^jpE># zC_XkLlXXVp1ID1Rd1pgF2wUSv=9PrMNFVdc#Tg{M{eJJKQRdymqvDUo+n4V2rzp=>Ok<2{MMVB@$7E@r)r3%VzJofZ+je+1%JPA|3@RXnWRppCjks^*ip z{Z06g?n0}D1<_i?bse1idhIP4@@fW>9cw7zgv}yJ+%PHY!&>Y7DD|#JK@eHy6OV^c zY6-}jP@-QuQAltX<&x5N1Z-Yb)TCN6FEk>b=>hs9Le00|A0C_Fh)gDJ(OYRT)YP;h zLI^rHS39Wgjzm{TCvW>2fW-v#C0cU0h;-MNDf1BiqhEZ+Uj`t??xG+|$p5bETt5*! zy>&Z%^;K^=)0*->`868N*r0JcQ|*9EL7QA#(Kr;S?vSXemXVxBb(NTV>iMU2pxnYx zK4ZZom!dm>GO1|$tT_gx*g0%)AsijJ zLTFOY8WRf7C@>b7KBJF8Ut8Wg59OpoK!C#>;9a{n+RD-Poe;I+H(iF*iZe^{2{Zv# zs79`p7R^$ecz%7J%3u4tKepM2sMG!Qag3<5C(wq$D4W_Y>BZDp?3-Re(v5Y~|Zk3j> zIfGX12`On=R_&yJD9)(e;V8gK`IeByh4(A%3b{!(=ASiE)7Mg6qOjajU({$Lh`=w< zZbc%2he#jC3vv^j^hpj^hl=2ULr?N&JoTuJ_po%Bw@eyN4)ql;WXZow-f=;YN#UJ6 z>Z-3OsM>m}1`tjixG*YS)z|(eh|%=O&(G|eL>?cIM-B0H=~d?5n8&W{Imi80L4P=b zxAA`kq~qWhMmD|sKClWug2;AT5(2;1c3B(Rr;zlD(K13$Q`eFCP~qJv8S~MB4*B>z zLvqZD^P$!LrA1VOsT8%Fk3mDS0*W6X=x_&;jc=WQPvQ46E{kEf(rI|^#{h8mF@rH< z3gy2g7>-}RLTY0w2*Qw??~0Ihk9mEBD_U7ix^9ezXqpAbIJA-j5EDuUR@j`WAS=TG zO8`jsVJ$Ri>u*S>GKw;hi(h@~w0LY!Dt|6a?c`u{@4&P4dsII9R@dR^KwVeffVZ4X zme}lL=*9HufOEUPT?*v_9hC~s7fGoz)N)Bx!RPY?b?Ptt%)%fmJ$HN}T77=~1(L;; zjd;tiOiFIBk(MX@|H38w!@Ly!gKygm3IG6s2mk>8pM?h}OYXRjbR!ZLlG9 zU8p0BiE43d<9SlRyv`EI)(d0;$T~s+(bFojH6oQKs<3^%_y{MUm~xMyRnw(PPOoIK z56n3|?;nF!x)3iJ4)AsQJ?*1wQReb`J(t=I%Q3he#=?sECB)Q4GgH3p%1+ubaFYI1S`sDJ&UFzLypB_AG16{LTHjFGqifzhxA2I0I=5SG>N^)P;rN^ zGp;Ux@wcE~mt|Ql;KuA8I&b<`UxXR0W6SaT>8-SfO}AOMafKI9V9hn;@`vZxjJvA1 zG*e>Wt&u&wG4YDEAqFH18gy-F&NX-FHXA)4eHxTgF-2J9Q*3woJiK3DlWveL*f$}q zQk=(XPi%Y4N_BllI|Nzg-D0;kz=b!Q70rlFq1xa;2nD5SNdQ=AVhT2+9#j~{Gy&xC zdAjUJDNKXCR;LdEa9n*`W#EC?iVB;1*2tMdruNq`8E-bc_7l;8%;dnB6lTIjNPwFK1xIEi>GIHSDJToJ!@9eO6tQ>38Psdx#Wc+D?>oFA|xN^taE|gNyo|P3QPP6g;PW5mg*H;2NHM!sbM1*13 z#e>sje~5|lVthi<%t`J=urHaNdV&`dY!ZM1AdV(ORsi-k`CN=Rr{K9P&v;D~c&8sDvpH8ur#x1P+x^QO#l?Y# zhF(5yQ$hUlWsPk6X+}D-S7@#BjtE)KA{|Y_4m)5l)>H0#QW^UnmO_a84$)-J3gtLJ zZ|JmcSW+);gbhcYvi$=uZdY6#fwNbwb==^>5?*!bdnPnFR`QV({^roBe&ULxvLWf4 zGC&$=S4zP+8z;h+EirKVI8WEi)rbRPE^4|Nx~ZDEctj~puq`L2vWh=GdCFfWaQB4G zf61TYOJ{R7c(w?{hd)IRcgljuA7aTu#`Ma*%(`a%#;T@R;05Y2yFMCf^2(QpBdIpm z0QFBzYZauVZ&*#H3Y&=3Z{&hj>$;Y=zK4$~vfhC~lZ53<+JhTv2Oao*QgRfI5}j3D zB&Es!5<0Qxe4aM@{C!<0wFXTr@>RfzIjym|?o_B&o9L#qoYf`^F z9!{F2yEw#kwO#IrSTZ2Y;^o5MKmB9m&>40qW<}1#@ zaF>vLA0uPPm+t(ygGR-^?!If^J7!yOq_I8Sj+?sTjH|W5BXZj|< z8ThqdMr>RmQW{I#M*&?vRoP)9TWJH@Om~Ti{gYZy{lr;LY0bp@`u?&%al&noY~KN+ zl}GkNqslvQeQu5y4!Iqe`Ba_);+oJQ@H)ZB}e9Iwk2U!u_W+ z@K$U1{!k`Xu8Yh!J%_!wk&}z=yX+z52u^_vSk*O)?`e3ZbEDxERI*dkrig!KV{7Os zce8EMj>98WcTH#WG3{mOSl=-6^!UORQn$Tg9GV!!n9G}U&ZJ<@F7l}&62UE=`U~)% z=nn<(&+Q*r&VM}l{~n<=w%AD}W`d@AsI4P7aobaM3kGt=GRNegRhOZ*nm+W~T; z5lW{bg9tboPM%%M)CsRfzVgCbyqC@hJczn^1O|``g1gk?=)8{-?qY_xsBsDVIuN~Y zZnjv*r_N+h+5m1eYyA)i%BM;sO{rNY?9O1?!de89E@a3>>Wc-$_5S<@Cd5OqjY5B7 z%e82UL8x`YOmW<}wW}3(wh~My4|~^yvmLP(M5ueapJe74$j6XY20h(+U8liFY$_tm z_JV;{=88v?j+sZXBt#25B{(lYRbk@8aIt@nx7B>{Rb!OaeV|t$!sSe%jMEkA+o$I9 z4_hqOo5+s>i&7<_l#z0gW3yFCBN%Kr3KJKBG>C(W22v$M#1NGT5fK~zxUtEUX0#mc z#ST=1Izw>E6Lgalv31wV)Rs_loM%neg&Ot6R5%+JZnH?2i=Y!}7U&#z?b@>&%7XjC zZf%t?ts69MuaKVt#=OcZujVQ`A?zfo1TE)IvK5@LK&mq)!21Hp79c{~NMH#pMAW@e zDByxiMKy^Mr_juGzZkzFj~|YlmWH7V4U4@0@!cVMiN9NIMw>v8igPN{@~=x63MPe< zHC13_z*sc8z5c+=o10~U7k_L;RIDi@Zla3<<8t*X)xSeS<7H9sK!%;Kh>C4(R_|Gr zNtc9l;RhBXi*_;#r0LvyAHi6uHfUGN9Cjb*wR7p-k^>%?h{HTUOaVP3k+nq1sK_;b z?+{q?VwG9SA=ORS0Z{MWm^twRZRlCCP4c83|8v`Jmlf6?kGbh_&k;e!g67Ce2&s9T zQxEwy&d|kt0OEt1D4vTgh8j{L6%>FS2FP6xemPVEVgHb2Ieb^2Y~s_G>^dHC8;V_P zYr065JvRAAoWuQmxW%iQr`%>I;Nw7=79>;!HU1X{1ZP5 zUb(FJ?%0F@i_KX}RVOBhEe2Tqz$a_?olQ?xM*Hw;KghKncw0`w#~Ve>0r2pxa`nie z`7o!K1w3uvs!vCFT)0P=e|^KPXIl@*wzT_xL1VFBD|JY>^<8RWgUQ>yOR5+{>612- z_i)<`d#dU*azp2){kV0oZlVY?n)C!65yL;L;=uOK((a1Pb&JGp@5aWz_PH6`EbRMS zqjTwf8a+hsq4VZ;$EK5;x4U!Qi@yDtzQMwC1N%<8(Gh}DO1_#2H=>jTUo_ogiZMwF zA&)WH0U^DQ1fp}mC=b`rIf?WwOLGTgQy7K*SFJ_}dGz@ZaWJ^FB$3?j_&@xM=tOZ= zs)d1nclJ*{_rmy%7vYsz*jf2^ZEC`Lt_RKof_8{g?K42gIa+`eJ|8;j|Q)6<@!D@Dc2=)jL;^`Mm zDM6H5acyx4YnhwnQR^P96?PvR5~0KIFf(y7yBRJ6XBIAZ<8E)X^x~v&{Q~}X{4B(# zkohNktNg@Jynn}NBM)a&eJ5u}OFMJ@ek2=7C&qJ#uAmxhr>bC@d#QGECzrk}*pL{2SEl{F1#qZY3|3He3~QM@Ey zNj3y8h*$U?iy<3YLS?KL{T)-9NSF>u_d$-yozxrp)4&M7268(Cwa~&AOck+Pj+RcUsGtp>P?#+>QLXewa*x|T>3GOcjNKSXYd0jAKem2}i%>#a$05l&I1wf6Rp^i#DhcCmsDy8?qISqs_TEtM^A@Nohp%>=D z^SVOS!==qVC5E1^EG&9Z7Geuo?)*aW!^k{dgx^k&fiGEHQHrJG+=iOoTR6Eg8^ZIt z5UWHZxlHes;{_haxec~l$gta!mJXSVA1nmR_3lU3zuVIo+WNymer?1zp4H7Oxk;PB zktCI|R-EcyKSl}7-CO&+5n5i0T6<zLTk=tEuC^rqlm` z+q>eicG}=b-1(%gyI0L|uB;;=Z*)uHwa*#(g>;g*ifV1f#G0g}h(be{4?zd0HK)7f zw`0d50YpNoD3_hlg{6vz1i^y!c@5TU{Cv8DS1&Ow+TcF!o6h6+h`;LSoHd5ZLZ0cj zT%XQZFJ)R1m8zN^44sR+wk#_0sqRrFnmHCu4SE|XxU#79sP;zlr2WWrKxm}Eau_k2 zN$n%zurQSrCx`-_YcbTwor@a(E)K(Yay%B}!DLd~-*L+|*^j!yqb%&&H7A{+LJJIj zEPndq`;}QcAPe{a`pJ_f*-uPsa!&-`_!bYziW)I*h;&!piQAw*wb|0jz1J9ZV$R>0 zD^pv$Et{>+FR!Q9m7~|w;rZR@ahL~qILnmyOR{b%NgZ92tTF#BlDdI^iE8rEc{;#w z6VSg&sVS2p`al!IC(}q9@b;cG&e-FR)WNf0D#$iQ!PV1@yq>In&Zj$D7hhH)gfB7LzZ^nC1y4uI1!(12P1+Qn%-;b?V6Vq;S$`5sI zd)7gkVY#9}YyrAsWa3tHOROjg5vxbOSB(nNDyxd2$cc}a$cLWC3P8Iw>_H!)vKtK| z$gLZ5Qxh0CA$P=2ItSlh-M|-WP+^u_%@EHM9$1w;K6t~C@EaoYd=)XEo_Wjj!qSn( z&tIQj&zF&1U72+!>Nk2iGOLG?pD#xjD>nD{fdRVQJ#Bvs(`Ke-#)hUqJoPuzzu5V+ zf9vb;F#C&C+VRjTN~~G#G5npi?|V86XlG%lbEo;KVHoE60Br%-kGas#!p>3WQEvlo z>(MOqI$ti(8<$_bwBxo;VXk*%!!Zm++Xpp8?*CD2 zhEIoU_QaP=9N+9z@B{P)s#u~|Y60-e2_ti$cngNNlhP~a^78812BEZfMo9#Q zD-SR_n6ntP&_ooG#A22*f?!kzCG{U8LtS-ZX5rC52nFph8D0how>{Kh6ui`rc*CKf zw3ntxx2~5mWTXZgY1E6fkOY9iB9MVRNRQ1gcPpXZ6CpJvurzVX=%ryP1-fckv)ird z{^@Bq+2-vxE@ zLHWKj_x%*8olz~c#eRYD>h3u3Qc<=`}s9*v#!eliGVr_&p1fw`V88S^LOrzDx z&Zvu#RCYp>2)C}1i=IM6(_H8QDlDJS10u5B&NN1s0+c}8D6AG62S)2w0ajRDEw-F9 zT{v$7K8D;T^oHbRsC4NRuhHqS|IR&*QGnRHQ8BYNkYa|(J+S2wHeRVI!=RX%T$|Za zaoKJ8YvMErnQ|$C^KE=|u^^l8kd;wh#~bWD)uF)@*6ZZH59Y!^yG)?`nc}qu@{^r> z;2_q7!_2(Q*Ch($SS!%H@UfuWpEkWaCFk_DO>89AZl1$y$MTJ0n(g#iNa8SPTivBV zH24F@n?x&qI)Q0T+A>B)v;(^y=~sHbA^1eNV8K0EL3$g=T{6I*0GBOR=|(_Pp-J88 z3P-%*+y#S+dfCnMRa05Gbws@X=|lCL`R9 zQt*+@Nqd|zM%=rHZOBSZoD)tCYw;|rz53lE1Rc%^JHldebwq+a631O4Uabc%th)q&w^RefcCaUAaAWG0zy}YW%j@NM?4}Y>QQIC(3!?~7nFZP2&W1(}BU-HNF*G7R?^@Tynje#h!%M;- zHtgK=`EMbJtUh?}DYgZbND*-8nw|Tn)S{1c45N96=Ww6FI?NY|35sMnxYrNNgIMhT zW|nnZ7%b>V`QTRSz*d~yJ@?i1DV$>{B*GvI%)s;^DtB;3s-OG^dupbD*yL@GaGo<~w^nKVv~(8~yi--b8e5 zx`b-CkrSryV3mU`0jcny+@3>@a3VW;xdNgUZ(N8Sr56iIsmm}B@wY&irn>(hL3~`N zIZFnNP4S!UIYPIL50EDQiDvI#UmMhfsgb2o;c!YV6h$wa`Yuo^@^=vZUm&1y*ZRUM zzQ{m$*VJ(@H=OGMcT$X%ylY|n7g7%RgQJ@_FaEX*H@8|Z5r5;>4$s$6y@Ey-C;OPW zw1dT`$GwmcR03Iw2(MeFv;`Dh53 z2Yu}dqT@L?a+D;lGBAy(WUvmAW%N=lGrfzPQvLzDc|FM5Wp9$u?GL0)Nzm;J;KqaJ z?B6{{TIA~xE=3ICt^QlOUB{CaM44!LNj%$!#}U)~?*IMQnhCFKt#j?;Y~`@F%gkUl z{jq4Fimmq1QZ<)zy6(qvQIN`X6Na1$mcg)=k~aJsFxOwZ1cEvh0eOaEj5apLLt71` z3`f#s1&m(~FPM=0-QwXD+9GH~@v*&ZThIKS+0K$;T$w<8&#hmC^J(227N&WD0ye#| ziN`MaEaR}F{|gJk!#uUKo*@1hClqh z_&d-}s||+x>u_8sL~TjNe<@Vh6Oh%2 zTrY&Q|NOj~D%Y-r_zyoi!S?lYdJH$$*Gi)CK^yFfVARoVg04tFh&KC8OXCU#mZmQP z)H2##%O-NIjg8YlMZ?8WF2=7vH3DrpF>N=?=2+ZCHCqZyS*I8Gv-o9HM&fBeRF78i zsYdlIY=ARdN^i!)=tPqK7=3abdSlBwmr^&zUD$5p4DVf<`&(#(xz`;n$9-X&T2@1r zj{uQZvJ8CAfzRAJQv$UCA^de;#{gc)lP@l(3t}$n?2bOVjo)qCewR}1IlA3}<@)HE=*AW?trdlepzK0sG8HOxMFSCHRFTaBBk>iK;;{k(0dlTJ~3 zwJbIVgH~$uSD0Kke_gi4wknAF3AD>VMbSo3Hv8WJ9V`@m-Z6HB%BTb?GQTHp9Jn%I zq_>$_fCS+1!95~E?L9MT?kqaWX6Po5(Zp=5yf{&O9-Ppg0rS^nbDg@ zmL=^GRE2v!Z8u|eIf9`eon~9*m3USF3E9wvwzpqNS65&%3=Bj9xx9y7tjk7`^**;d zyp>a|uxa|7?9*f*%l1Ycpt9sQ`4a@l($b6uAK`mnxy6Cp^s1C$4G2TE+ zNyrlVmO2^D?kY^<2(JS|%%Qn!pu=o%84h)M^usKx4>r3KxAl?hIvS)g)*GUOq)n+n z&3pzZ0Dz?^xM!D}-h*trkrtwL?7x(0r>4>!0FAa0#~TL`yuu-f&kH3dtDY%OiV7?c zKh3bH1R1)7E`^k@Rc_Y^D`#t%Dm&M%PzMY0 zw>?$;N^`|7VY`5zbrYVbFCV!D(ONtvuMp3aNxa<@D72}QqF>>uomQ|RBmWPI*+*fy z(8@wex&EM8q!t0X0dy34k^=8R{n$d1-_V9vI^ViFt*z1~#7@c7s3;T+6_QW+s#84a^N7j;cM!D_JGQC>;b zMciX1#(N?M*QZ>qAmp0$3s~ot(MXiSJJDo-IJ-O;R+Dme7WRlm0L34%iAFs;DJ?Kt zinfmmT&4U!M!|5QnSTz{jlh!M^1&u zYzFIy4GRMY^?q_fOup?Lfq+5O@{OcCqt$CSe<-ARj@pU`~xWfgW9G+nu={2OoN8NcV{3ROX~CKJ7x0#Sn(e&Tb#Fa~Tbe z^M?u)iO;hZqfS8NP4$WcZ}*2ggGNyVS;q{S#HM~Afkjs@jK*fFJP7P+!K*f#u23cy zb8YHw)5hxWyyd$VGViQ0(mETm=Uld*CavLuY zJ?sv(KG}Qkh&;)XsKS)-P-`!)#LQ3k17|THJfhUj4KT?D(`M z(IJXW0){DQZ>@urTFttxL(NbzzyF))LSs|Y`~GvG^`im+VEy|-v$XqBTwH8_YUQRT z|BK+Vr6KL~(`MEEpl+}Uonq*5KE}*P(lSpaOC!IJGY<_$l*VM6XaF!8`Q;-3Kmv%6 zm?9<1y%cTAgVwRb7ogf_5yPaV8Xu9!hLhqVogu*Ky(M5mrcEFZ6Iu4VwPOcwAyrZE zNK%AD>7eY?y}YZ6DjyMNmANk07>Drfj~e&y#c5VsELyf-%gB$hJjh*%c^x_Kc`@4= znoM}4uzJXV+6@mT@fBJLmt)7z9b_``n2(!NkFj%T>8qrX0ceyf-B;HX)x*MU2E-;h zjeU^2fu&mr=Aj78qioS>GyZ$nnR3J}c9gdEi+%oL~My~tw z<=0Ic$g=J`xS3{CNXSXW;I+ax1t(yU0QfG|%amTQsnr|CkXuJOk5cQJMZ!-o_sTwFK(blIxXzY2W>p zBcu#4lP4$wz?_x^lYChQQo^6CJ~L*H*EwcV+7vRnziZu9kF|l#_$t?)Lb0^8g zwv!M32-C0D4DKB%3;&Y0c2*0yw4-v>s$A&ZsPffA!%gEc#FsPHcS%mF9P31^dP>6& z-LrW>8f{-o&XE~*l_hb-BdagHtf%B~ox9`UVLiHB{Qib{RSbu$HtIB*~4icv0a^quEp+LR< z-SSPY)_nL84Yyc6txWrY_pn+2sejJFTkYGn^*Hlw?=^Mv?{youJwPi02mqiA0ssK_ z-{Hi{(A7}i(9zM*<6oNF|FCe!sL9%|3m|m8t5Mv}54N}D}oocILJ<2bs^ z@C7k}qon7oLOoLuF6GKk=8Q8Bztyl53r%ZgqW)p3Na)F<@Qj~_`5yQ8ofEe+M zM+M?U0i{>z7OWPYd0+!-GK-VyJ7Z@IT9&uzp%pZC*N?FIHNuj~FPf)#nK>Yp(>)C8 zm2QU|jJrw}*DOF2sOgRUEV@1jZ>K@z_6ZxPbtKMd7Z94?pn^n;fgbym&!Ro?o9*y% zz9e(JP7`?+%(Yn&%Vl3QR|@%N?ql#k8P+*7U?~f^ajvdcU#!?s-wlf(iSIs(MOfp= zKNp|s;jtA2k%ov&0nDZF7U*^ukjya23+0R+4VvGolb`$ox}X)7#v`{DcPv)Q!yDih zoDotsr2`s`=FAxFdYD#%=l(n!tNH2uheJSKezI^!H|2F{fIfyZ3Zs zUnf&*uPHO+G+`xKQvVDT4m#?iQ>$DnoYgUM4g;P8+5X*cI6|$nqCRNx7RrB?kl>n1 z7t^$6yfQ@kdgb-{V1o;nJ4i*ZKlqxos#bgkxbj5!G48B0x9yy>&0ZnPY|^CHvi_;QGi~{w(+V|H(^pDh{NG zWFt}5e|pnyIsSjL@_)>4V?!Goqo3}ue@*dxLxq%R6C;0(6j_yrqgib;VG)UK zl<^o3KOo?braYjC@lNv`g!GoffJ>z7j|v8(B##Z`P~g_Z`gI=x{zbHt!3+5d#3Vfs zYQg`Vv5-y@iNAeZkbW^K=+ZoqjtQeVs#btPKIr!Z_!`jwou^xvNr-oeRZueylDw88>KC$LZaht+*?@Kg*M?)`+ZmmxpA9%sGSo!FeHyAM%xQg_L0lN86eZLwQs0+MSY zH&K651|1Qd>xD8W2ZuV~71thi3_iqOfN&%@Jg~nbhByGQTGl~y+sxwC0`9LKZv~^GjkJeWkQ7{ zh_?G3XmZ8@30ho@sJ!3u)78FsQL}X*#I)Mjp141hI&(+jgE>qwgPd+3T;*b)_YQSm z&&CWeuBVnLNw`{qS(59UBS>*Z2hc8kHMp72Wa{5u3H*T_`9Gu!5cBC z)?jgLqCMa;EaxWlN&)uG46f^0w5Sy^E&=(E2m}daSGY`ea zsIki91#-OgWjZ4_1s8Pfb54pS#uS0bW5wn9!X&X262HQMo=gSo@mVczPfPc1CBI7x zb(^aDwq*(56l*B=Ig7WuJZ+p4r}+qP}n z&aCv4wry0}wr$(isonjh-_v_vSO15I6>CO}Id0N>8}KhsY?ln07Ro>Gk}5YAR6hY+ zvdN-9ChP@oF_;X=W#zB3*pbQQ?Od{zfS57TG?L6_x~Tkvd6i1Gml#^IAEuYlKfNij zBOyo}zyhC)7^m<&IgKjP{Q<=pWFbR7hSiB#y%^k_az7 zmwC7ChL^!4=#&%Ve9s(VRp*Wdy<+N>P7N$QFSBE(c`>)1Xb9g|7pB_8Fqe>&=%>Qz zVnee~sUXTC0O@Z6^rb<(mSey!K6i1A;2(!>DD`G$na2`kN)Pe_<@1u9K{Zk+3u~)p zJ*arZIs0Vr9Po7VJ-95^4=3GuPC!7ZK$RSM-vtI2nP-fv;=DuF+ti5W^pp?y#I1aR zcaw0j(e>+_@T2)pEkC5cChx2rlqj-9w#&6wuXL7{TBX}eKcJP{O_*Ipni((`_o@|Z zhUuuB@=747x1SN5_es*0?Re@o{u+<4OS(@4r6FWNl#Dne_D(uEsN+x|WlcdhY)~ls zz}T_Cs$jx3ieyoNX-Q+%=Xmj;YJH-1BgS14Y?A5yCLx2^CCu_`iGQDZvUqj07;bSg zxEVFYBfWZxSy*)c@dXv~xh@(6c@6u;2`0gXEAwTQiq1sEMbY-+2+8aPttmTmxOCG((3r7EyZaB$UJ~5zSeYXeojdCtskqKh*AzIcNB$o5;MO=Q?<_icVi&vv-zvOC%(^ZC) z-mlMJi+jQ}5mJQtYo}d(s&QPj=cM;cp@gpg@)6mr6aX2-6dXEV@KM3-LTFM-!%!Ou zqcpOgEXi=+JHStLp?88xhEk$dN8s6;%1ONAxusGLRbSiLRnI+wP5miC8%0Ctwbqf< zzg(!@wN8DkUH@EUanuA$Lb@w08>+8z@*7EgiEz(&M3LXJ|h2wwGFH&FU7i_Yt> z*ASUu)T-%tDsHc?evDJ(DGH2f`-6oxN4`loJt?t?&+3k zSX1hN!n7oGV}-c}7MLH14`@jlZcO+53{^|0E9L~^jN?(sR!fM)5l=S!kFaE5rK{Sg z;q<$mq7WyvY4&)a<%~<~R40~GOWY8av3;@pvvE$>%vo9Hr&=+yF`dfMk~;L3brWA{ z-}n9{4BU^d_zDh2GW44_Zfgn&s)I8@*aG z-m}e@)<(CtFo*G;uwY${@rJ&>ipP{CojOO&B~IJ<=Luj;%}RyI}cs?4|V$@(aT{y5le=sRlo ziL;9b1*eG+&6y-T#7-zLyol{yx%-( zJfZME;QyHkSFts;SN~lxF(U&3;rut5(8TgTDa`*N5ngF%%U`k~`OVbU^-CAjsBHmt z;YLD633O#UQ)MrN!w^l7CJpWmmS0^q{B(DdN@tk_7KC=QAVy%#_&5{bIQ3D7IBr*{ z0S!l}!DGNQv&woxq$z?2F|-LBsyQuF^~w=LITfzi$X}3XYRp zFo4epTKwg*?}((+egG?0)LGEW6geHgcpfjhf=)OI@BL)^7c-GMPUdsL#GMLFk!2_a z2ZEg>m_QKVV|53mn~oWth)~^d&+Q;i&gXfVM8W5Ip0soIX#zBbhNk3OjOJ6b*22vx z(zc7b*!DsWnN`&1Esn^{$UiH(b9b2z3s8Zm5CCKZxTR)${Q!(C_elg?Dv;`%`NQ{gY7j9?P+&ZL0ae-pfu2HaxV(Fr%vM5@#6|qYcdm>_X zJcY5~GrM5~7CJpW8*}b+OJ{t)#7XD{&lS^HTM(*VV&YGTH<84)vq|?3A*;ThlVPlR z|7Q@b+;yRWmJJjfi@dFyGG^8LWw596P(eBJDqPZS*D&2NUSkF;46-Ykp_*}0i!GN> z1%#`W?{s2PPfhW(iYU80#Z*E#Kd}mUj7=8!Q4Q`uG5I%3@P^&F>cS4hmwO0bqIWPQ zP&u-KKEmO6;O3##0*=W===!nj`Z3pLX0^l4^6G+MYrOLp*c+nO>5rJljhMur`@jhs zM%(I}<)!&xYm>m#VbqV)kxBTRHq6>8V<>O^kwuHC2HlEPtaP$YSZ@kXW ztNYqagHNj#FW4`eJyn9Kd9TKEE%P{it>=!?dhU%cdr6Onuu*yN zGASx1jO{y|u$3P-o6S4#ztr5zGV=`s-Jxdf6D?0B@m>^Gw2l&xaAnzK@HR097s8)e zNo;!`J}NkV4g~5moR%76)Fhf+5>%1Fb7E{0EqT@*@zlw4JdnW6Lcdr1l@`Z$1P;FS zT;X#b)XL^vU&l^!FfT~E5T;c3ROAK{v>z7gvE#&ly4+;YQHttd_m07@NpFs(&s=hF z61S5;5Y`J8D7P^A{eN`=4Sk~=1}IFZDL#XA`v0$@)_>Kqpn(2MGUk^Ctcphn1VmQ> z1Vr%Pd@RfjT}<8oe^*HVgLCYX*V1`ItYPPgdU&=H=2QU+SgNJa%d0s`X6^o5OVO>t zv`vAKv&bMAC|H>3A@14zl3?RH9hf0F8p(BV2FHGxJTNEcI;Z>ibdP+^t6OxqZ}7>y zVR+SWqH6O-^DrX~jimQZH4bYkLL<$KN^el?%#!Y!v#y`h4y!ku@!> zFKhLgB{C^~oz3winlMIkzhf4-nU}pJ|3FlCKc=8p+Q`Z|0)E>G*8?AK#BNvIasm2_ zH~&nVzeL}72CtZYD5>p$n;F}w3>vS)c|=-J@1YiF|H3BAFN6Dz)t*tlb5%E_FHP(h z6p^jxbZ*hD;9!2C!Ds03@7eEsUl*>wmx+68eLGisM?(fbMbrlKDT+*5Wq*yh^ksDA zZ_)IF(Ou#O1 zz`jFb`3rUT#R^SjvKd@bUQO#m1fM=$X`%+`#`Lsw_rbrbCf2`6@}aLJ@MLs%?l$=? z%I5i9N*)<}$pZb|>~Osyyj(A6R4caoGTA>xQ++w}AAW%gnX$B#(mj!X9%O-JqsaXU z;J?*5A0d3hO?(=-bG9)roXOJ}ExUSGynHpbf>Z{dYD3`)$PUXZfQ&uKuM21C%+LQk zei;n48`eV#gwsbvc-bM)Uq3O&NX5mNXUJEETpY)TC2e8GUl6r9Fo1wYN|?uczHIdSXT_27gU0wwmxjaEKBB@ed>zKZeyKL|)qj$|9= z`R)DApIg;b*eGo^Xy&_3e=QcD!>!sm#^lDq!5Qcq3hx6?0$U?lmh0Ks`HTB++$Y8I z@Ed>Y&e{n^@Ic|t<-^xR>n(8voWh*7@qQSegr}YvC&30g8jGM&j6n_ zPslO>7OZ9FOyI@E5FbbCIKdQy-}txio?ZT^KQ>tn`&%S*ju`(G-pUG>t`OT8jSIO! zecUqI`X-4x+B~I1|J}kZ!2P)_Ye)|2I8jpZk=fm@><4O#kXk!#YlRZ$A1lYshsMIk zo*F?{RO6H+plq`j{Fq~?#SH})KzdK9dwOXPQMbZem91x{K zBL77N4XsoQ?E~=7LKt9t=5S6Rkc@cpJr+4~Lqo`ek-<8v>m5#zq0G6A%{f416>WnRtVVhq&?Z#4KE|+<& zJ;f%-Nc<6;jilymOj7sp{1_n$$bdA`)o7ZRZEkoUs-mEJj!f~(v83>dup{xTFbl1#bu(7EREGm@yrHksV|lxcpW9QfRR|qm7 zX1>E5(86hDc`uxfyfc47C=C!G^V^bZCyKscvLPbb_Y#kFaSqdlwM6Xj*dZDm6Nd*X z7q2a(c8?Yvd7(y?cX1`CUlmOnD7y!T%;SOZ>v)ki^Xc4tg8$ zPBj!N=Hu49hS-{3KW!pzAd?#(vk5IntG>i`z93N88eq`eRYJwqLLC@%z&HKNfQJD~ zcIdB-tgdvF&J(aSFo$F~*T@#K$K^>=Vfo9GcfOAt=MXuWOP^{Mv?Vct;^d9mY-2ap zGQ@6w{mC4{lVHwrKfw}->7ds&x~jG>SlA5t6D7r->1c?c#ffn@E83N`zoK1b`+$y zcv$+^wJuc1tioWDHe{)9|AXO=7=E;h{A8V-M3?AYKs5u}$HDuMjnA$3zDN6WikF@P znto^xL_VZbXZt)AdYKA_x{qdP(ZboMWEYi@gt&!7D2Ur)1aTPTzXy0MZqfD0u_vb-6_%SxbL_h2xop=NiK>^Jw(#E=elbc=E7$}@1!zYLhX*h?{)C+R zTz+KRP=Q6VkuTz3`v;J5{_FCK!Aj{sMOiZebD3W8)bD;$T+62|B|Urk#1T1uo#|$j zpisQIa>>)3KrYeg1`hC7lLOQ2nSYqlY|+O{^r*dFD0^wUtzkJH1xPYce5i8l`<+K{ zQzP|4QBIxjy>xjZ6<7D-Yq1FvKf=w=b6a%TMCR*8UqNhUVP++e5a2}?s8egpcJ)TO z!Y@i_Z`B!nw+(|eg`i`=xxj-=jutv2$E=B16&%GEOOF3KOmB>@maUnjC9?L=ml}AZ z$v2%`Crw>p(7Gkq2sdCBV|z1X7FmYTXHsEM##BuEnWf3>b4>)=d}8Lz;S+4HY{-yg zC?$#_V@yfx>Dhl=7s>P|o?G_UXMwV8g8GDF{)QRxFYQezph`4GI}2X##(491B%LAyHz27LzY z57X@jvlNu@m)}&#$tu4fK`SsTCvZoCIjTI}KKYvHBF_JdFvhVr^5Z0=Ri^7!Kr$8U zv{cF6OkNs01fOYU*4TO`r@-}>LZ%C!e5q2G1;hA<#dK7G3;>uCIifArs zb5&?HmM@KxdY=^@y?(D|OgdC7ICfXO|Oh!1*(sF($uojeg7}3X-G948j7Zrp`)z58#jXz{>rsSlI;q z>t|wCR|Flhlcei!hIV7w7rYYfWd{5YNyd}0<|^NyhL|^*D(&VuFNvqOG95hPDw1#{ z#pRS9jzAK@*s`qhb|a=^ra7Tnt(0{d1dA0+IHjeyXs`=oh`af5G{w@ueFmKn*Kkk} zH=KgvW*pH~Q116n{y!rcn;!N}w%!$Ho1&MNU~V+Y)7#y9K@S&v&SpFUQS#=-KRx?+ zVU>e-cnde}dX(x&j!Fs8Rfh;lmo=Vg$Dd-ROVt)x#Lsp`olxzf7fmo#DWZEhH?NHO zky?)w$Rlbl1Au3w^aVZ;lsZp|{pt0g==UFeG&ojS3DrW9_$&9PYuM}aq%4HacgB>Dnn-Qa^# zXn5DjU&XGw<&@5FtIe2kw{-g2*{hp^{_g_;JYyBO_U93>tWI{+Rro> ztZs;fq`}d4U})39rO|vY%@i?Tqk^+&!6aKM!tn*q~_%mD@0h_DNy2Ck`rPiaJy@V)wzi^ z-P=OcEiTF5$hHQh>_yhd?ra4-j>JWV_svoW)|BPY^%jZsE_BF|`4EC04kIL;2aBNN z_UvKmQvl%#6#d3>u@}Z528%cftl_#1)wgvs7fv2?n$u$_YeyMjp4g~Kk>a|!RK?bv zMcX;d#Fw68qLyzJ^H?^Cd#u#RgX85*U{Etvduo<+mG49rM6Rms{+_WtJR_hUvD;rf zBiTzW%fE}z+$X?!4FK_T!P(|t&nZE%p#E+u&kirzC;$X7gKGVS_mLJ5+w=|kR{bY^ zn;xqjGivKc=MOQP6P4@ge6rvzwK&6<4Yy{HCKz___(f9pWl8|={LZdk^P#&!S*PAo^i0O+?F{mOW%$9oBU18t+d z%}pmIapJc9!;2_BV*NPbi=(L6+n42c6K8)n0y3sQs8d$N{Y<;tp?tlc$6AwcSqevy zuPc7u3jsluckc8uvEZDQ?FGW4%@|nZb+icQ_m-*Cd6wGCG#q@b5X&YM!-XZreLHjy z`fLt^jz;g-Jj2IVcIt+j{YHb{qu!G$3>Lku#Uk1y;BhRJ7sf^Xyi$5HdtS3%@pj4g zbayx#BA^i)UgCFaF{^|l&y}|{8k=+a@IpKMdV`E(=Th{aZGC*T%PRXs<{fS1&O1`0 zbW;7lB)-rEEC`>AL`8)fVV?Il{)`1t*RnC;t9^6{K3;B|0h1b+G|@OxjY87~8G7}j zH%wDNgJ>o9Y!BJ#Gev(sOf!Nh$nMS6-Y<84#&a8_v1I+|Ko=3k_Ja=skHpNxL+hyw zs!CKq?#%?|O&_qfrHzjRcT?c(SOX!6PNUxHU!b)d4qrXNJD(i6gUxf9VR0I~`McRc zCzKI(pi>7-CL?}xRWFOHvUySs#wXDFdHw==z=;a@#kerBYoDe?#7DOshUZRJvt_$+ zHZ0@5<-^WU=eJ&}XQI;^0{<|gpTfwyiwLXj7(MtSdcIT{W7A3P+JDImex8iXS7R6+KM*Y5(HN$xK&F~2 zkveULMj5e>k&;we@r0z&ut#g_GxSlA3u!=`J7u;aURBgB(q{2nnMMbT>htA%v<|~; zwpls4ndEphFl>!>=v|zrH`;tmZ(qN*eQl)r!h{q%r zH%6ISo%4J>W6CyRJ#2Y7YPl9$`vpgv`HiJ%+s^J!kgY}6!YWsidj)|zz_nc0nln|4 zZNftu*{R&JF`)Il+_XUzj;yPzcw2#O3~Im`H-dZ3bqebb07kq8giaIe+&A?x^!H%hK*e>-5b2k*>pTP2b6iFFni0im5%vCaX)mc63tO(8%*KP^MEjuOxc z&99C@#U5P{pr6OB(LuX;`&T4NFUM76G`|Q3^-JPuB4kwyqY&OZgiM(ZKdJM~MY9aO zAb<0Hzea|5>~j?OT~cn4XL7Kdre#DCI_W}$5_4395`^`h#1)9G)qv^AMkAY>qfY!g zHl}ZD1l_iX(7{QdGtF*J=sKc6lNg-5WqsYM3A97Sd}LUB)UZ?ue$A|uIm8BiJ^T?q zU~@tDqC;;v56EvKA#+(f%fI2QG={z&kA{$9LaQ_#hF_mqQ#p~w2W)_7(sQ^MQYV-G zdrg^k%^}yQca9)W*-6L)(C?LgzDV4|>SV~@w$Mh^BH@d8GUS-xLhgMXzUq_rNGklM z!+=d?xjI{lchxc4IDWC*usbV}^H8#tut%WsG;%4G!9f3cTmMt%>4&fGP7%7$8M+@A zhSzu51U0VHJxG^`kPt)pcRn2@_BKkf9z!0$l?mc01CFFDZ0H_3seJ=QHQ$-1?8vce zn3c37MSH%~QlR=2x*bQJVCWfwxoaP!sy7t zZ!)^8|4&{&=MHMGO&TMWX^C#Xqa>7Uo538jDY135A&)b-C8q$2CHfA{9kzL=6ymJF z7R8+d(@Q;&v7Z$Zq()Zw2FqozCROKup~()tE^xR_+S%W8?-h8xom-KJeE) zY~recJ3^F7o{B}UU&2|0t#XP6E$1FFCnBFjWjF+VxCU>My>qnonO>$+7}=)DQ?bnc zblyIpB)<%$omxez5Rm{OgHkkK$UfHS1CT$f=$B_NaXNt<-`~d;i7a{hJI4gBP9fYu zn;iob+wPE=M0(_&3BvUnWw{9kN)r8E&BYzodTQd`@ee}ec)#oVqD!g3y$5bKC(r{V zrG&5tW-{gYen~Gj*)BgzGKfn>5pa@kZ0pLqMRn0)&{4+*lhFKRp#mC6;F?jmO1|)> z<1nIO!S%=h8m5|<_g;A^Zx>T`XAifMJE#_;e8f3Th=5cW6W`o0fa&zq@nnY@Z6_X9 zzcLnu^e4H6)!Nxc-`!oh>Se2ME@$iT82?ppj^f~KXx8xpZ2(}{r*zpm3C zV>giKP@z7)#@&4ivkiJZ^BJDMshWhI3~9N ztI};GS-j8kw8zNd$hM?pay}Y?ML$RqGDqT6WH8n_FV;DkL)z#{EP%L=p^>59cw+eO zs-SwU-bXB%bL5a_Jnl&T?h?aj>XeO4l_e@+R|wrRU_w0-#qwr$s4~oJ>XN&Y*NYqL ze-EszugYuQ?wNjQ7$TZuq#+!48|C)>gyn8c)n!VZy!0a`{7B@V^}NE5KbAM=ej6=; zuOHtqg1%{gJj0UMu*?WSm*Kab=dTM_VQfP4m!+zP8*n9Kf5Cn1xc6}TMJ$FN#Wi`S z=@&>G-g$I#&6{2ri(Lf=IXqF6GIjf;giTGm_|Cqu=1_1>CDk^jW}04! zEfH0j$iU8F&yef;vt|jMd`bLZa6RBFy5OBmgcGVm4LXOv+GZN9dLG;kPAn0Iu)Hb0 zY$0?B(Z$P;;QU}|eKth0Kf!mSA+Rb zY?gEi6_T9Lvbu+|X`gwrSvnb=NtW|Hl{vO zgd8g)fv`PK>xcv?f&%4T0pBixup^gV+~*g34tl-- zc7pWw@)EE}Yk%A3#Wj#HT^B-~URNp6!%RirIrDp4rK>BUD5$fXCrFQtRL zA2E&4g*z#1i4$>Ix=N1H8Vc*A)EQ2~|BPY&(~e;uw8T|>%?4c^LiP|AZ@v1IJ$B4` zKrS@bgwQJ_T6a`s!l?#AeaA6%=Xz#DG`6iPH!j@GKqkRzGa-QNK>)n-*8ZypyD5I# zaY#LA6fKOGs)-V#W=8ygwyU1U-dp&F%NsO|!_qzn%znEQx?8uF1UHVBKul1{PY=w$ zI$qhG?xLKqyYy~ME#xEWweohW9Bdj9YS*`M>;gxG_B^?9UyJH&z8UP`HNg#qt&z%h zRj2p7S}SlWN~>c@XYpB+n!C~Gi$rQ%-6U4bG-bo)SWM8{yO73~!|YMk>H(OfJa>!O zf$$yScU}2t(YXrc)CD&Ni&)`qFyfBd^Mlt%-PuEvpzvP}H8-SWYlBg#~j3gOGrpfCQ-LIT^S$f+HJREKU_U-SwcYxZg+0X3{RkDC9ex& z`Ej8Aw5_?&54ORoN3MCf&UcegDZZPAB}i;Cv}1)ApLgX!-9ENzYwQ~9pZ3g__d(K< zyYrEenL^IgXHQPJNWTrX%cWQzD7##toP(I^ksBo}UU>Al z)$9RBa4exyoLIT~GVT5N^}`(DfX{7_v5NjL9Sjn+(XPCrrgHUBj@YCRjV02H zkFTqHL;tsA5kE;j4B*9@N3PQ#GwA&BvvPJ6fFV?xhL34*>Juho@za^cRAl3tHR)>G%x~=LfJamcH3E6 zYu)jq~dv~iaMBUAN*gWI_hA=~`!ymSPi zoed?l>eSz^7LqU8iR2TM;a({skBWONq-^EDjA#O?1IxJv%w5RWIZl`VO4piHm7b|r z!%=RurWvfepH|Jk+bxMLEbHoN%@dE}Tw&FoC!6flQGRrYVX?%NlG@}8iPJ~oc;Ion zKChU9HdgX5N;QMoK-avl-cH?I>o)3^PuIS(&8Ryr*mlKRo7^u_NnD~Wi%WE@*xI*- z*60#nV4pV3=LeD1S~3>;aJMhwZ@}) zsc;aWR>_T@vA{+NPk! zlin&g>VN9nMSVeNn#Qz9z8F1JV-}SUJ{}csw>HS_%YFZ0t`FZH7BQ?YvTb|q8z7~u zs{GOp>mZ14YjCBQdp0SFrfiv^AWN6M`?RXB82^bVvAw7V-TzY$+QI!_b<_Wa^;?+Q zIQ$3F|35HFODTPJ8w_w^H=Z!WU6i5*d7^!NjD$^7@i) zO5Cz7xGlp=vh(s$pE6QX?a^lEv2i? zJp%|oTR&}F`qlGG#xoona8rDbeVCwjnqbc1q2K&*z2h)BuVl%588i{Yx*QRTnsk-9 z{6E;pdQR(J^G|LUx(ampBlYb27d}M+*q<}l<3%7V=DLHrgBh;#TEbS9wJEHOjjj^m z8F_k`xEAO*!P~ysdqIBaL_F1S^>SQ!MX9Xt5WdtjX4jpPdTxbKLj{c>dq$%CadG2h zP7YeKt6lf%-*#aEr-cuAi6ue;8nIcqql0S%U95Us0)joa;$TxOAV=`+g@rk56y+KE}Kw z2p^?BT_YFo0VvW}*-bnQ3TP<>x&4VEz4T){MBwI|cUpGYpWP<89vyrJO&>uZhM8p@PZ zjD9{6l1-FR%l}%EuSggqyx`y8-s(ja9C66qz=Ka31WE*L_v6CrBPi^C^gMd?9As5P zW#?JTtic$qUV0%w()?b_W6(LKDMbpSTrUb?N~W00q9KeSfj}Ku1KPWY8RQJEyFGnE z$cYggrZ!PNyY{)78Y3T)v6il{y`x32=a}Q}tnB2wZFd)40s>&9rtQN!K%|4n{Wdg_ z+lTRJDdEIyMV@2CNsLr-G@9;G|waiE1(on{DK)%pH1_#y2iN<~iC zr7f)mEM=1g?|^C~VXU%|(sPteUDR<(+EVTu?8q!pgf;oo(_JFr(xw45mPf^hc)Ytv z)Q1r=hZiWs?jxQ+o#Lc!U!+fJW0qWIp{F|^f{FFsl~x?AFTLr4Rd3PX+mhc@r6c%b z6#aTPn4{A%_3MPGB7-RLD{oHFiyeg=uC?yO0U{`(>tF|t`Yu{#e=vtE7m7O+FuJ1d z^inQ%5C=a-QZV0F(T|YN91{pN|4z@&IfzLIj=c(9oL_4dBYR>HFK%KYtM`9^lGcz_ z?11dHH9|&M<0&}g_kywnLu8wM{mhgBQ&Ns^1TQ)&Byum}!XjYr?t0v6HeN8Vz2N_# zenp<>&#_KA<^UU2QVM?}lC*&&6l2;YZ<;(KuzPf@i0SN9PLXq!(brH5Sm%JfP{5GJ z+-Ks-3q9u*CA0{z<`85mUcgkO&7c8H7f9^r-~IKTRYzby+M8-G+`Pf!7LoAr&<03Y zBgB2Vr_3!(jLS0G$=Ub2CWMl*7^Y zHB@J^HOe6?oT;E*FO&yKWK%`_$>G1Neza;)p5unqdMzJDfI4@4M&Dl;S_3gsj)*)j zEWDUgqts2GiOUV7*8(V3RW<2x8t=G6jC%ICAKT8^oLfbnETFGOV993x z!jT#Law0!{Dtc*HF6thxqc-x|Uv!tVg)Clw;iH$pqnC%) z4Nhuhkk3Oq;D+#9_-CI3(ldtBlC2PTdh~#eyZTAe@@vz%oFeW<{ws+bX ze%4d9ekGoEybCI{cqU1nY3koL3f1S0A1rOySJ~pszPxt^6N7(U)|M+2SI$un+`^bz zCN$}!!K=VYTRN$#s8R)dor%$9!s)YPyIC!>2`3UF&db{W>EjeB{RUEp?gE4LLa)1%vR(fYnD-TFu^h@*#&4WfP0 zlIPd*7p8mWhOTT?cl$%Rp_&RL%U(O)K6>HIusbeB`Ks3|h{2V5+Sh1*P6)HJgK(KD z+C95j4indwH2^ZRVVc3ptMS?l?$#~YM*cc-6CP@nwst{3yP*_QS2rIfz#BK-n3c~x z{utOUe?M49hGtvALajWaw&@i)tvm6z_F{I%?8jF2c_~^z-lBQN3&7!`c8G3@?2Lka z#7|c!tdu2(f#@MDPNQjQjC@Mr%6z;-*+(Zh_NJUId1sizRg;?gX6FPCE-3q4Uepi1 z!sv2LXnJ(l_v>tI)*C#PKXYVoxSvYVSM+gmmdO{qR$9Eeq{z5mC|*Z=Zx$n2vTUHt3wVnhG;?6$4p|I>fbklK#@!aohijJ{#NG?l7iX1}y3 z+3c`@goJlMsa3<~aXY!J(SfrX#QM+lnxuBe6`QQYZYtNw$K96uiSq$REFB9JC}$FU zM-aF-x5=6XPgKE60jqWSKfPp=rZD)EFKW`a)I^43x0P&VptIVdl-0m`iN-jthbrc5 zK|c`+Y_h6#fIGIL%Pb2-JlHg#-Cwz+_%F(5ii6G2*Yp8KqVKtg@UnePC_1YG*kepJ zGn|eC_4=D6u^23Tao=8PnJTyrWh^u%xFxyQn-Dl)@SV>x7Px~)FH^KmxzUBR-I zTcq?OkuCK-AFKg~?~-+x-?1@o+qNwN*T)usIuUQ>$qsoT~m|7{Qf_1 z=KtD13)u3Mw*U6e6YPJ#f1FMKpZ*88)OMmX1^#JlYB4XPh2j^eeL0|-rT{5IXMzi* zg)~VVlBlk&9oF-oFL}60mIydu&v;GF1fvJzl(1(`E z*y7Is`~CBL1o;`%Fd>&ED`?S;ITS*ulTMa2)o-HtOhzQ%a7;4^8}G_m9?76<&$~@) zcC~!vaZSy}`zvEaSb~RLsvlvEt3ev;rpE;aoz-gJes()c4>nB%x<8iZpqFtWGvO&5 z)+Z~`WNw!_>n6_Bt?SX|cboPGyLH&G5X|I(mCHog&7c|T(KpxQtwW&AN5vq_5^fzg z`w3bm(YywjJ%rk{3Sk@IdSCbBt@g(jlc$5;3_#c2v^EZFPDzOz_Z=9#>)qR*V;$u- z9aqFv-t{b)*F;h!LJHV;7}RGji=~oTH$EBli?O(Tqz_k#yKbvicwniH+MLSo-eSLD zY;t0eis~kY^5Yl45@jUYNjKdna^rt}V-$YT_BVAgmeZ7X$fdZV}XXJv^7B2 ze&`ypJG7pV>G!%i;Psxs|Dcuy%N>HA%_mRL{-Nai-s>T~3Y<;FQubhDA{jOZYSSgW zO8Mr7H8DB)nFHZ`4I|0f)0n2+`ofS*2H?0|IbJRt|_II_0P!9@GlzwckLY)8|VL*L{qG${a=a3 zZ>5%SF?=+wgiNpl1e}G&wUJeD*`vPE&;ll?B-Vxt_1`Ei+?$mf1q|_)tm1)BqbQ!w zY4>Zm0R~JQ0|cf5v^VhMLl*MV5ej~YzL_c~Y1{3Df7O*F{?vX!qyis1mwzfIYIUr< zOW<<#T8T0$IGBH`$Gx>PQW9g5)AZ{X>Lh3yG;U*V;@g!>6o9E0UF4$Z%ue_pxK0PGIaKAD(;>EsX6nlQ%}S?oDrt{% ztruLUyLJ=n&Ac$*?ua<&Pk94PvQoz>V^XkaS5R;I?itX9b^Qg`INqs`(GklW$+2R0 zyH2O9+Vu+YWMrs&l&in?lniDmxw7t;H_mcCNPalm+H!Q`YHKMoz8sjmfCs|W4{@{% zlx^;YPuQ8goWDP=hNj-aQ)?c`3{2^pUAGNH($Zn{+DTa_EGjk5OP*~8(_duQIlX5- zOH#gb--KhrU?6F9x=k?W<-BF2WR{!RG`$w9W1mUpSzFLsH(OMTo^@LB$z|i@xJ>?9 z_fhvZOd2fo0f$z%8N|{@md2q0WvQZ%V#mV8bJ6)ZxaQY$bLa-|i}@p|r&DxieoIf_>7ACCt`j5(w+k~b>@Cj*oPcKX`8N@i!Dc-Gg_0>> zNxd`L>*=z=7z6NDij+Sw)GS1XkHoj%e0`w~&@ zoVEY1SYfr>9Zz*8?ci|g&T;0XdydoK7**~kGTw9beVvsT)@DBZ^tNvBYXd``Z)nZKi2XoVA%Qrce9Gr)N14-?^ zqaq(?Cnw~;t-L;VBMqq*hP?CB)?7Fgr%nlB{F&?#&mVgpc=AdxVEB27|;9-LmG^xqUv{O@^F|GR{d7$4zvKGI^Zyz{ zVryz^@8qd(V`})Hf{*{w8hjI*}}rwU!pfft#av_FOD=?_wkTlKK?2W&)zA5FIFxG z_m-P`WiYqTHg4}`(zPa@IrAe#)ZHw_T$W)C`bN%1(+~acWID>uaB@&**?e0k3ZSHK ziN}}X7I&Dzlv`=ir`~;`yWWo0*@)4{oR(N{7$N9_mdBHpph@9%ZL=R;Pb)b#mvDxLl2RfJ5*Rw5Ou}1YOE3=T{({; zZz4>apH?yY79{IS-&Qz&xK&&(m+`5l_4#(v)q1>fz#ljDa4?sN-cY*wZPt(Jv1%53 z(~RWGm~(P;A6J66-bcEM0Oe{|AK9bb*cM%=wr#NT6b`pOBk901F2NAIz2~S^uKZ2P z3Q;PxRyV`&PAhUdm_O^dK)#CnE>)=e7$^RvPV3_2<@WjgNdJ8&{4=n--}q+6;K2?D zUeRKEeR&yk$67E+JgMRuT}=@_Vu2sUhD>1X%#5i&sO!i2cM6Ag(Zh-JqkHFUdeGb( z;fLv1t^%r~th!~1j$r1b#yA>*tvkh{qJ7Tzv>dxh!9%Iik2iJ3s1CIDvPU z-wy!k@D#um)4?MyuIR;>w+q%-Sj;z%+FUPvSK+kyD_QB-R&g?ZM}#FTx3M03n>-DC zP}rH_&7XNJ8LnTqWJB^dIVHn6@Pkzy?fxJwb8f(y;w@? z414TEfp=RPW%8V1frZkomPj5fhqHc(56ZZ|(G3hp{B-ig;XS+Gt-s%pM9ozmUi-5{ZWbhmVOBi%@ebSmB5DBWGsEhXLEA>6}PKhW>~ z^xng>pXdB>c;DG;*38~Bv(|d)6D8yX$Tr1`w&%dx=eR7)7Ul+5$6kr_uXQP-S_%?H zPRs(onLB49nq21noW^QeYl#WG8yXrgjO1|~#BD_iep`r>tcWysA1)DeNC|V+IsxaL zJo2r!A>ADNt86HtJ8}U|T;XZ58f>~Qr({_77xjL)LnMM=bOr=dt=d?;hLF7|2$$Cs zB+ehV$25ST=Gc#6(BUkKOuC+@H<>E6!;7n|amIZn1Zuy zqstW0@=NrT^HnoPYM^-DPosi*`FerrcO+p2sPnbq!$G;Hsy%Dy5~>p1ZK;HM*2zXb z{fVjKU5Lmm#Y@JI_1?xT_vk!KQ4C}L2uaXgAV~d*IxF$Zq$u^4HzobNY-0FCg%ONw zZ%eR|xGhesd4ug{yZHLbgtV@lwAh)I1c1Td#i{O4v1y!Mu&dv`G7V(zwG`%aO0aNi z_>wHsB<6Hscd%CBkM?XZPsg-W9(~zHjT+3+1_|D z4m1Qf-x$=LQsR)%6&u-!`@1=sdRP<_W_KDDX03Nnp9)8% zN7STZ*DEC0V~2$fz}c$Crz=z~TvXrDxfN#ie_E~O^H{gL?^ilnTU%L@$W;Jh6t9rw zS2fR|y(VuWk=!m#z)13-Lw?D7aynB zV#$BGG=9&vB2BFGt>%Jo-x(ZsuXseiyF{#5?JMw}iQZXcOzBvR))45lHQL>z>5JaV zfzpgth1o_5yCZWfJq$Rw=ExS?{a^#2G5taIe&lH0THeo(-hct4onE9zLaKg)x(N)rX+>f_{W-lxd zc7%22&ysUVtT$l=Zzm2U?!%oDY6^5vx>6o^yHX`|+k#`hs&P}Yvv@@8eo1_f0ZcPx zT&u_}tv^wX=_&fMtd}V-Qotilg#nv9c%%sAkSw zRBdoTcMjbYl^TUAPUyV0v_1`_M$Q6I1Da$Q8$aTjCzRYQ45hFWr(J6gI+pGvTL=}h zI|!DOLU#%X%PqGJR8EgP84Cd}M8;R@D1)0%SbiQXd*2L;xa-xjDlB^&-RP?trg-so zq`5Ow>}9S>BE+sdQClMG&hjbfnz6zMs`20zUbQ1H*k^sLP&Lpu6ydBe;kr=xq6(1n z0p^XtAK3s8F?&{NzZ9qqbb5;i6rUxuFJ0Bw<{eoaCsSC@tY5(Xt#YektA@zQw6n_s zcCEn6<=P=0$b_p=s2FAJ>RSI@6(4#vD6Hq218wCTy=#h&n#!kMpAZuozGE9lNKw$R zv0f--2bR;?X#FlcIQKQfSZlJ?+v}?WbKM~VYr0*b%EmbRQ7v&|v z$Q$$q`*Z!%qIQhHq-gC7rF6Cl7l8yY$NhY~gQODXWOf~=zis;gZR^P>RFcE6$>jyL z4ZT8b6T6Oo5ke~~Gi9Db82>h412BR|@&$OcRU{@Di+T!lFOB_m`Lq{FA*NU^u)|jh zyH;GcwPfba2ffV(%jT1acxrP~818!!_Osz#ccc&U!K*f)QVW#6#Ek3qeu_A?R>qvT zlsONi3t?**RC^AA;9ST?w$2~ON?J-)#fiydxTD9*?qVW%5MQX<>3d|0Cq4&DIyG)H zp@4TIwk>^+8geb=U%AQVHD9mAgPM6=m_jh8jdpO(;qR1>L1d%P5Nqs{-5DENMnKX-1t*of?aYM0_A~?aI?_FiWK)WfCdf@^mOF!^-g@Ujmq^w=exql_E3prwfgQs$fb?tyaI@%)D>|xQPe5Hv>_Gmx>P>959M6FQLoRpl@Hterv4>7 z`Kr~Rj|TaOTc_vb$_~xjUs(1wST|OB-IU5tc)PW|Obdnz5G~!RwGiaHyVWMpy-b(> zVt-!t0unFhCQ{>`uj61%IA~ET_J~SV&AGOT#DomDd^^~3w!enGW-596^u1N}J+td>oq71_kA$Ppy4_xH?Z zS>PYg2)c7O9T1LE6{qsr%*Dy#UX27RO=-uaZY}O!O8`>{2HY3NcP$@I!gwbEslDcq z8j@w^M8(qF8&?V(fHZ%ep$z&4?h6~^BEiS%i@nc!QDxgSW^68by@FEso;1fc;rixm z(bH<=az~x#3F5a_iWyqUP^GEfShze8EY9a7@%$R#70`}WOrc&_J8`Wb8Ql8spN>=%i zatt&tMxs$j*O9{ZzupK&l-S@EIUw00l|zlh0iENQRuVekAG_PZ$J00~g>vaP=TKQs z7_6~<6Cqn7>i4!SFN#@tyo9|46D?c0Yorc2zRyQ&sl~?BEoP`FsuC2PSX8!-3;m1G z*)UA*0MJa1ExZLzu}Y>;LcaW)sV*EX-j~LXtpr?dFA*e~%=GtdcG;gf5PN9VqYT~3 z(!7W6?0){a|69-pHINlG7E474bU0P$G!y|*^4d!K_h0KvqdjqPNvA^j0ztN*KDWNM z;78}JqpyNLl#$Kiw(T{VY|bvuw56d)Jk2S2tCv_N8vsLH@R9=GG3PcjG*4r0Adw!w zA3}&)E^?u?$d`G*Oh3&|L9XqUD3wJLDR&Sypd1T?Z}c+(pA$qdIAkboda!Q@#qcs} zLKPcVuYeOehG4lGHYdaGt|n7zod$~>d+lT8k8U$RVp%EMuEr>gC_~9$nOx8mj15Sa1l$WNZ4zqlTBcY9yOR;}w8ncp=g%DBl6b2V|CtT%4B;LoDK z1Y?59xk(yEa!;?nMJj;H;5uIznb)-J z3upu|DlqPfNU@E^uJF1Im}#1C^N1Ccr63Hc20s(Z>x9&)$wbOJRjbKDw4JS19gK}- z6{UJ?SyfP`1{V+LORDGNe!kjL3*7pHs2hFRYq-mpJcBeuXd4S@Y!2w!lWRlPqvh;Y z*9*?P7#R1HdCjj+H(XhlvNo(Bh1ryr(eao(5j5e=T#BvoUv!t0Y`s-Gl_GZ~h!JUfSi?DHyDWhK|ite+L z>)Rg6HZYwjpMt`3SGq{N3}?okbt^=}I;4 zQjVx zvEhcM>Nyvp{dWF0LCWmx7-a0tY}9nlomVCJ!co!N2Q;cx(GVHYuZ@h1jj_#9?U$Oe z>0a1XZQ#rBc*`c2<7=(#a)rko_GHf;UmN?LzjGX*Q`sD;%b+ytP4;y-gv?y@JINb~ zGn&=g!$1%MKH4h-?(#!VDIZO@h-0J=hQicN80!^e10DqV>OA5XiGRkBmyGuH!zwQk z*uD-9#D#&azi5#QD}CCY+=$e>CN#FjTu}M#sD`u{h=e4$=lk^A0~Wze)!02Se8&gT z{rF-KtD;f8A?>vWrrPr;Uk`ubky(o)>Y zfo!cKi0{s4KK+9G&8Y1y-Y(=I4Xx;!kp{@I(rS}A(;FUfG|wEdz{0~OP0lyz=iw+uu4Q0Ep{2R_8a(I#g>Za4z2E+uFTM5L8N^>kwfA70sMC@6A9pX|Gz~ zH$uDlzAn=Nn=Mgxq(}>N5|JV!+y0F-Hns&h8_q{W>BO=72>-Tti9m2OQldMfGpDsd zN;GvAWqmugt*O<$6mCLZ@0pa_%1iKhB zbkXmX($Yhsf^=0PT`bbmL0u=p)DQ%8Qzq0Z<=4w5^9v8k8K`eFJ&1eE(@NyHiBi|@ z3T&0%!9WTY?@aHUCG|o7Ooib4PM1||gJZ`~xCb!lu);|y zuqM~&C-~>{sG3EdVbo_jF_Mvm9|%^zqDY;GUjVbI5WuZ|P&xSH+4p;6-ZEUu959Cl z9NeJ^P+$azb<@*<33FiXvYH$SYp!(|-nBR=w{3criawuC6ToOJ#rGi0aHW!rMTj77 zIHsUgK{%XdrO8e)&GVo(pH|4{^|Hl8gZU;76VaWU$u-I`BfZHI`b|(eyRF{D`cqbn zla6Ip%cQ_8h#F3hh`DX?%ViA9dwEwTmz!2DdcFAs$!-Vr;P81PySqxPzW5dXgT9uF z(`?iI-kP27(4^3O!dGl`-K90}|O;iK34OE*jl|Hr4;HpR037|wD zF4j*~xBjrt8&hL|=*n*$^zFTtvIc~}Hg2exY<2Ey=z}|A!mOg&s|OZSmhG}Wd%H~u z(49M`tyR7Mo`*1Nw@d1{UBb<}NmMD;BogGBW%qVadPufo!?vQ~yDyxU@LUL;#YpEPy4xmjh%TWLc!!^f@Q6v8(319&CPNmNIw!QzOud);nW@RA%T9lhrw>i$tA0T?D!W~)25(`vQzX-O!AYvjk6BdGEckLG$a2BoTgHE$x!|`3%bFjuP5v?rq>}43r~hte z&?ya%x(>SZut7=H;pwTTU`6(H5Vim{YxF1D+!UtoTE0@n+a?IiRCEwcUH}pB+_Qt~atY4#`>Q(insX zgcxd7=WS5kA>L_?yc0NE9H)F0AGi)i3=U9~#qzO2WI);U__!LC_u=y z*qL!bd7v07@L@CYRgs0TMQ-Eat#s?5VzyucA0wDKp>VJ3VQi#0K6G^zy0!VWwG=0K zH|!;Y%Q4G4UU5@=!5LCc(>l!MT?p|7;Zejr6O;izQ`S4wV{n`AZx z$!Nz?l-hbd0fsHOVwaJ^72z3u1;|~cI%5k640DnxKcGDowB4!z3M9S5q)!I-Y_7Rz z-=#f(%xc6aNU4C=bme6;u?kglKoxP(&b|=JRuQ6{`24Yx4o3L}(>KMP;mCro#adx_ z{Z1;=TE$AK^se>hCHYgw_l)z;xbH6H!61)n8GwBcT>}bVT5Ex7Kn1KYcd+R)vX@$B zbl!tW_3ORb7MaFeZ{`|(0YrnxA`aR8A=8h9WZ4YTpeC`cp6Iz$AmN(Y-J^1!@qiNf zPN6hkSpeZ;{BS@(SWh3{ZwxJr9;OmQfS!z@rII2v5NHOVF8;&yLv8%VxEc6D(Wr+& z`o%FaN5*U}IV}N`YYw9`^DDy&1x8BpPyZwIr!suS&Ds|Y}OCyPrYQ6KUrPDDM@p4%R zWhE8HAl&0T(d*e_Yvetm)IzAGHYy?ixTiv0#{GBVcMH}TgviZDxcA{Wttbtmu(Tq< zZLw)YqLBws@)vutd1i9U^o)7UI z<^q90bkh4Je0D{WTCOO}AYIbwa1tS?oLX;(JZjKFQ}KwfFEaDp@Cyk;pM2NlA>i&R zx>0>ItC-Xkp{2qSIOE8mElhM->L8I6PQ?p_i92=SiYp+BR)g$PtqgJgwij)i9oIhg zD|C}B_t*7{BF%ob*Y?N0HL+Gv-rdI2?Ly>tNqJ-t(ffr(I-8gwBikT~bCxSd!tl_pneEWtc}vYq_=roYA}jOt zmB?QlODh?^W`;m8~5ctB-lB7uOW+u|46zJ+TX#9H|=}zbfxMs#U*g zD>IWQcb&Bra`PGbr7}&#Way#$>2u}9A$_1Mx=_neU^ZM)p~0h z5EN*-AaNI!d9qkGi1!d;8H2k-@fKj98lK&yqQP2HJdU!U>g)xY5?cW`uhf$7E^A)+ zWkpnb?5!O*pMW4@?z6nbtHq2qAVQ9tw|v+2R+zKb8{K>y_4Cr1`{I;FZ?Vg~(`fy3 z)5NOTL3EFPY{v4V@~BnnY#5cBxZ0h@-dnwd!<%Yg+by&g(m^d2tGRb@FQ{o7F9^#P zP{@+IeJO}t_)}dR4+lSeiC0KJSaiYSV1nO6=*V~D`_|Y^*mDDqA6umx=Qv*mx^Pg# z#(kND$BOKskLjJ0h`)(m%j{8xf4sm4t|yS3yg7;&<@FR;t9xg_S1Bg)?u~}B%n_p*WSTaxTPTueJMu3U z@$qws7$bugnM{WAL0i5Oi<^2IgCbYeDJdLZ*i+Ft0=I-1BQEgBjgqx1PDSUK)$%@ZO~N`m#unzLoNT zN!low6rK}u@a>W#K>~VmngR{U>gTEa#qU!zC~sE85y`!V?0f~ohs09s+b_9s+1!gu zz+vM=m7ydM#`dhr2Z5NZnP?Q8^Wa5pp%}Bbt(<}WyxVtyg8^EUsde#_{ znxI|7>1#1+8t#?|pjMR88EzKnI25|1j%6Jo}CLM*8dK#Zb6#=D2a z)_V93S$)&db zE3Ho=C}pyCUb`XVsGLi5Yvnjw`6PMS(C%>E*|Wo~;v`h>mc7N*5-h*ZO9_(|<#V-? zb!oA0T6G3C({C`kWi7ukY=qJ)4T@xMn5xUek@V3P1zmgUdIo-K)e$fX47nTc(-wHxb3An8NNhl-T`yxezG2(nym7jRj|}lEoJ3C zKS&#yi=nQ`PT{iD`0n(5I;Ft+36IGl*)%GA%zG1S@j2+0)xLEFHin>iS$NJdI4BM4 z8Xc6r*HCyO$kjZ4qZbWXubOU&DZQZX{rkOE;OmamXI$e=M{h zzU&3FHP1!mC{!$%46$ii8IbHbTp7q~)xB)horFMF%C;O|wwl!8@H67@0OuHFWXiI` z0N1h2#PEt$9rSL-PYxZkC*|pX<^tq!r@OT%^)faQE{cZI=F^J9P&LCh& z@sIze96hyol8~FrtE_OEJTqLIbwIL9J;Zw@~z%csf z1|5nw*r0lReW-*HDm=bzZ(p!sv}&H|R7Ne;TE0ZZK_mEHiL`)wuQ`Zd4Tp&3@>MUT zcSv-@lrdw^t|49XsG}e;HF#)9!ixh*9nY?Wg*7F|s9XI8{QE1>Xac~|my|DslRod~ z2&xrJI{qu4kDe#zyxJ_ItPK~Jyp`iGtw)>}iU+<92h232-+HTxlTcQSN=|xo>(OoLj(#_sRF>ULnZhVy-Ja^%*QgLL$rf%zQ=*XDE=rHo zfe2rEVnnp#=;*{-D37sC*R9PQvvk#$eoT-_KEi#C#IzD-bXOE4CQ-|(8sz!5>A1kx z+muXCD$md+q1`pO&(&EL1U4Glb|qY^mZ5rPp0rGH5%sxrdRjDtNCq)PC%8;P;6c2a zw%k(aWnzD?LrDOoMiZAb9O6Zfgf@~Qs{R)n{}uk5%bk^X{adFm8WBEJxw>7-I`j)d zr+UD|F?p=ZTawRoIpvU8mE!lZJ-5ZlRKmF+LPfJ#Hw6^?@HEhi=LqZoi)cii`k?bl zfd$vMENyDELw3E9Xh(TB`%?qk*@Bj@T4VGx6a~+kCbVRg847P7ao$o{|1sWVaJU@q zg`N?xD?+P-#t~*x)O)Ak`|!|R?WZd`Ph{4nD!`tXGazj9hZXPtVKLXJ+zZHOL~Pul zffy8WL?fy2pnP^2K@LA`zI=)d20|B#l2efvp{Sm@+5BN#oKCi=CLcaPXz-Zx08Yf% zhK~y}6`a*ot?m2z`t2vJa>t~-))U7dNlsV%VUI60exw5YnIRLW-2A>|k-m&4}1+*9P*R*mO5+bqH%);~Ob8rvS@OnC<93fLX z2#yd?Wd2CdBM;P!IT|@EQS3ytGWM_3L3}(S1X~EJCCL-Ra`w_vdjKaK~d5`p<<`*IP;jYkI#oXdEQ z252Ch4xK`gu*mBQQ=*U9?R4!J+iFT>r*v_$nuJ6|90gh~$SHQGlImj=m7(0-0sBZo z0nFUHG0ykQmPLxm>9hOqg(A_7RlKsibDMPK(s8}~wk*DmcibO2;Ca9`l5d;AC}-m7 zXYx12iz}ZsIjPyP;*-qA9V)XEDb}5tXDC+K5O(yi%c2H#I&0(}C%`lZN=5I*v5tk{ zq@=}DEd>Q5a^!VU6 z0eT2jyCZ&ojU=@mGw1so=xNLbUn6loe{qang#Po;!@(3vKuz`p0^*P8nb&DseXr*U zuX6)=oxFMG@Lg{ISdRb>=>Q*U)irTD1X+Cy!1sp>;qUc`qak49(Aw(nO>c~V9yFD%jd=4P}5q0u!)CbWr8*4lJ-$Z>dtO+hj{U$&*)Kd!x2=z%@7XT@+f84gd zg(*Nj^o^k%z>a?-|H29n$kZ`5y_kUQs3^cP68;G?D&F77)|NK9|FrfOm=9r{3fGx_ zV*t!iL?9rfCz$qxe_{TkEc|ycBskn2Y%_`d?M?|3?0WIUeflZys)rGXPA%#PXv{ z9K8Jt*~ZS=-df+<;&0T4GW)-9G{IT|%SwQwK>=i&5uRAXB>FGZzn(sSgZ{53{`cVp z`xod#qxblY;{O7DICX#Xi$V4$;Dalmd0&D*47rC3{Yfhep920rhFXbrujf7>UBV3E zM}I&_c+yxsLo0m~OI@UcfHr9a{Lb8u__Il7;){KEep`rxU+_hT0I$|o%U z49X`A|Et^l$5v03jUO{~)I4GMN$2?2c6zvjr&6bn@lhI|;QulHb)O2K{+i-vwaLd6 zC(S=o{BwZ+PYeoxhnjVh{6o-B$_(|adK=KHp`)6c7KBxRWqNi~rkBKx;e@*mfRLSo#J&m+@ z%#?-pUzq+i?&9|tp9Ts%W>mrWFO2^lHt>6tPY;qGQx;(TO!7a#r)>^W^E)!(;d)%BS!@o+|&h=Ktq?@acl?F@6rs hU-*xgcL30TeIEe_EK-1g$N+yGfcJL@+K0FH{{dOjytV)U diff --git a/python/lib/py4j-0.10.7-src.zip b/python/lib/py4j-0.10.7-src.zip new file mode 100644 index 0000000000000000000000000000000000000000..128e321078793f41154613544ab7016878d5617e GIT binary patch literal 42437 zcmagEV~j39x2FBHZQHhcw{6?DZQHhX+qP}nHg?>CvkN2FC`PIKdTjbLAX`34>8PU>@i^KjzbxRdnn zB+RF;R|mYL%{iA-2?6_Z zU7!jh>L$P>>K^$-VC%T%aRpM%5D^BJeMvMC>o3P@>OwN1h1~0HkiLZyal6>nXCpXU zA|vKIC@@Rx`$TDd0~n@W!fuGfSr1hzmQ`yWq~XanN(?kcRyPuEDIz*KgU(MqualOl z+=_@y&<$zZq}dBEC{hsI0*=$S>pvAplua&-aa%mWL7=*UfT}-@@j_Ht;JH{)Oxoc! zj|`hX=uZWvajH_HY&+Ntg%azO8S(7ITriEi#GVB6GiJsJ2C+7>xf*q=k?k2_+wg@< zrLWtr*9~{zL4pZOPDKo&Fx&H&n9Vt@r>SiIdBBeXHR#Rq#3|}BsbNL9i##!W3g;ca zEm7DH7{`Vsa3niI_E2NUfk-vT4h8sbC%9C-4{_w!sF2How@~5F^|w)J+V;{nTjry zxbfjP18pM4NSH|&=62wWSMu~AoJD2JQ(!w~R8;EtEe4LZgdq;0JCPRKNkb z|4+q^#QbsUKV+RG8#wC_iq?c+y3ARM7W?$ReEoBD9JI;1O{Q@ZjEIv~e+#W0gzQOC zE2M}>mKN%qu`E6|=QwD$`i9FsSxe6u5+qSJW|>C&z4+i$vZ*ifeOMIEz{9N(<@5%u zHU(khTf6NFod9N?L(#J){c%UN;%WFb+}H1JZP5`ozj1t^Gl%%AwRx2^!BwchFMwJO z>j)lHhpy{d+9m>}bSXMT)464g=JY@mXYtLEI$bDe&c&;-Dr3RTowZjsf;4~J+&bvd zB1hVO=E=Zxsz%`V5?RzT*3ytwP6Z$0Uv5UMyA<_SNTHD_@?3pLboUCx?J;{b$ncaB zAJ3p-s%jf&eXSSD>GUQOPG=S0gmP@1yswUphw6AS)>U~w;xifo9~;?!_}3-;uC;$T zAozDC)|axyy*?gEc^&DOlNL>`K@bIml~+c3U}w`s}ngJPJQQra!J< z7EYZxt$HoA3lI&vO}PD0z~bZL1|~dhlY%zGIDzb4}z5eUA zEN3yJ9;b@VLUS`u02WueBo_1QA@c(2ZeZJf*tKG=@V|2)TO?2B$NBAs84~~?vH$>3 z|2qeaZ7fagT%1jv+)SPRI}6Y_&$Rw){U2Fy&1>zv#hJACO;bM^Uf_ndGCO0pl2+-X z%BMZyvDV<+%F@xXL6?<8m}wlfoTS3jS@C(hh5#6jZu}={b~{u~9Vr3~m%q1TL0C6@ zwbzIKXqA{=75V1ydHSq>H?d&A=%G&68j5=DyBg)wLv^f`W=q1{FrF^0IfUS$J_Xwq z<=T=EnlbUBIQc+{;y=S@df3%%*QQCdYx?-voihLYUE5P9m6%>r&CwFlHzqEAMh2L~ z31d2})$Ev$J=3H*mVf-L){#`4I{mA?vCoqa_O+>EN`eyr$yjXB*urzD89S0)85V~n z=S%uP2CFZ9No!DFI&-N;Utg4U9q!zM4;(x10W2PX2h~IgeXd$V);1bWVrA+VHP$VF z+n?{$tJ>cX8ptEojgLonWkt8>(qY@yBBMNPr`l$ueF{lV=hQE)#!8YnlT`_jplajf zQlg8tY5}gk`YX%kpFkIUEn&1dxGa8{mUo43n>%wto6Ua?5zp*hexfvE+w{20qx7-R3IO(1)UnqZec6Dk7!00=$_SwuKdc?*pmp5OB%}Kr zv6TmG>6YHs)%{#~`|UjQ@#gmS$=c8D$=2}+OhqRzU9_oVm@irP zsN_g;aTJxwYTL3Slh>C6#M>`fw_0C^;9Ib>vpcP-JXIsA4{p(4&Q`pJg2L&3ooU(@ zy{x^sTs-?em-NeAb%9pzugD~fRr$*)X0S2&bS2ET(U9x4>H=4m;JWum_LUPRZFNik zHX6}j%|cD%t0m5{=nw?r;N#&Dp-gJwTHE0Y-@@Bf!5Ne=(Rl zYI2$Gm&|5BV%Fi}&w&iwaONeq|7koGj)q44#WL6*@T86w3%f=}-tS1_R&bZt@nH(l zJSZ3Ru6p-KU1_o*3lxn4UgxSY;~(UgoDWy8B#a-%nb&f%hzM`om9$45ZxX^fAC7kg zXI6%}5hsTrt(|F#3 zwK}d2^Q6$B8)#-72?Xa?<8fPw1X!qFFsF|kW~lKy?JsC6rHqY}RQ{1cHCw5K7XmK5 zYA$c2j3}k57pC#Iif5=o-b&x7PPovjt@26_f)w^|o_Rs6@`Od7l-oSD*0g1-7X>$v z+W+PA`0=_~YV}GWb8Yj;gU$Y7BJ#52ksF7sBoLrQ&3fnMHE}Z@s)%(yjajE!ZQ#_8 z(rx*MhT3MeI1HHO;J_%0$>jzt?Q-Ydp1wN$BrB;$9B)`o*~O429jfm?(L=N&ktr2aQx!?dT` zN`c*(d{Ab#iLRPYefN)&wxXgHSRMaEnJmz-g&`C)Y&1(<8PUKU(-PH2fJD~cEw1!t zW>K~;Dfl1v+nj^IfGkscH(8psqxaTPk8^a^EA9-BT||(r`T>Jf)heQ{ydW&okScM8 za-?&J0fuIA+zS&=;kPjb3A@d6Y&vDZ=wT$yso*ph4@HCQI57Q5YC%Ni|eq^0{9ar>J1HzAV7w^Ezy`tt?kJq zEXP8M{&Y=AzAJ#dkyNAHx*P2zzTrQ{k;g8Ld3#t;85bbqBN-4;9P=+pU^+`iM)4}<}I#IYg z3GWQ*EF|oN&Ot5rG93anmX(i9m<%SkJ86qCiLIS&D71WVc_4LE#+y=3{D=Gi|nPm~W$T_IPrnphZ4+w{Wev!O(ZEB;u zkM35Kn{VW={4TT*xPQk}Itl~Bub@0BdQs=DTUKq85<+;A8nKB+2QR}l<8dV2rk~(s zT0ZJ~d8@Ad;a#^jT|tegDXlY_&3N5Z#D^Hzp_BmE=MGA+kUx$KcHsyHP1IbqHK7JgoShdU6m-E}MsvS$ zNTIt&g(*D-Z25^yUohuR52^o5JBY09%uf19qN63s&y?+q{{TqtkJ7e`qR`fE$fL)d z1n-&!y7JYwPN;@sZM4;h`kkaC4sfddJLtah&D3pyNUFJKTAOx@N@9;p*~@f!%nHKu$FbTo8!X&$u+$eJ#JN$=>q)*w3i0-${nfgS zXR|?wCkhQqVR5hf(;QezAMNfsVjCbY5FdYi4-?WY>cW@dvj?iC@D` zZ+MiT##BD*afvwl)4ax5N@6EK3IX=}k-0g5i6$X_rFPL)VjED;M*6Qu-lmW$%8ZZ# zsh73)OtXEl9^!=xtPd=?p|NWb1S0(6?F^mUQBr12{X)%0+9x|biPaK(;vgnoBiVM; ze%R>-UsB8C(3t6a`4Tyo#8E!BLl#3D>+&PM-$05ze8vRl+sG}ReUNiR9_Oz5e9fp> zX!L$mid_{R5n+~FNs1!gt|QvwRkwCp&HnBOK>xQbDfcWCGqZKR6&D41j}Hygjm?wfZS|IN9=vxOhYz17fo z29W+`mmvA^;A{VoFEVd6(2W(geLm)&Lnmf85wVr<*aLYZ!+xPm`dvljTa3x(=Nm?5 zumCPC6l+wz9KO;Ex{3y~o=dW3>22$%j?6P}3q#0;`4mR@Axv1=vIO&_2WUqvPKmzd zI1$u9wTTVm8tK@VL?*~<8-U?=0t;GMjVSWV;jrY^>L1X5Nr!M4=w<;kAaZ)3RDusa z?c_D_(6cTUHI9dxX50dqO%|(-Imb}yzfwBE%3qSUV{NXYaYaSp*Vo??zmKan2rCb4 z9((;*ITH1+FN{pW+BAHC*{32uu!5mXPN%bjU+!(!`2C;M_5&Uvs)d1yiuqtpXXO*U zXP||N)dl)W`U=cZIm$Ry#9F-<$1-yD$%?k-lz93c_=8xFJzrX6jHUUS7vBWM_$nr7 z?~mD7CBwvb3#t&p@FD^?CSL+lJ9i_Oy>J>iJVPS$9%ZrU-uN7mQyQALO25$Ow9-ht^r&{C@*{B}1j7+k?ngnC2)+>esdc zeB4nK9N5jwLq3u^H87?Kkd3clJP;y_Jj=s$h=)7$cx#Ht$>=524MGI>ZTVlJg6~UI zY_8n~_PE^@mBf(tV{fdp1Tmw=*=vP3fJs&j4_nz5%U?ksk1$$n z;ZZ0QB_mf0^;Un<&r8b%qP_tTbSZP@>1y^X3K~V1_1FLUb{k_FrqY!g(vD)y(Go4< z&j;=|>U`2LJiDBMu8?z%&YPdC?U^;SmX-qJ(ZZyCCpfpyJ%g!=J%+%_FEul)%;2zI zVqmRo4W4CJtPY<~GS;ML?`V|n7SzMn<>Sp%kt{{#cG4^pIckO<;0U5uaY73nj692G zSS5y;Fzklk1fd=p=q?qFD5-eez&3DL(fvA#imuNiGylraa+w@5T(&X$A0=~Kyq`R& zd6OPxS1t9GaI~#0CA$wm^-y|qmcZUG$KqqY_EzXVkQTqKyBRL_7O-yAyc~^jZBed~ z`gzP<@-Lce&Nx;G;K6<();!`!B0phyP8wo@6q1dBcg*4+Bg?D9Uv3aCP8jQM??kX( z8nRj8vCG%k`e4l|Rm5n|4Xi`bD5>NcNMimi_x7Tqg>a*(q`faQfM)lC@O@Xv1z8J6 zl85~y9jVN8@J(Zb=YM;T|NVS@IIBCT8_1Xi5q(R6=Z=@>+xE-O)9dkeck)cr%j?dL zI~<(n2P*4{&Z@S-FOe)f`bp!%%k3B0X-%3_E-OE>Fbh!fN@GkBh%z0=1Yhv=K}}yB z$c6A&L#Z^J3uxhYJw)VY61|(ILJ)YBaSN3JZ5mHcVt<3fRTYy7 zzTncV1O)Zk($7e37h&~ym-*-n5B(!WJ-p!yFv%B2=Cu(=sy-TOqQLW7gBjv^gMzJg z{T4P#?Azlc^bxTCYmyz?*+$W}6i>0ot;83WP^1-rwB9~xUm^vNU7K|*i(7D-eMxtw9(-WwMhj{TCSp=r9=8y07I|I>czux3|U?y9R zl4f0+|F0#=7TNr2XR(LNSgtQ8HfpP?iNukS!2m zL7HeS+YKYI480WkrN+zj3X|cGWBYleiD(&%DZxI8NDrPFGo*b>SSB{i&N$>Ux_%P> zj2n{4ARao!V*am0#{AmS7Q$NkMc2j5%*573y*Y8tqzNX!5AVD9Jhh)FlwUI&zuHbw zY~@D0WiMxpa)xSLp@W6Gc5K;B+fs5gjH}o*JJ0=wTb1k@;-=a|5EneN~c zUZmcYg};&25S~#V+&=|e?VYXUt6#w2z#=_;X<6HxmIZEK&kRPRgv`Ss#HgT%$!ks5 zo)q*Eoew83n3;_3C=aTruTBI(iS&9pfe$5t)YjIvJ?z&U&Ap3H(Fx{6Iaya; zp4?ntwpOpi*`W~_*yyd#fMOf1xn$Y=HmBdqiMkr0A(j4CviYS4AI}A#>(9)<>z^lC z>qxM2oMTCS-5Fxw&7SpZg*(0AYJri4BXSsyp(2=YKf|Dv$cVxVeAGROIOcN10fZnY ziffYYvtAD5ngom?$7^K7mI@0_dY4$!4IXh5_(eW9dcy?* zHkZtl0H-;pPp)YHcD7eT%0h0Icx$)$VV9r`mvCQyXiwgGecGn(d>qny8Fj>$XQGLo zojU!3$`!8J{&BNszw2L8yu%lW}g zBtBsrZ?q66>7Ie#`1@xMKY3^G8$Y@T?HC@u3=vw}EkdKPuq)Kw zQ9HFQSfI`dQ!#d zu06-gXT`gQZxipkNgO%aMNRB&_QmbB#V*A@;%LN=!nzk+-6A7tb6MUJr!!$N%4x@L zNXo>woxm7nR$3a>-k9&vzdc1Ka)l^GPi^84-P9m@u^AWdg(=u-bnNq}ET&nO@7*^y zGWGJ_DUz(9NxrXYlLzboko~)ED?#-8FaiNIah^w_LVerd3>n_;-9Q53(Sj&do%K~9 zJsY#_mMt!m2!9=&e?{;VRx$S7AvR}RE7VdeHfxVj|D2>{-{j$Ro`(m#ukBw+v*ZfL z|B!uuaPr9?Tf2hM&_~+~93)A3jmwIkP8W6PlSJY|VT#$Gy^1M2cUJdbjGMXGZQa|E zL&n>?Kkiy+b(cS=?j4?7<&SUpQY0q~zCF{ELEN_mtqH^Ec`I@*xUW_ek?rXAtdd80 zYeMNuInCdfkCXSxBSN@nI9}?O;UC`#eT{020by-}yjA7V=A&S=`fd3i4x$RX0XM!f z2c6}VbE-^w?PdJLp-e*%buLZ#9F+fEUc!K%Ux7!5T7E7b8~6?C|D6b2o85YTy~=;l zg}DG)-1_}!V3ygXgYK3 zz9DgB)pR0tTL||owE}gN>CToawE?P}IEIRYTJ0RHUSz6}8Pi(~kyVD?C9=Y@>?7Y- zrb%)9U)E;A9xftid^sx}=_zq~eJlI|VqyOu*g%4+6T0d(M0$)nuU_Z?KYC$L4_5** zTHia@H`={KULR2q2OyyvjeyY6z`49>=*P14ysEuzXmMKIx3P&thL|m?`{{s*h}BCj z=CuXxkwF602)r~-QSMVM&WgvP!J)+aKR=ukA`5U)+hT1khnNK|2DvcxdB zwUXrD6|CP^QZt^}6SaGv);O+lOD2*IMnh)TbZU~f3>zA5K{+pcCLj0NJ8Urjih3c4 zZ5wGX`D$L|j?AH*?jc@(2in=n&i&1h+~!}Yt)AK%C#pW*n3JnKk`*Na`G^)}@w!>~ zEJ`JUio~CIGS2*s78fC&5`cMQHoq~~i4XfPMK0H&KJ{}4f;+6N$dGb-0dw$hoX>4AV}^Zaa5U z@`!U|fMO5#x40jk?Q&jOjFo_p+7xS+MvEGf%Z+IwH>8XGjT5(Z`OF*5h(qWyc z6SREUb6M3-$9euwGQQ8HBk&c0Yv_mI9nSy1zLYTfi|BE!z7xGAhUmYh&!j_Og`r(cHc!lXt~(`2Xi3z=mwKozO=9g_rQ6UIbrkY- zoBN$Ww(+pju@){=Na4-;x9@!h_5(a^kB&$kEIk|cr@Ke<=a_V2WJL5q46R0>9<52i zV1f+Wz+s3C<%`5>9E=5=HA_r#XsuX_TK1P?TVwecF@shDCbO7das!Kef&n{o1UMyx zLuM!*=IvnQl_81LxTYp$c0bEfyP)3-qOTE7YBW~9V!1MGrP9W>%CQAlk3&k^f~*YL zfs7N(Y>+mfCKkD7M0H*wsHBFno~DJx=TqzC_1?6}`<;G9bw{D&o~@%wHi@)gQQQ6- zg2pj^K&EET(C$nV_(>s5-GQUP4Pslc4asUoWVnV6G(gq6+aUViZr6fFp;#tLUM>g- zOP-OGt7l8#WZG6#@gEeUnTJo$EDwS6onzm;^7|?rj&{9!0u(3Y>(-C7h5iQ6j=+Uq8stvPN zNT+=EJP)vUIjX}eBYe387{vz;6+m%h%iuWBQ&GAeo=e1iUd15H89!Z>rA}Qe-=RTq&NvcM zR^MOp*K^Wz=`Tz(pr=W$!XM4>2VwjOstMxAbNGT^dB>UEl50@LATckD*Wtu6>;Ft& zQUj26+5b3U_WxP`+XB zG;~ySQnB&s*q6`~`V@18n3_xB3q3T1RM8VHgp zF#pG+;I<*%OBD_P02TuPF#o%0SsA(+>KohJ*q9o-SlZh;|38j*tz&Dy#fkiLqd(x> zC!xJ2(fN6RXoEg(w$t^L#2T>!hQ1?IV#}pQBh9BMcS!Tnw>Nz!TC9glA-6@V03vGZ zWX73;#g{`=!eh+f=P^LO!75h=zrM9lh4DL_CfBxgonyMBN~}SD%q(rn6L#A~3vu|Q zbqvN>zxI_ibW4hqOgyW^F4>+U)JSZMab*-o*Vw{(Gf7KAYm{h;@l(&MvX%<%g1%9RV!8>ILE@7>2rWD_-^ZV4mKY=c-_eN3> zc4rU7ReFg2cBb=CI6eU0fLjMM@R(D#!q}jbwBkSPrg_ZEeu~@|5Ldu5EAF7}0!H3* zV9kQv`7rZFGTkb^H#clVD=?T!Sr`haP+swi(W5oRdNIvI!^((utmx$)+TgP{!npdY zhw#9&WZb`v!ILW{eOGKX0mexeTkF{f+;w_Lon{bCDVI(|T1E&8Cz)|QqZYb`C-kPU zpZGv=enUJJ?9_Sy1QlQaWG*!-yH8FS9~1wCIjm+7An>S|26AU@b=n~D&DXmj>SVlqLsK~OcptN|m*I+5CJQ34SB zlK#oSjQFMYh(duVlT~A|xrJgTPQ_2&EYb#vnWuZ@&Eh_`HdC=dzYrws5+9;SEOP1(RwWH)DF;%( zbP6z0-hjW%wjAgI(G+>I>MZa=-{os{_oGdF)-(jb#UM7Y01l`x1A{9L7c0_GYgFqW2IUpr9@PU;;Pmi?dNs>5t$P>zJW-{YkA(9$eCDW;!a^vO}B5$c?|7bv`tN~2{0Hg$F7`u7)B@vmysjgZSTznLR8qV_oj zu*G~BSX77R)L{s{x$0n@cePZ+Od<wr5Ixb8)(cT=A|cl5 z4y||FYu@*?p9wLL0`X*hl_|iT6Ya_boE{_GKwxJh0Ahs{vO)+-*)kp-N*ymghV6Q(Inw&h7&@X%Ib z-`yG)e-%hRx^4YnU<+wA1iwGis~OxL)Lf))-ZLb^OPNmMT!-PYtp}@R_S%$zqKS4xvoJ4wRPAfv&ph@p^y$$C$tr zQ)%Z6d5B3qlf2Z3Y=|-sE6o%KSBZ3(+Fw-wUFySk(bQBCf8h~_xw-gRpXiymImpYG zyu5Y#jeY>+P+Ol}aT}IhisH6F2q+O|)2I1nO~w6Jy;~6S`NguUpIr)m|GZdebkS5| z$#F$DQU_ZILAo!M0VEym{O?+Sg_jQsWv1;&&fi~g=12-NZqTbCgbPk9jw1wl`8bGp zV?R)6G`Ebu)V$vsJI92FU)-JM!nq*P+-dcKsJH=Is_Ms6li(G$Bap!#AL<5))HWhu zEtFGt`E2-K28Se1Or@1#{|+2=YatMJJ1vFtk!%pK1D3%}%vt~rL%E1ysOkw$uNG}V zpj&C=EwThN0IvNPJkdM2Im$xHYBM6J3XFq+$~LiaR9@E~(5_Y%#RvKmT?3mhhI$b{ zOk>DoO#e#iM->QMK`L4s4O$H^n#U1YwwCK1g>-*jNPItQ$XX=pXto4CQ_l6vc2j8; z0VoYsR&qMhG9-3bZC-K`=>Hgh0-@rNIUOv8hZTuB)=5$6PyZxF`6CWL@>|={biU$p zpc$-C-!2}2VLMdJVUk{!8!~ggPKNEu-2G6#^}lo;+JazZ19dH z9<3~wVFeOL3EUTG3J?HB(I__?I$c?8wCof6=^=-si|!l(TzEPBT!g&5`wrC0%6H0- zP&(4NmwHq0ZH}*;5hOR4lDOVr-1+S%>%})DjI9mes_6{F&vw82W3NPAk%lQyy5cO| z?|QYehSG&BuJW%RuPmyZ@Gbe6x3OwTg9zl9!7T`r|O+`m-X3*$&p)c#mxt@^5{Bi0HBCh#o)37K|ghtlDhH zLB1{&tiwFMV~DG19s#y2EI7#=h>&E>ps?_x-K}rYU|5r`F7k459d<$G9Y>8cK^@=8 z6qy;%Yc3>pMasI3Hd*h1+@RQfZNJ8j+*ZQ9<*GGjDV)yBR6*))=a)9_=^8QpUuyA5 zLHMbb#@}AojI?J9^+R2dp5nk4?Guh`oa|5zRV(fEuD6ZC;}sK}NxkdlnGy!SRvQ2N z$2;VVc7g)fQTe7oOY*&`gHflqAMuwmilGJc5~jorkk8w*=e@Zrg0t=)Yucz8gb-h- zA&zC_w+;HPF?qVWFtFkfyL1Eo@kF=gvh(51V2bh!-ML)E*mNzqGBp9UX(C=D-J90QkcX-hp@y~zH&NQvNxEM*FF8HD zJ1YwO-lGTxx{d7oj%87sVP%BRM1$Xx zQy=7x?k8KuMcrB{?Mwlas&iMcJIU*MtZAGDm?7nxxBNmOPo-T zsk@zZ8*l{KiC>%oZs{(bJDcalb519;SD>#JkWk9cA{gYlx9B0xpRIFd6zpdmv(q{Y zti?!gN!y#3tbp6;Vuz_gRY`*jIq7DWo0&Mr)OC@?*2=UFZ9V$!X8vQ8t@VL?GTX1Q z^3&+0?v=a>w0ba^QAQa*?%E`D*pDGEJ}h^uS+mon4`a>WdzlgNGw{sUMt2-#c}kSj z_yTEm6!cLa7XRAQpUv;s_F0jaNzerDmOjLVPPhEYP_R6L2s|h&7Ft@Jwuz|zbFb&m z;mv0}=FFrRU(C_ROMA9Ktytw01CxR^bu>WH3wATE6i)>X_DiaS$-Bm?W?(CJD(ieT zQrjVU8-@;>lRiI#+b}22D|+6SbE9M(q_y5&SI5ulb-n}sB&z+T5!6LYsVGGI4H6XL za5PIPM2#b6?(}Z?eLkF?f=2ai)y$XR#%1eoM|dZI+~1Ts|F}giL_mQUul<`9SjnQ2QFYpVoD|era2t3sJ_=ar1}O zj*T0^hZ$Q3KXzAV*W2aujGn#X&i{Y@mj25f|A#O_)usvns5t`tk30Uq#x!$77gKjb z&;K6L{x9zMKOI}=EwQAZS$!sH{gohiu&KBv_bF=M1hP%V!nRaO?mu-fxX8lD=#Yc} z%FOJCpO>q2PY?-esT}E>p+8iSpk`iPUY*`~H+c&v4X-ZpGLM`Sry!0jUOlh+Z^cX3 zZgUSFsnm*R#1mGkla_Qh$zDVc6XI+&G#=E&Z*sem_v}~k(!wdq5pJh zuSD}&_QG;!wKd6|k}9Q^VKk#nne~^AvL>wtNS0y9dcD2LE49+R6fz@K(W@G>PD?I9 zS^stD0IwZV&fTb}Q823HMAO!ZR+PRwRWxE|?nxF;dde6W2jQW1%Oq>3)_V6JDOQ-& z7y+rouc-L!U@*MA{_bAD9z&a{!Vl9Bb}3Y7)s8geRLPK|N!17^#}{r?>ZMXgi8RtW z85pFJ5t^cvKut2&dQRB3C=>&EgA=HwEYRadYiucOQpt6dQPLVFWHNf5|zxrYf z$5lw~S zAM?!pr4GoHqX6Mr2;zCbAyT1}`-~01n{dzaH(1vlgaI!Lj!2dy~$5kjUJ~XAs%Jm0FZ(AUIX{hL+4Wn#q zhz3`@&8Y|%FE6i`i|2!{6JM_fj4ykqV2UGmia_5FL)YKm2@6j*Cl5Yei~{3q_-PeG z*N(rlbKqUA5`Ncxt@83k)IFBp>O-+s3O2pVO9i`IIN0}h1dSg9vPK_2AXJR>9skJ8{L>O(bPLM zJ1QgJx^NQgONvdU2p@`B0Ld11L~H=S0i~#@A|l7Kw1W{-c%gpLEnp8E{J9D}HeR@CgX}CcZeEGiM0wGrZ^SbHhfz1`B!xo-r z+ITf#!goNVcNseO;`zBnXU|-p9p_J=Zu&fV0*ZuvRn|8U%T(3JYc95u zV9PfO#mc~Ujss$#lMpd>nAjt;KH9=HFRUJzS9G}CP&M#d4G1hQkWoAN9=mx-D?9_7KGexvX0(EDWywpcSw_W(SDD%w-i`^?SCW+iR!g_wTP z8KpmU`9&aC2Ma5A4y-Y=;kSqcxIm2&yE?ktI(|16{~|1umKGmsAx&$k(SY`W6x9M1 zX`qR6%7IMOgS4P`l#apdd03#w_LqE<0)JGrN-wOBI&}}*M>uS$?TORWkw>X zTjrutUBX&IumET%S%^6lfv4k(cuCa@Jm6>qm5i4IVrZWx1yUpq2!S)tcMc}XU*ANn z#sO|be&OBR*3L`llq&V^i>)OYlP zkJi#ZolW~h%N*JAcsw72YzjP8FY)**motgf?n-3Ak!rl zL^Blo&iLW96=Ow$7RkbN^7eY|alSx-!nQ(C5taf~3u6Lz?(?CNB?h|}fq~a!L+Tv% z?4=q)0rCwq1J-u{sVV)ccJcX($;K=a)F01ZmJggUUt#3%iQ6cf2kll#j2k zFXz|4u?NBcCTUIkOlkXQ>9-(9;RWmtZxFNZ$LRljQq%jen!{1CWaAQs-q8?7z9z2Z zB^II8XG8HKk9mO$Jmu1|XcPu~+$$+7j4T-4-P>X9kX`oH4T%R(m$0Ouc#Y!70d+zr zRsbBv1x;1UEt#bSh?83|yKanA5T6K%1@sAh1D9$7!WQvu74k=6SbhZd(an2_!1{~p zjSc8ykF+L+_?CR?RpGbmWdb!}R5%enSr;5sDJ(sp1N{&*U(FWP)ECavjt`ddaC#3w zTm-OXmq3XI0#CSOvg*oPxja;7eC0ChPdRzGq0dS{$x@*mM9V2ulywsLyR}03N_Thj&~?wYHC*_M>oAfzhNsG|VQ& zCgDmet7-&=+oe&hG%3mA^$%RJg4>f~ljZ{gL5TkmL6sah$9E%AI-!UFwe=fwh0ZOd^_O>An<4upV_eQ@RR#2g1x)h3}9iY89 zz<1L~R)bL}x!H2fk-`iLxiavHfRlmzI}#bii1b2Lfu(8Gi)~U|+P*dtV2NO%y|KDx zUp|?H6Qu#%I+f&KG2YFOnOb#?Dglw9{uQx(8_~e#ycJOnTjhzQHj2HE*;=fOjM1Na zkCV2ZaTO+zXB|#1#J14dn2a16ODv-i4AT?N#-$w6yn?UcN}U{fxFf{_^;9(7t6IJzXZg-J2Er=4D8gF1;E zgQJBKaj}6`$dMW4$_ETYB|?MlatNBdx`1Xu%hIrdL4O7VkZTQQ9n`}zqGaScC{B3d zt(O@p1#=g|q5zkxG*X5wTqFvId8wetv_?J#g{Fd4BiZjunXo~7Rw{Wi=n{QuJ?UU7 zga*?xJXvKKeGnNY2j0^Jhk*NUu`1^Ij}MNOqZzcSTAH;B)u#4l!L6yO4quoK@fWgF=TpKU-=s%d026H2~G$?`kSV7-Rf(4Gi|1?-ucRbrh$G$FC8 zwOG>rzrt%qx0`%!s3PuS;xZh7LH{(_p{;w zZ=d_(1)Sf=WOn#?RfDx+4tEVDwoTKG>V9CgeY+>3h{;l28+r30cA8+V8kQLmlgPZD zlAshzG2USjG>S4REflEdKFR7hc(cWvcx8g{?Yx5C%BFZ>%00q!6v(ngANUhfT);7OY5! z*I*6mlUx?5>XRY(8_!W)rh7-vbYIrU?bNZH?5As-; zAO}#AG`UlKofccd1JEw3I0WJ=JdssR0b7k zw7~T_K-UZ*VNwDZo2xqLTGQ$}ACk;mmu9qOi8Cb^EZX)E{ekIX;eqVW5pCwY+r+qa zN5{0UWs-VWO~kokpG9$C=ufv5A!-@*+Bs=^#4G3?5r83U!(1GFpy#LuI%nYTT zVd~TYt;&==K?A(19Yb%6(5#3mRj`x{Q7`7ynHHC(eQ(7AGTgK6{5bUk#nFtW;A1nx zFT{I3y`=S-D?Hr~Mrc0Xa-N=tw~3~`J!rNEFnCg;+U^xep}IEn?;G#~B)HigHtRNK zY1&kENdbAUQErIUWHk?TB#OY&06?-aBPf5VOfJ|imBuAzlT5V7rCm^Op&)76p=>py zxx?ZHfIE~|Ni^NiXak3EV3qDYHE)f9zOBRR4}&C?sgwsBJ357&SS4sK=R;ixAN97d zNJl_7jn2MrQL+u5h&dnGgf+s$3Sr{GQnGRDLdeRD)E7r$|C>7ZY3T$T7mU{WwN0 zSrT>xq|q1eTa*0oD0zB@EM>pfc#gg%2(x7=|Zn?5zyD+H(z<3 zsbij*ScK!An2WV^Vu$9zsRFLx5)QY|$uvBAOL?mZVVfWW{}E3S4WL6PH><5Wsk6la zSxVnm(QSo#-8$mq|MryAEKzC92#ABvzMw}Drp;Yy^-9%c0?}fnV?Ko_1X{#m?3$0A z)r~!hWB+uM$;0WS*ld+DP8|dHIdgL`!@W)CZq5Fgv!f@3&tzNGa5o;KWP2t|r!7sE zhlcQbI{vvMoyF0(nkP#61)hHc0=LW4uTO@j$J>k3{OZk+YSTMQ({2u7#O1t3k=k+k z8*9>y0UlkTUPyaBzi|j(7poi3bhX-Pt{L0N==>r&8=k!yUKutqTVhA>&Xw3PcAoTa zdZcca>*b=eD49u;;{g|gwKjG&C(-^6>GxWbc3`QE%^MB2+TKQKzZs9|I{yQUw2R5+ zcJ0}(cjD_oL+=Y>sCCZkK~o0{X&58P;9Qh|xLJQY9NiAf|dOy#SoGWNzj&~9bc1PqUcA$U1xOnET>h&eA3EIARg zM}SGqC9=XJhg3lYx+!N=@q(uDBAPb_-CEw1FN_D61*Y_{j;dm2tJ!NRxg7MH{2i0= zyL!Ui^muyIT@Xks3c|88g?YokGXYDsp$fmO7Y`_vQIw>L%mkx+7U)LX#-YTMcP=xF zsj^(5VCf*fq;|&MC}+^PP(X=MPTsQE9NFd)XeI`>`e$m#hT4Pcg&i(554+mapab0f zIEBr3QbmAjwuS&grd-n`JL&5i(?{5;nU>M?b){;{$i2oQH@7VO!j*3{Ws@h+ShSWC%Y!X1mzv^=L z$Gc@nX-{xt`a*U%S{LdD58x)%iJs=gEND)sq3^yAbqogEjNpofuXKv~PKK$><{`06 z9@EXhRo&1--Qkg}+Md$@H-zdJvLq6A16$mZR@=j-T{U&n6gajt(z;32scBx8sw(m< zrlN{$v%k$b`3gjzvAI@E7MUM<<4+8>owWXeZ%yo9(P?UKVtfYXcIisiqRj76B6tME zyTj+n!*m^p3)dGLpC`d4>0@ z&1n+LVC7N}EHGxcF4tauT#p{IjI_lN1$SG~HE$-vZ3Dh6NyI=(-9AzPH$Kbh$XKtj z80J1%z;{`!x8GbAY+2EbF*HJI{cBdFMSmc!9SOW}ruo>tGA#ms&H}ZV2wUO{ufM$+ zMOVk~uD)#P=zbo`vmx$XD~VwzdI&%Ta{EoL>x=00ieMX+QujH$LewR$u19p` z04AVeTV76+iL*sh{Piv%o4O_5IjNOV2z~NrhP64I4IZ20b7%`&QN#2!OBTV7)S!p{ zQiHqNXA}&eo<5i09eu+977MxhLKM%xMrp^I=b|ylp*|w z9loi0X_ROAYzCX!bN!uV4G;D?ZhGkE*n_vmcugE2<3HUI;!*^Yzb46F`)2gM9kypy zeY#N7TkqPnUHd*sb@NWy5q8k053B{#u^bt%c@d^6xl?#q#{88JN;D=0((Cg9yl6na z5|j_G_0}cDJO>PmI6Q2$E%iRAYw~EEYrBtAZ*$JXyCK-p0@gEkOr_ym%+&dYjqYXC z3+K=}-sHqH(8p8V6Yf{lxxSoXsWsKYUv+bvZ$I4M-+yjGJKR1iW!0qNN4&R|5QzUS zc+qei`i7dXrfrMXwctrd8+d9oXved65(^#Z#mh_|e+Kta))eFHI`$kVRguayJ4UPX$NUs^ z6NjR|qN6W%>)Toc9kZ>u?j~QXDew$K4Sjj4nVTvH9**+BHV)w;H#!KzMLdrwo{%|F zHdh(u;@WDiWip_u*Pi}n4|&d#v{t2#CUK;yaVm&M=mKng`Z~|MbU?@DAI{eBkp2vf z_xlwt1fKVzBp?Nl@KkS;BR)_|EZ~VXS%XJ{-DECtcO*9O&<&o_Hi0?pE;%=({3O%y zBViNO_xW@de>umC1y8#zIAImSJ03(yPqNtrXWR&|?MYKTDxlzsb`C8AeqfLEvEJZu zIH0bN@Yg!ipc!w*fo7atbQ;|2$E}l1PD3#@VS4gqcKi3b5}g^~hA`12m8HqXjFH6H zW*}`Mw>=9zNaugr`|{>Cjx5jr^(m^Mi7?=SV5y^LVuIUr4^7cBZIRTHl-&+l0Th9v zSQZGtC_og~>+in#uFRKL0i@)scM#ncN#uRJeCO{mErJDrt!u1!6>(nz($vMv*%^ZL zLw9CaPdZ8h47HzM9Xx-L4HCq8xMC+?e0{LrstbjW8NtBszncmkgioaFSkRrX zsBDV`2ubUmBNI`|4j0&n;2&#Q_n+^boo%(Mrqh49m1MnKwe0lpz6W2!!1O|!1{76` zt@ZGU+{L7Zb9_fv13{O{d$I6VCstOqL^3jIPB-V<5a1;SQfLc@hKBHSxr*X(FlhOaeMATQYt zTFi;IjNlN=xhwPpa4tXf3?SMJ{j5M^1pHOMP`(06D}-z|y*>os2FF7?fMU@{!b6ig zjFn^X0quLDHnQ=s4JszNrHwHIYLR0eRm!VZ38YhHo*c)SLOZqlOPEK9@=stI-` ziF+q4Sxy_z3B8&eT>x(@0CT`yG!)Y8`tq&1`I%Pql4)0}U2jPK zQZ?%_I!ejPp1eGN@$!85?a85~%L9PqGGp(oSC^0eOOY?EZ%#VWWFum9xv@5vUV#T= z>VJBAaymRcIDdJ1JUlu$JKKA9(AwR@;{4?4NdX&CoPqv=+7nBlb+X1k|06$A{JwXW zL$gm{3V}O0=sHNW!O4-}mT4m(nn@iIuwhlCD6)4iAg_DTwOpbx3$#e)4lz288nR)l z{T4KiS|??>$H~Y+VK2j)U@U#P-jTYd~`9LB;!O2Dx<$DN8z6>LZ5mp-)n27;TZ# zBxQBM`&J!j^+!l&=4J(_M!16406MYSL8oN`r&cQliLnDB4Kd&gRqhLvZ$Pz6!K!a( z^Bd9yVWg=*4<8jM-GScuy&I5-xgY`zP_#|b&2{2c?RL^j}0$GkB|R2!T6>( zpgg#!fMF`SN~=f!Vv~A|Lq`~v>iuF`kLwl7ht7qO7vbX8nsOwLL;)|d#(%Vp-aF<> zHz@Oia_syV77{Ghz#&y5@hVW0D^dfJwFrwg9FA;ffIr}2GriQS;m~5Yl3OWwLld|3 zDL6Zue?DcmXQgppFU%PHAwhuC^s+aw&4$%oZBuSwgBG1ryI+7t0S`6);NJ1HJHGhg z2hXVPY6y-Gh=*$aIE2WE;ZHVPyshQXU4NUF2}@eFY(*T`ZU9wI0}j_^<^W?KUlhLGxilmy>N;QIY$uf#rIJbyY1y1sEcb_X^n)__3>whvl;R zpa*qF4AglTZ)(_IzmZLA-Sr+H9HP5~q0{!}TPWd$;`kn+0AK`R81A|8vXZyBSSE+l zx>;?$ufgc=cwsrD0{-!206hkOaUT}nNl?ldW(n78B`!G)1q?S9P{Q=A2q7x#g4TWn zcZ2%PjEw--gLo^w(SQS*{=J3$O;&#Oy7h3{RT%&RZZ|{EZJ{vUH(Lsw-}OOu=I>cz zp#DUX)_i+@`Gs?+f!(=w7BMI8I}{rQt4$sf+AxF?a!?`kCAs5(Q3Dm;^vcd>m-QQP z`YMZ&at0c;oa<;i`hqOE8$fmUF;Zq%Lv;5G<+_%DPr+|-QOHuke<0v?0fVVP`7i{3 z+Ef!wGldEO@WJK}j!U>U-0$*sXrxV3ANZeyF+^K&xV94n@ek@;37Y{QF{RN;kh7Ek zZW``}!S1dk*wzK$R^|kk-WU`xGXsx&vLtt7U{{OL+1c}bn%@ZJHLvoB44YT;W@YU? zBZB;(P%|DAdIX^^#H-#VMN^M7>fA`z@ckSxVz|9keMyqgiVQk%UV<}$u}p@+gXnB= zKU0eo8HCdzk4VYCf{3Q}j*2cEf*L zg>DsA0({mBxN6wnZE;nj?U}4UjTlyx`zPA_-?}o<$eLn`*dCW=qfsbJuPR>C2vEdr zl>-#MxK(@WcNenx;z)Xn&ut)^SP3W^3K0l@Ftkhl;&H(cBmh2=#09a78aR4}w`_R1 z77n%LgwZeT;Ui!A;KcAjV}Qve=%Ng@M(r4gA<(KX4>c>;ZQex>bOSz4I+AASI@5d_ z=P#sI#*xWrljJ$4zjbMynZ?5hE6 zFDNQ1->7;8FA7-}L50$!e4^AG6|vuucy%J%hDv&8Q7+SUIjI6*AsiB{rlYL!VrJeD zjdD`K)v=&D4d5scx`Ee|au)1*4&hc4f!;Z~nI0Svz6-%q!S_O{u#^v#;>GekH_W&{ zFCHVTNW)Ms7&8p4%NSi3%gdG0nZn1>Y4KQE&CS*D6AM_M2ODQ+Cz8Cr{JP?Z#Dy7E z>LKz3`n8}&mJutwvEhp*_!WX`#%LEIT7}fnFE}!6@;vU!Lm4oU6zyxHf?YlP%s^H9 zOQH}_yj?{?Q-zpJ?Rwy`4C!B-P{tXynzw2LzCk`{z1olITB^K7#1s5yYmf(wH&NgC;eZJd+4By?H z9%m!lEZF}Qc{4jd(G1)Fd-2%6D|c$Qb4Rm(gb7rU7Z4RGBvj7wbCftjcelC{e^4|C z5?r_$#9-~Xx zASo~G#-j(8;m$;R@gS+rN2p?o8OH`gK;3zKRX5ps}^yp8MZA;-91sGoZS{9FW=vY zT_Zy2qI5CSyg&rzYNM@ZhSWlEx8EJgR~J=C24S^B9-0Jf5&zNMFBn;*$TTZcUm}Tz z#mC|YuII3WKB+y+Mlx&qM1V-?TfoDbTZ~gD*%W6L#Vp{Fn_Ev34*RSxy~Z@0tTpFJ z?osr0ye!9+!pKLQ?Rk8oTv^BngLtFmw7~p*a#$FIbWDvqb2++dei%cKY#=dXe>8^P zJ{IO%`*WsF+*J7J9nA zzP`m=vA*;Z*`1=qo>HnDV+c1Jr;vd%_hwgRU=Xm52m1ny$Fsq-O1;#t!e5D`}$#fjEG~w;eND^I9h)Gz+rAnpkyD*H7 zVIY|3quIeS859nq1ZjS5n*6e{)y-ACg^u1EC;JHBmaPR#!ip=l&YWA9FV#vM;QMx$BoA& z2W<>`qIk@+*Nr?mn2Of|!@BOZ0zNf@j7~nh2D&!?sf~3{^pw@aQ__=N^0euy_deqo z%ddan2@D_bZ%SCCBszG>zL3Rqj`7$LI*~GHDw9>p&ZK(O=iGs9*|39yt3e)Hj?B zH9Dv9Ij`Tr*$y<}36Kz(r?`jPL}$ScAzBEkWF3%(s8?B55hNvr@oODm^oWJ8B5l_W znlCAMIqA~ir5?HVy|FGTr$WRrOzX>vxG}aQNjeHNl?YJOs&+O}*)~V^xpbQHCvHV> z{kOW8gJ`{b@uAT`Ki(GsE8uYG)kz9TN%BAmmBc8i6ugY=X?${`K_zD7PM3akiZG|v{sy>N~)e0t4W zma??K=?guhJkH+vN1*|2f(Q#J#=Bn0R~o|D5;LRS@;bSxID-tDr^p(EIt3Z2^`CKl zSVtN%BAAiK`HH5phbf#g`*afHtI4a4UO_!}j8|2We zrnK|RKtUtHLVmrXJ74nP5V(95I97&w20(#jS$>p!RczirTw;}M#vU3DdbcsD&IRe) zqoNDb<`Q&H)L^0$)1zg2v!-l8355s6I5l&N0pXpB5p6z}qRsym=%qr!(WEmQ;;+WX zXES;!ZH8uQxu4@~hPX1xry4#CemJ}=SLKv|XS>p7FYYzc6%vIDnhf=f@Zlv(F(EB% zL`eL>Znp@vq8K?XO{7)|R4(D`qAIHxCDZl*t(pA*mMF0nx@}L0%`O&}j&_(Z{NGyi&<(v;Q0f~8oXx-gf+^LWwXt+KYlzPIjFIaWnVge-RPffHeZIB3GkKA z<_WZpSjAd9`V-ufHfEl$N()cSjb&^dQAO36x4jY|nA+q}vnXkOXQU2?1B%MjT|c(` zT}o3(o5*bWC*7%|GK%RR~v>nZ+Wb4tSv2`|L_H1&$Ispkjf#MVaK{N?522=^Hv z^^uDO1iI2u%h;&N!aS|mzCr)zl z^Z&q#dF$(%Qp$s=>D{~r1{P&M*A*^Wn>zjvsO&8-hrs&5!KSCw@JrAhzE!&GzH7Gk z>hWvyzd%h(hs?G^|8%*1g*6F|2Ir>-dqqR*)l#1|5-z+C#JTSjc^6DjsTT%3AGday$OY{3rA21#smc=-m8(tm>9wB7$vWM zT2aSmnA`4VFI-9iVvzz=Ly%?)lHGPY%Zm^jCuh}(OAR$hOImo^7F-5TK!_XD*{5(O z4+};UPx0ssrr?lj)DUuLHSoD^-axWWK&C;tB#TNDhHSy3*NM#_9d?=P{`~ltou8ll z(z6&%xs7XI2s6GM?2o6HH|2~_pgGmt4#n1h%$uVerUI?uHJpr&nc)vJGsPcfcgLYh zlfBk4cvbv0vDkYFS?4+s#X0NC(8?J#=V`HNUUcc2b4AC@*TZ6u^FuG@H3TU@NI$1^ zT~wTOoRvPvELY_$(MUi=nmIItg^T%mY3B2;Y6j`mQnF}H3T%Z%N{en^B0^Mz!<4yu z2{ENl!gNY!st|bs%Bd7uOs5>tIlN@?obh@nV|3#|fz$XbN_CRH*4$1=WGx=KB9m?2 zRu2KmYo$ERFRDv`USK45p%)D1&O{jVNz=r1TOkxqY6KA=0yWuIAZQ`aAh^fw*EP4K zBG>O}ZgDmi-PnE1Ey~8Cn`$+;+|_s*SWRUFZYMr;a+1lI?*(=+)pV>+lZxO4K;j~BAazuh4&^Vb}xTWJww0O3FEBxDz6l z;66R0oEAY_+#B;UDxOUNLpW^Q)clt9qT2&~wobgTVdG^<@isfJIq*4#F^19|>Cmai zk8)~hD}SYPYa36V+|z^ogTrqRoRhxFmlu2J3%n27#4$5Nw;ogj$%5GPy~Uje;fT^I z%;q=tK|^9Bcs1xAPj2DBWPM3cIPh0V?Wf)sOnYd)k_?4KA$m{&RCaUNjMYX-0HuMY za;pqsvcpgk1lHNb2K+B`hvDT-5VxB`Afn@MqkAa8i*++FyibFpCpSsa z5ry1~bl=)cU6~$4%lku4j>L1c!BO}yn(@rlnTLP8?}BS$zk@%C^abk4&fv#+9XKew ziV}Xf(}hW~^}Uioyp9)K_oZg4Z3n#AZ9RPZR+!o7xYpVzQgxoSdX_JtE^rk$e zY)&$?ip!!vk}D2{^Ch6<8uq=FrVUgF=~H#c;1lM2Ap>4()|b@#85Hq0^*#w%`U%r~ zTZR^|htN?*$@}nm2sUQcZ}GvfiQW+wAxKnVsInw-v~A?GY3n&;U!#HInJj~$h6u;m zQda<#lr7y!-l>5RkXGgRsv95qbVx3oFU!z{=lmm{CI~?P9nMm4n6^1Uk^i6G-Bf0n zKkA0-L3%O@u@LPt%GWA;Hp+GsOA0_55XFmR$wFo>*|A4PKzN+3E+pHC3JI4HTaDA9 zPr$xmy0c!m0V3eHM+|22%*gvnNi5OA-yY?CMI;fQ2hvCqh4V)dznUt5mbI;?CPTMp z3qfXhCrM`Lq3UfCLUO2kL)q5d;^$w2udq1gZ6YF|LteN9WUL!5=8N^T%sMh;g%2qtb2tFCe~joH;eIjH#Mlf zxe>d;Ok$QalP?9i#9S_w@fh0HS%=`0N}6K!MP*DeyU8#MnxAEw6-c7FzSA=erWuCM zxpaL&FP@ke?Dzi&&AFBS5UCmV12RN-ilZQ&5fws)=@*V9&_F(f3!0;W8=}AvEy6D2? zLEUZhj~+sZ0ood7DX(Xk-`-p>-0gZc2^=xGwW1KVpWuaHGC)y+Vt?%wd&Rkt9!P`9 zcNarBxkAsSRHKwG*A)nXIOWwc)EgItpxhO3{t{Z#J+pW9#d<|OK8W^>Y+U~95$Q9! z;S-fmCv+kH0lw>d%!@C7Vu3DkO<@DjkF=3dS40mP5UfmdxC9Ai%3nZ33Cx{kAj5v> zlkG4(ZsyAdHp#4LGd7RKvYv-1*nsIY23fMEB!!dt8X?h*yiUo@H99$Gt{)__MaV9^ zf_c3*%coS+v4#TGxO6iEs@P9b1W$qjI*6yBZfru`G;#38q{+jiPJjYY9bqJB!EUUs zXeJ+5`pd(E=TA2#dG)t4)~L7jl9%M<<4bfQceXjC%N?REMLWOa^wmh*pbj7KM954` zi`I*YX|JmXF)+_=z3Re}h3waTHqwD5eh0VNy${{FJMY@Z*)b7&a+~87 z23CT|;*2WB*Kqsi)8B@>&~`oc`_g#Vb1d9B?_$e%)BBuc^uqIsWMqrH(p2u*!Nz3W z+nrNo!}~Bb(%p`(HMP5tvEvhS3y;1yu-6tIhtQw()#KM;5GxrC7{=V?9)?e>i~qH+ zOUygT0Lj_YW92+RzMrf_fSj{05|ATAQn2Cn10OU9pzYdZbXYx+JHD4}3MW_}(0M`^I|%Tck6G>DXLTPYA502myz3q-?+fDhcr$YWvPHESK) z(02hN$N|ik;e0^=L}dda<}Q~O_s@J@lIvXGYmOZtR83{IsnLjPHeo$pft&e-!!cb+ zmK_ln$w|bIX&n9d?t0{*&zB-|_ls)qW?lzXtSb8lE1i=1@%h z19p1w<@1C6^TU(lE%;$5jO_5&d*AL2pYI(%8}2{fJ3~=!vys(1Zrd5*?{={CT?4NZ z=3|S`iznoPNbazYUpq$l?B)J`1Xs#i{!tHOa4b5ld|nU)67>rabWW38y@%|owY5C@ zfrBuBUhf`SansX~kl3bP`@&Y5X#=zI>`aNv@Rk;kK&X8khg)8n(xmc zDMk{C3I7EnytyY3$Oj0}$kkA`Fl-8bXH^tv;i~r~G{@Pl;GOuku5Ki%^PTybTL1j> z&x^TPKKRRkUG;h(^`h3v)O6wseQRqi&2G%ErK(jctylGQaEjk08dBd%IuMuLO$J#v z^1O*793zd<=-t4TPK#Rm;#L~#tAT%t*5Y?%js6PKHV(TuCluXsGJ*f2cSkuTMPk)_ z_JS7M(1pp|b&Cqd#qHmI#2L$Jf}|#Zkc5AupKT5Gx2_tr;jhiH9`Ig~Joto#1gK`S zf9&e1aydC;JVwfzf=6L&PZH@c382l+|B|CcRn}@6T18}*TB(|WK1h9_x3}{(*1X4b z02_jq^$Eqe!V^<2ZRUD6WM3Eg&B6CT7XL2Ce!f)a{qv5eH0G7asg3z-PjL*?l+>ma z`Om#56*Bx??lLb-m{}&L0k$3fp>@FbQvRCj_3nCDPeR~RfjN|~?;%?oDK6h&QW=ph9{WUn9I1T^&$tMw;@lw~ zAuxwrA&s;I90?r>PiPy6(Dsql1}~4lIR<=+O?m?vR$4XS9h~1xd!V)=wBE9o!1S^I z&EB(vx?%X92y`*N=Y>UH@I5#N57-X*YEb}u=` z+T08F#pr3sI=Ol34QVz9Q>Umo=}(xZ)1PQb(nS7lWY1e5{iSSZ#(iwa8y0`HyzN>M} zm4Y=qJos)qH=Zn&kA&~Gsl9sqI*Rr9lL*1>Rx|{nQGI2vo8hcMmfO;C?T_itoIkb+ zwJ}WBYN)^;Wo7{;|@nP!=WYFx%+W*zF|jQW3RbZ z!{f(gxU%*KD=ox2QlBs9WLjxtrv}cTXqf(e=x8d z77s(8N}&J>{RKw&MWBlI_>;m5vDdq@fXqB_n1T`z*yg z5CYIH%cLY2ta`-3!4zM|S~QHqqnqfDiTpLqkC0NCP^jFtcv~B&Hb+)7HkqHBuDZPu zo?Fl1hFY;DRdcGAb=m=ZwPiV}9?dT=?IA}a)U1XFp0v)8*<-%qWXX_jVBNq2-6Fn4 z+)60RvGnH10r}MVmiz_YHSnW+4|CADl)pd!%Dm4nPEY>%eIl_kK2bvNK{K;64hQ6QI@;%Ybhi^@RO3pIDYYX7V>TB@j6lmyqXcjXhrnr> zrL5P^3HY1d)Q#H922!ha2vksv5%D@?H*-f+CQIOvn-x(2k9-*}&Km@N0lQv-?+-St-!&NRo8S zv_-jCy-mx7aq8si>cv$I=p7QQ>+^YsZn(xBI&w?5T{o`$9Xd%qh}r&PDMt?=*%mPg z;lW7YP0NULXD;Ouoj<&4HKMd0D%lOA)@-%DejU5U(L`iEsil`MW9c+fN3-t=Wqb3# z?S=#|`{;n=l>08wbw+#S>bk#Wo*0)ZCpl{QzU$rVvud?o1Xw|S`;e?4n0ZTVJ+&8* zR%=)eW*eT;5Lc%qPe|@wQp4`sZZB$WtMi&@HF>kT8=BOr>GUxQG}X}&sW6obwm+`Lx^XZ`DLj(Gc%^o-Ir- ztBw#GEY`WqS=nYS{_^qpDF{M`$LGV7FTOt5KOY`G&2EXkEOt{8S1={Dg^2$gE8<^n z@&jEyo=MP$g>k9reYVO-mkUR^K4dAopqr>ADFVBbj zC&%9&oSv)HFPmNzzb!B1Z$p$~2m_d*sGIDmq4rL{Q3_`-p-WGm;f;a#bP8fs)W+^4 zFUZ-!safMY_xKKc3dJU$cp!47>)qH7D9#;KPZGrrpWY=>#*5)rN+4ZPp)6vX>QM`b zaXO0|IB1>S8X5O`z;R*c^>uYrPp5Sg=S<5aLZO#Was{OlpVz`av|t03{TV>?Dh)wOcVe^9S01qgv7pJGrhHSyNsO)GeHl8Ya+vtC~Q_q zF<*b>PbZW1#JSzK$vhrr!hQ7ut)le?<}LI@i1}5I(+*GUmg8tkC6eryIMH<4L!Bxj zRFLdriZ{ca+bnPLvcgRk-Ig`0C5+k`LeR?1dpqhVhT7l&ktMU5o9@5%_J~H}%d>5) zspk1IhS)9lb+>KHgC)q)LeMM0;{;5VdP;UX<8`x|U&F)I2@>An@`3HZ)I~U1WXRBm zzc$nm5>AoZbz>MVSD>0TYs1dDvnY7p?2>+6PS8gbZa!A9c2&>~PwVl@rDs;2>+k`C z(~Q7T&|%)4fM*CefKD4#IldD6-SvF3#w=OL3L@!|p%ZtB&Tim&UMjEx)YM*7rg%MH z_F@tb3%668EZ9Gvf**)*3g0>?p31CdHe0y2-|JcZ6F|gJCKP~LoYc*#o|)-G1LTt@ z?pov};`yn;r;8@c8OD1=j6i4GE`OCnNwt?M5cMq?= z@|aYE?fDb|F=LnDYSP`6~%-6}Qrf0W@i}rUHv%{Fb=%rx`d_MB}mbJ+8K@eT2b`3m-g22MaM;Yik zj=Z7tc0@N;rl}HE2yj!SL8a{sule5TX^?j@eG$1FUGqENaCcldDo#qGDM5w_?~2-N z|ACdLz}>iYTPL38zwiEnA0B4#+#k{LR2G^7$yfkikHAK+G8d`!?c9MA;`Kg`2mfm<+Q?ro!x+o^R3bczR3x z1f_&)GLKh=V-7bn%rv{L_zNDZ!JH!x6L8mpCt%PA;`HfoPR6dSnvH zdF*?{TPKcYhX20vQT_kuEas%zbO0I{tTDeVDie6s6Y}q!g~M1z`xigs_UvMA)B7J> zmn!@t9a`Ot=hFFGoMoNgJhaED!T5s7pYp_0F6ckUahF~Bu!C;6>a32mm2U_SLRY*G zs!2}_I;1ou^?}*LmV7WR?@8;9CE*S~)VxZkAUZDn;%!weGTw$hqnV-}Lv>YJXa*=( ztwvYfbR$BjTws!3ep$^I73MGX)UY;FSJB0gYrN;42tux`BL`!U&~l`vy8}%LK}nqh zip233&ulIp9Pr;;Q9f^T9;fT3eix<=+Gh0z;XpUm@J5ZLVV#rOw*+Bp%VQK;<(Pnk zI_Bh3di`(hmV;dkQVr&PSbV>p%w0=)8*-&xs!;>rrLSjn63$Pp5rbJ2ce~saoW7@EQNEgpH#;jL zDHGSZ>!A-elsDr6SJ0G>l$FhjFL^zCH-9SwA?F0NR@c`x_OqmKR`YpbWXoHgse6w# zWJAWq9zI$vLrsugCNXh%cJ%^xFPBx-D}or#pd+T^V})G?)Zp*vmVr!^%S7VDDNhQGUPQ@oj29$ zhflwp&pYO9+$wFu(XIJoAAVB*eo+$QM?}Ie5Ek2%z{p@3ADqJLay7m(<0{{ja5~8F zGh5Y=nY(^7D@}iuZ~r-;QzdfDSV;;Mj3tq)0k+K!{c>!+?Edg+CtfAD|H$Uj8L*8u zH6jja-MoG6BlTg!EDNp<{Xw>g4;A2Yn@%{609wO@0A5P~p$-Z8CSI?(<`POu{7Ov* zs=Ms6LvAkJ<8ltBq>;_sIS+!Oz!QtV&xUU2|4ciP?h(MgvH#%@Z zim{i|dTf2Y_06`InyJEn<7HKuc7RMw%k^x0WmPf(H_`OTo%u05B1_Z9)o!QLr(17- z141-kteV|dul?t9Qdh(I%&_>i4I+_fo?&C&#`P)@0-{S8;fm?iq%^d9R{x)BDBm;P zLMli@fgghhoUNFktzRtXz}OAnEa&TmjcFmTHE_I6@s-bUae^2$vFRbo5H(7XPXdp{ zxK-|mD&A3^*uO`hQv>z&=Od2x02V+8r_`Ffo>f(`b7@AnGxB#a)w*B8bO6K_JPBOe zGYZ569+}8HzUj~b8KNUfYfc7bxkS^X88LMdGZb)9$Qgr*1~;TxRAajPh%E(D_2R0! z2AiamM*=gc_#?Q8hfbbA6s%`x@tqhY@DAC|JA7aG|XOaV8W&mO^( zJt>zHBV)3&J>=QA6vhO{6((c@+u)Av{L|BfUFCR_Y{%YU2%q}EWkM8Y`O@IJow|c)?mNu?v1>qcolY(t^Lj#iQo>?#O z7PR~6l3I^!!I67Q{L2c={zwKj8uVfsZ^=m{O5xz-_{*)tcf-^|9xq_8%+~QHN@-}f zp*$+Z+Zr|_TCsq7?LfoJia=lJ8<#y6OZcz1xdDi0*<}h52n|>RQ}F4K$70?LCe^!{ zc`*kr77L>^7oZ{V+PDPm#-wV-%X)zjWTkmnD}39RNSVxUgW@|_7ci~hNrZqII`OMI zFp<-3Y9u>uEIkIq`lBz1Y4gvuAn`keN7b<;5E{F_Mivy@yexC$Kb+Fa1YZ^`LXK?J zI}BkJPv77ON1!rrrh(lM;i;UFxECx-Q*>#u`@Bu9Pjd8%_loQK%@yU6qyWQ3wVd9P zWIG8sSoS`KvTMTFEASmP@9;G-(DjVKL4X*6b`|y?^G1bR5zeBIe#-P}O-od8;-j9V zd_y9kd<6SUi^Vzc-w+}HtCORH&o1iOXFp>2q3^EeiXeAMq0F3-2v4MTHYDaGrsWy_ zB;aG={?yK!PBls0gGuX01-YsL_i}+a z4<0;tSnTgTfBwbZ{x`+hi-Y~cFAw(xy(I7`4FOC*-x631k(+21<`u3PLD&2~t|+P; z-Lp}X?bNRi?65Faos*JT%EdUbTIY>4b6Fuf9_kG3f-xA@lJehbw;^S!SdcPZ63P(E za7RoDa-j`%26sVb7Ca{9%Nx`=D^yUC$~o6K!B|T%7>VjiW1_>>xDxA zgeUl0&nFdndy7Y@btq_vvs$967j+Isrk{GK+yq*9YZc7nat(3}oUln%PRW;k;LmXi zGuSOoVJ4x8y>5nL^)p5MG^4rz@OeY8lqKa9Rg!Qaa8XpjrO<{krG#M%909A?mY4(p zkpf^w01;i!y1jv>8Rg{=-GX=7?_>bcaOOCngM~hZeRrB6*{u0raNOufQ7f)F`lrj? z^ySZIC;Q(F&(2Q|_Kx(jf^VleEko|J8@KFAr(IRN!UH4DO}g1>gqW6pd$d|Fm`|t| z@BXu2{MG-9k4gedBsa_N)~q4?1bX>o+z$9vR{-NNbaxzD(~Fua&7=g2K5<1u?;9{a z1YF~F1DSI=d+ftReoIRWO~0~m;5Nq2y*Zs!3igh;v;6${mz|%V{L=ID86{TRYlPi+ z?EUSU0p{@kbbu9<>lc*=fN)gaUQ~iU3Na)N1(BHZUzM}u3h4eGnYn^#};JPiYUouB+ASZ2=I3Hw?9EGBrMS?=gIYnwZaL% z6WgL$Cr9v}T17VMu1C<+h{&qwMy_*B_JgDGAF~z-$B9$oiuPD41 zA#lb(hX`mS;78@4jRwV@ZYmU?G(;i$MZn-hH3jkfU_(L-;7IR&KShS_RJX?cZ|Raw zj+F9($6kaRA8^`s=W1wK%!1bjcg-wq(izBm3}|D!|B)nsX+Or!8NH(JOqks_gtjw? ziVD%a23NjG1sqc1)%0JaUN0Op6{T&(IE*yzMxTo!rR;g9umA#vhaWO!#Ut<^HhOhH zatL=KmEo_PfQ_96P#jyBs0Vk4;O_434#C~s-GX}{!6jI54IbPzxI4jZaCaT_k*&LL z@2!2i_tjKSbx%!w{nu$ZJ^wlLf3H7`0C*fPp>f)rq@NK2Xl|62$$!Tp28~|64V&Vp z*|#T%vtR__J&XPhZ{BO1fEG&nz|@coRivDNTWmM$gS?x46J?SPmv;&MEX^~Y0@|v* zw7>QSeVzt*pe@KceXZ~eTVE;8rTAVEW}`JWN4Ux`3RSiCH7^XG2)B zjY3kpA>Zuuv7Q6HEPg!=_*}F=zK%z+&xkLc5nSkj(L|3<)kySns|N&^$^5}p%9S}A zCJlH-j2#!=JCf9Xhi+Cf%3$9gut%ha70r=!M5_5jY21EQpQ0NMhMuXps+bXLIMz#- z{A$jd6^u50ZFTgxuo+D3eeP3-Z&yL(E^<-L07fkKFRz*T!C(mfbN$O6* zn2&ugwZDQ}cS#-xi_xkA4$a+InG(d{z?2y2JSYcV#+A8DP-BqS^A3hd{s`3cA`<6xs7OA?@TER#2})$U>+fW=4P zM(db#zMpsI?RO&9joyWQfUGgT8m47OlPIci)@uQh+N<4wGzj39)rCq^q5-bJIzby5 zwo}hSlna$}WlV%^1;95{(ZmKsp6pS)P5I?7E83uwP$8Ct{Nc@YHZ8(_S^Gf&mPo)~n zX#unMoKIlY?gvj1TI#h>HW~C{$x@-xAjt^mVuSN70t23b7RmE4B@O`|R-I4Q3^lcg z%y35ps4ATcz75w0!^b6XLVD|nGdl(ZZjCwmdoI#WbMhjsrF!})xAl1>-&cCvp67F% z$En>q`wx;6zkl@}e~kQb!-;f=4d>^tg=-wn0p8-bhb^pXIp}-dx5sj!#`hhg?L0np zgzvk^-bhZ)56rR)lftqr!I&m&rP&kk$uDtKUU(3ls4wmSFcAdy2*`nWKN;Lq>&0xX z;5o2Cn!Amit@b_JZ0y}kf+wxz0GXX$Y?+(G6tI%)0l1g zj3WHD{K*kh7F3$6YfO!39cE{oyMnzVUu5czvt{_0u9Rn7y>azGmtm|6?9}m;qREt^ zWmJw#y|l`msLLoqXn+lBL{{gAdhOY#3)zj^)vxwrWt~CC%bz#J0_Gk`O)DK9)wsL4 z`buNK^aTUHJPBkQ-mBR%%*j9HUp}?X)gGF7A5(HRvZVJp_4D$%*PY8?0Z?=VNrW<3 zC9j1D@B9b!3BV`@aIKlaXGcCj>%)%IiCPoqw|%)v<|s#J#O+c^#LFE8Q_bZStzC;s z!}=vTB*=I1wlgQgE}()BgfS%jiOtG;VNrtp)ck#zSj;ajo}}bj5t5Vw@-Cp=^&YxD z$g|z{BoYSeEkBpTbE_>CeCBjjtkb+Cs~+|io!dtluhqty>U9shdGNx*s@kUf34wi# zVrBwbQVNPP$;KFNeIfZHxh(`ZTRg3=_HN1pR;K_5qCI#pQzff?l!`Yi-)SvDIS;ta zb|g~+P|A%TD8ioiTu5FkofOR|W+Y2bzOWGPzu>I=xg;tT2Mie++7jGYf*BU9+QEc} zBu?8q*kpde8mnW8>lgtjhHGP{8mCoB>DwhGlY`_ybQ^^hu6X5(@@lMh1$=B4u5tAF z?ddjHBKd$0AIuZ2Wn1skhqRJL6ijO?j!Hhi^$LWK}qJ8$i~+T zhOH5QDJO7)F}sgXYf+sivq&vHKqg`=cW#;$vW`oAyeQ0DZ>_AuUXS8*Jt8Z;tjc|M zuWYr>raY(m)Q1c~|0nQgQ-h^m?Hw7X%+vtb)zk}imwL56(@nk}_jiHL?|SucecfD( zAs4@xM~aBd%)?2)(o6{F8X>xw?U=X;QhrG?PewHaj{e7(Xvik$o|(sgs=mE5H3_Q4 zlN^OuDy%d~W=U0$ohb+m^y*}UGTF)NNBUUyK0$)rBeVDkOJRu4ydpYajV1IA<)1}hOCSZRG|&J5Ri3{oM4VlnJ~^2= z+5bxc*arUh0OelJ3-2fcpn-RsbzGTde*cpa(7i8^F+cie8YbHca_S@+9 z-(SeqaT1$aNZh+LQidf|3?Ch=66O4wm+^yf9-s_f`ovPo+@XI9#MtEda&{hO+LjeAkkhy{qk5zk2uW53j z&8mxe^YC+0^b$YlF*kakiU3~tPC8!kE-`(6w2T0;fH`m%VKdNJ_2QH__S)yc5@#JT z5M?kC%3Cx$OZaX7+9===1QOmI^CGj#J}5QXDZAXU6~JX)@A<($0W=?&o5oU01nxSs z24HBjD5Z<^!XYai^g&5|23dRIhZbn{;>NyS!XUn8LF#-71HT7CCib5=c=1C8^c?md z8`r}Qeuvd-nXu+d>+162gLAT)T`_L-vg?(ogkXUBg|zYY+W3wtb!ql4^LPxq>kZZ! zN^5%q4LAw5zki#7zt7f;w{E~J_Fz!SB-A{AVD5zhq7(Uw&9fBnvjm(GKUx1YY^H>b@>Nc>63EZi6=VHOhvum{HjW zLA$v#unM%{Y_jo@8J?>bt52Ugb-ak+iF%<0k4Z$`XGdxAdd(VhZ2!Ssf|v2k5(8*0 z6jlKQZcItMDeeJ-$_|v1ZAEFC5Ln>*aakpE!CFums&yGu#kAhI;kqv@pGnA$H?4i# z*h6!gOi?}EpsmQ+A&ZO|e*1h#Biglrk1hWsC9a$-R%*{7X0Npke5~bc=%(e-F_!Tp z#|*$AkqQEpxApJdG|84sC}WgK&^Sttr$JhE(NyhDJ4|5npx~!XENiP}`eNyCG+KJ_ zkpFDum2`nuJjigl(6hUOP=x*!_jyIp@{Y%!*HO-7ys$bIsyAbPm0({4M*P$44}qU# zg^1liIUc*~wy615hmJ51T>K26g{U{}4E&>Q?y>O4z69 zxP1uP$bm5s`2s!TdWKoYaXDv#)iXV+Zh+6s;CkH$NNz5}*F+S$^Z;i2p21 z>ou9|VW=!tv^<6q;Ri^S=^}`|Yy^=JR00^UT&T$)q3aWaavKpJI0Y@ck3zS@E+SOk z0|MVRGg?fLJd_e4wsW7kg@f^`54?AvsJ!}vTU;9^INDN7y=S&_vLU(M#wxLD&uXyK z;2amohE7hQ+#^?^n2suyzY8HDAGbH#T~ZC(grt9?@T}l5Z!hiYi53)rN~`g?#fC|^ zFWOu0*w;WHs&5ha?rgcwhBRJRZN9WzVP)A(6T~O0_<_OnQm|JBx)8!9VIs;4U;N8C z63?;n^~9o39$H9xx*%#Osj?Tga=zK1iF+s>p}!2A#eJ;(6~$%;KqydiADkb z`Yg(rV-l8EB9`~43m-u0vC(T<6Gut@eDf@PV;d~mjR{`lJpLsOXG}k?%2ak(_cxYk zX>z0Nkb!%N_X+=I)f~)+aEu;`Tkm^c9~;ugSXN?^T}S&T|h$MG6~7YykGQ8 zjDD}9)!8m>|1m$F%3{fr_7s)^GsLy#5VA=3=q!Fgh$unP?u$J%F1;6=Z%{jn*}Zca zf>Z(F0sIZ|bi(GIqc?HH^KS~WQgYpUWrytNtHp2AUD>V{?O#5?!;HNV!%NepM4}mE z5#&)KR^Rq~zE84`rd-sEtc%gBGub+Jl$j%|L-U84Tu8Cv@J))@>|^cwvF|`T*X__B zmi2+%;rc%d=+JfhVhi?q!F z8`Wrio^|GMr%5E7PG7(75!F~c?R$@2!hr%`a;7&zEcOVN1G?ju&g+Dn&5 z$ZEAdc2!?yhiF+&OK)5?Ajlj}@%8z4$x<=C>0u4(f@n}e0ZUes&RTPIlo zyXnjlRi#KS3AXcxB&UK|WLoJ)d_P<~5v^dS7M|b?&b3qYu+k$IXsg8Np&niNMSPBd zbZOTf9=-w&?4MoX&rxeOnpz|p43mD_5)Dl*6j5_b6>5!jMNfgtowD8K_;3hcz#fD% zq8$u>S6R*~XM`6I$<&6b*>WB0XiNe%Z#*Vsn=o$X`!Whui>D^tIt^@fVRg4eMP`+OeZw6%&fc%@AXeR}z|8#iZ2*!bKPDBc z#XQb(7lO`7j{Z(0&Lj!=g@?zly;<5;VCM$(<#zZW5iR)Z(3Zc#Lk07rV^W$}WlsWD zd27Cc=jsSI9b^RF6uby{u_AZlPcO)r+|^*HmaHq;2wClx_VlLwj71l6=R9LcmF3m8 zmftB2B6$)N1AWIxM0B&52ym%b-ibZs6$gdd2Z=V&77ozxG;)3tq;908IN7jR_Py{t zKzYHRsICWg`m)g|w5qM|^7;l^SjA~{ibY& zI-@BjQ4whJhb-cAp6P#XyS|^N5RPM6Ir-Art8q>_<2ea^o)I=zITZ{OLp;j5QnC2# zY$R$%H}=k8Q@_14KBxnK#jMf$_51>14x;&>n|NtIr}G*1Jn}aEo@>Ks7qkL8HNIf& z_DF7}W<*=H%{ z?G?KBJ`*}{j!@kqI7ws1^U;S&sO3?Rp3xhOFT*@eFb|{H1XX(!6MZ`gnj9bi?=qLR=vB^U@gxYKf3~K^d{}RhHAOeM&hr-efEkJ zU6dz!f0tvA*4?G;HDstF@1?0oDg%OHiz|_>(praSTHxdf5!#`wmdn%yvX zM0#~NgXu&F>vsRFeXJt5h~(99NN1!XB41O*(3SzNVJ!}u;8M9L6YQ_x+S2{-1qs(| zPMgD|LgtcUB1XIE1bR$FjUy7cd4BxjLVGYIp$hg2kOUhr49dtFnGequNSzR9F-+vg z11}He6F53dpgMVdNHMvc#1z+-9nN?P4X(3R+i%E2$~Q61cEJWJ{yG%3Vre#|O|I_g zHg_Kd==*1D#dk8qsfd;?iW{+BhC=hz2`QorHX}IDbnzGwF8T1PG5KT`k?IW514D;I zQgCQfp0%yzM_cK%Z8^Ogx#Q=jKV|mxb99uFzE6SonNitK?KP^F3iGbc|SZRHR=#<+vtZ*0nhXu+I;}F!O;IF>50iCbB)f`PJzo%p}D>0V5 zvU8t$6C(FI!S+L(`-1n-O5s%!;!z9vz(uB3i`pqqp ztJT#Zul69(K$S=C1--E1^-C-#$qY+DHXl}0Jscj4^im@*Ug=SMovG4ugt{%Vt7?Fc ztqI+-IoqOI{ck?Kt1}mX$052CWx$hZ2;S`i*HAA^DiZfG9vYu%yfM{8c@@Xw_r~S! zeWR6!Pz@CC%S(Pwd0DFYv_eC*nvJ?8p+R4t4XTWdUnX{uPC%$)E7|-re7`mvw*1*Z zN(j)PKrPr;1M6*EX{3twlna@MCV{!iA>vw1MBq8Mrm-AaW8Y%>Haha!madnAf-!9= zllMJAyDiNgs4+z|*?8f@5p#7=#{oGs8}5^_dnV8bRL^^o>;ms5z;WJ9tL3IYnt>)b zImDuz*0vp&d&r7i zR4&MsL|yWZyJ@SRwMwoE51W}3Mib1Yq}EC0Ci4yEztS^*{OaN@#10Km008WJiU#FB zRjO__R__Ane>137Xeq?4aG`b_YoQKH>hi1;`OqLfOpz(nh@=A)T;ai)8CAII&`T54 zxnA%6#lK-%@DG2eU`mmlSjgb+op!t1+Jh>0C!aIv6>1N>-NM$T%@*{%D|Q@EWbxdI zMUo2q7E>L~PX7cps!4=-&Tzc5?(C+E-T$=R&gZ-zsv@s@Cz*-?Hyn$<$8iG(t3#c{ z(&Xdxr8|NKz*((dFLf_Z#~;4Tx;P6h+=zWvl3_DT7z5n-v>H%x9A>(VuP7X7xX>Cl z(QMbu7oJapH(i@89G+!A;-T)|K#N1PM0E>d6O?L33r-R>?pV>8ZfG-TFug{9F)pQJ zi?Gh4S#S5hemXm&T%nqEsz+a>If~UAUH6{)*zp$a9AcAuf!|b%9A10)V^VS)%boyM zEabbc48U3kSF{0ZyUZk}9*`^K8WHh zTZJvEH=DPCz;#?I0HTmSYvSbDqF1NOju7A%I6#tZ!f*FZ4X5*$Vx30QNrsC1UmT(A zEV@%qLdtsh`4|OAvAK(kD>@Tq;h*f?-I1)htg2l(;RiCjp73}4I_1@o;RnmTAsh4I zk-VWs31@E0QlB~XprjNsgMKma%q2vZJE1SHzs>BJ?XZ5Zfzk);M}YDVX|K{3q%S=B z@5pmiXc!Iobr1s8{ce+BpwpN*?|cCTD-2*=_RL}qx9!LCa!q?7dUuzz1acM0pTUSn zGLmYO8mliMRQ$qd2?8TbUdpuZl z;Zm?whIT&eHc}i&}Ec-vdsFpsc4KYsqCQla%AEfvG}Yn9rD1Xb2=CWdbt++d-2 zx7kmLCEP6>1+Z6bl1aQ}YH?tG2;V!AD1G@+S6l^3wzhrwJqYzh4j*)v2}28t1hwR@ z*ghz7QV;G6H+&uMAumfT>6be%0b@Y$pcRd?cOz|FlY(K6^YKVujMygUqi33A8n2v= zNBgJ)z2@dtQr6<1tMW=xfw}CyHj3)@F zL~*5}*3jCjK4)S7a|Yx1`?;3j&}$>9Om01=8Q=*1o?=!4qEK4^VtaWMbl`h zqNiCTkDlo{0oV5M`rlY+!26diPe+Bz<_lJMCc>@4{3~iDY38vWZ>wlL2d-wOSaCLhZX#Y!Y z-pJ^$?-?192$E2(g*xq5J z)_`}j+~29N(0@_Q>}@O@Ke<`BzK8jr8MkdGBy-eEuuce`$n2*ni5||K14DRDWUr zp=$s4ru{?xr(6B+)U5xc{y$CY|4#q!O2EIi!Ph%lixZ_iONfiWKVoR0jay-(NQG LCzn+IpRNA`+u!Y8 literal 0 HcmV?d00001 diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 7c664966ed74e..dbb463f6005a1 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -998,8 +998,8 @@ def runJob(self, rdd, partitionFunc, partitions=None, allowLocal=False): # by runJob() in order to avoid having to pass a Python lambda into # SparkContext#runJob. mappedRDD = rdd.mapPartitions(partitionFunc) - port = self._jvm.PythonRDD.runJob(self._jsc.sc(), mappedRDD._jrdd, partitions) - return list(_load_from_socket(port, mappedRDD._jrdd_deserializer)) + sock_info = self._jvm.PythonRDD.runJob(self._jsc.sc(), mappedRDD._jrdd, partitions) + return list(_load_from_socket(sock_info, mappedRDD._jrdd_deserializer)) def show_profiles(self): """ Print the profile stats to stdout """ diff --git a/python/pyspark/daemon.py b/python/pyspark/daemon.py index 7bed5216eabf3..ebdd665e349c5 100644 --- a/python/pyspark/daemon.py +++ b/python/pyspark/daemon.py @@ -29,7 +29,7 @@ from signal import SIGHUP, SIGTERM, SIGCHLD, SIG_DFL, SIG_IGN, SIGINT from pyspark.worker import main as worker_main -from pyspark.serializers import read_int, write_int +from pyspark.serializers import read_int, write_int, write_with_length, UTF8Deserializer def compute_real_exit_code(exit_code): @@ -40,7 +40,7 @@ def compute_real_exit_code(exit_code): return 1 -def worker(sock): +def worker(sock, authenticated): """ Called by a worker process after the fork(). """ @@ -56,6 +56,18 @@ def worker(sock): # otherwise writes also cause a seek that makes us miss data on the read side. infile = os.fdopen(os.dup(sock.fileno()), "rb", 65536) outfile = os.fdopen(os.dup(sock.fileno()), "wb", 65536) + + if not authenticated: + client_secret = UTF8Deserializer().loads(infile) + if os.environ["PYTHON_WORKER_FACTORY_SECRET"] == client_secret: + write_with_length("ok".encode("utf-8"), outfile) + outfile.flush() + else: + write_with_length("err".encode("utf-8"), outfile) + outfile.flush() + sock.close() + return 1 + exit_code = 0 try: worker_main(infile, outfile) @@ -153,8 +165,11 @@ def handle_sigterm(*args): write_int(os.getpid(), outfile) outfile.flush() outfile.close() + authenticated = False while True: - code = worker(sock) + code = worker(sock, authenticated) + if code == 0: + authenticated = True if not reuse or code: # wait for closing try: diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py index 3e704fe9bf6ec..0afbe9dc6aa3e 100644 --- a/python/pyspark/java_gateway.py +++ b/python/pyspark/java_gateway.py @@ -21,16 +21,19 @@ import select import signal import shlex +import shutil import socket import platform +import tempfile +import time from subprocess import Popen, PIPE if sys.version >= '3': xrange = range -from py4j.java_gateway import java_import, JavaGateway, GatewayClient +from py4j.java_gateway import java_import, JavaGateway, GatewayParameters from pyspark.find_spark_home import _find_spark_home -from pyspark.serializers import read_int +from pyspark.serializers import read_int, write_with_length, UTF8Deserializer def launch_gateway(conf=None): @@ -41,6 +44,7 @@ def launch_gateway(conf=None): """ if "PYSPARK_GATEWAY_PORT" in os.environ: gateway_port = int(os.environ["PYSPARK_GATEWAY_PORT"]) + gateway_secret = os.environ["PYSPARK_GATEWAY_SECRET"] else: SPARK_HOME = _find_spark_home() # Launch the Py4j gateway using Spark's run command so that we pick up the @@ -59,40 +63,40 @@ def launch_gateway(conf=None): ]) command = command + shlex.split(submit_args) - # Start a socket that will be used by PythonGatewayServer to communicate its port to us - callback_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - callback_socket.bind(('127.0.0.1', 0)) - callback_socket.listen(1) - callback_host, callback_port = callback_socket.getsockname() - env = dict(os.environ) - env['_PYSPARK_DRIVER_CALLBACK_HOST'] = callback_host - env['_PYSPARK_DRIVER_CALLBACK_PORT'] = str(callback_port) - - # Launch the Java gateway. - # We open a pipe to stdin so that the Java gateway can die when the pipe is broken - if not on_windows: - # Don't send ctrl-c / SIGINT to the Java gateway: - def preexec_func(): - signal.signal(signal.SIGINT, signal.SIG_IGN) - proc = Popen(command, stdin=PIPE, preexec_fn=preexec_func, env=env) - else: - # preexec_fn not supported on Windows - proc = Popen(command, stdin=PIPE, env=env) - - gateway_port = None - # We use select() here in order to avoid blocking indefinitely if the subprocess dies - # before connecting - while gateway_port is None and proc.poll() is None: - timeout = 1 # (seconds) - readable, _, _ = select.select([callback_socket], [], [], timeout) - if callback_socket in readable: - gateway_connection = callback_socket.accept()[0] - # Determine which ephemeral port the server started on: - gateway_port = read_int(gateway_connection.makefile(mode="rb")) - gateway_connection.close() - callback_socket.close() - if gateway_port is None: - raise Exception("Java gateway process exited before sending the driver its port number") + # Create a temporary directory where the gateway server should write the connection + # information. + conn_info_dir = tempfile.mkdtemp() + try: + fd, conn_info_file = tempfile.mkstemp(dir=conn_info_dir) + os.close(fd) + os.unlink(conn_info_file) + + env = dict(os.environ) + env["_PYSPARK_DRIVER_CONN_INFO_PATH"] = conn_info_file + + # Launch the Java gateway. + # We open a pipe to stdin so that the Java gateway can die when the pipe is broken + if not on_windows: + # Don't send ctrl-c / SIGINT to the Java gateway: + def preexec_func(): + signal.signal(signal.SIGINT, signal.SIG_IGN) + proc = Popen(command, stdin=PIPE, preexec_fn=preexec_func, env=env) + else: + # preexec_fn not supported on Windows + proc = Popen(command, stdin=PIPE, env=env) + + # Wait for the file to appear, or for the process to exit, whichever happens first. + while not proc.poll() and not os.path.isfile(conn_info_file): + time.sleep(0.1) + + if not os.path.isfile(conn_info_file): + raise Exception("Java gateway process exited before sending its port number") + + with open(conn_info_file, "rb") as info: + gateway_port = read_int(info) + gateway_secret = UTF8Deserializer().loads(info) + finally: + shutil.rmtree(conn_info_dir) # In Windows, ensure the Java child processes do not linger after Python has exited. # In UNIX-based systems, the child process can kill itself on broken pipe (i.e. when @@ -111,7 +115,9 @@ def killChild(): atexit.register(killChild) # Connect to the gateway - gateway = JavaGateway(GatewayClient(port=gateway_port), auto_convert=True) + gateway = JavaGateway( + gateway_parameters=GatewayParameters(port=gateway_port, auth_token=gateway_secret, + auto_convert=True)) # Import the classes used by PySpark java_import(gateway.jvm, "org.apache.spark.SparkConf") @@ -126,3 +132,16 @@ def killChild(): java_import(gateway.jvm, "scala.Tuple2") return gateway + + +def do_server_auth(conn, auth_secret): + """ + Performs the authentication protocol defined by the SocketAuthHelper class on the given + file-like object 'conn'. + """ + write_with_length(auth_secret.encode("utf-8"), conn) + conn.flush() + reply = UTF8Deserializer().loads(conn) + if reply != "ok": + conn.close() + raise Exception("Unexpected reply from iterator server.") diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 4b44f76747264..d5a237a5b2855 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -39,9 +39,11 @@ else: from itertools import imap as map, ifilter as filter +from pyspark.java_gateway import do_server_auth from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \ BatchedSerializer, CloudPickleSerializer, PairDeserializer, \ - PickleSerializer, pack_long, AutoBatchedSerializer + PickleSerializer, pack_long, AutoBatchedSerializer, write_with_length, \ + UTF8Deserializer from pyspark.join import python_join, python_left_outer_join, \ python_right_outer_join, python_full_outer_join, python_cogroup from pyspark.statcounter import StatCounter @@ -136,7 +138,8 @@ def _parse_memory(s): return int(float(s[:-1]) * units[s[-1].lower()]) -def _load_from_socket(port, serializer): +def _load_from_socket(sock_info, serializer): + port, auth_secret = sock_info sock = None # Support for both IPv4 and IPv6. # On most of IPv6-ready systems, IPv6 will take precedence. @@ -156,8 +159,12 @@ def _load_from_socket(port, serializer): # The RDD materialization time is unpredicable, if we set a timeout for socket reading # operation, it will very possibly fail. See SPARK-18281. sock.settimeout(None) + + sockfile = sock.makefile("rwb", 65536) + do_server_auth(sockfile, auth_secret) + # The socket will be automatically closed when garbage-collected. - return serializer.load_stream(sock.makefile("rb", 65536)) + return serializer.load_stream(sockfile) def ignore_unicode_prefix(f): @@ -822,8 +829,8 @@ def collect(self): to be small, as all the data is loaded into the driver's memory. """ with SCCallSiteSync(self.context) as css: - port = self.ctx._jvm.PythonRDD.collectAndServe(self._jrdd.rdd()) - return list(_load_from_socket(port, self._jrdd_deserializer)) + sock_info = self.ctx._jvm.PythonRDD.collectAndServe(self._jrdd.rdd()) + return list(_load_from_socket(sock_info, self._jrdd_deserializer)) def reduce(self, f): """ @@ -2380,8 +2387,8 @@ def toLocalIterator(self): [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] """ with SCCallSiteSync(self.context) as css: - port = self.ctx._jvm.PythonRDD.toLocalIteratorAndServe(self._jrdd.rdd()) - return _load_from_socket(port, self._jrdd_deserializer) + sock_info = self.ctx._jvm.PythonRDD.toLocalIteratorAndServe(self._jrdd.rdd()) + return _load_from_socket(sock_info, self._jrdd_deserializer) def _prepare_for_python_RDD(sc, command): diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 16f8e52dead7b..213dc158f9328 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -463,8 +463,8 @@ def collect(self): [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')] """ with SCCallSiteSync(self._sc) as css: - port = self._jdf.collectToPython() - return list(_load_from_socket(port, BatchedSerializer(PickleSerializer()))) + sock_info = self._jdf.collectToPython() + return list(_load_from_socket(sock_info, BatchedSerializer(PickleSerializer()))) @ignore_unicode_prefix @since(2.0) @@ -477,8 +477,8 @@ def toLocalIterator(self): [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')] """ with SCCallSiteSync(self._sc) as css: - port = self._jdf.toPythonIterator() - return _load_from_socket(port, BatchedSerializer(PickleSerializer())) + sock_info = self._jdf.toPythonIterator() + return _load_from_socket(sock_info, BatchedSerializer(PickleSerializer())) @ignore_unicode_prefix @since(1.3) @@ -2087,8 +2087,8 @@ def _collectAsArrow(self): .. note:: Experimental. """ with SCCallSiteSync(self._sc) as css: - port = self._jdf.collectAsArrowToPython() - return list(_load_from_socket(port, ArrowSerializer())) + sock_info = self._jdf.collectAsArrowToPython() + return list(_load_from_socket(sock_info, ArrowSerializer())) ########################################################################################## # Pandas compatibility diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index a1a4336b1e8de..8bb63fcc7ff9c 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -27,6 +27,7 @@ from pyspark.accumulators import _accumulatorRegistry from pyspark.broadcast import Broadcast, _broadcastRegistry +from pyspark.java_gateway import do_server_auth from pyspark.taskcontext import TaskContext from pyspark.files import SparkFiles from pyspark.rdd import PythonEvalType @@ -301,9 +302,11 @@ def process(): if __name__ == '__main__': - # Read a local port to connect to from stdin - java_port = int(sys.stdin.readline()) + # Read information about how to connect back to the JVM from the environment. + java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"]) + auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"] sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.connect(("127.0.0.1", java_port)) sock_file = sock.makefile("rwb", 65536) + do_server_auth(sock_file, auth_secret) main(sock_file, sock_file) diff --git a/python/setup.py b/python/setup.py index 794ceceae3008..d309e0564530a 100644 --- a/python/setup.py +++ b/python/setup.py @@ -201,7 +201,7 @@ def _supports_symlinks(): 'pyspark.examples.src.main.python': ['*.py', '*/*.py']}, scripts=scripts, license='http://www.apache.org/licenses/LICENSE-2.0', - install_requires=['py4j==0.10.6'], + install_requires=['py4j==0.10.7'], setup_requires=['pypandoc'], extras_require={ 'ml': ['numpy>=1.7'], diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index bafb129032b49..134b3e5fef11a 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -1152,7 +1152,7 @@ private[spark] class Client( val pyArchivesFile = new File(pyLibPath, "pyspark.zip") require(pyArchivesFile.exists(), s"$pyArchivesFile not found; cannot run pyspark application in YARN mode.") - val py4jFile = new File(pyLibPath, "py4j-0.10.6-src.zip") + val py4jFile = new File(pyLibPath, "py4j-0.10.7-src.zip") require(py4jFile.exists(), s"$py4jFile not found; cannot run pyspark application in YARN mode.") Seq(pyArchivesFile.getAbsolutePath(), py4jFile.getAbsolutePath()) diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala index a129be7c06b53..59b0f29e37d84 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala @@ -265,7 +265,7 @@ class YarnClusterSuite extends BaseYarnClusterSuite { // needed locations. val sparkHome = sys.props("spark.test.home") val pythonPath = Seq( - s"$sparkHome/python/lib/py4j-0.10.6-src.zip", + s"$sparkHome/python/lib/py4j-0.10.7-src.zip", s"$sparkHome/python") val extraEnvVars = Map( "PYSPARK_ARCHIVES_PATH" -> pythonPath.map("local:" + _).mkString(File.pathSeparator), diff --git a/sbin/spark-config.sh b/sbin/spark-config.sh index bac154e10ae62..bf3da18c3706e 100755 --- a/sbin/spark-config.sh +++ b/sbin/spark-config.sh @@ -28,6 +28,6 @@ export SPARK_CONF_DIR="${SPARK_CONF_DIR:-"${SPARK_HOME}/conf"}" # Add the PySpark classes to the PYTHONPATH: if [ -z "${PYSPARK_PYTHONPATH_SET}" ]; then export PYTHONPATH="${SPARK_HOME}/python:${PYTHONPATH}" - export PYTHONPATH="${SPARK_HOME}/python/lib/py4j-0.10.6-src.zip:${PYTHONPATH}" + export PYTHONPATH="${SPARK_HOME}/python/lib/py4j-0.10.7-src.zip:${PYTHONPATH}" export PYSPARK_PYTHONPATH_SET=1 fi diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index cd4def71e6f3b..d518e07bfb62c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -3187,7 +3187,7 @@ class Dataset[T] private[sql]( EvaluatePython.javaToPython(rdd) } - private[sql] def collectToPython(): Int = { + private[sql] def collectToPython(): Array[Any] = { EvaluatePython.registerPicklers() withAction("collectToPython", queryExecution) { plan => val toJava: (Any) => Any = EvaluatePython.toJava(_, schema) @@ -3200,7 +3200,7 @@ class Dataset[T] private[sql]( /** * Collect a Dataset as ArrowPayload byte arrays and serve to PySpark. */ - private[sql] def collectAsArrowToPython(): Int = { + private[sql] def collectAsArrowToPython(): Array[Any] = { withAction("collectAsArrowToPython", queryExecution) { plan => val iter: Iterator[Array[Byte]] = toArrowPayload(plan).collect().iterator.map(_.asPythonSerializable) @@ -3208,7 +3208,7 @@ class Dataset[T] private[sql]( } } - private[sql] def toPythonIterator(): Int = { + private[sql] def toPythonIterator(): Array[Any] = { withNewExecutionId { PythonRDD.toLocalIteratorAndServe(javaToPython.rdd) } From 628c7b517969c4a7ccb26ea67ab3dd61266073ca Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Tue, 17 Apr 2018 13:29:43 -0700 Subject: [PATCH 0768/1161] [SPARKR] Match pyspark features in SparkR communication protocol. --- R/pkg/R/client.R | 4 +- R/pkg/R/deserialize.R | 10 +++- R/pkg/R/sparkR.R | 39 +++++++++++-- R/pkg/inst/worker/daemon.R | 4 +- R/pkg/inst/worker/worker.R | 5 +- .../org/apache/spark/api/r/RAuthHelper.scala | 38 +++++++++++++ .../org/apache/spark/api/r/RBackend.scala | 43 +++++++++++++-- .../spark/api/r/RBackendAuthHandler.scala | 55 +++++++++++++++++++ .../org/apache/spark/api/r/RRunner.scala | 35 ++++++++---- .../org/apache/spark/deploy/RRunner.scala | 6 +- 10 files changed, 210 insertions(+), 29 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/api/r/RAuthHelper.scala create mode 100644 core/src/main/scala/org/apache/spark/api/r/RBackendAuthHandler.scala diff --git a/R/pkg/R/client.R b/R/pkg/R/client.R index 9d82814211bc5..7244cc9f9e38e 100644 --- a/R/pkg/R/client.R +++ b/R/pkg/R/client.R @@ -19,7 +19,7 @@ # Creates a SparkR client connection object # if one doesn't already exist -connectBackend <- function(hostname, port, timeout) { +connectBackend <- function(hostname, port, timeout, authSecret) { if (exists(".sparkRcon", envir = .sparkREnv)) { if (isOpen(.sparkREnv[[".sparkRCon"]])) { cat("SparkRBackend client connection already exists\n") @@ -29,7 +29,7 @@ connectBackend <- function(hostname, port, timeout) { con <- socketConnection(host = hostname, port = port, server = FALSE, blocking = TRUE, open = "wb", timeout = timeout) - + doServerAuth(con, authSecret) assign(".sparkRCon", con, envir = .sparkREnv) con } diff --git a/R/pkg/R/deserialize.R b/R/pkg/R/deserialize.R index a90f7d381026b..cb03f1667629f 100644 --- a/R/pkg/R/deserialize.R +++ b/R/pkg/R/deserialize.R @@ -60,14 +60,18 @@ readTypedObject <- function(con, type) { stop(paste("Unsupported type for deserialization", type))) } -readString <- function(con) { - stringLen <- readInt(con) - raw <- readBin(con, raw(), stringLen, endian = "big") +readStringData <- function(con, len) { + raw <- readBin(con, raw(), len, endian = "big") string <- rawToChar(raw) Encoding(string) <- "UTF-8" string } +readString <- function(con) { + stringLen <- readInt(con) + readStringData(con, stringLen) +} + readInt <- function(con) { readBin(con, integer(), n = 1, endian = "big") } diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R index a480ac606f10d..38ee79477996f 100644 --- a/R/pkg/R/sparkR.R +++ b/R/pkg/R/sparkR.R @@ -158,6 +158,10 @@ sparkR.sparkContext <- function( " please use the --packages commandline instead", sep = ",")) } backendPort <- existingPort + authSecret <- Sys.getenv("SPARKR_BACKEND_AUTH_SECRET") + if (nchar(authSecret) == 0) { + stop("Auth secret not provided in environment.") + } } else { path <- tempfile(pattern = "backend_port") submitOps <- getClientModeSparkSubmitOpts( @@ -186,16 +190,27 @@ sparkR.sparkContext <- function( monitorPort <- readInt(f) rLibPath <- readString(f) connectionTimeout <- readInt(f) + + # Don't use readString() so that we can provide a useful + # error message if the R and Java versions are mismatched. + authSecretLen = readInt(f) + if (length(authSecretLen) == 0 || authSecretLen == 0) { + stop("Unexpected EOF in JVM connection data. Mismatched versions?") + } + authSecret <- readStringData(f, authSecretLen) close(f) file.remove(path) if (length(backendPort) == 0 || backendPort == 0 || length(monitorPort) == 0 || monitorPort == 0 || - length(rLibPath) != 1) { + length(rLibPath) != 1 || length(authSecret) == 0) { stop("JVM failed to launch") } - assign(".monitorConn", - socketConnection(port = monitorPort, timeout = connectionTimeout), - envir = .sparkREnv) + + monitorConn <- socketConnection(port = monitorPort, blocking = TRUE, + timeout = connectionTimeout, open = "wb") + doServerAuth(monitorConn, authSecret) + + assign(".monitorConn", monitorConn, envir = .sparkREnv) assign(".backendLaunched", 1, envir = .sparkREnv) if (rLibPath != "") { assign(".libPath", rLibPath, envir = .sparkREnv) @@ -205,7 +220,7 @@ sparkR.sparkContext <- function( .sparkREnv$backendPort <- backendPort tryCatch({ - connectBackend("localhost", backendPort, timeout = connectionTimeout) + connectBackend("localhost", backendPort, timeout = connectionTimeout, authSecret = authSecret) }, error = function(err) { stop("Failed to connect JVM\n") @@ -687,3 +702,17 @@ sparkCheckInstall <- function(sparkHome, master, deployMode) { NULL } } + +# Utility function for sending auth data over a socket and checking the server's reply. +doServerAuth <- function(con, authSecret) { + if (nchar(authSecret) == 0) { + stop("Auth secret not provided.") + } + writeString(con, authSecret) + flush(con) + reply <- readString(con) + if (reply != "ok") { + close(con) + stop("Unexpected reply from server.") + } +} diff --git a/R/pkg/inst/worker/daemon.R b/R/pkg/inst/worker/daemon.R index 2e31dc5f728cd..fb9db63b07cd0 100644 --- a/R/pkg/inst/worker/daemon.R +++ b/R/pkg/inst/worker/daemon.R @@ -28,7 +28,9 @@ suppressPackageStartupMessages(library(SparkR)) port <- as.integer(Sys.getenv("SPARKR_WORKER_PORT")) inputCon <- socketConnection( - port = port, open = "rb", blocking = TRUE, timeout = connectionTimeout) + port = port, open = "wb", blocking = TRUE, timeout = connectionTimeout) + +SparkR:::doServerAuth(inputCon, Sys.getenv("SPARKR_WORKER_SECRET")) # Waits indefinitely for a socket connecion by default. selectTimeout <- NULL diff --git a/R/pkg/inst/worker/worker.R b/R/pkg/inst/worker/worker.R index 00789d815bba8..ba458d2b9ddfb 100644 --- a/R/pkg/inst/worker/worker.R +++ b/R/pkg/inst/worker/worker.R @@ -100,9 +100,12 @@ suppressPackageStartupMessages(library(SparkR)) port <- as.integer(Sys.getenv("SPARKR_WORKER_PORT")) inputCon <- socketConnection( - port = port, blocking = TRUE, open = "rb", timeout = connectionTimeout) + port = port, blocking = TRUE, open = "wb", timeout = connectionTimeout) +SparkR:::doServerAuth(inputCon, Sys.getenv("SPARKR_WORKER_SECRET")) + outputCon <- socketConnection( port = port, blocking = TRUE, open = "wb", timeout = connectionTimeout) +SparkR:::doServerAuth(outputCon, Sys.getenv("SPARKR_WORKER_SECRET")) # read the index of the current partition inside the RDD partition <- SparkR:::readInt(inputCon) diff --git a/core/src/main/scala/org/apache/spark/api/r/RAuthHelper.scala b/core/src/main/scala/org/apache/spark/api/r/RAuthHelper.scala new file mode 100644 index 0000000000000..ac6826a9ec774 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/api/r/RAuthHelper.scala @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.r + +import java.io.{DataInputStream, DataOutputStream} +import java.net.Socket + +import org.apache.spark.SparkConf +import org.apache.spark.security.SocketAuthHelper + +private[spark] class RAuthHelper(conf: SparkConf) extends SocketAuthHelper(conf) { + + override protected def readUtf8(s: Socket): String = { + SerDe.readString(new DataInputStream(s.getInputStream())) + } + + override protected def writeUtf8(str: String, s: Socket): Unit = { + val out = s.getOutputStream() + SerDe.writeString(new DataOutputStream(out), str) + out.flush() + } + +} diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackend.scala b/core/src/main/scala/org/apache/spark/api/r/RBackend.scala index 2d1152a036449..3b2e809408e0f 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RBackend.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RBackend.scala @@ -17,8 +17,8 @@ package org.apache.spark.api.r -import java.io.{DataOutputStream, File, FileOutputStream, IOException} -import java.net.{InetAddress, InetSocketAddress, ServerSocket} +import java.io.{DataInputStream, DataOutputStream, File, FileOutputStream, IOException} +import java.net.{InetAddress, InetSocketAddress, ServerSocket, Socket} import java.util.concurrent.TimeUnit import io.netty.bootstrap.ServerBootstrap @@ -32,6 +32,8 @@ import io.netty.handler.timeout.ReadTimeoutHandler import org.apache.spark.SparkConf import org.apache.spark.internal.Logging +import org.apache.spark.network.util.JavaUtils +import org.apache.spark.util.Utils /** * Netty-based backend server that is used to communicate between R and Java. @@ -45,7 +47,7 @@ private[spark] class RBackend { /** Tracks JVM objects returned to R for this RBackend instance. */ private[r] val jvmObjectTracker = new JVMObjectTracker - def init(): Int = { + def init(): (Int, RAuthHelper) = { val conf = new SparkConf() val backendConnectionTimeout = conf.getInt( "spark.r.backendConnectionTimeout", SparkRDefaults.DEFAULT_CONNECTION_TIMEOUT) @@ -53,6 +55,7 @@ private[spark] class RBackend { conf.getInt("spark.r.numRBackendThreads", SparkRDefaults.DEFAULT_NUM_RBACKEND_THREADS)) val workerGroup = bossGroup val handler = new RBackendHandler(this) + val authHelper = new RAuthHelper(conf) bootstrap = new ServerBootstrap() .group(bossGroup, workerGroup) @@ -71,13 +74,16 @@ private[spark] class RBackend { new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, 4, 0, 4)) .addLast("decoder", new ByteArrayDecoder()) .addLast("readTimeoutHandler", new ReadTimeoutHandler(backendConnectionTimeout)) + .addLast(new RBackendAuthHandler(authHelper.secret)) .addLast("handler", handler) } }) channelFuture = bootstrap.bind(new InetSocketAddress("localhost", 0)) channelFuture.syncUninterruptibly() - channelFuture.channel().localAddress().asInstanceOf[InetSocketAddress].getPort() + + val port = channelFuture.channel().localAddress().asInstanceOf[InetSocketAddress].getPort() + (port, authHelper) } def run(): Unit = { @@ -116,7 +122,7 @@ private[spark] object RBackend extends Logging { val sparkRBackend = new RBackend() try { // bind to random port - val boundPort = sparkRBackend.init() + val (boundPort, authHelper) = sparkRBackend.init() val serverSocket = new ServerSocket(0, 1, InetAddress.getByName("localhost")) val listenPort = serverSocket.getLocalPort() // Connection timeout is set by socket client. To make it configurable we will pass the @@ -133,6 +139,7 @@ private[spark] object RBackend extends Logging { dos.writeInt(listenPort) SerDe.writeString(dos, RUtils.rPackages.getOrElse("")) dos.writeInt(backendConnectionTimeout) + SerDe.writeString(dos, authHelper.secret) dos.close() f.renameTo(new File(path)) @@ -144,12 +151,35 @@ private[spark] object RBackend extends Logging { val buf = new Array[Byte](1024) // shutdown JVM if R does not connect back in 10 seconds serverSocket.setSoTimeout(10000) + + // Wait for the R process to connect back, ignoring any failed auth attempts. Allow + // a max number of connection attempts to avoid looping forever. try { - val inSocket = serverSocket.accept() + var remainingAttempts = 10 + var inSocket: Socket = null + while (inSocket == null) { + inSocket = serverSocket.accept() + try { + authHelper.authClient(inSocket) + } catch { + case e: Exception => + remainingAttempts -= 1 + if (remainingAttempts == 0) { + val msg = "Too many failed authentication attempts." + logError(msg) + throw new IllegalStateException(msg) + } + logInfo("Client connection failed authentication.") + inSocket = null + } + } + serverSocket.close() + // wait for the end of socket, closed if R process die inSocket.getInputStream().read(buf) } finally { + serverSocket.close() sparkRBackend.close() System.exit(0) } @@ -165,4 +195,5 @@ private[spark] object RBackend extends Logging { } System.exit(0) } + } diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackendAuthHandler.scala b/core/src/main/scala/org/apache/spark/api/r/RBackendAuthHandler.scala new file mode 100644 index 0000000000000..4162e4a6c7476 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/api/r/RBackendAuthHandler.scala @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.r + +import java.io.{ByteArrayOutputStream, DataOutputStream} +import java.nio.charset.StandardCharsets.UTF_8 + +import io.netty.channel.{Channel, ChannelHandlerContext, SimpleChannelInboundHandler} + +import org.apache.spark.internal.Logging +import org.apache.spark.util.Utils + +/** + * Authentication handler for connections from the R process. + */ +private class RBackendAuthHandler(secret: String) + extends SimpleChannelInboundHandler[Array[Byte]] with Logging { + + override def channelRead0(ctx: ChannelHandlerContext, msg: Array[Byte]): Unit = { + // The R code adds a null terminator to serialized strings, so ignore it here. + val clientSecret = new String(msg, 0, msg.length - 1, UTF_8) + try { + require(secret == clientSecret, "Auth secret mismatch.") + ctx.pipeline().remove(this) + writeReply("ok", ctx.channel()) + } catch { + case e: Exception => + logInfo("Authentication failure.", e) + writeReply("err", ctx.channel()) + ctx.close() + } + } + + private def writeReply(reply: String, chan: Channel): Unit = { + val out = new ByteArrayOutputStream() + SerDe.writeString(new DataOutputStream(out), reply) + chan.writeAndFlush(out.toByteArray()) + } + +} diff --git a/core/src/main/scala/org/apache/spark/api/r/RRunner.scala b/core/src/main/scala/org/apache/spark/api/r/RRunner.scala index 88118392003e8..e7fdc3963945a 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RRunner.scala @@ -74,14 +74,19 @@ private[spark] class RRunner[U]( // the socket used to send out the input of task serverSocket.setSoTimeout(10000) - val inSocket = serverSocket.accept() - startStdinThread(inSocket.getOutputStream(), inputIterator, partitionIndex) - - // the socket used to receive the output of task - val outSocket = serverSocket.accept() - val inputStream = new BufferedInputStream(outSocket.getInputStream) - dataStream = new DataInputStream(inputStream) - serverSocket.close() + dataStream = try { + val inSocket = serverSocket.accept() + RRunner.authHelper.authClient(inSocket) + startStdinThread(inSocket.getOutputStream(), inputIterator, partitionIndex) + + // the socket used to receive the output of task + val outSocket = serverSocket.accept() + RRunner.authHelper.authClient(outSocket) + val inputStream = new BufferedInputStream(outSocket.getInputStream) + new DataInputStream(inputStream) + } finally { + serverSocket.close() + } try { return new Iterator[U] { @@ -315,6 +320,11 @@ private[r] object RRunner { private[this] var errThread: BufferedStreamThread = _ private[this] var daemonChannel: DataOutputStream = _ + private lazy val authHelper = { + val conf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf()) + new RAuthHelper(conf) + } + /** * Start a thread to print the process's stderr to ours */ @@ -349,6 +359,7 @@ private[r] object RRunner { pb.environment().put("SPARKR_BACKEND_CONNECTION_TIMEOUT", rConnectionTimeout.toString) pb.environment().put("SPARKR_SPARKFILES_ROOT_DIR", SparkFiles.getRootDirectory()) pb.environment().put("SPARKR_IS_RUNNING_ON_WORKER", "TRUE") + pb.environment().put("SPARKR_WORKER_SECRET", authHelper.secret) pb.redirectErrorStream(true) // redirect stderr into stdout val proc = pb.start() val errThread = startStdoutThread(proc) @@ -370,8 +381,12 @@ private[r] object RRunner { // the socket used to send out the input of task serverSocket.setSoTimeout(10000) val sock = serverSocket.accept() - daemonChannel = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream)) - serverSocket.close() + try { + authHelper.authClient(sock) + daemonChannel = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream)) + } finally { + serverSocket.close() + } } try { daemonChannel.writeInt(port) diff --git a/core/src/main/scala/org/apache/spark/deploy/RRunner.scala b/core/src/main/scala/org/apache/spark/deploy/RRunner.scala index 6eb53a8252205..e86b362639e57 100644 --- a/core/src/main/scala/org/apache/spark/deploy/RRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/RRunner.scala @@ -68,10 +68,13 @@ object RRunner { // Java system properties etc. val sparkRBackend = new RBackend() @volatile var sparkRBackendPort = 0 + @volatile var sparkRBackendSecret: String = null val initialized = new Semaphore(0) val sparkRBackendThread = new Thread("SparkR backend") { override def run() { - sparkRBackendPort = sparkRBackend.init() + val (port, authHelper) = sparkRBackend.init() + sparkRBackendPort = port + sparkRBackendSecret = authHelper.secret initialized.release() sparkRBackend.run() } @@ -91,6 +94,7 @@ object RRunner { env.put("SPARKR_PACKAGE_DIR", rPackageDir.mkString(",")) env.put("R_PROFILE_USER", Seq(rPackageDir(0), "SparkR", "profile", "general.R").mkString(File.separator)) + env.put("SPARKR_BACKEND_AUTH_SECRET", sparkRBackendSecret) builder.redirectErrorStream(true) // Ugly but needed for stdout and stderr to synchronize val process = builder.start() From 7aaa148f593470b2c32221b69097b8b54524eb74 Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Wed, 9 May 2018 11:09:19 -0700 Subject: [PATCH 0769/1161] [SPARK-14682][ML] Provide evaluateEachIteration method or equivalent for spark.ml GBTs ## What changes were proposed in this pull request? Provide evaluateEachIteration method or equivalent for spark.ml GBTs. ## How was this patch tested? UT. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: WeichenXu Closes #21097 from WeichenXu123/GBTeval. --- .../ml/classification/GBTClassifier.scala | 15 +++++++++ .../spark/ml/regression/GBTRegressor.scala | 17 +++++++++- .../org/apache/spark/ml/tree/treeParams.scala | 6 +++- .../classification/GBTClassifierSuite.scala | 29 ++++++++++++++++- .../ml/regression/GBTRegressorSuite.scala | 32 +++++++++++++++++-- 5 files changed, 94 insertions(+), 5 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index 0aa24f0a3cfcc..3fb6d1e4e4f3e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -334,6 +334,21 @@ class GBTClassificationModel private[ml]( // hard coded loss, which is not meant to be changed in the model private val loss = getOldLossType + /** + * Method to compute error or loss for every iteration of gradient boosting. + * + * @param dataset Dataset for validation. + */ + @Since("2.4.0") + def evaluateEachIteration(dataset: Dataset[_]): Array[Double] = { + val data = dataset.select(col($(labelCol)), col($(featuresCol))).rdd.map { + case Row(label: Double, features: Vector) => LabeledPoint(label, features) + } + GradientBoostedTrees.evaluateEachIteration(data, trees, treeWeights, loss, + OldAlgo.Classification + ) + } + @Since("2.0.0") override def write: MLWriter = new GBTClassificationModel.GBTClassificationModelWriter(this) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala index 8598e808c4946..d7e054bf55ef6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala @@ -34,7 +34,7 @@ import org.apache.spark.ml.util.DefaultParamsReader.Metadata import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, Dataset} +import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions._ /** @@ -269,6 +269,21 @@ class GBTRegressionModel private[ml]( new OldGBTModel(OldAlgo.Regression, _trees.map(_.toOld), _treeWeights) } + /** + * Method to compute error or loss for every iteration of gradient boosting. + * + * @param dataset Dataset for validation. + * @param loss The loss function used to compute error. Supported options: squared, absolute + */ + @Since("2.4.0") + def evaluateEachIteration(dataset: Dataset[_], loss: String): Array[Double] = { + val data = dataset.select(col($(labelCol)), col($(featuresCol))).rdd.map { + case Row(label: Double, features: Vector) => LabeledPoint(label, features) + } + GradientBoostedTrees.evaluateEachIteration(data, trees, treeWeights, + convertToOldLossType(loss), OldAlgo.Regression) + } + @Since("2.0.0") override def write: MLWriter = new GBTRegressionModel.GBTRegressionModelWriter(this) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala index 81b6222acc7ce..ec8868bb42cbb 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala @@ -579,7 +579,11 @@ private[ml] trait GBTRegressorParams extends GBTParams with TreeRegressorParams /** (private[ml]) Convert new loss to old loss. */ override private[ml] def getOldLossType: OldLoss = { - getLossType match { + convertToOldLossType(getLossType) + } + + private[ml] def convertToOldLossType(loss: String): OldLoss = { + loss match { case "squared" => OldSquaredError case "absolute" => OldAbsoluteError case _ => diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala index f0ee5496f9d1d..e20de196d65ca 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala @@ -25,7 +25,7 @@ import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.regression.DecisionTreeRegressionModel import org.apache.spark.ml.tree.RegressionLeafNode -import org.apache.spark.ml.tree.impl.TreeTests +import org.apache.spark.ml.tree.impl.{GradientBoostedTrees, TreeTests} import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint} @@ -365,6 +365,33 @@ class GBTClassifierSuite extends MLTest with DefaultReadWriteTest { assert(mostImportantFeature !== mostIF) } + test("model evaluateEachIteration") { + val gbt = new GBTClassifier() + .setSeed(1L) + .setMaxDepth(2) + .setMaxIter(3) + .setLossType("logistic") + val model3 = gbt.fit(trainData.toDF) + val model1 = new GBTClassificationModel("gbt-cls-model-test1", + model3.trees.take(1), model3.treeWeights.take(1), model3.numFeatures, model3.numClasses) + val model2 = new GBTClassificationModel("gbt-cls-model-test2", + model3.trees.take(2), model3.treeWeights.take(2), model3.numFeatures, model3.numClasses) + + val evalArr = model3.evaluateEachIteration(validationData.toDF) + val remappedValidationData = validationData.map( + x => new LabeledPoint((x.label * 2) - 1, x.features)) + val lossErr1 = GradientBoostedTrees.computeError(remappedValidationData, + model1.trees, model1.treeWeights, model1.getOldLossType) + val lossErr2 = GradientBoostedTrees.computeError(remappedValidationData, + model2.trees, model2.treeWeights, model2.getOldLossType) + val lossErr3 = GradientBoostedTrees.computeError(remappedValidationData, + model3.trees, model3.treeWeights, model3.getOldLossType) + + assert(evalArr(0) ~== lossErr1 relTol 1E-3) + assert(evalArr(1) ~== lossErr2 relTol 1E-3) + assert(evalArr(2) ~== lossErr3 relTol 1E-3) + } + ///////////////////////////////////////////////////////////////////////////// // Tests of model save/load ///////////////////////////////////////////////////////////////////////////// diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala index fad11d078250f..773f6d2c542fe 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala @@ -20,8 +20,9 @@ package org.apache.spark.ml.regression import org.apache.spark.SparkFunSuite import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.linalg.{Vector, Vectors} -import org.apache.spark.ml.tree.impl.TreeTests +import org.apache.spark.ml.tree.impl.{GradientBoostedTrees, TreeTests} import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} +import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint} import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees => OldGBT} import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} @@ -201,7 +202,34 @@ class GBTRegressorSuite extends MLTest with DefaultReadWriteTest { assert(mostImportantFeature !== mostIF) } - + test("model evaluateEachIteration") { + for (lossType <- GBTRegressor.supportedLossTypes) { + val gbt = new GBTRegressor() + .setSeed(1L) + .setMaxDepth(2) + .setMaxIter(3) + .setLossType(lossType) + val model3 = gbt.fit(trainData.toDF) + val model1 = new GBTRegressionModel("gbt-reg-model-test1", + model3.trees.take(1), model3.treeWeights.take(1), model3.numFeatures) + val model2 = new GBTRegressionModel("gbt-reg-model-test2", + model3.trees.take(2), model3.treeWeights.take(2), model3.numFeatures) + + for (evalLossType <- GBTRegressor.supportedLossTypes) { + val evalArr = model3.evaluateEachIteration(validationData.toDF, evalLossType) + val lossErr1 = GradientBoostedTrees.computeError(validationData, + model1.trees, model1.treeWeights, model1.convertToOldLossType(evalLossType)) + val lossErr2 = GradientBoostedTrees.computeError(validationData, + model2.trees, model2.treeWeights, model2.convertToOldLossType(evalLossType)) + val lossErr3 = GradientBoostedTrees.computeError(validationData, + model3.trees, model3.treeWeights, model3.convertToOldLossType(evalLossType)) + + assert(evalArr(0) ~== lossErr1 relTol 1E-3) + assert(evalArr(1) ~== lossErr2 relTol 1E-3) + assert(evalArr(2) ~== lossErr3 relTol 1E-3) + } + } + } ///////////////////////////////////////////////////////////////////////////// // Tests of model save/load From fd1179c17273283d32f275d5cd5f97aaa2aca1f7 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Wed, 9 May 2018 11:32:17 -0700 Subject: [PATCH 0770/1161] [SPARK-24214][SS] Fix toJSON for StreamingRelationV2/StreamingExecutionRelation/ContinuousExecutionRelation ## What changes were proposed in this pull request? We should overwrite "otherCopyArgs" to provide the SparkSession parameter otherwise TreeNode.toJSON cannot get the full constructor parameter list. ## How was this patch tested? The new unit test. Author: Shixiong Zhu Closes #21275 from zsxwing/SPARK-24214. --- .../execution/streaming/StreamingRelation.scala | 3 +++ .../spark/sql/streaming/StreamingQuerySuite.scala | 15 +++++++++++++++ 2 files changed, 18 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala index f02d3a2c3733f..24195b5657e8a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala @@ -66,6 +66,7 @@ case class StreamingExecutionRelation( output: Seq[Attribute])(session: SparkSession) extends LeafNode with MultiInstanceRelation { + override def otherCopyArgs: Seq[AnyRef] = session :: Nil override def isStreaming: Boolean = true override def toString: String = source.toString @@ -97,6 +98,7 @@ case class StreamingRelationV2( output: Seq[Attribute], v1Relation: Option[StreamingRelation])(session: SparkSession) extends LeafNode with MultiInstanceRelation { + override def otherCopyArgs: Seq[AnyRef] = session :: Nil override def isStreaming: Boolean = true override def toString: String = sourceName @@ -116,6 +118,7 @@ case class ContinuousExecutionRelation( output: Seq[Attribute])(session: SparkSession) extends LeafNode with MultiInstanceRelation { + override def otherCopyArgs: Seq[AnyRef] = session :: Nil override def isStreaming: Boolean = true override def toString: String = source.toString diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala index 0cb2375e0a49a..57986999bf861 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala @@ -831,6 +831,21 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi CheckLastBatch(("A", 1))) } + test("StreamingRelationV2/StreamingExecutionRelation/ContinuousExecutionRelation.toJSON " + + "should not fail") { + val df = spark.readStream.format("rate").load() + assert(df.logicalPlan.toJSON.contains("StreamingRelationV2")) + + testStream(df)( + AssertOnQuery(_.logicalPlan.toJSON.contains("StreamingExecutionRelation")) + ) + + testStream(df, useV2Sink = true)( + StartStream(trigger = Trigger.Continuous(100)), + AssertOnQuery(_.logicalPlan.toJSON.contains("ContinuousExecutionRelation")) + ) + } + /** Create a streaming DF that only execute one batch in which it returns the given static DF */ private def createSingleTriggerStreamingDF(triggerDF: DataFrame): DataFrame = { require(!triggerDF.isStreaming) From 9e3bb313682bf88d0c81427167ee341698d29b02 Mon Sep 17 00:00:00 2001 From: wuyi Date: Wed, 9 May 2018 15:44:36 -0700 Subject: [PATCH 0771/1161] [SPARK-24141][CORE] Fix bug in CoarseGrainedSchedulerBackend.killExecutors ## What changes were proposed in this pull request? In method *CoarseGrainedSchedulerBackend.killExecutors()*, `numPendingExecutors` should add `executorsToKill.size` rather than `knownExecutors.size` if we do not adjust target number of executors. ## How was this patch tested? N/A Author: wuyi Closes #21209 from Ngone51/SPARK-24141. --- .../spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 5627a557a12f3..d8794e8e551aa 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -633,7 +633,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp } doRequestTotalExecutors(requestedTotalExecutors) } else { - numPendingExecutors += knownExecutors.size + numPendingExecutors += executorsToKill.size Future.successful(true) } From 9341c951e85ff29714cbee302053872a6a4223da Mon Sep 17 00:00:00 2001 From: Henry Robinson Date: Wed, 9 May 2018 19:56:03 -0700 Subject: [PATCH 0772/1161] [SPARK-23852][SQL] Add test that fails if PARQUET-1217 is not fixed ## What changes were proposed in this pull request? Add a new test that triggers if PARQUET-1217 - a predicate pushdown bug - is not fixed in Spark's Parquet dependency. ## How was this patch tested? New unit test passes. Author: Henry Robinson Closes #21284 from henryr/spark-23852. --- .../test/resources/test-data/parquet-1217.parquet | Bin 0 -> 321 bytes .../datasources/parquet/ParquetFilterSuite.scala | 10 ++++++++++ 2 files changed, 10 insertions(+) create mode 100644 sql/core/src/test/resources/test-data/parquet-1217.parquet diff --git a/sql/core/src/test/resources/test-data/parquet-1217.parquet b/sql/core/src/test/resources/test-data/parquet-1217.parquet new file mode 100644 index 0000000000000000000000000000000000000000..eb2dc4f79907019ffe4edaa9adb951924aad1c5d GIT binary patch literal 321 zcmbu5(MrQG7=^Pml#v^~@DB~-BFI`Q)R6|)b+DV=$S%Yc=L=*_hK0@zhq6oG%h&Ne zT(Vd2z~P6(;p68ti 0").count() === 2) + } } class NumRowGroupsAcc extends AccumulatorV2[Integer, Integer] { From 62d01391fee77eedd75b4e3f475ede8b9f0df0c2 Mon Sep 17 00:00:00 2001 From: Ryan Blue Date: Wed, 9 May 2018 21:48:54 -0700 Subject: [PATCH 0773/1161] [SPARK-24073][SQL] Rename DataReaderFactory to InputPartition. ## What changes were proposed in this pull request? Renames: * `DataReaderFactory` to `InputPartition` * `DataReader` to `InputPartitionReader` * `createDataReaderFactories` to `planInputPartitions` * `createUnsafeDataReaderFactories` to `planUnsafeInputPartitions` * `createBatchDataReaderFactories` to `planBatchInputPartitions` This fixes the changes in SPARK-23219, which renamed ReadTask to DataReaderFactory. The intent of that change was to make the read and write API match (write side uses DataWriterFactory), but the underlying problem is that the two classes are not equivalent. ReadTask/DataReader function as Iterable/Iterator. One InputPartition is a specific partition of the data to be read, in contrast to DataWriterFactory where the same factory instance is used in all write tasks. InputPartition's purpose is to manage the lifecycle of the associated reader, which is now called InputPartitionReader, with an explicit create operation to mirror the close operation. This was no longer clear from the API because DataReaderFactory appeared to be more generic than it is and it isn't clear why a set of them is produced for a read. ## How was this patch tested? Existing tests, which have been updated to use the new name. Author: Ryan Blue Closes #21145 from rdblue/SPARK-24073-revert-data-reader-factory-rename. --- .../sql/kafka010/KafkaContinuousReader.scala | 20 ++--- .../sql/kafka010/KafkaMicroBatchReader.scala | 21 ++--- .../sql/kafka010/KafkaSourceProvider.scala | 3 +- .../kafka010/KafkaMicroBatchSourceSuite.scala | 2 +- .../sql/sources/v2/MicroBatchReadSupport.java | 2 +- ...ory.java => ContinuousInputPartition.java} | 8 +- .../sources/v2/reader/DataSourceReader.java | 10 +-- ...ReaderFactory.java => InputPartition.java} | 16 ++-- ...aReader.java => InputPartitionReader.java} | 4 +- .../v2/reader/SupportsReportPartitioning.java | 2 +- .../v2/reader/SupportsScanColumnarBatch.java | 10 +-- .../v2/reader/SupportsScanUnsafeRow.java | 8 +- .../partitioning/ClusteredDistribution.java | 4 +- .../v2/reader/partitioning/Distribution.java | 6 +- .../v2/reader/partitioning/Partitioning.java | 4 +- ...va => ContinuousInputPartitionReader.java} | 6 +- .../v2/reader/streaming/ContinuousReader.java | 10 +-- .../v2/reader/streaming/MicroBatchReader.java | 2 +- .../datasources/v2/DataSourceRDD.scala | 11 +-- .../datasources/v2/DataSourceV2ScanExec.scala | 46 +++++------ .../continuous/ContinuousDataSourceRDD.scala | 22 +++--- .../ContinuousQueuedDataReader.scala | 13 ++-- .../ContinuousRateStreamSource.scala | 20 ++--- .../sql/execution/streaming/memory.scala | 12 +-- .../sources/ContinuousMemoryStream.scala | 18 ++--- .../sources/RateStreamMicroBatchReader.scala | 15 ++-- .../execution/streaming/sources/socket.scala | 29 +++---- .../sources/v2/JavaAdvancedDataSourceV2.java | 22 +++--- .../sql/sources/v2/JavaBatchDataSourceV2.java | 12 +-- .../v2/JavaPartitionAwareDataSource.java | 12 +-- .../v2/JavaSchemaRequiredDataSource.java | 4 +- .../sources/v2/JavaSimpleDataSourceV2.java | 18 ++--- .../sources/v2/JavaUnsafeRowDataSourceV2.java | 16 ++-- .../sources/RateStreamProviderSuite.scala | 12 +-- .../sql/sources/v2/DataSourceV2Suite.scala | 78 ++++++++++--------- .../sources/v2/SimpleWritableDataSource.scala | 14 ++-- .../sql/streaming/StreamingQuerySuite.scala | 11 +-- .../ContinuousQueuedDataReaderSuite.scala | 8 +- .../sources/StreamingDataSourceV2Suite.scala | 4 +- 39 files changed, 272 insertions(+), 263 deletions(-) rename sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/{ContinuousDataReaderFactory.java => ContinuousInputPartition.java} (78%) rename sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/{DataReaderFactory.java => InputPartition.java} (81%) rename sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/{DataReader.java => InputPartitionReader.java} (92%) rename sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/{ContinuousDataReader.java => ContinuousInputPartitionReader.java} (84%) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala index f26c134c2f6e9..88abf8a8dd027 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.kafka010.KafkaSourceProvider.{INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE, INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE} import org.apache.spark.sql.sources.v2.reader._ -import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousDataReader, ContinuousReader, Offset, PartitionOffset} +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousInputPartitionReader, ContinuousReader, Offset, PartitionOffset} import org.apache.spark.sql.types.StructType /** @@ -86,7 +86,7 @@ class KafkaContinuousReader( KafkaSourceOffset(JsonUtils.partitionOffsets(json)) } - override def createUnsafeRowReaderFactories(): ju.List[DataReaderFactory[UnsafeRow]] = { + override def planUnsafeInputPartitions(): ju.List[InputPartition[UnsafeRow]] = { import scala.collection.JavaConverters._ val oldStartPartitionOffsets = KafkaSourceOffset.getPartitionOffsets(offset) @@ -108,7 +108,7 @@ class KafkaContinuousReader( case (topicPartition, start) => KafkaContinuousDataReaderFactory( topicPartition, start, kafkaParams, pollTimeoutMs, failOnDataLoss) - .asInstanceOf[DataReaderFactory[UnsafeRow]] + .asInstanceOf[InputPartition[UnsafeRow]] }.asJava } @@ -161,18 +161,18 @@ case class KafkaContinuousDataReaderFactory( startOffset: Long, kafkaParams: ju.Map[String, Object], pollTimeoutMs: Long, - failOnDataLoss: Boolean) extends ContinuousDataReaderFactory[UnsafeRow] { + failOnDataLoss: Boolean) extends ContinuousInputPartition[UnsafeRow] { - override def createDataReaderWithOffset(offset: PartitionOffset): DataReader[UnsafeRow] = { + override def createContinuousReader(offset: PartitionOffset): InputPartitionReader[UnsafeRow] = { val kafkaOffset = offset.asInstanceOf[KafkaSourcePartitionOffset] require(kafkaOffset.topicPartition == topicPartition, s"Expected topicPartition: $topicPartition, but got: ${kafkaOffset.topicPartition}") - new KafkaContinuousDataReader( + new KafkaContinuousInputPartitionReader( topicPartition, kafkaOffset.partitionOffset, kafkaParams, pollTimeoutMs, failOnDataLoss) } - override def createDataReader(): KafkaContinuousDataReader = { - new KafkaContinuousDataReader( + override def createPartitionReader(): KafkaContinuousInputPartitionReader = { + new KafkaContinuousInputPartitionReader( topicPartition, startOffset, kafkaParams, pollTimeoutMs, failOnDataLoss) } } @@ -187,12 +187,12 @@ case class KafkaContinuousDataReaderFactory( * @param failOnDataLoss Flag indicating whether data reader should fail if some offsets * are skipped. */ -class KafkaContinuousDataReader( +class KafkaContinuousInputPartitionReader( topicPartition: TopicPartition, startOffset: Long, kafkaParams: ju.Map[String, Object], pollTimeoutMs: Long, - failOnDataLoss: Boolean) extends ContinuousDataReader[UnsafeRow] { + failOnDataLoss: Boolean) extends ContinuousInputPartitionReader[UnsafeRow] { private val consumer = KafkaDataConsumer.acquire(topicPartition, kafkaParams, useCache = false) private val converter = new KafkaRecordToUnsafeRowConverter diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala index cbe655f9bff1f..8a377738ea782 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.execution.streaming.{HDFSMetadataLog, SerializedOffset} import org.apache.spark.sql.kafka010.KafkaSourceProvider.{INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE, INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE} import org.apache.spark.sql.sources.v2.DataSourceOptions -import org.apache.spark.sql.sources.v2.reader.{DataReader, DataReaderFactory, SupportsScanUnsafeRow} +import org.apache.spark.sql.sources.v2.reader.{InputPartition, InputPartitionReader, SupportsScanUnsafeRow} import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset} import org.apache.spark.sql.types.StructType import org.apache.spark.util.UninterruptibleThread @@ -101,7 +101,7 @@ private[kafka010] class KafkaMicroBatchReader( } } - override def createUnsafeRowReaderFactories(): ju.List[DataReaderFactory[UnsafeRow]] = { + override def planUnsafeInputPartitions(): ju.List[InputPartition[UnsafeRow]] = { // Find the new partitions, and get their earliest offsets val newPartitions = endPartitionOffsets.keySet.diff(startPartitionOffsets.keySet) val newPartitionInitialOffsets = kafkaOffsetReader.fetchEarliestOffsets(newPartitions.toSeq) @@ -146,7 +146,7 @@ private[kafka010] class KafkaMicroBatchReader( new KafkaMicroBatchDataReaderFactory( range, executorKafkaParams, pollTimeoutMs, failOnDataLoss, reuseKafkaConsumer) } - factories.map(_.asInstanceOf[DataReaderFactory[UnsafeRow]]).asJava + factories.map(_.asInstanceOf[InputPartition[UnsafeRow]]).asJava } override def getStartOffset: Offset = { @@ -299,27 +299,28 @@ private[kafka010] class KafkaMicroBatchReader( } } -/** A [[DataReaderFactory]] for reading Kafka data in a micro-batch streaming query. */ +/** A [[InputPartition]] for reading Kafka data in a micro-batch streaming query. */ private[kafka010] case class KafkaMicroBatchDataReaderFactory( offsetRange: KafkaOffsetRange, executorKafkaParams: ju.Map[String, Object], pollTimeoutMs: Long, failOnDataLoss: Boolean, - reuseKafkaConsumer: Boolean) extends DataReaderFactory[UnsafeRow] { + reuseKafkaConsumer: Boolean) extends InputPartition[UnsafeRow] { override def preferredLocations(): Array[String] = offsetRange.preferredLoc.toArray - override def createDataReader(): DataReader[UnsafeRow] = new KafkaMicroBatchDataReader( - offsetRange, executorKafkaParams, pollTimeoutMs, failOnDataLoss, reuseKafkaConsumer) + override def createPartitionReader(): InputPartitionReader[UnsafeRow] = + new KafkaMicroBatchInputPartitionReader(offsetRange, executorKafkaParams, pollTimeoutMs, + failOnDataLoss, reuseKafkaConsumer) } -/** A [[DataReader]] for reading Kafka data in a micro-batch streaming query. */ -private[kafka010] case class KafkaMicroBatchDataReader( +/** A [[InputPartitionReader]] for reading Kafka data in a micro-batch streaming query. */ +private[kafka010] case class KafkaMicroBatchInputPartitionReader( offsetRange: KafkaOffsetRange, executorKafkaParams: ju.Map[String, Object], pollTimeoutMs: Long, failOnDataLoss: Boolean, - reuseKafkaConsumer: Boolean) extends DataReader[UnsafeRow] with Logging { + reuseKafkaConsumer: Boolean) extends InputPartitionReader[UnsafeRow] with Logging { private val consumer = KafkaDataConsumer.acquire( offsetRange.topicPartition, executorKafkaParams, reuseKafkaConsumer) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala index 36b9f0466566b..d225c1ea6b7f1 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.{AnalysisException, DataFrame, SaveMode, SparkSessio import org.apache.spark.sql.execution.streaming.{Sink, Source} import org.apache.spark.sql.sources._ import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions, MicroBatchReadSupport, StreamWriteSupport} +import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousInputPartitionReader import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType @@ -149,7 +150,7 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister } /** - * Creates a [[org.apache.spark.sql.sources.v2.reader.streaming.ContinuousDataReader]] to read + * Creates a [[ContinuousInputPartitionReader]] to read * Kafka data in a continuous streaming query. */ override def createContinuousReader( diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala index d2d04b68de6ab..871f9700cd1db 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala @@ -678,7 +678,7 @@ class KafkaMicroBatchV2SourceSuite extends KafkaMicroBatchSourceSuiteBase { Optional.of[OffsetV2](KafkaSourceOffset(Map(tp -> 0L))), Optional.of[OffsetV2](KafkaSourceOffset(Map(tp -> 100L))) ) - val factories = reader.createUnsafeRowReaderFactories().asScala + val factories = reader.planUnsafeInputPartitions().asScala .map(_.asInstanceOf[KafkaMicroBatchDataReaderFactory]) withClue(s"minPartitions = $minPartitions generated factories $factories\n\t") { assert(factories.size == numPartitionsGenerated) diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/MicroBatchReadSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/MicroBatchReadSupport.java index 209ffa7a0b9fa..7f4a2c9593c76 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/MicroBatchReadSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/MicroBatchReadSupport.java @@ -34,7 +34,7 @@ public interface MicroBatchReadSupport extends DataSourceV2 { * streaming query. * * The execution engine will create a micro-batch reader at the start of a streaming query, - * alternate calls to setOffsetRange and createDataReaderFactories for each batch to process, and + * alternate calls to setOffsetRange and planInputPartitions for each batch to process, and * then call stop() when the execution is complete. Note that a single query may have multiple * executions due to restart or failure recovery. * diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousDataReaderFactory.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousInputPartition.java similarity index 78% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousDataReaderFactory.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousInputPartition.java index a61697649c43e..c24f3b21eade1 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousDataReaderFactory.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousInputPartition.java @@ -21,15 +21,15 @@ import org.apache.spark.sql.sources.v2.reader.streaming.PartitionOffset; /** - * A mix-in interface for {@link DataReaderFactory}. Continuous data reader factories can - * implement this interface to provide creating {@link DataReader} with particular offset. + * A mix-in interface for {@link InputPartition}. Continuous input partitions can + * implement this interface to provide creating {@link InputPartitionReader} with particular offset. */ @InterfaceStability.Evolving -public interface ContinuousDataReaderFactory extends DataReaderFactory { +public interface ContinuousInputPartition extends InputPartition { /** * Create a DataReader with particular offset as its startOffset. * * @param offset offset want to set as the DataReader's startOffset. */ - DataReader createDataReaderWithOffset(PartitionOffset offset); + InputPartitionReader createContinuousReader(PartitionOffset offset); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceReader.java index a470bccc5aad2..f898c296e4245 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceReader.java @@ -31,8 +31,8 @@ * {@link ReadSupport#createReader(DataSourceOptions)} or * {@link ReadSupportWithSchema#createReader(StructType, DataSourceOptions)}. * It can mix in various query optimization interfaces to speed up the data scan. The actual scan - * logic is delegated to {@link DataReaderFactory}s that are returned by - * {@link #createDataReaderFactories()}. + * logic is delegated to {@link InputPartition}s that are returned by + * {@link #planInputPartitions()}. * * There are mainly 3 kinds of query optimizations: * 1. Operators push-down. E.g., filter push-down, required columns push-down(aka column @@ -65,8 +65,8 @@ public interface DataSourceReader { StructType readSchema(); /** - * Returns a list of reader factories. Each factory is responsible for creating a data reader to - * output data for one RDD partition. That means the number of factories returned here is same as + * Returns a list of read tasks. Each task is responsible for creating a data reader to + * output data for one RDD partition. That means the number of tasks returned here is same as * the number of RDD partitions this scan outputs. * * Note that, this may not be a full scan if the data source reader mixes in other optimization @@ -76,5 +76,5 @@ public interface DataSourceReader { * If this method fails (by throwing an exception), the action would fail and no Spark job was * submitted. */ - List> createDataReaderFactories(); + List> planInputPartitions(); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReaderFactory.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartition.java similarity index 81% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReaderFactory.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartition.java index 32e98e8f5d8bd..c581e3b5d0047 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReaderFactory.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartition.java @@ -22,20 +22,20 @@ import org.apache.spark.annotation.InterfaceStability; /** - * A reader factory returned by {@link DataSourceReader#createDataReaderFactories()} and is + * An input partition returned by {@link DataSourceReader#planInputPartitions()} and is * responsible for creating the actual data reader. The relationship between - * {@link DataReaderFactory} and {@link DataReader} + * {@link InputPartition} and {@link InputPartitionReader} * is similar to the relationship between {@link Iterable} and {@link java.util.Iterator}. * - * Note that, the reader factory will be serialized and sent to executors, then the data reader - * will be created on executors and do the actual reading. So {@link DataReaderFactory} must be - * serializable and {@link DataReader} doesn't need to be. + * Note that input partitions will be serialized and sent to executors, then the partition reader + * will be created on executors and do the actual reading. So {@link InputPartition} must be + * serializable and {@link InputPartitionReader} doesn't need to be. */ @InterfaceStability.Evolving -public interface DataReaderFactory extends Serializable { +public interface InputPartition extends Serializable { /** - * The preferred locations where the data reader returned by this reader factory can run faster, + * The preferred locations where the data reader returned by this partition can run faster, * but Spark does not guarantee to run the data reader on these locations. * The implementations should make sure that it can be run on any location. * The location is a string representing the host name. @@ -57,5 +57,5 @@ default String[] preferredLocations() { * If this method fails (by throwing an exception), the corresponding Spark task would fail and * get retried until hitting the maximum retry times. */ - DataReader createDataReader(); + InputPartitionReader createPartitionReader(); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartitionReader.java similarity index 92% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReader.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartitionReader.java index bb9790a1c819e..1b7051f1ad0af 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartitionReader.java @@ -23,7 +23,7 @@ import org.apache.spark.annotation.InterfaceStability; /** - * A data reader returned by {@link DataReaderFactory#createDataReader()} and is responsible for + * A data reader returned by {@link InputPartition#createPartitionReader()} and is responsible for * outputting data for a RDD partition. * * Note that, Currently the type `T` can only be {@link org.apache.spark.sql.Row} for normal data @@ -31,7 +31,7 @@ * readers that mix in {@link SupportsScanUnsafeRow}. */ @InterfaceStability.Evolving -public interface DataReader extends Closeable { +public interface InputPartitionReader extends Closeable { /** * Proceed to next record, returns false if there is no more records. diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java index 607628746e873..6b60da7c4dc1d 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java @@ -24,7 +24,7 @@ * A mix in interface for {@link DataSourceReader}. Data source readers can implement this * interface to report data partitioning and try to avoid shuffle at Spark side. * - * Note that, when the reader creates exactly one {@link DataReaderFactory}, Spark may avoid + * Note that, when the reader creates exactly one {@link InputPartition}, Spark may avoid * adding a shuffle even if the reader does not implement this interface. */ @InterfaceStability.Evolving diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanColumnarBatch.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanColumnarBatch.java index 2e5cfa78511f0..0faf81db24605 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanColumnarBatch.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanColumnarBatch.java @@ -30,22 +30,22 @@ @InterfaceStability.Evolving public interface SupportsScanColumnarBatch extends DataSourceReader { @Override - default List> createDataReaderFactories() { + default List> planInputPartitions() { throw new IllegalStateException( - "createDataReaderFactories not supported by default within SupportsScanColumnarBatch."); + "planInputPartitions not supported by default within SupportsScanColumnarBatch."); } /** - * Similar to {@link DataSourceReader#createDataReaderFactories()}, but returns columnar data + * Similar to {@link DataSourceReader#planInputPartitions()}, but returns columnar data * in batches. */ - List> createBatchDataReaderFactories(); + List> planBatchInputPartitions(); /** * Returns true if the concrete data source reader can read data in batch according to the scan * properties like required columns, pushes filters, etc. It's possible that the implementation * can only support some certain columns with certain types. Users can overwrite this method and - * {@link #createDataReaderFactories()} to fallback to normal read path under some conditions. + * {@link #planInputPartitions()} to fallback to normal read path under some conditions. */ default boolean enableBatchRead() { return true; diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanUnsafeRow.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanUnsafeRow.java index 9cd749e8e4ce9..f2220f6d31093 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanUnsafeRow.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanUnsafeRow.java @@ -33,14 +33,14 @@ public interface SupportsScanUnsafeRow extends DataSourceReader { @Override - default List> createDataReaderFactories() { + default List> planInputPartitions() { throw new IllegalStateException( - "createDataReaderFactories not supported by default within SupportsScanUnsafeRow"); + "planInputPartitions not supported by default within SupportsScanUnsafeRow"); } /** - * Similar to {@link DataSourceReader#createDataReaderFactories()}, + * Similar to {@link DataSourceReader#planInputPartitions()}, * but returns data in unsafe row format. */ - List> createUnsafeRowReaderFactories(); + List> planUnsafeInputPartitions(); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/ClusteredDistribution.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/ClusteredDistribution.java index 2d0ee50212b56..38ca5fc6387b2 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/ClusteredDistribution.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/ClusteredDistribution.java @@ -18,12 +18,12 @@ package org.apache.spark.sql.sources.v2.reader.partitioning; import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.sources.v2.reader.DataReader; +import org.apache.spark.sql.sources.v2.reader.InputPartitionReader; /** * A concrete implementation of {@link Distribution}. Represents a distribution where records that * share the same values for the {@link #clusteredColumns} will be produced by the same - * {@link DataReader}. + * {@link InputPartitionReader}. */ @InterfaceStability.Evolving public class ClusteredDistribution implements Distribution { diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Distribution.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Distribution.java index f6b111fdf220d..d2ee9518d628f 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Distribution.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Distribution.java @@ -18,13 +18,13 @@ package org.apache.spark.sql.sources.v2.reader.partitioning; import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.sources.v2.reader.DataReader; +import org.apache.spark.sql.sources.v2.reader.InputPartitionReader; /** * An interface to represent data distribution requirement, which specifies how the records should - * be distributed among the data partitions(one {@link DataReader} outputs data for one partition). + * be distributed among the data partitions(one {@link InputPartitionReader} outputs data for one partition). * Note that this interface has nothing to do with the data ordering inside one - * partition(the output records of a single {@link DataReader}). + * partition(the output records of a single {@link InputPartitionReader}). * * The instance of this interface is created and provided by Spark, then consumed by * {@link Partitioning#satisfy(Distribution)}. This means data source developers don't need to diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Partitioning.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Partitioning.java index 309d9e5de0a0f..f460f6bfe3bb9 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Partitioning.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Partitioning.java @@ -18,7 +18,7 @@ package org.apache.spark.sql.sources.v2.reader.partitioning; import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.sources.v2.reader.DataReaderFactory; +import org.apache.spark.sql.sources.v2.reader.InputPartition; import org.apache.spark.sql.sources.v2.reader.SupportsReportPartitioning; /** @@ -31,7 +31,7 @@ public interface Partitioning { /** - * Returns the number of partitions(i.e., {@link DataReaderFactory}s) the data source outputs. + * Returns the number of partitions(i.e., {@link InputPartition}s) the data source outputs. */ int numPartitions(); diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousDataReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousInputPartitionReader.java similarity index 84% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousDataReader.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousInputPartitionReader.java index 47d26440841fd..7b0ba0bbdda90 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousDataReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousInputPartitionReader.java @@ -18,13 +18,13 @@ package org.apache.spark.sql.sources.v2.reader.streaming; import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.sources.v2.reader.DataReader; +import org.apache.spark.sql.sources.v2.reader.InputPartitionReader; /** - * A variation on {@link DataReader} for use with streaming in continuous processing mode. + * A variation on {@link InputPartitionReader} for use with streaming in continuous processing mode. */ @InterfaceStability.Evolving -public interface ContinuousDataReader extends DataReader { +public interface ContinuousInputPartitionReader extends InputPartitionReader { /** * Get the offset of the current record, or the start offset if no records have been read. * diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousReader.java index 7fe7f00ac2fa8..716c5c0e9e15a 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousReader.java @@ -27,7 +27,7 @@ * A mix-in interface for {@link DataSourceReader}. Data source readers can implement this * interface to allow reading in a continuous processing mode stream. * - * Implementations must ensure each reader factory output is a {@link ContinuousDataReader}. + * Implementations must ensure each partition reader is a {@link ContinuousInputPartitionReader}. * * Note: This class currently extends {@link BaseStreamingSource} to maintain compatibility with * DataSource V1 APIs. This extension will be removed once we get rid of V1 completely. @@ -35,7 +35,7 @@ @InterfaceStability.Evolving public interface ContinuousReader extends BaseStreamingSource, DataSourceReader { /** - * Merge partitioned offsets coming from {@link ContinuousDataReader} instances for each + * Merge partitioned offsets coming from {@link ContinuousInputPartitionReader} instances for each * partition to a single global offset. */ Offset mergeOffsets(PartitionOffset[] offsets); @@ -47,7 +47,7 @@ public interface ContinuousReader extends BaseStreamingSource, DataSourceReader Offset deserializeOffset(String json); /** - * Set the desired start offset for reader factories created from this reader. The scan will + * Set the desired start offset for partitions created from this reader. The scan will * start from the first record after the provided offset, or from an implementation-defined * inferred starting point if no offset is provided. */ @@ -61,8 +61,8 @@ public interface ContinuousReader extends BaseStreamingSource, DataSourceReader Offset getStartOffset(); /** - * The execution engine will call this method in every epoch to determine if new reader - * factories need to be generated, which may be required if for example the underlying + * The execution engine will call this method in every epoch to determine if new input + * partitions need to be generated, which may be required if for example the underlying * source system has had partitions added or removed. * * If true, the query will be shut down and restarted with a new reader. diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/MicroBatchReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/MicroBatchReader.java index 67ebde30d61a9..0159c731762d9 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/MicroBatchReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/MicroBatchReader.java @@ -33,7 +33,7 @@ @InterfaceStability.Evolving public interface MicroBatchReader extends DataSourceReader, BaseStreamingSource { /** - * Set the desired offset range for reader factories created from this reader. Reader factories + * Set the desired offset range for input partitions created from this reader. Partition readers * will generate only data within (`start`, `end`]; that is, from the first record after `start` * to the record with offset `end`. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala index f85971be394b1..1a6b32429313a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala @@ -22,14 +22,14 @@ import scala.reflect.ClassTag import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskContext} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.sources.v2.reader.DataReaderFactory +import org.apache.spark.sql.sources.v2.reader.InputPartition -class DataSourceRDDPartition[T : ClassTag](val index: Int, val readerFactory: DataReaderFactory[T]) +class DataSourceRDDPartition[T : ClassTag](val index: Int, val inputPartition: InputPartition[T]) extends Partition with Serializable class DataSourceRDD[T: ClassTag]( sc: SparkContext, - @transient private val readerFactories: Seq[DataReaderFactory[T]]) + @transient private val readerFactories: Seq[InputPartition[T]]) extends RDD[T](sc, Nil) { override protected def getPartitions: Array[Partition] = { @@ -39,7 +39,8 @@ class DataSourceRDD[T: ClassTag]( } override def compute(split: Partition, context: TaskContext): Iterator[T] = { - val reader = split.asInstanceOf[DataSourceRDDPartition[T]].readerFactory.createDataReader() + val reader = split.asInstanceOf[DataSourceRDDPartition[T]].inputPartition + .createPartitionReader() context.addTaskCompletionListener(_ => reader.close()) val iter = new Iterator[T] { private[this] var valuePrepared = false @@ -63,6 +64,6 @@ class DataSourceRDD[T: ClassTag]( } override def getPreferredLocations(split: Partition): Seq[String] = { - split.asInstanceOf[DataSourceRDDPartition[T]].readerFactory.preferredLocations() + split.asInstanceOf[DataSourceRDDPartition[T]].inputPartition.preferredLocations() } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala index 77cb707340b0f..c6a7684bf6ab0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala @@ -59,13 +59,13 @@ case class DataSourceV2ScanExec( } override def outputPartitioning: physical.Partitioning = reader match { - case r: SupportsScanColumnarBatch if r.enableBatchRead() && batchReaderFactories.size == 1 => + case r: SupportsScanColumnarBatch if r.enableBatchRead() && batchPartitions.size == 1 => SinglePartition - case r: SupportsScanColumnarBatch if !r.enableBatchRead() && readerFactories.size == 1 => + case r: SupportsScanColumnarBatch if !r.enableBatchRead() && partitions.size == 1 => SinglePartition - case r if !r.isInstanceOf[SupportsScanColumnarBatch] && readerFactories.size == 1 => + case r if !r.isInstanceOf[SupportsScanColumnarBatch] && partitions.size == 1 => SinglePartition case s: SupportsReportPartitioning => @@ -75,19 +75,19 @@ case class DataSourceV2ScanExec( case _ => super.outputPartitioning } - private lazy val readerFactories: Seq[DataReaderFactory[UnsafeRow]] = reader match { - case r: SupportsScanUnsafeRow => r.createUnsafeRowReaderFactories().asScala + private lazy val partitions: Seq[InputPartition[UnsafeRow]] = reader match { + case r: SupportsScanUnsafeRow => r.planUnsafeInputPartitions().asScala case _ => - reader.createDataReaderFactories().asScala.map { - new RowToUnsafeRowDataReaderFactory(_, reader.readSchema()): DataReaderFactory[UnsafeRow] + reader.planInputPartitions().asScala.map { + new RowToUnsafeRowInputPartition(_, reader.readSchema()): InputPartition[UnsafeRow] } } - private lazy val batchReaderFactories: Seq[DataReaderFactory[ColumnarBatch]] = reader match { + private lazy val batchPartitions: Seq[InputPartition[ColumnarBatch]] = reader match { case r: SupportsScanColumnarBatch if r.enableBatchRead() => assert(!reader.isInstanceOf[ContinuousReader], "continuous stream reader does not support columnar read yet.") - r.createBatchDataReaderFactories().asScala + r.planBatchInputPartitions().asScala } private lazy val inputRDD: RDD[InternalRow] = reader match { @@ -95,19 +95,18 @@ case class DataSourceV2ScanExec( EpochCoordinatorRef.get( sparkContext.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), sparkContext.env) - .askSync[Unit](SetReaderPartitions(readerFactories.size)) + .askSync[Unit](SetReaderPartitions(partitions.size)) new ContinuousDataSourceRDD( sparkContext, sqlContext.conf.continuousStreamingExecutorQueueSize, sqlContext.conf.continuousStreamingExecutorPollIntervalMs, - readerFactories) - .asInstanceOf[RDD[InternalRow]] + partitions).asInstanceOf[RDD[InternalRow]] case r: SupportsScanColumnarBatch if r.enableBatchRead() => - new DataSourceRDD(sparkContext, batchReaderFactories).asInstanceOf[RDD[InternalRow]] + new DataSourceRDD(sparkContext, batchPartitions).asInstanceOf[RDD[InternalRow]] case _ => - new DataSourceRDD(sparkContext, readerFactories).asInstanceOf[RDD[InternalRow]] + new DataSourceRDD(sparkContext, partitions).asInstanceOf[RDD[InternalRow]] } override def inputRDDs(): Seq[RDD[InternalRow]] = Seq(inputRDD) @@ -132,19 +131,22 @@ case class DataSourceV2ScanExec( } } -class RowToUnsafeRowDataReaderFactory(rowReaderFactory: DataReaderFactory[Row], schema: StructType) - extends DataReaderFactory[UnsafeRow] { +class RowToUnsafeRowInputPartition(partition: InputPartition[Row], schema: StructType) + extends InputPartition[UnsafeRow] { - override def preferredLocations: Array[String] = rowReaderFactory.preferredLocations + override def preferredLocations: Array[String] = partition.preferredLocations - override def createDataReader: DataReader[UnsafeRow] = { - new RowToUnsafeDataReader( - rowReaderFactory.createDataReader, RowEncoder.apply(schema).resolveAndBind()) + override def createPartitionReader: InputPartitionReader[UnsafeRow] = { + new RowToUnsafeInputPartitionReader( + partition.createPartitionReader, RowEncoder.apply(schema).resolveAndBind()) } } -class RowToUnsafeDataReader(val rowReader: DataReader[Row], encoder: ExpressionEncoder[Row]) - extends DataReader[UnsafeRow] { +class RowToUnsafeInputPartitionReader( + val rowReader: InputPartitionReader[Row], + encoder: ExpressionEncoder[Row]) + + extends InputPartitionReader[UnsafeRow] { override def next: Boolean = rowReader.next diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala index 0a3b9dcccb6c5..a7ccce10b0cee 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala @@ -21,14 +21,14 @@ import org.apache.spark._ import org.apache.spark.rdd.RDD import org.apache.spark.sql.{Row, SQLContext} import org.apache.spark.sql.catalyst.expressions.UnsafeRow -import org.apache.spark.sql.execution.datasources.v2.{DataSourceRDDPartition, RowToUnsafeDataReader} +import org.apache.spark.sql.execution.datasources.v2.{DataSourceRDDPartition, RowToUnsafeInputPartitionReader} import org.apache.spark.sql.sources.v2.reader._ -import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousDataReader, PartitionOffset} +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousInputPartitionReader, PartitionOffset} import org.apache.spark.util.{NextIterator, ThreadUtils} class ContinuousDataSourceRDDPartition( val index: Int, - val readerFactory: DataReaderFactory[UnsafeRow]) + val inputPartition: InputPartition[UnsafeRow]) extends Partition with Serializable { // This is semantically a lazy val - it's initialized once the first time a call to @@ -51,12 +51,12 @@ class ContinuousDataSourceRDD( sc: SparkContext, dataQueueSize: Int, epochPollIntervalMs: Long, - @transient private val readerFactories: Seq[DataReaderFactory[UnsafeRow]]) + @transient private val readerFactories: Seq[InputPartition[UnsafeRow]]) extends RDD[UnsafeRow](sc, Nil) { override protected def getPartitions: Array[Partition] = { readerFactories.zipWithIndex.map { - case (readerFactory, index) => new ContinuousDataSourceRDDPartition(index, readerFactory) + case (inputPartition, index) => new ContinuousDataSourceRDDPartition(index, inputPartition) }.toArray } @@ -75,7 +75,7 @@ class ContinuousDataSourceRDD( if (partition.queueReader == null) { partition.queueReader = new ContinuousQueuedDataReader( - partition.readerFactory, context, dataQueueSize, epochPollIntervalMs) + partition.inputPartition, context, dataQueueSize, epochPollIntervalMs) } partition.queueReader @@ -96,17 +96,17 @@ class ContinuousDataSourceRDD( } override def getPreferredLocations(split: Partition): Seq[String] = { - split.asInstanceOf[ContinuousDataSourceRDDPartition].readerFactory.preferredLocations() + split.asInstanceOf[ContinuousDataSourceRDDPartition].inputPartition.preferredLocations() } } object ContinuousDataSourceRDD { private[continuous] def getContinuousReader( - reader: DataReader[UnsafeRow]): ContinuousDataReader[_] = { + reader: InputPartitionReader[UnsafeRow]): ContinuousInputPartitionReader[_] = { reader match { - case r: ContinuousDataReader[UnsafeRow] => r - case wrapped: RowToUnsafeDataReader => - wrapped.rowReader.asInstanceOf[ContinuousDataReader[Row]] + case r: ContinuousInputPartitionReader[UnsafeRow] => r + case wrapped: RowToUnsafeInputPartitionReader => + wrapped.rowReader.asInstanceOf[ContinuousInputPartitionReader[Row]] case _ => throw new IllegalStateException(s"Unknown continuous reader type ${reader.getClass}") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala index 01a999f6505fc..d8645576c2052 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala @@ -18,15 +18,14 @@ package org.apache.spark.sql.execution.streaming.continuous import java.io.Closeable -import java.util.concurrent.{ArrayBlockingQueue, BlockingQueue, TimeUnit} -import java.util.concurrent.atomic.AtomicBoolean +import java.util.concurrent.{ArrayBlockingQueue, TimeUnit} import scala.util.control.NonFatal -import org.apache.spark.{Partition, SparkEnv, SparkException, TaskContext} +import org.apache.spark.{SparkEnv, SparkException, TaskContext} import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.expressions.UnsafeRow -import org.apache.spark.sql.sources.v2.reader.{DataReader, DataReaderFactory} +import org.apache.spark.sql.sources.v2.reader.{InputPartition, InputPartitionReader} import org.apache.spark.sql.sources.v2.reader.streaming.PartitionOffset import org.apache.spark.util.ThreadUtils @@ -38,11 +37,11 @@ import org.apache.spark.util.ThreadUtils * offsets across epochs. Each compute() should call the next() method here until null is returned. */ class ContinuousQueuedDataReader( - factory: DataReaderFactory[UnsafeRow], + partition: InputPartition[UnsafeRow], context: TaskContext, dataQueueSize: Int, epochPollIntervalMs: Long) extends Closeable { - private val reader = factory.createDataReader() + private val reader = partition.createPartitionReader() // Important sequencing - we must get our starting point before the provider threads start running private var currentOffset: PartitionOffset = @@ -132,7 +131,7 @@ class ContinuousQueuedDataReader( /** * The data component of [[ContinuousQueuedDataReader]]. Pushes (row, offset) to the queue when - * a new row arrives to the [[DataReader]]. + * a new row arrives to the [[InputPartitionReader]]. */ class DataReaderThread extends Thread( s"continuous-reader--${context.partitionId()}--" + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala index 2f0de2612c150..8d25d9ccc43d3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.execution.streaming.{RateStreamOffset, ValueRunTimeM import org.apache.spark.sql.execution.streaming.sources.RateStreamProvider import org.apache.spark.sql.sources.v2.DataSourceOptions import org.apache.spark.sql.sources.v2.reader._ -import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousDataReader, ContinuousReader, Offset, PartitionOffset} +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousInputPartitionReader, ContinuousReader, Offset, PartitionOffset} import org.apache.spark.sql.types.StructType case class RateStreamPartitionOffset( @@ -67,7 +67,7 @@ class RateStreamContinuousReader(options: DataSourceOptions) override def getStartOffset(): Offset = offset - override def createDataReaderFactories(): java.util.List[DataReaderFactory[Row]] = { + override def planInputPartitions(): java.util.List[InputPartition[Row]] = { val partitionStartMap = offset match { case off: RateStreamOffset => off.partitionToValueAndRunTimeMs case off => @@ -91,7 +91,7 @@ class RateStreamContinuousReader(options: DataSourceOptions) i, numPartitions, perPartitionRate) - .asInstanceOf[DataReaderFactory[Row]] + .asInstanceOf[InputPartition[Row]] }.asJava } @@ -119,13 +119,13 @@ case class RateStreamContinuousDataReaderFactory( partitionIndex: Int, increment: Long, rowsPerSecond: Double) - extends ContinuousDataReaderFactory[Row] { + extends ContinuousInputPartition[Row] { - override def createDataReaderWithOffset(offset: PartitionOffset): DataReader[Row] = { + override def createContinuousReader(offset: PartitionOffset): InputPartitionReader[Row] = { val rateStreamOffset = offset.asInstanceOf[RateStreamPartitionOffset] require(rateStreamOffset.partition == partitionIndex, s"Expected partitionIndex: $partitionIndex, but got: ${rateStreamOffset.partition}") - new RateStreamContinuousDataReader( + new RateStreamContinuousInputPartitionReader( rateStreamOffset.currentValue, rateStreamOffset.currentTimeMs, partitionIndex, @@ -133,18 +133,18 @@ case class RateStreamContinuousDataReaderFactory( rowsPerSecond) } - override def createDataReader(): DataReader[Row] = - new RateStreamContinuousDataReader( + override def createPartitionReader(): InputPartitionReader[Row] = + new RateStreamContinuousInputPartitionReader( startValue, startTimeMs, partitionIndex, increment, rowsPerSecond) } -class RateStreamContinuousDataReader( +class RateStreamContinuousInputPartitionReader( startValue: Long, startTimeMs: Long, partitionIndex: Int, increment: Long, rowsPerSecond: Double) - extends ContinuousDataReader[Row] { + extends ContinuousInputPartitionReader[Row] { private var nextReadTime: Long = startTimeMs private val readTimeIncrement: Long = (1000 / rowsPerSecond).toLong diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala index 6720cdd24b1b2..daa2963220aef 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeRow} import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ -import org.apache.spark.sql.sources.v2.reader.{DataReader, DataReaderFactory, SupportsScanUnsafeRow} +import org.apache.spark.sql.sources.v2.reader.{InputPartition, InputPartitionReader, SupportsScanUnsafeRow} import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset => OffsetV2} import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType @@ -139,7 +139,7 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) if (endOffset.offset == -1) null else endOffset } - override def createUnsafeRowReaderFactories(): ju.List[DataReaderFactory[UnsafeRow]] = { + override def planUnsafeInputPartitions(): ju.List[InputPartition[UnsafeRow]] = { synchronized { // Compute the internal batch numbers to fetch: [startOrdinal, endOrdinal) val startOrdinal = startOffset.offset.toInt + 1 @@ -156,7 +156,7 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) logDebug(generateDebugString(newBlocks.flatten, startOrdinal, endOrdinal)) newBlocks.map { block => - new MemoryStreamDataReaderFactory(block).asInstanceOf[DataReaderFactory[UnsafeRow]] + new MemoryStreamDataReaderFactory(block).asInstanceOf[InputPartition[UnsafeRow]] }.asJava } } @@ -202,9 +202,9 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) class MemoryStreamDataReaderFactory(records: Array[UnsafeRow]) - extends DataReaderFactory[UnsafeRow] { - override def createDataReader(): DataReader[UnsafeRow] = { - new DataReader[UnsafeRow] { + extends InputPartition[UnsafeRow] { + override def createPartitionReader(): InputPartitionReader[UnsafeRow] = { + new InputPartitionReader[UnsafeRow] { private var currentIndex = -1 override def next(): Boolean = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala index a8fca3c19a2d2..fef792eab69d5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala @@ -34,8 +34,8 @@ import org.apache.spark.sql.{Encoder, Row, SQLContext} import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.sources.ContinuousMemoryStream.GetRecord import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions} -import org.apache.spark.sql.sources.v2.reader.DataReaderFactory -import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousDataReader, ContinuousReader, Offset, PartitionOffset} +import org.apache.spark.sql.sources.v2.reader.InputPartition +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousInputPartitionReader, ContinuousReader, Offset, PartitionOffset} import org.apache.spark.sql.types.StructType import org.apache.spark.util.RpcUtils @@ -99,7 +99,7 @@ class ContinuousMemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) ) } - override def createDataReaderFactories(): ju.List[DataReaderFactory[Row]] = { + override def planInputPartitions(): ju.List[InputPartition[Row]] = { synchronized { val endpointName = s"ContinuousMemoryStreamRecordEndpoint-${java.util.UUID.randomUUID()}-$id" endpointRef = @@ -108,7 +108,7 @@ class ContinuousMemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) startOffset.partitionNums.map { case (part, index) => new ContinuousMemoryStreamDataReaderFactory( - endpointName, part, index): DataReaderFactory[Row] + endpointName, part, index): InputPartition[Row] }.toList.asJava } } @@ -160,9 +160,9 @@ object ContinuousMemoryStream { class ContinuousMemoryStreamDataReaderFactory( driverEndpointName: String, partition: Int, - startOffset: Int) extends DataReaderFactory[Row] { - override def createDataReader: ContinuousMemoryStreamDataReader = - new ContinuousMemoryStreamDataReader(driverEndpointName, partition, startOffset) + startOffset: Int) extends InputPartition[Row] { + override def createPartitionReader: ContinuousMemoryStreamInputPartitionReader = + new ContinuousMemoryStreamInputPartitionReader(driverEndpointName, partition, startOffset) } /** @@ -170,10 +170,10 @@ class ContinuousMemoryStreamDataReaderFactory( * * Polls the driver endpoint for new records. */ -class ContinuousMemoryStreamDataReader( +class ContinuousMemoryStreamInputPartitionReader( driverEndpointName: String, partition: Int, - startOffset: Int) extends ContinuousDataReader[Row] { + startOffset: Int) extends ContinuousInputPartitionReader[Row] { private val endpoint = RpcUtils.makeDriverRef( driverEndpointName, SparkEnv.get.conf, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala index f54291bea6678..723cc3ad5bb89 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala @@ -134,7 +134,7 @@ class RateStreamMicroBatchReader(options: DataSourceOptions, checkpointLocation: LongOffset(json.toLong) } - override def createDataReaderFactories(): java.util.List[DataReaderFactory[Row]] = { + override def planInputPartitions(): java.util.List[InputPartition[Row]] = { val startSeconds = LongOffset.convert(start).map(_.offset).getOrElse(0L) val endSeconds = LongOffset.convert(end).map(_.offset).getOrElse(0L) assert(startSeconds <= endSeconds, s"startSeconds($startSeconds) > endSeconds($endSeconds)") @@ -169,7 +169,7 @@ class RateStreamMicroBatchReader(options: DataSourceOptions, checkpointLocation: (0 until numPartitions).map { p => new RateStreamMicroBatchDataReaderFactory( p, numPartitions, rangeStart, rangeEnd, localStartTimeMs, relativeMsPerValue) - : DataReaderFactory[Row] + : InputPartition[Row] }.toList.asJava } @@ -188,19 +188,20 @@ class RateStreamMicroBatchDataReaderFactory( rangeStart: Long, rangeEnd: Long, localStartTimeMs: Long, - relativeMsPerValue: Double) extends DataReaderFactory[Row] { + relativeMsPerValue: Double) extends InputPartition[Row] { - override def createDataReader(): DataReader[Row] = new RateStreamMicroBatchDataReader( - partitionId, numPartitions, rangeStart, rangeEnd, localStartTimeMs, relativeMsPerValue) + override def createPartitionReader(): InputPartitionReader[Row] = + new RateStreamMicroBatchInputPartitionReader(partitionId, numPartitions, rangeStart, rangeEnd, + localStartTimeMs, relativeMsPerValue) } -class RateStreamMicroBatchDataReader( +class RateStreamMicroBatchInputPartitionReader( partitionId: Int, numPartitions: Int, rangeStart: Long, rangeEnd: Long, localStartTimeMs: Long, - relativeMsPerValue: Double) extends DataReader[Row] { + relativeMsPerValue: Double) extends InputPartitionReader[Row] { private var count = 0 override def next(): Boolean = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala index 90f4a5ba4234d..8240e06d4ab72 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql._ import org.apache.spark.sql.execution.streaming.LongOffset import org.apache.spark.sql.sources.DataSourceRegister import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, MicroBatchReadSupport} -import org.apache.spark.sql.sources.v2.reader.{DataReader, DataReaderFactory} +import org.apache.spark.sql.sources.v2.reader.{InputPartition, InputPartitionReader} import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset} import org.apache.spark.sql.types.{StringType, StructField, StructType, TimestampType} @@ -140,7 +140,7 @@ class TextSocketMicroBatchReader(options: DataSourceOptions) extends MicroBatchR } } - override def createDataReaderFactories(): JList[DataReaderFactory[Row]] = { + override def planInputPartitions(): JList[InputPartition[Row]] = { assert(startOffset != null && endOffset != null, "start offset and end offset should already be set before create read tasks.") @@ -165,21 +165,22 @@ class TextSocketMicroBatchReader(options: DataSourceOptions) extends MicroBatchR (0 until numPartitions).map { i => val slice = slices(i) - new DataReaderFactory[Row] { - override def createDataReader(): DataReader[Row] = new DataReader[Row] { - private var currentIdx = -1 + new InputPartition[Row] { + override def createPartitionReader(): InputPartitionReader[Row] = + new InputPartitionReader[Row] { + private var currentIdx = -1 + + override def next(): Boolean = { + currentIdx += 1 + currentIdx < slice.size + } - override def next(): Boolean = { - currentIdx += 1 - currentIdx < slice.size - } + override def get(): Row = { + Row(slice(currentIdx)._1, slice(currentIdx)._2) + } - override def get(): Row = { - Row(slice(currentIdx)._1, slice(currentIdx)._2) + override def close(): Unit = {} } - - override def close(): Unit = {} - } } }.toList.asJava } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java index 172e5d5eebcbe..714638e500c94 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java @@ -79,8 +79,8 @@ public Filter[] pushedFilters() { } @Override - public List> createDataReaderFactories() { - List> res = new ArrayList<>(); + public List> planInputPartitions() { + List> res = new ArrayList<>(); Integer lowerBound = null; for (Filter filter : filters) { @@ -94,33 +94,33 @@ public List> createDataReaderFactories() { } if (lowerBound == null) { - res.add(new JavaAdvancedDataReaderFactory(0, 5, requiredSchema)); - res.add(new JavaAdvancedDataReaderFactory(5, 10, requiredSchema)); + res.add(new JavaAdvancedInputPartition(0, 5, requiredSchema)); + res.add(new JavaAdvancedInputPartition(5, 10, requiredSchema)); } else if (lowerBound < 4) { - res.add(new JavaAdvancedDataReaderFactory(lowerBound + 1, 5, requiredSchema)); - res.add(new JavaAdvancedDataReaderFactory(5, 10, requiredSchema)); + res.add(new JavaAdvancedInputPartition(lowerBound + 1, 5, requiredSchema)); + res.add(new JavaAdvancedInputPartition(5, 10, requiredSchema)); } else if (lowerBound < 9) { - res.add(new JavaAdvancedDataReaderFactory(lowerBound + 1, 10, requiredSchema)); + res.add(new JavaAdvancedInputPartition(lowerBound + 1, 10, requiredSchema)); } return res; } } - static class JavaAdvancedDataReaderFactory implements DataReaderFactory, DataReader { + static class JavaAdvancedInputPartition implements InputPartition, InputPartitionReader { private int start; private int end; private StructType requiredSchema; - JavaAdvancedDataReaderFactory(int start, int end, StructType requiredSchema) { + JavaAdvancedInputPartition(int start, int end, StructType requiredSchema) { this.start = start; this.end = end; this.requiredSchema = requiredSchema; } @Override - public DataReader createDataReader() { - return new JavaAdvancedDataReaderFactory(start - 1, end, requiredSchema); + public InputPartitionReader createPartitionReader() { + return new JavaAdvancedInputPartition(start - 1, end, requiredSchema); } @Override diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java index c55093768105b..97d6176d02559 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java @@ -42,14 +42,14 @@ public StructType readSchema() { } @Override - public List> createBatchDataReaderFactories() { + public List> planBatchInputPartitions() { return java.util.Arrays.asList( - new JavaBatchDataReaderFactory(0, 50), new JavaBatchDataReaderFactory(50, 90)); + new JavaBatchInputPartition(0, 50), new JavaBatchInputPartition(50, 90)); } } - static class JavaBatchDataReaderFactory - implements DataReaderFactory, DataReader { + static class JavaBatchInputPartition + implements InputPartition, InputPartitionReader { private int start; private int end; @@ -59,13 +59,13 @@ static class JavaBatchDataReaderFactory private OnHeapColumnVector j; private ColumnarBatch batch; - JavaBatchDataReaderFactory(int start, int end) { + JavaBatchInputPartition(int start, int end) { this.start = start; this.end = end; } @Override - public DataReader createDataReader() { + public InputPartitionReader createPartitionReader() { this.i = new OnHeapColumnVector(BATCH_SIZE, DataTypes.IntegerType); this.j = new OnHeapColumnVector(BATCH_SIZE, DataTypes.IntegerType); ColumnVector[] vectors = new ColumnVector[2]; diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java index 32fad59b97ff6..e49c8cf8b9e16 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java @@ -43,10 +43,10 @@ public StructType readSchema() { } @Override - public List> createDataReaderFactories() { + public List> planInputPartitions() { return java.util.Arrays.asList( - new SpecificDataReaderFactory(new int[]{1, 1, 3}, new int[]{4, 4, 6}), - new SpecificDataReaderFactory(new int[]{2, 4, 4}, new int[]{6, 2, 2})); + new SpecificInputPartition(new int[]{1, 1, 3}, new int[]{4, 4, 6}), + new SpecificInputPartition(new int[]{2, 4, 4}, new int[]{6, 2, 2})); } @Override @@ -73,12 +73,12 @@ public boolean satisfy(Distribution distribution) { } } - static class SpecificDataReaderFactory implements DataReaderFactory, DataReader { + static class SpecificInputPartition implements InputPartition, InputPartitionReader { private int[] i; private int[] j; private int current = -1; - SpecificDataReaderFactory(int[] i, int[] j) { + SpecificInputPartition(int[] i, int[] j) { assert i.length == j.length; this.i = i; this.j = j; @@ -101,7 +101,7 @@ public void close() throws IOException { } @Override - public DataReader createDataReader() { + public InputPartitionReader createPartitionReader() { return this; } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java index 048d078dfaac4..80eeffd95f83b 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java @@ -24,7 +24,7 @@ import org.apache.spark.sql.sources.v2.DataSourceV2; import org.apache.spark.sql.sources.v2.ReadSupportWithSchema; import org.apache.spark.sql.sources.v2.reader.DataSourceReader; -import org.apache.spark.sql.sources.v2.reader.DataReaderFactory; +import org.apache.spark.sql.sources.v2.reader.InputPartition; import org.apache.spark.sql.types.StructType; public class JavaSchemaRequiredDataSource implements DataSourceV2, ReadSupportWithSchema { @@ -42,7 +42,7 @@ public StructType readSchema() { } @Override - public List> createDataReaderFactories() { + public List> planInputPartitions() { return java.util.Collections.emptyList(); } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java index 96f55b8a76811..8522a63898a3b 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java @@ -25,8 +25,8 @@ import org.apache.spark.sql.sources.v2.DataSourceV2; import org.apache.spark.sql.sources.v2.DataSourceOptions; import org.apache.spark.sql.sources.v2.ReadSupport; -import org.apache.spark.sql.sources.v2.reader.DataReader; -import org.apache.spark.sql.sources.v2.reader.DataReaderFactory; +import org.apache.spark.sql.sources.v2.reader.InputPartitionReader; +import org.apache.spark.sql.sources.v2.reader.InputPartition; import org.apache.spark.sql.sources.v2.reader.DataSourceReader; import org.apache.spark.sql.types.StructType; @@ -41,25 +41,25 @@ public StructType readSchema() { } @Override - public List> createDataReaderFactories() { + public List> planInputPartitions() { return java.util.Arrays.asList( - new JavaSimpleDataReaderFactory(0, 5), - new JavaSimpleDataReaderFactory(5, 10)); + new JavaSimpleInputPartition(0, 5), + new JavaSimpleInputPartition(5, 10)); } } - static class JavaSimpleDataReaderFactory implements DataReaderFactory, DataReader { + static class JavaSimpleInputPartition implements InputPartition, InputPartitionReader { private int start; private int end; - JavaSimpleDataReaderFactory(int start, int end) { + JavaSimpleInputPartition(int start, int end) { this.start = start; this.end = end; } @Override - public DataReader createDataReader() { - return new JavaSimpleDataReaderFactory(start - 1, end); + public InputPartitionReader createPartitionReader() { + return new JavaSimpleInputPartition(start - 1, end); } @Override diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaUnsafeRowDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaUnsafeRowDataSourceV2.java index c3916e0b370b5..3ad8e7a0104ce 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaUnsafeRowDataSourceV2.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaUnsafeRowDataSourceV2.java @@ -38,20 +38,20 @@ public StructType readSchema() { } @Override - public List> createUnsafeRowReaderFactories() { + public List> planUnsafeInputPartitions() { return java.util.Arrays.asList( - new JavaUnsafeRowDataReaderFactory(0, 5), - new JavaUnsafeRowDataReaderFactory(5, 10)); + new JavaUnsafeRowInputPartition(0, 5), + new JavaUnsafeRowInputPartition(5, 10)); } } - static class JavaUnsafeRowDataReaderFactory - implements DataReaderFactory, DataReader { + static class JavaUnsafeRowInputPartition + implements InputPartition, InputPartitionReader { private int start; private int end; private UnsafeRow row; - JavaUnsafeRowDataReaderFactory(int start, int end) { + JavaUnsafeRowInputPartition(int start, int end) { this.start = start; this.end = end; this.row = new UnsafeRow(2); @@ -59,8 +59,8 @@ static class JavaUnsafeRowDataReaderFactory } @Override - public DataReader createDataReader() { - return new JavaUnsafeRowDataReaderFactory(start - 1, end); + public InputPartitionReader createPartitionReader() { + return new JavaUnsafeRowInputPartition(start - 1, end); } @Override diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala index ff14ec38e66a8..39a010f970ce5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala @@ -142,9 +142,9 @@ class RateSourceSuite extends StreamTest { val startOffset = LongOffset(0L) val endOffset = LongOffset(1L) reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) - val tasks = reader.createDataReaderFactories() + val tasks = reader.planInputPartitions() assert(tasks.size == 1) - val dataReader = tasks.get(0).createDataReader() + val dataReader = tasks.get(0).createPartitionReader() val data = ArrayBuffer[Row]() while (dataReader.next()) { data.append(dataReader.get()) @@ -159,11 +159,11 @@ class RateSourceSuite extends StreamTest { val startOffset = LongOffset(0L) val endOffset = LongOffset(1L) reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) - val tasks = reader.createDataReaderFactories() + val tasks = reader.planInputPartitions() assert(tasks.size == 11) val readData = tasks.asScala - .map(_.createDataReader()) + .map(_.createPartitionReader()) .flatMap { reader => val buf = scala.collection.mutable.ListBuffer[Row]() while (reader.next()) buf.append(reader.get()) @@ -304,7 +304,7 @@ class RateSourceSuite extends StreamTest { val reader = new RateStreamContinuousReader( new DataSourceOptions(Map("numPartitions" -> "2", "rowsPerSecond" -> "20").asJava)) reader.setStartOffset(Optional.empty()) - val tasks = reader.createDataReaderFactories() + val tasks = reader.planInputPartitions() assert(tasks.size == 2) val data = scala.collection.mutable.ListBuffer[Row]() @@ -314,7 +314,7 @@ class RateSourceSuite extends StreamTest { .asInstanceOf[RateStreamOffset] .partitionToValueAndRunTimeMs(t.partitionIndex) .runTimeMs - val r = t.createDataReader().asInstanceOf[RateStreamContinuousDataReader] + val r = t.createPartitionReader().asInstanceOf[RateStreamContinuousInputPartitionReader] for (rowIndex <- 0 to 9) { r.next() data.append(r.get()) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala index e0a53272cd222..505a3f3465c02 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala @@ -346,8 +346,8 @@ class SimpleSinglePartitionSource extends DataSourceV2 with ReadSupport { class Reader extends DataSourceReader { override def readSchema(): StructType = new StructType().add("i", "int").add("j", "int") - override def createDataReaderFactories(): JList[DataReaderFactory[Row]] = { - java.util.Arrays.asList(new SimpleDataReaderFactory(0, 5)) + override def planInputPartitions(): JList[InputPartition[Row]] = { + java.util.Arrays.asList(new SimpleInputPartition(0, 5)) } } @@ -359,20 +359,21 @@ class SimpleDataSourceV2 extends DataSourceV2 with ReadSupport { class Reader extends DataSourceReader { override def readSchema(): StructType = new StructType().add("i", "int").add("j", "int") - override def createDataReaderFactories(): JList[DataReaderFactory[Row]] = { - java.util.Arrays.asList(new SimpleDataReaderFactory(0, 5), new SimpleDataReaderFactory(5, 10)) + override def planInputPartitions(): JList[InputPartition[Row]] = { + java.util.Arrays.asList(new SimpleInputPartition(0, 5), new SimpleInputPartition(5, 10)) } } override def createReader(options: DataSourceOptions): DataSourceReader = new Reader } -class SimpleDataReaderFactory(start: Int, end: Int) - extends DataReaderFactory[Row] - with DataReader[Row] { +class SimpleInputPartition(start: Int, end: Int) + extends InputPartition[Row] + with InputPartitionReader[Row] { private var current = start - 1 - override def createDataReader(): DataReader[Row] = new SimpleDataReaderFactory(start, end) + override def createPartitionReader(): InputPartitionReader[Row] = + new SimpleInputPartition(start, end) override def next(): Boolean = { current += 1 @@ -413,21 +414,21 @@ class AdvancedDataSourceV2 extends DataSourceV2 with ReadSupport { requiredSchema } - override def createDataReaderFactories(): JList[DataReaderFactory[Row]] = { + override def planInputPartitions(): JList[InputPartition[Row]] = { val lowerBound = filters.collect { case GreaterThan("i", v: Int) => v }.headOption - val res = new ArrayList[DataReaderFactory[Row]] + val res = new ArrayList[InputPartition[Row]] if (lowerBound.isEmpty) { - res.add(new AdvancedDataReaderFactory(0, 5, requiredSchema)) - res.add(new AdvancedDataReaderFactory(5, 10, requiredSchema)) + res.add(new AdvancedInputPartition(0, 5, requiredSchema)) + res.add(new AdvancedInputPartition(5, 10, requiredSchema)) } else if (lowerBound.get < 4) { - res.add(new AdvancedDataReaderFactory(lowerBound.get + 1, 5, requiredSchema)) - res.add(new AdvancedDataReaderFactory(5, 10, requiredSchema)) + res.add(new AdvancedInputPartition(lowerBound.get + 1, 5, requiredSchema)) + res.add(new AdvancedInputPartition(5, 10, requiredSchema)) } else if (lowerBound.get < 9) { - res.add(new AdvancedDataReaderFactory(lowerBound.get + 1, 10, requiredSchema)) + res.add(new AdvancedInputPartition(lowerBound.get + 1, 10, requiredSchema)) } res @@ -437,13 +438,13 @@ class AdvancedDataSourceV2 extends DataSourceV2 with ReadSupport { override def createReader(options: DataSourceOptions): DataSourceReader = new Reader } -class AdvancedDataReaderFactory(start: Int, end: Int, requiredSchema: StructType) - extends DataReaderFactory[Row] with DataReader[Row] { +class AdvancedInputPartition(start: Int, end: Int, requiredSchema: StructType) + extends InputPartition[Row] with InputPartitionReader[Row] { private var current = start - 1 - override def createDataReader(): DataReader[Row] = { - new AdvancedDataReaderFactory(start, end, requiredSchema) + override def createPartitionReader(): InputPartitionReader[Row] = { + new AdvancedInputPartition(start, end, requiredSchema) } override def close(): Unit = {} @@ -468,24 +469,24 @@ class UnsafeRowDataSourceV2 extends DataSourceV2 with ReadSupport { class Reader extends DataSourceReader with SupportsScanUnsafeRow { override def readSchema(): StructType = new StructType().add("i", "int").add("j", "int") - override def createUnsafeRowReaderFactories(): JList[DataReaderFactory[UnsafeRow]] = { - java.util.Arrays.asList(new UnsafeRowDataReaderFactory(0, 5), - new UnsafeRowDataReaderFactory(5, 10)) + override def planUnsafeInputPartitions(): JList[InputPartition[UnsafeRow]] = { + java.util.Arrays.asList(new UnsafeRowInputPartitionReader(0, 5), + new UnsafeRowInputPartitionReader(5, 10)) } } override def createReader(options: DataSourceOptions): DataSourceReader = new Reader } -class UnsafeRowDataReaderFactory(start: Int, end: Int) - extends DataReaderFactory[UnsafeRow] with DataReader[UnsafeRow] { +class UnsafeRowInputPartitionReader(start: Int, end: Int) + extends InputPartition[UnsafeRow] with InputPartitionReader[UnsafeRow] { private val row = new UnsafeRow(2) row.pointTo(new Array[Byte](8 * 3), 8 * 3) private var current = start - 1 - override def createDataReader(): DataReader[UnsafeRow] = this + override def createPartitionReader(): InputPartitionReader[UnsafeRow] = this override def next(): Boolean = { current += 1 @@ -503,7 +504,7 @@ class UnsafeRowDataReaderFactory(start: Int, end: Int) class SchemaRequiredDataSource extends DataSourceV2 with ReadSupportWithSchema { class Reader(val readSchema: StructType) extends DataSourceReader { - override def createDataReaderFactories(): JList[DataReaderFactory[Row]] = + override def planInputPartitions(): JList[InputPartition[Row]] = java.util.Collections.emptyList() } @@ -516,16 +517,17 @@ class BatchDataSourceV2 extends DataSourceV2 with ReadSupport { class Reader extends DataSourceReader with SupportsScanColumnarBatch { override def readSchema(): StructType = new StructType().add("i", "int").add("j", "int") - override def createBatchDataReaderFactories(): JList[DataReaderFactory[ColumnarBatch]] = { - java.util.Arrays.asList(new BatchDataReaderFactory(0, 50), new BatchDataReaderFactory(50, 90)) + override def planBatchInputPartitions(): JList[InputPartition[ColumnarBatch]] = { + java.util.Arrays.asList( + new BatchInputPartitionReader(0, 50), new BatchInputPartitionReader(50, 90)) } } override def createReader(options: DataSourceOptions): DataSourceReader = new Reader } -class BatchDataReaderFactory(start: Int, end: Int) - extends DataReaderFactory[ColumnarBatch] with DataReader[ColumnarBatch] { +class BatchInputPartitionReader(start: Int, end: Int) + extends InputPartition[ColumnarBatch] with InputPartitionReader[ColumnarBatch] { private final val BATCH_SIZE = 20 private lazy val i = new OnHeapColumnVector(BATCH_SIZE, IntegerType) @@ -534,7 +536,7 @@ class BatchDataReaderFactory(start: Int, end: Int) private var current = start - override def createDataReader(): DataReader[ColumnarBatch] = this + override def createPartitionReader(): InputPartitionReader[ColumnarBatch] = this override def next(): Boolean = { i.reset() @@ -568,11 +570,11 @@ class PartitionAwareDataSource extends DataSourceV2 with ReadSupport { class Reader extends DataSourceReader with SupportsReportPartitioning { override def readSchema(): StructType = new StructType().add("a", "int").add("b", "int") - override def createDataReaderFactories(): JList[DataReaderFactory[Row]] = { + override def planInputPartitions(): JList[InputPartition[Row]] = { // Note that we don't have same value of column `a` across partitions. java.util.Arrays.asList( - new SpecificDataReaderFactory(Array(1, 1, 3), Array(4, 4, 6)), - new SpecificDataReaderFactory(Array(2, 4, 4), Array(6, 2, 2))) + new SpecificInputPartitionReader(Array(1, 1, 3), Array(4, 4, 6)), + new SpecificInputPartitionReader(Array(2, 4, 4), Array(6, 2, 2))) } override def outputPartitioning(): Partitioning = new MyPartitioning @@ -590,14 +592,14 @@ class PartitionAwareDataSource extends DataSourceV2 with ReadSupport { override def createReader(options: DataSourceOptions): DataSourceReader = new Reader } -class SpecificDataReaderFactory(i: Array[Int], j: Array[Int]) - extends DataReaderFactory[Row] - with DataReader[Row] { +class SpecificInputPartitionReader(i: Array[Int], j: Array[Int]) + extends InputPartition[Row] + with InputPartitionReader[Row] { assert(i.length == j.length) private var current = -1 - override def createDataReader(): DataReader[Row] = this + override def createPartitionReader(): InputPartitionReader[Row] = this override def next(): Boolean = { current += 1 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala index a5007fa321359..694bb3b95b0f0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala @@ -28,7 +28,7 @@ import org.apache.hadoop.fs.{FileSystem, FSDataInputStream, Path} import org.apache.spark.SparkContext import org.apache.spark.sql.{Row, SaveMode} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.sources.v2.reader.{DataReader, DataReaderFactory, DataSourceReader} +import org.apache.spark.sql.sources.v2.reader.{DataSourceReader, InputPartition, InputPartitionReader} import org.apache.spark.sql.sources.v2.writer._ import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.util.SerializableConfiguration @@ -45,7 +45,7 @@ class SimpleWritableDataSource extends DataSourceV2 with ReadSupport with WriteS class Reader(path: String, conf: Configuration) extends DataSourceReader { override def readSchema(): StructType = schema - override def createDataReaderFactories(): JList[DataReaderFactory[Row]] = { + override def planInputPartitions(): JList[InputPartition[Row]] = { val dataPath = new Path(path) val fs = dataPath.getFileSystem(conf) if (fs.exists(dataPath)) { @@ -54,9 +54,9 @@ class SimpleWritableDataSource extends DataSourceV2 with ReadSupport with WriteS name.startsWith("_") || name.startsWith(".") }.map { f => val serializableConf = new SerializableConfiguration(conf) - new SimpleCSVDataReaderFactory( + new SimpleCSVInputPartitionReader( f.getPath.toUri.toString, - serializableConf): DataReaderFactory[Row] + serializableConf): InputPartition[Row] }.toList.asJava } else { Collections.emptyList() @@ -156,14 +156,14 @@ class SimpleWritableDataSource extends DataSourceV2 with ReadSupport with WriteS } } -class SimpleCSVDataReaderFactory(path: String, conf: SerializableConfiguration) - extends DataReaderFactory[Row] with DataReader[Row] { +class SimpleCSVInputPartitionReader(path: String, conf: SerializableConfiguration) + extends InputPartition[Row] with InputPartitionReader[Row] { @transient private var lines: Iterator[String] = _ @transient private var currentLine: String = _ @transient private var inputStream: FSDataInputStream = _ - override def createDataReader(): DataReader[Row] = { + override def createPartitionReader(): InputPartitionReader[Row] = { val filePath = new Path(path) val fs = filePath.getFileSystem(conf.value) inputStream = fs.open(filePath) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala index 57986999bf861..dcf6cb5d609ee 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala @@ -35,7 +35,7 @@ import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.sources.TestForeachWriter import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.sources.v2.reader.DataReaderFactory +import org.apache.spark.sql.sources.v2.reader.InputPartition import org.apache.spark.sql.sources.v2.reader.streaming.{Offset => OffsetV2} import org.apache.spark.sql.streaming.util.{BlockingSource, MockSourceProvider, StreamManualClock} import org.apache.spark.sql.types.StructType @@ -227,10 +227,10 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi } // getBatch should take 100 ms the first time it is called - override def createUnsafeRowReaderFactories(): ju.List[DataReaderFactory[UnsafeRow]] = { + override def planUnsafeInputPartitions(): ju.List[InputPartition[UnsafeRow]] = { synchronized { clock.waitTillTime(1350) - super.createUnsafeRowReaderFactories() + super.planUnsafeInputPartitions() } } } @@ -290,13 +290,14 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi AdvanceManualClock(100), // time = 1150 to unblock getEndOffset AssertClockTime(1150), - AssertStreamExecThreadIsWaitingForTime(1350), // will block on createReadTasks that needs 1350 + // will block on planInputPartitions that needs 1350 + AssertStreamExecThreadIsWaitingForTime(1350), AssertOnQuery(_.status.isDataAvailable === true), AssertOnQuery(_.status.isTriggerActive === true), AssertOnQuery(_.status.message === "Processing new data"), AssertOnQuery(_.recentProgress.count(_.numInputRows > 0) === 0), - AdvanceManualClock(200), // time = 1350 to unblock createReadTasks + AdvanceManualClock(200), // time = 1350 to unblock planInputPartitions AssertClockTime(1350), AssertStreamExecThreadIsWaitingForTime(1500), // will block on map task that needs 1500 AssertOnQuery(_.status.isDataAvailable === true), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala index e755625d09e0f..f47d3ec8ae025 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala @@ -27,8 +27,8 @@ import org.apache.spark.{SparkEnv, SparkFunSuite, TaskContext} import org.apache.spark.rpc.{RpcEndpointRef, RpcEnv} import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.execution.streaming.continuous._ -import org.apache.spark.sql.sources.v2.reader.DataReaderFactory -import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousDataReader, ContinuousReader, PartitionOffset} +import org.apache.spark.sql.sources.v2.reader.InputPartition +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousInputPartitionReader, ContinuousReader, PartitionOffset} import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter import org.apache.spark.sql.streaming.StreamTest import org.apache.spark.sql.types.{DataType, IntegerType} @@ -72,8 +72,8 @@ class ContinuousQueuedDataReaderSuite extends StreamTest with MockitoSugar { */ private def setup(): (BlockingQueue[UnsafeRow], ContinuousQueuedDataReader) = { val queue = new ArrayBlockingQueue[UnsafeRow](1024) - val factory = new DataReaderFactory[UnsafeRow] { - override def createDataReader() = new ContinuousDataReader[UnsafeRow] { + val factory = new InputPartition[UnsafeRow] { + override def createPartitionReader() = new ContinuousInputPartitionReader[UnsafeRow] { var index = -1 var curr: UnsafeRow = _ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala index af4618bed5456..c1a28b9bc75ef 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.{DataSourceRegister, StreamSinkProvider} import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions, MicroBatchReadSupport, StreamWriteSupport} -import org.apache.spark.sql.sources.v2.reader.DataReaderFactory +import org.apache.spark.sql.sources.v2.reader.InputPartition import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReader, MicroBatchReader, Offset, PartitionOffset} import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter import org.apache.spark.sql.streaming.{OutputMode, StreamTest, Trigger} @@ -44,7 +44,7 @@ case class FakeReader() extends MicroBatchReader with ContinuousReader { def mergeOffsets(offsets: Array[PartitionOffset]): Offset = RateStreamOffset(Map()) def setStartOffset(start: Optional[Offset]): Unit = {} - def createDataReaderFactories(): java.util.ArrayList[DataReaderFactory[Row]] = { + def planInputPartitions(): java.util.ArrayList[InputPartition[Row]] = { throw new IllegalStateException("fake source - cannot actually read") } } From e3d434994733ae16e7e1424fb6de2d22b1a13f99 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Thu, 10 May 2018 13:36:52 +0800 Subject: [PATCH 0774/1161] [SPARK-22279][SQL] Enable `convertMetastoreOrc` by default ## What changes were proposed in this pull request? We reverted `spark.sql.hive.convertMetastoreOrc` at https://github.com/apache/spark/pull/20536 because we should not ignore the table-specific compression conf. Now, it's resolved via [SPARK-23355](https://github.com/apache/spark/commit/8aa1d7b0ede5115297541d29eab4ce5f4fe905cb). ## How was this patch tested? Pass the Jenkins. Author: Dongjoon Hyun Closes #21186 from dongjoon-hyun/SPARK-24112. --- docs/sql-programming-guide.md | 3 ++- .../src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala | 3 +-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 3e8946e424237..3f79ed6422205 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1017,7 +1017,7 @@ the vectorized reader is used when `spark.sql.hive.convertMetastoreOrc` is also - + @@ -1813,6 +1813,7 @@ working with timestamps in `pandas_udf`s to get the best performance, see - Since Spark 2.4, the type coercion rules can automatically promote the argument types of the variadic SQL functions (e.g., IN/COALESCE) to the widest common type, no matter how the input arguments order. In prior Spark versions, the promotion could fail in some specific orders (e.g., TimestampType, IntegerType and StringType) and throw an exception. - In version 2.3 and earlier, `to_utc_timestamp` and `from_utc_timestamp` respect the timezone in the input timestamp string, which breaks the assumption that the input timestamp is in a specific timezone. Therefore, these 2 functions can return unexpected results. In version 2.4 and later, this problem has been fixed. `to_utc_timestamp` and `from_utc_timestamp` will return null if the input timestamp string contains timezone. As an example, `from_utc_timestamp('2000-10-10 00:00:00', 'GMT+1')` will return `2000-10-10 01:00:00` in both Spark 2.3 and 2.4. However, `from_utc_timestamp('2000-10-10 00:00:00+00:00', 'GMT+1')`, assuming a local timezone of GMT+8, will return `2000-10-10 09:00:00` in Spark 2.3 but `null` in 2.4. For people who don't care about this problem and want to retain the previous behaivor to keep their query unchanged, you can set `spark.sql.function.rejectTimezoneInString` to false. This option will be removed in Spark 3.0 and should only be used as a temporary workaround. - In version 2.3 and earlier, Spark converts Parquet Hive tables by default but ignores table properties like `TBLPROPERTIES (parquet.compression 'NONE')`. This happens for ORC Hive table properties like `TBLPROPERTIES (orc.compress 'NONE')` in case of `spark.sql.hive.convertMetastoreOrc=true`, too. Since Spark 2.4, Spark respects Parquet/ORC specific table properties while converting Parquet/ORC Hive tables. As an example, `CREATE TABLE t(id int) STORED AS PARQUET TBLPROPERTIES (parquet.compression 'NONE')` would generate Snappy parquet files during insertion in Spark 2.3, and in Spark 2.4, the result would be uncompressed parquet files. + - Since Spark 2.0, Spark converts Parquet Hive tables by default for better performance. Since Spark 2.4, Spark converts ORC Hive tables by default, too. It means Spark uses its own ORC support by default instead of Hive SerDe. As an example, `CREATE TABLE t(id int) STORED AS ORC` would be handled with Hive SerDe in Spark 2.3, and in Spark 2.4, it would be converted into Spark's ORC data source table and ORC vectorization would be applied. To set `false` to `spark.sql.hive.convertMetastoreOrc` restores the previous behavior. ## Upgrading From Spark SQL 2.2 to 2.3 diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala index 10c9603745379..bb134bbe68bd9 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala @@ -105,11 +105,10 @@ private[spark] object HiveUtils extends Logging { .createWithDefault(false) val CONVERT_METASTORE_ORC = buildConf("spark.sql.hive.convertMetastoreOrc") - .internal() .doc("When set to true, the built-in ORC reader and writer are used to process " + "ORC tables created by using the HiveQL syntax, instead of Hive serde.") .booleanConf - .createWithDefault(false) + .createWithDefault(true) val HIVE_METASTORE_SHARED_PREFIXES = buildConf("spark.sql.hive.metastore.sharedPrefixes") .doc("A comma separated list of class prefixes that should be loaded using the classloader " + From 94d671448240c8f6da11d2523ba9e4ae5b56a410 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Thu, 10 May 2018 20:38:52 +0900 Subject: [PATCH 0775/1161] [SPARK-23907][SQL] Add regr_* functions ## What changes were proposed in this pull request? The PR introduces regr_slope, regr_intercept, regr_r2, regr_sxx, regr_syy, regr_sxy, regr_avgx, regr_avgy, regr_count. The implementation of this functions mirrors Hive's one in HIVE-15978. ## How was this patch tested? added UT (values compared with Hive) Author: Marco Gaido Closes #21054 from mgaido91/SPARK-23907. --- .../catalyst/analysis/FunctionRegistry.scala | 9 + .../expressions/aggregate/Average.scala | 47 +++-- .../aggregate/CentralMomentAgg.scala | 60 +++--- .../catalyst/expressions/aggregate/Corr.scala | 52 ++--- .../expressions/aggregate/Count.scala | 47 +++-- .../expressions/aggregate/Covariance.scala | 36 ++-- .../expressions/aggregate/regression.scala | 190 ++++++++++++++++++ .../org/apache/spark/sql/functions.scala | 172 ++++++++++++++++ .../sql-tests/inputs/udaf-regrfunctions.sql | 56 ++++++ .../results/udaf-regrfunctions.sql.out | 93 +++++++++ .../spark/sql/DataFrameAggregateSuite.scala | 71 ++++++- 11 files changed, 721 insertions(+), 112 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/regression.scala create mode 100644 sql/core/src/test/resources/sql-tests/inputs/udaf-regrfunctions.sql create mode 100644 sql/core/src/test/resources/sql-tests/results/udaf-regrfunctions.sql.out diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 87b0911e150c5..087d000a9db70 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -299,6 +299,15 @@ object FunctionRegistry { expression[CollectList]("collect_list"), expression[CollectSet]("collect_set"), expression[CountMinSketchAgg]("count_min_sketch"), + expression[RegrCount]("regr_count"), + expression[RegrSXX]("regr_sxx"), + expression[RegrSYY]("regr_syy"), + expression[RegrAvgX]("regr_avgx"), + expression[RegrAvgY]("regr_avgy"), + expression[RegrSXY]("regr_sxy"), + expression[RegrSlope]("regr_slope"), + expression[RegrR2]("regr_r2"), + expression[RegrIntercept]("regr_intercept"), // string functions expression[Ascii]("ascii"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala index 708bdbfc36058..a133bc2361eb5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala @@ -23,24 +23,12 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ -@ExpressionDescription( - usage = "_FUNC_(expr) - Returns the mean calculated from values of a group.") -case class Average(child: Expression) extends DeclarativeAggregate with ImplicitCastInputTypes { - - override def prettyName: String = "avg" - - override def children: Seq[Expression] = child :: Nil +abstract class AverageLike(child: Expression) extends DeclarativeAggregate { override def nullable: Boolean = true - // Return data type. override def dataType: DataType = resultType - override def inputTypes: Seq[AbstractDataType] = Seq(NumericType) - - override def checkInputDataTypes(): TypeCheckResult = - TypeUtils.checkForNumericExpr(child.dataType, "function average") - private lazy val resultType = child.dataType match { case DecimalType.Fixed(p, s) => DecimalType.bounded(p + 4, s + 4) @@ -62,14 +50,6 @@ case class Average(child: Expression) extends DeclarativeAggregate with Implicit /* count = */ Literal(0L) ) - override lazy val updateExpressions = Seq( - /* sum = */ - Add( - sum, - Coalesce(Cast(child, sumDataType) :: Cast(Literal(0), sumDataType) :: Nil)), - /* count = */ If(IsNull(child), count, count + 1L) - ) - override lazy val mergeExpressions = Seq( /* sum = */ sum.left + sum.right, /* count = */ count.left + count.right @@ -85,4 +65,29 @@ case class Average(child: Expression) extends DeclarativeAggregate with Implicit case _ => Cast(sum, resultType) / Cast(count, resultType) } + + protected def updateExpressionsDef: Seq[Expression] = Seq( + /* sum = */ + Add( + sum, + Coalesce(Cast(child, sumDataType) :: Cast(Literal(0), sumDataType) :: Nil)), + /* count = */ If(IsNull(child), count, count + 1L) + ) + + override lazy val updateExpressions = updateExpressionsDef +} + +@ExpressionDescription( + usage = "_FUNC_(expr) - Returns the mean calculated from values of a group.") +case class Average(child: Expression) + extends AverageLike(child) with ImplicitCastInputTypes { + + override def prettyName: String = "avg" + + override def children: Seq[Expression] = child :: Nil + + override def inputTypes: Seq[AbstractDataType] = Seq(NumericType) + + override def checkInputDataTypes(): TypeCheckResult = + TypeUtils.checkForNumericExpr(child.dataType, "function average") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala index 572d29caf5bc9..6bbb083f1e18e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala @@ -67,35 +67,7 @@ abstract class CentralMomentAgg(child: Expression) override val initialValues: Seq[Expression] = Array.fill(momentOrder + 1)(Literal(0.0)) - override val updateExpressions: Seq[Expression] = { - val newN = n + Literal(1.0) - val delta = child - avg - val deltaN = delta / newN - val newAvg = avg + deltaN - val newM2 = m2 + delta * (delta - deltaN) - - val delta2 = delta * delta - val deltaN2 = deltaN * deltaN - val newM3 = if (momentOrder >= 3) { - m3 - Literal(3.0) * deltaN * newM2 + delta * (delta2 - deltaN2) - } else { - Literal(0.0) - } - val newM4 = if (momentOrder >= 4) { - m4 - Literal(4.0) * deltaN * newM3 - Literal(6.0) * deltaN2 * newM2 + - delta * (delta * delta2 - deltaN * deltaN2) - } else { - Literal(0.0) - } - - trimHigherOrder(Seq( - If(IsNull(child), n, newN), - If(IsNull(child), avg, newAvg), - If(IsNull(child), m2, newM2), - If(IsNull(child), m3, newM3), - If(IsNull(child), m4, newM4) - )) - } + override lazy val updateExpressions: Seq[Expression] = updateExpressionsDef override val mergeExpressions: Seq[Expression] = { @@ -128,6 +100,36 @@ abstract class CentralMomentAgg(child: Expression) trimHigherOrder(Seq(newN, newAvg, newM2, newM3, newM4)) } + + protected def updateExpressionsDef: Seq[Expression] = { + val newN = n + Literal(1.0) + val delta = child - avg + val deltaN = delta / newN + val newAvg = avg + deltaN + val newM2 = m2 + delta * (delta - deltaN) + + val delta2 = delta * delta + val deltaN2 = deltaN * deltaN + val newM3 = if (momentOrder >= 3) { + m3 - Literal(3.0) * deltaN * newM2 + delta * (delta2 - deltaN2) + } else { + Literal(0.0) + } + val newM4 = if (momentOrder >= 4) { + m4 - Literal(4.0) * deltaN * newM3 - Literal(6.0) * deltaN2 * newM2 + + delta * (delta * delta2 - deltaN * deltaN2) + } else { + Literal(0.0) + } + + trimHigherOrder(Seq( + If(IsNull(child), n, newN), + If(IsNull(child), avg, newAvg), + If(IsNull(child), m2, newM2), + If(IsNull(child), m3, newM3), + If(IsNull(child), m4, newM4) + )) + } } // Compute the population standard deviation of a column diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala index 95a4a0d5af634..3cdef72c1f2c4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala @@ -22,17 +22,13 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ /** - * Compute Pearson correlation between two expressions. + * Base class for computing Pearson correlation between two expressions. * When applied on empty data (i.e., count is zero), it returns NULL. * * Definition of Pearson correlation can be found at * http://en.wikipedia.org/wiki/Pearson_product-moment_correlation_coefficient */ -// scalastyle:off line.size.limit -@ExpressionDescription( - usage = "_FUNC_(expr1, expr2) - Returns Pearson coefficient of correlation between a set of number pairs.") -// scalastyle:on line.size.limit -case class Corr(x: Expression, y: Expression) +abstract class PearsonCorrelation(x: Expression, y: Expression) extends DeclarativeAggregate with ImplicitCastInputTypes { override def children: Seq[Expression] = Seq(x, y) @@ -51,7 +47,26 @@ case class Corr(x: Expression, y: Expression) override val initialValues: Seq[Expression] = Array.fill(6)(Literal(0.0)) - override val updateExpressions: Seq[Expression] = { + override lazy val updateExpressions: Seq[Expression] = updateExpressionsDef + + override val mergeExpressions: Seq[Expression] = { + val n1 = n.left + val n2 = n.right + val newN = n1 + n2 + val dx = xAvg.right - xAvg.left + val dxN = If(newN === Literal(0.0), Literal(0.0), dx / newN) + val dy = yAvg.right - yAvg.left + val dyN = If(newN === Literal(0.0), Literal(0.0), dy / newN) + val newXAvg = xAvg.left + dxN * n2 + val newYAvg = yAvg.left + dyN * n2 + val newCk = ck.left + ck.right + dx * dyN * n1 * n2 + val newXMk = xMk.left + xMk.right + dx * dxN * n1 * n2 + val newYMk = yMk.left + yMk.right + dy * dyN * n1 * n2 + + Seq(newN, newXAvg, newYAvg, newCk, newXMk, newYMk) + } + + protected def updateExpressionsDef: Seq[Expression] = { val newN = n + Literal(1.0) val dx = x - xAvg val dxN = dx / newN @@ -73,24 +88,15 @@ case class Corr(x: Expression, y: Expression) If(isNull, yMk, newYMk) ) } +} - override val mergeExpressions: Seq[Expression] = { - - val n1 = n.left - val n2 = n.right - val newN = n1 + n2 - val dx = xAvg.right - xAvg.left - val dxN = If(newN === Literal(0.0), Literal(0.0), dx / newN) - val dy = yAvg.right - yAvg.left - val dyN = If(newN === Literal(0.0), Literal(0.0), dy / newN) - val newXAvg = xAvg.left + dxN * n2 - val newYAvg = yAvg.left + dyN * n2 - val newCk = ck.left + ck.right + dx * dyN * n1 * n2 - val newXMk = xMk.left + xMk.right + dx * dxN * n1 * n2 - val newYMk = yMk.left + yMk.right + dy * dyN * n1 * n2 - Seq(newN, newXAvg, newYAvg, newCk, newXMk, newYMk) - } +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(expr1, expr2) - Returns Pearson coefficient of correlation between a set of number pairs.") +// scalastyle:on line.size.limit +case class Corr(x: Expression, y: Expression) + extends PearsonCorrelation(x, y) { override val evaluateExpression: Expression = { If(n === Literal(0.0), Literal.create(null, DoubleType), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala index 1990f2f2f0722..40582d0abd762 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala @@ -21,24 +21,16 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ -// scalastyle:off line.size.limit -@ExpressionDescription( - usage = """ - _FUNC_(*) - Returns the total number of retrieved rows, including rows containing null. - - _FUNC_(expr) - Returns the number of rows for which the supplied expression is non-null. - - _FUNC_(DISTINCT expr[, expr...]) - Returns the number of rows for which the supplied expression(s) are unique and non-null. - """) -// scalastyle:on line.size.limit -case class Count(children: Seq[Expression]) extends DeclarativeAggregate { - +/** + * Base class for all counting aggregators. + */ +abstract class CountLike extends DeclarativeAggregate { override def nullable: Boolean = false // Return data type. override def dataType: DataType = LongType - private lazy val count = AttributeReference("count", LongType, nullable = false)() + protected lazy val count = AttributeReference("count", LongType, nullable = false)() override lazy val aggBufferAttributes = count :: Nil @@ -46,6 +38,27 @@ case class Count(children: Seq[Expression]) extends DeclarativeAggregate { /* count = */ Literal(0L) ) + override lazy val mergeExpressions = Seq( + /* count = */ count.left + count.right + ) + + override lazy val evaluateExpression = count + + override def defaultResult: Option[Literal] = Option(Literal(0L)) +} + +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = """ + _FUNC_(*) - Returns the total number of retrieved rows, including rows containing null. + + _FUNC_(expr) - Returns the number of rows for which the supplied expression is non-null. + + _FUNC_(DISTINCT expr[, expr...]) - Returns the number of rows for which the supplied expression(s) are unique and non-null. + """) +// scalastyle:on line.size.limit +case class Count(children: Seq[Expression]) extends CountLike { + override lazy val updateExpressions = { val nullableChildren = children.filter(_.nullable) if (nullableChildren.isEmpty) { @@ -58,14 +71,6 @@ case class Count(children: Seq[Expression]) extends DeclarativeAggregate { ) } } - - override lazy val mergeExpressions = Seq( - /* count = */ count.left + count.right - ) - - override lazy val evaluateExpression = count - - override def defaultResult: Option[Literal] = Option(Literal(0L)) } object Count { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala index fc6c34baafdd1..72a7c62b328ee 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala @@ -42,23 +42,7 @@ abstract class Covariance(x: Expression, y: Expression) override val initialValues: Seq[Expression] = Array.fill(4)(Literal(0.0)) - override lazy val updateExpressions: Seq[Expression] = { - val newN = n + Literal(1.0) - val dx = x - xAvg - val dy = y - yAvg - val dyN = dy / newN - val newXAvg = xAvg + dx / newN - val newYAvg = yAvg + dyN - val newCk = ck + dx * (y - newYAvg) - - val isNull = IsNull(x) || IsNull(y) - Seq( - If(isNull, n, newN), - If(isNull, xAvg, newXAvg), - If(isNull, yAvg, newYAvg), - If(isNull, ck, newCk) - ) - } + override lazy val updateExpressions: Seq[Expression] = updateExpressionsDef override val mergeExpressions: Seq[Expression] = { @@ -75,6 +59,24 @@ abstract class Covariance(x: Expression, y: Expression) Seq(newN, newXAvg, newYAvg, newCk) } + + protected def updateExpressionsDef: Seq[Expression] = { + val newN = n + Literal(1.0) + val dx = x - xAvg + val dy = y - yAvg + val dyN = dy / newN + val newXAvg = xAvg + dx / newN + val newYAvg = yAvg + dyN + val newCk = ck + dx * (y - newYAvg) + + val isNull = IsNull(x) || IsNull(y) + Seq( + If(isNull, n, newN), + If(isNull, xAvg, newXAvg), + If(isNull, yAvg, newYAvg), + If(isNull, ck, newCk) + ) + } } @ExpressionDescription( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/regression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/regression.scala new file mode 100644 index 0000000000000..d8f4505588ff2 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/regression.scala @@ -0,0 +1,190 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.aggregate + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types.{AbstractDataType, DoubleType} + +/** + * Base trait for all regression functions. + */ +trait RegrLike extends AggregateFunction with ImplicitCastInputTypes { + def y: Expression + def x: Expression + + override def children: Seq[Expression] = Seq(y, x) + override def inputTypes: Seq[AbstractDataType] = Seq(DoubleType, DoubleType) + + protected def updateIfNotNull(exprs: Seq[Expression]): Seq[Expression] = { + assert(aggBufferAttributes.length == exprs.length) + val nullableChildren = children.filter(_.nullable) + if (nullableChildren.isEmpty) { + exprs + } else { + exprs.zip(aggBufferAttributes).map { case (e, a) => + If(nullableChildren.map(IsNull).reduce(Or), a, e) + } + } + } +} + + +@ExpressionDescription( + usage = "_FUNC_(y, x) - Returns the number of non-null pairs.", + since = "2.4.0") +case class RegrCount(y: Expression, x: Expression) + extends CountLike with RegrLike { + + override lazy val updateExpressions: Seq[Expression] = updateIfNotNull(Seq(count + 1L)) + + override def prettyName: String = "regr_count" +} + + +@ExpressionDescription( + usage = "_FUNC_(y, x) - Returns SUM(x*x)-SUM(x)*SUM(x)/N. Any pair with a NULL is ignored.", + since = "2.4.0") +case class RegrSXX(y: Expression, x: Expression) + extends CentralMomentAgg(x) with RegrLike { + + override protected def momentOrder = 2 + + override lazy val updateExpressions: Seq[Expression] = updateIfNotNull(updateExpressionsDef) + + override val evaluateExpression: Expression = { + If(n === Literal(0.0), Literal.create(null, DoubleType), m2) + } + + override def prettyName: String = "regr_sxx" +} + + +@ExpressionDescription( + usage = "_FUNC_(y, x) - Returns SUM(y*y)-SUM(y)*SUM(y)/N. Any pair with a NULL is ignored.", + since = "2.4.0") +case class RegrSYY(y: Expression, x: Expression) + extends CentralMomentAgg(y) with RegrLike { + + override protected def momentOrder = 2 + + override lazy val updateExpressions: Seq[Expression] = updateIfNotNull(updateExpressionsDef) + + override val evaluateExpression: Expression = { + If(n === Literal(0.0), Literal.create(null, DoubleType), m2) + } + + override def prettyName: String = "regr_syy" +} + + +@ExpressionDescription( + usage = "_FUNC_(y, x) - Returns the average of x. Any pair with a NULL is ignored.", + since = "2.4.0") +case class RegrAvgX(y: Expression, x: Expression) + extends AverageLike(x) with RegrLike { + + override lazy val updateExpressions: Seq[Expression] = updateIfNotNull(updateExpressionsDef) + + override def prettyName: String = "regr_avgx" +} + + +@ExpressionDescription( + usage = "_FUNC_(y, x) - Returns the average of y. Any pair with a NULL is ignored.", + since = "2.4.0") +case class RegrAvgY(y: Expression, x: Expression) + extends AverageLike(y) with RegrLike { + + override lazy val updateExpressions: Seq[Expression] = updateIfNotNull(updateExpressionsDef) + + override def prettyName: String = "regr_avgy" +} + +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(y, x) - Returns the covariance of y and x multiplied for the number of items in the dataset. Any pair with a NULL is ignored.", + since = "2.4.0") +// scalastyle:on line.size.limit +case class RegrSXY(y: Expression, x: Expression) + extends Covariance(y, x) with RegrLike { + + override lazy val updateExpressions: Seq[Expression] = updateIfNotNull(updateExpressionsDef) + + override val evaluateExpression: Expression = { + If(n === Literal(0.0), Literal.create(null, DoubleType), ck) + } + + override def prettyName: String = "regr_sxy" +} + + +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(y, x) - Returns the slope of the linear regression line. Any pair with a NULL is ignored.", + since = "2.4.0") +// scalastyle:on line.size.limit +case class RegrSlope(y: Expression, x: Expression) + extends PearsonCorrelation(y, x) with RegrLike { + + override lazy val updateExpressions: Seq[Expression] = updateIfNotNull(updateExpressionsDef) + + override val evaluateExpression: Expression = { + If(n < Literal(2.0) || yMk === Literal(0.0), Literal.create(null, DoubleType), ck / yMk) + } + + override def prettyName: String = "regr_slope" +} + + +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(y, x) - Returns the coefficient of determination (also called R-squared or goodness of fit) for the regression line. Any pair with a NULL is ignored.", + since = "2.4.0") +// scalastyle:on line.size.limit +case class RegrR2(y: Expression, x: Expression) + extends PearsonCorrelation(y, x) with RegrLike { + + override lazy val updateExpressions: Seq[Expression] = updateIfNotNull(updateExpressionsDef) + + override val evaluateExpression: Expression = { + If(n < Literal(2.0) || yMk === Literal(0.0), Literal.create(null, DoubleType), + If(xMk === Literal(0.0), Literal(1.0), ck * ck / yMk / xMk)) + } + + override def prettyName: String = "regr_r2" +} + + +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(y, x) - Returns the y-intercept of the linear regression line. Any pair with a NULL is ignored.", + since = "2.4.0") +// scalastyle:on line.size.limit +case class RegrIntercept(y: Expression, x: Expression) + extends PearsonCorrelation(y, x) with RegrLike { + + override lazy val updateExpressions: Seq[Expression] = updateIfNotNull(updateExpressionsDef) + + override val evaluateExpression: Expression = { + If(n === Literal(0.0) || yMk === Literal(0.0), Literal.create(null, DoubleType), + xAvg - (ck / yMk) * yAvg) + } + + override def prettyName: String = "regr_intercept" +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 8f9e4ae18b3f1..28cf705eb9700 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -775,6 +775,178 @@ object functions { */ def var_pop(columnName: String): Column = var_pop(Column(columnName)) + /** + * Aggregate function: returns the number of non-null pairs. + * + * @group agg_funcs + * @since 2.4.0 + */ + def regr_count(y: Column, x: Column): Column = withAggregateFunction { + RegrCount(y.expr, x.expr) + } + + /** + * Aggregate function: returns the number of non-null pairs. + * + * @group agg_funcs + * @since 2.4.0 + */ + def regr_count(y: String, x: String): Column = regr_count(Column(y), Column(x)) + + /** + * Aggregate function: returns SUM(x*x)-SUM(x)*SUM(x)/N. Any pair with a NULL is ignored. + * + * @group agg_funcs + * @since 2.4.0 + */ + def regr_sxx(y: Column, x: Column): Column = withAggregateFunction { + RegrSXX(y.expr, x.expr) + } + + /** + * Aggregate function: returns SUM(x*x)-SUM(x)*SUM(x)/N. Any pair with a NULL is ignored. + * + * @group agg_funcs + * @since 2.4.0 + */ + def regr_sxx(y: String, x: String): Column = regr_sxx(Column(y), Column(x)) + + /** + * Aggregate function: returns SUM(y*y)-SUM(y)*SUM(y)/N. Any pair with a NULL is ignored. + * + * @group agg_funcs + * @since 2.4.0 + */ + def regr_syy(y: Column, x: Column): Column = withAggregateFunction { + RegrSYY(y.expr, x.expr) + } + + /** + * Aggregate function: returns SUM(y*y)-SUM(y)*SUM(y)/N. Any pair with a NULL is ignored. + * + * @group agg_funcs + * @since 2.4.0 + */ + def regr_syy(y: String, x: String): Column = regr_syy(Column(y), Column(x)) + + /** + * Aggregate function: returns the average of y. Any pair with a NULL is ignored. + * + * @group agg_funcs + * @since 2.4.0 + */ + def regr_avgy(y: Column, x: Column): Column = withAggregateFunction { + RegrAvgY(y.expr, x.expr) + } + + /** + * Aggregate function: returns the average of y. Any pair with a NULL is ignored. + * + * @group agg_funcs + * @since 2.4.0 + */ + def regr_avgy(y: String, x: String): Column = regr_avgy(Column(y), Column(x)) + + /** + * Aggregate function: returns the average of x. Any pair with a NULL is ignored. + * + * @group agg_funcs + * @since 2.4.0 + */ + def regr_avgx(y: Column, x: Column): Column = withAggregateFunction { + RegrAvgX(y.expr, x.expr) + } + + /** + * Aggregate function: returns the average of x. Any pair with a NULL is ignored. + * + * @group agg_funcs + * @since 2.4.0 + */ + def regr_avgx(y: String, x: String): Column = regr_avgx(Column(y), Column(x)) + + /** + * Aggregate function: returns the covariance of y and x multiplied for the number of items in + * the dataset. Any pair with a NULL is ignored. + * + * @group agg_funcs + * @since 2.4.0 + */ + def regr_sxy(y: Column, x: Column): Column = withAggregateFunction { + RegrSXY(y.expr, x.expr) + } + + /** + * Aggregate function: returns the covariance of y and x multiplied for the number of items in + * the dataset. Any pair with a NULL is ignored. + * + * @group agg_funcs + * @since 2.4.0 + */ + def regr_sxy(y: String, x: String): Column = regr_sxy(Column(y), Column(x)) + + /** + * Aggregate function: returns the slope of the linear regression line. Any pair with a NULL is + * ignored. + * + * @group agg_funcs + * @since 2.4.0 + */ + def regr_slope(y: Column, x: Column): Column = withAggregateFunction { + RegrSlope(y.expr, x.expr) + } + + /** + * Aggregate function: returns the slope of the linear regression line. Any pair with a NULL is + * ignored. + * + * @group agg_funcs + * @since 2.4.0 + */ + def regr_slope(y: String, x: String): Column = regr_slope(Column(y), Column(x)) + + /** + * Aggregate function: returns the coefficient of determination (also called R-squared or + * goodness of fit) for the regression line. Any pair with a NULL is ignored. + * + * @group agg_funcs + * @since 2.4.0 + */ + def regr_r2(y: Column, x: Column): Column = withAggregateFunction { + RegrR2(y.expr, x.expr) + } + + /** + * Aggregate function: returns the coefficient of determination (also called R-squared or + * goodness of fit) for the regression line. Any pair with a NULL is ignored. + * + * @group agg_funcs + * @since 2.4.0 + */ + def regr_r2(y: String, x: String): Column = regr_r2(Column(y), Column(x)) + + /** + * Aggregate function: returns the y-intercept of the linear regression line. Any pair with a + * NULL is ignored. + * + * @group agg_funcs + * @since 2.4.0 + */ + def regr_intercept(y: Column, x: Column): Column = withAggregateFunction { + RegrIntercept(y.expr, x.expr) + } + + /** + * Aggregate function: returns the y-intercept of the linear regression line. Any pair with a + * NULL is ignored. + * + * @group agg_funcs + * @since 2.4.0 + */ + def regr_intercept(y: String, x: String): Column = regr_intercept(Column(y), Column(x)) + + + ////////////////////////////////////////////////////////////////////////////////////////////// // Window functions ////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/test/resources/sql-tests/inputs/udaf-regrfunctions.sql b/sql/core/src/test/resources/sql-tests/inputs/udaf-regrfunctions.sql new file mode 100644 index 0000000000000..92c7e26e3add2 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/udaf-regrfunctions.sql @@ -0,0 +1,56 @@ +-- +-- Licensed to the Apache Software Foundation (ASF) under one or more +-- contributor license agreements. See the NOTICE file distributed with +-- this work for additional information regarding copyright ownership. +-- The ASF licenses this file to You under the Apache License, Version 2.0 +-- (the "License"); you may not use this file except in compliance with +-- the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. +-- + +CREATE OR REPLACE TEMPORARY VIEW t1 AS SELECT * FROM VALUES + (101, 1, 1, 1), + (201, 2, 1, 1), + (301, 3, 1, 1), + (401, 4, 1, 11), + (501, 5, 1, null), + (601, 6, null, 1), + (701, 6, null, null), + (102, 1, 2, 2), + (202, 2, 1, 2), + (302, 3, 2, 1), + (402, 4, 2, 12), + (502, 5, 2, null), + (602, 6, null, 2), + (702, 6, null, null), + (103, 1, 3, 3), + (203, 2, 1, 3), + (303, 3, 3, 1), + (403, 4, 3, 13), + (503, 5, 3, null), + (603, 6, null, 3), + (703, 6, null, null), + (104, 1, 4, 4), + (204, 2, 1, 4), + (304, 3, 4, 1), + (404, 4, 4, 14), + (504, 5, 4, null), + (604, 6, null, 4), + (704, 6, null, null), + (800, 7, 1, 1) +as t1(id, px, y, x); + +select px, var_pop(x), var_pop(y), corr(y,x), covar_samp(y,x), covar_pop(y,x), regr_count(y,x), + regr_slope(y,x), regr_intercept(y,x), regr_r2(y,x), regr_sxx(y,x), regr_syy(y,x), regr_sxy(y,x), + regr_avgx(y,x), regr_avgy(y,x), regr_count(y,x) +from t1 group by px order by px; + + +select id, regr_count(y,x) over (partition by px) from t1 order by id; diff --git a/sql/core/src/test/resources/sql-tests/results/udaf-regrfunctions.sql.out b/sql/core/src/test/resources/sql-tests/results/udaf-regrfunctions.sql.out new file mode 100644 index 0000000000000..d7d009a64bf84 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/udaf-regrfunctions.sql.out @@ -0,0 +1,93 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 3 + + +-- !query 0 +CREATE OR REPLACE TEMPORARY VIEW t1 AS SELECT * FROM VALUES + (101, 1, 1, 1), + (201, 2, 1, 1), + (301, 3, 1, 1), + (401, 4, 1, 11), + (501, 5, 1, null), + (601, 6, null, 1), + (701, 6, null, null), + (102, 1, 2, 2), + (202, 2, 1, 2), + (302, 3, 2, 1), + (402, 4, 2, 12), + (502, 5, 2, null), + (602, 6, null, 2), + (702, 6, null, null), + (103, 1, 3, 3), + (203, 2, 1, 3), + (303, 3, 3, 1), + (403, 4, 3, 13), + (503, 5, 3, null), + (603, 6, null, 3), + (703, 6, null, null), + (104, 1, 4, 4), + (204, 2, 1, 4), + (304, 3, 4, 1), + (404, 4, 4, 14), + (504, 5, 4, null), + (604, 6, null, 4), + (704, 6, null, null), + (800, 7, 1, 1) +as t1(id, px, y, x) +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +select px, var_pop(x), var_pop(y), corr(y,x), covar_samp(y,x), covar_pop(y,x), regr_count(y,x), + regr_slope(y,x), regr_intercept(y,x), regr_r2(y,x), regr_sxx(y,x), regr_syy(y,x), regr_sxy(y,x), + regr_avgx(y,x), regr_avgy(y,x), regr_count(y,x) +from t1 group by px order by px +-- !query 1 schema +struct +-- !query 1 output +1 1.25 1.25 1.0 1.6666666666666667 1.25 4 1.0 0.0 1.0 5.0 5.0 5.0 2.5 2.5 4 +2 1.25 0.0 NULL 0.0 0.0 4 0.0 1.0 1.0 5.0 0.0 0.0 2.5 1.0 4 +3 0.0 1.25 NULL 0.0 0.0 4 NULL NULL NULL 0.0 5.0 0.0 1.0 2.5 4 +4 1.25 1.25 1.0 1.6666666666666667 1.25 4 1.0 -10.0 1.0 5.0 5.0 5.0 12.5 2.5 4 +5 NULL 1.25 NULL NULL NULL 0 NULL NULL NULL NULL NULL NULL NULL NULL 0 +6 1.25 NULL NULL NULL NULL 0 NULL NULL NULL NULL NULL NULL NULL NULL 0 +7 0.0 0.0 NaN NaN 0.0 1 NULL NULL NULL 0.0 0.0 0.0 1.0 1.0 1 + + +-- !query 2 +select id, regr_count(y,x) over (partition by px) from t1 order by id +-- !query 2 schema +struct +-- !query 2 output +101 4 +102 4 +103 4 +104 4 +201 4 +202 4 +203 4 +204 4 +301 4 +302 4 +303 4 +304 4 +401 4 +402 4 +403 4 +404 4 +501 0 +502 0 +503 0 +504 0 +601 0 +602 0 +603 0 +604 0 +701 0 +702 0 +703 0 +704 0 +800 1 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index e7776e36702ad..4337fb2290fbc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -36,6 +36,8 @@ case class Fact(date: Int, hour: Int, minute: Int, room_name: String, temp: Doub class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { import testImplicits._ + val absTol = 1e-8 + test("groupBy") { checkAnswer( testData2.groupBy("a").agg(sum($"b")), @@ -416,7 +418,6 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { } test("moments") { - val absTol = 1e-8 val sparkVariance = testData2.agg(variance('a)) checkAggregatesWithTol(sparkVariance, Row(4.0 / 5.0), absTol) @@ -686,4 +687,72 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { } } } + + test("SPARK-23907: regression functions") { + val emptyTableData = Seq.empty[(Double, Double)].toDF("a", "b") + val correlatedData = Seq[(Double, Double)]((2, 3), (3, 4), (7.5, 8.2), (10.3, 12)) + .toDF("a", "b") + val correlatedDataWithNull = Seq[(java.lang.Double, java.lang.Double)]( + (2.0, 3.0), (3.0, null), (7.5, 8.2), (10.3, 12.0)).toDF("a", "b") + checkAnswer(testData2.groupBy().agg(regr_count("a", "b")), Seq(Row(6))) + checkAnswer(testData3.groupBy().agg(regr_count("a", "b")), Seq(Row(1))) + checkAnswer(emptyTableData.groupBy().agg(regr_count("a", "b")), Seq(Row(0))) + + checkAggregatesWithTol(testData2.groupBy().agg(regr_sxx("a", "b")), Row(1.5), absTol) + checkAggregatesWithTol(testData3.groupBy().agg(regr_sxx("a", "b")), Row(0.0), absTol) + checkAggregatesWithTol(emptyTableData.groupBy().agg(regr_sxx("a", "b")), Row(null), absTol) + checkAggregatesWithTol(testData2.groupBy().agg(regr_syy("b", "a")), Row(1.5), absTol) + checkAggregatesWithTol(testData3.groupBy().agg(regr_syy("b", "a")), Row(0.0), absTol) + checkAggregatesWithTol(emptyTableData.groupBy().agg(regr_syy("b", "a")), Row(null), absTol) + + checkAggregatesWithTol(testData2.groupBy().agg(regr_avgx("a", "b")), Row(1.5), absTol) + checkAggregatesWithTol(testData3.groupBy().agg(regr_avgx("a", "b")), Row(2.0), absTol) + checkAggregatesWithTol(emptyTableData.groupBy().agg(regr_avgx("a", "b")), Row(null), absTol) + checkAggregatesWithTol(testData2.groupBy().agg(regr_avgy("b", "a")), Row(1.5), absTol) + checkAggregatesWithTol(testData3.groupBy().agg(regr_avgy("b", "a")), Row(2.0), absTol) + checkAggregatesWithTol(emptyTableData.groupBy().agg(regr_avgy("b", "a")), Row(null), absTol) + + checkAggregatesWithTol(testData2.groupBy().agg(regr_sxy("a", "b")), Row(0.0), absTol) + checkAggregatesWithTol(testData3.groupBy().agg(regr_sxy("a", "b")), Row(0.0), absTol) + checkAggregatesWithTol(emptyTableData.groupBy().agg(regr_sxy("a", "b")), Row(null), absTol) + + checkAggregatesWithTol(testData2.groupBy().agg(regr_slope("a", "b")), Row(0.0), absTol) + checkAggregatesWithTol(testData3.groupBy().agg(regr_slope("a", "b")), Row(null), absTol) + checkAggregatesWithTol(emptyTableData.groupBy().agg(regr_slope("a", "b")), Row(null), absTol) + + checkAggregatesWithTol(testData2.groupBy().agg(regr_r2("a", "b")), Row(0.0), absTol) + checkAggregatesWithTol(testData3.groupBy().agg(regr_r2("a", "b")), Row(null), absTol) + checkAggregatesWithTol(emptyTableData.groupBy().agg(regr_r2("a", "b")), Row(null), absTol) + + checkAggregatesWithTol(testData2.groupBy().agg(regr_intercept("a", "b")), Row(2.0), absTol) + checkAggregatesWithTol(testData3.groupBy().agg(regr_intercept("a", "b")), Row(null), absTol) + checkAggregatesWithTol(emptyTableData.groupBy().agg(regr_intercept("a", "b")), + Row(null), absTol) + + + checkAggregatesWithTol(correlatedData.groupBy().agg( + regr_count("a", "b"), + regr_avgx("a", "b"), + regr_avgy("a", "b"), + regr_sxx("a", "b"), + regr_syy("a", "b"), + regr_sxy("a", "b"), + regr_slope("a", "b"), + regr_r2("a", "b"), + regr_intercept("a", "b")), + Row(4, 6.8, 5.7, 51.28, 45.38, 48.06, 0.937207488, 0.992556013, -0.67301092), + absTol) + checkAggregatesWithTol(correlatedDataWithNull.groupBy().agg( + regr_count("a", "b"), + regr_avgx("a", "b"), + regr_avgy("a", "b"), + regr_sxx("a", "b"), + regr_syy("a", "b"), + regr_sxy("a", "b"), + regr_slope("a", "b"), + regr_r2("a", "b"), + regr_intercept("a", "b")), + Row(3, 7.73333333, 6.6, 40.82666666, 35.66, 37.98, 0.93027433, 0.99079694, -0.59412149), + absTol) + } } From f4fed0512101a67d9dae50ace11d3940b910e05e Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Thu, 10 May 2018 09:44:49 -0700 Subject: [PATCH 0776/1161] [SPARK-24171] Adding a note for non-deterministic functions ## What changes were proposed in this pull request? I propose to add a clear statement for functions like `collect_list()` about non-deterministic behavior of such functions. The behavior must be taken into account by user while creating and running queries. Author: Maxim Gekk Closes #21228 from MaxGekk/deterministic-comments. --- R/pkg/R/functions.R | 11 +++++ python/pyspark/sql/functions.py | 18 ++++++++ .../MonotonicallyIncreasingID.scala | 1 + .../spark/sql/catalyst/expressions/misc.scala | 5 +- .../expressions/randomExpressions.scala | 8 ++-- .../org/apache/spark/sql/functions.scala | 46 +++++++++++++++++-- 6 files changed, 81 insertions(+), 8 deletions(-) diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 0ec99d19e21e4..04d0e4620b28a 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -805,6 +805,8 @@ setMethod("factorial", #' #' The function by default returns the first values it sees. It will return the first non-missing #' value it sees when na.rm is set to true. If all values are missing, then NA is returned. +#' Note: the function is non-deterministic because its results depends on order of rows which +#' may be non-deterministic after a shuffle. #' #' @param na.rm a logical value indicating whether NA values should be stripped #' before the computation proceeds. @@ -948,6 +950,8 @@ setMethod("kurtosis", #' #' The function by default returns the last values it sees. It will return the last non-missing #' value it sees when na.rm is set to true. If all values are missing, then NA is returned. +#' Note: the function is non-deterministic because its results depends on order of rows which +#' may be non-deterministic after a shuffle. #' #' @param x column to compute on. #' @param na.rm a logical value indicating whether NA values should be stripped @@ -1201,6 +1205,7 @@ setMethod("minute", #' 0, 1, 2, 8589934592 (1L << 33), 8589934593, 8589934594. #' This is equivalent to the MONOTONICALLY_INCREASING_ID function in SQL. #' The method should be used with no argument. +#' Note: the function is non-deterministic because its result depends on partition IDs. #' #' @rdname column_nonaggregate_functions #' @aliases monotonically_increasing_id monotonically_increasing_id,missing-method @@ -2584,6 +2589,7 @@ setMethod("lpad", signature(x = "Column", len = "numeric", pad = "character"), #' @details #' \code{rand}: Generates a random column with independent and identically distributed (i.i.d.) #' samples from U[0.0, 1.0]. +#' Note: the function is non-deterministic in general case. #' #' @rdname column_nonaggregate_functions #' @param seed a random seed. Can be missing. @@ -2612,6 +2618,7 @@ setMethod("rand", signature(seed = "numeric"), #' @details #' \code{randn}: Generates a column with independent and identically distributed (i.i.d.) samples #' from the standard normal distribution. +#' Note: the function is non-deterministic in general case. #' #' @rdname column_nonaggregate_functions #' @aliases randn randn,missing-method @@ -3188,6 +3195,8 @@ setMethod("create_map", #' @details #' \code{collect_list}: Creates a list of objects with duplicates. +#' Note: the function is non-deterministic because the order of collected results depends +#' on order of rows which may be non-deterministic after a shuffle. #' #' @rdname column_aggregate_functions #' @aliases collect_list collect_list,Column-method @@ -3207,6 +3216,8 @@ setMethod("collect_list", #' @details #' \code{collect_set}: Creates a list of objects with duplicate elements eliminated. +#' Note: the function is non-deterministic because the order of collected results depends +#' on order of rows which may be non-deterministic after a shuffle. #' #' @rdname column_aggregate_functions #' @aliases collect_set collect_set,Column-method diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index ac3c79766702c..f5a584152b4f6 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -152,6 +152,9 @@ def _(): _collect_list_doc = """ Aggregate function: returns a list of objects with duplicates. + .. note:: The function is non-deterministic because the order of collected results depends + on order of rows which may be non-deterministic after a shuffle. + >>> df2 = spark.createDataFrame([(2,), (5,), (5,)], ('age',)) >>> df2.agg(collect_list('age')).collect() [Row(collect_list(age)=[2, 5, 5])] @@ -159,6 +162,9 @@ def _(): _collect_set_doc = """ Aggregate function: returns a set of objects with duplicate elements eliminated. + .. note:: The function is non-deterministic because the order of collected results depends + on order of rows which may be non-deterministic after a shuffle. + >>> df2 = spark.createDataFrame([(2,), (5,), (5,)], ('age',)) >>> df2.agg(collect_set('age')).collect() [Row(collect_set(age)=[5, 2])] @@ -401,6 +407,9 @@ def first(col, ignorenulls=False): The function by default returns the first values it sees. It will return the first non-null value it sees when ignoreNulls is set to true. If all values are null, then null is returned. + + .. note:: The function is non-deterministic because its results depends on order of rows which + may be non-deterministic after a shuffle. """ sc = SparkContext._active_spark_context jc = sc._jvm.functions.first(_to_java_column(col), ignorenulls) @@ -489,6 +498,9 @@ def last(col, ignorenulls=False): The function by default returns the last values it sees. It will return the last non-null value it sees when ignoreNulls is set to true. If all values are null, then null is returned. + + .. note:: The function is non-deterministic because its results depends on order of rows + which may be non-deterministic after a shuffle. """ sc = SparkContext._active_spark_context jc = sc._jvm.functions.last(_to_java_column(col), ignorenulls) @@ -504,6 +516,8 @@ def monotonically_increasing_id(): within each partition in the lower 33 bits. The assumption is that the data frame has less than 1 billion partitions, and each partition has less than 8 billion records. + .. note:: The function is non-deterministic because its result depends on partition IDs. + As an example, consider a :class:`DataFrame` with two partitions, each with 3 records. This expression would return the following IDs: 0, 1, 2, 8589934592 (1L << 33), 8589934593, 8589934594. @@ -536,6 +550,8 @@ def rand(seed=None): """Generates a random column with independent and identically distributed (i.i.d.) samples from U[0.0, 1.0]. + .. note:: The function is non-deterministic in general case. + >>> df.withColumn('rand', rand(seed=42) * 3).collect() [Row(age=2, name=u'Alice', rand=1.1568609015300986), Row(age=5, name=u'Bob', rand=1.403379671529166)] @@ -554,6 +570,8 @@ def randn(seed=None): """Generates a column with independent and identically distributed (i.i.d.) samples from the standard normal distribution. + .. note:: The function is non-deterministic in general case. + >>> df.withColumn('randn', randn(seed=42)).collect() [Row(age=2, name=u'Alice', randn=-0.7556247885860078), Row(age=5, name=u'Bob', randn=-0.0861619008451133)] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala index ad1e7bdb31987..9f0779642271d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala @@ -38,6 +38,7 @@ import org.apache.spark.sql.types.{DataType, LongType} puts the partition ID in the upper 31 bits, and the lower 33 bits represent the record number within each partition. The assumption is that the data frame has less than 1 billion partitions, and each partition has less than 8 billion records. + The function is non-deterministic because its result depends on partition IDs. """) case class MonotonicallyIncreasingID() extends LeafExpression with Stateful { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index 7eda65a867028..b7834696cafc3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -117,12 +117,13 @@ case class CurrentDatabase() extends LeafExpression with Unevaluable { // scalastyle:off line.size.limit @ExpressionDescription( - usage = "_FUNC_() - Returns an universally unique identifier (UUID) string. The value is returned as a canonical UUID 36-character string.", + usage = """_FUNC_() - Returns an universally unique identifier (UUID) string. The value is returned as a canonical UUID 36-character string.""", examples = """ Examples: > SELECT _FUNC_(); 46707d92-02f4-4817-8116-a4c3b23e6266 - """) + """, + note = "The function is non-deterministic.") // scalastyle:on line.size.limit case class Uuid(randomSeed: Option[Long] = None) extends LeafExpression with Stateful { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala index 70186053617f8..2653b28f6c3bd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala @@ -68,7 +68,8 @@ abstract class RDG extends UnaryExpression with ExpectsInputTypes with Stateful 0.8446490682263027 > SELECT _FUNC_(null); 0.8446490682263027 - """) + """, + note = "The function is non-deterministic in general case.") // scalastyle:on line.size.limit case class Rand(child: Expression) extends RDG { @@ -96,7 +97,7 @@ object Rand { /** Generate a random column with i.i.d. values drawn from the standard normal distribution. */ // scalastyle:off line.size.limit @ExpressionDescription( - usage = "_FUNC_([seed]) - Returns a random value with independent and identically distributed (i.i.d.) values drawn from the standard normal distribution.", + usage = """_FUNC_([seed]) - Returns a random value with independent and identically distributed (i.i.d.) values drawn from the standard normal distribution.""", examples = """ Examples: > SELECT _FUNC_(); @@ -105,7 +106,8 @@ object Rand { 1.1164209726833079 > SELECT _FUNC_(null); 1.1164209726833079 - """) + """, + note = "The function is non-deterministic in general case.") // scalastyle:on line.size.limit case class Randn(child: Expression) extends RDG { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 28cf705eb9700..225de0051d6fa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -283,6 +283,9 @@ object functions { /** * Aggregate function: returns a list of objects with duplicates. * + * @note The function is non-deterministic because the order of collected results depends + * on order of rows which may be non-deterministic after a shuffle. + * * @group agg_funcs * @since 1.6.0 */ @@ -291,6 +294,9 @@ object functions { /** * Aggregate function: returns a list of objects with duplicates. * + * @note The function is non-deterministic because the order of collected results depends + * on order of rows which may be non-deterministic after a shuffle. + * * @group agg_funcs * @since 1.6.0 */ @@ -299,6 +305,9 @@ object functions { /** * Aggregate function: returns a set of objects with duplicate elements eliminated. * + * @note The function is non-deterministic because the order of collected results depends + * on order of rows which may be non-deterministic after a shuffle. + * * @group agg_funcs * @since 1.6.0 */ @@ -307,6 +316,9 @@ object functions { /** * Aggregate function: returns a set of objects with duplicate elements eliminated. * + * @note The function is non-deterministic because the order of collected results depends + * on order of rows which may be non-deterministic after a shuffle. + * * @group agg_funcs * @since 1.6.0 */ @@ -422,6 +434,9 @@ object functions { * The function by default returns the first values it sees. It will return the first non-null * value it sees when ignoreNulls is set to true. If all values are null, then null is returned. * + * @note The function is non-deterministic because its results depends on order of rows which + * may be non-deterministic after a shuffle. + * * @group agg_funcs * @since 2.0.0 */ @@ -435,6 +450,9 @@ object functions { * The function by default returns the first values it sees. It will return the first non-null * value it sees when ignoreNulls is set to true. If all values are null, then null is returned. * + * @note The function is non-deterministic because its results depends on order of rows which + * may be non-deterministic after a shuffle. + * * @group agg_funcs * @since 2.0.0 */ @@ -448,6 +466,9 @@ object functions { * The function by default returns the first values it sees. It will return the first non-null * value it sees when ignoreNulls is set to true. If all values are null, then null is returned. * + * @note The function is non-deterministic because its results depends on order of rows which + * may be non-deterministic after a shuffle. + * * @group agg_funcs * @since 1.3.0 */ @@ -459,6 +480,9 @@ object functions { * The function by default returns the first values it sees. It will return the first non-null * value it sees when ignoreNulls is set to true. If all values are null, then null is returned. * + * @note The function is non-deterministic because its results depends on order of rows which + * may be non-deterministic after a shuffle. + * * @group agg_funcs * @since 1.3.0 */ @@ -535,6 +559,9 @@ object functions { * The function by default returns the last values it sees. It will return the last non-null * value it sees when ignoreNulls is set to true. If all values are null, then null is returned. * + * @note The function is non-deterministic because its results depends on order of rows which + * may be non-deterministic after a shuffle. + * * @group agg_funcs * @since 2.0.0 */ @@ -548,6 +575,9 @@ object functions { * The function by default returns the last values it sees. It will return the last non-null * value it sees when ignoreNulls is set to true. If all values are null, then null is returned. * + * @note The function is non-deterministic because its results depends on order of rows which + * may be non-deterministic after a shuffle. + * * @group agg_funcs * @since 2.0.0 */ @@ -561,6 +591,9 @@ object functions { * The function by default returns the last values it sees. It will return the last non-null * value it sees when ignoreNulls is set to true. If all values are null, then null is returned. * + * @note The function is non-deterministic because its results depends on order of rows which + * may be non-deterministic after a shuffle. + * * @group agg_funcs * @since 1.3.0 */ @@ -572,6 +605,9 @@ object functions { * The function by default returns the last values it sees. It will return the last non-null * value it sees when ignoreNulls is set to true. If all values are null, then null is returned. * + * @note The function is non-deterministic because its results depends on order of rows which + * may be non-deterministic after a shuffle. + * * @group agg_funcs * @since 1.3.0 */ @@ -1344,7 +1380,7 @@ object functions { * Generate a random column with independent and identically distributed (i.i.d.) samples * from U[0.0, 1.0]. * - * @note This is indeterministic when data partitions are not fixed. + * @note The function is non-deterministic in general case. * * @group normal_funcs * @since 1.4.0 @@ -1355,6 +1391,8 @@ object functions { * Generate a random column with independent and identically distributed (i.i.d.) samples * from U[0.0, 1.0]. * + * @note The function is non-deterministic in general case. + * * @group normal_funcs * @since 1.4.0 */ @@ -1364,7 +1402,7 @@ object functions { * Generate a column with independent and identically distributed (i.i.d.) samples from * the standard normal distribution. * - * @note This is indeterministic when data partitions are not fixed. + * @note The function is non-deterministic in general case. * * @group normal_funcs * @since 1.4.0 @@ -1375,6 +1413,8 @@ object functions { * Generate a column with independent and identically distributed (i.i.d.) samples from * the standard normal distribution. * + * @note The function is non-deterministic in general case. + * * @group normal_funcs * @since 1.4.0 */ @@ -1383,7 +1423,7 @@ object functions { /** * Partition ID. * - * @note This is indeterministic because it depends on data partitioning and task scheduling. + * @note This is non-deterministic because it depends on data partitioning and task scheduling. * * @group normal_funcs * @since 1.6.0 From 6282fc64e32fc2f70e79ace14efd4922e4535dbb Mon Sep 17 00:00:00 2001 From: mcheah Date: Thu, 10 May 2018 11:36:41 -0700 Subject: [PATCH 0777/1161] [SPARK-24137][K8S] Mount local directories as empty dir volumes. ## What changes were proposed in this pull request? Drastically improves performance and won't cause Spark applications to fail because they write too much data to the Docker image's specific file system. The file system's directories that back emptydir volumes are generally larger and more performant. ## How was this patch tested? Has been in use via the prototype version of Kubernetes support, but lost in the transition to here. Author: mcheah Closes #21238 from mccheah/mount-local-dirs. --- .../scala/org/apache/spark/SparkConf.scala | 5 +- .../k8s/features/LocalDirsFeatureStep.scala | 77 ++++++++++++ .../k8s/submit/KubernetesDriverBuilder.scala | 10 +- .../k8s/KubernetesExecutorBuilder.scala | 9 +- .../features/LocalDirsFeatureStepSuite.scala | 111 ++++++++++++++++++ .../submit/KubernetesDriverBuilderSuite.scala | 13 +- .../k8s/KubernetesExecutorBuilderSuite.scala | 12 +- 7 files changed, 223 insertions(+), 14 deletions(-) create mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/LocalDirsFeatureStep.scala create mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/LocalDirsFeatureStepSuite.scala diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index 129956e9f9ffa..dab409572646f 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -454,8 +454,9 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging with Seria */ private[spark] def validateSettings() { if (contains("spark.local.dir")) { - val msg = "In Spark 1.0 and later spark.local.dir will be overridden by the value set by " + - "the cluster manager (via SPARK_LOCAL_DIRS in mesos/standalone and LOCAL_DIRS in YARN)." + val msg = "Note that spark.local.dir will be overridden by the value set by " + + "the cluster manager (via SPARK_LOCAL_DIRS in mesos/standalone/kubernetes and LOCAL_DIRS" + + " in YARN)." logWarning(msg) } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/LocalDirsFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/LocalDirsFeatureStep.scala new file mode 100644 index 0000000000000..70b307303d149 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/LocalDirsFeatureStep.scala @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.features + +import java.nio.file.Paths +import java.util.UUID + +import io.fabric8.kubernetes.api.model.{ContainerBuilder, HasMetadata, PodBuilder, VolumeBuilder, VolumeMountBuilder} + +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpecificConf, KubernetesRoleSpecificConf, SparkPod} + +private[spark] class LocalDirsFeatureStep( + conf: KubernetesConf[_ <: KubernetesRoleSpecificConf], + defaultLocalDir: String = s"/var/data/spark-${UUID.randomUUID}") + extends KubernetesFeatureConfigStep { + + // Cannot use Utils.getConfiguredLocalDirs because that will default to the Java system + // property - we want to instead default to mounting an emptydir volume that doesn't already + // exist in the image. + // We could make utils.getConfiguredLocalDirs opinionated about Kubernetes, as it is already + // a bit opinionated about YARN and Mesos. + private val resolvedLocalDirs = Option(conf.sparkConf.getenv("SPARK_LOCAL_DIRS")) + .orElse(conf.getOption("spark.local.dir")) + .getOrElse(defaultLocalDir) + .split(",") + + override def configurePod(pod: SparkPod): SparkPod = { + val localDirVolumes = resolvedLocalDirs + .zipWithIndex + .map { case (localDir, index) => + new VolumeBuilder() + .withName(s"spark-local-dir-${index + 1}") + .withNewEmptyDir() + .endEmptyDir() + .build() + } + val localDirVolumeMounts = localDirVolumes + .zip(resolvedLocalDirs) + .map { case (localDirVolume, localDirPath) => + new VolumeMountBuilder() + .withName(localDirVolume.getName) + .withMountPath(localDirPath) + .build() + } + val podWithLocalDirVolumes = new PodBuilder(pod.pod) + .editSpec() + .addToVolumes(localDirVolumes: _*) + .endSpec() + .build() + val containerWithLocalDirVolumeMounts = new ContainerBuilder(pod.container) + .addNewEnv() + .withName("SPARK_LOCAL_DIRS") + .withValue(resolvedLocalDirs.mkString(",")) + .endEnv() + .addToVolumeMounts(localDirVolumeMounts: _*) + .build() + SparkPod(podWithLocalDirVolumes, containerWithLocalDirVolumeMounts) + } + + override def getAdditionalPodSystemProperties(): Map[String, String] = Map.empty + + override def getAdditionalKubernetesResources(): Seq[HasMetadata] = Seq.empty +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala index c7579ed8cb689..10b0154466a3a 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala @@ -17,7 +17,7 @@ package org.apache.spark.deploy.k8s.submit import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpec, KubernetesDriverSpecificConf, KubernetesRoleSpecificConf} -import org.apache.spark.deploy.k8s.features.{BasicDriverFeatureStep, DriverKubernetesCredentialsFeatureStep, DriverServiceFeatureStep, MountSecretsFeatureStep} +import org.apache.spark.deploy.k8s.features.{BasicDriverFeatureStep, DriverKubernetesCredentialsFeatureStep, DriverServiceFeatureStep, LocalDirsFeatureStep, MountSecretsFeatureStep} private[spark] class KubernetesDriverBuilder( provideBasicStep: (KubernetesConf[KubernetesDriverSpecificConf]) => BasicDriverFeatureStep = @@ -29,14 +29,18 @@ private[spark] class KubernetesDriverBuilder( new DriverServiceFeatureStep(_), provideSecretsStep: (KubernetesConf[_ <: KubernetesRoleSpecificConf] => MountSecretsFeatureStep) = - new MountSecretsFeatureStep(_)) { + new MountSecretsFeatureStep(_), + provideLocalDirsStep: (KubernetesConf[_ <: KubernetesRoleSpecificConf]) + => LocalDirsFeatureStep = + new LocalDirsFeatureStep(_)) { def buildFromFeatures( kubernetesConf: KubernetesConf[KubernetesDriverSpecificConf]): KubernetesDriverSpec = { val baseFeatures = Seq( provideBasicStep(kubernetesConf), provideCredentialsStep(kubernetesConf), - provideServiceStep(kubernetesConf)) + provideServiceStep(kubernetesConf), + provideLocalDirsStep(kubernetesConf)) val allFeatures = if (kubernetesConf.roleSecretNamesToMountPaths.nonEmpty) { baseFeatures ++ Seq(provideSecretsStep(kubernetesConf)) } else baseFeatures diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala index 22568fe7ea3be..d8f63d57574fb 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala @@ -17,18 +17,21 @@ package org.apache.spark.scheduler.cluster.k8s import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesExecutorSpecificConf, KubernetesRoleSpecificConf, SparkPod} -import org.apache.spark.deploy.k8s.features.{BasicExecutorFeatureStep, MountSecretsFeatureStep} +import org.apache.spark.deploy.k8s.features.{BasicExecutorFeatureStep, LocalDirsFeatureStep, MountSecretsFeatureStep} private[spark] class KubernetesExecutorBuilder( provideBasicStep: (KubernetesConf[KubernetesExecutorSpecificConf]) => BasicExecutorFeatureStep = new BasicExecutorFeatureStep(_), provideSecretsStep: (KubernetesConf[_ <: KubernetesRoleSpecificConf]) => MountSecretsFeatureStep = - new MountSecretsFeatureStep(_)) { + new MountSecretsFeatureStep(_), + provideLocalDirsStep: (KubernetesConf[_ <: KubernetesRoleSpecificConf]) + => LocalDirsFeatureStep = + new LocalDirsFeatureStep(_)) { def buildFromFeatures( kubernetesConf: KubernetesConf[KubernetesExecutorSpecificConf]): SparkPod = { - val baseFeatures = Seq(provideBasicStep(kubernetesConf)) + val baseFeatures = Seq(provideBasicStep(kubernetesConf), provideLocalDirsStep(kubernetesConf)) val allFeatures = if (kubernetesConf.roleSecretNamesToMountPaths.nonEmpty) { baseFeatures ++ Seq(provideSecretsStep(kubernetesConf)) } else baseFeatures diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/LocalDirsFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/LocalDirsFeatureStepSuite.scala new file mode 100644 index 0000000000000..91e184b84b86e --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/LocalDirsFeatureStepSuite.scala @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.features + +import io.fabric8.kubernetes.api.model.{EnvVarBuilder, VolumeBuilder, VolumeMountBuilder} +import org.mockito.Mockito +import org.scalatest.BeforeAndAfter + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpecificConf, KubernetesExecutorSpecificConf, KubernetesRoleSpecificConf, SparkPod} + +class LocalDirsFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { + private val defaultLocalDir = "/var/data/default-local-dir" + private var sparkConf: SparkConf = _ + private var kubernetesConf: KubernetesConf[_ <: KubernetesRoleSpecificConf] = _ + + before { + val realSparkConf = new SparkConf(false) + sparkConf = Mockito.spy(realSparkConf) + kubernetesConf = KubernetesConf( + sparkConf, + KubernetesDriverSpecificConf( + None, + "app-name", + "main", + Seq.empty), + "resource", + "app-id", + Map.empty, + Map.empty, + Map.empty, + Map.empty) + } + + test("Resolve to default local dir if neither env nor configuration are set") { + Mockito.doReturn(null).when(sparkConf).get("spark.local.dir") + Mockito.doReturn(null).when(sparkConf).getenv("SPARK_LOCAL_DIRS") + val stepUnderTest = new LocalDirsFeatureStep(kubernetesConf, defaultLocalDir) + val configuredPod = stepUnderTest.configurePod(SparkPod.initialPod()) + assert(configuredPod.pod.getSpec.getVolumes.size === 1) + assert(configuredPod.pod.getSpec.getVolumes.get(0) === + new VolumeBuilder() + .withName(s"spark-local-dir-1") + .withNewEmptyDir() + .endEmptyDir() + .build()) + assert(configuredPod.container.getVolumeMounts.size === 1) + assert(configuredPod.container.getVolumeMounts.get(0) === + new VolumeMountBuilder() + .withName(s"spark-local-dir-1") + .withMountPath(defaultLocalDir) + .build()) + assert(configuredPod.container.getEnv.size === 1) + assert(configuredPod.container.getEnv.get(0) === + new EnvVarBuilder() + .withName("SPARK_LOCAL_DIRS") + .withValue(defaultLocalDir) + .build()) + } + + test("Use configured local dirs split on comma if provided.") { + Mockito.doReturn("/var/data/my-local-dir-1,/var/data/my-local-dir-2") + .when(sparkConf).getenv("SPARK_LOCAL_DIRS") + val stepUnderTest = new LocalDirsFeatureStep(kubernetesConf, defaultLocalDir) + val configuredPod = stepUnderTest.configurePod(SparkPod.initialPod()) + assert(configuredPod.pod.getSpec.getVolumes.size === 2) + assert(configuredPod.pod.getSpec.getVolumes.get(0) === + new VolumeBuilder() + .withName(s"spark-local-dir-1") + .withNewEmptyDir() + .endEmptyDir() + .build()) + assert(configuredPod.pod.getSpec.getVolumes.get(1) === + new VolumeBuilder() + .withName(s"spark-local-dir-2") + .withNewEmptyDir() + .endEmptyDir() + .build()) + assert(configuredPod.container.getVolumeMounts.size === 2) + assert(configuredPod.container.getVolumeMounts.get(0) === + new VolumeMountBuilder() + .withName(s"spark-local-dir-1") + .withMountPath("/var/data/my-local-dir-1") + .build()) + assert(configuredPod.container.getVolumeMounts.get(1) === + new VolumeMountBuilder() + .withName(s"spark-local-dir-2") + .withMountPath("/var/data/my-local-dir-2") + .build()) + assert(configuredPod.container.getEnv.size === 1) + assert(configuredPod.container.getEnv.get(0) === + new EnvVarBuilder() + .withName("SPARK_LOCAL_DIRS") + .withValue("/var/data/my-local-dir-1,/var/data/my-local-dir-2") + .build()) + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala index 161f9afe7bba9..a511d254d2175 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala @@ -18,13 +18,14 @@ package org.apache.spark.deploy.k8s.submit import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpec, KubernetesDriverSpecificConf} -import org.apache.spark.deploy.k8s.features.{BasicDriverFeatureStep, DriverKubernetesCredentialsFeatureStep, DriverServiceFeatureStep, KubernetesFeaturesTestUtils, MountSecretsFeatureStep} +import org.apache.spark.deploy.k8s.features.{BasicDriverFeatureStep, DriverKubernetesCredentialsFeatureStep, DriverServiceFeatureStep, KubernetesFeaturesTestUtils, LocalDirsFeatureStep, MountSecretsFeatureStep} class KubernetesDriverBuilderSuite extends SparkFunSuite { private val BASIC_STEP_TYPE = "basic" private val CREDENTIALS_STEP_TYPE = "credentials" private val SERVICE_STEP_TYPE = "service" + private val LOCAL_DIRS_STEP_TYPE = "local-dirs" private val SECRETS_STEP_TYPE = "mount-secrets" private val basicFeatureStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( @@ -36,6 +37,9 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite { private val serviceStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( SERVICE_STEP_TYPE, classOf[DriverServiceFeatureStep]) + private val localDirsStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( + LOCAL_DIRS_STEP_TYPE, classOf[LocalDirsFeatureStep]) + private val secretsStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( SECRETS_STEP_TYPE, classOf[MountSecretsFeatureStep]) @@ -44,7 +48,8 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite { _ => basicFeatureStep, _ => credentialsStep, _ => serviceStep, - _ => secretsStep) + _ => secretsStep, + _ => localDirsStep) test("Apply fundamental steps all the time.") { val conf = KubernetesConf( @@ -64,7 +69,8 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite { builderUnderTest.buildFromFeatures(conf), BASIC_STEP_TYPE, CREDENTIALS_STEP_TYPE, - SERVICE_STEP_TYPE) + SERVICE_STEP_TYPE, + LOCAL_DIRS_STEP_TYPE) } test("Apply secrets step if secrets are present.") { @@ -86,6 +92,7 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite { BASIC_STEP_TYPE, CREDENTIALS_STEP_TYPE, SERVICE_STEP_TYPE, + LOCAL_DIRS_STEP_TYPE, SECRETS_STEP_TYPE) } diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala index f5270623f8acc..9ee86b5a423a9 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala @@ -20,20 +20,24 @@ import io.fabric8.kubernetes.api.model.PodBuilder import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesExecutorSpecificConf, SparkPod} -import org.apache.spark.deploy.k8s.features.{BasicExecutorFeatureStep, KubernetesFeaturesTestUtils, MountSecretsFeatureStep} +import org.apache.spark.deploy.k8s.features.{BasicExecutorFeatureStep, KubernetesFeaturesTestUtils, LocalDirsFeatureStep, MountSecretsFeatureStep} class KubernetesExecutorBuilderSuite extends SparkFunSuite { private val BASIC_STEP_TYPE = "basic" private val SECRETS_STEP_TYPE = "mount-secrets" + private val LOCAL_DIRS_STEP_TYPE = "local-dirs" private val basicFeatureStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( BASIC_STEP_TYPE, classOf[BasicExecutorFeatureStep]) private val mountSecretsStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( SECRETS_STEP_TYPE, classOf[MountSecretsFeatureStep]) + private val localDirsStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( + LOCAL_DIRS_STEP_TYPE, classOf[LocalDirsFeatureStep]) private val builderUnderTest = new KubernetesExecutorBuilder( _ => basicFeatureStep, - _ => mountSecretsStep) + _ => mountSecretsStep, + _ => localDirsStep) test("Basic steps are consistently applied.") { val conf = KubernetesConf( @@ -46,7 +50,8 @@ class KubernetesExecutorBuilderSuite extends SparkFunSuite { Map.empty, Map.empty, Map.empty) - validateStepTypesApplied(builderUnderTest.buildFromFeatures(conf), BASIC_STEP_TYPE) + validateStepTypesApplied( + builderUnderTest.buildFromFeatures(conf), BASIC_STEP_TYPE, LOCAL_DIRS_STEP_TYPE) } test("Apply secrets step if secrets are present.") { @@ -63,6 +68,7 @@ class KubernetesExecutorBuilderSuite extends SparkFunSuite { validateStepTypesApplied( builderUnderTest.buildFromFeatures(conf), BASIC_STEP_TYPE, + LOCAL_DIRS_STEP_TYPE, SECRETS_STEP_TYPE) } From 3e2600538ee477ffe3f23fba57719e035219550b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Cattilapiros=E2=80=9D?= Date: Thu, 10 May 2018 14:26:38 -0700 Subject: [PATCH 0778/1161] [SPARK-19181][CORE] Fixing flaky "SparkListenerSuite.local metrics" MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? Sometimes "SparkListenerSuite.local metrics" test fails because the average of executorDeserializeTime is too short. As squito suggested to avoid these situations in one of the task a reference introduced to an object implementing a custom Externalizable.readExternal which sleeps 1ms before returning. ## How was this patch tested? With unit tests (and checking the effect of this change to the average with a much larger sleep time). Author: “attilapiros” Author: Attila Zsolt Piros <2017933+attilapiros@users.noreply.github.com> Closes #21280 from attilapiros/SPARK-19181. --- .../spark/scheduler/SparkListenerSuite.scala | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala index da6ecb82c7e42..fa47a52bbbc47 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.scheduler +import java.io.{Externalizable, IOException, ObjectInput, ObjectOutput} import java.util.concurrent.Semaphore import scala.collection.JavaConverters._ @@ -294,10 +295,13 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match val listener = new SaveStageAndTaskInfo sc.addSparkListener(listener) sc.addSparkListener(new StatsReportListener) - // just to make sure some of the tasks take a noticeable amount of time + // just to make sure some of the tasks and their deserialization take a noticeable + // amount of time + val slowDeserializable = new SlowDeserializable val w = { i: Int => if (i == 0) { Thread.sleep(100) + slowDeserializable.use() } i } @@ -583,3 +587,12 @@ private class FirehoseListenerThatAcceptsSparkConf(conf: SparkConf) extends Spar case _ => } } + +private class SlowDeserializable extends Externalizable { + + override def writeExternal(out: ObjectOutput): Unit = { } + + override def readExternal(in: ObjectInput): Unit = Thread.sleep(1) + + def use(): Unit = { } +} From d3c426a5b02abdec49ff45df12a8f11f9e473a88 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Thu, 10 May 2018 14:41:55 -0700 Subject: [PATCH 0779/1161] [SPARK-10878][CORE] Fix race condition when multiple clients resolves artifacts at the same time ## What changes were proposed in this pull request? When multiple clients attempt to resolve artifacts via the `--packages` parameter, they could run into race condition when they each attempt to modify the dummy `org.apache.spark-spark-submit-parent-default.xml` file created in the default ivy cache dir. This PR changes the behavior to encode UUID in the dummy module descriptor so each client will operate on a different resolution file in the ivy cache dir. In addition, this patch changes the behavior of when and which resolution files are cleaned to prevent accumulation of resolution files in the default ivy cache dir. Since this PR is a successor of #18801, close #18801. Many codes were ported from #18801. **Many efforts were put here. I think this PR should credit to Victsm .** ## How was this patch tested? added UT into `SparkSubmitUtilsSuite` Author: Kazuaki Ishizaki Closes #21251 from kiszk/SPARK-10878. --- .../org/apache/spark/deploy/SparkSubmit.scala | 42 ++++++++++++++----- .../spark/deploy/SparkSubmitUtilsSuite.scala | 15 +++++++ 2 files changed, 47 insertions(+), 10 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 427c797755b84..087e9c31a9c9a 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -22,6 +22,7 @@ import java.lang.reflect.{InvocationTargetException, Modifier, UndeclaredThrowab import java.net.URL import java.security.PrivilegedExceptionAction import java.text.ParseException +import java.util.UUID import scala.annotation.tailrec import scala.collection.mutable.{ArrayBuffer, HashMap, Map} @@ -1204,7 +1205,33 @@ private[spark] object SparkSubmitUtils { /** A nice function to use in tests as well. Values are dummy strings. */ def getModuleDescriptor: DefaultModuleDescriptor = DefaultModuleDescriptor.newDefaultInstance( - ModuleRevisionId.newInstance("org.apache.spark", "spark-submit-parent", "1.0")) + // Include UUID in module name, so multiple clients resolving maven coordinate at the same time + // do not modify the same resolution file concurrently. + ModuleRevisionId.newInstance("org.apache.spark", + s"spark-submit-parent-${UUID.randomUUID.toString}", + "1.0")) + + /** + * Clear ivy resolution from current launch. The resolution file is usually at + * ~/.ivy2/org.apache.spark-spark-submit-parent-$UUID-default.xml, + * ~/.ivy2/resolved-org.apache.spark-spark-submit-parent-$UUID-1.0.xml, and + * ~/.ivy2/resolved-org.apache.spark-spark-submit-parent-$UUID-1.0.properties. + * Since each launch will have its own resolution files created, delete them after + * each resolution to prevent accumulation of these files in the ivy cache dir. + */ + private def clearIvyResolutionFiles( + mdId: ModuleRevisionId, + ivySettings: IvySettings, + ivyConfName: String): Unit = { + val currentResolutionFiles = Seq( + s"${mdId.getOrganisation}-${mdId.getName}-$ivyConfName.xml", + s"resolved-${mdId.getOrganisation}-${mdId.getName}-${mdId.getRevision}.xml", + s"resolved-${mdId.getOrganisation}-${mdId.getName}-${mdId.getRevision}.properties" + ) + currentResolutionFiles.foreach { filename => + new File(ivySettings.getDefaultCache, filename).delete() + } + } /** * Resolves any dependencies that were supplied through maven coordinates @@ -1255,14 +1282,6 @@ private[spark] object SparkSubmitUtils { // A Module descriptor must be specified. Entries are dummy strings val md = getModuleDescriptor - // clear ivy resolution from previous launches. The resolution file is usually at - // ~/.ivy2/org.apache.spark-spark-submit-parent-default.xml. In between runs, this file - // leads to confusion with Ivy when the files can no longer be found at the repository - // declared in that file/ - val mdId = md.getModuleRevisionId - val previousResolution = new File(ivySettings.getDefaultCache, - s"${mdId.getOrganisation}-${mdId.getName}-$ivyConfName.xml") - if (previousResolution.exists) previousResolution.delete md.setDefaultConf(ivyConfName) @@ -1283,7 +1302,10 @@ private[spark] object SparkSubmitUtils { packagesDirectory.getAbsolutePath + File.separator + "[organization]_[artifact]-[revision](-[classifier]).[ext]", retrieveOptions.setConfs(Array(ivyConfName))) - resolveDependencyPaths(rr.getArtifacts.toArray, packagesDirectory) + val paths = resolveDependencyPaths(rr.getArtifacts.toArray, packagesDirectory) + val mdId = md.getModuleRevisionId + clearIvyResolutionFiles(mdId, ivySettings, ivyConfName) + paths } finally { System.setOut(sysOut) } diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala index eb8c203ae7751..a0f09891787e0 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala @@ -256,4 +256,19 @@ class SparkSubmitUtilsSuite extends SparkFunSuite with BeforeAndAfterAll { assert(jarPath.indexOf("mydep") >= 0, "should find dependency") } } + + test("SPARK-10878: test resolution files cleaned after resolving artifact") { + val main = new MavenCoordinate("my.great.lib", "mylib", "0.1") + + IvyTestUtils.withRepository(main, None, None) { repo => + val ivySettings = SparkSubmitUtils.buildIvySettings(Some(repo), Some(tempIvyPath)) + val jarPath = SparkSubmitUtils.resolveMavenCoordinates( + main.toString, + ivySettings, + isTest = true) + val r = """.*org.apache.spark-spark-submit-parent-.*""".r + assert(!ivySettings.getDefaultCache.listFiles.map(_.getName) + .exists(r.findFirstIn(_).isDefined), "resolution files should be cleaned") + } + } } From a4206d58e05ab9ed6f01fee57e18dee65cbc4efc Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 11 May 2018 09:01:40 +0800 Subject: [PATCH 0780/1161] [SPARK-22938][SQL][FOLLOWUP] Assert that SQLConf.get is accessed only on the driver ## What changes were proposed in this pull request? This is a followup of https://github.com/apache/spark/pull/20136 . #20136 didn't really work because in the test, we are using local backend, which shares the driver side `SparkEnv`, so `SparkEnv.get.executorId == SparkContext.DRIVER_IDENTIFIER` doesn't work. This PR changes the check to `TaskContext.get != null`, and move the check to `SQLConf.get`, and fix all the places that violate this check: * `InMemoryTableScanExec#createAndDecompressColumn` is executed inside `rdd.map`, we can't access `conf.offHeapColumnVectorEnabled` there. https://github.com/apache/spark/pull/21223 merged * `DataType#sameType` may be executed in executor side, for things like json schema inference, so we can't call `conf.caseSensitiveAnalysis` there. This contributes to most of the code changes, as we need to add `caseSensitive` parameter to a lot of methods. * `ParquetFilters` is used in the file scan function, which is executed in executor side, so we can't can't call `conf.parquetFilterPushDownDate` there. https://github.com/apache/spark/pull/21224 merged * `WindowExec#createBoundOrdering` is called on executor side, so we can't use `conf.sessionLocalTimezone` there. https://github.com/apache/spark/pull/21225 merged * `JsonToStructs` can be serialized to executors and evaluate, we should not call `SQLConf.get.getConf(SQLConf.FROM_JSON_FORCE_NULLABLE_SCHEMA)` in the body. https://github.com/apache/spark/pull/21226 merged ## How was this patch tested? existing test Author: Wenchen Fan Closes #21190 from cloud-fan/minor. --- .../sql/catalyst/analysis/CheckAnalysis.scala | 5 +- .../analysis/ResolveInlineTables.scala | 4 +- .../sql/catalyst/analysis/TypeCoercion.scala | 156 ++++++++++-------- .../apache/spark/sql/internal/SQLConf.scala | 16 +- .../org/apache/spark/sql/types/DataType.scala | 8 +- .../catalyst/analysis/TypeCoercionSuite.scala | 70 ++++---- .../org/apache/spark/sql/SparkSession.scala | 21 ++- .../datasources/PartitioningUtils.scala | 5 +- .../datasources/json/JsonInferSchema.scala | 39 +++-- .../datasources/json/JsonSuite.scala | 4 +- 10 files changed, 188 insertions(+), 140 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 90bda2a72ad82..94b0561529e71 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.optimizer.BooleanSimplification import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ /** @@ -260,7 +261,9 @@ trait CheckAnalysis extends PredicateHelper { // Check if the data types match. dataTypes(child).zip(ref).zipWithIndex.foreach { case ((dt1, dt2), ci) => // SPARK-18058: we shall not care about the nullability of columns - if (TypeCoercion.findWiderTypeForTwo(dt1.asNullable, dt2.asNullable).isEmpty) { + val widerType = TypeCoercion.findWiderTypeForTwo( + dt1.asNullable, dt2.asNullable, SQLConf.get.caseSensitiveAnalysis) + if (widerType.isEmpty) { failAnalysis( s""" |${operator.nodeName} can only be performed on tables with the compatible diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala index f2df3e132629f..4eb6e642b1c37 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala @@ -83,7 +83,9 @@ case class ResolveInlineTables(conf: SQLConf) extends Rule[LogicalPlan] with Cas // For each column, traverse all the values and find a common data type and nullability. val fields = table.rows.transpose.zip(table.names).map { case (column, name) => val inputTypes = column.map(_.dataType) - val tpe = TypeCoercion.findWiderTypeWithoutStringPromotion(inputTypes).getOrElse { + val wideType = TypeCoercion.findWiderTypeWithoutStringPromotion( + inputTypes, conf.caseSensitiveAnalysis) + val tpe = wideType.getOrElse { table.failAnalysis(s"incompatible types found in column $name for inline table") } StructField(name, tpe, nullable = column.exists(_.nullable)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index b2817b0538a7f..a7ba201509b78 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -48,18 +48,18 @@ object TypeCoercion { def typeCoercionRules(conf: SQLConf): List[Rule[LogicalPlan]] = InConversion(conf) :: - WidenSetOperationTypes :: + WidenSetOperationTypes(conf) :: PromoteStrings(conf) :: DecimalPrecision :: BooleanEquality :: - FunctionArgumentConversion :: + FunctionArgumentConversion(conf) :: ConcatCoercion(conf) :: EltCoercion(conf) :: - CaseWhenCoercion :: - IfCoercion :: + CaseWhenCoercion(conf) :: + IfCoercion(conf) :: StackCoercion :: Division :: - new ImplicitTypeCasts(conf) :: + ImplicitTypeCasts(conf) :: DateTimeOperations :: WindowFrameCoercion :: Nil @@ -83,7 +83,10 @@ object TypeCoercion { * with primitive types, because in that case the precision and scale of the result depends on * the operation. Those rules are implemented in [[DecimalPrecision]]. */ - val findTightestCommonType: (DataType, DataType) => Option[DataType] = { + def findTightestCommonType( + left: DataType, + right: DataType, + caseSensitive: Boolean): Option[DataType] = (left, right) match { case (t1, t2) if t1 == t2 => Some(t1) case (NullType, t1) => Some(t1) case (t1, NullType) => Some(t1) @@ -102,22 +105,32 @@ object TypeCoercion { case (_: TimestampType, _: DateType) | (_: DateType, _: TimestampType) => Some(TimestampType) - case (t1 @ StructType(fields1), t2 @ StructType(fields2)) if t1.sameType(t2) => - Some(StructType(fields1.zip(fields2).map { case (f1, f2) => - // Since `t1.sameType(t2)` is true, two StructTypes have the same DataType - // except `name` (in case of `spark.sql.caseSensitive=false`) and `nullable`. - // - Different names: use f1.name - // - Different nullabilities: `nullable` is true iff one of them is nullable. - val dataType = findTightestCommonType(f1.dataType, f2.dataType).get - StructField(f1.name, dataType, nullable = f1.nullable || f2.nullable) - })) + case (t1 @ StructType(fields1), t2 @ StructType(fields2)) => + val isSameType = if (caseSensitive) { + DataType.equalsIgnoreNullability(t1, t2) + } else { + DataType.equalsIgnoreCaseAndNullability(t1, t2) + } + + if (isSameType) { + Some(StructType(fields1.zip(fields2).map { case (f1, f2) => + // Since t1 is same type of t2, two StructTypes have the same DataType + // except `name` (in case of `spark.sql.caseSensitive=false`) and `nullable`. + // - Different names: use f1.name + // - Different nullabilities: `nullable` is true iff one of them is nullable. + val dataType = findTightestCommonType(f1.dataType, f2.dataType, caseSensitive).get + StructField(f1.name, dataType, nullable = f1.nullable || f2.nullable) + })) + } else { + None + } case (a1 @ ArrayType(et1, hasNull1), a2 @ ArrayType(et2, hasNull2)) if a1.sameType(a2) => - findTightestCommonType(et1, et2).map(ArrayType(_, hasNull1 || hasNull2)) + findTightestCommonType(et1, et2, caseSensitive).map(ArrayType(_, hasNull1 || hasNull2)) case (m1 @ MapType(kt1, vt1, hasNull1), m2 @ MapType(kt2, vt2, hasNull2)) if m1.sameType(m2) => - val keyType = findTightestCommonType(kt1, kt2) - val valueType = findTightestCommonType(vt1, vt2) + val keyType = findTightestCommonType(kt1, kt2, caseSensitive) + val valueType = findTightestCommonType(vt1, vt2, caseSensitive) Some(MapType(keyType.get, valueType.get, hasNull1 || hasNull2)) case _ => None @@ -172,13 +185,14 @@ object TypeCoercion { * i.e. the main difference with [[findTightestCommonType]] is that here we allow some * loss of precision when widening decimal and double, and promotion to string. */ - def findWiderTypeForTwo(t1: DataType, t2: DataType): Option[DataType] = { - findTightestCommonType(t1, t2) + def findWiderTypeForTwo(t1: DataType, t2: DataType, caseSensitive: Boolean): Option[DataType] = { + findTightestCommonType(t1, t2, caseSensitive) .orElse(findWiderTypeForDecimal(t1, t2)) .orElse(stringPromotion(t1, t2)) .orElse((t1, t2) match { case (ArrayType(et1, containsNull1), ArrayType(et2, containsNull2)) => - findWiderTypeForTwo(et1, et2).map(ArrayType(_, containsNull1 || containsNull2)) + findWiderTypeForTwo(et1, et2, caseSensitive) + .map(ArrayType(_, containsNull1 || containsNull2)) case _ => None }) } @@ -193,7 +207,8 @@ object TypeCoercion { case _ => false } - private def findWiderCommonType(types: Seq[DataType]): Option[DataType] = { + private def findWiderCommonType( + types: Seq[DataType], caseSensitive: Boolean): Option[DataType] = { // findWiderTypeForTwo doesn't satisfy the associative law, i.e. (a op b) op c may not equal // to a op (b op c). This is only a problem for StringType or nested StringType in ArrayType. // Excluding these types, findWiderTypeForTwo satisfies the associative law. For instance, @@ -201,7 +216,7 @@ object TypeCoercion { val (stringTypes, nonStringTypes) = types.partition(hasStringType(_)) (stringTypes.distinct ++ nonStringTypes).foldLeft[Option[DataType]](Some(NullType))((r, c) => r match { - case Some(d) => findWiderTypeForTwo(d, c) + case Some(d) => findWiderTypeForTwo(d, c, caseSensitive) case _ => None }) } @@ -213,20 +228,22 @@ object TypeCoercion { */ private[analysis] def findWiderTypeWithoutStringPromotionForTwo( t1: DataType, - t2: DataType): Option[DataType] = { - findTightestCommonType(t1, t2) + t2: DataType, + caseSensitive: Boolean): Option[DataType] = { + findTightestCommonType(t1, t2, caseSensitive) .orElse(findWiderTypeForDecimal(t1, t2)) .orElse((t1, t2) match { case (ArrayType(et1, containsNull1), ArrayType(et2, containsNull2)) => - findWiderTypeWithoutStringPromotionForTwo(et1, et2) + findWiderTypeWithoutStringPromotionForTwo(et1, et2, caseSensitive) .map(ArrayType(_, containsNull1 || containsNull2)) case _ => None }) } - def findWiderTypeWithoutStringPromotion(types: Seq[DataType]): Option[DataType] = { + def findWiderTypeWithoutStringPromotion( + types: Seq[DataType], caseSensitive: Boolean): Option[DataType] = { types.foldLeft[Option[DataType]](Some(NullType))((r, c) => r match { - case Some(d) => findWiderTypeWithoutStringPromotionForTwo(d, c) + case Some(d) => findWiderTypeWithoutStringPromotionForTwo(d, c, caseSensitive) case None => None }) } @@ -279,29 +296,32 @@ object TypeCoercion { * * This rule is only applied to Union/Except/Intersect */ - object WidenSetOperationTypes extends Rule[LogicalPlan] { + case class WidenSetOperationTypes(conf: SQLConf) extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case s @ SetOperation(left, right) if s.childrenResolved && left.output.length == right.output.length && !s.resolved => - val newChildren: Seq[LogicalPlan] = buildNewChildrenWithWiderTypes(left :: right :: Nil) + val newChildren: Seq[LogicalPlan] = + buildNewChildrenWithWiderTypes(left :: right :: Nil, conf.caseSensitiveAnalysis) assert(newChildren.length == 2) s.makeCopy(Array(newChildren.head, newChildren.last)) case s: Union if s.childrenResolved && s.children.forall(_.output.length == s.children.head.output.length) && !s.resolved => - val newChildren: Seq[LogicalPlan] = buildNewChildrenWithWiderTypes(s.children) + val newChildren: Seq[LogicalPlan] = + buildNewChildrenWithWiderTypes(s.children, conf.caseSensitiveAnalysis) s.makeCopy(Array(newChildren)) } /** Build new children with the widest types for each attribute among all the children */ - private def buildNewChildrenWithWiderTypes(children: Seq[LogicalPlan]): Seq[LogicalPlan] = { + private def buildNewChildrenWithWiderTypes( + children: Seq[LogicalPlan], caseSensitive: Boolean): Seq[LogicalPlan] = { require(children.forall(_.output.length == children.head.output.length)) // Get a sequence of data types, each of which is the widest type of this specific attribute // in all the children val targetTypes: Seq[DataType] = - getWidestTypes(children, attrIndex = 0, mutable.Queue[DataType]()) + getWidestTypes(children, attrIndex = 0, mutable.Queue[DataType](), caseSensitive) if (targetTypes.nonEmpty) { // Add an extra Project if the targetTypes are different from the original types. @@ -316,18 +336,19 @@ object TypeCoercion { @tailrec private def getWidestTypes( children: Seq[LogicalPlan], attrIndex: Int, - castedTypes: mutable.Queue[DataType]): Seq[DataType] = { + castedTypes: mutable.Queue[DataType], + caseSensitive: Boolean): Seq[DataType] = { // Return the result after the widen data types have been found for all the children if (attrIndex >= children.head.output.length) return castedTypes.toSeq // For the attrIndex-th attribute, find the widest type - findWiderCommonType(children.map(_.output(attrIndex).dataType)) match { + findWiderCommonType(children.map(_.output(attrIndex).dataType), caseSensitive) match { // If unable to find an appropriate widen type for this column, return an empty Seq case None => Seq.empty[DataType] // Otherwise, record the result in the queue and find the type for the next column case Some(widenType) => castedTypes.enqueue(widenType) - getWidestTypes(children, attrIndex + 1, castedTypes) + getWidestTypes(children, attrIndex + 1, castedTypes, caseSensitive) } } @@ -432,7 +453,7 @@ object TypeCoercion { val commonTypes = lhs.zip(rhs).flatMap { case (l, r) => findCommonTypeForBinaryComparison(l.dataType, r.dataType, conf) - .orElse(findTightestCommonType(l.dataType, r.dataType)) + .orElse(findTightestCommonType(l.dataType, r.dataType, conf.caseSensitiveAnalysis)) } // The number of columns/expressions must match between LHS and RHS of an @@ -461,7 +482,7 @@ object TypeCoercion { } case i @ In(a, b) if b.exists(_.dataType != a.dataType) => - findWiderCommonType(i.children.map(_.dataType)) match { + findWiderCommonType(i.children.map(_.dataType), conf.caseSensitiveAnalysis) match { case Some(finalDataType) => i.withNewChildren(i.children.map(Cast(_, finalDataType))) case None => i } @@ -515,7 +536,7 @@ object TypeCoercion { /** * This ensure that the types for various functions are as expected. */ - object FunctionArgumentConversion extends TypeCoercionRule { + case class FunctionArgumentConversion(conf: SQLConf) extends TypeCoercionRule { override protected def coerceTypes( plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { // Skip nodes who's children have not been resolved yet. @@ -523,7 +544,7 @@ object TypeCoercion { case a @ CreateArray(children) if !haveSameType(children) => val types = children.map(_.dataType) - findWiderCommonType(types) match { + findWiderCommonType(types, conf.caseSensitiveAnalysis) match { case Some(finalDataType) => CreateArray(children.map(Cast(_, finalDataType))) case None => a } @@ -531,7 +552,7 @@ object TypeCoercion { case c @ Concat(children) if children.forall(c => ArrayType.acceptsType(c.dataType)) && !haveSameType(children) => val types = children.map(_.dataType) - findWiderCommonType(types) match { + findWiderCommonType(types, conf.caseSensitiveAnalysis) match { case Some(finalDataType) => Concat(children.map(Cast(_, finalDataType))) case None => c } @@ -542,7 +563,7 @@ object TypeCoercion { m.keys } else { val types = m.keys.map(_.dataType) - findWiderCommonType(types) match { + findWiderCommonType(types, conf.caseSensitiveAnalysis) match { case Some(finalDataType) => m.keys.map(Cast(_, finalDataType)) case None => m.keys } @@ -552,7 +573,7 @@ object TypeCoercion { m.values } else { val types = m.values.map(_.dataType) - findWiderCommonType(types) match { + findWiderCommonType(types, conf.caseSensitiveAnalysis) match { case Some(finalDataType) => m.values.map(Cast(_, finalDataType)) case None => m.values } @@ -580,7 +601,7 @@ object TypeCoercion { // compatible with every child column. case c @ Coalesce(es) if !haveSameType(es) => val types = es.map(_.dataType) - findWiderCommonType(types) match { + findWiderCommonType(types, conf.caseSensitiveAnalysis) match { case Some(finalDataType) => Coalesce(es.map(Cast(_, finalDataType))) case None => c } @@ -590,14 +611,14 @@ object TypeCoercion { // string.g case g @ Greatest(children) if !haveSameType(children) => val types = children.map(_.dataType) - findWiderTypeWithoutStringPromotion(types) match { + findWiderTypeWithoutStringPromotion(types, conf.caseSensitiveAnalysis) match { case Some(finalDataType) => Greatest(children.map(Cast(_, finalDataType))) case None => g } case l @ Least(children) if !haveSameType(children) => val types = children.map(_.dataType) - findWiderTypeWithoutStringPromotion(types) match { + findWiderTypeWithoutStringPromotion(types, conf.caseSensitiveAnalysis) match { case Some(finalDataType) => Least(children.map(Cast(_, finalDataType))) case None => l } @@ -637,11 +658,11 @@ object TypeCoercion { /** * Coerces the type of different branches of a CASE WHEN statement to a common type. */ - object CaseWhenCoercion extends TypeCoercionRule { + case class CaseWhenCoercion(conf: SQLConf) extends TypeCoercionRule { override protected def coerceTypes( plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { case c: CaseWhen if c.childrenResolved && !c.valueTypesEqual => - val maybeCommonType = findWiderCommonType(c.valueTypes) + val maybeCommonType = findWiderCommonType(c.valueTypes, conf.caseSensitiveAnalysis) maybeCommonType.map { commonType => var changed = false val newBranches = c.branches.map { case (condition, value) => @@ -668,16 +689,17 @@ object TypeCoercion { /** * Coerces the type of different branches of If statement to a common type. */ - object IfCoercion extends TypeCoercionRule { + case class IfCoercion(conf: SQLConf) extends TypeCoercionRule { override protected def coerceTypes( plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { case e if !e.childrenResolved => e // Find tightest common type for If, if the true value and false value have different types. case i @ If(pred, left, right) if left.dataType != right.dataType => - findWiderTypeForTwo(left.dataType, right.dataType).map { widestType => - val newLeft = if (left.dataType == widestType) left else Cast(left, widestType) - val newRight = if (right.dataType == widestType) right else Cast(right, widestType) - If(pred, newLeft, newRight) + findWiderTypeForTwo(left.dataType, right.dataType, conf.caseSensitiveAnalysis).map { + widestType => + val newLeft = if (left.dataType == widestType) left else Cast(left, widestType) + val newRight = if (right.dataType == widestType) right else Cast(right, widestType) + If(pred, newLeft, newRight) }.getOrElse(i) // If there is no applicable conversion, leave expression unchanged. case If(Literal(null, NullType), left, right) => If(Literal.create(null, BooleanType), left, right) @@ -776,12 +798,11 @@ object TypeCoercion { /** * Casts types according to the expected input types for [[Expression]]s. */ - class ImplicitTypeCasts(conf: SQLConf) extends TypeCoercionRule { + case class ImplicitTypeCasts(conf: SQLConf) extends TypeCoercionRule { private def rejectTzInString = conf.getConf(SQLConf.REJECT_TIMEZONE_IN_STRING) - override protected def coerceTypes( - plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + override def coerceTypes(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e @@ -804,17 +825,18 @@ object TypeCoercion { } case b @ BinaryOperator(left, right) if left.dataType != right.dataType => - findTightestCommonType(left.dataType, right.dataType).map { commonType => - if (b.inputType.acceptsType(commonType)) { - // If the expression accepts the tightest common type, cast to that. - val newLeft = if (left.dataType == commonType) left else Cast(left, commonType) - val newRight = if (right.dataType == commonType) right else Cast(right, commonType) - b.withNewChildren(Seq(newLeft, newRight)) - } else { - // Otherwise, don't do anything with the expression. - b - } - }.getOrElse(b) // If there is no applicable conversion, leave expression unchanged. + findTightestCommonType(left.dataType, right.dataType, conf.caseSensitiveAnalysis).map { + commonType => + if (b.inputType.acceptsType(commonType)) { + // If the expression accepts the tightest common type, cast to that. + val newLeft = if (left.dataType == commonType) left else Cast(left, commonType) + val newRight = if (right.dataType == commonType) right else Cast(right, commonType) + b.withNewChildren(Seq(newLeft, newRight)) + } else { + // Otherwise, don't do anything with the expression. + b + } + }.getOrElse(b) // If there is no applicable conversion, leave expression unchanged. case e: ImplicitCastInputTypes if e.inputTypes.nonEmpty => val children: Seq[Expression] = e.children.zip(e.inputTypes).map { case (in, expected) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index b00edca97cd44..0b1965c438e27 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -27,7 +27,7 @@ import scala.util.matching.Regex import org.apache.hadoop.fs.Path -import org.apache.spark.{SparkContext, SparkEnv} +import org.apache.spark.TaskContext import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ import org.apache.spark.network.util.ByteUnit @@ -107,7 +107,13 @@ object SQLConf { * run tests in parallel. At the time this feature was implemented, this was a no-op since we * run unit tests (that does not involve SparkSession) in serial order. */ - def get: SQLConf = confGetter.get()() + def get: SQLConf = { + if (Utils.isTesting && TaskContext.get != null) { + // we're accessing it during task execution, fail. + throw new IllegalStateException("SQLConf should only be created and accessed on the driver.") + } + confGetter.get()() + } val OPTIMIZER_MAX_ITERATIONS = buildConf("spark.sql.optimizer.maxIterations") .internal() @@ -1274,12 +1280,6 @@ object SQLConf { class SQLConf extends Serializable with Logging { import SQLConf._ - if (Utils.isTesting && SparkEnv.get != null) { - // assert that we're only accessing it on the driver. - assert(SparkEnv.get.executorId == SparkContext.DRIVER_IDENTIFIER, - "SQLConf should only be created and accessed on the driver.") - } - /** Only low degree of contention is expected for conf, thus NOT using ConcurrentHashMap. */ @transient protected[spark] val settings = java.util.Collections.synchronizedMap( new java.util.HashMap[String, String]()) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala index 0bef11659fc9e..4ee12db9c10ca 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -81,11 +81,7 @@ abstract class DataType extends AbstractDataType { * (`StructField.nullable`, `ArrayType.containsNull`, and `MapType.valueContainsNull`). */ private[spark] def sameType(other: DataType): Boolean = - if (SQLConf.get.caseSensitiveAnalysis) { - DataType.equalsIgnoreNullability(this, other) - } else { - DataType.equalsIgnoreCaseAndNullability(this, other) - } + DataType.equalsIgnoreNullability(this, other) /** * Returns the same data type but set all nullability fields are true @@ -218,7 +214,7 @@ object DataType { /** * Compares two types, ignoring nullability of ArrayType, MapType, StructType. */ - private[types] def equalsIgnoreNullability(left: DataType, right: DataType): Boolean = { + private[sql] def equalsIgnoreNullability(left: DataType, right: DataType): Boolean = { (left, right) match { case (ArrayType(leftElementType, _), ArrayType(rightElementType, _)) => equalsIgnoreNullability(leftElementType, rightElementType) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala index 0acd3b490447d..f73e045685ee1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala @@ -128,17 +128,17 @@ class TypeCoercionSuite extends AnalysisTest { } private def checkWidenType( - widenFunc: (DataType, DataType) => Option[DataType], + widenFunc: (DataType, DataType, Boolean) => Option[DataType], t1: DataType, t2: DataType, expected: Option[DataType], isSymmetric: Boolean = true): Unit = { - var found = widenFunc(t1, t2) + var found = widenFunc(t1, t2, conf.caseSensitiveAnalysis) assert(found == expected, s"Expected $expected as wider common type for $t1 and $t2, found $found") // Test both directions to make sure the widening is symmetric. if (isSymmetric) { - found = widenFunc(t2, t1) + found = widenFunc(t2, t1, conf.caseSensitiveAnalysis) assert(found == expected, s"Expected $expected as wider common type for $t2 and $t1, found $found") } @@ -524,11 +524,11 @@ class TypeCoercionSuite extends AnalysisTest { test("cast NullType for expressions that implement ExpectsInputTypes") { import TypeCoercionSuite._ - ruleTest(new TypeCoercion.ImplicitTypeCasts(conf), + ruleTest(TypeCoercion.ImplicitTypeCasts(conf), AnyTypeUnaryExpression(Literal.create(null, NullType)), AnyTypeUnaryExpression(Literal.create(null, NullType))) - ruleTest(new TypeCoercion.ImplicitTypeCasts(conf), + ruleTest(TypeCoercion.ImplicitTypeCasts(conf), NumericTypeUnaryExpression(Literal.create(null, NullType)), NumericTypeUnaryExpression(Literal.create(null, DoubleType))) } @@ -536,17 +536,17 @@ class TypeCoercionSuite extends AnalysisTest { test("cast NullType for binary operators") { import TypeCoercionSuite._ - ruleTest(new TypeCoercion.ImplicitTypeCasts(conf), + ruleTest(TypeCoercion.ImplicitTypeCasts(conf), AnyTypeBinaryOperator(Literal.create(null, NullType), Literal.create(null, NullType)), AnyTypeBinaryOperator(Literal.create(null, NullType), Literal.create(null, NullType))) - ruleTest(new TypeCoercion.ImplicitTypeCasts(conf), + ruleTest(TypeCoercion.ImplicitTypeCasts(conf), NumericTypeBinaryOperator(Literal.create(null, NullType), Literal.create(null, NullType)), NumericTypeBinaryOperator(Literal.create(null, DoubleType), Literal.create(null, DoubleType))) } test("coalesce casts") { - val rule = TypeCoercion.FunctionArgumentConversion + val rule = TypeCoercion.FunctionArgumentConversion(conf) val intLit = Literal(1) val longLit = Literal.create(1L) @@ -606,7 +606,7 @@ class TypeCoercionSuite extends AnalysisTest { } test("CreateArray casts") { - ruleTest(TypeCoercion.FunctionArgumentConversion, + ruleTest(TypeCoercion.FunctionArgumentConversion(conf), CreateArray(Literal(1.0) :: Literal(1) :: Literal.create(1.0, FloatType) @@ -616,7 +616,7 @@ class TypeCoercionSuite extends AnalysisTest { :: Cast(Literal.create(1.0, FloatType), DoubleType) :: Nil)) - ruleTest(TypeCoercion.FunctionArgumentConversion, + ruleTest(TypeCoercion.FunctionArgumentConversion(conf), CreateArray(Literal(1.0) :: Literal(1) :: Literal("a") @@ -626,7 +626,7 @@ class TypeCoercionSuite extends AnalysisTest { :: Cast(Literal("a"), StringType) :: Nil)) - ruleTest(TypeCoercion.FunctionArgumentConversion, + ruleTest(TypeCoercion.FunctionArgumentConversion(conf), CreateArray(Literal.create(null, DecimalType(5, 3)) :: Literal(1) :: Nil), @@ -634,7 +634,7 @@ class TypeCoercionSuite extends AnalysisTest { :: Literal(1).cast(DecimalType(13, 3)) :: Nil)) - ruleTest(TypeCoercion.FunctionArgumentConversion, + ruleTest(TypeCoercion.FunctionArgumentConversion(conf), CreateArray(Literal.create(null, DecimalType(5, 3)) :: Literal.create(null, DecimalType(22, 10)) :: Literal.create(null, DecimalType(38, 38)) @@ -647,7 +647,7 @@ class TypeCoercionSuite extends AnalysisTest { test("CreateMap casts") { // type coercion for map keys - ruleTest(TypeCoercion.FunctionArgumentConversion, + ruleTest(TypeCoercion.FunctionArgumentConversion(conf), CreateMap(Literal(1) :: Literal("a") :: Literal.create(2.0, FloatType) @@ -658,7 +658,7 @@ class TypeCoercionSuite extends AnalysisTest { :: Cast(Literal.create(2.0, FloatType), FloatType) :: Literal("b") :: Nil)) - ruleTest(TypeCoercion.FunctionArgumentConversion, + ruleTest(TypeCoercion.FunctionArgumentConversion(conf), CreateMap(Literal.create(null, DecimalType(5, 3)) :: Literal("a") :: Literal.create(2.0, FloatType) @@ -670,7 +670,7 @@ class TypeCoercionSuite extends AnalysisTest { :: Literal("b") :: Nil)) // type coercion for map values - ruleTest(TypeCoercion.FunctionArgumentConversion, + ruleTest(TypeCoercion.FunctionArgumentConversion(conf), CreateMap(Literal(1) :: Literal("a") :: Literal(2) @@ -681,7 +681,7 @@ class TypeCoercionSuite extends AnalysisTest { :: Literal(2) :: Cast(Literal(3.0), StringType) :: Nil)) - ruleTest(TypeCoercion.FunctionArgumentConversion, + ruleTest(TypeCoercion.FunctionArgumentConversion(conf), CreateMap(Literal(1) :: Literal.create(null, DecimalType(38, 0)) :: Literal(2) @@ -693,7 +693,7 @@ class TypeCoercionSuite extends AnalysisTest { :: Literal.create(null, DecimalType(38, 38)).cast(DecimalType(38, 38)) :: Nil)) // type coercion for both map keys and values - ruleTest(TypeCoercion.FunctionArgumentConversion, + ruleTest(TypeCoercion.FunctionArgumentConversion(conf), CreateMap(Literal(1) :: Literal("a") :: Literal(2.0) @@ -708,7 +708,7 @@ class TypeCoercionSuite extends AnalysisTest { test("greatest/least cast") { for (operator <- Seq[(Seq[Expression] => Expression)](Greatest, Least)) { - ruleTest(TypeCoercion.FunctionArgumentConversion, + ruleTest(TypeCoercion.FunctionArgumentConversion(conf), operator(Literal(1.0) :: Literal(1) :: Literal.create(1.0, FloatType) @@ -717,7 +717,7 @@ class TypeCoercionSuite extends AnalysisTest { :: Cast(Literal(1), DoubleType) :: Cast(Literal.create(1.0, FloatType), DoubleType) :: Nil)) - ruleTest(TypeCoercion.FunctionArgumentConversion, + ruleTest(TypeCoercion.FunctionArgumentConversion(conf), operator(Literal(1L) :: Literal(1) :: Literal(new java.math.BigDecimal("1000000000000000000000")) @@ -726,7 +726,7 @@ class TypeCoercionSuite extends AnalysisTest { :: Cast(Literal(1), DecimalType(22, 0)) :: Cast(Literal(new java.math.BigDecimal("1000000000000000000000")), DecimalType(22, 0)) :: Nil)) - ruleTest(TypeCoercion.FunctionArgumentConversion, + ruleTest(TypeCoercion.FunctionArgumentConversion(conf), operator(Literal(1.0) :: Literal.create(null, DecimalType(10, 5)) :: Literal(1) @@ -735,7 +735,7 @@ class TypeCoercionSuite extends AnalysisTest { :: Literal.create(null, DecimalType(10, 5)).cast(DoubleType) :: Literal(1).cast(DoubleType) :: Nil)) - ruleTest(TypeCoercion.FunctionArgumentConversion, + ruleTest(TypeCoercion.FunctionArgumentConversion(conf), operator(Literal.create(null, DecimalType(15, 0)) :: Literal.create(null, DecimalType(10, 5)) :: Literal(1) @@ -744,7 +744,7 @@ class TypeCoercionSuite extends AnalysisTest { :: Literal.create(null, DecimalType(10, 5)).cast(DecimalType(20, 5)) :: Literal(1).cast(DecimalType(20, 5)) :: Nil)) - ruleTest(TypeCoercion.FunctionArgumentConversion, + ruleTest(TypeCoercion.FunctionArgumentConversion(conf), operator(Literal.create(2L, LongType) :: Literal(1) :: Literal.create(null, DecimalType(10, 5)) @@ -757,25 +757,25 @@ class TypeCoercionSuite extends AnalysisTest { } test("nanvl casts") { - ruleTest(TypeCoercion.FunctionArgumentConversion, + ruleTest(TypeCoercion.FunctionArgumentConversion(conf), NaNvl(Literal.create(1.0f, FloatType), Literal.create(1.0, DoubleType)), NaNvl(Cast(Literal.create(1.0f, FloatType), DoubleType), Literal.create(1.0, DoubleType))) - ruleTest(TypeCoercion.FunctionArgumentConversion, + ruleTest(TypeCoercion.FunctionArgumentConversion(conf), NaNvl(Literal.create(1.0, DoubleType), Literal.create(1.0f, FloatType)), NaNvl(Literal.create(1.0, DoubleType), Cast(Literal.create(1.0f, FloatType), DoubleType))) - ruleTest(TypeCoercion.FunctionArgumentConversion, + ruleTest(TypeCoercion.FunctionArgumentConversion(conf), NaNvl(Literal.create(1.0, DoubleType), Literal.create(1.0, DoubleType)), NaNvl(Literal.create(1.0, DoubleType), Literal.create(1.0, DoubleType))) - ruleTest(TypeCoercion.FunctionArgumentConversion, + ruleTest(TypeCoercion.FunctionArgumentConversion(conf), NaNvl(Literal.create(1.0f, FloatType), Literal.create(null, NullType)), NaNvl(Literal.create(1.0f, FloatType), Cast(Literal.create(null, NullType), FloatType))) - ruleTest(TypeCoercion.FunctionArgumentConversion, + ruleTest(TypeCoercion.FunctionArgumentConversion(conf), NaNvl(Literal.create(1.0, DoubleType), Literal.create(null, NullType)), NaNvl(Literal.create(1.0, DoubleType), Cast(Literal.create(null, NullType), DoubleType))) } test("type coercion for If") { - val rule = TypeCoercion.IfCoercion + val rule = TypeCoercion.IfCoercion(conf) val intLit = Literal(1) val doubleLit = Literal(1.0) val trueLit = Literal.create(true, BooleanType) @@ -823,20 +823,20 @@ class TypeCoercionSuite extends AnalysisTest { } test("type coercion for CaseKeyWhen") { - ruleTest(new TypeCoercion.ImplicitTypeCasts(conf), + ruleTest(TypeCoercion.ImplicitTypeCasts(conf), CaseKeyWhen(Literal(1.toShort), Seq(Literal(1), Literal("a"))), CaseKeyWhen(Cast(Literal(1.toShort), IntegerType), Seq(Literal(1), Literal("a"))) ) - ruleTest(TypeCoercion.CaseWhenCoercion, + ruleTest(TypeCoercion.CaseWhenCoercion(conf), CaseKeyWhen(Literal(true), Seq(Literal(1), Literal("a"))), CaseKeyWhen(Literal(true), Seq(Literal(1), Literal("a"))) ) - ruleTest(TypeCoercion.CaseWhenCoercion, + ruleTest(TypeCoercion.CaseWhenCoercion(conf), CaseWhen(Seq((Literal(true), Literal(1.2))), Literal.create(1, DecimalType(7, 2))), CaseWhen(Seq((Literal(true), Literal(1.2))), Cast(Literal.create(1, DecimalType(7, 2)), DoubleType)) ) - ruleTest(TypeCoercion.CaseWhenCoercion, + ruleTest(TypeCoercion.CaseWhenCoercion(conf), CaseWhen(Seq((Literal(true), Literal(100L))), Literal.create(1, DecimalType(7, 2))), CaseWhen(Seq((Literal(true), Cast(Literal(100L), DecimalType(22, 2)))), Cast(Literal.create(1, DecimalType(7, 2)), DecimalType(22, 2))) @@ -1085,7 +1085,7 @@ class TypeCoercionSuite extends AnalysisTest { private val timeZoneResolver = ResolveTimeZone(new SQLConf) private def widenSetOperationTypes(plan: LogicalPlan): LogicalPlan = { - timeZoneResolver(TypeCoercion.WidenSetOperationTypes(plan)) + timeZoneResolver(TypeCoercion.WidenSetOperationTypes(conf)(plan)) } test("WidenSetOperationTypes for except and intersect") { @@ -1256,7 +1256,7 @@ class TypeCoercionSuite extends AnalysisTest { test("SPARK-15776 Divide expression's dataType should be casted to Double or Decimal " + "in aggregation function like sum") { - val rules = Seq(FunctionArgumentConversion, Division) + val rules = Seq(FunctionArgumentConversion(conf), Division) // Casts Integer to Double ruleTest(rules, sum(Divide(4, 3)), sum(Divide(Cast(4, DoubleType), Cast(3, DoubleType)))) // Left expression is Double, right expression is Int. Another rule ImplicitTypeCasts will @@ -1275,7 +1275,7 @@ class TypeCoercionSuite extends AnalysisTest { } test("SPARK-17117 null type coercion in divide") { - val rules = Seq(FunctionArgumentConversion, Division, new ImplicitTypeCasts(conf)) + val rules = Seq(FunctionArgumentConversion(conf), Division, ImplicitTypeCasts(conf)) val nullLit = Literal.create(null, NullType) ruleTest(rules, Divide(1L, nullLit), Divide(Cast(1L, DoubleType), Cast(nullLit, DoubleType))) ruleTest(rules, Divide(nullLit, 1L), Divide(Cast(nullLit, DoubleType), Cast(1L, DoubleType))) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index c502e583a55c5..e2a1a57c7dd4d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -24,7 +24,7 @@ import scala.collection.JavaConverters._ import scala.reflect.runtime.universe.TypeTag import scala.util.control.NonFatal -import org.apache.spark.{SPARK_VERSION, SparkConf, SparkContext} +import org.apache.spark.{SPARK_VERSION, SparkConf, SparkContext, TaskContext} import org.apache.spark.annotation.{DeveloperApi, Experimental, InterfaceStability} import org.apache.spark.api.java.JavaRDD import org.apache.spark.internal.Logging @@ -898,6 +898,7 @@ object SparkSession extends Logging { * @since 2.0.0 */ def getOrCreate(): SparkSession = synchronized { + assertOnDriver() // Get the session from current thread's active session. var session = activeThreadSession.get() if ((session ne null) && !session.sparkContext.isStopped) { @@ -1022,14 +1023,20 @@ object SparkSession extends Logging { * * @since 2.2.0 */ - def getActiveSession: Option[SparkSession] = Option(activeThreadSession.get) + def getActiveSession: Option[SparkSession] = { + assertOnDriver() + Option(activeThreadSession.get) + } /** * Returns the default SparkSession that is returned by the builder. * * @since 2.2.0 */ - def getDefaultSession: Option[SparkSession] = Option(defaultSession.get) + def getDefaultSession: Option[SparkSession] = { + assertOnDriver() + Option(defaultSession.get) + } /** * Returns the currently active SparkSession, otherwise the default one. If there is no default @@ -1062,6 +1069,14 @@ object SparkSession extends Logging { } } + private def assertOnDriver(): Unit = { + if (Utils.isTesting && TaskContext.get != null) { + // we're accessing it during task execution, fail. + throw new IllegalStateException( + "SparkSession should only be created and accessed on the driver.") + } + } + /** * Helper method to create an instance of `SessionState` based on `className` from conf. * The result is either `SessionState` or a Hive based `SessionState`. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala index f9a24806953e6..1edf27619ad7b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql.catalyst.analysis.{Resolver, TypeCoercion} import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.{Cast, Literal} import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.sql.util.SchemaUtils @@ -521,6 +522,8 @@ object PartitioningUtils { private val findWiderTypeForPartitionColumn: (DataType, DataType) => DataType = { case (DoubleType, _: DecimalType) | (_: DecimalType, DoubleType) => StringType case (DoubleType, LongType) | (LongType, DoubleType) => StringType - case (t1, t2) => TypeCoercion.findWiderTypeForTwo(t1, t2).getOrElse(StringType) + case (t1, t2) => + TypeCoercion.findWiderTypeForTwo( + t1, t2, SQLConf.get.caseSensitiveAnalysis).getOrElse(StringType) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala index a270a6451d5dd..e0424b7478122 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCoercion import org.apache.spark.sql.catalyst.json.JacksonUtils.nextUntil import org.apache.spark.sql.catalyst.json.JSONOptions import org.apache.spark.sql.catalyst.util.{DropMalformedMode, FailFastMode, ParseMode, PermissiveMode} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -44,6 +45,7 @@ private[sql] object JsonInferSchema { createParser: (JsonFactory, T) => JsonParser): StructType = { val parseMode = configOptions.parseMode val columnNameOfCorruptRecord = configOptions.columnNameOfCorruptRecord + val caseSensitive = SQLConf.get.caseSensitiveAnalysis // perform schema inference on each row and merge afterwards val rootType = json.mapPartitions { iter => @@ -53,7 +55,7 @@ private[sql] object JsonInferSchema { try { Utils.tryWithResource(createParser(factory, row)) { parser => parser.nextToken() - Some(inferField(parser, configOptions)) + Some(inferField(parser, configOptions, caseSensitive)) } } catch { case e @ (_: RuntimeException | _: JsonProcessingException) => parseMode match { @@ -68,7 +70,7 @@ private[sql] object JsonInferSchema { } } }.fold(StructType(Nil))( - compatibleRootType(columnNameOfCorruptRecord, parseMode)) + compatibleRootType(columnNameOfCorruptRecord, parseMode, caseSensitive)) canonicalizeType(rootType) match { case Some(st: StructType) => st @@ -98,14 +100,15 @@ private[sql] object JsonInferSchema { /** * Infer the type of a json document from the parser's token stream */ - private def inferField(parser: JsonParser, configOptions: JSONOptions): DataType = { + private def inferField( + parser: JsonParser, configOptions: JSONOptions, caseSensitive: Boolean): DataType = { import com.fasterxml.jackson.core.JsonToken._ parser.getCurrentToken match { case null | VALUE_NULL => NullType case FIELD_NAME => parser.nextToken() - inferField(parser, configOptions) + inferField(parser, configOptions, caseSensitive) case VALUE_STRING if parser.getTextLength < 1 => // Zero length strings and nulls have special handling to deal @@ -122,7 +125,7 @@ private[sql] object JsonInferSchema { while (nextUntil(parser, END_OBJECT)) { builder += StructField( parser.getCurrentName, - inferField(parser, configOptions), + inferField(parser, configOptions, caseSensitive), nullable = true) } val fields: Array[StructField] = builder.result() @@ -137,7 +140,7 @@ private[sql] object JsonInferSchema { var elementType: DataType = NullType while (nextUntil(parser, END_ARRAY)) { elementType = compatibleType( - elementType, inferField(parser, configOptions)) + elementType, inferField(parser, configOptions, caseSensitive), caseSensitive) } ArrayType(elementType) @@ -243,13 +246,14 @@ private[sql] object JsonInferSchema { */ private def compatibleRootType( columnNameOfCorruptRecords: String, - parseMode: ParseMode): (DataType, DataType) => DataType = { + parseMode: ParseMode, + caseSensitive: Boolean): (DataType, DataType) => DataType = { // Since we support array of json objects at the top level, // we need to check the element type and find the root level data type. case (ArrayType(ty1, _), ty2) => - compatibleRootType(columnNameOfCorruptRecords, parseMode)(ty1, ty2) + compatibleRootType(columnNameOfCorruptRecords, parseMode, caseSensitive)(ty1, ty2) case (ty1, ArrayType(ty2, _)) => - compatibleRootType(columnNameOfCorruptRecords, parseMode)(ty1, ty2) + compatibleRootType(columnNameOfCorruptRecords, parseMode, caseSensitive)(ty1, ty2) // Discard null/empty documents case (struct: StructType, NullType) => struct case (NullType, struct: StructType) => struct @@ -259,7 +263,7 @@ private[sql] object JsonInferSchema { withCorruptField(struct, o, columnNameOfCorruptRecords, parseMode) // If we get anything else, we call compatibleType. // Usually, when we reach here, ty1 and ty2 are two StructTypes. - case (ty1, ty2) => compatibleType(ty1, ty2) + case (ty1, ty2) => compatibleType(ty1, ty2, caseSensitive) } private[this] val emptyStructFieldArray = Array.empty[StructField] @@ -267,8 +271,8 @@ private[sql] object JsonInferSchema { /** * Returns the most general data type for two given data types. */ - def compatibleType(t1: DataType, t2: DataType): DataType = { - TypeCoercion.findTightestCommonType(t1, t2).getOrElse { + def compatibleType(t1: DataType, t2: DataType, caseSensitive: Boolean): DataType = { + TypeCoercion.findTightestCommonType(t1, t2, caseSensitive).getOrElse { // t1 or t2 is a StructType, ArrayType, or an unexpected type. (t1, t2) match { // Double support larger range than fixed decimal, DecimalType.Maximum should be enough @@ -303,7 +307,8 @@ private[sql] object JsonInferSchema { val f2Name = fields2(f2Idx).name val comp = f1Name.compareTo(f2Name) if (comp == 0) { - val dataType = compatibleType(fields1(f1Idx).dataType, fields2(f2Idx).dataType) + val dataType = compatibleType( + fields1(f1Idx).dataType, fields2(f2Idx).dataType, caseSensitive) newFields.add(StructField(f1Name, dataType, nullable = true)) f1Idx += 1 f2Idx += 1 @@ -326,15 +331,17 @@ private[sql] object JsonInferSchema { StructType(newFields.toArray(emptyStructFieldArray)) case (ArrayType(elementType1, containsNull1), ArrayType(elementType2, containsNull2)) => - ArrayType(compatibleType(elementType1, elementType2), containsNull1 || containsNull2) + ArrayType( + compatibleType(elementType1, elementType2, caseSensitive), + containsNull1 || containsNull2) // The case that given `DecimalType` is capable of given `IntegralType` is handled in // `findTightestCommonTypeOfTwo`. Both cases below will be executed only when // the given `DecimalType` is not capable of the given `IntegralType`. case (t1: IntegralType, t2: DecimalType) => - compatibleType(DecimalType.forType(t1), t2) + compatibleType(DecimalType.forType(t1), t2, caseSensitive) case (t1: DecimalType, t2: IntegralType) => - compatibleType(t1, DecimalType.forType(t2)) + compatibleType(t1, DecimalType.forType(t2), caseSensitive) // strings and every string is a Json object. case (_, _) => StringType diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 4b3921c61a000..34d23ee53220d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -122,10 +122,10 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { test("Get compatible type") { def checkDataType(t1: DataType, t2: DataType, expected: DataType) { - var actual = compatibleType(t1, t2) + var actual = compatibleType(t1, t2, conf.caseSensitiveAnalysis) assert(actual == expected, s"Expected $expected as the most general data type for $t1 and $t2, found $actual") - actual = compatibleType(t2, t1) + actual = compatibleType(t2, t1, conf.caseSensitiveAnalysis) assert(actual == expected, s"Expected $expected as the most general data type for $t1 and $t2, found $actual") } From 75cf369c742e7c7b68f384d123447c97be95c9f0 Mon Sep 17 00:00:00 2001 From: Marek Novotny Date: Fri, 11 May 2018 09:05:35 +0800 Subject: [PATCH 0781/1161] [SPARK-24197][SPARKR][SQL] Adding array_sort function to SparkR ## What changes were proposed in this pull request? The PR adds array_sort function to SparkR. ## How was this patch tested? Tests added into R/pkg/tests/fulltests/test_sparkSQL.R ## Example ``` > df <- createDataFrame(list(list(list(2L, 1L, 3L, NA)), list(list(NA, 6L, 5L, NA, 4L)))) > head(collect(select(df, array_sort(df[[1]])))) ``` Result: ``` array_sort(_1) 1 1, 2, 3, NA 2 4, 5, 6, NA, NA ``` Author: Marek Novotny Closes #21294 from mn-mikke/SPARK-24197. --- R/pkg/NAMESPACE | 1 + R/pkg/R/functions.R | 21 ++++++++++++++++++--- R/pkg/R/generics.R | 4 ++++ R/pkg/tests/fulltests/test_sparkSQL.R | 13 +++++++++---- 4 files changed, 32 insertions(+), 7 deletions(-) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 8cd00352d1956..5f8209689a559 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -204,6 +204,7 @@ exportMethods("%<=>%", "array_max", "array_min", "array_position", + "array_sort", "asc", "ascii", "asin", diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 04d0e4620b28a..1f97054443e1b 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -207,7 +207,7 @@ NULL #' tmp <- mutate(df, v1 = create_array(df$mpg, df$cyl, df$hp)) #' head(select(tmp, array_contains(tmp$v1, 21), size(tmp$v1))) #' head(select(tmp, array_max(tmp$v1), array_min(tmp$v1))) -#' head(select(tmp, array_position(tmp$v1, 21))) +#' head(select(tmp, array_position(tmp$v1, 21), array_sort(tmp$v1))) #' head(select(tmp, flatten(tmp$v1))) #' tmp2 <- mutate(tmp, v2 = explode(tmp$v1)) #' head(tmp2) @@ -3043,6 +3043,20 @@ setMethod("array_position", column(jc) }) +#' @details +#' \code{array_sort}: Sorts the input array in ascending order. The elements of the input array +#' must be orderable. NA elements will be placed at the end of the returned array. +#' +#' @rdname column_collection_functions +#' @aliases array_sort array_sort,Column-method +#' @note array_sort since 2.4.0 +setMethod("array_sort", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "array_sort", x@jc) + column(jc) + }) + #' @details #' \code{flatten}: Transforms an array of arrays into a single array. #' @@ -3125,8 +3139,9 @@ setMethod("size", }) #' @details -#' \code{sort_array}: Sorts the input array in ascending or descending order according -#' to the natural ordering of the array elements. +#' \code{sort_array}: Sorts the input array in ascending or descending order according to +#' the natural ordering of the array elements. NA elements will be placed at the beginning of +#' the returned array in ascending order or at the end of the returned array in descending order. #' #' @rdname column_collection_functions #' @param asc a logical flag indicating the sorting order. diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 4ef12d19b3575..5faa51eef3abd 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -769,6 +769,10 @@ setGeneric("array_min", function(x) { standardGeneric("array_min") }) #' @name NULL setGeneric("array_position", function(x, value) { standardGeneric("array_position") }) +#' @rdname column_collection_functions +#' @name NULL +setGeneric("array_sort", function(x) { standardGeneric("array_sort") }) + #' @rdname column_string_functions #' @name NULL setGeneric("ascii", function(x) { standardGeneric("ascii") }) diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index 43725e0ebd3bf..b8bfded0ebf2d 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -1479,8 +1479,7 @@ test_that("column functions", { df5 <- createDataFrame(list(list(a = "010101"))) expect_equal(collect(select(df5, conv(df5$a, 2, 16)))[1, 1], "15") - # Test array_contains(), array_max(), array_min(), array_position(), element_at() - # and sort_array() + # Test array_contains(), array_max(), array_min(), array_position() and element_at() df <- createDataFrame(list(list(list(1L, 2L, 3L)), list(list(6L, 5L, 4L)))) result <- collect(select(df, array_contains(df[[1]], 1L)))[[1]] expect_equal(result, c(TRUE, FALSE)) @@ -1497,10 +1496,16 @@ test_that("column functions", { result <- collect(select(df, element_at(df[[1]], 1L)))[[1]] expect_equal(result, c(1, 6)) + # Test array_sort() and sort_array() + df <- createDataFrame(list(list(list(2L, 1L, 3L, NA)), list(list(NA, 6L, 5L, NA, 4L)))) + + result <- collect(select(df, array_sort(df[[1]])))[[1]] + expect_equal(result, list(list(1L, 2L, 3L, NA), list(4L, 5L, 6L, NA, NA))) + result <- collect(select(df, sort_array(df[[1]], FALSE)))[[1]] - expect_equal(result, list(list(3L, 2L, 1L), list(6L, 5L, 4L))) + expect_equal(result, list(list(3L, 2L, 1L, NA), list(6L, 5L, 4L, NA, NA))) result <- collect(select(df, sort_array(df[[1]])))[[1]] - expect_equal(result, list(list(1L, 2L, 3L), list(4L, 5L, 6L))) + expect_equal(result, list(list(NA, 1L, 2L, 3L), list(NA, NA, 4L, 5L, 6L))) # Test flattern df <- createDataFrame(list(list(list(list(1L, 2L), list(3L, 4L))), From 54032682b910dc5089af27d2c7b6efe55700f034 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Fri, 11 May 2018 17:40:35 +0800 Subject: [PATCH 0782/1161] [SPARK-24182][YARN] Improve error message when client AM fails. Instead of always throwing a generic exception when the AM fails, print a generic error and throw the exception with the YARN diagnostics containing the reason for the failure. There was an issue with YARN sometimes providing a generic diagnostic message, even though the AM provides a failure reason when unregistering. That was happening because the AM was registering too late, and if errors happened before the registration, YARN would just create a generic "ExitCodeException" which wasn't very helpful. Since most errors in this path are a result of not being able to connect to the driver, this change modifies the AM registration a bit so that the AM is registered before the connection to the driver is established. That way, errors are properly propagated through YARN back to the driver. As part of that, I also removed the code that retried connections to the driver from the client AM. At that point, the driver should already be up and waiting for connections, so it's unlikely that retrying would help - and in case it does, that means a flaky network, which would mean problems would probably show up again. The effect of that is that connection-related errors are reported back to the driver much faster now (through the YARN report). One thing to note is that there seems to be a race on the YARN side that causes a report to be sent to the client without the corresponding diagnostics string from the AM; the diagnostics are available later from the RM web page. For that reason, the generic error messages are kept in the Spark scheduler code, to help guide users to a way of debugging their failure. Also of note is that if YARN's max attempts configuration is lower than Spark's, Spark will not unregister the AM with a proper diagnostics message. Unfortunately there seems to be no way to unregister the AM and still allow further re-attempts to happen. Testing: - existing unit tests - some of our integration tests - hardcoded an invalid driver address in the code and verified the error in the shell. e.g. ``` scala> 18/05/04 15:09:34 ERROR cluster.YarnClientSchedulerBackend: YARN application has exited unexpectedly with state FAILED! Check the YARN application logs for more details. 18/05/04 15:09:34 ERROR cluster.YarnClientSchedulerBackend: Diagnostics message: Uncaught exception: org.apache.spark.SparkException: Exception thrown in awaitResult: Caused by: java.io.IOException: Failed to connect to localhost/127.0.0.1:1234 ``` Author: Marcelo Vanzin Closes #21243 from vanzin/SPARK-24182. --- docs/running-on-yarn.md | 5 +- .../spark/deploy/yarn/ApplicationMaster.scala | 103 +++++++----------- .../org/apache/spark/deploy/yarn/Client.scala | 43 +++++--- .../spark/deploy/yarn/YarnRMClient.scala | 29 +++-- .../cluster/YarnClientSchedulerBackend.scala | 35 ++++-- 5 files changed, 112 insertions(+), 103 deletions(-) diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index ceda8a3ae2403..c9e68c3bfd056 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -133,9 +133,8 @@ To use a custom metrics.properties for the application master and executors, upd diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 595077e7e809f..3d6ee50b070a3 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -346,7 +346,7 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends synchronized { if (!finished) { val inShutdown = ShutdownHookManager.inShutdown() - if (registered) { + if (registered || !isClusterMode) { exitCode = code finalStatus = status } else { @@ -389,37 +389,40 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends } private def registerAM( + host: String, + port: Int, _sparkConf: SparkConf, - _rpcEnv: RpcEnv, - driverRef: RpcEndpointRef, - uiAddress: Option[String]) = { + uiAddress: Option[String]): Unit = { val appId = client.getAttemptId().getApplicationId().toString() val attemptId = client.getAttemptId().getAttemptId().toString() val historyAddress = ApplicationMaster .getHistoryServerAddress(_sparkConf, yarnConf, appId, attemptId) - val driverUrl = RpcEndpointAddress( - _sparkConf.get("spark.driver.host"), - _sparkConf.get("spark.driver.port").toInt, + client.register(host, port, yarnConf, _sparkConf, uiAddress, historyAddress) + registered = true + } + + private def createAllocator(driverRef: RpcEndpointRef, _sparkConf: SparkConf): Unit = { + val appId = client.getAttemptId().getApplicationId().toString() + val driverUrl = RpcEndpointAddress(driverRef.address.host, driverRef.address.port, CoarseGrainedSchedulerBackend.ENDPOINT_NAME).toString // Before we initialize the allocator, let's log the information about how executors will // be run up front, to avoid printing this out for every single executor being launched. // Use placeholders for information that changes such as executor IDs. logInfo { - val executorMemory = sparkConf.get(EXECUTOR_MEMORY).toInt - val executorCores = sparkConf.get(EXECUTOR_CORES) - val dummyRunner = new ExecutorRunnable(None, yarnConf, sparkConf, driverUrl, "", + val executorMemory = _sparkConf.get(EXECUTOR_MEMORY).toInt + val executorCores = _sparkConf.get(EXECUTOR_CORES) + val dummyRunner = new ExecutorRunnable(None, yarnConf, _sparkConf, driverUrl, "", "", executorMemory, executorCores, appId, securityMgr, localResources) dummyRunner.launchContextDebugInfo() } - allocator = client.register(driverUrl, - driverRef, + allocator = client.createAllocator( yarnConf, _sparkConf, - uiAddress, - historyAddress, + driverUrl, + driverRef, securityMgr, localResources) @@ -434,15 +437,6 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends reporterThread = launchReporterThread() } - /** - * @return An [[RpcEndpoint]] that communicates with the driver's scheduler backend. - */ - private def createSchedulerRef(host: String, port: String): RpcEndpointRef = { - rpcEnv.setupEndpointRef( - RpcAddress(host, port.toInt), - YarnSchedulerBackend.ENDPOINT_NAME) - } - private def runDriver(): Unit = { addAmIpFilter(None) userClassThread = startUserApplication() @@ -456,11 +450,16 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends Duration(totalWaitTime, TimeUnit.MILLISECONDS)) if (sc != null) { rpcEnv = sc.env.rpcEnv - val driverRef = createSchedulerRef( - sc.getConf.get("spark.driver.host"), - sc.getConf.get("spark.driver.port")) - registerAM(sc.getConf, rpcEnv, driverRef, sc.ui.map(_.webUrl)) - registered = true + + val userConf = sc.getConf + val host = userConf.get("spark.driver.host") + val port = userConf.get("spark.driver.port").toInt + registerAM(host, port, userConf, sc.ui.map(_.webUrl)) + + val driverRef = rpcEnv.setupEndpointRef( + RpcAddress(host, port), + YarnSchedulerBackend.ENDPOINT_NAME) + createAllocator(driverRef, userConf) } else { // Sanity check; should never happen in normal operation, since sc should only be null // if the user app did not create a SparkContext. @@ -486,10 +485,18 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends val amCores = sparkConf.get(AM_CORES) rpcEnv = RpcEnv.create("sparkYarnAM", hostname, hostname, -1, sparkConf, securityMgr, amCores, true) - val driverRef = waitForSparkDriver() + + // The client-mode AM doesn't listen for incoming connections, so report an invalid port. + registerAM(hostname, -1, sparkConf, sparkConf.getOption("spark.driver.appUIAddress")) + + // The driver should be up and listening, so unlike cluster mode, just try to connect to it + // with no waiting or retrying. + val (driverHost, driverPort) = Utils.parseHostPort(args.userArgs(0)) + val driverRef = rpcEnv.setupEndpointRef( + RpcAddress(driverHost, driverPort), + YarnSchedulerBackend.ENDPOINT_NAME) addAmIpFilter(Some(driverRef)) - registerAM(sparkConf, rpcEnv, driverRef, sparkConf.getOption("spark.driver.appUIAddress")) - registered = true + createAllocator(driverRef, sparkConf) // In client mode the actor will stop the reporter thread. reporterThread.join() @@ -600,40 +607,6 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends } } - private def waitForSparkDriver(): RpcEndpointRef = { - logInfo("Waiting for Spark driver to be reachable.") - var driverUp = false - val hostport = args.userArgs(0) - val (driverHost, driverPort) = Utils.parseHostPort(hostport) - - // Spark driver should already be up since it launched us, but we don't want to - // wait forever, so wait 100 seconds max to match the cluster mode setting. - val totalWaitTimeMs = sparkConf.get(AM_MAX_WAIT_TIME) - val deadline = System.currentTimeMillis + totalWaitTimeMs - - while (!driverUp && !finished && System.currentTimeMillis < deadline) { - try { - val socket = new Socket(driverHost, driverPort) - socket.close() - logInfo("Driver now available: %s:%s".format(driverHost, driverPort)) - driverUp = true - } catch { - case e: Exception => - logError("Failed to connect to driver at %s:%s, retrying ...". - format(driverHost, driverPort)) - Thread.sleep(100L) - } - } - - if (!driverUp) { - throw new SparkException("Failed to connect to driver!") - } - - sparkConf.set("spark.driver.host", driverHost) - sparkConf.set("spark.driver.port", driverPort.toString) - createSchedulerRef(driverHost, driverPort.toString) - } - /** Add the Yarn IP filter that is required for properly securing the UI. */ private def addAmIpFilter(driver: Option[RpcEndpointRef]) = { val proxyBase = System.getenv(ApplicationConstants.APPLICATION_WEB_PROXY_BASE_ENV) diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 134b3e5fef11a..7225ff03dc34e 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -1019,8 +1019,7 @@ private[spark] class Client( appId: ApplicationId, returnOnRunning: Boolean = false, logApplicationReport: Boolean = true, - interval: Long = sparkConf.get(REPORT_INTERVAL)): - (YarnApplicationState, FinalApplicationStatus) = { + interval: Long = sparkConf.get(REPORT_INTERVAL)): YarnAppReport = { var lastState: YarnApplicationState = null while (true) { Thread.sleep(interval) @@ -1031,11 +1030,13 @@ private[spark] class Client( case e: ApplicationNotFoundException => logError(s"Application $appId not found.") cleanupStagingDir(appId) - return (YarnApplicationState.KILLED, FinalApplicationStatus.KILLED) + return YarnAppReport(YarnApplicationState.KILLED, FinalApplicationStatus.KILLED, None) case NonFatal(e) => - logError(s"Failed to contact YARN for application $appId.", e) + val msg = s"Failed to contact YARN for application $appId." + logError(msg, e) // Don't necessarily clean up staging dir because status is unknown - return (YarnApplicationState.FAILED, FinalApplicationStatus.FAILED) + return YarnAppReport(YarnApplicationState.FAILED, FinalApplicationStatus.FAILED, + Some(msg)) } val state = report.getYarnApplicationState @@ -1073,14 +1074,14 @@ private[spark] class Client( } if (state == YarnApplicationState.FINISHED || - state == YarnApplicationState.FAILED || - state == YarnApplicationState.KILLED) { + state == YarnApplicationState.FAILED || + state == YarnApplicationState.KILLED) { cleanupStagingDir(appId) - return (state, report.getFinalApplicationStatus) + return createAppReport(report) } if (returnOnRunning && state == YarnApplicationState.RUNNING) { - return (state, report.getFinalApplicationStatus) + return createAppReport(report) } lastState = state @@ -1129,16 +1130,17 @@ private[spark] class Client( throw new SparkException(s"Application $appId finished with status: $state") } } else { - val (yarnApplicationState, finalApplicationStatus) = monitorApplication(appId) - if (yarnApplicationState == YarnApplicationState.FAILED || - finalApplicationStatus == FinalApplicationStatus.FAILED) { + val YarnAppReport(appState, finalState, diags) = monitorApplication(appId) + if (appState == YarnApplicationState.FAILED || finalState == FinalApplicationStatus.FAILED) { + diags.foreach { err => + logError(s"Application diagnostics message: $err") + } throw new SparkException(s"Application $appId finished with failed status") } - if (yarnApplicationState == YarnApplicationState.KILLED || - finalApplicationStatus == FinalApplicationStatus.KILLED) { + if (appState == YarnApplicationState.KILLED || finalState == FinalApplicationStatus.KILLED) { throw new SparkException(s"Application $appId is killed") } - if (finalApplicationStatus == FinalApplicationStatus.UNDEFINED) { + if (finalState == FinalApplicationStatus.UNDEFINED) { throw new SparkException(s"The final status of application $appId is undefined") } } @@ -1477,6 +1479,12 @@ private object Client extends Logging { uri.startsWith(s"$LOCAL_SCHEME:") } + def createAppReport(report: ApplicationReport): YarnAppReport = { + val diags = report.getDiagnostics() + val diagsOpt = if (diags != null && diags.nonEmpty) Some(diags) else None + YarnAppReport(report.getYarnApplicationState(), report.getFinalApplicationStatus(), diagsOpt) + } + } private[spark] class YarnClusterApplication extends SparkApplication { @@ -1491,3 +1499,8 @@ private[spark] class YarnClusterApplication extends SparkApplication { } } + +private[spark] case class YarnAppReport( + appState: YarnApplicationState, + finalState: FinalApplicationStatus, + diagnostics: Option[String]) diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala index 17234b120ae13..b59dcf158d87c 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala @@ -42,23 +42,20 @@ private[spark] class YarnRMClient extends Logging { /** * Registers the application master with the RM. * + * @param driverHost Host name where driver is running. + * @param driverPort Port where driver is listening. * @param conf The Yarn configuration. * @param sparkConf The Spark configuration. * @param uiAddress Address of the SparkUI. * @param uiHistoryAddress Address of the application on the History Server. - * @param securityMgr The security manager. - * @param localResources Map with information about files distributed via YARN's cache. */ def register( - driverUrl: String, - driverRef: RpcEndpointRef, + driverHost: String, + driverPort: Int, conf: YarnConfiguration, sparkConf: SparkConf, uiAddress: Option[String], - uiHistoryAddress: String, - securityMgr: SecurityManager, - localResources: Map[String, LocalResource] - ): YarnAllocator = { + uiHistoryAddress: String): Unit = { amClient = AMRMClient.createAMRMClient() amClient.init(conf) amClient.start() @@ -70,10 +67,19 @@ private[spark] class YarnRMClient extends Logging { logInfo("Registering the ApplicationMaster") synchronized { - amClient.registerApplicationMaster(driverRef.address.host, driverRef.address.port, - trackingUrl) + amClient.registerApplicationMaster(driverHost, driverPort, trackingUrl) registered = true } + } + + def createAllocator( + conf: YarnConfiguration, + sparkConf: SparkConf, + driverUrl: String, + driverRef: RpcEndpointRef, + securityMgr: SecurityManager, + localResources: Map[String, LocalResource]): YarnAllocator = { + require(registered, "Must register AM before creating allocator.") new YarnAllocator(driverUrl, driverRef, conf, sparkConf, amClient, getAttemptId(), securityMgr, localResources, new SparkRackResolver()) } @@ -88,6 +94,9 @@ private[spark] class YarnRMClient extends Logging { if (registered) { amClient.unregisterApplicationMaster(status, diagnostics, uiHistoryAddress) } + if (amClient != null) { + amClient.stop() + } } /** Returns the attempt ID. */ diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala index 06e54a2eaf95a..f1a8df00f9c5b 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala @@ -22,7 +22,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.hadoop.yarn.api.records.YarnApplicationState import org.apache.spark.{SparkContext, SparkException} -import org.apache.spark.deploy.yarn.{Client, ClientArguments} +import org.apache.spark.deploy.yarn.{Client, ClientArguments, YarnAppReport} import org.apache.spark.deploy.yarn.config._ import org.apache.spark.internal.Logging import org.apache.spark.launcher.SparkAppHandle @@ -75,13 +75,23 @@ private[spark] class YarnClientSchedulerBackend( val monitorInterval = conf.get(CLIENT_LAUNCH_MONITOR_INTERVAL) assert(client != null && appId.isDefined, "Application has not been submitted yet!") - val (state, _) = client.monitorApplication(appId.get, returnOnRunning = true, - interval = monitorInterval) // blocking + val YarnAppReport(state, _, diags) = client.monitorApplication(appId.get, + returnOnRunning = true, interval = monitorInterval) if (state == YarnApplicationState.FINISHED || - state == YarnApplicationState.FAILED || - state == YarnApplicationState.KILLED) { - throw new SparkException("Yarn application has already ended! " + - "It might have been killed or unable to launch application master.") + state == YarnApplicationState.FAILED || + state == YarnApplicationState.KILLED) { + val genericMessage = "The YARN application has already ended! " + + "It might have been killed or the Application Master may have failed to start. " + + "Check the YARN application logs for more details." + val exceptionMsg = diags match { + case Some(msg) => + logError(genericMessage) + msg + + case None => + genericMessage + } + throw new SparkException(exceptionMsg) } if (state == YarnApplicationState.RUNNING) { logInfo(s"Application ${appId.get} has started running.") @@ -100,8 +110,13 @@ private[spark] class YarnClientSchedulerBackend( override def run() { try { - val (state, _) = client.monitorApplication(appId.get, logApplicationReport = false) - logError(s"Yarn application has already exited with state $state!") + val YarnAppReport(_, state, diags) = + client.monitorApplication(appId.get, logApplicationReport = true) + logError(s"YARN application has exited unexpectedly with state $state! " + + "Check the YARN application logs for more details.") + diags.foreach { err => + logError(s"Diagnostics message: $err") + } allowInterrupt = false sc.stop() } catch { @@ -124,7 +139,7 @@ private[spark] class YarnClientSchedulerBackend( private def asyncMonitorApplication(): MonitorThread = { assert(client != null && appId.isDefined, "Application has not been submitted yet!") val t = new MonitorThread - t.setName("Yarn application state monitor") + t.setName("YARN application state monitor") t.setDaemon(true) t } From 928845a42230a2c0a318011002a54ad871468b2e Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 11 May 2018 10:00:28 -0700 Subject: [PATCH 0783/1161] [SPARK-24172][SQL] we should not apply operator pushdown to data source v2 many times ## What changes were proposed in this pull request? In `PushDownOperatorsToDataSource`, we use `transformUp` to match `PhysicalOperation` and apply pushdown. This is problematic if we have multiple `Filter` and `Project` above the data source v2 relation. e.g. for a query ``` Project Filter DataSourceV2Relation ``` The pattern match will be triggered twice and we will do operator pushdown twice. This is unnecessary, we can use `mapChildren` to only apply pushdown once. ## How was this patch tested? existing test Author: Wenchen Fan Closes #21230 from cloud-fan/step2. --- .../v2/PushDownOperatorsToDataSource.scala | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala index 9293d4f831bff..e894f8afd6762 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala @@ -23,17 +23,10 @@ import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project import org.apache.spark.sql.catalyst.rules.Rule object PushDownOperatorsToDataSource extends Rule[LogicalPlan] { - override def apply( - plan: LogicalPlan): LogicalPlan = plan transformUp { + override def apply(plan: LogicalPlan): LogicalPlan = plan match { // PhysicalOperation guarantees that filters are deterministic; no need to check - case PhysicalOperation(project, newFilters, relation : DataSourceV2Relation) => - // merge the filters - val filters = relation.filters match { - case Some(existing) => - existing ++ newFilters - case _ => - newFilters - } + case PhysicalOperation(project, filters, relation: DataSourceV2Relation) => + assert(relation.filters.isEmpty, "data source v2 should do push down only once.") val projectAttrs = project.map(_.toAttribute) val projectSet = AttributeSet(project.flatMap(_.references)) @@ -67,5 +60,7 @@ object PushDownOperatorsToDataSource extends Rule[LogicalPlan] { } else { filtered } + + case other => other.mapChildren(apply) } } From 92f6f52ff0ce47e046656ca8bed7d7bfbbb42dcb Mon Sep 17 00:00:00 2001 From: aditkumar Date: Fri, 11 May 2018 14:42:23 -0500 Subject: [PATCH 0784/1161] [MINOR][DOCS] Documenting months_between direction ## What changes were proposed in this pull request? It's useful to know what relationship between date1 and date2 results in a positive number. Author: aditkumar Author: Adit Kumar Closes #20787 from aditkumar/master. --- R/pkg/R/functions.R | 6 +++++- python/pyspark/sql/functions.py | 7 +++++-- .../catalyst/expressions/datetimeExpressions.scala | 14 +++++++++++--- .../spark/sql/catalyst/util/DateTimeUtils.scala | 8 ++++---- .../scala/org/apache/spark/sql/functions.scala | 7 ++++++- 5 files changed, 31 insertions(+), 11 deletions(-) diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 1f97054443e1b..4964594284aa0 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -1912,6 +1912,7 @@ setMethod("atan2", signature(y = "Column"), #' @details #' \code{datediff}: Returns the number of days from \code{y} to \code{x}. +#' If \code{y} is later than \code{x} then the result is positive. #' #' @rdname column_datetime_diff_functions #' @aliases datediff datediff,Column-method @@ -1971,7 +1972,10 @@ setMethod("levenshtein", signature(y = "Column"), }) #' @details -#' \code{months_between}: Returns number of months between dates \code{y} and \code{x}. +#' \code{months_between}: Returns number of months between dates \code{y} and \code{x}. +#' If \code{y} is later than \code{x}, then the result is positive. If \code{y} and \code{x} +#' are on the same day of month, or both are the last day of month, time of day will be ignored. +#' Otherwise, the difference is calculated based on 31 days per month, and rounded to 8 digits. #' #' @rdname column_datetime_diff_functions #' @aliases months_between months_between,Column-method diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index f5a584152b4f6..b62748e9a2d6c 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1108,8 +1108,11 @@ def add_months(start, months): @since(1.5) def months_between(date1, date2, roundOff=True): """ - Returns the number of months between date1 and date2. - Unless `roundOff` is set to `False`, the result is rounded off to 8 digits. + Returns number of months between dates date1 and date2. + If date1 is later than date2, then the result is positive. + If date1 and date2 are on the same day of month, or both are the last day of month, + returns an integer (time of day will be ignored). + The result is rounded off to 8 digits unless `roundOff` is set to `False`. >>> df = spark.createDataFrame([('1997-02-28 10:30:00', '1996-10-30')], ['date1', 'date2']) >>> df.select(months_between(df.date1, df.date2).alias('months')).collect() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index 76aa61415a11f..03422fecb3209 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -1194,13 +1194,21 @@ case class AddMonths(startDate: Expression, numMonths: Expression) } /** - * Returns number of months between dates date1 and date2. + * Returns number of months between times `timestamp1` and `timestamp2`. + * If `timestamp1` is later than `timestamp2`, then the result is positive. + * If `timestamp1` and `timestamp2` are on the same day of month, or both + * are the last day of month, time of day will be ignored. Otherwise, the + * difference is calculated based on 31 days per month, and rounded to + * 8 digits unless roundOff=false. */ // scalastyle:off line.size.limit @ExpressionDescription( usage = """ - _FUNC_(timestamp1, timestamp2[, roundOff]) - Returns number of months between `timestamp1` and `timestamp2`. - The result is rounded to 8 decimal places by default. Set roundOff=false otherwise."""", + _FUNC_(timestamp1, timestamp2[, roundOff]) - If `timestamp1` is later than `timestamp2`, then the result + is positive. If `timestamp1` and `timestamp2` are on the same day of month, or both + are the last day of month, time of day will be ignored. Otherwise, the difference is + calculated based on 31 days per month, and rounded to 8 digits unless roundOff=false. + """, examples = """ Examples: > SELECT _FUNC_('1997-02-28 10:30:00', '1996-10-30'); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index e646da0659e85..80f15053005ff 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -885,13 +885,13 @@ object DateTimeUtils { /** * Returns number of months between time1 and time2. time1 and time2 are expressed in - * microseconds since 1.1.1970. + * microseconds since 1.1.1970. If time1 is later than time2, the result is positive. * - * If time1 and time2 having the same day of month, or both are the last day of month, - * it returns an integer (time under a day will be ignored). + * If time1 and time2 are on the same day of month, or both are the last day of month, + * returns, time of day will be ignored. * * Otherwise, the difference is calculated based on 31 days per month. - * If `roundOff` is set to true, the result is rounded to 8 decimal places. + * The result is rounded to 8 decimal places if `roundOff` is set to true. */ def monthsBetween( time1: SQLTimestamp, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 225de0051d6fa..e7f866ddca681 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -2903,7 +2903,12 @@ object functions { /** * Returns number of months between dates `date1` and `date2`. - * The result is rounded off to 8 digits. + * If `date1` is later than `date2`, then the result is positive. + * If `date1` and `date2` are on the same day of month, or both are the last day of month, + * time of day will be ignored. + * + * Otherwise, the difference is calculated based on 31 days per month, and rounded to + * 8 digits. * @group datetime_funcs * @since 1.5.0 */ From f27a035daf705766d3445e5c6a99867c11c552b0 Mon Sep 17 00:00:00 2001 From: Shivaram Venkataraman Date: Fri, 11 May 2018 17:00:51 -0700 Subject: [PATCH 0785/1161] [SPARKR] Require Java 8 for SparkR This change updates the SystemRequirements and also includes a runtime check if the JVM is being launched by R. The runtime check is done by querying `java -version` ## How was this patch tested? Tested on a Mac and Windows machine Author: Shivaram Venkataraman Closes #21278 from shivaram/sparkr-skip-solaris. --- R/pkg/DESCRIPTION | 1 + R/pkg/R/client.R | 35 +++++++++++++++++++++++++++++++++++ R/pkg/R/sparkR.R | 1 + R/pkg/R/utils.R | 4 ++-- 4 files changed, 39 insertions(+), 2 deletions(-) diff --git a/R/pkg/DESCRIPTION b/R/pkg/DESCRIPTION index 855eb5bf77f16..f52d785e05cdd 100644 --- a/R/pkg/DESCRIPTION +++ b/R/pkg/DESCRIPTION @@ -13,6 +13,7 @@ Authors@R: c(person("Shivaram", "Venkataraman", role = c("aut", "cre"), License: Apache License (== 2.0) URL: http://www.apache.org/ http://spark.apache.org/ BugReports: http://spark.apache.org/contributing.html +SystemRequirements: Java (== 8) Depends: R (>= 3.0), methods diff --git a/R/pkg/R/client.R b/R/pkg/R/client.R index 7244cc9f9e38e..e9295e05872bd 100644 --- a/R/pkg/R/client.R +++ b/R/pkg/R/client.R @@ -60,6 +60,40 @@ generateSparkSubmitArgs <- function(args, sparkHome, jars, sparkSubmitOpts, pack combinedArgs } +checkJavaVersion <- function() { + javaBin <- "java" + javaHome <- Sys.getenv("JAVA_HOME") + javaReqs <- utils::packageDescription(utils::packageName(), fields=c("SystemRequirements")) + sparkJavaVersion <- as.numeric(tail(strsplit(javaReqs, "[(=)]")[[1]], n = 1L)) + if (javaHome != "") { + javaBin <- file.path(javaHome, "bin", javaBin) + } + + # If java is missing from PATH, we get an error in Unix and a warning in Windows + javaVersionOut <- tryCatch( + launchScript(javaBin, "-version", wait = TRUE, stdout = TRUE, stderr = TRUE), + error = function(e) { + stop("Java version check failed. Please make sure Java is installed", + " and set JAVA_HOME to point to the installation directory.", e) + }, + warning = function(w) { + stop("Java version check failed. Please make sure Java is installed", + " and set JAVA_HOME to point to the installation directory.", w) + }) + javaVersionFilter <- Filter( + function(x) { + grepl("java version", x) + }, javaVersionOut) + + javaVersionStr <- strsplit(javaVersionFilter[[1]], "[\"]")[[1L]][2] + # javaVersionStr is of the form 1.8.0_92. + # Extract 8 from it to compare to sparkJavaVersion + javaVersionNum <- as.integer(strsplit(javaVersionStr, "[.]")[[1L]][2]) + if (javaVersionNum != sparkJavaVersion) { + stop(paste("Java version", sparkJavaVersion, "is required for this package; found version:", javaVersionStr)) + } +} + launchBackend <- function(args, sparkHome, jars, sparkSubmitOpts, packages) { sparkSubmitBinName <- determineSparkSubmitBin() if (sparkHome != "") { @@ -67,6 +101,7 @@ launchBackend <- function(args, sparkHome, jars, sparkSubmitOpts, packages) { } else { sparkSubmitBin <- sparkSubmitBinName } + combinedArgs <- generateSparkSubmitArgs(args, sparkHome, jars, sparkSubmitOpts, packages) cat("Launching java with spark-submit command", sparkSubmitBin, combinedArgs, "\n") invisible(launchScript(sparkSubmitBin, combinedArgs)) diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R index 38ee79477996f..d6a2d08f9c218 100644 --- a/R/pkg/R/sparkR.R +++ b/R/pkg/R/sparkR.R @@ -167,6 +167,7 @@ sparkR.sparkContext <- function( submitOps <- getClientModeSparkSubmitOpts( Sys.getenv("SPARKR_SUBMIT_ARGS", "sparkr-shell"), sparkEnvirMap) + checkJavaVersion() launchBackend( args = path, sparkHome = sparkHome, diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R index f1b5ecaa017df..c3501977e64bc 100644 --- a/R/pkg/R/utils.R +++ b/R/pkg/R/utils.R @@ -746,7 +746,7 @@ varargsToJProperties <- function(...) { props } -launchScript <- function(script, combinedArgs, wait = FALSE) { +launchScript <- function(script, combinedArgs, wait = FALSE, stdout = "", stderr = "") { if (.Platform$OS.type == "windows") { scriptWithArgs <- paste(script, combinedArgs, sep = " ") # on Windows, intern = F seems to mean output to the console. (documentation on this is missing) @@ -756,7 +756,7 @@ launchScript <- function(script, combinedArgs, wait = FALSE) { # stdout = F means discard output # stdout = "" means to its console (default) # Note that the console of this child process might not be the same as the running R process. - system2(script, combinedArgs, stdout = "", wait = wait) + system2(script, combinedArgs, stdout = stdout, wait = wait, stderr = stderr) } } From e3dabdf6ef210fb9f4337e305feb9c4983a57350 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sat, 12 May 2018 12:15:36 +0800 Subject: [PATCH 0786/1161] [SPARK-23907] Removes regr_* functions in functions.scala ## What changes were proposed in this pull request? This patch removes the various regr_* functions in functions.scala. They are so uncommon that I don't think they deserve real estate in functions.scala. We can consider adding them later if more users need them. ## How was this patch tested? Removed the associated test case as well. Author: Reynold Xin Closes #21309 from rxin/SPARK-23907. --- .../org/apache/spark/sql/functions.scala | 171 ------------------ .../spark/sql/DataFrameAggregateSuite.scala | 68 ------- 2 files changed, 239 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index e7f866ddca681..3c9ace407a58e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -811,177 +811,6 @@ object functions { */ def var_pop(columnName: String): Column = var_pop(Column(columnName)) - /** - * Aggregate function: returns the number of non-null pairs. - * - * @group agg_funcs - * @since 2.4.0 - */ - def regr_count(y: Column, x: Column): Column = withAggregateFunction { - RegrCount(y.expr, x.expr) - } - - /** - * Aggregate function: returns the number of non-null pairs. - * - * @group agg_funcs - * @since 2.4.0 - */ - def regr_count(y: String, x: String): Column = regr_count(Column(y), Column(x)) - - /** - * Aggregate function: returns SUM(x*x)-SUM(x)*SUM(x)/N. Any pair with a NULL is ignored. - * - * @group agg_funcs - * @since 2.4.0 - */ - def regr_sxx(y: Column, x: Column): Column = withAggregateFunction { - RegrSXX(y.expr, x.expr) - } - - /** - * Aggregate function: returns SUM(x*x)-SUM(x)*SUM(x)/N. Any pair with a NULL is ignored. - * - * @group agg_funcs - * @since 2.4.0 - */ - def regr_sxx(y: String, x: String): Column = regr_sxx(Column(y), Column(x)) - - /** - * Aggregate function: returns SUM(y*y)-SUM(y)*SUM(y)/N. Any pair with a NULL is ignored. - * - * @group agg_funcs - * @since 2.4.0 - */ - def regr_syy(y: Column, x: Column): Column = withAggregateFunction { - RegrSYY(y.expr, x.expr) - } - - /** - * Aggregate function: returns SUM(y*y)-SUM(y)*SUM(y)/N. Any pair with a NULL is ignored. - * - * @group agg_funcs - * @since 2.4.0 - */ - def regr_syy(y: String, x: String): Column = regr_syy(Column(y), Column(x)) - - /** - * Aggregate function: returns the average of y. Any pair with a NULL is ignored. - * - * @group agg_funcs - * @since 2.4.0 - */ - def regr_avgy(y: Column, x: Column): Column = withAggregateFunction { - RegrAvgY(y.expr, x.expr) - } - - /** - * Aggregate function: returns the average of y. Any pair with a NULL is ignored. - * - * @group agg_funcs - * @since 2.4.0 - */ - def regr_avgy(y: String, x: String): Column = regr_avgy(Column(y), Column(x)) - - /** - * Aggregate function: returns the average of x. Any pair with a NULL is ignored. - * - * @group agg_funcs - * @since 2.4.0 - */ - def regr_avgx(y: Column, x: Column): Column = withAggregateFunction { - RegrAvgX(y.expr, x.expr) - } - - /** - * Aggregate function: returns the average of x. Any pair with a NULL is ignored. - * - * @group agg_funcs - * @since 2.4.0 - */ - def regr_avgx(y: String, x: String): Column = regr_avgx(Column(y), Column(x)) - - /** - * Aggregate function: returns the covariance of y and x multiplied for the number of items in - * the dataset. Any pair with a NULL is ignored. - * - * @group agg_funcs - * @since 2.4.0 - */ - def regr_sxy(y: Column, x: Column): Column = withAggregateFunction { - RegrSXY(y.expr, x.expr) - } - - /** - * Aggregate function: returns the covariance of y and x multiplied for the number of items in - * the dataset. Any pair with a NULL is ignored. - * - * @group agg_funcs - * @since 2.4.0 - */ - def regr_sxy(y: String, x: String): Column = regr_sxy(Column(y), Column(x)) - - /** - * Aggregate function: returns the slope of the linear regression line. Any pair with a NULL is - * ignored. - * - * @group agg_funcs - * @since 2.4.0 - */ - def regr_slope(y: Column, x: Column): Column = withAggregateFunction { - RegrSlope(y.expr, x.expr) - } - - /** - * Aggregate function: returns the slope of the linear regression line. Any pair with a NULL is - * ignored. - * - * @group agg_funcs - * @since 2.4.0 - */ - def regr_slope(y: String, x: String): Column = regr_slope(Column(y), Column(x)) - - /** - * Aggregate function: returns the coefficient of determination (also called R-squared or - * goodness of fit) for the regression line. Any pair with a NULL is ignored. - * - * @group agg_funcs - * @since 2.4.0 - */ - def regr_r2(y: Column, x: Column): Column = withAggregateFunction { - RegrR2(y.expr, x.expr) - } - - /** - * Aggregate function: returns the coefficient of determination (also called R-squared or - * goodness of fit) for the regression line. Any pair with a NULL is ignored. - * - * @group agg_funcs - * @since 2.4.0 - */ - def regr_r2(y: String, x: String): Column = regr_r2(Column(y), Column(x)) - - /** - * Aggregate function: returns the y-intercept of the linear regression line. Any pair with a - * NULL is ignored. - * - * @group agg_funcs - * @since 2.4.0 - */ - def regr_intercept(y: Column, x: Column): Column = withAggregateFunction { - RegrIntercept(y.expr, x.expr) - } - - /** - * Aggregate function: returns the y-intercept of the linear regression line. Any pair with a - * NULL is ignored. - * - * @group agg_funcs - * @since 2.4.0 - */ - def regr_intercept(y: String, x: String): Column = regr_intercept(Column(y), Column(x)) - - ////////////////////////////////////////////////////////////////////////////////////////////// // Window functions diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 4337fb2290fbc..96c28961e5aaf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -687,72 +687,4 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { } } } - - test("SPARK-23907: regression functions") { - val emptyTableData = Seq.empty[(Double, Double)].toDF("a", "b") - val correlatedData = Seq[(Double, Double)]((2, 3), (3, 4), (7.5, 8.2), (10.3, 12)) - .toDF("a", "b") - val correlatedDataWithNull = Seq[(java.lang.Double, java.lang.Double)]( - (2.0, 3.0), (3.0, null), (7.5, 8.2), (10.3, 12.0)).toDF("a", "b") - checkAnswer(testData2.groupBy().agg(regr_count("a", "b")), Seq(Row(6))) - checkAnswer(testData3.groupBy().agg(regr_count("a", "b")), Seq(Row(1))) - checkAnswer(emptyTableData.groupBy().agg(regr_count("a", "b")), Seq(Row(0))) - - checkAggregatesWithTol(testData2.groupBy().agg(regr_sxx("a", "b")), Row(1.5), absTol) - checkAggregatesWithTol(testData3.groupBy().agg(regr_sxx("a", "b")), Row(0.0), absTol) - checkAggregatesWithTol(emptyTableData.groupBy().agg(regr_sxx("a", "b")), Row(null), absTol) - checkAggregatesWithTol(testData2.groupBy().agg(regr_syy("b", "a")), Row(1.5), absTol) - checkAggregatesWithTol(testData3.groupBy().agg(regr_syy("b", "a")), Row(0.0), absTol) - checkAggregatesWithTol(emptyTableData.groupBy().agg(regr_syy("b", "a")), Row(null), absTol) - - checkAggregatesWithTol(testData2.groupBy().agg(regr_avgx("a", "b")), Row(1.5), absTol) - checkAggregatesWithTol(testData3.groupBy().agg(regr_avgx("a", "b")), Row(2.0), absTol) - checkAggregatesWithTol(emptyTableData.groupBy().agg(regr_avgx("a", "b")), Row(null), absTol) - checkAggregatesWithTol(testData2.groupBy().agg(regr_avgy("b", "a")), Row(1.5), absTol) - checkAggregatesWithTol(testData3.groupBy().agg(regr_avgy("b", "a")), Row(2.0), absTol) - checkAggregatesWithTol(emptyTableData.groupBy().agg(regr_avgy("b", "a")), Row(null), absTol) - - checkAggregatesWithTol(testData2.groupBy().agg(regr_sxy("a", "b")), Row(0.0), absTol) - checkAggregatesWithTol(testData3.groupBy().agg(regr_sxy("a", "b")), Row(0.0), absTol) - checkAggregatesWithTol(emptyTableData.groupBy().agg(regr_sxy("a", "b")), Row(null), absTol) - - checkAggregatesWithTol(testData2.groupBy().agg(regr_slope("a", "b")), Row(0.0), absTol) - checkAggregatesWithTol(testData3.groupBy().agg(regr_slope("a", "b")), Row(null), absTol) - checkAggregatesWithTol(emptyTableData.groupBy().agg(regr_slope("a", "b")), Row(null), absTol) - - checkAggregatesWithTol(testData2.groupBy().agg(regr_r2("a", "b")), Row(0.0), absTol) - checkAggregatesWithTol(testData3.groupBy().agg(regr_r2("a", "b")), Row(null), absTol) - checkAggregatesWithTol(emptyTableData.groupBy().agg(regr_r2("a", "b")), Row(null), absTol) - - checkAggregatesWithTol(testData2.groupBy().agg(regr_intercept("a", "b")), Row(2.0), absTol) - checkAggregatesWithTol(testData3.groupBy().agg(regr_intercept("a", "b")), Row(null), absTol) - checkAggregatesWithTol(emptyTableData.groupBy().agg(regr_intercept("a", "b")), - Row(null), absTol) - - - checkAggregatesWithTol(correlatedData.groupBy().agg( - regr_count("a", "b"), - regr_avgx("a", "b"), - regr_avgy("a", "b"), - regr_sxx("a", "b"), - regr_syy("a", "b"), - regr_sxy("a", "b"), - regr_slope("a", "b"), - regr_r2("a", "b"), - regr_intercept("a", "b")), - Row(4, 6.8, 5.7, 51.28, 45.38, 48.06, 0.937207488, 0.992556013, -0.67301092), - absTol) - checkAggregatesWithTol(correlatedDataWithNull.groupBy().agg( - regr_count("a", "b"), - regr_avgx("a", "b"), - regr_avgy("a", "b"), - regr_sxx("a", "b"), - regr_syy("a", "b"), - regr_sxy("a", "b"), - regr_slope("a", "b"), - regr_r2("a", "b"), - regr_intercept("a", "b")), - Row(3, 7.73333333, 6.6, 40.82666666, 35.66, 37.98, 0.93027433, 0.99079694, -0.59412149), - absTol) - } } From 5902125ac7ad25a0cb7aa3d98825c8290ee33c12 Mon Sep 17 00:00:00 2001 From: Marek Novotny Date: Sat, 12 May 2018 19:21:42 +0800 Subject: [PATCH 0787/1161] [SPARK-24198][SPARKR][SQL] Adding slice function to SparkR ## What changes were proposed in this pull request? The PR adds the `slice` function to SparkR. The function returns a subset of consecutive elements from the given array. ``` > df <- createDataFrame(cbind(model = rownames(mtcars), mtcars)) > tmp <- mutate(df, v1 = create_array(df$mpg, df$cyl, df$hp)) > head(select(tmp, slice(tmp$v1, 2L, 2L))) ``` ``` slice(v1, 2, 2) 1 6, 110 2 6, 110 3 4, 93 4 6, 110 5 8, 175 6 6, 105 ``` ## How was this patch tested? A test added into R/pkg/tests/fulltests/test_sparkSQL.R Author: Marek Novotny Closes #21298 from mn-mikke/SPARK-24198. --- R/pkg/NAMESPACE | 1 + R/pkg/R/functions.R | 17 +++++++++++++++++ R/pkg/R/generics.R | 4 ++++ R/pkg/tests/fulltests/test_sparkSQL.R | 5 +++++ 4 files changed, 27 insertions(+) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 5f8209689a559..c575fe255f57a 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -352,6 +352,7 @@ exportMethods("%<=>%", "sinh", "size", "skewness", + "slice", "sort_array", "soundex", "spark_partition_id", diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 4964594284aa0..77d70cb5d19e6 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -212,6 +212,7 @@ NULL #' tmp2 <- mutate(tmp, v2 = explode(tmp$v1)) #' head(tmp2) #' head(select(tmp, posexplode(tmp$v1))) +#' head(select(tmp, slice(tmp$v1, 2L, 2L))) #' head(select(tmp, sort_array(tmp$v1))) #' head(select(tmp, sort_array(tmp$v1, asc = FALSE))) #' tmp3 <- mutate(df, v3 = create_map(df$model, df$cyl)) @@ -3142,6 +3143,22 @@ setMethod("size", column(jc) }) +#' @details +#' \code{slice}: Returns an array containing all the elements in x from the index start +#' (or starting from the end if start is negative) with the specified length. +#' +#' @rdname column_collection_functions +#' @param start an index indicating the first element occuring in the result. +#' @param length a number of consecutive elements choosen to the result. +#' @aliases slice slice,Column-method +#' @note slice since 2.4.0 +setMethod("slice", + signature(x = "Column"), + function(x, start, length) { + jc <- callJStatic("org.apache.spark.sql.functions", "slice", x@jc, start, length) + column(jc) + }) + #' @details #' \code{sort_array}: Sorts the input array in ascending or descending order according to #' the natural ordering of the array elements. NA elements will be placed at the beginning of diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 5faa51eef3abd..fbc4113e2becc 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -1194,6 +1194,10 @@ setGeneric("size", function(x) { standardGeneric("size") }) #' @name NULL setGeneric("skewness", function(x) { standardGeneric("skewness") }) +#' @rdname column_collection_functions +#' @name NULL +setGeneric("slice", function(x, start, length) { standardGeneric("slice") }) + #' @rdname column_collection_functions #' @name NULL setGeneric("sort_array", function(x, asc = TRUE) { standardGeneric("sort_array") }) diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index b8bfded0ebf2d..2a550b9efb506 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -1507,6 +1507,11 @@ test_that("column functions", { result <- collect(select(df, sort_array(df[[1]])))[[1]] expect_equal(result, list(list(NA, 1L, 2L, 3L), list(NA, NA, 4L, 5L, 6L))) + # Test slice() + df <- createDataFrame(list(list(list(1L, 2L, 3L)), list(list(4L, 5L)))) + result <- collect(select(df, slice(df[[1]], 2L, 2L)))[[1]] + expect_equal(result, list(list(2L, 3L), list(5L))) + # Test flattern df <- createDataFrame(list(list(list(list(1L, 2L), list(3L, 4L))), list(list(list(5L, 6L), list(7L, 8L))))) From 348ddfd20f5b88777014f18a6374f33ee9b12731 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Sat, 12 May 2018 08:35:14 -0500 Subject: [PATCH 0788/1161] [BUILD] Close stale PRs Closes https://github.com/apache/spark/pull/20458 Closes https://github.com/apache/spark/pull/20530 Closes https://github.com/apache/spark/pull/20557 Closes https://github.com/apache/spark/pull/20966 Closes https://github.com/apache/spark/pull/20857 Closes https://github.com/apache/spark/pull/19694 Closes https://github.com/apache/spark/pull/18227 Closes https://github.com/apache/spark/pull/20683 Closes https://github.com/apache/spark/pull/20881 Closes https://github.com/apache/spark/pull/20347 Closes https://github.com/apache/spark/pull/20825 Closes https://github.com/apache/spark/pull/20078 Closes https://github.com/apache/spark/pull/21281 Closes https://github.com/apache/spark/pull/19951 Closes https://github.com/apache/spark/pull/20905 Closes https://github.com/apache/spark/pull/20635 Author: Sean Owen Closes #21303 from srowen/ClosePRs. From 32acfa78c60465efc03ae01e022614ad91345b1c Mon Sep 17 00:00:00 2001 From: Cody Allen Date: Sat, 12 May 2018 14:35:40 -0500 Subject: [PATCH 0789/1161] Improve implicitNotFound message for Encoder The `implicitNotFound` message for `Encoder` doesn't mention the name of the type for which it can't find an encoder. Furthermore, it covers up the fact that `Encoder` is the name of the relevant type class. Hopefully this new message provides a little more specific type detail while still giving the general message about which types are supported. ## What changes were proposed in this pull request? Augment the existing message to mention that it's looking for an `Encoder` and what the type of the encoder is. For example instead of: ``` Unable to find encoder for type stored in a Dataset. Primitive types (Int, String, etc) and Product types (case classes) are supported by importing spark.implicits._ Support for serializing other types will be added in future releases. ``` return this message: ``` Unable to find encoder for type Exception. An implicit Encoder[Exception] is needed to store Exception instances in a Dataset. Primitive types (Int, String, etc) and Product types (ca se classes) are supported by importing spark.implicits._ Support for serializing other types will be added in future releases. ``` ## How was this patch tested? It was tested manually in the Scala REPL, since triggering this in a test would cause a compilation error. ``` scala> implicitly[Encoder[Exception]] :51: error: Unable to find encoder for type Exception. An implicit Encoder[Exception] is needed to store Exception instances in a Dataset. Primitive types (Int, String, etc) and Product types (ca se classes) are supported by importing spark.implicits._ Support for serializing other types will be added in future releases. implicitly[Encoder[Exception]] ^ ``` Author: Cody Allen Closes #20869 from ceedubs/encoder-implicit-msg. --- .../src/main/scala/org/apache/spark/sql/Encoder.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala index ccdb6bc5d4b7c..7b02317b8538f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala @@ -68,10 +68,10 @@ import org.apache.spark.sql.types._ */ @Experimental @InterfaceStability.Evolving -@implicitNotFound("Unable to find encoder for type stored in a Dataset. Primitive types " + - "(Int, String, etc) and Product types (case classes) are supported by importing " + - "spark.implicits._ Support for serializing other types will be added in future " + - "releases.") +@implicitNotFound("Unable to find encoder for type ${T}. An implicit Encoder[${T}] is needed to " + + "store ${T} instances in a Dataset. Primitive types (Int, String, etc) and Product types (case " + + "classes) are supported by importing spark.implicits._ Support for serializing other types " + + "will be added in future releases.") trait Encoder[T] extends Serializable { /** Returns the schema of encoding this type of object as a Row. */ From 0d210ec8b610e4b0570ce730f3987dc86787c663 Mon Sep 17 00:00:00 2001 From: Kelley Robinson Date: Sun, 13 May 2018 13:19:03 -0700 Subject: [PATCH 0790/1161] [SPARK-24262][PYTHON] Fix typo in UDF type match error message ## What changes were proposed in this pull request? Updates `functon` to `function`. This was called out in holdenk's PyCon 2018 conference talk. Didn't see any existing PR's for this. holdenk happy to fix the Pandas.Series bug too but will need a bit more guidance. Author: Kelley Robinson Closes #21304 from robinske/master. --- python/pyspark/worker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 8bb63fcc7ff9c..5d2e58bef6466 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -82,7 +82,7 @@ def wrap_scalar_pandas_udf(f, return_type): def verify_result_length(*a): result = f(*a) if not hasattr(result, "__len__"): - raise TypeError("Return type of the user-defined functon should be " + raise TypeError("Return type of the user-defined function should be " "Pandas.Series, but is {}".format(type(result))) if len(result) != len(a[0]): raise RuntimeError("Result vector from pandas_udf was not the required length: " From 2fa33649d96394ae630a092a9f7e1261d1893f6e Mon Sep 17 00:00:00 2001 From: Fan Donglai Date: Sun, 13 May 2018 18:10:00 -0500 Subject: [PATCH 0791/1161] Update StreamingKMeans.scala MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? I think the ‘n_t+t’ in the following code may be wrong, it shoud be ‘n_t+1’ that means is the number of points to the cluster after it finish the no.t+1 min-batch. *
    * $$ * \begin{align} * c_t+1 &= [(c_t * n_t * a) + (x_t * m_t)] / [n_t + m_t] \\ * n_t+t &= n_t * a + m_t * \end{align} * $$ *
    Author: Fan Donglai Closes #21179 from ddna1021/master. --- .../org/apache/spark/mllib/clustering/StreamingKMeans.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala index 3ca75e8cdb97a..7a5e520d5818e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala @@ -43,7 +43,7 @@ import org.apache.spark.util.random.XORShiftRandom * $$ * \begin{align} * c_t+1 &= [(c_t * n_t * a) + (x_t * m_t)] / [n_t + m_t] \\ - * n_t+t &= n_t * a + m_t + * n_t+1 &= n_t * a + m_t * \end{align} * $$ * From 3f0e801c11e600ed28491924e550d3ba93f19c19 Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Mon, 14 May 2018 09:48:54 +0800 Subject: [PATCH 0792/1161] [SPARK-24186][R][SQL] change reverse and concat to collection functions in R ## What changes were proposed in this pull request? reverse and concat are already in functions.R as column string functions. Since now these two functions are categorized as collection functions in scala and python, we will do the same in R. ## How was this patch tested? Add test in test_sparkSQL.R Author: Huaxin Gao Closes #21307 from huaxingao/spark_24186. --- R/pkg/R/functions.R | 35 ++++++++++++++------------- R/pkg/R/generics.R | 4 +-- R/pkg/tests/fulltests/test_sparkSQL.R | 17 +++++++++++-- 3 files changed, 35 insertions(+), 21 deletions(-) diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 77d70cb5d19e6..fcb3521f901ea 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -208,7 +208,7 @@ NULL #' head(select(tmp, array_contains(tmp$v1, 21), size(tmp$v1))) #' head(select(tmp, array_max(tmp$v1), array_min(tmp$v1))) #' head(select(tmp, array_position(tmp$v1, 21), array_sort(tmp$v1))) -#' head(select(tmp, flatten(tmp$v1))) +#' head(select(tmp, flatten(tmp$v1), reverse(tmp$v1))) #' tmp2 <- mutate(tmp, v2 = explode(tmp$v1)) #' head(tmp2) #' head(select(tmp, posexplode(tmp$v1))) @@ -218,7 +218,10 @@ NULL #' tmp3 <- mutate(df, v3 = create_map(df$model, df$cyl)) #' head(select(tmp3, map_keys(tmp3$v3))) #' head(select(tmp3, map_values(tmp3$v3))) -#' head(select(tmp3, element_at(tmp3$v3, "Valiant")))} +#' head(select(tmp3, element_at(tmp3$v3, "Valiant"))) +#' tmp4 <- mutate(df, v4 = create_array(df$mpg, df$cyl), v5 = create_array(df$hp)) +#' head(select(tmp4, concat(tmp4$v4, tmp4$v5))) +#' head(select(tmp, concat(df$mpg, df$cyl, df$hp)))} NULL #' Window functions for Column operations @@ -1260,9 +1263,9 @@ setMethod("quarter", }) #' @details -#' \code{reverse}: Reverses the string column and returns it as a new string column. +#' \code{reverse}: Returns a reversed string or an array with reverse order of elements. #' -#' @rdname column_string_functions +#' @rdname column_collection_functions #' @aliases reverse reverse,Column-method #' @note reverse since 1.5.0 setMethod("reverse", @@ -2055,20 +2058,10 @@ setMethod("countDistinct", #' @details #' \code{concat}: Concatenates multiple input columns together into a single column. -#' If all inputs are binary, concat returns an output as binary. Otherwise, it returns as string. +#' The function works with strings, binary and compatible array columns. #' -#' @rdname column_string_functions +#' @rdname column_collection_functions #' @aliases concat concat,Column-method -#' @examples -#' -#' \dontrun{ -#' # concatenate strings -#' tmp <- mutate(df, s1 = concat(df$Class, df$Sex), -#' s2 = concat(df$Class, df$Sex, df$Age), -#' s3 = concat(df$Class, df$Sex, df$Age, df$Class), -#' s4 = concat_ws("_", df$Class, df$Sex), -#' s5 = concat_ws("+", df$Class, df$Sex, df$Age, df$Survived)) -#' head(tmp)} #' @note concat since 1.5.0 setMethod("concat", signature(x = "Column"), @@ -2409,6 +2402,13 @@ setMethod("shiftRightUnsigned", signature(y = "Column", x = "numeric"), #' @param sep separator to use. #' @rdname column_string_functions #' @aliases concat_ws concat_ws,character,Column-method +#' @examples +#' +#' \dontrun{ +#' # concatenate strings +#' tmp <- mutate(df, s1 = concat_ws("_", df$Class, df$Sex), +#' s2 = concat_ws("+", df$Class, df$Sex, df$Age, df$Survived)) +#' head(tmp)} #' @note concat_ws since 1.5.0 setMethod("concat_ws", signature(sep = "character", x = "Column"), function(sep, x, ...) { @@ -3063,7 +3063,8 @@ setMethod("array_sort", }) #' @details -#' \code{flatten}: Transforms an array of arrays into a single array. +#' \code{flatten}: Creates a single array from an array of arrays. +#' If a structure of nested arrays is deeper than two levels, only one level of nesting is removed. #' #' @rdname column_collection_functions #' @aliases flatten flatten,Column-method diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index fbc4113e2becc..61da30badac4e 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -817,7 +817,7 @@ setGeneric("collect_set", function(x) { standardGeneric("collect_set") }) #' @rdname column setGeneric("column", function(x) { standardGeneric("column") }) -#' @rdname column_string_functions +#' @rdname column_collection_functions #' @name NULL setGeneric("concat", function(x, ...) { standardGeneric("concat") }) @@ -1134,7 +1134,7 @@ setGeneric("regexp_replace", #' @name NULL setGeneric("repeat_string", function(x, n) { standardGeneric("repeat_string") }) -#' @rdname column_string_functions +#' @rdname column_collection_functions #' @name NULL setGeneric("reverse", function(x) { standardGeneric("reverse") }) diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index 2a550b9efb506..13b55ac6e6e3c 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -1479,7 +1479,7 @@ test_that("column functions", { df5 <- createDataFrame(list(list(a = "010101"))) expect_equal(collect(select(df5, conv(df5$a, 2, 16)))[1, 1], "15") - # Test array_contains(), array_max(), array_min(), array_position() and element_at() + # Test array_contains(), array_max(), array_min(), array_position(), element_at() and reverse() df <- createDataFrame(list(list(list(1L, 2L, 3L)), list(list(6L, 5L, 4L)))) result <- collect(select(df, array_contains(df[[1]], 1L)))[[1]] expect_equal(result, c(TRUE, FALSE)) @@ -1496,6 +1496,13 @@ test_that("column functions", { result <- collect(select(df, element_at(df[[1]], 1L)))[[1]] expect_equal(result, c(1, 6)) + result <- collect(select(df, reverse(df[[1]])))[[1]] + expect_equal(result, list(list(3L, 2L, 1L), list(4L, 5L, 6L))) + + df2 <- createDataFrame(list(list("abc"))) + result <- collect(select(df2, reverse(df2[[1]])))[[1]] + expect_equal(result, "cba") + # Test array_sort() and sort_array() df <- createDataFrame(list(list(list(2L, 1L, 3L, NA)), list(list(NA, 6L, 5L, NA, 4L)))) @@ -1512,7 +1519,13 @@ test_that("column functions", { result <- collect(select(df, slice(df[[1]], 2L, 2L)))[[1]] expect_equal(result, list(list(2L, 3L), list(5L))) - # Test flattern + # Test concat() + df <- createDataFrame(list(list(list(1L, 2L, 3L), list(4L, 5L, 6L)), + list(list(7L, 8L, 9L), list(10L, 11L, 12L)))) + result <- collect(select(df, concat(df[[1]], df[[2]])))[[1]] + expect_equal(result, list(list(1L, 2L, 3L, 4L, 5L, 6L), list(7L, 8L, 9L, 10L, 11L, 12L))) + + # Test flatten() df <- createDataFrame(list(list(list(list(1L, 2L), list(3L, 4L))), list(list(list(5L, 6L), list(7L, 8L))))) result <- collect(select(df, flatten(df[[1]])))[[1]] From 7a2d4895c75d4c232c377876b61c05a083eab3c8 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Mon, 14 May 2018 10:01:06 +0800 Subject: [PATCH 0793/1161] [SPARK-17916][SQL] Fix empty string being parsed as null when nullValue is set. ## What changes were proposed in this pull request? I propose to bump version of uniVocity parser up to 2.6.3 where quoted empty strings are replaced by the empty value (passed to `setEmptyValue`) instead of `null` values as in the current version 2.5.9: https://github.com/uniVocity/univocity-parsers/blob/v2.6.3/src/main/java/com/univocity/parsers/csv/CsvParser.java#L125 Empty value for writer is set to `""`. So, empty string in dataframe/dataset is stored as empty quoted string `""`. Empty value for reader is set to empty string (zero size). In this way, saved empty quoted string will be read as just empty string. Please, look at the tests for more details. Here are main changes made in [2.6.0](https://github.com/uniVocity/univocity-parsers/releases/tag/v2.6.0), [2.6.1](https://github.com/uniVocity/univocity-parsers/releases/tag/v2.6.1), [2.6.2](https://github.com/uniVocity/univocity-parsers/releases/tag/v2.6.2), [2.6.3](https://github.com/uniVocity/univocity-parsers/releases/tag/v2.6.3): - CSV parser now parses quoted values ~30% faster - CSV format detection process has option provide a list of possible delimiters, in order of priority ( i.e. settings.detectFormatAutomatically( '-', '.');) - https://github.com/uniVocity/univocity-parsers/issues/214 - Implemented trim quoted values support - https://github.com/uniVocity/univocity-parsers/issues/230 - NullPointer when stopping parser when nothing is parsed - https://github.com/uniVocity/univocity-parsers/issues/219 - Concurrency issue when calling stopParsing() - https://github.com/uniVocity/univocity-parsers/issues/231 Closes #20068 ## How was this patch tested? Added tests from the PR https://github.com/apache/spark/pull/20068 Author: Maxim Gekk Closes #21273 from MaxGekk/univocity-2.6. --- dev/deps/spark-deps-hadoop-2.6 | 2 +- dev/deps/spark-deps-hadoop-2.7 | 2 +- dev/deps/spark-deps-hadoop-3.1 | 2 +- sql/core/pom.xml | 2 +- .../datasources/csv/CSVOptions.scala | 3 +- .../datasources/csv/CSVBenchmarks.scala | 80 +++++++++++++++++++ .../execution/datasources/csv/CSVSuite.scala | 46 +++++++++++ 7 files changed, 132 insertions(+), 5 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVBenchmarks.scala diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6 index f552b81fde9f4..e710e26348117 100644 --- a/dev/deps/spark-deps-hadoop-2.6 +++ b/dev/deps/spark-deps-hadoop-2.6 @@ -190,7 +190,7 @@ stax-api-1.0.1.jar stream-2.7.0.jar stringtemplate-3.2.1.jar super-csv-2.2.0.jar -univocity-parsers-2.5.9.jar +univocity-parsers-2.6.3.jar validation-api-1.1.0.Final.jar xbean-asm5-shaded-4.4.jar xercesImpl-2.9.1.jar diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index 024b1fca717df..97ad17a9ff7b1 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -191,7 +191,7 @@ stax-api-1.0.1.jar stream-2.7.0.jar stringtemplate-3.2.1.jar super-csv-2.2.0.jar -univocity-parsers-2.5.9.jar +univocity-parsers-2.6.3.jar validation-api-1.1.0.Final.jar xbean-asm5-shaded-4.4.jar xercesImpl-2.9.1.jar diff --git a/dev/deps/spark-deps-hadoop-3.1 b/dev/deps/spark-deps-hadoop-3.1 index 938de7bc06663..e21bfef8c4291 100644 --- a/dev/deps/spark-deps-hadoop-3.1 +++ b/dev/deps/spark-deps-hadoop-3.1 @@ -211,7 +211,7 @@ stream-2.7.0.jar stringtemplate-3.2.1.jar super-csv-2.2.0.jar token-provider-1.0.1.jar -univocity-parsers-2.5.9.jar +univocity-parsers-2.6.3.jar validation-api-1.1.0.Final.jar woodstox-core-5.0.3.jar xbean-asm5-shaded-4.4.jar diff --git a/sql/core/pom.xml b/sql/core/pom.xml index ef41837f89d68..f270c70fbfcf0 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -38,7 +38,7 @@ com.univocity univocity-parsers - 2.5.9 + 2.6.3 jar diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala index ed2dc65a47914..1066d156acd74 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala @@ -164,7 +164,7 @@ class CSVOptions( writerSettings.setIgnoreLeadingWhitespaces(ignoreLeadingWhiteSpaceFlagInWrite) writerSettings.setIgnoreTrailingWhitespaces(ignoreTrailingWhiteSpaceFlagInWrite) writerSettings.setNullValue(nullValue) - writerSettings.setEmptyValue(nullValue) + writerSettings.setEmptyValue("\"\"") writerSettings.setSkipEmptyLines(true) writerSettings.setQuoteAllFields(quoteAll) writerSettings.setQuoteEscapingEnabled(escapeQuotes) @@ -185,6 +185,7 @@ class CSVOptions( settings.setInputBufferSize(inputBufferSize) settings.setMaxColumns(maxColumns) settings.setNullValue(nullValue) + settings.setEmptyValue("") settings.setMaxCharsPerColumn(maxCharsPerColumn) settings.setUnescapedQuoteHandling(UnescapedQuoteHandling.STOP_AT_DELIMITER) settings diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVBenchmarks.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVBenchmarks.scala new file mode 100644 index 0000000000000..d442ba7e59c61 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVBenchmarks.scala @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.datasources.csv + +import java.io.File + +import org.apache.spark.SparkConf +import org.apache.spark.sql.{Column, Row, SparkSession} +import org.apache.spark.sql.functions.lit +import org.apache.spark.sql.types._ +import org.apache.spark.util.{Benchmark, Utils} + +/** + * Benchmark to measure CSV read/write performance. + * To run this: + * spark-submit --class --jars + */ +object CSVBenchmarks { + val conf = new SparkConf() + + val spark = SparkSession.builder + .master("local[1]") + .appName("benchmark-csv-datasource") + .config(conf) + .getOrCreate() + import spark.implicits._ + + def withTempPath(f: File => Unit): Unit = { + val path = Utils.createTempDir() + path.delete() + try f(path) finally Utils.deleteRecursively(path) + } + + def quotedValuesBenchmark(rowsNum: Int, numIters: Int): Unit = { + val benchmark = new Benchmark(s"Parsing quoted values", rowsNum) + + withTempPath { path => + val str = (0 until 10000).map(i => s""""$i"""").mkString(",") + + spark.range(rowsNum) + .map(_ => str) + .write.option("header", true) + .csv(path.getAbsolutePath) + + val schema = new StructType().add("value", StringType) + val ds = spark.read.option("header", true).schema(schema).csv(path.getAbsolutePath) + + benchmark.addCase(s"One quoted string", numIters) { _ => + ds.filter((_: Row) => true).count() + } + + /* + Intel(R) Core(TM) i7-7920HQ CPU @ 3.10GHz + + Parsing quoted values: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + -------------------------------------------------------------------------------------------- + One quoted string 30273 / 30549 0.0 605451.2 1.0X + */ + benchmark.run() + } + } + + def main(args: Array[String]): Unit = { + quotedValuesBenchmark(rowsNum = 50 * 1000, numIters = 3) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index 461abdd96d3f3..07e6c74b14d0d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -1322,4 +1322,50 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils with Te val sampled = spark.read.option("inferSchema", true).option("samplingRatio", 1.0).csv(ds) assert(sampled.count() == ds.count()) } + + test("SPARK-17916: An empty string should not be coerced to null when nullValue is passed.") { + val litNull: String = null + val df = Seq( + (1, "John Doe"), + (2, ""), + (3, "-"), + (4, litNull) + ).toDF("id", "name") + + // Checks for new behavior where an empty string is not coerced to null when `nullValue` is + // set to anything but an empty string literal. + withTempPath { path => + df.write + .option("nullValue", "-") + .csv(path.getAbsolutePath) + val computed = spark.read + .option("nullValue", "-") + .schema(df.schema) + .csv(path.getAbsolutePath) + val expected = Seq( + (1, "John Doe"), + (2, ""), + (3, litNull), + (4, litNull) + ).toDF("id", "name") + + checkAnswer(computed, expected) + } + // Keeps the old behavior where empty string us coerced to nullValue is not passed. + withTempPath { path => + df.write + .csv(path.getAbsolutePath) + val computed = spark.read + .schema(df.schema) + .csv(path.getAbsolutePath) + val expected = Seq( + (1, "John Doe"), + (2, litNull), + (3, "-"), + (4, litNull) + ).toDF("id", "name") + + checkAnswer(computed, expected) + } + } } From b6c50d7820aafab172835633fb0b35899e93146b Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Mon, 14 May 2018 10:57:10 +0800 Subject: [PATCH 0794/1161] [SPARK-24228][SQL] Fix Java lint errors ## What changes were proposed in this pull request? This PR fixes the following Java lint errors due to importing unimport classes ``` $ dev/lint-java Using `mvn` from path: /usr/bin/mvn Checkstyle checks failed at following occurrences: [ERROR] src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Distribution.java:[25] (sizes) LineLength: Line is longer than 100 characters (found 109). [ERROR] src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousReader.java:[38] (sizes) LineLength: Line is longer than 100 characters (found 102). [ERROR] src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java:[21,8] (imports) UnusedImports: Unused import - java.io.ByteArrayInputStream. [ERROR] src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java:[29,8] (imports) UnusedImports: Unused import - org.apache.spark.unsafe.Platform. [ERROR] src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java:[110] (sizes) LineLength: Line is longer than 100 characters (found 101). ``` With this PR ``` $ dev/lint-java Using `mvn` from path: /usr/bin/mvn Checkstyle checks passed. ``` ## How was this patch tested? Existing UTs. Also manually run checkstyles against these two files. Author: Kazuaki Ishizaki Closes #21301 from kiszk/SPARK-24228. --- .../datasources/parquet/SpecificParquetRecordReaderBase.java | 1 - .../datasources/parquet/VectorizedPlainValuesReader.java | 1 - .../sql/sources/v2/reader/partitioning/Distribution.java | 3 ++- .../sql/sources/v2/reader/streaming/ContinuousReader.java | 4 ++-- .../apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java | 3 ++- 5 files changed, 6 insertions(+), 6 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java index 10d6ed85a4080..daedfd7e78f5f 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java @@ -18,7 +18,6 @@ package org.apache.spark.sql.execution.datasources.parquet; -import java.io.ByteArrayInputStream; import java.io.File; import java.io.IOException; import java.lang.reflect.InvocationTargetException; diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java index aacefacfc1c1a..c62dc3d86386e 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java @@ -26,7 +26,6 @@ import org.apache.parquet.column.values.ValuesReader; import org.apache.parquet.io.api.Binary; -import org.apache.spark.unsafe.Platform; /** * An implementation of the Parquet PLAIN decoder that supports the vectorized interface. diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Distribution.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Distribution.java index d2ee9518d628f..5e32ba6952e1c 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Distribution.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Distribution.java @@ -22,7 +22,8 @@ /** * An interface to represent data distribution requirement, which specifies how the records should - * be distributed among the data partitions(one {@link InputPartitionReader} outputs data for one partition). + * be distributed among the data partitions (one {@link InputPartitionReader} outputs data for one + * partition). * Note that this interface has nothing to do with the data ordering inside one * partition(the output records of a single {@link InputPartitionReader}). * diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousReader.java index 716c5c0e9e15a..6e960bedf8020 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousReader.java @@ -35,8 +35,8 @@ @InterfaceStability.Evolving public interface ContinuousReader extends BaseStreamingSource, DataSourceReader { /** - * Merge partitioned offsets coming from {@link ContinuousInputPartitionReader} instances for each - * partition to a single global offset. + * Merge partitioned offsets coming from {@link ContinuousInputPartitionReader} instances + * for each partition to a single global offset. */ Offset mergeOffsets(PartitionOffset[] offsets); diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java index 714638e500c94..445cb29f5ee3a 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java @@ -107,7 +107,8 @@ public List> planInputPartitions() { } } - static class JavaAdvancedInputPartition implements InputPartition, InputPartitionReader { + static class JavaAdvancedInputPartition implements InputPartition, + InputPartitionReader { private int start; private int end; private StructType requiredSchema; From 1430fa80e37762e31cc5adc74cd609c215d84b6e Mon Sep 17 00:00:00 2001 From: Felix Cheung Date: Mon, 14 May 2018 10:49:12 -0700 Subject: [PATCH 0795/1161] [SPARK-24263][R] SparkR java check breaks with openjdk ## What changes were proposed in this pull request? Change text to grep for. ## How was this patch tested? manual test Author: Felix Cheung Closes #21314 from felixcheung/openjdkver. --- R/pkg/R/client.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/pkg/R/client.R b/R/pkg/R/client.R index e9295e05872bd..14a17c600b17f 100644 --- a/R/pkg/R/client.R +++ b/R/pkg/R/client.R @@ -82,7 +82,7 @@ checkJavaVersion <- function() { }) javaVersionFilter <- Filter( function(x) { - grepl("java version", x) + grepl(" version", x) }, javaVersionOut) javaVersionStr <- strsplit(javaVersionFilter[[1]], "[\"]")[[1L]][2] From c26f673252c2cbbccf8c395ba6d4ab80c098d60e Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Mon, 14 May 2018 11:37:57 -0700 Subject: [PATCH 0796/1161] [SPARK-24246][SQL] Improve AnalysisException by setting the cause when it's available ## What changes were proposed in this pull request? If there is an exception, it's better to set it as the cause of AnalysisException since the exception may contain useful debug information. ## How was this patch tested? Jenkins Author: Shixiong Zhu Closes #21297 from zsxwing/SPARK-24246. --- .../org/apache/spark/sql/catalyst/analysis/Analyzer.scala | 6 +++--- .../spark/sql/catalyst/analysis/ResolveInlineTables.scala | 2 +- .../org/apache/spark/sql/catalyst/analysis/package.scala | 5 +++++ .../org/apache/spark/sql/execution/datasources/rules.scala | 2 +- 4 files changed, 10 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index dfdcdbc1eb2c7..3eaa9ecf5d075 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -676,13 +676,13 @@ class Analyzer( try { catalog.lookupRelation(tableIdentWithDb) } catch { - case _: NoSuchTableException => - u.failAnalysis(s"Table or view not found: ${tableIdentWithDb.unquotedString}") + case e: NoSuchTableException => + u.failAnalysis(s"Table or view not found: ${tableIdentWithDb.unquotedString}", e) // If the database is defined and that database is not found, throw an AnalysisException. // Note that if the database is not defined, it is possible we are looking up a temp view. case e: NoSuchDatabaseException => u.failAnalysis(s"Table or view not found: ${tableIdentWithDb.unquotedString}, the " + - s"database ${e.db} doesn't exist.") + s"database ${e.db} doesn't exist.", e) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala index 4eb6e642b1c37..31ba9d792024b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala @@ -105,7 +105,7 @@ case class ResolveInlineTables(conf: SQLConf) extends Rule[LogicalPlan] with Cas castedExpr.eval() } catch { case NonFatal(ex) => - table.failAnalysis(s"failed to evaluate expression ${e.sql}: ${ex.getMessage}") + table.failAnalysis(s"failed to evaluate expression ${e.sql}: ${ex.getMessage}", ex) } }) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala index 7731336d247db..354a3fa0602a9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala @@ -41,6 +41,11 @@ package object analysis { def failAnalysis(msg: String): Nothing = { throw new AnalysisException(msg, t.origin.line, t.origin.startPosition) } + + /** Fails the analysis at the point where a specific tree node was parsed. */ + def failAnalysis(msg: String, cause: Throwable): Nothing = { + throw new AnalysisException(msg, t.origin.line, t.origin.startPosition, cause = Some(cause)) + } } /** Catches any AnalysisExceptions thrown by `f` and attaches `t`'s position if any. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala index 0dea767840ed3..cab00251622b8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala @@ -61,7 +61,7 @@ class ResolveSQLOnFile(sparkSession: SparkSession) extends Rule[LogicalPlan] { case _: ClassNotFoundException => u case e: Exception => // the provider is valid, but failed to create a logical plan - u.failAnalysis(e.getMessage) + u.failAnalysis(e.getMessage, e) } } } From 075d678c8844614910b50abca07282bde31ef7e0 Mon Sep 17 00:00:00 2001 From: Lu WANG Date: Mon, 14 May 2018 13:35:54 -0700 Subject: [PATCH 0797/1161] [SPARK-24155][ML] Instrumentation improvements for clustering ## What changes were proposed in this pull request? changed the instrument for all of the clustering methods ## How was this patch tested? N/A Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Lu WANG Closes #21218 from ludatabricks/SPARK-23686-1. --- .../org/apache/spark/ml/clustering/BisectingKMeans.scala | 7 +++++-- .../org/apache/spark/ml/clustering/GaussianMixture.scala | 5 ++++- .../main/scala/org/apache/spark/ml/clustering/KMeans.scala | 4 +++- 3 files changed, 12 insertions(+), 4 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala index 438e53ba6197c..1ad4e097246a3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala @@ -261,8 +261,9 @@ class BisectingKMeans @Since("2.0.0") ( transformSchema(dataset.schema, logging = true) val rdd = DatasetUtils.columnToOldVector(dataset, getFeaturesCol) - val instr = Instrumentation.create(this, rdd) - instr.logParams(featuresCol, predictionCol, k, maxIter, seed, minDivisibleClusterSize) + val instr = Instrumentation.create(this, dataset) + instr.logParams(featuresCol, predictionCol, k, maxIter, seed, + minDivisibleClusterSize, distanceMeasure) val bkm = new MLlibBisectingKMeans() .setK($(k)) @@ -275,6 +276,8 @@ class BisectingKMeans @Since("2.0.0") ( val summary = new BisectingKMeansSummary( model.transform(dataset), $(predictionCol), $(featuresCol), $(k)) model.setSummary(Some(summary)) + // TODO: need to extend logNamedValue to support Array + instr.logNamedValue("clusterSizes", summary.clusterSizes.mkString("[", ",", "]")) instr.logSuccess(model) model } diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala index 88d618c3a03a8..3091bb5a2e54c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala @@ -352,7 +352,7 @@ class GaussianMixture @Since("2.0.0") ( s"than ${GaussianMixture.MAX_NUM_FEATURES} features because the size of the covariance" + s" matrix is quadratic in the number of features.") - val instr = Instrumentation.create(this, instances) + val instr = Instrumentation.create(this, dataset) instr.logParams(featuresCol, predictionCol, probabilityCol, k, maxIter, seed, tol) instr.logNumFeatures(numFeatures) @@ -425,6 +425,9 @@ class GaussianMixture @Since("2.0.0") ( val summary = new GaussianMixtureSummary(model.transform(dataset), $(predictionCol), $(probabilityCol), $(featuresCol), $(k), logLikelihood) model.setSummary(Some(summary)) + instr.logNamedValue("logLikelihood", logLikelihood) + // TODO: need to extend logNamedValue to support Array + instr.logNamedValue("clusterSizes", summary.clusterSizes.mkString("[", ",", "]")) instr.logSuccess(model) model } diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index 97f246fbfd859..e72d7f9485e6a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -342,7 +342,7 @@ class KMeans @Since("1.5.0") ( instances.persist(StorageLevel.MEMORY_AND_DISK) } - val instr = Instrumentation.create(this, instances) + val instr = Instrumentation.create(this, dataset) instr.logParams(featuresCol, predictionCol, k, initMode, initSteps, distanceMeasure, maxIter, seed, tol) val algo = new MLlibKMeans() @@ -359,6 +359,8 @@ class KMeans @Since("1.5.0") ( model.transform(dataset), $(predictionCol), $(featuresCol), $(k)) model.setSummary(Some(summary)) + // TODO: need to extend logNamedValue to support Array + instr.logNamedValue("clusterSizes", summary.clusterSizes.mkString("[", ",", "]")) instr.logSuccess(model) if (handlePersistence) { instances.unpersist() From 8cd83acf4075d369bfcf9e703760d4946ef15f00 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Mon, 14 May 2018 14:05:42 -0700 Subject: [PATCH 0798/1161] [SPARK-24027][SQL] Support MapType with StringType for keys as the root type by from_json ## What changes were proposed in this pull request? Currently, the from_json function support StructType or ArrayType as the root type. The PR allows to specify MapType(StringType, DataType) as the root type additionally to mentioned types. For example: ```scala import org.apache.spark.sql.types._ val schema = MapType(StringType, IntegerType) val in = Seq("""{"a": 1, "b": 2, "c": 3}""").toDS() in.select(from_json($"value", schema, Map[String, String]())).collect() ``` ``` res1: Array[org.apache.spark.sql.Row] = Array([Map(a -> 1, b -> 2, c -> 3)]) ``` ## How was this patch tested? It was checked by new tests for the map type with integer type and struct type as value types. Also roundtrip tests like from_json(to_json) and to_json(from_json) for MapType are added. Author: Maxim Gekk Author: Maxim Gekk Closes #21108 from MaxGekk/from_json-map-type. --- python/pyspark/sql/functions.py | 10 ++- .../expressions/jsonExpressions.scala | 10 ++- .../sql/catalyst/json/JacksonParser.scala | 18 ++++- .../org/apache/spark/sql/functions.scala | 29 ++++---- .../apache/spark/sql/JsonFunctionsSuite.scala | 66 +++++++++++++++++++ 5 files changed, 113 insertions(+), 20 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index b62748e9a2d6c..6866c1cf9f882 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2095,12 +2095,13 @@ def json_tuple(col, *fields): return Column(jc) +@ignore_unicode_prefix @since(2.1) def from_json(col, schema, options={}): """ - Parses a column containing a JSON string into a :class:`StructType` or :class:`ArrayType` - of :class:`StructType`\\s with the specified schema. Returns `null`, in the case of an - unparseable string. + Parses a column containing a JSON string into a :class:`MapType` with :class:`StringType` + as keys type, :class:`StructType` or :class:`ArrayType` of :class:`StructType`\\s with + the specified schema. Returns `null`, in the case of an unparseable string. :param col: string column in json format :param schema: a StructType or ArrayType of StructType to use when parsing the json column. @@ -2117,6 +2118,9 @@ def from_json(col, schema, options={}): [Row(json=Row(a=1))] >>> df.select(from_json(df.value, "a INT").alias("json")).collect() [Row(json=Row(a=1))] + >>> schema = MapType(StringType(), IntegerType()) + >>> df.select(from_json(df.value, schema).alias("json")).collect() + [Row(json={u'a': 1})] >>> data = [(1, '''[{"a": 1}]''')] >>> schema = ArrayType(StructType([StructField("a", IntegerType())])) >>> df = spark.createDataFrame(data, ("key", "value")) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index 34161f0f03f4a..04a4eb0ffc032 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -548,7 +548,7 @@ case class JsonToStructs( forceNullableSchema = SQLConf.get.getConf(SQLConf.FROM_JSON_FORCE_NULLABLE_SCHEMA)) override def checkInputDataTypes(): TypeCheckResult = nullableSchema match { - case _: StructType | ArrayType(_: StructType, _) => + case _: StructType | ArrayType(_: StructType, _) | _: MapType => super.checkInputDataTypes() case _ => TypeCheckResult.TypeCheckFailure( s"Input schema ${nullableSchema.simpleString} must be a struct or an array of structs.") @@ -558,6 +558,7 @@ case class JsonToStructs( lazy val rowSchema = nullableSchema match { case st: StructType => st case ArrayType(st: StructType, _) => st + case mt: MapType => mt } // This converts parsed rows to the desired output by the given schema. @@ -567,6 +568,8 @@ case class JsonToStructs( (rows: Seq[InternalRow]) => if (rows.length == 1) rows.head else null case ArrayType(_: StructType, _) => (rows: Seq[InternalRow]) => new GenericArrayData(rows) + case _: MapType => + (rows: Seq[InternalRow]) => rows.head.getMap(0) } @transient @@ -613,6 +616,11 @@ case class JsonToStructs( } override def inputTypes: Seq[AbstractDataType] = StringType :: Nil + + override def sql: String = schema match { + case _: MapType => "entries" + case _ => super.sql + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala index a5a4a13eb608b..c3a4ca8f64bf6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala @@ -36,7 +36,7 @@ import org.apache.spark.util.Utils * Constructs a parser for a given schema that translates a json string to an [[InternalRow]]. */ class JacksonParser( - schema: StructType, + schema: DataType, val options: JSONOptions) extends Logging { import JacksonUtils._ @@ -57,7 +57,14 @@ class JacksonParser( * to a value according to a desired schema. This is a wrapper for the method * `makeConverter()` to handle a row wrapped with an array. */ - private def makeRootConverter(st: StructType): JsonParser => Seq[InternalRow] = { + private def makeRootConverter(dt: DataType): JsonParser => Seq[InternalRow] = { + dt match { + case st: StructType => makeStructRootConverter(st) + case mt: MapType => makeMapRootConverter(mt) + } + } + + private def makeStructRootConverter(st: StructType): JsonParser => Seq[InternalRow] = { val elementConverter = makeConverter(st) val fieldConverters = st.map(_.dataType).map(makeConverter).toArray (parser: JsonParser) => parseJsonToken[Seq[InternalRow]](parser, st) { @@ -87,6 +94,13 @@ class JacksonParser( } } + private def makeMapRootConverter(mt: MapType): JsonParser => Seq[InternalRow] = { + val fieldConverter = makeConverter(mt.valueType) + (parser: JsonParser) => parseJsonToken[Seq[InternalRow]](parser, mt) { + case START_OBJECT => Seq(InternalRow(convertMap(parser, fieldConverter))) + } + } + /** * Create a converter which converts the JSON documents held by the `JsonParser` * to a value according to a desired schema. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 3c9ace407a58e..b71dfdad8aa9b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3231,9 +3231,9 @@ object functions { from_json(e, schema.asInstanceOf[DataType], options) /** - * (Scala-specific) Parses a column containing a JSON string into a `StructType` or `ArrayType` - * of `StructType`s with the specified schema. Returns `null`, in the case of an unparseable - * string. + * (Scala-specific) Parses a column containing a JSON string into a `MapType` with `StringType` + * as keys type, `StructType` or `ArrayType` of `StructType`s with the specified schema. + * Returns `null`, in the case of an unparseable string. * * @param e a string column containing JSON data. * @param schema the schema to use when parsing the json string @@ -3263,9 +3263,9 @@ object functions { from_json(e, schema, options.asScala.toMap) /** - * (Java-specific) Parses a column containing a JSON string into a `StructType` or `ArrayType` - * of `StructType`s with the specified schema. Returns `null`, in the case of an unparseable - * string. + * (Java-specific) Parses a column containing a JSON string into a `MapType` with `StringType` + * as keys type, `StructType` or `ArrayType` of `StructType`s with the specified schema. + * Returns `null`, in the case of an unparseable string. * * @param e a string column containing JSON data. * @param schema the schema to use when parsing the json string @@ -3292,8 +3292,9 @@ object functions { from_json(e, schema, Map.empty[String, String]) /** - * Parses a column containing a JSON string into a `StructType` or `ArrayType` of `StructType`s - * with the specified schema. Returns `null`, in the case of an unparseable string. + * Parses a column containing a JSON string into a `MapType` with `StringType` as keys type, + * `StructType` or `ArrayType` of `StructType`s with the specified schema. + * Returns `null`, in the case of an unparseable string. * * @param e a string column containing JSON data. * @param schema the schema to use when parsing the json string @@ -3305,9 +3306,9 @@ object functions { from_json(e, schema, Map.empty[String, String]) /** - * (Java-specific) Parses a column containing a JSON string into a `StructType` or `ArrayType` - * of `StructType`s with the specified schema. Returns `null`, in the case of an unparseable - * string. + * (Java-specific) Parses a column containing a JSON string into a `MapType` with `StringType` + * as keys type, `StructType` or `ArrayType` of `StructType`s with the specified schema. + * Returns `null`, in the case of an unparseable string. * * @param e a string column containing JSON data. * @param schema the schema to use when parsing the json string as a json string. In Spark 2.1, @@ -3322,9 +3323,9 @@ object functions { } /** - * (Scala-specific) Parses a column containing a JSON string into a `StructType` or `ArrayType` - * of `StructType`s with the specified schema. Returns `null`, in the case of an unparseable - * string. + * (Scala-specific) Parses a column containing a JSON string into a `MapType` with `StringType` + * as keys type, `StructType` or `ArrayType` of `StructType`s with the specified schema. + * Returns `null`, in the case of an unparseable string. * * @param e a string column containing JSON data. * @param schema the schema to use when parsing the json string as a json string, it could be a diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala index 00d2acc4a1d8a..055e1fc5640f3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala @@ -326,4 +326,70 @@ class JsonFunctionsSuite extends QueryTest with SharedSQLContext { assert(errMsg4.getMessage.startsWith( "A type of keys and values in map() must be string, but got")) } + + test("SPARK-24027: from_json - map") { + val in = Seq("""{"a": 1, "b": 2, "c": 3}""").toDS() + val schema = + """ + |{ + | "type" : "map", + | "keyType" : "string", + | "valueType" : "integer", + | "valueContainsNull" : true + |} + """.stripMargin + val out = in.select(from_json($"value", schema, Map[String, String]())) + + assert(out.columns.head == "entries") + checkAnswer(out, Row(Map("a" -> 1, "b" -> 2, "c" -> 3))) + } + + test("SPARK-24027: from_json - map") { + val in = Seq("""{"a": {"b": 1}}""").toDS() + val schema = MapType(StringType, new StructType().add("b", IntegerType), true) + val out = in.select(from_json($"value", schema)) + + checkAnswer(out, Row(Map("a" -> Row(1)))) + } + + test("SPARK-24027: from_json - map>") { + val in = Seq("""{"a": {"b": 1}}""").toDS() + val schema = MapType(StringType, MapType(StringType, IntegerType)) + val out = in.select(from_json($"value", schema)) + + checkAnswer(out, Row(Map("a" -> Map("b" -> 1)))) + } + + test("SPARK-24027: roundtrip - from_json -> to_json - map") { + val json = """{"a":1,"b":2,"c":3}""" + val schema = MapType(StringType, IntegerType, true) + val out = Seq(json).toDS().select(to_json(from_json($"value", schema))) + + checkAnswer(out, Row(json)) + } + + test("SPARK-24027: roundtrip - to_json -> from_json - map") { + val in = Seq(Map("a" -> 1)).toDF() + val schema = MapType(StringType, IntegerType, true) + val out = in.select(from_json(to_json($"value"), schema)) + + checkAnswer(out, in) + } + + test("SPARK-24027: from_json - wrong map") { + val in = Seq("""{"a" 1}""").toDS() + val schema = MapType(StringType, IntegerType) + val out = in.select(from_json($"value", schema, Map[String, String]())) + + checkAnswer(out, Row(null)) + } + + test("SPARK-24027: from_json of a map with unsupported key type") { + val schema = MapType(StructType(StructField("f", IntegerType) :: Nil), StringType) + + checkAnswer(Seq("""{{"f": 1}: "a"}""").toDS().select(from_json($"value", schema)), + Row(null)) + checkAnswer(Seq("""{"{"f": 1}": "a"}""").toDS().select(from_json($"value", schema)), + Row(null)) + } } From 061e0084ce19c1384ba271a97a0aa1f87abe879d Mon Sep 17 00:00:00 2001 From: Henry Robinson Date: Mon, 14 May 2018 14:35:08 -0700 Subject: [PATCH 0799/1161] [SPARK-23852][SQL] Add withSQLConf(...) to test case ## What changes were proposed in this pull request? Add a `withSQLConf(...)` wrapper to force Parquet filter pushdown for a test that relies on it. ## How was this patch tested? Test passes Author: Henry Robinson Closes #21323 from henryr/spark-23582. --- .../datasources/parquet/ParquetFilterSuite.scala | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index 4d0ecdef60986..90da7eb8c4fb5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -650,13 +650,15 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex } test("SPARK-23852: Broken Parquet push-down for partially-written stats") { - // parquet-1217.parquet contains a single column with values -1, 0, 1, 2 and null. - // The row-group statistics include null counts, but not min and max values, which - // triggers PARQUET-1217. - val df = readResourceParquetFile("test-data/parquet-1217.parquet") + withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> "true") { + // parquet-1217.parquet contains a single column with values -1, 0, 1, 2 and null. + // The row-group statistics include null counts, but not min and max values, which + // triggers PARQUET-1217. + val df = readResourceParquetFile("test-data/parquet-1217.parquet") - // Will return 0 rows if PARQUET-1217 is not fixed. - assert(df.where("col > 0").count() === 2) + // Will return 0 rows if PARQUET-1217 is not fixed. + assert(df.where("col > 0").count() === 2) + } } } From 9059f1ee6ae13c8636c9b7fdbb708a349256fb8e Mon Sep 17 00:00:00 2001 From: Felix Cheung Date: Mon, 14 May 2018 19:20:25 -0700 Subject: [PATCH 0800/1161] [SPARK-23780][R] Failed to use googleVis library with new SparkR ## What changes were proposed in this pull request? change generic to get it to work with googleVis also fix lintr ## How was this patch tested? manual test, unit tests Author: Felix Cheung Closes #21315 from felixcheung/googvis. --- R/pkg/R/client.R | 5 +++-- R/pkg/R/generics.R | 2 +- R/pkg/R/sparkR.R | 2 +- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/R/pkg/R/client.R b/R/pkg/R/client.R index 14a17c600b17f..4c87f64e7f0e1 100644 --- a/R/pkg/R/client.R +++ b/R/pkg/R/client.R @@ -63,7 +63,7 @@ generateSparkSubmitArgs <- function(args, sparkHome, jars, sparkSubmitOpts, pack checkJavaVersion <- function() { javaBin <- "java" javaHome <- Sys.getenv("JAVA_HOME") - javaReqs <- utils::packageDescription(utils::packageName(), fields=c("SystemRequirements")) + javaReqs <- utils::packageDescription(utils::packageName(), fields = c("SystemRequirements")) sparkJavaVersion <- as.numeric(tail(strsplit(javaReqs, "[(=)]")[[1]], n = 1L)) if (javaHome != "") { javaBin <- file.path(javaHome, "bin", javaBin) @@ -90,7 +90,8 @@ checkJavaVersion <- function() { # Extract 8 from it to compare to sparkJavaVersion javaVersionNum <- as.integer(strsplit(javaVersionStr, "[.]")[[1L]][2]) if (javaVersionNum != sparkJavaVersion) { - stop(paste("Java version", sparkJavaVersion, "is required for this package; found version:", javaVersionStr)) + stop(paste("Java version", sparkJavaVersion, "is required for this package; found version:", + javaVersionStr)) } } diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 61da30badac4e..3ea181157b644 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -624,7 +624,7 @@ setGeneric("summarize", function(x, ...) { standardGeneric("summarize") }) #' @rdname summary setGeneric("summary", function(object, ...) { standardGeneric("summary") }) -setGeneric("toJSON", function(x) { standardGeneric("toJSON") }) +setGeneric("toJSON", function(x, ...) { standardGeneric("toJSON") }) setGeneric("toRDD", function(x) { standardGeneric("toRDD") }) diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R index d6a2d08f9c218..f7c1663d32c96 100644 --- a/R/pkg/R/sparkR.R +++ b/R/pkg/R/sparkR.R @@ -194,7 +194,7 @@ sparkR.sparkContext <- function( # Don't use readString() so that we can provide a useful # error message if the R and Java versions are mismatched. - authSecretLen = readInt(f) + authSecretLen <- readInt(f) if (length(authSecretLen) == 0 || authSecretLen == 0) { stop("Unexpected EOF in JVM connection data. Mismatched versions?") } From e29176fd7dbcef04a29c4922ba655d58144fed24 Mon Sep 17 00:00:00 2001 From: Goun Na Date: Tue, 15 May 2018 14:11:20 +0800 Subject: [PATCH 0801/1161] [SPARK-23627][SQL] Provide isEmpty in Dataset ## What changes were proposed in this pull request? This PR adds isEmpty() in DataSet ## How was this patch tested? Unit tests added Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Goun Na Author: goungoun Closes #20800 from goungoun/SPARK-23627. --- .../src/main/scala/org/apache/spark/sql/Dataset.scala | 10 ++++++++++ .../test/scala/org/apache/spark/sql/DatasetSuite.scala | 8 ++++++++ 2 files changed, 18 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index d518e07bfb62c..f001f16e1d5ee 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -511,6 +511,16 @@ class Dataset[T] private[sql]( */ def isLocal: Boolean = logicalPlan.isInstanceOf[LocalRelation] + /** + * Returns true if the `Dataset` is empty. + * + * @group basic + * @since 2.4.0 + */ + def isEmpty: Boolean = withAction("isEmpty", limit(1).groupBy().count().queryExecution) { plan => + plan.executeCollect().head.getLong(0) == 0 + } + /** * Returns true if this Dataset contains one or more sources that continuously * return data as it arrives. A Dataset that reads data from a streaming source diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index e0f4d2ba685e1..d477d78dc14e3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -1425,6 +1425,14 @@ class DatasetSuite extends QueryTest with SharedSQLContext { } } + test("SPARK-23627: provide isEmpty in DataSet") { + val ds1 = spark.emptyDataset[Int] + val ds2 = Seq(1, 2, 3).toDS() + + assert(ds1.isEmpty == true) + assert(ds2.isEmpty == false) + } + test("SPARK-22472: add null check for top-level primitive values") { // If the primitive values are from Option, we need to do runtime null check. val ds = Seq(Some(1), None).toDS().as[Int] From 80c6d35a3edbfb2e053c7d6650e2f725c36af53e Mon Sep 17 00:00:00 2001 From: maryannxue Date: Mon, 14 May 2018 23:34:42 -0700 Subject: [PATCH 0802/1161] [SPARK-24035][SQL] SQL syntax for Pivot - fix antlr warning ## What changes were proposed in this pull request? 1. Change antlr rule to fix the warning. 2. Add PIVOT/LATERAL check in AstBuilder with a more meaningful error message. ## How was this patch tested? 1. Add a counter case in `PlanParserSuite.test("lateral view")` Author: maryannxue Closes #21324 from maryannxue/spark-24035-fix. --- .../org/apache/spark/sql/catalyst/parser/SqlBase.g4 | 2 +- .../apache/spark/sql/catalyst/parser/AstBuilder.scala | 3 +++ .../spark/sql/catalyst/parser/PlanParserSuite.scala | 10 ++++++++++ 3 files changed, 14 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index f7f921ec22c35..7c54851097af3 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -398,7 +398,7 @@ hintStatement ; fromClause - : FROM relation (',' relation)* (pivotClause | lateralView*)? + : FROM relation (',' relation)* lateralView* pivotClause? ; aggregation diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 64eed23884584..b9ece295c2510 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -504,6 +504,9 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging withJoinRelations(join, relation) } if (ctx.pivotClause() != null) { + if (!ctx.lateralView.isEmpty) { + throw new ParseException("LATERAL cannot be used together with PIVOT in FROM clause", ctx) + } withPivot(ctx.pivotClause, from) } else { ctx.lateralView.asScala.foldLeft(from)(withGenerate) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala index 812bfdd7bb885..fb51376c6163f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -318,6 +318,16 @@ class PlanParserSuite extends AnalysisTest { assertEqual( "select * from t lateral view posexplode(x) posexpl as x, y", expected) + + intercept( + """select * + |from t + |lateral view explode(x) expl + |pivot ( + | sum(x) + | FOR y IN ('a', 'b') + |)""".stripMargin, + "LATERAL cannot be used together with PIVOT in FROM clause") } test("joins") { From 4a2b15f0af400c71b7f20b2048f38a8b74d43dfa Mon Sep 17 00:00:00 2001 From: Kent Yao Date: Tue, 15 May 2018 16:04:17 +0800 Subject: [PATCH 0803/1161] [SPARK-24241][SUBMIT] Do not fail fast when dynamic resource allocation enabled with 0 executor ## What changes were proposed in this pull request? ``` ~/spark-2.3.0-bin-hadoop2.7$ bin/spark-sql --num-executors 0 --conf spark.dynamicAllocation.enabled=true Java HotSpot(TM) 64-Bit Server VM warning: ignoring option PermSize=1024m; support was removed in 8.0 Java HotSpot(TM) 64-Bit Server VM warning: ignoring option MaxPermSize=1024m; support was removed in 8.0 Error: Number of executors must be a positive number Run with --help for usage help or --verbose for debug output ``` Actually, we could start up with min executor number with 0 before if dynamically ## How was this patch tested? ut added Author: Kent Yao Closes #21290 from yaooqinn/SPARK-24241. --- .../spark/deploy/SparkSubmitArguments.scala | 7 +++++-- .../spark/deploy/SparkSubmitSuite.scala | 20 +++++++++++++++++++ 2 files changed, 25 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index 0733fdb72cafb..fed4e0a5069c3 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -36,7 +36,6 @@ import org.apache.spark.launcher.SparkSubmitArgumentsParser import org.apache.spark.network.util.JavaUtils import org.apache.spark.util.Utils - /** * Parses and encapsulates arguments from the spark-submit script. * The env argument is used for testing. @@ -76,6 +75,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S var proxyUser: String = null var principal: String = null var keytab: String = null + private var dynamicAllocationEnabled: Boolean = false // Standalone cluster mode only var supervise: Boolean = false @@ -198,6 +198,8 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S queue = Option(queue).orElse(sparkProperties.get("spark.yarn.queue")).orNull keytab = Option(keytab).orElse(sparkProperties.get("spark.yarn.keytab")).orNull principal = Option(principal).orElse(sparkProperties.get("spark.yarn.principal")).orNull + dynamicAllocationEnabled = + sparkProperties.get("spark.dynamicAllocation.enabled").exists("true".equalsIgnoreCase) // Try to set main class from JAR if no --class argument is given if (mainClass == null && !isPython && !isR && primaryResource != null) { @@ -274,7 +276,8 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S if (totalExecutorCores != null && Try(totalExecutorCores.toInt).getOrElse(-1) <= 0) { error("Total executor cores must be a positive number") } - if (numExecutors != null && Try(numExecutors.toInt).getOrElse(-1) <= 0) { + if (!dynamicAllocationEnabled && + numExecutors != null && Try(numExecutors.toInt).getOrElse(-1) <= 0) { error("Number of executors must be a positive number") } if (pyFiles != null && !isPython) { diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index 7451e07b25a1f..43286953e4383 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -180,6 +180,26 @@ class SparkSubmitSuite appArgs.toString should include ("thequeue") } + test("SPARK-24241: do not fail fast if executor num is 0 when dynamic allocation is enabled") { + val clArgs1 = Seq( + "--name", "myApp", + "--class", "Foo", + "--num-executors", "0", + "--conf", "spark.dynamicAllocation.enabled=true", + "thejar.jar") + new SparkSubmitArguments(clArgs1) + + val clArgs2 = Seq( + "--name", "myApp", + "--class", "Foo", + "--num-executors", "0", + "--conf", "spark.dynamicAllocation.enabled=false", + "thejar.jar") + + val e = intercept[SparkException](new SparkSubmitArguments(clArgs2)) + assert(e.getMessage.contains("Number of executors must be a positive number")) + } + test("specify deploy mode through configuration") { val clArgs = Seq( "--master", "yarn", From d610d2a3f57ca551f72cb4e5dfed78f27be62eec Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 15 May 2018 22:06:58 +0800 Subject: [PATCH 0804/1161] [SPARK-24259][SQL] ArrayWriter for Arrow produces wrong output ## What changes were proposed in this pull request? Right now `ArrayWriter` used to output Arrow data for array type, doesn't do `clear` or `reset` after each batch. It produces wrong output. ## How was this patch tested? Added test. Author: Liang-Chi Hsieh Closes #21312 from viirya/SPARK-24259. --- python/pyspark/sql/tests.py | 20 +++++++++++++++++++ .../sql/execution/arrow/ArrowWriter.scala | 8 ++++++++ 2 files changed, 28 insertions(+) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 16aa9378ad8ee..a1b6db71782bb 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -4680,6 +4680,26 @@ def test_supported_types(self): self.assertPandasEqual(expected2, result2) self.assertPandasEqual(expected3, result3) + def test_array_type_correct(self): + from pyspark.sql.functions import pandas_udf, PandasUDFType, array, col + + df = self.data.withColumn("arr", array(col("id"))).repartition(1, "id") + + output_schema = StructType( + [StructField('id', LongType()), + StructField('v', IntegerType()), + StructField('arr', ArrayType(LongType()))]) + + udf = pandas_udf( + lambda pdf: pdf, + output_schema, + PandasUDFType.GROUPED_MAP + ) + + result = df.groupby('id').apply(udf).sort('id').toPandas() + expected = df.toPandas().groupby('id').apply(udf.func).reset_index(drop=True) + self.assertPandasEqual(expected, result) + def test_register_grouped_map_udf(self): from pyspark.sql.functions import pandas_udf, PandasUDFType diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala index 22b63513548fe..66888fce7f9f5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala @@ -133,6 +133,14 @@ private[arrow] abstract class ArrowFieldWriter { valueVector match { case fixedWidthVector: BaseFixedWidthVector => fixedWidthVector.reset() case variableWidthVector: BaseVariableWidthVector => variableWidthVector.reset() + case listVector: ListVector => + // Manual "reset" the underlying buffer. + // TODO: When we upgrade to Arrow 0.10.0, we can simply remove this and call + // `listVector.reset()`. + val buffers = listVector.getBuffers(false) + buffers.foreach(buf => buf.setZero(0, buf.capacity())) + listVector.setValueCount(0) + listVector.setLastSet(0) case _ => } count = 0 From 3fabbc576203c7fd63808a259adafc5c3cea1838 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Tue, 15 May 2018 10:25:29 -0700 Subject: [PATCH 0805/1161] [SPARK-24040][SS] Support single partition aggregates in continuous processing. ## What changes were proposed in this pull request? Support aggregates with exactly 1 partition in continuous processing. A few small tweaks are needed to make this work: * Replace currentEpoch tracking with an ThreadLocal. This means that current epoch is scoped to a task rather than a node, but I think that's sustainable even once we add shuffle. * Add a new testing-only flag to disable the UnsupportedOperationChecker whitelist of allowed continuous processing nodes. I think this is preferable to writing a pile of custom logic to enforce that there is in fact only 1 partition; we plan to support multi-partition aggregates before the next Spark release, so we'd just have to tear that logic back out. * Restart continuous processing queries from the first available uncommitted epoch, rather than one that's guaranteed to be unused. This is required for stateful operators to overwrite partial state from the previous attempt at the epoch, and there was no specific motivation for the original strategy. In another PR before stabilizing the StreamWriter API, we'll need to narrow down and document more precise semantic guarantees for the epoch IDs. * We need a single-partition ContinuousMemoryStream. The way MemoryStream is constructed means it can't be a text option like it is for rate source, unfortunately. ## How was this patch tested? new unit tests Author: Jose Torres Closes #21239 from jose-torres/withAggr. --- .../UnsupportedOperationChecker.scala | 1 + .../continuous/ContinuousExecution.scala | 11 +-- .../ContinuousQueuedDataReader.scala | 7 +- .../continuous/ContinuousWriteRDD.scala | 18 +++-- .../streaming/continuous/EpochTracker.scala | 58 +++++++++++++++ .../sources/ContinuousMemoryStream.scala | 14 ++-- .../streaming/state/StateStoreRDD.scala | 10 ++- .../sql/streaming/StreamingQueryManager.scala | 4 +- .../ContinuousAggregationSuite.scala | 72 +++++++++++++++++++ .../ContinuousQueuedDataReaderSuite.scala | 1 + 10 files changed, 167 insertions(+), 29 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochTracker.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousAggregationSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala index d3d6c636c4ba8..2bed41672fe33 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.streaming.InternalOutputModes +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.OutputMode /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index f58146ac42398..0e7d1019b9c8f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -122,16 +122,7 @@ class ContinuousExecution( s"Batch $latestEpochId was committed without end epoch offsets!") } committedOffsets = nextOffsets.toStreamProgress(sources) - - // Get to an epoch ID that has definitely never been sent to a sink before. Since sink - // commit happens between offset log write and commit log write, this means an epoch ID - // which is not in the offset log. - val (latestOffsetEpoch, _) = offsetLog.getLatest().getOrElse { - throw new IllegalStateException( - s"Offset log had no latest element. This shouldn't be possible because nextOffsets is" + - s"an element.") - } - currentBatchId = latestOffsetEpoch + 1 + currentBatchId = latestEpochId + 1 logDebug(s"Resuming at epoch $currentBatchId with committed offsets $committedOffsets") nextOffsets diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala index d8645576c2052..f38577b6a9f16 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala @@ -46,8 +46,6 @@ class ContinuousQueuedDataReader( // Important sequencing - we must get our starting point before the provider threads start running private var currentOffset: PartitionOffset = ContinuousDataSourceRDD.getContinuousReader(reader).getOffset - private var currentEpoch: Long = - context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong /** * The record types in the read buffer. @@ -115,8 +113,7 @@ class ContinuousQueuedDataReader( currentEntry match { case EpochMarker => epochCoordEndpoint.send(ReportPartitionOffset( - context.partitionId(), currentEpoch, currentOffset)) - currentEpoch += 1 + context.partitionId(), EpochTracker.getCurrentEpoch.get, currentOffset)) null case ContinuousRow(row, offset) => currentOffset = offset @@ -184,7 +181,7 @@ class ContinuousQueuedDataReader( private val epochCoordEndpoint = EpochCoordinatorRef.get( context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), SparkEnv.get) - // Note that this is *not* the same as the currentEpoch in [[ContinuousDataQueuedReader]]! That + // Note that this is *not* the same as the currentEpoch in [[ContinuousWriteRDD]]! That // field represents the epoch wrt the data being processed. The currentEpoch here is just a // counter to ensure we send the appropriate number of markers if we fall behind the driver. private var currentEpoch = context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala index 91f1576581511..ef5f0da1e7cc2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala @@ -45,7 +45,8 @@ class ContinuousWriteRDD(var prev: RDD[InternalRow], writeTask: DataWriterFactor val epochCoordinator = EpochCoordinatorRef.get( context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), SparkEnv.get) - var currentEpoch = context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong + EpochTracker.initializeCurrentEpoch( + context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong) while (!context.isInterrupted() && !context.isCompleted()) { var dataWriter: DataWriter[InternalRow] = null @@ -54,19 +55,24 @@ class ContinuousWriteRDD(var prev: RDD[InternalRow], writeTask: DataWriterFactor try { val dataIterator = prev.compute(split, context) dataWriter = writeTask.createDataWriter( - context.partitionId(), context.attemptNumber(), currentEpoch) + context.partitionId(), + context.attemptNumber(), + EpochTracker.getCurrentEpoch.get) while (dataIterator.hasNext) { dataWriter.write(dataIterator.next()) } logInfo(s"Writer for partition ${context.partitionId()} " + - s"in epoch $currentEpoch is committing.") + s"in epoch ${EpochTracker.getCurrentEpoch.get} is committing.") val msg = dataWriter.commit() epochCoordinator.send( - CommitPartitionEpoch(context.partitionId(), currentEpoch, msg) + CommitPartitionEpoch( + context.partitionId(), + EpochTracker.getCurrentEpoch.get, + msg) ) logInfo(s"Writer for partition ${context.partitionId()} " + - s"in epoch $currentEpoch committed.") - currentEpoch += 1 + s"in epoch ${EpochTracker.getCurrentEpoch.get} committed.") + EpochTracker.incrementCurrentEpoch() } catch { case _: InterruptedException => // Continuous shutdown always involves an interrupt. Just finish the task. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochTracker.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochTracker.scala new file mode 100644 index 0000000000000..bc0ae428d4521 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochTracker.scala @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.continuous + +import java.util.concurrent.atomic.AtomicLong + +/** + * Tracks the current continuous processing epoch within a task. Call + * EpochTracker.getCurrentEpoch to get the current epoch. + */ +object EpochTracker { + // The current epoch. Note that this is a shared reference; ContinuousWriteRDD.compute() will + // update the underlying AtomicLong as it finishes epochs. Other code should only read the value. + private val currentEpoch: ThreadLocal[AtomicLong] = new ThreadLocal[AtomicLong] { + override def initialValue() = new AtomicLong(-1) + } + + /** + * Get the current epoch for the current task, or None if the task has no current epoch. + */ + def getCurrentEpoch: Option[Long] = { + currentEpoch.get().get() match { + case n if n < 0 => None + case e => Some(e) + } + } + + /** + * Increment the current epoch for this task thread. Should be called by [[ContinuousWriteRDD]] + * between epochs. + */ + def incrementCurrentEpoch(): Unit = { + currentEpoch.get().incrementAndGet() + } + + /** + * Initialize the current epoch for this task thread. Should be called by [[ContinuousWriteRDD]] + * at the beginning of a task. + */ + def initializeCurrentEpoch(startEpoch: Long): Unit = { + currentEpoch.get().set(startEpoch) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala index fef792eab69d5..4daafa65850de 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala @@ -47,10 +47,9 @@ import org.apache.spark.util.RpcUtils * ContinuousMemoryStreamDataReader instances to poll. It returns the record at the specified * offset within the list, or null if that offset doesn't yet have a record. */ -class ContinuousMemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) +class ContinuousMemoryStream[A : Encoder](id: Int, sqlContext: SQLContext, numPartitions: Int = 2) extends MemoryStreamBase[A](sqlContext) with ContinuousReader with ContinuousReadSupport { private implicit val formats = Serialization.formats(NoTypeHints) - private val NUM_PARTITIONS = 2 protected val logicalPlan = StreamingRelationV2(this, "memory", Map(), attributes, None)(sqlContext.sparkSession) @@ -58,7 +57,7 @@ class ContinuousMemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) // ContinuousReader implementation @GuardedBy("this") - private val records = Seq.fill(NUM_PARTITIONS)(new ListBuffer[A]) + private val records = Seq.fill(numPartitions)(new ListBuffer[A]) @GuardedBy("this") private var startOffset: ContinuousMemoryStreamOffset = _ @@ -69,17 +68,17 @@ class ContinuousMemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) def addData(data: TraversableOnce[A]): Offset = synchronized { // Distribute data evenly among partition lists. data.toSeq.zipWithIndex.map { - case (item, index) => records(index % NUM_PARTITIONS) += item + case (item, index) => records(index % numPartitions) += item } // The new target offset is the offset where all records in all partitions have been processed. - ContinuousMemoryStreamOffset((0 until NUM_PARTITIONS).map(i => (i, records(i).size)).toMap) + ContinuousMemoryStreamOffset((0 until numPartitions).map(i => (i, records(i).size)).toMap) } override def setStartOffset(start: Optional[Offset]): Unit = synchronized { // Inferred initial offset is position 0 in each partition. startOffset = start.orElse { - ContinuousMemoryStreamOffset((0 until NUM_PARTITIONS).map(i => (i, 0)).toMap) + ContinuousMemoryStreamOffset((0 until numPartitions).map(i => (i, 0)).toMap) }.asInstanceOf[ContinuousMemoryStreamOffset] } @@ -152,6 +151,9 @@ object ContinuousMemoryStream { def apply[A : Encoder](implicit sqlContext: SQLContext): ContinuousMemoryStream[A] = new ContinuousMemoryStream[A](memoryStreamId.getAndIncrement(), sqlContext) + + def singlePartition[A : Encoder](implicit sqlContext: SQLContext): ContinuousMemoryStream[A] = + new ContinuousMemoryStream[A](memoryStreamId.getAndIncrement(), sqlContext, 1) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala index 01d8e75980993..3f11b8f79943c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala @@ -23,6 +23,7 @@ import scala.reflect.ClassTag import org.apache.spark.{Partition, TaskContext} import org.apache.spark.rdd.RDD +import org.apache.spark.sql.execution.streaming.continuous.EpochTracker import org.apache.spark.sql.internal.SessionState import org.apache.spark.sql.types.StructType import org.apache.spark.util.SerializableConfiguration @@ -71,8 +72,15 @@ class StateStoreRDD[T: ClassTag, U: ClassTag]( StateStoreId(checkpointLocation, operatorId, partition.index), queryRunId) + // If we're in continuous processing mode, we should get the store version for the current + // epoch rather than the one at planning time. + val currentVersion = EpochTracker.getCurrentEpoch match { + case None => storeVersion + case Some(value) => value + } + store = StateStore.get( - storeProviderId, keySchema, valueSchema, indexOrdinal, storeVersion, + storeProviderId, keySchema, valueSchema, indexOrdinal, currentVersion, storeConf, hadoopConfBroadcast.value.value) val inputIter = dataRDD.iterator(partition, ctxt) storeUpdateFunction(store, inputIter) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala index 7cefd03e43bc3..97da2b1325f58 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala @@ -242,7 +242,9 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Lo (sink, trigger) match { case (v2Sink: StreamWriteSupport, trigger: ContinuousTrigger) => - UnsupportedOperationChecker.checkForContinuous(analyzedPlan, outputMode) + if (sparkSession.sessionState.conf.isUnsupportedOperationCheckEnabled) { + UnsupportedOperationChecker.checkForContinuous(analyzedPlan, outputMode) + } new StreamingQueryWrapper(new ContinuousExecution( sparkSession, userSpecifiedName.orNull, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousAggregationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousAggregationSuite.scala new file mode 100644 index 0000000000000..b7ef637f5270e --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousAggregationSuite.scala @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming.continuous + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.execution.streaming.sources.ContinuousMemoryStream +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.streaming.OutputMode + +class ContinuousAggregationSuite extends ContinuousSuiteBase { + import testImplicits._ + + test("not enabled") { + val ex = intercept[AnalysisException] { + val input = ContinuousMemoryStream.singlePartition[Int] + testStream(input.toDF().agg(max('value)), OutputMode.Complete)() + } + + assert(ex.getMessage.contains("Continuous processing does not support Aggregate operations")) + } + + test("basic") { + withSQLConf(("spark.sql.streaming.unsupportedOperationCheck", "false")) { + val input = ContinuousMemoryStream.singlePartition[Int] + + testStream(input.toDF().agg(max('value)), OutputMode.Complete)( + AddData(input, 0, 1, 2), + CheckAnswer(2), + StopStream, + AddData(input, 3, 4, 5), + StartStream(), + CheckAnswer(5), + AddData(input, -1, -2, -3), + CheckAnswer(5)) + } + } + + test("repeated restart") { + withSQLConf(("spark.sql.streaming.unsupportedOperationCheck", "false")) { + val input = ContinuousMemoryStream.singlePartition[Int] + + testStream(input.toDF().agg(max('value)), OutputMode.Complete)( + AddData(input, 0, 1, 2), + CheckAnswer(2), + StopStream, + StartStream(), + StopStream, + StartStream(), + StopStream, + StartStream(), + AddData(input, 0), + CheckAnswer(2), + AddData(input, 5), + CheckAnswer(5)) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala index f47d3ec8ae025..e663fa8312da4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala @@ -51,6 +51,7 @@ class ContinuousQueuedDataReaderSuite extends StreamTest with MockitoSugar { startEpoch, spark, SparkEnv.get) + EpochTracker.initializeCurrentEpoch(0) } override def afterEach(): Unit = { From 6b94420f6c672683678a54404e6341a0b9ab3c24 Mon Sep 17 00:00:00 2001 From: Lu WANG Date: Tue, 15 May 2018 14:16:31 -0700 Subject: [PATCH 0806/1161] [SPARK-24231][PYSPARK][ML] Provide Python API for evaluateEachIteration for spark.ml GBTs ## What changes were proposed in this pull request? Add evaluateEachIteration for GBTClassification and GBTRegressionModel ## How was this patch tested? doctest Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Lu WANG Closes #21335 from ludatabricks/SPARK-14682. --- python/pyspark/ml/classification.py | 15 +++++++++++++++ python/pyspark/ml/regression.py | 18 ++++++++++++++++++ 2 files changed, 33 insertions(+) diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index ec17653a1adf9..424ecfd89b060 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -1222,6 +1222,10 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol True >>> model.trees [DecisionTreeRegressionModel (uid=...) of depth..., DecisionTreeRegressionModel...] + >>> validation = spark.createDataFrame([(0.0, Vectors.dense(-1.0),)], + ... ["indexed", "features"]) + >>> model.evaluateEachIteration(validation) + [0.25..., 0.23..., 0.21..., 0.19..., 0.18...] .. versionadded:: 1.4.0 """ @@ -1319,6 +1323,17 @@ def trees(self): """Trees in this ensemble. Warning: These have null parent Estimators.""" return [DecisionTreeRegressionModel(m) for m in list(self._call_java("trees"))] + @since("2.4.0") + def evaluateEachIteration(self, dataset): + """ + Method to compute error or loss for every iteration of gradient boosting. + + :param dataset: + Test dataset to evaluate model on, where dataset is an + instance of :py:class:`pyspark.sql.DataFrame` + """ + return self._call_java("evaluateEachIteration", dataset) + @inherit_doc class NaiveBayes(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasProbabilityCol, diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index 9a66d87d7f211..dd0b62f184d26 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -1056,6 +1056,10 @@ class GBTRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, True >>> model.trees [DecisionTreeRegressionModel (uid=...) of depth..., DecisionTreeRegressionModel...] + >>> validation = spark.createDataFrame([(0.0, Vectors.dense(-1.0))], + ... ["label", "features"]) + >>> model.evaluateEachIteration(validation, "squared") + [0.0, 0.0, 0.0, 0.0, 0.0] .. versionadded:: 1.4.0 """ @@ -1156,6 +1160,20 @@ def trees(self): """Trees in this ensemble. Warning: These have null parent Estimators.""" return [DecisionTreeRegressionModel(m) for m in list(self._call_java("trees"))] + @since("2.4.0") + def evaluateEachIteration(self, dataset, loss): + """ + Method to compute error or loss for every iteration of gradient boosting. + + :param dataset: + Test dataset to evaluate model on, where dataset is an + instance of :py:class:`pyspark.sql.DataFrame` + :param loss: + The loss function used to compute error. + Supported options: squared, absolute + """ + return self._call_java("evaluateEachIteration", dataset, loss) + @inherit_doc class AFTSurvivalRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, From 8a13c5096898f95d1dfcedaf5d31205a1cbf0a19 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 15 May 2018 16:50:09 -0700 Subject: [PATCH 0807/1161] [SPARK-24058][ML][PYSPARK] Default Params in ML should be saved separately: Python API ## What changes were proposed in this pull request? See SPARK-23455 for reference. Now default params in ML are saved separately in metadata file in Scala. We must change it for Python for Spark 2.4.0 as well in order to keep them in sync. ## How was this patch tested? Added test. Author: Liang-Chi Hsieh Closes #21153 from viirya/SPARK-24058. --- python/pyspark/ml/tests.py | 38 ++++++++++++++++++++++++++++++++++++++ python/pyspark/ml/util.py | 30 ++++++++++++++++++++++++++++-- 2 files changed, 66 insertions(+), 2 deletions(-) diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 093593132e56d..0dde0db9e3339 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -1595,6 +1595,44 @@ def test_default_read_write(self): self.assertEqual(lr.uid, lr3.uid) self.assertEqual(lr.extractParamMap(), lr3.extractParamMap()) + def test_default_read_write_default_params(self): + lr = LogisticRegression() + self.assertFalse(lr.isSet(lr.getParam("threshold"))) + + lr.setMaxIter(50) + lr.setThreshold(.75) + + # `threshold` is set by user, default param `predictionCol` is not set by user. + self.assertTrue(lr.isSet(lr.getParam("threshold"))) + self.assertFalse(lr.isSet(lr.getParam("predictionCol"))) + self.assertTrue(lr.hasDefault(lr.getParam("predictionCol"))) + + writer = DefaultParamsWriter(lr) + metadata = json.loads(writer._get_metadata_to_save(lr, self.sc)) + self.assertTrue("defaultParamMap" in metadata) + + reader = DefaultParamsReadable.read() + metadataStr = json.dumps(metadata, separators=[',', ':']) + loadedMetadata = reader._parseMetaData(metadataStr, ) + reader.getAndSetParams(lr, loadedMetadata) + + self.assertTrue(lr.isSet(lr.getParam("threshold"))) + self.assertFalse(lr.isSet(lr.getParam("predictionCol"))) + self.assertTrue(lr.hasDefault(lr.getParam("predictionCol"))) + + # manually create metadata without `defaultParamMap` section. + del metadata['defaultParamMap'] + metadataStr = json.dumps(metadata, separators=[',', ':']) + loadedMetadata = reader._parseMetaData(metadataStr, ) + with self.assertRaisesRegexp(AssertionError, "`defaultParamMap` section not found"): + reader.getAndSetParams(lr, loadedMetadata) + + # Prior to 2.4.0, metadata doesn't have `defaultParamMap`. + metadata['sparkVersion'] = '2.3.0' + metadataStr = json.dumps(metadata, separators=[',', ':']) + loadedMetadata = reader._parseMetaData(metadataStr, ) + reader.getAndSetParams(lr, loadedMetadata) + class LDATest(SparkSessionTestCase): diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py index a486c6a3fdeb5..9fa85664939b8 100644 --- a/python/pyspark/ml/util.py +++ b/python/pyspark/ml/util.py @@ -30,6 +30,7 @@ from pyspark import SparkContext, since from pyspark.ml.common import inherit_doc from pyspark.sql import SparkSession +from pyspark.util import VersionUtils def _jvm(): @@ -396,6 +397,7 @@ def saveMetadata(instance, path, sc, extraMetadata=None, paramMap=None): - sparkVersion - uid - paramMap + - defaultParamMap (since 2.4.0) - (optionally, extra metadata) :param extraMetadata: Extra metadata to be saved at same level as uid, paramMap, etc. :param paramMap: If given, this is saved in the "paramMap" field. @@ -417,15 +419,24 @@ def _get_metadata_to_save(instance, sc, extraMetadata=None, paramMap=None): """ uid = instance.uid cls = instance.__module__ + '.' + instance.__class__.__name__ - params = instance.extractParamMap() + + # User-supplied param values + params = instance._paramMap jsonParams = {} if paramMap is not None: jsonParams = paramMap else: for p in params: jsonParams[p.name] = params[p] + + # Default param values + jsonDefaultParams = {} + for p in instance._defaultParamMap: + jsonDefaultParams[p.name] = instance._defaultParamMap[p] + basicMetadata = {"class": cls, "timestamp": long(round(time.time() * 1000)), - "sparkVersion": sc.version, "uid": uid, "paramMap": jsonParams} + "sparkVersion": sc.version, "uid": uid, "paramMap": jsonParams, + "defaultParamMap": jsonDefaultParams} if extraMetadata is not None: basicMetadata.update(extraMetadata) return json.dumps(basicMetadata, separators=[',', ':']) @@ -523,11 +534,26 @@ def getAndSetParams(instance, metadata): """ Extract Params from metadata, and set them in the instance. """ + # Set user-supplied param values for paramName in metadata['paramMap']: param = instance.getParam(paramName) paramValue = metadata['paramMap'][paramName] instance.set(param, paramValue) + # Set default param values + majorAndMinorVersions = VersionUtils.majorMinorVersion(metadata['sparkVersion']) + major = majorAndMinorVersions[0] + minor = majorAndMinorVersions[1] + + # For metadata file prior to Spark 2.4, there is no default section. + if major > 2 or (major == 2 and minor >= 4): + assert 'defaultParamMap' in metadata, "Error loading metadata: Expected " + \ + "`defaultParamMap` section not found" + + for paramName in metadata['defaultParamMap']: + paramValue = metadata['defaultParamMap'][paramName] + instance._setDefault(**{paramName: paramValue}) + @staticmethod def loadParamsInstance(path, sc): """ From 943493b165185c5362c8350dd355276cc458aad0 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 16 May 2018 22:01:24 +0800 Subject: [PATCH 0808/1161] =?UTF-8?q?Revert=20"[SPARK-22938][SQL][FOLLOWUP?= =?UTF-8?q?]=20Assert=20that=20SQLConf.get=20is=20acces=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …sed only on the driver" This reverts commit a4206d58e05ab9ed6f01fee57e18dee65cbc4efc. This is from https://github.com/apache/spark/pull/21299 and to ease the review of it. Author: Wenchen Fan Closes #21341 from cloud-fan/revert. --- .../sql/catalyst/analysis/CheckAnalysis.scala | 5 +- .../analysis/ResolveInlineTables.scala | 4 +- .../sql/catalyst/analysis/TypeCoercion.scala | 156 ++++++++---------- .../apache/spark/sql/internal/SQLConf.scala | 16 +- .../org/apache/spark/sql/types/DataType.scala | 8 +- .../catalyst/analysis/TypeCoercionSuite.scala | 70 ++++---- .../org/apache/spark/sql/SparkSession.scala | 21 +-- .../datasources/PartitioningUtils.scala | 5 +- .../datasources/json/JsonInferSchema.scala | 39 ++--- .../datasources/json/JsonSuite.scala | 4 +- 10 files changed, 140 insertions(+), 188 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 94b0561529e71..90bda2a72ad82 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -24,7 +24,6 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.optimizer.BooleanSimplification import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ /** @@ -261,9 +260,7 @@ trait CheckAnalysis extends PredicateHelper { // Check if the data types match. dataTypes(child).zip(ref).zipWithIndex.foreach { case ((dt1, dt2), ci) => // SPARK-18058: we shall not care about the nullability of columns - val widerType = TypeCoercion.findWiderTypeForTwo( - dt1.asNullable, dt2.asNullable, SQLConf.get.caseSensitiveAnalysis) - if (widerType.isEmpty) { + if (TypeCoercion.findWiderTypeForTwo(dt1.asNullable, dt2.asNullable).isEmpty) { failAnalysis( s""" |${operator.nodeName} can only be performed on tables with the compatible diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala index 31ba9d792024b..71ed75454cd4d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala @@ -83,9 +83,7 @@ case class ResolveInlineTables(conf: SQLConf) extends Rule[LogicalPlan] with Cas // For each column, traverse all the values and find a common data type and nullability. val fields = table.rows.transpose.zip(table.names).map { case (column, name) => val inputTypes = column.map(_.dataType) - val wideType = TypeCoercion.findWiderTypeWithoutStringPromotion( - inputTypes, conf.caseSensitiveAnalysis) - val tpe = wideType.getOrElse { + val tpe = TypeCoercion.findWiderTypeWithoutStringPromotion(inputTypes).getOrElse { table.failAnalysis(s"incompatible types found in column $name for inline table") } StructField(name, tpe, nullable = column.exists(_.nullable)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index a7ba201509b78..b2817b0538a7f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -48,18 +48,18 @@ object TypeCoercion { def typeCoercionRules(conf: SQLConf): List[Rule[LogicalPlan]] = InConversion(conf) :: - WidenSetOperationTypes(conf) :: + WidenSetOperationTypes :: PromoteStrings(conf) :: DecimalPrecision :: BooleanEquality :: - FunctionArgumentConversion(conf) :: + FunctionArgumentConversion :: ConcatCoercion(conf) :: EltCoercion(conf) :: - CaseWhenCoercion(conf) :: - IfCoercion(conf) :: + CaseWhenCoercion :: + IfCoercion :: StackCoercion :: Division :: - ImplicitTypeCasts(conf) :: + new ImplicitTypeCasts(conf) :: DateTimeOperations :: WindowFrameCoercion :: Nil @@ -83,10 +83,7 @@ object TypeCoercion { * with primitive types, because in that case the precision and scale of the result depends on * the operation. Those rules are implemented in [[DecimalPrecision]]. */ - def findTightestCommonType( - left: DataType, - right: DataType, - caseSensitive: Boolean): Option[DataType] = (left, right) match { + val findTightestCommonType: (DataType, DataType) => Option[DataType] = { case (t1, t2) if t1 == t2 => Some(t1) case (NullType, t1) => Some(t1) case (t1, NullType) => Some(t1) @@ -105,32 +102,22 @@ object TypeCoercion { case (_: TimestampType, _: DateType) | (_: DateType, _: TimestampType) => Some(TimestampType) - case (t1 @ StructType(fields1), t2 @ StructType(fields2)) => - val isSameType = if (caseSensitive) { - DataType.equalsIgnoreNullability(t1, t2) - } else { - DataType.equalsIgnoreCaseAndNullability(t1, t2) - } - - if (isSameType) { - Some(StructType(fields1.zip(fields2).map { case (f1, f2) => - // Since t1 is same type of t2, two StructTypes have the same DataType - // except `name` (in case of `spark.sql.caseSensitive=false`) and `nullable`. - // - Different names: use f1.name - // - Different nullabilities: `nullable` is true iff one of them is nullable. - val dataType = findTightestCommonType(f1.dataType, f2.dataType, caseSensitive).get - StructField(f1.name, dataType, nullable = f1.nullable || f2.nullable) - })) - } else { - None - } + case (t1 @ StructType(fields1), t2 @ StructType(fields2)) if t1.sameType(t2) => + Some(StructType(fields1.zip(fields2).map { case (f1, f2) => + // Since `t1.sameType(t2)` is true, two StructTypes have the same DataType + // except `name` (in case of `spark.sql.caseSensitive=false`) and `nullable`. + // - Different names: use f1.name + // - Different nullabilities: `nullable` is true iff one of them is nullable. + val dataType = findTightestCommonType(f1.dataType, f2.dataType).get + StructField(f1.name, dataType, nullable = f1.nullable || f2.nullable) + })) case (a1 @ ArrayType(et1, hasNull1), a2 @ ArrayType(et2, hasNull2)) if a1.sameType(a2) => - findTightestCommonType(et1, et2, caseSensitive).map(ArrayType(_, hasNull1 || hasNull2)) + findTightestCommonType(et1, et2).map(ArrayType(_, hasNull1 || hasNull2)) case (m1 @ MapType(kt1, vt1, hasNull1), m2 @ MapType(kt2, vt2, hasNull2)) if m1.sameType(m2) => - val keyType = findTightestCommonType(kt1, kt2, caseSensitive) - val valueType = findTightestCommonType(vt1, vt2, caseSensitive) + val keyType = findTightestCommonType(kt1, kt2) + val valueType = findTightestCommonType(vt1, vt2) Some(MapType(keyType.get, valueType.get, hasNull1 || hasNull2)) case _ => None @@ -185,14 +172,13 @@ object TypeCoercion { * i.e. the main difference with [[findTightestCommonType]] is that here we allow some * loss of precision when widening decimal and double, and promotion to string. */ - def findWiderTypeForTwo(t1: DataType, t2: DataType, caseSensitive: Boolean): Option[DataType] = { - findTightestCommonType(t1, t2, caseSensitive) + def findWiderTypeForTwo(t1: DataType, t2: DataType): Option[DataType] = { + findTightestCommonType(t1, t2) .orElse(findWiderTypeForDecimal(t1, t2)) .orElse(stringPromotion(t1, t2)) .orElse((t1, t2) match { case (ArrayType(et1, containsNull1), ArrayType(et2, containsNull2)) => - findWiderTypeForTwo(et1, et2, caseSensitive) - .map(ArrayType(_, containsNull1 || containsNull2)) + findWiderTypeForTwo(et1, et2).map(ArrayType(_, containsNull1 || containsNull2)) case _ => None }) } @@ -207,8 +193,7 @@ object TypeCoercion { case _ => false } - private def findWiderCommonType( - types: Seq[DataType], caseSensitive: Boolean): Option[DataType] = { + private def findWiderCommonType(types: Seq[DataType]): Option[DataType] = { // findWiderTypeForTwo doesn't satisfy the associative law, i.e. (a op b) op c may not equal // to a op (b op c). This is only a problem for StringType or nested StringType in ArrayType. // Excluding these types, findWiderTypeForTwo satisfies the associative law. For instance, @@ -216,7 +201,7 @@ object TypeCoercion { val (stringTypes, nonStringTypes) = types.partition(hasStringType(_)) (stringTypes.distinct ++ nonStringTypes).foldLeft[Option[DataType]](Some(NullType))((r, c) => r match { - case Some(d) => findWiderTypeForTwo(d, c, caseSensitive) + case Some(d) => findWiderTypeForTwo(d, c) case _ => None }) } @@ -228,22 +213,20 @@ object TypeCoercion { */ private[analysis] def findWiderTypeWithoutStringPromotionForTwo( t1: DataType, - t2: DataType, - caseSensitive: Boolean): Option[DataType] = { - findTightestCommonType(t1, t2, caseSensitive) + t2: DataType): Option[DataType] = { + findTightestCommonType(t1, t2) .orElse(findWiderTypeForDecimal(t1, t2)) .orElse((t1, t2) match { case (ArrayType(et1, containsNull1), ArrayType(et2, containsNull2)) => - findWiderTypeWithoutStringPromotionForTwo(et1, et2, caseSensitive) + findWiderTypeWithoutStringPromotionForTwo(et1, et2) .map(ArrayType(_, containsNull1 || containsNull2)) case _ => None }) } - def findWiderTypeWithoutStringPromotion( - types: Seq[DataType], caseSensitive: Boolean): Option[DataType] = { + def findWiderTypeWithoutStringPromotion(types: Seq[DataType]): Option[DataType] = { types.foldLeft[Option[DataType]](Some(NullType))((r, c) => r match { - case Some(d) => findWiderTypeWithoutStringPromotionForTwo(d, c, caseSensitive) + case Some(d) => findWiderTypeWithoutStringPromotionForTwo(d, c) case None => None }) } @@ -296,32 +279,29 @@ object TypeCoercion { * * This rule is only applied to Union/Except/Intersect */ - case class WidenSetOperationTypes(conf: SQLConf) extends Rule[LogicalPlan] { + object WidenSetOperationTypes extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case s @ SetOperation(left, right) if s.childrenResolved && left.output.length == right.output.length && !s.resolved => - val newChildren: Seq[LogicalPlan] = - buildNewChildrenWithWiderTypes(left :: right :: Nil, conf.caseSensitiveAnalysis) + val newChildren: Seq[LogicalPlan] = buildNewChildrenWithWiderTypes(left :: right :: Nil) assert(newChildren.length == 2) s.makeCopy(Array(newChildren.head, newChildren.last)) case s: Union if s.childrenResolved && s.children.forall(_.output.length == s.children.head.output.length) && !s.resolved => - val newChildren: Seq[LogicalPlan] = - buildNewChildrenWithWiderTypes(s.children, conf.caseSensitiveAnalysis) + val newChildren: Seq[LogicalPlan] = buildNewChildrenWithWiderTypes(s.children) s.makeCopy(Array(newChildren)) } /** Build new children with the widest types for each attribute among all the children */ - private def buildNewChildrenWithWiderTypes( - children: Seq[LogicalPlan], caseSensitive: Boolean): Seq[LogicalPlan] = { + private def buildNewChildrenWithWiderTypes(children: Seq[LogicalPlan]): Seq[LogicalPlan] = { require(children.forall(_.output.length == children.head.output.length)) // Get a sequence of data types, each of which is the widest type of this specific attribute // in all the children val targetTypes: Seq[DataType] = - getWidestTypes(children, attrIndex = 0, mutable.Queue[DataType](), caseSensitive) + getWidestTypes(children, attrIndex = 0, mutable.Queue[DataType]()) if (targetTypes.nonEmpty) { // Add an extra Project if the targetTypes are different from the original types. @@ -336,19 +316,18 @@ object TypeCoercion { @tailrec private def getWidestTypes( children: Seq[LogicalPlan], attrIndex: Int, - castedTypes: mutable.Queue[DataType], - caseSensitive: Boolean): Seq[DataType] = { + castedTypes: mutable.Queue[DataType]): Seq[DataType] = { // Return the result after the widen data types have been found for all the children if (attrIndex >= children.head.output.length) return castedTypes.toSeq // For the attrIndex-th attribute, find the widest type - findWiderCommonType(children.map(_.output(attrIndex).dataType), caseSensitive) match { + findWiderCommonType(children.map(_.output(attrIndex).dataType)) match { // If unable to find an appropriate widen type for this column, return an empty Seq case None => Seq.empty[DataType] // Otherwise, record the result in the queue and find the type for the next column case Some(widenType) => castedTypes.enqueue(widenType) - getWidestTypes(children, attrIndex + 1, castedTypes, caseSensitive) + getWidestTypes(children, attrIndex + 1, castedTypes) } } @@ -453,7 +432,7 @@ object TypeCoercion { val commonTypes = lhs.zip(rhs).flatMap { case (l, r) => findCommonTypeForBinaryComparison(l.dataType, r.dataType, conf) - .orElse(findTightestCommonType(l.dataType, r.dataType, conf.caseSensitiveAnalysis)) + .orElse(findTightestCommonType(l.dataType, r.dataType)) } // The number of columns/expressions must match between LHS and RHS of an @@ -482,7 +461,7 @@ object TypeCoercion { } case i @ In(a, b) if b.exists(_.dataType != a.dataType) => - findWiderCommonType(i.children.map(_.dataType), conf.caseSensitiveAnalysis) match { + findWiderCommonType(i.children.map(_.dataType)) match { case Some(finalDataType) => i.withNewChildren(i.children.map(Cast(_, finalDataType))) case None => i } @@ -536,7 +515,7 @@ object TypeCoercion { /** * This ensure that the types for various functions are as expected. */ - case class FunctionArgumentConversion(conf: SQLConf) extends TypeCoercionRule { + object FunctionArgumentConversion extends TypeCoercionRule { override protected def coerceTypes( plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { // Skip nodes who's children have not been resolved yet. @@ -544,7 +523,7 @@ object TypeCoercion { case a @ CreateArray(children) if !haveSameType(children) => val types = children.map(_.dataType) - findWiderCommonType(types, conf.caseSensitiveAnalysis) match { + findWiderCommonType(types) match { case Some(finalDataType) => CreateArray(children.map(Cast(_, finalDataType))) case None => a } @@ -552,7 +531,7 @@ object TypeCoercion { case c @ Concat(children) if children.forall(c => ArrayType.acceptsType(c.dataType)) && !haveSameType(children) => val types = children.map(_.dataType) - findWiderCommonType(types, conf.caseSensitiveAnalysis) match { + findWiderCommonType(types) match { case Some(finalDataType) => Concat(children.map(Cast(_, finalDataType))) case None => c } @@ -563,7 +542,7 @@ object TypeCoercion { m.keys } else { val types = m.keys.map(_.dataType) - findWiderCommonType(types, conf.caseSensitiveAnalysis) match { + findWiderCommonType(types) match { case Some(finalDataType) => m.keys.map(Cast(_, finalDataType)) case None => m.keys } @@ -573,7 +552,7 @@ object TypeCoercion { m.values } else { val types = m.values.map(_.dataType) - findWiderCommonType(types, conf.caseSensitiveAnalysis) match { + findWiderCommonType(types) match { case Some(finalDataType) => m.values.map(Cast(_, finalDataType)) case None => m.values } @@ -601,7 +580,7 @@ object TypeCoercion { // compatible with every child column. case c @ Coalesce(es) if !haveSameType(es) => val types = es.map(_.dataType) - findWiderCommonType(types, conf.caseSensitiveAnalysis) match { + findWiderCommonType(types) match { case Some(finalDataType) => Coalesce(es.map(Cast(_, finalDataType))) case None => c } @@ -611,14 +590,14 @@ object TypeCoercion { // string.g case g @ Greatest(children) if !haveSameType(children) => val types = children.map(_.dataType) - findWiderTypeWithoutStringPromotion(types, conf.caseSensitiveAnalysis) match { + findWiderTypeWithoutStringPromotion(types) match { case Some(finalDataType) => Greatest(children.map(Cast(_, finalDataType))) case None => g } case l @ Least(children) if !haveSameType(children) => val types = children.map(_.dataType) - findWiderTypeWithoutStringPromotion(types, conf.caseSensitiveAnalysis) match { + findWiderTypeWithoutStringPromotion(types) match { case Some(finalDataType) => Least(children.map(Cast(_, finalDataType))) case None => l } @@ -658,11 +637,11 @@ object TypeCoercion { /** * Coerces the type of different branches of a CASE WHEN statement to a common type. */ - case class CaseWhenCoercion(conf: SQLConf) extends TypeCoercionRule { + object CaseWhenCoercion extends TypeCoercionRule { override protected def coerceTypes( plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { case c: CaseWhen if c.childrenResolved && !c.valueTypesEqual => - val maybeCommonType = findWiderCommonType(c.valueTypes, conf.caseSensitiveAnalysis) + val maybeCommonType = findWiderCommonType(c.valueTypes) maybeCommonType.map { commonType => var changed = false val newBranches = c.branches.map { case (condition, value) => @@ -689,17 +668,16 @@ object TypeCoercion { /** * Coerces the type of different branches of If statement to a common type. */ - case class IfCoercion(conf: SQLConf) extends TypeCoercionRule { + object IfCoercion extends TypeCoercionRule { override protected def coerceTypes( plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { case e if !e.childrenResolved => e // Find tightest common type for If, if the true value and false value have different types. case i @ If(pred, left, right) if left.dataType != right.dataType => - findWiderTypeForTwo(left.dataType, right.dataType, conf.caseSensitiveAnalysis).map { - widestType => - val newLeft = if (left.dataType == widestType) left else Cast(left, widestType) - val newRight = if (right.dataType == widestType) right else Cast(right, widestType) - If(pred, newLeft, newRight) + findWiderTypeForTwo(left.dataType, right.dataType).map { widestType => + val newLeft = if (left.dataType == widestType) left else Cast(left, widestType) + val newRight = if (right.dataType == widestType) right else Cast(right, widestType) + If(pred, newLeft, newRight) }.getOrElse(i) // If there is no applicable conversion, leave expression unchanged. case If(Literal(null, NullType), left, right) => If(Literal.create(null, BooleanType), left, right) @@ -798,11 +776,12 @@ object TypeCoercion { /** * Casts types according to the expected input types for [[Expression]]s. */ - case class ImplicitTypeCasts(conf: SQLConf) extends TypeCoercionRule { + class ImplicitTypeCasts(conf: SQLConf) extends TypeCoercionRule { private def rejectTzInString = conf.getConf(SQLConf.REJECT_TIMEZONE_IN_STRING) - override def coerceTypes(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + override protected def coerceTypes( + plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e @@ -825,18 +804,17 @@ object TypeCoercion { } case b @ BinaryOperator(left, right) if left.dataType != right.dataType => - findTightestCommonType(left.dataType, right.dataType, conf.caseSensitiveAnalysis).map { - commonType => - if (b.inputType.acceptsType(commonType)) { - // If the expression accepts the tightest common type, cast to that. - val newLeft = if (left.dataType == commonType) left else Cast(left, commonType) - val newRight = if (right.dataType == commonType) right else Cast(right, commonType) - b.withNewChildren(Seq(newLeft, newRight)) - } else { - // Otherwise, don't do anything with the expression. - b - } - }.getOrElse(b) // If there is no applicable conversion, leave expression unchanged. + findTightestCommonType(left.dataType, right.dataType).map { commonType => + if (b.inputType.acceptsType(commonType)) { + // If the expression accepts the tightest common type, cast to that. + val newLeft = if (left.dataType == commonType) left else Cast(left, commonType) + val newRight = if (right.dataType == commonType) right else Cast(right, commonType) + b.withNewChildren(Seq(newLeft, newRight)) + } else { + // Otherwise, don't do anything with the expression. + b + } + }.getOrElse(b) // If there is no applicable conversion, leave expression unchanged. case e: ImplicitCastInputTypes if e.inputTypes.nonEmpty => val children: Seq[Expression] = e.children.zip(e.inputTypes).map { case (in, expected) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 0b1965c438e27..b00edca97cd44 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -27,7 +27,7 @@ import scala.util.matching.Regex import org.apache.hadoop.fs.Path -import org.apache.spark.TaskContext +import org.apache.spark.{SparkContext, SparkEnv} import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ import org.apache.spark.network.util.ByteUnit @@ -107,13 +107,7 @@ object SQLConf { * run tests in parallel. At the time this feature was implemented, this was a no-op since we * run unit tests (that does not involve SparkSession) in serial order. */ - def get: SQLConf = { - if (Utils.isTesting && TaskContext.get != null) { - // we're accessing it during task execution, fail. - throw new IllegalStateException("SQLConf should only be created and accessed on the driver.") - } - confGetter.get()() - } + def get: SQLConf = confGetter.get()() val OPTIMIZER_MAX_ITERATIONS = buildConf("spark.sql.optimizer.maxIterations") .internal() @@ -1280,6 +1274,12 @@ object SQLConf { class SQLConf extends Serializable with Logging { import SQLConf._ + if (Utils.isTesting && SparkEnv.get != null) { + // assert that we're only accessing it on the driver. + assert(SparkEnv.get.executorId == SparkContext.DRIVER_IDENTIFIER, + "SQLConf should only be created and accessed on the driver.") + } + /** Only low degree of contention is expected for conf, thus NOT using ConcurrentHashMap. */ @transient protected[spark] val settings = java.util.Collections.synchronizedMap( new java.util.HashMap[String, String]()) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala index 4ee12db9c10ca..0bef11659fc9e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -81,7 +81,11 @@ abstract class DataType extends AbstractDataType { * (`StructField.nullable`, `ArrayType.containsNull`, and `MapType.valueContainsNull`). */ private[spark] def sameType(other: DataType): Boolean = - DataType.equalsIgnoreNullability(this, other) + if (SQLConf.get.caseSensitiveAnalysis) { + DataType.equalsIgnoreNullability(this, other) + } else { + DataType.equalsIgnoreCaseAndNullability(this, other) + } /** * Returns the same data type but set all nullability fields are true @@ -214,7 +218,7 @@ object DataType { /** * Compares two types, ignoring nullability of ArrayType, MapType, StructType. */ - private[sql] def equalsIgnoreNullability(left: DataType, right: DataType): Boolean = { + private[types] def equalsIgnoreNullability(left: DataType, right: DataType): Boolean = { (left, right) match { case (ArrayType(leftElementType, _), ArrayType(rightElementType, _)) => equalsIgnoreNullability(leftElementType, rightElementType) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala index f73e045685ee1..0acd3b490447d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala @@ -128,17 +128,17 @@ class TypeCoercionSuite extends AnalysisTest { } private def checkWidenType( - widenFunc: (DataType, DataType, Boolean) => Option[DataType], + widenFunc: (DataType, DataType) => Option[DataType], t1: DataType, t2: DataType, expected: Option[DataType], isSymmetric: Boolean = true): Unit = { - var found = widenFunc(t1, t2, conf.caseSensitiveAnalysis) + var found = widenFunc(t1, t2) assert(found == expected, s"Expected $expected as wider common type for $t1 and $t2, found $found") // Test both directions to make sure the widening is symmetric. if (isSymmetric) { - found = widenFunc(t2, t1, conf.caseSensitiveAnalysis) + found = widenFunc(t2, t1) assert(found == expected, s"Expected $expected as wider common type for $t2 and $t1, found $found") } @@ -524,11 +524,11 @@ class TypeCoercionSuite extends AnalysisTest { test("cast NullType for expressions that implement ExpectsInputTypes") { import TypeCoercionSuite._ - ruleTest(TypeCoercion.ImplicitTypeCasts(conf), + ruleTest(new TypeCoercion.ImplicitTypeCasts(conf), AnyTypeUnaryExpression(Literal.create(null, NullType)), AnyTypeUnaryExpression(Literal.create(null, NullType))) - ruleTest(TypeCoercion.ImplicitTypeCasts(conf), + ruleTest(new TypeCoercion.ImplicitTypeCasts(conf), NumericTypeUnaryExpression(Literal.create(null, NullType)), NumericTypeUnaryExpression(Literal.create(null, DoubleType))) } @@ -536,17 +536,17 @@ class TypeCoercionSuite extends AnalysisTest { test("cast NullType for binary operators") { import TypeCoercionSuite._ - ruleTest(TypeCoercion.ImplicitTypeCasts(conf), + ruleTest(new TypeCoercion.ImplicitTypeCasts(conf), AnyTypeBinaryOperator(Literal.create(null, NullType), Literal.create(null, NullType)), AnyTypeBinaryOperator(Literal.create(null, NullType), Literal.create(null, NullType))) - ruleTest(TypeCoercion.ImplicitTypeCasts(conf), + ruleTest(new TypeCoercion.ImplicitTypeCasts(conf), NumericTypeBinaryOperator(Literal.create(null, NullType), Literal.create(null, NullType)), NumericTypeBinaryOperator(Literal.create(null, DoubleType), Literal.create(null, DoubleType))) } test("coalesce casts") { - val rule = TypeCoercion.FunctionArgumentConversion(conf) + val rule = TypeCoercion.FunctionArgumentConversion val intLit = Literal(1) val longLit = Literal.create(1L) @@ -606,7 +606,7 @@ class TypeCoercionSuite extends AnalysisTest { } test("CreateArray casts") { - ruleTest(TypeCoercion.FunctionArgumentConversion(conf), + ruleTest(TypeCoercion.FunctionArgumentConversion, CreateArray(Literal(1.0) :: Literal(1) :: Literal.create(1.0, FloatType) @@ -616,7 +616,7 @@ class TypeCoercionSuite extends AnalysisTest { :: Cast(Literal.create(1.0, FloatType), DoubleType) :: Nil)) - ruleTest(TypeCoercion.FunctionArgumentConversion(conf), + ruleTest(TypeCoercion.FunctionArgumentConversion, CreateArray(Literal(1.0) :: Literal(1) :: Literal("a") @@ -626,7 +626,7 @@ class TypeCoercionSuite extends AnalysisTest { :: Cast(Literal("a"), StringType) :: Nil)) - ruleTest(TypeCoercion.FunctionArgumentConversion(conf), + ruleTest(TypeCoercion.FunctionArgumentConversion, CreateArray(Literal.create(null, DecimalType(5, 3)) :: Literal(1) :: Nil), @@ -634,7 +634,7 @@ class TypeCoercionSuite extends AnalysisTest { :: Literal(1).cast(DecimalType(13, 3)) :: Nil)) - ruleTest(TypeCoercion.FunctionArgumentConversion(conf), + ruleTest(TypeCoercion.FunctionArgumentConversion, CreateArray(Literal.create(null, DecimalType(5, 3)) :: Literal.create(null, DecimalType(22, 10)) :: Literal.create(null, DecimalType(38, 38)) @@ -647,7 +647,7 @@ class TypeCoercionSuite extends AnalysisTest { test("CreateMap casts") { // type coercion for map keys - ruleTest(TypeCoercion.FunctionArgumentConversion(conf), + ruleTest(TypeCoercion.FunctionArgumentConversion, CreateMap(Literal(1) :: Literal("a") :: Literal.create(2.0, FloatType) @@ -658,7 +658,7 @@ class TypeCoercionSuite extends AnalysisTest { :: Cast(Literal.create(2.0, FloatType), FloatType) :: Literal("b") :: Nil)) - ruleTest(TypeCoercion.FunctionArgumentConversion(conf), + ruleTest(TypeCoercion.FunctionArgumentConversion, CreateMap(Literal.create(null, DecimalType(5, 3)) :: Literal("a") :: Literal.create(2.0, FloatType) @@ -670,7 +670,7 @@ class TypeCoercionSuite extends AnalysisTest { :: Literal("b") :: Nil)) // type coercion for map values - ruleTest(TypeCoercion.FunctionArgumentConversion(conf), + ruleTest(TypeCoercion.FunctionArgumentConversion, CreateMap(Literal(1) :: Literal("a") :: Literal(2) @@ -681,7 +681,7 @@ class TypeCoercionSuite extends AnalysisTest { :: Literal(2) :: Cast(Literal(3.0), StringType) :: Nil)) - ruleTest(TypeCoercion.FunctionArgumentConversion(conf), + ruleTest(TypeCoercion.FunctionArgumentConversion, CreateMap(Literal(1) :: Literal.create(null, DecimalType(38, 0)) :: Literal(2) @@ -693,7 +693,7 @@ class TypeCoercionSuite extends AnalysisTest { :: Literal.create(null, DecimalType(38, 38)).cast(DecimalType(38, 38)) :: Nil)) // type coercion for both map keys and values - ruleTest(TypeCoercion.FunctionArgumentConversion(conf), + ruleTest(TypeCoercion.FunctionArgumentConversion, CreateMap(Literal(1) :: Literal("a") :: Literal(2.0) @@ -708,7 +708,7 @@ class TypeCoercionSuite extends AnalysisTest { test("greatest/least cast") { for (operator <- Seq[(Seq[Expression] => Expression)](Greatest, Least)) { - ruleTest(TypeCoercion.FunctionArgumentConversion(conf), + ruleTest(TypeCoercion.FunctionArgumentConversion, operator(Literal(1.0) :: Literal(1) :: Literal.create(1.0, FloatType) @@ -717,7 +717,7 @@ class TypeCoercionSuite extends AnalysisTest { :: Cast(Literal(1), DoubleType) :: Cast(Literal.create(1.0, FloatType), DoubleType) :: Nil)) - ruleTest(TypeCoercion.FunctionArgumentConversion(conf), + ruleTest(TypeCoercion.FunctionArgumentConversion, operator(Literal(1L) :: Literal(1) :: Literal(new java.math.BigDecimal("1000000000000000000000")) @@ -726,7 +726,7 @@ class TypeCoercionSuite extends AnalysisTest { :: Cast(Literal(1), DecimalType(22, 0)) :: Cast(Literal(new java.math.BigDecimal("1000000000000000000000")), DecimalType(22, 0)) :: Nil)) - ruleTest(TypeCoercion.FunctionArgumentConversion(conf), + ruleTest(TypeCoercion.FunctionArgumentConversion, operator(Literal(1.0) :: Literal.create(null, DecimalType(10, 5)) :: Literal(1) @@ -735,7 +735,7 @@ class TypeCoercionSuite extends AnalysisTest { :: Literal.create(null, DecimalType(10, 5)).cast(DoubleType) :: Literal(1).cast(DoubleType) :: Nil)) - ruleTest(TypeCoercion.FunctionArgumentConversion(conf), + ruleTest(TypeCoercion.FunctionArgumentConversion, operator(Literal.create(null, DecimalType(15, 0)) :: Literal.create(null, DecimalType(10, 5)) :: Literal(1) @@ -744,7 +744,7 @@ class TypeCoercionSuite extends AnalysisTest { :: Literal.create(null, DecimalType(10, 5)).cast(DecimalType(20, 5)) :: Literal(1).cast(DecimalType(20, 5)) :: Nil)) - ruleTest(TypeCoercion.FunctionArgumentConversion(conf), + ruleTest(TypeCoercion.FunctionArgumentConversion, operator(Literal.create(2L, LongType) :: Literal(1) :: Literal.create(null, DecimalType(10, 5)) @@ -757,25 +757,25 @@ class TypeCoercionSuite extends AnalysisTest { } test("nanvl casts") { - ruleTest(TypeCoercion.FunctionArgumentConversion(conf), + ruleTest(TypeCoercion.FunctionArgumentConversion, NaNvl(Literal.create(1.0f, FloatType), Literal.create(1.0, DoubleType)), NaNvl(Cast(Literal.create(1.0f, FloatType), DoubleType), Literal.create(1.0, DoubleType))) - ruleTest(TypeCoercion.FunctionArgumentConversion(conf), + ruleTest(TypeCoercion.FunctionArgumentConversion, NaNvl(Literal.create(1.0, DoubleType), Literal.create(1.0f, FloatType)), NaNvl(Literal.create(1.0, DoubleType), Cast(Literal.create(1.0f, FloatType), DoubleType))) - ruleTest(TypeCoercion.FunctionArgumentConversion(conf), + ruleTest(TypeCoercion.FunctionArgumentConversion, NaNvl(Literal.create(1.0, DoubleType), Literal.create(1.0, DoubleType)), NaNvl(Literal.create(1.0, DoubleType), Literal.create(1.0, DoubleType))) - ruleTest(TypeCoercion.FunctionArgumentConversion(conf), + ruleTest(TypeCoercion.FunctionArgumentConversion, NaNvl(Literal.create(1.0f, FloatType), Literal.create(null, NullType)), NaNvl(Literal.create(1.0f, FloatType), Cast(Literal.create(null, NullType), FloatType))) - ruleTest(TypeCoercion.FunctionArgumentConversion(conf), + ruleTest(TypeCoercion.FunctionArgumentConversion, NaNvl(Literal.create(1.0, DoubleType), Literal.create(null, NullType)), NaNvl(Literal.create(1.0, DoubleType), Cast(Literal.create(null, NullType), DoubleType))) } test("type coercion for If") { - val rule = TypeCoercion.IfCoercion(conf) + val rule = TypeCoercion.IfCoercion val intLit = Literal(1) val doubleLit = Literal(1.0) val trueLit = Literal.create(true, BooleanType) @@ -823,20 +823,20 @@ class TypeCoercionSuite extends AnalysisTest { } test("type coercion for CaseKeyWhen") { - ruleTest(TypeCoercion.ImplicitTypeCasts(conf), + ruleTest(new TypeCoercion.ImplicitTypeCasts(conf), CaseKeyWhen(Literal(1.toShort), Seq(Literal(1), Literal("a"))), CaseKeyWhen(Cast(Literal(1.toShort), IntegerType), Seq(Literal(1), Literal("a"))) ) - ruleTest(TypeCoercion.CaseWhenCoercion(conf), + ruleTest(TypeCoercion.CaseWhenCoercion, CaseKeyWhen(Literal(true), Seq(Literal(1), Literal("a"))), CaseKeyWhen(Literal(true), Seq(Literal(1), Literal("a"))) ) - ruleTest(TypeCoercion.CaseWhenCoercion(conf), + ruleTest(TypeCoercion.CaseWhenCoercion, CaseWhen(Seq((Literal(true), Literal(1.2))), Literal.create(1, DecimalType(7, 2))), CaseWhen(Seq((Literal(true), Literal(1.2))), Cast(Literal.create(1, DecimalType(7, 2)), DoubleType)) ) - ruleTest(TypeCoercion.CaseWhenCoercion(conf), + ruleTest(TypeCoercion.CaseWhenCoercion, CaseWhen(Seq((Literal(true), Literal(100L))), Literal.create(1, DecimalType(7, 2))), CaseWhen(Seq((Literal(true), Cast(Literal(100L), DecimalType(22, 2)))), Cast(Literal.create(1, DecimalType(7, 2)), DecimalType(22, 2))) @@ -1085,7 +1085,7 @@ class TypeCoercionSuite extends AnalysisTest { private val timeZoneResolver = ResolveTimeZone(new SQLConf) private def widenSetOperationTypes(plan: LogicalPlan): LogicalPlan = { - timeZoneResolver(TypeCoercion.WidenSetOperationTypes(conf)(plan)) + timeZoneResolver(TypeCoercion.WidenSetOperationTypes(plan)) } test("WidenSetOperationTypes for except and intersect") { @@ -1256,7 +1256,7 @@ class TypeCoercionSuite extends AnalysisTest { test("SPARK-15776 Divide expression's dataType should be casted to Double or Decimal " + "in aggregation function like sum") { - val rules = Seq(FunctionArgumentConversion(conf), Division) + val rules = Seq(FunctionArgumentConversion, Division) // Casts Integer to Double ruleTest(rules, sum(Divide(4, 3)), sum(Divide(Cast(4, DoubleType), Cast(3, DoubleType)))) // Left expression is Double, right expression is Int. Another rule ImplicitTypeCasts will @@ -1275,7 +1275,7 @@ class TypeCoercionSuite extends AnalysisTest { } test("SPARK-17117 null type coercion in divide") { - val rules = Seq(FunctionArgumentConversion(conf), Division, ImplicitTypeCasts(conf)) + val rules = Seq(FunctionArgumentConversion, Division, new ImplicitTypeCasts(conf)) val nullLit = Literal.create(null, NullType) ruleTest(rules, Divide(1L, nullLit), Divide(Cast(1L, DoubleType), Cast(nullLit, DoubleType))) ruleTest(rules, Divide(nullLit, 1L), Divide(Cast(nullLit, DoubleType), Cast(1L, DoubleType))) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index e2a1a57c7dd4d..c502e583a55c5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -24,7 +24,7 @@ import scala.collection.JavaConverters._ import scala.reflect.runtime.universe.TypeTag import scala.util.control.NonFatal -import org.apache.spark.{SPARK_VERSION, SparkConf, SparkContext, TaskContext} +import org.apache.spark.{SPARK_VERSION, SparkConf, SparkContext} import org.apache.spark.annotation.{DeveloperApi, Experimental, InterfaceStability} import org.apache.spark.api.java.JavaRDD import org.apache.spark.internal.Logging @@ -898,7 +898,6 @@ object SparkSession extends Logging { * @since 2.0.0 */ def getOrCreate(): SparkSession = synchronized { - assertOnDriver() // Get the session from current thread's active session. var session = activeThreadSession.get() if ((session ne null) && !session.sparkContext.isStopped) { @@ -1023,20 +1022,14 @@ object SparkSession extends Logging { * * @since 2.2.0 */ - def getActiveSession: Option[SparkSession] = { - assertOnDriver() - Option(activeThreadSession.get) - } + def getActiveSession: Option[SparkSession] = Option(activeThreadSession.get) /** * Returns the default SparkSession that is returned by the builder. * * @since 2.2.0 */ - def getDefaultSession: Option[SparkSession] = { - assertOnDriver() - Option(defaultSession.get) - } + def getDefaultSession: Option[SparkSession] = Option(defaultSession.get) /** * Returns the currently active SparkSession, otherwise the default one. If there is no default @@ -1069,14 +1062,6 @@ object SparkSession extends Logging { } } - private def assertOnDriver(): Unit = { - if (Utils.isTesting && TaskContext.get != null) { - // we're accessing it during task execution, fail. - throw new IllegalStateException( - "SparkSession should only be created and accessed on the driver.") - } - } - /** * Helper method to create an instance of `SessionState` based on `className` from conf. * The result is either `SessionState` or a Hive based `SessionState`. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala index 1edf27619ad7b..f9a24806953e6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala @@ -32,7 +32,6 @@ import org.apache.spark.sql.catalyst.analysis.{Resolver, TypeCoercion} import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.{Cast, Literal} import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.sql.util.SchemaUtils @@ -522,8 +521,6 @@ object PartitioningUtils { private val findWiderTypeForPartitionColumn: (DataType, DataType) => DataType = { case (DoubleType, _: DecimalType) | (_: DecimalType, DoubleType) => StringType case (DoubleType, LongType) | (LongType, DoubleType) => StringType - case (t1, t2) => - TypeCoercion.findWiderTypeForTwo( - t1, t2, SQLConf.get.caseSensitiveAnalysis).getOrElse(StringType) + case (t1, t2) => TypeCoercion.findWiderTypeForTwo(t1, t2).getOrElse(StringType) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala index e0424b7478122..a270a6451d5dd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala @@ -27,7 +27,6 @@ import org.apache.spark.sql.catalyst.analysis.TypeCoercion import org.apache.spark.sql.catalyst.json.JacksonUtils.nextUntil import org.apache.spark.sql.catalyst.json.JSONOptions import org.apache.spark.sql.catalyst.util.{DropMalformedMode, FailFastMode, ParseMode, PermissiveMode} -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -45,7 +44,6 @@ private[sql] object JsonInferSchema { createParser: (JsonFactory, T) => JsonParser): StructType = { val parseMode = configOptions.parseMode val columnNameOfCorruptRecord = configOptions.columnNameOfCorruptRecord - val caseSensitive = SQLConf.get.caseSensitiveAnalysis // perform schema inference on each row and merge afterwards val rootType = json.mapPartitions { iter => @@ -55,7 +53,7 @@ private[sql] object JsonInferSchema { try { Utils.tryWithResource(createParser(factory, row)) { parser => parser.nextToken() - Some(inferField(parser, configOptions, caseSensitive)) + Some(inferField(parser, configOptions)) } } catch { case e @ (_: RuntimeException | _: JsonProcessingException) => parseMode match { @@ -70,7 +68,7 @@ private[sql] object JsonInferSchema { } } }.fold(StructType(Nil))( - compatibleRootType(columnNameOfCorruptRecord, parseMode, caseSensitive)) + compatibleRootType(columnNameOfCorruptRecord, parseMode)) canonicalizeType(rootType) match { case Some(st: StructType) => st @@ -100,15 +98,14 @@ private[sql] object JsonInferSchema { /** * Infer the type of a json document from the parser's token stream */ - private def inferField( - parser: JsonParser, configOptions: JSONOptions, caseSensitive: Boolean): DataType = { + private def inferField(parser: JsonParser, configOptions: JSONOptions): DataType = { import com.fasterxml.jackson.core.JsonToken._ parser.getCurrentToken match { case null | VALUE_NULL => NullType case FIELD_NAME => parser.nextToken() - inferField(parser, configOptions, caseSensitive) + inferField(parser, configOptions) case VALUE_STRING if parser.getTextLength < 1 => // Zero length strings and nulls have special handling to deal @@ -125,7 +122,7 @@ private[sql] object JsonInferSchema { while (nextUntil(parser, END_OBJECT)) { builder += StructField( parser.getCurrentName, - inferField(parser, configOptions, caseSensitive), + inferField(parser, configOptions), nullable = true) } val fields: Array[StructField] = builder.result() @@ -140,7 +137,7 @@ private[sql] object JsonInferSchema { var elementType: DataType = NullType while (nextUntil(parser, END_ARRAY)) { elementType = compatibleType( - elementType, inferField(parser, configOptions, caseSensitive), caseSensitive) + elementType, inferField(parser, configOptions)) } ArrayType(elementType) @@ -246,14 +243,13 @@ private[sql] object JsonInferSchema { */ private def compatibleRootType( columnNameOfCorruptRecords: String, - parseMode: ParseMode, - caseSensitive: Boolean): (DataType, DataType) => DataType = { + parseMode: ParseMode): (DataType, DataType) => DataType = { // Since we support array of json objects at the top level, // we need to check the element type and find the root level data type. case (ArrayType(ty1, _), ty2) => - compatibleRootType(columnNameOfCorruptRecords, parseMode, caseSensitive)(ty1, ty2) + compatibleRootType(columnNameOfCorruptRecords, parseMode)(ty1, ty2) case (ty1, ArrayType(ty2, _)) => - compatibleRootType(columnNameOfCorruptRecords, parseMode, caseSensitive)(ty1, ty2) + compatibleRootType(columnNameOfCorruptRecords, parseMode)(ty1, ty2) // Discard null/empty documents case (struct: StructType, NullType) => struct case (NullType, struct: StructType) => struct @@ -263,7 +259,7 @@ private[sql] object JsonInferSchema { withCorruptField(struct, o, columnNameOfCorruptRecords, parseMode) // If we get anything else, we call compatibleType. // Usually, when we reach here, ty1 and ty2 are two StructTypes. - case (ty1, ty2) => compatibleType(ty1, ty2, caseSensitive) + case (ty1, ty2) => compatibleType(ty1, ty2) } private[this] val emptyStructFieldArray = Array.empty[StructField] @@ -271,8 +267,8 @@ private[sql] object JsonInferSchema { /** * Returns the most general data type for two given data types. */ - def compatibleType(t1: DataType, t2: DataType, caseSensitive: Boolean): DataType = { - TypeCoercion.findTightestCommonType(t1, t2, caseSensitive).getOrElse { + def compatibleType(t1: DataType, t2: DataType): DataType = { + TypeCoercion.findTightestCommonType(t1, t2).getOrElse { // t1 or t2 is a StructType, ArrayType, or an unexpected type. (t1, t2) match { // Double support larger range than fixed decimal, DecimalType.Maximum should be enough @@ -307,8 +303,7 @@ private[sql] object JsonInferSchema { val f2Name = fields2(f2Idx).name val comp = f1Name.compareTo(f2Name) if (comp == 0) { - val dataType = compatibleType( - fields1(f1Idx).dataType, fields2(f2Idx).dataType, caseSensitive) + val dataType = compatibleType(fields1(f1Idx).dataType, fields2(f2Idx).dataType) newFields.add(StructField(f1Name, dataType, nullable = true)) f1Idx += 1 f2Idx += 1 @@ -331,17 +326,15 @@ private[sql] object JsonInferSchema { StructType(newFields.toArray(emptyStructFieldArray)) case (ArrayType(elementType1, containsNull1), ArrayType(elementType2, containsNull2)) => - ArrayType( - compatibleType(elementType1, elementType2, caseSensitive), - containsNull1 || containsNull2) + ArrayType(compatibleType(elementType1, elementType2), containsNull1 || containsNull2) // The case that given `DecimalType` is capable of given `IntegralType` is handled in // `findTightestCommonTypeOfTwo`. Both cases below will be executed only when // the given `DecimalType` is not capable of the given `IntegralType`. case (t1: IntegralType, t2: DecimalType) => - compatibleType(DecimalType.forType(t1), t2, caseSensitive) + compatibleType(DecimalType.forType(t1), t2) case (t1: DecimalType, t2: IntegralType) => - compatibleType(t1, DecimalType.forType(t2), caseSensitive) + compatibleType(t1, DecimalType.forType(t2)) // strings and every string is a Json object. case (_, _) => StringType diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 34d23ee53220d..4b3921c61a000 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -122,10 +122,10 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { test("Get compatible type") { def checkDataType(t1: DataType, t2: DataType, expected: DataType) { - var actual = compatibleType(t1, t2, conf.caseSensitiveAnalysis) + var actual = compatibleType(t1, t2) assert(actual == expected, s"Expected $expected as the most general data type for $t1 and $t2, found $actual") - actual = compatibleType(t2, t1, conf.caseSensitiveAnalysis) + actual = compatibleType(t2, t1) assert(actual == expected, s"Expected $expected as the most general data type for $t1 and $t2, found $actual") } From 6fb7d6c4f71be0007942f7d1fc3099f1bcf8c52b Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Thu, 17 May 2018 00:40:39 +0800 Subject: [PATCH 0809/1161] [SPARK-24275][SQL] Revise doc comments in InputPartition ## What changes were proposed in this pull request? In #21145, DataReaderFactory is renamed to InputPartition. This PR is to revise wording in the comments to make it more clear. ## How was this patch tested? None Author: Gengliang Wang Closes #21326 from gengliangwang/revise_reader_comments. --- .../spark/sql/sources/v2/ReadSupport.java | 2 +- .../sql/sources/v2/ReadSupportWithSchema.java | 2 +- .../spark/sql/sources/v2/WriteSupport.java | 2 +- .../sql/sources/v2/reader/DataSourceReader.java | 16 ++++++++-------- .../sql/sources/v2/reader/InputPartition.java | 17 +++++++++-------- .../sql/sources/v2/writer/DataSourceWriter.java | 6 +++--- .../sources/v2/writer/DataWriterFactory.java | 2 +- 7 files changed, 24 insertions(+), 23 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupport.java index 0ea4dc6b5def3..b2526ded53d92 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupport.java @@ -30,7 +30,7 @@ public interface ReadSupport extends DataSourceV2 { /** * Creates a {@link DataSourceReader} to scan the data from this data source. * - * If this method fails (by throwing an exception), the action would fail and no Spark job was + * If this method fails (by throwing an exception), the action will fail and no Spark job will be * submitted. * * @param options the options for the returned data source reader, which is an immutable diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupportWithSchema.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupportWithSchema.java index 3801402268af1..f31659904cc53 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupportWithSchema.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupportWithSchema.java @@ -35,7 +35,7 @@ public interface ReadSupportWithSchema extends DataSourceV2 { /** * Create a {@link DataSourceReader} to scan the data from this data source. * - * If this method fails (by throwing an exception), the action would fail and no Spark job was + * If this method fails (by throwing an exception), the action will fail and no Spark job will be * submitted. * * @param schema the full schema of this data source reader. Full schema usually maps to the diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/WriteSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/WriteSupport.java index cab56453816cc..83aeec0c47853 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/WriteSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/WriteSupport.java @@ -35,7 +35,7 @@ public interface WriteSupport extends DataSourceV2 { * Creates an optional {@link DataSourceWriter} to save the data to this data source. Data * sources can return None if there is no writing needed to be done according to the save mode. * - * If this method fails (by throwing an exception), the action would fail and no Spark job was + * If this method fails (by throwing an exception), the action will fail and no Spark job will be * submitted. * * @param jobId A unique string for the writing job. It's possible that there are many writing diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceReader.java index f898c296e4245..36a3e542b5a11 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceReader.java @@ -31,7 +31,7 @@ * {@link ReadSupport#createReader(DataSourceOptions)} or * {@link ReadSupportWithSchema#createReader(StructType, DataSourceOptions)}. * It can mix in various query optimization interfaces to speed up the data scan. The actual scan - * logic is delegated to {@link InputPartition}s that are returned by + * logic is delegated to {@link InputPartition}s, which are returned by * {@link #planInputPartitions()}. * * There are mainly 3 kinds of query optimizations: @@ -45,8 +45,8 @@ * only one of them would be respected, according to the priority list from high to low: * {@link SupportsScanColumnarBatch}, {@link SupportsScanUnsafeRow}. * - * If an exception was throw when applying any of these query optimizations, the action would fail - * and no Spark job was submitted. + * If an exception was throw when applying any of these query optimizations, the action will fail + * and no Spark job will be submitted. * * Spark first applies all operator push-down optimizations that this data source supports. Then * Spark collects information this data source reported for further optimizations. Finally Spark @@ -59,21 +59,21 @@ public interface DataSourceReader { * Returns the actual schema of this data source reader, which may be different from the physical * schema of the underlying storage, as column pruning or other optimizations may happen. * - * If this method fails (by throwing an exception), the action would fail and no Spark job was + * If this method fails (by throwing an exception), the action will fail and no Spark job will be * submitted. */ StructType readSchema(); /** - * Returns a list of read tasks. Each task is responsible for creating a data reader to - * output data for one RDD partition. That means the number of tasks returned here is same as - * the number of RDD partitions this scan outputs. + * Returns a list of {@link InputPartition}s. Each {@link InputPartition} is responsible for + * creating a data reader to output data of one RDD partition. The number of input partitions + * returned here is the same as the number of RDD partitions this scan outputs. * * Note that, this may not be a full scan if the data source reader mixes in other optimization * interfaces like column pruning, filter push-down, etc. These optimizations are applied before * Spark issues the scan request. * - * If this method fails (by throwing an exception), the action would fail and no Spark job was + * If this method fails (by throwing an exception), the action will fail and no Spark job will be * submitted. */ List> planInputPartitions(); diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartition.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartition.java index c581e3b5d0047..3524481784fea 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartition.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartition.java @@ -23,13 +23,14 @@ /** * An input partition returned by {@link DataSourceReader#planInputPartitions()} and is - * responsible for creating the actual data reader. The relationship between - * {@link InputPartition} and {@link InputPartitionReader} + * responsible for creating the actual data reader of one RDD partition. + * The relationship between {@link InputPartition} and {@link InputPartitionReader} * is similar to the relationship between {@link Iterable} and {@link java.util.Iterator}. * - * Note that input partitions will be serialized and sent to executors, then the partition reader - * will be created on executors and do the actual reading. So {@link InputPartition} must be - * serializable and {@link InputPartitionReader} doesn't need to be. + * Note that {@link InputPartition}s will be serialized and sent to executors, then + * {@link InputPartitionReader}s will be created on executors to do the actual reading. So + * {@link InputPartition} must be serializable while {@link InputPartitionReader} doesn't need to + * be. */ @InterfaceStability.Evolving public interface InputPartition extends Serializable { @@ -41,10 +42,10 @@ public interface InputPartition extends Serializable { * The location is a string representing the host name. * * Note that if a host name cannot be recognized by Spark, it will be ignored as it was not in - * the returned locations. By default this method returns empty string array, which means this - * task has no location preference. + * the returned locations. The default return value is empty string array, which means this + * input partition's reader has no location preference. * - * If this method fails (by throwing an exception), the action would fail and no Spark job was + * If this method fails (by throwing an exception), the action will fail and no Spark job will be * submitted. */ default String[] preferredLocations() { diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java index 0a0fd8db58035..0030a9f05dba7 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java @@ -34,8 +34,8 @@ * It can mix in various writing optimization interfaces to speed up the data saving. The actual * writing logic is delegated to {@link DataWriter}. * - * If an exception was throw when applying any of these writing optimizations, the action would fail - * and no Spark job was submitted. + * If an exception was throw when applying any of these writing optimizations, the action will fail + * and no Spark job will be submitted. * * The writing procedure is: * 1. Create a writer factory by {@link #createWriterFactory()}, serialize and send it to all the @@ -58,7 +58,7 @@ public interface DataSourceWriter { /** * Creates a writer factory which will be serialized and sent to executors. * - * If this method fails (by throwing an exception), the action would fail and no Spark job was + * If this method fails (by throwing an exception), the action will fail and no Spark job will be * submitted. */ DataWriterFactory createWriterFactory(); diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java index c2c2ab73257e8..7527bcc0c4027 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java @@ -35,7 +35,7 @@ public interface DataWriterFactory extends Serializable { /** * Returns a data writer to do the actual writing work. * - * If this method fails (by throwing an exception), the action would fail and no Spark job was + * If this method fails (by throwing an exception), the action will fail and no Spark job will be * submitted. * * @param partitionId A unique id of the RDD partition that the returned writer will process. From 8e60a16b73490007fe1c480d77cc09d760f0a02b Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Wed, 16 May 2018 13:34:54 -0700 Subject: [PATCH 0810/1161] [SPARK-23601][BUILD][FOLLOW-UP] Keep md5 checksums for nexus artifacts. The repository.apache.org server still requires md5 checksums or it won't publish the staging repo. Author: Marcelo Vanzin Closes #21338 from vanzin/SPARK-23601. --- dev/create-release/release-build.sh | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/dev/create-release/release-build.sh b/dev/create-release/release-build.sh index c00b00b845401..5faa3d3260a56 100755 --- a/dev/create-release/release-build.sh +++ b/dev/create-release/release-build.sh @@ -371,11 +371,18 @@ if [[ "$1" == "publish-release" ]]; then find . -type f |grep -v \.jar |grep -v \.pom | xargs rm echo "Creating hash and signature files" - # this must have .asc and .sha1 - it really doesn't like anything else there + # this must have .asc, .md5 and .sha1 - it really doesn't like anything else there for file in $(find . -type f) do echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --output $file.asc \ --detach-sig --armour $file; + if [ $(command -v md5) ]; then + # Available on OS X; -q to keep only hash + md5 -q $file > $file.md5 + else + # Available on Linux; cut to keep only hash + md5sum $file | cut -f1 -d' ' > $file.md5 + fi sha1sum $file | cut -f1 -d' ' > $file.sha1 done From 991726f31a8d182ed6d5b0e59185d97c0c5c532f Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Wed, 16 May 2018 14:55:02 -0700 Subject: [PATCH 0811/1161] [SPARK-24158][SS] Enable no-data batches for streaming joins ## What changes were proposed in this pull request? This is a continuation of the larger task of enabling zero-data batches for more eager state cleanup. This PR enables it for stream-stream joins. ## How was this patch tested? - Updated join tests. Additionally, updated them to not use `CheckLastBatch` anywhere to set good precedence for future. Author: Tathagata Das Closes #21253 from tdas/SPARK-24158. --- .../spark/sql/execution/SparkStrategies.scala | 2 +- .../StreamingSymmetricHashJoinExec.scala | 14 +- .../spark/sql/streaming/StreamTest.scala | 15 +- .../sql/streaming/StreamingJoinSuite.scala | 217 +++++++++--------- 4 files changed, 130 insertions(+), 118 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 82b4eb9fba242..37a0b9d6c8728 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -361,7 +361,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case Join(left, right, _, _) if left.isStreaming && right.isStreaming => throw new AnalysisException( - "Stream stream joins without equality predicate is not supported", plan = Some(plan)) + "Stream-stream join without equality predicate is not supported", plan = Some(plan)) case _ => Nil } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala index fa7c8ee906ecd..afa664eb76525 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala @@ -187,6 +187,17 @@ case class StreamingSymmetricHashJoinExec( s"${getClass.getSimpleName} should not take $x as the JoinType") } + override def shouldRunAnotherBatch(newMetadata: OffsetSeqMetadata): Boolean = { + val watermarkUsedForStateCleanup = + stateWatermarkPredicates.left.nonEmpty || stateWatermarkPredicates.right.nonEmpty + + // Latest watermark value is more than that used in this previous executed plan + val watermarkHasChanged = + eventTimeWatermark.isDefined && newMetadata.batchWatermarkMs > eventTimeWatermark.get + + watermarkUsedForStateCleanup && watermarkHasChanged + } + protected override def doExecute(): RDD[InternalRow] = { val stateStoreCoord = sqlContext.sessionState.streamingQueryManager.stateStoreCoordinator val stateStoreNames = SymmetricHashJoinStateManager.allStateStoreNames(LeftSide, RightSide) @@ -319,8 +330,7 @@ case class StreamingSymmetricHashJoinExec( // outer join) if possible. In all cases, nothing needs to be outputted, hence the removal // needs to be done greedily by immediately consuming the returned iterator. val cleanupIter = joinType match { - case Inner => - leftSideJoiner.removeOldState() ++ rightSideJoiner.removeOldState() + case Inner => leftSideJoiner.removeOldState() ++ rightSideJoiner.removeOldState() case LeftOuter => rightSideJoiner.removeOldState() case RightOuter => leftSideJoiner.removeOldState() case _ => throwBadJoinTypeException() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index 9d139a927bea5..f348dac1319cb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -199,15 +199,12 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be case class CheckAnswerRowsByFunc( globalCheckFunction: Seq[Row] => Unit, lastOnly: Boolean) extends StreamAction with StreamMustBeRunning { - override def toString: String = s"$operatorName" - private def operatorName = if (lastOnly) "CheckLastBatchByFunc" else "CheckAnswerByFunc" + override def toString: String = if (lastOnly) "CheckLastBatchByFunc" else "CheckAnswerByFunc" } case class CheckNewAnswerRows(expectedAnswer: Seq[Row]) extends StreamAction with StreamMustBeRunning { - override def toString: String = s"$operatorName: ${expectedAnswer.mkString(",")}" - - private def operatorName = "CheckNewAnswer" + override def toString: String = s"CheckNewAnswer: ${expectedAnswer.mkString(",")}" } object CheckNewAnswer { @@ -218,6 +215,8 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be val toExternalRow = RowEncoder(encoder.schema).resolveAndBind() CheckNewAnswerRows((data +: moreData).map(d => toExternalRow.fromRow(encoder.toRow(d)))) } + + def apply(rows: Row*): CheckNewAnswerRows = CheckNewAnswerRows(rows) } /** Stops the stream. It must currently be running. */ @@ -747,7 +746,6 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be error => failTest(error) } } - pos += 1 } try { @@ -761,8 +759,11 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be currentStream.asInstanceOf[MicroBatchExecution].withProgressLocked { actns.foreach(executeAction) } + pos += 1 - case action: StreamAction => executeAction(action) + case action: StreamAction => + executeAction(action) + pos += 1 } if (streamThreadDeathCause != null) { failTest("Stream Thread Died", streamThreadDeathCause) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala index da8f9608c1e9c..1f62357e6d09e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala @@ -62,20 +62,20 @@ class StreamingInnerJoinSuite extends StreamTest with StateStoreMetricsTest with AddData(input1, 1), CheckAnswer(), AddData(input2, 1, 10), // 1 arrived on input1 first, then input2, should join - CheckLastBatch((1, 2, 3)), + CheckNewAnswer((1, 2, 3)), AddData(input1, 10), // 10 arrived on input2 first, then input1, should join - CheckLastBatch((10, 20, 30)), + CheckNewAnswer((10, 20, 30)), AddData(input2, 1), // another 1 in input2 should join with 1 input1 - CheckLastBatch((1, 2, 3)), + CheckNewAnswer((1, 2, 3)), StopStream, StartStream(), AddData(input1, 1), // multiple 1s should be kept in state causing multiple (1, 2, 3) - CheckLastBatch((1, 2, 3), (1, 2, 3)), + CheckNewAnswer((1, 2, 3), (1, 2, 3)), StopStream, StartStream(), AddData(input1, 100), AddData(input2, 100), - CheckLastBatch((100, 200, 300)) + CheckNewAnswer((100, 200, 300)) ) } @@ -97,25 +97,25 @@ class StreamingInnerJoinSuite extends StreamTest with StateStoreMetricsTest with testStream(joined)( AddData(input1, 1), - CheckLastBatch(), + CheckNewAnswer(), AddData(input2, 1), - CheckLastBatch((1, 10, 2, 3)), + CheckNewAnswer((1, 10, 2, 3)), StopStream, StartStream(), AddData(input1, 25), - CheckLastBatch(), + CheckNewAnswer(), StopStream, StartStream(), AddData(input2, 25), - CheckLastBatch((25, 30, 50, 75)), + CheckNewAnswer((25, 30, 50, 75)), AddData(input1, 1), - CheckLastBatch((1, 10, 2, 3)), // State for 1 still around as there is no watermark + CheckNewAnswer((1, 10, 2, 3)), // State for 1 still around as there is no watermark StopStream, StartStream(), AddData(input1, 5), - CheckLastBatch(), + CheckNewAnswer(), AddData(input2, 5), - CheckLastBatch((5, 10, 10, 15)) // No filter by any watermark + CheckNewAnswer((5, 10, 10, 15)) // No filter by any watermark ) } @@ -142,27 +142,27 @@ class StreamingInnerJoinSuite extends StreamTest with StateStoreMetricsTest with assertNumStateRows(total = 1, updated = 1), AddData(input2, 1), - CheckLastBatch((1, 10, 2, 3)), + CheckAnswer((1, 10, 2, 3)), assertNumStateRows(total = 2, updated = 1), StopStream, StartStream(), AddData(input1, 25), - CheckLastBatch(), // since there is only 1 watermark operator, the watermark should be 15 - assertNumStateRows(total = 3, updated = 1), + CheckNewAnswer(), // watermark = 15, no-data-batch should remove 2 rows having window=[0,10] + assertNumStateRows(total = 1, updated = 1), AddData(input2, 25), - CheckLastBatch((25, 30, 50, 75)), // watermark = 15 should remove 2 rows having window=[0,10] + CheckNewAnswer((25, 30, 50, 75)), assertNumStateRows(total = 2, updated = 1), StopStream, StartStream(), AddData(input2, 1), - CheckLastBatch(), // Should not join as < 15 removed - assertNumStateRows(total = 2, updated = 0), // row not add as 1 < state key watermark = 15 + CheckNewAnswer(), // Should not join as < 15 removed + assertNumStateRows(total = 2, updated = 0), // row not add as 1 < state key watermark = 15 AddData(input1, 5), - CheckLastBatch(), // Should not join or add to state as < 15 got filtered by watermark + CheckNewAnswer(), // Same reason as above assertNumStateRows(total = 2, updated = 0) ) } @@ -189,42 +189,39 @@ class StreamingInnerJoinSuite extends StreamTest with StateStoreMetricsTest with AddData(leftInput, (1, 5)), CheckAnswer(), AddData(rightInput, (1, 11)), - CheckLastBatch((1, 5, 11)), + CheckNewAnswer((1, 5, 11)), AddData(rightInput, (1, 10)), - CheckLastBatch(), // no match as neither 5, nor 10 from leftTime is less than rightTime 10 - 5 + CheckNewAnswer(), // no match as leftTime 5 is not < rightTime 10 - 5 assertNumStateRows(total = 3, updated = 3), // Increase event time watermark to 20s by adding data with time = 30s on both inputs AddData(leftInput, (1, 3), (1, 30)), - CheckLastBatch((1, 3, 10), (1, 3, 11)), + CheckNewAnswer((1, 3, 10), (1, 3, 11)), assertNumStateRows(total = 5, updated = 2), AddData(rightInput, (0, 30)), - CheckLastBatch(), - assertNumStateRows(total = 6, updated = 1), + CheckNewAnswer(), // event time watermark: max event time - 10 ==> 30 - 10 = 20 + // so left side going to only receive data where leftTime > 20 // right side state constraint: 20 < leftTime < rightTime - 5 ==> rightTime > 25 - - // Run another batch with event time = 25 to clear right state where rightTime <= 25 - AddData(rightInput, (0, 30)), - CheckLastBatch(), - assertNumStateRows(total = 5, updated = 1), // removed (1, 11) and (1, 10), added (0, 30) + // right state where rightTime <= 25 will be cleared, (1, 11) and (1, 10) removed + assertNumStateRows(total = 4, updated = 1), // New data to right input should match with left side (1, 3) and (1, 5), as left state should // not be cleared. But rows rightTime <= 20 should be filtered due to event time watermark and // state rows with rightTime <= 25 should be removed from state. // (1, 20) ==> filtered by event time watermark = 20 // (1, 21) ==> passed filter, matched with left (1, 3) and (1, 5), not added to state - // as state watermark = 25 + // as 21 < state watermark = 25 // (1, 28) ==> passed filter, matched with left (1, 3) and (1, 5), added to state AddData(rightInput, (1, 20), (1, 21), (1, 28)), - CheckLastBatch((1, 3, 21), (1, 5, 21), (1, 3, 28), (1, 5, 28)), - assertNumStateRows(total = 6, updated = 1), + CheckNewAnswer((1, 3, 21), (1, 5, 21), (1, 3, 28), (1, 5, 28)), + assertNumStateRows(total = 5, updated = 1), // New data to left input with leftTime <= 20 should be filtered due to event time watermark AddData(leftInput, (1, 20), (1, 21)), - CheckLastBatch((1, 21, 28)), - assertNumStateRows(total = 7, updated = 1) + CheckNewAnswer((1, 21, 28)), + assertNumStateRows(total = 6, updated = 1) ) } @@ -275,38 +272,39 @@ class StreamingInnerJoinSuite extends StreamTest with StateStoreMetricsTest with AddData(leftInput, (1, 20)), CheckAnswer(), AddData(rightInput, (1, 14), (1, 15), (1, 25), (1, 26), (1, 30), (1, 31)), - CheckLastBatch((1, 20, 15), (1, 20, 25), (1, 20, 26), (1, 20, 30)), + CheckNewAnswer((1, 20, 15), (1, 20, 25), (1, 20, 26), (1, 20, 30)), assertNumStateRows(total = 7, updated = 7), // If rightTime = 60, then it matches only leftTime = [50, 65] AddData(rightInput, (1, 60)), - CheckLastBatch(), // matches with nothing on the left + CheckNewAnswer(), // matches with nothing on the left AddData(leftInput, (1, 49), (1, 50), (1, 65), (1, 66)), - CheckLastBatch((1, 50, 60), (1, 65, 60)), - assertNumStateRows(total = 12, updated = 5), + CheckNewAnswer((1, 50, 60), (1, 65, 60)), // Event time watermark = min(left: 66 - delay 20 = 46, right: 60 - delay 30 = 30) = 30 // Left state value watermark = 30 - 10 = slightly less than 20 (since condition has <=) // Should drop < 20 from left, i.e., none // Right state value watermark = 30 - 5 = slightly less than 25 (since condition has <=) // Should drop < 25 from the right, i.e., 14 and 15 - AddData(leftInput, (1, 30), (1, 31)), // 30 should not be processed or added to stat - CheckLastBatch((1, 31, 26), (1, 31, 30), (1, 31, 31)), - assertNumStateRows(total = 11, updated = 1), // 12 - 2 removed + 1 added + assertNumStateRows(total = 10, updated = 5), // 12 - 2 removed + + AddData(leftInput, (1, 30), (1, 31)), // 30 should not be processed or added to state + CheckNewAnswer((1, 31, 26), (1, 31, 30), (1, 31, 31)), + assertNumStateRows(total = 11, updated = 1), // only 31 added // Advance the watermark AddData(rightInput, (1, 80)), - CheckLastBatch(), - assertNumStateRows(total = 12, updated = 1), - + CheckNewAnswer(), // Event time watermark = min(left: 66 - delay 20 = 46, right: 80 - delay 30 = 50) = 46 // Left state value watermark = 46 - 10 = slightly less than 36 (since condition has <=) // Should drop < 36 from left, i.e., 20, 31 (30 was not added) // Right state value watermark = 46 - 5 = slightly less than 41 (since condition has <=) // Should drop < 41 from the right, i.e., 25, 26, 30, 31 - AddData(rightInput, (1, 50)), - CheckLastBatch((1, 49, 50), (1, 50, 50)), - assertNumStateRows(total = 7, updated = 1) // 12 - 6 removed + 1 added + assertNumStateRows(total = 6, updated = 1), // 12 - 6 removed + + AddData(rightInput, (1, 46), (1, 50)), // 46 should not be processed or added to state + CheckNewAnswer((1, 49, 50), (1, 50, 50)), + assertNumStateRows(total = 7, updated = 1) // 50 added ) } @@ -322,7 +320,7 @@ class StreamingInnerJoinSuite extends StreamTest with StateStoreMetricsTest with input1.addData(1) q.awaitTermination(10000) } - assert(e.toString.contains("Stream stream joins without equality predicate is not supported")) + assert(e.toString.contains("Stream-stream join without equality predicate is not supported")) } test("stream stream self join") { @@ -404,10 +402,11 @@ class StreamingInnerJoinSuite extends StreamTest with StateStoreMetricsTest with AddData(input1, 1, 5), AddData(input2, 1, 5, 10), AddData(input3, 5, 10), - CheckLastBatch((5, 10, 5, 15, 5, 25))) + CheckNewAnswer((5, 10, 5, 15, 5, 25))) } } + class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with BeforeAndAfter { import testImplicits._ @@ -465,13 +464,13 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with MultiAddData(leftInput, 1, 2, 3)(rightInput, 3, 4, 5), // The left rows with leftValue <= 4 should generate their outer join row now and // not get added to the state. - CheckLastBatch(Row(3, 10, 6, "9"), Row(1, 10, 2, null), Row(2, 10, 4, null)), + CheckNewAnswer(Row(3, 10, 6, "9"), Row(1, 10, 2, null), Row(2, 10, 4, null)), assertNumStateRows(total = 4, updated = 4), // We shouldn't get more outer join rows when the watermark advances. MultiAddData(leftInput, 20)(rightInput, 21), - CheckLastBatch(), + CheckNewAnswer(), AddData(rightInput, 20), - CheckLastBatch((20, 30, 40, "60")) + CheckNewAnswer((20, 30, 40, "60")) ) } @@ -492,15 +491,15 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with testStream(joined)( MultiAddData(leftInput, 3, 4, 5)(rightInput, 1, 2, 3), - // The right rows with value <= 7 should never be added to the state. - CheckLastBatch(Row(3, 10, 6, "9")), + // The right rows with rightValue <= 7 should never be added to the state. + CheckNewAnswer(Row(3, 10, 6, "9")), // rightValue = 9 > 7 hence joined and added to state assertNumStateRows(total = 4, updated = 4), // When the watermark advances, we get the outer join rows just as we would if they // were added but didn't match the full join condition. - MultiAddData(leftInput, 20)(rightInput, 21), - CheckLastBatch(), + MultiAddData(leftInput, 20)(rightInput, 21), // watermark = 10, no-data-batch computes nulls + CheckNewAnswer(Row(4, 10, 8, null), Row(5, 10, 10, null)), AddData(rightInput, 20), - CheckLastBatch(Row(20, 30, 40, "60"), Row(4, 10, 8, null), Row(5, 10, 10, null)) + CheckNewAnswer(Row(20, 30, 40, "60")) ) } @@ -521,15 +520,15 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with testStream(joined)( MultiAddData(leftInput, 1, 2, 3)(rightInput, 3, 4, 5), - // The left rows with value <= 4 should never be added to the state. - CheckLastBatch(Row(3, 10, 6, "9")), + // The left rows with leftValue <= 4 should never be added to the state. + CheckNewAnswer(Row(3, 10, 6, "9")), // leftValue = 7 > 4 hence joined and added to state assertNumStateRows(total = 4, updated = 4), // When the watermark advances, we get the outer join rows just as we would if they // were added but didn't match the full join condition. - MultiAddData(leftInput, 20)(rightInput, 21), - CheckLastBatch(), + MultiAddData(leftInput, 20)(rightInput, 21), // watermark = 10, no-data-batch computes nulls + CheckNewAnswer(Row(4, 10, null, "12"), Row(5, 10, null, "15")), AddData(rightInput, 20), - CheckLastBatch(Row(20, 30, 40, "60"), Row(4, 10, null, "12"), Row(5, 10, null, "15")) + CheckNewAnswer(Row(20, 30, 40, "60")) ) } @@ -552,13 +551,13 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with MultiAddData(leftInput, 3, 4, 5)(rightInput, 1, 2, 3), // The right rows with rightValue <= 7 should generate their outer join row now and // not get added to the state. - CheckLastBatch(Row(3, 10, 6, "9"), Row(1, 10, null, "3"), Row(2, 10, null, "6")), + CheckNewAnswer(Row(3, 10, 6, "9"), Row(1, 10, null, "3"), Row(2, 10, null, "6")), assertNumStateRows(total = 4, updated = 4), // We shouldn't get more outer join rows when the watermark advances. MultiAddData(leftInput, 20)(rightInput, 21), - CheckLastBatch(), + CheckNewAnswer(), AddData(rightInput, 20), - CheckLastBatch((20, 30, 40, "60")) + CheckNewAnswer((20, 30, 40, "60")) ) } @@ -568,14 +567,14 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with testStream(joined)( // Test inner part of the join. MultiAddData(leftInput, 1, 2, 3, 4, 5)(rightInput, 3, 4, 5, 6, 7), - CheckLastBatch((3, 10, 6, 9), (4, 10, 8, 12), (5, 10, 10, 15)), - // Old state doesn't get dropped until the batch *after* it gets introduced, so the - // nulls won't show up until the next batch after the watermark advances. - MultiAddData(leftInput, 21)(rightInput, 22), - CheckLastBatch(), - assertNumStateRows(total = 12, updated = 12), + CheckNewAnswer((3, 10, 6, 9), (4, 10, 8, 12), (5, 10, 10, 15)), + + MultiAddData(leftInput, 21)(rightInput, 22), // watermark = 11, no-data-batch computes nulls + CheckNewAnswer(Row(1, 10, 2, null), Row(2, 10, 4, null)), + assertNumStateRows(total = 2, updated = 12), + AddData(leftInput, 22), - CheckLastBatch(Row(22, 30, 44, 66), Row(1, 10, 2, null), Row(2, 10, 4, null)), + CheckNewAnswer(Row(22, 30, 44, 66)), assertNumStateRows(total = 3, updated = 1) ) } @@ -586,14 +585,14 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with testStream(joined)( // Test inner part of the join. MultiAddData(leftInput, 1, 2, 3, 4, 5)(rightInput, 3, 4, 5, 6, 7), - CheckLastBatch((3, 10, 6, 9), (4, 10, 8, 12), (5, 10, 10, 15)), - // Old state doesn't get dropped until the batch *after* it gets introduced, so the - // nulls won't show up until the next batch after the watermark advances. - MultiAddData(leftInput, 21)(rightInput, 22), - CheckLastBatch(), - assertNumStateRows(total = 12, updated = 12), + CheckNewAnswer((3, 10, 6, 9), (4, 10, 8, 12), (5, 10, 10, 15)), + + MultiAddData(leftInput, 21)(rightInput, 22), // watermark = 11, no-data-batch computes nulls + CheckNewAnswer(Row(6, 10, null, 18), Row(7, 10, null, 21)), + assertNumStateRows(total = 2, updated = 12), + AddData(leftInput, 22), - CheckLastBatch(Row(22, 30, 44, 66), Row(6, 10, null, 18), Row(7, 10, null, 21)), + CheckNewAnswer(Row(22, 30, 44, 66)), assertNumStateRows(total = 3, updated = 1) ) } @@ -627,21 +626,18 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with AddData(leftInput, (1, 5), (3, 5)), CheckAnswer(), AddData(rightInput, (1, 10), (2, 5)), - CheckLastBatch((1, 1, 5, 10)), + CheckNewAnswer((1, 1, 5, 10)), AddData(rightInput, (1, 11)), - CheckLastBatch(), // no match as left time is too low + CheckNewAnswer(), // no match as left time is too low assertNumStateRows(total = 5, updated = 5), // Increase event time watermark to 20s by adding data with time = 30s on both inputs AddData(leftInput, (1, 7), (1, 30)), - CheckLastBatch((1, 1, 7, 10), (1, 1, 7, 11)), + CheckNewAnswer((1, 1, 7, 10), (1, 1, 7, 11)), assertNumStateRows(total = 7, updated = 2), - AddData(rightInput, (0, 30)), - CheckLastBatch(), - assertNumStateRows(total = 8, updated = 1), - AddData(rightInput, (0, 30)), - CheckLastBatch(outerResult), - assertNumStateRows(total = 3, updated = 1) + AddData(rightInput, (0, 30)), // watermark = 30 - 10 = 20, no-data-batch computes nulls + CheckNewAnswer(outerResult), + assertNumStateRows(total = 2, updated = 1) ) } } @@ -665,36 +661,41 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with testStream(joined)( // leftValue <= 10 should generate outer join rows even though it matches right keys MultiAddData(leftInput, 1, 2, 3)(rightInput, 1, 2, 3), - CheckLastBatch(Row(1, 10, 2, null), Row(2, 10, 4, null), Row(3, 10, 6, null)), - MultiAddData(leftInput, 20)(rightInput, 21), - CheckLastBatch(), - assertNumStateRows(total = 5, updated = 5), // 1...3 added, but 20 and 21 not added + CheckNewAnswer(Row(1, 10, 2, null), Row(2, 10, 4, null), Row(3, 10, 6, null)), + assertNumStateRows(total = 3, updated = 3), // only right 1, 2, 3 added + + MultiAddData(leftInput, 20)(rightInput, 21), // watermark = 10, no-data-batch cleared < 10 + CheckNewAnswer(), + assertNumStateRows(total = 2, updated = 2), // only 20 and 21 left in state + AddData(rightInput, 20), - CheckLastBatch( - Row(20, 30, 40, 60)), + CheckNewAnswer(Row(20, 30, 40, 60)), assertNumStateRows(total = 3, updated = 1), + // leftValue and rightValue both satisfying condition should not generate outer join rows - MultiAddData(leftInput, 40, 41)(rightInput, 40, 41), - CheckLastBatch((40, 50, 80, 120), (41, 50, 82, 123)), - MultiAddData(leftInput, 70)(rightInput, 71), - CheckLastBatch(), - assertNumStateRows(total = 6, updated = 6), // all inputs added since last check + MultiAddData(leftInput, 40, 41)(rightInput, 40, 41), // watermark = 31 + CheckNewAnswer((40, 50, 80, 120), (41, 50, 82, 123)), + assertNumStateRows(total = 4, updated = 4), // only left 40, 41 + right 40,41 left in state + + MultiAddData(leftInput, 70)(rightInput, 71), // watermark = 60 + CheckNewAnswer(), + assertNumStateRows(total = 2, updated = 2), // only 70, 71 left in state + AddData(rightInput, 70), - CheckLastBatch((70, 80, 140, 210)), + CheckNewAnswer((70, 80, 140, 210)), assertNumStateRows(total = 3, updated = 1), + // rightValue between 300 and 1000 should generate outer join rows even though it matches left - MultiAddData(leftInput, 101, 102, 103)(rightInput, 101, 102, 103), - CheckLastBatch(), + MultiAddData(leftInput, 101, 102, 103)(rightInput, 101, 102, 103), // watermark = 91 + CheckNewAnswer(), + assertNumStateRows(total = 6, updated = 3), // only 101 - 103 left in state + MultiAddData(leftInput, 1000)(rightInput, 1001), - CheckLastBatch(), - assertNumStateRows(total = 8, updated = 5), // 101...103 added, but 1000 and 1001 not added - AddData(rightInput, 1000), - CheckLastBatch( - Row(1000, 1010, 2000, 3000), + CheckNewAnswer( Row(101, 110, 202, null), Row(102, 110, 204, null), Row(103, 110, 206, null)), - assertNumStateRows(total = 3, updated = 1) + assertNumStateRows(total = 2, updated = 2) ) } } From bfd75cdfb22a8c2fb005da597621e1ccd3990e82 Mon Sep 17 00:00:00 2001 From: Lu WANG Date: Wed, 16 May 2018 17:54:06 -0700 Subject: [PATCH 0812/1161] [SPARK-22210][ML] Add seed for LDA variationalTopicInference ## What changes were proposed in this pull request? - Add seed parameter for variationalTopicInference - Add seed for calling variationalTopicInference in submitMiniBatch - Add var seed in LDAModel so that it can take the seed from LDA and use it for the function call of variationalTopicInference in logLikelihoodBound, topicDistributions, getTopicDistributionMethod, and topicDistribution. ## How was this patch tested? Check the test result in mllib.clustering.LDASuite to make sure the result is repeatable with the seed. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Lu WANG Closes #21183 from ludatabricks/SPARK-22210. --- .../org/apache/spark/ml/clustering/LDA.scala | 6 ++- .../spark/mllib/clustering/LDAModel.scala | 34 ++++++++++++--- .../spark/mllib/clustering/LDAOptimizer.scala | 42 +++++++++++-------- .../apache/spark/ml/clustering/LDASuite.scala | 6 +++ 4 files changed, 64 insertions(+), 24 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala index afe599cd167cb..fed42c959b5ef 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala @@ -569,10 +569,14 @@ abstract class LDAModel private[ml] ( class LocalLDAModel private[ml] ( uid: String, vocabSize: Int, - @Since("1.6.0") override private[clustering] val oldLocalModel: OldLocalLDAModel, + private[clustering] val oldLocalModel_ : OldLocalLDAModel, sparkSession: SparkSession) extends LDAModel(uid, vocabSize, sparkSession) { + override private[clustering] def oldLocalModel: OldLocalLDAModel = { + oldLocalModel_.setSeed(getSeed) + } + @Since("1.6.0") override def copy(extra: ParamMap): LocalLDAModel = { val copied = new LocalLDAModel(uid, vocabSize, oldLocalModel, sparkSession) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala index b8a6e94248421..f915062d77389 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala @@ -32,7 +32,7 @@ import org.apache.spark.mllib.linalg.{Matrices, Matrix, Vector, Vectors} import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.rdd.RDD import org.apache.spark.sql.{Row, SparkSession} -import org.apache.spark.util.BoundedPriorityQueue +import org.apache.spark.util.{BoundedPriorityQueue, Utils} /** * Latent Dirichlet Allocation (LDA) model. @@ -194,6 +194,8 @@ class LocalLDAModel private[spark] ( override protected[spark] val gammaShape: Double = 100) extends LDAModel with Serializable { + private var seed: Long = Utils.random.nextLong() + @Since("1.3.0") override def k: Int = topics.numCols @@ -216,6 +218,21 @@ class LocalLDAModel private[spark] ( override protected def formatVersion = "1.0" + /** + * Random seed for cluster initialization. + */ + @Since("2.4.0") + def getSeed: Long = seed + + /** + * Set the random seed for cluster initialization. + */ + @Since("2.4.0") + def setSeed(seed: Long): this.type = { + this.seed = seed + this + } + @Since("1.5.0") override def save(sc: SparkContext, path: String): Unit = { LocalLDAModel.SaveLoadV1_0.save(sc, path, topicsMatrix, docConcentration, topicConcentration, @@ -298,6 +315,7 @@ class LocalLDAModel private[spark] ( // by topic (columns of lambda) val Elogbeta = LDAUtils.dirichletExpectation(lambda.t).t val ElogbetaBc = documents.sparkContext.broadcast(Elogbeta) + val gammaSeed = this.seed // Sum bound components for each document: // component for prob(tokens) + component for prob(document-topic distribution) @@ -306,7 +324,7 @@ class LocalLDAModel private[spark] ( val localElogbeta = ElogbetaBc.value var docBound = 0.0D val (gammad: BDV[Double], _, _) = OnlineLDAOptimizer.variationalTopicInference( - termCounts, exp(localElogbeta), brzAlpha, gammaShape, k) + termCounts, exp(localElogbeta), brzAlpha, gammaShape, k, gammaSeed + id) val Elogthetad: BDV[Double] = LDAUtils.dirichletExpectation(gammad) // E[log p(doc | theta, beta)] @@ -352,6 +370,7 @@ class LocalLDAModel private[spark] ( val docConcentrationBrz = this.docConcentration.asBreeze val gammaShape = this.gammaShape val k = this.k + val gammaSeed = this.seed documents.map { case (id: Long, termCounts: Vector) => if (termCounts.numNonzeros == 0) { @@ -362,7 +381,8 @@ class LocalLDAModel private[spark] ( expElogbetaBc.value, docConcentrationBrz, gammaShape, - k) + k, + gammaSeed + id) (id, Vectors.dense(normalize(gamma, 1.0).toArray)) } } @@ -376,6 +396,7 @@ class LocalLDAModel private[spark] ( val docConcentrationBrz = this.docConcentration.asBreeze val gammaShape = this.gammaShape val k = this.k + val gammaSeed = this.seed (termCounts: Vector) => if (termCounts.numNonzeros == 0) { @@ -386,7 +407,8 @@ class LocalLDAModel private[spark] ( expElogbeta, docConcentrationBrz, gammaShape, - k) + k, + gammaSeed) Vectors.dense(normalize(gamma, 1.0).toArray) } } @@ -403,6 +425,7 @@ class LocalLDAModel private[spark] ( */ @Since("2.0.0") def topicDistribution(document: Vector): Vector = { + val gammaSeed = this.seed val expElogbeta = exp(LDAUtils.dirichletExpectation(topicsMatrix.asBreeze.toDenseMatrix.t).t) if (document.numNonzeros == 0) { Vectors.zeros(this.k) @@ -412,7 +435,8 @@ class LocalLDAModel private[spark] ( expElogbeta, this.docConcentration.asBreeze, gammaShape, - this.k) + this.k, + gammaSeed) Vectors.dense(normalize(gamma, 1.0).toArray) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala index 693a2a31f026b..f8e5f3ed76457 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala @@ -30,6 +30,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.mllib.linalg.{DenseVector, Matrices, SparseVector, Vector, Vectors} import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel +import org.apache.spark.util.Utils /** * :: DeveloperApi :: @@ -464,6 +465,7 @@ final class OnlineLDAOptimizer extends LDAOptimizer with Logging { val alpha = this.alpha.asBreeze val gammaShape = this.gammaShape val optimizeDocConcentration = this.optimizeDocConcentration + val seed = randomGenerator.nextLong() // If and only if optimizeDocConcentration is set true, // we calculate logphat in the same pass as other statistics. // No calculation of loghat happens otherwise. @@ -473,20 +475,21 @@ final class OnlineLDAOptimizer extends LDAOptimizer with Logging { None } - val stats: RDD[(BDM[Double], Option[BDV[Double]], Long)] = batch.mapPartitions { docs => - val nonEmptyDocs = docs.filter(_._2.numNonzeros > 0) - - val stat = BDM.zeros[Double](k, vocabSize) - val logphatPartOption = logphatPartOptionBase() - var nonEmptyDocCount: Long = 0L - nonEmptyDocs.foreach { case (_, termCounts: Vector) => - nonEmptyDocCount += 1 - val (gammad, sstats, ids) = OnlineLDAOptimizer.variationalTopicInference( - termCounts, expElogbetaBc.value, alpha, gammaShape, k) - stat(::, ids) := stat(::, ids) + sstats - logphatPartOption.foreach(_ += LDAUtils.dirichletExpectation(gammad)) - } - Iterator((stat, logphatPartOption, nonEmptyDocCount)) + val stats: RDD[(BDM[Double], Option[BDV[Double]], Long)] = batch.mapPartitionsWithIndex { + (index, docs) => + val nonEmptyDocs = docs.filter(_._2.numNonzeros > 0) + + val stat = BDM.zeros[Double](k, vocabSize) + val logphatPartOption = logphatPartOptionBase() + var nonEmptyDocCount: Long = 0L + nonEmptyDocs.foreach { case (_, termCounts: Vector) => + nonEmptyDocCount += 1 + val (gammad, sstats, ids) = OnlineLDAOptimizer.variationalTopicInference( + termCounts, expElogbetaBc.value, alpha, gammaShape, k, seed + index) + stat(::, ids) := stat(::, ids) + sstats + logphatPartOption.foreach(_ += LDAUtils.dirichletExpectation(gammad)) + } + Iterator((stat, logphatPartOption, nonEmptyDocCount)) } val elementWiseSum = ( @@ -578,7 +581,8 @@ final class OnlineLDAOptimizer extends LDAOptimizer with Logging { } override private[clustering] def getLDAModel(iterationTimes: Array[Double]): LDAModel = { - new LocalLDAModel(Matrices.fromBreeze(lambda).transpose, alpha, eta, gammaShape) + new LocalLDAModel(Matrices.fromBreeze(lambda).transpose, alpha, eta) + .setSeed(randomGenerator.nextLong()) } } @@ -605,18 +609,20 @@ private[clustering] object OnlineLDAOptimizer { expElogbeta: BDM[Double], alpha: breeze.linalg.Vector[Double], gammaShape: Double, - k: Int): (BDV[Double], BDM[Double], List[Int]) = { + k: Int, + seed: Long): (BDV[Double], BDM[Double], List[Int]) = { val (ids: List[Int], cts: Array[Double]) = termCounts match { case v: DenseVector => ((0 until v.size).toList, v.values) case v: SparseVector => (v.indices.toList, v.values) } // Initialize the variational distribution q(theta|gamma) for the mini-batch + val randBasis = new RandBasis(new org.apache.commons.math3.random.MersenneTwister(seed)) val gammad: BDV[Double] = - new Gamma(gammaShape, 1.0 / gammaShape).samplesVector(k) // K + new Gamma(gammaShape, 1.0 / gammaShape)(randBasis).samplesVector(k) // K val expElogthetad: BDV[Double] = exp(LDAUtils.dirichletExpectation(gammad)) // K val expElogbetad = expElogbeta(ids, ::).toDenseMatrix // ids * K - val phiNorm: BDV[Double] = expElogbetad * expElogthetad +:+ 1e-100 // ids + val phiNorm: BDV[Double] = expElogbetad * expElogthetad +:+ 1e-100 // ids var meanGammaChange = 1D val ctsVector = new BDV[Double](cts) // ids diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala index 8d728f063dd8c..4d848205034c0 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala @@ -253,6 +253,12 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead val lda = new LDA() testEstimatorAndModelReadWrite(lda, dataset, LDASuite.allParamSettings, LDASuite.allParamSettings, checkModelData) + + // Make sure the result is deterministic after saving and loading the model + val model = lda.fit(dataset) + val model2 = testDefaultReadWrite(model) + assert(model.logLikelihood(dataset) ~== model2.logLikelihood(dataset) absTol 1e-6) + assert(model.logPerplexity(dataset) ~== model2.logPerplexity(dataset) absTol 1e-6) } test("read/write DistributedLDAModel") { From 9a641e7f721d01d283afb09dccefaf32972d3c04 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Thu, 17 May 2018 12:07:58 +0800 Subject: [PATCH 0813/1161] [SPARK-21945][YARN][PYTHON] Make --py-files work with PySpark shell in Yarn client mode ## What changes were proposed in this pull request? ### Problem When we run _PySpark shell with Yarn client mode_, specified `--py-files` are not recognised in _driver side_. Here are the steps I took to check: ```bash $ cat /home/spark/tmp.py def testtest(): return 1 ``` ```bash $ ./bin/pyspark --master yarn --deploy-mode client --py-files /home/spark/tmp.py ``` ```python >>> def test(): ... import tmp ... return tmp.testtest() ... >>> spark.range(1).rdd.map(lambda _: test()).collect() # executor side [1] >>> test() # driver side Traceback (most recent call last): File "", line 1, in File "", line 2, in test ImportError: No module named tmp ``` ### How did it happen? Unlike Yarn cluster and client mode with Spark submit, when Yarn client mode with PySpark shell specifically, 1. It first runs Python shell via: https://github.com/apache/spark/blob/3cb82047f2f51af553df09b9323796af507d36f8/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java#L158 as pointed out by tgravescs in the JIRA. 2. this triggers shell.py and submit another application to launch a py4j gateway: https://github.com/apache/spark/blob/209b9361ac8a4410ff797cff1115e1888e2f7e66/python/pyspark/java_gateway.py#L45-L60 3. it runs a Py4J gateway: https://github.com/apache/spark/blob/3cb82047f2f51af553df09b9323796af507d36f8/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala#L425 4. it copies (or downloads) --py-files into local temp directory: https://github.com/apache/spark/blob/3cb82047f2f51af553df09b9323796af507d36f8/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala#L365-L376 and then these files are set up to `spark.submit.pyFiles` 5. Py4J JVM is launched and then the Python paths are set via: https://github.com/apache/spark/blob/7013eea11cb32b1e0038dc751c485da5c94a484b/python/pyspark/context.py#L209-L216 However, these are not actually set because those files were copied into a tmp directory in 4. whereas this code path looks for `SparkFiles.getRootDirectory` where the files are stored only when `SparkContext.addFile()` is called. In other cluster mode, `spark.files` are set via: https://github.com/apache/spark/blob/3cb82047f2f51af553df09b9323796af507d36f8/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala#L554-L555 and those files are explicitly added via: https://github.com/apache/spark/blob/ecb8b383af1cf1b67f3111c148229e00c9c17c40/core/src/main/scala/org/apache/spark/SparkContext.scala#L395 So we are fine in other modes. In case of Yarn client and cluster with _submit_, these are manually being handled. In particular https://github.com/apache/spark/pull/6360 added most of the logics. In this case, the Python path looks manually set via, for example, `deploy.PythonRunner`. We don't use `spark.files` here. ### How does the PR fix the problem? I tried to make an isolated approach as possible as I can: simply copy py file or zip files into `SparkFiles.getRootDirectory()` in driver side if not existing. Another possible way is to set `spark.files` but it does unnecessary stuff together and sounds a bit invasive. **Before** ```python >>> def test(): ... import tmp ... return tmp.testtest() ... >>> spark.range(1).rdd.map(lambda _: test()).collect() [1] >>> test() Traceback (most recent call last): File "", line 1, in File "", line 2, in test ImportError: No module named tmp ``` **After** ```python >>> def test(): ... import tmp ... return tmp.testtest() ... >>> spark.range(1).rdd.map(lambda _: test()).collect() [1] >>> test() 1 ``` ## How was this patch tested? I manually tested in standalone and yarn cluster with PySpark shell. .zip and .py files were also tested with the similar steps above. It's difficult to add a test. Author: hyukjinkwon Closes #21267 from HyukjinKwon/SPARK-21945. --- python/pyspark/context.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/python/pyspark/context.py b/python/pyspark/context.py index dbb463f6005a1..ede3b6af0a8cf 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -211,9 +211,21 @@ def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize, for path in self._conf.get("spark.submit.pyFiles", "").split(","): if path != "": (dirname, filename) = os.path.split(path) - if filename[-4:].lower() in self.PACKAGE_EXTENSIONS: - self._python_includes.append(filename) - sys.path.insert(1, os.path.join(SparkFiles.getRootDirectory(), filename)) + try: + filepath = os.path.join(SparkFiles.getRootDirectory(), filename) + if not os.path.exists(filepath): + # In case of YARN with shell mode, 'spark.submit.pyFiles' files are + # not added via SparkContext.addFile. Here we check if the file exists, + # try to copy and then add it to the path. See SPARK-21945. + shutil.copyfile(path, filepath) + if filename[-4:].lower() in self.PACKAGE_EXTENSIONS: + self._python_includes.append(filename) + sys.path.insert(1, filepath) + except Exception: + warnings.warn( + "Failed to add file [%s] speficied in 'spark.submit.pyFiles' to " + "Python path:\n %s" % (path, "\n ".join(sys.path)), + RuntimeWarning) # Create a temporary directory inside spark.local.dir: local_dir = self._jvm.org.apache.spark.util.Utils.getLocalDir(self._jsc.sc().conf()) From 3e66350c2477a456560302b7738c9d122d5d9c43 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Florent=20P=C3=A9pin?= Date: Thu, 17 May 2018 13:31:14 +0900 Subject: [PATCH 0814/1161] [SPARK-23925][SQL] Add array_repeat collection function MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? The PR adds a new collection function, array_repeat. As there already was a function repeat with the same signature, with the only difference being the expected return type (String instead of Array), the new function is called array_repeat to distinguish. The behaviour of the function is based on Presto's one. The function creates an array containing a given element repeated the requested number of times. ## How was this patch tested? New unit tests added into: - CollectionExpressionsSuite - DataFrameFunctionsSuite Author: Florent Pépin Author: Florent Pépin Closes #21208 from pepinoflo/SPARK-23925. --- python/pyspark/sql/functions.py | 14 ++ .../catalyst/analysis/FunctionRegistry.scala | 1 + .../expressions/collectionOperations.scala | 149 ++++++++++++++++++ .../CollectionExpressionsSuite.scala | 18 +++ .../org/apache/spark/sql/functions.scala | 20 +++ .../spark/sql/DataFrameFunctionsSuite.scala | 76 +++++++++ 6 files changed, 278 insertions(+) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 6866c1cf9f882..925ac34196f4c 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2329,6 +2329,20 @@ def map_values(col): return Column(sc._jvm.functions.map_values(_to_java_column(col))) +@ignore_unicode_prefix +@since(2.4) +def array_repeat(col, count): + """ + Collection function: creates an array containing a column repeated count times. + + >>> df = spark.createDataFrame([('ab',)], ['data']) + >>> df.select(array_repeat(df.data, 3).alias('r')).collect() + [Row(r=[u'ab', u'ab', u'ab'])] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.array_repeat(_to_java_column(col), count)) + + # ---------------------------- User Defined Function ---------------------------------- class PandasUDFType(object): diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 087d000a9db70..9c370599bc0df 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -427,6 +427,7 @@ object FunctionRegistry { expression[Reverse]("reverse"), expression[Concat]("concat"), expression[Flatten]("flatten"), + expression[ArrayRepeat]("array_repeat"), CreateStruct.registryEntry, // misc functions diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 12b9ab2b272ab..2a4e42d4ba316 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -1468,3 +1468,152 @@ case class Flatten(child: Expression) extends UnaryExpression { override def prettyName: String = "flatten" } + +/** + * Returns the array containing the given input value (left) count (right) times. + */ +@ExpressionDescription( + usage = "_FUNC_(element, count) - Returns the array containing element count times.", + examples = """ + Examples: + > SELECT _FUNC_('123', 2); + ['123', '123'] + """, + since = "2.4.0") +case class ArrayRepeat(left: Expression, right: Expression) + extends BinaryExpression with ExpectsInputTypes { + + private val MAX_ARRAY_LENGTH = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH + + override def dataType: ArrayType = ArrayType(left.dataType, left.nullable) + + override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType, IntegerType) + + override def nullable: Boolean = right.nullable + + override def eval(input: InternalRow): Any = { + val count = right.eval(input) + if (count == null) { + null + } else { + if (count.asInstanceOf[Int] > MAX_ARRAY_LENGTH) { + throw new RuntimeException(s"Unsuccessful try to create array with $count elements " + + s"due to exceeding the array size limit $MAX_ARRAY_LENGTH."); + } + val element = left.eval(input) + new GenericArrayData(Array.fill(count.asInstanceOf[Int])(element)) + } + } + + override def prettyName: String = "array_repeat" + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val leftGen = left.genCode(ctx) + val rightGen = right.genCode(ctx) + val element = leftGen.value + val count = rightGen.value + val et = dataType.elementType + + val coreLogic = if (CodeGenerator.isPrimitiveType(et)) { + genCodeForPrimitiveElement(ctx, et, element, count, leftGen.isNull, ev.value) + } else { + genCodeForNonPrimitiveElement(ctx, element, count, leftGen.isNull, ev.value) + } + val resultCode = nullElementsProtection(ev, rightGen.isNull, coreLogic) + + ev.copy(code = + s""" + |boolean ${ev.isNull} = false; + |${leftGen.code} + |${rightGen.code} + |${CodeGenerator.javaType(dataType)} ${ev.value} = + | ${CodeGenerator.defaultValue(dataType)}; + |$resultCode + """.stripMargin) + } + + private def nullElementsProtection( + ev: ExprCode, + rightIsNull: String, + coreLogic: String): String = { + if (nullable) { + s""" + |if ($rightIsNull) { + | ${ev.isNull} = true; + |} else { + | ${coreLogic} + |} + """.stripMargin + } else { + coreLogic + } + } + + private def genCodeForNumberOfElements(ctx: CodegenContext, count: String): (String, String) = { + val numElements = ctx.freshName("numElements") + val numElementsCode = + s""" + |int $numElements = 0; + |if ($count > 0) { + | $numElements = $count; + |} + |if ($numElements > $MAX_ARRAY_LENGTH) { + | throw new RuntimeException("Unsuccessful try to create array with " + $numElements + + | " elements due to exceeding the array size limit $MAX_ARRAY_LENGTH."); + |} + """.stripMargin + + (numElements, numElementsCode) + } + + private def genCodeForPrimitiveElement( + ctx: CodegenContext, + elementType: DataType, + element: String, + count: String, + leftIsNull: String, + arrayDataName: String): String = { + val tempArrayDataName = ctx.freshName("tempArrayData") + val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType) + val errorMessage = s" $prettyName failed." + val (numElemName, numElemCode) = genCodeForNumberOfElements(ctx, count) + + s""" + |$numElemCode + |${ctx.createUnsafeArray(tempArrayDataName, numElemName, elementType, errorMessage)} + |if (!$leftIsNull) { + | for (int k = 0; k < $tempArrayDataName.numElements(); k++) { + | $tempArrayDataName.set$primitiveValueTypeName(k, $element); + | } + |} else { + | for (int k = 0; k < $tempArrayDataName.numElements(); k++) { + | $tempArrayDataName.setNullAt(k); + | } + |} + |$arrayDataName = $tempArrayDataName; + """.stripMargin + } + + private def genCodeForNonPrimitiveElement( + ctx: CodegenContext, + element: String, + count: String, + leftIsNull: String, + arrayDataName: String): String = { + val genericArrayClass = classOf[GenericArrayData].getName + val arrayName = ctx.freshName("arrayObject") + val (numElemName, numElemCode) = genCodeForNumberOfElements(ctx, count) + + s""" + |$numElemCode + |Object[] $arrayName = new Object[(int)$numElemName]; + |if (!$leftIsNull) { + | for (int k = 0; k < $numElemName; k++) { + | $arrayName[k] = $element; + | } + |} + |$arrayDataName = new $genericArrayClass($arrayName); + """.stripMargin + } + +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index a2851d071c7c6..57fc5f75dbca7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -468,4 +468,22 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(Flatten(asa3), null) checkEvaluation(Flatten(asa4), null) } + + test("ArrayRepeat") { + val intArray = Literal.create(Seq(1, 2), ArrayType(IntegerType)) + val strArray = Literal.create(Seq("hi", "hola"), ArrayType(StringType)) + + checkEvaluation(ArrayRepeat(Literal("hi"), Literal(0)), Seq()) + checkEvaluation(ArrayRepeat(Literal("hi"), Literal(-1)), Seq()) + checkEvaluation(ArrayRepeat(Literal("hi"), Literal(1)), Seq("hi")) + checkEvaluation(ArrayRepeat(Literal("hi"), Literal(2)), Seq("hi", "hi")) + checkEvaluation(ArrayRepeat(Literal(true), Literal(2)), Seq(true, true)) + checkEvaluation(ArrayRepeat(Literal(1), Literal(2)), Seq(1, 1)) + checkEvaluation(ArrayRepeat(Literal(3.2), Literal(2)), Seq(3.2, 3.2)) + checkEvaluation(ArrayRepeat(Literal(null), Literal(2)), Seq[String](null, null)) + checkEvaluation(ArrayRepeat(Literal(null, IntegerType), Literal(2)), Seq[Integer](null, null)) + checkEvaluation(ArrayRepeat(intArray, Literal(2)), Seq(Seq(1, 2), Seq(1, 2))) + checkEvaluation(ArrayRepeat(strArray, Literal(2)), Seq(Seq("hi", "hola"), Seq("hi", "hola"))) + checkEvaluation(ArrayRepeat(Literal("hi"), Literal(null, IntegerType)), null) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index b71dfdad8aa9b..550571a61a036 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3447,6 +3447,26 @@ object functions { */ def flatten(e: Column): Column = withExpr { Flatten(e.expr) } + /** + * Creates an array containing the left argument repeated the number of times given by the + * right argument. + * + * @group collection_funcs + * @since 2.4.0 + */ + def array_repeat(left: Column, right: Column): Column = withExpr { + ArrayRepeat(left.expr, right.expr) + } + + /** + * Creates an array containing the left argument repeated the number of times given by the + * right argument. + * + * @group collection_funcs + * @since 2.4.0 + */ + def array_repeat(e: Column, count: Int): Column = array_repeat(e, lit(count)) + /** * Returns an unordered array containing the keys of the map. * @group collection_funcs diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index ecce06f4c0755..e26565cd153b4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -843,6 +843,82 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { } } + test("array_repeat function") { + val dummyFilter = (c: Column) => c.isNull || c.isNotNull // to switch codeGen on + val strDF = Seq( + ("hi", 2), + (null, 2) + ).toDF("a", "b") + + val strDFTwiceResult = Seq( + Row(Seq("hi", "hi")), + Row(Seq(null, null)) + ) + + checkAnswer(strDF.select(array_repeat($"a", 2)), strDFTwiceResult) + checkAnswer(strDF.filter(dummyFilter($"a")).select(array_repeat($"a", 2)), strDFTwiceResult) + checkAnswer(strDF.select(array_repeat($"a", $"b")), strDFTwiceResult) + checkAnswer(strDF.filter(dummyFilter($"a")).select(array_repeat($"a", $"b")), strDFTwiceResult) + checkAnswer(strDF.selectExpr("array_repeat(a, 2)"), strDFTwiceResult) + checkAnswer(strDF.selectExpr("array_repeat(a, b)"), strDFTwiceResult) + + val intDF = { + val schema = StructType(Seq( + StructField("a", IntegerType), + StructField("b", IntegerType))) + val data = Seq( + Row(3, 2), + Row(null, 2) + ) + spark.createDataFrame(spark.sparkContext.parallelize(data), schema) + } + + val intDFTwiceResult = Seq( + Row(Seq(3, 3)), + Row(Seq(null, null)) + ) + + checkAnswer(intDF.select(array_repeat($"a", 2)), intDFTwiceResult) + checkAnswer(intDF.filter(dummyFilter($"a")).select(array_repeat($"a", 2)), intDFTwiceResult) + checkAnswer(intDF.select(array_repeat($"a", $"b")), intDFTwiceResult) + checkAnswer(intDF.filter(dummyFilter($"a")).select(array_repeat($"a", $"b")), intDFTwiceResult) + checkAnswer(intDF.selectExpr("array_repeat(a, 2)"), intDFTwiceResult) + checkAnswer(intDF.selectExpr("array_repeat(a, b)"), intDFTwiceResult) + + val nullCountDF = { + val schema = StructType(Seq( + StructField("a", StringType), + StructField("b", IntegerType))) + val data = Seq( + Row("hi", null), + Row(null, null) + ) + spark.createDataFrame(spark.sparkContext.parallelize(data), schema) + } + + checkAnswer( + nullCountDF.select(array_repeat($"a", $"b")), + Seq( + Row(null), + Row(null) + ) + ) + + // Error test cases + val invalidTypeDF = Seq(("hi", "1")).toDF("a", "b") + + intercept[AnalysisException] { + invalidTypeDF.select(array_repeat($"a", $"b")) + } + intercept[AnalysisException] { + invalidTypeDF.select(array_repeat($"a", lit("1"))) + } + intercept[AnalysisException] { + invalidTypeDF.selectExpr("array_repeat(a, 1.0)") + } + + } + private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = { import DataFrameFunctionsSuite.CodegenFallbackExpr for ((codegenFallback, wholeStage) <- Seq((true, false), (false, false), (false, true))) { From 6c35865d949a8b46f654cd53c7e5f3288def18d0 Mon Sep 17 00:00:00 2001 From: Artem Rudoy Date: Thu, 17 May 2018 18:49:46 +0800 Subject: [PATCH 0815/1161] [SPARK-22371][CORE] Return None instead of throwing an exception when an accumulator is garbage collected. ## What changes were proposed in this pull request? There's a period of time when an accumulator has been garbage collected, but hasn't been removed from AccumulatorContext.originals by ContextCleaner. When an update is received for such accumulator it will throw an exception and kill the whole job. This can happen when a stage completes, but there're still running tasks from other attempts, speculation etc. Since AccumulatorContext.get() returns an option we can just return None in such case. ## How was this patch tested? Unit test. Author: Artem Rudoy Closes #21114 from artemrd/SPARK-22371. --- .../org/apache/spark/util/AccumulatorV2.scala | 14 +++++++++----- .../scala/org/apache/spark/AccumulatorSuite.scala | 6 ++---- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala b/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala index 2bc84953a56eb..3b469a69437b9 100644 --- a/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala +++ b/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala @@ -24,6 +24,7 @@ import java.util.concurrent.ConcurrentHashMap import java.util.concurrent.atomic.AtomicLong import org.apache.spark.{InternalAccumulator, SparkContext, TaskContext} +import org.apache.spark.internal.Logging import org.apache.spark.scheduler.AccumulableInfo private[spark] case class AccumulatorMetadata( @@ -211,7 +212,7 @@ abstract class AccumulatorV2[IN, OUT] extends Serializable { /** * An internal class used to track accumulators by Spark itself. */ -private[spark] object AccumulatorContext { +private[spark] object AccumulatorContext extends Logging { /** * This global map holds the original accumulator objects that are created on the driver. @@ -258,13 +259,16 @@ private[spark] object AccumulatorContext { * Returns the [[AccumulatorV2]] registered with the given ID, if any. */ def get(id: Long): Option[AccumulatorV2[_, _]] = { - Option(originals.get(id)).map { ref => - // Since we are storing weak references, we must check whether the underlying data is valid. + val ref = originals.get(id) + if (ref eq null) { + None + } else { + // Since we are storing weak references, warn when the underlying data is not valid. val acc = ref.get if (acc eq null) { - throw new IllegalStateException(s"Attempted to access garbage collected accumulator $id") + logWarning(s"Attempted to access garbage collected accumulator $id") } - acc + Option(acc) } } diff --git a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala index 3990ee1ec326d..5d0ffd92647bc 100644 --- a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala @@ -209,10 +209,8 @@ class AccumulatorSuite extends SparkFunSuite with Matchers with LocalSparkContex System.gc() assert(ref.get.isEmpty) - // Getting a garbage collected accum should throw error - intercept[IllegalStateException] { - AccumulatorContext.get(accId) - } + // Getting a garbage collected accum should return None. + assert(AccumulatorContext.get(accId).isEmpty) // Getting a normal accumulator. Note: this has to be separate because referencing an // accumulator above in an `assert` would keep it from being garbage collected. From 6ec05826d7b0a512847e2522564e01256c8d192d Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 17 May 2018 20:42:40 +0800 Subject: [PATCH 0816/1161] [SPARK-24107][CORE][FOLLOWUP] ChunkedByteBuffer.writeFully method has not reset the limit value ## What changes were proposed in this pull request? According to the discussion in https://github.com/apache/spark/pull/21175 , this PR proposes 2 improvements: 1. add comments to explain why we call `limit` to write out `ByteBuffer` with slices. 2. remove the `try ... finally` ## How was this patch tested? existing tests Author: Wenchen Fan Closes #21327 from cloud-fan/minor. --- .../spark/util/io/ChunkedByteBuffer.scala | 20 +++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala index 3ae8dfcc1cb66..700ce56466c35 100644 --- a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala +++ b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala @@ -63,15 +63,19 @@ private[spark] class ChunkedByteBuffer(var chunks: Array[ByteBuffer]) { */ def writeFully(channel: WritableByteChannel): Unit = { for (bytes <- getChunks()) { - val curChunkLimit = bytes.limit() + val originalLimit = bytes.limit() while (bytes.hasRemaining) { - try { - val ioSize = Math.min(bytes.remaining(), bufferWriteChunkSize) - bytes.limit(bytes.position() + ioSize) - channel.write(bytes) - } finally { - bytes.limit(curChunkLimit) - } + // If `bytes` is an on-heap ByteBuffer, the Java NIO API will copy it to a temporary direct + // ByteBuffer when writing it out. This temporary direct ByteBuffer is cached per thread. + // Its size has no limit and can keep growing if it sees a larger input ByteBuffer. This may + // cause significant native memory leak, if a large direct ByteBuffer is allocated and + // cached, as it's never released until thread exits. Here we write the `bytes` with + // fixed-size slices to limit the size of the cached direct ByteBuffer. + // Please refer to http://www.evanjones.ca/java-bytebuffer-leak.html for more details. + val ioSize = Math.min(bytes.remaining(), bufferWriteChunkSize) + bytes.limit(bytes.position() + ioSize) + channel.write(bytes) + bytes.limit(originalLimit) } } } From 69350aa2f0a7aee4dcb1067f073b61a0b9f9cb51 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Thu, 17 May 2018 20:45:32 +0800 Subject: [PATCH 0817/1161] [SPARK-23922][SQL] Add arrays_overlap function ## What changes were proposed in this pull request? The PR adds the function `arrays_overlap`. This function returns `true` if the input arrays contain a non-null common element; if not, it returns `null` if any of the arrays contains a `null` element, `false` otherwise. ## How was this patch tested? added UTs Author: Marco Gaido Closes #21028 from mgaido91/SPARK-23922. --- python/pyspark/sql/functions.py | 15 + .../catalyst/analysis/FunctionRegistry.scala | 1 + .../expressions/collectionOperations.scala | 267 +++++++++++++++++- .../CollectionExpressionsSuite.scala | 66 +++++ .../org/apache/spark/sql/functions.scala | 11 + .../spark/sql/DataFrameFunctionsSuite.scala | 29 ++ 6 files changed, 388 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 925ac34196f4c..8490081facc5a 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1855,6 +1855,21 @@ def array_contains(col, value): return Column(sc._jvm.functions.array_contains(_to_java_column(col), value)) +@since(2.4) +def arrays_overlap(a1, a2): + """ + Collection function: returns true if the arrays contain any common non-null element; if not, + returns null if both the arrays are non-empty and any of them contains a null element; returns + false otherwise. + + >>> df = spark.createDataFrame([(["a", "b"], ["b", "c"]), (["a"], ["b", "c"])], ['x', 'y']) + >>> df.select(arrays_overlap(df.x, df.y).alias("overlap")).collect() + [Row(overlap=True), Row(overlap=False)] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.arrays_overlap(_to_java_column(a1), _to_java_column(a2))) + + @since(2.4) def slice(x, start, length): """ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 9c370599bc0df..867c2d5eab53d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -410,6 +410,7 @@ object FunctionRegistry { // collection functions expression[CreateArray]("array"), expression[ArrayContains]("array_contains"), + expression[ArraysOverlap]("arrays_overlap"), expression[ArrayJoin]("array_join"), expression[ArrayPosition]("array_position"), expression[ArraySort]("array_sort"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 2a4e42d4ba316..c82db839438ed 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -18,15 +18,51 @@ package org.apache.spark.sql.catalyst.expressions import java.util.Comparator +import scala.collection.mutable + import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion} import org.apache.spark.sql.catalyst.expressions.ArraySortLike.NullOrder import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData, TypeUtils} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.array.ByteArrayMethods import org.apache.spark.unsafe.types.{ByteArray, UTF8String} +/** + * Base trait for [[BinaryExpression]]s with two arrays of the same element type and implicit + * casting. + */ +trait BinaryArrayExpressionWithImplicitCast extends BinaryExpression + with ImplicitCastInputTypes { + + @transient protected lazy val elementType: DataType = + inputTypes.head.asInstanceOf[ArrayType].elementType + + override def inputTypes: Seq[AbstractDataType] = { + (left.dataType, right.dataType) match { + case (ArrayType(e1, hasNull1), ArrayType(e2, hasNull2)) => + TypeCoercion.findTightestCommonType(e1, e2) match { + case Some(dt) => Seq(ArrayType(dt, hasNull1), ArrayType(dt, hasNull2)) + case _ => Seq.empty + } + case _ => Seq.empty + } + } + + override def checkInputDataTypes(): TypeCheckResult = { + (left.dataType, right.dataType) match { + case (ArrayType(e1, _), ArrayType(e2, _)) if e1.sameType(e2) => + TypeCheckResult.TypeCheckSuccess + case _ => TypeCheckResult.TypeCheckFailure(s"input to function $prettyName should have " + + s"been two ${ArrayType.simpleString}s with same element type, but it's " + + s"[${left.dataType.simpleString}, ${right.dataType.simpleString}]") + } + } +} + + /** * Given an array or map, returns its size. Returns -1 if null. */ @@ -529,6 +565,235 @@ case class ArrayContains(left: Expression, right: Expression) override def prettyName: String = "array_contains" } +/** + * Checks if the two arrays contain at least one common element. + */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(a1, a2) - Returns true if a1 contains at least a non-null element present also in a2. If the arrays have no common element and they are both non-empty and either of them contains a null element null is returned, false otherwise.", + examples = """ + Examples: + > SELECT _FUNC_(array(1, 2, 3), array(3, 4, 5)); + true + """, since = "2.4.0") +// scalastyle:off line.size.limit +case class ArraysOverlap(left: Expression, right: Expression) + extends BinaryArrayExpressionWithImplicitCast { + + override def checkInputDataTypes(): TypeCheckResult = super.checkInputDataTypes() match { + case TypeCheckResult.TypeCheckSuccess => + if (RowOrdering.isOrderable(elementType)) { + TypeCheckResult.TypeCheckSuccess + } else { + TypeCheckResult.TypeCheckFailure(s"${elementType.simpleString} cannot be used in comparison.") + } + case failure => failure + } + + @transient private lazy val ordering: Ordering[Any] = + TypeUtils.getInterpretedOrdering(elementType) + + @transient private lazy val elementTypeSupportEquals = elementType match { + case BinaryType => false + case _: AtomicType => true + case _ => false + } + + @transient private lazy val doEvaluation = if (elementTypeSupportEquals) { + fastEval _ + } else { + bruteForceEval _ + } + + override def dataType: DataType = BooleanType + + override def nullable: Boolean = { + left.nullable || right.nullable || left.dataType.asInstanceOf[ArrayType].containsNull || + right.dataType.asInstanceOf[ArrayType].containsNull + } + + override def nullSafeEval(a1: Any, a2: Any): Any = { + doEvaluation(a1.asInstanceOf[ArrayData], a2.asInstanceOf[ArrayData]) + } + + /** + * A fast implementation which puts all the elements from the smaller array in a set + * and then performs a lookup on it for each element of the bigger one. + * This eval mode works only for data types which implements properly the equals method. + */ + private def fastEval(arr1: ArrayData, arr2: ArrayData): Any = { + var hasNull = false + val (bigger, smaller) = if (arr1.numElements() > arr2.numElements()) { + (arr1, arr2) + } else { + (arr2, arr1) + } + if (smaller.numElements() > 0) { + val smallestSet = new mutable.HashSet[Any] + smaller.foreach(elementType, (_, v) => + if (v == null) { + hasNull = true + } else { + smallestSet += v + }) + bigger.foreach(elementType, (_, v1) => + if (v1 == null) { + hasNull = true + } else if (smallestSet.contains(v1)) { + return true + } + ) + } + if (hasNull) { + null + } else { + false + } + } + + /** + * A slower evaluation which performs a nested loop and supports all the data types. + */ + private def bruteForceEval(arr1: ArrayData, arr2: ArrayData): Any = { + var hasNull = false + if (arr1.numElements() > 0 && arr2.numElements() > 0) { + arr1.foreach(elementType, (_, v1) => + if (v1 == null) { + hasNull = true + } else { + arr2.foreach(elementType, (_, v2) => + if (v2 == null) { + hasNull = true + } else if (ordering.equiv(v1, v2)) { + return true + } + ) + }) + } + if (hasNull) { + null + } else { + false + } + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, (a1, a2) => { + val smaller = ctx.freshName("smallerArray") + val bigger = ctx.freshName("biggerArray") + val comparisonCode = if (elementTypeSupportEquals) { + fastCodegen(ctx, ev, smaller, bigger) + } else { + bruteForceCodegen(ctx, ev, smaller, bigger) + } + s""" + |ArrayData $smaller; + |ArrayData $bigger; + |if ($a1.numElements() > $a2.numElements()) { + | $bigger = $a1; + | $smaller = $a2; + |} else { + | $smaller = $a1; + | $bigger = $a2; + |} + |if ($smaller.numElements() > 0) { + | $comparisonCode + |} + """.stripMargin + }) + } + + /** + * Code generation for a fast implementation which puts all the elements from the smaller array + * in a set and then performs a lookup on it for each element of the bigger one. + * It works only for data types which implements properly the equals method. + */ + private def fastCodegen(ctx: CodegenContext, ev: ExprCode, smaller: String, bigger: String): String = { + val i = ctx.freshName("i") + val getFromSmaller = CodeGenerator.getValue(smaller, elementType, i) + val getFromBigger = CodeGenerator.getValue(bigger, elementType, i) + val javaElementClass = CodeGenerator.boxedType(elementType) + val javaSet = classOf[java.util.HashSet[_]].getName + val set = ctx.freshName("set") + val addToSetFromSmallerCode = nullSafeElementCodegen( + smaller, i, s"$set.add($getFromSmaller);", s"${ev.isNull} = true;") + val elementIsInSetCode = nullSafeElementCodegen( + bigger, + i, + s""" + |if ($set.contains($getFromBigger)) { + | ${ev.isNull} = false; + | ${ev.value} = true; + | break; + |} + """.stripMargin, + s"${ev.isNull} = true;") + s""" + |$javaSet<$javaElementClass> $set = new $javaSet<$javaElementClass>(); + |for (int $i = 0; $i < $smaller.numElements(); $i ++) { + | $addToSetFromSmallerCode + |} + |for (int $i = 0; $i < $bigger.numElements(); $i ++) { + | $elementIsInSetCode + |} + """.stripMargin + } + + /** + * Code generation for a slower evaluation which performs a nested loop and supports all the data types. + */ + private def bruteForceCodegen(ctx: CodegenContext, ev: ExprCode, smaller: String, bigger: String): String = { + val i = ctx.freshName("i") + val j = ctx.freshName("j") + val getFromSmaller = CodeGenerator.getValue(smaller, elementType, j) + val getFromBigger = CodeGenerator.getValue(bigger, elementType, i) + val compareValues = nullSafeElementCodegen( + smaller, + j, + s""" + |if (${ctx.genEqual(elementType, getFromSmaller, getFromBigger)}) { + | ${ev.isNull} = false; + | ${ev.value} = true; + |} + """.stripMargin, + s"${ev.isNull} = true;") + val isInSmaller = nullSafeElementCodegen( + bigger, + i, + s""" + |for (int $j = 0; $j < $smaller.numElements() && !${ev.value}; $j ++) { + | $compareValues + |} + """.stripMargin, + s"${ev.isNull} = true;") + s""" + |for (int $i = 0; $i < $bigger.numElements() && !${ev.value}; $i ++) { + | $isInSmaller + |} + """.stripMargin + } + + def nullSafeElementCodegen( + arrayVar: String, + index: String, + code: String, + isNullCode: String): String = { + if (inputTypes.exists(_.asInstanceOf[ArrayType].containsNull)) { + s""" + |if ($arrayVar.isNullAt($index)) { + | $isNullCode + |} else { + | $code + |} + """.stripMargin + } else { + code + } + } + + override def prettyName: String = "arrays_overlap" +} + /** * Slices an array according to the requested start index and length */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 57fc5f75dbca7..6ae1ac18c4dc4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -136,6 +136,72 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(ArrayContains(a3, Literal.create(null, StringType)), null) } + test("ArraysOverlap") { + val a0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType)) + val a1 = Literal.create(Seq(4, 5, 3), ArrayType(IntegerType)) + val a2 = Literal.create(Seq(null, 5, 6), ArrayType(IntegerType)) + val a3 = Literal.create(Seq(7, 8), ArrayType(IntegerType)) + val a4 = Literal.create(Seq[String](null, ""), ArrayType(StringType)) + val a5 = Literal.create(Seq[String]("", "abc"), ArrayType(StringType)) + val a6 = Literal.create(Seq[String]("def", "ghi"), ArrayType(StringType)) + + val emptyIntArray = Literal.create(Seq.empty[Int], ArrayType(IntegerType)) + + checkEvaluation(ArraysOverlap(a0, a1), true) + checkEvaluation(ArraysOverlap(a0, a2), null) + checkEvaluation(ArraysOverlap(a1, a2), true) + checkEvaluation(ArraysOverlap(a1, a3), false) + checkEvaluation(ArraysOverlap(a0, emptyIntArray), false) + checkEvaluation(ArraysOverlap(a2, emptyIntArray), false) + checkEvaluation(ArraysOverlap(emptyIntArray, a2), false) + + checkEvaluation(ArraysOverlap(a4, a5), true) + checkEvaluation(ArraysOverlap(a4, a6), null) + checkEvaluation(ArraysOverlap(a5, a6), false) + + // null handling + checkEvaluation(ArraysOverlap(emptyIntArray, a2), false) + checkEvaluation(ArraysOverlap( + emptyIntArray, Literal.create(Seq(null), ArrayType(IntegerType))), false) + checkEvaluation(ArraysOverlap(Literal.create(null, ArrayType(IntegerType)), a0), null) + checkEvaluation(ArraysOverlap(a0, Literal.create(null, ArrayType(IntegerType))), null) + checkEvaluation(ArraysOverlap( + Literal.create(Seq(null), ArrayType(IntegerType)), + Literal.create(Seq(null), ArrayType(IntegerType))), null) + + // arrays of binaries + val b0 = Literal.create(Seq[Array[Byte]](Array[Byte](1, 2), Array[Byte](3, 4)), + ArrayType(BinaryType)) + val b1 = Literal.create(Seq[Array[Byte]](Array[Byte](5, 6), Array[Byte](1, 2)), + ArrayType(BinaryType)) + val b2 = Literal.create(Seq[Array[Byte]](Array[Byte](2, 1), Array[Byte](4, 3)), + ArrayType(BinaryType)) + + checkEvaluation(ArraysOverlap(b0, b1), true) + checkEvaluation(ArraysOverlap(b0, b2), false) + + // arrays of complex data types + val aa0 = Literal.create(Seq[Array[String]](Array[String]("a", "b"), Array[String]("c", "d")), + ArrayType(ArrayType(StringType))) + val aa1 = Literal.create(Seq[Array[String]](Array[String]("e", "f"), Array[String]("a", "b")), + ArrayType(ArrayType(StringType))) + val aa2 = Literal.create(Seq[Array[String]](Array[String]("b", "a"), Array[String]("f", "g")), + ArrayType(ArrayType(StringType))) + + checkEvaluation(ArraysOverlap(aa0, aa1), true) + checkEvaluation(ArraysOverlap(aa0, aa2), false) + + // null handling with complex datatypes + val emptyBinaryArray = Literal.create(Seq.empty[Array[Byte]], ArrayType(BinaryType)) + val arrayWithBinaryNull = Literal.create(Seq(null), ArrayType(BinaryType)) + checkEvaluation(ArraysOverlap(emptyBinaryArray, b0), false) + checkEvaluation(ArraysOverlap(b0, emptyBinaryArray), false) + checkEvaluation(ArraysOverlap(emptyBinaryArray, arrayWithBinaryNull), false) + checkEvaluation(ArraysOverlap(arrayWithBinaryNull, emptyBinaryArray), false) + checkEvaluation(ArraysOverlap(arrayWithBinaryNull, b0), null) + checkEvaluation(ArraysOverlap(b0, arrayWithBinaryNull), null) + } + test("Slice") { val a0 = Literal.create(Seq(1, 2, 3, 4, 5, 6), ArrayType(IntegerType)) val a1 = Literal.create(Seq[String]("a", "b", "c", "d"), ArrayType(StringType)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 550571a61a036..2a8fe583b83bc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3085,6 +3085,17 @@ object functions { ArrayContains(column.expr, Literal(value)) } + /** + * Returns `true` if `a1` and `a2` have at least one non-null element in common. If not and both + * the arrays are non-empty and any of them contains a `null`, it returns `null`. It returns + * `false` otherwise. + * @group collection_funcs + * @since 2.4.0 + */ + def arrays_overlap(a1: Column, a2: Column): Column = withExpr { + ArraysOverlap(a1.expr, a2.expr) + } + /** * Returns an array containing all the elements in `x` from index `start` (or starting from the * end if `start` is negative) with the specified `length`. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index e26565cd153b4..d08982a138bc5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -442,6 +442,35 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { ) } + test("arrays_overlap function") { + val df = Seq( + (Seq[Option[Int]](Some(1), Some(2)), Seq[Option[Int]](Some(-1), Some(10))), + (Seq[Option[Int]](Some(1), Some(2)), Seq[Option[Int]](Some(-1), None)), + (Seq[Option[Int]](Some(3), Some(2)), Seq[Option[Int]](Some(1), Some(2))) + ).toDF("a", "b") + + val answer = Seq(Row(false), Row(null), Row(true)) + + checkAnswer(df.select(arrays_overlap(df("a"), df("b"))), answer) + checkAnswer(df.selectExpr("arrays_overlap(a, b)"), answer) + + checkAnswer( + Seq((Seq(1, 2, 3), Seq(2.0, 2.5))).toDF("a", "b").selectExpr("arrays_overlap(a, b)"), + Row(true)) + + intercept[AnalysisException] { + sql("select arrays_overlap(array(1, 2, 3), array('a', 'b', 'c'))") + } + + intercept[AnalysisException] { + sql("select arrays_overlap(null, null)") + } + + intercept[AnalysisException] { + sql("select arrays_overlap(map(1, 2), map(3, 4))") + } + } + test("slice function") { val df = Seq( Seq(1, 2, 3), From 8a837bf4f3f2758f7825d2362cf9de209026651a Mon Sep 17 00:00:00 2001 From: jinxing Date: Thu, 17 May 2018 22:29:18 +0800 Subject: [PATCH 0818/1161] [SPARK-24193] create TakeOrderedAndProjectExec only when the limit number is below spark.sql.execution.topKSortFallbackThreshold. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? Physical plan of `select colA from t order by colB limit M` is `TakeOrderedAndProject`; Currently `TakeOrderedAndProject` sorts data in memory, see https://github.com/apache/spark/blob/master/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala#L158 We can add a config – if the number of limit (M) is too big, we can sort by disk. Thus memory issue can be resolved. ## How was this patch tested? Test added Author: jinxing Closes #21252 from jinxing64/SPARK-24193. --- .../org/apache/spark/sql/internal/SQLConf.scala | 11 +++++++++++ .../apache/spark/sql/execution/SparkStrategies.scala | 12 ++++++++---- .../apache/spark/sql/execution/PlannerSuite.scala | 12 ++++++++++++ 3 files changed, 31 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index b00edca97cd44..2a673c6ce8f4a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1253,6 +1253,15 @@ object SQLConf { .booleanConf .createWithDefault(true) + val TOP_K_SORT_FALLBACK_THRESHOLD = + buildConf("spark.sql.execution.topKSortFallbackThreshold") + .internal() + .doc("In SQL queries with a SORT followed by a LIMIT like " + + "'SELECT x FROM t ORDER BY y LIMIT m', if m is under this threshold, do a top-K sort" + + " in memory, otherwise do a global sort which spills to disk if necessary.") + .intConf + .createWithDefault(Int.MaxValue) + object Deprecated { val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" } @@ -1424,6 +1433,8 @@ class SQLConf extends Serializable with Logging { def sortBeforeRepartition: Boolean = getConf(SORT_BEFORE_REPARTITION) + def topKSortFallbackThreshold: Int = getConf(TOP_K_SORT_FALLBACK_THRESHOLD) + /** * Returns the [[Resolver]] for the current configuration, which can be used to determine if two * identifiers are equal. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 37a0b9d6c8728..b97a87a122406 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -66,9 +66,11 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { object SpecialLimits extends Strategy { override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case ReturnAnswer(rootPlan) => rootPlan match { - case Limit(IntegerLiteral(limit), Sort(order, true, child)) => + case Limit(IntegerLiteral(limit), Sort(order, true, child)) + if limit < conf.topKSortFallbackThreshold => TakeOrderedAndProjectExec(limit, order, child.output, planLater(child)) :: Nil - case Limit(IntegerLiteral(limit), Project(projectList, Sort(order, true, child))) => + case Limit(IntegerLiteral(limit), Project(projectList, Sort(order, true, child))) + if limit < conf.topKSortFallbackThreshold => TakeOrderedAndProjectExec(limit, order, projectList, planLater(child)) :: Nil case Limit(IntegerLiteral(limit), child) => // With whole stage codegen, Spark releases resources only when all the output data of the @@ -79,9 +81,11 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { CollectLimitExec(limit, LocalLimitExec(limit, planLater(child))) :: Nil case other => planLater(other) :: Nil } - case Limit(IntegerLiteral(limit), Sort(order, true, child)) => + case Limit(IntegerLiteral(limit), Sort(order, true, child)) + if limit < conf.topKSortFallbackThreshold => TakeOrderedAndProjectExec(limit, order, child.output, planLater(child)) :: Nil - case Limit(IntegerLiteral(limit), Project(projectList, Sort(order, true, child))) => + case Limit(IntegerLiteral(limit), Project(projectList, Sort(order, true, child))) + if limit < conf.topKSortFallbackThreshold => TakeOrderedAndProjectExec(limit, order, projectList, planLater(child)) :: Nil case _ => Nil } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index f0dfe6b76f7ae..a375f881c7d63 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -197,6 +197,18 @@ class PlannerSuite extends SharedSQLContext { assert(planned.cachedPlan.isInstanceOf[CollectLimitExec]) } + test("TakeOrderedAndProjectExec appears only when number of limit is below the threshold.") { + withSQLConf(SQLConf.TOP_K_SORT_FALLBACK_THRESHOLD.key -> "1000") { + val query0 = testData.select('value).orderBy('key).limit(100) + val planned0 = query0.queryExecution.executedPlan + assert(planned0.find(_.isInstanceOf[TakeOrderedAndProjectExec]).isDefined) + + val query1 = testData.select('value).orderBy('key).limit(2000) + val planned1 = query1.queryExecution.executedPlan + assert(planned1.find(_.isInstanceOf[TakeOrderedAndProjectExec]).isEmpty) + } + } + test("SPARK-23375: Cached sorted data doesn't need to be re-sorted") { val query = testData.select('key, 'value).sort('key.desc).cache() assert(query.queryExecution.optimizedPlan.isInstanceOf[InMemoryRelation]) From a7a9b1837808b281f47643490abcf054f6de7b50 Mon Sep 17 00:00:00 2001 From: Bago Amirbekian Date: Thu, 17 May 2018 11:13:16 -0700 Subject: [PATCH 0819/1161] [SPARK-24115] Have logging pass through instrumentation class. ## What changes were proposed in this pull request? Fixes to tuning instrumentation. ## How was this patch tested? Existing tests. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Bago Amirbekian Closes #21340 from MrBago/tunning-instrumentation. --- .../org/apache/spark/ml/tuning/CrossValidator.scala | 10 +++++----- .../apache/spark/ml/tuning/TrainValidationSplit.scala | 10 +++++----- .../org/apache/spark/ml/util/Instrumentation.scala | 7 +++++++ 3 files changed, 17 insertions(+), 10 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index 5e916cc4a9fdd..f327f37bad204 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -144,7 +144,7 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) val metrics = splits.zipWithIndex.map { case ((training, validation), splitIndex) => val trainingDataset = sparkSession.createDataFrame(training, schema).cache() val validationDataset = sparkSession.createDataFrame(validation, schema).cache() - logDebug(s"Train split $splitIndex with multiple sets of parameters.") + instr.logDebug(s"Train split $splitIndex with multiple sets of parameters.") // Fit models in a Future for training in parallel val foldMetricFutures = epm.zipWithIndex.map { case (paramMap, paramIndex) => @@ -155,7 +155,7 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) } // TODO: duplicate evaluator to take extra params from input val metric = eval.evaluate(model.transform(validationDataset, paramMap)) - logDebug(s"Got metric $metric for model trained with $paramMap.") + instr.logDebug(s"Got metric $metric for model trained with $paramMap.") metric } (executionContext) } @@ -169,12 +169,12 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) foldMetrics }.transpose.map(_.sum / $(numFolds)) // Calculate average metric over all splits - logInfo(s"Average cross-validation metrics: ${metrics.toSeq}") + instr.logInfo(s"Average cross-validation metrics: ${metrics.toSeq}") val (bestMetric, bestIndex) = if (eval.isLargerBetter) metrics.zipWithIndex.maxBy(_._1) else metrics.zipWithIndex.minBy(_._1) - logInfo(s"Best set of parameters:\n${epm(bestIndex)}") - logInfo(s"Best cross-validation metric: $bestMetric.") + instr.logInfo(s"Best set of parameters:\n${epm(bestIndex)}") + instr.logInfo(s"Best cross-validation metric: $bestMetric.") val bestModel = est.fit(dataset, epm(bestIndex)).asInstanceOf[Model[_]] instr.logSuccess(bestModel) copyValues(new CrossValidatorModel(uid, bestModel, metrics) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala index 13369c4df7180..14d6a69c36747 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala @@ -143,7 +143,7 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St } else None // Fit models in a Future for training in parallel - logDebug(s"Train split with multiple sets of parameters.") + instr.logDebug(s"Train split with multiple sets of parameters.") val metricFutures = epm.zipWithIndex.map { case (paramMap, paramIndex) => Future[Double] { val model = est.fit(trainingDataset, paramMap).asInstanceOf[Model[_]] @@ -153,7 +153,7 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St } // TODO: duplicate evaluator to take extra params from input val metric = eval.evaluate(model.transform(validationDataset, paramMap)) - logDebug(s"Got metric $metric for model trained with $paramMap.") + instr.logDebug(s"Got metric $metric for model trained with $paramMap.") metric } (executionContext) } @@ -165,12 +165,12 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St trainingDataset.unpersist() validationDataset.unpersist() - logInfo(s"Train validation split metrics: ${metrics.toSeq}") + instr.logInfo(s"Train validation split metrics: ${metrics.toSeq}") val (bestMetric, bestIndex) = if (eval.isLargerBetter) metrics.zipWithIndex.maxBy(_._1) else metrics.zipWithIndex.minBy(_._1) - logInfo(s"Best set of parameters:\n${epm(bestIndex)}") - logInfo(s"Best train validation split metric: $bestMetric.") + instr.logInfo(s"Best set of parameters:\n${epm(bestIndex)}") + instr.logInfo(s"Best train validation split metric: $bestMetric.") val bestModel = est.fit(dataset, epm(bestIndex)).asInstanceOf[Model[_]] instr.logSuccess(bestModel) copyValues(new TrainValidationSplitModel(uid, bestModel, metrics) diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala b/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala index 3247c394dfa64..467130b37c16e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala @@ -58,6 +58,13 @@ private[spark] class Instrumentation[E <: Estimator[_]] private ( s" storageLevel=${dataset.getStorageLevel}") } + /** + * Logs a debug message with a prefix that uniquely identifies the training session. + */ + override def logDebug(msg: => String): Unit = { + super.logDebug(prefix + msg) + } + /** * Logs a warning message with a prefix that uniquely identifies the training session. */ From 439c69511812776cb4b82956547ce958d0669c52 Mon Sep 17 00:00:00 2001 From: Bago Amirbekian Date: Thu, 17 May 2018 13:42:10 -0700 Subject: [PATCH 0820/1161] [SPARK-24114] Add instrumentation to FPGrowth. ## What changes were proposed in this pull request? Have FPGrowth keep track of model training using the Instrumentation class. ## How was this patch tested? manually Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Bago Amirbekian Closes #21344 from MrBago/fpgrowth-instr. --- mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala index 0bf405d9abf9d..d7fbe28ae7a64 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala @@ -161,6 +161,8 @@ class FPGrowth @Since("2.2.0") ( private def genericFit[T: ClassTag](dataset: Dataset[_]): FPGrowthModel = { val handlePersistence = dataset.storageLevel == StorageLevel.NONE + val instr = Instrumentation.create(this, dataset) + instr.logParams(params: _*) val data = dataset.select($(itemsCol)) val items = data.where(col($(itemsCol)).isNotNull).rdd.map(r => r.getSeq[Any](0).toArray) val mllibFP = new MLlibFPGrowth().setMinSupport($(minSupport)) @@ -183,7 +185,9 @@ class FPGrowth @Since("2.2.0") ( items.unpersist() } - copyValues(new FPGrowthModel(uid, frequentItems)).setParent(this) + val model = copyValues(new FPGrowthModel(uid, frequentItems)).setParent(this) + instr.logSuccess(model) + model } @Since("2.2.0") From d4a0895c628ca854895c3c35c46ed990af36ec61 Mon Sep 17 00:00:00 2001 From: Sandor Murakozi Date: Thu, 17 May 2018 16:33:06 -0700 Subject: [PATCH 0821/1161] [SPARK-22884][ML] ML tests for StructuredStreaming: spark.ml.clustering ## What changes were proposed in this pull request? Converting clustering tests to also check code with structured streaming, using the ML testing infrastructure implemented in SPARK-22882. This PR is a new version of https://github.com/apache/spark/pull/20319 Author: Sandor Murakozi Author: Joseph K. Bradley Closes #21358 from jkbradley/smurakozi-SPARK-22884. --- .../ml/clustering/BisectingKMeansSuite.scala | 41 ++++++++++--------- .../ml/clustering/GaussianMixtureSuite.scala | 22 ++++------ .../spark/ml/clustering/KMeansSuite.scala | 31 +++++++------- .../apache/spark/ml/clustering/LDASuite.scala | 21 ++++------ 4 files changed, 50 insertions(+), 65 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala index f3ff2afcad2cd..81842afbddbbb 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala @@ -19,17 +19,18 @@ package org.apache.spark.ml.clustering import scala.language.existentials -import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.SparkException import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamMap -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.clustering.DistanceMeasure -import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.{DataFrame, Dataset} +import org.apache.spark.sql.Dataset -class BisectingKMeansSuite - extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + +class BisectingKMeansSuite extends MLTest with DefaultReadWriteTest { + + import testImplicits._ final val k = 5 @transient var dataset: Dataset[_] = _ @@ -68,10 +69,13 @@ class BisectingKMeansSuite // Verify fit does not fail on very sparse data val model = bkm.fit(sparseDataset) - val result = model.transform(sparseDataset) - val numClusters = result.select("prediction").distinct().collect().length - // Verify we hit the edge case - assert(numClusters < k && numClusters > 1) + + testTransformerByGlobalCheckFunc[Tuple1[Vector]](sparseDataset.toDF(), model, "prediction") { + rows => + val numClusters = rows.distinct.length + // Verify we hit the edge case + assert(numClusters < k && numClusters > 1) + } } test("setter/getter") { @@ -104,19 +108,16 @@ class BisectingKMeansSuite val bkm = new BisectingKMeans().setK(k).setPredictionCol(predictionColName).setSeed(1) val model = bkm.fit(dataset) assert(model.clusterCenters.length === k) - - val transformed = model.transform(dataset) - val expectedColumns = Array("features", predictionColName) - expectedColumns.foreach { column => - assert(transformed.columns.contains(column)) - } - val clusters = - transformed.select(predictionColName).rdd.map(_.getInt(0)).distinct().collect().toSet - assert(clusters.size === k) - assert(clusters === Set(0, 1, 2, 3, 4)) assert(model.computeCost(dataset) < 0.1) assert(model.hasParent) + testTransformerByGlobalCheckFunc[Tuple1[Vector]](dataset.toDF(), model, + "features", predictionColName) { rows => + val clusters = rows.map(_.getAs[Int](predictionColName)).toSet + assert(clusters.size === k) + assert(clusters === Set(0, 1, 2, 3, 4)) + } + // Check validity of model summary val numRows = dataset.count() assert(model.hasSummary) diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala index d0d461a42711a..0b91f502f615b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala @@ -23,16 +23,15 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.ml.linalg.{DenseMatrix, Matrices, Vector, Vectors} import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.stat.distribution.MultivariateGaussian -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ -import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.{DataFrame, Dataset, Row} +import org.apache.spark.sql.{Dataset, Row} -class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext - with DefaultReadWriteTest { - import testImplicits._ +class GaussianMixtureSuite extends MLTest with DefaultReadWriteTest { + import GaussianMixtureSuite._ + import testImplicits._ final val k = 5 private val seed = 538009335 @@ -119,15 +118,10 @@ class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext assert(model.weights.length === k) assert(model.gaussians.length === k) - val transformed = model.transform(dataset) - val expectedColumns = Array("features", predictionColName, probabilityColName) - expectedColumns.foreach { column => - assert(transformed.columns.contains(column)) - } - // Check prediction matches the highest probability, and probabilities sum to one. - transformed.select(predictionColName, probabilityColName).collect().foreach { - case Row(pred: Int, prob: Vector) => + testTransformer[Tuple1[Vector]](dataset.toDF(), model, + "features", predictionColName, probabilityColName) { + case Row(_, pred: Int, prob: Vector) => val probArray = prob.toArray val predFromProb = probArray.zipWithIndex.maxBy(_._1)._2 assert(pred === predFromProb) diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala index 680a7c2034083..2569e7a432ca4 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala @@ -22,20 +22,21 @@ import scala.util.Random import org.dmg.pmml.{ClusteringModel, PMML} -import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.SparkException import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamMap -import org.apache.spark.ml.util._ +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils, PMMLReadWriteTest} import org.apache.spark.ml.util.TestingUtils._ -import org.apache.spark.mllib.clustering.{DistanceMeasure, KMeans => MLlibKMeans, KMeansModel => MLlibKMeansModel} +import org.apache.spark.mllib.clustering.{DistanceMeasure, KMeans => MLlibKMeans, + KMeansModel => MLlibKMeansModel} import org.apache.spark.mllib.linalg.{Vectors => MLlibVectors} -import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, Dataset, SparkSession} private[clustering] case class TestRow(features: Vector) -class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest - with PMMLReadWriteTest { +class KMeansSuite extends MLTest with DefaultReadWriteTest with PMMLReadWriteTest { + + import testImplicits._ final val k = 5 @transient var dataset: Dataset[_] = _ @@ -109,15 +110,13 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR val model = kmeans.fit(dataset) assert(model.clusterCenters.length === k) - val transformed = model.transform(dataset) - val expectedColumns = Array("features", predictionColName) - expectedColumns.foreach { column => - assert(transformed.columns.contains(column)) + testTransformerByGlobalCheckFunc[Tuple1[Vector]](dataset.toDF(), model, + "features", predictionColName) { rows => + val clusters = rows.map(_.getAs[Int](predictionColName)).toSet + assert(clusters.size === k) + assert(clusters === Set(0, 1, 2, 3, 4)) } - val clusters = - transformed.select(predictionColName).rdd.map(_.getInt(0)).distinct().collect().toSet - assert(clusters.size === k) - assert(clusters === Set(0, 1, 2, 3, 4)) + assert(model.computeCost(dataset) < 0.1) assert(model.hasParent) @@ -149,9 +148,7 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR model.setFeaturesCol(featuresColName).setPredictionCol(predictionColName) val transformed = model.transform(dataset.withColumnRenamed("features", featuresColName)) - Seq(featuresColName, predictionColName).foreach { column => - assert(transformed.columns.contains(column)) - } + assert(transformed.schema.fieldNames.toSet === Set(featuresColName, predictionColName)) assert(model.getFeaturesCol == featuresColName) assert(model.getPredictionCol == predictionColName) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala index 4d848205034c0..096b5416899e1 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala @@ -21,11 +21,9 @@ import scala.language.existentials import org.apache.hadoop.fs.Path -import org.apache.spark.SparkFunSuite import org.apache.spark.ml.linalg.{Vector, Vectors} -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ -import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql._ object LDASuite { @@ -61,7 +59,7 @@ object LDASuite { } -class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class LDASuite extends MLTest with DefaultReadWriteTest { import testImplicits._ @@ -186,16 +184,11 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead assert(model.topicsMatrix.numCols === k) assert(!model.isDistributed) - // transform() - val transformed = model.transform(dataset) - val expectedColumns = Array("features", lda.getTopicDistributionCol) - expectedColumns.foreach { column => - assert(transformed.columns.contains(column)) - } - transformed.select(lda.getTopicDistributionCol).collect().foreach { r => - val topicDistribution = r.getAs[Vector](0) - assert(topicDistribution.size === k) - assert(topicDistribution.toArray.forall(w => w >= 0.0 && w <= 1.0)) + testTransformer[Tuple1[Vector]](dataset.toDF(), model, + "features", lda.getTopicDistributionCol) { + case Row(_, topicDistribution: Vector) => + assert(topicDistribution.size === k) + assert(topicDistribution.toArray.forall(w => w >= 0.0 && w <= 1.0)) } // logLikelihood, logPerplexity From 7b2dca5b12164b787ec4e8e7e9f92c60a3f9563e Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Fri, 18 May 2018 15:32:29 +0800 Subject: [PATCH 0822/1161] [SPARK-24277][SQL] Code clean up in SQL module: HadoopMapReduceCommitProtocol ## What changes were proposed in this pull request? In HadoopMapReduceCommitProtocol and FileFormatWriter, there are unnecessary settings in hadoop configuration. Also clean up some code in SQL module. ## How was this patch tested? Unit test Author: Gengliang Wang Closes #21329 from gengliangwang/codeCleanWrite. --- .../io/HadoopMapReduceCommitProtocol.scala | 15 +++------------ .../datasources/orc/OrcColumnVector.java | 6 +----- .../parquet/VectorizedRleValuesReader.java | 4 ++-- .../org/apache/spark/sql/api/r/SQLUtils.scala | 2 +- .../spark/sql/execution/command/views.scala | 10 ++++------ .../execution/datasources/FileFormatWriter.scala | 11 +++++------ .../sql/execution/ui/SQLAppStatusListener.scala | 2 +- 7 files changed, 17 insertions(+), 33 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala index 3e60c50ada59b..163511b7ffa3a 100644 --- a/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala +++ b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala @@ -145,18 +145,9 @@ class HadoopMapReduceCommitProtocol( } override def setupJob(jobContext: JobContext): Unit = { - // Setup IDs - val jobId = SparkHadoopWriterUtils.createJobID(new Date, 0) - val taskId = new TaskID(jobId, TaskType.MAP, 0) - val taskAttemptId = new TaskAttemptID(taskId, 0) - - // Set up the configuration object - jobContext.getConfiguration.set("mapreduce.job.id", jobId.toString) - jobContext.getConfiguration.set("mapreduce.task.id", taskAttemptId.getTaskID.toString) - jobContext.getConfiguration.set("mapreduce.task.attempt.id", taskAttemptId.toString) - jobContext.getConfiguration.setBoolean("mapreduce.task.ismap", true) - jobContext.getConfiguration.setInt("mapreduce.task.partition", 0) - + // Create a dummy [[TaskAttemptContextImpl]] with configuration to get [[OutputCommitter]] + // instance on Spark driver. Note that the job/task/attampt id doesn't matter here. + val taskAttemptId = new TaskAttemptID(new TaskID(new JobID(), TaskType.MAP, 0), 0) val taskAttemptContext = new TaskAttemptContextImpl(jobContext.getConfiguration, taskAttemptId) committer = setupCommitter(taskAttemptContext) committer.setupJob(jobContext) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java index 12f4d658b1868..fcf73e8d7ae6c 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java @@ -47,11 +47,7 @@ public class OrcColumnVector extends org.apache.spark.sql.vectorized.ColumnVecto OrcColumnVector(DataType type, ColumnVector vector) { super(type); - if (type instanceof TimestampType) { - isTimestamp = true; - } else { - isTimestamp = false; - } + isTimestamp = type instanceof TimestampType; baseData = vector; if (vector instanceof LongColumnVector) { diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java index fe3d31ae8e746..de0d65a1e0906 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java @@ -571,7 +571,7 @@ private int readIntLittleEndian() throws IOException { int ch3 = in.read(); int ch2 = in.read(); int ch1 = in.read(); - return ((ch1 << 24) + (ch2 << 16) + (ch3 << 8) + (ch4 << 0)); + return ((ch1 << 24) + (ch2 << 16) + (ch3 << 8) + (ch4)); } /** @@ -592,7 +592,7 @@ private int readIntLittleEndianPaddedOnBitWidth() throws IOException { int ch3 = in.read(); int ch2 = in.read(); int ch1 = in.read(); - return (ch1 << 16) + (ch2 << 8) + (ch3 << 0); + return (ch1 << 16) + (ch2 << 8) + (ch3); } case 4: { return readIntLittleEndian(); diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala index af20764f9a968..265a84b39a425 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala @@ -188,7 +188,7 @@ private[sql] object SQLUtils extends Logging { dataType match { case 's' => // Read StructType for DataFrame - val fields = SerDe.readList(dis, jvmObjectTracker = null).asInstanceOf[Array[Object]] + val fields = SerDe.readList(dis, jvmObjectTracker = null) Row.fromSeq(fields) case _ => null } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala index 5172f32ec7b9c..6373584b10e35 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala @@ -410,12 +410,10 @@ object ViewHelper { } // Detect cyclic references from subqueries. - plan.expressions.foreach { expr => - expr match { - case s: SubqueryExpression => - checkCyclicViewReference(s.plan, path, viewIdent) - case _ => // Do nothing. - } + plan.expressions.foreach { + case s: SubqueryExpression => + checkCyclicViewReference(s.plan, path, viewIdent) + case _ => // Do nothing. } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala index 401597f967218..681bb1df6bbae 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala @@ -244,18 +244,17 @@ object FileFormatWriter extends Logging { iterator: Iterator[InternalRow]): WriteTaskResult = { val jobId = SparkHadoopWriterUtils.createJobID(new Date, sparkStageId) - val taskId = new TaskID(jobId, TaskType.MAP, sparkPartitionId) - val taskAttemptId = new TaskAttemptID(taskId, sparkAttemptNumber) // Set up the attempt context required to use in the output committer. val taskAttemptContext: TaskAttemptContext = { + val taskId = new TaskID(jobId, TaskType.MAP, sparkPartitionId) + val taskAttemptId = new TaskAttemptID(taskId, sparkAttemptNumber) // Set up the configuration object val hadoopConf = description.serializableHadoopConf.value hadoopConf.set("mapreduce.job.id", jobId.toString) hadoopConf.set("mapreduce.task.id", taskAttemptId.getTaskID.toString) hadoopConf.set("mapreduce.task.attempt.id", taskAttemptId.toString) hadoopConf.setBoolean("mapreduce.task.ismap", true) - hadoopConf.setInt("mapreduce.task.partition", 0) new TaskAttemptContextImpl(hadoopConf, taskAttemptId) } @@ -378,7 +377,7 @@ object FileFormatWriter extends Logging { dataSchema = description.dataColumns.toStructType, context = taskAttemptContext) - statsTrackers.map(_.newFile(currentPath)) + statsTrackers.foreach(_.newFile(currentPath)) } override def execute(iter: Iterator[InternalRow]): ExecutedWriteSummary = { @@ -429,10 +428,10 @@ object FileFormatWriter extends Logging { committer: FileCommitProtocol) extends ExecuteWriteTask { /** Flag saying whether or not the data to be written out is partitioned. */ - val isPartitioned = desc.partitionColumns.nonEmpty + private val isPartitioned = desc.partitionColumns.nonEmpty /** Flag saying whether or not the data to be written out is bucketed. */ - val isBucketed = desc.bucketIdExpression.isDefined + private val isBucketed = desc.bucketIdExpression.isDefined assert(isPartitioned || isBucketed, s"""DynamicPartitionWriteTask should be used for writing out data that's either diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala index d254af400a7cf..2c4d0bcf103ff 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala @@ -170,7 +170,7 @@ class SQLAppStatusListener( .filter { case (id, _) => metricIds.contains(id) } .groupBy(_._1) .map { case (id, values) => - id -> SQLMetrics.stringValue(metricTypes(id), values.map(_._2).toSeq) + id -> SQLMetrics.stringValue(metricTypes(id), values.map(_._2)) } // Check the execution again for whether the aggregated metrics data has been calculated. From 0cf59fcbe3799dd3c4469cbf8cd842d668a76f34 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Fri, 18 May 2018 09:53:24 -0700 Subject: [PATCH 0823/1161] [SPARK-24303][PYTHON] Update cloudpickle to v0.4.4 ## What changes were proposed in this pull request? cloudpickle 0.4.4 is released - https://github.com/cloudpipe/cloudpickle/releases/tag/v0.4.4 There's no invasive change - the main difference is that we are now able to pickle the root logger, which fix is pretty isolated. ## How was this patch tested? Jenkins tests. Author: hyukjinkwon Closes #21350 from HyukjinKwon/SPARK-24303. --- python/pyspark/cloudpickle.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/python/pyspark/cloudpickle.py b/python/pyspark/cloudpickle.py index ea845b98b3db2..88519d7311fcc 100644 --- a/python/pyspark/cloudpickle.py +++ b/python/pyspark/cloudpickle.py @@ -272,7 +272,7 @@ def save_memoryview(self, obj): if not PY3: def save_buffer(self, obj): self.save(str(obj)) - dispatch[buffer] = save_buffer + dispatch[buffer] = save_buffer # noqa: F821 'buffer' was removed in Python 3 def save_unsupported(self, obj): raise pickle.PicklingError("Cannot pickle objects of type %s" % type(obj)) @@ -801,10 +801,10 @@ def save_ellipsis(self, obj): def save_not_implemented(self, obj): self.save_reduce(_gen_not_implemented, ()) - if PY3: - dispatch[io.TextIOWrapper] = save_file - else: + try: # Python 2 dispatch[file] = save_file + except NameError: # Python 3 + dispatch[io.TextIOWrapper] = save_file dispatch[type(Ellipsis)] = save_ellipsis dispatch[type(NotImplemented)] = save_not_implemented @@ -819,6 +819,11 @@ def save_logger(self, obj): dispatch[logging.Logger] = save_logger + def save_root_logger(self, obj): + self.save_reduce(logging.getLogger, (), obj=obj) + + dispatch[logging.RootLogger] = save_root_logger + """Special functions for Add-on libraries""" def inject_addons(self): """Plug in system. Register additional pickling functions if modules already loaded""" From 7696b9de0df6e9eb85a74bdb404409da693cf65e Mon Sep 17 00:00:00 2001 From: Soham Aurangabadkar Date: Fri, 18 May 2018 10:29:34 -0700 Subject: [PATCH 0824/1161] [SPARK-20538][SQL] Wrap Dataset.reduce with withNewRddExecutionId. ## What changes were proposed in this pull request? Wrap Dataset.reduce with `withNewExecutionId`. Author: Soham Aurangabadkar Closes #21316 from sohama4/dataset_reduce_withexecutionid. --- sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index f001f16e1d5ee..32267eb0300f5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -1617,7 +1617,9 @@ class Dataset[T] private[sql]( */ @Experimental @InterfaceStability.Evolving - def reduce(func: (T, T) => T): T = rdd.reduce(func) + def reduce(func: (T, T) => T): T = withNewRDDExecutionId { + rdd.reduce(func) + } /** * :: Experimental :: From 807ba44cb742c5f7c22bdf6bfe2cf814be85398e Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Fri, 18 May 2018 10:35:43 -0700 Subject: [PATCH 0825/1161] [SPARK-24159][SS] Enable no-data micro batches for streaming mapGroupswithState ## What changes were proposed in this pull request? Enabled no-data batches in flatMapGroupsWithState in following two cases. - When ProcessingTime timeout is used, then we always run a batch every trigger interval. - When event-time watermark is defined, then the user may be doing arbitrary logic against the watermark value even if timeouts are not set. In such cases, it's best to run batches whenever the watermark has changed, irrespective of whether timeouts (i.e. event-time timeout) have been explicitly enabled. ## How was this patch tested? updated tests Author: Tathagata Das Closes #21345 from tdas/SPARK-24159. --- .../FlatMapGroupsWithStateExec.scala | 17 ++- .../FlatMapGroupsWithStateSuite.scala | 120 ++++++++++-------- 2 files changed, 80 insertions(+), 57 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala index 80769d728b8f1..8e82cccbc8fa3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala @@ -97,6 +97,18 @@ case class FlatMapGroupsWithStateExec( override def keyExpressions: Seq[Attribute] = groupingAttributes + override def shouldRunAnotherBatch(newMetadata: OffsetSeqMetadata): Boolean = { + timeoutConf match { + case ProcessingTimeTimeout => + true // Always run batches to process timeouts + case EventTimeTimeout => + // Process another non-data batch only if the watermark has changed in this executed plan + eventTimeWatermark.isDefined && newMetadata.batchWatermarkMs > eventTimeWatermark.get + case _ => + false + } + } + override protected def doExecute(): RDD[InternalRow] = { metrics // force lazy init at driver @@ -126,7 +138,6 @@ case class FlatMapGroupsWithStateExec( case _ => iter } - // Generate a iterator that returns the rows grouped by the grouping function // Note that this code ensures that the filtering for timeout occurs only after // all the data has been processed. This is to ensure that the timeout information of all @@ -194,11 +205,11 @@ case class FlatMapGroupsWithStateExec( throw new IllegalStateException( s"Cannot filter timed out keys for $timeoutConf") } - val timingOutKeys = store.getRange(None, None).filter { rowPair => + val timingOutPairs = store.getRange(None, None).filter { rowPair => val timeoutTimestamp = getTimeoutTimestamp(rowPair.value) timeoutTimestamp != NO_TIMESTAMP && timeoutTimestamp < timeoutThreshold } - timingOutKeys.flatMap { rowPair => + timingOutPairs.flatMap { rowPair => callFunctionAndUpdateState(rowPair.key, Iterator.empty, rowPair.value, hasTimedOut = true) } } else Iterator.empty diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala index b1416bff87ee7..988c8e6753e25 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala @@ -615,20 +615,20 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest testStream(result, Update)( AddData(inputData, "a"), - CheckLastBatch(("a", "1")), + CheckNewAnswer(("a", "1")), assertNumStateRows(total = 1, updated = 1), AddData(inputData, "a", "b"), - CheckLastBatch(("a", "2"), ("b", "1")), + CheckNewAnswer(("a", "2"), ("b", "1")), assertNumStateRows(total = 2, updated = 2), StopStream, StartStream(), AddData(inputData, "a", "b"), // should remove state for "a" and not return anything for a - CheckLastBatch(("b", "2")), + CheckNewAnswer(("b", "2")), assertNumStateRows(total = 1, updated = 2), StopStream, StartStream(), AddData(inputData, "a", "c"), // should recreate state for "a" and return count as 1 and - CheckLastBatch(("a", "1"), ("c", "1")), + CheckNewAnswer(("a", "1"), ("c", "1")), assertNumStateRows(total = 3, updated = 2) ) } @@ -657,15 +657,15 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest .flatMapGroupsWithState(Update, GroupStateTimeout.NoTimeout)(stateFunc) testStream(result, Update)( AddData(inputData, "a", "a", "b"), - CheckLastBatch(("a", "1"), ("a", "2"), ("b", "1")), + CheckNewAnswer(("a", "1"), ("a", "2"), ("b", "1")), StopStream, StartStream(), AddData(inputData, "a", "b"), // should remove state for "a" and not return anything for a - CheckLastBatch(("b", "2")), + CheckNewAnswer(("b", "2")), StopStream, StartStream(), AddData(inputData, "a", "c"), // should recreate state for "a" and return count as 1 and - CheckLastBatch(("a", "1"), ("c", "1")) + CheckNewAnswer(("a", "1"), ("c", "1")) ) } @@ -694,22 +694,22 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest testStream(result, Complete)( AddData(inputData, "a"), - CheckLastBatch(("a", 1)), + CheckNewAnswer(("a", 1)), AddData(inputData, "a", "b"), // mapGroups generates ("a", "2"), ("b", "1"); so increases counts of a and b by 1 - CheckLastBatch(("a", 2), ("b", 1)), + CheckNewAnswer(("a", 2), ("b", 1)), StopStream, StartStream(), AddData(inputData, "a", "b"), // mapGroups should remove state for "a" and generate ("a", "-1"), ("b", "2") ; // so increment a and b by 1 - CheckLastBatch(("a", 3), ("b", 2)), + CheckNewAnswer(("a", 3), ("b", 2)), StopStream, StartStream(), AddData(inputData, "a", "c"), // mapGroups should recreate state for "a" and generate ("a", "1"), ("c", "1") ; // so increment a and c by 1 - CheckLastBatch(("a", 4), ("b", 2), ("c", 1)) + CheckNewAnswer(("a", 4), ("b", 2), ("c", 1)) ) } @@ -729,8 +729,8 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest } test("flatMapGroupsWithState - streaming with processing time timeout") { - // Function to maintain running count up to 2, and then remove the count - // Returns the data and the count (-1 if count reached beyond 2 and state was just removed) + // Function to maintain the count as state and set the proc. time timeout delay of 10 seconds. + // It returns the count if changed, or -1 if the state was removed by timeout. val stateFunc = (key: String, values: Iterator[String], state: GroupState[RunningCount]) => { assertCanGetProcessingTime { state.getCurrentProcessingTimeMs() >= 0 } assertCannotGetWatermark { state.getCurrentWatermarkMs() } @@ -757,17 +757,17 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest StartStream(Trigger.ProcessingTime("1 second"), triggerClock = clock), AddData(inputData, "a"), AdvanceManualClock(1 * 1000), - CheckLastBatch(("a", "1")), + CheckNewAnswer(("a", "1")), assertNumStateRows(total = 1, updated = 1), AddData(inputData, "b"), AdvanceManualClock(1 * 1000), - CheckLastBatch(("b", "1")), + CheckNewAnswer(("b", "1")), assertNumStateRows(total = 2, updated = 1), AddData(inputData, "b"), AdvanceManualClock(10 * 1000), - CheckLastBatch(("a", "-1"), ("b", "2")), + CheckNewAnswer(("a", "-1"), ("b", "2")), assertNumStateRows(total = 1, updated = 2), StopStream, @@ -775,38 +775,42 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest AddData(inputData, "c"), AdvanceManualClock(11 * 1000), - CheckLastBatch(("b", "-1"), ("c", "1")), + CheckNewAnswer(("b", "-1"), ("c", "1")), assertNumStateRows(total = 1, updated = 2), - AddData(inputData, "c"), - AdvanceManualClock(20 * 1000), - CheckLastBatch(("c", "2")), - assertNumStateRows(total = 1, updated = 1) + AdvanceManualClock(12 * 1000), + AssertOnQuery { _ => clock.getTimeMillis() == 35000 }, + Execute { q => + failAfter(streamingTimeout) { + while (q.lastProgress.timestamp != "1970-01-01T00:00:35.000Z") { + Thread.sleep(1) + } + } + }, + CheckNewAnswer(("c", "-1")), + assertNumStateRows(total = 0, updated = 0) ) } test("flatMapGroupsWithState - streaming with event time timeout + watermark") { - // Function to maintain the max event time - // Returns the max event time in the state, or -1 if the state was removed by timeout + // Function to maintain the max event time as state and set the timeout timestamp based on the + // current max event time seen. It returns the max event time in the state, or -1 if the state + // was removed by timeout. val stateFunc = (key: String, values: Iterator[(String, Long)], state: GroupState[Long]) => { assertCanGetProcessingTime { state.getCurrentProcessingTimeMs() >= 0 } assertCanGetWatermark { state.getCurrentWatermarkMs() >= -1 } - val timeoutDelay = 5 - if (key != "a") { - Iterator.empty + val timeoutDelaySec = 5 + if (state.hasTimedOut) { + state.remove() + Iterator((key, -1)) } else { - if (state.hasTimedOut) { - state.remove() - Iterator((key, -1)) - } else { - val valuesSeq = values.toSeq - val maxEventTime = math.max(valuesSeq.map(_._2).max, state.getOption.getOrElse(0L)) - val timeoutTimestampMs = maxEventTime + timeoutDelay - state.update(maxEventTime) - state.setTimeoutTimestamp(timeoutTimestampMs * 1000) - Iterator((key, maxEventTime.toInt)) - } + val valuesSeq = values.toSeq + val maxEventTimeSec = math.max(valuesSeq.map(_._2).max, state.getOption.getOrElse(0L)) + val timeoutTimestampSec = maxEventTimeSec + timeoutDelaySec + state.update(maxEventTimeSec) + state.setTimeoutTimestamp(timeoutTimestampSec * 1000) + Iterator((key, maxEventTimeSec.toInt)) } } val inputData = MemoryStream[(String, Int)] @@ -819,15 +823,23 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest .flatMapGroupsWithState(Update, EventTimeTimeout)(stateFunc) testStream(result, Update)( - StartStream(Trigger.ProcessingTime("1 second")), - AddData(inputData, ("a", 11), ("a", 13), ("a", 15)), // Set timeout timestamp of ... - CheckLastBatch(("a", 15)), // "a" to 15 + 5 = 20s, watermark to 5s + StartStream(), + + AddData(inputData, ("a", 11), ("a", 13), ("a", 15)), + // Max event time = 15. Timeout timestamp for "a" = 15 + 5 = 20. Watermark = 15 - 10 = 5. + CheckNewAnswer(("a", 15)), // Output = max event time of a + AddData(inputData, ("a", 4)), // Add data older than watermark for "a" - CheckLastBatch(), // No output as data should get filtered by watermark - AddData(inputData, ("dummy", 35)), // Set watermark = 35 - 10 = 25s - CheckLastBatch(), // No output as no data for "a" - AddData(inputData, ("a", 24)), // Add data older than watermark, should be ignored - CheckLastBatch(("a", -1)) // State for "a" should timeout and emit -1 + CheckNewAnswer(), // No output as data should get filtered by watermark + + AddData(inputData, ("a", 10)), // Add data newer than watermark for "a" + CheckNewAnswer(("a", 15)), // Max event time is still the same + // Timeout timestamp for "a" is still 20 as max event time for "a" is still 15. + // Watermark is still 5 as max event time for all data is still 15. + + AddData(inputData, ("b", 31)), // Add data newer than watermark for "b", not "a" + // Watermark = 31 - 10 = 21, so "a" should be timed out as timeout timestamp for "a" is 20. + CheckNewAnswer(("a", -1), ("b", 31)) // State for "a" should timeout and emit -1 ) } @@ -856,20 +868,20 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest testStream(result, Update)( AddData(inputData, "a"), - CheckLastBatch(("a", "1")), + CheckNewAnswer(("a", "1")), assertNumStateRows(total = 1, updated = 1), AddData(inputData, "a", "b"), - CheckLastBatch(("a", "2"), ("b", "1")), + CheckNewAnswer(("a", "2"), ("b", "1")), assertNumStateRows(total = 2, updated = 2), StopStream, StartStream(), AddData(inputData, "a", "b"), // should remove state for "a" and return count as -1 - CheckLastBatch(("a", "-1"), ("b", "2")), + CheckNewAnswer(("a", "-1"), ("b", "2")), assertNumStateRows(total = 1, updated = 2), StopStream, StartStream(), AddData(inputData, "a", "c"), // should recreate state for "a" and return count as 1 - CheckLastBatch(("a", "1"), ("c", "1")), + CheckNewAnswer(("a", "1"), ("c", "1")), assertNumStateRows(total = 3, updated = 2) ) } @@ -920,15 +932,15 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest testStream(result, Update)( setFailInTask(false), AddData(inputData, "a"), - CheckLastBatch(("a", 1L)), + CheckNewAnswer(("a", 1L)), AddData(inputData, "a"), - CheckLastBatch(("a", 2L)), + CheckNewAnswer(("a", 2L)), setFailInTask(true), AddData(inputData, "a"), ExpectFailure[SparkException](), // task should fail but should not increment count setFailInTask(false), StartStream(), - CheckLastBatch(("a", 3L)) // task should not fail, and should show correct count + CheckNewAnswer(("a", 3L)) // task should not fail, and should show correct count ) } @@ -938,7 +950,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest val result = inputData.toDS.groupByKey(x => x).mapGroupsWithState(stateFunc) testStream(result, Update)( AddData(inputData, "a"), - CheckLastBatch("a"), + CheckNewAnswer("a"), AssertOnQuery(_.lastExecution.executedPlan.outputPartitioning === UnknownPartitioning(0)) ) } @@ -1000,7 +1012,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest StartStream(Trigger.ProcessingTime("1 second"), triggerClock = clock), AddData(inputData, ("a", 1L)), AdvanceManualClock(1 * 1000), - CheckLastBatch(("a", "1")) + CheckNewAnswer(("a", "1")) ) } } From ed7ba7db8fa344ff182b72d23ae458e711f63432 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Fri, 18 May 2018 11:14:22 -0700 Subject: [PATCH 0826/1161] [SPARK-23850][SQL] Add separate config for SQL options redaction. The old code was relying on a core configuration and extended its default value to include things that redact desired things in the app's environment. Instead, add a SQL-specific option for which options to redact, and apply both the core and SQL-specific rules when redacting the options in the save command. This is a little sub-optimal since it adds another config, but it retains the current default behavior. While there I also fixed a typo and a couple of minor config API usage issues in the related redaction option that SQL already had. Tested with existing unit tests, plus checking the env page on a shell UI. Author: Marcelo Vanzin Closes #21158 from vanzin/SPARK-23850. --- .../spark/internal/config/package.scala | 2 +- .../apache/spark/sql/internal/SQLConf.scala | 24 +++++++++++++++++-- .../sql/execution/DataSourceScanExec.scala | 2 +- .../spark/sql/execution/QueryExecution.scala | 2 +- .../SaveIntoDataSourceCommand.scala | 5 ++-- .../SaveIntoDataSourceCommandSuite.scala | 3 --- 6 files changed, 27 insertions(+), 11 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 82f0a04e94b1c..a54b091a64d50 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -342,7 +342,7 @@ package object config { "a property key or value, the value is redacted from the environment UI and various logs " + "like YARN and event logs.") .regexConf - .createWithDefault("(?i)secret|password|url|user|username".r) + .createWithDefault("(?i)secret|password".r) private[spark] val STRING_REDACTION_PATTERN = ConfigBuilder("spark.redaction.string.regex") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 2a673c6ce8f4a..53a50305348fa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1155,8 +1155,17 @@ object SQLConf { .booleanConf .createWithDefault(true) + val SQL_OPTIONS_REDACTION_PATTERN = + buildConf("spark.sql.redaction.options.regex") + .doc("Regex to decide which keys in a Spark SQL command's options map contain sensitive " + + "information. The values of options whose names that match this regex will be redacted " + + "in the explain output. This redaction is applied on top of the global redaction " + + s"configuration defined by ${SECRET_REDACTION_PATTERN.key}.") + .regexConf + .createWithDefault("(?i)url".r) + val SQL_STRING_REDACTION_PATTERN = - ConfigBuilder("spark.sql.redaction.string.regex") + buildConf("spark.sql.redaction.string.regex") .doc("Regex to decide which parts of strings produced by Spark contain sensitive " + "information. When this regex matches a string part, that string part is replaced by a " + "dummy value. This is currently used to redact the output of SQL explain commands. " + @@ -1429,7 +1438,7 @@ class SQLConf extends Serializable with Logging { def fileCompressionFactor: Double = getConf(FILE_COMRESSION_FACTOR) - def stringRedationPattern: Option[Regex] = SQL_STRING_REDACTION_PATTERN.readFrom(reader) + def stringRedactionPattern: Option[Regex] = getConf(SQL_STRING_REDACTION_PATTERN) def sortBeforeRepartition: Boolean = getConf(SORT_BEFORE_REPARTITION) @@ -1738,6 +1747,17 @@ class SQLConf extends Serializable with Logging { }.toSeq } + /** + * Redacts the given option map according to the description of SQL_OPTIONS_REDACTION_PATTERN. + */ + def redactOptions(options: Map[String, String]): Map[String, String] = { + val regexes = Seq( + getConf(SQL_OPTIONS_REDACTION_PATTERN), + SECRET_REDACTION_PATTERN.readFrom(reader)) + + regexes.foldLeft(options.toSeq) { case (opts, r) => Utils.redact(Some(r), opts) }.toMap + } + /** * Return whether a given key is set in this [[SQLConf]]. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index 08ff33afbba3d..61c14fee09337 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -69,7 +69,7 @@ trait DataSourceScanExec extends LeafExecNode with CodegenSupport { * Shorthand for calling redactString() without specifying redacting rules */ private def redact(text: String): String = { - Utils.redact(sqlContext.sessionState.conf.stringRedationPattern, text) + Utils.redact(sqlContext.sessionState.conf.stringRedactionPattern, text) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index 15379a0663f7d..3112b306c365e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -225,7 +225,7 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) { * Redact the sensitive information in the given string. */ private def withRedaction(message: String): String = { - Utils.redact(sparkSession.sessionState.conf.stringRedationPattern, message) + Utils.redact(sparkSession.sessionState.conf.stringRedactionPattern, message) } /** A special namespace for commands that can be used to debug query execution. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala index 568e953a5db66..00b1b5dedb593 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala @@ -17,13 +17,12 @@ package org.apache.spark.sql.execution.datasources -import org.apache.spark.SparkEnv import org.apache.spark.sql.{Dataset, Row, SaveMode, SparkSession} import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.command.RunnableCommand +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.CreatableRelationProvider -import org.apache.spark.util.Utils /** * Saves the results of `query` in to a data source. @@ -50,7 +49,7 @@ case class SaveIntoDataSourceCommand( } override def simpleString: String = { - val redacted = Utils.redact(SparkEnv.get.conf, options.toSeq).toMap + val redacted = SQLConf.get.redactOptions(options) s"SaveIntoDataSourceCommand ${dataSource}, ${redacted}, ${mode}" } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommandSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommandSuite.scala index 4b3ca8e60cab6..a1da3ec43eae3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommandSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommandSuite.scala @@ -23,9 +23,6 @@ import org.apache.spark.sql.test.SharedSQLContext class SaveIntoDataSourceCommandSuite extends SharedSQLContext { - override protected def sparkConf: SparkConf = super.sparkConf - .set("spark.redaction.regex", "(?i)password|url") - test("simpleString is redacted") { val URL = "connection.url" val PASS = "123" From 1c4553d67de8089e8aa84bc736faa11f21615a6a Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Fri, 18 May 2018 12:51:09 -0700 Subject: [PATCH 0827/1161] Revert "[SPARK-24277][SQL] Code clean up in SQL module: HadoopMapReduceCommitProtocol" This reverts commit 7b2dca5b12164b787ec4e8e7e9f92c60a3f9563e. --- .../io/HadoopMapReduceCommitProtocol.scala | 15 ++++++++++++--- .../datasources/orc/OrcColumnVector.java | 6 +++++- .../parquet/VectorizedRleValuesReader.java | 4 ++-- .../org/apache/spark/sql/api/r/SQLUtils.scala | 2 +- .../spark/sql/execution/command/views.scala | 10 ++++++---- .../execution/datasources/FileFormatWriter.scala | 11 ++++++----- .../sql/execution/ui/SQLAppStatusListener.scala | 2 +- 7 files changed, 33 insertions(+), 17 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala index 163511b7ffa3a..3e60c50ada59b 100644 --- a/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala +++ b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala @@ -145,9 +145,18 @@ class HadoopMapReduceCommitProtocol( } override def setupJob(jobContext: JobContext): Unit = { - // Create a dummy [[TaskAttemptContextImpl]] with configuration to get [[OutputCommitter]] - // instance on Spark driver. Note that the job/task/attampt id doesn't matter here. - val taskAttemptId = new TaskAttemptID(new TaskID(new JobID(), TaskType.MAP, 0), 0) + // Setup IDs + val jobId = SparkHadoopWriterUtils.createJobID(new Date, 0) + val taskId = new TaskID(jobId, TaskType.MAP, 0) + val taskAttemptId = new TaskAttemptID(taskId, 0) + + // Set up the configuration object + jobContext.getConfiguration.set("mapreduce.job.id", jobId.toString) + jobContext.getConfiguration.set("mapreduce.task.id", taskAttemptId.getTaskID.toString) + jobContext.getConfiguration.set("mapreduce.task.attempt.id", taskAttemptId.toString) + jobContext.getConfiguration.setBoolean("mapreduce.task.ismap", true) + jobContext.getConfiguration.setInt("mapreduce.task.partition", 0) + val taskAttemptContext = new TaskAttemptContextImpl(jobContext.getConfiguration, taskAttemptId) committer = setupCommitter(taskAttemptContext) committer.setupJob(jobContext) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java index fcf73e8d7ae6c..12f4d658b1868 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java @@ -47,7 +47,11 @@ public class OrcColumnVector extends org.apache.spark.sql.vectorized.ColumnVecto OrcColumnVector(DataType type, ColumnVector vector) { super(type); - isTimestamp = type instanceof TimestampType; + if (type instanceof TimestampType) { + isTimestamp = true; + } else { + isTimestamp = false; + } baseData = vector; if (vector instanceof LongColumnVector) { diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java index de0d65a1e0906..fe3d31ae8e746 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java @@ -571,7 +571,7 @@ private int readIntLittleEndian() throws IOException { int ch3 = in.read(); int ch2 = in.read(); int ch1 = in.read(); - return ((ch1 << 24) + (ch2 << 16) + (ch3 << 8) + (ch4)); + return ((ch1 << 24) + (ch2 << 16) + (ch3 << 8) + (ch4 << 0)); } /** @@ -592,7 +592,7 @@ private int readIntLittleEndianPaddedOnBitWidth() throws IOException { int ch3 = in.read(); int ch2 = in.read(); int ch1 = in.read(); - return (ch1 << 16) + (ch2 << 8) + (ch3); + return (ch1 << 16) + (ch2 << 8) + (ch3 << 0); } case 4: { return readIntLittleEndian(); diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala index 265a84b39a425..af20764f9a968 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala @@ -188,7 +188,7 @@ private[sql] object SQLUtils extends Logging { dataType match { case 's' => // Read StructType for DataFrame - val fields = SerDe.readList(dis, jvmObjectTracker = null) + val fields = SerDe.readList(dis, jvmObjectTracker = null).asInstanceOf[Array[Object]] Row.fromSeq(fields) case _ => null } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala index 6373584b10e35..5172f32ec7b9c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala @@ -410,10 +410,12 @@ object ViewHelper { } // Detect cyclic references from subqueries. - plan.expressions.foreach { - case s: SubqueryExpression => - checkCyclicViewReference(s.plan, path, viewIdent) - case _ => // Do nothing. + plan.expressions.foreach { expr => + expr match { + case s: SubqueryExpression => + checkCyclicViewReference(s.plan, path, viewIdent) + case _ => // Do nothing. + } } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala index 681bb1df6bbae..401597f967218 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala @@ -244,17 +244,18 @@ object FileFormatWriter extends Logging { iterator: Iterator[InternalRow]): WriteTaskResult = { val jobId = SparkHadoopWriterUtils.createJobID(new Date, sparkStageId) + val taskId = new TaskID(jobId, TaskType.MAP, sparkPartitionId) + val taskAttemptId = new TaskAttemptID(taskId, sparkAttemptNumber) // Set up the attempt context required to use in the output committer. val taskAttemptContext: TaskAttemptContext = { - val taskId = new TaskID(jobId, TaskType.MAP, sparkPartitionId) - val taskAttemptId = new TaskAttemptID(taskId, sparkAttemptNumber) // Set up the configuration object val hadoopConf = description.serializableHadoopConf.value hadoopConf.set("mapreduce.job.id", jobId.toString) hadoopConf.set("mapreduce.task.id", taskAttemptId.getTaskID.toString) hadoopConf.set("mapreduce.task.attempt.id", taskAttemptId.toString) hadoopConf.setBoolean("mapreduce.task.ismap", true) + hadoopConf.setInt("mapreduce.task.partition", 0) new TaskAttemptContextImpl(hadoopConf, taskAttemptId) } @@ -377,7 +378,7 @@ object FileFormatWriter extends Logging { dataSchema = description.dataColumns.toStructType, context = taskAttemptContext) - statsTrackers.foreach(_.newFile(currentPath)) + statsTrackers.map(_.newFile(currentPath)) } override def execute(iter: Iterator[InternalRow]): ExecutedWriteSummary = { @@ -428,10 +429,10 @@ object FileFormatWriter extends Logging { committer: FileCommitProtocol) extends ExecuteWriteTask { /** Flag saying whether or not the data to be written out is partitioned. */ - private val isPartitioned = desc.partitionColumns.nonEmpty + val isPartitioned = desc.partitionColumns.nonEmpty /** Flag saying whether or not the data to be written out is bucketed. */ - private val isBucketed = desc.bucketIdExpression.isDefined + val isBucketed = desc.bucketIdExpression.isDefined assert(isPartitioned || isBucketed, s"""DynamicPartitionWriteTask should be used for writing out data that's either diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala index 2c4d0bcf103ff..d254af400a7cf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala @@ -170,7 +170,7 @@ class SQLAppStatusListener( .filter { case (id, _) => metricIds.contains(id) } .groupBy(_._1) .map { case (id, values) => - id -> SQLMetrics.stringValue(metricTypes(id), values.map(_._2)) + id -> SQLMetrics.stringValue(metricTypes(id), values.map(_._2).toSeq) } // Check the execution again for whether the aggregated metrics data has been calculated. From 7f82c4a47e94ee4f544dee8bb71b99534e919769 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Fri, 18 May 2018 12:54:19 -0700 Subject: [PATCH 0828/1161] [SPARK-24312][SQL] Upgrade to 2.3.3 for Hive Metastore Client 2.3 ## What changes were proposed in this pull request? Hive 2.3.3 was [released on April 3rd](https://issues.apache.org/jira/secure/ReleaseNote.jspa?version=12342162&styleName=Text&projectId=12310843). This PR aims to upgrade Hive Metastore Client 2.3 from 2.3.2 to 2.3.3. ## How was this patch tested? Pass the Jenkins with the existing tests. Author: Dongjoon Hyun Closes #21359 from dongjoon-hyun/SPARK-24312. --- docs/sql-programming-guide.md | 4 ++-- .../src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala | 2 +- .../apache/spark/sql/hive/client/IsolatedClientLoader.scala | 2 +- .../main/scala/org/apache/spark/sql/hive/client/package.scala | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 3f79ed6422205..b93d8531d9efe 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1214,7 +1214,7 @@ The following options can be used to configure the version of Hive that is used
    @@ -2237,7 +2237,7 @@ referencing a singleton. Spark SQL is designed to be compatible with the Hive Metastore, SerDes and UDFs. Currently, Hive SerDes and UDFs are based on Hive 1.2.1, and Spark SQL can be connected to different versions of Hive Metastore -(from 0.12.0 to 2.3.2. Also see [Interacting with Different Versions of Hive Metastore](#interacting-with-different-versions-of-hive-metastore)). +(from 0.12.0 to 2.3.3. Also see [Interacting with Different Versions of Hive Metastore](#interacting-with-different-versions-of-hive-metastore)). #### Deploying in Existing Hive Warehouses diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala index bb134bbe68bd9..cd321d41f43e8 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala @@ -62,7 +62,7 @@ private[spark] object HiveUtils extends Logging { val HIVE_METASTORE_VERSION = buildConf("spark.sql.hive.metastore.version") .doc("Version of the Hive metastore. Available options are " + - s"0.12.0 through 2.3.2.") + s"0.12.0 through 2.3.3.") .stringConf .createWithDefault(builtinHiveVersion) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala index c2690ec32b9e7..2f34f69b5cf48 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala @@ -98,7 +98,7 @@ private[hive] object IsolatedClientLoader extends Logging { case "2.0" | "2.0.0" | "2.0.1" => hive.v2_0 case "2.1" | "2.1.0" | "2.1.1" => hive.v2_1 case "2.2" | "2.2.0" => hive.v2_2 - case "2.3" | "2.3.0" | "2.3.1" | "2.3.2" => hive.v2_3 + case "2.3" | "2.3.0" | "2.3.1" | "2.3.2" | "2.3.3" => hive.v2_3 } private def downloadVersion( diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala index 681ee9200f02b..25e9886fa6576 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala @@ -75,7 +75,7 @@ package object client { exclusions = Seq("org.apache.curator:*", "org.pentaho:pentaho-aggdesigner-algorithm")) - case object v2_3 extends HiveVersion("2.3.2", + case object v2_3 extends HiveVersion("2.3.3", exclusions = Seq("org.apache.curator:*", "org.pentaho:pentaho-aggdesigner-algorithm")) From 3159ee085b23e2e9f1657d80b7ae3efe82b5edb9 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Fri, 18 May 2018 13:04:00 -0700 Subject: [PATCH 0829/1161] [SPARK-24149][YARN] Retrieve all federated namespaces tokens ## What changes were proposed in this pull request? Hadoop 3 introduces HDFS federation. This means that multiple namespaces are allowed on the same HDFS cluster. In Spark, we need to ask the delegation token for all the namenodes (for each namespace), otherwise accessing any other namespace different from the default one (for which we already fetch the delegation token) fails. The PR adds the automatic discovery of all the namenodes related to all the namespaces available according to the configs in hdfs-site.xml. ## How was this patch tested? manual tests in dockerized env Author: Marco Gaido Closes #21216 from mgaido91/SPARK-24149. --- docs/running-on-yarn.md | 9 ++- .../deploy/yarn/YarnSparkHadoopUtil.scala | 24 ++++++- .../yarn/YarnSparkHadoopUtilSuite.scala | 65 ++++++++++++++++++- 3 files changed, 93 insertions(+), 5 deletions(-) diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index c9e68c3bfd056..4dbcbeafbbd9d 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -424,9 +424,12 @@ To use a custom metrics.properties for the application master and executors, upd Standard Kerberos support in Spark is covered in the [Security](security.html#kerberos) page. -In YARN mode, when accessing Hadoop file systems, aside from the service hosting the user's home -directory, Spark will also automatically obtain delegation tokens for the service hosting the -staging directory of the Spark application. +In YARN mode, when accessing Hadoop filesystems, Spark will automatically obtain delegation tokens +for: + +- the filesystem hosting the staging directory of the Spark application (which is the default + filesystem if `spark.yarn.stagingDir` is not set); +- if Hadoop federation is enabled, all the federated filesystems in the configuration. If an application needs to interact with other secure Hadoop filesystems, their URIs need to be explicitly provided to Spark at launch time. This is done by listing them in the diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala index 8eda6cb1277c5..7250e58b6c49a 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala @@ -200,7 +200,29 @@ object YarnSparkHadoopUtil { .map(new Path(_).getFileSystem(hadoopConf)) .getOrElse(FileSystem.get(hadoopConf)) - filesystemsToAccess + stagingFS + // Add the list of available namenodes for all namespaces in HDFS federation. + // If ViewFS is enabled, this is skipped as ViewFS already handles delegation tokens for its + // namespaces. + val hadoopFilesystems = if (stagingFS.getScheme == "viewfs") { + Set.empty + } else { + val nameservices = hadoopConf.getTrimmedStrings("dfs.nameservices") + // Retrieving the filesystem for the nameservices where HA is not enabled + val filesystemsWithoutHA = nameservices.flatMap { ns => + Option(hadoopConf.get(s"dfs.namenode.rpc-address.$ns")).map { nameNode => + new Path(s"hdfs://$nameNode").getFileSystem(hadoopConf) + } + } + // Retrieving the filesystem for the nameservices where HA is enabled + val filesystemsWithHA = nameservices.flatMap { ns => + Option(hadoopConf.get(s"dfs.ha.namenodes.$ns")).map { _ => + new Path(s"hdfs://$ns").getFileSystem(hadoopConf) + } + } + (filesystemsWithoutHA ++ filesystemsWithHA).toSet + } + + filesystemsToAccess ++ hadoopFilesystems + stagingFS } } diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala index f21353aa007c8..61c0c43f7c04f 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala @@ -21,7 +21,8 @@ import java.io.{File, IOException} import java.nio.charset.StandardCharsets import com.google.common.io.{ByteStreams, Files} -import org.apache.hadoop.io.Text +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path import org.apache.hadoop.yarn.api.records.ApplicationAccessType import org.apache.hadoop.yarn.conf.YarnConfiguration import org.scalatest.Matchers @@ -141,4 +142,66 @@ class YarnSparkHadoopUtilSuite extends SparkFunSuite with Matchers with Logging } + test("SPARK-24149: retrieve all namenodes from HDFS") { + val sparkConf = new SparkConf() + val basicFederationConf = new Configuration() + basicFederationConf.set("fs.defaultFS", "hdfs://localhost:8020") + basicFederationConf.set("dfs.nameservices", "ns1,ns2") + basicFederationConf.set("dfs.namenode.rpc-address.ns1", "localhost:8020") + basicFederationConf.set("dfs.namenode.rpc-address.ns2", "localhost:8021") + val basicFederationExpected = Set( + new Path("hdfs://localhost:8020").getFileSystem(basicFederationConf), + new Path("hdfs://localhost:8021").getFileSystem(basicFederationConf)) + val basicFederationResult = YarnSparkHadoopUtil.hadoopFSsToAccess( + sparkConf, basicFederationConf) + basicFederationResult should be (basicFederationExpected) + + // when viewfs is enabled, namespaces are handled by it, so we don't need to take care of them + val viewFsConf = new Configuration() + viewFsConf.addResource(basicFederationConf) + viewFsConf.set("fs.defaultFS", "viewfs://clusterX/") + viewFsConf.set("fs.viewfs.mounttable.clusterX.link./home", "hdfs://localhost:8020/") + val viewFsExpected = Set(new Path("viewfs://clusterX/").getFileSystem(viewFsConf)) + YarnSparkHadoopUtil.hadoopFSsToAccess(sparkConf, viewFsConf) should be (viewFsExpected) + + // invalid config should not throw NullPointerException + val invalidFederationConf = new Configuration() + invalidFederationConf.addResource(basicFederationConf) + invalidFederationConf.unset("dfs.namenode.rpc-address.ns2") + val invalidFederationExpected = Set( + new Path("hdfs://localhost:8020").getFileSystem(invalidFederationConf)) + val invalidFederationResult = YarnSparkHadoopUtil.hadoopFSsToAccess( + sparkConf, invalidFederationConf) + invalidFederationResult should be (invalidFederationExpected) + + // no namespaces defined, ie. old case + val noFederationConf = new Configuration() + noFederationConf.set("fs.defaultFS", "hdfs://localhost:8020") + val noFederationExpected = Set( + new Path("hdfs://localhost:8020").getFileSystem(noFederationConf)) + val noFederationResult = YarnSparkHadoopUtil.hadoopFSsToAccess(sparkConf, noFederationConf) + noFederationResult should be (noFederationExpected) + + // federation and HA enabled + val federationAndHAConf = new Configuration() + federationAndHAConf.set("fs.defaultFS", "hdfs://clusterXHA") + federationAndHAConf.set("dfs.nameservices", "clusterXHA,clusterYHA") + federationAndHAConf.set("dfs.ha.namenodes.clusterXHA", "x-nn1,x-nn2") + federationAndHAConf.set("dfs.ha.namenodes.clusterYHA", "y-nn1,y-nn2") + federationAndHAConf.set("dfs.namenode.rpc-address.clusterXHA.x-nn1", "localhost:8020") + federationAndHAConf.set("dfs.namenode.rpc-address.clusterXHA.x-nn2", "localhost:8021") + federationAndHAConf.set("dfs.namenode.rpc-address.clusterYHA.y-nn1", "localhost:8022") + federationAndHAConf.set("dfs.namenode.rpc-address.clusterYHA.y-nn2", "localhost:8023") + federationAndHAConf.set("dfs.client.failover.proxy.provider.clusterXHA", + "org.apache.hadoop.hdfs.server.namenode.ha.ConfiguredFailoverProxyProvider") + federationAndHAConf.set("dfs.client.failover.proxy.provider.clusterYHA", + "org.apache.hadoop.hdfs.server.namenode.ha.ConfiguredFailoverProxyProvider") + + val federationAndHAExpected = Set( + new Path("hdfs://clusterXHA").getFileSystem(federationAndHAConf), + new Path("hdfs://clusterYHA").getFileSystem(federationAndHAConf)) + val federationAndHAResult = YarnSparkHadoopUtil.hadoopFSsToAccess( + sparkConf, federationAndHAConf) + federationAndHAResult should be (federationAndHAExpected) + } } From a53ea70c1d8903cdff051edf667b0127c8131a09 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Fri, 18 May 2018 13:38:36 -0700 Subject: [PATCH 0830/1161] [SPARK-23856][SQL] Add an option `queryTimeout` in JDBCOptions ## What changes were proposed in this pull request? This pr added an option `queryTimeout` for the number of seconds the the driver will wait for a Statement object to execute. ## How was this patch tested? Added tests in `JDBCSuite`. Author: Takeshi Yamamuro Closes #21173 from maropu/SPARK-23856. --- docs/sql-programming-guide.md | 11 +++++++++++ .../org/apache/spark/sql/DataFrameReader.scala | 3 ++- .../datasources/jdbc/JDBCOptions.scala | 5 +++++ .../execution/datasources/jdbc/JDBCRDD.scala | 3 +++ .../jdbc/JdbcRelationProvider.scala | 2 +- .../execution/datasources/jdbc/JdbcUtils.scala | 16 +++++++++++++--- .../org/apache/spark/sql/jdbc/JDBCSuite.scala | 16 ++++++++++++++++ .../apache/spark/sql/jdbc/JDBCWriteSuite.scala | 18 ++++++++++++++++++ 8 files changed, 69 insertions(+), 5 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index b93d8531d9efe..f1ed316341b95 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1338,6 +1338,17 @@ the following case-insensitive options: + + + + +
    `spark.mllib` modelPMML model
    spark.mllib modelPMML model
    This configuration limits the number of remote requests to fetch blocks at any given point. When the number of hosts in the cluster increase, it might lead to very large number - of in-bound connections to one or more nodes, causing the workers to fail under load. + of inbound connections to one or more nodes, causing the workers to fail under load. By allowing it to limit the number of fetch requests, this scenario can be mitigated.
    4194304 (4 MB) The estimated cost to open a file, measured by the number of bytes could be scanned at the same - time. This is used when putting multiple files into a partition. It is better to over estimate, + time. This is used when putting multiple files into a partition. It is better to overestimate, then the partitions with small files will be faster than partitions with bigger files.
    0.8 for KUBERNETES mode; 0.8 for YARN mode; 0.0 for standalone mode and Mesos coarse-grained mode The minimum ratio of registered resources (registered resources / total expected resources) - (resources are executors in yarn mode and Kubernetes mode, CPU cores in standalone mode and Mesos coarsed-grained + (resources are executors in yarn mode and Kubernetes mode, CPU cores in standalone mode and Mesos coarse-grained mode ['spark.cores.max' value is total expected resources for Mesos coarse-grained mode] ) to wait for before scheduling begins. Specified as a double between 0.0 and 1.0. Regardless of whether the minimum ratio of resources has been reached, @@ -1634,7 +1634,7 @@ Apart from these, the following properties are also available, and may be useful false (Experimental) If set to "true", Spark will blacklist the executor immediately when a fetch - failure happenes. If external shuffle service is enabled, then the whole node will be + failure happens. If external shuffle service is enabled, then the whole node will be blacklisted.
    spark.streaming.receiver.writeAheadLog.enable false - Enable write ahead logs for receivers. All the input data received through receivers - will be saved to write ahead logs that will allow it to be recovered after driver failures. + Enable write-ahead logs for receivers. All the input data received through receivers + will be saved to write-ahead logs that will allow it to be recovered after driver failures. See the deployment guide in the Spark Streaming programing guide for more details. spark.streaming.driver.writeAheadLog.closeFileAfterWrite false - Whether to close the file after writing a write ahead log record on the driver. Set this to 'true' + Whether to close the file after writing a write-ahead log record on the driver. Set this to 'true' when you want to use S3 (or any file system that does not support flushing) for the metadata WAL on the driver. spark.streaming.receiver.writeAheadLog.closeFileAfterWrite false - Whether to close the file after writing a write ahead log record on the receivers. Set this to 'true' + Whether to close the file after writing a write-ahead log record on the receivers. Set this to 'true' when you want to use S3 (or any file system that does not support flushing) for the data WAL on the receivers. spark.mesos.constraints (none) - Attribute based constraints on mesos resource offers. By default, all resource offers will be accepted. This setting + Attribute-based constraints on mesos resource offers. By default, all resource offers will be accepted. This setting applies only to executors. Refer to Mesos Attributes & Resources for more information on attributes.
      diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index e07759a4dba87..ceda8a3ae2403 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -418,7 +418,7 @@ To use a custom metrics.properties for the application master and executors, upd - Whether core requests are honored in scheduling decisions depends on which scheduler is in use and how it is configured. - In `cluster` mode, the local directories used by the Spark executors and the Spark driver will be the local directories configured for YARN (Hadoop YARN config `yarn.nodemanager.local-dirs`). If the user specifies `spark.local.dir`, it will be ignored. In `client` mode, the Spark executors will use the local directories configured for YARN while the Spark driver will use those defined in `spark.local.dir`. This is because the Spark driver does not run on the YARN cluster in `client` mode, only the Spark executors do. -- The `--files` and `--archives` options support specifying file names with the # similar to Hadoop. For example you can specify: `--files localtest.txt#appSees.txt` and this will upload the file you have locally named `localtest.txt` into HDFS but this will be linked to by the name `appSees.txt`, and your application should use the name as `appSees.txt` to reference it when running on YARN. +- The `--files` and `--archives` options support specifying file names with the # similar to Hadoop. For example, you can specify: `--files localtest.txt#appSees.txt` and this will upload the file you have locally named `localtest.txt` into HDFS but this will be linked to by the name `appSees.txt`, and your application should use the name as `appSees.txt` to reference it when running on YARN. - The `--jars` option allows the `SparkContext.addJar` function to work if you are using it with local files and running in `cluster` mode. It does not need to be used if you are using it with HDFS, HTTP, HTTPS, or FTP files. # Kerberos diff --git a/docs/security.md b/docs/security.md index 3e5607a9a0d67..8c0c66fb5a285 100644 --- a/docs/security.md +++ b/docs/security.md @@ -374,7 +374,7 @@ replaced with one of the above namespaces.
    ${ns}.enabledAlgorithms None - A comma separated list of ciphers. The specified ciphers must be supported by JVM. + A comma-separated list of ciphers. The specified ciphers must be supported by JVM.
    The reference list of protocols can be found in the "JSSE Cipher Suite Names" section of the Java security guide. The list for Java 8 can be found at diff --git a/docs/spark-standalone.md b/docs/spark-standalone.md index 8fa643abf1373..f06e72a387df1 100644 --- a/docs/spark-standalone.md +++ b/docs/spark-standalone.md @@ -338,7 +338,7 @@ worker during one single schedule iteration. # Monitoring and Logging -Spark's standalone mode offers a web-based user interface to monitor the cluster. The master and each worker has its own web UI that shows cluster and job statistics. By default you can access the web UI for the master at port 8080. The port can be changed either in the configuration file or via command-line options. +Spark's standalone mode offers a web-based user interface to monitor the cluster. The master and each worker has its own web UI that shows cluster and job statistics. By default, you can access the web UI for the master at port 8080. The port can be changed either in the configuration file or via command-line options. In addition, detailed log output for each job is also written to the work directory of each slave node (`SPARK_HOME/work` by default). You will see two files for each job, `stdout` and `stderr`, with all output it wrote to its console. diff --git a/docs/sparkr.md b/docs/sparkr.md index 2909247e79e95..7fabab5d38f16 100644 --- a/docs/sparkr.md +++ b/docs/sparkr.md @@ -107,7 +107,7 @@ The following Spark driver properties can be set in `sparkConfig` with `sparkR.s With a `SparkSession`, applications can create `SparkDataFrame`s from a local R data frame, from a [Hive table](sql-programming-guide.html#hive-tables), or from other [data sources](sql-programming-guide.html#data-sources). ### From local data frames -The simplest way to create a data frame is to convert a local R data frame into a SparkDataFrame. Specifically we can use `as.DataFrame` or `createDataFrame` and pass in the local R data frame to create a SparkDataFrame. As an example, the following creates a `SparkDataFrame` based using the `faithful` dataset from R. +The simplest way to create a data frame is to convert a local R data frame into a SparkDataFrame. Specifically, we can use `as.DataFrame` or `createDataFrame` and pass in the local R data frame to create a SparkDataFrame. As an example, the following creates a `SparkDataFrame` based using the `faithful` dataset from R.
    {% highlight r %} @@ -169,7 +169,7 @@ df <- read.df(csvPath, "csv", header = "true", inferSchema = "true", na.strings {% endhighlight %}
    -The data sources API can also be used to save out SparkDataFrames into multiple file formats. For example we can save the SparkDataFrame from the previous example +The data sources API can also be used to save out SparkDataFrames into multiple file formats. For example, we can save the SparkDataFrame from the previous example to a Parquet file using `write.df`.
    @@ -241,7 +241,7 @@ head(filter(df, df$waiting < 50)) ### Grouping, Aggregation -SparkR data frames support a number of commonly used functions to aggregate data after grouping. For example we can compute a histogram of the `waiting` time in the `faithful` dataset as shown below +SparkR data frames support a number of commonly used functions to aggregate data after grouping. For example, we can compute a histogram of the `waiting` time in the `faithful` dataset as shown below
    {% highlight r %} diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 9822d669050d5..55d35b9dd31db 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -165,7 +165,7 @@ In addition to simple column references and expressions, Datasets also have a ri
    -In Python it's possible to access a DataFrame's columns either by attribute +In Python, it's possible to access a DataFrame's columns either by attribute (`df.age`) or by indexing (`df['age']`). While the former is convenient for interactive data exploration, users are highly encouraged to use the latter form, which is future proof and won't break with column names that @@ -278,7 +278,7 @@ the bytes back into an object. Spark SQL supports two different methods for converting existing RDDs into Datasets. The first method uses reflection to infer the schema of an RDD that contains specific types of objects. This -reflection based approach leads to more concise code and works well when you already know the schema +reflection-based approach leads to more concise code and works well when you already know the schema while writing your Spark application. The second method for creating Datasets is through a programmatic interface that allows you to @@ -1243,7 +1243,7 @@ The following options can be used to configure the version of Hive that is used
    com.mysql.jdbc,
    org.postgresql,
    com.microsoft.sqlserver,
    oracle.jdbc

    - A comma separated list of class prefixes that should be loaded using the classloader that is + A comma-separated list of class prefixes that should be loaded using the classloader that is shared between Spark SQL and a specific version of Hive. An example of classes that should be shared is JDBC drivers that are needed to talk to the metastore. Other classes that need to be shared are those that interact with classes that are already shared. For example, @@ -1441,7 +1441,7 @@ SELECT * FROM resultTable # Performance Tuning -For some workloads it is possible to improve performance by either caching data in memory, or by +For some workloads, it is possible to improve performance by either caching data in memory, or by turning on some experimental options. ## Caching Data In Memory @@ -1804,7 +1804,7 @@ working with timestamps in `pandas_udf`s to get the best performance, see ## Upgrading From Spark SQL 2.3 to 2.4 - Since Spark 2.4, Spark maximizes the usage of a vectorized ORC reader for ORC files by default. To do that, `spark.sql.orc.impl` and `spark.sql.orc.filterPushdown` change their default values to `native` and `true` respectively. - - In PySpark, when Arrow optimization is enabled, previously `toPandas` just failed when Arrow optimization is unabled to be used whereas `createDataFrame` from Pandas DataFrame allowed the fallback to non-optimization. Now, both `toPandas` and `createDataFrame` from Pandas DataFrame allow the fallback by default, which can be switched off by `spark.sql.execution.arrow.fallback.enabled`. + - In PySpark, when Arrow optimization is enabled, previously `toPandas` just failed when Arrow optimization is unable to be used whereas `createDataFrame` from Pandas DataFrame allowed the fallback to non-optimization. Now, both `toPandas` and `createDataFrame` from Pandas DataFrame allow the fallback by default, which can be switched off by `spark.sql.execution.arrow.fallback.enabled`. - Since Spark 2.4, writing an empty dataframe to a directory launches at least one write task, even if physically the dataframe has no partition. This introduces a small behavior change that for self-describing file formats like Parquet and Orc, Spark creates a metadata-only file in the target directory when writing a 0-partition dataframe, so that schema inference can still work if users read that directory later. The new behavior is more reasonable and more consistent regarding writing empty dataframe. - Since Spark 2.4, expression IDs in UDF arguments do not appear in column names. For example, an column name in Spark 2.4 is not `UDF:f(col0 AS colA#28)` but ``UDF:f(col0 AS `colA`)``. - Since Spark 2.4, writing a dataframe with an empty or nested empty schema using any file formats (parquet, orc, json, text, csv etc.) is not allowed. An exception is thrown when attempting to write dataframes with empty schema. @@ -1966,11 +1966,11 @@ working with timestamps in `pandas_udf`s to get the best performance, see - The rules to determine the result type of an arithmetic operation have been updated. In particular, if the precision / scale needed are out of the range of available values, the scale is reduced up to 6, in order to prevent the truncation of the integer part of the decimals. All the arithmetic operations are affected by the change, ie. addition (`+`), subtraction (`-`), multiplication (`*`), division (`/`), remainder (`%`) and positive module (`pmod`). - Literal values used in SQL operations are converted to DECIMAL with the exact precision and scale needed by them. - The configuration `spark.sql.decimalOperations.allowPrecisionLoss` has been introduced. It defaults to `true`, which means the new behavior described here; if set to `false`, Spark uses previous rules, ie. it doesn't adjust the needed scale to represent the values and it returns NULL if an exact representation of the value is not possible. - - In PySpark, `df.replace` does not allow to omit `value` when `to_replace` is not a dictionary. Previously, `value` could be omitted in the other cases and had `None` by default, which is counterintuitive and error prone. + - In PySpark, `df.replace` does not allow to omit `value` when `to_replace` is not a dictionary. Previously, `value` could be omitted in the other cases and had `None` by default, which is counterintuitive and error-prone. ## Upgrading From Spark SQL 2.1 to 2.2 - - Spark 2.1.1 introduced a new configuration key: `spark.sql.hive.caseSensitiveInferenceMode`. It had a default setting of `NEVER_INFER`, which kept behavior identical to 2.1.0. However, Spark 2.2.0 changes this setting's default value to `INFER_AND_SAVE` to restore compatibility with reading Hive metastore tables whose underlying file schema have mixed-case column names. With the `INFER_AND_SAVE` configuration value, on first access Spark will perform schema inference on any Hive metastore table for which it has not already saved an inferred schema. Note that schema inference can be a very time consuming operation for tables with thousands of partitions. If compatibility with mixed-case column names is not a concern, you can safely set `spark.sql.hive.caseSensitiveInferenceMode` to `NEVER_INFER` to avoid the initial overhead of schema inference. Note that with the new default `INFER_AND_SAVE` setting, the results of the schema inference are saved as a metastore key for future use. Therefore, the initial schema inference occurs only at a table's first access. + - Spark 2.1.1 introduced a new configuration key: `spark.sql.hive.caseSensitiveInferenceMode`. It had a default setting of `NEVER_INFER`, which kept behavior identical to 2.1.0. However, Spark 2.2.0 changes this setting's default value to `INFER_AND_SAVE` to restore compatibility with reading Hive metastore tables whose underlying file schema have mixed-case column names. With the `INFER_AND_SAVE` configuration value, on first access Spark will perform schema inference on any Hive metastore table for which it has not already saved an inferred schema. Note that schema inference can be a very time-consuming operation for tables with thousands of partitions. If compatibility with mixed-case column names is not a concern, you can safely set `spark.sql.hive.caseSensitiveInferenceMode` to `NEVER_INFER` to avoid the initial overhead of schema inference. Note that with the new default `INFER_AND_SAVE` setting, the results of the schema inference are saved as a metastore key for future use. Therefore, the initial schema inference occurs only at a table's first access. - Since Spark 2.2.1 and 2.3.0, the schema is always inferred at runtime when the data source tables have the columns that exist in both partition schema and data schema. The inferred schema does not have the partitioned columns. When reading the table, Spark respects the partition values of these overlapping columns instead of the values stored in the data source files. In 2.2.0 and 2.1.x release, the inferred schema is partitioned but the data of the table is invisible to users (i.e., the result set is empty). @@ -2013,7 +2013,7 @@ working with timestamps in `pandas_udf`s to get the best performance, see ## Upgrading From Spark SQL 1.5 to 1.6 - - From Spark 1.6, by default the Thrift server runs in multi-session mode. Which means each JDBC/ODBC + - From Spark 1.6, by default, the Thrift server runs in multi-session mode. Which means each JDBC/ODBC connection owns a copy of their own SQL configuration and temporary function registry. Cached tables are still shared though. If you prefer to run the Thrift server in the old single-session mode, please set option `spark.sql.hive.thriftServer.singleSession` to `true`. You may either add @@ -2161,7 +2161,7 @@ been renamed to `DataFrame`. This is primarily because DataFrames no longer inhe directly, but instead provide most of the functionality that RDDs provide though their own implementation. DataFrames can still be converted to RDDs by calling the `.rdd` method. -In Scala there is a type alias from `SchemaRDD` to `DataFrame` to provide source compatibility for +In Scala, there is a type alias from `SchemaRDD` to `DataFrame` to provide source compatibility for some use cases. It is still recommended that users update their code to use `DataFrame` instead. Java and Python users will need to update their code. @@ -2170,11 +2170,11 @@ Java and Python users will need to update their code. Prior to Spark 1.3 there were separate Java compatible classes (`JavaSQLContext` and `JavaSchemaRDD`) that mirrored the Scala API. In Spark 1.3 the Java API and Scala API have been unified. Users of either language should use `SQLContext` and `DataFrame`. In general these classes try to -use types that are usable from both languages (i.e. `Array` instead of language specific collections). +use types that are usable from both languages (i.e. `Array` instead of language-specific collections). In some cases where no common type exists (e.g., for passing in closures or Maps) function overloading is used instead. -Additionally the Java specific types API has been removed. Users of both Scala and Java should +Additionally, the Java specific types API has been removed. Users of both Scala and Java should use the classes present in `org.apache.spark.sql.types` to describe schema programmatically. @@ -2231,7 +2231,7 @@ referencing a singleton. ## Compatibility with Apache Hive Spark SQL is designed to be compatible with the Hive Metastore, SerDes and UDFs. -Currently Hive SerDes and UDFs are based on Hive 1.2.1, +Currently, Hive SerDes and UDFs are based on Hive 1.2.1, and Spark SQL can be connected to different versions of Hive Metastore (from 0.12.0 to 2.3.2. Also see [Interacting with Different Versions of Hive Metastore](#interacting-with-different-versions-of-hive-metastore)). @@ -2323,10 +2323,10 @@ A handful of Hive optimizations are not yet included in Spark. Some of these (su less important due to Spark SQL's in-memory computational model. Others are slotted for future releases of Spark SQL. -* Block level bitmap indexes and virtual columns (used to build indexes) -* Automatically determine the number of reducers for joins and groupbys: Currently in Spark SQL, you +* Block-level bitmap indexes and virtual columns (used to build indexes) +* Automatically determine the number of reducers for joins and groupbys: Currently, in Spark SQL, you need to control the degree of parallelism post-shuffle using "`SET spark.sql.shuffle.partitions=[num_tasks];`". -* Meta-data only query: For queries that can be answered by using only meta data, Spark SQL still +* Meta-data only query: For queries that can be answered by using only metadata, Spark SQL still launches tasks to compute the result. * Skew data flag: Spark SQL does not follow the skew data flags in Hive. * `STREAMTABLE` hint in join: Spark SQL does not follow the `STREAMTABLE` hint. @@ -2983,6 +2983,6 @@ does not exactly match standard floating point semantics. Specifically: - NaN = NaN returns true. - - In aggregations all NaN values are grouped together. + - In aggregations, all NaN values are grouped together. - NaN is treated as a normal value in join keys. - NaN values go last when in ascending order, larger than any other numeric value. diff --git a/docs/storage-openstack-swift.md b/docs/storage-openstack-swift.md index 1dd54719b21aa..dacaa3438d489 100644 --- a/docs/storage-openstack-swift.md +++ b/docs/storage-openstack-swift.md @@ -39,7 +39,7 @@ For example, for Maven support, add the following to the pom.xml fi # Configuration Parameters Create core-site.xml and place it inside Spark's conf directory. -The main category of parameters that should be configured are the authentication parameters +The main category of parameters that should be configured is the authentication parameters required by Keystone. The following table contains a list of Keystone mandatory parameters. PROVIDER can be diff --git a/docs/streaming-flume-integration.md b/docs/streaming-flume-integration.md index 257a4f7d4f3ca..a1b6942ffe0a4 100644 --- a/docs/streaming-flume-integration.md +++ b/docs/streaming-flume-integration.md @@ -17,7 +17,7 @@ Choose a machine in your cluster such that - Flume can be configured to push data to a port on that machine. -Due to the push model, the streaming application needs to be up, with the receiver scheduled and listening on the chosen port, for Flume to be able push data. +Due to the push model, the streaming application needs to be up, with the receiver scheduled and listening on the chosen port, for Flume to be able to push data. #### Configuring Flume Configure Flume agent to send data to an Avro sink by having the following in the configuration file. @@ -100,7 +100,7 @@ Choose a machine that will run the custom sink in a Flume agent. The rest of the #### Configuring Flume Configuring Flume on the chosen machine requires the following two steps. -1. **Sink JARs**: Add the following JARs to Flume's classpath (see [Flume's documentation](https://flume.apache.org/documentation.html) to see how) in the machine designated to run the custom sink . +1. **Sink JARs**: Add the following JARs to Flume's classpath (see [Flume's documentation](https://flume.apache.org/documentation.html) to see how) in the machine designated to run the custom sink. (i) *Custom sink JAR*: Download the JAR corresponding to the following artifact (or [direct link](http://search.maven.org/remotecontent?filepath=org/apache/spark/spark-streaming-flume-sink_{{site.SCALA_BINARY_VERSION}}/{{site.SPARK_VERSION_SHORT}}/spark-streaming-flume-sink_{{site.SCALA_BINARY_VERSION}}-{{site.SPARK_VERSION_SHORT}}.jar)). @@ -128,7 +128,7 @@ Configuring Flume on the chosen machine requires the following two steps. agent.sinks.spark.port = agent.sinks.spark.channel = memoryChannel - Also make sure that the upstream Flume pipeline is configured to send the data to the Flume agent running this sink. + Also, make sure that the upstream Flume pipeline is configured to send the data to the Flume agent running this sink. See the [Flume's documentation](https://flume.apache.org/documentation.html) for more information about configuring Flume agents. diff --git a/docs/streaming-kafka-0-8-integration.md b/docs/streaming-kafka-0-8-integration.md index 9f0671da2ee31..becf217738d26 100644 --- a/docs/streaming-kafka-0-8-integration.md +++ b/docs/streaming-kafka-0-8-integration.md @@ -10,7 +10,7 @@ Here we explain how to configure Spark Streaming to receive data from Kafka. The ## Approach 1: Receiver-based Approach This approach uses a Receiver to receive the data. The Receiver is implemented using the Kafka high-level consumer API. As with all receivers, the data received from Kafka through a Receiver is stored in Spark executors, and then jobs launched by Spark Streaming processes the data. -However, under default configuration, this approach can lose data under failures (see [receiver reliability](streaming-programming-guide.html#receiver-reliability). To ensure zero-data loss, you have to additionally enable Write Ahead Logs in Spark Streaming (introduced in Spark 1.2). This synchronously saves all the received Kafka data into write ahead logs on a distributed file system (e.g HDFS), so that all the data can be recovered on failure. See [Deploying section](streaming-programming-guide.html#deploying-applications) in the streaming programming guide for more details on Write Ahead Logs. +However, under default configuration, this approach can lose data under failures (see [receiver reliability](streaming-programming-guide.html#receiver-reliability). To ensure zero-data loss, you have to additionally enable Write-Ahead Logs in Spark Streaming (introduced in Spark 1.2). This synchronously saves all the received Kafka data into write-ahead logs on a distributed file system (e.g HDFS), so that all the data can be recovered on failure. See [Deploying section](streaming-programming-guide.html#deploying-applications) in the streaming programming guide for more details on Write-Ahead Logs. Next, we discuss how to use this approach in your streaming application. @@ -55,11 +55,11 @@ Next, we discuss how to use this approach in your streaming application. **Points to remember:** - - Topic partitions in Kafka does not correlate to partitions of RDDs generated in Spark Streaming. So increasing the number of topic-specific partitions in the `KafkaUtils.createStream()` only increases the number of threads using which topics that are consumed within a single receiver. It does not increase the parallelism of Spark in processing the data. Refer to the main document for more information on that. + - Topic partitions in Kafka do not correlate to partitions of RDDs generated in Spark Streaming. So increasing the number of topic-specific partitions in the `KafkaUtils.createStream()` only increases the number of threads using which topics that are consumed within a single receiver. It does not increase the parallelism of Spark in processing the data. Refer to the main document for more information on that. - Multiple Kafka input DStreams can be created with different groups and topics for parallel receiving of data using multiple receivers. - - If you have enabled Write Ahead Logs with a replicated file system like HDFS, the received data is already being replicated in the log. Hence, the storage level in storage level for the input stream to `StorageLevel.MEMORY_AND_DISK_SER` (that is, use + - If you have enabled Write-Ahead Logs with a replicated file system like HDFS, the received data is already being replicated in the log. Hence, the storage level in storage level for the input stream to `StorageLevel.MEMORY_AND_DISK_SER` (that is, use `KafkaUtils.createStream(..., StorageLevel.MEMORY_AND_DISK_SER)`). 3. **Deploying:** As with any Spark applications, `spark-submit` is used to launch your application. However, the details are slightly different for Scala/Java applications and Python applications. @@ -80,9 +80,9 @@ This approach has the following advantages over the receiver-based approach (i.e - *Simplified Parallelism:* No need to create multiple input Kafka streams and union them. With `directStream`, Spark Streaming will create as many RDD partitions as there are Kafka partitions to consume, which will all read data from Kafka in parallel. So there is a one-to-one mapping between Kafka and RDD partitions, which is easier to understand and tune. -- *Efficiency:* Achieving zero-data loss in the first approach required the data to be stored in a Write Ahead Log, which further replicated the data. This is actually inefficient as the data effectively gets replicated twice - once by Kafka, and a second time by the Write Ahead Log. This second approach eliminates the problem as there is no receiver, and hence no need for Write Ahead Logs. As long as you have sufficient Kafka retention, messages can be recovered from Kafka. +- *Efficiency:* Achieving zero-data loss in the first approach required the data to be stored in a Write-Ahead Log, which further replicated the data. This is actually inefficient as the data effectively gets replicated twice - once by Kafka, and a second time by the Write-Ahead Log. This second approach eliminates the problem as there is no receiver, and hence no need for Write-Ahead Logs. As long as you have sufficient Kafka retention, messages can be recovered from Kafka. -- *Exactly-once semantics:* The first approach uses Kafka's high level API to store consumed offsets in Zookeeper. This is traditionally the way to consume data from Kafka. While this approach (in combination with write ahead logs) can ensure zero data loss (i.e. at-least once semantics), there is a small chance some records may get consumed twice under some failures. This occurs because of inconsistencies between data reliably received by Spark Streaming and offsets tracked by Zookeeper. Hence, in this second approach, we use simple Kafka API that does not use Zookeeper. Offsets are tracked by Spark Streaming within its checkpoints. This eliminates inconsistencies between Spark Streaming and Zookeeper/Kafka, and so each record is received by Spark Streaming effectively exactly once despite failures. In order to achieve exactly-once semantics for output of your results, your output operation that saves the data to an external data store must be either idempotent, or an atomic transaction that saves results and offsets (see [Semantics of output operations](streaming-programming-guide.html#semantics-of-output-operations) in the main programming guide for further information). +- *Exactly-once semantics:* The first approach uses Kafka's high-level API to store consumed offsets in Zookeeper. This is traditionally the way to consume data from Kafka. While this approach (in combination with-write-ahead logs) can ensure zero data loss (i.e. at-least once semantics), there is a small chance some records may get consumed twice under some failures. This occurs because of inconsistencies between data reliably received by Spark Streaming and offsets tracked by Zookeeper. Hence, in this second approach, we use simple Kafka API that does not use Zookeeper. Offsets are tracked by Spark Streaming within its checkpoints. This eliminates inconsistencies between Spark Streaming and Zookeeper/Kafka, and so each record is received by Spark Streaming effectively exactly once despite failures. In order to achieve exactly-once semantics for output of your results, your output operation that saves the data to an external data store must be either idempotent, or an atomic transaction that saves results and offsets (see [Semantics of output operations](streaming-programming-guide.html#semantics-of-output-operations) in the main programming guide for further information). Note that one disadvantage of this approach is that it does not update offsets in Zookeeper, hence Zookeeper-based Kafka monitoring tools will not show progress. However, you can access the offsets processed by this approach in each batch and update Zookeeper yourself (see below). diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index ffda36d64a770..c30959263cdfa 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -1461,7 +1461,7 @@ Note that the connections in the pool should be lazily created on demand and tim *** ## DataFrame and SQL Operations -You can easily use [DataFrames and SQL](sql-programming-guide.html) operations on streaming data. You have to create a SparkSession using the SparkContext that the StreamingContext is using. Furthermore this has to done such that it can be restarted on driver failures. This is done by creating a lazily instantiated singleton instance of SparkSession. This is shown in the following example. It modifies the earlier [word count example](#a-quick-example) to generate word counts using DataFrames and SQL. Each RDD is converted to a DataFrame, registered as a temporary table and then queried using SQL. +You can easily use [DataFrames and SQL](sql-programming-guide.html) operations on streaming data. You have to create a SparkSession using the SparkContext that the StreamingContext is using. Furthermore, this has to done such that it can be restarted on driver failures. This is done by creating a lazily instantiated singleton instance of SparkSession. This is shown in the following example. It modifies the earlier [word count example](#a-quick-example) to generate word counts using DataFrames and SQL. Each RDD is converted to a DataFrame, registered as a temporary table and then queried using SQL.

    @@ -2010,10 +2010,10 @@ To run a Spark Streaming applications, you need to have the following. + *Mesos* - [Marathon](https://github.com/mesosphere/marathon) has been used to achieve this with Mesos. -- *Configuring write ahead logs* - Since Spark 1.2, - we have introduced _write ahead logs_ for achieving strong +- *Configuring write-ahead logs* - Since Spark 1.2, + we have introduced _write-ahead logs_ for achieving strong fault-tolerance guarantees. If enabled, all the data received from a receiver gets written into - a write ahead log in the configuration checkpoint directory. This prevents data loss on driver + a write-ahead log in the configuration checkpoint directory. This prevents data loss on driver recovery, thus ensuring zero data loss (discussed in detail in the [Fault-tolerance Semantics](#fault-tolerance-semantics) section). This can be enabled by setting the [configuration parameter](configuration.html#spark-streaming) @@ -2021,15 +2021,15 @@ To run a Spark Streaming applications, you need to have the following. come at the cost of the receiving throughput of individual receivers. This can be corrected by running [more receivers in parallel](#level-of-parallelism-in-data-receiving) to increase aggregate throughput. Additionally, it is recommended that the replication of the - received data within Spark be disabled when the write ahead log is enabled as the log is already + received data within Spark be disabled when the write-ahead log is enabled as the log is already stored in a replicated storage system. This can be done by setting the storage level for the input stream to `StorageLevel.MEMORY_AND_DISK_SER`. While using S3 (or any file system that - does not support flushing) for _write ahead logs_, please remember to enable + does not support flushing) for _write-ahead logs_, please remember to enable `spark.streaming.driver.writeAheadLog.closeFileAfterWrite` and `spark.streaming.receiver.writeAheadLog.closeFileAfterWrite`. See [Spark Streaming Configuration](configuration.html#spark-streaming) for more details. - Note that Spark will not encrypt data written to the write ahead log when I/O encryption is - enabled. If encryption of the write ahead log data is desired, it should be stored in a file + Note that Spark will not encrypt data written to the write-ahead log when I/O encryption is + enabled. If encryption of the write-ahead log data is desired, it should be stored in a file system that supports encryption natively. - *Setting the max receiving rate* - If the cluster resources is not large enough for the streaming @@ -2284,9 +2284,9 @@ Having bigger blockinterval means bigger blocks. A high value of `spark.locality - Instead of relying on batchInterval and blockInterval, you can define the number of partitions by calling `inputDstream.repartition(n)`. This reshuffles the data in RDD randomly to create n number of partitions. Yes, for greater parallelism. Though comes at the cost of a shuffle. An RDD's processing is scheduled by driver's jobscheduler as a job. At a given point of time only one job is active. So, if one job is executing the other jobs are queued. -- If you have two dstreams there will be two RDDs formed and there will be two jobs created which will be scheduled one after the another. To avoid this, you can union two dstreams. This will ensure that a single unionRDD is formed for the two RDDs of the dstreams. This unionRDD is then considered as a single job. However the partitioning of the RDDs is not impacted. +- If you have two dstreams there will be two RDDs formed and there will be two jobs created which will be scheduled one after the another. To avoid this, you can union two dstreams. This will ensure that a single unionRDD is formed for the two RDDs of the dstreams. This unionRDD is then considered as a single job. However, the partitioning of the RDDs is not impacted. -- If the batch processing time is more than batchinterval then obviously the receiver's memory will start filling up and will end up in throwing exceptions (most probably BlockNotFoundException). Currently there is no way to pause the receiver. Using SparkConf configuration `spark.streaming.receiver.maxRate`, rate of receiver can be limited. +- If the batch processing time is more than batchinterval then obviously the receiver's memory will start filling up and will end up in throwing exceptions (most probably BlockNotFoundException). Currently, there is no way to pause the receiver. Using SparkConf configuration `spark.streaming.receiver.maxRate`, rate of receiver can be limited. *************************************************************************************************** @@ -2388,7 +2388,7 @@ then besides these losses, all of the past data that was received and replicated lost. This will affect the results of the stateful transformations. To avoid this loss of past received data, Spark 1.2 introduced _write -ahead logs_ which save the received data to fault-tolerant storage. With the [write ahead logs +ahead logs_ which save the received data to fault-tolerant storage. With the [write-ahead logs enabled](#deploying-applications) and reliable receivers, there is zero data loss. In terms of semantics, it provides an at-least once guarantee. The following table summarizes the semantics under failures: @@ -2402,7 +2402,7 @@ The following table summarizes the semantics under failures:
    Spark 1.1 or earlier, OR
    - Spark 1.2 or later without write ahead logs + Spark 1.2 or later without write-ahead logs
    Buffered data lost with unreliable receivers
    @@ -2416,7 +2416,7 @@ The following table summarizes the semantics under failures:
    Spark 1.2 or later with write ahead logsSpark 1.2 or later with write-ahead logs Zero data loss with reliable receivers
    At-least once semantics diff --git a/docs/structured-streaming-kafka-integration.md b/docs/structured-streaming-kafka-integration.md index 5647ec6bc5797..71fd5b10cc407 100644 --- a/docs/structured-streaming-kafka-integration.md +++ b/docs/structured-streaming-kafka-integration.md @@ -15,7 +15,7 @@ For Scala/Java applications using SBT/Maven project definitions, link your appli For Python applications, you need to add this above library and its dependencies when deploying your application. See the [Deploying](#deploying) subsection below. -For experimenting on `spark-shell`, you need to add this above library and its dependencies too when invoking `spark-shell`. Also see the [Deploying](#deploying) subsection below. +For experimenting on `spark-shell`, you need to add this above library and its dependencies too when invoking `spark-shell`. Also, see the [Deploying](#deploying) subsection below. ## Reading Data from Kafka diff --git a/docs/structured-streaming-programming-guide.md b/docs/structured-streaming-programming-guide.md index 9a83f157452ad..602a4c70848e7 100644 --- a/docs/structured-streaming-programming-guide.md +++ b/docs/structured-streaming-programming-guide.md @@ -8,7 +8,7 @@ title: Structured Streaming Programming Guide {:toc} # Overview -Structured Streaming is a scalable and fault-tolerant stream processing engine built on the Spark SQL engine. You can express your streaming computation the same way you would express a batch computation on static data. The Spark SQL engine will take care of running it incrementally and continuously and updating the final result as streaming data continues to arrive. You can use the [Dataset/DataFrame API](sql-programming-guide.html) in Scala, Java, Python or R to express streaming aggregations, event-time windows, stream-to-batch joins, etc. The computation is executed on the same optimized Spark SQL engine. Finally, the system ensures end-to-end exactly-once fault-tolerance guarantees through checkpointing and Write Ahead Logs. In short, *Structured Streaming provides fast, scalable, fault-tolerant, end-to-end exactly-once stream processing without the user having to reason about streaming.* +Structured Streaming is a scalable and fault-tolerant stream processing engine built on the Spark SQL engine. You can express your streaming computation the same way you would express a batch computation on static data. The Spark SQL engine will take care of running it incrementally and continuously and updating the final result as streaming data continues to arrive. You can use the [Dataset/DataFrame API](sql-programming-guide.html) in Scala, Java, Python or R to express streaming aggregations, event-time windows, stream-to-batch joins, etc. The computation is executed on the same optimized Spark SQL engine. Finally, the system ensures end-to-end exactly-once fault-tolerance guarantees through checkpointing and Write-Ahead Logs. In short, *Structured Streaming provides fast, scalable, fault-tolerant, end-to-end exactly-once stream processing without the user having to reason about streaming.* Internally, by default, Structured Streaming queries are processed using a *micro-batch processing* engine, which processes data streams as a series of small batch jobs thereby achieving end-to-end latencies as low as 100 milliseconds and exactly-once fault-tolerance guarantees. However, since Spark 2.3, we have introduced a new low-latency processing mode called **Continuous Processing**, which can achieve end-to-end latencies as low as 1 millisecond with at-least-once guarantees. Without changing the Dataset/DataFrame operations in your queries, you will be able to choose the mode based on your application requirements. @@ -479,7 +479,7 @@ detail in the [Window Operations](#window-operations-on-event-time) section. ## Fault Tolerance Semantics Delivering end-to-end exactly-once semantics was one of key goals behind the design of Structured Streaming. To achieve that, we have designed the Structured Streaming sources, the sinks and the execution engine to reliably track the exact progress of the processing so that it can handle any kind of failure by restarting and/or reprocessing. Every streaming source is assumed to have offsets (similar to Kafka offsets, or Kinesis sequence numbers) -to track the read position in the stream. The engine uses checkpointing and write ahead logs to record the offset range of the data being processed in each trigger. The streaming sinks are designed to be idempotent for handling reprocessing. Together, using replayable sources and idempotent sinks, Structured Streaming can ensure **end-to-end exactly-once semantics** under any failure. +to track the read position in the stream. The engine uses checkpointing and write-ahead logs to record the offset range of the data being processed in each trigger. The streaming sinks are designed to be idempotent for handling reprocessing. Together, using replayable sources and idempotent sinks, Structured Streaming can ensure **end-to-end exactly-once semantics** under any failure. # API using Datasets and DataFrames Since Spark 2.0, DataFrames and Datasets can represent static, bounded data, as well as streaming, unbounded data. Similar to static Datasets/DataFrames, you can use the common entry point `SparkSession` @@ -690,7 +690,7 @@ These examples generate streaming DataFrames that are untyped, meaning that the By default, Structured Streaming from file based sources requires you to specify the schema, rather than rely on Spark to infer it automatically. This restriction ensures a consistent schema will be used for the streaming query, even in the case of failures. For ad-hoc use cases, you can reenable schema inference by setting `spark.sql.streaming.schemaInference` to `true`. -Partition discovery does occur when subdirectories that are named `/key=value/` are present and listing will automatically recurse into these directories. If these columns appear in the user provided schema, they will be filled in by Spark based on the path of the file being read. The directories that make up the partitioning scheme must be present when the query starts and must remain static. For example, it is okay to add `/data/year=2016/` when `/data/year=2015/` was present, but it is invalid to change the partitioning column (i.e. by creating the directory `/data/date=2016-04-17/`). +Partition discovery does occur when subdirectories that are named `/key=value/` are present and listing will automatically recurse into these directories. If these columns appear in the user-provided schema, they will be filled in by Spark based on the path of the file being read. The directories that make up the partitioning scheme must be present when the query starts and must remain static. For example, it is okay to add `/data/year=2016/` when `/data/year=2015/` was present, but it is invalid to change the partitioning column (i.e. by creating the directory `/data/date=2016-04-17/`). ## Operations on streaming DataFrames/Datasets You can apply all kinds of operations on streaming DataFrames/Datasets – ranging from untyped, SQL-like operations (e.g. `select`, `where`, `groupBy`), to typed RDD-like operations (e.g. `map`, `filter`, `flatMap`). See the [SQL programming guide](sql-programming-guide.html) for more details. Let’s take a look at a few example operations that you can use. @@ -2661,7 +2661,7 @@ sql("SET spark.sql.streaming.metricsEnabled=true") All queries started in the SparkSession after this configuration has been enabled will report metrics through Dropwizard to whatever [sinks](monitoring.html#metrics) have been configured (e.g. Ganglia, Graphite, JMX, etc.). ## Recovering from Failures with Checkpointing -In case of a failure or intentional shutdown, you can recover the previous progress and state of a previous query, and continue where it left off. This is done using checkpointing and write ahead logs. You can configure a query with a checkpoint location, and the query will save all the progress information (i.e. range of offsets processed in each trigger) and the running aggregates (e.g. word counts in the [quick example](#quick-example)) to the checkpoint location. This checkpoint location has to be a path in an HDFS compatible file system, and can be set as an option in the DataStreamWriter when [starting a query](#starting-streaming-queries). +In case of a failure or intentional shutdown, you can recover the previous progress and state of a previous query, and continue where it left off. This is done using checkpointing and write-ahead logs. You can configure a query with a checkpoint location, and the query will save all the progress information (i.e. range of offsets processed in each trigger) and the running aggregates (e.g. word counts in the [quick example](#quick-example)) to the checkpoint location. This checkpoint location has to be a path in an HDFS compatible file system, and can be set as an option in the DataStreamWriter when [starting a query](#starting-streaming-queries).
    diff --git a/docs/submitting-applications.md b/docs/submitting-applications.md index a3643bf0838a1..77aa083c4a584 100644 --- a/docs/submitting-applications.md +++ b/docs/submitting-applications.md @@ -177,7 +177,7 @@ The master URL passed to Spark can be in one of the following formats: # Loading Configuration from a File The `spark-submit` script can load default [Spark configuration values](configuration.html) from a -properties file and pass them on to your application. By default it will read options +properties file and pass them on to your application. By default, it will read options from `conf/spark-defaults.conf` in the Spark directory. For more detail, see the section on [loading default configurations](configuration.html#loading-default-configurations). diff --git a/docs/tuning.md b/docs/tuning.md index fc27713f28d46..912c39879be8f 100644 --- a/docs/tuning.md +++ b/docs/tuning.md @@ -196,7 +196,7 @@ To further tune garbage collection, we first need to understand some basic infor * A simplified description of the garbage collection procedure: When Eden is full, a minor GC is run on Eden and objects that are alive from Eden and Survivor1 are copied to Survivor2. The Survivor regions are swapped. If an object is old - enough or Survivor2 is full, it is moved to Old. Finally when Old is close to full, a full GC is invoked. + enough or Survivor2 is full, it is moved to Old. Finally, when Old is close to full, a full GC is invoked. The goal of GC tuning in Spark is to ensure that only long-lived RDDs are stored in the Old generation and that the Young generation is sufficiently sized to store short-lived objects. This will help avoid full GCs to collect diff --git a/python/README.md b/python/README.md index 3f17fdb98a081..2e0112da58b94 100644 --- a/python/README.md +++ b/python/README.md @@ -22,7 +22,7 @@ This packaging is currently experimental and may change in future versions (alth Using PySpark requires the Spark JARs, and if you are building this from source please see the builder instructions at ["Building Spark"](http://spark.apache.org/docs/latest/building-spark.html). -The Python packaging for Spark is not intended to replace all of the other use cases. This Python packaged version of Spark is suitable for interacting with an existing cluster (be it Spark standalone, YARN, or Mesos) - but does not contain the tools required to setup your own standalone Spark cluster. You can download the full version of Spark from the [Apache Spark downloads page](http://spark.apache.org/downloads.html). +The Python packaging for Spark is not intended to replace all of the other use cases. This Python packaged version of Spark is suitable for interacting with an existing cluster (be it Spark standalone, YARN, or Mesos) - but does not contain the tools required to set up your own standalone Spark cluster. You can download the full version of Spark from the [Apache Spark downloads page](http://spark.apache.org/downloads.html). **NOTE:** If you are using this with a Spark standalone cluster you must ensure that the version (including minor version) matches or you may experience odd errors. diff --git a/sql/README.md b/sql/README.md index fe1d352050c09..70cc7c637b58d 100644 --- a/sql/README.md +++ b/sql/README.md @@ -6,7 +6,7 @@ This module provides support for executing relational queries expressed in eithe Spark SQL is broken up into four subprojects: - Catalyst (sql/catalyst) - An implementation-agnostic framework for manipulating trees of relational operators and expressions. - Execution (sql/core) - A query planner / execution engine for translating Catalyst's logical query plans into Spark RDDs. This component also includes a new public interface, SQLContext, that allows users to execute SQL or LINQ statements against existing RDDs and Parquet files. - - Hive Support (sql/hive) - Includes an extension of SQLContext called HiveContext that allows users to write queries using a subset of HiveQL and access data from a Hive Metastore using Hive SerDes. There are also wrappers that allows users to run queries that include Hive UDFs, UDAFs, and UDTFs. + - Hive Support (sql/hive) - Includes an extension of SQLContext called HiveContext that allows users to write queries using a subset of HiveQL and access data from a Hive Metastore using Hive SerDes. There are also wrappers that allow users to run queries that include Hive UDFs, UDAFs, and UDTFs. - HiveServer and CLI support (sql/hive-thriftserver) - Includes support for the SQL CLI (bin/spark-sql) and a HiveServer2 (for JDBC/ODBC) compatible server. Running `sql/create-docs.sh` generates SQL documentation for built-in functions under `sql/site`. From 94524019315ad463f9bc13c107131091d17c6af9 Mon Sep 17 00:00:00 2001 From: Yuchen Huo Date: Fri, 6 Apr 2018 08:35:20 -0700 Subject: [PATCH 0573/1161] [SPARK-23822][SQL] Improve error message for Parquet schema mismatches ## What changes were proposed in this pull request? This pull request tries to improve the error message for spark while reading parquet files with different schemas, e.g. One with a STRING column and the other with a INT column. A new ParquetSchemaColumnConvertNotSupportedException is added to replace the old UnsupportedOperationException. The Exception is again wrapped in FileScanRdd.scala to throw a more a general QueryExecutionException with the actual parquet file name which trigger the exception. ## How was this patch tested? Unit tests added to check the new exception and verify the error messages. Also manually tested with two parquet with different schema to check the error message. screen shot 2018-03-30 at 4 03 04 pm Author: Yuchen Huo Closes #20953 from yuchenhuo/SPARK-23822. --- ...emaColumnConvertNotSupportedException.java | 62 +++++++++++++++++++ .../parquet/VectorizedColumnReader.java | 38 ++++++++---- .../execution/QueryExecutionException.scala | 3 +- .../execution/datasources/FileScanRDD.scala | 21 ++++++- .../parquet/ParquetSchemaSuite.scala | 55 ++++++++++++++++ 5 files changed, 166 insertions(+), 13 deletions(-) create mode 100644 sql/core/src/main/java/org/apache/spark/sql/execution/datasources/SchemaColumnConvertNotSupportedException.java diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/SchemaColumnConvertNotSupportedException.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/SchemaColumnConvertNotSupportedException.java new file mode 100644 index 0000000000000..82a1169cbe7ae --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/SchemaColumnConvertNotSupportedException.java @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources; + +import org.apache.spark.annotation.InterfaceStability; + +/** + * Exception thrown when the parquet reader find column type mismatches. + */ +@InterfaceStability.Unstable +public class SchemaColumnConvertNotSupportedException extends RuntimeException { + + /** + * Name of the column which cannot be converted. + */ + private String column; + /** + * Physical column type in the actual parquet file. + */ + private String physicalType; + /** + * Logical column type in the parquet schema the parquet reader use to parse all files. + */ + private String logicalType; + + public String getColumn() { + return column; + } + + public String getPhysicalType() { + return physicalType; + } + + public String getLogicalType() { + return logicalType; + } + + public SchemaColumnConvertNotSupportedException( + String column, + String physicalType, + String logicalType) { + super(); + this.column = column; + this.physicalType = physicalType; + this.logicalType = logicalType; + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java index 47dd625f4b154..72f1d024b08ce 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.datasources.parquet; import java.io.IOException; +import java.util.Arrays; import java.util.TimeZone; import org.apache.parquet.bytes.BytesUtils; @@ -31,6 +32,7 @@ import org.apache.parquet.schema.PrimitiveType; import org.apache.spark.sql.catalyst.util.DateTimeUtils; +import org.apache.spark.sql.execution.datasources.SchemaColumnConvertNotSupportedException; import org.apache.spark.sql.execution.vectorized.WritableColumnVector; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.DecimalType; @@ -231,6 +233,18 @@ private boolean shouldConvertTimestamps() { return convertTz != null && !convertTz.equals(UTC); } + /** + * Helper function to construct exception for parquet schema mismatch. + */ + private SchemaColumnConvertNotSupportedException constructConvertNotSupportedException( + ColumnDescriptor descriptor, + WritableColumnVector column) { + return new SchemaColumnConvertNotSupportedException( + Arrays.toString(descriptor.getPath()), + descriptor.getType().toString(), + column.dataType().toString()); + } + /** * Reads `num` values into column, decoding the values from `dictionaryIds` and `dictionary`. */ @@ -261,7 +275,7 @@ private void decodeDictionaryIds( } } } else { - throw new UnsupportedOperationException("Unimplemented type: " + column.dataType()); + throw constructConvertNotSupportedException(descriptor, column); } break; @@ -282,7 +296,7 @@ private void decodeDictionaryIds( } } } else { - throw new UnsupportedOperationException("Unimplemented type: " + column.dataType()); + throw constructConvertNotSupportedException(descriptor, column); } break; @@ -321,7 +335,7 @@ private void decodeDictionaryIds( } } } else { - throw new UnsupportedOperationException(); + throw constructConvertNotSupportedException(descriptor, column); } break; case BINARY: @@ -360,7 +374,7 @@ private void decodeDictionaryIds( } } } else { - throw new UnsupportedOperationException(); + throw constructConvertNotSupportedException(descriptor, column); } break; @@ -375,7 +389,9 @@ private void decodeDictionaryIds( */ private void readBooleanBatch(int rowId, int num, WritableColumnVector column) { - assert(column.dataType() == DataTypes.BooleanType); + if (column.dataType() != DataTypes.BooleanType) { + throw constructConvertNotSupportedException(descriptor, column); + } defColumn.readBooleans( num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn); } @@ -394,7 +410,7 @@ private void readIntBatch(int rowId, int num, WritableColumnVector column) { defColumn.readShorts( num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn); } else { - throw new UnsupportedOperationException("Unimplemented type: " + column.dataType()); + throw constructConvertNotSupportedException(descriptor, column); } } @@ -414,7 +430,7 @@ private void readLongBatch(int rowId, int num, WritableColumnVector column) { } } } else { - throw new UnsupportedOperationException("Unsupported conversion to: " + column.dataType()); + throw constructConvertNotSupportedException(descriptor, column); } } @@ -425,7 +441,7 @@ private void readFloatBatch(int rowId, int num, WritableColumnVector column) { defColumn.readFloats( num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn); } else { - throw new UnsupportedOperationException("Unsupported conversion to: " + column.dataType()); + throw constructConvertNotSupportedException(descriptor, column); } } @@ -436,7 +452,7 @@ private void readDoubleBatch(int rowId, int num, WritableColumnVector column) { defColumn.readDoubles( num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn); } else { - throw new UnsupportedOperationException("Unimplemented type: " + column.dataType()); + throw constructConvertNotSupportedException(descriptor, column); } } @@ -471,7 +487,7 @@ private void readBinaryBatch(int rowId, int num, WritableColumnVector column) { } } } else { - throw new UnsupportedOperationException("Unimplemented type: " + column.dataType()); + throw constructConvertNotSupportedException(descriptor, column); } } @@ -510,7 +526,7 @@ private void readFixedLenByteArrayBatch( } } } else { - throw new UnsupportedOperationException("Unimplemented type: " + column.dataType()); + throw constructConvertNotSupportedException(descriptor, column); } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecutionException.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecutionException.scala index 16806c620635f..cffd97baea6a2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecutionException.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecutionException.scala @@ -17,4 +17,5 @@ package org.apache.spark.sql.execution -class QueryExecutionException(message: String) extends Exception(message) +class QueryExecutionException(message: String, cause: Throwable = null) + extends Exception(message, cause) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala index 835ce98462477..28c36b6020d33 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala @@ -21,11 +21,14 @@ import java.io.{FileNotFoundException, IOException} import scala.collection.mutable +import org.apache.parquet.io.ParquetDecodingException + import org.apache.spark.{Partition => RDDPartition, TaskContext, TaskKilledException} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.rdd.{InputFileBlockHolder, RDD} import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.execution.QueryExecutionException import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.util.NextIterator @@ -179,7 +182,23 @@ class FileScanRDD( currentIterator = readCurrentFile() } - hasNext + try { + hasNext + } catch { + case e: SchemaColumnConvertNotSupportedException => + val message = "Parquet column cannot be converted in " + + s"file ${currentFile.filePath}. Column: ${e.getColumn}, " + + s"Expected: ${e.getLogicalType}, Found: ${e.getPhysicalType}" + throw new QueryExecutionException(message, e) + case e: ParquetDecodingException => + if (e.getMessage.contains("Can not read value at")) { + val message = "Encounter error while reading parquet files. " + + "One possible cause: Parquet column cannot be converted in the " + + "corresponding files. Details: " + throw new QueryExecutionException(message, e) + } + throw e + } } else { currentFile = null InputFileBlockHolder.unset() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala index 2cd2a600f2b97..9d3dfae348beb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala @@ -20,10 +20,13 @@ package org.apache.spark.sql.execution.datasources.parquet import scala.reflect.ClassTag import scala.reflect.runtime.universe.TypeTag +import org.apache.parquet.io.ParquetDecodingException import org.apache.parquet.schema.{MessageType, MessageTypeParser} import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.execution.QueryExecutionException +import org.apache.spark.sql.execution.datasources.SchemaColumnConvertNotSupportedException import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ @@ -382,6 +385,58 @@ class ParquetSchemaSuite extends ParquetSchemaTest { } } + // ======================================= + // Tests for parquet schema mismatch error + // ======================================= + def testSchemaMismatch(path: String, vectorizedReaderEnabled: Boolean): SparkException = { + import testImplicits._ + + var e: SparkException = null + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> vectorizedReaderEnabled.toString) { + // Create two parquet files with different schemas in the same folder + Seq(("bcd", 2)).toDF("a", "b").coalesce(1).write.mode("overwrite").parquet(s"$path/parquet") + Seq((1, "abc")).toDF("a", "b").coalesce(1).write.mode("append").parquet(s"$path/parquet") + + e = intercept[SparkException] { + spark.read.parquet(s"$path/parquet").collect() + } + } + e + } + + test("schema mismatch failure error message for parquet reader") { + withTempPath { dir => + val e = testSchemaMismatch(dir.getCanonicalPath, vectorizedReaderEnabled = false) + val expectedMessage = "Encounter error while reading parquet files. " + + "One possible cause: Parquet column cannot be converted in the corresponding " + + "files. Details:" + assert(e.getCause.isInstanceOf[QueryExecutionException]) + assert(e.getCause.getCause.isInstanceOf[ParquetDecodingException]) + assert(e.getCause.getMessage.startsWith(expectedMessage)) + } + } + + test("schema mismatch failure error message for parquet vectorized reader") { + withTempPath { dir => + val e = testSchemaMismatch(dir.getCanonicalPath, vectorizedReaderEnabled = true) + assert(e.getCause.isInstanceOf[QueryExecutionException]) + assert(e.getCause.getCause.isInstanceOf[SchemaColumnConvertNotSupportedException]) + + // Check if the physical type is reporting correctly + val errMsg = e.getCause.getMessage + assert(errMsg.startsWith("Parquet column cannot be converted in file")) + val file = errMsg.substring("Parquet column cannot be converted in file ".length, + errMsg.indexOf(". ")) + val col = spark.read.parquet(file).schema.fields.filter(_.name.equals("a")) + assert(col.length == 1) + if (col(0).dataType == StringType) { + assert(errMsg.contains("Column: [a], Expected: IntegerType, Found: BINARY")) + } else { + assert(errMsg.endsWith("Column: [a], Expected: StringType, Found: INT32")) + } + } + } + // ======================================================= // Tests for converting Parquet LIST to Catalyst ArrayType // ======================================================= From d766ea2ff2bf59afbd631d3cc2e43bebfccdebed Mon Sep 17 00:00:00 2001 From: Li Jin Date: Sat, 7 Apr 2018 00:15:54 +0800 Subject: [PATCH 0574/1161] [SPARK-23861][SQL][DOC] Clarify default window frame with and without orderBy clause ## What changes were proposed in this pull request? Add docstring to clarify default window frame boundaries with and without orderBy clause ## How was this patch tested? Manually generate doc and check. Author: Li Jin Closes #20978 from icexelloss/SPARK-23861-window-doc. --- python/pyspark/sql/window.py | 4 ++++ .../main/scala/org/apache/spark/sql/expressions/Window.scala | 4 ++++ 2 files changed, 8 insertions(+) diff --git a/python/pyspark/sql/window.py b/python/pyspark/sql/window.py index e667fba099fb9..d19ced954f04e 100644 --- a/python/pyspark/sql/window.py +++ b/python/pyspark/sql/window.py @@ -44,6 +44,10 @@ class Window(object): >>> # PARTITION BY country ORDER BY date RANGE BETWEEN 3 PRECEDING AND 3 FOLLOWING >>> window = Window.orderBy("date").partitionBy("country").rangeBetween(-3, 3) + .. note:: When ordering is not defined, an unbounded window frame (rowFrame, + unboundedPreceding, unboundedFollowing) is used by default. When ordering is defined, + a growing window frame (rangeFrame, unboundedPreceding, currentRow) is used by default. + .. note:: Experimental .. versionadded:: 1.4 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala index 1caa243f8d118..cd819bab1b14c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala @@ -33,6 +33,10 @@ import org.apache.spark.sql.catalyst.expressions._ * Window.partitionBy("country").orderBy("date").rowsBetween(-3, 3) * }}} * + * @note When ordering is not defined, an unbounded window frame (rowFrame, unboundedPreceding, + * unboundedFollowing) is used by default. When ordering is defined, a growing window frame + * (rangeFrame, unboundedPreceding, currentRow) is used by default. + * * @since 1.4.0 */ @InterfaceStability.Stable From c926acf719a6deb9d884a0f19bde075c312bfe5a Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Fri, 6 Apr 2018 18:42:14 +0200 Subject: [PATCH 0575/1161] [SPARK-23882][CORE] UTF8StringSuite.writeToOutputStreamUnderflow() is not expected to be supported ## What changes were proposed in this pull request? This PR excludes an existing UT [`writeToOutputStreamUnderflow()`](https://github.com/apache/spark/blob/master/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java#L519-L532) in `UTF8StringSuite`. As discussed [here](https://github.com/apache/spark/pull/19222#discussion_r177692142), the behavior of this test looks surprising. This test seems to access metadata area of the JVM object where is reserved by `Platform.BYTE_ARRAY_OFFSET`. This test is introduced thru #16089 by NathanHowell. More specifically, [the commit](https://github.com/apache/spark/pull/16089/commits/27c102deb1701fe62f776fe4da61dac959270b73) `Improve test coverage of UTFString.write` introduced this UT. However, I cannot find any discussion about this UT. I think that it would be good to exclude this UT. ```java public void writeToOutputStreamUnderflow() throws IOException { // offset underflow is apparently supported? final ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); final byte[] test = "01234567".getBytes(StandardCharsets.UTF_8); for (int i = 1; i <= Platform.BYTE_ARRAY_OFFSET; ++i) { new UTF8String( new ByteArrayMemoryBlock(test, Platform.BYTE_ARRAY_OFFSET - i, test.length + i)) .writeTo(outputStream); final ByteBuffer buffer = ByteBuffer.wrap(outputStream.toByteArray(), i, test.length); assertEquals("01234567", StandardCharsets.UTF_8.decode(buffer).toString()); outputStream.reset(); } } ``` ## How was this patch tested? Existing UTs Author: Kazuaki Ishizaki Closes #20995 from kiszk/SPARK-23882. --- .../spark/unsafe/types/UTF8StringSuite.java | 16 ---------------- 1 file changed, 16 deletions(-) diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java index bad908fcaf136..652c40a35527f 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java @@ -515,22 +515,6 @@ public void soundex() { assertEquals(fromString("世界千世").soundex(), fromString("世界千世")); } - @Test - public void writeToOutputStreamUnderflow() throws IOException { - // offset underflow is apparently supported? - final ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); - final byte[] test = "01234567".getBytes(StandardCharsets.UTF_8); - - for (int i = 1; i <= Platform.BYTE_ARRAY_OFFSET; ++i) { - new UTF8String( - new ByteArrayMemoryBlock(test, Platform.BYTE_ARRAY_OFFSET - i, test.length + i)) - .writeTo(outputStream); - final ByteBuffer buffer = ByteBuffer.wrap(outputStream.toByteArray(), i, test.length); - assertEquals("01234567", StandardCharsets.UTF_8.decode(buffer).toString()); - outputStream.reset(); - } - } - @Test public void writeToOutputStreamSlice() throws IOException { final ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); From d23a805f975f209f273db2b52de3f336be17d873 Mon Sep 17 00:00:00 2001 From: Bago Amirbekian Date: Fri, 6 Apr 2018 10:09:55 -0700 Subject: [PATCH 0576/1161] [SPARK-23859][ML] Initial PR for Instrumentation improvements: UUID and logging levels ## What changes were proposed in this pull request? Initial PR for Instrumentation improvements: UUID and logging levels. This PR takes over #20837 Closes #20837 ## How was this patch tested? Manual. Author: Bago Amirbekian Author: WeichenXu Closes #20982 from WeichenXu123/better-instrumentation. --- .../classification/LogisticRegression.scala | 15 ++++--- .../spark/ml/util/Instrumentation.scala | 40 +++++++++++++++---- 2 files changed, 41 insertions(+), 14 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 3ae4db3f3f965..ee4b01058c75c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -517,6 +517,9 @@ class LogisticRegression @Since("1.2.0") ( (new MultivariateOnlineSummarizer, new MultiClassSummarizer) )(seqOp, combOp, $(aggregationDepth)) } + instr.logNamedValue(Instrumentation.loggerTags.numExamples, summarizer.count) + instr.logNamedValue("lowestLabelWeight", labelSummarizer.histogram.min.toString) + instr.logNamedValue("highestLabelWeight", labelSummarizer.histogram.max.toString) val histogram = labelSummarizer.histogram val numInvalid = labelSummarizer.countInvalid @@ -560,15 +563,15 @@ class LogisticRegression @Since("1.2.0") ( if (numInvalid != 0) { val msg = s"Classification labels should be in [0 to ${numClasses - 1}]. " + s"Found $numInvalid invalid labels." - logError(msg) + instr.logError(msg) throw new SparkException(msg) } val isConstantLabel = histogram.count(_ != 0.0) == 1 if ($(fitIntercept) && isConstantLabel && !usingBoundConstrainedOptimization) { - logWarning(s"All labels are the same value and fitIntercept=true, so the coefficients " + - s"will be zeros. Training is not needed.") + instr.logWarning(s"All labels are the same value and fitIntercept=true, so the " + + s"coefficients will be zeros. Training is not needed.") val constantLabelIndex = Vectors.dense(histogram).argmax val coefMatrix = new SparseMatrix(numCoefficientSets, numFeatures, new Array[Int](numCoefficientSets + 1), Array.empty[Int], Array.empty[Double], @@ -581,7 +584,7 @@ class LogisticRegression @Since("1.2.0") ( (coefMatrix, interceptVec, Array.empty[Double]) } else { if (!$(fitIntercept) && isConstantLabel) { - logWarning(s"All labels belong to a single class and fitIntercept=false. It's a " + + instr.logWarning(s"All labels belong to a single class and fitIntercept=false. It's a " + s"dangerous ground, so the algorithm may not converge.") } @@ -590,7 +593,7 @@ class LogisticRegression @Since("1.2.0") ( if (!$(fitIntercept) && (0 until numFeatures).exists { i => featuresStd(i) == 0.0 && featuresMean(i) != 0.0 }) { - logWarning("Fitting LogisticRegressionModel without intercept on dataset with " + + instr.logWarning("Fitting LogisticRegressionModel without intercept on dataset with " + "constant nonzero column, Spark MLlib outputs zero coefficients for constant " + "nonzero columns. This behavior is the same as R glmnet but different from LIBSVM.") } @@ -708,7 +711,7 @@ class LogisticRegression @Since("1.2.0") ( (_initialModel.interceptVector.size == numCoefficientSets) && (_initialModel.getFitIntercept == $(fitIntercept)) if (!modelIsValid) { - logWarning(s"Initial coefficients will be ignored! Its dimensions " + + instr.logWarning(s"Initial coefficients will be ignored! Its dimensions " + s"(${providedCoefs.numRows}, ${providedCoefs.numCols}) did not match the " + s"expected size ($numCoefficientSets, $numFeatures)") } diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala b/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala index 7c46f45c59717..e694bc27b2f1e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala @@ -17,7 +17,7 @@ package org.apache.spark.ml.util -import java.util.concurrent.atomic.AtomicLong +import java.util.UUID import org.json4s._ import org.json4s.JsonDSL._ @@ -42,7 +42,7 @@ import org.apache.spark.sql.Dataset private[spark] class Instrumentation[E <: Estimator[_]] private ( estimator: E, dataset: RDD[_]) extends Logging { - private val id = Instrumentation.counter.incrementAndGet() + private val id = UUID.randomUUID() private val prefix = { val className = estimator.getClass.getSimpleName s"$className-${estimator.uid}-${dataset.hashCode()}-$id: " @@ -56,12 +56,31 @@ private[spark] class Instrumentation[E <: Estimator[_]] private ( } /** - * Logs a message with a prefix that uniquely identifies the training session. + * Logs a warning message with a prefix that uniquely identifies the training session. */ - def log(msg: String): Unit = { - logInfo(prefix + msg) + override def logWarning(msg: => String): Unit = { + super.logWarning(prefix + msg) } + /** + * Logs a error message with a prefix that uniquely identifies the training session. + */ + override def logError(msg: => String): Unit = { + super.logError(prefix + msg) + } + + /** + * Logs an info message with a prefix that uniquely identifies the training session. + */ + override def logInfo(msg: => String): Unit = { + super.logInfo(prefix + msg) + } + + /** + * Alias for logInfo, see above. + */ + def log(msg: String): Unit = logInfo(msg) + /** * Logs the value of the given parameters for the estimator being used in this session. */ @@ -77,11 +96,11 @@ private[spark] class Instrumentation[E <: Estimator[_]] private ( } def logNumFeatures(num: Long): Unit = { - log(compact(render("numFeatures" -> num))) + logNamedValue(Instrumentation.loggerTags.numFeatures, num) } def logNumClasses(num: Long): Unit = { - log(compact(render("numClasses" -> num))) + logNamedValue(Instrumentation.loggerTags.numClasses, num) } /** @@ -107,7 +126,12 @@ private[spark] class Instrumentation[E <: Estimator[_]] private ( * Some common methods for logging information about a training session. */ private[spark] object Instrumentation { - private val counter = new AtomicLong(0) + + object loggerTags { + val numFeatures = "numFeatures" + val numClasses = "numClasses" + val numExamples = "numExamples" + } /** * Creates an instrumentation object for a training session. From b6935ffb4dfb1d9fdf36ba402ac07bd02978c012 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Fri, 6 Apr 2018 10:23:26 -0700 Subject: [PATCH 0577/1161] [SPARK-10399][SPARK-23879][HOTFIX] Fix Java lint errors ## What changes were proposed in this pull request? This PR fixes the following errors in [Java lint](https://amplab.cs.berkeley.edu/jenkins/view/Spark%20QA%20Compile/job/spark-master-lint/7717/console) after #19222 has been merged. These errors were pointed by ueshin . ``` [ERROR] src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java:[57] (sizes) LineLength: Line is longer than 100 characters (found 106). [ERROR] src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java:[26,8] (imports) UnusedImports: Unused import - org.apache.spark.unsafe.Platform. [ERROR] src/main/java/org/apache/spark/unsafe/memory/OffHeapMemoryBlock.java:[23,10] (modifier) ModifierOrder: 'public' modifier out of order with the JLS suggestions. [ERROR] src/main/java/org/apache/spark/unsafe/memory/OnHeapMemoryBlock.java:[64,10] (modifier) RedundantModifier: Redundant 'final' modifier. [ERROR] src/main/java/org/apache/spark/unsafe/memory/OnHeapMemoryBlock.java:[69,10] (modifier) RedundantModifier: Redundant 'final' modifier. [ERROR] src/main/java/org/apache/spark/unsafe/memory/OnHeapMemoryBlock.java:[74,10] (modifier) RedundantModifier: Redundant 'final' modifier. [ERROR] src/main/java/org/apache/spark/unsafe/memory/OnHeapMemoryBlock.java:[79,10] (modifier) RedundantModifier: Redundant 'final' modifier. [ERROR] src/main/java/org/apache/spark/unsafe/memory/OnHeapMemoryBlock.java:[84,10] (modifier) RedundantModifier: Redundant 'final' modifier. [ERROR] src/main/java/org/apache/spark/unsafe/memory/OnHeapMemoryBlock.java:[89,10] (modifier) RedundantModifier: Redundant 'final' modifier. [ERROR] src/main/java/org/apache/spark/unsafe/memory/OnHeapMemoryBlock.java:[94,10] (modifier) RedundantModifier: Redundant 'final' modifier. [ERROR] src/main/java/org/apache/spark/unsafe/memory/OnHeapMemoryBlock.java:[99,10] (modifier) RedundantModifier: Redundant 'final' modifier. [ERROR] src/main/java/org/apache/spark/unsafe/memory/OnHeapMemoryBlock.java:[104,10] (modifier) RedundantModifier: Redundant 'final' modifier. [ERROR] src/main/java/org/apache/spark/unsafe/memory/OnHeapMemoryBlock.java:[109,10] (modifier) RedundantModifier: Redundant 'final' modifier. [ERROR] src/main/java/org/apache/spark/unsafe/memory/OnHeapMemoryBlock.java:[114,10] (modifier) RedundantModifier: Redundant 'final' modifier. [ERROR] src/main/java/org/apache/spark/unsafe/memory/OnHeapMemoryBlock.java:[119,10] (modifier) RedundantModifier: Redundant 'final' modifier. [ERROR] src/main/java/org/apache/spark/unsafe/memory/OnHeapMemoryBlock.java:[124,10] (modifier) RedundantModifier: Redundant 'final' modifier. [ERROR] src/main/java/org/apache/spark/unsafe/memory/OnHeapMemoryBlock.java:[129,10] (modifier) RedundantModifier: Redundant 'final' modifier. [ERROR] src/main/java/org/apache/spark/unsafe/memory/ByteArrayMemoryBlock.java:[60,10] (modifier) RedundantModifier: Redundant 'final' modifier. [ERROR] src/main/java/org/apache/spark/unsafe/memory/ByteArrayMemoryBlock.java:[65,10] (modifier) RedundantModifier: Redundant 'final' modifier. [ERROR] src/main/java/org/apache/spark/unsafe/memory/ByteArrayMemoryBlock.java:[70,10] (modifier) RedundantModifier: Redundant 'final' modifier. [ERROR] src/main/java/org/apache/spark/unsafe/memory/ByteArrayMemoryBlock.java:[75,10] (modifier) RedundantModifier: Redundant 'final' modifier. [ERROR] src/main/java/org/apache/spark/unsafe/memory/ByteArrayMemoryBlock.java:[80,10] (modifier) RedundantModifier: Redundant 'final' modifier. [ERROR] src/main/java/org/apache/spark/unsafe/memory/ByteArrayMemoryBlock.java:[85,10] (modifier) RedundantModifier: Redundant 'final' modifier. [ERROR] src/main/java/org/apache/spark/unsafe/memory/ByteArrayMemoryBlock.java:[90,10] (modifier) RedundantModifier: Redundant 'final' modifier. [ERROR] src/main/java/org/apache/spark/unsafe/memory/ByteArrayMemoryBlock.java:[95,10] (modifier) RedundantModifier: Redundant 'final' modifier. [ERROR] src/main/java/org/apache/spark/unsafe/memory/ByteArrayMemoryBlock.java:[100,10] (modifier) RedundantModifier: Redundant 'final' modifier. [ERROR] src/main/java/org/apache/spark/unsafe/memory/ByteArrayMemoryBlock.java:[105,10] (modifier) RedundantModifier: Redundant 'final' modifier. [ERROR] src/main/java/org/apache/spark/unsafe/memory/ByteArrayMemoryBlock.java:[110,10] (modifier) RedundantModifier: Redundant 'final' modifier. [ERROR] src/main/java/org/apache/spark/unsafe/memory/ByteArrayMemoryBlock.java:[115,10] (modifier) RedundantModifier: Redundant 'final' modifier. [ERROR] src/main/java/org/apache/spark/unsafe/memory/ByteArrayMemoryBlock.java:[120,10] (modifier) RedundantModifier: Redundant 'final' modifier. [ERROR] src/main/java/org/apache/spark/unsafe/memory/ByteArrayMemoryBlock.java:[125,10] (modifier) RedundantModifier: Redundant 'final' modifier. [ERROR] src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java:[114,16] (modifier) ModifierOrder: 'static' modifier out of order with the JLS suggestions. [ERROR] src/main/java/org/apache/spark/sql/catalyst/expressions/HiveHasher.java:[20,8] (imports) UnusedImports: Unused import - org.apache.spark.unsafe.Platform. [ERROR] src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java:[30,8] (imports) UnusedImports: Unused import - org.apache.spark.unsafe.memory.MemoryBlock. [ERROR] src/test/java/org/apache/spark/unsafe/memory/MemoryBlockSuite.java:[126,15] (naming) MethodName: Method name 'ByteArrayMemoryBlockTest' must match pattern '^[a-z][a-z0-9][a-zA-Z0-9_]*$'. [ERROR] src/test/java/org/apache/spark/unsafe/memory/MemoryBlockSuite.java:[143,15] (naming) MethodName: Method name 'OnHeapMemoryBlockTest' must match pattern '^[a-z][a-z0-9][a-zA-Z0-9_]*$'. [ERROR] src/test/java/org/apache/spark/unsafe/memory/MemoryBlockSuite.java:[160,15] (naming) MethodName: Method name 'OffHeapArrayMemoryBlockTest' must match pattern '^[a-z][a-z0-9][a-zA-Z0-9_]*$'. [ERROR] src/main/java/org/apache/spark/sql/catalyst/expressions/XXH64.java:[19,8] (imports) UnusedImports: Unused import - com.google.common.primitives.Ints. [ERROR] src/main/java/org/apache/spark/sql/catalyst/expressions/XXH64.java:[21,8] (imports) UnusedImports: Unused import - org.apache.spark.unsafe.Platform. [ERROR] src/test/java/org/apache/spark/sql/catalyst/expressions/HiveHasherSuite.java:[20,8] (imports) UnusedImports: Unused import - org.apache.spark.unsafe.Platform. ``` ## How was this patch tested? Existing UTs Author: Kazuaki Ishizaki Closes #20991 from kiszk/SPARK-10399-jlint. --- .../sql/catalyst/expressions/HiveHasher.java | 1 - .../spark/unsafe/array/ByteArrayMethods.java | 4 +-- .../unsafe/memory/ByteArrayMemoryBlock.java | 28 +++++++++---------- .../unsafe/memory/HeapMemoryAllocator.java | 2 -- .../spark/unsafe/memory/MemoryBlock.java | 2 +- .../unsafe/memory/OffHeapMemoryBlock.java | 2 +- .../unsafe/memory/OnHeapMemoryBlock.java | 28 +++++++++---------- .../spark/unsafe/memory/MemoryBlockSuite.java | 6 ++-- .../spark/unsafe/types/UTF8StringSuite.java | 1 - .../spark/sql/catalyst/expressions/XXH64.java | 3 -- .../catalyst/expressions/HiveHasherSuite.java | 1 - 11 files changed, 35 insertions(+), 43 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/expressions/HiveHasher.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/expressions/HiveHasher.java index 5d905943a3aa7..c34e36903a93e 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/expressions/HiveHasher.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/expressions/HiveHasher.java @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst.expressions; -import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.memory.MemoryBlock; /** diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java index c334c9651cf6b..4bc9955090fd7 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java @@ -54,7 +54,7 @@ public static int roundNumberOfBytesToNearestWord(int numBytes) { * @return true if the arrays are equal, false otherwise */ public static boolean arrayEqualsBlock( - MemoryBlock leftBase, long leftOffset, MemoryBlock rightBase, long rightOffset, final long length) { + MemoryBlock leftBase, long leftOffset, MemoryBlock rightBase, long rightOffset, long length) { return arrayEquals(leftBase.getBaseObject(), leftBase.getBaseOffset() + leftOffset, rightBase.getBaseObject(), rightBase.getBaseOffset() + rightOffset, length); } @@ -64,7 +64,7 @@ public static boolean arrayEqualsBlock( * @return true if the arrays are equal, false otherwise */ public static boolean arrayEquals( - Object leftBase, long leftOffset, Object rightBase, long rightOffset, final long length) { + Object leftBase, long leftOffset, Object rightBase, long rightOffset, long length) { int i = 0; // check if starts align and we can get both offsets to be aligned diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/ByteArrayMemoryBlock.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/ByteArrayMemoryBlock.java index 99a9868a49a79..9f238632bc87a 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/ByteArrayMemoryBlock.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/ByteArrayMemoryBlock.java @@ -57,72 +57,72 @@ public static ByteArrayMemoryBlock fromArray(final byte[] array) { } @Override - public final int getInt(long offset) { + public int getInt(long offset) { return Platform.getInt(array, this.offset + offset); } @Override - public final void putInt(long offset, int value) { + public void putInt(long offset, int value) { Platform.putInt(array, this.offset + offset, value); } @Override - public final boolean getBoolean(long offset) { + public boolean getBoolean(long offset) { return Platform.getBoolean(array, this.offset + offset); } @Override - public final void putBoolean(long offset, boolean value) { + public void putBoolean(long offset, boolean value) { Platform.putBoolean(array, this.offset + offset, value); } @Override - public final byte getByte(long offset) { + public byte getByte(long offset) { return array[(int)(this.offset + offset - Platform.BYTE_ARRAY_OFFSET)]; } @Override - public final void putByte(long offset, byte value) { + public void putByte(long offset, byte value) { array[(int)(this.offset + offset - Platform.BYTE_ARRAY_OFFSET)] = value; } @Override - public final short getShort(long offset) { + public short getShort(long offset) { return Platform.getShort(array, this.offset + offset); } @Override - public final void putShort(long offset, short value) { + public void putShort(long offset, short value) { Platform.putShort(array, this.offset + offset, value); } @Override - public final long getLong(long offset) { + public long getLong(long offset) { return Platform.getLong(array, this.offset + offset); } @Override - public final void putLong(long offset, long value) { + public void putLong(long offset, long value) { Platform.putLong(array, this.offset + offset, value); } @Override - public final float getFloat(long offset) { + public float getFloat(long offset) { return Platform.getFloat(array, this.offset + offset); } @Override - public final void putFloat(long offset, float value) { + public void putFloat(long offset, float value) { Platform.putFloat(array, this.offset + offset, value); } @Override - public final double getDouble(long offset) { + public double getDouble(long offset) { return Platform.getDouble(array, this.offset + offset); } @Override - public final void putDouble(long offset, double value) { + public void putDouble(long offset, double value) { Platform.putDouble(array, this.offset + offset, value); } } diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java index acf28fd7ee59b..36caf80888cda 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java @@ -23,8 +23,6 @@ import java.util.LinkedList; import java.util.Map; -import org.apache.spark.unsafe.Platform; - /** * A simple {@link MemoryAllocator} that can allocate up to 16GB using a JVM long primitive array. */ diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java index b086941108522..ca7213bbf92da 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java @@ -111,7 +111,7 @@ public final void fill(byte value) { /** * Instantiate MemoryBlock for given object type with new offset */ - public final static MemoryBlock allocateFromObject(Object obj, long offset, long length) { + public static final MemoryBlock allocateFromObject(Object obj, long offset, long length) { MemoryBlock mb = null; if (obj instanceof byte[]) { byte[] array = (byte[])obj; diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/OffHeapMemoryBlock.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/OffHeapMemoryBlock.java index f90f62bf21dcb..3431b08980eb8 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/OffHeapMemoryBlock.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/OffHeapMemoryBlock.java @@ -20,7 +20,7 @@ import org.apache.spark.unsafe.Platform; public class OffHeapMemoryBlock extends MemoryBlock { - static public final OffHeapMemoryBlock NULL = new OffHeapMemoryBlock(0, 0); + public static final OffHeapMemoryBlock NULL = new OffHeapMemoryBlock(0, 0); public OffHeapMemoryBlock(long address, long size) { super(null, address, size); diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/OnHeapMemoryBlock.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/OnHeapMemoryBlock.java index 12f67c7bd593e..ee42bc27c9c5f 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/OnHeapMemoryBlock.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/OnHeapMemoryBlock.java @@ -61,72 +61,72 @@ public static OnHeapMemoryBlock fromArray(final long[] array, long size) { } @Override - public final int getInt(long offset) { + public int getInt(long offset) { return Platform.getInt(array, this.offset + offset); } @Override - public final void putInt(long offset, int value) { + public void putInt(long offset, int value) { Platform.putInt(array, this.offset + offset, value); } @Override - public final boolean getBoolean(long offset) { + public boolean getBoolean(long offset) { return Platform.getBoolean(array, this.offset + offset); } @Override - public final void putBoolean(long offset, boolean value) { + public void putBoolean(long offset, boolean value) { Platform.putBoolean(array, this.offset + offset, value); } @Override - public final byte getByte(long offset) { + public byte getByte(long offset) { return Platform.getByte(array, this.offset + offset); } @Override - public final void putByte(long offset, byte value) { + public void putByte(long offset, byte value) { Platform.putByte(array, this.offset + offset, value); } @Override - public final short getShort(long offset) { + public short getShort(long offset) { return Platform.getShort(array, this.offset + offset); } @Override - public final void putShort(long offset, short value) { + public void putShort(long offset, short value) { Platform.putShort(array, this.offset + offset, value); } @Override - public final long getLong(long offset) { + public long getLong(long offset) { return Platform.getLong(array, this.offset + offset); } @Override - public final void putLong(long offset, long value) { + public void putLong(long offset, long value) { Platform.putLong(array, this.offset + offset, value); } @Override - public final float getFloat(long offset) { + public float getFloat(long offset) { return Platform.getFloat(array, this.offset + offset); } @Override - public final void putFloat(long offset, float value) { + public void putFloat(long offset, float value) { Platform.putFloat(array, this.offset + offset, value); } @Override - public final double getDouble(long offset) { + public double getDouble(long offset) { return Platform.getDouble(array, this.offset + offset); } @Override - public final void putDouble(long offset, double value) { + public void putDouble(long offset, double value) { Platform.putDouble(array, this.offset + offset, value); } } diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/memory/MemoryBlockSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/memory/MemoryBlockSuite.java index 47f05c928f2e5..5d5fdc1c55a75 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/memory/MemoryBlockSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/memory/MemoryBlockSuite.java @@ -123,7 +123,7 @@ private void check(MemoryBlock memory, Object obj, long offset, int length) { } @Test - public void ByteArrayMemoryBlockTest() { + public void testByteArrayMemoryBlock() { byte[] obj = new byte[56]; long offset = Platform.BYTE_ARRAY_OFFSET; int length = obj.length; @@ -140,7 +140,7 @@ public void ByteArrayMemoryBlockTest() { } @Test - public void OnHeapMemoryBlockTest() { + public void testOnHeapMemoryBlock() { long[] obj = new long[7]; long offset = Platform.LONG_ARRAY_OFFSET; int length = obj.length * 8; @@ -157,7 +157,7 @@ public void OnHeapMemoryBlockTest() { } @Test - public void OffHeapArrayMemoryBlockTest() { + public void testOffHeapArrayMemoryBlock() { MemoryAllocator memoryAllocator = new UnsafeMemoryAllocator(); MemoryBlock memory = memoryAllocator.allocate(56); Object obj = memory.getBaseObject(); diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java index 652c40a35527f..2c08535a16465 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java @@ -27,7 +27,6 @@ import com.google.common.collect.ImmutableMap; import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.memory.ByteArrayMemoryBlock; -import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.unsafe.memory.OnHeapMemoryBlock; import org.junit.Test; diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/XXH64.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/XXH64.java index 883748932ad33..fe727f6011cbf 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/XXH64.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/XXH64.java @@ -16,9 +16,6 @@ */ package org.apache.spark.sql.catalyst.expressions; -import com.google.common.primitives.Ints; - -import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.memory.MemoryBlock; // scalastyle: off diff --git a/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/HiveHasherSuite.java b/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/HiveHasherSuite.java index 8ffc1d7c24d61..76930f9368514 100644 --- a/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/HiveHasherSuite.java +++ b/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/HiveHasherSuite.java @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst.expressions; -import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.memory.ByteArrayMemoryBlock; import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.unsafe.types.UTF8String; From e998250588de0df250e2800278da4d3e3705c259 Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Fri, 6 Apr 2018 11:51:36 -0700 Subject: [PATCH 0578/1161] [SPARK-23828][ML][PYTHON] PySpark StringIndexerModel should have constructor from labels ## What changes were proposed in this pull request? The Scala StringIndexerModel has an alternate constructor that will create the model from an array of label strings. Add the corresponding Python API: model = StringIndexerModel.from_labels(["a", "b", "c"]) ## How was this patch tested? Add doctest and unit test. Author: Huaxin Gao Closes #20968 from huaxingao/spark-23828. --- python/pyspark/ml/feature.py | 88 ++++++++++++++++++++++++++---------- python/pyspark/ml/tests.py | 41 ++++++++++++++++- 2 files changed, 104 insertions(+), 25 deletions(-) diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index fcb0dfc563720..5a3e0dd655150 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -2342,9 +2342,38 @@ def mean(self): return self._call_java("mean") +class _StringIndexerParams(JavaParams, HasHandleInvalid, HasInputCol, HasOutputCol): + """ + Params for :py:attr:`StringIndexer` and :py:attr:`StringIndexerModel`. + """ + + stringOrderType = Param(Params._dummy(), "stringOrderType", + "How to order labels of string column. The first label after " + + "ordering is assigned an index of 0. Supported options: " + + "frequencyDesc, frequencyAsc, alphabetDesc, alphabetAsc.", + typeConverter=TypeConverters.toString) + + handleInvalid = Param(Params._dummy(), "handleInvalid", "how to handle invalid data (unseen " + + "or NULL values) in features and label column of string type. " + + "Options are 'skip' (filter out rows with invalid data), " + + "error (throw an error), or 'keep' (put invalid data " + + "in a special additional bucket, at index numLabels).", + typeConverter=TypeConverters.toString) + + def __init__(self, *args): + super(_StringIndexerParams, self).__init__(*args) + self._setDefault(handleInvalid="error", stringOrderType="frequencyDesc") + + @since("2.3.0") + def getStringOrderType(self): + """ + Gets the value of :py:attr:`stringOrderType` or its default value 'frequencyDesc'. + """ + return self.getOrDefault(self.stringOrderType) + + @inherit_doc -class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol, HasHandleInvalid, JavaMLReadable, - JavaMLWritable): +class StringIndexer(JavaEstimator, _StringIndexerParams, JavaMLReadable, JavaMLWritable): """ A label indexer that maps a string column of labels to an ML column of label indices. If the input column is numeric, we cast it to string and index the string values. @@ -2388,23 +2417,16 @@ class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol, HasHandleInvalid, >>> sorted(set([(i[0], i[1]) for i in td.select(td.id, td.indexed).collect()]), ... key=lambda x: x[0]) [(0, 2.0), (1, 1.0), (2, 0.0), (3, 2.0), (4, 2.0), (5, 0.0)] + >>> fromlabelsModel = StringIndexerModel.from_labels(["a", "b", "c"], + ... inputCol="label", outputCol="indexed", handleInvalid="error") + >>> result = fromlabelsModel.transform(stringIndDf) + >>> sorted(set([(i[0], i[1]) for i in result.select(result.id, result.indexed).collect()]), + ... key=lambda x: x[0]) + [(0, 0.0), (1, 1.0), (2, 2.0), (3, 0.0), (4, 0.0), (5, 2.0)] .. versionadded:: 1.4.0 """ - stringOrderType = Param(Params._dummy(), "stringOrderType", - "How to order labels of string column. The first label after " + - "ordering is assigned an index of 0. Supported options: " + - "frequencyDesc, frequencyAsc, alphabetDesc, alphabetAsc.", - typeConverter=TypeConverters.toString) - - handleInvalid = Param(Params._dummy(), "handleInvalid", "how to handle invalid data (unseen " + - "or NULL values) in features and label column of string type. " + - "Options are 'skip' (filter out rows with invalid data), " + - "error (throw an error), or 'keep' (put invalid data " + - "in a special additional bucket, at index numLabels).", - typeConverter=TypeConverters.toString) - @keyword_only def __init__(self, inputCol=None, outputCol=None, handleInvalid="error", stringOrderType="frequencyDesc"): @@ -2414,7 +2436,6 @@ def __init__(self, inputCol=None, outputCol=None, handleInvalid="error", """ super(StringIndexer, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.StringIndexer", self.uid) - self._setDefault(handleInvalid="error", stringOrderType="frequencyDesc") kwargs = self._input_kwargs self.setParams(**kwargs) @@ -2440,21 +2461,33 @@ def setStringOrderType(self, value): """ return self._set(stringOrderType=value) - @since("2.3.0") - def getStringOrderType(self): - """ - Gets the value of :py:attr:`stringOrderType` or its default value 'frequencyDesc'. - """ - return self.getOrDefault(self.stringOrderType) - -class StringIndexerModel(JavaModel, JavaMLReadable, JavaMLWritable): +class StringIndexerModel(JavaModel, _StringIndexerParams, JavaMLReadable, JavaMLWritable): """ Model fitted by :py:class:`StringIndexer`. .. versionadded:: 1.4.0 """ + @classmethod + @since("2.4.0") + def from_labels(cls, labels, inputCol, outputCol=None, handleInvalid=None): + """ + Construct the model directly from an array of label strings, + requires an active SparkContext. + """ + sc = SparkContext._active_spark_context + java_class = sc._gateway.jvm.java.lang.String + jlabels = StringIndexerModel._new_java_array(labels, java_class) + model = StringIndexerModel._create_from_java_class( + "org.apache.spark.ml.feature.StringIndexerModel", jlabels) + model.setInputCol(inputCol) + if outputCol is not None: + model.setOutputCol(outputCol) + if handleInvalid is not None: + model.setHandleInvalid(handleInvalid) + return model + @property @since("1.5.0") def labels(self): @@ -2463,6 +2496,13 @@ def labels(self): """ return self._call_java("labels") + @since("2.4.0") + def setHandleInvalid(self, value): + """ + Sets the value of :py:attr:`handleInvalid`. + """ + return self._set(handleInvalid=value) + @inherit_doc class IndexToString(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index c2c4861e2aff4..4ce54547eab09 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -800,6 +800,43 @@ def test_string_indexer_handle_invalid(self): expected2 = [Row(id=0, indexed=0.0), Row(id=1, indexed=1.0)] self.assertEqual(actual2, expected2) + def test_string_indexer_from_labels(self): + model = StringIndexerModel.from_labels(["a", "b", "c"], inputCol="label", + outputCol="indexed", handleInvalid="keep") + self.assertEqual(model.labels, ["a", "b", "c"]) + + df1 = self.spark.createDataFrame([ + (0, "a"), + (1, "c"), + (2, None), + (3, "b"), + (4, "b")], ["id", "label"]) + + result1 = model.transform(df1) + actual1 = result1.select("id", "indexed").collect() + expected1 = [Row(id=0, indexed=0.0), Row(id=1, indexed=2.0), Row(id=2, indexed=3.0), + Row(id=3, indexed=1.0), Row(id=4, indexed=1.0)] + self.assertEqual(actual1, expected1) + + model_empty_labels = StringIndexerModel.from_labels( + [], inputCol="label", outputCol="indexed", handleInvalid="keep") + actual2 = model_empty_labels.transform(df1).select("id", "indexed").collect() + expected2 = [Row(id=0, indexed=0.0), Row(id=1, indexed=0.0), Row(id=2, indexed=0.0), + Row(id=3, indexed=0.0), Row(id=4, indexed=0.0)] + self.assertEqual(actual2, expected2) + + # Test model with default settings can transform + model_default = StringIndexerModel.from_labels(["a", "b", "c"], inputCol="label") + df2 = self.spark.createDataFrame([ + (0, "a"), + (1, "c"), + (2, "b"), + (3, "b"), + (4, "b")], ["id", "label"]) + transformed_list = model_default.transform(df2)\ + .select(model_default.getOrDefault(model_default.outputCol)).collect() + self.assertEqual(len(transformed_list), 5) + class HasInducedError(Params): @@ -2097,9 +2134,11 @@ def test_java_params(self): ParamTests.check_params(self, cls(), check_params_exist=False) # Additional classes that need explicit construction - from pyspark.ml.feature import CountVectorizerModel + from pyspark.ml.feature import CountVectorizerModel, StringIndexerModel ParamTests.check_params(self, CountVectorizerModel.from_vocabulary(['a'], 'input'), check_params_exist=False) + ParamTests.check_params(self, StringIndexerModel.from_labels(['a', 'b'], 'input'), + check_params_exist=False) def _squared_distance(a, b): From 6ab134ca7d8f7802a6d196929513cc02b9b4d35d Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Fri, 6 Apr 2018 15:00:13 -0700 Subject: [PATCH 0579/1161] [SPARK-21898][ML][FOLLOWUP] Fix Scala 2.12 build. ## What changes were proposed in this pull request? This is a follow-up pr of #19108 which broke Scala 2.12 build. ``` [error] /Users/ueshin/workspace/apache-spark/spark/mllib/src/main/scala/org/apache/spark/ml/stat/KolmogorovSmirnovTest.scala:86: overloaded method value test with alternatives: [error] (dataset: org.apache.spark.sql.DataFrame,sampleCol: String,cdf: org.apache.spark.api.java.function.Function[java.lang.Double,java.lang.Double])org.apache.spark.sql.DataFrame [error] (dataset: org.apache.spark.sql.DataFrame,sampleCol: String,cdf: scala.Double => scala.Double)org.apache.spark.sql.DataFrame [error] cannot be applied to (org.apache.spark.sql.DataFrame, String, scala.Double => java.lang.Double) [error] test(dataset, sampleCol, (x: Double) => cdf.call(x)) [error] ^ [error] one error found ``` ## How was this patch tested? Existing tests. Author: Takuya UESHIN Closes #20994 from ueshin/issues/SPARK-21898/fix_scala-2.12. --- .../scala/org/apache/spark/ml/stat/KolmogorovSmirnovTest.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/stat/KolmogorovSmirnovTest.scala b/mllib/src/main/scala/org/apache/spark/ml/stat/KolmogorovSmirnovTest.scala index 8d80e7768cb6e..c62d7463288f7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/stat/KolmogorovSmirnovTest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/stat/KolmogorovSmirnovTest.scala @@ -83,7 +83,8 @@ object KolmogorovSmirnovTest { @Since("2.4.0") def test(dataset: DataFrame, sampleCol: String, cdf: Function[java.lang.Double, java.lang.Double]): DataFrame = { - test(dataset, sampleCol, (x: Double) => cdf.call(x)) + val f: Double => Double = x => cdf.call(x) + test(dataset, sampleCol, f) } /** From 2c1fe647575e97e28b2232478ca86847d113e185 Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Sun, 8 Apr 2018 12:09:06 +0800 Subject: [PATCH 0580/1161] [SPARK-23847][PYTHON][SQL] Add asc_nulls_first, asc_nulls_last to PySpark ## What changes were proposed in this pull request? Column.scala and Functions.scala have asc_nulls_first, asc_nulls_last, desc_nulls_first and desc_nulls_last. Add the corresponding python APIs in column.py and functions.py ## How was this patch tested? Add doctest Author: Huaxin Gao Closes #20962 from huaxingao/spark-23847. --- python/pyspark/sql/column.py | 56 +++++++++++++++++-- python/pyspark/sql/functions.py | 13 +++++ python/pyspark/sql/tests.py | 14 +++++ .../scala/org/apache/spark/sql/Column.scala | 4 +- .../org/apache/spark/sql/functions.scala | 2 +- 5 files changed, 82 insertions(+), 7 deletions(-) diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py index 922c7cf288f8f..e7dec11c69b57 100644 --- a/python/pyspark/sql/column.py +++ b/python/pyspark/sql/column.py @@ -447,24 +447,72 @@ def isin(self, *cols): # order _asc_doc = """ - Returns a sort expression based on the ascending order of the given column name + Returns a sort expression based on ascending order of the column. >>> from pyspark.sql import Row - >>> df = spark.createDataFrame([Row(name=u'Tom', height=80), Row(name=u'Alice', height=None)]) + >>> df = spark.createDataFrame([('Tom', 80), ('Alice', None)], ["name", "height"]) >>> df.select(df.name).orderBy(df.name.asc()).collect() [Row(name=u'Alice'), Row(name=u'Tom')] """ + _asc_nulls_first_doc = """ + Returns a sort expression based on ascending order of the column, and null values + return before non-null values. + + >>> from pyspark.sql import Row + >>> df = spark.createDataFrame([('Tom', 80), (None, 60), ('Alice', None)], ["name", "height"]) + >>> df.select(df.name).orderBy(df.name.asc_nulls_first()).collect() + [Row(name=None), Row(name=u'Alice'), Row(name=u'Tom')] + + .. versionadded:: 2.4 + """ + _asc_nulls_last_doc = """ + Returns a sort expression based on ascending order of the column, and null values + appear after non-null values. + + >>> from pyspark.sql import Row + >>> df = spark.createDataFrame([('Tom', 80), (None, 60), ('Alice', None)], ["name", "height"]) + >>> df.select(df.name).orderBy(df.name.asc_nulls_last()).collect() + [Row(name=u'Alice'), Row(name=u'Tom'), Row(name=None)] + + .. versionadded:: 2.4 + """ _desc_doc = """ - Returns a sort expression based on the descending order of the given column name. + Returns a sort expression based on the descending order of the column. >>> from pyspark.sql import Row - >>> df = spark.createDataFrame([Row(name=u'Tom', height=80), Row(name=u'Alice', height=None)]) + >>> df = spark.createDataFrame([('Tom', 80), ('Alice', None)], ["name", "height"]) >>> df.select(df.name).orderBy(df.name.desc()).collect() [Row(name=u'Tom'), Row(name=u'Alice')] """ + _desc_nulls_first_doc = """ + Returns a sort expression based on the descending order of the column, and null values + appear before non-null values. + + >>> from pyspark.sql import Row + >>> df = spark.createDataFrame([('Tom', 80), (None, 60), ('Alice', None)], ["name", "height"]) + >>> df.select(df.name).orderBy(df.name.desc_nulls_first()).collect() + [Row(name=None), Row(name=u'Tom'), Row(name=u'Alice')] + + .. versionadded:: 2.4 + """ + _desc_nulls_last_doc = """ + Returns a sort expression based on the descending order of the column, and null values + appear after non-null values. + + >>> from pyspark.sql import Row + >>> df = spark.createDataFrame([('Tom', 80), (None, 60), ('Alice', None)], ["name", "height"]) + >>> df.select(df.name).orderBy(df.name.desc_nulls_last()).collect() + [Row(name=u'Tom'), Row(name=u'Alice'), Row(name=None)] + + .. versionadded:: 2.4 + """ asc = ignore_unicode_prefix(_unary_op("asc", _asc_doc)) + asc_nulls_first = ignore_unicode_prefix(_unary_op("asc_nulls_first", _asc_nulls_first_doc)) + asc_nulls_last = ignore_unicode_prefix(_unary_op("asc_nulls_last", _asc_nulls_last_doc)) desc = ignore_unicode_prefix(_unary_op("desc", _desc_doc)) + desc_nulls_first = ignore_unicode_prefix(_unary_op("desc_nulls_first", _desc_nulls_first_doc)) + desc_nulls_last = ignore_unicode_prefix(_unary_op("desc_nulls_last", _desc_nulls_last_doc)) _isNull_doc = """ True if the current expression is null. diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index ad3e37c872628..1b192680f0795 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -138,6 +138,17 @@ def _(): 'bitwiseNOT': 'Computes bitwise not.', } +_functions_2_4 = { + 'asc_nulls_first': 'Returns a sort expression based on the ascending order of the given' + + ' column name, and null values return before non-null values.', + 'asc_nulls_last': 'Returns a sort expression based on the ascending order of the given' + + ' column name, and null values appear after non-null values.', + 'desc_nulls_first': 'Returns a sort expression based on the descending order of the given' + + ' column name, and null values appear before non-null values.', + 'desc_nulls_last': 'Returns a sort expression based on the descending order of the given' + + ' column name, and null values appear after non-null values', +} + _collect_list_doc = """ Aggregate function: returns a list of objects with duplicates. @@ -250,6 +261,8 @@ def _(): globals()[_name] = since(2.1)(_create_function(_name, _doc)) for _name, _message in _functions_deprecated.items(): globals()[_name] = _wrap_deprecated_function(globals()[_name], _message) +for _name, _doc in _functions_2_4.items(): + globals()[_name] = since(2.4)(_create_function(_name, _doc)) del _name, _doc diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 5181053a0d318..dd04ffb4ed393 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -2991,6 +2991,20 @@ def test_create_dateframe_from_pandas_with_dst(self): os.environ['TZ'] = orig_env_tz time.tzset() + def test_2_4_functions(self): + from pyspark.sql import functions + + df = self.spark.createDataFrame( + [('Tom', 80), (None, 60), ('Alice', 50)], ["name", "height"]) + df.select(df.name).orderBy(functions.asc_nulls_first('name')).collect() + [Row(name=None), Row(name=u'Alice'), Row(name=u'Tom')] + df.select(df.name).orderBy(functions.asc_nulls_last('name')).collect() + [Row(name=u'Alice'), Row(name=u'Tom'), Row(name=None)] + df.select(df.name).orderBy(functions.desc_nulls_first('name')).collect() + [Row(name=None), Row(name=u'Tom'), Row(name=u'Alice')] + df.select(df.name).orderBy(functions.desc_nulls_last('name')).collect() + [Row(name=u'Tom'), Row(name=u'Alice'), Row(name=None)] + class HiveSparkSubmitTests(SparkSubmitTests): diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 92988680871a4..ad0efbae89830 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -1083,10 +1083,10 @@ class Column(val expr: Expression) extends Logging { * and null values return before non-null values. * {{{ * // Scala: sort a DataFrame by age column in ascending order and null values appearing first. - * df.sort(df("age").asc_nulls_last) + * df.sort(df("age").asc_nulls_first) * * // Java - * df.sort(df.col("age").asc_nulls_last()); + * df.sort(df.col("age").asc_nulls_first()); * }}} * * @group expr_ops diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index c9ca9a8996344..c658f25ced053 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -132,7 +132,7 @@ object functions { * Returns a sort expression based on ascending order of the column, * and null values return before non-null values. * {{{ - * df.sort(asc_nulls_last("dept"), desc("age")) + * df.sort(asc_nulls_first("dept"), desc("age")) * }}} * * @group sort_funcs From 6a734575a80e6b4ec4963206254451f05d64b742 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Sat, 7 Apr 2018 21:44:32 -0700 Subject: [PATCH 0581/1161] [SPARK-23849][SQL] Tests for the samplingRatio option of JSON datasource ## What changes were proposed in this pull request? Proposed tests checks that only subset of input dataset is touched during schema inferring. Author: Maxim Gekk Closes #20963 from MaxGekk/json-sampling-tests. --- .../datasources/json/JsonSuite.scala | 37 ++++++++++++++++++- 1 file changed, 36 insertions(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 10bac0554484a..70aee561ff0f6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.datasources.json import java.io.{File, StringWriter} import java.nio.charset.StandardCharsets -import java.nio.file.Files +import java.nio.file.{Files, Paths, StandardOpenOption} import java.sql.{Date, Timestamp} import java.util.Locale @@ -2127,4 +2127,39 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { assert(df.schema === expectedSchema) } } + + test("SPARK-23849: schema inferring touches less data if samplingRation < 1.0") { + val predefinedSample = Set[Int](2, 8, 15, 27, 30, 34, 35, 37, 44, 46, + 57, 62, 68, 72) + withTempPath { path => + val writer = Files.newBufferedWriter(Paths.get(path.getAbsolutePath), + StandardCharsets.UTF_8, StandardOpenOption.CREATE_NEW) + for (i <- 0 until 100) { + if (predefinedSample.contains(i)) { + writer.write(s"""{"f1":${i.toString}}""" + "\n") + } else { + writer.write(s"""{"f1":${(i.toDouble + 0.1).toString}}""" + "\n") + } + } + writer.close() + + val ds = spark.read.option("samplingRatio", 0.1).json(path.getCanonicalPath) + assert(ds.schema == new StructType().add("f1", LongType)) + } + } + + test("SPARK-23849: usage of samplingRation while parsing of dataset of strings") { + val dstr = spark.sparkContext.parallelize(0 until 100, 1).map { i => + val predefinedSample = Set[Int](2, 8, 15, 27, 30, 34, 35, 37, 44, 46, + 57, 62, 68, 72) + if (predefinedSample.contains(i)) { + s"""{"f1":${i.toString}}""" + "\n" + } else { + s"""{"f1":${(i.toDouble + 0.1).toString}}""" + "\n" + } + }.toDS() + val ds = spark.read.option("samplingRatio", 0.1).json(dstr) + + assert(ds.schema == new StructType().add("f1", LongType)) + } } From 710a68cec27a94c2df10d8b4022a755a94a5443b Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Sun, 8 Apr 2018 20:26:31 +0200 Subject: [PATCH 0582/1161] [SPARK-23892][TEST] Improve converge and fix lint error in UTF8String-related tests ## What changes were proposed in this pull request? This PR improves test coverage in `UTF8StringSuite` and code efficiency in `UTF8StringPropertyCheckSuite`. This PR also fixes lint-java issue in `UTF8StringSuite` reported at [here](https://github.com/apache/spark/pull/20995#issuecomment-379325527) ```[ERROR] src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java:[28,8] (imports) UnusedImports: Unused import - org.apache.spark.unsafe.Platform.``` ## How was this patch tested? Existing UT Author: Kazuaki Ishizaki Closes #21000 from kiszk/SPARK-23892. --- .../java/org/apache/spark/unsafe/types/UTF8StringSuite.java | 5 ++--- .../spark/unsafe/types/UTF8StringPropertyCheckSuite.scala | 2 +- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java index 2c08535a16465..42dda30480702 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java @@ -25,7 +25,6 @@ import java.util.*; import com.google.common.collect.ImmutableMap; -import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.memory.ByteArrayMemoryBlock; import org.apache.spark.unsafe.memory.OnHeapMemoryBlock; import org.junit.Test; @@ -53,8 +52,8 @@ private static void checkBasic(String str, int len) { assertTrue(s1.contains(s2)); assertTrue(s2.contains(s1)); - assertTrue(s1.startsWith(s1)); - assertTrue(s1.endsWith(s1)); + assertTrue(s1.startsWith(s2)); + assertTrue(s1.endsWith(s2)); } @Test diff --git a/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/UTF8StringPropertyCheckSuite.scala b/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/UTF8StringPropertyCheckSuite.scala index 62d4176d00f94..48004e812a8bf 100644 --- a/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/UTF8StringPropertyCheckSuite.scala +++ b/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/UTF8StringPropertyCheckSuite.scala @@ -164,7 +164,7 @@ class UTF8StringPropertyCheckSuite extends FunSuite with GeneratorDrivenProperty def padding(origin: String, pad: String, length: Int, isLPad: Boolean): String = { if (length <= 0) return "" if (length <= origin.length) { - if (length <= 0) "" else origin.substring(0, length) + origin.substring(0, length) } else { if (pad.length == 0) return origin val toPad = length - origin.length From 8d40a79a077a30024a8ef921781b68f6f7e542d1 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Sun, 8 Apr 2018 20:40:27 +0200 Subject: [PATCH 0583/1161] [SPARK-23893][CORE][SQL] Avoid possible integer overflow in multiplication ## What changes were proposed in this pull request? This PR avoids possible overflow at an operation `long = (long)(int * int)`. The multiplication of large positive integer values may set one to MSB. This leads to a negative value in long while we expected a positive value (e.g. `0111_0000_0000_0000 * 0000_0000_0000_0010`). This PR performs long cast before the multiplication to avoid this situation. ## How was this patch tested? Existing UTs Author: Kazuaki Ishizaki Closes #21002 from kiszk/SPARK-23893. --- .../util/collection/unsafe/sort/UnsafeInMemorySorter.java | 2 +- .../util/collection/unsafe/sort/UnsafeSortDataFormat.java | 2 +- .../src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala | 2 +- .../scala/org/apache/spark/InternalAccumulatorSuite.scala | 2 +- .../apache/spark/deploy/history/FsHistoryProviderSuite.scala | 4 ++-- .../test/scala/org/apache/spark/util/JsonProtocolSuite.scala | 4 ++-- .../src/test/scala/org/apache/spark/sql/HashBenchmark.scala | 2 +- .../scala/org/apache/spark/sql/HashByteArrayBenchmark.scala | 3 ++- .../org/apache/spark/sql/UnsafeProjectionBenchmark.scala | 2 +- .../columnar/compression/CompressionSchemeBenchmark.scala | 4 ++-- .../sql/execution/vectorized/ColumnarBatchBenchmark.scala | 2 +- 11 files changed, 15 insertions(+), 14 deletions(-) diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java index 20a7a8b267438..717823ebbd320 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java @@ -124,7 +124,7 @@ public UnsafeInMemorySorter( int initialSize, boolean canUseRadixSort) { this(consumer, memoryManager, recordComparator, prefixComparator, - consumer.allocateArray(initialSize * 2), canUseRadixSort); + consumer.allocateArray(initialSize * 2L), canUseRadixSort); } public UnsafeInMemorySorter( diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSortDataFormat.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSortDataFormat.java index d9f84d10e9051..37772f41caa87 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSortDataFormat.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSortDataFormat.java @@ -84,7 +84,7 @@ public void copyRange(LongArray src, int srcPos, LongArray dst, int dstPos, int @Override public LongArray allocate(int length) { - assert (length * 2 <= buffer.size()) : + assert (length * 2L <= buffer.size()) : "the buffer is smaller than required: " + buffer.size() + " < " + (length * 2); return buffer; } diff --git a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala index c9ed12f4e1bd4..13db4985b0b80 100644 --- a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala @@ -90,7 +90,7 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi // Otherwise, interpolate the number of partitions we need to try, but overestimate it // by 50%. We also cap the estimation in the end. if (results.size == 0) { - numPartsToTry = partsScanned * 4 + numPartsToTry = partsScanned * 4L } else { // the left side of max is >=1 whenever partsScanned >= 2 numPartsToTry = Math.max(1, diff --git a/core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala b/core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala index 8d7be77f51fe9..62824a5bec9d1 100644 --- a/core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala @@ -135,7 +135,7 @@ class InternalAccumulatorSuite extends SparkFunSuite with LocalSparkContext { // This job runs 2 stages, and we're in the second stage. Therefore, any task attempt // ID that's < 2 * numPartitions belongs to the first attempt of this stage. val taskContext = TaskContext.get() - val isFirstStageAttempt = taskContext.taskAttemptId() < numPartitions * 2 + val isFirstStageAttempt = taskContext.taskAttemptId() < numPartitions * 2L if (isFirstStageAttempt) { throw new FetchFailedException( SparkEnv.get.blockManager.blockManagerId, diff --git a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala index fde5f25bce456..0ba57bf4563c1 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala @@ -382,8 +382,8 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc val log = newLogFile("downloadApp1", Some(s"attempt$i"), inProgress = false) writeFile(log, true, None, SparkListenerApplicationStart( - "downloadApp1", Some("downloadApp1"), 5000 * i, "test", Some(s"attempt$i")), - SparkListenerApplicationEnd(5001 * i) + "downloadApp1", Some("downloadApp1"), 5000L * i, "test", Some(s"attempt$i")), + SparkListenerApplicationEnd(5001L * i) ) log } diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index 4abbb8e7894f5..74b72d940eeef 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -317,7 +317,7 @@ class JsonProtocolSuite extends SparkFunSuite { test("SparkListenerJobStart backward compatibility") { // Prior to Spark 1.2.0, SparkListenerJobStart did not have a "Stage Infos" property. val stageIds = Seq[Int](1, 2, 3, 4) - val stageInfos = stageIds.map(x => makeStageInfo(x, x * 200, x * 300, x * 400, x * 500)) + val stageInfos = stageIds.map(x => makeStageInfo(x, x * 200, x * 300, x * 400L, x * 500L)) val dummyStageInfos = stageIds.map(id => new StageInfo(id, 0, "unknown", 0, Seq.empty, Seq.empty, "unknown")) val jobStart = SparkListenerJobStart(10, jobSubmissionTime, stageInfos, properties) @@ -331,7 +331,7 @@ class JsonProtocolSuite extends SparkFunSuite { // Prior to Spark 1.3.0, SparkListenerJobStart did not have a "Submission Time" property. // Also, SparkListenerJobEnd did not have a "Completion Time" property. val stageIds = Seq[Int](1, 2, 3, 4) - val stageInfos = stageIds.map(x => makeStageInfo(x * 10, x * 20, x * 30, x * 40, x * 50)) + val stageInfos = stageIds.map(x => makeStageInfo(x * 10, x * 20, x * 30, x * 40L, x * 50L)) val jobStart = SparkListenerJobStart(11, jobSubmissionTime, stageInfos, properties) val oldStartEvent = JsonProtocol.jobStartToJson(jobStart) .removeField({ _._1 == "Submission Time"}) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/HashBenchmark.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/HashBenchmark.scala index 2d94b66a1e122..9a89e6290e695 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/HashBenchmark.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/HashBenchmark.scala @@ -40,7 +40,7 @@ object HashBenchmark { safeProjection(encoder.toRow(generator().asInstanceOf[Row])).copy() ).toArray - val benchmark = new Benchmark("Hash For " + name, iters * numRows) + val benchmark = new Benchmark("Hash For " + name, iters * numRows.toLong) benchmark.addCase("interpreted version") { _: Int => var sum = 0 for (_ <- 0L until iters) { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/HashByteArrayBenchmark.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/HashByteArrayBenchmark.scala index 2a753a0c84ed5..f6c8111f5bc57 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/HashByteArrayBenchmark.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/HashByteArrayBenchmark.scala @@ -36,7 +36,8 @@ object HashByteArrayBenchmark { bytes } - val benchmark = new Benchmark("Hash byte arrays with length " + length, iters * numArrays) + val benchmark = + new Benchmark("Hash byte arrays with length " + length, iters * numArrays.toLong) benchmark.addCase("Murmur3_x86_32") { _: Int => var sum = 0L for (_ <- 0L until iters) { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/UnsafeProjectionBenchmark.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/UnsafeProjectionBenchmark.scala index 769addf3b29e6..6c63769945312 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/UnsafeProjectionBenchmark.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/UnsafeProjectionBenchmark.scala @@ -38,7 +38,7 @@ object UnsafeProjectionBenchmark { val iters = 1024 * 16 val numRows = 1024 * 16 - val benchmark = new Benchmark("unsafe projection", iters * numRows) + val benchmark = new Benchmark("unsafe projection", iters * numRows.toLong) val schema1 = new StructType().add("l", LongType, false) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/CompressionSchemeBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/CompressionSchemeBenchmark.scala index 9005ec93e786e..619b76fabdd5e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/CompressionSchemeBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/CompressionSchemeBenchmark.scala @@ -77,7 +77,7 @@ object CompressionSchemeBenchmark extends AllCompressionSchemes { count: Int, tpe: NativeColumnType[T], input: ByteBuffer): Unit = { - val benchmark = new Benchmark(name, iters * count) + val benchmark = new Benchmark(name, iters * count.toLong) schemes.filter(_.supports(tpe)).foreach { scheme => val (compressFunc, compressionRatio, buf) = prepareEncodeInternal(count, tpe, scheme, input) @@ -101,7 +101,7 @@ object CompressionSchemeBenchmark extends AllCompressionSchemes { count: Int, tpe: NativeColumnType[T], input: ByteBuffer): Unit = { - val benchmark = new Benchmark(name, iters * count) + val benchmark = new Benchmark(name, iters * count.toLong) schemes.filter(_.supports(tpe)).foreach { scheme => val (compressFunc, _, buf) = prepareEncodeInternal(count, tpe, scheme, input) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala index 1f31aa45a1220..8aeb06d428951 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala @@ -295,7 +295,7 @@ object ColumnarBatchBenchmark { def booleanAccess(iters: Int): Unit = { val count = 8 * 1024 - val benchmark = new Benchmark("Boolean Read/Write", iters * count) + val benchmark = new Benchmark("Boolean Read/Write", iters * count.toLong) benchmark.addCase("Bitset") { i: Int => { val b = new BitSet(count) var sum = 0L From 32471ba0af52b59141b44a8375025b6a7eafae70 Mon Sep 17 00:00:00 2001 From: Nolan Emirot Date: Mon, 9 Apr 2018 08:04:02 -0500 Subject: [PATCH 0584/1161] Fix typo in Python docstring kinesis example ## What changes were proposed in this pull request? (Please fill in changes proposed in this fix) ## How was this patch tested? (Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests) (If this patch involves UI changes, please attach a screenshot; otherwise, remove this) Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Nolan Emirot Closes #20990 from emirot/kinesis_stream_example_typo. --- .../src/main/python/examples/streaming/kinesis_wordcount_asl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/external/kinesis-asl/src/main/python/examples/streaming/kinesis_wordcount_asl.py b/external/kinesis-asl/src/main/python/examples/streaming/kinesis_wordcount_asl.py index 4d7fc9a549bfb..49794faab88c4 100644 --- a/external/kinesis-asl/src/main/python/examples/streaming/kinesis_wordcount_asl.py +++ b/external/kinesis-asl/src/main/python/examples/streaming/kinesis_wordcount_asl.py @@ -34,7 +34,7 @@ $ export AWS_SECRET_KEY= # run the example - $ bin/spark-submit -jar external/kinesis-asl/target/scala-*/\ + $ bin/spark-submit -jars external/kinesis-asl/target/scala-*/\ spark-streaming-kinesis-asl-assembly_*.jar \ external/kinesis-asl/src/main/python/examples/streaming/kinesis_wordcount_asl.py \ myAppName mySparkStream https://kinesis.us-east-1.amazonaws.com From d81f29ecafe8fc9816e36087e3b8acdc93d6cc1b Mon Sep 17 00:00:00 2001 From: Xingbo Jiang Date: Mon, 9 Apr 2018 10:19:22 -0700 Subject: [PATCH 0585/1161] [SPARK-23881][CORE][TEST] Fix flaky test JobCancellationSuite."interruptible iterator of shuffle reader" ## What changes were proposed in this pull request? The test case JobCancellationSuite."interruptible iterator of shuffle reader" has been flaky because `KillTask` event is handled asynchronously, so it can happen that the semaphore is released but the task is still running. Actually we only have to check if the total number of processed elements is less than the input elements number, so we know the task get cancelled. ## How was this patch tested? The new test case still fails without the purposed patch, and succeeded in current master. Author: Xingbo Jiang Closes #20993 from jiangxb1987/JobCancellationSuite. --- .../apache/spark/JobCancellationSuite.scala | 31 +++++++++++-------- 1 file changed, 18 insertions(+), 13 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala index 3b793bb231cf3..61da4138896cd 100644 --- a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala +++ b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala @@ -332,13 +332,15 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft import JobCancellationSuite._ sc = new SparkContext("local[2]", "test interruptible iterator") + // Increase the number of elements to be proceeded to avoid this test being flaky. + val numElements = 10000 val taskCompletedSem = new Semaphore(0) sc.addSparkListener(new SparkListener { override def onStageCompleted(stageCompleted: SparkListenerStageCompleted): Unit = { // release taskCancelledSemaphore when cancelTasks event has been posted if (stageCompleted.stageInfo.stageId == 1) { - taskCancelledSemaphore.release(1000) + taskCancelledSemaphore.release(numElements) } } @@ -349,28 +351,31 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft } }) - val f = sc.parallelize(1 to 1000).map { i => (i, i) } + // Explicitly disable interrupt task thread on cancelling tasks, so the task thread can only be + // interrupted by `InterruptibleIterator`. + sc.setLocalProperty(SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL, "false") + + val f = sc.parallelize(1 to numElements).map { i => (i, i) } .repartitionAndSortWithinPartitions(new HashPartitioner(1)) .mapPartitions { iter => taskStartedSemaphore.release() iter }.foreachAsync { x => - if (x._1 >= 10) { - // This block of code is partially executed. It will be blocked when x._1 >= 10 and the - // next iteration will be cancelled if the source iterator is interruptible. Then in this - // case, the maximum num of increment would be 10(|1...10|) - taskCancelledSemaphore.acquire() - } + // Block this code from being executed, until the job get cancelled. In this case, if the + // source iterator is interruptible, the max number of increment should be under + // `numElements`. + taskCancelledSemaphore.acquire() executionOfInterruptibleCounter.getAndIncrement() } taskStartedSemaphore.acquire() // Job is cancelled when: // 1. task in reduce stage has been started, guaranteed by previous line. - // 2. task in reduce stage is blocked after processing at most 10 records as - // taskCancelledSemaphore is not released until cancelTasks event is posted - // After job being cancelled, task in reduce stage will be cancelled and no more iteration are - // executed. + // 2. task in reduce stage is blocked as taskCancelledSemaphore is not released until + // JobCancelled event is posted. + // After job being cancelled, task in reduce stage will be cancelled asynchronously, thus + // partial of the inputs should not get processed (It's very unlikely that Spark can process + // 10000 elements between JobCancelled is posted and task is really killed). f.cancel() val e = intercept[SparkException](f.get()).getCause @@ -378,7 +383,7 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft // Make sure tasks are indeed completed. taskCompletedSem.acquire() - assert(executionOfInterruptibleCounter.get() <= 10) + assert(executionOfInterruptibleCounter.get() < numElements) } def testCount() { From 10f45bb8233e6ac838dd4f053052c8556f5b54bd Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Mon, 9 Apr 2018 11:31:21 -0700 Subject: [PATCH 0586/1161] [SPARK-23816][CORE] Killed tasks should ignore FetchFailures. SPARK-19276 ensured that FetchFailures do not get swallowed by other layers of exception handling, but it also meant that a killed task could look like a fetch failure. This is particularly a problem with speculative execution, where we expect to kill tasks as they are reading shuffle data. The fix is to ensure that we always check for killed tasks first. Added a new unit test which fails before the fix, ran it 1k times to check for flakiness. Full suite of tests on jenkins. Author: Imran Rashid Closes #20987 from squito/SPARK-23816. --- .../org/apache/spark/executor/Executor.scala | 26 +++--- .../apache/spark/executor/ExecutorSuite.scala | 92 +++++++++++++++---- 2 files changed, 88 insertions(+), 30 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index dcec3ec21b546..c325222b764b8 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -480,6 +480,19 @@ private[spark] class Executor( execBackend.statusUpdate(taskId, TaskState.FINISHED, serializedResult) } catch { + case t: TaskKilledException => + logInfo(s"Executor killed $taskName (TID $taskId), reason: ${t.reason}") + setTaskFinishedAndClearInterruptStatus() + execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled(t.reason))) + + case _: InterruptedException | NonFatal(_) if + task != null && task.reasonIfKilled.isDefined => + val killReason = task.reasonIfKilled.getOrElse("unknown reason") + logInfo(s"Executor interrupted and killed $taskName (TID $taskId), reason: $killReason") + setTaskFinishedAndClearInterruptStatus() + execBackend.statusUpdate( + taskId, TaskState.KILLED, ser.serialize(TaskKilled(killReason))) + case t: Throwable if hasFetchFailure && !Utils.isFatalError(t) => val reason = task.context.fetchFailed.get.toTaskFailedReason if (!t.isInstanceOf[FetchFailedException]) { @@ -494,19 +507,6 @@ private[spark] class Executor( setTaskFinishedAndClearInterruptStatus() execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason)) - case t: TaskKilledException => - logInfo(s"Executor killed $taskName (TID $taskId), reason: ${t.reason}") - setTaskFinishedAndClearInterruptStatus() - execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled(t.reason))) - - case _: InterruptedException | NonFatal(_) if - task != null && task.reasonIfKilled.isDefined => - val killReason = task.reasonIfKilled.getOrElse("unknown reason") - logInfo(s"Executor interrupted and killed $taskName (TID $taskId), reason: $killReason") - setTaskFinishedAndClearInterruptStatus() - execBackend.statusUpdate( - taskId, TaskState.KILLED, ser.serialize(TaskKilled(killReason))) - case CausedBy(cDE: CommitDeniedException) => val reason = cDE.toTaskCommitDeniedReason setTaskFinishedAndClearInterruptStatus() diff --git a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala index 105a178f2d94e..1a7bebe2c53cd 100644 --- a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala +++ b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala @@ -22,6 +22,7 @@ import java.lang.Thread.UncaughtExceptionHandler import java.nio.ByteBuffer import java.util.Properties import java.util.concurrent.{CountDownLatch, TimeUnit} +import java.util.concurrent.atomic.AtomicBoolean import scala.collection.mutable.Map import scala.concurrent.duration._ @@ -139,7 +140,7 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug // the fetch failure. The executor should still tell the driver that the task failed due to a // fetch failure, not a generic exception from user code. val inputRDD = new FetchFailureThrowingRDD(sc) - val secondRDD = new FetchFailureHidingRDD(sc, inputRDD, throwOOM = false) + val secondRDD = new FetchFailureHidingRDD(sc, inputRDD, throwOOM = false, interrupt = false) val taskBinary = sc.broadcast(serializer.serialize((secondRDD, resultFunc)).array()) val serializedTaskMetrics = serializer.serialize(TaskMetrics.registered).array() val task = new ResultTask( @@ -173,17 +174,48 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug } test("SPARK-19276: OOMs correctly handled with a FetchFailure") { + val (failReason, uncaughtExceptionHandler) = testFetchFailureHandling(true) + assert(failReason.isInstanceOf[ExceptionFailure]) + val exceptionCaptor = ArgumentCaptor.forClass(classOf[Throwable]) + verify(uncaughtExceptionHandler).uncaughtException(any(), exceptionCaptor.capture()) + assert(exceptionCaptor.getAllValues.size === 1) + assert(exceptionCaptor.getAllValues().get(0).isInstanceOf[OutOfMemoryError]) + } + + test("SPARK-23816: interrupts are not masked by a FetchFailure") { + // If killing the task causes a fetch failure, we still treat it as a task that was killed, + // as the fetch failure could easily be caused by interrupting the thread. + val (failReason, _) = testFetchFailureHandling(false) + assert(failReason.isInstanceOf[TaskKilled]) + } + + /** + * Helper for testing some cases where a FetchFailure should *not* get sent back, because its + * superceded by another error, either an OOM or intentionally killing a task. + * @param oom if true, throw an OOM after the FetchFailure; else, interrupt the task after the + * FetchFailure + */ + private def testFetchFailureHandling( + oom: Boolean): (TaskFailedReason, UncaughtExceptionHandler) = { // when there is a fatal error like an OOM, we don't do normal fetch failure handling, since it // may be a false positive. And we should call the uncaught exception handler. + // SPARK-23816 also handle interrupts the same way, as killing an obsolete speculative task + // does not represent a real fetch failure. val conf = new SparkConf().setMaster("local").setAppName("executor suite test") sc = new SparkContext(conf) val serializer = SparkEnv.get.closureSerializer.newInstance() val resultFunc = (context: TaskContext, itr: Iterator[Int]) => itr.size - // Submit a job where a fetch failure is thrown, but then there is an OOM. We should treat - // the fetch failure as a false positive, and just do normal OOM handling. + // Submit a job where a fetch failure is thrown, but then there is an OOM or interrupt. We + // should treat the fetch failure as a false positive, and do normal OOM or interrupt handling. val inputRDD = new FetchFailureThrowingRDD(sc) - val secondRDD = new FetchFailureHidingRDD(sc, inputRDD, throwOOM = true) + if (!oom) { + // we are trying to setup a case where a task is killed after a fetch failure -- this + // is just a helper to coordinate between the task thread and this thread that will + // kill the task + ExecutorSuiteHelper.latches = new ExecutorSuiteHelper() + } + val secondRDD = new FetchFailureHidingRDD(sc, inputRDD, throwOOM = oom, interrupt = !oom) val taskBinary = sc.broadcast(serializer.serialize((secondRDD, resultFunc)).array()) val serializedTaskMetrics = serializer.serialize(TaskMetrics.registered).array() val task = new ResultTask( @@ -200,15 +232,8 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug val serTask = serializer.serialize(task) val taskDescription = createFakeTaskDescription(serTask) - val (failReason, uncaughtExceptionHandler) = - runTaskGetFailReasonAndExceptionHandler(taskDescription) - // make sure the task failure just looks like a OOM, not a fetch failure - assert(failReason.isInstanceOf[ExceptionFailure]) - val exceptionCaptor = ArgumentCaptor.forClass(classOf[Throwable]) - verify(uncaughtExceptionHandler).uncaughtException(any(), exceptionCaptor.capture()) - assert(exceptionCaptor.getAllValues.size === 1) - assert(exceptionCaptor.getAllValues.get(0).isInstanceOf[OutOfMemoryError]) - } + runTaskGetFailReasonAndExceptionHandler(taskDescription, killTask = !oom) + } test("Gracefully handle error in task deserialization") { val conf = new SparkConf @@ -257,22 +282,39 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug } private def runTaskAndGetFailReason(taskDescription: TaskDescription): TaskFailedReason = { - runTaskGetFailReasonAndExceptionHandler(taskDescription)._1 + runTaskGetFailReasonAndExceptionHandler(taskDescription, false)._1 } private def runTaskGetFailReasonAndExceptionHandler( - taskDescription: TaskDescription): (TaskFailedReason, UncaughtExceptionHandler) = { + taskDescription: TaskDescription, + killTask: Boolean): (TaskFailedReason, UncaughtExceptionHandler) = { val mockBackend = mock[ExecutorBackend] val mockUncaughtExceptionHandler = mock[UncaughtExceptionHandler] var executor: Executor = null + val timedOut = new AtomicBoolean(false) try { executor = new Executor("id", "localhost", SparkEnv.get, userClassPath = Nil, isLocal = true, uncaughtExceptionHandler = mockUncaughtExceptionHandler) // the task will be launched in a dedicated worker thread executor.launchTask(mockBackend, taskDescription) + if (killTask) { + val killingThread = new Thread("kill-task") { + override def run(): Unit = { + // wait to kill the task until it has thrown a fetch failure + if (ExecutorSuiteHelper.latches.latch1.await(10, TimeUnit.SECONDS)) { + // now we can kill the task + executor.killAllTasks(true, "Killed task, eg. because of speculative execution") + } else { + timedOut.set(true) + } + } + } + killingThread.start() + } eventually(timeout(5.seconds), interval(10.milliseconds)) { assert(executor.numRunningTasks === 0) } + assert(!timedOut.get(), "timed out waiting to be ready to kill tasks") } finally { if (executor != null) { executor.stop() @@ -282,8 +324,9 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug val statusCaptor = ArgumentCaptor.forClass(classOf[ByteBuffer]) orderedMock.verify(mockBackend) .statusUpdate(meq(0L), meq(TaskState.RUNNING), statusCaptor.capture()) + val finalState = if (killTask) TaskState.KILLED else TaskState.FAILED orderedMock.verify(mockBackend) - .statusUpdate(meq(0L), meq(TaskState.FAILED), statusCaptor.capture()) + .statusUpdate(meq(0L), meq(finalState), statusCaptor.capture()) // first statusUpdate for RUNNING has empty data assert(statusCaptor.getAllValues().get(0).remaining() === 0) // second update is more interesting @@ -321,7 +364,8 @@ class SimplePartition extends Partition { class FetchFailureHidingRDD( sc: SparkContext, val input: FetchFailureThrowingRDD, - throwOOM: Boolean) extends RDD[Int](input) { + throwOOM: Boolean, + interrupt: Boolean) extends RDD[Int](input) { override def compute(split: Partition, context: TaskContext): Iterator[Int] = { val inItr = input.compute(split, context) try { @@ -330,6 +374,15 @@ class FetchFailureHidingRDD( case t: Throwable => if (throwOOM) { throw new OutOfMemoryError("OOM while handling another exception") + } else if (interrupt) { + // make sure our test is setup correctly + assert(TaskContext.get().asInstanceOf[TaskContextImpl].fetchFailed.isDefined) + // signal our test is ready for the task to get killed + ExecutorSuiteHelper.latches.latch1.countDown() + // then wait for another thread in the test to kill the task -- this latch + // is never actually decremented, we just wait to get killed. + ExecutorSuiteHelper.latches.latch2.await(10, TimeUnit.SECONDS) + throw new IllegalStateException("timed out waiting to be interrupted") } else { throw new RuntimeException("User Exception that hides the original exception", t) } @@ -352,6 +405,11 @@ private class ExecutorSuiteHelper { @volatile var testFailedReason: TaskFailedReason = _ } +// helper for coordinating killing tasks +private object ExecutorSuiteHelper { + var latches: ExecutorSuiteHelper = null +} + private class NonDeserializableTask extends FakeTask(0, 0) with Externalizable { def writeExternal(out: ObjectOutput): Unit = {} def readExternal(in: ObjectInput): Unit = { From 7c1654e2159662e7e663ba141719d755002f770a Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 9 Apr 2018 11:54:35 -0700 Subject: [PATCH 0587/1161] [SPARK-22856][SQL] Add wrappers for codegen output and nullability ## What changes were proposed in this pull request? The codegen output of `Expression`, aka `ExprCode`, now encapsulates only strings of output value (`value`) and nullability (`isNull`). It makes difficulty for us to know what the output really is. I think it is better if we can add wrappers for the value and nullability that let us to easily know that. ## How was this patch tested? Existing tests. Author: Liang-Chi Hsieh Closes #20043 from viirya/SPARK-22856. --- .../catalyst/expressions/BoundAttribute.scala | 4 +- .../sql/catalyst/expressions/Expression.scala | 16 ++-- .../MonotonicallyIncreasingID.scala | 4 +- .../expressions/SparkPartitionID.scala | 4 +- .../sql/catalyst/expressions/arithmetic.scala | 6 +- .../expressions/codegen/CodeGenerator.scala | 16 ++-- .../expressions/codegen/CodegenFallback.scala | 2 +- .../expressions/codegen/ExprValue.scala | 76 +++++++++++++++++++ .../codegen/GenerateMutableProjection.scala | 6 +- .../codegen/GenerateSafeProjection.scala | 19 +++-- .../codegen/GenerateUnsafeProjection.scala | 8 +- .../expressions/collectionOperations.scala | 4 +- .../expressions/complexTypeCreator.scala | 8 +- .../expressions/conditionalExpressions.scala | 3 +- .../expressions/datetimeExpressions.scala | 3 +- .../sql/catalyst/expressions/generators.scala | 4 +- .../spark/sql/catalyst/expressions/hash.scala | 4 +- .../catalyst/expressions/inputFileBlock.scala | 8 +- .../sql/catalyst/expressions/literals.scala | 24 +++--- .../spark/sql/catalyst/expressions/misc.scala | 5 +- .../expressions/nullExpressions.scala | 25 ++++-- .../expressions/objects/objects.scala | 28 ++++--- .../sql/catalyst/expressions/predicates.scala | 10 +-- .../expressions/randomExpressions.scala | 6 +- .../expressions/CodeGenerationSuite.scala | 6 +- .../expressions/codegen/ExprValueSuite.scala | 46 +++++++++++ .../sql/execution/ColumnarBatchScan.scala | 10 ++- .../spark/sql/execution/ExpandExec.scala | 5 +- .../spark/sql/execution/GenerateExec.scala | 13 ++-- .../sql/execution/WholeStageCodegenExec.scala | 13 ++-- .../aggregate/HashAggregateExec.scala | 3 +- .../aggregate/HashMapGenerator.scala | 5 +- .../execution/basicPhysicalOperators.scala | 6 +- .../joins/BroadcastHashJoinExec.scala | 6 +- .../execution/joins/SortMergeJoinExec.scala | 8 +- 35 files changed, 294 insertions(+), 120 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExprValue.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExprValueSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index 89ffbb0016916..5021a567592e0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors.attachTree -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral} import org.apache.spark.sql.types._ /** @@ -76,7 +76,7 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) | ${CodeGenerator.defaultValue(dataType)} : ($value); """.stripMargin) } else { - ev.copy(code = s"$javaType ${ev.value} = $value;", isNull = "false") + ev.copy(code = s"$javaType ${ev.value} = $value;", isNull = FalseLiteral) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 38caf67d465d8..7a5e49cb5206b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -104,7 +104,9 @@ abstract class Expression extends TreeNode[Expression] { }.getOrElse { val isNull = ctx.freshName("isNull") val value = ctx.freshName("value") - val eval = doGenCode(ctx, ExprCode("", isNull, value)) + val eval = doGenCode(ctx, ExprCode("", + VariableValue(isNull, CodeGenerator.JAVA_BOOLEAN), + VariableValue(value, CodeGenerator.javaType(dataType)))) reduceCodeSize(ctx, eval) if (eval.code.nonEmpty) { // Add `this` in the comment. @@ -118,10 +120,10 @@ abstract class Expression extends TreeNode[Expression] { private def reduceCodeSize(ctx: CodegenContext, eval: ExprCode): Unit = { // TODO: support whole stage codegen too if (eval.code.trim.length > 1024 && ctx.INPUT_ROW != null && ctx.currentVars == null) { - val setIsNull = if (eval.isNull != "false" && eval.isNull != "true") { + val setIsNull = if (!eval.isNull.isInstanceOf[LiteralValue]) { val globalIsNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "globalIsNull") val localIsNull = eval.isNull - eval.isNull = globalIsNull + eval.isNull = GlobalValue(globalIsNull, CodeGenerator.JAVA_BOOLEAN) s"$globalIsNull = $localIsNull;" } else { "" @@ -140,7 +142,7 @@ abstract class Expression extends TreeNode[Expression] { |} """.stripMargin) - eval.value = newValue + eval.value = VariableValue(newValue, javaType) eval.code = s"$javaType $newValue = $funcFullName(${ctx.INPUT_ROW});" } } @@ -446,7 +448,7 @@ abstract class UnaryExpression extends Expression { boolean ${ev.isNull} = false; ${childGen.code} ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; - $resultCode""", isNull = "false") + $resultCode""", isNull = FalseLiteral) } } } @@ -546,7 +548,7 @@ abstract class BinaryExpression extends Expression { ${leftGen.code} ${rightGen.code} ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; - $resultCode""", isNull = "false") + $resultCode""", isNull = FalseLiteral) } } } @@ -690,7 +692,7 @@ abstract class TernaryExpression extends Expression { ${midGen.code} ${rightGen.code} ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; - $resultCode""", isNull = "false") + $resultCode""", isNull = FalseLiteral) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala index dd523d312e3b4..ad1e7bdb31987 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral} import org.apache.spark.sql.types.{DataType, LongType} /** @@ -73,7 +73,7 @@ case class MonotonicallyIncreasingID() extends LeafExpression with Stateful { ev.copy(code = s""" final ${CodeGenerator.javaType(dataType)} ${ev.value} = $partitionMaskTerm + $countTerm; - $countTerm++;""", isNull = "false") + $countTerm++;""", isNull = FalseLiteral) } override def prettyName: String = "monotonically_increasing_id" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala index cc6a769d032d3..787bcaf5e81de 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral} import org.apache.spark.sql.types.{DataType, IntegerType} /** @@ -47,6 +47,6 @@ case class SparkPartitionID() extends LeafExpression with Nondeterministic { ctx.addImmutableStateIfNotExists(CodeGenerator.JAVA_INT, idTerm) ctx.addPartitionInitializationStatement(s"$idTerm = partitionIndex;") ev.copy(code = s"final ${CodeGenerator.javaType(dataType)} ${ev.value} = $idTerm;", - isNull = "false") + isNull = FalseLiteral) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 508bdd5050b54..478ff3a7c1011 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -601,7 +601,8 @@ case class Least(children: Seq[Expression]) extends Expression { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val evalChildren = children.map(_.genCode(ctx)) - ev.isNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, ev.isNull) + ev.isNull = GlobalValue(ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, ev.isNull), + CodeGenerator.JAVA_BOOLEAN) val evals = evalChildren.map(eval => s""" |${eval.code} @@ -680,7 +681,8 @@ case class Greatest(children: Seq[Expression]) extends Expression { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val evalChildren = children.map(_.genCode(ctx)) - ev.isNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, ev.isNull) + ev.isNull = GlobalValue(ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, ev.isNull), + CodeGenerator.JAVA_BOOLEAN) val evals = evalChildren.map(eval => s""" |${eval.code} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 84b1e3fbda876..c9c60ef1be640 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -56,16 +56,17 @@ import org.apache.spark.util.{ParentClassLoader, Utils} * @param value A term for a (possibly primitive) value of the result of the evaluation. Not * valid if `isNull` is set to `true`. */ -case class ExprCode(var code: String, var isNull: String, var value: String) +case class ExprCode(var code: String, var isNull: ExprValue, var value: ExprValue) object ExprCode { def forNullValue(dataType: DataType): ExprCode = { val defaultValueLiteral = CodeGenerator.defaultValue(dataType, typedNull = true) - ExprCode(code = "", isNull = "true", value = defaultValueLiteral) + ExprCode(code = "", isNull = TrueLiteral, + value = LiteralValue(defaultValueLiteral, CodeGenerator.javaType(dataType))) } - def forNonNullValue(value: String): ExprCode = { - ExprCode(code = "", isNull = "false", value = value) + def forNonNullValue(value: ExprValue): ExprCode = { + ExprCode(code = "", isNull = FalseLiteral, value = value) } } @@ -77,7 +78,7 @@ object ExprCode { * @param value A term for a value of a common sub-expression. Not valid if `isNull` * is set to `true`. */ -case class SubExprEliminationState(isNull: String, value: String) +case class SubExprEliminationState(isNull: ExprValue, value: ExprValue) /** * Codes and common subexpressions mapping used for subexpression elimination. @@ -330,7 +331,7 @@ class CodegenContext { case _: StructType | _: ArrayType | _: MapType => s"$value = $initCode.copy();" case _ => s"$value = $initCode;" } - ExprCode(code, "false", value) + ExprCode(code, FalseLiteral, GlobalValue(value, javaType(dataType))) } def declareMutableStates(): String = { @@ -1003,7 +1004,8 @@ class CodegenContext { // at least two nodes) as the cost of doing it is expected to be low. subexprFunctions += s"${addNewFunction(fnName, fn)}($INPUT_ROW);" - val state = SubExprEliminationState(isNull, value) + val state = SubExprEliminationState(GlobalValue(isNull, JAVA_BOOLEAN), + GlobalValue(value, javaType(expr.dataType))) subExprEliminationExprs ++= e.map(_ -> state).toMap } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala index e12420bb5dfdd..a91989e129664 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala @@ -59,7 +59,7 @@ trait CodegenFallback extends Expression { $placeHolder Object $objectTerm = ((Expression) references[$idx]).eval($input); $javaType ${ev.value} = (${CodeGenerator.boxedType(this.dataType)}) $objectTerm; - """, isNull = "false") + """, isNull = FalseLiteral) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExprValue.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExprValue.scala new file mode 100644 index 0000000000000..df5f1c58b1b2d --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExprValue.scala @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.codegen + +import scala.language.implicitConversions + +import org.apache.spark.sql.types.DataType + +// An abstraction that represents the evaluation result of [[ExprCode]]. +abstract class ExprValue { + + val javaType: String + + // Whether we can directly access the evaluation value anywhere. + // For example, a variable created outside a method can not be accessed inside the method. + // For such cases, we may need to pass the evaluation as parameter. + val canDirectAccess: Boolean + + def isPrimitive: Boolean = CodeGenerator.isPrimitiveType(javaType) +} + +object ExprValue { + implicit def exprValueToString(exprValue: ExprValue): String = exprValue.toString +} + +// A literal evaluation of [[ExprCode]]. +class LiteralValue(val value: String, val javaType: String) extends ExprValue { + override def toString: String = value + override val canDirectAccess: Boolean = true +} + +object LiteralValue { + def apply(value: String, javaType: String): LiteralValue = new LiteralValue(value, javaType) + def unapply(literal: LiteralValue): Option[(String, String)] = + Some((literal.value, literal.javaType)) +} + +// A variable evaluation of [[ExprCode]]. +case class VariableValue( + val variableName: String, + val javaType: String) extends ExprValue { + override def toString: String = variableName + override val canDirectAccess: Boolean = false +} + +// A statement evaluation of [[ExprCode]]. +case class StatementValue( + val statement: String, + val javaType: String, + val canDirectAccess: Boolean = false) extends ExprValue { + override def toString: String = statement +} + +// A global variable evaluation of [[ExprCode]]. +case class GlobalValue(val value: String, val javaType: String) extends ExprValue { + override def toString: String = value + override val canDirectAccess: Boolean = true +} + +case object TrueLiteral extends LiteralValue("true", "boolean") +case object FalseLiteral extends LiteralValue("false", "boolean") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index d35fd8ecb4d63..3ae0b54c754cf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -59,7 +59,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP val exprVals = ctx.generateExpressions(validExpr, useSubexprElimination) // 4-tuples: (code for projection, isNull variable name, value variable name, column index) - val projectionCodes: Seq[(String, String, String, Int)] = exprVals.zip(index).map { + val projectionCodes: Seq[(String, ExprValue, String, Int)] = exprVals.zip(index).map { case (ev, i) => val e = expressions(i) val value = ctx.addMutableState(CodeGenerator.javaType(e.dataType), "value") @@ -69,7 +69,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP |${ev.code} |$isNull = ${ev.isNull}; |$value = ${ev.value}; - """.stripMargin, isNull, value, i) + """.stripMargin, GlobalValue(isNull, CodeGenerator.JAVA_BOOLEAN), value, i) } else { (s""" |${ev.code} @@ -83,7 +83,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP val updates = validExpr.zip(projectionCodes).map { case (e, (_, isNull, value, i)) => - val ev = ExprCode("", isNull, value) + val ev = ExprCode("", isNull, GlobalValue(value, CodeGenerator.javaType(e.dataType))) CodeGenerator.updateColumn("mutableRow", e.dataType, i, ev, e.nullable) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala index f92f70ee71fef..a30a0b22cd305 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala @@ -53,7 +53,9 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] val rowClass = classOf[GenericInternalRow].getName val fieldWriters = schema.map(_.dataType).zipWithIndex.map { case (dt, i) => - val converter = convertToSafe(ctx, CodeGenerator.getValue(tmpInput, dt, i.toString), dt) + val converter = convertToSafe(ctx, + StatementValue(CodeGenerator.getValue(tmpInput, dt, i.toString), + CodeGenerator.javaType(dt)), dt) s""" if (!$tmpInput.isNullAt($i)) { ${converter.code} @@ -74,7 +76,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] |final InternalRow $output = new $rowClass($values); """.stripMargin - ExprCode(code, "false", output) + ExprCode(code, FalseLiteral, VariableValue(output, "InternalRow")) } private def createCodeForArray( @@ -89,8 +91,9 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] val index = ctx.freshName("index") val arrayClass = classOf[GenericArrayData].getName - val elementConverter = convertToSafe( - ctx, CodeGenerator.getValue(tmpInput, elementType, index), elementType) + val elementConverter = convertToSafe(ctx, + StatementValue(CodeGenerator.getValue(tmpInput, elementType, index), + CodeGenerator.javaType(elementType)), elementType) val code = s""" final ArrayData $tmpInput = $input; final int $numElements = $tmpInput.numElements(); @@ -104,7 +107,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] final ArrayData $output = new $arrayClass($values); """ - ExprCode(code, "false", output) + ExprCode(code, FalseLiteral, VariableValue(output, "ArrayData")) } private def createCodeForMap( @@ -125,19 +128,19 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] final MapData $output = new $mapClass(${keyConverter.value}, ${valueConverter.value}); """ - ExprCode(code, "false", output) + ExprCode(code, FalseLiteral, VariableValue(output, "MapData")) } @tailrec private def convertToSafe( ctx: CodegenContext, - input: String, + input: ExprValue, dataType: DataType): ExprCode = dataType match { case s: StructType => createCodeForStruct(ctx, input, s) case ArrayType(elementType, _) => createCodeForArray(ctx, input, elementType) case MapType(keyType, valueType, _) => createCodeForMap(ctx, input, keyType, valueType) case udt: UserDefinedType[_] => convertToSafe(ctx, input, udt.sqlType) - case _ => ExprCode("", "false", input) + case _ => ExprCode("", FalseLiteral, input) } protected def create(expressions: Seq[Expression]): Projection = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index ab2254cd9f70a..4a4d76313a543 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -52,7 +52,9 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro // Puts `input` in a local variable to avoid to re-evaluate it if it's a statement. val tmpInput = ctx.freshName("tmpInput") val fieldEvals = fieldTypes.zipWithIndex.map { case (dt, i) => - ExprCode("", s"$tmpInput.isNullAt($i)", CodeGenerator.getValue(tmpInput, dt, i.toString)) + ExprCode("", StatementValue(s"$tmpInput.isNullAt($i)", CodeGenerator.JAVA_BOOLEAN), + StatementValue(CodeGenerator.getValue(tmpInput, dt, i.toString), + CodeGenerator.javaType(dt))) } val rowWriterClass = classOf[UnsafeRowWriter].getName @@ -334,7 +336,9 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro $evalSubexpr $writeExpressions """ - ExprCode(code, "false", s"$rowWriter.getRow()") + // `rowWriter` is declared as a class field, so we can access it directly in methods. + ExprCode(code, FalseLiteral, StatementValue(s"$rowWriter.getRow()", "UnsafeRow", + canDirectAccess = true)) } protected def canonicalize(in: Seq[Expression]): Seq[Expression] = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index beb84694c44e8..91188da8b0bd3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -20,7 +20,7 @@ import java.util.Comparator import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, CodegenFallback, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData} import org.apache.spark.sql.types._ @@ -55,7 +55,7 @@ case class Size(child: Expression) extends UnaryExpression with ExpectsInputType boolean ${ev.isNull} = false; ${childGen.code} ${CodeGenerator.javaType(dataType)} ${ev.value} = ${childGen.isNull} ? -1 : - (${childGen.value}).numElements();""", isNull = "false") + (${childGen.value}).numElements();""", isNull = FalseLiteral) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 85facdad43db7..49a8d12057188 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -64,8 +64,8 @@ case class CreateArray(children: Seq[Expression]) extends Expression { GenArrayData.genCodeToCreateArrayData(ctx, et, evals, false) ev.copy( code = preprocess + assigns + postprocess, - value = arrayData, - isNull = "false") + value = VariableValue(arrayData, CodeGenerator.javaType(dataType)), + isNull = FalseLiteral) } override def prettyName: String = "array" @@ -378,7 +378,7 @@ case class CreateNamedStruct(children: Seq[Expression]) extends CreateNamedStruc |$valuesCode |final InternalRow ${ev.value} = new $rowClass($values); |$values = null; - """.stripMargin, isNull = "false") + """.stripMargin, isNull = FalseLiteral) } override def prettyName: String = "named_struct" @@ -394,7 +394,7 @@ case class CreateNamedStruct(children: Seq[Expression]) extends CreateNamedStruc case class CreateNamedStructUnsafe(children: Seq[Expression]) extends CreateNamedStructLike { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val eval = GenerateUnsafeProjection.createCode(ctx, valExprs) - ExprCode(code = eval.code, isNull = "false", value = eval.value) + ExprCode(code = eval.code, isNull = FalseLiteral, value = eval.value) } override def prettyName: String = "named_struct_unsafe" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index f4e9619bac59d..409c0b6b79b81 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -191,7 +191,8 @@ case class CaseWhen( // It is initialized to `NOT_MATCHED`, and if it's set to `HAS_NULL` or `HAS_NONNULL`, // We won't go on anymore on the computation. val resultState = ctx.freshName("caseWhenResultState") - ev.value = ctx.addMutableState(CodeGenerator.javaType(dataType), ev.value) + ev.value = GlobalValue(ctx.addMutableState(CodeGenerator.javaType(dataType), ev.value), + CodeGenerator.javaType(dataType)) // these blocks are meant to be inside a // do { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index 1ae4e5a2f716b..49dd988b4b53c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -813,7 +813,8 @@ case class FromUnixTime(sec: Expression, format: Expression, timeZoneId: Option[ val df = classOf[DateFormat].getName if (format.foldable) { if (formatter == null) { - ExprCode("", "true", "(UTF8String) null") + ExprCode("", TrueLiteral, LiteralValue("(UTF8String) null", + CodeGenerator.javaType(dataType))) } else { val formatterName = ctx.addReferenceObj("formatter", formatter, df) val t = left.genCode(ctx) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index 4f4d49166e88c..3af4bfebad45e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -22,7 +22,7 @@ import scala.collection.mutable import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodegenFallback, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} import org.apache.spark.sql.types._ @@ -218,7 +218,7 @@ case class Stack(children: Seq[Expression]) extends Generator { s""" |$code |$wrapperClass ${ev.value} = $wrapperClass$$.MODULE$$.make($rowData); - """.stripMargin, isNull = "false") + """.stripMargin, isNull = FalseLiteral) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala index b76b64ab5096f..df29c38d64d3d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala @@ -270,7 +270,7 @@ abstract class HashExpression[E] extends Expression { protected def computeHash(value: Any, dataType: DataType, seed: E): E override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - ev.isNull = "false" + ev.isNull = FalseLiteral val childrenHash = children.map { child => val childGen = child.genCode(ctx) @@ -644,7 +644,7 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] { } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - ev.isNull = "false" + ev.isNull = FalseLiteral val childHash = ctx.freshName("childHash") val childrenHash = children.map { child => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/inputFileBlock.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/inputFileBlock.scala index 07785e7448586..2a3cc580273ee 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/inputFileBlock.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/inputFileBlock.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.rdd.InputFileBlockHolder import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral} import org.apache.spark.sql.types.{DataType, LongType, StringType} import org.apache.spark.unsafe.types.UTF8String @@ -43,7 +43,7 @@ case class InputFileName() extends LeafExpression with Nondeterministic { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val className = InputFileBlockHolder.getClass.getName.stripSuffix("$") ev.copy(code = s"final ${CodeGenerator.javaType(dataType)} ${ev.value} = " + - s"$className.getInputFilePath();", isNull = "false") + s"$className.getInputFilePath();", isNull = FalseLiteral) } } @@ -66,7 +66,7 @@ case class InputFileBlockStart() extends LeafExpression with Nondeterministic { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val className = InputFileBlockHolder.getClass.getName.stripSuffix("$") ev.copy(code = s"final ${CodeGenerator.javaType(dataType)} ${ev.value} = " + - s"$className.getStartOffset();", isNull = "false") + s"$className.getStartOffset();", isNull = FalseLiteral) } } @@ -89,6 +89,6 @@ case class InputFileBlockLength() extends LeafExpression with Nondeterministic { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val className = InputFileBlockHolder.getClass.getName.stripSuffix("$") ev.copy(code = s"final ${CodeGenerator.javaType(dataType)} ${ev.value} = " + - s"$className.getLength();", isNull = "false") + s"$className.getLength();", isNull = FalseLiteral) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index 7395609a04ba5..742a650eb445d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -283,36 +283,36 @@ case class Literal (value: Any, dataType: DataType) extends LeafExpression { } else { dataType match { case BooleanType | IntegerType | DateType => - ExprCode.forNonNullValue(value.toString) + ExprCode.forNonNullValue(LiteralValue(value.toString, javaType)) case FloatType => value.asInstanceOf[Float] match { case v if v.isNaN => - ExprCode.forNonNullValue("Float.NaN") + ExprCode.forNonNullValue(LiteralValue("Float.NaN", javaType)) case Float.PositiveInfinity => - ExprCode.forNonNullValue("Float.POSITIVE_INFINITY") + ExprCode.forNonNullValue(LiteralValue("Float.POSITIVE_INFINITY", javaType)) case Float.NegativeInfinity => - ExprCode.forNonNullValue("Float.NEGATIVE_INFINITY") + ExprCode.forNonNullValue(LiteralValue("Float.NEGATIVE_INFINITY", javaType)) case _ => - ExprCode.forNonNullValue(s"${value}F") + ExprCode.forNonNullValue(LiteralValue(s"${value}F", javaType)) } case DoubleType => value.asInstanceOf[Double] match { case v if v.isNaN => - ExprCode.forNonNullValue("Double.NaN") + ExprCode.forNonNullValue(LiteralValue("Double.NaN", javaType)) case Double.PositiveInfinity => - ExprCode.forNonNullValue("Double.POSITIVE_INFINITY") + ExprCode.forNonNullValue(LiteralValue("Double.POSITIVE_INFINITY", javaType)) case Double.NegativeInfinity => - ExprCode.forNonNullValue("Double.NEGATIVE_INFINITY") + ExprCode.forNonNullValue(LiteralValue("Double.NEGATIVE_INFINITY", javaType)) case _ => - ExprCode.forNonNullValue(s"${value}D") + ExprCode.forNonNullValue(LiteralValue(s"${value}D", javaType)) } case ByteType | ShortType => - ExprCode.forNonNullValue(s"($javaType)$value") + ExprCode.forNonNullValue(LiteralValue(s"($javaType)$value", javaType)) case TimestampType | LongType => - ExprCode.forNonNullValue(s"${value}L") + ExprCode.forNonNullValue(LiteralValue(s"${value}L", javaType)) case _ => val constRef = ctx.addReferenceObj("literal", value, javaType) - ExprCode.forNonNullValue(constRef) + ExprCode.forNonNullValue(GlobalValue(constRef, javaType)) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index a390f8ef7fd9a..7081a5e096d56 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -91,7 +91,8 @@ case class AssertTrue(child: Expression) extends UnaryExpression with ImplicitCa ExprCode(code = s"""${eval.code} |if (${eval.isNull} || !${eval.value}) { | throw new RuntimeException($errMsgField); - |}""".stripMargin, isNull = "true", value = "null") + |}""".stripMargin, isNull = TrueLiteral, + value = LiteralValue("null", CodeGenerator.javaType(dataType))) } override def sql: String = s"assert_true(${child.sql})" @@ -150,7 +151,7 @@ case class Uuid(randomSeed: Option[Long] = None) extends LeafExpression with Sta "new org.apache.spark.sql.catalyst.util.RandomUUIDGenerator(" + s"${randomSeed.get}L + partitionIndex);") ev.copy(code = s"final UTF8String ${ev.value} = $randomGen.getNextUUIDUTF8String();", - isNull = "false") + isNull = FalseLiteral) } override def freshCopy(): Uuid = Uuid(randomSeed) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala index b35fa72e95d1e..55b6e346be82a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ @@ -72,7 +72,8 @@ case class Coalesce(children: Seq[Expression]) extends Expression { } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - ev.isNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, ev.isNull) + ev.isNull = GlobalValue(ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, ev.isNull), + CodeGenerator.JAVA_BOOLEAN) // all the evals are meant to be in a do { ... } while (false); loop val evals = children.map { e => @@ -235,7 +236,7 @@ case class IsNaN(child: Expression) extends UnaryExpression ev.copy(code = s""" ${eval.code} ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; - ${ev.value} = !${eval.isNull} && Double.isNaN(${eval.value});""", isNull = "false") + ${ev.value} = !${eval.isNull} && Double.isNaN(${eval.value});""", isNull = FalseLiteral) } } } @@ -320,7 +321,12 @@ case class IsNull(child: Expression) extends UnaryExpression with Predicate { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val eval = child.genCode(ctx) - ExprCode(code = eval.code, isNull = "false", value = eval.isNull) + val value = if (eval.isNull.isInstanceOf[LiteralValue]) { + LiteralValue(eval.isNull, CodeGenerator.JAVA_BOOLEAN) + } else { + VariableValue(eval.isNull, CodeGenerator.JAVA_BOOLEAN) + } + ExprCode(code = eval.code, isNull = FalseLiteral, value = value) } override def sql: String = s"(${child.sql} IS NULL)" @@ -346,7 +352,14 @@ case class IsNotNull(child: Expression) extends UnaryExpression with Predicate { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val eval = child.genCode(ctx) - ExprCode(code = eval.code, isNull = "false", value = s"(!(${eval.isNull}))") + val value = if (eval.isNull == TrueLiteral) { + FalseLiteral + } else if (eval.isNull == FalseLiteral) { + TrueLiteral + } else { + StatementValue(s"(!(${eval.isNull}))", CodeGenerator.javaType(dataType)) + } + ExprCode(code = eval.code, isNull = FalseLiteral, value = value) } override def sql: String = s"(${child.sql} IS NOT NULL)" @@ -441,6 +454,6 @@ case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends Predicate | $codes |} while (false); |${CodeGenerator.JAVA_BOOLEAN} ${ev.value} = $nonnull >= $n; - """.stripMargin, isNull = "false") + """.stripMargin, isNull = FalseLiteral) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 9252425f86473..b2cca3178cd2a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.catalyst.{InternalRow, ScalaReflection} import org.apache.spark.sql.catalyst.ScalaReflection.universe.TermName import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData} import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -61,13 +61,13 @@ trait InvokeLike extends Expression with NonSQLExpression { * @param ctx a [[CodegenContext]] * @return (code to prepare arguments, argument string, result of argument null check) */ - def prepareArguments(ctx: CodegenContext): (String, String, String) = { + def prepareArguments(ctx: CodegenContext): (String, String, ExprValue) = { val resultIsNull = if (needNullCheck) { val resultIsNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "resultIsNull") - resultIsNull + GlobalValue(resultIsNull, CodeGenerator.JAVA_BOOLEAN) } else { - "false" + FalseLiteral } val argValues = arguments.map { e => val argValue = ctx.addMutableState(CodeGenerator.javaType(e.dataType), "argValue") @@ -244,7 +244,7 @@ case class StaticInvoke( val prepareIsNull = if (nullable) { s"boolean ${ev.isNull} = $resultIsNull;" } else { - ev.isNull = "false" + ev.isNull = FalseLiteral "" } @@ -546,7 +546,7 @@ case class WrapOption(child: Expression, optType: DataType) ${inputObject.isNull} ? scala.Option$$.MODULE$$.apply(null) : new scala.Some(${inputObject.value}); """ - ev.copy(code = code, isNull = "false") + ev.copy(code = code, isNull = FalseLiteral) } } @@ -568,7 +568,13 @@ case class LambdaVariable( } override def genCode(ctx: CodegenContext): ExprCode = { - ExprCode(code = "", value = value, isNull = if (nullable) isNull else "false") + val isNullValue = if (nullable) { + VariableValue(isNull, CodeGenerator.JAVA_BOOLEAN) + } else { + FalseLiteral + } + ExprCode(code = "", value = VariableValue(value, CodeGenerator.javaType(dataType)), + isNull = isNullValue) } // This won't be called as `genCode` is overrided, just overriding it to make @@ -840,7 +846,7 @@ case class MapObjects private( // Make a copy of the data if it's unsafe-backed def makeCopyIfInstanceOf(clazz: Class[_ <: Any], value: String) = s"$value instanceof ${clazz.getSimpleName}? ${value}.copy() : $value" - val genFunctionValue = lambdaFunction.dataType match { + val genFunctionValue: String = lambdaFunction.dataType match { case StructType(_) => makeCopyIfInstanceOf(classOf[UnsafeRow], genFunction.value) case ArrayType(_, _) => makeCopyIfInstanceOf(classOf[UnsafeArrayData], genFunction.value) case MapType(_, _, _) => makeCopyIfInstanceOf(classOf[UnsafeMapData], genFunction.value) @@ -1343,7 +1349,7 @@ case class CreateExternalRow(children: Seq[Expression], schema: StructType) |$childrenCode |final ${classOf[Row].getName} ${ev.value} = new $rowClass($values, $schemaField); """.stripMargin - ev.copy(code = code, isNull = "false") + ev.copy(code = code, isNull = FalseLiteral) } } @@ -1538,7 +1544,7 @@ case class AssertNotNull(child: Expression, walkedTypePath: Seq[String] = Nil) throw new NullPointerException($errMsgField); } """ - ev.copy(code = code, isNull = "false", value = childGen.value) + ev.copy(code = code, isNull = FalseLiteral, value = childGen.value) } } @@ -1589,7 +1595,7 @@ case class GetExternalRowField( final Object ${ev.value} = ${row.value}.get($index); """ - ev.copy(code = code, isNull = "false") + ev.copy(code = code, isNull = FalseLiteral) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 4b85d9adbe311..e195ec17f3bcf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -21,7 +21,7 @@ import scala.collection.immutable.TreeSet import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, GenerateSafeProjection, GenerateUnsafeProjection, Predicate => BasePredicate} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral, GenerateSafeProjection, GenerateUnsafeProjection, Predicate => BasePredicate} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ @@ -405,7 +405,7 @@ case class And(left: Expression, right: Expression) extends BinaryOperator with if (${eval1.value}) { ${eval2.code} ${ev.value} = ${eval2.value}; - }""", isNull = "false") + }""", isNull = FalseLiteral) } else { ev.copy(code = s""" ${eval1.code} @@ -461,7 +461,7 @@ case class Or(left: Expression, right: Expression) extends BinaryOperator with P // The result should be `true`, if any of them is `true` whenever the other is null or not. if (!left.nullable && !right.nullable) { - ev.isNull = "false" + ev.isNull = FalseLiteral ev.copy(code = s""" ${eval1.code} boolean ${ev.value} = true; @@ -469,7 +469,7 @@ case class Or(left: Expression, right: Expression) extends BinaryOperator with P if (!${eval1.value}) { ${eval2.code} ${ev.value} = ${eval2.value}; - }""", isNull = "false") + }""", isNull = FalseLiteral) } else { ev.copy(code = s""" ${eval1.code} @@ -615,7 +615,7 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp val equalCode = ctx.genEqual(left.dataType, eval1.value, eval2.value) ev.copy(code = eval1.code + eval2.code + s""" boolean ${ev.value} = (${eval1.isNull} && ${eval2.isNull}) || - (!${eval1.isNull} && !${eval2.isNull} && $equalCode);""", isNull = "false") + (!${eval1.isNull} && !${eval2.isNull} && $equalCode);""", isNull = FalseLiteral) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala index f36633867316e..70186053617f8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral} import org.apache.spark.sql.types._ import org.apache.spark.util.Utils import org.apache.spark.util.random.XORShiftRandom @@ -83,7 +83,7 @@ case class Rand(child: Expression) extends RDG { s"$rngTerm = new $className(${seed}L + partitionIndex);") ev.copy(code = s""" final ${CodeGenerator.javaType(dataType)} ${ev.value} = $rngTerm.nextDouble();""", - isNull = "false") + isNull = FalseLiteral) } override def freshCopy(): Rand = Rand(child) @@ -120,7 +120,7 @@ case class Randn(child: Expression) extends RDG { s"$rngTerm = new $className(${seed}L + partitionIndex);") ev.copy(code = s""" final ${CodeGenerator.javaType(dataType)} ${ev.value} = $rngTerm.nextGaussian();""", - isNull = "false") + isNull = FalseLiteral) } override def freshCopy(): Randn = Randn(child) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index 398b6767654fa..8e83b35c3809c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -448,6 +448,8 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { val ref = BoundReference(0, IntegerType, true) val add1 = Add(ref, ref) val add2 = Add(add1, add1) + val dummy = SubExprEliminationState(VariableValue("dummy", "boolean"), + VariableValue("dummy", "boolean")) // raw testing of basic functionality { @@ -457,7 +459,7 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { ctx.subExprEliminationExprs += ref -> SubExprEliminationState(e.isNull, e.value) assert(ctx.subExprEliminationExprs.contains(ref)) // call withSubExprEliminationExprs - ctx.withSubExprEliminationExprs(Map(add1 -> SubExprEliminationState("dummy", "dummy"))) { + ctx.withSubExprEliminationExprs(Map(add1 -> dummy)) { assert(ctx.subExprEliminationExprs.contains(add1)) assert(!ctx.subExprEliminationExprs.contains(ref)) Seq.empty @@ -475,7 +477,7 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { ctx.generateExpressions(Seq(add2, add1), doSubexpressionElimination = true) // trigger CSE assert(ctx.subExprEliminationExprs.contains(add1)) // call withSubExprEliminationExprs - ctx.withSubExprEliminationExprs(Map(ref -> SubExprEliminationState("dummy", "dummy"))) { + ctx.withSubExprEliminationExprs(Map(ref -> dummy)) { assert(ctx.subExprEliminationExprs.contains(ref)) assert(!ctx.subExprEliminationExprs.contains(add1)) Seq.empty diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExprValueSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExprValueSuite.scala new file mode 100644 index 0000000000000..c8f4cff7db48d --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExprValueSuite.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.codegen + +import org.apache.spark.SparkFunSuite + +class ExprValueSuite extends SparkFunSuite { + + test("TrueLiteral and FalseLiteral should be LiteralValue") { + val trueLit = TrueLiteral + val falseLit = FalseLiteral + + assert(trueLit.value == "true") + assert(falseLit.value == "false") + + assert(trueLit.isPrimitive) + assert(falseLit.isPrimitive) + + trueLit match { + case LiteralValue(value, javaType) => + assert(value == "true" && javaType == "boolean") + case _ => fail() + } + + falseLit match { + case LiteralValue(value, javaType) => + assert(value == "false" && javaType == "boolean") + case _ => fail() + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala index 392906a022903..434214a10e1e3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.catalyst.expressions.{BoundReference, UnsafeRow} -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral, VariableValue} import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.types.DataType import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector} @@ -51,7 +51,11 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport { nullable: Boolean): ExprCode = { val javaType = CodeGenerator.javaType(dataType) val value = CodeGenerator.getValueFromVector(columnVar, dataType, ordinal) - val isNullVar = if (nullable) { ctx.freshName("isNull") } else { "false" } + val isNullVar = if (nullable) { + VariableValue(ctx.freshName("isNull"), CodeGenerator.JAVA_BOOLEAN) + } else { + FalseLiteral + } val valueVar = ctx.freshName("value") val str = s"columnVector[$columnVar, $ordinal, ${dataType.simpleString}]" val code = s"${ctx.registerComment(str)}\n" + (if (nullable) { @@ -62,7 +66,7 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport { } else { s"$javaType $valueVar = $value;" }).trim - ExprCode(code, isNullVar, valueVar) + ExprCode(code, isNullVar, VariableValue(valueVar, javaType)) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala index 12ae1ea4a7c13..0d9a62cace62a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala @@ -21,7 +21,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, VariableValue} import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartitioning} import org.apache.spark.sql.execution.metric.SQLMetrics @@ -157,7 +157,8 @@ case class ExpandExec( |${CodeGenerator.javaType(firstExpr.dataType)} $value = | ${CodeGenerator.defaultValue(firstExpr.dataType)}; """.stripMargin - ExprCode(code, isNull, value) + ExprCode(code, VariableValue(isNull, CodeGenerator.JAVA_BOOLEAN), + VariableValue(value, CodeGenerator.javaType(firstExpr.dataType))) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala index 384f0398a1ec0..85c5ebfdaa689 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructType} @@ -170,9 +170,10 @@ case class GenerateExec( // Add position val position = if (e.position) { if (outer) { - Seq(ExprCode("", s"$index == -1", index)) + Seq(ExprCode("", StatementValue(s"$index == -1", CodeGenerator.JAVA_BOOLEAN), + VariableValue(index, CodeGenerator.JAVA_INT))) } else { - Seq(ExprCode("", "false", index)) + Seq(ExprCode("", FalseLiteral, VariableValue(index, CodeGenerator.JAVA_INT))) } } else { Seq.empty @@ -315,9 +316,11 @@ case class GenerateExec( |boolean $isNull = ${checks.mkString(" || ")}; |$javaType $value = $isNull ? ${CodeGenerator.defaultValue(dt)} : $getter; """.stripMargin - ExprCode(code, isNull, value) + ExprCode(code, VariableValue(isNull, CodeGenerator.JAVA_BOOLEAN), + VariableValue(value, javaType)) } else { - ExprCode(s"$javaType $value = $getter;", "false", value) + ExprCode(s"$javaType $value = $getter;", FalseLiteral, + VariableValue(value, javaType)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index 6ddaacfee1a40..805ff3cf001ba 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -111,7 +111,7 @@ trait CodegenSupport extends SparkPlan { private def prepareRowVar(ctx: CodegenContext, row: String, colVars: Seq[ExprCode]): ExprCode = { if (row != null) { - ExprCode("", "false", row) + ExprCode("", FalseLiteral, VariableValue(row, "UnsafeRow")) } else { if (colVars.nonEmpty) { val colExprs = output.zipWithIndex.map { case (attr, i) => @@ -126,10 +126,10 @@ trait CodegenSupport extends SparkPlan { |$evaluateInputs |${ev.code.trim} """.stripMargin.trim - ExprCode(code, "false", ev.value) + ExprCode(code, FalseLiteral, ev.value) } else { // There is no columns - ExprCode("", "false", "unsafeRow") + ExprCode("", FalseLiteral, VariableValue("unsafeRow", "UnsafeRow")) } } } @@ -241,15 +241,16 @@ trait CodegenSupport extends SparkPlan { parameters += s"$paramType $paramName" val paramIsNull = if (!attributes(i).nullable) { // Use constant `false` without passing `isNull` for non-nullable variable. - "false" + FalseLiteral } else { val isNull = ctx.freshName(s"exprIsNull_$i") arguments += ev.isNull parameters += s"boolean $isNull" - isNull + VariableValue(isNull, CodeGenerator.JAVA_BOOLEAN) } - paramVars += ExprCode("", paramIsNull, paramName) + paramVars += ExprCode("", paramIsNull, + VariableValue(paramName, CodeGenerator.javaType(attributes(i).dataType))) } (arguments, parameters, paramVars) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 1926e9373bc55..8f7f10243d4cc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -194,7 +194,8 @@ case class HashAggregateExec( | $isNull = ${ev.isNull}; | $value = ${ev.value}; """.stripMargin - ExprCode(ev.code + initVars, isNull, value) + ExprCode(ev.code + initVars, GlobalValue(isNull, CodeGenerator.JAVA_BOOLEAN), + GlobalValue(value, CodeGenerator.javaType(e.dataType))) } val initBufVar = evaluateVariables(bufVars) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala index 6b60b414ffe5f..4978954271311 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, DeclarativeAggregate} -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, GlobalValue} import org.apache.spark.sql.types._ /** @@ -54,7 +54,8 @@ abstract class HashMapGenerator( | $isNull = ${ev.isNull}; | $value = ${ev.value}; """.stripMargin - ExprCode(ev.code + initVars, isNull, value) + ExprCode(ev.code + initVars, GlobalValue(isNull, CodeGenerator.JAVA_BOOLEAN), + GlobalValue(value, CodeGenerator.javaType(e.dataType))) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index 4707022f74547..cab7081400ce9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -24,7 +24,7 @@ import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskCon import org.apache.spark.rdd.{EmptyRDD, PartitionwiseSampledRDD, RDD} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, ExpressionCanonicalizer} +import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.types.LongType @@ -192,7 +192,7 @@ case class FilterExec(condition: Expression, child: SparkPlan) // generate better code (remove dead branches). val resultVars = input.zipWithIndex.map { case (ev, i) => if (notNullAttributes.contains(child.output(i).exprId)) { - ev.isNull = "false" + ev.isNull = FalseLiteral } ev } @@ -368,7 +368,7 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) val number = ctx.addMutableState(CodeGenerator.JAVA_LONG, "number") val value = ctx.freshName("value") - val ev = ExprCode("", "false", value) + val ev = ExprCode("", FalseLiteral, VariableValue(value, CodeGenerator.JAVA_LONG)) val BigInt = classOf[java.math.BigInteger].getName // Inline mutable state since not many Range operations in a task diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala index 487d6a2383318..fa62a32d51f3e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala @@ -192,7 +192,8 @@ case class BroadcastHashJoinExec( | $value = ${ev.value}; |} """.stripMargin - ExprCode(code, isNull, value) + ExprCode(code, VariableValue(isNull, CodeGenerator.JAVA_BOOLEAN), + VariableValue(value, CodeGenerator.javaType(a.dataType))) } } } @@ -487,7 +488,8 @@ case class BroadcastHashJoinExec( s"$existsVar = true;" } - val resultVar = input ++ Seq(ExprCode("", "false", existsVar)) + val resultVar = input ++ Seq(ExprCode("", FalseLiteral, + VariableValue(existsVar, CodeGenerator.JAVA_BOOLEAN))) if (broadcastRelation.value.keyIsUnique) { s""" |// generate join key for stream side diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index 5a511b30e4fd9..b61acb8d4fda9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -22,7 +22,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral, VariableValue} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.{BinaryExecNode, CodegenSupport, @@ -531,11 +531,13 @@ case class SortMergeJoinExec( |boolean $isNull = false; |$javaType $value = $defaultValue; """.stripMargin - (ExprCode(code, isNull, value), leftVarsDecl) + (ExprCode(code, VariableValue(isNull, CodeGenerator.JAVA_BOOLEAN), + VariableValue(value, CodeGenerator.javaType(a.dataType))), leftVarsDecl) } else { val code = s"$value = $valueCode;" val leftVarsDecl = s"""$javaType $value = $defaultValue;""" - (ExprCode(code, "false", value), leftVarsDecl) + (ExprCode(code, FalseLiteral, + VariableValue(value, CodeGenerator.javaType(a.dataType))), leftVarsDecl) } }.unzip } From 252468a744b95082400ba9e8b2e3b3d9d50ab7fa Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Mon, 9 Apr 2018 12:18:07 -0700 Subject: [PATCH 0588/1161] [SPARK-14681][ML] Provide label/impurity stats for spark.ml decision tree nodes ## What changes were proposed in this pull request? API: ``` trait ClassificationNode extends Node def getLabelCount(label: Int): Double trait RegressionNode extends Node def getCount(): Double def getSum(): Double def getSquareSum(): Double // turn LeafNode to be trait trait LeafNode extends Node { def prediction: Double def impurity: Double ... } class ClassificationLeafNode extends ClassificationNode with LeafNode class RegressionLeafNode extends RegressionNode with LeafNode // turn InternalNode to be trait trait InternalNode extends Node{ def gain: Double def leftChild: Node def rightChild: Node def split: Split ... } class ClassificationInternalNode extends ClassificationNode with InternalNode override def leftChild: ClassificationNode override def rightChild: ClassificationNode class RegressionInternalNode extends RegressionNode with InternalNode override val leftChild: RegressionNode override val rightChild: RegressionNode class DecisionTreeClassificationModel override val rootNode: ClassificationNode class DecisionTreeRegressionModel override val rootNode: RegressionNode ``` Closes #17466 ## How was this patch tested? UT will be added soon. Author: WeichenXu Author: jkbradley Closes #20786 from WeichenXu123/tree_stat_api_2. --- .../DecisionTreeClassifier.scala | 14 +- .../ml/classification/GBTClassifier.scala | 6 +- .../RandomForestClassifier.scala | 6 +- .../ml/regression/DecisionTreeRegressor.scala | 13 +- .../spark/ml/regression/GBTRegressor.scala | 6 +- .../ml/regression/RandomForestRegressor.scala | 6 +- .../scala/org/apache/spark/ml/tree/Node.scala | 247 ++++++++++++++---- .../spark/ml/tree/impl/RandomForest.scala | 10 +- .../org/apache/spark/ml/tree/treeModels.scala | 36 ++- .../DecisionTreeClassifierSuite.scala | 31 ++- .../classification/GBTClassifierSuite.scala | 4 +- .../RandomForestClassifierSuite.scala | 5 +- .../DecisionTreeRegressorSuite.scala | 14 + .../ml/tree/impl/RandomForestSuite.scala | 22 +- .../apache/spark/ml/tree/impl/TreeTests.scala | 12 +- project/MimaExcludes.scala | 9 +- 16 files changed, 333 insertions(+), 108 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala index 65cce697d8202..771cd4fe91dcf 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala @@ -165,7 +165,7 @@ object DecisionTreeClassifier extends DefaultParamsReadable[DecisionTreeClassifi @Since("1.4.0") class DecisionTreeClassificationModel private[ml] ( @Since("1.4.0")override val uid: String, - @Since("1.4.0")override val rootNode: Node, + @Since("1.4.0")override val rootNode: ClassificationNode, @Since("1.6.0")override val numFeatures: Int, @Since("1.5.0")override val numClasses: Int) extends ProbabilisticClassificationModel[Vector, DecisionTreeClassificationModel] @@ -178,7 +178,7 @@ class DecisionTreeClassificationModel private[ml] ( * Construct a decision tree classification model. * @param rootNode Root node of tree, with other nodes attached. */ - private[ml] def this(rootNode: Node, numFeatures: Int, numClasses: Int) = + private[ml] def this(rootNode: ClassificationNode, numFeatures: Int, numClasses: Int) = this(Identifiable.randomUID("dtc"), rootNode, numFeatures, numClasses) override def predict(features: Vector): Double = { @@ -276,8 +276,9 @@ object DecisionTreeClassificationModel extends MLReadable[DecisionTreeClassifica val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val numFeatures = (metadata.metadata \ "numFeatures").extract[Int] val numClasses = (metadata.metadata \ "numClasses").extract[Int] - val root = loadTreeNodes(path, metadata, sparkSession) - val model = new DecisionTreeClassificationModel(metadata.uid, root, numFeatures, numClasses) + val root = loadTreeNodes(path, metadata, sparkSession, isClassification = true) + val model = new DecisionTreeClassificationModel(metadata.uid, + root.asInstanceOf[ClassificationNode], numFeatures, numClasses) DefaultParamsReader.getAndSetParams(model, metadata) model } @@ -292,9 +293,10 @@ object DecisionTreeClassificationModel extends MLReadable[DecisionTreeClassifica require(oldModel.algo == OldAlgo.Classification, s"Cannot convert non-classification DecisionTreeModel (old API) to" + s" DecisionTreeClassificationModel (new API). Algo is: ${oldModel.algo}") - val rootNode = Node.fromOld(oldModel.topNode, categoricalFeatures) + val rootNode = Node.fromOld(oldModel.topNode, categoricalFeatures, isClassification = true) val uid = if (parent != null) parent.uid else Identifiable.randomUID("dtc") // Can't infer number of features from old model, so default to -1 - new DecisionTreeClassificationModel(uid, rootNode, numFeatures, -1) + new DecisionTreeClassificationModel(uid, + rootNode.asInstanceOf[ClassificationNode], numFeatures, -1) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index cd44489f618b2..c0255103bc313 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -371,14 +371,14 @@ object GBTClassificationModel extends MLReadable[GBTClassificationModel] { override def load(path: String): GBTClassificationModel = { implicit val format = DefaultFormats val (metadata: Metadata, treesData: Array[(Metadata, Node)], treeWeights: Array[Double]) = - EnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName) + EnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName, false) val numFeatures = (metadata.metadata \ numFeaturesKey).extract[Int] val numTrees = (metadata.metadata \ numTreesKey).extract[Int] val trees: Array[DecisionTreeRegressionModel] = treesData.map { case (treeMetadata, root) => - val tree = - new DecisionTreeRegressionModel(treeMetadata.uid, root, numFeatures) + val tree = new DecisionTreeRegressionModel(treeMetadata.uid, + root.asInstanceOf[RegressionNode], numFeatures) DefaultParamsReader.getAndSetParams(tree, treeMetadata) tree } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index 78a4972adbdbb..bb972e9706fc1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -310,15 +310,15 @@ object RandomForestClassificationModel extends MLReadable[RandomForestClassifica override def load(path: String): RandomForestClassificationModel = { implicit val format = DefaultFormats val (metadata: Metadata, treesData: Array[(Metadata, Node)], _) = - EnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName) + EnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName, true) val numFeatures = (metadata.metadata \ "numFeatures").extract[Int] val numClasses = (metadata.metadata \ "numClasses").extract[Int] val numTrees = (metadata.metadata \ "numTrees").extract[Int] val trees: Array[DecisionTreeClassificationModel] = treesData.map { case (treeMetadata, root) => - val tree = - new DecisionTreeClassificationModel(treeMetadata.uid, root, numFeatures, numClasses) + val tree = new DecisionTreeClassificationModel(treeMetadata.uid, + root.asInstanceOf[ClassificationNode], numFeatures, numClasses) DefaultParamsReader.getAndSetParams(tree, treeMetadata) tree } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala index ad154fcd010cc..5cef5c9f21f1e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala @@ -160,7 +160,7 @@ object DecisionTreeRegressor extends DefaultParamsReadable[DecisionTreeRegressor @Since("1.4.0") class DecisionTreeRegressionModel private[ml] ( override val uid: String, - override val rootNode: Node, + override val rootNode: RegressionNode, override val numFeatures: Int) extends PredictionModel[Vector, DecisionTreeRegressionModel] with DecisionTreeModel with DecisionTreeRegressorParams with MLWritable with Serializable { @@ -175,7 +175,7 @@ class DecisionTreeRegressionModel private[ml] ( * Construct a decision tree regression model. * @param rootNode Root node of tree, with other nodes attached. */ - private[ml] def this(rootNode: Node, numFeatures: Int) = + private[ml] def this(rootNode: RegressionNode, numFeatures: Int) = this(Identifiable.randomUID("dtr"), rootNode, numFeatures) override def predict(features: Vector): Double = { @@ -279,8 +279,9 @@ object DecisionTreeRegressionModel extends MLReadable[DecisionTreeRegressionMode implicit val format = DefaultFormats val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val numFeatures = (metadata.metadata \ "numFeatures").extract[Int] - val root = loadTreeNodes(path, metadata, sparkSession) - val model = new DecisionTreeRegressionModel(metadata.uid, root, numFeatures) + val root = loadTreeNodes(path, metadata, sparkSession, isClassification = false) + val model = new DecisionTreeRegressionModel(metadata.uid, + root.asInstanceOf[RegressionNode], numFeatures) DefaultParamsReader.getAndSetParams(model, metadata) model } @@ -295,8 +296,8 @@ object DecisionTreeRegressionModel extends MLReadable[DecisionTreeRegressionMode require(oldModel.algo == OldAlgo.Regression, s"Cannot convert non-regression DecisionTreeModel (old API) to" + s" DecisionTreeRegressionModel (new API). Algo is: ${oldModel.algo}") - val rootNode = Node.fromOld(oldModel.topNode, categoricalFeatures) + val rootNode = Node.fromOld(oldModel.topNode, categoricalFeatures, isClassification = false) val uid = if (parent != null) parent.uid else Identifiable.randomUID("dtr") - new DecisionTreeRegressionModel(uid, rootNode, numFeatures) + new DecisionTreeRegressionModel(uid, rootNode.asInstanceOf[RegressionNode], numFeatures) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala index 6569ff2a5bfc1..834aaa0e362d1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala @@ -302,15 +302,15 @@ object GBTRegressionModel extends MLReadable[GBTRegressionModel] { override def load(path: String): GBTRegressionModel = { implicit val format = DefaultFormats val (metadata: Metadata, treesData: Array[(Metadata, Node)], treeWeights: Array[Double]) = - EnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName) + EnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName, false) val numFeatures = (metadata.metadata \ "numFeatures").extract[Int] val numTrees = (metadata.metadata \ "numTrees").extract[Int] val trees: Array[DecisionTreeRegressionModel] = treesData.map { case (treeMetadata, root) => - val tree = - new DecisionTreeRegressionModel(treeMetadata.uid, root, numFeatures) + val tree = new DecisionTreeRegressionModel(treeMetadata.uid, + root.asInstanceOf[RegressionNode], numFeatures) DefaultParamsReader.getAndSetParams(tree, treeMetadata) tree } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala index 2d594460c2475..7f77398ba2a22 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala @@ -269,13 +269,13 @@ object RandomForestRegressionModel extends MLReadable[RandomForestRegressionMode override def load(path: String): RandomForestRegressionModel = { implicit val format = DefaultFormats val (metadata: Metadata, treesData: Array[(Metadata, Node)], treeWeights: Array[Double]) = - EnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName) + EnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName, false) val numFeatures = (metadata.metadata \ "numFeatures").extract[Int] val numTrees = (metadata.metadata \ "numTrees").extract[Int] val trees: Array[DecisionTreeRegressionModel] = treesData.map { case (treeMetadata, root) => - val tree = - new DecisionTreeRegressionModel(treeMetadata.uid, root, numFeatures) + val tree = new DecisionTreeRegressionModel(treeMetadata.uid, + root.asInstanceOf[RegressionNode], numFeatures) DefaultParamsReader.getAndSetParams(tree, treeMetadata) tree } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala index d30be452a436e..0242bc76698d0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala @@ -17,14 +17,16 @@ package org.apache.spark.ml.tree +import org.apache.spark.annotation.Since import org.apache.spark.ml.linalg.Vector import org.apache.spark.mllib.tree.impurity.ImpurityCalculator -import org.apache.spark.mllib.tree.model.{ImpurityStats, InformationGainStats => OldInformationGainStats, Node => OldNode, Predict => OldPredict} +import org.apache.spark.mllib.tree.model.{ImpurityStats, InformationGainStats => OldInformationGainStats, + Node => OldNode, Predict => OldPredict} /** * Decision tree node interface. */ -sealed abstract class Node extends Serializable { +sealed trait Node extends Serializable { // TODO: Add aggregate stats (once available). This will happen after we move the DecisionTree // code into the new API and deprecate the old API. SPARK-3727 @@ -84,35 +86,86 @@ private[ml] object Node { /** * Create a new Node from the old Node format, recursively creating child nodes as needed. */ - def fromOld(oldNode: OldNode, categoricalFeatures: Map[Int, Int]): Node = { + def fromOld( + oldNode: OldNode, + categoricalFeatures: Map[Int, Int], + isClassification: Boolean): Node = { if (oldNode.isLeaf) { // TODO: Once the implementation has been moved to this API, then include sufficient // statistics here. - new LeafNode(prediction = oldNode.predict.predict, - impurity = oldNode.impurity, impurityStats = null) + if (isClassification) { + new ClassificationLeafNode(prediction = oldNode.predict.predict, + impurity = oldNode.impurity, impurityStats = null) + } else { + new RegressionLeafNode(prediction = oldNode.predict.predict, + impurity = oldNode.impurity, impurityStats = null) + } } else { val gain = if (oldNode.stats.nonEmpty) { oldNode.stats.get.gain } else { 0.0 } - new InternalNode(prediction = oldNode.predict.predict, impurity = oldNode.impurity, - gain = gain, leftChild = fromOld(oldNode.leftNode.get, categoricalFeatures), - rightChild = fromOld(oldNode.rightNode.get, categoricalFeatures), - split = Split.fromOld(oldNode.split.get, categoricalFeatures), impurityStats = null) + if (isClassification) { + new ClassificationInternalNode(prediction = oldNode.predict.predict, + impurity = oldNode.impurity, gain = gain, + leftChild = fromOld(oldNode.leftNode.get, categoricalFeatures, true) + .asInstanceOf[ClassificationNode], + rightChild = fromOld(oldNode.rightNode.get, categoricalFeatures, true) + .asInstanceOf[ClassificationNode], + split = Split.fromOld(oldNode.split.get, categoricalFeatures), impurityStats = null) + } else { + new RegressionInternalNode(prediction = oldNode.predict.predict, + impurity = oldNode.impurity, gain = gain, + leftChild = fromOld(oldNode.leftNode.get, categoricalFeatures, false) + .asInstanceOf[RegressionNode], + rightChild = fromOld(oldNode.rightNode.get, categoricalFeatures, false) + .asInstanceOf[RegressionNode], + split = Split.fromOld(oldNode.split.get, categoricalFeatures), impurityStats = null) + } } } } -/** - * Decision tree leaf node. - * @param prediction Prediction this node makes - * @param impurity Impurity measure at this node (for training data) - */ -class LeafNode private[ml] ( - override val prediction: Double, - override val impurity: Double, - override private[ml] val impurityStats: ImpurityCalculator) extends Node { +@Since("2.4.0") +sealed trait ClassificationNode extends Node { + + /** + * Get count of training examples for specified label in this node + * @param label label number in the range [0, numClasses) + */ + @Since("2.4.0") + def getLabelCount(label: Int): Double = { + require(label >= 0 && label < impurityStats.stats.length, + "label should be in the range between 0 (inclusive) " + + s"and ${impurityStats.stats.length} (exclusive).") + impurityStats.stats(label) + } +} + +@Since("2.4.0") +sealed trait RegressionNode extends Node { + + /** Number of training data points in this node */ + @Since("2.4.0") + def getCount: Double = impurityStats.stats(0) + + /** Sum over training data points of the labels in this node */ + @Since("2.4.0") + def getSum: Double = impurityStats.stats(1) + + /** Sum over training data points of the square of the labels in this node */ + @Since("2.4.0") + def getSumOfSquares: Double = impurityStats.stats(2) +} + +@Since("2.4.0") +sealed trait LeafNode extends Node { + + /** Prediction this node makes. */ + def prediction: Double + + def impurity: Double override def toString: String = s"LeafNode(prediction = $prediction, impurity = $impurity)" @@ -135,32 +188,58 @@ class LeafNode private[ml] ( override private[ml] def maxSplitFeatureIndex(): Int = -1 +} + +/** + * Decision tree leaf node for classification. + */ +@Since("2.4.0") +class ClassificationLeafNode private[ml] ( + override val prediction: Double, + override val impurity: Double, + override private[ml] val impurityStats: ImpurityCalculator) + extends ClassificationNode with LeafNode { + override private[tree] def deepCopy(): Node = { - new LeafNode(prediction, impurity, impurityStats) + new ClassificationLeafNode(prediction, impurity, impurityStats) } } /** - * Internal Decision Tree node. - * @param prediction Prediction this node would make if it were a leaf node - * @param impurity Impurity measure at this node (for training data) - * @param gain Information gain value. Values less than 0 indicate missing values; - * this quirk will be removed with future updates. - * @param leftChild Left-hand child node - * @param rightChild Right-hand child node - * @param split Information about the test used to split to the left or right child. + * Decision tree leaf node for regression. */ -class InternalNode private[ml] ( +@Since("2.4.0") +class RegressionLeafNode private[ml] ( override val prediction: Double, override val impurity: Double, - val gain: Double, - val leftChild: Node, - val rightChild: Node, - val split: Split, - override private[ml] val impurityStats: ImpurityCalculator) extends Node { + override private[ml] val impurityStats: ImpurityCalculator) + extends RegressionNode with LeafNode { - // Note to developers: The constructor argument impurityStats should be reconsidered before we - // make the constructor public. We may be able to improve the representation. + override private[tree] def deepCopy(): Node = { + new RegressionLeafNode(prediction, impurity, impurityStats) + } +} + +/** + * Internal Decision Tree node. + */ +@Since("2.4.0") +sealed trait InternalNode extends Node { + + /** + * Information gain value. Values less than 0 indicate missing values; + * this quirk will be removed with future updates. + */ + def gain: Double + + /** Left-hand child node */ + def leftChild: Node + + /** Right-hand child node */ + def rightChild: Node + + /** Information about the test used to split to the left or right child. */ + def split: Split override def toString: String = { s"InternalNode(prediction = $prediction, impurity = $impurity, split = $split)" @@ -205,11 +284,6 @@ class InternalNode private[ml] ( math.max(split.featureIndex, math.max(leftChild.maxSplitFeatureIndex(), rightChild.maxSplitFeatureIndex())) } - - override private[tree] def deepCopy(): Node = { - new InternalNode(prediction, impurity, gain, leftChild.deepCopy(), rightChild.deepCopy(), - split, impurityStats) - } } private object InternalNode { @@ -240,6 +314,57 @@ private object InternalNode { } } +/** + * Internal Decision Tree node for regression. + */ +@Since("2.4.0") +class ClassificationInternalNode private[ml] ( + override val prediction: Double, + override val impurity: Double, + override val gain: Double, + override val leftChild: ClassificationNode, + override val rightChild: ClassificationNode, + override val split: Split, + override private[ml] val impurityStats: ImpurityCalculator) + extends ClassificationNode with InternalNode { + + // Note to developers: The constructor argument impurityStats should be reconsidered before we + // make the constructor public. We may be able to improve the representation. + + override private[tree] def deepCopy(): Node = { + new ClassificationInternalNode(prediction, impurity, gain, + leftChild.deepCopy().asInstanceOf[ClassificationNode], + rightChild.deepCopy().asInstanceOf[ClassificationNode], + split, impurityStats) + } +} + +/** + * Internal Decision Tree node for regression. + */ +@Since("2.4.0") +class RegressionInternalNode private[ml] ( + override val prediction: Double, + override val impurity: Double, + override val gain: Double, + override val leftChild: RegressionNode, + override val rightChild: RegressionNode, + override val split: Split, + override private[ml] val impurityStats: ImpurityCalculator) + extends RegressionNode with InternalNode { + + // Note to developers: The constructor argument impurityStats should be reconsidered before we + // make the constructor public. We may be able to improve the representation. + + override private[tree] def deepCopy(): Node = { + new RegressionInternalNode(prediction, impurity, gain, + leftChild.deepCopy().asInstanceOf[RegressionNode], + rightChild.deepCopy().asInstanceOf[RegressionNode], + split, impurityStats) + } +} + + /** * Version of a node used in learning. This uses vars so that we can modify nodes as we split the * tree by adding children, etc. @@ -265,30 +390,52 @@ private[tree] class LearningNode( var isLeaf: Boolean, var stats: ImpurityStats) extends Serializable { - def toNode: Node = toNode(prune = true) + def toNode(isClassification: Boolean): Node = toNode(isClassification, prune = true) + + def toClassificationNode(prune: Boolean = true): ClassificationNode = { + toNode(true, prune).asInstanceOf[ClassificationNode] + } + + def toRegressionNode(prune: Boolean = true): RegressionNode = { + toNode(false, prune).asInstanceOf[RegressionNode] + } /** * Convert this [[LearningNode]] to a regular [[Node]], and recurse on any children. */ - def toNode(prune: Boolean = true): Node = { + def toNode(isClassification: Boolean, prune: Boolean): Node = { if (!leftChild.isEmpty || !rightChild.isEmpty) { assert(leftChild.nonEmpty && rightChild.nonEmpty && split.nonEmpty && stats != null, "Unknown error during Decision Tree learning. Could not convert LearningNode to Node.") - (leftChild.get.toNode(prune), rightChild.get.toNode(prune)) match { + (leftChild.get.toNode(isClassification, prune), + rightChild.get.toNode(isClassification, prune)) match { case (l: LeafNode, r: LeafNode) if prune && l.prediction == r.prediction => - new LeafNode(l.prediction, stats.impurity, stats.impurityCalculator) + if (isClassification) { + new ClassificationLeafNode(l.prediction, stats.impurity, stats.impurityCalculator) + } else { + new RegressionLeafNode(l.prediction, stats.impurity, stats.impurityCalculator) + } case (l, r) => - new InternalNode(stats.impurityCalculator.predict, stats.impurity, stats.gain, - l, r, split.get, stats.impurityCalculator) + if (isClassification) { + new ClassificationInternalNode(stats.impurityCalculator.predict, stats.impurity, + stats.gain, l.asInstanceOf[ClassificationNode], r.asInstanceOf[ClassificationNode], + split.get, stats.impurityCalculator) + } else { + new RegressionInternalNode(stats.impurityCalculator.predict, stats.impurity, stats.gain, + l.asInstanceOf[RegressionNode], r.asInstanceOf[RegressionNode], + split.get, stats.impurityCalculator) + } } } else { - if (stats.valid) { - new LeafNode(stats.impurityCalculator.predict, stats.impurity, + // Here we want to keep same behavior with the old mllib.DecisionTreeModel + val impurity = if (stats.valid) stats.impurity else -1.0 + if (isClassification) { + new ClassificationLeafNode(stats.impurityCalculator.predict, impurity, stats.impurityCalculator) } else { - // Here we want to keep same behavior with the old mllib.DecisionTreeModel - new LeafNode(stats.impurityCalculator.predict, -1.0, stats.impurityCalculator) + new RegressionLeafNode(stats.impurityCalculator.predict, impurity, + stats.impurityCalculator) } } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index 16f32d76b9984..056a94b351f79 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -224,23 +224,23 @@ private[spark] object RandomForest extends Logging { case Some(uid) => if (strategy.algo == OldAlgo.Classification) { topNodes.map { rootNode => - new DecisionTreeClassificationModel(uid, rootNode.toNode(prune), numFeatures, - strategy.getNumClasses) + new DecisionTreeClassificationModel(uid, rootNode.toClassificationNode(prune), + numFeatures, strategy.getNumClasses) } } else { topNodes.map { rootNode => - new DecisionTreeRegressionModel(uid, rootNode.toNode(prune), numFeatures) + new DecisionTreeRegressionModel(uid, rootNode.toRegressionNode(prune), numFeatures) } } case None => if (strategy.algo == OldAlgo.Classification) { topNodes.map { rootNode => - new DecisionTreeClassificationModel(rootNode.toNode(prune), numFeatures, + new DecisionTreeClassificationModel(rootNode.toClassificationNode(prune), numFeatures, strategy.getNumClasses) } } else { topNodes.map(rootNode => - new DecisionTreeRegressionModel(rootNode.toNode(prune), numFeatures)) + new DecisionTreeRegressionModel(rootNode.toRegressionNode(prune), numFeatures)) } } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala index 4aa4c3617e7fd..f027b14f1d476 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala @@ -219,8 +219,10 @@ private[ml] object TreeEnsembleModel { importances.changeValue(feature, scaledGain, _ + scaledGain) computeFeatureImportance(n.leftChild, importances) computeFeatureImportance(n.rightChild, importances) - case n: LeafNode => + case _: LeafNode => // do nothing + case _ => + throw new IllegalArgumentException(s"Unknown node type: ${node.getClass.toString}") } } @@ -317,6 +319,8 @@ private[ml] object DecisionTreeModelReadWrite { (Seq(NodeData(id, node.prediction, node.impurity, node.impurityStats.stats, -1.0, -1, -1, SplitData(-1, Array.empty[Double], -1))), id) + case _ => + throw new IllegalArgumentException(s"Unknown node type: ${node.getClass.toString}") } } @@ -327,7 +331,7 @@ private[ml] object DecisionTreeModelReadWrite { def loadTreeNodes( path: String, metadata: DefaultParamsReader.Metadata, - sparkSession: SparkSession): Node = { + sparkSession: SparkSession, isClassification: Boolean): Node = { import sparkSession.implicits._ implicit val format = DefaultFormats @@ -339,7 +343,7 @@ private[ml] object DecisionTreeModelReadWrite { val dataPath = new Path(path, "data").toString val data = sparkSession.read.parquet(dataPath).as[NodeData] - buildTreeFromNodes(data.collect(), impurityType) + buildTreeFromNodes(data.collect(), impurityType, isClassification) } /** @@ -348,7 +352,8 @@ private[ml] object DecisionTreeModelReadWrite { * @param impurityType Impurity type for this tree * @return Root node of reconstructed tree */ - def buildTreeFromNodes(data: Array[NodeData], impurityType: String): Node = { + def buildTreeFromNodes(data: Array[NodeData], impurityType: String, + isClassification: Boolean): Node = { // Load all nodes, sorted by ID. val nodes = data.sortBy(_.id) // Sanity checks; could remove @@ -364,10 +369,21 @@ private[ml] object DecisionTreeModelReadWrite { val node = if (n.leftChild != -1) { val leftChild = finalNodes(n.leftChild) val rightChild = finalNodes(n.rightChild) - new InternalNode(n.prediction, n.impurity, n.gain, leftChild, rightChild, - n.split.getSplit, impurityStats) + if (isClassification) { + new ClassificationInternalNode(n.prediction, n.impurity, n.gain, + leftChild.asInstanceOf[ClassificationNode], rightChild.asInstanceOf[ClassificationNode], + n.split.getSplit, impurityStats) + } else { + new RegressionInternalNode(n.prediction, n.impurity, n.gain, + leftChild.asInstanceOf[RegressionNode], rightChild.asInstanceOf[RegressionNode], + n.split.getSplit, impurityStats) + } } else { - new LeafNode(n.prediction, n.impurity, impurityStats) + if (isClassification) { + new ClassificationLeafNode(n.prediction, n.impurity, impurityStats) + } else { + new RegressionLeafNode(n.prediction, n.impurity, impurityStats) + } } finalNodes(n.id) = node } @@ -421,7 +437,8 @@ private[ml] object EnsembleModelReadWrite { path: String, sql: SparkSession, className: String, - treeClassName: String): (Metadata, Array[(Metadata, Node)], Array[Double]) = { + treeClassName: String, + isClassification: Boolean): (Metadata, Array[(Metadata, Node)], Array[Double]) = { import sql.implicits._ implicit val format = DefaultFormats val metadata = DefaultParamsReader.loadMetadata(path, sql.sparkContext, className) @@ -449,7 +466,8 @@ private[ml] object EnsembleModelReadWrite { val rootNodesRDD: RDD[(Int, Node)] = nodeData.rdd.map(d => (d.treeID, d.nodeData)).groupByKey().map { case (treeID: Int, nodeData: Iterable[NodeData]) => - treeID -> DecisionTreeModelReadWrite.buildTreeFromNodes(nodeData.toArray, impurityType) + treeID -> DecisionTreeModelReadWrite.buildTreeFromNodes( + nodeData.toArray, impurityType, isClassification) } val rootNodes: Array[Node] = rootNodesRDD.sortByKey().values.collect() (metadata, treesMetadata.zip(rootNodes), treesWeights) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala index 2930f4900d50e..d3dbb4e754d3d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala @@ -21,7 +21,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.tree.LeafNode +import org.apache.spark.ml.tree.ClassificationLeafNode import org.apache.spark.ml.tree.impl.TreeTests import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint} @@ -61,7 +61,8 @@ class DecisionTreeClassifierSuite extends MLTest with DefaultReadWriteTest { test("params") { ParamsSuite.checkParams(new DecisionTreeClassifier) - val model = new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0, null), 1, 2) + val model = new DecisionTreeClassificationModel("dtc", + new ClassificationLeafNode(0.0, 0.0, null), 1, 2) ParamsSuite.checkParams(model) } @@ -375,6 +376,32 @@ class DecisionTreeClassifierSuite extends MLTest with DefaultReadWriteTest { testDefaultReadWrite(model) } + + test("label/impurity stats") { + val arr = Array( + LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0)))), + LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0)))), + LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 1.0))))) + val rdd = sc.parallelize(arr) + val df = TreeTests.setMetadata(rdd, Map.empty[Int, Int], 2) + val dt1 = new DecisionTreeClassifier() + .setImpurity("entropy") + .setMaxDepth(2) + .setMinInstancesPerNode(2) + val model1 = dt1.fit(df) + + val rootNode1 = model1.rootNode + assert(Array(rootNode1.getLabelCount(0), rootNode1.getLabelCount(1)) === Array(2.0, 1.0)) + + val dt2 = new DecisionTreeClassifier() + .setImpurity("gini") + .setMaxDepth(2) + .setMinInstancesPerNode(2) + val model2 = dt2.fit(df) + + val rootNode2 = model2.rootNode + assert(Array(rootNode2.getLabelCount(0), rootNode2.getLabelCount(1)) === Array(2.0, 1.0)) + } } private[ml] object DecisionTreeClassifierSuite extends SparkFunSuite { diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala index 57796069f6052..f0ee5496f9d1d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala @@ -24,7 +24,7 @@ import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.regression.DecisionTreeRegressionModel -import org.apache.spark.ml.tree.LeafNode +import org.apache.spark.ml.tree.RegressionLeafNode import org.apache.spark.ml.tree.impl.TreeTests import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ @@ -69,7 +69,7 @@ class GBTClassifierSuite extends MLTest with DefaultReadWriteTest { test("params") { ParamsSuite.checkParams(new GBTClassifier) val model = new GBTClassificationModel("gbtc", - Array(new DecisionTreeRegressionModel("dtr", new LeafNode(0.0, 0.0, null), 1)), + Array(new DecisionTreeRegressionModel("dtr", new RegressionLeafNode(0.0, 0.0, null), 1)), Array(1.0), 1, 2) ParamsSuite.checkParams(model) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala index ba4a9cf082785..3062aa9f3d274 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala @@ -21,7 +21,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.tree.LeafNode +import org.apache.spark.ml.tree.ClassificationLeafNode import org.apache.spark.ml.tree.impl.TreeTests import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint} @@ -71,7 +71,8 @@ class RandomForestClassifierSuite extends MLTest with DefaultReadWriteTest { test("params") { ParamsSuite.checkParams(new RandomForestClassifier) val model = new RandomForestClassificationModel("rfc", - Array(new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0, null), 1, 2)), 2, 2) + Array(new DecisionTreeClassificationModel("dtc", + new ClassificationLeafNode(0.0, 0.0, null), 1, 2)), 2, 2) ParamsSuite.checkParams(model) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala index 29a438396516b..9ae27339b11d5 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala @@ -191,6 +191,20 @@ class DecisionTreeRegressorSuite extends MLTest with DefaultReadWriteTest { TreeTests.allParamSettings ++ Map("maxDepth" -> 0), TreeTests.allParamSettings ++ Map("maxDepth" -> 0), checkModelData) } + + test("label/impurity stats") { + val categoricalFeatures = Map(0 -> 2, 1 -> 2) + val df = TreeTests.setMetadata(categoricalDataPointsRDD, categoricalFeatures, numClasses = 0) + val dtr = new DecisionTreeRegressor() + .setImpurity("variance") + .setMaxDepth(2) + .setMaxBins(8) + val model = dtr.fit(df) + val statInfo = model.rootNode + + assert(statInfo.getCount == 1000.0 && statInfo.getSum == 600.0 + && statInfo.getSumOfSquares == 600.0) + } } private[ml] object DecisionTreeRegressorSuite extends SparkFunSuite { diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala index 743dacf146fe7..4dbbd75d2466d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala @@ -340,8 +340,8 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { assert(topNode.stats.impurity > 0.0) // set impurity and predict for child nodes - assert(topNode.leftChild.get.toNode.prediction === 0.0) - assert(topNode.rightChild.get.toNode.prediction === 1.0) + assert(topNode.leftChild.get.toNode(isClassification = true).prediction === 0.0) + assert(topNode.rightChild.get.toNode(isClassification = true).prediction === 1.0) assert(topNode.leftChild.get.stats.impurity === 0.0) assert(topNode.rightChild.get.stats.impurity === 0.0) } @@ -382,8 +382,8 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { assert(topNode.stats.impurity > 0.0) // set impurity and predict for child nodes - assert(topNode.leftChild.get.toNode.prediction === 0.0) - assert(topNode.rightChild.get.toNode.prediction === 1.0) + assert(topNode.leftChild.get.toNode(isClassification = true).prediction === 0.0) + assert(topNode.rightChild.get.toNode(isClassification = true).prediction === 1.0) assert(topNode.leftChild.get.stats.impurity === 0.0) assert(topNode.rightChild.get.stats.impurity === 0.0) } @@ -582,18 +582,18 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { left right */ val leftImp = new GiniCalculator(Array(3.0, 2.0, 1.0)) - val left = new LeafNode(0.0, leftImp.calculate(), leftImp) + val left = new ClassificationLeafNode(0.0, leftImp.calculate(), leftImp) val rightImp = new GiniCalculator(Array(1.0, 2.0, 5.0)) - val right = new LeafNode(2.0, rightImp.calculate(), rightImp) + val right = new ClassificationLeafNode(2.0, rightImp.calculate(), rightImp) - val parent = TreeTests.buildParentNode(left, right, new ContinuousSplit(0, 0.5)) + val parent = TreeTests.buildParentNode(left, right, new ContinuousSplit(0, 0.5), true) val parentImp = parent.impurityStats val left2Imp = new GiniCalculator(Array(1.0, 6.0, 1.0)) - val left2 = new LeafNode(0.0, left2Imp.calculate(), left2Imp) + val left2 = new ClassificationLeafNode(0.0, left2Imp.calculate(), left2Imp) - val grandParent = TreeTests.buildParentNode(left2, parent, new ContinuousSplit(1, 1.0)) + val grandParent = TreeTests.buildParentNode(left2, parent, new ContinuousSplit(1, 1.0), true) val grandImp = grandParent.impurityStats // Test feature importance computed at different subtrees. @@ -618,8 +618,8 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { // Forest consisting of (full tree) + (internal node with 2 leafs) val trees = Array(parent, grandParent).map { root => - new DecisionTreeClassificationModel(root, numFeatures = 2, numClasses = 3) - .asInstanceOf[DecisionTreeModel] + new DecisionTreeClassificationModel(root.asInstanceOf[ClassificationNode], + numFeatures = 2, numClasses = 3).asInstanceOf[DecisionTreeModel] } val importances: Vector = TreeEnsembleModel.featureImportances(trees, 2) val tree2norm = feature0importance + feature1importance diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala index b6894b30b0c2b..3f03d909d4a4c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala @@ -159,7 +159,7 @@ private[ml] object TreeTests extends SparkFunSuite { * @param split Split for parent node * @return Parent node with children attached */ - def buildParentNode(left: Node, right: Node, split: Split): Node = { + def buildParentNode(left: Node, right: Node, split: Split, isClassification: Boolean): Node = { val leftImp = left.impurityStats val rightImp = right.impurityStats val parentImp = leftImp.copy.add(rightImp) @@ -168,7 +168,15 @@ private[ml] object TreeTests extends SparkFunSuite { val gain = parentImp.calculate() - (leftWeight * leftImp.calculate() + rightWeight * rightImp.calculate()) val pred = parentImp.predict - new InternalNode(pred, parentImp.calculate(), gain, left, right, split, parentImp) + if (isClassification) { + new ClassificationInternalNode(pred, parentImp.calculate(), gain, + left.asInstanceOf[ClassificationNode], right.asInstanceOf[ClassificationNode], + split, parentImp) + } else { + new RegressionInternalNode(pred, parentImp.calculate(), gain, + left.asInstanceOf[RegressionNode], right.asInstanceOf[RegressionNode], + split, parentImp) + } } /** diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 1b6d1dec69d49..b37b4d51775e8 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -55,7 +55,14 @@ object MimaExcludes { ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.StorageStatus.numRddBlocksById"), ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.StorageStatus.memUsedByRdd"), ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.StorageStatus.cacheSize"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.StorageStatus.rddStorageLevel") + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.StorageStatus.rddStorageLevel"), + + // [SPARK-14681][ML] Provide label/impurity stats for spark.ml decision tree nodes + ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.ml.tree.LeafNode"), + ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.ml.tree.InternalNode"), + ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.ml.tree.Node"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.DecisionTreeClassificationModel.this"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.regression.DecisionTreeRegressionModel.this") ) // Exclude rules for 2.3.x From 61b724724cc4a18818774ecaaa5a45b70fdb8dae Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Mon, 9 Apr 2018 14:07:33 -0700 Subject: [PATCH 0589/1161] [INFRA] Close stale PRs. Closes #20957 Closes #20792 From f94f3624ea81053653a06560808cb71f510c6828 Mon Sep 17 00:00:00 2001 From: Kris Mok Date: Mon, 9 Apr 2018 21:07:28 -0700 Subject: [PATCH 0590/1161] [SPARK-23947][SQL] Add hashUTF8String convenience method to hasher classes ## What changes were proposed in this pull request? Add `hashUTF8String()` to the hasher classes to allow Spark SQL codegen to generate cleaner code for hashing `UTF8String`s. No change in behavior otherwise. Although with the introduction of SPARK-10399, the code size for hashing `UTF8String` is already smaller, it's still good to extract a separate function in the hasher classes so that the generated code can stay clean. ## How was this patch tested? Existing tests. Author: Kris Mok Closes #21016 from rednaxelafx/hashutf8. --- .../apache/spark/sql/catalyst/expressions/HiveHasher.java | 5 +++++ .../java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java | 7 ++++++- .../org/apache/spark/sql/catalyst/expressions/XXH64.java | 5 +++++ .../org/apache/spark/sql/catalyst/expressions/hash.scala | 6 ++---- 4 files changed, 18 insertions(+), 5 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/expressions/HiveHasher.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/expressions/HiveHasher.java index c34e36903a93e..62b75ae8aa01d 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/expressions/HiveHasher.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/expressions/HiveHasher.java @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions; import org.apache.spark.unsafe.memory.MemoryBlock; +import org.apache.spark.unsafe.types.UTF8String; /** * Simulates Hive's hashing function from Hive v1.2.1 @@ -51,4 +52,8 @@ public static int hashUnsafeBytesBlock(MemoryBlock mb) { public static int hashUnsafeBytes(Object base, long offset, int lengthInBytes) { return hashUnsafeBytesBlock(MemoryBlock.allocateFromObject(base, offset, lengthInBytes)); } + + public static int hashUTF8String(UTF8String str) { + return hashUnsafeBytesBlock(str.getMemoryBlock()); + } } diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java index f372b19fac119..aff6e93d647fe 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java @@ -20,6 +20,7 @@ import com.google.common.primitives.Ints; import org.apache.spark.unsafe.memory.MemoryBlock; +import org.apache.spark.unsafe.types.UTF8String; /** * 32-bit Murmur3 hasher. This is based on Guava's Murmur3_32HashFunction. @@ -82,6 +83,10 @@ public static int hashUnsafeBytesBlock(MemoryBlock base, int seed) { return fmix(h1, lengthInBytes); } + public static int hashUTF8String(UTF8String str, int seed) { + return hashUnsafeBytesBlock(str.getMemoryBlock(), seed); + } + public static int hashUnsafeBytes(Object base, long offset, int lengthInBytes, int seed) { return hashUnsafeBytesBlock(MemoryBlock.allocateFromObject(base, offset, lengthInBytes), seed); } @@ -91,7 +96,7 @@ public static int hashUnsafeBytes2(Object base, long offset, int lengthInBytes, } public static int hashUnsafeBytes2Block(MemoryBlock base, int seed) { - // This is compatible with original and another implementations. + // This is compatible with original and other implementations. // Use this method for new components after Spark 2.3. int lengthInBytes = Ints.checkedCast(base.size()); assert (lengthInBytes >= 0) : "lengthInBytes cannot be negative"; diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/XXH64.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/XXH64.java index fe727f6011cbf..8e9c0a2e9dc81 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/XXH64.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/XXH64.java @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.expressions; import org.apache.spark.unsafe.memory.MemoryBlock; +import org.apache.spark.unsafe.types.UTF8String; // scalastyle: off /** @@ -107,6 +108,10 @@ public static long hashUnsafeBytesBlock(MemoryBlock mb, long seed) { return fmix(hash); } + public static long hashUTF8String(UTF8String str, long seed) { + return hashUnsafeBytesBlock(str.getMemoryBlock(), seed); + } + public static long hashUnsafeBytes(Object base, long offset, int length, long seed) { return hashUnsafeBytesBlock(MemoryBlock.allocateFromObject(base, offset, length), seed); } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala index df29c38d64d3d..ef790338bdd27 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala @@ -361,8 +361,7 @@ abstract class HashExpression[E] extends Expression { } protected def genHashString(input: String, result: String): String = { - val mb = s"$input.getMemoryBlock()" - s"$result = $hasherClassName.hashUnsafeBytesBlock($mb, $result);" + s"$result = $hasherClassName.hashUTF8String($input, $result);" } protected def genHashForMap( @@ -725,8 +724,7 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] { """ override protected def genHashString(input: String, result: String): String = { - val mb = s"$input.getMemoryBlock()" - s"$result = $hasherClassName.hashUnsafeBytesBlock($mb);" + s"$result = $hasherClassName.hashUTF8String($input);" } override protected def genHashForArray( From 64988841540464e261b0cbaede43058e7bd36261 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Mon, 9 Apr 2018 21:49:49 -0700 Subject: [PATCH 0591/1161] [SPARK-23898][SQL] Simplify add & subtract code generation ## What changes were proposed in this pull request? Code generation for the `Add` and `Subtract` expressions was not done using the `BinaryArithmetic.doCodeGen` method because these expressions also support `CalendarInterval`. This leads to a bit of duplication. This PR gets rid of that duplication by adding `calendarIntervalMethod` to `BinaryArithmetic` and doing the code generation for `CalendarInterval` in `BinaryArithmetic` instead. ## How was this patch tested? Existing tests. Author: Herman van Hovell Closes #21005 from hvanhovell/SPARK-23898. --- .../sql/catalyst/expressions/arithmetic.scala | 50 ++++++++----------- 1 file changed, 20 insertions(+), 30 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 478ff3a7c1011..defd6f3cd8849 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -43,7 +43,7 @@ case class UnaryMinus(child: Expression) extends UnaryExpression private lazy val numeric = TypeUtils.getNumeric(dataType) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = dataType match { - case dt: DecimalType => defineCodeGen(ctx, ev, c => s"$c.unary_$$minus()") + case _: DecimalType => defineCodeGen(ctx, ev, c => s"$c.unary_$$minus()") case dt: NumericType => nullSafeCodeGen(ctx, ev, eval => { val originValue = ctx.freshName("origin") // codegen would fail to compile if we just write (-($c)) @@ -52,7 +52,7 @@ case class UnaryMinus(child: Expression) extends UnaryExpression ${CodeGenerator.javaType(dt)} $originValue = (${CodeGenerator.javaType(dt)})($eval); ${ev.value} = (${CodeGenerator.javaType(dt)})(-($originValue)); """}) - case dt: CalendarIntervalType => defineCodeGen(ctx, ev, c => s"$c.negate()") + case _: CalendarIntervalType => defineCodeGen(ctx, ev, c => s"$c.negate()") } protected override def nullSafeEval(input: Any): Any = { @@ -104,7 +104,7 @@ case class Abs(child: Expression) private lazy val numeric = TypeUtils.getNumeric(dataType) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = dataType match { - case dt: DecimalType => + case _: DecimalType => defineCodeGen(ctx, ev, c => s"$c.abs()") case dt: NumericType => defineCodeGen(ctx, ev, c => s"(${CodeGenerator.javaType(dt)})(java.lang.Math.abs($c))") @@ -117,15 +117,21 @@ abstract class BinaryArithmetic extends BinaryOperator with NullIntolerant { override def dataType: DataType = left.dataType - override lazy val resolved = childrenResolved && checkInputDataTypes().isSuccess + override lazy val resolved: Boolean = childrenResolved && checkInputDataTypes().isSuccess /** Name of the function for this expression on a [[Decimal]] type. */ def decimalMethod: String = sys.error("BinaryArithmetics must override either decimalMethod or genCode") + /** Name of the function for this expression on a [[CalendarInterval]] type. */ + def calendarIntervalMethod: String = + sys.error("BinaryArithmetics must override either calendarIntervalMethod or genCode") + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = dataType match { - case dt: DecimalType => + case _: DecimalType => defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$decimalMethod($eval2)") + case CalendarIntervalType => + defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$calendarIntervalMethod($eval2)") // byte and short are casted into int when add, minus, times or divide case ByteType | ShortType => defineCodeGen(ctx, ev, @@ -152,6 +158,10 @@ case class Add(left: Expression, right: Expression) extends BinaryArithmetic { override def symbol: String = "+" + override def decimalMethod: String = "$plus" + + override def calendarIntervalMethod: String = "add" + private lazy val numeric = TypeUtils.getNumeric(dataType) protected override def nullSafeEval(input1: Any, input2: Any): Any = { @@ -161,18 +171,6 @@ case class Add(left: Expression, right: Expression) extends BinaryArithmetic { numeric.plus(input1, input2) } } - - override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = dataType match { - case dt: DecimalType => - defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$$plus($eval2)") - case ByteType | ShortType => - defineCodeGen(ctx, ev, - (eval1, eval2) => s"(${CodeGenerator.javaType(dataType)})($eval1 $symbol $eval2)") - case CalendarIntervalType => - defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.add($eval2)") - case _ => - defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1 $symbol $eval2") - } } @ExpressionDescription( @@ -188,6 +186,10 @@ case class Subtract(left: Expression, right: Expression) extends BinaryArithmeti override def symbol: String = "-" + override def decimalMethod: String = "$minus" + + override def calendarIntervalMethod: String = "subtract" + private lazy val numeric = TypeUtils.getNumeric(dataType) protected override def nullSafeEval(input1: Any, input2: Any): Any = { @@ -197,18 +199,6 @@ case class Subtract(left: Expression, right: Expression) extends BinaryArithmeti numeric.minus(input1, input2) } } - - override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = dataType match { - case dt: DecimalType => - defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$$minus($eval2)") - case ByteType | ShortType => - defineCodeGen(ctx, ev, - (eval1, eval2) => s"(${CodeGenerator.javaType(dataType)})($eval1 $symbol $eval2)") - case CalendarIntervalType => - defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.subtract($eval2)") - case _ => - defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1 $symbol $eval2") - } } @ExpressionDescription( @@ -416,7 +406,7 @@ case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic { override def symbol: String = "pmod" - protected def checkTypesInternal(t: DataType) = + protected def checkTypesInternal(t: DataType): TypeCheckResult = TypeUtils.checkForNumericExpr(t, "pmod") override def inputType: AbstractDataType = NumericType From 95034af69623bb8be5b9f5eabf50980bdeca48e6 Mon Sep 17 00:00:00 2001 From: Zheng RuiFeng Date: Tue, 10 Apr 2018 08:51:35 -0500 Subject: [PATCH 0592/1161] [SPARK-23841][ML] NodeIdCache should unpersist the last cached nodeIdsForInstances ## What changes were proposed in this pull request? unpersist the last cached nodeIdsForInstances in `deleteAllCheckpoints` ## How was this patch tested? existing tests Author: Zheng RuiFeng Closes #20956 from zhengruifeng/NodeIdCache_cleanup. --- .../scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala index a7c5f489dea86..5b14a63ada4ef 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala @@ -95,7 +95,7 @@ private[spark] class NodeIdCache( splits: Array[Array[Split]]): Unit = { if (prevNodeIdsForInstances != null) { // Unpersist the previous one if one exists. - prevNodeIdsForInstances.unpersist() + prevNodeIdsForInstances.unpersist(false) } prevNodeIdsForInstances = nodeIdsForInstances @@ -166,9 +166,13 @@ private[spark] class NodeIdCache( } } } + if (nodeIdsForInstances != null) { + // Unpersist current one if one exists. + nodeIdsForInstances.unpersist(false) + } if (prevNodeIdsForInstances != null) { // Unpersist the previous one if one exists. - prevNodeIdsForInstances.unpersist() + prevNodeIdsForInstances.unpersist(false) } } } From 3323b156f9c0beb0b3c2b724a6faddc6ffdfe99a Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Tue, 10 Apr 2018 17:32:00 +0200 Subject: [PATCH 0593/1161] [SPARK-23864][SQL] Add unsafe object writing to UnsafeWriter ## What changes were proposed in this pull request? This PR moves writing of `UnsafeRow`, `UnsafeArrayData` & `UnsafeMapData` out of the `GenerateUnsafeProjection`/`InterpretedUnsafeProjection` classes into the `UnsafeWriter` interface. This cleans up the code a little bit, and it should also result in less byte code for the code generated path. ## How was this patch tested? Existing tests Author: Herman van Hovell Closes #20986 from hvanhovell/SPARK-23864. --- .../expressions/codegen/UnsafeWriter.java | 72 ++-- .../InterpretedUnsafeProjection.scala | 46 +-- .../codegen/GenerateUnsafeProjection.scala | 322 ++++++++---------- .../spark/sql/types/UserDefinedType.scala | 10 + 4 files changed, 204 insertions(+), 246 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java index de0eb6dbb76be..2781655002000 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java @@ -16,6 +16,9 @@ */ package org.apache.spark.sql.catalyst.expressions.codegen; +import org.apache.spark.sql.catalyst.expressions.UnsafeArrayData; +import org.apache.spark.sql.catalyst.expressions.UnsafeMapData; +import org.apache.spark.sql.catalyst.expressions.UnsafeRow; import org.apache.spark.sql.types.Decimal; import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.ByteArrayMethods; @@ -103,21 +106,7 @@ protected final void zeroOutPaddingBytes(int numBytes) { public abstract void write(int ordinal, Decimal input, int precision, int scale); public final void write(int ordinal, UTF8String input) { - final int numBytes = input.numBytes(); - final int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes); - - // grow the global buffer before writing data. - grow(roundedSize); - - zeroOutPaddingBytes(numBytes); - - // Write the bytes to the variable length portion. - input.writeToMemory(getBuffer(), cursor()); - - setOffsetAndSize(ordinal, numBytes); - - // move the cursor forward. - increaseCursor(roundedSize); + writeUnalignedBytes(ordinal, input.getBaseObject(), input.getBaseOffset(), input.numBytes()); } public final void write(int ordinal, byte[] input) { @@ -125,20 +114,19 @@ public final void write(int ordinal, byte[] input) { } public final void write(int ordinal, byte[] input, int offset, int numBytes) { - final int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(input.length); + writeUnalignedBytes(ordinal, input, Platform.BYTE_ARRAY_OFFSET + offset, numBytes); + } - // grow the global buffer before writing data. + private void writeUnalignedBytes( + int ordinal, + Object baseObject, + long baseOffset, + int numBytes) { + final int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes); grow(roundedSize); - zeroOutPaddingBytes(numBytes); - - // Write the bytes to the variable length portion. - Platform.copyMemory( - input, Platform.BYTE_ARRAY_OFFSET + offset, getBuffer(), cursor(), numBytes); - + Platform.copyMemory(baseObject, baseOffset, getBuffer(), cursor(), numBytes); setOffsetAndSize(ordinal, numBytes); - - // move the cursor forward. increaseCursor(roundedSize); } @@ -156,6 +144,40 @@ public final void write(int ordinal, CalendarInterval input) { increaseCursor(16); } + public final void write(int ordinal, UnsafeRow row) { + writeAlignedBytes(ordinal, row.getBaseObject(), row.getBaseOffset(), row.getSizeInBytes()); + } + + public final void write(int ordinal, UnsafeMapData map) { + writeAlignedBytes(ordinal, map.getBaseObject(), map.getBaseOffset(), map.getSizeInBytes()); + } + + public final void write(UnsafeArrayData array) { + // Unsafe arrays both can be written as a regular array field or as part of a map. This makes + // updating the offset and size dependent on the code path, this is why we currently do not + // provide an method for writing unsafe arrays that also updates the size and offset. + int numBytes = array.getSizeInBytes(); + grow(numBytes); + Platform.copyMemory( + array.getBaseObject(), + array.getBaseOffset(), + getBuffer(), + cursor(), + numBytes); + increaseCursor(numBytes); + } + + private void writeAlignedBytes( + int ordinal, + Object baseObject, + long baseOffset, + int numBytes) { + grow(numBytes); + Platform.copyMemory(baseObject, baseOffset, getBuffer(), cursor(), numBytes); + setOffsetAndSize(ordinal, numBytes); + increaseCursor(numBytes); + } + protected final void writeBoolean(long offset, boolean value) { Platform.putBoolean(getBuffer(), offset, value); } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala index b31466f5c92d1..6d69d69b1c802 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala @@ -173,21 +173,17 @@ object InterpretedUnsafeProjection extends UnsafeProjectionCreator { val rowWriter = new UnsafeRowWriter(writer, numFields) val structWriter = generateStructWriter(rowWriter, fields) (v, i) => { - val previousCursor = writer.cursor() v.getStruct(i, fields.length) match { case row: UnsafeRow => - writeUnsafeData( - rowWriter, - row.getBaseObject, - row.getBaseOffset, - row.getSizeInBytes) + writer.write(i, row) case row => + val previousCursor = writer.cursor() // Nested struct. We don't know where this will start because a row can be // variable length, so we need to update the offsets and zero out the bit mask. rowWriter.resetRowWriter() structWriter.apply(row) + writer.setOffsetAndSizeFromPreviousCursor(i, previousCursor) } - writer.setOffsetAndSizeFromPreviousCursor(i, previousCursor) } case ArrayType(elementType, containsNull) => @@ -214,15 +210,12 @@ object InterpretedUnsafeProjection extends UnsafeProjectionCreator { valueType, valueContainsNull) (v, i) => { - val previousCursor = writer.cursor() v.getMap(i) match { case map: UnsafeMapData => - writeUnsafeData( - valueArrayWriter, - map.getBaseObject, - map.getBaseOffset, - map.getSizeInBytes) + writer.write(i, map) case map => + val previousCursor = writer.cursor() + // preserve 8 bytes to write the key array numBytes later. valueArrayWriter.grow(8) valueArrayWriter.increaseCursor(8) @@ -237,8 +230,8 @@ object InterpretedUnsafeProjection extends UnsafeProjectionCreator { // Write the values. writeArray(valueArrayWriter, valueWriter, map.valueArray()) + writer.setOffsetAndSizeFromPreviousCursor(i, previousCursor) } - writer.setOffsetAndSizeFromPreviousCursor(i, previousCursor) } case udt: UserDefinedType[_] => @@ -318,11 +311,7 @@ object InterpretedUnsafeProjection extends UnsafeProjectionCreator { elementWriter: (SpecializedGetters, Int) => Unit, array: ArrayData): Unit = array match { case unsafe: UnsafeArrayData => - writeUnsafeData( - arrayWriter, - unsafe.getBaseObject, - unsafe.getBaseOffset, - unsafe.getSizeInBytes) + arrayWriter.write(unsafe) case _ => val numElements = array.numElements() arrayWriter.initialize(numElements) @@ -332,23 +321,4 @@ object InterpretedUnsafeProjection extends UnsafeProjectionCreator { i += 1 } } - - /** - * Write an opaque block of data to the buffer. This is used to copy - * [[UnsafeRow]], [[UnsafeArrayData]] and [[UnsafeMapData]] objects. - */ - private def writeUnsafeData( - writer: UnsafeWriter, - baseObject: AnyRef, - baseOffset: Long, - sizeInBytes: Int) : Unit = { - writer.grow(sizeInBytes) - Platform.copyMemory( - baseObject, - baseOffset, - writer.getBuffer, - writer.cursor, - sizeInBytes) - writer.increaseCursor(sizeInBytes) - } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 4a4d76313a543..2fb441ac4500e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -32,14 +32,13 @@ import org.apache.spark.sql.types._ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafeProjection] { /** Returns true iff we support this data type. */ - def canSupport(dataType: DataType): Boolean = dataType match { + def canSupport(dataType: DataType): Boolean = UserDefinedType.sqlType(dataType) match { case NullType => true - case t: AtomicType => true + case _: AtomicType => true case _: CalendarIntervalType => true case t: StructType => t.forall(field => canSupport(field.dataType)) case t: ArrayType if canSupport(t.elementType) => true case MapType(kt, vt, _) if canSupport(kt) && canSupport(vt) => true - case udt: UserDefinedType[_] => canSupport(udt.sqlType) case _ => false } @@ -47,6 +46,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro private def writeStructToBuffer( ctx: CodegenContext, input: String, + index: String, fieldTypes: Seq[DataType], rowWriter: String): String = { // Puts `input` in a local variable to avoid to re-evaluate it if it's a statement. @@ -60,15 +60,19 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val rowWriterClass = classOf[UnsafeRowWriter].getName val structRowWriter = ctx.addMutableState(rowWriterClass, "rowWriter", v => s"$v = new $rowWriterClass($rowWriter, ${fieldEvals.length});") - + val previousCursor = ctx.freshName("previousCursor") s""" - final InternalRow $tmpInput = $input; - if ($tmpInput instanceof UnsafeRow) { - ${writeUnsafeData(ctx, s"((UnsafeRow) $tmpInput)", structRowWriter)} - } else { - ${writeExpressionsToBuffer(ctx, tmpInput, fieldEvals, fieldTypes, structRowWriter)} - } - """ + |final InternalRow $tmpInput = $input; + |if ($tmpInput instanceof UnsafeRow) { + | $rowWriter.write($index, (UnsafeRow) $tmpInput); + |} else { + | // Remember the current cursor so that we can calculate how many bytes are + | // written later. + | final int $previousCursor = $rowWriter.cursor(); + | ${writeExpressionsToBuffer(ctx, tmpInput, fieldEvals, fieldTypes, structRowWriter)} + | $rowWriter.setOffsetAndSizeFromPreviousCursor($index, $previousCursor); + |} + """.stripMargin } private def writeExpressionsToBuffer( @@ -95,10 +99,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val writeFields = inputs.zip(inputTypes).zipWithIndex.map { case ((input, dataType), index) => - val dt = dataType match { - case udt: UserDefinedType[_] => udt.sqlType - case other => other - } + val dt = UserDefinedType.sqlType(dataType) val setNull = dt match { case t: DecimalType if t.precision > Decimal.MAX_LONG_DIGITS => @@ -106,58 +107,22 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro s"$rowWriter.write($index, (Decimal) null, ${t.precision}, ${t.scale});" case _ => s"$rowWriter.setNullAt($index);" } - val previousCursor = ctx.freshName("previousCursor") - - val writeField = dt match { - case t: StructType => - s""" - // Remember the current cursor so that we can calculate how many bytes are - // written later. - final int $previousCursor = $rowWriter.cursor(); - ${writeStructToBuffer(ctx, input.value, t.map(_.dataType), rowWriter)} - $rowWriter.setOffsetAndSizeFromPreviousCursor($index, $previousCursor); - """ - - case a @ ArrayType(et, _) => - s""" - // Remember the current cursor so that we can calculate how many bytes are - // written later. - final int $previousCursor = $rowWriter.cursor(); - ${writeArrayToBuffer(ctx, input.value, et, rowWriter)} - $rowWriter.setOffsetAndSizeFromPreviousCursor($index, $previousCursor); - """ - - case m @ MapType(kt, vt, _) => - s""" - // Remember the current cursor so that we can calculate how many bytes are - // written later. - final int $previousCursor = $rowWriter.cursor(); - ${writeMapToBuffer(ctx, input.value, kt, vt, rowWriter)} - $rowWriter.setOffsetAndSizeFromPreviousCursor($index, $previousCursor); - """ - - case t: DecimalType => - s"$rowWriter.write($index, ${input.value}, ${t.precision}, ${t.scale});" - - case NullType => "" - - case _ => s"$rowWriter.write($index, ${input.value});" - } + val writeField = writeElement(ctx, input.value, index.toString, dt, rowWriter) if (input.isNull == "false") { s""" - ${input.code} - ${writeField.trim} - """ + |${input.code} + |${writeField.trim} + """.stripMargin } else { s""" - ${input.code} - if (${input.isNull}) { - ${setNull.trim} - } else { - ${writeField.trim} - } - """ + |${input.code} + |if (${input.isNull}) { + | ${setNull.trim} + |} else { + | ${writeField.trim} + |} + """.stripMargin } } @@ -171,11 +136,10 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro funcName = "writeFields", arguments = Seq("InternalRow" -> row)) } - s""" - $resetWriter - $writeFieldsCode - """.trim + |$resetWriter + |$writeFieldsCode + """.stripMargin } // TODO: if the nullability of array element is correct, we can use it to save null check. @@ -189,10 +153,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val numElements = ctx.freshName("numElements") val index = ctx.freshName("index") - val et = elementType match { - case udt: UserDefinedType[_] => udt.sqlType - case other => other - } + val et = UserDefinedType.sqlType(elementType) val jt = CodeGenerator.javaType(et) @@ -205,106 +166,100 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val arrayWriterClass = classOf[UnsafeArrayWriter].getName val arrayWriter = ctx.addMutableState(arrayWriterClass, "arrayWriter", v => s"$v = new $arrayWriterClass($rowWriter, $elementOrOffsetSize);") - val previousCursor = ctx.freshName("previousCursor") val element = CodeGenerator.getValue(tmpInput, et, index) - val writeElement = et match { - case t: StructType => - s""" - final int $previousCursor = $arrayWriter.cursor(); - ${writeStructToBuffer(ctx, element, t.map(_.dataType), arrayWriter)} - $arrayWriter.setOffsetAndSizeFromPreviousCursor($index, $previousCursor); - """ - - case a @ ArrayType(et, _) => - s""" - final int $previousCursor = $arrayWriter.cursor(); - ${writeArrayToBuffer(ctx, element, et, arrayWriter)} - $arrayWriter.setOffsetAndSizeFromPreviousCursor($index, $previousCursor); - """ - - case m @ MapType(kt, vt, _) => - s""" - final int $previousCursor = $arrayWriter.cursor(); - ${writeMapToBuffer(ctx, element, kt, vt, arrayWriter)} - $arrayWriter.setOffsetAndSizeFromPreviousCursor($index, $previousCursor); - """ - - case t: DecimalType => - s"$arrayWriter.write($index, $element, ${t.precision}, ${t.scale});" - - case NullType => "" - - case _ => s"$arrayWriter.write($index, $element);" - } - val primitiveTypeName = - if (CodeGenerator.isPrimitiveType(jt)) CodeGenerator.primitiveTypeName(et) else "" s""" - final ArrayData $tmpInput = $input; - if ($tmpInput instanceof UnsafeArrayData) { - ${writeUnsafeData(ctx, s"((UnsafeArrayData) $tmpInput)", arrayWriter)} - } else { - final int $numElements = $tmpInput.numElements(); - $arrayWriter.initialize($numElements); - - for (int $index = 0; $index < $numElements; $index++) { - if ($tmpInput.isNullAt($index)) { - $arrayWriter.setNull${elementOrOffsetSize}Bytes($index); - } else { - $writeElement - } - } - } - """ + |final ArrayData $tmpInput = $input; + |if ($tmpInput instanceof UnsafeArrayData) { + | $rowWriter.write((UnsafeArrayData) $tmpInput); + |} else { + | final int $numElements = $tmpInput.numElements(); + | $arrayWriter.initialize($numElements); + | + | for (int $index = 0; $index < $numElements; $index++) { + | if ($tmpInput.isNullAt($index)) { + | $arrayWriter.setNull${elementOrOffsetSize}Bytes($index); + | } else { + | ${writeElement(ctx, element, index, et, arrayWriter)} + | } + | } + |} + """.stripMargin } // TODO: if the nullability of value element is correct, we can use it to save null check. private def writeMapToBuffer( ctx: CodegenContext, input: String, + index: String, keyType: DataType, valueType: DataType, rowWriter: String): String = { // Puts `input` in a local variable to avoid to re-evaluate it if it's a statement. val tmpInput = ctx.freshName("tmpInput") val tmpCursor = ctx.freshName("tmpCursor") + val previousCursor = ctx.freshName("previousCursor") // Writes out unsafe map according to the format described in `UnsafeMapData`. s""" - final MapData $tmpInput = $input; - if ($tmpInput instanceof UnsafeMapData) { - ${writeUnsafeData(ctx, s"((UnsafeMapData) $tmpInput)", rowWriter)} - } else { - // preserve 8 bytes to write the key array numBytes later. - $rowWriter.grow(8); - $rowWriter.increaseCursor(8); + |final MapData $tmpInput = $input; + |if ($tmpInput instanceof UnsafeMapData) { + | $rowWriter.write($index, (UnsafeMapData) $tmpInput); + |} else { + | // Remember the current cursor so that we can calculate how many bytes are + | // written later. + | final int $previousCursor = $rowWriter.cursor(); + | + | // preserve 8 bytes to write the key array numBytes later. + | $rowWriter.grow(8); + | $rowWriter.increaseCursor(8); + | + | // Remember the current cursor so that we can write numBytes of key array later. + | final int $tmpCursor = $rowWriter.cursor(); + | + | ${writeArrayToBuffer(ctx, s"$tmpInput.keyArray()", keyType, rowWriter)} + | + | // Write the numBytes of key array into the first 8 bytes. + | Platform.putLong( + | $rowWriter.getBuffer(), + | $tmpCursor - 8, + | $rowWriter.cursor() - $tmpCursor); + | + | ${writeArrayToBuffer(ctx, s"$tmpInput.valueArray()", valueType, rowWriter)} + | $rowWriter.setOffsetAndSizeFromPreviousCursor($index, $previousCursor); + |} + """.stripMargin + } - // Remember the current cursor so that we can write numBytes of key array later. - final int $tmpCursor = $rowWriter.cursor(); + private def writeElement( + ctx: CodegenContext, + input: String, + index: String, + dt: DataType, + writer: String): String = dt match { + case t: StructType => + writeStructToBuffer(ctx, input, index, t.map(_.dataType), writer) + + case ArrayType(et, _) => + val previousCursor = ctx.freshName("previousCursor") + s""" + |// Remember the current cursor so that we can calculate how many bytes are + |// written later. + |final int $previousCursor = $writer.cursor(); + |${writeArrayToBuffer(ctx, input, et, writer)} + |$writer.setOffsetAndSizeFromPreviousCursor($index, $previousCursor); + """.stripMargin - ${writeArrayToBuffer(ctx, s"$tmpInput.keyArray()", keyType, rowWriter)} - // Write the numBytes of key array into the first 8 bytes. - Platform.putLong($rowWriter.getBuffer(), $tmpCursor - 8, $rowWriter.cursor() - $tmpCursor); + case MapType(kt, vt, _) => + writeMapToBuffer(ctx, input, index, kt, vt, writer) - ${writeArrayToBuffer(ctx, s"$tmpInput.valueArray()", valueType, rowWriter)} - } - """ - } + case DecimalType.Fixed(precision, scale) => + s"$writer.write($index, $input, $precision, $scale);" - /** - * If the input is already in unsafe format, we don't need to go through all elements/fields, - * we can directly write it. - */ - private def writeUnsafeData(ctx: CodegenContext, input: String, rowWriter: String) = { - val sizeInBytes = ctx.freshName("sizeInBytes") - s""" - final int $sizeInBytes = $input.getSizeInBytes(); - // grow the global buffer before writing data. - $rowWriter.grow($sizeInBytes); - $input.writeToMemory($rowWriter.getBuffer(), $rowWriter.cursor()); - $rowWriter.increaseCursor($sizeInBytes); - """ + case NullType => "" + + case _ => s"$writer.write($index, $input);" } def createCode( @@ -332,10 +287,10 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val code = s""" - $rowWriter.reset(); - $evalSubexpr - $writeExpressions - """ + |$rowWriter.reset(); + |$evalSubexpr + |$writeExpressions + """.stripMargin // `rowWriter` is declared as a class field, so we can access it directly in methods. ExprCode(code, FalseLiteral, StatementValue(s"$rowWriter.getRow()", "UnsafeRow", canDirectAccess = true)) @@ -363,38 +318,39 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val ctx = newCodeGenContext() val eval = createCode(ctx, expressions, subexpressionEliminationEnabled) - val codeBody = s""" - public java.lang.Object generate(Object[] references) { - return new SpecificUnsafeProjection(references); - } - - class SpecificUnsafeProjection extends ${classOf[UnsafeProjection].getName} { - - private Object[] references; - ${ctx.declareMutableStates()} - - public SpecificUnsafeProjection(Object[] references) { - this.references = references; - ${ctx.initMutableStates()} - } - - public void initialize(int partitionIndex) { - ${ctx.initPartition()} - } - - // Scala.Function1 need this - public java.lang.Object apply(java.lang.Object row) { - return apply((InternalRow) row); - } - - public UnsafeRow apply(InternalRow ${ctx.INPUT_ROW}) { - ${eval.code.trim} - return ${eval.value}; - } - - ${ctx.declareAddedFunctions()} - } - """ + val codeBody = + s""" + |public java.lang.Object generate(Object[] references) { + | return new SpecificUnsafeProjection(references); + |} + | + |class SpecificUnsafeProjection extends ${classOf[UnsafeProjection].getName} { + | + | private Object[] references; + | ${ctx.declareMutableStates()} + | + | public SpecificUnsafeProjection(Object[] references) { + | this.references = references; + | ${ctx.initMutableStates()} + | } + | + | public void initialize(int partitionIndex) { + | ${ctx.initPartition()} + | } + | + | // Scala.Function1 need this + | public java.lang.Object apply(java.lang.Object row) { + | return apply((InternalRow) row); + | } + | + | public UnsafeRow apply(InternalRow ${ctx.INPUT_ROW}) { + | ${eval.code.trim} + | return ${eval.value}; + | } + | + | ${ctx.declareAddedFunctions()} + |} + """.stripMargin val code = CodeFormatter.stripOverlappingComments( new CodeAndComment(codeBody, ctx.getPlaceHolderToComments())) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala index 5a944e763e099..6af16e2dba105 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala @@ -97,6 +97,16 @@ abstract class UserDefinedType[UserType >: Null] extends DataType with Serializa override def catalogString: String = sqlType.simpleString } +private[spark] object UserDefinedType { + /** + * Get the sqlType of a (potential) [[UserDefinedType]]. + */ + def sqlType(dt: DataType): DataType = dt match { + case udt: UserDefinedType[_] => udt.sqlType + case _ => dt + } +} + /** * The user defined type in Python. * From e179658914963de472120a81621396706584c949 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Tue, 10 Apr 2018 09:33:09 -0700 Subject: [PATCH 0594/1161] [SPARK-19724][SQL][FOLLOW-UP] Check location of managed table when ignoreIfExists is true ## What changes were proposed in this pull request? In the PR #20886, I mistakenly check the table location only when `ignoreIfExists` is false, which was following the original deprecated PR. That was wrong. When `ignoreIfExists` is true and the target table doesn't exist, we should also check the table location. In other word, **`ignoreIfExists` has nothing to do with table location validation**. This is a follow-up PR to fix the mistake. ## How was this patch tested? Add one unit test. Author: Gengliang Wang Closes #21001 from gengliangwang/SPARK-19724-followup. --- .../spark/sql/catalyst/catalog/SessionCatalog.scala | 11 +++++++++-- .../execution/command/createDataSourceTables.scala | 2 +- .../apache/spark/sql/execution/command/DDLSuite.scala | 9 +++++++++ .../spark/sql/hive/execution/HiveDDLSuite.scala | 2 +- 4 files changed, 20 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index 52ed89ef8d781..c390337c03ff5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -286,7 +286,10 @@ class SessionCatalog( * Create a metastore table in the database specified in `tableDefinition`. * If no such database is specified, create it in the current database. */ - def createTable(tableDefinition: CatalogTable, ignoreIfExists: Boolean): Unit = { + def createTable( + tableDefinition: CatalogTable, + ignoreIfExists: Boolean, + validateLocation: Boolean = true): Unit = { val db = formatDatabaseName(tableDefinition.identifier.database.getOrElse(getCurrentDatabase)) val table = formatTableName(tableDefinition.identifier.table) val tableIdentifier = TableIdentifier(table, Some(db)) @@ -305,7 +308,11 @@ class SessionCatalog( } requireDbExists(db) - if (!ignoreIfExists) { + if (tableExists(newTableDefinition.identifier)) { + if (!ignoreIfExists) { + throw new TableAlreadyExistsException(db = db, table = table) + } + } else if (validateLocation) { validateTableLocation(newTableDefinition) } externalCatalog.createTable(newTableDefinition, ignoreIfExists) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala index f7c3e9b019258..f6ef433f2ce15 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala @@ -182,7 +182,7 @@ case class CreateDataSourceTableAsSelectCommand( // provider (for example, see org.apache.spark.sql.parquet.DefaultSource). schema = result.schema) // Table location is already validated. No need to check it again during table creation. - sessionState.catalog.createTable(newTable, ignoreIfExists = true) + sessionState.catalog.createTable(newTable, ignoreIfExists = false, validateLocation = false) result match { case fs: HadoopFsRelation if table.partitionColumnNames.nonEmpty && diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index 4304d0b6f6b16..cbd7f9d6f67be 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -425,6 +425,15 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { sql(s"CREATE TABLE tab1 (col1 int, col2 string) USING ${dataSource}") }.getMessage assert(ex.contains(exMsgWithDefaultDB)) + + // Always check location of managed table, with or without (IF NOT EXISTS) + withTable("tab2") { + sql(s"CREATE TABLE tab2 (col1 int, col2 string) USING ${dataSource}") + ex = intercept[AnalysisException] { + sql(s"CREATE TABLE IF NOT EXISTS tab1 LIKE tab2") + }.getMessage + assert(ex.contains(exMsgWithDefaultDB)) + } } finally { waitForTasksToFinish() Utils.deleteRecursively(tableLoc) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala index db76ec9d084cb..c85db78c732de 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -1461,7 +1461,7 @@ class HiveDDLSuite assert(e2.getMessage.contains(forbiddenPrefix + "foo")) val e3 = intercept[AnalysisException] { - sql(s"CREATE TABLE tbl (a INT) TBLPROPERTIES ('${forbiddenPrefix}foo'='anything')") + sql(s"CREATE TABLE tbl2 (a INT) TBLPROPERTIES ('${forbiddenPrefix}foo'='anything')") } assert(e3.getMessage.contains(forbiddenPrefix + "foo")) } From adb222b957f327a69929b8f16fa5ebc071fa99e3 Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Tue, 10 Apr 2018 11:18:14 -0700 Subject: [PATCH 0595/1161] [SPARK-23751][ML][PYSPARK] Kolmogorov-Smirnoff test Python API in pyspark.ml ## What changes were proposed in this pull request? Kolmogorov-Smirnoff test Python API in `pyspark.ml` **Note** API with `CDF` is a little difficult to support in python. We can add it in following PR. ## How was this patch tested? doctest Author: WeichenXu Closes #20904 from WeichenXu123/ks-test-py. --- .../spark/ml/stat/KolmogorovSmirnovTest.scala | 29 +-- python/pyspark/ml/stat.py | 181 ++++++++++++------ 2 files changed, 138 insertions(+), 72 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/stat/KolmogorovSmirnovTest.scala b/mllib/src/main/scala/org/apache/spark/ml/stat/KolmogorovSmirnovTest.scala index c62d7463288f7..af8ff64d33ffe 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/stat/KolmogorovSmirnovTest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/stat/KolmogorovSmirnovTest.scala @@ -24,7 +24,7 @@ import org.apache.spark.api.java.function.Function import org.apache.spark.ml.util.SchemaUtils import org.apache.spark.mllib.stat.{Statistics => OldStatistics} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions.col /** @@ -59,7 +59,7 @@ object KolmogorovSmirnovTest { * distribution of the sample data and the theoretical distribution we can provide a test for the * the null hypothesis that the sample data comes from that theoretical distribution. * - * @param dataset a `DataFrame` containing the sample of data to test + * @param dataset A `Dataset` or a `DataFrame` containing the sample of data to test * @param sampleCol Name of sample column in dataset, of any numerical type * @param cdf a `Double => Double` function to calculate the theoretical CDF at a given value * @return DataFrame containing the test result for the input sampled data. @@ -68,10 +68,10 @@ object KolmogorovSmirnovTest { * - `statistic: Double` */ @Since("2.4.0") - def test(dataset: DataFrame, sampleCol: String, cdf: Double => Double): DataFrame = { + def test(dataset: Dataset[_], sampleCol: String, cdf: Double => Double): DataFrame = { val spark = dataset.sparkSession - val rdd = getSampleRDD(dataset, sampleCol) + val rdd = getSampleRDD(dataset.toDF(), sampleCol) val testResult = OldStatistics.kolmogorovSmirnovTest(rdd, cdf) spark.createDataFrame(Seq(KolmogorovSmirnovTestResult( testResult.pValue, testResult.statistic))) @@ -81,10 +81,11 @@ object KolmogorovSmirnovTest { * Java-friendly version of `test(dataset: DataFrame, sampleCol: String, cdf: Double => Double)` */ @Since("2.4.0") - def test(dataset: DataFrame, sampleCol: String, - cdf: Function[java.lang.Double, java.lang.Double]): DataFrame = { - val f: Double => Double = x => cdf.call(x) - test(dataset, sampleCol, f) + def test( + dataset: Dataset[_], + sampleCol: String, + cdf: Function[java.lang.Double, java.lang.Double]): DataFrame = { + test(dataset, sampleCol, (x: Double) => cdf.call(x)) } /** @@ -92,10 +93,11 @@ object KolmogorovSmirnovTest { * distribution equality. Currently supports the normal distribution, taking as parameters * the mean and standard deviation. * - * @param dataset a `DataFrame` containing the sample of data to test + * @param dataset A `Dataset` or a `DataFrame` containing the sample of data to test * @param sampleCol Name of sample column in dataset, of any numerical type * @param distName a `String` name for a theoretical distribution, currently only support "norm". - * @param params `Double*` specifying the parameters to be used for the theoretical distribution + * @param params `Double*` specifying the parameters to be used for the theoretical distribution. + * For "norm" distribution, the parameters includes mean and variance. * @return DataFrame containing the test result for the input sampled data. * This DataFrame will contain a single Row with the following fields: * - `pValue: Double` @@ -103,10 +105,13 @@ object KolmogorovSmirnovTest { */ @Since("2.4.0") @varargs - def test(dataset: DataFrame, sampleCol: String, distName: String, params: Double*): DataFrame = { + def test( + dataset: Dataset[_], + sampleCol: String, distName: String, + params: Double*): DataFrame = { val spark = dataset.sparkSession - val rdd = getSampleRDD(dataset, sampleCol) + val rdd = getSampleRDD(dataset.toDF(), sampleCol) val testResult = OldStatistics.kolmogorovSmirnovTest(rdd, distName, params: _*) spark.createDataFrame(Seq(KolmogorovSmirnovTestResult( testResult.pValue, testResult.statistic))) diff --git a/python/pyspark/ml/stat.py b/python/pyspark/ml/stat.py index 0eeb5e528434a..93d0f4fd9148f 100644 --- a/python/pyspark/ml/stat.py +++ b/python/pyspark/ml/stat.py @@ -32,32 +32,6 @@ class ChiSquareTest(object): The null hypothesis is that the occurrence of the outcomes is statistically independent. - :param dataset: - DataFrame of categorical labels and categorical features. - Real-valued features will be treated as categorical for each distinct value. - :param featuresCol: - Name of features column in dataset, of type `Vector` (`VectorUDT`). - :param labelCol: - Name of label column in dataset, of any numerical type. - :return: - DataFrame containing the test result for every feature against the label. - This DataFrame will contain a single Row with the following fields: - - `pValues: Vector` - - `degreesOfFreedom: Array[Int]` - - `statistics: Vector` - Each of these fields has one value per feature. - - >>> from pyspark.ml.linalg import Vectors - >>> from pyspark.ml.stat import ChiSquareTest - >>> dataset = [[0, Vectors.dense([0, 0, 1])], - ... [0, Vectors.dense([1, 0, 1])], - ... [1, Vectors.dense([2, 1, 1])], - ... [1, Vectors.dense([3, 1, 1])]] - >>> dataset = spark.createDataFrame(dataset, ["label", "features"]) - >>> chiSqResult = ChiSquareTest.test(dataset, 'features', 'label') - >>> chiSqResult.select("degreesOfFreedom").collect()[0] - Row(degreesOfFreedom=[3, 1, 0]) - .. versionadded:: 2.2.0 """ @@ -66,6 +40,32 @@ class ChiSquareTest(object): def test(dataset, featuresCol, labelCol): """ Perform a Pearson's independence test using dataset. + + :param dataset: + DataFrame of categorical labels and categorical features. + Real-valued features will be treated as categorical for each distinct value. + :param featuresCol: + Name of features column in dataset, of type `Vector` (`VectorUDT`). + :param labelCol: + Name of label column in dataset, of any numerical type. + :return: + DataFrame containing the test result for every feature against the label. + This DataFrame will contain a single Row with the following fields: + - `pValues: Vector` + - `degreesOfFreedom: Array[Int]` + - `statistics: Vector` + Each of these fields has one value per feature. + + >>> from pyspark.ml.linalg import Vectors + >>> from pyspark.ml.stat import ChiSquareTest + >>> dataset = [[0, Vectors.dense([0, 0, 1])], + ... [0, Vectors.dense([1, 0, 1])], + ... [1, Vectors.dense([2, 1, 1])], + ... [1, Vectors.dense([3, 1, 1])]] + >>> dataset = spark.createDataFrame(dataset, ["label", "features"]) + >>> chiSqResult = ChiSquareTest.test(dataset, 'features', 'label') + >>> chiSqResult.select("degreesOfFreedom").collect()[0] + Row(degreesOfFreedom=[3, 1, 0]) """ sc = SparkContext._active_spark_context javaTestObj = _jvm().org.apache.spark.ml.stat.ChiSquareTest @@ -85,40 +85,6 @@ class Correlation(object): which is fairly costly. Cache the input Dataset before calling corr with `method = 'spearman'` to avoid recomputing the common lineage. - :param dataset: - A dataset or a dataframe. - :param column: - The name of the column of vectors for which the correlation coefficient needs - to be computed. This must be a column of the dataset, and it must contain - Vector objects. - :param method: - String specifying the method to use for computing correlation. - Supported: `pearson` (default), `spearman`. - :return: - A dataframe that contains the correlation matrix of the column of vectors. This - dataframe contains a single row and a single column of name - '$METHODNAME($COLUMN)'. - - >>> from pyspark.ml.linalg import Vectors - >>> from pyspark.ml.stat import Correlation - >>> dataset = [[Vectors.dense([1, 0, 0, -2])], - ... [Vectors.dense([4, 5, 0, 3])], - ... [Vectors.dense([6, 7, 0, 8])], - ... [Vectors.dense([9, 0, 0, 1])]] - >>> dataset = spark.createDataFrame(dataset, ['features']) - >>> pearsonCorr = Correlation.corr(dataset, 'features', 'pearson').collect()[0][0] - >>> print(str(pearsonCorr).replace('nan', 'NaN')) - DenseMatrix([[ 1. , 0.0556..., NaN, 0.4004...], - [ 0.0556..., 1. , NaN, 0.9135...], - [ NaN, NaN, 1. , NaN], - [ 0.4004..., 0.9135..., NaN, 1. ]]) - >>> spearmanCorr = Correlation.corr(dataset, 'features', method='spearman').collect()[0][0] - >>> print(str(spearmanCorr).replace('nan', 'NaN')) - DenseMatrix([[ 1. , 0.1054..., NaN, 0.4 ], - [ 0.1054..., 1. , NaN, 0.9486... ], - [ NaN, NaN, 1. , NaN], - [ 0.4 , 0.9486... , NaN, 1. ]]) - .. versionadded:: 2.2.0 """ @@ -127,6 +93,40 @@ class Correlation(object): def corr(dataset, column, method="pearson"): """ Compute the correlation matrix with specified method using dataset. + + :param dataset: + A Dataset or a DataFrame. + :param column: + The name of the column of vectors for which the correlation coefficient needs + to be computed. This must be a column of the dataset, and it must contain + Vector objects. + :param method: + String specifying the method to use for computing correlation. + Supported: `pearson` (default), `spearman`. + :return: + A DataFrame that contains the correlation matrix of the column of vectors. This + DataFrame contains a single row and a single column of name + '$METHODNAME($COLUMN)'. + + >>> from pyspark.ml.linalg import Vectors + >>> from pyspark.ml.stat import Correlation + >>> dataset = [[Vectors.dense([1, 0, 0, -2])], + ... [Vectors.dense([4, 5, 0, 3])], + ... [Vectors.dense([6, 7, 0, 8])], + ... [Vectors.dense([9, 0, 0, 1])]] + >>> dataset = spark.createDataFrame(dataset, ['features']) + >>> pearsonCorr = Correlation.corr(dataset, 'features', 'pearson').collect()[0][0] + >>> print(str(pearsonCorr).replace('nan', 'NaN')) + DenseMatrix([[ 1. , 0.0556..., NaN, 0.4004...], + [ 0.0556..., 1. , NaN, 0.9135...], + [ NaN, NaN, 1. , NaN], + [ 0.4004..., 0.9135..., NaN, 1. ]]) + >>> spearmanCorr = Correlation.corr(dataset, 'features', method='spearman').collect()[0][0] + >>> print(str(spearmanCorr).replace('nan', 'NaN')) + DenseMatrix([[ 1. , 0.1054..., NaN, 0.4 ], + [ 0.1054..., 1. , NaN, 0.9486... ], + [ NaN, NaN, 1. , NaN], + [ 0.4 , 0.9486... , NaN, 1. ]]) """ sc = SparkContext._active_spark_context javaCorrObj = _jvm().org.apache.spark.ml.stat.Correlation @@ -134,6 +134,67 @@ def corr(dataset, column, method="pearson"): return _java2py(sc, javaCorrObj.corr(*args)) +class KolmogorovSmirnovTest(object): + """ + .. note:: Experimental + + Conduct the two-sided Kolmogorov Smirnov (KS) test for data sampled from a continuous + distribution. + + By comparing the largest difference between the empirical cumulative + distribution of the sample data and the theoretical distribution we can provide a test for the + the null hypothesis that the sample data comes from that theoretical distribution. + + .. versionadded:: 2.4.0 + + """ + @staticmethod + @since("2.4.0") + def test(dataset, sampleCol, distName, *params): + """ + Conduct a one-sample, two-sided Kolmogorov-Smirnov test for probability distribution + equality. Currently supports the normal distribution, taking as parameters the mean and + standard deviation. + + :param dataset: + a Dataset or a DataFrame containing the sample of data to test. + :param sampleCol: + Name of sample column in dataset, of any numerical type. + :param distName: + a `string` name for a theoretical distribution, currently only support "norm". + :param params: + a list of `Double` values specifying the parameters to be used for the theoretical + distribution. For "norm" distribution, the parameters includes mean and variance. + :return: + A DataFrame that contains the Kolmogorov-Smirnov test result for the input sampled data. + This DataFrame will contain a single Row with the following fields: + - `pValue: Double` + - `statistic: Double` + + >>> from pyspark.ml.stat import KolmogorovSmirnovTest + >>> dataset = [[-1.0], [0.0], [1.0]] + >>> dataset = spark.createDataFrame(dataset, ['sample']) + >>> ksResult = KolmogorovSmirnovTest.test(dataset, 'sample', 'norm', 0.0, 1.0).first() + >>> round(ksResult.pValue, 3) + 1.0 + >>> round(ksResult.statistic, 3) + 0.175 + >>> dataset = [[2.0], [3.0], [4.0]] + >>> dataset = spark.createDataFrame(dataset, ['sample']) + >>> ksResult = KolmogorovSmirnovTest.test(dataset, 'sample', 'norm', 3.0, 1.0).first() + >>> round(ksResult.pValue, 3) + 1.0 + >>> round(ksResult.statistic, 3) + 0.175 + """ + sc = SparkContext._active_spark_context + javaTestObj = _jvm().org.apache.spark.ml.stat.KolmogorovSmirnovTest + dataset = _py2java(sc, dataset) + params = [float(param) for param in params] + return _java2py(sc, javaTestObj.test(dataset, sampleCol, distName, + _jvm().PythonUtils.toSeq(params))) + + if __name__ == "__main__": import doctest import pyspark.ml.stat From 4f1e8b9bb7d795d4ca3d5cd5dcc0f9419e52dfae Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Tue, 10 Apr 2018 15:41:45 -0700 Subject: [PATCH 0596/1161] [SPARK-23871][ML][PYTHON] add python api for VectorAssembler handleInvalid ## What changes were proposed in this pull request? add python api for VectorAssembler handleInvalid ## How was this patch tested? Add doctest Author: Huaxin Gao Closes #21003 from huaxingao/spark-23871. --- .../spark/ml/feature/VectorAssembler.scala | 12 +++--- python/pyspark/ml/feature.py | 42 ++++++++++++++++--- 2 files changed, 43 insertions(+), 11 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala index 6bf4aa38b1fcb..4061154b39c14 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala @@ -71,12 +71,12 @@ class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String) */ @Since("2.4.0") override val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", - """Param for how to handle invalid data (NULL values). Options are 'skip' (filter out rows with - |invalid data), 'error' (throw an error), or 'keep' (return relevant number of NaN in the - |output). Column lengths are taken from the size of ML Attribute Group, which can be set using - |`VectorSizeHint` in a pipeline before `VectorAssembler`. Column lengths can also be inferred - |from first rows of the data since it is safe to do so but only in case of 'error' or 'skip'. - |""".stripMargin.replaceAll("\n", " "), + """Param for how to handle invalid data (NULL and NaN values). Options are 'skip' (filter out + |rows with invalid data), 'error' (throw an error), or 'keep' (return relevant number of NaN + |in the output). Column lengths are taken from the size of ML Attribute Group, which can be + |set using `VectorSizeHint` in a pipeline before `VectorAssembler`. Column lengths can also + |be inferred from first rows of the data since it is safe to do so but only in case of 'error' + |or 'skip'.""".stripMargin.replaceAll("\n", " "), ParamValidators.inArray(VectorAssembler.supportedHandleInvalids)) setDefault(handleInvalid, VectorAssembler.ERROR_INVALID) diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 5a3e0dd655150..cdda30cfab482 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -2701,7 +2701,8 @@ def setParams(self, inputCol=None, outputCol=None): @inherit_doc -class VectorAssembler(JavaTransformer, HasInputCols, HasOutputCol, JavaMLReadable, JavaMLWritable): +class VectorAssembler(JavaTransformer, HasInputCols, HasOutputCol, HasHandleInvalid, JavaMLReadable, + JavaMLWritable): """ A feature transformer that merges multiple columns into a vector column. @@ -2719,25 +2720,56 @@ class VectorAssembler(JavaTransformer, HasInputCols, HasOutputCol, JavaMLReadabl >>> loadedAssembler = VectorAssembler.load(vectorAssemblerPath) >>> loadedAssembler.transform(df).head().freqs == vecAssembler.transform(df).head().freqs True + >>> dfWithNullsAndNaNs = spark.createDataFrame( + ... [(1.0, 2.0, None), (3.0, float("nan"), 4.0), (5.0, 6.0, 7.0)], ["a", "b", "c"]) + >>> vecAssembler2 = VectorAssembler(inputCols=["a", "b", "c"], outputCol="features", + ... handleInvalid="keep") + >>> vecAssembler2.transform(dfWithNullsAndNaNs).show() + +---+---+----+-------------+ + | a| b| c| features| + +---+---+----+-------------+ + |1.0|2.0|null|[1.0,2.0,NaN]| + |3.0|NaN| 4.0|[3.0,NaN,4.0]| + |5.0|6.0| 7.0|[5.0,6.0,7.0]| + +---+---+----+-------------+ + ... + >>> vecAssembler2.setParams(handleInvalid="skip").transform(dfWithNullsAndNaNs).show() + +---+---+---+-------------+ + | a| b| c| features| + +---+---+---+-------------+ + |5.0|6.0|7.0|[5.0,6.0,7.0]| + +---+---+---+-------------+ + ... .. versionadded:: 1.4.0 """ + handleInvalid = Param(Params._dummy(), "handleInvalid", "How to handle invalid data (NULL " + + "and NaN values). Options are 'skip' (filter out rows with invalid " + + "data), 'error' (throw an error), or 'keep' (return relevant number " + + "of NaN in the output). Column lengths are taken from the size of ML " + + "Attribute Group, which can be set using `VectorSizeHint` in a " + + "pipeline before `VectorAssembler`. Column lengths can also be " + + "inferred from first rows of the data since it is safe to do so but " + + "only in case of 'error' or 'skip').", + typeConverter=TypeConverters.toString) + @keyword_only - def __init__(self, inputCols=None, outputCol=None): + def __init__(self, inputCols=None, outputCol=None, handleInvalid="error"): """ - __init__(self, inputCols=None, outputCol=None) + __init__(self, inputCols=None, outputCol=None, handleInvalid="error") """ super(VectorAssembler, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.VectorAssembler", self.uid) + self._setDefault(handleInvalid="error") kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @since("1.4.0") - def setParams(self, inputCols=None, outputCol=None): + def setParams(self, inputCols=None, outputCol=None, handleInvalid="error"): """ - setParams(self, inputCols=None, outputCol=None) + setParams(self, inputCols=None, outputCol=None, handleInvalid="error") Sets params for this VectorAssembler. """ kwargs = self._input_kwargs From 7c7570d466a8ded51e580eb6a28583bd9a9c5337 Mon Sep 17 00:00:00 2001 From: Lu WANG Date: Tue, 10 Apr 2018 17:26:06 -0700 Subject: [PATCH 0597/1161] [SPARK-23944][ML] Add the set method for the two LSHModel ## What changes were proposed in this pull request? Add two set method for LSHModel in LSH.scala, BucketedRandomProjectionLSH.scala, and MinHashLSH.scala ## How was this patch tested? New test for the param setup was added into - BucketedRandomProjectionLSHSuite.scala - MinHashLSHSuite.scala Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Lu WANG Closes #21015 from ludatabricks/SPARK-23944. --- .../spark/ml/feature/BucketedRandomProjectionLSH.scala | 8 ++++++++ .../src/main/scala/org/apache/spark/ml/feature/LSH.scala | 6 ++++++ .../scala/org/apache/spark/ml/feature/MinHashLSH.scala | 8 ++++++++ .../ml/feature/BucketedRandomProjectionLSHSuite.scala | 8 ++++++++ .../org/apache/spark/ml/feature/MinHashLSHSuite.scala | 8 ++++++++ 5 files changed, 38 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala index 36a46ca6ff4b7..41eaaf9679914 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala @@ -73,6 +73,14 @@ class BucketedRandomProjectionLSHModel private[ml]( private[ml] val randUnitVectors: Array[Vector]) extends LSHModel[BucketedRandomProjectionLSHModel] with BucketedRandomProjectionLSHParams { + /** @group setParam */ + @Since("2.4.0") + override def setInputCol(value: String): this.type = super.set(inputCol, value) + + /** @group setParam */ + @Since("2.4.0") + override def setOutputCol(value: String): this.type = super.set(outputCol, value) + @Since("2.1.0") override protected[ml] val hashFunction: Vector => Array[Vector] = { key: Vector => { diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala index 1c9f47a0b201d..a70931f783f45 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala @@ -65,6 +65,12 @@ private[ml] abstract class LSHModel[T <: LSHModel[T]] extends Model[T] with LSHParams with MLWritable { self: T => + /** @group setParam */ + def setInputCol(value: String): this.type = set(inputCol, value) + + /** @group setParam */ + def setOutputCol(value: String): this.type = set(outputCol, value) + /** * The hash function of LSH, mapping an input feature vector to multiple hash vectors. * @return The mapping of LSH function. diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala index 145422a059196..556848e45532d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala @@ -51,6 +51,14 @@ class MinHashLSHModel private[ml]( private[ml] val randCoefficients: Array[(Int, Int)]) extends LSHModel[MinHashLSHModel] { + /** @group setParam */ + @Since("2.4.0") + override def setInputCol(value: String): this.type = super.set(inputCol, value) + + /** @group setParam */ + @Since("2.4.0") + override def setOutputCol(value: String): this.type = super.set(outputCol, value) + @Since("2.1.0") override protected[ml] val hashFunction: Vector => Array[Vector] = { elems: Vector => { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala index ed9a39d8d1512..9b823259b1deb 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala @@ -48,6 +48,14 @@ class BucketedRandomProjectionLSHSuite extends MLTest with DefaultReadWriteTest ParamsSuite.checkParams(model) } + test("setters") { + val model = new BucketedRandomProjectionLSHModel("brp", Array(Vectors.dense(0.0, 1.0))) + .setInputCol("testkeys") + .setOutputCol("testvalues") + assert(model.getInputCol === "testkeys") + assert(model.getOutputCol === "testvalues") + } + test("BucketedRandomProjectionLSH: default params") { val brp = new BucketedRandomProjectionLSH assert(brp.getNumHashTables === 1.0) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala index 96df68dbdf053..3da0fb7da01ae 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala @@ -43,6 +43,14 @@ class MinHashLSHSuite extends SparkFunSuite with MLlibTestSparkContext with Defa ParamsSuite.checkParams(model) } + test("setters") { + val model = new MinHashLSHModel("mh", randCoefficients = Array((1, 0))) + .setInputCol("testkeys") + .setOutputCol("testvalues") + assert(model.getInputCol === "testkeys") + assert(model.getOutputCol === "testvalues") + } + test("MinHashLSH: default params") { val rp = new MinHashLSH assert(rp.getNumHashTables === 1.0) From c7622befdadfea725797d76e820e3dfc76fec927 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Wed, 11 Apr 2018 19:42:09 +0800 Subject: [PATCH 0598/1161] [SPARK-23847][FOLLOWUP][PYTHON][SQL] Actually test [desc|acs]_nulls_[first|last] functions in PySpark ## What changes were proposed in this pull request? There was a mistake in `tests.py` missing `assertEquals`. ## How was this patch tested? Fixed tests. Author: hyukjinkwon Closes #21035 from HyukjinKwon/SPARK-23847. --- python/pyspark/sql/tests.py | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index dd04ffb4ed393..96c2a776a5049 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -2991,19 +2991,23 @@ def test_create_dateframe_from_pandas_with_dst(self): os.environ['TZ'] = orig_env_tz time.tzset() - def test_2_4_functions(self): + def test_sort_with_nulls_order(self): from pyspark.sql import functions df = self.spark.createDataFrame( [('Tom', 80), (None, 60), ('Alice', 50)], ["name", "height"]) - df.select(df.name).orderBy(functions.asc_nulls_first('name')).collect() - [Row(name=None), Row(name=u'Alice'), Row(name=u'Tom')] - df.select(df.name).orderBy(functions.asc_nulls_last('name')).collect() - [Row(name=u'Alice'), Row(name=u'Tom'), Row(name=None)] - df.select(df.name).orderBy(functions.desc_nulls_first('name')).collect() - [Row(name=None), Row(name=u'Tom'), Row(name=u'Alice')] - df.select(df.name).orderBy(functions.desc_nulls_last('name')).collect() - [Row(name=u'Tom'), Row(name=u'Alice'), Row(name=None)] + self.assertEquals( + df.select(df.name).orderBy(functions.asc_nulls_first('name')).collect(), + [Row(name=None), Row(name=u'Alice'), Row(name=u'Tom')]) + self.assertEquals( + df.select(df.name).orderBy(functions.asc_nulls_last('name')).collect(), + [Row(name=u'Alice'), Row(name=u'Tom'), Row(name=None)]) + self.assertEquals( + df.select(df.name).orderBy(functions.desc_nulls_first('name')).collect(), + [Row(name=None), Row(name=u'Tom'), Row(name=u'Alice')]) + self.assertEquals( + df.select(df.name).orderBy(functions.desc_nulls_last('name')).collect(), + [Row(name=u'Tom'), Row(name=u'Alice'), Row(name=None)]) class HiveSparkSubmitTests(SparkSubmitTests): From 87611bba222a95158fc5b638a566bdf47346da8e Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Wed, 11 Apr 2018 19:44:01 +0800 Subject: [PATCH 0599/1161] [MINOR][DOCS] Fix R documentation generation instruction for roxygen2 ## What changes were proposed in this pull request? This PR proposes to fix `roxygen2` to `5.0.1` in `docs/README.md` for SparkR documentation generation. If I use higher version and creates the doc, it shows the diff below. Not a big deal but it bothered me. ```diff diff --git a/R/pkg/DESCRIPTION b/R/pkg/DESCRIPTION index 855eb5bf77f..159fca61e06 100644 --- a/R/pkg/DESCRIPTION +++ b/R/pkg/DESCRIPTION -57,6 +57,6 Collate: 'types.R' 'utils.R' 'window.R' -RoxygenNote: 5.0.1 +RoxygenNote: 6.0.1 VignetteBuilder: knitr NeedsCompilation: no ``` ## How was this patch tested? Manually tested. I met this every time I set the new environment for Spark dev but I have kept forgetting to fix it. Author: hyukjinkwon Closes #21020 from HyukjinKwon/minor-r-doc. --- docs/README.md | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/docs/README.md b/docs/README.md index 9eac4ba35c458..dbea4d64c4298 100644 --- a/docs/README.md +++ b/docs/README.md @@ -22,10 +22,13 @@ $ sudo gem install jekyll jekyll-redirect-from pygments.rb $ sudo pip install Pygments # Following is needed only for generating API docs $ sudo pip install sphinx pypandoc mkdocs -$ sudo Rscript -e 'install.packages(c("knitr", "devtools", "roxygen2", "testthat", "rmarkdown"), repos="http://cran.stat.ucla.edu/")' +$ sudo Rscript -e 'install.packages(c("knitr", "devtools", "testthat", "rmarkdown"), repos="http://cran.stat.ucla.edu/")' +$ sudo Rscript -e 'devtools::install_version("roxygen2", version = "5.0.1", repos="http://cran.stat.ucla.edu/")' ``` -(Note: If you are on a system with both Ruby 1.9 and Ruby 2.0 you may need to replace gem with gem2.0) +Note: If you are on a system with both Ruby 1.9 and Ruby 2.0 you may need to replace gem with gem2.0. + +Note: Other versions of roxygen2 might work in SparkR documentation generation but `RoxygenNote` field in `$SPARK_HOME/R/pkg/DESCRIPTION` is 5.0.1, which is updated if the version is mismatched. ## Generating the Documentation HTML @@ -62,12 +65,12 @@ $ PRODUCTION=1 jekyll build ## API Docs (Scaladoc, Javadoc, Sphinx, roxygen2, MkDocs) -You can build just the Spark scaladoc and javadoc by running `build/sbt unidoc` from the `SPARK_HOME` directory. +You can build just the Spark scaladoc and javadoc by running `build/sbt unidoc` from the `$SPARK_HOME` directory. Similarly, you can build just the PySpark docs by running `make html` from the -`SPARK_HOME/python/docs` directory. Documentation is only generated for classes that are listed as -public in `__init__.py`. The SparkR docs can be built by running `SPARK_HOME/R/create-docs.sh`, and -the SQL docs can be built by running `SPARK_HOME/sql/create-docs.sh` +`$SPARK_HOME/python/docs` directory. Documentation is only generated for classes that are listed as +public in `__init__.py`. The SparkR docs can be built by running `$SPARK_HOME/R/create-docs.sh`, and +the SQL docs can be built by running `$SPARK_HOME/sql/create-docs.sh` after [building Spark](https://github.com/apache/spark#building-spark) first. When you run `jekyll build` in the `docs` directory, it will also copy over the scaladoc and javadoc for the various From c604d659e19c1b2704cdf8c8ea97edaf50d8cb6b Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Wed, 11 Apr 2018 20:11:03 +0800 Subject: [PATCH 0600/1161] [SPARK-23951][SQL] Use actual java class instead of string representation. ## What changes were proposed in this pull request? This PR slightly refactors the newly added `ExprValue` API by quite a bit. The following changes are introduced: 1. `ExprValue` now uses the actual class instead of the class name as its type. This should give some more flexibility with generating code in the future. 2. Renamed `StatementValue` to `SimpleExprValue`. The statement concept is broader then an expression (untyped and it cannot be on the right hand side of an assignment), and this was not really what we were using it for. I have added a top level `JavaCode` trait that can be used in the future to reinstate (no pun intended) a statement a-like code fragment. 3. Added factory methods to the `JavaCode` companion object to make it slightly less verbose to create `JavaCode`/`ExprValue` objects. This is also what makes the diff quite large. 4. Added one more factory method to `ExprCode` to make it easier to create code-less expressions. ## How was this patch tested? Existing tests. Author: Herman van Hovell Closes #21026 from hvanhovell/SPARK-23951. --- .../sql/catalyst/expressions/Expression.scala | 10 +- .../sql/catalyst/expressions/arithmetic.scala | 6 +- .../expressions/codegen/CodeGenerator.scala | 35 +++- .../expressions/codegen/ExprValue.scala | 76 -------- .../codegen/GenerateMutableProjection.scala | 36 ++-- .../codegen/GenerateSafeProjection.scala | 25 +-- .../codegen/GenerateUnsafeProjection.scala | 11 +- .../expressions/codegen/javaCode.scala | 166 ++++++++++++++++++ .../expressions/complexTypeCreator.scala | 2 +- .../expressions/conditionalExpressions.scala | 5 +- .../expressions/datetimeExpressions.scala | 3 +- .../sql/catalyst/expressions/literals.scala | 27 +-- .../spark/sql/catalyst/expressions/misc.scala | 2 +- .../expressions/nullExpressions.scala | 20 +-- .../expressions/objects/objects.scala | 7 +- .../expressions/CodeGenerationSuite.scala | 5 +- .../expressions/codegen/ExprValueSuite.scala | 14 +- .../sql/execution/ColumnarBatchScan.scala | 6 +- .../spark/sql/execution/ExpandExec.scala | 8 +- .../spark/sql/execution/GenerateExec.scala | 15 +- .../sql/execution/WholeStageCodegenExec.scala | 11 +- .../aggregate/HashAggregateExec.scala | 6 +- .../aggregate/HashMapGenerator.scala | 8 +- .../execution/basicPhysicalOperators.scala | 2 +- .../joins/BroadcastHashJoinExec.scala | 9 +- .../execution/joins/SortMergeJoinExec.scala | 12 +- 26 files changed, 315 insertions(+), 212 deletions(-) delete mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExprValue.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 7a5e49cb5206b..97dff6ae88299 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -104,9 +104,9 @@ abstract class Expression extends TreeNode[Expression] { }.getOrElse { val isNull = ctx.freshName("isNull") val value = ctx.freshName("value") - val eval = doGenCode(ctx, ExprCode("", - VariableValue(isNull, CodeGenerator.JAVA_BOOLEAN), - VariableValue(value, CodeGenerator.javaType(dataType)))) + val eval = doGenCode(ctx, ExprCode( + JavaCode.isNullVariable(isNull), + JavaCode.variable(value, dataType))) reduceCodeSize(ctx, eval) if (eval.code.nonEmpty) { // Add `this` in the comment. @@ -123,7 +123,7 @@ abstract class Expression extends TreeNode[Expression] { val setIsNull = if (!eval.isNull.isInstanceOf[LiteralValue]) { val globalIsNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "globalIsNull") val localIsNull = eval.isNull - eval.isNull = GlobalValue(globalIsNull, CodeGenerator.JAVA_BOOLEAN) + eval.isNull = JavaCode.isNullGlobal(globalIsNull) s"$globalIsNull = $localIsNull;" } else { "" @@ -142,7 +142,7 @@ abstract class Expression extends TreeNode[Expression] { |} """.stripMargin) - eval.value = VariableValue(newValue, javaType) + eval.value = JavaCode.variable(newValue, dataType) eval.code = s"$javaType $newValue = $funcFullName(${ctx.INPUT_ROW});" } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index defd6f3cd8849..9212c3de1f814 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -591,8 +591,7 @@ case class Least(children: Seq[Expression]) extends Expression { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val evalChildren = children.map(_.genCode(ctx)) - ev.isNull = GlobalValue(ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, ev.isNull), - CodeGenerator.JAVA_BOOLEAN) + ev.isNull = JavaCode.isNullGlobal(ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, ev.isNull)) val evals = evalChildren.map(eval => s""" |${eval.code} @@ -671,8 +670,7 @@ case class Greatest(children: Seq[Expression]) extends Expression { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val evalChildren = children.map(_.genCode(ctx)) - ev.isNull = GlobalValue(ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, ev.isNull), - CodeGenerator.JAVA_BOOLEAN) + ev.isNull = JavaCode.isNullGlobal(ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, ev.isNull)) val evals = evalChildren.map(eval => s""" |${eval.code} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index c9c60ef1be640..0abfc9fa4c465 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -59,10 +59,12 @@ import org.apache.spark.util.{ParentClassLoader, Utils} case class ExprCode(var code: String, var isNull: ExprValue, var value: ExprValue) object ExprCode { + def apply(isNull: ExprValue, value: ExprValue): ExprCode = { + ExprCode(code = "", isNull, value) + } + def forNullValue(dataType: DataType): ExprCode = { - val defaultValueLiteral = CodeGenerator.defaultValue(dataType, typedNull = true) - ExprCode(code = "", isNull = TrueLiteral, - value = LiteralValue(defaultValueLiteral, CodeGenerator.javaType(dataType))) + ExprCode(code = "", isNull = TrueLiteral, JavaCode.defaultLiteral(dataType)) } def forNonNullValue(value: ExprValue): ExprCode = { @@ -331,7 +333,7 @@ class CodegenContext { case _: StructType | _: ArrayType | _: MapType => s"$value = $initCode.copy();" case _ => s"$value = $initCode;" } - ExprCode(code, FalseLiteral, GlobalValue(value, javaType(dataType))) + ExprCode(code, FalseLiteral, JavaCode.global(value, dataType)) } def declareMutableStates(): String = { @@ -1004,8 +1006,9 @@ class CodegenContext { // at least two nodes) as the cost of doing it is expected to be low. subexprFunctions += s"${addNewFunction(fnName, fn)}($INPUT_ROW);" - val state = SubExprEliminationState(GlobalValue(isNull, JAVA_BOOLEAN), - GlobalValue(value, javaType(expr.dataType))) + val state = SubExprEliminationState( + JavaCode.isNullGlobal(isNull), + JavaCode.global(value, expr.dataType)) subExprEliminationExprs ++= e.map(_ -> state).toMap } } @@ -1479,6 +1482,26 @@ object CodeGenerator extends Logging { case _ => "Object" } + def javaClass(dt: DataType): Class[_] = dt match { + case BooleanType => java.lang.Boolean.TYPE + case ByteType => java.lang.Byte.TYPE + case ShortType => java.lang.Short.TYPE + case IntegerType | DateType => java.lang.Integer.TYPE + case LongType | TimestampType => java.lang.Long.TYPE + case FloatType => java.lang.Float.TYPE + case DoubleType => java.lang.Double.TYPE + case _: DecimalType => classOf[Decimal] + case BinaryType => classOf[Array[Byte]] + case StringType => classOf[UTF8String] + case CalendarIntervalType => classOf[CalendarInterval] + case _: StructType => classOf[InternalRow] + case _: ArrayType => classOf[ArrayData] + case _: MapType => classOf[MapData] + case udt: UserDefinedType[_] => javaClass(udt.sqlType) + case ObjectType(cls) => cls + case _ => classOf[Object] + } + /** * Returns the boxed type in Java. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExprValue.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExprValue.scala deleted file mode 100644 index df5f1c58b1b2d..0000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExprValue.scala +++ /dev/null @@ -1,76 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.expressions.codegen - -import scala.language.implicitConversions - -import org.apache.spark.sql.types.DataType - -// An abstraction that represents the evaluation result of [[ExprCode]]. -abstract class ExprValue { - - val javaType: String - - // Whether we can directly access the evaluation value anywhere. - // For example, a variable created outside a method can not be accessed inside the method. - // For such cases, we may need to pass the evaluation as parameter. - val canDirectAccess: Boolean - - def isPrimitive: Boolean = CodeGenerator.isPrimitiveType(javaType) -} - -object ExprValue { - implicit def exprValueToString(exprValue: ExprValue): String = exprValue.toString -} - -// A literal evaluation of [[ExprCode]]. -class LiteralValue(val value: String, val javaType: String) extends ExprValue { - override def toString: String = value - override val canDirectAccess: Boolean = true -} - -object LiteralValue { - def apply(value: String, javaType: String): LiteralValue = new LiteralValue(value, javaType) - def unapply(literal: LiteralValue): Option[(String, String)] = - Some((literal.value, literal.javaType)) -} - -// A variable evaluation of [[ExprCode]]. -case class VariableValue( - val variableName: String, - val javaType: String) extends ExprValue { - override def toString: String = variableName - override val canDirectAccess: Boolean = false -} - -// A statement evaluation of [[ExprCode]]. -case class StatementValue( - val statement: String, - val javaType: String, - val canDirectAccess: Boolean = false) extends ExprValue { - override def toString: String = statement -} - -// A global variable evaluation of [[ExprCode]]. -case class GlobalValue(val value: String, val javaType: String) extends ExprValue { - override def toString: String = value - override val canDirectAccess: Boolean = true -} - -case object TrueLiteral extends LiteralValue("true", "boolean") -case object FalseLiteral extends LiteralValue("false", "boolean") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index 3ae0b54c754cf..33d14329ec95c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -52,43 +52,45 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP expressions: Seq[Expression], useSubexprElimination: Boolean): MutableProjection = { val ctx = newCodeGenContext() - val (validExpr, index) = expressions.zipWithIndex.filter { + val validExpr = expressions.zipWithIndex.filter { case (NoOp, _) => false case _ => true - }.unzip - val exprVals = ctx.generateExpressions(validExpr, useSubexprElimination) + } + val exprVals = ctx.generateExpressions(validExpr.map(_._1), useSubexprElimination) // 4-tuples: (code for projection, isNull variable name, value variable name, column index) - val projectionCodes: Seq[(String, ExprValue, String, Int)] = exprVals.zip(index).map { - case (ev, i) => - val e = expressions(i) - val value = ctx.addMutableState(CodeGenerator.javaType(e.dataType), "value") - if (e.nullable) { + val projectionCodes: Seq[(String, String)] = validExpr.zip(exprVals).map { + case ((e, i), ev) => + val value = JavaCode.global( + ctx.addMutableState(CodeGenerator.javaType(e.dataType), "value"), + e.dataType) + val (code, isNull) = if (e.nullable) { val isNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "isNull") (s""" |${ev.code} |$isNull = ${ev.isNull}; |$value = ${ev.value}; - """.stripMargin, GlobalValue(isNull, CodeGenerator.JAVA_BOOLEAN), value, i) + """.stripMargin, JavaCode.isNullGlobal(isNull)) } else { (s""" |${ev.code} |$value = ${ev.value}; - """.stripMargin, ev.isNull, value, i) + """.stripMargin, FalseLiteral) } + val update = CodeGenerator.updateColumn( + "mutableRow", + e.dataType, + i, + ExprCode(isNull, value), + e.nullable) + (code, update) } // Evaluate all the subexpressions. val evalSubexpr = ctx.subexprFunctions.mkString("\n") - val updates = validExpr.zip(projectionCodes).map { - case (e, (_, isNull, value, i)) => - val ev = ExprCode("", isNull, GlobalValue(value, CodeGenerator.javaType(e.dataType))) - CodeGenerator.updateColumn("mutableRow", e.dataType, i, ev, e.nullable) - } - val allProjections = ctx.splitExpressionsWithCurrentInputs(projectionCodes.map(_._1)) - val allUpdates = ctx.splitExpressionsWithCurrentInputs(updates) + val allUpdates = ctx.splitExpressionsWithCurrentInputs(projectionCodes.map(_._2)) val codeBody = s""" public java.lang.Object generate(Object[] references) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala index a30a0b22cd305..01c350e9dbf69 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala @@ -19,9 +19,10 @@ package org.apache.spark.sql.catalyst.expressions.codegen import scala.annotation.tailrec +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp -import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData} import org.apache.spark.sql.types._ /** @@ -53,9 +54,10 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] val rowClass = classOf[GenericInternalRow].getName val fieldWriters = schema.map(_.dataType).zipWithIndex.map { case (dt, i) => - val converter = convertToSafe(ctx, - StatementValue(CodeGenerator.getValue(tmpInput, dt, i.toString), - CodeGenerator.javaType(dt)), dt) + val converter = convertToSafe( + ctx, + JavaCode.expression(CodeGenerator.getValue(tmpInput, dt, i.toString), dt), + dt) s""" if (!$tmpInput.isNullAt($i)) { ${converter.code} @@ -76,7 +78,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] |final InternalRow $output = new $rowClass($values); """.stripMargin - ExprCode(code, FalseLiteral, VariableValue(output, "InternalRow")) + ExprCode(code, FalseLiteral, JavaCode.variable(output, classOf[InternalRow])) } private def createCodeForArray( @@ -91,9 +93,10 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] val index = ctx.freshName("index") val arrayClass = classOf[GenericArrayData].getName - val elementConverter = convertToSafe(ctx, - StatementValue(CodeGenerator.getValue(tmpInput, elementType, index), - CodeGenerator.javaType(elementType)), elementType) + val elementConverter = convertToSafe( + ctx, + JavaCode.expression(CodeGenerator.getValue(tmpInput, elementType, index), elementType), + elementType) val code = s""" final ArrayData $tmpInput = $input; final int $numElements = $tmpInput.numElements(); @@ -107,7 +110,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] final ArrayData $output = new $arrayClass($values); """ - ExprCode(code, FalseLiteral, VariableValue(output, "ArrayData")) + ExprCode(code, FalseLiteral, JavaCode.variable(output, classOf[ArrayData])) } private def createCodeForMap( @@ -128,7 +131,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] final MapData $output = new $mapClass(${keyConverter.value}, ${valueConverter.value}); """ - ExprCode(code, FalseLiteral, VariableValue(output, "MapData")) + ExprCode(code, FalseLiteral, JavaCode.variable(output, classOf[MapData])) } @tailrec @@ -140,7 +143,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] case ArrayType(elementType, _) => createCodeForArray(ctx, input, elementType) case MapType(keyType, valueType, _) => createCodeForMap(ctx, input, keyType, valueType) case udt: UserDefinedType[_] => convertToSafe(ctx, input, udt.sqlType) - case _ => ExprCode("", FalseLiteral, input) + case _ => ExprCode(FalseLiteral, input) } protected def create(expressions: Seq[Expression]): Projection = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 2fb441ac4500e..01b4d6c4529bd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -52,9 +52,9 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro // Puts `input` in a local variable to avoid to re-evaluate it if it's a statement. val tmpInput = ctx.freshName("tmpInput") val fieldEvals = fieldTypes.zipWithIndex.map { case (dt, i) => - ExprCode("", StatementValue(s"$tmpInput.isNullAt($i)", CodeGenerator.JAVA_BOOLEAN), - StatementValue(CodeGenerator.getValue(tmpInput, dt, i.toString), - CodeGenerator.javaType(dt))) + ExprCode( + JavaCode.isNullExpression(s"$tmpInput.isNullAt($i)"), + JavaCode.expression(CodeGenerator.getValue(tmpInput, dt, i.toString), dt)) } val rowWriterClass = classOf[UnsafeRowWriter].getName @@ -109,7 +109,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro } val writeField = writeElement(ctx, input.value, index.toString, dt, rowWriter) - if (input.isNull == "false") { + if (input.isNull == FalseLiteral) { s""" |${input.code} |${writeField.trim} @@ -292,8 +292,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro |$writeExpressions """.stripMargin // `rowWriter` is declared as a class field, so we can access it directly in methods. - ExprCode(code, FalseLiteral, StatementValue(s"$rowWriter.getRow()", "UnsafeRow", - canDirectAccess = true)) + ExprCode(code, FalseLiteral, JavaCode.expression(s"$rowWriter.getRow()", classOf[UnsafeRow])) } protected def canonicalize(in: Seq[Expression]): Seq[Expression] = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala new file mode 100644 index 0000000000000..74ff018488863 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala @@ -0,0 +1,166 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.codegen + +import java.lang.{Boolean => JBool} + +import scala.language.{existentials, implicitConversions} + +import org.apache.spark.sql.types.{BooleanType, DataType} + +/** + * Trait representing an opaque fragments of java code. + */ +trait JavaCode { + def code: String + override def toString: String = code +} + +/** + * Utility functions for creating [[JavaCode]] fragments. + */ +object JavaCode { + /** + * Create a java literal. + */ + def literal(v: String, dataType: DataType): LiteralValue = dataType match { + case BooleanType if v == "true" => TrueLiteral + case BooleanType if v == "false" => FalseLiteral + case _ => new LiteralValue(v, CodeGenerator.javaClass(dataType)) + } + + /** + * Create a default literal. This is null for reference types, false for boolean types and + * -1 for other primitive types. + */ + def defaultLiteral(dataType: DataType): LiteralValue = { + new LiteralValue( + CodeGenerator.defaultValue(dataType, typedNull = true), + CodeGenerator.javaClass(dataType)) + } + + /** + * Create a local java variable. + */ + def variable(name: String, dataType: DataType): VariableValue = { + variable(name, CodeGenerator.javaClass(dataType)) + } + + /** + * Create a local java variable. + */ + def variable(name: String, javaClass: Class[_]): VariableValue = { + VariableValue(name, javaClass) + } + + /** + * Create a local isNull variable. + */ + def isNullVariable(name: String): VariableValue = variable(name, BooleanType) + + /** + * Create a global java variable. + */ + def global(name: String, dataType: DataType): GlobalValue = { + global(name, CodeGenerator.javaClass(dataType)) + } + + /** + * Create a global java variable. + */ + def global(name: String, javaClass: Class[_]): GlobalValue = { + GlobalValue(name, javaClass) + } + + /** + * Create a global isNull variable. + */ + def isNullGlobal(name: String): GlobalValue = global(name, BooleanType) + + /** + * Create an expression fragment. + */ + def expression(code: String, dataType: DataType): SimpleExprValue = { + expression(code, CodeGenerator.javaClass(dataType)) + } + + /** + * Create an expression fragment. + */ + def expression(code: String, javaClass: Class[_]): SimpleExprValue = { + SimpleExprValue(code, javaClass) + } + + /** + * Create a isNull expression fragment. + */ + def isNullExpression(code: String): SimpleExprValue = { + expression(code, BooleanType) + } +} + +/** + * A typed java fragment that must be a valid java expression. + */ +trait ExprValue extends JavaCode { + def javaType: Class[_] + def isPrimitive: Boolean = javaType.isPrimitive +} + +object ExprValue { + implicit def exprValueToString(exprValue: ExprValue): String = exprValue.toString +} + + +/** + * A java expression fragment. + */ +case class SimpleExprValue(expr: String, javaType: Class[_]) extends ExprValue { + override def code: String = s"($expr)" +} + +/** + * A local variable java expression. + */ +case class VariableValue(variableName: String, javaType: Class[_]) extends ExprValue { + override def code: String = variableName +} + +/** + * A global variable java expression. + */ +case class GlobalValue(value: String, javaType: Class[_]) extends ExprValue { + override def code: String = value +} + +/** + * A literal java expression. + */ +class LiteralValue(val value: String, val javaType: Class[_]) extends ExprValue with Serializable { + override def code: String = value + + override def equals(arg: Any): Boolean = arg match { + case l: LiteralValue => l.javaType == javaType && l.value == value + case _ => false + } + + override def hashCode(): Int = value.hashCode() * 31 + javaType.hashCode() +} + +case object TrueLiteral extends LiteralValue("true", JBool.TYPE) +case object FalseLiteral extends LiteralValue("false", JBool.TYPE) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 49a8d12057188..67876a8565488 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -64,7 +64,7 @@ case class CreateArray(children: Seq[Expression]) extends Expression { GenArrayData.genCodeToCreateArrayData(ctx, et, evals, false) ev.copy( code = preprocess + assigns + postprocess, - value = VariableValue(arrayData, CodeGenerator.javaType(dataType)), + value = JavaCode.variable(arrayData, dataType), isNull = FalseLiteral) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index 409c0b6b79b81..205d77f6a9acf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -191,8 +191,9 @@ case class CaseWhen( // It is initialized to `NOT_MATCHED`, and if it's set to `HAS_NULL` or `HAS_NONNULL`, // We won't go on anymore on the computation. val resultState = ctx.freshName("caseWhenResultState") - ev.value = GlobalValue(ctx.addMutableState(CodeGenerator.javaType(dataType), ev.value), - CodeGenerator.javaType(dataType)) + ev.value = JavaCode.global( + ctx.addMutableState(CodeGenerator.javaType(dataType), ev.value), + dataType) // these blocks are meant to be inside a // do { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index 49dd988b4b53c..32fdb13afbbfa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -813,8 +813,7 @@ case class FromUnixTime(sec: Expression, format: Expression, timeZoneId: Option[ val df = classOf[DateFormat].getName if (format.foldable) { if (formatter == null) { - ExprCode("", TrueLiteral, LiteralValue("(UTF8String) null", - CodeGenerator.javaType(dataType))) + ExprCode.forNullValue(StringType) } else { val formatterName = ctx.addReferenceObj("formatter", formatter, df) val t = left.genCode(ctx) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index 742a650eb445d..246025b82d59e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -281,38 +281,41 @@ case class Literal (value: Any, dataType: DataType) extends LeafExpression { if (value == null) { ExprCode.forNullValue(dataType) } else { + def toExprCode(code: String): ExprCode = { + ExprCode.forNonNullValue(JavaCode.literal(code, dataType)) + } dataType match { case BooleanType | IntegerType | DateType => - ExprCode.forNonNullValue(LiteralValue(value.toString, javaType)) + toExprCode(value.toString) case FloatType => value.asInstanceOf[Float] match { case v if v.isNaN => - ExprCode.forNonNullValue(LiteralValue("Float.NaN", javaType)) + toExprCode("Float.NaN") case Float.PositiveInfinity => - ExprCode.forNonNullValue(LiteralValue("Float.POSITIVE_INFINITY", javaType)) + toExprCode("Float.POSITIVE_INFINITY") case Float.NegativeInfinity => - ExprCode.forNonNullValue(LiteralValue("Float.NEGATIVE_INFINITY", javaType)) + toExprCode("Float.NEGATIVE_INFINITY") case _ => - ExprCode.forNonNullValue(LiteralValue(s"${value}F", javaType)) + toExprCode(s"${value}F") } case DoubleType => value.asInstanceOf[Double] match { case v if v.isNaN => - ExprCode.forNonNullValue(LiteralValue("Double.NaN", javaType)) + toExprCode("Double.NaN") case Double.PositiveInfinity => - ExprCode.forNonNullValue(LiteralValue("Double.POSITIVE_INFINITY", javaType)) + toExprCode("Double.POSITIVE_INFINITY") case Double.NegativeInfinity => - ExprCode.forNonNullValue(LiteralValue("Double.NEGATIVE_INFINITY", javaType)) + toExprCode("Double.NEGATIVE_INFINITY") case _ => - ExprCode.forNonNullValue(LiteralValue(s"${value}D", javaType)) + toExprCode(s"${value}D") } case ByteType | ShortType => - ExprCode.forNonNullValue(LiteralValue(s"($javaType)$value", javaType)) + ExprCode.forNonNullValue(JavaCode.expression(s"($javaType)$value", dataType)) case TimestampType | LongType => - ExprCode.forNonNullValue(LiteralValue(s"${value}L", javaType)) + toExprCode(s"${value}L") case _ => val constRef = ctx.addReferenceObj("literal", value, javaType) - ExprCode.forNonNullValue(GlobalValue(constRef, javaType)) + ExprCode.forNonNullValue(JavaCode.global(constRef, dataType)) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index 7081a5e096d56..7eda65a867028 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -92,7 +92,7 @@ case class AssertTrue(child: Expression) extends UnaryExpression with ImplicitCa |if (${eval.isNull} || !${eval.value}) { | throw new RuntimeException($errMsgField); |}""".stripMargin, isNull = TrueLiteral, - value = LiteralValue("null", CodeGenerator.javaType(dataType))) + value = JavaCode.defaultLiteral(dataType)) } override def sql: String = s"assert_true(${child.sql})" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala index 55b6e346be82a..0787342bce6bc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala @@ -72,8 +72,7 @@ case class Coalesce(children: Seq[Expression]) extends Expression { } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - ev.isNull = GlobalValue(ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, ev.isNull), - CodeGenerator.JAVA_BOOLEAN) + ev.isNull = JavaCode.isNullGlobal(ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, ev.isNull)) // all the evals are meant to be in a do { ... } while (false); loop val evals = children.map { e => @@ -321,12 +320,7 @@ case class IsNull(child: Expression) extends UnaryExpression with Predicate { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val eval = child.genCode(ctx) - val value = if (eval.isNull.isInstanceOf[LiteralValue]) { - LiteralValue(eval.isNull, CodeGenerator.JAVA_BOOLEAN) - } else { - VariableValue(eval.isNull, CodeGenerator.JAVA_BOOLEAN) - } - ExprCode(code = eval.code, isNull = FalseLiteral, value = value) + ExprCode(code = eval.code, isNull = FalseLiteral, value = eval.isNull) } override def sql: String = s"(${child.sql} IS NULL)" @@ -352,12 +346,10 @@ case class IsNotNull(child: Expression) extends UnaryExpression with Predicate { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val eval = child.genCode(ctx) - val value = if (eval.isNull == TrueLiteral) { - FalseLiteral - } else if (eval.isNull == FalseLiteral) { - TrueLiteral - } else { - StatementValue(s"(!(${eval.isNull}))", CodeGenerator.javaType(dataType)) + val value = eval.isNull match { + case TrueLiteral => FalseLiteral + case FalseLiteral => TrueLiteral + case v => JavaCode.isNullExpression(s"!$v") } ExprCode(code = eval.code, isNull = FalseLiteral, value = value) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index b2cca3178cd2a..50e90ca550807 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -65,7 +65,7 @@ trait InvokeLike extends Expression with NonSQLExpression { val resultIsNull = if (needNullCheck) { val resultIsNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "resultIsNull") - GlobalValue(resultIsNull, CodeGenerator.JAVA_BOOLEAN) + JavaCode.isNullGlobal(resultIsNull) } else { FalseLiteral } @@ -569,12 +569,11 @@ case class LambdaVariable( override def genCode(ctx: CodegenContext): ExprCode = { val isNullValue = if (nullable) { - VariableValue(isNull, CodeGenerator.JAVA_BOOLEAN) + JavaCode.isNullVariable(isNull) } else { FalseLiteral } - ExprCode(code = "", value = VariableValue(value, CodeGenerator.javaType(dataType)), - isNull = isNullValue) + ExprCode(value = JavaCode.variable(value, dataType), isNull = isNullValue) } // This won't be called as `genCode` is overrided, just overriding it to make diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index 8e83b35c3809c..f7c023111ff59 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -448,8 +448,9 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { val ref = BoundReference(0, IntegerType, true) val add1 = Add(ref, ref) val add2 = Add(add1, add1) - val dummy = SubExprEliminationState(VariableValue("dummy", "boolean"), - VariableValue("dummy", "boolean")) + val dummy = SubExprEliminationState( + JavaCode.variable("dummy", BooleanType), + JavaCode.variable("dummy", BooleanType)) // raw testing of basic functionality { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExprValueSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExprValueSuite.scala index c8f4cff7db48d..378b8bc055e34 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExprValueSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExprValueSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions.codegen import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.types.BooleanType class ExprValueSuite extends SparkFunSuite { @@ -31,16 +32,7 @@ class ExprValueSuite extends SparkFunSuite { assert(trueLit.isPrimitive) assert(falseLit.isPrimitive) - trueLit match { - case LiteralValue(value, javaType) => - assert(value == "true" && javaType == "boolean") - case _ => fail() - } - - falseLit match { - case LiteralValue(value, javaType) => - assert(value == "false" && javaType == "boolean") - case _ => fail() - } + assert(trueLit === JavaCode.literal("true", BooleanType)) + assert(falseLit === JavaCode.literal("false", BooleanType)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala index 434214a10e1e3..fc3dbc1c5591b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.catalyst.expressions.{BoundReference, UnsafeRow} -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral, VariableValue} +import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.types.DataType import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector} @@ -52,7 +52,7 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport { val javaType = CodeGenerator.javaType(dataType) val value = CodeGenerator.getValueFromVector(columnVar, dataType, ordinal) val isNullVar = if (nullable) { - VariableValue(ctx.freshName("isNull"), CodeGenerator.JAVA_BOOLEAN) + JavaCode.isNullVariable(ctx.freshName("isNull")) } else { FalseLiteral } @@ -66,7 +66,7 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport { } else { s"$javaType $valueVar = $value;" }).trim - ExprCode(code, isNullVar, VariableValue(valueVar, javaType)) + ExprCode(code, isNullVar, JavaCode.variable(valueVar, dataType)) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala index 0d9a62cace62a..e4812f3d338fb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala @@ -21,7 +21,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, VariableValue} +import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartitioning} import org.apache.spark.sql.execution.metric.SQLMetrics @@ -157,8 +157,10 @@ case class ExpandExec( |${CodeGenerator.javaType(firstExpr.dataType)} $value = | ${CodeGenerator.defaultValue(firstExpr.dataType)}; """.stripMargin - ExprCode(code, VariableValue(isNull, CodeGenerator.JAVA_BOOLEAN), - VariableValue(value, CodeGenerator.javaType(firstExpr.dataType))) + ExprCode( + code, + JavaCode.isNullVariable(isNull), + JavaCode.variable(value, firstExpr.dataType)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala index 85c5ebfdaa689..f40c50df74ccb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.execution.metric.SQLMetrics -import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructType} +import org.apache.spark.sql.types._ /** * For lazy computing, be sure the generator.terminate() called in the very last @@ -170,10 +170,11 @@ case class GenerateExec( // Add position val position = if (e.position) { if (outer) { - Seq(ExprCode("", StatementValue(s"$index == -1", CodeGenerator.JAVA_BOOLEAN), - VariableValue(index, CodeGenerator.JAVA_INT))) + Seq(ExprCode( + JavaCode.isNullExpression(s"$index == -1"), + JavaCode.variable(index, IntegerType))) } else { - Seq(ExprCode("", FalseLiteral, VariableValue(index, CodeGenerator.JAVA_INT))) + Seq(ExprCode(FalseLiteral, JavaCode.variable(index, IntegerType))) } } else { Seq.empty @@ -316,11 +317,9 @@ case class GenerateExec( |boolean $isNull = ${checks.mkString(" || ")}; |$javaType $value = $isNull ? ${CodeGenerator.defaultValue(dt)} : $getter; """.stripMargin - ExprCode(code, VariableValue(isNull, CodeGenerator.JAVA_BOOLEAN), - VariableValue(value, javaType)) + ExprCode(code, JavaCode.isNullVariable(isNull), JavaCode.variable(value, dt)) } else { - ExprCode(s"$javaType $value = $getter;", FalseLiteral, - VariableValue(value, javaType)) + ExprCode(s"$javaType $value = $getter;", FalseLiteral, JavaCode.variable(value, dt)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index 805ff3cf001ba..828b51fa199de 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -111,7 +111,7 @@ trait CodegenSupport extends SparkPlan { private def prepareRowVar(ctx: CodegenContext, row: String, colVars: Seq[ExprCode]): ExprCode = { if (row != null) { - ExprCode("", FalseLiteral, VariableValue(row, "UnsafeRow")) + ExprCode.forNonNullValue(JavaCode.variable(row, classOf[UnsafeRow])) } else { if (colVars.nonEmpty) { val colExprs = output.zipWithIndex.map { case (attr, i) => @@ -128,8 +128,8 @@ trait CodegenSupport extends SparkPlan { """.stripMargin.trim ExprCode(code, FalseLiteral, ev.value) } else { - // There is no columns - ExprCode("", FalseLiteral, VariableValue("unsafeRow", "UnsafeRow")) + // There are no columns + ExprCode.forNonNullValue(JavaCode.variable("unsafeRow", classOf[UnsafeRow])) } } } @@ -246,11 +246,10 @@ trait CodegenSupport extends SparkPlan { val isNull = ctx.freshName(s"exprIsNull_$i") arguments += ev.isNull parameters += s"boolean $isNull" - VariableValue(isNull, CodeGenerator.JAVA_BOOLEAN) + JavaCode.isNullVariable(isNull) } - paramVars += ExprCode("", paramIsNull, - VariableValue(paramName, CodeGenerator.javaType(attributes(i).dataType))) + paramVars += ExprCode(paramIsNull, JavaCode.variable(paramName, attributes(i).dataType)) } (arguments, parameters, paramVars) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 8f7f10243d4cc..a5dc6ebf2b0f2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -194,8 +194,10 @@ case class HashAggregateExec( | $isNull = ${ev.isNull}; | $value = ${ev.value}; """.stripMargin - ExprCode(ev.code + initVars, GlobalValue(isNull, CodeGenerator.JAVA_BOOLEAN), - GlobalValue(value, CodeGenerator.javaType(e.dataType))) + ExprCode( + ev.code + initVars, + JavaCode.isNullGlobal(isNull), + JavaCode.global(value, e.dataType)) } val initBufVar = evaluateVariables(bufVars) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala index 4978954271311..de2d630de3fdb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, DeclarativeAggregate} -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, GlobalValue} +import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.types._ /** @@ -54,8 +54,10 @@ abstract class HashMapGenerator( | $isNull = ${ev.isNull}; | $value = ${ev.value}; """.stripMargin - ExprCode(ev.code + initVars, GlobalValue(isNull, CodeGenerator.JAVA_BOOLEAN), - GlobalValue(value, CodeGenerator.javaType(e.dataType))) + ExprCode( + ev.code + initVars, + JavaCode.isNullGlobal(isNull), + JavaCode.global(value, e.dataType)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index cab7081400ce9..1edfdc888afd8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -368,7 +368,7 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) val number = ctx.addMutableState(CodeGenerator.JAVA_LONG, "number") val value = ctx.freshName("value") - val ev = ExprCode("", FalseLiteral, VariableValue(value, CodeGenerator.JAVA_LONG)) + val ev = ExprCode.forNonNullValue(JavaCode.variable(value, LongType)) val BigInt = classOf[java.math.BigInteger].getName // Inline mutable state since not many Range operations in a task diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala index fa62a32d51f3e..6fa716d9fadee 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical.{BroadcastDistribution, Distribution, UnspecifiedDistribution} import org.apache.spark.sql.execution.{BinaryExecNode, CodegenSupport, SparkPlan} import org.apache.spark.sql.execution.metric.SQLMetrics -import org.apache.spark.sql.types.LongType +import org.apache.spark.sql.types.{BooleanType, LongType} import org.apache.spark.util.TaskCompletionListener /** @@ -192,8 +192,7 @@ case class BroadcastHashJoinExec( | $value = ${ev.value}; |} """.stripMargin - ExprCode(code, VariableValue(isNull, CodeGenerator.JAVA_BOOLEAN), - VariableValue(value, CodeGenerator.javaType(a.dataType))) + ExprCode(code, JavaCode.isNullVariable(isNull), JavaCode.variable(value, a.dataType)) } } } @@ -488,8 +487,8 @@ case class BroadcastHashJoinExec( s"$existsVar = true;" } - val resultVar = input ++ Seq(ExprCode("", FalseLiteral, - VariableValue(existsVar, CodeGenerator.JAVA_BOOLEAN))) + val resultVar = input ++ Seq(ExprCode.forNonNullValue( + JavaCode.variable(existsVar, BooleanType))) if (broadcastRelation.value.keyIsUnique) { s""" |// generate join key for stream side diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index b61acb8d4fda9..d8261f0f33b61 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -22,11 +22,10 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral, VariableValue} +import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.execution.{BinaryExecNode, CodegenSupport, -ExternalAppendOnlyUnsafeRowArray, RowIterator, SparkPlan} +import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.util.collection.BitSet @@ -531,13 +530,12 @@ case class SortMergeJoinExec( |boolean $isNull = false; |$javaType $value = $defaultValue; """.stripMargin - (ExprCode(code, VariableValue(isNull, CodeGenerator.JAVA_BOOLEAN), - VariableValue(value, CodeGenerator.javaType(a.dataType))), leftVarsDecl) + (ExprCode(code, JavaCode.isNullVariable(isNull), JavaCode.variable(value, a.dataType)), + leftVarsDecl) } else { val code = s"$value = $valueCode;" val leftVarsDecl = s"""$javaType $value = $defaultValue;""" - (ExprCode(code, FalseLiteral, - VariableValue(value, CodeGenerator.javaType(a.dataType))), leftVarsDecl) + (ExprCode(code, FalseLiteral, JavaCode.variable(value, a.dataType)), leftVarsDecl) } }.unzip } From 271c891b91917d660d1f6b995de397c47c7a6058 Mon Sep 17 00:00:00 2001 From: Kris Mok Date: Wed, 11 Apr 2018 21:52:48 +0800 Subject: [PATCH 0601/1161] [SPARK-23960][SQL][MINOR] Mark HashAggregateExec.bufVars as transient ## What changes were proposed in this pull request? Mark `HashAggregateExec.bufVars` as transient to avoid it from being serialized. Also manually null out this field at the end of `doProduceWithoutKeys()` to shorten its lifecycle, because it'll no longer be used after that. ## How was this patch tested? Existing tests. Author: Kris Mok Closes #21039 from rednaxelafx/codegen-improve. --- .../spark/sql/execution/aggregate/HashAggregateExec.scala | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index a5dc6ebf2b0f2..965950ed94fe8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -174,8 +174,8 @@ case class HashAggregateExec( } } - // The variables used as aggregation buffer. Only used for aggregation without keys. - private var bufVars: Seq[ExprCode] = _ + // The variables used as aggregation buffer. Only used in codegen for aggregation without keys. + @transient private var bufVars: Seq[ExprCode] = _ private def doProduceWithoutKeys(ctx: CodegenContext): String = { val initAgg = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "initAgg") @@ -238,6 +238,8 @@ case class HashAggregateExec( | } """.stripMargin) + bufVars = null // explicitly null this field out to allow the referent to be GC'd sooner + val numOutput = metricTerm(ctx, "numOutputRows") val aggTime = metricTerm(ctx, "aggTime") val beforeAgg = ctx.freshName("beforeAgg") From 653fe02415a537299e15f92b56045569864b6183 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Wed, 11 Apr 2018 09:49:25 -0500 Subject: [PATCH 0602/1161] [SPARK-6951][CORE] Speed up parsing of event logs during listing. This change introduces two optimizations to help speed up generation of listing data when parsing events logs. The first one allows the parser to be stopped when enough data to create the listing entry has been read. This is currently the start event plus environment info, to capture UI ACLs. If the end event is needed, the code will skip to the end of the log to try to find that information, instead of parsing the whole log file. Unfortunately this works better with uncompressed logs. Skipping bytes on compressed logs only saves the work of parsing lines and some events, so not a lot of gains are observed. The second optimization deals with in-progress logs. It works in two ways: first, it completely avoids parsing the rest of the log for these apps when enough listing data is read. This, unlike the above, also speeds things up for compressed logs, since only the very beginning of the log has to be read. On top of that, the code that decides whether to re-parse logs to get updated listing data will ignore in-progress applications until they've completed. Both optimizations can be disabled but are enabled by default. I tested this on some fake event logs to see the effect. I created 500 logs of about 60M each (so ~30G uncompressed; each log was 1.7M when compressed with zstd). Below, C = completed, IP = in-progress, the size means the amount of data re-parsed at the end of logs when necessary. ``` none/C none/IP zstd/C zstd/IP On / 16k 2s 2s 22s 2s On / 1m 3s 2s 24s 2s Off 1.1m 1.1m 26s 24s ``` This was with 4 threads on a single local SSD. As expected from the previous explanations, there are considerable gains for in-progress logs, and for uncompressed logs, but not so much when looking at the full compressed log. As a side note, I removed the custom code to get the scan time by creating a file on HDFS; since file mod times are not used to detect changed logs anymore, local time is enough for the current use of the SHS. Author: Marcelo Vanzin Closes #20952 from vanzin/SPARK-6951. --- .../deploy/history/FsHistoryProvider.scala | 251 ++++++++++++------ .../apache/spark/deploy/history/config.scala | 15 ++ .../spark/scheduler/ReplayListenerBus.scala | 11 + .../org/apache/spark/util/ListenerBus.scala | 5 +- .../history/FsHistoryProviderSuite.scala | 78 ++++-- 5 files changed, 264 insertions(+), 96 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index ace6d9e00c838..56db9359e033f 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -18,12 +18,13 @@ package org.apache.spark.deploy.history import java.io.{File, FileNotFoundException, IOException} -import java.util.{Date, ServiceLoader, UUID} +import java.util.{Date, ServiceLoader} import java.util.concurrent.{ExecutorService, TimeUnit} import java.util.zip.{ZipEntry, ZipOutputStream} import scala.collection.JavaConverters._ import scala.collection.mutable +import scala.io.Source import scala.util.Try import scala.xml.Node @@ -58,10 +59,10 @@ import org.apache.spark.util.kvstore._ * * == How new and updated attempts are detected == * - * - New attempts are detected in [[checkForLogs]]: the log dir is scanned, and any - * entries in the log dir whose modification time is greater than the last scan time - * are considered new or updated. These are replayed to create a new attempt info entry - * and update or create a matching application info element in the list of applications. + * - New attempts are detected in [[checkForLogs]]: the log dir is scanned, and any entries in the + * log dir whose size changed since the last scan time are considered new or updated. These are + * replayed to create a new attempt info entry and update or create a matching application info + * element in the list of applications. * - Updated attempts are also found in [[checkForLogs]] -- if the attempt's log file has grown, the * attempt is replaced by another one with a larger log size. * @@ -125,6 +126,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) private val pendingReplayTasksCount = new java.util.concurrent.atomic.AtomicInteger(0) private val storePath = conf.get(LOCAL_STORE_DIR).map(new File(_)) + private val fastInProgressParsing = conf.get(FAST_IN_PROGRESS_PARSING) // Visible for testing. private[history] val listing: KVStore = storePath.map { path => @@ -402,13 +404,13 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) */ private[history] def checkForLogs(): Unit = { try { - val newLastScanTime = getNewLastScanTime() + val newLastScanTime = clock.getTimeMillis() logDebug(s"Scanning $logDir with lastScanTime==$lastScanTime") val updated = Option(fs.listStatus(new Path(logDir))).map(_.toSeq).getOrElse(Nil) .filter { entry => !entry.isDirectory() && - // FsHistoryProvider generates a hidden file which can't be read. Accidentally + // FsHistoryProvider used to generate a hidden file which can't be read. Accidentally // reading a garbage file is safe, but we would log an error which can be scary to // the end-user. !entry.getPath().getName().startsWith(".") && @@ -417,15 +419,24 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) .filter { entry => try { val info = listing.read(classOf[LogInfo], entry.getPath().toString()) - if (info.fileSize < entry.getLen()) { - // Log size has changed, it should be parsed. - true - } else { + + if (info.appId.isDefined) { // If the SHS view has a valid application, update the time the file was last seen so - // that the entry is not deleted from the SHS listing. - if (info.appId.isDefined) { - listing.write(info.copy(lastProcessed = newLastScanTime)) + // that the entry is not deleted from the SHS listing. Also update the file size, in + // case the code below decides we don't need to parse the log. + listing.write(info.copy(lastProcessed = newLastScanTime, fileSize = entry.getLen())) + } + + if (info.fileSize < entry.getLen()) { + if (info.appId.isDefined && fastInProgressParsing) { + // When fast in-progress parsing is on, we don't need to re-parse when the + // size changes, but we do need to invalidate any existing UIs. + invalidateUI(info.appId.get, info.attemptId) + false + } else { + true } + } else { false } } catch { @@ -449,7 +460,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) val tasks = updated.map { entry => try { replayExecutor.submit(new Runnable { - override def run(): Unit = mergeApplicationListing(entry, newLastScanTime) + override def run(): Unit = mergeApplicationListing(entry, newLastScanTime, true) }) } catch { // let the iteration over the updated entries break, since an exception on @@ -542,25 +553,6 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) } } - private[history] def getNewLastScanTime(): Long = { - val fileName = "." + UUID.randomUUID().toString - val path = new Path(logDir, fileName) - val fos = fs.create(path) - - try { - fos.close() - fs.getFileStatus(path).getModificationTime - } catch { - case e: Exception => - logError("Exception encountered when attempting to update last scan time", e) - lastScanTime.get() - } finally { - if (!fs.delete(path, true)) { - logWarning(s"Error deleting ${path}") - } - } - } - override def writeEventLogs( appId: String, attemptId: Option[String], @@ -607,7 +599,10 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) /** * Replay the given log file, saving the application in the listing db. */ - protected def mergeApplicationListing(fileStatus: FileStatus, scanTime: Long): Unit = { + protected def mergeApplicationListing( + fileStatus: FileStatus, + scanTime: Long, + enableOptimizations: Boolean): Unit = { val eventsFilter: ReplayEventsFilter = { eventString => eventString.startsWith(APPL_START_EVENT_PREFIX) || eventString.startsWith(APPL_END_EVENT_PREFIX) || @@ -616,32 +611,118 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) } val logPath = fileStatus.getPath() + val appCompleted = isCompleted(logPath.getName()) + val reparseChunkSize = conf.get(END_EVENT_REPARSE_CHUNK_SIZE) + + // Enable halt support in listener if: + // - app in progress && fast parsing enabled + // - skipping to end event is enabled (regardless of in-progress state) + val shouldHalt = enableOptimizations && + ((!appCompleted && fastInProgressParsing) || reparseChunkSize > 0) + val bus = new ReplayListenerBus() - val listener = new AppListingListener(fileStatus, clock) + val listener = new AppListingListener(fileStatus, clock, shouldHalt) bus.addListener(listener) - replay(fileStatus, bus, eventsFilter = eventsFilter) - - val (appId, attemptId) = listener.applicationInfo match { - case Some(app) => - // Invalidate the existing UI for the reloaded app attempt, if any. See LoadedAppUI for a - // discussion on the UI lifecycle. - synchronized { - activeUIs.get((app.info.id, app.attempts.head.info.attemptId)).foreach { ui => - ui.invalidate() - ui.ui.store.close() + + logInfo(s"Parsing $logPath for listing data...") + Utils.tryWithResource(EventLoggingListener.openEventLog(logPath, fs)) { in => + bus.replay(in, logPath.toString, !appCompleted, eventsFilter) + } + + // If enabled above, the listing listener will halt parsing when there's enough information to + // create a listing entry. When the app is completed, or fast parsing is disabled, we still need + // to replay until the end of the log file to try to find the app end event. Instead of reading + // and parsing line by line, this code skips bytes from the underlying stream so that it is + // positioned somewhere close to the end of the log file. + // + // Because the application end event is written while some Spark subsystems such as the + // scheduler are still active, there is no guarantee that the end event will be the last + // in the log. So, to be safe, the code uses a configurable chunk to be re-parsed at + // the end of the file, and retries parsing the whole log later if the needed data is + // still not found. + // + // Note that skipping bytes in compressed files is still not cheap, but there are still some + // minor gains over the normal log parsing done by the replay bus. + // + // This code re-opens the file so that it knows where it's skipping to. This isn't as cheap as + // just skipping from the current position, but there isn't a a good way to detect what the + // current position is, since the replay listener bus buffers data internally. + val lookForEndEvent = shouldHalt && (appCompleted || !fastInProgressParsing) + if (lookForEndEvent && listener.applicationInfo.isDefined) { + Utils.tryWithResource(EventLoggingListener.openEventLog(logPath, fs)) { in => + val target = fileStatus.getLen() - reparseChunkSize + if (target > 0) { + logInfo(s"Looking for end event; skipping $target bytes from $logPath...") + var skipped = 0L + while (skipped < target) { + skipped += in.skip(target - skipped) } } + val source = Source.fromInputStream(in).getLines() + + // Because skipping may leave the stream in the middle of a line, read the next line + // before replaying. + if (target > 0) { + source.next() + } + + bus.replay(source, logPath.toString, !appCompleted, eventsFilter) + } + } + + logInfo(s"Finished parsing $logPath") + + listener.applicationInfo match { + case Some(app) if !lookForEndEvent || app.attempts.head.info.completed => + // In this case, we either didn't care about the end event, or we found it. So the + // listing data is good. + invalidateUI(app.info.id, app.attempts.head.info.attemptId) addListing(app) - (Some(app.info.id), app.attempts.head.info.attemptId) + listing.write(LogInfo(logPath.toString(), scanTime, Some(app.info.id), + app.attempts.head.info.attemptId, fileStatus.getLen())) + + // For a finished log, remove the corresponding "in progress" entry from the listing DB if + // the file is really gone. + if (appCompleted) { + val inProgressLog = logPath.toString() + EventLoggingListener.IN_PROGRESS + try { + // Fetch the entry first to avoid an RPC when it's already removed. + listing.read(classOf[LogInfo], inProgressLog) + if (!fs.isFile(new Path(inProgressLog))) { + listing.delete(classOf[LogInfo], inProgressLog) + } + } catch { + case _: NoSuchElementException => + } + } + + case Some(_) => + // In this case, the attempt is still not marked as finished but was expected to. This can + // mean the end event is before the configured threshold, so call the method again to + // re-parse the whole log. + logInfo(s"Reparsing $logPath since end event was not found.") + mergeApplicationListing(fileStatus, scanTime, false) case _ => // If the app hasn't written down its app ID to the logs, still record the entry in the // listing db, with an empty ID. This will make the log eligible for deletion if the app // does not make progress after the configured max log age. - (None, None) + listing.write(LogInfo(logPath.toString(), scanTime, None, None, fileStatus.getLen())) + } + } + + /** + * Invalidate an existing UI for a given app attempt. See LoadedAppUI for a discussion on the + * UI lifecycle. + */ + private def invalidateUI(appId: String, attemptId: Option[String]): Unit = { + synchronized { + activeUIs.get((appId, attemptId)).foreach { ui => + ui.invalidate() + ui.ui.store.close() + } } - listing.write(LogInfo(logPath.toString(), scanTime, appId, attemptId, fileStatus.getLen())) } /** @@ -696,29 +777,6 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) } } - /** - * Replays the events in the specified log file on the supplied `ReplayListenerBus`. - * `ReplayEventsFilter` determines what events are replayed. - */ - private def replay( - eventLog: FileStatus, - bus: ReplayListenerBus, - eventsFilter: ReplayEventsFilter = SELECT_ALL_FILTER): Unit = { - val logPath = eventLog.getPath() - val isCompleted = !logPath.getName().endsWith(EventLoggingListener.IN_PROGRESS) - logInfo(s"Replaying log path: $logPath") - // Note that the eventLog may have *increased* in size since when we grabbed the filestatus, - // and when we read the file here. That is OK -- it may result in an unnecessary refresh - // when there is no update, but will not result in missing an update. We *must* prevent - // an error the other way -- if we report a size bigger (ie later) than the file that is - // actually read, we may never refresh the app. FileStatus is guaranteed to be static - // after it's created, so we get a file size that is no bigger than what is actually read. - Utils.tryWithResource(EventLoggingListener.openEventLog(logPath, fs)) { in => - bus.replay(in, logPath.toString, !isCompleted, eventsFilter) - logInfo(s"Finished parsing $logPath") - } - } - /** * Rebuilds the application state store from its event log. */ @@ -741,8 +799,13 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) } replayBus.addListener(listener) try { - replay(eventLog, replayBus) + val path = eventLog.getPath() + logInfo(s"Parsing $path to re-build UI...") + Utils.tryWithResource(EventLoggingListener.openEventLog(path, fs)) { in => + replayBus.replay(in, path.toString(), maybeTruncated = !isCompleted(path.toString())) + } trackingStore.close(false) + logInfo(s"Finished parsing $path") } catch { case e: Exception => Utils.tryLogNonFatalError { @@ -881,6 +944,10 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) } } + private def isCompleted(name: String): Boolean = { + !name.endsWith(EventLoggingListener.IN_PROGRESS) + } + } private[history] object FsHistoryProvider { @@ -945,11 +1012,17 @@ private[history] class ApplicationInfoWrapper( } -private[history] class AppListingListener(log: FileStatus, clock: Clock) extends SparkListener { +private[history] class AppListingListener( + log: FileStatus, + clock: Clock, + haltEnabled: Boolean) extends SparkListener { private val app = new MutableApplicationInfo() private val attempt = new MutableAttemptInfo(log.getPath().getName(), log.getLen()) + private var gotEnvUpdate = false + private var halted = false + override def onApplicationStart(event: SparkListenerApplicationStart): Unit = { app.id = event.appId.orNull app.name = event.appName @@ -958,6 +1031,8 @@ private[history] class AppListingListener(log: FileStatus, clock: Clock) extends attempt.startTime = new Date(event.time) attempt.lastUpdated = new Date(clock.getTimeMillis()) attempt.sparkUser = event.sparkUser + + checkProgress() } override def onApplicationEnd(event: SparkListenerApplicationEnd): Unit = { @@ -968,11 +1043,18 @@ private[history] class AppListingListener(log: FileStatus, clock: Clock) extends } override def onEnvironmentUpdate(event: SparkListenerEnvironmentUpdate): Unit = { - val allProperties = event.environmentDetails("Spark Properties").toMap - attempt.viewAcls = allProperties.get("spark.ui.view.acls") - attempt.adminAcls = allProperties.get("spark.admin.acls") - attempt.viewAclsGroups = allProperties.get("spark.ui.view.acls.groups") - attempt.adminAclsGroups = allProperties.get("spark.admin.acls.groups") + // Only parse the first env update, since any future changes don't have any effect on + // the ACLs set for the UI. + if (!gotEnvUpdate) { + val allProperties = event.environmentDetails("Spark Properties").toMap + attempt.viewAcls = allProperties.get("spark.ui.view.acls") + attempt.adminAcls = allProperties.get("spark.admin.acls") + attempt.viewAclsGroups = allProperties.get("spark.ui.view.acls.groups") + attempt.adminAclsGroups = allProperties.get("spark.admin.acls.groups") + + gotEnvUpdate = true + checkProgress() + } } override def onOtherEvent(event: SparkListenerEvent): Unit = event match { @@ -989,6 +1071,17 @@ private[history] class AppListingListener(log: FileStatus, clock: Clock) extends } } + /** + * Throws a halt exception to stop replay if enough data to create the app listing has been + * read. + */ + private def checkProgress(): Unit = { + if (haltEnabled && !halted && app.id != null && gotEnvUpdate) { + halted = true + throw new HaltReplayException() + } + } + private class MutableApplicationInfo { var id: String = null var name: String = null diff --git a/core/src/main/scala/org/apache/spark/deploy/history/config.scala b/core/src/main/scala/org/apache/spark/deploy/history/config.scala index efdbf672bb52f..25ba9edb9e014 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/config.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/config.scala @@ -49,4 +49,19 @@ private[spark] object config { .intConf .createWithDefault(18080) + val FAST_IN_PROGRESS_PARSING = + ConfigBuilder("spark.history.fs.inProgressOptimization.enabled") + .doc("Enable optimized handling of in-progress logs. This option may leave finished " + + "applications that fail to rename their event logs listed as in-progress.") + .booleanConf + .createWithDefault(true) + + val END_EVENT_REPARSE_CHUNK_SIZE = + ConfigBuilder("spark.history.fs.endEventReparseChunkSize") + .doc("How many bytes to parse at the end of log files looking for the end event. " + + "This is used to speed up generation of application listings by skipping unnecessary " + + "parts of event log files. It can be disabled by setting this config to 0.") + .bytesConf(ByteUnit.BYTE) + .createWithDefaultString("1m") + } diff --git a/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala index c9cd662f5709d..226c23733c870 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala @@ -115,6 +115,8 @@ private[spark] class ReplayListenerBus extends SparkListenerBus with Logging { } } } catch { + case e: HaltReplayException => + // Just stop replay. case _: EOFException if maybeTruncated => case ioe: IOException => throw ioe @@ -124,8 +126,17 @@ private[spark] class ReplayListenerBus extends SparkListenerBus with Logging { } } + override protected def isIgnorableException(e: Throwable): Boolean = { + e.isInstanceOf[HaltReplayException] + } + } +/** + * Exception that can be thrown by listeners to halt replay. This is handled by ReplayListenerBus + * only, and will cause errors if thrown when using other bus implementations. + */ +private[spark] class HaltReplayException extends RuntimeException private[spark] object ReplayListenerBus { diff --git a/core/src/main/scala/org/apache/spark/util/ListenerBus.scala b/core/src/main/scala/org/apache/spark/util/ListenerBus.scala index 76a56298aaebc..b25a731401f23 100644 --- a/core/src/main/scala/org/apache/spark/util/ListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/util/ListenerBus.scala @@ -81,7 +81,7 @@ private[spark] trait ListenerBus[L <: AnyRef, E] extends Logging { try { doPostEvent(listener, event) } catch { - case NonFatal(e) => + case NonFatal(e) if !isIgnorableException(e) => logError(s"Listener ${Utils.getFormattedClassName(listener)} threw an exception", e) } finally { if (maybeTimerContext != null) { @@ -97,6 +97,9 @@ private[spark] trait ListenerBus[L <: AnyRef, E] extends Logging { */ protected def doPostEvent(listener: L, event: E): Unit + /** Allows bus implementations to prevent error logging for certain exceptions. */ + protected def isIgnorableException(e: Throwable): Boolean = false + private[spark] def findListenersByClass[T <: L : ClassTag](): Seq[T] = { val c = implicitly[ClassTag[T]].runtimeClass listeners.asScala.filter(_.getClass == c).map(_.asInstanceOf[T]).toSeq diff --git a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala index 0ba57bf4563c1..77b239489d489 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala @@ -31,7 +31,7 @@ import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.hdfs.DistributedFileSystem import org.json4s.jackson.JsonMethods._ import org.mockito.Matchers.any -import org.mockito.Mockito.{doReturn, mock, spy, verify} +import org.mockito.Mockito.{mock, spy, verify} import org.scalatest.BeforeAndAfter import org.scalatest.Matchers import org.scalatest.concurrent.Eventually._ @@ -151,8 +151,9 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc var mergeApplicationListingCall = 0 override protected def mergeApplicationListing( fileStatus: FileStatus, - lastSeen: Long): Unit = { - super.mergeApplicationListing(fileStatus, lastSeen) + lastSeen: Long, + enableSkipToEnd: Boolean): Unit = { + super.mergeApplicationListing(fileStatus, lastSeen, enableSkipToEnd) mergeApplicationListingCall += 1 } } @@ -256,14 +257,13 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc ) updateAndCheck(provider) { list => - list should not be (null) list.size should be (1) list.head.attempts.size should be (3) list.head.attempts.head.attemptId should be (Some("attempt3")) } val app2Attempt1 = newLogFile("app2", Some("attempt1"), inProgress = false) - writeFile(attempt1, true, None, + writeFile(app2Attempt1, true, None, SparkListenerApplicationStart("app2", Some("app2"), 5L, "test", Some("attempt1")), SparkListenerApplicationEnd(6L) ) @@ -649,8 +649,7 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc // Add more info to the app log, and trigger the provider to update things. writeFile(appLog, true, None, SparkListenerApplicationStart(appId, Some(appId), 1L, "test", None), - SparkListenerJobStart(0, 1L, Nil, null), - SparkListenerApplicationEnd(5L) + SparkListenerJobStart(0, 1L, Nil, null) ) provider.checkForLogs() @@ -668,11 +667,12 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc test("clean up stale app information") { val storeDir = Utils.createTempDir() val conf = createTestConf().set(LOCAL_STORE_DIR, storeDir.getAbsolutePath()) - val provider = spy(new FsHistoryProvider(conf)) + val clock = new ManualClock() + val provider = spy(new FsHistoryProvider(conf, clock)) val appId = "new1" // Write logs for two app attempts. - doReturn(1L).when(provider).getNewLastScanTime() + clock.advance(1) val attempt1 = newLogFile(appId, Some("1"), inProgress = false) writeFile(attempt1, true, None, SparkListenerApplicationStart(appId, Some(appId), 1L, "test", Some("1")), @@ -697,7 +697,7 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc // Delete the underlying log file for attempt 1 and rescan. The UI should go away, but since // attempt 2 still exists, listing data should be there. - doReturn(2L).when(provider).getNewLastScanTime() + clock.advance(1) attempt1.delete() updateAndCheck(provider) { list => assert(list.size === 1) @@ -708,7 +708,7 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc assert(provider.getAppUI(appId, None) === None) // Delete the second attempt's log file. Now everything should go away. - doReturn(3L).when(provider).getNewLastScanTime() + clock.advance(1) attempt2.delete() updateAndCheck(provider) { list => assert(list.isEmpty) @@ -718,9 +718,7 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc test("SPARK-21571: clean up removes invalid history files") { val clock = new ManualClock() val conf = createTestConf().set(MAX_LOG_AGE_S.key, s"2d") - val provider = new FsHistoryProvider(conf, clock) { - override def getNewLastScanTime(): Long = clock.getTimeMillis() - } + val provider = new FsHistoryProvider(conf, clock) // Create 0-byte size inprogress and complete files var logCount = 0 @@ -772,6 +770,54 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc assert(new File(testDir.toURI).listFiles().size === validLogCount) } + test("always find end event for finished apps") { + // Create a log file where the end event is before the configure chunk to be reparsed at + // the end of the file. The correct listing should still be generated. + val log = newLogFile("end-event-test", None, inProgress = false) + writeFile(log, true, None, + Seq( + SparkListenerApplicationStart("end-event-test", Some("end-event-test"), 1L, "test", None), + SparkListenerEnvironmentUpdate(Map( + "Spark Properties" -> Seq.empty, + "JVM Information" -> Seq.empty, + "System Properties" -> Seq.empty, + "Classpath Entries" -> Seq.empty + )), + SparkListenerApplicationEnd(5L) + ) ++ (1 to 1000).map { i => SparkListenerJobStart(i, i, Nil) }: _*) + + val conf = createTestConf().set(END_EVENT_REPARSE_CHUNK_SIZE.key, s"1k") + val provider = new FsHistoryProvider(conf) + updateAndCheck(provider) { list => + assert(list.size === 1) + assert(list(0).attempts.size === 1) + assert(list(0).attempts(0).completed) + } + } + + test("parse event logs with optimizations off") { + val conf = createTestConf() + .set(END_EVENT_REPARSE_CHUNK_SIZE, 0L) + .set(FAST_IN_PROGRESS_PARSING, false) + val provider = new FsHistoryProvider(conf) + + val complete = newLogFile("complete", None, inProgress = false) + writeFile(complete, true, None, + SparkListenerApplicationStart("complete", Some("complete"), 1L, "test", None), + SparkListenerApplicationEnd(5L) + ) + + val incomplete = newLogFile("incomplete", None, inProgress = true) + writeFile(incomplete, true, None, + SparkListenerApplicationStart("incomplete", Some("incomplete"), 1L, "test", None) + ) + + updateAndCheck(provider) { list => + list.size should be (2) + list.count(_.attempts.head.completed) should be (1) + } + } + /** * Asks the provider to check for logs and calls a function to perform checks on the updated * app list. Example: @@ -815,7 +861,8 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc private def createTestConf(inMemory: Boolean = false): SparkConf = { val conf = new SparkConf() - .set("spark.history.fs.logDirectory", testDir.getAbsolutePath()) + .set(EVENT_LOG_DIR, testDir.getAbsolutePath()) + .set(FAST_IN_PROGRESS_PARSING, true) if (!inMemory) { conf.set(LOCAL_STORE_DIR, Utils.createTempDir().getAbsolutePath()) @@ -848,4 +895,3 @@ class TestGroupsMappingProvider extends GroupMappingServiceProvider { mappings.get(username).map(Set(_)).getOrElse(Set.empty) } } - From 3cb82047f2f51af553df09b9323796af507d36f8 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Wed, 11 Apr 2018 10:13:44 -0500 Subject: [PATCH 0603/1161] [SPARK-22941][CORE] Do not exit JVM when submit fails with in-process launcher. The current in-process launcher implementation just calls the SparkSubmit object, which, in case of errors, will more often than not exit the JVM. This is not desirable since this launcher is meant to be used inside other applications, and that would kill the application. The change turns SparkSubmit into a class, and abstracts aways some of the functionality used to print error messages and abort the submission process. The default implementation uses the logging system for messages, and throws exceptions for errors. As part of that I also moved some code that doesn't really belong in SparkSubmit to a better location. The command line invocation of spark-submit now uses a special implementation of the SparkSubmit class that overrides those behaviors to do what is expected from the command line version (print to the terminal, exit the JVM, etc). A lot of the changes are to replace calls to methods such as "printErrorAndExit" with the new API. As part of adding tests for this, I had to fix some small things in the launcher option parser so that things like "--version" can work when used in the launcher library. There is still code that prints directly to the terminal, like all the Ivy-related code in SparkSubmitUtils, and other areas where some re-factoring would help, like the CommandLineUtils class, but I chose to leave those alone to keep this change more focused. Aside from existing and added unit tests, I ran command line tools with a bunch of different arguments to make sure messages and errors behave like before. Author: Marcelo Vanzin Closes #20925 from vanzin/SPARK-22941. --- .../apache/spark/deploy/DependencyUtils.scala | 30 +- .../org/apache/spark/deploy/SparkSubmit.scala | 318 +++++++++--------- .../spark/deploy/SparkSubmitArguments.scala | 90 +++-- .../spark/deploy/worker/DriverWrapper.scala | 4 +- .../apache/spark/util/CommandLineUtils.scala | 18 +- .../spark/launcher/SparkLauncherSuite.java | 37 +- .../spark/deploy/SparkSubmitSuite.scala | 69 ++-- .../rest/StandaloneRestSubmitSuite.scala | 2 +- .../spark/launcher/AbstractLauncher.java | 6 +- .../spark/launcher/InProcessLauncher.java | 14 +- .../launcher/SparkSubmitCommandBuilder.java | 82 +++-- project/MimaExcludes.scala | 7 +- .../deploy/mesos/MesosClusterDispatcher.scala | 10 +- .../MesosClusterDispatcherArguments.scala | 6 +- 14 files changed, 401 insertions(+), 292 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/DependencyUtils.scala b/core/src/main/scala/org/apache/spark/deploy/DependencyUtils.scala index fac834a70b893..178bdcfccb603 100644 --- a/core/src/main/scala/org/apache/spark/deploy/DependencyUtils.scala +++ b/core/src/main/scala/org/apache/spark/deploy/DependencyUtils.scala @@ -25,9 +25,10 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.spark.{SecurityManager, SparkConf, SparkException} +import org.apache.spark.internal.Logging import org.apache.spark.util.{MutableURLClassLoader, Utils} -private[deploy] object DependencyUtils { +private[deploy] object DependencyUtils extends Logging { def resolveMavenDependencies( packagesExclusions: String, @@ -75,7 +76,7 @@ private[deploy] object DependencyUtils { def addJarsToClassPath(jars: String, loader: MutableURLClassLoader): Unit = { if (jars != null) { for (jar <- jars.split(",")) { - SparkSubmit.addJarToClasspath(jar, loader) + addJarToClasspath(jar, loader) } } } @@ -151,6 +152,31 @@ private[deploy] object DependencyUtils { }.mkString(",") } + def addJarToClasspath(localJar: String, loader: MutableURLClassLoader): Unit = { + val uri = Utils.resolveURI(localJar) + uri.getScheme match { + case "file" | "local" => + val file = new File(uri.getPath) + if (file.exists()) { + loader.addURL(file.toURI.toURL) + } else { + logWarning(s"Local jar $file does not exist, skipping.") + } + case _ => + logWarning(s"Skip remote jar $uri.") + } + } + + /** + * Merge a sequence of comma-separated file lists, some of which may be null to indicate + * no files, into a single comma-separated string. + */ + def mergeFileLists(lists: String*): String = { + val merged = lists.filterNot(StringUtils.isBlank) + .flatMap(Utils.stringToSeq) + if (merged.nonEmpty) merged.mkString(",") else null + } + private def splitOnFragment(path: String): (URI, Option[String]) = { val uri = Utils.resolveURI(path) val withoutFragment = new URI(uri.getScheme, uri.getSchemeSpecificPart, null) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index eddbedeb1024d..427c797755b84 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -58,7 +58,7 @@ import org.apache.spark.util._ */ private[deploy] object SparkSubmitAction extends Enumeration { type SparkSubmitAction = Value - val SUBMIT, KILL, REQUEST_STATUS = Value + val SUBMIT, KILL, REQUEST_STATUS, PRINT_VERSION = Value } /** @@ -67,78 +67,32 @@ private[deploy] object SparkSubmitAction extends Enumeration { * This program handles setting up the classpath with relevant Spark dependencies and provides * a layer over the different cluster managers and deploy modes that Spark supports. */ -object SparkSubmit extends CommandLineUtils with Logging { +private[spark] class SparkSubmit extends Logging { import DependencyUtils._ + import SparkSubmit._ - // Cluster managers - private val YARN = 1 - private val STANDALONE = 2 - private val MESOS = 4 - private val LOCAL = 8 - private val KUBERNETES = 16 - private val ALL_CLUSTER_MGRS = YARN | STANDALONE | MESOS | LOCAL | KUBERNETES - - // Deploy modes - private val CLIENT = 1 - private val CLUSTER = 2 - private val ALL_DEPLOY_MODES = CLIENT | CLUSTER - - // Special primary resource names that represent shells rather than application jars. - private val SPARK_SHELL = "spark-shell" - private val PYSPARK_SHELL = "pyspark-shell" - private val SPARKR_SHELL = "sparkr-shell" - private val SPARKR_PACKAGE_ARCHIVE = "sparkr.zip" - private val R_PACKAGE_ARCHIVE = "rpkg.zip" - - private val CLASS_NOT_FOUND_EXIT_STATUS = 101 - - // Following constants are visible for testing. - private[deploy] val YARN_CLUSTER_SUBMIT_CLASS = - "org.apache.spark.deploy.yarn.YarnClusterApplication" - private[deploy] val REST_CLUSTER_SUBMIT_CLASS = classOf[RestSubmissionClientApp].getName() - private[deploy] val STANDALONE_CLUSTER_SUBMIT_CLASS = classOf[ClientApp].getName() - private[deploy] val KUBERNETES_CLUSTER_SUBMIT_CLASS = - "org.apache.spark.deploy.k8s.submit.KubernetesClientApplication" - - // scalastyle:off println - private[spark] def printVersionAndExit(): Unit = { - printStream.println("""Welcome to - ____ __ - / __/__ ___ _____/ /__ - _\ \/ _ \/ _ `/ __/ '_/ - /___/ .__/\_,_/_/ /_/\_\ version %s - /_/ - """.format(SPARK_VERSION)) - printStream.println("Using Scala %s, %s, %s".format( - Properties.versionString, Properties.javaVmName, Properties.javaVersion)) - printStream.println("Branch %s".format(SPARK_BRANCH)) - printStream.println("Compiled by user %s on %s".format(SPARK_BUILD_USER, SPARK_BUILD_DATE)) - printStream.println("Revision %s".format(SPARK_REVISION)) - printStream.println("Url %s".format(SPARK_REPO_URL)) - printStream.println("Type --help for more information.") - exitFn(0) - } - // scalastyle:on println - - override def main(args: Array[String]): Unit = { + def doSubmit(args: Array[String]): Unit = { // Initialize logging if it hasn't been done yet. Keep track of whether logging needs to // be reset before the application starts. val uninitLog = initializeLogIfNecessary(true, silent = true) - val appArgs = new SparkSubmitArguments(args) + val appArgs = parseArguments(args) if (appArgs.verbose) { - // scalastyle:off println - printStream.println(appArgs) - // scalastyle:on println + logInfo(appArgs.toString) } appArgs.action match { case SparkSubmitAction.SUBMIT => submit(appArgs, uninitLog) case SparkSubmitAction.KILL => kill(appArgs) case SparkSubmitAction.REQUEST_STATUS => requestStatus(appArgs) + case SparkSubmitAction.PRINT_VERSION => printVersion() } } + protected def parseArguments(args: Array[String]): SparkSubmitArguments = { + new SparkSubmitArguments(args) + } + /** * Kill an existing submission using the REST protocol. Standalone and Mesos cluster mode only. */ @@ -156,6 +110,24 @@ object SparkSubmit extends CommandLineUtils with Logging { .requestSubmissionStatus(args.submissionToRequestStatusFor) } + /** Print version information to the log. */ + private def printVersion(): Unit = { + logInfo("""Welcome to + ____ __ + / __/__ ___ _____/ /__ + _\ \/ _ \/ _ `/ __/ '_/ + /___/ .__/\_,_/_/ /_/\_\ version %s + /_/ + """.format(SPARK_VERSION)) + logInfo("Using Scala %s, %s, %s".format( + Properties.versionString, Properties.javaVmName, Properties.javaVersion)) + logInfo(s"Branch $SPARK_BRANCH") + logInfo(s"Compiled by user $SPARK_BUILD_USER on $SPARK_BUILD_DATE") + logInfo(s"Revision $SPARK_REVISION") + logInfo(s"Url $SPARK_REPO_URL") + logInfo("Type --help for more information.") + } + /** * Submit the application using the provided parameters. * @@ -185,10 +157,7 @@ object SparkSubmit extends CommandLineUtils with Logging { // makes the message printed to the output by the JVM not very helpful. Instead, // detect exceptions with empty stack traces here, and treat them differently. if (e.getStackTrace().length == 0) { - // scalastyle:off println - printStream.println(s"ERROR: ${e.getClass().getName()}: ${e.getMessage()}") - // scalastyle:on println - exitFn(1) + error(s"ERROR: ${e.getClass().getName()}: ${e.getMessage()}") } else { throw e } @@ -210,14 +179,11 @@ object SparkSubmit extends CommandLineUtils with Logging { // to use the legacy gateway if the master endpoint turns out to be not a REST server. if (args.isStandaloneCluster && args.useRest) { try { - // scalastyle:off println - printStream.println("Running Spark using the REST application submission protocol.") - // scalastyle:on println - doRunMain() + logInfo("Running Spark using the REST application submission protocol.") } catch { // Fail over to use the legacy submission gateway case e: SubmitRestConnectionException => - printWarning(s"Master endpoint ${args.master} was not a REST server. " + + logWarning(s"Master endpoint ${args.master} was not a REST server. " + "Falling back to legacy submission gateway instead.") args.useRest = false submit(args, false) @@ -245,19 +211,6 @@ object SparkSubmit extends CommandLineUtils with Logging { args: SparkSubmitArguments, conf: Option[HadoopConfiguration] = None) : (Seq[String], Seq[String], SparkConf, String) = { - try { - doPrepareSubmitEnvironment(args, conf) - } catch { - case e: SparkException => - printErrorAndExit(e.getMessage) - throw e - } - } - - private def doPrepareSubmitEnvironment( - args: SparkSubmitArguments, - conf: Option[HadoopConfiguration] = None) - : (Seq[String], Seq[String], SparkConf, String) = { // Return values val childArgs = new ArrayBuffer[String]() val childClasspath = new ArrayBuffer[String]() @@ -268,7 +221,7 @@ object SparkSubmit extends CommandLineUtils with Logging { val clusterManager: Int = args.master match { case "yarn" => YARN case "yarn-client" | "yarn-cluster" => - printWarning(s"Master ${args.master} is deprecated since 2.0." + + logWarning(s"Master ${args.master} is deprecated since 2.0." + " Please use master \"yarn\" with specified deploy mode instead.") YARN case m if m.startsWith("spark") => STANDALONE @@ -276,7 +229,7 @@ object SparkSubmit extends CommandLineUtils with Logging { case m if m.startsWith("k8s") => KUBERNETES case m if m.startsWith("local") => LOCAL case _ => - printErrorAndExit("Master must either be yarn or start with spark, mesos, k8s, or local") + error("Master must either be yarn or start with spark, mesos, k8s, or local") -1 } @@ -284,7 +237,9 @@ object SparkSubmit extends CommandLineUtils with Logging { var deployMode: Int = args.deployMode match { case "client" | null => CLIENT case "cluster" => CLUSTER - case _ => printErrorAndExit("Deploy mode must be either client or cluster"); -1 + case _ => + error("Deploy mode must be either client or cluster") + -1 } // Because the deprecated way of specifying "yarn-cluster" and "yarn-client" encapsulate both @@ -296,16 +251,16 @@ object SparkSubmit extends CommandLineUtils with Logging { deployMode = CLUSTER args.master = "yarn" case ("yarn-cluster", "client") => - printErrorAndExit("Client deploy mode is not compatible with master \"yarn-cluster\"") + error("Client deploy mode is not compatible with master \"yarn-cluster\"") case ("yarn-client", "cluster") => - printErrorAndExit("Cluster deploy mode is not compatible with master \"yarn-client\"") + error("Cluster deploy mode is not compatible with master \"yarn-client\"") case (_, mode) => args.master = "yarn" } // Make sure YARN is included in our build if we're trying to use it if (!Utils.classIsLoadable(YARN_CLUSTER_SUBMIT_CLASS) && !Utils.isTesting) { - printErrorAndExit( + error( "Could not load YARN classes. " + "This copy of Spark may not have been compiled with YARN support.") } @@ -315,7 +270,7 @@ object SparkSubmit extends CommandLineUtils with Logging { args.master = Utils.checkAndGetK8sMasterUrl(args.master) // Make sure KUBERNETES is included in our build if we're trying to use it if (!Utils.classIsLoadable(KUBERNETES_CLUSTER_SUBMIT_CLASS) && !Utils.isTesting) { - printErrorAndExit( + error( "Could not load KUBERNETES classes. " + "This copy of Spark may not have been compiled with KUBERNETES support.") } @@ -324,23 +279,23 @@ object SparkSubmit extends CommandLineUtils with Logging { // Fail fast, the following modes are not supported or applicable (clusterManager, deployMode) match { case (STANDALONE, CLUSTER) if args.isPython => - printErrorAndExit("Cluster deploy mode is currently not supported for python " + + error("Cluster deploy mode is currently not supported for python " + "applications on standalone clusters.") case (STANDALONE, CLUSTER) if args.isR => - printErrorAndExit("Cluster deploy mode is currently not supported for R " + + error("Cluster deploy mode is currently not supported for R " + "applications on standalone clusters.") case (KUBERNETES, _) if args.isPython => - printErrorAndExit("Python applications are currently not supported for Kubernetes.") + error("Python applications are currently not supported for Kubernetes.") case (KUBERNETES, _) if args.isR => - printErrorAndExit("R applications are currently not supported for Kubernetes.") + error("R applications are currently not supported for Kubernetes.") case (LOCAL, CLUSTER) => - printErrorAndExit("Cluster deploy mode is not compatible with master \"local\"") + error("Cluster deploy mode is not compatible with master \"local\"") case (_, CLUSTER) if isShell(args.primaryResource) => - printErrorAndExit("Cluster deploy mode is not applicable to Spark shells.") + error("Cluster deploy mode is not applicable to Spark shells.") case (_, CLUSTER) if isSqlShell(args.mainClass) => - printErrorAndExit("Cluster deploy mode is not applicable to Spark SQL shell.") + error("Cluster deploy mode is not applicable to Spark SQL shell.") case (_, CLUSTER) if isThriftServer(args.mainClass) => - printErrorAndExit("Cluster deploy mode is not applicable to Spark Thrift server.") + error("Cluster deploy mode is not applicable to Spark Thrift server.") case _ => } @@ -493,11 +448,11 @@ object SparkSubmit extends CommandLineUtils with Logging { if (args.isR && clusterManager == YARN) { val sparkRPackagePath = RUtils.localSparkRPackagePath if (sparkRPackagePath.isEmpty) { - printErrorAndExit("SPARK_HOME does not exist for R application in YARN mode.") + error("SPARK_HOME does not exist for R application in YARN mode.") } val sparkRPackageFile = new File(sparkRPackagePath.get, SPARKR_PACKAGE_ARCHIVE) if (!sparkRPackageFile.exists()) { - printErrorAndExit(s"$SPARKR_PACKAGE_ARCHIVE does not exist for R application in YARN mode.") + error(s"$SPARKR_PACKAGE_ARCHIVE does not exist for R application in YARN mode.") } val sparkRPackageURI = Utils.resolveURI(sparkRPackageFile.getAbsolutePath).toString @@ -510,7 +465,7 @@ object SparkSubmit extends CommandLineUtils with Logging { val rPackageFile = RPackageUtils.zipRLibraries(new File(RUtils.rPackages.get), R_PACKAGE_ARCHIVE) if (!rPackageFile.exists()) { - printErrorAndExit("Failed to zip all the built R packages.") + error("Failed to zip all the built R packages.") } val rPackageURI = Utils.resolveURI(rPackageFile.getAbsolutePath).toString @@ -521,12 +476,12 @@ object SparkSubmit extends CommandLineUtils with Logging { // TODO: Support distributing R packages with standalone cluster if (args.isR && clusterManager == STANDALONE && !RUtils.rPackages.isEmpty) { - printErrorAndExit("Distributing R packages with standalone cluster is not supported.") + error("Distributing R packages with standalone cluster is not supported.") } // TODO: Support distributing R packages with mesos cluster if (args.isR && clusterManager == MESOS && !RUtils.rPackages.isEmpty) { - printErrorAndExit("Distributing R packages with mesos cluster is not supported.") + error("Distributing R packages with mesos cluster is not supported.") } // If we're running an R app, set the main class to our specific R runner @@ -799,9 +754,7 @@ object SparkSubmit extends CommandLineUtils with Logging { private def setRMPrincipal(sparkConf: SparkConf): Unit = { val shortUserName = UserGroupInformation.getCurrentUser.getShortUserName val key = s"spark.hadoop.${YarnConfiguration.RM_PRINCIPAL}" - // scalastyle:off println - printStream.println(s"Setting ${key} to ${shortUserName}") - // scalastyle:off println + logInfo(s"Setting ${key} to ${shortUserName}") sparkConf.set(key, shortUserName) } @@ -817,16 +770,14 @@ object SparkSubmit extends CommandLineUtils with Logging { sparkConf: SparkConf, childMainClass: String, verbose: Boolean): Unit = { - // scalastyle:off println if (verbose) { - printStream.println(s"Main class:\n$childMainClass") - printStream.println(s"Arguments:\n${childArgs.mkString("\n")}") + logInfo(s"Main class:\n$childMainClass") + logInfo(s"Arguments:\n${childArgs.mkString("\n")}") // sysProps may contain sensitive information, so redact before printing - printStream.println(s"Spark config:\n${Utils.redact(sparkConf.getAll.toMap).mkString("\n")}") - printStream.println(s"Classpath elements:\n${childClasspath.mkString("\n")}") - printStream.println("\n") + logInfo(s"Spark config:\n${Utils.redact(sparkConf.getAll.toMap).mkString("\n")}") + logInfo(s"Classpath elements:\n${childClasspath.mkString("\n")}") + logInfo("\n") } - // scalastyle:on println val loader = if (sparkConf.get(DRIVER_USER_CLASS_PATH_FIRST)) { @@ -848,23 +799,19 @@ object SparkSubmit extends CommandLineUtils with Logging { mainClass = Utils.classForName(childMainClass) } catch { case e: ClassNotFoundException => - e.printStackTrace(printStream) + logWarning(s"Failed to load $childMainClass.", e) if (childMainClass.contains("thriftserver")) { - // scalastyle:off println - printStream.println(s"Failed to load main class $childMainClass.") - printStream.println("You need to build Spark with -Phive and -Phive-thriftserver.") - // scalastyle:on println + logInfo(s"Failed to load main class $childMainClass.") + logInfo("You need to build Spark with -Phive and -Phive-thriftserver.") } - System.exit(CLASS_NOT_FOUND_EXIT_STATUS) + throw new SparkUserAppException(CLASS_NOT_FOUND_EXIT_STATUS) case e: NoClassDefFoundError => - e.printStackTrace(printStream) + logWarning(s"Failed to load $childMainClass: ${e.getMessage()}") if (e.getMessage.contains("org/apache/hadoop/hive")) { - // scalastyle:off println - printStream.println(s"Failed to load hive class.") - printStream.println("You need to build Spark with -Phive and -Phive-thriftserver.") - // scalastyle:on println + logInfo(s"Failed to load hive class.") + logInfo("You need to build Spark with -Phive and -Phive-thriftserver.") } - System.exit(CLASS_NOT_FOUND_EXIT_STATUS) + throw new SparkUserAppException(CLASS_NOT_FOUND_EXIT_STATUS) } val app: SparkApplication = if (classOf[SparkApplication].isAssignableFrom(mainClass)) { @@ -872,7 +819,7 @@ object SparkSubmit extends CommandLineUtils with Logging { } else { // SPARK-4170 if (classOf[scala.App].isAssignableFrom(mainClass)) { - printWarning("Subclasses of scala.App may not work correctly. Use a main() method instead.") + logWarning("Subclasses of scala.App may not work correctly. Use a main() method instead.") } new JavaMainApplication(mainClass) } @@ -891,29 +838,90 @@ object SparkSubmit extends CommandLineUtils with Logging { app.start(childArgs.toArray, sparkConf) } catch { case t: Throwable => - findCause(t) match { - case SparkUserAppException(exitCode) => - System.exit(exitCode) - - case t: Throwable => - throw t - } + throw findCause(t) } } - private[deploy] def addJarToClasspath(localJar: String, loader: MutableURLClassLoader) { - val uri = Utils.resolveURI(localJar) - uri.getScheme match { - case "file" | "local" => - val file = new File(uri.getPath) - if (file.exists()) { - loader.addURL(file.toURI.toURL) - } else { - printWarning(s"Local jar $file does not exist, skipping.") + /** Throw a SparkException with the given error message. */ + private def error(msg: String): Unit = throw new SparkException(msg) + +} + + +/** + * This entry point is used by the launcher library to start in-process Spark applications. + */ +private[spark] object InProcessSparkSubmit { + + def main(args: Array[String]): Unit = { + val submit = new SparkSubmit() + submit.doSubmit(args) + } + +} + +object SparkSubmit extends CommandLineUtils with Logging { + + // Cluster managers + private val YARN = 1 + private val STANDALONE = 2 + private val MESOS = 4 + private val LOCAL = 8 + private val KUBERNETES = 16 + private val ALL_CLUSTER_MGRS = YARN | STANDALONE | MESOS | LOCAL | KUBERNETES + + // Deploy modes + private val CLIENT = 1 + private val CLUSTER = 2 + private val ALL_DEPLOY_MODES = CLIENT | CLUSTER + + // Special primary resource names that represent shells rather than application jars. + private val SPARK_SHELL = "spark-shell" + private val PYSPARK_SHELL = "pyspark-shell" + private val SPARKR_SHELL = "sparkr-shell" + private val SPARKR_PACKAGE_ARCHIVE = "sparkr.zip" + private val R_PACKAGE_ARCHIVE = "rpkg.zip" + + private val CLASS_NOT_FOUND_EXIT_STATUS = 101 + + // Following constants are visible for testing. + private[deploy] val YARN_CLUSTER_SUBMIT_CLASS = + "org.apache.spark.deploy.yarn.YarnClusterApplication" + private[deploy] val REST_CLUSTER_SUBMIT_CLASS = classOf[RestSubmissionClientApp].getName() + private[deploy] val STANDALONE_CLUSTER_SUBMIT_CLASS = classOf[ClientApp].getName() + private[deploy] val KUBERNETES_CLUSTER_SUBMIT_CLASS = + "org.apache.spark.deploy.k8s.submit.KubernetesClientApplication" + + override def main(args: Array[String]): Unit = { + val submit = new SparkSubmit() { + self => + + override protected def parseArguments(args: Array[String]): SparkSubmitArguments = { + new SparkSubmitArguments(args) { + override protected def logInfo(msg: => String): Unit = self.logInfo(msg) + + override protected def logWarning(msg: => String): Unit = self.logWarning(msg) } - case _ => - printWarning(s"Skip remote jar $uri.") + } + + override protected def logInfo(msg: => String): Unit = printMessage(msg) + + override protected def logWarning(msg: => String): Unit = printMessage(s"Warning: $msg") + + override def doSubmit(args: Array[String]): Unit = { + try { + super.doSubmit(args) + } catch { + case e: SparkUserAppException => + exitFn(e.exitCode) + case e: SparkException => + printErrorAndExit(e.getMessage()) + } + } + } + + submit.doSubmit(args) } /** @@ -962,17 +970,6 @@ object SparkSubmit extends CommandLineUtils with Logging { res == SparkLauncher.NO_RESOURCE } - /** - * Merge a sequence of comma-separated file lists, some of which may be null to indicate - * no files, into a single comma-separated string. - */ - private[deploy] def mergeFileLists(lists: String*): String = { - val merged = lists.filterNot(StringUtils.isBlank) - .flatMap(_.split(",")) - .mkString(",") - if (merged == "") null else merged - } - } /** Provides utility functions to be used inside SparkSubmit. */ @@ -1000,12 +997,12 @@ private[spark] object SparkSubmitUtils { override def toString: String = s"$groupId:$artifactId:$version" } -/** - * Extracts maven coordinates from a comma-delimited string. Coordinates should be provided - * in the format `groupId:artifactId:version` or `groupId/artifactId:version`. - * @param coordinates Comma-delimited string of maven coordinates - * @return Sequence of Maven coordinates - */ + /** + * Extracts maven coordinates from a comma-delimited string. Coordinates should be provided + * in the format `groupId:artifactId:version` or `groupId/artifactId:version`. + * @param coordinates Comma-delimited string of maven coordinates + * @return Sequence of Maven coordinates + */ def extractMavenCoordinates(coordinates: String): Seq[MavenCoordinate] = { coordinates.split(",").map { p => val splits = p.replace("/", ":").split(":") @@ -1304,6 +1301,13 @@ private[spark] object SparkSubmitUtils { rule } + def parseSparkConfProperty(pair: String): (String, String) = { + pair.split("=", 2).toSeq match { + case Seq(k, v) => (k, v) + case _ => throw new SparkException(s"Spark config without '=': $pair") + } + } + } /** diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index 8e7070593687b..0733fdb72cafb 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -29,7 +29,9 @@ import scala.collection.mutable.{ArrayBuffer, HashMap} import scala.io.Source import scala.util.Try +import org.apache.spark.{SparkException, SparkUserAppException} import org.apache.spark.deploy.SparkSubmitAction._ +import org.apache.spark.internal.Logging import org.apache.spark.launcher.SparkSubmitArgumentsParser import org.apache.spark.network.util.JavaUtils import org.apache.spark.util.Utils @@ -40,7 +42,7 @@ import org.apache.spark.util.Utils * The env argument is used for testing. */ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, String] = sys.env) - extends SparkSubmitArgumentsParser { + extends SparkSubmitArgumentsParser with Logging { var master: String = null var deployMode: String = null var executorMemory: String = null @@ -85,8 +87,9 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S /** Default properties present in the currently defined defaults file. */ lazy val defaultSparkProperties: HashMap[String, String] = { val defaultProperties = new HashMap[String, String]() - // scalastyle:off println - if (verbose) SparkSubmit.printStream.println(s"Using properties file: $propertiesFile") + if (verbose) { + logInfo(s"Using properties file: $propertiesFile") + } Option(propertiesFile).foreach { filename => val properties = Utils.getPropertiesFromFile(filename) properties.foreach { case (k, v) => @@ -95,21 +98,16 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S // Property files may contain sensitive information, so redact before printing if (verbose) { Utils.redact(properties).foreach { case (k, v) => - SparkSubmit.printStream.println(s"Adding default property: $k=$v") + logInfo(s"Adding default property: $k=$v") } } } - // scalastyle:on println defaultProperties } // Set parameters from command line arguments - try { - parse(args.asJava) - } catch { - case e: IllegalArgumentException => - SparkSubmit.printErrorAndExit(e.getMessage()) - } + parse(args.asJava) + // Populate `sparkProperties` map from properties file mergeDefaultSparkProperties() // Remove keys that don't start with "spark." from `sparkProperties`. @@ -141,7 +139,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S sparkProperties.foreach { case (k, v) => if (!k.startsWith("spark.")) { sparkProperties -= k - SparkSubmit.printWarning(s"Ignoring non-spark config property: $k=$v") + logWarning(s"Ignoring non-spark config property: $k=$v") } } } @@ -215,10 +213,10 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S } } catch { case _: Exception => - SparkSubmit.printErrorAndExit(s"Cannot load main class from JAR $primaryResource") + error(s"Cannot load main class from JAR $primaryResource") } case _ => - SparkSubmit.printErrorAndExit( + error( s"Cannot load main class from JAR $primaryResource with URI $uriScheme. " + "Please specify a class through --class.") } @@ -248,6 +246,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S case SUBMIT => validateSubmitArguments() case KILL => validateKillArguments() case REQUEST_STATUS => validateStatusRequestArguments() + case PRINT_VERSION => } } @@ -256,62 +255,61 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S printUsageAndExit(-1) } if (primaryResource == null) { - SparkSubmit.printErrorAndExit("Must specify a primary resource (JAR or Python or R file)") + error("Must specify a primary resource (JAR or Python or R file)") } if (mainClass == null && SparkSubmit.isUserJar(primaryResource)) { - SparkSubmit.printErrorAndExit("No main class set in JAR; please specify one with --class") + error("No main class set in JAR; please specify one with --class") } if (driverMemory != null && Try(JavaUtils.byteStringAsBytes(driverMemory)).getOrElse(-1L) <= 0) { - SparkSubmit.printErrorAndExit("Driver Memory must be a positive number") + error("Driver memory must be a positive number") } if (executorMemory != null && Try(JavaUtils.byteStringAsBytes(executorMemory)).getOrElse(-1L) <= 0) { - SparkSubmit.printErrorAndExit("Executor Memory cores must be a positive number") + error("Executor memory must be a positive number") } if (executorCores != null && Try(executorCores.toInt).getOrElse(-1) <= 0) { - SparkSubmit.printErrorAndExit("Executor cores must be a positive number") + error("Executor cores must be a positive number") } if (totalExecutorCores != null && Try(totalExecutorCores.toInt).getOrElse(-1) <= 0) { - SparkSubmit.printErrorAndExit("Total executor cores must be a positive number") + error("Total executor cores must be a positive number") } if (numExecutors != null && Try(numExecutors.toInt).getOrElse(-1) <= 0) { - SparkSubmit.printErrorAndExit("Number of executors must be a positive number") + error("Number of executors must be a positive number") } if (pyFiles != null && !isPython) { - SparkSubmit.printErrorAndExit("--py-files given but primary resource is not a Python script") + error("--py-files given but primary resource is not a Python script") } if (master.startsWith("yarn")) { val hasHadoopEnv = env.contains("HADOOP_CONF_DIR") || env.contains("YARN_CONF_DIR") if (!hasHadoopEnv && !Utils.isTesting) { - throw new Exception(s"When running with master '$master' " + + error(s"When running with master '$master' " + "either HADOOP_CONF_DIR or YARN_CONF_DIR must be set in the environment.") } } if (proxyUser != null && principal != null) { - SparkSubmit.printErrorAndExit("Only one of --proxy-user or --principal can be provided.") + error("Only one of --proxy-user or --principal can be provided.") } } private def validateKillArguments(): Unit = { if (!master.startsWith("spark://") && !master.startsWith("mesos://")) { - SparkSubmit.printErrorAndExit( - "Killing submissions is only supported in standalone or Mesos mode!") + error("Killing submissions is only supported in standalone or Mesos mode!") } if (submissionToKill == null) { - SparkSubmit.printErrorAndExit("Please specify a submission to kill.") + error("Please specify a submission to kill.") } } private def validateStatusRequestArguments(): Unit = { if (!master.startsWith("spark://") && !master.startsWith("mesos://")) { - SparkSubmit.printErrorAndExit( + error( "Requesting submission statuses is only supported in standalone or Mesos mode!") } if (submissionToRequestStatusFor == null) { - SparkSubmit.printErrorAndExit("Please specify a submission to request status for.") + error("Please specify a submission to request status for.") } } @@ -368,7 +366,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S case DEPLOY_MODE => if (value != "client" && value != "cluster") { - SparkSubmit.printErrorAndExit("--deploy-mode must be either \"client\" or \"cluster\"") + error("--deploy-mode must be either \"client\" or \"cluster\"") } deployMode = value @@ -405,14 +403,14 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S case KILL_SUBMISSION => submissionToKill = value if (action != null) { - SparkSubmit.printErrorAndExit(s"Action cannot be both $action and $KILL.") + error(s"Action cannot be both $action and $KILL.") } action = KILL case STATUS => submissionToRequestStatusFor = value if (action != null) { - SparkSubmit.printErrorAndExit(s"Action cannot be both $action and $REQUEST_STATUS.") + error(s"Action cannot be both $action and $REQUEST_STATUS.") } action = REQUEST_STATUS @@ -444,7 +442,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S repositories = value case CONF => - val (confName, confValue) = SparkSubmit.parseSparkConfProperty(value) + val (confName, confValue) = SparkSubmitUtils.parseSparkConfProperty(value) sparkProperties(confName) = confValue case PROXY_USER => @@ -463,15 +461,15 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S verbose = true case VERSION => - SparkSubmit.printVersionAndExit() + action = SparkSubmitAction.PRINT_VERSION case USAGE_ERROR => printUsageAndExit(1) case _ => - throw new IllegalArgumentException(s"Unexpected argument '$opt'.") + error(s"Unexpected argument '$opt'.") } - true + action != SparkSubmitAction.PRINT_VERSION } /** @@ -482,7 +480,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S */ override protected def handleUnknown(opt: String): Boolean = { if (opt.startsWith("-")) { - SparkSubmit.printErrorAndExit(s"Unrecognized option '$opt'.") + error(s"Unrecognized option '$opt'.") } primaryResource = @@ -501,20 +499,18 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S } private def printUsageAndExit(exitCode: Int, unknownParam: Any = null): Unit = { - // scalastyle:off println - val outStream = SparkSubmit.printStream if (unknownParam != null) { - outStream.println("Unknown/unsupported param " + unknownParam) + logInfo("Unknown/unsupported param " + unknownParam) } val command = sys.env.get("_SPARK_CMD_USAGE").getOrElse( """Usage: spark-submit [options] [app arguments] |Usage: spark-submit --kill [submission ID] --master [spark://...] |Usage: spark-submit --status [submission ID] --master [spark://...] |Usage: spark-submit run-example [options] example-class [example args]""".stripMargin) - outStream.println(command) + logInfo(command) val mem_mb = Utils.DEFAULT_DRIVER_MEM_MB - outStream.println( + logInfo( s""" |Options: | --master MASTER_URL spark://host:port, mesos://host:port, yarn, @@ -596,12 +592,11 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S ) if (SparkSubmit.isSqlShell(mainClass)) { - outStream.println("CLI options:") - outStream.println(getSqlShellOptions()) + logInfo("CLI options:") + logInfo(getSqlShellOptions()) } - // scalastyle:on println - SparkSubmit.exitFn(exitCode) + throw new SparkUserAppException(exitCode) } /** @@ -655,4 +650,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S System.setErr(currentErr) } } + + private def error(msg: String): Unit = throw new SparkException(msg) + } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala b/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala index 3f71237164a15..8d6a2b80ef5f2 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala @@ -25,7 +25,7 @@ import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.deploy.{DependencyUtils, SparkHadoopUtil, SparkSubmit} import org.apache.spark.internal.Logging import org.apache.spark.rpc.RpcEnv -import org.apache.spark.util.{ChildFirstURLClassLoader, MutableURLClassLoader, Utils} +import org.apache.spark.util._ /** * Utility object for launching driver programs such that they share fate with the Worker process. @@ -93,7 +93,7 @@ object DriverWrapper extends Logging { val jars = { val jarsProp = sys.props.get("spark.jars").orNull if (!StringUtils.isBlank(resolvedMavenCoordinates)) { - SparkSubmit.mergeFileLists(jarsProp, resolvedMavenCoordinates) + DependencyUtils.mergeFileLists(jarsProp, resolvedMavenCoordinates) } else { jarsProp } diff --git a/core/src/main/scala/org/apache/spark/util/CommandLineUtils.scala b/core/src/main/scala/org/apache/spark/util/CommandLineUtils.scala index d73901686b705..4b6602b50aa1c 100644 --- a/core/src/main/scala/org/apache/spark/util/CommandLineUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/CommandLineUtils.scala @@ -33,24 +33,14 @@ private[spark] trait CommandLineUtils { private[spark] var printStream: PrintStream = System.err // scalastyle:off println - - private[spark] def printWarning(str: String): Unit = printStream.println("Warning: " + str) + private[spark] def printMessage(str: String): Unit = printStream.println(str) + // scalastyle:on println private[spark] def printErrorAndExit(str: String): Unit = { - printStream.println("Error: " + str) - printStream.println("Run with --help for usage help or --verbose for debug output") + printMessage("Error: " + str) + printMessage("Run with --help for usage help or --verbose for debug output") exitFn(1) } - // scalastyle:on println - - private[spark] def parseSparkConfProperty(pair: String): (String, String) = { - pair.split("=", 2).toSeq match { - case Seq(k, v) => (k, v) - case _ => printErrorAndExit(s"Spark config without '=': $pair") - throw new SparkException(s"Spark config without '=': $pair") - } - } - def main(args: Array[String]): Unit } diff --git a/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java b/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java index 2225591a4ff75..6a1a38c1a54f4 100644 --- a/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java +++ b/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java @@ -109,7 +109,7 @@ public void testChildProcLauncher() throws Exception { .addSparkArg(opts.CONF, String.format("%s=-Dfoo=ShouldBeOverriddenBelow", SparkLauncher.DRIVER_EXTRA_JAVA_OPTIONS)) .setConf(SparkLauncher.DRIVER_EXTRA_JAVA_OPTIONS, - "-Dfoo=bar -Dtest.appender=childproc") + "-Dfoo=bar -Dtest.appender=console") .setConf(SparkLauncher.DRIVER_EXTRA_CLASSPATH, System.getProperty("java.class.path")) .addSparkArg(opts.CLASS, "ShouldBeOverriddenBelow") .setMainClass(SparkLauncherTestApp.class.getName()) @@ -192,6 +192,41 @@ private void inProcessLauncherTestImpl() throws Exception { } } + @Test + public void testInProcessLauncherDoesNotKillJvm() throws Exception { + SparkSubmitOptionParser opts = new SparkSubmitOptionParser(); + List wrongArgs = Arrays.asList( + new String[] { "--unknown" }, + new String[] { opts.DEPLOY_MODE, "invalid" }); + + for (String[] args : wrongArgs) { + InProcessLauncher launcher = new InProcessLauncher() + .setAppResource(SparkLauncher.NO_RESOURCE); + switch (args.length) { + case 2: + launcher.addSparkArg(args[0], args[1]); + break; + + case 1: + launcher.addSparkArg(args[0]); + break; + + default: + fail("FIXME: invalid test."); + } + + SparkAppHandle handle = launcher.startApplication(); + waitFor(handle); + assertEquals(SparkAppHandle.State.FAILED, handle.getState()); + } + + // Run --version, which is useless as a use case, but should succeed and not exit the JVM. + // The expected state is "LOST" since "--version" doesn't report state back to the handle. + SparkAppHandle handle = new InProcessLauncher().addSparkArg(opts.VERSION).startApplication(); + waitFor(handle); + assertEquals(SparkAppHandle.State.LOST, handle.getState()); + } + public static class SparkLauncherTestApp { public static void main(String[] args) throws Exception { diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index 0d7c342a5eacd..7451e07b25a1f 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -42,6 +42,7 @@ import org.apache.spark.deploy.SparkSubmit._ import org.apache.spark.deploy.SparkSubmitUtils.MavenCoordinate import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ +import org.apache.spark.launcher.SparkLauncher import org.apache.spark.scheduler.EventLoggingListener import org.apache.spark.util.{CommandLineUtils, ResetSystemProperties, Utils} @@ -109,6 +110,8 @@ class SparkSubmitSuite private val emptyIvySettings = File.createTempFile("ivy", ".xml") FileUtils.write(emptyIvySettings, "", StandardCharsets.UTF_8) + private val submit = new SparkSubmit() + override def beforeEach() { super.beforeEach() } @@ -128,13 +131,16 @@ class SparkSubmitSuite } test("handle binary specified but not class") { - testPrematureExit(Array("foo.jar"), "No main class") + val jar = TestUtils.createJarWithClasses(Seq("SparkSubmitClassA")) + testPrematureExit(Array(jar.toString()), "No main class") } test("handles arguments with --key=val") { val clArgs = Seq( "--jars=one.jar,two.jar,three.jar", - "--name=myApp") + "--name=myApp", + "--class=org.FooBar", + SparkLauncher.NO_RESOURCE) val appArgs = new SparkSubmitArguments(clArgs) appArgs.jars should include regex (".*one.jar,.*two.jar,.*three.jar") appArgs.name should be ("myApp") @@ -182,7 +188,7 @@ class SparkSubmitSuite "thejar.jar" ) val appArgs = new SparkSubmitArguments(clArgs) - val (_, _, conf, _) = prepareSubmitEnvironment(appArgs) + val (_, _, conf, _) = submit.prepareSubmitEnvironment(appArgs) appArgs.deployMode should be ("client") conf.get("spark.submit.deployMode") should be ("client") @@ -192,11 +198,11 @@ class SparkSubmitSuite "--master", "yarn", "--deploy-mode", "cluster", "--conf", "spark.submit.deployMode=client", - "-class", "org.SomeClass", + "--class", "org.SomeClass", "thejar.jar" ) val appArgs1 = new SparkSubmitArguments(clArgs1) - val (_, _, conf1, _) = prepareSubmitEnvironment(appArgs1) + val (_, _, conf1, _) = submit.prepareSubmitEnvironment(appArgs1) appArgs1.deployMode should be ("cluster") conf1.get("spark.submit.deployMode") should be ("cluster") @@ -210,7 +216,7 @@ class SparkSubmitSuite val appArgs2 = new SparkSubmitArguments(clArgs2) appArgs2.deployMode should be (null) - val (_, _, conf2, _) = prepareSubmitEnvironment(appArgs2) + val (_, _, conf2, _) = submit.prepareSubmitEnvironment(appArgs2) appArgs2.deployMode should be ("client") conf2.get("spark.submit.deployMode") should be ("client") } @@ -233,7 +239,7 @@ class SparkSubmitSuite "thejar.jar", "arg1", "arg2") val appArgs = new SparkSubmitArguments(clArgs) - val (childArgs, classpath, conf, mainClass) = prepareSubmitEnvironment(appArgs) + val (childArgs, classpath, conf, mainClass) = submit.prepareSubmitEnvironment(appArgs) val childArgsStr = childArgs.mkString(" ") childArgsStr should include ("--class org.SomeClass") childArgsStr should include ("--arg arg1 --arg arg2") @@ -276,7 +282,7 @@ class SparkSubmitSuite "thejar.jar", "arg1", "arg2") val appArgs = new SparkSubmitArguments(clArgs) - val (childArgs, classpath, conf, mainClass) = prepareSubmitEnvironment(appArgs) + val (childArgs, classpath, conf, mainClass) = submit.prepareSubmitEnvironment(appArgs) childArgs.mkString(" ") should be ("arg1 arg2") mainClass should be ("org.SomeClass") classpath should have length (4) @@ -322,7 +328,7 @@ class SparkSubmitSuite "arg1", "arg2") val appArgs = new SparkSubmitArguments(clArgs) appArgs.useRest = useRest - val (childArgs, classpath, conf, mainClass) = prepareSubmitEnvironment(appArgs) + val (childArgs, classpath, conf, mainClass) = submit.prepareSubmitEnvironment(appArgs) val childArgsStr = childArgs.mkString(" ") if (useRest) { childArgsStr should endWith ("thejar.jar org.SomeClass arg1 arg2") @@ -359,7 +365,7 @@ class SparkSubmitSuite "thejar.jar", "arg1", "arg2") val appArgs = new SparkSubmitArguments(clArgs) - val (childArgs, classpath, conf, mainClass) = prepareSubmitEnvironment(appArgs) + val (childArgs, classpath, conf, mainClass) = submit.prepareSubmitEnvironment(appArgs) childArgs.mkString(" ") should be ("arg1 arg2") mainClass should be ("org.SomeClass") classpath should have length (1) @@ -381,7 +387,7 @@ class SparkSubmitSuite "thejar.jar", "arg1", "arg2") val appArgs = new SparkSubmitArguments(clArgs) - val (childArgs, classpath, conf, mainClass) = prepareSubmitEnvironment(appArgs) + val (childArgs, classpath, conf, mainClass) = submit.prepareSubmitEnvironment(appArgs) childArgs.mkString(" ") should be ("arg1 arg2") mainClass should be ("org.SomeClass") classpath should have length (1) @@ -403,7 +409,7 @@ class SparkSubmitSuite "/home/thejar.jar", "arg1") val appArgs = new SparkSubmitArguments(clArgs) - val (childArgs, classpath, conf, mainClass) = prepareSubmitEnvironment(appArgs) + val (childArgs, classpath, conf, mainClass) = submit.prepareSubmitEnvironment(appArgs) val childArgsMap = childArgs.grouped(2).map(a => a(0) -> a(1)).toMap childArgsMap.get("--primary-java-resource") should be (Some("file:/home/thejar.jar")) @@ -428,7 +434,7 @@ class SparkSubmitSuite "thejar.jar", "arg1", "arg2") val appArgs = new SparkSubmitArguments(clArgs) - val (_, _, conf, mainClass) = prepareSubmitEnvironment(appArgs) + val (_, _, conf, mainClass) = submit.prepareSubmitEnvironment(appArgs) conf.get("spark.executor.memory") should be ("5g") conf.get("spark.master") should be ("yarn") conf.get("spark.submit.deployMode") should be ("cluster") @@ -441,12 +447,12 @@ class SparkSubmitSuite val clArgs1 = Seq("--class", "org.apache.spark.repl.Main", "spark-shell") val appArgs1 = new SparkSubmitArguments(clArgs1) - val (_, _, conf1, _) = prepareSubmitEnvironment(appArgs1) + val (_, _, conf1, _) = submit.prepareSubmitEnvironment(appArgs1) conf1.get(UI_SHOW_CONSOLE_PROGRESS) should be (true) val clArgs2 = Seq("--class", "org.SomeClass", "thejar.jar") val appArgs2 = new SparkSubmitArguments(clArgs2) - val (_, _, conf2, _) = prepareSubmitEnvironment(appArgs2) + val (_, _, conf2, _) = submit.prepareSubmitEnvironment(appArgs2) assert(!conf2.contains(UI_SHOW_CONSOLE_PROGRESS)) } @@ -625,7 +631,7 @@ class SparkSubmitSuite "--files", files, "thejar.jar") val appArgs = new SparkSubmitArguments(clArgs) - val (_, _, conf, _) = SparkSubmit.prepareSubmitEnvironment(appArgs) + val (_, _, conf, _) = submit.prepareSubmitEnvironment(appArgs) appArgs.jars should be (Utils.resolveURIs(jars)) appArgs.files should be (Utils.resolveURIs(files)) conf.get("spark.jars") should be (Utils.resolveURIs(jars + ",thejar.jar")) @@ -640,7 +646,7 @@ class SparkSubmitSuite "thejar.jar" ) val appArgs2 = new SparkSubmitArguments(clArgs2) - val (_, _, conf2, _) = SparkSubmit.prepareSubmitEnvironment(appArgs2) + val (_, _, conf2, _) = submit.prepareSubmitEnvironment(appArgs2) appArgs2.files should be (Utils.resolveURIs(files)) appArgs2.archives should fullyMatch regex ("file:/archive1,file:.*#archive3") conf2.get("spark.yarn.dist.files") should be (Utils.resolveURIs(files)) @@ -656,7 +662,7 @@ class SparkSubmitSuite "mister.py" ) val appArgs3 = new SparkSubmitArguments(clArgs3) - val (_, _, conf3, _) = SparkSubmit.prepareSubmitEnvironment(appArgs3) + val (_, _, conf3, _) = submit.prepareSubmitEnvironment(appArgs3) appArgs3.pyFiles should be (Utils.resolveURIs(pyFiles)) conf3.get("spark.submit.pyFiles") should be ( PythonRunner.formatPaths(Utils.resolveURIs(pyFiles)).mkString(",")) @@ -708,7 +714,7 @@ class SparkSubmitSuite "thejar.jar" ) val appArgs = new SparkSubmitArguments(clArgs) - val (_, _, conf, _) = SparkSubmit.prepareSubmitEnvironment(appArgs) + val (_, _, conf, _) = submit.prepareSubmitEnvironment(appArgs) conf.get("spark.jars") should be(Utils.resolveURIs(jars + ",thejar.jar")) conf.get("spark.files") should be(Utils.resolveURIs(files)) @@ -725,7 +731,7 @@ class SparkSubmitSuite "thejar.jar" ) val appArgs2 = new SparkSubmitArguments(clArgs2) - val (_, _, conf2, _) = SparkSubmit.prepareSubmitEnvironment(appArgs2) + val (_, _, conf2, _) = submit.prepareSubmitEnvironment(appArgs2) conf2.get("spark.yarn.dist.files") should be(Utils.resolveURIs(files)) conf2.get("spark.yarn.dist.archives") should be(Utils.resolveURIs(archives)) @@ -740,7 +746,7 @@ class SparkSubmitSuite "mister.py" ) val appArgs3 = new SparkSubmitArguments(clArgs3) - val (_, _, conf3, _) = SparkSubmit.prepareSubmitEnvironment(appArgs3) + val (_, _, conf3, _) = submit.prepareSubmitEnvironment(appArgs3) conf3.get("spark.submit.pyFiles") should be( PythonRunner.formatPaths(Utils.resolveURIs(pyFiles)).mkString(",")) @@ -757,7 +763,7 @@ class SparkSubmitSuite "hdfs:///tmp/mister.py" ) val appArgs4 = new SparkSubmitArguments(clArgs4) - val (_, _, conf4, _) = SparkSubmit.prepareSubmitEnvironment(appArgs4) + val (_, _, conf4, _) = submit.prepareSubmitEnvironment(appArgs4) // Should not format python path for yarn cluster mode conf4.get("spark.submit.pyFiles") should be(Utils.resolveURIs(remotePyFiles)) } @@ -778,17 +784,17 @@ class SparkSubmitSuite } test("SPARK_CONF_DIR overrides spark-defaults.conf") { - forConfDir(Map("spark.executor.memory" -> "2.3g")) { path => + forConfDir(Map("spark.executor.memory" -> "3g")) { path => val unusedJar = TestUtils.createJarWithClasses(Seq.empty) val args = Seq( "--class", SimpleApplicationTest.getClass.getName.stripSuffix("$"), "--name", "testApp", "--master", "local", unusedJar.toString) - val appArgs = new SparkSubmitArguments(args, Map("SPARK_CONF_DIR" -> path)) + val appArgs = new SparkSubmitArguments(args, env = Map("SPARK_CONF_DIR" -> path)) assert(appArgs.propertiesFile != null) assert(appArgs.propertiesFile.startsWith(path)) - appArgs.executorMemory should be ("2.3g") + appArgs.executorMemory should be ("3g") } } @@ -809,6 +815,9 @@ class SparkSubmitSuite val archive1 = File.createTempFile("archive1", ".zip", tmpArchiveDir) val archive2 = File.createTempFile("archive2", ".zip", tmpArchiveDir) + val tempPyFile = File.createTempFile("tmpApp", ".py") + tempPyFile.deleteOnExit() + val args = Seq( "--class", UserClasspathFirstTest.getClass.getName.stripPrefix("$"), "--name", "testApp", @@ -818,10 +827,10 @@ class SparkSubmitSuite "--files", s"${tmpFileDir.getAbsolutePath}/tmpFile*", "--py-files", s"${tmpPyFileDir.getAbsolutePath}/tmpPy*", "--archives", s"${tmpArchiveDir.getAbsolutePath}/*.zip", - jar2.toString) + tempPyFile.toURI().toString()) val appArgs = new SparkSubmitArguments(args) - val (_, _, conf, _) = SparkSubmit.prepareSubmitEnvironment(appArgs) + val (_, _, conf, _) = submit.prepareSubmitEnvironment(appArgs) conf.get("spark.yarn.dist.jars").split(",").toSet should be (Set(jar1.toURI.toString, jar2.toURI.toString)) conf.get("spark.yarn.dist.files").split(",").toSet should be @@ -947,7 +956,7 @@ class SparkSubmitSuite ) val appArgs = new SparkSubmitArguments(args) - val (_, _, conf, _) = SparkSubmit.prepareSubmitEnvironment(appArgs, Some(hadoopConf)) + val (_, _, conf, _) = submit.prepareSubmitEnvironment(appArgs, conf = Some(hadoopConf)) // All the resources should still be remote paths, so that YARN client will not upload again. conf.get("spark.yarn.dist.jars") should be (tmpJarPath) @@ -1007,7 +1016,7 @@ class SparkSubmitSuite ) ++ forceDownloadArgs ++ Seq(s"s3a://$mainResource") val appArgs = new SparkSubmitArguments(args) - val (_, _, conf, _) = SparkSubmit.prepareSubmitEnvironment(appArgs, Some(hadoopConf)) + val (_, _, conf, _) = submit.prepareSubmitEnvironment(appArgs, conf = Some(hadoopConf)) val jars = conf.get("spark.yarn.dist.jars").split(",").toSet @@ -1058,7 +1067,7 @@ class SparkSubmitSuite "hello") val exception = intercept[SparkException] { - SparkSubmit.main(args) + submit.doSubmit(args) } assert(exception.getMessage() === "hello") diff --git a/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala index e505bc018857d..54c168a8218f3 100644 --- a/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala @@ -445,7 +445,7 @@ class StandaloneRestSubmitSuite extends SparkFunSuite with BeforeAndAfterEach { "--class", mainClass, mainJar) ++ appArgs val args = new SparkSubmitArguments(commandLineArgs) - val (_, _, sparkConf, _) = SparkSubmit.prepareSubmitEnvironment(args) + val (_, _, sparkConf, _) = new SparkSubmit().prepareSubmitEnvironment(args) new RestSubmissionClient("spark://host:port").constructSubmitRequest( mainJar, mainClass, appArgs, sparkConf.getAll.toMap, Map.empty) } diff --git a/launcher/src/main/java/org/apache/spark/launcher/AbstractLauncher.java b/launcher/src/main/java/org/apache/spark/launcher/AbstractLauncher.java index 44e69fc45dffa..4e02843480e8f 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/AbstractLauncher.java +++ b/launcher/src/main/java/org/apache/spark/launcher/AbstractLauncher.java @@ -139,7 +139,7 @@ public T setMainClass(String mainClass) { public T addSparkArg(String arg) { SparkSubmitOptionParser validator = new ArgumentValidator(false); validator.parse(Arrays.asList(arg)); - builder.sparkArgs.add(arg); + builder.userArgs.add(arg); return self(); } @@ -187,8 +187,8 @@ public T addSparkArg(String name, String value) { } } else { validator.parse(Arrays.asList(name, value)); - builder.sparkArgs.add(name); - builder.sparkArgs.add(value); + builder.userArgs.add(name); + builder.userArgs.add(value); } return self(); } diff --git a/launcher/src/main/java/org/apache/spark/launcher/InProcessLauncher.java b/launcher/src/main/java/org/apache/spark/launcher/InProcessLauncher.java index 6d726b4a69a86..688e1f763c205 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/InProcessLauncher.java +++ b/launcher/src/main/java/org/apache/spark/launcher/InProcessLauncher.java @@ -89,10 +89,18 @@ Method findSparkSubmit() throws IOException { } Class sparkSubmit; + // SPARK-22941: first try the new SparkSubmit interface that has better error handling, + // but fall back to the old interface in case someone is mixing & matching launcher and + // Spark versions. try { - sparkSubmit = cl.loadClass("org.apache.spark.deploy.SparkSubmit"); - } catch (Exception e) { - throw new IOException("Cannot find SparkSubmit; make sure necessary jars are available.", e); + sparkSubmit = cl.loadClass("org.apache.spark.deploy.InProcessSparkSubmit"); + } catch (Exception e1) { + try { + sparkSubmit = cl.loadClass("org.apache.spark.deploy.SparkSubmit"); + } catch (Exception e2) { + throw new IOException("Cannot find SparkSubmit; make sure necessary jars are available.", + e2); + } } Method main; diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java index e0ef22d7d5058..5cb6457bf5c21 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java +++ b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java @@ -88,8 +88,9 @@ class SparkSubmitCommandBuilder extends AbstractCommandBuilder { SparkLauncher.NO_RESOURCE); } - final List sparkArgs; - private final boolean isAppResourceReq; + final List userArgs; + private final List parsedArgs; + private final boolean requiresAppResource; private final boolean isExample; /** @@ -99,17 +100,27 @@ class SparkSubmitCommandBuilder extends AbstractCommandBuilder { */ private boolean allowsMixedArguments; + /** + * This constructor is used when creating a user-configurable launcher. It allows the + * spark-submit argument list to be modified after creation. + */ SparkSubmitCommandBuilder() { - this.sparkArgs = new ArrayList<>(); - this.isAppResourceReq = true; + this.requiresAppResource = true; this.isExample = false; + this.parsedArgs = new ArrayList<>(); + this.userArgs = new ArrayList<>(); } + /** + * This constructor is used when invoking spark-submit; it parses and validates arguments + * provided by the user on the command line. + */ SparkSubmitCommandBuilder(List args) { this.allowsMixedArguments = false; - this.sparkArgs = new ArrayList<>(); + this.parsedArgs = new ArrayList<>(); boolean isExample = false; List submitArgs = args; + this.userArgs = Collections.emptyList(); if (args.size() > 0) { switch (args.get(0)) { @@ -131,21 +142,21 @@ class SparkSubmitCommandBuilder extends AbstractCommandBuilder { } this.isExample = isExample; - OptionParser parser = new OptionParser(); + OptionParser parser = new OptionParser(true); parser.parse(submitArgs); - this.isAppResourceReq = parser.isAppResourceReq; - } else { + this.requiresAppResource = parser.requiresAppResource; + } else { this.isExample = isExample; - this.isAppResourceReq = false; + this.requiresAppResource = false; } } @Override public List buildCommand(Map env) throws IOException, IllegalArgumentException { - if (PYSPARK_SHELL.equals(appResource) && isAppResourceReq) { + if (PYSPARK_SHELL.equals(appResource) && requiresAppResource) { return buildPySparkShellCommand(env); - } else if (SPARKR_SHELL.equals(appResource) && isAppResourceReq) { + } else if (SPARKR_SHELL.equals(appResource) && requiresAppResource) { return buildSparkRCommand(env); } else { return buildSparkSubmitCommand(env); @@ -154,9 +165,19 @@ public List buildCommand(Map env) List buildSparkSubmitArgs() { List args = new ArrayList<>(); - SparkSubmitOptionParser parser = new SparkSubmitOptionParser(); + OptionParser parser = new OptionParser(false); + final boolean requiresAppResource; + + // If the user args array is not empty, we need to parse it to detect exactly what + // the user is trying to run, so that checks below are correct. + if (!userArgs.isEmpty()) { + parser.parse(userArgs); + requiresAppResource = parser.requiresAppResource; + } else { + requiresAppResource = this.requiresAppResource; + } - if (!allowsMixedArguments && isAppResourceReq) { + if (!allowsMixedArguments && requiresAppResource) { checkArgument(appResource != null, "Missing application resource."); } @@ -208,15 +229,16 @@ List buildSparkSubmitArgs() { args.add(join(",", pyFiles)); } - if (isAppResourceReq) { - checkArgument(!isExample || mainClass != null, "Missing example class name."); + if (isExample) { + checkArgument(mainClass != null, "Missing example class name."); } + if (mainClass != null) { args.add(parser.CLASS); args.add(mainClass); } - args.addAll(sparkArgs); + args.addAll(parsedArgs); if (appResource != null) { args.add(appResource); } @@ -399,7 +421,12 @@ private List findExamplesJars() { private class OptionParser extends SparkSubmitOptionParser { - boolean isAppResourceReq = true; + boolean requiresAppResource = true; + private final boolean errorOnUnknownArgs; + + OptionParser(boolean errorOnUnknownArgs) { + this.errorOnUnknownArgs = errorOnUnknownArgs; + } @Override protected boolean handle(String opt, String value) { @@ -443,23 +470,23 @@ protected boolean handle(String opt, String value) { break; case KILL_SUBMISSION: case STATUS: - isAppResourceReq = false; - sparkArgs.add(opt); - sparkArgs.add(value); + requiresAppResource = false; + parsedArgs.add(opt); + parsedArgs.add(value); break; case HELP: case USAGE_ERROR: - isAppResourceReq = false; - sparkArgs.add(opt); + requiresAppResource = false; + parsedArgs.add(opt); break; case VERSION: - isAppResourceReq = false; - sparkArgs.add(opt); + requiresAppResource = false; + parsedArgs.add(opt); break; default: - sparkArgs.add(opt); + parsedArgs.add(opt); if (value != null) { - sparkArgs.add(value); + parsedArgs.add(value); } break; } @@ -483,12 +510,13 @@ protected boolean handleUnknown(String opt) { mainClass = className; appResource = SparkLauncher.NO_RESOURCE; return false; - } else { + } else if (errorOnUnknownArgs) { checkArgument(!opt.startsWith("-"), "Unrecognized option: %s", opt); checkState(appResource == null, "Found unrecognized argument but resource is already set."); appResource = opt; return false; } + return true; } @Override diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index b37b4d51775e8..a87fa68422c34 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -36,12 +36,17 @@ object MimaExcludes { // Exclude rules for 2.4.x lazy val v24excludes = v23excludes ++ Seq( + // [SPARK-22941][core] Do not exit JVM when submit fails with in-process launcher. + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.SparkSubmit.printWarning"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.SparkSubmit.parseSparkConfProperty"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.SparkSubmit.printVersionAndExit"), + // [SPARK-23412][ML] Add cosine distance measure to BisectingKmeans ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasDistanceMeasure.org$apache$spark$ml$param$shared$HasDistanceMeasure$_setter_$distanceMeasure_="), ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasDistanceMeasure.getDistanceMeasure"), ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasDistanceMeasure.distanceMeasure"), ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.clustering.BisectingKMeansModel#SaveLoadV1_0.load"), - + // [SPARK-20659] Remove StorageStatus, or make it private ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.SparkExecutorInfo.totalOffHeapStorageMemory"), ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.SparkExecutorInfo.usedOffHeapStorageMemory"), diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala index aa378c9d340f1..ccf33e8d4283c 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala @@ -19,7 +19,7 @@ package org.apache.spark.deploy.mesos import java.util.concurrent.CountDownLatch -import org.apache.spark.{SecurityManager, SparkConf} +import org.apache.spark.{SecurityManager, SparkConf, SparkException} import org.apache.spark.deploy.mesos.config._ import org.apache.spark.deploy.mesos.ui.MesosClusterUI import org.apache.spark.deploy.rest.mesos.MesosRestServer @@ -100,7 +100,13 @@ private[mesos] object MesosClusterDispatcher Thread.setDefaultUncaughtExceptionHandler(new SparkUncaughtExceptionHandler) Utils.initDaemon(log) val conf = new SparkConf - val dispatcherArgs = new MesosClusterDispatcherArguments(args, conf) + val dispatcherArgs = try { + new MesosClusterDispatcherArguments(args, conf) + } catch { + case e: SparkException => + printErrorAndExit(e.getMessage()) + null + } conf.setMaster(dispatcherArgs.masterUrl) conf.setAppName(dispatcherArgs.name) dispatcherArgs.zookeeperUrl.foreach { z => diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArguments.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArguments.scala index 096bb4e1af688..267a4283db9e6 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArguments.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArguments.scala @@ -21,6 +21,7 @@ import scala.annotation.tailrec import scala.collection.mutable import org.apache.spark.SparkConf +import org.apache.spark.deploy.SparkSubmitUtils import org.apache.spark.util.{IntParam, Utils} private[mesos] class MesosClusterDispatcherArguments(args: Array[String], conf: SparkConf) { @@ -95,9 +96,8 @@ private[mesos] class MesosClusterDispatcherArguments(args: Array[String], conf: parse(tail) case ("--conf") :: value :: tail => - val pair = MesosClusterDispatcher. - parseSparkConfProperty(value) - confProperties(pair._1) = pair._2 + val (k, v) = SparkSubmitUtils.parseSparkConfProperty(value) + confProperties(k) = v parse(tail) case ("--help") :: tail => From 75a183071c4ed2e407c930edfdf721779662b3ee Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Wed, 11 Apr 2018 09:59:38 -0700 Subject: [PATCH 0604/1161] [SPARK-22883] ML test for StructuredStreaming: spark.ml.feature, I-M ## What changes were proposed in this pull request? Adds structured streaming tests using testTransformer for these suites: * IDF * Imputer * Interaction * MaxAbsScaler * MinHashLSH * MinMaxScaler * NGram ## How was this patch tested? It is a bunch of tests! Author: Joseph K. Bradley Closes #20964 from jkbradley/SPARK-22883-part2. --- .../apache/spark/ml/feature/IDFSuite.scala | 14 +++--- .../spark/ml/feature/ImputerSuite.scala | 31 ++++++++++--- .../spark/ml/feature/InteractionSuite.scala | 46 ++++++++++--------- .../spark/ml/feature/MaxAbsScalerSuite.scala | 14 +++--- .../spark/ml/feature/MinHashLSHSuite.scala | 25 ++++++++-- .../spark/ml/feature/MinMaxScalerSuite.scala | 14 +++--- .../apache/spark/ml/feature/NGramSuite.scala | 2 +- 7 files changed, 89 insertions(+), 57 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala index 005edf73d29be..cdd62be43b54c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala @@ -17,17 +17,15 @@ package org.apache.spark.ml.feature -import org.apache.spark.SparkFunSuite import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.feature.{IDFModel => OldIDFModel} import org.apache.spark.mllib.linalg.VectorImplicits._ -import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.Row -class IDFSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class IDFSuite extends MLTest with DefaultReadWriteTest { import testImplicits._ @@ -57,7 +55,7 @@ class IDFSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead Vectors.dense(0.0, 1.0, 2.0, 3.0), Vectors.sparse(numOfFeatures, Array(1), Array(1.0)) ) - val numOfData = data.size + val numOfData = data.length val idf = Vectors.dense(Array(0, 3, 1, 2).map { x => math.log((numOfData + 1.0) / (x + 1.0)) }) @@ -72,7 +70,7 @@ class IDFSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead MLTestingUtils.checkCopyAndUids(idfEst, idfModel) - idfModel.transform(df).select("idfValue", "expected").collect().foreach { + testTransformer[(Vector, Vector)](df, idfModel, "idfValue", "expected") { case Row(x: Vector, y: Vector) => assert(x ~== y absTol 1e-5, "Transformed vector is different with expected vector.") } @@ -85,7 +83,7 @@ class IDFSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead Vectors.dense(0.0, 1.0, 2.0, 3.0), Vectors.sparse(numOfFeatures, Array(1), Array(1.0)) ) - val numOfData = data.size + val numOfData = data.length val idf = Vectors.dense(Array(0, 3, 1, 2).map { x => if (x > 0) math.log((numOfData + 1.0) / (x + 1.0)) else 0 }) @@ -99,7 +97,7 @@ class IDFSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead .setMinDocFreq(1) .fit(df) - idfModel.transform(df).select("idfValue", "expected").collect().foreach { + testTransformer[(Vector, Vector)](df, idfModel, "idfValue", "expected") { case Row(x: Vector, y: Vector) => assert(x ~== y absTol 1e-5, "Transformed vector is different with expected vector.") } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/ImputerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/ImputerSuite.scala index c08b35b419266..75f63a623e6d8 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/ImputerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/ImputerSuite.scala @@ -16,13 +16,12 @@ */ package org.apache.spark.ml.feature -import org.apache.spark.{SparkException, SparkFunSuite} -import org.apache.spark.ml.util.DefaultReadWriteTest -import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.SparkException +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest} import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.{DataFrame, Row} -class ImputerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class ImputerSuite extends MLTest with DefaultReadWriteTest { test("Imputer for Double with default missing Value NaN") { val df = spark.createDataFrame( Seq( @@ -76,6 +75,28 @@ class ImputerSuite extends SparkFunSuite with MLlibTestSparkContext with Default ImputerSuite.iterateStrategyTest(imputer, df) } + test("Imputer should work with Structured Streaming") { + val localSpark = spark + import localSpark.implicits._ + val df = Seq[(java.lang.Double, Double)]( + (4.0, 4.0), + (10.0, 10.0), + (10.0, 10.0), + (Double.NaN, 8.0), + (null, 8.0) + ).toDF("value", "expected_mean_value") + val imputer = new Imputer() + .setInputCols(Array("value")) + .setOutputCols(Array("out")) + .setStrategy("mean") + val model = imputer.fit(df) + testTransformer[(java.lang.Double, Double)](df, model, "expected_mean_value", "out") { + case Row(exp: java.lang.Double, out: Double) => + assert((exp.isNaN && out.isNaN) || (exp == out), + s"Imputed values differ. Expected: $exp, actual: $out") + } + } + test("Imputer throws exception when surrogate cannot be computed") { val df = spark.createDataFrame( Seq( (0, Double.NaN, 1.0, 1.0), @@ -164,8 +185,6 @@ object ImputerSuite { * @param df DataFrame with columns "id", "value", "expected_mean", "expected_median" */ def iterateStrategyTest(imputer: Imputer, df: DataFrame): Unit = { - val inputCols = imputer.getInputCols - Seq("mean", "median").foreach { strategy => imputer.setStrategy(strategy) val model = imputer.fit(df) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/InteractionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/InteractionSuite.scala index 54f059e5f143e..eea31fc7ae3f2 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/InteractionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/InteractionSuite.scala @@ -19,15 +19,15 @@ package org.apache.spark.ml.feature import scala.collection.mutable.ArrayBuilder -import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.SparkException import org.apache.spark.ml.attribute._ import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.DefaultReadWriteTest -import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest} +import org.apache.spark.sql.Row import org.apache.spark.sql.functions.col -class InteractionSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class InteractionSuite extends MLTest with DefaultReadWriteTest { import testImplicits._ @@ -63,9 +63,9 @@ class InteractionSuite extends SparkFunSuite with MLlibTestSparkContext with Def test("numeric interaction") { val data = Seq( - (2, Vectors.dense(3.0, 4.0)), - (1, Vectors.dense(1.0, 5.0)) - ).toDF("a", "b") + (2, Vectors.dense(3.0, 4.0), Vectors.dense(6.0, 8.0)), + (1, Vectors.dense(1.0, 5.0), Vectors.dense(1.0, 5.0)) + ).toDF("a", "b", "expected") val groupAttr = new AttributeGroup( "b", Array[Attribute]( @@ -73,14 +73,15 @@ class InteractionSuite extends SparkFunSuite with MLlibTestSparkContext with Def NumericAttribute.defaultAttr.withName("bar"))) val df = data.select( col("a").as("a", NumericAttribute.defaultAttr.toMetadata()), - col("b").as("b", groupAttr.toMetadata())) + col("b").as("b", groupAttr.toMetadata()), + col("expected")) val trans = new Interaction().setInputCols(Array("a", "b")).setOutputCol("features") + testTransformer[(Int, Vector, Vector)](df, trans, "features", "expected") { + case Row(features: Vector, expected: Vector) => + assert(features === expected) + } + val res = trans.transform(df) - val expected = Seq( - (2, Vectors.dense(3.0, 4.0), Vectors.dense(6.0, 8.0)), - (1, Vectors.dense(1.0, 5.0), Vectors.dense(1.0, 5.0)) - ).toDF("a", "b", "features") - assert(res.collect() === expected.collect()) val attrs = AttributeGroup.fromStructField(res.schema("features")) val expectedAttrs = new AttributeGroup( "features", @@ -92,9 +93,9 @@ class InteractionSuite extends SparkFunSuite with MLlibTestSparkContext with Def test("nominal interaction") { val data = Seq( - (2, Vectors.dense(3.0, 4.0)), - (1, Vectors.dense(1.0, 5.0)) - ).toDF("a", "b") + (2, Vectors.dense(3.0, 4.0), Vectors.dense(0, 0, 0, 0, 3, 4)), + (1, Vectors.dense(1.0, 5.0), Vectors.dense(0, 0, 1, 5, 0, 0)) + ).toDF("a", "b", "expected") val groupAttr = new AttributeGroup( "b", Array[Attribute]( @@ -103,14 +104,15 @@ class InteractionSuite extends SparkFunSuite with MLlibTestSparkContext with Def val df = data.select( col("a").as( "a", NominalAttribute.defaultAttr.withValues(Array("up", "down", "left")).toMetadata()), - col("b").as("b", groupAttr.toMetadata())) + col("b").as("b", groupAttr.toMetadata()), + col("expected")) val trans = new Interaction().setInputCols(Array("a", "b")).setOutputCol("features") + testTransformer[(Int, Vector, Vector)](df, trans, "features", "expected") { + case Row(features: Vector, expected: Vector) => + assert(features === expected) + } + val res = trans.transform(df) - val expected = Seq( - (2, Vectors.dense(3.0, 4.0), Vectors.dense(0, 0, 0, 0, 3, 4)), - (1, Vectors.dense(1.0, 5.0), Vectors.dense(0, 0, 1, 5, 0, 0)) - ).toDF("a", "b", "features") - assert(res.collect() === expected.collect()) val attrs = AttributeGroup.fromStructField(res.schema("features")) val expectedAttrs = new AttributeGroup( "features", diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/MaxAbsScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/MaxAbsScalerSuite.scala index 918da4f9388d4..8dd0f0cb91e37 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/MaxAbsScalerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/MaxAbsScalerSuite.scala @@ -14,15 +14,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.apache.spark.ml.feature -import org.apache.spark.SparkFunSuite import org.apache.spark.ml.linalg.{Vector, Vectors} -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} -import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.sql.Row -class MaxAbsScalerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class MaxAbsScalerSuite extends MLTest with DefaultReadWriteTest { import testImplicits._ @@ -45,9 +44,10 @@ class MaxAbsScalerSuite extends SparkFunSuite with MLlibTestSparkContext with De .setOutputCol("scaled") val model = scaler.fit(df) - model.transform(df).select("expected", "scaled").collect() - .foreach { case Row(vector1: Vector, vector2: Vector) => - assert(vector1.equals(vector2), s"MaxAbsScaler ut error: $vector2 should be $vector1") + testTransformer[(Vector, Vector)](df, model, "expected", "scaled") { + case Row(expectedVec: Vector, actualVec: Vector) => + assert(expectedVec === actualVec, + s"MaxAbsScaler error: Expected $expectedVec but computed $actualVec") } MLTestingUtils.checkCopyAndUids(scaler, model) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala index 3da0fb7da01ae..1c2956cb82908 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala @@ -17,14 +17,13 @@ package org.apache.spark.ml.feature -import org.apache.spark.SparkFunSuite import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} -import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.Dataset +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} +import org.apache.spark.sql.{Dataset, Row} -class MinHashLSHSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + +class MinHashLSHSuite extends MLTest with DefaultReadWriteTest { @transient var dataset: Dataset[_] = _ @@ -175,4 +174,20 @@ class MinHashLSHSuite extends SparkFunSuite with MLlibTestSparkContext with Defa assert(precision == 1.0) assert(recall >= 0.7) } + + test("MinHashLSHModel.transform should work with Structured Streaming") { + val localSpark = spark + import localSpark.implicits._ + + val model = new MinHashLSHModel("mh", randCoefficients = Array((1, 0))) + model.set(model.inputCol, "keys") + testTransformer[Tuple1[Vector]](dataset.toDF(), model, "keys", model.getOutputCol) { + case Row(_: Vector, output: Seq[_]) => + assert(output.length === model.randCoefficients.length) + // no AND-amplification yet: SPARK-18450, so each hash output is of length 1 + output.foreach { + case hashOutput: Vector => assert(hashOutput.size === 1) + } + } + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala index 51db74eb739ca..2d965f2ca2c54 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala @@ -17,13 +17,11 @@ package org.apache.spark.ml.feature -import org.apache.spark.SparkFunSuite import org.apache.spark.ml.linalg.{Vector, Vectors} -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} -import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.sql.Row -class MinMaxScalerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class MinMaxScalerSuite extends MLTest with DefaultReadWriteTest { import testImplicits._ @@ -48,9 +46,9 @@ class MinMaxScalerSuite extends SparkFunSuite with MLlibTestSparkContext with De .setMax(5) val model = scaler.fit(df) - model.transform(df).select("expected", "scaled").collect() - .foreach { case Row(vector1: Vector, vector2: Vector) => - assert(vector1.equals(vector2), "Transformed vector is different with expected.") + testTransformer[(Vector, Vector)](df, model, "expected", "scaled") { + case Row(vector1: Vector, vector2: Vector) => + assert(vector1 === vector2, "Transformed vector is different with expected.") } MLTestingUtils.checkCopyAndUids(scaler, model) @@ -114,7 +112,7 @@ class MinMaxScalerSuite extends SparkFunSuite with MLlibTestSparkContext with De val model = scaler.fit(df) model.transform(df).select("expected", "scaled").collect() .foreach { case Row(vector1: Vector, vector2: Vector) => - assert(vector1.equals(vector2), "Transformed vector is different with expected.") + assert(vector1 === vector2, "Transformed vector is different with expected.") } } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala index e5956ee9942aa..201a335e0d7be 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala @@ -84,7 +84,7 @@ class NGramSuite extends MLTest with DefaultReadWriteTest { def testNGram(t: NGram, dataFrame: DataFrame): Unit = { testTransformer[(Seq[String], Seq[String])](dataFrame, t, "nGrams", "wantedNGrams") { - case Row(actualNGrams : Seq[String], wantedNGrams: Seq[String]) => + case Row(actualNGrams : Seq[_], wantedNGrams: Seq[_]) => assert(actualNGrams === wantedNGrams) } } From 9d960de0814a1128318676cc2e91f447cdf0137f Mon Sep 17 00:00:00 2001 From: JBauerKogentix <37910022+JBauerKogentix@users.noreply.github.com> Date: Wed, 11 Apr 2018 15:52:13 -0700 Subject: [PATCH 0605/1161] typo rawPredicition changed to rawPrediction MultilayerPerceptronClassifier had 4 occurrences ## What changes were proposed in this pull request? (Please fill in changes proposed in this fix) ## How was this patch tested? (Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests) (If this patch involves UI changes, please attach a screenshot; otherwise, remove this) Please review http://spark.apache.org/contributing.html before opening a pull request. Author: JBauerKogentix <37910022+JBauerKogentix@users.noreply.github.com> Closes #21030 from JBauerKogentix/patch-1. --- python/pyspark/ml/classification.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index fbbe3d0307c81..ec17653a1adf9 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -1543,12 +1543,12 @@ class MultilayerPerceptronClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxIter=100, tol=1e-6, seed=None, layers=None, blockSize=128, stepSize=0.03, solver="l-bfgs", initialWeights=None, probabilityCol="probability", - rawPredicitionCol="rawPrediction"): + rawPredictionCol="rawPrediction"): """ __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxIter=100, tol=1e-6, seed=None, layers=None, blockSize=128, stepSize=0.03, \ solver="l-bfgs", initialWeights=None, probabilityCol="probability", \ - rawPredicitionCol="rawPrediction") + rawPredictionCol="rawPrediction") """ super(MultilayerPerceptronClassifier, self).__init__() self._java_obj = self._new_java_obj( @@ -1562,12 +1562,12 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxIter=100, tol=1e-6, seed=None, layers=None, blockSize=128, stepSize=0.03, solver="l-bfgs", initialWeights=None, probabilityCol="probability", - rawPredicitionCol="rawPrediction"): + rawPredictionCol="rawPrediction"): """ setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxIter=100, tol=1e-6, seed=None, layers=None, blockSize=128, stepSize=0.03, \ solver="l-bfgs", initialWeights=None, probabilityCol="probability", \ - rawPredicitionCol="rawPrediction"): + rawPredictionCol="rawPrediction"): Sets params for MultilayerPerceptronClassifier. """ kwargs = self._input_kwargs From e904dfaf0d16f9fa0cc4d2f46a3dec1b1d77de75 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Wed, 11 Apr 2018 17:04:34 -0700 Subject: [PATCH 0606/1161] Revert "[SPARK-23960][SQL][MINOR] Mark HashAggregateExec.bufVars as transient" This reverts commit 271c891b91917d660d1f6b995de397c47c7a6058. --- .../spark/sql/execution/aggregate/HashAggregateExec.scala | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 965950ed94fe8..a5dc6ebf2b0f2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -174,8 +174,8 @@ case class HashAggregateExec( } } - // The variables used as aggregation buffer. Only used in codegen for aggregation without keys. - @transient private var bufVars: Seq[ExprCode] = _ + // The variables used as aggregation buffer. Only used for aggregation without keys. + private var bufVars: Seq[ExprCode] = _ private def doProduceWithoutKeys(ctx: CodegenContext): String = { val initAgg = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "initAgg") @@ -238,8 +238,6 @@ case class HashAggregateExec( | } """.stripMargin) - bufVars = null // explicitly null this field out to allow the referent to be GC'd sooner - val numOutput = metricTerm(ctx, "numOutputRows") val aggTime = metricTerm(ctx, "aggTime") val beforeAgg = ctx.freshName("beforeAgg") From 6a2289ecf020a99cd9b3bcea7da5e78fb4e0303a Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Thu, 12 Apr 2018 15:58:04 +0800 Subject: [PATCH 0607/1161] [SPARK-23962][SQL][TEST] Fix race in currentExecutionIds(). SQLMetricsTestUtils.currentExecutionIds() was racing with the listener bus, which lead to some flaky tests. We should wait till the listener bus is empty. I tested by adding some Thread.sleep()s in SQLAppStatusListener, which reproduced the exceptions I saw on Jenkins. With this change, they went away. Author: Imran Rashid Closes #21041 from squito/SPARK-23962. --- .../apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala index 534d8bb629b8c..dcc540fc4f109 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala @@ -34,6 +34,7 @@ trait SQLMetricsTestUtils extends SQLTestUtils { import testImplicits._ protected def currentExecutionIds(): Set[Long] = { + spark.sparkContext.listenerBus.waitUntilEmpty(10000) statusStore.executionsList.map(_.executionId).toSet } From 0b19122d434e39eb117ccc3174a0688c9c874d48 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Thu, 12 Apr 2018 22:21:30 +0800 Subject: [PATCH 0608/1161] [SPARK-23762][SQL] UTF8StringBuffer uses MemoryBlock ## What changes were proposed in this pull request? This PR tries to use `MemoryBlock` in `UTF8StringBuffer`. In general, there are two advantages to use `MemoryBlock`. 1. Has clean API calls rather than using a Java array or `PlatformMemory` 2. Improve runtime performance of memory access instead of using `Object`. ## How was this patch tested? Added `UTF8StringBufferSuite` Author: Kazuaki Ishizaki Closes #20871 from kiszk/SPARK-23762. --- .../codegen/UTF8StringBuilder.java | 35 +++++++--------- .../codegen/UTF8StringBuilderSuite.scala | 42 +++++++++++++++++++ 2 files changed, 56 insertions(+), 21 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/UTF8StringBuilderSuite.scala diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UTF8StringBuilder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UTF8StringBuilder.java index f0f66bae245fd..f8000d78cd1b6 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UTF8StringBuilder.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UTF8StringBuilder.java @@ -19,6 +19,8 @@ import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.ByteArrayMethods; +import org.apache.spark.unsafe.memory.ByteArrayMemoryBlock; +import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.unsafe.types.UTF8String; /** @@ -29,43 +31,34 @@ public class UTF8StringBuilder { private static final int ARRAY_MAX = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH; - private byte[] buffer; - private int cursor = Platform.BYTE_ARRAY_OFFSET; + private ByteArrayMemoryBlock buffer; + private int length = 0; public UTF8StringBuilder() { // Since initial buffer size is 16 in `StringBuilder`, we set the same size here - this.buffer = new byte[16]; + this.buffer = new ByteArrayMemoryBlock(16); } // Grows the buffer by at least `neededSize` private void grow(int neededSize) { - if (neededSize > ARRAY_MAX - totalSize()) { + if (neededSize > ARRAY_MAX - length) { throw new UnsupportedOperationException( "Cannot grow internal buffer by size " + neededSize + " because the size after growing " + "exceeds size limitation " + ARRAY_MAX); } - final int length = totalSize() + neededSize; - if (buffer.length < length) { - int newLength = length < ARRAY_MAX / 2 ? length * 2 : ARRAY_MAX; - final byte[] tmp = new byte[newLength]; - Platform.copyMemory( - buffer, - Platform.BYTE_ARRAY_OFFSET, - tmp, - Platform.BYTE_ARRAY_OFFSET, - totalSize()); + final int requestedSize = length + neededSize; + if (buffer.size() < requestedSize) { + int newLength = requestedSize < ARRAY_MAX / 2 ? requestedSize * 2 : ARRAY_MAX; + final ByteArrayMemoryBlock tmp = new ByteArrayMemoryBlock(newLength); + MemoryBlock.copyMemory(buffer, tmp, length); buffer = tmp; } } - private int totalSize() { - return cursor - Platform.BYTE_ARRAY_OFFSET; - } - public void append(UTF8String value) { grow(value.numBytes()); - value.writeToMemory(buffer, cursor); - cursor += value.numBytes(); + value.writeToMemory(buffer.getByteArray(), length + Platform.BYTE_ARRAY_OFFSET); + length += value.numBytes(); } public void append(String value) { @@ -73,6 +66,6 @@ public void append(String value) { } public UTF8String build() { - return UTF8String.fromBytes(buffer, 0, totalSize()); + return UTF8String.fromBytes(buffer.getByteArray(), 0, length); } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/UTF8StringBuilderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/UTF8StringBuilderSuite.scala new file mode 100644 index 0000000000000..1b25a4b191f86 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/UTF8StringBuilderSuite.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.codegen + +import org.apache.spark.SparkFunSuite +import org.apache.spark.unsafe.types.UTF8String + +class UTF8StringBuilderSuite extends SparkFunSuite { + + test("basic test") { + val sb = new UTF8StringBuilder() + assert(sb.build() === UTF8String.EMPTY_UTF8) + + sb.append("") + assert(sb.build() === UTF8String.EMPTY_UTF8) + + sb.append("abcd") + assert(sb.build() === UTF8String.fromString("abcd")) + + sb.append(UTF8String.fromString("1234")) + assert(sb.build() === UTF8String.fromString("abcd1234")) + + // expect to grow an internal buffer + sb.append(UTF8String.fromString("efgijk567890")) + assert(sb.build() === UTF8String.fromString("abcd1234efgijk567890")) + } +} From 0f93b91a71444a1a938acfd8ea2191c54fb0187c Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Thu, 12 Apr 2018 15:47:42 -0600 Subject: [PATCH 0609/1161] [SPARK-23751][FOLLOW-UP] fix build for scala-2.12 ## What changes were proposed in this pull request? fix build for scala-2.12 ## How was this patch tested? Manual. Author: WeichenXu Closes #21051 from WeichenXu123/fix_build212. --- .../scala/org/apache/spark/ml/stat/KolmogorovSmirnovTest.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/stat/KolmogorovSmirnovTest.scala b/mllib/src/main/scala/org/apache/spark/ml/stat/KolmogorovSmirnovTest.scala index af8ff64d33ffe..adf8145726711 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/stat/KolmogorovSmirnovTest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/stat/KolmogorovSmirnovTest.scala @@ -85,7 +85,7 @@ object KolmogorovSmirnovTest { dataset: Dataset[_], sampleCol: String, cdf: Function[java.lang.Double, java.lang.Double]): DataFrame = { - test(dataset, sampleCol, (x: Double) => cdf.call(x)) + test(dataset, sampleCol, (x: Double) => cdf.call(x).toDouble) } /** From 682002b6da844ed11324ee5ff4d00fc0294c0b31 Mon Sep 17 00:00:00 2001 From: Patrick Pisciuneri Date: Fri, 13 Apr 2018 09:45:27 +0800 Subject: [PATCH 0610/1161] [SPARK-23867][SCHEDULER] use droppedCount in logWarning MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? Get the count of dropped events for output in log message. ## How was this patch tested? The fix is pretty trivial, but `./dev/run-tests` were run and were successful. Please review http://spark.apache.org/contributing.html before opening a pull request. vanzin cloud-fan The contribution is my original work and I license the work to the project under the project’s open source license. Author: Patrick Pisciuneri Closes #20977 from phpisciuneri/fix-log-warning. --- .../main/scala/org/apache/spark/scheduler/AsyncEventQueue.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/AsyncEventQueue.scala b/core/src/main/scala/org/apache/spark/scheduler/AsyncEventQueue.scala index 7e14938acd8e0..c1fedd63f6a90 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/AsyncEventQueue.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/AsyncEventQueue.scala @@ -166,7 +166,7 @@ private class AsyncEventQueue(val name: String, conf: SparkConf, metrics: LiveLi val prevLastReportTimestamp = lastReportTimestamp lastReportTimestamp = System.currentTimeMillis() val previous = new java.util.Date(prevLastReportTimestamp) - logWarning(s"Dropped $droppedEvents events from $name since $previous.") + logWarning(s"Dropped $droppedCount events from $name since $previous.") } } } From 14291b061b9b40eadbf4ed442f9a5021b8e09597 Mon Sep 17 00:00:00 2001 From: jerryshao Date: Thu, 12 Apr 2018 20:00:25 -0700 Subject: [PATCH 0611/1161] [SPARK-23748][SS] Fix SS continuous process doesn't support SubqueryAlias issue ## What changes were proposed in this pull request? Current SS continuous doesn't support processing on temp table or `df.as("xxx")`, SS will throw an exception as LogicalPlan not supported, details described in [here](https://issues.apache.org/jira/browse/SPARK-23748). So here propose to add this support. ## How was this patch tested? new UT. Author: jerryshao Closes #21017 from jerryshao/SPARK-23748. --- .../UnsupportedOperationChecker.scala | 2 +- .../continuous/ContinuousSuite.scala | 19 +++++++++++++++++++ 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala index b55043c270644..ff9d6d7a7dded 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala @@ -345,7 +345,7 @@ object UnsupportedOperationChecker { plan.foreachUp { implicit subPlan => subPlan match { case (_: Project | _: Filter | _: MapElements | _: MapPartitions | - _: DeserializeToObject | _: SerializeFromObject) => + _: DeserializeToObject | _: SerializeFromObject | _: SubqueryAlias) => case node if node.nodeName == "StreamingRelationV2" => case node => throwError(s"Continuous processing does not support ${node.nodeName} operations.") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala index f5884b9c8de12..ef74efef156d5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala @@ -171,6 +171,25 @@ class ContinuousSuite extends ContinuousSuiteBase { "Continuous processing does not support current time operations.")) } + test("subquery alias") { + val df = spark.readStream + .format("rate") + .option("numPartitions", "5") + .option("rowsPerSecond", "5") + .load() + .createOrReplaceTempView("rate") + val test = spark.sql("select value from rate where value > 5") + + testStream(test, useV2Sink = true)( + StartStream(longContinuousTrigger), + AwaitEpoch(0), + Execute(waitForRateSourceTriggers(_, 2)), + IncrementEpoch(), + Execute(waitForRateSourceTriggers(_, 4)), + IncrementEpoch(), + CheckAnswerRowsContains(scala.Range(6, 20).map(Row(_)))) + } + test("repeatedly restart") { val df = spark.readStream .format("rate") From ab7b961a4fe96ca02b8352d16b0fa80c972b67fc Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Fri, 13 Apr 2018 11:28:13 +0800 Subject: [PATCH 0612/1161] [SPARK-23942][PYTHON][SQL] Makes collect in PySpark as action for a query executor listener ## What changes were proposed in this pull request? This PR proposes to add `collect` to a query executor as an action. Seems `collect` / `collect` with Arrow are not recognised via `QueryExecutionListener` as an action. For example, if we have a custom listener as below: ```scala package org.apache.spark.sql import org.apache.spark.internal.Logging import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.util.QueryExecutionListener class TestQueryExecutionListener extends QueryExecutionListener with Logging { override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = { logError("Look at me! I'm 'onSuccess'") } override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = { } } ``` and set `spark.sql.queryExecutionListeners` to `org.apache.spark.sql.TestQueryExecutionListener` Other operations in PySpark or Scala side seems fine: ```python >>> sql("SELECT * FROM range(1)").show() ``` ``` 18/04/09 17:02:04 ERROR TestQueryExecutionListener: Look at me! I'm 'onSuccess' +---+ | id| +---+ | 0| +---+ ``` ```scala scala> sql("SELECT * FROM range(1)").collect() ``` ``` 18/04/09 16:58:41 ERROR TestQueryExecutionListener: Look at me! I'm 'onSuccess' res1: Array[org.apache.spark.sql.Row] = Array([0]) ``` but .. **Before** ```python >>> sql("SELECT * FROM range(1)").collect() ``` ``` [Row(id=0)] ``` ```python >>> spark.conf.set("spark.sql.execution.arrow.enabled", "true") >>> sql("SELECT * FROM range(1)").toPandas() ``` ``` id 0 0 ``` **After** ```python >>> sql("SELECT * FROM range(1)").collect() ``` ``` 18/04/09 16:57:58 ERROR TestQueryExecutionListener: Look at me! I'm 'onSuccess' [Row(id=0)] ``` ```python >>> spark.conf.set("spark.sql.execution.arrow.enabled", "true") >>> sql("SELECT * FROM range(1)").toPandas() ``` ``` 18/04/09 17:53:26 ERROR TestQueryExecutionListener: Look at me! I'm 'onSuccess' id 0 0 ``` ## How was this patch tested? I have manually tested as described above and unit test was added. Author: hyukjinkwon Closes #21007 from HyukjinKwon/SPARK-23942. --- python/pyspark/sql/tests.py | 87 ++++++++++++++++--- .../scala/org/apache/spark/sql/Dataset.scala | 20 +++-- .../sql/TestQueryExecutionListener.scala | 44 ++++++++++ 3 files changed, 134 insertions(+), 17 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/TestQueryExecutionListener.scala diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 96c2a776a5049..4e99c8e3c6b10 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -186,16 +186,12 @@ def __init__(self, key, value): self.value = value -class ReusedSQLTestCase(ReusedPySparkTestCase): - @classmethod - def setUpClass(cls): - ReusedPySparkTestCase.setUpClass() - cls.spark = SparkSession(cls.sc) - - @classmethod - def tearDownClass(cls): - ReusedPySparkTestCase.tearDownClass() - cls.spark.stop() +class SQLTestUtils(object): + """ + This util assumes the instance of this to have 'spark' attribute, having a spark session. + It is usually used with 'ReusedSQLTestCase' class but can be used if you feel sure the + the implementation of this class has 'spark' attribute. + """ @contextmanager def sql_conf(self, pairs): @@ -204,6 +200,7 @@ def sql_conf(self, pairs): `value` to the configuration `key` and then restores it back when it exits. """ assert isinstance(pairs, dict), "pairs should be a dictionary." + assert hasattr(self, "spark"), "it should have 'spark' attribute, having a spark session." keys = pairs.keys() new_values = pairs.values() @@ -219,6 +216,18 @@ def sql_conf(self, pairs): else: self.spark.conf.set(key, old_value) + +class ReusedSQLTestCase(ReusedPySparkTestCase, SQLTestUtils): + @classmethod + def setUpClass(cls): + ReusedPySparkTestCase.setUpClass() + cls.spark = SparkSession(cls.sc) + + @classmethod + def tearDownClass(cls): + ReusedPySparkTestCase.tearDownClass() + cls.spark.stop() + def assertPandasEqual(self, expected, result): msg = ("DataFrames are not equal: " + "\n\nExpected:\n%s\n%s" % (expected, expected.dtypes) + @@ -3066,6 +3075,64 @@ def test_sparksession_with_stopped_sparkcontext(self): sc.stop() +class QueryExecutionListenerTests(unittest.TestCase, SQLTestUtils): + # These tests are separate because it uses 'spark.sql.queryExecutionListeners' which is + # static and immutable. This can't be set or unset, for example, via `spark.conf`. + + @classmethod + def setUpClass(cls): + import glob + from pyspark.find_spark_home import _find_spark_home + + SPARK_HOME = _find_spark_home() + filename_pattern = ( + "sql/core/target/scala-*/test-classes/org/apache/spark/sql/" + "TestQueryExecutionListener.class") + if not glob.glob(os.path.join(SPARK_HOME, filename_pattern)): + raise unittest.SkipTest( + "'org.apache.spark.sql.TestQueryExecutionListener' is not " + "available. Will skip the related tests.") + + # Note that 'spark.sql.queryExecutionListeners' is a static immutable configuration. + cls.spark = SparkSession.builder \ + .master("local[4]") \ + .appName(cls.__name__) \ + .config( + "spark.sql.queryExecutionListeners", + "org.apache.spark.sql.TestQueryExecutionListener") \ + .getOrCreate() + + @classmethod + def tearDownClass(cls): + cls.spark.stop() + + def tearDown(self): + self.spark._jvm.OnSuccessCall.clear() + + def test_query_execution_listener_on_collect(self): + self.assertFalse( + self.spark._jvm.OnSuccessCall.isCalled(), + "The callback from the query execution listener should not be called before 'collect'") + self.spark.sql("SELECT * FROM range(1)").collect() + self.assertTrue( + self.spark._jvm.OnSuccessCall.isCalled(), + "The callback from the query execution listener should be called after 'collect'") + + @unittest.skipIf( + not _have_pandas or not _have_pyarrow, + _pandas_requirement_message or _pyarrow_requirement_message) + def test_query_execution_listener_on_collect_with_arrow(self): + with self.sql_conf({"spark.sql.execution.arrow.enabled": True}): + self.assertFalse( + self.spark._jvm.OnSuccessCall.isCalled(), + "The callback from the query execution listener should not be " + "called before 'toPandas'") + self.spark.sql("SELECT * FROM range(1)").toPandas() + self.assertTrue( + self.spark._jvm.OnSuccessCall.isCalled(), + "The callback from the query execution listener should be called after 'toPandas'") + + class SparkSessionTests(PySparkTestCase): # This test is separate because it's closely related with session's start and stop. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 0aee1d7be5788..917168162b236 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -3189,10 +3189,10 @@ class Dataset[T] private[sql]( private[sql] def collectToPython(): Int = { EvaluatePython.registerPicklers() - withNewExecutionId { + withAction("collectToPython", queryExecution) { plan => val toJava: (Any) => Any = EvaluatePython.toJava(_, schema) - val iter = new SerDeUtil.AutoBatchedPickler( - queryExecution.executedPlan.executeCollect().iterator.map(toJava)) + val iter: Iterator[Array[Byte]] = new SerDeUtil.AutoBatchedPickler( + plan.executeCollect().iterator.map(toJava)) PythonRDD.serveIterator(iter, "serve-DataFrame") } } @@ -3201,8 +3201,9 @@ class Dataset[T] private[sql]( * Collect a Dataset as ArrowPayload byte arrays and serve to PySpark. */ private[sql] def collectAsArrowToPython(): Int = { - withNewExecutionId { - val iter = toArrowPayload.collect().iterator.map(_.asPythonSerializable) + withAction("collectAsArrowToPython", queryExecution) { plan => + val iter: Iterator[Array[Byte]] = + toArrowPayload(plan).collect().iterator.map(_.asPythonSerializable) PythonRDD.serveIterator(iter, "serve-Arrow") } } @@ -3311,14 +3312,19 @@ class Dataset[T] private[sql]( } /** Convert to an RDD of ArrowPayload byte arrays */ - private[sql] def toArrowPayload: RDD[ArrowPayload] = { + private[sql] def toArrowPayload(plan: SparkPlan): RDD[ArrowPayload] = { val schemaCaptured = this.schema val maxRecordsPerBatch = sparkSession.sessionState.conf.arrowMaxRecordsPerBatch val timeZoneId = sparkSession.sessionState.conf.sessionLocalTimeZone - queryExecution.toRdd.mapPartitionsInternal { iter => + plan.execute().mapPartitionsInternal { iter => val context = TaskContext.get() ArrowConverters.toPayloadIterator( iter, schemaCaptured, maxRecordsPerBatch, timeZoneId, context) } } + + // This is only used in tests, for now. + private[sql] def toArrowPayload: RDD[ArrowPayload] = { + toArrowPayload(queryExecution.executedPlan) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestQueryExecutionListener.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestQueryExecutionListener.scala new file mode 100644 index 0000000000000..d2a6358ee822b --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/TestQueryExecutionListener.scala @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import java.util.concurrent.atomic.AtomicBoolean + +import org.apache.spark.sql.execution.QueryExecution +import org.apache.spark.sql.util.QueryExecutionListener + + +class TestQueryExecutionListener extends QueryExecutionListener { + override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = { + OnSuccessCall.isOnSuccessCalled.set(true) + } + + override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = { } +} + +/** + * This has a variable to check if `onSuccess` is actually called or not. Currently, this is for + * the test case in PySpark. See SPARK-23942. + */ +object OnSuccessCall { + val isOnSuccessCalled = new AtomicBoolean(false) + + def isCalled(): Boolean = isOnSuccessCalled.get() + + def clear(): Unit = isOnSuccessCalled.set(false) +} From 1018be44d6c52cf18e14d84160850063f0e60a1d Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Thu, 12 Apr 2018 22:30:59 -0700 Subject: [PATCH 0613/1161] [SPARK-23971] Should not leak Spark sessions across test suites ## What changes were proposed in this pull request? Many suites currently leak Spark sessions (sometimes with stopped SparkContexts) via the thread-local active Spark session and default Spark session. We should attempt to clean these up and detect when this happens to improve the reproducibility of tests. ## How was this patch tested? Existing tests Author: Eric Liang Closes #21058 from ericl/clear-session. --- .../org/apache/spark/SharedSparkSession.java | 9 ++++++-- .../org/apache/spark/sql/SparkSession.scala | 23 +++++++++++++++++-- .../apache/spark/sql/SessionStateSuite.scala | 2 ++ .../spark/sql/test/SharedSparkSession.scala | 22 ++++++++++++++---- 4 files changed, 47 insertions(+), 9 deletions(-) diff --git a/mllib/src/test/java/org/apache/spark/SharedSparkSession.java b/mllib/src/test/java/org/apache/spark/SharedSparkSession.java index 43779878890db..35a250955b282 100644 --- a/mllib/src/test/java/org/apache/spark/SharedSparkSession.java +++ b/mllib/src/test/java/org/apache/spark/SharedSparkSession.java @@ -42,7 +42,12 @@ public void setUp() throws IOException { @After public void tearDown() { - spark.stop(); - spark = null; + try { + spark.stop(); + spark = null; + } finally { + SparkSession.clearDefaultSession(); + SparkSession.clearActiveSession(); + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index b107492fbb330..c502e583a55c5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -44,7 +44,7 @@ import org.apache.spark.sql.sources.BaseRelation import org.apache.spark.sql.streaming._ import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.sql.util.ExecutionListenerManager -import org.apache.spark.util.Utils +import org.apache.spark.util.{CallSite, Utils} /** @@ -81,6 +81,9 @@ class SparkSession private( @transient private[sql] val extensions: SparkSessionExtensions) extends Serializable with Closeable with Logging { self => + // The call site where this SparkSession was constructed. + private val creationSite: CallSite = Utils.getCallSite() + private[sql] def this(sc: SparkContext) { this(sc, None, None, new SparkSessionExtensions) } @@ -763,7 +766,7 @@ class SparkSession private( @InterfaceStability.Stable -object SparkSession { +object SparkSession extends Logging { /** * Builder for [[SparkSession]]. @@ -1090,4 +1093,20 @@ object SparkSession { } } + private[spark] def cleanupAnyExistingSession(): Unit = { + val session = getActiveSession.orElse(getDefaultSession) + if (session.isDefined) { + logWarning( + s"""An existing Spark session exists as the active or default session. + |This probably means another suite leaked it. Attempting to stop it before continuing. + |This existing Spark session was created at: + | + |${session.get.creationSite.longForm} + | + """.stripMargin) + session.get.stop() + SparkSession.clearActiveSession() + SparkSession.clearDefaultSession() + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala index 4efae4c46c2e1..7d1366092d1e6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala @@ -44,6 +44,8 @@ class SessionStateSuite extends SparkFunSuite { if (activeSession != null) { activeSession.stop() activeSession = null + SparkSession.clearActiveSession() + SparkSession.clearDefaultSession() } super.afterAll() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala index e758c865b908f..8968dbf36d507 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala @@ -60,6 +60,7 @@ trait SharedSparkSession protected implicit def sqlContext: SQLContext = _spark.sqlContext protected def createSparkSession: TestSparkSession = { + SparkSession.cleanupAnyExistingSession() new TestSparkSession(sparkConf) } @@ -92,11 +93,22 @@ trait SharedSparkSession * Stop the underlying [[org.apache.spark.SparkContext]], if any. */ protected override def afterAll(): Unit = { - super.afterAll() - if (_spark != null) { - _spark.sessionState.catalog.reset() - _spark.stop() - _spark = null + try { + super.afterAll() + } finally { + try { + if (_spark != null) { + try { + _spark.sessionState.catalog.reset() + } finally { + _spark.stop() + _spark = null + } + } + } finally { + SparkSession.clearActiveSession() + SparkSession.clearDefaultSession() + } } } From 4b07036799b01894826b47c73142fe282c607a57 Mon Sep 17 00:00:00 2001 From: Fangshi Li Date: Fri, 13 Apr 2018 13:46:34 +0800 Subject: [PATCH 0614/1161] [SPARK-23815][CORE] Spark writer dynamic partition overwrite mode may fail to write output on multi level partition ## What changes were proposed in this pull request? Spark introduced new writer mode to overwrite only related partitions in SPARK-20236. While we are using this feature in our production cluster, we found a bug when writing multi-level partitions on HDFS. A simple test case to reproduce this issue: val df = Seq(("1","2","3")).toDF("col1", "col2","col3") df.write.partitionBy("col1","col2").mode("overwrite").save("/my/hdfs/location") If HDFS location "/my/hdfs/location" does not exist, there will be no output. This seems to be caused by the job commit change in SPARK-20236 in HadoopMapReduceCommitProtocol. In the commit job process, the output has been written into staging dir /my/hdfs/location/.spark-staging.xxx/col1=1/col2=2, and then the code calls fs.rename to rename /my/hdfs/location/.spark-staging.xxx/col1=1/col2=2 to /my/hdfs/location/col1=1/col2=2. However, in our case the operation will fail on HDFS because /my/hdfs/location/col1=1 does not exists. HDFS rename can not create directory for more than one level. This does not happen in the new unit test added with SPARK-20236 which uses local file system. We are proposing a fix. When cleaning current partition dir /my/hdfs/location/col1=1/col2=2 before the rename op, if the delete op fails (because /my/hdfs/location/col1=1/col2=2 may not exist), we call mkdirs op to create the parent dir /my/hdfs/location/col1=1 (if the parent dir does not exist) so the following rename op can succeed. Reference: in official HDFS document(https://hadoop.apache.org/docs/stable/hadoop-project-dist/hadoop-common/filesystem/filesystem.html), the rename command has precondition "dest must be root, or have a parent that exists" ## How was this patch tested? We have tested this patch on our production cluster and it fixed the problem Author: Fangshi Li Closes #20931 from fangshil/master. --- .../internal/io/HadoopMapReduceCommitProtocol.scala | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala index 6d20ef1f98a3c..3e60c50ada59b 100644 --- a/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala +++ b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala @@ -186,7 +186,17 @@ class HadoopMapReduceCommitProtocol( logDebug(s"Clean up default partition directories for overwriting: $partitionPaths") for (part <- partitionPaths) { val finalPartPath = new Path(path, part) - fs.delete(finalPartPath, true) + if (!fs.delete(finalPartPath, true) && !fs.exists(finalPartPath.getParent)) { + // According to the official hadoop FileSystem API spec, delete op should assume + // the destination is no longer present regardless of return value, thus we do not + // need to double check if finalPartPath exists before rename. + // Also in our case, based on the spec, delete returns false only when finalPartPath + // does not exist. When this happens, we need to take action if parent of finalPartPath + // also does not exist(e.g. the scenario described on SPARK-23815), because + // FileSystem API spec on rename op says the rename dest(finalPartPath) must have + // a parent that exists, otherwise we may get unexpected result on the rename. + fs.mkdirs(finalPartPath.getParent) + } fs.rename(new Path(stagingDir, part), finalPartPath) } } From 0323e61465ee747c9a57a70e9d6108876499546e Mon Sep 17 00:00:00 2001 From: yucai Date: Fri, 13 Apr 2018 00:00:04 -0700 Subject: [PATCH 0615/1161] [SPARK-23905][SQL] Add UDF weekday ## What changes were proposed in this pull request? Add UDF weekday ## How was this patch tested? A new test Author: yucai Closes #21009 from yucai/SPARK-23905. --- .../catalyst/analysis/FunctionRegistry.scala | 1 + .../expressions/datetimeExpressions.scala | 55 +++++++++++++++---- .../expressions/DateExpressionsSuite.scala | 11 ++++ .../resources/sql-tests/inputs/datetime.sql | 2 + .../sql-tests/results/datetime.sql.out | 9 ++- 5 files changed, 67 insertions(+), 11 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 747016beb06e7..131b958239e41 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -395,6 +395,7 @@ object FunctionRegistry { expression[TruncTimestamp]("date_trunc"), expression[UnixTimestamp]("unix_timestamp"), expression[DayOfWeek]("dayofweek"), + expression[WeekDay]("weekday"), expression[WeekOfYear]("weekofyear"), expression[Year]("year"), expression[TimeWindow]("window"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index 32fdb13afbbfa..b9b2cd5bdb9f0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -426,36 +426,71 @@ case class DayOfMonth(child: Expression) extends UnaryExpression with ImplicitCa """, since = "2.3.0") // scalastyle:on line.size.limit -case class DayOfWeek(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { +case class DayOfWeek(child: Expression) extends DayWeek { - override def inputTypes: Seq[AbstractDataType] = Seq(DateType) - - override def dataType: DataType = IntegerType + override protected def nullSafeEval(date: Any): Any = { + cal.setTimeInMillis(date.asInstanceOf[Int] * 1000L * 3600L * 24L) + cal.get(Calendar.DAY_OF_WEEK) + } - @transient private lazy val c = { - Calendar.getInstance(DateTimeUtils.getTimeZone("UTC")) + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, time => { + val cal = classOf[Calendar].getName + val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") + val c = "calDayOfWeek" + ctx.addImmutableStateIfNotExists(cal, c, + v => s"""$v = $cal.getInstance($dtu.getTimeZone("UTC"));""") + s""" + $c.setTimeInMillis($time * 1000L * 3600L * 24L); + ${ev.value} = $c.get($cal.DAY_OF_WEEK); + """ + }) } +} + +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(date) - Returns the day of the week for date/timestamp (0 = Monday, 1 = Tuesday, ..., 6 = Sunday).", + examples = """ + Examples: + > SELECT _FUNC_('2009-07-30'); + 3 + """, + since = "2.4.0") +// scalastyle:on line.size.limit +case class WeekDay(child: Expression) extends DayWeek { override protected def nullSafeEval(date: Any): Any = { - c.setTimeInMillis(date.asInstanceOf[Int] * 1000L * 3600L * 24L) - c.get(Calendar.DAY_OF_WEEK) + cal.setTimeInMillis(date.asInstanceOf[Int] * 1000L * 3600L * 24L) + (cal.get(Calendar.DAY_OF_WEEK) + 5 ) % 7 } override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, time => { val cal = classOf[Calendar].getName val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") - val c = "calDayOfWeek" + val c = "calWeekDay" ctx.addImmutableStateIfNotExists(cal, c, v => s"""$v = $cal.getInstance($dtu.getTimeZone("UTC"));""") s""" $c.setTimeInMillis($time * 1000L * 3600L * 24L); - ${ev.value} = $c.get($cal.DAY_OF_WEEK); + ${ev.value} = ($c.get($cal.DAY_OF_WEEK) + 5) % 7; """ }) } } +abstract class DayWeek extends UnaryExpression with ImplicitCastInputTypes { + + override def inputTypes: Seq[AbstractDataType] = Seq(DateType) + + override def dataType: DataType = IntegerType + + @transient protected lazy val cal: Calendar = { + Calendar.getInstance(DateTimeUtils.getTimeZone("UTC")) + } +} + // scalastyle:off line.size.limit @ExpressionDescription( usage = "_FUNC_(date) - Returns the week of the year of the given date. A week is considered to start on a Monday and week 1 is the first week with >3 days.", diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala index 786266a2c13c0..080ec487cfa6a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala @@ -211,6 +211,17 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkConsistencyBetweenInterpretedAndCodegen(DayOfWeek, DateType) } + test("WeekDay") { + checkEvaluation(WeekDay(Literal.create(null, DateType)), null) + checkEvaluation(WeekDay(Literal(d)), 2) + checkEvaluation(WeekDay(Cast(Literal(sdfDate.format(d)), DateType, gmtId)), 2) + checkEvaluation(WeekDay(Cast(Literal(ts), DateType, gmtId)), 4) + checkEvaluation(WeekDay(Cast(Literal("2011-05-06"), DateType, gmtId)), 4) + checkEvaluation(WeekDay(Literal(new Date(sdf.parse("2017-05-27 13:10:15").getTime))), 5) + checkEvaluation(WeekDay(Literal(new Date(sdf.parse("1582-10-15 13:10:15").getTime))), 4) + checkConsistencyBetweenInterpretedAndCodegen(WeekDay, DateType) + } + test("WeekOfYear") { checkEvaluation(WeekOfYear(Literal.create(null, DateType)), null) checkEvaluation(WeekOfYear(Literal(d)), 15) diff --git a/sql/core/src/test/resources/sql-tests/inputs/datetime.sql b/sql/core/src/test/resources/sql-tests/inputs/datetime.sql index adea2bfa82cd3..547c2bef02b24 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/datetime.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/datetime.sql @@ -25,3 +25,5 @@ create temporary view ttf2 as select * from values select current_date = current_date(), current_timestamp = current_timestamp(), a, b from ttf2; select a, b from ttf2 order by a, current_date; + +select weekday('2007-02-03'), weekday('2009-07-30'), weekday('2017-05-27'), weekday(null), weekday('1582-10-15 13:10:15'); diff --git a/sql/core/src/test/resources/sql-tests/results/datetime.sql.out b/sql/core/src/test/resources/sql-tests/results/datetime.sql.out index bbb6851e69c7e..4e1cfa6e48c1c 100644 --- a/sql/core/src/test/resources/sql-tests/results/datetime.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/datetime.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 9 +-- Number of queries: 10 -- !query 0 @@ -81,3 +81,10 @@ struct -- !query 8 output 1 2 2 3 + +-- !query 9 +select weekday('2007-02-03'), weekday('2009-07-30'), weekday('2017-05-27'), weekday(null), weekday('1582-10-15 13:10:15') +-- !query 3 schema +struct +-- !query 3 output +5 3 5 NULL 4 From a83ae0d9bc1b8f4909b9338370efe4020079bea7 Mon Sep 17 00:00:00 2001 From: mcheah Date: Fri, 13 Apr 2018 08:43:58 -0700 Subject: [PATCH 0616/1161] [SPARK-22839][K8S] Refactor to unify driver and executor pod builder APIs ## What changes were proposed in this pull request? Breaks down the construction of driver pods and executor pods in a way that uses a common abstraction for both spark-submit creating the driver and KubernetesClusterSchedulerBackend creating the executor. Encourages more code reuse and is more legible than the older approach. The high-level design is discussed in more detail on the JIRA ticket. This pull request is the implementation of that design with some minor changes in the implementation details. No user-facing behavior should break as a result of this change. ## How was this patch tested? Migrated all unit tests from the old submission steps architecture to the new architecture. Integration tests should not have to change and pass given that this shouldn't change any outward behavior. Author: mcheah Closes #20910 from mccheah/spark-22839-incremental. --- .../org/apache/spark/deploy/k8s/Config.scala | 2 +- .../spark/deploy/k8s/KubernetesConf.scala | 184 +++++++++++++ .../deploy/k8s/KubernetesDriverSpec.scala} | 25 +- .../spark/deploy/k8s/KubernetesUtils.scala | 11 - .../deploy/k8s/MountSecretsBootstrap.scala | 72 ----- ...ConfigurationStep.scala => SparkPod.scala} | 24 +- .../k8s/features/BasicDriverFeatureStep.scala | 136 ++++++++++ .../features/BasicExecutorFeatureStep.scala | 179 +++++++++++++ ...iverKubernetesCredentialsFeatureStep.scala | 216 +++++++++++++++ .../features/DriverServiceFeatureStep.scala | 97 +++++++ .../KubernetesFeatureConfigStep.scala | 71 +++++ .../features/MountSecretsFeatureStep.scala | 62 +++++ .../k8s/submit/DriverConfigOrchestrator.scala | 145 ----------- .../submit/KubernetesClientApplication.scala | 80 +++--- .../k8s/submit/KubernetesDriverBuilder.scala | 56 ++++ .../k8s/submit/KubernetesDriverSpec.scala | 47 ---- .../steps/BasicDriverConfigurationStep.scala | 163 ------------ .../steps/DependencyResolutionStep.scala | 61 ----- .../DriverKubernetesCredentialsStep.scala | 245 ------------------ .../submit/steps/DriverMountSecretsStep.scala | 38 --- .../steps/DriverServiceBootstrapStep.scala | 104 -------- .../cluster/k8s/ExecutorPodFactory.scala | 227 ---------------- .../k8s/KubernetesClusterManager.scala | 12 +- .../KubernetesClusterSchedulerBackend.scala | 20 +- .../k8s/KubernetesExecutorBuilder.scala | 41 +++ .../deploy/k8s/KubernetesConfSuite.scala | 175 +++++++++++++ .../BasicDriverFeatureStepSuite.scala | 153 +++++++++++ .../BasicExecutorFeatureStepSuite.scala | 179 +++++++++++++ ...bernetesCredentialsFeatureStepSuite.scala} | 101 +++++--- .../DriverServiceFeatureStepSuite.scala | 227 ++++++++++++++++ .../KubernetesFeaturesTestUtils.scala | 61 +++++ .../MountSecretsFeatureStepSuite.scala} | 29 ++- .../spark/deploy/k8s/submit/ClientSuite.scala | 216 +++++++-------- .../DriverConfigOrchestratorSuite.scala | 131 ---------- .../submit/KubernetesDriverBuilderSuite.scala | 102 ++++++++ .../BasicDriverConfigurationStepSuite.scala | 122 --------- .../steps/DependencyResolutionStepSuite.scala | 69 ----- .../DriverServiceBootstrapStepSuite.scala | 180 ------------- .../cluster/k8s/ExecutorPodFactorySuite.scala | 195 -------------- ...bernetesClusterSchedulerBackendSuite.scala | 37 ++- .../k8s/KubernetesExecutorBuilderSuite.scala | 75 ++++++ 41 files changed, 2289 insertions(+), 2081 deletions(-) create mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesConf.scala rename resource-managers/kubernetes/core/src/{test/scala/org/apache/spark/deploy/k8s/KubernetesUtilsTest.scala => main/scala/org/apache/spark/deploy/k8s/KubernetesDriverSpec.scala} (57%) delete mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/MountSecretsBootstrap.scala rename resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/{submit/steps/DriverConfigurationStep.scala => SparkPod.scala} (64%) create mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStep.scala create mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala create mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/DriverKubernetesCredentialsFeatureStep.scala create mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/DriverServiceFeatureStep.scala create mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/KubernetesFeatureConfigStep.scala create mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/MountSecretsFeatureStep.scala delete mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestrator.scala create mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala delete mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverSpec.scala delete mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStep.scala delete mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DependencyResolutionStep.scala delete mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverKubernetesCredentialsStep.scala delete mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverMountSecretsStep.scala delete mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverServiceBootstrapStep.scala delete mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala create mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala create mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesConfSuite.scala create mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStepSuite.scala create mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStepSuite.scala rename resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/{submit/steps/DriverKubernetesCredentialsStepSuite.scala => features/DriverKubernetesCredentialsFeatureStepSuite.scala} (67%) create mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverServiceFeatureStepSuite.scala create mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/KubernetesFeaturesTestUtils.scala rename resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/{submit/steps/DriverMountSecretsStepSuite.scala => features/MountSecretsFeatureStepSuite.scala} (64%) delete mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestratorSuite.scala create mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala delete mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStepSuite.scala delete mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DependencyResolutionStepSuite.scala delete mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverServiceBootstrapStepSuite.scala delete mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala create mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala index 82f6c714f3555..4086970ffb256 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala @@ -167,5 +167,5 @@ private[spark] object Config extends Logging { val KUBERNETES_EXECUTOR_ANNOTATION_PREFIX = "spark.kubernetes.executor.annotation." val KUBERNETES_EXECUTOR_SECRETS_PREFIX = "spark.kubernetes.executor.secrets." - val KUBERNETES_DRIVER_ENV_KEY = "spark.kubernetes.driverEnv." + val KUBERNETES_DRIVER_ENV_PREFIX = "spark.kubernetes.driverEnv." } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesConf.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesConf.scala new file mode 100644 index 0000000000000..77b634ddfabcc --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesConf.scala @@ -0,0 +1,184 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s + +import io.fabric8.kubernetes.api.model.{LocalObjectReference, LocalObjectReferenceBuilder, Pod} + +import org.apache.spark.SparkConf +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.deploy.k8s.submit.{JavaMainAppResource, MainAppResource} +import org.apache.spark.internal.config.ConfigEntry + +private[spark] sealed trait KubernetesRoleSpecificConf + +/* + * Structure containing metadata for Kubernetes logic that builds a Spark driver. + */ +private[spark] case class KubernetesDriverSpecificConf( + mainAppResource: Option[MainAppResource], + mainClass: String, + appName: String, + appArgs: Seq[String]) extends KubernetesRoleSpecificConf + +/* + * Structure containing metadata for Kubernetes logic that builds a Spark executor. + */ +private[spark] case class KubernetesExecutorSpecificConf( + executorId: String, + driverPod: Pod) + extends KubernetesRoleSpecificConf + +/** + * Structure containing metadata for Kubernetes logic to build Spark pods. + */ +private[spark] case class KubernetesConf[T <: KubernetesRoleSpecificConf]( + sparkConf: SparkConf, + roleSpecificConf: T, + appResourceNamePrefix: String, + appId: String, + roleLabels: Map[String, String], + roleAnnotations: Map[String, String], + roleSecretNamesToMountPaths: Map[String, String], + roleEnvs: Map[String, String]) { + + def namespace(): String = sparkConf.get(KUBERNETES_NAMESPACE) + + def sparkJars(): Seq[String] = sparkConf + .getOption("spark.jars") + .map(str => str.split(",").toSeq) + .getOrElse(Seq.empty[String]) + + def sparkFiles(): Seq[String] = sparkConf + .getOption("spark.files") + .map(str => str.split(",").toSeq) + .getOrElse(Seq.empty[String]) + + def imagePullPolicy(): String = sparkConf.get(CONTAINER_IMAGE_PULL_POLICY) + + def imagePullSecrets(): Seq[LocalObjectReference] = { + sparkConf + .get(IMAGE_PULL_SECRETS) + .map(_.split(",")) + .getOrElse(Array.empty[String]) + .map(_.trim) + .map { secret => + new LocalObjectReferenceBuilder().withName(secret).build() + } + } + + def nodeSelector(): Map[String, String] = + KubernetesUtils.parsePrefixedKeyValuePairs(sparkConf, KUBERNETES_NODE_SELECTOR_PREFIX) + + def get[T](config: ConfigEntry[T]): T = sparkConf.get(config) + + def get(conf: String): String = sparkConf.get(conf) + + def get(conf: String, defaultValue: String): String = sparkConf.get(conf, defaultValue) + + def getOption(key: String): Option[String] = sparkConf.getOption(key) +} + +private[spark] object KubernetesConf { + def createDriverConf( + sparkConf: SparkConf, + appName: String, + appResourceNamePrefix: String, + appId: String, + mainAppResource: Option[MainAppResource], + mainClass: String, + appArgs: Array[String]): KubernetesConf[KubernetesDriverSpecificConf] = { + val sparkConfWithMainAppJar = sparkConf.clone() + mainAppResource.foreach { + case JavaMainAppResource(res) => + val previousJars = sparkConf + .getOption("spark.jars") + .map(_.split(",")) + .getOrElse(Array.empty) + if (!previousJars.contains(res)) { + sparkConfWithMainAppJar.setJars(previousJars ++ Seq(res)) + } + } + + val driverCustomLabels = KubernetesUtils.parsePrefixedKeyValuePairs( + sparkConf, KUBERNETES_DRIVER_LABEL_PREFIX) + require(!driverCustomLabels.contains(SPARK_APP_ID_LABEL), "Label with key " + + s"$SPARK_APP_ID_LABEL is not allowed as it is reserved for Spark bookkeeping " + + "operations.") + require(!driverCustomLabels.contains(SPARK_ROLE_LABEL), "Label with key " + + s"$SPARK_ROLE_LABEL is not allowed as it is reserved for Spark bookkeeping " + + "operations.") + val driverLabels = driverCustomLabels ++ Map( + SPARK_APP_ID_LABEL -> appId, + SPARK_ROLE_LABEL -> SPARK_POD_DRIVER_ROLE) + val driverAnnotations = KubernetesUtils.parsePrefixedKeyValuePairs( + sparkConf, KUBERNETES_DRIVER_ANNOTATION_PREFIX) + val driverSecretNamesToMountPaths = KubernetesUtils.parsePrefixedKeyValuePairs( + sparkConf, KUBERNETES_DRIVER_SECRETS_PREFIX) + val driverEnvs = KubernetesUtils.parsePrefixedKeyValuePairs( + sparkConf, KUBERNETES_DRIVER_ENV_PREFIX) + + KubernetesConf( + sparkConfWithMainAppJar, + KubernetesDriverSpecificConf(mainAppResource, mainClass, appName, appArgs), + appResourceNamePrefix, + appId, + driverLabels, + driverAnnotations, + driverSecretNamesToMountPaths, + driverEnvs) + } + + def createExecutorConf( + sparkConf: SparkConf, + executorId: String, + appId: String, + driverPod: Pod): KubernetesConf[KubernetesExecutorSpecificConf] = { + val executorCustomLabels = KubernetesUtils.parsePrefixedKeyValuePairs( + sparkConf, KUBERNETES_EXECUTOR_LABEL_PREFIX) + require( + !executorCustomLabels.contains(SPARK_APP_ID_LABEL), + s"Custom executor labels cannot contain $SPARK_APP_ID_LABEL as it is reserved for Spark.") + require( + !executorCustomLabels.contains(SPARK_EXECUTOR_ID_LABEL), + s"Custom executor labels cannot contain $SPARK_EXECUTOR_ID_LABEL as it is reserved for" + + " Spark.") + require( + !executorCustomLabels.contains(SPARK_ROLE_LABEL), + s"Custom executor labels cannot contain $SPARK_ROLE_LABEL as it is reserved for Spark.") + val executorLabels = Map( + SPARK_EXECUTOR_ID_LABEL -> executorId, + SPARK_APP_ID_LABEL -> appId, + SPARK_ROLE_LABEL -> SPARK_POD_EXECUTOR_ROLE) ++ + executorCustomLabels + val executorAnnotations = KubernetesUtils.parsePrefixedKeyValuePairs( + sparkConf, KUBERNETES_EXECUTOR_ANNOTATION_PREFIX) + val executorSecrets = KubernetesUtils.parsePrefixedKeyValuePairs( + sparkConf, KUBERNETES_EXECUTOR_SECRETS_PREFIX) + val executorEnv = sparkConf.getExecutorEnv.toMap + + KubernetesConf( + sparkConf.clone(), + KubernetesExecutorSpecificConf(executorId, driverPod), + sparkConf.get(KUBERNETES_EXECUTOR_POD_NAME_PREFIX), + appId, + executorLabels, + executorAnnotations, + executorSecrets, + executorEnv) + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesUtilsTest.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesDriverSpec.scala similarity index 57% rename from resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesUtilsTest.scala rename to resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesDriverSpec.scala index cf41b22e241af..0c5ae022f4070 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesUtilsTest.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesDriverSpec.scala @@ -16,21 +16,16 @@ */ package org.apache.spark.deploy.k8s -import io.fabric8.kubernetes.api.model.LocalObjectReference +import io.fabric8.kubernetes.api.model.HasMetadata -import org.apache.spark.SparkFunSuite - -class KubernetesUtilsTest extends SparkFunSuite { - - test("testParseImagePullSecrets") { - val noSecrets = KubernetesUtils.parseImagePullSecrets(None) - assert(noSecrets === Nil) - - val oneSecret = KubernetesUtils.parseImagePullSecrets(Some("imagePullSecret")) - assert(oneSecret === new LocalObjectReference("imagePullSecret") :: Nil) - - val commaSeparatedSecrets = KubernetesUtils.parseImagePullSecrets(Some("s1, s2 , s3,s4")) - assert(commaSeparatedSecrets.map(_.getName) === "s1" :: "s2" :: "s3" :: "s4" :: Nil) - } +private[spark] case class KubernetesDriverSpec( + pod: SparkPod, + driverKubernetesResources: Seq[HasMetadata], + systemProperties: Map[String, String]) +private[spark] object KubernetesDriverSpec { + def initialSpec(initialProps: Map[String, String]): KubernetesDriverSpec = KubernetesDriverSpec( + SparkPod.initialPod(), + Seq.empty, + initialProps) } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesUtils.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesUtils.scala index 5b2bb819cdb14..ee629068ad90d 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesUtils.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesUtils.scala @@ -37,17 +37,6 @@ private[spark] object KubernetesUtils { sparkConf.getAllWithPrefix(prefix).toMap } - /** - * Parses comma-separated list of imagePullSecrets into K8s-understandable format - */ - def parseImagePullSecrets(imagePullSecrets: Option[String]): List[LocalObjectReference] = { - imagePullSecrets match { - case Some(secretsCommaSeparated) => - secretsCommaSeparated.split(',').map(_.trim).map(new LocalObjectReference(_)).toList - case None => Nil - } - } - def requireNandDefined(opt1: Option[_], opt2: Option[_], errMessage: String): Unit = { opt1.foreach { _ => require(opt2.isEmpty, errMessage) } } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/MountSecretsBootstrap.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/MountSecretsBootstrap.scala deleted file mode 100644 index c35e7db51d407..0000000000000 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/MountSecretsBootstrap.scala +++ /dev/null @@ -1,72 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s - -import io.fabric8.kubernetes.api.model.{Container, ContainerBuilder, Pod, PodBuilder} - -/** - * Bootstraps a driver or executor container or an init-container with needed secrets mounted. - */ -private[spark] class MountSecretsBootstrap(secretNamesToMountPaths: Map[String, String]) { - - /** - * Add new secret volumes for the secrets specified in secretNamesToMountPaths into the given pod. - * - * @param pod the pod into which the secret volumes are being added. - * @return the updated pod with the secret volumes added. - */ - def addSecretVolumes(pod: Pod): Pod = { - var podBuilder = new PodBuilder(pod) - secretNamesToMountPaths.keys.foreach { name => - podBuilder = podBuilder - .editOrNewSpec() - .addNewVolume() - .withName(secretVolumeName(name)) - .withNewSecret() - .withSecretName(name) - .endSecret() - .endVolume() - .endSpec() - } - - podBuilder.build() - } - - /** - * Mounts Kubernetes secret volumes of the secrets specified in secretNamesToMountPaths into the - * given container. - * - * @param container the container into which the secret volumes are being mounted. - * @return the updated container with the secrets mounted. - */ - def mountSecrets(container: Container): Container = { - var containerBuilder = new ContainerBuilder(container) - secretNamesToMountPaths.foreach { case (name, path) => - containerBuilder = containerBuilder - .addNewVolumeMount() - .withName(secretVolumeName(name)) - .withMountPath(path) - .endVolumeMount() - } - - containerBuilder.build() - } - - private def secretVolumeName(secretName: String): String = { - secretName + "-volume" - } -} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverConfigurationStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/SparkPod.scala similarity index 64% rename from resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverConfigurationStep.scala rename to resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/SparkPod.scala index 17614e040e587..345dd117fd35f 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverConfigurationStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/SparkPod.scala @@ -14,17 +14,21 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.deploy.k8s.submit.steps +package org.apache.spark.deploy.k8s -import org.apache.spark.deploy.k8s.submit.KubernetesDriverSpec +import io.fabric8.kubernetes.api.model.{Container, ContainerBuilder, Pod, PodBuilder} -/** - * Represents a step in configuring the Spark driver pod. - */ -private[spark] trait DriverConfigurationStep { +private[spark] case class SparkPod(pod: Pod, container: Container) - /** - * Apply some transformation to the previous state of the driver to add a new feature to it. - */ - def configureDriver(driverSpec: KubernetesDriverSpec): KubernetesDriverSpec +private[spark] object SparkPod { + def initialPod(): SparkPod = { + SparkPod( + new PodBuilder() + .withNewMetadata() + .endMetadata() + .withNewSpec() + .endSpec() + .build(), + new ContainerBuilder().build()) + } } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStep.scala new file mode 100644 index 0000000000000..07bdccbe0479d --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStep.scala @@ -0,0 +1,136 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.features + +import scala.collection.JavaConverters._ +import scala.collection.mutable + +import io.fabric8.kubernetes.api.model.{ContainerBuilder, EnvVarBuilder, EnvVarSourceBuilder, HasMetadata, PodBuilder, QuantityBuilder} + +import org.apache.spark.SparkException +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpecificConf, KubernetesUtils, SparkPod} +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.internal.config._ +import org.apache.spark.launcher.SparkLauncher + +private[spark] class BasicDriverFeatureStep( + conf: KubernetesConf[KubernetesDriverSpecificConf]) + extends KubernetesFeatureConfigStep { + + private val driverPodName = conf + .get(KUBERNETES_DRIVER_POD_NAME) + .getOrElse(s"${conf.appResourceNamePrefix}-driver") + + private val driverContainerImage = conf + .get(DRIVER_CONTAINER_IMAGE) + .getOrElse(throw new SparkException("Must specify the driver container image")) + + // CPU settings + private val driverCpuCores = conf.get("spark.driver.cores", "1") + private val driverLimitCores = conf.get(KUBERNETES_DRIVER_LIMIT_CORES) + + // Memory settings + private val driverMemoryMiB = conf.get(DRIVER_MEMORY) + private val memoryOverheadMiB = conf + .get(DRIVER_MEMORY_OVERHEAD) + .getOrElse(math.max((MEMORY_OVERHEAD_FACTOR * driverMemoryMiB).toInt, MEMORY_OVERHEAD_MIN_MIB)) + private val driverMemoryWithOverheadMiB = driverMemoryMiB + memoryOverheadMiB + + override def configurePod(pod: SparkPod): SparkPod = { + val driverCustomEnvs = conf.roleEnvs + .toSeq + .map { env => + new EnvVarBuilder() + .withName(env._1) + .withValue(env._2) + .build() + } + + val driverCpuQuantity = new QuantityBuilder(false) + .withAmount(driverCpuCores) + .build() + val driverMemoryQuantity = new QuantityBuilder(false) + .withAmount(s"${driverMemoryWithOverheadMiB}Mi") + .build() + val maybeCpuLimitQuantity = driverLimitCores.map { limitCores => + ("cpu", new QuantityBuilder(false).withAmount(limitCores).build()) + } + + val driverContainer = new ContainerBuilder(pod.container) + .withName(DRIVER_CONTAINER_NAME) + .withImage(driverContainerImage) + .withImagePullPolicy(conf.imagePullPolicy()) + .addAllToEnv(driverCustomEnvs.asJava) + .addNewEnv() + .withName(ENV_DRIVER_BIND_ADDRESS) + .withValueFrom(new EnvVarSourceBuilder() + .withNewFieldRef("v1", "status.podIP") + .build()) + .endEnv() + .withNewResources() + .addToRequests("cpu", driverCpuQuantity) + .addToLimits(maybeCpuLimitQuantity.toMap.asJava) + .addToRequests("memory", driverMemoryQuantity) + .addToLimits("memory", driverMemoryQuantity) + .endResources() + .addToArgs("driver") + .addToArgs("--properties-file", SPARK_CONF_PATH) + .addToArgs("--class", conf.roleSpecificConf.mainClass) + // The user application jar is merged into the spark.jars list and managed through that + // property, so there is no need to reference it explicitly here. + .addToArgs(SparkLauncher.NO_RESOURCE) + .addToArgs(conf.roleSpecificConf.appArgs: _*) + .build() + + val driverPod = new PodBuilder(pod.pod) + .editOrNewMetadata() + .withName(driverPodName) + .addToLabels(conf.roleLabels.asJava) + .addToAnnotations(conf.roleAnnotations.asJava) + .endMetadata() + .withNewSpec() + .withRestartPolicy("Never") + .withNodeSelector(conf.nodeSelector().asJava) + .addToImagePullSecrets(conf.imagePullSecrets(): _*) + .endSpec() + .build() + SparkPod(driverPod, driverContainer) + } + + override def getAdditionalPodSystemProperties(): Map[String, String] = { + val additionalProps = mutable.Map( + KUBERNETES_DRIVER_POD_NAME.key -> driverPodName, + "spark.app.id" -> conf.appId, + KUBERNETES_EXECUTOR_POD_NAME_PREFIX.key -> conf.appResourceNamePrefix, + KUBERNETES_DRIVER_SUBMIT_CHECK.key -> "true") + + val resolvedSparkJars = KubernetesUtils.resolveFileUrisAndPath( + conf.sparkJars()) + val resolvedSparkFiles = KubernetesUtils.resolveFileUrisAndPath( + conf.sparkFiles()) + if (resolvedSparkJars.nonEmpty) { + additionalProps.put("spark.jars", resolvedSparkJars.mkString(",")) + } + if (resolvedSparkFiles.nonEmpty) { + additionalProps.put("spark.files", resolvedSparkFiles.mkString(",")) + } + additionalProps.toMap + } + + override def getAdditionalKubernetesResources(): Seq[HasMetadata] = Seq.empty +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala new file mode 100644 index 0000000000000..d22097587aafe --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala @@ -0,0 +1,179 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.features + +import scala.collection.JavaConverters._ + +import io.fabric8.kubernetes.api.model.{ContainerBuilder, ContainerPortBuilder, EnvVar, EnvVarBuilder, EnvVarSourceBuilder, HasMetadata, PodBuilder, QuantityBuilder} + +import org.apache.spark.SparkException +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesExecutorSpecificConf, SparkPod} +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.internal.config.{EXECUTOR_CLASS_PATH, EXECUTOR_JAVA_OPTIONS, EXECUTOR_MEMORY, EXECUTOR_MEMORY_OVERHEAD} +import org.apache.spark.rpc.RpcEndpointAddress +import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend +import org.apache.spark.util.Utils + +private[spark] class BasicExecutorFeatureStep( + kubernetesConf: KubernetesConf[KubernetesExecutorSpecificConf]) + extends KubernetesFeatureConfigStep { + + // Consider moving some of these fields to KubernetesConf or KubernetesExecutorSpecificConf + private val executorExtraClasspath = kubernetesConf.get(EXECUTOR_CLASS_PATH) + private val executorContainerImage = kubernetesConf + .get(EXECUTOR_CONTAINER_IMAGE) + .getOrElse(throw new SparkException("Must specify the executor container image")) + private val blockManagerPort = kubernetesConf + .sparkConf + .getInt("spark.blockmanager.port", DEFAULT_BLOCKMANAGER_PORT) + + private val executorPodNamePrefix = kubernetesConf.appResourceNamePrefix + + private val driverUrl = RpcEndpointAddress( + kubernetesConf.get("spark.driver.host"), + kubernetesConf.sparkConf.getInt("spark.driver.port", DEFAULT_DRIVER_PORT), + CoarseGrainedSchedulerBackend.ENDPOINT_NAME).toString + private val executorMemoryMiB = kubernetesConf.get(EXECUTOR_MEMORY) + private val executorMemoryString = kubernetesConf.get( + EXECUTOR_MEMORY.key, EXECUTOR_MEMORY.defaultValueString) + + private val memoryOverheadMiB = kubernetesConf + .get(EXECUTOR_MEMORY_OVERHEAD) + .getOrElse(math.max((MEMORY_OVERHEAD_FACTOR * executorMemoryMiB).toInt, + MEMORY_OVERHEAD_MIN_MIB)) + private val executorMemoryWithOverhead = executorMemoryMiB + memoryOverheadMiB + + private val executorCores = kubernetesConf.sparkConf.getInt("spark.executor.cores", 1) + private val executorCoresRequest = + if (kubernetesConf.sparkConf.contains(KUBERNETES_EXECUTOR_REQUEST_CORES)) { + kubernetesConf.get(KUBERNETES_EXECUTOR_REQUEST_CORES).get + } else { + executorCores.toString + } + private val executorLimitCores = kubernetesConf.get(KUBERNETES_EXECUTOR_LIMIT_CORES) + + override def configurePod(pod: SparkPod): SparkPod = { + val name = s"$executorPodNamePrefix-exec-${kubernetesConf.roleSpecificConf.executorId}" + + // hostname must be no longer than 63 characters, so take the last 63 characters of the pod + // name as the hostname. This preserves uniqueness since the end of name contains + // executorId + val hostname = name.substring(Math.max(0, name.length - 63)) + val executorMemoryQuantity = new QuantityBuilder(false) + .withAmount(s"${executorMemoryWithOverhead}Mi") + .build() + val executorCpuQuantity = new QuantityBuilder(false) + .withAmount(executorCoresRequest) + .build() + val executorExtraClasspathEnv = executorExtraClasspath.map { cp => + new EnvVarBuilder() + .withName(ENV_CLASSPATH) + .withValue(cp) + .build() + } + val executorExtraJavaOptionsEnv = kubernetesConf + .get(EXECUTOR_JAVA_OPTIONS) + .map { opts => + val delimitedOpts = Utils.splitCommandString(opts) + delimitedOpts.zipWithIndex.map { + case (opt, index) => + new EnvVarBuilder().withName(s"$ENV_JAVA_OPT_PREFIX$index").withValue(opt).build() + } + }.getOrElse(Seq.empty[EnvVar]) + val executorEnv = (Seq( + (ENV_DRIVER_URL, driverUrl), + (ENV_EXECUTOR_CORES, executorCores.toString), + (ENV_EXECUTOR_MEMORY, executorMemoryString), + (ENV_APPLICATION_ID, kubernetesConf.appId), + // This is to set the SPARK_CONF_DIR to be /opt/spark/conf + (ENV_SPARK_CONF_DIR, SPARK_CONF_DIR_INTERNAL), + (ENV_EXECUTOR_ID, kubernetesConf.roleSpecificConf.executorId)) ++ + kubernetesConf.roleEnvs) + .map(env => new EnvVarBuilder() + .withName(env._1) + .withValue(env._2) + .build() + ) ++ Seq( + new EnvVarBuilder() + .withName(ENV_EXECUTOR_POD_IP) + .withValueFrom(new EnvVarSourceBuilder() + .withNewFieldRef("v1", "status.podIP") + .build()) + .build() + ) ++ executorExtraJavaOptionsEnv ++ executorExtraClasspathEnv.toSeq + val requiredPorts = Seq( + (BLOCK_MANAGER_PORT_NAME, blockManagerPort)) + .map { case (name, port) => + new ContainerPortBuilder() + .withName(name) + .withContainerPort(port) + .build() + } + + val executorContainer = new ContainerBuilder(pod.container) + .withName("executor") + .withImage(executorContainerImage) + .withImagePullPolicy(kubernetesConf.imagePullPolicy()) + .withNewResources() + .addToRequests("memory", executorMemoryQuantity) + .addToLimits("memory", executorMemoryQuantity) + .addToRequests("cpu", executorCpuQuantity) + .endResources() + .addAllToEnv(executorEnv.asJava) + .withPorts(requiredPorts.asJava) + .addToArgs("executor") + .build() + val containerWithLimitCores = executorLimitCores.map { limitCores => + val executorCpuLimitQuantity = new QuantityBuilder(false) + .withAmount(limitCores) + .build() + new ContainerBuilder(executorContainer) + .editResources() + .addToLimits("cpu", executorCpuLimitQuantity) + .endResources() + .build() + }.getOrElse(executorContainer) + val driverPod = kubernetesConf.roleSpecificConf.driverPod + val executorPod = new PodBuilder(pod.pod) + .editOrNewMetadata() + .withName(name) + .withLabels(kubernetesConf.roleLabels.asJava) + .withAnnotations(kubernetesConf.roleAnnotations.asJava) + .withOwnerReferences() + .addNewOwnerReference() + .withController(true) + .withApiVersion(driverPod.getApiVersion) + .withKind(driverPod.getKind) + .withName(driverPod.getMetadata.getName) + .withUid(driverPod.getMetadata.getUid) + .endOwnerReference() + .endMetadata() + .editOrNewSpec() + .withHostname(hostname) + .withRestartPolicy("Never") + .withNodeSelector(kubernetesConf.nodeSelector().asJava) + .addToImagePullSecrets(kubernetesConf.imagePullSecrets(): _*) + .endSpec() + .build() + SparkPod(executorPod, containerWithLimitCores) + } + + override def getAdditionalPodSystemProperties(): Map[String, String] = Map.empty + + override def getAdditionalKubernetesResources(): Seq[HasMetadata] = Seq.empty +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/DriverKubernetesCredentialsFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/DriverKubernetesCredentialsFeatureStep.scala new file mode 100644 index 0000000000000..ff5ad6673b309 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/DriverKubernetesCredentialsFeatureStep.scala @@ -0,0 +1,216 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.features + +import java.io.File +import java.nio.charset.StandardCharsets + +import scala.collection.JavaConverters._ + +import com.google.common.io.{BaseEncoding, Files} +import io.fabric8.kubernetes.api.model.{ContainerBuilder, HasMetadata, PodBuilder, Secret, SecretBuilder} + +import org.apache.spark.deploy.k8s.{KubernetesConf, SparkPod} +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.Constants._ + +private[spark] class DriverKubernetesCredentialsFeatureStep(kubernetesConf: KubernetesConf[_]) + extends KubernetesFeatureConfigStep { + // TODO clean up this class, and credentials in general. See also SparkKubernetesClientFactory. + // We should use a struct to hold all creds-related fields. A lot of the code is very repetitive. + + private val maybeMountedOAuthTokenFile = kubernetesConf.getOption( + s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$OAUTH_TOKEN_FILE_CONF_SUFFIX") + private val maybeMountedClientKeyFile = kubernetesConf.getOption( + s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$CLIENT_KEY_FILE_CONF_SUFFIX") + private val maybeMountedClientCertFile = kubernetesConf.getOption( + s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$CLIENT_CERT_FILE_CONF_SUFFIX") + private val maybeMountedCaCertFile = kubernetesConf.getOption( + s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$CA_CERT_FILE_CONF_SUFFIX") + private val driverServiceAccount = kubernetesConf.get(KUBERNETES_SERVICE_ACCOUNT_NAME) + + private val oauthTokenBase64 = kubernetesConf + .getOption(s"$KUBERNETES_AUTH_DRIVER_CONF_PREFIX.$OAUTH_TOKEN_CONF_SUFFIX") + .map { token => + BaseEncoding.base64().encode(token.getBytes(StandardCharsets.UTF_8)) + } + + private val caCertDataBase64 = safeFileConfToBase64( + s"$KUBERNETES_AUTH_DRIVER_CONF_PREFIX.$CA_CERT_FILE_CONF_SUFFIX", + "Driver CA cert file") + private val clientKeyDataBase64 = safeFileConfToBase64( + s"$KUBERNETES_AUTH_DRIVER_CONF_PREFIX.$CLIENT_KEY_FILE_CONF_SUFFIX", + "Driver client key file") + private val clientCertDataBase64 = safeFileConfToBase64( + s"$KUBERNETES_AUTH_DRIVER_CONF_PREFIX.$CLIENT_CERT_FILE_CONF_SUFFIX", + "Driver client cert file") + + // TODO decide whether or not to apply this step entirely in the caller, i.e. the builder. + private val shouldMountSecret = oauthTokenBase64.isDefined || + caCertDataBase64.isDefined || + clientKeyDataBase64.isDefined || + clientCertDataBase64.isDefined + + private val driverCredentialsSecretName = + s"${kubernetesConf.appResourceNamePrefix}-kubernetes-credentials" + + override def configurePod(pod: SparkPod): SparkPod = { + if (!shouldMountSecret) { + pod.copy( + pod = driverServiceAccount.map { account => + new PodBuilder(pod.pod) + .editOrNewSpec() + .withServiceAccount(account) + .withServiceAccountName(account) + .endSpec() + .build() + }.getOrElse(pod.pod)) + } else { + val driverPodWithMountedKubernetesCredentials = + new PodBuilder(pod.pod) + .editOrNewSpec() + .addNewVolume() + .withName(DRIVER_CREDENTIALS_SECRET_VOLUME_NAME) + .withNewSecret().withSecretName(driverCredentialsSecretName).endSecret() + .endVolume() + .endSpec() + .build() + + val driverContainerWithMountedSecretVolume = + new ContainerBuilder(pod.container) + .addNewVolumeMount() + .withName(DRIVER_CREDENTIALS_SECRET_VOLUME_NAME) + .withMountPath(DRIVER_CREDENTIALS_SECRETS_BASE_DIR) + .endVolumeMount() + .build() + SparkPod(driverPodWithMountedKubernetesCredentials, driverContainerWithMountedSecretVolume) + } + } + + override def getAdditionalPodSystemProperties(): Map[String, String] = { + val resolvedMountedOAuthTokenFile = resolveSecretLocation( + maybeMountedOAuthTokenFile, + oauthTokenBase64, + DRIVER_CREDENTIALS_OAUTH_TOKEN_PATH) + val resolvedMountedClientKeyFile = resolveSecretLocation( + maybeMountedClientKeyFile, + clientKeyDataBase64, + DRIVER_CREDENTIALS_CLIENT_KEY_PATH) + val resolvedMountedClientCertFile = resolveSecretLocation( + maybeMountedClientCertFile, + clientCertDataBase64, + DRIVER_CREDENTIALS_CLIENT_CERT_PATH) + val resolvedMountedCaCertFile = resolveSecretLocation( + maybeMountedCaCertFile, + caCertDataBase64, + DRIVER_CREDENTIALS_CA_CERT_PATH) + + val redactedTokens = kubernetesConf.sparkConf.getAll + .filter(_._1.endsWith(OAUTH_TOKEN_CONF_SUFFIX)) + .toMap + .mapValues( _ => "") + redactedTokens ++ + resolvedMountedCaCertFile.map { file => + Map( + s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$CA_CERT_FILE_CONF_SUFFIX" -> + file) + }.getOrElse(Map.empty) ++ + resolvedMountedClientKeyFile.map { file => + Map( + s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$CLIENT_KEY_FILE_CONF_SUFFIX" -> + file) + }.getOrElse(Map.empty) ++ + resolvedMountedClientCertFile.map { file => + Map( + s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$CLIENT_CERT_FILE_CONF_SUFFIX" -> + file) + }.getOrElse(Map.empty) ++ + resolvedMountedOAuthTokenFile.map { file => + Map( + s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$OAUTH_TOKEN_FILE_CONF_SUFFIX" -> + file) + }.getOrElse(Map.empty) + } + + override def getAdditionalKubernetesResources(): Seq[HasMetadata] = { + if (shouldMountSecret) { + Seq(createCredentialsSecret()) + } else { + Seq.empty + } + } + + private def safeFileConfToBase64(conf: String, fileType: String): Option[String] = { + kubernetesConf.getOption(conf) + .map(new File(_)) + .map { file => + require(file.isFile, String.format("%s provided at %s does not exist or is not a file.", + fileType, file.getAbsolutePath)) + BaseEncoding.base64().encode(Files.toByteArray(file)) + } + } + + /** + * Resolve a Kubernetes secret data entry from an optional client credential used by the + * driver to talk to the Kubernetes API server. + * + * @param userSpecifiedCredential the optional user-specified client credential. + * @param secretName name of the Kubernetes secret storing the client credential. + * @return a secret data entry in the form of a map from the secret name to the secret data, + * which may be empty if the user-specified credential is empty. + */ + private def resolveSecretData( + userSpecifiedCredential: Option[String], + secretName: String): Map[String, String] = { + userSpecifiedCredential.map { valueBase64 => + Map(secretName -> valueBase64) + }.getOrElse(Map.empty[String, String]) + } + + private def resolveSecretLocation( + mountedUserSpecified: Option[String], + valueMountedFromSubmitter: Option[String], + mountedCanonicalLocation: String): Option[String] = { + mountedUserSpecified.orElse(valueMountedFromSubmitter.map { _ => + mountedCanonicalLocation + }) + } + + private def createCredentialsSecret(): Secret = { + val allSecretData = + resolveSecretData( + clientKeyDataBase64, + DRIVER_CREDENTIALS_CLIENT_KEY_SECRET_NAME) ++ + resolveSecretData( + clientCertDataBase64, + DRIVER_CREDENTIALS_CLIENT_CERT_SECRET_NAME) ++ + resolveSecretData( + caCertDataBase64, + DRIVER_CREDENTIALS_CA_CERT_SECRET_NAME) ++ + resolveSecretData( + oauthTokenBase64, + DRIVER_CREDENTIALS_OAUTH_TOKEN_SECRET_NAME) + + new SecretBuilder() + .withNewMetadata() + .withName(driverCredentialsSecretName) + .endMetadata() + .withData(allSecretData.asJava) + .build() + } + +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/DriverServiceFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/DriverServiceFeatureStep.scala new file mode 100644 index 0000000000000..f2d7bbd08f305 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/DriverServiceFeatureStep.scala @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.features + +import scala.collection.JavaConverters._ + +import io.fabric8.kubernetes.api.model.{HasMetadata, ServiceBuilder} + +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpecificConf, SparkPod} +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.internal.Logging +import org.apache.spark.util.{Clock, SystemClock} + +private[spark] class DriverServiceFeatureStep( + kubernetesConf: KubernetesConf[KubernetesDriverSpecificConf], + clock: Clock = new SystemClock) + extends KubernetesFeatureConfigStep with Logging { + import DriverServiceFeatureStep._ + + require(kubernetesConf.getOption(DRIVER_BIND_ADDRESS_KEY).isEmpty, + s"$DRIVER_BIND_ADDRESS_KEY is not supported in Kubernetes mode, as the driver's bind " + + "address is managed and set to the driver pod's IP address.") + require(kubernetesConf.getOption(DRIVER_HOST_KEY).isEmpty, + s"$DRIVER_HOST_KEY is not supported in Kubernetes mode, as the driver's hostname will be " + + "managed via a Kubernetes service.") + + private val preferredServiceName = s"${kubernetesConf.appResourceNamePrefix}$DRIVER_SVC_POSTFIX" + private val resolvedServiceName = if (preferredServiceName.length <= MAX_SERVICE_NAME_LENGTH) { + preferredServiceName + } else { + val randomServiceId = clock.getTimeMillis() + val shorterServiceName = s"spark-$randomServiceId$DRIVER_SVC_POSTFIX" + logWarning(s"Driver's hostname would preferably be $preferredServiceName, but this is " + + s"too long (must be <= $MAX_SERVICE_NAME_LENGTH characters). Falling back to use " + + s"$shorterServiceName as the driver service's name.") + shorterServiceName + } + + private val driverPort = kubernetesConf.sparkConf.getInt( + "spark.driver.port", DEFAULT_DRIVER_PORT) + private val driverBlockManagerPort = kubernetesConf.sparkConf.getInt( + org.apache.spark.internal.config.DRIVER_BLOCK_MANAGER_PORT.key, DEFAULT_BLOCKMANAGER_PORT) + + override def configurePod(pod: SparkPod): SparkPod = pod + + override def getAdditionalPodSystemProperties(): Map[String, String] = { + val driverHostname = s"$resolvedServiceName.${kubernetesConf.namespace()}.svc" + Map(DRIVER_HOST_KEY -> driverHostname, + "spark.driver.port" -> driverPort.toString, + org.apache.spark.internal.config.DRIVER_BLOCK_MANAGER_PORT.key -> + driverBlockManagerPort.toString) + } + + override def getAdditionalKubernetesResources(): Seq[HasMetadata] = { + val driverService = new ServiceBuilder() + .withNewMetadata() + .withName(resolvedServiceName) + .endMetadata() + .withNewSpec() + .withClusterIP("None") + .withSelector(kubernetesConf.roleLabels.asJava) + .addNewPort() + .withName(DRIVER_PORT_NAME) + .withPort(driverPort) + .withNewTargetPort(driverPort) + .endPort() + .addNewPort() + .withName(BLOCK_MANAGER_PORT_NAME) + .withPort(driverBlockManagerPort) + .withNewTargetPort(driverBlockManagerPort) + .endPort() + .endSpec() + .build() + Seq(driverService) + } +} + +private[spark] object DriverServiceFeatureStep { + val DRIVER_BIND_ADDRESS_KEY = org.apache.spark.internal.config.DRIVER_BIND_ADDRESS.key + val DRIVER_HOST_KEY = org.apache.spark.internal.config.DRIVER_HOST_ADDRESS.key + val DRIVER_SVC_POSTFIX = "-driver-svc" + val MAX_SERVICE_NAME_LENGTH = 63 +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/KubernetesFeatureConfigStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/KubernetesFeatureConfigStep.scala new file mode 100644 index 0000000000000..4c1be3bb13293 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/KubernetesFeatureConfigStep.scala @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.features + +import io.fabric8.kubernetes.api.model.HasMetadata + +import org.apache.spark.deploy.k8s.SparkPod + +/** + * A collection of functions that together represent a "feature" in pods that are launched for + * Spark drivers and executors. + */ +private[spark] trait KubernetesFeatureConfigStep { + + /** + * Apply modifications on the given pod in accordance to this feature. This can include attaching + * volumes, adding environment variables, and adding labels/annotations. + *

    + * Note that we should return a SparkPod that keeps all of the properties of the passed SparkPod + * object. So this is correct: + *

    +   * {@code val configuredPod = new PodBuilder(pod.pod)
    +   *     .editSpec()
    +   *     ...
    +   *     .build()
    +   *   val configuredContainer = new ContainerBuilder(pod.container)
    +   *     ...
    +   *     .build()
    +   *   SparkPod(configuredPod, configuredContainer)
    +   *  }
    +   * 
    + * This is incorrect: + *
    +   * {@code val configuredPod = new PodBuilder() // Loses the original state
    +   *     .editSpec()
    +   *     ...
    +   *     .build()
    +   *   val configuredContainer = new ContainerBuilder() // Loses the original state
    +   *     ...
    +   *     .build()
    +   *   SparkPod(configuredPod, configuredContainer)
    +   *  }
    +   * 
    + */ + def configurePod(pod: SparkPod): SparkPod + + /** + * Return any system properties that should be set on the JVM in accordance to this feature. + */ + def getAdditionalPodSystemProperties(): Map[String, String] + + /** + * Return any additional Kubernetes resources that should be added to support this feature. Only + * applicable when creating the driver in cluster mode. + */ + def getAdditionalKubernetesResources(): Seq[HasMetadata] +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/MountSecretsFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/MountSecretsFeatureStep.scala new file mode 100644 index 0000000000000..97fa9499b2edb --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/MountSecretsFeatureStep.scala @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.features + +import io.fabric8.kubernetes.api.model.{ContainerBuilder, HasMetadata, PodBuilder, VolumeBuilder, VolumeMountBuilder} + +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesRoleSpecificConf, SparkPod} + +private[spark] class MountSecretsFeatureStep( + kubernetesConf: KubernetesConf[_ <: KubernetesRoleSpecificConf]) + extends KubernetesFeatureConfigStep { + override def configurePod(pod: SparkPod): SparkPod = { + val addedVolumes = kubernetesConf + .roleSecretNamesToMountPaths + .keys + .map(secretName => + new VolumeBuilder() + .withName(secretVolumeName(secretName)) + .withNewSecret() + .withSecretName(secretName) + .endSecret() + .build()) + val podWithVolumes = new PodBuilder(pod.pod) + .editOrNewSpec() + .addToVolumes(addedVolumes.toSeq: _*) + .endSpec() + .build() + val addedVolumeMounts = kubernetesConf + .roleSecretNamesToMountPaths + .map { + case (secretName, mountPath) => + new VolumeMountBuilder() + .withName(secretVolumeName(secretName)) + .withMountPath(mountPath) + .build() + } + val containerWithMounts = new ContainerBuilder(pod.container) + .addToVolumeMounts(addedVolumeMounts.toSeq: _*) + .build() + SparkPod(podWithVolumes, containerWithMounts) + } + + override def getAdditionalPodSystemProperties(): Map[String, String] = Map.empty + + override def getAdditionalKubernetesResources(): Seq[HasMetadata] = Seq.empty + + private def secretVolumeName(secretName: String): String = s"$secretName-volume" +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestrator.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestrator.scala deleted file mode 100644 index b4d3f04a1bc32..0000000000000 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestrator.scala +++ /dev/null @@ -1,145 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s.submit - -import org.apache.spark.{SparkConf, SparkException} -import org.apache.spark.deploy.k8s.{KubernetesUtils, MountSecretsBootstrap} -import org.apache.spark.deploy.k8s.Config._ -import org.apache.spark.deploy.k8s.Constants._ -import org.apache.spark.deploy.k8s.submit.steps._ -import org.apache.spark.launcher.SparkLauncher -import org.apache.spark.util.SystemClock -import org.apache.spark.util.Utils - -/** - * Figures out and returns the complete ordered list of needed DriverConfigurationSteps to - * configure the Spark driver pod. The returned steps will be applied one by one in the given - * order to produce a final KubernetesDriverSpec that is used in KubernetesClientApplication - * to construct and create the driver pod. - */ -private[spark] class DriverConfigOrchestrator( - kubernetesAppId: String, - kubernetesResourceNamePrefix: String, - mainAppResource: Option[MainAppResource], - appName: String, - mainClass: String, - appArgs: Array[String], - sparkConf: SparkConf) { - - // The resource name prefix is derived from the Spark application name, making it easy to connect - // the names of the Kubernetes resources from e.g. kubectl or the Kubernetes dashboard to the - // application the user submitted. - - private val imagePullPolicy = sparkConf.get(CONTAINER_IMAGE_PULL_POLICY) - - def getAllConfigurationSteps: Seq[DriverConfigurationStep] = { - val driverCustomLabels = KubernetesUtils.parsePrefixedKeyValuePairs( - sparkConf, - KUBERNETES_DRIVER_LABEL_PREFIX) - require(!driverCustomLabels.contains(SPARK_APP_ID_LABEL), "Label with key " + - s"$SPARK_APP_ID_LABEL is not allowed as it is reserved for Spark bookkeeping " + - "operations.") - require(!driverCustomLabels.contains(SPARK_ROLE_LABEL), "Label with key " + - s"$SPARK_ROLE_LABEL is not allowed as it is reserved for Spark bookkeeping " + - "operations.") - - val secretNamesToMountPaths = KubernetesUtils.parsePrefixedKeyValuePairs( - sparkConf, - KUBERNETES_DRIVER_SECRETS_PREFIX) - - val allDriverLabels = driverCustomLabels ++ Map( - SPARK_APP_ID_LABEL -> kubernetesAppId, - SPARK_ROLE_LABEL -> SPARK_POD_DRIVER_ROLE) - - val initialSubmissionStep = new BasicDriverConfigurationStep( - kubernetesAppId, - kubernetesResourceNamePrefix, - allDriverLabels, - imagePullPolicy, - appName, - mainClass, - appArgs, - sparkConf) - - val serviceBootstrapStep = new DriverServiceBootstrapStep( - kubernetesResourceNamePrefix, - allDriverLabels, - sparkConf, - new SystemClock) - - val kubernetesCredentialsStep = new DriverKubernetesCredentialsStep( - sparkConf, kubernetesResourceNamePrefix) - - val additionalMainAppJar = if (mainAppResource.nonEmpty) { - val mayBeResource = mainAppResource.get match { - case JavaMainAppResource(resource) if resource != SparkLauncher.NO_RESOURCE => - Some(resource) - case _ => None - } - mayBeResource - } else { - None - } - - val sparkJars = sparkConf.getOption("spark.jars") - .map(_.split(",")) - .getOrElse(Array.empty[String]) ++ - additionalMainAppJar.toSeq - val sparkFiles = sparkConf.getOption("spark.files") - .map(_.split(",")) - .getOrElse(Array.empty[String]) - - // TODO(SPARK-23153): remove once submission client local dependencies are supported. - if (existSubmissionLocalFiles(sparkJars) || existSubmissionLocalFiles(sparkFiles)) { - throw new SparkException("The Kubernetes mode does not yet support referencing application " + - "dependencies in the local file system.") - } - - val dependencyResolutionStep = if (sparkJars.nonEmpty || sparkFiles.nonEmpty) { - Seq(new DependencyResolutionStep( - sparkJars, - sparkFiles)) - } else { - Nil - } - - val mountSecretsStep = if (secretNamesToMountPaths.nonEmpty) { - Seq(new DriverMountSecretsStep(new MountSecretsBootstrap(secretNamesToMountPaths))) - } else { - Nil - } - - Seq( - initialSubmissionStep, - serviceBootstrapStep, - kubernetesCredentialsStep) ++ - dependencyResolutionStep ++ - mountSecretsStep - } - - private def existSubmissionLocalFiles(files: Seq[String]): Boolean = { - files.exists { uri => - Utils.resolveURI(uri).getScheme == "file" - } - } - - private def existNonContainerLocalFiles(files: Seq[String]): Boolean = { - files.exists { uri => - Utils.resolveURI(uri).getScheme != "local" - } - } -} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesClientApplication.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesClientApplication.scala index e16d1add600b2..a97f5650fb869 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesClientApplication.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesClientApplication.scala @@ -27,12 +27,10 @@ import scala.util.control.NonFatal import org.apache.spark.SparkConf import org.apache.spark.deploy.SparkApplication +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpecificConf, SparkKubernetesClientFactory} import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ -import org.apache.spark.deploy.k8s.SparkKubernetesClientFactory -import org.apache.spark.deploy.k8s.submit.steps.DriverConfigurationStep import org.apache.spark.internal.Logging -import org.apache.spark.internal.config.ConfigBuilder import org.apache.spark.util.Utils /** @@ -43,9 +41,9 @@ import org.apache.spark.util.Utils * @param driverArgs arguments to the driver */ private[spark] case class ClientArguments( - mainAppResource: Option[MainAppResource], - mainClass: String, - driverArgs: Array[String]) + mainAppResource: Option[MainAppResource], + mainClass: String, + driverArgs: Array[String]) private[spark] object ClientArguments { @@ -80,8 +78,9 @@ private[spark] object ClientArguments { * watcher that monitors and logs the application status. Waits for the application to terminate if * spark.kubernetes.submission.waitAppCompletion is true. * - * @param submissionSteps steps that collectively configure the driver - * @param sparkConf the submission client Spark configuration + * @param builder Responsible for building the base driver pod based on a composition of + * implemented features. + * @param kubernetesConf application configuration * @param kubernetesClient the client to talk to the Kubernetes API server * @param waitForAppCompletion a flag indicating whether the client should wait for the application * to complete @@ -89,31 +88,21 @@ private[spark] object ClientArguments { * @param watcher a watcher that monitors and logs the application status */ private[spark] class Client( - submissionSteps: Seq[DriverConfigurationStep], - sparkConf: SparkConf, + builder: KubernetesDriverBuilder, + kubernetesConf: KubernetesConf[KubernetesDriverSpecificConf], kubernetesClient: KubernetesClient, waitForAppCompletion: Boolean, appName: String, watcher: LoggingPodStatusWatcher, kubernetesResourceNamePrefix: String) extends Logging { - /** - * Run command that initializes a DriverSpec that will be updated after each - * DriverConfigurationStep in the sequence that is passed in. The final KubernetesDriverSpec - * will be used to build the Driver Container, Driver Pod, and Kubernetes Resources - */ def run(): Unit = { - var currentDriverSpec = KubernetesDriverSpec.initialSpec(sparkConf) - // submissionSteps contain steps necessary to take, to resolve varying - // client arguments that are passed in, created by orchestrator - for (nextStep <- submissionSteps) { - currentDriverSpec = nextStep.configureDriver(currentDriverSpec) - } + val resolvedDriverSpec = builder.buildFromFeatures(kubernetesConf) val configMapName = s"$kubernetesResourceNamePrefix-driver-conf-map" - val configMap = buildConfigMap(configMapName, currentDriverSpec.driverSparkConf) + val configMap = buildConfigMap(configMapName, resolvedDriverSpec.systemProperties) // The include of the ENV_VAR for "SPARK_CONF_DIR" is to allow for the // Spark command builder to pickup on the Java Options present in the ConfigMap - val resolvedDriverContainer = new ContainerBuilder(currentDriverSpec.driverContainer) + val resolvedDriverContainer = new ContainerBuilder(resolvedDriverSpec.pod.container) .addNewEnv() .withName(ENV_SPARK_CONF_DIR) .withValue(SPARK_CONF_DIR_INTERNAL) @@ -123,7 +112,7 @@ private[spark] class Client( .withMountPath(SPARK_CONF_DIR_INTERNAL) .endVolumeMount() .build() - val resolvedDriverPod = new PodBuilder(currentDriverSpec.driverPod) + val resolvedDriverPod = new PodBuilder(resolvedDriverSpec.pod.pod) .editSpec() .addToContainers(resolvedDriverContainer) .addNewVolume() @@ -141,12 +130,10 @@ private[spark] class Client( .watch(watcher)) { _ => val createdDriverPod = kubernetesClient.pods().create(resolvedDriverPod) try { - if (currentDriverSpec.otherKubernetesResources.nonEmpty) { - val otherKubernetesResources = - currentDriverSpec.otherKubernetesResources ++ Seq(configMap) - addDriverOwnerReference(createdDriverPod, otherKubernetesResources) - kubernetesClient.resourceList(otherKubernetesResources: _*).createOrReplace() - } + val otherKubernetesResources = + resolvedDriverSpec.driverKubernetesResources ++ Seq(configMap) + addDriverOwnerReference(createdDriverPod, otherKubernetesResources) + kubernetesClient.resourceList(otherKubernetesResources: _*).createOrReplace() } catch { case NonFatal(e) => kubernetesClient.pods().delete(createdDriverPod) @@ -180,20 +167,17 @@ private[spark] class Client( } // Build a Config Map that will house spark conf properties in a single file for spark-submit - private def buildConfigMap(configMapName: String, conf: SparkConf): ConfigMap = { + private def buildConfigMap(configMapName: String, conf: Map[String, String]): ConfigMap = { val properties = new Properties() - conf.getAll.foreach { case (k, v) => + conf.foreach { case (k, v) => properties.setProperty(k, v) } val propertiesWriter = new StringWriter() properties.store(propertiesWriter, s"Java properties built from Kubernetes config map with name: $configMapName") - - val namespace = conf.get(KUBERNETES_NAMESPACE) new ConfigMapBuilder() .withNewMetadata() .withName(configMapName) - .withNamespace(namespace) .endMetadata() .addToData(SPARK_CONF_FILE_NAME, propertiesWriter.toString) .build() @@ -211,7 +195,7 @@ private[spark] class KubernetesClientApplication extends SparkApplication { } private def run(clientArguments: ClientArguments, sparkConf: SparkConf): Unit = { - val namespace = sparkConf.get(KUBERNETES_NAMESPACE) + val appName = sparkConf.getOption("spark.app.name").getOrElse("spark") // For constructing the app ID, we can't use the Spark application name, as the app ID is going // to be added as a label to group resources belonging to the same application. Label values are // considerably restrictive, e.g. must be no longer than 63 characters in length. So we generate @@ -219,10 +203,19 @@ private[spark] class KubernetesClientApplication extends SparkApplication { val kubernetesAppId = s"spark-${UUID.randomUUID().toString.replaceAll("-", "")}" val launchTime = System.currentTimeMillis() val waitForAppCompletion = sparkConf.get(WAIT_FOR_APP_COMPLETION) - val appName = sparkConf.getOption("spark.app.name").getOrElse("spark") val kubernetesResourceNamePrefix = { s"$appName-$launchTime".toLowerCase.replaceAll("\\.", "-") } + val kubernetesConf = KubernetesConf.createDriverConf( + sparkConf, + appName, + kubernetesResourceNamePrefix, + kubernetesAppId, + clientArguments.mainAppResource, + clientArguments.mainClass, + clientArguments.driverArgs) + val builder = new KubernetesDriverBuilder + val namespace = kubernetesConf.namespace() // The master URL has been checked for validity already in SparkSubmit. // We just need to get rid of the "k8s://" prefix here. val master = sparkConf.get("spark.master").substring("k8s://".length) @@ -230,15 +223,6 @@ private[spark] class KubernetesClientApplication extends SparkApplication { val watcher = new LoggingPodStatusWatcherImpl(kubernetesAppId, loggingInterval) - val orchestrator = new DriverConfigOrchestrator( - kubernetesAppId, - kubernetesResourceNamePrefix, - clientArguments.mainAppResource, - appName, - clientArguments.mainClass, - clientArguments.driverArgs, - sparkConf) - Utils.tryWithResource(SparkKubernetesClientFactory.createKubernetesClient( master, Some(namespace), @@ -247,8 +231,8 @@ private[spark] class KubernetesClientApplication extends SparkApplication { None, None)) { kubernetesClient => val client = new Client( - orchestrator.getAllConfigurationSteps, - sparkConf, + builder, + kubernetesConf, kubernetesClient, waitForAppCompletion, appName, diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala new file mode 100644 index 0000000000000..c7579ed8cb689 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.submit + +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpec, KubernetesDriverSpecificConf, KubernetesRoleSpecificConf} +import org.apache.spark.deploy.k8s.features.{BasicDriverFeatureStep, DriverKubernetesCredentialsFeatureStep, DriverServiceFeatureStep, MountSecretsFeatureStep} + +private[spark] class KubernetesDriverBuilder( + provideBasicStep: (KubernetesConf[KubernetesDriverSpecificConf]) => BasicDriverFeatureStep = + new BasicDriverFeatureStep(_), + provideCredentialsStep: (KubernetesConf[KubernetesDriverSpecificConf]) + => DriverKubernetesCredentialsFeatureStep = + new DriverKubernetesCredentialsFeatureStep(_), + provideServiceStep: (KubernetesConf[KubernetesDriverSpecificConf]) => DriverServiceFeatureStep = + new DriverServiceFeatureStep(_), + provideSecretsStep: (KubernetesConf[_ <: KubernetesRoleSpecificConf] + => MountSecretsFeatureStep) = + new MountSecretsFeatureStep(_)) { + + def buildFromFeatures( + kubernetesConf: KubernetesConf[KubernetesDriverSpecificConf]): KubernetesDriverSpec = { + val baseFeatures = Seq( + provideBasicStep(kubernetesConf), + provideCredentialsStep(kubernetesConf), + provideServiceStep(kubernetesConf)) + val allFeatures = if (kubernetesConf.roleSecretNamesToMountPaths.nonEmpty) { + baseFeatures ++ Seq(provideSecretsStep(kubernetesConf)) + } else baseFeatures + + var spec = KubernetesDriverSpec.initialSpec(kubernetesConf.sparkConf.getAll.toMap) + for (feature <- allFeatures) { + val configuredPod = feature.configurePod(spec.pod) + val addedSystemProperties = feature.getAdditionalPodSystemProperties() + val addedResources = feature.getAdditionalKubernetesResources() + spec = KubernetesDriverSpec( + configuredPod, + spec.driverKubernetesResources ++ addedResources, + spec.systemProperties ++ addedSystemProperties) + } + spec + } +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverSpec.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverSpec.scala deleted file mode 100644 index db13f09387ef9..0000000000000 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverSpec.scala +++ /dev/null @@ -1,47 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s.submit - -import io.fabric8.kubernetes.api.model.{Container, ContainerBuilder, HasMetadata, Pod, PodBuilder} - -import org.apache.spark.SparkConf - -/** - * Represents the components and characteristics of a Spark driver. The driver can be considered - * as being comprised of the driver pod itself, any other Kubernetes resources that the driver - * pod depends on, and the SparkConf that should be supplied to the Spark application. The driver - * container should be operated on via the specific field of this case class as opposed to trying - * to edit the container directly on the pod. The driver container should be attached at the - * end of executing all submission steps. - */ -private[spark] case class KubernetesDriverSpec( - driverPod: Pod, - driverContainer: Container, - otherKubernetesResources: Seq[HasMetadata], - driverSparkConf: SparkConf) - -private[spark] object KubernetesDriverSpec { - def initialSpec(initialSparkConf: SparkConf): KubernetesDriverSpec = { - KubernetesDriverSpec( - // Set new metadata and a new spec so that submission steps can use - // PodBuilder#editMetadata() and/or PodBuilder#editSpec() safely. - new PodBuilder().withNewMetadata().endMetadata().withNewSpec().endSpec().build(), - new ContainerBuilder().build(), - Seq.empty[HasMetadata], - initialSparkConf.clone()) - } -} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStep.scala deleted file mode 100644 index fcb1db8008053..0000000000000 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStep.scala +++ /dev/null @@ -1,163 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s.submit.steps - -import scala.collection.JavaConverters._ - -import io.fabric8.kubernetes.api.model._ - -import org.apache.spark.{SparkConf, SparkException} -import org.apache.spark.deploy.k8s.Config._ -import org.apache.spark.deploy.k8s.Constants._ -import org.apache.spark.deploy.k8s.KubernetesUtils -import org.apache.spark.deploy.k8s.submit.KubernetesDriverSpec -import org.apache.spark.internal.config.{DRIVER_CLASS_PATH, DRIVER_MEMORY, DRIVER_MEMORY_OVERHEAD} -import org.apache.spark.launcher.SparkLauncher - -/** - * Performs basic configuration for the driver pod. - */ -private[spark] class BasicDriverConfigurationStep( - kubernetesAppId: String, - resourceNamePrefix: String, - driverLabels: Map[String, String], - imagePullPolicy: String, - appName: String, - mainClass: String, - appArgs: Array[String], - sparkConf: SparkConf) extends DriverConfigurationStep { - - private val driverPodName = sparkConf - .get(KUBERNETES_DRIVER_POD_NAME) - .getOrElse(s"$resourceNamePrefix-driver") - - private val driverExtraClasspath = sparkConf.get(DRIVER_CLASS_PATH) - - private val driverContainerImage = sparkConf - .get(DRIVER_CONTAINER_IMAGE) - .getOrElse(throw new SparkException("Must specify the driver container image")) - - private val imagePullSecrets = sparkConf.get(IMAGE_PULL_SECRETS) - - // CPU settings - private val driverCpuCores = sparkConf.getOption("spark.driver.cores").getOrElse("1") - private val driverLimitCores = sparkConf.get(KUBERNETES_DRIVER_LIMIT_CORES) - - // Memory settings - private val driverMemoryMiB = sparkConf.get(DRIVER_MEMORY) - private val memoryOverheadMiB = sparkConf - .get(DRIVER_MEMORY_OVERHEAD) - .getOrElse(math.max((MEMORY_OVERHEAD_FACTOR * driverMemoryMiB).toInt, MEMORY_OVERHEAD_MIN_MIB)) - private val driverMemoryWithOverheadMiB = driverMemoryMiB + memoryOverheadMiB - - override def configureDriver(driverSpec: KubernetesDriverSpec): KubernetesDriverSpec = { - val driverExtraClasspathEnv = driverExtraClasspath.map { classPath => - new EnvVarBuilder() - .withName(ENV_CLASSPATH) - .withValue(classPath) - .build() - } - - val driverCustomAnnotations = KubernetesUtils.parsePrefixedKeyValuePairs( - sparkConf, KUBERNETES_DRIVER_ANNOTATION_PREFIX) - require(!driverCustomAnnotations.contains(SPARK_APP_NAME_ANNOTATION), - s"Annotation with key $SPARK_APP_NAME_ANNOTATION is not allowed as it is reserved for" + - " Spark bookkeeping operations.") - - val driverCustomEnvs = sparkConf.getAllWithPrefix(KUBERNETES_DRIVER_ENV_KEY).toSeq - .map { env => - new EnvVarBuilder() - .withName(env._1) - .withValue(env._2) - .build() - } - - val driverAnnotations = driverCustomAnnotations ++ Map(SPARK_APP_NAME_ANNOTATION -> appName) - - val nodeSelector = KubernetesUtils.parsePrefixedKeyValuePairs( - sparkConf, KUBERNETES_NODE_SELECTOR_PREFIX) - - val driverCpuQuantity = new QuantityBuilder(false) - .withAmount(driverCpuCores) - .build() - val driverMemoryQuantity = new QuantityBuilder(false) - .withAmount(s"${driverMemoryWithOverheadMiB}Mi") - .build() - val maybeCpuLimitQuantity = driverLimitCores.map { limitCores => - ("cpu", new QuantityBuilder(false).withAmount(limitCores).build()) - } - - val driverContainerWithoutArgs = new ContainerBuilder(driverSpec.driverContainer) - .withName(DRIVER_CONTAINER_NAME) - .withImage(driverContainerImage) - .withImagePullPolicy(imagePullPolicy) - .addAllToEnv(driverCustomEnvs.asJava) - .addToEnv(driverExtraClasspathEnv.toSeq: _*) - .addNewEnv() - .withName(ENV_DRIVER_BIND_ADDRESS) - .withValueFrom(new EnvVarSourceBuilder() - .withNewFieldRef("v1", "status.podIP") - .build()) - .endEnv() - .withNewResources() - .addToRequests("cpu", driverCpuQuantity) - .addToRequests("memory", driverMemoryQuantity) - .addToLimits("memory", driverMemoryQuantity) - .addToLimits(maybeCpuLimitQuantity.toMap.asJava) - .endResources() - .addToArgs("driver") - .addToArgs("--properties-file", SPARK_CONF_PATH) - .addToArgs("--class", mainClass) - // The user application jar is merged into the spark.jars list and managed through that - // property, so there is no need to reference it explicitly here. - .addToArgs(SparkLauncher.NO_RESOURCE) - - val driverContainer = appArgs.toList match { - case "" :: Nil | Nil => driverContainerWithoutArgs.build() - case _ => driverContainerWithoutArgs.addToArgs(appArgs: _*).build() - } - - val parsedImagePullSecrets = KubernetesUtils.parseImagePullSecrets(imagePullSecrets) - - val baseDriverPod = new PodBuilder(driverSpec.driverPod) - .editOrNewMetadata() - .withName(driverPodName) - .addToLabels(driverLabels.asJava) - .addToAnnotations(driverAnnotations.asJava) - .endMetadata() - .withNewSpec() - .withRestartPolicy("Never") - .withNodeSelector(nodeSelector.asJava) - .withImagePullSecrets(parsedImagePullSecrets.asJava) - .endSpec() - .build() - - val resolvedSparkConf = driverSpec.driverSparkConf.clone() - .setIfMissing(KUBERNETES_DRIVER_POD_NAME, driverPodName) - .set("spark.app.id", kubernetesAppId) - .set(KUBERNETES_EXECUTOR_POD_NAME_PREFIX, resourceNamePrefix) - // to set the config variables to allow client-mode spark-submit from driver - .set(KUBERNETES_DRIVER_SUBMIT_CHECK, true) - - driverSpec.copy( - driverPod = baseDriverPod, - driverSparkConf = resolvedSparkConf, - driverContainer = driverContainer) - } - -} - diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DependencyResolutionStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DependencyResolutionStep.scala deleted file mode 100644 index 43de329f239ad..0000000000000 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DependencyResolutionStep.scala +++ /dev/null @@ -1,61 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s.submit.steps - -import java.io.File - -import io.fabric8.kubernetes.api.model.ContainerBuilder - -import org.apache.spark.deploy.k8s.Constants._ -import org.apache.spark.deploy.k8s.KubernetesUtils -import org.apache.spark.deploy.k8s.submit.KubernetesDriverSpec - -/** - * Step that configures the classpath, spark.jars, and spark.files for the driver given that the - * user may provide remote files or files with local:// schemes. - */ -private[spark] class DependencyResolutionStep( - sparkJars: Seq[String], - sparkFiles: Seq[String]) extends DriverConfigurationStep { - - override def configureDriver(driverSpec: KubernetesDriverSpec): KubernetesDriverSpec = { - val resolvedSparkJars = KubernetesUtils.resolveFileUrisAndPath(sparkJars) - val resolvedSparkFiles = KubernetesUtils.resolveFileUrisAndPath(sparkFiles) - - val sparkConf = driverSpec.driverSparkConf.clone() - if (resolvedSparkJars.nonEmpty) { - sparkConf.set("spark.jars", resolvedSparkJars.mkString(",")) - } - if (resolvedSparkFiles.nonEmpty) { - sparkConf.set("spark.files", resolvedSparkFiles.mkString(",")) - } - val resolvedDriverContainer = if (resolvedSparkJars.nonEmpty) { - new ContainerBuilder(driverSpec.driverContainer) - .addNewEnv() - .withName(ENV_MOUNTED_CLASSPATH) - .withValue(resolvedSparkJars.mkString(File.pathSeparator)) - .endEnv() - .build() - } else { - driverSpec.driverContainer - } - - driverSpec.copy( - driverContainer = resolvedDriverContainer, - driverSparkConf = sparkConf) - } -} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverKubernetesCredentialsStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverKubernetesCredentialsStep.scala deleted file mode 100644 index 2424e63999a82..0000000000000 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverKubernetesCredentialsStep.scala +++ /dev/null @@ -1,245 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s.submit.steps - -import java.io.File -import java.nio.charset.StandardCharsets - -import scala.collection.JavaConverters._ -import scala.language.implicitConversions - -import com.google.common.io.{BaseEncoding, Files} -import io.fabric8.kubernetes.api.model.{ContainerBuilder, PodBuilder, Secret, SecretBuilder} - -import org.apache.spark.SparkConf -import org.apache.spark.deploy.k8s.Config._ -import org.apache.spark.deploy.k8s.Constants._ -import org.apache.spark.deploy.k8s.submit.KubernetesDriverSpec - -/** - * Mounts Kubernetes credentials into the driver pod. The driver will use such mounted credentials - * to request executors. - */ -private[spark] class DriverKubernetesCredentialsStep( - submissionSparkConf: SparkConf, - kubernetesResourceNamePrefix: String) extends DriverConfigurationStep { - - private val maybeMountedOAuthTokenFile = submissionSparkConf.getOption( - s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$OAUTH_TOKEN_FILE_CONF_SUFFIX") - private val maybeMountedClientKeyFile = submissionSparkConf.getOption( - s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$CLIENT_KEY_FILE_CONF_SUFFIX") - private val maybeMountedClientCertFile = submissionSparkConf.getOption( - s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$CLIENT_CERT_FILE_CONF_SUFFIX") - private val maybeMountedCaCertFile = submissionSparkConf.getOption( - s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$CA_CERT_FILE_CONF_SUFFIX") - private val driverServiceAccount = submissionSparkConf.get(KUBERNETES_SERVICE_ACCOUNT_NAME) - - override def configureDriver(driverSpec: KubernetesDriverSpec): KubernetesDriverSpec = { - val driverSparkConf = driverSpec.driverSparkConf.clone() - - val oauthTokenBase64 = submissionSparkConf - .getOption(s"$KUBERNETES_AUTH_DRIVER_CONF_PREFIX.$OAUTH_TOKEN_CONF_SUFFIX") - .map { token => - BaseEncoding.base64().encode(token.getBytes(StandardCharsets.UTF_8)) - } - val caCertDataBase64 = safeFileConfToBase64( - s"$KUBERNETES_AUTH_DRIVER_CONF_PREFIX.$CA_CERT_FILE_CONF_SUFFIX", - "Driver CA cert file") - val clientKeyDataBase64 = safeFileConfToBase64( - s"$KUBERNETES_AUTH_DRIVER_CONF_PREFIX.$CLIENT_KEY_FILE_CONF_SUFFIX", - "Driver client key file") - val clientCertDataBase64 = safeFileConfToBase64( - s"$KUBERNETES_AUTH_DRIVER_CONF_PREFIX.$CLIENT_CERT_FILE_CONF_SUFFIX", - "Driver client cert file") - - val driverSparkConfWithCredentialsLocations = setDriverPodKubernetesCredentialLocations( - driverSparkConf, - oauthTokenBase64, - caCertDataBase64, - clientKeyDataBase64, - clientCertDataBase64) - - val kubernetesCredentialsSecret = createCredentialsSecret( - oauthTokenBase64, - caCertDataBase64, - clientKeyDataBase64, - clientCertDataBase64) - - val driverPodWithMountedKubernetesCredentials = kubernetesCredentialsSecret.map { secret => - new PodBuilder(driverSpec.driverPod) - .editOrNewSpec() - .addNewVolume() - .withName(DRIVER_CREDENTIALS_SECRET_VOLUME_NAME) - .withNewSecret().withSecretName(secret.getMetadata.getName).endSecret() - .endVolume() - .endSpec() - .build() - }.getOrElse( - driverServiceAccount.map { account => - new PodBuilder(driverSpec.driverPod) - .editOrNewSpec() - .withServiceAccount(account) - .withServiceAccountName(account) - .endSpec() - .build() - }.getOrElse(driverSpec.driverPod) - ) - - val driverContainerWithMountedSecretVolume = kubernetesCredentialsSecret.map { _ => - new ContainerBuilder(driverSpec.driverContainer) - .addNewVolumeMount() - .withName(DRIVER_CREDENTIALS_SECRET_VOLUME_NAME) - .withMountPath(DRIVER_CREDENTIALS_SECRETS_BASE_DIR) - .endVolumeMount() - .build() - }.getOrElse(driverSpec.driverContainer) - - driverSpec.copy( - driverPod = driverPodWithMountedKubernetesCredentials, - otherKubernetesResources = - driverSpec.otherKubernetesResources ++ kubernetesCredentialsSecret.toSeq, - driverSparkConf = driverSparkConfWithCredentialsLocations, - driverContainer = driverContainerWithMountedSecretVolume) - } - - private def createCredentialsSecret( - driverOAuthTokenBase64: Option[String], - driverCaCertDataBase64: Option[String], - driverClientKeyDataBase64: Option[String], - driverClientCertDataBase64: Option[String]): Option[Secret] = { - val allSecretData = - resolveSecretData( - driverClientKeyDataBase64, - DRIVER_CREDENTIALS_CLIENT_KEY_SECRET_NAME) ++ - resolveSecretData( - driverClientCertDataBase64, - DRIVER_CREDENTIALS_CLIENT_CERT_SECRET_NAME) ++ - resolveSecretData( - driverCaCertDataBase64, - DRIVER_CREDENTIALS_CA_CERT_SECRET_NAME) ++ - resolveSecretData( - driverOAuthTokenBase64, - DRIVER_CREDENTIALS_OAUTH_TOKEN_SECRET_NAME) - - if (allSecretData.isEmpty) { - None - } else { - Some(new SecretBuilder() - .withNewMetadata() - .withName(s"$kubernetesResourceNamePrefix-kubernetes-credentials") - .endMetadata() - .withData(allSecretData.asJava) - .build()) - } - } - - private def setDriverPodKubernetesCredentialLocations( - driverSparkConf: SparkConf, - driverOauthTokenBase64: Option[String], - driverCaCertDataBase64: Option[String], - driverClientKeyDataBase64: Option[String], - driverClientCertDataBase64: Option[String]): SparkConf = { - val resolvedMountedOAuthTokenFile = resolveSecretLocation( - maybeMountedOAuthTokenFile, - driverOauthTokenBase64, - DRIVER_CREDENTIALS_OAUTH_TOKEN_PATH) - val resolvedMountedClientKeyFile = resolveSecretLocation( - maybeMountedClientKeyFile, - driverClientKeyDataBase64, - DRIVER_CREDENTIALS_CLIENT_KEY_PATH) - val resolvedMountedClientCertFile = resolveSecretLocation( - maybeMountedClientCertFile, - driverClientCertDataBase64, - DRIVER_CREDENTIALS_CLIENT_CERT_PATH) - val resolvedMountedCaCertFile = resolveSecretLocation( - maybeMountedCaCertFile, - driverCaCertDataBase64, - DRIVER_CREDENTIALS_CA_CERT_PATH) - - val sparkConfWithCredentialLocations = driverSparkConf - .setOption( - s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$CA_CERT_FILE_CONF_SUFFIX", - resolvedMountedCaCertFile) - .setOption( - s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$CLIENT_KEY_FILE_CONF_SUFFIX", - resolvedMountedClientKeyFile) - .setOption( - s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$CLIENT_CERT_FILE_CONF_SUFFIX", - resolvedMountedClientCertFile) - .setOption( - s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$OAUTH_TOKEN_FILE_CONF_SUFFIX", - resolvedMountedOAuthTokenFile) - - // Redact all OAuth token values - sparkConfWithCredentialLocations - .getAll - .filter(_._1.endsWith(OAUTH_TOKEN_CONF_SUFFIX)).map(_._1) - .foreach { - sparkConfWithCredentialLocations.set(_, "") - } - sparkConfWithCredentialLocations - } - - private def safeFileConfToBase64(conf: String, fileType: String): Option[String] = { - submissionSparkConf.getOption(conf) - .map(new File(_)) - .map { file => - require(file.isFile, String.format("%s provided at %s does not exist or is not a file.", - fileType, file.getAbsolutePath)) - BaseEncoding.base64().encode(Files.toByteArray(file)) - } - } - - private def resolveSecretLocation( - mountedUserSpecified: Option[String], - valueMountedFromSubmitter: Option[String], - mountedCanonicalLocation: String): Option[String] = { - mountedUserSpecified.orElse(valueMountedFromSubmitter.map { _ => - mountedCanonicalLocation - }) - } - - /** - * Resolve a Kubernetes secret data entry from an optional client credential used by the - * driver to talk to the Kubernetes API server. - * - * @param userSpecifiedCredential the optional user-specified client credential. - * @param secretName name of the Kubernetes secret storing the client credential. - * @return a secret data entry in the form of a map from the secret name to the secret data, - * which may be empty if the user-specified credential is empty. - */ - private def resolveSecretData( - userSpecifiedCredential: Option[String], - secretName: String): Map[String, String] = { - userSpecifiedCredential.map { valueBase64 => - Map(secretName -> valueBase64) - }.getOrElse(Map.empty[String, String]) - } - - private implicit def augmentSparkConf(sparkConf: SparkConf): OptionSettableSparkConf = { - new OptionSettableSparkConf(sparkConf) - } -} - -private class OptionSettableSparkConf(sparkConf: SparkConf) { - def setOption(configEntry: String, option: Option[String]): SparkConf = { - option.foreach { opt => - sparkConf.set(configEntry, opt) - } - sparkConf - } -} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverMountSecretsStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverMountSecretsStep.scala deleted file mode 100644 index 91e9a9f211335..0000000000000 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverMountSecretsStep.scala +++ /dev/null @@ -1,38 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s.submit.steps - -import org.apache.spark.deploy.k8s.MountSecretsBootstrap -import org.apache.spark.deploy.k8s.submit.KubernetesDriverSpec - -/** - * A driver configuration step for mounting user-specified secrets onto user-specified paths. - * - * @param bootstrap a utility actually handling mounting of the secrets. - */ -private[spark] class DriverMountSecretsStep( - bootstrap: MountSecretsBootstrap) extends DriverConfigurationStep { - - override def configureDriver(driverSpec: KubernetesDriverSpec): KubernetesDriverSpec = { - val pod = bootstrap.addSecretVolumes(driverSpec.driverPod) - val container = bootstrap.mountSecrets(driverSpec.driverContainer) - driverSpec.copy( - driverPod = pod, - driverContainer = container - ) - } -} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverServiceBootstrapStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverServiceBootstrapStep.scala deleted file mode 100644 index 34af7cde6c1a9..0000000000000 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverServiceBootstrapStep.scala +++ /dev/null @@ -1,104 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s.submit.steps - -import scala.collection.JavaConverters._ - -import io.fabric8.kubernetes.api.model.ServiceBuilder - -import org.apache.spark.SparkConf -import org.apache.spark.deploy.k8s.Config._ -import org.apache.spark.deploy.k8s.Constants._ -import org.apache.spark.deploy.k8s.submit.KubernetesDriverSpec -import org.apache.spark.internal.Logging -import org.apache.spark.util.Clock - -/** - * Allows the driver to be reachable by executor pods through a headless service. The service's - * ports should correspond to the ports that the executor will reach the pod at for RPC. - */ -private[spark] class DriverServiceBootstrapStep( - resourceNamePrefix: String, - driverLabels: Map[String, String], - sparkConf: SparkConf, - clock: Clock) extends DriverConfigurationStep with Logging { - - import DriverServiceBootstrapStep._ - - override def configureDriver(driverSpec: KubernetesDriverSpec): KubernetesDriverSpec = { - require(sparkConf.getOption(DRIVER_BIND_ADDRESS_KEY).isEmpty, - s"$DRIVER_BIND_ADDRESS_KEY is not supported in Kubernetes mode, as the driver's bind " + - "address is managed and set to the driver pod's IP address.") - require(sparkConf.getOption(DRIVER_HOST_KEY).isEmpty, - s"$DRIVER_HOST_KEY is not supported in Kubernetes mode, as the driver's hostname will be " + - "managed via a Kubernetes service.") - - val preferredServiceName = s"$resourceNamePrefix$DRIVER_SVC_POSTFIX" - val resolvedServiceName = if (preferredServiceName.length <= MAX_SERVICE_NAME_LENGTH) { - preferredServiceName - } else { - val randomServiceId = clock.getTimeMillis() - val shorterServiceName = s"spark-$randomServiceId$DRIVER_SVC_POSTFIX" - logWarning(s"Driver's hostname would preferably be $preferredServiceName, but this is " + - s"too long (must be <= $MAX_SERVICE_NAME_LENGTH characters). Falling back to use " + - s"$shorterServiceName as the driver service's name.") - shorterServiceName - } - - val driverPort = sparkConf.getInt("spark.driver.port", DEFAULT_DRIVER_PORT) - val driverBlockManagerPort = sparkConf.getInt( - org.apache.spark.internal.config.DRIVER_BLOCK_MANAGER_PORT.key, DEFAULT_BLOCKMANAGER_PORT) - val driverService = new ServiceBuilder() - .withNewMetadata() - .withName(resolvedServiceName) - .endMetadata() - .withNewSpec() - .withClusterIP("None") - .withSelector(driverLabels.asJava) - .addNewPort() - .withName(DRIVER_PORT_NAME) - .withPort(driverPort) - .withNewTargetPort(driverPort) - .endPort() - .addNewPort() - .withName(BLOCK_MANAGER_PORT_NAME) - .withPort(driverBlockManagerPort) - .withNewTargetPort(driverBlockManagerPort) - .endPort() - .endSpec() - .build() - - val namespace = sparkConf.get(KUBERNETES_NAMESPACE) - val driverHostname = s"${driverService.getMetadata.getName}.$namespace.svc" - val resolvedSparkConf = driverSpec.driverSparkConf.clone() - .set(DRIVER_HOST_KEY, driverHostname) - .set("spark.driver.port", driverPort.toString) - .set( - org.apache.spark.internal.config.DRIVER_BLOCK_MANAGER_PORT, driverBlockManagerPort) - - driverSpec.copy( - driverSparkConf = resolvedSparkConf, - otherKubernetesResources = driverSpec.otherKubernetesResources ++ Seq(driverService)) - } -} - -private[spark] object DriverServiceBootstrapStep { - val DRIVER_BIND_ADDRESS_KEY = org.apache.spark.internal.config.DRIVER_BIND_ADDRESS.key - val DRIVER_HOST_KEY = org.apache.spark.internal.config.DRIVER_HOST_ADDRESS.key - val DRIVER_SVC_POSTFIX = "-driver-svc" - val MAX_SERVICE_NAME_LENGTH = 63 -} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala deleted file mode 100644 index 8607d6fba3234..0000000000000 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala +++ /dev/null @@ -1,227 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.scheduler.cluster.k8s - -import scala.collection.JavaConverters._ - -import io.fabric8.kubernetes.api.model._ - -import org.apache.spark.{SparkConf, SparkException} -import org.apache.spark.deploy.k8s.{KubernetesUtils, MountSecretsBootstrap} -import org.apache.spark.deploy.k8s.Config._ -import org.apache.spark.deploy.k8s.Constants._ -import org.apache.spark.internal.config.{EXECUTOR_CLASS_PATH, EXECUTOR_JAVA_OPTIONS, EXECUTOR_MEMORY, EXECUTOR_MEMORY_OVERHEAD} -import org.apache.spark.util.Utils - -/** - * A factory class for bootstrapping and creating executor pods with the given bootstrapping - * components. - * - * @param sparkConf Spark configuration - * @param mountSecretsBootstrap an optional component for mounting user-specified secrets onto - * user-specified paths into the executor container - */ -private[spark] class ExecutorPodFactory( - sparkConf: SparkConf, - mountSecretsBootstrap: Option[MountSecretsBootstrap]) { - - private val executorExtraClasspath = sparkConf.get(EXECUTOR_CLASS_PATH) - - private val executorLabels = KubernetesUtils.parsePrefixedKeyValuePairs( - sparkConf, - KUBERNETES_EXECUTOR_LABEL_PREFIX) - require( - !executorLabels.contains(SPARK_APP_ID_LABEL), - s"Custom executor labels cannot contain $SPARK_APP_ID_LABEL as it is reserved for Spark.") - require( - !executorLabels.contains(SPARK_EXECUTOR_ID_LABEL), - s"Custom executor labels cannot contain $SPARK_EXECUTOR_ID_LABEL as it is reserved for" + - " Spark.") - require( - !executorLabels.contains(SPARK_ROLE_LABEL), - s"Custom executor labels cannot contain $SPARK_ROLE_LABEL as it is reserved for Spark.") - - private val executorAnnotations = - KubernetesUtils.parsePrefixedKeyValuePairs( - sparkConf, - KUBERNETES_EXECUTOR_ANNOTATION_PREFIX) - private val nodeSelector = - KubernetesUtils.parsePrefixedKeyValuePairs( - sparkConf, - KUBERNETES_NODE_SELECTOR_PREFIX) - - private val executorContainerImage = sparkConf - .get(EXECUTOR_CONTAINER_IMAGE) - .getOrElse(throw new SparkException("Must specify the executor container image")) - private val imagePullPolicy = sparkConf.get(CONTAINER_IMAGE_PULL_POLICY) - private val imagePullSecrets = sparkConf.get(IMAGE_PULL_SECRETS) - private val blockManagerPort = sparkConf - .getInt("spark.blockmanager.port", DEFAULT_BLOCKMANAGER_PORT) - - private val executorPodNamePrefix = sparkConf.get(KUBERNETES_EXECUTOR_POD_NAME_PREFIX) - - private val executorMemoryMiB = sparkConf.get(EXECUTOR_MEMORY) - private val executorMemoryString = sparkConf.get( - EXECUTOR_MEMORY.key, EXECUTOR_MEMORY.defaultValueString) - - private val memoryOverheadMiB = sparkConf - .get(EXECUTOR_MEMORY_OVERHEAD) - .getOrElse(math.max((MEMORY_OVERHEAD_FACTOR * executorMemoryMiB).toInt, - MEMORY_OVERHEAD_MIN_MIB)) - private val executorMemoryWithOverhead = executorMemoryMiB + memoryOverheadMiB - - private val executorCores = sparkConf.getInt("spark.executor.cores", 1) - private val executorCoresRequest = if (sparkConf.contains(KUBERNETES_EXECUTOR_REQUEST_CORES)) { - sparkConf.get(KUBERNETES_EXECUTOR_REQUEST_CORES).get - } else { - executorCores.toString - } - private val executorLimitCores = sparkConf.get(KUBERNETES_EXECUTOR_LIMIT_CORES) - - /** - * Configure and construct an executor pod with the given parameters. - */ - def createExecutorPod( - executorId: String, - applicationId: String, - driverUrl: String, - executorEnvs: Seq[(String, String)], - driverPod: Pod, - nodeToLocalTaskCount: Map[String, Int]): Pod = { - val name = s"$executorPodNamePrefix-exec-$executorId" - - val parsedImagePullSecrets = KubernetesUtils.parseImagePullSecrets(imagePullSecrets) - - // hostname must be no longer than 63 characters, so take the last 63 characters of the pod - // name as the hostname. This preserves uniqueness since the end of name contains - // executorId - val hostname = name.substring(Math.max(0, name.length - 63)) - val resolvedExecutorLabels = Map( - SPARK_EXECUTOR_ID_LABEL -> executorId, - SPARK_APP_ID_LABEL -> applicationId, - SPARK_ROLE_LABEL -> SPARK_POD_EXECUTOR_ROLE) ++ - executorLabels - val executorMemoryQuantity = new QuantityBuilder(false) - .withAmount(s"${executorMemoryWithOverhead}Mi") - .build() - val executorCpuQuantity = new QuantityBuilder(false) - .withAmount(executorCoresRequest) - .build() - val executorExtraClasspathEnv = executorExtraClasspath.map { cp => - new EnvVarBuilder() - .withName(ENV_CLASSPATH) - .withValue(cp) - .build() - } - val executorExtraJavaOptionsEnv = sparkConf - .get(EXECUTOR_JAVA_OPTIONS) - .map { opts => - val delimitedOpts = Utils.splitCommandString(opts) - delimitedOpts.zipWithIndex.map { - case (opt, index) => - new EnvVarBuilder().withName(s"$ENV_JAVA_OPT_PREFIX$index").withValue(opt).build() - } - }.getOrElse(Seq.empty[EnvVar]) - val executorEnv = (Seq( - (ENV_DRIVER_URL, driverUrl), - (ENV_EXECUTOR_CORES, executorCores.toString), - (ENV_EXECUTOR_MEMORY, executorMemoryString), - (ENV_APPLICATION_ID, applicationId), - // This is to set the SPARK_CONF_DIR to be /opt/spark/conf - (ENV_SPARK_CONF_DIR, SPARK_CONF_DIR_INTERNAL), - (ENV_EXECUTOR_ID, executorId)) ++ executorEnvs) - .map(env => new EnvVarBuilder() - .withName(env._1) - .withValue(env._2) - .build() - ) ++ Seq( - new EnvVarBuilder() - .withName(ENV_EXECUTOR_POD_IP) - .withValueFrom(new EnvVarSourceBuilder() - .withNewFieldRef("v1", "status.podIP") - .build()) - .build() - ) ++ executorExtraJavaOptionsEnv ++ executorExtraClasspathEnv.toSeq - val requiredPorts = Seq( - (BLOCK_MANAGER_PORT_NAME, blockManagerPort)) - .map { case (name, port) => - new ContainerPortBuilder() - .withName(name) - .withContainerPort(port) - .build() - } - - val executorContainer = new ContainerBuilder() - .withName("executor") - .withImage(executorContainerImage) - .withImagePullPolicy(imagePullPolicy) - .withNewResources() - .addToRequests("memory", executorMemoryQuantity) - .addToLimits("memory", executorMemoryQuantity) - .addToRequests("cpu", executorCpuQuantity) - .endResources() - .addAllToEnv(executorEnv.asJava) - .withPorts(requiredPorts.asJava) - .addToArgs("executor") - .build() - - val executorPod = new PodBuilder() - .withNewMetadata() - .withName(name) - .withLabels(resolvedExecutorLabels.asJava) - .withAnnotations(executorAnnotations.asJava) - .withOwnerReferences() - .addNewOwnerReference() - .withController(true) - .withApiVersion(driverPod.getApiVersion) - .withKind(driverPod.getKind) - .withName(driverPod.getMetadata.getName) - .withUid(driverPod.getMetadata.getUid) - .endOwnerReference() - .endMetadata() - .withNewSpec() - .withHostname(hostname) - .withRestartPolicy("Never") - .withNodeSelector(nodeSelector.asJava) - .withImagePullSecrets(parsedImagePullSecrets.asJava) - .endSpec() - .build() - - val containerWithLimitCores = executorLimitCores.map { limitCores => - val executorCpuLimitQuantity = new QuantityBuilder(false) - .withAmount(limitCores) - .build() - new ContainerBuilder(executorContainer) - .editResources() - .addToLimits("cpu", executorCpuLimitQuantity) - .endResources() - .build() - }.getOrElse(executorContainer) - - val (maybeSecretsMountedPod, maybeSecretsMountedContainer) = - mountSecretsBootstrap.map { bootstrap => - (bootstrap.addSecretVolumes(executorPod), bootstrap.mountSecrets(containerWithLimitCores)) - }.getOrElse((executorPod, containerWithLimitCores)) - - - new PodBuilder(maybeSecretsMountedPod) - .editSpec() - .addToContainers(maybeSecretsMountedContainer) - .endSpec() - .build() - } -} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala index ff5f6801da2a3..0ea80dfbc0d97 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala @@ -21,7 +21,7 @@ import java.io.File import io.fabric8.kubernetes.client.Config import org.apache.spark.{SparkContext, SparkException} -import org.apache.spark.deploy.k8s.{KubernetesUtils, MountSecretsBootstrap, SparkKubernetesClientFactory} +import org.apache.spark.deploy.k8s.{KubernetesUtils, SparkKubernetesClientFactory} import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ import org.apache.spark.internal.Logging @@ -48,12 +48,6 @@ private[spark] class KubernetesClusterManager extends ExternalClusterManager wit scheduler: TaskScheduler): SchedulerBackend = { val executorSecretNamesToMountPaths = KubernetesUtils.parsePrefixedKeyValuePairs( sc.conf, KUBERNETES_EXECUTOR_SECRETS_PREFIX) - val mountSecretBootstrap = if (executorSecretNamesToMountPaths.nonEmpty) { - Some(new MountSecretsBootstrap(executorSecretNamesToMountPaths)) - } else { - None - } - val kubernetesClient = SparkKubernetesClientFactory.createKubernetesClient( KUBERNETES_MASTER_INTERNAL_URL, Some(sc.conf.get(KUBERNETES_NAMESPACE)), @@ -62,8 +56,6 @@ private[spark] class KubernetesClusterManager extends ExternalClusterManager wit Some(new File(Config.KUBERNETES_SERVICE_ACCOUNT_TOKEN_PATH)), Some(new File(Config.KUBERNETES_SERVICE_ACCOUNT_CA_CRT_PATH))) - val executorPodFactory = new ExecutorPodFactory(sc.conf, mountSecretBootstrap) - val allocatorExecutor = ThreadUtils .newDaemonSingleThreadScheduledExecutor("kubernetes-pod-allocator") val requestExecutorsService = ThreadUtils.newDaemonCachedThreadPool( @@ -71,7 +63,7 @@ private[spark] class KubernetesClusterManager extends ExternalClusterManager wit new KubernetesClusterSchedulerBackend( scheduler.asInstanceOf[TaskSchedulerImpl], sc.env.rpcEnv, - executorPodFactory, + new KubernetesExecutorBuilder, kubernetesClient, allocatorExecutor, requestExecutorsService) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala index 9de4b16c30d3c..d86664c81071b 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala @@ -32,6 +32,7 @@ import scala.concurrent.{ExecutionContext, Future} import org.apache.spark.SparkException import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.deploy.k8s.KubernetesConf import org.apache.spark.rpc.{RpcAddress, RpcEndpointAddress, RpcEnv} import org.apache.spark.scheduler.{ExecutorExited, SlaveLost, TaskSchedulerImpl} import org.apache.spark.scheduler.cluster.{CoarseGrainedSchedulerBackend, SchedulerBackendUtils} @@ -40,7 +41,7 @@ import org.apache.spark.util.Utils private[spark] class KubernetesClusterSchedulerBackend( scheduler: TaskSchedulerImpl, rpcEnv: RpcEnv, - executorPodFactory: ExecutorPodFactory, + executorBuilder: KubernetesExecutorBuilder, kubernetesClient: KubernetesClient, allocatorExecutor: ScheduledExecutorService, requestExecutorsService: ExecutorService) @@ -115,14 +116,19 @@ private[spark] class KubernetesClusterSchedulerBackend( for (_ <- 0 until math.min( currentTotalExpectedExecutors - runningExecutorsToPods.size, podAllocationSize)) { val executorId = EXECUTOR_ID_COUNTER.incrementAndGet().toString - val executorPod = executorPodFactory.createExecutorPod( + val executorConf = KubernetesConf.createExecutorConf( + conf, executorId, applicationId(), - driverUrl, - conf.getExecutorEnv, - driverPod, - currentNodeToLocalTaskCount) - executorsToAllocate(executorId) = executorPod + driverPod) + val executorPod = executorBuilder.buildFromFeatures(executorConf) + val podWithAttachedContainer = new PodBuilder(executorPod.pod) + .editOrNewSpec() + .addToContainers(executorPod.container) + .endSpec() + .build() + + executorsToAllocate(executorId) = podWithAttachedContainer logInfo( s"Requesting a new executor, total executors is now ${runningExecutorsToPods.size}") } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala new file mode 100644 index 0000000000000..22568fe7ea3be --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesExecutorSpecificConf, KubernetesRoleSpecificConf, SparkPod} +import org.apache.spark.deploy.k8s.features.{BasicExecutorFeatureStep, MountSecretsFeatureStep} + +private[spark] class KubernetesExecutorBuilder( + provideBasicStep: (KubernetesConf[KubernetesExecutorSpecificConf]) => BasicExecutorFeatureStep = + new BasicExecutorFeatureStep(_), + provideSecretsStep: + (KubernetesConf[_ <: KubernetesRoleSpecificConf]) => MountSecretsFeatureStep = + new MountSecretsFeatureStep(_)) { + + def buildFromFeatures( + kubernetesConf: KubernetesConf[KubernetesExecutorSpecificConf]): SparkPod = { + val baseFeatures = Seq(provideBasicStep(kubernetesConf)) + val allFeatures = if (kubernetesConf.roleSecretNamesToMountPaths.nonEmpty) { + baseFeatures ++ Seq(provideSecretsStep(kubernetesConf)) + } else baseFeatures + var executorPod = SparkPod.initialPod() + for (feature <- allFeatures) { + executorPod = feature.configurePod(executorPod) + } + executorPod + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesConfSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesConfSuite.scala new file mode 100644 index 0000000000000..f10202f7a3546 --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesConfSuite.scala @@ -0,0 +1,175 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.k8s + +import io.fabric8.kubernetes.api.model.{LocalObjectReferenceBuilder, PodBuilder} + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.deploy.k8s.submit.JavaMainAppResource + +class KubernetesConfSuite extends SparkFunSuite { + + private val APP_NAME = "test-app" + private val RESOURCE_NAME_PREFIX = "prefix" + private val APP_ID = "test-id" + private val MAIN_CLASS = "test-class" + private val APP_ARGS = Array("arg1", "arg2") + private val CUSTOM_LABELS = Map( + "customLabel1Key" -> "customLabel1Value", + "customLabel2Key" -> "customLabel2Value") + private val CUSTOM_ANNOTATIONS = Map( + "customAnnotation1Key" -> "customAnnotation1Value", + "customAnnotation2Key" -> "customAnnotation2Value") + private val SECRET_NAMES_TO_MOUNT_PATHS = Map( + "secret1" -> "/mnt/secrets/secret1", + "secret2" -> "/mnt/secrets/secret2") + private val CUSTOM_ENVS = Map( + "customEnvKey1" -> "customEnvValue1", + "customEnvKey2" -> "customEnvValue2") + private val DRIVER_POD = new PodBuilder().build() + private val EXECUTOR_ID = "executor-id" + + test("Basic driver translated fields.") { + val sparkConf = new SparkConf(false) + val conf = KubernetesConf.createDriverConf( + sparkConf, + APP_NAME, + RESOURCE_NAME_PREFIX, + APP_ID, + None, + MAIN_CLASS, + APP_ARGS) + assert(conf.appId === APP_ID) + assert(conf.sparkConf.getAll.toMap === sparkConf.getAll.toMap) + assert(conf.appResourceNamePrefix === RESOURCE_NAME_PREFIX) + assert(conf.roleSpecificConf.appName === APP_NAME) + assert(conf.roleSpecificConf.mainAppResource.isEmpty) + assert(conf.roleSpecificConf.mainClass === MAIN_CLASS) + assert(conf.roleSpecificConf.appArgs === APP_ARGS) + } + + test("Creating driver conf with and without the main app jar influences spark.jars") { + val sparkConf = new SparkConf(false) + .setJars(Seq("local:///opt/spark/jar1.jar")) + val mainAppJar = Some(JavaMainAppResource("local:///opt/spark/main.jar")) + val kubernetesConfWithMainJar = KubernetesConf.createDriverConf( + sparkConf, + APP_NAME, + RESOURCE_NAME_PREFIX, + APP_ID, + mainAppJar, + MAIN_CLASS, + APP_ARGS) + assert(kubernetesConfWithMainJar.sparkConf.get("spark.jars") + .split(",") + === Array("local:///opt/spark/jar1.jar", "local:///opt/spark/main.jar")) + val kubernetesConfWithoutMainJar = KubernetesConf.createDriverConf( + sparkConf, + APP_NAME, + RESOURCE_NAME_PREFIX, + APP_ID, + None, + MAIN_CLASS, + APP_ARGS) + assert(kubernetesConfWithoutMainJar.sparkConf.get("spark.jars").split(",") + === Array("local:///opt/spark/jar1.jar")) + } + + test("Resolve driver labels, annotations, secret mount paths, and envs.") { + val sparkConf = new SparkConf(false) + CUSTOM_LABELS.foreach { case (key, value) => + sparkConf.set(s"$KUBERNETES_DRIVER_LABEL_PREFIX$key", value) + } + CUSTOM_ANNOTATIONS.foreach { case (key, value) => + sparkConf.set(s"$KUBERNETES_DRIVER_ANNOTATION_PREFIX$key", value) + } + SECRET_NAMES_TO_MOUNT_PATHS.foreach { case (key, value) => + sparkConf.set(s"$KUBERNETES_DRIVER_SECRETS_PREFIX$key", value) + } + CUSTOM_ENVS.foreach { case (key, value) => + sparkConf.set(s"$KUBERNETES_DRIVER_ENV_PREFIX$key", value) + } + + val conf = KubernetesConf.createDriverConf( + sparkConf, + APP_NAME, + RESOURCE_NAME_PREFIX, + APP_ID, + None, + MAIN_CLASS, + APP_ARGS) + assert(conf.roleLabels === Map( + SPARK_APP_ID_LABEL -> APP_ID, + SPARK_ROLE_LABEL -> SPARK_POD_DRIVER_ROLE) ++ + CUSTOM_LABELS) + assert(conf.roleAnnotations === CUSTOM_ANNOTATIONS) + assert(conf.roleSecretNamesToMountPaths === SECRET_NAMES_TO_MOUNT_PATHS) + assert(conf.roleEnvs === CUSTOM_ENVS) + } + + test("Basic executor translated fields.") { + val conf = KubernetesConf.createExecutorConf( + new SparkConf(false), + EXECUTOR_ID, + APP_ID, + DRIVER_POD) + assert(conf.roleSpecificConf.executorId === EXECUTOR_ID) + assert(conf.roleSpecificConf.driverPod === DRIVER_POD) + } + + test("Image pull secrets.") { + val conf = KubernetesConf.createExecutorConf( + new SparkConf(false) + .set(IMAGE_PULL_SECRETS, "my-secret-1,my-secret-2 "), + EXECUTOR_ID, + APP_ID, + DRIVER_POD) + assert(conf.imagePullSecrets() === + Seq( + new LocalObjectReferenceBuilder().withName("my-secret-1").build(), + new LocalObjectReferenceBuilder().withName("my-secret-2").build())) + } + + test("Set executor labels, annotations, and secrets") { + val sparkConf = new SparkConf(false) + CUSTOM_LABELS.foreach { case (key, value) => + sparkConf.set(s"$KUBERNETES_EXECUTOR_LABEL_PREFIX$key", value) + } + CUSTOM_ANNOTATIONS.foreach { case (key, value) => + sparkConf.set(s"$KUBERNETES_EXECUTOR_ANNOTATION_PREFIX$key", value) + } + SECRET_NAMES_TO_MOUNT_PATHS.foreach { case (key, value) => + sparkConf.set(s"$KUBERNETES_EXECUTOR_SECRETS_PREFIX$key", value) + } + + val conf = KubernetesConf.createExecutorConf( + sparkConf, + EXECUTOR_ID, + APP_ID, + DRIVER_POD) + assert(conf.roleLabels === Map( + SPARK_EXECUTOR_ID_LABEL -> EXECUTOR_ID, + SPARK_APP_ID_LABEL -> APP_ID, + SPARK_ROLE_LABEL -> SPARK_POD_EXECUTOR_ROLE) ++ CUSTOM_LABELS) + assert(conf.roleAnnotations === CUSTOM_ANNOTATIONS) + assert(conf.roleSecretNamesToMountPaths === SECRET_NAMES_TO_MOUNT_PATHS) + } + +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStepSuite.scala new file mode 100644 index 0000000000000..eee85b8baa730 --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStepSuite.scala @@ -0,0 +1,153 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.features + +import scala.collection.JavaConverters._ + +import io.fabric8.kubernetes.api.model.LocalObjectReferenceBuilder + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpecificConf, SparkPod} +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.Constants._ + +class BasicDriverFeatureStepSuite extends SparkFunSuite { + + private val APP_ID = "spark-app-id" + private val RESOURCE_NAME_PREFIX = "spark" + private val DRIVER_LABELS = Map("labelkey" -> "labelvalue") + private val CONTAINER_IMAGE_PULL_POLICY = "IfNotPresent" + private val APP_NAME = "spark-test" + private val MAIN_CLASS = "org.apache.spark.examples.SparkPi" + private val APP_ARGS = Array("arg1", "arg2", "\"arg 3\"") + private val CUSTOM_ANNOTATION_KEY = "customAnnotation" + private val CUSTOM_ANNOTATION_VALUE = "customAnnotationValue" + private val DRIVER_ANNOTATIONS = Map(CUSTOM_ANNOTATION_KEY -> CUSTOM_ANNOTATION_VALUE) + private val DRIVER_CUSTOM_ENV1 = "customDriverEnv1" + private val DRIVER_CUSTOM_ENV2 = "customDriverEnv2" + private val DRIVER_ENVS = Map( + DRIVER_CUSTOM_ENV1 -> DRIVER_CUSTOM_ENV1, + DRIVER_CUSTOM_ENV2 -> DRIVER_CUSTOM_ENV2) + private val TEST_IMAGE_PULL_SECRETS = Seq("my-secret-1", "my-secret-2") + private val TEST_IMAGE_PULL_SECRET_OBJECTS = + TEST_IMAGE_PULL_SECRETS.map { secret => + new LocalObjectReferenceBuilder().withName(secret).build() + } + + test("Check the pod respects all configurations from the user.") { + val sparkConf = new SparkConf() + .set(KUBERNETES_DRIVER_POD_NAME, "spark-driver-pod") + .set("spark.driver.cores", "2") + .set(KUBERNETES_DRIVER_LIMIT_CORES, "4") + .set(org.apache.spark.internal.config.DRIVER_MEMORY.key, "256M") + .set(org.apache.spark.internal.config.DRIVER_MEMORY_OVERHEAD, 200L) + .set(CONTAINER_IMAGE, "spark-driver:latest") + .set(IMAGE_PULL_SECRETS, TEST_IMAGE_PULL_SECRETS.mkString(",")) + val kubernetesConf = KubernetesConf( + sparkConf, + KubernetesDriverSpecificConf( + None, + APP_NAME, + MAIN_CLASS, + APP_ARGS), + RESOURCE_NAME_PREFIX, + APP_ID, + DRIVER_LABELS, + DRIVER_ANNOTATIONS, + Map.empty, + DRIVER_ENVS) + + val featureStep = new BasicDriverFeatureStep(kubernetesConf) + val basePod = SparkPod.initialPod() + val configuredPod = featureStep.configurePod(basePod) + + assert(configuredPod.container.getName === DRIVER_CONTAINER_NAME) + assert(configuredPod.container.getImage === "spark-driver:latest") + assert(configuredPod.container.getImagePullPolicy === CONTAINER_IMAGE_PULL_POLICY) + + assert(configuredPod.container.getEnv.size === 3) + val envs = configuredPod.container + .getEnv + .asScala + .map(env => (env.getName, env.getValue)) + .toMap + assert(envs(DRIVER_CUSTOM_ENV1) === DRIVER_ENVS(DRIVER_CUSTOM_ENV1)) + assert(envs(DRIVER_CUSTOM_ENV2) === DRIVER_ENVS(DRIVER_CUSTOM_ENV2)) + + assert(configuredPod.pod.getSpec().getImagePullSecrets.asScala === + TEST_IMAGE_PULL_SECRET_OBJECTS) + + assert(configuredPod.container.getEnv.asScala.exists(envVar => + envVar.getName.equals(ENV_DRIVER_BIND_ADDRESS) && + envVar.getValueFrom.getFieldRef.getApiVersion.equals("v1") && + envVar.getValueFrom.getFieldRef.getFieldPath.equals("status.podIP"))) + + val resourceRequirements = configuredPod.container.getResources + val requests = resourceRequirements.getRequests.asScala + assert(requests("cpu").getAmount === "2") + assert(requests("memory").getAmount === "456Mi") + val limits = resourceRequirements.getLimits.asScala + assert(limits("memory").getAmount === "456Mi") + assert(limits("cpu").getAmount === "4") + + val driverPodMetadata = configuredPod.pod.getMetadata + assert(driverPodMetadata.getName === "spark-driver-pod") + assert(driverPodMetadata.getLabels.asScala === DRIVER_LABELS) + assert(driverPodMetadata.getAnnotations.asScala === DRIVER_ANNOTATIONS) + assert(configuredPod.pod.getSpec.getRestartPolicy === "Never") + + val expectedSparkConf = Map( + KUBERNETES_DRIVER_POD_NAME.key -> "spark-driver-pod", + "spark.app.id" -> APP_ID, + KUBERNETES_EXECUTOR_POD_NAME_PREFIX.key -> RESOURCE_NAME_PREFIX, + "spark.kubernetes.submitInDriver" -> "true") + assert(featureStep.getAdditionalPodSystemProperties() === expectedSparkConf) + } + + test("Additional system properties resolve jars and set cluster-mode confs.") { + val allJars = Seq("local:///opt/spark/jar1.jar", "hdfs:///opt/spark/jar2.jar") + val allFiles = Seq("https://localhost:9000/file1.txt", "local:///opt/spark/file2.txt") + val sparkConf = new SparkConf() + .set(KUBERNETES_DRIVER_POD_NAME, "spark-driver-pod") + .setJars(allJars) + .set("spark.files", allFiles.mkString(",")) + .set(CONTAINER_IMAGE, "spark-driver:latest") + val kubernetesConf = KubernetesConf( + sparkConf, + KubernetesDriverSpecificConf( + None, + APP_NAME, + MAIN_CLASS, + APP_ARGS), + RESOURCE_NAME_PREFIX, + APP_ID, + DRIVER_LABELS, + DRIVER_ANNOTATIONS, + Map.empty, + Map.empty) + val step = new BasicDriverFeatureStep(kubernetesConf) + val additionalProperties = step.getAdditionalPodSystemProperties() + val expectedSparkConf = Map( + KUBERNETES_DRIVER_POD_NAME.key -> "spark-driver-pod", + "spark.app.id" -> APP_ID, + KUBERNETES_EXECUTOR_POD_NAME_PREFIX.key -> RESOURCE_NAME_PREFIX, + "spark.kubernetes.submitInDriver" -> "true", + "spark.jars" -> "/opt/spark/jar1.jar,hdfs:///opt/spark/jar2.jar", + "spark.files" -> "https://localhost:9000/file1.txt,/opt/spark/file2.txt") + assert(additionalProperties === expectedSparkConf) + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStepSuite.scala new file mode 100644 index 0000000000000..a764f7630b5c8 --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStepSuite.scala @@ -0,0 +1,179 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.features + +import scala.collection.JavaConverters._ + +import io.fabric8.kubernetes.api.model._ +import org.mockito.MockitoAnnotations +import org.scalatest.{BeforeAndAfter, BeforeAndAfterEach} + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesExecutorSpecificConf, SparkPod} +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.rpc.RpcEndpointAddress +import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend + +class BasicExecutorFeatureStepSuite + extends SparkFunSuite with BeforeAndAfter with BeforeAndAfterEach { + + private val APP_ID = "app-id" + private val DRIVER_HOSTNAME = "localhost" + private val DRIVER_PORT = 7098 + private val DRIVER_ADDRESS = RpcEndpointAddress( + DRIVER_HOSTNAME, + DRIVER_PORT.toInt, + CoarseGrainedSchedulerBackend.ENDPOINT_NAME) + private val DRIVER_POD_NAME = "driver-pod" + + private val DRIVER_POD_UID = "driver-uid" + private val RESOURCE_NAME_PREFIX = "base" + private val EXECUTOR_IMAGE = "executor-image" + private val LABELS = Map("label1key" -> "label1value") + private val ANNOTATIONS = Map("annotation1key" -> "annotation1value") + private val TEST_IMAGE_PULL_SECRETS = Seq("my-1secret-1", "my-secret-2") + private val TEST_IMAGE_PULL_SECRET_OBJECTS = + TEST_IMAGE_PULL_SECRETS.map { secret => + new LocalObjectReferenceBuilder().withName(secret).build() + } + private val DRIVER_POD = new PodBuilder() + .withNewMetadata() + .withName(DRIVER_POD_NAME) + .withUid(DRIVER_POD_UID) + .endMetadata() + .withNewSpec() + .withNodeName("some-node") + .endSpec() + .withNewStatus() + .withHostIP("192.168.99.100") + .endStatus() + .build() + private var baseConf: SparkConf = _ + + before { + MockitoAnnotations.initMocks(this) + baseConf = new SparkConf() + .set(KUBERNETES_DRIVER_POD_NAME, DRIVER_POD_NAME) + .set(KUBERNETES_EXECUTOR_POD_NAME_PREFIX, RESOURCE_NAME_PREFIX) + .set(CONTAINER_IMAGE, EXECUTOR_IMAGE) + .set(KUBERNETES_DRIVER_SUBMIT_CHECK, true) + .set("spark.driver.host", DRIVER_HOSTNAME) + .set("spark.driver.port", DRIVER_PORT.toString) + .set(IMAGE_PULL_SECRETS, TEST_IMAGE_PULL_SECRETS.mkString(",")) + } + + test("basic executor pod has reasonable defaults") { + val step = new BasicExecutorFeatureStep( + KubernetesConf( + baseConf, + KubernetesExecutorSpecificConf("1", DRIVER_POD), + RESOURCE_NAME_PREFIX, + APP_ID, + LABELS, + ANNOTATIONS, + Map.empty, + Map.empty)) + val executor = step.configurePod(SparkPod.initialPod()) + + // The executor pod name and default labels. + assert(executor.pod.getMetadata.getName === s"$RESOURCE_NAME_PREFIX-exec-1") + assert(executor.pod.getMetadata.getLabels.asScala === LABELS) + assert(executor.pod.getSpec.getImagePullSecrets.asScala === TEST_IMAGE_PULL_SECRET_OBJECTS) + + // There is exactly 1 container with no volume mounts and default memory limits. + // Default memory limit is 1024M + 384M (minimum overhead constant). + assert(executor.container.getImage === EXECUTOR_IMAGE) + assert(executor.container.getVolumeMounts.isEmpty) + assert(executor.container.getResources.getLimits.size() === 1) + assert(executor.container.getResources + .getLimits.get("memory").getAmount === "1408Mi") + + // The pod has no node selector, volumes. + assert(executor.pod.getSpec.getNodeSelector.isEmpty) + assert(executor.pod.getSpec.getVolumes.isEmpty) + + checkEnv(executor, Map()) + checkOwnerReferences(executor.pod, DRIVER_POD_UID) + } + + test("executor pod hostnames get truncated to 63 characters") { + val conf = baseConf.clone() + val longPodNamePrefix = "loremipsumdolorsitametvimatelitrefficiendisuscipianturvixlegeresple" + + val step = new BasicExecutorFeatureStep( + KubernetesConf( + conf, + KubernetesExecutorSpecificConf("1", DRIVER_POD), + longPodNamePrefix, + APP_ID, + LABELS, + ANNOTATIONS, + Map.empty, + Map.empty)) + assert(step.configurePod(SparkPod.initialPod()).pod.getSpec.getHostname.length === 63) + } + + test("classpath and extra java options get translated into environment variables") { + val conf = baseConf.clone() + conf.set(org.apache.spark.internal.config.EXECUTOR_JAVA_OPTIONS, "foo=bar") + conf.set(org.apache.spark.internal.config.EXECUTOR_CLASS_PATH, "bar=baz") + + val step = new BasicExecutorFeatureStep( + KubernetesConf( + conf, + KubernetesExecutorSpecificConf("1", DRIVER_POD), + RESOURCE_NAME_PREFIX, + APP_ID, + LABELS, + ANNOTATIONS, + Map.empty, + Map("qux" -> "quux"))) + val executor = step.configurePod(SparkPod.initialPod()) + + checkEnv(executor, + Map("SPARK_JAVA_OPT_0" -> "foo=bar", + ENV_CLASSPATH -> "bar=baz", + "qux" -> "quux")) + checkOwnerReferences(executor.pod, DRIVER_POD_UID) + } + + // There is always exactly one controller reference, and it points to the driver pod. + private def checkOwnerReferences(executor: Pod, driverPodUid: String): Unit = { + assert(executor.getMetadata.getOwnerReferences.size() === 1) + assert(executor.getMetadata.getOwnerReferences.get(0).getUid === driverPodUid) + assert(executor.getMetadata.getOwnerReferences.get(0).getController === true) + } + + // Check that the expected environment variables are present. + private def checkEnv(executorPod: SparkPod, additionalEnvVars: Map[String, String]): Unit = { + val defaultEnvs = Map( + ENV_EXECUTOR_ID -> "1", + ENV_DRIVER_URL -> DRIVER_ADDRESS.toString, + ENV_EXECUTOR_CORES -> "1", + ENV_EXECUTOR_MEMORY -> "1g", + ENV_APPLICATION_ID -> APP_ID, + ENV_SPARK_CONF_DIR -> SPARK_CONF_DIR_INTERNAL, + ENV_EXECUTOR_POD_IP -> null) ++ additionalEnvVars + + assert(executorPod.container.getEnv.size() === defaultEnvs.size) + val mapEnvs = executorPod.container.getEnv.asScala.map { + x => (x.getName, x.getValue) + }.toMap + assert(defaultEnvs === mapEnvs) + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverKubernetesCredentialsStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverKubernetesCredentialsFeatureStepSuite.scala similarity index 67% rename from resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverKubernetesCredentialsStepSuite.scala rename to resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverKubernetesCredentialsFeatureStepSuite.scala index 64553d25883bb..9f817d3bfc79a 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverKubernetesCredentialsStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverKubernetesCredentialsFeatureStepSuite.scala @@ -14,34 +14,35 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.deploy.k8s.submit.steps +package org.apache.spark.deploy.k8s.features import java.io.File -import scala.collection.JavaConverters._ - import com.google.common.base.Charsets import com.google.common.io.{BaseEncoding, Files} import io.fabric8.kubernetes.api.model.{ContainerBuilder, HasMetadata, PodBuilder, Secret} +import org.mockito.{Mock, MockitoAnnotations} import org.scalatest.BeforeAndAfter +import scala.collection.JavaConverters._ import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpecificConf, SparkPod} import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ -import org.apache.spark.deploy.k8s.submit.KubernetesDriverSpec import org.apache.spark.util.Utils -class DriverKubernetesCredentialsStepSuite extends SparkFunSuite with BeforeAndAfter { +class DriverKubernetesCredentialsFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { private val KUBERNETES_RESOURCE_NAME_PREFIX = "spark" + private val APP_ID = "k8s-app" private var credentialsTempDirectory: File = _ - private val BASE_DRIVER_SPEC = new KubernetesDriverSpec( - driverPod = new PodBuilder().build(), - driverContainer = new ContainerBuilder().build(), - driverSparkConf = new SparkConf(false), - otherKubernetesResources = Seq.empty[HasMetadata]) + private val BASE_DRIVER_POD = SparkPod.initialPod() + + @Mock + private var driverSpecificConf: KubernetesDriverSpecificConf = _ before { + MockitoAnnotations.initMocks(this) credentialsTempDirectory = Utils.createTempDir() } @@ -50,13 +51,19 @@ class DriverKubernetesCredentialsStepSuite extends SparkFunSuite with BeforeAndA } test("Don't set any credentials") { - val kubernetesCredentialsStep = new DriverKubernetesCredentialsStep( - new SparkConf(false), KUBERNETES_RESOURCE_NAME_PREFIX) - val preparedDriverSpec = kubernetesCredentialsStep.configureDriver(BASE_DRIVER_SPEC) - assert(preparedDriverSpec.driverPod === BASE_DRIVER_SPEC.driverPod) - assert(preparedDriverSpec.driverContainer === BASE_DRIVER_SPEC.driverContainer) - assert(preparedDriverSpec.otherKubernetesResources.isEmpty) - assert(preparedDriverSpec.driverSparkConf.getAll.isEmpty) + val kubernetesConf = KubernetesConf( + new SparkConf(false), + driverSpecificConf, + KUBERNETES_RESOURCE_NAME_PREFIX, + APP_ID, + Map.empty, + Map.empty, + Map.empty, + Map.empty) + val kubernetesCredentialsStep = new DriverKubernetesCredentialsFeatureStep(kubernetesConf) + assert(kubernetesCredentialsStep.configurePod(BASE_DRIVER_POD) === BASE_DRIVER_POD) + assert(kubernetesCredentialsStep.getAdditionalPodSystemProperties().isEmpty) + assert(kubernetesCredentialsStep.getAdditionalKubernetesResources().isEmpty) } test("Only set credentials that are manually mounted.") { @@ -73,14 +80,23 @@ class DriverKubernetesCredentialsStepSuite extends SparkFunSuite with BeforeAndA .set( s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$CA_CERT_FILE_CONF_SUFFIX", "/mnt/secrets/my-ca.pem") + val kubernetesConf = KubernetesConf( + submissionSparkConf, + driverSpecificConf, + KUBERNETES_RESOURCE_NAME_PREFIX, + APP_ID, + Map.empty, + Map.empty, + Map.empty, + Map.empty) - val kubernetesCredentialsStep = new DriverKubernetesCredentialsStep( - submissionSparkConf, KUBERNETES_RESOURCE_NAME_PREFIX) - val preparedDriverSpec = kubernetesCredentialsStep.configureDriver(BASE_DRIVER_SPEC) - assert(preparedDriverSpec.driverPod === BASE_DRIVER_SPEC.driverPod) - assert(preparedDriverSpec.driverContainer === BASE_DRIVER_SPEC.driverContainer) - assert(preparedDriverSpec.otherKubernetesResources.isEmpty) - assert(preparedDriverSpec.driverSparkConf.getAll.toMap === submissionSparkConf.getAll.toMap) + val kubernetesCredentialsStep = new DriverKubernetesCredentialsFeatureStep(kubernetesConf) + assert(kubernetesCredentialsStep.configurePod(BASE_DRIVER_POD) === BASE_DRIVER_POD) + assert(kubernetesCredentialsStep.getAdditionalKubernetesResources().isEmpty) + val resolvedProperties = kubernetesCredentialsStep.getAdditionalPodSystemProperties() + resolvedProperties.foreach { case (propKey, propValue) => + assert(submissionSparkConf.get(propKey) === propValue) + } } test("Mount credentials from the submission client as a secret.") { @@ -100,10 +116,17 @@ class DriverKubernetesCredentialsStepSuite extends SparkFunSuite with BeforeAndA .set( s"$KUBERNETES_AUTH_DRIVER_CONF_PREFIX.$CA_CERT_FILE_CONF_SUFFIX", caCertFile.getAbsolutePath) - val kubernetesCredentialsStep = new DriverKubernetesCredentialsStep( - submissionSparkConf, KUBERNETES_RESOURCE_NAME_PREFIX) - val preparedDriverSpec = kubernetesCredentialsStep.configureDriver( - BASE_DRIVER_SPEC.copy(driverSparkConf = submissionSparkConf)) + val kubernetesConf = KubernetesConf( + submissionSparkConf, + driverSpecificConf, + KUBERNETES_RESOURCE_NAME_PREFIX, + APP_ID, + Map.empty, + Map.empty, + Map.empty, + Map.empty) + val kubernetesCredentialsStep = new DriverKubernetesCredentialsFeatureStep(kubernetesConf) + val resolvedProperties = kubernetesCredentialsStep.getAdditionalPodSystemProperties() val expectedSparkConf = Map( s"$KUBERNETES_AUTH_DRIVER_CONF_PREFIX.$OAUTH_TOKEN_CONF_SUFFIX" -> "", s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$OAUTH_TOKEN_FILE_CONF_SUFFIX" -> @@ -113,16 +136,13 @@ class DriverKubernetesCredentialsStepSuite extends SparkFunSuite with BeforeAndA s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$CLIENT_CERT_FILE_CONF_SUFFIX" -> DRIVER_CREDENTIALS_CLIENT_CERT_PATH, s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$CA_CERT_FILE_CONF_SUFFIX" -> - DRIVER_CREDENTIALS_CA_CERT_PATH, - s"$KUBERNETES_AUTH_DRIVER_CONF_PREFIX.$CLIENT_KEY_FILE_CONF_SUFFIX" -> - clientKeyFile.getAbsolutePath, - s"$KUBERNETES_AUTH_DRIVER_CONF_PREFIX.$CLIENT_CERT_FILE_CONF_SUFFIX" -> - clientCertFile.getAbsolutePath, - s"$KUBERNETES_AUTH_DRIVER_CONF_PREFIX.$CA_CERT_FILE_CONF_SUFFIX" -> - caCertFile.getAbsolutePath) - assert(preparedDriverSpec.driverSparkConf.getAll.toMap === expectedSparkConf) - assert(preparedDriverSpec.otherKubernetesResources.size === 1) - val credentialsSecret = preparedDriverSpec.otherKubernetesResources.head.asInstanceOf[Secret] + DRIVER_CREDENTIALS_CA_CERT_PATH) + assert(resolvedProperties === expectedSparkConf) + assert(kubernetesCredentialsStep.getAdditionalKubernetesResources().size === 1) + val credentialsSecret = kubernetesCredentialsStep + .getAdditionalKubernetesResources() + .head + .asInstanceOf[Secret] assert(credentialsSecret.getMetadata.getName === s"$KUBERNETES_RESOURCE_NAME_PREFIX-kubernetes-credentials") val decodedSecretData = credentialsSecret.getData.asScala.map { data => @@ -134,12 +154,13 @@ class DriverKubernetesCredentialsStepSuite extends SparkFunSuite with BeforeAndA DRIVER_CREDENTIALS_CLIENT_KEY_SECRET_NAME -> "key", DRIVER_CREDENTIALS_CLIENT_CERT_SECRET_NAME -> "cert") assert(decodedSecretData === expectedSecretData) - val driverPodVolumes = preparedDriverSpec.driverPod.getSpec.getVolumes.asScala + val driverPod = kubernetesCredentialsStep.configurePod(BASE_DRIVER_POD) + val driverPodVolumes = driverPod.pod.getSpec.getVolumes.asScala assert(driverPodVolumes.size === 1) assert(driverPodVolumes.head.getName === DRIVER_CREDENTIALS_SECRET_VOLUME_NAME) assert(driverPodVolumes.head.getSecret != null) assert(driverPodVolumes.head.getSecret.getSecretName === credentialsSecret.getMetadata.getName) - val driverContainerVolumeMount = preparedDriverSpec.driverContainer.getVolumeMounts.asScala + val driverContainerVolumeMount = driverPod.container.getVolumeMounts.asScala assert(driverContainerVolumeMount.size === 1) assert(driverContainerVolumeMount.head.getName === DRIVER_CREDENTIALS_SECRET_VOLUME_NAME) assert(driverContainerVolumeMount.head.getMountPath === DRIVER_CREDENTIALS_SECRETS_BASE_DIR) diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverServiceFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverServiceFeatureStepSuite.scala new file mode 100644 index 0000000000000..c299d56865ec0 --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverServiceFeatureStepSuite.scala @@ -0,0 +1,227 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.features + +import io.fabric8.kubernetes.api.model.Service +import org.mockito.{Mock, MockitoAnnotations} +import org.mockito.Mockito.when +import org.scalatest.BeforeAndAfter +import scala.collection.JavaConverters._ + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpecificConf, SparkPod} +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.util.Clock + +class DriverServiceFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { + + private val SHORT_RESOURCE_NAME_PREFIX = + "a" * (DriverServiceFeatureStep.MAX_SERVICE_NAME_LENGTH - + DriverServiceFeatureStep.DRIVER_SVC_POSTFIX.length) + + private val LONG_RESOURCE_NAME_PREFIX = + "a" * (DriverServiceFeatureStep.MAX_SERVICE_NAME_LENGTH - + DriverServiceFeatureStep.DRIVER_SVC_POSTFIX.length + 1) + private val DRIVER_LABELS = Map( + "label1key" -> "label1value", + "label2key" -> "label2value") + + @Mock + private var clock: Clock = _ + + private var sparkConf: SparkConf = _ + + before { + MockitoAnnotations.initMocks(this) + sparkConf = new SparkConf(false) + } + + test("Headless service has a port for the driver RPC and the block manager.") { + sparkConf = sparkConf + .set("spark.driver.port", "9000") + .set(org.apache.spark.internal.config.DRIVER_BLOCK_MANAGER_PORT, 8080) + val configurationStep = new DriverServiceFeatureStep( + KubernetesConf( + sparkConf, + KubernetesDriverSpecificConf( + None, "main", "app", Seq.empty), + SHORT_RESOURCE_NAME_PREFIX, + "app-id", + DRIVER_LABELS, + Map.empty, + Map.empty, + Map.empty)) + assert(configurationStep.configurePod(SparkPod.initialPod()) === SparkPod.initialPod()) + assert(configurationStep.getAdditionalKubernetesResources().size === 1) + assert(configurationStep.getAdditionalKubernetesResources().head.isInstanceOf[Service]) + val driverService = configurationStep + .getAdditionalKubernetesResources() + .head + .asInstanceOf[Service] + verifyService( + 9000, + 8080, + s"$SHORT_RESOURCE_NAME_PREFIX${DriverServiceFeatureStep.DRIVER_SVC_POSTFIX}", + driverService) + } + + test("Hostname and ports are set according to the service name.") { + val configurationStep = new DriverServiceFeatureStep( + KubernetesConf( + sparkConf + .set("spark.driver.port", "9000") + .set(org.apache.spark.internal.config.DRIVER_BLOCK_MANAGER_PORT, 8080) + .set(KUBERNETES_NAMESPACE, "my-namespace"), + KubernetesDriverSpecificConf( + None, "main", "app", Seq.empty), + SHORT_RESOURCE_NAME_PREFIX, + "app-id", + DRIVER_LABELS, + Map.empty, + Map.empty, + Map.empty)) + val expectedServiceName = SHORT_RESOURCE_NAME_PREFIX + + DriverServiceFeatureStep.DRIVER_SVC_POSTFIX + val expectedHostName = s"$expectedServiceName.my-namespace.svc" + val additionalProps = configurationStep.getAdditionalPodSystemProperties() + verifySparkConfHostNames(additionalProps, expectedHostName) + } + + test("Ports should resolve to defaults in SparkConf and in the service.") { + val configurationStep = new DriverServiceFeatureStep( + KubernetesConf( + sparkConf, + KubernetesDriverSpecificConf( + None, "main", "app", Seq.empty), + SHORT_RESOURCE_NAME_PREFIX, + "app-id", + DRIVER_LABELS, + Map.empty, + Map.empty, + Map.empty)) + val resolvedService = configurationStep + .getAdditionalKubernetesResources() + .head + .asInstanceOf[Service] + verifyService( + DEFAULT_DRIVER_PORT, + DEFAULT_BLOCKMANAGER_PORT, + s"$SHORT_RESOURCE_NAME_PREFIX${DriverServiceFeatureStep.DRIVER_SVC_POSTFIX}", + resolvedService) + val additionalProps = configurationStep.getAdditionalPodSystemProperties() + assert(additionalProps("spark.driver.port") === DEFAULT_DRIVER_PORT.toString) + assert(additionalProps(org.apache.spark.internal.config.DRIVER_BLOCK_MANAGER_PORT.key) + === DEFAULT_BLOCKMANAGER_PORT.toString) + } + + test("Long prefixes should switch to using a generated name.") { + when(clock.getTimeMillis()).thenReturn(10000) + val configurationStep = new DriverServiceFeatureStep( + KubernetesConf( + sparkConf.set(KUBERNETES_NAMESPACE, "my-namespace"), + KubernetesDriverSpecificConf( + None, "main", "app", Seq.empty), + LONG_RESOURCE_NAME_PREFIX, + "app-id", + DRIVER_LABELS, + Map.empty, + Map.empty, + Map.empty), + clock) + val driverService = configurationStep + .getAdditionalKubernetesResources() + .head + .asInstanceOf[Service] + val expectedServiceName = s"spark-10000${DriverServiceFeatureStep.DRIVER_SVC_POSTFIX}" + assert(driverService.getMetadata.getName === expectedServiceName) + val expectedHostName = s"$expectedServiceName.my-namespace.svc" + val additionalProps = configurationStep.getAdditionalPodSystemProperties() + verifySparkConfHostNames(additionalProps, expectedHostName) + } + + test("Disallow bind address and driver host to be set explicitly.") { + try { + new DriverServiceFeatureStep( + KubernetesConf( + sparkConf.set(org.apache.spark.internal.config.DRIVER_BIND_ADDRESS, "host"), + KubernetesDriverSpecificConf( + None, "main", "app", Seq.empty), + LONG_RESOURCE_NAME_PREFIX, + "app-id", + DRIVER_LABELS, + Map.empty, + Map.empty, + Map.empty), + clock) + fail("The driver bind address should not be allowed.") + } catch { + case e: Throwable => + assert(e.getMessage === + s"requirement failed: ${DriverServiceFeatureStep.DRIVER_BIND_ADDRESS_KEY} is" + + " not supported in Kubernetes mode, as the driver's bind address is managed" + + " and set to the driver pod's IP address.") + } + sparkConf.remove(org.apache.spark.internal.config.DRIVER_BIND_ADDRESS) + sparkConf.set(org.apache.spark.internal.config.DRIVER_HOST_ADDRESS, "host") + try { + new DriverServiceFeatureStep( + KubernetesConf( + sparkConf, + KubernetesDriverSpecificConf( + None, "main", "app", Seq.empty), + LONG_RESOURCE_NAME_PREFIX, + "app-id", + DRIVER_LABELS, + Map.empty, + Map.empty, + Map.empty), + clock) + fail("The driver host address should not be allowed.") + } catch { + case e: Throwable => + assert(e.getMessage === + s"requirement failed: ${DriverServiceFeatureStep.DRIVER_HOST_KEY} is" + + " not supported in Kubernetes mode, as the driver's hostname will be managed via" + + " a Kubernetes service.") + } + } + + private def verifyService( + driverPort: Int, + blockManagerPort: Int, + expectedServiceName: String, + service: Service): Unit = { + assert(service.getMetadata.getName === expectedServiceName) + assert(service.getSpec.getClusterIP === "None") + assert(service.getSpec.getSelector.asScala === DRIVER_LABELS) + assert(service.getSpec.getPorts.size() === 2) + val driverServicePorts = service.getSpec.getPorts.asScala + assert(driverServicePorts.head.getName === DRIVER_PORT_NAME) + assert(driverServicePorts.head.getPort.intValue() === driverPort) + assert(driverServicePorts.head.getTargetPort.getIntVal === driverPort) + assert(driverServicePorts(1).getName === BLOCK_MANAGER_PORT_NAME) + assert(driverServicePorts(1).getPort.intValue() === blockManagerPort) + assert(driverServicePorts(1).getTargetPort.getIntVal === blockManagerPort) + } + + private def verifySparkConfHostNames( + driverSparkConf: Map[String, String], expectedHostName: String): Unit = { + assert(driverSparkConf( + org.apache.spark.internal.config.DRIVER_HOST_ADDRESS.key) === expectedHostName) + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/KubernetesFeaturesTestUtils.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/KubernetesFeaturesTestUtils.scala new file mode 100644 index 0000000000000..27bff74ce38af --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/KubernetesFeaturesTestUtils.scala @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.features + +import io.fabric8.kubernetes.api.model.{HasMetadata, PodBuilder, SecretBuilder} +import org.mockito.Matchers +import org.mockito.Mockito._ +import org.mockito.invocation.InvocationOnMock +import org.mockito.stubbing.Answer + +import org.apache.spark.deploy.k8s.SparkPod + +object KubernetesFeaturesTestUtils { + + def getMockConfigStepForStepType[T <: KubernetesFeatureConfigStep]( + stepType: String, stepClass: Class[T]): T = { + val mockStep = mock(stepClass) + when(mockStep.getAdditionalKubernetesResources()).thenReturn( + getSecretsForStepType(stepType)) + + when(mockStep.getAdditionalPodSystemProperties()) + .thenReturn(Map(stepType -> stepType)) + when(mockStep.configurePod(Matchers.any(classOf[SparkPod]))) + .thenAnswer(new Answer[SparkPod]() { + override def answer(invocation: InvocationOnMock): SparkPod = { + val originalPod = invocation.getArgumentAt(0, classOf[SparkPod]) + val configuredPod = new PodBuilder(originalPod.pod) + .editOrNewMetadata() + .addToLabels(stepType, stepType) + .endMetadata() + .build() + SparkPod(configuredPod, originalPod.container) + } + }) + mockStep + } + + def getSecretsForStepType[T <: KubernetesFeatureConfigStep](stepType: String) + : Seq[HasMetadata] = { + Seq(new SecretBuilder() + .withNewMetadata() + .withName(stepType) + .endMetadata() + .build()) + } + +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverMountSecretsStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountSecretsFeatureStepSuite.scala similarity index 64% rename from resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverMountSecretsStepSuite.scala rename to resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountSecretsFeatureStepSuite.scala index 960d0bda1d011..9d02f56cc206d 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverMountSecretsStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountSecretsFeatureStepSuite.scala @@ -14,29 +14,38 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.deploy.k8s.submit.steps +package org.apache.spark.deploy.k8s.features + +import io.fabric8.kubernetes.api.model.PodBuilder import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.deploy.k8s.{MountSecretsBootstrap, SecretVolumeUtils} -import org.apache.spark.deploy.k8s.submit.KubernetesDriverSpec +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesExecutorSpecificConf, SecretVolumeUtils, SparkPod} -class DriverMountSecretsStepSuite extends SparkFunSuite { +class MountSecretsFeatureStepSuite extends SparkFunSuite { private val SECRET_FOO = "foo" private val SECRET_BAR = "bar" private val SECRET_MOUNT_PATH = "/etc/secrets/driver" test("mounts all given secrets") { - val baseDriverSpec = KubernetesDriverSpec.initialSpec(new SparkConf(false)) + val baseDriverPod = SparkPod.initialPod() val secretNamesToMountPaths = Map( SECRET_FOO -> SECRET_MOUNT_PATH, SECRET_BAR -> SECRET_MOUNT_PATH) + val sparkConf = new SparkConf(false) + val kubernetesConf = KubernetesConf( + sparkConf, + KubernetesExecutorSpecificConf("1", new PodBuilder().build()), + "resource-name-prefix", + "app-id", + Map.empty, + Map.empty, + secretNamesToMountPaths, + Map.empty) - val mountSecretsBootstrap = new MountSecretsBootstrap(secretNamesToMountPaths) - val mountSecretsStep = new DriverMountSecretsStep(mountSecretsBootstrap) - val configuredDriverSpec = mountSecretsStep.configureDriver(baseDriverSpec) - val driverPodWithSecretsMounted = configuredDriverSpec.driverPod - val driverContainerWithSecretsMounted = configuredDriverSpec.driverContainer + val step = new MountSecretsFeatureStep(kubernetesConf) + val driverPodWithSecretsMounted = step.configurePod(baseDriverPod).pod + val driverContainerWithSecretsMounted = step.configurePod(baseDriverPod).container Seq(s"$SECRET_FOO-volume", s"$SECRET_BAR-volume").foreach { volumeName => assert(SecretVolumeUtils.podHasVolume(driverPodWithSecretsMounted, volumeName)) diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala index 6a501592f42a3..c1b203e03a357 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala @@ -16,22 +16,17 @@ */ package org.apache.spark.deploy.k8s.submit -import scala.collection.JavaConverters._ - -import com.google.common.collect.Iterables import io.fabric8.kubernetes.api.model._ import io.fabric8.kubernetes.client.{KubernetesClient, Watch} import io.fabric8.kubernetes.client.dsl.{MixedOperation, NamespaceListVisitFromServerGetDeleteRecreateWaitApplicable, PodResource} import org.mockito.{ArgumentCaptor, Mock, MockitoAnnotations} import org.mockito.Mockito.{doReturn, verify, when} -import org.mockito.invocation.InvocationOnMock -import org.mockito.stubbing.Answer import org.scalatest.BeforeAndAfter import org.scalatest.mockito.MockitoSugar._ import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpec, KubernetesDriverSpecificConf, SparkPod} import org.apache.spark.deploy.k8s.Constants._ -import org.apache.spark.deploy.k8s.submit.steps.DriverConfigurationStep class ClientSuite extends SparkFunSuite with BeforeAndAfter { @@ -39,6 +34,74 @@ class ClientSuite extends SparkFunSuite with BeforeAndAfter { private val DRIVER_POD_API_VERSION = "v1" private val DRIVER_POD_KIND = "pod" private val KUBERNETES_RESOURCE_PREFIX = "resource-example" + private val POD_NAME = "driver" + private val CONTAINER_NAME = "container" + private val APP_ID = "app-id" + private val APP_NAME = "app" + private val MAIN_CLASS = "main" + private val APP_ARGS = Seq("arg1", "arg2") + private val RESOLVED_JAVA_OPTIONS = Map( + "conf1key" -> "conf1value", + "conf2key" -> "conf2value") + private val BUILT_DRIVER_POD = + new PodBuilder() + .withNewMetadata() + .withName(POD_NAME) + .endMetadata() + .withNewSpec() + .withHostname("localhost") + .endSpec() + .build() + private val BUILT_DRIVER_CONTAINER = new ContainerBuilder().withName(CONTAINER_NAME).build() + private val ADDITIONAL_RESOURCES = Seq( + new SecretBuilder().withNewMetadata().withName("secret").endMetadata().build()) + + private val BUILT_KUBERNETES_SPEC = KubernetesDriverSpec( + SparkPod(BUILT_DRIVER_POD, BUILT_DRIVER_CONTAINER), + ADDITIONAL_RESOURCES, + RESOLVED_JAVA_OPTIONS) + + private val FULL_EXPECTED_CONTAINER = new ContainerBuilder(BUILT_DRIVER_CONTAINER) + .addNewEnv() + .withName(ENV_SPARK_CONF_DIR) + .withValue(SPARK_CONF_DIR_INTERNAL) + .endEnv() + .addNewVolumeMount() + .withName(SPARK_CONF_VOLUME) + .withMountPath(SPARK_CONF_DIR_INTERNAL) + .endVolumeMount() + .build() + private val FULL_EXPECTED_POD = new PodBuilder(BUILT_DRIVER_POD) + .editSpec() + .addToContainers(FULL_EXPECTED_CONTAINER) + .addNewVolume() + .withName(SPARK_CONF_VOLUME) + .withNewConfigMap().withName(s"$KUBERNETES_RESOURCE_PREFIX-driver-conf-map").endConfigMap() + .endVolume() + .endSpec() + .build() + + private val POD_WITH_OWNER_REFERENCE = new PodBuilder(FULL_EXPECTED_POD) + .editMetadata() + .withUid(DRIVER_POD_UID) + .endMetadata() + .withApiVersion(DRIVER_POD_API_VERSION) + .withKind(DRIVER_POD_KIND) + .build() + + private val ADDITIONAL_RESOURCES_WITH_OWNER_REFERENCES = ADDITIONAL_RESOURCES.map { secret => + new SecretBuilder(secret) + .editMetadata() + .addNewOwnerReference() + .withName(POD_NAME) + .withApiVersion(DRIVER_POD_API_VERSION) + .withKind(DRIVER_POD_KIND) + .withController(true) + .withUid(DRIVER_POD_UID) + .endOwnerReference() + .endMetadata() + .build() + } private type ResourceList = NamespaceListVisitFromServerGetDeleteRecreateWaitApplicable[ HasMetadata, Boolean] @@ -56,113 +119,86 @@ class ClientSuite extends SparkFunSuite with BeforeAndAfter { @Mock private var loggingPodStatusWatcher: LoggingPodStatusWatcher = _ + @Mock + private var driverBuilder: KubernetesDriverBuilder = _ + @Mock private var resourceList: ResourceList = _ - private val submissionSteps = Seq(FirstTestConfigurationStep, SecondTestConfigurationStep) + private var kubernetesConf: KubernetesConf[KubernetesDriverSpecificConf] = _ + + private var sparkConf: SparkConf = _ private var createdPodArgumentCaptor: ArgumentCaptor[Pod] = _ private var createdResourcesArgumentCaptor: ArgumentCaptor[HasMetadata] = _ - private var createdContainerArgumentCaptor: ArgumentCaptor[Container] = _ before { MockitoAnnotations.initMocks(this) + sparkConf = new SparkConf(false) + kubernetesConf = KubernetesConf[KubernetesDriverSpecificConf]( + sparkConf, + KubernetesDriverSpecificConf(None, MAIN_CLASS, APP_NAME, APP_ARGS), + KUBERNETES_RESOURCE_PREFIX, + APP_ID, + Map.empty, + Map.empty, + Map.empty, + Map.empty) + when(driverBuilder.buildFromFeatures(kubernetesConf)).thenReturn(BUILT_KUBERNETES_SPEC) when(kubernetesClient.pods()).thenReturn(podOperations) - when(podOperations.withName(FirstTestConfigurationStep.podName)).thenReturn(namedPods) + when(podOperations.withName(POD_NAME)).thenReturn(namedPods) createdPodArgumentCaptor = ArgumentCaptor.forClass(classOf[Pod]) createdResourcesArgumentCaptor = ArgumentCaptor.forClass(classOf[HasMetadata]) - when(podOperations.create(createdPodArgumentCaptor.capture())).thenAnswer(new Answer[Pod] { - override def answer(invocation: InvocationOnMock): Pod = { - new PodBuilder(invocation.getArgumentAt(0, classOf[Pod])) - .editMetadata() - .withUid(DRIVER_POD_UID) - .endMetadata() - .withApiVersion(DRIVER_POD_API_VERSION) - .withKind(DRIVER_POD_KIND) - .build() - } - }) - when(podOperations.withName(FirstTestConfigurationStep.podName)).thenReturn(namedPods) + when(podOperations.create(FULL_EXPECTED_POD)).thenReturn(POD_WITH_OWNER_REFERENCE) when(namedPods.watch(loggingPodStatusWatcher)).thenReturn(mock[Watch]) doReturn(resourceList) .when(kubernetesClient) .resourceList(createdResourcesArgumentCaptor.capture()) } - test("The client should configure the pod with the submission steps.") { + test("The client should configure the pod using the builder.") { val submissionClient = new Client( - submissionSteps, - new SparkConf(false), + driverBuilder, + kubernetesConf, kubernetesClient, false, "spark", loggingPodStatusWatcher, KUBERNETES_RESOURCE_PREFIX) submissionClient.run() - val createdPod = createdPodArgumentCaptor.getValue - assert(createdPod.getMetadata.getName === FirstTestConfigurationStep.podName) - assert(createdPod.getMetadata.getLabels.asScala === - Map(FirstTestConfigurationStep.labelKey -> FirstTestConfigurationStep.labelValue)) - assert(createdPod.getMetadata.getAnnotations.asScala === - Map(SecondTestConfigurationStep.annotationKey -> - SecondTestConfigurationStep.annotationValue)) - assert(createdPod.getSpec.getContainers.size() === 1) - assert(createdPod.getSpec.getContainers.get(0).getName === - SecondTestConfigurationStep.containerName) + verify(podOperations).create(FULL_EXPECTED_POD) } test("The client should create Kubernetes resources") { - val EXAMPLE_JAVA_OPTS = "-XX:+HeapDumpOnOutOfMemoryError -XX:+PrintGCDetails" - val EXPECTED_JAVA_OPTS = "-XX\\:+HeapDumpOnOutOfMemoryError -XX\\:+PrintGCDetails" val submissionClient = new Client( - submissionSteps, - new SparkConf(false) - .set(org.apache.spark.internal.config.DRIVER_JAVA_OPTIONS, EXAMPLE_JAVA_OPTS), + driverBuilder, + kubernetesConf, kubernetesClient, false, "spark", loggingPodStatusWatcher, KUBERNETES_RESOURCE_PREFIX) submissionClient.run() - val createdPod = createdPodArgumentCaptor.getValue val otherCreatedResources = createdResourcesArgumentCaptor.getAllValues assert(otherCreatedResources.size === 2) - val secrets = otherCreatedResources.toArray - .filter(_.isInstanceOf[Secret]).map(_.asInstanceOf[Secret]) + val secrets = otherCreatedResources.toArray.filter(_.isInstanceOf[Secret]).toSeq + assert(secrets === ADDITIONAL_RESOURCES_WITH_OWNER_REFERENCES) val configMaps = otherCreatedResources.toArray .filter(_.isInstanceOf[ConfigMap]).map(_.asInstanceOf[ConfigMap]) assert(secrets.nonEmpty) - val secret = secrets.head - assert(secret.getMetadata.getName === FirstTestConfigurationStep.secretName) - assert(secret.getData.asScala === - Map(FirstTestConfigurationStep.secretKey -> FirstTestConfigurationStep.secretData)) - val ownerReference = Iterables.getOnlyElement(secret.getMetadata.getOwnerReferences) - assert(ownerReference.getName === createdPod.getMetadata.getName) - assert(ownerReference.getKind === DRIVER_POD_KIND) - assert(ownerReference.getUid === DRIVER_POD_UID) - assert(ownerReference.getApiVersion === DRIVER_POD_API_VERSION) assert(configMaps.nonEmpty) val configMap = configMaps.head assert(configMap.getMetadata.getName === s"$KUBERNETES_RESOURCE_PREFIX-driver-conf-map") assert(configMap.getData.containsKey(SPARK_CONF_FILE_NAME)) - assert(configMap.getData.get(SPARK_CONF_FILE_NAME).contains(EXPECTED_JAVA_OPTS)) - assert(configMap.getData.get(SPARK_CONF_FILE_NAME).contains( - "spark.custom-conf=custom-conf-value")) - val driverContainer = Iterables.getOnlyElement(createdPod.getSpec.getContainers) - assert(driverContainer.getName === SecondTestConfigurationStep.containerName) - val driverEnv = driverContainer.getEnv.asScala.head - assert(driverEnv.getName === ENV_SPARK_CONF_DIR) - assert(driverEnv.getValue === SPARK_CONF_DIR_INTERNAL) - val driverMount = driverContainer.getVolumeMounts.asScala.head - assert(driverMount.getName === SPARK_CONF_VOLUME) - assert(driverMount.getMountPath === SPARK_CONF_DIR_INTERNAL) + assert(configMap.getData.get(SPARK_CONF_FILE_NAME).contains("conf1key=conf1value")) + assert(configMap.getData.get(SPARK_CONF_FILE_NAME).contains("conf2key=conf2value")) } test("Waiting for app completion should stall on the watcher") { val submissionClient = new Client( - submissionSteps, - new SparkConf(false), + driverBuilder, + kubernetesConf, kubernetesClient, true, "spark", @@ -171,56 +207,4 @@ class ClientSuite extends SparkFunSuite with BeforeAndAfter { submissionClient.run() verify(loggingPodStatusWatcher).awaitCompletion() } - -} - -private object FirstTestConfigurationStep extends DriverConfigurationStep { - - val podName = "test-pod" - val secretName = "test-secret" - val labelKey = "first-submit" - val labelValue = "true" - val secretKey = "secretKey" - val secretData = "secretData" - - override def configureDriver(driverSpec: KubernetesDriverSpec): KubernetesDriverSpec = { - val modifiedPod = new PodBuilder(driverSpec.driverPod) - .editMetadata() - .withName(podName) - .addToLabels(labelKey, labelValue) - .endMetadata() - .build() - val additionalResource = new SecretBuilder() - .withNewMetadata() - .withName(secretName) - .endMetadata() - .addToData(secretKey, secretData) - .build() - driverSpec.copy( - driverPod = modifiedPod, - otherKubernetesResources = driverSpec.otherKubernetesResources ++ Seq(additionalResource)) - } -} - -private object SecondTestConfigurationStep extends DriverConfigurationStep { - val annotationKey = "second-submit" - val annotationValue = "submitted" - val sparkConfKey = "spark.custom-conf" - val sparkConfValue = "custom-conf-value" - val containerName = "driverContainer" - override def configureDriver(driverSpec: KubernetesDriverSpec): KubernetesDriverSpec = { - val modifiedPod = new PodBuilder(driverSpec.driverPod) - .editMetadata() - .addToAnnotations(annotationKey, annotationValue) - .endMetadata() - .build() - val resolvedSparkConf = driverSpec.driverSparkConf.clone().set(sparkConfKey, sparkConfValue) - val modifiedContainer = new ContainerBuilder(driverSpec.driverContainer) - .withName(containerName) - .build() - driverSpec.copy( - driverPod = modifiedPod, - driverSparkConf = resolvedSparkConf, - driverContainer = modifiedContainer) - } } diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestratorSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestratorSuite.scala deleted file mode 100644 index df34d2dbcb5be..0000000000000 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestratorSuite.scala +++ /dev/null @@ -1,131 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s.submit - -import org.apache.spark.{SparkConf, SparkException, SparkFunSuite} -import org.apache.spark.deploy.k8s.Config._ -import org.apache.spark.deploy.k8s.submit.steps._ - -class DriverConfigOrchestratorSuite extends SparkFunSuite { - - private val DRIVER_IMAGE = "driver-image" - private val IC_IMAGE = "init-container-image" - private val APP_ID = "spark-app-id" - private val KUBERNETES_RESOURCE_PREFIX = "example-prefix" - private val APP_NAME = "spark" - private val MAIN_CLASS = "org.apache.spark.examples.SparkPi" - private val APP_ARGS = Array("arg1", "arg2") - private val SECRET_FOO = "foo" - private val SECRET_BAR = "bar" - private val SECRET_MOUNT_PATH = "/etc/secrets/driver" - - test("Base submission steps with a main app resource.") { - val sparkConf = new SparkConf(false).set(CONTAINER_IMAGE, DRIVER_IMAGE) - val mainAppResource = JavaMainAppResource("local:///var/apps/jars/main.jar") - val orchestrator = new DriverConfigOrchestrator( - APP_ID, - KUBERNETES_RESOURCE_PREFIX, - Some(mainAppResource), - APP_NAME, - MAIN_CLASS, - APP_ARGS, - sparkConf) - validateStepTypes( - orchestrator, - classOf[BasicDriverConfigurationStep], - classOf[DriverServiceBootstrapStep], - classOf[DriverKubernetesCredentialsStep], - classOf[DependencyResolutionStep]) - } - - test("Base submission steps without a main app resource.") { - val sparkConf = new SparkConf(false).set(CONTAINER_IMAGE, DRIVER_IMAGE) - val orchestrator = new DriverConfigOrchestrator( - APP_ID, - KUBERNETES_RESOURCE_PREFIX, - Option.empty, - APP_NAME, - MAIN_CLASS, - APP_ARGS, - sparkConf) - validateStepTypes( - orchestrator, - classOf[BasicDriverConfigurationStep], - classOf[DriverServiceBootstrapStep], - classOf[DriverKubernetesCredentialsStep]) - } - - test("Submission steps with driver secrets to mount") { - val sparkConf = new SparkConf(false) - .set(CONTAINER_IMAGE, DRIVER_IMAGE) - .set(s"$KUBERNETES_DRIVER_SECRETS_PREFIX$SECRET_FOO", SECRET_MOUNT_PATH) - .set(s"$KUBERNETES_DRIVER_SECRETS_PREFIX$SECRET_BAR", SECRET_MOUNT_PATH) - val mainAppResource = JavaMainAppResource("local:///var/apps/jars/main.jar") - val orchestrator = new DriverConfigOrchestrator( - APP_ID, - KUBERNETES_RESOURCE_PREFIX, - Some(mainAppResource), - APP_NAME, - MAIN_CLASS, - APP_ARGS, - sparkConf) - validateStepTypes( - orchestrator, - classOf[BasicDriverConfigurationStep], - classOf[DriverServiceBootstrapStep], - classOf[DriverKubernetesCredentialsStep], - classOf[DependencyResolutionStep], - classOf[DriverMountSecretsStep]) - } - - test("Submission using client local dependencies") { - val sparkConf = new SparkConf(false) - .set(CONTAINER_IMAGE, DRIVER_IMAGE) - var orchestrator = new DriverConfigOrchestrator( - APP_ID, - KUBERNETES_RESOURCE_PREFIX, - Some(JavaMainAppResource("file:///var/apps/jars/main.jar")), - APP_NAME, - MAIN_CLASS, - APP_ARGS, - sparkConf) - assertThrows[SparkException] { - orchestrator.getAllConfigurationSteps - } - - sparkConf.set("spark.files", "/path/to/file1,/path/to/file2") - orchestrator = new DriverConfigOrchestrator( - APP_ID, - KUBERNETES_RESOURCE_PREFIX, - Some(JavaMainAppResource("local:///var/apps/jars/main.jar")), - APP_NAME, - MAIN_CLASS, - APP_ARGS, - sparkConf) - assertThrows[SparkException] { - orchestrator.getAllConfigurationSteps - } - } - - private def validateStepTypes( - orchestrator: DriverConfigOrchestrator, - types: Class[_ <: DriverConfigurationStep]*): Unit = { - val steps = orchestrator.getAllConfigurationSteps - assert(steps.size === types.size) - assert(steps.map(_.getClass) === types) - } -} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala new file mode 100644 index 0000000000000..161f9afe7bba9 --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.submit + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpec, KubernetesDriverSpecificConf} +import org.apache.spark.deploy.k8s.features.{BasicDriverFeatureStep, DriverKubernetesCredentialsFeatureStep, DriverServiceFeatureStep, KubernetesFeaturesTestUtils, MountSecretsFeatureStep} + +class KubernetesDriverBuilderSuite extends SparkFunSuite { + + private val BASIC_STEP_TYPE = "basic" + private val CREDENTIALS_STEP_TYPE = "credentials" + private val SERVICE_STEP_TYPE = "service" + private val SECRETS_STEP_TYPE = "mount-secrets" + + private val basicFeatureStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( + BASIC_STEP_TYPE, classOf[BasicDriverFeatureStep]) + + private val credentialsStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( + CREDENTIALS_STEP_TYPE, classOf[DriverKubernetesCredentialsFeatureStep]) + + private val serviceStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( + SERVICE_STEP_TYPE, classOf[DriverServiceFeatureStep]) + + private val secretsStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( + SECRETS_STEP_TYPE, classOf[MountSecretsFeatureStep]) + + private val builderUnderTest: KubernetesDriverBuilder = + new KubernetesDriverBuilder( + _ => basicFeatureStep, + _ => credentialsStep, + _ => serviceStep, + _ => secretsStep) + + test("Apply fundamental steps all the time.") { + val conf = KubernetesConf( + new SparkConf(false), + KubernetesDriverSpecificConf( + None, + "test-app", + "main", + Seq.empty), + "prefix", + "appId", + Map.empty, + Map.empty, + Map.empty, + Map.empty) + validateStepTypesApplied( + builderUnderTest.buildFromFeatures(conf), + BASIC_STEP_TYPE, + CREDENTIALS_STEP_TYPE, + SERVICE_STEP_TYPE) + } + + test("Apply secrets step if secrets are present.") { + val conf = KubernetesConf( + new SparkConf(false), + KubernetesDriverSpecificConf( + None, + "test-app", + "main", + Seq.empty), + "prefix", + "appId", + Map.empty, + Map.empty, + Map("secret" -> "secretMountPath"), + Map.empty) + validateStepTypesApplied( + builderUnderTest.buildFromFeatures(conf), + BASIC_STEP_TYPE, + CREDENTIALS_STEP_TYPE, + SERVICE_STEP_TYPE, + SECRETS_STEP_TYPE) + } + + private def validateStepTypesApplied(resolvedSpec: KubernetesDriverSpec, stepTypes: String*) + : Unit = { + assert(resolvedSpec.systemProperties.size === stepTypes.size) + stepTypes.foreach { stepType => + assert(resolvedSpec.pod.pod.getMetadata.getLabels.get(stepType) === stepType) + assert(resolvedSpec.driverKubernetesResources.containsSlice( + KubernetesFeaturesTestUtils.getSecretsForStepType(stepType))) + assert(resolvedSpec.systemProperties(stepType) === stepType) + } + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStepSuite.scala deleted file mode 100644 index ee450fff8d376..0000000000000 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStepSuite.scala +++ /dev/null @@ -1,122 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s.submit.steps - -import scala.collection.JavaConverters._ - -import io.fabric8.kubernetes.api.model.{ContainerBuilder, HasMetadata, PodBuilder} - -import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.deploy.k8s.Config._ -import org.apache.spark.deploy.k8s.Constants._ -import org.apache.spark.deploy.k8s.submit.KubernetesDriverSpec - -class BasicDriverConfigurationStepSuite extends SparkFunSuite { - - private val APP_ID = "spark-app-id" - private val RESOURCE_NAME_PREFIX = "spark" - private val DRIVER_LABELS = Map("labelkey" -> "labelvalue") - private val CONTAINER_IMAGE_PULL_POLICY = "IfNotPresent" - private val APP_NAME = "spark-test" - private val MAIN_CLASS = "org.apache.spark.examples.SparkPi" - private val APP_ARGS = Array("arg1", "arg2", "\"arg 3\"") - private val CUSTOM_ANNOTATION_KEY = "customAnnotation" - private val CUSTOM_ANNOTATION_VALUE = "customAnnotationValue" - private val DRIVER_CUSTOM_ENV_KEY1 = "customDriverEnv1" - private val DRIVER_CUSTOM_ENV_KEY2 = "customDriverEnv2" - - test("Set all possible configurations from the user.") { - val sparkConf = new SparkConf() - .set(KUBERNETES_DRIVER_POD_NAME, "spark-driver-pod") - .set(org.apache.spark.internal.config.DRIVER_CLASS_PATH, "/opt/spark/spark-examples.jar") - .set("spark.driver.cores", "2") - .set(KUBERNETES_DRIVER_LIMIT_CORES, "4") - .set(org.apache.spark.internal.config.DRIVER_MEMORY.key, "256M") - .set(org.apache.spark.internal.config.DRIVER_MEMORY_OVERHEAD, 200L) - .set(CONTAINER_IMAGE, "spark-driver:latest") - .set(s"$KUBERNETES_DRIVER_ANNOTATION_PREFIX$CUSTOM_ANNOTATION_KEY", CUSTOM_ANNOTATION_VALUE) - .set(s"$KUBERNETES_DRIVER_ENV_KEY$DRIVER_CUSTOM_ENV_KEY1", "customDriverEnv1") - .set(s"$KUBERNETES_DRIVER_ENV_KEY$DRIVER_CUSTOM_ENV_KEY2", "customDriverEnv2") - .set(IMAGE_PULL_SECRETS, "imagePullSecret1, imagePullSecret2") - - val submissionStep = new BasicDriverConfigurationStep( - APP_ID, - RESOURCE_NAME_PREFIX, - DRIVER_LABELS, - CONTAINER_IMAGE_PULL_POLICY, - APP_NAME, - MAIN_CLASS, - APP_ARGS, - sparkConf) - val basePod = new PodBuilder().withNewMetadata().endMetadata().withNewSpec().endSpec().build() - val baseDriverSpec = KubernetesDriverSpec( - driverPod = basePod, - driverContainer = new ContainerBuilder().build(), - driverSparkConf = new SparkConf(false), - otherKubernetesResources = Seq.empty[HasMetadata]) - val preparedDriverSpec = submissionStep.configureDriver(baseDriverSpec) - - assert(preparedDriverSpec.driverContainer.getName === DRIVER_CONTAINER_NAME) - assert(preparedDriverSpec.driverContainer.getImage === "spark-driver:latest") - assert(preparedDriverSpec.driverContainer.getImagePullPolicy === CONTAINER_IMAGE_PULL_POLICY) - - assert(preparedDriverSpec.driverContainer.getEnv.size === 4) - val envs = preparedDriverSpec.driverContainer - .getEnv - .asScala - .map(env => (env.getName, env.getValue)) - .toMap - assert(envs(ENV_CLASSPATH) === "/opt/spark/spark-examples.jar") - assert(envs(DRIVER_CUSTOM_ENV_KEY1) === "customDriverEnv1") - assert(envs(DRIVER_CUSTOM_ENV_KEY2) === "customDriverEnv2") - - assert(preparedDriverSpec.driverContainer.getEnv.asScala.exists(envVar => - envVar.getName.equals(ENV_DRIVER_BIND_ADDRESS) && - envVar.getValueFrom.getFieldRef.getApiVersion.equals("v1") && - envVar.getValueFrom.getFieldRef.getFieldPath.equals("status.podIP"))) - - val resourceRequirements = preparedDriverSpec.driverContainer.getResources - val requests = resourceRequirements.getRequests.asScala - assert(requests("cpu").getAmount === "2") - assert(requests("memory").getAmount === "456Mi") - val limits = resourceRequirements.getLimits.asScala - assert(limits("memory").getAmount === "456Mi") - assert(limits("cpu").getAmount === "4") - - val driverPodMetadata = preparedDriverSpec.driverPod.getMetadata - assert(driverPodMetadata.getName === "spark-driver-pod") - assert(driverPodMetadata.getLabels.asScala === DRIVER_LABELS) - val expectedAnnotations = Map( - CUSTOM_ANNOTATION_KEY -> CUSTOM_ANNOTATION_VALUE, - SPARK_APP_NAME_ANNOTATION -> APP_NAME) - assert(driverPodMetadata.getAnnotations.asScala === expectedAnnotations) - - val driverPodSpec = preparedDriverSpec.driverPod.getSpec - assert(driverPodSpec.getRestartPolicy === "Never") - assert(driverPodSpec.getImagePullSecrets.size() === 2) - assert(driverPodSpec.getImagePullSecrets.get(0).getName === "imagePullSecret1") - assert(driverPodSpec.getImagePullSecrets.get(1).getName === "imagePullSecret2") - - val resolvedSparkConf = preparedDriverSpec.driverSparkConf.getAll.toMap - val expectedSparkConf = Map( - KUBERNETES_DRIVER_POD_NAME.key -> "spark-driver-pod", - "spark.app.id" -> APP_ID, - KUBERNETES_EXECUTOR_POD_NAME_PREFIX.key -> RESOURCE_NAME_PREFIX, - "spark.kubernetes.submitInDriver" -> "true") - assert(resolvedSparkConf === expectedSparkConf) - } -} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DependencyResolutionStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DependencyResolutionStepSuite.scala deleted file mode 100644 index ca43fc97dc991..0000000000000 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DependencyResolutionStepSuite.scala +++ /dev/null @@ -1,69 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s.submit.steps - -import java.io.File - -import scala.collection.JavaConverters._ - -import io.fabric8.kubernetes.api.model.{ContainerBuilder, HasMetadata, PodBuilder} - -import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.deploy.k8s.Constants._ -import org.apache.spark.deploy.k8s.submit.KubernetesDriverSpec - -class DependencyResolutionStepSuite extends SparkFunSuite { - - private val SPARK_JARS = Seq( - "apps/jars/jar1.jar", - "local:///var/apps/jars/jar2.jar") - - private val SPARK_FILES = Seq( - "apps/files/file1.txt", - "local:///var/apps/files/file2.txt") - - test("Added dependencies should be resolved in Spark configuration and environment") { - val dependencyResolutionStep = new DependencyResolutionStep( - SPARK_JARS, - SPARK_FILES) - val driverPod = new PodBuilder().build() - val baseDriverSpec = KubernetesDriverSpec( - driverPod = driverPod, - driverContainer = new ContainerBuilder().build(), - driverSparkConf = new SparkConf(false), - otherKubernetesResources = Seq.empty[HasMetadata]) - val preparedDriverSpec = dependencyResolutionStep.configureDriver(baseDriverSpec) - assert(preparedDriverSpec.driverPod === driverPod) - assert(preparedDriverSpec.otherKubernetesResources.isEmpty) - val resolvedSparkJars = preparedDriverSpec.driverSparkConf.get("spark.jars").split(",").toSet - val expectedResolvedSparkJars = Set( - "apps/jars/jar1.jar", - "/var/apps/jars/jar2.jar") - assert(resolvedSparkJars === expectedResolvedSparkJars) - val resolvedSparkFiles = preparedDriverSpec.driverSparkConf.get("spark.files").split(",").toSet - val expectedResolvedSparkFiles = Set( - "apps/files/file1.txt", - "/var/apps/files/file2.txt") - assert(resolvedSparkFiles === expectedResolvedSparkFiles) - val driverEnv = preparedDriverSpec.driverContainer.getEnv.asScala - assert(driverEnv.size === 1) - assert(driverEnv.head.getName === ENV_MOUNTED_CLASSPATH) - val resolvedDriverClasspath = driverEnv.head.getValue.split(File.pathSeparator).toSet - val expectedResolvedDriverClasspath = expectedResolvedSparkJars - assert(resolvedDriverClasspath === expectedResolvedDriverClasspath) - } -} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverServiceBootstrapStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverServiceBootstrapStepSuite.scala deleted file mode 100644 index 78c8c3ba1afbd..0000000000000 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverServiceBootstrapStepSuite.scala +++ /dev/null @@ -1,180 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s.submit.steps - -import scala.collection.JavaConverters._ - -import io.fabric8.kubernetes.api.model.Service -import org.mockito.{Mock, MockitoAnnotations} -import org.mockito.Mockito.when -import org.scalatest.BeforeAndAfter - -import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.deploy.k8s.Config._ -import org.apache.spark.deploy.k8s.Constants._ -import org.apache.spark.deploy.k8s.submit.KubernetesDriverSpec -import org.apache.spark.util.Clock - -class DriverServiceBootstrapStepSuite extends SparkFunSuite with BeforeAndAfter { - - private val SHORT_RESOURCE_NAME_PREFIX = - "a" * (DriverServiceBootstrapStep.MAX_SERVICE_NAME_LENGTH - - DriverServiceBootstrapStep.DRIVER_SVC_POSTFIX.length) - - private val LONG_RESOURCE_NAME_PREFIX = - "a" * (DriverServiceBootstrapStep.MAX_SERVICE_NAME_LENGTH - - DriverServiceBootstrapStep.DRIVER_SVC_POSTFIX.length + 1) - private val DRIVER_LABELS = Map( - "label1key" -> "label1value", - "label2key" -> "label2value") - - @Mock - private var clock: Clock = _ - - private var sparkConf: SparkConf = _ - - before { - MockitoAnnotations.initMocks(this) - sparkConf = new SparkConf(false) - } - - test("Headless service has a port for the driver RPC and the block manager.") { - val configurationStep = new DriverServiceBootstrapStep( - SHORT_RESOURCE_NAME_PREFIX, - DRIVER_LABELS, - sparkConf - .set("spark.driver.port", "9000") - .set(org.apache.spark.internal.config.DRIVER_BLOCK_MANAGER_PORT, 8080), - clock) - val baseDriverSpec = KubernetesDriverSpec.initialSpec(sparkConf.clone()) - val resolvedDriverSpec = configurationStep.configureDriver(baseDriverSpec) - assert(resolvedDriverSpec.otherKubernetesResources.size === 1) - assert(resolvedDriverSpec.otherKubernetesResources.head.isInstanceOf[Service]) - val driverService = resolvedDriverSpec.otherKubernetesResources.head.asInstanceOf[Service] - verifyService( - 9000, - 8080, - s"$SHORT_RESOURCE_NAME_PREFIX${DriverServiceBootstrapStep.DRIVER_SVC_POSTFIX}", - driverService) - } - - test("Hostname and ports are set according to the service name.") { - val configurationStep = new DriverServiceBootstrapStep( - SHORT_RESOURCE_NAME_PREFIX, - DRIVER_LABELS, - sparkConf - .set("spark.driver.port", "9000") - .set(org.apache.spark.internal.config.DRIVER_BLOCK_MANAGER_PORT, 8080) - .set(KUBERNETES_NAMESPACE, "my-namespace"), - clock) - val baseDriverSpec = KubernetesDriverSpec.initialSpec(sparkConf.clone()) - val resolvedDriverSpec = configurationStep.configureDriver(baseDriverSpec) - val expectedServiceName = SHORT_RESOURCE_NAME_PREFIX + - DriverServiceBootstrapStep.DRIVER_SVC_POSTFIX - val expectedHostName = s"$expectedServiceName.my-namespace.svc" - verifySparkConfHostNames(resolvedDriverSpec.driverSparkConf, expectedHostName) - } - - test("Ports should resolve to defaults in SparkConf and in the service.") { - val configurationStep = new DriverServiceBootstrapStep( - SHORT_RESOURCE_NAME_PREFIX, - DRIVER_LABELS, - sparkConf, - clock) - val baseDriverSpec = KubernetesDriverSpec.initialSpec(sparkConf.clone()) - val resolvedDriverSpec = configurationStep.configureDriver(baseDriverSpec) - verifyService( - DEFAULT_DRIVER_PORT, - DEFAULT_BLOCKMANAGER_PORT, - s"$SHORT_RESOURCE_NAME_PREFIX${DriverServiceBootstrapStep.DRIVER_SVC_POSTFIX}", - resolvedDriverSpec.otherKubernetesResources.head.asInstanceOf[Service]) - assert(resolvedDriverSpec.driverSparkConf.get("spark.driver.port") === - DEFAULT_DRIVER_PORT.toString) - assert(resolvedDriverSpec.driverSparkConf.get( - org.apache.spark.internal.config.DRIVER_BLOCK_MANAGER_PORT) === DEFAULT_BLOCKMANAGER_PORT) - } - - test("Long prefixes should switch to using a generated name.") { - val configurationStep = new DriverServiceBootstrapStep( - LONG_RESOURCE_NAME_PREFIX, - DRIVER_LABELS, - sparkConf.set(KUBERNETES_NAMESPACE, "my-namespace"), - clock) - when(clock.getTimeMillis()).thenReturn(10000) - val baseDriverSpec = KubernetesDriverSpec.initialSpec(sparkConf.clone()) - val resolvedDriverSpec = configurationStep.configureDriver(baseDriverSpec) - val driverService = resolvedDriverSpec.otherKubernetesResources.head.asInstanceOf[Service] - val expectedServiceName = s"spark-10000${DriverServiceBootstrapStep.DRIVER_SVC_POSTFIX}" - assert(driverService.getMetadata.getName === expectedServiceName) - val expectedHostName = s"$expectedServiceName.my-namespace.svc" - verifySparkConfHostNames(resolvedDriverSpec.driverSparkConf, expectedHostName) - } - - test("Disallow bind address and driver host to be set explicitly.") { - val configurationStep = new DriverServiceBootstrapStep( - LONG_RESOURCE_NAME_PREFIX, - DRIVER_LABELS, - sparkConf.set(org.apache.spark.internal.config.DRIVER_BIND_ADDRESS, "host"), - clock) - try { - configurationStep.configureDriver(KubernetesDriverSpec.initialSpec(sparkConf)) - fail("The driver bind address should not be allowed.") - } catch { - case e: Throwable => - assert(e.getMessage === - s"requirement failed: ${DriverServiceBootstrapStep.DRIVER_BIND_ADDRESS_KEY} is" + - " not supported in Kubernetes mode, as the driver's bind address is managed" + - " and set to the driver pod's IP address.") - } - sparkConf.remove(org.apache.spark.internal.config.DRIVER_BIND_ADDRESS) - sparkConf.set(org.apache.spark.internal.config.DRIVER_HOST_ADDRESS, "host") - try { - configurationStep.configureDriver(KubernetesDriverSpec.initialSpec(sparkConf)) - fail("The driver host address should not be allowed.") - } catch { - case e: Throwable => - assert(e.getMessage === - s"requirement failed: ${DriverServiceBootstrapStep.DRIVER_HOST_KEY} is" + - " not supported in Kubernetes mode, as the driver's hostname will be managed via" + - " a Kubernetes service.") - } - } - - private def verifyService( - driverPort: Int, - blockManagerPort: Int, - expectedServiceName: String, - service: Service): Unit = { - assert(service.getMetadata.getName === expectedServiceName) - assert(service.getSpec.getClusterIP === "None") - assert(service.getSpec.getSelector.asScala === DRIVER_LABELS) - assert(service.getSpec.getPorts.size() === 2) - val driverServicePorts = service.getSpec.getPorts.asScala - assert(driverServicePorts.head.getName === DRIVER_PORT_NAME) - assert(driverServicePorts.head.getPort.intValue() === driverPort) - assert(driverServicePorts.head.getTargetPort.getIntVal === driverPort) - assert(driverServicePorts(1).getName === BLOCK_MANAGER_PORT_NAME) - assert(driverServicePorts(1).getPort.intValue() === blockManagerPort) - assert(driverServicePorts(1).getTargetPort.getIntVal === blockManagerPort) - } - - private def verifySparkConfHostNames( - driverSparkConf: SparkConf, expectedHostName: String): Unit = { - assert(driverSparkConf.get( - org.apache.spark.internal.config.DRIVER_HOST_ADDRESS) === expectedHostName) - } -} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala deleted file mode 100644 index d73df20f0f956..0000000000000 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala +++ /dev/null @@ -1,195 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.scheduler.cluster.k8s - -import scala.collection.JavaConverters._ - -import io.fabric8.kubernetes.api.model._ -import org.mockito.MockitoAnnotations -import org.scalatest.{BeforeAndAfter, BeforeAndAfterEach} - -import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.deploy.k8s.Config._ -import org.apache.spark.deploy.k8s.Constants._ -import org.apache.spark.deploy.k8s.MountSecretsBootstrap - -class ExecutorPodFactorySuite extends SparkFunSuite with BeforeAndAfter with BeforeAndAfterEach { - - private val driverPodName: String = "driver-pod" - private val driverPodUid: String = "driver-uid" - private val executorPrefix: String = "base" - private val executorImage: String = "executor-image" - private val imagePullSecrets: String = "imagePullSecret1, imagePullSecret2" - private val driverPod = new PodBuilder() - .withNewMetadata() - .withName(driverPodName) - .withUid(driverPodUid) - .endMetadata() - .withNewSpec() - .withNodeName("some-node") - .endSpec() - .withNewStatus() - .withHostIP("192.168.99.100") - .endStatus() - .build() - private var baseConf: SparkConf = _ - - before { - MockitoAnnotations.initMocks(this) - baseConf = new SparkConf() - .set(KUBERNETES_DRIVER_POD_NAME, driverPodName) - .set(KUBERNETES_EXECUTOR_POD_NAME_PREFIX, executorPrefix) - .set(CONTAINER_IMAGE, executorImage) - .set(KUBERNETES_DRIVER_SUBMIT_CHECK, true) - .set(IMAGE_PULL_SECRETS, imagePullSecrets) - } - - test("basic executor pod has reasonable defaults") { - val factory = new ExecutorPodFactory(baseConf, None) - val executor = factory.createExecutorPod( - "1", "dummy", "dummy", Seq[(String, String)](), driverPod, Map[String, Int]()) - - // The executor pod name and default labels. - assert(executor.getMetadata.getName === s"$executorPrefix-exec-1") - assert(executor.getMetadata.getLabels.size() === 3) - assert(executor.getMetadata.getLabels.get(SPARK_EXECUTOR_ID_LABEL) === "1") - - // There is exactly 1 container with no volume mounts and default memory limits and requests. - // Default memory limit/request is 1024M + 384M (minimum overhead constant). - assert(executor.getSpec.getContainers.size() === 1) - assert(executor.getSpec.getContainers.get(0).getImage === executorImage) - assert(executor.getSpec.getContainers.get(0).getVolumeMounts.isEmpty) - assert(executor.getSpec.getContainers.get(0).getResources.getLimits.size() === 1) - assert(executor.getSpec.getContainers.get(0).getResources - .getRequests.get("memory").getAmount === "1408Mi") - assert(executor.getSpec.getContainers.get(0).getResources - .getLimits.get("memory").getAmount === "1408Mi") - assert(executor.getSpec.getImagePullSecrets.size() === 2) - assert(executor.getSpec.getImagePullSecrets.get(0).getName === "imagePullSecret1") - assert(executor.getSpec.getImagePullSecrets.get(1).getName === "imagePullSecret2") - - // The pod has no node selector, volumes. - assert(executor.getSpec.getNodeSelector.isEmpty) - assert(executor.getSpec.getVolumes.isEmpty) - - checkEnv(executor, Map()) - checkOwnerReferences(executor, driverPodUid) - } - - test("executor core request specification") { - var factory = new ExecutorPodFactory(baseConf, None) - var executor = factory.createExecutorPod( - "1", "dummy", "dummy", Seq[(String, String)](), driverPod, Map[String, Int]()) - assert(executor.getSpec.getContainers.size() === 1) - assert(executor.getSpec.getContainers.get(0).getResources.getRequests.get("cpu").getAmount - === "1") - - val conf = baseConf.clone() - - conf.set(KUBERNETES_EXECUTOR_REQUEST_CORES, "0.1") - factory = new ExecutorPodFactory(conf, None) - executor = factory.createExecutorPod( - "1", "dummy", "dummy", Seq[(String, String)](), driverPod, Map[String, Int]()) - assert(executor.getSpec.getContainers.size() === 1) - assert(executor.getSpec.getContainers.get(0).getResources.getRequests.get("cpu").getAmount - === "0.1") - - conf.set(KUBERNETES_EXECUTOR_REQUEST_CORES, "100m") - factory = new ExecutorPodFactory(conf, None) - conf.set(KUBERNETES_EXECUTOR_REQUEST_CORES, "100m") - executor = factory.createExecutorPod( - "1", "dummy", "dummy", Seq[(String, String)](), driverPod, Map[String, Int]()) - assert(executor.getSpec.getContainers.get(0).getResources.getRequests.get("cpu").getAmount - === "100m") - } - - test("executor pod hostnames get truncated to 63 characters") { - val conf = baseConf.clone() - conf.set(KUBERNETES_EXECUTOR_POD_NAME_PREFIX, - "loremipsumdolorsitametvimatelitrefficiendisuscipianturvixlegeresple") - - val factory = new ExecutorPodFactory(conf, None) - val executor = factory.createExecutorPod( - "1", "dummy", "dummy", Seq[(String, String)](), driverPod, Map[String, Int]()) - - assert(executor.getSpec.getHostname.length === 63) - } - - test("classpath and extra java options get translated into environment variables") { - val conf = baseConf.clone() - conf.set(org.apache.spark.internal.config.EXECUTOR_JAVA_OPTIONS, "foo=bar") - conf.set(org.apache.spark.internal.config.EXECUTOR_CLASS_PATH, "bar=baz") - - val factory = new ExecutorPodFactory(conf, None) - val executor = factory.createExecutorPod( - "1", "dummy", "dummy", Seq[(String, String)]("qux" -> "quux"), driverPod, Map[String, Int]()) - - checkEnv(executor, - Map("SPARK_JAVA_OPT_0" -> "foo=bar", - ENV_CLASSPATH -> "bar=baz", - "qux" -> "quux")) - checkOwnerReferences(executor, driverPodUid) - } - - test("executor secrets get mounted") { - val conf = baseConf.clone() - - val secretsBootstrap = new MountSecretsBootstrap(Map("secret1" -> "/var/secret1")) - val factory = new ExecutorPodFactory(conf, Some(secretsBootstrap)) - val executor = factory.createExecutorPod( - "1", "dummy", "dummy", Seq[(String, String)](), driverPod, Map[String, Int]()) - - assert(executor.getSpec.getContainers.size() === 1) - assert(executor.getSpec.getContainers.get(0).getVolumeMounts.size() === 1) - assert(executor.getSpec.getContainers.get(0).getVolumeMounts.get(0).getName - === "secret1-volume") - assert(executor.getSpec.getContainers.get(0).getVolumeMounts.get(0) - .getMountPath === "/var/secret1") - - // check volume mounted. - assert(executor.getSpec.getVolumes.size() === 1) - assert(executor.getSpec.getVolumes.get(0).getSecret.getSecretName === "secret1") - - checkOwnerReferences(executor, driverPodUid) - } - - // There is always exactly one controller reference, and it points to the driver pod. - private def checkOwnerReferences(executor: Pod, driverPodUid: String): Unit = { - assert(executor.getMetadata.getOwnerReferences.size() === 1) - assert(executor.getMetadata.getOwnerReferences.get(0).getUid === driverPodUid) - assert(executor.getMetadata.getOwnerReferences.get(0).getController === true) - } - - // Check that the expected environment variables are present. - private def checkEnv(executor: Pod, additionalEnvVars: Map[String, String]): Unit = { - val defaultEnvs = Map( - ENV_EXECUTOR_ID -> "1", - ENV_DRIVER_URL -> "dummy", - ENV_EXECUTOR_CORES -> "1", - ENV_EXECUTOR_MEMORY -> "1g", - ENV_APPLICATION_ID -> "dummy", - ENV_SPARK_CONF_DIR -> SPARK_CONF_DIR_INTERNAL, - ENV_EXECUTOR_POD_IP -> null) ++ additionalEnvVars - - assert(executor.getSpec.getContainers.size() === 1) - assert(executor.getSpec.getContainers.get(0).getEnv.size() === defaultEnvs.size) - val mapEnvs = executor.getSpec.getContainers.get(0).getEnv.asScala.map { - x => (x.getName, x.getValue) - }.toMap - assert(defaultEnvs === mapEnvs) - } -} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackendSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackendSuite.scala index b2f26f205a329..96065e83f069c 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackendSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackendSuite.scala @@ -18,11 +18,12 @@ package org.apache.spark.scheduler.cluster.k8s import java.util.concurrent.{ExecutorService, ScheduledExecutorService, TimeUnit} -import io.fabric8.kubernetes.api.model.{DoneablePod, Pod, PodBuilder, PodList} +import io.fabric8.kubernetes.api.model.{ContainerBuilder, DoneablePod, Pod, PodBuilder, PodList} import io.fabric8.kubernetes.client.{KubernetesClient, Watch, Watcher} import io.fabric8.kubernetes.client.Watcher.Action import io.fabric8.kubernetes.client.dsl.{FilterWatchListDeletable, MixedOperation, NonNamespaceOperation, PodResource} -import org.mockito.{AdditionalAnswers, ArgumentCaptor, Mock, MockitoAnnotations} +import org.hamcrest.{BaseMatcher, Description, Matcher} +import org.mockito.{AdditionalAnswers, ArgumentCaptor, Matchers, Mock, MockitoAnnotations} import org.mockito.Matchers.{any, eq => mockitoEq} import org.mockito.Mockito.{doNothing, never, times, verify, when} import org.scalatest.BeforeAndAfter @@ -31,6 +32,7 @@ import scala.collection.JavaConverters._ import scala.concurrent.Future import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesExecutorSpecificConf, SparkPod} import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ import org.apache.spark.rpc._ @@ -47,8 +49,6 @@ class KubernetesClusterSchedulerBackendSuite extends SparkFunSuite with BeforeAn private val SPARK_DRIVER_HOST = "localhost" private val SPARK_DRIVER_PORT = 7077 private val POD_ALLOCATION_INTERVAL = "1m" - private val DRIVER_URL = RpcEndpointAddress( - SPARK_DRIVER_HOST, SPARK_DRIVER_PORT, CoarseGrainedSchedulerBackend.ENDPOINT_NAME).toString private val FIRST_EXECUTOR_POD = new PodBuilder() .withNewMetadata() .withName("pod1") @@ -94,7 +94,7 @@ class KubernetesClusterSchedulerBackendSuite extends SparkFunSuite with BeforeAn private var requestExecutorsService: ExecutorService = _ @Mock - private var executorPodFactory: ExecutorPodFactory = _ + private var executorBuilder: KubernetesExecutorBuilder = _ @Mock private var kubernetesClient: KubernetesClient = _ @@ -399,7 +399,7 @@ class KubernetesClusterSchedulerBackendSuite extends SparkFunSuite with BeforeAn new KubernetesClusterSchedulerBackend( taskSchedulerImpl, rpcEnv, - executorPodFactory, + executorBuilder, kubernetesClient, allocatorExecutor, requestExecutorsService) { @@ -428,13 +428,22 @@ class KubernetesClusterSchedulerBackendSuite extends SparkFunSuite with BeforeAn .addToLabels(SPARK_EXECUTOR_ID_LABEL, executorId.toString) .endMetadata() .build() - when(executorPodFactory.createExecutorPod( - executorId.toString, - APP_ID, - DRIVER_URL, - sparkConf.getExecutorEnv, - driverPod, - Map.empty)).thenReturn(resolvedPod) - resolvedPod + val resolvedContainer = new ContainerBuilder().build() + when(executorBuilder.buildFromFeatures(Matchers.argThat( + new BaseMatcher[KubernetesConf[KubernetesExecutorSpecificConf]] { + override def matches(argument: scala.Any) + : Boolean = { + argument.isInstanceOf[KubernetesConf[KubernetesExecutorSpecificConf]] && + argument.asInstanceOf[KubernetesConf[KubernetesExecutorSpecificConf]] + .roleSpecificConf.executorId == executorId.toString + } + + override def describeTo(description: Description): Unit = {} + }))).thenReturn(SparkPod(resolvedPod, resolvedContainer)) + new PodBuilder(resolvedPod) + .editSpec() + .addToContainers(resolvedContainer) + .endSpec() + .build() } } diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala new file mode 100644 index 0000000000000..f5270623f8acc --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import io.fabric8.kubernetes.api.model.PodBuilder + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesExecutorSpecificConf, SparkPod} +import org.apache.spark.deploy.k8s.features.{BasicExecutorFeatureStep, KubernetesFeaturesTestUtils, MountSecretsFeatureStep} + +class KubernetesExecutorBuilderSuite extends SparkFunSuite { + private val BASIC_STEP_TYPE = "basic" + private val SECRETS_STEP_TYPE = "mount-secrets" + + private val basicFeatureStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( + BASIC_STEP_TYPE, classOf[BasicExecutorFeatureStep]) + private val mountSecretsStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( + SECRETS_STEP_TYPE, classOf[MountSecretsFeatureStep]) + + private val builderUnderTest = new KubernetesExecutorBuilder( + _ => basicFeatureStep, + _ => mountSecretsStep) + + test("Basic steps are consistently applied.") { + val conf = KubernetesConf( + new SparkConf(false), + KubernetesExecutorSpecificConf( + "executor-id", new PodBuilder().build()), + "prefix", + "appId", + Map.empty, + Map.empty, + Map.empty, + Map.empty) + validateStepTypesApplied(builderUnderTest.buildFromFeatures(conf), BASIC_STEP_TYPE) + } + + test("Apply secrets step if secrets are present.") { + val conf = KubernetesConf( + new SparkConf(false), + KubernetesExecutorSpecificConf( + "executor-id", new PodBuilder().build()), + "prefix", + "appId", + Map.empty, + Map.empty, + Map("secret" -> "secretMountPath"), + Map.empty) + validateStepTypesApplied( + builderUnderTest.buildFromFeatures(conf), + BASIC_STEP_TYPE, + SECRETS_STEP_TYPE) + } + + private def validateStepTypesApplied(resolvedPod: SparkPod, stepTypes: String*): Unit = { + assert(resolvedPod.pod.getMetadata.getLabels.size === stepTypes.size) + stepTypes.foreach { stepType => + assert(resolvedPod.pod.getMetadata.getLabels.get(stepType) === stepType) + } + } +} From 4dfd746de3f4346ed0c2191f8523a7e6cc9f064d Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Sat, 14 Apr 2018 00:22:38 +0800 Subject: [PATCH 0617/1161] [SPARK-23896][SQL] Improve PartitioningAwareFileIndex ## What changes were proposed in this pull request? Currently `PartitioningAwareFileIndex` accepts an optional parameter `userPartitionSchema`. If provided, it will combine the inferred partition schema with the parameter. However, 1. to get `userPartitionSchema`, we need to combine inferred partition schema with `userSpecifiedSchema` 2. to get the inferred partition schema, we have to create a temporary file index. Only after that, a final version of `PartitioningAwareFileIndex` can be created. This can be improved by passing `userSpecifiedSchema` to `PartitioningAwareFileIndex`. With the improvement, we can reduce redundant code and avoid parsing the file partition twice. ## How was this patch tested? Unit test Author: Gengliang Wang Closes #21004 from gengliangwang/PartitioningAwareFileIndex. --- .../datasources/CatalogFileIndex.scala | 2 +- .../execution/datasources/DataSource.scala | 133 ++++++++---------- .../datasources/InMemoryFileIndex.scala | 8 +- .../PartitioningAwareFileIndex.scala | 54 ++++--- .../streaming/MetadataLogFileIndex.scala | 10 +- .../datasources/FileSourceStrategySuite.scala | 2 +- .../hive/PartitionedTablePerfStatsSuite.scala | 2 +- 7 files changed, 103 insertions(+), 108 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CatalogFileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CatalogFileIndex.scala index 4046396d0e614..a66a07673e25f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CatalogFileIndex.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CatalogFileIndex.scala @@ -85,7 +85,7 @@ class CatalogFileIndex( sparkSession, new Path(baseLocation.get), fileStatusCache, partitionSpec, Option(timeNs)) } else { new InMemoryFileIndex( - sparkSession, rootPaths, table.storage.properties, partitionSchema = None) + sparkSession, rootPaths, table.storage.properties, userSpecifiedSchema = None) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index b84ea769808f9..f16d824201e77 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -23,7 +23,6 @@ import scala.collection.JavaConverters._ import scala.language.{existentials, implicitConversions} import scala.util.{Failure, Success, Try} -import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.spark.deploy.SparkHadoopUtil @@ -103,24 +102,6 @@ case class DataSource( bucket.sortColumnNames, "in the sort definition", equality) } - /** - * In the read path, only managed tables by Hive provide the partition columns properly when - * initializing this class. All other file based data sources will try to infer the partitioning, - * and then cast the inferred types to user specified dataTypes if the partition columns exist - * inside `userSpecifiedSchema`, otherwise we can hit data corruption bugs like SPARK-18510, or - * inconsistent data types as reported in SPARK-21463. - * @param fileIndex A FileIndex that will perform partition inference - * @return The PartitionSchema resolved from inference and cast according to `userSpecifiedSchema` - */ - private def combineInferredAndUserSpecifiedPartitionSchema(fileIndex: FileIndex): StructType = { - val resolved = fileIndex.partitionSchema.map { partitionField => - // SPARK-18510: try to get schema from userSpecifiedSchema, otherwise fallback to inferred - userSpecifiedSchema.flatMap(_.find(f => equality(f.name, partitionField.name))).getOrElse( - partitionField) - } - StructType(resolved) - } - /** * Get the schema of the given FileFormat, if provided by `userSpecifiedSchema`, or try to infer * it. In the read path, only managed tables by Hive provide the partition columns properly when @@ -140,31 +121,26 @@ case class DataSource( * be any further inference in any triggers. * * @param format the file format object for this DataSource - * @param fileStatusCache the shared cache for file statuses to speed up listing + * @param fileIndex optional [[InMemoryFileIndex]] for getting partition schema and file list * @return A pair of the data schema (excluding partition columns) and the schema of the partition * columns. */ private def getOrInferFileFormatSchema( format: FileFormat, - fileStatusCache: FileStatusCache = NoopCache): (StructType, StructType) = { - // the operations below are expensive therefore try not to do them if we don't need to, e.g., + fileIndex: Option[InMemoryFileIndex] = None): (StructType, StructType) = { + // The operations below are expensive therefore try not to do them if we don't need to, e.g., // in streaming mode, we have already inferred and registered partition columns, we will // never have to materialize the lazy val below - lazy val tempFileIndex = { - val allPaths = caseInsensitiveOptions.get("path") ++ paths - val hadoopConf = sparkSession.sessionState.newHadoopConf() - val globbedPaths = allPaths.toSeq.flatMap { path => - val hdfsPath = new Path(path) - val fs = hdfsPath.getFileSystem(hadoopConf) - val qualified = hdfsPath.makeQualified(fs.getUri, fs.getWorkingDirectory) - SparkHadoopUtil.get.globPathIfNecessary(fs, qualified) - }.toArray - new InMemoryFileIndex(sparkSession, globbedPaths, options, None, fileStatusCache) + lazy val tempFileIndex = fileIndex.getOrElse { + val globbedPaths = + checkAndGlobPathIfNecessary(checkEmptyGlobPath = false, checkFilesExist = false) + createInMemoryFileIndex(globbedPaths) } + val partitionSchema = if (partitionColumns.isEmpty) { // Try to infer partitioning, because no DataSource in the read path provides the partitioning // columns properly unless it is a Hive DataSource - combineInferredAndUserSpecifiedPartitionSchema(tempFileIndex) + tempFileIndex.partitionSchema } else { // maintain old behavior before SPARK-18510. If userSpecifiedSchema is empty used inferred // partitioning @@ -356,13 +332,7 @@ case class DataSource( caseInsensitiveOptions.get("path").toSeq ++ paths, sparkSession.sessionState.newHadoopConf()) => val basePath = new Path((caseInsensitiveOptions.get("path").toSeq ++ paths).head) - val tempFileCatalog = new MetadataLogFileIndex(sparkSession, basePath, None) - val fileCatalog = if (userSpecifiedSchema.nonEmpty) { - val partitionSchema = combineInferredAndUserSpecifiedPartitionSchema(tempFileCatalog) - new MetadataLogFileIndex(sparkSession, basePath, Option(partitionSchema)) - } else { - tempFileCatalog - } + val fileCatalog = new MetadataLogFileIndex(sparkSession, basePath, userSpecifiedSchema) val dataSchema = userSpecifiedSchema.orElse { format.inferSchema( sparkSession, @@ -384,24 +354,23 @@ case class DataSource( // This is a non-streaming file based datasource. case (format: FileFormat, _) => - val allPaths = caseInsensitiveOptions.get("path") ++ paths - val hadoopConf = sparkSession.sessionState.newHadoopConf() - val globbedPaths = allPaths.flatMap( - DataSource.checkAndGlobPathIfNecessary(hadoopConf, _, checkFilesExist)).toArray - - val fileStatusCache = FileStatusCache.getOrCreate(sparkSession) - val (dataSchema, partitionSchema) = getOrInferFileFormatSchema(format, fileStatusCache) - - val fileCatalog = if (sparkSession.sqlContext.conf.manageFilesourcePartitions && - catalogTable.isDefined && catalogTable.get.tracksPartitionsInCatalog) { + val globbedPaths = + checkAndGlobPathIfNecessary(checkEmptyGlobPath = true, checkFilesExist = checkFilesExist) + val useCatalogFileIndex = sparkSession.sqlContext.conf.manageFilesourcePartitions && + catalogTable.isDefined && catalogTable.get.tracksPartitionsInCatalog && + catalogTable.get.partitionColumnNames.nonEmpty + val (fileCatalog, dataSchema, partitionSchema) = if (useCatalogFileIndex) { val defaultTableSize = sparkSession.sessionState.conf.defaultSizeInBytes - new CatalogFileIndex( + val index = new CatalogFileIndex( sparkSession, catalogTable.get, catalogTable.get.stats.map(_.sizeInBytes.toLong).getOrElse(defaultTableSize)) + (index, catalogTable.get.dataSchema, catalogTable.get.partitionSchema) } else { - new InMemoryFileIndex( - sparkSession, globbedPaths, options, Some(partitionSchema), fileStatusCache) + val index = createInMemoryFileIndex(globbedPaths) + val (resultDataSchema, resultPartitionSchema) = + getOrInferFileFormatSchema(format, Some(index)) + (index, resultDataSchema, resultPartitionSchema) } HadoopFsRelation( @@ -552,6 +521,40 @@ case class DataSource( sys.error(s"${providingClass.getCanonicalName} does not allow create table as select.") } } + + /** Returns an [[InMemoryFileIndex]] that can be used to get partition schema and file list. */ + private def createInMemoryFileIndex(globbedPaths: Seq[Path]): InMemoryFileIndex = { + val fileStatusCache = FileStatusCache.getOrCreate(sparkSession) + new InMemoryFileIndex( + sparkSession, globbedPaths, options, userSpecifiedSchema, fileStatusCache) + } + + /** + * Checks and returns files in all the paths. + */ + private def checkAndGlobPathIfNecessary( + checkEmptyGlobPath: Boolean, + checkFilesExist: Boolean): Seq[Path] = { + val allPaths = caseInsensitiveOptions.get("path") ++ paths + val hadoopConf = sparkSession.sessionState.newHadoopConf() + allPaths.flatMap { path => + val hdfsPath = new Path(path) + val fs = hdfsPath.getFileSystem(hadoopConf) + val qualified = hdfsPath.makeQualified(fs.getUri, fs.getWorkingDirectory) + val globPath = SparkHadoopUtil.get.globPathIfNecessary(fs, qualified) + + if (checkEmptyGlobPath && globPath.isEmpty) { + throw new AnalysisException(s"Path does not exist: $qualified") + } + + // Sufficient to check head of the globPath seq for non-glob scenario + // Don't need to check once again if files exist in streaming mode + if (checkFilesExist && !fs.exists(globPath.head)) { + throw new AnalysisException(s"Path does not exist: ${globPath.head}") + } + globPath + }.toSeq + } } object DataSource extends Logging { @@ -699,30 +702,6 @@ object DataSource extends Logging { locationUri = path.map(CatalogUtils.stringToURI), properties = optionsWithoutPath) } - /** - * If `path` is a file pattern, return all the files that match it. Otherwise, return itself. - * If `checkFilesExist` is `true`, also check the file existence. - */ - private def checkAndGlobPathIfNecessary( - hadoopConf: Configuration, - path: String, - checkFilesExist: Boolean): Seq[Path] = { - val hdfsPath = new Path(path) - val fs = hdfsPath.getFileSystem(hadoopConf) - val qualified = hdfsPath.makeQualified(fs.getUri, fs.getWorkingDirectory) - val globPath = SparkHadoopUtil.get.globPathIfNecessary(fs, qualified) - - if (globPath.isEmpty) { - throw new AnalysisException(s"Path does not exist: $qualified") - } - // Sufficient to check head of the globPath seq for non-glob scenario - // Don't need to check once again if files exist in streaming mode - if (checkFilesExist && !fs.exists(globPath.head)) { - throw new AnalysisException(s"Path does not exist: ${globPath.head}") - } - globPath - } - /** * Called before writing into a FileFormat based data source to make sure the * supplied schema is not empty. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala index 318ada0ceefc5..739d1f456e3ec 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala @@ -41,17 +41,17 @@ import org.apache.spark.util.SerializableConfiguration * @param rootPathsSpecified the list of root table paths to scan (some of which might be * filtered out later) * @param parameters as set of options to control discovery - * @param partitionSchema an optional partition schema that will be use to provide types for the - * discovered partitions + * @param userSpecifiedSchema an optional user specified schema that will be use to provide + * types for the discovered partitions */ class InMemoryFileIndex( sparkSession: SparkSession, rootPathsSpecified: Seq[Path], parameters: Map[String, String], - partitionSchema: Option[StructType], + userSpecifiedSchema: Option[StructType], fileStatusCache: FileStatusCache = NoopCache) extends PartitioningAwareFileIndex( - sparkSession, parameters, partitionSchema, fileStatusCache) { + sparkSession, parameters, userSpecifiedSchema, fileStatusCache) { // Filter out streaming metadata dirs or files such as "/.../_spark_metadata" (the metadata dir) // or "/.../_spark_metadata/0" (a file in the metadata dir). `rootPathsSpecified` might contain diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala index 6b6f6388d54e8..cc8af7b92c454 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala @@ -34,13 +34,13 @@ import org.apache.spark.sql.types.{StringType, StructType} * It provides the necessary methods to parse partition data based on a set of files. * * @param parameters as set of options to control partition discovery - * @param userPartitionSchema an optional partition schema that will be use to provide types for - * the discovered partitions + * @param userSpecifiedSchema an optional user specified schema that will be use to provide + * types for the discovered partitions */ abstract class PartitioningAwareFileIndex( sparkSession: SparkSession, parameters: Map[String, String], - userPartitionSchema: Option[StructType], + userSpecifiedSchema: Option[StructType], fileStatusCache: FileStatusCache = NoopCache) extends FileIndex with Logging { import PartitioningAwareFileIndex.BASE_PATH_PARAM @@ -126,35 +126,32 @@ abstract class PartitioningAwareFileIndex( val caseInsensitiveOptions = CaseInsensitiveMap(parameters) val timeZoneId = caseInsensitiveOptions.get(DateTimeUtils.TIMEZONE_OPTION) .getOrElse(sparkSession.sessionState.conf.sessionLocalTimeZone) - - userPartitionSchema match { + val inferredPartitionSpec = PartitioningUtils.parsePartitions( + leafDirs, + typeInference = sparkSession.sessionState.conf.partitionColumnTypeInferenceEnabled, + basePaths = basePaths, + timeZoneId = timeZoneId) + userSpecifiedSchema match { case Some(userProvidedSchema) if userProvidedSchema.nonEmpty => - val spec = PartitioningUtils.parsePartitions( - leafDirs, - typeInference = false, - basePaths = basePaths, - timeZoneId = timeZoneId) + val userPartitionSchema = + combineInferredAndUserSpecifiedPartitionSchema(inferredPartitionSpec) - // Without auto inference, all of value in the `row` should be null or in StringType, // we need to cast into the data type that user specified. def castPartitionValuesToUserSchema(row: InternalRow) = { InternalRow((0 until row.numFields).map { i => + val dt = inferredPartitionSpec.partitionColumns.fields(i).dataType Cast( - Literal.create(row.getUTF8String(i), StringType), - userProvidedSchema.fields(i).dataType, + Literal.create(row.get(i, dt), dt), + userPartitionSchema.fields(i).dataType, Option(timeZoneId)).eval() }: _*) } - PartitionSpec(userProvidedSchema, spec.partitions.map { part => + PartitionSpec(userPartitionSchema, inferredPartitionSpec.partitions.map { part => part.copy(values = castPartitionValuesToUserSchema(part.values)) }) case _ => - PartitioningUtils.parsePartitions( - leafDirs, - typeInference = sparkSession.sessionState.conf.partitionColumnTypeInferenceEnabled, - basePaths = basePaths, - timeZoneId = timeZoneId) + inferredPartitionSpec } } @@ -236,6 +233,25 @@ abstract class PartitioningAwareFileIndex( val name = path.getName !((name.startsWith("_") && !name.contains("=")) || name.startsWith(".")) } + + /** + * In the read path, only managed tables by Hive provide the partition columns properly when + * initializing this class. All other file based data sources will try to infer the partitioning, + * and then cast the inferred types to user specified dataTypes if the partition columns exist + * inside `userSpecifiedSchema`, otherwise we can hit data corruption bugs like SPARK-18510, or + * inconsistent data types as reported in SPARK-21463. + * @param spec A partition inference result + * @return The PartitionSchema resolved from inference and cast according to `userSpecifiedSchema` + */ + private def combineInferredAndUserSpecifiedPartitionSchema(spec: PartitionSpec): StructType = { + val equality = sparkSession.sessionState.conf.resolver + val resolved = spec.partitionColumns.map { partitionField => + // SPARK-18510: try to get schema from userSpecifiedSchema, otherwise fallback to inferred + userSpecifiedSchema.flatMap(_.find(f => equality(f.name, partitionField.name))).getOrElse( + partitionField) + } + StructType(resolved) + } } object PartitioningAwareFileIndex { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetadataLogFileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetadataLogFileIndex.scala index 1da703cefd8ea..5cacdd070b735 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetadataLogFileIndex.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetadataLogFileIndex.scala @@ -30,14 +30,14 @@ import org.apache.spark.sql.types.StructType * A [[FileIndex]] that generates the list of files to processing by reading them from the * metadata log files generated by the [[FileStreamSink]]. * - * @param userPartitionSchema an optional partition schema that will be use to provide types for - * the discovered partitions + * @param userSpecifiedSchema an optional user specified schema that will be use to provide + * types for the discovered partitions */ class MetadataLogFileIndex( sparkSession: SparkSession, path: Path, - userPartitionSchema: Option[StructType]) - extends PartitioningAwareFileIndex(sparkSession, Map.empty, userPartitionSchema) { + userSpecifiedSchema: Option[StructType]) + extends PartitioningAwareFileIndex(sparkSession, Map.empty, userSpecifiedSchema) { private val metadataDirectory = new Path(path, FileStreamSink.metadataDir) logInfo(s"Reading streaming file log from $metadataDirectory") @@ -51,7 +51,7 @@ class MetadataLogFileIndex( } override protected val leafDirToChildrenFiles: Map[Path, Array[FileStatus]] = { - allFilesFromLog.toArray.groupBy(_.getPath.getParent) + allFilesFromLog.groupBy(_.getPath.getParent) } override def rootPaths: Seq[Path] = path :: Nil diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala index c1d61b843d899..8764f0c42cf9f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala @@ -401,7 +401,7 @@ class FileSourceStrategySuite extends QueryTest with SharedSQLContext with Predi sparkSession = spark, rootPathsSpecified = Seq(new Path(tempDir)), parameters = Map.empty[String, String], - partitionSchema = None) + userSpecifiedSchema = None) // This should not fail. fileCatalog.listLeafFiles(Seq(new Path(tempDir))) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionedTablePerfStatsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionedTablePerfStatsSuite.scala index 1a86c604d5da3..3af163af0968c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionedTablePerfStatsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionedTablePerfStatsSuite.scala @@ -419,7 +419,7 @@ class PartitionedTablePerfStatsSuite HiveCatalogMetrics.reset() spark.read.load(dir.getAbsolutePath) assert(HiveCatalogMetrics.METRIC_FILES_DISCOVERED.getCount() == 1) - assert(HiveCatalogMetrics.METRIC_FILE_CACHE_HITS.getCount() == 1) + assert(HiveCatalogMetrics.METRIC_FILE_CACHE_HITS.getCount() == 0) } } } From 25892f3cc9dcb938220be8020a5b9a17c92dbdbe Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Sat, 14 Apr 2018 01:01:00 +0800 Subject: [PATCH 0618/1161] [SPARK-23375][SQL] Eliminate unneeded Sort in Optimizer ## What changes were proposed in this pull request? Added a new rule to remove Sort operation when its child is already sorted. For instance, this simple code: ``` spark.sparkContext.parallelize(Seq(("a", "b"))).toDF("a", "b").registerTempTable("table1") val df = sql(s"""SELECT b | FROM ( | SELECT a, b | FROM table1 | ORDER BY a | ) t | ORDER BY a""".stripMargin) df.explain(true) ``` before the PR produces this plan: ``` == Parsed Logical Plan == 'Sort ['a ASC NULLS FIRST], true +- 'Project ['b] +- 'SubqueryAlias t +- 'Sort ['a ASC NULLS FIRST], true +- 'Project ['a, 'b] +- 'UnresolvedRelation `table1` == Analyzed Logical Plan == b: string Project [b#7] +- Sort [a#6 ASC NULLS FIRST], true +- Project [b#7, a#6] +- SubqueryAlias t +- Sort [a#6 ASC NULLS FIRST], true +- Project [a#6, b#7] +- SubqueryAlias table1 +- Project [_1#3 AS a#6, _2#4 AS b#7] +- SerializeFromObject [staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(assertnotnull(input[0, scala.Tuple2, true]))._1, true, false) AS _1#3, staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(assertnotnull(input[0, scala.Tuple2, true]))._2, true, false) AS _2#4] +- ExternalRDD [obj#2] == Optimized Logical Plan == Project [b#7] +- Sort [a#6 ASC NULLS FIRST], true +- Project [b#7, a#6] +- Sort [a#6 ASC NULLS FIRST], true +- Project [_1#3 AS a#6, _2#4 AS b#7] +- SerializeFromObject [staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(input[0, scala.Tuple2, true])._1, true, false) AS _1#3, staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(input[0, scala.Tuple2, true])._2, true, false) AS _2#4] +- ExternalRDD [obj#2] == Physical Plan == *(3) Project [b#7] +- *(3) Sort [a#6 ASC NULLS FIRST], true, 0 +- Exchange rangepartitioning(a#6 ASC NULLS FIRST, 200) +- *(2) Project [b#7, a#6] +- *(2) Sort [a#6 ASC NULLS FIRST], true, 0 +- Exchange rangepartitioning(a#6 ASC NULLS FIRST, 200) +- *(1) Project [_1#3 AS a#6, _2#4 AS b#7] +- *(1) SerializeFromObject [staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(input[0, scala.Tuple2, true])._1, true, false) AS _1#3, staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(input[0, scala.Tuple2, true])._2, true, false) AS _2#4] +- Scan ExternalRDDScan[obj#2] ``` while after the PR produces: ``` == Parsed Logical Plan == 'Sort ['a ASC NULLS FIRST], true +- 'Project ['b] +- 'SubqueryAlias t +- 'Sort ['a ASC NULLS FIRST], true +- 'Project ['a, 'b] +- 'UnresolvedRelation `table1` == Analyzed Logical Plan == b: string Project [b#7] +- Sort [a#6 ASC NULLS FIRST], true +- Project [b#7, a#6] +- SubqueryAlias t +- Sort [a#6 ASC NULLS FIRST], true +- Project [a#6, b#7] +- SubqueryAlias table1 +- Project [_1#3 AS a#6, _2#4 AS b#7] +- SerializeFromObject [staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(assertnotnull(input[0, scala.Tuple2, true]))._1, true, false) AS _1#3, staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(assertnotnull(input[0, scala.Tuple2, true]))._2, true, false) AS _2#4] +- ExternalRDD [obj#2] == Optimized Logical Plan == Project [b#7] +- Sort [a#6 ASC NULLS FIRST], true +- Project [_1#3 AS a#6, _2#4 AS b#7] +- SerializeFromObject [staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(input[0, scala.Tuple2, true])._1, true, false) AS _1#3, staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(input[0, scala.Tuple2, true])._2, true, false) AS _2#4] +- ExternalRDD [obj#2] == Physical Plan == *(2) Project [b#7] +- *(2) Sort [a#6 ASC NULLS FIRST], true, 0 +- Exchange rangepartitioning(a#6 ASC NULLS FIRST, 5) +- *(1) Project [_1#3 AS a#6, _2#4 AS b#7] +- *(1) SerializeFromObject [staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(input[0, scala.Tuple2, true])._1, true, false) AS _1#3, staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(input[0, scala.Tuple2, true])._2, true, false) AS _2#4] +- Scan ExternalRDDScan[obj#2] ``` this means that an unnecessary sort operation is not performed after the PR. ## How was this patch tested? added UT Author: Marco Gaido Closes #20560 from mgaido91/SPARK-23375. --- .../sql/catalyst/optimizer/Optimizer.scala | 12 +++ .../catalyst/plans/logical/LogicalPlan.scala | 9 ++ .../plans/logical/basicLogicalOperators.scala | 23 ++-- .../optimizer/RemoveRedundantSortsSuite.scala | 101 ++++++++++++++++++ .../spark/sql/execution/CacheManager.scala | 4 +- .../spark/sql/execution/ExistingRDD.scala | 2 +- .../execution/columnar/InMemoryRelation.scala | 17 +-- .../spark/sql/ConfigBehaviorSuite.scala | 2 +- .../spark/sql/execution/PlannerSuite.scala | 15 ++- .../columnar/InMemoryColumnarQuerySuite.scala | 14 +-- 10 files changed, 175 insertions(+), 24 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantSortsSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 9a1bbc675e397..5fb59ef350b8b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -138,6 +138,8 @@ abstract class Optimizer(sessionCatalog: SessionCatalog) operatorOptimizationBatch) :+ Batch("Join Reorder", Once, CostBasedJoinReorder) :+ + Batch("Remove Redundant Sorts", Once, + RemoveRedundantSorts) :+ Batch("Decimal Optimizations", fixedPoint, DecimalAggregates) :+ Batch("Object Expressions Optimization", fixedPoint, @@ -733,6 +735,16 @@ object EliminateSorts extends Rule[LogicalPlan] { } } +/** + * Removes Sort operation if the child is already sorted + */ +object RemoveRedundantSorts extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case Sort(orders, true, child) if SortOrder.orderingSatisfies(child.outputOrdering, orders) => + child + } +} + /** * Removes filters that can be evaluated trivially. This can be done through the following ways: * 1) by eliding the filter for cases where it will always evaluate to `true`. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index c8ccd9bd03994..42034403d6d03 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -219,6 +219,11 @@ abstract class LogicalPlan * Refreshes (or invalidates) any metadata/data cached in the plan recursively. */ def refresh(): Unit = children.foreach(_.refresh()) + + /** + * Returns the output ordering that this plan generates. + */ + def outputOrdering: Seq[SortOrder] = Nil } /** @@ -274,3 +279,7 @@ abstract class BinaryNode extends LogicalPlan { override final def children: Seq[LogicalPlan] = Seq(left, right) } + +abstract class OrderPreservingUnaryNode extends UnaryNode { + override final def outputOrdering: Seq[SortOrder] = child.outputOrdering +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index a4fca790dd086..10df504795430 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -43,11 +43,12 @@ case class ReturnAnswer(child: LogicalPlan) extends UnaryNode { * This node is inserted at the top of a subquery when it is optimized. This makes sure we can * recognize a subquery as such, and it allows us to write subquery aware transformations. */ -case class Subquery(child: LogicalPlan) extends UnaryNode { +case class Subquery(child: LogicalPlan) extends OrderPreservingUnaryNode { override def output: Seq[Attribute] = child.output } -case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extends UnaryNode { +case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) + extends OrderPreservingUnaryNode { override def output: Seq[Attribute] = projectList.map(_.toAttribute) override def maxRows: Option[Long] = child.maxRows @@ -125,7 +126,7 @@ case class Generate( } case class Filter(condition: Expression, child: LogicalPlan) - extends UnaryNode with PredicateHelper { + extends OrderPreservingUnaryNode with PredicateHelper { override def output: Seq[Attribute] = child.output override def maxRows: Option[Long] = child.maxRows @@ -469,6 +470,7 @@ case class Sort( child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output override def maxRows: Option[Long] = child.maxRows + override def outputOrdering: Seq[SortOrder] = order } /** Factory for constructing new `Range` nodes. */ @@ -522,6 +524,15 @@ case class Range( override def computeStats(): Statistics = { Statistics(sizeInBytes = LongType.defaultSize * numElements) } + + override def outputOrdering: Seq[SortOrder] = { + val order = if (step > 0) { + Ascending + } else { + Descending + } + output.map(a => SortOrder(a, order)) + } } case class Aggregate( @@ -728,7 +739,7 @@ object Limit { * * See [[Limit]] for more information. */ -case class GlobalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode { +case class GlobalLimit(limitExpr: Expression, child: LogicalPlan) extends OrderPreservingUnaryNode { override def output: Seq[Attribute] = child.output override def maxRows: Option[Long] = { limitExpr match { @@ -744,7 +755,7 @@ case class GlobalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryN * * See [[Limit]] for more information. */ -case class LocalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode { +case class LocalLimit(limitExpr: Expression, child: LogicalPlan) extends OrderPreservingUnaryNode { override def output: Seq[Attribute] = child.output override def maxRowsPerPartition: Option[Long] = { @@ -764,7 +775,7 @@ case class LocalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryNo case class SubqueryAlias( alias: String, child: LogicalPlan) - extends UnaryNode { + extends OrderPreservingUnaryNode { override def doCanonicalize(): LogicalPlan = child.canonicalized diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantSortsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantSortsSuite.scala new file mode 100644 index 0000000000000..2319ab8046e56 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantSortsSuite.scala @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.analysis.{Analyzer, EmptyFunctionRegistry} +import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog} +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules._ +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.SQLConf.{CASE_SENSITIVE, ORDER_BY_ORDINAL} + +class RemoveRedundantSortsSuite extends PlanTest { + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("Remove Redundant Sorts", Once, + RemoveRedundantSorts) :: + Batch("Collapse Project", Once, + CollapseProject) :: Nil + } + + val testRelation = LocalRelation('a.int, 'b.int, 'c.int) + + test("remove redundant order by") { + val orderedPlan = testRelation.select('a, 'b).orderBy('a.asc, 'b.desc_nullsFirst) + val unnecessaryReordered = orderedPlan.select('a).orderBy('a.asc, 'b.desc_nullsFirst) + val optimized = Optimize.execute(unnecessaryReordered.analyze) + val correctAnswer = orderedPlan.select('a).analyze + comparePlans(Optimize.execute(optimized), correctAnswer) + } + + test("do not remove sort if the order is different") { + val orderedPlan = testRelation.select('a, 'b).orderBy('a.asc, 'b.desc_nullsFirst) + val reorderedDifferently = orderedPlan.select('a).orderBy('a.asc, 'b.desc) + val optimized = Optimize.execute(reorderedDifferently.analyze) + val correctAnswer = reorderedDifferently.analyze + comparePlans(optimized, correctAnswer) + } + + test("filters don't affect order") { + val orderedPlan = testRelation.select('a, 'b).orderBy('a.asc, 'b.desc) + val filteredAndReordered = orderedPlan.where('a > Literal(10)).orderBy('a.asc, 'b.desc) + val optimized = Optimize.execute(filteredAndReordered.analyze) + val correctAnswer = orderedPlan.where('a > Literal(10)).analyze + comparePlans(optimized, correctAnswer) + } + + test("limits don't affect order") { + val orderedPlan = testRelation.select('a, 'b).orderBy('a.asc, 'b.desc) + val filteredAndReordered = orderedPlan.limit(Literal(10)).orderBy('a.asc, 'b.desc) + val optimized = Optimize.execute(filteredAndReordered.analyze) + val correctAnswer = orderedPlan.limit(Literal(10)).analyze + comparePlans(optimized, correctAnswer) + } + + test("range is already sorted") { + val inputPlan = Range(1L, 1000L, 1, 10) + val orderedPlan = inputPlan.orderBy('id.asc) + val optimized = Optimize.execute(orderedPlan.analyze) + val correctAnswer = inputPlan.analyze + comparePlans(optimized, correctAnswer) + + val reversedPlan = inputPlan.orderBy('id.desc) + val reversedOptimized = Optimize.execute(reversedPlan.analyze) + val reversedCorrectAnswer = reversedPlan.analyze + comparePlans(reversedOptimized, reversedCorrectAnswer) + + val negativeStepInputPlan = Range(10L, 1L, -1, 10) + val negativeStepOrderedPlan = negativeStepInputPlan.orderBy('id.desc) + val negativeStepOptimized = Optimize.execute(negativeStepOrderedPlan.analyze) + val negativeStepCorrectAnswer = negativeStepInputPlan.analyze + comparePlans(negativeStepOptimized, negativeStepCorrectAnswer) + } + + test("sort should not be removed when there is a node which doesn't guarantee any order") { + val orderedPlan = testRelation.select('a, 'b).orderBy('a.asc) + val groupedAndResorted = orderedPlan.groupBy('a)(sum('a)).orderBy('a.asc) + val optimized = Optimize.execute(groupedAndResorted.analyze) + val correctAnswer = groupedAndResorted.analyze + comparePlans(optimized, correctAnswer) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala index d68aeb275afda..a8794be7280c7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala @@ -99,7 +99,7 @@ class CacheManager extends Logging { sparkSession.sessionState.conf.columnBatchSize, storageLevel, sparkSession.sessionState.executePlan(planToCache).executedPlan, tableName, - planToCache.stats) + planToCache) cachedData.add(CachedData(planToCache, inMemoryRelation)) } } @@ -148,7 +148,7 @@ class CacheManager extends Logging { storageLevel = cd.cachedRepresentation.storageLevel, child = spark.sessionState.executePlan(cd.plan).executedPlan, tableName = cd.cachedRepresentation.tableName, - statsOfPlanToCache = cd.plan.stats) + logicalPlan = cd.plan) needToRecache += cd.copy(cachedRepresentation = newCache) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index f3555508185fe..be50a1571a2ff 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -125,7 +125,7 @@ case class LogicalRDD( output: Seq[Attribute], rdd: RDD[InternalRow], outputPartitioning: Partitioning = UnknownPartitioning(0), - outputOrdering: Seq[SortOrder] = Nil, + override val outputOrdering: Seq[SortOrder] = Nil, override val isStreaming: Boolean = false)(session: SparkSession) extends LeafNode with MultiInstanceRelation { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala index 2579046e30708..a7ba9b86a176f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical -import org.apache.spark.sql.catalyst.plans.logical.{HintInfo, Statistics} +import org.apache.spark.sql.catalyst.plans.logical.{HintInfo, LogicalPlan, Statistics} import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.storage.StorageLevel import org.apache.spark.util.LongAccumulator @@ -39,9 +39,9 @@ object InMemoryRelation { storageLevel: StorageLevel, child: SparkPlan, tableName: Option[String], - statsOfPlanToCache: Statistics): InMemoryRelation = + logicalPlan: LogicalPlan): InMemoryRelation = new InMemoryRelation(child.output, useCompression, batchSize, storageLevel, child, tableName)( - statsOfPlanToCache = statsOfPlanToCache) + statsOfPlanToCache = logicalPlan.stats, outputOrdering = logicalPlan.outputOrdering) } @@ -64,7 +64,8 @@ case class InMemoryRelation( tableName: Option[String])( @transient var _cachedColumnBuffers: RDD[CachedBatch] = null, val sizeInBytesStats: LongAccumulator = child.sqlContext.sparkContext.longAccumulator, - statsOfPlanToCache: Statistics) + statsOfPlanToCache: Statistics, + override val outputOrdering: Seq[SortOrder]) extends logical.LeafNode with MultiInstanceRelation { override protected def innerChildren: Seq[SparkPlan] = Seq(child) @@ -76,7 +77,8 @@ case class InMemoryRelation( tableName = None)( _cachedColumnBuffers, sizeInBytesStats, - statsOfPlanToCache) + statsOfPlanToCache, + outputOrdering) override def producedAttributes: AttributeSet = outputSet @@ -159,7 +161,7 @@ case class InMemoryRelation( def withOutput(newOutput: Seq[Attribute]): InMemoryRelation = { InMemoryRelation( newOutput, useCompression, batchSize, storageLevel, child, tableName)( - _cachedColumnBuffers, sizeInBytesStats, statsOfPlanToCache) + _cachedColumnBuffers, sizeInBytesStats, statsOfPlanToCache, outputOrdering) } override def newInstance(): this.type = { @@ -172,7 +174,8 @@ case class InMemoryRelation( tableName)( _cachedColumnBuffers, sizeInBytesStats, - statsOfPlanToCache).asInstanceOf[this.type] + statsOfPlanToCache, + outputOrdering).asInstanceOf[this.type] } def cachedColumnBuffers: RDD[CachedBatch] = _cachedColumnBuffers diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ConfigBehaviorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ConfigBehaviorSuite.scala index cee85ec8af04d..949505e449fd7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ConfigBehaviorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ConfigBehaviorSuite.scala @@ -39,7 +39,7 @@ class ConfigBehaviorSuite extends QueryTest with SharedSQLContext { def computeChiSquareTest(): Double = { val n = 10000 // Trigger a sort - val data = spark.range(0, n, 1, 1).sort('id) + val data = spark.range(0, n, 1, 1).sort('id.desc) .selectExpr("SPARK_PARTITION_ID() pid", "id").as[(Int, Long)].collect() // Compute histogram for the number of records per partition post sort diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index f8b26f5b28cc7..40915a102bab0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.{execution, Row} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.{Cross, FullOuter, Inner, LeftOuter, RightOuter} -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Repartition} +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Repartition, Sort} import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.columnar.InMemoryRelation import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReusedExchangeExec, ReuseExchange, ShuffleExchangeExec} @@ -197,6 +197,19 @@ class PlannerSuite extends SharedSQLContext { assert(planned.child.isInstanceOf[CollectLimitExec]) } + test("SPARK-23375: Cached sorted data doesn't need to be re-sorted") { + val query = testData.select('key, 'value).sort('key.desc).cache() + assert(query.queryExecution.optimizedPlan.isInstanceOf[InMemoryRelation]) + val resorted = query.sort('key.desc) + assert(resorted.queryExecution.optimizedPlan.collect { case s: Sort => s}.isEmpty) + assert(resorted.select('key).collect().map(_.getInt(0)).toSeq == + (1 to 100).reverse) + // with a different order, the sort is needed + val sortedAsc = query.sort('key) + assert(sortedAsc.queryExecution.optimizedPlan.collect { case s: Sort => s}.size == 1) + assert(sortedAsc.select('key).collect().map(_.getInt(0)).toSeq == (1 to 100)) + } + test("PartitioningCollection") { withTempView("normal", "small", "tiny") { testData.createOrReplaceTempView("normal") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala index 26b63e8e8490f..9b7b316211d30 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala @@ -22,6 +22,7 @@ import java.sql.{Date, Timestamp} import org.apache.spark.sql.{DataFrame, QueryTest, Row} import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeSet, In} +import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning import org.apache.spark.sql.execution.{FilterExec, LocalTableScanExec, WholeStageCodegenExec} import org.apache.spark.sql.functions._ @@ -42,7 +43,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { val storageLevel = MEMORY_ONLY val plan = spark.sessionState.executePlan(data.logicalPlan).sparkPlan val inMemoryRelation = InMemoryRelation(useCompression = true, 5, storageLevel, plan, None, - data.logicalPlan.stats) + data.logicalPlan) assert(inMemoryRelation.cachedColumnBuffers.getStorageLevel == storageLevel) inMemoryRelation.cachedColumnBuffers.collect().head match { @@ -119,7 +120,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { test("simple columnar query") { val plan = spark.sessionState.executePlan(testData.logicalPlan).sparkPlan val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan, None, - testData.logicalPlan.stats) + testData.logicalPlan) checkAnswer(scan, testData.collect().toSeq) } @@ -138,7 +139,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { val logicalPlan = testData.select('value, 'key).logicalPlan val plan = spark.sessionState.executePlan(logicalPlan).sparkPlan val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan, None, - logicalPlan.stats) + logicalPlan) checkAnswer(scan, testData.collect().map { case Row(key: Int, value: String) => value -> key @@ -155,7 +156,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { test("SPARK-1436 regression: in-memory columns must be able to be accessed multiple times") { val plan = spark.sessionState.executePlan(testData.logicalPlan).sparkPlan val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan, None, - testData.logicalPlan.stats) + testData.logicalPlan) checkAnswer(scan, testData.collect().toSeq) checkAnswer(scan, testData.collect().toSeq) @@ -329,7 +330,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { test("SPARK-17549: cached table size should be correctly calculated") { val data = spark.sparkContext.parallelize(1 to 10, 5).toDF() val plan = spark.sessionState.executePlan(data.logicalPlan).sparkPlan - val cached = InMemoryRelation(true, 5, MEMORY_ONLY, plan, None, data.logicalPlan.stats) + val cached = InMemoryRelation(true, 5, MEMORY_ONLY, plan, None, data.logicalPlan) // Materialize the data. val expectedAnswer = data.collect() @@ -455,7 +456,8 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { test("SPARK-22249: buildFilter should not throw exception when In contains an empty list") { val attribute = AttributeReference("a", IntegerType)() val localTableScanExec = LocalTableScanExec(Seq(attribute), Nil) - val testRelation = InMemoryRelation(false, 1, MEMORY_ONLY, localTableScanExec, None, null) + val testRelation = InMemoryRelation(false, 1, MEMORY_ONLY, localTableScanExec, None, + LocalRelation(Seq(attribute), Nil)) val tableScanExec = InMemoryTableScanExec(Seq(attribute), Seq(In(attribute, Nil)), testRelation) assert(tableScanExec.partitionFilters.isEmpty) From 558f31b31c73b7e9f26f56498b54cf53997b59b8 Mon Sep 17 00:00:00 2001 From: Bruce Robbins Date: Fri, 13 Apr 2018 14:05:04 -0700 Subject: [PATCH 0619/1161] [SPARK-23963][SQL] Properly handle large number of columns in query on text-based Hive table ## What changes were proposed in this pull request? TableReader would get disproportionately slower as the number of columns in the query increased. I fixed the way TableReader was looking up metadata for each column in the row. Previously, it had been looking up this data in linked lists, accessing each linked list by an index (column number). Now it looks up this data in arrays, where indexing by column number works better. ## How was this patch tested? Manual testing All sbt unit tests python sql tests Author: Bruce Robbins Closes #21043 from bersprockets/tabreadfix. --- .../src/main/scala/org/apache/spark/sql/hive/TableReader.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala index cc8907a0bbc93..b5444a4217924 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala @@ -381,7 +381,7 @@ private[hive] object HadoopTableReader extends HiveInspectors with Logging { val (fieldRefs, fieldOrdinals) = nonPartitionKeyAttrs.map { case (attr, ordinal) => soi.getStructFieldRef(attr.name) -> ordinal - }.unzip + }.toArray.unzip /** * Builds specific unwrappers ahead of time according to object inspector From cbb41a0c5b01579c85f06ef42cc0585fbef216c5 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Fri, 13 Apr 2018 16:31:39 -0700 Subject: [PATCH 0620/1161] [SPARK-23966][SS] Refactoring all checkpoint file writing logic in a common CheckpointFileManager interface ## What changes were proposed in this pull request? Checkpoint files (offset log files, state store files) in Structured Streaming must be written atomically such that no partial files are generated (would break fault-tolerance guarantees). Currently, there are 3 locations which try to do this individually, and in some cases, incorrectly. 1. HDFSOffsetMetadataLog - This uses a FileManager interface to use any implementation of `FileSystem` or `FileContext` APIs. It preferably loads `FileContext` implementation as FileContext of HDFS has atomic renames. 1. HDFSBackedStateStore (aka in-memory state store) - Writing a version.delta file - This uses FileSystem APIs only to perform a rename. This is incorrect as rename is not atomic in HDFS FileSystem implementation. - Writing a snapshot file - Same as above. #### Current problems: 1. State Store behavior is incorrect - HDFS FileSystem implementation does not have atomic rename. 1. Inflexible - Some file systems provide mechanisms other than write-to-temp-file-and-rename for writing atomically and more efficiently. For example, with S3 you can write directly to the final file and it will be made visible only when the entire file is written and closed correctly. Any failure can be made to terminate the writing without making any partial files visible in S3. The current code does not abstract out this mechanism enough that it can be customized. #### Solution: 1. Introduce a common interface that all 3 cases above can use to write checkpoint files atomically. 2. This interface must provide the necessary interfaces that allow customization of the write-and-rename mechanism. This PR does that by introducing the interface `CheckpointFileManager` and modifying `HDFSMetadataLog` and `HDFSBackedStateStore` to use the interface. Similar to earlier `FileManager`, there are implementations based on `FileSystem` and `FileContext` APIs, and the latter implementation is preferred to make it work correctly with HDFS. The key method this interface has is `createAtomic(path, overwrite)` which returns a `CancellableFSDataOutputStream` that has the method `cancel()`. All users of this method need to either call `close()` to successfully write the file, or `cancel()` in case of an error. ## How was this patch tested? New tests in `CheckpointFileManagerSuite` and slightly modified existing tests. Author: Tathagata Das Closes #21048 from tdas/SPARK-23966. --- .../apache/spark/sql/internal/SQLConf.scala | 7 + .../streaming/CheckpointFileManager.scala | 349 ++++++++++++++++++ .../execution/streaming/HDFSMetadataLog.scala | 229 +----------- .../state/HDFSBackedStateStoreProvider.scala | 120 +++--- .../streaming/state/StateStore.scala | 4 +- .../CheckpointFileManagerSuite.scala | 192 ++++++++++ .../CompactibleFileStreamLogSuite.scala | 5 - .../streaming/HDFSMetadataLogSuite.scala | 116 +----- .../streaming/state/StateStoreSuite.scala | 58 ++- 9 files changed, 678 insertions(+), 402 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CheckpointFileManager.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CheckpointFileManagerSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 1c8ab9c62623e..0dc47bfe075d0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -930,6 +930,13 @@ object SQLConf { .intConf .createWithDefault(100) + val STREAMING_CHECKPOINT_FILE_MANAGER_CLASS = + buildConf("spark.sql.streaming.checkpointFileManagerClass") + .doc("The class used to write checkpoint files atomically. This class must be a subclass " + + "of the interface CheckpointFileManager.") + .internal() + .stringConf + val NDV_MAX_ERROR = buildConf("spark.sql.statistics.ndv.maxError") .internal() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CheckpointFileManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CheckpointFileManager.scala new file mode 100644 index 0000000000000..606ba250ad9d2 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CheckpointFileManager.scala @@ -0,0 +1,349 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.streaming + +import java.io.{FileNotFoundException, IOException, OutputStream} +import java.util.{EnumSet, UUID} + +import scala.util.control.NonFatal + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs._ +import org.apache.hadoop.fs.local.{LocalFs, RawLocalFs} +import org.apache.hadoop.fs.permission.FsPermission + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.execution.streaming.CheckpointFileManager.RenameHelperMethods +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.util.Utils + +/** + * An interface to abstract out all operation related to streaming checkpoints. Most importantly, + * the key operation this interface provides is `createAtomic(path, overwrite)` which returns a + * `CancellableFSDataOutputStream`. This method is used by [[HDFSMetadataLog]] and + * [[org.apache.spark.sql.execution.streaming.state.StateStore StateStore]] implementations + * to write a complete checkpoint file atomically (i.e. no partial file will be visible), with or + * without overwrite. + * + * This higher-level interface above the Hadoop FileSystem is necessary because + * different implementation of FileSystem/FileContext may have different combination of operations + * to provide the desired atomic guarantees (e.g. write-to-temp-file-and-rename, + * direct-write-and-cancel-on-failure) and this abstraction allow different implementations while + * keeping the usage simple (`createAtomic` -> `close` or `cancel`). + */ +trait CheckpointFileManager { + + import org.apache.spark.sql.execution.streaming.CheckpointFileManager._ + + /** + * Create a file and make its contents available atomically after the output stream is closed. + * + * @param path Path to create + * @param overwriteIfPossible If true, then the implementations must do a best-effort attempt to + * overwrite the file if it already exists. It should not throw + * any exception if the file exists. However, if false, then the + * implementation must not overwrite if the file alraedy exists and + * must throw `FileAlreadyExistsException` in that case. + */ + def createAtomic(path: Path, overwriteIfPossible: Boolean): CancellableFSDataOutputStream + + /** Open a file for reading, or throw exception if it does not exist. */ + def open(path: Path): FSDataInputStream + + /** List the files in a path that match a filter. */ + def list(path: Path, filter: PathFilter): Array[FileStatus] + + /** List all the files in a path. */ + def list(path: Path): Array[FileStatus] = { + list(path, new PathFilter { override def accept(path: Path): Boolean = true }) + } + + /** Make directory at the give path and all its parent directories as needed. */ + def mkdirs(path: Path): Unit + + /** Whether path exists */ + def exists(path: Path): Boolean + + /** Recursively delete a path if it exists. Should not throw exception if file doesn't exist. */ + def delete(path: Path): Unit + + /** Is the default file system this implementation is operating on the local file system. */ + def isLocal: Boolean +} + +object CheckpointFileManager extends Logging { + + /** + * Additional methods in CheckpointFileManager implementations that allows + * [[RenameBasedFSDataOutputStream]] get atomicity by write-to-temp-file-and-rename + */ + sealed trait RenameHelperMethods { self => CheckpointFileManager + /** Create a file with overwrite. */ + def createTempFile(path: Path): FSDataOutputStream + + /** + * Rename a file. + * + * @param srcPath Source path to rename + * @param dstPath Destination path to rename to + * @param overwriteIfPossible If true, then the implementations must do a best-effort attempt to + * overwrite the file if it already exists. It should not throw + * any exception if the file exists. However, if false, then the + * implementation must not overwrite if the file alraedy exists and + * must throw `FileAlreadyExistsException` in that case. + */ + def renameTempFile(srcPath: Path, dstPath: Path, overwriteIfPossible: Boolean): Unit + } + + /** + * An interface to add the cancel() operation to [[FSDataOutputStream]]. This is used + * mainly by `CheckpointFileManager.createAtomic` to write a file atomically. + * + * @see [[CheckpointFileManager]]. + */ + abstract class CancellableFSDataOutputStream(protected val underlyingStream: OutputStream) + extends FSDataOutputStream(underlyingStream, null) { + /** Cancel the `underlyingStream` and ensure that the output file is not generated. */ + def cancel(): Unit + } + + /** + * An implementation of [[CancellableFSDataOutputStream]] that writes a file atomically by writing + * to a temporary file and then renames. + */ + sealed class RenameBasedFSDataOutputStream( + fm: CheckpointFileManager with RenameHelperMethods, + finalPath: Path, + tempPath: Path, + overwriteIfPossible: Boolean) + extends CancellableFSDataOutputStream(fm.createTempFile(tempPath)) { + + def this(fm: CheckpointFileManager with RenameHelperMethods, path: Path, overwrite: Boolean) = { + this(fm, path, generateTempPath(path), overwrite) + } + + logInfo(s"Writing atomically to $finalPath using temp file $tempPath") + @volatile private var terminated = false + + override def close(): Unit = synchronized { + try { + if (terminated) return + underlyingStream.close() + try { + fm.renameTempFile(tempPath, finalPath, overwriteIfPossible) + } catch { + case fe: FileAlreadyExistsException => + logWarning( + s"Failed to rename temp file $tempPath to $finalPath because file exists", fe) + if (!overwriteIfPossible) throw fe + } + logInfo(s"Renamed temp file $tempPath to $finalPath") + } finally { + terminated = true + } + } + + override def cancel(): Unit = synchronized { + try { + if (terminated) return + underlyingStream.close() + fm.delete(tempPath) + } catch { + case NonFatal(e) => + logWarning(s"Error cancelling write to $finalPath", e) + } finally { + terminated = true + } + } + } + + + /** Create an instance of [[CheckpointFileManager]] based on the path and configuration. */ + def create(path: Path, hadoopConf: Configuration): CheckpointFileManager = { + val fileManagerClass = hadoopConf.get( + SQLConf.STREAMING_CHECKPOINT_FILE_MANAGER_CLASS.parent.key) + if (fileManagerClass != null) { + return Utils.classForName(fileManagerClass) + .getConstructor(classOf[Path], classOf[Configuration]) + .newInstance(path, hadoopConf) + .asInstanceOf[CheckpointFileManager] + } + try { + // Try to create a manager based on `FileContext` because HDFS's `FileContext.rename() + // gives atomic renames, which is what we rely on for the default implementation + // `CheckpointFileManager.createAtomic`. + new FileContextBasedCheckpointFileManager(path, hadoopConf) + } catch { + case e: UnsupportedFileSystemException => + logWarning( + "Could not use FileContext API for managing Structured Streaming checkpoint files at " + + s"$path. Using FileSystem API instead for managing log files. If the implementation " + + s"of FileSystem.rename() is not atomic, then the correctness and fault-tolerance of" + + s"your Structured Streaming is not guaranteed.") + new FileSystemBasedCheckpointFileManager(path, hadoopConf) + } + } + + private def generateTempPath(path: Path): Path = { + val tc = org.apache.spark.TaskContext.get + val tid = if (tc != null) ".TID" + tc.taskAttemptId else "" + new Path(path.getParent, s".${path.getName}.${UUID.randomUUID}${tid}.tmp") + } +} + + +/** An implementation of [[CheckpointFileManager]] using Hadoop's [[FileSystem]] API. */ +class FileSystemBasedCheckpointFileManager(path: Path, hadoopConf: Configuration) + extends CheckpointFileManager with RenameHelperMethods with Logging { + + import CheckpointFileManager._ + + protected val fs = path.getFileSystem(hadoopConf) + + override def list(path: Path, filter: PathFilter): Array[FileStatus] = { + fs.listStatus(path, filter) + } + + override def mkdirs(path: Path): Unit = { + fs.mkdirs(path, FsPermission.getDirDefault) + } + + override def createTempFile(path: Path): FSDataOutputStream = { + fs.create(path, true) + } + + override def createAtomic( + path: Path, + overwriteIfPossible: Boolean): CancellableFSDataOutputStream = { + new RenameBasedFSDataOutputStream(this, path, overwriteIfPossible) + } + + override def open(path: Path): FSDataInputStream = { + fs.open(path) + } + + override def exists(path: Path): Boolean = { + try + return fs.getFileStatus(path) != null + catch { + case e: FileNotFoundException => + return false + } + } + + override def renameTempFile(srcPath: Path, dstPath: Path, overwriteIfPossible: Boolean): Unit = { + if (!overwriteIfPossible && fs.exists(dstPath)) { + throw new FileAlreadyExistsException( + s"Failed to rename $srcPath to $dstPath as destination already exists") + } + + if (!fs.rename(srcPath, dstPath)) { + // FileSystem.rename() returning false is very ambiguous as it can be for many reasons. + // This tries to make a best effort attempt to return the most appropriate exception. + if (fs.exists(dstPath)) { + if (!overwriteIfPossible) { + throw new FileAlreadyExistsException(s"Failed to rename as $dstPath already exists") + } + } else if (!fs.exists(srcPath)) { + throw new FileNotFoundException(s"Failed to rename as $srcPath was not found") + } else { + val msg = s"Failed to rename temp file $srcPath to $dstPath as rename returned false" + logWarning(msg) + throw new IOException(msg) + } + } + } + + override def delete(path: Path): Unit = { + try { + fs.delete(path, true) + } catch { + case e: FileNotFoundException => + logInfo(s"Failed to delete $path as it does not exist") + // ignore if file has already been deleted + } + } + + override def isLocal: Boolean = fs match { + case _: LocalFileSystem | _: RawLocalFileSystem => true + case _ => false + } +} + + +/** An implementation of [[CheckpointFileManager]] using Hadoop's [[FileContext]] API. */ +class FileContextBasedCheckpointFileManager(path: Path, hadoopConf: Configuration) + extends CheckpointFileManager with RenameHelperMethods with Logging { + + import CheckpointFileManager._ + + private val fc = if (path.toUri.getScheme == null) { + FileContext.getFileContext(hadoopConf) + } else { + FileContext.getFileContext(path.toUri, hadoopConf) + } + + override def list(path: Path, filter: PathFilter): Array[FileStatus] = { + fc.util.listStatus(path, filter) + } + + override def mkdirs(path: Path): Unit = { + fc.mkdir(path, FsPermission.getDirDefault, true) + } + + override def createTempFile(path: Path): FSDataOutputStream = { + import CreateFlag._ + import Options._ + fc.create( + path, EnumSet.of(CREATE, OVERWRITE), CreateOpts.checksumParam(ChecksumOpt.createDisabled())) + } + + override def createAtomic( + path: Path, + overwriteIfPossible: Boolean): CancellableFSDataOutputStream = { + new RenameBasedFSDataOutputStream(this, path, overwriteIfPossible) + } + + override def open(path: Path): FSDataInputStream = { + fc.open(path) + } + + override def exists(path: Path): Boolean = { + fc.util.exists(path) + } + + override def renameTempFile(srcPath: Path, dstPath: Path, overwriteIfPossible: Boolean): Unit = { + import Options.Rename._ + fc.rename(srcPath, dstPath, if (overwriteIfPossible) OVERWRITE else NONE) + } + + + override def delete(path: Path): Unit = { + try { + fc.delete(path, true) + } catch { + case e: FileNotFoundException => + // ignore if file has already been deleted + } + } + + override def isLocal: Boolean = fc.getDefaultFileSystem match { + case _: LocalFs | _: RawLocalFs => true // LocalFs = RawLocalFs + ChecksumFs + case _ => false + } +} + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala index 00bc215a5dc8c..bd0a46115ceb0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala @@ -57,10 +57,10 @@ class HDFSMetadataLog[T <: AnyRef : ClassTag](sparkSession: SparkSession, path: require(implicitly[ClassTag[T]].runtimeClass != classOf[Seq[_]], "Should not create a log with type Seq, use Arrays instead - see SPARK-17372") - import HDFSMetadataLog._ - val metadataPath = new Path(path) - protected val fileManager = createFileManager() + + protected val fileManager = + CheckpointFileManager.create(metadataPath, sparkSession.sessionState.newHadoopConf) if (!fileManager.exists(metadataPath)) { fileManager.mkdirs(metadataPath) @@ -109,84 +109,31 @@ class HDFSMetadataLog[T <: AnyRef : ClassTag](sparkSession: SparkSession, path: require(metadata != null, "'null' metadata cannot written to a metadata log") get(batchId).map(_ => false).getOrElse { // Only write metadata when the batch has not yet been written - writeBatch(batchId, metadata) + writeBatchToFile(metadata, batchIdToPath(batchId)) true } } - private def writeTempBatch(metadata: T): Option[Path] = { - while (true) { - val tempPath = new Path(metadataPath, s".${UUID.randomUUID.toString}.tmp") - try { - val output = fileManager.create(tempPath) - try { - serialize(metadata, output) - return Some(tempPath) - } finally { - output.close() - } - } catch { - case e: FileAlreadyExistsException => - // Failed to create "tempPath". There are two cases: - // 1. Someone is creating "tempPath" too. - // 2. This is a restart. "tempPath" has already been created but not moved to the final - // batch file (not committed). - // - // For both cases, the batch has not yet been committed. So we can retry it. - // - // Note: there is a potential risk here: if HDFSMetadataLog A is running, people can use - // the same metadata path to create "HDFSMetadataLog" and fail A. However, this is not a - // big problem because it requires the attacker must have the permission to write the - // metadata path. In addition, the old Streaming also have this issue, people can create - // malicious checkpoint files to crash a Streaming application too. - } - } - None - } - - /** - * Write a batch to a temp file then rename it to the batch file. + /** Write a batch to a temp file then rename it to the batch file. * * There may be multiple [[HDFSMetadataLog]] using the same metadata path. Although it is not a * valid behavior, we still need to prevent it from destroying the files. */ - private def writeBatch(batchId: Long, metadata: T): Unit = { - val tempPath = writeTempBatch(metadata).getOrElse( - throw new IllegalStateException(s"Unable to create temp batch file $batchId")) + private def writeBatchToFile(metadata: T, path: Path): Unit = { + val output = fileManager.createAtomic(path, overwriteIfPossible = false) try { - // Try to commit the batch - // It will fail if there is an existing file (someone has committed the batch) - logDebug(s"Attempting to write log #${batchIdToPath(batchId)}") - fileManager.rename(tempPath, batchIdToPath(batchId)) - - // SPARK-17475: HDFSMetadataLog should not leak CRC files - // If the underlying filesystem didn't rename the CRC file, delete it. - val crcPath = new Path(tempPath.getParent(), s".${tempPath.getName()}.crc") - if (fileManager.exists(crcPath)) fileManager.delete(crcPath) + serialize(metadata, output) + output.close() } catch { case e: FileAlreadyExistsException => - // If "rename" fails, it means some other "HDFSMetadataLog" has committed the batch. - // So throw an exception to tell the user this is not a valid behavior. + output.cancel() + // If next batch file already exists, then another concurrently running query has + // written it. throw new ConcurrentModificationException( - s"Multiple HDFSMetadataLog are using $path", e) - } finally { - fileManager.delete(tempPath) - } - } - - /** - * @return the deserialized metadata in a batch file, or None if file not exist. - * @throws IllegalArgumentException when path does not point to a batch file. - */ - def get(batchFile: Path): Option[T] = { - if (fileManager.exists(batchFile)) { - if (isBatchFile(batchFile)) { - get(pathToBatchId(batchFile)) - } else { - throw new IllegalArgumentException(s"File ${batchFile} is not a batch file!") - } - } else { - None + s"Multiple streaming queries are concurrently using $path", e) + case e: Throwable => + output.cancel() + throw e } } @@ -219,7 +166,7 @@ class HDFSMetadataLog[T <: AnyRef : ClassTag](sparkSession: SparkSession, path: (endId.isEmpty || batchId <= endId.get) && (startId.isEmpty || batchId >= startId.get) }.sorted - verifyBatchIds(batchIds, startId, endId) + HDFSMetadataLog.verifyBatchIds(batchIds, startId, endId) batchIds.map(batchId => (batchId, get(batchId))).filter(_._2.isDefined).map { case (batchId, metadataOption) => @@ -280,19 +227,6 @@ class HDFSMetadataLog[T <: AnyRef : ClassTag](sparkSession: SparkSession, path: } } - private def createFileManager(): FileManager = { - val hadoopConf = sparkSession.sessionState.newHadoopConf() - try { - new FileContextManager(metadataPath, hadoopConf) - } catch { - case e: UnsupportedFileSystemException => - logWarning("Could not use FileContext API for managing metadata log files at path " + - s"$metadataPath. Using FileSystem API instead for managing log files. The log may be " + - s"inconsistent under failures.") - new FileSystemManager(metadataPath, hadoopConf) - } - } - /** * Parse the log version from the given `text` -- will throw exception when the parsed version * exceeds `maxSupportedVersion`, or when `text` is malformed (such as "xyz", "v", "v-1", @@ -327,135 +261,6 @@ class HDFSMetadataLog[T <: AnyRef : ClassTag](sparkSession: SparkSession, path: object HDFSMetadataLog { - /** A simple trait to abstract out the file management operations needed by HDFSMetadataLog. */ - trait FileManager { - - /** List the files in a path that match a filter. */ - def list(path: Path, filter: PathFilter): Array[FileStatus] - - /** Make directory at the give path and all its parent directories as needed. */ - def mkdirs(path: Path): Unit - - /** Whether path exists */ - def exists(path: Path): Boolean - - /** Open a file for reading, or throw exception if it does not exist. */ - def open(path: Path): FSDataInputStream - - /** Create path, or throw exception if it already exists */ - def create(path: Path): FSDataOutputStream - - /** - * Atomically rename path, or throw exception if it cannot be done. - * Should throw FileNotFoundException if srcPath does not exist. - * Should throw FileAlreadyExistsException if destPath already exists. - */ - def rename(srcPath: Path, destPath: Path): Unit - - /** Recursively delete a path if it exists. Should not throw exception if file doesn't exist. */ - def delete(path: Path): Unit - } - - /** - * Default implementation of FileManager using newer FileContext API. - */ - class FileContextManager(path: Path, hadoopConf: Configuration) extends FileManager { - private val fc = if (path.toUri.getScheme == null) { - FileContext.getFileContext(hadoopConf) - } else { - FileContext.getFileContext(path.toUri, hadoopConf) - } - - override def list(path: Path, filter: PathFilter): Array[FileStatus] = { - fc.util.listStatus(path, filter) - } - - override def rename(srcPath: Path, destPath: Path): Unit = { - fc.rename(srcPath, destPath) - } - - override def mkdirs(path: Path): Unit = { - fc.mkdir(path, FsPermission.getDirDefault, true) - } - - override def open(path: Path): FSDataInputStream = { - fc.open(path) - } - - override def create(path: Path): FSDataOutputStream = { - fc.create(path, EnumSet.of(CreateFlag.CREATE)) - } - - override def exists(path: Path): Boolean = { - fc.util().exists(path) - } - - override def delete(path: Path): Unit = { - try { - fc.delete(path, true) - } catch { - case e: FileNotFoundException => - // ignore if file has already been deleted - } - } - } - - /** - * Implementation of FileManager using older FileSystem API. Note that this implementation - * cannot provide atomic renaming of paths, hence can lead to consistency issues. This - * should be used only as a backup option, when FileContextManager cannot be used. - */ - class FileSystemManager(path: Path, hadoopConf: Configuration) extends FileManager { - private val fs = path.getFileSystem(hadoopConf) - - override def list(path: Path, filter: PathFilter): Array[FileStatus] = { - fs.listStatus(path, filter) - } - - /** - * Rename a path. Note that this implementation is not atomic. - * @throws FileNotFoundException if source path does not exist. - * @throws FileAlreadyExistsException if destination path already exists. - * @throws IOException if renaming fails for some unknown reason. - */ - override def rename(srcPath: Path, destPath: Path): Unit = { - if (!fs.exists(srcPath)) { - throw new FileNotFoundException(s"Source path does not exist: $srcPath") - } - if (fs.exists(destPath)) { - throw new FileAlreadyExistsException(s"Destination path already exists: $destPath") - } - if (!fs.rename(srcPath, destPath)) { - throw new IOException(s"Failed to rename $srcPath to $destPath") - } - } - - override def mkdirs(path: Path): Unit = { - fs.mkdirs(path, FsPermission.getDirDefault) - } - - override def open(path: Path): FSDataInputStream = { - fs.open(path) - } - - override def create(path: Path): FSDataOutputStream = { - fs.create(path, false) - } - - override def exists(path: Path): Boolean = { - fs.exists(path) - } - - override def delete(path: Path): Unit = { - try { - fs.delete(path, true) - } catch { - case e: FileNotFoundException => - // ignore if file has already been deleted - } - } - } - /** * Verify if batchIds are continuous and between `startId` and `endId`. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala index 3f5002a4e6937..df722b953228b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution.streaming.state -import java.io.{DataInputStream, DataOutputStream, FileNotFoundException, IOException} +import java.io._ import java.nio.channels.ClosedChannelException import java.util.Locale @@ -27,13 +27,16 @@ import scala.util.Random import scala.util.control.NonFatal import com.google.common.io.ByteStreams +import org.apache.commons.io.IOUtils import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{FileStatus, Path} +import org.apache.hadoop.fs._ import org.apache.spark.{SparkConf, SparkEnv} import org.apache.spark.internal.Logging import org.apache.spark.io.LZ4CompressionCodec import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.execution.streaming.CheckpointFileManager +import org.apache.spark.sql.execution.streaming.CheckpointFileManager.CancellableFSDataOutputStream import org.apache.spark.sql.types.StructType import org.apache.spark.util.{SizeEstimator, Utils} @@ -87,10 +90,10 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit case object ABORTED extends STATE private val newVersion = version + 1 - private val tempDeltaFile = new Path(baseDir, s"temp-${Random.nextLong}") - private lazy val tempDeltaFileStream = compressStream(fs.create(tempDeltaFile, true)) @volatile private var state: STATE = UPDATING - @volatile private var finalDeltaFile: Path = null + private val finalDeltaFile: Path = deltaFile(newVersion) + private lazy val deltaFileStream = fm.createAtomic(finalDeltaFile, overwriteIfPossible = true) + private lazy val compressedStream = compressStream(deltaFileStream) override def id: StateStoreId = HDFSBackedStateStoreProvider.this.stateStoreId @@ -103,14 +106,14 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit val keyCopy = key.copy() val valueCopy = value.copy() mapToUpdate.put(keyCopy, valueCopy) - writeUpdateToDeltaFile(tempDeltaFileStream, keyCopy, valueCopy) + writeUpdateToDeltaFile(compressedStream, keyCopy, valueCopy) } override def remove(key: UnsafeRow): Unit = { verify(state == UPDATING, "Cannot remove after already committed or aborted") val prevValue = mapToUpdate.remove(key) if (prevValue != null) { - writeRemoveToDeltaFile(tempDeltaFileStream, key) + writeRemoveToDeltaFile(compressedStream, key) } } @@ -126,8 +129,7 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit verify(state == UPDATING, "Cannot commit after already committed or aborted") try { - finalizeDeltaFile(tempDeltaFileStream) - finalDeltaFile = commitUpdates(newVersion, mapToUpdate, tempDeltaFile) + commitUpdates(newVersion, mapToUpdate, compressedStream) state = COMMITTED logInfo(s"Committed version $newVersion for $this to file $finalDeltaFile") newVersion @@ -140,23 +142,14 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit /** Abort all the updates made on this store. This store will not be usable any more. */ override def abort(): Unit = { - verify(state == UPDATING || state == ABORTED, "Cannot abort after already committed") - try { + // This if statement is to ensure that files are deleted only if there are changes to the + // StateStore. We have two StateStores for each task, one which is used only for reading, and + // the other used for read+write. We don't want the read-only to delete state files. + if (state == UPDATING) { + state = ABORTED + cancelDeltaFile(compressedStream, deltaFileStream) + } else { state = ABORTED - if (tempDeltaFileStream != null) { - tempDeltaFileStream.close() - } - if (tempDeltaFile != null) { - fs.delete(tempDeltaFile, true) - } - } catch { - case c: ClosedChannelException => - // This can happen when underlying file output stream has been closed before the - // compression stream. - logDebug(s"Error aborting version $newVersion into $this", c) - - case e: Exception => - logWarning(s"Error aborting version $newVersion into $this", e) } logInfo(s"Aborted version $newVersion for $this") } @@ -212,7 +205,7 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit this.valueSchema = valueSchema this.storeConf = storeConf this.hadoopConf = hadoopConf - fs.mkdirs(baseDir) + fm.mkdirs(baseDir) } override def stateStoreId: StateStoreId = stateStoreId_ @@ -251,31 +244,15 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit private lazy val loadedMaps = new mutable.HashMap[Long, MapType] private lazy val baseDir = stateStoreId.storeCheckpointLocation() - private lazy val fs = baseDir.getFileSystem(hadoopConf) + private lazy val fm = CheckpointFileManager.create(baseDir, hadoopConf) private lazy val sparkConf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf) private case class StoreFile(version: Long, path: Path, isSnapshot: Boolean) - /** Commit a set of updates to the store with the given new version */ - private def commitUpdates(newVersion: Long, map: MapType, tempDeltaFile: Path): Path = { + private def commitUpdates(newVersion: Long, map: MapType, output: DataOutputStream): Unit = { synchronized { - val finalDeltaFile = deltaFile(newVersion) - - // scalastyle:off - // Renaming a file atop an existing one fails on HDFS - // (http://hadoop.apache.org/docs/stable/hadoop-project-dist/hadoop-common/filesystem/filesystem.html). - // Hence we should either skip the rename step or delete the target file. Because deleting the - // target file will break speculation, skipping the rename step is the only choice. It's still - // semantically correct because Structured Streaming requires rerunning a batch should - // generate the same output. (SPARK-19677) - // scalastyle:on - if (fs.exists(finalDeltaFile)) { - fs.delete(tempDeltaFile, true) - } else if (!fs.rename(tempDeltaFile, finalDeltaFile)) { - throw new IOException(s"Failed to rename $tempDeltaFile to $finalDeltaFile") - } + finalizeDeltaFile(output) loadedMaps.put(newVersion, map) - finalDeltaFile } } @@ -365,7 +342,7 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit val fileToRead = deltaFile(version) var input: DataInputStream = null val sourceStream = try { - fs.open(fileToRead) + fm.open(fileToRead) } catch { case f: FileNotFoundException => throw new IllegalStateException( @@ -412,12 +389,12 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit } private def writeSnapshotFile(version: Long, map: MapType): Unit = { - val fileToWrite = snapshotFile(version) - val tempFile = - new Path(fileToWrite.getParent, s"${fileToWrite.getName}.temp-${Random.nextLong}") + val targetFile = snapshotFile(version) + var rawOutput: CancellableFSDataOutputStream = null var output: DataOutputStream = null - Utils.tryWithSafeFinally { - output = compressStream(fs.create(tempFile, false)) + try { + rawOutput = fm.createAtomic(targetFile, overwriteIfPossible = true) + output = compressStream(rawOutput) val iter = map.entrySet().iterator() while(iter.hasNext) { val entry = iter.next() @@ -429,16 +406,34 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit output.write(valueBytes) } output.writeInt(-1) - } { - if (output != null) output.close() + output.close() + } catch { + case e: Throwable => + cancelDeltaFile(compressedStream = output, rawStream = rawOutput) + throw e } - if (fs.exists(fileToWrite)) { - // Skip rename if the file is alreayd created. - fs.delete(tempFile, true) - } else if (!fs.rename(tempFile, fileToWrite)) { - throw new IOException(s"Failed to rename $tempFile to $fileToWrite") + logInfo(s"Written snapshot file for version $version of $this at $targetFile") + } + + /** + * Try to cancel the underlying stream and safely close the compressed stream. + * + * @param compressedStream the compressed stream. + * @param rawStream the underlying stream which needs to be cancelled. + */ + private def cancelDeltaFile( + compressedStream: DataOutputStream, + rawStream: CancellableFSDataOutputStream): Unit = { + try { + if (rawStream != null) rawStream.cancel() + IOUtils.closeQuietly(compressedStream) + } catch { + case e: FSError if e.getCause.isInstanceOf[IOException] => + // Closing the compressedStream causes the stream to write/flush flush data into the + // rawStream. Since the rawStream is already closed, there may be errors. + // Usually its an IOException. However, Hadoop's RawLocalFileSystem wraps + // IOException into FSError. } - logInfo(s"Written snapshot file for version $version of $this at $fileToWrite") } private def readSnapshotFile(version: Long): Option[MapType] = { @@ -447,7 +442,7 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit var input: DataInputStream = null try { - input = decompressStream(fs.open(fileToRead)) + input = decompressStream(fm.open(fileToRead)) var eof = false while (!eof) { @@ -508,7 +503,6 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit case None => // The last map is not loaded, probably some other instance is in charge } - } } catch { case NonFatal(e) => @@ -534,7 +528,7 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit } val filesToDelete = files.filter(_.version < earliestFileToRetain.version) filesToDelete.foreach { f => - fs.delete(f.path, true) + fm.delete(f.path) } logInfo(s"Deleted files older than ${earliestFileToRetain.version} for $this: " + filesToDelete.mkString(", ")) @@ -576,7 +570,7 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit /** Fetch all the files that back the store */ private def fetchFiles(): Seq[StoreFile] = { val files: Seq[FileStatus] = try { - fs.listStatus(baseDir) + fm.list(baseDir) } catch { case _: java.io.FileNotFoundException => Seq.empty diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index d1d9f95cb0977..7eb68c21569ba 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -459,7 +459,6 @@ object StateStore extends Logging { private def coordinatorRef: Option[StateStoreCoordinatorRef] = loadedProviders.synchronized { val env = SparkEnv.get if (env != null) { - logInfo("Env is not null") val isDriver = env.executorId == SparkContext.DRIVER_IDENTIFIER || env.executorId == SparkContext.LEGACY_DRIVER_IDENTIFIER @@ -467,13 +466,12 @@ object StateStore extends Logging { // as SparkContext + SparkEnv may have been restarted. Hence, when running in driver, // always recreate the reference. if (isDriver || _coordRef == null) { - logInfo("Getting StateStoreCoordinatorRef") + logDebug("Getting StateStoreCoordinatorRef") _coordRef = StateStoreCoordinatorRef.forExecutor(env) } logInfo(s"Retrieved reference to StateStoreCoordinator: ${_coordRef}") Some(_coordRef) } else { - logInfo("Env is null") _coordRef = null None } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CheckpointFileManagerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CheckpointFileManagerSuite.scala new file mode 100644 index 0000000000000..fe59cb25d5005 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CheckpointFileManagerSuite.scala @@ -0,0 +1,192 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.streaming + +import java.io._ +import java.net.URI + +import scala.util.Random + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs._ + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.util.quietly +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.util.Utils + +abstract class CheckpointFileManagerTests extends SparkFunSuite { + + def createManager(path: Path): CheckpointFileManager + + test("mkdirs, list, createAtomic, open, delete, exists") { + withTempPath { p => + val basePath = new Path(p.getAbsolutePath) + val fm = createManager(basePath) + // Mkdirs + val dir = new Path(s"$basePath/dir/subdir/subsubdir") + assert(!fm.exists(dir)) + fm.mkdirs(dir) + assert(fm.exists(dir)) + fm.mkdirs(dir) + + // List + val acceptAllFilter = new PathFilter { + override def accept(path: Path): Boolean = true + } + val rejectAllFilter = new PathFilter { + override def accept(path: Path): Boolean = false + } + assert(fm.list(basePath, acceptAllFilter).exists(_.getPath.getName == "dir")) + assert(fm.list(basePath, rejectAllFilter).length === 0) + + // Create atomic without overwrite + var path = new Path(s"$dir/file") + assert(!fm.exists(path)) + fm.createAtomic(path, overwriteIfPossible = false).cancel() + assert(!fm.exists(path)) + fm.createAtomic(path, overwriteIfPossible = false).close() + assert(fm.exists(path)) + quietly { + intercept[IOException] { + // should throw exception since file exists and overwrite is false + fm.createAtomic(path, overwriteIfPossible = false).close() + } + } + + // Create atomic with overwrite if possible + path = new Path(s"$dir/file2") + assert(!fm.exists(path)) + fm.createAtomic(path, overwriteIfPossible = true).cancel() + assert(!fm.exists(path)) + fm.createAtomic(path, overwriteIfPossible = true).close() + assert(fm.exists(path)) + fm.createAtomic(path, overwriteIfPossible = true).close() // should not throw exception + + // Open and delete + fm.open(path).close() + fm.delete(path) + assert(!fm.exists(path)) + intercept[IOException] { + fm.open(path) + } + fm.delete(path) // should not throw exception + } + } + + protected def withTempPath(f: File => Unit): Unit = { + val path = Utils.createTempDir() + path.delete() + try f(path) finally Utils.deleteRecursively(path) + } +} + +class CheckpointFileManagerSuite extends SparkFunSuite with SharedSparkSession { + + test("CheckpointFileManager.create() should pick up user-specified class from conf") { + withSQLConf( + SQLConf.STREAMING_CHECKPOINT_FILE_MANAGER_CLASS.parent.key -> + classOf[CreateAtomicTestManager].getName) { + val fileManager = + CheckpointFileManager.create(new Path("/"), spark.sessionState.newHadoopConf) + assert(fileManager.isInstanceOf[CreateAtomicTestManager]) + } + } + + test("CheckpointFileManager.create() should fallback from FileContext to FileSystem") { + import CheckpointFileManagerSuiteFileSystem.scheme + spark.conf.set(s"fs.$scheme.impl", classOf[CheckpointFileManagerSuiteFileSystem].getName) + quietly { + withTempDir { temp => + val metadataLog = new HDFSMetadataLog[String](spark, s"$scheme://${temp.toURI.getPath}") + assert(metadataLog.add(0, "batch0")) + assert(metadataLog.getLatest() === Some(0 -> "batch0")) + assert(metadataLog.get(0) === Some("batch0")) + assert(metadataLog.get(None, Some(0)) === Array(0 -> "batch0")) + + val metadataLog2 = new HDFSMetadataLog[String](spark, s"$scheme://${temp.toURI.getPath}") + assert(metadataLog2.get(0) === Some("batch0")) + assert(metadataLog2.getLatest() === Some(0 -> "batch0")) + assert(metadataLog2.get(None, Some(0)) === Array(0 -> "batch0")) + } + } + } +} + +class FileContextBasedCheckpointFileManagerSuite extends CheckpointFileManagerTests { + override def createManager(path: Path): CheckpointFileManager = { + new FileContextBasedCheckpointFileManager(path, new Configuration()) + } +} + +class FileSystemBasedCheckpointFileManagerSuite extends CheckpointFileManagerTests { + override def createManager(path: Path): CheckpointFileManager = { + new FileSystemBasedCheckpointFileManager(path, new Configuration()) + } +} + + +/** A fake implementation to test different characteristics of CheckpointFileManager interface */ +class CreateAtomicTestManager(path: Path, hadoopConf: Configuration) + extends FileSystemBasedCheckpointFileManager(path, hadoopConf) { + + import CheckpointFileManager._ + + override def createAtomic(path: Path, overwrite: Boolean): CancellableFSDataOutputStream = { + if (CreateAtomicTestManager.shouldFailInCreateAtomic) { + CreateAtomicTestManager.cancelCalledInCreateAtomic = false + } + val originalOut = super.createAtomic(path, overwrite) + + new CancellableFSDataOutputStream(originalOut) { + override def close(): Unit = { + if (CreateAtomicTestManager.shouldFailInCreateAtomic) { + throw new IOException("Copy failed intentionally") + } + super.close() + } + + override def cancel(): Unit = { + CreateAtomicTestManager.cancelCalledInCreateAtomic = true + originalOut.cancel() + } + } + } +} + +object CreateAtomicTestManager { + @volatile var shouldFailInCreateAtomic = false + @volatile var cancelCalledInCreateAtomic = false +} + + +/** + * CheckpointFileManagerSuiteFileSystem to test fallback of the CheckpointFileManager + * from FileContext to FileSystem API. + */ +private class CheckpointFileManagerSuiteFileSystem extends RawLocalFileSystem { + import CheckpointFileManagerSuiteFileSystem.scheme + + override def getUri: URI = { + URI.create(s"$scheme:///") + } +} + +private object CheckpointFileManagerSuiteFileSystem { + val scheme = s"CheckpointFileManagerSuiteFileSystem${math.abs(Random.nextInt)}" +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLogSuite.scala index 12eaf63415081..ec961a9ecb592 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLogSuite.scala @@ -22,15 +22,10 @@ import java.nio.charset.StandardCharsets._ import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.execution.streaming.FakeFileSystem._ import org.apache.spark.sql.test.SharedSQLContext class CompactibleFileStreamLogSuite extends SparkFunSuite with SharedSQLContext { - /** To avoid caching of FS objects */ - override protected def sparkConf = - super.sparkConf.set(s"spark.hadoop.fs.$scheme.impl.disable.cache", "true") - import CompactibleFileStreamLog._ /** -- testing of `object CompactibleFileStreamLog` begins -- */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala index 4677769c12a35..9268306ce4275 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala @@ -17,46 +17,22 @@ package org.apache.spark.sql.execution.streaming -import java.io.{File, FileNotFoundException, IOException} -import java.net.URI +import java.io.File import java.util.ConcurrentModificationException import scala.language.implicitConversions -import scala.util.Random -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs._ import org.scalatest.concurrent.Waiters._ import org.scalatest.time.SpanSugar._ import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.execution.streaming.FakeFileSystem._ -import org.apache.spark.sql.execution.streaming.HDFSMetadataLog.{FileContextManager, FileManager, FileSystemManager} import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.util.UninterruptibleThread class HDFSMetadataLogSuite extends SparkFunSuite with SharedSQLContext { - /** To avoid caching of FS objects */ - override protected def sparkConf = - super.sparkConf.set(s"spark.hadoop.fs.$scheme.impl.disable.cache", "true") - private implicit def toOption[A](a: A): Option[A] = Option(a) - test("FileManager: FileContextManager") { - withTempDir { temp => - val path = new Path(temp.getAbsolutePath) - testFileManager(path, new FileContextManager(path, new Configuration)) - } - } - - test("FileManager: FileSystemManager") { - withTempDir { temp => - val path = new Path(temp.getAbsolutePath) - testFileManager(path, new FileSystemManager(path, new Configuration)) - } - } - test("HDFSMetadataLog: basic") { withTempDir { temp => val dir = new File(temp, "dir") // use non-existent directory to test whether log make the dir @@ -82,26 +58,6 @@ class HDFSMetadataLogSuite extends SparkFunSuite with SharedSQLContext { } } - testQuietly("HDFSMetadataLog: fallback from FileContext to FileSystem") { - spark.conf.set( - s"fs.$scheme.impl", - classOf[FakeFileSystem].getName) - withTempDir { temp => - val metadataLog = new HDFSMetadataLog[String](spark, s"$scheme://${temp.toURI.getPath}") - assert(metadataLog.add(0, "batch0")) - assert(metadataLog.getLatest() === Some(0 -> "batch0")) - assert(metadataLog.get(0) === Some("batch0")) - assert(metadataLog.get(None, Some(0)) === Array(0 -> "batch0")) - - - val metadataLog2 = new HDFSMetadataLog[String](spark, s"$scheme://${temp.toURI.getPath}") - assert(metadataLog2.get(0) === Some("batch0")) - assert(metadataLog2.getLatest() === Some(0 -> "batch0")) - assert(metadataLog2.get(None, Some(0)) === Array(0 -> "batch0")) - - } - } - test("HDFSMetadataLog: purge") { withTempDir { temp => val metadataLog = new HDFSMetadataLog[String](spark, temp.getAbsolutePath) @@ -121,7 +77,8 @@ class HDFSMetadataLogSuite extends SparkFunSuite with SharedSQLContext { // There should be exactly one file, called "2", in the metadata directory. // This check also tests for regressions of SPARK-17475 - val allFiles = new File(metadataLog.metadataPath.toString).listFiles().toSeq + val allFiles = new File(metadataLog.metadataPath.toString).listFiles() + .filter(!_.getName.startsWith(".")).toSeq assert(allFiles.size == 1) assert(allFiles(0).getName() == "2") } @@ -172,7 +129,7 @@ class HDFSMetadataLogSuite extends SparkFunSuite with SharedSQLContext { } } - test("HDFSMetadataLog: metadata directory collision") { + testQuietly("HDFSMetadataLog: metadata directory collision") { withTempDir { temp => val waiter = new Waiter val maxBatchId = 100 @@ -206,60 +163,6 @@ class HDFSMetadataLogSuite extends SparkFunSuite with SharedSQLContext { } } - /** Basic test case for [[FileManager]] implementation. */ - private def testFileManager(basePath: Path, fm: FileManager): Unit = { - // Mkdirs - val dir = new Path(s"$basePath/dir/subdir/subsubdir") - assert(!fm.exists(dir)) - fm.mkdirs(dir) - assert(fm.exists(dir)) - fm.mkdirs(dir) - - // List - val acceptAllFilter = new PathFilter { - override def accept(path: Path): Boolean = true - } - val rejectAllFilter = new PathFilter { - override def accept(path: Path): Boolean = false - } - assert(fm.list(basePath, acceptAllFilter).exists(_.getPath.getName == "dir")) - assert(fm.list(basePath, rejectAllFilter).length === 0) - - // Create - val path = new Path(s"$dir/file") - assert(!fm.exists(path)) - fm.create(path).close() - assert(fm.exists(path)) - intercept[IOException] { - fm.create(path) - } - - // Open and delete - fm.open(path).close() - fm.delete(path) - assert(!fm.exists(path)) - intercept[IOException] { - fm.open(path) - } - fm.delete(path) // should not throw exception - - // Rename - val path1 = new Path(s"$dir/file1") - val path2 = new Path(s"$dir/file2") - fm.create(path1).close() - assert(fm.exists(path1)) - fm.rename(path1, path2) - intercept[FileNotFoundException] { - fm.rename(path1, path2) - } - val path3 = new Path(s"$dir/file3") - fm.create(path3).close() - assert(fm.exists(path3)) - intercept[FileAlreadyExistsException] { - fm.rename(path2, path3) - } - } - test("verifyBatchIds") { import HDFSMetadataLog.verifyBatchIds verifyBatchIds(Seq(1L, 2L, 3L), Some(1L), Some(3L)) @@ -277,14 +180,3 @@ class HDFSMetadataLogSuite extends SparkFunSuite with SharedSQLContext { intercept[IllegalStateException](verifyBatchIds(Seq(1, 2, 4, 5), Some(1L), Some(5L))) } } - -/** FakeFileSystem to test fallback of the HDFSMetadataLog from FileContext to FileSystem API */ -class FakeFileSystem extends RawLocalFileSystem { - override def getUri: URI = { - URI.create(s"$scheme:///") - } -} - -object FakeFileSystem { - val scheme = s"HDFSMetadataLogSuite${math.abs(Random.nextInt)}" -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala index c843b65020d8c..73f8705060402 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql.execution.streaming.state import java.io.{File, IOException} import java.net.URI import java.util.UUID -import java.util.concurrent.ConcurrentHashMap import scala.collection.JavaConverters._ import scala.collection.mutable @@ -28,17 +27,17 @@ import scala.util.Random import org.apache.commons.io.FileUtils import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{FileStatus, Path, RawLocalFileSystem} +import org.apache.hadoop.fs._ import org.scalatest.{BeforeAndAfter, PrivateMethodTester} import org.scalatest.concurrent.Eventually._ import org.scalatest.time.SpanSugar._ -import org.apache.spark.{SparkConf, SparkContext, SparkEnv, SparkFunSuite} +import org.apache.spark._ import org.apache.spark.LocalSparkContext._ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.catalyst.util.quietly -import org.apache.spark.sql.execution.streaming.{MemoryStream, StreamingQueryWrapper} +import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.functions.count import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -138,7 +137,7 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] assert(getData(provider, 19) === Set("a" -> 19)) } - test("SPARK-19677: Committing a delta file atop an existing one should not fail on HDFS") { + testQuietly("SPARK-19677: Committing a delta file atop an existing one should not fail on HDFS") { val conf = new Configuration() conf.set("fs.fake.impl", classOf[RenameLikeHDFSFileSystem].getName) conf.set("fs.defaultFS", "fake:///") @@ -344,7 +343,7 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] } } - test("SPARK-18342: commit fails when rename fails") { + testQuietly("SPARK-18342: commit fails when rename fails") { import RenameReturnsFalseFileSystem._ val dir = scheme + "://" + newDir() val conf = new Configuration() @@ -366,7 +365,7 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] def numTempFiles: Int = { if (deltaFileDir.exists) { - deltaFileDir.listFiles.map(_.getName).count(n => n.contains("temp") && !n.startsWith(".")) + deltaFileDir.listFiles.map(_.getName).count(n => n.endsWith(".tmp")) } else 0 } @@ -471,6 +470,43 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] } } + test("error writing [version].delta cancels the output stream") { + + val hadoopConf = new Configuration() + hadoopConf.set( + SQLConf.STREAMING_CHECKPOINT_FILE_MANAGER_CLASS.parent.key, + classOf[CreateAtomicTestManager].getName) + val remoteDir = Utils.createTempDir().getAbsolutePath + + val provider = newStoreProvider( + opId = Random.nextInt, partition = 0, dir = remoteDir, hadoopConf = hadoopConf) + + // Disable failure of output stream and generate versions + CreateAtomicTestManager.shouldFailInCreateAtomic = false + for (version <- 1 to 10) { + val store = provider.getStore(version - 1) + put(store, version.toString, version) // update "1" -> 1, "2" -> 2, ... + store.commit() + } + val version10Data = (1L to 10).map(_.toString).map(x => x -> x).toSet + + CreateAtomicTestManager.cancelCalledInCreateAtomic = false + val store = provider.getStore(10) + // Fail commit for next version and verify that reloading resets the files + CreateAtomicTestManager.shouldFailInCreateAtomic = true + put(store, "11", 11) + val e = intercept[IllegalStateException] { quietly { store.commit() } } + assert(e.getCause.isInstanceOf[IOException]) + CreateAtomicTestManager.shouldFailInCreateAtomic = false + + // Abort commit for next version and verify that reloading resets the files + CreateAtomicTestManager.cancelCalledInCreateAtomic = false + val store2 = provider.getStore(10) + put(store2, "11", 11) + store2.abort() + assert(CreateAtomicTestManager.cancelCalledInCreateAtomic) + } + override def newStoreProvider(): HDFSBackedStateStoreProvider = { newStoreProvider(opId = Random.nextInt(), partition = 0) } @@ -720,6 +756,14 @@ abstract class StateStoreSuiteBase[ProviderClass <: StateStoreProvider] * this provider */ def getData(storeProvider: ProviderClass, version: Int): Set[(String, Int)] + + protected def testQuietly(name: String)(f: => Unit): Unit = { + test(name) { + quietly { + f + } + } + } } object StateStoreTestsHelper { From 73f28530d6f6dd8aba758ea818c456cf911a5f41 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sat, 14 Apr 2018 08:59:04 +0800 Subject: [PATCH 0621/1161] [SPARK-23979][SQL] MultiAlias should not be a CodegenFallback ## What changes were proposed in this pull request? Just found `MultiAlias` is a `CodegenFallback`. It should not be as looks like `MultiAlias` won't be evaluated. ## How was this patch tested? Existing tests. Author: Liang-Chi Hsieh Closes #21065 from viirya/multialias-without-codegenfallback. --- .../org/apache/spark/sql/catalyst/analysis/unresolved.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index a65f58fa61ff4..71e23175168e2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -21,7 +21,7 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.{FunctionIdentifier, InternalRow, TableIdentifier} import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodegenFallback, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.parser.ParserUtils import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, UnaryNode} import org.apache.spark.sql.catalyst.trees.TreeNode @@ -335,7 +335,7 @@ case class UnresolvedRegex(regexPattern: String, table: Option[String], caseSens * @param names the names to be associated with each output of computing [[child]]. */ case class MultiAlias(child: Expression, names: Seq[String]) - extends UnaryExpression with NamedExpression with CodegenFallback { + extends UnaryExpression with NamedExpression with Unevaluable { override def name: String = throw new UnresolvedException(this, "name") From c0964935d614bf345535439bce01cbd0e60c86aa Mon Sep 17 00:00:00 2001 From: Gera Shegalov Date: Mon, 16 Apr 2018 12:01:42 +0800 Subject: [PATCH 0622/1161] [SPARK-23956][YARN] Use effective RPC port in AM registration ## What changes were proposed in this pull request? We propose not to hard-code the RPC port in the AM registration. ## How was this patch tested? Tested application reports from a pseudo-distributed cluster ``` 18/04/10 14:56:21 INFO Client: client token: N/A diagnostics: N/A ApplicationMaster host: localhost ApplicationMaster RPC port: 58338 queue: default start time: 1523397373659 final status: UNDEFINED tracking URL: http://localhost:8088/proxy/application_1523370127531_0016/ ``` Author: Gera Shegalov Closes #21047 from gerashegalov/gera/am-to-rm-nmhost. --- .../scala/org/apache/spark/deploy/yarn/YarnRMClient.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala index c1ae12aabb8cc..17234b120ae13 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala @@ -29,7 +29,6 @@ import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.deploy.yarn.config._ import org.apache.spark.internal.Logging import org.apache.spark.rpc.RpcEndpointRef -import org.apache.spark.util.Utils /** * Handles registering and unregistering the application with the YARN ResourceManager. @@ -71,7 +70,8 @@ private[spark] class YarnRMClient extends Logging { logInfo("Registering the ApplicationMaster") synchronized { - amClient.registerApplicationMaster(Utils.localHostName(), 0, trackingUrl) + amClient.registerApplicationMaster(driverRef.address.host, driverRef.address.port, + trackingUrl) registered = true } new YarnAllocator(driverUrl, driverRef, conf, sparkConf, amClient, getAttemptId(), securityMgr, From 69310220319163bac18c9ee69d7da6d92227253b Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Sun, 15 Apr 2018 21:45:55 -0700 Subject: [PATCH 0623/1161] [SPARK-23917][SQL] Add array_max function ## What changes were proposed in this pull request? The PR adds the SQL function `array_max`. It takes an array as argument and returns the maximum value in it. ## How was this patch tested? added UTs Author: Marco Gaido Closes #21024 from mgaido91/SPARK-23917. --- python/pyspark/sql/functions.py | 15 ++++ .../catalyst/analysis/FunctionRegistry.scala | 1 + .../sql/catalyst/expressions/arithmetic.scala | 6 +- .../expressions/codegen/CodeGenerator.scala | 17 +++++ .../expressions/collectionOperations.scala | 68 ++++++++++++++++++- .../CollectionExpressionsSuite.scala | 10 +++ .../org/apache/spark/sql/functions.scala | 8 +++ .../spark/sql/DataFrameFunctionsSuite.scala | 14 ++++ 8 files changed, 133 insertions(+), 6 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 1b192680f0795..f3492ae42639c 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2080,6 +2080,21 @@ def size(col): return Column(sc._jvm.functions.size(_to_java_column(col))) +@since(2.4) +def array_max(col): + """ + Collection function: returns the maximum value of the array. + + :param col: name of column or expression + + >>> df = spark.createDataFrame([([2, 1, 3],), ([None, 10, -1],)], ['data']) + >>> df.select(array_max(df.data).alias('max')).collect() + [Row(max=3), Row(max=10)] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.array_max(_to_java_column(col))) + + @since(1.5) def sort_array(col, asc=True): """ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 131b958239e41..05bfa2dd45340 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -409,6 +409,7 @@ object FunctionRegistry { expression[MapValues]("map_values"), expression[Size]("size"), expression[SortArray]("sort_array"), + expression[ArrayMax]("array_max"), CreateStruct.registryEntry, // misc functions diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 9212c3de1f814..942dfd4292610 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -674,11 +674,7 @@ case class Greatest(children: Seq[Expression]) extends Expression { val evals = evalChildren.map(eval => s""" |${eval.code} - |if (!${eval.isNull} && (${ev.isNull} || - | ${ctx.genGreater(dataType, eval.value, ev.value)})) { - | ${ev.isNull} = false; - | ${ev.value} = ${eval.value}; - |} + |${ctx.reassignIfGreater(dataType, ev, eval)} """.stripMargin ) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 0abfc9fa4c465..c86c5beded9d0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -699,6 +699,23 @@ class CodegenContext { case _ => s"(${genComp(dataType, c1, c2)}) > 0" } + /** + * Generates code for updating `partialResult` if `item` is greater than it. + * + * @param dataType data type of the expressions + * @param partialResult `ExprCode` representing the partial result which has to be updated + * @param item `ExprCode` representing the new expression to evaluate for the result + */ + def reassignIfGreater(dataType: DataType, partialResult: ExprCode, item: ExprCode): String = { + s""" + |if (!${item.isNull} && (${partialResult.isNull} || + | ${genGreater(dataType, item.value, partialResult.value)})) { + | ${partialResult.isNull} = false; + | ${partialResult.value} = ${item.value}; + |} + """.stripMargin + } + /** * Generates code to do null safe execution, i.e. only execute the code when the input is not * null by adding null check if necessary. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 91188da8b0bd3..e2614a179aad8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -21,7 +21,7 @@ import java.util.Comparator import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ -import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData} +import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData, TypeUtils} import org.apache.spark.sql.types._ /** @@ -287,3 +287,69 @@ case class ArrayContains(left: Expression, right: Expression) override def prettyName: String = "array_contains" } + + +/** + * Returns the maximum value in the array. + */ +@ExpressionDescription( + usage = "_FUNC_(array) - Returns the maximum value in the array. NULL elements are skipped.", + examples = """ + Examples: + > SELECT _FUNC_(array(1, 20, null, 3)); + 20 + """, since = "2.4.0") +case class ArrayMax(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { + + override def nullable: Boolean = true + + override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType) + + private lazy val ordering = TypeUtils.getInterpretedOrdering(dataType) + + override def checkInputDataTypes(): TypeCheckResult = { + val typeCheckResult = super.checkInputDataTypes() + if (typeCheckResult.isSuccess) { + TypeUtils.checkForOrderingExpr(dataType, s"function $prettyName") + } else { + typeCheckResult + } + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val childGen = child.genCode(ctx) + val javaType = CodeGenerator.javaType(dataType) + val i = ctx.freshName("i") + val item = ExprCode("", + isNull = JavaCode.isNullExpression(s"${childGen.value}.isNullAt($i)"), + value = JavaCode.expression(CodeGenerator.getValue(childGen.value, dataType, i), dataType)) + ev.copy(code = + s""" + |${childGen.code} + |boolean ${ev.isNull} = true; + |$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; + |if (!${childGen.isNull}) { + | for (int $i = 0; $i < ${childGen.value}.numElements(); $i ++) { + | ${ctx.reassignIfGreater(dataType, ev, item)} + | } + |} + """.stripMargin) + } + + override protected def nullSafeEval(input: Any): Any = { + var max: Any = null + input.asInstanceOf[ArrayData].foreach(dataType, (_, item) => + if (item != null && (max == null || ordering.gt(item, max))) { + max = item + } + ) + max + } + + override def dataType: DataType = child.dataType match { + case ArrayType(dt, _) => dt + case _ => throw new IllegalStateException(s"$prettyName accepts only arrays.") + } + + override def prettyName: String = "array_max" +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 020687e4b3a27..a2384019533b7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -105,4 +105,14 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(ArrayContains(a3, Literal("")), null) checkEvaluation(ArrayContains(a3, Literal.create(null, StringType)), null) } + + test("Array max") { + checkEvaluation(ArrayMax(Literal.create(Seq(1, 10, 2), ArrayType(IntegerType))), 10) + checkEvaluation( + ArrayMax(Literal.create(Seq[String](null, "abc", ""), ArrayType(StringType))), "abc") + checkEvaluation(ArrayMax(Literal.create(Seq(null), ArrayType(LongType))), null) + checkEvaluation(ArrayMax(Literal.create(null, ArrayType(StringType))), null) + checkEvaluation( + ArrayMax(Literal.create(Seq(1.123, 0.1234, 1.121), ArrayType(DoubleType))), 1.123) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index c658f25ced053..daf407926dca4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3300,6 +3300,14 @@ object functions { */ def sort_array(e: Column, asc: Boolean): Column = withExpr { SortArray(e.expr, lit(asc).expr) } + /** + * Returns the maximum value in the array. + * + * @group collection_funcs + * @since 2.4.0 + */ + def array_max(e: Column): Column = withExpr { ArrayMax(e.expr) } + /** * Returns an unordered array containing the keys of the map. * @group collection_funcs diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 50e475984f458..5d5d92c84df6d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -413,6 +413,20 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { ) } + test("array_max function") { + val df = Seq( + Seq[Option[Int]](Some(1), Some(3), Some(2)), + Seq.empty[Option[Int]], + Seq[Option[Int]](None), + Seq[Option[Int]](None, Some(1), Some(-100)) + ).toDF("a") + + val answer = Seq(Row(3), Row(null), Row(null), Row(1)) + + checkAnswer(df.select(array_max(df("a"))), answer) + checkAnswer(df.selectExpr("array_max(a)"), answer) + } + private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = { import DataFrameFunctionsSuite.CodegenFallbackExpr for ((codegenFallback, wholeStage) <- Seq((true, false), (false, false), (false, true))) { From 083cf223569b7896e35ff1d53a73498a4971b28d Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Mon, 16 Apr 2018 23:50:50 +0800 Subject: [PATCH 0624/1161] [SPARK-21033][CORE][FOLLOW-UP] Update Spillable ## What changes were proposed in this pull request? Update ```scala SparkEnv.get.conf.getLong("spark.shuffle.spill.numElementsForceSpillThreshold", Long.MaxValue) ``` to ```scala SparkEnv.get.conf.get(SHUFFLE_SPILL_NUM_ELEMENTS_FORCE_SPILL_THRESHOLD) ``` because of `SHUFFLE_SPILL_NUM_ELEMENTS_FORCE_SPILL_THRESHOLD`'s default value is `Integer.MAX_VALUE`: https://github.com/apache/spark/blob/c99fc9ad9b600095baba003053dbf84304ca392b/core/src/main/scala/org/apache/spark/internal/config/package.scala#L503-L511 ## How was this patch tested? N/A Author: Yuming Wang Closes #21077 from wangyum/SPARK-21033. --- .../org/apache/spark/util/collection/Spillable.scala | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala b/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala index 8183f825592c0..81457b53cd814 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala @@ -19,6 +19,7 @@ package org.apache.spark.util.collection import org.apache.spark.SparkEnv import org.apache.spark.internal.Logging +import org.apache.spark.internal.config._ import org.apache.spark.memory.{MemoryConsumer, MemoryMode, TaskMemoryManager} /** @@ -41,7 +42,7 @@ private[spark] abstract class Spillable[C](taskMemoryManager: TaskMemoryManager) protected def forceSpill(): Boolean // Number of elements read from input since last spill - protected def elementsRead: Long = _elementsRead + protected def elementsRead: Int = _elementsRead // Called by subclasses every time a record is read // It's used for checking spilling frequency @@ -54,15 +55,15 @@ private[spark] abstract class Spillable[C](taskMemoryManager: TaskMemoryManager) // Force this collection to spill when there are this many elements in memory // For testing only - private[this] val numElementsForceSpillThreshold: Long = - SparkEnv.get.conf.getLong("spark.shuffle.spill.numElementsForceSpillThreshold", Long.MaxValue) + private[this] val numElementsForceSpillThreshold: Int = + SparkEnv.get.conf.get(SHUFFLE_SPILL_NUM_ELEMENTS_FORCE_SPILL_THRESHOLD) // Threshold for this collection's size in bytes before we start tracking its memory usage // To avoid a large number of small spills, initialize this to a value orders of magnitude > 0 @volatile private[this] var myMemoryThreshold = initialMemoryThreshold // Number of elements read from input since last spill - private[this] var _elementsRead = 0L + private[this] var _elementsRead = 0 // Number of bytes spilled in total @volatile private[this] var _memoryBytesSpilled = 0L From 5003736ad60c3231bb18264c9561646c08379170 Mon Sep 17 00:00:00 2001 From: Lu WANG Date: Mon, 16 Apr 2018 11:27:30 -0500 Subject: [PATCH 0625/1161] [SPARK-9312][ML] Add RawPrediction, numClasses, and numFeatures for OneVsRestModel add RawPrediction as output column add numClasses and numFeatures to OneVsRestModel ## What changes were proposed in this pull request? - Add two val numClasses and numFeatures in OneVsRestModel so that we can inherit from Classifier in the future - Add rawPrediction output column in transform, the prediction label in calculated by the rawPrediciton like raw2prediction ## How was this patch tested? (Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests) (If this patch involves UI changes, please attach a screenshot; otherwise, remove this) Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Lu WANG Closes #21044 from ludatabricks/SPARK-9312. --- .../spark/ml/classification/OneVsRest.scala | 56 +++++++++++++++---- .../ml/classification/OneVsRestSuite.scala | 7 ++- 2 files changed, 51 insertions(+), 12 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala index f04fde2cbbca1..5348d882cfd67 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala @@ -32,7 +32,7 @@ import org.apache.spark.SparkContext import org.apache.spark.annotation.Since import org.apache.spark.ml._ import org.apache.spark.ml.attribute._ -import org.apache.spark.ml.linalg.Vector +import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.{Param, ParamMap, ParamPair, Params} import org.apache.spark.ml.param.shared.{HasParallelism, HasWeightCol} import org.apache.spark.ml.util._ @@ -55,7 +55,7 @@ private[ml] trait ClassifierTypeTrait { /** * Params for [[OneVsRest]]. */ -private[ml] trait OneVsRestParams extends PredictorParams +private[ml] trait OneVsRestParams extends ClassifierParams with ClassifierTypeTrait with HasWeightCol { /** @@ -138,6 +138,14 @@ final class OneVsRestModel private[ml] ( @Since("1.4.0") val models: Array[_ <: ClassificationModel[_, _]]) extends Model[OneVsRestModel] with OneVsRestParams with MLWritable { + require(models.nonEmpty, "OneVsRestModel requires at least one model for one class") + + @Since("2.4.0") + val numClasses: Int = models.length + + @Since("2.4.0") + val numFeatures: Int = models.head.numFeatures + /** @group setParam */ @Since("2.1.0") def setFeaturesCol(value: String): this.type = set(featuresCol, value) @@ -146,6 +154,10 @@ final class OneVsRestModel private[ml] ( @Since("2.1.0") def setPredictionCol(value: String): this.type = set(predictionCol, value) + /** @group setParam */ + @Since("2.4.0") + def setRawPredictionCol(value: String): this.type = set(rawPredictionCol, value) + @Since("1.4.0") override def transformSchema(schema: StructType): StructType = { validateAndTransformSchema(schema, fitting = false, getClassifier.featuresDataType) @@ -181,6 +193,7 @@ final class OneVsRestModel private[ml] ( val updateUDF = udf { (predictions: Map[Int, Double], prediction: Vector) => predictions + ((index, prediction(1))) } + model.setFeaturesCol($(featuresCol)) val transformedDataset = model.transform(df).select(columns: _*) val updatedDataset = transformedDataset @@ -195,15 +208,34 @@ final class OneVsRestModel private[ml] ( newDataset.unpersist() } - // output the index of the classifier with highest confidence as prediction - val labelUDF = udf { (predictions: Map[Int, Double]) => - predictions.maxBy(_._2)._1.toDouble - } + if (getRawPredictionCol != "") { + val numClass = models.length - // output label and label metadata as prediction - aggregatedDataset - .withColumn($(predictionCol), labelUDF(col(accColName)), labelMetadata) - .drop(accColName) + // output the RawPrediction as vector + val rawPredictionUDF = udf { (predictions: Map[Int, Double]) => + val predArray = Array.fill[Double](numClass)(0.0) + predictions.foreach { case (idx, value) => predArray(idx) = value } + Vectors.dense(predArray) + } + + // output the index of the classifier with highest confidence as prediction + val labelUDF = udf { (rawPredictions: Vector) => rawPredictions.argmax.toDouble } + + // output confidence as raw prediction, label and label metadata as prediction + aggregatedDataset + .withColumn(getRawPredictionCol, rawPredictionUDF(col(accColName))) + .withColumn(getPredictionCol, labelUDF(col(getRawPredictionCol)), labelMetadata) + .drop(accColName) + } else { + // output the index of the classifier with highest confidence as prediction + val labelUDF = udf { (predictions: Map[Int, Double]) => + predictions.maxBy(_._2)._1.toDouble + } + // output label and label metadata as prediction + aggregatedDataset + .withColumn(getPredictionCol, labelUDF(col(accColName)), labelMetadata) + .drop(accColName) + } } @Since("1.4.1") @@ -297,6 +329,10 @@ final class OneVsRest @Since("1.4.0") ( @Since("1.5.0") def setPredictionCol(value: String): this.type = set(predictionCol, value) + /** @group setParam */ + @Since("2.4.0") + def setRawPredictionCol(value: String): this.type = set(rawPredictionCol, value) + /** * The implementation of parallel one vs. rest runs the classification for * each class in a separate threads. diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala index 11e88367108b4..2c3417c7e4028 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala @@ -72,11 +72,12 @@ class OneVsRestSuite extends MLTest with DefaultReadWriteTest { .setClassifier(new LogisticRegression) assert(ova.getLabelCol === "label") assert(ova.getPredictionCol === "prediction") + assert(ova.getRawPredictionCol === "rawPrediction") val ovaModel = ova.fit(dataset) MLTestingUtils.checkCopyAndUids(ova, ovaModel) - assert(ovaModel.models.length === numClasses) + assert(ovaModel.numClasses === numClasses) val transformedDataset = ovaModel.transform(dataset) // check for label metadata in prediction col @@ -179,6 +180,7 @@ class OneVsRestSuite extends MLTest with DefaultReadWriteTest { val dataset2 = dataset.select(col("label").as("y"), col("features").as("fea")) ovaModel.setFeaturesCol("fea") ovaModel.setPredictionCol("pred") + ovaModel.setRawPredictionCol("") val transformedDataset = ovaModel.transform(dataset2) val outputFields = transformedDataset.schema.fieldNames.toSet assert(outputFields === Set("y", "fea", "pred")) @@ -190,7 +192,8 @@ class OneVsRestSuite extends MLTest with DefaultReadWriteTest { val ovr = new OneVsRest() .setClassifier(logReg) val output = ovr.fit(dataset).transform(dataset) - assert(output.schema.fieldNames.toSet === Set("label", "features", "prediction")) + assert(output.schema.fieldNames.toSet + === Set("label", "features", "prediction", "rawPrediction")) } test("SPARK-21306: OneVsRest should support setWeightCol") { From 04614820e103feeae91299dc90dba1dd628fd485 Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Mon, 16 Apr 2018 11:31:24 -0500 Subject: [PATCH 0626/1161] [SPARK-21088][ML] CrossValidator, TrainValidationSplit support collect all models when fitting: Python API ## What changes were proposed in this pull request? Add python API for collecting sub-models during CrossValidator/TrainValidationSplit fitting. ## How was this patch tested? UT added. Author: WeichenXu Closes #19627 from WeichenXu123/expose-model-list-py. --- .../spark/ml/tuning/CrossValidator.scala | 11 ++ .../ml/tuning/TrainValidationSplit.scala | 11 ++ .../ml/param/_shared_params_code_gen.py | 5 + python/pyspark/ml/param/shared.py | 24 ++++ python/pyspark/ml/tests.py | 78 +++++++++++++ python/pyspark/ml/tuning.py | 107 +++++++++++++----- python/pyspark/ml/util.py | 4 + 7 files changed, 211 insertions(+), 29 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index a0b507d2e718c..c2826dcc08634 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -270,6 +270,17 @@ class CrossValidatorModel private[ml] ( this } + // A Python-friendly auxiliary method + private[tuning] def setSubModels(subModels: JList[JList[Model[_]]]) + : CrossValidatorModel = { + _subModels = if (subModels != null) { + Some(subModels.asScala.toArray.map(_.asScala.toArray)) + } else { + None + } + this + } + /** * @return submodels represented in two dimension array. The index of outer array is the * fold index, and the index of inner array corresponds to the ordering of diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala index 88ff0dfd75e96..8d1b9a8ddab59 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala @@ -262,6 +262,17 @@ class TrainValidationSplitModel private[ml] ( this } + // A Python-friendly auxiliary method + private[tuning] def setSubModels(subModels: JList[Model[_]]) + : TrainValidationSplitModel = { + _subModels = if (subModels != null) { + Some(subModels.asScala.toArray) + } else { + None + } + this + } + /** * @return submodels represented in array. The index of array corresponds to the ordering of * estimatorParamMaps diff --git a/python/pyspark/ml/param/_shared_params_code_gen.py b/python/pyspark/ml/param/_shared_params_code_gen.py index db951d81de1e7..6e9e0a34cdfde 100644 --- a/python/pyspark/ml/param/_shared_params_code_gen.py +++ b/python/pyspark/ml/param/_shared_params_code_gen.py @@ -157,6 +157,11 @@ def get$Name(self): "TypeConverters.toInt"), ("parallelism", "the number of threads to use when running parallel algorithms (>= 1).", "1", "TypeConverters.toInt"), + ("collectSubModels", "Param for whether to collect a list of sub-models trained during " + + "tuning. If set to false, then only the single best sub-model will be available after " + + "fitting. If set to true, then all sub-models will be available. Warning: For large " + + "models, collecting all sub-models can cause OOMs on the Spark driver.", + "False", "TypeConverters.toBoolean"), ("loss", "the loss function to be optimized.", None, "TypeConverters.toString")] code = [] diff --git a/python/pyspark/ml/param/shared.py b/python/pyspark/ml/param/shared.py index 474c38764e5a1..08408ee8fbfcc 100644 --- a/python/pyspark/ml/param/shared.py +++ b/python/pyspark/ml/param/shared.py @@ -655,6 +655,30 @@ def getParallelism(self): return self.getOrDefault(self.parallelism) +class HasCollectSubModels(Params): + """ + Mixin for param collectSubModels: Param for whether to collect a list of sub-models trained during tuning. If set to false, then only the single best sub-model will be available after fitting. If set to true, then all sub-models will be available. Warning: For large models, collecting all sub-models can cause OOMs on the Spark driver. + """ + + collectSubModels = Param(Params._dummy(), "collectSubModels", "Param for whether to collect a list of sub-models trained during tuning. If set to false, then only the single best sub-model will be available after fitting. If set to true, then all sub-models will be available. Warning: For large models, collecting all sub-models can cause OOMs on the Spark driver.", typeConverter=TypeConverters.toBoolean) + + def __init__(self): + super(HasCollectSubModels, self).__init__() + self._setDefault(collectSubModels=False) + + def setCollectSubModels(self, value): + """ + Sets the value of :py:attr:`collectSubModels`. + """ + return self._set(collectSubModels=value) + + def getCollectSubModels(self): + """ + Gets the value of collectSubModels or its default value. + """ + return self.getOrDefault(self.collectSubModels) + + class HasLoss(Params): """ Mixin for param loss: the loss function to be optimized. diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 4ce54547eab09..2ec0be60e9fa9 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -1018,6 +1018,50 @@ def test_parallel_evaluation(self): cvParallelModel = cv.fit(dataset) self.assertEqual(cvSerialModel.avgMetrics, cvParallelModel.avgMetrics) + def test_expose_sub_models(self): + temp_path = tempfile.mkdtemp() + dataset = self.spark.createDataFrame( + [(Vectors.dense([0.0]), 0.0), + (Vectors.dense([0.4]), 1.0), + (Vectors.dense([0.5]), 0.0), + (Vectors.dense([0.6]), 1.0), + (Vectors.dense([1.0]), 1.0)] * 10, + ["features", "label"]) + + lr = LogisticRegression() + grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build() + evaluator = BinaryClassificationEvaluator() + + numFolds = 3 + cv = CrossValidator(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator, + numFolds=numFolds, collectSubModels=True) + + def checkSubModels(subModels): + self.assertEqual(len(subModels), numFolds) + for i in range(numFolds): + self.assertEqual(len(subModels[i]), len(grid)) + + cvModel = cv.fit(dataset) + checkSubModels(cvModel.subModels) + + # Test the default value for option "persistSubModel" to be "true" + testSubPath = temp_path + "/testCrossValidatorSubModels" + savingPathWithSubModels = testSubPath + "cvModel3" + cvModel.save(savingPathWithSubModels) + cvModel3 = CrossValidatorModel.load(savingPathWithSubModels) + checkSubModels(cvModel3.subModels) + cvModel4 = cvModel3.copy() + checkSubModels(cvModel4.subModels) + + savingPathWithoutSubModels = testSubPath + "cvModel2" + cvModel.write().option("persistSubModels", "false").save(savingPathWithoutSubModels) + cvModel2 = CrossValidatorModel.load(savingPathWithoutSubModels) + self.assertEqual(cvModel2.subModels, None) + + for i in range(numFolds): + for j in range(len(grid)): + self.assertEqual(cvModel.subModels[i][j].uid, cvModel3.subModels[i][j].uid) + def test_save_load_nested_estimator(self): temp_path = tempfile.mkdtemp() dataset = self.spark.createDataFrame( @@ -1186,6 +1230,40 @@ def test_parallel_evaluation(self): tvsParallelModel = tvs.fit(dataset) self.assertEqual(tvsSerialModel.validationMetrics, tvsParallelModel.validationMetrics) + def test_expose_sub_models(self): + temp_path = tempfile.mkdtemp() + dataset = self.spark.createDataFrame( + [(Vectors.dense([0.0]), 0.0), + (Vectors.dense([0.4]), 1.0), + (Vectors.dense([0.5]), 0.0), + (Vectors.dense([0.6]), 1.0), + (Vectors.dense([1.0]), 1.0)] * 10, + ["features", "label"]) + lr = LogisticRegression() + grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build() + evaluator = BinaryClassificationEvaluator() + tvs = TrainValidationSplit(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator, + collectSubModels=True) + tvsModel = tvs.fit(dataset) + self.assertEqual(len(tvsModel.subModels), len(grid)) + + # Test the default value for option "persistSubModel" to be "true" + testSubPath = temp_path + "/testTrainValidationSplitSubModels" + savingPathWithSubModels = testSubPath + "cvModel3" + tvsModel.save(savingPathWithSubModels) + tvsModel3 = TrainValidationSplitModel.load(savingPathWithSubModels) + self.assertEqual(len(tvsModel3.subModels), len(grid)) + tvsModel4 = tvsModel3.copy() + self.assertEqual(len(tvsModel4.subModels), len(grid)) + + savingPathWithoutSubModels = testSubPath + "cvModel2" + tvsModel.write().option("persistSubModels", "false").save(savingPathWithoutSubModels) + tvsModel2 = TrainValidationSplitModel.load(savingPathWithoutSubModels) + self.assertEqual(tvsModel2.subModels, None) + + for i in range(len(grid)): + self.assertEqual(tvsModel.subModels[i].uid, tvsModel3.subModels[i].uid) + def test_save_load_nested_estimator(self): # This tests saving and loading the trained model only. # Save/load for TrainValidationSplit will be added later: SPARK-13786 diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py index 545e24ca05aa5..0c8029f293cfe 100644 --- a/python/pyspark/ml/tuning.py +++ b/python/pyspark/ml/tuning.py @@ -24,7 +24,7 @@ from pyspark.ml import Estimator, Model from pyspark.ml.common import _py2java from pyspark.ml.param import Params, Param, TypeConverters -from pyspark.ml.param.shared import HasParallelism, HasSeed +from pyspark.ml.param.shared import HasCollectSubModels, HasParallelism, HasSeed from pyspark.ml.util import * from pyspark.ml.wrapper import JavaParams from pyspark.sql.functions import rand @@ -33,7 +33,7 @@ 'TrainValidationSplitModel'] -def _parallelFitTasks(est, train, eva, validation, epm): +def _parallelFitTasks(est, train, eva, validation, epm, collectSubModel): """ Creates a list of callables which can be called from different threads to fit and evaluate an estimator in parallel. Each callable returns an `(index, metric)` pair. @@ -43,14 +43,15 @@ def _parallelFitTasks(est, train, eva, validation, epm): :param eva: Evaluator, used to compute `metric` :param validation: DataFrame, validation data set, used for evaluation. :param epm: Sequence of ParamMap, params maps to be used during fitting & evaluation. - :return: (int, float), an index into `epm` and the associated metric value. + :param collectSubModel: Whether to collect sub model. + :return: (int, float, subModel), an index into `epm` and the associated metric value. """ modelIter = est.fitMultiple(train, epm) def singleTask(): index, model = next(modelIter) metric = eva.evaluate(model.transform(validation, epm[index])) - return index, metric + return index, metric, model if collectSubModel else None return [singleTask] * len(epm) @@ -194,7 +195,8 @@ def _to_java_impl(self): return java_estimator, java_epms, java_evaluator -class CrossValidator(Estimator, ValidatorParams, HasParallelism, MLReadable, MLWritable): +class CrossValidator(Estimator, ValidatorParams, HasParallelism, HasCollectSubModels, + MLReadable, MLWritable): """ K-fold cross validation performs model selection by splitting the dataset into a set of @@ -233,10 +235,10 @@ class CrossValidator(Estimator, ValidatorParams, HasParallelism, MLReadable, MLW @keyword_only def __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3, - seed=None, parallelism=1): + seed=None, parallelism=1, collectSubModels=False): """ __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3,\ - seed=None, parallelism=1) + seed=None, parallelism=1, collectSubModels=False) """ super(CrossValidator, self).__init__() self._setDefault(numFolds=3, parallelism=1) @@ -246,10 +248,10 @@ def __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numF @keyword_only @since("1.4.0") def setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3, - seed=None, parallelism=1): + seed=None, parallelism=1, collectSubModels=False): """ setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3,\ - seed=None, parallelism=1): + seed=None, parallelism=1, collectSubModels=False): Sets params for cross validator. """ kwargs = self._input_kwargs @@ -282,6 +284,10 @@ def _fit(self, dataset): metrics = [0.0] * numModels pool = ThreadPool(processes=min(self.getParallelism(), numModels)) + subModels = None + collectSubModelsParam = self.getCollectSubModels() + if collectSubModelsParam: + subModels = [[None for j in range(numModels)] for i in range(nFolds)] for i in range(nFolds): validateLB = i * h @@ -290,9 +296,12 @@ def _fit(self, dataset): validation = df.filter(condition).cache() train = df.filter(~condition).cache() - tasks = _parallelFitTasks(est, train, eva, validation, epm) - for j, metric in pool.imap_unordered(lambda f: f(), tasks): + tasks = _parallelFitTasks(est, train, eva, validation, epm, collectSubModelsParam) + for j, metric, subModel in pool.imap_unordered(lambda f: f(), tasks): metrics[j] += (metric / nFolds) + if collectSubModelsParam: + subModels[i][j] = subModel + validation.unpersist() train.unpersist() @@ -301,7 +310,7 @@ def _fit(self, dataset): else: bestIndex = np.argmin(metrics) bestModel = est.fit(dataset, epm[bestIndex]) - return self._copyValues(CrossValidatorModel(bestModel, metrics)) + return self._copyValues(CrossValidatorModel(bestModel, metrics, subModels)) @since("1.4.0") def copy(self, extra=None): @@ -345,9 +354,11 @@ def _from_java(cls, java_stage): numFolds = java_stage.getNumFolds() seed = java_stage.getSeed() parallelism = java_stage.getParallelism() + collectSubModels = java_stage.getCollectSubModels() # Create a new instance of this stage. py_stage = cls(estimator=estimator, estimatorParamMaps=epms, evaluator=evaluator, - numFolds=numFolds, seed=seed, parallelism=parallelism) + numFolds=numFolds, seed=seed, parallelism=parallelism, + collectSubModels=collectSubModels) py_stage._resetUid(java_stage.uid()) return py_stage @@ -367,6 +378,7 @@ def _to_java(self): _java_obj.setSeed(self.getSeed()) _java_obj.setNumFolds(self.getNumFolds()) _java_obj.setParallelism(self.getParallelism()) + _java_obj.setCollectSubModels(self.getCollectSubModels()) return _java_obj @@ -381,13 +393,15 @@ class CrossValidatorModel(Model, ValidatorParams, MLReadable, MLWritable): .. versionadded:: 1.4.0 """ - def __init__(self, bestModel, avgMetrics=[]): + def __init__(self, bestModel, avgMetrics=[], subModels=None): super(CrossValidatorModel, self).__init__() #: best model from cross validation self.bestModel = bestModel #: Average cross-validation metrics for each paramMap in #: CrossValidator.estimatorParamMaps, in the corresponding order. self.avgMetrics = avgMetrics + #: sub model list from cross validation + self.subModels = subModels def _transform(self, dataset): return self.bestModel.transform(dataset) @@ -399,6 +413,7 @@ def copy(self, extra=None): and some extra params. This copies the underlying bestModel, creates a deep copy of the embedded paramMap, and copies the embedded and extra parameters over. + It does not copy the extra Params into the subModels. :param extra: Extra parameters to copy to the new instance :return: Copy of this instance @@ -407,7 +422,8 @@ def copy(self, extra=None): extra = dict() bestModel = self.bestModel.copy(extra) avgMetrics = self.avgMetrics - return CrossValidatorModel(bestModel, avgMetrics) + subModels = self.subModels + return CrossValidatorModel(bestModel, avgMetrics, subModels) @since("2.3.0") def write(self): @@ -426,13 +442,17 @@ def _from_java(cls, java_stage): Given a Java CrossValidatorModel, create and return a Python wrapper of it. Used for ML persistence. """ - bestModel = JavaParams._from_java(java_stage.bestModel()) estimator, epms, evaluator = super(CrossValidatorModel, cls)._from_java_impl(java_stage) py_stage = cls(bestModel=bestModel).setEstimator(estimator) py_stage = py_stage.setEstimatorParamMaps(epms).setEvaluator(evaluator) + if java_stage.hasSubModels(): + py_stage.subModels = [[JavaParams._from_java(sub_model) + for sub_model in fold_sub_models] + for fold_sub_models in java_stage.subModels()] + py_stage._resetUid(java_stage.uid()) return py_stage @@ -454,10 +474,16 @@ def _to_java(self): _java_obj.set("evaluator", evaluator) _java_obj.set("estimator", estimator) _java_obj.set("estimatorParamMaps", epms) + + if self.subModels is not None: + java_sub_models = [[sub_model._to_java() for sub_model in fold_sub_models] + for fold_sub_models in self.subModels] + _java_obj.setSubModels(java_sub_models) return _java_obj -class TrainValidationSplit(Estimator, ValidatorParams, HasParallelism, MLReadable, MLWritable): +class TrainValidationSplit(Estimator, ValidatorParams, HasParallelism, HasCollectSubModels, + MLReadable, MLWritable): """ .. note:: Experimental @@ -492,10 +518,10 @@ class TrainValidationSplit(Estimator, ValidatorParams, HasParallelism, MLReadabl @keyword_only def __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, trainRatio=0.75, - parallelism=1, seed=None): + parallelism=1, collectSubModels=False, seed=None): """ __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, trainRatio=0.75,\ - parallelism=1, seed=None) + parallelism=1, collectSubModels=False, seed=None) """ super(TrainValidationSplit, self).__init__() self._setDefault(trainRatio=0.75, parallelism=1) @@ -505,10 +531,10 @@ def __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, trai @since("2.0.0") @keyword_only def setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, trainRatio=0.75, - parallelism=1, seed=None): + parallelism=1, collectSubModels=False, seed=None): """ setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, trainRatio=0.75,\ - parallelism=1, seed=None): + parallelism=1, collectSubModels=False, seed=None): Sets params for the train validation split. """ kwargs = self._input_kwargs @@ -541,11 +567,19 @@ def _fit(self, dataset): validation = df.filter(condition).cache() train = df.filter(~condition).cache() - tasks = _parallelFitTasks(est, train, eva, validation, epm) + subModels = None + collectSubModelsParam = self.getCollectSubModels() + if collectSubModelsParam: + subModels = [None for i in range(numModels)] + + tasks = _parallelFitTasks(est, train, eva, validation, epm, collectSubModelsParam) pool = ThreadPool(processes=min(self.getParallelism(), numModels)) metrics = [None] * numModels - for j, metric in pool.imap_unordered(lambda f: f(), tasks): + for j, metric, subModel in pool.imap_unordered(lambda f: f(), tasks): metrics[j] = metric + if collectSubModelsParam: + subModels[j] = subModel + train.unpersist() validation.unpersist() @@ -554,7 +588,7 @@ def _fit(self, dataset): else: bestIndex = np.argmin(metrics) bestModel = est.fit(dataset, epm[bestIndex]) - return self._copyValues(TrainValidationSplitModel(bestModel, metrics)) + return self._copyValues(TrainValidationSplitModel(bestModel, metrics, subModels)) @since("2.0.0") def copy(self, extra=None): @@ -598,9 +632,11 @@ def _from_java(cls, java_stage): trainRatio = java_stage.getTrainRatio() seed = java_stage.getSeed() parallelism = java_stage.getParallelism() + collectSubModels = java_stage.getCollectSubModels() # Create a new instance of this stage. py_stage = cls(estimator=estimator, estimatorParamMaps=epms, evaluator=evaluator, - trainRatio=trainRatio, seed=seed, parallelism=parallelism) + trainRatio=trainRatio, seed=seed, parallelism=parallelism, + collectSubModels=collectSubModels) py_stage._resetUid(java_stage.uid()) return py_stage @@ -620,7 +656,7 @@ def _to_java(self): _java_obj.setTrainRatio(self.getTrainRatio()) _java_obj.setSeed(self.getSeed()) _java_obj.setParallelism(self.getParallelism()) - + _java_obj.setCollectSubModels(self.getCollectSubModels()) return _java_obj @@ -633,12 +669,14 @@ class TrainValidationSplitModel(Model, ValidatorParams, MLReadable, MLWritable): .. versionadded:: 2.0.0 """ - def __init__(self, bestModel, validationMetrics=[]): + def __init__(self, bestModel, validationMetrics=[], subModels=None): super(TrainValidationSplitModel, self).__init__() - #: best model from cross validation + #: best model from train validation split self.bestModel = bestModel #: evaluated validation metrics self.validationMetrics = validationMetrics + #: sub models from train validation split + self.subModels = subModels def _transform(self, dataset): return self.bestModel.transform(dataset) @@ -651,6 +689,7 @@ def copy(self, extra=None): creates a deep copy of the embedded paramMap, and copies the embedded and extra parameters over. And, this creates a shallow copy of the validationMetrics. + It does not copy the extra Params into the subModels. :param extra: Extra parameters to copy to the new instance :return: Copy of this instance @@ -659,7 +698,8 @@ def copy(self, extra=None): extra = dict() bestModel = self.bestModel.copy(extra) validationMetrics = list(self.validationMetrics) - return TrainValidationSplitModel(bestModel, validationMetrics) + subModels = self.subModels + return TrainValidationSplitModel(bestModel, validationMetrics, subModels) @since("2.3.0") def write(self): @@ -687,6 +727,10 @@ def _from_java(cls, java_stage): py_stage = cls(bestModel=bestModel).setEstimator(estimator) py_stage = py_stage.setEstimatorParamMaps(epms).setEvaluator(evaluator) + if java_stage.hasSubModels(): + py_stage.subModels = [JavaParams._from_java(sub_model) + for sub_model in java_stage.subModels()] + py_stage._resetUid(java_stage.uid()) return py_stage @@ -708,6 +752,11 @@ def _to_java(self): _java_obj.set("evaluator", evaluator) _java_obj.set("estimator", estimator) _java_obj.set("estimatorParamMaps", epms) + + if self.subModels is not None: + java_sub_models = [sub_model._to_java() for sub_model in self.subModels] + _java_obj.setSubModels(java_sub_models) + return _java_obj diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py index c3c47bd79459a..a486c6a3fdeb5 100644 --- a/python/pyspark/ml/util.py +++ b/python/pyspark/ml/util.py @@ -169,6 +169,10 @@ def overwrite(self): self._jwrite.overwrite() return self + def option(self, key, value): + self._jwrite.option(key, value) + return self + def context(self, sqlContext): """ Sets the SQL context to use for saving. From fd990a908b94d1c90c4ca604604f35a13b453d44 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 16 Apr 2018 22:45:57 +0200 Subject: [PATCH 0627/1161] [SPARK-23873][SQL] Use accessors in interpreted LambdaVariable ## What changes were proposed in this pull request? Currently, interpreted execution of `LambdaVariable` just uses `InternalRow.get` to access element. We should use specified accessors if possible. ## How was this patch tested? Added test. Author: Liang-Chi Hsieh Closes #20981 from viirya/SPARK-23873. --- .../spark/sql/catalyst/InternalRow.scala | 26 ++++++++++++- .../catalyst/expressions/BoundAttribute.scala | 22 ++--------- .../expressions/objects/objects.scala | 8 +++- .../expressions/ExpressionEvalHelper.scala | 4 +- .../expressions/ObjectExpressionsSuite.scala | 38 ++++++++++++++++++- 5 files changed, 75 insertions(+), 23 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala index 29110640d64f2..274d75e680f03 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} -import org.apache.spark.sql.types.{DataType, Decimal, StructType} +import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String /** @@ -119,4 +119,28 @@ object InternalRow { case v: MapData => v.copy() case _ => value } + + /** + * Returns an accessor for an `InternalRow` with given data type. The returned accessor + * actually takes a `SpecializedGetters` input because it can be generalized to other classes + * that implements `SpecializedGetters` (e.g., `ArrayData`) too. + */ + def getAccessor(dataType: DataType): (SpecializedGetters, Int) => Any = dataType match { + case BooleanType => (input, ordinal) => input.getBoolean(ordinal) + case ByteType => (input, ordinal) => input.getByte(ordinal) + case ShortType => (input, ordinal) => input.getShort(ordinal) + case IntegerType | DateType => (input, ordinal) => input.getInt(ordinal) + case LongType | TimestampType => (input, ordinal) => input.getLong(ordinal) + case FloatType => (input, ordinal) => input.getFloat(ordinal) + case DoubleType => (input, ordinal) => input.getDouble(ordinal) + case StringType => (input, ordinal) => input.getUTF8String(ordinal) + case BinaryType => (input, ordinal) => input.getBinary(ordinal) + case CalendarIntervalType => (input, ordinal) => input.getInterval(ordinal) + case t: DecimalType => (input, ordinal) => input.getDecimal(ordinal, t.precision, t.scale) + case t: StructType => (input, ordinal) => input.getStruct(ordinal, t.size) + case _: ArrayType => (input, ordinal) => input.getArray(ordinal) + case _: MapType => (input, ordinal) => input.getMap(ordinal) + case u: UserDefinedType[_] => getAccessor(u.sqlType) + case _ => (input, ordinal) => input.get(ordinal, dataType) + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index 5021a567592e0..4cc84b27d9eb0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -33,28 +33,14 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) override def toString: String = s"input[$ordinal, ${dataType.simpleString}, $nullable]" + private val accessor: (InternalRow, Int) => Any = InternalRow.getAccessor(dataType) + // Use special getter for primitive types (for UnsafeRow) override def eval(input: InternalRow): Any = { - if (input.isNullAt(ordinal)) { + if (nullable && input.isNullAt(ordinal)) { null } else { - dataType match { - case BooleanType => input.getBoolean(ordinal) - case ByteType => input.getByte(ordinal) - case ShortType => input.getShort(ordinal) - case IntegerType | DateType => input.getInt(ordinal) - case LongType | TimestampType => input.getLong(ordinal) - case FloatType => input.getFloat(ordinal) - case DoubleType => input.getDouble(ordinal) - case StringType => input.getUTF8String(ordinal) - case BinaryType => input.getBinary(ordinal) - case CalendarIntervalType => input.getInterval(ordinal) - case t: DecimalType => input.getDecimal(ordinal, t.precision, t.scale) - case t: StructType => input.getStruct(ordinal, t.size) - case _: ArrayType => input.getArray(ordinal) - case _: MapType => input.getMap(ordinal) - case _ => input.get(ordinal, dataType) - } + accessor(input, ordinal) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 50e90ca550807..77802e89e942b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -560,11 +560,17 @@ case class LambdaVariable( dataType: DataType, nullable: Boolean = true) extends LeafExpression with NonSQLExpression { + private val accessor: (InternalRow, Int) => Any = InternalRow.getAccessor(dataType) + // Interpreted execution of `LambdaVariable` always get the 0-index element from input row. override def eval(input: InternalRow): Any = { assert(input.numFields == 1, "The input row of interpreted LambdaVariable should have only 1 field.") - input.get(0, dataType) + if (nullable && input.isNullAt(0)) { + null + } else { + accessor(input, 0) + } } override def genCode(ctx: CodegenContext): ExprCode = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index a5ecd1b68fac4..b4bf6d7107d7e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -70,7 +70,9 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { * Check the equality between result of expression and expected value, it will handle * Array[Byte], Spread[Double], MapData and Row. */ - protected def checkResult(result: Any, expected: Any, dataType: DataType): Boolean = { + protected def checkResult(result: Any, expected: Any, exprDataType: DataType): Boolean = { + val dataType = UserDefinedType.sqlType(exprDataType) + (result, expected) match { case (result: Array[Byte], expected: Array[Byte]) => java.util.Arrays.equals(result, expected) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala index b1bc67dfac1b5..b0188b0098def 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala @@ -21,13 +21,14 @@ import java.sql.{Date, Timestamp} import scala.collection.JavaConverters._ import scala.reflect.ClassTag +import scala.util.Random import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} -import org.apache.spark.sql.Row +import org.apache.spark.sql.{RandomDataGenerator, Row} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.ResolveTimeZone -import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.catalyst.encoders.{ExamplePointUDT, ExpressionEncoder, RowEncoder} import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection import org.apache.spark.sql.catalyst.expressions.objects._ import org.apache.spark.sql.catalyst.util._ @@ -381,6 +382,39 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(decodeUsingSerializer, null, InternalRow.fromSeq(Seq(null))) } } + + test("LambdaVariable should support interpreted execution") { + def genSchema(dt: DataType): Seq[StructType] = { + Seq(StructType(StructField("col_1", dt, nullable = false) :: Nil), + StructType(StructField("col_1", dt, nullable = true) :: Nil)) + } + + val elementTypes = Seq(BooleanType, ByteType, ShortType, IntegerType, LongType, FloatType, + DoubleType, DecimalType.USER_DEFAULT, StringType, BinaryType, DateType, TimestampType, + CalendarIntervalType, new ExamplePointUDT()) + val arrayTypes = elementTypes.flatMap { elementType => + Seq(ArrayType(elementType, containsNull = false), ArrayType(elementType, containsNull = true)) + } + val mapTypes = elementTypes.flatMap { elementType => + Seq(MapType(elementType, elementType, false), MapType(elementType, elementType, true)) + } + val structTypes = elementTypes.flatMap { elementType => + Seq(StructType(StructField("col1", elementType, false) :: Nil), + StructType(StructField("col1", elementType, true) :: Nil)) + } + + val testTypes = elementTypes ++ arrayTypes ++ mapTypes ++ structTypes + val random = new Random(100) + testTypes.foreach { dt => + genSchema(dt).map { schema => + val row = RandomDataGenerator.randomRow(random, schema) + val rowConverter = RowEncoder(schema) + val internalRow = rowConverter.toRow(row) + val lambda = LambdaVariable("dummy", "dummuIsNull", schema(0).dataType, schema(0).nullable) + checkEvaluationWithoutCodegen(lambda, internalRow.get(0, schema(0).dataType), internalRow) + } + } + } } class TestBean extends Serializable { From 14844a62c025e7299029d7452b8c4003bc221ac8 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Tue, 17 Apr 2018 17:55:35 +0900 Subject: [PATCH 0628/1161] [SPARK-23918][SQL] Add array_min function ## What changes were proposed in this pull request? The PR adds the SQL function `array_min`. It takes an array as argument and returns the minimum value in it. ## How was this patch tested? added UTs Author: Marco Gaido Closes #21025 from mgaido91/SPARK-23918. --- python/pyspark/sql/functions.py | 17 ++++- .../catalyst/analysis/FunctionRegistry.scala | 1 + .../sql/catalyst/expressions/arithmetic.scala | 6 +- .../expressions/codegen/CodeGenerator.scala | 17 +++++ .../expressions/collectionOperations.scala | 64 +++++++++++++++++++ .../CollectionExpressionsSuite.scala | 10 +++ .../org/apache/spark/sql/functions.scala | 8 +++ .../spark/sql/DataFrameFunctionsSuite.scala | 14 ++++ 8 files changed, 131 insertions(+), 6 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index f3492ae42639c..6ca22b610843d 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2080,6 +2080,21 @@ def size(col): return Column(sc._jvm.functions.size(_to_java_column(col))) +@since(2.4) +def array_min(col): + """ + Collection function: returns the minimum value of the array. + + :param col: name of column or expression + + >>> df = spark.createDataFrame([([2, 1, 3],), ([None, 10, -1],)], ['data']) + >>> df.select(array_min(df.data).alias('min')).collect() + [Row(min=1), Row(min=-1)] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.array_min(_to_java_column(col))) + + @since(2.4) def array_max(col): """ @@ -2108,7 +2123,7 @@ def sort_array(col, asc=True): [Row(r=[1, 2, 3]), Row(r=[1]), Row(r=[])] >>> df.select(sort_array(df.data, asc=False).alias('r')).collect() [Row(r=[3, 2, 1]), Row(r=[1]), Row(r=[])] - """ + """ sc = SparkContext._active_spark_context return Column(sc._jvm.functions.sort_array(_to_java_column(col), asc)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 05bfa2dd45340..4dd1ca509bf2c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -409,6 +409,7 @@ object FunctionRegistry { expression[MapValues]("map_values"), expression[Size]("size"), expression[SortArray]("sort_array"), + expression[ArrayMin]("array_min"), expression[ArrayMax]("array_max"), CreateStruct.registryEntry, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 942dfd4292610..d4e322d23b95b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -595,11 +595,7 @@ case class Least(children: Seq[Expression]) extends Expression { val evals = evalChildren.map(eval => s""" |${eval.code} - |if (!${eval.isNull} && (${ev.isNull} || - | ${ctx.genGreater(dataType, ev.value, eval.value)})) { - | ${ev.isNull} = false; - | ${ev.value} = ${eval.value}; - |} + |${ctx.reassignIfSmaller(dataType, ev, eval)} """.stripMargin ) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index c86c5beded9d0..d97611c98ac91 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -699,6 +699,23 @@ class CodegenContext { case _ => s"(${genComp(dataType, c1, c2)}) > 0" } + /** + * Generates code for updating `partialResult` if `item` is smaller than it. + * + * @param dataType data type of the expressions + * @param partialResult `ExprCode` representing the partial result which has to be updated + * @param item `ExprCode` representing the new expression to evaluate for the result + */ + def reassignIfSmaller(dataType: DataType, partialResult: ExprCode, item: ExprCode): String = { + s""" + |if (!${item.isNull} && (${partialResult.isNull} || + | ${genGreater(dataType, partialResult.value, item.value)})) { + | ${partialResult.isNull} = false; + | ${partialResult.value} = ${item.value}; + |} + """.stripMargin + } + /** * Generates code for updating `partialResult` if `item` is greater than it. * diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index e2614a179aad8..7c87777eed47a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -288,6 +288,70 @@ case class ArrayContains(left: Expression, right: Expression) override def prettyName: String = "array_contains" } +/** + * Returns the minimum value in the array. + */ +@ExpressionDescription( + usage = "_FUNC_(array) - Returns the minimum value in the array. NULL elements are skipped.", + examples = """ + Examples: + > SELECT _FUNC_(array(1, 20, null, 3)); + 1 + """, since = "2.4.0") +case class ArrayMin(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { + + override def nullable: Boolean = true + + override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType) + + private lazy val ordering = TypeUtils.getInterpretedOrdering(dataType) + + override def checkInputDataTypes(): TypeCheckResult = { + val typeCheckResult = super.checkInputDataTypes() + if (typeCheckResult.isSuccess) { + TypeUtils.checkForOrderingExpr(dataType, s"function $prettyName") + } else { + typeCheckResult + } + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val childGen = child.genCode(ctx) + val javaType = CodeGenerator.javaType(dataType) + val i = ctx.freshName("i") + val item = ExprCode("", + isNull = JavaCode.isNullExpression(s"${childGen.value}.isNullAt($i)"), + value = JavaCode.expression(CodeGenerator.getValue(childGen.value, dataType, i), dataType)) + ev.copy(code = + s""" + |${childGen.code} + |boolean ${ev.isNull} = true; + |$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; + |if (!${childGen.isNull}) { + | for (int $i = 0; $i < ${childGen.value}.numElements(); $i ++) { + | ${ctx.reassignIfSmaller(dataType, ev, item)} + | } + |} + """.stripMargin) + } + + override protected def nullSafeEval(input: Any): Any = { + var min: Any = null + input.asInstanceOf[ArrayData].foreach(dataType, (_, item) => + if (item != null && (min == null || ordering.lt(item, min))) { + min = item + } + ) + min + } + + override def dataType: DataType = child.dataType match { + case ArrayType(dt, _) => dt + case _ => throw new IllegalStateException(s"$prettyName accepts only arrays.") + } + + override def prettyName: String = "array_min" +} /** * Returns the maximum value in the array. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index a2384019533b7..5a31e3a30edd6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -106,6 +106,16 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(ArrayContains(a3, Literal.create(null, StringType)), null) } + test("Array Min") { + checkEvaluation(ArrayMin(Literal.create(Seq(-11, 10, 2), ArrayType(IntegerType))), -11) + checkEvaluation( + ArrayMin(Literal.create(Seq[String](null, "abc", ""), ArrayType(StringType))), "") + checkEvaluation(ArrayMin(Literal.create(Seq(null), ArrayType(LongType))), null) + checkEvaluation(ArrayMin(Literal.create(null, ArrayType(StringType))), null) + checkEvaluation( + ArrayMin(Literal.create(Seq(1.123, 0.1234, 1.121), ArrayType(DoubleType))), 0.1234) + } + test("Array max") { checkEvaluation(ArrayMax(Literal.create(Seq(1, 10, 2), ArrayType(IntegerType))), 10) checkEvaluation( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index daf407926dca4..642ac056bb809 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3300,6 +3300,14 @@ object functions { */ def sort_array(e: Column, asc: Boolean): Column = withExpr { SortArray(e.expr, lit(asc).expr) } + /** + * Returns the minimum value in the array. + * + * @group collection_funcs + * @since 2.4.0 + */ + def array_min(e: Column): Column = withExpr { ArrayMin(e.expr) } + /** * Returns the maximum value in the array. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 5d5d92c84df6d..636e86baedf6f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -413,6 +413,20 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { ) } + test("array_min function") { + val df = Seq( + Seq[Option[Int]](Some(1), Some(3), Some(2)), + Seq.empty[Option[Int]], + Seq[Option[Int]](None), + Seq[Option[Int]](None, Some(1), Some(-100)) + ).toDF("a") + + val answer = Seq(Row(1), Row(null), Row(null), Row(-100)) + + checkAnswer(df.select(array_min(df("a"))), answer) + checkAnswer(df.selectExpr("array_min(a)"), answer) + } + test("array_max function") { val df = Seq( Seq[Option[Int]](Some(1), Some(3), Some(2)), From 1cc66a072b7fd3bf140fa41596f6b18f8d1bd7b9 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Tue, 17 Apr 2018 01:59:38 -0700 Subject: [PATCH 0629/1161] [SPARK-23687][SS] Add a memory source for continuous processing. ## What changes were proposed in this pull request? Add a memory source for continuous processing. Note that only one of the ContinuousSuite tests is migrated to minimize the diff here. I'll submit a second PR for SPARK-23688 to change the rest and get rid of waitForRateSourceTriggers. ## How was this patch tested? unit test Author: Jose Torres Closes #20828 from jose-torres/continuousMemory. --- .../continuous/ContinuousExecution.scala | 5 +- .../sql/execution/streaming/memory.scala | 59 +++-- .../sources/ContinuousMemoryStream.scala | 211 ++++++++++++++++++ .../spark/sql/streaming/StreamTest.scala | 4 +- .../continuous/ContinuousSuite.scala | 31 +-- 5 files changed, 266 insertions(+), 44 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index 1758b3844bd62..951d694355ec5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.datasources.v2.{StreamingDataSourceV2Relation, WriteToDataSourceV2} import org.apache.spark.sql.execution.streaming.{ContinuousExecutionRelation, StreamingRelationV2, _} +import org.apache.spark.sql.sources.v2 import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions, StreamWriteSupport} import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReader, PartitionOffset} import org.apache.spark.sql.streaming.{OutputMode, ProcessingTime, Trigger} @@ -317,8 +318,10 @@ class ContinuousExecution( synchronized { if (queryExecutionThread.isAlive) { commitLog.add(epoch) - val offset = offsetLog.get(epoch).get.offsets(0).get + val offset = + continuousSources(0).deserializeOffset(offsetLog.get(epoch).get.offsets(0).get.json) committedOffsets ++= Seq(continuousSources(0) -> offset) + continuousSources(0).commit(offset.asInstanceOf[v2.reader.streaming.Offset]) } else { return } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala index 352d4ce9fbcaa..628923d367ce7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -24,17 +24,19 @@ import javax.annotation.concurrent.GuardedBy import scala.collection.JavaConverters._ import scala.collection.mutable.{ArrayBuffer, ListBuffer} +import scala.reflect.ClassTag import scala.util.control.NonFatal import org.apache.spark.internal.Logging import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.encoders.encoderFor +import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder} import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeRow} -import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics} +import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ +import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger import org.apache.spark.sql.sources.v2.reader.{DataReader, DataReaderFactory, SupportsScanUnsafeRow} import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset => OffsetV2} -import org.apache.spark.sql.streaming.OutputMode +import org.apache.spark.sql.streaming.{OutputMode, Trigger} import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils @@ -47,16 +49,43 @@ object MemoryStream { new MemoryStream[A](memoryStreamId.getAndIncrement(), sqlContext) } +/** + * A base class for memory stream implementations. Supports adding data and resetting. + */ +abstract class MemoryStreamBase[A : Encoder](sqlContext: SQLContext) extends BaseStreamingSource { + protected val encoder = encoderFor[A] + protected val attributes = encoder.schema.toAttributes + + def toDS(): Dataset[A] = { + Dataset[A](sqlContext.sparkSession, logicalPlan) + } + + def toDF(): DataFrame = { + Dataset.ofRows(sqlContext.sparkSession, logicalPlan) + } + + def addData(data: A*): Offset = { + addData(data.toTraversable) + } + + def readSchema(): StructType = encoder.schema + + protected def logicalPlan: LogicalPlan + + def addData(data: TraversableOnce[A]): Offset +} + /** * A [[Source]] that produces value stored in memory as they are added by the user. This [[Source]] * is intended for use in unit tests as it can only replay data when the object is still * available. */ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) - extends MicroBatchReader with SupportsScanUnsafeRow with Logging { - protected val encoder = encoderFor[A] - private val attributes = encoder.schema.toAttributes - protected val logicalPlan = StreamingExecutionRelation(this, attributes)(sqlContext.sparkSession) + extends MemoryStreamBase[A](sqlContext) + with MicroBatchReader with SupportsScanUnsafeRow with Logging { + + protected val logicalPlan: LogicalPlan = + StreamingExecutionRelation(this, attributes)(sqlContext.sparkSession) protected val output = logicalPlan.output /** @@ -70,7 +99,7 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) protected var currentOffset: LongOffset = new LongOffset(-1) @GuardedBy("this") - private var startOffset = new LongOffset(-1) + protected var startOffset = new LongOffset(-1) @GuardedBy("this") private var endOffset = new LongOffset(-1) @@ -82,18 +111,6 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) @GuardedBy("this") protected var lastOffsetCommitted : LongOffset = new LongOffset(-1) - def toDS(): Dataset[A] = { - Dataset(sqlContext.sparkSession, logicalPlan) - } - - def toDF(): DataFrame = { - Dataset.ofRows(sqlContext.sparkSession, logicalPlan) - } - - def addData(data: A*): Offset = { - addData(data.toTraversable) - } - def addData(data: TraversableOnce[A]): Offset = { val objects = data.toSeq val rows = objects.iterator.map(d => encoder.toRow(d).copy().asInstanceOf[UnsafeRow]).toArray @@ -114,8 +131,6 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) } } - override def readSchema(): StructType = encoder.schema - override def deserializeOffset(json: String): OffsetV2 = LongOffset(json.toLong) override def getStartOffset: OffsetV2 = synchronized { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala new file mode 100644 index 0000000000000..c28919b8b729b --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala @@ -0,0 +1,211 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.sources + +import java.{util => ju} +import java.util.Optional +import java.util.concurrent.atomic.AtomicInteger +import javax.annotation.concurrent.GuardedBy + +import scala.collection.JavaConverters._ +import scala.collection.mutable.ListBuffer + +import org.json4s.NoTypeHints +import org.json4s.jackson.Serialization + +import org.apache.spark.SparkEnv +import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, ThreadSafeRpcEndpoint} +import org.apache.spark.sql.{Encoder, Row, SQLContext} +import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.execution.streaming.sources.ContinuousMemoryStream.GetRecord +import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions} +import org.apache.spark.sql.sources.v2.reader.DataReaderFactory +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousDataReader, ContinuousReader, Offset, PartitionOffset} +import org.apache.spark.sql.types.StructType +import org.apache.spark.util.RpcUtils + +/** + * The overall strategy here is: + * * ContinuousMemoryStream maintains a list of records for each partition. addData() will + * distribute records evenly-ish across partitions. + * * RecordEndpoint is set up as an endpoint for executor-side + * ContinuousMemoryStreamDataReader instances to poll. It returns the record at the specified + * offset within the list, or null if that offset doesn't yet have a record. + */ +class ContinuousMemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) + extends MemoryStreamBase[A](sqlContext) with ContinuousReader with ContinuousReadSupport { + private implicit val formats = Serialization.formats(NoTypeHints) + private val NUM_PARTITIONS = 2 + + protected val logicalPlan = + StreamingRelationV2(this, "memory", Map(), attributes, None)(sqlContext.sparkSession) + + // ContinuousReader implementation + + @GuardedBy("this") + private val records = Seq.fill(NUM_PARTITIONS)(new ListBuffer[A]) + + @GuardedBy("this") + private var startOffset: ContinuousMemoryStreamOffset = _ + + private val recordEndpoint = new RecordEndpoint() + @volatile private var endpointRef: RpcEndpointRef = _ + + def addData(data: TraversableOnce[A]): Offset = synchronized { + // Distribute data evenly among partition lists. + data.toSeq.zipWithIndex.map { + case (item, index) => records(index % NUM_PARTITIONS) += item + } + + // The new target offset is the offset where all records in all partitions have been processed. + ContinuousMemoryStreamOffset((0 until NUM_PARTITIONS).map(i => (i, records(i).size)).toMap) + } + + override def setStartOffset(start: Optional[Offset]): Unit = synchronized { + // Inferred initial offset is position 0 in each partition. + startOffset = start.orElse { + ContinuousMemoryStreamOffset((0 until NUM_PARTITIONS).map(i => (i, 0)).toMap) + }.asInstanceOf[ContinuousMemoryStreamOffset] + } + + override def getStartOffset: Offset = synchronized { + startOffset + } + + override def deserializeOffset(json: String): ContinuousMemoryStreamOffset = { + ContinuousMemoryStreamOffset(Serialization.read[Map[Int, Int]](json)) + } + + override def mergeOffsets(offsets: Array[PartitionOffset]): ContinuousMemoryStreamOffset = { + ContinuousMemoryStreamOffset( + offsets.map { + case ContinuousMemoryStreamPartitionOffset(part, num) => (part, num) + }.toMap + ) + } + + override def createDataReaderFactories(): ju.List[DataReaderFactory[Row]] = { + synchronized { + val endpointName = s"ContinuousMemoryStreamRecordEndpoint-${java.util.UUID.randomUUID()}-$id" + endpointRef = + recordEndpoint.rpcEnv.setupEndpoint(endpointName, recordEndpoint) + + startOffset.partitionNums.map { + case (part, index) => + new ContinuousMemoryStreamDataReaderFactory( + endpointName, part, index): DataReaderFactory[Row] + }.toList.asJava + } + } + + override def stop(): Unit = { + if (endpointRef != null) recordEndpoint.rpcEnv.stop(endpointRef) + } + + override def commit(end: Offset): Unit = {} + + // ContinuousReadSupport implementation + // This is necessary because of how StreamTest finds the source for AddDataMemory steps. + def createContinuousReader( + schema: Optional[StructType], + checkpointLocation: String, + options: DataSourceOptions): ContinuousReader = { + this + } + + /** + * Endpoint for executors to poll for records. + */ + private class RecordEndpoint extends ThreadSafeRpcEndpoint { + override val rpcEnv: RpcEnv = SparkEnv.get.rpcEnv + + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case GetRecord(ContinuousMemoryStreamPartitionOffset(part, index)) => + ContinuousMemoryStream.this.synchronized { + val buf = records(part) + val record = if (buf.size <= index) None else Some(buf(index)) + + context.reply(record.map(Row(_))) + } + } + } +} + +object ContinuousMemoryStream { + case class GetRecord(offset: ContinuousMemoryStreamPartitionOffset) + protected val memoryStreamId = new AtomicInteger(0) + + def apply[A : Encoder](implicit sqlContext: SQLContext): ContinuousMemoryStream[A] = + new ContinuousMemoryStream[A](memoryStreamId.getAndIncrement(), sqlContext) +} + +/** + * Data reader factory for continuous memory stream. + */ +class ContinuousMemoryStreamDataReaderFactory( + driverEndpointName: String, + partition: Int, + startOffset: Int) extends DataReaderFactory[Row] { + override def createDataReader: ContinuousMemoryStreamDataReader = + new ContinuousMemoryStreamDataReader(driverEndpointName, partition, startOffset) +} + +/** + * Data reader for continuous memory stream. + * + * Polls the driver endpoint for new records. + */ +class ContinuousMemoryStreamDataReader( + driverEndpointName: String, + partition: Int, + startOffset: Int) extends ContinuousDataReader[Row] { + private val endpoint = RpcUtils.makeDriverRef( + driverEndpointName, + SparkEnv.get.conf, + SparkEnv.get.rpcEnv) + + private var currentOffset = startOffset + private var current: Option[Row] = None + + override def next(): Boolean = { + current = None + while (current.isEmpty) { + Thread.sleep(10) + current = endpoint.askSync[Option[Row]]( + GetRecord(ContinuousMemoryStreamPartitionOffset(partition, currentOffset))) + } + currentOffset += 1 + true + } + + override def get(): Row = current.get + + override def close(): Unit = {} + + override def getOffset: ContinuousMemoryStreamPartitionOffset = + ContinuousMemoryStreamPartitionOffset(partition, currentOffset) +} + +case class ContinuousMemoryStreamOffset(partitionNums: Map[Int, Int]) + extends Offset { + private implicit val formats = Serialization.formats(NoTypeHints) + override def json(): String = Serialization.write(partitionNums) +} + +case class ContinuousMemoryStreamPartitionOffset(partition: Int, numProcessed: Int) + extends PartitionOffset diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index 00741d660dd2d..af0268fa47871 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -99,7 +99,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be * been processed. */ object AddData { - def apply[A](source: MemoryStream[A], data: A*): AddDataMemory[A] = + def apply[A](source: MemoryStreamBase[A], data: A*): AddDataMemory[A] = AddDataMemory(source, data) } @@ -131,7 +131,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be def runAction(): Unit } - case class AddDataMemory[A](source: MemoryStream[A], data: Seq[A]) extends AddData { + case class AddDataMemory[A](source: MemoryStreamBase[A], data: Seq[A]) extends AddData { override def toString: String = s"AddData to $source: ${data.mkString(",")}" override def addData(query: Option[StreamExecution]): (BaseStreamingSource, Offset) = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala index ef74efef156d5..c318b951ff992 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql._ import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanExec import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous._ +import org.apache.spark.sql.execution.streaming.sources.ContinuousMemoryStream import org.apache.spark.sql.functions._ import org.apache.spark.sql.streaming.{StreamTest, Trigger} import org.apache.spark.sql.test.TestSparkSession @@ -53,32 +54,24 @@ class ContinuousSuiteBase extends StreamTest { // A continuous trigger that will only fire the initial time for the duration of a test. // This allows clean testing with manual epoch advancement. protected val longContinuousTrigger = Trigger.Continuous("1 hour") + + override protected val defaultTrigger = Trigger.Continuous(100) + override protected val defaultUseV2Sink = true } class ContinuousSuite extends ContinuousSuiteBase { import testImplicits._ - test("basic rate source") { - val df = spark.readStream - .format("rate") - .option("numPartitions", "5") - .option("rowsPerSecond", "5") - .load() - .select('value) + test("basic") { + val input = ContinuousMemoryStream[Int] - testStream(df, useV2Sink = true)( - StartStream(longContinuousTrigger), - AwaitEpoch(0), - Execute(waitForRateSourceTriggers(_, 2)), - IncrementEpoch(), - CheckAnswerRowsContains(scala.Range(0, 10).map(Row(_))), + testStream(input.toDF())( + AddData(input, 0, 1, 2), + CheckAnswer(0, 1, 2), StopStream, - StartStream(longContinuousTrigger), - AwaitEpoch(2), - Execute(waitForRateSourceTriggers(_, 2)), - IncrementEpoch(), - CheckAnswerRowsContains(scala.Range(0, 20).map(Row(_))), - StopStream) + AddData(input, 3, 4, 5), + StartStream(), + CheckAnswer(0, 1, 2, 3, 4, 5)) } test("map") { From 05ae74778a10fbdd7f2cbf7742de7855966b7d35 Mon Sep 17 00:00:00 2001 From: Efim Poberezkin Date: Tue, 17 Apr 2018 04:13:17 -0700 Subject: [PATCH 0630/1161] [SPARK-23747][STRUCTURED STREAMING] Add EpochCoordinator unit tests ## What changes were proposed in this pull request? Unit tests for EpochCoordinator that test correct sequencing of committed epochs. Several tests are ignored since they test functionality implemented in SPARK-23503 which is not yet merged, otherwise they fail. Author: Efim Poberezkin Closes #20983 from efimpoberezkin/pr/EpochCoordinator-tests. --- .../continuous/EpochCoordinatorSuite.scala | 224 ++++++++++++++++++ 1 file changed, 224 insertions(+) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/EpochCoordinatorSuite.scala diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/EpochCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/EpochCoordinatorSuite.scala new file mode 100644 index 0000000000000..99e30561f81d5 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/EpochCoordinatorSuite.scala @@ -0,0 +1,224 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming.continuous + +import org.mockito.InOrder +import org.mockito.Matchers.{any, eq => eqTo} +import org.mockito.Mockito._ +import org.scalatest.BeforeAndAfterEach +import org.scalatest.mockito.MockitoSugar + +import org.apache.spark._ +import org.apache.spark.rpc.RpcEndpointRef +import org.apache.spark.sql.LocalSparkSession +import org.apache.spark.sql.execution.streaming.continuous._ +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReader, PartitionOffset} +import org.apache.spark.sql.sources.v2.writer.WriterCommitMessage +import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter +import org.apache.spark.sql.test.TestSparkSession + +class EpochCoordinatorSuite + extends SparkFunSuite + with LocalSparkSession + with MockitoSugar + with BeforeAndAfterEach { + + private var epochCoordinator: RpcEndpointRef = _ + + private var writer: StreamWriter = _ + private var query: ContinuousExecution = _ + private var orderVerifier: InOrder = _ + + override def beforeEach(): Unit = { + val reader = mock[ContinuousReader] + writer = mock[StreamWriter] + query = mock[ContinuousExecution] + orderVerifier = inOrder(writer, query) + + spark = new TestSparkSession() + + epochCoordinator + = EpochCoordinatorRef.create(writer, reader, query, "test", 1, spark, SparkEnv.get) + } + + test("single epoch") { + setWriterPartitions(3) + setReaderPartitions(2) + + commitPartitionEpoch(0, 1) + commitPartitionEpoch(1, 1) + commitPartitionEpoch(2, 1) + reportPartitionOffset(0, 1) + reportPartitionOffset(1, 1) + + // Here and in subsequent tests this is called to make a synchronous call to EpochCoordinator + // so that mocks would have been acted upon by the time verification happens + makeSynchronousCall() + + verifyCommit(1) + } + + test("single epoch, all but one writer partition has committed") { + setWriterPartitions(3) + setReaderPartitions(2) + + commitPartitionEpoch(0, 1) + commitPartitionEpoch(1, 1) + reportPartitionOffset(0, 1) + reportPartitionOffset(1, 1) + + makeSynchronousCall() + + verifyNoCommitFor(1) + } + + test("single epoch, all but one reader partition has reported an offset") { + setWriterPartitions(3) + setReaderPartitions(2) + + commitPartitionEpoch(0, 1) + commitPartitionEpoch(1, 1) + commitPartitionEpoch(2, 1) + reportPartitionOffset(0, 1) + + makeSynchronousCall() + + verifyNoCommitFor(1) + } + + test("consequent epochs, messages for epoch (k + 1) arrive after messages for epoch k") { + setWriterPartitions(2) + setReaderPartitions(2) + + commitPartitionEpoch(0, 1) + commitPartitionEpoch(1, 1) + reportPartitionOffset(0, 1) + reportPartitionOffset(1, 1) + + commitPartitionEpoch(0, 2) + commitPartitionEpoch(1, 2) + reportPartitionOffset(0, 2) + reportPartitionOffset(1, 2) + + makeSynchronousCall() + + verifyCommitsInOrderOf(List(1, 2)) + } + + ignore("consequent epochs, a message for epoch k arrives after messages for epoch (k + 1)") { + setWriterPartitions(2) + setReaderPartitions(2) + + commitPartitionEpoch(0, 1) + commitPartitionEpoch(1, 1) + reportPartitionOffset(0, 1) + + commitPartitionEpoch(0, 2) + commitPartitionEpoch(1, 2) + reportPartitionOffset(0, 2) + reportPartitionOffset(1, 2) + + // Message that arrives late + reportPartitionOffset(1, 1) + + makeSynchronousCall() + + verifyCommitsInOrderOf(List(1, 2)) + } + + ignore("several epochs, messages arrive in order 1 -> 3 -> 4 -> 2") { + setWriterPartitions(1) + setReaderPartitions(1) + + commitPartitionEpoch(0, 1) + reportPartitionOffset(0, 1) + + commitPartitionEpoch(0, 3) + reportPartitionOffset(0, 3) + + commitPartitionEpoch(0, 4) + reportPartitionOffset(0, 4) + + commitPartitionEpoch(0, 2) + reportPartitionOffset(0, 2) + + makeSynchronousCall() + + verifyCommitsInOrderOf(List(1, 2, 3, 4)) + } + + ignore("several epochs, messages arrive in order 1 -> 3 -> 5 -> 4 -> 2") { + setWriterPartitions(1) + setReaderPartitions(1) + + commitPartitionEpoch(0, 1) + reportPartitionOffset(0, 1) + + commitPartitionEpoch(0, 3) + reportPartitionOffset(0, 3) + + commitPartitionEpoch(0, 5) + reportPartitionOffset(0, 5) + + commitPartitionEpoch(0, 4) + reportPartitionOffset(0, 4) + + commitPartitionEpoch(0, 2) + reportPartitionOffset(0, 2) + + makeSynchronousCall() + + verifyCommitsInOrderOf(List(1, 2, 3, 4, 5)) + } + + private def setWriterPartitions(numPartitions: Int): Unit = { + epochCoordinator.askSync[Unit](SetWriterPartitions(numPartitions)) + } + + private def setReaderPartitions(numPartitions: Int): Unit = { + epochCoordinator.askSync[Unit](SetReaderPartitions(numPartitions)) + } + + private def commitPartitionEpoch(partitionId: Int, epoch: Long): Unit = { + val dummyMessage: WriterCommitMessage = mock[WriterCommitMessage] + epochCoordinator.send(CommitPartitionEpoch(partitionId, epoch, dummyMessage)) + } + + private def reportPartitionOffset(partitionId: Int, epoch: Long): Unit = { + val dummyOffset: PartitionOffset = mock[PartitionOffset] + epochCoordinator.send(ReportPartitionOffset(partitionId, epoch, dummyOffset)) + } + + private def makeSynchronousCall(): Unit = { + epochCoordinator.askSync[Long](GetCurrentEpoch) + } + + private def verifyCommit(epoch: Long): Unit = { + orderVerifier.verify(writer).commit(eqTo(epoch), any()) + orderVerifier.verify(query).commit(epoch) + } + + private def verifyNoCommitFor(epoch: Long): Unit = { + verify(writer, never()).commit(eqTo(epoch), any()) + verify(query, never()).commit(epoch) + } + + private def verifyCommitsInOrderOf(epochs: Seq[Long]): Unit = { + epochs.foreach(verifyCommit) + } +} From 30ffb53cad84283b4f7694bfd60bdd7e1101b04e Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 17 Apr 2018 15:09:36 +0200 Subject: [PATCH 0631/1161] [SPARK-23875][SQL] Add IndexedSeq wrapper for ArrayData ## What changes were proposed in this pull request? We don't have a good way to sequentially access `UnsafeArrayData` with a common interface such as `Seq`. An example is `MapObject` where we need to access several sequence collection types together. But `UnsafeArrayData` doesn't implement `ArrayData.array`. Calling `toArray` will copy the entire array. We can provide an `IndexedSeq` wrapper for `ArrayData`, so we can avoid copying the entire array. ## How was this patch tested? Added test. Author: Liang-Chi Hsieh Closes #20984 from viirya/SPARK-23875. --- .../expressions/objects/objects.scala | 2 +- .../spark/sql/catalyst/util/ArrayData.scala | 30 +++++- .../util/ArrayDataIndexedSeqSuite.scala | 100 ++++++++++++++++++ 3 files changed, 130 insertions(+), 2 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ArrayDataIndexedSeqSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 77802e89e942b..72b202b3a5020 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -708,7 +708,7 @@ case class MapObjects private( } } case ArrayType(et, _) => - _.asInstanceOf[ArrayData].array + _.asInstanceOf[ArrayData].toSeq[Any](et) } private lazy val mapElements: Seq[_] => Any = customCollectionCls match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayData.scala index 9beef41d639f3..2cf59d567c08c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayData.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayData.scala @@ -19,8 +19,9 @@ package org.apache.spark.sql.catalyst.util import scala.reflect.ClassTag +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{SpecializedGetters, UnsafeArrayData} -import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.types._ object ArrayData { def toArrayData(input: Any): ArrayData = input match { @@ -42,6 +43,9 @@ abstract class ArrayData extends SpecializedGetters with Serializable { def array: Array[Any] + def toSeq[T](dataType: DataType): IndexedSeq[T] = + new ArrayDataIndexedSeq[T](this, dataType) + def setNullAt(i: Int): Unit def update(i: Int, value: Any): Unit @@ -164,3 +168,27 @@ abstract class ArrayData extends SpecializedGetters with Serializable { } } } + +/** + * Implements an `IndexedSeq` interface for `ArrayData`. Notice that if the original `ArrayData` + * is a primitive array and contains null elements, it is better to ask for `IndexedSeq[Any]`, + * instead of `IndexedSeq[Int]`, in order to keep the null elements. + */ +class ArrayDataIndexedSeq[T](arrayData: ArrayData, dataType: DataType) extends IndexedSeq[T] { + + private val accessor: (SpecializedGetters, Int) => Any = InternalRow.getAccessor(dataType) + + override def apply(idx: Int): T = + if (0 <= idx && idx < arrayData.numElements()) { + if (arrayData.isNullAt(idx)) { + null.asInstanceOf[T] + } else { + accessor(arrayData, idx).asInstanceOf[T] + } + } else { + throw new IndexOutOfBoundsException( + s"Index $idx must be between 0 and the length of the ArrayData.") + } + + override def length: Int = arrayData.numElements() +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ArrayDataIndexedSeqSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ArrayDataIndexedSeqSuite.scala new file mode 100644 index 0000000000000..6400898343ae7 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ArrayDataIndexedSeqSuite.scala @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util + +import scala.util.Random + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.RandomDataGenerator +import org.apache.spark.sql.catalyst.encoders.{ExamplePointUDT, RowEncoder} +import org.apache.spark.sql.catalyst.expressions.{FromUnsafeProjection, UnsafeArrayData, UnsafeProjection} +import org.apache.spark.sql.types._ + +class ArrayDataIndexedSeqSuite extends SparkFunSuite { + private def compArray(arrayData: ArrayData, elementDt: DataType, array: Array[Any]): Unit = { + assert(arrayData.numElements == array.length) + array.zipWithIndex.map { case (e, i) => + if (e != null) { + elementDt match { + // For NaN, etc. + case FloatType | DoubleType => assert(arrayData.get(i, elementDt).equals(e)) + case _ => assert(arrayData.get(i, elementDt) === e) + } + } else { + assert(arrayData.isNullAt(i)) + } + } + + val seq = arrayData.toSeq[Any](elementDt) + array.zipWithIndex.map { case (e, i) => + if (e != null) { + elementDt match { + // For Nan, etc. + case FloatType | DoubleType => assert(seq(i).equals(e)) + case _ => assert(seq(i) === e) + } + } else { + assert(seq(i) == null) + } + } + + intercept[IndexOutOfBoundsException] { + seq(-1) + }.getMessage().contains("must be between 0 and the length of the ArrayData.") + + intercept[IndexOutOfBoundsException] { + seq(seq.length) + }.getMessage().contains("must be between 0 and the length of the ArrayData.") + } + + private def testArrayData(): Unit = { + val elementTypes = Seq(BooleanType, ByteType, ShortType, IntegerType, LongType, FloatType, + DoubleType, DecimalType.USER_DEFAULT, StringType, BinaryType, DateType, TimestampType, + CalendarIntervalType, new ExamplePointUDT()) + val arrayTypes = elementTypes.flatMap { elementType => + Seq(ArrayType(elementType, containsNull = false), ArrayType(elementType, containsNull = true)) + } + val random = new Random(100) + arrayTypes.foreach { dt => + val schema = StructType(StructField("col_1", dt, nullable = false) :: Nil) + val row = RandomDataGenerator.randomRow(random, schema) + val rowConverter = RowEncoder(schema) + val internalRow = rowConverter.toRow(row) + + val unsafeRowConverter = UnsafeProjection.create(schema) + val safeRowConverter = FromUnsafeProjection(schema) + + val unsafeRow = unsafeRowConverter(internalRow) + val safeRow = safeRowConverter(unsafeRow) + + val genericArrayData = safeRow.getArray(0).asInstanceOf[GenericArrayData] + val unsafeArrayData = unsafeRow.getArray(0).asInstanceOf[UnsafeArrayData] + + val elementType = dt.elementType + test("ArrayDataIndexedSeq - UnsafeArrayData - " + dt.toString) { + compArray(unsafeArrayData, elementType, unsafeArrayData.toArray[Any](elementType)) + } + + test("ArrayDataIndexedSeq - GenericArrayData - " + dt.toString) { + compArray(genericArrayData, elementType, genericArrayData.toArray[Any](elementType)) + } + } + } + + testArrayData() +} From 0a9172a05e604a4a94adbb9208c8c02362afca00 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Tue, 17 Apr 2018 21:45:20 +0800 Subject: [PATCH 0632/1161] [SPARK-23835][SQL] Add not-null check to Tuples' arguments deserialization ## What changes were proposed in this pull request? There was no check on nullability for arguments of `Tuple`s. This could lead to have weird behavior when a null value had to be deserialized into a non-nullable Scala object: in those cases, the `null` got silently transformed in a valid value (like `-1` for `Int`), corresponding to the default value we are using in the SQL codebase. This situation was very likely to happen when deserializing to a Tuple of primitive Scala types (like Double, Int, ...). The PR adds the `AssertNotNull` to arguments of tuples which have been asked to be converted to non-nullable types. ## How was this patch tested? added UT Author: Marco Gaido Closes #20976 from mgaido91/SPARK-23835. --- .../sql/kafka010/KafkaContinuousSinkSuite.scala | 6 +++--- .../apache/spark/sql/kafka010/KafkaSinkSuite.scala | 2 +- .../spark/sql/catalyst/ScalaReflection.scala | 14 +++++++------- .../spark/sql/catalyst/ScalaReflectionSuite.scala | 12 +++++++++++- .../scala/org/apache/spark/sql/DatasetSuite.scala | 5 +++++ 5 files changed, 27 insertions(+), 12 deletions(-) diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala index fc890a0cfdac3..ddfc0c1a4be2d 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala @@ -79,7 +79,7 @@ class KafkaContinuousSinkSuite extends KafkaContinuousTest { val reader = createKafkaReader(topic) .selectExpr("CAST(key as STRING) key", "CAST(value as STRING) value") .selectExpr("CAST(key as INT) key", "CAST(value as INT) value") - .as[(Int, Int)] + .as[(Option[Int], Int)] .map(_._2) try { @@ -119,7 +119,7 @@ class KafkaContinuousSinkSuite extends KafkaContinuousTest { val reader = createKafkaReader(topic) .selectExpr("CAST(key as STRING) key", "CAST(value as STRING) value") .selectExpr("CAST(key as INT) key", "CAST(value as INT) value") - .as[(Int, Int)] + .as[(Option[Int], Int)] .map(_._2) try { @@ -167,7 +167,7 @@ class KafkaContinuousSinkSuite extends KafkaContinuousTest { val reader = createKafkaReader(topic) .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") .selectExpr("CAST(key AS INT)", "CAST(value AS INT)") - .as[(Int, Int)] + .as[(Option[Int], Int)] .map(_._2) try { diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala index 42f8b4c7657e2..7079ac6453ffc 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala @@ -138,7 +138,7 @@ class KafkaSinkSuite extends StreamTest with SharedSQLContext { val reader = createKafkaReader(topic) .selectExpr("CAST(key as STRING) key", "CAST(value as STRING) value") .selectExpr("CAST(key as INT) key", "CAST(value as INT) value") - .as[(Int, Int)] + .as[(Option[Int], Int)] .map(_._2) try { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 1aae3aea3a31a..e4274aaa9727e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -382,22 +382,22 @@ object ScalaReflection extends ScalaReflection { val clsName = getClassNameFromType(fieldType) val newTypePath = s"""- field (class: "$clsName", name: "$fieldName")""" +: walkedTypePath // For tuples, we based grab the inner fields by ordinal instead of name. - if (cls.getName startsWith "scala.Tuple") { + val constructor = if (cls.getName startsWith "scala.Tuple") { deserializerFor( fieldType, Some(addToPathOrdinal(i, dataType, newTypePath)), newTypePath) } else { - val constructor = deserializerFor( + deserializerFor( fieldType, Some(addToPath(fieldName, dataType, newTypePath)), newTypePath) + } - if (!nullable) { - AssertNotNull(constructor, newTypePath) - } else { - constructor - } + if (!nullable) { + AssertNotNull(constructor, newTypePath) + } else { + constructor } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala index 8c3db48a01f12..353b8344658f2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -21,7 +21,7 @@ import java.sql.{Date, Timestamp} import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute -import org.apache.spark.sql.catalyst.expressions.{BoundReference, Literal, SpecificInternalRow, UpCast} +import org.apache.spark.sql.catalyst.expressions.{BoundReference, Expression, Literal, SpecificInternalRow, UpCast} import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, NewInstance} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -365,4 +365,14 @@ class ScalaReflectionSuite extends SparkFunSuite { StructField("_2", NullType, nullable = true))), nullable = true)) } + + test("SPARK-23835: add null check to non-nullable types in Tuples") { + def numberOfCheckedArguments(deserializer: Expression): Int = { + assert(deserializer.isInstanceOf[NewInstance]) + deserializer.asInstanceOf[NewInstance].arguments.count(_.isInstanceOf[AssertNotNull]) + } + assert(numberOfCheckedArguments(deserializerFor[(Double, Double)]) == 2) + assert(numberOfCheckedArguments(deserializerFor[(java.lang.Double, Int)]) == 1) + assert(numberOfCheckedArguments(deserializerFor[(java.lang.Integer, java.lang.Integer)]) == 0) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 9b745befcb611..e0f4d2ba685e1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -1453,6 +1453,11 @@ class DatasetSuite extends QueryTest with SharedSQLContext { val group2 = cached.groupBy("x").agg(min(col("z")) as "value") checkAnswer(group1.union(group2), Row(4, 5) :: Row(1, 2) :: Row(4, 6) :: Row(1, 3) :: Nil) } + + test("SPARK-23835: null primitive data type should throw NullPointerException") { + val ds = Seq[(Option[Int], Option[Int])]((Some(1), None)).toDS() + intercept[NullPointerException](ds.as[(Int, Int)].collect()) + } } case class TestDataUnion(x: Int, y: Int, z: Int) From ed4101d29f50d54fd7846421e4c00e9ecd3599d0 Mon Sep 17 00:00:00 2001 From: jinxing Date: Tue, 17 Apr 2018 21:52:33 +0800 Subject: [PATCH 0633/1161] [SPARK-22676] Avoid iterating all partition paths when spark.sql.hive.verifyPartitionPath=true MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? In current code, it will scanning all partition paths when spark.sql.hive.verifyPartitionPath=true. e.g. table like below: ``` CREATE TABLE `test`( `id` int, `age` int, `name` string) PARTITIONED BY ( `A` string, `B` string) load data local inpath '/tmp/data0' into table test partition(A='00', B='00') load data local inpath '/tmp/data1' into table test partition(A='01', B='01') load data local inpath '/tmp/data2' into table test partition(A='10', B='10') load data local inpath '/tmp/data3' into table test partition(A='11', B='11') ``` If I query with SQL – "select * from test where A='00' and B='01' ", current code will scan all partition paths including '/data/A=00/B=00', '/data/A=00/B=00', '/data/A=01/B=01', '/data/A=10/B=10', '/data/A=11/B=11'. It costs much time and memory cost. This pr proposes to avoid iterating all partition paths. Add a config `spark.files.ignoreMissingFiles` and ignore the `file not found` when `getPartitions/compute`(for hive table scan). This is much like the logic brought by `spark.sql.files.ignoreMissingFiles`(which is for datasource scan). ## How was this patch tested? UT Author: jinxing Closes #19868 from jinxing64/SPARK-22676. --- .../spark/internal/config/package.scala | 6 ++ .../org/apache/spark/rdd/HadoopRDD.scala | 43 +++++++++--- .../org/apache/spark/rdd/NewHadoopRDD.scala | 45 ++++++++---- .../scala/org/apache/spark/FileSuite.scala | 69 ++++++++++++++++++- .../apache/spark/sql/internal/SQLConf.scala | 3 +- .../spark/sql/hive/QueryPartitionSuite.scala | 40 +++++++++++ 6 files changed, 181 insertions(+), 25 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 407545aa4a47a..99d779fb600e8 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -301,6 +301,12 @@ package object config { .booleanConf .createWithDefault(false) + private[spark] val IGNORE_MISSING_FILES = ConfigBuilder("spark.files.ignoreMissingFiles") + .doc("Whether to ignore missing files. If true, the Spark jobs will continue to run when " + + "encountering missing files and the contents that have been read will still be returned.") + .booleanConf + .createWithDefault(false) + private[spark] val APP_CALLER_CONTEXT = ConfigBuilder("spark.log.callerContext") .stringConf .createOptional diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index 2480559a41b7a..44895abc7bd4d 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -17,7 +17,7 @@ package org.apache.spark.rdd -import java.io.IOException +import java.io.{FileNotFoundException, IOException} import java.text.SimpleDateFormat import java.util.{Date, Locale} @@ -28,6 +28,7 @@ import org.apache.hadoop.conf.{Configurable, Configuration} import org.apache.hadoop.mapred._ import org.apache.hadoop.mapred.lib.CombineFileSplit import org.apache.hadoop.mapreduce.TaskType +import org.apache.hadoop.mapreduce.lib.input.FileInputFormat import org.apache.hadoop.util.ReflectionUtils import org.apache.spark._ @@ -134,6 +135,8 @@ class HadoopRDD[K, V]( private val ignoreCorruptFiles = sparkContext.conf.get(IGNORE_CORRUPT_FILES) + private val ignoreMissingFiles = sparkContext.conf.get(IGNORE_MISSING_FILES) + private val ignoreEmptySplits = sparkContext.conf.get(HADOOP_RDD_IGNORE_EMPTY_SPLITS) // Returns a JobConf that will be used on slaves to obtain input splits for Hadoop reads. @@ -197,17 +200,24 @@ class HadoopRDD[K, V]( val jobConf = getJobConf() // add the credentials here as this can be called before SparkContext initialized SparkHadoopUtil.get.addCredentials(jobConf) - val allInputSplits = getInputFormat(jobConf).getSplits(jobConf, minPartitions) - val inputSplits = if (ignoreEmptySplits) { - allInputSplits.filter(_.getLength > 0) - } else { - allInputSplits - } - val array = new Array[Partition](inputSplits.size) - for (i <- 0 until inputSplits.size) { - array(i) = new HadoopPartition(id, i, inputSplits(i)) + try { + val allInputSplits = getInputFormat(jobConf).getSplits(jobConf, minPartitions) + val inputSplits = if (ignoreEmptySplits) { + allInputSplits.filter(_.getLength > 0) + } else { + allInputSplits + } + val array = new Array[Partition](inputSplits.size) + for (i <- 0 until inputSplits.size) { + array(i) = new HadoopPartition(id, i, inputSplits(i)) + } + array + } catch { + case e: InvalidInputException if ignoreMissingFiles => + logWarning(s"${jobConf.get(FileInputFormat.INPUT_DIR)} doesn't exist and no" + + s" partitions returned from this path.", e) + Array.empty[Partition] } - array } override def compute(theSplit: Partition, context: TaskContext): InterruptibleIterator[(K, V)] = { @@ -256,6 +266,12 @@ class HadoopRDD[K, V]( try { inputFormat.getRecordReader(split.inputSplit.value, jobConf, Reporter.NULL) } catch { + case e: FileNotFoundException if ignoreMissingFiles => + logWarning(s"Skipped missing file: ${split.inputSplit}", e) + finished = true + null + // Throw FileNotFoundException even if `ignoreCorruptFiles` is true + case e: FileNotFoundException if !ignoreMissingFiles => throw e case e: IOException if ignoreCorruptFiles => logWarning(s"Skipped the rest content in the corrupted file: ${split.inputSplit}", e) finished = true @@ -276,6 +292,11 @@ class HadoopRDD[K, V]( try { finished = !reader.next(key, value) } catch { + case e: FileNotFoundException if ignoreMissingFiles => + logWarning(s"Skipped missing file: ${split.inputSplit}", e) + finished = true + // Throw FileNotFoundException even if `ignoreCorruptFiles` is true + case e: FileNotFoundException if !ignoreMissingFiles => throw e case e: IOException if ignoreCorruptFiles => logWarning(s"Skipped the rest content in the corrupted file: ${split.inputSplit}", e) finished = true diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala index e4dd1b6a82498..ff66a04859d10 100644 --- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala @@ -17,7 +17,7 @@ package org.apache.spark.rdd -import java.io.IOException +import java.io.{FileNotFoundException, IOException} import java.text.SimpleDateFormat import java.util.{Date, Locale} @@ -28,7 +28,7 @@ import org.apache.hadoop.conf.{Configurable, Configuration} import org.apache.hadoop.io.Writable import org.apache.hadoop.mapred.JobConf import org.apache.hadoop.mapreduce._ -import org.apache.hadoop.mapreduce.lib.input.{CombineFileSplit, FileSplit} +import org.apache.hadoop.mapreduce.lib.input.{CombineFileSplit, FileInputFormat, FileSplit, InvalidInputException} import org.apache.hadoop.mapreduce.task.{JobContextImpl, TaskAttemptContextImpl} import org.apache.spark._ @@ -90,6 +90,8 @@ class NewHadoopRDD[K, V]( private val ignoreCorruptFiles = sparkContext.conf.get(IGNORE_CORRUPT_FILES) + private val ignoreMissingFiles = sparkContext.conf.get(IGNORE_MISSING_FILES) + private val ignoreEmptySplits = sparkContext.conf.get(HADOOP_RDD_IGNORE_EMPTY_SPLITS) def getConf: Configuration = { @@ -124,17 +126,25 @@ class NewHadoopRDD[K, V]( configurable.setConf(_conf) case _ => } - val allRowSplits = inputFormat.getSplits(new JobContextImpl(_conf, jobId)).asScala - val rawSplits = if (ignoreEmptySplits) { - allRowSplits.filter(_.getLength > 0) - } else { - allRowSplits - } - val result = new Array[Partition](rawSplits.size) - for (i <- 0 until rawSplits.size) { - result(i) = new NewHadoopPartition(id, i, rawSplits(i).asInstanceOf[InputSplit with Writable]) + try { + val allRowSplits = inputFormat.getSplits(new JobContextImpl(_conf, jobId)).asScala + val rawSplits = if (ignoreEmptySplits) { + allRowSplits.filter(_.getLength > 0) + } else { + allRowSplits + } + val result = new Array[Partition](rawSplits.size) + for (i <- 0 until rawSplits.size) { + result(i) = + new NewHadoopPartition(id, i, rawSplits(i).asInstanceOf[InputSplit with Writable]) + } + result + } catch { + case e: InvalidInputException if ignoreMissingFiles => + logWarning(s"${_conf.get(FileInputFormat.INPUT_DIR)} doesn't exist and no" + + s" partitions returned from this path.", e) + Array.empty[Partition] } - result } override def compute(theSplit: Partition, context: TaskContext): InterruptibleIterator[(K, V)] = { @@ -189,6 +199,12 @@ class NewHadoopRDD[K, V]( _reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext) _reader } catch { + case e: FileNotFoundException if ignoreMissingFiles => + logWarning(s"Skipped missing file: ${split.serializableHadoopSplit}", e) + finished = true + null + // Throw FileNotFoundException even if `ignoreCorruptFiles` is true + case e: FileNotFoundException if !ignoreMissingFiles => throw e case e: IOException if ignoreCorruptFiles => logWarning( s"Skipped the rest content in the corrupted file: ${split.serializableHadoopSplit}", @@ -213,6 +229,11 @@ class NewHadoopRDD[K, V]( try { finished = !reader.nextKeyValue } catch { + case e: FileNotFoundException if ignoreMissingFiles => + logWarning(s"Skipped missing file: ${split.serializableHadoopSplit}", e) + finished = true + // Throw FileNotFoundException even if `ignoreCorruptFiles` is true + case e: FileNotFoundException if !ignoreMissingFiles => throw e case e: IOException if ignoreCorruptFiles => logWarning( s"Skipped the rest content in the corrupted file: ${split.serializableHadoopSplit}", diff --git a/core/src/test/scala/org/apache/spark/FileSuite.scala b/core/src/test/scala/org/apache/spark/FileSuite.scala index 55a9122cf9026..a441b9c8ab97a 100644 --- a/core/src/test/scala/org/apache/spark/FileSuite.scala +++ b/core/src/test/scala/org/apache/spark/FileSuite.scala @@ -23,6 +23,7 @@ import java.util.zip.GZIPOutputStream import scala.io.Source +import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.hadoop.io._ import org.apache.hadoop.io.compress.DefaultCodec @@ -32,7 +33,7 @@ import org.apache.hadoop.mapreduce.lib.input.{FileSplit => NewFileSplit, TextInp import org.apache.hadoop.mapreduce.lib.output.{TextOutputFormat => NewTextOutputFormat} import org.apache.spark.internal.config._ -import org.apache.spark.rdd.{HadoopRDD, NewHadoopRDD} +import org.apache.spark.rdd.{HadoopRDD, NewHadoopRDD, RDD} import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils @@ -596,4 +597,70 @@ class FileSuite extends SparkFunSuite with LocalSparkContext { actualPartitionNum = 5, expectedPartitionNum = 2) } + + test("spark.files.ignoreMissingFiles should work both HadoopRDD and NewHadoopRDD") { + // "file not found" can happen both when getPartitions or compute in HadoopRDD/NewHadoopRDD, + // We test both cases here. + + val deletedPath = new Path(tempDir.getAbsolutePath, "test-data-1") + val fs = deletedPath.getFileSystem(new Configuration()) + fs.delete(deletedPath, true) + intercept[FileNotFoundException](fs.open(deletedPath)) + + def collectRDDAndDeleteFileBeforeCompute(newApi: Boolean): Array[_] = { + val dataPath = new Path(tempDir.getAbsolutePath, "test-data-2") + val writer = new OutputStreamWriter(new FileOutputStream(new File(dataPath.toString))) + writer.write("hello\n") + writer.write("world\n") + writer.close() + val rdd = if (newApi) { + sc.newAPIHadoopFile(dataPath.toString, classOf[NewTextInputFormat], + classOf[LongWritable], classOf[Text]) + } else { + sc.textFile(dataPath.toString) + } + rdd.partitions + fs.delete(dataPath, true) + // Exception happens when initialize record reader in HadoopRDD/NewHadoopRDD.compute + // because partitions' info already cached. + rdd.collect() + } + + // collect HadoopRDD and NewHadoopRDD when spark.files.ignoreMissingFiles=false by default. + sc = new SparkContext("local", "test") + intercept[org.apache.hadoop.mapred.InvalidInputException] { + // Exception happens when HadoopRDD.getPartitions + sc.textFile(deletedPath.toString).collect() + } + + var e = intercept[SparkException] { + collectRDDAndDeleteFileBeforeCompute(false) + } + assert(e.getCause.isInstanceOf[java.io.FileNotFoundException]) + + intercept[org.apache.hadoop.mapreduce.lib.input.InvalidInputException] { + // Exception happens when NewHadoopRDD.getPartitions + sc.newAPIHadoopFile(deletedPath.toString, classOf[NewTextInputFormat], + classOf[LongWritable], classOf[Text]).collect + } + + e = intercept[SparkException] { + collectRDDAndDeleteFileBeforeCompute(true) + } + assert(e.getCause.isInstanceOf[java.io.FileNotFoundException]) + + sc.stop() + + // collect HadoopRDD and NewHadoopRDD when spark.files.ignoreMissingFiles=true. + val conf = new SparkConf().set(IGNORE_MISSING_FILES, true) + sc = new SparkContext("local", "test", conf) + assert(sc.textFile(deletedPath.toString).collect().isEmpty) + + assert(collectRDDAndDeleteFileBeforeCompute(false).isEmpty) + + assert(sc.newAPIHadoopFile(deletedPath.toString, classOf[NewTextInputFormat], + classOf[LongWritable], classOf[Text]).collect().isEmpty) + + assert(collectRDDAndDeleteFileBeforeCompute(true).isEmpty) + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 0dc47bfe075d0..3729bd5293eca 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -437,7 +437,8 @@ object SQLConf { val HIVE_VERIFY_PARTITION_PATH = buildConf("spark.sql.hive.verifyPartitionPath") .doc("When true, check all the partition paths under the table\'s root directory " + - "when reading data stored in HDFS.") + "when reading data stored in HDFS. This configuration will be deprecated in the future " + + "releases and replaced by spark.files.ignoreMissingFiles.") .booleanConf .createWithDefault(false) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala index b2dc401ce1efc..78156b17fb43b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala @@ -23,6 +23,7 @@ import java.sql.Timestamp import com.google.common.io.Files import org.apache.hadoop.fs.FileSystem +import org.apache.spark.internal.config._ import org.apache.spark.sql._ import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.internal.SQLConf @@ -70,6 +71,45 @@ class QueryPartitionSuite extends QueryTest with SQLTestUtils with TestHiveSingl } } + test("Replace spark.sql.hive.verifyPartitionPath by spark.files.ignoreMissingFiles") { + withSQLConf((SQLConf.HIVE_VERIFY_PARTITION_PATH.key, "false")) { + sparkContext.conf.set(IGNORE_MISSING_FILES.key, "true") + val testData = sparkContext.parallelize( + (1 to 10).map(i => TestData(i, i.toString))).toDF() + testData.createOrReplaceTempView("testData") + + val tmpDir = Files.createTempDir() + // create the table for test + sql(s"CREATE TABLE table_with_partition(key int,value string) " + + s"PARTITIONED by (ds string) location '${tmpDir.toURI}' ") + sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='1') " + + "SELECT key,value FROM testData") + sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='2') " + + "SELECT key,value FROM testData") + sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='3') " + + "SELECT key,value FROM testData") + sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='4') " + + "SELECT key,value FROM testData") + + // test for the exist path + checkAnswer(sql("select key,value from table_with_partition"), + testData.toDF.collect ++ testData.toDF.collect + ++ testData.toDF.collect ++ testData.toDF.collect) + + // delete the path of one partition + tmpDir.listFiles + .find { f => f.isDirectory && f.getName().startsWith("ds=") } + .foreach { f => Utils.deleteRecursively(f) } + + // test for after delete the path + checkAnswer(sql("select key,value from table_with_partition"), + testData.toDF.collect ++ testData.toDF.collect ++ testData.toDF.collect) + + sql("DROP TABLE IF EXISTS table_with_partition") + sql("DROP TABLE IF EXISTS createAndInsertTest") + } + } + test("SPARK-21739: Cast expression should initialize timezoneId") { withTable("table_with_timestamp_partition") { sql("CREATE TABLE table_with_timestamp_partition(value int) PARTITIONED BY (ts TIMESTAMP)") From 3990daaf3b6ca2c5a9f7790030096262efb12cb2 Mon Sep 17 00:00:00 2001 From: jinxing Date: Tue, 17 Apr 2018 08:55:01 -0500 Subject: [PATCH 0634/1161] [SPARK-23948] Trigger mapstage's job listener in submitMissingTasks ## What changes were proposed in this pull request? SparkContext submitted a map stage from `submitMapStage` to `DAGScheduler`, `markMapStageJobAsFinished` is called only in (https://github.com/apache/spark/blob/master/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala#L933 and https://github.com/apache/spark/blob/master/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala#L1314); But think about below scenario: 1. stage0 and stage1 are all `ShuffleMapStage` and stage1 depends on stage0; 2. We submit stage1 by `submitMapStage`; 3. When stage 1 running, `FetchFailed` happened, stage0 and stage1 got resubmitted as stage0_1 and stage1_1; 4. When stage0_1 running, speculated tasks in old stage1 come as succeeded, but stage1 is not inside `runningStages`. So even though all splits(including the speculated tasks) in stage1 succeeded, job listener in stage1 will not be called; 5. stage0_1 finished, stage1_1 starts running. When `submitMissingTasks`, there is no missing tasks. But in current code, job listener is not triggered. We should call the job listener for map stage in `5`. ## How was this patch tested? Not added yet. Author: jinxing Closes #21019 from jinxing64/SPARK-23948. --- .../apache/spark/scheduler/DAGScheduler.scala | 33 ++++++------ .../spark/scheduler/DAGSchedulerSuite.scala | 52 +++++++++++++++++++ 2 files changed, 70 insertions(+), 15 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 8c46a84323392..78b6b34b5d2bb 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -1092,17 +1092,16 @@ class DAGScheduler( // the stage as completed here in case there are no tasks to run markStageAsFinished(stage, None) - val debugString = stage match { + stage match { case stage: ShuffleMapStage => - s"Stage ${stage} is actually done; " + - s"(available: ${stage.isAvailable}," + - s"available outputs: ${stage.numAvailableOutputs}," + - s"partitions: ${stage.numPartitions})" + logDebug(s"Stage ${stage} is actually done; " + + s"(available: ${stage.isAvailable}," + + s"available outputs: ${stage.numAvailableOutputs}," + + s"partitions: ${stage.numPartitions})") + markMapStageJobsAsFinished(stage) case stage : ResultStage => - s"Stage ${stage} is actually done; (partitions: ${stage.numPartitions})" + logDebug(s"Stage ${stage} is actually done; (partitions: ${stage.numPartitions})") } - logDebug(debugString) - submitWaitingChildStages(stage) } } @@ -1307,13 +1306,7 @@ class DAGScheduler( shuffleStage.findMissingPartitions().mkString(", ")) submitStage(shuffleStage) } else { - // Mark any map-stage jobs waiting on this stage as finished - if (shuffleStage.mapStageJobs.nonEmpty) { - val stats = mapOutputTracker.getStatistics(shuffleStage.shuffleDep) - for (job <- shuffleStage.mapStageJobs) { - markMapStageJobAsFinished(job, stats) - } - } + markMapStageJobsAsFinished(shuffleStage) submitWaitingChildStages(shuffleStage) } } @@ -1433,6 +1426,16 @@ class DAGScheduler( } } + private[scheduler] def markMapStageJobsAsFinished(shuffleStage: ShuffleMapStage): Unit = { + // Mark any map-stage jobs waiting on this stage as finished + if (shuffleStage.isAvailable && shuffleStage.mapStageJobs.nonEmpty) { + val stats = mapOutputTracker.getStatistics(shuffleStage.shuffleDep) + for (job <- shuffleStage.mapStageJobs) { + markMapStageJobAsFinished(job, stats) + } + } + } + /** * Responds to an executor being lost. This is called inside the event loop, so it assumes it can * modify the scheduler's internal state. Use executorLost() to post a loss event from outside. diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index d812b5bd92c1b..8b6ec37625eec 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -2146,6 +2146,58 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi assertDataStructuresEmpty() } + test("Trigger mapstage's job listener in submitMissingTasks") { + val rdd1 = new MyRDD(sc, 2, Nil) + val dep1 = new ShuffleDependency(rdd1, new HashPartitioner(2)) + val rdd2 = new MyRDD(sc, 2, List(dep1), tracker = mapOutputTracker) + val dep2 = new ShuffleDependency(rdd2, new HashPartitioner(2)) + + val listener1 = new SimpleListener + val listener2 = new SimpleListener + + submitMapStage(dep1, listener1) + submitMapStage(dep2, listener2) + + // Complete the stage0. + assert(taskSets(0).stageId === 0) + complete(taskSets(0), Seq( + (Success, makeMapStatus("hostA", rdd1.partitions.length)), + (Success, makeMapStatus("hostB", rdd1.partitions.length)))) + assert(mapOutputTracker.getMapSizesByExecutorId(dep1.shuffleId, 0).map(_._1).toSet === + HashSet(makeBlockManagerId("hostA"), makeBlockManagerId("hostB"))) + assert(listener1.results.size === 1) + + // When attempting stage1, trigger a fetch failure. + assert(taskSets(1).stageId === 1) + complete(taskSets(1), Seq( + (Success, makeMapStatus("hostC", rdd2.partitions.length)), + (FetchFailed(makeBlockManagerId("hostA"), dep1.shuffleId, 0, 0, "ignored"), null))) + scheduler.resubmitFailedStages() + // Stage1 listener should not have a result yet + assert(listener2.results.size === 0) + + // Speculative task succeeded in stage1. + runEvent(makeCompletionEvent( + taskSets(1).tasks(1), + Success, + makeMapStatus("hostD", rdd2.partitions.length))) + // stage1 listener still should not have a result, though there's no missing partitions + // in it. Because stage1 has been failed and is not inside `runningStages` at this moment. + assert(listener2.results.size === 0) + + // Stage0 should now be running as task set 2; make its task succeed + assert(taskSets(2).stageId === 0) + complete(taskSets(2), Seq( + (Success, makeMapStatus("hostC", rdd2.partitions.length)))) + assert(mapOutputTracker.getMapSizesByExecutorId(dep1.shuffleId, 0).map(_._1).toSet === + Set(makeBlockManagerId("hostC"), makeBlockManagerId("hostB"))) + + // After stage0 is finished, stage1 will be submitted and found there is no missing + // partitions in it. Then listener got triggered. + assert(listener2.results.size === 1) + assertDataStructuresEmpty() + } + /** * In this test, we run a map stage where one of the executors fails but we still receive a * "zombie" complete message from that executor. We want to make sure the stage is not reported From f39e82ce150b6a7ea038e6858ba7adbaba3cad88 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Wed, 18 Apr 2018 00:35:44 +0800 Subject: [PATCH 0635/1161] [SPARK-23986][SQL] freshName can generate non-unique names ## What changes were proposed in this pull request? We are using `CodegenContext.freshName` to get a unique name for any new variable we are adding. Unfortunately, this method currently fails to create a unique name when we request more than one instance of variables with starting name `name1` and an instance with starting name `name11`. The PR changes the way a new name is generated by `CodegenContext.freshName` so that we generate unique names in this scenario too. ## How was this patch tested? added UT Author: Marco Gaido Closes #21080 from mgaido91/SPARK-23986. --- .../catalyst/expressions/codegen/CodeGenerator.scala | 11 +++-------- .../catalyst/expressions/CodeGenerationSuite.scala | 10 ++++++++++ 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index d97611c98ac91..f6b6775923ac6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -572,14 +572,9 @@ class CodegenContext { } else { s"${freshNamePrefix}_$name" } - if (freshNameIds.contains(fullName)) { - val id = freshNameIds(fullName) - freshNameIds(fullName) = id + 1 - s"$fullName$id" - } else { - freshNameIds += fullName -> 1 - fullName - } + val id = freshNameIds.getOrElse(fullName, 0) + freshNameIds(fullName) = id + 1 + s"${fullName}_$id" } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index f7c023111ff59..5b71becee2de0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -489,4 +489,14 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { assert(!ctx.subExprEliminationExprs.contains(ref)) } } + + test("SPARK-23986: freshName can generate duplicated names") { + val ctx = new CodegenContext + val names1 = ctx.freshName("myName1") :: ctx.freshName("myName1") :: + ctx.freshName("myName11") :: Nil + assert(names1.distinct.length == 3) + val names2 = ctx.freshName("a") :: ctx.freshName("a") :: + ctx.freshName("a_1") :: ctx.freshName("a_0") :: Nil + assert(names2.distinct.length == 4) + } } From 1ca3c50fefb34532c78427fa74872db3ecbf7ba2 Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Tue, 17 Apr 2018 10:11:08 -0700 Subject: [PATCH 0636/1161] [SPARK-21741][ML][PYSPARK] Python API for DataFrame-based multivariate summarizer ## What changes were proposed in this pull request? Python API for DataFrame-based multivariate summarizer. ## How was this patch tested? doctest added. Author: WeichenXu Closes #20695 from WeichenXu123/py_summarizer. --- python/pyspark/ml/stat.py | 193 +++++++++++++++++++++++++++++++++++++- 1 file changed, 192 insertions(+), 1 deletion(-) diff --git a/python/pyspark/ml/stat.py b/python/pyspark/ml/stat.py index 93d0f4fd9148f..a06ab31a7a56a 100644 --- a/python/pyspark/ml/stat.py +++ b/python/pyspark/ml/stat.py @@ -19,7 +19,9 @@ from pyspark import since, SparkContext from pyspark.ml.common import _java2py, _py2java -from pyspark.ml.wrapper import _jvm +from pyspark.ml.wrapper import JavaWrapper, _jvm +from pyspark.sql.column import Column, _to_seq +from pyspark.sql.functions import lit class ChiSquareTest(object): @@ -195,6 +197,195 @@ def test(dataset, sampleCol, distName, *params): _jvm().PythonUtils.toSeq(params))) +class Summarizer(object): + """ + .. note:: Experimental + + Tools for vectorized statistics on MLlib Vectors. + The methods in this package provide various statistics for Vectors contained inside DataFrames. + This class lets users pick the statistics they would like to extract for a given column. + + >>> from pyspark.ml.stat import Summarizer + >>> from pyspark.sql import Row + >>> from pyspark.ml.linalg import Vectors + >>> summarizer = Summarizer.metrics("mean", "count") + >>> df = sc.parallelize([Row(weight=1.0, features=Vectors.dense(1.0, 1.0, 1.0)), + ... Row(weight=0.0, features=Vectors.dense(1.0, 2.0, 3.0))]).toDF() + >>> df.select(summarizer.summary(df.features, df.weight)).show(truncate=False) + +-----------------------------------+ + |aggregate_metrics(features, weight)| + +-----------------------------------+ + |[[1.0,1.0,1.0], 1] | + +-----------------------------------+ + + >>> df.select(summarizer.summary(df.features)).show(truncate=False) + +--------------------------------+ + |aggregate_metrics(features, 1.0)| + +--------------------------------+ + |[[1.0,1.5,2.0], 2] | + +--------------------------------+ + + >>> df.select(Summarizer.mean(df.features, df.weight)).show(truncate=False) + +--------------+ + |mean(features)| + +--------------+ + |[1.0,1.0,1.0] | + +--------------+ + + >>> df.select(Summarizer.mean(df.features)).show(truncate=False) + +--------------+ + |mean(features)| + +--------------+ + |[1.0,1.5,2.0] | + +--------------+ + + + .. versionadded:: 2.4.0 + + """ + @staticmethod + @since("2.4.0") + def mean(col, weightCol=None): + """ + return a column of mean summary + """ + return Summarizer._get_single_metric(col, weightCol, "mean") + + @staticmethod + @since("2.4.0") + def variance(col, weightCol=None): + """ + return a column of variance summary + """ + return Summarizer._get_single_metric(col, weightCol, "variance") + + @staticmethod + @since("2.4.0") + def count(col, weightCol=None): + """ + return a column of count summary + """ + return Summarizer._get_single_metric(col, weightCol, "count") + + @staticmethod + @since("2.4.0") + def numNonZeros(col, weightCol=None): + """ + return a column of numNonZero summary + """ + return Summarizer._get_single_metric(col, weightCol, "numNonZeros") + + @staticmethod + @since("2.4.0") + def max(col, weightCol=None): + """ + return a column of max summary + """ + return Summarizer._get_single_metric(col, weightCol, "max") + + @staticmethod + @since("2.4.0") + def min(col, weightCol=None): + """ + return a column of min summary + """ + return Summarizer._get_single_metric(col, weightCol, "min") + + @staticmethod + @since("2.4.0") + def normL1(col, weightCol=None): + """ + return a column of normL1 summary + """ + return Summarizer._get_single_metric(col, weightCol, "normL1") + + @staticmethod + @since("2.4.0") + def normL2(col, weightCol=None): + """ + return a column of normL2 summary + """ + return Summarizer._get_single_metric(col, weightCol, "normL2") + + @staticmethod + def _check_param(featuresCol, weightCol): + if weightCol is None: + weightCol = lit(1.0) + if not isinstance(featuresCol, Column) or not isinstance(weightCol, Column): + raise TypeError("featureCol and weightCol should be a Column") + return featuresCol, weightCol + + @staticmethod + def _get_single_metric(col, weightCol, metric): + col, weightCol = Summarizer._check_param(col, weightCol) + return Column(JavaWrapper._new_java_obj("org.apache.spark.ml.stat.Summarizer." + metric, + col._jc, weightCol._jc)) + + @staticmethod + @since("2.4.0") + def metrics(*metrics): + """ + Given a list of metrics, provides a builder that it turns computes metrics from a column. + + See the documentation of [[Summarizer]] for an example. + + The following metrics are accepted (case sensitive): + - mean: a vector that contains the coefficient-wise mean. + - variance: a vector tha contains the coefficient-wise variance. + - count: the count of all vectors seen. + - numNonzeros: a vector with the number of non-zeros for each coefficients + - max: the maximum for each coefficient. + - min: the minimum for each coefficient. + - normL2: the Euclidian norm for each coefficient. + - normL1: the L1 norm of each coefficient (sum of the absolute values). + + :param metrics: + metrics that can be provided. + :return: + an object of :py:class:`pyspark.ml.stat.SummaryBuilder` + + Note: Currently, the performance of this interface is about 2x~3x slower then using the RDD + interface. + """ + sc = SparkContext._active_spark_context + js = JavaWrapper._new_java_obj("org.apache.spark.ml.stat.Summarizer.metrics", + _to_seq(sc, metrics)) + return SummaryBuilder(js) + + +class SummaryBuilder(JavaWrapper): + """ + .. note:: Experimental + + A builder object that provides summary statistics about a given column. + + Users should not directly create such builders, but instead use one of the methods in + :py:class:`pyspark.ml.stat.Summarizer` + + .. versionadded:: 2.4.0 + + """ + def __init__(self, jSummaryBuilder): + super(SummaryBuilder, self).__init__(jSummaryBuilder) + + @since("2.4.0") + def summary(self, featuresCol, weightCol=None): + """ + Returns an aggregate object that contains the summary of the column with the requested + metrics. + + :param featuresCol: + a column that contains features Vector object. + :param weightCol: + a column that contains weight value. Default weight is 1.0. + :return: + an aggregate column that contains the statistics. The exact content of this + structure is determined during the creation of the builder. + """ + featuresCol, weightCol = Summarizer._check_param(featuresCol, weightCol) + return Column(self._java_obj.summary(featuresCol._jc, weightCol._jc)) + + if __name__ == "__main__": import doctest import pyspark.ml.stat From 5fccdae18911793967b315c02c058eb737e46174 Mon Sep 17 00:00:00 2001 From: jerryshao Date: Tue, 17 Apr 2018 21:08:42 -0500 Subject: [PATCH 0637/1161] [SPARK-22968][DSTREAM] Throw an exception on partition revoking issue ## What changes were proposed in this pull request? Kafka partitions can be revoked when new consumers joined in the consumer group to rebalance the partitions. But current Spark Kafka connector code makes sure there's no partition revoking scenarios, so trying to get latest offset from revoked partitions will throw exceptions as JIRA mentioned. Partition revoking happens when new consumer joined the consumer group, which means different streaming apps are trying to use same group id. This is fundamentally not correct, different apps should use different consumer group. So instead of throwing an confused exception from Kafka, improve the exception message by identifying revoked partition and directly throw an meaningful exception when partition is revoked. Besides, this PR also fixes bugs in `DirectKafkaWordCount`, this example simply cannot be worked without the fix. ``` 8/01/05 09:48:27 INFO internals.ConsumerCoordinator: Revoking previously assigned partitions [kssh-7, kssh-4, kssh-3, kssh-6, kssh-5, kssh-0, kssh-2, kssh-1] for group use_a_separate_group_id_for_each_stream 18/01/05 09:48:27 INFO internals.AbstractCoordinator: (Re-)joining group use_a_separate_group_id_for_each_stream 18/01/05 09:48:27 INFO internals.AbstractCoordinator: Successfully joined group use_a_separate_group_id_for_each_stream with generation 4 18/01/05 09:48:27 INFO internals.ConsumerCoordinator: Setting newly assigned partitions [kssh-7, kssh-4, kssh-6, kssh-5] for group use_a_separate_group_id_for_each_stream ``` ## How was this patch tested? This is manually verified in local cluster, unfortunately I'm not sure how to simulate it in UT, so propose the PR without UT added. Author: jerryshao Closes #21038 from jerryshao/SPARK-22968. --- .../streaming/DirectKafkaWordCount.scala | 17 +++++++++++++---- .../kafka010/DirectKafkaInputDStream.scala | 12 ++++++++++++ 2 files changed, 25 insertions(+), 4 deletions(-) diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala index def06026bde96..2082fb71afdf1 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala @@ -18,6 +18,9 @@ // scalastyle:off println package org.apache.spark.examples.streaming +import org.apache.kafka.clients.consumer.ConsumerConfig +import org.apache.kafka.common.serialization.StringDeserializer + import org.apache.spark.SparkConf import org.apache.spark.streaming._ import org.apache.spark.streaming.kafka010._ @@ -26,18 +29,20 @@ import org.apache.spark.streaming.kafka010._ * Consumes messages from one or more topics in Kafka and does wordcount. * Usage: DirectKafkaWordCount * is a list of one or more Kafka brokers + * is a consumer group name to consume from topics * is a list of one or more kafka topics to consume from * * Example: * $ bin/run-example streaming.DirectKafkaWordCount broker1-host:port,broker2-host:port \ - * topic1,topic2 + * consumer-group topic1,topic2 */ object DirectKafkaWordCount { def main(args: Array[String]) { - if (args.length < 2) { + if (args.length < 3) { System.err.println(s""" |Usage: DirectKafkaWordCount | is a list of one or more Kafka brokers + | is a consumer group name to consume from topics | is a list of one or more kafka topics to consume from | """.stripMargin) @@ -46,7 +51,7 @@ object DirectKafkaWordCount { StreamingExamples.setStreamingLogLevels() - val Array(brokers, topics) = args + val Array(brokers, groupId, topics) = args // Create context with 2 second batch interval val sparkConf = new SparkConf().setAppName("DirectKafkaWordCount") @@ -54,7 +59,11 @@ object DirectKafkaWordCount { // Create direct kafka stream with brokers and topics val topicsSet = topics.split(",").toSet - val kafkaParams = Map[String, String]("metadata.broker.list" -> brokers) + val kafkaParams = Map[String, Object]( + ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG -> brokers, + ConsumerConfig.GROUP_ID_CONFIG -> groupId, + ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG -> classOf[StringDeserializer], + ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG -> classOf[StringDeserializer]) val messages = KafkaUtils.createDirectStream[String, String]( ssc, LocationStrategies.PreferConsistent, diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala index 215b7cab703fb..c3221481556f5 100644 --- a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala +++ b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala @@ -190,8 +190,20 @@ private[spark] class DirectKafkaInputDStream[K, V]( // make sure new partitions are reflected in currentOffsets val newPartitions = parts.diff(currentOffsets.keySet) + + // Check if there's any partition been revoked because of consumer rebalance. + val revokedPartitions = currentOffsets.keySet.diff(parts) + if (revokedPartitions.nonEmpty) { + throw new IllegalStateException(s"Previously tracked partitions " + + s"${revokedPartitions.mkString("[", ",", "]")} been revoked by Kafka because of consumer " + + s"rebalance. This is mostly due to another stream with same group id joined, " + + s"please check if there're different streaming application misconfigure to use same " + + s"group id. Fundamentally different stream should use different group id") + } + // position for new partitions determined by auto.offset.reset if no commit currentOffsets = currentOffsets ++ newPartitions.map(tp => tp -> c.position(tp)).toMap + // don't want to consume messages, so pause c.pause(newPartitions.asJava) // find latest available offsets From 1e3b8762a854a07c317f69fba7fa1a7bcdc58ff3 Mon Sep 17 00:00:00 2001 From: maryannxue Date: Wed, 18 Apr 2018 10:36:41 +0800 Subject: [PATCH 0638/1161] [SPARK-21479][SQL] Outer join filter pushdown in null supplying table when condition is on one of the joined columns ## What changes were proposed in this pull request? Added `TransitPredicateInOuterJoin` optimization rule that transits constraints from the preserved side of an outer join to the null-supplying side. The constraints of the join operator will remain unchanged. ## How was this patch tested? Added 3 tests in `InferFiltersFromConstraintsSuite`. Author: maryannxue Closes #20816 from maryannxue/spark-21479. --- .../sql/catalyst/optimizer/Optimizer.scala | 42 +++++++++++++++++-- .../plans/logical/QueryPlanConstraints.scala | 25 +++++++++-- .../InferFiltersFromConstraintsSuite.scala | 36 ++++++++++++++++ 3 files changed, 96 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 5fb59ef350b8b..913354e4df0e6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -637,8 +637,11 @@ object CollapseWindow extends Rule[LogicalPlan] { * constraints. These filters are currently inserted to the existing conditions in the Filter * operators and on either side of Join operators. * - * Note: While this optimization is applicable to all types of join, it primarily benefits Inner and - * LeftSemi joins. + * In addition, for left/right outer joins, infer predicate from the preserved side of the Join + * operator and push the inferred filter over to the null-supplying side. For example, if the + * preserved side has constraints of the form 'a > 5' and the join condition is 'a = b', in + * which 'b' is an attribute from the null-supplying side, a [[Filter]] operator of 'b > 5' will + * be applied to the null-supplying side. */ object InferFiltersFromConstraints extends Rule[LogicalPlan] with PredicateHelper { @@ -671,11 +674,42 @@ object InferFiltersFromConstraints extends Rule[LogicalPlan] with PredicateHelpe val newConditionOpt = conditionOpt match { case Some(condition) => val newFilters = additionalConstraints -- splitConjunctivePredicates(condition) - if (newFilters.nonEmpty) Option(And(newFilters.reduce(And), condition)) else None + if (newFilters.nonEmpty) Option(And(newFilters.reduce(And), condition)) else conditionOpt case None => additionalConstraints.reduceOption(And) } - if (newConditionOpt.isDefined) Join(left, right, joinType, newConditionOpt) else join + // Infer filter for left/right outer joins + val newLeftOpt = joinType match { + case RightOuter if newConditionOpt.isDefined => + val inferredConstraints = left.getRelevantConstraints( + left.constraints + .union(right.constraints) + .union(splitConjunctivePredicates(newConditionOpt.get).toSet)) + val newFilters = inferredConstraints + .filterNot(left.constraints.contains) + .reduceLeftOption(And) + newFilters.map(Filter(_, left)) + case _ => None + } + val newRightOpt = joinType match { + case LeftOuter if newConditionOpt.isDefined => + val inferredConstraints = right.getRelevantConstraints( + right.constraints + .union(left.constraints) + .union(splitConjunctivePredicates(newConditionOpt.get).toSet)) + val newFilters = inferredConstraints + .filterNot(right.constraints.contains) + .reduceLeftOption(And) + newFilters.map(Filter(_, right)) + case _ => None + } + + if ((newConditionOpt.isDefined && (newConditionOpt ne conditionOpt)) + || newLeftOpt.isDefined || newRightOpt.isDefined) { + Join(newLeftOpt.getOrElse(left), newRightOpt.getOrElse(right), joinType, newConditionOpt) + } else { + join + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala index 046848875548b..a29f3d29236c7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala @@ -41,9 +41,7 @@ trait QueryPlanConstraints { self: LogicalPlan => * example, if this set contains the expression `a = 2` then that expression is guaranteed to * evaluate to `true` for all rows produced. */ - lazy val constraints: ExpressionSet = ExpressionSet(allConstraints.filter { c => - c.references.nonEmpty && c.references.subsetOf(outputSet) && c.deterministic - }) + lazy val constraints: ExpressionSet = ExpressionSet(allConstraints.filter(selfReferenceOnly)) /** * This method can be overridden by any child class of QueryPlan to specify a set of constraints @@ -55,6 +53,23 @@ trait QueryPlanConstraints { self: LogicalPlan => */ protected def validConstraints: Set[Expression] = Set.empty + /** + * Returns an [[ExpressionSet]] that contains an additional set of constraints, such as + * equality constraints and `isNotNull` constraints, etc., and that only contains references + * to this [[LogicalPlan]] node. + */ + def getRelevantConstraints(constraints: Set[Expression]): ExpressionSet = { + val allRelevantConstraints = + if (conf.constraintPropagationEnabled) { + constraints + .union(inferAdditionalConstraints(constraints)) + .union(constructIsNotNullConstraints(constraints)) + } else { + constraints + } + ExpressionSet(allRelevantConstraints.filter(selfReferenceOnly)) + } + /** * Infers a set of `isNotNull` constraints from null intolerant expressions as well as * non-nullable attributes. For e.g., if an expression is of the form (`a > 5`), this @@ -120,4 +135,8 @@ trait QueryPlanConstraints { self: LogicalPlan => destination: Attribute): Set[Expression] = constraints.map(_ transform { case e: Expression if e.semanticEquals(source) => destination }) + + private def selfReferenceOnly(e: Expression): Boolean = { + e.references.nonEmpty && e.references.subsetOf(outputSet) && e.deterministic + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala index f78c2356e35a5..e068f51044589 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala @@ -204,4 +204,40 @@ class InferFiltersFromConstraintsSuite extends PlanTest { val optimized = Optimize.execute(originalQuery) comparePlans(optimized, correctAnswer) } + + test("SPARK-21479: Outer join after-join filters push down to null-supplying side") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + val condition = Some("x.a".attr === "y.a".attr) + val originalQuery = x.join(y, LeftOuter, condition).where("x.a".attr === 2).analyze + val left = x.where(IsNotNull('a) && 'a === 2) + val right = y.where(IsNotNull('a) && 'a === 2) + val correctAnswer = left.join(right, LeftOuter, condition).analyze + val optimized = Optimize.execute(originalQuery) + comparePlans(optimized, correctAnswer) + } + + test("SPARK-21479: Outer join pre-existing filters push down to null-supplying side") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + val condition = Some("x.a".attr === "y.a".attr) + val originalQuery = x.join(y.where("y.a".attr > 5), RightOuter, condition).analyze + val left = x.where(IsNotNull('a) && 'a > 5) + val right = y.where(IsNotNull('a) && 'a > 5) + val correctAnswer = left.join(right, RightOuter, condition).analyze + val optimized = Optimize.execute(originalQuery) + comparePlans(optimized, correctAnswer) + } + + test("SPARK-21479: Outer join no filter push down to preserved side") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + val condition = Some("x.a".attr === "y.a".attr) + val originalQuery = x.join(y.where("y.a".attr === 1), LeftOuter, condition).analyze + val left = x + val right = y.where(IsNotNull('a) && 'a === 1) + val correctAnswer = left.join(right, LeftOuter, condition).analyze + val optimized = Optimize.execute(originalQuery) + comparePlans(optimized, correctAnswer) + } } From 310a8cd06299e434d94a1e391a6eb62944112446 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 18 Apr 2018 11:51:10 +0800 Subject: [PATCH 0639/1161] [SPARK-23341][SQL] define some standard options for data source v2 ## What changes were proposed in this pull request? Each data source implementation can define its own options and teach its users how to set them. Spark doesn't have any restrictions about what options a data source should or should not have. It's possible that some options are very common and many data sources use them. However different data sources may define the common options(key and meaning) differently, which is quite confusing to end users. This PR defines some standard options that data sources can optionally adopt: path, table and database. ## How was this patch tested? a new test case. Author: Wenchen Fan Closes #20535 from cloud-fan/options. --- .../sql/sources/v2/DataSourceOptions.java | 100 ++++++++++++++++++ .../apache/spark/sql/DataFrameReader.scala | 14 ++- .../sources/v2/DataSourceOptionsSuite.scala | 25 +++++ 3 files changed, 135 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceOptions.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceOptions.java index c32053580f016..83df3be747085 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceOptions.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceOptions.java @@ -17,16 +17,61 @@ package org.apache.spark.sql.sources.v2; +import java.io.IOException; import java.util.HashMap; import java.util.Locale; import java.util.Map; import java.util.Optional; +import java.util.stream.Stream; + +import com.fasterxml.jackson.databind.ObjectMapper; import org.apache.spark.annotation.InterfaceStability; /** * An immutable string-to-string map in which keys are case-insensitive. This is used to represent * data source options. + * + * Each data source implementation can define its own options and teach its users how to set them. + * Spark doesn't have any restrictions about what options a data source should or should not have. + * Instead Spark defines some standard options that data sources can optionally adopt. It's possible + * that some options are very common and many data sources use them. However different data + * sources may define the common options(key and meaning) differently, which is quite confusing to + * end users. + * + * The standard options defined by Spark: + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + *
    Option keyOption value
    pathA path string of the data files/directories, like + * path1, /absolute/file2, path3/*. The path can + * either be relative or absolute, points to either file or directory, and can contain + * wildcards. This option is commonly used by file-based data sources.
    pathsA JSON array style paths string of the data files/directories, like + * ["path1", "/absolute/file2"]. The format of each path is same as the + * path option, plus it should follow JSON string literal format, e.g. quotes + * should be escaped, pa\"th means pa"th. + *
    tableA table name string representing the table name directly without any interpretation. + * For example, db.tbl means a table called db.tbl, not a table called tbl + * inside database db. `t*b.l` means a table called `t*b.l`, not t*b.l.
    databaseA database name string representing the database name directly without any + * interpretation, which is very similar to the table name option.
    */ @InterfaceStability.Evolving public class DataSourceOptions { @@ -97,4 +142,59 @@ public double getDouble(String key, double defaultValue) { return keyLowerCasedMap.containsKey(lcaseKey) ? Double.parseDouble(keyLowerCasedMap.get(lcaseKey)) : defaultValue; } + + /** + * The option key for singular path. + */ + public static final String PATH_KEY = "path"; + + /** + * The option key for multiple paths. + */ + public static final String PATHS_KEY = "paths"; + + /** + * The option key for table name. + */ + public static final String TABLE_KEY = "table"; + + /** + * The option key for database name. + */ + public static final String DATABASE_KEY = "database"; + + /** + * Returns all the paths specified by both the singular path option and the multiple + * paths option. + */ + public String[] paths() { + String[] singularPath = + get(PATH_KEY).map(s -> new String[]{s}).orElseGet(() -> new String[0]); + Optional pathsStr = get(PATHS_KEY); + if (pathsStr.isPresent()) { + ObjectMapper objectMapper = new ObjectMapper(); + try { + String[] paths = objectMapper.readValue(pathsStr.get(), String[].class); + return Stream.of(singularPath, paths).flatMap(Stream::of).toArray(String[]::new); + } catch (IOException e) { + return singularPath; + } + } else { + return singularPath; + } + } + + /** + * Returns the value of the table name option. + */ + public Optional tableName() { + return get(TABLE_KEY); + } + + /** + * Returns the value of the database name option. + */ + public Optional databaseName() { + return get(DATABASE_KEY); + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index ae3ba1690f696..d640fdc530ce2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -21,6 +21,8 @@ import java.util.{Locale, Properties} import scala.collection.JavaConverters._ +import com.fasterxml.jackson.databind.ObjectMapper + import org.apache.spark.Partition import org.apache.spark.annotation.InterfaceStability import org.apache.spark.api.java.JavaRDD @@ -34,7 +36,7 @@ import org.apache.spark.sql.execution.datasources.jdbc._ import org.apache.spark.sql.execution.datasources.json.TextInputJsonDataSource import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Utils -import org.apache.spark.sql.sources.v2.{DataSourceV2, ReadSupport, ReadSupportWithSchema} +import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, ReadSupport, ReadSupportWithSchema} import org.apache.spark.sql.types.{StringType, StructType} import org.apache.spark.unsafe.types.UTF8String @@ -171,7 +173,8 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * @since 1.4.0 */ def load(path: String): DataFrame = { - option("path", path).load(Seq.empty: _*) // force invocation of `load(...varargs...)` + // force invocation of `load(...varargs...)` + option(DataSourceOptions.PATH_KEY, path).load(Seq.empty: _*) } /** @@ -193,10 +196,13 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { if (ds.isInstanceOf[ReadSupport] || ds.isInstanceOf[ReadSupportWithSchema]) { val sessionOptions = DataSourceV2Utils.extractSessionConfigs( ds = ds, conf = sparkSession.sessionState.conf) + val pathsOption = { + val objectMapper = new ObjectMapper() + DataSourceOptions.PATHS_KEY -> objectMapper.writeValueAsString(paths.toArray) + } Dataset.ofRows(sparkSession, DataSourceV2Relation.create( - ds, extraOptions.toMap ++ sessionOptions, + ds, extraOptions.toMap ++ sessionOptions + pathsOption, userSpecifiedSchema = userSpecifiedSchema)) - } else { loadV1Source(paths: _*) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceOptionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceOptionsSuite.scala index 31dfc55b23361..cfa69a86de1a7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceOptionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceOptionsSuite.scala @@ -79,4 +79,29 @@ class DataSourceOptionsSuite extends SparkFunSuite { options.getDouble("foo", 0.1d) } } + + test("standard options") { + val options = new DataSourceOptions(Map( + DataSourceOptions.PATH_KEY -> "abc", + DataSourceOptions.TABLE_KEY -> "tbl").asJava) + + assert(options.paths().toSeq == Seq("abc")) + assert(options.tableName().get() == "tbl") + assert(!options.databaseName().isPresent) + } + + test("standard options with both singular path and multi-paths") { + val options = new DataSourceOptions(Map( + DataSourceOptions.PATH_KEY -> "abc", + DataSourceOptions.PATHS_KEY -> """["c", "d"]""").asJava) + + assert(options.paths().toSeq == Seq("abc", "c", "d")) + } + + test("standard options with only multi-paths") { + val options = new DataSourceOptions(Map( + DataSourceOptions.PATHS_KEY -> """["c", "d\"e"]""").asJava) + + assert(options.paths().toSeq == Seq("c", "d\"e")) + } } From cce469435d61bda5893d9aa6cfdf7ea46fa717df Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 17 Apr 2018 21:03:57 -0700 Subject: [PATCH 0640/1161] [SPARK-24002][SQL] Task not serializable caused by org.apache.parquet.io.api.Binary$ByteBufferBackedBinary.getBytes ## What changes were proposed in this pull request? ``` Py4JJavaError: An error occurred while calling o153.sql. : org.apache.spark.SparkException: Job aborted. at org.apache.spark.sql.execution.datasources.FileFormatWriter$.write(FileFormatWriter.scala:223) at org.apache.spark.sql.execution.datasources.InsertIntoHadoopFsRelationCommand.run(InsertIntoHadoopFsRelationCommand.scala:189) at org.apache.spark.sql.execution.command.ExecutedCommandExec.sideEffectResult$lzycompute(commands.scala:70) at org.apache.spark.sql.execution.command.ExecutedCommandExec.sideEffectResult(commands.scala:68) at org.apache.spark.sql.execution.command.ExecutedCommandExec.executeCollect(commands.scala:79) at org.apache.spark.sql.Dataset$$anonfun$6.apply(Dataset.scala:190) at org.apache.spark.sql.Dataset$$anonfun$6.apply(Dataset.scala:190) at org.apache.spark.sql.Dataset$$anonfun$59.apply(Dataset.scala:3021) at org.apache.spark.sql.execution.SQLExecution$.withCustomExecutionEnv(SQLExecution.scala:89) at org.apache.spark.sql.execution.SQLExecution$.withNewExecutionId(SQLExecution.scala:127) at org.apache.spark.sql.Dataset.withAction(Dataset.scala:3020) at org.apache.spark.sql.Dataset.(Dataset.scala:190) at org.apache.spark.sql.Dataset$.ofRows(Dataset.scala:74) at org.apache.spark.sql.SparkSession.sql(SparkSession.scala:646) at sun.reflect.GeneratedMethodAccessor153.invoke(Unknown Source) at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43) at java.lang.reflect.Method.invoke(Method.java:498) at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244) at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:380) at py4j.Gateway.invoke(Gateway.java:293) at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132) at py4j.commands.CallCommand.execute(CallCommand.java:79) at py4j.GatewayConnection.run(GatewayConnection.java:226) at java.lang.Thread.run(Thread.java:748) Caused by: org.apache.spark.SparkException: Exception thrown in Future.get: at org.apache.spark.sql.execution.exchange.BroadcastExchangeExec.doExecuteBroadcast(BroadcastExchangeExec.scala:190) at org.apache.spark.sql.execution.InputAdapter.doExecuteBroadcast(WholeStageCodegenExec.scala:267) at org.apache.spark.sql.execution.joins.BroadcastNestedLoopJoinExec.doConsume(BroadcastNestedLoopJoinExec.scala:530) at org.apache.spark.sql.execution.CodegenSupport$class.consume(WholeStageCodegenExec.scala:155) at org.apache.spark.sql.execution.ProjectExec.consume(basicPhysicalOperators.scala:37) at org.apache.spark.sql.execution.ProjectExec.doConsume(basicPhysicalOperators.scala:69) at org.apache.spark.sql.execution.CodegenSupport$class.consume(WholeStageCodegenExec.scala:155) at org.apache.spark.sql.execution.FilterExec.consume(basicPhysicalOperators.scala:144) ... at org.apache.spark.sql.execution.datasources.FileFormatWriter$.write(FileFormatWriter.scala:190) ... 23 more Caused by: java.util.concurrent.ExecutionException: org.apache.spark.SparkException: Task not serializable at java.util.concurrent.FutureTask.report(FutureTask.java:122) at java.util.concurrent.FutureTask.get(FutureTask.java:206) at org.apache.spark.sql.execution.exchange.BroadcastExchangeExec.doExecuteBroadcast(BroadcastExchangeExec.scala:179) ... 276 more Caused by: org.apache.spark.SparkException: Task not serializable at org.apache.spark.util.ClosureCleaner$.ensureSerializable(ClosureCleaner.scala:340) at org.apache.spark.util.ClosureCleaner$.org$apache$spark$util$ClosureCleaner$$clean(ClosureCleaner.scala:330) at org.apache.spark.util.ClosureCleaner$.clean(ClosureCleaner.scala:156) at org.apache.spark.SparkContext.clean(SparkContext.scala:2380) at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsWithIndex$1.apply(RDD.scala:850) at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsWithIndex$1.apply(RDD.scala:849) at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151) at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:112) at org.apache.spark.rdd.RDD.withScope(RDD.scala:371) at org.apache.spark.rdd.RDD.mapPartitionsWithIndex(RDD.scala:849) at org.apache.spark.sql.execution.WholeStageCodegenExec.doExecute(WholeStageCodegenExec.scala:417) at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:123) at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:118) at org.apache.spark.sql.execution.SparkPlan$$anonfun$executeQuery$3.apply(SparkPlan.scala:152) at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151) at org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:149) at org.apache.spark.sql.execution.SparkPlan.execute(SparkPlan.scala:118) at org.apache.spark.sql.execution.exchange.ShuffleExchangeExec.prepareShuffleDependency(ShuffleExchangeExec.scala:89) at org.apache.spark.sql.execution.exchange.ShuffleExchangeExec$$anonfun$doExecute$1.apply(ShuffleExchangeExec.scala:125) at org.apache.spark.sql.execution.exchange.ShuffleExchangeExec$$anonfun$doExecute$1.apply(ShuffleExchangeExec.scala:116) at org.apache.spark.sql.catalyst.errors.package$.attachTree(package.scala:52) at org.apache.spark.sql.execution.exchange.ShuffleExchangeExec.doExecute(ShuffleExchangeExec.scala:116) at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:123) at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:118) at org.apache.spark.sql.execution.SparkPlan$$anonfun$executeQuery$3.apply(SparkPlan.scala:152) at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151) at org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:149) at org.apache.spark.sql.execution.SparkPlan.execute(SparkPlan.scala:118) at org.apache.spark.sql.execution.InputAdapter.inputRDDs(WholeStageCodegenExec.scala:271) at org.apache.spark.sql.execution.aggregate.HashAggregateExec.inputRDDs(HashAggregateExec.scala:181) at org.apache.spark.sql.execution.WholeStageCodegenExec.doExecute(WholeStageCodegenExec.scala:414) at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:123) at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:118) at org.apache.spark.sql.execution.SparkPlan$$anonfun$executeQuery$3.apply(SparkPlan.scala:152) at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151) at org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:149) at org.apache.spark.sql.execution.SparkPlan.execute(SparkPlan.scala:118) at org.apache.spark.sql.execution.collect.Collector$.collect(Collector.scala:61) at org.apache.spark.sql.execution.collect.Collector$.collect(Collector.scala:70) at org.apache.spark.sql.execution.SparkPlan.executeCollectResult(SparkPlan.scala:264) at org.apache.spark.sql.execution.exchange.BroadcastExchangeExec$$anon$1$$anonfun$call$1.apply(BroadcastExchangeExec.scala:93) at org.apache.spark.sql.execution.exchange.BroadcastExchangeExec$$anon$1$$anonfun$call$1.apply(BroadcastExchangeExec.scala:81) at org.apache.spark.sql.execution.SQLExecution$.withExecutionId(SQLExecution.scala:150) at org.apache.spark.sql.execution.exchange.BroadcastExchangeExec$$anon$1.call(BroadcastExchangeExec.scala:80) at org.apache.spark.sql.execution.exchange.BroadcastExchangeExec$$anon$1.call(BroadcastExchangeExec.scala:76) at java.util.concurrent.FutureTask.run(FutureTask.java:266) at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149) at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624) ... 1 more Caused by: java.nio.BufferUnderflowException at java.nio.HeapByteBuffer.get(HeapByteBuffer.java:151) at java.nio.ByteBuffer.get(ByteBuffer.java:715) at org.apache.parquet.io.api.Binary$ByteBufferBackedBinary.getBytes(Binary.java:405) at org.apache.parquet.io.api.Binary$ByteBufferBackedBinary.getBytesUnsafe(Binary.java:414) at org.apache.parquet.io.api.Binary$ByteBufferBackedBinary.writeObject(Binary.java:484) at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method) at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62) at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43) at java.lang.reflect.Method.invoke(Method.java:498) at java.io.ObjectStreamClass.invokeWriteObject(ObjectStreamClass.java:1128) at java.io.ObjectOutputStream.writeSerialData(ObjectOutputStream.java:1496) ``` The Parquet filters are serializable but not thread safe. SparkPlan.prepare() could be called in different threads (BroadcastExchange will call it in a thread pool). Thus, we could serialize the same Parquet filter at the same time. This is not easily reproduced. The fix is to avoid serializing these Parquet filters in the driver. This PR is to avoid serializing these Parquet filters by moving the parquet filter generation from the driver to executors. ## How was this patch tested? Having two queries one is a 1000-line SQL query and a 3000-line SQL query. Need to run at least one hour with a heavy write workload to reproduce once. Author: gatorsmile Closes #21086 from gatorsmile/taskNotSerializable. --- .../parquet/ParquetFileFormat.scala | 27 ++++++++++--------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala index 476bd02374364..d8f47eec952de 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala @@ -321,19 +321,6 @@ class ParquetFileFormat SQLConf.PARQUET_INT96_AS_TIMESTAMP.key, sparkSession.sessionState.conf.isParquetINT96AsTimestamp) - // Try to push down filters when filter push-down is enabled. - val pushed = - if (sparkSession.sessionState.conf.parquetFilterPushDown) { - filters - // Collects all converted Parquet filter predicates. Notice that not all predicates can be - // converted (`ParquetFilters.createFilter` returns an `Option`). That's why a `flatMap` - // is used here. - .flatMap(ParquetFilters.createFilter(requiredSchema, _)) - .reduceOption(FilterApi.and) - } else { - None - } - val broadcastedHadoopConf = sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) @@ -351,12 +338,26 @@ class ParquetFileFormat val timestampConversion: Boolean = sparkSession.sessionState.conf.isParquetINT96TimestampConversion val capacity = sqlConf.parquetVectorizedReaderBatchSize + val enableParquetFilterPushDown: Boolean = + sparkSession.sessionState.conf.parquetFilterPushDown // Whole stage codegen (PhysicalRDD) is able to deal with batches directly val returningBatch = supportBatch(sparkSession, resultSchema) (file: PartitionedFile) => { assert(file.partitionValues.numFields == partitionSchema.size) + // Try to push down filters when filter push-down is enabled. + val pushed = if (enableParquetFilterPushDown) { + filters + // Collects all converted Parquet filter predicates. Notice that not all predicates can be + // converted (`ParquetFilters.createFilter` returns an `Option`). That's why a `flatMap` + // is used here. + .flatMap(ParquetFilters.createFilter(requiredSchema, _)) + .reduceOption(FilterApi.and) + } else { + None + } + val fileSplit = new FileSplit(new Path(new URI(file.filePath)), file.start, file.length, Array.empty) From f81fa478ff990146e2a8e463ac252271448d96f5 Mon Sep 17 00:00:00 2001 From: mn-mikke Date: Wed, 18 Apr 2018 18:41:55 +0900 Subject: [PATCH 0641/1161] [SPARK-23926][SQL] Extending reverse function to support ArrayType arguments ## What changes were proposed in this pull request? This PR extends `reverse` functions to be able to operate over array columns and covers: - Introduction of `Reverse` expression that represents logic for reversing arrays and also strings - Removal of `StringReverse` expression - A wrapper for PySpark ## How was this patch tested? New tests added into: - CollectionExpressionsSuite - DataFrameFunctionsSuite ## Codegen examples ### Primitive type ``` val df = Seq( Seq(1, 3, 4, 2), null ).toDF("i") df.filter($"i".isNotNull || $"i".isNull).select(reverse($"i")).debugCodegen ``` Result: ``` /* 032 */ boolean inputadapter_isNull = inputadapter_row.isNullAt(0); /* 033 */ ArrayData inputadapter_value = inputadapter_isNull ? /* 034 */ null : (inputadapter_row.getArray(0)); /* 035 */ /* 036 */ boolean filter_value = true; /* 037 */ /* 038 */ if (!(!inputadapter_isNull)) { /* 039 */ filter_value = inputadapter_isNull; /* 040 */ } /* 041 */ if (!filter_value) continue; /* 042 */ /* 043 */ ((org.apache.spark.sql.execution.metric.SQLMetric) references[0] /* numOutputRows */).add(1); /* 044 */ /* 045 */ boolean project_isNull = inputadapter_isNull; /* 046 */ ArrayData project_value = null; /* 047 */ /* 048 */ if (!inputadapter_isNull) { /* 049 */ final int project_length = inputadapter_value.numElements(); /* 050 */ project_value = inputadapter_value.copy(); /* 051 */ for(int k = 0; k < project_length / 2; k++) { /* 052 */ int l = project_length - k - 1; /* 053 */ boolean isNullAtK = project_value.isNullAt(k); /* 054 */ boolean isNullAtL = project_value.isNullAt(l); /* 055 */ if(!isNullAtK) { /* 056 */ int el = project_value.getInt(k); /* 057 */ if(!isNullAtL) { /* 058 */ project_value.setInt(k, project_value.getInt(l)); /* 059 */ } else { /* 060 */ project_value.setNullAt(k); /* 061 */ } /* 062 */ project_value.setInt(l, el); /* 063 */ } else if (!isNullAtL) { /* 064 */ project_value.setInt(k, project_value.getInt(l)); /* 065 */ project_value.setNullAt(l); /* 066 */ } /* 067 */ } /* 068 */ /* 069 */ } ``` ### Non-primitive type ``` val df = Seq( Seq("a", "c", "d", "b"), null ).toDF("s") df.filter($"s".isNotNull || $"s".isNull).select(reverse($"s")).debugCodegen ``` Result: ``` /* 032 */ boolean inputadapter_isNull = inputadapter_row.isNullAt(0); /* 033 */ ArrayData inputadapter_value = inputadapter_isNull ? /* 034 */ null : (inputadapter_row.getArray(0)); /* 035 */ /* 036 */ boolean filter_value = true; /* 037 */ /* 038 */ if (!(!inputadapter_isNull)) { /* 039 */ filter_value = inputadapter_isNull; /* 040 */ } /* 041 */ if (!filter_value) continue; /* 042 */ /* 043 */ ((org.apache.spark.sql.execution.metric.SQLMetric) references[0] /* numOutputRows */).add(1); /* 044 */ /* 045 */ boolean project_isNull = inputadapter_isNull; /* 046 */ ArrayData project_value = null; /* 047 */ /* 048 */ if (!inputadapter_isNull) { /* 049 */ final int project_length = inputadapter_value.numElements(); /* 050 */ project_value = new org.apache.spark.sql.catalyst.util.GenericArrayData(new Object[project_length]); /* 051 */ for(int k = 0; k < project_length; k++) { /* 052 */ int l = project_length - k - 1; /* 053 */ project_value.update(k, inputadapter_value.getUTF8String(l)); /* 054 */ } /* 055 */ /* 056 */ } ``` Author: mn-mikke Closes #21034 from mn-mikke/feature/array-api-reverse-to-master. --- python/pyspark/sql/functions.py | 20 +++- .../catalyst/analysis/FunctionRegistry.scala | 2 +- .../expressions/collectionOperations.scala | 88 +++++++++++++++++ .../expressions/stringExpressions.scala | 20 ---- .../CollectionExpressionsSuite.scala | 44 +++++++++ .../expressions/StringExpressionsSuite.scala | 6 +- .../org/apache/spark/sql/functions.scala | 15 ++- .../spark/sql/DataFrameFunctionsSuite.scala | 94 +++++++++++++++++++ 8 files changed, 256 insertions(+), 33 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 6ca22b610843d..d3bb0a5d6b36a 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1414,7 +1414,6 @@ def hash(*cols): 'uppercase. Words are delimited by whitespace.', 'lower': 'Converts a string column to lower case.', 'upper': 'Converts a string column to upper case.', - 'reverse': 'Reverses the string column and returns it as a new string column.', 'ltrim': 'Trim the spaces from left end for the specified string value.', 'rtrim': 'Trim the spaces from right end for the specified string value.', 'trim': 'Trim the spaces from both ends for the specified string column.', @@ -2128,6 +2127,25 @@ def sort_array(col, asc=True): return Column(sc._jvm.functions.sort_array(_to_java_column(col), asc)) +@since(1.5) +@ignore_unicode_prefix +def reverse(col): + """ + Collection function: returns a reversed string or an array with reverse order of elements. + + :param col: name of column or expression + + >>> df = spark.createDataFrame([('Spark SQL',)], ['data']) + >>> df.select(reverse(df.data).alias('s')).collect() + [Row(s=u'LQS krapS')] + >>> df = spark.createDataFrame([([2, 1, 3],) ,([1],) ,([],)], ['data']) + >>> df.select(reverse(df.data).alias('r')).collect() + [Row(r=[3, 1, 2]), Row(r=[1]), Row(r=[])] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.reverse(_to_java_column(col))) + + @since(2.3) def map_keys(col): """ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 4dd1ca509bf2c..38c874ad948e1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -336,7 +336,6 @@ object FunctionRegistry { expression[RegExpReplace]("regexp_replace"), expression[StringRepeat]("repeat"), expression[StringReplace]("replace"), - expression[StringReverse]("reverse"), expression[RLike]("rlike"), expression[StringRPad]("rpad"), expression[StringTrimRight]("rtrim"), @@ -411,6 +410,7 @@ object FunctionRegistry { expression[SortArray]("sort_array"), expression[ArrayMin]("array_min"), expression[ArrayMax]("array_max"), + expression[Reverse]("reverse"), CreateStruct.registryEntry, // misc functions diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 7c87777eed47a..76b71f5b86074 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData, TypeUtils} import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String /** * Given an array or map, returns its size. Returns -1 if null. @@ -212,6 +213,93 @@ case class SortArray(base: Expression, ascendingOrder: Expression) override def prettyName: String = "sort_array" } +/** + * Returns a reversed string or an array with reverse order of elements. + */ +@ExpressionDescription( + usage = "_FUNC_(array) - Returns a reversed string or an array with reverse order of elements.", + examples = """ + Examples: + > SELECT _FUNC_('Spark SQL'); + LQS krapS + > SELECT _FUNC_(array(2, 1, 4, 3)); + [3, 4, 1, 2] + """, + since = "1.5.0", + note = "Reverse logic for arrays is available since 2.4.0." +) +case class Reverse(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { + + // Input types are utilized by type coercion in ImplicitTypeCasts. + override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(StringType, ArrayType)) + + override def dataType: DataType = child.dataType + + lazy val elementType: DataType = dataType.asInstanceOf[ArrayType].elementType + + override def nullSafeEval(input: Any): Any = input match { + case a: ArrayData => new GenericArrayData(a.toObjectArray(elementType).reverse) + case s: UTF8String => s.reverse() + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, c => dataType match { + case _: StringType => stringCodeGen(ev, c) + case _: ArrayType => arrayCodeGen(ctx, ev, c) + }) + } + + private def stringCodeGen(ev: ExprCode, childName: String): String = { + s"${ev.value} = ($childName).reverse();" + } + + private def arrayCodeGen(ctx: CodegenContext, ev: ExprCode, childName: String): String = { + val length = ctx.freshName("length") + val javaElementType = CodeGenerator.javaType(elementType) + val isPrimitiveType = CodeGenerator.isPrimitiveType(elementType) + + val initialization = if (isPrimitiveType) { + s"$childName.copy()" + } else { + s"new ${classOf[GenericArrayData].getName()}(new Object[$length])" + } + + val numberOfIterations = if (isPrimitiveType) s"$length / 2" else length + + val swapAssigments = if (isPrimitiveType) { + val setFunc = "set" + CodeGenerator.primitiveTypeName(elementType) + val getCall = (index: String) => CodeGenerator.getValue(ev.value, elementType, index) + s"""|boolean isNullAtK = ${ev.value}.isNullAt(k); + |boolean isNullAtL = ${ev.value}.isNullAt(l); + |if(!isNullAtK) { + | $javaElementType el = ${getCall("k")}; + | if(!isNullAtL) { + | ${ev.value}.$setFunc(k, ${getCall("l")}); + | } else { + | ${ev.value}.setNullAt(k); + | } + | ${ev.value}.$setFunc(l, el); + |} else if (!isNullAtL) { + | ${ev.value}.$setFunc(k, ${getCall("l")}); + | ${ev.value}.setNullAt(l); + |}""".stripMargin + } else { + s"${ev.value}.update(k, ${CodeGenerator.getValue(childName, elementType, "l")});" + } + + s""" + |final int $length = $childName.numElements(); + |${ev.value} = $initialization; + |for(int k = 0; k < $numberOfIterations; k++) { + | int l = $length - k - 1; + | $swapAssigments + |} + """.stripMargin + } + + override def prettyName: String = "reverse" +} + /** * Checks if the array (left) has the element (right) */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 22fbb8998ed89..5a02ca0d6862c 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -1504,26 +1504,6 @@ case class StringRepeat(str: Expression, times: Expression) } } -/** - * Returns the reversed given string. - */ -@ExpressionDescription( - usage = "_FUNC_(str) - Returns the reversed given string.", - examples = """ - Examples: - > SELECT _FUNC_('Spark SQL'); - LQS krapS - """) -case class StringReverse(child: Expression) extends UnaryExpression with String2StringExpression { - override def convert(v: UTF8String): UTF8String = v.reverse() - - override def prettyName: String = "reverse" - - override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - defineCodeGen(ctx, ev, c => s"($c).reverse()") - } -} - /** * Returns a string consisting of n spaces. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 5a31e3a30edd6..517639dbc7232 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -125,4 +125,48 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation( ArrayMax(Literal.create(Seq(1.123, 0.1234, 1.121), ArrayType(DoubleType))), 1.123) } + + test("Reverse") { + // Primitive-type elements + val ai0 = Literal.create(Seq(2, 1, 4, 3), ArrayType(IntegerType)) + val ai1 = Literal.create(Seq(2, 1, 3), ArrayType(IntegerType)) + val ai2 = Literal.create(Seq(null, 1, null, 3), ArrayType(IntegerType)) + val ai3 = Literal.create(Seq(2, null, 4, null), ArrayType(IntegerType)) + val ai4 = Literal.create(Seq(null, null, null), ArrayType(IntegerType)) + val ai5 = Literal.create(Seq(1), ArrayType(IntegerType)) + val ai6 = Literal.create(Seq.empty, ArrayType(IntegerType)) + val ai7 = Literal.create(null, ArrayType(IntegerType)) + + checkEvaluation(Reverse(ai0), Seq(3, 4, 1, 2)) + checkEvaluation(Reverse(ai1), Seq(3, 1, 2)) + checkEvaluation(Reverse(ai2), Seq(3, null, 1, null)) + checkEvaluation(Reverse(ai3), Seq(null, 4, null, 2)) + checkEvaluation(Reverse(ai4), Seq(null, null, null)) + checkEvaluation(Reverse(ai5), Seq(1)) + checkEvaluation(Reverse(ai6), Seq.empty) + checkEvaluation(Reverse(ai7), null) + + // Non-primitive-type elements + val as0 = Literal.create(Seq("b", "a", "d", "c"), ArrayType(StringType)) + val as1 = Literal.create(Seq("b", "a", "c"), ArrayType(StringType)) + val as2 = Literal.create(Seq(null, "a", null, "c"), ArrayType(StringType)) + val as3 = Literal.create(Seq("b", null, "d", null), ArrayType(StringType)) + val as4 = Literal.create(Seq(null, null, null), ArrayType(StringType)) + val as5 = Literal.create(Seq("a"), ArrayType(StringType)) + val as6 = Literal.create(Seq.empty, ArrayType(StringType)) + val as7 = Literal.create(null, ArrayType(StringType)) + val aa = Literal.create( + Seq(Seq("a", "b"), Seq("c", "d"), Seq("e")), + ArrayType(ArrayType(StringType))) + + checkEvaluation(Reverse(as0), Seq("c", "d", "a", "b")) + checkEvaluation(Reverse(as1), Seq("c", "a", "b")) + checkEvaluation(Reverse(as2), Seq("c", null, "a", null)) + checkEvaluation(Reverse(as3), Seq(null, "d", null, "b")) + checkEvaluation(Reverse(as4), Seq(null, null, null)) + checkEvaluation(Reverse(as5), Seq("a")) + checkEvaluation(Reverse(as6), Seq.empty) + checkEvaluation(Reverse(as7), null) + checkEvaluation(Reverse(aa), Seq(Seq("e"), Seq("c", "d"), Seq("a", "b"))) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index 9a1a4da074ce3..f1a6f9b8889fa 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -629,9 +629,9 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { test("REVERSE") { val s = 'a.string.at(0) val row1 = create_row("abccc") - checkEvaluation(StringReverse(Literal("abccc")), "cccba", row1) - checkEvaluation(StringReverse(s), "cccba", row1) - checkEvaluation(StringReverse(Literal.create(null, StringType)), null, row1) + checkEvaluation(Reverse(Literal("abccc")), "cccba", row1) + checkEvaluation(Reverse(s), "cccba", row1) + checkEvaluation(Reverse(Literal.create(null, StringType)), null, row1) } test("SPACE") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 642ac056bb809..a55a800f48245 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -2464,14 +2464,6 @@ object functions { StringRepeat(str.expr, lit(n).expr) } - /** - * Reverses the string column and returns it as a new string column. - * - * @group string_funcs - * @since 1.5.0 - */ - def reverse(str: Column): Column = withExpr { StringReverse(str.expr) } - /** * Trim the spaces from right end for the specified string value. * @@ -3316,6 +3308,13 @@ object functions { */ def array_max(e: Column): Column = withExpr { ArrayMax(e.expr) } + /** + * Returns a reversed string or an array with reverse order of elements. + * @group collection_funcs + * @since 1.5.0 + */ + def reverse(e: Column): Column = withExpr { Reverse(e.expr) } + /** * Returns an unordered array containing the keys of the map. * @group collection_funcs diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 636e86baedf6f..74c42f2599dca 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -441,6 +441,100 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { checkAnswer(df.selectExpr("array_max(a)"), answer) } + test("reverse function") { + val dummyFilter = (c: Column) => c.isNull || c.isNotNull // switch codegen on + + // String test cases + val oneRowDF = Seq(("Spark", 3215)).toDF("s", "i") + + checkAnswer( + oneRowDF.select(reverse('s)), + Seq(Row("krapS")) + ) + checkAnswer( + oneRowDF.selectExpr("reverse(s)"), + Seq(Row("krapS")) + ) + checkAnswer( + oneRowDF.select(reverse('i)), + Seq(Row("5123")) + ) + checkAnswer( + oneRowDF.selectExpr("reverse(i)"), + Seq(Row("5123")) + ) + checkAnswer( + oneRowDF.selectExpr("reverse(null)"), + Seq(Row(null)) + ) + + // Array test cases (primitive-type elements) + val idf = Seq( + Seq(1, 9, 8, 7), + Seq(5, 8, 9, 7, 2), + Seq.empty, + null + ).toDF("i") + + checkAnswer( + idf.select(reverse('i)), + Seq(Row(Seq(7, 8, 9, 1)), Row(Seq(2, 7, 9, 8, 5)), Row(Seq.empty), Row(null)) + ) + checkAnswer( + idf.filter(dummyFilter('i)).select(reverse('i)), + Seq(Row(Seq(7, 8, 9, 1)), Row(Seq(2, 7, 9, 8, 5)), Row(Seq.empty), Row(null)) + ) + checkAnswer( + idf.selectExpr("reverse(i)"), + Seq(Row(Seq(7, 8, 9, 1)), Row(Seq(2, 7, 9, 8, 5)), Row(Seq.empty), Row(null)) + ) + checkAnswer( + oneRowDF.selectExpr("reverse(array(1, null, 2, null))"), + Seq(Row(Seq(null, 2, null, 1))) + ) + checkAnswer( + oneRowDF.filter(dummyFilter('i)).selectExpr("reverse(array(1, null, 2, null))"), + Seq(Row(Seq(null, 2, null, 1))) + ) + + // Array test cases (non-primitive-type elements) + val sdf = Seq( + Seq("c", "a", "b"), + Seq("b", null, "c", null), + Seq.empty, + null + ).toDF("s") + + checkAnswer( + sdf.select(reverse('s)), + Seq(Row(Seq("b", "a", "c")), Row(Seq(null, "c", null, "b")), Row(Seq.empty), Row(null)) + ) + checkAnswer( + sdf.filter(dummyFilter('s)).select(reverse('s)), + Seq(Row(Seq("b", "a", "c")), Row(Seq(null, "c", null, "b")), Row(Seq.empty), Row(null)) + ) + checkAnswer( + sdf.selectExpr("reverse(s)"), + Seq(Row(Seq("b", "a", "c")), Row(Seq(null, "c", null, "b")), Row(Seq.empty), Row(null)) + ) + checkAnswer( + oneRowDF.selectExpr("reverse(array(array(1, 2), array(3, 4)))"), + Seq(Row(Seq(Seq(3, 4), Seq(1, 2)))) + ) + checkAnswer( + oneRowDF.filter(dummyFilter('s)).selectExpr("reverse(array(array(1, 2), array(3, 4)))"), + Seq(Row(Seq(Seq(3, 4), Seq(1, 2)))) + ) + + // Error test cases + intercept[AnalysisException] { + oneRowDF.selectExpr("reverse(struct(1, 'a'))") + } + intercept[AnalysisException] { + oneRowDF.selectExpr("reverse(map(1, 'a'))") + } + } + private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = { import DataFrameFunctionsSuite.CodegenFallbackExpr for ((codegenFallback, wholeStage) <- Seq((true, false), (false, false), (false, true))) { From f09a9e9418c1697d198de18f340b1288f5eb025c Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Wed, 18 Apr 2018 08:22:05 -0700 Subject: [PATCH 0642/1161] [SPARK-24007][SQL] EqualNullSafe for FloatType and DoubleType might generate a wrong result by codegen. ## What changes were proposed in this pull request? `EqualNullSafe` for `FloatType` and `DoubleType` might generate a wrong result by codegen. ```scala scala> val df = Seq((Some(-1.0d), None), (None, Some(-1.0d))).toDF() df: org.apache.spark.sql.DataFrame = [_1: double, _2: double] scala> df.show() +----+----+ | _1| _2| +----+----+ |-1.0|null| |null|-1.0| +----+----+ scala> df.filter("_1 <=> _2").show() +----+----+ | _1| _2| +----+----+ |-1.0|null| |null|-1.0| +----+----+ ``` The result should be empty but the result remains two rows. ## How was this patch tested? Added a test. Author: Takuya UESHIN Closes #21094 from ueshin/issues/SPARK-24007/equalnullsafe. --- .../sql/catalyst/expressions/codegen/CodeGenerator.scala | 6 ++++-- .../spark/sql/catalyst/expressions/PredicateSuite.scala | 7 +++++++ 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index f6b6775923ac6..cf0a91ff00626 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -582,8 +582,10 @@ class CodegenContext { */ def genEqual(dataType: DataType, c1: String, c2: String): String = dataType match { case BinaryType => s"java.util.Arrays.equals($c1, $c2)" - case FloatType => s"(java.lang.Float.isNaN($c1) && java.lang.Float.isNaN($c2)) || $c1 == $c2" - case DoubleType => s"(java.lang.Double.isNaN($c1) && java.lang.Double.isNaN($c2)) || $c1 == $c2" + case FloatType => + s"((java.lang.Float.isNaN($c1) && java.lang.Float.isNaN($c2)) || $c1 == $c2)" + case DoubleType => + s"((java.lang.Double.isNaN($c1) && java.lang.Double.isNaN($c2)) || $c1 == $c2)" case dt: DataType if isPrimitiveType(dt) => s"$c1 == $c2" case dt: DataType if dt.isInstanceOf[AtomicType] => s"$c1.equals($c2)" case array: ArrayType => genComp(array, c1, c2) + " == 0" diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala index 8a8f8e10225fa..1bfd180ae4393 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala @@ -442,4 +442,11 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { InSet(Literal(1), Set(1, 2, 3, 4)).genCode(ctx) assert(ctx.inlinedMutableStates.isEmpty) } + + test("SPARK-24007: EqualNullSafe for FloatType and DoubleType might generate a wrong result") { + checkEvaluation(EqualNullSafe(Literal(null, FloatType), Literal(-1.0f)), false) + checkEvaluation(EqualNullSafe(Literal(-1.0f), Literal(null, FloatType)), false) + checkEvaluation(EqualNullSafe(Literal(null, DoubleType), Literal(-1.0d)), false) + checkEvaluation(EqualNullSafe(Literal(-1.0d), Literal(null, DoubleType)), false) + } } From a9066478f6d98c3ae634c3bb9b09ee20bd60e111 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 19 Apr 2018 00:05:47 +0200 Subject: [PATCH 0643/1161] [SPARK-23875][SQL][FOLLOWUP] Add IndexedSeq wrapper for ArrayData ## What changes were proposed in this pull request? Use specified accessor in `ArrayData.foreach` and `toArray`. ## How was this patch tested? Existing tests. Author: Liang-Chi Hsieh Closes #21099 from viirya/SPARK-23875-followup. --- .../org/apache/spark/sql/catalyst/util/ArrayData.scala | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayData.scala index 2cf59d567c08c..104b428614849 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayData.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayData.scala @@ -141,28 +141,29 @@ abstract class ArrayData extends SpecializedGetters with Serializable { def toArray[T: ClassTag](elementType: DataType): Array[T] = { val size = numElements() + val accessor = InternalRow.getAccessor(elementType) val values = new Array[T](size) var i = 0 while (i < size) { if (isNullAt(i)) { values(i) = null.asInstanceOf[T] } else { - values(i) = get(i, elementType).asInstanceOf[T] + values(i) = accessor(this, i).asInstanceOf[T] } i += 1 } values } - // todo: specialize this. def foreach(elementType: DataType, f: (Int, Any) => Unit): Unit = { val size = numElements() + val accessor = InternalRow.getAccessor(elementType) var i = 0 while (i < size) { if (isNullAt(i)) { f(i, null) } else { - f(i, get(i, elementType)) + f(i, accessor(this, i)) } i += 1 } From 0c94e48bc50717e1627c0d2acd5382d9adc73c97 Mon Sep 17 00:00:00 2001 From: Gabor Somogyi Date: Wed, 18 Apr 2018 16:37:41 -0700 Subject: [PATCH 0644/1161] [SPARK-23775][TEST] Make DataFrameRangeSuite not flaky ## What changes were proposed in this pull request? DataFrameRangeSuite.test("Cancelling stage in a query with Range.") stays sometimes in an infinite loop and times out the build. There were multiple issues with the test: 1. The first valid stageId is zero when the test started alone and not in a suite and the following code waits until timeout: ``` eventually(timeout(10.seconds), interval(1.millis)) { assert(DataFrameRangeSuite.stageToKill > 0) } ``` 2. The `DataFrameRangeSuite.stageToKill` was overwritten by the task's thread after the reset which ended up in canceling the same stage 2 times. This caused the infinite wait. This PR solves this mentioned flakyness by removing the shared `DataFrameRangeSuite.stageToKill` and using `wait` and `CountDownLatch` for synhronization. ## How was this patch tested? Existing unit test. Author: Gabor Somogyi Closes #20888 from gaborgsomogyi/SPARK-23775. --- .../spark/sql/DataFrameRangeSuite.scala | 78 +++++++++++-------- 1 file changed, 45 insertions(+), 33 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala index 57a930dfaf320..a0fd74088ce8b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala @@ -17,14 +17,16 @@ package org.apache.spark.sql +import java.util.concurrent.{CountDownLatch, TimeUnit} + import scala.concurrent.duration._ import scala.math.abs import scala.util.Random import org.scalatest.concurrent.Eventually -import org.apache.spark.{SparkException, TaskContext} -import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} +import org.apache.spark.{SparkContext, SparkException} +import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskStart} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext @@ -152,39 +154,53 @@ class DataFrameRangeSuite extends QueryTest with SharedSQLContext with Eventuall } test("Cancelling stage in a query with Range.") { - val listener = new SparkListener { - override def onJobStart(jobStart: SparkListenerJobStart): Unit = { - eventually(timeout(10.seconds), interval(1.millis)) { - assert(DataFrameRangeSuite.stageToKill > 0) + // Save and restore the value because SparkContext is shared + val savedInterruptOnCancel = sparkContext + .getLocalProperty(SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL) + + try { + sparkContext.setLocalProperty(SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL, "true") + + for (codegen <- Seq(true, false)) { + // This countdown latch used to make sure with all the stages cancelStage called in listener + val latch = new CountDownLatch(2) + + val listener = new SparkListener { + override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = { + sparkContext.cancelStage(taskStart.stageId) + latch.countDown() + } } - sparkContext.cancelStage(DataFrameRangeSuite.stageToKill) - } - } - sparkContext.addSparkListener(listener) - for (codegen <- Seq(true, false)) { - withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> codegen.toString()) { - DataFrameRangeSuite.stageToKill = -1 - val ex = intercept[SparkException] { - spark.range(0, 100000000000L, 1, 1).map { x => - DataFrameRangeSuite.stageToKill = TaskContext.get().stageId() - x - }.toDF("id").agg(sum("id")).collect() + sparkContext.addSparkListener(listener) + withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> codegen.toString()) { + val ex = intercept[SparkException] { + sparkContext.range(0, 10000L, numSlices = 10).mapPartitions { x => + x.synchronized { + x.wait() + } + x + }.toDF("id").agg(sum("id")).collect() + } + ex.getCause() match { + case null => + assert(ex.getMessage().contains("cancelled")) + case cause: SparkException => + assert(cause.getMessage().contains("cancelled")) + case cause: Throwable => + fail("Expected the cause to be SparkException, got " + cause.toString() + " instead.") + } } - ex.getCause() match { - case null => - assert(ex.getMessage().contains("cancelled")) - case cause: SparkException => - assert(cause.getMessage().contains("cancelled")) - case cause: Throwable => - fail("Expected the cause to be SparkException, got " + cause.toString() + " instead.") + latch.await(20, TimeUnit.SECONDS) + eventually(timeout(20.seconds)) { + assert(sparkContext.statusTracker.getExecutorInfos.map(_.numRunningTasks()).sum == 0) } + sparkContext.removeSparkListener(listener) } - eventually(timeout(20.seconds)) { - assert(sparkContext.statusTracker.getExecutorInfos.map(_.numRunningTasks()).sum == 0) - } + } finally { + sparkContext.setLocalProperty(SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL, + savedInterruptOnCancel) } - sparkContext.removeSparkListener(listener) } test("SPARK-20430 Initialize Range parameters in a driver side") { @@ -204,7 +220,3 @@ class DataFrameRangeSuite extends QueryTest with SharedSQLContext with Eventuall } } } - -object DataFrameRangeSuite { - @volatile var stageToKill = -1 -} From 8bb0df2c65355dfdcd28e362ff661c6c7ebc99c0 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 19 Apr 2018 10:00:57 +0800 Subject: [PATCH 0645/1161] [SPARK-24014][PYSPARK] Add onStreamingStarted method to StreamingListener ## What changes were proposed in this pull request? The `StreamingListener` in PySpark side seems to be lack of `onStreamingStarted` method. This patch adds it and a test for it. This patch also includes a trivial doc improvement for `createDirectStream`. Original PR is #21057. ## How was this patch tested? Added test. Author: Liang-Chi Hsieh Closes #21098 from viirya/SPARK-24014. --- python/pyspark/streaming/kafka.py | 3 ++- python/pyspark/streaming/listener.py | 6 ++++++ python/pyspark/streaming/tests.py | 7 +++++++ 3 files changed, 15 insertions(+), 1 deletion(-) diff --git a/python/pyspark/streaming/kafka.py b/python/pyspark/streaming/kafka.py index fdb9308604489..ed2e0e7d10fa2 100644 --- a/python/pyspark/streaming/kafka.py +++ b/python/pyspark/streaming/kafka.py @@ -104,7 +104,8 @@ def createDirectStream(ssc, topics, kafkaParams, fromOffsets=None, :param topics: list of topic_name to consume. :param kafkaParams: Additional params for Kafka. :param fromOffsets: Per-topic/partition Kafka offsets defining the (inclusive) starting - point of the stream. + point of the stream (a dictionary mapping `TopicAndPartition` to + integers). :param keyDecoder: A function used to decode key (default is utf8_decoder). :param valueDecoder: A function used to decode value (default is utf8_decoder). :param messageHandler: A function used to convert KafkaMessageAndMetadata. You can assess diff --git a/python/pyspark/streaming/listener.py b/python/pyspark/streaming/listener.py index b830797f5c0a0..d4ecc215aea99 100644 --- a/python/pyspark/streaming/listener.py +++ b/python/pyspark/streaming/listener.py @@ -23,6 +23,12 @@ class StreamingListener(object): def __init__(self): pass + def onStreamingStarted(self, streamingStarted): + """ + Called when the streaming has been started. + """ + pass + def onReceiverStarted(self, receiverStarted): """ Called when a receiver has been started diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 7dde7c0928c08..103940923dd4d 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -507,6 +507,10 @@ def __init__(self): self.batchInfosCompleted = [] self.batchInfosStarted = [] self.batchInfosSubmitted = [] + self.streamingStartedTime = [] + + def onStreamingStarted(self, streamingStarted): + self.streamingStartedTime.append(streamingStarted.time) def onBatchSubmitted(self, batchSubmitted): self.batchInfosSubmitted.append(batchSubmitted.batchInfo()) @@ -530,9 +534,12 @@ def func(dstream): batchInfosSubmitted = batch_collector.batchInfosSubmitted batchInfosStarted = batch_collector.batchInfosStarted batchInfosCompleted = batch_collector.batchInfosCompleted + streamingStartedTime = batch_collector.streamingStartedTime self.wait_for(batchInfosCompleted, 4) + self.assertEqual(len(streamingStartedTime), 1) + self.assertGreaterEqual(len(batchInfosSubmitted), 4) for info in batchInfosSubmitted: self.assertGreaterEqual(info.batchTime().milliseconds(), 0) From d5bec48b9cb225c19b43935c07b24090c51cacce Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Thu, 19 Apr 2018 11:59:17 +0900 Subject: [PATCH 0646/1161] [SPARK-23919][SQL] Add array_position function ## What changes were proposed in this pull request? The PR adds the SQL function `array_position`. The behavior of the function is based on Presto's one. The function returns the position of the first occurrence of the element in array x (or 0 if not found) using 1-based index as BigInt. ## How was this patch tested? Added UTs Author: Kazuaki Ishizaki Closes #21037 from kiszk/SPARK-23919. --- python/pyspark/sql/functions.py | 17 ++++++ .../catalyst/analysis/FunctionRegistry.scala | 1 + .../expressions/collectionOperations.scala | 56 +++++++++++++++++++ .../CollectionExpressionsSuite.scala | 22 ++++++++ .../org/apache/spark/sql/functions.scala | 14 +++++ .../spark/sql/DataFrameFunctionsSuite.scala | 34 +++++++++++ 6 files changed, 144 insertions(+) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index d3bb0a5d6b36a..36dcabc6766d8 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1845,6 +1845,23 @@ def array_contains(col, value): return Column(sc._jvm.functions.array_contains(_to_java_column(col), value)) +@since(2.4) +def array_position(col, value): + """ + Collection function: Locates the position of the first occurrence of the given value + in the given array. Returns null if either of the arguments are null. + + .. note:: The position is not zero based, but 1 based index. Returns 0 if the given + value could not be found in the array. + + >>> df = spark.createDataFrame([(["c", "b", "a"],), ([],)], ['data']) + >>> df.select(array_position(df.data, "a")).collect() + [Row(array_position(data, a)=3), Row(array_position(data, a)=0)] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.array_position(_to_java_column(col), value)) + + @since(1.4) def explode(col): """Returns a new row for each element in the given array or map. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 38c874ad948e1..74095fe697b6a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -402,6 +402,7 @@ object FunctionRegistry { // collection functions expression[CreateArray]("array"), expression[ArrayContains]("array_contains"), + expression[ArrayPosition]("array_position"), expression[CreateMap]("map"), expression[CreateNamedStruct]("named_struct"), expression[MapKeys]("map_keys"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 76b71f5b86074..e6a05f535cb1c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -505,3 +505,59 @@ case class ArrayMax(child: Expression) extends UnaryExpression with ImplicitCast override def prettyName: String = "array_max" } + + +/** + * Returns the position of the first occurrence of element in the given array as long. + * Returns 0 if the given value could not be found in the array. Returns null if either of + * the arguments are null + * + * NOTE: that this is not zero based, but 1-based index. The first element in the array has + * index 1. + */ +@ExpressionDescription( + usage = """ + _FUNC_(array, element) - Returns the (1-based) index of the first element of the array as long. + """, + examples = """ + Examples: + > SELECT _FUNC_(array(3, 2, 1), 1); + 3 + """, + since = "2.4.0") +case class ArrayPosition(left: Expression, right: Expression) + extends BinaryExpression with ImplicitCastInputTypes { + + override def dataType: DataType = LongType + override def inputTypes: Seq[AbstractDataType] = + Seq(ArrayType, left.dataType.asInstanceOf[ArrayType].elementType) + + override def nullSafeEval(arr: Any, value: Any): Any = { + arr.asInstanceOf[ArrayData].foreach(right.dataType, (i, v) => + if (v == value) { + return (i + 1).toLong + } + ) + 0L + } + + override def prettyName: String = "array_position" + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, (arr, value) => { + val pos = ctx.freshName("arrayPosition") + val i = ctx.freshName("i") + val getValue = CodeGenerator.getValue(arr, right.dataType, i) + s""" + |int $pos = 0; + |for (int $i = 0; $i < $arr.numElements(); $i ++) { + | if (!$arr.isNullAt($i) && ${ctx.genEqual(right.dataType, value, getValue)}) { + | $pos = $i + 1; + | break; + | } + |} + |${ev.value} = (long) $pos; + """.stripMargin + }) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 517639dbc7232..916cd3bb4cca5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -169,4 +169,26 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(Reverse(as7), null) checkEvaluation(Reverse(aa), Seq(Seq("e"), Seq("c", "d"), Seq("a", "b"))) } + + test("Array Position") { + val a0 = Literal.create(Seq(1, null, 2, 3), ArrayType(IntegerType)) + val a1 = Literal.create(Seq[String](null, ""), ArrayType(StringType)) + val a2 = Literal.create(Seq(null), ArrayType(LongType)) + val a3 = Literal.create(null, ArrayType(StringType)) + + checkEvaluation(ArrayPosition(a0, Literal(3)), 4L) + checkEvaluation(ArrayPosition(a0, Literal(1)), 1L) + checkEvaluation(ArrayPosition(a0, Literal(0)), 0L) + checkEvaluation(ArrayPosition(a0, Literal.create(null, IntegerType)), null) + + checkEvaluation(ArrayPosition(a1, Literal("")), 2L) + checkEvaluation(ArrayPosition(a1, Literal("a")), 0L) + checkEvaluation(ArrayPosition(a1, Literal.create(null, StringType)), null) + + checkEvaluation(ArrayPosition(a2, Literal(1L)), 0L) + checkEvaluation(ArrayPosition(a2, Literal.create(null, LongType)), null) + + checkEvaluation(ArrayPosition(a3, Literal("")), null) + checkEvaluation(ArrayPosition(a3, Literal.create(null, StringType)), null) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index a55a800f48245..3a09ec4f1982e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3038,6 +3038,20 @@ object functions { ArrayContains(column.expr, Literal(value)) } + /** + * Locates the position of the first occurrence of the value in the given array as long. + * Returns null if either of the arguments are null. + * + * @note The position is not zero based, but 1 based index. Returns 0 if value + * could not be found in array. + * + * @group collection_funcs + * @since 2.4.0 + */ + def array_position(column: Column, value: Any): Column = withExpr { + ArrayPosition(column.expr, Literal(value)) + } + /** * Creates a new row for each element in the given array or map column. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 74c42f2599dca..13161e7e24cfe 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -535,6 +535,40 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { } } + test("array position function") { + val df = Seq( + (Seq[Int](1, 2), "x"), + (Seq[Int](), "x") + ).toDF("a", "b") + + checkAnswer( + df.select(array_position(df("a"), 1)), + Seq(Row(1L), Row(0L)) + ) + checkAnswer( + df.selectExpr("array_position(a, 1)"), + Seq(Row(1L), Row(0L)) + ) + + checkAnswer( + df.select(array_position(df("a"), null)), + Seq(Row(null), Row(null)) + ) + checkAnswer( + df.selectExpr("array_position(a, null)"), + Seq(Row(null), Row(null)) + ) + + checkAnswer( + df.selectExpr("array_position(array(array(1), null)[0], 1)"), + Seq(Row(1L), Row(1L)) + ) + checkAnswer( + df.selectExpr("array_position(array(1, null), array(1, null)[0])"), + Seq(Row(1L), Row(1L)) + ) + } + private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = { import DataFrameFunctionsSuite.CodegenFallbackExpr for ((codegenFallback, wholeStage) <- Seq((true, false), (false, false), (false, true))) { From 46bb2b5129833cc5829089bf1174a76cb7b81741 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Thu, 19 Apr 2018 21:00:10 +0900 Subject: [PATCH 0647/1161] [SPARK-23924][SQL] Add element_at function ## What changes were proposed in this pull request? The PR adds the SQL function `element_at`. The behavior of the function is based on Presto's one. This function returns element of array at given index in value if column is array, or returns value for the given key in value if column is map. ## How was this patch tested? Added UTs Author: Kazuaki Ishizaki Closes #21053 from kiszk/SPARK-23924. --- python/pyspark/sql/functions.py | 24 ++++ .../catalyst/analysis/FunctionRegistry.scala | 1 + .../expressions/collectionOperations.scala | 104 ++++++++++++++++++ .../expressions/complexTypeExtractors.scala | 64 +++++++---- .../CollectionExpressionsSuite.scala | 48 ++++++++ .../org/apache/spark/sql/functions.scala | 11 ++ .../spark/sql/DataFrameFunctionsSuite.scala | 48 ++++++++ 7 files changed, 276 insertions(+), 24 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 36dcabc6766d8..1be68f2a4a448 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1862,6 +1862,30 @@ def array_position(col, value): return Column(sc._jvm.functions.array_position(_to_java_column(col), value)) +@ignore_unicode_prefix +@since(2.4) +def element_at(col, extraction): + """ + Collection function: Returns element of array at given index in extraction if col is array. + Returns value for the given key in extraction if col is map. + + :param col: name of column containing array or map + :param extraction: index to check for in array or key to check for in map + + .. note:: The position is not zero based, but 1 based index. + + >>> df = spark.createDataFrame([(["a", "b", "c"],), ([],)], ['data']) + >>> df.select(element_at(df.data, 1)).collect() + [Row(element_at(data, 1)=u'a'), Row(element_at(data, 1)=None)] + + >>> df = spark.createDataFrame([({"a": 1.0, "b": 2.0},), ({},)], ['data']) + >>> df.select(element_at(df.data, "a")).collect() + [Row(element_at(data, a)=1.0), Row(element_at(data, a)=None)] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.element_at(_to_java_column(col), extraction)) + + @since(1.4) def explode(col): """Returns a new row for each element in the given array or map. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 74095fe697b6a..a44f2d5272b8e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -405,6 +405,7 @@ object FunctionRegistry { expression[ArrayPosition]("array_position"), expression[CreateMap]("map"), expression[CreateNamedStruct]("named_struct"), + expression[ElementAt]("element_at"), expression[MapKeys]("map_keys"), expression[MapValues]("map_values"), expression[Size]("size"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index e6a05f535cb1c..dba426e999dda 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -561,3 +561,107 @@ case class ArrayPosition(left: Expression, right: Expression) }) } } + +/** + * Returns the value of index `right` in Array `left` or the value for key `right` in Map `left`. + */ +@ExpressionDescription( + usage = """ + _FUNC_(array, index) - Returns element of array at given (1-based) index. If index < 0, + accesses elements from the last to the first. Returns NULL if the index exceeds the length + of the array. + + _FUNC_(map, key) - Returns value for given key, or NULL if the key is not contained in the map + """, + examples = """ + Examples: + > SELECT _FUNC_(array(1, 2, 3), 2); + 2 + > SELECT _FUNC_(map(1, 'a', 2, 'b'), 2); + "b" + """, + since = "2.4.0") +case class ElementAt(left: Expression, right: Expression) extends GetMapValueUtil { + + override def dataType: DataType = left.dataType match { + case ArrayType(elementType, _) => elementType + case MapType(_, valueType, _) => valueType + } + + override def inputTypes: Seq[AbstractDataType] = { + Seq(TypeCollection(ArrayType, MapType), + left.dataType match { + case _: ArrayType => IntegerType + case _: MapType => left.dataType.asInstanceOf[MapType].keyType + } + ) + } + + override def nullable: Boolean = true + + override def nullSafeEval(value: Any, ordinal: Any): Any = { + left.dataType match { + case _: ArrayType => + val array = value.asInstanceOf[ArrayData] + val index = ordinal.asInstanceOf[Int] + if (array.numElements() < math.abs(index)) { + null + } else { + val idx = if (index == 0) { + throw new ArrayIndexOutOfBoundsException("SQL array indices start at 1") + } else if (index > 0) { + index - 1 + } else { + array.numElements() + index + } + if (left.dataType.asInstanceOf[ArrayType].containsNull && array.isNullAt(idx)) { + null + } else { + array.get(idx, dataType) + } + } + case _: MapType => + getValueEval(value, ordinal, left.dataType.asInstanceOf[MapType].keyType) + } + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + left.dataType match { + case _: ArrayType => + nullSafeCodeGen(ctx, ev, (eval1, eval2) => { + val index = ctx.freshName("elementAtIndex") + val nullCheck = if (left.dataType.asInstanceOf[ArrayType].containsNull) { + s""" + |if ($eval1.isNullAt($index)) { + | ${ev.isNull} = true; + |} else + """.stripMargin + } else { + "" + } + s""" + |int $index = (int) $eval2; + |if ($eval1.numElements() < Math.abs($index)) { + | ${ev.isNull} = true; + |} else { + | if ($index == 0) { + | throw new ArrayIndexOutOfBoundsException("SQL array indices start at 1"); + | } else if ($index > 0) { + | $index--; + | } else { + | $index += $eval1.numElements(); + | } + | $nullCheck + | { + | ${ev.value} = ${CodeGenerator.getValue(eval1, dataType, index)}; + | } + |} + """.stripMargin + }) + case _: MapType => + doGetValueGenCode(ctx, ev, left.dataType.asInstanceOf[MapType]) + } + } + + override def prettyName: String = "element_at" +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index 6cdad19168dce..3fba52d745453 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -268,31 +268,12 @@ case class GetArrayItem(child: Expression, ordinal: Expression) } /** - * Returns the value of key `key` in Map `child`. - * - * We need to do type checking here as `key` expression maybe unresolved. + * Common base class for [[GetMapValue]] and [[ElementAt]]. */ -case class GetMapValue(child: Expression, key: Expression) - extends BinaryExpression with ImplicitCastInputTypes with ExtractValue with NullIntolerant { - - private def keyType = child.dataType.asInstanceOf[MapType].keyType - - // We have done type checking for child in `ExtractValue`, so only need to check the `key`. - override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType, keyType) - - override def toString: String = s"$child[$key]" - override def sql: String = s"${child.sql}[${key.sql}]" - - override def left: Expression = child - override def right: Expression = key - - /** `Null` is returned for invalid ordinals. */ - override def nullable: Boolean = true - - override def dataType: DataType = child.dataType.asInstanceOf[MapType].valueType +abstract class GetMapValueUtil extends BinaryExpression with ImplicitCastInputTypes { // todo: current search is O(n), improve it. - protected override def nullSafeEval(value: Any, ordinal: Any): Any = { + def getValueEval(value: Any, ordinal: Any, keyType: DataType): Any = { val map = value.asInstanceOf[MapData] val length = map.numElements() val keys = map.keyArray() @@ -315,14 +296,15 @@ case class GetMapValue(child: Expression, key: Expression) } } - override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + def doGetValueGenCode(ctx: CodegenContext, ev: ExprCode, mapType: MapType): ExprCode = { val index = ctx.freshName("index") val length = ctx.freshName("length") val keys = ctx.freshName("keys") val found = ctx.freshName("found") val key = ctx.freshName("key") val values = ctx.freshName("values") - val nullCheck = if (child.dataType.asInstanceOf[MapType].valueContainsNull) { + val keyType = mapType.keyType + val nullCheck = if (mapType.valueContainsNull) { s" || $values.isNullAt($index)" } else { "" @@ -354,3 +336,37 @@ case class GetMapValue(child: Expression, key: Expression) }) } } + +/** + * Returns the value of key `key` in Map `child`. + * + * We need to do type checking here as `key` expression maybe unresolved. + */ +case class GetMapValue(child: Expression, key: Expression) + extends GetMapValueUtil with ExtractValue with NullIntolerant { + + private def keyType = child.dataType.asInstanceOf[MapType].keyType + + // We have done type checking for child in `ExtractValue`, so only need to check the `key`. + override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType, keyType) + + override def toString: String = s"$child[$key]" + override def sql: String = s"${child.sql}[${key.sql}]" + + override def left: Expression = child + override def right: Expression = key + + /** `Null` is returned for invalid ordinals. */ + override def nullable: Boolean = true + + override def dataType: DataType = child.dataType.asInstanceOf[MapType].valueType + + // todo: current search is O(n), improve it. + override def nullSafeEval(value: Any, ordinal: Any): Any = { + getValueEval(value, ordinal, keyType) + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + doGetValueGenCode(ctx, ev, child.dataType.asInstanceOf[MapType]) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 916cd3bb4cca5..7d8fe211858b2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -191,4 +191,52 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(ArrayPosition(a3, Literal("")), null) checkEvaluation(ArrayPosition(a3, Literal.create(null, StringType)), null) } + + test("elementAt") { + val a0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType)) + val a1 = Literal.create(Seq[String](null, ""), ArrayType(StringType)) + val a2 = Literal.create(Seq(null), ArrayType(LongType)) + val a3 = Literal.create(null, ArrayType(StringType)) + + intercept[Exception] { + checkEvaluation(ElementAt(a0, Literal(0)), null) + }.getMessage.contains("SQL array indices start at 1") + intercept[Exception] { checkEvaluation(ElementAt(a0, Literal(1.1)), null) } + checkEvaluation(ElementAt(a0, Literal(4)), null) + checkEvaluation(ElementAt(a0, Literal(-4)), null) + + checkEvaluation(ElementAt(a0, Literal(1)), 1) + checkEvaluation(ElementAt(a0, Literal(2)), 2) + checkEvaluation(ElementAt(a0, Literal(3)), 3) + checkEvaluation(ElementAt(a0, Literal(-3)), 1) + checkEvaluation(ElementAt(a0, Literal(-2)), 2) + checkEvaluation(ElementAt(a0, Literal(-1)), 3) + + checkEvaluation(ElementAt(a1, Literal(1)), null) + checkEvaluation(ElementAt(a1, Literal(2)), "") + checkEvaluation(ElementAt(a1, Literal(-2)), null) + checkEvaluation(ElementAt(a1, Literal(-1)), "") + + checkEvaluation(ElementAt(a2, Literal(1)), null) + + checkEvaluation(ElementAt(a3, Literal(1)), null) + + + val m0 = + Literal.create(Map("a" -> "1", "b" -> "2", "c" -> null), MapType(StringType, StringType)) + val m1 = Literal.create(Map[String, String](), MapType(StringType, StringType)) + val m2 = Literal.create(null, MapType(StringType, StringType)) + + checkEvaluation(ElementAt(m0, Literal(1.0)), null) + + checkEvaluation(ElementAt(m0, Literal("d")), null) + + checkEvaluation(ElementAt(m1, Literal("a")), null) + + checkEvaluation(ElementAt(m0, Literal("a")), "1") + checkEvaluation(ElementAt(m0, Literal("b")), "2") + checkEvaluation(ElementAt(m0, Literal("c")), null) + + checkEvaluation(ElementAt(m2, Literal("a")), null) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 3a09ec4f1982e..9c8580378303e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3052,6 +3052,17 @@ object functions { ArrayPosition(column.expr, Literal(value)) } + /** + * Returns element of array at given index in value if column is array. Returns value for + * the given key in value if column is map. + * + * @group collection_funcs + * @since 2.4.0 + */ + def element_at(column: Column, value: Any): Column = withExpr { + ElementAt(column.expr, Literal(value)) + } + /** * Creates a new row for each element in the given array or map column. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 13161e7e24cfe..7c976c1b7f915 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -569,6 +569,54 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { ) } + test("element_at function") { + val df = Seq( + (Seq[String]("1", "2", "3")), + (Seq[String](null, "")), + (Seq[String]()) + ).toDF("a") + + intercept[Exception] { + checkAnswer( + df.select(element_at(df("a"), 0)), + Seq(Row(null), Row(null), Row(null)) + ) + }.getMessage.contains("SQL array indices start at 1") + intercept[Exception] { + checkAnswer( + df.select(element_at(df("a"), 1.1)), + Seq(Row(null), Row(null), Row(null)) + ) + } + checkAnswer( + df.select(element_at(df("a"), 4)), + Seq(Row(null), Row(null), Row(null)) + ) + + checkAnswer( + df.select(element_at(df("a"), 1)), + Seq(Row("1"), Row(null), Row(null)) + ) + checkAnswer( + df.select(element_at(df("a"), -1)), + Seq(Row("3"), Row(""), Row(null)) + ) + + checkAnswer( + df.selectExpr("element_at(a, 4)"), + Seq(Row(null), Row(null), Row(null)) + ) + + checkAnswer( + df.selectExpr("element_at(a, 1)"), + Seq(Row("1"), Row(null), Row(null)) + ) + checkAnswer( + df.selectExpr("element_at(a, -1)"), + Seq(Row("3"), Row(""), Row(null)) + ) + } + private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = { import DataFrameFunctionsSuite.CodegenFallbackExpr for ((codegenFallback, wholeStage) <- Seq((true, false), (false, false), (false, true))) { From 1b08c4393cf48e21fea9914d130d8d3bf544061d Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Thu, 19 Apr 2018 14:38:26 +0200 Subject: [PATCH 0648/1161] [SPARK-23584][SQL] NewInstance should support interpreted execution ## What changes were proposed in this pull request? This pr supported interpreted mode for `NewInstance`. ## How was this patch tested? Added tests in `ObjectExpressionsSuite`. Author: Takeshi Yamamuro Closes #20778 from maropu/SPARK-23584. --- .../spark/sql/catalyst/ScalaReflection.scala | 13 +++++++ .../expressions/objects/objects.scala | 28 +++++++++++++-- .../expressions/ObjectExpressionsSuite.scala | 36 +++++++++++++++++++ 3 files changed, 75 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index e4274aaa9727e..818cc2fb1e8a8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -17,6 +17,10 @@ package org.apache.spark.sql.catalyst +import java.lang.reflect.Constructor + +import org.apache.commons.lang3.reflect.ConstructorUtils + import org.apache.spark.sql.catalyst.analysis.{GetColumnByOrdinal, UnresolvedAttribute, UnresolvedExtractValue} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.objects._ @@ -781,6 +785,15 @@ object ScalaReflection extends ScalaReflection { } } + /** + * Finds an accessible constructor with compatible parameters. This is a more flexible search + * than the exact matching algorithm in `Class.getConstructor`. The first assignment-compatible + * matching constructor is returned. Otherwise, it returns `None`. + */ + def findConstructor(cls: Class[_], paramTypes: Seq[Class[_]]): Option[Constructor[_]] = { + Option(ConstructorUtils.getMatchingAccessibleConstructor(cls, paramTypes: _*)) + } + /** * Whether the fields of the given type is defined entirely by its constructor parameters. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 72b202b3a5020..1645bd7d57b1d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -449,8 +449,32 @@ case class NewInstance( childrenResolved && !needOuterPointer } - override def eval(input: InternalRow): Any = - throw new UnsupportedOperationException("Only code-generated evaluation is supported.") + @transient private lazy val constructor: (Seq[AnyRef]) => Any = { + val paramTypes = ScalaReflection.expressionJavaClasses(arguments) + val getConstructor = (paramClazz: Seq[Class[_]]) => { + ScalaReflection.findConstructor(cls, paramClazz).getOrElse { + sys.error(s"Couldn't find a valid constructor on $cls") + } + } + outerPointer.map { p => + val outerObj = p() + val d = outerObj.getClass +: paramTypes + val c = getConstructor(outerObj.getClass +: paramTypes) + (args: Seq[AnyRef]) => { + c.newInstance(outerObj +: args: _*) + } + }.getOrElse { + val c = getConstructor(paramTypes) + (args: Seq[AnyRef]) => { + c.newInstance(args: _*) + } + } + } + + override def eval(input: InternalRow): Any = { + val argValues = arguments.map(_.eval(input)) + constructor(argValues.map(_.asInstanceOf[AnyRef])) + } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val javaType = CodeGenerator.javaType(dataType) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala index b0188b0098def..bf805f4f29ac5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala @@ -47,6 +47,20 @@ class InvokeTargetSubClass extends InvokeTargetClass { override def binOp(e1: Int, e2: Double): Double = e1 - e2 } +// Tests for NewInstance +class Outer extends Serializable { + class Inner(val value: Int) { + override def hashCode(): Int = super.hashCode() + override def equals(other: Any): Boolean = { + if (other.isInstanceOf[Inner]) { + value == other.asInstanceOf[Inner].value + } else { + false + } + } + } +} + class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { test("SPARK-16622: The returned value of the called method in Invoke can be null") { @@ -383,6 +397,27 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } } + test("SPARK-23584 NewInstance should support interpreted execution") { + // Normal case test + val newInst1 = NewInstance( + cls = classOf[GenericArrayData], + arguments = Literal.fromObject(List(1, 2, 3)) :: Nil, + propagateNull = false, + dataType = ArrayType(IntegerType), + outerPointer = None) + checkObjectExprEvaluation(newInst1, new GenericArrayData(List(1, 2, 3))) + + // Inner class case test + val outerObj = new Outer() + val newInst2 = NewInstance( + cls = classOf[outerObj.Inner], + arguments = Literal(1) :: Nil, + propagateNull = false, + dataType = ObjectType(classOf[outerObj.Inner]), + outerPointer = Some(() => outerObj)) + checkObjectExprEvaluation(newInst2, new outerObj.Inner(1)) + } + test("LambdaVariable should support interpreted execution") { def genSchema(dt: DataType): Seq[StructType] = { Seq(StructType(StructField("col_1", dt, nullable = false) :: Nil), @@ -421,6 +456,7 @@ class TestBean extends Serializable { private var x: Int = 0 def setX(i: Int): Unit = x = i + def setNonPrimitive(i: AnyRef): Unit = assert(i != null, "this setter should not be called with null.") } From e13416502f814b04d59bb650953a0114332d163a Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Thu, 19 Apr 2018 14:42:50 +0200 Subject: [PATCH 0649/1161] [SPARK-23588][SQL] CatalystToExternalMap should support interpreted execution ## What changes were proposed in this pull request? This pr supported interpreted mode for `CatalystToExternalMap`. ## How was this patch tested? Added tests in `ObjectExpressionsSuite`. Author: Takeshi Yamamuro Closes #20979 from maropu/SPARK-23588. --- .../expressions/objects/objects.scala | 39 +++++++++++++++++-- .../expressions/ObjectExpressionsSuite.scala | 34 +++++++++++++--- 2 files changed, 63 insertions(+), 10 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 1645bd7d57b1d..bc17d1229420a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -28,12 +28,12 @@ import scala.util.Try import org.apache.spark.{SparkConf, SparkEnv} import org.apache.spark.serializer._ import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.{InternalRow, ScalaReflection} +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, ScalaReflection} import org.apache.spark.sql.catalyst.ScalaReflection.universe.TermName import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ -import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData} import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -1033,8 +1033,39 @@ case class CatalystToExternalMap private( override def children: Seq[Expression] = keyLambdaFunction :: valueLambdaFunction :: inputData :: Nil - override def eval(input: InternalRow): Any = - throw new UnsupportedOperationException("Only code-generated evaluation is supported") + private lazy val inputMapType = inputData.dataType.asInstanceOf[MapType] + + private lazy val keyConverter = + CatalystTypeConverters.createToScalaConverter(inputMapType.keyType) + private lazy val valueConverter = + CatalystTypeConverters.createToScalaConverter(inputMapType.valueType) + + private def newMapBuilder(): Builder[AnyRef, AnyRef] = { + val clazz = Utils.classForName(collClass.getCanonicalName + "$") + val module = clazz.getField("MODULE$").get(null) + val method = clazz.getMethod("newBuilder") + method.invoke(module).asInstanceOf[Builder[AnyRef, AnyRef]] + } + + override def eval(input: InternalRow): Any = { + val result = inputData.eval(input).asInstanceOf[MapData] + if (result != null) { + val builder = newMapBuilder() + builder.sizeHint(result.numElements()) + val keyArray = result.keyArray() + val valueArray = result.valueArray() + var i = 0 + while (i < result.numElements()) { + val key = keyConverter(keyArray.get(i, inputMapType.keyType)) + val value = valueConverter(valueArray.get(i, inputMapType.valueType)) + builder += Tuple2(key, value) + i += 1 + } + builder.result() + } else { + null + } + } override def dataType: DataType = ObjectType(collClass) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala index bf805f4f29ac5..bcd035c1eba0b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala @@ -27,12 +27,14 @@ import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} import org.apache.spark.sql.{RandomDataGenerator, Row} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.ResolveTimeZone -import org.apache.spark.sql.catalyst.encoders.{ExamplePointUDT, ExpressionEncoder, RowEncoder} +import org.apache.spark.sql.catalyst.analysis.{ResolveTimeZone, SimpleAnalyzer, UnresolvedDeserializer} +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.encoders._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection import org.apache.spark.sql.catalyst.expressions.objects._ -import org.apache.spark.sql.catalyst.util._ -import org.apache.spark.sql.catalyst.util.DateTimeUtils.{SQLDate, SQLTimestamp} +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData} +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -162,9 +164,10 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { "fromPrimitiveArray", ObjectType(classOf[Array[Int]]), Array[Int](1, 2, 3), UnsafeArrayData.fromPrimitiveArray(Array[Int](1, 2, 3))), (DateTimeUtils.getClass, ObjectType(classOf[Date]), - "toJavaDate", ObjectType(classOf[SQLDate]), 77777, DateTimeUtils.toJavaDate(77777)), + "toJavaDate", ObjectType(classOf[DateTimeUtils.SQLDate]), 77777, + DateTimeUtils.toJavaDate(77777)), (DateTimeUtils.getClass, ObjectType(classOf[Timestamp]), - "toJavaTimestamp", ObjectType(classOf[SQLTimestamp]), + "toJavaTimestamp", ObjectType(classOf[DateTimeUtils.SQLTimestamp]), 88888888.toLong, DateTimeUtils.toJavaTimestamp(88888888)) ).foreach { case (cls, dataType, methodName, argType, arg, expected) => checkObjectExprEvaluation(StaticInvoke(cls, dataType, methodName, @@ -450,6 +453,25 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } } } + + implicit private def mapIntStrEncoder = ExpressionEncoder[Map[Int, String]]() + + test("SPARK-23588 CatalystToExternalMap should support interpreted execution") { + // To get a resolved `CatalystToExternalMap` expression, we build a deserializer plan + // with dummy input, resolve the plan by the analyzer, and replace the dummy input + // with a literal for tests. + val unresolvedDeser = UnresolvedDeserializer(encoderFor[Map[Int, String]].deserializer) + val dummyInputPlan = LocalRelation('value.map(MapType(IntegerType, StringType))) + val plan = Project(Alias(unresolvedDeser, "none")() :: Nil, dummyInputPlan) + + val analyzedPlan = SimpleAnalyzer.execute(plan) + val Alias(toMapExpr: CatalystToExternalMap, _) = analyzedPlan.expressions.head + + // Replaces the dummy input with a literal for tests here + val data = Map[Int, String](0 -> "v0", 1 -> "v1", 2 -> null, 3 -> "v3") + val deserializer = toMapExpr.copy(inputData = Literal.create(data)) + checkObjectExprEvaluation(deserializer, expected = data) + } } class TestBean extends Serializable { From 9e10f69df52abde2de5d93435bab54e97dd59d9c Mon Sep 17 00:00:00 2001 From: jinxing Date: Thu, 19 Apr 2018 21:07:21 +0800 Subject: [PATCH 0650/1161] [SPARK-22676][FOLLOW-UP] fix code style for test. ## What changes were proposed in this pull request? This pr address comments in https://github.com/apache/spark/pull/19868 ; Fix the code style for `org.apache.spark.sql.hive.QueryPartitionSuite` by using: `withTempView`, `withTempDir`, `withTable`... Author: jinxing Closes #21091 from jinxing64/SPARK-22676-FOLLOW-UP. --- .../spark/sql/hive/QueryPartitionSuite.scala | 109 +++++++----------- 1 file changed, 41 insertions(+), 68 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala index 78156b17fb43b..1e396553c9c52 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala @@ -33,80 +33,53 @@ import org.apache.spark.util.Utils class QueryPartitionSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { import spark.implicits._ - test("SPARK-5068: query data when path doesn't exist") { - withSQLConf((SQLConf.HIVE_VERIFY_PARTITION_PATH.key, "true")) { - val testData = sparkContext.parallelize( - (1 to 10).map(i => TestData(i, i.toString))).toDF() - testData.createOrReplaceTempView("testData") - - val tmpDir = Files.createTempDir() - // create the table for test - sql(s"CREATE TABLE table_with_partition(key int,value string) " + - s"PARTITIONED by (ds string) location '${tmpDir.toURI}' ") - sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='1') " + - "SELECT key,value FROM testData") - sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='2') " + - "SELECT key,value FROM testData") - sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='3') " + - "SELECT key,value FROM testData") - sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='4') " + - "SELECT key,value FROM testData") - - // test for the exist path - checkAnswer(sql("select key,value from table_with_partition"), - testData.toDF.collect ++ testData.toDF.collect - ++ testData.toDF.collect ++ testData.toDF.collect) - - // delete the path of one partition - tmpDir.listFiles - .find { f => f.isDirectory && f.getName().startsWith("ds=") } - .foreach { f => Utils.deleteRecursively(f) } - - // test for after delete the path - checkAnswer(sql("select key,value from table_with_partition"), - testData.toDF.collect ++ testData.toDF.collect ++ testData.toDF.collect) + private def queryWhenPathNotExist(): Unit = { + withTempView("testData") { + withTable("table_with_partition", "createAndInsertTest") { + withTempDir { tmpDir => + val testData = sparkContext.parallelize( + (1 to 10).map(i => TestData(i, i.toString))).toDF() + testData.createOrReplaceTempView("testData") + + // create the table for test + sql(s"CREATE TABLE table_with_partition(key int,value string) " + + s"PARTITIONED by (ds string) location '${tmpDir.toURI}' ") + sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='1') " + + "SELECT key,value FROM testData") + sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='2') " + + "SELECT key,value FROM testData") + sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='3') " + + "SELECT key,value FROM testData") + sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='4') " + + "SELECT key,value FROM testData") + + // test for the exist path + checkAnswer(sql("select key,value from table_with_partition"), + testData.union(testData).union(testData).union(testData)) + + // delete the path of one partition + tmpDir.listFiles + .find { f => f.isDirectory && f.getName().startsWith("ds=") } + .foreach { f => Utils.deleteRecursively(f) } + + // test for after delete the path + checkAnswer(sql("select key,value from table_with_partition"), + testData.union(testData).union(testData)) + } + } + } + } - sql("DROP TABLE IF EXISTS table_with_partition") - sql("DROP TABLE IF EXISTS createAndInsertTest") + test("SPARK-5068: query data when path doesn't exist") { + withSQLConf(SQLConf.HIVE_VERIFY_PARTITION_PATH.key -> "true") { + queryWhenPathNotExist() } } test("Replace spark.sql.hive.verifyPartitionPath by spark.files.ignoreMissingFiles") { - withSQLConf((SQLConf.HIVE_VERIFY_PARTITION_PATH.key, "false")) { + withSQLConf(SQLConf.HIVE_VERIFY_PARTITION_PATH.key -> "false") { sparkContext.conf.set(IGNORE_MISSING_FILES.key, "true") - val testData = sparkContext.parallelize( - (1 to 10).map(i => TestData(i, i.toString))).toDF() - testData.createOrReplaceTempView("testData") - - val tmpDir = Files.createTempDir() - // create the table for test - sql(s"CREATE TABLE table_with_partition(key int,value string) " + - s"PARTITIONED by (ds string) location '${tmpDir.toURI}' ") - sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='1') " + - "SELECT key,value FROM testData") - sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='2') " + - "SELECT key,value FROM testData") - sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='3') " + - "SELECT key,value FROM testData") - sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='4') " + - "SELECT key,value FROM testData") - - // test for the exist path - checkAnswer(sql("select key,value from table_with_partition"), - testData.toDF.collect ++ testData.toDF.collect - ++ testData.toDF.collect ++ testData.toDF.collect) - - // delete the path of one partition - tmpDir.listFiles - .find { f => f.isDirectory && f.getName().startsWith("ds=") } - .foreach { f => Utils.deleteRecursively(f) } - - // test for after delete the path - checkAnswer(sql("select key,value from table_with_partition"), - testData.toDF.collect ++ testData.toDF.collect ++ testData.toDF.collect) - - sql("DROP TABLE IF EXISTS table_with_partition") - sql("DROP TABLE IF EXISTS createAndInsertTest") + queryWhenPathNotExist() } } From d96c3e33cc2a95de8e15e1a2ddf50a8d0cc66dd2 Mon Sep 17 00:00:00 2001 From: Xingbo Jiang Date: Thu, 19 Apr 2018 21:21:22 +0800 Subject: [PATCH 0651/1161] [SPARK-21811][SQL] Fix the inconsistency behavior when finding the widest common type ## What changes were proposed in this pull request? Currently we find the wider common type by comparing the two types from left to right, this can be a problem when you have two data types which don't have a common type but each can be promoted to StringType. For instance, if you have a table with the schema: [c1: date, c2: string, c3: int] The following succeeds: SELECT coalesce(c1, c2, c3) FROM table While the following produces an exception: SELECT coalesce(c1, c3, c2) FROM table This is only a issue when the seq of dataTypes contains `StringType` and all the types can do string promotion. close #19033 ## How was this patch tested? Add test in `TypeCoercionSuite` Author: Xingbo Jiang Closes #21074 from jiangxb1987/typeCoercion. --- docs/sql-programming-guide.md | 2 +- .../sql/catalyst/analysis/TypeCoercion.scala | 24 +++++++++++++++---- .../catalyst/analysis/TypeCoercionSuite.scala | 13 ++++++++++ 3 files changed, 34 insertions(+), 5 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 55d35b9dd31db..e8ff1470970f7 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1810,7 +1810,7 @@ working with timestamps in `pandas_udf`s to get the best performance, see - Since Spark 2.4, writing a dataframe with an empty or nested empty schema using any file formats (parquet, orc, json, text, csv etc.) is not allowed. An exception is thrown when attempting to write dataframes with empty schema. - Since Spark 2.4, Spark compares a DATE type with a TIMESTAMP type after promotes both sides to TIMESTAMP. To set `false` to `spark.sql.hive.compareDateTimestampInTimestamp` restores the previous behavior. This option will be removed in Spark 3.0. - Since Spark 2.4, creating a managed table with nonempty location is not allowed. An exception is thrown when attempting to create a managed table with nonempty location. To set `true` to `spark.sql.allowCreatingManagedTableUsingNonemptyLocation` restores the previous behavior. This option will be removed in Spark 3.0. - + - Since Spark 2.4, the type coercion rules can automatically promote the argument types of the variadic SQL functions (e.g., IN/COALESCE) to the widest common type, no matter how the input arguments order. In prior Spark versions, the promotion could fail in some specific orders (e.g., TimestampType, IntegerType and StringType) and throw an exception. ## Upgrading From Spark SQL 2.2 to 2.3 - Since Spark 2.3, the queries from raw JSON/CSV files are disallowed when the referenced columns only include the internal corrupt record column (named `_corrupt_record` by default). For example, `spark.read.schema(schema).json(file).filter($"_corrupt_record".isNotNull).count()` and `spark.read.schema(schema).json(file).select("_corrupt_record").show()`. Instead, you can cache or save the parsed results and then send the same query. For example, `val df = spark.read.schema(schema).json(file).cache()` and then `df.filter($"_corrupt_record".isNotNull).count()`. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index ec7e7761dc4c2..281f206e8d59e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -175,11 +175,27 @@ object TypeCoercion { }) } + /** + * Whether the data type contains StringType. + */ + def hasStringType(dt: DataType): Boolean = dt match { + case StringType => true + case ArrayType(et, _) => hasStringType(et) + // Add StructType if we support string promotion for struct fields in the future. + case _ => false + } + private def findWiderCommonType(types: Seq[DataType]): Option[DataType] = { - types.foldLeft[Option[DataType]](Some(NullType))((r, c) => r match { - case Some(d) => findWiderTypeForTwo(d, c) - case None => None - }) + // findWiderTypeForTwo doesn't satisfy the associative law, i.e. (a op b) op c may not equal + // to a op (b op c). This is only a problem for StringType or nested StringType in ArrayType. + // Excluding these types, findWiderTypeForTwo satisfies the associative law. For instance, + // (TimestampType, IntegerType, StringType) should have StringType as the wider common type. + val (stringTypes, nonStringTypes) = types.partition(hasStringType(_)) + (stringTypes.distinct ++ nonStringTypes).foldLeft[Option[DataType]](Some(NullType))((r, c) => + r match { + case Some(d) => findWiderTypeForTwo(d, c) + case _ => None + }) } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala index 8ac49dc05e3cf..fd6a3121663ed 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala @@ -539,6 +539,9 @@ class TypeCoercionSuite extends AnalysisTest { val floatLit = Literal.create(1.0f, FloatType) val timestampLit = Literal.create("2017-04-12", TimestampType) val decimalLit = Literal(new java.math.BigDecimal("1000000000000000000000")) + val tsArrayLit = Literal(Array(new Timestamp(System.currentTimeMillis()))) + val strArrayLit = Literal(Array("c")) + val intArrayLit = Literal(Array(1)) ruleTest(rule, Coalesce(Seq(doubleLit, intLit, floatLit)), @@ -572,6 +575,16 @@ class TypeCoercionSuite extends AnalysisTest { Coalesce(Seq(nullLit, floatNullLit, doubleLit, stringLit)), Coalesce(Seq(Cast(nullLit, StringType), Cast(floatNullLit, StringType), Cast(doubleLit, StringType), Cast(stringLit, StringType)))) + + ruleTest(rule, + Coalesce(Seq(timestampLit, intLit, stringLit)), + Coalesce(Seq(Cast(timestampLit, StringType), Cast(intLit, StringType), + Cast(stringLit, StringType)))) + + ruleTest(rule, + Coalesce(Seq(tsArrayLit, intArrayLit, strArrayLit)), + Coalesce(Seq(Cast(tsArrayLit, ArrayType(StringType)), + Cast(intArrayLit, ArrayType(StringType)), Cast(strArrayLit, ArrayType(StringType))))) } test("CreateArray casts") { From 0deaa5251326a32a3d2d2b8851193ca926303972 Mon Sep 17 00:00:00 2001 From: wuyi Date: Thu, 19 Apr 2018 09:00:33 -0500 Subject: [PATCH 0652/1161] [SPARK-24021][CORE] fix bug in BlacklistTracker's updateBlacklistForFetchFailure MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? There‘s a miswrite in BlacklistTracker's updateBlacklistForFetchFailure: ``` val blacklistedExecsOnNode = nodeToBlacklistedExecs.getOrElseUpdate(exec, HashSet[String]()) blacklistedExecsOnNode += exec ``` where first **exec** should be **host**. ## How was this patch tested? adjust existed test. Author: wuyi Closes #21104 from Ngone51/SPARK-24021. --- .../scala/org/apache/spark/scheduler/BlacklistTracker.scala | 2 +- .../org/apache/spark/scheduler/BlacklistTrackerSuite.scala | 5 +++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/BlacklistTracker.scala b/core/src/main/scala/org/apache/spark/scheduler/BlacklistTracker.scala index 952598f6de19d..30cf75d43ee09 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/BlacklistTracker.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/BlacklistTracker.scala @@ -210,7 +210,7 @@ private[scheduler] class BlacklistTracker ( updateNextExpiryTime() killBlacklistedExecutor(exec) - val blacklistedExecsOnNode = nodeToBlacklistedExecs.getOrElseUpdate(exec, HashSet[String]()) + val blacklistedExecsOnNode = nodeToBlacklistedExecs.getOrElseUpdate(host, HashSet[String]()) blacklistedExecsOnNode += exec } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala index 06d7afaaff55c..96c8404327e24 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala @@ -574,6 +574,9 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M verify(allocationClientMock, never).killExecutors(any(), any(), any(), any()) verify(allocationClientMock, never).killExecutorsOnHost(any()) + assert(blacklist.nodeToBlacklistedExecs.contains("hostA")) + assert(blacklist.nodeToBlacklistedExecs("hostA").contains("1")) + // Enable auto-kill. Blacklist an executor and make sure killExecutors is called. conf.set(config.BLACKLIST_KILL_ENABLED, true) blacklist = new BlacklistTracker(listenerBusMock, conf, Some(allocationClientMock), clock) @@ -589,6 +592,8 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M 1000 + blacklist.BLACKLIST_TIMEOUT_MILLIS) assert(blacklist.nextExpiryTime === 1000 + blacklist.BLACKLIST_TIMEOUT_MILLIS) assert(blacklist.nodeIdToBlacklistExpiryTime.isEmpty) + assert(blacklist.nodeToBlacklistedExecs.contains("hostA")) + assert(blacklist.nodeToBlacklistedExecs("hostA").contains("1")) // Enable external shuffle service to see if all the executors on this node will be killed. conf.set(config.SHUFFLE_SERVICE_ENABLED, true) From 6e19f7683fc73fabe7cdaac4eb1982d2e3e607b7 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 19 Apr 2018 17:54:53 +0200 Subject: [PATCH 0653/1161] [SPARK-23989][SQL] exchange should copy data before non-serialized shuffle ## What changes were proposed in this pull request? In Spark SQL, we usually reuse the `UnsafeRow` instance and need to copy the data when a place buffers non-serialized objects. Shuffle may buffer objects if we don't make it to the bypass merge shuffle or unsafe shuffle. `ShuffleExchangeExec.needToCopyObjectsBeforeShuffle` misses the case that, if `spark.sql.shuffle.partitions` is large enough, we could fail to run unsafe shuffle and go with the non-serialized shuffle. This bug is very hard to hit since users wouldn't set such a large number of partitions(16 million) for Spark SQL exchange. TODO: test ## How was this patch tested? todo. Author: Wenchen Fan Closes #21101 from cloud-fan/shuffle. --- .../exchange/ShuffleExchangeExec.scala | 21 +++++++++---------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala index 4d95ee34f30de..b89203719541b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala @@ -153,12 +153,9 @@ object ShuffleExchangeExec { * See SPARK-2967, SPARK-4479, and SPARK-7375 for more discussion of this issue. * * @param partitioner the partitioner for the shuffle - * @param serializer the serializer that will be used to write rows * @return true if rows should be copied before being shuffled, false otherwise */ - private def needToCopyObjectsBeforeShuffle( - partitioner: Partitioner, - serializer: Serializer): Boolean = { + private def needToCopyObjectsBeforeShuffle(partitioner: Partitioner): Boolean = { // Note: even though we only use the partitioner's `numPartitions` field, we require it to be // passed instead of directly passing the number of partitions in order to guard against // corner-cases where a partitioner constructed with `numPartitions` partitions may output @@ -167,22 +164,24 @@ object ShuffleExchangeExec { val shuffleManager = SparkEnv.get.shuffleManager val sortBasedShuffleOn = shuffleManager.isInstanceOf[SortShuffleManager] val bypassMergeThreshold = conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200) + val numParts = partitioner.numPartitions if (sortBasedShuffleOn) { - val bypassIsSupported = SparkEnv.get.shuffleManager.isInstanceOf[SortShuffleManager] - if (bypassIsSupported && partitioner.numPartitions <= bypassMergeThreshold) { + if (numParts <= bypassMergeThreshold) { // If we're using the original SortShuffleManager and the number of output partitions is // sufficiently small, then Spark will fall back to the hash-based shuffle write path, which // doesn't buffer deserialized records. // Note that we'll have to remove this case if we fix SPARK-6026 and remove this bypass. false - } else if (serializer.supportsRelocationOfSerializedObjects) { + } else if (numParts <= SortShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE) { // SPARK-4550 and SPARK-7081 extended sort-based shuffle to serialize individual records // prior to sorting them. This optimization is only applied in cases where shuffle // dependency does not specify an aggregator or ordering and the record serializer has - // certain properties. If this optimization is enabled, we can safely avoid the copy. + // certain properties and the number of partitions doesn't exceed the limitation. If this + // optimization is enabled, we can safely avoid the copy. // - // Exchange never configures its ShuffledRDDs with aggregators or key orderings, so we only - // need to check whether the optimization is enabled and supported by our serializer. + // Exchange never configures its ShuffledRDDs with aggregators or key orderings, and the + // serializer in Spark SQL always satisfy the properties, so we only need to check whether + // the number of partitions exceeds the limitation. false } else { // Spark's SortShuffleManager uses `ExternalSorter` to buffer records in memory, so we must @@ -298,7 +297,7 @@ object ShuffleExchangeExec { rdd } - if (needToCopyObjectsBeforeShuffle(part, serializer)) { + if (needToCopyObjectsBeforeShuffle(part)) { newRdd.mapPartitionsInternal { iter => val getPartitionKey = getPartitionKeyExtractor() iter.map { row => (part.getPartition(getPartitionKey(row)), row.copy()) } From a471880afbeafd4ef54c15a97e72ea7ff784a88d Mon Sep 17 00:00:00 2001 From: "wm624@hotmail.com" Date: Thu, 19 Apr 2018 09:40:20 -0700 Subject: [PATCH 0654/1161] [SPARK-24026][ML] Add Power Iteration Clustering to spark.ml ## What changes were proposed in this pull request? This PR adds PowerIterationClustering as a Transformer to spark.ml. In the transform method, it calls spark.mllib's PowerIterationClustering.run() method and transforms the return value assignments (the Kmeans output of the pseudo-eigenvector) as a DataFrame (id: LongType, cluster: IntegerType). This PR is copied and modified from https://github.com/apache/spark/pull/15770 The primary author is wangmiao1981 ## How was this patch tested? This PR has 2 types of tests: * Copies of tests from spark.mllib's PIC tests * New tests specific to the spark.ml APIs Author: wm624@hotmail.com Author: wangmiao1981 Author: Joseph K. Bradley Closes #21090 from jkbradley/wangmiao1981-pic. --- .../clustering/PowerIterationClustering.scala | 256 ++++++++++++++++++ .../PowerIterationClusteringSuite.scala | 238 ++++++++++++++++ 2 files changed, 494 insertions(+) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/clustering/PowerIterationClustering.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/clustering/PowerIterationClusteringSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/PowerIterationClustering.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/PowerIterationClustering.scala new file mode 100644 index 0000000000000..2c30a1d9aa947 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/PowerIterationClustering.scala @@ -0,0 +1,256 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.clustering + +import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.ml.Transformer +import org.apache.spark.ml.param._ +import org.apache.spark.ml.param.shared._ +import org.apache.spark.ml.util._ +import org.apache.spark.mllib.clustering.{PowerIterationClustering => MLlibPowerIterationClustering} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{DataFrame, Dataset, Row} +import org.apache.spark.sql.functions.col +import org.apache.spark.sql.types._ + +/** + * Common params for PowerIterationClustering + */ +private[clustering] trait PowerIterationClusteringParams extends Params with HasMaxIter + with HasPredictionCol { + + /** + * The number of clusters to create (k). Must be > 1. Default: 2. + * @group param + */ + @Since("2.4.0") + final val k = new IntParam(this, "k", "The number of clusters to create. " + + "Must be > 1.", ParamValidators.gt(1)) + + /** @group getParam */ + @Since("2.4.0") + def getK: Int = $(k) + + /** + * Param for the initialization algorithm. This can be either "random" to use a random vector + * as vertex properties, or "degree" to use a normalized sum of similarities with other vertices. + * Default: random. + * @group expertParam + */ + @Since("2.4.0") + final val initMode = { + val allowedParams = ParamValidators.inArray(Array("random", "degree")) + new Param[String](this, "initMode", "The initialization algorithm. This can be either " + + "'random' to use a random vector as vertex properties, or 'degree' to use a normalized sum " + + "of similarities with other vertices. Supported options: 'random' and 'degree'.", + allowedParams) + } + + /** @group expertGetParam */ + @Since("2.4.0") + def getInitMode: String = $(initMode) + + /** + * Param for the name of the input column for vertex IDs. + * Default: "id" + * @group param + */ + @Since("2.4.0") + val idCol = new Param[String](this, "idCol", "Name of the input column for vertex IDs.", + (value: String) => value.nonEmpty) + + setDefault(idCol, "id") + + /** @group getParam */ + @Since("2.4.0") + def getIdCol: String = getOrDefault(idCol) + + /** + * Param for the name of the input column for neighbors in the adjacency list representation. + * Default: "neighbors" + * @group param + */ + @Since("2.4.0") + val neighborsCol = new Param[String](this, "neighborsCol", + "Name of the input column for neighbors in the adjacency list representation.", + (value: String) => value.nonEmpty) + + setDefault(neighborsCol, "neighbors") + + /** @group getParam */ + @Since("2.4.0") + def getNeighborsCol: String = $(neighborsCol) + + /** + * Param for the name of the input column for neighbors in the adjacency list representation. + * Default: "similarities" + * @group param + */ + @Since("2.4.0") + val similaritiesCol = new Param[String](this, "similaritiesCol", + "Name of the input column for neighbors in the adjacency list representation.", + (value: String) => value.nonEmpty) + + setDefault(similaritiesCol, "similarities") + + /** @group getParam */ + @Since("2.4.0") + def getSimilaritiesCol: String = $(similaritiesCol) + + protected def validateAndTransformSchema(schema: StructType): StructType = { + SchemaUtils.checkColumnTypes(schema, $(idCol), Seq(IntegerType, LongType)) + SchemaUtils.checkColumnTypes(schema, $(neighborsCol), + Seq(ArrayType(IntegerType, containsNull = false), + ArrayType(LongType, containsNull = false))) + SchemaUtils.checkColumnTypes(schema, $(similaritiesCol), + Seq(ArrayType(FloatType, containsNull = false), + ArrayType(DoubleType, containsNull = false))) + SchemaUtils.appendColumn(schema, $(predictionCol), IntegerType) + } +} + +/** + * :: Experimental :: + * Power Iteration Clustering (PIC), a scalable graph clustering algorithm developed by + * Lin and Cohen. From the abstract: + * PIC finds a very low-dimensional embedding of a dataset using truncated power + * iteration on a normalized pair-wise similarity matrix of the data. + * + * PIC takes an affinity matrix between items (or vertices) as input. An affinity matrix + * is a symmetric matrix whose entries are non-negative similarities between items. + * PIC takes this matrix (or graph) as an adjacency matrix. Specifically, each input row includes: + * - `idCol`: vertex ID + * - `neighborsCol`: neighbors of vertex in `idCol` + * - `similaritiesCol`: non-negative weights (similarities) of edges between the vertex + * in `idCol` and each neighbor in `neighborsCol` + * PIC returns a cluster assignment for each input vertex. It appends a new column `predictionCol` + * containing the cluster assignment in `[0,k)` for each row (vertex). + * + * Notes: + * - [[PowerIterationClustering]] is a transformer with an expensive [[transform]] operation. + * Transform runs the iterative PIC algorithm to cluster the whole input dataset. + * - Input validation: This validates that similarities are non-negative but does NOT validate + * that the input matrix is symmetric. + * + * @see + * Spectral clustering (Wikipedia) + */ +@Since("2.4.0") +@Experimental +class PowerIterationClustering private[clustering] ( + @Since("2.4.0") override val uid: String) + extends Transformer with PowerIterationClusteringParams with DefaultParamsWritable { + + setDefault( + k -> 2, + maxIter -> 20, + initMode -> "random") + + @Since("2.4.0") + def this() = this(Identifiable.randomUID("PowerIterationClustering")) + + /** @group setParam */ + @Since("2.4.0") + def setPredictionCol(value: String): this.type = set(predictionCol, value) + + /** @group setParam */ + @Since("2.4.0") + def setK(value: Int): this.type = set(k, value) + + /** @group expertSetParam */ + @Since("2.4.0") + def setInitMode(value: String): this.type = set(initMode, value) + + /** @group setParam */ + @Since("2.4.0") + def setMaxIter(value: Int): this.type = set(maxIter, value) + + /** @group setParam */ + @Since("2.4.0") + def setIdCol(value: String): this.type = set(idCol, value) + + /** @group setParam */ + @Since("2.4.0") + def setNeighborsCol(value: String): this.type = set(neighborsCol, value) + + /** @group setParam */ + @Since("2.4.0") + def setSimilaritiesCol(value: String): this.type = set(similaritiesCol, value) + + @Since("2.4.0") + override def transform(dataset: Dataset[_]): DataFrame = { + transformSchema(dataset.schema, logging = true) + + val sparkSession = dataset.sparkSession + val idColValue = $(idCol) + val rdd: RDD[(Long, Long, Double)] = + dataset.select( + col($(idCol)).cast(LongType), + col($(neighborsCol)).cast(ArrayType(LongType, containsNull = false)), + col($(similaritiesCol)).cast(ArrayType(DoubleType, containsNull = false)) + ).rdd.flatMap { + case Row(id: Long, nbrs: Seq[_], sims: Seq[_]) => + require(nbrs.size == sims.size, s"The length of the neighbor ID list must be " + + s"equal to the the length of the neighbor similarity list. Row for ID " + + s"$idColValue=$id has neighbor ID list of length ${nbrs.length} but similarity list " + + s"of length ${sims.length}.") + nbrs.asInstanceOf[Seq[Long]].zip(sims.asInstanceOf[Seq[Double]]).map { + case (nbr, similarity) => (id, nbr, similarity) + } + } + val algorithm = new MLlibPowerIterationClustering() + .setK($(k)) + .setInitializationMode($(initMode)) + .setMaxIterations($(maxIter)) + val model = algorithm.run(rdd) + + val predictionsRDD: RDD[Row] = model.assignments.map { assignment => + Row(assignment.id, assignment.cluster) + } + + val predictionsSchema = StructType(Seq( + StructField($(idCol), LongType, nullable = false), + StructField($(predictionCol), IntegerType, nullable = false))) + val predictions = { + val uncastPredictions = sparkSession.createDataFrame(predictionsRDD, predictionsSchema) + dataset.schema($(idCol)).dataType match { + case _: LongType => + uncastPredictions + case otherType => + uncastPredictions.select(col($(idCol)).cast(otherType).alias($(idCol))) + } + } + + dataset.join(predictions, $(idCol)) + } + + @Since("2.4.0") + override def transformSchema(schema: StructType): StructType = { + validateAndTransformSchema(schema) + } + + @Since("2.4.0") + override def copy(extra: ParamMap): PowerIterationClustering = defaultCopy(extra) +} + +@Since("2.4.0") +object PowerIterationClustering extends DefaultParamsReadable[PowerIterationClustering] { + + @Since("2.4.0") + override def load(path: String): PowerIterationClustering = super.load(path) +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/PowerIterationClusteringSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/PowerIterationClusteringSuite.scala new file mode 100644 index 0000000000000..65328df17baff --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/PowerIterationClusteringSuite.scala @@ -0,0 +1,238 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.clustering + +import scala.collection.mutable + +import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession} +import org.apache.spark.sql.functions.col +import org.apache.spark.sql.types._ + + +class PowerIterationClusteringSuite extends SparkFunSuite + with MLlibTestSparkContext with DefaultReadWriteTest { + + @transient var data: Dataset[_] = _ + final val r1 = 1.0 + final val n1 = 10 + final val r2 = 4.0 + final val n2 = 40 + + override def beforeAll(): Unit = { + super.beforeAll() + + data = PowerIterationClusteringSuite.generatePICData(spark, r1, r2, n1, n2) + } + + test("default parameters") { + val pic = new PowerIterationClustering() + + assert(pic.getK === 2) + assert(pic.getMaxIter === 20) + assert(pic.getInitMode === "random") + assert(pic.getPredictionCol === "prediction") + assert(pic.getIdCol === "id") + assert(pic.getNeighborsCol === "neighbors") + assert(pic.getSimilaritiesCol === "similarities") + } + + test("parameter validation") { + intercept[IllegalArgumentException] { + new PowerIterationClustering().setK(1) + } + intercept[IllegalArgumentException] { + new PowerIterationClustering().setInitMode("no_such_a_mode") + } + intercept[IllegalArgumentException] { + new PowerIterationClustering().setIdCol("") + } + intercept[IllegalArgumentException] { + new PowerIterationClustering().setNeighborsCol("") + } + intercept[IllegalArgumentException] { + new PowerIterationClustering().setSimilaritiesCol("") + } + } + + test("power iteration clustering") { + val n = n1 + n2 + + val model = new PowerIterationClustering() + .setK(2) + .setMaxIter(40) + val result = model.transform(data) + + val predictions = Array.fill(2)(mutable.Set.empty[Long]) + result.select("id", "prediction").collect().foreach { + case Row(id: Long, cluster: Integer) => predictions(cluster) += id + } + assert(predictions.toSet == Set((1 until n1).toSet, (n1 until n).toSet)) + + val result2 = new PowerIterationClustering() + .setK(2) + .setMaxIter(10) + .setInitMode("degree") + .transform(data) + val predictions2 = Array.fill(2)(mutable.Set.empty[Long]) + result2.select("id", "prediction").collect().foreach { + case Row(id: Long, cluster: Integer) => predictions2(cluster) += id + } + assert(predictions2.toSet == Set((1 until n1).toSet, (n1 until n).toSet)) + } + + test("supported input types") { + val model = new PowerIterationClustering() + .setK(2) + .setMaxIter(1) + + def runTest(idType: DataType, neighborType: DataType, similarityType: DataType): Unit = { + val typedData = data.select( + col("id").cast(idType).alias("id"), + col("neighbors").cast(ArrayType(neighborType, containsNull = false)).alias("neighbors"), + col("similarities").cast(ArrayType(similarityType, containsNull = false)) + .alias("similarities") + ) + model.transform(typedData).collect() + } + + for (idType <- Seq(IntegerType, LongType)) { + runTest(idType, LongType, DoubleType) + } + for (neighborType <- Seq(IntegerType, LongType)) { + runTest(LongType, neighborType, DoubleType) + } + for (similarityType <- Seq(FloatType, DoubleType)) { + runTest(LongType, LongType, similarityType) + } + } + + test("invalid input: wrong types") { + val model = new PowerIterationClustering() + .setK(2) + .setMaxIter(1) + intercept[IllegalArgumentException] { + val typedData = data.select( + col("id").cast(DoubleType).alias("id"), + col("neighbors"), + col("similarities") + ) + model.transform(typedData) + } + intercept[IllegalArgumentException] { + val typedData = data.select( + col("id"), + col("neighbors").cast(ArrayType(DoubleType, containsNull = false)).alias("neighbors"), + col("similarities") + ) + model.transform(typedData) + } + intercept[IllegalArgumentException] { + val typedData = data.select( + col("id"), + col("neighbors"), + col("neighbors").alias("similarities") + ) + model.transform(typedData) + } + } + + test("invalid input: negative similarity") { + val model = new PowerIterationClustering() + .setMaxIter(1) + val badData = spark.createDataFrame(Seq( + (0, Array(1), Array(-1.0)), + (1, Array(0), Array(-1.0)) + )).toDF("id", "neighbors", "similarities") + val msg = intercept[SparkException] { + model.transform(badData) + }.getCause.getMessage + assert(msg.contains("Similarity must be nonnegative")) + } + + test("invalid input: mismatched lengths for neighbor and similarity arrays") { + val model = new PowerIterationClustering() + .setMaxIter(1) + val badData = spark.createDataFrame(Seq( + (0, Array(1), Array(0.5)), + (1, Array(0, 2), Array(0.5)), + (2, Array(1), Array(0.5)) + )).toDF("id", "neighbors", "similarities") + val msg = intercept[SparkException] { + model.transform(badData) + }.getCause.getMessage + assert(msg.contains("The length of the neighbor ID list must be equal to the the length of " + + "the neighbor similarity list.")) + assert(msg.contains(s"Row for ID ${model.getIdCol}=1")) + } + + test("read/write") { + val t = new PowerIterationClustering() + .setK(4) + .setMaxIter(100) + .setInitMode("degree") + .setIdCol("test_id") + .setNeighborsCol("myNeighborsCol") + .setSimilaritiesCol("mySimilaritiesCol") + .setPredictionCol("test_prediction") + testDefaultReadWrite(t) + } +} + +object PowerIterationClusteringSuite { + + /** Generates a circle of points. */ + private def genCircle(r: Double, n: Int): Array[(Double, Double)] = { + Array.tabulate(n) { i => + val theta = 2.0 * math.Pi * i / n + (r * math.cos(theta), r * math.sin(theta)) + } + } + + /** Computes Gaussian similarity. */ + private def sim(x: (Double, Double), y: (Double, Double)): Double = { + val dist2 = (x._1 - y._1) * (x._1 - y._1) + (x._2 - y._2) * (x._2 - y._2) + math.exp(-dist2 / 2.0) + } + + def generatePICData( + spark: SparkSession, + r1: Double, + r2: Double, + n1: Int, + n2: Int): DataFrame = { + // Generate two circles following the example in the PIC paper. + val n = n1 + n2 + val points = genCircle(r1, n1) ++ genCircle(r2, n2) + + val rows = for (i <- 1 until n) yield { + val neighbors = for (j <- 0 until i) yield { + j.toLong + } + val similarities = for (j <- 0 until i) yield { + sim(points(i), points(j)) + } + (i.toLong, neighbors.toArray, similarities.toArray) + } + + spark.createDataFrame(rows).toDF("id", "neighbors", "similarities") + } + +} From 9ea8d3d31b75246bf61118ac7934bc92c18b5f19 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Cattilapiros=E2=80=9D?= Date: Thu, 19 Apr 2018 18:55:59 +0200 Subject: [PATCH 0655/1161] [SPARK-22362][SQL] Add unit test for Window Aggregate Functions MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? Improving the test coverage of window functions focusing on missing test for window aggregate functions. No new UDAF test is added as it has been tested already. ## How was this patch tested? Only new tests were added, automated tests were executed. Author: “attilapiros” Author: Attila Zsolt Piros <2017933+attilapiros@users.noreply.github.com> Closes #20046 from attilapiros/SPARK-22362. --- .../resources/sql-tests/inputs/window.sql | 10 +- .../sql-tests/results/window.sql.out | 30 +- .../sql/DataFrameWindowFunctionsSuite.scala | 266 ++++++++++++++++++ 3 files changed, 294 insertions(+), 12 deletions(-) diff --git a/sql/core/src/test/resources/sql-tests/inputs/window.sql b/sql/core/src/test/resources/sql-tests/inputs/window.sql index c4bea34ec4cf3..cda4db4b449fe 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/window.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/window.sql @@ -76,7 +76,15 @@ ntile(2) OVER w AS ntile, row_number() OVER w AS row_number, var_pop(val) OVER w AS var_pop, var_samp(val) OVER w AS var_samp, -approx_count_distinct(val) OVER w AS approx_count_distinct +approx_count_distinct(val) OVER w AS approx_count_distinct, +covar_pop(val, val_long) OVER w AS covar_pop, +corr(val, val_long) OVER w AS corr, +stddev_samp(val) OVER w AS stddev_samp, +stddev_pop(val) OVER w AS stddev_pop, +collect_list(val) OVER w AS collect_list, +collect_set(val) OVER w AS collect_set, +skewness(val_double) OVER w AS skewness, +kurtosis(val_double) OVER w AS kurtosis FROM testData WINDOW w AS (PARTITION BY cate ORDER BY val) ORDER BY cate, val; diff --git a/sql/core/src/test/resources/sql-tests/results/window.sql.out b/sql/core/src/test/resources/sql-tests/results/window.sql.out index 133458ae9303b..4afbcd62853dc 100644 --- a/sql/core/src/test/resources/sql-tests/results/window.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/window.sql.out @@ -273,22 +273,30 @@ ntile(2) OVER w AS ntile, row_number() OVER w AS row_number, var_pop(val) OVER w AS var_pop, var_samp(val) OVER w AS var_samp, -approx_count_distinct(val) OVER w AS approx_count_distinct +approx_count_distinct(val) OVER w AS approx_count_distinct, +covar_pop(val, val_long) OVER w AS covar_pop, +corr(val, val_long) OVER w AS corr, +stddev_samp(val) OVER w AS stddev_samp, +stddev_pop(val) OVER w AS stddev_pop, +collect_list(val) OVER w AS collect_list, +collect_set(val) OVER w AS collect_set, +skewness(val_double) OVER w AS skewness, +kurtosis(val_double) OVER w AS kurtosis FROM testData WINDOW w AS (PARTITION BY cate ORDER BY val) ORDER BY cate, val -- !query 17 schema -struct +struct,collect_set:array,skewness:double,kurtosis:double> -- !query 17 output -NULL NULL NULL NULL NULL 0 NULL NULL NULL NULL NULL NULL NULL NULL NULL 1 1 0.5 0.0 1 1 NULL NULL 0 -3 NULL 3 3 3 1 3 3.0 NaN NULL 3 NULL 3 3 3 2 2 1.0 1.0 2 2 0.0 NaN 1 -NULL a NULL NULL NULL 0 NULL NULL NULL NULL NULL NULL NULL NULL NULL 1 1 0.25 0.0 1 1 NULL NULL 0 -1 a 1 1 1 2 2 1.0 0.0 NULL 1 NULL 1 1 1 2 2 0.75 0.3333333333333333 1 2 0.0 0.0 1 -1 a 1 1 1 2 2 1.0 0.0 NULL 1 NULL 1 1 1 2 2 0.75 0.3333333333333333 2 3 0.0 0.0 1 -2 a 2 1 1 3 4 1.3333333333333333 0.5773502691896258 NULL 1 NULL 2 2 2 4 3 1.0 1.0 2 4 0.22222222222222224 0.33333333333333337 2 -1 b 1 1 1 1 1 1.0 NaN 1 1 1 1 1 1 1 1 0.3333333333333333 0.0 1 1 0.0 NaN 1 -2 b 2 1 1 2 3 1.5 0.7071067811865476 1 1 1 2 2 2 2 2 0.6666666666666666 0.5 1 2 0.25 0.5 2 -3 b 3 1 1 3 6 2.0 1.0 1 1 1 3 3 3 3 3 1.0 1.0 2 3 0.6666666666666666 1.0 3 +NULL NULL NULL NULL NULL 0 NULL NULL NULL NULL NULL NULL NULL NULL NULL 1 1 0.5 0.0 1 1 NULL NULL 0 NULL NULL NULL NULL [] [] NULL NULL +3 NULL 3 3 3 1 3 3.0 NaN NULL 3 NULL 3 3 3 2 2 1.0 1.0 2 2 0.0 NaN 1 0.0 NaN NaN 0.0 [3] [3] NaN NaN +NULL a NULL NULL NULL 0 NULL NULL NULL NULL NULL NULL NULL NULL NULL 1 1 0.25 0.0 1 1 NULL NULL 0 NULL NULL NULL NULL [] [] NaN NaN +1 a 1 1 1 2 2 1.0 0.0 NULL 1 NULL 1 1 1 2 2 0.75 0.3333333333333333 1 2 0.0 0.0 1 0.0 NULL 0.0 0.0 [1,1] [1] 0.7071067811865476 -1.5 +1 a 1 1 1 2 2 1.0 0.0 NULL 1 NULL 1 1 1 2 2 0.75 0.3333333333333333 2 3 0.0 0.0 1 0.0 NULL 0.0 0.0 [1,1] [1] 0.7071067811865476 -1.5 +2 a 2 1 1 3 4 1.3333333333333333 0.5773502691896258 NULL 1 NULL 2 2 2 4 3 1.0 1.0 2 4 0.22222222222222224 0.33333333333333337 2 4.772185885555555E8 1.0 0.5773502691896258 0.4714045207910317 [1,1,2] [1,2] 1.1539890888012805 -0.6672217220327235 +1 b 1 1 1 1 1 1.0 NaN 1 1 1 1 1 1 1 1 0.3333333333333333 0.0 1 1 0.0 NaN 1 NULL NULL NaN 0.0 [1] [1] NaN NaN +2 b 2 1 1 2 3 1.5 0.7071067811865476 1 1 1 2 2 2 2 2 0.6666666666666666 0.5 1 2 0.25 0.5 2 0.0 NaN 0.7071067811865476 0.5 [1,2] [1,2] 0.0 -2.0000000000000013 +3 b 3 1 1 3 6 2.0 1.0 1 1 1 3 3 3 3 3 1.0 1.0 2 3 0.6666666666666666 1.0 3 5.3687091175E8 1.0 1.0 0.816496580927726 [1,2,3] [1,2,3] 0.7057890433107311 -1.4999999999999984 -- !query 18 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala index 281147835abde..3ea398aad7375 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql import java.sql.{Date, Timestamp} +import scala.collection.mutable + import org.apache.spark.TestUtils.{assertNotSpilled, assertSpilled} import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction, Window} import org.apache.spark.sql.functions._ @@ -86,6 +88,236 @@ class DataFrameWindowFunctionsSuite extends QueryTest with SharedSQLContext { assert(e.message.contains("requires window to be ordered")) } + test("corr, covar_pop, stddev_pop functions in specific window") { + val df = Seq( + ("a", "p1", 10.0, 20.0), + ("b", "p1", 20.0, 10.0), + ("c", "p2", 20.0, 20.0), + ("d", "p2", 20.0, 20.0), + ("e", "p3", 0.0, 0.0), + ("f", "p3", 6.0, 12.0), + ("g", "p3", 6.0, 12.0), + ("h", "p3", 8.0, 16.0), + ("i", "p4", 5.0, 5.0)).toDF("key", "partitionId", "value1", "value2") + checkAnswer( + df.select( + $"key", + corr("value1", "value2").over(Window.partitionBy("partitionId") + .orderBy("key").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)), + covar_pop("value1", "value2") + .over(Window.partitionBy("partitionId") + .orderBy("key").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)), + var_pop("value1") + .over(Window.partitionBy("partitionId") + .orderBy("key").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)), + stddev_pop("value1") + .over(Window.partitionBy("partitionId") + .orderBy("key").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)), + var_pop("value2") + .over(Window.partitionBy("partitionId") + .orderBy("key").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)), + stddev_pop("value2") + .over(Window.partitionBy("partitionId") + .orderBy("key").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing))), + + // As stddev_pop(expr) = sqrt(var_pop(expr)) + // the "stddev_pop" column can be calculated from the "var_pop" column. + // + // As corr(expr1, expr2) = covar_pop(expr1, expr2) / (stddev_pop(expr1) * stddev_pop(expr2)) + // the "corr" column can be calculated from the "covar_pop" and the two "stddev_pop" columns. + Seq( + Row("a", -1.0, -25.0, 25.0, 5.0, 25.0, 5.0), + Row("b", -1.0, -25.0, 25.0, 5.0, 25.0, 5.0), + Row("c", null, 0.0, 0.0, 0.0, 0.0, 0.0), + Row("d", null, 0.0, 0.0, 0.0, 0.0, 0.0), + Row("e", 1.0, 18.0, 9.0, 3.0, 36.0, 6.0), + Row("f", 1.0, 18.0, 9.0, 3.0, 36.0, 6.0), + Row("g", 1.0, 18.0, 9.0, 3.0, 36.0, 6.0), + Row("h", 1.0, 18.0, 9.0, 3.0, 36.0, 6.0), + Row("i", Double.NaN, 0.0, 0.0, 0.0, 0.0, 0.0))) + } + + test("covar_samp, var_samp (variance), stddev_samp (stddev) functions in specific window") { + val df = Seq( + ("a", "p1", 10.0, 20.0), + ("b", "p1", 20.0, 10.0), + ("c", "p2", 20.0, 20.0), + ("d", "p2", 20.0, 20.0), + ("e", "p3", 0.0, 0.0), + ("f", "p3", 6.0, 12.0), + ("g", "p3", 6.0, 12.0), + ("h", "p3", 8.0, 16.0), + ("i", "p4", 5.0, 5.0)).toDF("key", "partitionId", "value1", "value2") + checkAnswer( + df.select( + $"key", + covar_samp("value1", "value2").over(Window.partitionBy("partitionId") + .orderBy("key").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)), + var_samp("value1").over(Window.partitionBy("partitionId") + .orderBy("key").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)), + variance("value1").over(Window.partitionBy("partitionId") + .orderBy("key").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)), + stddev_samp("value1").over(Window.partitionBy("partitionId") + .orderBy("key").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)), + stddev("value1").over(Window.partitionBy("partitionId") + .orderBy("key").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)) + ), + Seq( + Row("a", -50.0, 50.0, 50.0, 7.0710678118654755, 7.0710678118654755), + Row("b", -50.0, 50.0, 50.0, 7.0710678118654755, 7.0710678118654755), + Row("c", 0.0, 0.0, 0.0, 0.0, 0.0 ), + Row("d", 0.0, 0.0, 0.0, 0.0, 0.0 ), + Row("e", 24.0, 12.0, 12.0, 3.4641016151377544, 3.4641016151377544 ), + Row("f", 24.0, 12.0, 12.0, 3.4641016151377544, 3.4641016151377544 ), + Row("g", 24.0, 12.0, 12.0, 3.4641016151377544, 3.4641016151377544 ), + Row("h", 24.0, 12.0, 12.0, 3.4641016151377544, 3.4641016151377544 ), + Row("i", Double.NaN, Double.NaN, Double.NaN, Double.NaN, Double.NaN))) + } + + test("collect_list in ascending ordered window") { + val df = Seq( + ("a", "p1", "1"), + ("b", "p1", "2"), + ("c", "p1", "2"), + ("d", "p1", null), + ("e", "p1", "3"), + ("f", "p2", "10"), + ("g", "p2", "11"), + ("h", "p3", "20"), + ("i", "p4", null)).toDF("key", "partition", "value") + checkAnswer( + df.select( + $"key", + sort_array( + collect_list("value").over(Window.partitionBy($"partition").orderBy($"value") + .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)))), + Seq( + Row("a", Array("1", "2", "2", "3")), + Row("b", Array("1", "2", "2", "3")), + Row("c", Array("1", "2", "2", "3")), + Row("d", Array("1", "2", "2", "3")), + Row("e", Array("1", "2", "2", "3")), + Row("f", Array("10", "11")), + Row("g", Array("10", "11")), + Row("h", Array("20")), + Row("i", Array()))) + } + + test("collect_list in descending ordered window") { + val df = Seq( + ("a", "p1", "1"), + ("b", "p1", "2"), + ("c", "p1", "2"), + ("d", "p1", null), + ("e", "p1", "3"), + ("f", "p2", "10"), + ("g", "p2", "11"), + ("h", "p3", "20"), + ("i", "p4", null)).toDF("key", "partition", "value") + checkAnswer( + df.select( + $"key", + sort_array( + collect_list("value").over(Window.partitionBy($"partition").orderBy($"value".desc) + .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)))), + Seq( + Row("a", Array("1", "2", "2", "3")), + Row("b", Array("1", "2", "2", "3")), + Row("c", Array("1", "2", "2", "3")), + Row("d", Array("1", "2", "2", "3")), + Row("e", Array("1", "2", "2", "3")), + Row("f", Array("10", "11")), + Row("g", Array("10", "11")), + Row("h", Array("20")), + Row("i", Array()))) + } + + test("collect_set in window") { + val df = Seq( + ("a", "p1", "1"), + ("b", "p1", "2"), + ("c", "p1", "2"), + ("d", "p1", "3"), + ("e", "p1", "3"), + ("f", "p2", "10"), + ("g", "p2", "11"), + ("h", "p3", "20")).toDF("key", "partition", "value") + checkAnswer( + df.select( + $"key", + sort_array( + collect_set("value").over(Window.partitionBy($"partition").orderBy($"value") + .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)))), + Seq( + Row("a", Array("1", "2", "3")), + Row("b", Array("1", "2", "3")), + Row("c", Array("1", "2", "3")), + Row("d", Array("1", "2", "3")), + Row("e", Array("1", "2", "3")), + Row("f", Array("10", "11")), + Row("g", Array("10", "11")), + Row("h", Array("20")))) + } + + test("skewness and kurtosis functions in window") { + val df = Seq( + ("a", "p1", 1.0), + ("b", "p1", 1.0), + ("c", "p1", 2.0), + ("d", "p1", 2.0), + ("e", "p1", 3.0), + ("f", "p1", 3.0), + ("g", "p1", 3.0), + ("h", "p2", 1.0), + ("i", "p2", 2.0), + ("j", "p2", 5.0)).toDF("key", "partition", "value") + checkAnswer( + df.select( + $"key", + skewness("value").over(Window.partitionBy("partition").orderBy($"key") + .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)), + kurtosis("value").over(Window.partitionBy("partition").orderBy($"key") + .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing))), + // results are checked by scipy.stats.skew() and scipy.stats.kurtosis() + Seq( + Row("a", -0.27238010581457267, -1.506920415224914), + Row("b", -0.27238010581457267, -1.506920415224914), + Row("c", -0.27238010581457267, -1.506920415224914), + Row("d", -0.27238010581457267, -1.506920415224914), + Row("e", -0.27238010581457267, -1.506920415224914), + Row("f", -0.27238010581457267, -1.506920415224914), + Row("g", -0.27238010581457267, -1.506920415224914), + Row("h", 0.5280049792181881, -1.5000000000000013), + Row("i", 0.5280049792181881, -1.5000000000000013), + Row("j", 0.5280049792181881, -1.5000000000000013))) + } + + test("aggregation function on invalid column") { + val df = Seq((1, "1")).toDF("key", "value") + val e = intercept[AnalysisException]( + df.select($"key", count("invalid").over())) + assert(e.message.contains("cannot resolve '`invalid`' given input columns: [key, value]")) + } + + test("numerical aggregate functions on string column") { + val df = Seq((1, "a", "b")).toDF("key", "value1", "value2") + checkAnswer( + df.select($"key", + var_pop("value1").over(), + variance("value1").over(), + stddev_pop("value1").over(), + stddev("value1").over(), + sum("value1").over(), + mean("value1").over(), + avg("value1").over(), + corr("value1", "value2").over(), + covar_pop("value1", "value2").over(), + covar_samp("value1", "value2").over(), + skewness("value1").over(), + kurtosis("value1").over()), + Seq(Row(1, null, null, null, null, null, null, null, null, null, null, null, null))) + } + test("statistical functions") { val df = Seq(("a", 1), ("a", 1), ("a", 2), ("a", 2), ("b", 4), ("b", 3), ("b", 2)). toDF("key", "value") @@ -232,6 +464,40 @@ class DataFrameWindowFunctionsSuite extends QueryTest with SharedSQLContext { Row("b", 2, null, null, null, null, null, null))) } + test("last/first on descending ordered window") { + val nullStr: String = null + val df = Seq( + ("a", 0, nullStr), + ("a", 1, "x"), + ("a", 2, "y"), + ("a", 3, "z"), + ("a", 4, "v"), + ("b", 1, "k"), + ("b", 2, "l"), + ("b", 3, nullStr)). + toDF("key", "order", "value") + val window = Window.partitionBy($"key").orderBy($"order".desc) + checkAnswer( + df.select( + $"key", + $"order", + first($"value").over(window), + first($"value", ignoreNulls = false).over(window), + first($"value", ignoreNulls = true).over(window), + last($"value").over(window), + last($"value", ignoreNulls = false).over(window), + last($"value", ignoreNulls = true).over(window)), + Seq( + Row("a", 0, "v", "v", "v", null, null, "x"), + Row("a", 1, "v", "v", "v", "x", "x", "x"), + Row("a", 2, "v", "v", "v", "y", "y", "y"), + Row("a", 3, "v", "v", "v", "z", "z", "z"), + Row("a", 4, "v", "v", "v", "v", "v", "v"), + Row("b", 1, null, null, "l", "k", "k", "k"), + Row("b", 2, null, null, "l", "l", "l", "l"), + Row("b", 3, null, null, null, null, null, null))) + } + test("SPARK-12989 ExtractWindowExpressions treats alias as regular attribute") { val src = Seq((0, 3, 5)).toDF("a", "b", "c") .withColumn("Data", struct("a", "b")) From e55953b0bf2a80b34127ba123417ee54955a6064 Mon Sep 17 00:00:00 2001 From: Gabor Somogyi Date: Thu, 19 Apr 2018 15:06:27 -0700 Subject: [PATCH 0656/1161] [SPARK-24022][TEST] Make SparkContextSuite not flaky ## What changes were proposed in this pull request? SparkContextSuite.test("Cancelling stages/jobs with custom reasons.") could stay in an infinite loop because of the problem found and fixed in [SPARK-23775](https://issues.apache.org/jira/browse/SPARK-23775). This PR solves this mentioned flakyness by removing shared variable usages when cancel happens in a loop and using wait and CountDownLatch for synhronization. ## How was this patch tested? Existing unit test. Author: Gabor Somogyi Closes #21105 from gaborgsomogyi/SPARK-24022. --- .../org/apache/spark/SparkContextSuite.scala | 61 ++++++++----------- 1 file changed, 26 insertions(+), 35 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala index b30bd74812b36..ce9f2be1c02dd 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark import java.io.File import java.net.{MalformedURLException, URI} import java.nio.charset.StandardCharsets -import java.util.concurrent.{Semaphore, TimeUnit} +import java.util.concurrent.{CountDownLatch, Semaphore, TimeUnit} import scala.concurrent.duration._ @@ -498,45 +498,36 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu test("Cancelling stages/jobs with custom reasons.") { sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local")) + sc.setLocalProperty(SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL, "true") val REASON = "You shall not pass" - val slices = 10 - val listener = new SparkListener { - override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = { - if (SparkContextSuite.cancelStage) { - eventually(timeout(10.seconds)) { - assert(SparkContextSuite.isTaskStarted) + for (cancelWhat <- Seq("stage", "job")) { + // This countdown latch used to make sure stage or job canceled in listener + val latch = new CountDownLatch(1) + + val listener = cancelWhat match { + case "stage" => + new SparkListener { + override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = { + sc.cancelStage(taskStart.stageId, REASON) + latch.countDown() + } } - sc.cancelStage(taskStart.stageId, REASON) - SparkContextSuite.cancelStage = false - SparkContextSuite.semaphore.release(slices) - } - } - - override def onJobStart(jobStart: SparkListenerJobStart): Unit = { - if (SparkContextSuite.cancelJob) { - eventually(timeout(10.seconds)) { - assert(SparkContextSuite.isTaskStarted) + case "job" => + new SparkListener { + override def onJobStart(jobStart: SparkListenerJobStart): Unit = { + sc.cancelJob(jobStart.jobId, REASON) + latch.countDown() + } } - sc.cancelJob(jobStart.jobId, REASON) - SparkContextSuite.cancelJob = false - SparkContextSuite.semaphore.release(slices) - } } - } - sc.addSparkListener(listener) - - for (cancelWhat <- Seq("stage", "job")) { - SparkContextSuite.semaphore.drainPermits() - SparkContextSuite.isTaskStarted = false - SparkContextSuite.cancelStage = (cancelWhat == "stage") - SparkContextSuite.cancelJob = (cancelWhat == "job") + sc.addSparkListener(listener) val ex = intercept[SparkException] { - sc.range(0, 10000L, numSlices = slices).mapPartitions { x => - SparkContextSuite.isTaskStarted = true - // Block waiting for the listener to cancel the stage or job. - SparkContextSuite.semaphore.acquire() + sc.range(0, 10000L, numSlices = 10).mapPartitions { x => + x.synchronized { + x.wait() + } x }.count() } @@ -550,9 +541,11 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu fail("Expected the cause to be SparkException, got " + cause.toString() + " instead.") } + latch.await(20, TimeUnit.SECONDS) eventually(timeout(20.seconds)) { assert(sc.statusTracker.getExecutorInfos.map(_.numRunningTasks()).sum == 0) } + sc.removeSparkListener(listener) } } @@ -637,8 +630,6 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu } object SparkContextSuite { - @volatile var cancelJob = false - @volatile var cancelStage = false @volatile var isTaskStarted = false @volatile var taskKilled = false @volatile var taskSucceeded = false From b3fde5a41ee625141b9d21ce32ea68c082449430 Mon Sep 17 00:00:00 2001 From: Ryan Blue Date: Fri, 20 Apr 2018 12:06:41 +0800 Subject: [PATCH 0657/1161] [SPARK-23877][SQL] Use filter predicates to prune partitions in metadata-only queries ## What changes were proposed in this pull request? This updates the OptimizeMetadataOnlyQuery rule to use filter expressions when listing partitions, if there are filter nodes in the logical plan. This avoids listing all partitions for large tables on the driver. This also fixes a minor bug where the partitions returned from fsRelation cannot be serialized without hitting a stack level too deep error. This is caused by serializing a stream to executors, where the stream is a recursive structure. If the stream is too long, the serialization stack reaches the maximum level of depth. The fix is to create a LocalRelation using an Array instead of the incoming Seq. ## How was this patch tested? Existing tests for metadata-only queries. Author: Ryan Blue Closes #20988 from rdblue/SPARK-23877-metadata-only-push-filters. --- .../execution/OptimizeMetadataOnlyQuery.scala | 94 +++++++++++++------ .../OptimizeHiveMetadataOnlyQuerySuite.scala | 68 ++++++++++++++ 2 files changed, 132 insertions(+), 30 deletions(-) create mode 100644 sql/hive/src/test/scala/org/apache/spark/sql/hive/OptimizeHiveMetadataOnlyQuerySuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala index dc4aff9f12580..acbd4becb8549 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala @@ -49,9 +49,9 @@ case class OptimizeMetadataOnlyQuery(catalog: SessionCatalog) extends Rule[Logic } plan.transform { - case a @ Aggregate(_, aggExprs, child @ PartitionedRelation(partAttrs, relation)) => + case a @ Aggregate(_, aggExprs, child @ PartitionedRelation(_, attrs, filters, rel)) => // We only apply this optimization when only partitioned attributes are scanned. - if (a.references.subsetOf(partAttrs)) { + if (a.references.subsetOf(attrs)) { val aggFunctions = aggExprs.flatMap(_.collect { case agg: AggregateExpression => agg }) @@ -67,7 +67,7 @@ case class OptimizeMetadataOnlyQuery(catalog: SessionCatalog) extends Rule[Logic }) } if (isAllDistinctAgg) { - a.withNewChildren(Seq(replaceTableScanWithPartitionMetadata(child, relation))) + a.withNewChildren(Seq(replaceTableScanWithPartitionMetadata(child, rel, filters))) } else { a } @@ -98,14 +98,27 @@ case class OptimizeMetadataOnlyQuery(catalog: SessionCatalog) extends Rule[Logic */ private def replaceTableScanWithPartitionMetadata( child: LogicalPlan, - relation: LogicalPlan): LogicalPlan = { + relation: LogicalPlan, + partFilters: Seq[Expression]): LogicalPlan = { + // this logic comes from PruneFileSourcePartitions. it ensures that the filter names match the + // relation's schema. PartitionedRelation ensures that the filters only reference partition cols + val relFilters = partFilters.map { e => + e transform { + case a: AttributeReference => + a.withName(relation.output.find(_.semanticEquals(a)).get.name) + } + } + child transform { case plan if plan eq relation => relation match { case l @ LogicalRelation(fsRelation: HadoopFsRelation, _, _, isStreaming) => val partAttrs = getPartitionAttrs(fsRelation.partitionSchema.map(_.name), l) - val partitionData = fsRelation.location.listFiles(Nil, Nil) - LocalRelation(partAttrs, partitionData.map(_.values), isStreaming) + val partitionData = fsRelation.location.listFiles(relFilters, Nil) + // partition data may be a stream, which can cause serialization to hit stack level too + // deep exceptions because it is a recursive structure in memory. converting to array + // avoids the problem. + LocalRelation(partAttrs, partitionData.map(_.values).toArray, isStreaming) case relation: HiveTableRelation => val partAttrs = getPartitionAttrs(relation.tableMeta.partitionColumnNames, relation) @@ -113,12 +126,21 @@ case class OptimizeMetadataOnlyQuery(catalog: SessionCatalog) extends Rule[Logic CaseInsensitiveMap(relation.tableMeta.storage.properties) val timeZoneId = caseInsensitiveProperties.get(DateTimeUtils.TIMEZONE_OPTION) .getOrElse(SQLConf.get.sessionLocalTimeZone) - val partitionData = catalog.listPartitions(relation.tableMeta.identifier).map { p => + val partitions = if (partFilters.nonEmpty) { + catalog.listPartitionsByFilter(relation.tableMeta.identifier, relFilters) + } else { + catalog.listPartitions(relation.tableMeta.identifier) + } + + val partitionData = partitions.map { p => InternalRow.fromSeq(partAttrs.map { attr => Cast(Literal(p.spec(attr.name)), attr.dataType, Option(timeZoneId)).eval() }) } - LocalRelation(partAttrs, partitionData) + // partition data may be a stream, which can cause serialization to hit stack level too + // deep exceptions because it is a recursive structure in memory. converting to array + // avoids the problem. + LocalRelation(partAttrs, partitionData.toArray) case _ => throw new IllegalStateException(s"unrecognized table scan node: $relation, " + @@ -129,35 +151,47 @@ case class OptimizeMetadataOnlyQuery(catalog: SessionCatalog) extends Rule[Logic /** * A pattern that finds the partitioned table relation node inside the given plan, and returns a - * pair of the partition attributes and the table relation node. + * pair of the partition attributes, partition filters, and the table relation node. * * It keeps traversing down the given plan tree if there is a [[Project]] or [[Filter]] with * deterministic expressions, and returns result after reaching the partitioned table relation * node. */ - object PartitionedRelation { - - def unapply(plan: LogicalPlan): Option[(AttributeSet, LogicalPlan)] = plan match { - case l @ LogicalRelation(fsRelation: HadoopFsRelation, _, _, _) - if fsRelation.partitionSchema.nonEmpty => - val partAttrs = getPartitionAttrs(fsRelation.partitionSchema.map(_.name), l) - Some((AttributeSet(partAttrs), l)) - - case relation: HiveTableRelation if relation.tableMeta.partitionColumnNames.nonEmpty => - val partAttrs = getPartitionAttrs(relation.tableMeta.partitionColumnNames, relation) - Some((AttributeSet(partAttrs), relation)) - - case p @ Project(projectList, child) if projectList.forall(_.deterministic) => - unapply(child).flatMap { case (partAttrs, relation) => - if (p.references.subsetOf(partAttrs)) Some((p.outputSet, relation)) else None - } + object PartitionedRelation extends PredicateHelper { + + def unapply( + plan: LogicalPlan): Option[(AttributeSet, AttributeSet, Seq[Expression], LogicalPlan)] = { + plan match { + case l @ LogicalRelation(fsRelation: HadoopFsRelation, _, _, _) + if fsRelation.partitionSchema.nonEmpty => + val partAttrs = AttributeSet(getPartitionAttrs(fsRelation.partitionSchema.map(_.name), l)) + Some((partAttrs, partAttrs, Nil, l)) + + case relation: HiveTableRelation if relation.tableMeta.partitionColumnNames.nonEmpty => + val partAttrs = AttributeSet( + getPartitionAttrs(relation.tableMeta.partitionColumnNames, relation)) + Some((partAttrs, partAttrs, Nil, relation)) + + case p @ Project(projectList, child) if projectList.forall(_.deterministic) => + unapply(child).flatMap { case (partAttrs, attrs, filters, relation) => + if (p.references.subsetOf(attrs)) { + Some((partAttrs, p.outputSet, filters, relation)) + } else { + None + } + } - case f @ Filter(condition, child) if condition.deterministic => - unapply(child).flatMap { case (partAttrs, relation) => - if (f.references.subsetOf(partAttrs)) Some((partAttrs, relation)) else None - } + case f @ Filter(condition, child) if condition.deterministic => + unapply(child).flatMap { case (partAttrs, attrs, filters, relation) => + if (f.references.subsetOf(partAttrs)) { + Some((partAttrs, attrs, splitConjunctivePredicates(condition) ++ filters, relation)) + } else { + None + } + } - case _ => None + case _ => None + } } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/OptimizeHiveMetadataOnlyQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/OptimizeHiveMetadataOnlyQuerySuite.scala new file mode 100644 index 0000000000000..95f192f0e40e2 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/OptimizeHiveMetadataOnlyQuerySuite.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import org.scalatest.BeforeAndAfter + +import org.apache.spark.metrics.source.HiveCatalogMetrics +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.catalyst.expressions.NamedExpression +import org.apache.spark.sql.catalyst.plans.logical.{Distinct, Filter, Project, SubqueryAlias} +import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.types.{IntegerType, StructField, StructType} + +class OptimizeHiveMetadataOnlyQuerySuite extends QueryTest with TestHiveSingleton + with BeforeAndAfter with SQLTestUtils { + + import spark.implicits._ + + before { + sql("CREATE TABLE metadata_only (id bigint, data string) PARTITIONED BY (part int)") + (0 to 10).foreach(p => sql(s"ALTER TABLE metadata_only ADD PARTITION (part=$p)")) + } + + test("SPARK-23877: validate metadata-only query pushes filters to metastore") { + withTable("metadata_only") { + val startCount = HiveCatalogMetrics.METRIC_PARTITIONS_FETCHED.getCount + + // verify the number of matching partitions + assert(sql("SELECT DISTINCT part FROM metadata_only WHERE part < 5").collect().length === 5) + + // verify that the partition predicate was pushed down to the metastore + assert(HiveCatalogMetrics.METRIC_PARTITIONS_FETCHED.getCount - startCount === 5) + } + } + + test("SPARK-23877: filter on projected expression") { + withTable("metadata_only") { + val startCount = HiveCatalogMetrics.METRIC_PARTITIONS_FETCHED.getCount + + // verify the matching partitions + val partitions = spark.internalCreateDataFrame(Distinct(Filter(($"x" < 5).expr, + Project(Seq(($"part" + 1).as("x").expr.asInstanceOf[NamedExpression]), + spark.table("metadata_only").logicalPlan.asInstanceOf[SubqueryAlias].child))) + .queryExecution.toRdd, StructType(Seq(StructField("x", IntegerType)))) + + checkAnswer(partitions, Seq(1, 2, 3, 4).toDF("x")) + + // verify that the partition predicate was not pushed down to the metastore + assert(HiveCatalogMetrics.METRIC_PARTITIONS_FETCHED.getCount - startCount == 11) + } + } +} From e6b466084c26fbb9b9e50dd5cc8b25da7533ac72 Mon Sep 17 00:00:00 2001 From: mn-mikke Date: Fri, 20 Apr 2018 14:58:11 +0900 Subject: [PATCH 0658/1161] [SPARK-23736][SQL] Extending the concat function to support array columns ## What changes were proposed in this pull request? The PR adds a logic for easy concatenation of multiple array columns and covers: - Concat expression has been extended to support array columns - A Python wrapper ## How was this patch tested? New tests added into: - CollectionExpressionsSuite - DataFrameFunctionsSuite - typeCoercion/native/concat.sql ## Codegen examples ### Primitive-type elements ``` val df = Seq( (Seq(1 ,2), Seq(3, 4)), (Seq(1, 2, 3), null) ).toDF("a", "b") df.filter('a.isNotNull).select(concat('a, 'b)).debugCodegen() ``` Result: ``` /* 033 */ boolean inputadapter_isNull = inputadapter_row.isNullAt(0); /* 034 */ ArrayData inputadapter_value = inputadapter_isNull ? /* 035 */ null : (inputadapter_row.getArray(0)); /* 036 */ /* 037 */ if (!(!inputadapter_isNull)) continue; /* 038 */ /* 039 */ ((org.apache.spark.sql.execution.metric.SQLMetric) references[0] /* numOutputRows */).add(1); /* 040 */ /* 041 */ ArrayData[] project_args = new ArrayData[2]; /* 042 */ /* 043 */ if (!false) { /* 044 */ project_args[0] = inputadapter_value; /* 045 */ } /* 046 */ /* 047 */ boolean inputadapter_isNull1 = inputadapter_row.isNullAt(1); /* 048 */ ArrayData inputadapter_value1 = inputadapter_isNull1 ? /* 049 */ null : (inputadapter_row.getArray(1)); /* 050 */ if (!inputadapter_isNull1) { /* 051 */ project_args[1] = inputadapter_value1; /* 052 */ } /* 053 */ /* 054 */ ArrayData project_value = new Object() { /* 055 */ public ArrayData concat(ArrayData[] args) { /* 056 */ for (int z = 0; z < 2; z++) { /* 057 */ if (args[z] == null) return null; /* 058 */ } /* 059 */ /* 060 */ long project_numElements = 0L; /* 061 */ for (int z = 0; z < 2; z++) { /* 062 */ project_numElements += args[z].numElements(); /* 063 */ } /* 064 */ if (project_numElements > 2147483632) { /* 065 */ throw new RuntimeException("Unsuccessful try to concat arrays with " + project_numElements + /* 066 */ " elements due to exceeding the array size limit 2147483632."); /* 067 */ } /* 068 */ /* 069 */ long project_size = UnsafeArrayData.calculateSizeOfUnderlyingByteArray( /* 070 */ project_numElements, /* 071 */ 4); /* 072 */ if (project_size > 2147483632) { /* 073 */ throw new RuntimeException("Unsuccessful try to concat arrays with " + project_size + /* 074 */ " bytes of data due to exceeding the limit 2147483632 bytes" + /* 075 */ " for UnsafeArrayData."); /* 076 */ } /* 077 */ /* 078 */ byte[] project_array = new byte[(int)project_size]; /* 079 */ UnsafeArrayData project_arrayData = new UnsafeArrayData(); /* 080 */ Platform.putLong(project_array, 16, project_numElements); /* 081 */ project_arrayData.pointTo(project_array, 16, (int)project_size); /* 082 */ int project_counter = 0; /* 083 */ for (int y = 0; y < 2; y++) { /* 084 */ for (int z = 0; z < args[y].numElements(); z++) { /* 085 */ if (args[y].isNullAt(z)) { /* 086 */ project_arrayData.setNullAt(project_counter); /* 087 */ } else { /* 088 */ project_arrayData.setInt( /* 089 */ project_counter, /* 090 */ args[y].getInt(z) /* 091 */ ); /* 092 */ } /* 093 */ project_counter++; /* 094 */ } /* 095 */ } /* 096 */ return project_arrayData; /* 097 */ } /* 098 */ }.concat(project_args); /* 099 */ boolean project_isNull = project_value == null; ``` ### Non-primitive-type elements ``` val df = Seq( (Seq("aa" ,"bb"), Seq("ccc", "ddd")), (Seq("x", "y"), null) ).toDF("a", "b") df.filter('a.isNotNull).select(concat('a, 'b)).debugCodegen() ``` Result: ``` /* 033 */ boolean inputadapter_isNull = inputadapter_row.isNullAt(0); /* 034 */ ArrayData inputadapter_value = inputadapter_isNull ? /* 035 */ null : (inputadapter_row.getArray(0)); /* 036 */ /* 037 */ if (!(!inputadapter_isNull)) continue; /* 038 */ /* 039 */ ((org.apache.spark.sql.execution.metric.SQLMetric) references[0] /* numOutputRows */).add(1); /* 040 */ /* 041 */ ArrayData[] project_args = new ArrayData[2]; /* 042 */ /* 043 */ if (!false) { /* 044 */ project_args[0] = inputadapter_value; /* 045 */ } /* 046 */ /* 047 */ boolean inputadapter_isNull1 = inputadapter_row.isNullAt(1); /* 048 */ ArrayData inputadapter_value1 = inputadapter_isNull1 ? /* 049 */ null : (inputadapter_row.getArray(1)); /* 050 */ if (!inputadapter_isNull1) { /* 051 */ project_args[1] = inputadapter_value1; /* 052 */ } /* 053 */ /* 054 */ ArrayData project_value = new Object() { /* 055 */ public ArrayData concat(ArrayData[] args) { /* 056 */ for (int z = 0; z < 2; z++) { /* 057 */ if (args[z] == null) return null; /* 058 */ } /* 059 */ /* 060 */ long project_numElements = 0L; /* 061 */ for (int z = 0; z < 2; z++) { /* 062 */ project_numElements += args[z].numElements(); /* 063 */ } /* 064 */ if (project_numElements > 2147483632) { /* 065 */ throw new RuntimeException("Unsuccessful try to concat arrays with " + project_numElements + /* 066 */ " elements due to exceeding the array size limit 2147483632."); /* 067 */ } /* 068 */ /* 069 */ Object[] project_arrayObjects = new Object[(int)project_numElements]; /* 070 */ int project_counter = 0; /* 071 */ for (int y = 0; y < 2; y++) { /* 072 */ for (int z = 0; z < args[y].numElements(); z++) { /* 073 */ project_arrayObjects[project_counter] = args[y].getUTF8String(z); /* 074 */ project_counter++; /* 075 */ } /* 076 */ } /* 077 */ return new org.apache.spark.sql.catalyst.util.GenericArrayData(project_arrayObjects); /* 078 */ } /* 079 */ }.concat(project_args); /* 080 */ boolean project_isNull = project_value == null; ``` Author: mn-mikke Closes #20858 from mn-mikke/feature/array-api-concat_arrays-to-master. --- .../spark/unsafe/array/ByteArrayMethods.java | 6 +- python/pyspark/sql/functions.py | 34 +-- .../catalyst/expressions/UnsafeArrayData.java | 10 + .../catalyst/analysis/FunctionRegistry.scala | 2 +- .../sql/catalyst/analysis/TypeCoercion.scala | 8 + .../expressions/collectionOperations.scala | 220 +++++++++++++++++- .../expressions/stringExpressions.scala | 81 ------- .../CollectionExpressionsSuite.scala | 41 ++++ .../org/apache/spark/sql/functions.scala | 20 +- .../inputs/typeCoercion/native/concat.sql | 62 +++++ .../typeCoercion/native/concat.sql.out | 78 +++++++ .../spark/sql/DataFrameFunctionsSuite.scala | 74 ++++++ .../sql/execution/command/DDLSuite.scala | 4 +- 13 files changed, 529 insertions(+), 111 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java index 4bc9955090fd7..ef0f78d95d1ee 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java @@ -33,7 +33,11 @@ public static long nextPowerOf2(long num) { } public static int roundNumberOfBytesToNearestWord(int numBytes) { - int remainder = numBytes & 0x07; // This is equivalent to `numBytes % 8` + return (int)roundNumberOfBytesToNearestWord((long)numBytes); + } + + public static long roundNumberOfBytesToNearestWord(long numBytes) { + long remainder = numBytes & 0x07; // This is equivalent to `numBytes % 8` if (remainder == 0) { return numBytes; } else { diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 1be68f2a4a448..da32ab25cad0c 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1425,21 +1425,6 @@ def hash(*cols): del _name, _doc -@since(1.5) -@ignore_unicode_prefix -def concat(*cols): - """ - Concatenates multiple input columns together into a single column. - If all inputs are binary, concat returns an output as binary. Otherwise, it returns as string. - - >>> df = spark.createDataFrame([('abcd','123')], ['s', 'd']) - >>> df.select(concat(df.s, df.d).alias('s')).collect() - [Row(s=u'abcd123')] - """ - sc = SparkContext._active_spark_context - return Column(sc._jvm.functions.concat(_to_seq(sc, cols, _to_java_column))) - - @since(1.5) @ignore_unicode_prefix def concat_ws(sep, *cols): @@ -1845,6 +1830,25 @@ def array_contains(col, value): return Column(sc._jvm.functions.array_contains(_to_java_column(col), value)) +@since(1.5) +@ignore_unicode_prefix +def concat(*cols): + """ + Concatenates multiple input columns together into a single column. + The function works with strings, binary and compatible array columns. + + >>> df = spark.createDataFrame([('abcd','123')], ['s', 'd']) + >>> df.select(concat(df.s, df.d).alias('s')).collect() + [Row(s=u'abcd123')] + + >>> df = spark.createDataFrame([([1, 2], [3, 4], [5]), ([1, 2], None, [3])], ['a', 'b', 'c']) + >>> df.select(concat(df.a, df.b, df.c).alias("arr")).collect() + [Row(arr=[1, 2, 3, 4, 5]), Row(arr=None)] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.concat(_to_seq(sc, cols, _to_java_column))) + + @since(2.4) def array_position(col, value): """ diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java index 8546c28335536..d5d934bc91cab 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java @@ -56,9 +56,19 @@ public final class UnsafeArrayData extends ArrayData { public static int calculateHeaderPortionInBytes(int numFields) { + return (int)calculateHeaderPortionInBytes((long)numFields); + } + + public static long calculateHeaderPortionInBytes(long numFields) { return 8 + ((numFields + 63)/ 64) * 8; } + public static long calculateSizeOfUnderlyingByteArray(long numFields, int elementSize) { + long size = UnsafeArrayData.calculateHeaderPortionInBytes(numFields) + + ByteArrayMethods.roundNumberOfBytesToNearestWord(numFields * elementSize); + return size; + } + private Object baseObject; private long baseOffset; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index a44f2d5272b8e..c41f16c61d7a2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -308,7 +308,6 @@ object FunctionRegistry { expression[BitLength]("bit_length"), expression[Length]("char_length"), expression[Length]("character_length"), - expression[Concat]("concat"), expression[ConcatWs]("concat_ws"), expression[Decode]("decode"), expression[Elt]("elt"), @@ -413,6 +412,7 @@ object FunctionRegistry { expression[ArrayMin]("array_min"), expression[ArrayMax]("array_max"), expression[Reverse]("reverse"), + expression[Concat]("concat"), CreateStruct.registryEntry, // misc functions diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 281f206e8d59e..cfcbd8db559a3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -520,6 +520,14 @@ object TypeCoercion { case None => a } + case c @ Concat(children) if children.forall(c => ArrayType.acceptsType(c.dataType)) && + !haveSameType(children) => + val types = children.map(_.dataType) + findWiderCommonType(types) match { + case Some(finalDataType) => Concat(children.map(Cast(_, finalDataType))) + case None => c + } + case m @ CreateMap(children) if m.keys.length == m.values.length && (!haveSameType(m.keys) || !haveSameType(m.values)) => val newKeys = if (haveSameType(m.keys)) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index dba426e999dda..c16793bda028e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -23,7 +23,9 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData, TypeUtils} import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.unsafe.Platform +import org.apache.spark.unsafe.array.ByteArrayMethods +import org.apache.spark.unsafe.types.{ByteArray, UTF8String} /** * Given an array or map, returns its size. Returns -1 if null. @@ -665,3 +667,219 @@ case class ElementAt(left: Expression, right: Expression) extends GetMapValueUti override def prettyName: String = "element_at" } + +/** + * Concatenates multiple input columns together into a single column. + * The function works with strings, binary and compatible array columns. + */ +@ExpressionDescription( + usage = "_FUNC_(col1, col2, ..., colN) - Returns the concatenation of col1, col2, ..., colN.", + examples = """ + Examples: + > SELECT _FUNC_('Spark', 'SQL'); + SparkSQL + > SELECT _FUNC_(array(1, 2, 3), array(4, 5), array(6)); + | [1,2,3,4,5,6] + """) +case class Concat(children: Seq[Expression]) extends Expression { + + private val MAX_ARRAY_LENGTH: Int = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH + + val allowedTypes = Seq(StringType, BinaryType, ArrayType) + + override def checkInputDataTypes(): TypeCheckResult = { + if (children.isEmpty) { + TypeCheckResult.TypeCheckSuccess + } else { + val childTypes = children.map(_.dataType) + if (childTypes.exists(tpe => !allowedTypes.exists(_.acceptsType(tpe)))) { + return TypeCheckResult.TypeCheckFailure( + s"input to function $prettyName should have been StringType, BinaryType or ArrayType," + + s" but it's " + childTypes.map(_.simpleString).mkString("[", ", ", "]")) + } + TypeUtils.checkForSameTypeInputExpr(childTypes, s"function $prettyName") + } + } + + override def dataType: DataType = children.map(_.dataType).headOption.getOrElse(StringType) + + lazy val javaType: String = CodeGenerator.javaType(dataType) + + override def nullable: Boolean = children.exists(_.nullable) + + override def foldable: Boolean = children.forall(_.foldable) + + override def eval(input: InternalRow): Any = dataType match { + case BinaryType => + val inputs = children.map(_.eval(input).asInstanceOf[Array[Byte]]) + ByteArray.concat(inputs: _*) + case StringType => + val inputs = children.map(_.eval(input).asInstanceOf[UTF8String]) + UTF8String.concat(inputs : _*) + case ArrayType(elementType, _) => + val inputs = children.toStream.map(_.eval(input)) + if (inputs.contains(null)) { + null + } else { + val arrayData = inputs.map(_.asInstanceOf[ArrayData]) + val numberOfElements = arrayData.foldLeft(0L)((sum, ad) => sum + ad.numElements()) + if (numberOfElements > MAX_ARRAY_LENGTH) { + throw new RuntimeException(s"Unsuccessful try to concat arrays with $numberOfElements" + + s" elements due to exceeding the array size limit $MAX_ARRAY_LENGTH.") + } + val finalData = new Array[AnyRef](numberOfElements.toInt) + var position = 0 + for(ad <- arrayData) { + val arr = ad.toObjectArray(elementType) + Array.copy(arr, 0, finalData, position, arr.length) + position += arr.length + } + new GenericArrayData(finalData) + } + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val evals = children.map(_.genCode(ctx)) + val args = ctx.freshName("args") + + val inputs = evals.zipWithIndex.map { case (eval, index) => + s""" + ${eval.code} + if (!${eval.isNull}) { + $args[$index] = ${eval.value}; + } + """ + } + + val (concatenator, initCode) = dataType match { + case BinaryType => + (classOf[ByteArray].getName, s"byte[][] $args = new byte[${evals.length}][];") + case StringType => + ("UTF8String", s"UTF8String[] $args = new UTF8String[${evals.length}];") + case ArrayType(elementType, _) => + val arrayConcatClass = if (CodeGenerator.isPrimitiveType(elementType)) { + genCodeForPrimitiveArrays(ctx, elementType) + } else { + genCodeForNonPrimitiveArrays(ctx, elementType) + } + (arrayConcatClass, s"ArrayData[] $args = new ArrayData[${evals.length}];") + } + val codes = ctx.splitExpressionsWithCurrentInputs( + expressions = inputs, + funcName = "valueConcat", + extraArguments = (s"$javaType[]", args) :: Nil) + ev.copy(s""" + $initCode + $codes + $javaType ${ev.value} = $concatenator.concat($args); + boolean ${ev.isNull} = ${ev.value} == null; + """) + } + + private def genCodeForNumberOfElements(ctx: CodegenContext) : (String, String) = { + val numElements = ctx.freshName("numElements") + val code = s""" + |long $numElements = 0L; + |for (int z = 0; z < ${children.length}; z++) { + | $numElements += args[z].numElements(); + |} + |if ($numElements > $MAX_ARRAY_LENGTH) { + | throw new RuntimeException("Unsuccessful try to concat arrays with " + $numElements + + | " elements due to exceeding the array size limit $MAX_ARRAY_LENGTH."); + |} + """.stripMargin + + (code, numElements) + } + + private def nullArgumentProtection() : String = { + if (nullable) { + s""" + |for (int z = 0; z < ${children.length}; z++) { + | if (args[z] == null) return null; + |} + """.stripMargin + } else { + "" + } + } + + private def genCodeForPrimitiveArrays(ctx: CodegenContext, elementType: DataType): String = { + val arrayName = ctx.freshName("array") + val arraySizeName = ctx.freshName("size") + val counter = ctx.freshName("counter") + val arrayData = ctx.freshName("arrayData") + + val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx) + + val unsafeArraySizeInBytes = s""" + |long $arraySizeName = UnsafeArrayData.calculateSizeOfUnderlyingByteArray( + | $numElemName, + | ${elementType.defaultSize}); + |if ($arraySizeName > $MAX_ARRAY_LENGTH) { + | throw new RuntimeException("Unsuccessful try to concat arrays with " + $arraySizeName + + | " bytes of data due to exceeding the limit $MAX_ARRAY_LENGTH bytes" + + | " for UnsafeArrayData."); + |} + """.stripMargin + val baseOffset = Platform.BYTE_ARRAY_OFFSET + val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType) + + s""" + |new Object() { + | public ArrayData concat($javaType[] args) { + | ${nullArgumentProtection()} + | $numElemCode + | $unsafeArraySizeInBytes + | byte[] $arrayName = new byte[(int)$arraySizeName]; + | UnsafeArrayData $arrayData = new UnsafeArrayData(); + | Platform.putLong($arrayName, $baseOffset, $numElemName); + | $arrayData.pointTo($arrayName, $baseOffset, (int)$arraySizeName); + | int $counter = 0; + | for (int y = 0; y < ${children.length}; y++) { + | for (int z = 0; z < args[y].numElements(); z++) { + | if (args[y].isNullAt(z)) { + | $arrayData.setNullAt($counter); + | } else { + | $arrayData.set$primitiveValueTypeName( + | $counter, + | ${CodeGenerator.getValue(s"args[y]", elementType, "z")} + | ); + | } + | $counter++; + | } + | } + | return $arrayData; + | } + |}""".stripMargin.stripPrefix("\n") + } + + private def genCodeForNonPrimitiveArrays(ctx: CodegenContext, elementType: DataType): String = { + val genericArrayClass = classOf[GenericArrayData].getName + val arrayData = ctx.freshName("arrayObjects") + val counter = ctx.freshName("counter") + + val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx) + + s""" + |new Object() { + | public ArrayData concat($javaType[] args) { + | ${nullArgumentProtection()} + | $numElemCode + | Object[] $arrayData = new Object[(int)$numElemName]; + | int $counter = 0; + | for (int y = 0; y < ${children.length}; y++) { + | for (int z = 0; z < args[y].numElements(); z++) { + | $arrayData[$counter] = ${CodeGenerator.getValue(s"args[y]", elementType, "z")}; + | $counter++; + | } + | } + | return new $genericArrayClass($arrayData); + | } + |}""".stripMargin.stripPrefix("\n") + } + + override def toString: String = s"concat(${children.mkString(", ")})" + + override def sql: String = s"concat(${children.map(_.sql).mkString(", ")})" +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 5a02ca0d6862c..ea005a26a4c8b 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -36,87 +36,6 @@ import org.apache.spark.unsafe.types.{ByteArray, UTF8String} //////////////////////////////////////////////////////////////////////////////////////////////////// -/** - * An expression that concatenates multiple inputs into a single output. - * If all inputs are binary, concat returns an output as binary. Otherwise, it returns as string. - * If any input is null, concat returns null. - */ -@ExpressionDescription( - usage = "_FUNC_(str1, str2, ..., strN) - Returns the concatenation of str1, str2, ..., strN.", - examples = """ - Examples: - > SELECT _FUNC_('Spark', 'SQL'); - SparkSQL - """) -case class Concat(children: Seq[Expression]) extends Expression { - - private lazy val isBinaryMode: Boolean = dataType == BinaryType - - override def checkInputDataTypes(): TypeCheckResult = { - if (children.isEmpty) { - TypeCheckResult.TypeCheckSuccess - } else { - val childTypes = children.map(_.dataType) - if (childTypes.exists(tpe => !Seq(StringType, BinaryType).contains(tpe))) { - return TypeCheckResult.TypeCheckFailure( - s"input to function $prettyName should have StringType or BinaryType, but it's " + - childTypes.map(_.simpleString).mkString("[", ", ", "]")) - } - TypeUtils.checkForSameTypeInputExpr(childTypes, s"function $prettyName") - } - } - - override def dataType: DataType = children.map(_.dataType).headOption.getOrElse(StringType) - - override def nullable: Boolean = children.exists(_.nullable) - override def foldable: Boolean = children.forall(_.foldable) - - override def eval(input: InternalRow): Any = { - if (isBinaryMode) { - val inputs = children.map(_.eval(input).asInstanceOf[Array[Byte]]) - ByteArray.concat(inputs: _*) - } else { - val inputs = children.map(_.eval(input).asInstanceOf[UTF8String]) - UTF8String.concat(inputs : _*) - } - } - - override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val evals = children.map(_.genCode(ctx)) - val args = ctx.freshName("args") - - val inputs = evals.zipWithIndex.map { case (eval, index) => - s""" - ${eval.code} - if (!${eval.isNull}) { - $args[$index] = ${eval.value}; - } - """ - } - - val (concatenator, initCode) = if (isBinaryMode) { - (classOf[ByteArray].getName, s"byte[][] $args = new byte[${evals.length}][];") - } else { - ("UTF8String", s"UTF8String[] $args = new UTF8String[${evals.length}];") - } - val codes = ctx.splitExpressionsWithCurrentInputs( - expressions = inputs, - funcName = "valueConcat", - extraArguments = (s"${CodeGenerator.javaType(dataType)}[]", args) :: Nil) - ev.copy(s""" - $initCode - $codes - ${CodeGenerator.javaType(dataType)} ${ev.value} = $concatenator.concat($args); - boolean ${ev.isNull} = ${ev.value} == null; - """) - } - - override def toString: String = s"concat(${children.mkString(", ")})" - - override def sql: String = s"concat(${children.map(_.sql).mkString(", ")})" -} - - /** * An expression that concatenates multiple input strings or array of strings into a single string, * using a given separator (the first child). diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 7d8fe211858b2..43c5dda2e4a48 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -239,4 +239,45 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(ElementAt(m2, Literal("a")), null) } + + test("Concat") { + // Primitive-type elements + val ai0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType)) + val ai1 = Literal.create(Seq.empty[Integer], ArrayType(IntegerType)) + val ai2 = Literal.create(Seq(4, null, 5), ArrayType(IntegerType)) + val ai3 = Literal.create(Seq(null, null), ArrayType(IntegerType)) + val ai4 = Literal.create(null, ArrayType(IntegerType)) + + checkEvaluation(Concat(Seq(ai0)), Seq(1, 2, 3)) + checkEvaluation(Concat(Seq(ai0, ai1)), Seq(1, 2, 3)) + checkEvaluation(Concat(Seq(ai1, ai0)), Seq(1, 2, 3)) + checkEvaluation(Concat(Seq(ai0, ai0)), Seq(1, 2, 3, 1, 2, 3)) + checkEvaluation(Concat(Seq(ai0, ai2)), Seq(1, 2, 3, 4, null, 5)) + checkEvaluation(Concat(Seq(ai0, ai3, ai2)), Seq(1, 2, 3, null, null, 4, null, 5)) + checkEvaluation(Concat(Seq(ai4)), null) + checkEvaluation(Concat(Seq(ai0, ai4)), null) + checkEvaluation(Concat(Seq(ai4, ai0)), null) + + // Non-primitive-type elements + val as0 = Literal.create(Seq("a", "b", "c"), ArrayType(StringType)) + val as1 = Literal.create(Seq.empty[String], ArrayType(StringType)) + val as2 = Literal.create(Seq("d", null, "e"), ArrayType(StringType)) + val as3 = Literal.create(Seq(null, null), ArrayType(StringType)) + val as4 = Literal.create(null, ArrayType(StringType)) + + val aa0 = Literal.create(Seq(Seq("a", "b"), Seq("c")), ArrayType(ArrayType(StringType))) + val aa1 = Literal.create(Seq(Seq("d"), Seq("e", "f")), ArrayType(ArrayType(StringType))) + + checkEvaluation(Concat(Seq(as0)), Seq("a", "b", "c")) + checkEvaluation(Concat(Seq(as0, as1)), Seq("a", "b", "c")) + checkEvaluation(Concat(Seq(as1, as0)), Seq("a", "b", "c")) + checkEvaluation(Concat(Seq(as0, as0)), Seq("a", "b", "c", "a", "b", "c")) + checkEvaluation(Concat(Seq(as0, as2)), Seq("a", "b", "c", "d", null, "e")) + checkEvaluation(Concat(Seq(as0, as3, as2)), Seq("a", "b", "c", null, null, "d", null, "e")) + checkEvaluation(Concat(Seq(as4)), null) + checkEvaluation(Concat(Seq(as0, as4)), null) + checkEvaluation(Concat(Seq(as4, as0)), null) + + checkEvaluation(Concat(Seq(aa0, aa1)), Seq(Seq("a", "b"), Seq("c"), Seq("d"), Seq("e", "f"))) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 9c8580378303e..bea8c0e445002 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -2228,16 +2228,6 @@ object functions { */ def base64(e: Column): Column = withExpr { Base64(e.expr) } - /** - * Concatenates multiple input columns together into a single column. - * If all inputs are binary, concat returns an output as binary. Otherwise, it returns as string. - * - * @group string_funcs - * @since 1.5.0 - */ - @scala.annotation.varargs - def concat(exprs: Column*): Column = withExpr { Concat(exprs.map(_.expr)) } - /** * Concatenates multiple input string columns together into a single string column, * using the given separator. @@ -3038,6 +3028,16 @@ object functions { ArrayContains(column.expr, Literal(value)) } + /** + * Concatenates multiple input columns together into a single column. + * The function works with strings, binary and compatible array columns. + * + * @group collection_funcs + * @since 1.5.0 + */ + @scala.annotation.varargs + def concat(exprs: Column*): Column = withExpr { Concat(exprs.map(_.expr)) } + /** * Locates the position of the first occurrence of the value in the given array as long. * Returns null if either of the arguments are null. diff --git a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/concat.sql b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/concat.sql index 0beebec5702fd..db00a18f2e7e9 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/concat.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/concat.sql @@ -91,3 +91,65 @@ FROM ( encode(string(id + 3), 'utf-8') col4 FROM range(10) ); + +CREATE TEMPORARY VIEW various_arrays AS SELECT * FROM VALUES ( + array(true, false), array(true), + array(2Y, 1Y), array(3Y, 4Y), + array(2S, 1S), array(3S, 4S), + array(2, 1), array(3, 4), + array(2L, 1L), array(3L, 4L), + array(9223372036854775809, 9223372036854775808), array(9223372036854775808, 9223372036854775809), + array(2.0D, 1.0D), array(3.0D, 4.0D), + array(float(2.0), float(1.0)), array(float(3.0), float(4.0)), + array(date '2016-03-14', date '2016-03-13'), array(date '2016-03-12', date '2016-03-11'), + array(timestamp '2016-11-15 20:54:00.000', timestamp '2016-11-12 20:54:00.000'), + array(timestamp '2016-11-11 20:54:00.000'), + array('a', 'b'), array('c', 'd'), + array(array('a', 'b'), array('c', 'd')), array(array('e'), array('f')), + array(struct('a', 1), struct('b', 2)), array(struct('c', 3), struct('d', 4)), + array(map('a', 1), map('b', 2)), array(map('c', 3), map('d', 4)) +) AS various_arrays( + boolean_array1, boolean_array2, + tinyint_array1, tinyint_array2, + smallint_array1, smallint_array2, + int_array1, int_array2, + bigint_array1, bigint_array2, + decimal_array1, decimal_array2, + double_array1, double_array2, + float_array1, float_array2, + date_array1, data_array2, + timestamp_array1, timestamp_array2, + string_array1, string_array2, + array_array1, array_array2, + struct_array1, struct_array2, + map_array1, map_array2 +); + +-- Concatenate arrays of the same type +SELECT + (boolean_array1 || boolean_array2) boolean_array, + (tinyint_array1 || tinyint_array2) tinyint_array, + (smallint_array1 || smallint_array2) smallint_array, + (int_array1 || int_array2) int_array, + (bigint_array1 || bigint_array2) bigint_array, + (decimal_array1 || decimal_array2) decimal_array, + (double_array1 || double_array2) double_array, + (float_array1 || float_array2) float_array, + (date_array1 || data_array2) data_array, + (timestamp_array1 || timestamp_array2) timestamp_array, + (string_array1 || string_array2) string_array, + (array_array1 || array_array2) array_array, + (struct_array1 || struct_array2) struct_array, + (map_array1 || map_array2) map_array +FROM various_arrays; + +-- Concatenate arrays of different types +SELECT + (tinyint_array1 || smallint_array2) ts_array, + (smallint_array1 || int_array2) si_array, + (int_array1 || bigint_array2) ib_array, + (double_array1 || float_array2) df_array, + (string_array1 || data_array2) std_array, + (timestamp_array1 || string_array2) tst_array, + (string_array1 || int_array2) sti_array +FROM various_arrays; diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/concat.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/concat.sql.out index 09729fdc2ec32..62befc5ca0f15 100644 --- a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/concat.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/concat.sql.out @@ -237,3 +237,81 @@ struct 78910 891011 9101112 + + +-- !query 11 +CREATE TEMPORARY VIEW various_arrays AS SELECT * FROM VALUES ( + array(true, false), array(true), + array(2Y, 1Y), array(3Y, 4Y), + array(2S, 1S), array(3S, 4S), + array(2, 1), array(3, 4), + array(2L, 1L), array(3L, 4L), + array(9223372036854775809, 9223372036854775808), array(9223372036854775808, 9223372036854775809), + array(2.0D, 1.0D), array(3.0D, 4.0D), + array(float(2.0), float(1.0)), array(float(3.0), float(4.0)), + array(date '2016-03-14', date '2016-03-13'), array(date '2016-03-12', date '2016-03-11'), + array(timestamp '2016-11-15 20:54:00.000', timestamp '2016-11-12 20:54:00.000'), + array(timestamp '2016-11-11 20:54:00.000'), + array('a', 'b'), array('c', 'd'), + array(array('a', 'b'), array('c', 'd')), array(array('e'), array('f')), + array(struct('a', 1), struct('b', 2)), array(struct('c', 3), struct('d', 4)), + array(map('a', 1), map('b', 2)), array(map('c', 3), map('d', 4)) +) AS various_arrays( + boolean_array1, boolean_array2, + tinyint_array1, tinyint_array2, + smallint_array1, smallint_array2, + int_array1, int_array2, + bigint_array1, bigint_array2, + decimal_array1, decimal_array2, + double_array1, double_array2, + float_array1, float_array2, + date_array1, data_array2, + timestamp_array1, timestamp_array2, + string_array1, string_array2, + array_array1, array_array2, + struct_array1, struct_array2, + map_array1, map_array2 +) +-- !query 11 schema +struct<> +-- !query 11 output + + + +-- !query 12 +SELECT + (boolean_array1 || boolean_array2) boolean_array, + (tinyint_array1 || tinyint_array2) tinyint_array, + (smallint_array1 || smallint_array2) smallint_array, + (int_array1 || int_array2) int_array, + (bigint_array1 || bigint_array2) bigint_array, + (decimal_array1 || decimal_array2) decimal_array, + (double_array1 || double_array2) double_array, + (float_array1 || float_array2) float_array, + (date_array1 || data_array2) data_array, + (timestamp_array1 || timestamp_array2) timestamp_array, + (string_array1 || string_array2) string_array, + (array_array1 || array_array2) array_array, + (struct_array1 || struct_array2) struct_array, + (map_array1 || map_array2) map_array +FROM various_arrays +-- !query 12 schema +struct,tinyint_array:array,smallint_array:array,int_array:array,bigint_array:array,decimal_array:array,double_array:array,float_array:array,data_array:array,timestamp_array:array,string_array:array,array_array:array>,struct_array:array>,map_array:array>> +-- !query 12 output +[true,false,true] [2,1,3,4] [2,1,3,4] [2,1,3,4] [2,1,3,4] [9223372036854775809,9223372036854775808,9223372036854775808,9223372036854775809] [2.0,1.0,3.0,4.0] [2.0,1.0,3.0,4.0] [2016-03-14,2016-03-13,2016-03-12,2016-03-11] [2016-11-15 20:54:00.0,2016-11-12 20:54:00.0,2016-11-11 20:54:00.0] ["a","b","c","d"] [["a","b"],["c","d"],["e"],["f"]] [{"col1":"a","col2":1},{"col1":"b","col2":2},{"col1":"c","col2":3},{"col1":"d","col2":4}] [{"a":1},{"b":2},{"c":3},{"d":4}] + + +-- !query 13 +SELECT + (tinyint_array1 || smallint_array2) ts_array, + (smallint_array1 || int_array2) si_array, + (int_array1 || bigint_array2) ib_array, + (double_array1 || float_array2) df_array, + (string_array1 || data_array2) std_array, + (timestamp_array1 || string_array2) tst_array, + (string_array1 || int_array2) sti_array +FROM various_arrays +-- !query 13 schema +struct,si_array:array,ib_array:array,df_array:array,std_array:array,tst_array:array,sti_array:array> +-- !query 13 output +[2,1,3,4] [2,1,3,4] [2,1,3,4] [2.0,1.0,3.0,4.0] ["a","b","2016-03-12","2016-03-11"] ["2016-11-15 20:54:00","2016-11-12 20:54:00","c","d"] ["a","b","3","4"] diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 7c976c1b7f915..25e5cd60dd236 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -617,6 +617,80 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { ) } + test("concat function - arrays") { + val nseqi : Seq[Int] = null + val nseqs : Seq[String] = null + val df = Seq( + + (Seq(1), Seq(2, 3), Seq(5L, 6L), nseqi, Seq("a", "b", "c"), Seq("d", "e"), Seq("f"), nseqs), + (Seq(1, 0), Seq.empty[Int], Seq(2L), nseqi, Seq("a"), Seq.empty[String], Seq(null), nseqs) + ).toDF("i1", "i2", "i3", "in", "s1", "s2", "s3", "sn") + + val dummyFilter = (c: Column) => c.isNull || c.isNotNull // switch codeGen on + + // Simple test cases + checkAnswer( + df.selectExpr("array(1, 2, 3L)"), + Seq(Row(Seq(1L, 2L, 3L)), Row(Seq(1L, 2L, 3L))) + ) + + checkAnswer ( + df.select(concat($"i1", $"s1")), + Seq(Row(Seq("1", "a", "b", "c")), Row(Seq("1", "0", "a"))) + ) + checkAnswer( + df.select(concat($"i1", $"i2", $"i3")), + Seq(Row(Seq(1, 2, 3, 5, 6)), Row(Seq(1, 0, 2))) + ) + checkAnswer( + df.filter(dummyFilter($"i1")).select(concat($"i1", $"i2", $"i3")), + Seq(Row(Seq(1, 2, 3, 5, 6)), Row(Seq(1, 0, 2))) + ) + checkAnswer( + df.selectExpr("concat(array(1, null), i2, i3)"), + Seq(Row(Seq(1, null, 2, 3, 5, 6)), Row(Seq(1, null, 2))) + ) + checkAnswer( + df.select(concat($"s1", $"s2", $"s3")), + Seq(Row(Seq("a", "b", "c", "d", "e", "f")), Row(Seq("a", null))) + ) + checkAnswer( + df.selectExpr("concat(s1, s2, s3)"), + Seq(Row(Seq("a", "b", "c", "d", "e", "f")), Row(Seq("a", null))) + ) + checkAnswer( + df.filter(dummyFilter($"s1"))select(concat($"s1", $"s2", $"s3")), + Seq(Row(Seq("a", "b", "c", "d", "e", "f")), Row(Seq("a", null))) + ) + + // Null test cases + checkAnswer( + df.select(concat($"i1", $"in")), + Seq(Row(null), Row(null)) + ) + checkAnswer( + df.select(concat($"in", $"i1")), + Seq(Row(null), Row(null)) + ) + checkAnswer( + df.select(concat($"s1", $"sn")), + Seq(Row(null), Row(null)) + ) + checkAnswer( + df.select(concat($"sn", $"s1")), + Seq(Row(null), Row(null)) + ) + + // Type error test cases + intercept[AnalysisException] { + df.selectExpr("concat(i1, i2, null)") + } + + intercept[AnalysisException] { + df.selectExpr("concat(i1, array(i1, i2))") + } + } + private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = { import DataFrameFunctionsSuite.CodegenFallbackExpr for ((codegenFallback, wholeStage) <- Seq((true, false), (false, false), (false, true))) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index cbd7f9d6f67be..3998ceca38b30 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -1742,8 +1742,8 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { sql("DESCRIBE FUNCTION 'concat'"), Row("Class: org.apache.spark.sql.catalyst.expressions.Concat") :: Row("Function: concat") :: - Row("Usage: concat(str1, str2, ..., strN) - " + - "Returns the concatenation of str1, str2, ..., strN.") :: Nil + Row("Usage: concat(col1, col2, ..., colN) - " + + "Returns the concatenation of col1, col2, ..., colN.") :: Nil ) // extended mode checkAnswer( From 074a7f90536493b607e8e74bcebf3a27ea49a49d Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Fri, 20 Apr 2018 14:43:47 +0200 Subject: [PATCH 0659/1161] [SPARK-23588][SQL][FOLLOW-UP] Resolve a map builder method per execution in CatalystToExternalMap ## What changes were proposed in this pull request? This pr is a follow-up pr of #20979 and fixes code to resolve a map builder method per execution instead of per row in `CatalystToExternalMap`. ## How was this patch tested? Existing tests. Author: Takeshi Yamamuro Closes #21112 from maropu/SPARK-23588-FOLLOWUP. --- .../sql/catalyst/expressions/objects/objects.scala | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index bc17d1229420a..32c1f34ef97a5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -1040,11 +1040,13 @@ case class CatalystToExternalMap private( private lazy val valueConverter = CatalystTypeConverters.createToScalaConverter(inputMapType.valueType) - private def newMapBuilder(): Builder[AnyRef, AnyRef] = { + private lazy val (newMapBuilderMethod, moduleField) = { val clazz = Utils.classForName(collClass.getCanonicalName + "$") - val module = clazz.getField("MODULE$").get(null) - val method = clazz.getMethod("newBuilder") - method.invoke(module).asInstanceOf[Builder[AnyRef, AnyRef]] + (clazz.getMethod("newBuilder"), clazz.getField("MODULE$").get(null)) + } + + private def newMapBuilder(): Builder[AnyRef, AnyRef] = { + newMapBuilderMethod.invoke(moduleField).asInstanceOf[Builder[AnyRef, AnyRef]] } override def eval(input: InternalRow): Any = { From 0dd97f6ea4affde1531dec1bec004b7ab18c6965 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Fri, 20 Apr 2018 15:02:27 +0200 Subject: [PATCH 0660/1161] [SPARK-23595][SQL] ValidateExternalType should support interpreted execution ## What changes were proposed in this pull request? This pr supported interpreted mode for `ValidateExternalType`. ## How was this patch tested? Added tests in `ObjectExpressionsSuite`. Author: Takeshi Yamamuro Closes #20757 from maropu/SPARK-23595. --- .../spark/sql/catalyst/ScalaReflection.scala | 13 +++++++ .../sql/catalyst/encoders/RowEncoder.scala | 2 +- .../expressions/objects/objects.scala | 34 ++++++++++++++++--- .../expressions/ObjectExpressionsSuite.scala | 33 ++++++++++++++++-- 4 files changed, 74 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 818cc2fb1e8a8..f9acc208b715e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -846,6 +846,19 @@ object ScalaReflection extends ScalaReflection { } } + def javaBoxedType(dt: DataType): Class[_] = dt match { + case _: DecimalType => classOf[Decimal] + case BinaryType => classOf[Array[Byte]] + case StringType => classOf[UTF8String] + case CalendarIntervalType => classOf[CalendarInterval] + case _: StructType => classOf[InternalRow] + case _: ArrayType => classOf[ArrayType] + case _: MapType => classOf[MapType] + case udt: UserDefinedType[_] => javaBoxedType(udt.sqlType) + case ObjectType(cls) => cls + case _ => ScalaReflection.typeBoxedJavaMapping.getOrElse(dt, classOf[java.lang.Object]) + } + def expressionJavaClasses(arguments: Seq[Expression]): Seq[Class[_]] = { if (arguments != Nil) { arguments.map(e => dataTypeJavaClass(e.dataType)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index 789750fd408f2..3340789398f9c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.analysis.GetColumnByOrdinal import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.objects._ -import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, DateTimeUtils, GenericArrayData} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, DateTimeUtils} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 32c1f34ef97a5..f1ffcaec8a484 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -35,6 +35,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData} import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} import org.apache.spark.util.Utils /** @@ -1672,13 +1673,36 @@ case class ValidateExternalType(child: Expression, expected: DataType) override def nullable: Boolean = child.nullable - override def dataType: DataType = RowEncoder.externalDataTypeForInput(expected) - - override def eval(input: InternalRow): Any = - throw new UnsupportedOperationException("Only code-generated evaluation is supported") + override val dataType: DataType = RowEncoder.externalDataTypeForInput(expected) private val errMsg = s" is not a valid external type for schema of ${expected.simpleString}" + private lazy val checkType: (Any) => Boolean = expected match { + case _: DecimalType => + (value: Any) => { + value.isInstanceOf[java.math.BigDecimal] || value.isInstanceOf[scala.math.BigDecimal] || + value.isInstanceOf[Decimal] + } + case _: ArrayType => + (value: Any) => { + value.getClass.isArray || value.isInstanceOf[Seq[_]] + } + case _ => + val dataTypeClazz = ScalaReflection.javaBoxedType(dataType) + (value: Any) => { + dataTypeClazz.isInstance(value) + } + } + + override def eval(input: InternalRow): Any = { + val result = child.eval(input) + if (checkType(result)) { + result + } else { + throw new RuntimeException(s"${result.getClass.getName}$errMsg") + } + } + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { // Use unnamed reference that doesn't create a local field here to reduce the number of fields // because errMsgField is used only when the type doesn't match. @@ -1691,7 +1715,7 @@ case class ValidateExternalType(child: Expression, expected: DataType) Seq(classOf[java.math.BigDecimal], classOf[scala.math.BigDecimal], classOf[Decimal]) .map(cls => s"$obj instanceof ${cls.getName}").mkString(" || ") case _: ArrayType => - s"$obj instanceof ${classOf[Seq[_]].getName} || $obj.getClass().isArray()" + s"$obj.getClass().isArray() || $obj instanceof ${classOf[Seq[_]].getName}" case _ => s"$obj instanceof ${CodeGenerator.boxedType(dataType)}" } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala index bcd035c1eba0b..7136af8934486 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala @@ -37,7 +37,7 @@ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, Generic import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} class InvokeTargetClass extends Serializable { def filterInt(e: Any): Any = e.asInstanceOf[Int] > 0 @@ -296,7 +296,7 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val inputObject = BoundReference(0, ObjectType(classOf[Row]), nullable = true) val getRowField = GetExternalRowField(inputObject, index = 0, fieldName = "c0") Seq((Row(1), 1), (Row(3), 3)).foreach { case (input, expected) => - checkEvaluation(getRowField, expected, InternalRow.fromSeq(Seq(input))) + checkObjectExprEvaluation(getRowField, expected, InternalRow.fromSeq(Seq(input))) } // If an input row or a field are null, a runtime exception will be thrown @@ -472,6 +472,35 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val deserializer = toMapExpr.copy(inputData = Literal.create(data)) checkObjectExprEvaluation(deserializer, expected = data) } + + test("SPARK-23595 ValidateExternalType should support interpreted execution") { + val inputObject = BoundReference(0, ObjectType(classOf[Row]), nullable = true) + Seq( + (true, BooleanType), + (2.toByte, ByteType), + (5.toShort, ShortType), + (23, IntegerType), + (61L, LongType), + (1.0f, FloatType), + (10.0, DoubleType), + ("abcd".getBytes, BinaryType), + ("abcd", StringType), + (BigDecimal.valueOf(10), DecimalType.IntDecimal), + (CalendarInterval.fromString("interval 3 day"), CalendarIntervalType), + (java.math.BigDecimal.valueOf(10), DecimalType.BigIntDecimal), + (Array(3, 2, 1), ArrayType(IntegerType)) + ).foreach { case (input, dt) => + val validateType = ValidateExternalType( + GetExternalRowField(inputObject, index = 0, fieldName = "c0"), dt) + checkObjectExprEvaluation(validateType, input, InternalRow.fromSeq(Seq(Row(input)))) + } + + checkExceptionInExpression[RuntimeException]( + ValidateExternalType( + GetExternalRowField(inputObject, index = 0, fieldName = "c0"), DoubleType), + InternalRow.fromSeq(Seq(Row(1))), + "java.lang.Integer is not a valid external type for schema of double") + } } class TestBean extends Serializable { From 1d758dc73b54e802fdc92be204185fe7414e6553 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Fri, 20 Apr 2018 10:23:01 -0700 Subject: [PATCH 0661/1161] Revert "[SPARK-23775][TEST] Make DataFrameRangeSuite not flaky" This reverts commit 0c94e48bc50717e1627c0d2acd5382d9adc73c97. --- .../spark/sql/DataFrameRangeSuite.scala | 78 ++++++++----------- 1 file changed, 33 insertions(+), 45 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala index a0fd74088ce8b..57a930dfaf320 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala @@ -17,16 +17,14 @@ package org.apache.spark.sql -import java.util.concurrent.{CountDownLatch, TimeUnit} - import scala.concurrent.duration._ import scala.math.abs import scala.util.Random import org.scalatest.concurrent.Eventually -import org.apache.spark.{SparkContext, SparkException} -import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskStart} +import org.apache.spark.{SparkException, TaskContext} +import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext @@ -154,53 +152,39 @@ class DataFrameRangeSuite extends QueryTest with SharedSQLContext with Eventuall } test("Cancelling stage in a query with Range.") { - // Save and restore the value because SparkContext is shared - val savedInterruptOnCancel = sparkContext - .getLocalProperty(SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL) - - try { - sparkContext.setLocalProperty(SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL, "true") - - for (codegen <- Seq(true, false)) { - // This countdown latch used to make sure with all the stages cancelStage called in listener - val latch = new CountDownLatch(2) - - val listener = new SparkListener { - override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = { - sparkContext.cancelStage(taskStart.stageId) - latch.countDown() - } + val listener = new SparkListener { + override def onJobStart(jobStart: SparkListenerJobStart): Unit = { + eventually(timeout(10.seconds), interval(1.millis)) { + assert(DataFrameRangeSuite.stageToKill > 0) } + sparkContext.cancelStage(DataFrameRangeSuite.stageToKill) + } + } - sparkContext.addSparkListener(listener) - withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> codegen.toString()) { - val ex = intercept[SparkException] { - sparkContext.range(0, 10000L, numSlices = 10).mapPartitions { x => - x.synchronized { - x.wait() - } - x - }.toDF("id").agg(sum("id")).collect() - } - ex.getCause() match { - case null => - assert(ex.getMessage().contains("cancelled")) - case cause: SparkException => - assert(cause.getMessage().contains("cancelled")) - case cause: Throwable => - fail("Expected the cause to be SparkException, got " + cause.toString() + " instead.") - } + sparkContext.addSparkListener(listener) + for (codegen <- Seq(true, false)) { + withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> codegen.toString()) { + DataFrameRangeSuite.stageToKill = -1 + val ex = intercept[SparkException] { + spark.range(0, 100000000000L, 1, 1).map { x => + DataFrameRangeSuite.stageToKill = TaskContext.get().stageId() + x + }.toDF("id").agg(sum("id")).collect() } - latch.await(20, TimeUnit.SECONDS) - eventually(timeout(20.seconds)) { - assert(sparkContext.statusTracker.getExecutorInfos.map(_.numRunningTasks()).sum == 0) + ex.getCause() match { + case null => + assert(ex.getMessage().contains("cancelled")) + case cause: SparkException => + assert(cause.getMessage().contains("cancelled")) + case cause: Throwable => + fail("Expected the cause to be SparkException, got " + cause.toString() + " instead.") } - sparkContext.removeSparkListener(listener) } - } finally { - sparkContext.setLocalProperty(SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL, - savedInterruptOnCancel) + eventually(timeout(20.seconds)) { + assert(sparkContext.statusTracker.getExecutorInfos.map(_.numRunningTasks()).sum == 0) + } } + sparkContext.removeSparkListener(listener) } test("SPARK-20430 Initialize Range parameters in a driver side") { @@ -220,3 +204,7 @@ class DataFrameRangeSuite extends QueryTest with SharedSQLContext with Eventuall } } } + +object DataFrameRangeSuite { + @volatile var stageToKill = -1 +} From 32b4bcd6d31b92b179a15f9886779fc5f96404b5 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Sat, 21 Apr 2018 23:14:58 +0800 Subject: [PATCH 0662/1161] [SPARK-24029][CORE] Set SO_REUSEADDR on listen sockets. This allows sockets to be bound even if there are sockets from a previous application that are still pending closure. It avoids bind issues when, for example, re-starting the SHS. Don't enable the option on Windows though. The following page explains some odd behavior that this option can have there: https://msdn.microsoft.com/en-us/library/windows/desktop/ms740621%28v=vs.85%29.aspx I intentionally ignored server sockets that always bind to ephemeral ports, since those don't benefit from this option. Author: Marcelo Vanzin Closes #21110 from vanzin/SPARK-24029. --- .../java/org/apache/spark/network/server/TransportServer.java | 4 +++- .../org/apache/spark/deploy/rest/RestSubmissionServer.scala | 1 + core/src/main/scala/org/apache/spark/ui/JettyUtils.scala | 1 + 3 files changed, 5 insertions(+), 1 deletion(-) diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/TransportServer.java b/common/network-common/src/main/java/org/apache/spark/network/server/TransportServer.java index 0719fa7647bcc..612750972c4bb 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/TransportServer.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/TransportServer.java @@ -32,6 +32,7 @@ import io.netty.channel.ChannelOption; import io.netty.channel.EventLoopGroup; import io.netty.channel.socket.SocketChannel; +import org.apache.commons.lang3.SystemUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -98,7 +99,8 @@ private void init(String hostToBind, int portToBind) { .group(bossGroup, workerGroup) .channel(NettyUtils.getServerChannelClass(ioMode)) .option(ChannelOption.ALLOCATOR, allocator) - .childOption(ChannelOption.ALLOCATOR, allocator); + .childOption(ChannelOption.ALLOCATOR, allocator) + .childOption(ChannelOption.SO_REUSEADDR, !SystemUtils.IS_OS_WINDOWS); this.metrics = new NettyMemoryMetrics( allocator, conf.getModuleName() + "-server", conf); diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionServer.scala b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionServer.scala index e88195d95f270..3d99d085408c6 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionServer.scala @@ -94,6 +94,7 @@ private[spark] abstract class RestSubmissionServer( new HttpConnectionFactory()) connector.setHost(host) connector.setPort(startPort) + connector.setReuseAddress(!Utils.isWindows) server.addConnector(connector) val mainHandler = new ServletContextHandler diff --git a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala index 0e8a6307de6a8..d6a025a6f12da 100644 --- a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala @@ -344,6 +344,7 @@ private[spark] object JettyUtils extends Logging { connectionFactories: _*) connector.setPort(port) connector.setHost(hostName) + connector.setReuseAddress(!Utils.isWindows) // Currently we only use "SelectChannelConnector" // Limit the max acceptor number to 8 so that we don't waste a lot of threads From 7bc853d08973a6bd839ad2222911eb0a0f413677 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Sat, 21 Apr 2018 10:45:12 -0700 Subject: [PATCH 0663/1161] [SPARK-24033][SQL] Fix Mismatched of Window Frame specifiedwindowframe(RowFrame, -1, -1) ## What changes were proposed in this pull request? When the OffsetWindowFunction's frame is `UnaryMinus(Literal(1))` but the specified window frame has been simplified to `Literal(-1)` by some optimizer rules e.g., `ConstantFolding`. Thus, they do not match and cause the following error: ``` org.apache.spark.sql.AnalysisException: Window Frame specifiedwindowframe(RowFrame, -1, -1) must match the required frame specifiedwindowframe(RowFrame, -1, -1); at org.apache.spark.sql.catalyst.analysis.CheckAnalysis$class.failAnalysis(CheckAnalysis.scala:41) at org.apache.spark.sql.catalyst.analysis.Analyzer.failAnalysis(Analyzer.scala:91) at ``` ## How was this patch tested? Added a test Author: gatorsmile Closes #21115 from gatorsmile/fixLag. --- .../catalyst/expressions/windowExpressions.scala | 5 ++++- .../spark/sql/DataFrameWindowFramesSuite.scala | 14 ++++++++++++++ 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala index 78895f1c2f6f5..9fe2fb2b95e4d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala @@ -342,7 +342,10 @@ abstract class OffsetWindowFunction override lazy val frame: WindowFrame = { val boundary = direction match { case Ascending => offset - case Descending => UnaryMinus(offset) + case Descending => UnaryMinus(offset) match { + case e: Expression if e.foldable => Literal.create(e.eval(EmptyRow), e.dataType) + case o => o + } } SpecifiedWindowFrame(RowFrame, boundary, boundary) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFramesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFramesSuite.scala index 0ee9b0edc02b2..2a0b2b85e10a9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFramesSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFramesSuite.scala @@ -402,4 +402,18 @@ class DataFrameWindowFramesSuite extends QueryTest with SharedSQLContext { Row(7, 3000) :: Row(8, 3000) :: Row(9, 5500) :: Row(10, 6000) :: Nil) } + + test("SPARK-24033: Analysis Failure of OffsetWindowFunction") { + val ds = Seq((1, 1), (1, 2), (1, 3), (2, 1), (2, 2)).toDF("n", "i") + val res = + Row(1, 1, null) :: Row (1, 2, 1) :: Row(1, 3, 2) :: Row(2, 1, null) :: Row(2, 2, 1) :: Nil + checkAnswer( + ds.withColumn("m", + lead("i", -1).over(Window.partitionBy("n").orderBy("i").rowsBetween(-1, -1))), + res) + checkAnswer( + ds.withColumn("m", + lag("i", 1).over(Window.partitionBy("n").orderBy("i").rowsBetween(-1, -1))), + res) + } } From c48085aa91c60615a4de3b391f019f46f3fcdbe3 Mon Sep 17 00:00:00 2001 From: Mykhailo Shtelma Date: Sat, 21 Apr 2018 23:33:57 -0700 Subject: [PATCH 0664/1161] [SPARK-23799][SQL] FilterEstimation.evaluateInSet produces devision by zero in a case of empty table with analyzed statistics >What changes were proposed in this pull request? During evaluation of IN conditions, if the source data frame, is represented by a plan, that uses hive table with columns, which were previously analysed, and the plan has conditions for these fields, that cannot be satisfied (which leads us to an empty data frame), FilterEstimation.evaluateInSet method produces NumberFormatException and ClassCastException. In order to fix this bug, method FilterEstimation.evaluateInSet at first checks, if distinct count is not zero, and also checks if colStat.min and colStat.max are defined, and only in this case proceeds with the calculation. If at least one of the conditions is not satisfied, zero is returned. >How was this patch tested? In order to test the PR two tests were implemented: one in FilterEstimationSuite, that tests the plan with the statistics that violates the conditions mentioned above, and another one in StatisticsCollectionSuite, that test the whole process of analysis/optimisation of the query, that leads to the problems, mentioned in the first section. Author: Mykhailo Shtelma Author: smikesh Closes #21052 from mshtelma/filter_estimation_evaluateInSet_Bugs. --- .../statsEstimation/FilterEstimation.scala | 4 +++ .../FilterEstimationSuite.scala | 11 ++++++++ .../spark/sql/StatisticsCollectionSuite.scala | 28 +++++++++++++++++++ 3 files changed, 43 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala index 0538c9d88584b..263c9ba60d145 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala @@ -392,6 +392,10 @@ case class FilterEstimation(plan: Filter) extends Logging { val dataType = attr.dataType var newNdv = ndv + if (ndv.toDouble == 0 || colStat.min.isEmpty || colStat.max.isEmpty) { + return Some(0.0) + } + // use [min, max] to filter the original hSet dataType match { case _: NumericType | BooleanType | DateType | TimestampType => diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala index 43440d51dede6..16cb5d032cf57 100755 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala @@ -357,6 +357,17 @@ class FilterEstimationSuite extends StatsEstimationTestBase { expectedRowCount = 3) } + test("evaluateInSet with all zeros") { + validateEstimatedStats( + Filter(InSet(attrString, Set(3, 4, 5)), + StatsTestPlan(Seq(attrString), 0, + AttributeMap(Seq(attrString -> + ColumnStat(distinctCount = Some(0), min = None, max = None, + nullCount = Some(0), avgLen = Some(0), maxLen = Some(0)))))), + Seq(attrString -> ColumnStat(distinctCount = Some(0))), + expectedRowCount = 0) + } + test("cint NOT IN (3, 4, 5)") { validateEstimatedStats( Filter(Not(InSet(attrInt, Set(3, 4, 5))), childStatsTestPlan(Seq(attrInt), 10L)), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala index 14a565863d66c..b91712f4cc25d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala @@ -382,4 +382,32 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared } } } + + test("Simple queries must be working, if CBO is turned on") { + withSQLConf(SQLConf.CBO_ENABLED.key -> "true") { + withTable("TBL1", "TBL") { + import org.apache.spark.sql.functions._ + val df = spark.range(1000L).select('id, + 'id * 2 as "FLD1", + 'id * 12 as "FLD2", + lit("aaa") + 'id as "fld3") + df.write + .mode(SaveMode.Overwrite) + .bucketBy(10, "id", "FLD1", "FLD2") + .sortBy("id", "FLD1", "FLD2") + .saveAsTable("TBL") + sql("ANALYZE TABLE TBL COMPUTE STATISTICS ") + sql("ANALYZE TABLE TBL COMPUTE STATISTICS FOR COLUMNS ID, FLD1, FLD2, FLD3") + val df2 = spark.sql( + """ + |SELECT t1.id, t1.fld1, t1.fld2, t1.fld3 + |FROM tbl t1 + |JOIN tbl t2 on t1.id=t2.id + |WHERE t1.fld3 IN (-123.23,321.23) + """.stripMargin) + df2.createTempView("TBL2") + sql("SELECT * FROM tbl2 WHERE fld3 IN ('qqq', 'qwe') ").queryExecution.executedPlan + } + } + } } From c3a86faa53c9e49efd595802adc38a6d412ce681 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Mon, 23 Apr 2018 10:45:25 +0800 Subject: [PATCH 0665/1161] [SPARK-10399][SPARK-23879][FOLLOWUP][CORE] Free unused off-heap memory in MemoryBlockSuite ## What changes were proposed in this pull request? As viirya pointed out [here](https://github.com/apache/spark/pull/19222#discussion_r179910484), this PR explicitly frees unused off-heap memory in `MemoryBlockSuite` ## How was this patch tested? Existing UTs Author: Kazuaki Ishizaki Closes #21117 from kiszk/SPARK-10399-free-offheap. --- .../java/org/apache/spark/unsafe/memory/MemoryBlockSuite.java | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/memory/MemoryBlockSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/memory/MemoryBlockSuite.java index 5d5fdc1c55a75..ef5ff8ee70ec0 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/memory/MemoryBlockSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/memory/MemoryBlockSuite.java @@ -120,6 +120,8 @@ private void check(MemoryBlock memory, Object obj, long offset, int length) { } catch (Exception expected) { Assert.assertThat(expected.getMessage(), containsString("should not be larger than")); } + + memory.setPageNumber(MemoryBlock.NO_PAGE_NUMBER); } @Test @@ -165,11 +167,13 @@ public void testOffHeapArrayMemoryBlock() { int length = 56; check(memory, obj, offset, length); + memoryAllocator.free(memory); long address = Platform.allocateMemory(112); memory = new OffHeapMemoryBlock(address, length); obj = memory.getBaseObject(); offset = memory.getBaseOffset(); check(memory, obj, offset, length); + Platform.freeMemory(address); } } From f70f46d1e5bc503e9071707d837df618b7696d32 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 23 Apr 2018 20:18:50 +0800 Subject: [PATCH 0666/1161] [SPARK-23877][SQL][FOLLOWUP] use PhysicalOperation to simplify the handling of Project and Filter over partitioned relation ## What changes were proposed in this pull request? A followup of https://github.com/apache/spark/pull/20988 `PhysicalOperation` can collect Project and Filters over a certain plan and substitute the alias with the original attributes in the bottom plan. We can use it in `OptimizeMetadataOnlyQuery` rule to handle the Project and Filter over partitioned relation. ## How was this patch tested? existing test Author: Wenchen Fan Closes #21111 from cloud-fan/refactor. --- .../plans/logical/LocalRelation.scala | 6 ++ .../sql/execution/LocalTableScanExec.scala | 3 + .../execution/OptimizeMetadataOnlyQuery.scala | 58 ++++++------------- .../OptimizeHiveMetadataOnlyQuerySuite.scala | 16 ++++- 4 files changed, 39 insertions(+), 44 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala index b05508db786ad..720d42ab409a0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala @@ -43,6 +43,12 @@ object LocalRelation { } } +/** + * Logical plan node for scanning data from a local collection. + * + * @param data The local collection holding the data. It doesn't need to be sent to executors + * and then doesn't need to be serializable. + */ case class LocalRelation( output: Seq[Attribute], data: Seq[InternalRow] = Nil, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala index 514ad7018d8c7..448eb703eacde 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala @@ -25,6 +25,9 @@ import org.apache.spark.sql.execution.metric.SQLMetrics /** * Physical plan node for scanning data from a local collection. + * + * `Seq` may not be serializable and ideally we should not send `rows` and `unsafeRows` + * to the executors. Thus marking them as transient. */ case class LocalTableScanExec( output: Seq[Attribute], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala index acbd4becb8549..3ca03ab2939aa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.catalog.{HiveTableRelation, SessionCatalog} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} @@ -49,9 +50,13 @@ case class OptimizeMetadataOnlyQuery(catalog: SessionCatalog) extends Rule[Logic } plan.transform { - case a @ Aggregate(_, aggExprs, child @ PartitionedRelation(_, attrs, filters, rel)) => + case a @ Aggregate(_, aggExprs, child @ PhysicalOperation( + projectList, filters, PartitionedRelation(partAttrs, rel))) => // We only apply this optimization when only partitioned attributes are scanned. - if (a.references.subsetOf(attrs)) { + if (AttributeSet((projectList ++ filters).flatMap(_.references)).subsetOf(partAttrs)) { + // The project list and filters all only refer to partition attributes, which means the + // the Aggregator operator can also only refer to partition attributes, and filters are + // all partition filters. This is a metadata only query we can optimize. val aggFunctions = aggExprs.flatMap(_.collect { case agg: AggregateExpression => agg }) @@ -102,7 +107,7 @@ case class OptimizeMetadataOnlyQuery(catalog: SessionCatalog) extends Rule[Logic partFilters: Seq[Expression]): LogicalPlan = { // this logic comes from PruneFileSourcePartitions. it ensures that the filter names match the // relation's schema. PartitionedRelation ensures that the filters only reference partition cols - val relFilters = partFilters.map { e => + val normalizedFilters = partFilters.map { e => e transform { case a: AttributeReference => a.withName(relation.output.find(_.semanticEquals(a)).get.name) @@ -114,11 +119,8 @@ case class OptimizeMetadataOnlyQuery(catalog: SessionCatalog) extends Rule[Logic relation match { case l @ LogicalRelation(fsRelation: HadoopFsRelation, _, _, isStreaming) => val partAttrs = getPartitionAttrs(fsRelation.partitionSchema.map(_.name), l) - val partitionData = fsRelation.location.listFiles(relFilters, Nil) - // partition data may be a stream, which can cause serialization to hit stack level too - // deep exceptions because it is a recursive structure in memory. converting to array - // avoids the problem. - LocalRelation(partAttrs, partitionData.map(_.values).toArray, isStreaming) + val partitionData = fsRelation.location.listFiles(normalizedFilters, Nil) + LocalRelation(partAttrs, partitionData.map(_.values), isStreaming) case relation: HiveTableRelation => val partAttrs = getPartitionAttrs(relation.tableMeta.partitionColumnNames, relation) @@ -127,7 +129,7 @@ case class OptimizeMetadataOnlyQuery(catalog: SessionCatalog) extends Rule[Logic val timeZoneId = caseInsensitiveProperties.get(DateTimeUtils.TIMEZONE_OPTION) .getOrElse(SQLConf.get.sessionLocalTimeZone) val partitions = if (partFilters.nonEmpty) { - catalog.listPartitionsByFilter(relation.tableMeta.identifier, relFilters) + catalog.listPartitionsByFilter(relation.tableMeta.identifier, normalizedFilters) } else { catalog.listPartitions(relation.tableMeta.identifier) } @@ -137,10 +139,7 @@ case class OptimizeMetadataOnlyQuery(catalog: SessionCatalog) extends Rule[Logic Cast(Literal(p.spec(attr.name)), attr.dataType, Option(timeZoneId)).eval() }) } - // partition data may be a stream, which can cause serialization to hit stack level too - // deep exceptions because it is a recursive structure in memory. converting to array - // avoids the problem. - LocalRelation(partAttrs, partitionData.toArray) + LocalRelation(partAttrs, partitionData) case _ => throw new IllegalStateException(s"unrecognized table scan node: $relation, " + @@ -151,44 +150,21 @@ case class OptimizeMetadataOnlyQuery(catalog: SessionCatalog) extends Rule[Logic /** * A pattern that finds the partitioned table relation node inside the given plan, and returns a - * pair of the partition attributes, partition filters, and the table relation node. - * - * It keeps traversing down the given plan tree if there is a [[Project]] or [[Filter]] with - * deterministic expressions, and returns result after reaching the partitioned table relation - * node. + * pair of the partition attributes and the table relation node. */ object PartitionedRelation extends PredicateHelper { - def unapply( - plan: LogicalPlan): Option[(AttributeSet, AttributeSet, Seq[Expression], LogicalPlan)] = { + def unapply(plan: LogicalPlan): Option[(AttributeSet, LogicalPlan)] = { plan match { case l @ LogicalRelation(fsRelation: HadoopFsRelation, _, _, _) - if fsRelation.partitionSchema.nonEmpty => + if fsRelation.partitionSchema.nonEmpty => val partAttrs = AttributeSet(getPartitionAttrs(fsRelation.partitionSchema.map(_.name), l)) - Some((partAttrs, partAttrs, Nil, l)) + Some((partAttrs, l)) case relation: HiveTableRelation if relation.tableMeta.partitionColumnNames.nonEmpty => val partAttrs = AttributeSet( getPartitionAttrs(relation.tableMeta.partitionColumnNames, relation)) - Some((partAttrs, partAttrs, Nil, relation)) - - case p @ Project(projectList, child) if projectList.forall(_.deterministic) => - unapply(child).flatMap { case (partAttrs, attrs, filters, relation) => - if (p.references.subsetOf(attrs)) { - Some((partAttrs, p.outputSet, filters, relation)) - } else { - None - } - } - - case f @ Filter(condition, child) if condition.deterministic => - unapply(child).flatMap { case (partAttrs, attrs, filters, relation) => - if (f.references.subsetOf(partAttrs)) { - Some((partAttrs, attrs, splitConjunctivePredicates(condition) ++ filters, relation)) - } else { - None - } - } + Some((partAttrs, relation)) case _ => None } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/OptimizeHiveMetadataOnlyQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/OptimizeHiveMetadataOnlyQuerySuite.scala index 95f192f0e40e2..1e525c46a9cfb 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/OptimizeHiveMetadataOnlyQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/OptimizeHiveMetadataOnlyQuerySuite.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.QueryTest import org.apache.spark.sql.catalyst.expressions.NamedExpression import org.apache.spark.sql.catalyst.plans.logical.{Distinct, Filter, Project, SubqueryAlias} import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.internal.SQLConf.OPTIMIZER_METADATA_ONLY import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types.{IntegerType, StructField, StructType} @@ -32,13 +33,22 @@ class OptimizeHiveMetadataOnlyQuerySuite extends QueryTest with TestHiveSingleto import spark.implicits._ - before { + override def beforeAll(): Unit = { + super.beforeAll() sql("CREATE TABLE metadata_only (id bigint, data string) PARTITIONED BY (part int)") (0 to 10).foreach(p => sql(s"ALTER TABLE metadata_only ADD PARTITION (part=$p)")) } + override protected def afterAll(): Unit = { + try { + sql("DROP TABLE IF EXISTS metadata_only") + } finally { + super.afterAll() + } + } + test("SPARK-23877: validate metadata-only query pushes filters to metastore") { - withTable("metadata_only") { + withSQLConf(OPTIMIZER_METADATA_ONLY.key -> "true") { val startCount = HiveCatalogMetrics.METRIC_PARTITIONS_FETCHED.getCount // verify the number of matching partitions @@ -50,7 +60,7 @@ class OptimizeHiveMetadataOnlyQuerySuite extends QueryTest with TestHiveSingleto } test("SPARK-23877: filter on projected expression") { - withTable("metadata_only") { + withSQLConf(OPTIMIZER_METADATA_ONLY.key -> "true") { val startCount = HiveCatalogMetrics.METRIC_PARTITIONS_FETCHED.getCount // verify the matching partitions From d87d30e4fe9c9e91c462351e9f744a830db8d6fc Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 23 Apr 2018 20:21:01 +0800 Subject: [PATCH 0667/1161] [SPARK-23564][SQL] infer additional filters from constraints for join's children ## What changes were proposed in this pull request? The existing query constraints framework has 2 steps: 1. propagate constraints bottom up. 2. use constraints to infer additional filters for better data pruning. For step 2, it mostly helps with Join, because we can connect the constraints from children to the join condition and infer powerful filters to prune the data of the join sides. e.g., the left side has constraints `a = 1`, the join condition is `left.a = right.a`, then we can infer `right.a = 1` to the right side and prune the right side a lot. However, the current logic of inferring filters from constraints for Join is pretty weak. It infers the filters from Join's constraints. Some joins like left semi/anti exclude output from right side and the right side constraints will be lost here. This PR propose to check the left and right constraints individually, expand the constraints with join condition and add filters to children of join directly, instead of adding to the join condition. This reverts https://github.com/apache/spark/pull/20670 , covers https://github.com/apache/spark/pull/20717 and https://github.com/apache/spark/pull/20816 This is inspired by the original PRs and the tests are all from these PRs. Thanks to the authors mgaido91 maryannxue KaiXinXiaoLei ! ## How was this patch tested? new tests Author: Wenchen Fan Closes #21083 from cloud-fan/join. --- .../sql/catalyst/optimizer/Optimizer.scala | 97 +++++++++---------- .../plans/logical/QueryPlanConstraints.scala | 95 ++++++++---------- .../InferFiltersFromConstraintsSuite.scala | 53 +++++++--- 3 files changed, 124 insertions(+), 121 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 913354e4df0e6..f00d40d11f23f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -637,13 +637,11 @@ object CollapseWindow extends Rule[LogicalPlan] { * constraints. These filters are currently inserted to the existing conditions in the Filter * operators and on either side of Join operators. * - * In addition, for left/right outer joins, infer predicate from the preserved side of the Join - * operator and push the inferred filter over to the null-supplying side. For example, if the - * preserved side has constraints of the form 'a > 5' and the join condition is 'a = b', in - * which 'b' is an attribute from the null-supplying side, a [[Filter]] operator of 'b > 5' will - * be applied to the null-supplying side. + * Note: While this optimization is applicable to a lot of types of join, it primarily benefits + * Inner and LeftSemi joins. */ -object InferFiltersFromConstraints extends Rule[LogicalPlan] with PredicateHelper { +object InferFiltersFromConstraints extends Rule[LogicalPlan] + with PredicateHelper with ConstraintHelper { def apply(plan: LogicalPlan): LogicalPlan = { if (SQLConf.get.constraintPropagationEnabled) { @@ -664,53 +662,52 @@ object InferFiltersFromConstraints extends Rule[LogicalPlan] with PredicateHelpe } case join @ Join(left, right, joinType, conditionOpt) => - // Only consider constraints that can be pushed down completely to either the left or the - // right child - val constraints = join.allConstraints.filter { c => - c.references.subsetOf(left.outputSet) || c.references.subsetOf(right.outputSet) - } - // Remove those constraints that are already enforced by either the left or the right child - val additionalConstraints = constraints -- (left.constraints ++ right.constraints) - val newConditionOpt = conditionOpt match { - case Some(condition) => - val newFilters = additionalConstraints -- splitConjunctivePredicates(condition) - if (newFilters.nonEmpty) Option(And(newFilters.reduce(And), condition)) else conditionOpt - case None => - additionalConstraints.reduceOption(And) - } - // Infer filter for left/right outer joins - val newLeftOpt = joinType match { - case RightOuter if newConditionOpt.isDefined => - val inferredConstraints = left.getRelevantConstraints( - left.constraints - .union(right.constraints) - .union(splitConjunctivePredicates(newConditionOpt.get).toSet)) - val newFilters = inferredConstraints - .filterNot(left.constraints.contains) - .reduceLeftOption(And) - newFilters.map(Filter(_, left)) - case _ => None - } - val newRightOpt = joinType match { - case LeftOuter if newConditionOpt.isDefined => - val inferredConstraints = right.getRelevantConstraints( - right.constraints - .union(left.constraints) - .union(splitConjunctivePredicates(newConditionOpt.get).toSet)) - val newFilters = inferredConstraints - .filterNot(right.constraints.contains) - .reduceLeftOption(And) - newFilters.map(Filter(_, right)) - case _ => None - } + joinType match { + // For inner join, we can infer additional filters for both sides. LeftSemi is kind of an + // inner join, it just drops the right side in the final output. + case _: InnerLike | LeftSemi => + val allConstraints = getAllConstraints(left, right, conditionOpt) + val newLeft = inferNewFilter(left, allConstraints) + val newRight = inferNewFilter(right, allConstraints) + join.copy(left = newLeft, right = newRight) - if ((newConditionOpt.isDefined && (newConditionOpt ne conditionOpt)) - || newLeftOpt.isDefined || newRightOpt.isDefined) { - Join(newLeftOpt.getOrElse(left), newRightOpt.getOrElse(right), joinType, newConditionOpt) - } else { - join + // For right outer join, we can only infer additional filters for left side. + case RightOuter => + val allConstraints = getAllConstraints(left, right, conditionOpt) + val newLeft = inferNewFilter(left, allConstraints) + join.copy(left = newLeft) + + // For left join, we can only infer additional filters for right side. + case LeftOuter | LeftAnti => + val allConstraints = getAllConstraints(left, right, conditionOpt) + val newRight = inferNewFilter(right, allConstraints) + join.copy(right = newRight) + + case _ => join } } + + private def getAllConstraints( + left: LogicalPlan, + right: LogicalPlan, + conditionOpt: Option[Expression]): Set[Expression] = { + val baseConstraints = left.constraints.union(right.constraints) + .union(conditionOpt.map(splitConjunctivePredicates).getOrElse(Nil).toSet) + baseConstraints.union(inferAdditionalConstraints(baseConstraints)) + } + + private def inferNewFilter(plan: LogicalPlan, constraints: Set[Expression]): LogicalPlan = { + val newPredicates = constraints + .union(constructIsNotNullConstraints(constraints, plan.output)) + .filter { c => + c.references.nonEmpty && c.references.subsetOf(plan.outputSet) && c.deterministic + } -- plan.constraints + if (newPredicates.isEmpty) { + plan + } else { + Filter(newPredicates.reduce(And), plan) + } + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala index a29f3d29236c7..cc352c59dff80 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala @@ -20,29 +20,28 @@ package org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.expressions._ -trait QueryPlanConstraints { self: LogicalPlan => +trait QueryPlanConstraints extends ConstraintHelper { self: LogicalPlan => /** - * An [[ExpressionSet]] that contains an additional set of constraints, such as equality - * constraints and `isNotNull` constraints, etc. + * An [[ExpressionSet]] that contains invariants about the rows output by this operator. For + * example, if this set contains the expression `a = 2` then that expression is guaranteed to + * evaluate to `true` for all rows produced. */ - lazy val allConstraints: ExpressionSet = { + lazy val constraints: ExpressionSet = { if (conf.constraintPropagationEnabled) { - ExpressionSet(validConstraints - .union(inferAdditionalConstraints(validConstraints)) - .union(constructIsNotNullConstraints(validConstraints))) + ExpressionSet( + validConstraints + .union(inferAdditionalConstraints(validConstraints)) + .union(constructIsNotNullConstraints(validConstraints, output)) + .filter { c => + c.references.nonEmpty && c.references.subsetOf(outputSet) && c.deterministic + } + ) } else { ExpressionSet(Set.empty) } } - /** - * An [[ExpressionSet]] that contains invariants about the rows output by this operator. For - * example, if this set contains the expression `a = 2` then that expression is guaranteed to - * evaluate to `true` for all rows produced. - */ - lazy val constraints: ExpressionSet = ExpressionSet(allConstraints.filter(selfReferenceOnly)) - /** * This method can be overridden by any child class of QueryPlan to specify a set of constraints * based on the given operator's constraint propagation logic. These constraints are then @@ -52,30 +51,42 @@ trait QueryPlanConstraints { self: LogicalPlan => * See [[Canonicalize]] for more details. */ protected def validConstraints: Set[Expression] = Set.empty +} + +trait ConstraintHelper { /** - * Returns an [[ExpressionSet]] that contains an additional set of constraints, such as - * equality constraints and `isNotNull` constraints, etc., and that only contains references - * to this [[LogicalPlan]] node. + * Infers an additional set of constraints from a given set of equality constraints. + * For e.g., if an operator has constraints of the form (`a = 5`, `a = b`), this returns an + * additional constraint of the form `b = 5`. */ - def getRelevantConstraints(constraints: Set[Expression]): ExpressionSet = { - val allRelevantConstraints = - if (conf.constraintPropagationEnabled) { - constraints - .union(inferAdditionalConstraints(constraints)) - .union(constructIsNotNullConstraints(constraints)) - } else { - constraints - } - ExpressionSet(allRelevantConstraints.filter(selfReferenceOnly)) + def inferAdditionalConstraints(constraints: Set[Expression]): Set[Expression] = { + var inferredConstraints = Set.empty[Expression] + constraints.foreach { + case eq @ EqualTo(l: Attribute, r: Attribute) => + val candidateConstraints = constraints - eq + inferredConstraints ++= replaceConstraints(candidateConstraints, l, r) + inferredConstraints ++= replaceConstraints(candidateConstraints, r, l) + case _ => // No inference + } + inferredConstraints -- constraints } + private def replaceConstraints( + constraints: Set[Expression], + source: Expression, + destination: Attribute): Set[Expression] = constraints.map(_ transform { + case e: Expression if e.semanticEquals(source) => destination + }) + /** * Infers a set of `isNotNull` constraints from null intolerant expressions as well as * non-nullable attributes. For e.g., if an expression is of the form (`a > 5`), this * returns a constraint of the form `isNotNull(a)` */ - private def constructIsNotNullConstraints(constraints: Set[Expression]): Set[Expression] = { + def constructIsNotNullConstraints( + constraints: Set[Expression], + output: Seq[Attribute]): Set[Expression] = { // First, we propagate constraints from the null intolerant expressions. var isNotNullConstraints: Set[Expression] = constraints.flatMap(inferIsNotNullConstraints) @@ -111,32 +122,4 @@ trait QueryPlanConstraints { self: LogicalPlan => case _: NullIntolerant => expr.children.flatMap(scanNullIntolerantAttribute) case _ => Seq.empty[Attribute] } - - /** - * Infers an additional set of constraints from a given set of equality constraints. - * For e.g., if an operator has constraints of the form (`a = 5`, `a = b`), this returns an - * additional constraint of the form `b = 5`. - */ - private def inferAdditionalConstraints(constraints: Set[Expression]): Set[Expression] = { - var inferredConstraints = Set.empty[Expression] - constraints.foreach { - case eq @ EqualTo(l: Attribute, r: Attribute) => - val candidateConstraints = constraints - eq - inferredConstraints ++= replaceConstraints(candidateConstraints, l, r) - inferredConstraints ++= replaceConstraints(candidateConstraints, r, l) - case _ => // No inference - } - inferredConstraints -- constraints - } - - private def replaceConstraints( - constraints: Set[Expression], - source: Expression, - destination: Attribute): Set[Expression] = constraints.map(_ transform { - case e: Expression if e.semanticEquals(source) => destination - }) - - private def selfReferenceOnly(e: Expression): Boolean = { - e.references.nonEmpty && e.references.subsetOf(outputSet) && e.deterministic - } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala index e068f51044589..e4671f0d1cce6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala @@ -35,11 +35,25 @@ class InferFiltersFromConstraintsSuite extends PlanTest { InferFiltersFromConstraints, CombineFilters, SimplifyBinaryComparison, - BooleanSimplification) :: Nil + BooleanSimplification, + PruneFilters) :: Nil } val testRelation = LocalRelation('a.int, 'b.int, 'c.int) + private def testConstraintsAfterJoin( + x: LogicalPlan, + y: LogicalPlan, + expectedLeft: LogicalPlan, + expectedRight: LogicalPlan, + joinType: JoinType) = { + val condition = Some("x.a".attr === "y.a".attr) + val originalQuery = x.join(y, joinType, condition).analyze + val correctAnswer = expectedLeft.join(expectedRight, joinType, condition).analyze + val optimized = Optimize.execute(originalQuery) + comparePlans(optimized, correctAnswer) + } + test("filter: filter out constraints in condition") { val originalQuery = testRelation.where('a === 1 && 'a === 'b).analyze val correctAnswer = testRelation @@ -196,13 +210,7 @@ class InferFiltersFromConstraintsSuite extends PlanTest { test("SPARK-23405: left-semi equal-join should filter out null join keys on both sides") { val x = testRelation.subquery('x) val y = testRelation.subquery('y) - val condition = Some("x.a".attr === "y.a".attr) - val originalQuery = x.join(y, LeftSemi, condition).analyze - val left = x.where(IsNotNull('a)) - val right = y.where(IsNotNull('a)) - val correctAnswer = left.join(right, LeftSemi, condition).analyze - val optimized = Optimize.execute(originalQuery) - comparePlans(optimized, correctAnswer) + testConstraintsAfterJoin(x, y, x.where(IsNotNull('a)), y.where(IsNotNull('a)), LeftSemi) } test("SPARK-21479: Outer join after-join filters push down to null-supplying side") { @@ -232,12 +240,27 @@ class InferFiltersFromConstraintsSuite extends PlanTest { test("SPARK-21479: Outer join no filter push down to preserved side") { val x = testRelation.subquery('x) val y = testRelation.subquery('y) - val condition = Some("x.a".attr === "y.a".attr) - val originalQuery = x.join(y.where("y.a".attr === 1), LeftOuter, condition).analyze - val left = x - val right = y.where(IsNotNull('a) && 'a === 1) - val correctAnswer = left.join(right, LeftOuter, condition).analyze - val optimized = Optimize.execute(originalQuery) - comparePlans(optimized, correctAnswer) + testConstraintsAfterJoin( + x, y.where("a".attr === 1), + x, y.where(IsNotNull('a) && 'a === 1), + LeftOuter) + } + + test("SPARK-23564: left anti join should filter out null join keys on right side") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + testConstraintsAfterJoin(x, y, x, y.where(IsNotNull('a)), LeftAnti) + } + + test("SPARK-23564: left outer join should filter out null join keys on right side") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + testConstraintsAfterJoin(x, y, x, y.where(IsNotNull('a)), LeftOuter) + } + + test("SPARK-23564: right outer join should filter out null join keys on left side") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + testConstraintsAfterJoin(x, y, x.where(IsNotNull('a)), y, RightOuter) } } From afbdf427302aba858f95205ecef7667f412b2a6a Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Mon, 23 Apr 2018 14:28:28 +0200 Subject: [PATCH 0668/1161] [SPARK-23589][SQL] ExternalMapToCatalyst should support interpreted execution ## What changes were proposed in this pull request? This pr supported interpreted mode for `ExternalMapToCatalyst`. ## How was this patch tested? Added tests in `ObjectExpressionsSuite`. Author: Takeshi Yamamuro Closes #20980 from maropu/SPARK-23589. --- .../expressions/objects/objects.scala | 60 +++++++++- .../expressions/ObjectExpressionsSuite.scala | 108 +++++++++++++++++- 2 files changed, 165 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index f1ffcaec8a484..9c7e76467d153 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -1255,8 +1255,64 @@ case class ExternalMapToCatalyst private( override def dataType: MapType = MapType( keyConverter.dataType, valueConverter.dataType, valueContainsNull = valueConverter.nullable) - override def eval(input: InternalRow): Any = - throw new UnsupportedOperationException("Only code-generated evaluation is supported") + private lazy val mapCatalystConverter: Any => (Array[Any], Array[Any]) = child.dataType match { + case ObjectType(cls) if classOf[java.util.Map[_, _]].isAssignableFrom(cls) => + (input: Any) => { + val data = input.asInstanceOf[java.util.Map[Any, Any]] + val keys = new Array[Any](data.size) + val values = new Array[Any](data.size) + val iter = data.entrySet().iterator() + var i = 0 + while (iter.hasNext) { + val entry = iter.next() + val (key, value) = (entry.getKey, entry.getValue) + keys(i) = if (key != null) { + keyConverter.eval(InternalRow.fromSeq(key :: Nil)) + } else { + throw new RuntimeException("Cannot use null as map key!") + } + values(i) = if (value != null) { + valueConverter.eval(InternalRow.fromSeq(value :: Nil)) + } else { + null + } + i += 1 + } + (keys, values) + } + + case ObjectType(cls) if classOf[scala.collection.Map[_, _]].isAssignableFrom(cls) => + (input: Any) => { + val data = input.asInstanceOf[scala.collection.Map[Any, Any]] + val keys = new Array[Any](data.size) + val values = new Array[Any](data.size) + var i = 0 + for ((key, value) <- data) { + keys(i) = if (key != null) { + keyConverter.eval(InternalRow.fromSeq(key :: Nil)) + } else { + throw new RuntimeException("Cannot use null as map key!") + } + values(i) = if (value != null) { + valueConverter.eval(InternalRow.fromSeq(value :: Nil)) + } else { + null + } + i += 1 + } + (keys, values) + } + } + + override def eval(input: InternalRow): Any = { + val result = child.eval(input) + if (result != null) { + val (keys, values) = mapCatalystConverter(result) + new ArrayBasedMapData(new GenericArrayData(keys), new GenericArrayData(values)) + } else { + null + } + } override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val inputMap = child.genCode(ctx) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala index 7136af8934486..730b36c32333c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala @@ -21,12 +21,13 @@ import java.sql.{Date, Timestamp} import scala.collection.JavaConverters._ import scala.reflect.ClassTag +import scala.reflect.runtime.universe.TypeTag import scala.util.Random import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} import org.apache.spark.sql.{RandomDataGenerator, Row} -import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, JavaTypeInference, ScalaReflection} import org.apache.spark.sql.catalyst.analysis.{ResolveTimeZone, SimpleAnalyzer, UnresolvedDeserializer} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.encoders._ @@ -501,6 +502,111 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { InternalRow.fromSeq(Seq(Row(1))), "java.lang.Integer is not a valid external type for schema of double") } + + private def javaMapSerializerFor( + keyClazz: Class[_], + valueClazz: Class[_])(inputObject: Expression): Expression = { + + def kvSerializerFor(inputObject: Expression, clazz: Class[_]): Expression = clazz match { + case c if c == classOf[java.lang.Integer] => + Invoke(inputObject, "intValue", IntegerType) + case c if c == classOf[java.lang.String] => + StaticInvoke( + classOf[UTF8String], + StringType, + "fromString", + inputObject :: Nil, + returnNullable = false) + } + + ExternalMapToCatalyst( + inputObject, + ObjectType(keyClazz), + kvSerializerFor(_, keyClazz), + keyNullable = true, + ObjectType(valueClazz), + kvSerializerFor(_, valueClazz), + valueNullable = true + ) + } + + private def scalaMapSerializerFor[T: TypeTag, U: TypeTag](inputObject: Expression): Expression = { + import org.apache.spark.sql.catalyst.ScalaReflection._ + + val curId = new java.util.concurrent.atomic.AtomicInteger() + + def kvSerializerFor[V: TypeTag](inputObject: Expression): Expression = + localTypeOf[V].dealias match { + case t if t <:< localTypeOf[java.lang.Integer] => + Invoke(inputObject, "intValue", IntegerType) + case t if t <:< localTypeOf[String] => + StaticInvoke( + classOf[UTF8String], + StringType, + "fromString", + inputObject :: Nil, + returnNullable = false) + case _ => + inputObject + } + + ExternalMapToCatalyst( + inputObject, + dataTypeFor[T], + kvSerializerFor[T], + keyNullable = !localTypeOf[T].typeSymbol.asClass.isPrimitive, + dataTypeFor[U], + kvSerializerFor[U], + valueNullable = !localTypeOf[U].typeSymbol.asClass.isPrimitive + ) + } + + test("SPARK-23589 ExternalMapToCatalyst should support interpreted execution") { + // Simple test + val scalaMap = scala.collection.Map[Int, String](0 -> "v0", 1 -> "v1", 2 -> null, 3 -> "v3") + val javaMap = new java.util.HashMap[java.lang.Integer, java.lang.String]() { + { + put(0, "v0") + put(1, "v1") + put(2, null) + put(3, "v3") + } + } + val expected = CatalystTypeConverters.convertToCatalyst(scalaMap) + + // Java Map + val serializer1 = javaMapSerializerFor(classOf[java.lang.Integer], classOf[java.lang.String])( + Literal.fromObject(javaMap)) + checkEvaluation(serializer1, expected) + + // Scala Map + val serializer2 = scalaMapSerializerFor[Int, String](Literal.fromObject(scalaMap)) + checkEvaluation(serializer2, expected) + + // NULL key test + val scalaMapHasNullKey = scala.collection.Map[java.lang.Integer, String]( + null.asInstanceOf[java.lang.Integer] -> "v0", new java.lang.Integer(1) -> "v1") + val javaMapHasNullKey = new java.util.HashMap[java.lang.Integer, java.lang.String]() { + { + put(null, "v0") + put(1, "v1") + } + } + + // Java Map + val serializer3 = + javaMapSerializerFor(classOf[java.lang.Integer], classOf[java.lang.String])( + Literal.fromObject(javaMapHasNullKey)) + checkExceptionInExpression[RuntimeException]( + serializer3, EmptyRow, "Cannot use null as map key!") + + // Scala Map + val serializer4 = scalaMapSerializerFor[java.lang.Integer, String]( + Literal.fromObject(scalaMapHasNullKey)) + + checkExceptionInExpression[RuntimeException]( + serializer4, EmptyRow, "Cannot use null as map key!") + } } class TestBean extends Serializable { From 293a0f29e314dc532cec2048a7c6bc00e31de472 Mon Sep 17 00:00:00 2001 From: Teng Peng Date: Mon, 23 Apr 2018 10:29:47 -0700 Subject: [PATCH 0669/1161] [Spark-24024][ML] Fix poisson deviance calculations in GLM to handle y = 0 ## What changes were proposed in this pull request? It is reported by Spark users that the deviance calculation for poisson regression does not handle y = 0. Thus, the correct model summary cannot be obtained. The user has confirmed the the issue is in ``` override def deviance(y: Double, mu: Double, weight: Double): Double = { 2.0 * weight * (y * math.log(y / mu) - (y - mu)) } when y = 0. ``` The user also mentioned there are many other places he believe we should check the same thing. However, no other changes are needed, including Gamma distribution. ## How was this patch tested? Add a comparison with R deviance calculation to the existing unit test. Author: Teng Peng Closes #21125 from tengpeng/Spark24024GLM. --- .../ml/regression/GeneralizedLinearRegression.scala | 10 +++++----- .../regression/GeneralizedLinearRegressionSuite.scala | 10 ++++++++++ 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala index 9f1f2405c428e..4c3f1431d5077 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala @@ -471,6 +471,10 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine private[regression] val epsilon: Double = 1E-16 + private[regression] def ylogy(y: Double, mu: Double): Double = { + if (y == 0) 0.0 else y * math.log(y / mu) + } + /** * Wrapper of family and link combination used in the model. */ @@ -725,10 +729,6 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine override def variance(mu: Double): Double = mu * (1.0 - mu) - private def ylogy(y: Double, mu: Double): Double = { - if (y == 0) 0.0 else y * math.log(y / mu) - } - override def deviance(y: Double, mu: Double, weight: Double): Double = { 2.0 * weight * (ylogy(y, mu) + ylogy(1.0 - y, 1.0 - mu)) } @@ -783,7 +783,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine override def variance(mu: Double): Double = mu override def deviance(y: Double, mu: Double, weight: Double): Double = { - 2.0 * weight * (y * math.log(y / mu) - (y - mu)) + 2.0 * weight * (ylogy(y, mu) - (y - mu)) } override def aic( diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala index d5bcbb221783e..997c50157dcda 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala @@ -493,11 +493,20 @@ class GeneralizedLinearRegressionSuite extends MLTest with DefaultReadWriteTest } [1] -0.0457441 -0.6833928 [1] 1.8121235 -0.1747493 -0.5815417 + + R code for deivance calculation: + data = cbind(y=c(0,1,0,0,0,1), x1=c(18, 12, 15, 13, 15, 16), x2=c(1,0,0,2,1,1)) + summary(glm(y~x1+x2, family=poisson, data=data.frame(data)))$deviance + [1] 3.70055 + summary(glm(y~x1+x2-1, family=poisson, data=data.frame(data)))$deviance + [1] 3.809296 */ val expected = Seq( Vectors.dense(0.0, -0.0457441, -0.6833928), Vectors.dense(1.8121235, -0.1747493, -0.5815417)) + val residualDeviancesR = Array(3.809296, 3.70055) + import GeneralizedLinearRegression._ var idx = 0 @@ -510,6 +519,7 @@ class GeneralizedLinearRegressionSuite extends MLTest with DefaultReadWriteTest val actual = Vectors.dense(model.intercept, model.coefficients(0), model.coefficients(1)) assert(actual ~= expected(idx) absTol 1e-4, "Model mismatch: GLM with poisson family, " + s"$link link and fitIntercept = $fitIntercept (with zero values).") + assert(model.summary.deviance ~== residualDeviancesR(idx) absTol 1E-3) idx += 1 } } From 448d248f897fa39cfc82d71a3d6b67e6470f8a02 Mon Sep 17 00:00:00 2001 From: liuzhaokun Date: Mon, 23 Apr 2018 13:56:11 -0500 Subject: [PATCH 0670/1161] [SPARK-21168] KafkaRDD should always set kafka clientId. [https://issues.apache.org/jira/browse/SPARK-21168](https://issues.apache.org/jira/browse/SPARK-21168) There are no a number of other places that a client ID should be set,and I think we should use consumer.clientId in the clientId method,because the fetch request will be used by the same consumer behind. Author: liuzhaokun Closes #19887 from liu-zhaokun/master1205. --- .../main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala b/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala index 5ea52b6ad36a0..791cf0efaf888 100644 --- a/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala +++ b/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala @@ -191,6 +191,7 @@ class KafkaRDD[ private def fetchBatch: Iterator[MessageAndOffset] = { val req = new FetchRequestBuilder() + .clientId(consumer.clientId) .addFetch(part.topic, part.partition, requestOffset, kc.config.fetchMessageMaxBytes) .build() val resp = consumer.fetch(req) From 770add81c3474e754867d7105031a5eaf27159bd Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Mon, 23 Apr 2018 13:20:32 -0700 Subject: [PATCH 0671/1161] [SPARK-23004][SS] Ensure StateStore.commit is called only once in a streaming aggregation task MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? A structured streaming query with a streaming aggregation can throw the following error in rare cases.  ``` java.lang.IllegalStateException: Cannot commit after already committed or aborted at org.apache.spark.sql.execution.streaming.state.HDFSBackedStateStoreProvider.org$apache$spark$sql$execution$streaming$state$HDFSBackedStateStoreProvider$$verify(HDFSBackedStateStoreProvider.scala:643) at org.apache.spark.sql.execution.streaming.state.HDFSBackedStateStoreProvider$HDFSBackedStateStore.commit(HDFSBackedStateStoreProvider.scala:135) at org.apache.spark.sql.execution.streaming.StateStoreSaveExec$$anonfun$doExecute$3$$anon$2$$anonfun$hasNext$2.apply$mcV$sp(statefulOperators.scala:359) at org.apache.spark.sql.execution.streaming.StateStoreWriter$class.timeTakenMs(statefulOperators.scala:102) at org.apache.spark.sql.execution.streaming.StateStoreSaveExec.timeTakenMs(statefulOperators.scala:251) at org.apache.spark.sql.execution.streaming.StateStoreSaveExec$$anonfun$doExecute$3$$anon$2.hasNext(statefulOperators.scala:359) at org.apache.spark.sql.execution.aggregate.ObjectAggregationIterator.processInputs(ObjectAggregationIterator.scala:188) at org.apache.spark.sql.execution.aggregate.ObjectAggregationIterator.(ObjectAggregationIterator.scala:78) at org.apache.spark.sql.execution.aggregate.ObjectHashAggregateExec$$anonfun$doExecute$1$$anonfun$2.apply(ObjectHashAggregateExec.scala:114) at org.apache.spark.sql.execution.aggregate.ObjectHashAggregateExec$$anonfun$doExecute$1$$anonfun$2.apply(ObjectHashAggregateExec.scala:105) at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsWithIndexInternal$1$$anonfun$apply$24.apply(RDD.scala:830) at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsWithIndexInternal$1$$anonfun$apply$24.apply(RDD.scala:830) at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:42) at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:336) ``` This can happen when the following conditions are accidentally hit.  - Streaming aggregation with aggregation function that is a subset of [`TypedImperativeAggregation`](https://github.com/apache/spark/blob/76b8b840ddc951ee6203f9cccd2c2b9671c1b5e8/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala#L473) (for example, `collect_set`, `collect_list`, `percentile`, etc.).  - Query running in `update}` mode - After the shuffle, a partition has exactly 128 records.  This causes StateStore.commit to be called twice. See the [JIRA](https://issues.apache.org/jira/browse/SPARK-23004) for a more detailed explanation. The solution is to use `NextIterator` or `CompletionIterator`, each of which has a flag to prevent the "onCompletion" task from being called more than once. In this PR, I chose to implement using `NextIterator`. ## How was this patch tested? Added unit test that I have confirm will fail without the fix. Author: Tathagata Das Closes #21124 from tdas/SPARK-23004. --- .../streaming/statefulOperators.scala | 40 +++++++++---------- .../streaming/StreamingAggregationSuite.scala | 25 ++++++++++++ 2 files changed, 44 insertions(+), 21 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index b9b07a2e688f9..c9354ac0ec78a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -340,37 +340,35 @@ case class StateStoreSaveExec( // Update and output modified rows from the StateStore. case Some(Update) => - val updatesStartTimeNs = System.nanoTime - - new Iterator[InternalRow] { - + new NextIterator[InternalRow] { // Filter late date using watermark if specified private[this] val baseIterator = watermarkPredicateForData match { case Some(predicate) => iter.filter((row: InternalRow) => !predicate.eval(row)) case None => iter } + private val updatesStartTimeNs = System.nanoTime - override def hasNext: Boolean = { - if (!baseIterator.hasNext) { - allUpdatesTimeMs += NANOSECONDS.toMillis(System.nanoTime - updatesStartTimeNs) - - // Remove old aggregates if watermark specified - allRemovalsTimeMs += timeTakenMs { removeKeysOlderThanWatermark(store) } - commitTimeMs += timeTakenMs { store.commit() } - setStoreMetrics(store) - false + override protected def getNext(): InternalRow = { + if (baseIterator.hasNext) { + val row = baseIterator.next().asInstanceOf[UnsafeRow] + val key = getKey(row) + store.put(key, row) + numOutputRows += 1 + numUpdatedStateRows += 1 + row } else { - true + finished = true + null } } - override def next(): InternalRow = { - val row = baseIterator.next().asInstanceOf[UnsafeRow] - val key = getKey(row) - store.put(key, row) - numOutputRows += 1 - numUpdatedStateRows += 1 - row + override protected def close(): Unit = { + allUpdatesTimeMs += NANOSECONDS.toMillis(System.nanoTime - updatesStartTimeNs) + + // Remove old aggregates if watermark specified + allRemovalsTimeMs += timeTakenMs { removeKeysOlderThanWatermark(store) } + commitTimeMs += timeTakenMs { store.commit() } + setStoreMetrics(store) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala index 1cae8cb8d47f1..382da13430781 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala @@ -536,6 +536,31 @@ class StreamingAggregationSuite extends StateStoreMetricsTest ) } + test("SPARK-23004: Ensure that TypedImperativeAggregate functions do not throw errors") { + // See the JIRA SPARK-23004 for more details. In short, this test reproduces the error + // by ensuring the following. + // - A streaming query with a streaming aggregation. + // - Aggregation function 'collect_list' that is a subclass of TypedImperativeAggregate. + // - Post shuffle partition has exactly 128 records (i.e. the threshold at which + // ObjectHashAggregateExec falls back to sort-based aggregation). This is done by having a + // micro-batch with 128 records that shuffle to a single partition. + // This test throws the exact error reported in SPARK-23004 without the corresponding fix. + withSQLConf("spark.sql.shuffle.partitions" -> "1") { + val input = MemoryStream[Int] + val df = input.toDF().toDF("value") + .selectExpr("value as group", "value") + .groupBy("group") + .agg(collect_list("value")) + testStream(df, outputMode = OutputMode.Update)( + AddData(input, (1 to spark.sqlContext.conf.objectAggSortBasedFallbackThreshold): _*), + AssertOnQuery { q => + q.processAllAvailable() + true + } + ) + } + } + /** Add blocks of data to the `BlockRDDBackedSource`. */ case class AddBlockData(source: BlockRDDBackedSource, data: Seq[Int]*) extends AddData { override def addData(query: Option[StreamExecution]): (Source, Offset) = { From e82cb68349b785c1b35bcfb85bff3a8ec2c93fee Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Mon, 23 Apr 2018 13:23:02 -0700 Subject: [PATCH 0672/1161] [SPARK-11237][ML] Add pmml export for k-means in Spark ML ## What changes were proposed in this pull request? Adding PMML export to Spark ML's KMeans Model. ## How was this patch tested? New unit test for Spark ML PMML export based on the old Spark MLlib unit test. Author: Holden Karau Closes #20907 from holdenk/SPARK-11237-Add-PMML-Export-for-KMeans. --- .../org.apache.spark.ml.util.MLFormatRegister | 4 +- .../apache/spark/ml/clustering/KMeans.scala | 75 ++++++++++++------- .../ml/regression/LinearRegression.scala | 2 +- .../spark/ml/clustering/KMeansSuite.scala | 32 +++++++- 4 files changed, 83 insertions(+), 30 deletions(-) diff --git a/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.util.MLFormatRegister b/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.util.MLFormatRegister index 5e5484fd8784d..f14431d50feec 100644 --- a/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.util.MLFormatRegister +++ b/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.util.MLFormatRegister @@ -1,2 +1,4 @@ org.apache.spark.ml.regression.InternalLinearRegressionModelWriter -org.apache.spark.ml.regression.PMMLLinearRegressionModelWriter \ No newline at end of file +org.apache.spark.ml.regression.PMMLLinearRegressionModelWriter +org.apache.spark.ml.clustering.InternalKMeansModelWriter +org.apache.spark.ml.clustering.PMMLKMeansModelWriter \ No newline at end of file diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index 987a4285ebad4..1ad157a695a7d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -17,11 +17,13 @@ package org.apache.spark.ml.clustering +import scala.collection.mutable + import org.apache.hadoop.fs.Path import org.apache.spark.SparkException import org.apache.spark.annotation.{Experimental, Since} -import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.ml.{Estimator, Model, PipelineStage} import org.apache.spark.ml.linalg.{Vector, VectorUDT} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ @@ -30,7 +32,7 @@ import org.apache.spark.mllib.clustering.{DistanceMeasure, KMeans => MLlibKMeans import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => OldVectors} import org.apache.spark.mllib.linalg.VectorImplicits._ import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, Dataset, Row} +import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession} import org.apache.spark.sql.functions.{col, udf} import org.apache.spark.sql.types.{IntegerType, StructType} import org.apache.spark.storage.StorageLevel @@ -103,8 +105,8 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe @Since("1.5.0") class KMeansModel private[ml] ( @Since("1.5.0") override val uid: String, - private val parentModel: MLlibKMeansModel) - extends Model[KMeansModel] with KMeansParams with MLWritable { + private[clustering] val parentModel: MLlibKMeansModel) + extends Model[KMeansModel] with KMeansParams with GeneralMLWritable { @Since("1.5.0") override def copy(extra: ParamMap): KMeansModel = { @@ -152,14 +154,14 @@ class KMeansModel private[ml] ( } /** - * Returns a [[org.apache.spark.ml.util.MLWriter]] instance for this ML instance. + * Returns a [[org.apache.spark.ml.util.GeneralMLWriter]] instance for this ML instance. * * For [[KMeansModel]], this does NOT currently save the training [[summary]]. * An option to save [[summary]] may be added in the future. * */ @Since("1.6.0") - override def write: MLWriter = new KMeansModel.KMeansModelWriter(this) + override def write: GeneralMLWriter = new GeneralMLWriter(this) private var trainingSummary: Option[KMeansSummary] = None @@ -185,6 +187,47 @@ class KMeansModel private[ml] ( } } +/** Helper class for storing model data */ +private case class ClusterData(clusterIdx: Int, clusterCenter: Vector) + + +/** A writer for KMeans that handles the "internal" (or default) format */ +private class InternalKMeansModelWriter extends MLWriterFormat with MLFormatRegister { + + override def format(): String = "internal" + override def stageName(): String = "org.apache.spark.ml.clustering.KMeansModel" + + override def write(path: String, sparkSession: SparkSession, + optionMap: mutable.Map[String, String], stage: PipelineStage): Unit = { + val instance = stage.asInstanceOf[KMeansModel] + val sc = sparkSession.sparkContext + // Save metadata and Params + DefaultParamsWriter.saveMetadata(instance, path, sc) + // Save model data: cluster centers + val data: Array[ClusterData] = instance.clusterCenters.zipWithIndex.map { + case (center, idx) => + ClusterData(idx, center) + } + val dataPath = new Path(path, "data").toString + sparkSession.createDataFrame(data).repartition(1).write.parquet(dataPath) + } +} + +/** A writer for KMeans that handles the "pmml" format */ +private class PMMLKMeansModelWriter extends MLWriterFormat with MLFormatRegister { + + override def format(): String = "pmml" + override def stageName(): String = "org.apache.spark.ml.clustering.KMeansModel" + + override def write(path: String, sparkSession: SparkSession, + optionMap: mutable.Map[String, String], stage: PipelineStage): Unit = { + val instance = stage.asInstanceOf[KMeansModel] + val sc = sparkSession.sparkContext + instance.parentModel.toPMML(sc, path) + } +} + + @Since("1.6.0") object KMeansModel extends MLReadable[KMeansModel] { @@ -194,30 +237,12 @@ object KMeansModel extends MLReadable[KMeansModel] { @Since("1.6.0") override def load(path: String): KMeansModel = super.load(path) - /** Helper class for storing model data */ - private case class Data(clusterIdx: Int, clusterCenter: Vector) - /** * We store all cluster centers in a single row and use this class to store model data by * Spark 1.6 and earlier. A model can be loaded from such older data for backward compatibility. */ private case class OldData(clusterCenters: Array[OldVector]) - /** [[MLWriter]] instance for [[KMeansModel]] */ - private[KMeansModel] class KMeansModelWriter(instance: KMeansModel) extends MLWriter { - - override protected def saveImpl(path: String): Unit = { - // Save metadata and Params - DefaultParamsWriter.saveMetadata(instance, path, sc) - // Save model data: cluster centers - val data: Array[Data] = instance.clusterCenters.zipWithIndex.map { case (center, idx) => - Data(idx, center) - } - val dataPath = new Path(path, "data").toString - sparkSession.createDataFrame(data).repartition(1).write.parquet(dataPath) - } - } - private class KMeansModelReader extends MLReader[KMeansModel] { /** Checked against metadata when loading model */ @@ -232,7 +257,7 @@ object KMeansModel extends MLReadable[KMeansModel] { val dataPath = new Path(path, "data").toString val clusterCenters = if (majorVersion(metadata.sparkVersion) >= 2) { - val data: Dataset[Data] = sparkSession.read.parquet(dataPath).as[Data] + val data: Dataset[ClusterData] = sparkSession.read.parquet(dataPath).as[ClusterData] data.collect().sortBy(_.clusterIdx).map(_.clusterCenter).map(OldVectors.fromML) } else { // Loads KMeansModel stored with the old format used by Spark 1.6 and earlier. diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index f67d9d831f327..9cdd3a051e719 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -746,7 +746,7 @@ private class InternalLinearRegressionModelWriter /** A writer for LinearRegression that handles the "pmml" format */ private class PMMLLinearRegressionModelWriter - extends MLWriterFormat with MLFormatRegister { + extends MLWriterFormat with MLFormatRegister { override def format(): String = "pmml" diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala index 32830b39407ad..77c9d482d95b6 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala @@ -19,17 +19,22 @@ package org.apache.spark.ml.clustering import scala.util.Random +import org.dmg.pmml.{ClusteringModel, PMML} + import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamMap -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} -import org.apache.spark.mllib.clustering.{DistanceMeasure, KMeans => MLlibKMeans} +import org.apache.spark.ml.util._ +import org.apache.spark.mllib.clustering.{DistanceMeasure, KMeans => MLlibKMeans, + KMeansModel => MLlibKMeansModel} +import org.apache.spark.mllib.linalg.{Vectors => MLlibVectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, Dataset, SparkSession} private[clustering] case class TestRow(features: Vector) -class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest + with PMMLReadWriteTest { final val k = 5 @transient var dataset: Dataset[_] = _ @@ -202,6 +207,27 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR testEstimatorAndModelReadWrite(kmeans, dataset, KMeansSuite.allParamSettings, KMeansSuite.allParamSettings, checkModelData) } + + test("pmml export") { + val clusterCenters = Array( + MLlibVectors.dense(1.0, 2.0, 6.0), + MLlibVectors.dense(1.0, 3.0, 0.0), + MLlibVectors.dense(1.0, 4.0, 6.0)) + val oldKmeansModel = new MLlibKMeansModel(clusterCenters) + val kmeansModel = new KMeansModel("", oldKmeansModel) + def checkModel(pmml: PMML): Unit = { + // Check the header descripiton is what we expect + assert(pmml.getHeader.getDescription === "k-means clustering") + // check that the number of fields match the single vector size + assert(pmml.getDataDictionary.getNumberOfFields === clusterCenters(0).size) + // This verify that there is a model attached to the pmml object and the model is a clustering + // one. It also verifies that the pmml model has the same number of clusters of the spark + // model. + val pmmlClusteringModel = pmml.getModels.get(0).asInstanceOf[ClusteringModel] + assert(pmmlClusteringModel.getNumberOfClusters === clusterCenters.length) + } + testPMMLWrite(sc, kmeansModel, checkModel) + } } object KMeansSuite { From c8f3ac69d176bd10b8de1c147b6903a247943d51 Mon Sep 17 00:00:00 2001 From: wuyi Date: Mon, 23 Apr 2018 15:35:45 -0500 Subject: [PATCH 0673/1161] [SPARK-23888][CORE] correct the comment of hasAttemptOnHost() TaskSetManager.hasAttemptOnHost had a misleading comment. The comment said that it only checked for running tasks, but really it checked for any tasks that might have run in the past as well. This updates to line up with the implementation. Author: wuyi Closes #20998 from Ngone51/SPARK-23888. --- .../main/scala/org/apache/spark/scheduler/TaskSetManager.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index d958658527f6d..8a96a7692f614 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -287,7 +287,7 @@ private[spark] class TaskSetManager( None } - /** Check whether a task is currently running an attempt on a given host */ + /** Check whether a task once ran an attempt on a given host */ private def hasAttemptOnHost(taskIndex: Int, host: String): Boolean = { taskAttempts(taskIndex).exists(_.host == host) } From 428b903859c3d8873045fdcfffdebe24fc6e027f Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Tue, 24 Apr 2018 09:10:29 +0800 Subject: [PATCH 0674/1161] [SPARK-24029][CORE] Follow up: set SO_REUSEADDR on the server socket. "childOption" is for the remote connections, not for the server socket that actually listens for incoming connections. Author: Marcelo Vanzin Closes #21132 from vanzin/SPARK-24029.2. --- .../java/org/apache/spark/network/server/TransportServer.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/TransportServer.java b/common/network-common/src/main/java/org/apache/spark/network/server/TransportServer.java index 612750972c4bb..60f51125c07fd 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/TransportServer.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/TransportServer.java @@ -99,8 +99,8 @@ private void init(String hostToBind, int portToBind) { .group(bossGroup, workerGroup) .channel(NettyUtils.getServerChannelClass(ioMode)) .option(ChannelOption.ALLOCATOR, allocator) - .childOption(ChannelOption.ALLOCATOR, allocator) - .childOption(ChannelOption.SO_REUSEADDR, !SystemUtils.IS_OS_WINDOWS); + .option(ChannelOption.SO_REUSEADDR, !SystemUtils.IS_OS_WINDOWS) + .childOption(ChannelOption.ALLOCATOR, allocator); this.metrics = new NettyMemoryMetrics( allocator, conf.getModuleName() + "-server", conf); From 281c1ca0dc96b0441a60c32df3d16fbb1c61e99f Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Tue, 24 Apr 2018 10:11:09 +0800 Subject: [PATCH 0675/1161] [SPARK-23973][SQL] Remove consecutive Sorts ## What changes were proposed in this pull request? In SPARK-23375 we introduced the ability of removing `Sort` operation during query optimization if the data is already sorted. In this follow-up we remove also a `Sort` which is followed by another `Sort`: in this case the first sort is not needed and can be safely removed. The PR starts from henryr's comment: https://github.com/apache/spark/pull/20560#discussion_r180601594. So credit should be given to him. ## How was this patch tested? added UT Author: Marco Gaido Closes #21072 from mgaido91/SPARK-23973. --- .../sql/catalyst/optimizer/Optimizer.scala | 21 +++++++- .../optimizer/RemoveRedundantSortsSuite.scala | 51 ++++++++++++++++--- 2 files changed, 63 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index f00d40d11f23f..45f13956a0a85 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -767,12 +767,29 @@ object EliminateSorts extends Rule[LogicalPlan] { } /** - * Removes Sort operation if the child is already sorted + * Removes redundant Sort operation. This can happen: + * 1) if the child is already sorted + * 2) if there is another Sort operator separated by 0...n Project/Filter operators */ object RemoveRedundantSorts extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transform { + def apply(plan: LogicalPlan): LogicalPlan = plan transformDown { case Sort(orders, true, child) if SortOrder.orderingSatisfies(child.outputOrdering, orders) => child + case s @ Sort(_, _, child) => s.copy(child = recursiveRemoveSort(child)) + } + + def recursiveRemoveSort(plan: LogicalPlan): LogicalPlan = plan match { + case Sort(_, _, child) => recursiveRemoveSort(child) + case other if canEliminateSort(other) => + other.withNewChildren(other.children.map(recursiveRemoveSort)) + case _ => plan + } + + def canEliminateSort(plan: LogicalPlan): Boolean = plan match { + case p: Project => p.projectList.forall(_.deterministic) + case f: Filter => f.condition.deterministic + case _: ResolvedHint => true + case _ => false } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantSortsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantSortsSuite.scala index 2319ab8046e56..dae5e6f3ee3dd 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantSortsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantSortsSuite.scala @@ -17,16 +17,12 @@ package org.apache.spark.sql.catalyst.optimizer -import org.apache.spark.sql.catalyst.analysis.{Analyzer, EmptyFunctionRegistry} -import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.internal.SQLConf.{CASE_SENSITIVE, ORDER_BY_ORDINAL} class RemoveRedundantSortsSuite extends PlanTest { @@ -42,15 +38,15 @@ class RemoveRedundantSortsSuite extends PlanTest { test("remove redundant order by") { val orderedPlan = testRelation.select('a, 'b).orderBy('a.asc, 'b.desc_nullsFirst) - val unnecessaryReordered = orderedPlan.select('a).orderBy('a.asc, 'b.desc_nullsFirst) + val unnecessaryReordered = orderedPlan.limit(2).select('a).orderBy('a.asc, 'b.desc_nullsFirst) val optimized = Optimize.execute(unnecessaryReordered.analyze) - val correctAnswer = orderedPlan.select('a).analyze + val correctAnswer = orderedPlan.limit(2).select('a).analyze comparePlans(Optimize.execute(optimized), correctAnswer) } test("do not remove sort if the order is different") { val orderedPlan = testRelation.select('a, 'b).orderBy('a.asc, 'b.desc_nullsFirst) - val reorderedDifferently = orderedPlan.select('a).orderBy('a.asc, 'b.desc) + val reorderedDifferently = orderedPlan.limit(2).select('a).orderBy('a.asc, 'b.desc) val optimized = Optimize.execute(reorderedDifferently.analyze) val correctAnswer = reorderedDifferently.analyze comparePlans(optimized, correctAnswer) @@ -72,6 +68,14 @@ class RemoveRedundantSortsSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + test("different sorts are not simplified if limit is in between") { + val orderedPlan = testRelation.select('a, 'b).orderBy('b.desc).limit(Literal(10)) + .orderBy('a.asc) + val optimized = Optimize.execute(orderedPlan.analyze) + val correctAnswer = orderedPlan.analyze + comparePlans(optimized, correctAnswer) + } + test("range is already sorted") { val inputPlan = Range(1L, 1000L, 1, 10) val orderedPlan = inputPlan.orderBy('id.asc) @@ -98,4 +102,37 @@ class RemoveRedundantSortsSuite extends PlanTest { val correctAnswer = groupedAndResorted.analyze comparePlans(optimized, correctAnswer) } + + test("remove two consecutive sorts") { + val orderedTwice = testRelation.orderBy('a.asc).orderBy('b.desc) + val optimized = Optimize.execute(orderedTwice.analyze) + val correctAnswer = testRelation.orderBy('b.desc).analyze + comparePlans(optimized, correctAnswer) + } + + test("remove sorts separated by Filter/Project operators") { + val orderedTwiceWithProject = testRelation.orderBy('a.asc).select('b).orderBy('b.desc) + val optimizedWithProject = Optimize.execute(orderedTwiceWithProject.analyze) + val correctAnswerWithProject = testRelation.select('b).orderBy('b.desc).analyze + comparePlans(optimizedWithProject, correctAnswerWithProject) + + val orderedTwiceWithFilter = + testRelation.orderBy('a.asc).where('b > Literal(0)).orderBy('b.desc) + val optimizedWithFilter = Optimize.execute(orderedTwiceWithFilter.analyze) + val correctAnswerWithFilter = testRelation.where('b > Literal(0)).orderBy('b.desc).analyze + comparePlans(optimizedWithFilter, correctAnswerWithFilter) + + val orderedTwiceWithBoth = + testRelation.orderBy('a.asc).select('b).where('b > Literal(0)).orderBy('b.desc) + val optimizedWithBoth = Optimize.execute(orderedTwiceWithBoth.analyze) + val correctAnswerWithBoth = + testRelation.select('b).where('b > Literal(0)).orderBy('b.desc).analyze + comparePlans(optimizedWithBoth, correctAnswerWithBoth) + + val orderedThrice = orderedTwiceWithBoth.select(('b + 1).as('c)).orderBy('c.asc) + val optimizedThrice = Optimize.execute(orderedThrice.analyze) + val correctAnswerThrice = testRelation.select('b).where('b > Literal(0)) + .select(('b + 1).as('c)).orderBy('c.asc).analyze + comparePlans(optimizedThrice, correctAnswerThrice) + } } From c303b1b6766a3dc5961713f98f62cd7d7ac7972a Mon Sep 17 00:00:00 2001 From: seancxmao Date: Tue, 24 Apr 2018 16:16:07 +0800 Subject: [PATCH 0676/1161] [MINOR][DOCS] Fix comments of SQLExecution#withExecutionId ## What changes were proposed in this pull request? Fix comment. Change `BroadcastHashJoin.broadcastFuture` to `BroadcastExchangeExec.relationFuture`: https://github.com/apache/spark/blob/d28d5732ae205771f1f443b15b10e64dcffb5ff0/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala#L66 ## How was this patch tested? N/A Author: seancxmao Closes #21113 from seancxmao/SPARK-13136. --- .../scala/org/apache/spark/sql/execution/SQLExecution.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala index e991da7df0bde..2c5102b1e5ee7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala @@ -88,7 +88,7 @@ object SQLExecution { /** * Wrap an action with a known executionId. When running a different action in a different * thread from the original one, this method can be used to connect the Spark jobs in this action - * with the known executionId, e.g., `BroadcastHashJoin.broadcastFuture`. + * with the known executionId, e.g., `BroadcastExchangeExec.relationFuture`. */ def withExecutionId[T](sc: SparkContext, executionId: String)(body: => T): T = { val oldExecutionId = sc.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) From 87e8a572be14381da9081365d9aa2cbf3253a32c Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Tue, 24 Apr 2018 16:18:20 +0800 Subject: [PATCH 0677/1161] [SPARK-24054][R] Add array_position function / element_at functions ## What changes were proposed in this pull request? This PR proposes to add array_position and element_at in R side too. array_position: ```r df <- createDataFrame(cbind(model = rownames(mtcars), mtcars)) mutated <- mutate(df, v1 = create_array(df$gear, df$am, df$carb)) head(select(mutated, array_position(mutated$v1, 1))) ``` ``` array_position(v1, 1.0) 1 2 2 2 3 2 4 3 5 0 6 3 ``` element_at: ```r df <- createDataFrame(cbind(model = rownames(mtcars), mtcars)) mutated <- mutate(df, v1 = create_array(df$mpg, df$cyl, df$hp)) head(select(mutated, element_at(mutated$v1, 1))) ``` ``` element_at(v1, 1.0) 1 21.0 2 21.0 3 22.8 4 21.4 5 18.7 6 18.1 ``` ```r df <- createDataFrame(cbind(model = rownames(mtcars), mtcars)) mutated <- mutate(df, v1 = create_map(df$model, df$cyl)) head(select(mutated, element_at(mutated$v1, "Valiant"))) ``` ``` element_at(v3, Valiant) 1 NA 2 NA 3 NA 4 NA 5 NA 6 6 ``` ## How was this patch tested? Unit tests were added in `R/pkg/tests/fulltests/test_sparkSQL.R` and manually tested. Documentation was manually built and verified. Author: hyukjinkwon Closes #21130 from HyukjinKwon/sparkr_array_position_element_at. --- R/pkg/NAMESPACE | 2 ++ R/pkg/R/functions.R | 42 +++++++++++++++++++++++++-- R/pkg/R/generics.R | 8 +++++ R/pkg/tests/fulltests/test_sparkSQL.R | 13 +++++++-- 4 files changed, 61 insertions(+), 4 deletions(-) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 190c50ea10482..55dec177ea853 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -201,6 +201,7 @@ exportMethods("%<=>%", "approxCountDistinct", "approxQuantile", "array_contains", + "array_position", "asc", "ascii", "asin", @@ -245,6 +246,7 @@ exportMethods("%<=>%", "decode", "dense_rank", "desc", + "element_at", "encode", "endsWith", "exp", diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index a527426b19674..7b3aa05074563 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -189,6 +189,11 @@ NULL #' the map or array of maps. #' \item \code{from_json}: it is the column containing the JSON string. #' } +#' @param value A value to compute on. +#' \itemize{ +#' \item \code{array_contains}: a value to be checked if contained in the column. +#' \item \code{array_position}: a value to locate in the given array. +#' } #' @param ... additional argument(s). In \code{to_json} and \code{from_json}, this contains #' additional named properties to control how it is converted, accepts the same #' options as the JSON data source. @@ -201,6 +206,7 @@ NULL #' df <- createDataFrame(cbind(model = rownames(mtcars), mtcars)) #' tmp <- mutate(df, v1 = create_array(df$mpg, df$cyl, df$hp)) #' head(select(tmp, array_contains(tmp$v1, 21), size(tmp$v1))) +#' head(select(tmp, array_position(tmp$v1, 21))) #' tmp2 <- mutate(tmp, v2 = explode(tmp$v1)) #' head(tmp2) #' head(select(tmp, posexplode(tmp$v1))) @@ -208,7 +214,8 @@ NULL #' head(select(tmp, sort_array(tmp$v1, asc = FALSE))) #' tmp3 <- mutate(df, v3 = create_map(df$model, df$cyl)) #' head(select(tmp3, map_keys(tmp3$v3))) -#' head(select(tmp3, map_values(tmp3$v3)))} +#' head(select(tmp3, map_values(tmp3$v3))) +#' head(select(tmp3, element_at(tmp3$v3, "Valiant")))} NULL #' Window functions for Column operations @@ -2975,7 +2982,6 @@ setMethod("row_number", #' \code{array_contains}: Returns null if the array is null, true if the array contains #' the value, and false otherwise. #' -#' @param value a value to be checked if contained in the column #' @rdname column_collection_functions #' @aliases array_contains array_contains,Column-method #' @note array_contains since 1.6.0 @@ -2986,6 +2992,22 @@ setMethod("array_contains", column(jc) }) +#' @details +#' \code{array_position}: Locates the position of the first occurrence of the given value +#' in the given array. Returns NA if either of the arguments are NA. +#' Note: The position is not zero based, but 1 based index. Returns 0 if the given +#' value could not be found in the array. +#' +#' @rdname column_collection_functions +#' @aliases array_position array_position,Column-method +#' @note array_position since 2.4.0 +setMethod("array_position", + signature(x = "Column", value = "ANY"), + function(x, value) { + jc <- callJStatic("org.apache.spark.sql.functions", "array_position", x@jc, value) + column(jc) + }) + #' @details #' \code{map_keys}: Returns an unordered array containing the keys of the map. #' @@ -3012,6 +3034,22 @@ setMethod("map_values", column(jc) }) +#' @details +#' \code{element_at}: Returns element of array at given index in \code{extraction} if +#' \code{x} is array. Returns value for the given key in \code{extraction} if \code{x} is map. +#' Note: The position is not zero based, but 1 based index. +#' +#' @param extraction index to check for in array or key to check for in map +#' @rdname column_collection_functions +#' @aliases element_at element_at,Column-method +#' @note element_at since 2.4.0 +setMethod("element_at", + signature(x = "Column", extraction = "ANY"), + function(x, extraction) { + jc <- callJStatic("org.apache.spark.sql.functions", "element_at", x@jc, extraction) + column(jc) + }) + #' @details #' \code{explode}: Creates a new row for each element in the given array or map column. #' diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 974beff1a3d76..f30ac9e4295e4 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -757,6 +757,10 @@ setGeneric("approxCountDistinct", function(x, ...) { standardGeneric("approxCoun #' @name NULL setGeneric("array_contains", function(x, value) { standardGeneric("array_contains") }) +#' @rdname column_collection_functions +#' @name NULL +setGeneric("array_position", function(x, value) { standardGeneric("array_position") }) + #' @rdname column_string_functions #' @name NULL setGeneric("ascii", function(x) { standardGeneric("ascii") }) @@ -886,6 +890,10 @@ setGeneric("decode", function(x, charset) { standardGeneric("decode") }) #' @name NULL setGeneric("dense_rank", function(x = "missing") { standardGeneric("dense_rank") }) +#' @rdname column_collection_functions +#' @name NULL +setGeneric("element_at", function(x, extraction) { standardGeneric("element_at") }) + #' @rdname column_string_functions #' @name NULL setGeneric("encode", function(x, charset) { standardGeneric("encode") }) diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index 7105469ffc242..a384997830276 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -1479,17 +1479,23 @@ test_that("column functions", { df5 <- createDataFrame(list(list(a = "010101"))) expect_equal(collect(select(df5, conv(df5$a, 2, 16)))[1, 1], "15") - # Test array_contains() and sort_array() + # Test array_contains(), array_position(), element_at() and sort_array() df <- createDataFrame(list(list(list(1L, 2L, 3L)), list(list(6L, 5L, 4L)))) result <- collect(select(df, array_contains(df[[1]], 1L)))[[1]] expect_equal(result, c(TRUE, FALSE)) + result <- collect(select(df, array_position(df[[1]], 1L)))[[1]] + expect_equal(result, c(1, 0)) + + result <- collect(select(df, element_at(df[[1]], 1L)))[[1]] + expect_equal(result, c(1, 6)) + result <- collect(select(df, sort_array(df[[1]], FALSE)))[[1]] expect_equal(result, list(list(3L, 2L, 1L), list(6L, 5L, 4L))) result <- collect(select(df, sort_array(df[[1]])))[[1]] expect_equal(result, list(list(1L, 2L, 3L), list(4L, 5L, 6L))) - # Test map_keys() and map_values() + # Test map_keys(), map_values() and element_at() df <- createDataFrame(list(list(map = as.environment(list(x = 1, y = 2))))) result <- collect(select(df, map_keys(df$map)))[[1]] expect_equal(result, list(list("x", "y"))) @@ -1497,6 +1503,9 @@ test_that("column functions", { result <- collect(select(df, map_values(df$map)))[[1]] expect_equal(result, list(list(1, 2))) + result <- collect(select(df, element_at(df$map, "y")))[[1]] + expect_equal(result, 2) + # Test that stats::lag is working expect_equal(length(lag(ldeaths, 12)), 72) From 4926a7c2f0a47b562f99dbb4f1ca17adb3192061 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Tue, 24 Apr 2018 17:52:05 +0200 Subject: [PATCH 0678/1161] [SPARK-23589][SQL][FOLLOW-UP] Reuse InternalRow in ExternalMapToCatalyst eval ## What changes were proposed in this pull request? This pr is a follow-up of #20980 and fixes code to reuse `InternalRow` for converting input keys/values in `ExternalMapToCatalyst` eval. ## How was this patch tested? Existing tests. Author: Takeshi Yamamuro Closes #21137 from maropu/SPARK-23589-FOLLOWUP. --- .../expressions/objects/objects.scala | 92 ++++++++++--------- 1 file changed, 50 insertions(+), 42 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 9c7e76467d153..f974fd81fc788 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -1255,53 +1255,61 @@ case class ExternalMapToCatalyst private( override def dataType: MapType = MapType( keyConverter.dataType, valueConverter.dataType, valueContainsNull = valueConverter.nullable) - private lazy val mapCatalystConverter: Any => (Array[Any], Array[Any]) = child.dataType match { - case ObjectType(cls) if classOf[java.util.Map[_, _]].isAssignableFrom(cls) => - (input: Any) => { - val data = input.asInstanceOf[java.util.Map[Any, Any]] - val keys = new Array[Any](data.size) - val values = new Array[Any](data.size) - val iter = data.entrySet().iterator() - var i = 0 - while (iter.hasNext) { - val entry = iter.next() - val (key, value) = (entry.getKey, entry.getValue) - keys(i) = if (key != null) { - keyConverter.eval(InternalRow.fromSeq(key :: Nil)) - } else { - throw new RuntimeException("Cannot use null as map key!") - } - values(i) = if (value != null) { - valueConverter.eval(InternalRow.fromSeq(value :: Nil)) - } else { - null + private lazy val mapCatalystConverter: Any => (Array[Any], Array[Any]) = { + val rowBuffer = InternalRow.fromSeq(Array[Any](1)) + def rowWrapper(data: Any): InternalRow = { + rowBuffer.update(0, data) + rowBuffer + } + + child.dataType match { + case ObjectType(cls) if classOf[java.util.Map[_, _]].isAssignableFrom(cls) => + (input: Any) => { + val data = input.asInstanceOf[java.util.Map[Any, Any]] + val keys = new Array[Any](data.size) + val values = new Array[Any](data.size) + val iter = data.entrySet().iterator() + var i = 0 + while (iter.hasNext) { + val entry = iter.next() + val (key, value) = (entry.getKey, entry.getValue) + keys(i) = if (key != null) { + keyConverter.eval(rowWrapper(key)) + } else { + throw new RuntimeException("Cannot use null as map key!") + } + values(i) = if (value != null) { + valueConverter.eval(rowWrapper(value)) + } else { + null + } + i += 1 } - i += 1 + (keys, values) } - (keys, values) - } - case ObjectType(cls) if classOf[scala.collection.Map[_, _]].isAssignableFrom(cls) => - (input: Any) => { - val data = input.asInstanceOf[scala.collection.Map[Any, Any]] - val keys = new Array[Any](data.size) - val values = new Array[Any](data.size) - var i = 0 - for ((key, value) <- data) { - keys(i) = if (key != null) { - keyConverter.eval(InternalRow.fromSeq(key :: Nil)) - } else { - throw new RuntimeException("Cannot use null as map key!") - } - values(i) = if (value != null) { - valueConverter.eval(InternalRow.fromSeq(value :: Nil)) - } else { - null + case ObjectType(cls) if classOf[scala.collection.Map[_, _]].isAssignableFrom(cls) => + (input: Any) => { + val data = input.asInstanceOf[scala.collection.Map[Any, Any]] + val keys = new Array[Any](data.size) + val values = new Array[Any](data.size) + var i = 0 + for ((key, value) <- data) { + keys(i) = if (key != null) { + keyConverter.eval(rowWrapper(key)) + } else { + throw new RuntimeException("Cannot use null as map key!") + } + values(i) = if (value != null) { + valueConverter.eval(rowWrapper(value)) + } else { + null + } + i += 1 } - i += 1 + (keys, values) } - (keys, values) - } + } } override def eval(input: InternalRow): Any = { From 55c4ca88a3b093ee197a8689631be8d1fac1f10f Mon Sep 17 00:00:00 2001 From: Julien Cuquemelle Date: Tue, 24 Apr 2018 10:56:55 -0500 Subject: [PATCH 0679/1161] [SPARK-22683][CORE] Add a executorAllocationRatio parameter to throttle the parallelism of the dynamic allocation ## What changes were proposed in this pull request? By default, the dynamic allocation will request enough executors to maximize the parallelism according to the number of tasks to process. While this minimizes the latency of the job, with small tasks this setting can waste a lot of resources due to executor allocation overhead, as some executor might not even do any work. This setting allows to set a ratio that will be used to reduce the number of target executors w.r.t. full parallelism. The number of executors computed with this setting is still fenced by `spark.dynamicAllocation.maxExecutors` and `spark.dynamicAllocation.minExecutors` ## How was this patch tested? Units tests and runs on various actual workloads on a Yarn Cluster Author: Julien Cuquemelle Closes #19881 from jcuquemelle/AddTaskPerExecutorSlot. --- .../spark/ExecutorAllocationManager.scala | 24 +++++++++++--- .../spark/internal/config/package.scala | 4 +++ .../ExecutorAllocationManagerSuite.scala | 33 +++++++++++++++++++ docs/configuration.md | 18 ++++++++++ 4 files changed, 74 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala index 189d91333c045..aa363eeffffb8 100644 --- a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala +++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala @@ -26,7 +26,7 @@ import scala.util.control.{ControlThrowable, NonFatal} import com.codahale.metrics.{Gauge, MetricRegistry} import org.apache.spark.internal.Logging -import org.apache.spark.internal.config.{DYN_ALLOCATION_MAX_EXECUTORS, DYN_ALLOCATION_MIN_EXECUTORS} +import org.apache.spark.internal.config._ import org.apache.spark.metrics.source.Source import org.apache.spark.scheduler._ import org.apache.spark.storage.BlockManagerMaster @@ -69,6 +69,10 @@ import org.apache.spark.util.{Clock, SystemClock, ThreadUtils, Utils} * spark.dynamicAllocation.maxExecutors - Upper bound on the number of executors * spark.dynamicAllocation.initialExecutors - Number of executors to start with * + * spark.dynamicAllocation.executorAllocationRatio - + * This is used to reduce the parallelism of the dynamic allocation that can waste + * resources when tasks are small + * * spark.dynamicAllocation.schedulerBacklogTimeout (M) - * If there are backlogged tasks for this duration, add new executors * @@ -116,9 +120,12 @@ private[spark] class ExecutorAllocationManager( // TODO: The default value of 1 for spark.executor.cores works right now because dynamic // allocation is only supported for YARN and the default number of cores per executor in YARN is // 1, but it might need to be attained differently for different cluster managers - private val tasksPerExecutor = + private val tasksPerExecutorForFullParallelism = conf.getInt("spark.executor.cores", 1) / conf.getInt("spark.task.cpus", 1) + private val executorAllocationRatio = + conf.get(DYN_ALLOCATION_EXECUTOR_ALLOCATION_RATIO) + validateSettings() // Number of executors to add in the next round @@ -209,8 +216,13 @@ private[spark] class ExecutorAllocationManager( throw new SparkException("Dynamic allocation of executors requires the external " + "shuffle service. You may enable this through spark.shuffle.service.enabled.") } - if (tasksPerExecutor == 0) { - throw new SparkException("spark.executor.cores must not be less than spark.task.cpus.") + if (tasksPerExecutorForFullParallelism == 0) { + throw new SparkException("spark.executor.cores must not be < spark.task.cpus.") + } + + if (executorAllocationRatio > 1.0 || executorAllocationRatio <= 0.0) { + throw new SparkException( + "spark.dynamicAllocation.executorAllocationRatio must be > 0 and <= 1.0") } } @@ -273,7 +285,9 @@ private[spark] class ExecutorAllocationManager( */ private def maxNumExecutorsNeeded(): Int = { val numRunningOrPendingTasks = listener.totalPendingTasks + listener.totalRunningTasks - (numRunningOrPendingTasks + tasksPerExecutor - 1) / tasksPerExecutor + math.ceil(numRunningOrPendingTasks * executorAllocationRatio / + tasksPerExecutorForFullParallelism) + .toInt } private def totalRunningTasks(): Int = synchronized { diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 99d779fb600e8..6bb98c37b4479 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -126,6 +126,10 @@ package object config { private[spark] val DYN_ALLOCATION_MAX_EXECUTORS = ConfigBuilder("spark.dynamicAllocation.maxExecutors").intConf.createWithDefault(Int.MaxValue) + private[spark] val DYN_ALLOCATION_EXECUTOR_ALLOCATION_RATIO = + ConfigBuilder("spark.dynamicAllocation.executorAllocationRatio") + .doubleConf.createWithDefault(1.0) + private[spark] val LOCALITY_WAIT = ConfigBuilder("spark.locality.wait") .timeConf(TimeUnit.MILLISECONDS) .createWithDefaultString("3s") diff --git a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala index 9807d1269e3d4..3cfb0a9feb32b 100644 --- a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala @@ -145,6 +145,39 @@ class ExecutorAllocationManagerSuite assert(numExecutorsToAdd(manager) === 1) } + def testAllocationRatio(cores: Int, divisor: Double, expected: Int): Unit = { + val conf = new SparkConf() + .setMaster("myDummyLocalExternalClusterManager") + .setAppName("test-executor-allocation-manager") + .set("spark.dynamicAllocation.enabled", "true") + .set("spark.dynamicAllocation.testing", "true") + .set("spark.dynamicAllocation.maxExecutors", "15") + .set("spark.dynamicAllocation.minExecutors", "3") + .set("spark.dynamicAllocation.executorAllocationRatio", divisor.toString) + .set("spark.executor.cores", cores.toString) + val sc = new SparkContext(conf) + contexts += sc + var manager = sc.executorAllocationManager.get + post(sc.listenerBus, SparkListenerStageSubmitted(createStageInfo(0, 20))) + for (i <- 0 to 5) { + addExecutors(manager) + } + assert(numExecutorsTarget(manager) === expected) + sc.stop() + } + + test("executionAllocationRatio is correctly handled") { + testAllocationRatio(1, 0.5, 10) + testAllocationRatio(1, 1.0/3.0, 7) + testAllocationRatio(2, 1.0/3.0, 4) + testAllocationRatio(1, 0.385, 8) + + // max/min executors capping + testAllocationRatio(1, 1.0, 15) // should be 20 but capped by max + testAllocationRatio(4, 1.0/3.0, 3) // should be 2 but elevated by min + } + + test("add executors capped by num pending tasks") { sc = createSparkContext(0, 10, 0) val manager = sc.executorAllocationManager.get diff --git a/docs/configuration.md b/docs/configuration.md index 4d4d0c58dd07d..fb02d7ea1d4ea 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -1753,6 +1753,7 @@ Apart from these, the following properties are also available, and may be useful spark.dynamicAllocation.minExecutors, spark.dynamicAllocation.maxExecutors, and spark.dynamicAllocation.initialExecutors + spark.dynamicAllocation.executorAllocationRatio
    spark.dynamicAllocation.executorAllocationRatio1 + By default, the dynamic allocation will request enough executors to maximize the + parallelism according to the number of tasks to process. While this minimizes the + latency of the job, with small tasks this setting can waste a lot of resources due to + executor allocation overhead, as some executor might not even do any work. + This setting allows to set a ratio that will be used to reduce the number of + executors w.r.t. full parallelism. + Defaults to 1.0 to give maximum parallelism. + 0.5 will divide the target number of executors by 2 + The target number of executors computed by the dynamicAllocation can still be overriden + by the spark.dynamicAllocation.minExecutors and + spark.dynamicAllocation.maxExecutors settings +
    spark.dynamicAllocation.schedulerBacklogTimeout 1s
    Property NameDefaultMeaning
    spark.sql.orc.implhivenative The name of ORC implementation. It can be one of native and hive. native means the native ORC support that is built on Apache ORC 1.4. `hive` means the ORC library in Hive 1.2.1.
    spark.yarn.am.waitTime 100s - In cluster mode, time for the YARN Application Master to wait for the - SparkContext to be initialized. In client mode, time for the YARN Application Master to wait - for the driver to connect to it. + Only used in cluster mode. Time for the YARN Application Master to wait for the + SparkContext to be initialized.
    1.2.1 Version of the Hive metastore. Available - options are 0.12.0 through 2.3.2. + options are 0.12.0 through 2.3.3.
    queryTimeout + The number of seconds the driver will wait for a Statement object to execute to the given + number of seconds. Zero means there is no limit. In the write path, this option depends on + how JDBC drivers implement the API setQueryTimeout, e.g., the h2 JDBC driver + checks the timeout of each query instead of an entire JDBC batch. + It defaults to 0. +
    fetchsize diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 53f44888ebaff..917f0cb221412 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -257,7 +257,8 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * @param connectionProperties JDBC database connection arguments, a list of arbitrary string * tag/value. Normally at least a "user" and "password" property * should be included. "fetchsize" can be used to control the - * number of rows per fetch. + * number of rows per fetch and "queryTimeout" can be used to wait + * for a Statement object to execute to the given number of seconds. * @since 1.4.0 */ def jdbc( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala index b4e5d169066d9..a73a97c06fe5a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala @@ -89,6 +89,10 @@ class JDBCOptions( // the number of partitions val numPartitions = parameters.get(JDBC_NUM_PARTITIONS).map(_.toInt) + // the number of seconds the driver will wait for a Statement object to execute to the given + // number of seconds. Zero means there is no limit. + val queryTimeout = parameters.getOrElse(JDBC_QUERY_TIMEOUT, "0").toInt + // ------------------------------------------------------------ // Optional parameters only for reading // ------------------------------------------------------------ @@ -160,6 +164,7 @@ object JDBCOptions { val JDBC_LOWER_BOUND = newOption("lowerBound") val JDBC_UPPER_BOUND = newOption("upperBound") val JDBC_NUM_PARTITIONS = newOption("numPartitions") + val JDBC_QUERY_TIMEOUT = newOption("queryTimeout") val JDBC_BATCH_FETCH_SIZE = newOption("fetchsize") val JDBC_TRUNCATE = newOption("truncate") val JDBC_CREATE_TABLE_OPTIONS = newOption("createTableOptions") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index 05326210f3242..0bab3689e5d0e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -57,6 +57,7 @@ object JDBCRDD extends Logging { try { val statement = conn.prepareStatement(dialect.getSchemaQuery(table)) try { + statement.setQueryTimeout(options.queryTimeout) val rs = statement.executeQuery() try { JdbcUtils.getSchema(rs, dialect, alwaysNullable = true) @@ -281,6 +282,7 @@ private[jdbc] class JDBCRDD( val statement = conn.prepareStatement(sql) logInfo(s"Executing sessionInitStatement: $sql") try { + statement.setQueryTimeout(options.queryTimeout) statement.execute() } finally { statement.close() @@ -298,6 +300,7 @@ private[jdbc] class JDBCRDD( stmt = conn.prepareStatement(sqlText, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY) stmt.setFetchSize(options.fetchSize) + stmt.setQueryTimeout(options.queryTimeout) rs = stmt.executeQuery() val rowsIterator = JdbcUtils.resultSetToSparkInternalRows(rs, schema, inputMetrics) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala index cc506e51bd0c6..f8c5677ea0f2a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala @@ -73,7 +73,7 @@ class JdbcRelationProvider extends CreatableRelationProvider saveTable(df, tableSchema, isCaseSensitive, options) } else { // Otherwise, do not truncate the table, instead drop and recreate it - dropTable(conn, options.table) + dropTable(conn, options.table, options) createTable(conn, df, options) saveTable(df, Some(df.schema), isCaseSensitive, options) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index e6dc2fda4eb1b..433443007cfd8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -76,6 +76,7 @@ object JdbcUtils extends Logging { Try { val statement = conn.prepareStatement(dialect.getTableExistsQuery(options.table)) try { + statement.setQueryTimeout(options.queryTimeout) statement.executeQuery() } finally { statement.close() @@ -86,9 +87,10 @@ object JdbcUtils extends Logging { /** * Drops a table from the JDBC database. */ - def dropTable(conn: Connection, table: String): Unit = { + def dropTable(conn: Connection, table: String, options: JDBCOptions): Unit = { val statement = conn.createStatement try { + statement.setQueryTimeout(options.queryTimeout) statement.executeUpdate(s"DROP TABLE $table") } finally { statement.close() @@ -102,6 +104,7 @@ object JdbcUtils extends Logging { val dialect = JdbcDialects.get(options.url) val statement = conn.createStatement try { + statement.setQueryTimeout(options.queryTimeout) statement.executeUpdate(dialect.getTruncateQuery(options.table)) } finally { statement.close() @@ -254,6 +257,7 @@ object JdbcUtils extends Logging { try { val statement = conn.prepareStatement(dialect.getSchemaQuery(options.table)) try { + statement.setQueryTimeout(options.queryTimeout) Some(getSchema(statement.executeQuery(), dialect)) } catch { case _: SQLException => None @@ -596,7 +600,8 @@ object JdbcUtils extends Logging { insertStmt: String, batchSize: Int, dialect: JdbcDialect, - isolationLevel: Int): Iterator[Byte] = { + isolationLevel: Int, + options: JDBCOptions): Iterator[Byte] = { val conn = getConnection() var committed = false @@ -637,6 +642,9 @@ object JdbcUtils extends Logging { try { var rowCount = 0 + + stmt.setQueryTimeout(options.queryTimeout) + while (iterator.hasNext) { val row = iterator.next() var i = 0 @@ -819,7 +827,8 @@ object JdbcUtils extends Logging { case _ => df } repartitionedDF.rdd.foreachPartition(iterator => savePartition( - getConnection, table, iterator, rddSchema, insertStmt, batchSize, dialect, isolationLevel) + getConnection, table, iterator, rddSchema, insertStmt, batchSize, dialect, isolationLevel, + options) ) } @@ -841,6 +850,7 @@ object JdbcUtils extends Logging { val sql = s"CREATE TABLE $table ($strSchema) $createTableOptions" val statement = conn.createStatement try { + statement.setQueryTimeout(options.queryTimeout) statement.executeUpdate(sql) } finally { statement.close() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 5238adce4a699..bc2aca65e803f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -1190,4 +1190,20 @@ class JDBCSuite extends SparkFunSuite assert(sql("select * from people_view").schema === schema) } } + + test("SPARK-23856 Spark jdbc setQueryTimeout option") { + val numJoins = 100 + val longRunningQuery = + s"SELECT t0.NAME AS c0, ${(1 to numJoins).map(i => s"t$i.NAME AS c$i").mkString(", ")} " + + s"FROM test.people t0 ${(1 to numJoins).map(i => s"join test.people t$i").mkString(" ")}" + val df = spark.read.format("jdbc") + .option("Url", urlWithUserAndPass) + .option("dbtable", s"($longRunningQuery)") + .option("queryTimeout", 1) + .load() + val errMsg = intercept[SparkException] { + df.collect() + }.getMessage + assert(errMsg.contains("Statement was canceled or the session timed out")) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala index 1985b1dc82879..1c2c92d1f0737 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala @@ -515,4 +515,22 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter { }.getMessage assert(e.contains("NULL not allowed for column \"NAME\"")) } + + ignore("SPARK-23856 Spark jdbc setQueryTimeout option") { + // The behaviour of the option `queryTimeout` depends on how JDBC drivers implement the API + // `setQueryTimeout`. For example, in the h2 JDBC driver, `executeBatch` invokes multiple + // INSERT queries in a batch and `setQueryTimeout` means that the driver checks the timeout + // of each query. In the PostgreSQL JDBC driver, `setQueryTimeout` means that the driver + // checks the timeout of an entire batch in a driver side. So, the test below fails because + // this test suite depends on the h2 JDBC driver and the JDBC write path internally + // uses `executeBatch`. + val errMsg = intercept[SparkException] { + spark.range(10000000L).selectExpr("id AS k", "id AS v").coalesce(1).write + .mode(SaveMode.Overwrite) + .option("queryTimeout", 1) + .option("batchsize", Int.MaxValue) + .jdbc(url1, "TEST.TIMEOUTTEST", properties) + }.getMessage + assert(errMsg.contains("Statement was canceled or the session timed out")) + } } From 710e4e81a8efc1aacc14283fb57bc8786146f885 Mon Sep 17 00:00:00 2001 From: Arun Mahadevan Date: Fri, 18 May 2018 14:37:01 -0700 Subject: [PATCH 0831/1161] [SPARK-24308][SQL] Handle DataReaderFactory to InputPartition rename in left over classes ## What changes were proposed in this pull request? SPARK-24073 renames DataReaderFactory -> InputPartition and DataReader -> InputPartitionReader. Some classes still reflects the old name and causes confusion. This patch renames the left over classes to reflect the new interface and fixes a few comments. ## How was this patch tested? Existing unit tests. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Arun Mahadevan Closes #21355 from arunmahadevan/SPARK-24308. --- .../spark/sql/kafka010/KafkaContinuousReader.scala | 6 +++--- .../spark/sql/kafka010/KafkaMicroBatchReader.scala | 4 ++-- .../sql/kafka010/KafkaMicroBatchSourceSuite.scala | 2 +- .../sources/v2/reader/ContinuousInputPartition.java | 4 ++-- .../spark/sql/sources/v2/reader/InputPartition.java | 6 +++--- .../sql/sources/v2/reader/InputPartitionReader.java | 6 +++--- .../sql/execution/datasources/v2/DataSourceRDD.scala | 6 +++--- .../continuous/ContinuousRateStreamSource.scala | 4 ++-- .../spark/sql/execution/streaming/memory.scala | 4 ++-- .../streaming/sources/ContinuousMemoryStream.scala | 12 ++++++------ .../sources/RateStreamMicroBatchReader.scala | 4 ++-- .../streaming/sources/RateStreamProviderSuite.scala | 2 +- 12 files changed, 30 insertions(+), 30 deletions(-) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala index 88abf8a8dd027..badaa69cc303c 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala @@ -106,7 +106,7 @@ class KafkaContinuousReader( startOffsets.toSeq.map { case (topicPartition, start) => - KafkaContinuousDataReaderFactory( + KafkaContinuousInputPartition( topicPartition, start, kafkaParams, pollTimeoutMs, failOnDataLoss) .asInstanceOf[InputPartition[UnsafeRow]] }.asJava @@ -146,7 +146,7 @@ class KafkaContinuousReader( } /** - * A data reader factory for continuous Kafka processing. This will be serialized and transformed + * An input partition for continuous Kafka processing. This will be serialized and transformed * into a full reader on executors. * * @param topicPartition The (topic, partition) pair this task is responsible for. @@ -156,7 +156,7 @@ class KafkaContinuousReader( * @param failOnDataLoss Flag indicating whether data reader should fail if some offsets * are skipped. */ -case class KafkaContinuousDataReaderFactory( +case class KafkaContinuousInputPartition( topicPartition: TopicPartition, startOffset: Long, kafkaParams: ju.Map[String, Object], diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala index 8a377738ea782..64ba98762788c 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala @@ -143,7 +143,7 @@ private[kafka010] class KafkaMicroBatchReader( // Generate factories based on the offset ranges val factories = offsetRanges.map { range => - new KafkaMicroBatchDataReaderFactory( + new KafkaMicroBatchInputPartition( range, executorKafkaParams, pollTimeoutMs, failOnDataLoss, reuseKafkaConsumer) } factories.map(_.asInstanceOf[InputPartition[UnsafeRow]]).asJava @@ -300,7 +300,7 @@ private[kafka010] class KafkaMicroBatchReader( } /** A [[InputPartition]] for reading Kafka data in a micro-batch streaming query. */ -private[kafka010] case class KafkaMicroBatchDataReaderFactory( +private[kafka010] case class KafkaMicroBatchInputPartition( offsetRange: KafkaOffsetRange, executorKafkaParams: ju.Map[String, Object], pollTimeoutMs: Long, diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala index 871f9700cd1db..c6412eac97dba 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala @@ -679,7 +679,7 @@ class KafkaMicroBatchV2SourceSuite extends KafkaMicroBatchSourceSuiteBase { Optional.of[OffsetV2](KafkaSourceOffset(Map(tp -> 100L))) ) val factories = reader.planUnsafeInputPartitions().asScala - .map(_.asInstanceOf[KafkaMicroBatchDataReaderFactory]) + .map(_.asInstanceOf[KafkaMicroBatchInputPartition]) withClue(s"minPartitions = $minPartitions generated factories $factories\n\t") { assert(factories.size == numPartitionsGenerated) factories.foreach { f => assert(f.reuseKafkaConsumer == reusesConsumers) } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousInputPartition.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousInputPartition.java index c24f3b21eade1..dcb87715d0b6f 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousInputPartition.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousInputPartition.java @@ -27,9 +27,9 @@ @InterfaceStability.Evolving public interface ContinuousInputPartition extends InputPartition { /** - * Create a DataReader with particular offset as its startOffset. + * Create an input partition reader with particular offset as its startOffset. * - * @param offset offset want to set as the DataReader's startOffset. + * @param offset offset want to set as the input partition reader's startOffset. */ InputPartitionReader createContinuousReader(PartitionOffset offset); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartition.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartition.java index 3524481784fea..f53687e113ae0 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartition.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartition.java @@ -36,8 +36,8 @@ public interface InputPartition extends Serializable { /** - * The preferred locations where the data reader returned by this partition can run faster, - * but Spark does not guarantee to run the data reader on these locations. + * The preferred locations where the input partition reader returned by this partition can run faster, + * but Spark does not guarantee to run the input partition reader on these locations. * The implementations should make sure that it can be run on any location. * The location is a string representing the host name. * @@ -53,7 +53,7 @@ default String[] preferredLocations() { } /** - * Returns a data reader to do the actual reading work. + * Returns an input partition reader to do the actual reading work. * * If this method fails (by throwing an exception), the corresponding Spark task would fail and * get retried until hitting the maximum retry times. diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartitionReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartitionReader.java index 1b7051f1ad0af..f0d808536207a 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartitionReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartitionReader.java @@ -23,11 +23,11 @@ import org.apache.spark.annotation.InterfaceStability; /** - * A data reader returned by {@link InputPartition#createPartitionReader()} and is responsible for + * An input partition reader returned by {@link InputPartition#createPartitionReader()} and is responsible for * outputting data for a RDD partition. * - * Note that, Currently the type `T` can only be {@link org.apache.spark.sql.Row} for normal data - * source readers, or {@link org.apache.spark.sql.catalyst.expressions.UnsafeRow} for data source + * Note that, Currently the type `T` can only be {@link org.apache.spark.sql.Row} for normal input + * partition readers, or {@link org.apache.spark.sql.catalyst.expressions.UnsafeRow} for input partition * readers that mix in {@link SupportsScanUnsafeRow}. */ @InterfaceStability.Evolving diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala index 1a6b32429313a..8d6fb3820d420 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala @@ -29,12 +29,12 @@ class DataSourceRDDPartition[T : ClassTag](val index: Int, val inputPartition: I class DataSourceRDD[T: ClassTag]( sc: SparkContext, - @transient private val readerFactories: Seq[InputPartition[T]]) + @transient private val inputPartitions: Seq[InputPartition[T]]) extends RDD[T](sc, Nil) { override protected def getPartitions: Array[Partition] = { - readerFactories.zipWithIndex.map { - case (readerFactory, index) => new DataSourceRDDPartition(index, readerFactory) + inputPartitions.zipWithIndex.map { + case (inputPartition, index) => new DataSourceRDDPartition(index, inputPartition) }.toArray } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala index 8d25d9ccc43d3..516a563bdcc7a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala @@ -85,7 +85,7 @@ class RateStreamContinuousReader(options: DataSourceOptions) val start = partitionStartMap(i) // Have each partition advance by numPartitions each row, with starting points staggered // by their partition index. - RateStreamContinuousDataReaderFactory( + RateStreamContinuousInputPartition( start.value, start.runTimeMs, i, @@ -113,7 +113,7 @@ class RateStreamContinuousReader(options: DataSourceOptions) } -case class RateStreamContinuousDataReaderFactory( +case class RateStreamContinuousInputPartition( startValue: Long, startTimeMs: Long, partitionIndex: Int, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala index daa2963220aef..b137f98045c5a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -156,7 +156,7 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) logDebug(generateDebugString(newBlocks.flatten, startOrdinal, endOrdinal)) newBlocks.map { block => - new MemoryStreamDataReaderFactory(block).asInstanceOf[InputPartition[UnsafeRow]] + new MemoryStreamInputPartition(block).asInstanceOf[InputPartition[UnsafeRow]] }.asJava } } @@ -201,7 +201,7 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) } -class MemoryStreamDataReaderFactory(records: Array[UnsafeRow]) +class MemoryStreamInputPartition(records: Array[UnsafeRow]) extends InputPartition[UnsafeRow] { override def createPartitionReader(): InputPartitionReader[UnsafeRow] = { new InputPartitionReader[UnsafeRow] { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala index 4daafa65850de..d1c3498450096 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala @@ -44,8 +44,8 @@ import org.apache.spark.util.RpcUtils * * ContinuousMemoryStream maintains a list of records for each partition. addData() will * distribute records evenly-ish across partitions. * * RecordEndpoint is set up as an endpoint for executor-side - * ContinuousMemoryStreamDataReader instances to poll. It returns the record at the specified - * offset within the list, or null if that offset doesn't yet have a record. + * ContinuousMemoryStreamInputPartitionReader instances to poll. It returns the record at + * the specified offset within the list, or null if that offset doesn't yet have a record. */ class ContinuousMemoryStream[A : Encoder](id: Int, sqlContext: SQLContext, numPartitions: Int = 2) extends MemoryStreamBase[A](sqlContext) with ContinuousReader with ContinuousReadSupport { @@ -106,7 +106,7 @@ class ContinuousMemoryStream[A : Encoder](id: Int, sqlContext: SQLContext, numPa startOffset.partitionNums.map { case (part, index) => - new ContinuousMemoryStreamDataReaderFactory( + new ContinuousMemoryStreamInputPartition( endpointName, part, index): InputPartition[Row] }.toList.asJava } @@ -157,9 +157,9 @@ object ContinuousMemoryStream { } /** - * Data reader factory for continuous memory stream. + * An input partition for continuous memory stream. */ -class ContinuousMemoryStreamDataReaderFactory( +class ContinuousMemoryStreamInputPartition( driverEndpointName: String, partition: Int, startOffset: Int) extends InputPartition[Row] { @@ -168,7 +168,7 @@ class ContinuousMemoryStreamDataReaderFactory( } /** - * Data reader for continuous memory stream. + * An input partition reader for continuous memory stream. * * Polls the driver endpoint for new records. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala index 723cc3ad5bb89..fbff8db987110 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala @@ -167,7 +167,7 @@ class RateStreamMicroBatchReader(options: DataSourceOptions, checkpointLocation: } (0 until numPartitions).map { p => - new RateStreamMicroBatchDataReaderFactory( + new RateStreamMicroBatchInputPartition( p, numPartitions, rangeStart, rangeEnd, localStartTimeMs, relativeMsPerValue) : InputPartition[Row] }.toList.asJava @@ -182,7 +182,7 @@ class RateStreamMicroBatchReader(options: DataSourceOptions, checkpointLocation: s"numPartitions=${options.get(NUM_PARTITIONS).orElse("default")}" } -class RateStreamMicroBatchDataReaderFactory( +class RateStreamMicroBatchInputPartition( partitionId: Int, numPartitions: Int, rangeStart: Long, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala index 39a010f970ce5..bf72e5c99689f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala @@ -309,7 +309,7 @@ class RateSourceSuite extends StreamTest { val data = scala.collection.mutable.ListBuffer[Row]() tasks.asScala.foreach { - case t: RateStreamContinuousDataReaderFactory => + case t: RateStreamContinuousInputPartition => val startTimeMs = reader.getStartOffset() .asInstanceOf[RateStreamOffset] .partitionToValueAndRunTimeMs(t.partitionIndex) From 434d74e337465d77fa49ab65e2b5461e5ff7b5c7 Mon Sep 17 00:00:00 2001 From: Efim Poberezkin Date: Fri, 18 May 2018 16:54:39 -0700 Subject: [PATCH 0832/1161] [SPARK-23503][SS] Enforce sequencing of committed epochs for Continuous Execution ## What changes were proposed in this pull request? Made changes to EpochCoordinator so that it enforces a commit order. In case a message for epoch n is lost and epoch (n + 1) is ready for commit before epoch n is, epoch (n + 1) will wait for epoch n to be committed first. ## How was this patch tested? Existing tests in ContinuousSuite and EpochCoordinatorSuite. Author: Efim Poberezkin Closes #20936 from efimpoberezkin/pr/sequence-commited-epochs. --- .../continuous/EpochCoordinator.scala | 69 +++++++++++++++---- .../continuous/EpochCoordinatorSuite.scala | 6 +- 2 files changed, 58 insertions(+), 17 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala index cc6808065c0cd..8877ebeb26735 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala @@ -137,30 +137,71 @@ private[continuous] class EpochCoordinator( private val partitionOffsets = mutable.Map[(Long, Int), PartitionOffset]() + private var lastCommittedEpoch = startEpoch - 1 + // Remembers epochs that have to wait for previous epochs to be committed first. + private val epochsWaitingToBeCommitted = mutable.HashSet.empty[Long] + private def resolveCommitsAtEpoch(epoch: Long) = { - val thisEpochCommits = - partitionCommits.collect { case ((e, _), msg) if e == epoch => msg } + val thisEpochCommits = findPartitionCommitsForEpoch(epoch) val nextEpochOffsets = partitionOffsets.collect { case ((e, _), o) if e == epoch => o } if (thisEpochCommits.size == numWriterPartitions && nextEpochOffsets.size == numReaderPartitions) { - logDebug(s"Epoch $epoch has received commits from all partitions. Committing globally.") - // Sequencing is important here. We must commit to the writer before recording the commit - // in the query, or we will end up dropping the commit if we restart in the middle. - writer.commit(epoch, thisEpochCommits.toArray) - query.commit(epoch) - - // Cleanup state from before this epoch, now that we know all partitions are forever past it. - for (k <- partitionCommits.keys.filter { case (e, _) => e < epoch }) { - partitionCommits.remove(k) - } - for (k <- partitionOffsets.keys.filter { case (e, _) => e < epoch }) { - partitionOffsets.remove(k) + + // Check that last committed epoch is the previous one for sequencing of committed epochs. + // If not, add the epoch being currently processed to epochs waiting to be committed, + // otherwise commit it. + if (lastCommittedEpoch != epoch - 1) { + logDebug(s"Epoch $epoch has received commits from all partitions " + + s"and is waiting for epoch ${epoch - 1} to be committed first.") + epochsWaitingToBeCommitted.add(epoch) + } else { + commitEpoch(epoch, thisEpochCommits) + lastCommittedEpoch = epoch + + // Commit subsequent epochs that are waiting to be committed. + var nextEpoch = lastCommittedEpoch + 1 + while (epochsWaitingToBeCommitted.contains(nextEpoch)) { + val nextEpochCommits = findPartitionCommitsForEpoch(nextEpoch) + commitEpoch(nextEpoch, nextEpochCommits) + + epochsWaitingToBeCommitted.remove(nextEpoch) + lastCommittedEpoch = nextEpoch + nextEpoch += 1 + } + + // Cleanup state from before last committed epoch, + // now that we know all partitions are forever past it. + for (k <- partitionCommits.keys.filter { case (e, _) => e < lastCommittedEpoch }) { + partitionCommits.remove(k) + } + for (k <- partitionOffsets.keys.filter { case (e, _) => e < lastCommittedEpoch }) { + partitionOffsets.remove(k) + } } } } + /** + * Collect per-partition commits for an epoch. + */ + private def findPartitionCommitsForEpoch(epoch: Long): Iterable[WriterCommitMessage] = { + partitionCommits.collect { case ((e, _), msg) if e == epoch => msg } + } + + /** + * Commit epoch to the offset log. + */ + private def commitEpoch(epoch: Long, messages: Iterable[WriterCommitMessage]): Unit = { + logDebug(s"Epoch $epoch has received commits from all partitions " + + s"and is ready to be committed. Committing epoch $epoch.") + // Sequencing is important here. We must commit to the writer before recording the commit + // in the query, or we will end up dropping the commit if we restart in the middle. + writer.commit(epoch, messages.toArray) + query.commit(epoch) + } + override def receive: PartialFunction[Any, Unit] = { // If we just drop these messages, we won't do any writes to the query. The lame duck tasks // won't shed errors or anything. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/EpochCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/EpochCoordinatorSuite.scala index 99e30561f81d5..82836dced9df7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/EpochCoordinatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/EpochCoordinatorSuite.scala @@ -120,7 +120,7 @@ class EpochCoordinatorSuite verifyCommitsInOrderOf(List(1, 2)) } - ignore("consequent epochs, a message for epoch k arrives after messages for epoch (k + 1)") { + test("consequent epochs, a message for epoch k arrives after messages for epoch (k + 1)") { setWriterPartitions(2) setReaderPartitions(2) @@ -141,7 +141,7 @@ class EpochCoordinatorSuite verifyCommitsInOrderOf(List(1, 2)) } - ignore("several epochs, messages arrive in order 1 -> 3 -> 4 -> 2") { + test("several epochs, messages arrive in order 1 -> 3 -> 4 -> 2") { setWriterPartitions(1) setReaderPartitions(1) @@ -162,7 +162,7 @@ class EpochCoordinatorSuite verifyCommitsInOrderOf(List(1, 2, 3, 4)) } - ignore("several epochs, messages arrive in order 1 -> 3 -> 5 -> 4 -> 2") { + test("several epochs, messages arrive in order 1 -> 3 -> 5 -> 4 -> 2") { setWriterPartitions(1) setReaderPartitions(1) From dd37529a8dada6ed8a49b8ce50875268f6a20cba Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Sat, 19 May 2018 18:51:02 +0800 Subject: [PATCH 0833/1161] [SPARK-24250][SQL] support accessing SQLConf inside tasks ## What changes were proposed in this pull request? Previously in #20136 we decided to forbid tasks to access `SQLConf`, because it doesn't work and always give you the default conf value. In #21190 we fixed the check and all the places that violate it. Currently the pattern of accessing configs at the executor side is: read the configs at the driver side, then access the variables holding the config values in the RDD closure, so that they will be serialized to the executor side. Something like ``` val someConf = conf.getXXX child.execute().mapPartitions { if (someConf == ...) ... ... } ``` However, this pattern is hard to apply if the config needs to be propagated via a long call stack. An example is `DataType.sameType`, and see how many changes were made in #21190 . When it comes to code generation, it's even worse. I tried it locally and we need to change a ton of files to propagate configs to code generators. This PR proposes to allow tasks to access `SQLConf`. The idea is, we can save all the SQL configs to job properties when an SQL execution is triggered. At executor side we rebuild the `SQLConf` from job properties. ## How was this patch tested? a new test suite Author: Wenchen Fan Closes #21299 from cloud-fan/config. --- .../org/apache/spark/TaskContextImpl.scala | 2 + .../spark/sql/internal/ReadOnlySQLConf.scala | 66 +++++++++++++++++++ .../apache/spark/sql/internal/SQLConf.scala | 21 +++--- .../org/apache/spark/sql/SparkSession.scala | 21 +++++- .../spark/sql/execution/SQLExecution.scala | 50 ++++++++++---- .../execution/basicPhysicalOperators.scala | 2 +- .../datasources/json/JsonDataSource.scala | 16 ++--- .../exchange/BroadcastExchangeExec.scala | 2 +- .../internal/ExecutorSideSQLConfSuite.scala | 66 +++++++++++++++++++ 9 files changed, 210 insertions(+), 36 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/internal/ReadOnlySQLConf.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/internal/ExecutorSideSQLConfSuite.scala diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala index cccd3ea457ba4..0791fe856ef15 100644 --- a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala +++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala @@ -178,4 +178,6 @@ private[spark] class TaskContextImpl( private[spark] def fetchFailed: Option[FetchFailedException] = _fetchFailedException + // TODO: shall we publish it and define it in `TaskContext`? + private[spark] def getLocalProperties(): Properties = localProperties } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/ReadOnlySQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/ReadOnlySQLConf.scala new file mode 100644 index 0000000000000..19f67236c8979 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/ReadOnlySQLConf.scala @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.internal + +import java.util.{Map => JMap} + +import org.apache.spark.{TaskContext, TaskContextImpl} +import org.apache.spark.internal.config.{ConfigEntry, ConfigProvider, ConfigReader} + +/** + * A readonly SQLConf that will be created by tasks running at the executor side. It reads the + * configs from the local properties which are propagated from driver to executors. + */ +class ReadOnlySQLConf(context: TaskContext) extends SQLConf { + + @transient override val settings: JMap[String, String] = { + context.asInstanceOf[TaskContextImpl].getLocalProperties().asInstanceOf[JMap[String, String]] + } + + @transient override protected val reader: ConfigReader = { + new ConfigReader(new TaskContextConfigProvider(context)) + } + + override protected def setConfWithCheck(key: String, value: String): Unit = { + throw new UnsupportedOperationException("Cannot mutate ReadOnlySQLConf.") + } + + override def unsetConf(key: String): Unit = { + throw new UnsupportedOperationException("Cannot mutate ReadOnlySQLConf.") + } + + override def unsetConf(entry: ConfigEntry[_]): Unit = { + throw new UnsupportedOperationException("Cannot mutate ReadOnlySQLConf.") + } + + override def clear(): Unit = { + throw new UnsupportedOperationException("Cannot mutate ReadOnlySQLConf.") + } + + override def clone(): SQLConf = { + throw new UnsupportedOperationException("Cannot clone/copy ReadOnlySQLConf.") + } + + override def copy(entries: (ConfigEntry[_], Any)*): SQLConf = { + throw new UnsupportedOperationException("Cannot clone/copy ReadOnlySQLConf.") + } +} + +class TaskContextConfigProvider(context: TaskContext) extends ConfigProvider { + override def get(key: String): Option[String] = Option(context.getLocalProperty(key)) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 53a50305348fa..643e4c686f58d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -27,13 +27,12 @@ import scala.util.matching.Regex import org.apache.hadoop.fs.Path -import org.apache.spark.{SparkContext, SparkEnv} +import org.apache.spark.TaskContext import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ import org.apache.spark.network.util.ByteUnit import org.apache.spark.sql.catalyst.analysis.Resolver import org.apache.spark.sql.catalyst.expressions.codegen.CodeGenerator -import org.apache.spark.util.Utils //////////////////////////////////////////////////////////////////////////////////////////////////// // This file defines the configuration options for Spark SQL. @@ -107,7 +106,13 @@ object SQLConf { * run tests in parallel. At the time this feature was implemented, this was a no-op since we * run unit tests (that does not involve SparkSession) in serial order. */ - def get: SQLConf = confGetter.get()() + def get: SQLConf = { + if (TaskContext.get != null) { + new ReadOnlySQLConf(TaskContext.get()) + } else { + confGetter.get()() + } + } val OPTIMIZER_MAX_ITERATIONS = buildConf("spark.sql.optimizer.maxIterations") .internal() @@ -1292,17 +1297,11 @@ object SQLConf { class SQLConf extends Serializable with Logging { import SQLConf._ - if (Utils.isTesting && SparkEnv.get != null) { - // assert that we're only accessing it on the driver. - assert(SparkEnv.get.executorId == SparkContext.DRIVER_IDENTIFIER, - "SQLConf should only be created and accessed on the driver.") - } - /** Only low degree of contention is expected for conf, thus NOT using ConcurrentHashMap. */ @transient protected[spark] val settings = java.util.Collections.synchronizedMap( new java.util.HashMap[String, String]()) - @transient private val reader = new ConfigReader(settings) + @transient protected val reader = new ConfigReader(settings) /** ************************ Spark SQL Params/Hints ******************* */ @@ -1765,7 +1764,7 @@ class SQLConf extends Serializable with Logging { settings.containsKey(key) } - private def setConfWithCheck(key: String, value: String): Unit = { + protected def setConfWithCheck(key: String, value: String): Unit = { settings.put(key, value) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index c502e583a55c5..e2a1a57c7dd4d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -24,7 +24,7 @@ import scala.collection.JavaConverters._ import scala.reflect.runtime.universe.TypeTag import scala.util.control.NonFatal -import org.apache.spark.{SPARK_VERSION, SparkConf, SparkContext} +import org.apache.spark.{SPARK_VERSION, SparkConf, SparkContext, TaskContext} import org.apache.spark.annotation.{DeveloperApi, Experimental, InterfaceStability} import org.apache.spark.api.java.JavaRDD import org.apache.spark.internal.Logging @@ -898,6 +898,7 @@ object SparkSession extends Logging { * @since 2.0.0 */ def getOrCreate(): SparkSession = synchronized { + assertOnDriver() // Get the session from current thread's active session. var session = activeThreadSession.get() if ((session ne null) && !session.sparkContext.isStopped) { @@ -1022,14 +1023,20 @@ object SparkSession extends Logging { * * @since 2.2.0 */ - def getActiveSession: Option[SparkSession] = Option(activeThreadSession.get) + def getActiveSession: Option[SparkSession] = { + assertOnDriver() + Option(activeThreadSession.get) + } /** * Returns the default SparkSession that is returned by the builder. * * @since 2.2.0 */ - def getDefaultSession: Option[SparkSession] = Option(defaultSession.get) + def getDefaultSession: Option[SparkSession] = { + assertOnDriver() + Option(defaultSession.get) + } /** * Returns the currently active SparkSession, otherwise the default one. If there is no default @@ -1062,6 +1069,14 @@ object SparkSession extends Logging { } } + private def assertOnDriver(): Unit = { + if (Utils.isTesting && TaskContext.get != null) { + // we're accessing it during task execution, fail. + throw new IllegalStateException( + "SparkSession should only be created and accessed on the driver.") + } + } + /** * Helper method to create an instance of `SessionState` based on `className` from conf. * The result is either `SessionState` or a Hive based `SessionState`. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala index 2c5102b1e5ee7..032525a08ccdb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala @@ -68,16 +68,18 @@ object SQLExecution { // sparkContext.getCallSite() would first try to pick up any call site that was previously // set, then fall back to Utils.getCallSite(); call Utils.getCallSite() directly on // streaming queries would give us call site like "run at :0" - val callSite = sparkSession.sparkContext.getCallSite() + val callSite = sc.getCallSite() - sparkSession.sparkContext.listenerBus.post(SparkListenerSQLExecutionStart( - executionId, callSite.shortForm, callSite.longForm, queryExecution.toString, - SparkPlanInfo.fromSparkPlan(queryExecution.executedPlan), System.currentTimeMillis())) - try { - body - } finally { - sparkSession.sparkContext.listenerBus.post(SparkListenerSQLExecutionEnd( - executionId, System.currentTimeMillis())) + withSQLConfPropagated(sparkSession) { + sc.listenerBus.post(SparkListenerSQLExecutionStart( + executionId, callSite.shortForm, callSite.longForm, queryExecution.toString, + SparkPlanInfo.fromSparkPlan(queryExecution.executedPlan), System.currentTimeMillis())) + try { + body + } finally { + sc.listenerBus.post(SparkListenerSQLExecutionEnd( + executionId, System.currentTimeMillis())) + } } } finally { executionIdToQueryExecution.remove(executionId) @@ -90,13 +92,37 @@ object SQLExecution { * thread from the original one, this method can be used to connect the Spark jobs in this action * with the known executionId, e.g., `BroadcastExchangeExec.relationFuture`. */ - def withExecutionId[T](sc: SparkContext, executionId: String)(body: => T): T = { + def withExecutionId[T](sparkSession: SparkSession, executionId: String)(body: => T): T = { + val sc = sparkSession.sparkContext val oldExecutionId = sc.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) + withSQLConfPropagated(sparkSession) { + try { + sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, executionId) + body + } finally { + sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, oldExecutionId) + } + } + } + + def withSQLConfPropagated[T](sparkSession: SparkSession)(body: => T): T = { + val sc = sparkSession.sparkContext + // Set all the specified SQL configs to local properties, so that they can be available at + // the executor side. + val allConfigs = sparkSession.sessionState.conf.getAllConfs + val originalLocalProps = allConfigs.collect { + case (key, value) if key.startsWith("spark") => + val originalValue = sc.getLocalProperty(key) + sc.setLocalProperty(key, value) + (key, originalValue) + } + try { - sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, executionId) body } finally { - sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, oldExecutionId) + for ((key, value) <- originalLocalProps) { + sc.setLocalProperty(key, value) + } } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index 1edfdc888afd8..d54bfbfc14f5f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -629,7 +629,7 @@ case class SubqueryExec(name: String, child: SparkPlan) extends UnaryExecNode { Future { // This will run in another thread. Set the execution id so that we can connect these jobs // with the correct execution. - SQLExecution.withExecutionId(sparkContext, executionId) { + SQLExecution.withExecutionId(sqlContext.sparkSession, executionId) { val beforeCollect = System.nanoTime() // Note that we use .executeCollect() because we don't want to convert data to Scala types val rows: Array[InternalRow] = child.executeCollect() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala index ba83df0efebd0..3b6df45e949e8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala @@ -34,6 +34,7 @@ import org.apache.spark.rdd.{BinaryFileRDD, RDD} import org.apache.spark.sql.{Dataset, Encoders, SparkSession} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JSONOptions} +import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.text.TextFileFormat import org.apache.spark.sql.types.StructType @@ -104,22 +105,19 @@ object TextInputJsonDataSource extends JsonDataSource { CreateJacksonParser.internalRow(enc, _: JsonFactory, _: InternalRow) }.getOrElse(CreateJacksonParser.internalRow(_: JsonFactory, _: InternalRow)) - JsonInferSchema.infer(rdd, parsedOptions, rowParser) + SQLExecution.withSQLConfPropagated(json.sparkSession) { + JsonInferSchema.infer(rdd, parsedOptions, rowParser) + } } private def createBaseDataset( sparkSession: SparkSession, inputPaths: Seq[FileStatus], parsedOptions: JSONOptions): Dataset[String] = { - val paths = inputPaths.map(_.getPath.toString) - val textOptions = Map.empty[String, String] ++ - parsedOptions.encoding.map("encoding" -> _) ++ - parsedOptions.lineSeparator.map("lineSep" -> _) - sparkSession.baseRelationToDataFrame( DataSource.apply( sparkSession, - paths = paths, + paths = inputPaths.map(_.getPath.toString), className = classOf[TextFileFormat].getName, options = parsedOptions.parameters ).resolveRelation(checkFilesExist = false)) @@ -165,7 +163,9 @@ object MultiLineJsonDataSource extends JsonDataSource { .map(enc => createParser(enc, _: JsonFactory, _: PortableDataStream)) .getOrElse(createParser(_: JsonFactory, _: PortableDataStream)) - JsonInferSchema.infer[PortableDataStream](sampled, parsedOptions, parser) + SQLExecution.withSQLConfPropagated(sparkSession) { + JsonInferSchema.infer[PortableDataStream](sampled, parsedOptions, parser) + } } private def createBaseRdd( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala index daea6c39624d6..9e0ec9481b0de 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala @@ -69,7 +69,7 @@ case class BroadcastExchangeExec( Future { // This will run in another thread. Set the execution id so that we can connect these jobs // with the correct execution. - SQLExecution.withExecutionId(sparkContext, executionId) { + SQLExecution.withExecutionId(sqlContext.sparkSession, executionId) { try { val beforeCollect = System.nanoTime() // Use executeCollect/executeCollectIterator to avoid conversion to Scala types diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/ExecutorSideSQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/ExecutorSideSQLConfSuite.scala new file mode 100644 index 0000000000000..404d6313ab92c --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/ExecutorSideSQLConfSuite.scala @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.internal + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.test.SQLTestUtils + +class ExecutorSideSQLConfSuite extends SparkFunSuite with SQLTestUtils { + import testImplicits._ + + protected var spark: SparkSession = null + + // Create a new [[SparkSession]] running in local-cluster mode. + override def beforeAll(): Unit = { + super.beforeAll() + spark = SparkSession.builder() + .master("local-cluster[2,1,1024]") + .appName("testing") + .getOrCreate() + } + + override def afterAll(): Unit = { + spark.stop() + spark = null + } + + test("ReadonlySQLConf is correctly created at the executor side") { + SQLConf.get.setConfString("spark.sql.x", "a") + try { + val checks = spark.range(10).mapPartitions { it => + val conf = SQLConf.get + Iterator(conf.isInstanceOf[ReadOnlySQLConf] && conf.getConfString("spark.sql.x") == "a") + }.collect() + assert(checks.forall(_ == true)) + } finally { + SQLConf.get.unsetConf("spark.sql.x") + } + } + + test("case-sensitive config should work for json schema inference") { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + withTempPath { path => + val pathString = path.getCanonicalPath + spark.range(10).select('id.as("ID")).write.json(pathString) + spark.range(10).write.mode("append").json(pathString) + assert(spark.read.json(pathString).columns.toSet == Set("id", "ID")) + } + } + } +} From 000e25ae7950ff005d4bbe4fffed410e5947075c Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Sun, 20 May 2018 16:13:42 +0800 Subject: [PATCH 0834/1161] Revert "[SPARK-24250][SQL] support accessing SQLConf inside tasks" This reverts commit dd37529a8dada6ed8a49b8ce50875268f6a20cba. --- .../org/apache/spark/TaskContextImpl.scala | 2 - .../spark/sql/internal/ReadOnlySQLConf.scala | 66 ------------------- .../apache/spark/sql/internal/SQLConf.scala | 21 +++--- .../org/apache/spark/sql/SparkSession.scala | 21 +----- .../spark/sql/execution/SQLExecution.scala | 50 ++++---------- .../execution/basicPhysicalOperators.scala | 2 +- .../datasources/json/JsonDataSource.scala | 16 ++--- .../exchange/BroadcastExchangeExec.scala | 2 +- .../internal/ExecutorSideSQLConfSuite.scala | 66 ------------------- 9 files changed, 36 insertions(+), 210 deletions(-) delete mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/internal/ReadOnlySQLConf.scala delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/internal/ExecutorSideSQLConfSuite.scala diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala index 0791fe856ef15..cccd3ea457ba4 100644 --- a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala +++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala @@ -178,6 +178,4 @@ private[spark] class TaskContextImpl( private[spark] def fetchFailed: Option[FetchFailedException] = _fetchFailedException - // TODO: shall we publish it and define it in `TaskContext`? - private[spark] def getLocalProperties(): Properties = localProperties } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/ReadOnlySQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/ReadOnlySQLConf.scala deleted file mode 100644 index 19f67236c8979..0000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/ReadOnlySQLConf.scala +++ /dev/null @@ -1,66 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.internal - -import java.util.{Map => JMap} - -import org.apache.spark.{TaskContext, TaskContextImpl} -import org.apache.spark.internal.config.{ConfigEntry, ConfigProvider, ConfigReader} - -/** - * A readonly SQLConf that will be created by tasks running at the executor side. It reads the - * configs from the local properties which are propagated from driver to executors. - */ -class ReadOnlySQLConf(context: TaskContext) extends SQLConf { - - @transient override val settings: JMap[String, String] = { - context.asInstanceOf[TaskContextImpl].getLocalProperties().asInstanceOf[JMap[String, String]] - } - - @transient override protected val reader: ConfigReader = { - new ConfigReader(new TaskContextConfigProvider(context)) - } - - override protected def setConfWithCheck(key: String, value: String): Unit = { - throw new UnsupportedOperationException("Cannot mutate ReadOnlySQLConf.") - } - - override def unsetConf(key: String): Unit = { - throw new UnsupportedOperationException("Cannot mutate ReadOnlySQLConf.") - } - - override def unsetConf(entry: ConfigEntry[_]): Unit = { - throw new UnsupportedOperationException("Cannot mutate ReadOnlySQLConf.") - } - - override def clear(): Unit = { - throw new UnsupportedOperationException("Cannot mutate ReadOnlySQLConf.") - } - - override def clone(): SQLConf = { - throw new UnsupportedOperationException("Cannot clone/copy ReadOnlySQLConf.") - } - - override def copy(entries: (ConfigEntry[_], Any)*): SQLConf = { - throw new UnsupportedOperationException("Cannot clone/copy ReadOnlySQLConf.") - } -} - -class TaskContextConfigProvider(context: TaskContext) extends ConfigProvider { - override def get(key: String): Option[String] = Option(context.getLocalProperty(key)) -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 643e4c686f58d..53a50305348fa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -27,12 +27,13 @@ import scala.util.matching.Regex import org.apache.hadoop.fs.Path -import org.apache.spark.TaskContext +import org.apache.spark.{SparkContext, SparkEnv} import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ import org.apache.spark.network.util.ByteUnit import org.apache.spark.sql.catalyst.analysis.Resolver import org.apache.spark.sql.catalyst.expressions.codegen.CodeGenerator +import org.apache.spark.util.Utils //////////////////////////////////////////////////////////////////////////////////////////////////// // This file defines the configuration options for Spark SQL. @@ -106,13 +107,7 @@ object SQLConf { * run tests in parallel. At the time this feature was implemented, this was a no-op since we * run unit tests (that does not involve SparkSession) in serial order. */ - def get: SQLConf = { - if (TaskContext.get != null) { - new ReadOnlySQLConf(TaskContext.get()) - } else { - confGetter.get()() - } - } + def get: SQLConf = confGetter.get()() val OPTIMIZER_MAX_ITERATIONS = buildConf("spark.sql.optimizer.maxIterations") .internal() @@ -1297,11 +1292,17 @@ object SQLConf { class SQLConf extends Serializable with Logging { import SQLConf._ + if (Utils.isTesting && SparkEnv.get != null) { + // assert that we're only accessing it on the driver. + assert(SparkEnv.get.executorId == SparkContext.DRIVER_IDENTIFIER, + "SQLConf should only be created and accessed on the driver.") + } + /** Only low degree of contention is expected for conf, thus NOT using ConcurrentHashMap. */ @transient protected[spark] val settings = java.util.Collections.synchronizedMap( new java.util.HashMap[String, String]()) - @transient protected val reader = new ConfigReader(settings) + @transient private val reader = new ConfigReader(settings) /** ************************ Spark SQL Params/Hints ******************* */ @@ -1764,7 +1765,7 @@ class SQLConf extends Serializable with Logging { settings.containsKey(key) } - protected def setConfWithCheck(key: String, value: String): Unit = { + private def setConfWithCheck(key: String, value: String): Unit = { settings.put(key, value) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index e2a1a57c7dd4d..c502e583a55c5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -24,7 +24,7 @@ import scala.collection.JavaConverters._ import scala.reflect.runtime.universe.TypeTag import scala.util.control.NonFatal -import org.apache.spark.{SPARK_VERSION, SparkConf, SparkContext, TaskContext} +import org.apache.spark.{SPARK_VERSION, SparkConf, SparkContext} import org.apache.spark.annotation.{DeveloperApi, Experimental, InterfaceStability} import org.apache.spark.api.java.JavaRDD import org.apache.spark.internal.Logging @@ -898,7 +898,6 @@ object SparkSession extends Logging { * @since 2.0.0 */ def getOrCreate(): SparkSession = synchronized { - assertOnDriver() // Get the session from current thread's active session. var session = activeThreadSession.get() if ((session ne null) && !session.sparkContext.isStopped) { @@ -1023,20 +1022,14 @@ object SparkSession extends Logging { * * @since 2.2.0 */ - def getActiveSession: Option[SparkSession] = { - assertOnDriver() - Option(activeThreadSession.get) - } + def getActiveSession: Option[SparkSession] = Option(activeThreadSession.get) /** * Returns the default SparkSession that is returned by the builder. * * @since 2.2.0 */ - def getDefaultSession: Option[SparkSession] = { - assertOnDriver() - Option(defaultSession.get) - } + def getDefaultSession: Option[SparkSession] = Option(defaultSession.get) /** * Returns the currently active SparkSession, otherwise the default one. If there is no default @@ -1069,14 +1062,6 @@ object SparkSession extends Logging { } } - private def assertOnDriver(): Unit = { - if (Utils.isTesting && TaskContext.get != null) { - // we're accessing it during task execution, fail. - throw new IllegalStateException( - "SparkSession should only be created and accessed on the driver.") - } - } - /** * Helper method to create an instance of `SessionState` based on `className` from conf. * The result is either `SessionState` or a Hive based `SessionState`. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala index 032525a08ccdb..2c5102b1e5ee7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala @@ -68,18 +68,16 @@ object SQLExecution { // sparkContext.getCallSite() would first try to pick up any call site that was previously // set, then fall back to Utils.getCallSite(); call Utils.getCallSite() directly on // streaming queries would give us call site like "run at :0" - val callSite = sc.getCallSite() + val callSite = sparkSession.sparkContext.getCallSite() - withSQLConfPropagated(sparkSession) { - sc.listenerBus.post(SparkListenerSQLExecutionStart( - executionId, callSite.shortForm, callSite.longForm, queryExecution.toString, - SparkPlanInfo.fromSparkPlan(queryExecution.executedPlan), System.currentTimeMillis())) - try { - body - } finally { - sc.listenerBus.post(SparkListenerSQLExecutionEnd( - executionId, System.currentTimeMillis())) - } + sparkSession.sparkContext.listenerBus.post(SparkListenerSQLExecutionStart( + executionId, callSite.shortForm, callSite.longForm, queryExecution.toString, + SparkPlanInfo.fromSparkPlan(queryExecution.executedPlan), System.currentTimeMillis())) + try { + body + } finally { + sparkSession.sparkContext.listenerBus.post(SparkListenerSQLExecutionEnd( + executionId, System.currentTimeMillis())) } } finally { executionIdToQueryExecution.remove(executionId) @@ -92,37 +90,13 @@ object SQLExecution { * thread from the original one, this method can be used to connect the Spark jobs in this action * with the known executionId, e.g., `BroadcastExchangeExec.relationFuture`. */ - def withExecutionId[T](sparkSession: SparkSession, executionId: String)(body: => T): T = { - val sc = sparkSession.sparkContext + def withExecutionId[T](sc: SparkContext, executionId: String)(body: => T): T = { val oldExecutionId = sc.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) - withSQLConfPropagated(sparkSession) { - try { - sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, executionId) - body - } finally { - sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, oldExecutionId) - } - } - } - - def withSQLConfPropagated[T](sparkSession: SparkSession)(body: => T): T = { - val sc = sparkSession.sparkContext - // Set all the specified SQL configs to local properties, so that they can be available at - // the executor side. - val allConfigs = sparkSession.sessionState.conf.getAllConfs - val originalLocalProps = allConfigs.collect { - case (key, value) if key.startsWith("spark") => - val originalValue = sc.getLocalProperty(key) - sc.setLocalProperty(key, value) - (key, originalValue) - } - try { + sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, executionId) body } finally { - for ((key, value) <- originalLocalProps) { - sc.setLocalProperty(key, value) - } + sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, oldExecutionId) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index d54bfbfc14f5f..1edfdc888afd8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -629,7 +629,7 @@ case class SubqueryExec(name: String, child: SparkPlan) extends UnaryExecNode { Future { // This will run in another thread. Set the execution id so that we can connect these jobs // with the correct execution. - SQLExecution.withExecutionId(sqlContext.sparkSession, executionId) { + SQLExecution.withExecutionId(sparkContext, executionId) { val beforeCollect = System.nanoTime() // Note that we use .executeCollect() because we don't want to convert data to Scala types val rows: Array[InternalRow] = child.executeCollect() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala index 3b6df45e949e8..ba83df0efebd0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala @@ -34,7 +34,6 @@ import org.apache.spark.rdd.{BinaryFileRDD, RDD} import org.apache.spark.sql.{Dataset, Encoders, SparkSession} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JSONOptions} -import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.text.TextFileFormat import org.apache.spark.sql.types.StructType @@ -105,19 +104,22 @@ object TextInputJsonDataSource extends JsonDataSource { CreateJacksonParser.internalRow(enc, _: JsonFactory, _: InternalRow) }.getOrElse(CreateJacksonParser.internalRow(_: JsonFactory, _: InternalRow)) - SQLExecution.withSQLConfPropagated(json.sparkSession) { - JsonInferSchema.infer(rdd, parsedOptions, rowParser) - } + JsonInferSchema.infer(rdd, parsedOptions, rowParser) } private def createBaseDataset( sparkSession: SparkSession, inputPaths: Seq[FileStatus], parsedOptions: JSONOptions): Dataset[String] = { + val paths = inputPaths.map(_.getPath.toString) + val textOptions = Map.empty[String, String] ++ + parsedOptions.encoding.map("encoding" -> _) ++ + parsedOptions.lineSeparator.map("lineSep" -> _) + sparkSession.baseRelationToDataFrame( DataSource.apply( sparkSession, - paths = inputPaths.map(_.getPath.toString), + paths = paths, className = classOf[TextFileFormat].getName, options = parsedOptions.parameters ).resolveRelation(checkFilesExist = false)) @@ -163,9 +165,7 @@ object MultiLineJsonDataSource extends JsonDataSource { .map(enc => createParser(enc, _: JsonFactory, _: PortableDataStream)) .getOrElse(createParser(_: JsonFactory, _: PortableDataStream)) - SQLExecution.withSQLConfPropagated(sparkSession) { - JsonInferSchema.infer[PortableDataStream](sampled, parsedOptions, parser) - } + JsonInferSchema.infer[PortableDataStream](sampled, parsedOptions, parser) } private def createBaseRdd( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala index 9e0ec9481b0de..daea6c39624d6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala @@ -69,7 +69,7 @@ case class BroadcastExchangeExec( Future { // This will run in another thread. Set the execution id so that we can connect these jobs // with the correct execution. - SQLExecution.withExecutionId(sqlContext.sparkSession, executionId) { + SQLExecution.withExecutionId(sparkContext, executionId) { try { val beforeCollect = System.nanoTime() // Use executeCollect/executeCollectIterator to avoid conversion to Scala types diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/ExecutorSideSQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/ExecutorSideSQLConfSuite.scala deleted file mode 100644 index 404d6313ab92c..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/internal/ExecutorSideSQLConfSuite.scala +++ /dev/null @@ -1,66 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.internal - -import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.test.SQLTestUtils - -class ExecutorSideSQLConfSuite extends SparkFunSuite with SQLTestUtils { - import testImplicits._ - - protected var spark: SparkSession = null - - // Create a new [[SparkSession]] running in local-cluster mode. - override def beforeAll(): Unit = { - super.beforeAll() - spark = SparkSession.builder() - .master("local-cluster[2,1,1024]") - .appName("testing") - .getOrCreate() - } - - override def afterAll(): Unit = { - spark.stop() - spark = null - } - - test("ReadonlySQLConf is correctly created at the executor side") { - SQLConf.get.setConfString("spark.sql.x", "a") - try { - val checks = spark.range(10).mapPartitions { it => - val conf = SQLConf.get - Iterator(conf.isInstanceOf[ReadOnlySQLConf] && conf.getConfString("spark.sql.x") == "a") - }.collect() - assert(checks.forall(_ == true)) - } finally { - SQLConf.get.unsetConf("spark.sql.x") - } - } - - test("case-sensitive config should work for json schema inference") { - withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { - withTempPath { path => - val pathString = path.getCanonicalPath - spark.range(10).select('id.as("ID")).write.json(pathString) - spark.range(10).write.mode("append").json(pathString) - assert(spark.read.json(pathString).columns.toSet == Set("id", "ID")) - } - } - } -} From 8eac621229b50e15bea550a751593bba0bf8b20c Mon Sep 17 00:00:00 2001 From: Stavros Date: Sun, 20 May 2018 18:15:04 -0500 Subject: [PATCH 0835/1161] [SPARK-23857][MESOS] remove keytab check in mesos cluster mode at first submit time ## What changes were proposed in this pull request? - Removes the check for the keytab when we are running in mesos cluster mode. - Keeps the check for client mode since in cluster mode we eventually launch the driver within the cluster in client mode. In the latter case we want to have the check done when the container starts, the keytab should be checked if it exists within the container's local filesystem. ## How was this patch tested? This was manually tested by running spark submit in mesos cluster mode. Author: Stavros Closes #20967 from skonto/fix_mesos_keytab_susbmit. --- core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 087e9c31a9c9a..4baf032f0e9c6 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -310,6 +310,7 @@ private[spark] class SparkSubmit extends Logging { val isMesosCluster = clusterManager == MESOS && deployMode == CLUSTER val isStandAloneCluster = clusterManager == STANDALONE && deployMode == CLUSTER val isKubernetesCluster = clusterManager == KUBERNETES && deployMode == CLUSTER + val isMesosClient = clusterManager == MESOS && deployMode == CLIENT if (!isMesosCluster && !isStandAloneCluster) { // Resolve maven dependencies if there are any and add classpath to jars. Add them to py-files @@ -337,7 +338,7 @@ private[spark] class SparkSubmit extends Logging { val targetDir = Utils.createTempDir() // assure a keytab is available from any place in a JVM - if (clusterManager == YARN || clusterManager == LOCAL || clusterManager == MESOS) { + if (clusterManager == YARN || clusterManager == LOCAL || isMesosClient) { if (args.principal != null) { if (args.keytab != null) { require(new File(args.keytab).exists(), s"Keytab file: ${args.keytab} does not exist") From f32b7faf7c4b5d2ac45a2db96935f67d1b629ca2 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Mon, 21 May 2018 09:47:52 +0800 Subject: [PATCH 0836/1161] [MINOR][PROJECT-INFRA] Check if 'original_head' variable is defined in clean_up at merge script ## What changes were proposed in this pull request? This PR proposes to check if global variable exists or not in clean_up. This can happen when it fails at: https://github.com/apache/spark/blob/7013eea11cb32b1e0038dc751c485da5c94a484b/dev/merge_spark_pr.py#L423 I found this (It was my environment problem) but the error message took me a while to debug. ## How was this patch tested? Manually tested: **Before** ``` git rev-parse --abbrev-ref HEAD fatal: Not a git repository (or any of the parent directories): .git Traceback (most recent call last): File "./dev/merge_spark_pr_jira.py", line 517, in clean_up() File "./dev/merge_spark_pr_jira.py", line 104, in clean_up print("Restoring head pointer to %s" % original_head) NameError: global name 'original_head' is not defined ``` **After** ``` git rev-parse --abbrev-ref HEAD fatal: Not a git repository (or any of the parent directories): .git Traceback (most recent call last): File "./dev/merge_spark_pr.py", line 516, in main() File "./dev/merge_spark_pr.py", line 424, in main original_head = get_current_ref() File "./dev/merge_spark_pr.py", line 412, in get_current_ref ref = run_cmd("git rev-parse --abbrev-ref HEAD").strip() File "./dev/merge_spark_pr.py", line 94, in run_cmd return subprocess.check_output(cmd.split(" ")) File "/usr/local/Cellar/python2/2.7.14_3/Frameworks/Python.framework/Versions/2.7/lib/python2.7/subprocess.py", line 219, in check_output raise CalledProcessError(retcode, cmd, output=output) subprocess.CalledProcessError: Command '['git', 'rev-parse', '--abbrev-ref', 'HEAD']' returned non-zero exit status 128 ``` Author: hyukjinkwon Closes #21349 from HyukjinKwon/minor-merge-script. --- dev/merge_spark_pr.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/dev/merge_spark_pr.py b/dev/merge_spark_pr.py index 5ea205fbed4aa..7f46a1c8f6a7c 100755 --- a/dev/merge_spark_pr.py +++ b/dev/merge_spark_pr.py @@ -101,14 +101,15 @@ def continue_maybe(prompt): def clean_up(): - print("Restoring head pointer to %s" % original_head) - run_cmd("git checkout %s" % original_head) + if 'original_head' in globals(): + print("Restoring head pointer to %s" % original_head) + run_cmd("git checkout %s" % original_head) - branches = run_cmd("git branch").replace(" ", "").split("\n") + branches = run_cmd("git branch").replace(" ", "").split("\n") - for branch in filter(lambda x: x.startswith(BRANCH_PREFIX), branches): - print("Deleting local branch %s" % branch) - run_cmd("git branch -D %s" % branch) + for branch in filter(lambda x: x.startswith(BRANCH_PREFIX), branches): + print("Deleting local branch %s" % branch) + run_cmd("git branch -D %s" % branch) # merge the requested PR and return the merge hash From 6d7d45a1af078edd9e4ed027e735d6096482179c Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 21 May 2018 15:39:35 +0800 Subject: [PATCH 0837/1161] [SPARK-24242][SQL] RangeExec should have correct outputOrdering and outputPartitioning ## What changes were proposed in this pull request? Logical `Range` node has been added with `outputOrdering` recently. It's used to eliminate redundant `Sort` during optimization. However, this `outputOrdering` doesn't not propagate to physical `RangeExec` node. We also add correct `outputPartitioning` to `RangeExec` node. ## How was this patch tested? Added test. Author: Liang-Chi Hsieh Closes #21291 from viirya/SPARK-24242. --- python/pyspark/sql/tests.py | 4 +-- .../execution/basicPhysicalOperators.scala | 14 ++++++++++ .../spark/sql/ConfigBehaviorSuite.scala | 4 ++- .../spark/sql/execution/PlannerSuite.scala | 27 ++++++++++++++++++- .../execution/WholeStageCodegenSuite.scala | 4 +-- .../sql/execution/debug/DebuggingSuite.scala | 7 +++-- 6 files changed, 52 insertions(+), 8 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index a1b6db71782bb..c7bd8f01b907f 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -5239,8 +5239,8 @@ def test_complex_groupby(self): expected2 = df.groupby().agg(sum(df.v)) # groupby one column and one sql expression - result3 = df.groupby(df.id, df.v % 2).agg(sum_udf(df.v)) - expected3 = df.groupby(df.id, df.v % 2).agg(sum(df.v)) + result3 = df.groupby(df.id, df.v % 2).agg(sum_udf(df.v)).orderBy(df.id, df.v % 2) + expected3 = df.groupby(df.id, df.v % 2).agg(sum(df.v)).orderBy(df.id, df.v % 2) # groupby one python UDF result4 = df.groupby(plus_one(df.id)).agg(sum_udf(df.v)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index 1edfdc888afd8..2df81d09c58e7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -345,6 +345,20 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) override val output: Seq[Attribute] = range.output + override def outputOrdering: Seq[SortOrder] = range.outputOrdering + + override def outputPartitioning: Partitioning = { + if (numElements > 0) { + if (numSlices == 1) { + SinglePartition + } else { + RangePartitioning(outputOrdering, numSlices) + } + } else { + UnknownPartitioning(0) + } + } + override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ConfigBehaviorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ConfigBehaviorSuite.scala index 949505e449fd7..276496be3d62c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ConfigBehaviorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ConfigBehaviorSuite.scala @@ -39,7 +39,9 @@ class ConfigBehaviorSuite extends QueryTest with SharedSQLContext { def computeChiSquareTest(): Double = { val n = 10000 // Trigger a sort - val data = spark.range(0, n, 1, 1).sort('id.desc) + // Range has range partitioning in its output now. To have a range shuffle, we + // need to run a repartition first. + val data = spark.range(0, n, 1, 1).repartition(10).sort('id.desc) .selectExpr("SPARK_PARTITION_ID() pid", "id").as[(Int, Long)].collect() // Compute histogram for the number of records per partition post sort diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index a375f881c7d63..b2aba8e72c5db 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.{execution, Row} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.{Cross, FullOuter, Inner, LeftOuter, RightOuter} -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Repartition, Sort} +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Range, Repartition, Sort} import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.columnar.InMemoryRelation import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReusedExchangeExec, ReuseExchange, ShuffleExchangeExec} @@ -633,6 +633,31 @@ class PlannerSuite extends SharedSQLContext { requiredOrdering = Seq(orderingA, orderingB), shouldHaveSort = true) } + + test("SPARK-24242: RangeExec should have correct output ordering and partitioning") { + val df = spark.range(10) + val rangeExec = df.queryExecution.executedPlan.collect { + case r: RangeExec => r + } + val range = df.queryExecution.optimizedPlan.collect { + case r: Range => r + } + assert(rangeExec.head.outputOrdering == range.head.outputOrdering) + assert(rangeExec.head.outputPartitioning == + RangePartitioning(rangeExec.head.outputOrdering, df.rdd.getNumPartitions)) + + val rangeInOnePartition = spark.range(1, 10, 1, 1) + val rangeExecInOnePartition = rangeInOnePartition.queryExecution.executedPlan.collect { + case r: RangeExec => r + } + assert(rangeExecInOnePartition.head.outputPartitioning == SinglePartition) + + val rangeInZeroPartition = spark.range(-10, -9, -20, 1) + val rangeExecInZeroPartition = rangeInZeroPartition.queryExecution.executedPlan.collect { + case r: RangeExec => r + } + assert(rangeExecInZeroPartition.head.outputPartitioning == UnknownPartitioning(0)) + } } // Used for unit-testing EnsureRequirements diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index 9180a22c260f1..b714dcd5269fc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -51,12 +51,12 @@ class WholeStageCodegenSuite extends QueryTest with SharedSQLContext { } test("Aggregate with grouping keys should be included in WholeStageCodegen") { - val df = spark.range(3).groupBy("id").count().orderBy("id") + val df = spark.range(3).groupBy(col("id") * 2).count().orderBy(col("id") * 2) val plan = df.queryExecution.executedPlan assert(plan.find(p => p.isInstanceOf[WholeStageCodegenExec] && p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[HashAggregateExec]).isDefined) - assert(df.collect() === Array(Row(0, 1), Row(1, 1), Row(2, 1))) + assert(df.collect() === Array(Row(0, 1), Row(2, 1), Row(4, 1))) } test("BroadcastHashJoin should be included in WholeStageCodegen") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala index adcaf2d76519f..8251ff159e05f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.debug import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.test.SQLTestData.TestData @@ -33,14 +34,16 @@ class DebuggingSuite extends SparkFunSuite with SharedSQLContext { } test("debugCodegen") { - val res = codegenString(spark.range(10).groupBy("id").count().queryExecution.executedPlan) + val res = codegenString(spark.range(10).groupBy(col("id") * 2).count() + .queryExecution.executedPlan) assert(res.contains("Subtree 1 / 2")) assert(res.contains("Subtree 2 / 2")) assert(res.contains("Object[]")) } test("debugCodegenStringSeq") { - val res = codegenStringSeq(spark.range(10).groupBy("id").count().queryExecution.executedPlan) + val res = codegenStringSeq(spark.range(10).groupBy(col("id") * 2).count() + .queryExecution.executedPlan) assert(res.length == 2) assert(res.forall{ case (subtree, code) => subtree.contains("Range") && code.contains("Object[]")}) From e480eccd9754b4900c3e2c2036d69130a262cffe Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Mon, 21 May 2018 15:42:04 +0800 Subject: [PATCH 0838/1161] [SPARK-24323][SQL] Fix lint-java errors ## What changes were proposed in this pull request? This PR fixes the following errors reported by `lint-java` ``` % dev/lint-java Using `mvn` from path: /usr/bin/mvn Checkstyle checks failed at following occurrences: [ERROR] src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartition.java:[39] (sizes) LineLength: Line is longer than 100 characters (found 104). [ERROR] src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartitionReader.java:[26] (sizes) LineLength: Line is longer than 100 characters (found 110). [ERROR] src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartitionReader.java:[30] (sizes) LineLength: Line is longer than 100 characters (found 104). ``` ## How was this patch tested? Run `lint-java` manually. Author: Kazuaki Ishizaki Closes #21374 from kiszk/SPARK-24323. --- .../spark/sql/sources/v2/reader/InputPartition.java | 4 ++-- .../spark/sql/sources/v2/reader/InputPartitionReader.java | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartition.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartition.java index f53687e113ae0..f2038d0de3ffe 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartition.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartition.java @@ -36,8 +36,8 @@ public interface InputPartition extends Serializable { /** - * The preferred locations where the input partition reader returned by this partition can run faster, - * but Spark does not guarantee to run the input partition reader on these locations. + * The preferred locations where the input partition reader returned by this partition can run + * faster, but Spark does not guarantee to run the input partition reader on these locations. * The implementations should make sure that it can be run on any location. * The location is a string representing the host name. * diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartitionReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartitionReader.java index f0d808536207a..33fa7be4c1b20 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartitionReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartitionReader.java @@ -23,12 +23,12 @@ import org.apache.spark.annotation.InterfaceStability; /** - * An input partition reader returned by {@link InputPartition#createPartitionReader()} and is responsible for - * outputting data for a RDD partition. + * An input partition reader returned by {@link InputPartition#createPartitionReader()} and is + * responsible for outputting data for a RDD partition. * * Note that, Currently the type `T` can only be {@link org.apache.spark.sql.Row} for normal input - * partition readers, or {@link org.apache.spark.sql.catalyst.expressions.UnsafeRow} for input partition - * readers that mix in {@link SupportsScanUnsafeRow}. + * partition readers, or {@link org.apache.spark.sql.catalyst.expressions.UnsafeRow} for input + * partition readers that mix in {@link SupportsScanUnsafeRow}. */ @InterfaceStability.Evolving public interface InputPartitionReader extends Closeable { From a6e883feb3b78232ad5cf636f7f7d5e825183041 Mon Sep 17 00:00:00 2001 From: Marek Novotny Date: Mon, 21 May 2018 23:14:03 +0900 Subject: [PATCH 0839/1161] [SPARK-23935][SQL] Adding map_entries function ## What changes were proposed in this pull request? This PR adds `map_entries` function that returns an unordered array of all entries in the given map. ## How was this patch tested? New tests added into: - `CollectionExpressionSuite` - `DataFrameFunctionsSuite` ## CodeGen examples ### Primitive types ``` val df = Seq(Map(1 -> 5, 2 -> 6)).toDF("m") df.filter('m.isNotNull).select(map_entries('m)).debugCodegen ``` Result: ``` /* 042 */ boolean project_isNull_0 = false; /* 043 */ /* 044 */ ArrayData project_value_0 = null; /* 045 */ /* 046 */ final int project_numElements_0 = inputadapter_value_0.numElements(); /* 047 */ final ArrayData project_keys_0 = inputadapter_value_0.keyArray(); /* 048 */ final ArrayData project_values_0 = inputadapter_value_0.valueArray(); /* 049 */ /* 050 */ final long project_size_0 = UnsafeArrayData.calculateSizeOfUnderlyingByteArray( /* 051 */ project_numElements_0, /* 052 */ 32); /* 053 */ if (project_size_0 > 2147483632) { /* 054 */ final Object[] project_internalRowArray_0 = new Object[project_numElements_0]; /* 055 */ for (int z = 0; z < project_numElements_0; z++) { /* 056 */ project_internalRowArray_0[z] = new org.apache.spark.sql.catalyst.expressions.GenericInternalRow(new Object[]{project_keys_0.getInt(z), project_values_0.getInt(z)}); /* 057 */ } /* 058 */ project_value_0 = new org.apache.spark.sql.catalyst.util.GenericArrayData(project_internalRowArray_0); /* 059 */ /* 060 */ } else { /* 061 */ final byte[] project_arrayBytes_0 = new byte[(int)project_size_0]; /* 062 */ UnsafeArrayData project_unsafeArrayData_0 = new UnsafeArrayData(); /* 063 */ Platform.putLong(project_arrayBytes_0, 16, project_numElements_0); /* 064 */ project_unsafeArrayData_0.pointTo(project_arrayBytes_0, 16, (int)project_size_0); /* 065 */ /* 066 */ final int project_structsOffset_0 = UnsafeArrayData.calculateHeaderPortionInBytes(project_numElements_0) + project_numElements_0 * 8; /* 067 */ UnsafeRow project_unsafeRow_0 = new UnsafeRow(2); /* 068 */ for (int z = 0; z < project_numElements_0; z++) { /* 069 */ long offset = project_structsOffset_0 + z * 24L; /* 070 */ project_unsafeArrayData_0.setLong(z, (offset << 32) + 24L); /* 071 */ project_unsafeRow_0.pointTo(project_arrayBytes_0, 16 + offset, 24); /* 072 */ project_unsafeRow_0.setInt(0, project_keys_0.getInt(z)); /* 073 */ project_unsafeRow_0.setInt(1, project_values_0.getInt(z)); /* 074 */ } /* 075 */ project_value_0 = project_unsafeArrayData_0; /* 076 */ /* 077 */ } ``` ### Non-primitive types ``` val df = Seq(Map("a" -> "foo", "b" -> null)).toDF("m") df.filter('m.isNotNull).select(map_entries('m)).debugCodegen ``` Result: ``` /* 042 */ boolean project_isNull_0 = false; /* 043 */ /* 044 */ ArrayData project_value_0 = null; /* 045 */ /* 046 */ final int project_numElements_0 = inputadapter_value_0.numElements(); /* 047 */ final ArrayData project_keys_0 = inputadapter_value_0.keyArray(); /* 048 */ final ArrayData project_values_0 = inputadapter_value_0.valueArray(); /* 049 */ /* 050 */ final Object[] project_internalRowArray_0 = new Object[project_numElements_0]; /* 051 */ for (int z = 0; z < project_numElements_0; z++) { /* 052 */ project_internalRowArray_0[z] = new org.apache.spark.sql.catalyst.expressions.GenericInternalRow(new Object[]{project_keys_0.getUTF8String(z), project_values_0.getUTF8String(z)}); /* 053 */ } /* 054 */ project_value_0 = new org.apache.spark.sql.catalyst.util.GenericArrayData(project_internalRowArray_0); ``` Author: Marek Novotny Closes #21236 from mn-mikke/feature/array-api-map_entries-to-master. --- python/pyspark/sql/functions.py | 20 +++ .../sql/catalyst/expressions/UnsafeRow.java | 2 + .../catalyst/analysis/FunctionRegistry.scala | 1 + .../expressions/codegen/CodeGenerator.scala | 34 ++++ .../expressions/collectionOperations.scala | 153 ++++++++++++++++++ .../CollectionExpressionsSuite.scala | 23 +++ .../expressions/ExpressionEvalHelper.scala | 3 + .../org/apache/spark/sql/functions.scala | 7 + .../spark/sql/DataFrameFunctionsSuite.scala | 44 +++++ 9 files changed, 287 insertions(+) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 8490081facc5a..fbc8a2d038f8f 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2344,6 +2344,26 @@ def map_values(col): return Column(sc._jvm.functions.map_values(_to_java_column(col))) +@since(2.4) +def map_entries(col): + """ + Collection function: Returns an unordered array of all entries in the given map. + + :param col: name of column or expression + + >>> from pyspark.sql.functions import map_entries + >>> df = spark.sql("SELECT map(1, 'a', 2, 'b') as data") + >>> df.select(map_entries("data").alias("entries")).show() + +----------------+ + | entries| + +----------------+ + |[[1, a], [2, b]]| + +----------------+ + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.map_entries(_to_java_column(col))) + + @ignore_unicode_prefix @since(2.4) def array_repeat(col, count): diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index 29a1411241cf6..469b0e60cc9a2 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -62,6 +62,8 @@ */ public final class UnsafeRow extends InternalRow implements Externalizable, KryoSerializable { + public static final int WORD_SIZE = 8; + ////////////////////////////////////////////////////////////////////////////// // Static methods ////////////////////////////////////////////////////////////////////////////// diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 867c2d5eab53d..1134a8866dc13 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -419,6 +419,7 @@ object FunctionRegistry { expression[ElementAt]("element_at"), expression[MapKeys]("map_keys"), expression[MapValues]("map_values"), + expression[MapEntries]("map_entries"), expression[Size]("size"), expression[Slice]("slice"), expression[Size]("cardinality"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 4dda525294259..d382d9aace109 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -764,6 +764,40 @@ class CodegenContext { """.stripMargin } + /** + * Generates code creating a [[UnsafeArrayData]]. The generated code executes + * a provided fallback when the size of backing array would exceed the array size limit. + * @param arrayName a name of the array to create + * @param numElements a piece of code representing the number of elements the array should contain + * @param elementSize a size of an element in bytes + * @param bodyCode a function generating code that fills up the [[UnsafeArrayData]] + * and getting the backing array as a parameter + * @param fallbackCode a piece of code executed when the array size limit is exceeded + */ + def createUnsafeArrayWithFallback( + arrayName: String, + numElements: String, + elementSize: Int, + bodyCode: String => String, + fallbackCode: String): String = { + val arraySize = freshName("size") + val arrayBytes = freshName("arrayBytes") + s""" + |final long $arraySize = UnsafeArrayData.calculateSizeOfUnderlyingByteArray( + | $numElements, + | $elementSize); + |if ($arraySize > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { + | $fallbackCode + |} else { + | final byte[] $arrayBytes = new byte[(int)$arraySize]; + | UnsafeArrayData $arrayName = new UnsafeArrayData(); + | Platform.putLong($arrayBytes, ${Platform.BYTE_ARRAY_OFFSET}, $numElements); + | $arrayName.pointTo($arrayBytes, ${Platform.BYTE_ARRAY_OFFSET}, (int)$arraySize); + | ${bodyCode(arrayBytes)} + |} + """.stripMargin + } + /** * Generates code to do null safe execution, i.e. only execute the code when the input is not * null by adding null check if necessary. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index c82db839438ed..8d763dca5243e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData, TypeUtils} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.array.ByteArrayMethods import org.apache.spark.unsafe.types.{ByteArray, UTF8String} @@ -154,6 +155,158 @@ case class MapValues(child: Expression) override def prettyName: String = "map_values" } +/** + * Returns an unordered array of all entries in the given map. + */ +@ExpressionDescription( + usage = "_FUNC_(map) - Returns an unordered array of all entries in the given map.", + examples = """ + Examples: + > SELECT _FUNC_(map(1, 'a', 2, 'b')); + [(1,"a"),(2,"b")] + """, + since = "2.4.0") +case class MapEntries(child: Expression) extends UnaryExpression with ExpectsInputTypes { + + override def inputTypes: Seq[AbstractDataType] = Seq(MapType) + + lazy val childDataType: MapType = child.dataType.asInstanceOf[MapType] + + override def dataType: DataType = { + ArrayType( + StructType( + StructField("key", childDataType.keyType, false) :: + StructField("value", childDataType.valueType, childDataType.valueContainsNull) :: + Nil), + false) + } + + override protected def nullSafeEval(input: Any): Any = { + val childMap = input.asInstanceOf[MapData] + val keys = childMap.keyArray() + val values = childMap.valueArray() + val length = childMap.numElements() + val resultData = new Array[AnyRef](length) + var i = 0; + while (i < length) { + val key = keys.get(i, childDataType.keyType) + val value = values.get(i, childDataType.valueType) + val row = new GenericInternalRow(Array[Any](key, value)) + resultData.update(i, row) + i += 1 + } + new GenericArrayData(resultData) + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, c => { + val numElements = ctx.freshName("numElements") + val keys = ctx.freshName("keys") + val values = ctx.freshName("values") + val isKeyPrimitive = CodeGenerator.isPrimitiveType(childDataType.keyType) + val isValuePrimitive = CodeGenerator.isPrimitiveType(childDataType.valueType) + val code = if (isKeyPrimitive && isValuePrimitive) { + genCodeForPrimitiveElements(ctx, keys, values, ev.value, numElements) + } else { + genCodeForAnyElements(ctx, keys, values, ev.value, numElements) + } + s""" + |final int $numElements = $c.numElements(); + |final ArrayData $keys = $c.keyArray(); + |final ArrayData $values = $c.valueArray(); + |$code + """.stripMargin + }) + } + + private def getKey(varName: String) = CodeGenerator.getValue(varName, childDataType.keyType, "z") + + private def getValue(varName: String) = { + CodeGenerator.getValue(varName, childDataType.valueType, "z") + } + + private def genCodeForPrimitiveElements( + ctx: CodegenContext, + keys: String, + values: String, + arrayData: String, + numElements: String): String = { + val unsafeRow = ctx.freshName("unsafeRow") + val unsafeArrayData = ctx.freshName("unsafeArrayData") + val structsOffset = ctx.freshName("structsOffset") + val calculateHeader = "UnsafeArrayData.calculateHeaderPortionInBytes" + + val baseOffset = Platform.BYTE_ARRAY_OFFSET + val wordSize = UnsafeRow.WORD_SIZE + val structSize = UnsafeRow.calculateBitSetWidthInBytes(2) + wordSize * 2 + val structSizeAsLong = structSize + "L" + val keyTypeName = CodeGenerator.primitiveTypeName(childDataType.keyType) + val valueTypeName = CodeGenerator.primitiveTypeName(childDataType.keyType) + + val valueAssignment = s"$unsafeRow.set$valueTypeName(1, ${getValue(values)});" + val valueAssignmentChecked = if (childDataType.valueContainsNull) { + s""" + |if ($values.isNullAt(z)) { + | $unsafeRow.setNullAt(1); + |} else { + | $valueAssignment + |} + """.stripMargin + } else { + valueAssignment + } + + val assignmentLoop = (byteArray: String) => + s""" + |final int $structsOffset = $calculateHeader($numElements) + $numElements * $wordSize; + |UnsafeRow $unsafeRow = new UnsafeRow(2); + |for (int z = 0; z < $numElements; z++) { + | long offset = $structsOffset + z * $structSizeAsLong; + | $unsafeArrayData.setLong(z, (offset << 32) + $structSizeAsLong); + | $unsafeRow.pointTo($byteArray, $baseOffset + offset, $structSize); + | $unsafeRow.set$keyTypeName(0, ${getKey(keys)}); + | $valueAssignmentChecked + |} + |$arrayData = $unsafeArrayData; + """.stripMargin + + ctx.createUnsafeArrayWithFallback( + unsafeArrayData, + numElements, + structSize + wordSize, + assignmentLoop, + genCodeForAnyElements(ctx, keys, values, arrayData, numElements)) + } + + private def genCodeForAnyElements( + ctx: CodegenContext, + keys: String, + values: String, + arrayData: String, + numElements: String): String = { + val genericArrayClass = classOf[GenericArrayData].getName + val rowClass = classOf[GenericInternalRow].getName + val data = ctx.freshName("internalRowArray") + + val isValuePrimitive = CodeGenerator.isPrimitiveType(childDataType.valueType) + val getValueWithCheck = if (childDataType.valueContainsNull && isValuePrimitive) { + s"$values.isNullAt(z) ? null : (Object)${getValue(values)}" + } else { + getValue(values) + } + + s""" + |final Object[] $data = new Object[$numElements]; + |for (int z = 0; z < $numElements; z++) { + | $data[z] = new $rowClass(new Object[]{${getKey(keys)}, $getValueWithCheck}); + |} + |$arrayData = new $genericArrayClass($data); + """.stripMargin + } + + override def prettyName: String = "map_entries" +} + /** * Common base class for [[SortArray]] and [[ArraySort]]. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 6ae1ac18c4dc4..71ff96bb722e2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types._ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -56,6 +57,28 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(MapValues(m2), null) } + test("MapEntries") { + def r(values: Any*): InternalRow = create_row(values: _*) + + // Primitive-type keys/values + val mi0 = Literal.create(Map(1 -> 1, 2 -> null, 3 -> 2), MapType(IntegerType, IntegerType)) + val mi1 = Literal.create(Map[Int, Int](), MapType(IntegerType, IntegerType)) + val mi2 = Literal.create(null, MapType(IntegerType, IntegerType)) + + checkEvaluation(MapEntries(mi0), Seq(r(1, 1), r(2, null), r(3, 2))) + checkEvaluation(MapEntries(mi1), Seq.empty) + checkEvaluation(MapEntries(mi2), null) + + // Non-primitive-type keys/values + val ms0 = Literal.create(Map("a" -> "c", "b" -> null), MapType(StringType, StringType)) + val ms1 = Literal.create(Map[Int, Int](), MapType(StringType, StringType)) + val ms2 = Literal.create(null, MapType(StringType, StringType)) + + checkEvaluation(MapEntries(ms0), Seq(r("a", "c"), r("b", null))) + checkEvaluation(MapEntries(ms1), Seq.empty) + checkEvaluation(MapEntries(ms2), null) + } + test("Sort Array") { val a0 = Literal.create(Seq(2, 1, 3), ArrayType(IntegerType)) val a1 = Literal.create(Seq[Integer](), ArrayType(IntegerType)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index a22e9d4655e8c..c2a44e0d33b18 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -98,6 +98,9 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { if (expected.isNaN) result.isNaN else expected == result case (result: Float, expected: Float) => if (expected.isNaN) result.isNaN else expected == result + case (result: UnsafeRow, expected: GenericInternalRow) => + val structType = exprDataType.asInstanceOf[StructType] + result.toSeq(structType) == expected.toSeq(structType) case (result: Row, expected: InternalRow) => result.toSeq == expected.toSeq(result.schema) case _ => result == expected diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 2a8fe583b83bc..5ab9cb3fb86a5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3492,6 +3492,13 @@ object functions { */ def map_values(e: Column): Column = withExpr { MapValues(e.expr) } + /** + * Returns an unordered array of all entries in the given map. + * @group collection_funcs + * @since 2.4.0 + */ + def map_entries(e: Column): Column = withExpr { MapEntries(e.expr) } + // scalastyle:off line.size.limit // scalastyle:off parameter.number diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index d08982a138bc5..df23e07e441a0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -405,6 +405,50 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { ) } + test("map_entries") { + val dummyFilter = (c: Column) => c.isNotNull || c.isNull + + // Primitive-type elements + val idf = Seq( + Map[Int, Int](1 -> 100, 2 -> 200, 3 -> 300), + Map[Int, Int](), + null + ).toDF("m") + val iExpected = Seq( + Row(Seq(Row(1, 100), Row(2, 200), Row(3, 300))), + Row(Seq.empty), + Row(null) + ) + + checkAnswer(idf.select(map_entries('m)), iExpected) + checkAnswer(idf.selectExpr("map_entries(m)"), iExpected) + checkAnswer(idf.filter(dummyFilter('m)).select(map_entries('m)), iExpected) + checkAnswer( + spark.range(1).selectExpr("map_entries(map(1, null, 2, null))"), + Seq(Row(Seq(Row(1, null), Row(2, null))))) + checkAnswer( + spark.range(1).filter(dummyFilter('id)).selectExpr("map_entries(map(1, null, 2, null))"), + Seq(Row(Seq(Row(1, null), Row(2, null))))) + + // Non-primitive-type elements + val sdf = Seq( + Map[String, String]("a" -> "f", "b" -> "o", "c" -> "o"), + Map[String, String]("a" -> null, "b" -> null), + Map[String, String](), + null + ).toDF("m") + val sExpected = Seq( + Row(Seq(Row("a", "f"), Row("b", "o"), Row("c", "o"))), + Row(Seq(Row("a", null), Row("b", null))), + Row(Seq.empty), + Row(null) + ) + + checkAnswer(sdf.select(map_entries('m)), sExpected) + checkAnswer(sdf.selectExpr("map_entries(m)"), sExpected) + checkAnswer(sdf.filter(dummyFilter('m)).select(map_entries('m)), sExpected) + } + test("array contains function") { val df = Seq( (Seq[Int](1, 2), "x"), From 03e90f65bfdad376400a4ae4df31a82c05ed4d4b Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 22 May 2018 00:19:18 +0800 Subject: [PATCH 0840/1161] [SPARK-24250][SQL] support accessing SQLConf inside tasks re-submit https://github.com/apache/spark/pull/21299 which broke build. A few new commits are added to fix the SQLConf problem in `JsonSchemaInference.infer`, and prevent us to access `SQLConf` in DAGScheduler event loop thread. ## What changes were proposed in this pull request? Previously in #20136 we decided to forbid tasks to access `SQLConf`, because it doesn't work and always give you the default conf value. In #21190 we fixed the check and all the places that violate it. Currently the pattern of accessing configs at the executor side is: read the configs at the driver side, then access the variables holding the config values in the RDD closure, so that they will be serialized to the executor side. Something like ``` val someConf = conf.getXXX child.execute().mapPartitions { if (someConf == ...) ... ... } ``` However, this pattern is hard to apply if the config needs to be propagated via a long call stack. An example is `DataType.sameType`, and see how many changes were made in #21190 . When it comes to code generation, it's even worse. I tried it locally and we need to change a ton of files to propagate configs to code generators. This PR proposes to allow tasks to access `SQLConf`. The idea is, we can save all the SQL configs to job properties when an SQL execution is triggered. At executor side we rebuild the `SQLConf` from job properties. ## How was this patch tested? a new test suite Author: Wenchen Fan Closes #21376 from cloud-fan/config. --- .../org/apache/spark/TaskContextImpl.scala | 2 + .../apache/spark/scheduler/DAGScheduler.scala | 2 +- .../org/apache/spark/util/EventLoop.scala | 3 +- .../spark/sql/internal/ReadOnlySQLConf.scala | 66 +++++++++++++++++++ .../apache/spark/sql/internal/SQLConf.scala | 33 ++++++---- .../org/apache/spark/sql/SparkSession.scala | 21 +++++- .../spark/sql/execution/SQLExecution.scala | 54 +++++++++++---- .../execution/basicPhysicalOperators.scala | 2 +- .../datasources/json/JsonDataSource.scala | 16 ++--- .../datasources/json/JsonInferSchema.scala | 15 +++-- .../exchange/BroadcastExchangeExec.scala | 2 +- .../internal/ExecutorSideSQLConfSuite.scala | 66 +++++++++++++++++++ 12 files changed, 239 insertions(+), 43 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/internal/ReadOnlySQLConf.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/internal/ExecutorSideSQLConfSuite.scala diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala index cccd3ea457ba4..0791fe856ef15 100644 --- a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala +++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala @@ -178,4 +178,6 @@ private[spark] class TaskContextImpl( private[spark] def fetchFailed: Option[FetchFailedException] = _fetchFailedException + // TODO: shall we publish it and define it in `TaskContext`? + private[spark] def getLocalProperties(): Properties = localProperties } diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 78b6b34b5d2bb..5f2d16d03165f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -206,7 +206,7 @@ class DAGScheduler( private val messageScheduler = ThreadUtils.newDaemonSingleThreadScheduledExecutor("dag-scheduler-message") - private[scheduler] val eventProcessLoop = new DAGSchedulerEventProcessLoop(this) + private[spark] val eventProcessLoop = new DAGSchedulerEventProcessLoop(this) taskScheduler.setDAGScheduler(this) /** diff --git a/core/src/main/scala/org/apache/spark/util/EventLoop.scala b/core/src/main/scala/org/apache/spark/util/EventLoop.scala index 3ea9139e11027..651ea4996f6cb 100644 --- a/core/src/main/scala/org/apache/spark/util/EventLoop.scala +++ b/core/src/main/scala/org/apache/spark/util/EventLoop.scala @@ -37,7 +37,8 @@ private[spark] abstract class EventLoop[E](name: String) extends Logging { private val stopped = new AtomicBoolean(false) - private val eventThread = new Thread(name) { + // Exposed for testing. + private[spark] val eventThread = new Thread(name) { setDaemon(true) override def run(): Unit = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/ReadOnlySQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/ReadOnlySQLConf.scala new file mode 100644 index 0000000000000..19f67236c8979 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/ReadOnlySQLConf.scala @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.internal + +import java.util.{Map => JMap} + +import org.apache.spark.{TaskContext, TaskContextImpl} +import org.apache.spark.internal.config.{ConfigEntry, ConfigProvider, ConfigReader} + +/** + * A readonly SQLConf that will be created by tasks running at the executor side. It reads the + * configs from the local properties which are propagated from driver to executors. + */ +class ReadOnlySQLConf(context: TaskContext) extends SQLConf { + + @transient override val settings: JMap[String, String] = { + context.asInstanceOf[TaskContextImpl].getLocalProperties().asInstanceOf[JMap[String, String]] + } + + @transient override protected val reader: ConfigReader = { + new ConfigReader(new TaskContextConfigProvider(context)) + } + + override protected def setConfWithCheck(key: String, value: String): Unit = { + throw new UnsupportedOperationException("Cannot mutate ReadOnlySQLConf.") + } + + override def unsetConf(key: String): Unit = { + throw new UnsupportedOperationException("Cannot mutate ReadOnlySQLConf.") + } + + override def unsetConf(entry: ConfigEntry[_]): Unit = { + throw new UnsupportedOperationException("Cannot mutate ReadOnlySQLConf.") + } + + override def clear(): Unit = { + throw new UnsupportedOperationException("Cannot mutate ReadOnlySQLConf.") + } + + override def clone(): SQLConf = { + throw new UnsupportedOperationException("Cannot clone/copy ReadOnlySQLConf.") + } + + override def copy(entries: (ConfigEntry[_], Any)*): SQLConf = { + throw new UnsupportedOperationException("Cannot clone/copy ReadOnlySQLConf.") + } +} + +class TaskContextConfigProvider(context: TaskContext) extends ConfigProvider { + override def get(key: String): Option[String] = Option(context.getLocalProperty(key)) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 53a50305348fa..a2fb3c64844b5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -27,7 +27,7 @@ import scala.util.matching.Regex import org.apache.hadoop.fs.Path -import org.apache.spark.{SparkContext, SparkEnv} +import org.apache.spark.{SparkContext, TaskContext} import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ import org.apache.spark.network.util.ByteUnit @@ -95,7 +95,9 @@ object SQLConf { /** * Returns the active config object within the current scope. If there is an active SparkSession, - * the proper SQLConf associated with the thread's session is used. + * the proper SQLConf associated with the thread's active session is used. If it's called from + * tasks in the executor side, a SQLConf will be created from job local properties, which are set + * and propagated from the driver side. * * The way this works is a little bit convoluted, due to the fact that config was added initially * only for physical plans (and as a result not in sql/catalyst module). @@ -107,7 +109,22 @@ object SQLConf { * run tests in parallel. At the time this feature was implemented, this was a no-op since we * run unit tests (that does not involve SparkSession) in serial order. */ - def get: SQLConf = confGetter.get()() + def get: SQLConf = { + if (TaskContext.get != null) { + new ReadOnlySQLConf(TaskContext.get()) + } else { + if (Utils.isTesting && SparkContext.getActive.isDefined) { + // DAGScheduler event loop thread does not have an active SparkSession, the `confGetter` + // will return `fallbackConf` which is unexpected. Here we prevent it from happening. + val schedulerEventLoopThread = + SparkContext.getActive.get.dagScheduler.eventProcessLoop.eventThread + if (schedulerEventLoopThread.getId == Thread.currentThread().getId) { + throw new RuntimeException("Cannot get SQLConf inside scheduler event loop thread.") + } + } + confGetter.get()() + } + } val OPTIMIZER_MAX_ITERATIONS = buildConf("spark.sql.optimizer.maxIterations") .internal() @@ -1292,17 +1309,11 @@ object SQLConf { class SQLConf extends Serializable with Logging { import SQLConf._ - if (Utils.isTesting && SparkEnv.get != null) { - // assert that we're only accessing it on the driver. - assert(SparkEnv.get.executorId == SparkContext.DRIVER_IDENTIFIER, - "SQLConf should only be created and accessed on the driver.") - } - /** Only low degree of contention is expected for conf, thus NOT using ConcurrentHashMap. */ @transient protected[spark] val settings = java.util.Collections.synchronizedMap( new java.util.HashMap[String, String]()) - @transient private val reader = new ConfigReader(settings) + @transient protected val reader = new ConfigReader(settings) /** ************************ Spark SQL Params/Hints ******************* */ @@ -1765,7 +1776,7 @@ class SQLConf extends Serializable with Logging { settings.containsKey(key) } - private def setConfWithCheck(key: String, value: String): Unit = { + protected def setConfWithCheck(key: String, value: String): Unit = { settings.put(key, value) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index c502e583a55c5..e2a1a57c7dd4d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -24,7 +24,7 @@ import scala.collection.JavaConverters._ import scala.reflect.runtime.universe.TypeTag import scala.util.control.NonFatal -import org.apache.spark.{SPARK_VERSION, SparkConf, SparkContext} +import org.apache.spark.{SPARK_VERSION, SparkConf, SparkContext, TaskContext} import org.apache.spark.annotation.{DeveloperApi, Experimental, InterfaceStability} import org.apache.spark.api.java.JavaRDD import org.apache.spark.internal.Logging @@ -898,6 +898,7 @@ object SparkSession extends Logging { * @since 2.0.0 */ def getOrCreate(): SparkSession = synchronized { + assertOnDriver() // Get the session from current thread's active session. var session = activeThreadSession.get() if ((session ne null) && !session.sparkContext.isStopped) { @@ -1022,14 +1023,20 @@ object SparkSession extends Logging { * * @since 2.2.0 */ - def getActiveSession: Option[SparkSession] = Option(activeThreadSession.get) + def getActiveSession: Option[SparkSession] = { + assertOnDriver() + Option(activeThreadSession.get) + } /** * Returns the default SparkSession that is returned by the builder. * * @since 2.2.0 */ - def getDefaultSession: Option[SparkSession] = Option(defaultSession.get) + def getDefaultSession: Option[SparkSession] = { + assertOnDriver() + Option(defaultSession.get) + } /** * Returns the currently active SparkSession, otherwise the default one. If there is no default @@ -1062,6 +1069,14 @@ object SparkSession extends Logging { } } + private def assertOnDriver(): Unit = { + if (Utils.isTesting && TaskContext.get != null) { + // we're accessing it during task execution, fail. + throw new IllegalStateException( + "SparkSession should only be created and accessed on the driver.") + } + } + /** * Helper method to create an instance of `SessionState` based on `className` from conf. * The result is either `SessionState` or a Hive based `SessionState`. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala index 2c5102b1e5ee7..439932b0cc3ac 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala @@ -68,16 +68,18 @@ object SQLExecution { // sparkContext.getCallSite() would first try to pick up any call site that was previously // set, then fall back to Utils.getCallSite(); call Utils.getCallSite() directly on // streaming queries would give us call site like "run at :0" - val callSite = sparkSession.sparkContext.getCallSite() + val callSite = sc.getCallSite() - sparkSession.sparkContext.listenerBus.post(SparkListenerSQLExecutionStart( - executionId, callSite.shortForm, callSite.longForm, queryExecution.toString, - SparkPlanInfo.fromSparkPlan(queryExecution.executedPlan), System.currentTimeMillis())) - try { - body - } finally { - sparkSession.sparkContext.listenerBus.post(SparkListenerSQLExecutionEnd( - executionId, System.currentTimeMillis())) + withSQLConfPropagated(sparkSession) { + sc.listenerBus.post(SparkListenerSQLExecutionStart( + executionId, callSite.shortForm, callSite.longForm, queryExecution.toString, + SparkPlanInfo.fromSparkPlan(queryExecution.executedPlan), System.currentTimeMillis())) + try { + body + } finally { + sc.listenerBus.post(SparkListenerSQLExecutionEnd( + executionId, System.currentTimeMillis())) + } } } finally { executionIdToQueryExecution.remove(executionId) @@ -90,13 +92,41 @@ object SQLExecution { * thread from the original one, this method can be used to connect the Spark jobs in this action * with the known executionId, e.g., `BroadcastExchangeExec.relationFuture`. */ - def withExecutionId[T](sc: SparkContext, executionId: String)(body: => T): T = { + def withExecutionId[T](sparkSession: SparkSession, executionId: String)(body: => T): T = { + val sc = sparkSession.sparkContext val oldExecutionId = sc.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) + withSQLConfPropagated(sparkSession) { + try { + sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, executionId) + body + } finally { + sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, oldExecutionId) + } + } + } + + /** + * Wrap an action with specified SQL configs. These configs will be propagated to the executor + * side via job local properties. + */ + def withSQLConfPropagated[T](sparkSession: SparkSession)(body: => T): T = { + val sc = sparkSession.sparkContext + // Set all the specified SQL configs to local properties, so that they can be available at + // the executor side. + val allConfigs = sparkSession.sessionState.conf.getAllConfs + val originalLocalProps = allConfigs.collect { + case (key, value) if key.startsWith("spark") => + val originalValue = sc.getLocalProperty(key) + sc.setLocalProperty(key, value) + (key, originalValue) + } + try { - sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, executionId) body } finally { - sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, oldExecutionId) + for ((key, value) <- originalLocalProps) { + sc.setLocalProperty(key, value) + } } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index 2df81d09c58e7..9434ceb7cd16c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -643,7 +643,7 @@ case class SubqueryExec(name: String, child: SparkPlan) extends UnaryExecNode { Future { // This will run in another thread. Set the execution id so that we can connect these jobs // with the correct execution. - SQLExecution.withExecutionId(sparkContext, executionId) { + SQLExecution.withExecutionId(sqlContext.sparkSession, executionId) { val beforeCollect = System.nanoTime() // Note that we use .executeCollect() because we don't want to convert data to Scala types val rows: Array[InternalRow] = child.executeCollect() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala index ba83df0efebd0..3b6df45e949e8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala @@ -34,6 +34,7 @@ import org.apache.spark.rdd.{BinaryFileRDD, RDD} import org.apache.spark.sql.{Dataset, Encoders, SparkSession} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JSONOptions} +import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.text.TextFileFormat import org.apache.spark.sql.types.StructType @@ -104,22 +105,19 @@ object TextInputJsonDataSource extends JsonDataSource { CreateJacksonParser.internalRow(enc, _: JsonFactory, _: InternalRow) }.getOrElse(CreateJacksonParser.internalRow(_: JsonFactory, _: InternalRow)) - JsonInferSchema.infer(rdd, parsedOptions, rowParser) + SQLExecution.withSQLConfPropagated(json.sparkSession) { + JsonInferSchema.infer(rdd, parsedOptions, rowParser) + } } private def createBaseDataset( sparkSession: SparkSession, inputPaths: Seq[FileStatus], parsedOptions: JSONOptions): Dataset[String] = { - val paths = inputPaths.map(_.getPath.toString) - val textOptions = Map.empty[String, String] ++ - parsedOptions.encoding.map("encoding" -> _) ++ - parsedOptions.lineSeparator.map("lineSep" -> _) - sparkSession.baseRelationToDataFrame( DataSource.apply( sparkSession, - paths = paths, + paths = inputPaths.map(_.getPath.toString), className = classOf[TextFileFormat].getName, options = parsedOptions.parameters ).resolveRelation(checkFilesExist = false)) @@ -165,7 +163,9 @@ object MultiLineJsonDataSource extends JsonDataSource { .map(enc => createParser(enc, _: JsonFactory, _: PortableDataStream)) .getOrElse(createParser(_: JsonFactory, _: PortableDataStream)) - JsonInferSchema.infer[PortableDataStream](sampled, parsedOptions, parser) + SQLExecution.withSQLConfPropagated(sparkSession) { + JsonInferSchema.infer[PortableDataStream](sampled, parsedOptions, parser) + } } private def createBaseRdd( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala index a270a6451d5dd..e7eed95a560a3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala @@ -45,8 +45,9 @@ private[sql] object JsonInferSchema { val parseMode = configOptions.parseMode val columnNameOfCorruptRecord = configOptions.columnNameOfCorruptRecord - // perform schema inference on each row and merge afterwards - val rootType = json.mapPartitions { iter => + // In each RDD partition, perform schema inference on each row and merge afterwards. + val typeMerger = compatibleRootType(columnNameOfCorruptRecord, parseMode) + val mergedTypesFromPartitions = json.mapPartitions { iter => val factory = new JsonFactory() configOptions.setJacksonOptions(factory) iter.flatMap { row => @@ -66,9 +67,13 @@ private[sql] object JsonInferSchema { s"Parse Mode: ${FailFastMode.name}.", e) } } - } - }.fold(StructType(Nil))( - compatibleRootType(columnNameOfCorruptRecord, parseMode)) + }.reduceOption(typeMerger).toIterator + } + + // Here we get RDD local iterator then fold, instead of calling `RDD.fold` directly, because + // `RDD.fold` will run the fold function in DAGScheduler event loop thread, which may not have + // active SparkSession and `SQLConf.get` may point to the wrong configs. + val rootType = mergedTypesFromPartitions.toLocalIterator.fold(StructType(Nil))(typeMerger) canonicalizeType(rootType) match { case Some(st: StructType) => st diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala index daea6c39624d6..9e0ec9481b0de 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala @@ -69,7 +69,7 @@ case class BroadcastExchangeExec( Future { // This will run in another thread. Set the execution id so that we can connect these jobs // with the correct execution. - SQLExecution.withExecutionId(sparkContext, executionId) { + SQLExecution.withExecutionId(sqlContext.sparkSession, executionId) { try { val beforeCollect = System.nanoTime() // Use executeCollect/executeCollectIterator to avoid conversion to Scala types diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/ExecutorSideSQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/ExecutorSideSQLConfSuite.scala new file mode 100644 index 0000000000000..3dd0712e02448 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/ExecutorSideSQLConfSuite.scala @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.internal + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.test.SQLTestUtils + +class ExecutorSideSQLConfSuite extends SparkFunSuite with SQLTestUtils { + import testImplicits._ + + protected var spark: SparkSession = null + + // Create a new [[SparkSession]] running in local-cluster mode. + override def beforeAll(): Unit = { + super.beforeAll() + spark = SparkSession.builder() + .master("local-cluster[2,1,1024]") + .appName("testing") + .getOrCreate() + } + + override def afterAll(): Unit = { + spark.stop() + spark = null + } + + test("ReadOnlySQLConf is correctly created at the executor side") { + SQLConf.get.setConfString("spark.sql.x", "a") + try { + val checks = spark.range(10).mapPartitions { it => + val conf = SQLConf.get + Iterator(conf.isInstanceOf[ReadOnlySQLConf] && conf.getConfString("spark.sql.x") == "a") + }.collect() + assert(checks.forall(_ == true)) + } finally { + SQLConf.get.unsetConf("spark.sql.x") + } + } + + test("case-sensitive config should work for json schema inference") { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + withTempPath { path => + val pathString = path.getCanonicalPath + spark.range(10).select('id.as("ID")).write.json(pathString) + spark.range(10).write.mode("append").json(pathString) + assert(spark.read.json(pathString).columns.toSet == Set("id", "ID")) + } + } + } +} From a33dcf4a0bbe20dce6f1e1e6c2e1c3828291fb3d Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Mon, 21 May 2018 12:58:05 -0700 Subject: [PATCH 0841/1161] [SPARK-24234][SS] Reader for continuous processing shuffle ## What changes were proposed in this pull request? Read RDD for continuous processing shuffle, as well as the initial RPC-based row receiver. https://docs.google.com/document/d/1IL4kJoKrZWeyIhklKUJqsW-yEN7V7aL05MmM65AYOfE/edit#heading=h.8t3ci57f7uii ## How was this patch tested? new unit tests Author: Jose Torres Closes #21337 from jose-torres/readerRddMaster. --- .../shuffle/ContinuousShuffleReadRDD.scala | 61 ++++++ .../shuffle/ContinuousShuffleReader.scala | 32 +++ .../shuffle/UnsafeRowReceiver.scala | 75 +++++++ .../shuffle/ContinuousShuffleReadSuite.scala | 184 ++++++++++++++++++ 4 files changed, 352 insertions(+) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReader.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala new file mode 100644 index 0000000000000..270b1a5c28dee --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.continuous.shuffle + +import java.util.UUID + +import org.apache.spark.{Partition, SparkContext, SparkEnv, TaskContext} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.util.NextIterator + +case class ContinuousShuffleReadPartition(index: Int, queueSize: Int) extends Partition { + // Initialized only on the executor, and only once even as we call compute() multiple times. + lazy val (reader: ContinuousShuffleReader, endpoint) = { + val env = SparkEnv.get.rpcEnv + val receiver = new UnsafeRowReceiver(queueSize, env) + val endpoint = env.setupEndpoint(s"UnsafeRowReceiver-${UUID.randomUUID()}", receiver) + TaskContext.get().addTaskCompletionListener { ctx => + env.stop(endpoint) + } + (receiver, endpoint) + } +} + +/** + * RDD at the map side of each continuous processing shuffle task. Upstream tasks send their + * shuffle output to the wrapped receivers in partitions of this RDD; each of the RDD's tasks + * poll from their receiver until an epoch marker is sent. + */ +class ContinuousShuffleReadRDD( + sc: SparkContext, + numPartitions: Int, + queueSize: Int = 1024) + extends RDD[UnsafeRow](sc, Nil) { + + override protected def getPartitions: Array[Partition] = { + (0 until numPartitions).map { partIndex => + ContinuousShuffleReadPartition(partIndex, queueSize) + }.toArray + } + + override def compute(split: Partition, context: TaskContext): Iterator[UnsafeRow] = { + split.asInstanceOf[ContinuousShuffleReadPartition].reader.read() + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReader.scala new file mode 100644 index 0000000000000..42631c90ebc55 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReader.scala @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.continuous.shuffle + +import org.apache.spark.sql.catalyst.expressions.UnsafeRow + +/** + * Trait for reading from a continuous processing shuffle. + */ +trait ContinuousShuffleReader { + /** + * Returns an iterator over the incoming rows in an epoch. Implementations should block waiting + * for new rows to arrive, and end the iterator once they've received epoch markers from all + * shuffle writers. + */ + def read(): Iterator[UnsafeRow] +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala new file mode 100644 index 0000000000000..b8adbb743c6c2 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.continuous.shuffle + +import java.util.concurrent.{ArrayBlockingQueue, BlockingQueue} +import java.util.concurrent.atomic.AtomicBoolean + +import org.apache.spark.internal.Logging +import org.apache.spark.rpc.{RpcCallContext, RpcEnv, ThreadSafeRpcEndpoint} +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.util.NextIterator + +/** + * Messages for the UnsafeRowReceiver endpoint. Either an incoming row or an epoch marker. + */ +private[shuffle] sealed trait UnsafeRowReceiverMessage extends Serializable +private[shuffle] case class ReceiverRow(row: UnsafeRow) extends UnsafeRowReceiverMessage +private[shuffle] case class ReceiverEpochMarker() extends UnsafeRowReceiverMessage + +/** + * RPC endpoint for receiving rows into a continuous processing shuffle task. Continuous shuffle + * writers will send rows here, with continuous shuffle readers polling for new rows as needed. + * + * TODO: Support multiple source tasks. We need to output a single epoch marker once all + * source tasks have sent one. + */ +private[shuffle] class UnsafeRowReceiver( + queueSize: Int, + override val rpcEnv: RpcEnv) + extends ThreadSafeRpcEndpoint with ContinuousShuffleReader with Logging { + // Note that this queue will be drained from the main task thread and populated in the RPC + // response thread. + private val queue = new ArrayBlockingQueue[UnsafeRowReceiverMessage](queueSize) + + // Exposed for testing to determine if the endpoint gets stopped on task end. + private[shuffle] val stopped = new AtomicBoolean(false) + + override def onStop(): Unit = { + stopped.set(true) + } + + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case r: UnsafeRowReceiverMessage => + queue.put(r) + context.reply(()) + } + + override def read(): Iterator[UnsafeRow] = { + new NextIterator[UnsafeRow] { + override def getNext(): UnsafeRow = queue.take() match { + case ReceiverRow(r) => r + case ReceiverEpochMarker() => + finished = true + null + } + + override def close(): Unit = {} + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala new file mode 100644 index 0000000000000..b25e75b3b37a6 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala @@ -0,0 +1,184 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.continuous.shuffle + +import org.apache.spark.{TaskContext, TaskContextImpl} +import org.apache.spark.rpc.RpcEndpointRef +import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection} +import org.apache.spark.sql.streaming.StreamTest +import org.apache.spark.sql.types.{DataType, IntegerType} + +class ContinuousShuffleReadSuite extends StreamTest { + + private def unsafeRow(value: Int) = { + UnsafeProjection.create(Array(IntegerType : DataType))( + new GenericInternalRow(Array(value: Any))) + } + + private def send(endpoint: RpcEndpointRef, messages: UnsafeRowReceiverMessage*) = { + messages.foreach(endpoint.askSync[Unit](_)) + } + + // In this unit test, we emulate that we're in the task thread where + // ContinuousShuffleReadRDD.compute() will be evaluated. This requires a task context + // thread local to be set. + var ctx: TaskContextImpl = _ + + override def beforeEach(): Unit = { + super.beforeEach() + ctx = TaskContext.empty() + TaskContext.setTaskContext(ctx) + } + + override def afterEach(): Unit = { + ctx.markTaskCompleted(None) + TaskContext.unset() + ctx = null + super.afterEach() + } + + test("receiver stopped with row last") { + val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) + val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint + send( + endpoint, + ReceiverEpochMarker(), + ReceiverRow(unsafeRow(111)) + ) + + ctx.markTaskCompleted(None) + val receiver = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].reader + eventually(timeout(streamingTimeout)) { + assert(receiver.asInstanceOf[UnsafeRowReceiver].stopped.get()) + } + } + + test("receiver stopped with marker last") { + val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) + val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint + endpoint.askSync[Unit](ReceiverRow(unsafeRow(111))) + endpoint.askSync[Unit](ReceiverEpochMarker()) + + ctx.markTaskCompleted(None) + val receiver = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].reader + eventually(timeout(streamingTimeout)) { + assert(receiver.asInstanceOf[UnsafeRowReceiver].stopped.get()) + } + } + + test("one epoch") { + val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) + val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint + send( + endpoint, + ReceiverRow(unsafeRow(111)), + ReceiverRow(unsafeRow(222)), + ReceiverRow(unsafeRow(333)), + ReceiverEpochMarker() + ) + + val iter = rdd.compute(rdd.partitions(0), ctx) + assert(iter.toSeq.map(_.getInt(0)) == Seq(111, 222, 333)) + } + + test("multiple epochs") { + val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) + val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint + send( + endpoint, + ReceiverRow(unsafeRow(111)), + ReceiverEpochMarker(), + ReceiverRow(unsafeRow(222)), + ReceiverRow(unsafeRow(333)), + ReceiverEpochMarker() + ) + + val firstEpoch = rdd.compute(rdd.partitions(0), ctx) + assert(firstEpoch.toSeq.map(_.getInt(0)) == Seq(111)) + + val secondEpoch = rdd.compute(rdd.partitions(0), ctx) + assert(secondEpoch.toSeq.map(_.getInt(0)) == Seq(222, 333)) + } + + test("empty epochs") { + val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) + val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint + send( + endpoint, + ReceiverEpochMarker(), + ReceiverEpochMarker(), + ReceiverRow(unsafeRow(111)), + ReceiverEpochMarker(), + ReceiverEpochMarker(), + ReceiverEpochMarker() + ) + + assert(rdd.compute(rdd.partitions(0), ctx).isEmpty) + assert(rdd.compute(rdd.partitions(0), ctx).isEmpty) + + val thirdEpoch = rdd.compute(rdd.partitions(0), ctx) + assert(thirdEpoch.toSeq.map(_.getInt(0)) == Seq(111)) + + assert(rdd.compute(rdd.partitions(0), ctx).isEmpty) + assert(rdd.compute(rdd.partitions(0), ctx).isEmpty) + } + + test("multiple partitions") { + val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 5) + // Send all data before processing to ensure there's no crossover. + for (p <- rdd.partitions) { + val part = p.asInstanceOf[ContinuousShuffleReadPartition] + // Send index for identification. + send( + part.endpoint, + ReceiverRow(unsafeRow(part.index)), + ReceiverEpochMarker() + ) + } + + for (p <- rdd.partitions) { + val part = p.asInstanceOf[ContinuousShuffleReadPartition] + val iter = rdd.compute(part, ctx) + assert(iter.next().getInt(0) == part.index) + assert(!iter.hasNext) + } + } + + test("blocks waiting for new rows") { + val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) + + val readRowThread = new Thread { + override def run(): Unit = { + // set the non-inheritable thread local + TaskContext.setTaskContext(ctx) + val epoch = rdd.compute(rdd.partitions(0), ctx) + epoch.next().getInt(0) + } + } + + try { + readRowThread.start() + eventually(timeout(streamingTimeout)) { + assert(readRowThread.getState == Thread.State.WAITING) + } + } finally { + readRowThread.interrupt() + readRowThread.join() + } + } +} From ffaefe755e20cb94e27f07b233615a4bbb476679 Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Mon, 21 May 2018 13:05:17 -0700 Subject: [PATCH 0842/1161] [SPARK-7132][ML] Add fit with validation set to spark.ml GBT ## What changes were proposed in this pull request? Add fit with validation set to spark.ml GBT ## How was this patch tested? Will add later. Author: WeichenXu Closes #21129 from WeichenXu123/gbt_fit_validation. --- .../ml/classification/GBTClassifier.scala | 38 ++++++++++++--- .../ml/param/shared/SharedParamsCodeGen.scala | 5 +- .../spark/ml/param/shared/sharedParams.scala | 17 +++++++ .../spark/ml/regression/GBTRegressor.scala | 31 ++++++++++-- .../org/apache/spark/ml/tree/treeParams.scala | 41 +++++++++++----- .../classification/GBTClassifierSuite.scala | 46 ++++++++++++++++++ .../ml/regression/GBTRegressorSuite.scala | 48 ++++++++++++++++++- project/MimaExcludes.scala | 13 ++++- 8 files changed, 213 insertions(+), 26 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index 3fb6d1e4e4f3e..337133a2e2326 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -146,12 +146,21 @@ class GBTClassifier @Since("1.4.0") ( @Since("1.4.0") def setLossType(value: String): this.type = set(lossType, value) + /** @group setParam */ + @Since("2.4.0") + def setValidationIndicatorCol(value: String): this.type = { + set(validationIndicatorCol, value) + } + override protected def train(dataset: Dataset[_]): GBTClassificationModel = { val categoricalFeatures: Map[Int, Int] = MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) + + val withValidation = isDefined(validationIndicatorCol) && $(validationIndicatorCol).nonEmpty + // We copy and modify this from Classifier.extractLabeledPoints since GBT only supports // 2 classes now. This lets us provide a more precise error message. - val oldDataset: RDD[LabeledPoint] = + val convert2LabeledPoint = (dataset: Dataset[_]) => { dataset.select(col($(labelCol)), col($(featuresCol))).rdd.map { case Row(label: Double, features: Vector) => require(label == 0 || label == 1, s"GBTClassifier was given" + @@ -159,7 +168,18 @@ class GBTClassifier @Since("1.4.0") ( s" GBTClassifier currently only supports binary classification.") LabeledPoint(label, features) } - val numFeatures = oldDataset.first().features.size + } + + val (trainDataset, validationDataset) = if (withValidation) { + ( + convert2LabeledPoint(dataset.filter(not(col($(validationIndicatorCol))))), + convert2LabeledPoint(dataset.filter(col($(validationIndicatorCol)))) + ) + } else { + (convert2LabeledPoint(dataset), null) + } + + val numFeatures = trainDataset.first().features.size val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Classification) val numClasses = 2 @@ -169,15 +189,21 @@ class GBTClassifier @Since("1.4.0") ( s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}") } - val instr = Instrumentation.create(this, oldDataset) + val instr = Instrumentation.create(this, dataset) instr.logParams(labelCol, featuresCol, predictionCol, impurity, lossType, maxDepth, maxBins, maxIter, maxMemoryInMB, minInfoGain, minInstancesPerNode, - seed, stepSize, subsamplingRate, cacheNodeIds, checkpointInterval, featureSubsetStrategy) + seed, stepSize, subsamplingRate, cacheNodeIds, checkpointInterval, featureSubsetStrategy, + validationIndicatorCol) instr.logNumFeatures(numFeatures) instr.logNumClasses(numClasses) - val (baseLearners, learnerWeights) = GradientBoostedTrees.run(oldDataset, boostingStrategy, - $(seed), $(featureSubsetStrategy)) + val (baseLearners, learnerWeights) = if (withValidation) { + GradientBoostedTrees.runWithValidation(trainDataset, validationDataset, boostingStrategy, + $(seed), $(featureSubsetStrategy)) + } else { + GradientBoostedTrees.run(trainDataset, boostingStrategy, $(seed), $(featureSubsetStrategy)) + } + val m = new GBTClassificationModel(uid, baseLearners, learnerWeights, numFeatures) instr.logSuccess(m) m diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala index b9c3170cc3c28..7e08675f834da 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala @@ -95,7 +95,10 @@ private[shared] object SharedParamsCodeGen { ParamDesc[String]("distanceMeasure", "The distance measure. Supported options: 'euclidean'" + " and 'cosine'", Some("org.apache.spark.mllib.clustering.DistanceMeasure.EUCLIDEAN"), isValid = "(value: String) => " + - "org.apache.spark.mllib.clustering.DistanceMeasure.validateDistanceMeasure(value)") + "org.apache.spark.mllib.clustering.DistanceMeasure.validateDistanceMeasure(value)"), + ParamDesc[String]("validationIndicatorCol", "name of the column that indicates whether " + + "each row is for training or for validation. False indicates training; true indicates " + + "validation.") ) val code = genSharedParams(params) diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala index 282ea6ebcbf7f..5928a0749f738 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala @@ -523,4 +523,21 @@ trait HasDistanceMeasure extends Params { /** @group getParam */ final def getDistanceMeasure: String = $(distanceMeasure) } + +/** + * Trait for shared param validationIndicatorCol. This trait may be changed or + * removed between minor versions. + */ +@DeveloperApi +trait HasValidationIndicatorCol extends Params { + + /** + * Param for name of the column that indicates whether each row is for training or for validation. False indicates training; true indicates validation.. + * @group param + */ + final val validationIndicatorCol: Param[String] = new Param[String](this, "validationIndicatorCol", "name of the column that indicates whether each row is for training or for validation. False indicates training; true indicates validation.") + + /** @group getParam */ + final def getValidationIndicatorCol: String = $(validationIndicatorCol) +} // scalastyle:on diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala index d7e054bf55ef6..eb8b3c001436a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala @@ -145,21 +145,42 @@ class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String) override def setFeatureSubsetStrategy(value: String): this.type = set(featureSubsetStrategy, value) + /** @group setParam */ + @Since("2.4.0") + def setValidationIndicatorCol(value: String): this.type = { + set(validationIndicatorCol, value) + } + override protected def train(dataset: Dataset[_]): GBTRegressionModel = { val categoricalFeatures: Map[Int, Int] = MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) - val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset) - val numFeatures = oldDataset.first().features.size + + val withValidation = isDefined(validationIndicatorCol) && $(validationIndicatorCol).nonEmpty + + val (trainDataset, validationDataset) = if (withValidation) { + ( + extractLabeledPoints(dataset.filter(not(col($(validationIndicatorCol))))), + extractLabeledPoints(dataset.filter(col($(validationIndicatorCol)))) + ) + } else { + (extractLabeledPoints(dataset), null) + } + val numFeatures = trainDataset.first().features.size val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Regression) - val instr = Instrumentation.create(this, oldDataset) + val instr = Instrumentation.create(this, dataset) instr.logParams(labelCol, featuresCol, predictionCol, impurity, lossType, maxDepth, maxBins, maxIter, maxMemoryInMB, minInfoGain, minInstancesPerNode, seed, stepSize, subsamplingRate, cacheNodeIds, checkpointInterval, featureSubsetStrategy) instr.logNumFeatures(numFeatures) - val (baseLearners, learnerWeights) = GradientBoostedTrees.run(oldDataset, boostingStrategy, - $(seed), $(featureSubsetStrategy)) + val (baseLearners, learnerWeights) = if (withValidation) { + GradientBoostedTrees.runWithValidation(trainDataset, validationDataset, boostingStrategy, + $(seed), $(featureSubsetStrategy)) + } else { + GradientBoostedTrees.run(trainDataset, boostingStrategy, + $(seed), $(featureSubsetStrategy)) + } val m = new GBTRegressionModel(uid, baseLearners, learnerWeights, numFeatures) instr.logSuccess(m) m diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala index ec8868bb42cbb..00157fe63af41 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala @@ -21,6 +21,7 @@ import java.util.Locale import scala.util.Try +import org.apache.spark.annotation.Since import org.apache.spark.ml.PredictorParams import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ @@ -460,18 +461,34 @@ private[ml] trait RandomForestRegressorParams * * Note: Marked as private and DeveloperApi since this may be made public in the future. */ -private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter with HasStepSize { - - /* TODO: Add this doc when we add this param. SPARK-7132 - * Threshold for stopping early when runWithValidation is used. - * If the error rate on the validation input changes by less than the validationTol, - * then learning will stop early (before [[numIterations]]). - * This parameter is ignored when run is used. - * (default = 1e-5) +private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter with HasStepSize + with HasValidationIndicatorCol { + + /** + * Threshold for stopping early when fit with validation is used. + * (This parameter is ignored when fit without validation is used.) + * The decision to stop early is decided based on this logic: + * If the current loss on the validation set is greater than 0.01, the diff + * of validation error is compared to relative tolerance which is + * validationTol * (current loss on the validation set). + * If the current loss on the validation set is less than or equal to 0.01, + * the diff of validation error is compared to absolute tolerance which is + * validationTol * 0.01. * @group param + * @see validationIndicatorCol */ - // final val validationTol: DoubleParam = new DoubleParam(this, "validationTol", "") - // validationTol -> 1e-5 + @Since("2.4.0") + final val validationTol: DoubleParam = new DoubleParam(this, "validationTol", + "Threshold for stopping early when fit with validation is used." + + "If the error rate on the validation input changes by less than the validationTol," + + "then learning will stop early (before `maxIter`)." + + "This parameter is ignored when fit without validation is used.", + ParamValidators.gtEq(0.0) + ) + + /** @group getParam */ + @Since("2.4.0") + final def getValidationTol: Double = $(validationTol) /** * @deprecated This method is deprecated and will be removed in 3.0.0. @@ -497,7 +514,7 @@ private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter with HasS @deprecated("This method is deprecated and will be removed in 3.0.0.", "2.1.0") def setStepSize(value: Double): this.type = set(stepSize, value) - setDefault(maxIter -> 20, stepSize -> 0.1) + setDefault(maxIter -> 20, stepSize -> 0.1, validationTol -> 0.01) setDefault(featureSubsetStrategy -> "all") @@ -507,7 +524,7 @@ private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter with HasS oldAlgo: OldAlgo.Algo): OldBoostingStrategy = { val strategy = super.getOldStrategy(categoricalFeatures, numClasses = 2, oldAlgo, OldVariance) // NOTE: The old API does not support "seed" so we ignore it. - new OldBoostingStrategy(strategy, getOldLossType, getMaxIter, getStepSize) + new OldBoostingStrategy(strategy, getOldLossType, getMaxIter, getStepSize, getValidationTol) } /** Get old Gradient Boosting Loss type */ diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala index e20de196d65ca..e6d2a8e2b900e 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala @@ -34,6 +34,7 @@ import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.mllib.tree.loss.LogLoss import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.functions.lit import org.apache.spark.util.Utils /** @@ -392,6 +393,51 @@ class GBTClassifierSuite extends MLTest with DefaultReadWriteTest { assert(evalArr(2) ~== lossErr3 relTol 1E-3) } + test("runWithValidation stops early and performs better on a validation dataset") { + val validationIndicatorCol = "validationIndicator" + val trainDF = trainData.toDF().withColumn(validationIndicatorCol, lit(false)) + val validationDF = validationData.toDF().withColumn(validationIndicatorCol, lit(true)) + + val numIter = 20 + for (lossType <- GBTClassifier.supportedLossTypes) { + val gbt = new GBTClassifier() + .setSeed(123) + .setMaxDepth(2) + .setLossType(lossType) + .setMaxIter(numIter) + val modelWithoutValidation = gbt.fit(trainDF) + + gbt.setValidationIndicatorCol(validationIndicatorCol) + val modelWithValidation = gbt.fit(trainDF.union(validationDF)) + + assert(modelWithoutValidation.numTrees === numIter) + // early stop + assert(modelWithValidation.numTrees < numIter) + + val (errorWithoutValidation, errorWithValidation) = { + val remappedRdd = validationData.map(x => new LabeledPoint(2 * x.label - 1, x.features)) + (GradientBoostedTrees.computeError(remappedRdd, modelWithoutValidation.trees, + modelWithoutValidation.treeWeights, modelWithoutValidation.getOldLossType), + GradientBoostedTrees.computeError(remappedRdd, modelWithValidation.trees, + modelWithValidation.treeWeights, modelWithValidation.getOldLossType)) + } + assert(errorWithValidation < errorWithoutValidation) + + val evaluationArray = GradientBoostedTrees + .evaluateEachIteration(validationData, modelWithoutValidation.trees, + modelWithoutValidation.treeWeights, modelWithoutValidation.getOldLossType, + OldAlgo.Classification) + assert(evaluationArray.length === numIter) + assert(evaluationArray(modelWithValidation.numTrees) > + evaluationArray(modelWithValidation.numTrees - 1)) + var i = 1 + while (i < modelWithValidation.numTrees) { + assert(evaluationArray(i) <= evaluationArray(i - 1)) + i += 1 + } + } + } + ///////////////////////////////////////////////////////////////////////////// // Tests of model save/load ///////////////////////////////////////////////////////////////////////////// diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala index 773f6d2c542fe..b145c7a3dc952 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala @@ -28,6 +28,7 @@ import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees => import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.functions.lit import org.apache.spark.util.Utils /** @@ -231,7 +232,52 @@ class GBTRegressorSuite extends MLTest with DefaultReadWriteTest { } } - ///////////////////////////////////////////////////////////////////////////// + test("runWithValidation stops early and performs better on a validation dataset") { + val validationIndicatorCol = "validationIndicator" + val trainDF = trainData.toDF().withColumn(validationIndicatorCol, lit(false)) + val validationDF = validationData.toDF().withColumn(validationIndicatorCol, lit(true)) + + val numIter = 20 + for (lossType <- GBTRegressor.supportedLossTypes) { + val gbt = new GBTRegressor() + .setSeed(123) + .setMaxDepth(2) + .setLossType(lossType) + .setMaxIter(numIter) + val modelWithoutValidation = gbt.fit(trainDF) + + gbt.setValidationIndicatorCol(validationIndicatorCol) + val modelWithValidation = gbt.fit(trainDF.union(validationDF)) + + assert(modelWithoutValidation.numTrees === numIter) + // early stop + assert(modelWithValidation.numTrees < numIter) + + val errorWithoutValidation = GradientBoostedTrees.computeError(validationData, + modelWithoutValidation.trees, modelWithoutValidation.treeWeights, + modelWithoutValidation.getOldLossType) + val errorWithValidation = GradientBoostedTrees.computeError(validationData, + modelWithValidation.trees, modelWithValidation.treeWeights, + modelWithValidation.getOldLossType) + + assert(errorWithValidation < errorWithoutValidation) + + val evaluationArray = GradientBoostedTrees + .evaluateEachIteration(validationData, modelWithoutValidation.trees, + modelWithoutValidation.treeWeights, modelWithoutValidation.getOldLossType, + OldAlgo.Regression) + assert(evaluationArray.length === numIter) + assert(evaluationArray(modelWithValidation.numTrees) > + evaluationArray(modelWithValidation.numTrees - 1)) + var i = 1 + while (i < modelWithValidation.numTrees) { + assert(evaluationArray(i) <= evaluationArray(i - 1)) + i += 1 + } + } + } + + ///////////////////////////////////////////////////////////////////////////// // Tests of model save/load ///////////////////////////////////////////////////////////////////////////// diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 7d0e88ee20c3f..6bae4d147d4ac 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -73,7 +73,18 @@ object MimaExcludes { ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.ml.tree.InternalNode"), ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.ml.tree.Node"), ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.DecisionTreeClassificationModel.this"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.regression.DecisionTreeRegressionModel.this") + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.regression.DecisionTreeRegressionModel.this"), + + // [SPARK-7132][ML] Add fit with validation set to spark.ml GBT + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasValidationIndicatorCol.getValidationIndicatorCol"), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasValidationIndicatorCol.org$apache$spark$ml$param$shared$HasValidationIndicatorCol$_setter_$validationIndicatorCol_="), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasValidationIndicatorCol.validationIndicatorCol"), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasValidationIndicatorCol.getValidationIndicatorCol"), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasValidationIndicatorCol.org$apache$spark$ml$param$shared$HasValidationIndicatorCol$_setter_$validationIndicatorCol_="), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasValidationIndicatorCol.validationIndicatorCol"), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasValidationIndicatorCol.getValidationIndicatorCol"), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasValidationIndicatorCol.org$apache$spark$ml$param$shared$HasValidationIndicatorCol$_setter_$validationIndicatorCol_="), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasValidationIndicatorCol.validationIndicatorCol") ) // Exclude rules for 2.3.x From b550b2a1a159941c7327973182f16004a6bf179d Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Mon, 21 May 2018 14:21:05 -0700 Subject: [PATCH 0843/1161] [SPARK-24325] Tests for Hadoop's LinesReader ## What changes were proposed in this pull request? The tests cover basic functionality of [Hadoop LinesReader](https://github.com/apache/spark/blob/8d79113b812a91073d2c24a3a9ad94cc3b90b24a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFileLinesReader.scala#L42). In particular, the added tests check: - A split slices a line or delimiter - A split slices two consecutive lines and cover a delimiter between the lines - Two splits slice a line and there are no duplicates - Internal buffer size (`io.file.buffer.size`) is less than line length - Constrain of maximum line length - `mapreduce.input.linerecordreader.line.maxlength` Author: Maxim Gekk Closes #21377 from MaxGekk/line-reader-tests. --- .../HadoopFileLinesReaderSuite.scala | 137 ++++++++++++++++++ 1 file changed, 137 insertions(+) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/HadoopFileLinesReaderSuite.scala diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/HadoopFileLinesReaderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/HadoopFileLinesReaderSuite.scala new file mode 100644 index 0000000000000..a39a25be262a6 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/HadoopFileLinesReaderSuite.scala @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import java.io.File +import java.nio.charset.StandardCharsets +import java.nio.file.Files + +import org.apache.hadoop.conf.Configuration + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.test.SharedSQLContext + +class HadoopFileLinesReaderSuite extends SharedSQLContext { + def getLines( + path: File, + text: String, + ranges: Seq[(Long, Long)], + delimiter: Option[String] = None, + conf: Option[Configuration] = None): Seq[String] = { + val delimOpt = delimiter.map(_.getBytes(StandardCharsets.UTF_8)) + Files.write(path.toPath, text.getBytes(StandardCharsets.UTF_8)) + + val lines = ranges.map { case (start, length) => + val file = PartitionedFile(InternalRow.empty, path.getCanonicalPath, start, length) + val hadoopConf = conf.getOrElse(spark.sparkContext.hadoopConfiguration) + val reader = new HadoopFileLinesReader(file, delimOpt, hadoopConf) + + reader.map(_.toString) + }.flatten + + lines + } + + test("A split ends at the delimiter") { + withTempPath { path => + val lines = getLines(path, text = "a\r\nb", ranges = Seq((0, 1), (1, 3))) + assert(lines == Seq("a", "b")) + } + } + + test("A split cuts the delimiter") { + withTempPath { path => + val lines = getLines(path, text = "a\r\nb", ranges = Seq((0, 2), (2, 2))) + assert(lines == Seq("a", "b")) + } + } + + test("A split ends at the end of the delimiter") { + withTempPath { path => + val lines = getLines(path, text = "a\r\nb", ranges = Seq((0, 3), (3, 1))) + assert(lines == Seq("a", "b")) + } + } + + test("A split covers two lines") { + withTempPath { path => + val lines = getLines(path, text = "a\r\nb", ranges = Seq((0, 4), (4, 1))) + assert(lines == Seq("a", "b")) + } + } + + test("A split ends at the custom delimiter") { + withTempPath { path => + val lines = getLines(path, text = "a^_^b", ranges = Seq((0, 1), (1, 4)), Some("^_^")) + assert(lines == Seq("a", "b")) + } + } + + test("A split slices the custom delimiter") { + withTempPath { path => + val lines = getLines(path, text = "a^_^b", ranges = Seq((0, 2), (2, 3)), Some("^_^")) + assert(lines == Seq("a", "b")) + } + } + + test("The first split covers the first line and the custom delimiter") { + withTempPath { path => + val lines = getLines(path, text = "a^_^b", ranges = Seq((0, 4), (4, 1)), Some("^_^")) + assert(lines == Seq("a", "b")) + } + } + + test("A split cuts the first line") { + withTempPath { path => + val lines = getLines(path, text = "abc,def", ranges = Seq((0, 1)), Some(",")) + assert(lines == Seq("abc")) + } + } + + test("The split cuts both lines") { + withTempPath { path => + val lines = getLines(path, text = "abc,def", ranges = Seq((2, 2)), Some(",")) + assert(lines == Seq("def")) + } + } + + test("io.file.buffer.size is less than line length") { + val conf = spark.sparkContext.hadoopConfiguration + conf.set("io.file.buffer.size", "2") + withTempPath { path => + val lines = getLines(path, text = "abcdef\n123456", ranges = Seq((4, 4), (8, 5))) + assert(lines == Seq("123456")) + } + } + + test("line cannot be longer than line.maxlength") { + val conf = spark.sparkContext.hadoopConfiguration + conf.set("mapreduce.input.linerecordreader.line.maxlength", "5") + withTempPath { path => + val lines = getLines(path, text = "abcdef\n1234", ranges = Seq((0, 15))) + assert(lines == Seq("1234")) + } + } + + test("default delimiter is 0xd or 0xa or 0xd0xa") { + withTempPath { path => + val lines = getLines(path, text = "1\r2\n3\r\n4", ranges = Seq((0, 3), (3, 5))) + assert(lines == Seq("1", "2", "3", "4")) + } + } +} From 32447079e9d0fa9f7e180b94ecac19091b6af1ab Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Mon, 21 May 2018 16:26:39 -0700 Subject: [PATCH 0844/1161] [SPARK-24309][CORE] AsyncEventQueue should stop on interrupt. EventListeners can interrupt the event queue thread. In particular, when the EventLoggingListener writes to hdfs, hdfs can interrupt the thread. When there is an interrupt, the queue should be removed and stop accepting any more events. Before this change, the queue would continue to take more events (till it was full), and then would not stop when the application was complete because the PoisonPill couldn't be added. Added a unit test which failed before this change. Author: Imran Rashid Closes #21356 from squito/SPARK-24309. --- .../spark/scheduler/AsyncEventQueue.scala | 41 ++++++++------ .../spark/scheduler/LiveListenerBus.scala | 2 +- .../org/apache/spark/util/ListenerBus.scala | 18 +++++++ .../spark/scheduler/SparkListenerSuite.scala | 54 +++++++++++++++++++ 4 files changed, 98 insertions(+), 17 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/AsyncEventQueue.scala b/core/src/main/scala/org/apache/spark/scheduler/AsyncEventQueue.scala index c1fedd63f6a90..e2b6df4600590 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/AsyncEventQueue.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/AsyncEventQueue.scala @@ -34,7 +34,11 @@ import org.apache.spark.util.Utils * Delivery will only begin when the `start()` method is called. The `stop()` method should be * called when no more events need to be delivered. */ -private class AsyncEventQueue(val name: String, conf: SparkConf, metrics: LiveListenerBusMetrics) +private class AsyncEventQueue( + val name: String, + conf: SparkConf, + metrics: LiveListenerBusMetrics, + bus: LiveListenerBus) extends SparkListenerBus with Logging { @@ -81,23 +85,18 @@ private class AsyncEventQueue(val name: String, conf: SparkConf, metrics: LiveLi } private def dispatch(): Unit = LiveListenerBus.withinListenerThread.withValue(true) { - try { - var next: SparkListenerEvent = eventQueue.take() - while (next != POISON_PILL) { - val ctx = processingTime.time() - try { - super.postToAll(next) - } finally { - ctx.stop() - } - eventCount.decrementAndGet() - next = eventQueue.take() + var next: SparkListenerEvent = eventQueue.take() + while (next != POISON_PILL) { + val ctx = processingTime.time() + try { + super.postToAll(next) + } finally { + ctx.stop() } eventCount.decrementAndGet() - } catch { - case ie: InterruptedException => - logInfo(s"Stopping listener queue $name.", ie) + next = eventQueue.take() } + eventCount.decrementAndGet() } override protected def getTimer(listener: SparkListenerInterface): Option[Timer] = { @@ -130,7 +129,11 @@ private class AsyncEventQueue(val name: String, conf: SparkConf, metrics: LiveLi eventCount.incrementAndGet() eventQueue.put(POISON_PILL) } - dispatchThread.join() + // this thread might be trying to stop itself as part of error handling -- we can't join + // in that case. + if (Thread.currentThread() != dispatchThread) { + dispatchThread.join() + } } def post(event: SparkListenerEvent): Unit = { @@ -187,6 +190,12 @@ private class AsyncEventQueue(val name: String, conf: SparkConf, metrics: LiveLi true } + override def removeListenerOnError(listener: SparkListenerInterface): Unit = { + // the listener failed in an unrecoverably way, we want to remove it from the entire + // LiveListenerBus (potentially stopping a queue if it is empty) + bus.removeListener(listener) + } + } private object AsyncEventQueue { diff --git a/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala index ba6387a8f08ad..d135190d1e919 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala @@ -102,7 +102,7 @@ private[spark] class LiveListenerBus(conf: SparkConf) { queue.addListener(listener) case None => - val newQueue = new AsyncEventQueue(queue, conf, metrics) + val newQueue = new AsyncEventQueue(queue, conf, metrics, this) newQueue.addListener(listener) if (started.get()) { newQueue.start(sparkContext) diff --git a/core/src/main/scala/org/apache/spark/util/ListenerBus.scala b/core/src/main/scala/org/apache/spark/util/ListenerBus.scala index b25a731401f23..d4474a90b26f1 100644 --- a/core/src/main/scala/org/apache/spark/util/ListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/util/ListenerBus.scala @@ -60,6 +60,15 @@ private[spark] trait ListenerBus[L <: AnyRef, E] extends Logging { } } + /** + * This can be overriden by subclasses if there is any extra cleanup to do when removing a + * listener. In particular AsyncEventQueues can clean up queues in the LiveListenerBus. + */ + def removeListenerOnError(listener: L): Unit = { + removeListener(listener) + } + + /** * Post the event to all registered listeners. The `postToAll` caller should guarantee calling * `postToAll` in the same thread for all events. @@ -80,7 +89,16 @@ private[spark] trait ListenerBus[L <: AnyRef, E] extends Logging { } try { doPostEvent(listener, event) + if (Thread.interrupted()) { + // We want to throw the InterruptedException right away so we can associate the interrupt + // with this listener, as opposed to waiting for a queue.take() etc. to detect it. + throw new InterruptedException() + } } catch { + case ie: InterruptedException => + logError(s"Interrupted while posting to ${Utils.getFormattedClassName(listener)}. " + + s"Removing that listener.", ie) + removeListenerOnError(listener) case NonFatal(e) if !isIgnorableException(e) => logError(s"Listener ${Utils.getFormattedClassName(listener)} threw an exception", e) } finally { diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala index fa47a52bbbc47..6ffd1e84f7adb 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala @@ -489,6 +489,48 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match assert(bus.findListenersByClass[BasicJobCounter]().isEmpty) } + Seq(true, false).foreach { throwInterruptedException => + val suffix = if (throwInterruptedException) "throw interrupt" else "set Thread interrupted" + test(s"interrupt within listener is handled correctly: $suffix") { + val conf = new SparkConf(false) + .set(LISTENER_BUS_EVENT_QUEUE_CAPACITY, 5) + val bus = new LiveListenerBus(conf) + val counter1 = new BasicJobCounter() + val counter2 = new BasicJobCounter() + val interruptingListener1 = new InterruptingListener(throwInterruptedException) + val interruptingListener2 = new InterruptingListener(throwInterruptedException) + bus.addToSharedQueue(counter1) + bus.addToSharedQueue(interruptingListener1) + bus.addToStatusQueue(counter2) + bus.addToEventLogQueue(interruptingListener2) + assert(bus.activeQueues() === Set(SHARED_QUEUE, APP_STATUS_QUEUE, EVENT_LOG_QUEUE)) + assert(bus.findListenersByClass[BasicJobCounter]().size === 2) + assert(bus.findListenersByClass[InterruptingListener]().size === 2) + + bus.start(mockSparkContext, mockMetricsSystem) + + // after we post one event, both interrupting listeners should get removed, and the + // event log queue should be removed + bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded)) + bus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) + assert(bus.activeQueues() === Set(SHARED_QUEUE, APP_STATUS_QUEUE)) + assert(bus.findListenersByClass[BasicJobCounter]().size === 2) + assert(bus.findListenersByClass[InterruptingListener]().size === 0) + assert(counter1.count === 1) + assert(counter2.count === 1) + + // posting more events should be fine, they'll just get processed from the OK queue. + (0 until 5).foreach { _ => bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded)) } + bus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) + assert(counter1.count === 6) + assert(counter2.count === 6) + + // Make sure stopping works -- this requires putting a poison pill in all active queues, which + // would fail if our interrupted queue was still active, as its queue would be full. + bus.stop() + } + } + /** * Assert that the given list of numbers has an average that is greater than zero. */ @@ -547,6 +589,18 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = { throw new Exception } } + /** + * A simple listener that interrupts on job end. + */ + private class InterruptingListener(val throwInterruptedException: Boolean) extends SparkListener { + override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = { + if (throwInterruptedException) { + throw new InterruptedException("got interrupted") + } else { + Thread.currentThread().interrupt() + } + } + } } // These classes can't be declared inside of the SparkListenerSuite class because we don't want From 84d31aa5d453620d462f1fdd90206c676a8395cd Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Mon, 21 May 2018 18:11:05 -0700 Subject: [PATCH 0845/1161] [SPARK-24209][SHS] Automatic retrieve proxyBase from Knox headers ## What changes were proposed in this pull request? The PR retrieves the proxyBase automatically from the header `X-Forwarded-Context` (if available). This is the header used by Knox to inform the proxied service about the base path. This provides 0-configuration support for Knox gateway (instead of having to properly set `spark.ui.proxyBase`) and it allows to access directly SHS when it is proxied by Knox. In the previous scenario, indeed, after setting `spark.ui.proxyBase`, direct access to SHS was not working fine (due to bad link generated). ## How was this patch tested? added UT + manual tests Author: Marco Gaido Closes #21268 from mgaido91/SPARK-24209. --- .../spark/deploy/history/HistoryPage.scala | 17 +-- .../spark/deploy/history/HistoryServer.scala | 2 +- .../deploy/master/ui/ApplicationPage.scala | 4 +- .../spark/deploy/master/ui/MasterPage.scala | 2 +- .../spark/deploy/worker/ui/LogPage.scala | 2 +- .../spark/deploy/worker/ui/WorkerPage.scala | 2 +- .../scala/org/apache/spark/ui/UIUtils.scala | 109 ++++++++++-------- .../apache/spark/ui/env/EnvironmentPage.scala | 2 +- .../ui/exec/ExecutorThreadDumpPage.scala | 2 +- .../apache/spark/ui/exec/ExecutorsTab.scala | 6 +- .../apache/spark/ui/jobs/AllJobsPage.scala | 4 +- .../apache/spark/ui/jobs/AllStagesPage.scala | 4 +- .../org/apache/spark/ui/jobs/JobPage.scala | 5 +- .../org/apache/spark/ui/jobs/PoolPage.scala | 4 +- .../org/apache/spark/ui/jobs/PoolTable.scala | 9 +- .../org/apache/spark/ui/jobs/StagePage.scala | 8 +- .../org/apache/spark/ui/jobs/StageTable.scala | 12 +- .../org/apache/spark/ui/storage/RDDPage.scala | 7 +- .../apache/spark/ui/storage/StoragePage.scala | 20 +++- .../deploy/history/HistoryServerSuite.scala | 24 ++++ .../spark/ui/storage/StoragePageSuite.scala | 7 +- .../spark/deploy/mesos/ui/DriverPage.scala | 6 +- .../deploy/mesos/ui/MesosClusterPage.scala | 2 +- .../sql/execution/ui/AllExecutionsPage.scala | 38 +++--- .../sql/execution/ui/ExecutionPage.scala | 30 ++--- .../thriftserver/ui/ThriftServerPage.scala | 17 +-- .../ui/ThriftServerSessionPage.scala | 9 +- .../apache/spark/streaming/ui/BatchPage.scala | 21 +++- .../spark/streaming/ui/StreamingPage.scala | 12 +- 29 files changed, 232 insertions(+), 155 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala index 6fc12d721e6f1..32667ddf5c7ea 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala @@ -37,8 +37,8 @@ private[history] class HistoryPage(parent: HistoryServer) extends WebUIPage("") val lastUpdatedTime = parent.getLastUpdatedTime() val providerConfig = parent.getProviderConfig() val content = - ++ - + ++ +
    - UIUtils.basicSparkPage(content, "History Server", true) + UIUtils.basicSparkPage(request, content, "History Server", true) } - private def makePageLink(showIncomplete: Boolean): String = { - UIUtils.prependBaseUri("/?" + "showIncomplete=" + showIncomplete) + private def makePageLink(request: HttpServletRequest, showIncomplete: Boolean): String = { + UIUtils.prependBaseUri(request, "/?" + "showIncomplete=" + showIncomplete) } private def isApplicationCompleted(appInfo: ApplicationInfo): Boolean = { diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala index 611fa563a7cd9..a9a4d5a4ec6a2 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala @@ -87,7 +87,7 @@ class HistoryServer( if (!loadAppUi(appId, None) && (!attemptId.isDefined || !loadAppUi(appId, attemptId))) { val msg =
    Application {appId} not found.
    res.setStatus(HttpServletResponse.SC_NOT_FOUND) - UIUtils.basicSparkPage(msg, "Not Found").foreach { n => + UIUtils.basicSparkPage(req, msg, "Not Found").foreach { n => res.getWriter().write(n.toString) } return diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala index f699c75085fe1..fad4e46dc035d 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala @@ -40,7 +40,7 @@ private[ui] class ApplicationPage(parent: MasterWebUI) extends WebUIPage("app") .getOrElse(state.completedApps.find(_.id == appId).orNull) if (app == null) { val msg =
    No running application with ID {appId}
    - return UIUtils.basicSparkPage(msg, "Not Found") + return UIUtils.basicSparkPage(request, msg, "Not Found") } val executorHeaders = Seq("ExecutorID", "Worker", "Cores", "Memory", "State", "Logs") @@ -127,7 +127,7 @@ private[ui] class ApplicationPage(parent: MasterWebUI) extends WebUIPage("app") } ; - UIUtils.basicSparkPage(content, "Application: " + app.desc.name) + UIUtils.basicSparkPage(request, content, "Application: " + app.desc.name) } private def executorRow(executor: ExecutorDesc): Seq[Node] = { diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala index c629937606b51..b8afe203fbfa2 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala @@ -215,7 +215,7 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { } ; - UIUtils.basicSparkPage(content, "Spark Master at " + state.uri) + UIUtils.basicSparkPage(request, content, "Spark Master at " + state.uri) } private def workerRow(worker: WorkerInfo): Seq[Node] = { diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala index 2f5a5642d3cab..4fca9342c0378 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala @@ -118,7 +118,7 @@ private[ui] class LogPage(parent: WorkerWebUI) extends WebUIPage("logPage") with - UIUtils.basicSparkPage(content, logType + " log page for " + pageName) + UIUtils.basicSparkPage(request, content, logType + " log page for " + pageName) } /** Get the part of the log files given the offset and desired length of bytes */ diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala index 8b98ae56fc108..aa4e28d213e2b 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala @@ -135,7 +135,7 @@ private[ui] class WorkerPage(parent: WorkerWebUI) extends WebUIPage("") { } ; - UIUtils.basicSparkPage(content, "Spark Worker at %s:%s".format( + UIUtils.basicSparkPage(request, content, "Spark Worker at %s:%s".format( workerState.host, workerState.port)) } diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala index 02cf19e00ecde..5d015b0531ef6 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala @@ -20,6 +20,7 @@ package org.apache.spark.ui import java.net.URLDecoder import java.text.SimpleDateFormat import java.util.{Date, Locale, TimeZone} +import javax.servlet.http.HttpServletRequest import scala.util.control.NonFatal import scala.xml._ @@ -148,60 +149,71 @@ private[spark] object UIUtils extends Logging { } // Yarn has to go through a proxy so the base uri is provided and has to be on all links - def uiRoot: String = { + def uiRoot(request: HttpServletRequest): String = { + // Knox uses X-Forwarded-Context to notify the application the base path + val knoxBasePath = Option(request.getHeader("X-Forwarded-Context")) // SPARK-11484 - Use the proxyBase set by the AM, if not found then use env. sys.props.get("spark.ui.proxyBase") .orElse(sys.env.get("APPLICATION_WEB_PROXY_BASE")) + .orElse(knoxBasePath) .getOrElse("") } - def prependBaseUri(basePath: String = "", resource: String = ""): String = { - uiRoot + basePath + resource + def prependBaseUri( + request: HttpServletRequest, + basePath: String = "", + resource: String = ""): String = { + uiRoot(request) + basePath + resource } - def commonHeaderNodes: Seq[Node] = { + def commonHeaderNodes(request: HttpServletRequest): Seq[Node] = { - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + } - def vizHeaderNodes: Seq[Node] = { - - - - - + def vizHeaderNodes(request: HttpServletRequest): Seq[Node] = { + + + + + } - def dataTablesHeaderNodes: Seq[Node] = { + def dataTablesHeaderNodes(request: HttpServletRequest): Seq[Node] = { + + href={prependBaseUri(request, "/static/dataTables.bootstrap.css")} type="text/css"/> - - - - - - - + href={prependBaseUri(request, "/static/jsonFormatter.min.css")} type="text/css"/> + + + + + + } /** Returns a spark page with correctly formatted headers */ def headerSparkPage( + request: HttpServletRequest, title: String, content: => Seq[Node], activeTab: SparkUITab, @@ -214,25 +226,26 @@ private[spark] object UIUtils extends Logging { val shortAppName = if (appName.length < 36) appName else appName.take(32) + "..." val header = activeTab.headerTabs.map { tab =>
  • - {tab.name} + {tab.name}
  • } val helpButton: Seq[Node] = helpText.map(tooltip(_, "bottom")).getOrElse(Seq.empty) - {commonHeaderNodes} - {if (showVisualization) vizHeaderNodes else Seq.empty} - {if (useDataTables) dataTablesHeaderNodes else Seq.empty} - + {commonHeaderNodes(request)} + {if (showVisualization) vizHeaderNodes(request) else Seq.empty} + {if (useDataTables) dataTablesHeaderNodes(request) else Seq.empty} + {appName} - {title}
    }.getOrElse(Text("Error fetching thread dump")) - UIUtils.headerSparkPage(s"Thread dump for executor $executorId", content, parent) + UIUtils.headerSparkPage(request, s"Thread dump for executor $executorId", content, parent) } } diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala index 843486f4a70d2..d5a60f52cbb0f 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala @@ -49,12 +49,12 @@ private[ui] class ExecutorsPage(
    {
    ++ - ++ - ++ + ++ + ++ }
    - UIUtils.headerSparkPage("Executors", content, parent, useDataTables = true) + UIUtils.headerSparkPage(request, "Executors", content, parent, useDataTables = true) } } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala index 2b0f4acbac72a..f651fe97c2cd5 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala @@ -248,7 +248,7 @@ private[ui] class AllJobsPage(parent: JobsTab, store: AppStatusStore) extends We jobs, tableHeaderId, jobTag, - UIUtils.prependBaseUri(parent.basePath), + UIUtils.prependBaseUri(request, parent.basePath), "jobs", // subPath parameterOtherTable, killEnabled, @@ -407,7 +407,7 @@ private[ui] class AllJobsPage(parent: JobsTab, store: AppStatusStore) extends We val helpText = """A job is triggered by an action, like count() or saveAsTextFile().""" + " Click on a job to see information about the stages of tasks inside it." - UIUtils.headerSparkPage("Spark Jobs", content, parent, helpText = Some(helpText)) + UIUtils.headerSparkPage(request, "Spark Jobs", content, parent, helpText = Some(helpText)) } } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala index 4658aa1cea3f1..f672ce0ec6a68 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala @@ -66,7 +66,7 @@ private[ui] class AllStagesPage(parent: StagesTab) extends WebUIPage("") { ++
    - {poolTable.toNodeSeq} + {poolTable.toNodeSeq(request)}
    } else { Seq.empty[Node] @@ -74,7 +74,7 @@ private[ui] class AllStagesPage(parent: StagesTab) extends WebUIPage("") { val content = summary ++ poolsDescription ++ tables.flatten.flatten - UIUtils.headerSparkPage("Stages for All Jobs", content, parent) + UIUtils.headerSparkPage(request, "Stages for All Jobs", content, parent) } private def summaryAndTableForStatus( diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala index 46f2a76cc651b..55444a2c0c9ab 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala @@ -195,7 +195,7 @@ private[ui] class JobPage(parent: JobsTab, store: AppStatusStore) extends WebUIP

    No information to display for job {jobId}

    return UIUtils.headerSparkPage( - s"Details for Job $jobId", content, parent) + request, s"Details for Job $jobId", content, parent) } val isComplete = jobData.status != JobExecutionStatus.RUNNING val stages = jobData.stageIds.map { stageId => @@ -413,6 +413,7 @@ private[ui] class JobPage(parent: JobsTab, store: AppStatusStore) extends WebUIP {failedStagesTable.toNodeSeq} } - UIUtils.headerSparkPage(s"Details for Job $jobId", content, parent, showVisualization = true) + UIUtils.headerSparkPage( + request, s"Details for Job $jobId", content, parent, showVisualization = true) } } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala index a3e1f13782e30..22a40101e33df 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala @@ -49,7 +49,7 @@ private[ui] class PoolPage(parent: StagesTab) extends WebUIPage("pool") { "stages/pool", parent.isFairScheduler, parent.killEnabled, false) val poolTable = new PoolTable(Map(pool -> uiPool), parent) - var content =

    Summary

    ++ poolTable.toNodeSeq + var content =

    Summary

    ++ poolTable.toNodeSeq(request) if (activeStages.nonEmpty) { content ++= } - UIUtils.headerSparkPage("Fair Scheduler Pool: " + poolName, content, parent) + UIUtils.headerSparkPage(request, "Fair Scheduler Pool: " + poolName, content, parent) } } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala index 5dfce858dec07..96b5f72393070 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala @@ -18,6 +18,7 @@ package org.apache.spark.ui.jobs import java.net.URLEncoder +import javax.servlet.http.HttpServletRequest import scala.xml.Node @@ -28,7 +29,7 @@ import org.apache.spark.ui.UIUtils /** Table showing list of pools */ private[ui] class PoolTable(pools: Map[Schedulable, PoolData], parent: StagesTab) { - def toNodeSeq: Seq[Node] = { + def toNodeSeq(request: HttpServletRequest): Seq[Node] = { @@ -39,15 +40,15 @@ private[ui] class PoolTable(pools: Map[Schedulable, PoolData], parent: StagesTab - {pools.map { case (s, p) => poolRow(s, p) }} + {pools.map { case (s, p) => poolRow(request, s, p) }}
    Pool NameSchedulingMode
    } - private def poolRow(s: Schedulable, p: PoolData): Seq[Node] = { + private def poolRow(request: HttpServletRequest, s: Schedulable, p: PoolData): Seq[Node] = { val activeStages = p.stageIds.size val href = "%s/stages/pool?poolname=%s" - .format(UIUtils.prependBaseUri(parent.basePath), URLEncoder.encode(p.name, "UTF-8")) + .format(UIUtils.prependBaseUri(request, parent.basePath), URLEncoder.encode(p.name, "UTF-8"))
    {p.name} diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index ac83de10f9237..2575914121c39 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -112,7 +112,7 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We

    No information to display for Stage {stageId} (Attempt {stageAttemptId})

    - return UIUtils.headerSparkPage(stageHeader, content, parent) + return UIUtils.headerSparkPage(request, stageHeader, content, parent) } val localitySummary = store.localitySummary(stageData.stageId, stageData.attemptId) @@ -125,7 +125,7 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We

    Summary Metrics

    No tasks have started yet

    Tasks

    No tasks have started yet - return UIUtils.headerSparkPage(stageHeader, content, parent) + return UIUtils.headerSparkPage(request, stageHeader, content, parent) } val storedTasks = store.taskCount(stageData.stageId, stageData.attemptId) @@ -282,7 +282,7 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We val (taskTable, taskTableHTML) = try { val _taskTable = new TaskPagedTable( stageData, - UIUtils.prependBaseUri(parent.basePath) + + UIUtils.prependBaseUri(request, parent.basePath) + s"/stages/stage?id=${stageId}&attempt=${stageAttemptId}", currentTime, pageSize = taskPageSize, @@ -498,7 +498,7 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We
    {taskTableHTML ++ jsForScrollingDownToTaskTable}
    - UIUtils.headerSparkPage(stageHeader, content, parent, showVisualization = true) + UIUtils.headerSparkPage(request, stageHeader, content, parent, showVisualization = true) } def makeTimeline(tasks: Seq[TaskData], currentTime: Long): Seq[Node] = { diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala index 18a4926f2f6c0..b8b20db1fa407 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala @@ -92,7 +92,8 @@ private[ui] class StageTableBase( stageSortColumn, stageSortDesc, isFailedStage, - parameterOtherTable + parameterOtherTable, + request ).table(page) } catch { case e @ (_ : IllegalArgumentException | _ : IndexOutOfBoundsException) => @@ -147,7 +148,8 @@ private[ui] class StagePagedTable( sortColumn: String, desc: Boolean, isFailedStage: Boolean, - parameterOtherTable: Iterable[String]) extends PagedTable[StageTableRowData] { + parameterOtherTable: Iterable[String], + request: HttpServletRequest) extends PagedTable[StageTableRowData] { override def tableId: String = stageTag + "-table" @@ -161,7 +163,7 @@ private[ui] class StagePagedTable( override def pageNumberFormField: String = stageTag + ".page" - val parameterPath = UIUtils.prependBaseUri(basePath) + s"/$subPath/?" + + val parameterPath = UIUtils.prependBaseUri(request, basePath) + s"/$subPath/?" + parameterOtherTable.mkString("&") override val dataSource = new StageDataSource( @@ -288,7 +290,7 @@ private[ui] class StagePagedTable( {if (isFairScheduler) { + .format(UIUtils.prependBaseUri(request, basePath), data.schedulingPool)}> {data.schedulingPool} @@ -346,7 +348,7 @@ private[ui] class StagePagedTable( } private def makeDescription(s: v1.StageData, descriptionOption: Option[String]): Seq[Node] = { - val basePathUri = UIUtils.prependBaseUri(basePath) + val basePathUri = UIUtils.prependBaseUri(request, basePath) val killLink = if (killEnabled) { val confirm = diff --git a/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala b/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala index 2674b9291203a..238cd31433660 100644 --- a/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala @@ -53,7 +53,7 @@ private[ui] class RDDPage(parent: SparkUITab, store: AppStatusStore) extends Web } catch { case _: NoSuchElementException => // Rather than crashing, render an "RDD Not Found" page - return UIUtils.headerSparkPage("RDD Not Found", Seq.empty[Node], parent) + return UIUtils.headerSparkPage(request, "RDD Not Found", Seq.empty[Node], parent) } // Worker table @@ -72,7 +72,7 @@ private[ui] class RDDPage(parent: SparkUITab, store: AppStatusStore) extends Web } val blockTableHTML = try { val _blockTable = new BlockPagedTable( - UIUtils.prependBaseUri(parent.basePath) + s"/storage/rdd/?id=${rddId}", + UIUtils.prependBaseUri(request, parent.basePath) + s"/storage/rdd/?id=${rddId}", rddStorageInfo.partitions.get, blockPageSize, blockSortColumn, @@ -145,7 +145,8 @@ private[ui] class RDDPage(parent: SparkUITab, store: AppStatusStore) extends Web {blockTableHTML ++ jsForScrollingDownToBlockTable} ; - UIUtils.headerSparkPage("RDD Storage Info for " + rddStorageInfo.name, content, parent) + UIUtils.headerSparkPage( + request, "RDD Storage Info for " + rddStorageInfo.name, content, parent) } /** Header fields for the worker table */ diff --git a/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala b/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala index 68d946574a37b..3eb546e336e99 100644 --- a/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala @@ -31,11 +31,14 @@ import org.apache.spark.util.Utils private[ui] class StoragePage(parent: SparkUITab, store: AppStatusStore) extends WebUIPage("") { def render(request: HttpServletRequest): Seq[Node] = { - val content = rddTable(store.rddList()) ++ receiverBlockTables(store.streamBlocksList()) - UIUtils.headerSparkPage("Storage", content, parent) + val content = rddTable(request, store.rddList()) ++ + receiverBlockTables(store.streamBlocksList()) + UIUtils.headerSparkPage(request, "Storage", content, parent) } - private[storage] def rddTable(rdds: Seq[v1.RDDStorageInfo]): Seq[Node] = { + private[storage] def rddTable( + request: HttpServletRequest, + rdds: Seq[v1.RDDStorageInfo]): Seq[Node] = { if (rdds.isEmpty) { // Don't show the rdd table if there is no RDD persisted. Nil @@ -49,7 +52,11 @@ private[ui] class StoragePage(parent: SparkUITab, store: AppStatusStore) extends
    - {UIUtils.listingTable(rddHeader, rddRow, rdds, id = Some("storage-by-rdd-table"))} + {UIUtils.listingTable( + rddHeader, + rddRow(request, _: v1.RDDStorageInfo), + rdds, + id = Some("storage-by-rdd-table"))}
    } @@ -66,12 +73,13 @@ private[ui] class StoragePage(parent: SparkUITab, store: AppStatusStore) extends "Size on Disk") /** Render an HTML row representing an RDD */ - private def rddRow(rdd: v1.RDDStorageInfo): Seq[Node] = { + private def rddRow(request: HttpServletRequest, rdd: v1.RDDStorageInfo): Seq[Node] = { // scalastyle:off {rdd.id} - + {rdd.name} diff --git a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala index a871b1c717837..11b29121739a4 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala @@ -36,6 +36,7 @@ import org.eclipse.jetty.servlet.{ServletContextHandler, ServletHolder} import org.json4s.JsonAST._ import org.json4s.jackson.JsonMethods import org.json4s.jackson.JsonMethods._ +import org.mockito.Mockito._ import org.openqa.selenium.WebDriver import org.openqa.selenium.htmlunit.HtmlUnitDriver import org.scalatest.{BeforeAndAfter, Matchers} @@ -281,6 +282,29 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers getContentAndCode("foobar")._1 should be (HttpServletResponse.SC_NOT_FOUND) } + test("automatically retrieve uiRoot from request through Knox") { + assert(sys.props.get("spark.ui.proxyBase").isEmpty, + "spark.ui.proxyBase is defined but it should not for this UT") + assert(sys.env.get("APPLICATION_WEB_PROXY_BASE").isEmpty, + "APPLICATION_WEB_PROXY_BASE is defined but it should not for this UT") + val page = new HistoryPage(server) + val requestThroughKnox = mock[HttpServletRequest] + val knoxBaseUrl = "/gateway/default/sparkhistoryui" + when(requestThroughKnox.getHeader("X-Forwarded-Context")).thenReturn(knoxBaseUrl) + val responseThroughKnox = page.render(requestThroughKnox) + + val urlsThroughKnox = responseThroughKnox \\ "@href" map (_.toString) + val siteRelativeLinksThroughKnox = urlsThroughKnox filter (_.startsWith("/")) + all (siteRelativeLinksThroughKnox) should startWith (knoxBaseUrl) + + val directRequest = mock[HttpServletRequest] + val directResponse = page.render(directRequest) + + val directUrls = directResponse \\ "@href" map (_.toString) + val directSiteRelativeLinks = directUrls filter (_.startsWith("/")) + all (directSiteRelativeLinks) should not startWith (knoxBaseUrl) + } + test("static relative links are prefixed with uiRoot (spark.ui.proxyBase)") { val uiRoot = Option(System.getenv("APPLICATION_WEB_PROXY_BASE")).getOrElse("/testwebproxybase") val page = new HistoryPage(server) diff --git a/core/src/test/scala/org/apache/spark/ui/storage/StoragePageSuite.scala b/core/src/test/scala/org/apache/spark/ui/storage/StoragePageSuite.scala index a71521c91d2f2..cdc7f541b9552 100644 --- a/core/src/test/scala/org/apache/spark/ui/storage/StoragePageSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/storage/StoragePageSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.ui.storage +import javax.servlet.http.HttpServletRequest + import org.mockito.Mockito._ import org.apache.spark.SparkFunSuite @@ -29,6 +31,7 @@ class StoragePageSuite extends SparkFunSuite { val storageTab = mock(classOf[StorageTab]) when(storageTab.basePath).thenReturn("http://localhost:4040") val storagePage = new StoragePage(storageTab, null) + val request = mock(classOf[HttpServletRequest]) test("rddTable") { val rdd1 = new RDDStorageInfo(1, @@ -61,7 +64,7 @@ class StoragePageSuite extends SparkFunSuite { None, None) - val xmlNodes = storagePage.rddTable(Seq(rdd1, rdd2, rdd3)) + val xmlNodes = storagePage.rddTable(request, Seq(rdd1, rdd2, rdd3)) val headers = Seq( "ID", @@ -94,7 +97,7 @@ class StoragePageSuite extends SparkFunSuite { } test("empty rddTable") { - assert(storagePage.rddTable(Seq.empty).isEmpty) + assert(storagePage.rddTable(request, Seq.empty).isEmpty) } test("streamBlockStorageLevelDescriptionAndSize") { diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/DriverPage.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/DriverPage.scala index 022191d0070fd..91f64141e5318 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/DriverPage.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/DriverPage.scala @@ -39,7 +39,7 @@ private[ui] class DriverPage(parent: MesosClusterUI) extends WebUIPage("driver")

    Cannot find driver {driverId}

    - return UIUtils.basicSparkPage(content, s"Details for Job $driverId") + return UIUtils.basicSparkPage(request, content, s"Details for Job $driverId") } val driverState = state.get val driverHeaders = Seq("Driver property", "Value") @@ -68,7 +68,7 @@ private[ui] class DriverPage(parent: MesosClusterUI) extends WebUIPage("driver") retryHeaders, retryRow, Iterable.apply(driverState.description.retryState)) val content =

    Driver state information for driver id {driverId}

    - Back to Drivers + Back to Drivers

    Driver state: {driverState.state}

    @@ -87,7 +87,7 @@ private[ui] class DriverPage(parent: MesosClusterUI) extends WebUIPage("driver")
    ; - UIUtils.basicSparkPage(content, s"Details for Job $driverId") + UIUtils.basicSparkPage(request, content, s"Details for Job $driverId") } private def launchedRow(submissionState: Option[MesosClusterSubmissionState]): Seq[Node] = { diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterPage.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterPage.scala index 88a6614d51384..c53285331ea68 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterPage.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterPage.scala @@ -62,7 +62,7 @@ private[mesos] class MesosClusterPage(parent: MesosClusterUI) extends WebUIPage( {retryTable} ; - UIUtils.basicSparkPage(content, "Spark Drivers for Mesos cluster") + UIUtils.basicSparkPage(request, content, "Spark Drivers for Mesos cluster") } private def queuedRow(submission: MesosDriverDescription): Seq[Node] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala index 582528777f90e..bf46bc4cf904d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala @@ -58,21 +58,21 @@ private[ui] class AllExecutionsPage(parent: SQLTab) extends WebUIPage("") with L _content ++= new RunningExecutionTable( parent, s"Running Queries (${running.size})", currentTime, - running.sortBy(_.submissionTime).reverse).toNodeSeq + running.sortBy(_.submissionTime).reverse).toNodeSeq(request) } if (completed.nonEmpty) { _content ++= new CompletedExecutionTable( parent, s"Completed Queries (${completed.size})", currentTime, - completed.sortBy(_.submissionTime).reverse).toNodeSeq + completed.sortBy(_.submissionTime).reverse).toNodeSeq(request) } if (failed.nonEmpty) { _content ++= new FailedExecutionTable( parent, s"Failed Queries (${failed.size})", currentTime, - failed.sortBy(_.submissionTime).reverse).toNodeSeq + failed.sortBy(_.submissionTime).reverse).toNodeSeq(request) } _content } @@ -111,7 +111,7 @@ private[ui] class AllExecutionsPage(parent: SQLTab) extends WebUIPage("") with L } - UIUtils.headerSparkPage("SQL", summary ++ content, parent, Some(5000)) + UIUtils.headerSparkPage(request, "SQL", summary ++ content, parent, Some(5000)) } } @@ -133,7 +133,10 @@ private[ui] abstract class ExecutionTable( protected def header: Seq[String] - protected def row(currentTime: Long, executionUIData: SQLExecutionUIData): Seq[Node] = { + protected def row( + request: HttpServletRequest, + currentTime: Long, + executionUIData: SQLExecutionUIData): Seq[Node] = { val submissionTime = executionUIData.submissionTime val duration = executionUIData.completionTime.map(_.getTime()).getOrElse(currentTime) - submissionTime @@ -141,7 +144,7 @@ private[ui] abstract class ExecutionTable( def jobLinks(status: JobExecutionStatus): Seq[Node] = { executionUIData.jobs.flatMap { case (jobId, jobStatus) => if (jobStatus == status) { - [{jobId.toString}] + [{jobId.toString}] } else { None } @@ -153,7 +156,7 @@ private[ui] abstract class ExecutionTable( {executionUIData.executionId.toString} - {descriptionCell(executionUIData)} + {descriptionCell(request, executionUIData)} {UIUtils.formatDate(submissionTime)} @@ -179,7 +182,9 @@ private[ui] abstract class ExecutionTable( } - private def descriptionCell(execution: SQLExecutionUIData): Seq[Node] = { + private def descriptionCell( + request: HttpServletRequest, + execution: SQLExecutionUIData): Seq[Node] = { val details = if (execution.details != null && execution.details.nonEmpty) { +details @@ -192,27 +197,28 @@ private[ui] abstract class ExecutionTable( } val desc = if (execution.description != null && execution.description.nonEmpty) { - {execution.description} + {execution.description} } else { - {execution.executionId} + {execution.executionId} }
    {desc} {details}
    } - def toNodeSeq: Seq[Node] = { + def toNodeSeq(request: HttpServletRequest): Seq[Node] = {

    {tableName}

    {UIUtils.listingTable[SQLExecutionUIData]( - header, row(currentTime, _), executionUIDatas, id = Some(tableId))} + header, row(request, currentTime, _), executionUIDatas, id = Some(tableId))}
    } - private def jobURL(jobId: Long): String = - "%s/jobs/job?id=%s".format(UIUtils.prependBaseUri(parent.basePath), jobId) + private def jobURL(request: HttpServletRequest, jobId: Long): String = + "%s/jobs/job?id=%s".format(UIUtils.prependBaseUri(request, parent.basePath), jobId) - private def executionURL(executionID: Long): String = - s"${UIUtils.prependBaseUri(parent.basePath)}/${parent.prefix}/execution?id=$executionID" + private def executionURL(request: HttpServletRequest, executionID: Long): String = + s"${UIUtils.prependBaseUri( + request, parent.basePath)}/${parent.prefix}/execution?id=$executionID" } private[ui] class RunningExecutionTable( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala index e0554f0c4d337..282f7b4bb5a58 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala @@ -49,7 +49,7 @@ class ExecutionPage(parent: SQLTab) extends WebUIPage("execution") with Logging
  • {label} {jobs.toSeq.sorted.map { jobId => - {jobId.toString}  + {jobId.toString}  }}
  • } else { @@ -77,27 +77,31 @@ class ExecutionPage(parent: SQLTab) extends WebUIPage("execution") with Logging val graph = sqlStore.planGraph(executionId) summary ++ - planVisualization(metrics, graph) ++ + planVisualization(request, metrics, graph) ++ physicalPlanDescription(executionUIData.physicalPlanDescription) }.getOrElse {
    No information to display for query {executionId}
    } - UIUtils.headerSparkPage(s"Details for Query $executionId", content, parent, Some(5000)) + UIUtils.headerSparkPage( + request, s"Details for Query $executionId", content, parent, Some(5000)) } - private def planVisualizationResources: Seq[Node] = { + private def planVisualizationResources(request: HttpServletRequest): Seq[Node] = { // scalastyle:off - - - - - + + + + + // scalastyle:on } - private def planVisualization(metrics: Map[Long, String], graph: SparkPlanGraph): Seq[Node] = { + private def planVisualization( + request: HttpServletRequest, + metrics: Map[Long, String], + graph: SparkPlanGraph): Seq[Node] = { val metadata = graph.allNodes.flatMap { node => val nodeId = s"plan-meta-data-${node.id}"
    {node.desc}
    @@ -112,13 +116,13 @@ class ExecutionPage(parent: SQLTab) extends WebUIPage("execution") with Logging
    {graph.allNodes.size.toString}
    {metadata} - {planVisualizationResources} + {planVisualizationResources(request)} } - private def jobURL(jobId: Long): String = - "%s/jobs/job?id=%s".format(UIUtils.prependBaseUri(parent.basePath), jobId) + private def jobURL(request: HttpServletRequest, jobId: Long): String = + "%s/jobs/job?id=%s".format(UIUtils.prependBaseUri(request, parent.basePath), jobId) private def physicalPlanDescription(physicalPlanDescription: String): Seq[Node] = {
    diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala index f517bffccdf31..0950b30126773 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala @@ -47,10 +47,10 @@ private[ui] class ThriftServerPage(parent: ThriftServerTab) extends WebUIPage("" {listener.getOnlineSessionNum} session(s) are online, running {listener.getTotalRunning} SQL statement(s) ++ - generateSessionStatsTable() ++ - generateSQLStatsTable() + generateSessionStatsTable(request) ++ + generateSQLStatsTable(request) } - UIUtils.headerSparkPage("JDBC/ODBC Server", content, parent, Some(5000)) + UIUtils.headerSparkPage(request, "JDBC/ODBC Server", content, parent, Some(5000)) } /** Generate basic stats of the thrift server program */ @@ -67,7 +67,7 @@ private[ui] class ThriftServerPage(parent: ThriftServerTab) extends WebUIPage("" } /** Generate stats of batch statements of the thrift server program */ - private def generateSQLStatsTable(): Seq[Node] = { + private def generateSQLStatsTable(request: HttpServletRequest): Seq[Node] = { val numStatement = listener.getExecutionList.size val table = if (numStatement > 0) { val headerRow = Seq("User", "JobID", "GroupID", "Start Time", "Finish Time", "Duration", @@ -76,7 +76,8 @@ private[ui] class ThriftServerPage(parent: ThriftServerTab) extends WebUIPage("" def generateDataRow(info: ExecutionInfo): Seq[Node] = { val jobLink = info.jobId.map { id: String => - + [{id}] } @@ -138,7 +139,7 @@ private[ui] class ThriftServerPage(parent: ThriftServerTab) extends WebUIPage("" } /** Generate stats of batch sessions of the thrift server program */ - private def generateSessionStatsTable(): Seq[Node] = { + private def generateSessionStatsTable(request: HttpServletRequest): Seq[Node] = { val sessionList = listener.getSessionList val numBatches = sessionList.size val table = if (numBatches > 0) { @@ -146,8 +147,8 @@ private[ui] class ThriftServerPage(parent: ThriftServerTab) extends WebUIPage("" val headerRow = Seq("User", "IP", "Session ID", "Start Time", "Finish Time", "Duration", "Total Execute") def generateDataRow(session: SessionInfo): Seq[Node] = { - val sessionLink = "%s/%s/session?id=%s" - .format(UIUtils.prependBaseUri(parent.basePath), parent.prefix, session.sessionId) + val sessionLink = "%s/%s/session?id=%s".format( + UIUtils.prependBaseUri(request, parent.basePath), parent.prefix, session.sessionId) {session.userName} {session.ip} diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala index 5cd2fdf6437c2..c884aa0ecbdf8 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala @@ -56,9 +56,9 @@ private[ui] class ThriftServerSessionPage(parent: ThriftServerTab) Session created at {formatDate(sessionStat.startTimestamp)}, Total run {sessionStat.totalExecution} SQL ++ - generateSQLStatsTable(sessionStat.sessionId) + generateSQLStatsTable(request, sessionStat.sessionId) } - UIUtils.headerSparkPage("JDBC/ODBC Session", content, parent, Some(5000)) + UIUtils.headerSparkPage(request, "JDBC/ODBC Session", content, parent, Some(5000)) } /** Generate basic stats of the thrift server program */ @@ -75,7 +75,7 @@ private[ui] class ThriftServerSessionPage(parent: ThriftServerTab) } /** Generate stats of batch statements of the thrift server program */ - private def generateSQLStatsTable(sessionID: String): Seq[Node] = { + private def generateSQLStatsTable(request: HttpServletRequest, sessionID: String): Seq[Node] = { val executionList = listener.getExecutionList .filter(_.sessionId == sessionID) val numStatement = executionList.size @@ -86,7 +86,8 @@ private[ui] class ThriftServerSessionPage(parent: ThriftServerTab) def generateDataRow(info: ExecutionInfo): Seq[Node] = { val jobLink = info.jobId.map { id: String => - + [{id}] } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala index 6748dd4ec48e3..ca9da6139649a 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala @@ -47,6 +47,7 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { } private def generateJobRow( + request: HttpServletRequest, outputOpData: OutputOperationUIData, outputOpDescription: Seq[Node], formattedOutputOpDuration: String, @@ -54,7 +55,7 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { isFirstRow: Boolean, jobIdWithData: SparkJobIdWithUIData): Seq[Node] = { if (jobIdWithData.jobData.isDefined) { - generateNormalJobRow(outputOpData, outputOpDescription, formattedOutputOpDuration, + generateNormalJobRow(request, outputOpData, outputOpDescription, formattedOutputOpDuration, numSparkJobRowsInOutputOp, isFirstRow, jobIdWithData.jobData.get) } else { generateDroppedJobRow(outputOpData, outputOpDescription, formattedOutputOpDuration, @@ -89,6 +90,7 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { * one cell, we use "rowspan" for the first row of an output op. */ private def generateNormalJobRow( + request: HttpServletRequest, outputOpData: OutputOperationUIData, outputOpDescription: Seq[Node], formattedOutputOpDuration: String, @@ -106,7 +108,8 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { dropWhile(_.failureReason == None).take(1). // get the first info that contains failure flatMap(info => info.failureReason).headOption.getOrElse("") val formattedDuration = duration.map(d => SparkUIUtils.formatDuration(d)).getOrElse("-") - val detailUrl = s"${SparkUIUtils.prependBaseUri(parent.basePath)}/jobs/job?id=${sparkJob.jobId}" + val detailUrl = s"${SparkUIUtils.prependBaseUri( + request, parent.basePath)}/jobs/job?id=${sparkJob.jobId}" // In the first row, output op id and its information needs to be shown. In other rows, these // cells will be taken up due to "rowspan". @@ -196,6 +199,7 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { } private def generateOutputOpIdRow( + request: HttpServletRequest, outputOpData: OutputOperationUIData, sparkJobs: Seq[SparkJobIdWithUIData]): Seq[Node] = { val formattedOutputOpDuration = @@ -212,6 +216,7 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { } else { val firstRow = generateJobRow( + request, outputOpData, description, formattedOutputOpDuration, @@ -221,6 +226,7 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { val tailRows = sparkJobs.tail.map { sparkJob => generateJobRow( + request, outputOpData, description, formattedOutputOpDuration, @@ -278,7 +284,9 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { /** * Generate the job table for the batch. */ - private def generateJobTable(batchUIData: BatchUIData): Seq[Node] = { + private def generateJobTable( + request: HttpServletRequest, + batchUIData: BatchUIData): Seq[Node] = { val outputOpIdToSparkJobIds = batchUIData.outputOpIdSparkJobIdPairs.groupBy(_.outputOpId). map { case (outputOpId, outputOpIdAndSparkJobIds) => // sort SparkJobIds for each OutputOpId @@ -301,7 +309,7 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { { outputOpWithJobs.map { case (outputOpData, sparkJobs) => - generateOutputOpIdRow(outputOpData, sparkJobs) + generateOutputOpIdRow(request, outputOpData, sparkJobs) } } @@ -364,9 +372,10 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") {
    - val content = summary ++ generateJobTable(batchUIData) + val content = summary ++ generateJobTable(request, batchUIData) - SparkUIUtils.headerSparkPage(s"Details of batch at $formattedBatchTime", content, parent) + SparkUIUtils.headerSparkPage( + request, s"Details of batch at $formattedBatchTime", content, parent) } def generateInputMetadataTable(inputMetadatas: Seq[(Int, String)]): Seq[Node] = { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala index 3a176f64cdd60..4ce661bc1144e 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala @@ -148,7 +148,7 @@ private[ui] class StreamingPage(parent: StreamingTab) /** Render the page */ def render(request: HttpServletRequest): Seq[Node] = { - val resources = generateLoadResources() + val resources = generateLoadResources(request) val basicInfo = generateBasicInfo() val content = resources ++ basicInfo ++ @@ -156,17 +156,17 @@ private[ui] class StreamingPage(parent: StreamingTab) generateStatTable() ++ generateBatchListTables() } - SparkUIUtils.headerSparkPage("Streaming Statistics", content, parent, Some(5000)) + SparkUIUtils.headerSparkPage(request, "Streaming Statistics", content, parent, Some(5000)) } /** * Generate html that will load css/js files for StreamingPage */ - private def generateLoadResources(): Seq[Node] = { + private def generateLoadResources(request: HttpServletRequest): Seq[Node] = { // scalastyle:off - - - + + + // scalastyle:on } From 952e4d1c830c4eb3dfd522be3d292dd02d8c9065 Mon Sep 17 00:00:00 2001 From: Kris Mok Date: Tue, 22 May 2018 19:12:30 +0800 Subject: [PATCH 0846/1161] [SPARK-24321][SQL] Extract common code from Divide/Remainder to a base trait ## What changes were proposed in this pull request? Extract common code from `Divide`/`Remainder` to a new base trait, `DivModLike`. Further refactoring to make `Pmod` work with `DivModLike` is to be done as a separate task. ## How was this patch tested? Existing tests in `ArithmeticExpressionSuite` covers the functionality. Author: Kris Mok Closes #21367 from rednaxelafx/catalyst-divmod. --- .../sql/catalyst/expressions/arithmetic.scala | 145 ++++++------------ 1 file changed, 51 insertions(+), 94 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index d4e322d23b95b..efd4e992c8eec 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -220,30 +220,12 @@ case class Multiply(left: Expression, right: Expression) extends BinaryArithmeti protected override def nullSafeEval(input1: Any, input2: Any): Any = numeric.times(input1, input2) } -// scalastyle:off line.size.limit -@ExpressionDescription( - usage = "expr1 _FUNC_ expr2 - Returns `expr1`/`expr2`. It always performs floating point division.", - examples = """ - Examples: - > SELECT 3 _FUNC_ 2; - 1.5 - > SELECT 2L _FUNC_ 2L; - 1.0 - """) -// scalastyle:on line.size.limit -case class Divide(left: Expression, right: Expression) extends BinaryArithmetic { - - override def inputType: AbstractDataType = TypeCollection(DoubleType, DecimalType) +// Common base trait for Divide and Remainder, since these two classes are almost identical +trait DivModLike extends BinaryArithmetic { - override def symbol: String = "/" - override def decimalMethod: String = "$div" override def nullable: Boolean = true - private lazy val div: (Any, Any) => Any = dataType match { - case ft: FractionalType => ft.fractional.asInstanceOf[Fractional[Any]].div - } - - override def eval(input: InternalRow): Any = { + final override def eval(input: InternalRow): Any = { val input2 = right.eval(input) if (input2 == null || input2 == 0) { null @@ -252,13 +234,15 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic if (input1 == null) { null } else { - div(input1, input2) + evalOperation(input1, input2) } } } + def evalOperation(left: Any, right: Any): Any + /** - * Special case handling due to division by 0 => null. + * Special case handling due to division/remainder by 0 => null. */ override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val eval1 = left.genCode(ctx) @@ -269,7 +253,7 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic s"${eval2.value} == 0" } val javaType = CodeGenerator.javaType(dataType) - val divide = if (dataType.isInstanceOf[DecimalType]) { + val operation = if (dataType.isInstanceOf[DecimalType]) { s"${eval1.value}.$decimalMethod(${eval2.value})" } else { s"($javaType)(${eval1.value} $symbol ${eval2.value})" @@ -283,7 +267,7 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic ${ev.isNull} = true; } else { ${eval1.code} - ${ev.value} = $divide; + ${ev.value} = $operation; }""") } else { ev.copy(code = s""" @@ -297,13 +281,38 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic if (${eval1.isNull}) { ${ev.isNull} = true; } else { - ${ev.value} = $divide; + ${ev.value} = $operation; } }""") } } } +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "expr1 _FUNC_ expr2 - Returns `expr1`/`expr2`. It always performs floating point division.", + examples = """ + Examples: + > SELECT 3 _FUNC_ 2; + 1.5 + > SELECT 2L _FUNC_ 2L; + 1.0 + """) +// scalastyle:on line.size.limit +case class Divide(left: Expression, right: Expression) extends DivModLike { + + override def inputType: AbstractDataType = TypeCollection(DoubleType, DecimalType) + + override def symbol: String = "/" + override def decimalMethod: String = "$div" + + private lazy val div: (Any, Any) => Any = dataType match { + case ft: FractionalType => ft.fractional.asInstanceOf[Fractional[Any]].div + } + + override def evalOperation(left: Any, right: Any): Any = div(left, right) +} + @ExpressionDescription( usage = "expr1 _FUNC_ expr2 - Returns the remainder after `expr1`/`expr2`.", examples = """ @@ -313,82 +322,30 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic > SELECT MOD(2, 1.8); 0.2 """) -case class Remainder(left: Expression, right: Expression) extends BinaryArithmetic { +case class Remainder(left: Expression, right: Expression) extends DivModLike { override def inputType: AbstractDataType = NumericType override def symbol: String = "%" override def decimalMethod: String = "remainder" - override def nullable: Boolean = true - private lazy val integral = dataType match { - case i: IntegralType => i.integral.asInstanceOf[Integral[Any]] - case i: FractionalType => i.asIntegral.asInstanceOf[Integral[Any]] + private lazy val mod: (Any, Any) => Any = dataType match { + // special cases to make float/double primitive types faster + case DoubleType => + (left, right) => left.asInstanceOf[Double] % right.asInstanceOf[Double] + case FloatType => + (left, right) => left.asInstanceOf[Float] % right.asInstanceOf[Float] + + // catch-all cases + case i: IntegralType => + val integral = i.integral.asInstanceOf[Integral[Any]] + (left, right) => integral.rem(left, right) + case i: FractionalType => // should only be DecimalType for now + val integral = i.asIntegral.asInstanceOf[Integral[Any]] + (left, right) => integral.rem(left, right) } - override def eval(input: InternalRow): Any = { - val input2 = right.eval(input) - if (input2 == null || input2 == 0) { - null - } else { - val input1 = left.eval(input) - if (input1 == null) { - null - } else { - input1 match { - case d: Double => d % input2.asInstanceOf[java.lang.Double] - case f: Float => f % input2.asInstanceOf[java.lang.Float] - case _ => integral.rem(input1, input2) - } - } - } - } - - /** - * Special case handling for x % 0 ==> null. - */ - override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val eval1 = left.genCode(ctx) - val eval2 = right.genCode(ctx) - val isZero = if (dataType.isInstanceOf[DecimalType]) { - s"${eval2.value}.isZero()" - } else { - s"${eval2.value} == 0" - } - val javaType = CodeGenerator.javaType(dataType) - val remainder = if (dataType.isInstanceOf[DecimalType]) { - s"${eval1.value}.$decimalMethod(${eval2.value})" - } else { - s"($javaType)(${eval1.value} $symbol ${eval2.value})" - } - if (!left.nullable && !right.nullable) { - ev.copy(code = s""" - ${eval2.code} - boolean ${ev.isNull} = false; - $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; - if ($isZero) { - ${ev.isNull} = true; - } else { - ${eval1.code} - ${ev.value} = $remainder; - }""") - } else { - ev.copy(code = s""" - ${eval2.code} - boolean ${ev.isNull} = false; - $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; - if (${eval2.isNull} || $isZero) { - ${ev.isNull} = true; - } else { - ${eval1.code} - if (${eval1.isNull}) { - ${ev.isNull} = true; - } else { - ${ev.value} = $remainder; - } - }""") - } - } + override def evalOperation(left: Any, right: Any): Any = mod(left, right) } @ExpressionDescription( From 82fb5bfa770b0325d4f377dd38d89869007c6111 Mon Sep 17 00:00:00 2001 From: Xianjin YE Date: Tue, 22 May 2018 21:02:17 +0800 Subject: [PATCH 0847/1161] [SPARK-20087][CORE] Attach accumulators / metrics to 'TaskKilled' end reason ## What changes were proposed in this pull request? The ultimate goal is for listeners to onTaskEnd to receive metrics when a task is killed intentionally, since the data is currently just thrown away. This is already done for ExceptionFailure, so this just copies the same approach. ## How was this patch tested? Updated existing tests. This is a rework of https://github.com/apache/spark/pull/17422, all credits should go to noodle-fb Author: Xianjin YE Author: Charles Lewis Closes #21165 from advancedxy/SPARK-20087. --- .../org/apache/spark/TaskEndReason.scala | 8 ++- .../org/apache/spark/executor/Executor.scala | 55 ++++++++++++------- .../apache/spark/scheduler/DAGScheduler.scala | 6 +- .../spark/scheduler/TaskSetManager.scala | 8 ++- .../org/apache/spark/util/JsonProtocol.scala | 9 ++- .../spark/scheduler/DAGSchedulerSuite.scala | 18 ++++-- project/MimaExcludes.scala | 5 ++ 7 files changed, 78 insertions(+), 31 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/TaskEndReason.scala b/core/src/main/scala/org/apache/spark/TaskEndReason.scala index a76283e33fa65..33901bc8380e9 100644 --- a/core/src/main/scala/org/apache/spark/TaskEndReason.scala +++ b/core/src/main/scala/org/apache/spark/TaskEndReason.scala @@ -212,9 +212,15 @@ case object TaskResultLost extends TaskFailedReason { * Task was killed intentionally and needs to be rescheduled. */ @DeveloperApi -case class TaskKilled(reason: String) extends TaskFailedReason { +case class TaskKilled( + reason: String, + accumUpdates: Seq[AccumulableInfo] = Seq.empty, + private[spark] val accums: Seq[AccumulatorV2[_, _]] = Nil) + extends TaskFailedReason { + override def toErrorString: String = s"TaskKilled ($reason)" override def countTowardsTaskFailures: Boolean = false + } /** diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index c325222b764b8..b1856ff0f3247 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -287,6 +287,28 @@ private[spark] class Executor( notifyAll() } + /** + * Utility function to: + * 1. Report executor runtime and JVM gc time if possible + * 2. Collect accumulator updates + * 3. Set the finished flag to true and clear current thread's interrupt status + */ + private def collectAccumulatorsAndResetStatusOnFailure(taskStartTime: Long) = { + // Report executor runtime and JVM gc time + Option(task).foreach(t => { + t.metrics.setExecutorRunTime(System.currentTimeMillis() - taskStartTime) + t.metrics.setJvmGCTime(computeTotalGcTime() - startGCTime) + }) + + // Collect latest accumulator values to report back to the driver + val accums: Seq[AccumulatorV2[_, _]] = + Option(task).map(_.collectAccumulatorUpdates(taskFailed = true)).getOrElse(Seq.empty) + val accUpdates = accums.map(acc => acc.toInfo(Some(acc.value), None)) + + setTaskFinishedAndClearInterruptStatus() + (accums, accUpdates) + } + override def run(): Unit = { threadId = Thread.currentThread.getId Thread.currentThread.setName(threadName) @@ -300,7 +322,7 @@ private[spark] class Executor( val ser = env.closureSerializer.newInstance() logInfo(s"Running $taskName (TID $taskId)") execBackend.statusUpdate(taskId, TaskState.RUNNING, EMPTY_BYTE_BUFFER) - var taskStart: Long = 0 + var taskStartTime: Long = 0 var taskStartCpu: Long = 0 startGCTime = computeTotalGcTime() @@ -336,7 +358,7 @@ private[spark] class Executor( } // Run the actual task and measure its runtime. - taskStart = System.currentTimeMillis() + taskStartTime = System.currentTimeMillis() taskStartCpu = if (threadMXBean.isCurrentThreadCpuTimeSupported) { threadMXBean.getCurrentThreadCpuTime } else 0L @@ -396,11 +418,11 @@ private[spark] class Executor( // Deserialization happens in two parts: first, we deserialize a Task object, which // includes the Partition. Second, Task.run() deserializes the RDD and function to be run. task.metrics.setExecutorDeserializeTime( - (taskStart - deserializeStartTime) + task.executorDeserializeTime) + (taskStartTime - deserializeStartTime) + task.executorDeserializeTime) task.metrics.setExecutorDeserializeCpuTime( (taskStartCpu - deserializeStartCpuTime) + task.executorDeserializeCpuTime) // We need to subtract Task.run()'s deserialization time to avoid double-counting - task.metrics.setExecutorRunTime((taskFinish - taskStart) - task.executorDeserializeTime) + task.metrics.setExecutorRunTime((taskFinish - taskStartTime) - task.executorDeserializeTime) task.metrics.setExecutorCpuTime( (taskFinishCpu - taskStartCpu) - task.executorDeserializeCpuTime) task.metrics.setJvmGCTime(computeTotalGcTime() - startGCTime) @@ -482,16 +504,19 @@ private[spark] class Executor( } catch { case t: TaskKilledException => logInfo(s"Executor killed $taskName (TID $taskId), reason: ${t.reason}") - setTaskFinishedAndClearInterruptStatus() - execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled(t.reason))) + + val (accums, accUpdates) = collectAccumulatorsAndResetStatusOnFailure(taskStartTime) + val serializedTK = ser.serialize(TaskKilled(t.reason, accUpdates, accums)) + execBackend.statusUpdate(taskId, TaskState.KILLED, serializedTK) case _: InterruptedException | NonFatal(_) if task != null && task.reasonIfKilled.isDefined => val killReason = task.reasonIfKilled.getOrElse("unknown reason") logInfo(s"Executor interrupted and killed $taskName (TID $taskId), reason: $killReason") - setTaskFinishedAndClearInterruptStatus() - execBackend.statusUpdate( - taskId, TaskState.KILLED, ser.serialize(TaskKilled(killReason))) + + val (accums, accUpdates) = collectAccumulatorsAndResetStatusOnFailure(taskStartTime) + val serializedTK = ser.serialize(TaskKilled(killReason, accUpdates, accums)) + execBackend.statusUpdate(taskId, TaskState.KILLED, serializedTK) case t: Throwable if hasFetchFailure && !Utils.isFatalError(t) => val reason = task.context.fetchFailed.get.toTaskFailedReason @@ -524,17 +549,7 @@ private[spark] class Executor( // the task failure would not be ignored if the shutdown happened because of premption, // instead of an app issue). if (!ShutdownHookManager.inShutdown()) { - // Collect latest accumulator values to report back to the driver - val accums: Seq[AccumulatorV2[_, _]] = - if (task != null) { - task.metrics.setExecutorRunTime(System.currentTimeMillis() - taskStart) - task.metrics.setJvmGCTime(computeTotalGcTime() - startGCTime) - task.collectAccumulatorUpdates(taskFailed = true) - } else { - Seq.empty - } - - val accUpdates = accums.map(acc => acc.toInfo(Some(acc.value), None)) + val (accums, accUpdates) = collectAccumulatorsAndResetStatusOnFailure(taskStartTime) val serializedTaskEndReason = { try { diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 5f2d16d03165f..ea7bfd7d7a68d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -1210,7 +1210,7 @@ class DAGScheduler( case _ => updateAccumulators(event) } - case _: ExceptionFailure => updateAccumulators(event) + case _: ExceptionFailure | _: TaskKilled => updateAccumulators(event) case _ => } postTaskEnd(event) @@ -1414,13 +1414,13 @@ class DAGScheduler( case commitDenied: TaskCommitDenied => // Do nothing here, left up to the TaskScheduler to decide how to handle denied commits - case exceptionFailure: ExceptionFailure => + case _: ExceptionFailure | _: TaskKilled => // Nothing left to do, already handled above for accumulator updates. case TaskResultLost => // Do nothing here; the TaskScheduler handles these failures and resubmits the task. - case _: ExecutorLostFailure | _: TaskKilled | UnknownReason => + case _: ExecutorLostFailure | UnknownReason => // Unrecognized failure - also do nothing. If the task fails repeatedly, the TaskScheduler // will abort the job. } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 195fc8025e4b5..a18c66596852a 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -851,13 +851,19 @@ private[spark] class TaskSetManager( } ef.exception + case tk: TaskKilled => + // TaskKilled might have accumulator updates + accumUpdates = tk.accums + logWarning(failureReason) + None + case e: ExecutorLostFailure if !e.exitCausedByApp => logInfo(s"Task $tid failed because while it was being computed, its executor " + "exited for a reason unrelated to the task. Not counting this failure towards the " + "maximum number of failures for the task.") None - case e: TaskFailedReason => // TaskResultLost, TaskKilled, and others + case e: TaskFailedReason => // TaskResultLost and others logWarning(failureReason) None } diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index 40383fe05026b..50c6461373dee 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -407,7 +407,9 @@ private[spark] object JsonProtocol { ("Exit Caused By App" -> exitCausedByApp) ~ ("Loss Reason" -> reason.map(_.toString)) case taskKilled: TaskKilled => - ("Kill Reason" -> taskKilled.reason) + val accumUpdates = JArray(taskKilled.accumUpdates.map(accumulableInfoToJson).toList) + ("Kill Reason" -> taskKilled.reason) ~ + ("Accumulator Updates" -> accumUpdates) case _ => emptyJson } ("Reason" -> reason) ~ json @@ -917,7 +919,10 @@ private[spark] object JsonProtocol { case `taskKilled` => val killReason = jsonOption(json \ "Kill Reason") .map(_.extract[String]).getOrElse("unknown reason") - TaskKilled(killReason) + val accumUpdates = jsonOption(json \ "Accumulator Updates") + .map(_.extract[List[JValue]].map(accumulableInfoFromJson)) + .getOrElse(Seq[AccumulableInfo]()) + TaskKilled(killReason, accumUpdates) case `taskCommitDenied` => // Unfortunately, the `TaskCommitDenied` message was introduced in 1.3.0 but the JSON // de/serialization logic was not added until 1.5.1. To provide backward compatibility diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 8b6ec37625eec..2987170bf5026 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -1852,7 +1852,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi assertDataStructuresEmpty() } - test("accumulators are updated on exception failures") { + test("accumulators are updated on exception failures and task killed") { val acc1 = AccumulatorSuite.createLongAccum("ingenieur") val acc2 = AccumulatorSuite.createLongAccum("boulanger") val acc3 = AccumulatorSuite.createLongAccum("agriculteur") @@ -1868,15 +1868,24 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi val accUpdate3 = new LongAccumulator accUpdate3.metadata = acc3.metadata accUpdate3.setValue(18) - val accumUpdates = Seq(accUpdate1, accUpdate2, accUpdate3) - val accumInfo = accumUpdates.map(AccumulatorSuite.makeInfo) + + val accumUpdates1 = Seq(accUpdate1, accUpdate2) + val accumInfo1 = accumUpdates1.map(AccumulatorSuite.makeInfo) val exceptionFailure = new ExceptionFailure( new SparkException("fondue?"), - accumInfo).copy(accums = accumUpdates) + accumInfo1).copy(accums = accumUpdates1) submit(new MyRDD(sc, 1, Nil), Array(0)) runEvent(makeCompletionEvent(taskSets.head.tasks.head, exceptionFailure, "result")) + assert(AccumulatorContext.get(acc1.id).get.value === 15L) assert(AccumulatorContext.get(acc2.id).get.value === 13L) + + val accumUpdates2 = Seq(accUpdate3) + val accumInfo2 = accumUpdates2.map(AccumulatorSuite.makeInfo) + + val taskKilled = new TaskKilled( "test", accumInfo2, accums = accumUpdates2) + runEvent(makeCompletionEvent(taskSets.head.tasks.head, taskKilled, "result")) + assert(AccumulatorContext.get(acc3.id).get.value === 18L) } @@ -2497,6 +2506,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi val accumUpdates = reason match { case Success => task.metrics.accumulators() case ef: ExceptionFailure => ef.accums + case tk: TaskKilled => tk.accums case _ => Seq.empty } CompletionEvent(task, reason, result, accumUpdates ++ extraAccumUpdates, taskInfo) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 6bae4d147d4ac..4f6d5ff898681 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -36,6 +36,11 @@ object MimaExcludes { // Exclude rules for 2.4.x lazy val v24excludes = v23excludes ++ Seq( + // [SPARK-20087][CORE] Attach accumulators / metrics to 'TaskKilled' end reason + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.TaskKilled.apply"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.TaskKilled.copy"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.TaskKilled.this"), + // [SPARK-22941][core] Do not exit JVM when submit fails with in-process launcher. ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.SparkSubmit.printWarning"), ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.SparkSubmit.parseSparkConfProperty"), From a4470bc78ca5f5a090b6831a7cdca88274eb9afc Mon Sep 17 00:00:00 2001 From: Jake Charland Date: Tue, 22 May 2018 08:06:15 -0500 Subject: [PATCH 0848/1161] [SPARK-21673] Use the correct sandbox environment variable set by Mesos ## What changes were proposed in this pull request? This change changes spark behavior to use the correct environment variable set by Mesos in the container on startup. Author: Jake Charland Closes #18894 from jakecharland/MesosSandbox. --- core/src/main/scala/org/apache/spark/util/Utils.scala | 8 ++++---- docs/configuration.md | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 13adaa921dc23..f9191a59c1655 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -810,15 +810,15 @@ private[spark] object Utils extends Logging { conf.getenv("SPARK_EXECUTOR_DIRS").split(File.pathSeparator) } else if (conf.getenv("SPARK_LOCAL_DIRS") != null) { conf.getenv("SPARK_LOCAL_DIRS").split(",") - } else if (conf.getenv("MESOS_DIRECTORY") != null && !shuffleServiceEnabled) { + } else if (conf.getenv("MESOS_SANDBOX") != null && !shuffleServiceEnabled) { // Mesos already creates a directory per Mesos task. Spark should use that directory // instead so all temporary files are automatically cleaned up when the Mesos task ends. // Note that we don't want this if the shuffle service is enabled because we want to // continue to serve shuffle files after the executors that wrote them have already exited. - Array(conf.getenv("MESOS_DIRECTORY")) + Array(conf.getenv("MESOS_SANDBOX")) } else { - if (conf.getenv("MESOS_DIRECTORY") != null && shuffleServiceEnabled) { - logInfo("MESOS_DIRECTORY available but not using provided Mesos sandbox because " + + if (conf.getenv("MESOS_SANDBOX") != null && shuffleServiceEnabled) { + logInfo("MESOS_SANDBOX available but not using provided Mesos sandbox because " + "spark.shuffle.service.enabled is enabled.") } // In non-Yarn mode (or for the driver in yarn-client mode), we cannot trust the user diff --git a/docs/configuration.md b/docs/configuration.md index 8a1aacef85760..fd2670cba2125 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -208,7 +208,7 @@ of the most common options to set are: stored on disk. This should be on a fast, local disk in your system. It can also be a comma-separated list of multiple directories on different disks. - NOTE: In Spark 1.0 and later this will be overridden by SPARK_LOCAL_DIRS (Standalone, Mesos) or + NOTE: In Spark 1.0 and later this will be overridden by SPARK_LOCAL_DIRS (Standalone), MESOS_SANDBOX (Mesos) or LOCAL_DIRS (YARN) environment variables set by the cluster manager. From d3d18073152cab4408464d1417ec644d939cfdf7 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Tue, 22 May 2018 21:08:49 +0800 Subject: [PATCH 0849/1161] [SPARK-24313][SQL] Fix collection operations' interpreted evaluation for complex types ## What changes were proposed in this pull request? The interpreted evaluation of several collection operations works only for simple datatypes. For complex data types, for instance, `array_contains` it returns always `false`. The list of the affected functions is `array_contains`, `array_position`, `element_at` and `GetMapValue`. The PR fixes the behavior for all the datatypes. ## How was this patch tested? added UT Author: Marco Gaido Closes #21361 from mgaido91/SPARK-24313. --- .../expressions/collectionOperations.scala | 41 ++++++++++++---- .../expressions/complexTypeExtractors.scala | 19 +++++-- .../CollectionExpressionsSuite.scala | 49 ++++++++++++++++++- .../optimizer/complexTypesSuite.scala | 13 +++++ .../org/apache/spark/sql/DataFrameSuite.scala | 5 ++ 5 files changed, 113 insertions(+), 14 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 8d763dca5243e..7da4c3cc6b9fa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -657,6 +657,9 @@ case class ArrayContains(left: Expression, right: Expression) override def dataType: DataType = BooleanType + @transient private lazy val ordering: Ordering[Any] = + TypeUtils.getInterpretedOrdering(right.dataType) + override def inputTypes: Seq[AbstractDataType] = right.dataType match { case NullType => Seq.empty case _ => left.dataType match { @@ -673,7 +676,7 @@ case class ArrayContains(left: Expression, right: Expression) TypeCheckResult.TypeCheckFailure( "Arguments must be an array followed by a value of same type as the array members") } else { - TypeCheckResult.TypeCheckSuccess + TypeUtils.checkForOrderingExpr(right.dataType, s"function $prettyName") } } @@ -686,7 +689,7 @@ case class ArrayContains(left: Expression, right: Expression) arr.asInstanceOf[ArrayData].foreach(right.dataType, (i, v) => if (v == null) { hasNull = true - } else if (v == value) { + } else if (ordering.equiv(v, value)) { return true } ) @@ -735,11 +738,7 @@ case class ArraysOverlap(left: Expression, right: Expression) override def checkInputDataTypes(): TypeCheckResult = super.checkInputDataTypes() match { case TypeCheckResult.TypeCheckSuccess => - if (RowOrdering.isOrderable(elementType)) { - TypeCheckResult.TypeCheckSuccess - } else { - TypeCheckResult.TypeCheckFailure(s"${elementType.simpleString} cannot be used in comparison.") - } + TypeUtils.checkForOrderingExpr(elementType, s"function $prettyName") case failure => failure } @@ -1391,13 +1390,24 @@ case class ArrayMax(child: Expression) extends UnaryExpression with ImplicitCast case class ArrayPosition(left: Expression, right: Expression) extends BinaryExpression with ImplicitCastInputTypes { + @transient private lazy val ordering: Ordering[Any] = + TypeUtils.getInterpretedOrdering(right.dataType) + override def dataType: DataType = LongType override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, left.dataType.asInstanceOf[ArrayType].elementType) + override def checkInputDataTypes(): TypeCheckResult = { + super.checkInputDataTypes() match { + case f: TypeCheckResult.TypeCheckFailure => f + case TypeCheckResult.TypeCheckSuccess => + TypeUtils.checkForOrderingExpr(right.dataType, s"function $prettyName") + } + } + override def nullSafeEval(arr: Any, value: Any): Any = { arr.asInstanceOf[ArrayData].foreach(right.dataType, (i, v) => - if (v == value) { + if (v != null && ordering.equiv(v, value)) { return (i + 1).toLong } ) @@ -1446,6 +1456,9 @@ case class ArrayPosition(left: Expression, right: Expression) since = "2.4.0") case class ElementAt(left: Expression, right: Expression) extends GetMapValueUtil { + @transient private lazy val ordering: Ordering[Any] = + TypeUtils.getInterpretedOrdering(left.dataType.asInstanceOf[MapType].keyType) + override def dataType: DataType = left.dataType match { case ArrayType(elementType, _) => elementType case MapType(_, valueType, _) => valueType @@ -1460,6 +1473,16 @@ case class ElementAt(left: Expression, right: Expression) extends GetMapValueUti ) } + override def checkInputDataTypes(): TypeCheckResult = { + super.checkInputDataTypes() match { + case f: TypeCheckResult.TypeCheckFailure => f + case TypeCheckResult.TypeCheckSuccess if left.dataType.isInstanceOf[MapType] => + TypeUtils.checkForOrderingExpr( + left.dataType.asInstanceOf[MapType].keyType, s"function $prettyName") + case TypeCheckResult.TypeCheckSuccess => TypeCheckResult.TypeCheckSuccess + } + } + override def nullable: Boolean = true override def nullSafeEval(value: Any, ordinal: Any): Any = { @@ -1484,7 +1507,7 @@ case class ElementAt(left: Expression, right: Expression) extends GetMapValueUti } } case _: MapType => - getValueEval(value, ordinal, left.dataType.asInstanceOf[MapType].keyType) + getValueEval(value, ordinal, left.dataType.asInstanceOf[MapType].keyType, ordering) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index 3fba52d745453..99671d5b863c4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -21,7 +21,7 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} -import org.apache.spark.sql.catalyst.util.{quoteIdentifier, ArrayData, GenericArrayData, MapData} +import org.apache.spark.sql.catalyst.util.{quoteIdentifier, ArrayData, GenericArrayData, MapData, TypeUtils} import org.apache.spark.sql.types._ //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -273,7 +273,7 @@ case class GetArrayItem(child: Expression, ordinal: Expression) abstract class GetMapValueUtil extends BinaryExpression with ImplicitCastInputTypes { // todo: current search is O(n), improve it. - def getValueEval(value: Any, ordinal: Any, keyType: DataType): Any = { + def getValueEval(value: Any, ordinal: Any, keyType: DataType, ordering: Ordering[Any]): Any = { val map = value.asInstanceOf[MapData] val length = map.numElements() val keys = map.keyArray() @@ -282,7 +282,7 @@ abstract class GetMapValueUtil extends BinaryExpression with ImplicitCastInputTy var i = 0 var found = false while (i < length && !found) { - if (keys.get(i, keyType) == ordinal) { + if (ordering.equiv(keys.get(i, keyType), ordinal)) { found = true } else { i += 1 @@ -345,8 +345,19 @@ abstract class GetMapValueUtil extends BinaryExpression with ImplicitCastInputTy case class GetMapValue(child: Expression, key: Expression) extends GetMapValueUtil with ExtractValue with NullIntolerant { + @transient private lazy val ordering: Ordering[Any] = + TypeUtils.getInterpretedOrdering(keyType) + private def keyType = child.dataType.asInstanceOf[MapType].keyType + override def checkInputDataTypes(): TypeCheckResult = { + super.checkInputDataTypes() match { + case f: TypeCheckResult.TypeCheckFailure => f + case TypeCheckResult.TypeCheckSuccess => + TypeUtils.checkForOrderingExpr(keyType, s"function $prettyName") + } + } + // We have done type checking for child in `ExtractValue`, so only need to check the `key`. override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType, keyType) @@ -363,7 +374,7 @@ case class GetMapValue(child: Expression, key: Expression) // todo: current search is O(n), improve it. override def nullSafeEval(value: Any, ordinal: Any): Any = { - getValueEval(value, ordinal, keyType) + getValueEval(value, ordinal, keyType, ordering) } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 71ff96bb722e2..3fc0b08c56e02 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -157,6 +157,33 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(ArrayContains(a3, Literal("")), null) checkEvaluation(ArrayContains(a3, Literal.create(null, StringType)), null) + + // binary + val b0 = Literal.create(Seq[Array[Byte]](Array[Byte](5, 6), Array[Byte](1, 2)), + ArrayType(BinaryType)) + val b1 = Literal.create(Seq[Array[Byte]](Array[Byte](2, 1), Array[Byte](4, 3)), + ArrayType(BinaryType)) + val b2 = Literal.create(Seq[Array[Byte]](Array[Byte](2, 1), null), + ArrayType(BinaryType)) + val b3 = Literal.create(Seq[Array[Byte]](null, Array[Byte](1, 2)), + ArrayType(BinaryType)) + val be = Literal.create(Array[Byte](1, 2), BinaryType) + val nullBinary = Literal.create(null, BinaryType) + + checkEvaluation(ArrayContains(b0, be), true) + checkEvaluation(ArrayContains(b1, be), false) + checkEvaluation(ArrayContains(b0, nullBinary), null) + checkEvaluation(ArrayContains(b2, be), null) + checkEvaluation(ArrayContains(b3, be), true) + + // complex data types + val aa0 = Literal.create(Seq[Seq[Int]](Seq[Int](1, 2), Seq[Int](3, 4)), + ArrayType(ArrayType(IntegerType))) + val aa1 = Literal.create(Seq[Seq[Int]](Seq[Int](5, 6), Seq[Int](2, 1)), + ArrayType(ArrayType(IntegerType))) + val aae = Literal.create(Seq[Int](1, 2), ArrayType(IntegerType)) + checkEvaluation(ArrayContains(aa0, aae), true) + checkEvaluation(ArrayContains(aa1, aae), false) } test("ArraysOverlap") { @@ -372,6 +399,14 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(ArrayPosition(a3, Literal("")), null) checkEvaluation(ArrayPosition(a3, Literal.create(null, StringType)), null) + + val aa0 = Literal.create(Seq[Seq[Int]](Seq[Int](1, 2), Seq[Int](3, 4)), + ArrayType(ArrayType(IntegerType))) + val aa1 = Literal.create(Seq[Seq[Int]](Seq[Int](5, 6), Seq[Int](2, 1)), + ArrayType(ArrayType(IntegerType))) + val aae = Literal.create(Seq[Int](1, 2), ArrayType(IntegerType)) + checkEvaluation(ArrayPosition(aa0, aae), 1L) + checkEvaluation(ArrayPosition(aa1, aae), 0L) } test("elementAt") { @@ -409,7 +444,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper val m1 = Literal.create(Map[String, String](), MapType(StringType, StringType)) val m2 = Literal.create(null, MapType(StringType, StringType)) - checkEvaluation(ElementAt(m0, Literal(1.0)), null) + assert(ElementAt(m0, Literal(1.0)).checkInputDataTypes().isFailure) checkEvaluation(ElementAt(m0, Literal("d")), null) @@ -420,6 +455,18 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(ElementAt(m0, Literal("c")), null) checkEvaluation(ElementAt(m2, Literal("a")), null) + + // test binary type as keys + val mb0 = Literal.create( + Map(Array[Byte](1, 2) -> "1", Array[Byte](3, 4) -> null, Array[Byte](2, 1) -> "2"), + MapType(BinaryType, StringType)) + val mb1 = Literal.create(Map[Array[Byte], String](), MapType(BinaryType, StringType)) + + checkEvaluation(ElementAt(mb0, Literal(Array[Byte](1, 2, 3))), null) + + checkEvaluation(ElementAt(mb1, Literal(Array[Byte](1, 2))), null) + checkEvaluation(ElementAt(mb0, Literal(Array[Byte](2, 1), BinaryType)), "2") + checkEvaluation(ElementAt(mb0, Literal(Array[Byte](3, 4))), null) } test("Concat") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala index 633d86d495581..5452e72b38647 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala @@ -439,4 +439,17 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { .select('c as 'sCol2, 'a as 'sCol1) checkRule(originalQuery, correctAnswer) } + + test("SPARK-24313: support binary type as map keys in GetMapValue") { + val mb0 = Literal.create( + Map(Array[Byte](1, 2) -> "1", Array[Byte](3, 4) -> null, Array[Byte](2, 1) -> "2"), + MapType(BinaryType, StringType)) + val mb1 = Literal.create(Map[Array[Byte], String](), MapType(BinaryType, StringType)) + + checkEvaluation(GetMapValue(mb0, Literal(Array[Byte](1, 2, 3))), null) + + checkEvaluation(GetMapValue(mb1, Literal(Array[Byte](1, 2))), null) + checkEvaluation(GetMapValue(mb0, Literal(Array[Byte](2, 1), BinaryType)), "2") + checkEvaluation(GetMapValue(mb0, Literal(Array[Byte](3, 4))), null) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 60e84e6ee7504..1cc8cb3874c9b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -2265,4 +2265,9 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { val df = spark.range(1).select($"id", new Column(Uuid())) checkAnswer(df, df.collect()) } + + test("SPARK-24313: access map with binary keys") { + val mapWithBinaryKey = map(lit(Array[Byte](1.toByte)), lit(1)) + checkAnswer(spark.range(1).select(mapWithBinaryKey.getItem(Array[Byte](1.toByte))), Row(1)) + } } From fc743f7b30902bad1da36131087bb922c17a048e Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Tue, 22 May 2018 08:20:59 -0500 Subject: [PATCH 0850/1161] [SPARK-20120][SQL][FOLLOW-UP] Better way to support spark-sql silent mode. ## What changes were proposed in this pull request? `spark-sql` silent mode will broken if`SPARK_HOME/jars` missing `kubernetes-model-2.0.0.jar`. This pr use `sc.setLogLevel ()` to implement silent mode. ## How was this patch tested? manual tests ``` build/sbt -Phive -Phive-thriftserver package export SPARK_PREPEND_CLASSES=true ./bin/spark-sql -S ``` Author: Yuming Wang Closes #20274 from wangyum/SPARK-20120-FOLLOW-UP. --- .../spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala index 084f8200102ba..d9fd3ebd3c65d 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala @@ -35,7 +35,7 @@ import org.apache.hadoop.hive.ql.exec.Utilities import org.apache.hadoop.hive.ql.processors._ import org.apache.hadoop.hive.ql.session.SessionState import org.apache.hadoop.security.{Credentials, UserGroupInformation} -import org.apache.log4j.{Level, Logger} +import org.apache.log4j.Level import org.apache.thrift.transport.TSocket import org.apache.spark.SparkConf @@ -300,10 +300,6 @@ private[hive] class SparkSQLCLIDriver extends CliDriver with Logging { private val console = new SessionState.LogHelper(LOG) - if (sessionState.getIsSilent) { - Logger.getRootLogger.setLevel(Level.WARN) - } - private val isRemoteMode = { SparkSQLCLIDriver.isRemoteMode(sessionState) } @@ -315,6 +311,9 @@ private[hive] class SparkSQLCLIDriver extends CliDriver with Logging { // because the Hive unit tests do not go through the main() code path. if (!isRemoteMode) { SparkSQLEnv.init() + if (sessionState.getIsSilent) { + SparkSQLEnv.sparkContext.setLogLevel(Level.WARN.toString) + } } else { // Hive 1.2 + not supported in CLI throw new RuntimeException("Remote operations not supported") From 8086acc2f676a04ce6255a621ffae871bd09ceea Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Tue, 22 May 2018 22:07:32 +0800 Subject: [PATCH 0851/1161] [SPARK-24244][SQL] Passing only required columns to the CSV parser ## What changes were proposed in this pull request? uniVocity parser allows to specify only required column names or indexes for [parsing](https://www.univocity.com/pages/parsers-tutorial) like: ``` // Here we select only the columns by their indexes. // The parser just skips the values in other columns parserSettings.selectIndexes(4, 0, 1); CsvParser parser = new CsvParser(parserSettings); ``` In this PR, I propose to extract indexes from required schema and pass them into the CSV parser. Benchmarks show the following improvements in parsing of 1000 columns: ``` Select 100 columns out of 1000: x1.76 Select 1 column out of 1000: x2 ``` **Note**: Comparing to current implementation, the changes can return different result for malformed rows in the `DROPMALFORMED` and `FAILFAST` modes if only subset of all columns is requested. To have previous behavior, set `spark.sql.csv.parser.columnPruning.enabled` to `false`. ## How was this patch tested? It was tested by new test which selects 3 columns out of 15, by existing tests and by new benchmarks. Author: Maxim Gekk Closes #21296 from MaxGekk/csv-column-pruning. --- docs/sql-programming-guide.md | 1 + .../apache/spark/sql/internal/SQLConf.scala | 7 +++ .../datasources/csv/CSVOptions.scala | 3 ++ .../datasources/csv/UnivocityParser.scala | 26 ++++++----- .../datasources/csv/CSVBenchmarks.scala | 42 ++++++++++++++++++ .../execution/datasources/csv/CSVSuite.scala | 43 ++++++++++++++++--- 6 files changed, 104 insertions(+), 18 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index f1ed316341b95..fc26562ff33da 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1825,6 +1825,7 @@ working with timestamps in `pandas_udf`s to get the best performance, see - In version 2.3 and earlier, `to_utc_timestamp` and `from_utc_timestamp` respect the timezone in the input timestamp string, which breaks the assumption that the input timestamp is in a specific timezone. Therefore, these 2 functions can return unexpected results. In version 2.4 and later, this problem has been fixed. `to_utc_timestamp` and `from_utc_timestamp` will return null if the input timestamp string contains timezone. As an example, `from_utc_timestamp('2000-10-10 00:00:00', 'GMT+1')` will return `2000-10-10 01:00:00` in both Spark 2.3 and 2.4. However, `from_utc_timestamp('2000-10-10 00:00:00+00:00', 'GMT+1')`, assuming a local timezone of GMT+8, will return `2000-10-10 09:00:00` in Spark 2.3 but `null` in 2.4. For people who don't care about this problem and want to retain the previous behaivor to keep their query unchanged, you can set `spark.sql.function.rejectTimezoneInString` to false. This option will be removed in Spark 3.0 and should only be used as a temporary workaround. - In version 2.3 and earlier, Spark converts Parquet Hive tables by default but ignores table properties like `TBLPROPERTIES (parquet.compression 'NONE')`. This happens for ORC Hive table properties like `TBLPROPERTIES (orc.compress 'NONE')` in case of `spark.sql.hive.convertMetastoreOrc=true`, too. Since Spark 2.4, Spark respects Parquet/ORC specific table properties while converting Parquet/ORC Hive tables. As an example, `CREATE TABLE t(id int) STORED AS PARQUET TBLPROPERTIES (parquet.compression 'NONE')` would generate Snappy parquet files during insertion in Spark 2.3, and in Spark 2.4, the result would be uncompressed parquet files. - Since Spark 2.0, Spark converts Parquet Hive tables by default for better performance. Since Spark 2.4, Spark converts ORC Hive tables by default, too. It means Spark uses its own ORC support by default instead of Hive SerDe. As an example, `CREATE TABLE t(id int) STORED AS ORC` would be handled with Hive SerDe in Spark 2.3, and in Spark 2.4, it would be converted into Spark's ORC data source table and ORC vectorization would be applied. To set `false` to `spark.sql.hive.convertMetastoreOrc` restores the previous behavior. + - In version 2.3 and earlier, CSV rows are considered as malformed if at least one column value in the row is malformed. CSV parser dropped such rows in the DROPMALFORMED mode or outputs an error in the FAILFAST mode. Since Spark 2.4, CSV row is considered as malformed only when it contains malformed column values requested from CSV datasource, other values can be ignored. As an example, CSV file contains the "id,name" header and one row "1234". In Spark 2.4, selection of the id column consists of a row with one column value 1234 but in Spark 2.3 and earlier it is empty in the DROPMALFORMED mode. To restore the previous behavior, set `spark.sql.csv.parser.columnPruning.enabled` to `false`. ## Upgrading From Spark SQL 2.2 to 2.3 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index a2fb3c64844b5..d0478d6ad250b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1295,6 +1295,13 @@ object SQLConf { object Replaced { val MAPREDUCE_JOB_REDUCES = "mapreduce.job.reduces" } + + val CSV_PARSER_COLUMN_PRUNING = buildConf("spark.sql.csv.parser.columnPruning.enabled") + .internal() + .doc("If it is set to true, column names of the requested schema are passed to CSV parser. " + + "Other column values can be ignored during parsing even if they are malformed.") + .booleanConf + .createWithDefault(true) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala index 1066d156acd74..dd41aee0f2ebc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala @@ -25,6 +25,7 @@ import org.apache.commons.lang3.time.FastDateFormat import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.internal.SQLConf class CSVOptions( @transient val parameters: CaseInsensitiveMap[String], @@ -80,6 +81,8 @@ class CSVOptions( } } + private[csv] val columnPruning = SQLConf.get.getConf(SQLConf.CSV_PARSER_COLUMN_PRUNING) + val delimiter = CSVUtils.toChar( parameters.getOrElse("sep", parameters.getOrElse("delimiter", ","))) val parseMode: ParseMode = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala index 99557a1ceb0c8..4f00cc5eb3f39 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala @@ -34,10 +34,10 @@ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String class UnivocityParser( - schema: StructType, + dataSchema: StructType, requiredSchema: StructType, val options: CSVOptions) extends Logging { - require(requiredSchema.toSet.subsetOf(schema.toSet), + require(requiredSchema.toSet.subsetOf(dataSchema.toSet), "requiredSchema should be the subset of schema.") def this(schema: StructType, options: CSVOptions) = this(schema, schema, options) @@ -45,9 +45,17 @@ class UnivocityParser( // A `ValueConverter` is responsible for converting the given value to a desired type. private type ValueConverter = String => Any - private val tokenizer = new CsvParser(options.asParserSettings) + private val tokenizer = { + val parserSetting = options.asParserSettings + if (options.columnPruning && requiredSchema.length < dataSchema.length) { + val tokenIndexArr = requiredSchema.map(f => java.lang.Integer.valueOf(dataSchema.indexOf(f))) + parserSetting.selectIndexes(tokenIndexArr: _*) + } + new CsvParser(parserSetting) + } + private val schema = if (options.columnPruning) requiredSchema else dataSchema - private val row = new GenericInternalRow(requiredSchema.length) + private val row = new GenericInternalRow(schema.length) // Retrieve the raw record string. private def getCurrentInput: UTF8String = { @@ -73,11 +81,8 @@ class UnivocityParser( // Each input token is placed in each output row's position by mapping these. In this case, // // output row - ["A", 2] - private val valueConverters: Array[ValueConverter] = + private val valueConverters: Array[ValueConverter] = { schema.map(f => makeConverter(f.name, f.dataType, f.nullable, options)).toArray - - private val tokenIndexArr: Array[Int] = { - requiredSchema.map(f => schema.indexOf(f)).toArray } /** @@ -210,9 +215,8 @@ class UnivocityParser( } else { try { var i = 0 - while (i < requiredSchema.length) { - val from = tokenIndexArr(i) - row(i) = valueConverters(from).apply(tokens(from)) + while (i < schema.length) { + row(i) = valueConverters(i).apply(tokens(i)) i += 1 } row diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVBenchmarks.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVBenchmarks.scala index d442ba7e59c61..ec788df00aa92 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVBenchmarks.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVBenchmarks.scala @@ -74,7 +74,49 @@ object CSVBenchmarks { } } + def multiColumnsBenchmark(rowsNum: Int): Unit = { + val colsNum = 1000 + val benchmark = new Benchmark(s"Wide rows with $colsNum columns", rowsNum) + + withTempPath { path => + val fields = Seq.tabulate(colsNum)(i => StructField(s"col$i", IntegerType)) + val schema = StructType(fields) + val values = (0 until colsNum).map(i => i.toString).mkString(",") + val columnNames = schema.fieldNames + + spark.range(rowsNum) + .select(Seq.tabulate(colsNum)(i => lit(i).as(s"col$i")): _*) + .write.option("header", true) + .csv(path.getAbsolutePath) + + val ds = spark.read.schema(schema).csv(path.getAbsolutePath) + + benchmark.addCase(s"Select $colsNum columns", 3) { _ => + ds.select("*").filter((row: Row) => true).count() + } + val cols100 = columnNames.take(100).map(Column(_)) + benchmark.addCase(s"Select 100 columns", 3) { _ => + ds.select(cols100: _*).filter((row: Row) => true).count() + } + benchmark.addCase(s"Select one column", 3) { _ => + ds.select($"col1").filter((row: Row) => true).count() + } + + /* + Intel(R) Core(TM) i7-7920HQ CPU @ 3.10GHz + + Wide rows with 1000 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + -------------------------------------------------------------------------------------------- + Select 1000 columns 76910 / 78065 0.0 76909.8 1.0X + Select 100 columns 28625 / 32884 0.0 28625.1 2.7X + Select one column 22498 / 22669 0.0 22497.8 3.4X + */ + benchmark.run() + } + } + def main(args: Array[String]): Unit = { quotedValuesBenchmark(rowsNum = 50 * 1000, numIters = 3) + multiColumnsBenchmark(rowsNum = 1000 * 1000) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index 07e6c74b14d0d..5f9f799a6c466 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -260,14 +260,16 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils with Te } test("test for DROPMALFORMED parsing mode") { - Seq(false, true).foreach { multiLine => - val cars = spark.read - .format("csv") - .option("multiLine", multiLine) - .options(Map("header" -> "true", "mode" -> "dropmalformed")) - .load(testFile(carsFile)) + withSQLConf(SQLConf.CSV_PARSER_COLUMN_PRUNING.key -> "false") { + Seq(false, true).foreach { multiLine => + val cars = spark.read + .format("csv") + .option("multiLine", multiLine) + .options(Map("header" -> "true", "mode" -> "dropmalformed")) + .load(testFile(carsFile)) - assert(cars.select("year").collect().size === 2) + assert(cars.select("year").collect().size === 2) + } } } @@ -1368,4 +1370,31 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils with Te checkAnswer(computed, expected) } } + + test("SPARK-24244: Select a subset of all columns") { + withTempPath { path => + import collection.JavaConverters._ + val schema = new StructType() + .add("f1", IntegerType).add("f2", IntegerType).add("f3", IntegerType) + .add("f4", IntegerType).add("f5", IntegerType).add("f6", IntegerType) + .add("f7", IntegerType).add("f8", IntegerType).add("f9", IntegerType) + .add("f10", IntegerType).add("f11", IntegerType).add("f12", IntegerType) + .add("f13", IntegerType).add("f14", IntegerType).add("f15", IntegerType) + + val odf = spark.createDataFrame(List( + Row(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15), + Row(-1, -2, -3, -4, -5, -6, -7, -8, -9, -10, -11, -12, -13, -14, -15) + ).asJava, schema) + odf.write.csv(path.getCanonicalPath) + val idf = spark.read + .schema(schema) + .csv(path.getCanonicalPath) + .select('f15, 'f10, 'f5) + + checkAnswer( + idf, + List(Row(15, 10, 5), Row(-15, -10, -5)) + ) + } + } } From f9f055afa47412eec8228c843b34a90decb9be43 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 23 May 2018 01:50:22 +0800 Subject: [PATCH 0852/1161] [SPARK-24121][SQL] Add API for handling expression code generation ## What changes were proposed in this pull request? This patch tries to implement this [proposal](https://github.com/apache/spark/pull/19813#issuecomment-354045400) to add an API for handling expression code generation. It should allow us to manipulate how to generate codes for expressions. In details, this adds an new abstraction `CodeBlock` to `JavaCode`. `CodeBlock` holds the code snippet and inputs for generating actual java code. For example, in following java code: ```java int ${variable} = 1; boolean ${isNull} = ${CodeGenerator.defaultValue(BooleanType)}; ``` `variable`, `isNull` are two `VariableValue` and `CodeGenerator.defaultValue(BooleanType)` is a string. They are all inputs to this code block and held by `CodeBlock` representing this code. For codegen, we provide a specified string interpolator `code`, so you can define a code like this: ```scala val codeBlock = code""" |int ${variable} = 1; |boolean ${isNull} = ${CodeGenerator.defaultValue(BooleanType)}; """.stripMargin // Generates actual java code. codeBlock.toString ``` Because those inputs are held separately in `CodeBlock` before generating code, we can safely manipulate them, e.g., replacing statements to aliased variables, etc.. ## How was this patch tested? Added tests. Author: Liang-Chi Hsieh Closes #21193 from viirya/SPARK-24121. --- .../catalyst/expressions/BoundAttribute.scala | 5 +- .../spark/sql/catalyst/expressions/Cast.scala | 10 +- .../sql/catalyst/expressions/Expression.scala | 26 ++-- .../MonotonicallyIncreasingID.scala | 3 +- .../sql/catalyst/expressions/ScalaUDF.scala | 3 +- .../sql/catalyst/expressions/SortOrder.scala | 3 +- .../expressions/SparkPartitionID.scala | 3 +- .../sql/catalyst/expressions/TimeWindow.scala | 3 +- .../sql/catalyst/expressions/arithmetic.scala | 13 +- .../expressions/codegen/CodeGenerator.scala | 25 +-- .../expressions/codegen/CodegenFallback.scala | 5 +- .../codegen/GenerateSafeProjection.scala | 7 +- .../codegen/GenerateUnsafeProjection.scala | 5 +- .../expressions/codegen/javaCode.scala | 145 +++++++++++++++++- .../expressions/collectionOperations.scala | 19 +-- .../expressions/complexTypeCreator.scala | 7 +- .../expressions/conditionalExpressions.scala | 5 +- .../expressions/datetimeExpressions.scala | 23 +-- .../expressions/decimalExpressions.scala | 5 +- .../sql/catalyst/expressions/generators.scala | 3 +- .../spark/sql/catalyst/expressions/hash.scala | 5 +- .../catalyst/expressions/inputFileBlock.scala | 14 +- .../expressions/mathExpressions.scala | 5 +- .../spark/sql/catalyst/expressions/misc.scala | 5 +- .../expressions/nullExpressions.scala | 9 +- .../expressions/objects/objects.scala | 48 +++--- .../sql/catalyst/expressions/predicates.scala | 15 +- .../expressions/randomExpressions.scala | 5 +- .../expressions/regexpExpressions.scala | 9 +- .../expressions/stringExpressions.scala | 25 +-- .../ExpressionEvalHelperSuite.scala | 3 +- .../expressions/codegen/CodeBlockSuite.scala | 136 ++++++++++++++++ .../sql/execution/ColumnarBatchScan.scala | 9 +- .../spark/sql/execution/ExpandExec.scala | 3 +- .../spark/sql/execution/GenerateExec.scala | 5 +- .../sql/execution/WholeStageCodegenExec.scala | 15 +- .../aggregate/HashAggregateExec.scala | 7 +- .../aggregate/HashMapGenerator.scala | 3 +- .../joins/BroadcastHashJoinExec.scala | 3 +- .../execution/joins/SortMergeJoinExec.scala | 5 +- .../spark/sql/GeneratorFunctionSuite.scala | 4 +- 41 files changed, 479 insertions(+), 172 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeBlockSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index 4cc84b27d9eb0..df3ab05e02c76 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -21,6 +21,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors.attachTree import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral} +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.types._ /** @@ -56,13 +57,13 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) val value = CodeGenerator.getValue(ctx.INPUT_ROW, dataType, ordinal.toString) if (nullable) { ev.copy(code = - s""" + code""" |boolean ${ev.isNull} = ${ctx.INPUT_ROW}.isNullAt($ordinal); |$javaType ${ev.value} = ${ev.isNull} ? | ${CodeGenerator.defaultValue(dataType)} : ($value); """.stripMargin) } else { - ev.copy(code = s"$javaType ${ev.value} = $value;", isNull = FalseLiteral) + ev.copy(code = code"$javaType ${ev.value} = $value;", isNull = FalseLiteral) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 12330bfa55ab9..699ea53b5df0f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -23,6 +23,7 @@ import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion} import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} @@ -623,8 +624,13 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val eval = child.genCode(ctx) val nullSafeCast = nullSafeCastFunction(child.dataType, dataType, ctx) - ev.copy(code = eval.code + - castCode(ctx, eval.value, eval.isNull, ev.value, ev.isNull, dataType, nullSafeCast)) + + ev.copy(code = + code""" + ${eval.code} + // This comment is added for manually tracking reference of ${eval.value}, ${eval.isNull} + ${castCode(ctx, eval.value, eval.isNull, ev.value, ev.isNull, dataType, nullSafeCast)} + """) } // The function arguments are: `input`, `result` and `resultIsNull`. We don't need `inputIsNull` diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 97dff6ae88299..9b9fa41a47d0f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -22,6 +22,7 @@ import java.util.Locale import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -108,9 +109,9 @@ abstract class Expression extends TreeNode[Expression] { JavaCode.isNullVariable(isNull), JavaCode.variable(value, dataType))) reduceCodeSize(ctx, eval) - if (eval.code.nonEmpty) { + if (eval.code.toString.nonEmpty) { // Add `this` in the comment. - eval.copy(code = s"${ctx.registerComment(this.toString)}\n" + eval.code.trim) + eval.copy(code = ctx.registerComment(this.toString) + eval.code) } else { eval } @@ -119,7 +120,7 @@ abstract class Expression extends TreeNode[Expression] { private def reduceCodeSize(ctx: CodegenContext, eval: ExprCode): Unit = { // TODO: support whole stage codegen too - if (eval.code.trim.length > 1024 && ctx.INPUT_ROW != null && ctx.currentVars == null) { + if (eval.code.length > 1024 && ctx.INPUT_ROW != null && ctx.currentVars == null) { val setIsNull = if (!eval.isNull.isInstanceOf[LiteralValue]) { val globalIsNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "globalIsNull") val localIsNull = eval.isNull @@ -136,14 +137,14 @@ abstract class Expression extends TreeNode[Expression] { val funcFullName = ctx.addNewFunction(funcName, s""" |private $javaType $funcName(InternalRow ${ctx.INPUT_ROW}) { - | ${eval.code.trim} + | ${eval.code} | $setIsNull | return ${eval.value}; |} """.stripMargin) eval.value = JavaCode.variable(newValue, dataType) - eval.code = s"$javaType $newValue = $funcFullName(${ctx.INPUT_ROW});" + eval.code = code"$javaType $newValue = $funcFullName(${ctx.INPUT_ROW});" } } @@ -437,15 +438,14 @@ abstract class UnaryExpression extends Expression { if (nullable) { val nullSafeEval = ctx.nullSafeExec(child.nullable, childGen.isNull)(resultCode) - ev.copy(code = s""" + ev.copy(code = code""" ${childGen.code} boolean ${ev.isNull} = ${childGen.isNull}; ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; $nullSafeEval """) } else { - ev.copy(code = s""" - boolean ${ev.isNull} = false; + ev.copy(code = code""" ${childGen.code} ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; $resultCode""", isNull = FalseLiteral) @@ -537,14 +537,13 @@ abstract class BinaryExpression extends Expression { } } - ev.copy(code = s""" + ev.copy(code = code""" boolean ${ev.isNull} = true; ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; $nullSafeEval """) } else { - ev.copy(code = s""" - boolean ${ev.isNull} = false; + ev.copy(code = code""" ${leftGen.code} ${rightGen.code} ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; @@ -681,13 +680,12 @@ abstract class TernaryExpression extends Expression { } } - ev.copy(code = s""" + ev.copy(code = code""" boolean ${ev.isNull} = true; ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; $nullSafeEval""") } else { - ev.copy(code = s""" - boolean ${ev.isNull} = false; + ev.copy(code = code""" ${leftGen.code} ${midGen.code} ${rightGen.code} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala index 9f0779642271d..f1da592a76845 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral} +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.types.{DataType, LongType} /** @@ -72,7 +73,7 @@ case class MonotonicallyIncreasingID() extends LeafExpression with Stateful { ctx.addPartitionInitializationStatement(s"$countTerm = 0L;") ctx.addPartitionInitializationStatement(s"$partitionMaskTerm = ((long) partitionIndex) << 33;") - ev.copy(code = s""" + ev.copy(code = code""" final ${CodeGenerator.javaType(dataType)} ${ev.value} = $partitionMaskTerm + $countTerm; $countTerm++;""", isNull = FalseLiteral) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala index e869258469a97..3e7ca88249737 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.types.DataType /** @@ -1030,7 +1031,7 @@ case class ScalaUDF( """.stripMargin ev.copy(code = - s""" + code""" |$evalCode |${initArgs.mkString("\n")} |$callFunc diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala index ff7c98f714905..2ce9d072c71c9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.types._ import org.apache.spark.util.collection.unsafe.sort.PrefixComparators._ @@ -181,7 +182,7 @@ case class SortPrefix(child: SortOrder) extends UnaryExpression { } ev.copy(code = childCode.code + - s""" + code""" |long ${ev.value} = 0L; |boolean ${ev.isNull} = ${childCode.isNull}; |if (!${childCode.isNull}) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala index 787bcaf5e81de..9856b37e53fbc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral} +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.types.{DataType, IntegerType} /** @@ -46,7 +47,7 @@ case class SparkPartitionID() extends LeafExpression with Nondeterministic { val idTerm = "partitionId" ctx.addImmutableStateIfNotExists(CodeGenerator.JAVA_INT, idTerm) ctx.addPartitionInitializationStatement(s"$idTerm = partitionIndex;") - ev.copy(code = s"final ${CodeGenerator.javaType(dataType)} ${ev.value} = $idTerm;", + ev.copy(code = code"final ${CodeGenerator.javaType(dataType)} ${ev.value} = $idTerm;", isNull = FalseLiteral) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala index 6c4a3601c1730..84e38a8b2711e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.CalendarInterval @@ -164,7 +165,7 @@ case class PreciseTimestampConversion( override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val eval = child.genCode(ctx) ev.copy(code = eval.code + - s"""boolean ${ev.isNull} = ${eval.isNull}; + code"""boolean ${ev.isNull} = ${eval.isNull}; |${CodeGenerator.javaType(dataType)} ${ev.value} = ${eval.value}; """.stripMargin) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index efd4e992c8eec..fe91e520169b4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.CalendarInterval @@ -259,7 +260,7 @@ trait DivModLike extends BinaryArithmetic { s"($javaType)(${eval1.value} $symbol ${eval2.value})" } if (!left.nullable && !right.nullable) { - ev.copy(code = s""" + ev.copy(code = code""" ${eval2.code} boolean ${ev.isNull} = false; $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; @@ -270,7 +271,7 @@ trait DivModLike extends BinaryArithmetic { ${ev.value} = $operation; }""") } else { - ev.copy(code = s""" + ev.copy(code = code""" ${eval2.code} boolean ${ev.isNull} = false; $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; @@ -436,7 +437,7 @@ case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic { } if (!left.nullable && !right.nullable) { - ev.copy(code = s""" + ev.copy(code = code""" ${eval2.code} boolean ${ev.isNull} = false; $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; @@ -447,7 +448,7 @@ case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic { $result }""") } else { - ev.copy(code = s""" + ev.copy(code = code""" ${eval2.code} boolean ${ev.isNull} = false; $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; @@ -569,7 +570,7 @@ case class Least(children: Seq[Expression]) extends Expression { """.stripMargin, foldFunctions = _.map(funcCall => s"${ev.value} = $funcCall;").mkString("\n")) ev.copy(code = - s""" + code""" |${ev.isNull} = true; |$resultType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; |$codes @@ -644,7 +645,7 @@ case class Greatest(children: Seq[Expression]) extends Expression { """.stripMargin, foldFunctions = _.map(funcCall => s"${ev.value} = $funcCall;").mkString("\n")) ev.copy(code = - s""" + code""" |${ev.isNull} = true; |$resultType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; |$codes diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index d382d9aace109..66315e5906253 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -38,6 +38,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.metrics.source.CodegenMetrics import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -57,19 +58,19 @@ import org.apache.spark.util.{ParentClassLoader, Utils} * @param value A term for a (possibly primitive) value of the result of the evaluation. Not * valid if `isNull` is set to `true`. */ -case class ExprCode(var code: String, var isNull: ExprValue, var value: ExprValue) +case class ExprCode(var code: Block, var isNull: ExprValue, var value: ExprValue) object ExprCode { def apply(isNull: ExprValue, value: ExprValue): ExprCode = { - ExprCode(code = "", isNull, value) + ExprCode(code = EmptyBlock, isNull, value) } def forNullValue(dataType: DataType): ExprCode = { - ExprCode(code = "", isNull = TrueLiteral, JavaCode.defaultLiteral(dataType)) + ExprCode(code = EmptyBlock, isNull = TrueLiteral, JavaCode.defaultLiteral(dataType)) } def forNonNullValue(value: ExprValue): ExprCode = { - ExprCode(code = "", isNull = FalseLiteral, value = value) + ExprCode(code = EmptyBlock, isNull = FalseLiteral, value = value) } } @@ -330,9 +331,9 @@ class CodegenContext { def addBufferedState(dataType: DataType, variableName: String, initCode: String): ExprCode = { val value = addMutableState(javaType(dataType), variableName) val code = dataType match { - case StringType => s"$value = $initCode.clone();" - case _: StructType | _: ArrayType | _: MapType => s"$value = $initCode.copy();" - case _ => s"$value = $initCode;" + case StringType => code"$value = $initCode.clone();" + case _: StructType | _: ArrayType | _: MapType => code"$value = $initCode.copy();" + case _ => code"$value = $initCode;" } ExprCode(code, FalseLiteral, JavaCode.global(value, dataType)) } @@ -1056,7 +1057,7 @@ class CodegenContext { val eval = expr.genCode(this) val state = SubExprEliminationState(eval.isNull, eval.value) e.foreach(localSubExprEliminationExprs.put(_, state)) - eval.code.trim + eval.code.toString } SubExprCodes(codes, localSubExprEliminationExprs.toMap) } @@ -1084,7 +1085,7 @@ class CodegenContext { val fn = s""" |private void $fnName(InternalRow $INPUT_ROW) { - | ${eval.code.trim} + | ${eval.code} | $isNull = ${eval.isNull}; | $value = ${eval.value}; |} @@ -1141,7 +1142,7 @@ class CodegenContext { def registerComment( text: => String, placeholderId: String = "", - force: Boolean = false): String = { + force: Boolean = false): Block = { // By default, disable comments in generated code because computing the comments themselves can // be extremely expensive in certain cases, such as deeply-nested expressions which operate over // inputs with wide schemas. For more details on the performance issues that motivated this @@ -1160,9 +1161,9 @@ class CodegenContext { s"// $text" } placeHolderToComments += (name -> comment) - s"/*$name*/" + code"/*$name*/" } else { - "" + EmptyBlock } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala index a91989e129664..3f4704d287cbd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions.codegen import org.apache.spark.sql.catalyst.expressions.{Expression, LeafExpression, Nondeterministic} +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ /** * A trait that can be used to provide a fallback mode for expression code generation. @@ -46,7 +47,7 @@ trait CodegenFallback extends Expression { val placeHolder = ctx.registerComment(this.toString) val javaType = CodeGenerator.javaType(this.dataType) if (nullable) { - ev.copy(code = s""" + ev.copy(code = code""" $placeHolder Object $objectTerm = ((Expression) references[$idx]).eval($input); boolean ${ev.isNull} = $objectTerm == null; @@ -55,7 +56,7 @@ trait CodegenFallback extends Expression { ${ev.value} = (${CodeGenerator.boxedType(this.dataType)}) $objectTerm; }""") } else { - ev.copy(code = s""" + ev.copy(code = code""" $placeHolder Object $objectTerm = ((Expression) references[$idx]).eval($input); $javaType ${ev.value} = (${CodeGenerator.boxedType(this.dataType)}) $objectTerm; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala index 01c350e9dbf69..39778661d1c48 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala @@ -22,6 +22,7 @@ import scala.annotation.tailrec import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData} import org.apache.spark.sql.types._ @@ -71,7 +72,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] arguments = Seq("InternalRow" -> tmpInput, "Object[]" -> values) ) val code = - s""" + code""" |final InternalRow $tmpInput = $input; |final Object[] $values = new Object[${schema.length}]; |$allFields @@ -97,7 +98,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] ctx, JavaCode.expression(CodeGenerator.getValue(tmpInput, elementType, index), elementType), elementType) - val code = s""" + val code = code""" final ArrayData $tmpInput = $input; final int $numElements = $tmpInput.numElements(); final Object[] $values = new Object[$numElements]; @@ -124,7 +125,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] val keyConverter = createCodeForArray(ctx, s"$tmpInput.keyArray()", keyType) val valueConverter = createCodeForArray(ctx, s"$tmpInput.valueArray()", valueType) - val code = s""" + val code = code""" final MapData $tmpInput = $input; ${keyConverter.code} ${valueConverter.code} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 01b4d6c4529bd..8f2a5a0dce943 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions.codegen import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.types._ /** @@ -286,7 +287,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro ctx, ctx.INPUT_ROW, exprEvals, exprTypes, rowWriter, isTopLevel = true) val code = - s""" + code""" |$rowWriter.reset(); |$evalSubexpr |$writeExpressions @@ -343,7 +344,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro | } | | public UnsafeRow apply(InternalRow ${ctx.INPUT_ROW}) { - | ${eval.code.trim} + | ${eval.code} | return ${eval.value}; | } | diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala index 74ff018488863..250ce48d059e0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions.codegen import java.lang.{Boolean => JBool} +import scala.collection.mutable.ArrayBuffer import scala.language.{existentials, implicitConversions} import org.apache.spark.sql.types.{BooleanType, DataType} @@ -114,6 +115,147 @@ object JavaCode { } } +/** + * A trait representing a block of java code. + */ +trait Block extends JavaCode { + + // The expressions to be evaluated inside this block. + def exprValues: Set[ExprValue] + + // Returns java code string for this code block. + override def toString: String = _marginChar match { + case Some(c) => code.stripMargin(c).trim + case _ => code.trim + } + + def length: Int = toString.length + + def nonEmpty: Boolean = toString.nonEmpty + + // The leading prefix that should be stripped from each line. + // By default we strip blanks or control characters followed by '|' from the line. + var _marginChar: Option[Char] = Some('|') + + def stripMargin(c: Char): this.type = { + _marginChar = Some(c) + this + } + + def stripMargin: this.type = { + _marginChar = Some('|') + this + } + + // Concatenates this block with other block. + def + (other: Block): Block +} + +object Block { + + val CODE_BLOCK_BUFFER_LENGTH: Int = 512 + + implicit def blocksToBlock(blocks: Seq[Block]): Block = Blocks(blocks) + + implicit class BlockHelper(val sc: StringContext) extends AnyVal { + def code(args: Any*): Block = { + sc.checkLengths(args) + if (sc.parts.length == 0) { + EmptyBlock + } else { + args.foreach { + case _: ExprValue => + case _: Int | _: Long | _: Float | _: Double | _: String => + case _: Block => + case other => throw new IllegalArgumentException( + s"Can not interpolate ${other.getClass.getName} into code block.") + } + + val (codeParts, blockInputs) = foldLiteralArgs(sc.parts, args) + CodeBlock(codeParts, blockInputs) + } + } + } + + // Folds eagerly the literal args into the code parts. + private def foldLiteralArgs(parts: Seq[String], args: Seq[Any]): (Seq[String], Seq[JavaCode]) = { + val codeParts = ArrayBuffer.empty[String] + val blockInputs = ArrayBuffer.empty[JavaCode] + + val strings = parts.iterator + val inputs = args.iterator + val buf = new StringBuilder(Block.CODE_BLOCK_BUFFER_LENGTH) + + buf.append(strings.next) + while (strings.hasNext) { + val input = inputs.next + input match { + case _: ExprValue | _: Block => + codeParts += buf.toString + buf.clear + blockInputs += input.asInstanceOf[JavaCode] + case _ => + buf.append(input) + } + buf.append(strings.next) + } + if (buf.nonEmpty) { + codeParts += buf.toString + } + + (codeParts.toSeq, blockInputs.toSeq) + } +} + +/** + * A block of java code. Including a sequence of code parts and some inputs to this block. + * The actual java code is generated by embedding the inputs into the code parts. + */ +case class CodeBlock(codeParts: Seq[String], blockInputs: Seq[JavaCode]) extends Block { + override lazy val exprValues: Set[ExprValue] = { + blockInputs.flatMap { + case b: Block => b.exprValues + case e: ExprValue => Set(e) + }.toSet + } + + override lazy val code: String = { + val strings = codeParts.iterator + val inputs = blockInputs.iterator + val buf = new StringBuilder(Block.CODE_BLOCK_BUFFER_LENGTH) + buf.append(StringContext.treatEscapes(strings.next)) + while (strings.hasNext) { + buf.append(inputs.next) + buf.append(StringContext.treatEscapes(strings.next)) + } + buf.toString + } + + override def + (other: Block): Block = other match { + case c: CodeBlock => Blocks(Seq(this, c)) + case b: Blocks => Blocks(Seq(this) ++ b.blocks) + case EmptyBlock => this + } +} + +case class Blocks(blocks: Seq[Block]) extends Block { + override lazy val exprValues: Set[ExprValue] = blocks.flatMap(_.exprValues).toSet + override lazy val code: String = blocks.map(_.toString).mkString("\n") + + override def + (other: Block): Block = other match { + case c: CodeBlock => Blocks(blocks :+ c) + case b: Blocks => Blocks(blocks ++ b.blocks) + case EmptyBlock => this + } +} + +object EmptyBlock extends Block with Serializable { + override val code: String = "" + override val exprValues: Set[ExprValue] = Set.empty + + override def + (other: Block): Block = other +} + /** * A typed java fragment that must be a valid java expression. */ @@ -123,10 +265,9 @@ trait ExprValue extends JavaCode { } object ExprValue { - implicit def exprValueToString(exprValue: ExprValue): String = exprValue.toString + implicit def exprValueToString(exprValue: ExprValue): String = exprValue.code } - /** * A java expression fragment. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 7da4c3cc6b9fa..c28eab71b84fd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion} import org.apache.spark.sql.catalyst.expressions.ArraySortLike.NullOrder import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData, TypeUtils} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -91,7 +92,7 @@ case class Size(child: Expression) extends UnaryExpression with ExpectsInputType override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val childGen = child.genCode(ctx) - ev.copy(code = s""" + ev.copy(code = code""" boolean ${ev.isNull} = false; ${childGen.code} ${CodeGenerator.javaType(dataType)} ${ev.value} = ${childGen.isNull} ? -1 : @@ -1177,14 +1178,14 @@ case class ArrayJoin( } if (nullable) { ev.copy( - s""" + code""" |boolean ${ev.isNull} = true; |UTF8String ${ev.value} = null; |$code """.stripMargin) } else { ev.copy( - s""" + code""" |UTF8String ${ev.value} = null; |$code """.stripMargin, FalseLiteral) @@ -1269,11 +1270,11 @@ case class ArrayMin(child: Expression) extends UnaryExpression with ImplicitCast val childGen = child.genCode(ctx) val javaType = CodeGenerator.javaType(dataType) val i = ctx.freshName("i") - val item = ExprCode("", + val item = ExprCode(EmptyBlock, isNull = JavaCode.isNullExpression(s"${childGen.value}.isNullAt($i)"), value = JavaCode.expression(CodeGenerator.getValue(childGen.value, dataType, i), dataType)) ev.copy(code = - s""" + code""" |${childGen.code} |boolean ${ev.isNull} = true; |$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; @@ -1334,11 +1335,11 @@ case class ArrayMax(child: Expression) extends UnaryExpression with ImplicitCast val childGen = child.genCode(ctx) val javaType = CodeGenerator.javaType(dataType) val i = ctx.freshName("i") - val item = ExprCode("", + val item = ExprCode(EmptyBlock, isNull = JavaCode.isNullExpression(s"${childGen.value}.isNullAt($i)"), value = JavaCode.expression(CodeGenerator.getValue(childGen.value, dataType, i), dataType)) ev.copy(code = - s""" + code""" |${childGen.code} |boolean ${ev.isNull} = true; |$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; @@ -1653,7 +1654,7 @@ case class Concat(children: Seq[Expression]) extends Expression { expressions = inputs, funcName = "valueConcat", extraArguments = (s"$javaType[]", args) :: Nil) - ev.copy(s""" + ev.copy(code""" $initCode $codes $javaType ${ev.value} = $concatenator.concat($args); @@ -1963,7 +1964,7 @@ case class ArrayRepeat(left: Expression, right: Expression) val resultCode = nullElementsProtection(ev, rightGen.isNull, coreLogic) ev.copy(code = - s""" + code""" |boolean ${ev.isNull} = false; |${leftGen.code} |${rightGen.code} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 67876a8565488..a9867aaeb0cfe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, TypeUtils} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.Platform @@ -63,7 +64,7 @@ case class CreateArray(children: Seq[Expression]) extends Expression { val (preprocess, assigns, postprocess, arrayData) = GenArrayData.genCodeToCreateArrayData(ctx, et, evals, false) ev.copy( - code = preprocess + assigns + postprocess, + code = code"${preprocess}${assigns}${postprocess}", value = JavaCode.variable(arrayData, dataType), isNull = FalseLiteral) } @@ -219,7 +220,7 @@ case class CreateMap(children: Seq[Expression]) extends Expression { val (preprocessValueData, assignValues, postprocessValueData, valueArrayData) = GenArrayData.genCodeToCreateArrayData(ctx, valueDt, evalValues, false) val code = - s""" + code""" final boolean ${ev.isNull} = false; $preprocessKeyData $assignKeys @@ -373,7 +374,7 @@ case class CreateNamedStruct(children: Seq[Expression]) extends CreateNamedStruc extraArguments = "Object[]" -> values :: Nil) ev.copy(code = - s""" + code""" |Object[] $values = new Object[${valExprs.size}]; |$valuesCode |final InternalRow ${ev.value} = new $rowClass($values); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index 205d77f6a9acf..77ac6c088022e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.types._ // scalastyle:off line.size.limit @@ -66,7 +67,7 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi val falseEval = falseValue.genCode(ctx) val code = - s""" + code""" |${condEval.code} |boolean ${ev.isNull} = false; |${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; @@ -265,7 +266,7 @@ case class CaseWhen( }.mkString) ev.copy(code = - s""" + code""" |${CodeGenerator.JAVA_BYTE} $resultState = $NOT_MATCHED; |do { | $codes diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index 03422fecb3209..e8d85f72f7a7a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -27,6 +27,7 @@ import org.apache.commons.lang3.StringEscapeUtils import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} @@ -717,7 +718,7 @@ abstract class UnixTime } else { val formatterName = ctx.addReferenceObj("formatter", formatter, df) val eval1 = left.genCode(ctx) - ev.copy(code = s""" + ev.copy(code = code""" ${eval1.code} boolean ${ev.isNull} = ${eval1.isNull}; $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; @@ -746,7 +747,7 @@ abstract class UnixTime }) case TimestampType => val eval1 = left.genCode(ctx) - ev.copy(code = s""" + ev.copy(code = code""" ${eval1.code} boolean ${ev.isNull} = ${eval1.isNull}; $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; @@ -757,7 +758,7 @@ abstract class UnixTime val tz = ctx.addReferenceObj("timeZone", timeZone) val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") val eval1 = left.genCode(ctx) - ev.copy(code = s""" + ev.copy(code = code""" ${eval1.code} boolean ${ev.isNull} = ${eval1.isNull}; $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; @@ -852,7 +853,7 @@ case class FromUnixTime(sec: Expression, format: Expression, timeZoneId: Option[ } else { val formatterName = ctx.addReferenceObj("formatter", formatter, df) val t = left.genCode(ctx) - ev.copy(code = s""" + ev.copy(code = code""" ${t.code} boolean ${ev.isNull} = ${t.isNull}; ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; @@ -1042,7 +1043,7 @@ case class StringToTimestampWithoutTimezone(child: Expression, timeZoneId: Optio val tz = ctx.addReferenceObj("timeZone", timeZone) val longOpt = ctx.freshName("longOpt") val eval = child.genCode(ctx) - val code = s""" + val code = code""" |${eval.code} |${CodeGenerator.JAVA_BOOLEAN} ${ev.isNull} = true; |${CodeGenerator.JAVA_LONG} ${ev.value} = ${CodeGenerator.defaultValue(TimestampType)}; @@ -1090,7 +1091,7 @@ case class FromUTCTimestamp(left: Expression, right: Expression) if (right.foldable) { val tz = right.eval().asInstanceOf[UTF8String] if (tz == null) { - ev.copy(code = s""" + ev.copy(code = code""" |boolean ${ev.isNull} = true; |long ${ev.value} = 0; """.stripMargin) @@ -1104,7 +1105,7 @@ case class FromUTCTimestamp(left: Expression, right: Expression) ctx.addImmutableStateIfNotExists(tzClass, utcTerm, v => s"""$v = $dtu.getTimeZone("UTC");""") val eval = left.genCode(ctx) - ev.copy(code = s""" + ev.copy(code = code""" |${eval.code} |boolean ${ev.isNull} = ${eval.isNull}; |long ${ev.value} = 0; @@ -1287,7 +1288,7 @@ case class ToUTCTimestamp(left: Expression, right: Expression) if (right.foldable) { val tz = right.eval().asInstanceOf[UTF8String] if (tz == null) { - ev.copy(code = s""" + ev.copy(code = code""" |boolean ${ev.isNull} = true; |long ${ev.value} = 0; """.stripMargin) @@ -1301,7 +1302,7 @@ case class ToUTCTimestamp(left: Expression, right: Expression) ctx.addImmutableStateIfNotExists(tzClass, utcTerm, v => s"""$v = $dtu.getTimeZone("UTC");""") val eval = left.genCode(ctx) - ev.copy(code = s""" + ev.copy(code = code""" |${eval.code} |boolean ${ev.isNull} = ${eval.isNull}; |long ${ev.value} = 0; @@ -1444,13 +1445,13 @@ trait TruncInstant extends BinaryExpression with ImplicitCastInputTypes { val javaType = CodeGenerator.javaType(dataType) if (format.foldable) { if (truncLevel == DateTimeUtils.TRUNC_INVALID || truncLevel > maxLevel) { - ev.copy(code = s""" + ev.copy(code = code""" boolean ${ev.isNull} = true; $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};""") } else { val t = instant.genCode(ctx) val truncFuncStr = truncFunc(t.value, truncLevel.toString) - ev.copy(code = s""" + ev.copy(code = code""" ${t.code} boolean ${ev.isNull} = ${t.isNull}; $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala index db1579ba28671..04de83343be71 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, EmptyBlock, ExprCode} import org.apache.spark.sql.types._ /** @@ -72,7 +72,8 @@ case class PromotePrecision(child: Expression) extends UnaryExpression { override def eval(input: InternalRow): Any = child.eval(input) /** Just a simple pass-through for code generation. */ override def genCode(ctx: CodegenContext): ExprCode = child.genCode(ctx) - override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = ev.copy("") + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = + ev.copy(EmptyBlock) override def prettyName: String = "promote_precision" override def sql: String = child.sql override lazy val canonicalized: Expression = child.canonicalized diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index 3af4bfebad45e..b7c52f1d7b40a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} import org.apache.spark.sql.types._ @@ -215,7 +216,7 @@ case class Stack(children: Seq[Expression]) extends Generator { // Create the collection. val wrapperClass = classOf[mutable.WrappedArray[_]].getName ev.copy(code = - s""" + code""" |$code |$wrapperClass ${ev.value} = $wrapperClass$$.MODULE$$.make($rowData); """.stripMargin, isNull = FalseLiteral) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala index ef790338bdd27..cec00b66f873c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala @@ -28,6 +28,7 @@ import org.apache.commons.codec.digest.DigestUtils import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.Platform @@ -293,7 +294,7 @@ abstract class HashExpression[E] extends Expression { foldFunctions = _.map(funcCall => s"${ev.value} = $funcCall;").mkString("\n")) ev.copy(code = - s""" + code""" |$hashResultType ${ev.value} = $seed; |$codes """.stripMargin) @@ -674,7 +675,7 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] { ev.copy(code = - s""" + code""" |${CodeGenerator.JAVA_INT} ${ev.value} = $seed; |${CodeGenerator.JAVA_INT} $childHash = 0; |$codes diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/inputFileBlock.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/inputFileBlock.scala index 2a3cc580273ee..3b0141ad52cc7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/inputFileBlock.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/inputFileBlock.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.rdd.InputFileBlockHolder import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral} +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.types.{DataType, LongType, StringType} import org.apache.spark.unsafe.types.UTF8String @@ -42,8 +43,9 @@ case class InputFileName() extends LeafExpression with Nondeterministic { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val className = InputFileBlockHolder.getClass.getName.stripSuffix("$") - ev.copy(code = s"final ${CodeGenerator.javaType(dataType)} ${ev.value} = " + - s"$className.getInputFilePath();", isNull = FalseLiteral) + val typeDef = s"final ${CodeGenerator.javaType(dataType)}" + ev.copy(code = code"$typeDef ${ev.value} = $className.getInputFilePath();", + isNull = FalseLiteral) } } @@ -65,8 +67,8 @@ case class InputFileBlockStart() extends LeafExpression with Nondeterministic { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val className = InputFileBlockHolder.getClass.getName.stripSuffix("$") - ev.copy(code = s"final ${CodeGenerator.javaType(dataType)} ${ev.value} = " + - s"$className.getStartOffset();", isNull = FalseLiteral) + val typeDef = s"final ${CodeGenerator.javaType(dataType)}" + ev.copy(code = code"$typeDef ${ev.value} = $className.getStartOffset();", isNull = FalseLiteral) } } @@ -88,7 +90,7 @@ case class InputFileBlockLength() extends LeafExpression with Nondeterministic { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val className = InputFileBlockHolder.getClass.getName.stripSuffix("$") - ev.copy(code = s"final ${CodeGenerator.javaType(dataType)} ${ev.value} = " + - s"$className.getLength();", isNull = FalseLiteral) + val typeDef = s"final ${CodeGenerator.javaType(dataType)}" + ev.copy(code = code"$typeDef ${ev.value} = $className.getLength();", isNull = FalseLiteral) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala index bc4cfcec47425..c2e1720259b53 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.NumberConverter import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -1191,11 +1192,11 @@ abstract class RoundBase(child: Expression, scale: Expression, val javaType = CodeGenerator.javaType(dataType) if (scaleV == null) { // if scale is null, no need to eval its child at all - ev.copy(code = s""" + ev.copy(code = code""" boolean ${ev.isNull} = true; $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};""") } else { - ev.copy(code = s""" + ev.copy(code = code""" ${ce.code} boolean ${ev.isNull} = ${ce.isNull}; $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index b7834696cafc3..5d98dac46cf17 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -21,6 +21,7 @@ import java.util.UUID import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.RandomUUIDGenerator import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -88,7 +89,7 @@ case class AssertTrue(child: Expression) extends UnaryExpression with ImplicitCa // Use unnamed reference that doesn't create a local field here to reduce the number of fields // because errMsgField is used only when the value is null or false. val errMsgField = ctx.addReferenceObj("errMsg", errMsg) - ExprCode(code = s"""${eval.code} + ExprCode(code = code"""${eval.code} |if (${eval.isNull} || !${eval.value}) { | throw new RuntimeException($errMsgField); |}""".stripMargin, isNull = TrueLiteral, @@ -151,7 +152,7 @@ case class Uuid(randomSeed: Option[Long] = None) extends LeafExpression with Sta ctx.addPartitionInitializationStatement(s"$randomGen = " + "new org.apache.spark.sql.catalyst.util.RandomUUIDGenerator(" + s"${randomSeed.get}L + partitionIndex);") - ev.copy(code = s"final UTF8String ${ev.value} = $randomGen.getNextUUIDUTF8String();", + ev.copy(code = code"final UTF8String ${ev.value} = $randomGen.getNextUUIDUTF8String();", isNull = FalseLiteral) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala index 0787342bce6bc..2eeed3bbb2d91 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ @@ -111,7 +112,7 @@ case class Coalesce(children: Seq[Expression]) extends Expression { ev.copy(code = - s""" + code""" |${ev.isNull} = true; |$resultType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; |do { @@ -232,7 +233,7 @@ case class IsNaN(child: Expression) extends UnaryExpression val eval = child.genCode(ctx) child.dataType match { case DoubleType | FloatType => - ev.copy(code = s""" + ev.copy(code = code""" ${eval.code} ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; ${ev.value} = !${eval.isNull} && Double.isNaN(${eval.value});""", isNull = FalseLiteral) @@ -278,7 +279,7 @@ case class NaNvl(left: Expression, right: Expression) val rightGen = right.genCode(ctx) left.dataType match { case DoubleType | FloatType => - ev.copy(code = s""" + ev.copy(code = code""" ${leftGen.code} boolean ${ev.isNull} = false; ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; @@ -440,7 +441,7 @@ case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends Predicate }.mkString) ev.copy(code = - s""" + code""" |${CodeGenerator.JAVA_INT} $nonnull = 0; |do { | $codes diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index f974fd81fc788..2bf4203d0fec3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -33,6 +33,7 @@ import org.apache.spark.sql.catalyst.ScalaReflection.universe.TermName import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} @@ -269,7 +270,7 @@ case class StaticInvoke( s"${ev.value} = $callFunc;" } - val code = s""" + val code = code""" $argCode $prepareIsNull $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; @@ -385,8 +386,7 @@ case class Invoke( """ } - val code = s""" - ${obj.code} + val code = obj.code + code""" boolean ${ev.isNull} = true; $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${obj.isNull}) { @@ -492,7 +492,7 @@ case class NewInstance( s"new $className($argString)" } - val code = s""" + val code = code""" $argCode ${outer.map(_.code).getOrElse("")} final $javaType ${ev.value} = ${ev.isNull} ? @@ -532,9 +532,7 @@ case class UnwrapOption( val javaType = CodeGenerator.javaType(dataType) val inputObject = child.genCode(ctx) - val code = s""" - ${inputObject.code} - + val code = inputObject.code + code""" final boolean ${ev.isNull} = ${inputObject.isNull} || ${inputObject.value}.isEmpty(); $javaType ${ev.value} = ${ev.isNull} ? ${CodeGenerator.defaultValue(dataType)} : (${CodeGenerator.boxedType(javaType)}) ${inputObject.value}.get(); @@ -564,9 +562,7 @@ case class WrapOption(child: Expression, optType: DataType) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val inputObject = child.genCode(ctx) - val code = s""" - ${inputObject.code} - + val code = inputObject.code + code""" scala.Option ${ev.value} = ${inputObject.isNull} ? scala.Option$$.MODULE$$.apply(null) : new scala.Some(${inputObject.value}); @@ -935,8 +931,7 @@ case class MapObjects private( ) } - val code = s""" - ${genInputData.code} + val code = genInputData.code + code""" ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${genInputData.isNull}) { @@ -1147,8 +1142,7 @@ case class CatalystToExternalMap private( """ val getBuilderResult = s"${ev.value} = (${collClass.getName}) $builderValue.result();" - val code = s""" - ${genInputData.code} + val code = genInputData.code + code""" ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${genInputData.isNull}) { @@ -1391,9 +1385,8 @@ case class ExternalMapToCatalyst private( val mapCls = classOf[ArrayBasedMapData].getName val convertedKeyType = CodeGenerator.boxedType(keyConverter.dataType) val convertedValueType = CodeGenerator.boxedType(valueConverter.dataType) - val code = - s""" - ${inputMap.code} + val code = inputMap.code + + code""" ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${inputMap.isNull}) { final int $length = ${inputMap.value}.size(); @@ -1471,7 +1464,7 @@ case class CreateExternalRow(children: Seq[Expression], schema: StructType) val schemaField = ctx.addReferenceObj("schema", schema) val code = - s""" + code""" |Object[] $values = new Object[${children.size}]; |$childrenCode |final ${classOf[Row].getName} ${ev.value} = new $rowClass($values, $schemaField); @@ -1499,8 +1492,7 @@ case class EncodeUsingSerializer(child: Expression, kryo: Boolean) val javaType = CodeGenerator.javaType(dataType) val serialize = s"$serializer.serialize(${input.value}, null).array()" - val code = s""" - ${input.code} + val code = input.code + code""" final $javaType ${ev.value} = ${input.isNull} ? ${CodeGenerator.defaultValue(dataType)} : $serialize; """ @@ -1532,8 +1524,7 @@ case class DecodeUsingSerializer[T](child: Expression, tag: ClassTag[T], kryo: B val deserialize = s"($javaType) $serializer.deserialize(java.nio.ByteBuffer.wrap(${input.value}), null)" - val code = s""" - ${input.code} + val code = input.code + code""" final $javaType ${ev.value} = ${input.isNull} ? ${CodeGenerator.defaultValue(dataType)} : $deserialize; """ @@ -1614,9 +1605,8 @@ case class InitializeJavaBean(beanInstance: Expression, setters: Map[String, Exp funcName = "initializeJavaBean", extraArguments = beanInstanceJavaType -> javaBeanInstance :: Nil) - val code = - s""" - |${instanceGen.code} + val code = instanceGen.code + + code""" |$beanInstanceJavaType $javaBeanInstance = ${instanceGen.value}; |if (!${instanceGen.isNull}) { | $initializeCode @@ -1664,9 +1654,7 @@ case class AssertNotNull(child: Expression, walkedTypePath: Seq[String] = Nil) // because errMsgField is used only when the value is null. val errMsgField = ctx.addReferenceObj("errMsg", errMsg) - val code = s""" - ${childGen.code} - + val code = childGen.code + code""" if (${childGen.isNull}) { throw new NullPointerException($errMsgField); } @@ -1709,7 +1697,7 @@ case class GetExternalRowField( // because errMsgField is used only when the field is null. val errMsgField = ctx.addReferenceObj("errMsg", errMsg) val row = child.genCode(ctx) - val code = s""" + val code = code""" ${row.code} if (${row.isNull}) { @@ -1784,7 +1772,7 @@ case class ValidateExternalType(child: Expression, expected: DataType) s"$obj instanceof ${CodeGenerator.boxedType(dataType)}" } - val code = s""" + val code = code""" ${input.code} ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${input.isNull}) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index f8c6dc4e6adc9..f54103c4fbfba 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -22,6 +22,7 @@ import scala.collection.immutable.TreeSet import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral, GenerateSafeProjection, GenerateUnsafeProjection, Predicate => BasePredicate} +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ @@ -290,7 +291,7 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { }.mkString("\n")) ev.copy(code = - s""" + code""" |${valueGen.code} |byte $tmpResult = $HAS_NULL; |if (!${valueGen.isNull}) { @@ -354,7 +355,7 @@ case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with "" } ev.copy(code = - s""" + code""" |${childGen.code} |${CodeGenerator.JAVA_BOOLEAN} ${ev.isNull} = ${childGen.isNull}; |${CodeGenerator.JAVA_BOOLEAN} ${ev.value} = false; @@ -406,7 +407,7 @@ case class And(left: Expression, right: Expression) extends BinaryOperator with // The result should be `false`, if any of them is `false` whenever the other is null or not. if (!left.nullable && !right.nullable) { - ev.copy(code = s""" + ev.copy(code = code""" ${eval1.code} boolean ${ev.value} = false; @@ -415,7 +416,7 @@ case class And(left: Expression, right: Expression) extends BinaryOperator with ${ev.value} = ${eval2.value}; }""", isNull = FalseLiteral) } else { - ev.copy(code = s""" + ev.copy(code = code""" ${eval1.code} boolean ${ev.isNull} = false; boolean ${ev.value} = false; @@ -470,7 +471,7 @@ case class Or(left: Expression, right: Expression) extends BinaryOperator with P // The result should be `true`, if any of them is `true` whenever the other is null or not. if (!left.nullable && !right.nullable) { ev.isNull = FalseLiteral - ev.copy(code = s""" + ev.copy(code = code""" ${eval1.code} boolean ${ev.value} = true; @@ -479,7 +480,7 @@ case class Or(left: Expression, right: Expression) extends BinaryOperator with P ${ev.value} = ${eval2.value}; }""", isNull = FalseLiteral) } else { - ev.copy(code = s""" + ev.copy(code = code""" ${eval1.code} boolean ${ev.isNull} = false; boolean ${ev.value} = true; @@ -621,7 +622,7 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp val eval1 = left.genCode(ctx) val eval2 = right.genCode(ctx) val equalCode = ctx.genEqual(left.dataType, eval1.value, eval2.value) - ev.copy(code = eval1.code + eval2.code + s""" + ev.copy(code = eval1.code + eval2.code + code""" boolean ${ev.value} = (${eval1.isNull} && ${eval2.isNull}) || (!${eval1.isNull} && !${eval2.isNull} && $equalCode);""", isNull = FalseLiteral) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala index 2653b28f6c3bd..926c2f00d430d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral} +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.types._ import org.apache.spark.util.Utils import org.apache.spark.util.random.XORShiftRandom @@ -82,7 +83,7 @@ case class Rand(child: Expression) extends RDG { val rngTerm = ctx.addMutableState(className, "rng") ctx.addPartitionInitializationStatement( s"$rngTerm = new $className(${seed}L + partitionIndex);") - ev.copy(code = s""" + ev.copy(code = code""" final ${CodeGenerator.javaType(dataType)} ${ev.value} = $rngTerm.nextDouble();""", isNull = FalseLiteral) } @@ -120,7 +121,7 @@ case class Randn(child: Expression) extends RDG { val rngTerm = ctx.addMutableState(className, "rng") ctx.addPartitionInitializationStatement( s"$rngTerm = new $className(${seed}L + partitionIndex);") - ev.copy(code = s""" + ev.copy(code = code""" final ${CodeGenerator.javaType(dataType)} ${ev.value} = $rngTerm.nextGaussian();""", isNull = FalseLiteral) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index ad0c0791d895f..7b68bb771faf3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -23,6 +23,7 @@ import java.util.regex.{MatchResult, Pattern} import org.apache.commons.lang3.StringEscapeUtils import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.{GenericArrayData, StringUtils} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -123,7 +124,7 @@ case class Like(left: Expression, right: Expression) extends StringRegexExpressi // We don't use nullSafeCodeGen here because we don't want to re-evaluate right again. val eval = left.genCode(ctx) - ev.copy(code = s""" + ev.copy(code = code""" ${eval.code} boolean ${ev.isNull} = ${eval.isNull}; ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; @@ -132,7 +133,7 @@ case class Like(left: Expression, right: Expression) extends StringRegexExpressi } """) } else { - ev.copy(code = s""" + ev.copy(code = code""" boolean ${ev.isNull} = true; ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; """) @@ -198,7 +199,7 @@ case class RLike(left: Expression, right: Expression) extends StringRegexExpress // We don't use nullSafeCodeGen here because we don't want to re-evaluate right again. val eval = left.genCode(ctx) - ev.copy(code = s""" + ev.copy(code = code""" ${eval.code} boolean ${ev.isNull} = ${eval.isNull}; ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; @@ -207,7 +208,7 @@ case class RLike(left: Expression, right: Expression) extends StringRegexExpress } """) } else { - ev.copy(code = s""" + ev.copy(code = code""" boolean ${ev.isNull} = true; ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; """) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index ea005a26a4c8b..9823b2fc5ad97 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -27,6 +27,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, TypeUtils} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{ByteArray, UTF8String} @@ -105,7 +106,7 @@ case class ConcatWs(children: Seq[Expression]) expressions = inputs, funcName = "valueConcatWs", extraArguments = ("UTF8String[]", args) :: Nil) - ev.copy(s""" + ev.copy(code""" UTF8String[] $args = new UTF8String[$numArgs]; ${separator.code} $codes @@ -149,7 +150,7 @@ case class ConcatWs(children: Seq[Expression]) } }.unzip - val codes = ctx.splitExpressionsWithCurrentInputs(evals.map(_.code)) + val codes = ctx.splitExpressionsWithCurrentInputs(evals.map(_.code.toString)) val varargCounts = ctx.splitExpressionsWithCurrentInputs( expressions = varargCount, @@ -176,7 +177,7 @@ case class ConcatWs(children: Seq[Expression]) foldFunctions = _.map(funcCall => s"$idxVararg = $funcCall;").mkString("\n")) ev.copy( - s""" + code""" $codes int $varargNum = ${children.count(_.dataType == StringType) - 1}; int $idxVararg = 0; @@ -288,7 +289,7 @@ case class Elt(children: Seq[Expression]) extends Expression { }.mkString) ev.copy( - s""" + code""" |${index.code} |final int $indexVal = ${index.value}; |${CodeGenerator.JAVA_BOOLEAN} $indexMatched = false; @@ -654,7 +655,7 @@ case class StringTrim( val srcString = evals(0) if (evals.length == 1) { - ev.copy(evals.map(_.code).mkString + s""" + ev.copy(evals.map(_.code) :+ code""" boolean ${ev.isNull} = false; UTF8String ${ev.value} = null; if (${srcString.isNull}) { @@ -671,7 +672,7 @@ case class StringTrim( } else { ${ev.value} = ${srcString.value}.trim(${trimString.value}); }""" - ev.copy(evals.map(_.code).mkString + s""" + ev.copy(evals.map(_.code) :+ code""" boolean ${ev.isNull} = false; UTF8String ${ev.value} = null; if (${srcString.isNull}) { @@ -754,7 +755,7 @@ case class StringTrimLeft( val srcString = evals(0) if (evals.length == 1) { - ev.copy(evals.map(_.code).mkString + s""" + ev.copy(evals.map(_.code) :+ code""" boolean ${ev.isNull} = false; UTF8String ${ev.value} = null; if (${srcString.isNull}) { @@ -771,7 +772,7 @@ case class StringTrimLeft( } else { ${ev.value} = ${srcString.value}.trimLeft(${trimString.value}); }""" - ev.copy(evals.map(_.code).mkString + s""" + ev.copy(evals.map(_.code) :+ code""" boolean ${ev.isNull} = false; UTF8String ${ev.value} = null; if (${srcString.isNull}) { @@ -856,7 +857,7 @@ case class StringTrimRight( val srcString = evals(0) if (evals.length == 1) { - ev.copy(evals.map(_.code).mkString + s""" + ev.copy(evals.map(_.code) :+ code""" boolean ${ev.isNull} = false; UTF8String ${ev.value} = null; if (${srcString.isNull}) { @@ -873,7 +874,7 @@ case class StringTrimRight( } else { ${ev.value} = ${srcString.value}.trimRight(${trimString.value}); }""" - ev.copy(evals.map(_.code).mkString + s""" + ev.copy(evals.map(_.code) :+ code""" boolean ${ev.isNull} = false; UTF8String ${ev.value} = null; if (${srcString.isNull}) { @@ -1024,7 +1025,7 @@ case class StringLocate(substr: Expression, str: Expression, start: Expression) val substrGen = substr.genCode(ctx) val strGen = str.genCode(ctx) val startGen = start.genCode(ctx) - ev.copy(code = s""" + ev.copy(code = code""" int ${ev.value} = 0; boolean ${ev.isNull} = false; ${startGen.code} @@ -1350,7 +1351,7 @@ case class FormatString(children: Expression*) extends Expression with ImplicitC val formatter = classOf[java.util.Formatter].getName val sb = ctx.freshName("sb") val stringBuffer = classOf[StringBuffer].getName - ev.copy(code = s""" + ev.copy(code = code""" ${pattern.code} boolean ${ev.isNull} = ${pattern.isNull}; ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelperSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelperSuite.scala index 64b65e2070ed6..7c7c4cccee253 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelperSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelperSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.types.{DataType, IntegerType} /** @@ -45,7 +46,7 @@ case class BadCodegenExpression() extends LeafExpression { override def eval(input: InternalRow): Any = 10 override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { ev.copy(code = - s""" + code""" |int some_variable = 11; |int ${ev.value} = 10; """.stripMargin) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeBlockSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeBlockSuite.scala new file mode 100644 index 0000000000000..d2c6420eadb20 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeBlockSuite.scala @@ -0,0 +1,136 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.codegen + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ +import org.apache.spark.sql.types.{BooleanType, IntegerType} + +class CodeBlockSuite extends SparkFunSuite { + + test("Block interpolates string and ExprValue inputs") { + val isNull = JavaCode.isNullVariable("expr1_isNull") + val stringLiteral = "false" + val code = code"boolean $isNull = $stringLiteral;" + assert(code.toString == "boolean expr1_isNull = false;") + } + + test("Literals are folded into string code parts instead of block inputs") { + val value = JavaCode.variable("expr1", IntegerType) + val intLiteral = 1 + val code = code"int $value = $intLiteral;" + assert(code.asInstanceOf[CodeBlock].blockInputs === Seq(value)) + } + + test("Block.stripMargin") { + val isNull = JavaCode.isNullVariable("expr1_isNull") + val value = JavaCode.variable("expr1", IntegerType) + val code1 = + code""" + |boolean $isNull = false; + |int $value = ${JavaCode.defaultLiteral(IntegerType)};""".stripMargin + val expected = + s""" + |boolean expr1_isNull = false; + |int expr1 = ${JavaCode.defaultLiteral(IntegerType)};""".stripMargin.trim + assert(code1.toString == expected) + + val code2 = + code""" + >boolean $isNull = false; + >int $value = ${JavaCode.defaultLiteral(IntegerType)};""".stripMargin('>') + assert(code2.toString == expected) + } + + test("Block can capture input expr values") { + val isNull = JavaCode.isNullVariable("expr1_isNull") + val value = JavaCode.variable("expr1", IntegerType) + val code = + code""" + |boolean $isNull = false; + |int $value = -1; + """.stripMargin + val exprValues = code.exprValues + assert(exprValues.size == 2) + assert(exprValues === Set(value, isNull)) + } + + test("concatenate blocks") { + val isNull1 = JavaCode.isNullVariable("expr1_isNull") + val value1 = JavaCode.variable("expr1", IntegerType) + val isNull2 = JavaCode.isNullVariable("expr2_isNull") + val value2 = JavaCode.variable("expr2", IntegerType) + val literal = JavaCode.literal("100", IntegerType) + + val code = + code""" + |boolean $isNull1 = false; + |int $value1 = -1;""".stripMargin + + code""" + |boolean $isNull2 = true; + |int $value2 = $literal;""".stripMargin + + val expected = + """ + |boolean expr1_isNull = false; + |int expr1 = -1; + |boolean expr2_isNull = true; + |int expr2 = 100;""".stripMargin.trim + + assert(code.toString == expected) + + val exprValues = code.exprValues + assert(exprValues.size == 5) + assert(exprValues === Set(isNull1, value1, isNull2, value2, literal)) + } + + test("Throws exception when interpolating unexcepted object in code block") { + val obj = Tuple2(1, 1) + val e = intercept[IllegalArgumentException] { + code"$obj" + } + assert(e.getMessage().contains(s"Can not interpolate ${obj.getClass.getName}")) + } + + test("replace expr values in code block") { + val expr = JavaCode.expression("1 + 1", IntegerType) + val isNull = JavaCode.isNullVariable("expr1_isNull") + val exprInFunc = JavaCode.variable("expr1", IntegerType) + + val code = + code""" + |callFunc(int $expr) { + | boolean $isNull = false; + | int $exprInFunc = $expr + 1; + |}""".stripMargin + + val aliasedParam = JavaCode.variable("aliased", expr.javaType) + val aliasedInputs = code.asInstanceOf[CodeBlock].blockInputs.map { + case _: SimpleExprValue => aliasedParam + case other => other + } + val aliasedCode = CodeBlock(code.asInstanceOf[CodeBlock].codeParts, aliasedInputs).stripMargin + val expected = + code""" + |callFunc(int $aliasedParam) { + | boolean $isNull = false; + | int $exprInFunc = $aliasedParam + 1; + |}""".stripMargin + assert(aliasedCode.toString == expected.toString) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala index fc3dbc1c5591b..48abad9078650 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.catalyst.expressions.{BoundReference, UnsafeRow} import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.types.DataType import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector} @@ -58,14 +59,14 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport { } val valueVar = ctx.freshName("value") val str = s"columnVector[$columnVar, $ordinal, ${dataType.simpleString}]" - val code = s"${ctx.registerComment(str)}\n" + (if (nullable) { - s""" + val code = code"${ctx.registerComment(str)}" + (if (nullable) { + code""" boolean $isNullVar = $columnVar.isNullAt($ordinal); $javaType $valueVar = $isNullVar ? ${CodeGenerator.defaultValue(dataType)} : ($value); """ } else { - s"$javaType $valueVar = $value;" - }).trim + code"$javaType $valueVar = $value;" + }) ExprCode(code, isNullVar, JavaCode.variable(valueVar, dataType)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala index e4812f3d338fb..5b4edf5136e3f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartitioning} import org.apache.spark.sql.execution.metric.SQLMetrics @@ -152,7 +153,7 @@ case class ExpandExec( } else { val isNull = ctx.freshName("isNull") val value = ctx.freshName("value") - val code = s""" + val code = code""" |boolean $isNull = true; |${CodeGenerator.javaType(firstExpr.dataType)} $value = | ${CodeGenerator.defaultValue(firstExpr.dataType)}; diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala index f40c50df74ccb..2549b9e1537a0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala @@ -21,6 +21,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.types._ @@ -313,13 +314,13 @@ case class GenerateExec( if (checks.nonEmpty) { val isNull = ctx.freshName("isNull") val code = - s""" + code""" |boolean $isNull = ${checks.mkString(" || ")}; |$javaType $value = $isNull ? ${CodeGenerator.defaultValue(dt)} : $getter; """.stripMargin ExprCode(code, JavaCode.isNullVariable(isNull), JavaCode.variable(value, dt)) } else { - ExprCode(s"$javaType $value = $getter;", FalseLiteral, JavaCode.variable(value, dt)) + ExprCode(code"$javaType $value = $getter;", FalseLiteral, JavaCode.variable(value, dt)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index 828b51fa199de..372dc3db36ce6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -27,6 +27,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.aggregate.HashAggregateExec @@ -122,10 +123,10 @@ trait CodegenSupport extends SparkPlan { ctx.INPUT_ROW = row ctx.currentVars = colVars val ev = GenerateUnsafeProjection.createCode(ctx, colExprs, false) - val code = s""" + val code = code""" |$evaluateInputs - |${ev.code.trim} - """.stripMargin.trim + |${ev.code} + """.stripMargin ExprCode(code, FalseLiteral, ev.value) } else { // There are no columns @@ -259,8 +260,8 @@ trait CodegenSupport extends SparkPlan { * them to be evaluated twice. */ protected def evaluateVariables(variables: Seq[ExprCode]): String = { - val evaluate = variables.filter(_.code != "").map(_.code.trim).mkString("\n") - variables.foreach(_.code = "") + val evaluate = variables.filter(_.code.nonEmpty).map(_.code.toString).mkString("\n") + variables.foreach(_.code = EmptyBlock) evaluate } @@ -275,8 +276,8 @@ trait CodegenSupport extends SparkPlan { val evaluateVars = new StringBuilder variables.zipWithIndex.foreach { case (ev, i) => if (ev.code != "" && required.contains(attributes(i))) { - evaluateVars.append(ev.code.trim + "\n") - ev.code = "" + evaluateVars.append(ev.code.toString + "\n") + ev.code = EmptyBlock } } evaluateVars.toString() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 6a8ec4f722aea..8c7b2c187cccd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} @@ -190,7 +191,7 @@ case class HashAggregateExec( val value = ctx.addMutableState(CodeGenerator.javaType(e.dataType), "bufValue") // The initial expression should not access any column val ev = e.genCode(ctx) - val initVars = s""" + val initVars = code""" | $isNull = ${ev.isNull}; | $value = ${ev.value}; """.stripMargin @@ -773,8 +774,8 @@ case class HashAggregateExec( val findOrInsertRegularHashMap: String = s""" |// generate grouping key - |${unsafeRowKeyCode.code.trim} - |${hashEval.code.trim} + |${unsafeRowKeyCode.code} + |${hashEval.code} |if ($checkFallbackForBytesToBytesMap) { | // try to get the buffer from hash map | $unsafeRowBuffer = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala index de2d630de3fdb..e1c85823259b1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, DeclarativeAggregate} import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.types._ /** @@ -50,7 +51,7 @@ abstract class HashMapGenerator( val value = ctx.addMutableState(CodeGenerator.javaType(e.dataType), "bufValue") val ev = e.genCode(ctx) val initVars = - s""" + code""" | $isNull = ${ev.isNull}; | $value = ${ev.value}; """.stripMargin diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala index 6fa716d9fadee..0da0e8610c392 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala @@ -23,6 +23,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical.{BroadcastDistribution, Distribution, UnspecifiedDistribution} import org.apache.spark.sql.execution.{BinaryExecNode, CodegenSupport, SparkPlan} @@ -183,7 +184,7 @@ case class BroadcastHashJoinExec( val isNull = ctx.freshName("isNull") val value = ctx.freshName("value") val javaType = CodeGenerator.javaType(a.dataType) - val code = s""" + val code = code""" |boolean $isNull = true; |$javaType $value = ${CodeGenerator.defaultValue(a.dataType)}; |if ($matched != null) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index d8261f0f33b61..f4b9d132122e4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -23,6 +23,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution._ @@ -521,7 +522,7 @@ case class SortMergeJoinExec( if (a.nullable) { val isNull = ctx.freshName("isNull") val code = - s""" + code""" |$isNull = $leftRow.isNullAt($i); |$value = $isNull ? $defaultValue : ($valueCode); """.stripMargin @@ -533,7 +534,7 @@ case class SortMergeJoinExec( (ExprCode(code, JavaCode.isNullVariable(isNull), JavaCode.variable(value, a.dataType)), leftVarsDecl) } else { - val code = s"$value = $valueCode;" + val code = code"$value = $valueCode;" val leftVarsDecl = s"""$javaType $value = $defaultValue;""" (ExprCode(code, FalseLiteral, JavaCode.variable(value, a.dataType)), leftVarsDecl) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala index 109fcf90a3ec9..8280a3ce39845 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Expression, Generator} import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{IntegerType, StructType} @@ -315,6 +316,7 @@ case class EmptyGenerator() extends Generator { override def eval(input: InternalRow): TraversableOnce[InternalRow] = Seq.empty override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val iteratorClass = classOf[Iterator[_]].getName - ev.copy(code = s"$iteratorClass ${ev.value} = $iteratorClass$$.MODULE$$.empty();") + ev.copy(code = + code"$iteratorClass ${ev.value} = $iteratorClass$$.MODULE$$.empty();") } } From bc6ea614ad4c6a323c78f209120287b256a458d3 Mon Sep 17 00:00:00 2001 From: "Vayda, Oleksandr: IT (PRG)" Date: Tue, 22 May 2018 13:01:07 -0700 Subject: [PATCH 0853/1161] [SPARK-24348][SQL] "element_at" error fix ## What changes were proposed in this pull request? ### Fixes a `scala.MatchError` in the `element_at` operation - [SPARK-24348](https://issues.apache.org/jira/browse/SPARK-24348) When calling `element_at` with a wrong first operand type an `AnalysisException` should be thrown instead of `scala.MatchError` *Example:* ```sql select element_at('foo', 1) ``` results in: ``` scala.MatchError: StringType (of class org.apache.spark.sql.types.StringType$) at org.apache.spark.sql.catalyst.expressions.ElementAt.inputTypes(collectionOperations.scala:1469) at org.apache.spark.sql.catalyst.expressions.ExpectsInputTypes$class.checkInputDataTypes(ExpectsInputTypes.scala:44) at org.apache.spark.sql.catalyst.expressions.ElementAt.checkInputDataTypes(collectionOperations.scala:1478) at org.apache.spark.sql.catalyst.expressions.Expression.resolved$lzycompute(Expression.scala:168) at org.apache.spark.sql.catalyst.expressions.Expression.resolved(Expression.scala:168) at org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveAliases$$anonfun$org$apache$spark$sql$catalyst$analysis$Analyzer$ResolveAliases$$assignAliases$1$$anonfun$apply$3.applyOrElse(Analyzer.scala:256) at org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveAliases$$anonfun$org$apache$spark$sql$catalyst$analysis$Analyzer$ResolveAliases$$assignAliases$1$$anonfun$apply$3.applyOrElse(Analyzer.scala:252) at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$transformUp$1.apply(TreeNode.scala:289) at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$transformUp$1.apply(TreeNode.scala:289) at org.apache.spark.sql.catalyst.trees.CurrentOrigin$.withOrigin(TreeNode.scala:70) at org.apache.spark.sql.catalyst.trees.TreeNode.transformUp(TreeNode.scala:288) ``` ## How was this patch tested? unit tests Author: Vayda, Oleksandr: IT (PRG) Closes #21395 from wajda/SPARK-24348-element_at-error-fix. --- .../sql/catalyst/expressions/collectionOperations.scala | 1 + .../org/apache/spark/sql/DataFrameFunctionsSuite.scala | 6 ++++++ 2 files changed, 7 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index c28eab71b84fd..03b3b21a16617 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -1470,6 +1470,7 @@ case class ElementAt(left: Expression, right: Expression) extends GetMapValueUti left.dataType match { case _: ArrayType => IntegerType case _: MapType => left.dataType.asInstanceOf[MapType].keyType + case _ => AnyDataType // no match for a wrong 'left' expression type } ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index df23e07e441a0..ec2a569f900d1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -756,6 +756,12 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { df.selectExpr("element_at(a, -1)"), Seq(Row("3"), Row(""), Row(null)) ) + + val e = intercept[AnalysisException] { + Seq(("a string element", 1)).toDF().selectExpr("element_at(_1, _2)") + } + assert(e.message.contains( + "argument 1 requires (array or map) type, however, '`_1`' is of string type")) } test("concat function - arrays") { From 79e06faa4ef6596c9e2d4be09c74b935064021bb Mon Sep 17 00:00:00 2001 From: Gabor Somogyi Date: Tue, 22 May 2018 13:43:45 -0700 Subject: [PATCH 0854/1161] [SPARK-19185][DSTREAMS] Avoid concurrent use of cached consumers in CachedKafkaConsumer ## What changes were proposed in this pull request? `CachedKafkaConsumer` in the project streaming-kafka-0-10 is designed to maintain a pool of KafkaConsumers that can be reused. However, it was built with the assumption there will be only one thread trying to read the same Kafka TopicPartition at the same time. This assumption is not true all the time and this can inadvertently lead to ConcurrentModificationException. Here is a better way to design this. The consumer pool should be smart enough to avoid concurrent use of a cached consumer. If there is another request for the same TopicPartition as a currently in-use consumer, the pool should automatically return a fresh consumer. - There are effectively two kinds of consumer that may be generated - Cached consumer - this should be returned to the pool at task end - Non-cached consumer - this should be closed at task end - A trait called `KafkaDataConsumer` is introduced to hide this difference from the users of the consumer so that the client code does not have to reason about whether to stop and release. They simply call `val consumer = KafkaDataConsumer.acquire` and then `consumer.release`. - If there is request for a consumer that is in-use, then a new consumer is generated. - If there is request for a consumer which is a task reattempt, then already existing cached consumer will be invalidated and a new consumer is generated. This could fix potential issues if the source of the reattempt is a malfunctioning consumer. - In addition, I renamed the `CachedKafkaConsumer` class to `KafkaDataConsumer` because is a misnomer given that what it returns may or may not be cached. ## How was this patch tested? A new stress test that verifies it is safe to concurrently get consumers for the same TopicPartition from the consumer pool. Author: Gabor Somogyi Closes #20997 from gaborgsomogyi/SPARK-19185. --- .../sql/kafka010/KafkaDataConsumer.scala | 2 +- .../kafka010/CachedKafkaConsumer.scala | 226 ----------- .../kafka010/KafkaDataConsumer.scala | 359 ++++++++++++++++++ .../spark/streaming/kafka010/KafkaRDD.scala | 20 +- .../kafka010/KafkaDataConsumerSuite.scala | 131 +++++++ 5 files changed, 496 insertions(+), 242 deletions(-) delete mode 100644 external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/CachedKafkaConsumer.scala create mode 100644 external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaDataConsumer.scala create mode 100644 external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaDataConsumerSuite.scala diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataConsumer.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataConsumer.scala index 48508d057a540..941f0ab177e48 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataConsumer.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataConsumer.scala @@ -395,7 +395,7 @@ private[kafka010] object KafkaDataConsumer extends Logging { // likely running on a beefy machine that can handle a large number of simultaneously // active consumers. - if (entry.getValue.inUse == false && this.size > capacity) { + if (!entry.getValue.inUse && this.size > capacity) { logWarning( s"KafkaConsumer cache hitting max capacity of $capacity, " + s"removing consumer for ${entry.getKey}") diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/CachedKafkaConsumer.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/CachedKafkaConsumer.scala deleted file mode 100644 index aeb8c1dc342b3..0000000000000 --- a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/CachedKafkaConsumer.scala +++ /dev/null @@ -1,226 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.streaming.kafka010 - -import java.{ util => ju } - -import org.apache.kafka.clients.consumer.{ ConsumerConfig, ConsumerRecord, KafkaConsumer } -import org.apache.kafka.common.{ KafkaException, TopicPartition } - -import org.apache.spark.internal.Logging - -/** - * Consumer of single topicpartition, intended for cached reuse. - * Underlying consumer is not threadsafe, so neither is this, - * but processing the same topicpartition and group id in multiple threads is usually bad anyway. - */ -private[kafka010] -class CachedKafkaConsumer[K, V] private( - val groupId: String, - val topic: String, - val partition: Int, - val kafkaParams: ju.Map[String, Object]) extends Logging { - - require(groupId == kafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG), - "groupId used for cache key must match the groupId in kafkaParams") - - val topicPartition = new TopicPartition(topic, partition) - - protected val consumer = { - val c = new KafkaConsumer[K, V](kafkaParams) - val tps = new ju.ArrayList[TopicPartition]() - tps.add(topicPartition) - c.assign(tps) - c - } - - // TODO if the buffer was kept around as a random-access structure, - // could possibly optimize re-calculating of an RDD in the same batch - protected var buffer = ju.Collections.emptyListIterator[ConsumerRecord[K, V]]() - protected var nextOffset = -2L - - def close(): Unit = consumer.close() - - /** - * Get the record for the given offset, waiting up to timeout ms if IO is necessary. - * Sequential forward access will use buffers, but random access will be horribly inefficient. - */ - def get(offset: Long, timeout: Long): ConsumerRecord[K, V] = { - logDebug(s"Get $groupId $topic $partition nextOffset $nextOffset requested $offset") - if (offset != nextOffset) { - logInfo(s"Initial fetch for $groupId $topic $partition $offset") - seek(offset) - poll(timeout) - } - - if (!buffer.hasNext()) { poll(timeout) } - require(buffer.hasNext(), - s"Failed to get records for $groupId $topic $partition $offset after polling for $timeout") - var record = buffer.next() - - if (record.offset != offset) { - logInfo(s"Buffer miss for $groupId $topic $partition $offset") - seek(offset) - poll(timeout) - require(buffer.hasNext(), - s"Failed to get records for $groupId $topic $partition $offset after polling for $timeout") - record = buffer.next() - require(record.offset == offset, - s"Got wrong record for $groupId $topic $partition even after seeking to offset $offset " + - s"got offset ${record.offset} instead. If this is a compacted topic, consider enabling " + - "spark.streaming.kafka.allowNonConsecutiveOffsets" - ) - } - - nextOffset = offset + 1 - record - } - - /** - * Start a batch on a compacted topic - */ - def compactedStart(offset: Long, timeout: Long): Unit = { - logDebug(s"compacted start $groupId $topic $partition starting $offset") - // This seek may not be necessary, but it's hard to tell due to gaps in compacted topics - if (offset != nextOffset) { - logInfo(s"Initial fetch for compacted $groupId $topic $partition $offset") - seek(offset) - poll(timeout) - } - } - - /** - * Get the next record in the batch from a compacted topic. - * Assumes compactedStart has been called first, and ignores gaps. - */ - def compactedNext(timeout: Long): ConsumerRecord[K, V] = { - if (!buffer.hasNext()) { - poll(timeout) - } - require(buffer.hasNext(), - s"Failed to get records for compacted $groupId $topic $partition after polling for $timeout") - val record = buffer.next() - nextOffset = record.offset + 1 - record - } - - /** - * Rewind to previous record in the batch from a compacted topic. - * @throws NoSuchElementException if no previous element - */ - def compactedPrevious(): ConsumerRecord[K, V] = { - buffer.previous() - } - - private def seek(offset: Long): Unit = { - logDebug(s"Seeking to $topicPartition $offset") - consumer.seek(topicPartition, offset) - } - - private def poll(timeout: Long): Unit = { - val p = consumer.poll(timeout) - val r = p.records(topicPartition) - logDebug(s"Polled ${p.partitions()} ${r.size}") - buffer = r.listIterator - } - -} - -private[kafka010] -object CachedKafkaConsumer extends Logging { - - private case class CacheKey(groupId: String, topic: String, partition: Int) - - // Don't want to depend on guava, don't want a cleanup thread, use a simple LinkedHashMap - private var cache: ju.LinkedHashMap[CacheKey, CachedKafkaConsumer[_, _]] = null - - /** Must be called before get, once per JVM, to configure the cache. Further calls are ignored */ - def init( - initialCapacity: Int, - maxCapacity: Int, - loadFactor: Float): Unit = CachedKafkaConsumer.synchronized { - if (null == cache) { - logInfo(s"Initializing cache $initialCapacity $maxCapacity $loadFactor") - cache = new ju.LinkedHashMap[CacheKey, CachedKafkaConsumer[_, _]]( - initialCapacity, loadFactor, true) { - override def removeEldestEntry( - entry: ju.Map.Entry[CacheKey, CachedKafkaConsumer[_, _]]): Boolean = { - if (this.size > maxCapacity) { - try { - entry.getValue.consumer.close() - } catch { - case x: KafkaException => - logError("Error closing oldest Kafka consumer", x) - } - true - } else { - false - } - } - } - } - } - - /** - * Get a cached consumer for groupId, assigned to topic and partition. - * If matching consumer doesn't already exist, will be created using kafkaParams. - */ - def get[K, V]( - groupId: String, - topic: String, - partition: Int, - kafkaParams: ju.Map[String, Object]): CachedKafkaConsumer[K, V] = - CachedKafkaConsumer.synchronized { - val k = CacheKey(groupId, topic, partition) - val v = cache.get(k) - if (null == v) { - logInfo(s"Cache miss for $k") - logDebug(cache.keySet.toString) - val c = new CachedKafkaConsumer[K, V](groupId, topic, partition, kafkaParams) - cache.put(k, c) - c - } else { - // any given topicpartition should have a consistent key and value type - v.asInstanceOf[CachedKafkaConsumer[K, V]] - } - } - - /** - * Get a fresh new instance, unassociated with the global cache. - * Caller is responsible for closing - */ - def getUncached[K, V]( - groupId: String, - topic: String, - partition: Int, - kafkaParams: ju.Map[String, Object]): CachedKafkaConsumer[K, V] = - new CachedKafkaConsumer[K, V](groupId, topic, partition, kafkaParams) - - /** remove consumer for given groupId, topic, and partition, if it exists */ - def remove(groupId: String, topic: String, partition: Int): Unit = { - val k = CacheKey(groupId, topic, partition) - logInfo(s"Removing $k from cache") - val v = CachedKafkaConsumer.synchronized { - cache.remove(k) - } - if (null != v) { - v.close() - logInfo(s"Removed $k from cache") - } - } -} diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaDataConsumer.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaDataConsumer.scala new file mode 100644 index 0000000000000..68c5fe9ab066a --- /dev/null +++ b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaDataConsumer.scala @@ -0,0 +1,359 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kafka010 + +import java.{util => ju} + +import org.apache.kafka.clients.consumer.{ConsumerConfig, ConsumerRecord, KafkaConsumer} +import org.apache.kafka.common.{KafkaException, TopicPartition} + +import org.apache.spark.TaskContext +import org.apache.spark.internal.Logging + +private[kafka010] sealed trait KafkaDataConsumer[K, V] { + /** + * Get the record for the given offset if available. + * + * @param offset the offset to fetch. + * @param pollTimeoutMs timeout in milliseconds to poll data from Kafka. + */ + def get(offset: Long, pollTimeoutMs: Long): ConsumerRecord[K, V] = { + internalConsumer.get(offset, pollTimeoutMs) + } + + /** + * Start a batch on a compacted topic + * + * @param offset the offset to fetch. + * @param pollTimeoutMs timeout in milliseconds to poll data from Kafka. + */ + def compactedStart(offset: Long, pollTimeoutMs: Long): Unit = { + internalConsumer.compactedStart(offset, pollTimeoutMs) + } + + /** + * Get the next record in the batch from a compacted topic. + * Assumes compactedStart has been called first, and ignores gaps. + * + * @param pollTimeoutMs timeout in milliseconds to poll data from Kafka. + */ + def compactedNext(pollTimeoutMs: Long): ConsumerRecord[K, V] = { + internalConsumer.compactedNext(pollTimeoutMs) + } + + /** + * Rewind to previous record in the batch from a compacted topic. + * + * @throws NoSuchElementException if no previous element + */ + def compactedPrevious(): ConsumerRecord[K, V] = { + internalConsumer.compactedPrevious() + } + + /** + * Release this consumer from being further used. Depending on its implementation, + * this consumer will be either finalized, or reset for reuse later. + */ + def release(): Unit + + /** Reference to the internal implementation that this wrapper delegates to */ + def internalConsumer: InternalKafkaConsumer[K, V] +} + + +/** + * A wrapper around Kafka's KafkaConsumer. + * This is not for direct use outside this file. + */ +private[kafka010] class InternalKafkaConsumer[K, V]( + val topicPartition: TopicPartition, + val kafkaParams: ju.Map[String, Object]) extends Logging { + + private[kafka010] val groupId = kafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG) + .asInstanceOf[String] + + private val consumer = createConsumer + + /** indicates whether this consumer is in use or not */ + var inUse = true + + /** indicate whether this consumer is going to be stopped in the next release */ + var markedForClose = false + + // TODO if the buffer was kept around as a random-access structure, + // could possibly optimize re-calculating of an RDD in the same batch + @volatile private var buffer = ju.Collections.emptyListIterator[ConsumerRecord[K, V]]() + @volatile private var nextOffset = InternalKafkaConsumer.UNKNOWN_OFFSET + + override def toString: String = { + "InternalKafkaConsumer(" + + s"hash=${Integer.toHexString(hashCode)}, " + + s"groupId=$groupId, " + + s"topicPartition=$topicPartition)" + } + + /** Create a KafkaConsumer to fetch records for `topicPartition` */ + private def createConsumer: KafkaConsumer[K, V] = { + val c = new KafkaConsumer[K, V](kafkaParams) + val topics = ju.Arrays.asList(topicPartition) + c.assign(topics) + c + } + + def close(): Unit = consumer.close() + + /** + * Get the record for the given offset, waiting up to timeout ms if IO is necessary. + * Sequential forward access will use buffers, but random access will be horribly inefficient. + */ + def get(offset: Long, timeout: Long): ConsumerRecord[K, V] = { + logDebug(s"Get $groupId $topicPartition nextOffset $nextOffset requested $offset") + if (offset != nextOffset) { + logInfo(s"Initial fetch for $groupId $topicPartition $offset") + seek(offset) + poll(timeout) + } + + if (!buffer.hasNext()) { + poll(timeout) + } + require(buffer.hasNext(), + s"Failed to get records for $groupId $topicPartition $offset after polling for $timeout") + var record = buffer.next() + + if (record.offset != offset) { + logInfo(s"Buffer miss for $groupId $topicPartition $offset") + seek(offset) + poll(timeout) + require(buffer.hasNext(), + s"Failed to get records for $groupId $topicPartition $offset after polling for $timeout") + record = buffer.next() + require(record.offset == offset, + s"Got wrong record for $groupId $topicPartition even after seeking to offset $offset " + + s"got offset ${record.offset} instead. If this is a compacted topic, consider enabling " + + "spark.streaming.kafka.allowNonConsecutiveOffsets" + ) + } + + nextOffset = offset + 1 + record + } + + /** + * Start a batch on a compacted topic + */ + def compactedStart(offset: Long, pollTimeoutMs: Long): Unit = { + logDebug(s"compacted start $groupId $topicPartition starting $offset") + // This seek may not be necessary, but it's hard to tell due to gaps in compacted topics + if (offset != nextOffset) { + logInfo(s"Initial fetch for compacted $groupId $topicPartition $offset") + seek(offset) + poll(pollTimeoutMs) + } + } + + /** + * Get the next record in the batch from a compacted topic. + * Assumes compactedStart has been called first, and ignores gaps. + */ + def compactedNext(pollTimeoutMs: Long): ConsumerRecord[K, V] = { + if (!buffer.hasNext()) { + poll(pollTimeoutMs) + } + require(buffer.hasNext(), + s"Failed to get records for compacted $groupId $topicPartition " + + s"after polling for $pollTimeoutMs") + val record = buffer.next() + nextOffset = record.offset + 1 + record + } + + /** + * Rewind to previous record in the batch from a compacted topic. + * @throws NoSuchElementException if no previous element + */ + def compactedPrevious(): ConsumerRecord[K, V] = { + buffer.previous() + } + + private def seek(offset: Long): Unit = { + logDebug(s"Seeking to $topicPartition $offset") + consumer.seek(topicPartition, offset) + } + + private def poll(timeout: Long): Unit = { + val p = consumer.poll(timeout) + val r = p.records(topicPartition) + logDebug(s"Polled ${p.partitions()} ${r.size}") + buffer = r.listIterator + } + +} + +private[kafka010] case class CacheKey(groupId: String, topicPartition: TopicPartition) + +private[kafka010] object KafkaDataConsumer extends Logging { + + private case class CachedKafkaDataConsumer[K, V](internalConsumer: InternalKafkaConsumer[K, V]) + extends KafkaDataConsumer[K, V] { + assert(internalConsumer.inUse) + override def release(): Unit = KafkaDataConsumer.release(internalConsumer) + } + + private case class NonCachedKafkaDataConsumer[K, V](internalConsumer: InternalKafkaConsumer[K, V]) + extends KafkaDataConsumer[K, V] { + override def release(): Unit = internalConsumer.close() + } + + // Don't want to depend on guava, don't want a cleanup thread, use a simple LinkedHashMap + private[kafka010] var cache: ju.Map[CacheKey, InternalKafkaConsumer[_, _]] = null + + /** + * Must be called before acquire, once per JVM, to configure the cache. + * Further calls are ignored. + */ + def init( + initialCapacity: Int, + maxCapacity: Int, + loadFactor: Float): Unit = synchronized { + if (null == cache) { + logInfo(s"Initializing cache $initialCapacity $maxCapacity $loadFactor") + cache = new ju.LinkedHashMap[CacheKey, InternalKafkaConsumer[_, _]]( + initialCapacity, loadFactor, true) { + override def removeEldestEntry( + entry: ju.Map.Entry[CacheKey, InternalKafkaConsumer[_, _]]): Boolean = { + + // Try to remove the least-used entry if its currently not in use. + // + // If you cannot remove it, then the cache will keep growing. In the worst case, + // the cache will grow to the max number of concurrent tasks that can run in the executor, + // (that is, number of tasks slots) after which it will never reduce. This is unlikely to + // be a serious problem because an executor with more than 64 (default) tasks slots is + // likely running on a beefy machine that can handle a large number of simultaneously + // active consumers. + + if (entry.getValue.inUse == false && this.size > maxCapacity) { + logWarning( + s"KafkaConsumer cache hitting max capacity of $maxCapacity, " + + s"removing consumer for ${entry.getKey}") + try { + entry.getValue.close() + } catch { + case x: KafkaException => + logError("Error closing oldest Kafka consumer", x) + } + true + } else { + false + } + } + } + } + } + + /** + * Get a cached consumer for groupId, assigned to topic and partition. + * If matching consumer doesn't already exist, will be created using kafkaParams. + * The returned consumer must be released explicitly using [[KafkaDataConsumer.release()]]. + * + * Note: This method guarantees that the consumer returned is not currently in use by anyone + * else. Within this guarantee, this method will make a best effort attempt to re-use consumers by + * caching them and tracking when they are in use. + */ + def acquire[K, V]( + topicPartition: TopicPartition, + kafkaParams: ju.Map[String, Object], + context: TaskContext, + useCache: Boolean): KafkaDataConsumer[K, V] = synchronized { + val groupId = kafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG).asInstanceOf[String] + val key = new CacheKey(groupId, topicPartition) + val existingInternalConsumer = cache.get(key) + + lazy val newInternalConsumer = new InternalKafkaConsumer[K, V](topicPartition, kafkaParams) + + if (context != null && context.attemptNumber >= 1) { + // If this is reattempt at running the task, then invalidate cached consumers if any and + // start with a new one. If prior attempt failures were cache related then this way old + // problematic consumers can be removed. + logDebug(s"Reattempt detected, invalidating cached consumer $existingInternalConsumer") + if (existingInternalConsumer != null) { + // Consumer exists in cache. If its in use, mark it for closing later, or close it now. + if (existingInternalConsumer.inUse) { + existingInternalConsumer.markedForClose = true + } else { + existingInternalConsumer.close() + // Remove the consumer from cache only if it's closed. + // Marked for close consumers will be removed in release function. + cache.remove(key) + } + } + + logDebug("Reattempt detected, new non-cached consumer will be allocated " + + s"$newInternalConsumer") + NonCachedKafkaDataConsumer(newInternalConsumer) + } else if (!useCache) { + // If consumer reuse turned off, then do not use it, return a new consumer + logDebug("Cache usage turned off, new non-cached consumer will be allocated " + + s"$newInternalConsumer") + NonCachedKafkaDataConsumer(newInternalConsumer) + } else if (existingInternalConsumer == null) { + // If consumer is not already cached, then put a new in the cache and return it + logDebug("No cached consumer, new cached consumer will be allocated " + + s"$newInternalConsumer") + cache.put(key, newInternalConsumer) + CachedKafkaDataConsumer(newInternalConsumer) + } else if (existingInternalConsumer.inUse) { + // If consumer is already cached but is currently in use, then return a new consumer + logDebug("Used cached consumer found, new non-cached consumer will be allocated " + + s"$newInternalConsumer") + NonCachedKafkaDataConsumer(newInternalConsumer) + } else { + // If consumer is already cached and is currently not in use, then return that consumer + logDebug(s"Not used cached consumer found, re-using it $existingInternalConsumer") + existingInternalConsumer.inUse = true + // Any given TopicPartition should have a consistent key and value type + CachedKafkaDataConsumer(existingInternalConsumer.asInstanceOf[InternalKafkaConsumer[K, V]]) + } + } + + private def release(internalConsumer: InternalKafkaConsumer[_, _]): Unit = synchronized { + // Clear the consumer from the cache if this is indeed the consumer present in the cache + val key = new CacheKey(internalConsumer.groupId, internalConsumer.topicPartition) + val cachedInternalConsumer = cache.get(key) + if (internalConsumer.eq(cachedInternalConsumer)) { + // The released consumer is the same object as the cached one. + if (internalConsumer.markedForClose) { + internalConsumer.close() + cache.remove(key) + } else { + internalConsumer.inUse = false + } + } else { + // The released consumer is either not the same one as in the cache, or not in the cache + // at all. This may happen if the cache was invalidate while this consumer was being used. + // Just close this consumer. + internalConsumer.close() + logInfo(s"Released a supposedly cached consumer that was not found in the cache " + + s"$internalConsumer") + } + } +} + +private[kafka010] object InternalKafkaConsumer { + private val UNKNOWN_OFFSET = -2L +} diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaRDD.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaRDD.scala index 07239eda64d2e..81abc9860bfc3 100644 --- a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaRDD.scala +++ b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaRDD.scala @@ -19,8 +19,6 @@ package org.apache.spark.streaming.kafka010 import java.{ util => ju } -import scala.collection.mutable.ArrayBuffer - import org.apache.kafka.clients.consumer.{ ConsumerConfig, ConsumerRecord } import org.apache.kafka.common.TopicPartition @@ -239,26 +237,18 @@ private class KafkaRDDIterator[K, V]( cacheLoadFactor: Float ) extends Iterator[ConsumerRecord[K, V]] { - val groupId = kafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG).asInstanceOf[String] - context.addTaskCompletionListener(_ => closeIfNeeded()) - val consumer = if (useConsumerCache) { - CachedKafkaConsumer.init(cacheInitialCapacity, cacheMaxCapacity, cacheLoadFactor) - if (context.attemptNumber >= 1) { - // just in case the prior attempt failures were cache related - CachedKafkaConsumer.remove(groupId, part.topic, part.partition) - } - CachedKafkaConsumer.get[K, V](groupId, part.topic, part.partition, kafkaParams) - } else { - CachedKafkaConsumer.getUncached[K, V](groupId, part.topic, part.partition, kafkaParams) + val consumer = { + KafkaDataConsumer.init(cacheInitialCapacity, cacheMaxCapacity, cacheLoadFactor) + KafkaDataConsumer.acquire[K, V](part.topicPartition(), kafkaParams, context, useConsumerCache) } var requestOffset = part.fromOffset def closeIfNeeded(): Unit = { - if (!useConsumerCache && consumer != null) { - consumer.close() + if (consumer != null) { + consumer.release() } } diff --git a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaDataConsumerSuite.scala b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaDataConsumerSuite.scala new file mode 100644 index 0000000000000..d934c64962adb --- /dev/null +++ b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaDataConsumerSuite.scala @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kafka010 + +import java.util.concurrent.{Executors, TimeUnit} + +import scala.collection.JavaConverters._ +import scala.util.Random + +import org.apache.kafka.clients.consumer.ConsumerConfig._ +import org.apache.kafka.common.TopicPartition +import org.apache.kafka.common.serialization.ByteArrayDeserializer +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark._ + +class KafkaDataConsumerSuite extends SparkFunSuite with BeforeAndAfterAll { + private var testUtils: KafkaTestUtils = _ + private val topic = "topic" + Random.nextInt() + private val topicPartition = new TopicPartition(topic, 0) + private val groupId = "groupId" + + override def beforeAll(): Unit = { + super.beforeAll() + testUtils = new KafkaTestUtils + testUtils.setup() + KafkaDataConsumer.init(16, 64, 0.75f) + } + + override def afterAll(): Unit = { + if (testUtils != null) { + testUtils.teardown() + testUtils = null + } + super.afterAll() + } + + private def getKafkaParams() = Map[String, Object]( + GROUP_ID_CONFIG -> groupId, + BOOTSTRAP_SERVERS_CONFIG -> testUtils.brokerAddress, + KEY_DESERIALIZER_CLASS_CONFIG -> classOf[ByteArrayDeserializer].getName, + VALUE_DESERIALIZER_CLASS_CONFIG -> classOf[ByteArrayDeserializer].getName, + AUTO_OFFSET_RESET_CONFIG -> "earliest", + ENABLE_AUTO_COMMIT_CONFIG -> "false" + ).asJava + + test("KafkaDataConsumer reuse in case of same groupId and TopicPartition") { + KafkaDataConsumer.cache.clear() + + val kafkaParams = getKafkaParams() + + val consumer1 = KafkaDataConsumer.acquire[Array[Byte], Array[Byte]]( + topicPartition, kafkaParams, null, true) + consumer1.release() + + val consumer2 = KafkaDataConsumer.acquire[Array[Byte], Array[Byte]]( + topicPartition, kafkaParams, null, true) + consumer2.release() + + assert(KafkaDataConsumer.cache.size() == 1) + val key = new CacheKey(groupId, topicPartition) + val existingInternalConsumer = KafkaDataConsumer.cache.get(key) + assert(existingInternalConsumer.eq(consumer1.internalConsumer)) + assert(existingInternalConsumer.eq(consumer2.internalConsumer)) + } + + test("concurrent use of KafkaDataConsumer") { + val data = (1 to 1000).map(_.toString) + testUtils.createTopic(topic) + testUtils.sendMessages(topic, data.toArray) + + val kafkaParams = getKafkaParams() + + val numThreads = 100 + val numConsumerUsages = 500 + + @volatile var error: Throwable = null + + def consume(i: Int): Unit = { + val useCache = Random.nextBoolean + val taskContext = if (Random.nextBoolean) { + new TaskContextImpl(0, 0, 0, 0, attemptNumber = Random.nextInt(2), null, null, null) + } else { + null + } + val consumer = KafkaDataConsumer.acquire[Array[Byte], Array[Byte]]( + topicPartition, kafkaParams, taskContext, useCache) + try { + val rcvd = (0 until data.length).map { offset => + val bytes = consumer.get(offset, 10000).value() + new String(bytes) + } + assert(rcvd == data) + } catch { + case e: Throwable => + error = e + throw e + } finally { + consumer.release() + } + } + + val threadPool = Executors.newFixedThreadPool(numThreads) + try { + val futures = (1 to numConsumerUsages).map { i => + threadPool.submit(new Runnable { + override def run(): Unit = { consume(i) } + }) + } + futures.foreach(_.get(1, TimeUnit.MINUTES)) + assert(error == null) + } finally { + threadPool.shutdown() + } + } +} From 00c13cfad78607fde0787c9d494f0df8ab7051ba Mon Sep 17 00:00:00 2001 From: Seth Fitzsimmons Date: Wed, 23 May 2018 09:14:03 +0800 Subject: [PATCH 0855/1161] Correct reference to Offset class This is a documentation-only correction; `org.apache.spark.sql.sources.v2.reader.Offset` is actually `org.apache.spark.sql.sources.v2.reader.streaming.Offset`. Author: Seth Fitzsimmons Closes #21387 from mojodna/patch-1. --- .../org/apache/spark/sql/execution/streaming/Offset.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Offset.java b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Offset.java index 80aa5505db991..43ad4b3384ec3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Offset.java +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Offset.java @@ -19,8 +19,8 @@ /** * This is an internal, deprecated interface. New source implementations should use the - * org.apache.spark.sql.sources.v2.reader.Offset class, which is the one that will be supported - * in the long term. + * org.apache.spark.sql.sources.v2.reader.streaming.Offset class, which is the one that will be + * supported in the long term. * * This class will be removed in a future release. */ From a40ffc656d62372da85e0fa932b67207839e7fde Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 23 May 2018 22:40:52 +0800 Subject: [PATCH 0856/1161] [SPARK-23711][SQL] Add fallback generator for UnsafeProjection ## What changes were proposed in this pull request? Add fallback logic for `UnsafeProjection`. In production we can try to create unsafe projection using codegen implementation. Once any compile error happens, it fallbacks to interpreted implementation. ## How was this patch tested? Added test. Author: Liang-Chi Hsieh Closes #21106 from viirya/SPARK-23711. --- ...CodeGeneratorWithInterpretedFallback.scala | 73 +++++++++++++++++++ .../InterpretedUnsafeProjection.scala | 5 +- .../sql/catalyst/expressions/Projection.scala | 60 +++++++++------ .../apache/spark/sql/internal/SQLConf.scala | 12 +++ ...eneratorWithInterpretedFallbackSuite.scala | 43 +++++++++++ .../expressions/ExpressionEvalHelper.scala | 52 ++++++------- .../expressions/ObjectExpressionsSuite.scala | 18 +---- .../expressions/UnsafeRowConverterSuite.scala | 52 ++++++++----- .../UnsafeKVExternalSorterSuite.scala | 3 +- 9 files changed, 230 insertions(+), 88 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CodeGeneratorWithInterpretedFallback.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGeneratorWithInterpretedFallbackSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CodeGeneratorWithInterpretedFallback.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CodeGeneratorWithInterpretedFallback.scala new file mode 100644 index 0000000000000..fb25e781e72e4 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CodeGeneratorWithInterpretedFallback.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.codehaus.commons.compiler.CompileException +import org.codehaus.janino.InternalCompilerException + +import org.apache.spark.TaskContext +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.util.Utils + +/** + * Catches compile error during code generation. + */ +object CodegenError { + def unapply(throwable: Throwable): Option[Exception] = throwable match { + case e: InternalCompilerException => Some(e) + case e: CompileException => Some(e) + case _ => None + } +} + +/** + * Defines values for `SQLConf` config of fallback mode. Use for test only. + */ +object CodegenObjectFactoryMode extends Enumeration { + val FALLBACK, CODEGEN_ONLY, NO_CODEGEN = Value +} + +/** + * A codegen object generator which creates objects with codegen path first. Once any compile + * error happens, it can fallbacks to interpreted implementation. In tests, we can use a SQL config + * `SQLConf.CODEGEN_FACTORY_MODE` to control fallback behavior. + */ +abstract class CodeGeneratorWithInterpretedFallback[IN, OUT] { + + def createObject(in: IN): OUT = { + // We are allowed to choose codegen-only or no-codegen modes if under tests. + val config = SQLConf.get.getConf(SQLConf.CODEGEN_FACTORY_MODE) + val fallbackMode = CodegenObjectFactoryMode.withName(config) + + fallbackMode match { + case CodegenObjectFactoryMode.CODEGEN_ONLY if Utils.isTesting => + createCodeGeneratedObject(in) + case CodegenObjectFactoryMode.NO_CODEGEN if Utils.isTesting => + createInterpretedObject(in) + case _ => + try { + createCodeGeneratedObject(in) + } catch { + case CodegenError(_) => createInterpretedObject(in) + } + } + } + + protected def createCodeGeneratedObject(in: IN): OUT + protected def createInterpretedObject(in: IN): OUT +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala index 6d69d69b1c802..55a5bd380859e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala @@ -87,12 +87,11 @@ class InterpretedUnsafeProjection(expressions: Array[Expression]) extends Unsafe /** * Helper functions for creating an [[InterpretedUnsafeProjection]]. */ -object InterpretedUnsafeProjection extends UnsafeProjectionCreator { - +object InterpretedUnsafeProjection { /** * Returns an [[UnsafeProjection]] for given sequence of bound Expressions. */ - override protected def createProjection(exprs: Seq[Expression]): UnsafeProjection = { + def createProjection(exprs: Seq[Expression]): UnsafeProjection = { // We need to make sure that we do not reuse stateful expressions. val cleanedExpressions = exprs.map(_.transform { case s: Stateful => s.freshCopy() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index 3cd73682188bc..6493f09100577 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -108,7 +108,32 @@ abstract class UnsafeProjection extends Projection { override def apply(row: InternalRow): UnsafeRow } -trait UnsafeProjectionCreator { +/** + * The factory object for `UnsafeProjection`. + */ +object UnsafeProjection + extends CodeGeneratorWithInterpretedFallback[Seq[Expression], UnsafeProjection] { + + override protected def createCodeGeneratedObject(in: Seq[Expression]): UnsafeProjection = { + GenerateUnsafeProjection.generate(in) + } + + override protected def createInterpretedObject(in: Seq[Expression]): UnsafeProjection = { + InterpretedUnsafeProjection.createProjection(in) + } + + protected def toBoundExprs( + exprs: Seq[Expression], + inputSchema: Seq[Attribute]): Seq[Expression] = { + exprs.map(BindReferences.bindReference(_, inputSchema)) + } + + protected def toUnsafeExprs(exprs: Seq[Expression]): Seq[Expression] = { + exprs.map(_ transform { + case CreateNamedStruct(children) => CreateNamedStructUnsafe(children) + }) + } + /** * Returns an UnsafeProjection for given StructType. * @@ -129,10 +154,7 @@ trait UnsafeProjectionCreator { * Returns an UnsafeProjection for given sequence of bound Expressions. */ def create(exprs: Seq[Expression]): UnsafeProjection = { - val unsafeExprs = exprs.map(_ transform { - case CreateNamedStruct(children) => CreateNamedStructUnsafe(children) - }) - createProjection(unsafeExprs) + createObject(toUnsafeExprs(exprs)) } def create(expr: Expression): UnsafeProjection = create(Seq(expr)) @@ -142,34 +164,24 @@ trait UnsafeProjectionCreator { * `inputSchema`. */ def create(exprs: Seq[Expression], inputSchema: Seq[Attribute]): UnsafeProjection = { - create(exprs.map(BindReferences.bindReference(_, inputSchema))) - } - - /** - * Returns an [[UnsafeProjection]] for given sequence of bound Expressions. - */ - protected def createProjection(exprs: Seq[Expression]): UnsafeProjection -} - -object UnsafeProjection extends UnsafeProjectionCreator { - - override protected def createProjection(exprs: Seq[Expression]): UnsafeProjection = { - GenerateUnsafeProjection.generate(exprs) + create(toBoundExprs(exprs, inputSchema)) } /** * Same as other create()'s but allowing enabling/disabling subexpression elimination. - * TODO: refactor the plumbing and clean this up. + * The param `subexpressionEliminationEnabled` doesn't guarantee to work. For example, + * when fallbacking to interpreted execution, it is not supported. */ def create( exprs: Seq[Expression], inputSchema: Seq[Attribute], subexpressionEliminationEnabled: Boolean): UnsafeProjection = { - val e = exprs.map(BindReferences.bindReference(_, inputSchema)) - .map(_ transform { - case CreateNamedStruct(children) => CreateNamedStructUnsafe(children) - }) - GenerateUnsafeProjection.generate(e, subexpressionEliminationEnabled) + val unsafeExprs = toUnsafeExprs(toBoundExprs(exprs, inputSchema)) + try { + GenerateUnsafeProjection.generate(unsafeExprs, subexpressionEliminationEnabled) + } catch { + case CodegenError(_) => InterpretedUnsafeProjection.createProjection(unsafeExprs) + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index d0478d6ad250b..fb98df587129b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -32,6 +32,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ import org.apache.spark.network.util.ByteUnit import org.apache.spark.sql.catalyst.analysis.Resolver +import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode import org.apache.spark.sql.catalyst.expressions.codegen.CodeGenerator import org.apache.spark.util.Utils @@ -703,6 +704,17 @@ object SQLConf { .intConf .createWithDefault(100) + val CODEGEN_FACTORY_MODE = buildConf("spark.sql.codegen.factoryMode") + .doc("This config determines the fallback behavior of several codegen generators " + + "during tests. `FALLBACK` means trying codegen first and then fallbacking to " + + "interpreted if any compile error happens. Disabling fallback if `CODEGEN_ONLY`. " + + "`NO_CODEGEN` skips codegen and goes interpreted path always. Note that " + + "this config works only for tests.") + .internal() + .stringConf + .checkValues(CodegenObjectFactoryMode.values.map(_.toString)) + .createWithDefault(CodegenObjectFactoryMode.FALLBACK.toString) + val CODEGEN_FALLBACK = buildConf("spark.sql.codegen.fallback") .internal() .doc("When true, (whole stage) codegen could be temporary disabled for the part of query that" + diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGeneratorWithInterpretedFallbackSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGeneratorWithInterpretedFallbackSuite.scala new file mode 100644 index 0000000000000..531ca9a87370a --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGeneratorWithInterpretedFallbackSuite.scala @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.plans.PlanTestBase +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.{IntegerType, LongType} + +class CodeGeneratorWithInterpretedFallbackSuite extends SparkFunSuite with PlanTestBase { + + test("UnsafeProjection with codegen factory mode") { + val input = Seq(LongType, IntegerType) + .zipWithIndex.map(x => BoundReference(x._2, x._1, true)) + + val codegenOnly = CodegenObjectFactoryMode.CODEGEN_ONLY.toString + withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> codegenOnly) { + val obj = UnsafeProjection.createObject(input) + assert(obj.getClass.getName.contains("GeneratedClass$SpecificUnsafeProjection")) + } + + val noCodegen = CodegenObjectFactoryMode.NO_CODEGEN.toString + withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> noCodegen) { + val obj = UnsafeProjection.createObject(input) + assert(obj.isInstanceOf[InterpretedUnsafeProjection]) + } + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index c2a44e0d33b18..14bfa212b5496 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.analysis.{ResolveTimeZone, SimpleAnalyzer} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.optimizer.SimpleTestOptimizer +import org.apache.spark.sql.catalyst.plans.PlanTestBase import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Project} import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} import org.apache.spark.sql.internal.SQLConf @@ -40,7 +41,7 @@ import org.apache.spark.util.Utils /** * A few helper functions for expression evaluation testing. Mixin this trait to use them. */ -trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { +trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks with PlanTestBase { self: SparkFunSuite => protected def create_row(values: Any*): InternalRow = { @@ -205,39 +206,34 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { expression: Expression, expected: Any, inputRow: InternalRow = EmptyRow): Unit = { - checkEvaluationWithUnsafeProjection(expression, expected, inputRow, UnsafeProjection) - checkEvaluationWithUnsafeProjection(expression, expected, inputRow, InterpretedUnsafeProjection) - } - - protected def checkEvaluationWithUnsafeProjection( - expression: Expression, - expected: Any, - inputRow: InternalRow, - factory: UnsafeProjectionCreator): Unit = { - val unsafeRow = evaluateWithUnsafeProjection(expression, inputRow, factory) - val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" - - if (expected == null) { - if (!unsafeRow.isNullAt(0)) { - val expectedRow = InternalRow(expected, expected) - fail("Incorrect evaluation in unsafe mode: " + - s"$expression, actual: $unsafeRow, expected: $expectedRow$input") - } - } else { - val lit = InternalRow(expected, expected) - val expectedRow = - factory.create(Array(expression.dataType, expression.dataType)).apply(lit) - if (unsafeRow != expectedRow) { - fail("Incorrect evaluation in unsafe mode: " + - s"$expression, actual: $unsafeRow, expected: $expectedRow$input") + val modes = Seq(CodegenObjectFactoryMode.CODEGEN_ONLY, CodegenObjectFactoryMode.NO_CODEGEN) + for (fallbackMode <- modes) { + withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> fallbackMode.toString) { + val unsafeRow = evaluateWithUnsafeProjection(expression, inputRow) + val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" + + if (expected == null) { + if (!unsafeRow.isNullAt(0)) { + val expectedRow = InternalRow(expected, expected) + fail("Incorrect evaluation in unsafe mode: " + + s"$expression, actual: $unsafeRow, expected: $expectedRow$input") + } + } else { + val lit = InternalRow(expected, expected) + val expectedRow = + UnsafeProjection.create(Array(expression.dataType, expression.dataType)).apply(lit) + if (unsafeRow != expectedRow) { + fail("Incorrect evaluation in unsafe mode: " + + s"$expression, actual: $unsafeRow, expected: $expectedRow$input") + } + } } } } protected def evaluateWithUnsafeProjection( expression: Expression, - inputRow: InternalRow = EmptyRow, - factory: UnsafeProjectionCreator = UnsafeProjection): InternalRow = { + inputRow: InternalRow = EmptyRow): InternalRow = { // SPARK-16489 Explicitly doing code generation twice so code gen will fail if // some expression is reusing variable names across different instances. // This behavior is tested in ExpressionEvalHelperSuite. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala index 77ca640f2e0bd..20d568c44258f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala @@ -81,10 +81,7 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val structExpected = new GenericArrayData( Array(InternalRow.fromSeq(Seq(1, 2)), InternalRow.fromSeq(Seq(3, 4)))) checkEvaluationWithUnsafeProjection( - structEncoder.serializer.head, - structExpected, - structInputRow, - UnsafeProjection) // TODO(hvanhovell) revert this when SPARK-23587 is fixed + structEncoder.serializer.head, structExpected, structInputRow) // test UnsafeArray-backed data val arrayEncoder = ExpressionEncoder[Array[Array[Int]]] @@ -92,10 +89,7 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val arrayExpected = new GenericArrayData( Array(new GenericArrayData(Array(1, 2)), new GenericArrayData(Array(3, 4)))) checkEvaluationWithUnsafeProjection( - arrayEncoder.serializer.head, - arrayExpected, - arrayInputRow, - UnsafeProjection) // TODO(hvanhovell) revert this when SPARK-23587 is fixed + arrayEncoder.serializer.head, arrayExpected, arrayInputRow) // test UnsafeMap-backed data val mapEncoder = ExpressionEncoder[Array[Map[Int, Int]]] @@ -109,10 +103,7 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { new GenericArrayData(Array(3, 4)), new GenericArrayData(Array(300, 400))))) checkEvaluationWithUnsafeProjection( - mapEncoder.serializer.head, - mapExpected, - mapInputRow, - UnsafeProjection) // TODO(hvanhovell) revert this when SPARK-23587 is fixed + mapEncoder.serializer.head, mapExpected, mapInputRow) } test("SPARK-23582: StaticInvoke should support interpreted execution") { @@ -286,8 +277,7 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluationWithUnsafeProjection( expr, expected, - inputRow, - UnsafeProjection) // TODO(hvanhovell) revert this when SPARK-23587 is fixed + inputRow) } checkEvaluationWithOptimization(expr, expected, inputRow) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala index c07da122cd7b8..5a646d9a850ac 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala @@ -24,25 +24,30 @@ import org.scalatest.Matchers import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.plans.PlanTestBase import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{IntegerType, LongType, _} import org.apache.spark.unsafe.array.ByteArrayMethods import org.apache.spark.unsafe.types.UTF8String -class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { +class UnsafeRowConverterSuite extends SparkFunSuite with Matchers with PlanTestBase { private def roundedSize(size: Int) = ByteArrayMethods.roundNumberOfBytesToNearestWord(size) - private def testWithFactory( - name: String)( - f: UnsafeProjectionCreator => Unit): Unit = { - test(name) { - f(UnsafeProjection) - f(InterpretedUnsafeProjection) + private def testBothCodegenAndInterpreted(name: String)(f: => Unit): Unit = { + val modes = Seq(CodegenObjectFactoryMode.CODEGEN_ONLY, CodegenObjectFactoryMode.NO_CODEGEN) + for (fallbackMode <- modes) { + test(s"$name with $fallbackMode") { + withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> fallbackMode.toString) { + f + } + } } } - testWithFactory("basic conversion with only primitive types") { factory => + testBothCodegenAndInterpreted("basic conversion with only primitive types") { + val factory = UnsafeProjection val fieldTypes: Array[DataType] = Array(LongType, LongType, IntegerType) val converter = factory.create(fieldTypes) val row = new SpecificInternalRow(fieldTypes) @@ -79,7 +84,8 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { assert(unsafeRow2.getInt(2) === 2) } - testWithFactory("basic conversion with primitive, string and binary types") { factory => + testBothCodegenAndInterpreted("basic conversion with primitive, string and binary types") { + val factory = UnsafeProjection val fieldTypes: Array[DataType] = Array(LongType, StringType, BinaryType) val converter = factory.create(fieldTypes) @@ -98,7 +104,9 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { assert(unsafeRow.getBinary(2) === "World".getBytes(StandardCharsets.UTF_8)) } - testWithFactory("basic conversion with primitive, string, date and timestamp types") { factory => + testBothCodegenAndInterpreted( + "basic conversion with primitive, string, date and timestamp types") { + val factory = UnsafeProjection val fieldTypes: Array[DataType] = Array(LongType, StringType, DateType, TimestampType) val converter = factory.create(fieldTypes) @@ -127,7 +135,8 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { (Timestamp.valueOf("2015-06-22 08:10:25")) } - testWithFactory("null handling") { factory => + testBothCodegenAndInterpreted("null handling") { + val factory = UnsafeProjection val fieldTypes: Array[DataType] = Array( NullType, BooleanType, @@ -248,7 +257,8 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { // assert(setToNullAfterCreation.get(11) === rowWithNoNullColumns.get(11)) } - testWithFactory("NaN canonicalization") { factory => + testBothCodegenAndInterpreted("NaN canonicalization") { + val factory = UnsafeProjection val fieldTypes: Array[DataType] = Array(FloatType, DoubleType) val row1 = new SpecificInternalRow(fieldTypes) @@ -263,7 +273,8 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { assert(converter.apply(row1).getBytes === converter.apply(row2).getBytes) } - testWithFactory("basic conversion with struct type") { factory => + testBothCodegenAndInterpreted("basic conversion with struct type") { + val factory = UnsafeProjection val fieldTypes: Array[DataType] = Array( new StructType().add("i", IntegerType), new StructType().add("nest", new StructType().add("l", LongType)) @@ -325,7 +336,8 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { assert(map.getSizeInBytes == 8 + map.keyArray.getSizeInBytes + map.valueArray.getSizeInBytes) } - testWithFactory("basic conversion with array type") { factory => + testBothCodegenAndInterpreted("basic conversion with array type") { + val factory = UnsafeProjection val fieldTypes: Array[DataType] = Array( ArrayType(IntegerType), ArrayType(ArrayType(IntegerType)) @@ -355,7 +367,8 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { assert(unsafeRow.getSizeInBytes == 8 + 8 * 2 + array1Size + array2Size) } - testWithFactory("basic conversion with map type") { factory => + testBothCodegenAndInterpreted("basic conversion with map type") { + val factory = UnsafeProjection val fieldTypes: Array[DataType] = Array( MapType(IntegerType, IntegerType), MapType(IntegerType, MapType(IntegerType, IntegerType)) @@ -401,7 +414,8 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { assert(unsafeRow.getSizeInBytes == 8 + 8 * 2 + map1Size + map2Size) } - testWithFactory("basic conversion with struct and array") { factory => + testBothCodegenAndInterpreted("basic conversion with struct and array") { + val factory = UnsafeProjection val fieldTypes: Array[DataType] = Array( new StructType().add("arr", ArrayType(IntegerType)), ArrayType(new StructType().add("l", LongType)) @@ -440,7 +454,8 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { 8 + 8 * 2 + field1.getSizeInBytes + roundedSize(field2.getSizeInBytes)) } - testWithFactory("basic conversion with struct and map") { factory => + testBothCodegenAndInterpreted("basic conversion with struct and map") { + val factory = UnsafeProjection val fieldTypes: Array[DataType] = Array( new StructType().add("map", MapType(IntegerType, IntegerType)), MapType(IntegerType, new StructType().add("l", LongType)) @@ -486,7 +501,8 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { 8 + 8 * 2 + field1.getSizeInBytes + roundedSize(field2.getSizeInBytes)) } - testWithFactory("basic conversion with array and map") { factory => + testBothCodegenAndInterpreted("basic conversion with array and map") { + val factory = UnsafeProjection val fieldTypes: Array[DataType] = Array( ArrayType(MapType(IntegerType, IntegerType)), MapType(IntegerType, ArrayType(IntegerType)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala index bf588d3bb7841..c882a9dd2148c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala @@ -231,7 +231,8 @@ class UnsafeKVExternalSorterSuite extends SparkFunSuite with SharedSQLContext { // Make sure we can successfully create a UnsafeKVExternalSorter with a `BytesToBytesMap` // which has duplicated keys and the number of entries exceeds its capacity. try { - TaskContext.setTaskContext(new TaskContextImpl(0, 0, 0, 0, 0, taskMemoryManager, null, null)) + val context = new TaskContextImpl(0, 0, 0, 0, 0, taskMemoryManager, new Properties(), null) + TaskContext.setTaskContext(context) new UnsafeKVExternalSorter( schema, schema, From df125062c8dac9fee3328d67dd438a456b7a3b74 Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Wed, 23 May 2018 11:00:23 -0700 Subject: [PATCH 0857/1161] [SPARK-20114][ML][FOLLOW-UP] spark.ml parity for sequential pattern mining - PrefixSpan ## What changes were proposed in this pull request? Change `PrefixSpan` into a class with param setter/getters. This address issues mentioned here: https://github.com/apache/spark/pull/20973#discussion_r186931806 ## How was this patch tested? UT. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: WeichenXu Closes #21393 from WeichenXu123/fix_prefix_span. --- .../org/apache/spark/ml/fpm/PrefixSpan.scala | 127 ++++++++++++++---- .../apache/spark/ml/fpm/PrefixSpanSuite.scala | 28 ++-- 2 files changed, 119 insertions(+), 36 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala index 02168fee16dbf..41716c621ca98 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala @@ -18,6 +18,8 @@ package org.apache.spark.ml.fpm import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.ml.param._ +import org.apache.spark.ml.util.Identifiable import org.apache.spark.mllib.fpm.{PrefixSpan => mllibPrefixSpan} import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions.col @@ -29,13 +31,97 @@ import org.apache.spark.sql.types.{ArrayType, LongType, StructField, StructType} * The PrefixSpan algorithm is described in J. Pei, et al., PrefixSpan: Mining Sequential Patterns * Efficiently by Prefix-Projected Pattern Growth * (see here). + * This class is not yet an Estimator/Transformer, use `findFrequentSequentialPatterns` method to + * run the PrefixSpan algorithm. * * @see Sequential Pattern Mining * (Wikipedia) */ @Since("2.4.0") @Experimental -object PrefixSpan { +final class PrefixSpan(@Since("2.4.0") override val uid: String) extends Params { + + @Since("2.4.0") + def this() = this(Identifiable.randomUID("prefixSpan")) + + /** + * Param for the minimal support level (default: `0.1`). + * Sequential patterns that appear more than (minSupport * size-of-the-dataset) times are + * identified as frequent sequential patterns. + * @group param + */ + @Since("2.4.0") + val minSupport = new DoubleParam(this, "minSupport", "The minimal support level of the " + + "sequential pattern. Sequential pattern that appears more than " + + "(minSupport * size-of-the-dataset)." + + "times will be output.", ParamValidators.gtEq(0.0)) + + /** @group getParam */ + @Since("2.4.0") + def getMinSupport: Double = $(minSupport) + + /** @group setParam */ + @Since("2.4.0") + def setMinSupport(value: Double): this.type = set(minSupport, value) + + /** + * Param for the maximal pattern length (default: `10`). + * @group param + */ + @Since("2.4.0") + val maxPatternLength = new IntParam(this, "maxPatternLength", + "The maximal length of the sequential pattern.", + ParamValidators.gt(0)) + + /** @group getParam */ + @Since("2.4.0") + def getMaxPatternLength: Int = $(maxPatternLength) + + /** @group setParam */ + @Since("2.4.0") + def setMaxPatternLength(value: Int): this.type = set(maxPatternLength, value) + + /** + * Param for the maximum number of items (including delimiters used in the internal storage + * format) allowed in a projected database before local processing (default: `32000000`). + * If a projected database exceeds this size, another iteration of distributed prefix growth + * is run. + * @group param + */ + @Since("2.4.0") + val maxLocalProjDBSize = new LongParam(this, "maxLocalProjDBSize", + "The maximum number of items (including delimiters used in the internal storage format) " + + "allowed in a projected database before local processing. If a projected database exceeds " + + "this size, another iteration of distributed prefix growth is run.", + ParamValidators.gt(0)) + + /** @group getParam */ + @Since("2.4.0") + def getMaxLocalProjDBSize: Long = $(maxLocalProjDBSize) + + /** @group setParam */ + @Since("2.4.0") + def setMaxLocalProjDBSize(value: Long): this.type = set(maxLocalProjDBSize, value) + + /** + * Param for the name of the sequence column in dataset (default "sequence"), rows with + * nulls in this column are ignored. + * @group param + */ + @Since("2.4.0") + val sequenceCol = new Param[String](this, "sequenceCol", "The name of the sequence column in " + + "dataset, rows with nulls in this column are ignored.") + + /** @group getParam */ + @Since("2.4.0") + def getSequenceCol: String = $(sequenceCol) + + /** @group setParam */ + @Since("2.4.0") + def setSequenceCol(value: String): this.type = set(sequenceCol, value) + + setDefault(minSupport -> 0.1, maxPatternLength -> 10, maxLocalProjDBSize -> 32000000, + sequenceCol -> "sequence") /** * :: Experimental :: @@ -43,54 +129,39 @@ object PrefixSpan { * * @param dataset A dataset or a dataframe containing a sequence column which is * {{{Seq[Seq[_]]}}} type - * @param sequenceCol the name of the sequence column in dataset, rows with nulls in this column - * are ignored - * @param minSupport the minimal support level of the sequential pattern, any pattern that - * appears more than (minSupport * size-of-the-dataset) times will be output - * (recommended value: `0.1`). - * @param maxPatternLength the maximal length of the sequential pattern - * (recommended value: `10`). - * @param maxLocalProjDBSize The maximum number of items (including delimiters used in the - * internal storage format) allowed in a projected database before - * local processing. If a projected database exceeds this size, another - * iteration of distributed prefix growth is run - * (recommended value: `32000000`). * @return A `DataFrame` that contains columns of sequence and corresponding frequency. * The schema of it will be: * - `sequence: Seq[Seq[T]]` (T is the item type) * - `freq: Long` */ @Since("2.4.0") - def findFrequentSequentialPatterns( - dataset: Dataset[_], - sequenceCol: String, - minSupport: Double, - maxPatternLength: Int, - maxLocalProjDBSize: Long): DataFrame = { - - val inputType = dataset.schema(sequenceCol).dataType + def findFrequentSequentialPatterns(dataset: Dataset[_]): DataFrame = { + val sequenceColParam = $(sequenceCol) + val inputType = dataset.schema(sequenceColParam).dataType require(inputType.isInstanceOf[ArrayType] && inputType.asInstanceOf[ArrayType].elementType.isInstanceOf[ArrayType], s"The input column must be ArrayType and the array element type must also be ArrayType, " + s"but got $inputType.") - - val data = dataset.select(sequenceCol) - val sequences = data.where(col(sequenceCol).isNotNull).rdd + val data = dataset.select(sequenceColParam) + val sequences = data.where(col(sequenceColParam).isNotNull).rdd .map(r => r.getAs[Seq[Seq[Any]]](0).map(_.toArray).toArray) val mllibPrefixSpan = new mllibPrefixSpan() - .setMinSupport(minSupport) - .setMaxPatternLength(maxPatternLength) - .setMaxLocalProjDBSize(maxLocalProjDBSize) + .setMinSupport($(minSupport)) + .setMaxPatternLength($(maxPatternLength)) + .setMaxLocalProjDBSize($(maxLocalProjDBSize)) val rows = mllibPrefixSpan.run(sequences).freqSequences.map(f => Row(f.sequence, f.freq)) val schema = StructType(Seq( - StructField("sequence", dataset.schema(sequenceCol).dataType, nullable = false), + StructField("sequence", dataset.schema(sequenceColParam).dataType, nullable = false), StructField("freq", LongType, nullable = false))) val freqSequences = dataset.sparkSession.createDataFrame(rows, schema) freqSequences } + @Since("2.4.0") + override def copy(extra: ParamMap): PrefixSpan = defaultCopy(extra) + } diff --git a/mllib/src/test/scala/org/apache/spark/ml/fpm/PrefixSpanSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/fpm/PrefixSpanSuite.scala index 9e538696cbcf7..2252151af306b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/fpm/PrefixSpanSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/fpm/PrefixSpanSuite.scala @@ -29,8 +29,11 @@ class PrefixSpanSuite extends MLTest { test("PrefixSpan projections with multiple partial starts") { val smallDataset = Seq(Seq(Seq(1, 2), Seq(1, 2, 3))).toDF("sequence") - val result = PrefixSpan.findFrequentSequentialPatterns(smallDataset, "sequence", - minSupport = 1.0, maxPatternLength = 2, maxLocalProjDBSize = 32000000) + val result = new PrefixSpan() + .setMinSupport(1.0) + .setMaxPatternLength(2) + .setMaxLocalProjDBSize(32000000) + .findFrequentSequentialPatterns(smallDataset) .as[(Seq[Seq[Int]], Long)].collect() val expected = Array( (Seq(Seq(1)), 1L), @@ -90,8 +93,11 @@ class PrefixSpanSuite extends MLTest { test("PrefixSpan Integer type, variable-size itemsets") { val df = smallTestData.toDF("sequence") - val result = PrefixSpan.findFrequentSequentialPatterns(df, "sequence", - minSupport = 0.5, maxPatternLength = 5, maxLocalProjDBSize = 32000000) + val result = new PrefixSpan() + .setMinSupport(0.5) + .setMaxPatternLength(5) + .setMaxLocalProjDBSize(32000000) + .findFrequentSequentialPatterns(df) .as[(Seq[Seq[Int]], Long)].collect() compareResults[Int](smallTestDataExpectedResult, result) @@ -99,8 +105,11 @@ class PrefixSpanSuite extends MLTest { test("PrefixSpan input row with nulls") { val df = (smallTestData :+ null).toDF("sequence") - val result = PrefixSpan.findFrequentSequentialPatterns(df, "sequence", - minSupport = 0.5, maxPatternLength = 5, maxLocalProjDBSize = 32000000) + val result = new PrefixSpan() + .setMinSupport(0.5) + .setMaxPatternLength(5) + .setMaxLocalProjDBSize(32000000) + .findFrequentSequentialPatterns(df) .as[(Seq[Seq[Int]], Long)].collect() compareResults[Int](smallTestDataExpectedResult, result) @@ -111,8 +120,11 @@ class PrefixSpanSuite extends MLTest { val df = smallTestData .map(seq => seq.map(itemSet => itemSet.map(intToString))) .toDF("sequence") - val result = PrefixSpan.findFrequentSequentialPatterns(df, "sequence", - minSupport = 0.5, maxPatternLength = 5, maxLocalProjDBSize = 32000000) + val result = new PrefixSpan() + .setMinSupport(0.5) + .setMaxPatternLength(5) + .setMaxLocalProjDBSize(32000000) + .findFrequentSequentialPatterns(df) .as[(Seq[Seq[String]], Long)].collect() val expected = smallTestDataExpectedResult.map { case (seq, freq) => From 5a5a868dc410ad8c97851d7f3f0ea1c9fc1db90c Mon Sep 17 00:00:00 2001 From: Xiao Li Date: Wed, 23 May 2018 11:51:13 -0700 Subject: [PATCH 0858/1161] Revert "[SPARK-24244][SQL] Passing only required columns to the CSV parser" This reverts commit 8086acc2f676a04ce6255a621ffae871bd09ceea. --- docs/sql-programming-guide.md | 1 - .../apache/spark/sql/internal/SQLConf.scala | 7 --- .../datasources/csv/CSVOptions.scala | 3 -- .../datasources/csv/UnivocityParser.scala | 26 +++++------ .../datasources/csv/CSVBenchmarks.scala | 42 ------------------ .../execution/datasources/csv/CSVSuite.scala | 43 +++---------------- 6 files changed, 18 insertions(+), 104 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index fc26562ff33da..f1ed316341b95 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1825,7 +1825,6 @@ working with timestamps in `pandas_udf`s to get the best performance, see - In version 2.3 and earlier, `to_utc_timestamp` and `from_utc_timestamp` respect the timezone in the input timestamp string, which breaks the assumption that the input timestamp is in a specific timezone. Therefore, these 2 functions can return unexpected results. In version 2.4 and later, this problem has been fixed. `to_utc_timestamp` and `from_utc_timestamp` will return null if the input timestamp string contains timezone. As an example, `from_utc_timestamp('2000-10-10 00:00:00', 'GMT+1')` will return `2000-10-10 01:00:00` in both Spark 2.3 and 2.4. However, `from_utc_timestamp('2000-10-10 00:00:00+00:00', 'GMT+1')`, assuming a local timezone of GMT+8, will return `2000-10-10 09:00:00` in Spark 2.3 but `null` in 2.4. For people who don't care about this problem and want to retain the previous behaivor to keep their query unchanged, you can set `spark.sql.function.rejectTimezoneInString` to false. This option will be removed in Spark 3.0 and should only be used as a temporary workaround. - In version 2.3 and earlier, Spark converts Parquet Hive tables by default but ignores table properties like `TBLPROPERTIES (parquet.compression 'NONE')`. This happens for ORC Hive table properties like `TBLPROPERTIES (orc.compress 'NONE')` in case of `spark.sql.hive.convertMetastoreOrc=true`, too. Since Spark 2.4, Spark respects Parquet/ORC specific table properties while converting Parquet/ORC Hive tables. As an example, `CREATE TABLE t(id int) STORED AS PARQUET TBLPROPERTIES (parquet.compression 'NONE')` would generate Snappy parquet files during insertion in Spark 2.3, and in Spark 2.4, the result would be uncompressed parquet files. - Since Spark 2.0, Spark converts Parquet Hive tables by default for better performance. Since Spark 2.4, Spark converts ORC Hive tables by default, too. It means Spark uses its own ORC support by default instead of Hive SerDe. As an example, `CREATE TABLE t(id int) STORED AS ORC` would be handled with Hive SerDe in Spark 2.3, and in Spark 2.4, it would be converted into Spark's ORC data source table and ORC vectorization would be applied. To set `false` to `spark.sql.hive.convertMetastoreOrc` restores the previous behavior. - - In version 2.3 and earlier, CSV rows are considered as malformed if at least one column value in the row is malformed. CSV parser dropped such rows in the DROPMALFORMED mode or outputs an error in the FAILFAST mode. Since Spark 2.4, CSV row is considered as malformed only when it contains malformed column values requested from CSV datasource, other values can be ignored. As an example, CSV file contains the "id,name" header and one row "1234". In Spark 2.4, selection of the id column consists of a row with one column value 1234 but in Spark 2.3 and earlier it is empty in the DROPMALFORMED mode. To restore the previous behavior, set `spark.sql.csv.parser.columnPruning.enabled` to `false`. ## Upgrading From Spark SQL 2.2 to 2.3 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index fb98df587129b..15ba10f604510 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1307,13 +1307,6 @@ object SQLConf { object Replaced { val MAPREDUCE_JOB_REDUCES = "mapreduce.job.reduces" } - - val CSV_PARSER_COLUMN_PRUNING = buildConf("spark.sql.csv.parser.columnPruning.enabled") - .internal() - .doc("If it is set to true, column names of the requested schema are passed to CSV parser. " + - "Other column values can be ignored during parsing even if they are malformed.") - .booleanConf - .createWithDefault(true) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala index dd41aee0f2ebc..1066d156acd74 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala @@ -25,7 +25,6 @@ import org.apache.commons.lang3.time.FastDateFormat import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.util._ -import org.apache.spark.sql.internal.SQLConf class CSVOptions( @transient val parameters: CaseInsensitiveMap[String], @@ -81,8 +80,6 @@ class CSVOptions( } } - private[csv] val columnPruning = SQLConf.get.getConf(SQLConf.CSV_PARSER_COLUMN_PRUNING) - val delimiter = CSVUtils.toChar( parameters.getOrElse("sep", parameters.getOrElse("delimiter", ","))) val parseMode: ParseMode = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala index 4f00cc5eb3f39..99557a1ceb0c8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala @@ -34,10 +34,10 @@ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String class UnivocityParser( - dataSchema: StructType, + schema: StructType, requiredSchema: StructType, val options: CSVOptions) extends Logging { - require(requiredSchema.toSet.subsetOf(dataSchema.toSet), + require(requiredSchema.toSet.subsetOf(schema.toSet), "requiredSchema should be the subset of schema.") def this(schema: StructType, options: CSVOptions) = this(schema, schema, options) @@ -45,17 +45,9 @@ class UnivocityParser( // A `ValueConverter` is responsible for converting the given value to a desired type. private type ValueConverter = String => Any - private val tokenizer = { - val parserSetting = options.asParserSettings - if (options.columnPruning && requiredSchema.length < dataSchema.length) { - val tokenIndexArr = requiredSchema.map(f => java.lang.Integer.valueOf(dataSchema.indexOf(f))) - parserSetting.selectIndexes(tokenIndexArr: _*) - } - new CsvParser(parserSetting) - } - private val schema = if (options.columnPruning) requiredSchema else dataSchema + private val tokenizer = new CsvParser(options.asParserSettings) - private val row = new GenericInternalRow(schema.length) + private val row = new GenericInternalRow(requiredSchema.length) // Retrieve the raw record string. private def getCurrentInput: UTF8String = { @@ -81,8 +73,11 @@ class UnivocityParser( // Each input token is placed in each output row's position by mapping these. In this case, // // output row - ["A", 2] - private val valueConverters: Array[ValueConverter] = { + private val valueConverters: Array[ValueConverter] = schema.map(f => makeConverter(f.name, f.dataType, f.nullable, options)).toArray + + private val tokenIndexArr: Array[Int] = { + requiredSchema.map(f => schema.indexOf(f)).toArray } /** @@ -215,8 +210,9 @@ class UnivocityParser( } else { try { var i = 0 - while (i < schema.length) { - row(i) = valueConverters(i).apply(tokens(i)) + while (i < requiredSchema.length) { + val from = tokenIndexArr(i) + row(i) = valueConverters(from).apply(tokens(from)) i += 1 } row diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVBenchmarks.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVBenchmarks.scala index ec788df00aa92..d442ba7e59c61 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVBenchmarks.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVBenchmarks.scala @@ -74,49 +74,7 @@ object CSVBenchmarks { } } - def multiColumnsBenchmark(rowsNum: Int): Unit = { - val colsNum = 1000 - val benchmark = new Benchmark(s"Wide rows with $colsNum columns", rowsNum) - - withTempPath { path => - val fields = Seq.tabulate(colsNum)(i => StructField(s"col$i", IntegerType)) - val schema = StructType(fields) - val values = (0 until colsNum).map(i => i.toString).mkString(",") - val columnNames = schema.fieldNames - - spark.range(rowsNum) - .select(Seq.tabulate(colsNum)(i => lit(i).as(s"col$i")): _*) - .write.option("header", true) - .csv(path.getAbsolutePath) - - val ds = spark.read.schema(schema).csv(path.getAbsolutePath) - - benchmark.addCase(s"Select $colsNum columns", 3) { _ => - ds.select("*").filter((row: Row) => true).count() - } - val cols100 = columnNames.take(100).map(Column(_)) - benchmark.addCase(s"Select 100 columns", 3) { _ => - ds.select(cols100: _*).filter((row: Row) => true).count() - } - benchmark.addCase(s"Select one column", 3) { _ => - ds.select($"col1").filter((row: Row) => true).count() - } - - /* - Intel(R) Core(TM) i7-7920HQ CPU @ 3.10GHz - - Wide rows with 1000 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - -------------------------------------------------------------------------------------------- - Select 1000 columns 76910 / 78065 0.0 76909.8 1.0X - Select 100 columns 28625 / 32884 0.0 28625.1 2.7X - Select one column 22498 / 22669 0.0 22497.8 3.4X - */ - benchmark.run() - } - } - def main(args: Array[String]): Unit = { quotedValuesBenchmark(rowsNum = 50 * 1000, numIters = 3) - multiColumnsBenchmark(rowsNum = 1000 * 1000) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index 5f9f799a6c466..07e6c74b14d0d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -260,16 +260,14 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils with Te } test("test for DROPMALFORMED parsing mode") { - withSQLConf(SQLConf.CSV_PARSER_COLUMN_PRUNING.key -> "false") { - Seq(false, true).foreach { multiLine => - val cars = spark.read - .format("csv") - .option("multiLine", multiLine) - .options(Map("header" -> "true", "mode" -> "dropmalformed")) - .load(testFile(carsFile)) + Seq(false, true).foreach { multiLine => + val cars = spark.read + .format("csv") + .option("multiLine", multiLine) + .options(Map("header" -> "true", "mode" -> "dropmalformed")) + .load(testFile(carsFile)) - assert(cars.select("year").collect().size === 2) - } + assert(cars.select("year").collect().size === 2) } } @@ -1370,31 +1368,4 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils with Te checkAnswer(computed, expected) } } - - test("SPARK-24244: Select a subset of all columns") { - withTempPath { path => - import collection.JavaConverters._ - val schema = new StructType() - .add("f1", IntegerType).add("f2", IntegerType).add("f3", IntegerType) - .add("f4", IntegerType).add("f5", IntegerType).add("f6", IntegerType) - .add("f7", IntegerType).add("f8", IntegerType).add("f9", IntegerType) - .add("f10", IntegerType).add("f11", IntegerType).add("f12", IntegerType) - .add("f13", IntegerType).add("f14", IntegerType).add("f15", IntegerType) - - val odf = spark.createDataFrame(List( - Row(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15), - Row(-1, -2, -3, -4, -5, -6, -7, -8, -9, -10, -11, -12, -13, -14, -15) - ).asJava, schema) - odf.write.csv(path.getCanonicalPath) - val idf = spark.read - .schema(schema) - .csv(path.getCanonicalPath) - .select('f15, 'f10, 'f5) - - checkAnswer( - idf, - List(Row(15, 10, 5), Row(-15, -10, -5)) - ) - } - } } From 84557bc9f87885577f22d7e7f2220c0b988014cd Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Wed, 23 May 2018 13:02:32 -0700 Subject: [PATCH 0859/1161] [SPARK-24206][SQL] Improve DataSource read benchmark code ## What changes were proposed in this pull request? This pr added benchmark code `DataSourceReadBenchmark` for `orc`, `paruqet`, `csv`, and `json` based on the existing `ParquetReadBenchmark` and `OrcReadBenchmark`. ## How was this patch tested? N/A Author: Takeshi Yamamuro Closes #21266 from maropu/DataSourceReadBenchmark. --- .../benchmark/DataSourceReadBenchmark.scala | 828 ++++++++++++++++++ .../parquet/ParquetReadBenchmark.scala | 339 ------- 2 files changed, 828 insertions(+), 339 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceReadBenchmark.scala delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceReadBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceReadBenchmark.scala new file mode 100644 index 0000000000000..fc6d8abc03c09 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceReadBenchmark.scala @@ -0,0 +1,828 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.benchmark + +import java.io.File + +import scala.collection.JavaConverters._ +import scala.util.{Random, Try} + +import org.apache.spark.SparkConf +import org.apache.spark.sql.{DataFrame, DataFrameWriter, Row, SparkSession} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.execution.datasources.parquet.{SpecificParquetRecordReaderBase, VectorizedParquetRecordReader} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types._ +import org.apache.spark.sql.vectorized.ColumnVector +import org.apache.spark.util.{Benchmark, Utils} + + +/** + * Benchmark to measure data source read performance. + * To run this: + * spark-submit --class + */ +object DataSourceReadBenchmark { + val conf = new SparkConf() + .setAppName("DataSourceReadBenchmark") + .setIfMissing("spark.master", "local[1]") + .setIfMissing("spark.driver.memory", "3g") + .setIfMissing("spark.executor.memory", "3g") + + val spark = SparkSession.builder.config(conf).getOrCreate() + + // Set default configs. Individual cases will change them if necessary. + spark.conf.set(SQLConf.ORC_FILTER_PUSHDOWN_ENABLED.key, "true") + spark.conf.set(SQLConf.ORC_COPY_BATCH_TO_SPARK.key, "false") + spark.conf.set(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key, "true") + spark.conf.set(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, "true") + + def withTempPath(f: File => Unit): Unit = { + val path = Utils.createTempDir() + path.delete() + try f(path) finally Utils.deleteRecursively(path) + } + + def withTempTable(tableNames: String*)(f: => Unit): Unit = { + try f finally tableNames.foreach(spark.catalog.dropTempView) + } + + def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = { + val (keys, values) = pairs.unzip + val currentValues = keys.map(key => Try(spark.conf.get(key)).toOption) + (keys, values).zipped.foreach(spark.conf.set) + try f finally { + keys.zip(currentValues).foreach { + case (key, Some(value)) => spark.conf.set(key, value) + case (key, None) => spark.conf.unset(key) + } + } + } + private def prepareTable(dir: File, df: DataFrame, partition: Option[String] = None): Unit = { + val testDf = if (partition.isDefined) { + df.write.partitionBy(partition.get) + } else { + df.write + } + + saveAsCsvTable(testDf, dir.getCanonicalPath + "/csv") + saveAsJsonTable(testDf, dir.getCanonicalPath + "/json") + saveAsParquetTable(testDf, dir.getCanonicalPath + "/parquet") + saveAsOrcTable(testDf, dir.getCanonicalPath + "/orc") + } + + private def saveAsCsvTable(df: DataFrameWriter[Row], dir: String): Unit = { + df.mode("overwrite").option("compression", "gzip").option("header", true).csv(dir) + spark.read.option("header", true).csv(dir).createOrReplaceTempView("csvTable") + } + + private def saveAsJsonTable(df: DataFrameWriter[Row], dir: String): Unit = { + df.mode("overwrite").option("compression", "gzip").json(dir) + spark.read.json(dir).createOrReplaceTempView("jsonTable") + } + + private def saveAsParquetTable(df: DataFrameWriter[Row], dir: String): Unit = { + df.mode("overwrite").option("compression", "snappy").parquet(dir) + spark.read.parquet(dir).createOrReplaceTempView("parquetTable") + } + + private def saveAsOrcTable(df: DataFrameWriter[Row], dir: String): Unit = { + df.mode("overwrite").option("compression", "snappy").orc(dir) + spark.read.orc(dir).createOrReplaceTempView("orcTable") + } + + def numericScanBenchmark(values: Int, dataType: DataType): Unit = { + // Benchmarks running through spark sql. + val sqlBenchmark = new Benchmark(s"SQL Single ${dataType.sql} Column Scan", values) + + // Benchmarks driving reader component directly. + val parquetReaderBenchmark = new Benchmark( + s"Parquet Reader Single ${dataType.sql} Column Scan", values) + + withTempPath { dir => + withTempTable("t1", "csvTable", "jsonTable", "parquetTable", "orcTable") { + import spark.implicits._ + spark.range(values).map(_ => Random.nextLong).createOrReplaceTempView("t1") + + prepareTable(dir, spark.sql(s"SELECT CAST(value as ${dataType.sql}) id FROM t1")) + + sqlBenchmark.addCase("SQL CSV") { _ => + spark.sql("select sum(id) from csvTable").collect() + } + + sqlBenchmark.addCase("SQL Json") { _ => + spark.sql("select sum(id) from jsonTable").collect() + } + + sqlBenchmark.addCase("SQL Parquet Vectorized") { _ => + spark.sql("select sum(id) from parquetTable").collect() + } + + sqlBenchmark.addCase("SQL Parquet MR") { _ => + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { + spark.sql("select sum(id) from parquetTable").collect() + } + } + + sqlBenchmark.addCase("SQL ORC Vectorized") { _ => + spark.sql("SELECT sum(id) FROM orcTable").collect() + } + + sqlBenchmark.addCase("SQL ORC Vectorized with copy") { _ => + withSQLConf(SQLConf.ORC_COPY_BATCH_TO_SPARK.key -> "true") { + spark.sql("SELECT sum(id) FROM orcTable").collect() + } + } + + sqlBenchmark.addCase("SQL ORC MR") { _ => + withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> "false") { + spark.sql("SELECT sum(id) FROM orcTable").collect() + } + } + + + /* + Intel(R) Xeon(R) CPU E5-2686 v4 @ 2.30GHz + SQL Single TINYINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + -------------------------------------------------------------------------------------------- + SQL CSV 15231 / 15267 1.0 968.3 1.0X + SQL Json 8476 / 8498 1.9 538.9 1.8X + SQL Parquet Vectorized 121 / 127 130.0 7.7 125.9X + SQL Parquet MR 1515 / 1543 10.4 96.3 10.1X + SQL ORC Vectorized 164 / 171 95.9 10.4 92.9X + SQL ORC Vectorized with copy 228 / 234 69.0 14.5 66.8X + SQL ORC MR 1297 / 1309 12.1 82.5 11.7X + + + SQL Single SMALLINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + -------------------------------------------------------------------------------------------- + SQL CSV 16344 / 16374 1.0 1039.1 1.0X + SQL Json 8634 / 8648 1.8 548.9 1.9X + SQL Parquet Vectorized 172 / 177 91.5 10.9 95.1X + SQL Parquet MR 1744 / 1746 9.0 110.9 9.4X + SQL ORC Vectorized 189 / 194 83.1 12.0 86.4X + SQL ORC Vectorized with copy 244 / 250 64.5 15.5 67.0X + SQL ORC MR 1341 / 1386 11.7 85.3 12.2X + + + SQL Single INT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + -------------------------------------------------------------------------------------------- + SQL CSV 17874 / 17875 0.9 1136.4 1.0X + SQL Json 9190 / 9204 1.7 584.3 1.9X + SQL Parquet Vectorized 141 / 160 111.2 9.0 126.4X + SQL Parquet MR 1930 / 2049 8.2 122.7 9.3X + SQL ORC Vectorized 259 / 264 60.7 16.5 69.0X + SQL ORC Vectorized with copy 265 / 272 59.4 16.8 67.5X + SQL ORC MR 1528 / 1569 10.3 97.2 11.7X + + + SQL Single BIGINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + -------------------------------------------------------------------------------------------- + SQL CSV 22812 / 22839 0.7 1450.4 1.0X + SQL Json 12026 / 12054 1.3 764.6 1.9X + SQL Parquet Vectorized 222 / 227 70.8 14.1 102.6X + SQL Parquet MR 2199 / 2204 7.2 139.8 10.4X + SQL ORC Vectorized 331 / 335 47.6 21.0 69.0X + SQL ORC Vectorized with copy 338 / 343 46.6 21.5 67.6X + SQL ORC MR 1618 / 1622 9.7 102.9 14.1X + + + SQL Single FLOAT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + -------------------------------------------------------------------------------------------- + SQL CSV 18703 / 18740 0.8 1189.1 1.0X + SQL Json 11779 / 11869 1.3 748.9 1.6X + SQL Parquet Vectorized 143 / 145 110.1 9.1 130.9X + SQL Parquet MR 1954 / 1963 8.0 124.2 9.6X + SQL ORC Vectorized 347 / 355 45.3 22.1 53.8X + SQL ORC Vectorized with copy 356 / 359 44.1 22.7 52.5X + SQL ORC MR 1570 / 1598 10.0 99.8 11.9X + + + SQL Single DOUBLE Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + -------------------------------------------------------------------------------------------- + SQL CSV 23832 / 23838 0.7 1515.2 1.0X + SQL Json 16204 / 16226 1.0 1030.2 1.5X + SQL Parquet Vectorized 242 / 306 65.1 15.4 98.6X + SQL Parquet MR 2462 / 2482 6.4 156.5 9.7X + SQL ORC Vectorized 419 / 451 37.6 26.6 56.9X + SQL ORC Vectorized with copy 426 / 447 36.9 27.1 55.9X + SQL ORC MR 1885 / 1931 8.3 119.8 12.6X + */ + sqlBenchmark.run() + + // Driving the parquet reader in batch mode directly. + val files = SpecificParquetRecordReaderBase.listDirectory(new File(dir, "parquet")).toArray + val enableOffHeapColumnVector = spark.sessionState.conf.offHeapColumnVectorEnabled + val vectorizedReaderBatchSize = spark.sessionState.conf.parquetVectorizedReaderBatchSize + parquetReaderBenchmark.addCase("ParquetReader Vectorized") { _ => + var longSum = 0L + var doubleSum = 0.0 + val aggregateValue: (ColumnVector, Int) => Unit = dataType match { + case ByteType => (col: ColumnVector, i: Int) => longSum += col.getByte(i) + case ShortType => (col: ColumnVector, i: Int) => longSum += col.getShort(i) + case IntegerType => (col: ColumnVector, i: Int) => longSum += col.getInt(i) + case LongType => (col: ColumnVector, i: Int) => longSum += col.getLong(i) + case FloatType => (col: ColumnVector, i: Int) => doubleSum += col.getFloat(i) + case DoubleType => (col: ColumnVector, i: Int) => doubleSum += col.getDouble(i) + } + + files.map(_.asInstanceOf[String]).foreach { p => + val reader = new VectorizedParquetRecordReader( + null, enableOffHeapColumnVector, vectorizedReaderBatchSize) + try { + reader.initialize(p, ("id" :: Nil).asJava) + val batch = reader.resultBatch() + val col = batch.column(0) + while (reader.nextBatch()) { + val numRows = batch.numRows() + var i = 0 + while (i < numRows) { + if (!col.isNullAt(i)) aggregateValue(col, i) + i += 1 + } + } + } finally { + reader.close() + } + } + } + + // Decoding in vectorized but having the reader return rows. + parquetReaderBenchmark.addCase("ParquetReader Vectorized -> Row") { num => + var longSum = 0L + var doubleSum = 0.0 + val aggregateValue: (InternalRow) => Unit = dataType match { + case ByteType => (col: InternalRow) => longSum += col.getByte(0) + case ShortType => (col: InternalRow) => longSum += col.getShort(0) + case IntegerType => (col: InternalRow) => longSum += col.getInt(0) + case LongType => (col: InternalRow) => longSum += col.getLong(0) + case FloatType => (col: InternalRow) => doubleSum += col.getFloat(0) + case DoubleType => (col: InternalRow) => doubleSum += col.getDouble(0) + } + + files.map(_.asInstanceOf[String]).foreach { p => + val reader = new VectorizedParquetRecordReader( + null, enableOffHeapColumnVector, vectorizedReaderBatchSize) + try { + reader.initialize(p, ("id" :: Nil).asJava) + val batch = reader.resultBatch() + while (reader.nextBatch()) { + val it = batch.rowIterator() + while (it.hasNext) { + val record = it.next() + if (!record.isNullAt(0)) aggregateValue(record) + } + } + } finally { + reader.close() + } + } + } + + /* + Intel(R) Xeon(R) CPU E5-2686 v4 @ 2.30GHz + Single TINYINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + -------------------------------------------------------------------------------------------- + ParquetReader Vectorized 187 / 201 84.2 11.9 1.0X + ParquetReader Vectorized -> Row 101 / 103 156.4 6.4 1.9X + + + Single SMALLINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + -------------------------------------------------------------------------------------------- + ParquetReader Vectorized 272 / 288 57.8 17.3 1.0X + ParquetReader Vectorized -> Row 213 / 219 73.7 13.6 1.3X + + + Single INT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + -------------------------------------------------------------------------------------------- + ParquetReader Vectorized 252 / 288 62.5 16.0 1.0X + ParquetReader Vectorized -> Row 232 / 246 67.7 14.8 1.1X + + + Single BIGINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + -------------------------------------------------------------------------------------------- + ParquetReader Vectorized 415 / 454 37.9 26.4 1.0X + ParquetReader Vectorized -> Row 407 / 432 38.6 25.9 1.0X + + + Single FLOAT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + -------------------------------------------------------------------------------------------- + ParquetReader Vectorized 251 / 302 62.7 16.0 1.0X + ParquetReader Vectorized -> Row 220 / 234 71.5 14.0 1.1X + + + Single DOUBLE Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + -------------------------------------------------------------------------------------------- + ParquetReader Vectorized 432 / 436 36.4 27.5 1.0X + ParquetReader Vectorized -> Row 414 / 422 38.0 26.4 1.0X + */ + parquetReaderBenchmark.run() + } + } + } + + def intStringScanBenchmark(values: Int): Unit = { + val benchmark = new Benchmark("Int and String Scan", values) + + withTempPath { dir => + withTempTable("t1", "csvTable", "jsonTable", "parquetTable", "orcTable") { + import spark.implicits._ + spark.range(values).map(_ => Random.nextLong).createOrReplaceTempView("t1") + + prepareTable( + dir, + spark.sql("SELECT CAST(value AS INT) AS c1, CAST(value as STRING) AS c2 FROM t1")) + + benchmark.addCase("SQL CSV") { _ => + spark.sql("select sum(c1), sum(length(c2)) from csvTable").collect() + } + + benchmark.addCase("SQL Json") { _ => + spark.sql("select sum(c1), sum(length(c2)) from jsonTable").collect() + } + + benchmark.addCase("SQL Parquet Vectorized") { _ => + spark.sql("select sum(c1), sum(length(c2)) from parquetTable").collect() + } + + benchmark.addCase("SQL Parquet MR") { _ => + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { + spark.sql("select sum(c1), sum(length(c2)) from parquetTable").collect() + } + } + + benchmark.addCase("SQL ORC Vectorized") { _ => + spark.sql("SELECT sum(c1), sum(length(c2)) FROM orcTable").collect() + } + + benchmark.addCase("SQL ORC Vectorized with copy") { _ => + withSQLConf(SQLConf.ORC_COPY_BATCH_TO_SPARK.key -> "true") { + spark.sql("SELECT sum(c1), sum(length(c2)) FROM orcTable").collect() + } + } + + benchmark.addCase("SQL ORC MR") { _ => + withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> "false") { + spark.sql("SELECT sum(c1), sum(length(c2)) FROM orcTable").collect() + } + } + + /* + Intel(R) Xeon(R) CPU E5-2686 v4 @ 2.30GHz + Int and String Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + -------------------------------------------------------------------------------------------- + SQL CSV 19172 / 19173 0.5 1828.4 1.0X + SQL Json 12799 / 12873 0.8 1220.6 1.5X + SQL Parquet Vectorized 2558 / 2564 4.1 244.0 7.5X + SQL Parquet MR 4514 / 4583 2.3 430.4 4.2X + SQL ORC Vectorized 2561 / 2697 4.1 244.3 7.5X + SQL ORC Vectorized with copy 3076 / 3110 3.4 293.4 6.2X + SQL ORC MR 4197 / 4283 2.5 400.2 4.6X + */ + benchmark.run() + } + } + } + + def repeatedStringScanBenchmark(values: Int): Unit = { + val benchmark = new Benchmark("Repeated String", values) + + withTempPath { dir => + withTempTable("t1", "csvTable", "jsonTable", "parquetTable", "orcTable") { + import spark.implicits._ + spark.range(values).map(_ => Random.nextLong).createOrReplaceTempView("t1") + + prepareTable( + dir, + spark.sql("select cast((value % 200) + 10000 as STRING) as c1 from t1")) + + benchmark.addCase("SQL CSV") { _ => + spark.sql("select sum(length(c1)) from csvTable").collect() + } + + benchmark.addCase("SQL Json") { _ => + spark.sql("select sum(length(c1)) from jsonTable").collect() + } + + benchmark.addCase("SQL Parquet Vectorized") { _ => + spark.sql("select sum(length(c1)) from parquetTable").collect() + } + + benchmark.addCase("SQL Parquet MR") { _ => + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { + spark.sql("select sum(length(c1)) from parquetTable").collect() + } + } + + benchmark.addCase("SQL ORC Vectorized") { _ => + spark.sql("select sum(length(c1)) from orcTable").collect() + } + + benchmark.addCase("SQL ORC Vectorized with copy") { _ => + withSQLConf(SQLConf.ORC_COPY_BATCH_TO_SPARK.key -> "true") { + spark.sql("select sum(length(c1)) from orcTable").collect() + } + } + + benchmark.addCase("SQL ORC MR") { _ => + withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> "false") { + spark.sql("select sum(length(c1)) from orcTable").collect() + } + } + + /* + Intel(R) Xeon(R) CPU E5-2686 v4 @ 2.30GHz + Repeated String: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + -------------------------------------------------------------------------------------------- + SQL CSV 10889 / 10924 1.0 1038.5 1.0X + SQL Json 7903 / 7931 1.3 753.7 1.4X + SQL Parquet Vectorized 777 / 799 13.5 74.1 14.0X + SQL Parquet MR 1682 / 1708 6.2 160.4 6.5X + SQL ORC Vectorized 532 / 534 19.7 50.7 20.5X + SQL ORC Vectorized with copy 742 / 743 14.1 70.7 14.7X + SQL ORC MR 1996 / 2002 5.3 190.4 5.5X + */ + benchmark.run() + } + } + } + + def partitionTableScanBenchmark(values: Int): Unit = { + val benchmark = new Benchmark("Partitioned Table", values) + + withTempPath { dir => + withTempTable("t1", "csvTable", "jsonTable", "parquetTable", "orcTable") { + import spark.implicits._ + spark.range(values).map(_ => Random.nextLong).createOrReplaceTempView("t1") + + prepareTable(dir, spark.sql("SELECT value % 2 AS p, value AS id FROM t1"), Some("p")) + + benchmark.addCase("Data column - CSV") { _ => + spark.sql("select sum(id) from csvTable").collect() + } + + benchmark.addCase("Data column - Json") { _ => + spark.sql("select sum(id) from jsonTable").collect() + } + + benchmark.addCase("Data column - Parquet Vectorized") { _ => + spark.sql("select sum(id) from parquetTable").collect() + } + + benchmark.addCase("Data column - Parquet MR") { _ => + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { + spark.sql("select sum(id) from parquetTable").collect() + } + } + + benchmark.addCase("Data column - ORC Vectorized") { _ => + spark.sql("SELECT sum(id) FROM orcTable").collect() + } + + benchmark.addCase("Data column - ORC Vectorized with copy") { _ => + withSQLConf(SQLConf.ORC_COPY_BATCH_TO_SPARK.key -> "true") { + spark.sql("SELECT sum(id) FROM orcTable").collect() + } + } + + benchmark.addCase("Data column - ORC MR") { _ => + withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> "false") { + spark.sql("SELECT sum(id) FROM orcTable").collect() + } + } + + benchmark.addCase("Partition column - CSV") { _ => + spark.sql("select sum(p) from csvTable").collect() + } + + benchmark.addCase("Partition column - Json") { _ => + spark.sql("select sum(p) from jsonTable").collect() + } + + benchmark.addCase("Partition column - Parquet Vectorized") { _ => + spark.sql("select sum(p) from parquetTable").collect() + } + + benchmark.addCase("Partition column - Parquet MR") { _ => + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { + spark.sql("select sum(p) from parquetTable").collect() + } + } + + benchmark.addCase("Partition column - ORC Vectorized") { _ => + spark.sql("SELECT sum(p) FROM orcTable").collect() + } + + benchmark.addCase("Partition column - ORC Vectorized with copy") { _ => + withSQLConf(SQLConf.ORC_COPY_BATCH_TO_SPARK.key -> "true") { + spark.sql("SELECT sum(p) FROM orcTable").collect() + } + } + + benchmark.addCase("Partition column - ORC MR") { _ => + withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> "false") { + spark.sql("SELECT sum(p) FROM orcTable").collect() + } + } + + benchmark.addCase("Both columns - CSV") { _ => + spark.sql("select sum(p), sum(id) from csvTable").collect() + } + + benchmark.addCase("Both columns - Json") { _ => + spark.sql("select sum(p), sum(id) from jsonTable").collect() + } + + benchmark.addCase("Both columns - Parquet Vectorized") { _ => + spark.sql("select sum(p), sum(id) from parquetTable").collect() + } + + benchmark.addCase("Both columns - Parquet MR") { _ => + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { + spark.sql("select sum(p), sum(id) from parquetTable").collect + } + } + + benchmark.addCase("Both columns - ORC Vectorized") { _ => + spark.sql("SELECT sum(p), sum(id) FROM orcTable").collect() + } + + benchmark.addCase("Both column - ORC Vectorized with copy") { _ => + withSQLConf(SQLConf.ORC_COPY_BATCH_TO_SPARK.key -> "true") { + spark.sql("SELECT sum(p), sum(id) FROM orcTable").collect() + } + } + + benchmark.addCase("Both columns - ORC MR") { _ => + withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> "false") { + spark.sql("SELECT sum(p), sum(id) FROM orcTable").collect() + } + } + + /* + Intel(R) Xeon(R) CPU E5-2686 v4 @ 2.30GHz + Partitioned Table: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + -------------------------------------------------------------------------------------------- + Data column - CSV 25428 / 25454 0.6 1616.7 1.0X + Data column - Json 12689 / 12774 1.2 806.7 2.0X + Data column - Parquet Vectorized 222 / 231 70.7 14.1 114.3X + Data column - Parquet MR 3355 / 3397 4.7 213.3 7.6X + Data column - ORC Vectorized 332 / 338 47.4 21.1 76.6X + Data column - ORC Vectorized with copy 338 / 341 46.5 21.5 75.2X + Data column - ORC MR 2329 / 2356 6.8 148.0 10.9X + Partition column - CSV 17465 / 17502 0.9 1110.4 1.5X + Partition column - Json 10865 / 10876 1.4 690.8 2.3X + Partition column - Parquet Vectorized 48 / 52 325.4 3.1 526.1X + Partition column - Parquet MR 1695 / 1696 9.3 107.8 15.0X + Partition column - ORC Vectorized 49 / 54 319.9 3.1 517.2X + Partition column - ORC Vectorized with copy 49 / 52 324.1 3.1 524.0X + Partition column - ORC MR 1548 / 1549 10.2 98.4 16.4X + Both columns - CSV 25568 / 25595 0.6 1625.6 1.0X + Both columns - Json 13658 / 13673 1.2 868.4 1.9X + Both columns - Parquet Vectorized 270 / 296 58.3 17.1 94.3X + Both columns - Parquet MR 3501 / 3521 4.5 222.6 7.3X + Both columns - ORC Vectorized 377 / 380 41.7 24.0 67.4X + Both column - ORC Vectorized with copy 447 / 448 35.2 28.4 56.9X + Both columns - ORC MR 2440 / 2446 6.4 155.2 10.4X + */ + benchmark.run() + } + } + } + + def stringWithNullsScanBenchmark(values: Int, fractionOfNulls: Double): Unit = { + val benchmark = new Benchmark("String with Nulls Scan", values) + + withTempPath { dir => + withTempTable("t1", "csvTable", "jsonTable", "parquetTable", "orcTable") { + spark.range(values).createOrReplaceTempView("t1") + + prepareTable( + dir, + spark.sql( + s"SELECT IF(RAND(1) < $fractionOfNulls, NULL, CAST(id as STRING)) AS c1, " + + s"IF(RAND(2) < $fractionOfNulls, NULL, CAST(id as STRING)) AS c2 FROM t1")) + + benchmark.addCase("SQL CSV") { _ => + spark.sql("select sum(length(c2)) from csvTable where c1 is " + + "not NULL and c2 is not NULL").collect() + } + + benchmark.addCase("SQL Json") { _ => + spark.sql("select sum(length(c2)) from jsonTable where c1 is " + + "not NULL and c2 is not NULL").collect() + } + + benchmark.addCase("SQL Parquet Vectorized") { _ => + spark.sql("select sum(length(c2)) from parquetTable where c1 is " + + "not NULL and c2 is not NULL").collect() + } + + benchmark.addCase("SQL Parquet MR") { _ => + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { + spark.sql("select sum(length(c2)) from parquetTable where c1 is " + + "not NULL and c2 is not NULL").collect() + } + } + + val files = SpecificParquetRecordReaderBase.listDirectory(new File(dir, "parquet")).toArray + val enableOffHeapColumnVector = spark.sessionState.conf.offHeapColumnVectorEnabled + val vectorizedReaderBatchSize = spark.sessionState.conf.parquetVectorizedReaderBatchSize + benchmark.addCase("ParquetReader Vectorized") { num => + var sum = 0 + files.map(_.asInstanceOf[String]).foreach { p => + val reader = new VectorizedParquetRecordReader( + null, enableOffHeapColumnVector, vectorizedReaderBatchSize) + try { + reader.initialize(p, ("c1" :: "c2" :: Nil).asJava) + val batch = reader.resultBatch() + while (reader.nextBatch()) { + val rowIterator = batch.rowIterator() + while (rowIterator.hasNext) { + val row = rowIterator.next() + val value = row.getUTF8String(0) + if (!row.isNullAt(0) && !row.isNullAt(1)) sum += value.numBytes() + } + } + } finally { + reader.close() + } + } + } + + benchmark.addCase("SQL ORC Vectorized") { _ => + spark.sql("SELECT SUM(LENGTH(c2)) FROM orcTable " + + "WHERE c1 IS NOT NULL AND c2 IS NOT NULL").collect() + } + + benchmark.addCase("SQL ORC Vectorized with copy") { _ => + withSQLConf(SQLConf.ORC_COPY_BATCH_TO_SPARK.key -> "true") { + spark.sql("SELECT SUM(LENGTH(c2)) FROM orcTable " + + "WHERE c1 IS NOT NULL AND c2 IS NOT NULL").collect() + } + } + + benchmark.addCase("SQL ORC MR") { _ => + withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> "false") { + spark.sql("SELECT SUM(LENGTH(c2)) FROM orcTable " + + "WHERE c1 IS NOT NULL AND c2 IS NOT NULL").collect() + } + } + + /* + Intel(R) Xeon(R) CPU E5-2686 v4 @ 2.30GHz + String with Nulls Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + -------------------------------------------------------------------------------------------- + SQL CSV 13518 / 13529 0.8 1289.2 1.0X + SQL Json 10895 / 10926 1.0 1039.0 1.2X + SQL Parquet Vectorized 1539 / 1581 6.8 146.8 8.8X + SQL Parquet MR 3746 / 3811 2.8 357.3 3.6X + ParquetReader Vectorized 1070 / 1112 9.8 102.0 12.6X + SQL ORC Vectorized 1389 / 1408 7.6 132.4 9.7X + SQL ORC Vectorized with copy 1736 / 1750 6.0 165.6 7.8X + SQL ORC MR 3799 / 3892 2.8 362.3 3.6X + + + String with Nulls Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + -------------------------------------------------------------------------------------------- + SQL CSV 10854 / 10892 1.0 1035.2 1.0X + SQL Json 8129 / 8138 1.3 775.3 1.3X + SQL Parquet Vectorized 1053 / 1104 10.0 100.4 10.3X + SQL Parquet MR 2840 / 2854 3.7 270.8 3.8X + ParquetReader Vectorized 978 / 1008 10.7 93.2 11.1X + SQL ORC Vectorized 1312 / 1387 8.0 125.1 8.3X + SQL ORC Vectorized with copy 1764 / 1772 5.9 168.2 6.2X + SQL ORC MR 3435 / 3445 3.1 327.6 3.2X + + + String with Nulls Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + -------------------------------------------------------------------------------------------- + SQL CSV 8043 / 8048 1.3 767.1 1.0X + SQL Json 4911 / 4923 2.1 468.4 1.6X + SQL Parquet Vectorized 206 / 209 51.0 19.6 39.1X + SQL Parquet MR 1528 / 1537 6.9 145.8 5.3X + ParquetReader Vectorized 216 / 219 48.6 20.6 37.2X + SQL ORC Vectorized 462 / 466 22.7 44.1 17.4X + SQL ORC Vectorized with copy 568 / 572 18.5 54.2 14.2X + SQL ORC MR 1647 / 1649 6.4 157.1 4.9X + */ + benchmark.run() + } + } + } + + def columnsBenchmark(values: Int, width: Int): Unit = { + val benchmark = new Benchmark(s"Single Column Scan from $width columns", values) + + withTempPath { dir => + withTempTable("t1", "csvTable", "jsonTable", "parquetTable", "orcTable") { + import spark.implicits._ + val middle = width / 2 + val selectExpr = (1 to width).map(i => s"value as c$i") + spark.range(values).map(_ => Random.nextLong).toDF() + .selectExpr(selectExpr: _*).createOrReplaceTempView("t1") + + prepareTable(dir, spark.sql("SELECT * FROM t1")) + + benchmark.addCase("SQL CSV") { _ => + spark.sql(s"SELECT sum(c$middle) FROM csvTable").collect() + } + + benchmark.addCase("SQL Json") { _ => + spark.sql(s"SELECT sum(c$middle) FROM jsonTable").collect() + } + + benchmark.addCase("SQL Parquet Vectorized") { _ => + spark.sql(s"SELECT sum(c$middle) FROM parquetTable").collect() + } + + benchmark.addCase("SQL Parquet MR") { _ => + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { + spark.sql(s"SELECT sum(c$middle) FROM parquetTable").collect() + } + } + + benchmark.addCase("SQL ORC Vectorized") { _ => + spark.sql(s"SELECT sum(c$middle) FROM orcTable").collect() + } + + benchmark.addCase("SQL ORC Vectorized with copy") { _ => + withSQLConf(SQLConf.ORC_COPY_BATCH_TO_SPARK.key -> "true") { + spark.sql(s"SELECT sum(c$middle) FROM orcTable").collect() + } + } + + benchmark.addCase("SQL ORC MR") { _ => + withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> "false") { + spark.sql(s"SELECT sum(c$middle) FROM orcTable").collect() + } + } + + /* + Intel(R) Xeon(R) CPU E5-2686 v4 @ 2.30GHz + Single Column Scan from 10 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + -------------------------------------------------------------------------------------------- + SQL CSV 3663 / 3665 0.3 3493.2 1.0X + SQL Json 3122 / 3160 0.3 2977.5 1.2X + SQL Parquet Vectorized 40 / 42 26.2 38.2 91.5X + SQL Parquet MR 189 / 192 5.5 180.2 19.4X + SQL ORC Vectorized 48 / 51 21.6 46.2 75.6X + SQL ORC Vectorized with copy 49 / 52 21.4 46.7 74.9X + SQL ORC MR 280 / 289 3.7 267.1 13.1X + + + Single Column Scan from 50 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + -------------------------------------------------------------------------------------------- + SQL CSV 11420 / 11505 0.1 10891.1 1.0X + SQL Json 11905 / 12120 0.1 11353.6 1.0X + SQL Parquet Vectorized 50 / 54 20.9 47.8 227.7X + SQL Parquet MR 195 / 199 5.4 185.8 58.6X + SQL ORC Vectorized 61 / 65 17.3 57.8 188.3X + SQL ORC Vectorized with copy 62 / 65 17.0 58.8 185.2X + SQL ORC MR 847 / 865 1.2 807.4 13.5X + + + Single Column Scan from 100 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + -------------------------------------------------------------------------------------------- + SQL CSV 21278 / 21404 0.0 20292.4 1.0X + SQL Json 22455 / 22625 0.0 21414.7 0.9X + SQL Parquet Vectorized 73 / 75 14.4 69.3 292.8X + SQL Parquet MR 220 / 226 4.8 209.7 96.8X + SQL ORC Vectorized 82 / 86 12.8 78.2 259.4X + SQL ORC Vectorized with copy 82 / 90 12.7 78.7 258.0X + SQL ORC MR 1568 / 1582 0.7 1495.4 13.6X + */ + benchmark.run() + } + } + } + + def main(args: Array[String]): Unit = { + Seq(ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType).foreach { dataType => + numericScanBenchmark(1024 * 1024 * 15, dataType) + } + intStringScanBenchmark(1024 * 1024 * 10) + repeatedStringScanBenchmark(1024 * 1024 * 10) + partitionTableScanBenchmark(1024 * 1024 * 15) + for (fractionOfNulls <- List(0.0, 0.50, 0.95)) { + stringWithNullsScanBenchmark(1024 * 1024 * 10, fractionOfNulls) + } + for (columnWidth <- List(10, 50, 100)) { + columnsBenchmark(1024 * 1024 * 1, columnWidth) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala deleted file mode 100644 index e43336d947364..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala +++ /dev/null @@ -1,339 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.sql.execution.datasources.parquet - -import java.io.File - -import scala.collection.JavaConverters._ -import scala.util.Try - -import org.apache.spark.SparkConf -import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.util.{Benchmark, Utils} - -/** - * Benchmark to measure parquet read performance. - * To run this: - * spark-submit --class --jars - */ -object ParquetReadBenchmark { - val conf = new SparkConf() - conf.set("spark.sql.parquet.compression.codec", "snappy") - - val spark = SparkSession.builder - .master("local[1]") - .appName("test-sql-context") - .config(conf) - .getOrCreate() - - // Set default configs. Individual cases will change them if necessary. - spark.conf.set(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key, "true") - spark.conf.set(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, "true") - - def withTempPath(f: File => Unit): Unit = { - val path = Utils.createTempDir() - path.delete() - try f(path) finally Utils.deleteRecursively(path) - } - - def withTempTable(tableNames: String*)(f: => Unit): Unit = { - try f finally tableNames.foreach(spark.catalog.dropTempView) - } - - def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = { - val (keys, values) = pairs.unzip - val currentValues = keys.map(key => Try(spark.conf.get(key)).toOption) - (keys, values).zipped.foreach(spark.conf.set) - try f finally { - keys.zip(currentValues).foreach { - case (key, Some(value)) => spark.conf.set(key, value) - case (key, None) => spark.conf.unset(key) - } - } - } - - def intScanBenchmark(values: Int): Unit = { - // Benchmarks running through spark sql. - val sqlBenchmark = new Benchmark("SQL Single Int Column Scan", values) - // Benchmarks driving reader component directly. - val parquetReaderBenchmark = new Benchmark("Parquet Reader Single Int Column Scan", values) - - withTempPath { dir => - withTempTable("t1", "tempTable") { - val enableOffHeapColumnVector = spark.sessionState.conf.offHeapColumnVectorEnabled - val vectorizedReaderBatchSize = spark.sessionState.conf.parquetVectorizedReaderBatchSize - spark.range(values).createOrReplaceTempView("t1") - spark.sql("select cast(id as INT) as id from t1") - .write.parquet(dir.getCanonicalPath) - spark.read.parquet(dir.getCanonicalPath).createOrReplaceTempView("tempTable") - - sqlBenchmark.addCase("SQL Parquet Vectorized") { iter => - spark.sql("select sum(id) from tempTable").collect() - } - - sqlBenchmark.addCase("SQL Parquet MR") { iter => - withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { - spark.sql("select sum(id) from tempTable").collect() - } - } - - val files = SpecificParquetRecordReaderBase.listDirectory(dir).toArray - // Driving the parquet reader in batch mode directly. - parquetReaderBenchmark.addCase("ParquetReader Vectorized") { num => - var sum = 0L - files.map(_.asInstanceOf[String]).foreach { p => - val reader = new VectorizedParquetRecordReader( - null, enableOffHeapColumnVector, vectorizedReaderBatchSize) - try { - reader.initialize(p, ("id" :: Nil).asJava) - val batch = reader.resultBatch() - val col = batch.column(0) - while (reader.nextBatch()) { - val numRows = batch.numRows() - var i = 0 - while (i < numRows) { - if (!col.isNullAt(i)) sum += col.getInt(i) - i += 1 - } - } - } finally { - reader.close() - } - } - } - - // Decoding in vectorized but having the reader return rows. - parquetReaderBenchmark.addCase("ParquetReader Vectorized -> Row") { num => - var sum = 0L - files.map(_.asInstanceOf[String]).foreach { p => - val reader = new VectorizedParquetRecordReader( - null, enableOffHeapColumnVector, vectorizedReaderBatchSize) - try { - reader.initialize(p, ("id" :: Nil).asJava) - val batch = reader.resultBatch() - while (reader.nextBatch()) { - val it = batch.rowIterator() - while (it.hasNext) { - val record = it.next() - if (!record.isNullAt(0)) sum += record.getInt(0) - } - } - } finally { - reader.close() - } - } - } - - /* - Intel(R) Core(TM) i7-4870HQ CPU @ 2.50GHz - SQL Single Int Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------- - SQL Parquet Vectorized 215 / 262 73.0 13.7 1.0X - SQL Parquet MR 1946 / 2083 8.1 123.7 0.1X - */ - sqlBenchmark.run() - - /* - Intel(R) Core(TM) i7-4870HQ CPU @ 2.50GHz - Parquet Reader Single Int Column Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------- - ParquetReader Vectorized 123 / 152 127.8 7.8 1.0X - ParquetReader Vectorized -> Row 165 / 180 95.2 10.5 0.7X - */ - parquetReaderBenchmark.run() - } - } - } - - def intStringScanBenchmark(values: Int): Unit = { - withTempPath { dir => - withTempTable("t1", "tempTable") { - spark.range(values).createOrReplaceTempView("t1") - spark.sql("select cast(id as INT) as c1, cast(id as STRING) as c2 from t1") - .write.parquet(dir.getCanonicalPath) - spark.read.parquet(dir.getCanonicalPath).createOrReplaceTempView("tempTable") - - val benchmark = new Benchmark("Int and String Scan", values) - - benchmark.addCase("SQL Parquet Vectorized") { iter => - spark.sql("select sum(c1), sum(length(c2)) from tempTable").collect - } - - benchmark.addCase("SQL Parquet MR") { iter => - withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { - spark.sql("select sum(c1), sum(length(c2)) from tempTable").collect - } - } - - val files = SpecificParquetRecordReaderBase.listDirectory(dir).toArray - - /* - Intel(R) Core(TM) i7-4870HQ CPU @ 2.50GHz - Int and String Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------- - SQL Parquet Vectorized 628 / 720 16.7 59.9 1.0X - SQL Parquet MR 1905 / 2239 5.5 181.7 0.3X - */ - benchmark.run() - } - } - } - - def stringDictionaryScanBenchmark(values: Int): Unit = { - withTempPath { dir => - withTempTable("t1", "tempTable") { - spark.range(values).createOrReplaceTempView("t1") - spark.sql("select cast((id % 200) + 10000 as STRING) as c1 from t1") - .write.parquet(dir.getCanonicalPath) - spark.read.parquet(dir.getCanonicalPath).createOrReplaceTempView("tempTable") - - val benchmark = new Benchmark("String Dictionary", values) - - benchmark.addCase("SQL Parquet Vectorized") { iter => - spark.sql("select sum(length(c1)) from tempTable").collect - } - - benchmark.addCase("SQL Parquet MR") { iter => - withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { - spark.sql("select sum(length(c1)) from tempTable").collect - } - } - - /* - Intel(R) Core(TM) i7-4870HQ CPU @ 2.50GHz - String Dictionary: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------- - SQL Parquet Vectorized 329 / 337 31.9 31.4 1.0X - SQL Parquet MR 1131 / 1325 9.3 107.8 0.3X - */ - benchmark.run() - } - } - } - - def partitionTableScanBenchmark(values: Int): Unit = { - withTempPath { dir => - withTempTable("t1", "tempTable") { - spark.range(values).createOrReplaceTempView("t1") - spark.sql("select id % 2 as p, cast(id as INT) as id from t1") - .write.partitionBy("p").parquet(dir.getCanonicalPath) - spark.read.parquet(dir.getCanonicalPath).createOrReplaceTempView("tempTable") - - val benchmark = new Benchmark("Partitioned Table", values) - - benchmark.addCase("Read data column") { iter => - spark.sql("select sum(id) from tempTable").collect - } - - benchmark.addCase("Read partition column") { iter => - spark.sql("select sum(p) from tempTable").collect - } - - benchmark.addCase("Read both columns") { iter => - spark.sql("select sum(p), sum(id) from tempTable").collect - } - - /* - Intel(R) Core(TM) i7-4870HQ CPU @ 2.50GHz - Partitioned Table: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------- - Read data column 191 / 250 82.1 12.2 1.0X - Read partition column 82 / 86 192.4 5.2 2.3X - Read both columns 220 / 248 71.5 14.0 0.9X - */ - benchmark.run() - } - } - } - - def stringWithNullsScanBenchmark(values: Int, fractionOfNulls: Double): Unit = { - withTempPath { dir => - withTempTable("t1", "tempTable") { - val enableOffHeapColumnVector = spark.sessionState.conf.offHeapColumnVectorEnabled - val vectorizedReaderBatchSize = spark.sessionState.conf.parquetVectorizedReaderBatchSize - spark.range(values).createOrReplaceTempView("t1") - spark.sql(s"select IF(rand(1) < $fractionOfNulls, NULL, cast(id as STRING)) as c1, " + - s"IF(rand(2) < $fractionOfNulls, NULL, cast(id as STRING)) as c2 from t1") - .write.parquet(dir.getCanonicalPath) - spark.read.parquet(dir.getCanonicalPath).createOrReplaceTempView("tempTable") - - val benchmark = new Benchmark("String with Nulls Scan", values) - - benchmark.addCase("SQL Parquet Vectorized") { iter => - spark.sql("select sum(length(c2)) from tempTable where c1 is " + - "not NULL and c2 is not NULL").collect() - } - - val files = SpecificParquetRecordReaderBase.listDirectory(dir).toArray - benchmark.addCase("PR Vectorized") { num => - var sum = 0 - files.map(_.asInstanceOf[String]).foreach { p => - val reader = new VectorizedParquetRecordReader( - null, enableOffHeapColumnVector, vectorizedReaderBatchSize) - try { - reader.initialize(p, ("c1" :: "c2" :: Nil).asJava) - val batch = reader.resultBatch() - while (reader.nextBatch()) { - val rowIterator = batch.rowIterator() - while (rowIterator.hasNext) { - val row = rowIterator.next() - val value = row.getUTF8String(0) - if (!row.isNullAt(0) && !row.isNullAt(1)) sum += value.numBytes() - } - } - } finally { - reader.close() - } - } - } - - /* - Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz - String with Nulls Scan (0%): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------- - SQL Parquet Vectorized 1229 / 1648 8.5 117.2 1.0X - PR Vectorized 833 / 846 12.6 79.4 1.5X - - Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz - String with Nulls Scan (50%): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------- - SQL Parquet Vectorized 995 / 1053 10.5 94.9 1.0X - PR Vectorized 732 / 772 14.3 69.8 1.4X - - Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz - String with Nulls Scan (95%): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------- - SQL Parquet Vectorized 326 / 333 32.2 31.1 1.0X - PR Vectorized 190 / 200 55.1 18.2 1.7X - */ - - benchmark.run() - } - } - } - - def main(args: Array[String]): Unit = { - intScanBenchmark(1024 * 1024 * 15) - intStringScanBenchmark(1024 * 1024 * 10) - stringDictionaryScanBenchmark(1024 * 1024 * 10) - partitionTableScanBenchmark(1024 * 1024 * 15) - for (fractionOfNulls <- List(0.0, 0.50, 0.95)) { - stringWithNullsScanBenchmark(1024 * 1024 * 10, fractionOfNulls) - } - } -} From b7a036b75b8a1d287ac014b85e90d555753064c9 Mon Sep 17 00:00:00 2001 From: jinxing Date: Wed, 23 May 2018 13:12:05 -0700 Subject: [PATCH 0860/1161] [SPARK-24294] Throw SparkException when OOM in BroadcastExchangeExec MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? When OutOfMemoryError thrown from BroadcastExchangeExec, scala.concurrent.Future will hit scala bug – https://github.com/scala/bug/issues/9554, and hang until future timeout: We could wrap the OOM inside SparkException to resolve this issue. ## How was this patch tested? Manually tested. Author: jinxing Closes #21342 from jinxing64/SPARK-24294. --- .../spark/util/SparkFatalException.scala | 27 +++++++++++++++++++ .../util/SparkUncaughtExceptionHandler.scala | 13 ++++++--- .../org/apache/spark/util/ThreadUtils.scala | 2 ++ .../exchange/BroadcastExchangeExec.scala | 13 ++++++--- 4 files changed, 48 insertions(+), 7 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/util/SparkFatalException.scala diff --git a/core/src/main/scala/org/apache/spark/util/SparkFatalException.scala b/core/src/main/scala/org/apache/spark/util/SparkFatalException.scala new file mode 100644 index 0000000000000..1aa2009fa9b5b --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/SparkFatalException.scala @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.util + +/** + * SPARK-24294: To bypass scala bug: https://github.com/scala/bug/issues/9554, we catch + * fatal throwable in {@link scala.concurrent.Future}'s body, and re-throw + * SparkFatalException, which wraps the fatal throwable inside. + * Note that SparkFatalException should only be thrown from a {@link scala.concurrent.Future}, + * which is run by using ThreadUtils.awaitResult. ThreadUtils.awaitResult will catch + * it and re-throw the original exception/error. + */ +private[spark] final class SparkFatalException(val throwable: Throwable) extends Exception diff --git a/core/src/main/scala/org/apache/spark/util/SparkUncaughtExceptionHandler.scala b/core/src/main/scala/org/apache/spark/util/SparkUncaughtExceptionHandler.scala index e0f5af5250e7f..1b34fbde38cd6 100644 --- a/core/src/main/scala/org/apache/spark/util/SparkUncaughtExceptionHandler.scala +++ b/core/src/main/scala/org/apache/spark/util/SparkUncaughtExceptionHandler.scala @@ -39,10 +39,15 @@ private[spark] class SparkUncaughtExceptionHandler(val exitOnUncaughtException: // We may have been called from a shutdown hook. If so, we must not call System.exit(). // (If we do, we will deadlock.) if (!ShutdownHookManager.inShutdown()) { - if (exception.isInstanceOf[OutOfMemoryError]) { - System.exit(SparkExitCode.OOM) - } else if (exitOnUncaughtException) { - System.exit(SparkExitCode.UNCAUGHT_EXCEPTION) + exception match { + case _: OutOfMemoryError => + System.exit(SparkExitCode.OOM) + case e: SparkFatalException if e.throwable.isInstanceOf[OutOfMemoryError] => + // SPARK-24294: This is defensive code, in case that SparkFatalException is + // misused and uncaught. + System.exit(SparkExitCode.OOM) + case _ if exitOnUncaughtException => + System.exit(SparkExitCode.UNCAUGHT_EXCEPTION) } } } catch { diff --git a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala index 81aaf79db0c13..165a15c73e7ca 100644 --- a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala @@ -200,6 +200,8 @@ private[spark] object ThreadUtils { val awaitPermission = null.asInstanceOf[scala.concurrent.CanAwait] awaitable.result(atMost)(awaitPermission) } catch { + case e: SparkFatalException => + throw e.throwable // TimeoutException is thrown in the current thread, so not need to warp the exception. case NonFatal(t) if !t.isInstanceOf[TimeoutException] => throw new SparkException("Exception thrown in awaitResult: ", t) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala index 9e0ec9481b0de..c55f9b8f1a7fc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.exchange import scala.concurrent.{ExecutionContext, Future} import scala.concurrent.duration._ +import scala.util.control.NonFatal import org.apache.spark.{broadcast, SparkException} import org.apache.spark.launcher.SparkLauncher @@ -30,7 +31,7 @@ import org.apache.spark.sql.execution.{SparkPlan, SQLExecution} import org.apache.spark.sql.execution.joins.HashedRelation import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.util.ThreadUtils +import org.apache.spark.util.{SparkFatalException, ThreadUtils} /** * A [[BroadcastExchangeExec]] collects, transforms and finally broadcasts the result of @@ -111,12 +112,18 @@ case class BroadcastExchangeExec( SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, metrics.values.toSeq) broadcasted } catch { + // SPARK-24294: To bypass scala bug: https://github.com/scala/bug/issues/9554, we throw + // SparkFatalException, which is a subclass of Exception. ThreadUtils.awaitResult + // will catch this exception and re-throw the wrapped fatal throwable. case oe: OutOfMemoryError => - throw new OutOfMemoryError(s"Not enough memory to build and broadcast the table to " + + throw new SparkFatalException( + new OutOfMemoryError(s"Not enough memory to build and broadcast the table to " + s"all worker nodes. As a workaround, you can either disable broadcast by setting " + s"${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key} to -1 or increase the spark driver " + s"memory by setting ${SparkLauncher.DRIVER_MEMORY} to a higher value") - .initCause(oe.getCause) + .initCause(oe.getCause)) + case e if !NonFatal(e) => + throw new SparkFatalException(e) } } }(BroadcastExchangeExec.executionContext) From f4579332931c9bf424d0b6147fad89bd63da26f6 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Wed, 23 May 2018 17:21:29 -0700 Subject: [PATCH 0861/1161] [SPARK-23416][SS] Add a specific stop method for ContinuousExecution. ## What changes were proposed in this pull request? Add a specific stop method for ContinuousExecution. The previous StreamExecution.stop() method had a race condition as applied to continuous processing: if the cancellation was round-tripped to the driver too quickly, the generic SparkException it caused would be reported as the query death cause. We earlier decided that SparkException should not be added to the StreamExecution.isInterruptionException() whitelist, so we need to ensure this never happens instead. ## How was this patch tested? Existing tests. I could consistently reproduce the previous flakiness by putting Thread.sleep(1000) between the first job cancellation and thread interruption in StreamExecution.stop(). Author: Jose Torres Closes #21384 from jose-torres/fixKafka. --- .../streaming/MicroBatchExecution.scala | 18 ++++++++++++++++++ .../execution/streaming/StreamExecution.scala | 18 ------------------ .../continuous/ContinuousExecution.scala | 16 ++++++++++++++++ 3 files changed, 34 insertions(+), 18 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index 6709e7052f005..7817360810bde 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -126,6 +126,24 @@ class MicroBatchExecution( _logicalPlan } + /** + * Signals to the thread executing micro-batches that it should stop running after the next + * batch. This method blocks until the thread stops running. + */ + override def stop(): Unit = { + // Set the state to TERMINATED so that the batching thread knows that it was interrupted + // intentionally + state.set(TERMINATED) + if (queryExecutionThread.isAlive) { + sparkSession.sparkContext.cancelJobGroup(runId.toString) + queryExecutionThread.interrupt() + queryExecutionThread.join() + // microBatchThread may spawn new jobs, so we need to cancel again to prevent a leak + sparkSession.sparkContext.cancelJobGroup(runId.toString) + } + logInfo(s"Query $prettyIdString was stopped") + } + /** * Repeatedly attempts to run batches as data arrives. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index 3fc8c7887896a..290de873c5cfb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -378,24 +378,6 @@ abstract class StreamExecution( } } - /** - * Signals to the thread executing micro-batches that it should stop running after the next - * batch. This method blocks until the thread stops running. - */ - override def stop(): Unit = { - // Set the state to TERMINATED so that the batching thread knows that it was interrupted - // intentionally - state.set(TERMINATED) - if (queryExecutionThread.isAlive) { - sparkSession.sparkContext.cancelJobGroup(runId.toString) - queryExecutionThread.interrupt() - queryExecutionThread.join() - // microBatchThread may spawn new jobs, so we need to cancel again to prevent a leak - sparkSession.sparkContext.cancelJobGroup(runId.toString) - } - logInfo(s"Query $prettyIdString was stopped") - } - /** * Blocks the current thread until processing for data from the given `source` has reached at * least the given `Offset`. This method is intended for use primarily when writing tests. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index 0e7d1019b9c8f..d16b24c89ebef 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -356,6 +356,22 @@ class ContinuousExecution( } } } + + /** + * Stops the query execution thread to terminate the query. + */ + override def stop(): Unit = { + // Set the state to TERMINATED so that the batching thread knows that it was interrupted + // intentionally + state.set(TERMINATED) + if (queryExecutionThread.isAlive) { + // The query execution thread will clean itself up in the finally clause of runContinuous. + // We just need to interrupt the long running job. + queryExecutionThread.interrupt() + queryExecutionThread.join() + } + logInfo(s"Query $prettyIdString was stopped") + } } object ContinuousExecution { From 230f1441978e14062d3e0b6ba1524acf36e3eafe Mon Sep 17 00:00:00 2001 From: "Vayda, Oleksandr: IT (PRG)" Date: Wed, 23 May 2018 17:22:52 -0700 Subject: [PATCH 0862/1161] [SPARK-24350][SQL] Fixes ClassCastException in the "array_position" function ## What changes were proposed in this pull request? ### Fixes `ClassCastException` in the `array_position` function - [SPARK-24350](https://issues.apache.org/jira/browse/SPARK-24350) When calling `array_position` function with a wrong type of the 1st argument an `AnalysisException` should be thrown instead of `ClassCastException` Example: ```sql select array_position('foo', 'bar') ``` ``` java.lang.ClassCastException: org.apache.spark.sql.types.StringType$ cannot be cast to org.apache.spark.sql.types.ArrayType at org.apache.spark.sql.catalyst.expressions.ArrayPosition.inputTypes(collectionOperations.scala:1398) at org.apache.spark.sql.catalyst.expressions.ExpectsInputTypes$class.checkInputDataTypes(ExpectsInputTypes.scala:44) at org.apache.spark.sql.catalyst.expressions.ArrayPosition.checkInputDataTypes(collectionOperations.scala:1401) at org.apache.spark.sql.catalyst.expressions.Expression.resolved$lzycompute(Expression.scala:168) at org.apache.spark.sql.catalyst.expressions.Expression.resolved(Expression.scala:168) at org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveAliases$$anonfun$org$apache$spark$sql$catalyst$analysis$Analyzer$ResolveAliases$$assignAliases$1$$anonfun$apply$3.applyOrElse(Analyzer.scala:256) at org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveAliases$$anonfun$org$apache$spark$sql$catalyst$analysis$Analyzer$ResolveAliases$$assignAliases$1$$anonfun$apply$3.applyOrElse(Analyzer.scala:252) at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$transformUp$1.apply(TreeNode.scala:289) at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$transformUp$1.apply(TreeNode.scala:289) at org.apache.spark.sql.catalyst.trees.CurrentOrigin$.withOrigin(TreeNode.scala:70) at org.apache.spark.sql.catalyst.trees.TreeNode.transformUp(TreeNode.scala:288) ``` ## How was this patch tested? unit test Author: Vayda, Oleksandr: IT (PRG) Closes #21401 from wajda/SPARK-24350-array_position-error-fix. --- .../catalyst/expressions/collectionOperations.scala | 10 ++++++++-- .../org/apache/spark/sql/DataFrameFunctionsSuite.scala | 5 +++++ 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 03b3b21a16617..8a877b02c8191 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -1395,8 +1395,14 @@ case class ArrayPosition(left: Expression, right: Expression) TypeUtils.getInterpretedOrdering(right.dataType) override def dataType: DataType = LongType - override def inputTypes: Seq[AbstractDataType] = - Seq(ArrayType, left.dataType.asInstanceOf[ArrayType].elementType) + + override def inputTypes: Seq[AbstractDataType] = { + val elementType = left.dataType match { + case t: ArrayType => t.elementType + case _ => AnyDataType + } + Seq(ArrayType, elementType) + } override def checkInputDataTypes(): TypeCheckResult = { super.checkInputDataTypes() match { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index ec2a569f900d1..79e743d961af8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -708,6 +708,11 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { df.selectExpr("array_position(array(1, null), array(1, null)[0])"), Seq(Row(1L), Row(1L)) ) + + val e = intercept[AnalysisException] { + Seq(("a string element", "a")).toDF().selectExpr("array_position(_1, _2)") + } + assert(e.message.contains("argument 1 requires array type, however, '`_1`' is of string type")) } test("element_at function") { From 888340151f737bb68d0e419b1e949f11469881f9 Mon Sep 17 00:00:00 2001 From: sychen Date: Thu, 24 May 2018 11:02:09 +0800 Subject: [PATCH 0863/1161] [SPARK-24257][SQL] LongToUnsafeRowMap calculate the new size may be wrong LongToUnsafeRowMap has a mistake when growing its page array: it blindly grows to `oldSize * 2`, while the new record may be larger than `oldSize * 2`. Then we may have a malformed UnsafeRow when querying this map, whose actual data is smaller than its declared size, and the data is corrupted. Author: sychen Closes #21311 from cxzl25/fix_LongToUnsafeRowMap_page_size. --- .../sql/execution/joins/HashedRelation.scala | 38 +++++++++++-------- .../execution/joins/HashedRelationSuite.scala | 26 ++++++++++++- 2 files changed, 48 insertions(+), 16 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index 1465346eb802d..20ce01f4ce8cc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -557,7 +557,7 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap def append(key: Long, row: UnsafeRow): Unit = { val sizeInBytes = row.getSizeInBytes if (sizeInBytes >= (1 << SIZE_BITS)) { - sys.error("Does not support row that is larger than 256M") + throw new UnsupportedOperationException("Does not support row that is larger than 256M") } if (key < minKey) { @@ -567,19 +567,7 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap maxKey = key } - // There is 8 bytes for the pointer to next value - if (cursor + 8 + row.getSizeInBytes > page.length * 8L + Platform.LONG_ARRAY_OFFSET) { - val used = page.length - if (used >= (1 << 30)) { - sys.error("Can not build a HashedRelation that is larger than 8G") - } - ensureAcquireMemory(used * 8L * 2) - val newPage = new Array[Long](used * 2) - Platform.copyMemory(page, Platform.LONG_ARRAY_OFFSET, newPage, Platform.LONG_ARRAY_OFFSET, - cursor - Platform.LONG_ARRAY_OFFSET) - page = newPage - freeMemory(used * 8L) - } + grow(row.getSizeInBytes) // copy the bytes of UnsafeRow val offset = cursor @@ -615,7 +603,8 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap growArray() } else if (numKeys > array.length / 2 * 0.75) { // The fill ratio should be less than 0.75 - sys.error("Cannot build HashedRelation with more than 1/3 billions unique keys") + throw new UnsupportedOperationException( + "Cannot build HashedRelation with more than 1/3 billions unique keys") } } } else { @@ -626,6 +615,25 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap } } + private def grow(inputRowSize: Int): Unit = { + // There is 8 bytes for the pointer to next value + val neededNumWords = (cursor - Platform.LONG_ARRAY_OFFSET + 8 + inputRowSize + 7) / 8 + if (neededNumWords > page.length) { + if (neededNumWords > (1 << 30)) { + throw new UnsupportedOperationException( + "Can not build a HashedRelation that is larger than 8G") + } + val newNumWords = math.max(neededNumWords, math.min(page.length * 2, 1 << 30)) + ensureAcquireMemory(newNumWords * 8L) + val newPage = new Array[Long](newNumWords.toInt) + Platform.copyMemory(page, Platform.LONG_ARRAY_OFFSET, newPage, Platform.LONG_ARRAY_OFFSET, + cursor - Platform.LONG_ARRAY_OFFSET) + val used = page.length + page = newPage + freeMemory(used * 8L) + } + } + private def growArray(): Unit = { var old_array = array val n = array.length diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala index 51f8c3325fdff..037cc2e3ccad7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala @@ -28,7 +28,7 @@ import org.apache.spark.serializer.KryoSerializer import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructField, StructType} +import org.apache.spark.sql.types._ import org.apache.spark.unsafe.map.BytesToBytesMap import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.collection.CompactBuffer @@ -254,6 +254,30 @@ class HashedRelationSuite extends SparkFunSuite with SharedSQLContext { map.free() } + test("SPARK-24257: insert big values into LongToUnsafeRowMap") { + val taskMemoryManager = new TaskMemoryManager( + new StaticMemoryManager( + new SparkConf().set(MEMORY_OFFHEAP_ENABLED.key, "false"), + Long.MaxValue, + Long.MaxValue, + 1), + 0) + val unsafeProj = UnsafeProjection.create(Array[DataType](StringType)) + val map = new LongToUnsafeRowMap(taskMemoryManager, 1) + + val key = 0L + // the page array is initialized with length 1 << 17 (1M bytes), + // so here we need a value larger than 1 << 18 (2M bytes), to trigger the bug + val bigStr = UTF8String.fromString("x" * (1 << 19)) + + map.append(key, unsafeProj(InternalRow(bigStr))) + map.optimize() + + val resultRow = new UnsafeRow(1) + assert(map.getValue(key, resultRow).getUTF8String(0) === bigStr) + map.free() + } + test("Spark-14521") { val ser = new KryoSerializer( (new SparkConf).set("spark.kryo.referenceTracking", "false")).newInstance() From 486ecc680e9a0e7b6b3c3a45fb883a61072096fc Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Thu, 24 May 2018 11:34:13 +0800 Subject: [PATCH 0864/1161] [SPARK-24322][BUILD] Upgrade Apache ORC to 1.4.4 ## What changes were proposed in this pull request? ORC 1.4.4 includes [nine fixes](https://issues.apache.org/jira/issues/?filter=12342568&jql=project%20%3D%20ORC%20AND%20resolution%20%3D%20Fixed%20AND%20fixVersion%20%3D%201.4.4). One of the issues is about `Timestamp` bug (ORC-306) which occurs when `native` ORC vectorized reader reads ORC column vector's sub-vector `times` and `nanos`. ORC-306 fixes this according to the [original definition](https://github.com/apache/hive/blob/master/storage-api/src/java/org/apache/hadoop/hive/ql/exec/vector/TimestampColumnVector.java#L45-L46) and this PR includes the updated interpretation on ORC column vectors. Note that `hive` ORC reader and ORC MR reader is not affected. ```scala scala> spark.version res0: String = 2.3.0 scala> spark.sql("set spark.sql.orc.impl=native") scala> Seq(java.sql.Timestamp.valueOf("1900-05-05 12:34:56.000789")).toDF().write.orc("/tmp/orc") scala> spark.read.orc("/tmp/orc").show(false) +--------------------------+ |value | +--------------------------+ |1900-05-05 12:34:55.000789| +--------------------------+ ``` This PR aims to update Apache Spark to use it. **FULL LIST** ID | TITLE -- | -- ORC-281 | Fix compiler warnings from clang 5.0 ORC-301 | `extractFileTail` should open a file in `try` statement ORC-304 | Fix TestRecordReaderImpl to not fail with new storage-api ORC-306 | Fix incorrect workaround for bug in java.sql.Timestamp ORC-324 | Add support for ARM and PPC arch ORC-330 | Remove unnecessary Hive artifacts from root pom ORC-332 | Add syntax version to orc_proto.proto ORC-336 | Remove avro and parquet dependency management entries ORC-360 | Implement error checking on subtype fields in Java ## How was this patch tested? Pass the Jenkins. Author: Dongjoon Hyun Closes #21372 from dongjoon-hyun/SPARK_ORC144. --- dev/deps/spark-deps-hadoop-2.6 | 4 ++-- dev/deps/spark-deps-hadoop-2.7 | 4 ++-- dev/deps/spark-deps-hadoop-3.1 | 4 ++-- pom.xml | 2 +- .../sql/execution/datasources/orc/OrcColumnVector.java | 2 +- .../datasources/orc/OrcColumnarBatchReader.java | 2 +- .../sql/execution/datasources/orc/OrcSourceSuite.scala | 9 +++++++++ 7 files changed, 18 insertions(+), 9 deletions(-) diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6 index e710e26348117..723180a14febb 100644 --- a/dev/deps/spark-deps-hadoop-2.6 +++ b/dev/deps/spark-deps-hadoop-2.6 @@ -157,8 +157,8 @@ objenesis-2.1.jar okhttp-3.8.1.jar okio-1.13.0.jar opencsv-2.3.jar -orc-core-1.4.3-nohive.jar -orc-mapreduce-1.4.3-nohive.jar +orc-core-1.4.4-nohive.jar +orc-mapreduce-1.4.4-nohive.jar oro-2.0.8.jar osgi-resource-locator-1.0.1.jar paranamer-2.8.jar diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index 97ad17a9ff7b1..ea08a001a1c9b 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -158,8 +158,8 @@ objenesis-2.1.jar okhttp-3.8.1.jar okio-1.13.0.jar opencsv-2.3.jar -orc-core-1.4.3-nohive.jar -orc-mapreduce-1.4.3-nohive.jar +orc-core-1.4.4-nohive.jar +orc-mapreduce-1.4.4-nohive.jar oro-2.0.8.jar osgi-resource-locator-1.0.1.jar paranamer-2.8.jar diff --git a/dev/deps/spark-deps-hadoop-3.1 b/dev/deps/spark-deps-hadoop-3.1 index e21bfef8c4291..da874026d7d10 100644 --- a/dev/deps/spark-deps-hadoop-3.1 +++ b/dev/deps/spark-deps-hadoop-3.1 @@ -176,8 +176,8 @@ okhttp-2.7.5.jar okhttp-3.8.1.jar okio-1.13.0.jar opencsv-2.3.jar -orc-core-1.4.3-nohive.jar -orc-mapreduce-1.4.3-nohive.jar +orc-core-1.4.4-nohive.jar +orc-mapreduce-1.4.4-nohive.jar oro-2.0.8.jar osgi-resource-locator-1.0.1.jar paranamer-2.8.jar diff --git a/pom.xml b/pom.xml index 6e37e518d86e4..883c096ae1ae9 100644 --- a/pom.xml +++ b/pom.xml @@ -130,7 +130,7 @@ 1.2.1 10.12.1.1 1.10.0 - 1.4.3 + 1.4.4 nohive 1.6.0 9.3.20.v20170531 diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java index 12f4d658b1868..9bfad1e83ee7b 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java @@ -136,7 +136,7 @@ public int getInt(int rowId) { public long getLong(int rowId) { int index = getRowIndex(rowId); if (isTimestamp) { - return timestampData.time[index] * 1000 + timestampData.nanos[index] / 1000; + return timestampData.time[index] * 1000 + timestampData.nanos[index] / 1000 % 1000; } else { return longData.vector[index]; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java index dcebdc39f0aa2..a0d9578a377b1 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java @@ -497,7 +497,7 @@ private void putValues( * Returns the number of micros since epoch from an element of TimestampColumnVector. */ private static long fromTimestampColumnVector(TimestampColumnVector vector, int index) { - return vector.time[index] * 1000L + vector.nanos[index] / 1000L; + return vector.time[index] * 1000 + (vector.nanos[index] / 1000 % 1000); } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala index 8a3bbd03a26dc..02bfb7197ffc0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.datasources.orc import java.io.File +import java.sql.Timestamp import java.util.Locale import org.apache.orc.OrcConf.COMPRESS @@ -169,6 +170,14 @@ abstract class OrcSuite extends OrcTest with BeforeAndAfterAll { } } } + + test("SPARK-24322 Fix incorrect workaround for bug in java.sql.Timestamp") { + withTempPath { path => + val ts = Timestamp.valueOf("1900-05-05 12:34:56.000789") + Seq(ts).toDF.write.orc(path.getCanonicalPath) + checkAnswer(spark.read.orc(path.getCanonicalPath), Row(ts)) + } + } } class OrcSourceSuite extends OrcSuite with SharedSQLContext { From e108f84f5cd562f070872651bdcf6c02e80dd585 Mon Sep 17 00:00:00 2001 From: Xingbo Jiang Date: Thu, 24 May 2018 11:42:25 +0800 Subject: [PATCH 0865/1161] [MINOR][CORE] Cleanup unused vals in `DAGScheduler.handleTaskCompletion` ## What changes were proposed in this pull request? Cleanup unused vals in `DAGScheduler.handleTaskCompletion` to reduce the code complexity slightly. ## How was this patch tested? Existing test cases. Author: Xingbo Jiang Closes #21406 from jiangxb1987/handleTaskCompletion. --- .../scala/org/apache/spark/scheduler/DAGScheduler.scala | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index ea7bfd7d7a68d..041eade82d3ca 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -1167,9 +1167,7 @@ class DAGScheduler( */ private[scheduler] def handleTaskCompletion(event: CompletionEvent) { val task = event.task - val taskId = event.taskInfo.id val stageId = task.stageId - val taskType = Utils.getFormattedClassName(task) outputCommitCoordinator.taskCompleted( stageId, @@ -1323,7 +1321,7 @@ class DAGScheduler( "tasks in ShuffleMapStages.") } - case FetchFailed(bmAddress, shuffleId, mapId, reduceId, failureMessage) => + case FetchFailed(bmAddress, shuffleId, mapId, _, failureMessage) => val failedStage = stageIdToStage(task.stageId) val mapStage = shuffleIdToMapStage(shuffleId) @@ -1411,7 +1409,7 @@ class DAGScheduler( } } - case commitDenied: TaskCommitDenied => + case _: TaskCommitDenied => // Do nothing here, left up to the TaskScheduler to decide how to handle denied commits case _: ExceptionFailure | _: TaskKilled => From 8a545822d0cc3a866ef91a94e58ea5c8b1014007 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Thu, 24 May 2018 13:21:02 +0800 Subject: [PATCH 0866/1161] [SPARK-24364][SS] Prevent InMemoryFileIndex from failing if file path doesn't exist ## What changes were proposed in this pull request? This PR proposes to follow up https://github.com/apache/spark/pull/15153 and complete SPARK-17599. `FileSystem` operation (`fs.getFileBlockLocations`) can still fail if the file path does not exist. For example see the exception message below: ``` Error occurred while processing: File does not exist: /rel/00171151/input/PJ/part-00136-b6403bac-a240-44f8-a792-fc2e174682b7-c000.csv ... java.io.FileNotFoundException: File does not exist: /rel/00171151/input/PJ/part-00136-b6403bac-a240-44f8-a792-fc2e174682b7-c000.csv ... org.apache.hadoop.hdfs.DistributedFileSystem.getFileBlockLocations(DistributedFileSystem.java:249) at org.apache.hadoop.hdfs.DistributedFileSystem.getFileBlockLocations(DistributedFileSystem.java:229) at org.apache.spark.sql.execution.datasources.InMemoryFileIndex$$anonfun$org$apache$spark$sql$execution$datasources$InMemoryFileIndex$$listLeafFiles$3.apply(InMemoryFileIndex.scala:314) at org.apache.spark.sql.execution.datasources.InMemoryFileIndex$$anonfun$org$apache$spark$sql$execution$datasources$InMemoryFileIndex$$listLeafFiles$3.apply(InMemoryFileIndex.scala:297) at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234) at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234) at scala.collection.IndexedSeqOptimized$class.foreach(IndexedSeqOptimized.scala:33) at scala.collection.mutable.ArrayOps$ofRef.foreach(ArrayOps.scala:186) at scala.collection.TraversableLike$class.map(TraversableLike.scala:234) at scala.collection.mutable.ArrayOps$ofRef.map(ArrayOps.scala:186) at org.apache.spark.sql.execution.datasources.InMemoryFileIndex$.org$apache$spark$sql$execution$datasources$InMemoryFileIndex$$listLeafFiles(InMemoryFileIndex.scala:297) at org.apache.spark.sql.execution.datasources.InMemoryFileIndex$$anonfun$org$apache$spark$sql$execution$datasources$InMemoryFileIndex$$bulkListLeafFiles$1.apply(InMemoryFileIndex.scala:174) at org.apache.spark.sql.execution.datasources.InMemoryFileIndex$$anonfun$org$apache$spark$sql$execution$datasources$InMemoryFileIndex$$bulkListLeafFiles$1.apply(InMemoryFileIndex.scala:173) at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234) at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234) at scala.collection.mutable.ResizableArray$class.foreach(ResizableArray.scala:59) at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:48) at scala.collection.TraversableLike$class.map(TraversableLike.scala:234) at scala.collection.AbstractTraversable.map(Traversable.scala:104) at org.apache.spark.sql.execution.datasources.InMemoryFileIndex$.org$apache$spark$sql$execution$datasources$InMemoryFileIndex$$bulkListLeafFiles(InMemoryFileIndex.scala:173) at org.apache.spark.sql.execution.datasources.InMemoryFileIndex.listLeafFiles(InMemoryFileIndex.scala:126) at org.apache.spark.sql.execution.datasources.InMemoryFileIndex.refresh0(InMemoryFileIndex.scala:91) at org.apache.spark.sql.execution.datasources.InMemoryFileIndex.(InMemoryFileIndex.scala:67) at org.apache.spark.sql.execution.datasources.DataSource.tempFileIndex$lzycompute$1(DataSource.scala:161) at org.apache.spark.sql.execution.datasources.DataSource.org$apache$spark$sql$execution$datasources$DataSource$$tempFileIndex$1(DataSource.scala:152) at org.apache.spark.sql.execution.datasources.DataSource.getOrInferFileFormatSchema(DataSource.scala:166) at org.apache.spark.sql.execution.datasources.DataSource.sourceSchema(DataSource.scala:261) at org.apache.spark.sql.execution.datasources.DataSource.sourceInfo$lzycompute(DataSource.scala:94) at org.apache.spark.sql.execution.datasources.DataSource.sourceInfo(DataSource.scala:94) at org.apache.spark.sql.execution.streaming.StreamingRelation$.apply(StreamingRelation.scala:33) at org.apache.spark.sql.streaming.DataStreamReader.load(DataStreamReader.scala:196) at org.apache.spark.sql.streaming.DataStreamReader.load(DataStreamReader.scala:206) at com.hwx.StreamTest$.main(StreamTest.scala:97) at com.hwx.StreamTest.main(StreamTest.scala) at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method) at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62) at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43) at java.lang.reflect.Method.invoke(Method.java:498) at org.apache.spark.deploy.JavaMainApplication.start(SparkApplication.scala:52) at org.apache.spark.deploy.SparkSubmit$.org$apache$spark$deploy$SparkSubmit$$runMain(SparkSubmit.scala:906) at org.apache.spark.deploy.SparkSubmit$.doRunMain$1(SparkSubmit.scala:197) at org.apache.spark.deploy.SparkSubmit$.submit(SparkSubmit.scala:227) at org.apache.spark.deploy.SparkSubmit$.main(SparkSubmit.scala:136) at org.apache.spark.deploy.SparkSubmit.main(SparkSubmit.scala) Caused by: org.apache.hadoop.ipc.RemoteException(java.io.FileNotFoundException): File does not exist: /rel/00171151/input/PJ/part-00136-b6403bac-a240-44f8-a792-fc2e174682b7-c000.csv ... ``` So, it fixes it to make a warning instead. ## How was this patch tested? It's hard to write a test. Manually tested multiple times. Author: hyukjinkwon Closes #21408 from HyukjinKwon/missing-files. --- .../datasources/InMemoryFileIndex.scala | 32 ++++++++++++++----- 1 file changed, 24 insertions(+), 8 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala index 739d1f456e3ec..9d9f8bd5bb58e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala @@ -294,9 +294,12 @@ object InMemoryFileIndex extends Logging { if (filter != null) allFiles.filter(f => filter.accept(f.getPath)) else allFiles } - allLeafStatuses.filterNot(status => shouldFilterOut(status.getPath.getName)).map { + val missingFiles = mutable.ArrayBuffer.empty[String] + val filteredLeafStatuses = allLeafStatuses.filterNot( + status => shouldFilterOut(status.getPath.getName)) + val resolvedLeafStatuses = filteredLeafStatuses.flatMap { case f: LocatedFileStatus => - f + Some(f) // NOTE: // @@ -311,14 +314,27 @@ object InMemoryFileIndex extends Logging { // The other constructor of LocatedFileStatus will call FileStatus.getPermission(), // which is very slow on some file system (RawLocalFileSystem, which is launch a // subprocess and parse the stdout). - val locations = fs.getFileBlockLocations(f, 0, f.getLen) - val lfs = new LocatedFileStatus(f.getLen, f.isDirectory, f.getReplication, f.getBlockSize, - f.getModificationTime, 0, null, null, null, null, f.getPath, locations) - if (f.isSymlink) { - lfs.setSymlink(f.getSymlink) + try { + val locations = fs.getFileBlockLocations(f, 0, f.getLen) + val lfs = new LocatedFileStatus(f.getLen, f.isDirectory, f.getReplication, f.getBlockSize, + f.getModificationTime, 0, null, null, null, null, f.getPath, locations) + if (f.isSymlink) { + lfs.setSymlink(f.getSymlink) + } + Some(lfs) + } catch { + case _: FileNotFoundException => + missingFiles += f.getPath.toString + None } - lfs } + + if (missingFiles.nonEmpty) { + logWarning( + s"the following files were missing during file scan:\n ${missingFiles.mkString("\n ")}") + } + + resolvedLeafStatuses } /** Checks if we should filter out this path name. */ From 4a14dc0aff9cac85390cab94bc183271fa95beef Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Thu, 24 May 2018 14:19:32 +0800 Subject: [PATCH 0867/1161] [SPARK-22269][BUILD] Run Java linter via SBT for Jenkins ## What changes were proposed in this pull request? This PR proposes to check Java lint via SBT for Jenkins. It uses the SBT wrapper for checkstyle. I manually tested. If we build the codes once, running this script takes 2 mins at maximum in my local: Test codes: ``` Checkstyle failed at following occurrences: [error] Checkstyle error found in /.../spark/core/src/test/java/test/org/apache/spark/JavaAPISuite.java:82: Line is longer than 100 characters (found 103). [error] 1 issue(s) found in Checkstyle report: /.../spark/core/target/checkstyle-test-report.xml [error] Checkstyle error found in /.../spark/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaDataFrameSuite.java:84: Line is longer than 100 characters (found 115). [error] 1 issue(s) found in Checkstyle report: /.../spark/sql/hive/target/checkstyle-test-report.xml ... ``` Main codes: ``` Checkstyle failed at following occurrences: [error] Checkstyle error found in /.../spark/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartition.java:39: Line is longer than 100 characters (found 104). [error] Checkstyle error found in /.../spark/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartitionReader.java:26: Line is longer than 100 characters (found 110). [error] Checkstyle error found in /.../spark/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartitionReader.java:30: Line is longer than 100 characters (found 104). ... ``` ## How was this patch tested? Manually tested. Jenkins build should test this. Author: hyukjinkwon Closes #21399 from HyukjinKwon/SPARK-22269. --- dev/run-tests.py | 5 ++--- dev/sbt-checkstyle | 42 ++++++++++++++++++++++++++++++++++++++++ project/SparkBuild.scala | 14 +++++++++++++- project/plugins.sbt | 8 ++++++++ 4 files changed, 65 insertions(+), 4 deletions(-) create mode 100755 dev/sbt-checkstyle diff --git a/dev/run-tests.py b/dev/run-tests.py index 164c1e2200aa9..5e8c8590b5c34 100755 --- a/dev/run-tests.py +++ b/dev/run-tests.py @@ -204,7 +204,7 @@ def run_scala_style_checks(): def run_java_style_checks(): set_title_and_block("Running Java style checks", "BLOCK_JAVA_STYLE") - run_cmd([os.path.join(SPARK_HOME, "dev", "lint-java")]) + run_cmd([os.path.join(SPARK_HOME, "dev", "sbt-checkstyle")]) def run_python_style_checks(): @@ -574,8 +574,7 @@ def main(): or f.endswith("checkstyle.xml") or f.endswith("checkstyle-suppressions.xml") for f in changed_files): - # run_java_style_checks() - pass + run_java_style_checks() if not changed_files or any(f.endswith("lint-python") or f.endswith("tox.ini") or f.endswith(".py") diff --git a/dev/sbt-checkstyle b/dev/sbt-checkstyle new file mode 100755 index 0000000000000..8821a7c0e4ccf --- /dev/null +++ b/dev/sbt-checkstyle @@ -0,0 +1,42 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# NOTE: echo "q" is needed because SBT prompts the user for input on encountering a build file +# with failure (either resolution or compilation); the "q" makes SBT quit. +ERRORS=$(echo -e "q\n" \ + | build/sbt \ + -Pkinesis-asl \ + -Pmesos \ + -Pkafka-0-8 \ + -Pkubernetes \ + -Pyarn \ + -Pflume \ + -Phive \ + -Phive-thriftserver \ + checkstyle test:checkstyle \ + | awk '{if($1~/error/)print}' \ +) + +if test ! -z "$ERRORS"; then + echo -e "Checkstyle failed at following occurrences:\n$ERRORS" + exit 1 +else + echo -e "Checkstyle checks passed." +fi + diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 7469f11df0294..4cb6495a33b61 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -27,6 +27,7 @@ import sbt._ import sbt.Classpaths.publishTask import sbt.Keys._ import sbtunidoc.Plugin.UnidocKeys.unidocGenjavadocVersion +import com.etsy.sbt.checkstyle.CheckstylePlugin.autoImport._ import com.simplytyped.Antlr4Plugin._ import com.typesafe.sbt.pom.{PomBuild, SbtPomKeys} import com.typesafe.tools.mima.plugin.MimaKeys @@ -317,7 +318,7 @@ object SparkBuild extends PomBuild { /* Enable shared settings on all projects */ (allProjects ++ optionallyEnabledProjects ++ assemblyProjects ++ copyJarsProjects ++ Seq(spark, tools)) .foreach(enable(sharedSettings ++ DependencyOverrides.settings ++ - ExcludedDependencies.settings)) + ExcludedDependencies.settings ++ Checkstyle.settings)) /* Enable tests settings for all projects except examples, assembly and tools */ (allProjects ++ optionallyEnabledProjects).foreach(enable(TestSettings.settings)) @@ -740,6 +741,17 @@ object Unidoc { ) } +object Checkstyle { + lazy val settings = Seq( + checkstyleSeverityLevel := Some(CheckstyleSeverityLevel.Error), + javaSource in (Compile, checkstyle) := baseDirectory.value / "src/main/java", + javaSource in (Test, checkstyle) := baseDirectory.value / "src/test/java", + checkstyleConfigLocation := CheckstyleConfigLocation.File("dev/checkstyle.xml"), + checkstyleOutputFile := baseDirectory.value / "target/checkstyle-output.xml", + checkstyleOutputFile in Test := baseDirectory.value / "target/checkstyle-output.xml" + ) +} + object CopyDependencies { val copyDeps = TaskKey[Unit]("copyDeps", "Copies needed dependencies to the build directory.") diff --git a/project/plugins.sbt b/project/plugins.sbt index 96bdb9067ae59..ffbd417b0f145 100644 --- a/project/plugins.sbt +++ b/project/plugins.sbt @@ -1,3 +1,11 @@ +addSbtPlugin("com.etsy" % "sbt-checkstyle-plugin" % "3.1.1") + +// sbt-checkstyle-plugin uses an old version of checkstyle. Match it to Maven's. +libraryDependencies += "com.puppycrawl.tools" % "checkstyle" % "8.2" + +// checkstyle uses guava 23.0. +libraryDependencies += "com.google.guava" % "guava" % "23.0" + // need to make changes to uptake sbt 1.0 support in "com.eed3si9n" % "sbt-assembly" % "1.14.5" addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "0.11.2") From 3469f5c989e686866051382a3a28b2265619cab9 Mon Sep 17 00:00:00 2001 From: Ryan Blue Date: Thu, 24 May 2018 20:55:26 +0800 Subject: [PATCH 0868/1161] [SPARK-24230][SQL] Fix SpecificParquetRecordReaderBase with dictionary filters. ## What changes were proposed in this pull request? I missed this commit when preparing #21070. When Parquet is able to filter blocks with dictionary filtering, the expected total value count to be too high in Spark, leading to an error when there were fewer than expected row groups to process. Spark should get the row groups from Parquet to pick up new filter schemes in Parquet like dictionary filtering. ## How was this patch tested? Using in production at Netflix. Added test case for dictionary-filtered blocks. Author: Ryan Blue Closes #21295 from rdblue/SPARK-24230-fix-parquet-block-tracking. --- .../parquet/SpecificParquetRecordReaderBase.java | 6 ++++-- .../datasources/parquet/ParquetQuerySuite.scala | 12 ++++++++++++ 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java index daedfd7e78f5f..c975e52734e01 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java @@ -146,7 +146,8 @@ public void initialize(InputSplit inputSplit, TaskAttemptContext taskAttemptCont this.sparkSchema = StructType$.MODULE$.fromString(sparkRequestedSchemaString); this.reader = new ParquetFileReader( configuration, footer.getFileMetaData(), file, blocks, requestedSchema.getColumns()); - for (BlockMetaData block : blocks) { + // use the blocks from the reader in case some do not match filters and will not be read + for (BlockMetaData block : reader.getRowGroups()) { this.totalRowCount += block.getRowCount(); } @@ -224,7 +225,8 @@ protected void initialize(String path, List columns) throws IOException this.sparkSchema = new ParquetToSparkSchemaConverter(config).convert(requestedSchema); this.reader = new ParquetFileReader( config, footer.getFileMetaData(), file, blocks, requestedSchema.getColumns()); - for (BlockMetaData block : blocks) { + // use the blocks from the reader in case some do not match filters and will not be read + for (BlockMetaData block : reader.getRowGroups()) { this.totalRowCount += block.getRowCount(); } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala index e1f094d0a7af3..2b1227faf48a0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala @@ -879,6 +879,18 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext } } } + + test("SPARK-24230: filter row group using dictionary") { + withSQLConf(("parquet.filter.dictionary.enabled", "true")) { + // create a table with values from 0, 2, ..., 18 that will be dictionary-encoded + withParquetTable((0 until 100).map(i => ((i * 2) % 20, s"data-$i")), "t") { + // search for a key that is not present so the dictionary filter eliminates all row groups + // Fails without SPARK-24230: + // java.io.IOException: expecting more rows but reached last block. Read 0 out of 50 + checkAnswer(sql("SELECT _2 FROM t WHERE t._1 = 5"), Seq.empty) + } + } + } } object TestingUDT { From 13bedc05c28fcc6e739fb472bd2ee3035fa11648 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Thu, 24 May 2018 22:18:58 +0800 Subject: [PATCH 0869/1161] [SPARK-24329][SQL] Test for skipping multi-space lines ## What changes were proposed in this pull request? The PR is a continue of https://github.com/apache/spark/pull/21380 . It checks cases that are handled by the code: https://github.com/apache/spark/blob/e3de6ab30d52890eb08578e55eb4a5d2b4e7aa35/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala#L303-L304 Basically the code skips lines with one or many whitespaces, and lines with comments (see [filterCommentAndEmpty](https://github.com/apache/spark/blob/e3de6ab30d52890eb08578e55eb4a5d2b4e7aa35/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala#L47)) ```scala iter.filter { line => line.trim.nonEmpty && !line.startsWith(options.comment.toString) } ``` Closes #21380 ## How was this patch tested? Added a test for the case described above. Author: Maxim Gekk Author: Maxim Gekk Closes #21394 from MaxGekk/test-for-multi-space-lines. --- .../resources/test-data/comments-whitespaces.csv | 8 ++++++++ .../sql/execution/datasources/csv/CSVSuite.scala | 15 +++++++++++++++ 2 files changed, 23 insertions(+) create mode 100644 sql/core/src/test/resources/test-data/comments-whitespaces.csv diff --git a/sql/core/src/test/resources/test-data/comments-whitespaces.csv b/sql/core/src/test/resources/test-data/comments-whitespaces.csv new file mode 100644 index 0000000000000..2737978f83a5e --- /dev/null +++ b/sql/core/src/test/resources/test-data/comments-whitespaces.csv @@ -0,0 +1,8 @@ +# The file contains comments, whitespaces and empty lines +colA +# empty line + +# the line with a few whitespaces + +# int value with leading and trailing whitespaces + "a" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index 07e6c74b14d0d..2bac1a3e2d4c6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -1368,4 +1368,19 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils with Te checkAnswer(computed, expected) } } + + test("SPARK-24329: skip lines with comments, and one or multiple whitespaces") { + val schema = new StructType().add("colA", StringType) + val ds = spark + .read + .schema(schema) + .option("multiLine", false) + .option("header", true) + .option("comment", "#") + .option("ignoreLeadingWhiteSpace", false) + .option("ignoreTrailingWhiteSpace", false) + .csv(testFile("test-data/comments-whitespaces.csv")) + + checkAnswer(ds, Seq(Row(""" "a" """))) + } } From 0d89943449764e8a578edf4ceb6245158421eb96 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Thu, 24 May 2018 23:38:50 +0800 Subject: [PATCH 0870/1161] [SPARK-24378][SQL] Fix date_trunc function incorrect examples ## What changes were proposed in this pull request? Fix `date_trunc` function incorrect examples. ## How was this patch tested? N/A Author: Yuming Wang Closes #21423 from wangyum/SPARK-24378. --- .../expressions/datetimeExpressions.scala | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index e8d85f72f7a7a..08838d2b2c612 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -1533,14 +1533,14 @@ case class TruncDate(date: Expression, format: Expression) """, examples = """ Examples: - > SELECT _FUNC_('2015-03-05T09:32:05.359', 'YEAR'); - 2015-01-01T00:00:00 - > SELECT _FUNC_('2015-03-05T09:32:05.359', 'MM'); - 2015-03-01T00:00:00 - > SELECT _FUNC_('2015-03-05T09:32:05.359', 'DD'); - 2015-03-05T00:00:00 - > SELECT _FUNC_('2015-03-05T09:32:05.359', 'HOUR'); - 2015-03-05T09:00:00 + > SELECT _FUNC_('YEAR', '2015-03-05T09:32:05.359'); + 2015-01-01 00:00:00 + > SELECT _FUNC_('MM', '2015-03-05T09:32:05.359'); + 2015-03-01 00:00:00 + > SELECT _FUNC_('DD', '2015-03-05T09:32:05.359'); + 2015-03-05 00:00:00 + > SELECT _FUNC_('HOUR', '2015-03-05T09:32:05.359'); + 2015-03-05 09:00:00 """, since = "2.3.0") // scalastyle:on line.size.limit From 53c06ddabbdf689f8823807445849ad63173676f Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Thu, 24 May 2018 13:00:24 -0700 Subject: [PATCH 0871/1161] [SPARK-24332][SS][MESOS] Fix places reading 'spark.network.timeout' as milliseconds ## What changes were proposed in this pull request? This PR replaces `getTimeAsMs` with `getTimeAsSeconds` to fix the issue that reading "spark.network.timeout" using a wrong time unit when the user doesn't specify a time out. ## How was this patch tested? Jenkins Author: Shixiong Zhu Closes #21382 from zsxwing/fix-network-timeout-conf. --- .../org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala | 2 +- .../scala/org/apache/spark/sql/kafka010/KafkaRelation.scala | 4 +++- .../scala/org/apache/spark/sql/kafka010/KafkaSource.scala | 2 +- .../scala/org/apache/spark/streaming/kafka010/KafkaRDD.scala | 2 +- .../cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala | 2 +- 5 files changed, 7 insertions(+), 5 deletions(-) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala index 64ba98762788c..737da2e51b125 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala @@ -68,7 +68,7 @@ private[kafka010] class KafkaMicroBatchReader( private val pollTimeoutMs = options.getLong( "kafkaConsumer.pollTimeoutMs", - SparkEnv.get.conf.getTimeAsMs("spark.network.timeout", "120s")) + SparkEnv.get.conf.getTimeAsSeconds("spark.network.timeout", "120s") * 1000L) private val maxOffsetsPerTrigger = Option(options.get("maxOffsetsPerTrigger").orElse(null)).map(_.toLong) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRelation.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRelation.scala index 7103709969c18..c31e6ed3e0903 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRelation.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRelation.scala @@ -48,7 +48,9 @@ private[kafka010] class KafkaRelation( private val pollTimeoutMs = sourceOptions.getOrElse( "kafkaConsumer.pollTimeoutMs", - sqlContext.sparkContext.conf.getTimeAsMs("spark.network.timeout", "120s").toString + (sqlContext.sparkContext.conf.getTimeAsSeconds( + "spark.network.timeout", + "120s") * 1000L).toString ).toLong override def schema: StructType = KafkaOffsetReader.kafkaSchema diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala index 1c7b3a29a861f..101e649727fcf 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala @@ -84,7 +84,7 @@ private[kafka010] class KafkaSource( private val pollTimeoutMs = sourceOptions.getOrElse( "kafkaConsumer.pollTimeoutMs", - sc.conf.getTimeAsMs("spark.network.timeout", "120s").toString + (sc.conf.getTimeAsSeconds("spark.network.timeout", "120s") * 1000L).toString ).toLong private val maxOffsetsPerTrigger = diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaRDD.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaRDD.scala index 81abc9860bfc3..3efc90fe466b2 100644 --- a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaRDD.scala +++ b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaRDD.scala @@ -65,7 +65,7 @@ private[spark] class KafkaRDD[K, V]( // TODO is it necessary to have separate configs for initial poll time vs ongoing poll time? private val pollTimeout = conf.getLong("spark.streaming.kafka.consumer.poll.ms", - conf.getTimeAsMs("spark.network.timeout", "120s")) + conf.getTimeAsSeconds("spark.network.timeout", "120s") * 1000L) private val cacheInitialCapacity = conf.getInt("spark.streaming.kafka.consumer.cache.initialCapacity", 16) private val cacheMaxCapacity = diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala index 9b75e4c98344a..d35bea4aca311 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala @@ -634,7 +634,7 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( slave.hostname, externalShufflePort, sc.conf.getTimeAsMs("spark.storage.blockManagerSlaveTimeoutMs", - s"${sc.conf.getTimeAsMs("spark.network.timeout", "120s")}ms"), + s"${sc.conf.getTimeAsSeconds("spark.network.timeout", "120s") * 1000L}ms"), sc.conf.getTimeAsMs("spark.executor.heartbeatInterval", "10s")) slave.shuffleRegistered = true } From 0fd68cb7278e5fdf106e73b580ee7dd829006386 Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Thu, 24 May 2018 17:08:52 -0700 Subject: [PATCH 0872/1161] [SPARK-24234][SS] Support multiple row writers in continuous processing shuffle reader. ## What changes were proposed in this pull request? https://docs.google.com/document/d/1IL4kJoKrZWeyIhklKUJqsW-yEN7V7aL05MmM65AYOfE/edit#heading=h.8t3ci57f7uii Support multiple different row writers in continuous processing shuffle reader. Note that having multiple read-side buffers ended up being the natural way to do this. Otherwise it's hard to express the constraint of sending an epoch marker only when all writers have sent one. ## How was this patch tested? new unit tests Author: Jose Torres Closes #21385 from jose-torres/multipleWrite. --- .../shuffle/ContinuousShuffleReadRDD.scala | 21 ++- .../shuffle/UnsafeRowReceiver.scala | 87 ++++++++-- .../shuffle/ContinuousShuffleReadSuite.scala | 163 +++++++++++++++--- 3 files changed, 227 insertions(+), 44 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala index 270b1a5c28dee..801b28b751bee 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala @@ -25,11 +25,16 @@ import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.NextIterator -case class ContinuousShuffleReadPartition(index: Int, queueSize: Int) extends Partition { +case class ContinuousShuffleReadPartition( + index: Int, + queueSize: Int, + numShuffleWriters: Int, + epochIntervalMs: Long) + extends Partition { // Initialized only on the executor, and only once even as we call compute() multiple times. lazy val (reader: ContinuousShuffleReader, endpoint) = { val env = SparkEnv.get.rpcEnv - val receiver = new UnsafeRowReceiver(queueSize, env) + val receiver = new UnsafeRowReceiver(queueSize, numShuffleWriters, epochIntervalMs, env) val endpoint = env.setupEndpoint(s"UnsafeRowReceiver-${UUID.randomUUID()}", receiver) TaskContext.get().addTaskCompletionListener { ctx => env.stop(endpoint) @@ -42,16 +47,24 @@ case class ContinuousShuffleReadPartition(index: Int, queueSize: Int) extends Pa * RDD at the map side of each continuous processing shuffle task. Upstream tasks send their * shuffle output to the wrapped receivers in partitions of this RDD; each of the RDD's tasks * poll from their receiver until an epoch marker is sent. + * + * @param sc the RDD context + * @param numPartitions the number of read partitions for this RDD + * @param queueSize the size of the row buffers to use + * @param numShuffleWriters the number of continuous shuffle writers feeding into this RDD + * @param epochIntervalMs the checkpoint interval of the streaming query */ class ContinuousShuffleReadRDD( sc: SparkContext, numPartitions: Int, - queueSize: Int = 1024) + queueSize: Int = 1024, + numShuffleWriters: Int = 1, + epochIntervalMs: Long = 1000) extends RDD[UnsafeRow](sc, Nil) { override protected def getPartitions: Array[Partition] = { (0 until numPartitions).map { partIndex => - ContinuousShuffleReadPartition(partIndex, queueSize) + ContinuousShuffleReadPartition(partIndex, queueSize, numShuffleWriters, epochIntervalMs) }.toArray } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala index b8adbb743c6c2..d81f552d56626 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala @@ -17,9 +17,11 @@ package org.apache.spark.sql.execution.streaming.continuous.shuffle -import java.util.concurrent.{ArrayBlockingQueue, BlockingQueue} +import java.util.concurrent._ import java.util.concurrent.atomic.AtomicBoolean +import scala.collection.mutable + import org.apache.spark.internal.Logging import org.apache.spark.rpc.{RpcCallContext, RpcEnv, ThreadSafeRpcEndpoint} import org.apache.spark.sql.catalyst.expressions.UnsafeRow @@ -27,10 +29,17 @@ import org.apache.spark.util.NextIterator /** * Messages for the UnsafeRowReceiver endpoint. Either an incoming row or an epoch marker. + * + * Each message comes tagged with writerId, identifying which writer the message is coming + * from. The receiver will only begin the next epoch once all writers have sent an epoch + * marker ending the current epoch. */ -private[shuffle] sealed trait UnsafeRowReceiverMessage extends Serializable -private[shuffle] case class ReceiverRow(row: UnsafeRow) extends UnsafeRowReceiverMessage -private[shuffle] case class ReceiverEpochMarker() extends UnsafeRowReceiverMessage +private[shuffle] sealed trait UnsafeRowReceiverMessage extends Serializable { + def writerId: Int +} +private[shuffle] case class ReceiverRow(writerId: Int, row: UnsafeRow) + extends UnsafeRowReceiverMessage +private[shuffle] case class ReceiverEpochMarker(writerId: Int) extends UnsafeRowReceiverMessage /** * RPC endpoint for receiving rows into a continuous processing shuffle task. Continuous shuffle @@ -41,11 +50,15 @@ private[shuffle] case class ReceiverEpochMarker() extends UnsafeRowReceiverMessa */ private[shuffle] class UnsafeRowReceiver( queueSize: Int, + numShuffleWriters: Int, + epochIntervalMs: Long, override val rpcEnv: RpcEnv) extends ThreadSafeRpcEndpoint with ContinuousShuffleReader with Logging { // Note that this queue will be drained from the main task thread and populated in the RPC // response thread. - private val queue = new ArrayBlockingQueue[UnsafeRowReceiverMessage](queueSize) + private val queues = Array.fill(numShuffleWriters) { + new ArrayBlockingQueue[UnsafeRowReceiverMessage](queueSize) + } // Exposed for testing to determine if the endpoint gets stopped on task end. private[shuffle] val stopped = new AtomicBoolean(false) @@ -56,20 +69,70 @@ private[shuffle] class UnsafeRowReceiver( override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case r: UnsafeRowReceiverMessage => - queue.put(r) + queues(r.writerId).put(r) context.reply(()) } override def read(): Iterator[UnsafeRow] = { new NextIterator[UnsafeRow] { - override def getNext(): UnsafeRow = queue.take() match { - case ReceiverRow(r) => r - case ReceiverEpochMarker() => - finished = true - null + // An array of flags for whether each writer ID has gotten an epoch marker. + private val writerEpochMarkersReceived = Array.fill(numShuffleWriters)(false) + + private val executor = Executors.newFixedThreadPool(numShuffleWriters) + private val completion = new ExecutorCompletionService[UnsafeRowReceiverMessage](executor) + + private def completionTask(writerId: Int) = new Callable[UnsafeRowReceiverMessage] { + override def call(): UnsafeRowReceiverMessage = queues(writerId).take() } - override def close(): Unit = {} + // Initialize by submitting tasks to read the first row from each writer. + (0 until numShuffleWriters).foreach(writerId => completion.submit(completionTask(writerId))) + + /** + * In each call to getNext(), we pull the next row available in the completion queue, and then + * submit another task to read the next row from the writer which returned it. + * + * When a writer sends an epoch marker, we note that it's finished and don't submit another + * task for it in this epoch. The iterator is over once all writers have sent an epoch marker. + */ + override def getNext(): UnsafeRow = { + var nextRow: UnsafeRow = null + while (!finished && nextRow == null) { + completion.poll(epochIntervalMs, TimeUnit.MILLISECONDS) match { + case null => + // Try again if the poll didn't wait long enough to get a real result. + // But we should be getting at least an epoch marker every checkpoint interval. + val writerIdsUncommitted = writerEpochMarkersReceived.zipWithIndex.collect { + case (flag, idx) if !flag => idx + } + logWarning( + s"Completion service failed to make progress after $epochIntervalMs ms. Waiting " + + s"for writers $writerIdsUncommitted to send epoch markers.") + + // The completion service guarantees this future will be available immediately. + case future => future.get() match { + case ReceiverRow(writerId, r) => + // Start reading the next element in the queue we just took from. + completion.submit(completionTask(writerId)) + nextRow = r + case ReceiverEpochMarker(writerId) => + // Don't read any more from this queue. If all the writers have sent epoch markers, + // the epoch is over; otherwise we need to loop again to poll from the remaining + // writers. + writerEpochMarkersReceived(writerId) = true + if (writerEpochMarkersReceived.forall(_ == true)) { + finished = true + } + } + } + } + + nextRow + } + + override def close(): Unit = { + executor.shutdownNow() + } } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala index b25e75b3b37a6..2e4d607a403ca 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala @@ -21,7 +21,8 @@ import org.apache.spark.{TaskContext, TaskContextImpl} import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection} import org.apache.spark.sql.streaming.StreamTest -import org.apache.spark.sql.types.{DataType, IntegerType} +import org.apache.spark.sql.types.{DataType, IntegerType, StringType} +import org.apache.spark.unsafe.types.UTF8String class ContinuousShuffleReadSuite extends StreamTest { @@ -30,6 +31,11 @@ class ContinuousShuffleReadSuite extends StreamTest { new GenericInternalRow(Array(value: Any))) } + private def unsafeRow(value: String) = { + UnsafeProjection.create(Array(StringType : DataType))( + new GenericInternalRow(Array(UTF8String.fromString(value): Any))) + } + private def send(endpoint: RpcEndpointRef, messages: UnsafeRowReceiverMessage*) = { messages.foreach(endpoint.askSync[Unit](_)) } @@ -57,8 +63,8 @@ class ContinuousShuffleReadSuite extends StreamTest { val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint send( endpoint, - ReceiverEpochMarker(), - ReceiverRow(unsafeRow(111)) + ReceiverEpochMarker(0), + ReceiverRow(0, unsafeRow(111)) ) ctx.markTaskCompleted(None) @@ -71,8 +77,11 @@ class ContinuousShuffleReadSuite extends StreamTest { test("receiver stopped with marker last") { val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint - endpoint.askSync[Unit](ReceiverRow(unsafeRow(111))) - endpoint.askSync[Unit](ReceiverEpochMarker()) + send( + endpoint, + ReceiverRow(0, unsafeRow(111)), + ReceiverEpochMarker(0) + ) ctx.markTaskCompleted(None) val receiver = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].reader @@ -86,10 +95,10 @@ class ContinuousShuffleReadSuite extends StreamTest { val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint send( endpoint, - ReceiverRow(unsafeRow(111)), - ReceiverRow(unsafeRow(222)), - ReceiverRow(unsafeRow(333)), - ReceiverEpochMarker() + ReceiverRow(0, unsafeRow(111)), + ReceiverRow(0, unsafeRow(222)), + ReceiverRow(0, unsafeRow(333)), + ReceiverEpochMarker(0) ) val iter = rdd.compute(rdd.partitions(0), ctx) @@ -101,11 +110,11 @@ class ContinuousShuffleReadSuite extends StreamTest { val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint send( endpoint, - ReceiverRow(unsafeRow(111)), - ReceiverEpochMarker(), - ReceiverRow(unsafeRow(222)), - ReceiverRow(unsafeRow(333)), - ReceiverEpochMarker() + ReceiverRow(0, unsafeRow(111)), + ReceiverEpochMarker(0), + ReceiverRow(0, unsafeRow(222)), + ReceiverRow(0, unsafeRow(333)), + ReceiverEpochMarker(0) ) val firstEpoch = rdd.compute(rdd.partitions(0), ctx) @@ -118,14 +127,15 @@ class ContinuousShuffleReadSuite extends StreamTest { test("empty epochs") { val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint + send( endpoint, - ReceiverEpochMarker(), - ReceiverEpochMarker(), - ReceiverRow(unsafeRow(111)), - ReceiverEpochMarker(), - ReceiverEpochMarker(), - ReceiverEpochMarker() + ReceiverEpochMarker(0), + ReceiverEpochMarker(0), + ReceiverRow(0, unsafeRow(111)), + ReceiverEpochMarker(0), + ReceiverEpochMarker(0), + ReceiverEpochMarker(0) ) assert(rdd.compute(rdd.partitions(0), ctx).isEmpty) @@ -146,8 +156,8 @@ class ContinuousShuffleReadSuite extends StreamTest { // Send index for identification. send( part.endpoint, - ReceiverRow(unsafeRow(part.index)), - ReceiverEpochMarker() + ReceiverRow(0, unsafeRow(part.index)), + ReceiverEpochMarker(0) ) } @@ -160,25 +170,122 @@ class ContinuousShuffleReadSuite extends StreamTest { } test("blocks waiting for new rows") { - val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) + val rdd = new ContinuousShuffleReadRDD( + sparkContext, numPartitions = 1, epochIntervalMs = Long.MaxValue) + val epoch = rdd.compute(rdd.partitions(0), ctx) val readRowThread = new Thread { override def run(): Unit = { - // set the non-inheritable thread local - TaskContext.setTaskContext(ctx) - val epoch = rdd.compute(rdd.partitions(0), ctx) - epoch.next().getInt(0) + try { + epoch.next().getInt(0) + } catch { + case _: InterruptedException => // do nothing - expected at test ending + } } } try { readRowThread.start() eventually(timeout(streamingTimeout)) { - assert(readRowThread.getState == Thread.State.WAITING) + assert(readRowThread.getState == Thread.State.TIMED_WAITING) } } finally { readRowThread.interrupt() readRowThread.join() } } + + test("multiple writers") { + val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1, numShuffleWriters = 3) + val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint + send( + endpoint, + ReceiverRow(0, unsafeRow("writer0-row0")), + ReceiverRow(1, unsafeRow("writer1-row0")), + ReceiverRow(2, unsafeRow("writer2-row0")), + ReceiverEpochMarker(0), + ReceiverEpochMarker(1), + ReceiverEpochMarker(2) + ) + + val firstEpoch = rdd.compute(rdd.partitions(0), ctx) + assert(firstEpoch.toSeq.map(_.getUTF8String(0).toString).toSet == + Set("writer0-row0", "writer1-row0", "writer2-row0")) + } + + test("epoch only ends when all writers send markers") { + val rdd = new ContinuousShuffleReadRDD( + sparkContext, numPartitions = 1, numShuffleWriters = 3, epochIntervalMs = Long.MaxValue) + val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint + send( + endpoint, + ReceiverRow(0, unsafeRow("writer0-row0")), + ReceiverRow(1, unsafeRow("writer1-row0")), + ReceiverRow(2, unsafeRow("writer2-row0")), + ReceiverEpochMarker(0), + ReceiverEpochMarker(2) + ) + + val epoch = rdd.compute(rdd.partitions(0), ctx) + val rows = (0 until 3).map(_ => epoch.next()).toSet + assert(rows.map(_.getUTF8String(0).toString) == + Set("writer0-row0", "writer1-row0", "writer2-row0")) + + // After checking the right rows, block until we get an epoch marker indicating there's no next. + // (Also fail the assertion if for some reason we get a row.) + val readEpochMarkerThread = new Thread { + override def run(): Unit = { + assert(!epoch.hasNext) + } + } + + readEpochMarkerThread.start() + eventually(timeout(streamingTimeout)) { + assert(readEpochMarkerThread.getState == Thread.State.TIMED_WAITING) + } + + // Send the last epoch marker - now the epoch should finish. + send(endpoint, ReceiverEpochMarker(1)) + eventually(timeout(streamingTimeout)) { + !readEpochMarkerThread.isAlive + } + + // Join to pick up assertion failures. + readEpochMarkerThread.join() + } + + test("writer epochs non aligned") { + val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1, numShuffleWriters = 3) + val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint + // We send multiple epochs for 0, then multiple for 1, then multiple for 2. The receiver should + // collate them as though the markers were aligned in the first place. + send( + endpoint, + ReceiverRow(0, unsafeRow("writer0-row0")), + ReceiverEpochMarker(0), + ReceiverRow(0, unsafeRow("writer0-row1")), + ReceiverEpochMarker(0), + ReceiverEpochMarker(0), + + ReceiverEpochMarker(1), + ReceiverRow(1, unsafeRow("writer1-row0")), + ReceiverEpochMarker(1), + ReceiverRow(1, unsafeRow("writer1-row1")), + ReceiverEpochMarker(1), + + ReceiverEpochMarker(2), + ReceiverEpochMarker(2), + ReceiverRow(2, unsafeRow("writer2-row0")), + ReceiverEpochMarker(2) + ) + + val firstEpoch = rdd.compute(rdd.partitions(0), ctx).map(_.getUTF8String(0).toString).toSet + assert(firstEpoch == Set("writer0-row0")) + + val secondEpoch = rdd.compute(rdd.partitions(0), ctx).map(_.getUTF8String(0).toString).toSet + assert(secondEpoch == Set("writer0-row1", "writer1-row0")) + + val thirdEpoch = rdd.compute(rdd.partitions(0), ctx).map(_.getUTF8String(0).toString).toSet + assert(thirdEpoch == Set("writer1-row1", "writer2-row0")) + } } From 3b20b34ab72c92d9d20188ed430955e1a94eac9c Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Fri, 25 May 2018 11:16:35 +0800 Subject: [PATCH 0873/1161] [SPARK-24367][SQL] Parquet: use JOB_SUMMARY_LEVEL instead of deprecated flag ENABLE_JOB_SUMMARY ## What changes were proposed in this pull request? In current parquet version,the conf ENABLE_JOB_SUMMARY is deprecated. When writing to Parquet files, the warning message ```WARN org.apache.parquet.hadoop.ParquetOutputFormat: Setting parquet.enable.summary-metadata is deprecated, please use parquet.summary.metadata.level``` keeps showing up. From https://github.com/apache/parquet-mr/blame/master/parquet-hadoop/src/main/java/org/apache/parquet/hadoop/ParquetOutputFormat.java#L164 we can see that we should use JOB_SUMMARY_LEVEL. ## How was this patch tested? Unit test Author: Gengliang Wang Closes #21411 from gengliangwang/summaryLevel. --- .../scala/org/apache/spark/sql/internal/SQLConf.scala | 2 +- .../datasources/parquet/ParquetFileFormat.scala | 10 ++++++---- .../datasources/parquet/ParquetCommitterSuite.scala | 7 ++++++- .../execution/datasources/parquet/ParquetIOSuite.scala | 2 +- .../parquet/ParquetPartitionDiscoverySuite.scala | 2 +- .../datasources/parquet/ParquetQuerySuite.scala | 2 +- .../sql/sources/ParquetHadoopFsRelationSuite.scala | 2 +- 7 files changed, 17 insertions(+), 10 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 15ba10f604510..93d356fef07af 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -395,7 +395,7 @@ object SQLConf { .doc("The output committer class used by Parquet. The specified class needs to be a " + "subclass of org.apache.hadoop.mapreduce.OutputCommitter. Typically, it's also a subclass " + "of org.apache.parquet.hadoop.ParquetOutputCommitter. If it is not, then metadata summaries" + - "will never be created, irrespective of the value of parquet.enable.summary-metadata") + "will never be created, irrespective of the value of parquet.summary.metadata.level") .internal() .stringConf .createWithDefault("org.apache.parquet.hadoop.ParquetOutputCommitter") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala index d1f9e11ed4225..60fc9ec7e1f82 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala @@ -34,6 +34,7 @@ import org.apache.parquet.filter2.compat.FilterCompat import org.apache.parquet.filter2.predicate.FilterApi import org.apache.parquet.format.converter.ParquetMetadataConverter.SKIP_ROW_GROUPS import org.apache.parquet.hadoop._ +import org.apache.parquet.hadoop.ParquetOutputFormat.JobSummaryLevel import org.apache.parquet.hadoop.codec.CodecConfig import org.apache.parquet.hadoop.util.ContextUtil import org.apache.parquet.schema.MessageType @@ -125,16 +126,17 @@ class ParquetFileFormat conf.set(ParquetOutputFormat.COMPRESSION, parquetOptions.compressionCodecClassName) // SPARK-15719: Disables writing Parquet summary files by default. - if (conf.get(ParquetOutputFormat.ENABLE_JOB_SUMMARY) == null) { - conf.setBoolean(ParquetOutputFormat.ENABLE_JOB_SUMMARY, false) + if (conf.get(ParquetOutputFormat.JOB_SUMMARY_LEVEL) == null + && conf.get(ParquetOutputFormat.ENABLE_JOB_SUMMARY) == null) { + conf.setEnum(ParquetOutputFormat.JOB_SUMMARY_LEVEL, JobSummaryLevel.NONE) } - if (conf.getBoolean(ParquetOutputFormat.ENABLE_JOB_SUMMARY, false) + if (ParquetOutputFormat.getJobSummaryLevel(conf) == JobSummaryLevel.NONE && !classOf[ParquetOutputCommitter].isAssignableFrom(committerClass)) { // output summary is requested, but the class is not a Parquet Committer logWarning(s"Committer $committerClass is not a ParquetOutputCommitter and cannot" + s" create job summaries. " + - s"Set Parquet option ${ParquetOutputFormat.ENABLE_JOB_SUMMARY} to false.") + s"Set Parquet option ${ParquetOutputFormat.JOB_SUMMARY_LEVEL} to NONE.") } new OutputWriterFactory { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCommitterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCommitterSuite.scala index f3ecc5ced689f..4b2437803d645 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCommitterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCommitterSuite.scala @@ -91,9 +91,14 @@ class ParquetCommitterSuite extends SparkFunSuite with SQLTestUtils summary: Boolean, check: Boolean): Option[FileStatus] = { var result: Option[FileStatus] = None + val summaryLevel = if (summary) { + "ALL" + } else { + "NONE" + } withSQLConf( SQLConf.PARQUET_OUTPUT_COMMITTER_CLASS.key -> committer, - ParquetOutputFormat.ENABLE_JOB_SUMMARY -> summary.toString) { + ParquetOutputFormat.JOB_SUMMARY_LEVEL -> summaryLevel) { withTempPath { dest => val df = spark.createDataFrame(Seq((1, "4"), (2, "2"))) val destPath = new Path(dest.toURI) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala index 0b3e8ca060d87..002c42f23bd64 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala @@ -543,7 +543,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { val hadoopConf = spark.sessionState.newHadoopConfWithOptions(extraOptions) - withSQLConf(ParquetOutputFormat.ENABLE_JOB_SUMMARY -> "true") { + withSQLConf(ParquetOutputFormat.JOB_SUMMARY_LEVEL -> "ALL") { withTempPath { dir => val path = s"${dir.getCanonicalPath}/part-r-0.parquet" spark.range(1 << 16).selectExpr("(id % 4) AS i") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala index e887c9734a8b8..9966ed94a8392 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala @@ -1014,7 +1014,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha val path = dir.getCanonicalPath withSQLConf( - ParquetOutputFormat.ENABLE_JOB_SUMMARY -> "true", + ParquetOutputFormat.JOB_SUMMARY_LEVEL -> "ALL", "spark.sql.sources.commitProtocolClass" -> classOf[SQLHadoopMapReduceCommitProtocol].getCanonicalName) { spark.range(3).write.parquet(s"$path/p0=0/p1=0") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala index 2b1227faf48a0..dbf637783e6d2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala @@ -275,7 +275,7 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext classOf[SQLHadoopMapReduceCommitProtocol].getCanonicalName, SQLConf.PARQUET_SCHEMA_MERGING_ENABLED.key -> "true", SQLConf.PARQUET_SCHEMA_RESPECT_SUMMARIES.key -> "true", - ParquetOutputFormat.ENABLE_JOB_SUMMARY -> "true" + ParquetOutputFormat.JOB_SUMMARY_LEVEL -> "ALL" ) { testSchemaMerging(2) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala index dce5bb7ddba66..6858bbc441721 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala @@ -124,7 +124,7 @@ class ParquetHadoopFsRelationSuite extends HadoopFsRelationTest { test("SPARK-8604: Parquet data source should write summary file while doing appending") { withSQLConf( - ParquetOutputFormat.ENABLE_JOB_SUMMARY -> "true", + ParquetOutputFormat.JOB_SUMMARY_LEVEL -> "ALL", SQLConf.FILE_COMMIT_PROTOCOL_CLASS.key -> classOf[SQLHadoopMapReduceCommitProtocol].getCanonicalName) { withTempPath { dir => From 64fad0b519cf35b8c0a0dec18dd3df9488a5ed25 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Thu, 24 May 2018 21:38:04 -0700 Subject: [PATCH 0874/1161] [SPARK-24244][SPARK-24368][SQL] Passing only required columns to the CSV parser ## What changes were proposed in this pull request? uniVocity parser allows to specify only required column names or indexes for [parsing](https://www.univocity.com/pages/parsers-tutorial) like: ``` // Here we select only the columns by their indexes. // The parser just skips the values in other columns parserSettings.selectIndexes(4, 0, 1); CsvParser parser = new CsvParser(parserSettings); ``` In this PR, I propose to extract indexes from required schema and pass them into the CSV parser. Benchmarks show the following improvements in parsing of 1000 columns: ``` Select 100 columns out of 1000: x1.76 Select 1 column out of 1000: x2 ``` **Note**: Comparing to current implementation, the changes can return different result for malformed rows in the `DROPMALFORMED` and `FAILFAST` modes if only subset of all columns is requested. To have previous behavior, set `spark.sql.csv.parser.columnPruning.enabled` to `false`. ## How was this patch tested? It was tested by new test which selects 3 columns out of 15, by existing tests and by new benchmarks. Author: Maxim Gekk Author: Maxim Gekk Closes #21415 from MaxGekk/csv-column-pruning2. --- docs/sql-programming-guide.md | 1 + .../apache/spark/sql/internal/SQLConf.scala | 9 ++++ .../apache/spark/sql/DataFrameReader.scala | 1 + .../datasources/csv/CSVFileFormat.scala | 18 ++++++-- .../datasources/csv/CSVOptions.scala | 3 ++ .../datasources/csv/UnivocityParser.scala | 26 ++++++----- .../datasources/csv/CSVBenchmarks.scala | 46 +++++++++++++++++++ .../datasources/csv/CSVInferSchemaSuite.scala | 22 ++++----- .../execution/datasources/csv/CSVSuite.scala | 41 ++++++++++++++--- .../csv/UnivocityParserSuite.scala | 37 +++++++-------- 10 files changed, 152 insertions(+), 52 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index f1ed316341b95..fc26562ff33da 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1825,6 +1825,7 @@ working with timestamps in `pandas_udf`s to get the best performance, see - In version 2.3 and earlier, `to_utc_timestamp` and `from_utc_timestamp` respect the timezone in the input timestamp string, which breaks the assumption that the input timestamp is in a specific timezone. Therefore, these 2 functions can return unexpected results. In version 2.4 and later, this problem has been fixed. `to_utc_timestamp` and `from_utc_timestamp` will return null if the input timestamp string contains timezone. As an example, `from_utc_timestamp('2000-10-10 00:00:00', 'GMT+1')` will return `2000-10-10 01:00:00` in both Spark 2.3 and 2.4. However, `from_utc_timestamp('2000-10-10 00:00:00+00:00', 'GMT+1')`, assuming a local timezone of GMT+8, will return `2000-10-10 09:00:00` in Spark 2.3 but `null` in 2.4. For people who don't care about this problem and want to retain the previous behaivor to keep their query unchanged, you can set `spark.sql.function.rejectTimezoneInString` to false. This option will be removed in Spark 3.0 and should only be used as a temporary workaround. - In version 2.3 and earlier, Spark converts Parquet Hive tables by default but ignores table properties like `TBLPROPERTIES (parquet.compression 'NONE')`. This happens for ORC Hive table properties like `TBLPROPERTIES (orc.compress 'NONE')` in case of `spark.sql.hive.convertMetastoreOrc=true`, too. Since Spark 2.4, Spark respects Parquet/ORC specific table properties while converting Parquet/ORC Hive tables. As an example, `CREATE TABLE t(id int) STORED AS PARQUET TBLPROPERTIES (parquet.compression 'NONE')` would generate Snappy parquet files during insertion in Spark 2.3, and in Spark 2.4, the result would be uncompressed parquet files. - Since Spark 2.0, Spark converts Parquet Hive tables by default for better performance. Since Spark 2.4, Spark converts ORC Hive tables by default, too. It means Spark uses its own ORC support by default instead of Hive SerDe. As an example, `CREATE TABLE t(id int) STORED AS ORC` would be handled with Hive SerDe in Spark 2.3, and in Spark 2.4, it would be converted into Spark's ORC data source table and ORC vectorization would be applied. To set `false` to `spark.sql.hive.convertMetastoreOrc` restores the previous behavior. + - In version 2.3 and earlier, CSV rows are considered as malformed if at least one column value in the row is malformed. CSV parser dropped such rows in the DROPMALFORMED mode or outputs an error in the FAILFAST mode. Since Spark 2.4, CSV row is considered as malformed only when it contains malformed column values requested from CSV datasource, other values can be ignored. As an example, CSV file contains the "id,name" header and one row "1234". In Spark 2.4, selection of the id column consists of a row with one column value 1234 but in Spark 2.3 and earlier it is empty in the DROPMALFORMED mode. To restore the previous behavior, set `spark.sql.csv.parser.columnPruning.enabled` to `false`. ## Upgrading From Spark SQL 2.2 to 2.3 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 93d356fef07af..8d2320d8a6ed7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1307,6 +1307,13 @@ object SQLConf { object Replaced { val MAPREDUCE_JOB_REDUCES = "mapreduce.job.reduces" } + + val CSV_PARSER_COLUMN_PRUNING = buildConf("spark.sql.csv.parser.columnPruning.enabled") + .internal() + .doc("If it is set to true, column names of the requested schema are passed to CSV parser. " + + "Other column values can be ignored during parsing even if they are malformed.") + .booleanConf + .createWithDefault(true) } /** @@ -1664,6 +1671,8 @@ class SQLConf extends Serializable with Logging { def partitionOverwriteMode: PartitionOverwriteMode.Value = PartitionOverwriteMode.withName(getConf(PARTITION_OVERWRITE_MODE)) + def csvColumnPruning: Boolean = getConf(SQLConf.CSV_PARSER_COLUMN_PRUNING) + /** ********************** SQLConf functionality methods ************ */ /** Set Spark SQL configuration properties. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 917f0cb221412..ac4580a0919ad 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -480,6 +480,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { def csv(csvDataset: Dataset[String]): DataFrame = { val parsedOptions: CSVOptions = new CSVOptions( extraOptions.toMap, + sparkSession.sessionState.conf.csvColumnPruning, sparkSession.sessionState.conf.sessionLocalTimeZone) val filteredLines: Dataset[String] = CSVUtils.filterCommentAndEmpty(csvDataset, parsedOptions) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala index e20977a4ec79f..21279d6daf7ad 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala @@ -41,8 +41,10 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { sparkSession: SparkSession, options: Map[String, String], path: Path): Boolean = { - val parsedOptions = - new CSVOptions(options, sparkSession.sessionState.conf.sessionLocalTimeZone) + val parsedOptions = new CSVOptions( + options, + columnPruning = sparkSession.sessionState.conf.csvColumnPruning, + sparkSession.sessionState.conf.sessionLocalTimeZone) val csvDataSource = CSVDataSource(parsedOptions) csvDataSource.isSplitable && super.isSplitable(sparkSession, options, path) } @@ -51,8 +53,10 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { sparkSession: SparkSession, options: Map[String, String], files: Seq[FileStatus]): Option[StructType] = { - val parsedOptions = - new CSVOptions(options, sparkSession.sessionState.conf.sessionLocalTimeZone) + val parsedOptions = new CSVOptions( + options, + columnPruning = sparkSession.sessionState.conf.csvColumnPruning, + sparkSession.sessionState.conf.sessionLocalTimeZone) CSVDataSource(parsedOptions).inferSchema(sparkSession, files, parsedOptions) } @@ -64,7 +68,10 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { dataSchema: StructType): OutputWriterFactory = { CSVUtils.verifySchema(dataSchema) val conf = job.getConfiguration - val csvOptions = new CSVOptions(options, sparkSession.sessionState.conf.sessionLocalTimeZone) + val csvOptions = new CSVOptions( + options, + columnPruning = sparkSession.sessionState.conf.csvColumnPruning, + sparkSession.sessionState.conf.sessionLocalTimeZone) csvOptions.compressionCodec.foreach { codec => CompressionCodecs.setCodecConfiguration(conf, codec) } @@ -97,6 +104,7 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { val parsedOptions = new CSVOptions( options, + sparkSession.sessionState.conf.csvColumnPruning, sparkSession.sessionState.conf.sessionLocalTimeZone, sparkSession.sessionState.conf.columnNameOfCorruptRecord) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala index 1066d156acd74..7119189a4e131 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala @@ -28,16 +28,19 @@ import org.apache.spark.sql.catalyst.util._ class CSVOptions( @transient val parameters: CaseInsensitiveMap[String], + val columnPruning: Boolean, defaultTimeZoneId: String, defaultColumnNameOfCorruptRecord: String) extends Logging with Serializable { def this( parameters: Map[String, String], + columnPruning: Boolean, defaultTimeZoneId: String, defaultColumnNameOfCorruptRecord: String = "") = { this( CaseInsensitiveMap(parameters), + columnPruning, defaultTimeZoneId, defaultColumnNameOfCorruptRecord) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala index 99557a1ceb0c8..4f00cc5eb3f39 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala @@ -34,10 +34,10 @@ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String class UnivocityParser( - schema: StructType, + dataSchema: StructType, requiredSchema: StructType, val options: CSVOptions) extends Logging { - require(requiredSchema.toSet.subsetOf(schema.toSet), + require(requiredSchema.toSet.subsetOf(dataSchema.toSet), "requiredSchema should be the subset of schema.") def this(schema: StructType, options: CSVOptions) = this(schema, schema, options) @@ -45,9 +45,17 @@ class UnivocityParser( // A `ValueConverter` is responsible for converting the given value to a desired type. private type ValueConverter = String => Any - private val tokenizer = new CsvParser(options.asParserSettings) + private val tokenizer = { + val parserSetting = options.asParserSettings + if (options.columnPruning && requiredSchema.length < dataSchema.length) { + val tokenIndexArr = requiredSchema.map(f => java.lang.Integer.valueOf(dataSchema.indexOf(f))) + parserSetting.selectIndexes(tokenIndexArr: _*) + } + new CsvParser(parserSetting) + } + private val schema = if (options.columnPruning) requiredSchema else dataSchema - private val row = new GenericInternalRow(requiredSchema.length) + private val row = new GenericInternalRow(schema.length) // Retrieve the raw record string. private def getCurrentInput: UTF8String = { @@ -73,11 +81,8 @@ class UnivocityParser( // Each input token is placed in each output row's position by mapping these. In this case, // // output row - ["A", 2] - private val valueConverters: Array[ValueConverter] = + private val valueConverters: Array[ValueConverter] = { schema.map(f => makeConverter(f.name, f.dataType, f.nullable, options)).toArray - - private val tokenIndexArr: Array[Int] = { - requiredSchema.map(f => schema.indexOf(f)).toArray } /** @@ -210,9 +215,8 @@ class UnivocityParser( } else { try { var i = 0 - while (i < requiredSchema.length) { - val from = tokenIndexArr(i) - row(i) = valueConverters(from).apply(tokens(from)) + while (i < schema.length) { + row(i) = valueConverters(i).apply(tokens(i)) i += 1 } row diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVBenchmarks.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVBenchmarks.scala index d442ba7e59c61..1a3dacb8398e6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVBenchmarks.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVBenchmarks.scala @@ -74,7 +74,53 @@ object CSVBenchmarks { } } + def multiColumnsBenchmark(rowsNum: Int): Unit = { + val colsNum = 1000 + val benchmark = new Benchmark(s"Wide rows with $colsNum columns", rowsNum) + + withTempPath { path => + val fields = Seq.tabulate(colsNum)(i => StructField(s"col$i", IntegerType)) + val schema = StructType(fields) + val values = (0 until colsNum).map(i => i.toString).mkString(",") + val columnNames = schema.fieldNames + + spark.range(rowsNum) + .select(Seq.tabulate(colsNum)(i => lit(i).as(s"col$i")): _*) + .write.option("header", true) + .csv(path.getAbsolutePath) + + val ds = spark.read.schema(schema).csv(path.getAbsolutePath) + + benchmark.addCase(s"Select $colsNum columns", 3) { _ => + ds.select("*").filter((row: Row) => true).count() + } + val cols100 = columnNames.take(100).map(Column(_)) + benchmark.addCase(s"Select 100 columns", 3) { _ => + ds.select(cols100: _*).filter((row: Row) => true).count() + } + benchmark.addCase(s"Select one column", 3) { _ => + ds.select($"col1").filter((row: Row) => true).count() + } + benchmark.addCase(s"count()", 3) { _ => + ds.count() + } + + /* + Intel(R) Core(TM) i7-7920HQ CPU @ 3.10GHz + + Wide rows with 1000 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + -------------------------------------------------------------------------------------------- + Select 1000 columns 81091 / 81692 0.0 81090.7 1.0X + Select 100 columns 30003 / 34448 0.0 30003.0 2.7X + Select one column 24792 / 24855 0.0 24792.0 3.3X + count() 24344 / 24642 0.0 24343.8 3.3X + */ + benchmark.run() + } + } + def main(args: Array[String]): Unit = { quotedValuesBenchmark(rowsNum = 50 * 1000, numIters = 3) + multiColumnsBenchmark(rowsNum = 1000 * 1000) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala index 661742087112f..842251be92c18 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.types._ class CSVInferSchemaSuite extends SparkFunSuite { test("String fields types are inferred correctly from null types") { - val options = new CSVOptions(Map.empty[String, String], "GMT") + val options = new CSVOptions(Map.empty[String, String], false, "GMT") assert(CSVInferSchema.inferField(NullType, "", options) == NullType) assert(CSVInferSchema.inferField(NullType, null, options) == NullType) assert(CSVInferSchema.inferField(NullType, "100000000000", options) == LongType) @@ -41,7 +41,7 @@ class CSVInferSchemaSuite extends SparkFunSuite { } test("String fields types are inferred correctly from other types") { - val options = new CSVOptions(Map.empty[String, String], "GMT") + val options = new CSVOptions(Map.empty[String, String], false, "GMT") assert(CSVInferSchema.inferField(LongType, "1.0", options) == DoubleType) assert(CSVInferSchema.inferField(LongType, "test", options) == StringType) assert(CSVInferSchema.inferField(IntegerType, "1.0", options) == DoubleType) @@ -60,21 +60,21 @@ class CSVInferSchemaSuite extends SparkFunSuite { } test("Timestamp field types are inferred correctly via custom data format") { - var options = new CSVOptions(Map("timestampFormat" -> "yyyy-mm"), "GMT") + var options = new CSVOptions(Map("timestampFormat" -> "yyyy-mm"), false, "GMT") assert(CSVInferSchema.inferField(TimestampType, "2015-08", options) == TimestampType) - options = new CSVOptions(Map("timestampFormat" -> "yyyy"), "GMT") + options = new CSVOptions(Map("timestampFormat" -> "yyyy"), false, "GMT") assert(CSVInferSchema.inferField(TimestampType, "2015", options) == TimestampType) } test("Timestamp field types are inferred correctly from other types") { - val options = new CSVOptions(Map.empty[String, String], "GMT") + val options = new CSVOptions(Map.empty[String, String], false, "GMT") assert(CSVInferSchema.inferField(IntegerType, "2015-08-20 14", options) == StringType) assert(CSVInferSchema.inferField(DoubleType, "2015-08-20 14:10", options) == StringType) assert(CSVInferSchema.inferField(LongType, "2015-08 14:49:00", options) == StringType) } test("Boolean fields types are inferred correctly from other types") { - val options = new CSVOptions(Map.empty[String, String], "GMT") + val options = new CSVOptions(Map.empty[String, String], false, "GMT") assert(CSVInferSchema.inferField(LongType, "Fale", options) == StringType) assert(CSVInferSchema.inferField(DoubleType, "TRUEe", options) == StringType) } @@ -92,12 +92,12 @@ class CSVInferSchemaSuite extends SparkFunSuite { } test("Null fields are handled properly when a nullValue is specified") { - var options = new CSVOptions(Map("nullValue" -> "null"), "GMT") + var options = new CSVOptions(Map("nullValue" -> "null"), false, "GMT") assert(CSVInferSchema.inferField(NullType, "null", options) == NullType) assert(CSVInferSchema.inferField(StringType, "null", options) == StringType) assert(CSVInferSchema.inferField(LongType, "null", options) == LongType) - options = new CSVOptions(Map("nullValue" -> "\\N"), "GMT") + options = new CSVOptions(Map("nullValue" -> "\\N"), false, "GMT") assert(CSVInferSchema.inferField(IntegerType, "\\N", options) == IntegerType) assert(CSVInferSchema.inferField(DoubleType, "\\N", options) == DoubleType) assert(CSVInferSchema.inferField(TimestampType, "\\N", options) == TimestampType) @@ -111,12 +111,12 @@ class CSVInferSchemaSuite extends SparkFunSuite { } test("SPARK-18433: Improve DataSource option keys to be more case-insensitive") { - val options = new CSVOptions(Map("TiMeStampFormat" -> "yyyy-mm"), "GMT") + val options = new CSVOptions(Map("TiMeStampFormat" -> "yyyy-mm"), false, "GMT") assert(CSVInferSchema.inferField(TimestampType, "2015-08", options) == TimestampType) } test("SPARK-18877: `inferField` on DecimalType should find a common type with `typeSoFar`") { - val options = new CSVOptions(Map.empty[String, String], "GMT") + val options = new CSVOptions(Map.empty[String, String], false, "GMT") // 9.03E+12 is Decimal(3, -10) and 1.19E+11 is Decimal(3, -9). assert(CSVInferSchema.inferField(DecimalType(3, -10), "1.19E+11", options) == @@ -134,7 +134,7 @@ class CSVInferSchemaSuite extends SparkFunSuite { test("DoubleType should be infered when user defined nan/inf are provided") { val options = new CSVOptions(Map("nanValue" -> "nan", "negativeInf" -> "-inf", - "positiveInf" -> "inf"), "GMT") + "positiveInf" -> "inf"), false, "GMT") assert(CSVInferSchema.inferField(NullType, "nan", options) == DoubleType) assert(CSVInferSchema.inferField(NullType, "inf", options) == DoubleType) assert(CSVInferSchema.inferField(NullType, "-inf", options) == DoubleType) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index 2bac1a3e2d4c6..afe10bdc4de26 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -260,14 +260,16 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils with Te } test("test for DROPMALFORMED parsing mode") { - Seq(false, true).foreach { multiLine => - val cars = spark.read - .format("csv") - .option("multiLine", multiLine) - .options(Map("header" -> "true", "mode" -> "dropmalformed")) - .load(testFile(carsFile)) + withSQLConf(SQLConf.CSV_PARSER_COLUMN_PRUNING.key -> "false") { + Seq(false, true).foreach { multiLine => + val cars = spark.read + .format("csv") + .option("multiLine", multiLine) + .options(Map("header" -> "true", "mode" -> "dropmalformed")) + .load(testFile(carsFile)) - assert(cars.select("year").collect().size === 2) + assert(cars.select("year").collect().size === 2) + } } } @@ -1383,4 +1385,29 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils with Te checkAnswer(ds, Seq(Row(""" "a" """))) } + + test("SPARK-24244: Select a subset of all columns") { + withTempPath { path => + import collection.JavaConverters._ + val schema = new StructType() + .add("f1", IntegerType).add("f2", IntegerType).add("f3", IntegerType) + .add("f4", IntegerType).add("f5", IntegerType).add("f6", IntegerType) + .add("f7", IntegerType).add("f8", IntegerType).add("f9", IntegerType) + .add("f10", IntegerType).add("f11", IntegerType).add("f12", IntegerType) + .add("f13", IntegerType).add("f14", IntegerType).add("f15", IntegerType) + + val odf = spark.createDataFrame(List( + Row(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15), + Row(-1, -2, -3, -4, -5, -6, -7, -8, -9, -10, -11, -12, -13, -14, -15) + ).asJava, schema) + odf.write.csv(path.getCanonicalPath) + val idf = spark.read + .schema(schema) + .csv(path.getCanonicalPath) + .select('f15, 'f10, 'f5) + + assert(idf.count() == 2) + checkAnswer(idf, List(Row(15, 10, 5), Row(-15, -10, -5))) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParserSuite.scala index efbf73534bd19..458edb253fb33 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParserSuite.scala @@ -26,8 +26,9 @@ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String class UnivocityParserSuite extends SparkFunSuite { - private val parser = - new UnivocityParser(StructType(Seq.empty), new CSVOptions(Map.empty[String, String], "GMT")) + private val parser = new UnivocityParser( + StructType(Seq.empty), + new CSVOptions(Map.empty[String, String], false, "GMT")) private def assertNull(v: Any) = assert(v == null) @@ -38,7 +39,7 @@ class UnivocityParserSuite extends SparkFunSuite { stringValues.zip(decimalValues).foreach { case (strVal, decimalVal) => val decimalValue = new BigDecimal(decimalVal.toString) - val options = new CSVOptions(Map.empty[String, String], "GMT") + val options = new CSVOptions(Map.empty[String, String], false, "GMT") assert(parser.makeConverter("_1", decimalType, options = options).apply(strVal) === Decimal(decimalValue, decimalType.precision, decimalType.scale)) } @@ -51,21 +52,21 @@ class UnivocityParserSuite extends SparkFunSuite { // Nullable field with nullValue option. types.foreach { t => // Tests that a custom nullValue. - val nullValueOptions = new CSVOptions(Map("nullValue" -> "-"), "GMT") + val nullValueOptions = new CSVOptions(Map("nullValue" -> "-"), false, "GMT") val converter = parser.makeConverter("_1", t, nullable = true, options = nullValueOptions) assertNull(converter.apply("-")) assertNull(converter.apply(null)) // Tests that the default nullValue is empty string. - val options = new CSVOptions(Map.empty[String, String], "GMT") + val options = new CSVOptions(Map.empty[String, String], false, "GMT") assertNull(parser.makeConverter("_1", t, nullable = true, options = options).apply("")) } // Not nullable field with nullValue option. types.foreach { t => // Casts a null to not nullable field should throw an exception. - val options = new CSVOptions(Map("nullValue" -> "-"), "GMT") + val options = new CSVOptions(Map("nullValue" -> "-"), false, "GMT") val converter = parser.makeConverter("_1", t, nullable = false, options = options) var message = intercept[RuntimeException] { @@ -81,7 +82,7 @@ class UnivocityParserSuite extends SparkFunSuite { // If nullValue is different with empty string, then, empty string should not be casted into // null. Seq(true, false).foreach { b => - val options = new CSVOptions(Map("nullValue" -> "null"), "GMT") + val options = new CSVOptions(Map("nullValue" -> "null"), false, "GMT") val converter = parser.makeConverter("_1", StringType, nullable = b, options = options) assert(converter.apply("") == UTF8String.fromString("")) @@ -89,7 +90,7 @@ class UnivocityParserSuite extends SparkFunSuite { } test("Throws exception for empty string with non null type") { - val options = new CSVOptions(Map.empty[String, String], "GMT") + val options = new CSVOptions(Map.empty[String, String], false, "GMT") val exception = intercept[RuntimeException]{ parser.makeConverter("_1", IntegerType, nullable = false, options = options).apply("") } @@ -97,7 +98,7 @@ class UnivocityParserSuite extends SparkFunSuite { } test("Types are cast correctly") { - val options = new CSVOptions(Map.empty[String, String], "GMT") + val options = new CSVOptions(Map.empty[String, String], false, "GMT") assert(parser.makeConverter("_1", ByteType, options = options).apply("10") == 10) assert(parser.makeConverter("_1", ShortType, options = options).apply("10") == 10) assert(parser.makeConverter("_1", IntegerType, options = options).apply("10") == 10) @@ -107,7 +108,7 @@ class UnivocityParserSuite extends SparkFunSuite { assert(parser.makeConverter("_1", BooleanType, options = options).apply("true") == true) val timestampsOptions = - new CSVOptions(Map("timestampFormat" -> "dd/MM/yyyy hh:mm"), "GMT") + new CSVOptions(Map("timestampFormat" -> "dd/MM/yyyy hh:mm"), false, "GMT") val customTimestamp = "31/01/2015 00:00" val expectedTime = timestampsOptions.timestampFormat.parse(customTimestamp).getTime val castedTimestamp = @@ -116,7 +117,7 @@ class UnivocityParserSuite extends SparkFunSuite { assert(castedTimestamp == expectedTime * 1000L) val customDate = "31/01/2015" - val dateOptions = new CSVOptions(Map("dateFormat" -> "dd/MM/yyyy"), "GMT") + val dateOptions = new CSVOptions(Map("dateFormat" -> "dd/MM/yyyy"), false, "GMT") val expectedDate = dateOptions.dateFormat.parse(customDate).getTime val castedDate = parser.makeConverter("_1", DateType, nullable = true, options = dateOptions) @@ -131,7 +132,7 @@ class UnivocityParserSuite extends SparkFunSuite { } test("Throws exception for casting an invalid string to Float and Double Types") { - val options = new CSVOptions(Map.empty[String, String], "GMT") + val options = new CSVOptions(Map.empty[String, String], false, "GMT") val types = Seq(DoubleType, FloatType) val input = Seq("10u000", "abc", "1 2/3") types.foreach { dt => @@ -145,7 +146,7 @@ class UnivocityParserSuite extends SparkFunSuite { } test("Float NaN values are parsed correctly") { - val options = new CSVOptions(Map("nanValue" -> "nn"), "GMT") + val options = new CSVOptions(Map("nanValue" -> "nn"), false, "GMT") val floatVal: Float = parser.makeConverter( "_1", FloatType, nullable = true, options = options ).apply("nn").asInstanceOf[Float] @@ -156,7 +157,7 @@ class UnivocityParserSuite extends SparkFunSuite { } test("Double NaN values are parsed correctly") { - val options = new CSVOptions(Map("nanValue" -> "-"), "GMT") + val options = new CSVOptions(Map("nanValue" -> "-"), false, "GMT") val doubleVal: Double = parser.makeConverter( "_1", DoubleType, nullable = true, options = options ).apply("-").asInstanceOf[Double] @@ -165,14 +166,14 @@ class UnivocityParserSuite extends SparkFunSuite { } test("Float infinite values can be parsed") { - val negativeInfOptions = new CSVOptions(Map("negativeInf" -> "max"), "GMT") + val negativeInfOptions = new CSVOptions(Map("negativeInf" -> "max"), false, "GMT") val floatVal1 = parser.makeConverter( "_1", FloatType, nullable = true, options = negativeInfOptions ).apply("max").asInstanceOf[Float] assert(floatVal1 == Float.NegativeInfinity) - val positiveInfOptions = new CSVOptions(Map("positiveInf" -> "max"), "GMT") + val positiveInfOptions = new CSVOptions(Map("positiveInf" -> "max"), false, "GMT") val floatVal2 = parser.makeConverter( "_1", FloatType, nullable = true, options = positiveInfOptions ).apply("max").asInstanceOf[Float] @@ -181,14 +182,14 @@ class UnivocityParserSuite extends SparkFunSuite { } test("Double infinite values can be parsed") { - val negativeInfOptions = new CSVOptions(Map("negativeInf" -> "max"), "GMT") + val negativeInfOptions = new CSVOptions(Map("negativeInf" -> "max"), false, "GMT") val doubleVal1 = parser.makeConverter( "_1", DoubleType, nullable = true, options = negativeInfOptions ).apply("max").asInstanceOf[Double] assert(doubleVal1 == Double.NegativeInfinity) - val positiveInfOptions = new CSVOptions(Map("positiveInf" -> "max"), "GMT") + val positiveInfOptions = new CSVOptions(Map("positiveInf" -> "max"), false, "GMT") val doubleVal2 = parser.makeConverter( "_1", DoubleType, nullable = true, options = positiveInfOptions ).apply("max").asInstanceOf[Double] From fd315f5884c03c6dd21abca178897584dee83f1a Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Fri, 25 May 2018 12:49:06 -0700 Subject: [PATCH 0875/1161] [MINOR] Add port SSL config in toString and scaladoc ## What changes were proposed in this pull request? SPARK-17874 introduced a new configuration to set the port where SSL services bind to. We missed to update the scaladoc and the `toString` method, though. The PR adds it in the missing places ## How was this patch tested? checked the `toString` output in the logs Author: Marco Gaido Closes #21429 from mgaido91/minor_ssl. --- core/src/main/scala/org/apache/spark/SSLOptions.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/SSLOptions.scala b/core/src/main/scala/org/apache/spark/SSLOptions.scala index 477b01968c6ef..04c38f12acc78 100644 --- a/core/src/main/scala/org/apache/spark/SSLOptions.scala +++ b/core/src/main/scala/org/apache/spark/SSLOptions.scala @@ -128,7 +128,7 @@ private[spark] case class SSLOptions( } /** Returns a string representation of this SSLOptions with all the passwords masked. */ - override def toString: String = s"SSLOptions{enabled=$enabled, " + + override def toString: String = s"SSLOptions{enabled=$enabled, port=$port, " + s"keyStore=$keyStore, keyStorePassword=${keyStorePassword.map(_ => "xxx")}, " + s"trustStore=$trustStore, trustStorePassword=${trustStorePassword.map(_ => "xxx")}, " + s"protocol=$protocol, enabledAlgorithms=$enabledAlgorithms}" @@ -142,6 +142,7 @@ private[spark] object SSLOptions extends Logging { * * The following settings are allowed: * $ - `[ns].enabled` - `true` or `false`, to enable or disable SSL respectively + * $ - `[ns].port` - the port where to bind the SSL server * $ - `[ns].keyStore` - a path to the key-store file; can be relative to the current directory * $ - `[ns].keyStorePassword` - a password to the key-store file * $ - `[ns].keyPassword` - a password to the private key From 1b1528a504febfadf6fe41fd72e657689da50525 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Fri, 25 May 2018 15:42:46 -0700 Subject: [PATCH 0876/1161] [SPARK-24366][SQL] Improving of error messages for type converting ## What changes were proposed in this pull request? Currently, users are getting the following error messages on type conversions: ``` scala.MatchError: test (of class java.lang.String) ``` The message doesn't give any clues to the users where in the schema the error happened. In this PR, I would like to improve the error message like: ``` The value (test) of the type (java.lang.String) cannot be converted to struct ``` ## How was this patch tested? Added tests for converting of wrong values to `struct`, `map`, `array`, `string` and `decimal`. Author: Maxim Gekk Closes #21410 from MaxGekk/type-conv-error. --- .../sql/catalyst/CatalystTypeConverters.scala | 16 +++++++ .../CatalystTypeConvertersSuite.scala | 45 +++++++++++++++++++ 2 files changed, 61 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala index 474ec592201d9..9e9105a157abe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala @@ -170,6 +170,9 @@ object CatalystTypeConverters { convertedIterable += elementConverter.toCatalyst(item) } new GenericArrayData(convertedIterable.toArray) + case other => throw new IllegalArgumentException( + s"The value (${other.toString}) of the type (${other.getClass.getCanonicalName}) " + + s"cannot be converted to an array of ${elementType.catalogString}") } } @@ -206,6 +209,10 @@ object CatalystTypeConverters { scalaValue match { case map: Map[_, _] => ArrayBasedMapData(map, keyFunction, valueFunction) case javaMap: JavaMap[_, _] => ArrayBasedMapData(javaMap, keyFunction, valueFunction) + case other => throw new IllegalArgumentException( + s"The value (${other.toString}) of the type (${other.getClass.getCanonicalName}) " + + "cannot be converted to a map type with " + + s"key type (${keyType.catalogString}) and value type (${valueType.catalogString})") } } @@ -252,6 +259,9 @@ object CatalystTypeConverters { idx += 1 } new GenericInternalRow(ar) + case other => throw new IllegalArgumentException( + s"The value (${other.toString}) of the type (${other.getClass.getCanonicalName}) " + + s"cannot be converted to ${structType.catalogString}") } override def toScala(row: InternalRow): Row = { @@ -276,6 +286,9 @@ object CatalystTypeConverters { override def toCatalystImpl(scalaValue: Any): UTF8String = scalaValue match { case str: String => UTF8String.fromString(str) case utf8: UTF8String => utf8 + case other => throw new IllegalArgumentException( + s"The value (${other.toString}) of the type (${other.getClass.getCanonicalName}) " + + s"cannot be converted to the string type") } override def toScala(catalystValue: UTF8String): String = if (catalystValue == null) null else catalystValue.toString @@ -309,6 +322,9 @@ object CatalystTypeConverters { case d: JavaBigDecimal => Decimal(d) case d: JavaBigInteger => Decimal(d) case d: Decimal => d + case other => throw new IllegalArgumentException( + s"The value (${other.toString}) of the type (${other.getClass.getCanonicalName}) " + + s"cannot be converted to ${dataType.catalogString}") } decimal.toPrecision(dataType.precision, dataType.scale) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala index f3702ec92b425..f99af9b84d959 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala @@ -94,4 +94,49 @@ class CatalystTypeConvertersSuite extends SparkFunSuite { assert(CatalystTypeConverters.createToCatalystConverter(doubleArrayType)(doubleArray) == doubleGenericArray) } + + test("converting a wrong value to the struct type") { + val structType = new StructType().add("f1", IntegerType) + val exception = intercept[IllegalArgumentException] { + CatalystTypeConverters.createToCatalystConverter(structType)("test") + } + assert(exception.getMessage.contains("The value (test) of the type " + + "(java.lang.String) cannot be converted to struct")) + } + + test("converting a wrong value to the map type") { + val mapType = MapType(StringType, IntegerType, false) + val exception = intercept[IllegalArgumentException] { + CatalystTypeConverters.createToCatalystConverter(mapType)("test") + } + assert(exception.getMessage.contains("The value (test) of the type " + + "(java.lang.String) cannot be converted to a map type with key " + + "type (string) and value type (int)")) + } + + test("converting a wrong value to the array type") { + val arrayType = ArrayType(IntegerType, true) + val exception = intercept[IllegalArgumentException] { + CatalystTypeConverters.createToCatalystConverter(arrayType)("test") + } + assert(exception.getMessage.contains("The value (test) of the type " + + "(java.lang.String) cannot be converted to an array of int")) + } + + test("converting a wrong value to the decimal type") { + val decimalType = DecimalType(10, 0) + val exception = intercept[IllegalArgumentException] { + CatalystTypeConverters.createToCatalystConverter(decimalType)("test") + } + assert(exception.getMessage.contains("The value (test) of the type " + + "(java.lang.String) cannot be converted to decimal(10,0)")) + } + + test("converting a wrong value to the string type") { + val exception = intercept[IllegalArgumentException] { + CatalystTypeConverters.createToCatalystConverter(StringType)(0.1) + } + assert(exception.getMessage.contains("The value (0.1) of the type " + + "(java.lang.Double) cannot be converted to the string type")) + } } From ed1a65448f228776afe2e5c6b1ac4228d2ed2854 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Sat, 26 May 2018 20:26:00 +0800 Subject: [PATCH 0877/1161] [SPARK-19112][CORE][FOLLOW-UP] Add missing shortCompressionCodecNames to configuration. ## What changes were proposed in this pull request? Spark provides four codecs: `lz4`, `lzf`, `snappy`, and `zstd`. This pr add missing shortCompressionCodecNames to configuration. ## How was this patch tested? manually tested Author: Yuming Wang Closes #21431 from wangyum/SPARK-19112. --- docs/configuration.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/configuration.md b/docs/configuration.md index fd2670cba2125..64af0e98a82f5 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -910,8 +910,8 @@ Apart from these, the following properties are also available, and may be useful lz4 The codec used to compress internal data such as RDD partitions, event log, broadcast variables - and shuffle outputs. By default, Spark provides three codecs: lz4, lzf, - and snappy. You can also use fully qualified class names to specify the codec, + and shuffle outputs. By default, Spark provides four codecs: lz4, lzf, + snappy, and zstd. You can also use fully qualified class names to specify the codec, e.g. org.apache.spark.io.LZ4CompressionCodec, org.apache.spark.io.LZFCompressionCodec, From d440699192f21b14dfb8ec0dc5673537e1003b55 Mon Sep 17 00:00:00 2001 From: Miles Yucht Date: Sat, 26 May 2018 20:42:23 -0700 Subject: [PATCH 0878/1161] [SPARK-24381][TESTING] Add unit tests for NOT IN subquery around null values ## What changes were proposed in this pull request? This PR adds several unit tests along the `cols NOT IN (subquery)` pathway. There are a scattering of tests here and there which cover this codepath, but there doesn't seem to be a unified unit test of the correctness of null-aware anti joins anywhere. I have also added a brief explanation of how this expression behaves in SubquerySuite. Lastly, I made some clarifying changes in the NOT IN pathway in RewritePredicateSubquery. ## How was this patch tested? Added unit tests! There should be no behavioral change in this PR. Author: Miles Yucht Closes #21425 from mgyucht/spark-24381. --- .../sql/catalyst/optimizer/subquery.scala | 9 +- ...not-in-unit-tests-multi-column-literal.sql | 39 +++++ .../not-in-unit-tests-multi-column.sql | 98 ++++++++++++ ...ot-in-unit-tests-single-column-literal.sql | 42 +++++ .../not-in-unit-tests-single-column.sql | 123 +++++++++++++++ ...in-unit-tests-multi-column-literal.sql.out | 54 +++++++ .../not-in-unit-tests-multi-column.sql.out | 134 ++++++++++++++++ ...n-unit-tests-single-column-literal.sql.out | 69 ++++++++ .../not-in-unit-tests-single-column.sql.out | 149 ++++++++++++++++++ .../typeCoercion/native/concat.sql.out | 2 +- 10 files changed, 714 insertions(+), 5 deletions(-) create mode 100644 sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/not-in-unit-tests-multi-column-literal.sql create mode 100644 sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/not-in-unit-tests-multi-column.sql create mode 100644 sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/not-in-unit-tests-single-column-literal.sql create mode 100644 sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/not-in-unit-tests-single-column.sql create mode 100644 sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/not-in-unit-tests-multi-column-literal.sql.out create mode 100644 sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/not-in-unit-tests-multi-column.sql.out create mode 100644 sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/not-in-unit-tests-single-column-literal.sql.out create mode 100644 sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/not-in-unit-tests-single-column.sql.out diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala index 709db6d8bec7d..de89e17e51f1b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala @@ -116,15 +116,16 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { // (a1,a2,...) = (b1,b2,...) // to // (a1=b1 OR isnull(a1=b1)) AND (a2=b2 OR isnull(a2=b2)) AND ... - val joinConds = splitConjunctivePredicates(joinCond.get) + val baseJoinConds = splitConjunctivePredicates(joinCond.get) + val nullAwareJoinConds = baseJoinConds.map(c => Or(c, IsNull(c))) // After that, add back the correlated join predicate(s) in the subquery // Example: // SELECT ... FROM A WHERE A.A1 NOT IN (SELECT B.B1 FROM B WHERE B.B2 = A.A2 AND B.B3 > 1) // will have the final conditions in the LEFT ANTI as - // (A.A1 = B.B1 OR ISNULL(A.A1 = B.B1)) AND (B.B2 = A.A2) - val pairs = (joinConds.map(c => Or(c, IsNull(c))) ++ conditions).reduceLeft(And) + // (A.A1 = B.B1 OR ISNULL(A.A1 = B.B1)) AND (B.B2 = A.A2) AND B.B3 > 1 + val finalJoinCond = (nullAwareJoinConds ++ conditions).reduceLeft(And) // Deduplicate conflicting attributes if any. - dedupJoin(Join(outerPlan, sub, LeftAnti, Option(pairs))) + dedupJoin(Join(outerPlan, sub, LeftAnti, Option(finalJoinCond))) case (p, predicate) => val (newCond, inputPlan) = rewriteExistentialExpr(Seq(predicate), p) Project(p.output, Filter(newCond.get, inputPlan)) diff --git a/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/not-in-unit-tests-multi-column-literal.sql b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/not-in-unit-tests-multi-column-literal.sql new file mode 100644 index 0000000000000..8eea84f4f5272 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/not-in-unit-tests-multi-column-literal.sql @@ -0,0 +1,39 @@ +-- Unit tests for simple NOT IN predicate subquery across multiple columns. +-- +-- See not-in-single-column-unit-tests.sql for an introduction. +-- This file has the same test cases as not-in-unit-tests-multi-column.sql with literals instead of +-- subqueries. Small changes have been made to the literals to make them typecheck. + +CREATE TEMPORARY VIEW m AS SELECT * FROM VALUES + (null, null), + (null, 1.0), + (2, 3.0), + (4, 5.0) + AS m(a, b); + +-- Case 1 (not possible to write a literal with no rows, so we ignore it.) +-- (subquery is empty -> row is returned) + +-- Cases 2, 3 and 4 are currently broken, so I have commented them out here. +-- Filed https://issues.apache.org/jira/browse/SPARK-24395 to fix and restore these test cases. + + -- Case 5 + -- (one null column with no match -> row is returned) +SELECT * +FROM m +WHERE b = 1.0 -- Matches (null, 1.0) + AND (a, b) NOT IN ((2, 3.0)); + + -- Case 6 + -- (no null columns with match -> row not returned) +SELECT * +FROM m +WHERE b = 3.0 -- Matches (2, 3.0) + AND (a, b) NOT IN ((2, 3.0)); + + -- Case 7 + -- (no null columns with no match -> row is returned) +SELECT * +FROM m +WHERE b = 5.0 -- Matches (4, 5.0) + AND (a, b) NOT IN ((2, 3.0)); diff --git a/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/not-in-unit-tests-multi-column.sql b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/not-in-unit-tests-multi-column.sql new file mode 100644 index 0000000000000..9f8dc7fca3b94 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/not-in-unit-tests-multi-column.sql @@ -0,0 +1,98 @@ +-- Unit tests for simple NOT IN predicate subquery across multiple columns. +-- +-- See not-in-single-column-unit-tests.sql for an introduction. +-- +-- Test cases for multi-column ``WHERE a NOT IN (SELECT c FROM r ...)'': +-- | # | does subquery include null? | do filter columns contain null? | a = c? | b = d? | row included in result? | +-- | 1 | empty | * | * | * | yes | +-- | 2 | 1+ row has null for all columns | * | * | * | no | +-- | 3 | no row has null for all columns | (yes, yes) | * | * | no | +-- | 4 | no row has null for all columns | (no, yes) | yes | * | no | +-- | 5 | no row has null for all columns | (no, yes) | no | * | yes | +-- | 6 | no | (no, no) | yes | yes | no | +-- | 7 | no | (no, no) | _ | _ | yes | +-- +-- This can be generalized to include more tests for more columns, but it covers the main cases +-- when there is more than one column. + +CREATE TEMPORARY VIEW m AS SELECT * FROM VALUES + (null, null), + (null, 1.0), + (2, 3.0), + (4, 5.0) + AS m(a, b); + +CREATE TEMPORARY VIEW s AS SELECT * FROM VALUES + (null, null), + (0, 1.0), + (2, 3.0), + (4, null) + AS s(c, d); + + -- Case 1 + -- (subquery is empty -> row is returned) +SELECT * +FROM m +WHERE (a, b) NOT IN (SELECT * + FROM s + WHERE d > 5.0) -- Matches no rows +; + + -- Case 2 + -- (subquery contains a row with null in all columns -> row not returned) +SELECT * +FROM m +WHERE (a, b) NOT IN (SELECT * + FROM s + WHERE c IS NULL AND d IS NULL) -- Matches only (null, null) +; + + -- Case 3 + -- (probe-side columns are all null -> row not returned) +SELECT * +FROM m +WHERE a IS NULL AND b IS NULL -- Matches only (null, null) + AND (a, b) NOT IN (SELECT * + FROM s + WHERE c IS NOT NULL) -- Matches (0, 1.0), (2, 3.0), (4, null) +; + + -- Case 4 + -- (one column null, other column matches a row in the subquery result -> row not returned) +SELECT * +FROM m +WHERE b = 1.0 -- Matches (null, 1.0) + AND (a, b) NOT IN (SELECT * + FROM s + WHERE c IS NOT NULL) -- Matches (0, 1.0), (2, 3.0), (4, null) +; + + -- Case 5 + -- (one null column with no match -> row is returned) +SELECT * +FROM m +WHERE b = 1.0 -- Matches (null, 1.0) + AND (a, b) NOT IN (SELECT * + FROM s + WHERE c = 2) -- Matches (2, 3.0) +; + + -- Case 6 + -- (no null columns with match -> row not returned) +SELECT * +FROM m +WHERE b = 3.0 -- Matches (2, 3.0) + AND (a, b) NOT IN (SELECT * + FROM s + WHERE c = 2) -- Matches (2, 3.0) +; + + -- Case 7 + -- (no null columns with no match -> row is returned) +SELECT * +FROM m +WHERE b = 5.0 -- Matches (4, 5.0) + AND (a, b) NOT IN (SELECT * + FROM s + WHERE c = 2) -- Matches (2, 3.0) +; diff --git a/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/not-in-unit-tests-single-column-literal.sql b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/not-in-unit-tests-single-column-literal.sql new file mode 100644 index 0000000000000..b261363d1dde7 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/not-in-unit-tests-single-column-literal.sql @@ -0,0 +1,42 @@ +-- Unit tests for simple NOT IN with a literal expression of a single column +-- +-- More information can be found in not-in-unit-tests-single-column.sql. +-- This file has the same test cases as not-in-unit-tests-single-column.sql with literals instead of +-- subqueries. + +CREATE TEMPORARY VIEW m AS SELECT * FROM VALUES + (null, 1.0), + (2, 3.0), + (4, 5.0) + AS m(a, b); + + -- Uncorrelated NOT IN Subquery test cases + -- Case 1 (not possible to write a literal with no rows, so we ignore it.) + -- (empty subquery -> all rows returned) + + -- Case 2 + -- (subquery includes null -> no rows returned) +SELECT * +FROM m +WHERE a NOT IN (null); + + -- Case 3 + -- (probe column is null -> row not returned) +SELECT * +FROM m +WHERE b = 1.0 -- Only matches (null, 1.0) + AND a NOT IN (2); + + -- Case 4 + -- (probe column matches subquery row -> row not returned) +SELECT * +FROM m +WHERE b = 3.0 -- Only matches (2, 3.0) + AND a NOT IN (2); + + -- Case 5 + -- (probe column does not match subquery row -> row is returned) +SELECT * +FROM m +WHERE b = 3.0 -- Only matches (2, 3.0) + AND a NOT IN (6); diff --git a/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/not-in-unit-tests-single-column.sql b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/not-in-unit-tests-single-column.sql new file mode 100644 index 0000000000000..2cc08e10acf67 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/not-in-unit-tests-single-column.sql @@ -0,0 +1,123 @@ +-- Unit tests for simple NOT IN predicate subquery across a single column. +-- +-- ``col NOT IN expr'' is quite difficult to reason about. There are many edge cases, some of the +-- rules are confusing to the uninitiated, and precedence and treatment of null values is plain +-- unintuitive. To make this simpler to understand, I've come up with a plain English way of +-- describing the expected behavior of this query. +-- +-- - If the subquery is empty (i.e. returns no rows), the row should be returned, regardless of +-- whether the filtered columns include nulls. +-- - If the subquery contains a result with all columns null, then the row should not be returned. +-- - If for all non-null filter columns there exists a row in the subquery in which each column +-- either +-- 1. is equal to the corresponding filter column or +-- 2. is null +-- then the row should not be returned. (This includes the case where all filter columns are +-- null.) +-- - Otherwise, the row should be returned. +-- +-- Using these rules, we can come up with a set of test cases for single-column and multi-column +-- NOT IN test cases. +-- +-- Test cases for single-column ``WHERE a NOT IN (SELECT c FROM r ...)'': +-- | # | does subquery include null? | is a null? | a = c? | row with a included in result? | +-- | 1 | empty | | | yes | +-- | 2 | yes | | | no | +-- | 3 | no | yes | | no | +-- | 4 | no | no | yes | no | +-- | 5 | no | no | no | yes | +-- +-- There are also some considerations around correlated subqueries. Correlated subqueries can +-- cause cases 2, 3, or 4 to be reduced to case 1 by limiting the number of rows returned by the +-- subquery, so the row from the parent table should always be included in the output. + +CREATE TEMPORARY VIEW m AS SELECT * FROM VALUES + (null, 1.0), + (2, 3.0), + (4, 5.0) + AS m(a, b); + +CREATE TEMPORARY VIEW s AS SELECT * FROM VALUES + (null, 1.0), + (2, 3.0), + (6, 7.0) + AS s(c, d); + + -- Uncorrelated NOT IN Subquery test cases + -- Case 1 + -- (empty subquery -> all rows returned) +SELECT * +FROM m +WHERE a NOT IN (SELECT c + FROM s + WHERE d > 10.0) -- (empty subquery) +; + + -- Case 2 + -- (subquery includes null -> no rows returned) +SELECT * +FROM m +WHERE a NOT IN (SELECT c + FROM s + WHERE d = 1.0) -- Only matches (null, 1.0) +; + + -- Case 3 + -- (probe column is null -> row not returned) +SELECT * +FROM m +WHERE b = 1.0 -- Only matches (null, 1.0) + AND a NOT IN (SELECT c + FROM s + WHERE d = 3.0) -- Matches (2, 3.0) +; + + -- Case 4 + -- (probe column matches subquery row -> row not returned) +SELECT * +FROM m +WHERE b = 3.0 -- Only matches (2, 3.0) + AND a NOT IN (SELECT c + FROM s + WHERE d = 3.0) -- Matches (2, 3.0) +; + + -- Case 5 + -- (probe column does not match subquery row -> row is returned) +SELECT * +FROM m +WHERE b = 3.0 -- Only matches (2, 3.0) + AND a NOT IN (SELECT c + FROM s + WHERE d = 7.0) -- Matches (6, 7.0) +; + + -- Correlated NOT IN subquery test cases + -- Case 2->1 + -- (subquery had nulls but they are removed by correlated subquery -> all rows returned) +SELECT * +FROM m +WHERE a NOT IN (SELECT c + FROM s + WHERE d = b + 10) -- Matches no row +; + + -- Case 3->1 + -- (probe column is null but subquery returns no rows -> row is returned) +SELECT * +FROM m +WHERE b = 1.0 -- Only matches (null, 1.0) + AND a NOT IN (SELECT c + FROM s + WHERE d = b + 10) -- Matches no row +; + + -- Case 4->1 + -- (probe column matches row which is filtered out by correlated subquery -> row is returned) +SELECT * +FROM m +WHERE b = 3.0 -- Only matches (2, 3.0) + AND a NOT IN (SELECT c + FROM s + WHERE d = b + 10) -- Matches no row +; diff --git a/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/not-in-unit-tests-multi-column-literal.sql.out b/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/not-in-unit-tests-multi-column-literal.sql.out new file mode 100644 index 0000000000000..a16e98af9a417 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/not-in-unit-tests-multi-column-literal.sql.out @@ -0,0 +1,54 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 4 + + +-- !query 0 +CREATE TEMPORARY VIEW m AS SELECT * FROM VALUES + (null, null), + (null, 1.0), + (2, 3.0), + (4, 5.0) + AS m(a, b) +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +-- Case 5 + -- (one null column with no match -> row is returned) +SELECT * +FROM m +WHERE b = 1.0 -- Matches (null, 1.0) + AND (a, b) NOT IN ((2, 3.0)) +-- !query 1 schema +struct +-- !query 1 output +NULL 1 + + +-- !query 2 +-- Case 6 + -- (no null columns with match -> row not returned) +SELECT * +FROM m +WHERE b = 3.0 -- Matches (2, 3.0) + AND (a, b) NOT IN ((2, 3.0)) +-- !query 2 schema +struct +-- !query 2 output + + + +-- !query 3 +-- Case 7 + -- (no null columns with no match -> row is returned) +SELECT * +FROM m +WHERE b = 5.0 -- Matches (4, 5.0) + AND (a, b) NOT IN ((2, 3.0)) +-- !query 3 schema +struct +-- !query 3 output +4 5 diff --git a/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/not-in-unit-tests-multi-column.sql.out b/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/not-in-unit-tests-multi-column.sql.out new file mode 100644 index 0000000000000..aa5f64b8ebf55 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/not-in-unit-tests-multi-column.sql.out @@ -0,0 +1,134 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 9 + + +-- !query 0 +CREATE TEMPORARY VIEW m AS SELECT * FROM VALUES + (null, null), + (null, 1.0), + (2, 3.0), + (4, 5.0) + AS m(a, b) +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +CREATE TEMPORARY VIEW s AS SELECT * FROM VALUES + (null, null), + (0, 1.0), + (2, 3.0), + (4, null) + AS s(c, d) +-- !query 1 schema +struct<> +-- !query 1 output + + + +-- !query 2 +-- Case 1 + -- (subquery is empty -> row is returned) +SELECT * +FROM m +WHERE (a, b) NOT IN (SELECT * + FROM s + WHERE d > 5.0) -- Matches no rows +-- !query 2 schema +struct +-- !query 2 output +2 3 +4 5 +NULL 1 +NULL NULL + + +-- !query 3 +-- Case 2 + -- (subquery contains a row with null in all columns -> row not returned) +SELECT * +FROM m +WHERE (a, b) NOT IN (SELECT * + FROM s + WHERE c IS NULL AND d IS NULL) -- Matches only (null, null) +-- !query 3 schema +struct +-- !query 3 output + + + +-- !query 4 +-- Case 3 + -- (probe-side columns are all null -> row not returned) +SELECT * +FROM m +WHERE a IS NULL AND b IS NULL -- Matches only (null, null) + AND (a, b) NOT IN (SELECT * + FROM s + WHERE c IS NOT NULL) -- Matches (0, 1.0), (2, 3.0), (4, null) +-- !query 4 schema +struct +-- !query 4 output + + + +-- !query 5 +-- Case 4 + -- (one column null, other column matches a row in the subquery result -> row not returned) +SELECT * +FROM m +WHERE b = 1.0 -- Matches (null, 1.0) + AND (a, b) NOT IN (SELECT * + FROM s + WHERE c IS NOT NULL) -- Matches (0, 1.0), (2, 3.0), (4, null) +-- !query 5 schema +struct +-- !query 5 output + + + +-- !query 6 +-- Case 5 + -- (one null column with no match -> row is returned) +SELECT * +FROM m +WHERE b = 1.0 -- Matches (null, 1.0) + AND (a, b) NOT IN (SELECT * + FROM s + WHERE c = 2) -- Matches (2, 3.0) +-- !query 6 schema +struct +-- !query 6 output +NULL 1 + + +-- !query 7 +-- Case 6 + -- (no null columns with match -> row not returned) +SELECT * +FROM m +WHERE b = 3.0 -- Matches (2, 3.0) + AND (a, b) NOT IN (SELECT * + FROM s + WHERE c = 2) -- Matches (2, 3.0) +-- !query 7 schema +struct +-- !query 7 output + + + +-- !query 8 +-- Case 7 + -- (no null columns with no match -> row is returned) +SELECT * +FROM m +WHERE b = 5.0 -- Matches (4, 5.0) + AND (a, b) NOT IN (SELECT * + FROM s + WHERE c = 2) -- Matches (2, 3.0) +-- !query 8 schema +struct +-- !query 8 output +4 5 diff --git a/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/not-in-unit-tests-single-column-literal.sql.out b/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/not-in-unit-tests-single-column-literal.sql.out new file mode 100644 index 0000000000000..446447e890449 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/not-in-unit-tests-single-column-literal.sql.out @@ -0,0 +1,69 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 5 + + +-- !query 0 +CREATE TEMPORARY VIEW m AS SELECT * FROM VALUES + (null, 1.0), + (2, 3.0), + (4, 5.0) + AS m(a, b) +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +-- Uncorrelated NOT IN Subquery test cases + -- Case 1 (not possible to write a literal with no rows, so we ignore it.) + -- (empty subquery -> all rows returned) + + -- Case 2 + -- (subquery includes null -> no rows returned) +SELECT * +FROM m +WHERE a NOT IN (null) +-- !query 1 schema +struct +-- !query 1 output + + + +-- !query 2 +-- Case 3 + -- (probe column is null -> row not returned) +SELECT * +FROM m +WHERE b = 1.0 -- Only matches (null, 1.0) + AND a NOT IN (2) +-- !query 2 schema +struct +-- !query 2 output + + + +-- !query 3 +-- Case 4 + -- (probe column matches subquery row -> row not returned) +SELECT * +FROM m +WHERE b = 3.0 -- Only matches (2, 3.0) + AND a NOT IN (2) +-- !query 3 schema +struct +-- !query 3 output + + + +-- !query 4 +-- Case 5 + -- (probe column does not match subquery row -> row is returned) +SELECT * +FROM m +WHERE b = 3.0 -- Only matches (2, 3.0) + AND a NOT IN (6) +-- !query 4 schema +struct +-- !query 4 output +2 3 diff --git a/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/not-in-unit-tests-single-column.sql.out b/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/not-in-unit-tests-single-column.sql.out new file mode 100644 index 0000000000000..f58ebeacc2872 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/not-in-unit-tests-single-column.sql.out @@ -0,0 +1,149 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 10 + + +-- !query 0 +CREATE TEMPORARY VIEW m AS SELECT * FROM VALUES + (null, 1.0), + (2, 3.0), + (4, 5.0) + AS m(a, b) +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +CREATE TEMPORARY VIEW s AS SELECT * FROM VALUES + (null, 1.0), + (2, 3.0), + (6, 7.0) + AS s(c, d) +-- !query 1 schema +struct<> +-- !query 1 output + + + +-- !query 2 +-- Uncorrelated NOT IN Subquery test cases + -- Case 1 + -- (empty subquery -> all rows returned) +SELECT * +FROM m +WHERE a NOT IN (SELECT c + FROM s + WHERE d > 10.0) -- (empty subquery) +-- !query 2 schema +struct +-- !query 2 output +2 3 +4 5 +NULL 1 + + +-- !query 3 +-- Case 2 + -- (subquery includes null -> no rows returned) +SELECT * +FROM m +WHERE a NOT IN (SELECT c + FROM s + WHERE d = 1.0) -- Only matches (null, 1.0) +-- !query 3 schema +struct +-- !query 3 output + + + +-- !query 4 +-- Case 3 + -- (probe column is null -> row not returned) +SELECT * +FROM m +WHERE b = 1.0 -- Only matches (null, 1.0) + AND a NOT IN (SELECT c + FROM s + WHERE d = 3.0) -- Matches (2, 3.0) +-- !query 4 schema +struct +-- !query 4 output + + + +-- !query 5 +-- Case 4 + -- (probe column matches subquery row -> row not returned) +SELECT * +FROM m +WHERE b = 3.0 -- Only matches (2, 3.0) + AND a NOT IN (SELECT c + FROM s + WHERE d = 3.0) -- Matches (2, 3.0) +-- !query 5 schema +struct +-- !query 5 output + + + +-- !query 6 +-- Case 5 + -- (probe column does not match subquery row -> row is returned) +SELECT * +FROM m +WHERE b = 3.0 -- Only matches (2, 3.0) + AND a NOT IN (SELECT c + FROM s + WHERE d = 7.0) -- Matches (6, 7.0) +-- !query 6 schema +struct +-- !query 6 output +2 3 + + +-- !query 7 +-- Correlated NOT IN subquery test cases + -- Case 2->1 + -- (subquery had nulls but they are removed by correlated subquery -> all rows returned) +SELECT * +FROM m +WHERE a NOT IN (SELECT c + FROM s + WHERE d = b + 10) -- Matches no row +-- !query 7 schema +struct +-- !query 7 output +2 3 +4 5 +NULL 1 + + +-- !query 8 +-- Case 3->1 + -- (probe column is null but subquery returns no rows -> row is returned) +SELECT * +FROM m +WHERE b = 1.0 -- Only matches (null, 1.0) + AND a NOT IN (SELECT c + FROM s + WHERE d = b + 10) -- Matches no row +-- !query 8 schema +struct +-- !query 8 output +NULL 1 + + +-- !query 9 +-- Case 4->1 + -- (probe column matches row which is filtered out by correlated subquery -> row is returned) +SELECT * +FROM m +WHERE b = 3.0 -- Only matches (2, 3.0) + AND a NOT IN (SELECT c + FROM s + WHERE d = b + 10) -- Matches no row +-- !query 9 schema +struct +-- !query 9 output +2 3 diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/concat.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/concat.sql.out index 62befc5ca0f15..be637b66abc86 100644 --- a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/concat.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/concat.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 11 +-- Number of queries: 14 -- !query 0 From 672209f2909a95e891f3c779bfb2f0e534239851 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Mon, 28 May 2018 10:50:17 +0800 Subject: [PATCH 0879/1161] [SPARK-24334] Fix race condition in ArrowPythonRunner causes unclean shutdown of Arrow memory allocator ## What changes were proposed in this pull request? There is a race condition of closing Arrow VectorSchemaRoot and Allocator in the writer thread of ArrowPythonRunner. The race results in memory leak exception when closing the allocator. This patch removes the closing routine from the TaskCompletionListener and make the writer thread responsible for cleaning up the Arrow memory. This issue be reproduced by this test: ``` def test_memory_leak(self): from pyspark.sql.functions import pandas_udf, col, PandasUDFType, array, lit, explode # Have all data in a single executor thread so it can trigger the race condition easier with self.sql_conf({'spark.sql.shuffle.partitions': 1}): df = self.spark.range(0, 1000) df = df.withColumn('id', array([lit(i) for i in range(0, 300)])) \ .withColumn('id', explode(col('id'))) \ .withColumn('v', array([lit(i) for i in range(0, 1000)])) pandas_udf(df.schema, PandasUDFType.GROUPED_MAP) def foo(pdf): xxx return pdf result = df.groupby('id').apply(foo) with QuietTest(self.sc): with self.assertRaises(py4j.protocol.Py4JJavaError) as context: result.count() self.assertTrue('Memory leaked' not in str(context.exception)) ``` Note: Because of the race condition, the test case cannot reproduce the issue reliably so it's not added to test cases. ## How was this patch tested? Because of the race condition, the bug cannot be unit test easily. So far it has only happens on large amount of data. This is currently tested manually. Author: Li Jin Closes #21397 from icexelloss/SPARK-24334-arrow-memory-leak. --- .../execution/python/ArrowPythonRunner.scala | 29 ++++++++++++------- 1 file changed, 18 insertions(+), 11 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala index 5fcdcddca7d51..01e19bddbfb66 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala @@ -70,19 +70,13 @@ class ArrowPythonRunner( val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId) val allocator = ArrowUtils.rootAllocator.newChildAllocator( s"stdout writer for $pythonExec", 0, Long.MaxValue) - val root = VectorSchemaRoot.create(arrowSchema, allocator) - val arrowWriter = ArrowWriter.create(root) - - context.addTaskCompletionListener { _ => - root.close() - allocator.close() - } - - val writer = new ArrowStreamWriter(root, null, dataOut) - writer.start() Utils.tryWithSafeFinally { + val arrowWriter = ArrowWriter.create(root) + val writer = new ArrowStreamWriter(root, null, dataOut) + writer.start() + while (inputIterator.hasNext) { val nextBatch = inputIterator.next() @@ -94,8 +88,21 @@ class ArrowPythonRunner( writer.writeBatch() arrowWriter.reset() } - } { + // end writes footer to the output stream and doesn't clean any resources. + // It could throw exception if the output stream is closed, so it should be + // in the try block. writer.end() + } { + // If we close root and allocator in TaskCompletionListener, there could be a race + // condition where the writer thread keeps writing to the VectorSchemaRoot while + // it's being closed by the TaskCompletion listener. + // Closing root and allocator here is cleaner because root and allocator is owned + // by the writer thread and is only visible to the writer thread. + // + // If the writer thread is interrupted by TaskCompletionListener, it should either + // (1) in the try block, in which case it will get an InterruptedException when + // performing io, and goes into the finally block or (2) in the finally block, + // in which case it will ignore the interruption and close the resources. root.close() allocator.close() } From de01a8d50c9c3e196591db057d544f5d7b24d95f Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Mon, 28 May 2018 12:09:44 +0800 Subject: [PATCH 0880/1161] [SPARK-24373][SQL] Add AnalysisBarrier to RelationalGroupedDataset's and KeyValueGroupedDataset's child ## What changes were proposed in this pull request? When we create a `RelationalGroupedDataset` or a `KeyValueGroupedDataset` we set its child to the `logicalPlan` of the `DataFrame` we need to aggregate. Since the `logicalPlan` is already analyzed, we should not analyze it again. But this happens when the new plan of the aggregate is analyzed. The current behavior in most of the cases is likely to produce no harm, but in other cases re-analyzing an analyzed plan can change it, since the analysis is not idempotent. This can cause issues like the one described in the JIRA (missing to find a cached plan). The PR adds an `AnalysisBarrier` to the `logicalPlan` which is used as child of `RelationalGroupedDataset` or a `KeyValueGroupedDataset`. ## How was this patch tested? added UT Author: Marco Gaido Closes #21432 from mgaido91/SPARK-24373. --- .../scala/org/apache/spark/sql/Dataset.scala | 2 +- .../spark/sql/KeyValueGroupedDataset.scala | 2 +- .../spark/sql/RelationalGroupedDataset.scala | 12 +-- .../spark/sql/GroupedDatasetSuite.scala | 96 +++++++++++++++++++ 4 files changed, 104 insertions(+), 8 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/GroupedDatasetSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 32267eb0300f5..abb5ae53f4d73 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -196,7 +196,7 @@ class Dataset[T] private[sql]( } // Wraps analyzed logical plans with an analysis barrier so we won't traverse/resolve it again. - @transient private val planWithBarrier = AnalysisBarrier(logicalPlan) + @transient private[sql] val planWithBarrier = AnalysisBarrier(logicalPlan) /** * Currently [[ExpressionEncoder]] is the only implementation of [[Encoder]], here we turn the diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala index 6bab21dca0cbd..36f6038aa9485 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala @@ -49,7 +49,7 @@ class KeyValueGroupedDataset[K, V] private[sql]( private implicit val kExprEnc = encoderFor(kEncoder) private implicit val vExprEnc = encoderFor(vEncoder) - private def logicalPlan = queryExecution.analyzed + private def logicalPlan = AnalysisBarrier(queryExecution.analyzed) private def sparkSession = queryExecution.sparkSession /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index 6c2be3610ae30..c6449cd5a16b0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -63,17 +63,17 @@ class RelationalGroupedDataset protected[sql]( groupType match { case RelationalGroupedDataset.GroupByType => Dataset.ofRows( - df.sparkSession, Aggregate(groupingExprs, aliasedAgg, df.logicalPlan)) + df.sparkSession, Aggregate(groupingExprs, aliasedAgg, df.planWithBarrier)) case RelationalGroupedDataset.RollupType => Dataset.ofRows( - df.sparkSession, Aggregate(Seq(Rollup(groupingExprs)), aliasedAgg, df.logicalPlan)) + df.sparkSession, Aggregate(Seq(Rollup(groupingExprs)), aliasedAgg, df.planWithBarrier)) case RelationalGroupedDataset.CubeType => Dataset.ofRows( - df.sparkSession, Aggregate(Seq(Cube(groupingExprs)), aliasedAgg, df.logicalPlan)) + df.sparkSession, Aggregate(Seq(Cube(groupingExprs)), aliasedAgg, df.planWithBarrier)) case RelationalGroupedDataset.PivotType(pivotCol, values) => val aliasedGrps = groupingExprs.map(alias) Dataset.ofRows( - df.sparkSession, Pivot(Some(aliasedGrps), pivotCol, values, aggExprs, df.logicalPlan)) + df.sparkSession, Pivot(Some(aliasedGrps), pivotCol, values, aggExprs, df.planWithBarrier)) } } @@ -433,7 +433,7 @@ class RelationalGroupedDataset protected[sql]( df.exprEnc.schema, groupingAttributes, df.logicalPlan.output, - df.logicalPlan)) + df.planWithBarrier)) } /** @@ -459,7 +459,7 @@ class RelationalGroupedDataset protected[sql]( case other => Alias(other, other.toString)() } val groupingAttributes = groupingNamedExpressions.map(_.toAttribute) - val child = df.logicalPlan + val child = df.planWithBarrier val project = Project(groupingNamedExpressions ++ child.output, child) val output = expr.dataType.asInstanceOf[StructType].toAttributes val plan = FlatMapGroupsInPandas(groupingAttributes, expr, output, project) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/GroupedDatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/GroupedDatasetSuite.scala new file mode 100644 index 0000000000000..147c0b61f5017 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/GroupedDatasetSuite.scala @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.api.python.PythonEvalType +import org.apache.spark.sql.catalyst.expressions.PythonUDF +import org.apache.spark.sql.catalyst.plans.logical.AnalysisBarrier +import org.apache.spark.sql.functions.udf +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.{LongType, StructField, StructType} + +class GroupedDatasetSuite extends QueryTest with SharedSQLContext { + import testImplicits._ + + private val scalaUDF = udf((x: Long) => { x + 1 }) + private lazy val datasetWithUDF = spark.range(1).toDF("s").select($"s", scalaUDF($"s")) + + private def assertContainsAnalysisBarrier(ds: Dataset[_], atLevel: Int = 1): Unit = { + assert(atLevel >= 0) + var children = Seq(ds.queryExecution.logical) + (1 to atLevel).foreach { _ => + children = children.flatMap(_.children) + } + val barriers = children.collect { + case ab: AnalysisBarrier => ab + } + assert(barriers.nonEmpty, s"Plan does not contain AnalysisBarrier at level $atLevel:\n" + + ds.queryExecution.logical) + } + + test("SPARK-24373: avoid running Analyzer rules twice on RelationalGroupedDataset") { + val groupByDataset = datasetWithUDF.groupBy() + val rollupDataset = datasetWithUDF.rollup("s") + val cubeDataset = datasetWithUDF.cube("s") + val pivotDataset = datasetWithUDF.groupBy().pivot("s", Seq(1, 2)) + datasetWithUDF.cache() + Seq(groupByDataset, rollupDataset, cubeDataset, pivotDataset).foreach { rgDS => + val df = rgDS.count() + assertContainsAnalysisBarrier(df) + assertCached(df) + } + + val flatMapGroupsInRDF = datasetWithUDF.groupBy().flatMapGroupsInR( + Array.emptyByteArray, + Array.emptyByteArray, + Array.empty, + StructType(Seq(StructField("s", LongType)))) + val flatMapGroupsInPandasDF = datasetWithUDF.groupBy().flatMapGroupsInPandas(PythonUDF( + "pyUDF", + null, + StructType(Seq(StructField("s", LongType))), + Seq.empty, + PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, + true)) + Seq(flatMapGroupsInRDF, flatMapGroupsInPandasDF).foreach { df => + assertContainsAnalysisBarrier(df, 2) + assertCached(df) + } + datasetWithUDF.unpersist(true) + } + + test("SPARK-24373: avoid running Analyzer rules twice on KeyValueGroupedDataset") { + val kvDasaset = datasetWithUDF.groupByKey(_.getLong(0)) + datasetWithUDF.cache() + val mapValuesKVDataset = kvDasaset.mapValues(_.getLong(0)).reduceGroups(_ + _) + val keysKVDataset = kvDasaset.keys + val flatMapGroupsKVDataset = kvDasaset.flatMapGroups((k, _) => Seq(k)) + val aggKVDataset = kvDasaset.count() + val otherKVDataset = spark.range(1).groupByKey(_ + 1) + val cogroupKVDataset = kvDasaset.cogroup(otherKVDataset)((k, _, _) => Seq(k)) + Seq((mapValuesKVDataset, 1), + (keysKVDataset, 2), + (flatMapGroupsKVDataset, 2), + (aggKVDataset, 1), + (cogroupKVDataset, 2)).foreach { case (df, analysisBarrierDepth) => + assertContainsAnalysisBarrier(df, analysisBarrierDepth) + assertCached(df) + } + datasetWithUDF.unpersist(true) + } +} From fa2ae9d2019f839647d17932d8fea769e7622777 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Mon, 28 May 2018 12:56:05 +0800 Subject: [PATCH 0881/1161] [SPARK-24392][PYTHON] Label pandas_udf as Experimental ## What changes were proposed in this pull request? The pandas_udf functionality was introduced in 2.3.0, but is not completely stable and still evolving. This adds a label to indicate it is still an experimental API. ## How was this patch tested? NA Author: Bryan Cutler Closes #21435 from BryanCutler/arrow-pandas_udf-experimental-SPARK-24392. --- docs/sql-programming-guide.md | 4 ++++ python/pyspark/sql/dataframe.py | 2 ++ python/pyspark/sql/functions.py | 2 ++ python/pyspark/sql/group.py | 2 ++ python/pyspark/sql/session.py | 2 ++ 5 files changed, 12 insertions(+) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index fc26562ff33da..50600861912b1 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1827,6 +1827,10 @@ working with timestamps in `pandas_udf`s to get the best performance, see - Since Spark 2.0, Spark converts Parquet Hive tables by default for better performance. Since Spark 2.4, Spark converts ORC Hive tables by default, too. It means Spark uses its own ORC support by default instead of Hive SerDe. As an example, `CREATE TABLE t(id int) STORED AS ORC` would be handled with Hive SerDe in Spark 2.3, and in Spark 2.4, it would be converted into Spark's ORC data source table and ORC vectorization would be applied. To set `false` to `spark.sql.hive.convertMetastoreOrc` restores the previous behavior. - In version 2.3 and earlier, CSV rows are considered as malformed if at least one column value in the row is malformed. CSV parser dropped such rows in the DROPMALFORMED mode or outputs an error in the FAILFAST mode. Since Spark 2.4, CSV row is considered as malformed only when it contains malformed column values requested from CSV datasource, other values can be ignored. As an example, CSV file contains the "id,name" header and one row "1234". In Spark 2.4, selection of the id column consists of a row with one column value 1234 but in Spark 2.3 and earlier it is empty in the DROPMALFORMED mode. To restore the previous behavior, set `spark.sql.csv.parser.columnPruning.enabled` to `false`. +## Upgrading From Spark SQL 2.3.0 to 2.3.1 and above + + - As of version 2.3.1 Arrow functionality, including `pandas_udf` and `toPandas()`/`createDataFrame()` with `spark.sql.execution.arrow.enabled` set to `True`, has been marked as experimental. These are still evolving and not currently recommended for use in production. + ## Upgrading From Spark SQL 2.2 to 2.3 - Since Spark 2.3, the queries from raw JSON/CSV files are disallowed when the referenced columns only include the internal corrupt record column (named `_corrupt_record` by default). For example, `spark.read.schema(schema).json(file).filter($"_corrupt_record".isNotNull).count()` and `spark.read.schema(schema).json(file).select("_corrupt_record").show()`. Instead, you can cache or save the parsed results and then send the same query. For example, `val df = spark.read.schema(schema).json(file).cache()` and then `df.filter($"_corrupt_record".isNotNull).count()`. diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 213dc158f9328..808235ab25440 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1975,6 +1975,8 @@ def toPandas(self): .. note:: This method should only be used if the resulting Pandas's DataFrame is expected to be small, as all the data is loaded into the driver's memory. + .. note:: Usage with spark.sql.execution.arrow.enabled=True is experimental. + >>> df.toPandas() # doctest: +SKIP age name 0 2 Alice diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index fbc8a2d038f8f..efcce25a08e04 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2456,6 +2456,8 @@ def pandas_udf(f=None, returnType=None, functionType=None): :param functionType: an enum value in :class:`pyspark.sql.functions.PandasUDFType`. Default: SCALAR. + .. note:: Experimental + The function type of the UDF can be one of the following: 1. SCALAR diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py index 3505065b648f2..0906c9c6b329a 100644 --- a/python/pyspark/sql/group.py +++ b/python/pyspark/sql/group.py @@ -236,6 +236,8 @@ def apply(self, udf): into memory, so the user should be aware of the potential OOM risk if data is skewed and certain groups are too large to fit in memory. + .. note:: Experimental + :param udf: a grouped map user-defined function returned by :func:`pyspark.sql.functions.pandas_udf`. diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index 13d6e2e53dbd0..d675a240172a7 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -584,6 +584,8 @@ def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=Tr .. versionchanged:: 2.1 Added verifySchema. + .. note:: Usage with spark.sql.execution.arrow.enabled=True is experimental. + >>> l = [('Alice', 1)] >>> spark.createDataFrame(l).collect() [Row(_1=u'Alice', _2=1)] From b31b587cd091010337378cf448fd598c37757053 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Tue, 29 May 2018 10:35:30 +0800 Subject: [PATCH 0882/1161] [SPARK-19613][SS][TEST] Random.nextString is not safe for directory namePrefix MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? `Random.nextString` is good for generating random string data, but it's not proper for directory name prefix in `Utils.createDirectory(tempDir, Random.nextString(10))`. This PR uses more safe directory namePrefix. ```scala scala> scala.util.Random.nextString(10) res0: String = 馨쭔ᎰႻ穚䃈兩㻞藑並 ``` ```scala StateStoreRDDSuite: - versioning and immutability - recovering from files - usage with iterators - only gets and only puts - preferred locations using StateStoreCoordinator *** FAILED *** java.io.IOException: Failed to create a temp directory (under /.../spark/sql/core/target/tmp/StateStoreRDDSuite8712796397908632676) after 10 attempts! at org.apache.spark.util.Utils$.createDirectory(Utils.scala:295) at org.apache.spark.sql.execution.streaming.state.StateStoreRDDSuite$$anonfun$13$$anonfun$apply$6.apply(StateStoreRDDSuite.scala:152) at org.apache.spark.sql.execution.streaming.state.StateStoreRDDSuite$$anonfun$13$$anonfun$apply$6.apply(StateStoreRDDSuite.scala:149) at org.apache.spark.sql.catalyst.util.package$.quietly(package.scala:42) at org.apache.spark.sql.execution.streaming.state.StateStoreRDDSuite$$anonfun$13.apply(StateStoreRDDSuite.scala:149) at org.apache.spark.sql.execution.streaming.state.StateStoreRDDSuite$$anonfun$13.apply(StateStoreRDDSuite.scala:149) ... - distributed test *** FAILED *** java.io.IOException: Failed to create a temp directory (under /.../spark/sql/core/target/tmp/StateStoreRDDSuite8712796397908632676) after 10 attempts! at org.apache.spark.util.Utils$.createDirectory(Utils.scala:295) ``` ## How was this patch tested? Pass the existing tests.StateStoreRDDSuite: Author: Dongjoon Hyun Closes #21446 from dongjoon-hyun/SPARK-19613. --- .../execution/streaming/state/StateStoreRDDSuite.scala | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala index 65b39f0fbd73d..579a364ebc3e5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala @@ -55,7 +55,7 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn test("versioning and immutability") { withSparkSession(SparkSession.builder.config(sparkConf).getOrCreate()) { spark => - val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString + val path = Utils.createDirectory(tempDir, Random.nextFloat.toString).toString val rdd1 = makeRDD(spark.sparkContext, Seq("a", "b", "a")).mapPartitionsWithStateStore( spark.sqlContext, operatorStateInfo(path, version = 0), keySchema, valueSchema, None)( increment) @@ -73,7 +73,7 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn } test("recovering from files") { - val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString + val path = Utils.createDirectory(tempDir, Random.nextFloat.toString).toString def makeStoreRDD( spark: SparkSession, @@ -101,7 +101,7 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn test("usage with iterators - only gets and only puts") { withSparkSession(SparkSession.builder.config(sparkConf).getOrCreate()) { spark => implicit val sqlContext = spark.sqlContext - val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString + val path = Utils.createDirectory(tempDir, Random.nextFloat.toString).toString val opId = 0 // Returns an iterator of the incremented value made into the store @@ -149,7 +149,7 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn quietly { val queryRunId = UUID.randomUUID val opId = 0 - val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString + val path = Utils.createDirectory(tempDir, Random.nextFloat.toString).toString withSparkSession(SparkSession.builder.config(sparkConf).getOrCreate()) { spark => implicit val sqlContext = spark.sqlContext @@ -189,7 +189,7 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn .config(sparkConf.setMaster("local-cluster[2, 1, 1024]")) .getOrCreate()) { spark => implicit val sqlContext = spark.sqlContext - val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString + val path = Utils.createDirectory(tempDir, Random.nextFloat.toString).toString val opId = 0 val rdd1 = makeRDD(spark.sparkContext, Seq("a", "b", "a")).mapPartitionsWithStateStore( sqlContext, operatorStateInfo(path, version = 0), keySchema, valueSchema, None)(increment) From 2ced6193b39dc63e5f74138859f2a9d69d3cfd11 Mon Sep 17 00:00:00 2001 From: jerryshao Date: Tue, 29 May 2018 10:48:48 +0800 Subject: [PATCH 0883/1161] [SPARK-24377][SPARK SUBMIT] make --py-files work in non pyspark application ## What changes were proposed in this pull request? For some Spark applications, though they're a java program, they require not only jar dependencies, but also python dependencies. One example is Livy remote SparkContext application, this application is actually an embedded REPL for Scala/Python/R, it will not only load in jar dependencies, but also python and R deps, so we should specify not only "--jars", but also "--py-files". Currently for a Spark application, --py-files can only be worked for a pyspark application, so it will not be worked in the above case. So here propose to remove such restriction. Also we tested that "spark.submit.pyFiles" only supports quite limited scenario (client mode with local deps), so here also expand the usage of "spark.submit.pyFiles" to be alternative of --py-files. ## How was this patch tested? UT added. Author: jerryshao Closes #21420 from jerryshao/SPARK-24377. --- .../org/apache/spark/deploy/SparkSubmit.scala | 11 ++--- .../spark/deploy/SparkSubmitArguments.scala | 4 +- .../spark/deploy/SparkSubmitSuite.scala | 46 ++++++++++++++++++- 3 files changed, 49 insertions(+), 12 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 4baf032f0e9c6..a46af26feb061 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -430,18 +430,15 @@ private[spark] class SparkSubmit extends Logging { // Usage: PythonAppRunner
    [app arguments] args.mainClass = "org.apache.spark.deploy.PythonRunner" args.childArgs = ArrayBuffer(localPrimaryResource, localPyFiles) ++ args.childArgs - if (clusterManager != YARN) { - // The YARN backend distributes the primary file differently, so don't merge it. - args.files = mergeFileLists(args.files, args.primaryResource) - } } if (clusterManager != YARN) { // The YARN backend handles python files differently, so don't merge the lists. args.files = mergeFileLists(args.files, args.pyFiles) } - if (localPyFiles != null) { - sparkConf.set("spark.submit.pyFiles", localPyFiles) - } + } + + if (localPyFiles != null) { + sparkConf.set("spark.submit.pyFiles", localPyFiles) } // In YARN mode for an R app, add the SparkR package archive and the R package diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index fed4e0a5069c3..fb232101114b9 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -182,6 +182,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S name = Option(name).orElse(sparkProperties.get("spark.app.name")).orNull jars = Option(jars).orElse(sparkProperties.get("spark.jars")).orNull files = Option(files).orElse(sparkProperties.get("spark.files")).orNull + pyFiles = Option(pyFiles).orElse(sparkProperties.get("spark.submit.pyFiles")).orNull ivyRepoPath = sparkProperties.get("spark.jars.ivy").orNull ivySettingsPath = sparkProperties.get("spark.jars.ivySettings") packages = Option(packages).orElse(sparkProperties.get("spark.jars.packages")).orNull @@ -280,9 +281,6 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S numExecutors != null && Try(numExecutors.toInt).getOrElse(-1) <= 0) { error("Number of executors must be a positive number") } - if (pyFiles != null && !isPython) { - error("--py-files given but primary resource is not a Python script") - } if (master.startsWith("yarn")) { val hasHadoopEnv = env.contains("HADOOP_CONF_DIR") || env.contains("YARN_CONF_DIR") diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index 43286953e4383..545c8d0423dc3 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -771,9 +771,13 @@ class SparkSubmitSuite PythonRunner.formatPaths(Utils.resolveURIs(pyFiles)).mkString(",")) // Test remote python files + val hadoopConf = new Configuration() + updateConfWithFakeS3Fs(hadoopConf) val f4 = File.createTempFile("test-submit-remote-python-files", "", tmpDir) + val pyFile1 = File.createTempFile("file1", ".py", tmpDir) + val pyFile2 = File.createTempFile("file2", ".py", tmpDir) val writer4 = new PrintWriter(f4) - val remotePyFiles = "hdfs:///tmp/file1.py,hdfs:///tmp/file2.py" + val remotePyFiles = s"s3a://${pyFile1.getAbsolutePath},s3a://${pyFile2.getAbsolutePath}" writer4.println("spark.submit.pyFiles " + remotePyFiles) writer4.close() val clArgs4 = Seq( @@ -783,7 +787,7 @@ class SparkSubmitSuite "hdfs:///tmp/mister.py" ) val appArgs4 = new SparkSubmitArguments(clArgs4) - val (_, _, conf4, _) = submit.prepareSubmitEnvironment(appArgs4) + val (_, _, conf4, _) = submit.prepareSubmitEnvironment(appArgs4, conf = Some(hadoopConf)) // Should not format python path for yarn cluster mode conf4.get("spark.submit.pyFiles") should be(Utils.resolveURIs(remotePyFiles)) } @@ -1093,6 +1097,44 @@ class SparkSubmitSuite assert(exception.getMessage() === "hello") } + test("support --py-files/spark.submit.pyFiles in non pyspark application") { + val hadoopConf = new Configuration() + updateConfWithFakeS3Fs(hadoopConf) + + val tmpDir = Utils.createTempDir() + val pyFile = File.createTempFile("tmpPy", ".egg", tmpDir) + + val args = Seq( + "--class", UserClasspathFirstTest.getClass.getName.stripPrefix("$"), + "--name", "testApp", + "--master", "yarn", + "--deploy-mode", "client", + "--py-files", s"s3a://${pyFile.getAbsolutePath}", + "spark-internal" + ) + + val appArgs = new SparkSubmitArguments(args) + val (_, _, conf, _) = submit.prepareSubmitEnvironment(appArgs, conf = Some(hadoopConf)) + + conf.get(PY_FILES.key) should be (s"s3a://${pyFile.getAbsolutePath}") + conf.get("spark.submit.pyFiles") should (startWith("/")) + + // Verify "spark.submit.pyFiles" + val args1 = Seq( + "--class", UserClasspathFirstTest.getClass.getName.stripPrefix("$"), + "--name", "testApp", + "--master", "yarn", + "--deploy-mode", "client", + "--conf", s"spark.submit.pyFiles=s3a://${pyFile.getAbsolutePath}", + "spark-internal" + ) + + val appArgs1 = new SparkSubmitArguments(args1) + val (_, _, conf1, _) = submit.prepareSubmitEnvironment(appArgs1, conf = Some(hadoopConf)) + + conf1.get(PY_FILES.key) should be (s"s3a://${pyFile.getAbsolutePath}") + conf1.get("spark.submit.pyFiles") should (startWith("/")) + } } object SparkSubmitSuite extends SparkFunSuite with TimeLimits { From 23db600c956cff4d0b20c38ddd2d746de2b535a0 Mon Sep 17 00:00:00 2001 From: Xiao Li Date: Mon, 28 May 2018 23:23:22 -0700 Subject: [PATCH 0884/1161] [SPARK-24250][SQL][FOLLOW-UP] support accessing SQLConf inside tasks ## What changes were proposed in this pull request? We should not stop users from calling `getActiveSession` and `getDefaultSession` in executors. To not break the existing behaviors, we should simply return None. ## How was this patch tested? N/A Author: Xiao Li Closes #21436 from gatorsmile/followUpSPARK-24250. --- .../org/apache/spark/sql/SparkSession.scala | 20 +++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index e2a1a57c7dd4d..565042fcf762e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -1021,21 +1021,33 @@ object SparkSession extends Logging { /** * Returns the active SparkSession for the current thread, returned by the builder. * + * @note Return None, when calling this function on executors + * * @since 2.2.0 */ def getActiveSession: Option[SparkSession] = { - assertOnDriver() - Option(activeThreadSession.get) + if (TaskContext.get != null) { + // Return None when running on executors. + None + } else { + Option(activeThreadSession.get) + } } /** * Returns the default SparkSession that is returned by the builder. * + * @note Return None, when calling this function on executors + * * @since 2.2.0 */ def getDefaultSession: Option[SparkSession] = { - assertOnDriver() - Option(defaultSession.get) + if (TaskContext.get != null) { + // Return None when running on executors. + None + } else { + Option(defaultSession.get) + } } /** From aca65c63cb12073eb193fe08998994c60acb8b58 Mon Sep 17 00:00:00 2001 From: Gabor Somogyi Date: Tue, 29 May 2018 20:10:59 +0800 Subject: [PATCH 0885/1161] [SPARK-23991][DSTREAMS] Fix data loss when WAL write fails in allocateBlocksToBatch When blocks tried to get allocated to a batch and WAL write fails then the blocks will be removed from the received block queue. This fact simply produces data loss because the next allocation will not find the mentioned blocks in the queue. In this PR blocks will be removed from the received queue only if WAL write succeded. Additional unit test. Author: Gabor Somogyi Closes #21430 from gaborgsomogyi/SPARK-23991. Change-Id: I5ead84f0233f0c95e6d9f2854ac2ff6906f6b341 --- .../scheduler/ReceivedBlockTracker.scala | 3 +- .../streaming/ReceivedBlockTrackerSuite.scala | 47 ++++++++++++++++++- 2 files changed, 47 insertions(+), 3 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala index dacff69d55dd2..cf4324578ea87 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala @@ -112,10 +112,11 @@ private[streaming] class ReceivedBlockTracker( def allocateBlocksToBatch(batchTime: Time): Unit = synchronized { if (lastAllocatedBatchTime == null || batchTime > lastAllocatedBatchTime) { val streamIdToBlocks = streamIds.map { streamId => - (streamId, getReceivedBlockQueue(streamId).dequeueAll(x => true)) + (streamId, getReceivedBlockQueue(streamId).clone()) }.toMap val allocatedBlocks = AllocatedBlocks(streamIdToBlocks) if (writeToLog(BatchAllocationEvent(batchTime, allocatedBlocks))) { + streamIds.foreach(getReceivedBlockQueue(_).clear()) timeToAllocatedBlocks.put(batchTime, allocatedBlocks) lastAllocatedBatchTime = batchTime } else { diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala index 4fa236bd39663..fd7e00b1de25f 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala @@ -26,10 +26,12 @@ import scala.language.{implicitConversions, postfixOps} import scala.util.Random import org.apache.hadoop.conf.Configuration +import org.mockito.Matchers.any +import org.mockito.Mockito.{doThrow, reset, spy} import org.scalatest.{BeforeAndAfter, Matchers} import org.scalatest.concurrent.Eventually._ -import org.apache.spark.{SparkConf, SparkException, SparkFunSuite} +import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.internal.Logging import org.apache.spark.storage.StreamBlockId import org.apache.spark.streaming.receiver.BlockManagerBasedStoreResult @@ -115,6 +117,47 @@ class ReceivedBlockTrackerSuite tracker2.stop() } + test("block allocation to batch should not loose blocks from received queue") { + val tracker1 = spy(createTracker()) + tracker1.isWriteAheadLogEnabled should be (true) + tracker1.getUnallocatedBlocks(streamId) shouldEqual Seq.empty + + // Add blocks + val blockInfos = generateBlockInfos() + blockInfos.map(tracker1.addBlock) + tracker1.getUnallocatedBlocks(streamId) shouldEqual blockInfos + + // Try to allocate the blocks to a batch and verify that it's failing + // The blocks should stay in the received queue when WAL write failing + doThrow(new RuntimeException("Not able to write BatchAllocationEvent")) + .when(tracker1).writeToLog(any(classOf[BatchAllocationEvent])) + val errMsg = intercept[RuntimeException] { + tracker1.allocateBlocksToBatch(1) + } + assert(errMsg.getMessage === "Not able to write BatchAllocationEvent") + tracker1.getUnallocatedBlocks(streamId) shouldEqual blockInfos + tracker1.getBlocksOfBatch(1) shouldEqual Map.empty + tracker1.getBlocksOfBatchAndStream(1, streamId) shouldEqual Seq.empty + + // Allocate the blocks to a batch and verify that all of them have been allocated + reset(tracker1) + tracker1.allocateBlocksToBatch(2) + tracker1.getUnallocatedBlocks(streamId) shouldEqual Seq.empty + tracker1.hasUnallocatedReceivedBlocks should be (false) + tracker1.getBlocksOfBatch(2) shouldEqual Map(streamId -> blockInfos) + tracker1.getBlocksOfBatchAndStream(2, streamId) shouldEqual blockInfos + + tracker1.stop() + + // Recover from WAL to see the correctness + val tracker2 = createTracker(recoverFromWriteAheadLog = true) + tracker2.getUnallocatedBlocks(streamId) shouldEqual Seq.empty + tracker2.hasUnallocatedReceivedBlocks should be (false) + tracker2.getBlocksOfBatch(2) shouldEqual Map(streamId -> blockInfos) + tracker2.getBlocksOfBatchAndStream(2, streamId) shouldEqual blockInfos + tracker2.stop() + } + test("recovery and cleanup with write ahead logs") { val manualClock = new ManualClock // Set the time increment level to twice the rotation interval so that every increment creates @@ -312,7 +355,7 @@ class ReceivedBlockTrackerSuite recoverFromWriteAheadLog: Boolean = false, clock: Clock = new SystemClock): ReceivedBlockTracker = { val cpDirOption = if (setCheckpointDir) Some(checkpointDirectory.toString) else None - val tracker = new ReceivedBlockTracker( + var tracker = new ReceivedBlockTracker( conf, hadoopConf, Seq(streamId), clock, recoverFromWriteAheadLog, cpDirOption) allReceivedBlockTrackers += tracker tracker From 900bc1f7dc5a7c013b473fceab1c4052ade74a2f Mon Sep 17 00:00:00 2001 From: DB Tsai Date: Tue, 29 May 2018 10:22:18 -0700 Subject: [PATCH 0886/1161] [SPARK-24371][SQL] Added isInCollection in DataFrame API for Scala and Java. ## What changes were proposed in this pull request? Implemented **`isInCollection `** in DataFrame API for both Scala and Java, so users can do ```scala val profileDF = Seq( Some(1), Some(2), Some(3), Some(4), Some(5), Some(6), Some(7), None ).toDF("profileID") val validUsers: Seq[Any] = Seq(6, 7.toShort, 8L, "3") val result = profileDF.withColumn("isValid", $"profileID". isInCollection(validUsers)) result.show(10) """ +---------+-------+ |profileID|isValid| +---------+-------+ | 1| false| | 2| false| | 3| true| | 4| false| | 5| false| | 6| true| | 7| true| | null| null| +---------+-------+ """.stripMargin ``` ## How was this patch tested? Several unit tests are added. Author: DB Tsai Closes #21416 from dbtsai/optimize-set. --- .../sql/catalyst/optimizer/expressions.scala | 1 - .../catalyst/plans/logical/LogicalPlan.scala | 2 +- .../scala/org/apache/spark/sql/Column.scala | 19 ++++++ .../spark/sql/ColumnExpressionSuite.scala | 62 ++++++++++++++++++- 4 files changed, 81 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 1c0b7bd806801..1d363b8146e3f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -21,7 +21,6 @@ import scala.collection.immutable.HashSet import scala.collection.mutable.{ArrayBuffer, Stack} import org.apache.spark.sql.catalyst.analysis._ -import org.apache.spark.sql.catalyst.analysis.TypeCoercion.ImplicitTypeCasts import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral} import org.apache.spark.sql.catalyst.expressions.aggregate._ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index e487693927ab6..c486ad700f362 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -78,7 +78,7 @@ abstract class LogicalPlan schema.map { field => resolve(field.name :: Nil, resolver).map { case a: AttributeReference => a - case other => sys.error(s"can not handle nested schema yet... plan $this") + case _ => sys.error(s"can not handle nested schema yet... plan $this") }.getOrElse { throw new AnalysisException( s"Unable to resolve ${field.name} given [${output.map(_.name).mkString(", ")}]") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index ad0efbae89830..b3e59f53ee3de 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql +import scala.collection.JavaConverters._ import scala.language.implicitConversions import org.apache.spark.annotation.InterfaceStability @@ -786,6 +787,24 @@ class Column(val expr: Expression) extends Logging { @scala.annotation.varargs def isin(list: Any*): Column = withExpr { In(expr, list.map(lit(_).expr)) } + /** + * A boolean expression that is evaluated to true if the value of this expression is contained + * by the provided collection. + * + * @group expr_ops + * @since 2.4.0 + */ + def isInCollection(values: scala.collection.Iterable[_]): Column = isin(values.toSeq: _*) + + /** + * A boolean expression that is evaluated to true if the value of this expression is contained + * by the provided collection. + * + * @group java_expr_ops + * @since 2.4.0 + */ + def isInCollection(values: java.lang.Iterable[_]): Column = isInCollection(values.asScala) + /** * SQL like expression. Returns a boolean column based on a SQL LIKE match. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index 7c45be21961d3..2182bd7eadd63 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -17,6 +17,10 @@ package org.apache.spark.sql +import java.util.Locale + +import scala.collection.JavaConverters._ + import org.apache.hadoop.io.{LongWritable, Text} import org.apache.hadoop.mapreduce.lib.input.{TextInputFormat => NewTextInputFormat} import org.scalatest.Matchers._ @@ -390,11 +394,67 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { checkAnswer(df.filter($"b".isin("z", "y")), df.collect().toSeq.filter(r => r.getString(1) == "z" || r.getString(1) == "y")) + // Auto casting should work with mixture of different types in collections + checkAnswer(df.filter($"a".isin(1.toShort, "2")), + df.collect().toSeq.filter(r => r.getInt(0) == 1 || r.getInt(0) == 2)) + checkAnswer(df.filter($"a".isin("3", 2.toLong)), + df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 2)) + checkAnswer(df.filter($"a".isin(3, "1")), + df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 1)) + val df2 = Seq((1, Seq(1)), (2, Seq(2)), (3, Seq(3))).toDF("a", "b") - intercept[AnalysisException] { + val e = intercept[AnalysisException] { df2.filter($"a".isin($"b")) } + Seq("cannot resolve", "due to data type mismatch: Arguments must be same type but were") + .foreach { s => + assert(e.getMessage.toLowerCase(Locale.ROOT).contains(s.toLowerCase(Locale.ROOT))) + } + } + + test("isInCollection: Scala Collection") { + val df = Seq((1, "x"), (2, "y"), (3, "z")).toDF("a", "b") + // Test with different types of collections + checkAnswer(df.filter($"a".isInCollection(Seq(3, 1))), + df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 1)) + checkAnswer(df.filter($"a".isInCollection(Seq(1, 2).toSet)), + df.collect().toSeq.filter(r => r.getInt(0) == 1 || r.getInt(0) == 2)) + checkAnswer(df.filter($"a".isInCollection(Seq(3, 2).toArray)), + df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 2)) + checkAnswer(df.filter($"a".isInCollection(Seq(3, 1).toList)), + df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 1)) + + val df2 = Seq((1, Seq(1)), (2, Seq(2)), (3, Seq(3))).toDF("a", "b") + + val e = intercept[AnalysisException] { + df2.filter($"a".isInCollection(Seq($"b"))) + } + Seq("cannot resolve", "due to data type mismatch: Arguments must be same type but were") + .foreach { s => + assert(e.getMessage.toLowerCase(Locale.ROOT).contains(s.toLowerCase(Locale.ROOT))) + } + } + + test("isInCollection: Java Collection") { + val df = Seq((1, "x"), (2, "y"), (3, "z")).toDF("a", "b") + // Test with different types of collections + checkAnswer(df.filter($"a".isInCollection(Seq(1, 2).asJava)), + df.collect().toSeq.filter(r => r.getInt(0) == 1 || r.getInt(0) == 2)) + checkAnswer(df.filter($"a".isInCollection(Seq(1, 2).toSet.asJava)), + df.collect().toSeq.filter(r => r.getInt(0) == 1 || r.getInt(0) == 2)) + checkAnswer(df.filter($"a".isInCollection(Seq(3, 1).toList.asJava)), + df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 1)) + + val df2 = Seq((1, Seq(1)), (2, Seq(2)), (3, Seq(3))).toDF("a", "b") + + val e = intercept[AnalysisException] { + df2.filter($"a".isInCollection(Seq($"b").asJava)) + } + Seq("cannot resolve", "due to data type mismatch: Arguments must be same type but were") + .foreach { s => + assert(e.getMessage.toLowerCase(Locale.ROOT).contains(s.toLowerCase(Locale.ROOT))) + } } test("&&") { From f48938800e6dc3880441f160dd93856b9f86874e Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Wed, 30 May 2018 09:32:33 +0800 Subject: [PATCH 0887/1161] [SPARK-24365][SQL] Add Data Source write benchmark ## What changes were proposed in this pull request? Add Data Source write benchmark. So that it would be easier to measure the writer performance. Author: Gengliang Wang Closes #21409 from gengliangwang/parquetWriteBenchmark. --- .../benchmark/DataSourceWriteBenchmark.scala | 149 ++++++++++++++++++ 1 file changed, 149 insertions(+) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceWriteBenchmark.scala diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceWriteBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceWriteBenchmark.scala new file mode 100644 index 0000000000000..2d2cdebd067c1 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceWriteBenchmark.scala @@ -0,0 +1,149 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.benchmark + +import org.apache.spark.SparkConf +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.util.Benchmark + +/** + * Benchmark to measure data source write performance. + * By default it measures 4 data source format: Parquet, ORC, JSON, CSV: + * spark-submit --class + * To measure specified formats, run it with arguments: + * spark-submit --class format1 [format2] [...] + */ +object DataSourceWriteBenchmark { + val conf = new SparkConf() + .setAppName("DataSourceWriteBenchmark") + .setIfMissing("spark.master", "local[1]") + .set("spark.sql.parquet.compression.codec", "snappy") + .set("spark.sql.orc.compression.codec", "snappy") + + val spark = SparkSession.builder.config(conf).getOrCreate() + + // Set default configs. Individual cases will change them if necessary. + spark.conf.set(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, "true") + + val tempTable = "temp" + val numRows = 1024 * 1024 * 15 + + def withTempTable(tableNames: String*)(f: => Unit): Unit = { + try f finally tableNames.foreach(spark.catalog.dropTempView) + } + + def withTable(tableNames: String*)(f: => Unit): Unit = { + try f finally { + tableNames.foreach { name => + spark.sql(s"DROP TABLE IF EXISTS $name") + } + } + } + + def writeNumeric(table: String, format: String, benchmark: Benchmark, dataType: String): Unit = { + spark.sql(s"create table $table(id $dataType) using $format") + benchmark.addCase(s"Output Single $dataType Column") { _ => + spark.sql(s"INSERT OVERWRITE TABLE $table SELECT CAST(id AS $dataType) AS c1 FROM $tempTable") + } + } + + def writeIntString(table: String, format: String, benchmark: Benchmark): Unit = { + spark.sql(s"CREATE TABLE $table(c1 INT, c2 STRING) USING $format") + benchmark.addCase("Output Int and String Column") { _ => + spark.sql(s"INSERT OVERWRITE TABLE $table SELECT CAST(id AS INT) AS " + + s"c1, CAST(id AS STRING) AS c2 FROM $tempTable") + } + } + + def writePartition(table: String, format: String, benchmark: Benchmark): Unit = { + spark.sql(s"CREATE TABLE $table(p INT, id INT) USING $format PARTITIONED BY (p)") + benchmark.addCase("Output Partitions") { _ => + spark.sql(s"INSERT OVERWRITE TABLE $table SELECT CAST(id AS INT) AS id," + + s" CAST(id % 2 AS INT) AS p FROM $tempTable") + } + } + + def writeBucket(table: String, format: String, benchmark: Benchmark): Unit = { + spark.sql(s"CREATE TABLE $table(c1 INT, c2 INT) USING $format CLUSTERED BY (c2) INTO 2 BUCKETS") + benchmark.addCase("Output Buckets") { _ => + spark.sql(s"INSERT OVERWRITE TABLE $table SELECT CAST(id AS INT) AS " + + s"c1, CAST(id AS INT) AS c2 FROM $tempTable") + } + } + + def main(args: Array[String]): Unit = { + val tableInt = "tableInt" + val tableDouble = "tableDouble" + val tableIntString = "tableIntString" + val tablePartition = "tablePartition" + val tableBucket = "tableBucket" + val formats: Seq[String] = if (args.isEmpty) { + Seq("Parquet", "ORC", "JSON", "CSV") + } else { + args + } + /* + Intel(R) Core(TM) i7-6920HQ CPU @ 2.90GHz + Parquet writer benchmark: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Output Single Int Column 1815 / 1932 8.7 115.4 1.0X + Output Single Double Column 1877 / 1878 8.4 119.3 1.0X + Output Int and String Column 6265 / 6543 2.5 398.3 0.3X + Output Partitions 4067 / 4457 3.9 258.6 0.4X + Output Buckets 5608 / 5820 2.8 356.6 0.3X + + ORC writer benchmark: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Output Single Int Column 1201 / 1239 13.1 76.3 1.0X + Output Single Double Column 1542 / 1600 10.2 98.0 0.8X + Output Int and String Column 6495 / 6580 2.4 412.9 0.2X + Output Partitions 3648 / 3842 4.3 231.9 0.3X + Output Buckets 5022 / 5145 3.1 319.3 0.2X + + JSON writer benchmark: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Output Single Int Column 1988 / 2093 7.9 126.4 1.0X + Output Single Double Column 2854 / 2911 5.5 181.4 0.7X + Output Int and String Column 6467 / 6653 2.4 411.1 0.3X + Output Partitions 4548 / 5055 3.5 289.1 0.4X + Output Buckets 5664 / 5765 2.8 360.1 0.4X + + CSV writer benchmark: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Output Single Int Column 3025 / 3190 5.2 192.3 1.0X + Output Single Double Column 3575 / 3634 4.4 227.3 0.8X + Output Int and String Column 7313 / 7399 2.2 464.9 0.4X + Output Partitions 5105 / 5190 3.1 324.6 0.6X + Output Buckets 6986 / 6992 2.3 444.1 0.4X + */ + withTempTable(tempTable) { + spark.range(numRows).createOrReplaceTempView(tempTable) + formats.foreach { format => + withTable(tableInt, tableDouble, tableIntString, tablePartition, tableBucket) { + val benchmark = new Benchmark(s"$format writer benchmark", numRows) + writeNumeric(tableInt, format, benchmark, "Int") + writeNumeric(tableDouble, format, benchmark, "Double") + writeIntString(tableIntString, format, benchmark) + writePartition(tablePartition, format, benchmark) + writeBucket(tableBucket, format, benchmark) + benchmark.run() + } + } + } + } +} From a4be981c0476bb613c660b70a370f671c8b3ffee Mon Sep 17 00:00:00 2001 From: Marek Novotny Date: Tue, 29 May 2018 23:26:39 -0700 Subject: [PATCH 0888/1161] [SPARK-24331][SPARKR][SQL] Adding arrays_overlap, array_repeat, map_entries to SparkR ## What changes were proposed in this pull request? The PR adds functions `arrays_overlap`, `array_repeat`, `map_entries` to SparkR. ## How was this patch tested? Tests added into R/pkg/tests/fulltests/test_sparkSQL.R ## Examples ### arrays_overlap ``` df <- createDataFrame(list(list(list(1L, 2L), list(3L, 1L)), list(list(1L, 2L), list(3L, 4L)), list(list(1L, NA), list(3L, 4L)))) collect(select(df, arrays_overlap(df[[1]], df[[2]]))) ``` ``` arrays_overlap(_1, _2) 1 TRUE 2 FALSE 3 NA ``` ### array_repeat ``` df <- createDataFrame(list(list("a", 3L), list("b", 2L))) collect(select(df, array_repeat(df[[1]], df[[2]]))) ``` ``` array_repeat(_1, _2) 1 a, a, a 2 b, b ``` ``` collect(select(df, array_repeat(df[[1]], 2L))) ``` ``` array_repeat(_1, 2) 1 a, a 2 b, b ``` ### map_entries ``` df <- createDataFrame(list(list(map = as.environment(list(x = 1, y = 2))))) collect(select(df, map_entries(df$map))) ``` ``` map_entries(map) 1 x, 1, y, 2 ``` Author: Marek Novotny Closes #21434 from mn-mikke/SPARK-24331. --- R/pkg/NAMESPACE | 3 ++ R/pkg/R/DataFrame.R | 2 + R/pkg/R/functions.R | 58 ++++++++++++++++++++++++--- R/pkg/R/generics.R | 12 ++++++ R/pkg/tests/fulltests/test_sparkSQL.R | 22 +++++++++- 5 files changed, 91 insertions(+), 6 deletions(-) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index c575fe255f57a..73a33af4dd48b 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -204,7 +204,9 @@ exportMethods("%<=>%", "array_max", "array_min", "array_position", + "array_repeat", "array_sort", + "arrays_overlap", "asc", "ascii", "asin", @@ -302,6 +304,7 @@ exportMethods("%<=>%", "lower", "lpad", "ltrim", + "map_entries", "map_keys", "map_values", "max", diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index a1c9495b0795e..70eb7a874b75c 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -2297,6 +2297,8 @@ setMethod("rename", setClassUnion("characterOrColumn", c("character", "Column")) +setClassUnion("numericOrColumn", c("numeric", "Column")) + #' Arrange Rows by Variables #' #' Sort a SparkDataFrame by the specified column(s). diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index fcb3521f901ea..abc91aeeb4825 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -189,6 +189,7 @@ NULL #' the map or array of maps. #' \item \code{from_json}: it is the column containing the JSON string. #' } +#' @param y Column to compute on. #' @param value A value to compute on. #' \itemize{ #' \item \code{array_contains}: a value to be checked if contained in the column. @@ -207,7 +208,7 @@ NULL #' tmp <- mutate(df, v1 = create_array(df$mpg, df$cyl, df$hp)) #' head(select(tmp, array_contains(tmp$v1, 21), size(tmp$v1))) #' head(select(tmp, array_max(tmp$v1), array_min(tmp$v1))) -#' head(select(tmp, array_position(tmp$v1, 21), array_sort(tmp$v1))) +#' head(select(tmp, array_position(tmp$v1, 21), array_repeat(df$mpg, 3), array_sort(tmp$v1))) #' head(select(tmp, flatten(tmp$v1), reverse(tmp$v1))) #' tmp2 <- mutate(tmp, v2 = explode(tmp$v1)) #' head(tmp2) @@ -216,11 +217,10 @@ NULL #' head(select(tmp, sort_array(tmp$v1))) #' head(select(tmp, sort_array(tmp$v1, asc = FALSE))) #' tmp3 <- mutate(df, v3 = create_map(df$model, df$cyl)) -#' head(select(tmp3, map_keys(tmp3$v3))) -#' head(select(tmp3, map_values(tmp3$v3))) +#' head(select(tmp3, map_entries(tmp3$v3), map_keys(tmp3$v3), map_values(tmp3$v3))) #' head(select(tmp3, element_at(tmp3$v3, "Valiant"))) -#' tmp4 <- mutate(df, v4 = create_array(df$mpg, df$cyl), v5 = create_array(df$hp)) -#' head(select(tmp4, concat(tmp4$v4, tmp4$v5))) +#' tmp4 <- mutate(df, v4 = create_array(df$mpg, df$cyl), v5 = create_array(df$cyl, df$hp)) +#' head(select(tmp4, concat(tmp4$v4, tmp4$v5), arrays_overlap(tmp4$v4, tmp4$v5))) #' head(select(tmp, concat(df$mpg, df$cyl, df$hp)))} NULL @@ -3048,6 +3048,26 @@ setMethod("array_position", column(jc) }) +#' @details +#' \code{array_repeat}: Creates an array containing \code{x} repeated the number of times +#' given by \code{count}. +#' +#' @param count a Column or constant determining the number of repetitions. +#' @rdname column_collection_functions +#' @aliases array_repeat array_repeat,Column,numericOrColumn-method +#' @note array_repeat since 2.4.0 +setMethod("array_repeat", + signature(x = "Column", count = "numericOrColumn"), + function(x, count) { + if (class(count) == "Column") { + count <- count@jc + } else { + count <- as.integer(count) + } + jc <- callJStatic("org.apache.spark.sql.functions", "array_repeat", x@jc, count) + column(jc) + }) + #' @details #' \code{array_sort}: Sorts the input array in ascending order. The elements of the input array #' must be orderable. NA elements will be placed at the end of the returned array. @@ -3062,6 +3082,21 @@ setMethod("array_sort", column(jc) }) +#' @details +#' \code{arrays_overlap}: Returns true if the input arrays have at least one non-null element in +#' common. If not and both arrays are non-empty and any of them contains a null, it returns null. +#' It returns false otherwise. +#' +#' @rdname column_collection_functions +#' @aliases arrays_overlap arrays_overlap,Column-method +#' @note arrays_overlap since 2.4.0 +setMethod("arrays_overlap", + signature(x = "Column", y = "Column"), + function(x, y) { + jc <- callJStatic("org.apache.spark.sql.functions", "arrays_overlap", x@jc, y@jc) + column(jc) + }) + #' @details #' \code{flatten}: Creates a single array from an array of arrays. #' If a structure of nested arrays is deeper than two levels, only one level of nesting is removed. @@ -3076,6 +3111,19 @@ setMethod("flatten", column(jc) }) +#' @details +#' \code{map_entries}: Returns an unordered array of all entries in the given map. +#' +#' @rdname column_collection_functions +#' @aliases map_entries map_entries,Column-method +#' @note map_entries since 2.4.0 +setMethod("map_entries", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "map_entries", x@jc) + column(jc) + }) + #' @details #' \code{map_keys}: Returns an unordered array containing the keys of the map. #' diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 3ea181157b644..8894cb1c5b92f 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -769,10 +769,18 @@ setGeneric("array_min", function(x) { standardGeneric("array_min") }) #' @name NULL setGeneric("array_position", function(x, value) { standardGeneric("array_position") }) +#' @rdname column_collection_functions +#' @name NULL +setGeneric("array_repeat", function(x, count) { standardGeneric("array_repeat") }) + #' @rdname column_collection_functions #' @name NULL setGeneric("array_sort", function(x) { standardGeneric("array_sort") }) +#' @rdname column_collection_functions +#' @name NULL +setGeneric("arrays_overlap", function(x, y) { standardGeneric("arrays_overlap") }) + #' @rdname column_string_functions #' @name NULL setGeneric("ascii", function(x) { standardGeneric("ascii") }) @@ -1034,6 +1042,10 @@ setGeneric("lpad", function(x, len, pad) { standardGeneric("lpad") }) #' @name NULL setGeneric("ltrim", function(x, trimString) { standardGeneric("ltrim") }) +#' @rdname column_collection_functions +#' @name NULL +setGeneric("map_entries", function(x) { standardGeneric("map_entries") }) + #' @rdname column_collection_functions #' @name NULL setGeneric("map_keys", function(x) { standardGeneric("map_keys") }) diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index 13b55ac6e6e3c..16c1fd5a065eb 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -1503,6 +1503,21 @@ test_that("column functions", { result <- collect(select(df2, reverse(df2[[1]])))[[1]] expect_equal(result, "cba") + # Test array_repeat() + df <- createDataFrame(list(list("a", 3L), list("b", 2L))) + result <- collect(select(df, array_repeat(df[[1]], df[[2]])))[[1]] + expect_equal(result, list(list("a", "a", "a"), list("b", "b"))) + + result <- collect(select(df, array_repeat(df[[1]], 2L)))[[1]] + expect_equal(result, list(list("a", "a"), list("b", "b"))) + + # Test arrays_overlap() + df <- createDataFrame(list(list(list(1L, 2L), list(3L, 1L)), + list(list(1L, 2L), list(3L, 4L)), + list(list(1L, NA), list(3L, 4L)))) + result <- collect(select(df, arrays_overlap(df[[1]], df[[2]])))[[1]] + expect_equal(result, c(TRUE, FALSE, NA)) + # Test array_sort() and sort_array() df <- createDataFrame(list(list(list(2L, 1L, 3L, NA)), list(list(NA, 6L, 5L, NA, 4L)))) @@ -1531,8 +1546,13 @@ test_that("column functions", { result <- collect(select(df, flatten(df[[1]])))[[1]] expect_equal(result, list(list(1L, 2L, 3L, 4L), list(5L, 6L, 7L, 8L))) - # Test map_keys(), map_values() and element_at() + # Test map_entries(), map_keys(), map_values() and element_at() df <- createDataFrame(list(list(map = as.environment(list(x = 1, y = 2))))) + result <- collect(select(df, map_entries(df$map)))[[1]] + expected_entries <- list(listToStruct(list(key = "x", value = 1)), + listToStruct(list(key = "y", value = 2))) + expect_equal(result, list(expected_entries)) + result <- collect(select(df, map_keys(df$map)))[[1]] expect_equal(result, list(list("x", "y"))) From 0ebb0c0d4dd3e192464dc5e0e6f01efa55b945ed Mon Sep 17 00:00:00 2001 From: e-dorigatti Date: Wed, 30 May 2018 18:11:33 +0800 Subject: [PATCH 0889/1161] [SPARK-23754][PYTHON] Re-raising StopIteration in client code ## What changes were proposed in this pull request? Make sure that `StopIteration`s raised in users' code do not silently interrupt processing by spark, but are raised as exceptions to the users. The users' functions are wrapped in `safe_iter` (in `shuffle.py`), which re-raises `StopIteration`s as `RuntimeError`s ## How was this patch tested? Unit tests, making sure that the exceptions are indeed raised. I am not sure how to check whether a `Py4JJavaError` contains my exception, so I simply looked for the exception message in the java exception's `toString`. Can you propose a better way? ## License This is my original work, licensed in the same way as spark Author: e-dorigatti Author: edorigatti Closes #21383 from e-dorigatti/fix_spark_23754. --- python/pyspark/rdd.py | 18 ++++++++++--- python/pyspark/shuffle.py | 7 ++--- python/pyspark/sql/tests.py | 16 +++++++++++ python/pyspark/sql/udf.py | 14 ++++++++-- python/pyspark/tests.py | 53 +++++++++++++++++++++++++++++++++++++ python/pyspark/util.py | 28 +++++++++++++++++--- 6 files changed, 125 insertions(+), 11 deletions(-) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index d5a237a5b2855..14d9128502ab0 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -53,6 +53,7 @@ from pyspark.shuffle import Aggregator, ExternalMerger, \ get_used_memory, ExternalSorter, ExternalGroupBy from pyspark.traceback_utils import SCCallSiteSync +from pyspark.util import fail_on_stopiteration __all__ = ["RDD"] @@ -339,7 +340,7 @@ def map(self, f, preservesPartitioning=False): [('a', 1), ('b', 1), ('c', 1)] """ def func(_, iterator): - return map(f, iterator) + return map(fail_on_stopiteration(f), iterator) return self.mapPartitionsWithIndex(func, preservesPartitioning) def flatMap(self, f, preservesPartitioning=False): @@ -354,7 +355,7 @@ def flatMap(self, f, preservesPartitioning=False): [(2, 2), (2, 2), (3, 3), (3, 3), (4, 4), (4, 4)] """ def func(s, iterator): - return chain.from_iterable(map(f, iterator)) + return chain.from_iterable(map(fail_on_stopiteration(f), iterator)) return self.mapPartitionsWithIndex(func, preservesPartitioning) def mapPartitions(self, f, preservesPartitioning=False): @@ -417,7 +418,7 @@ def filter(self, f): [2, 4] """ def func(iterator): - return filter(f, iterator) + return filter(fail_on_stopiteration(f), iterator) return self.mapPartitions(func, True) def distinct(self, numPartitions=None): @@ -798,6 +799,8 @@ def foreach(self, f): >>> def f(x): print(x) >>> sc.parallelize([1, 2, 3, 4, 5]).foreach(f) """ + f = fail_on_stopiteration(f) + def processPartition(iterator): for x in iterator: f(x) @@ -847,6 +850,8 @@ def reduce(self, f): ... ValueError: Can not reduce() empty RDD """ + f = fail_on_stopiteration(f) + def func(iterator): iterator = iter(iterator) try: @@ -918,6 +923,8 @@ def fold(self, zeroValue, op): >>> sc.parallelize([1, 2, 3, 4, 5]).fold(0, add) 15 """ + op = fail_on_stopiteration(op) + def func(iterator): acc = zeroValue for obj in iterator: @@ -950,6 +957,9 @@ def aggregate(self, zeroValue, seqOp, combOp): >>> sc.parallelize([]).aggregate((0, 0), seqOp, combOp) (0, 0) """ + seqOp = fail_on_stopiteration(seqOp) + combOp = fail_on_stopiteration(combOp) + def func(iterator): acc = zeroValue for obj in iterator: @@ -1643,6 +1653,8 @@ def reduceByKeyLocally(self, func): >>> sorted(rdd.reduceByKeyLocally(add).items()) [('a', 2), ('b', 1)] """ + func = fail_on_stopiteration(func) + def reducePartition(iterator): m = {} for k, v in iterator: diff --git a/python/pyspark/shuffle.py b/python/pyspark/shuffle.py index 02c773302e9da..bd0ac0039ffe1 100644 --- a/python/pyspark/shuffle.py +++ b/python/pyspark/shuffle.py @@ -28,6 +28,7 @@ import pyspark.heapq3 as heapq from pyspark.serializers import BatchedSerializer, PickleSerializer, FlattenedValuesSerializer, \ CompressedSerializer, AutoBatchedSerializer +from pyspark.util import fail_on_stopiteration try: @@ -94,9 +95,9 @@ class Aggregator(object): """ def __init__(self, createCombiner, mergeValue, mergeCombiners): - self.createCombiner = createCombiner - self.mergeValue = mergeValue - self.mergeCombiners = mergeCombiners + self.createCombiner = fail_on_stopiteration(createCombiner) + self.mergeValue = fail_on_stopiteration(mergeValue) + self.mergeCombiners = fail_on_stopiteration(mergeCombiners) class SimpleAggregator(Aggregator): diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index c7bd8f01b907f..a2450932e303d 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -900,6 +900,22 @@ def __call__(self, x): self.assertEqual(f, f_.func) self.assertEqual(return_type, f_.returnType) + def test_stopiteration_in_udf(self): + # test for SPARK-23754 + from pyspark.sql.functions import udf + from py4j.protocol import Py4JJavaError + + def foo(x): + raise StopIteration() + + with self.assertRaises(Py4JJavaError) as cm: + self.spark.range(0, 1000).withColumn('v', udf(foo)('id')).show() + + self.assertIn( + "Caught StopIteration thrown from user's code; failing the task", + cm.exception.java_exception.toString() + ) + def test_validate_column_types(self): from pyspark.sql.functions import udf, to_json from pyspark.sql.column import _to_java_column diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index 9dbe49b831cef..c8fb49d7c2b65 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -25,7 +25,7 @@ from pyspark.sql.column import Column, _to_java_column, _to_seq from pyspark.sql.types import StringType, DataType, StructType, _parse_datatype_string,\ to_arrow_type, to_arrow_schema -from pyspark.util import _get_argspec +from pyspark.util import _get_argspec, fail_on_stopiteration __all__ = ["UDFRegistration"] @@ -157,7 +157,17 @@ def _create_judf(self): spark = SparkSession.builder.getOrCreate() sc = spark.sparkContext - wrapped_func = _wrap_function(sc, self.func, self.returnType) + func = fail_on_stopiteration(self.func) + + # for pandas UDFs the worker needs to know if the function takes + # one or two arguments, but the signature is lost when wrapping with + # fail_on_stopiteration, so we store it here + if self.evalType in (PythonEvalType.SQL_SCALAR_PANDAS_UDF, + PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, + PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF): + func._argspec = _get_argspec(self.func) + + wrapped_func = _wrap_function(sc, func, self.returnType) jdt = spark._jsparkSession.parseDataType(self.returnType.json()) judf = sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonFunction( self._name, wrapped_func, jdt, self.evalType, self.deterministic) diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 498d6b57e4353..3b37cc028c1b7 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -161,6 +161,37 @@ def gen_gs(N, step=1): self.assertEqual(k, len(vs)) self.assertEqual(list(range(k)), list(vs)) + def test_stopiteration_is_raised(self): + + def stopit(*args, **kwargs): + raise StopIteration() + + def legit_create_combiner(x): + return [x] + + def legit_merge_value(x, y): + return x.append(y) or x + + def legit_merge_combiners(x, y): + return x.extend(y) or x + + data = [(x % 2, x) for x in range(100)] + + # wrong create combiner + m = ExternalMerger(Aggregator(stopit, legit_merge_value, legit_merge_combiners), 20) + with self.assertRaises((Py4JJavaError, RuntimeError)) as cm: + m.mergeValues(data) + + # wrong merge value + m = ExternalMerger(Aggregator(legit_create_combiner, stopit, legit_merge_combiners), 20) + with self.assertRaises((Py4JJavaError, RuntimeError)) as cm: + m.mergeValues(data) + + # wrong merge combiners + m = ExternalMerger(Aggregator(legit_create_combiner, legit_merge_value, stopit), 20) + with self.assertRaises((Py4JJavaError, RuntimeError)) as cm: + m.mergeCombiners(map(lambda x_y1: (x_y1[0], [x_y1[1]]), data)) + class SorterTests(unittest.TestCase): def test_in_memory_sort(self): @@ -1246,6 +1277,28 @@ def test_pipe_unicode(self): result = rdd.pipe('cat').collect() self.assertEqual(data, result) + def test_stopiteration_in_client_code(self): + + def stopit(*x): + raise StopIteration() + + seq_rdd = self.sc.parallelize(range(10)) + keyed_rdd = self.sc.parallelize((x % 2, x) for x in range(10)) + + self.assertRaises(Py4JJavaError, seq_rdd.map(stopit).collect) + self.assertRaises(Py4JJavaError, seq_rdd.filter(stopit).collect) + self.assertRaises(Py4JJavaError, seq_rdd.cartesian(seq_rdd).flatMap(stopit).collect) + self.assertRaises(Py4JJavaError, seq_rdd.foreach, stopit) + self.assertRaises(Py4JJavaError, keyed_rdd.reduceByKeyLocally, stopit) + self.assertRaises(Py4JJavaError, seq_rdd.reduce, stopit) + self.assertRaises(Py4JJavaError, seq_rdd.fold, 0, stopit) + + # the exception raised is non-deterministic + self.assertRaises((Py4JJavaError, RuntimeError), + seq_rdd.aggregate, 0, stopit, lambda *x: 1) + self.assertRaises((Py4JJavaError, RuntimeError), + seq_rdd.aggregate, 0, lambda *x: 1, stopit) + class ProfilerTests(PySparkTestCase): diff --git a/python/pyspark/util.py b/python/pyspark/util.py index 59cc2a6329350..e95a9b523393f 100644 --- a/python/pyspark/util.py +++ b/python/pyspark/util.py @@ -53,11 +53,16 @@ def _get_argspec(f): """ Get argspec of a function. Supports both Python 2 and Python 3. """ - # `getargspec` is deprecated since python3.0 (incompatible with function annotations). - # See SPARK-23569. - if sys.version_info[0] < 3: + + if hasattr(f, '_argspec'): + # only used for pandas UDF: they wrap the user function, losing its signature + # workers need this signature, so UDF saves it here + argspec = f._argspec + elif sys.version_info[0] < 3: argspec = inspect.getargspec(f) else: + # `getargspec` is deprecated since python3.0 (incompatible with function annotations). + # See SPARK-23569. argspec = inspect.getfullargspec(f) return argspec @@ -89,6 +94,23 @@ def majorMinorVersion(sparkVersion): " version numbers.") +def fail_on_stopiteration(f): + """ + Wraps the input function to fail on 'StopIteration' by raising a 'RuntimeError' + prevents silent loss of data when 'f' is used in a for loop + """ + def wrapper(*args, **kwargs): + try: + return f(*args, **kwargs) + except StopIteration as exc: + raise RuntimeError( + "Caught StopIteration thrown from user's code; failing the task", + exc + ) + + return wrapper + + if __name__ == "__main__": import doctest (failure_count, test_count) = doctest.testmod() From 9e7bad0edd9f6c59c0af21c95e5df98cf82150d3 Mon Sep 17 00:00:00 2001 From: DB Tsai Date: Wed, 30 May 2018 05:18:18 -0700 Subject: [PATCH 0890/1161] [SPARK-24419][BUILD] Upgrade SBT to 0.13.17 with Scala 2.10.7 for JDK9+ ## What changes were proposed in this pull request? Upgrade SBT to 0.13.17 with Scala 2.10.7 for JDK9+ ## How was this patch tested? Existing tests Author: DB Tsai Closes #21458 from dbtsai/sbt. --- project/build.properties | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/project/build.properties b/project/build.properties index b19518fd7aa1c..d03985d980ec8 100644 --- a/project/build.properties +++ b/project/build.properties @@ -14,4 +14,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -sbt.version=0.13.16 +sbt.version=0.13.17 From 1e46f92f956a00d04d47340489b6125d44dbd47b Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Thu, 31 May 2018 00:23:25 +0800 Subject: [PATCH 0891/1161] [SPARK-24369][SQL] Correct handling for multiple distinct aggregations having the same argument set ## What changes were proposed in this pull request? This pr fixed an issue when having multiple distinct aggregations having the same argument set, e.g., ``` scala>: paste val df = sql( s"""SELECT corr(DISTINCT x, y), corr(DISTINCT y, x), count(*) | FROM (VALUES (1, 1), (2, 2), (2, 2)) t(x, y) """.stripMargin) java.lang.RuntimeException You hit a query analyzer bug. Please report your query to Spark user mailing list. ``` The root cause is that `RewriteDistinctAggregates` can't detect multiple distinct aggregations if they have the same argument set. This pr modified code so that `RewriteDistinctAggregates` could count the number of aggregate expressions with `isDistinct=true`. ## How was this patch tested? Added tests in `DataFrameAggregateSuite`. Author: Takeshi Yamamuro Closes #21443 from maropu/SPARK-24369. --- .../optimizer/RewriteDistinctAggregates.scala | 7 ++++--- .../apache/spark/sql/execution/SparkStrategies.scala | 2 +- .../src/test/resources/sql-tests/inputs/group-by.sql | 6 +++++- .../test/resources/sql-tests/results/group-by.sql.out | 11 ++++++++++- 4 files changed, 20 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala index 4448ace7105a4..bc898ab0dc723 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala @@ -115,7 +115,8 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { } // Extract distinct aggregate expressions. - val distinctAggGroups = aggExpressions.filter(_.isDistinct).groupBy { e => + val distincgAggExpressions = aggExpressions.filter(_.isDistinct) + val distinctAggGroups = distincgAggExpressions.groupBy { e => val unfoldableChildren = e.aggregateFunction.children.filter(!_.foldable).toSet if (unfoldableChildren.nonEmpty) { // Only expand the unfoldable children @@ -132,7 +133,7 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { } // Aggregation strategy can handle queries with a single distinct group. - if (distinctAggGroups.size > 1) { + if (distincgAggExpressions.size > 1) { // Create the attributes for the grouping id and the group by clause. val gid = AttributeReference("gid", IntegerType, nullable = false)() val groupByMap = a.groupingExpressions.collect { @@ -151,7 +152,7 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { } // Setup unique distinct aggregate children. - val distinctAggChildren = distinctAggGroups.keySet.flatten.toSeq.distinct + val distinctAggChildren = distinctAggGroups.keySet.flatten.toSeq val distinctAggChildAttrMap = distinctAggChildren.map(expressionAttributePair) val distinctAggChildAttrs = distinctAggChildAttrMap.map(_._2) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index b97a87a122406..b9452b58657a4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -386,7 +386,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { aggregateExpressions.partition(_.isDistinct) if (functionsWithDistinct.map(_.aggregateFunction.children).distinct.length > 1) { // This is a sanity check. We should not reach here when we have multiple distinct - // column sets. Our MultipleDistinctRewriter should take care this case. + // column sets. Our `RewriteDistinctAggregates` should take care this case. sys.error("You hit a query analyzer bug. Please report your query to " + "Spark user mailing list.") } diff --git a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql index c5070b734d521..2c18d6aaabdba 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql @@ -68,4 +68,8 @@ SELECT 1 from ( FROM (select 1 as x) a WHERE false ) b -where b.z != b.z +where b.z != b.z; + +-- SPARK-24369 multiple distinct aggregations having the same argument set +SELECT corr(DISTINCT x, y), corr(DISTINCT y, x), count(*) + FROM (VALUES (1, 1), (2, 2), (2, 2)) t(x, y); diff --git a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out index c1abc6dff754b..581aa1754ce14 100644 --- a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 26 +-- Number of queries: 27 -- !query 0 @@ -241,3 +241,12 @@ where b.z != b.z struct<1:int> -- !query 25 output + + +-- !query 26 +SELECT corr(DISTINCT x, y), corr(DISTINCT y, x), count(*) + FROM (VALUES (1, 1), (2, 2), (2, 2)) t(x, y) +-- !query 26 schema +struct +-- !query 26 output +1.0 1.0 3 From b142157dcc7f595eea93d66dda8b1d169a38d95c Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Wed, 30 May 2018 10:33:34 -0700 Subject: [PATCH 0892/1161] [SPARK-24384][PYTHON][SPARK SUBMIT] Add .py files correctly into PythonRunner in submit with client mode in spark-submit ## What changes were proposed in this pull request? In client side before context initialization specifically, .py file doesn't work in client side before context initialization when the application is a Python file. See below: ``` $ cat /home/spark/tmp.py def testtest(): return 1 ``` This works: ``` $ cat app.py import pyspark pyspark.sql.SparkSession.builder.getOrCreate() import tmp print("************************%s" % tmp.testtest()) $ ./bin/spark-submit --master yarn --deploy-mode client --py-files /home/spark/tmp.py app.py ... ************************1 ``` but this doesn't: ``` $ cat app.py import pyspark import tmp pyspark.sql.SparkSession.builder.getOrCreate() print("************************%s" % tmp.testtest()) $ ./bin/spark-submit --master yarn --deploy-mode client --py-files /home/spark/tmp.py app.py Traceback (most recent call last): File "/home/spark/spark/app.py", line 2, in import tmp ImportError: No module named tmp ``` ### How did it happen? In client mode specifically, the paths are being added into PythonRunner as are: https://github.com/apache/spark/blob/628c7b517969c4a7ccb26ea67ab3dd61266073ca/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala#L430 https://github.com/apache/spark/blob/628c7b517969c4a7ccb26ea67ab3dd61266073ca/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala#L49-L88 The problem here is, .py file shouldn't be added as are since `PYTHONPATH` expects a directory or an archive like zip or egg. ### How does this PR fix? We shouldn't simply just add its parent directory because other files in the parent directory could also be added into the `PYTHONPATH` in client mode before context initialization. Therefore, we copy .py files into a temp directory for .py files and add it to `PYTHONPATH`. ## How was this patch tested? Unit tests are added and manually tested in both standalond and yarn client modes with submit. Author: hyukjinkwon Closes #21426 from HyukjinKwon/SPARK-24384. --- .../apache/spark/deploy/PythonRunner.scala | 29 ++++++++++++++++++- .../spark/deploy/yarn/YarnClusterSuite.scala | 15 ++++------ 2 files changed, 33 insertions(+), 11 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala b/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala index 1b7e031ee0678..ccb30e205ca40 100644 --- a/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala @@ -19,6 +19,7 @@ package org.apache.spark.deploy import java.io.File import java.net.{InetAddress, URI} +import java.nio.file.Files import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer @@ -48,7 +49,7 @@ object PythonRunner { // Format python file paths before adding them to the PYTHONPATH val formattedPythonFile = formatPath(pythonFile) - val formattedPyFiles = formatPaths(pyFiles) + val formattedPyFiles = resolvePyFiles(formatPaths(pyFiles)) // Launch a Py4J gateway server for the process to connect to; this will let it see our // Java system properties and such @@ -153,4 +154,30 @@ object PythonRunner { .map { p => formatPath(p, testWindows) } } + /** + * Resolves the ".py" files. ".py" file should not be added as is because PYTHONPATH does + * not expect a file. This method creates a temporary directory and puts the ".py" files + * if exist in the given paths. + */ + private def resolvePyFiles(pyFiles: Array[String]): Array[String] = { + lazy val dest = Utils.createTempDir(namePrefix = "localPyFiles") + pyFiles.flatMap { pyFile => + // In case of client with submit, the python paths should be set before context + // initialization because the context initialization can be done later. + // We will copy the local ".py" files because ".py" file shouldn't be added + // alone but its parent directory in PYTHONPATH. See SPARK-24384. + if (pyFile.endsWith(".py")) { + val source = new File(pyFile) + if (source.exists() && source.isFile && source.canRead) { + Files.copy(source.toPath, new File(dest, source.getName).toPath) + Some(dest.getAbsolutePath) + } else { + // Don't have to add it if it doesn't exist or isn't readable. + None + } + } else { + Some(pyFile) + } + }.distinct + } } diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala index 59b0f29e37d84..3b78b88de778d 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala @@ -271,16 +271,11 @@ class YarnClusterSuite extends BaseYarnClusterSuite { "PYSPARK_ARCHIVES_PATH" -> pythonPath.map("local:" + _).mkString(File.pathSeparator), "PYTHONPATH" -> pythonPath.mkString(File.pathSeparator)) ++ extraEnv - val moduleDir = - if (clientMode) { - // In client-mode, .py files added with --py-files are not visible in the driver. - // This is something that the launcher library would have to handle. - tempDir - } else { - val subdir = new File(tempDir, "pyModules") - subdir.mkdir() - subdir - } + val moduleDir = { + val subdir = new File(tempDir, "pyModules") + subdir.mkdir() + subdir + } val pyModule = new File(moduleDir, "mod1.py") Files.write(TEST_PYMODULE, pyModule, StandardCharsets.UTF_8) From ec6f971dc57bcdc0ad65ac1987b6f0c1801157f4 Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Wed, 30 May 2018 11:04:09 -0700 Subject: [PATCH 0893/1161] [SPARK-23161][PYSPARK][ML] Add missing APIs to Python GBTClassifier ## What changes were proposed in this pull request? Add featureSubsetStrategy in GBTClassifier and GBTRegressor. Also make GBTClassificationModel inherit from JavaClassificationModel instead of prediction model so it will have numClasses. ## How was this patch tested? Add tests in doctest Author: Huaxin Gao Closes #21413 from huaxingao/spark-23161. --- python/pyspark/ml/classification.py | 35 ++++++++++++--- python/pyspark/ml/regression.py | 70 ++++++++++++++++++----------- 2 files changed, 74 insertions(+), 31 deletions(-) diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 424ecfd89b060..1754c48937a62 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -1131,6 +1131,13 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre def _create_model(self, java_model): return RandomForestClassificationModel(java_model) + @since("2.4.0") + def setFeatureSubsetStrategy(self, value): + """ + Sets the value of :py:attr:`featureSubsetStrategy`. + """ + return self._set(featureSubsetStrategy=value) + class RandomForestClassificationModel(TreeEnsembleModel, JavaClassificationModel, JavaMLWritable, JavaMLReadable): @@ -1193,6 +1200,8 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol >>> si_model = stringIndexer.fit(df) >>> td = si_model.transform(df) >>> gbt = GBTClassifier(maxIter=5, maxDepth=2, labelCol="indexed", seed=42) + >>> gbt.getFeatureSubsetStrategy() + 'all' >>> model = gbt.fit(td) >>> model.featureImportances SparseVector(1, {0: 1.0}) @@ -1226,6 +1235,8 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol ... ["indexed", "features"]) >>> model.evaluateEachIteration(validation) [0.25..., 0.23..., 0.21..., 0.19..., 0.18...] + >>> model.numClasses + 2 .. versionadded:: 1.4.0 """ @@ -1244,19 +1255,22 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, lossType="logistic", - maxIter=20, stepSize=0.1, seed=None, subsamplingRate=1.0): + maxIter=20, stepSize=0.1, seed=None, subsamplingRate=1.0, + featureSubsetStrategy="all"): """ __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \ maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \ - lossType="logistic", maxIter=20, stepSize=0.1, seed=None, subsamplingRate=1.0) + lossType="logistic", maxIter=20, stepSize=0.1, seed=None, subsamplingRate=1.0, \ + featureSubsetStrategy="all") """ super(GBTClassifier, self).__init__() self._java_obj = self._new_java_obj( "org.apache.spark.ml.classification.GBTClassifier", self.uid) self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, - lossType="logistic", maxIter=20, stepSize=0.1, subsamplingRate=1.0) + lossType="logistic", maxIter=20, stepSize=0.1, subsamplingRate=1.0, + featureSubsetStrategy="all") kwargs = self._input_kwargs self.setParams(**kwargs) @@ -1265,12 +1279,14 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, - lossType="logistic", maxIter=20, stepSize=0.1, seed=None, subsamplingRate=1.0): + lossType="logistic", maxIter=20, stepSize=0.1, seed=None, subsamplingRate=1.0, + featureSubsetStrategy="all"): """ setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \ maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \ - lossType="logistic", maxIter=20, stepSize=0.1, seed=None, subsamplingRate=1.0) + lossType="logistic", maxIter=20, stepSize=0.1, seed=None, subsamplingRate=1.0, \ + featureSubsetStrategy="all") Sets params for Gradient Boosted Tree Classification. """ kwargs = self._input_kwargs @@ -1293,8 +1309,15 @@ def getLossType(self): """ return self.getOrDefault(self.lossType) + @since("2.4.0") + def setFeatureSubsetStrategy(self, value): + """ + Sets the value of :py:attr:`featureSubsetStrategy`. + """ + return self._set(featureSubsetStrategy=value) + -class GBTClassificationModel(TreeEnsembleModel, JavaPredictionModel, JavaMLWritable, +class GBTClassificationModel(TreeEnsembleModel, JavaClassificationModel, JavaMLWritable, JavaMLReadable): """ Model fitted by GBTClassifier. diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index dd0b62f184d26..dba0e57b01a0b 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -602,6 +602,14 @@ class TreeEnsembleParams(DecisionTreeParams): "used for learning each decision tree, in range (0, 1].", typeConverter=TypeConverters.toFloat) + supportedFeatureSubsetStrategies = ["auto", "all", "onethird", "sqrt", "log2"] + + featureSubsetStrategy = \ + Param(Params._dummy(), "featureSubsetStrategy", + "The number of features to consider for splits at each tree node. Supported " + + "options: " + ", ".join(supportedFeatureSubsetStrategies) + ", (0.0-1.0], [1-n].", + typeConverter=TypeConverters.toString) + def __init__(self): super(TreeEnsembleParams, self).__init__() @@ -619,6 +627,22 @@ def getSubsamplingRate(self): """ return self.getOrDefault(self.subsamplingRate) + @since("1.4.0") + def setFeatureSubsetStrategy(self, value): + """ + Sets the value of :py:attr:`featureSubsetStrategy`. + + .. note:: Deprecated in 2.4.0 and will be removed in 3.0.0. + """ + return self._set(featureSubsetStrategy=value) + + @since("1.4.0") + def getFeatureSubsetStrategy(self): + """ + Gets the value of featureSubsetStrategy or its default value. + """ + return self.getOrDefault(self.featureSubsetStrategy) + class TreeRegressorParams(Params): """ @@ -654,14 +678,8 @@ class RandomForestParams(TreeEnsembleParams): Private class to track supported random forest parameters. """ - supportedFeatureSubsetStrategies = ["auto", "all", "onethird", "sqrt", "log2"] numTrees = Param(Params._dummy(), "numTrees", "Number of trees to train (>= 1).", typeConverter=TypeConverters.toInt) - featureSubsetStrategy = \ - Param(Params._dummy(), "featureSubsetStrategy", - "The number of features to consider for splits at each tree node. Supported " + - "options: " + ", ".join(supportedFeatureSubsetStrategies) + ", (0.0-1.0], [1-n].", - typeConverter=TypeConverters.toString) def __init__(self): super(RandomForestParams, self).__init__() @@ -680,20 +698,6 @@ def getNumTrees(self): """ return self.getOrDefault(self.numTrees) - @since("1.4.0") - def setFeatureSubsetStrategy(self, value): - """ - Sets the value of :py:attr:`featureSubsetStrategy`. - """ - return self._set(featureSubsetStrategy=value) - - @since("1.4.0") - def getFeatureSubsetStrategy(self): - """ - Gets the value of featureSubsetStrategy or its default value. - """ - return self.getOrDefault(self.featureSubsetStrategy) - class GBTParams(TreeEnsembleParams): """ @@ -981,6 +985,13 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre def _create_model(self, java_model): return RandomForestRegressionModel(java_model) + @since("2.4.0") + def setFeatureSubsetStrategy(self, value): + """ + Sets the value of :py:attr:`featureSubsetStrategy`. + """ + return self._set(featureSubsetStrategy=value) + class RandomForestRegressionModel(TreeEnsembleModel, JavaPredictionModel, JavaMLWritable, JavaMLReadable): @@ -1029,6 +1040,8 @@ class GBTRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, >>> gbt = GBTRegressor(maxIter=5, maxDepth=2, seed=42) >>> print(gbt.getImpurity()) variance + >>> print(gbt.getFeatureSubsetStrategy()) + all >>> model = gbt.fit(df) >>> model.featureImportances SparseVector(1, {0: 1.0}) @@ -1079,20 +1092,20 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMemoryInMB=256, cacheNodeIds=False, subsamplingRate=1.0, checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1, seed=None, - impurity="variance"): + impurity="variance", featureSubsetStrategy="all"): """ __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \ maxMemoryInMB=256, cacheNodeIds=False, subsamplingRate=1.0, \ checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1, seed=None, \ - impurity="variance") + impurity="variance", featureSubsetStrategy="all") """ super(GBTRegressor, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.regression.GBTRegressor", self.uid) self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMemoryInMB=256, cacheNodeIds=False, subsamplingRate=1.0, checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1, - impurity="variance") + impurity="variance", featureSubsetStrategy="all") kwargs = self._input_kwargs self.setParams(**kwargs) @@ -1102,13 +1115,13 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMemoryInMB=256, cacheNodeIds=False, subsamplingRate=1.0, checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1, seed=None, - impuriy="variance"): + impuriy="variance", featureSubsetStrategy="all"): """ setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \ maxMemoryInMB=256, cacheNodeIds=False, subsamplingRate=1.0, \ checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1, seed=None, \ - impurity="variance") + impurity="variance", featureSubsetStrategy="all") Sets params for Gradient Boosted Tree Regression. """ kwargs = self._input_kwargs @@ -1131,6 +1144,13 @@ def getLossType(self): """ return self.getOrDefault(self.lossType) + @since("2.4.0") + def setFeatureSubsetStrategy(self, value): + """ + Sets the value of :py:attr:`featureSubsetStrategy`. + """ + return self._set(featureSubsetStrategy=value) + class GBTRegressionModel(TreeEnsembleModel, JavaPredictionModel, JavaMLWritable, JavaMLReadable): """ From 1b36f148891ac41ef36a40366f87dd5405cb3751 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Wed, 30 May 2018 11:18:04 -0700 Subject: [PATCH 0894/1161] [SPARK-23901][SQL] Add masking functions ## What changes were proposed in this pull request? The PR adds the masking function as they are described in Hive's documentation: https://cwiki.apache.org/confluence/display/Hive/LanguageManual+UDF#LanguageManualUDF-DataMaskingFunctions. This means that only `string`s are accepted as parameter for the masking functions. ## How was this patch tested? added UTs Author: Marco Gaido Closes #21246 from mgaido91/SPARK-23901. --- .../expressions/MaskExpressionsUtils.java | 80 +++ .../catalyst/analysis/FunctionRegistry.scala | 8 + .../expressions/maskExpressions.scala | 569 ++++++++++++++++++ .../expressions/MaskExpressionsSuite.scala | 236 ++++++++ .../org/apache/spark/sql/functions.scala | 119 ++++ .../spark/sql/DataFrameFunctionsSuite.scala | 107 ++++ 6 files changed, 1119 insertions(+) create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/MaskExpressionsUtils.java create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/maskExpressions.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MaskExpressionsSuite.scala diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/MaskExpressionsUtils.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/MaskExpressionsUtils.java new file mode 100644 index 0000000000000..05879902a4ed9 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/MaskExpressionsUtils.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions; + +/** + * Contains all the Utils methods used in the masking expressions. + */ +public class MaskExpressionsUtils { + static final int UNMASKED_VAL = -1; + + /** + * Returns the masking character for {@param c} or {@param c} is it should not be masked. + * @param c the character to transform + * @param maskedUpperChar the character to use instead of a uppercase letter + * @param maskedLowerChar the character to use instead of a lowercase letter + * @param maskedDigitChar the character to use instead of a digit + * @param maskedOtherChar the character to use instead of a any other character + * @return masking character for {@param c} + */ + public static int transformChar( + final int c, + int maskedUpperChar, + int maskedLowerChar, + int maskedDigitChar, + int maskedOtherChar) { + switch(Character.getType(c)) { + case Character.UPPERCASE_LETTER: + if(maskedUpperChar != UNMASKED_VAL) { + return maskedUpperChar; + } + break; + + case Character.LOWERCASE_LETTER: + if(maskedLowerChar != UNMASKED_VAL) { + return maskedLowerChar; + } + break; + + case Character.DECIMAL_DIGIT_NUMBER: + if(maskedDigitChar != UNMASKED_VAL) { + return maskedDigitChar; + } + break; + + default: + if(maskedOtherChar != UNMASKED_VAL) { + return maskedOtherChar; + } + break; + } + + return c; + } + + /** + * Returns the replacement char to use according to the {@param rep} specified by the user and + * the {@param def} default. + */ + public static int getReplacementChar(String rep, int def) { + if (rep != null && rep.length() > 0) { + return rep.codePointAt(0); + } + return def; + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 1134a8866dc13..23a4a440fac23 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -432,6 +432,14 @@ object FunctionRegistry { expression[ArrayRepeat]("array_repeat"), CreateStruct.registryEntry, + // mask functions + expression[Mask]("mask"), + expression[MaskFirstN]("mask_first_n"), + expression[MaskLastN]("mask_last_n"), + expression[MaskShowFirstN]("mask_show_first_n"), + expression[MaskShowLastN]("mask_show_last_n"), + expression[MaskHash]("mask_hash"), + // misc functions expression[AssertTrue]("assert_true"), expression[Crc32]("crc32"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/maskExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/maskExpressions.scala new file mode 100644 index 0000000000000..276a57266a6e0 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/maskExpressions.scala @@ -0,0 +1,569 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.commons.codec.digest.DigestUtils + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.expressions.MaskExpressionsUtils._ +import org.apache.spark.sql.catalyst.expressions.MaskLike._ +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + + +trait MaskLike { + def upper: String + def lower: String + def digit: String + + protected lazy val upperReplacement: Int = getReplacementChar(upper, defaultMaskedUppercase) + protected lazy val lowerReplacement: Int = getReplacementChar(lower, defaultMaskedLowercase) + protected lazy val digitReplacement: Int = getReplacementChar(digit, defaultMaskedDigit) + + protected val maskUtilsClassName: String = classOf[MaskExpressionsUtils].getName + + def inputStringLengthCode(inputString: String, length: String): String = { + s"${CodeGenerator.JAVA_INT} $length = $inputString.codePointCount(0, $inputString.length());" + } + + def appendMaskedToStringBuilderCode( + ctx: CodegenContext, + sb: String, + inputString: String, + offset: String, + numChars: String): String = { + val i = ctx.freshName("i") + val codePoint = ctx.freshName("codePoint") + s""" + |for (${CodeGenerator.JAVA_INT} $i = 0; $i < $numChars; $i++) { + | ${CodeGenerator.JAVA_INT} $codePoint = $inputString.codePointAt($offset); + | $sb.appendCodePoint($maskUtilsClassName.transformChar($codePoint, + | $upperReplacement, $lowerReplacement, + | $digitReplacement, $defaultMaskedOther)); + | $offset += Character.charCount($codePoint); + |} + """.stripMargin + } + + def appendUnchangedToStringBuilderCode( + ctx: CodegenContext, + sb: String, + inputString: String, + offset: String, + numChars: String): String = { + val i = ctx.freshName("i") + val codePoint = ctx.freshName("codePoint") + s""" + |for (${CodeGenerator.JAVA_INT} $i = 0; $i < $numChars; $i++) { + | ${CodeGenerator.JAVA_INT} $codePoint = $inputString.codePointAt($offset); + | $sb.appendCodePoint($codePoint); + | $offset += Character.charCount($codePoint); + |} + """.stripMargin + } + + def appendMaskedToStringBuilder( + sb: java.lang.StringBuilder, + inputString: String, + startOffset: Int, + numChars: Int): Int = { + var offset = startOffset + (1 to numChars) foreach { _ => + val codePoint = inputString.codePointAt(offset) + sb.appendCodePoint(transformChar( + codePoint, + upperReplacement, + lowerReplacement, + digitReplacement, + defaultMaskedOther)) + offset += Character.charCount(codePoint) + } + offset + } + + def appendUnchangedToStringBuilder( + sb: java.lang.StringBuilder, + inputString: String, + startOffset: Int, + numChars: Int): Int = { + var offset = startOffset + (1 to numChars) foreach { _ => + val codePoint = inputString.codePointAt(offset) + sb.appendCodePoint(codePoint) + offset += Character.charCount(codePoint) + } + offset + } +} + +trait MaskLikeWithN extends MaskLike { + def n: Int + protected lazy val charCount: Int = if (n < 0) 0 else n +} + +/** + * Utils for mask operations. + */ +object MaskLike { + val defaultCharCount = 4 + val defaultMaskedUppercase: Int = 'X' + val defaultMaskedLowercase: Int = 'x' + val defaultMaskedDigit: Int = 'n' + val defaultMaskedOther: Int = MaskExpressionsUtils.UNMASKED_VAL + + def extractCharCount(e: Expression): Int = e match { + case Literal(i, IntegerType | NullType) => + if (i == null) defaultCharCount else i.asInstanceOf[Int] + case Literal(_, dt) => throw new AnalysisException("Expected literal expression of type " + + s"${IntegerType.simpleString}, but got literal of ${dt.simpleString}") + case other => throw new AnalysisException(s"Expected literal expression, but got ${other.sql}") + } + + def extractReplacement(e: Expression): String = e match { + case Literal(s, StringType | NullType) => if (s == null) null else s.toString + case Literal(_, dt) => throw new AnalysisException("Expected literal expression of type " + + s"${StringType.simpleString}, but got literal of ${dt.simpleString}") + case other => throw new AnalysisException(s"Expected literal expression, but got ${other.sql}") + } +} + +/** + * Masks the input string. Additional parameters can be set to change the masking chars for + * uppercase letters, lowercase letters and digits. + */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(str[, upper[, lower[, digit]]]) - Masks str. By default, upper case letters are converted to \"X\", lower case letters are converted to \"x\" and numbers are converted to \"n\". You can override the characters used in the mask by supplying additional arguments: the second argument controls the mask character for upper case letters, the third argument for lower case letters and the fourth argument for numbers.", + examples = """ + Examples: + > SELECT _FUNC_("abcd-EFGH-8765-4321", "U", "l", "#"); + llll-UUUU-####-#### + """) +// scalastyle:on line.size.limit +case class Mask(child: Expression, upper: String, lower: String, digit: String) + extends UnaryExpression with ExpectsInputTypes with MaskLike { + + def this(child: Expression) = this(child, null.asInstanceOf[String], null, null) + + def this(child: Expression, upper: Expression) = + this(child, extractReplacement(upper), null, null) + + def this(child: Expression, upper: Expression, lower: Expression) = + this(child, extractReplacement(upper), extractReplacement(lower), null) + + def this(child: Expression, upper: Expression, lower: Expression, digit: Expression) = + this(child, extractReplacement(upper), extractReplacement(lower), extractReplacement(digit)) + + override def nullSafeEval(input: Any): Any = { + val str = input.asInstanceOf[UTF8String].toString + val length = str.codePointCount(0, str.length()) + val sb = new java.lang.StringBuilder(length) + appendMaskedToStringBuilder(sb, str, 0, length) + UTF8String.fromString(sb.toString) + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, (input: String) => { + val sb = ctx.freshName("sb") + val length = ctx.freshName("length") + val offset = ctx.freshName("offset") + val inputString = ctx.freshName("inputString") + s""" + |String $inputString = $input.toString(); + |${inputStringLengthCode(inputString, length)} + |StringBuilder $sb = new StringBuilder($length); + |${CodeGenerator.JAVA_INT} $offset = 0; + |${appendMaskedToStringBuilderCode(ctx, sb, inputString, offset, length)} + |${ev.value} = UTF8String.fromString($sb.toString()); + """.stripMargin + }) + } + + override def dataType: DataType = StringType + + override def inputTypes: Seq[AbstractDataType] = Seq(StringType) +} + +/** + * Masks the first N chars of the input string. N defaults to 4. Additional parameters can be set + * to change the masking chars for uppercase letters, lowercase letters and digits. + */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(str[, n[, upper[, lower[, digit]]]]) - Masks the first n values of str. By default, n is 4, upper case letters are converted to \"X\", lower case letters are converted to \"x\" and numbers are converted to \"n\". You can override the characters used in the mask by supplying additional arguments: the second argument controls the mask character for upper case letters, the third argument for lower case letters and the fourth argument for numbers.", + examples = """ + Examples: + > SELECT _FUNC_("1234-5678-8765-4321", 4); + nnnn-5678-8765-4321 + """) +// scalastyle:on line.size.limit +case class MaskFirstN( + child: Expression, + n: Int, + upper: String, + lower: String, + digit: String) + extends UnaryExpression with ExpectsInputTypes with MaskLikeWithN { + + def this(child: Expression) = + this(child, defaultCharCount, null, null, null) + + def this(child: Expression, n: Expression) = + this(child, extractCharCount(n), null, null, null) + + def this(child: Expression, n: Expression, upper: Expression) = + this(child, extractCharCount(n), extractReplacement(upper), null, null) + + def this(child: Expression, n: Expression, upper: Expression, lower: Expression) = + this(child, extractCharCount(n), extractReplacement(upper), extractReplacement(lower), null) + + def this( + child: Expression, + n: Expression, + upper: Expression, + lower: Expression, + digit: Expression) = + this(child, + extractCharCount(n), + extractReplacement(upper), + extractReplacement(lower), + extractReplacement(digit)) + + override def nullSafeEval(input: Any): Any = { + val str = input.asInstanceOf[UTF8String].toString + val length = str.codePointCount(0, str.length()) + val endOfMask = if (charCount > length) length else charCount + val sb = new java.lang.StringBuilder(length) + val offset = appendMaskedToStringBuilder(sb, str, 0, endOfMask) + appendUnchangedToStringBuilder(sb, str, offset, length - endOfMask) + UTF8String.fromString(sb.toString) + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, (input: String) => { + val sb = ctx.freshName("sb") + val length = ctx.freshName("length") + val offset = ctx.freshName("offset") + val inputString = ctx.freshName("inputString") + val endOfMask = ctx.freshName("endOfMask") + s""" + |String $inputString = $input.toString(); + |${inputStringLengthCode(inputString, length)} + |${CodeGenerator.JAVA_INT} $endOfMask = $charCount > $length ? $length : $charCount; + |${CodeGenerator.JAVA_INT} $offset = 0; + |StringBuilder $sb = new StringBuilder($length); + |${appendMaskedToStringBuilderCode(ctx, sb, inputString, offset, endOfMask)} + |${appendUnchangedToStringBuilderCode( + ctx, sb, inputString, offset, s"$length - $endOfMask")} + |${ev.value} = UTF8String.fromString($sb.toString()); + |""".stripMargin + }) + } + + override def dataType: DataType = StringType + + override def inputTypes: Seq[AbstractDataType] = Seq(StringType) + + override def prettyName: String = "mask_first_n" +} + +/** + * Masks the last N chars of the input string. N defaults to 4. Additional parameters can be set + * to change the masking chars for uppercase letters, lowercase letters and digits. + */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(str[, n[, upper[, lower[, digit]]]]) - Masks the last n values of str. By default, n is 4, upper case letters are converted to \"X\", lower case letters are converted to \"x\" and numbers are converted to \"n\". You can override the characters used in the mask by supplying additional arguments: the second argument controls the mask character for upper case letters, the third argument for lower case letters and the fourth argument for numbers.", + examples = """ + Examples: + > SELECT _FUNC_("1234-5678-8765-4321", 4); + 1234-5678-8765-nnnn + """, since = "2.4.0") +// scalastyle:on line.size.limit +case class MaskLastN( + child: Expression, + n: Int, + upper: String, + lower: String, + digit: String) + extends UnaryExpression with ExpectsInputTypes with MaskLikeWithN { + + def this(child: Expression) = + this(child, defaultCharCount, null, null, null) + + def this(child: Expression, n: Expression) = + this(child, extractCharCount(n), null, null, null) + + def this(child: Expression, n: Expression, upper: Expression) = + this(child, extractCharCount(n), extractReplacement(upper), null, null) + + def this(child: Expression, n: Expression, upper: Expression, lower: Expression) = + this(child, extractCharCount(n), extractReplacement(upper), extractReplacement(lower), null) + + def this( + child: Expression, + n: Expression, + upper: Expression, + lower: Expression, + digit: Expression) = + this(child, + extractCharCount(n), + extractReplacement(upper), + extractReplacement(lower), + extractReplacement(digit)) + + override def nullSafeEval(input: Any): Any = { + val str = input.asInstanceOf[UTF8String].toString + val length = str.codePointCount(0, str.length()) + val startOfMask = if (charCount >= length) 0 else length - charCount + val sb = new java.lang.StringBuilder(length) + val offset = appendUnchangedToStringBuilder(sb, str, 0, startOfMask) + appendMaskedToStringBuilder(sb, str, offset, length - startOfMask) + UTF8String.fromString(sb.toString) + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, (input: String) => { + val sb = ctx.freshName("sb") + val length = ctx.freshName("length") + val offset = ctx.freshName("offset") + val inputString = ctx.freshName("inputString") + val startOfMask = ctx.freshName("startOfMask") + s""" + |String $inputString = $input.toString(); + |${inputStringLengthCode(inputString, length)} + |${CodeGenerator.JAVA_INT} $startOfMask = $charCount >= $length ? + | 0 : $length - $charCount; + |${CodeGenerator.JAVA_INT} $offset = 0; + |StringBuilder $sb = new StringBuilder($length); + |${appendUnchangedToStringBuilderCode(ctx, sb, inputString, offset, startOfMask)} + |${appendMaskedToStringBuilderCode( + ctx, sb, inputString, offset, s"$length - $startOfMask")} + |${ev.value} = UTF8String.fromString($sb.toString()); + |""".stripMargin + }) + } + + override def dataType: DataType = StringType + + override def inputTypes: Seq[AbstractDataType] = Seq(StringType) + + override def prettyName: String = "mask_last_n" +} + +/** + * Masks all but the first N chars of the input string. N defaults to 4. Additional parameters can + * be set to change the masking chars for uppercase letters, lowercase letters and digits. + */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(str[, n[, upper[, lower[, digit]]]]) - Masks all but the first n values of str. By default, n is 4, upper case letters are converted to \"X\", lower case letters are converted to \"x\" and numbers are converted to \"n\". You can override the characters used in the mask by supplying additional arguments: the second argument controls the mask character for upper case letters, the third argument for lower case letters and the fourth argument for numbers.", + examples = """ + Examples: + > SELECT _FUNC_("1234-5678-8765-4321", 4); + 1234-nnnn-nnnn-nnnn + """, since = "2.4.0") +// scalastyle:on line.size.limit +case class MaskShowFirstN( + child: Expression, + n: Int, + upper: String, + lower: String, + digit: String) + extends UnaryExpression with ExpectsInputTypes with MaskLikeWithN { + + def this(child: Expression) = + this(child, defaultCharCount, null, null, null) + + def this(child: Expression, n: Expression) = + this(child, extractCharCount(n), null, null, null) + + def this(child: Expression, n: Expression, upper: Expression) = + this(child, extractCharCount(n), extractReplacement(upper), null, null) + + def this(child: Expression, n: Expression, upper: Expression, lower: Expression) = + this(child, extractCharCount(n), extractReplacement(upper), extractReplacement(lower), null) + + def this( + child: Expression, + n: Expression, + upper: Expression, + lower: Expression, + digit: Expression) = + this(child, + extractCharCount(n), + extractReplacement(upper), + extractReplacement(lower), + extractReplacement(digit)) + + override def nullSafeEval(input: Any): Any = { + val str = input.asInstanceOf[UTF8String].toString + val length = str.codePointCount(0, str.length()) + val startOfMask = if (charCount > length) length else charCount + val sb = new java.lang.StringBuilder(length) + val offset = appendUnchangedToStringBuilder(sb, str, 0, startOfMask) + appendMaskedToStringBuilder(sb, str, offset, length - startOfMask) + UTF8String.fromString(sb.toString) + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, (input: String) => { + val sb = ctx.freshName("sb") + val length = ctx.freshName("length") + val offset = ctx.freshName("offset") + val inputString = ctx.freshName("inputString") + val startOfMask = ctx.freshName("startOfMask") + s""" + |String $inputString = $input.toString(); + |${inputStringLengthCode(inputString, length)} + |${CodeGenerator.JAVA_INT} $startOfMask = $charCount > $length ? $length : $charCount; + |${CodeGenerator.JAVA_INT} $offset = 0; + |StringBuilder $sb = new StringBuilder($length); + |${appendUnchangedToStringBuilderCode(ctx, sb, inputString, offset, startOfMask)} + |${appendMaskedToStringBuilderCode( + ctx, sb, inputString, offset, s"$length - $startOfMask")} + |${ev.value} = UTF8String.fromString($sb.toString()); + |""".stripMargin + }) + } + + override def dataType: DataType = StringType + + override def inputTypes: Seq[AbstractDataType] = Seq(StringType) + + override def prettyName: String = "mask_show_first_n" +} + +/** + * Masks all but the last N chars of the input string. N defaults to 4. Additional parameters can + * be set to change the masking chars for uppercase letters, lowercase letters and digits. + */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(str[, n[, upper[, lower[, digit]]]]) - Masks all but the last n values of str. By default, n is 4, upper case letters are converted to \"X\", lower case letters are converted to \"x\" and numbers are converted to \"n\". You can override the characters used in the mask by supplying additional arguments: the second argument controls the mask character for upper case letters, the third argument for lower case letters and the fourth argument for numbers.", + examples = """ + Examples: + > SELECT _FUNC_("1234-5678-8765-4321", 4); + nnnn-nnnn-nnnn-4321 + """, since = "2.4.0") +// scalastyle:on line.size.limit +case class MaskShowLastN( + child: Expression, + n: Int, + upper: String, + lower: String, + digit: String) + extends UnaryExpression with ExpectsInputTypes with MaskLikeWithN { + + def this(child: Expression) = + this(child, defaultCharCount, null, null, null) + + def this(child: Expression, n: Expression) = + this(child, extractCharCount(n), null, null, null) + + def this(child: Expression, n: Expression, upper: Expression) = + this(child, extractCharCount(n), extractReplacement(upper), null, null) + + def this(child: Expression, n: Expression, upper: Expression, lower: Expression) = + this(child, extractCharCount(n), extractReplacement(upper), extractReplacement(lower), null) + + def this( + child: Expression, + n: Expression, + upper: Expression, + lower: Expression, + digit: Expression) = + this(child, + extractCharCount(n), + extractReplacement(upper), + extractReplacement(lower), + extractReplacement(digit)) + + override def nullSafeEval(input: Any): Any = { + val str = input.asInstanceOf[UTF8String].toString + val length = str.codePointCount(0, str.length()) + val endOfMask = if (charCount >= length) 0 else length - charCount + val sb = new java.lang.StringBuilder(length) + val offset = appendMaskedToStringBuilder(sb, str, 0, endOfMask) + appendUnchangedToStringBuilder(sb, str, offset, length - endOfMask) + UTF8String.fromString(sb.toString) + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, (input: String) => { + val sb = ctx.freshName("sb") + val length = ctx.freshName("length") + val offset = ctx.freshName("offset") + val inputString = ctx.freshName("inputString") + val endOfMask = ctx.freshName("endOfMask") + s""" + |String $inputString = $input.toString(); + |${inputStringLengthCode(inputString, length)} + |${CodeGenerator.JAVA_INT} $endOfMask = $charCount >= $length ? 0 : $length - $charCount; + |${CodeGenerator.JAVA_INT} $offset = 0; + |StringBuilder $sb = new StringBuilder($length); + |${appendMaskedToStringBuilderCode(ctx, sb, inputString, offset, endOfMask)} + |${appendUnchangedToStringBuilderCode( + ctx, sb, inputString, offset, s"$length - $endOfMask")} + |${ev.value} = UTF8String.fromString($sb.toString()); + |""".stripMargin + }) + } + + override def dataType: DataType = StringType + + override def inputTypes: Seq[AbstractDataType] = Seq(StringType) + + override def prettyName: String = "mask_show_last_n" +} + +/** + * Returns a hashed value based on str. + */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(str) - Returns a hashed value based on str. The hash is consistent and can be used to join masked values together across tables.", + examples = """ + Examples: + > SELECT _FUNC_("abcd-EFGH-8765-4321"); + 60c713f5ec6912229d2060df1c322776 + """) +// scalastyle:on line.size.limit +case class MaskHash(child: Expression) + extends UnaryExpression with ExpectsInputTypes { + + override def nullSafeEval(input: Any): Any = { + UTF8String.fromString(DigestUtils.md5Hex(input.asInstanceOf[UTF8String].toString)) + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, (input: String) => { + val digestUtilsClass = classOf[DigestUtils].getName.stripSuffix("$") + s""" + |${ev.value} = UTF8String.fromString($digestUtilsClass.md5Hex($input.toString())); + |""".stripMargin + }) + } + + override def dataType: DataType = StringType + + override def inputTypes: Seq[AbstractDataType] = Seq(StringType) + + override def prettyName: String = "mask_hash" +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MaskExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MaskExpressionsSuite.scala new file mode 100644 index 0000000000000..4d69dc32ace82 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MaskExpressionsSuite.scala @@ -0,0 +1,236 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.types.{IntegerType, StringType} + +class MaskExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { + + test("mask") { + checkEvaluation(Mask(Literal("abcd-EFGH-8765-4321"), "U", "l", "#"), "llll-UUUU-####-####") + checkEvaluation( + new Mask(Literal("abcd-EFGH-8765-4321"), Literal("U"), Literal("l"), Literal("#")), + "llll-UUUU-####-####") + checkEvaluation(new Mask(Literal("abcd-EFGH-8765-4321"), Literal("U"), Literal("l")), + "llll-UUUU-nnnn-nnnn") + checkEvaluation(new Mask(Literal("abcd-EFGH-8765-4321"), Literal("U")), "xxxx-UUUU-nnnn-nnnn") + checkEvaluation(new Mask(Literal("abcd-EFGH-8765-4321")), "xxxx-XXXX-nnnn-nnnn") + checkEvaluation(new Mask(Literal(null, StringType)), null) + checkEvaluation(Mask(Literal("abcd-EFGH-8765-4321"), null, "l", "#"), "llll-XXXX-####-####") + checkEvaluation(new Mask( + Literal("abcd-EFGH-8765-4321"), + Literal(null, StringType), + Literal(null, StringType), + Literal(null, StringType)), "xxxx-XXXX-nnnn-nnnn") + checkEvaluation(new Mask(Literal("abcd-EFGH-8765-4321"), Literal("Upper")), + "xxxx-UUUU-nnnn-nnnn") + checkEvaluation(new Mask(Literal("")), "") + checkEvaluation(new Mask(Literal("abcd-EFGH-8765-4321"), Literal("")), "xxxx-XXXX-nnnn-nnnn") + checkEvaluation(Mask(Literal("abcd-EFGH-8765-4321"), "", "", ""), "xxxx-XXXX-nnnn-nnnn") + // scalastyle:off nonascii + checkEvaluation(Mask(Literal("Ul9U"), "\u2200", null, null), "\u2200xn\u2200") + checkEvaluation(new Mask(Literal("Hello World, こんにちは, 𠀋"), Literal("あ"), Literal("𡈽")), + "あ𡈽𡈽𡈽𡈽 あ𡈽𡈽𡈽𡈽, こんにちは, 𠀋") + // scalastyle:on nonascii + intercept[AnalysisException] { + checkEvaluation(new Mask(Literal(""), Literal(1)), "") + } + } + + test("mask_first_n") { + checkEvaluation(MaskFirstN(Literal("aB3d-EFGH-8765"), 6, "U", "l", "#"), + "lU#l-UFGH-8765") + checkEvaluation(new MaskFirstN( + Literal("abcd-EFGH-8765-4321"), Literal(6), Literal("U"), Literal("l"), Literal("#")), + "llll-UFGH-8765-4321") + checkEvaluation( + new MaskFirstN(Literal("abcd-EFGH-8765-4321"), Literal(6), Literal("U"), Literal("l")), + "llll-UFGH-8765-4321") + checkEvaluation(new MaskFirstN(Literal("abcd-EFGH-8765-4321"), Literal(6), Literal("U")), + "xxxx-UFGH-8765-4321") + checkEvaluation(new MaskFirstN(Literal("abcd-EFGH-8765-4321"), Literal(6)), + "xxxx-XFGH-8765-4321") + intercept[AnalysisException] { + checkEvaluation(new MaskFirstN(Literal("abcd-EFGH-8765-4321"), Literal("U")), "") + } + checkEvaluation(new MaskFirstN(Literal("abcd-EFGH-8765-4321")), "xxxx-EFGH-8765-4321") + checkEvaluation(new MaskFirstN(Literal(null, StringType)), null) + checkEvaluation(MaskFirstN(Literal("abcd-EFGH-8765-4321"), 4, "U", "l", null), + "llll-EFGH-8765-4321") + checkEvaluation(new MaskFirstN( + Literal("abcd-EFGH-8765-4321"), + Literal(null, IntegerType), + Literal(null, StringType), + Literal(null, StringType), + Literal(null, StringType)), "xxxx-EFGH-8765-4321") + checkEvaluation(new MaskFirstN(Literal("abcd-EFGH-8765-4321"), Literal(6), Literal("Upper")), + "xxxx-UFGH-8765-4321") + checkEvaluation(new MaskFirstN(Literal("")), "") + checkEvaluation(new MaskFirstN(Literal("abcd-EFGH-8765-4321"), Literal(4), Literal("")), + "xxxx-EFGH-8765-4321") + checkEvaluation(MaskFirstN(Literal("abcd-EFGH-8765-4321"), 1000, "", "", ""), + "xxxx-XXXX-nnnn-nnnn") + checkEvaluation(MaskFirstN(Literal("abcd-EFGH-8765-4321"), -1, "", "", ""), + "abcd-EFGH-8765-4321") + // scalastyle:off nonascii + checkEvaluation(MaskFirstN(Literal("Ul9U"), 2, "\u2200", null, null), "\u2200x9U") + checkEvaluation(new MaskFirstN(Literal("あ, 𠀋, Hello World"), Literal(10)), + "あ, 𠀋, Xxxxo World") + // scalastyle:on nonascii + } + + test("mask_last_n") { + checkEvaluation(MaskLastN(Literal("abcd-EFGH-aB3d"), 6, "U", "l", "#"), + "abcd-EFGU-lU#l") + checkEvaluation(new MaskLastN( + Literal("abcd-EFGH-8765"), Literal(6), Literal("U"), Literal("l"), Literal("#")), + "abcd-EFGU-####") + checkEvaluation( + new MaskLastN(Literal("abcd-EFGH-8765"), Literal(6), Literal("U"), Literal("l")), + "abcd-EFGU-nnnn") + checkEvaluation( + new MaskLastN(Literal("abcd-EFGH-8765"), Literal(6), Literal("U")), + "abcd-EFGU-nnnn") + checkEvaluation( + new MaskLastN(Literal("abcd-EFGH-8765"), Literal(6)), + "abcd-EFGX-nnnn") + intercept[AnalysisException] { + checkEvaluation(new MaskLastN(Literal("abcd-EFGH-8765"), Literal("U")), "") + } + checkEvaluation(new MaskLastN(Literal("abcd-EFGH-8765-4321")), "abcd-EFGH-8765-nnnn") + checkEvaluation(new MaskLastN(Literal(null, StringType)), null) + checkEvaluation(MaskLastN(Literal("abcd-EFGH-8765-4321"), 4, "U", "l", null), + "abcd-EFGH-8765-nnnn") + checkEvaluation(new MaskLastN( + Literal("abcd-EFGH-8765-4321"), + Literal(null, IntegerType), + Literal(null, StringType), + Literal(null, StringType), + Literal(null, StringType)), "abcd-EFGH-8765-nnnn") + checkEvaluation(new MaskLastN(Literal("abcd-EFGH-8765-4321"), Literal(12), Literal("Upper")), + "abcd-EFUU-nnnn-nnnn") + checkEvaluation(new MaskLastN(Literal("")), "") + checkEvaluation(new MaskLastN(Literal("abcd-EFGH-8765-4321"), Literal(16), Literal("")), + "abcx-XXXX-nnnn-nnnn") + checkEvaluation(MaskLastN(Literal("abcd-EFGH-8765-4321"), 1000, "", "", ""), + "xxxx-XXXX-nnnn-nnnn") + checkEvaluation(MaskLastN(Literal("abcd-EFGH-8765-4321"), -1, "", "", ""), + "abcd-EFGH-8765-4321") + // scalastyle:off nonascii + checkEvaluation(MaskLastN(Literal("Ul9U"), 2, "\u2200", null, null), "Uln\u2200") + checkEvaluation(new MaskLastN(Literal("あ, 𠀋, Hello World あ 𠀋"), Literal(10)), + "あ, 𠀋, Hello Xxxxx あ 𠀋") + // scalastyle:on nonascii + } + + test("mask_show_first_n") { + checkEvaluation(MaskShowFirstN(Literal("abcd-EFGH-8765-aB3d"), 6, "U", "l", "#"), + "abcd-EUUU-####-lU#l") + checkEvaluation(new MaskShowFirstN( + Literal("abcd-EFGH-8765-4321"), Literal(6), Literal("U"), Literal("l"), Literal("#")), + "abcd-EUUU-####-####") + checkEvaluation( + new MaskShowFirstN(Literal("abcd-EFGH-8765-4321"), Literal(6), Literal("U"), Literal("l")), + "abcd-EUUU-nnnn-nnnn") + checkEvaluation(new MaskShowFirstN(Literal("abcd-EFGH-8765-4321"), Literal(6), Literal("U")), + "abcd-EUUU-nnnn-nnnn") + checkEvaluation(new MaskShowFirstN(Literal("abcd-EFGH-8765-4321"), Literal(6)), + "abcd-EXXX-nnnn-nnnn") + intercept[AnalysisException] { + checkEvaluation(new MaskShowFirstN(Literal("abcd-EFGH-8765-4321"), Literal("U")), "") + } + checkEvaluation(new MaskShowFirstN(Literal("abcd-EFGH-8765-4321")), "abcd-XXXX-nnnn-nnnn") + checkEvaluation(new MaskShowFirstN(Literal(null, StringType)), null) + checkEvaluation(MaskShowFirstN(Literal("abcd-EFGH-8765-4321"), 4, "U", "l", null), + "abcd-UUUU-nnnn-nnnn") + checkEvaluation(new MaskShowFirstN( + Literal("abcd-EFGH-8765-4321"), + Literal(null, IntegerType), + Literal(null, StringType), + Literal(null, StringType), + Literal(null, StringType)), "abcd-XXXX-nnnn-nnnn") + checkEvaluation( + new MaskShowFirstN(Literal("abcd-EFGH-8765-4321"), Literal(6), Literal("Upper")), + "abcd-EUUU-nnnn-nnnn") + checkEvaluation(new MaskShowFirstN(Literal("")), "") + checkEvaluation(new MaskShowFirstN(Literal("abcd-EFGH-8765-4321"), Literal(4), Literal("")), + "abcd-XXXX-nnnn-nnnn") + checkEvaluation(MaskShowFirstN(Literal("abcd-EFGH-8765-4321"), 1000, "", "", ""), + "abcd-EFGH-8765-4321") + checkEvaluation(MaskShowFirstN(Literal("abcd-EFGH-8765-4321"), -1, "", "", ""), + "xxxx-XXXX-nnnn-nnnn") + // scalastyle:off nonascii + checkEvaluation(MaskShowFirstN(Literal("Ul9U"), 2, "\u2200", null, null), "Uln\u2200") + checkEvaluation(new MaskShowFirstN(Literal("あ, 𠀋, Hello World"), Literal(10)), + "あ, 𠀋, Hellx Xxxxx") + // scalastyle:on nonascii + } + + test("mask_show_last_n") { + checkEvaluation(MaskShowLastN(Literal("aB3d-EFGH-8765"), 6, "U", "l", "#"), + "lU#l-UUUH-8765") + checkEvaluation(new MaskShowLastN( + Literal("abcd-EFGH-8765-4321"), Literal(6), Literal("U"), Literal("l"), Literal("#")), + "llll-UUUU-###5-4321") + checkEvaluation( + new MaskShowLastN(Literal("abcd-EFGH-8765-4321"), Literal(6), Literal("U"), Literal("l")), + "llll-UUUU-nnn5-4321") + checkEvaluation(new MaskShowLastN(Literal("abcd-EFGH-8765-4321"), Literal(6), Literal("U")), + "xxxx-UUUU-nnn5-4321") + checkEvaluation(new MaskShowLastN(Literal("abcd-EFGH-8765-4321"), Literal(6)), + "xxxx-XXXX-nnn5-4321") + intercept[AnalysisException] { + checkEvaluation(new MaskShowLastN(Literal("abcd-EFGH-8765-4321"), Literal("U")), "") + } + checkEvaluation(new MaskShowLastN(Literal("abcd-EFGH-8765-4321")), "xxxx-XXXX-nnnn-4321") + checkEvaluation(new MaskShowLastN(Literal(null, StringType)), null) + checkEvaluation(MaskShowLastN(Literal("abcd-EFGH-8765-4321"), 4, "U", "l", null), + "llll-UUUU-nnnn-4321") + checkEvaluation(new MaskShowLastN( + Literal("abcd-EFGH-8765-4321"), + Literal(null, IntegerType), + Literal(null, StringType), + Literal(null, StringType), + Literal(null, StringType)), "xxxx-XXXX-nnnn-4321") + checkEvaluation(new MaskShowLastN(Literal("abcd-EFGH-8765-4321"), Literal(6), Literal("Upper")), + "xxxx-UUUU-nnn5-4321") + checkEvaluation(new MaskShowLastN(Literal("")), "") + checkEvaluation(new MaskShowLastN(Literal("abcd-EFGH-8765-4321"), Literal(4), Literal("")), + "xxxx-XXXX-nnnn-4321") + checkEvaluation(MaskShowLastN(Literal("abcd-EFGH-8765-4321"), 1000, "", "", ""), + "abcd-EFGH-8765-4321") + checkEvaluation(MaskShowLastN(Literal("abcd-EFGH-8765-4321"), -1, "", "", ""), + "xxxx-XXXX-nnnn-nnnn") + // scalastyle:off nonascii + checkEvaluation(MaskShowLastN(Literal("Ul9U"), 2, "\u2200", null, null), "\u2200x9U") + checkEvaluation(new MaskShowLastN(Literal("あ, 𠀋, Hello World"), Literal(10)), + "あ, 𠀋, Xello World") + // scalastyle:on nonascii + } + + test("mask_hash") { + checkEvaluation(MaskHash(Literal("abcd-EFGH-8765-4321")), "60c713f5ec6912229d2060df1c322776") + checkEvaluation(MaskHash(Literal("")), "d41d8cd98f00b204e9800998ecf8427e") + checkEvaluation(MaskHash(Literal(null, StringType)), null) + // scalastyle:off nonascii + checkEvaluation(MaskHash(Literal("\u2200x9U")), "f1243ef123d516b1f32a3a75309e5711") + // scalastyle:on nonascii + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 5ab9cb3fb86a5..443ba2aa3757d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3499,6 +3499,125 @@ object functions { */ def map_entries(e: Column): Column = withExpr { MapEntries(e.expr) } + ////////////////////////////////////////////////////////////////////////////////////////////// + // Mask functions + ////////////////////////////////////////////////////////////////////////////////////////////// + /** + * Returns a string which is the masked representation of the input. + * @group mask_funcs + * @since 2.4.0 + */ + def mask(e: Column): Column = withExpr { new Mask(e.expr) } + + /** + * Returns a string which is the masked representation of the input, using `upper`, `lower` and + * `digit` as replacement characters. + * @group mask_funcs + * @since 2.4.0 + */ + def mask(e: Column, upper: String, lower: String, digit: String): Column = withExpr { + Mask(e.expr, upper, lower, digit) + } + + /** + * Returns a string with the first `n` characters masked. + * @group mask_funcs + * @since 2.4.0 + */ + def mask_first_n(e: Column, n: Int): Column = withExpr { new MaskFirstN(e.expr, Literal(n)) } + + /** + * Returns a string with the first `n` characters masked, using `upper`, `lower` and `digit` as + * replacement characters. + * @group mask_funcs + * @since 2.4.0 + */ + def mask_first_n( + e: Column, + n: Int, + upper: String, + lower: String, + digit: String): Column = withExpr { + MaskFirstN(e.expr, n, upper, lower, digit) + } + + /** + * Returns a string with the last `n` characters masked. + * @group mask_funcs + * @since 2.4.0 + */ + def mask_last_n(e: Column, n: Int): Column = withExpr { new MaskLastN(e.expr, Literal(n)) } + + /** + * Returns a string with the last `n` characters masked, using `upper`, `lower` and `digit` as + * replacement characters. + * @group mask_funcs + * @since 2.4.0 + */ + def mask_last_n( + e: Column, + n: Int, + upper: String, + lower: String, + digit: String): Column = withExpr { + MaskLastN(e.expr, n, upper, lower, digit) + } + + /** + * Returns a string with all but the first `n` characters masked. + * @group mask_funcs + * @since 2.4.0 + */ + def mask_show_first_n(e: Column, n: Int): Column = withExpr { + new MaskShowFirstN(e.expr, Literal(n)) + } + + /** + * Returns a string with all but the first `n` characters masked, using `upper`, `lower` and + * `digit` as replacement characters. + * @group mask_funcs + * @since 2.4.0 + */ + def mask_show_first_n( + e: Column, + n: Int, + upper: String, + lower: String, + digit: String): Column = withExpr { + MaskShowFirstN(e.expr, n, upper, lower, digit) + } + + /** + * Returns a string with all but the last `n` characters masked. + * @group mask_funcs + * @since 2.4.0 + */ + def mask_show_last_n(e: Column, n: Int): Column = withExpr { + new MaskShowLastN(e.expr, Literal(n)) + } + + /** + * Returns a string with all but the last `n` characters masked, using `upper`, `lower` and + * `digit` as replacement characters. + * @group mask_funcs + * @since 2.4.0 + */ + def mask_show_last_n( + e: Column, + n: Int, + upper: String, + lower: String, + digit: String): Column = withExpr { + MaskShowLastN(e.expr, n, upper, lower, digit) + } + + /** + * Returns a hashed value based on the input column. + * @group mask_funcs + * @since 2.4.0 + */ + def mask_hash(e: Column): Column = withExpr { MaskHash(e.expr) } + // scalastyle:off line.size.limit // scalastyle:off parameter.number diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 79e743d961af8..cc8bad4ded53e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -276,6 +276,113 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { ) } + test("mask functions") { + val df = Seq("TestString-123", "", null).toDF("a") + checkAnswer(df.select(mask($"a")), Seq(Row("XxxxXxxxxx-nnn"), Row(""), Row(null))) + checkAnswer(df.select(mask_first_n($"a", 4)), Seq(Row("XxxxString-123"), Row(""), Row(null))) + checkAnswer(df.select(mask_last_n($"a", 4)), Seq(Row("TestString-nnn"), Row(""), Row(null))) + checkAnswer(df.select(mask_show_first_n($"a", 4)), + Seq(Row("TestXxxxxx-nnn"), Row(""), Row(null))) + checkAnswer(df.select(mask_show_last_n($"a", 4)), + Seq(Row("XxxxXxxxxx-123"), Row(""), Row(null))) + checkAnswer(df.select(mask_hash($"a")), + Seq(Row("dd78d68ad1b23bde126812482dd70ac6"), + Row("d41d8cd98f00b204e9800998ecf8427e"), + Row(null))) + + checkAnswer(df.select(mask($"a", "U", "l", "#")), + Seq(Row("UlllUlllll-###"), Row(""), Row(null))) + checkAnswer(df.select(mask_first_n($"a", 4, "U", "l", "#")), + Seq(Row("UlllString-123"), Row(""), Row(null))) + checkAnswer(df.select(mask_last_n($"a", 4, "U", "l", "#")), + Seq(Row("TestString-###"), Row(""), Row(null))) + checkAnswer(df.select(mask_show_first_n($"a", 4, "U", "l", "#")), + Seq(Row("TestUlllll-###"), Row(""), Row(null))) + checkAnswer(df.select(mask_show_last_n($"a", 4, "U", "l", "#")), + Seq(Row("UlllUlllll-123"), Row(""), Row(null))) + + checkAnswer( + df.selectExpr("mask(a)", "mask(a, 'U')", "mask(a, 'U', 'l')", "mask(a, 'U', 'l', '#')"), + Seq(Row("XxxxXxxxxx-nnn", "UxxxUxxxxx-nnn", "UlllUlllll-nnn", "UlllUlllll-###"), + Row("", "", "", ""), + Row(null, null, null, null))) + checkAnswer(sql("select mask(null)"), Row(null)) + checkAnswer(sql("select mask('AAaa11', null, null, null)"), Row("XXxxnn")) + intercept[AnalysisException] { + checkAnswer(df.selectExpr("mask(a, a)"), Seq(Row("XxxxXxxxxx-nnn"), Row(""), Row(null))) + } + + checkAnswer( + df.selectExpr( + "mask_first_n(a)", + "mask_first_n(a, 6)", + "mask_first_n(a, 6, 'U')", + "mask_first_n(a, 6, 'U', 'l')", + "mask_first_n(a, 6, 'U', 'l', '#')"), + Seq(Row("XxxxString-123", "XxxxXxring-123", "UxxxUxring-123", "UlllUlring-123", + "UlllUlring-123"), + Row("", "", "", "", ""), + Row(null, null, null, null, null))) + checkAnswer(sql("select mask_first_n(null)"), Row(null)) + checkAnswer(sql("select mask_first_n('A1aA1a', null, null, null, null)"), Row("XnxX1a")) + intercept[AnalysisException] { + checkAnswer(spark.range(1).selectExpr("mask_first_n('A1aA1a', id)"), Row("XnxX1a")) + } + + checkAnswer( + df.selectExpr( + "mask_last_n(a)", + "mask_last_n(a, 6)", + "mask_last_n(a, 6, 'U')", + "mask_last_n(a, 6, 'U', 'l')", + "mask_last_n(a, 6, 'U', 'l', '#')"), + Seq(Row("TestString-nnn", "TestStrixx-nnn", "TestStrixx-nnn", "TestStrill-nnn", + "TestStrill-###"), + Row("", "", "", "", ""), + Row(null, null, null, null, null))) + checkAnswer(sql("select mask_last_n(null)"), Row(null)) + checkAnswer(sql("select mask_last_n('A1aA1a', null, null, null, null)"), Row("A1xXnx")) + intercept[AnalysisException] { + checkAnswer(spark.range(1).selectExpr("mask_last_n('A1aA1a', id)"), Row("A1xXnx")) + } + + checkAnswer( + df.selectExpr( + "mask_show_first_n(a)", + "mask_show_first_n(a, 6)", + "mask_show_first_n(a, 6, 'U')", + "mask_show_first_n(a, 6, 'U', 'l')", + "mask_show_first_n(a, 6, 'U', 'l', '#')"), + Seq(Row("TestXxxxxx-nnn", "TestStxxxx-nnn", "TestStxxxx-nnn", "TestStllll-nnn", + "TestStllll-###"), + Row("", "", "", "", ""), + Row(null, null, null, null, null))) + checkAnswer(sql("select mask_show_first_n(null)"), Row(null)) + checkAnswer(sql("select mask_show_first_n('A1aA1a', null, null, null, null)"), Row("A1aAnx")) + intercept[AnalysisException] { + checkAnswer(spark.range(1).selectExpr("mask_show_first_n('A1aA1a', id)"), Row("A1aAnx")) + } + + checkAnswer( + df.selectExpr( + "mask_show_last_n(a)", + "mask_show_last_n(a, 6)", + "mask_show_last_n(a, 6, 'U')", + "mask_show_last_n(a, 6, 'U', 'l')", + "mask_show_last_n(a, 6, 'U', 'l', '#')"), + Seq(Row("XxxxXxxxxx-123", "XxxxXxxxng-123", "UxxxUxxxng-123", "UlllUlllng-123", + "UlllUlllng-123"), + Row("", "", "", "", ""), + Row(null, null, null, null, null))) + checkAnswer(sql("select mask_show_last_n(null)"), Row(null)) + checkAnswer(sql("select mask_show_last_n('A1aA1a', null, null, null, null)"), Row("XnaA1a")) + intercept[AnalysisException] { + checkAnswer(spark.range(1).selectExpr("mask_show_last_n('A1aA1a', id)"), Row("XnaA1a")) + } + + checkAnswer(sql("select mask_hash(null)"), Row(null)) + } + test("sort_array/array_sort functions") { val df = Seq( (Array[Int](2, 1, 3), Array("b", "c", "a")), From 24ef7fbfa94fc85663eaf49cc574f81387c66c62 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Wed, 30 May 2018 15:31:40 -0700 Subject: [PATCH 0895/1161] [SPARK-24276][SQL] Order of literals in IN should not affect semantic equality ## What changes were proposed in this pull request? When two `In` operators are created with the same list of values, but different order, we are considering them as semantically different. This is wrong, since they have the same semantic meaning. The PR adds a canonicalization rule which orders the literals in the `In` operator so the semantic equality works properly. ## How was this patch tested? added UT Author: Marco Gaido Closes #21331 from mgaido91/SPARK-24276. --- .../catalyst/expressions/Canonicalize.scala | 6 +++ .../expressions/CanonicalizeSuite.scala | 53 +++++++++++++++++++ 2 files changed, 59 insertions(+) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CanonicalizeSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala index d848ba18356d3..7541f527a52a8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala @@ -30,6 +30,7 @@ package org.apache.spark.sql.catalyst.expressions * by `hashCode`. * - [[EqualTo]] and [[EqualNullSafe]] are reordered by `hashCode`. * - Other comparisons ([[GreaterThan]], [[LessThan]]) are reversed by `hashCode`. + * - Elements in [[In]] are reordered by `hashCode`. */ object Canonicalize { def execute(e: Expression): Expression = { @@ -85,6 +86,11 @@ object Canonicalize { case Not(GreaterThanOrEqual(l, r)) => LessThan(l, r) case Not(LessThanOrEqual(l, r)) => GreaterThan(l, r) + // order the list in the In operator + // In subqueries contain only one element of type ListQuery. So checking that the length > 1 + // we are not reordering In subqueries. + case In(value, list) if list.length > 1 => In(value, list.sortBy(_.hashCode())) + case _ => e } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CanonicalizeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CanonicalizeSuite.scala new file mode 100644 index 0000000000000..28e6940f3cca3 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CanonicalizeSuite.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.plans.logical.Range + +class CanonicalizeSuite extends SparkFunSuite { + + test("SPARK-24276: IN expression with different order are semantically equal") { + val range = Range(1, 1, 1, 1) + val idAttr = range.output.head + + val in1 = In(idAttr, Seq(Literal(1), Literal(2))) + val in2 = In(idAttr, Seq(Literal(2), Literal(1))) + val in3 = In(idAttr, Seq(Literal(1), Literal(2), Literal(3))) + + assert(in1.canonicalized.semanticHash() == in2.canonicalized.semanticHash()) + assert(in1.canonicalized.semanticHash() != in3.canonicalized.semanticHash()) + + assert(range.where(in1).sameResult(range.where(in2))) + assert(!range.where(in1).sameResult(range.where(in3))) + + val arrays1 = In(idAttr, Seq(CreateArray(Seq(Literal(1), Literal(2))), + CreateArray(Seq(Literal(2), Literal(1))))) + val arrays2 = In(idAttr, Seq(CreateArray(Seq(Literal(2), Literal(1))), + CreateArray(Seq(Literal(1), Literal(2))))) + val arrays3 = In(idAttr, Seq(CreateArray(Seq(Literal(1), Literal(2))), + CreateArray(Seq(Literal(3), Literal(1))))) + + assert(arrays1.canonicalized.semanticHash() == arrays2.canonicalized.semanticHash()) + assert(arrays1.canonicalized.semanticHash() != arrays3.canonicalized.semanticHash()) + + assert(range.where(arrays1).sameResult(range.where(arrays2))) + assert(!range.where(arrays1).sameResult(range.where(arrays3))) + } +} From 0053e153faaa76ea38c845adab137d5be970e5af Mon Sep 17 00:00:00 2001 From: William Sheu Date: Wed, 30 May 2018 22:37:27 -0700 Subject: [PATCH 0896/1161] [SPARK-24337][CORE] Improve error messages for Spark conf values ## What changes were proposed in this pull request? Improve the exception messages when retrieving Spark conf values to include the key name when the value is invalid. ## How was this patch tested? Unit tests for all get* operations in SparkConf that require a specific value format Author: William Sheu Closes #21454 from PenguinToast/SPARK-24337-spark-config-errors. --- .../scala/org/apache/spark/SparkConf.scala | 85 ++++++++++++++----- .../org/apache/spark/SparkConfSuite.scala | 32 +++++++ 2 files changed, 96 insertions(+), 21 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index dab409572646f..6c4c5c94cfa28 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -265,16 +265,18 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging with Seria * Get a time parameter as seconds; throws a NoSuchElementException if it's not set. If no * suffix is provided then seconds are assumed. * @throws java.util.NoSuchElementException If the time parameter is not set + * @throws NumberFormatException If the value cannot be interpreted as seconds */ - def getTimeAsSeconds(key: String): Long = { + def getTimeAsSeconds(key: String): Long = catchIllegalValue(key) { Utils.timeStringAsSeconds(get(key)) } /** * Get a time parameter as seconds, falling back to a default if not set. If no * suffix is provided then seconds are assumed. + * @throws NumberFormatException If the value cannot be interpreted as seconds */ - def getTimeAsSeconds(key: String, defaultValue: String): Long = { + def getTimeAsSeconds(key: String, defaultValue: String): Long = catchIllegalValue(key) { Utils.timeStringAsSeconds(get(key, defaultValue)) } @@ -282,16 +284,18 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging with Seria * Get a time parameter as milliseconds; throws a NoSuchElementException if it's not set. If no * suffix is provided then milliseconds are assumed. * @throws java.util.NoSuchElementException If the time parameter is not set + * @throws NumberFormatException If the value cannot be interpreted as milliseconds */ - def getTimeAsMs(key: String): Long = { + def getTimeAsMs(key: String): Long = catchIllegalValue(key) { Utils.timeStringAsMs(get(key)) } /** * Get a time parameter as milliseconds, falling back to a default if not set. If no * suffix is provided then milliseconds are assumed. + * @throws NumberFormatException If the value cannot be interpreted as milliseconds */ - def getTimeAsMs(key: String, defaultValue: String): Long = { + def getTimeAsMs(key: String, defaultValue: String): Long = catchIllegalValue(key) { Utils.timeStringAsMs(get(key, defaultValue)) } @@ -299,23 +303,26 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging with Seria * Get a size parameter as bytes; throws a NoSuchElementException if it's not set. If no * suffix is provided then bytes are assumed. * @throws java.util.NoSuchElementException If the size parameter is not set + * @throws NumberFormatException If the value cannot be interpreted as bytes */ - def getSizeAsBytes(key: String): Long = { + def getSizeAsBytes(key: String): Long = catchIllegalValue(key) { Utils.byteStringAsBytes(get(key)) } /** * Get a size parameter as bytes, falling back to a default if not set. If no * suffix is provided then bytes are assumed. + * @throws NumberFormatException If the value cannot be interpreted as bytes */ - def getSizeAsBytes(key: String, defaultValue: String): Long = { + def getSizeAsBytes(key: String, defaultValue: String): Long = catchIllegalValue(key) { Utils.byteStringAsBytes(get(key, defaultValue)) } /** * Get a size parameter as bytes, falling back to a default if not set. + * @throws NumberFormatException If the value cannot be interpreted as bytes */ - def getSizeAsBytes(key: String, defaultValue: Long): Long = { + def getSizeAsBytes(key: String, defaultValue: Long): Long = catchIllegalValue(key) { Utils.byteStringAsBytes(get(key, defaultValue + "B")) } @@ -323,16 +330,18 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging with Seria * Get a size parameter as Kibibytes; throws a NoSuchElementException if it's not set. If no * suffix is provided then Kibibytes are assumed. * @throws java.util.NoSuchElementException If the size parameter is not set + * @throws NumberFormatException If the value cannot be interpreted as Kibibytes */ - def getSizeAsKb(key: String): Long = { + def getSizeAsKb(key: String): Long = catchIllegalValue(key) { Utils.byteStringAsKb(get(key)) } /** * Get a size parameter as Kibibytes, falling back to a default if not set. If no * suffix is provided then Kibibytes are assumed. + * @throws NumberFormatException If the value cannot be interpreted as Kibibytes */ - def getSizeAsKb(key: String, defaultValue: String): Long = { + def getSizeAsKb(key: String, defaultValue: String): Long = catchIllegalValue(key) { Utils.byteStringAsKb(get(key, defaultValue)) } @@ -340,16 +349,18 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging with Seria * Get a size parameter as Mebibytes; throws a NoSuchElementException if it's not set. If no * suffix is provided then Mebibytes are assumed. * @throws java.util.NoSuchElementException If the size parameter is not set + * @throws NumberFormatException If the value cannot be interpreted as Mebibytes */ - def getSizeAsMb(key: String): Long = { + def getSizeAsMb(key: String): Long = catchIllegalValue(key) { Utils.byteStringAsMb(get(key)) } /** * Get a size parameter as Mebibytes, falling back to a default if not set. If no * suffix is provided then Mebibytes are assumed. + * @throws NumberFormatException If the value cannot be interpreted as Mebibytes */ - def getSizeAsMb(key: String, defaultValue: String): Long = { + def getSizeAsMb(key: String, defaultValue: String): Long = catchIllegalValue(key) { Utils.byteStringAsMb(get(key, defaultValue)) } @@ -357,16 +368,18 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging with Seria * Get a size parameter as Gibibytes; throws a NoSuchElementException if it's not set. If no * suffix is provided then Gibibytes are assumed. * @throws java.util.NoSuchElementException If the size parameter is not set + * @throws NumberFormatException If the value cannot be interpreted as Gibibytes */ - def getSizeAsGb(key: String): Long = { + def getSizeAsGb(key: String): Long = catchIllegalValue(key) { Utils.byteStringAsGb(get(key)) } /** * Get a size parameter as Gibibytes, falling back to a default if not set. If no * suffix is provided then Gibibytes are assumed. + * @throws NumberFormatException If the value cannot be interpreted as Gibibytes */ - def getSizeAsGb(key: String, defaultValue: String): Long = { + def getSizeAsGb(key: String, defaultValue: String): Long = catchIllegalValue(key) { Utils.byteStringAsGb(get(key, defaultValue)) } @@ -394,23 +407,35 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging with Seria } - /** Get a parameter as an integer, falling back to a default if not set */ - def getInt(key: String, defaultValue: Int): Int = { + /** + * Get a parameter as an integer, falling back to a default if not set + * @throws NumberFormatException If the value cannot be interpreted as an integer + */ + def getInt(key: String, defaultValue: Int): Int = catchIllegalValue(key) { getOption(key).map(_.toInt).getOrElse(defaultValue) } - /** Get a parameter as a long, falling back to a default if not set */ - def getLong(key: String, defaultValue: Long): Long = { + /** + * Get a parameter as a long, falling back to a default if not set + * @throws NumberFormatException If the value cannot be interpreted as a long + */ + def getLong(key: String, defaultValue: Long): Long = catchIllegalValue(key) { getOption(key).map(_.toLong).getOrElse(defaultValue) } - /** Get a parameter as a double, falling back to a default if not set */ - def getDouble(key: String, defaultValue: Double): Double = { + /** + * Get a parameter as a double, falling back to a default if not ste + * @throws NumberFormatException If the value cannot be interpreted as a double + */ + def getDouble(key: String, defaultValue: Double): Double = catchIllegalValue(key) { getOption(key).map(_.toDouble).getOrElse(defaultValue) } - /** Get a parameter as a boolean, falling back to a default if not set */ - def getBoolean(key: String, defaultValue: Boolean): Boolean = { + /** + * Get a parameter as a boolean, falling back to a default if not set + * @throws IllegalArgumentException If the value cannot be interpreted as a boolean + */ + def getBoolean(key: String, defaultValue: Boolean): Boolean = catchIllegalValue(key) { getOption(key).map(_.toBoolean).getOrElse(defaultValue) } @@ -448,6 +473,24 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging with Seria */ private[spark] def getenv(name: String): String = System.getenv(name) + /** + * Wrapper method for get() methods which require some specific value format. This catches + * any [[NumberFormatException]] or [[IllegalArgumentException]] and re-raises it with the + * incorrectly configured key in the exception message. + */ + private def catchIllegalValue[T](key: String)(getValue: => T): T = { + try { + getValue + } catch { + case e: NumberFormatException => + // NumberFormatException doesn't have a constructor that takes a cause for some reason. + throw new NumberFormatException(s"Illegal value for config key $key: ${e.getMessage}") + .initCause(e) + case e: IllegalArgumentException => + throw new IllegalArgumentException(s"Illegal value for config key $key: ${e.getMessage}", e) + } + } + /** * Checks for illegal or deprecated config settings. Throws an exception for the former. Not * idempotent - may mutate this conf object to convert deprecated settings to supported ones. diff --git a/core/src/test/scala/org/apache/spark/SparkConfSuite.scala b/core/src/test/scala/org/apache/spark/SparkConfSuite.scala index bff808eb540ac..0d06b02e74e34 100644 --- a/core/src/test/scala/org/apache/spark/SparkConfSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkConfSuite.scala @@ -339,6 +339,38 @@ class SparkConfSuite extends SparkFunSuite with LocalSparkContext with ResetSyst } } + val defaultIllegalValue = "SomeIllegalValue" + val illegalValueTests : Map[String, (SparkConf, String) => Any] = Map( + "getTimeAsSeconds" -> (_.getTimeAsSeconds(_)), + "getTimeAsSeconds with default" -> (_.getTimeAsSeconds(_, defaultIllegalValue)), + "getTimeAsMs" -> (_.getTimeAsMs(_)), + "getTimeAsMs with default" -> (_.getTimeAsMs(_, defaultIllegalValue)), + "getSizeAsBytes" -> (_.getSizeAsBytes(_)), + "getSizeAsBytes with default string" -> (_.getSizeAsBytes(_, defaultIllegalValue)), + "getSizeAsBytes with default long" -> (_.getSizeAsBytes(_, 0L)), + "getSizeAsKb" -> (_.getSizeAsKb(_)), + "getSizeAsKb with default" -> (_.getSizeAsKb(_, defaultIllegalValue)), + "getSizeAsMb" -> (_.getSizeAsMb(_)), + "getSizeAsMb with default" -> (_.getSizeAsMb(_, defaultIllegalValue)), + "getSizeAsGb" -> (_.getSizeAsGb(_)), + "getSizeAsGb with default" -> (_.getSizeAsGb(_, defaultIllegalValue)), + "getInt" -> (_.getInt(_, 0)), + "getLong" -> (_.getLong(_, 0L)), + "getDouble" -> (_.getDouble(_, 0.0)), + "getBoolean" -> (_.getBoolean(_, false)) + ) + + illegalValueTests.foreach { case (name, getValue) => + test(s"SPARK-24337: $name throws an useful error message with key name") { + val key = "SomeKey" + val conf = new SparkConf() + conf.set(key, "SomeInvalidValue") + val thrown = intercept[IllegalArgumentException] { + getValue(conf, key) + } + assert(thrown.getMessage.contains(key)) + } + } } class Class1 {} From 90ae98d1accb3e4b7d381de072257bdece8dd7e0 Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Thu, 31 May 2018 06:53:10 -0700 Subject: [PATCH 0897/1161] [SPARK-24146][PYSPARK][ML] spark.ml parity for sequential pattern mining - PrefixSpan: Python API ## What changes were proposed in this pull request? spark.ml parity for sequential pattern mining - PrefixSpan: Python API ## How was this patch tested? doctests Author: WeichenXu Closes #21265 from WeichenXu123/prefix_span_py. --- .../org/apache/spark/ml/fpm/PrefixSpan.scala | 6 +- python/pyspark/ml/fpm.py | 104 +++++++++++++++++- 2 files changed, 106 insertions(+), 4 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala index 41716c621ca98..bd1c1a8885201 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala @@ -53,7 +53,7 @@ final class PrefixSpan(@Since("2.4.0") override val uid: String) extends Params @Since("2.4.0") val minSupport = new DoubleParam(this, "minSupport", "The minimal support level of the " + "sequential pattern. Sequential pattern that appears more than " + - "(minSupport * size-of-the-dataset)." + + "(minSupport * size-of-the-dataset) " + "times will be output.", ParamValidators.gtEq(0.0)) /** @group getParam */ @@ -128,10 +128,10 @@ final class PrefixSpan(@Since("2.4.0") override val uid: String) extends Params * Finds the complete set of frequent sequential patterns in the input sequences of itemsets. * * @param dataset A dataset or a dataframe containing a sequence column which is - * {{{Seq[Seq[_]]}}} type + * {{{ArrayType(ArrayType(T))}}} type, T is the item type for the input dataset. * @return A `DataFrame` that contains columns of sequence and corresponding frequency. * The schema of it will be: - * - `sequence: Seq[Seq[T]]` (T is the item type) + * - `sequence: ArrayType(ArrayType(T))` (T is the item type) * - `freq: Long` */ @Since("2.4.0") diff --git a/python/pyspark/ml/fpm.py b/python/pyspark/ml/fpm.py index b8dafd49d354d..fd19fd96c4df6 100644 --- a/python/pyspark/ml/fpm.py +++ b/python/pyspark/ml/fpm.py @@ -16,8 +16,9 @@ # from pyspark import keyword_only, since +from pyspark.sql import DataFrame from pyspark.ml.util import * -from pyspark.ml.wrapper import JavaEstimator, JavaModel +from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaParams, _jvm from pyspark.ml.param.shared import * __all__ = ["FPGrowth", "FPGrowthModel"] @@ -243,3 +244,104 @@ def setParams(self, minSupport=0.3, minConfidence=0.8, itemsCol="items", def _create_model(self, java_model): return FPGrowthModel(java_model) + + +class PrefixSpan(JavaParams): + """ + .. note:: Experimental + + A parallel PrefixSpan algorithm to mine frequent sequential patterns. + The PrefixSpan algorithm is described in J. Pei, et al., PrefixSpan: Mining Sequential Patterns + Efficiently by Prefix-Projected Pattern Growth + (see here). + This class is not yet an Estimator/Transformer, use :py:func:`findFrequentSequentialPatterns` + method to run the PrefixSpan algorithm. + + @see Sequential Pattern Mining + (Wikipedia) + .. versionadded:: 2.4.0 + + """ + + minSupport = Param(Params._dummy(), "minSupport", "The minimal support level of the " + + "sequential pattern. Sequential pattern that appears more than " + + "(minSupport * size-of-the-dataset) times will be output. Must be >= 0.", + typeConverter=TypeConverters.toFloat) + + maxPatternLength = Param(Params._dummy(), "maxPatternLength", + "The maximal length of the sequential pattern. Must be > 0.", + typeConverter=TypeConverters.toInt) + + maxLocalProjDBSize = Param(Params._dummy(), "maxLocalProjDBSize", + "The maximum number of items (including delimiters used in the " + + "internal storage format) allowed in a projected database before " + + "local processing. If a projected database exceeds this size, " + + "another iteration of distributed prefix growth is run. " + + "Must be > 0.", + typeConverter=TypeConverters.toInt) + + sequenceCol = Param(Params._dummy(), "sequenceCol", "The name of the sequence column in " + + "dataset, rows with nulls in this column are ignored.", + typeConverter=TypeConverters.toString) + + @keyword_only + def __init__(self, minSupport=0.1, maxPatternLength=10, maxLocalProjDBSize=32000000, + sequenceCol="sequence"): + """ + __init__(self, minSupport=0.1, maxPatternLength=10, maxLocalProjDBSize=32000000, \ + sequenceCol="sequence") + """ + super(PrefixSpan, self).__init__() + self._java_obj = self._new_java_obj("org.apache.spark.ml.fpm.PrefixSpan", self.uid) + self._setDefault(minSupport=0.1, maxPatternLength=10, maxLocalProjDBSize=32000000, + sequenceCol="sequence") + kwargs = self._input_kwargs + self.setParams(**kwargs) + + @keyword_only + @since("2.4.0") + def setParams(self, minSupport=0.1, maxPatternLength=10, maxLocalProjDBSize=32000000, + sequenceCol="sequence"): + """ + setParams(self, minSupport=0.1, maxPatternLength=10, maxLocalProjDBSize=32000000, \ + sequenceCol="sequence") + """ + kwargs = self._input_kwargs + return self._set(**kwargs) + + @since("2.4.0") + def findFrequentSequentialPatterns(self, dataset): + """ + .. note:: Experimental + Finds the complete set of frequent sequential patterns in the input sequences of itemsets. + + :param dataset: A dataframe containing a sequence column which is + `ArrayType(ArrayType(T))` type, T is the item type for the input dataset. + :return: A `DataFrame` that contains columns of sequence and corresponding frequency. + The schema of it will be: + - `sequence: ArrayType(ArrayType(T))` (T is the item type) + - `freq: Long` + + >>> from pyspark.ml.fpm import PrefixSpan + >>> from pyspark.sql import Row + >>> df = sc.parallelize([Row(sequence=[[1, 2], [3]]), + ... Row(sequence=[[1], [3, 2], [1, 2]]), + ... Row(sequence=[[1, 2], [5]]), + ... Row(sequence=[[6]])]).toDF() + >>> prefixSpan = PrefixSpan(minSupport=0.5, maxPatternLength=5) + >>> prefixSpan.findFrequentSequentialPatterns(df).sort("sequence").show(truncate=False) + +----------+----+ + |sequence |freq| + +----------+----+ + |[[1]] |3 | + |[[1], [3]]|2 | + |[[1, 2]] |3 | + |[[2]] |3 | + |[[3]] |2 | + +----------+----+ + + .. versionadded:: 2.4.0 + """ + self._transfer_params_to_java() + jdf = self._java_obj.findFrequentSequentialPatterns(dataset._jdf) + return DataFrame(jdf, dataset.sql_ctx) From 698b9a0981f0ec322e15d6ac89cc38c8f49ed33d Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Thu, 31 May 2018 09:34:39 -0700 Subject: [PATCH 0898/1161] [WEBUI] Avoid possibility of script in query param keys As discussed separately, this avoids the possibility of XSS on certain request param keys. CC vanzin Author: Sean Owen Closes #21464 from srowen/XSS2. --- .../src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala | 4 +++- core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala index f651fe97c2cd5..178d2c8d1a10a 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala @@ -206,7 +206,9 @@ private[ui] class AllJobsPage(parent: JobsTab, store: AppStatusStore) extends We jobs: Seq[v1.JobData], killEnabled: Boolean): Seq[Node] = { // stripXSS is called to remove suspicious characters used in XSS attacks - val allParameters = request.getParameterMap.asScala.toMap.mapValues(_.map(UIUtils.stripXSS)) + val allParameters = request.getParameterMap.asScala.toMap.map { case (k, v) => + UIUtils.stripXSS(k) -> v.map(UIUtils.stripXSS).toSeq + } val parameterOtherTable = allParameters.filterNot(_._1.startsWith(jobTag)) .map(para => para._1 + "=" + para._2(0)) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala index b8b20db1fa407..56e4d6838a99a 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala @@ -43,7 +43,9 @@ private[ui] class StageTableBase( killEnabled: Boolean, isFailedStage: Boolean) { // stripXSS is called to remove suspicious characters used in XSS attacks - val allParameters = request.getParameterMap.asScala.toMap.mapValues(_.map(UIUtils.stripXSS)) + val allParameters = request.getParameterMap.asScala.toMap.map { case (k, v) => + UIUtils.stripXSS(k) -> v.map(UIUtils.stripXSS).toSeq + } val parameterOtherTable = allParameters.filterNot(_._1.startsWith(stageTag)) .map(para => para._1 + "=" + para._2(0)) From 7a82e93b349b4f414f2075dd5add8e4ed72fe357 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Thu, 31 May 2018 10:05:20 -0700 Subject: [PATCH 0899/1161] [SPARK-24414][UI] Calculate the correct number of tasks for a stage. This change takes into account all non-pending tasks when calculating the number of tasks to be shown. This also means that when the stage is pending, the task table (or, in fact, most of the data in the stage page) will not be rendered. I also fixed the label when the known number of tasks is larger than the recorded number of tasks (it was inverted). Author: Marcelo Vanzin Closes #21457 from vanzin/SPARK-24414. --- .../scala/org/apache/spark/ui/jobs/StagePage.scala | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index 2575914121c39..d4e6a7bc3effa 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -117,8 +117,7 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We val localitySummary = store.localitySummary(stageData.stageId, stageData.attemptId) - val totalTasks = stageData.numActiveTasks + stageData.numCompleteTasks + - stageData.numFailedTasks + stageData.numKilledTasks + val totalTasks = taskCount(stageData) if (totalTasks == 0) { val content =
    @@ -133,7 +132,7 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We val totalTasksNumStr = if (totalTasks == storedTasks) { s"$totalTasks" } else { - s"$totalTasks, showing ${storedTasks}" + s"$storedTasks, showing ${totalTasks}" } val summary = @@ -686,7 +685,7 @@ private[ui] class TaskDataSource( private var _tasksToShow: Seq[TaskData] = null - override def dataSize: Int = stage.numTasks + override def dataSize: Int = taskCount(stage) override def sliceData(from: Int, to: Int): Seq[TaskData] = { if (_tasksToShow == null) { @@ -1052,4 +1051,9 @@ private[ui] object ApiHelper { (stage.map(_.name).getOrElse(""), stage.flatMap(_.description).getOrElse(job.name)) } + def taskCount(stageData: StageData): Int = { + stageData.numActiveTasks + stageData.numCompleteTasks + stageData.numFailedTasks + + stageData.numKilledTasks + } + } From 223df5d9d4fbf48db017edb41f9b7e4033679f35 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Thu, 31 May 2018 11:23:57 -0700 Subject: [PATCH 0900/1161] [SPARK-24397][PYSPARK] Added TaskContext.getLocalProperty(key) in Python ## What changes were proposed in this pull request? This adds a new API `TaskContext.getLocalProperty(key)` to the Python TaskContext. It mirrors the Java TaskContext API of returning a string value if the key exists, or None if the key does not exist. ## How was this patch tested? New test added. Author: Tathagata Das Closes #21437 from tdas/SPARK-24397. --- .../org/apache/spark/api/python/PythonRunner.scala | 7 +++++++ python/pyspark/taskcontext.py | 7 +++++++ python/pyspark/tests.py | 14 ++++++++++++++ python/pyspark/worker.py | 6 ++++++ 4 files changed, 34 insertions(+) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala index f075a7e0eb0b4..41eac10d9b267 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala @@ -183,6 +183,13 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( dataOut.writeInt(context.partitionId()) dataOut.writeInt(context.attemptNumber()) dataOut.writeLong(context.taskAttemptId()) + val localProps = context.asInstanceOf[TaskContextImpl].getLocalProperties.asScala + dataOut.writeInt(localProps.size) + localProps.foreach { case (k, v) => + PythonRDD.writeUTF(k, dataOut) + PythonRDD.writeUTF(v, dataOut) + } + // sparkFilesDir PythonRDD.writeUTF(SparkFiles.getRootDirectory(), dataOut) // Python includes (*.zip and *.egg files) diff --git a/python/pyspark/taskcontext.py b/python/pyspark/taskcontext.py index e5218d9e75e78..63ae1f30e17ca 100644 --- a/python/pyspark/taskcontext.py +++ b/python/pyspark/taskcontext.py @@ -34,6 +34,7 @@ class TaskContext(object): _partitionId = None _stageId = None _taskAttemptId = None + _localProperties = None def __new__(cls): """Even if users construct TaskContext instead of using get, give them the singleton.""" @@ -88,3 +89,9 @@ def taskAttemptId(self): TaskAttemptID. """ return self._taskAttemptId + + def getLocalProperty(self, key): + """ + Get a local property set upstream in the driver, or None if it is missing. + """ + return self._localProperties.get(key, None) diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 3b37cc028c1b7..30723b8e15b36 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -574,6 +574,20 @@ def test_tc_on_driver(self): tc = TaskContext.get() self.assertTrue(tc is None) + def test_get_local_property(self): + """Verify that local properties set on the driver are available in TaskContext.""" + key = "testkey" + value = "testvalue" + self.sc.setLocalProperty(key, value) + try: + rdd = self.sc.parallelize(range(1), 1) + prop1 = rdd.map(lambda x: TaskContext.get().getLocalProperty(key)).collect()[0] + self.assertEqual(prop1, value) + prop2 = rdd.map(lambda x: TaskContext.get().getLocalProperty("otherkey")).collect()[0] + self.assertTrue(prop2 is None) + finally: + self.sc.setLocalProperty(key, None) + class RDDTests(ReusedPySparkTestCase): diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 5d2e58bef6466..fbcb8af8bfb24 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -222,6 +222,12 @@ def main(infile, outfile): taskContext._partitionId = read_int(infile) taskContext._attemptNumber = read_int(infile) taskContext._taskAttemptId = read_long(infile) + taskContext._localProperties = dict() + for i in range(read_int(infile)): + k = utf8_deserializer.loads(infile) + v = utf8_deserializer.loads(infile) + taskContext._localProperties[k] = v + shuffle.MemoryBytesSpilled = 0 shuffle.DiskBytesSpilled = 0 _accumulatorRegistry.clear() From cc976f6cb858adb5f52987b56dda54769915ce50 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Thu, 31 May 2018 11:38:23 -0700 Subject: [PATCH 0901/1161] [SPARK-23900][SQL] format_number support user specifed format as argument ## What changes were proposed in this pull request? `format_number` support user specifed format as argument. For example: ```sql spark-sql> SELECT format_number(12332.123456, '##################.###'); 12332.123 ``` ## How was this patch tested? unit test Author: Yuming Wang Closes #21010 from wangyum/SPARK-23900. --- .../expressions/stringExpressions.scala | 142 ++++++++++++------ .../expressions/StringExpressionsSuite.scala | 24 +++ 2 files changed, 116 insertions(+), 50 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 9823b2fc5ad97..bedad7da334ae 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -1916,12 +1916,15 @@ case class Encode(value: Expression, charset: Expression) usage = """ _FUNC_(expr1, expr2) - Formats the number `expr1` like '#,###,###.##', rounded to `expr2` decimal places. If `expr2` is 0, the result has no decimal point or fractional part. + `expr2` also accept a user specified format. This is supposed to function like MySQL's FORMAT. """, examples = """ Examples: > SELECT _FUNC_(12332.123456, 4); 12,332.1235 + > SELECT _FUNC_(12332.123456, '##################.###'); + 12332.123 """) case class FormatNumber(x: Expression, d: Expression) extends BinaryExpression with ExpectsInputTypes { @@ -1930,14 +1933,20 @@ case class FormatNumber(x: Expression, d: Expression) override def right: Expression = d override def dataType: DataType = StringType override def nullable: Boolean = true - override def inputTypes: Seq[AbstractDataType] = Seq(NumericType, IntegerType) + override def inputTypes: Seq[AbstractDataType] = + Seq(NumericType, TypeCollection(IntegerType, StringType)) + + private val defaultFormat = "#,###,###,###,###,###,##0" // Associated with the pattern, for the last d value, and we will update the // pattern (DecimalFormat) once the new coming d value differ with the last one. // This is an Option to distinguish between 0 (numberFormat is valid) and uninitialized after // serialization (numberFormat has not been updated for dValue = 0). @transient - private var lastDValue: Option[Int] = None + private var lastDIntValue: Option[Int] = None + + @transient + private var lastDStringValue: Option[String] = None // A cached DecimalFormat, for performance concern, we will change it // only if the d value changed. @@ -1950,33 +1959,49 @@ case class FormatNumber(x: Expression, d: Expression) private lazy val numberFormat = new DecimalFormat("", new DecimalFormatSymbols(Locale.US)) override protected def nullSafeEval(xObject: Any, dObject: Any): Any = { - val dValue = dObject.asInstanceOf[Int] - if (dValue < 0) { - return null - } - - lastDValue match { - case Some(last) if last == dValue => - // use the current pattern - case _ => - // construct a new DecimalFormat only if a new dValue - pattern.delete(0, pattern.length) - pattern.append("#,###,###,###,###,###,##0") - - // decimal place - if (dValue > 0) { - pattern.append(".") - - var i = 0 - while (i < dValue) { - i += 1 - pattern.append("0") - } + right.dataType match { + case IntegerType => + val dValue = dObject.asInstanceOf[Int] + if (dValue < 0) { + return null } - lastDValue = Some(dValue) + lastDIntValue match { + case Some(last) if last == dValue => + // use the current pattern + case _ => + // construct a new DecimalFormat only if a new dValue + pattern.delete(0, pattern.length) + pattern.append(defaultFormat) + + // decimal place + if (dValue > 0) { + pattern.append(".") + + var i = 0 + while (i < dValue) { + i += 1 + pattern.append("0") + } + } + + lastDIntValue = Some(dValue) - numberFormat.applyLocalizedPattern(pattern.toString) + numberFormat.applyLocalizedPattern(pattern.toString) + } + case StringType => + val dValue = dObject.asInstanceOf[UTF8String].toString + lastDStringValue match { + case Some(last) if last == dValue => + case _ => + pattern.delete(0, pattern.length) + lastDStringValue = Some(dValue) + if (dValue.isEmpty) { + numberFormat.applyLocalizedPattern(defaultFormat) + } else { + numberFormat.applyLocalizedPattern(dValue) + } + } } x.dataType match { @@ -2008,35 +2033,52 @@ case class FormatNumber(x: Expression, d: Expression) // SPARK-13515: US Locale configures the DecimalFormat object to use a dot ('.') // as a decimal separator. val usLocale = "US" - val i = ctx.freshName("i") - val dFormat = ctx.freshName("dFormat") - val lastDValue = - ctx.addMutableState(CodeGenerator.JAVA_INT, "lastDValue", v => s"$v = -100;") - val pattern = ctx.addMutableState(sb, "pattern", v => s"$v = new $sb();") val numberFormat = ctx.addMutableState(df, "numberFormat", v => s"""$v = new $df("", new $dfs($l.$usLocale));""") - s""" - if ($d >= 0) { - $pattern.delete(0, $pattern.length()); - if ($d != $lastDValue) { - $pattern.append("#,###,###,###,###,###,##0"); - - if ($d > 0) { - $pattern.append("."); - for (int $i = 0; $i < $d; $i++) { - $pattern.append("0"); + right.dataType match { + case IntegerType => + val pattern = ctx.addMutableState(sb, "pattern", v => s"$v = new $sb();") + val i = ctx.freshName("i") + val lastDValue = + ctx.addMutableState(CodeGenerator.JAVA_INT, "lastDValue", v => s"$v = -100;") + s""" + if ($d >= 0) { + $pattern.delete(0, $pattern.length()); + if ($d != $lastDValue) { + $pattern.append("$defaultFormat"); + + if ($d > 0) { + $pattern.append("."); + for (int $i = 0; $i < $d; $i++) { + $pattern.append("0"); + } + } + $lastDValue = $d; + $numberFormat.applyLocalizedPattern($pattern.toString()); } + ${ev.value} = UTF8String.fromString($numberFormat.format(${typeHelper(num)})); + } else { + ${ev.value} = null; + ${ev.isNull} = true; } - $lastDValue = $d; - $numberFormat.applyLocalizedPattern($pattern.toString()); - } - ${ev.value} = UTF8String.fromString($numberFormat.format(${typeHelper(num)})); - } else { - ${ev.value} = null; - ${ev.isNull} = true; - } - """ + """ + case StringType => + val lastDValue = ctx.addMutableState("String", "lastDValue", v => s"""$v = null;""") + val dValue = ctx.freshName("dValue") + s""" + String $dValue = $d.toString(); + if (!$dValue.equals($lastDValue)) { + $lastDValue = $dValue; + if ($dValue.isEmpty()) { + $numberFormat.applyLocalizedPattern("$defaultFormat"); + } else { + $numberFormat.applyLocalizedPattern($dValue); + } + } + ${ev.value} = UTF8String.fromString($numberFormat.format(${typeHelper(num)})); + """ + } }) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index f1a6f9b8889fa..aa334e040d5fc 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -706,6 +706,30 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { "15,159,339,180,002,773.2778") checkEvaluation(FormatNumber(Literal.create(null, IntegerType), Literal(3)), null) assert(FormatNumber(Literal.create(null, NullType), Literal(3)).resolved === false) + + checkEvaluation(FormatNumber(Literal(12332.123456), Literal("##############.###")), "12332.123") + checkEvaluation(FormatNumber(Literal(12332.123456), Literal("##.###")), "12332.123") + checkEvaluation(FormatNumber(Literal(4.asInstanceOf[Byte]), Literal("##.####")), "4") + checkEvaluation(FormatNumber(Literal(4.asInstanceOf[Short]), Literal("##.####")), "4") + checkEvaluation(FormatNumber(Literal(4.0f), Literal("##.###")), "4") + checkEvaluation(FormatNumber(Literal(4), Literal("##.###")), "4") + checkEvaluation(FormatNumber(Literal(12831273.23481d), + Literal("###,###,###,###,###.###")), "12,831,273.235") + checkEvaluation(FormatNumber(Literal(12831273.83421d), Literal("")), "12,831,274") + checkEvaluation(FormatNumber(Literal(123123324123L), Literal("###,###,###,###,###.###")), + "123,123,324,123") + checkEvaluation( + FormatNumber(Literal(Decimal(123123324123L) * Decimal(123123.21234d)), + Literal("###,###,###,###,###.####")), "15,159,339,180,002,773.2778") + checkEvaluation(FormatNumber(Literal.create(null, IntegerType), Literal("##.###")), null) + assert(FormatNumber(Literal.create(null, NullType), Literal("##.###")).resolved === false) + + checkEvaluation(FormatNumber(Literal(12332.123456), Literal("#,###,###,###,###,###,##0")), + "12,332") + checkEvaluation(FormatNumber( + Literal.create(null, IntegerType), Literal.create(null, StringType)), null) + checkEvaluation(FormatNumber( + Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null) } test("find in set") { From 21e1fc7d4aed688d7b685be6ce93f76752159c98 Mon Sep 17 00:00:00 2001 From: Stavros Kontopoulos Date: Thu, 31 May 2018 14:28:33 -0700 Subject: [PATCH 0902/1161] [SPARK-24232][K8S] Add support for secret env vars ## What changes were proposed in this pull request? * Allows to refer a secret as an env var. * Introduces new config properties in the form: spark.kubernetes{driver,executor}.secretKeyRef.ENV_NAME=name:key ENV_NAME is case sensitive. * Updates docs. * Adds required unit tests. ## How was this patch tested? Manually tested and confirmed that the secrets exist in driver's and executor's container env. Also job finished successfully. First created a secret with the following yaml: ``` apiVersion: v1 kind: Secret metadata: name: test-secret data: username: c3RhdnJvcwo= password: Mzk1MjgkdmRnN0pi ------- $ echo -n 'stavros' | base64 c3RhdnJvcw== $ echo -n '39528$vdg7Jb' | base64 MWYyZDFlMmU2N2Rm ``` Run a job as follows: ```./bin/spark-submit \ --master k8s://http://localhost:9000 \ --deploy-mode cluster \ --name spark-pi \ --class org.apache.spark.examples.SparkPi \ --conf spark.executor.instances=1 \ --conf spark.kubernetes.container.image=skonto/spark:k8envs3 \ --conf spark.kubernetes.driver.secretKeyRef.MY_USERNAME=test-secret:username \ --conf spark.kubernetes.driver.secretKeyRef.My_password=test-secret:password \ --conf spark.kubernetes.executor.secretKeyRef.MY_USERNAME=test-secret:username \ --conf spark.kubernetes.executor.secretKeyRef.My_password=test-secret:password \ local:///opt/spark/examples/jars/spark-examples_2.11-2.4.0-SNAPSHOT.jar 10000 ``` Secret loaded correctly at the driver container: ![image](https://user-images.githubusercontent.com/7945591/40174346-7fee70c8-59dd-11e8-8705-995a5472716f.png) Also if I log into the exec container: kubectl exec -it spark-pi-1526555613156-exec-1 bash bash-4.4# env > SPARK_EXECUTOR_MEMORY=1g > SPARK_EXECUTOR_CORES=1 > LANG=C.UTF-8 > HOSTNAME=spark-pi-1526555613156-exec-1 > SPARK_APPLICATION_ID=spark-application-1526555618626 > **MY_USERNAME=stavros** > > JAVA_HOME=/usr/lib/jvm/java-1.8-openjdk > KUBERNETES_PORT_443_TCP_PROTO=tcp > KUBERNETES_PORT_443_TCP_ADDR=10.100.0.1 > JAVA_VERSION=8u151 > KUBERNETES_PORT=tcp://10.100.0.1:443 > PWD=/opt/spark/work-dir > HOME=/root > SPARK_LOCAL_DIRS=/var/data/spark-b569b0ae-b7ef-4f91-bcd5-0f55535d3564 > KUBERNETES_SERVICE_PORT_HTTPS=443 > KUBERNETES_PORT_443_TCP_PORT=443 > SPARK_HOME=/opt/spark > SPARK_DRIVER_URL=spark://CoarseGrainedSchedulerspark-pi-1526555613156-driver-svc.default.svc:7078 > KUBERNETES_PORT_443_TCP=tcp://10.100.0.1:443 > SPARK_EXECUTOR_POD_IP=9.0.9.77 > TERM=xterm > SPARK_EXECUTOR_ID=1 > SHLVL=1 > KUBERNETES_SERVICE_PORT=443 > SPARK_CONF_DIR=/opt/spark/conf > PATH=/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/usr/lib/jvm/java-1.8-openjdk/jre/bin:/usr/lib/jvm/java-1.8-openjdk/bin > JAVA_ALPINE_VERSION=8.151.12-r0 > KUBERNETES_SERVICE_HOST=10.100.0.1 > **My_password=39528$vdg7Jb** > _=/usr/bin/env > Author: Stavros Kontopoulos Closes #21317 from skonto/k8s-fix-env-secrets. --- docs/running-on-kubernetes.md | 22 +++++++ .../org/apache/spark/deploy/k8s/Config.scala | 2 + .../spark/deploy/k8s/KubernetesConf.scala | 11 +++- .../k8s/features/EnvSecretsFeatureStep.scala | 57 ++++++++++++++++++ .../k8s/submit/KubernetesDriverBuilder.scala | 11 +++- .../k8s/KubernetesExecutorBuilder.scala | 12 +++- .../deploy/k8s/KubernetesConfSuite.scala | 12 +++- .../BasicDriverFeatureStepSuite.scala | 2 + .../BasicExecutorFeatureStepSuite.scala | 3 + ...ubernetesCredentialsFeatureStepSuite.scala | 3 + .../DriverServiceFeatureStepSuite.scala | 6 ++ .../features/EnvSecretsFeatureStepSuite.scala | 59 +++++++++++++++++++ .../KubernetesFeaturesTestUtils.scala | 7 ++- .../features/LocalDirsFeatureStepSuite.scala | 1 + .../MountSecretsFeatureStepSuite.scala | 1 + .../spark/deploy/k8s/submit/ClientSuite.scala | 1 + .../submit/KubernetesDriverBuilderSuite.scala | 13 +++- .../k8s/KubernetesExecutorBuilderSuite.scala | 11 +++- 18 files changed, 222 insertions(+), 12 deletions(-) create mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/EnvSecretsFeatureStep.scala create mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/EnvSecretsFeatureStepSuite.scala diff --git a/docs/running-on-kubernetes.md b/docs/running-on-kubernetes.md index e9e1f3e280609..a4b2b98b0b649 100644 --- a/docs/running-on-kubernetes.md +++ b/docs/running-on-kubernetes.md @@ -140,6 +140,12 @@ namespace as that of the driver and executor pods. For example, to mount a secre --conf spark.kubernetes.executor.secrets.spark-secret=/etc/secrets ``` +To use a secret through an environment variable use the following options to the `spark-submit` command: +``` +--conf spark.kubernetes.driver.secretKeyRef.ENV_NAME=name:key +--conf spark.kubernetes.executor.secretKeyRef.ENV_NAME=name:key +``` + ## Introspection and Debugging These are the different ways in which you can investigate a running/completed Spark application, monitor progress, and @@ -602,4 +608,20 @@ specific to Spark on Kubernetes. spark.kubernetes.executor.secrets.spark-secret=/etc/secrets. + + spark.kubernetes.driver.secretKeyRef.[EnvName] + (none) + + Add as an environment variable to the driver container with name EnvName (case sensitive), the value referenced by key key in the data of the referenced Kubernetes Secret. For example, + spark.kubernetes.driver.secretKeyRef.ENV_VAR=spark-secret:key. + + + + spark.kubernetes.executor.secretKeyRef.[EnvName] + (none) + + Add as an environment variable to the executor container with name EnvName (case sensitive), the value referenced by key key in the data of the referenced Kubernetes Secret. For example, + spark.kubernetes.executor.secrets.ENV_VAR=spark-secret:key. + + diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala index 4086970ffb256..560dedf431b08 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala @@ -162,10 +162,12 @@ private[spark] object Config extends Logging { val KUBERNETES_DRIVER_LABEL_PREFIX = "spark.kubernetes.driver.label." val KUBERNETES_DRIVER_ANNOTATION_PREFIX = "spark.kubernetes.driver.annotation." val KUBERNETES_DRIVER_SECRETS_PREFIX = "spark.kubernetes.driver.secrets." + val KUBERNETES_DRIVER_SECRET_KEY_REF_PREFIX = "spark.kubernetes.driver.secretKeyRef." val KUBERNETES_EXECUTOR_LABEL_PREFIX = "spark.kubernetes.executor.label." val KUBERNETES_EXECUTOR_ANNOTATION_PREFIX = "spark.kubernetes.executor.annotation." val KUBERNETES_EXECUTOR_SECRETS_PREFIX = "spark.kubernetes.executor.secrets." + val KUBERNETES_EXECUTOR_SECRET_KEY_REF_PREFIX = "spark.kubernetes.executor.secretKeyRef." val KUBERNETES_DRIVER_ENV_PREFIX = "spark.kubernetes.driverEnv." } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesConf.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesConf.scala index 77b634ddfabcc..5a944187a7096 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesConf.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesConf.scala @@ -54,6 +54,7 @@ private[spark] case class KubernetesConf[T <: KubernetesRoleSpecificConf]( roleLabels: Map[String, String], roleAnnotations: Map[String, String], roleSecretNamesToMountPaths: Map[String, String], + roleSecretEnvNamesToKeyRefs: Map[String, String], roleEnvs: Map[String, String]) { def namespace(): String = sparkConf.get(KUBERNETES_NAMESPACE) @@ -129,6 +130,8 @@ private[spark] object KubernetesConf { sparkConf, KUBERNETES_DRIVER_ANNOTATION_PREFIX) val driverSecretNamesToMountPaths = KubernetesUtils.parsePrefixedKeyValuePairs( sparkConf, KUBERNETES_DRIVER_SECRETS_PREFIX) + val driverSecretEnvNamesToKeyRefs = KubernetesUtils.parsePrefixedKeyValuePairs( + sparkConf, KUBERNETES_DRIVER_SECRET_KEY_REF_PREFIX) val driverEnvs = KubernetesUtils.parsePrefixedKeyValuePairs( sparkConf, KUBERNETES_DRIVER_ENV_PREFIX) @@ -140,6 +143,7 @@ private[spark] object KubernetesConf { driverLabels, driverAnnotations, driverSecretNamesToMountPaths, + driverSecretEnvNamesToKeyRefs, driverEnvs) } @@ -167,8 +171,10 @@ private[spark] object KubernetesConf { executorCustomLabels val executorAnnotations = KubernetesUtils.parsePrefixedKeyValuePairs( sparkConf, KUBERNETES_EXECUTOR_ANNOTATION_PREFIX) - val executorSecrets = KubernetesUtils.parsePrefixedKeyValuePairs( + val executorMountSecrets = KubernetesUtils.parsePrefixedKeyValuePairs( sparkConf, KUBERNETES_EXECUTOR_SECRETS_PREFIX) + val executorEnvSecrets = KubernetesUtils.parsePrefixedKeyValuePairs( + sparkConf, KUBERNETES_EXECUTOR_SECRET_KEY_REF_PREFIX) val executorEnv = sparkConf.getExecutorEnv.toMap KubernetesConf( @@ -178,7 +184,8 @@ private[spark] object KubernetesConf { appId, executorLabels, executorAnnotations, - executorSecrets, + executorMountSecrets, + executorEnvSecrets, executorEnv) } } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/EnvSecretsFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/EnvSecretsFeatureStep.scala new file mode 100644 index 0000000000000..03ff7d48420ff --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/EnvSecretsFeatureStep.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.features + +import scala.collection.JavaConverters._ + +import io.fabric8.kubernetes.api.model.{ContainerBuilder, EnvVarBuilder, HasMetadata} + +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesRoleSpecificConf, SparkPod} + +private[spark] class EnvSecretsFeatureStep( + kubernetesConf: KubernetesConf[_ <: KubernetesRoleSpecificConf]) + extends KubernetesFeatureConfigStep { + override def configurePod(pod: SparkPod): SparkPod = { + val addedEnvSecrets = kubernetesConf + .roleSecretEnvNamesToKeyRefs + .map{ case (envName, keyRef) => + // Keyref parts + val keyRefParts = keyRef.split(":") + require(keyRefParts.size == 2, "SecretKeyRef must be in the form name:key.") + val name = keyRefParts(0) + val key = keyRefParts(1) + new EnvVarBuilder() + .withName(envName) + .withNewValueFrom() + .withNewSecretKeyRef() + .withKey(key) + .withName(name) + .endSecretKeyRef() + .endValueFrom() + .build() + } + + val containerWithEnvVars = new ContainerBuilder(pod.container) + .addAllToEnv(addedEnvSecrets.toSeq.asJava) + .build() + SparkPod(pod.pod, containerWithEnvVars) + } + + override def getAdditionalPodSystemProperties(): Map[String, String] = Map.empty + + override def getAdditionalKubernetesResources(): Seq[HasMetadata] = Seq.empty +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala index 10b0154466a3a..fdc5eb0d75832 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala @@ -17,7 +17,7 @@ package org.apache.spark.deploy.k8s.submit import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpec, KubernetesDriverSpecificConf, KubernetesRoleSpecificConf} -import org.apache.spark.deploy.k8s.features.{BasicDriverFeatureStep, DriverKubernetesCredentialsFeatureStep, DriverServiceFeatureStep, LocalDirsFeatureStep, MountSecretsFeatureStep} +import org.apache.spark.deploy.k8s.features._ private[spark] class KubernetesDriverBuilder( provideBasicStep: (KubernetesConf[KubernetesDriverSpecificConf]) => BasicDriverFeatureStep = @@ -30,6 +30,9 @@ private[spark] class KubernetesDriverBuilder( provideSecretsStep: (KubernetesConf[_ <: KubernetesRoleSpecificConf] => MountSecretsFeatureStep) = new MountSecretsFeatureStep(_), + provideEnvSecretsStep: (KubernetesConf[_ <: KubernetesRoleSpecificConf] + => EnvSecretsFeatureStep) = + new EnvSecretsFeatureStep(_), provideLocalDirsStep: (KubernetesConf[_ <: KubernetesRoleSpecificConf]) => LocalDirsFeatureStep = new LocalDirsFeatureStep(_)) { @@ -41,10 +44,14 @@ private[spark] class KubernetesDriverBuilder( provideCredentialsStep(kubernetesConf), provideServiceStep(kubernetesConf), provideLocalDirsStep(kubernetesConf)) - val allFeatures = if (kubernetesConf.roleSecretNamesToMountPaths.nonEmpty) { + var allFeatures = if (kubernetesConf.roleSecretNamesToMountPaths.nonEmpty) { baseFeatures ++ Seq(provideSecretsStep(kubernetesConf)) } else baseFeatures + allFeatures = if (kubernetesConf.roleSecretEnvNamesToKeyRefs.nonEmpty) { + allFeatures ++ Seq(provideEnvSecretsStep(kubernetesConf)) + } else allFeatures + var spec = KubernetesDriverSpec.initialSpec(kubernetesConf.sparkConf.getAll.toMap) for (feature <- allFeatures) { val configuredPod = feature.configurePod(spec.pod) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala index d8f63d57574fb..d5e1de36a58df 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala @@ -17,7 +17,7 @@ package org.apache.spark.scheduler.cluster.k8s import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesExecutorSpecificConf, KubernetesRoleSpecificConf, SparkPod} -import org.apache.spark.deploy.k8s.features.{BasicExecutorFeatureStep, LocalDirsFeatureStep, MountSecretsFeatureStep} +import org.apache.spark.deploy.k8s.features.{BasicExecutorFeatureStep, EnvSecretsFeatureStep, LocalDirsFeatureStep, MountSecretsFeatureStep} private[spark] class KubernetesExecutorBuilder( provideBasicStep: (KubernetesConf[KubernetesExecutorSpecificConf]) => BasicExecutorFeatureStep = @@ -25,6 +25,9 @@ private[spark] class KubernetesExecutorBuilder( provideSecretsStep: (KubernetesConf[_ <: KubernetesRoleSpecificConf]) => MountSecretsFeatureStep = new MountSecretsFeatureStep(_), + provideEnvSecretsStep: + (KubernetesConf[_ <: KubernetesRoleSpecificConf] => EnvSecretsFeatureStep) = + new EnvSecretsFeatureStep(_), provideLocalDirsStep: (KubernetesConf[_ <: KubernetesRoleSpecificConf]) => LocalDirsFeatureStep = new LocalDirsFeatureStep(_)) { @@ -32,9 +35,14 @@ private[spark] class KubernetesExecutorBuilder( def buildFromFeatures( kubernetesConf: KubernetesConf[KubernetesExecutorSpecificConf]): SparkPod = { val baseFeatures = Seq(provideBasicStep(kubernetesConf), provideLocalDirsStep(kubernetesConf)) - val allFeatures = if (kubernetesConf.roleSecretNamesToMountPaths.nonEmpty) { + var allFeatures = if (kubernetesConf.roleSecretNamesToMountPaths.nonEmpty) { baseFeatures ++ Seq(provideSecretsStep(kubernetesConf)) } else baseFeatures + + allFeatures = if (kubernetesConf.roleSecretEnvNamesToKeyRefs.nonEmpty) { + allFeatures ++ Seq(provideEnvSecretsStep(kubernetesConf)) + } else allFeatures + var executorPod = SparkPod.initialPod() for (feature <- allFeatures) { executorPod = feature.configurePod(executorPod) diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesConfSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesConfSuite.scala index f10202f7a3546..3d23e1cb90fd2 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesConfSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesConfSuite.scala @@ -40,6 +40,9 @@ class KubernetesConfSuite extends SparkFunSuite { private val SECRET_NAMES_TO_MOUNT_PATHS = Map( "secret1" -> "/mnt/secrets/secret1", "secret2" -> "/mnt/secrets/secret2") + private val SECRET_ENV_VARS = Map( + "envName1" -> "name1:key1", + "envName2" -> "name2:key2") private val CUSTOM_ENVS = Map( "customEnvKey1" -> "customEnvValue1", "customEnvKey2" -> "customEnvValue2") @@ -103,6 +106,9 @@ class KubernetesConfSuite extends SparkFunSuite { SECRET_NAMES_TO_MOUNT_PATHS.foreach { case (key, value) => sparkConf.set(s"$KUBERNETES_DRIVER_SECRETS_PREFIX$key", value) } + SECRET_ENV_VARS.foreach { case (key, value) => + sparkConf.set(s"$KUBERNETES_DRIVER_SECRET_KEY_REF_PREFIX$key", value) + } CUSTOM_ENVS.foreach { case (key, value) => sparkConf.set(s"$KUBERNETES_DRIVER_ENV_PREFIX$key", value) } @@ -121,6 +127,7 @@ class KubernetesConfSuite extends SparkFunSuite { CUSTOM_LABELS) assert(conf.roleAnnotations === CUSTOM_ANNOTATIONS) assert(conf.roleSecretNamesToMountPaths === SECRET_NAMES_TO_MOUNT_PATHS) + assert(conf.roleSecretEnvNamesToKeyRefs === SECRET_ENV_VARS) assert(conf.roleEnvs === CUSTOM_ENVS) } @@ -155,6 +162,9 @@ class KubernetesConfSuite extends SparkFunSuite { CUSTOM_ANNOTATIONS.foreach { case (key, value) => sparkConf.set(s"$KUBERNETES_EXECUTOR_ANNOTATION_PREFIX$key", value) } + SECRET_ENV_VARS.foreach { case (key, value) => + sparkConf.set(s"$KUBERNETES_EXECUTOR_SECRET_KEY_REF_PREFIX$key", value) + } SECRET_NAMES_TO_MOUNT_PATHS.foreach { case (key, value) => sparkConf.set(s"$KUBERNETES_EXECUTOR_SECRETS_PREFIX$key", value) } @@ -170,6 +180,6 @@ class KubernetesConfSuite extends SparkFunSuite { SPARK_ROLE_LABEL -> SPARK_POD_EXECUTOR_ROLE) ++ CUSTOM_LABELS) assert(conf.roleAnnotations === CUSTOM_ANNOTATIONS) assert(conf.roleSecretNamesToMountPaths === SECRET_NAMES_TO_MOUNT_PATHS) + assert(conf.roleSecretEnvNamesToKeyRefs === SECRET_ENV_VARS) } - } diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStepSuite.scala index eee85b8baa730..b2813d8b3265d 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStepSuite.scala @@ -69,6 +69,7 @@ class BasicDriverFeatureStepSuite extends SparkFunSuite { DRIVER_LABELS, DRIVER_ANNOTATIONS, Map.empty, + Map.empty, DRIVER_ENVS) val featureStep = new BasicDriverFeatureStep(kubernetesConf) @@ -138,6 +139,7 @@ class BasicDriverFeatureStepSuite extends SparkFunSuite { DRIVER_LABELS, DRIVER_ANNOTATIONS, Map.empty, + Map.empty, Map.empty) val step = new BasicDriverFeatureStep(kubernetesConf) val additionalProperties = step.getAdditionalPodSystemProperties() diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStepSuite.scala index a764f7630b5c8..9182134b3337c 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStepSuite.scala @@ -87,6 +87,7 @@ class BasicExecutorFeatureStepSuite LABELS, ANNOTATIONS, Map.empty, + Map.empty, Map.empty)) val executor = step.configurePod(SparkPod.initialPod()) @@ -124,6 +125,7 @@ class BasicExecutorFeatureStepSuite LABELS, ANNOTATIONS, Map.empty, + Map.empty, Map.empty)) assert(step.configurePod(SparkPod.initialPod()).pod.getSpec.getHostname.length === 63) } @@ -142,6 +144,7 @@ class BasicExecutorFeatureStepSuite LABELS, ANNOTATIONS, Map.empty, + Map.empty, Map("qux" -> "quux"))) val executor = step.configurePod(SparkPod.initialPod()) diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverKubernetesCredentialsFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverKubernetesCredentialsFeatureStepSuite.scala index 9f817d3bfc79a..f81894f8055f1 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverKubernetesCredentialsFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverKubernetesCredentialsFeatureStepSuite.scala @@ -59,6 +59,7 @@ class DriverKubernetesCredentialsFeatureStepSuite extends SparkFunSuite with Bef Map.empty, Map.empty, Map.empty, + Map.empty, Map.empty) val kubernetesCredentialsStep = new DriverKubernetesCredentialsFeatureStep(kubernetesConf) assert(kubernetesCredentialsStep.configurePod(BASE_DRIVER_POD) === BASE_DRIVER_POD) @@ -88,6 +89,7 @@ class DriverKubernetesCredentialsFeatureStepSuite extends SparkFunSuite with Bef Map.empty, Map.empty, Map.empty, + Map.empty, Map.empty) val kubernetesCredentialsStep = new DriverKubernetesCredentialsFeatureStep(kubernetesConf) @@ -124,6 +126,7 @@ class DriverKubernetesCredentialsFeatureStepSuite extends SparkFunSuite with Bef Map.empty, Map.empty, Map.empty, + Map.empty, Map.empty) val kubernetesCredentialsStep = new DriverKubernetesCredentialsFeatureStep(kubernetesConf) val resolvedProperties = kubernetesCredentialsStep.getAdditionalPodSystemProperties() diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverServiceFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverServiceFeatureStepSuite.scala index c299d56865ec0..f265522a8823a 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverServiceFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverServiceFeatureStepSuite.scala @@ -65,6 +65,7 @@ class DriverServiceFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { DRIVER_LABELS, Map.empty, Map.empty, + Map.empty, Map.empty)) assert(configurationStep.configurePod(SparkPod.initialPod()) === SparkPod.initialPod()) assert(configurationStep.getAdditionalKubernetesResources().size === 1) @@ -94,6 +95,7 @@ class DriverServiceFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { DRIVER_LABELS, Map.empty, Map.empty, + Map.empty, Map.empty)) val expectedServiceName = SHORT_RESOURCE_NAME_PREFIX + DriverServiceFeatureStep.DRIVER_SVC_POSTFIX @@ -113,6 +115,7 @@ class DriverServiceFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { DRIVER_LABELS, Map.empty, Map.empty, + Map.empty, Map.empty)) val resolvedService = configurationStep .getAdditionalKubernetesResources() @@ -141,6 +144,7 @@ class DriverServiceFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { DRIVER_LABELS, Map.empty, Map.empty, + Map.empty, Map.empty), clock) val driverService = configurationStep @@ -166,6 +170,7 @@ class DriverServiceFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { DRIVER_LABELS, Map.empty, Map.empty, + Map.empty, Map.empty), clock) fail("The driver bind address should not be allowed.") @@ -189,6 +194,7 @@ class DriverServiceFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { DRIVER_LABELS, Map.empty, Map.empty, + Map.empty, Map.empty), clock) fail("The driver host address should not be allowed.") diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/EnvSecretsFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/EnvSecretsFeatureStepSuite.scala new file mode 100644 index 0000000000000..8b0b2d0739c76 --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/EnvSecretsFeatureStepSuite.scala @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.features + +import io.fabric8.kubernetes.api.model.PodBuilder + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.k8s._ + +class EnvSecretsFeatureStepSuite extends SparkFunSuite{ + private val KEY_REF_NAME_FOO = "foo" + private val KEY_REF_NAME_BAR = "bar" + private val KEY_REF_KEY_FOO = "key_foo" + private val KEY_REF_KEY_BAR = "key_bar" + private val ENV_NAME_FOO = "MY_FOO" + private val ENV_NAME_BAR = "MY_bar" + + test("sets up all keyRefs") { + val baseDriverPod = SparkPod.initialPod() + val envVarsToKeys = Map( + ENV_NAME_BAR -> s"${KEY_REF_NAME_BAR}:${KEY_REF_KEY_BAR}", + ENV_NAME_FOO -> s"${KEY_REF_NAME_FOO}:${KEY_REF_KEY_FOO}") + val sparkConf = new SparkConf(false) + val kubernetesConf = KubernetesConf( + sparkConf, + KubernetesExecutorSpecificConf("1", new PodBuilder().build()), + "resource-name-prefix", + "app-id", + Map.empty, + Map.empty, + Map.empty, + envVarsToKeys, + Map.empty) + + val step = new EnvSecretsFeatureStep(kubernetesConf) + val driverContainerWithEnvSecrets = step.configurePod(baseDriverPod).container + + val expectedVars = + Seq(s"${ENV_NAME_BAR}", s"${ENV_NAME_FOO}") + + expectedVars.foreach { envName => + assert(KubernetesFeaturesTestUtils.containerHasEnvVar(driverContainerWithEnvSecrets, envName)) + } + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/KubernetesFeaturesTestUtils.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/KubernetesFeaturesTestUtils.scala index 27bff74ce38af..f90380e30e52a 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/KubernetesFeaturesTestUtils.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/KubernetesFeaturesTestUtils.scala @@ -16,7 +16,9 @@ */ package org.apache.spark.deploy.k8s.features -import io.fabric8.kubernetes.api.model.{HasMetadata, PodBuilder, SecretBuilder} +import scala.collection.JavaConverters._ + +import io.fabric8.kubernetes.api.model.{Container, HasMetadata, PodBuilder, SecretBuilder} import org.mockito.Matchers import org.mockito.Mockito._ import org.mockito.invocation.InvocationOnMock @@ -58,4 +60,7 @@ object KubernetesFeaturesTestUtils { .build()) } + def containerHasEnvVar(container: Container, envVarName: String): Boolean = { + container.getEnv.asScala.exists(envVar => envVar.getName == envVarName) + } } diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/LocalDirsFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/LocalDirsFeatureStepSuite.scala index 91e184b84b86e..2542a02d37766 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/LocalDirsFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/LocalDirsFeatureStepSuite.scala @@ -43,6 +43,7 @@ class LocalDirsFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { Map.empty, Map.empty, Map.empty, + Map.empty, Map.empty) } diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountSecretsFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountSecretsFeatureStepSuite.scala index 9d02f56cc206d..9155793774123 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountSecretsFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountSecretsFeatureStepSuite.scala @@ -41,6 +41,7 @@ class MountSecretsFeatureStepSuite extends SparkFunSuite { Map.empty, Map.empty, secretNamesToMountPaths, + Map.empty, Map.empty) val step = new MountSecretsFeatureStep(kubernetesConf) diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala index c1b203e03a357..0775338098a13 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala @@ -142,6 +142,7 @@ class ClientSuite extends SparkFunSuite with BeforeAndAfter { Map.empty, Map.empty, Map.empty, + Map.empty, Map.empty) when(driverBuilder.buildFromFeatures(kubernetesConf)).thenReturn(BUILT_KUBERNETES_SPEC) when(kubernetesClient.pods()).thenReturn(podOperations) diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala index a511d254d2175..cb724068ea4f3 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.deploy.k8s.submit import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpec, KubernetesDriverSpecificConf} -import org.apache.spark.deploy.k8s.features.{BasicDriverFeatureStep, DriverKubernetesCredentialsFeatureStep, DriverServiceFeatureStep, KubernetesFeaturesTestUtils, LocalDirsFeatureStep, MountSecretsFeatureStep} +import org.apache.spark.deploy.k8s.features.{BasicDriverFeatureStep, DriverKubernetesCredentialsFeatureStep, DriverServiceFeatureStep, EnvSecretsFeatureStep, KubernetesFeaturesTestUtils, LocalDirsFeatureStep, MountSecretsFeatureStep} class KubernetesDriverBuilderSuite extends SparkFunSuite { @@ -27,6 +27,7 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite { private val SERVICE_STEP_TYPE = "service" private val LOCAL_DIRS_STEP_TYPE = "local-dirs" private val SECRETS_STEP_TYPE = "mount-secrets" + private val ENV_SECRETS_STEP_TYPE = "env-secrets" private val basicFeatureStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( BASIC_STEP_TYPE, classOf[BasicDriverFeatureStep]) @@ -43,12 +44,16 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite { private val secretsStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( SECRETS_STEP_TYPE, classOf[MountSecretsFeatureStep]) + private val envSecretsStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( + ENV_SECRETS_STEP_TYPE, classOf[EnvSecretsFeatureStep]) + private val builderUnderTest: KubernetesDriverBuilder = new KubernetesDriverBuilder( _ => basicFeatureStep, _ => credentialsStep, _ => serviceStep, _ => secretsStep, + _ => envSecretsStep, _ => localDirsStep) test("Apply fundamental steps all the time.") { @@ -64,6 +69,7 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite { Map.empty, Map.empty, Map.empty, + Map.empty, Map.empty) validateStepTypesApplied( builderUnderTest.buildFromFeatures(conf), @@ -86,6 +92,7 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite { Map.empty, Map.empty, Map("secret" -> "secretMountPath"), + Map("EnvName" -> "SecretName:secretKey"), Map.empty) validateStepTypesApplied( builderUnderTest.buildFromFeatures(conf), @@ -93,7 +100,9 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite { CREDENTIALS_STEP_TYPE, SERVICE_STEP_TYPE, LOCAL_DIRS_STEP_TYPE, - SECRETS_STEP_TYPE) + SECRETS_STEP_TYPE, + ENV_SECRETS_STEP_TYPE + ) } private def validateStepTypesApplied(resolvedSpec: KubernetesDriverSpec, stepTypes: String*) diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala index 9ee86b5a423a9..753cd30a237f3 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala @@ -20,23 +20,27 @@ import io.fabric8.kubernetes.api.model.PodBuilder import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesExecutorSpecificConf, SparkPod} -import org.apache.spark.deploy.k8s.features.{BasicExecutorFeatureStep, KubernetesFeaturesTestUtils, LocalDirsFeatureStep, MountSecretsFeatureStep} +import org.apache.spark.deploy.k8s.features.{BasicExecutorFeatureStep, EnvSecretsFeatureStep, KubernetesFeaturesTestUtils, LocalDirsFeatureStep, MountSecretsFeatureStep} class KubernetesExecutorBuilderSuite extends SparkFunSuite { private val BASIC_STEP_TYPE = "basic" private val SECRETS_STEP_TYPE = "mount-secrets" + private val ENV_SECRETS_STEP_TYPE = "env-secrets" private val LOCAL_DIRS_STEP_TYPE = "local-dirs" private val basicFeatureStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( BASIC_STEP_TYPE, classOf[BasicExecutorFeatureStep]) private val mountSecretsStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( SECRETS_STEP_TYPE, classOf[MountSecretsFeatureStep]) + private val envSecretsStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( + ENV_SECRETS_STEP_TYPE, classOf[EnvSecretsFeatureStep]) private val localDirsStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( LOCAL_DIRS_STEP_TYPE, classOf[LocalDirsFeatureStep]) private val builderUnderTest = new KubernetesExecutorBuilder( _ => basicFeatureStep, _ => mountSecretsStep, + _ => envSecretsStep, _ => localDirsStep) test("Basic steps are consistently applied.") { @@ -49,6 +53,7 @@ class KubernetesExecutorBuilderSuite extends SparkFunSuite { Map.empty, Map.empty, Map.empty, + Map.empty, Map.empty) validateStepTypesApplied( builderUnderTest.buildFromFeatures(conf), BASIC_STEP_TYPE, LOCAL_DIRS_STEP_TYPE) @@ -64,12 +69,14 @@ class KubernetesExecutorBuilderSuite extends SparkFunSuite { Map.empty, Map.empty, Map("secret" -> "secretMountPath"), + Map("secret-name" -> "secret-key"), Map.empty) validateStepTypesApplied( builderUnderTest.buildFromFeatures(conf), BASIC_STEP_TYPE, LOCAL_DIRS_STEP_TYPE, - SECRETS_STEP_TYPE) + SECRETS_STEP_TYPE, + ENV_SECRETS_STEP_TYPE) } private def validateStepTypesApplied(resolvedPod: SparkPod, stepTypes: String*): Unit = { From 2c9c8629b7aa87c057ec9b84b6082f8b9b9319b1 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Fri, 1 Jun 2018 08:44:57 +0800 Subject: [PATCH 0903/1161] [MINOR][YARN] Add YARN-specific credential providers in debug logging message This PR adds a debugging log for YARN-specific credential providers which is loaded by service loader mechanism. It took me a while to debug if it's actually loaded or not. I had to explicitly set the deprecated configuration and check if that's actually being loaded. The change scope is manually tested. Logs are like: ``` Using the following builtin delegation token providers: hadoopfs, hive, hbase. Using the following YARN-specific credential providers: yarn-test. ``` Author: hyukjinkwon Closes #21466 from HyukjinKwon/minor-log. Change-Id: I18e2fb8eeb3289b148f24c47bb3130a560a881cf --- .../spark/deploy/security/HadoopDelegationTokenManager.scala | 4 ++-- .../yarn/security/YARNHadoopDelegationTokenManager.scala | 4 ++++ 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManager.scala b/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManager.scala index 5151df00476f9..ab8d8d96a9b08 100644 --- a/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManager.scala +++ b/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManager.scala @@ -32,7 +32,7 @@ import org.apache.spark.internal.Logging * * Also, each HadoopDelegationTokenProvider is controlled by * spark.security.credentials.{service}.enabled, and will not be loaded if this config is set to - * false. For example, Hive's delegation token provider [[HiveDelegationTokenProvider]] can be + * false. For example, Hive's delegation token provider [[HiveDelegationTokenProvider]] can be * enabled/disabled by the configuration spark.security.credentials.hive.enabled. * * @param sparkConf Spark configuration @@ -52,7 +52,7 @@ private[spark] class HadoopDelegationTokenManager( // Maintain all the registered delegation token providers private val delegationTokenProviders = getDelegationTokenProviders - logDebug(s"Using the following delegation token providers: " + + logDebug("Using the following builtin delegation token providers: " + s"${delegationTokenProviders.keys.mkString(", ")}.") /** Construct a [[HadoopDelegationTokenManager]] for the default Hadoop filesystem */ diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/YARNHadoopDelegationTokenManager.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/YARNHadoopDelegationTokenManager.scala index d4eeb6bbcf886..26a2e5d730218 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/YARNHadoopDelegationTokenManager.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/YARNHadoopDelegationTokenManager.scala @@ -44,6 +44,10 @@ private[yarn] class YARNHadoopDelegationTokenManager( // public for testing val credentialProviders = getCredentialProviders + if (credentialProviders.nonEmpty) { + logDebug("Using the following YARN-specific credential providers: " + + s"${credentialProviders.keys.mkString(", ")}.") + } /** * Writes delegation tokens to creds. Delegation tokens are fetched from all registered From cbaa729132e9aee0e5d8aed332848e21dd8670a8 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Fri, 1 Jun 2018 10:01:15 +0800 Subject: [PATCH 0904/1161] [SPARK-24330][SQL] Refactor ExecuteWriteTask and Use `while` in writing files ## What changes were proposed in this pull request? 1. Refactor ExecuteWriteTask in FileFormatWriter to reduce common logic and improve readability. After the change, callers only need to call `commit()` or `abort` at the end of task. Also there is less code in `SingleDirectoryWriteTask` and `DynamicPartitionWriteTask`. Definitions of related classes are moved to a new file, and `ExecuteWriteTask` is renamed to `FileFormatDataWriter`. 2. As per code style guide: https://github.com/databricks/scala-style-guide#traversal-and-zipwithindex , we avoid using `for` for looping in [FileFormatWriter](https://github.com/apache/spark/pull/21381/files#diff-3b69eb0963b68c65cfe8075f8a42e850L536) , or `foreach` in [WriteToDataSourceV2Exec](https://github.com/apache/spark/pull/21381/files#diff-6fbe10db766049a395bae2e785e9d56eL119). In such critical code path, using `while` is good for performance. ## How was this patch tested? Existing unit test. I tried the microbenchmark in https://github.com/apache/spark/pull/21409 | Workload | Before changes(Best/Avg Time(ms)) | After changes(Best/Avg Time(ms)) | | --- | --- | -- | |Output Single Int Column| 2018 / 2043 | 2096 / 2236 | |Output Single Double Column| 1978 / 2043 | 2013 / 2018 | |Output Int and String Column| 6332 / 6706 | 6162 / 6298 | |Output Partitions| 4458 / 5094 | 3792 / 4008 | |Output Buckets| 5695 / 6102 | 5120 / 5154 | Also a microbenchmark on my laptop for general comparison among while/foreach/for : ``` class Writer { var sum = 0L def write(l: Long): Unit = sum += l } def testWhile(iterator: Iterator[Long]): Long = { val w = new Writer while (iterator.hasNext) { w.write(iterator.next()) } w.sum } def testForeach(iterator: Iterator[Long]): Long = { val w = new Writer iterator.foreach(w.write) w.sum } def testFor(iterator: Iterator[Long]): Long = { val w = new Writer for (x <- iterator) { w.write(x) } w.sum } val data = 0L to 100000000L val start = System.nanoTime (0 to 10).foreach(_ => testWhile(data.iterator)) println("benchmark while: " + (System.nanoTime - start)/1000000) val start2 = System.nanoTime (0 to 10).foreach(_ => testForeach(data.iterator)) println("benchmark foreach: " + (System.nanoTime - start2)/1000000) val start3 = System.nanoTime (0 to 10).foreach(_ => testForeach(data.iterator)) println("benchmark for: " + (System.nanoTime - start3)/1000000) ``` Benchmark result: `while`: 15401 ms `foreach`: 43034 ms `for`: 41279 ms Author: Gengliang Wang Closes #21381 from gengliangwang/refactorExecuteWriteTask. --- .../datasources/BasicWriteStatsTracker.scala | 2 +- .../datasources/FileFormatDataWriter.scala | 313 ++++++++++++++++ .../datasources/FileFormatWriter.scala | 353 +----------------- .../datasources/v2/WriteToDataSourceV2.scala | 4 +- 4 files changed, 334 insertions(+), 338 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BasicWriteStatsTracker.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BasicWriteStatsTracker.scala index 69c03d862391e..ba7d2b7cbdb1a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BasicWriteStatsTracker.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BasicWriteStatsTracker.scala @@ -31,7 +31,7 @@ import org.apache.spark.util.SerializableConfiguration /** - * Simple metrics collected during an instance of [[FileFormatWriter.ExecuteWriteTask]]. + * Simple metrics collected during an instance of [[FileFormatDataWriter]]. * These were first introduced in https://github.com/apache/spark/pull/18159 (SPARK-20703). */ case class BasicWriteTaskStats( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala new file mode 100644 index 0000000000000..6499328e89ce7 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala @@ -0,0 +1,313 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.datasources + +import scala.collection.mutable + +import org.apache.hadoop.fs.Path +import org.apache.hadoop.mapreduce.TaskAttemptContext + +import org.apache.spark.internal.io.FileCommitProtocol +import org.apache.spark.internal.io.FileCommitProtocol.TaskCommitMessage +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec +import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types.StringType +import org.apache.spark.util.SerializableConfiguration + +/** + * Abstract class for writing out data in a single Spark task. + * Exceptions thrown by the implementation of this trait will automatically trigger task aborts. + */ +abstract class FileFormatDataWriter( + description: WriteJobDescription, + taskAttemptContext: TaskAttemptContext, + committer: FileCommitProtocol) { + /** + * Max number of files a single task writes out due to file size. In most cases the number of + * files written should be very small. This is just a safe guard to protect some really bad + * settings, e.g. maxRecordsPerFile = 1. + */ + protected val MAX_FILE_COUNTER: Int = 1000 * 1000 + protected val updatedPartitions: mutable.Set[String] = mutable.Set[String]() + protected var currentWriter: OutputWriter = _ + + /** Trackers for computing various statistics on the data as it's being written out. */ + protected val statsTrackers: Seq[WriteTaskStatsTracker] = + description.statsTrackers.map(_.newTaskInstance()) + + protected def releaseResources(): Unit = { + if (currentWriter != null) { + try { + currentWriter.close() + } finally { + currentWriter = null + } + } + } + + /** Writes a record */ + def write(record: InternalRow): Unit + + /** + * Returns the summary of relative information which + * includes the list of partition strings written out. The list of partitions is sent back + * to the driver and used to update the catalog. Other information will be sent back to the + * driver too and used to e.g. update the metrics in UI. + */ + def commit(): WriteTaskResult = { + releaseResources() + val summary = ExecutedWriteSummary( + updatedPartitions = updatedPartitions.toSet, + stats = statsTrackers.map(_.getFinalStats())) + WriteTaskResult(committer.commitTask(taskAttemptContext), summary) + } + + def abort(): Unit = { + try { + releaseResources() + } finally { + committer.abortTask(taskAttemptContext) + } + } +} + +/** FileFormatWriteTask for empty partitions */ +class EmptyDirectoryDataWriter( + description: WriteJobDescription, + taskAttemptContext: TaskAttemptContext, + committer: FileCommitProtocol +) extends FileFormatDataWriter(description, taskAttemptContext, committer) { + override def write(record: InternalRow): Unit = {} +} + +/** Writes data to a single directory (used for non-dynamic-partition writes). */ +class SingleDirectoryDataWriter( + description: WriteJobDescription, + taskAttemptContext: TaskAttemptContext, + committer: FileCommitProtocol) + extends FileFormatDataWriter(description, taskAttemptContext, committer) { + private var fileCounter: Int = _ + private var recordsInFile: Long = _ + // Initialize currentWriter and statsTrackers + newOutputWriter() + + private def newOutputWriter(): Unit = { + recordsInFile = 0 + releaseResources() + + val ext = description.outputWriterFactory.getFileExtension(taskAttemptContext) + val currentPath = committer.newTaskTempFile( + taskAttemptContext, + None, + f"-c$fileCounter%03d" + ext) + + currentWriter = description.outputWriterFactory.newInstance( + path = currentPath, + dataSchema = description.dataColumns.toStructType, + context = taskAttemptContext) + + statsTrackers.foreach(_.newFile(currentPath)) + } + + override def write(record: InternalRow): Unit = { + if (description.maxRecordsPerFile > 0 && recordsInFile >= description.maxRecordsPerFile) { + fileCounter += 1 + assert(fileCounter < MAX_FILE_COUNTER, + s"File counter $fileCounter is beyond max value $MAX_FILE_COUNTER") + + newOutputWriter() + } + + currentWriter.write(record) + statsTrackers.foreach(_.newRow(record)) + recordsInFile += 1 + } +} + +/** + * Writes data to using dynamic partition writes, meaning this single function can write to + * multiple directories (partitions) or files (bucketing). + */ +class DynamicPartitionDataWriter( + description: WriteJobDescription, + taskAttemptContext: TaskAttemptContext, + committer: FileCommitProtocol) + extends FileFormatDataWriter(description, taskAttemptContext, committer) { + + /** Flag saying whether or not the data to be written out is partitioned. */ + private val isPartitioned = description.partitionColumns.nonEmpty + + /** Flag saying whether or not the data to be written out is bucketed. */ + private val isBucketed = description.bucketIdExpression.isDefined + + assert(isPartitioned || isBucketed, + s"""DynamicPartitionWriteTask should be used for writing out data that's either + |partitioned or bucketed. In this case neither is true. + |WriteJobDescription: $description + """.stripMargin) + + private var fileCounter: Int = _ + private var recordsInFile: Long = _ + private var currentPartionValues: Option[UnsafeRow] = None + private var currentBucketId: Option[Int] = None + + /** Extracts the partition values out of an input row. */ + private lazy val getPartitionValues: InternalRow => UnsafeRow = { + val proj = UnsafeProjection.create(description.partitionColumns, description.allColumns) + row => proj(row) + } + + /** Expression that given partition columns builds a path string like: col1=val/col2=val/... */ + private lazy val partitionPathExpression: Expression = Concat( + description.partitionColumns.zipWithIndex.flatMap { case (c, i) => + val partitionName = ScalaUDF( + ExternalCatalogUtils.getPartitionPathString _, + StringType, + Seq(Literal(c.name), Cast(c, StringType, Option(description.timeZoneId)))) + if (i == 0) Seq(partitionName) else Seq(Literal(Path.SEPARATOR), partitionName) + }) + + /** Evaluates the `partitionPathExpression` above on a row of `partitionValues` and returns + * the partition string. */ + private lazy val getPartitionPath: InternalRow => String = { + val proj = UnsafeProjection.create(Seq(partitionPathExpression), description.partitionColumns) + row => proj(row).getString(0) + } + + /** Given an input row, returns the corresponding `bucketId` */ + private lazy val getBucketId: InternalRow => Int = { + val proj = + UnsafeProjection.create(description.bucketIdExpression.toSeq, description.allColumns) + row => proj(row).getInt(0) + } + + /** Returns the data columns to be written given an input row */ + private val getOutputRow = + UnsafeProjection.create(description.dataColumns, description.allColumns) + + /** + * Opens a new OutputWriter given a partition key and/or a bucket id. + * If bucket id is specified, we will append it to the end of the file name, but before the + * file extension, e.g. part-r-00009-ea518ad4-455a-4431-b471-d24e03814677-00002.gz.parquet + * + * @param partitionValues the partition which all tuples being written by this `OutputWriter` + * belong to + * @param bucketId the bucket which all tuples being written by this `OutputWriter` belong to + */ + private def newOutputWriter(partitionValues: Option[InternalRow], bucketId: Option[Int]): Unit = { + recordsInFile = 0 + releaseResources() + + val partDir = partitionValues.map(getPartitionPath(_)) + partDir.foreach(updatedPartitions.add) + + val bucketIdStr = bucketId.map(BucketingUtils.bucketIdToString).getOrElse("") + + // This must be in a form that matches our bucketing format. See BucketingUtils. + val ext = f"$bucketIdStr.c$fileCounter%03d" + + description.outputWriterFactory.getFileExtension(taskAttemptContext) + + val customPath = partDir.flatMap { dir => + description.customPartitionLocations.get(PartitioningUtils.parsePathFragment(dir)) + } + val currentPath = if (customPath.isDefined) { + committer.newTaskTempFileAbsPath(taskAttemptContext, customPath.get, ext) + } else { + committer.newTaskTempFile(taskAttemptContext, partDir, ext) + } + + currentWriter = description.outputWriterFactory.newInstance( + path = currentPath, + dataSchema = description.dataColumns.toStructType, + context = taskAttemptContext) + + statsTrackers.foreach(_.newFile(currentPath)) + } + + override def write(record: InternalRow): Unit = { + val nextPartitionValues = if (isPartitioned) Some(getPartitionValues(record)) else None + val nextBucketId = if (isBucketed) Some(getBucketId(record)) else None + + if (currentPartionValues != nextPartitionValues || currentBucketId != nextBucketId) { + // See a new partition or bucket - write to a new partition dir (or a new bucket file). + if (isPartitioned && currentPartionValues != nextPartitionValues) { + currentPartionValues = Some(nextPartitionValues.get.copy()) + statsTrackers.foreach(_.newPartition(currentPartionValues.get)) + } + if (isBucketed) { + currentBucketId = nextBucketId + statsTrackers.foreach(_.newBucket(currentBucketId.get)) + } + + fileCounter = 0 + newOutputWriter(currentPartionValues, currentBucketId) + } else if (description.maxRecordsPerFile > 0 && + recordsInFile >= description.maxRecordsPerFile) { + // Exceeded the threshold in terms of the number of records per file. + // Create a new file by increasing the file counter. + fileCounter += 1 + assert(fileCounter < MAX_FILE_COUNTER, + s"File counter $fileCounter is beyond max value $MAX_FILE_COUNTER") + + newOutputWriter(currentPartionValues, currentBucketId) + } + val outputRow = getOutputRow(record) + currentWriter.write(outputRow) + statsTrackers.foreach(_.newRow(outputRow)) + recordsInFile += 1 + } +} + +/** A shared job description for all the write tasks. */ +class WriteJobDescription( + val uuid: String, // prevent collision between different (appending) write jobs + val serializableHadoopConf: SerializableConfiguration, + val outputWriterFactory: OutputWriterFactory, + val allColumns: Seq[Attribute], + val dataColumns: Seq[Attribute], + val partitionColumns: Seq[Attribute], + val bucketIdExpression: Option[Expression], + val path: String, + val customPartitionLocations: Map[TablePartitionSpec, String], + val maxRecordsPerFile: Long, + val timeZoneId: String, + val statsTrackers: Seq[WriteJobStatsTracker]) + extends Serializable { + + assert(AttributeSet(allColumns) == AttributeSet(partitionColumns ++ dataColumns), + s""" + |All columns: ${allColumns.mkString(", ")} + |Partition columns: ${partitionColumns.mkString(", ")} + |Data columns: ${dataColumns.mkString(", ")} + """.stripMargin) +} + +/** The result of a successful write task. */ +case class WriteTaskResult(commitMsg: TaskCommitMessage, summary: ExecutedWriteSummary) + +/** + * Wrapper class for the metrics of writing data out. + * + * @param updatedPartitions the partitions updated during writing data out. Only valid + * for dynamic partition. + * @param stats one `WriteTaskStats` object for every `WriteJobStatsTracker` that the job had. + */ +case class ExecutedWriteSummary( + updatedPartitions: Set[String], + stats: Seq[WriteTaskStats]) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala index 401597f967218..52da8356ab835 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala @@ -19,8 +19,6 @@ package org.apache.spark.sql.execution.datasources import java.util.{Date, UUID} -import scala.collection.mutable - import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.hadoop.mapreduce._ @@ -30,62 +28,25 @@ import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl import org.apache.spark._ import org.apache.spark.internal.Logging import org.apache.spark.internal.io.{FileCommitProtocol, SparkHadoopWriterUtils} -import org.apache.spark.internal.io.FileCommitProtocol.TaskCommitMessage import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.catalog.{BucketSpec, ExternalCatalogUtils} +import org.apache.spark.sql.catalyst.catalog.BucketSpec import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec -import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, _} +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} import org.apache.spark.sql.execution.{SortExec, SparkPlan, SQLExecution} -import org.apache.spark.sql.types.StringType import org.apache.spark.util.{SerializableConfiguration, Utils} /** A helper object for writing FileFormat data out to a location. */ object FileFormatWriter extends Logging { - - /** - * Max number of files a single task writes out due to file size. In most cases the number of - * files written should be very small. This is just a safe guard to protect some really bad - * settings, e.g. maxRecordsPerFile = 1. - */ - private val MAX_FILE_COUNTER = 1000 * 1000 - /** Describes how output files should be placed in the filesystem. */ case class OutputSpec( - outputPath: String, - customPartitionLocations: Map[TablePartitionSpec, String], - outputColumns: Seq[Attribute]) - - /** A shared job description for all the write tasks. */ - private class WriteJobDescription( - val uuid: String, // prevent collision between different (appending) write jobs - val serializableHadoopConf: SerializableConfiguration, - val outputWriterFactory: OutputWriterFactory, - val allColumns: Seq[Attribute], - val dataColumns: Seq[Attribute], - val partitionColumns: Seq[Attribute], - val bucketIdExpression: Option[Expression], - val path: String, - val customPartitionLocations: Map[TablePartitionSpec, String], - val maxRecordsPerFile: Long, - val timeZoneId: String, - val statsTrackers: Seq[WriteJobStatsTracker]) - extends Serializable { - - assert(AttributeSet(allColumns) == AttributeSet(partitionColumns ++ dataColumns), - s""" - |All columns: ${allColumns.mkString(", ")} - |Partition columns: ${partitionColumns.mkString(", ")} - |Data columns: ${dataColumns.mkString(", ")} - """.stripMargin) - } - - /** The result of a successful write task. */ - private case class WriteTaskResult(commitMsg: TaskCommitMessage, summary: ExecutedWriteSummary) + outputPath: String, + customPartitionLocations: Map[TablePartitionSpec, String], + outputColumns: Seq[Attribute]) /** * Basic work flow of this command is: @@ -262,30 +223,27 @@ object FileFormatWriter extends Logging { committer.setupTask(taskAttemptContext) - val writeTask = + val dataWriter = if (sparkPartitionId != 0 && !iterator.hasNext) { // In case of empty job, leave first partition to save meta for file format like parquet. - new EmptyDirectoryWriteTask(description) + new EmptyDirectoryDataWriter(description, taskAttemptContext, committer) } else if (description.partitionColumns.isEmpty && description.bucketIdExpression.isEmpty) { - new SingleDirectoryWriteTask(description, taskAttemptContext, committer) + new SingleDirectoryDataWriter(description, taskAttemptContext, committer) } else { - new DynamicPartitionWriteTask(description, taskAttemptContext, committer) + new DynamicPartitionDataWriter(description, taskAttemptContext, committer) } try { Utils.tryWithSafeFinallyAndFailureCallbacks(block = { // Execute the task to write rows out and commit the task. - val summary = writeTask.execute(iterator) - writeTask.releaseResources() - WriteTaskResult(committer.commitTask(taskAttemptContext), summary) - })(catchBlock = { - // If there is an error, release resource and then abort the task - try { - writeTask.releaseResources() - } finally { - committer.abortTask(taskAttemptContext) - logError(s"Job $jobId aborted.") + while (iterator.hasNext) { + dataWriter.write(iterator.next()) } + dataWriter.commit() + })(catchBlock = { + // If there is an error, abort the task + dataWriter.abort() + logError(s"Job $jobId aborted.") }) } catch { case e: FetchFailedException => @@ -302,7 +260,7 @@ object FileFormatWriter extends Logging { private def processStats( statsTrackers: Seq[WriteJobStatsTracker], statsPerTask: Seq[Seq[WriteTaskStats]]) - : Unit = { + : Unit = { val numStatsTrackers = statsTrackers.length assert(statsPerTask.forall(_.length == numStatsTrackers), @@ -321,281 +279,4 @@ object FileFormatWriter extends Logging { case (statsTracker, stats) => statsTracker.processStats(stats) } } - - /** - * A simple trait for writing out data in a single Spark task, without any concerns about how - * to commit or abort tasks. Exceptions thrown by the implementation of this trait will - * automatically trigger task aborts. - */ - private trait ExecuteWriteTask { - - /** - * Writes data out to files, and then returns the summary of relative information which - * includes the list of partition strings written out. The list of partitions is sent back - * to the driver and used to update the catalog. Other information will be sent back to the - * driver too and used to e.g. update the metrics in UI. - */ - def execute(iterator: Iterator[InternalRow]): ExecutedWriteSummary - def releaseResources(): Unit - } - - /** ExecuteWriteTask for empty partitions */ - private class EmptyDirectoryWriteTask(description: WriteJobDescription) - extends ExecuteWriteTask { - - val statsTrackers: Seq[WriteTaskStatsTracker] = - description.statsTrackers.map(_.newTaskInstance()) - - override def execute(iter: Iterator[InternalRow]): ExecutedWriteSummary = { - ExecutedWriteSummary( - updatedPartitions = Set.empty, - stats = statsTrackers.map(_.getFinalStats())) - } - - override def releaseResources(): Unit = {} - } - - /** Writes data to a single directory (used for non-dynamic-partition writes). */ - private class SingleDirectoryWriteTask( - description: WriteJobDescription, - taskAttemptContext: TaskAttemptContext, - committer: FileCommitProtocol) extends ExecuteWriteTask { - - private[this] var currentWriter: OutputWriter = _ - - val statsTrackers: Seq[WriteTaskStatsTracker] = - description.statsTrackers.map(_.newTaskInstance()) - - private def newOutputWriter(fileCounter: Int): Unit = { - val ext = description.outputWriterFactory.getFileExtension(taskAttemptContext) - val currentPath = committer.newTaskTempFile( - taskAttemptContext, - None, - f"-c$fileCounter%03d" + ext) - - currentWriter = description.outputWriterFactory.newInstance( - path = currentPath, - dataSchema = description.dataColumns.toStructType, - context = taskAttemptContext) - - statsTrackers.map(_.newFile(currentPath)) - } - - override def execute(iter: Iterator[InternalRow]): ExecutedWriteSummary = { - var fileCounter = 0 - var recordsInFile: Long = 0L - newOutputWriter(fileCounter) - - while (iter.hasNext) { - if (description.maxRecordsPerFile > 0 && recordsInFile >= description.maxRecordsPerFile) { - fileCounter += 1 - assert(fileCounter < MAX_FILE_COUNTER, - s"File counter $fileCounter is beyond max value $MAX_FILE_COUNTER") - - recordsInFile = 0 - releaseResources() - newOutputWriter(fileCounter) - } - - val internalRow = iter.next() - currentWriter.write(internalRow) - statsTrackers.foreach(_.newRow(internalRow)) - recordsInFile += 1 - } - releaseResources() - ExecutedWriteSummary( - updatedPartitions = Set.empty, - stats = statsTrackers.map(_.getFinalStats())) - } - - override def releaseResources(): Unit = { - if (currentWriter != null) { - try { - currentWriter.close() - } finally { - currentWriter = null - } - } - } - } - - /** - * Writes data to using dynamic partition writes, meaning this single function can write to - * multiple directories (partitions) or files (bucketing). - */ - private class DynamicPartitionWriteTask( - desc: WriteJobDescription, - taskAttemptContext: TaskAttemptContext, - committer: FileCommitProtocol) extends ExecuteWriteTask { - - /** Flag saying whether or not the data to be written out is partitioned. */ - val isPartitioned = desc.partitionColumns.nonEmpty - - /** Flag saying whether or not the data to be written out is bucketed. */ - val isBucketed = desc.bucketIdExpression.isDefined - - assert(isPartitioned || isBucketed, - s"""DynamicPartitionWriteTask should be used for writing out data that's either - |partitioned or bucketed. In this case neither is true. - |WriteJobDescription: ${desc} - """.stripMargin) - - // currentWriter is initialized whenever we see a new key (partitionValues + BucketId) - private var currentWriter: OutputWriter = _ - - /** Trackers for computing various statistics on the data as it's being written out. */ - private val statsTrackers: Seq[WriteTaskStatsTracker] = - desc.statsTrackers.map(_.newTaskInstance()) - - /** Extracts the partition values out of an input row. */ - private lazy val getPartitionValues: InternalRow => UnsafeRow = { - val proj = UnsafeProjection.create(desc.partitionColumns, desc.allColumns) - row => proj(row) - } - - /** Expression that given partition columns builds a path string like: col1=val/col2=val/... */ - private lazy val partitionPathExpression: Expression = Concat( - desc.partitionColumns.zipWithIndex.flatMap { case (c, i) => - val partitionName = ScalaUDF( - ExternalCatalogUtils.getPartitionPathString _, - StringType, - Seq(Literal(c.name), Cast(c, StringType, Option(desc.timeZoneId)))) - if (i == 0) Seq(partitionName) else Seq(Literal(Path.SEPARATOR), partitionName) - }) - - /** Evaluates the `partitionPathExpression` above on a row of `partitionValues` and returns - * the partition string. */ - private lazy val getPartitionPath: InternalRow => String = { - val proj = UnsafeProjection.create(Seq(partitionPathExpression), desc.partitionColumns) - row => proj(row).getString(0) - } - - /** Given an input row, returns the corresponding `bucketId` */ - private lazy val getBucketId: InternalRow => Int = { - val proj = UnsafeProjection.create(desc.bucketIdExpression.toSeq, desc.allColumns) - row => proj(row).getInt(0) - } - - /** Returns the data columns to be written given an input row */ - private val getOutputRow = UnsafeProjection.create(desc.dataColumns, desc.allColumns) - - /** - * Opens a new OutputWriter given a partition key and/or a bucket id. - * If bucket id is specified, we will append it to the end of the file name, but before the - * file extension, e.g. part-r-00009-ea518ad4-455a-4431-b471-d24e03814677-00002.gz.parquet - * - * @param partitionValues the partition which all tuples being written by this `OutputWriter` - * belong to - * @param bucketId the bucket which all tuples being written by this `OutputWriter` belong to - * @param fileCounter the number of files that have been written in the past for this specific - * partition. This is used to limit the max number of records written for a - * single file. The value should start from 0. - * @param updatedPartitions the set of updated partition paths, we should add the new partition - * path of this writer to it. - */ - private def newOutputWriter( - partitionValues: Option[InternalRow], - bucketId: Option[Int], - fileCounter: Int, - updatedPartitions: mutable.Set[String]): Unit = { - - val partDir = partitionValues.map(getPartitionPath(_)) - partDir.foreach(updatedPartitions.add) - - val bucketIdStr = bucketId.map(BucketingUtils.bucketIdToString).getOrElse("") - - // This must be in a form that matches our bucketing format. See BucketingUtils. - val ext = f"$bucketIdStr.c$fileCounter%03d" + - desc.outputWriterFactory.getFileExtension(taskAttemptContext) - - val customPath = partDir.flatMap { dir => - desc.customPartitionLocations.get(PartitioningUtils.parsePathFragment(dir)) - } - val currentPath = if (customPath.isDefined) { - committer.newTaskTempFileAbsPath(taskAttemptContext, customPath.get, ext) - } else { - committer.newTaskTempFile(taskAttemptContext, partDir, ext) - } - - currentWriter = desc.outputWriterFactory.newInstance( - path = currentPath, - dataSchema = desc.dataColumns.toStructType, - context = taskAttemptContext) - - statsTrackers.foreach(_.newFile(currentPath)) - } - - override def execute(iter: Iterator[InternalRow]): ExecutedWriteSummary = { - // If anything below fails, we should abort the task. - var recordsInFile: Long = 0L - var fileCounter = 0 - val updatedPartitions = mutable.Set[String]() - var currentPartionValues: Option[UnsafeRow] = None - var currentBucketId: Option[Int] = None - - for (row <- iter) { - val nextPartitionValues = if (isPartitioned) Some(getPartitionValues(row)) else None - val nextBucketId = if (isBucketed) Some(getBucketId(row)) else None - - if (currentPartionValues != nextPartitionValues || currentBucketId != nextBucketId) { - // See a new partition or bucket - write to a new partition dir (or a new bucket file). - if (isPartitioned && currentPartionValues != nextPartitionValues) { - currentPartionValues = Some(nextPartitionValues.get.copy()) - statsTrackers.foreach(_.newPartition(currentPartionValues.get)) - } - if (isBucketed) { - currentBucketId = nextBucketId - statsTrackers.foreach(_.newBucket(currentBucketId.get)) - } - - recordsInFile = 0 - fileCounter = 0 - - releaseResources() - newOutputWriter(currentPartionValues, currentBucketId, fileCounter, updatedPartitions) - } else if (desc.maxRecordsPerFile > 0 && - recordsInFile >= desc.maxRecordsPerFile) { - // Exceeded the threshold in terms of the number of records per file. - // Create a new file by increasing the file counter. - recordsInFile = 0 - fileCounter += 1 - assert(fileCounter < MAX_FILE_COUNTER, - s"File counter $fileCounter is beyond max value $MAX_FILE_COUNTER") - - releaseResources() - newOutputWriter(currentPartionValues, currentBucketId, fileCounter, updatedPartitions) - } - val outputRow = getOutputRow(row) - currentWriter.write(outputRow) - statsTrackers.foreach(_.newRow(outputRow)) - recordsInFile += 1 - } - releaseResources() - - ExecutedWriteSummary( - updatedPartitions = updatedPartitions.toSet, - stats = statsTrackers.map(_.getFinalStats())) - } - - override def releaseResources(): Unit = { - if (currentWriter != null) { - try { - currentWriter.close() - } finally { - currentWriter = null - } - } - } - } } - -/** - * Wrapper class for the metrics of writing data out. - * - * @param updatedPartitions the partitions updated during writing data out. Only valid - * for dynamic partition. - * @param stats one `WriteTaskStats` object for every `WriteJobStatsTracker` that the job had. - */ -case class ExecutedWriteSummary( - updatedPartitions: Set[String], - stats: Seq[WriteTaskStats]) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala index ea283ed77efda..ea4bda327f36f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala @@ -116,7 +116,9 @@ object DataWritingSparkTask extends Logging { // write the data and commit this writer. Utils.tryWithSafeFinallyAndFailureCallbacks(block = { - iter.foreach(dataWriter.write) + while (iter.hasNext) { + dataWriter.write(iter.next()) + } val msg = if (useCommitCoordinator) { val coordinator = SparkEnv.get.outputCommitCoordinator From b2d022656298c7a39ff3e84b04f813d5f315cb95 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Fri, 1 Jun 2018 11:58:59 +0800 Subject: [PATCH 0905/1161] [SPARK-24444][DOCS][PYTHON] Improve Pandas UDF docs to explain column assignment ## What changes were proposed in this pull request? Added sections to pandas_udf docs, in the grouped map section, to indicate columns are assigned by position. ## How was this patch tested? NA Author: Bryan Cutler Closes #21471 from BryanCutler/arrow-doc-pandas_udf-column_by_pos-SPARK-21427. --- docs/sql-programming-guide.md | 9 +++++++++ python/pyspark/sql/functions.py | 9 ++++++++- 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 50600861912b1..4d8a738507bd1 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1752,6 +1752,15 @@ To use `groupBy().apply()`, the user needs to define the following: * A Python function that defines the computation for each group. * A `StructType` object or a string that defines the schema of the output `DataFrame`. +The output schema will be applied to the columns of the returned `pandas.DataFrame` in order by position, +not by name. This means that the columns in the `pandas.DataFrame` must be indexed so that their +position matches the corresponding field in the schema. + +Note that when creating a new `pandas.DataFrame` using a dictionary, the actual position of the column +can differ from the order that it was placed in the dictionary. It is recommended in this case to +explicitly define the column order using the `columns` keyword, e.g. +`pandas.DataFrame({'id': ids, 'a': data}, columns=['id', 'a'])`, or alternatively use an `OrderedDict`. + Note that all data for a group will be loaded into memory before the function is applied. This can lead to out of memory exceptons, especially if the group sizes are skewed. The configuration for [maxRecordsPerBatch](#setting-arrow-batch-size) is not applied on groups and it is up to the user diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index efcce25a08e04..fd656c5c35844 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2500,7 +2500,8 @@ def pandas_udf(f=None, returnType=None, functionType=None): A grouped map UDF defines transformation: A `pandas.DataFrame` -> A `pandas.DataFrame` The returnType should be a :class:`StructType` describing the schema of the returned `pandas.DataFrame`. - The length of the returned `pandas.DataFrame` can be arbitrary. + The length of the returned `pandas.DataFrame` can be arbitrary and the columns must be + indexed so that their position matches the corresponding field in the schema. Grouped map UDFs are used with :meth:`pyspark.sql.GroupedData.apply`. @@ -2548,6 +2549,12 @@ def pandas_udf(f=None, returnType=None, functionType=None): | 2|6.0| +---+---+ + .. note:: If returning a new `pandas.DataFrame` constructed with a dictionary, it is + recommended to explicitly index the columns by name to ensure the positions are correct, + or alternatively use an `OrderedDict`. + For example, `pd.DataFrame({'id': ids, 'a': data}, columns=['id', 'a'])` or + `pd.DataFrame(OrderedDict([('id', ids), ('a', data)]))`. + .. seealso:: :meth:`pyspark.sql.GroupedData.apply` 3. GROUPED_AGG From 22df953f6bb191858053eafbabaa5b3ebca29f56 Mon Sep 17 00:00:00 2001 From: Stavros Kontopoulos Date: Thu, 31 May 2018 21:25:45 -0700 Subject: [PATCH 0906/1161] [SPARK-24326][MESOS] add support for local:// scheme for the app jar ## What changes were proposed in this pull request? * Adds support for local:// scheme like in k8s case for image based deployments where the jar is already in the image. Affects cluster mode and the mesos dispatcher. Covers also file:// scheme. Keeps the default case where jar resolution happens on the host. ## How was this patch tested? Dispatcher image with the patch, use it to start DC/OS Spark service: skonto/spark-local-disp:test Test image with my application jar located at the root folder: skonto/spark-local:test Dockerfile for that image. From mesosphere/spark:2.3.0-2.2.1-2-hadoop-2.6 COPY spark-examples_2.11-2.2.1.jar / WORKDIR /opt/spark/dist Tests: The following work as expected: * local normal example ``` dcos spark run --submit-args="--conf spark.mesos.appJar.local.resolution.mode=container --conf spark.executor.memory=1g --conf spark.mesos.executor.docker.image=skonto/spark-local:test --conf spark.executor.cores=2 --conf spark.cores.max=8 --class org.apache.spark.examples.SparkPi local:///spark-examples_2.11-2.2.1.jar" ``` * make sure the flag does not affect other uris ``` dcos spark run --submit-args="--conf spark.mesos.appJar.local.resolution.mode=container --conf spark.executor.memory=1g --conf spark.executor.cores=2 --conf spark.cores.max=8 --class org.apache.spark.examples.SparkPi https://s3-eu-west-1.amazonaws.com/fdp-stavros-test/spark-examples_2.11-2.1.1.jar" ``` * normal example no local ``` dcos spark run --submit-args="--conf spark.executor.memory=1g --conf spark.executor.cores=2 --conf spark.cores.max=8 --class org.apache.spark.examples.SparkPi https://s3-eu-west-1.amazonaws.com/fdp-stavros-test/spark-examples_2.11-2.1.1.jar" ``` The following fails * uses local with no setting, default is host. ``` dcos spark run --submit-args="--conf spark.executor.memory=1g --conf spark.mesos.executor.docker.image=skonto/spark-local:test --conf spark.executor.cores=2 --conf spark.cores.max=8 --class org.apache.spark.examples.SparkPi local:///spark-examples_2.11-2.2.1.jar" ``` ![image](https://user-images.githubusercontent.com/7945591/40283021-8d349762-5c80-11e8-9d62-2a61a4318fd5.png) Author: Stavros Kontopoulos Closes #21378 from skonto/local-upstream. --- docs/running-on-mesos.md | 12 +++++++ .../cluster/mesos/MesosClusterScheduler.scala | 36 +++++++++++++++---- 2 files changed, 41 insertions(+), 7 deletions(-) diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md index 3c2a1501ca692..66ffb17949845 100644 --- a/docs/running-on-mesos.md +++ b/docs/running-on-mesos.md @@ -753,6 +753,18 @@ See the [configuration page](configuration.html) for information on Spark config spark.cores.max is reached + + spark.mesos.appJar.local.resolution.mode + host + + Provides support for the `local:///` scheme to reference the app jar resource in cluster mode. + If user uses a local resource (`local:///path/to/jar`) and the config option is not used it defaults to `host` eg. + the mesos fetcher tries to get the resource from the host's file system. + If the value is unknown it prints a warning msg in the dispatcher logs and defaults to `host`. + If the value is `container` then spark submit in the container will use the jar in the container's path: + `/path/to/jar`. + + # Troubleshooting and Debugging diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala index b36f46456f9a5..7d80eedcc43ce 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala @@ -30,8 +30,7 @@ import org.apache.mesos.Protos.Environment.Variable import org.apache.mesos.Protos.TaskStatus.Reason import org.apache.spark.{SecurityManager, SparkConf, SparkException, TaskState} -import org.apache.spark.deploy.mesos.MesosDriverDescription -import org.apache.spark.deploy.mesos.config +import org.apache.spark.deploy.mesos.{config, MesosDriverDescription} import org.apache.spark.deploy.rest.{CreateSubmissionResponse, KillSubmissionResponse, SubmissionStatusResponse} import org.apache.spark.metrics.MetricsSystem import org.apache.spark.util.Utils @@ -418,6 +417,18 @@ private[spark] class MesosClusterScheduler( envBuilder.build() } + private def isContainerLocalAppJar(desc: MesosDriverDescription): Boolean = { + val isLocalJar = desc.jarUrl.startsWith("local://") + val isContainerLocal = desc.conf.getOption("spark.mesos.appJar.local.resolution.mode").exists { + case "container" => true + case "host" => false + case other => + logWarning(s"Unknown spark.mesos.appJar.local.resolution.mode $other, using host.") + false + } + isLocalJar && isContainerLocal + } + private def getDriverUris(desc: MesosDriverDescription): List[CommandInfo.URI] = { val confUris = List(conf.getOption("spark.mesos.uris"), desc.conf.getOption("spark.mesos.uris"), @@ -425,10 +436,14 @@ private[spark] class MesosClusterScheduler( _.map(_.split(",").map(_.trim)) ).flatten - val jarUrl = desc.jarUrl.stripPrefix("file:").stripPrefix("local:") - - ((jarUrl :: confUris) ++ getDriverExecutorURI(desc).toList).map(uri => - CommandInfo.URI.newBuilder().setValue(uri.trim()).setCache(useFetchCache).build()) + if (isContainerLocalAppJar(desc)) { + (confUris ++ getDriverExecutorURI(desc).toList).map(uri => + CommandInfo.URI.newBuilder().setValue(uri.trim()).setCache(useFetchCache).build()) + } else { + val jarUrl = desc.jarUrl.stripPrefix("file:").stripPrefix("local:") + ((jarUrl :: confUris) ++ getDriverExecutorURI(desc).toList).map(uri => + CommandInfo.URI.newBuilder().setValue(uri.trim()).setCache(useFetchCache).build()) + } } private def getContainerInfo(desc: MesosDriverDescription): ContainerInfo.Builder = { @@ -480,7 +495,14 @@ private[spark] class MesosClusterScheduler( (cmdExecutable, ".") } val cmdOptions = generateCmdOption(desc, sandboxPath).mkString(" ") - val primaryResource = new File(sandboxPath, desc.jarUrl.split("/").last).toString() + val primaryResource = { + if (isContainerLocalAppJar(desc)) { + new File(desc.jarUrl.stripPrefix("local://")).toString() + } else { + new File(sandboxPath, desc.jarUrl.split("/").last).toString() + } + } + val appArguments = desc.command.arguments.mkString(" ") s"$executable $cmdOptions $primaryResource $appArguments" From 98909c398dbcbffcae8015d36f44f185d3280af6 Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Thu, 31 May 2018 22:04:26 -0700 Subject: [PATCH 0907/1161] [SPARK-23920][SQL] add array_remove to remove all elements that equal element from array ## What changes were proposed in this pull request? add array_remove to remove all elements that equal element from array ## How was this patch tested? add unit tests Author: Huaxin Gao Closes #21069 from huaxingao/spark-23920. --- python/pyspark/sql/functions.py | 16 +++ .../catalyst/analysis/FunctionRegistry.scala | 1 + .../expressions/collectionOperations.scala | 123 ++++++++++++++++++ .../CollectionExpressionsSuite.scala | 58 +++++++++ .../org/apache/spark/sql/functions.scala | 9 ++ .../spark/sql/DataFrameFunctionsSuite.scala | 29 +++++ 6 files changed, 236 insertions(+) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index fd656c5c35844..1759195c6fcc0 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1964,6 +1964,22 @@ def element_at(col, extraction): return Column(sc._jvm.functions.element_at(_to_java_column(col), extraction)) +@since(2.4) +def array_remove(col, element): + """ + Collection function: Remove all elements that equal to element from the given array. + + :param col: name of column containing array + :param element: element to be removed from the array + + >>> df = spark.createDataFrame([([1, 2, 3, 1, 1],), ([],)], ['data']) + >>> df.select(array_remove(df.data, 1)).collect() + [Row(array_remove(data, 1)=[2, 3]), Row(array_remove(data, 1)=[])] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.array_remove(_to_java_column(col), element)) + + @since(1.4) def explode(col): """Returns a new row for each element in the given array or map. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 23a4a440fac23..49fb35b083580 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -430,6 +430,7 @@ object FunctionRegistry { expression[Concat]("concat"), expression[Flatten]("flatten"), expression[ArrayRepeat]("array_repeat"), + expression[ArrayRemove]("array_remove"), CreateStruct.registryEntry, // mask functions diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 8a877b02c8191..176995affe701 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -2066,3 +2066,126 @@ case class ArrayRepeat(left: Expression, right: Expression) } } + +/** + * Remove all elements that equal to element from the given array + */ +@ExpressionDescription( + usage = "_FUNC_(array, element) - Remove all elements that equal to element from array.", + examples = """ + Examples: + > SELECT _FUNC_(array(1, 2, 3, null, 3), 3); + [1,2,null] + """, since = "2.4.0") +case class ArrayRemove(left: Expression, right: Expression) + extends BinaryExpression with ImplicitCastInputTypes { + + override def dataType: DataType = left.dataType + + override def inputTypes: Seq[AbstractDataType] = { + val elementType = left.dataType match { + case t: ArrayType => t.elementType + case _ => AnyDataType + } + Seq(ArrayType, elementType) + } + + lazy val elementType: DataType = left.dataType.asInstanceOf[ArrayType].elementType + + @transient private lazy val ordering: Ordering[Any] = + TypeUtils.getInterpretedOrdering(right.dataType) + + override def checkInputDataTypes(): TypeCheckResult = { + super.checkInputDataTypes() match { + case f: TypeCheckResult.TypeCheckFailure => f + case TypeCheckResult.TypeCheckSuccess => + TypeUtils.checkForOrderingExpr(right.dataType, s"function $prettyName") + } + } + + override def nullSafeEval(arr: Any, value: Any): Any = { + val newArray = new Array[Any](arr.asInstanceOf[ArrayData].numElements()) + var pos = 0 + arr.asInstanceOf[ArrayData].foreach(right.dataType, (i, v) => + if (v == null || !ordering.equiv(v, value)) { + newArray(pos) = v + pos += 1 + } + ) + new GenericArrayData(newArray.slice(0, pos)) + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, (arr, value) => { + val numsToRemove = ctx.freshName("numsToRemove") + val newArraySize = ctx.freshName("newArraySize") + val i = ctx.freshName("i") + val getValue = CodeGenerator.getValue(arr, elementType, i) + val isEqual = ctx.genEqual(elementType, value, getValue) + s""" + |int $numsToRemove = 0; + |for (int $i = 0; $i < $arr.numElements(); $i ++) { + | if (!$arr.isNullAt($i) && $isEqual) { + | $numsToRemove = $numsToRemove + 1; + | } + |} + |int $newArraySize = $arr.numElements() - $numsToRemove; + |${genCodeForResult(ctx, ev, arr, value, newArraySize)} + """.stripMargin + }) + } + + def genCodeForResult( + ctx: CodegenContext, + ev: ExprCode, + inputArray: String, + value: String, + newArraySize: String): String = { + val values = ctx.freshName("values") + val i = ctx.freshName("i") + val pos = ctx.freshName("pos") + val getValue = CodeGenerator.getValue(inputArray, elementType, i) + val isEqual = ctx.genEqual(elementType, value, getValue) + if (!CodeGenerator.isPrimitiveType(elementType)) { + val arrayClass = classOf[GenericArrayData].getName + s""" + |int $pos = 0; + |Object[] $values = new Object[$newArraySize]; + |for (int $i = 0; $i < $inputArray.numElements(); $i ++) { + | if ($inputArray.isNullAt($i)) { + | $values[$pos] = null; + | $pos = $pos + 1; + | } + | else { + | if (!($isEqual)) { + | $values[$pos] = $getValue; + | $pos = $pos + 1; + | } + | } + |} + |${ev.value} = new $arrayClass($values); + """.stripMargin + } else { + val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType) + s""" + |${ctx.createUnsafeArray(values, newArraySize, elementType, s" $prettyName failed.")} + |int $pos = 0; + |for (int $i = 0; $i < $inputArray.numElements(); $i ++) { + | if ($inputArray.isNullAt($i)) { + | $values.setNullAt($pos); + | $pos = $pos + 1; + | } + | else { + | if (!($isEqual)) { + | $values.set$primitiveValueTypeName($pos, $getValue); + | $pos = $pos + 1; + | } + | } + |} + |${ev.value} = $values; + """.stripMargin + } + } + + override def prettyName: String = "array_remove" +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 3fc0b08c56e02..f8ad624ce0e3d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -622,4 +622,62 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(ArrayRepeat(strArray, Literal(2)), Seq(Seq("hi", "hola"), Seq("hi", "hola"))) checkEvaluation(ArrayRepeat(Literal("hi"), Literal(null, IntegerType)), null) } + + test("Array remove") { + val a0 = Literal.create(Seq(1, 2, 3, 2, 2, 5), ArrayType(IntegerType)) + val a1 = Literal.create(Seq("b", "a", "a", "c", "b"), ArrayType(StringType)) + val a2 = Literal.create(Seq[String](null, "", null, ""), ArrayType(StringType)) + val a3 = Literal.create(Seq.empty[Integer], ArrayType(IntegerType)) + val a4 = Literal.create(null, ArrayType(StringType)) + val a5 = Literal.create(Seq(1, null, 8, 9, null), ArrayType(IntegerType)) + val a6 = Literal.create(Seq(true, false, false, true), ArrayType(BooleanType)) + + checkEvaluation(ArrayRemove(a0, Literal(0)), Seq(1, 2, 3, 2, 2, 5)) + checkEvaluation(ArrayRemove(a0, Literal(1)), Seq(2, 3, 2, 2, 5)) + checkEvaluation(ArrayRemove(a0, Literal(2)), Seq(1, 3, 5)) + checkEvaluation(ArrayRemove(a0, Literal(3)), Seq(1, 2, 2, 2, 5)) + checkEvaluation(ArrayRemove(a0, Literal(5)), Seq(1, 2, 3, 2, 2)) + checkEvaluation(ArrayRemove(a0, Literal(null, IntegerType)), null) + + checkEvaluation(ArrayRemove(a1, Literal("")), Seq("b", "a", "a", "c", "b")) + checkEvaluation(ArrayRemove(a1, Literal("a")), Seq("b", "c", "b")) + checkEvaluation(ArrayRemove(a1, Literal("b")), Seq("a", "a", "c")) + checkEvaluation(ArrayRemove(a1, Literal("c")), Seq("b", "a", "a", "b")) + + checkEvaluation(ArrayRemove(a2, Literal("")), Seq(null, null)) + checkEvaluation(ArrayRemove(a2, Literal(null, StringType)), null) + + checkEvaluation(ArrayRemove(a3, Literal(1)), Seq.empty[Integer]) + + checkEvaluation(ArrayRemove(a4, Literal("a")), null) + + checkEvaluation(ArrayRemove(a5, Literal(9)), Seq(1, null, 8, null)) + checkEvaluation(ArrayRemove(a6, Literal(false)), Seq(true, true)) + + // complex data types + val b0 = Literal.create(Seq[Array[Byte]](Array[Byte](5, 6), Array[Byte](1, 2), + Array[Byte](1, 2), Array[Byte](5, 6)), ArrayType(BinaryType)) + val b1 = Literal.create(Seq[Array[Byte]](Array[Byte](2, 1), null), + ArrayType(BinaryType)) + val b2 = Literal.create(Seq[Array[Byte]](null, Array[Byte](1, 2)), + ArrayType(BinaryType)) + val nullBinary = Literal.create(null, BinaryType) + + val dataToRemove1 = Literal.create(Array[Byte](5, 6), BinaryType) + checkEvaluation(ArrayRemove(b0, dataToRemove1), + Seq[Array[Byte]](Array[Byte](1, 2), Array[Byte](1, 2))) + checkEvaluation(ArrayRemove(b0, nullBinary), null) + checkEvaluation(ArrayRemove(b1, dataToRemove1), Seq[Array[Byte]](Array[Byte](2, 1), null)) + checkEvaluation(ArrayRemove(b2, dataToRemove1), Seq[Array[Byte]](null, Array[Byte](1, 2))) + + val c0 = Literal.create(Seq[Seq[Int]](Seq[Int](1, 2), Seq[Int](3, 4)), + ArrayType(ArrayType(IntegerType))) + val c1 = Literal.create(Seq[Seq[Int]](Seq[Int](5, 6), Seq[Int](2, 1)), + ArrayType(ArrayType(IntegerType))) + val c2 = Literal.create(Seq[Seq[Int]](null, Seq[Int](2, 1)), ArrayType(ArrayType(IntegerType))) + val dataToRemove2 = Literal.create(Seq[Int](1, 2), ArrayType(IntegerType)) + checkEvaluation(ArrayRemove(c0, dataToRemove2), Seq[Seq[Int]](Seq[Int](3, 4))) + checkEvaluation(ArrayRemove(c1, dataToRemove2), Seq[Seq[Int]](Seq[Int](5, 6), Seq[Int](2, 1))) + checkEvaluation(ArrayRemove(c2, dataToRemove2), Seq[Seq[Int]](null, Seq[Int](2, 1))) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 443ba2aa3757d..a2aae9a708ff3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3169,6 +3169,15 @@ object functions { */ def array_sort(e: Column): Column = withExpr { ArraySort(e.expr) } + /** + * Remove all elements that equal to element from the given array. + * @group collection_funcs + * @since 2.4.0 + */ + def array_remove(column: Column, element: Any): Column = withExpr { + ArrayRemove(column.expr, Literal(element)) + } + /** * Creates a new row for each element in the given array or map column. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index cc8bad4ded53e..59119bbbd8a2c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -1110,6 +1110,35 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { } + test("array remove") { + val df = Seq( + (Array[Int](2, 1, 2, 3), Array("a", "b", "c", "a"), Array("", "")), + (Array.empty[Int], Array.empty[String], Array.empty[String]), + (null, null, null) + ).toDF("a", "b", "c") + checkAnswer( + df.select(array_remove($"a", 2), array_remove($"b", "a"), array_remove($"c", "")), + Seq( + Row(Seq(1, 3), Seq("b", "c"), Seq.empty[String]), + Row(Seq.empty[Int], Seq.empty[String], Seq.empty[String]), + Row(null, null, null)) + ) + + checkAnswer( + df.selectExpr("array_remove(a, 2)", "array_remove(b, \"a\")", + "array_remove(c, \"\")"), + Seq( + Row(Seq(1, 3), Seq("b", "c"), Seq.empty[String]), + Row(Seq.empty[Int], Seq.empty[String], Seq.empty[String]), + Row(null, null, null)) + ) + + val e = intercept[AnalysisException] { + Seq(("a string element", "a")).toDF().selectExpr("array_remove(_1, _2)") + } + assert(e.message.contains("argument 1 requires array type, however, '`_1`' is of string type")) + } + private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = { import DataFrameFunctionsSuite.CodegenFallbackExpr for ((codegenFallback, wholeStage) <- Seq((true, false), (false, false), (false, true))) { From 6039b132304cc77ed39e4ca7813850507ae0b440 Mon Sep 17 00:00:00 2001 From: Huang Tengfei Date: Fri, 1 Jun 2018 10:47:53 -0700 Subject: [PATCH 0908/1161] [SPARK-24351][SS] offsetLog/commitLog purge thresholdBatchId should be computed with current committed epoch but not currentBatchId in CP mode ## What changes were proposed in this pull request? Compute the thresholdBatchId to purge metadata based on current committed epoch instead of currentBatchId in CP mode to avoid cleaning all the committed metadata in some case as described in the jira [SPARK-24351](https://issues.apache.org/jira/browse/SPARK-24351). ## How was this patch tested? Add new unit test. Author: Huang Tengfei Closes #21400 from ivoson/branch-cp-meta. --- .../continuous/ContinuousExecution.scala | 11 +++-- .../continuous/ContinuousSuite.scala | 46 +++++++++++++++++++ 2 files changed, 54 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index d16b24c89ebef..e3d0cea608b2a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -318,9 +318,14 @@ class ContinuousExecution( } } - if (minLogEntriesToMaintain < currentBatchId) { - offsetLog.purge(currentBatchId - minLogEntriesToMaintain) - commitLog.purge(currentBatchId - minLogEntriesToMaintain) + // Since currentBatchId increases independently in cp mode, the current committed epoch may + // be far behind currentBatchId. It is not safe to discard the metadata with thresholdBatchId + // computed based on currentBatchId. As minLogEntriesToMaintain is used to keep the minimum + // number of batches that must be retained and made recoverable, so we should keep the + // specified number of metadata that have been committed. + if (minLogEntriesToMaintain <= epoch) { + offsetLog.purge(epoch + 1 - minLogEntriesToMaintain) + commitLog.purge(epoch + 1 - minLogEntriesToMaintain) } awaitProgressLock.lock() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala index cd1704ac2fdad..4980b0cd41f81 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala @@ -297,3 +297,49 @@ class ContinuousStressSuite extends ContinuousSuiteBase { CheckAnswerRowsContains(scala.Range(0, 25000).map(Row(_)))) } } + +class ContinuousMetaSuite extends ContinuousSuiteBase { + import testImplicits._ + + // We need to specify spark.sql.streaming.minBatchesToRetain to do the following test. + override protected def createSparkSession = new TestSparkSession( + new SparkContext( + "local[10]", + "continuous-stream-test-sql-context", + sparkConf.set("spark.sql.testkey", "true") + .set("spark.sql.streaming.minBatchesToRetain", "2"))) + + test("SPARK-24351: check offsetLog/commitLog retained in the checkpoint directory") { + withTempDir { checkpointDir => + val input = ContinuousMemoryStream[Int] + val df = input.toDF().mapPartitions(iter => { + // Sleep the task thread for 300 ms to make sure epoch processing time 3 times + // longer than epoch creating interval. So the gap between last committed + // epoch and currentBatchId grows over time. + Thread.sleep(300) + iter.map(row => row.getInt(0) * 2) + }) + + testStream(df)( + StartStream(trigger = Trigger.Continuous(100), + checkpointLocation = checkpointDir.getAbsolutePath), + AddData(input, 1), + CheckAnswer(2), + // Make sure epoch 2 has been committed before the following validation. + AwaitEpoch(2), + StopStream, + AssertOnQuery(q => { + q.commitLog.getLatest() match { + case Some((latestEpochId, _)) => + val commitLogValidateResult = q.commitLog.get(latestEpochId - 1).isDefined && + q.commitLog.get(latestEpochId - 2).isEmpty + val offsetLogValidateResult = q.offsetLog.get(latestEpochId - 1).isDefined && + q.offsetLog.get(latestEpochId - 2).isEmpty + commitLogValidateResult && offsetLogValidateResult + case None => false + } + }) + ) + } + } +} From d2c3de7efcfacadff20b023924d4566a5bf9ad7a Mon Sep 17 00:00:00 2001 From: Xiao Li Date: Fri, 1 Jun 2018 11:51:10 -0700 Subject: [PATCH 0909/1161] Revert "[SPARK-24369][SQL] Correct handling for multiple distinct aggregations having the same argument set" This reverts commit 1e46f92f956a00d04d47340489b6125d44dbd47b. --- .../optimizer/RewriteDistinctAggregates.scala | 7 +++---- .../apache/spark/sql/execution/SparkStrategies.scala | 2 +- .../src/test/resources/sql-tests/inputs/group-by.sql | 6 +----- .../test/resources/sql-tests/results/group-by.sql.out | 11 +---------- 4 files changed, 6 insertions(+), 20 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala index bc898ab0dc723..4448ace7105a4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala @@ -115,8 +115,7 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { } // Extract distinct aggregate expressions. - val distincgAggExpressions = aggExpressions.filter(_.isDistinct) - val distinctAggGroups = distincgAggExpressions.groupBy { e => + val distinctAggGroups = aggExpressions.filter(_.isDistinct).groupBy { e => val unfoldableChildren = e.aggregateFunction.children.filter(!_.foldable).toSet if (unfoldableChildren.nonEmpty) { // Only expand the unfoldable children @@ -133,7 +132,7 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { } // Aggregation strategy can handle queries with a single distinct group. - if (distincgAggExpressions.size > 1) { + if (distinctAggGroups.size > 1) { // Create the attributes for the grouping id and the group by clause. val gid = AttributeReference("gid", IntegerType, nullable = false)() val groupByMap = a.groupingExpressions.collect { @@ -152,7 +151,7 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { } // Setup unique distinct aggregate children. - val distinctAggChildren = distinctAggGroups.keySet.flatten.toSeq + val distinctAggChildren = distinctAggGroups.keySet.flatten.toSeq.distinct val distinctAggChildAttrMap = distinctAggChildren.map(expressionAttributePair) val distinctAggChildAttrs = distinctAggChildAttrMap.map(_._2) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index b9452b58657a4..b97a87a122406 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -386,7 +386,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { aggregateExpressions.partition(_.isDistinct) if (functionsWithDistinct.map(_.aggregateFunction.children).distinct.length > 1) { // This is a sanity check. We should not reach here when we have multiple distinct - // column sets. Our `RewriteDistinctAggregates` should take care this case. + // column sets. Our MultipleDistinctRewriter should take care this case. sys.error("You hit a query analyzer bug. Please report your query to " + "Spark user mailing list.") } diff --git a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql index 2c18d6aaabdba..c5070b734d521 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql @@ -68,8 +68,4 @@ SELECT 1 from ( FROM (select 1 as x) a WHERE false ) b -where b.z != b.z; - --- SPARK-24369 multiple distinct aggregations having the same argument set -SELECT corr(DISTINCT x, y), corr(DISTINCT y, x), count(*) - FROM (VALUES (1, 1), (2, 2), (2, 2)) t(x, y); +where b.z != b.z diff --git a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out index 581aa1754ce14..c1abc6dff754b 100644 --- a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 27 +-- Number of queries: 26 -- !query 0 @@ -241,12 +241,3 @@ where b.z != b.z struct<1:int> -- !query 25 output - - --- !query 26 -SELECT corr(DISTINCT x, y), corr(DISTINCT y, x), count(*) - FROM (VALUES (1, 1), (2, 2), (2, 2)) t(x, y) --- !query 26 schema -struct --- !query 26 output -1.0 1.0 3 From 09e78c1eaa742b9cab4564928e5a5401fe0198a9 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Fri, 1 Jun 2018 11:55:34 -0700 Subject: [PATCH 0910/1161] [INFRA] Close stale PRs. Closes #21444 From 8ef167a5f9ba8a79bb7ca98a9844fe9cfcfea060 Mon Sep 17 00:00:00 2001 From: Xingbo Jiang Date: Fri, 1 Jun 2018 13:46:05 -0700 Subject: [PATCH 0911/1161] [SPARK-24340][CORE] Clean up non-shuffle disk block manager files following executor exits on a Standalone cluster ## What changes were proposed in this pull request? Currently we only clean up the local directories on application removed. However, when executors die and restart repeatedly, many temp files are left untouched in the local directories, which is undesired behavior and could cause disk space used up gradually. We can detect executor death in the Worker, and clean up the non-shuffle files (files not ended with ".index" or ".data") in the local directories, we should not touch the shuffle files since they are expected to be used by the external shuffle service. Scope of this PR is limited to only implement the cleanup logic on a Standalone cluster, we defer to experts familiar with other cluster managers(YARN/Mesos/K8s) to determine whether it's worth to add similar support. ## How was this patch tested? Add new test suite to cover. Author: Xingbo Jiang Closes #21390 from jiangxb1987/cleanupNonshuffleFiles. --- .../apache/spark/network/util/JavaUtils.java | 45 ++-- .../shuffle/ExternalShuffleBlockHandler.java | 7 + .../shuffle/ExternalShuffleBlockResolver.java | 43 ++++ .../shuffle/NonShuffleFilesCleanupSuite.java | 221 ++++++++++++++++++ .../shuffle/TestShuffleDataContext.java | 15 ++ .../spark/deploy/ExternalShuffleService.scala | 5 + .../apache/spark/deploy/worker/Worker.scala | 17 +- .../spark/deploy/worker/WorkerSuite.scala | 55 ++++- docs/spark-standalone.md | 12 + 9 files changed, 400 insertions(+), 20 deletions(-) create mode 100644 common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/NonShuffleFilesCleanupSuite.java diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/JavaUtils.java b/common/network-common/src/main/java/org/apache/spark/network/util/JavaUtils.java index afc59efaef810..b5497087634ce 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/util/JavaUtils.java +++ b/common/network-common/src/main/java/org/apache/spark/network/util/JavaUtils.java @@ -17,10 +17,7 @@ package org.apache.spark.network.util; -import java.io.Closeable; -import java.io.EOFException; -import java.io.File; -import java.io.IOException; +import java.io.*; import java.nio.ByteBuffer; import java.nio.channels.ReadableByteChannel; import java.nio.charset.StandardCharsets; @@ -91,11 +88,24 @@ public static String bytesToString(ByteBuffer b) { * @throws IOException if deletion is unsuccessful */ public static void deleteRecursively(File file) throws IOException { + deleteRecursively(file, null); + } + + /** + * Delete a file or directory and its contents recursively. + * Don't follow directories if they are symlinks. + * + * @param file Input file / dir to be deleted + * @param filter A filename filter that make sure only files / dirs with the satisfied filenames + * are deleted. + * @throws IOException if deletion is unsuccessful + */ + public static void deleteRecursively(File file, FilenameFilter filter) throws IOException { if (file == null) { return; } // On Unix systems, use operating system command to run faster // If that does not work out, fallback to the Java IO way - if (SystemUtils.IS_OS_UNIX) { + if (SystemUtils.IS_OS_UNIX && filter == null) { try { deleteRecursivelyUsingUnixNative(file); return; @@ -105,15 +115,17 @@ public static void deleteRecursively(File file) throws IOException { } } - deleteRecursivelyUsingJavaIO(file); + deleteRecursivelyUsingJavaIO(file, filter); } - private static void deleteRecursivelyUsingJavaIO(File file) throws IOException { + private static void deleteRecursivelyUsingJavaIO( + File file, + FilenameFilter filter) throws IOException { if (file.isDirectory() && !isSymlink(file)) { IOException savedIOException = null; - for (File child : listFilesSafely(file)) { + for (File child : listFilesSafely(file, filter)) { try { - deleteRecursively(child); + deleteRecursively(child, filter); } catch (IOException e) { // In case of multiple exceptions, only last one will be thrown savedIOException = e; @@ -124,10 +136,13 @@ private static void deleteRecursivelyUsingJavaIO(File file) throws IOException { } } - boolean deleted = file.delete(); - // Delete can also fail if the file simply did not exist. - if (!deleted && file.exists()) { - throw new IOException("Failed to delete: " + file.getAbsolutePath()); + // Delete file only when it's a normal file or an empty directory. + if (file.isFile() || (file.isDirectory() && listFilesSafely(file, null).length == 0)) { + boolean deleted = file.delete(); + // Delete can also fail if the file simply did not exist. + if (!deleted && file.exists()) { + throw new IOException("Failed to delete: " + file.getAbsolutePath()); + } } } @@ -157,9 +172,9 @@ private static void deleteRecursivelyUsingUnixNative(File file) throws IOExcepti } } - private static File[] listFilesSafely(File file) throws IOException { + private static File[] listFilesSafely(File file, FilenameFilter filter) throws IOException { if (file.exists()) { - File[] files = file.listFiles(); + File[] files = file.listFiles(filter); if (files == null) { throw new IOException("Failed to list files for dir: " + file); } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java index fc7bba41185f0..098fa7974b87b 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java @@ -138,6 +138,13 @@ public void applicationRemoved(String appId, boolean cleanupLocalDirs) { blockManager.applicationRemoved(appId, cleanupLocalDirs); } + /** + * Clean up any non-shuffle files in any local directories associated with an finished executor. + */ + public void executorRemoved(String executorId, String appId) { + blockManager.executorRemoved(executorId, appId); + } + /** * Register an (application, executor) with the given shuffle info. * diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java index e6399897be9c2..58fb17f60a79d 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java @@ -211,6 +211,26 @@ public void applicationRemoved(String appId, boolean cleanupLocalDirs) { } } + /** + * Removes all the non-shuffle files in any local directories associated with the finished + * executor. + */ + public void executorRemoved(String executorId, String appId) { + logger.info("Clean up non-shuffle files associated with the finished executor {}", executorId); + AppExecId fullId = new AppExecId(appId, executorId); + final ExecutorShuffleInfo executor = executors.get(fullId); + if (executor == null) { + // Executor not registered, skip clean up of the local directories. + logger.info("Executor is not registered (appId={}, execId={})", appId, executorId); + } else { + logger.info("Cleaning up non-shuffle files in executor {}'s {} local dirs", fullId, + executor.localDirs.length); + + // Execute the actual deletion in a different thread, as it may take some time. + directoryCleaner.execute(() -> deleteNonShuffleFiles(executor.localDirs)); + } + } + /** * Synchronously deletes each directory one at a time. * Should be executed in its own thread, as this may take a long time. @@ -226,6 +246,29 @@ private void deleteExecutorDirs(String[] dirs) { } } + /** + * Synchronously deletes non-shuffle files in each directory recursively. + * Should be executed in its own thread, as this may take a long time. + */ + private void deleteNonShuffleFiles(String[] dirs) { + FilenameFilter filter = new FilenameFilter() { + @Override + public boolean accept(File dir, String name) { + // Don't delete shuffle data or shuffle index files. + return !name.endsWith(".index") && !name.endsWith(".data"); + } + }; + + for (String localDir : dirs) { + try { + JavaUtils.deleteRecursively(new File(localDir), filter); + logger.debug("Successfully cleaned up non-shuffle files in directory: {}", localDir); + } catch (Exception e) { + logger.error("Failed to delete non-shuffle files in directory: " + localDir, e); + } + } + } + /** * Sort-based shuffle data uses an index called "shuffle_ShuffleId_MapId_0.index" into a data file * called "shuffle_ShuffleId_MapId_0.data". This logic is from IndexShuffleBlockResolver, diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/NonShuffleFilesCleanupSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/NonShuffleFilesCleanupSuite.java new file mode 100644 index 0000000000000..d22f3ace4103b --- /dev/null +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/NonShuffleFilesCleanupSuite.java @@ -0,0 +1,221 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle; + +import java.io.File; +import java.io.FilenameFilter; +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.Random; +import java.util.concurrent.Executor; +import java.util.concurrent.atomic.AtomicBoolean; + +import com.google.common.util.concurrent.MoreExecutors; +import org.junit.Test; +import static org.junit.Assert.assertTrue; + +import org.apache.spark.network.util.MapConfigProvider; +import org.apache.spark.network.util.TransportConf; + +public class NonShuffleFilesCleanupSuite { + + // Same-thread Executor used to ensure cleanup happens synchronously in test thread. + private Executor sameThreadExecutor = MoreExecutors.sameThreadExecutor(); + private TransportConf conf = new TransportConf("shuffle", MapConfigProvider.EMPTY); + private static final String SORT_MANAGER = "org.apache.spark.shuffle.sort.SortShuffleManager"; + + @Test + public void cleanupOnRemovedExecutorWithShuffleFiles() throws IOException { + cleanupOnRemovedExecutor(true); + } + + @Test + public void cleanupOnRemovedExecutorWithoutShuffleFiles() throws IOException { + cleanupOnRemovedExecutor(false); + } + + private void cleanupOnRemovedExecutor(boolean withShuffleFiles) throws IOException { + TestShuffleDataContext dataContext = initDataContext(withShuffleFiles); + + ExternalShuffleBlockResolver resolver = + new ExternalShuffleBlockResolver(conf, null, sameThreadExecutor); + resolver.registerExecutor("app", "exec0", dataContext.createExecutorInfo(SORT_MANAGER)); + resolver.executorRemoved("exec0", "app"); + + assertCleanedUp(dataContext); + } + + @Test + public void cleanupUsesExecutorWithShuffleFiles() throws IOException { + cleanupUsesExecutor(true); + } + + @Test + public void cleanupUsesExecutorWithoutShuffleFiles() throws IOException { + cleanupUsesExecutor(false); + } + + private void cleanupUsesExecutor(boolean withShuffleFiles) throws IOException { + TestShuffleDataContext dataContext = initDataContext(withShuffleFiles); + + AtomicBoolean cleanupCalled = new AtomicBoolean(false); + + // Executor which does nothing to ensure we're actually using it. + Executor noThreadExecutor = runnable -> cleanupCalled.set(true); + + ExternalShuffleBlockResolver manager = + new ExternalShuffleBlockResolver(conf, null, noThreadExecutor); + + manager.registerExecutor("app", "exec0", dataContext.createExecutorInfo(SORT_MANAGER)); + manager.executorRemoved("exec0", "app"); + + assertTrue(cleanupCalled.get()); + assertStillThere(dataContext); + } + + @Test + public void cleanupOnlyRemovedExecutorWithShuffleFiles() throws IOException { + cleanupOnlyRemovedExecutor(true); + } + + @Test + public void cleanupOnlyRemovedExecutorWithoutShuffleFiles() throws IOException { + cleanupOnlyRemovedExecutor(false); + } + + private void cleanupOnlyRemovedExecutor(boolean withShuffleFiles) throws IOException { + TestShuffleDataContext dataContext0 = initDataContext(withShuffleFiles); + TestShuffleDataContext dataContext1 = initDataContext(withShuffleFiles); + + ExternalShuffleBlockResolver resolver = + new ExternalShuffleBlockResolver(conf, null, sameThreadExecutor); + resolver.registerExecutor("app", "exec0", dataContext0.createExecutorInfo(SORT_MANAGER)); + resolver.registerExecutor("app", "exec1", dataContext1.createExecutorInfo(SORT_MANAGER)); + + + resolver.executorRemoved("exec-nonexistent", "app"); + assertStillThere(dataContext0); + assertStillThere(dataContext1); + + resolver.executorRemoved("exec0", "app"); + assertCleanedUp(dataContext0); + assertStillThere(dataContext1); + + resolver.executorRemoved("exec1", "app"); + assertCleanedUp(dataContext0); + assertCleanedUp(dataContext1); + + // Make sure it's not an error to cleanup multiple times + resolver.executorRemoved("exec1", "app"); + assertCleanedUp(dataContext0); + assertCleanedUp(dataContext1); + } + + @Test + public void cleanupOnlyRegisteredExecutorWithShuffleFiles() throws IOException { + cleanupOnlyRegisteredExecutor(true); + } + + @Test + public void cleanupOnlyRegisteredExecutorWithoutShuffleFiles() throws IOException { + cleanupOnlyRegisteredExecutor(false); + } + + private void cleanupOnlyRegisteredExecutor(boolean withShuffleFiles) throws IOException { + TestShuffleDataContext dataContext = initDataContext(withShuffleFiles); + + ExternalShuffleBlockResolver resolver = + new ExternalShuffleBlockResolver(conf, null, sameThreadExecutor); + resolver.registerExecutor("app", "exec0", dataContext.createExecutorInfo(SORT_MANAGER)); + + resolver.executorRemoved("exec1", "app"); + assertStillThere(dataContext); + + resolver.executorRemoved("exec0", "app"); + assertCleanedUp(dataContext); + } + + private static void assertStillThere(TestShuffleDataContext dataContext) { + for (String localDir : dataContext.localDirs) { + assertTrue(localDir + " was cleaned up prematurely", new File(localDir).exists()); + } + } + + private static FilenameFilter filter = new FilenameFilter() { + @Override + public boolean accept(File dir, String name) { + // Don't delete shuffle data or shuffle index files. + return !name.endsWith(".index") && !name.endsWith(".data"); + } + }; + + private static boolean assertOnlyShuffleDataInDir(File[] dirs) { + for (File dir : dirs) { + assertTrue(dir.getName() + " wasn't cleaned up", !dir.exists() || + dir.listFiles(filter).length == 0 || assertOnlyShuffleDataInDir(dir.listFiles())); + } + return true; + } + + private static void assertCleanedUp(TestShuffleDataContext dataContext) { + for (String localDir : dataContext.localDirs) { + File[] dirs = new File[] {new File(localDir)}; + assertOnlyShuffleDataInDir(dirs); + } + } + + private static TestShuffleDataContext initDataContext(boolean withShuffleFiles) + throws IOException { + if (withShuffleFiles) { + return initDataContextWithShuffleFiles(); + } else { + return initDataContextWithoutShuffleFiles(); + } + } + + private static TestShuffleDataContext initDataContextWithShuffleFiles() throws IOException { + TestShuffleDataContext dataContext = createDataContext(); + createShuffleFiles(dataContext); + createNonShuffleFiles(dataContext); + return dataContext; + } + + private static TestShuffleDataContext initDataContextWithoutShuffleFiles() throws IOException { + TestShuffleDataContext dataContext = createDataContext(); + createNonShuffleFiles(dataContext); + return dataContext; + } + + private static TestShuffleDataContext createDataContext() { + TestShuffleDataContext dataContext = new TestShuffleDataContext(10, 5); + dataContext.create(); + return dataContext; + } + + private static void createShuffleFiles(TestShuffleDataContext dataContext) throws IOException { + Random rand = new Random(123); + dataContext.insertSortShuffleData(rand.nextInt(1000), rand.nextInt(1000), new byte[][] { + "ABC".getBytes(StandardCharsets.UTF_8), + "DEF".getBytes(StandardCharsets.UTF_8)}); + } + + private static void createNonShuffleFiles(TestShuffleDataContext dataContext) throws IOException { + // Create spill file(s) + dataContext.insertSpillData(); + } +} diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java index 81e01949e50fa..6989c3baf2e28 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java @@ -22,6 +22,7 @@ import java.io.FileOutputStream; import java.io.IOException; import java.io.OutputStream; +import java.util.UUID; import com.google.common.io.Closeables; import com.google.common.io.Files; @@ -94,6 +95,20 @@ public void insertSortShuffleData(int shuffleId, int mapId, byte[][] blocks) thr } } + /** Creates spill file(s) within the local dirs. */ + public void insertSpillData() throws IOException { + String filename = "temp_local_" + UUID.randomUUID(); + OutputStream dataStream = null; + + try { + dataStream = new FileOutputStream( + ExternalShuffleBlockResolver.getFile(localDirs, subDirsPerLocalDir, filename)); + dataStream.write(42); + } finally { + Closeables.close(dataStream, false); + } + } + /** * Creates an ExecutorShuffleInfo object based on the given shuffle manager which targets this * context's directories. diff --git a/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala b/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala index f975fa5cb4e23..b59a4fe66587c 100644 --- a/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala +++ b/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala @@ -94,6 +94,11 @@ class ExternalShuffleService(sparkConf: SparkConf, securityManager: SecurityMana blockHandler.applicationRemoved(appId, true /* cleanupLocalDirs */) } + /** Clean up all the non-shuffle files associated with an executor that has exited. */ + def executorRemoved(executorId: String, appId: String): Unit = { + blockHandler.executorRemoved(executorId, appId) + } + def stop() { if (server != null) { server.close() diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index 563b84934f264..ee1ca0bba5749 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -23,6 +23,7 @@ import java.text.SimpleDateFormat import java.util.{Date, Locale, UUID} import java.util.concurrent._ import java.util.concurrent.{Future => JFuture, ScheduledFuture => JScheduledFuture} +import java.util.function.Supplier import scala.collection.mutable.{HashMap, HashSet, LinkedHashMap} import scala.concurrent.ExecutionContext @@ -49,7 +50,8 @@ private[deploy] class Worker( endpointName: String, workDirPath: String = null, val conf: SparkConf, - val securityMgr: SecurityManager) + val securityMgr: SecurityManager, + externalShuffleServiceSupplier: Supplier[ExternalShuffleService] = null) extends ThreadSafeRpcEndpoint with Logging { private val host = rpcEnv.address.host @@ -97,6 +99,10 @@ private[deploy] class Worker( private val APP_DATA_RETENTION_SECONDS = conf.getLong("spark.worker.cleanup.appDataTtl", 7 * 24 * 3600) + // Whether or not cleanup the non-shuffle files on executor exits. + private val CLEANUP_NON_SHUFFLE_FILES_ENABLED = + conf.getBoolean("spark.storage.cleanupFilesAfterExecutorExit", true) + private val testing: Boolean = sys.props.contains("spark.testing") private var master: Option[RpcEndpointRef] = None @@ -142,7 +148,11 @@ private[deploy] class Worker( WorkerWebUI.DEFAULT_RETAINED_DRIVERS) // The shuffle service is not actually started unless configured. - private val shuffleService = new ExternalShuffleService(conf, securityMgr) + private val shuffleService = if (externalShuffleServiceSupplier != null) { + externalShuffleServiceSupplier.get() + } else { + new ExternalShuffleService(conf, securityMgr) + } private val publicAddress = { val envVar = conf.getenv("SPARK_PUBLIC_DNS") @@ -732,6 +742,9 @@ private[deploy] class Worker( trimFinishedExecutorsIfNecessary() coresUsed -= executor.cores memoryUsed -= executor.memory + if (CLEANUP_NON_SHUFFLE_FILES_ENABLED) { + shuffleService.executorRemoved(executorStateChanged.execId.toString, appId) + } case None => logInfo("Unknown Executor " + fullId + " finished with state " + state + message.map(" message " + _).getOrElse("") + diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerSuite.scala index ce212a7513310..e3fe2b696aa1f 100644 --- a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerSuite.scala @@ -17,10 +17,19 @@ package org.apache.spark.deploy.worker +import java.util.concurrent.atomic.AtomicBoolean +import java.util.function.Supplier + +import org.mockito.{Mock, MockitoAnnotations} +import org.mockito.Answers.RETURNS_SMART_NULLS +import org.mockito.Matchers._ +import org.mockito.Mockito._ +import org.mockito.invocation.InvocationOnMock +import org.mockito.stubbing.Answer import org.scalatest.{BeforeAndAfter, Matchers} import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} -import org.apache.spark.deploy.{Command, ExecutorState} +import org.apache.spark.deploy.{Command, ExecutorState, ExternalShuffleService} import org.apache.spark.deploy.DeployMessages.{DriverStateChanged, ExecutorStateChanged} import org.apache.spark.deploy.master.DriverState import org.apache.spark.rpc.{RpcAddress, RpcEnv} @@ -29,6 +38,8 @@ class WorkerSuite extends SparkFunSuite with Matchers with BeforeAndAfter { import org.apache.spark.deploy.DeployTestUtils._ + @Mock(answer = RETURNS_SMART_NULLS) private var shuffleService: ExternalShuffleService = _ + def cmd(javaOpts: String*): Command = { Command("", Seq.empty, Map.empty, Seq.empty, Seq.empty, Seq(javaOpts : _*)) } @@ -36,15 +47,21 @@ class WorkerSuite extends SparkFunSuite with Matchers with BeforeAndAfter { private var _worker: Worker = _ - private def makeWorker(conf: SparkConf): Worker = { + private def makeWorker( + conf: SparkConf, + shuffleServiceSupplier: Supplier[ExternalShuffleService] = null): Worker = { assert(_worker === null, "Some Worker's RpcEnv is leaked in tests") val securityMgr = new SecurityManager(conf) val rpcEnv = RpcEnv.create("test", "localhost", 12345, conf, securityMgr) _worker = new Worker(rpcEnv, 50000, 20, 1234 * 5, Array.fill(1)(RpcAddress("1.2.3.4", 1234)), - "Worker", "/tmp", conf, securityMgr) + "Worker", "/tmp", conf, securityMgr, shuffleServiceSupplier) _worker } + before { + MockitoAnnotations.initMocks(this) + } + after { if (_worker != null) { _worker.rpcEnv.shutdown() @@ -194,4 +211,36 @@ class WorkerSuite extends SparkFunSuite with Matchers with BeforeAndAfter { assert(worker.finishedDrivers.size === expectedValue) } } + + test("cleanup non-shuffle files after executor exits when config " + + "spark.storage.cleanupFilesAfterExecutorExit=true") { + testCleanupFilesWithConfig(true) + } + + test("don't cleanup non-shuffle files after executor exits when config " + + "spark.storage.cleanupFilesAfterExecutorExit=false") { + testCleanupFilesWithConfig(false) + } + + private def testCleanupFilesWithConfig(value: Boolean) = { + val conf = new SparkConf().set("spark.storage.cleanupFilesAfterExecutorExit", value.toString) + + val cleanupCalled = new AtomicBoolean(false) + when(shuffleService.executorRemoved(any[String], any[String])).thenAnswer(new Answer[Unit] { + override def answer(invocations: InvocationOnMock): Unit = { + cleanupCalled.set(true) + } + }) + val externalShuffleServiceSupplier = new Supplier[ExternalShuffleService] { + override def get: ExternalShuffleService = shuffleService + } + val worker = makeWorker(conf, externalShuffleServiceSupplier) + // initialize workers + for (i <- 0 until 10) { + worker.executors += s"app1/$i" -> createExecutorRunner(i) + } + worker.handleExecutorStateChanged( + ExecutorStateChanged("app1", 0, ExecutorState.EXITED, None, None)) + assert(cleanupCalled.get() == value) + } } diff --git a/docs/spark-standalone.md b/docs/spark-standalone.md index f06e72a387df1..14d742de5655c 100644 --- a/docs/spark-standalone.md +++ b/docs/spark-standalone.md @@ -254,6 +254,18 @@ SPARK_WORKER_OPTS supports the following system properties: especially if you run jobs very frequently. + + spark.storage.cleanupFilesAfterExecutorExit + true + + Enable cleanup non-shuffle files(such as temp. shuffle blocks, cached RDD/broadcast blocks, + spill files, etc) of worker directories following executor exits. Note that this doesn't + overlap with `spark.worker.cleanup.enabled`, as this enables cleanup of non-shuffle files in + local directories of a dead executor, while `spark.worker.cleanup.enabled` enables cleanup of + all files/subdirectories of a stopped and timeout application. + This only affects Standalone mode, support of other cluster manangers can be added in the future. + + spark.worker.ui.compressedLogFileLengthCacheSize 100 From a36c1a6bbd1deb119d96316ccbb6dc96ad174796 Mon Sep 17 00:00:00 2001 From: Yinan Li Date: Fri, 1 Jun 2018 23:43:10 -0700 Subject: [PATCH 0912/1161] [SPARK-23668][K8S] Added missing config property in running-on-kubernetes.md ## What changes were proposed in this pull request? PR https://github.com/apache/spark/pull/20811 introduced a new Spark configuration property `spark.kubernetes.container.image.pullSecrets` for specifying image pull secrets. However, the documentation wasn't updated accordingly. This PR adds the property introduced into running-on-kubernetes.md. ## How was this patch tested? N/A. foxish mccheah please help merge this. Thanks! Author: Yinan Li Closes #21480 from liyinan926/master. --- docs/running-on-kubernetes.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/docs/running-on-kubernetes.md b/docs/running-on-kubernetes.md index a4b2b98b0b649..4eac9bd9032e4 100644 --- a/docs/running-on-kubernetes.md +++ b/docs/running-on-kubernetes.md @@ -327,6 +327,13 @@ specific to Spark on Kubernetes. Container image pull policy used when pulling images within Kubernetes. + + spark.kubernetes.container.image.pullSecrets + + + Comma separated list of Kubernetes secrets used to pull images from private image registries. + + spark.kubernetes.allocation.batch.size 5 From de4feae3cd2fef9e83cac749b04ea9395bdd805e Mon Sep 17 00:00:00 2001 From: Misha Dmitriev Date: Sat, 2 Jun 2018 23:07:39 -0500 Subject: [PATCH 0913/1161] [SPARK-24356][CORE] Duplicate strings in File.path managed by FileSegmentManagedBuffer This patch eliminates duplicate strings that come from the 'path' field of java.io.File objects created by FileSegmentManagedBuffer. That is, we want to avoid the situation when multiple File instances for the same pathname "foo/bar" are created, each with a separate copy of the "foo/bar" String instance. In some scenarios such duplicate strings may waste a lot of memory (~ 10% of the heap). To avoid that, we intern the pathname with String.intern(), and before that we make sure that it's in a normalized form (contains no "//", "///" etc.) Otherwise, the code in java.io.File would normalize it later, creating a new "foo/bar" String copy. Unfortunately, the normalization code that java.io.File uses internally is in the package-private class java.io.FileSystem, so we cannot call it here directly. ## What changes were proposed in this pull request? Added code to ExternalShuffleBlockResolver.getFile(), that normalizes and then interns the pathname string before passing it to the File() constructor. ## How was this patch tested? Added unit test Author: Misha Dmitriev Closes #21456 from countmdm/misha/spark-24356. --- .../shuffle/ExternalShuffleBlockResolver.java | 30 ++++++++++++++++++- .../ExternalShuffleBlockResolverSuite.java | 20 +++++++++++++ 2 files changed, 49 insertions(+), 1 deletion(-) diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java index 58fb17f60a79d..0b7a27402369d 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java @@ -24,6 +24,8 @@ import java.util.concurrent.ExecutionException; import java.util.concurrent.Executor; import java.util.concurrent.Executors; +import java.util.regex.Matcher; +import java.util.regex.Pattern; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; @@ -59,6 +61,7 @@ public class ExternalShuffleBlockResolver { private static final Logger logger = LoggerFactory.getLogger(ExternalShuffleBlockResolver.class); private static final ObjectMapper mapper = new ObjectMapper(); + /** * This a common prefix to the key for each app registration we stick in leveldb, so they * are easy to find, since leveldb lets you search based on prefix. @@ -66,6 +69,8 @@ public class ExternalShuffleBlockResolver { private static final String APP_KEY_PREFIX = "AppExecShuffleInfo"; private static final StoreVersion CURRENT_VERSION = new StoreVersion(1, 0); + private static final Pattern MULTIPLE_SEPARATORS = Pattern.compile(File.separator + "{2,}"); + // Map containing all registered executors' metadata. @VisibleForTesting final ConcurrentMap executors; @@ -302,7 +307,8 @@ static File getFile(String[] localDirs, int subDirsPerLocalDir, String filename) int hash = JavaUtils.nonNegativeHash(filename); String localDir = localDirs[hash % localDirs.length]; int subDirId = (hash / localDirs.length) % subDirsPerLocalDir; - return new File(new File(localDir, String.format("%02x", subDirId)), filename); + return new File(createNormalizedInternedPathname( + localDir, String.format("%02x", subDirId), filename)); } void close() { @@ -315,6 +321,28 @@ void close() { } } + /** + * This method is needed to avoid the situation when multiple File instances for the + * same pathname "foo/bar" are created, each with a separate copy of the "foo/bar" String. + * According to measurements, in some scenarios such duplicate strings may waste a lot + * of memory (~ 10% of the heap). To avoid that, we intern the pathname, and before that + * we make sure that it's in a normalized form (contains no "//", "///" etc.) Otherwise, + * the internal code in java.io.File would normalize it later, creating a new "foo/bar" + * String copy. Unfortunately, we cannot just reuse the normalization code that java.io.File + * uses, since it is in the package-private class java.io.FileSystem. + */ + @VisibleForTesting + static String createNormalizedInternedPathname(String dir1, String dir2, String fname) { + String pathname = dir1 + File.separator + dir2 + File.separator + fname; + Matcher m = MULTIPLE_SEPARATORS.matcher(pathname); + pathname = m.replaceAll("/"); + // A single trailing slash needs to be taken care of separately + if (pathname.length() > 1 && pathname.endsWith("/")) { + pathname = pathname.substring(0, pathname.length() - 1); + } + return pathname.intern(); + } + /** Simply encodes an executor's full ID, which is appId + execId. */ public static class AppExecId { public final String appId; diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java index 6d201b8fe8d7d..d2072a54fa415 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java @@ -17,6 +17,7 @@ package org.apache.spark.network.shuffle; +import java.io.File; import java.io.IOException; import java.io.InputStream; import java.io.InputStreamReader; @@ -135,4 +136,23 @@ public void jsonSerializationOfExecutorRegistration() throws IOException { "\"subDirsPerLocalDir\": 7, \"shuffleManager\": " + "\"" + SORT_MANAGER + "\"}"; assertEquals(shuffleInfo, mapper.readValue(legacyShuffleJson, ExecutorShuffleInfo.class)); } + + @Test + public void testNormalizeAndInternPathname() { + assertPathsMatch("/foo", "bar", "baz", "/foo/bar/baz"); + assertPathsMatch("//foo/", "bar/", "//baz", "/foo/bar/baz"); + assertPathsMatch("foo", "bar", "baz///", "foo/bar/baz"); + assertPathsMatch("/foo/", "/bar//", "/baz", "/foo/bar/baz"); + assertPathsMatch("/", "", "", "/"); + assertPathsMatch("/", "/", "/", "/"); + } + + private void assertPathsMatch(String p1, String p2, String p3, String expectedPathname) { + String normPathname = + ExternalShuffleBlockResolver.createNormalizedInternedPathname(p1, p2, p3); + assertEquals(expectedPathname, normPathname); + File file = new File(normPathname); + String returnedPath = file.getPath(); + assertTrue(normPathname == returnedPath); + } } From a2166ecddaec030f78acaa66ce660d979a35079c Mon Sep 17 00:00:00 2001 From: xueyu Date: Mon, 4 Jun 2018 08:10:49 +0700 Subject: [PATCH 0914/1161] [SPARK-24455][CORE] fix typo in TaskSchedulerImpl comment change runTasks to submitTasks in the TaskSchedulerImpl.scala 's comment Author: xueyu Author: Xue Yu <278006819@qq.com> Closes #21485 from xueyumusic/fixtypo1. --- .../org/apache/spark/scheduler/TaskSchedulerImpl.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index 8e97b3da33820..598b62f85a1fa 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -42,7 +42,7 @@ import org.apache.spark.util.{AccumulatorV2, ThreadUtils, Utils} * up to launch speculative tasks, etc. * * Clients should first call initialize() and start(), then submit task sets through the - * runTasks method. + * submitTasks method. * * THREADING: [[SchedulerBackend]]s and task-submitting clients can call this class from multiple * threads, so it needs locks in public API methods to maintain its state. In addition, some @@ -62,7 +62,7 @@ private[spark] class TaskSchedulerImpl( this(sc, sc.conf.get(config.MAX_TASK_FAILURES)) } - // Lazily initializing blackListTrackOpt to avoid getting empty ExecutorAllocationClient, + // Lazily initializing blacklistTrackerOpt to avoid getting empty ExecutorAllocationClient, // because ExecutorAllocationClient is created after this TaskSchedulerImpl. private[scheduler] lazy val blacklistTrackerOpt = maybeCreateBlacklistTracker(sc) @@ -228,7 +228,7 @@ private[spark] class TaskSchedulerImpl( // 1. The task set manager has been created and some tasks have been scheduled. // In this case, send a kill signal to the executors to kill the task and then abort // the stage. - // 2. The task set manager has been created but no tasks has been scheduled. In this case, + // 2. The task set manager has been created but no tasks have been scheduled. In this case, // simply abort the stage. tsm.runningTasksSet.foreach { tid => taskIdToExecutorId.get(tid).foreach(execId => @@ -694,7 +694,7 @@ private[spark] class TaskSchedulerImpl( * * After stage failure and retry, there may be multiple TaskSetManagers for the stage. * If an earlier attempt of a stage completes a task, we should ensure that the later attempts - * do not also submit those same tasks. That also means that a task completion from an earlier + * do not also submit those same tasks. That also means that a task completion from an earlier * attempt can lead to the entire stage getting marked as successful. */ private[scheduler] def markPartitionCompletedInAllTaskSets(stageId: Int, partitionId: Int) = { From 416cd1fd96c0db9194e32ba877b1396b6dc13c8e Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Sun, 3 Jun 2018 21:57:42 -0700 Subject: [PATCH 0915/1161] [SPARK-24369][SQL] Correct handling for multiple distinct aggregations having the same argument set ## What changes were proposed in this pull request? bring back https://github.com/apache/spark/pull/21443 This is a different approach: just change the check to count distinct columns with `toSet` ## How was this patch tested? a new test to verify the planner behavior. Author: Wenchen Fan Author: Takeshi Yamamuro Closes #21487 from cloud-fan/back. --- .../spark/sql/execution/SparkStrategies.scala | 4 ++-- .../resources/sql-tests/inputs/group-by.sql | 6 +++++- .../sql-tests/results/group-by.sql.out | 11 +++++++++- .../spark/sql/execution/PlannerSuite.scala | 21 +++++++++++++++++++ 4 files changed, 38 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index b97a87a122406..be34387f6a874 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -384,9 +384,9 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { val (functionsWithDistinct, functionsWithoutDistinct) = aggregateExpressions.partition(_.isDistinct) - if (functionsWithDistinct.map(_.aggregateFunction.children).distinct.length > 1) { + if (functionsWithDistinct.map(_.aggregateFunction.children.toSet).distinct.length > 1) { // This is a sanity check. We should not reach here when we have multiple distinct - // column sets. Our MultipleDistinctRewriter should take care this case. + // column sets. Our `RewriteDistinctAggregates` should take care this case. sys.error("You hit a query analyzer bug. Please report your query to " + "Spark user mailing list.") } diff --git a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql index c5070b734d521..2c18d6aaabdba 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql @@ -68,4 +68,8 @@ SELECT 1 from ( FROM (select 1 as x) a WHERE false ) b -where b.z != b.z +where b.z != b.z; + +-- SPARK-24369 multiple distinct aggregations having the same argument set +SELECT corr(DISTINCT x, y), corr(DISTINCT y, x), count(*) + FROM (VALUES (1, 1), (2, 2), (2, 2)) t(x, y); diff --git a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out index c1abc6dff754b..581aa1754ce14 100644 --- a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 26 +-- Number of queries: 27 -- !query 0 @@ -241,3 +241,12 @@ where b.z != b.z struct<1:int> -- !query 25 output + + +-- !query 26 +SELECT corr(DISTINCT x, y), corr(DISTINCT y, x), count(*) + FROM (VALUES (1, 1), (2, 2), (2, 2)) t(x, y) +-- !query 26 schema +struct +-- !query 26 output +1.0 1.0 3 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index b2aba8e72c5db..98a50fbd52b4d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -69,6 +69,27 @@ class PlannerSuite extends SharedSQLContext { testPartialAggregationPlan(query) } + test("mixed aggregates with same distinct columns") { + def assertNoExpand(plan: SparkPlan): Unit = { + assert(plan.collect { case e: ExpandExec => e }.isEmpty) + } + + withTempView("v") { + Seq((1, 1.0, 1.0), (1, 2.0, 2.0)).toDF("i", "j", "k").createTempView("v") + // one distinct column + val query1 = sql("SELECT sum(DISTINCT j), max(DISTINCT j) FROM v GROUP BY i") + assertNoExpand(query1.queryExecution.executedPlan) + + // 2 distinct columns + val query2 = sql("SELECT corr(DISTINCT j, k), count(DISTINCT j, k) FROM v GROUP BY i") + assertNoExpand(query2.queryExecution.executedPlan) + + // 2 distinct columns with different order + val query3 = sql("SELECT corr(DISTINCT j, k), count(DISTINCT k, j) FROM v GROUP BY i") + assertNoExpand(query3.queryExecution.executedPlan) + } + } + test("sizeInBytes estimation of limit operator for broadcast hash join optimization") { def checkPlan(fieldTypes: Seq[DataType]): Unit = { withTempView("testLimit") { From 1d9338bb10b953daddb23b8879ff99aa5c57dbea Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Sun, 3 Jun 2018 22:02:21 -0700 Subject: [PATCH 0916/1161] [SPARK-23786][SQL] Checking column names of csv headers ## What changes were proposed in this pull request? Currently column names of headers in CSV files are not checked against provided schema of CSV data. It could cause errors like showed in the [SPARK-23786](https://issues.apache.org/jira/browse/SPARK-23786) and https://github.com/apache/spark/pull/20894#issuecomment-375957777. I introduced new CSV option - `enforceSchema`. If it is enabled (by default `true`), Spark forcibly applies provided or inferred schema to CSV files. In that case, CSV headers are ignored and not checked against the schema. If `enforceSchema` is set to `false`, additional checks can be performed. For example, if column in CSV header and in the schema have different ordering, the following exception is thrown: ``` java.lang.IllegalArgumentException: CSV file header does not contain the expected fields Header: depth, temperature Schema: temperature, depth CSV file: marina.csv ``` ## How was this patch tested? The changes were tested by existing tests of CSVSuite and by 2 new tests. Author: Maxim Gekk Author: Maxim Gekk Closes #20894 from MaxGekk/check-column-names. --- python/pyspark/sql/readwriter.py | 15 +- python/pyspark/sql/streaming.py | 15 +- python/pyspark/sql/tests.py | 18 ++ .../apache/spark/sql/DataFrameReader.scala | 19 ++ .../datasources/csv/CSVDataSource.scala | 126 +++++++++++- .../datasources/csv/CSVFileFormat.scala | 9 +- .../datasources/csv/CSVOptions.scala | 6 + .../execution/datasources/csv/CSVUtils.scala | 22 +- .../datasources/csv/UnivocityParser.scala | 26 +-- .../execution/datasources/csv/CSVSuite.scala | 192 ++++++++++++++++++ 10 files changed, 411 insertions(+), 37 deletions(-) diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 448a4732001b5..a0e20d39c20da 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -346,7 +346,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non negativeInf=None, dateFormat=None, timestampFormat=None, maxColumns=None, maxCharsPerColumn=None, maxMalformedLogPerPartition=None, mode=None, columnNameOfCorruptRecord=None, multiLine=None, charToEscapeQuoteEscaping=None, - samplingRatio=None): + samplingRatio=None, enforceSchema=None): """Loads a CSV file and returns the result as a :class:`DataFrame`. This function will go through the input once to determine the input schema if @@ -373,6 +373,16 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non default value, ``false``. :param inferSchema: infers the input schema automatically from data. It requires one extra pass over the data. If None is set, it uses the default value, ``false``. + :param enforceSchema: If it is set to ``true``, the specified or inferred schema will be + forcibly applied to datasource files, and headers in CSV files will be + ignored. If the option is set to ``false``, the schema will be + validated against all headers in CSV files or the first header in RDD + if the ``header`` option is set to ``true``. Field names in the schema + and column names in CSV headers are checked by their positions + taking into account ``spark.sql.caseSensitive``. If None is set, + ``true`` is used by default. Though the default value is ``true``, + it is recommended to disable the ``enforceSchema`` option + to avoid incorrect results. :param ignoreLeadingWhiteSpace: A flag indicating whether or not leading whitespaces from values being read should be skipped. If None is set, it uses the default value, ``false``. @@ -449,7 +459,8 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non maxCharsPerColumn=maxCharsPerColumn, maxMalformedLogPerPartition=maxMalformedLogPerPartition, mode=mode, columnNameOfCorruptRecord=columnNameOfCorruptRecord, multiLine=multiLine, - charToEscapeQuoteEscaping=charToEscapeQuoteEscaping, samplingRatio=samplingRatio) + charToEscapeQuoteEscaping=charToEscapeQuoteEscaping, samplingRatio=samplingRatio, + enforceSchema=enforceSchema) if isinstance(path, basestring): path = [path] if type(path) == list: diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index 15f9407389864..fae50b3d5d532 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -564,7 +564,8 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non ignoreTrailingWhiteSpace=None, nullValue=None, nanValue=None, positiveInf=None, negativeInf=None, dateFormat=None, timestampFormat=None, maxColumns=None, maxCharsPerColumn=None, maxMalformedLogPerPartition=None, mode=None, - columnNameOfCorruptRecord=None, multiLine=None, charToEscapeQuoteEscaping=None): + columnNameOfCorruptRecord=None, multiLine=None, charToEscapeQuoteEscaping=None, + enforceSchema=None): """Loads a CSV file stream and returns the result as a :class:`DataFrame`. This function will go through the input once to determine the input schema if @@ -592,6 +593,16 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non default value, ``false``. :param inferSchema: infers the input schema automatically from data. It requires one extra pass over the data. If None is set, it uses the default value, ``false``. + :param enforceSchema: If it is set to ``true``, the specified or inferred schema will be + forcibly applied to datasource files, and headers in CSV files will be + ignored. If the option is set to ``false``, the schema will be + validated against all headers in CSV files or the first header in RDD + if the ``header`` option is set to ``true``. Field names in the schema + and column names in CSV headers are checked by their positions + taking into account ``spark.sql.caseSensitive``. If None is set, + ``true`` is used by default. Though the default value is ``true``, + it is recommended to disable the ``enforceSchema`` option + to avoid incorrect results. :param ignoreLeadingWhiteSpace: a flag indicating whether or not leading whitespaces from values being read should be skipped. If None is set, it uses the default value, ``false``. @@ -664,7 +675,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non maxCharsPerColumn=maxCharsPerColumn, maxMalformedLogPerPartition=maxMalformedLogPerPartition, mode=mode, columnNameOfCorruptRecord=columnNameOfCorruptRecord, multiLine=multiLine, - charToEscapeQuoteEscaping=charToEscapeQuoteEscaping) + charToEscapeQuoteEscaping=charToEscapeQuoteEscaping, enforceSchema=enforceSchema) if isinstance(path, basestring): return self._df(self._jreader.csv(path)) else: diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index a2450932e303d..ea2dd7605dc57 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3056,6 +3056,24 @@ def test_csv_sampling_ratio(self): .csv(rdd, samplingRatio=0.5).schema self.assertEquals(schema, StructType([StructField("_c0", IntegerType(), True)])) + def test_checking_csv_header(self): + path = tempfile.mkdtemp() + shutil.rmtree(path) + try: + self.spark.createDataFrame([[1, 1000], [2000, 2]])\ + .toDF('f1', 'f2').write.option("header", "true").csv(path) + schema = StructType([ + StructField('f2', IntegerType(), nullable=True), + StructField('f1', IntegerType(), nullable=True)]) + df = self.spark.read.option('header', 'true').schema(schema)\ + .csv(path, enforceSchema=False) + self.assertRaisesRegexp( + Exception, + "CSV header does not conform to the schema", + lambda: df.collect()) + finally: + shutil.rmtree(path) + class HiveSparkSubmitTests(SparkSubmitTests): diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index ac4580a0919ad..de6be5f76e15a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -22,6 +22,7 @@ import java.util.{Locale, Properties} import scala.collection.JavaConverters._ import com.fasterxml.jackson.databind.ObjectMapper +import com.univocity.parsers.csv.CsvParser import org.apache.spark.Partition import org.apache.spark.annotation.InterfaceStability @@ -474,6 +475,9 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * it determines the columns as string types and it reads only the first line to determine the * names and the number of fields. * + * If the enforceSchema is set to `false`, only the CSV header in the first line is checked + * to conform specified or inferred schema. + * * @param csvDataset input Dataset with one CSV row per record * @since 2.2.0 */ @@ -499,6 +503,13 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { StructType(schema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord)) val linesWithoutHeader: RDD[String] = maybeFirstLine.map { firstLine => + CSVDataSource.checkHeader( + firstLine, + new CsvParser(parsedOptions.asParserSettings), + actualSchema, + csvDataset.getClass.getCanonicalName, + parsedOptions.enforceSchema, + sparkSession.sessionState.conf.caseSensitiveAnalysis) filteredLines.rdd.mapPartitions(CSVUtils.filterHeaderLine(_, firstLine, parsedOptions)) }.getOrElse(filteredLines.rdd) @@ -539,6 +550,13 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { *
  • `comment` (default empty string): sets a single character used for skipping lines * beginning with this character. By default, it is disabled.
  • *
  • `header` (default `false`): uses the first line as names of columns.
  • + *
  • `enforceSchema` (default `true`): If it is set to `true`, the specified or inferred schema + * will be forcibly applied to datasource files, and headers in CSV files will be ignored. + * If the option is set to `false`, the schema will be validated against all headers in CSV files + * in the case when the `header` option is set to `true`. Field names in the schema + * and column names in CSV headers are checked by their positions taking into account + * `spark.sql.caseSensitive`. Though the default value is true, it is recommended to disable + * the `enforceSchema` option to avoid incorrect results.
  • *
  • `inferSchema` (default `false`): infers the input schema automatically from data. It * requires one extra pass over the data.
  • *
  • `samplingRatio` (default is 1.0): defines fraction of rows used for schema inferring.
  • @@ -583,6 +601,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * created by `PERMISSIVE` mode. This overrides `spark.sql.columnNameOfCorruptRecord`. *
  • `multiLine` (default `false`): parse one record, which may span multiple lines.
  • * + * * @since 2.0.0 */ @scala.annotation.varargs diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala index dc54d182651b1..82322df407521 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala @@ -30,6 +30,7 @@ import org.apache.hadoop.mapreduce.lib.input.FileInputFormat import org.apache.spark.TaskContext import org.apache.spark.input.{PortableDataStream, StreamInputFormat} +import org.apache.spark.internal.Logging import org.apache.spark.rdd.{BinaryFileRDD, RDD} import org.apache.spark.sql.{Dataset, Encoders, SparkSession} import org.apache.spark.sql.catalyst.InternalRow @@ -50,7 +51,10 @@ abstract class CSVDataSource extends Serializable { conf: Configuration, file: PartitionedFile, parser: UnivocityParser, - schema: StructType): Iterator[InternalRow] + requiredSchema: StructType, + // Actual schema of data in the csv file + dataSchema: StructType, + caseSensitive: Boolean): Iterator[InternalRow] /** * Infers the schema from `inputPaths` files. @@ -110,7 +114,7 @@ abstract class CSVDataSource extends Serializable { } } -object CSVDataSource { +object CSVDataSource extends Logging { def apply(options: CSVOptions): CSVDataSource = { if (options.multiLine) { MultiLineCSVDataSource @@ -118,6 +122,84 @@ object CSVDataSource { TextInputCSVDataSource } } + + /** + * Checks that column names in a CSV header and field names in the schema are the same + * by taking into account case sensitivity. + * + * @param schema - provided (or inferred) schema to which CSV must conform. + * @param columnNames - names of CSV columns that must be checked against to the schema. + * @param fileName - name of CSV file that are currently checked. It is used in error messages. + * @param enforceSchema - if it is `true`, column names are ignored otherwise the CSV column + * names are checked for conformance to the schema. In the case if + * the column name don't conform to the schema, an exception is thrown. + * @param caseSensitive - if it is set to `false`, comparison of column names and schema field + * names is not case sensitive. + */ + def checkHeaderColumnNames( + schema: StructType, + columnNames: Array[String], + fileName: String, + enforceSchema: Boolean, + caseSensitive: Boolean): Unit = { + if (columnNames != null) { + val fieldNames = schema.map(_.name).toIndexedSeq + val (headerLen, schemaSize) = (columnNames.size, fieldNames.length) + var errorMessage: Option[String] = None + + if (headerLen == schemaSize) { + var i = 0 + while (errorMessage.isEmpty && i < headerLen) { + var (nameInSchema, nameInHeader) = (fieldNames(i), columnNames(i)) + if (!caseSensitive) { + nameInSchema = nameInSchema.toLowerCase + nameInHeader = nameInHeader.toLowerCase + } + if (nameInHeader != nameInSchema) { + errorMessage = Some( + s"""|CSV header does not conform to the schema. + | Header: ${columnNames.mkString(", ")} + | Schema: ${fieldNames.mkString(", ")} + |Expected: ${fieldNames(i)} but found: ${columnNames(i)} + |CSV file: $fileName""".stripMargin) + } + i += 1 + } + } else { + errorMessage = Some( + s"""|Number of column in CSV header is not equal to number of fields in the schema: + | Header length: $headerLen, schema size: $schemaSize + |CSV file: $fileName""".stripMargin) + } + + errorMessage.foreach { msg => + if (enforceSchema) { + logWarning(msg) + } else { + throw new IllegalArgumentException(msg) + } + } + } + } + + /** + * Checks that CSV header contains the same column names as fields names in the given schema + * by taking into account case sensitivity. + */ + def checkHeader( + header: String, + parser: CsvParser, + schema: StructType, + fileName: String, + enforceSchema: Boolean, + caseSensitive: Boolean): Unit = { + checkHeaderColumnNames( + schema, + parser.parseLine(header), + fileName, + enforceSchema, + caseSensitive) + } } object TextInputCSVDataSource extends CSVDataSource { @@ -127,7 +209,9 @@ object TextInputCSVDataSource extends CSVDataSource { conf: Configuration, file: PartitionedFile, parser: UnivocityParser, - schema: StructType): Iterator[InternalRow] = { + requiredSchema: StructType, + dataSchema: StructType, + caseSensitive: Boolean): Iterator[InternalRow] = { val lines = { val linesReader = new HadoopFileLinesReader(file, conf) Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => linesReader.close())) @@ -136,8 +220,24 @@ object TextInputCSVDataSource extends CSVDataSource { } } - val shouldDropHeader = parser.options.headerFlag && file.start == 0 - UnivocityParser.parseIterator(lines, shouldDropHeader, parser, schema) + val hasHeader = parser.options.headerFlag && file.start == 0 + if (hasHeader) { + // Checking that column names in the header are matched to field names of the schema. + // The header will be removed from lines. + // Note: if there are only comments in the first block, the header would probably + // be not extracted. + CSVUtils.extractHeader(lines, parser.options).foreach { header => + CSVDataSource.checkHeader( + header, + parser.tokenizer, + dataSchema, + file.filePath, + parser.options.enforceSchema, + caseSensitive) + } + } + + UnivocityParser.parseIterator(lines, parser, requiredSchema) } override def infer( @@ -206,12 +306,24 @@ object MultiLineCSVDataSource extends CSVDataSource { conf: Configuration, file: PartitionedFile, parser: UnivocityParser, - schema: StructType): Iterator[InternalRow] = { + requiredSchema: StructType, + dataSchema: StructType, + caseSensitive: Boolean): Iterator[InternalRow] = { + def checkHeader(header: Array[String]): Unit = { + CSVDataSource.checkHeaderColumnNames( + dataSchema, + header, + file.filePath, + parser.options.enforceSchema, + caseSensitive) + } + UnivocityParser.parseStream( CodecStreams.createInputStreamWithCloseResource(conf, new Path(new URI(file.filePath))), parser.options.headerFlag, parser, - schema) + requiredSchema, + checkHeader) } override def infer( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala index 21279d6daf7ad..b90275de9f40a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala @@ -130,6 +130,7 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { "df.filter($\"_corrupt_record\".isNotNull).count()." ) } + val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis (file: PartitionedFile) => { val conf = broadcastedHadoopConf.value.value @@ -137,7 +138,13 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { StructType(dataSchema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord)), StructType(requiredSchema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord)), parsedOptions) - CSVDataSource(parsedOptions).readFile(conf, file, parser, requiredSchema) + CSVDataSource(parsedOptions).readFile( + conf, + file, + parser, + requiredSchema, + dataSchema, + caseSensitive) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala index 7119189a4e131..fab8d62da0c1d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala @@ -156,6 +156,12 @@ class CSVOptions( val samplingRatio = parameters.get("samplingRatio").map(_.toDouble).getOrElse(1.0) + /** + * Forcibly apply the specified or inferred schema to datasource files. + * If the option is enabled, headers of CSV files will be ignored. + */ + val enforceSchema = getBool("enforceSchema", default = true) + def asWriterSettings: CsvWriterSettings = { val writerSettings = new CsvWriterSettings() val format = writerSettings.getFormat diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala index 9dae41b63e810..1012e774118e2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala @@ -68,12 +68,8 @@ object CSVUtils { } } - /** - * Drop header line so that only data can remain. - * This is similar with `filterHeaderLine` above and currently being used in CSV reading path. - */ - def dropHeaderLine(iter: Iterator[String], options: CSVOptions): Iterator[String] = { - val nonEmptyLines = if (options.isCommentSet) { + def skipComments(iter: Iterator[String], options: CSVOptions): Iterator[String] = { + if (options.isCommentSet) { val commentPrefix = options.comment.toString iter.dropWhile { line => line.trim.isEmpty || line.trim.startsWith(commentPrefix) @@ -81,11 +77,19 @@ object CSVUtils { } else { iter.dropWhile(_.trim.isEmpty) } - - if (nonEmptyLines.hasNext) nonEmptyLines.drop(1) - iter } + /** + * Extracts header and moves iterator forward so that only data remains in it + */ + def extractHeader(iter: Iterator[String], options: CSVOptions): Option[String] = { + val nonEmptyLines = skipComments(iter, options) + if (nonEmptyLines.hasNext) { + Some(nonEmptyLines.next()) + } else { + None + } + } /** * Helper method that converts string representation of a character to actual character. * It handles some Java escaped strings and throws exception if given string is longer than one diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala index 4f00cc5eb3f39..5f7d5696b71a6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala @@ -45,7 +45,7 @@ class UnivocityParser( // A `ValueConverter` is responsible for converting the given value to a desired type. private type ValueConverter = String => Any - private val tokenizer = { + val tokenizer = { val parserSetting = options.asParserSettings if (options.columnPruning && requiredSchema.length < dataSchema.length) { val tokenIndexArr = requiredSchema.map(f => java.lang.Integer.valueOf(dataSchema.indexOf(f))) @@ -250,14 +250,15 @@ private[csv] object UnivocityParser { inputStream: InputStream, shouldDropHeader: Boolean, parser: UnivocityParser, - schema: StructType): Iterator[InternalRow] = { + schema: StructType, + checkHeader: Array[String] => Unit): Iterator[InternalRow] = { val tokenizer = parser.tokenizer val safeParser = new FailureSafeParser[Array[String]]( input => Seq(parser.convert(input)), parser.options.parseMode, schema, parser.options.columnNameOfCorruptRecord) - convertStream(inputStream, shouldDropHeader, tokenizer) { tokens => + convertStream(inputStream, shouldDropHeader, tokenizer, checkHeader) { tokens => safeParser.parse(tokens) }.flatten } @@ -265,11 +266,14 @@ private[csv] object UnivocityParser { private def convertStream[T]( inputStream: InputStream, shouldDropHeader: Boolean, - tokenizer: CsvParser)(convert: Array[String] => T) = new Iterator[T] { + tokenizer: CsvParser, + checkHeader: Array[String] => Unit = _ => ())( + convert: Array[String] => T) = new Iterator[T] { tokenizer.beginParsing(inputStream) private var nextRecord = { if (shouldDropHeader) { - tokenizer.parseNext() + val firstRecord = tokenizer.parseNext() + checkHeader(firstRecord) } tokenizer.parseNext() } @@ -291,21 +295,11 @@ private[csv] object UnivocityParser { */ def parseIterator( lines: Iterator[String], - shouldDropHeader: Boolean, parser: UnivocityParser, schema: StructType): Iterator[InternalRow] = { val options = parser.options - val linesWithoutHeader = if (shouldDropHeader) { - // Note that if there are only comments in the first block, the header would probably - // be not dropped. - CSVUtils.dropHeaderLine(lines, options) - } else { - lines - } - - val filteredLines: Iterator[String] = - CSVUtils.filterCommentAndEmpty(linesWithoutHeader, options) + val filteredLines: Iterator[String] = CSVUtils.filterCommentAndEmpty(lines, options) val safeParser = new FailureSafeParser[String]( input => Seq(parser.parse(input)), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index afe10bdc4de26..d2f166c7d1877 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -23,9 +23,13 @@ import java.sql.{Date, Timestamp} import java.text.SimpleDateFormat import java.util.Locale +import scala.collection.JavaConverters._ + import org.apache.commons.lang3.time.FastDateFormat import org.apache.hadoop.io.SequenceFile.CompressionType import org.apache.hadoop.io.compress.GzipCodec +import org.apache.log4j.{AppenderSkeleton, LogManager} +import org.apache.log4j.spi.LoggingEvent import org.apache.spark.SparkException import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row, UDT} @@ -1410,4 +1414,192 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils with Te checkAnswer(idf, List(Row(15, 10, 5), Row(-15, -10, -5))) } } + + def checkHeader(multiLine: Boolean): Unit = { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + withTempPath { path => + val oschema = new StructType().add("f1", DoubleType).add("f2", DoubleType) + val odf = spark.createDataFrame(List(Row(1.0, 1234.5)).asJava, oschema) + odf.write.option("header", true).csv(path.getCanonicalPath) + val ischema = new StructType().add("f2", DoubleType).add("f1", DoubleType) + val exception = intercept[SparkException] { + spark.read + .schema(ischema) + .option("multiLine", multiLine) + .option("header", true) + .option("enforceSchema", false) + .csv(path.getCanonicalPath) + .collect() + } + assert(exception.getMessage.contains("CSV header does not conform to the schema")) + + val shortSchema = new StructType().add("f1", DoubleType) + val exceptionForShortSchema = intercept[SparkException] { + spark.read + .schema(shortSchema) + .option("multiLine", multiLine) + .option("header", true) + .option("enforceSchema", false) + .csv(path.getCanonicalPath) + .collect() + } + assert(exceptionForShortSchema.getMessage.contains( + "Number of column in CSV header is not equal to number of fields in the schema")) + + val longSchema = new StructType() + .add("f1", DoubleType) + .add("f2", DoubleType) + .add("f3", DoubleType) + + val exceptionForLongSchema = intercept[SparkException] { + spark.read + .schema(longSchema) + .option("multiLine", multiLine) + .option("header", true) + .option("enforceSchema", false) + .csv(path.getCanonicalPath) + .collect() + } + assert(exceptionForLongSchema.getMessage.contains("Header length: 2, schema size: 3")) + + val caseSensitiveSchema = new StructType().add("F1", DoubleType).add("f2", DoubleType) + val caseSensitiveException = intercept[SparkException] { + spark.read + .schema(caseSensitiveSchema) + .option("multiLine", multiLine) + .option("header", true) + .option("enforceSchema", false) + .csv(path.getCanonicalPath) + .collect() + } + assert(caseSensitiveException.getMessage.contains( + "CSV header does not conform to the schema")) + } + } + } + + test(s"SPARK-23786: Checking column names against schema in the multiline mode") { + checkHeader(multiLine = true) + } + + test(s"SPARK-23786: Checking column names against schema in the per-line mode") { + checkHeader(multiLine = false) + } + + test("SPARK-23786: CSV header must not be checked if it doesn't exist") { + withTempPath { path => + val oschema = new StructType().add("f1", DoubleType).add("f2", DoubleType) + val odf = spark.createDataFrame(List(Row(1.0, 1234.5)).asJava, oschema) + odf.write.option("header", false).csv(path.getCanonicalPath) + val ischema = new StructType().add("f2", DoubleType).add("f1", DoubleType) + val idf = spark.read + .schema(ischema) + .option("header", false) + .option("enforceSchema", false) + .csv(path.getCanonicalPath) + + checkAnswer(idf, odf) + } + } + + test("SPARK-23786: Ignore column name case if spark.sql.caseSensitive is false") { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { + withTempPath { path => + val oschema = new StructType().add("A", StringType) + val odf = spark.createDataFrame(List(Row("0")).asJava, oschema) + odf.write.option("header", true).csv(path.getCanonicalPath) + val ischema = new StructType().add("a", StringType) + val idf = spark.read.schema(ischema) + .option("header", true) + .option("enforceSchema", false) + .csv(path.getCanonicalPath) + checkAnswer(idf, odf) + } + } + } + + test("SPARK-23786: check header on parsing of dataset of strings") { + val ds = Seq("columnA,columnB", "1.0,1000.0").toDS() + val ischema = new StructType().add("columnB", DoubleType).add("columnA", DoubleType) + val exception = intercept[IllegalArgumentException] { + spark.read.schema(ischema).option("header", true).option("enforceSchema", false).csv(ds) + } + + assert(exception.getMessage.contains("CSV header does not conform to the schema")) + } + + test("SPARK-23786: enforce inferred schema") { + val expectedSchema = new StructType().add("_c0", DoubleType).add("_c1", StringType) + val withHeader = spark.read + .option("inferSchema", true) + .option("enforceSchema", false) + .option("header", true) + .csv(Seq("_c0,_c1", "1.0,a").toDS()) + assert(withHeader.schema == expectedSchema) + checkAnswer(withHeader, Seq(Row(1.0, "a"))) + + // Ignore the inferSchema flag if an user sets a schema + val schema = new StructType().add("colA", DoubleType).add("colB", StringType) + val ds = spark.read + .option("inferSchema", true) + .option("enforceSchema", false) + .option("header", true) + .schema(schema) + .csv(Seq("colA,colB", "1.0,a").toDS()) + assert(ds.schema == schema) + checkAnswer(ds, Seq(Row(1.0, "a"))) + + val exception = intercept[IllegalArgumentException] { + spark.read + .option("inferSchema", true) + .option("enforceSchema", false) + .option("header", true) + .schema(schema) + .csv(Seq("col1,col2", "1.0,a").toDS()) + } + assert(exception.getMessage.contains("CSV header does not conform to the schema")) + } + + test("SPARK-23786: warning should be printed if CSV header doesn't conform to schema") { + class TestAppender extends AppenderSkeleton { + var events = new java.util.ArrayList[LoggingEvent] + override def close(): Unit = {} + override def requiresLayout: Boolean = false + protected def append(event: LoggingEvent): Unit = events.add(event) + } + + val testAppender1 = new TestAppender + LogManager.getRootLogger.addAppender(testAppender1) + try { + val ds = Seq("columnA,columnB", "1.0,1000.0").toDS() + val ischema = new StructType().add("columnB", DoubleType).add("columnA", DoubleType) + + spark.read.schema(ischema).option("header", true).option("enforceSchema", true).csv(ds) + } finally { + LogManager.getRootLogger.removeAppender(testAppender1) + } + assert(testAppender1.events.asScala + .exists(msg => msg.getRenderedMessage.contains("CSV header does not conform to the schema"))) + + val testAppender2 = new TestAppender + LogManager.getRootLogger.addAppender(testAppender2) + try { + withTempPath { path => + val oschema = new StructType().add("f1", DoubleType).add("f2", DoubleType) + val odf = spark.createDataFrame(List(Row(1.0, 1234.5)).asJava, oschema) + odf.write.option("header", true).csv(path.getCanonicalPath) + val ischema = new StructType().add("f2", DoubleType).add("f1", DoubleType) + spark.read + .schema(ischema) + .option("header", true) + .option("enforceSchema", true) + .csv(path.getCanonicalPath) + .collect() + } + } finally { + LogManager.getRootLogger.removeAppender(testAppender2) + } + assert(testAppender2.events.asScala + .exists(msg => msg.getRenderedMessage.contains("CSV header does not conform to the schema"))) + } } From 0be5aa27460f87b5627f9de16ec25b09368d205a Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Mon, 4 Jun 2018 10:16:13 -0700 Subject: [PATCH 0917/1161] [SPARK-23903][SQL] Add support for date extract ## What changes were proposed in this pull request? Add support for date `extract` function: ```sql spark-sql> SELECT EXTRACT(YEAR FROM TIMESTAMP '2000-12-16 12:21:13'); 2000 ``` Supported field same as [Hive](https://github.com/apache/hive/blob/rel/release-2.3.3/ql/src/java/org/apache/hadoop/hive/ql/parse/IdentifiersParser.g#L308-L316): `YEAR`, `QUARTER`, `MONTH`, `WEEK`, `DAY`, `DAYOFWEEK`, `HOUR`, `MINUTE`, `SECOND`. ## How was this patch tested? unit tests Author: Yuming Wang Closes #21479 from wangyum/SPARK-23903. --- .../spark/sql/catalyst/parser/SqlBase.g4 | 3 + .../sql/catalyst/parser/AstBuilder.scala | 28 ++++++ .../parser/TableIdentifierParserSuite.scala | 2 +- .../resources/sql-tests/inputs/extract.sql | 21 ++++ .../sql-tests/results/extract.sql.out | 96 +++++++++++++++++++ 5 files changed, 149 insertions(+), 1 deletion(-) create mode 100644 sql/core/src/test/resources/sql-tests/inputs/extract.sql create mode 100644 sql/core/src/test/resources/sql-tests/results/extract.sql.out diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 7c54851097af3..3fe00eefde7d8 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -592,6 +592,7 @@ primaryExpression | identifier #columnReference | base=primaryExpression '.' fieldName=identifier #dereference | '(' expression ')' #parenthesizedExpression + | EXTRACT '(' field=identifier FROM source=valueExpression ')' #extract ; constant @@ -739,6 +740,7 @@ nonReserved | VIEW | REPLACE | IF | POSITION + | EXTRACT | NO | DATA | START | TRANSACTION | COMMIT | ROLLBACK | IGNORE | SORT | CLUSTER | DISTRIBUTE | UNSET | TBLPROPERTIES | SKEWED | STORED | DIRECTORIES | LOCATION @@ -878,6 +880,7 @@ TRAILING: 'TRAILING'; IF: 'IF'; POSITION: 'POSITION'; +EXTRACT: 'EXTRACT'; EQ : '=' | '=='; NSEQ: '<=>'; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index b9ece295c2510..383ebde3229d6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -1206,6 +1206,34 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging new StringLocate(expression(ctx.substr), expression(ctx.str)) } + /** + * Create a Extract expression. + */ + override def visitExtract(ctx: ExtractContext): Expression = withOrigin(ctx) { + ctx.field.getText.toUpperCase(Locale.ROOT) match { + case "YEAR" => + Year(expression(ctx.source)) + case "QUARTER" => + Quarter(expression(ctx.source)) + case "MONTH" => + Month(expression(ctx.source)) + case "WEEK" => + WeekOfYear(expression(ctx.source)) + case "DAY" => + DayOfMonth(expression(ctx.source)) + case "DAYOFWEEK" => + DayOfWeek(expression(ctx.source)) + case "HOUR" => + Hour(expression(ctx.source)) + case "MINUTE" => + Minute(expression(ctx.source)) + case "SECOND" => + Second(expression(ctx.source)) + case other => + throw new ParseException(s"Literals of type '$other' are currently not supported.", ctx) + } + } + /** * Create a (windowed) Function expression. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala index 89903c2825125..ff0de0fb7c1f0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala @@ -51,7 +51,7 @@ class TableIdentifierParserSuite extends SparkFunSuite { "rollup", "row", "rows", "set", "smallint", "table", "timestamp", "to", "trigger", "true", "truncate", "update", "user", "values", "with", "regexp", "rlike", "bigint", "binary", "boolean", "current_date", "current_timestamp", "date", "double", "float", - "int", "smallint", "timestamp", "at", "position", "both", "leading", "trailing") + "int", "smallint", "timestamp", "at", "position", "both", "leading", "trailing", "extract") val hiveStrictNonReservedKeyword = Seq("anti", "full", "inner", "left", "semi", "right", "natural", "union", "intersect", "except", "database", "on", "join", "cross", "select", "from", diff --git a/sql/core/src/test/resources/sql-tests/inputs/extract.sql b/sql/core/src/test/resources/sql-tests/inputs/extract.sql new file mode 100644 index 0000000000000..9adf5d70056e2 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/extract.sql @@ -0,0 +1,21 @@ +CREATE TEMPORARY VIEW t AS select '2011-05-06 07:08:09.1234567' as c; + +select extract(year from c) from t; + +select extract(quarter from c) from t; + +select extract(month from c) from t; + +select extract(week from c) from t; + +select extract(day from c) from t; + +select extract(dayofweek from c) from t; + +select extract(hour from c) from t; + +select extract(minute from c) from t; + +select extract(second from c) from t; + +select extract(not_supported from c) from t; diff --git a/sql/core/src/test/resources/sql-tests/results/extract.sql.out b/sql/core/src/test/resources/sql-tests/results/extract.sql.out new file mode 100644 index 0000000000000..160e4c7d78455 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/extract.sql.out @@ -0,0 +1,96 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 11 + + +-- !query 0 +CREATE TEMPORARY VIEW t AS select '2011-05-06 07:08:09.1234567' as c +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +select extract(year from c) from t +-- !query 1 schema +struct +-- !query 1 output +2011 + + +-- !query 2 +select extract(quarter from c) from t +-- !query 2 schema +struct +-- !query 2 output +2 + + +-- !query 3 +select extract(month from c) from t +-- !query 3 schema +struct +-- !query 3 output +5 + + +-- !query 4 +select extract(week from c) from t +-- !query 4 schema +struct +-- !query 4 output +18 + + +-- !query 5 +select extract(day from c) from t +-- !query 5 schema +struct +-- !query 5 output +6 + + +-- !query 6 +select extract(dayofweek from c) from t +-- !query 6 schema +struct +-- !query 6 output +6 + + +-- !query 7 +select extract(hour from c) from t +-- !query 7 schema +struct +-- !query 7 output +7 + + +-- !query 8 +select extract(minute from c) from t +-- !query 8 schema +struct +-- !query 8 output +8 + + +-- !query 9 +select extract(second from c) from t +-- !query 9 schema +struct +-- !query 9 output +9 + + +-- !query 10 +select extract(not_supported from c) from t +-- !query 10 schema +struct<> +-- !query 10 output +org.apache.spark.sql.catalyst.parser.ParseException + +Literals of type 'NOT_SUPPORTED' are currently not supported.(line 1, pos 7) + +== SQL == +select extract(not_supported from c) from t +-------^^^ From 7297ae04d87b6e3d48b747a7c1d53687fcc3971c Mon Sep 17 00:00:00 2001 From: aokolnychyi Date: Mon, 4 Jun 2018 13:28:16 -0700 Subject: [PATCH 0918/1161] [SPARK-21896][SQL] Fix StackOverflow caused by window functions inside aggregate functions ## What changes were proposed in this pull request? This PR explicitly prohibits window functions inside aggregates. Currently, this will cause StackOverflow during analysis. See PR #19193 for previous discussion. ## How was this patch tested? This PR comes with a dedicated unit test. Author: aokolnychyi Closes #21473 from aokolnychyi/fix-stackoverflow-window-funcs. --- .../sql/catalyst/analysis/Analyzer.scala | 10 ++++-- .../spark/sql/DataFrameAggregateSuite.scala | 34 +++++++++++++++++-- 2 files changed, 39 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 3eaa9ecf5d075..f9947d1fa6c78 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1744,10 +1744,10 @@ class Analyzer( * it into the plan tree. */ object ExtractWindowExpressions extends Rule[LogicalPlan] { - private def hasWindowFunction(projectList: Seq[NamedExpression]): Boolean = - projectList.exists(hasWindowFunction) + private def hasWindowFunction(exprs: Seq[Expression]): Boolean = + exprs.exists(hasWindowFunction) - private def hasWindowFunction(expr: NamedExpression): Boolean = { + private def hasWindowFunction(expr: Expression): Boolean = { expr.find { case window: WindowExpression => true case _ => false @@ -1830,6 +1830,10 @@ class Analyzer( seenWindowAggregates += newAgg WindowExpression(newAgg, spec) + case AggregateExpression(aggFunc, _, _, _) if hasWindowFunction(aggFunc.children) => + failAnalysis("It is not allowed to use a window function inside an aggregate " + + "function. Please use the inner window function in a sub-query.") + // Extracts AggregateExpression. For example, for SUM(x) - Sum(y) OVER (...), // we need to extract SUM(x). case agg: AggregateExpression if !seenWindowAggregates.contains(agg) => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 96c28961e5aaf..f495a949ebc5a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -19,8 +19,8 @@ package org.apache.spark.sql import scala.util.Random -import org.apache.spark.sql.catalyst.expressions.{Alias, Literal} -import org.apache.spark.sql.catalyst.expressions.aggregate.Count +import org.scalatest.Matchers.the + import org.apache.spark.sql.execution.WholeStageCodegenExec import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec} import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec @@ -687,4 +687,34 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { } } } + + test("SPARK-21896: Window functions inside aggregate functions") { + def checkWindowError(df: => DataFrame): Unit = { + val thrownException = the [AnalysisException] thrownBy { + df.queryExecution.analyzed + } + assert(thrownException.message.contains("not allowed to use a window function")) + } + + checkWindowError(testData2.select(min(avg('b).over(Window.partitionBy('a))))) + checkWindowError(testData2.agg(sum('b), max(rank().over(Window.orderBy('a))))) + checkWindowError(testData2.groupBy('a).agg(sum('b), max(rank().over(Window.orderBy('b))))) + checkWindowError(testData2.groupBy('a).agg(max(sum(sum('b)).over(Window.orderBy('a))))) + checkWindowError( + testData2.groupBy('a).agg(sum('b).as("s"), max(count("*").over())).where('s === 3)) + checkAnswer( + testData2.groupBy('a).agg(max('b), sum('b).as("s"), count("*").over()).where('s === 3), + Row(1, 2, 3, 3) :: Row(2, 2, 3, 3) :: Row(3, 2, 3, 3) :: Nil) + + checkWindowError(sql("SELECT MIN(AVG(b) OVER(PARTITION BY a)) FROM testData2")) + checkWindowError(sql("SELECT SUM(b), MAX(RANK() OVER(ORDER BY a)) FROM testData2")) + checkWindowError(sql("SELECT SUM(b), MAX(RANK() OVER(ORDER BY b)) FROM testData2 GROUP BY a")) + checkWindowError(sql("SELECT MAX(SUM(SUM(b)) OVER(ORDER BY a)) FROM testData2 GROUP BY a")) + checkWindowError( + sql("SELECT MAX(RANK() OVER(ORDER BY b)) FROM testData2 GROUP BY a HAVING SUM(b) = 3")) + checkAnswer( + sql("SELECT a, MAX(b), RANK() OVER(ORDER BY a) FROM testData2 GROUP BY a HAVING SUM(b) = 3"), + Row(1, 2, 1) :: Row(2, 2, 2) :: Row(3, 2, 3) :: Nil) + } + } From b24d3dba6571fd3c9e2649aceeaadc3f9c6cc90f Mon Sep 17 00:00:00 2001 From: Lu WANG Date: Mon, 4 Jun 2018 14:54:31 -0700 Subject: [PATCH 0919/1161] [SPARK-24290][ML] add support for Array input for instrumentation.logNamedValue ## What changes were proposed in this pull request? Extend instrumentation.logNamedValue to support Array input change the logging for "clusterSizes" to new method ## How was this patch tested? N/A Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Lu WANG Closes #21347 from ludatabricks/SPARK-24290. --- .../spark/ml/clustering/BisectingKMeans.scala | 3 +-- .../spark/ml/clustering/GaussianMixture.scala | 3 +-- .../org/apache/spark/ml/clustering/KMeans.scala | 3 +-- .../org/apache/spark/ml/util/Instrumentation.scala | 13 +++++++++++++ 4 files changed, 16 insertions(+), 6 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala index 1ad4e097246a3..9c9614509c64f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala @@ -276,8 +276,7 @@ class BisectingKMeans @Since("2.0.0") ( val summary = new BisectingKMeansSummary( model.transform(dataset), $(predictionCol), $(featuresCol), $(k)) model.setSummary(Some(summary)) - // TODO: need to extend logNamedValue to support Array - instr.logNamedValue("clusterSizes", summary.clusterSizes.mkString("[", ",", "]")) + instr.logNamedValue("clusterSizes", summary.clusterSizes) instr.logSuccess(model) model } diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala index 3091bb5a2e54c..64ecc1ebda589 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala @@ -426,8 +426,7 @@ class GaussianMixture @Since("2.0.0") ( $(predictionCol), $(probabilityCol), $(featuresCol), $(k), logLikelihood) model.setSummary(Some(summary)) instr.logNamedValue("logLikelihood", logLikelihood) - // TODO: need to extend logNamedValue to support Array - instr.logNamedValue("clusterSizes", summary.clusterSizes.mkString("[", ",", "]")) + instr.logNamedValue("clusterSizes", summary.clusterSizes) instr.logSuccess(model) model } diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index e72d7f9485e6a..1704412741d49 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -359,8 +359,7 @@ class KMeans @Since("1.5.0") ( model.transform(dataset), $(predictionCol), $(featuresCol), $(k)) model.setSummary(Some(summary)) - // TODO: need to extend logNamedValue to support Array - instr.logNamedValue("clusterSizes", summary.clusterSizes.mkString("[", ",", "]")) + instr.logNamedValue("clusterSizes", summary.clusterSizes) instr.logSuccess(model) if (handlePersistence) { instances.unpersist() diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala b/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala index 467130b37c16e..3a1c166d46257 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala @@ -132,6 +132,19 @@ private[spark] class Instrumentation[E <: Estimator[_]] private ( log(compact(render(name -> value))) } + def logNamedValue(name: String, value: Array[String]): Unit = { + log(compact(render(name -> compact(render(value.toSeq))))) + } + + def logNamedValue(name: String, value: Array[Long]): Unit = { + log(compact(render(name -> compact(render(value.toSeq))))) + } + + def logNamedValue(name: String, value: Array[Double]): Unit = { + log(compact(render(name -> compact(render(value.toSeq))))) + } + + /** * Logs the successful completion of the training session. */ From ff0501b0c27dc8149bd5fb38a19d9b0056698766 Mon Sep 17 00:00:00 2001 From: Lu WANG Date: Mon, 4 Jun 2018 16:08:27 -0700 Subject: [PATCH 0920/1161] [SPARK-24300][ML] change the way to set seed in ml.cluster.LDASuite.generateLDAData ## What changes were proposed in this pull request? Using different RNG in all different partitions. ## How was this patch tested? manually Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Lu WANG Closes #21492 from ludatabricks/SPARK-24300. --- .../test/scala/org/apache/spark/ml/clustering/LDASuite.scala | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala index 096b5416899e1..db92132d18b7b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala @@ -34,9 +34,8 @@ object LDASuite { vocabSize: Int): DataFrame = { val avgWC = 1 // average instances of each word in a doc val sc = spark.sparkContext - val rng = new java.util.Random() - rng.setSeed(1) val rdd = sc.parallelize(1 to rows).map { i => + val rng = new java.util.Random(i) Vectors.dense(Array.fill(vocabSize)(rng.nextInt(2 * avgWC).toDouble)) }.map(v => new TestRow(v)) spark.createDataFrame(rdd) From dbb4d83829ec4b51d6e6d3a96f7a4e611d8827bc Mon Sep 17 00:00:00 2001 From: Yuanjian Li Date: Tue, 5 Jun 2018 08:23:08 +0700 Subject: [PATCH 0921/1161] [SPARK-24215][PYSPARK] Implement _repr_html_ for dataframes in PySpark ## What changes were proposed in this pull request? Implement `_repr_html_` for PySpark while in notebook and add config named "spark.sql.repl.eagerEval.enabled" to control this. The dev list thread for context: http://apache-spark-developers-list.1001551.n3.nabble.com/eager-execution-and-debuggability-td23928.html ## How was this patch tested? New ut in DataFrameSuite and manual test in jupyter. Some screenshot below. **After:** ![image](https://user-images.githubusercontent.com/4833765/40268422-8db5bef0-5b9f-11e8-80f1-04bc654a4f2c.png) **Before:** ![image](https://user-images.githubusercontent.com/4833765/40268431-9f92c1b8-5b9f-11e8-9db9-0611f0940b26.png) Author: Yuanjian Li Closes #21370 from xuanyuanking/SPARK-24215. --- docs/configuration.md | 27 ++++++ python/pyspark/sql/dataframe.py | 65 +++++++++++++- python/pyspark/sql/tests.py | 30 +++++++ .../scala/org/apache/spark/sql/Dataset.scala | 84 ++++++++++++------- 4 files changed, 176 insertions(+), 30 deletions(-) diff --git a/docs/configuration.md b/docs/configuration.md index 64af0e98a82f5..5588c372d3e42 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -456,6 +456,33 @@ Apart from these, the following properties are also available, and may be useful from JVM to Python worker for every task. + + spark.sql.repl.eagerEval.enabled + false + + Enable eager evaluation or not. If true and the REPL you are using supports eager evaluation, + Dataset will be ran automatically. The HTML table which generated by _repl_html_ + called by notebooks like Jupyter will feedback the queries user have defined. For plain Python + REPL, the output will be shown like dataframe.show() + (see SPARK-24215 for more details). + + + + spark.sql.repl.eagerEval.maxNumRows + 20 + + Default number of rows in eager evaluation output HTML table generated by _repr_html_ or plain text, + this only take effect when spark.sql.repl.eagerEval.enabled is set to true. + + + + spark.sql.repl.eagerEval.truncate + 20 + + Default number of truncate in eager evaluation output HTML table generated by _repr_html_ or + plain text, this only take effect when spark.sql.repl.eagerEval.enabled set to true. + + spark.files diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 808235ab25440..1e6a1acebb5ca 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -78,6 +78,9 @@ def __init__(self, jdf, sql_ctx): self.is_cached = False self._schema = None # initialized lazily self._lazy_rdd = None + # Check whether _repr_html is supported or not, we use it to avoid calling _jdf twice + # by __repr__ and _repr_html_ while eager evaluation opened. + self._support_repr_html = False @property @since(1.3) @@ -351,8 +354,68 @@ def show(self, n=20, truncate=True, vertical=False): else: print(self._jdf.showString(n, int(truncate), vertical)) + @property + def _eager_eval(self): + """Returns true if the eager evaluation enabled. + """ + return self.sql_ctx.getConf( + "spark.sql.repl.eagerEval.enabled", "false").lower() == "true" + + @property + def _max_num_rows(self): + """Returns the max row number for eager evaluation. + """ + return int(self.sql_ctx.getConf( + "spark.sql.repl.eagerEval.maxNumRows", "20")) + + @property + def _truncate(self): + """Returns the truncate length for eager evaluation. + """ + return int(self.sql_ctx.getConf( + "spark.sql.repl.eagerEval.truncate", "20")) + def __repr__(self): - return "DataFrame[%s]" % (", ".join("%s: %s" % c for c in self.dtypes)) + if not self._support_repr_html and self._eager_eval: + vertical = False + return self._jdf.showString( + self._max_num_rows, self._truncate, vertical) + else: + return "DataFrame[%s]" % (", ".join("%s: %s" % c for c in self.dtypes)) + + def _repr_html_(self): + """Returns a dataframe with html code when you enabled eager evaluation + by 'spark.sql.repl.eagerEval.enabled', this only called by REPL you are + using support eager evaluation with HTML. + """ + import cgi + if not self._support_repr_html: + self._support_repr_html = True + if self._eager_eval: + max_num_rows = max(self._max_num_rows, 0) + vertical = False + sock_info = self._jdf.getRowsToPython( + max_num_rows, self._truncate, vertical) + rows = list(_load_from_socket(sock_info, BatchedSerializer(PickleSerializer()))) + head = rows[0] + row_data = rows[1:] + has_more_data = len(row_data) > max_num_rows + row_data = row_data[:max_num_rows] + + html = "\n" + # generate table head + html += "\n" % "\n" % "
    %s
    ".join(map(lambda x: cgi.escape(x), head)) + # generate table rows + for row in row_data: + html += "
    %s
    ".join( + map(lambda x: cgi.escape(x), row)) + html += "
    \n" + if has_more_data: + html += "only showing top %d %s\n" % ( + max_num_rows, "row" if max_num_rows == 1 else "rows") + return html + else: + return None @since(2.1) def checkpoint(self, eager=True): diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index ea2dd7605dc57..487eb19c3b98a 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3074,6 +3074,36 @@ def test_checking_csv_header(self): finally: shutil.rmtree(path) + def test_repr_html(self): + import re + pattern = re.compile(r'^ *\|', re.MULTILINE) + df = self.spark.createDataFrame([(1, "1"), (22222, "22222")], ("key", "value")) + self.assertEquals(None, df._repr_html_()) + with self.sql_conf({"spark.sql.repl.eagerEval.enabled": True}): + expected1 = """ + | + | + | + |
    keyvalue
    11
    2222222222
    + |""" + self.assertEquals(re.sub(pattern, '', expected1), df._repr_html_()) + with self.sql_conf({"spark.sql.repl.eagerEval.truncate": 3}): + expected2 = """ + | + | + | + |
    keyvalue
    11
    222222
    + |""" + self.assertEquals(re.sub(pattern, '', expected2), df._repr_html_()) + with self.sql_conf({"spark.sql.repl.eagerEval.maxNumRows": 1}): + expected3 = """ + | + | + |
    keyvalue
    11
    + |only showing top 1 row + |""" + self.assertEquals(re.sub(pattern, '', expected3), df._repr_html_()) + class HiveSparkSubmitTests(SparkSubmitTests): diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index abb5ae53f4d73..f5526104690d2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -231,16 +231,17 @@ class Dataset[T] private[sql]( } /** - * Compose the string representing rows for output + * Get rows represented in Sequence by specific truncate and vertical requirement. * - * @param _numRows Number of rows to show + * @param numRows Number of rows to return * @param truncate If set to more than 0, truncates strings to `truncate` characters and * all cells will be aligned right. - * @param vertical If set to true, prints output rows vertically (one line per column value). + * @param vertical If set to true, the rows to return do not need truncate. */ - private[sql] def showString( - _numRows: Int, truncate: Int = 20, vertical: Boolean = false): String = { - val numRows = _numRows.max(0).min(Int.MaxValue - 1) + private[sql] def getRows( + numRows: Int, + truncate: Int, + vertical: Boolean): Seq[Seq[String]] = { val newDf = toDF() val castCols = newDf.logicalPlan.output.map { col => // Since binary types in top-level schema fields have a specific format to print, @@ -251,14 +252,12 @@ class Dataset[T] private[sql]( Column(col).cast(StringType) } } - val takeResult = newDf.select(castCols: _*).take(numRows + 1) - val hasMoreData = takeResult.length > numRows - val data = takeResult.take(numRows) + val data = newDf.select(castCols: _*).take(numRows + 1) // For array values, replace Seq and Array with square brackets // For cells that are beyond `truncate` characters, replace it with the // first `truncate-3` and "..." - val rows: Seq[Seq[String]] = schema.fieldNames.toSeq +: data.map { row => + schema.fieldNames.toSeq +: data.map { row => row.toSeq.map { cell => val str = cell match { case null => "null" @@ -274,6 +273,26 @@ class Dataset[T] private[sql]( } }: Seq[String] } + } + + /** + * Compose the string representing rows for output + * + * @param _numRows Number of rows to show + * @param truncate If set to more than 0, truncates strings to `truncate` characters and + * all cells will be aligned right. + * @param vertical If set to true, prints output rows vertically (one line per column value). + */ + private[sql] def showString( + _numRows: Int, + truncate: Int = 20, + vertical: Boolean = false): String = { + val numRows = _numRows.max(0).min(Int.MaxValue - 1) + // Get rows represented by Seq[Seq[String]], we may get one more line if it has more data. + val tmpRows = getRows(numRows, truncate, vertical) + + val hasMoreData = tmpRows.length - 1 > numRows + val rows = tmpRows.take(numRows + 1) val sb = new StringBuilder val numCols = schema.fieldNames.length @@ -291,31 +310,25 @@ class Dataset[T] private[sql]( } } + val paddedRows = rows.map { row => + row.zipWithIndex.map { case (cell, i) => + if (truncate > 0) { + StringUtils.leftPad(cell, colWidths(i)) + } else { + StringUtils.rightPad(cell, colWidths(i)) + } + } + } + // Create SeparateLine val sep: String = colWidths.map("-" * _).addString(sb, "+", "+", "+\n").toString() // column names - rows.head.zipWithIndex.map { case (cell, i) => - if (truncate > 0) { - StringUtils.leftPad(cell, colWidths(i)) - } else { - StringUtils.rightPad(cell, colWidths(i)) - } - }.addString(sb, "|", "|", "|\n") - + paddedRows.head.addString(sb, "|", "|", "|\n") sb.append(sep) // data - rows.tail.foreach { - _.zipWithIndex.map { case (cell, i) => - if (truncate > 0) { - StringUtils.leftPad(cell.toString, colWidths(i)) - } else { - StringUtils.rightPad(cell.toString, colWidths(i)) - } - }.addString(sb, "|", "|", "|\n") - } - + paddedRows.tail.foreach(_.addString(sb, "|", "|", "|\n")) sb.append(sep) } else { // Extended display mode enabled @@ -346,7 +359,7 @@ class Dataset[T] private[sql]( } // Print a footer - if (vertical && data.isEmpty) { + if (vertical && rows.tail.isEmpty) { // In a vertical mode, print an empty row set explicitly sb.append("(0 rows)\n") } else if (hasMoreData) { @@ -3209,6 +3222,19 @@ class Dataset[T] private[sql]( } } + private[sql] def getRowsToPython( + _numRows: Int, + truncate: Int, + vertical: Boolean): Array[Any] = { + EvaluatePython.registerPicklers() + val numRows = _numRows.max(0).min(Int.MaxValue - 1) + val rows = getRows(numRows, truncate, vertical).map(_.toArray).toArray + val toJava: (Any) => Any = EvaluatePython.toJava(_, ArrayType(ArrayType(StringType))) + val iter: Iterator[Array[Byte]] = new SerDeUtil.AutoBatchedPickler( + rows.iterator.map(toJava)) + PythonRDD.serveIterator(iter, "serve-GetRows") + } + /** * Collect a Dataset as ArrowPayload byte arrays and serve to PySpark. */ From b3417b731d4e323398a0d7ec6e86405f4464f4f9 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Tue, 5 Jun 2018 08:29:29 +0700 Subject: [PATCH 0922/1161] [SPARK-16451][REPL] Fail shell if SparkSession fails to start. Currently, in spark-shell, if the session fails to start, the user sees a bunch of unrelated errors which are caused by code in the shell initialization that references the "spark" variable, which does not exist in that case. Things like: ``` :14: error: not found: value spark import spark.sql ``` The user is also left with a non-working shell (unless they want to just write non-Spark Scala or Python code, that is). This change fails the whole shell session at the point where the failure occurs, so that the last error message is the one with the actual information about the failure. For the python error handling, I moved the session initialization code to session.py, so that traceback.print_exc() only shows the last error. Otherwise, the printed exception would contain all previous exceptions with a message "During handling of the above exception, another exception occurred", making the actual error kinda hard to parse. Tested with spark-shell, pyspark (with 2.7 and 3.5), by forcing an error during SparkContext initialization. Author: Marcelo Vanzin Closes #21368 from vanzin/SPARK-16451. --- python/pyspark/shell.py | 26 ++----- python/pyspark/sql/session.py | 34 +++++++++ .../scala/org/apache/spark/repl/Main.scala | 72 ++++++++++--------- 3 files changed, 81 insertions(+), 51 deletions(-) diff --git a/python/pyspark/shell.py b/python/pyspark/shell.py index b5fcf7092d93a..472c3cd4452f0 100644 --- a/python/pyspark/shell.py +++ b/python/pyspark/shell.py @@ -38,25 +38,13 @@ SparkContext._ensure_initialized() try: - # Try to access HiveConf, it will raise exception if Hive is not added - conf = SparkConf() - if conf.get('spark.sql.catalogImplementation', 'hive').lower() == 'hive': - SparkContext._jvm.org.apache.hadoop.hive.conf.HiveConf() - spark = SparkSession.builder\ - .enableHiveSupport()\ - .getOrCreate() - else: - spark = SparkSession.builder.getOrCreate() -except py4j.protocol.Py4JError: - if conf.get('spark.sql.catalogImplementation', '').lower() == 'hive': - warnings.warn("Fall back to non-hive support because failing to access HiveConf, " - "please make sure you build spark with hive") - spark = SparkSession.builder.getOrCreate() -except TypeError: - if conf.get('spark.sql.catalogImplementation', '').lower() == 'hive': - warnings.warn("Fall back to non-hive support because failing to access HiveConf, " - "please make sure you build spark with hive") - spark = SparkSession.builder.getOrCreate() + spark = SparkSession._create_shell_session() +except Exception: + import sys + import traceback + warnings.warn("Failed to initialize Spark session.") + traceback.print_exc(file=sys.stderr) + sys.exit(1) sc = spark.sparkContext sql = spark.sql diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index d675a240172a7..e880dd1ca6d1a 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -547,6 +547,40 @@ def _create_from_pandas_with_arrow(self, pdf, schema, timezone): df._schema = schema return df + @staticmethod + def _create_shell_session(): + """ + Initialize a SparkSession for a pyspark shell session. This is called from shell.py + to make error handling simpler without needing to declare local variables in that + script, which would expose those to users. + """ + import py4j + from pyspark.conf import SparkConf + from pyspark.context import SparkContext + try: + # Try to access HiveConf, it will raise exception if Hive is not added + conf = SparkConf() + if conf.get('spark.sql.catalogImplementation', 'hive').lower() == 'hive': + SparkContext._jvm.org.apache.hadoop.hive.conf.HiveConf() + return SparkSession.builder\ + .enableHiveSupport()\ + .getOrCreate() + else: + return SparkSession.builder.getOrCreate() + except py4j.protocol.Py4JError: + if conf.get('spark.sql.catalogImplementation', '').lower() == 'hive': + warnings.warn("Fall back to non-hive support because failing to access HiveConf, " + "please make sure you build spark with hive") + + try: + return SparkSession.builder.getOrCreate() + except TypeError: + if conf.get('spark.sql.catalogImplementation', '').lower() == 'hive': + warnings.warn("Fall back to non-hive support because failing to access HiveConf, " + "please make sure you build spark with hive") + + return SparkSession.builder.getOrCreate() + @since(2.0) @ignore_unicode_prefix def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=True): diff --git a/repl/src/main/scala/org/apache/spark/repl/Main.scala b/repl/src/main/scala/org/apache/spark/repl/Main.scala index cc76a703bdf8f..e4ddcef9772e4 100644 --- a/repl/src/main/scala/org/apache/spark/repl/Main.scala +++ b/repl/src/main/scala/org/apache/spark/repl/Main.scala @@ -44,6 +44,7 @@ object Main extends Logging { var interp: SparkILoop = _ private var hasErrors = false + private var isShellSession = false private def scalaOptionError(msg: String): Unit = { hasErrors = true @@ -53,6 +54,7 @@ object Main extends Logging { } def main(args: Array[String]) { + isShellSession = true doMain(args, new SparkILoop) } @@ -79,44 +81,50 @@ object Main extends Logging { } def createSparkSession(): SparkSession = { - val execUri = System.getenv("SPARK_EXECUTOR_URI") - conf.setIfMissing("spark.app.name", "Spark shell") - // SparkContext will detect this configuration and register it with the RpcEnv's - // file server, setting spark.repl.class.uri to the actual URI for executors to - // use. This is sort of ugly but since executors are started as part of SparkContext - // initialization in certain cases, there's an initialization order issue that prevents - // this from being set after SparkContext is instantiated. - conf.set("spark.repl.class.outputDir", outputDir.getAbsolutePath()) - if (execUri != null) { - conf.set("spark.executor.uri", execUri) - } - if (System.getenv("SPARK_HOME") != null) { - conf.setSparkHome(System.getenv("SPARK_HOME")) - } + try { + val execUri = System.getenv("SPARK_EXECUTOR_URI") + conf.setIfMissing("spark.app.name", "Spark shell") + // SparkContext will detect this configuration and register it with the RpcEnv's + // file server, setting spark.repl.class.uri to the actual URI for executors to + // use. This is sort of ugly but since executors are started as part of SparkContext + // initialization in certain cases, there's an initialization order issue that prevents + // this from being set after SparkContext is instantiated. + conf.set("spark.repl.class.outputDir", outputDir.getAbsolutePath()) + if (execUri != null) { + conf.set("spark.executor.uri", execUri) + } + if (System.getenv("SPARK_HOME") != null) { + conf.setSparkHome(System.getenv("SPARK_HOME")) + } - val builder = SparkSession.builder.config(conf) - if (conf.get(CATALOG_IMPLEMENTATION.key, "hive").toLowerCase(Locale.ROOT) == "hive") { - if (SparkSession.hiveClassesArePresent) { - // In the case that the property is not set at all, builder's config - // does not have this value set to 'hive' yet. The original default - // behavior is that when there are hive classes, we use hive catalog. - sparkSession = builder.enableHiveSupport().getOrCreate() - logInfo("Created Spark session with Hive support") + val builder = SparkSession.builder.config(conf) + if (conf.get(CATALOG_IMPLEMENTATION.key, "hive").toLowerCase(Locale.ROOT) == "hive") { + if (SparkSession.hiveClassesArePresent) { + // In the case that the property is not set at all, builder's config + // does not have this value set to 'hive' yet. The original default + // behavior is that when there are hive classes, we use hive catalog. + sparkSession = builder.enableHiveSupport().getOrCreate() + logInfo("Created Spark session with Hive support") + } else { + // Need to change it back to 'in-memory' if no hive classes are found + // in the case that the property is set to hive in spark-defaults.conf + builder.config(CATALOG_IMPLEMENTATION.key, "in-memory") + sparkSession = builder.getOrCreate() + logInfo("Created Spark session") + } } else { - // Need to change it back to 'in-memory' if no hive classes are found - // in the case that the property is set to hive in spark-defaults.conf - builder.config(CATALOG_IMPLEMENTATION.key, "in-memory") + // In the case that the property is set but not to 'hive', the internal + // default is 'in-memory'. So the sparkSession will use in-memory catalog. sparkSession = builder.getOrCreate() logInfo("Created Spark session") } - } else { - // In the case that the property is set but not to 'hive', the internal - // default is 'in-memory'. So the sparkSession will use in-memory catalog. - sparkSession = builder.getOrCreate() - logInfo("Created Spark session") + sparkContext = sparkSession.sparkContext + sparkSession + } catch { + case e: Exception if isShellSession => + logError("Failed to initialize Spark session.", e) + sys.exit(1) } - sparkContext = sparkSession.sparkContext - sparkSession } } From e8c1a0c2fdb09a628d9cc925676af870d5a7a946 Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Mon, 4 Jun 2018 21:24:35 -0700 Subject: [PATCH 0923/1161] [SPARK-15784] Add Power Iteration Clustering to spark.ml ## What changes were proposed in this pull request? According to the discussion on JIRA. I rewrite the Power Iteration Clustering API in `spark.ml`. ## How was this patch tested? Unit test. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: WeichenXu Closes #21493 from WeichenXu123/pic_api. --- .../clustering/PowerIterationClustering.scala | 157 +++++---------- .../PowerIterationClusteringSuite.scala | 179 ++++++++---------- 2 files changed, 125 insertions(+), 211 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/PowerIterationClustering.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/PowerIterationClustering.scala index 2c30a1d9aa947..1b9a3499947d9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/PowerIterationClustering.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/PowerIterationClustering.scala @@ -18,21 +18,20 @@ package org.apache.spark.ml.clustering import org.apache.spark.annotation.{Experimental, Since} -import org.apache.spark.ml.Transformer import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ import org.apache.spark.mllib.clustering.{PowerIterationClustering => MLlibPowerIterationClustering} import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Dataset, Row} -import org.apache.spark.sql.functions.col +import org.apache.spark.sql.functions.{col, lit} import org.apache.spark.sql.types._ /** * Common params for PowerIterationClustering */ private[clustering] trait PowerIterationClusteringParams extends Params with HasMaxIter - with HasPredictionCol { + with HasWeightCol { /** * The number of clusters to create (k). Must be > 1. Default: 2. @@ -66,62 +65,33 @@ private[clustering] trait PowerIterationClusteringParams extends Params with Has def getInitMode: String = $(initMode) /** - * Param for the name of the input column for vertex IDs. - * Default: "id" + * Param for the name of the input column for source vertex IDs. + * Default: "src" * @group param */ @Since("2.4.0") - val idCol = new Param[String](this, "idCol", "Name of the input column for vertex IDs.", + val srcCol = new Param[String](this, "srcCol", "Name of the input column for source vertex IDs.", (value: String) => value.nonEmpty) - setDefault(idCol, "id") - - /** @group getParam */ - @Since("2.4.0") - def getIdCol: String = getOrDefault(idCol) - - /** - * Param for the name of the input column for neighbors in the adjacency list representation. - * Default: "neighbors" - * @group param - */ - @Since("2.4.0") - val neighborsCol = new Param[String](this, "neighborsCol", - "Name of the input column for neighbors in the adjacency list representation.", - (value: String) => value.nonEmpty) - - setDefault(neighborsCol, "neighbors") - /** @group getParam */ @Since("2.4.0") - def getNeighborsCol: String = $(neighborsCol) + def getSrcCol: String = getOrDefault(srcCol) /** - * Param for the name of the input column for neighbors in the adjacency list representation. - * Default: "similarities" + * Name of the input column for destination vertex IDs. + * Default: "dst" * @group param */ @Since("2.4.0") - val similaritiesCol = new Param[String](this, "similaritiesCol", - "Name of the input column for neighbors in the adjacency list representation.", + val dstCol = new Param[String](this, "dstCol", + "Name of the input column for destination vertex IDs.", (value: String) => value.nonEmpty) - setDefault(similaritiesCol, "similarities") - /** @group getParam */ @Since("2.4.0") - def getSimilaritiesCol: String = $(similaritiesCol) + def getDstCol: String = $(dstCol) - protected def validateAndTransformSchema(schema: StructType): StructType = { - SchemaUtils.checkColumnTypes(schema, $(idCol), Seq(IntegerType, LongType)) - SchemaUtils.checkColumnTypes(schema, $(neighborsCol), - Seq(ArrayType(IntegerType, containsNull = false), - ArrayType(LongType, containsNull = false))) - SchemaUtils.checkColumnTypes(schema, $(similaritiesCol), - Seq(ArrayType(FloatType, containsNull = false), - ArrayType(DoubleType, containsNull = false))) - SchemaUtils.appendColumn(schema, $(predictionCol), IntegerType) - } + setDefault(srcCol -> "src", dstCol -> "dst") } /** @@ -131,21 +101,8 @@ private[clustering] trait PowerIterationClusteringParams extends Params with Has * PIC finds a very low-dimensional embedding of a dataset using truncated power * iteration on a normalized pair-wise similarity matrix of the data. * - * PIC takes an affinity matrix between items (or vertices) as input. An affinity matrix - * is a symmetric matrix whose entries are non-negative similarities between items. - * PIC takes this matrix (or graph) as an adjacency matrix. Specifically, each input row includes: - * - `idCol`: vertex ID - * - `neighborsCol`: neighbors of vertex in `idCol` - * - `similaritiesCol`: non-negative weights (similarities) of edges between the vertex - * in `idCol` and each neighbor in `neighborsCol` - * PIC returns a cluster assignment for each input vertex. It appends a new column `predictionCol` - * containing the cluster assignment in `[0,k)` for each row (vertex). - * - * Notes: - * - [[PowerIterationClustering]] is a transformer with an expensive [[transform]] operation. - * Transform runs the iterative PIC algorithm to cluster the whole input dataset. - * - Input validation: This validates that similarities are non-negative but does NOT validate - * that the input matrix is symmetric. + * This class is not yet an Estimator/Transformer, use `assignClusters` method to run the + * PowerIterationClustering algorithm. * * @see * Spectral clustering (Wikipedia) @@ -154,7 +111,7 @@ private[clustering] trait PowerIterationClusteringParams extends Params with Has @Experimental class PowerIterationClustering private[clustering] ( @Since("2.4.0") override val uid: String) - extends Transformer with PowerIterationClusteringParams with DefaultParamsWritable { + extends PowerIterationClusteringParams with DefaultParamsWritable { setDefault( k -> 2, @@ -164,10 +121,6 @@ class PowerIterationClustering private[clustering] ( @Since("2.4.0") def this() = this(Identifiable.randomUID("PowerIterationClustering")) - /** @group setParam */ - @Since("2.4.0") - def setPredictionCol(value: String): this.type = set(predictionCol, value) - /** @group setParam */ @Since("2.4.0") def setK(value: Int): this.type = set(k, value) @@ -182,66 +135,56 @@ class PowerIterationClustering private[clustering] ( /** @group setParam */ @Since("2.4.0") - def setIdCol(value: String): this.type = set(idCol, value) + def setSrcCol(value: String): this.type = set(srcCol, value) /** @group setParam */ @Since("2.4.0") - def setNeighborsCol(value: String): this.type = set(neighborsCol, value) + def setDstCol(value: String): this.type = set(dstCol, value) /** @group setParam */ @Since("2.4.0") - def setSimilaritiesCol(value: String): this.type = set(similaritiesCol, value) + def setWeightCol(value: String): this.type = set(weightCol, value) + /** + * Run the PIC algorithm and returns a cluster assignment for each input vertex. + * + * @param dataset A dataset with columns src, dst, weight representing the affinity matrix, + * which is the matrix A in the PIC paper. Suppose the src column value is i, + * the dst column value is j, the weight column value is similarity s,,ij,, + * which must be nonnegative. This is a symmetric matrix and hence + * s,,ij,, = s,,ji,,. For any (i, j) with nonzero similarity, there should be + * either (i, j, s,,ij,,) or (j, i, s,,ji,,) in the input. Rows with i = j are + * ignored, because we assume s,,ij,, = 0.0. + * + * @return A dataset that contains columns of vertex id and the corresponding cluster for the id. + * The schema of it will be: + * - id: Long + * - cluster: Int + */ @Since("2.4.0") - override def transform(dataset: Dataset[_]): DataFrame = { - transformSchema(dataset.schema, logging = true) + def assignClusters(dataset: Dataset[_]): DataFrame = { + val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) { + lit(1.0) + } else { + col($(weightCol)).cast(DoubleType) + } - val sparkSession = dataset.sparkSession - val idColValue = $(idCol) - val rdd: RDD[(Long, Long, Double)] = - dataset.select( - col($(idCol)).cast(LongType), - col($(neighborsCol)).cast(ArrayType(LongType, containsNull = false)), - col($(similaritiesCol)).cast(ArrayType(DoubleType, containsNull = false)) - ).rdd.flatMap { - case Row(id: Long, nbrs: Seq[_], sims: Seq[_]) => - require(nbrs.size == sims.size, s"The length of the neighbor ID list must be " + - s"equal to the the length of the neighbor similarity list. Row for ID " + - s"$idColValue=$id has neighbor ID list of length ${nbrs.length} but similarity list " + - s"of length ${sims.length}.") - nbrs.asInstanceOf[Seq[Long]].zip(sims.asInstanceOf[Seq[Double]]).map { - case (nbr, similarity) => (id, nbr, similarity) - } - } + SchemaUtils.checkColumnTypes(dataset.schema, $(srcCol), Seq(IntegerType, LongType)) + SchemaUtils.checkColumnTypes(dataset.schema, $(dstCol), Seq(IntegerType, LongType)) + val rdd: RDD[(Long, Long, Double)] = dataset.select( + col($(srcCol)).cast(LongType), + col($(dstCol)).cast(LongType), + w).rdd.map { + case Row(src: Long, dst: Long, weight: Double) => (src, dst, weight) + } val algorithm = new MLlibPowerIterationClustering() .setK($(k)) .setInitializationMode($(initMode)) .setMaxIterations($(maxIter)) val model = algorithm.run(rdd) - val predictionsRDD: RDD[Row] = model.assignments.map { assignment => - Row(assignment.id, assignment.cluster) - } - - val predictionsSchema = StructType(Seq( - StructField($(idCol), LongType, nullable = false), - StructField($(predictionCol), IntegerType, nullable = false))) - val predictions = { - val uncastPredictions = sparkSession.createDataFrame(predictionsRDD, predictionsSchema) - dataset.schema($(idCol)).dataType match { - case _: LongType => - uncastPredictions - case otherType => - uncastPredictions.select(col($(idCol)).cast(otherType).alias($(idCol))) - } - } - - dataset.join(predictions, $(idCol)) - } - - @Since("2.4.0") - override def transformSchema(schema: StructType): StructType = { - validateAndTransformSchema(schema) + import dataset.sparkSession.implicits._ + model.assignments.toDF } @Since("2.4.0") diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/PowerIterationClusteringSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/PowerIterationClusteringSuite.scala index 65328df17baff..b7072728d48f0 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/PowerIterationClusteringSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/PowerIterationClusteringSuite.scala @@ -17,19 +17,19 @@ package org.apache.spark.ml.clustering -import scala.collection.mutable - import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession} -import org.apache.spark.sql.functions.col +import org.apache.spark.sql.{DataFrame, Dataset, SparkSession} +import org.apache.spark.sql.functions.{col, lit} import org.apache.spark.sql.types._ class PowerIterationClusteringSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + import testImplicits._ + @transient var data: Dataset[_] = _ final val r1 = 1.0 final val n1 = 10 @@ -48,10 +48,9 @@ class PowerIterationClusteringSuite extends SparkFunSuite assert(pic.getK === 2) assert(pic.getMaxIter === 20) assert(pic.getInitMode === "random") - assert(pic.getPredictionCol === "prediction") - assert(pic.getIdCol === "id") - assert(pic.getNeighborsCol === "neighbors") - assert(pic.getSimilaritiesCol === "similarities") + assert(pic.getSrcCol === "src") + assert(pic.getDstCol === "dst") + assert(!pic.isDefined(pic.weightCol)) } test("parameter validation") { @@ -62,125 +61,102 @@ class PowerIterationClusteringSuite extends SparkFunSuite new PowerIterationClustering().setInitMode("no_such_a_mode") } intercept[IllegalArgumentException] { - new PowerIterationClustering().setIdCol("") + new PowerIterationClustering().setSrcCol("") } intercept[IllegalArgumentException] { - new PowerIterationClustering().setNeighborsCol("") - } - intercept[IllegalArgumentException] { - new PowerIterationClustering().setSimilaritiesCol("") + new PowerIterationClustering().setDstCol("") } } test("power iteration clustering") { val n = n1 + n2 - val model = new PowerIterationClustering() + val assignments = new PowerIterationClustering() .setK(2) .setMaxIter(40) - val result = model.transform(data) - - val predictions = Array.fill(2)(mutable.Set.empty[Long]) - result.select("id", "prediction").collect().foreach { - case Row(id: Long, cluster: Integer) => predictions(cluster) += id - } - assert(predictions.toSet == Set((1 until n1).toSet, (n1 until n).toSet)) - - val result2 = new PowerIterationClustering() + .setWeightCol("weight") + .assignClusters(data) + val localAssignments = assignments + .select('id, 'cluster) + .as[(Long, Int)].collect().toSet + val expectedResult = (0 until n1).map(x => (x, 1)).toSet ++ + (n1 until n).map(x => (x, 0)).toSet + assert(localAssignments === expectedResult) + + val assignments2 = new PowerIterationClustering() .setK(2) .setMaxIter(10) .setInitMode("degree") - .transform(data) - val predictions2 = Array.fill(2)(mutable.Set.empty[Long]) - result2.select("id", "prediction").collect().foreach { - case Row(id: Long, cluster: Integer) => predictions2(cluster) += id - } - assert(predictions2.toSet == Set((1 until n1).toSet, (n1 until n).toSet)) + .setWeightCol("weight") + .assignClusters(data) + val localAssignments2 = assignments2 + .select('id, 'cluster) + .as[(Long, Int)].collect().toSet + assert(localAssignments2 === expectedResult) } test("supported input types") { - val model = new PowerIterationClustering() + val pic = new PowerIterationClustering() .setK(2) .setMaxIter(1) + .setWeightCol("weight") - def runTest(idType: DataType, neighborType: DataType, similarityType: DataType): Unit = { + def runTest(srcType: DataType, dstType: DataType, weightType: DataType): Unit = { val typedData = data.select( - col("id").cast(idType).alias("id"), - col("neighbors").cast(ArrayType(neighborType, containsNull = false)).alias("neighbors"), - col("similarities").cast(ArrayType(similarityType, containsNull = false)) - .alias("similarities") + col("src").cast(srcType).alias("src"), + col("dst").cast(dstType).alias("dst"), + col("weight").cast(weightType).alias("weight") ) - model.transform(typedData).collect() - } - - for (idType <- Seq(IntegerType, LongType)) { - runTest(idType, LongType, DoubleType) - } - for (neighborType <- Seq(IntegerType, LongType)) { - runTest(LongType, neighborType, DoubleType) - } - for (similarityType <- Seq(FloatType, DoubleType)) { - runTest(LongType, LongType, similarityType) + pic.assignClusters(typedData).collect() } - } - test("invalid input: wrong types") { - val model = new PowerIterationClustering() - .setK(2) - .setMaxIter(1) - intercept[IllegalArgumentException] { - val typedData = data.select( - col("id").cast(DoubleType).alias("id"), - col("neighbors"), - col("similarities") - ) - model.transform(typedData) + for (srcType <- Seq(IntegerType, LongType)) { + runTest(srcType, LongType, DoubleType) } - intercept[IllegalArgumentException] { - val typedData = data.select( - col("id"), - col("neighbors").cast(ArrayType(DoubleType, containsNull = false)).alias("neighbors"), - col("similarities") - ) - model.transform(typedData) + for (dstType <- Seq(IntegerType, LongType)) { + runTest(LongType, dstType, DoubleType) } - intercept[IllegalArgumentException] { - val typedData = data.select( - col("id"), - col("neighbors"), - col("neighbors").alias("similarities") - ) - model.transform(typedData) + for (weightType <- Seq(FloatType, DoubleType)) { + runTest(LongType, LongType, weightType) } } test("invalid input: negative similarity") { - val model = new PowerIterationClustering() + val pic = new PowerIterationClustering() .setMaxIter(1) + .setWeightCol("weight") val badData = spark.createDataFrame(Seq( - (0, Array(1), Array(-1.0)), - (1, Array(0), Array(-1.0)) - )).toDF("id", "neighbors", "similarities") + (0, 1, -1.0), + (1, 0, -1.0) + )).toDF("src", "dst", "weight") val msg = intercept[SparkException] { - model.transform(badData) + pic.assignClusters(badData) }.getCause.getMessage assert(msg.contains("Similarity must be nonnegative")) } - test("invalid input: mismatched lengths for neighbor and similarity arrays") { - val model = new PowerIterationClustering() - .setMaxIter(1) - val badData = spark.createDataFrame(Seq( - (0, Array(1), Array(0.5)), - (1, Array(0, 2), Array(0.5)), - (2, Array(1), Array(0.5)) - )).toDF("id", "neighbors", "similarities") - val msg = intercept[SparkException] { - model.transform(badData) - }.getCause.getMessage - assert(msg.contains("The length of the neighbor ID list must be equal to the the length of " + - "the neighbor similarity list.")) - assert(msg.contains(s"Row for ID ${model.getIdCol}=1")) + test("test default weight") { + val dataWithoutWeight = data.sample(0.5, 1L).select('src, 'dst) + + val assignments = new PowerIterationClustering() + .setK(2) + .setMaxIter(40) + .assignClusters(dataWithoutWeight) + val localAssignments = assignments + .select('id, 'cluster) + .as[(Long, Int)].collect().toSet + + val dataWithWeightOne = dataWithoutWeight.withColumn("weight", lit(1.0)) + + val assignments2 = new PowerIterationClustering() + .setK(2) + .setMaxIter(40) + .assignClusters(dataWithWeightOne) + val localAssignments2 = assignments2 + .select('id, 'cluster) + .as[(Long, Int)].collect().toSet + + assert(localAssignments === localAssignments2) } test("read/write") { @@ -188,10 +164,9 @@ class PowerIterationClusteringSuite extends SparkFunSuite .setK(4) .setMaxIter(100) .setInitMode("degree") - .setIdCol("test_id") - .setNeighborsCol("myNeighborsCol") - .setSimilaritiesCol("mySimilaritiesCol") - .setPredictionCol("test_prediction") + .setSrcCol("src1") + .setDstCol("dst1") + .setWeightCol("weight") testDefaultReadWrite(t) } } @@ -222,17 +197,13 @@ object PowerIterationClusteringSuite { val n = n1 + n2 val points = genCircle(r1, n1) ++ genCircle(r2, n2) - val rows = for (i <- 1 until n) yield { - val neighbors = for (j <- 0 until i) yield { - j.toLong + val rows = (for (i <- 1 until n) yield { + for (j <- 0 until i) yield { + (i.toLong, j.toLong, sim(points(i), points(j))) } - val similarities = for (j <- 0 until i) yield { - sim(points(i), points(j)) - } - (i.toLong, neighbors.toArray, similarities.toArray) - } + }).flatMap(_.iterator) - spark.createDataFrame(rows).toDF("id", "neighbors", "similarities") + spark.createDataFrame(rows).toDF("src", "dst", "weight") } } From 2c2a86b5d5be6f77ee72d16f990b39ae59f479b9 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 5 Jun 2018 01:08:55 -0700 Subject: [PATCH 0924/1161] [SPARK-24453][SS] Fix error recovering from the failure in a no-data batch ## What changes were proposed in this pull request? The error occurs when we are recovering from a failure in a no-data batch (say X) that has been planned (i.e. written to offset log) but not executed (i.e. not written to commit log). Upon recovery the following sequence of events happen. 1. `MicroBatchExecution.populateStartOffsets` sets `currentBatchId` to X. Since there was no data in the batch, the `availableOffsets` is same as `committedOffsets`, so `isNewDataAvailable` is `false`. 2. When `MicroBatchExecution.constructNextBatch` is called, ideally it should immediately return true because the next batch has already been constructed. However, the check of whether the batch has been constructed was `if (isNewDataAvailable) return true`. Since the planned batch is a no-data batch, it escaped this check and proceeded to plan the same batch X *once again*. The solution is to have an explicit flag that signifies whether a batch has already been constructed or not. `populateStartOffsets` is going to set the flag appropriately. ## How was this patch tested? new unit test Author: Tathagata Das Closes #21491 from tdas/SPARK-24453. --- .../streaming/MicroBatchExecution.scala | 38 ++++++---- .../streaming/MicroBatchExecutionSuite.scala | 71 +++++++++++++++++++ .../spark/sql/streaming/StreamTest.scala | 2 +- 3 files changed, 98 insertions(+), 13 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecutionSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index 7817360810bde..17ffa2a517312 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -126,6 +126,12 @@ class MicroBatchExecution( _logicalPlan } + /** + * Signifies whether current batch (i.e. for the batch `currentBatchId`) has been constructed + * (i.e. written to the offsetLog) and is ready for execution. + */ + private var isCurrentBatchConstructed = false + /** * Signals to the thread executing micro-batches that it should stop running after the next * batch. This method blocks until the thread stops running. @@ -154,7 +160,6 @@ class MicroBatchExecution( triggerExecutor.execute(() => { if (isActive) { - var currentBatchIsRunnable = false // Whether the current batch is runnable / has been run var currentBatchHasNewData = false // Whether the current batch had new data startTrigger() @@ -175,7 +180,9 @@ class MicroBatchExecution( // new data to process as `constructNextBatch` may decide to run a batch for // state cleanup, etc. `isNewDataAvailable` will be updated to reflect whether new data // is available or not. - currentBatchIsRunnable = constructNextBatch(noDataBatchesEnabled) + if (!isCurrentBatchConstructed) { + isCurrentBatchConstructed = constructNextBatch(noDataBatchesEnabled) + } // Remember whether the current batch has data or not. This will be required later // for bookkeeping after running the batch, when `isNewDataAvailable` will have changed @@ -183,7 +190,7 @@ class MicroBatchExecution( currentBatchHasNewData = isNewDataAvailable currentStatus = currentStatus.copy(isDataAvailable = isNewDataAvailable) - if (currentBatchIsRunnable) { + if (isCurrentBatchConstructed) { if (currentBatchHasNewData) updateStatusMessage("Processing new data") else updateStatusMessage("No new data but cleaning up state") runBatch(sparkSessionForStream) @@ -194,9 +201,12 @@ class MicroBatchExecution( finishTrigger(currentBatchHasNewData) // Must be outside reportTimeTaken so it is recorded - // If the current batch has been executed, then increment the batch id, else there was - // no data to execute the batch - if (currentBatchIsRunnable) currentBatchId += 1 else Thread.sleep(pollingDelayMs) + // If the current batch has been executed, then increment the batch id and reset flag. + // Otherwise, there was no data to execute the batch and sleep for some time + if (isCurrentBatchConstructed) { + currentBatchId += 1 + isCurrentBatchConstructed = false + } else Thread.sleep(pollingDelayMs) } updateStatusMessage("Waiting for next trigger") isActive @@ -231,6 +241,7 @@ class MicroBatchExecution( /* First assume that we are re-executing the latest known batch * in the offset log */ currentBatchId = latestBatchId + isCurrentBatchConstructed = true availableOffsets = nextOffsets.toStreamProgress(sources) /* Initialize committed offsets to a committed batch, which at this * is the second latest batch id in the offset log. */ @@ -269,6 +280,7 @@ class MicroBatchExecution( // here, so we do nothing here. } currentBatchId = latestCommittedBatchId + 1 + isCurrentBatchConstructed = false committedOffsets ++= availableOffsets // Construct a new batch be recomputing availableOffsets } else if (latestCommittedBatchId < latestBatchId - 1) { @@ -313,11 +325,8 @@ class MicroBatchExecution( * - If either of the above is true, then construct the next batch by committing to the offset * log that range of offsets that the next batch will process. */ - private def constructNextBatch(noDataBatchesEnables: Boolean): Boolean = withProgressLocked { - // If new data is already available that means this method has already been called before - // and it must have already committed the offset range of next batch to the offset log. - // Hence do nothing, just return true. - if (isNewDataAvailable) return true + private def constructNextBatch(noDataBatchesEnabled: Boolean): Boolean = withProgressLocked { + if (isCurrentBatchConstructed) return true // Generate a map from each unique source to the next available offset. val latestOffsets: Map[BaseStreamingSource, Option[Offset]] = uniqueSources.map { @@ -348,9 +357,14 @@ class MicroBatchExecution( batchTimestampMs = triggerClock.getTimeMillis()) // Check whether next batch should be constructed - val lastExecutionRequiresAnotherBatch = noDataBatchesEnables && + val lastExecutionRequiresAnotherBatch = noDataBatchesEnabled && Option(lastExecution).exists(_.shouldRunAnotherBatch(offsetSeqMetadata)) val shouldConstructNextBatch = isNewDataAvailable || lastExecutionRequiresAnotherBatch + logTrace( + s"noDataBatchesEnabled = $noDataBatchesEnabled, " + + s"lastExecutionRequiresAnotherBatch = $lastExecutionRequiresAnotherBatch, " + + s"isNewDataAvailable = $isNewDataAvailable, " + + s"shouldConstructNextBatch = $shouldConstructNextBatch") if (shouldConstructNextBatch) { // Commit the next batch offset range to the offset log diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecutionSuite.scala new file mode 100644 index 0000000000000..c228740df07c8 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecutionSuite.scala @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import org.scalatest.BeforeAndAfter + +import org.apache.spark.sql.functions.{count, window} +import org.apache.spark.sql.streaming.StreamTest + +class MicroBatchExecutionSuite extends StreamTest with BeforeAndAfter { + + import testImplicits._ + + after { + sqlContext.streams.active.foreach(_.stop()) + } + + test("SPARK-24156: do not plan a no-data batch again after it has already been planned") { + val inputData = MemoryStream[Int] + val df = inputData.toDF() + .withColumn("eventTime", $"value".cast("timestamp")) + .withWatermark("eventTime", "10 seconds") + .groupBy(window($"eventTime", "5 seconds") as 'window) + .agg(count("*") as 'count) + .select($"window".getField("start").cast("long").as[Long], $"count".as[Long]) + + testStream(df)( + AddData(inputData, 10, 11, 12, 13, 14, 15), // Set watermark to 5 + CheckAnswer(), + AddData(inputData, 25), // Set watermark to 15 to make MicroBatchExecution run no-data batch + CheckAnswer((10, 5)), // Last batch should be a no-data batch + StopStream, + Execute { q => + // Delete the last committed batch from the commit log to signify that the last batch + // (a no-data batch) never completed + val commit = q.commitLog.getLatest().map(_._1).getOrElse(-1L) + q.commitLog.purgeAfter(commit - 1) + }, + // Add data before start so that MicroBatchExecution can plan a batch. It should not, + // it should first re-run the incomplete no-data batch and then run a new batch to process + // new data. + AddData(inputData, 30), + StartStream(), + CheckNewAnswer((15, 1)), // This should not throw the error reported in SPARK-24156 + StopStream, + Execute { q => + // Delete the entire commit log + val commit = q.commitLog.getLatest().map(_._1).getOrElse(-1L) + q.commitLog.purge(commit + 1) + }, + AddData(inputData, 50), + StartStream(), + CheckNewAnswer((25, 1), (30, 1)) // This should not throw the error reported in SPARK-24156 + ) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index f348dac1319cb..4c3fd58cb2e45 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -292,7 +292,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be /** Execute arbitrary code */ object Execute { def apply(func: StreamExecution => Any): AssertOnQuery = - AssertOnQuery(query => { func(query); true }) + AssertOnQuery(query => { func(query); true }, "Execute") } object AwaitEpoch { From 93df3cd03503fca7745141fbd2676b8bf70fe92f Mon Sep 17 00:00:00 2001 From: jinxing Date: Tue, 5 Jun 2018 11:32:42 -0700 Subject: [PATCH 0925/1161] [SPARK-22384][SQL] Refine partition pruning when attribute is wrapped in Cast ## What changes were proposed in this pull request? Sql below will get all partitions from metastore, which put much burden on metastore; ``` CREATE TABLE `partition_test`(`col` int) PARTITIONED BY (`pt` byte) SELECT * FROM partition_test WHERE CAST(pt AS INT)=1 ``` The reason is that the the analyzed attribute `dt` is wrapped in `Cast` and `HiveShim` fails to generate a proper partition filter. This pr proposes to take `Cast` into consideration when generate partition filter. ## How was this patch tested? Test added. This pr proposes to use analyzed expressions in `HiveClientSuite` Author: jinxing Closes #19602 from jinxing64/SPARK-22384. --- .../spark/sql/hive/client/HiveShim.scala | 23 +++- .../sql/hive/client/HiveClientSuite.scala | 102 ++++++++++++------ 2 files changed, 86 insertions(+), 39 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala index 948ba542b5733..130e258e78ca2 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala @@ -24,7 +24,6 @@ import java.util.{ArrayList => JArrayList, List => JList, Locale, Map => JMap, S import java.util.concurrent.TimeUnit import scala.collection.JavaConverters._ -import scala.util.Try import scala.util.control.NonFatal import org.apache.hadoop.fs.Path @@ -657,17 +656,31 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { val useAdvanced = SQLConf.get.advancedPartitionPredicatePushdownEnabled + object ExtractAttribute { + def unapply(expr: Expression): Option[Attribute] = { + expr match { + case attr: Attribute => Some(attr) + case Cast(child, dt, _) if !Cast.mayTruncate(child.dataType, dt) => unapply(child) + case _ => None + } + } + } + def convert(expr: Expression): Option[String] = expr match { - case In(NonVarcharAttribute(name), ExtractableLiterals(values)) if useAdvanced => + case In(ExtractAttribute(NonVarcharAttribute(name)), ExtractableLiterals(values)) + if useAdvanced => Some(convertInToOr(name, values)) - case InSet(NonVarcharAttribute(name), ExtractableValues(values)) if useAdvanced => + case InSet(ExtractAttribute(NonVarcharAttribute(name)), ExtractableValues(values)) + if useAdvanced => Some(convertInToOr(name, values)) - case op @ SpecialBinaryComparison(NonVarcharAttribute(name), ExtractableLiteral(value)) => + case op @ SpecialBinaryComparison( + ExtractAttribute(NonVarcharAttribute(name)), ExtractableLiteral(value)) => Some(s"$name ${op.symbol} $value") - case op @ SpecialBinaryComparison(ExtractableLiteral(value), NonVarcharAttribute(name)) => + case op @ SpecialBinaryComparison( + ExtractableLiteral(value), ExtractAttribute(NonVarcharAttribute(name))) => Some(s"$value ${op.symbol} $name") case And(expr1, expr2) if useAdvanced => diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala index f991352b207d4..55275f6b37945 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala @@ -22,13 +22,13 @@ import org.apache.hadoop.hive.conf.HiveConf import org.scalatest.BeforeAndAfterAll import org.apache.spark.sql.catalyst.catalog._ -import org.apache.spark.sql.catalyst.expressions.{EmptyRow, Expression, In, InSet} -import org.apache.spark.sql.catalyst.parser.CatalystSqlParser +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types.LongType // TODO: Refactor this to `HivePartitionFilteringSuite` class HiveClientSuite(version: String) extends HiveVersionSuite(version) with BeforeAndAfterAll { - import CatalystSqlParser._ private val tryDirectSqlKey = HiveConf.ConfVars.METASTORE_TRY_DIRECT_SQL.varname @@ -46,8 +46,7 @@ class HiveClientSuite(version: String) val hadoopConf = new Configuration() hadoopConf.setBoolean(tryDirectSqlKey, tryDirectSql) val client = buildClient(hadoopConf) - client - .runSqlHive("CREATE TABLE test (value INT) PARTITIONED BY (ds INT, h INT, chunk STRING)") + client.runSqlHive("CREATE TABLE test (value INT) PARTITIONED BY (ds INT, h INT, chunk STRING)") val partitions = for { @@ -66,6 +65,15 @@ class HiveClientSuite(version: String) client } + private def attr(name: String): Attribute = { + client.getTable("default", "test").partitionSchema.fields + .find(field => field.name.equals(name)) match { + case Some(field) => AttributeReference(field.name, field.dataType)() + case None => + fail(s"Illegal name of partition attribute: $name") + } + } + override def beforeAll() { super.beforeAll() client = init(true) @@ -74,7 +82,7 @@ class HiveClientSuite(version: String) test(s"getPartitionsByFilter returns all partitions when $tryDirectSqlKey=false") { val client = init(false) val filteredPartitions = client.getPartitionsByFilter(client.getTable("default", "test"), - Seq(parseExpression("ds=20170101"))) + Seq(attr("ds") === 20170101)) assert(filteredPartitions.size == testPartitionCount) } @@ -82,7 +90,7 @@ class HiveClientSuite(version: String) test("getPartitionsByFilter: ds<=>20170101") { // Should return all partitions where <=> is not supported testMetastorePartitionFiltering( - "ds<=>20170101", + attr("ds") <=> 20170101, 20170101 to 20170103, 0 to 23, "aa" :: "ab" :: "ba" :: "bb" :: Nil) @@ -90,7 +98,7 @@ class HiveClientSuite(version: String) test("getPartitionsByFilter: ds=20170101") { testMetastorePartitionFiltering( - "ds=20170101", + attr("ds") === 20170101, 20170101 to 20170101, 0 to 23, "aa" :: "ab" :: "ba" :: "bb" :: Nil) @@ -100,7 +108,7 @@ class HiveClientSuite(version: String) // Should return all partitions where h=0 because getPartitionsByFilter does not support // comparisons to non-literal values testMetastorePartitionFiltering( - "ds=(20170101 + 1) and h=0", + attr("ds") === (Literal(20170101) + 1) && attr("h") === 0, 20170101 to 20170103, 0 to 0, "aa" :: "ab" :: "ba" :: "bb" :: Nil) @@ -108,7 +116,7 @@ class HiveClientSuite(version: String) test("getPartitionsByFilter: chunk='aa'") { testMetastorePartitionFiltering( - "chunk='aa'", + attr("chunk") === "aa", 20170101 to 20170103, 0 to 23, "aa" :: Nil) @@ -116,7 +124,7 @@ class HiveClientSuite(version: String) test("getPartitionsByFilter: 20170101=ds") { testMetastorePartitionFiltering( - "20170101=ds", + Literal(20170101) === attr("ds"), 20170101 to 20170101, 0 to 23, "aa" :: "ab" :: "ba" :: "bb" :: Nil) @@ -124,7 +132,15 @@ class HiveClientSuite(version: String) test("getPartitionsByFilter: ds=20170101 and h=10") { testMetastorePartitionFiltering( - "ds=20170101 and h=10", + attr("ds") === 20170101 && attr("h") === 10, + 20170101 to 20170101, + 10 to 10, + "aa" :: "ab" :: "ba" :: "bb" :: Nil) + } + + test("getPartitionsByFilter: chunk in cast(ds as long)=20170101L") { + testMetastorePartitionFiltering( + attr("ds").cast(LongType) === 20170101L && attr("h") === 10, 20170101 to 20170101, 10 to 10, "aa" :: "ab" :: "ba" :: "bb" :: Nil) @@ -132,7 +148,7 @@ class HiveClientSuite(version: String) test("getPartitionsByFilter: ds=20170101 or ds=20170102") { testMetastorePartitionFiltering( - "ds=20170101 or ds=20170102", + attr("ds") === 20170101 || attr("ds") === 20170102, 20170101 to 20170102, 0 to 23, "aa" :: "ab" :: "ba" :: "bb" :: Nil) @@ -140,7 +156,15 @@ class HiveClientSuite(version: String) test("getPartitionsByFilter: ds in (20170102, 20170103) (using IN expression)") { testMetastorePartitionFiltering( - "ds in (20170102, 20170103)", + attr("ds").in(20170102, 20170103), + 20170102 to 20170103, + 0 to 23, + "aa" :: "ab" :: "ba" :: "bb" :: Nil) + } + + test("getPartitionsByFilter: cast(ds as long) in (20170102L, 20170103L) (using IN expression)") { + testMetastorePartitionFiltering( + attr("ds").cast(LongType).in(20170102L, 20170103L), 20170102 to 20170103, 0 to 23, "aa" :: "ab" :: "ba" :: "bb" :: Nil) @@ -148,7 +172,19 @@ class HiveClientSuite(version: String) test("getPartitionsByFilter: ds in (20170102, 20170103) (using INSET expression)") { testMetastorePartitionFiltering( - "ds in (20170102, 20170103)", + attr("ds").in(20170102, 20170103), + 20170102 to 20170103, + 0 to 23, + "aa" :: "ab" :: "ba" :: "bb" :: Nil, { + case expr @ In(v, list) if expr.inSetConvertible => + InSet(v, list.map(_.eval(EmptyRow)).toSet) + }) + } + + test("getPartitionsByFilter: cast(ds as long) in (20170102L, 20170103L) (using INSET expression)") + { + testMetastorePartitionFiltering( + attr("ds").cast(LongType).in(20170102L, 20170103L), 20170102 to 20170103, 0 to 23, "aa" :: "ab" :: "ba" :: "bb" :: Nil, { @@ -159,7 +195,7 @@ class HiveClientSuite(version: String) test("getPartitionsByFilter: chunk in ('ab', 'ba') (using IN expression)") { testMetastorePartitionFiltering( - "chunk in ('ab', 'ba')", + attr("chunk").in("ab", "ba"), 20170101 to 20170103, 0 to 23, "ab" :: "ba" :: Nil) @@ -167,7 +203,7 @@ class HiveClientSuite(version: String) test("getPartitionsByFilter: chunk in ('ab', 'ba') (using INSET expression)") { testMetastorePartitionFiltering( - "chunk in ('ab', 'ba')", + attr("chunk").in("ab", "ba"), 20170101 to 20170103, 0 to 23, "ab" :: "ba" :: Nil, { @@ -179,26 +215,24 @@ class HiveClientSuite(version: String) test("getPartitionsByFilter: (ds=20170101 and h>=8) or (ds=20170102 and h<8)") { val day1 = (20170101 to 20170101, 8 to 23, Seq("aa", "ab", "ba", "bb")) val day2 = (20170102 to 20170102, 0 to 7, Seq("aa", "ab", "ba", "bb")) - testMetastorePartitionFiltering( - "(ds=20170101 and h>=8) or (ds=20170102 and h<8)", - day1 :: day2 :: Nil) + testMetastorePartitionFiltering((attr("ds") === 20170101 && attr("h") >= 8) || + (attr("ds") === 20170102 && attr("h") < 8), day1 :: day2 :: Nil) } test("getPartitionsByFilter: (ds=20170101 and h>=8) or (ds=20170102 and h<(7+1))") { val day1 = (20170101 to 20170101, 8 to 23, Seq("aa", "ab", "ba", "bb")) // Day 2 should include all hours because we can't build a filter for h<(7+1) val day2 = (20170102 to 20170102, 0 to 23, Seq("aa", "ab", "ba", "bb")) - testMetastorePartitionFiltering( - "(ds=20170101 and h>=8) or (ds=20170102 and h<(7+1))", - day1 :: day2 :: Nil) + testMetastorePartitionFiltering((attr("ds") === 20170101 && attr("h") >= 8) || + (attr("ds") === 20170102 && attr("h") < (Literal(7) + 1)), day1 :: day2 :: Nil) } test("getPartitionsByFilter: " + "chunk in ('ab', 'ba') and ((ds=20170101 and h>=8) or (ds=20170102 and h<8))") { val day1 = (20170101 to 20170101, 8 to 23, Seq("ab", "ba")) val day2 = (20170102 to 20170102, 0 to 7, Seq("ab", "ba")) - testMetastorePartitionFiltering( - "chunk in ('ab', 'ba') and ((ds=20170101 and h>=8) or (ds=20170102 and h<8))", + testMetastorePartitionFiltering(attr("chunk").in("ab", "ba") && + ((attr("ds") === 20170101 && attr("h") >= 8) || (attr("ds") === 20170102 && attr("h") < 8)), day1 :: day2 :: Nil) } @@ -207,41 +241,41 @@ class HiveClientSuite(version: String) } private def testMetastorePartitionFiltering( - filterString: String, + filterExpr: Expression, expectedDs: Seq[Int], expectedH: Seq[Int], expectedChunks: Seq[String]): Unit = { testMetastorePartitionFiltering( - filterString, + filterExpr, (expectedDs, expectedH, expectedChunks) :: Nil, identity) } private def testMetastorePartitionFiltering( - filterString: String, + filterExpr: Expression, expectedDs: Seq[Int], expectedH: Seq[Int], expectedChunks: Seq[String], transform: Expression => Expression): Unit = { testMetastorePartitionFiltering( - filterString, + filterExpr, (expectedDs, expectedH, expectedChunks) :: Nil, - identity) + transform) } private def testMetastorePartitionFiltering( - filterString: String, + filterExpr: Expression, expectedPartitionCubes: Seq[(Seq[Int], Seq[Int], Seq[String])]): Unit = { - testMetastorePartitionFiltering(filterString, expectedPartitionCubes, identity) + testMetastorePartitionFiltering(filterExpr, expectedPartitionCubes, identity) } private def testMetastorePartitionFiltering( - filterString: String, + filterExpr: Expression, expectedPartitionCubes: Seq[(Seq[Int], Seq[Int], Seq[String])], transform: Expression => Expression): Unit = { val filteredPartitions = client.getPartitionsByFilter(client.getTable("default", "test"), Seq( - transform(parseExpression(filterString)) + transform(filterExpr) )) val expectedPartitionCount = expectedPartitionCubes.map { From e9efb62e0795c8d5233b7e5bfc276d74953942b8 Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Wed, 6 Jun 2018 08:31:35 +0700 Subject: [PATCH 0926/1161] [SPARK-24187][R][SQL] Add array_join function to SparkR ## What changes were proposed in this pull request? This PR adds array_join function to SparkR ## How was this patch tested? Add unit test in test_sparkSQL.R Author: Huaxin Gao Closes #21313 from huaxingao/spark-24187. --- R/pkg/NAMESPACE | 1 + R/pkg/R/functions.R | 29 ++++++++++++++++++++++++--- R/pkg/R/generics.R | 4 ++++ R/pkg/tests/fulltests/test_sparkSQL.R | 15 ++++++++++++++ 4 files changed, 46 insertions(+), 3 deletions(-) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 73a33af4dd48b..9696f6987ad78 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -201,6 +201,7 @@ exportMethods("%<=>%", "approxCountDistinct", "approxQuantile", "array_contains", + "array_join", "array_max", "array_min", "array_position", diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index abc91aeeb4825..3bff633fbc1ff 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -221,7 +221,9 @@ NULL #' head(select(tmp3, element_at(tmp3$v3, "Valiant"))) #' tmp4 <- mutate(df, v4 = create_array(df$mpg, df$cyl), v5 = create_array(df$cyl, df$hp)) #' head(select(tmp4, concat(tmp4$v4, tmp4$v5), arrays_overlap(tmp4$v4, tmp4$v5))) -#' head(select(tmp, concat(df$mpg, df$cyl, df$hp)))} +#' head(select(tmp, concat(df$mpg, df$cyl, df$hp))) +#' tmp5 <- mutate(df, v6 = create_array(df$model, df$model)) +#' head(select(tmp5, array_join(tmp5$v6, "#"), array_join(tmp5$v6, "#", "NULL")))} NULL #' Window functions for Column operations @@ -3006,6 +3008,27 @@ setMethod("array_contains", column(jc) }) +#' @details +#' \code{array_join}: Concatenates the elements of column using the delimiter. +#' Null values are replaced with nullReplacement if set, otherwise they are ignored. +#' +#' @param delimiter a character string that is used to concatenate the elements of column. +#' @param nullReplacement an optional character string that is used to replace the Null values. +#' @rdname column_collection_functions +#' @aliases array_join array_join,Column-method +#' @note array_join since 2.4.0 +setMethod("array_join", + signature(x = "Column", delimiter = "character"), + function(x, delimiter, nullReplacement = NULL) { + jc <- if (is.null(nullReplacement)) { + callJStatic("org.apache.spark.sql.functions", "array_join", x@jc, delimiter) + } else { + callJStatic("org.apache.spark.sql.functions", "array_join", x@jc, delimiter, + as.character(nullReplacement)) + } + column(jc) + }) + #' @details #' \code{array_max}: Returns the maximum value of the array. #' @@ -3197,8 +3220,8 @@ setMethod("size", #' (or starting from the end if start is negative) with the specified length. #' #' @rdname column_collection_functions -#' @param start an index indicating the first element occuring in the result. -#' @param length a number of consecutive elements choosen to the result. +#' @param start an index indicating the first element occurring in the result. +#' @param length a number of consecutive elements chosen to the result. #' @aliases slice slice,Column-method #' @note slice since 2.4.0 setMethod("slice", diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 8894cb1c5b92f..9321bbaf96ff8 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -757,6 +757,10 @@ setGeneric("approxCountDistinct", function(x, ...) { standardGeneric("approxCoun #' @name NULL setGeneric("array_contains", function(x, value) { standardGeneric("array_contains") }) +#' @rdname column_collection_functions +#' @name NULL +setGeneric("array_join", function(x, delimiter, ...) { standardGeneric("array_join") }) + #' @rdname column_collection_functions #' @name NULL setGeneric("array_max", function(x) { standardGeneric("array_max") }) diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index 16c1fd5a065eb..36e0f78bb0599 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -1518,6 +1518,21 @@ test_that("column functions", { result <- collect(select(df, arrays_overlap(df[[1]], df[[2]])))[[1]] expect_equal(result, c(TRUE, FALSE, NA)) + # Test array_join() + df <- createDataFrame(list(list(list("Hello", "World!")))) + result <- collect(select(df, array_join(df[[1]], "#")))[[1]] + expect_equal(result, "Hello#World!") + df2 <- createDataFrame(list(list(list("Hello", NA, "World!")))) + result <- collect(select(df2, array_join(df2[[1]], "#", "Beautiful")))[[1]] + expect_equal(result, "Hello#Beautiful#World!") + result <- collect(select(df2, array_join(df2[[1]], "#")))[[1]] + expect_equal(result, "Hello#World!") + df3 <- createDataFrame(list(list(list("Hello", NULL, "World!")))) + result <- collect(select(df3, array_join(df3[[1]], "#", "Beautiful")))[[1]] + expect_equal(result, "Hello#Beautiful#World!") + result <- collect(select(df3, array_join(df3[[1]], "#")))[[1]] + expect_equal(result, "Hello#World!") + # Test array_sort() and sort_array() df <- createDataFrame(list(list(list(2L, 1L, 3L, NA)), list(list(NA, 6L, 5L, NA, 4L)))) From e76b0124fbe463def00b1dffcfd8fd47e04772fe Mon Sep 17 00:00:00 2001 From: Asher Saban Date: Wed, 6 Jun 2018 07:14:08 -0700 Subject: [PATCH 0927/1161] [SPARK-23803][SQL] Support bucket pruning ## What changes were proposed in this pull request? support bucket pruning when filtering on a single bucketed column on the following predicates - EqualTo, EqualNullSafe, In, And/Or predicates ## How was this patch tested? refactored unit tests to test the above. based on gatorsmile work in https://github.com/apache/spark/commit/e3c75c6398b1241500343ff237e9bcf78b5396f9 Author: Asher Saban Author: asaban Closes #20915 from sabanas/filter-prune-buckets. --- .../sql/execution/DataSourceScanExec.scala | 32 ++++- .../datasources/BucketingUtils.scala | 14 +++ .../datasources/DataSourceStrategy.scala | 12 -- .../datasources/FileSourceStrategy.scala | 98 +++++++++++++++- .../spark/sql/sources/BucketedReadSuite.scala | 110 +++++++++++++++--- 5 files changed, 231 insertions(+), 35 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index 61c14fee09337..d7f2654be0451 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -36,6 +36,7 @@ import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.sources.{BaseRelation, Filter} import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils +import org.apache.spark.util.collection.BitSet trait DataSourceScanExec extends LeafExecNode with CodegenSupport { val relation: BaseRelation @@ -151,6 +152,7 @@ case class RowDataSourceScanExec( * @param output Output attributes of the scan, including data attributes and partition attributes. * @param requiredSchema Required schema of the underlying relation, excluding partition columns. * @param partitionFilters Predicates to use for partition pruning. + * @param optionalBucketSet Bucket ids for bucket pruning * @param dataFilters Filters on non-partition columns. * @param tableIdentifier identifier for the table in the metastore. */ @@ -159,6 +161,7 @@ case class FileSourceScanExec( output: Seq[Attribute], requiredSchema: StructType, partitionFilters: Seq[Expression], + optionalBucketSet: Option[BitSet], dataFilters: Seq[Expression], override val tableIdentifier: Option[TableIdentifier]) extends DataSourceScanExec with ColumnarBatchScan { @@ -286,7 +289,20 @@ case class FileSourceScanExec( } getOrElse { metadata } - withOptPartitionCount + + val withSelectedBucketsCount = relation.bucketSpec.map { spec => + val numSelectedBuckets = optionalBucketSet.map { b => + b.cardinality() + } getOrElse { + spec.numBuckets + } + withOptPartitionCount + ("SelectedBucketsCount" -> + s"$numSelectedBuckets out of ${spec.numBuckets}") + } getOrElse { + withOptPartitionCount + } + + withSelectedBucketsCount } private lazy val inputRDD: RDD[InternalRow] = { @@ -365,7 +381,7 @@ case class FileSourceScanExec( selectedPartitions: Seq[PartitionDirectory], fsRelation: HadoopFsRelation): RDD[InternalRow] = { logInfo(s"Planning with ${bucketSpec.numBuckets} buckets") - val bucketed = + val filesGroupedToBuckets = selectedPartitions.flatMap { p => p.files.map { f => val hosts = getBlockHosts(getBlockLocations(f), 0, f.getLen) @@ -377,8 +393,17 @@ case class FileSourceScanExec( .getOrElse(sys.error(s"Invalid bucket file ${f.filePath}")) } + val prunedFilesGroupedToBuckets = if (optionalBucketSet.isDefined) { + val bucketSet = optionalBucketSet.get + filesGroupedToBuckets.filter { + f => bucketSet.get(f._1) + } + } else { + filesGroupedToBuckets + } + val filePartitions = Seq.tabulate(bucketSpec.numBuckets) { bucketId => - FilePartition(bucketId, bucketed.getOrElse(bucketId, Nil)) + FilePartition(bucketId, prunedFilesGroupedToBuckets.getOrElse(bucketId, Nil)) } new FileScanRDD(fsRelation.sparkSession, readFile, filePartitions) @@ -503,6 +528,7 @@ case class FileSourceScanExec( output.map(QueryPlan.normalizeExprId(_, output)), requiredSchema, QueryPlan.normalizePredicates(partitionFilters, output), + optionalBucketSet, QueryPlan.normalizePredicates(dataFilters, output), None) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BucketingUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BucketingUtils.scala index ea4fe9c8ade5f..a776fc3e7021d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BucketingUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BucketingUtils.scala @@ -17,6 +17,9 @@ package org.apache.spark.sql.execution.datasources +import org.apache.spark.sql.catalyst.expressions.{Attribute, SpecificInternalRow, UnsafeProjection} +import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning + object BucketingUtils { // The file name of bucketed data should have 3 parts: // 1. some other information in the head of file name @@ -35,5 +38,16 @@ object BucketingUtils { case other => None } + // Given bucketColumn, numBuckets and value, returns the corresponding bucketId + def getBucketIdFromValue(bucketColumn: Attribute, numBuckets: Int, value: Any): Int = { + val mutableInternalRow = new SpecificInternalRow(Seq(bucketColumn.dataType)) + mutableInternalRow.update(0, value) + + val bucketIdGenerator = UnsafeProjection.create( + HashPartitioning(Seq(bucketColumn), numBuckets).partitionIdExpression :: Nil, + bucketColumn :: Nil) + bucketIdGenerator(mutableInternalRow).getInt(0) + } + def bucketIdToString(id: Int): String = f"_$id%05d" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 3f41612c08065..7b129435c45db 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -312,18 +312,6 @@ case class DataSourceStrategy(conf: SQLConf) extends Strategy with Logging with case _ => Nil } - // Get the bucket ID based on the bucketing values. - // Restriction: Bucket pruning works iff the bucketing column has one and only one column. - def getBucketId(bucketColumn: Attribute, numBuckets: Int, value: Any): Int = { - val mutableRow = new SpecificInternalRow(Seq(bucketColumn.dataType)) - mutableRow(0) = cast(Literal(value), bucketColumn.dataType).eval(null) - val bucketIdGeneration = UnsafeProjection.create( - HashPartitioning(bucketColumn :: Nil, numBuckets).partitionIdExpression :: Nil, - bucketColumn :: Nil) - - bucketIdGeneration(mutableRow).getInt(0) - } - // Based on Public API. private def pruneFilterProject( relation: LogicalRelation, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala index 0a568d6b8adce..fe27b78bf3360 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala @@ -19,12 +19,13 @@ package org.apache.spark.sql.execution.datasources import org.apache.spark.internal.Logging import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.catalog.BucketSpec import org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.execution.FileSourceScanExec -import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.{FileSourceScanExec, SparkPlan} +import org.apache.spark.util.collection.BitSet /** * A strategy for planning scans over collections of files that might be partitioned or bucketed @@ -50,6 +51,91 @@ import org.apache.spark.sql.execution.SparkPlan * and add it. Proceed to the next file. */ object FileSourceStrategy extends Strategy with Logging { + + // should prune buckets iff num buckets is greater than 1 and there is only one bucket column + private def shouldPruneBuckets(bucketSpec: Option[BucketSpec]): Boolean = { + bucketSpec match { + case Some(spec) => spec.bucketColumnNames.length == 1 && spec.numBuckets > 1 + case None => false + } + } + + private def getExpressionBuckets( + expr: Expression, + bucketColumnName: String, + numBuckets: Int): BitSet = { + + def getBucketNumber(attr: Attribute, v: Any): Int = { + BucketingUtils.getBucketIdFromValue(attr, numBuckets, v) + } + + def getBucketSetFromIterable(attr: Attribute, iter: Iterable[Any]): BitSet = { + val matchedBuckets = new BitSet(numBuckets) + iter + .map(v => getBucketNumber(attr, v)) + .foreach(bucketNum => matchedBuckets.set(bucketNum)) + matchedBuckets + } + + def getBucketSetFromValue(attr: Attribute, v: Any): BitSet = { + val matchedBuckets = new BitSet(numBuckets) + matchedBuckets.set(getBucketNumber(attr, v)) + matchedBuckets + } + + expr match { + case expressions.Equality(a: Attribute, Literal(v, _)) if a.name == bucketColumnName => + getBucketSetFromValue(a, v) + case expressions.In(a: Attribute, list) + if list.forall(_.isInstanceOf[Literal]) && a.name == bucketColumnName => + getBucketSetFromIterable(a, list.map(e => e.eval(EmptyRow))) + case expressions.InSet(a: Attribute, hset) + if hset.forall(_.isInstanceOf[Literal]) && a.name == bucketColumnName => + getBucketSetFromIterable(a, hset.map(e => expressions.Literal(e).eval(EmptyRow))) + case expressions.IsNull(a: Attribute) if a.name == bucketColumnName => + getBucketSetFromValue(a, null) + case expressions.And(left, right) => + getExpressionBuckets(left, bucketColumnName, numBuckets) & + getExpressionBuckets(right, bucketColumnName, numBuckets) + case expressions.Or(left, right) => + getExpressionBuckets(left, bucketColumnName, numBuckets) | + getExpressionBuckets(right, bucketColumnName, numBuckets) + case _ => + val matchedBuckets = new BitSet(numBuckets) + matchedBuckets.setUntil(numBuckets) + matchedBuckets + } + } + + private def genBucketSet( + normalizedFilters: Seq[Expression], + bucketSpec: BucketSpec): Option[BitSet] = { + if (normalizedFilters.isEmpty) { + return None + } + + val bucketColumnName = bucketSpec.bucketColumnNames.head + val numBuckets = bucketSpec.numBuckets + + val normalizedFiltersAndExpr = normalizedFilters + .reduce(expressions.And) + val matchedBuckets = getExpressionBuckets(normalizedFiltersAndExpr, bucketColumnName, + numBuckets) + + val numBucketsSelected = matchedBuckets.cardinality() + + logInfo { + s"Pruned ${numBuckets - numBucketsSelected} out of $numBuckets buckets." + } + + // None means all the buckets need to be scanned + if (numBucketsSelected == numBuckets) { + None + } else { + Some(matchedBuckets) + } + } + def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case PhysicalOperation(projects, filters, l @ LogicalRelation(fsRelation: HadoopFsRelation, _, table, _)) => @@ -82,6 +168,13 @@ object FileSourceStrategy extends Strategy with Logging { logInfo(s"Pruning directories with: ${partitionKeyFilters.mkString(",")}") + val bucketSpec: Option[BucketSpec] = fsRelation.bucketSpec + val bucketSet = if (shouldPruneBuckets(bucketSpec)) { + genBucketSet(normalizedFilters, bucketSpec.get) + } else { + None + } + val dataColumns = l.resolve(fsRelation.dataSchema, fsRelation.sparkSession.sessionState.analyzer.resolver) @@ -111,6 +204,7 @@ object FileSourceStrategy extends Strategy with Logging { outputAttributes, outputSchema, partitionKeyFilters.toSeq, + bucketSet, dataFilters, table.map(_.identifier)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala index fb61fa716b946..a9414200e70f8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala @@ -22,10 +22,11 @@ import java.net.URI import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.catalog.BucketSpec +import org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning import org.apache.spark.sql.execution.{DataSourceScanExec, SortExec} -import org.apache.spark.sql.execution.datasources.DataSourceStrategy +import org.apache.spark.sql.execution.datasources.BucketingUtils import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec import org.apache.spark.sql.execution.joins.SortMergeJoinExec import org.apache.spark.sql.functions._ @@ -52,6 +53,11 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils { s <- Seq(null, "a", "b", "c", "d", "e", "f", null, "g") } yield (i % 5, s, i % 13)).toDF("i", "j", "k") + // number of buckets that doesn't yield empty buckets when bucketing on column j on df/nullDF + // empty buckets before filtering might hide bugs in pruning logic + private val NumBucketsForPruningDF = 7 + private val NumBucketsForPruningNullDf = 5 + test("read bucketed data") { withTable("bucketed_table") { df.write @@ -90,32 +96,37 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils { originalDataFrame: DataFrame): Unit = { // This test verifies parts of the plan. Disable whole stage codegen. withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false") { - val strategy = DataSourceStrategy(spark.sessionState.conf) val bucketedDataFrame = spark.table("bucketed_table").select("i", "j", "k") val BucketSpec(numBuckets, bucketColumnNames, _) = bucketSpec // Limit: bucket pruning only works when the bucket column has one and only one column assert(bucketColumnNames.length == 1) val bucketColumnIndex = bucketedDataFrame.schema.fieldIndex(bucketColumnNames.head) val bucketColumn = bucketedDataFrame.schema.toAttributes(bucketColumnIndex) - val matchedBuckets = new BitSet(numBuckets) - bucketValues.foreach { value => - matchedBuckets.set(strategy.getBucketId(bucketColumn, numBuckets, value)) - } // Filter could hide the bug in bucket pruning. Thus, skipping all the filters val plan = bucketedDataFrame.filter(filterCondition).queryExecution.executedPlan val rdd = plan.find(_.isInstanceOf[DataSourceScanExec]) assert(rdd.isDefined, plan) - val checkedResult = rdd.get.execute().mapPartitionsWithIndex { case (index, iter) => - if (matchedBuckets.get(index % numBuckets) && iter.nonEmpty) Iterator(index) else Iterator() + // if nothing should be pruned, skip the pruning test + if (bucketValues.nonEmpty) { + val matchedBuckets = new BitSet(numBuckets) + bucketValues.foreach { value => + matchedBuckets.set(BucketingUtils.getBucketIdFromValue(bucketColumn, numBuckets, value)) + } + val invalidBuckets = rdd.get.execute().mapPartitionsWithIndex { case (index, iter) => + // return indexes of partitions that should have been pruned and are not empty + if (!matchedBuckets.get(index % numBuckets) && iter.nonEmpty) { + Iterator(index) + } else { + Iterator() + } + }.collect() + + if (invalidBuckets.nonEmpty) { + fail(s"Buckets ${invalidBuckets.mkString(",")} should have been pruned from:\n$plan") + } } - // TODO: These tests are not testing the right columns. -// // checking if all the pruned buckets are empty -// val invalidBuckets = checkedResult.collect().toList -// if (invalidBuckets.nonEmpty) { -// fail(s"Buckets $invalidBuckets should have been pruned from:\n$plan") -// } checkAnswer( bucketedDataFrame.filter(filterCondition).orderBy("i", "j", "k"), @@ -125,7 +136,7 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils { test("read partitioning bucketed tables with bucket pruning filters") { withTable("bucketed_table") { - val numBuckets = 8 + val numBuckets = NumBucketsForPruningDF val bucketSpec = BucketSpec(numBuckets, Seq("j"), Nil) // json does not support predicate push-down, and thus json is used here df.write @@ -155,13 +166,21 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils { bucketValues = Seq(j, j + 1, j + 2, j + 3), filterCondition = $"j".isin(j, j + 1, j + 2, j + 3), df) + + // Case 4: InSet + val inSetExpr = expressions.InSet($"j".expr, Set(j, j + 1, j + 2, j + 3).map(lit(_).expr)) + checkPrunedAnswers( + bucketSpec, + bucketValues = Seq(j, j + 1, j + 2, j + 3), + filterCondition = Column(inSetExpr), + df) } } } test("read non-partitioning bucketed tables with bucket pruning filters") { withTable("bucketed_table") { - val numBuckets = 8 + val numBuckets = NumBucketsForPruningDF val bucketSpec = BucketSpec(numBuckets, Seq("j"), Nil) // json does not support predicate push-down, and thus json is used here df.write @@ -181,7 +200,7 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils { test("read partitioning bucketed tables having null in bucketing key") { withTable("bucketed_table") { - val numBuckets = 8 + val numBuckets = NumBucketsForPruningNullDf val bucketSpec = BucketSpec(numBuckets, Seq("j"), Nil) // json does not support predicate push-down, and thus json is used here nullDF.write @@ -208,7 +227,7 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils { test("read partitioning bucketed tables having composite filters") { withTable("bucketed_table") { - val numBuckets = 8 + val numBuckets = NumBucketsForPruningDF val bucketSpec = BucketSpec(numBuckets, Seq("j"), Nil) // json does not support predicate push-down, and thus json is used here df.write @@ -229,7 +248,62 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils { bucketValues = j :: Nil, filterCondition = $"j" === j && $"i" > j % 5, df) + + // check multiple bucket values OR condition + checkPrunedAnswers( + bucketSpec, + bucketValues = Seq(j, j + 1), + filterCondition = $"j" === j || $"j" === (j + 1), + df) + + // check bucket value and none bucket value OR condition + checkPrunedAnswers( + bucketSpec, + bucketValues = Nil, + filterCondition = $"j" === j || $"i" === 0, + df) + + // check AND condition in complex expression + checkPrunedAnswers( + bucketSpec, + bucketValues = Seq(j), + filterCondition = ($"i" === 0 || $"k" > $"j") && $"j" === j, + df) + } + } + } + + test("read bucketed table without filters") { + withTable("bucketed_table") { + val numBuckets = NumBucketsForPruningDF + val bucketSpec = BucketSpec(numBuckets, Seq("j"), Nil) + // json does not support predicate push-down, and thus json is used here + df.write + .format("json") + .bucketBy(numBuckets, "j") + .saveAsTable("bucketed_table") + + val bucketedDataFrame = spark.table("bucketed_table").select("i", "j", "k") + val plan = bucketedDataFrame.queryExecution.executedPlan + val rdd = plan.find(_.isInstanceOf[DataSourceScanExec]) + assert(rdd.isDefined, plan) + + val emptyBuckets = rdd.get.execute().mapPartitionsWithIndex { case (index, iter) => + // return indexes of empty partitions + if (iter.isEmpty) { + Iterator(index) + } else { + Iterator() + } + }.collect() + + if (emptyBuckets.nonEmpty) { + fail(s"Buckets ${emptyBuckets.mkString(",")} should not have been pruned from:\n$plan") } + + checkAnswer( + bucketedDataFrame.orderBy("i", "j", "k"), + df.orderBy("i", "j", "k")) } } From 1462bba4fd99a264ebc8679db91dfc62d0b9a35f Mon Sep 17 00:00:00 2001 From: Bruce Robbins Date: Fri, 8 Jun 2018 13:27:52 +0200 Subject: [PATCH 0928/1161] [SPARK-24119][SQL] Add interpreted execution to SortPrefix expression ## What changes were proposed in this pull request? Implemented eval in SortPrefix expression. ## How was this patch tested? - ran existing sbt SQL tests - added unit test - ran existing Python SQL tests - manual tests: disabling codegen -- patching code to disable beyond what spark.sql.codegen.wholeStage=false can do -- and running sbt SQL tests Author: Bruce Robbins Closes #21231 from bersprockets/sortprefixeval. --- .../sql/catalyst/expressions/SortOrder.scala | 37 +++++++- .../SortOrderExpressionsSuite.scala | 95 +++++++++++++++++++ 2 files changed, 131 insertions(+), 1 deletion(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SortOrderExpressionsSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala index 2ce9d072c71c9..76a881146a146 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.collection.unsafe.sort.PrefixComparators._ abstract sealed class SortDirection { @@ -148,7 +149,41 @@ case class SortPrefix(child: SortOrder) extends UnaryExpression { (!child.isAscending && child.nullOrdering == NullsLast) } - override def eval(input: InternalRow): Any = throw new UnsupportedOperationException + private lazy val calcPrefix: Any => Long = child.child.dataType match { + case BooleanType => (raw) => + if (raw.asInstanceOf[Boolean]) 1 else 0 + case DateType | TimestampType | _: IntegralType => (raw) => + raw.asInstanceOf[java.lang.Number].longValue() + case FloatType | DoubleType => (raw) => { + val dVal = raw.asInstanceOf[java.lang.Number].doubleValue() + DoublePrefixComparator.computePrefix(dVal) + } + case StringType => (raw) => + StringPrefixComparator.computePrefix(raw.asInstanceOf[UTF8String]) + case BinaryType => (raw) => + BinaryPrefixComparator.computePrefix(raw.asInstanceOf[Array[Byte]]) + case dt: DecimalType if dt.precision <= Decimal.MAX_LONG_DIGITS => + _.asInstanceOf[Decimal].toUnscaledLong + case dt: DecimalType if dt.precision - dt.scale <= Decimal.MAX_LONG_DIGITS => + val p = Decimal.MAX_LONG_DIGITS + val s = p - (dt.precision - dt.scale) + (raw) => { + val value = raw.asInstanceOf[Decimal] + if (value.changePrecision(p, s)) value.toUnscaledLong else Long.MinValue + } + case dt: DecimalType => (raw) => + DoublePrefixComparator.computePrefix(raw.asInstanceOf[Decimal].toDouble) + case _ => (Any) => 0L + } + + override def eval(input: InternalRow): Any = { + val value = child.child.eval(input) + if (value == null) { + null + } else { + calcPrefix(value) + } + } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val childCode = child.child.genCode(ctx) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SortOrderExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SortOrderExpressionsSuite.scala new file mode 100644 index 0000000000000..cc2e2a993d629 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SortOrderExpressionsSuite.scala @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import java.sql.{Date, Timestamp} +import java.util.TimeZone + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.util.collection.unsafe.sort.PrefixComparators._ + +class SortOrderExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { + + test("SortPrefix") { + val b1 = Literal.create(false, BooleanType) + val b2 = Literal.create(true, BooleanType) + val i1 = Literal.create(20132983, IntegerType) + val i2 = Literal.create(-20132983, IntegerType) + val l1 = Literal.create(20132983, LongType) + val l2 = Literal.create(-20132983, LongType) + val millis = 1524954911000L; + // Explicitly choose a time zone, since Date objects can create different values depending on + // local time zone of the machine on which the test is running + val oldDefaultTZ = TimeZone.getDefault + val d1 = try { + TimeZone.setDefault(TimeZone.getTimeZone("America/Los_Angeles")) + Literal.create(new java.sql.Date(millis), DateType) + } finally { + TimeZone.setDefault(oldDefaultTZ) + } + val t1 = Literal.create(new Timestamp(millis), TimestampType) + val f1 = Literal.create(0.7788229f, FloatType) + val f2 = Literal.create(-0.7788229f, FloatType) + val db1 = Literal.create(0.7788229d, DoubleType) + val db2 = Literal.create(-0.7788229d, DoubleType) + val s1 = Literal.create("T", StringType) + val s2 = Literal.create("This is longer than 8 characters", StringType) + val bin1 = Literal.create(Array[Byte](12), BinaryType) + val bin2 = Literal.create(Array[Byte](12, 17, 99, 0, 0, 0, 2, 3, 0xf4.asInstanceOf[Byte]), + BinaryType) + val dec1 = Literal(Decimal(20132983L, 10, 2)) + val dec2 = Literal(Decimal(20132983L, 19, 2)) + val dec3 = Literal(Decimal(20132983L, 21, 2)) + val list1 = Literal(List(1, 2), ArrayType(IntegerType)) + val nullVal = Literal.create(null, IntegerType) + + checkEvaluation(SortPrefix(SortOrder(b1, Ascending)), 0L) + checkEvaluation(SortPrefix(SortOrder(b2, Ascending)), 1L) + checkEvaluation(SortPrefix(SortOrder(i1, Ascending)), 20132983L) + checkEvaluation(SortPrefix(SortOrder(i2, Ascending)), -20132983L) + checkEvaluation(SortPrefix(SortOrder(l1, Ascending)), 20132983L) + checkEvaluation(SortPrefix(SortOrder(l2, Ascending)), -20132983L) + // For some reason, the Literal.create code gives us the number of days since the epoch + checkEvaluation(SortPrefix(SortOrder(d1, Ascending)), 17649L) + checkEvaluation(SortPrefix(SortOrder(t1, Ascending)), millis * 1000) + checkEvaluation(SortPrefix(SortOrder(f1, Ascending)), + DoublePrefixComparator.computePrefix(f1.value.asInstanceOf[Float].toDouble)) + checkEvaluation(SortPrefix(SortOrder(f2, Ascending)), + DoublePrefixComparator.computePrefix(f2.value.asInstanceOf[Float].toDouble)) + checkEvaluation(SortPrefix(SortOrder(db1, Ascending)), + DoublePrefixComparator.computePrefix(db1.value.asInstanceOf[Double])) + checkEvaluation(SortPrefix(SortOrder(db2, Ascending)), + DoublePrefixComparator.computePrefix(db2.value.asInstanceOf[Double])) + checkEvaluation(SortPrefix(SortOrder(s1, Ascending)), + StringPrefixComparator.computePrefix(s1.value.asInstanceOf[UTF8String])) + checkEvaluation(SortPrefix(SortOrder(s2, Ascending)), + StringPrefixComparator.computePrefix(s2.value.asInstanceOf[UTF8String])) + checkEvaluation(SortPrefix(SortOrder(bin1, Ascending)), + BinaryPrefixComparator.computePrefix(bin1.value.asInstanceOf[Array[Byte]])) + checkEvaluation(SortPrefix(SortOrder(bin2, Ascending)), + BinaryPrefixComparator.computePrefix(bin2.value.asInstanceOf[Array[Byte]])) + checkEvaluation(SortPrefix(SortOrder(dec1, Ascending)), 20132983L) + checkEvaluation(SortPrefix(SortOrder(dec2, Ascending)), 2013298L) + checkEvaluation(SortPrefix(SortOrder(dec3, Ascending)), + DoublePrefixComparator.computePrefix(201329.83d)) + checkEvaluation(SortPrefix(SortOrder(list1, Ascending)), 0L) + checkEvaluation(SortPrefix(SortOrder(nullVal, Ascending)), null) + } +} From 2c100209f0b73e882ab953993b307867d1df7c2f Mon Sep 17 00:00:00 2001 From: Shahid Date: Fri, 8 Jun 2018 08:44:59 -0500 Subject: [PATCH 0929/1161] [SPARK-24224][ML-EXAMPLES] Java example code for Power Iteration Clustering in spark.ml ## What changes were proposed in this pull request? Java example code for Power Iteration Clustering in spark.ml ## How was this patch tested? Locally tested Author: Shahid Closes #21283 from shahidki31/JavaPicExample. --- .../JavaPowerIterationClusteringExample.java | 71 +++++++++++++++++++ 1 file changed, 71 insertions(+) create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaPowerIterationClusteringExample.java diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaPowerIterationClusteringExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaPowerIterationClusteringExample.java new file mode 100644 index 0000000000000..51865637df6f6 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaPowerIterationClusteringExample.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +// $example on$ +import java.util.Arrays; +import java.util.List; + +import org.apache.spark.ml.clustering.PowerIterationClustering; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +// $example off$ + +public class JavaPowerIterationClusteringExample { + public static void main(String[] args) { + // Create a SparkSession. + SparkSession spark = SparkSession + .builder() + .appName("JavaPowerIterationClustering") + .getOrCreate(); + + // $example on$ + List data = Arrays.asList( + RowFactory.create(0L, 1L, 1.0), + RowFactory.create(0L, 2L, 1.0), + RowFactory.create(1L, 2L, 1.0), + RowFactory.create(3L, 4L, 1.0), + RowFactory.create(4L, 0L, 0.1) + ); + + StructType schema = new StructType(new StructField[]{ + new StructField("src", DataTypes.LongType, false, Metadata.empty()), + new StructField("dst", DataTypes.LongType, false, Metadata.empty()), + new StructField("weight", DataTypes.DoubleType, false, Metadata.empty()) + }); + + Dataset df = spark.createDataFrame(data, schema); + + PowerIterationClustering model = new PowerIterationClustering() + .setK(2) + .setMaxIter(10) + .setInitMode("degree") + .setWeightCol("weight"); + + Dataset result = model.assignClusters(df); + result.show(false); + // $example off$ + spark.stop(); + } +} From a5d775a1f3aad7bef0ac0f93869eaf96b677411b Mon Sep 17 00:00:00 2001 From: Shahid Date: Fri, 8 Jun 2018 08:45:56 -0500 Subject: [PATCH 0930/1161] [SPARK-24191][ML] Scala Example code for Power Iteration Clustering ## What changes were proposed in this pull request? Added example code for Power Iteration Clustering in Spark ML examples Author: Shahid Closes #21248 from shahidki31/sparkCommit. --- .../ml/PowerIterationClusteringExample.scala | 56 +++++++++++++++++++ 1 file changed, 56 insertions(+) create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/PowerIterationClusteringExample.scala diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/PowerIterationClusteringExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/PowerIterationClusteringExample.scala new file mode 100644 index 0000000000000..ca8f7affb14e8 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/PowerIterationClusteringExample.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.clustering.PowerIterationClustering +// $example off$ +import org.apache.spark.sql.SparkSession + +object PowerIterationClusteringExample { + def main(args: Array[String]): Unit = { + val spark = SparkSession + .builder + .appName(s"${this.getClass.getSimpleName}") + .getOrCreate() + + // $example on$ + val dataset = spark.createDataFrame(Seq( + (0L, 1L, 1.0), + (0L, 2L, 1.0), + (1L, 2L, 1.0), + (3L, 4L, 1.0), + (4L, 0L, 0.1) + )).toDF("src", "dst", "weight") + + val model = new PowerIterationClustering(). + setK(2). + setMaxIter(20). + setInitMode("degree"). + setWeightCol("weight") + + val prediction = model.assignClusters(dataset).select("id", "cluster") + + // Shows the cluster assignment + prediction.show(false) + // $example off$ + + spark.stop() + } + } From 173fe450df203b262b58f7e71c6b52a79db95ee0 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Fri, 8 Jun 2018 09:32:11 -0700 Subject: [PATCH 0931/1161] [SPARK-24477][SPARK-24454][ML][PYTHON] Imports submodule in ml/__init__.py and add ImageSchema into __all__ ## What changes were proposed in this pull request? This PR attaches submodules to ml's `__init__.py` module. Also, adds `ImageSchema` into `image.py` explicitly. ## How was this patch tested? Before: ```python >>> from pyspark import ml >>> ml.image Traceback (most recent call last): File "", line 1, in AttributeError: 'module' object has no attribute 'image' >>> ml.image.ImageSchema Traceback (most recent call last): File "", line 1, in AttributeError: 'module' object has no attribute 'image' ``` ```python >>> "image" in globals() False >>> from pyspark.ml import * >>> "image" in globals() False >>> image Traceback (most recent call last): File "", line 1, in NameError: name 'image' is not defined ``` After: ```python >>> from pyspark import ml >>> ml.image >>> ml.image.ImageSchema ``` ```python >>> "image" in globals() False >>> from pyspark.ml import * >>> "image" in globals() True >>> image ``` Author: hyukjinkwon Closes #21483 from HyukjinKwon/SPARK-24454. --- python/pyspark/ml/__init__.py | 8 +++++++- python/pyspark/ml/image.py | 2 ++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/python/pyspark/ml/__init__.py b/python/pyspark/ml/__init__.py index 129d7d68f7cbb..d99a25390db15 100644 --- a/python/pyspark/ml/__init__.py +++ b/python/pyspark/ml/__init__.py @@ -21,5 +21,11 @@ """ from pyspark.ml.base import Estimator, Model, Transformer, UnaryTransformer from pyspark.ml.pipeline import Pipeline, PipelineModel +from pyspark.ml import classification, clustering, evaluation, feature, fpm, \ + image, pipeline, recommendation, regression, stat, tuning, util, linalg, param -__all__ = ["Transformer", "UnaryTransformer", "Estimator", "Model", "Pipeline", "PipelineModel"] +__all__ = [ + "Transformer", "UnaryTransformer", "Estimator", "Model", "Pipeline", "PipelineModel", + "classification", "clustering", "evaluation", "feature", "fpm", "image", + "recommendation", "regression", "stat", "tuning", "util", "linalg", "param", +] diff --git a/python/pyspark/ml/image.py b/python/pyspark/ml/image.py index 96d702f844839..5f0c57ee3cc67 100644 --- a/python/pyspark/ml/image.py +++ b/python/pyspark/ml/image.py @@ -31,6 +31,8 @@ from pyspark.sql.types import Row, _create_row, _parse_datatype_json_string from pyspark.sql import DataFrame, SparkSession +__all__ = ["ImageSchema"] + class _ImageSchema(object): """ From 1a644afbac35c204f9ad55f86999319a9ab458c6 Mon Sep 17 00:00:00 2001 From: Ilan Filonenko Date: Fri, 8 Jun 2018 11:18:34 -0700 Subject: [PATCH 0932/1161] [SPARK-23984][K8S] Initial Python Bindings for PySpark on K8s ## What changes were proposed in this pull request? Introducing Python Bindings for PySpark. - [x] Running PySpark Jobs - [x] Increased Default Memory Overhead value - [ ] Dependency Management for virtualenv/conda ## How was this patch tested? This patch was tested with - [x] Unit Tests - [x] Integration tests with [this addition](https://github.com/apache-spark-on-k8s/spark-integration/pull/46) ``` KubernetesSuite: - Run SparkPi with no resources - Run SparkPi with a very long application name. - Run SparkPi with a master URL without a scheme. - Run SparkPi with an argument. - Run SparkPi with custom labels, annotations, and environment variables. - Run SparkPi with a test secret mounted into the driver and executor pods - Run extraJVMOptions check on driver - Run SparkRemoteFileTest using a remote data file - Run PySpark on simple pi.py example - Run PySpark with Python2 to test a pyfiles example - Run PySpark with Python3 to test a pyfiles example Run completed in 4 minutes, 28 seconds. Total number of tests run: 11 Suites: completed 2, aborted 0 Tests: succeeded 11, failed 0, canceled 0, ignored 0, pending 0 All tests passed. ``` Author: Ilan Filonenko Author: Ilan Filonenko Closes #21092 from ifilonenko/master. --- bin/docker-image-tool.sh | 23 +++- .../org/apache/spark/deploy/SparkSubmit.scala | 14 ++- docs/running-on-kubernetes.md | 16 ++- .../src/main/python/py_container_checks.py | 32 ++++++ examples/src/main/python/pyfiles.py | 38 ++++++ .../org/apache/spark/deploy/k8s/Config.scala | 40 +++++++ .../apache/spark/deploy/k8s/Constants.scala | 7 +- .../spark/deploy/k8s/KubernetesConf.scala | 62 +++++++--- .../spark/deploy/k8s/KubernetesUtils.scala | 2 +- .../k8s/features/BasicDriverFeatureStep.scala | 14 +-- .../features/BasicExecutorFeatureStep.scala | 3 +- .../bindings/JavaDriverFeatureStep.scala | 44 +++++++ .../bindings/PythonDriverFeatureStep.scala | 73 ++++++++++++ .../submit/KubernetesClientApplication.scala | 16 ++- .../k8s/submit/KubernetesDriverBuilder.scala | 39 +++++-- .../deploy/k8s/submit/MainAppResource.scala | 5 + .../k8s/KubernetesExecutorBuilder.scala | 22 ++-- .../deploy/k8s/KubernetesConfSuite.scala | 66 +++++++++-- .../BasicDriverFeatureStepSuite.scala | 58 +++++++++- .../BasicExecutorFeatureStepSuite.scala | 9 +- ...ubernetesCredentialsFeatureStepSuite.scala | 9 +- .../DriverServiceFeatureStepSuite.scala | 18 ++- .../features/EnvSecretsFeatureStepSuite.scala | 3 +- .../features/LocalDirsFeatureStepSuite.scala | 3 +- .../MountSecretsFeatureStepSuite.scala | 3 +- .../bindings/JavaDriverFeatureStepSuite.scala | 60 ++++++++++ .../PythonDriverFeatureStepSuite.scala | 108 ++++++++++++++++++ .../spark/deploy/k8s/submit/ClientSuite.scala | 3 +- .../submit/KubernetesDriverBuilderSuite.scala | 78 +++++++++++-- .../k8s/KubernetesExecutorBuilderSuite.scala | 6 +- .../spark/bindings/python/Dockerfile | 39 +++++++ .../src/main/dockerfiles/spark/entrypoint.sh | 30 +++++ 32 files changed, 842 insertions(+), 101 deletions(-) create mode 100644 examples/src/main/python/py_container_checks.py create mode 100644 examples/src/main/python/pyfiles.py create mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/bindings/JavaDriverFeatureStep.scala create mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/bindings/PythonDriverFeatureStep.scala create mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/bindings/JavaDriverFeatureStepSuite.scala create mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/bindings/PythonDriverFeatureStepSuite.scala create mode 100644 resource-managers/kubernetes/docker/src/main/dockerfiles/spark/bindings/python/Dockerfile diff --git a/bin/docker-image-tool.sh b/bin/docker-image-tool.sh index f090240065bf1..a871ab5d448c3 100755 --- a/bin/docker-image-tool.sh +++ b/bin/docker-image-tool.sh @@ -63,12 +63,20 @@ function build { if [ ! -d "$IMG_PATH" ]; then error "Cannot find docker image. This script must be run from a runnable distribution of Apache Spark." fi - - local DOCKERFILE=${DOCKERFILE:-"$IMG_PATH/spark/Dockerfile"} + local BINDING_BUILD_ARGS=( + --build-arg + base_img=$(image_ref spark) + ) + local BASEDOCKERFILE=${BASEDOCKERFILE:-"$IMG_PATH/spark/Dockerfile"} + local PYDOCKERFILE=${PYDOCKERFILE:-"$IMG_PATH/spark/bindings/python/Dockerfile"} docker build "${BUILD_ARGS[@]}" \ -t $(image_ref spark) \ - -f "$DOCKERFILE" . + -f "$BASEDOCKERFILE" . + + docker build "${BINDING_BUILD_ARGS[@]}" \ + -t $(image_ref spark-py) \ + -f "$PYDOCKERFILE" . } function push { @@ -86,7 +94,8 @@ Commands: push Push a pre-built image to a registry. Requires a repository address to be provided. Options: - -f file Dockerfile to build. By default builds the Dockerfile shipped with Spark. + -f file Dockerfile to build for JVM based Jobs. By default builds the Dockerfile shipped with Spark. + -p file Dockerfile with Python baked in. By default builds the Dockerfile shipped with Spark. -r repo Repository address. -t tag Tag to apply to the built image, or to identify the image to be pushed. -m Use minikube's Docker daemon. @@ -116,12 +125,14 @@ fi REPO= TAG= -DOCKERFILE= +BASEDOCKERFILE= +PYDOCKERFILE= while getopts f:mr:t: option do case "${option}" in - f) DOCKERFILE=${OPTARG};; + f) BASEDOCKERFILE=${OPTARG};; + p) PYDOCKERFILE=${OPTARG};; r) REPO=${OPTARG};; t) TAG=${OPTARG};; m) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index a46af26feb061..e83d82f847c61 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -285,8 +285,6 @@ private[spark] class SparkSubmit extends Logging { case (STANDALONE, CLUSTER) if args.isR => error("Cluster deploy mode is currently not supported for R " + "applications on standalone clusters.") - case (KUBERNETES, _) if args.isPython => - error("Python applications are currently not supported for Kubernetes.") case (KUBERNETES, _) if args.isR => error("R applications are currently not supported for Kubernetes.") case (LOCAL, CLUSTER) => @@ -694,9 +692,17 @@ private[spark] class SparkSubmit extends Logging { if (isKubernetesCluster) { childMainClass = KUBERNETES_CLUSTER_SUBMIT_CLASS if (args.primaryResource != SparkLauncher.NO_RESOURCE) { - childArgs ++= Array("--primary-java-resource", args.primaryResource) + if (args.isPython) { + childArgs ++= Array("--primary-py-file", args.primaryResource) + childArgs ++= Array("--main-class", "org.apache.spark.deploy.PythonRunner") + if (args.pyFiles != null) { + childArgs ++= Array("--other-py-files", args.pyFiles) + } + } else { + childArgs ++= Array("--primary-java-resource", args.primaryResource) + childArgs ++= Array("--main-class", args.mainClass) + } } - childArgs ++= Array("--main-class", args.mainClass) if (args.childArgs != null) { args.childArgs.foreach { arg => childArgs += ("--arg", arg) diff --git a/docs/running-on-kubernetes.md b/docs/running-on-kubernetes.md index 4eac9bd9032e4..408e446ea4822 100644 --- a/docs/running-on-kubernetes.md +++ b/docs/running-on-kubernetes.md @@ -270,7 +270,6 @@ future versions of the spark-kubernetes integration. Some of these include: -* PySpark * R * Dynamic Executor Scaling * Local File Dependency Management @@ -631,4 +630,19 @@ specific to Spark on Kubernetes. spark.kubernetes.executor.secrets.ENV_VAR=spark-secret:key. + + spark.kubernetes.memoryOverheadFactor + 0.1 + + This sets the Memory Overhead Factor that will allocate memory to non-JVM memory, which includes off-heap memory allocations, non-JVM tasks, and various systems processes. For JVM-based jobs this value will default to 0.10 and 0.40 for non-JVM jobs. + This is done as non-JVM tasks need more non-JVM heap space and such tasks commonly fail with "Memory Overhead Exceeded" errors. This prempts this error with a higher default. + + + + spark.kubernetes.pyspark.pythonversion + "2" + + This sets the major Python version of the docker image used to run the driver and executor containers. Can either be 2 or 3. + + diff --git a/examples/src/main/python/py_container_checks.py b/examples/src/main/python/py_container_checks.py new file mode 100644 index 0000000000000..f6b3be2806c82 --- /dev/null +++ b/examples/src/main/python/py_container_checks.py @@ -0,0 +1,32 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import sys + + +def version_check(python_env, major_python_version): + """ + These are various tests to test the Python container image. + This file will be distributed via --py-files in the e2e tests. + """ + env_version = os.environ.get('PYSPARK_PYTHON') + print("Python runtime version check is: " + + str(sys.version_info[0] == major_python_version)) + + print("Python environment version check is: " + + str(env_version == python_env)) diff --git a/examples/src/main/python/pyfiles.py b/examples/src/main/python/pyfiles.py new file mode 100644 index 0000000000000..4193654b49a12 --- /dev/null +++ b/examples/src/main/python/pyfiles.py @@ -0,0 +1,38 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +import sys + +from pyspark.sql import SparkSession + + +if __name__ == "__main__": + """ + Usage: pyfiles [major_python_version] + """ + spark = SparkSession \ + .builder \ + .appName("PyFilesTest") \ + .getOrCreate() + + from py_container_checks import version_check + # Begin of Python container checks + version_check(sys.argv[1], 2 if sys.argv[1] == "python" else 3) + + spark.stop() diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala index 560dedf431b08..590deaa72e7ee 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala @@ -117,6 +117,28 @@ private[spark] object Config extends Logging { .stringConf .createWithDefault("spark") + val KUBERNETES_PYSPARK_PY_FILES = + ConfigBuilder("spark.kubernetes.python.pyFiles") + .doc("The PyFiles that are distributed via client arguments") + .internal() + .stringConf + .createOptional + + val KUBERNETES_PYSPARK_MAIN_APP_RESOURCE = + ConfigBuilder("spark.kubernetes.python.mainAppResource") + .doc("The main app resource for pyspark jobs") + .internal() + .stringConf + .createOptional + + val KUBERNETES_PYSPARK_APP_ARGS = + ConfigBuilder("spark.kubernetes.python.appArgs") + .doc("The app arguments for PySpark Jobs") + .internal() + .stringConf + .createOptional + + val KUBERNETES_ALLOCATION_BATCH_SIZE = ConfigBuilder("spark.kubernetes.allocation.batch.size") .doc("Number of pods to launch at once in each round of executor allocation.") @@ -154,6 +176,24 @@ private[spark] object Config extends Logging { .checkValue(interval => interval > 0, s"Logging interval must be a positive time value.") .createWithDefaultString("1s") + val MEMORY_OVERHEAD_FACTOR = + ConfigBuilder("spark.kubernetes.memoryOverheadFactor") + .doc("This sets the Memory Overhead Factor that will allocate memory to non-JVM jobs " + + "which in the case of JVM tasks will default to 0.10 and 0.40 for non-JVM jobs") + .doubleConf + .checkValue(mem_overhead => mem_overhead >= 0 && mem_overhead < 1, + "Ensure that memory overhead is a double between 0 --> 1.0") + .createWithDefault(0.1) + + val PYSPARK_MAJOR_PYTHON_VERSION = + ConfigBuilder("spark.kubernetes.pyspark.pythonversion") + .doc("This sets the major Python version. Either 2 or 3. (Python2 or Python3)") + .stringConf + .checkValue(pv => List("2", "3").contains(pv), + "Ensure that major Python version is either Python2 or Python3") + .createWithDefault("2") + + val KUBERNETES_AUTH_SUBMISSION_CONF_PREFIX = "spark.kubernetes.authenticate.submission" diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Constants.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Constants.scala index 8da5f24044aad..69bd03d1eda6f 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Constants.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Constants.scala @@ -71,9 +71,14 @@ private[spark] object Constants { val SPARK_CONF_FILE_NAME = "spark.properties" val SPARK_CONF_PATH = s"$SPARK_CONF_DIR_INTERNAL/$SPARK_CONF_FILE_NAME" + // BINDINGS + val ENV_PYSPARK_PRIMARY = "PYSPARK_PRIMARY" + val ENV_PYSPARK_FILES = "PYSPARK_FILES" + val ENV_PYSPARK_ARGS = "PYSPARK_APP_ARGS" + val ENV_PYSPARK_MAJOR_PYTHON_VERSION = "PYSPARK_MAJOR_PYTHON_VERSION" + // Miscellaneous val KUBERNETES_MASTER_INTERNAL_URL = "https://kubernetes.default.svc" val DRIVER_CONTAINER_NAME = "spark-kubernetes-driver" - val MEMORY_OVERHEAD_FACTOR = 0.10 val MEMORY_OVERHEAD_MIN_MIB = 384L } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesConf.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesConf.scala index 5a944187a7096..b0ccaa36b01ed 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesConf.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesConf.scala @@ -16,14 +16,17 @@ */ package org.apache.spark.deploy.k8s +import scala.collection.mutable + import io.fabric8.kubernetes.api.model.{LocalObjectReference, LocalObjectReferenceBuilder, Pod} import org.apache.spark.SparkConf import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ -import org.apache.spark.deploy.k8s.submit.{JavaMainAppResource, MainAppResource} +import org.apache.spark.deploy.k8s.submit._ import org.apache.spark.internal.config.ConfigEntry + private[spark] sealed trait KubernetesRoleSpecificConf /* @@ -55,7 +58,8 @@ private[spark] case class KubernetesConf[T <: KubernetesRoleSpecificConf]( roleAnnotations: Map[String, String], roleSecretNamesToMountPaths: Map[String, String], roleSecretEnvNamesToKeyRefs: Map[String, String], - roleEnvs: Map[String, String]) { + roleEnvs: Map[String, String], + sparkFiles: Seq[String]) { def namespace(): String = sparkConf.get(KUBERNETES_NAMESPACE) @@ -64,10 +68,14 @@ private[spark] case class KubernetesConf[T <: KubernetesRoleSpecificConf]( .map(str => str.split(",").toSeq) .getOrElse(Seq.empty[String]) - def sparkFiles(): Seq[String] = sparkConf - .getOption("spark.files") - .map(str => str.split(",").toSeq) - .getOrElse(Seq.empty[String]) + def pyFiles(): Option[String] = sparkConf + .get(KUBERNETES_PYSPARK_PY_FILES) + + def pySparkMainResource(): Option[String] = sparkConf + .get(KUBERNETES_PYSPARK_MAIN_APP_RESOURCE) + + def pySparkPythonVersion(): String = sparkConf + .get(PYSPARK_MAJOR_PYTHON_VERSION) def imagePullPolicy(): String = sparkConf.get(CONTAINER_IMAGE_PULL_POLICY) @@ -102,17 +110,30 @@ private[spark] object KubernetesConf { appId: String, mainAppResource: Option[MainAppResource], mainClass: String, - appArgs: Array[String]): KubernetesConf[KubernetesDriverSpecificConf] = { + appArgs: Array[String], + maybePyFiles: Option[String]): KubernetesConf[KubernetesDriverSpecificConf] = { val sparkConfWithMainAppJar = sparkConf.clone() + val additionalFiles = mutable.ArrayBuffer.empty[String] mainAppResource.foreach { - case JavaMainAppResource(res) => - val previousJars = sparkConf - .getOption("spark.jars") - .map(_.split(",")) - .getOrElse(Array.empty) - if (!previousJars.contains(res)) { - sparkConfWithMainAppJar.setJars(previousJars ++ Seq(res)) - } + case JavaMainAppResource(res) => + val previousJars = sparkConf + .getOption("spark.jars") + .map(_.split(",")) + .getOrElse(Array.empty) + if (!previousJars.contains(res)) { + sparkConfWithMainAppJar.setJars(previousJars ++ Seq(res)) + } + // The function of this outer match is to account for multiple nonJVM + // bindings that will all have increased MEMORY_OVERHEAD_FACTOR to 0.4 + case nonJVM: NonJVMResource => + nonJVM match { + case PythonMainAppResource(res) => + additionalFiles += res + maybePyFiles.foreach{maybePyFiles => + additionalFiles.appendAll(maybePyFiles.split(","))} + sparkConfWithMainAppJar.set(KUBERNETES_PYSPARK_MAIN_APP_RESOURCE, res) + } + sparkConfWithMainAppJar.setIfMissing(MEMORY_OVERHEAD_FACTOR, 0.4) } val driverCustomLabels = KubernetesUtils.parsePrefixedKeyValuePairs( @@ -135,6 +156,11 @@ private[spark] object KubernetesConf { val driverEnvs = KubernetesUtils.parsePrefixedKeyValuePairs( sparkConf, KUBERNETES_DRIVER_ENV_PREFIX) + val sparkFiles = sparkConf + .getOption("spark.files") + .map(str => str.split(",").toSeq) + .getOrElse(Seq.empty[String]) ++ additionalFiles + KubernetesConf( sparkConfWithMainAppJar, KubernetesDriverSpecificConf(mainAppResource, mainClass, appName, appArgs), @@ -144,7 +170,8 @@ private[spark] object KubernetesConf { driverAnnotations, driverSecretNamesToMountPaths, driverSecretEnvNamesToKeyRefs, - driverEnvs) + driverEnvs, + sparkFiles) } def createExecutorConf( @@ -186,6 +213,7 @@ private[spark] object KubernetesConf { executorAnnotations, executorMountSecrets, executorEnvSecrets, - executorEnv) + executorEnv, + Seq.empty[String]) } } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesUtils.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesUtils.scala index ee629068ad90d..593fb531a004d 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesUtils.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesUtils.scala @@ -52,7 +52,7 @@ private[spark] object KubernetesUtils { } } - private def resolveFileUri(uri: String): String = { + def resolveFileUri(uri: String): String = { val fileUri = Utils.resolveURI(uri) val fileScheme = Option(fileUri.getScheme).getOrElse("file") fileScheme match { diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStep.scala index 07bdccbe0479d..143dc8a12304e 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStep.scala @@ -25,8 +25,8 @@ import org.apache.spark.SparkException import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpecificConf, KubernetesUtils, SparkPod} import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.deploy.k8s.submit._ import org.apache.spark.internal.config._ -import org.apache.spark.launcher.SparkLauncher private[spark] class BasicDriverFeatureStep( conf: KubernetesConf[KubernetesDriverSpecificConf]) @@ -48,7 +48,8 @@ private[spark] class BasicDriverFeatureStep( private val driverMemoryMiB = conf.get(DRIVER_MEMORY) private val memoryOverheadMiB = conf .get(DRIVER_MEMORY_OVERHEAD) - .getOrElse(math.max((MEMORY_OVERHEAD_FACTOR * driverMemoryMiB).toInt, MEMORY_OVERHEAD_MIN_MIB)) + .getOrElse(math.max((conf.get(MEMORY_OVERHEAD_FACTOR) * driverMemoryMiB).toInt, + MEMORY_OVERHEAD_MIN_MIB)) private val driverMemoryWithOverheadMiB = driverMemoryMiB + memoryOverheadMiB override def configurePod(pod: SparkPod): SparkPod = { @@ -88,13 +89,6 @@ private[spark] class BasicDriverFeatureStep( .addToRequests("memory", driverMemoryQuantity) .addToLimits("memory", driverMemoryQuantity) .endResources() - .addToArgs("driver") - .addToArgs("--properties-file", SPARK_CONF_PATH) - .addToArgs("--class", conf.roleSpecificConf.mainClass) - // The user application jar is merged into the spark.jars list and managed through that - // property, so there is no need to reference it explicitly here. - .addToArgs(SparkLauncher.NO_RESOURCE) - .addToArgs(conf.roleSpecificConf.appArgs: _*) .build() val driverPod = new PodBuilder(pod.pod) @@ -122,7 +116,7 @@ private[spark] class BasicDriverFeatureStep( val resolvedSparkJars = KubernetesUtils.resolveFileUrisAndPath( conf.sparkJars()) val resolvedSparkFiles = KubernetesUtils.resolveFileUrisAndPath( - conf.sparkFiles()) + conf.sparkFiles) if (resolvedSparkJars.nonEmpty) { additionalProps.put("spark.jars", resolvedSparkJars.mkString(",")) } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala index 529069d3b8a0c..91c54a9776982 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala @@ -54,7 +54,8 @@ private[spark] class BasicExecutorFeatureStep( private val memoryOverheadMiB = kubernetesConf .get(EXECUTOR_MEMORY_OVERHEAD) - .getOrElse(math.max((MEMORY_OVERHEAD_FACTOR * executorMemoryMiB).toInt, + .getOrElse(math.max( + (kubernetesConf.get(MEMORY_OVERHEAD_FACTOR) * executorMemoryMiB).toInt, MEMORY_OVERHEAD_MIN_MIB)) private val executorMemoryWithOverhead = executorMemoryMiB + memoryOverheadMiB diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/bindings/JavaDriverFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/bindings/JavaDriverFeatureStep.scala new file mode 100644 index 0000000000000..f52ec9fdc677e --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/bindings/JavaDriverFeatureStep.scala @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.features.bindings + +import io.fabric8.kubernetes.api.model.{ContainerBuilder, HasMetadata} + +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpecificConf, SparkPod} +import org.apache.spark.deploy.k8s.Constants.SPARK_CONF_PATH +import org.apache.spark.deploy.k8s.features.KubernetesFeatureConfigStep +import org.apache.spark.launcher.SparkLauncher + +private[spark] class JavaDriverFeatureStep( + kubernetesConf: KubernetesConf[KubernetesDriverSpecificConf]) + extends KubernetesFeatureConfigStep { + override def configurePod(pod: SparkPod): SparkPod = { + val withDriverArgs = new ContainerBuilder(pod.container) + .addToArgs("driver") + .addToArgs("--properties-file", SPARK_CONF_PATH) + .addToArgs("--class", kubernetesConf.roleSpecificConf.mainClass) + // The user application jar is merged into the spark.jars list and managed through that + // property, so there is no need to reference it explicitly here. + .addToArgs(SparkLauncher.NO_RESOURCE) + .addToArgs(kubernetesConf.roleSpecificConf.appArgs: _*) + .build() + SparkPod(pod.pod, withDriverArgs) + } + override def getAdditionalPodSystemProperties(): Map[String, String] = Map.empty + + override def getAdditionalKubernetesResources(): Seq[HasMetadata] = Seq.empty +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/bindings/PythonDriverFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/bindings/PythonDriverFeatureStep.scala new file mode 100644 index 0000000000000..c20bcac1f8987 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/bindings/PythonDriverFeatureStep.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.features.bindings + +import scala.collection.JavaConverters._ + +import io.fabric8.kubernetes.api.model.{ContainerBuilder, EnvVarBuilder, HasMetadata} + +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpecificConf, KubernetesUtils, SparkPod} +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.deploy.k8s.features.KubernetesFeatureConfigStep + +private[spark] class PythonDriverFeatureStep( + kubernetesConf: KubernetesConf[KubernetesDriverSpecificConf]) + extends KubernetesFeatureConfigStep { + override def configurePod(pod: SparkPod): SparkPod = { + val roleConf = kubernetesConf.roleSpecificConf + require(roleConf.mainAppResource.isDefined, "PySpark Main Resource must be defined") + val maybePythonArgs = Option(roleConf.appArgs).filter(_.nonEmpty).map( + pyArgs => + new EnvVarBuilder() + .withName(ENV_PYSPARK_ARGS) + .withValue(pyArgs.mkString(",")) + .build()) + val maybePythonFiles = kubernetesConf.pyFiles().map( + // Dilineation by ":" is to append the PySpark Files to the PYTHONPATH + // of the respective PySpark pod + pyFiles => + new EnvVarBuilder() + .withName(ENV_PYSPARK_FILES) + .withValue(KubernetesUtils.resolveFileUrisAndPath(pyFiles.split(",")) + .mkString(":")) + .build()) + val envSeq = + Seq(new EnvVarBuilder() + .withName(ENV_PYSPARK_PRIMARY) + .withValue(KubernetesUtils.resolveFileUri(kubernetesConf.pySparkMainResource().get)) + .build(), + new EnvVarBuilder() + .withName(ENV_PYSPARK_MAJOR_PYTHON_VERSION) + .withValue(kubernetesConf.pySparkPythonVersion()) + .build()) + val pythonEnvs = envSeq ++ + maybePythonArgs.toSeq ++ + maybePythonFiles.toSeq + + val withPythonPrimaryContainer = new ContainerBuilder(pod.container) + .addAllToEnv(pythonEnvs.asJava) + .addToArgs("driver-py") + .addToArgs("--properties-file", SPARK_CONF_PATH) + .addToArgs("--class", roleConf.mainClass) + .build() + + SparkPod(pod.pod, withPythonPrimaryContainer) + } + override def getAdditionalPodSystemProperties(): Map[String, String] = Map.empty + + override def getAdditionalKubernetesResources(): Seq[HasMetadata] = Seq.empty +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesClientApplication.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesClientApplication.scala index a97f5650fb869..eaff47205dbbc 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesClientApplication.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesClientApplication.scala @@ -39,11 +39,13 @@ import org.apache.spark.util.Utils * @param mainAppResource the main application resource if any * @param mainClass the main class of the application to run * @param driverArgs arguments to the driver + * @param maybePyFiles additional Python files via --py-files */ private[spark] case class ClientArguments( mainAppResource: Option[MainAppResource], mainClass: String, - driverArgs: Array[String]) + driverArgs: Array[String], + maybePyFiles: Option[String]) private[spark] object ClientArguments { @@ -51,10 +53,15 @@ private[spark] object ClientArguments { var mainAppResource: Option[MainAppResource] = None var mainClass: Option[String] = None val driverArgs = mutable.ArrayBuffer.empty[String] + var maybePyFiles : Option[String] = None args.sliding(2, 2).toList.foreach { case Array("--primary-java-resource", primaryJavaResource: String) => mainAppResource = Some(JavaMainAppResource(primaryJavaResource)) + case Array("--primary-py-file", primaryPythonResource: String) => + mainAppResource = Some(PythonMainAppResource(primaryPythonResource)) + case Array("--other-py-files", pyFiles: String) => + maybePyFiles = Some(pyFiles) case Array("--main-class", clazz: String) => mainClass = Some(clazz) case Array("--arg", arg: String) => @@ -69,7 +76,8 @@ private[spark] object ClientArguments { ClientArguments( mainAppResource, mainClass.get, - driverArgs.toArray) + driverArgs.toArray, + maybePyFiles) } } @@ -206,6 +214,7 @@ private[spark] class KubernetesClientApplication extends SparkApplication { val kubernetesResourceNamePrefix = { s"$appName-$launchTime".toLowerCase.replaceAll("\\.", "-") } + sparkConf.set(KUBERNETES_PYSPARK_PY_FILES, clientArguments.maybePyFiles.getOrElse("")) val kubernetesConf = KubernetesConf.createDriverConf( sparkConf, appName, @@ -213,7 +222,8 @@ private[spark] class KubernetesClientApplication extends SparkApplication { kubernetesAppId, clientArguments.mainAppResource, clientArguments.mainClass, - clientArguments.driverArgs) + clientArguments.driverArgs, + clientArguments.maybePyFiles) val builder = new KubernetesDriverBuilder val namespace = kubernetesConf.namespace() // The master URL has been checked for validity already in SparkSubmit. diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala index fdc5eb0d75832..5762d8245f778 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala @@ -17,7 +17,8 @@ package org.apache.spark.deploy.k8s.submit import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpec, KubernetesDriverSpecificConf, KubernetesRoleSpecificConf} -import org.apache.spark.deploy.k8s.features._ +import org.apache.spark.deploy.k8s.features.{BasicDriverFeatureStep, DriverKubernetesCredentialsFeatureStep, DriverServiceFeatureStep, EnvSecretsFeatureStep, KubernetesFeatureConfigStep, LocalDirsFeatureStep, MountSecretsFeatureStep} +import org.apache.spark.deploy.k8s.features.bindings.{JavaDriverFeatureStep, PythonDriverFeatureStep} private[spark] class KubernetesDriverBuilder( provideBasicStep: (KubernetesConf[KubernetesDriverSpecificConf]) => BasicDriverFeatureStep = @@ -33,9 +34,17 @@ private[spark] class KubernetesDriverBuilder( provideEnvSecretsStep: (KubernetesConf[_ <: KubernetesRoleSpecificConf] => EnvSecretsFeatureStep) = new EnvSecretsFeatureStep(_), - provideLocalDirsStep: (KubernetesConf[_ <: KubernetesRoleSpecificConf]) - => LocalDirsFeatureStep = - new LocalDirsFeatureStep(_)) { + provideLocalDirsStep: (KubernetesConf[_ <: KubernetesRoleSpecificConf] + => LocalDirsFeatureStep) = + new LocalDirsFeatureStep(_), + provideJavaStep: ( + KubernetesConf[KubernetesDriverSpecificConf] + => JavaDriverFeatureStep) = + new JavaDriverFeatureStep(_), + providePythonStep: ( + KubernetesConf[KubernetesDriverSpecificConf] + => PythonDriverFeatureStep) = + new PythonDriverFeatureStep(_)) { def buildFromFeatures( kubernetesConf: KubernetesConf[KubernetesDriverSpecificConf]): KubernetesDriverSpec = { @@ -44,13 +53,23 @@ private[spark] class KubernetesDriverBuilder( provideCredentialsStep(kubernetesConf), provideServiceStep(kubernetesConf), provideLocalDirsStep(kubernetesConf)) - var allFeatures = if (kubernetesConf.roleSecretNamesToMountPaths.nonEmpty) { - baseFeatures ++ Seq(provideSecretsStep(kubernetesConf)) - } else baseFeatures - allFeatures = if (kubernetesConf.roleSecretEnvNamesToKeyRefs.nonEmpty) { - allFeatures ++ Seq(provideEnvSecretsStep(kubernetesConf)) - } else allFeatures + val maybeRoleSecretNamesStep = if (kubernetesConf.roleSecretNamesToMountPaths.nonEmpty) { + Some(provideSecretsStep(kubernetesConf)) } else None + + val maybeProvideSecretsStep = if (kubernetesConf.roleSecretEnvNamesToKeyRefs.nonEmpty) { + Some(provideEnvSecretsStep(kubernetesConf)) } else None + + val bindingsStep = kubernetesConf.roleSpecificConf.mainAppResource.map { + case JavaMainAppResource(_) => + provideJavaStep(kubernetesConf) + case PythonMainAppResource(_) => + providePythonStep(kubernetesConf)}.getOrElse(provideJavaStep(kubernetesConf)) + + val allFeatures: Seq[KubernetesFeatureConfigStep] = + (baseFeatures :+ bindingsStep) ++ + maybeRoleSecretNamesStep.toSeq ++ + maybeProvideSecretsStep.toSeq var spec = KubernetesDriverSpec.initialSpec(kubernetesConf.sparkConf.getAll.toMap) for (feature <- allFeatures) { diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/MainAppResource.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/MainAppResource.scala index cca9f4627a1f6..cbe081ae35683 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/MainAppResource.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/MainAppResource.scala @@ -18,4 +18,9 @@ package org.apache.spark.deploy.k8s.submit private[spark] sealed trait MainAppResource +private[spark] sealed trait NonJVMResource + private[spark] case class JavaMainAppResource(primaryResource: String) extends MainAppResource + +private[spark] case class PythonMainAppResource(primaryResource: String) + extends MainAppResource with NonJVMResource diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala index d5e1de36a58df..769a0a5a63047 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala @@ -17,7 +17,7 @@ package org.apache.spark.scheduler.cluster.k8s import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesExecutorSpecificConf, KubernetesRoleSpecificConf, SparkPod} -import org.apache.spark.deploy.k8s.features.{BasicExecutorFeatureStep, EnvSecretsFeatureStep, LocalDirsFeatureStep, MountSecretsFeatureStep} +import org.apache.spark.deploy.k8s.features.{BasicExecutorFeatureStep, EnvSecretsFeatureStep, KubernetesFeatureConfigStep, LocalDirsFeatureStep, MountSecretsFeatureStep} private[spark] class KubernetesExecutorBuilder( provideBasicStep: (KubernetesConf[KubernetesExecutorSpecificConf]) => BasicExecutorFeatureStep = @@ -34,14 +34,20 @@ private[spark] class KubernetesExecutorBuilder( def buildFromFeatures( kubernetesConf: KubernetesConf[KubernetesExecutorSpecificConf]): SparkPod = { - val baseFeatures = Seq(provideBasicStep(kubernetesConf), provideLocalDirsStep(kubernetesConf)) - var allFeatures = if (kubernetesConf.roleSecretNamesToMountPaths.nonEmpty) { - baseFeatures ++ Seq(provideSecretsStep(kubernetesConf)) - } else baseFeatures + val baseFeatures = Seq( + provideBasicStep(kubernetesConf), + provideLocalDirsStep(kubernetesConf)) - allFeatures = if (kubernetesConf.roleSecretEnvNamesToKeyRefs.nonEmpty) { - allFeatures ++ Seq(provideEnvSecretsStep(kubernetesConf)) - } else allFeatures + val maybeRoleSecretNamesStep = if (kubernetesConf.roleSecretNamesToMountPaths.nonEmpty) { + Some(provideSecretsStep(kubernetesConf)) } else None + + val maybeProvideSecretsStep = if (kubernetesConf.roleSecretEnvNamesToKeyRefs.nonEmpty) { + Some(provideEnvSecretsStep(kubernetesConf)) } else None + + val allFeatures: Seq[KubernetesFeatureConfigStep] = + baseFeatures ++ + maybeRoleSecretNamesStep.toSeq ++ + maybeProvideSecretsStep.toSeq var executorPod = SparkPod.initialPod() for (feature <- allFeatures) { diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesConfSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesConfSuite.scala index 3d23e1cb90fd2..661f942435921 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesConfSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesConfSuite.scala @@ -22,7 +22,7 @@ import io.fabric8.kubernetes.api.model.{LocalObjectReferenceBuilder, PodBuilder} import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ -import org.apache.spark.deploy.k8s.submit.JavaMainAppResource +import org.apache.spark.deploy.k8s.submit._ class KubernetesConfSuite extends SparkFunSuite { @@ -56,9 +56,10 @@ class KubernetesConfSuite extends SparkFunSuite { APP_NAME, RESOURCE_NAME_PREFIX, APP_ID, - None, + mainAppResource = None, MAIN_CLASS, - APP_ARGS) + APP_ARGS, + maybePyFiles = None) assert(conf.appId === APP_ID) assert(conf.sparkConf.getAll.toMap === sparkConf.getAll.toMap) assert(conf.appResourceNamePrefix === RESOURCE_NAME_PREFIX) @@ -79,7 +80,8 @@ class KubernetesConfSuite extends SparkFunSuite { APP_ID, mainAppJar, MAIN_CLASS, - APP_ARGS) + APP_ARGS, + maybePyFiles = None) assert(kubernetesConfWithMainJar.sparkConf.get("spark.jars") .split(",") === Array("local:///opt/spark/jar1.jar", "local:///opt/spark/main.jar")) @@ -88,15 +90,59 @@ class KubernetesConfSuite extends SparkFunSuite { APP_NAME, RESOURCE_NAME_PREFIX, APP_ID, - None, + mainAppResource = None, MAIN_CLASS, - APP_ARGS) + APP_ARGS, + maybePyFiles = None) assert(kubernetesConfWithoutMainJar.sparkConf.get("spark.jars").split(",") === Array("local:///opt/spark/jar1.jar")) + assert(kubernetesConfWithoutMainJar.sparkConf.get(MEMORY_OVERHEAD_FACTOR) === 0.1) } - test("Resolve driver labels, annotations, secret mount paths, and envs.") { + test("Creating driver conf with a python primary file") { + val mainResourceFile = "local:///opt/spark/main.py" + val inputPyFiles = Array("local:///opt/spark/example2.py", "local:///example3.py") val sparkConf = new SparkConf(false) + .setJars(Seq("local:///opt/spark/jar1.jar")) + .set("spark.files", "local:///opt/spark/example4.py") + val mainAppResource = Some(PythonMainAppResource(mainResourceFile)) + val kubernetesConfWithMainResource = KubernetesConf.createDriverConf( + sparkConf, + APP_NAME, + RESOURCE_NAME_PREFIX, + APP_ID, + mainAppResource, + MAIN_CLASS, + APP_ARGS, + Some(inputPyFiles.mkString(","))) + assert(kubernetesConfWithMainResource.sparkConf.get("spark.jars").split(",") + === Array("local:///opt/spark/jar1.jar")) + assert(kubernetesConfWithMainResource.sparkConf.get(MEMORY_OVERHEAD_FACTOR) === 0.4) + assert(kubernetesConfWithMainResource.sparkFiles + === Array("local:///opt/spark/example4.py", mainResourceFile) ++ inputPyFiles) + } + + test("Testing explicit setting of memory overhead on non-JVM tasks") { + val sparkConf = new SparkConf(false) + .set(MEMORY_OVERHEAD_FACTOR, 0.3) + + val mainResourceFile = "local:///opt/spark/main.py" + val mainAppResource = Some(PythonMainAppResource(mainResourceFile)) + val conf = KubernetesConf.createDriverConf( + sparkConf, + APP_NAME, + RESOURCE_NAME_PREFIX, + APP_ID, + mainAppResource, + MAIN_CLASS, + APP_ARGS, + None) + assert(conf.sparkConf.get(MEMORY_OVERHEAD_FACTOR) === 0.3) + } + + test("Resolve driver labels, annotations, secret mount paths, envs, and memory overhead") { + val sparkConf = new SparkConf(false) + .set(MEMORY_OVERHEAD_FACTOR, 0.3) CUSTOM_LABELS.foreach { case (key, value) => sparkConf.set(s"$KUBERNETES_DRIVER_LABEL_PREFIX$key", value) } @@ -118,9 +164,10 @@ class KubernetesConfSuite extends SparkFunSuite { APP_NAME, RESOURCE_NAME_PREFIX, APP_ID, - None, + mainAppResource = None, MAIN_CLASS, - APP_ARGS) + APP_ARGS, + maybePyFiles = None) assert(conf.roleLabels === Map( SPARK_APP_ID_LABEL -> APP_ID, SPARK_ROLE_LABEL -> SPARK_POD_DRIVER_ROLE) ++ @@ -129,6 +176,7 @@ class KubernetesConfSuite extends SparkFunSuite { assert(conf.roleSecretNamesToMountPaths === SECRET_NAMES_TO_MOUNT_PATHS) assert(conf.roleSecretEnvNamesToKeyRefs === SECRET_ENV_VARS) assert(conf.roleEnvs === CUSTOM_ENVS) + assert(conf.sparkConf.get(MEMORY_OVERHEAD_FACTOR) === 0.3) } test("Basic executor translated fields.") { diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStepSuite.scala index b2813d8b3265d..04b909db9d9f3 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStepSuite.scala @@ -24,6 +24,8 @@ import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpecificConf, SparkPod} import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.deploy.k8s.submit.JavaMainAppResource +import org.apache.spark.deploy.k8s.submit.PythonMainAppResource class BasicDriverFeatureStepSuite extends SparkFunSuite { @@ -33,6 +35,7 @@ class BasicDriverFeatureStepSuite extends SparkFunSuite { private val CONTAINER_IMAGE_PULL_POLICY = "IfNotPresent" private val APP_NAME = "spark-test" private val MAIN_CLASS = "org.apache.spark.examples.SparkPi" + private val PY_MAIN_CLASS = "org.apache.spark.deploy.PythonRunner" private val APP_ARGS = Array("arg1", "arg2", "\"arg 3\"") private val CUSTOM_ANNOTATION_KEY = "customAnnotation" private val CUSTOM_ANNOTATION_VALUE = "customAnnotationValue" @@ -60,7 +63,7 @@ class BasicDriverFeatureStepSuite extends SparkFunSuite { val kubernetesConf = KubernetesConf( sparkConf, KubernetesDriverSpecificConf( - None, + Some(JavaMainAppResource("")), APP_NAME, MAIN_CLASS, APP_ARGS), @@ -70,7 +73,8 @@ class BasicDriverFeatureStepSuite extends SparkFunSuite { DRIVER_ANNOTATIONS, Map.empty, Map.empty, - DRIVER_ENVS) + DRIVER_ENVS, + Seq.empty[String]) val featureStep = new BasicDriverFeatureStep(kubernetesConf) val basePod = SparkPod.initialPod() @@ -110,7 +114,6 @@ class BasicDriverFeatureStepSuite extends SparkFunSuite { assert(driverPodMetadata.getLabels.asScala === DRIVER_LABELS) assert(driverPodMetadata.getAnnotations.asScala === DRIVER_ANNOTATIONS) assert(configuredPod.pod.getSpec.getRestartPolicy === "Never") - val expectedSparkConf = Map( KUBERNETES_DRIVER_POD_NAME.key -> "spark-driver-pod", "spark.app.id" -> APP_ID, @@ -119,6 +122,50 @@ class BasicDriverFeatureStepSuite extends SparkFunSuite { assert(featureStep.getAdditionalPodSystemProperties() === expectedSparkConf) } + test("Check appropriate entrypoint rerouting for various bindings") { + val javaSparkConf = new SparkConf() + .set(org.apache.spark.internal.config.DRIVER_MEMORY.key, "4g") + .set(CONTAINER_IMAGE, "spark-driver:latest") + val pythonSparkConf = new SparkConf() + .set(org.apache.spark.internal.config.DRIVER_MEMORY.key, "4g") + .set(CONTAINER_IMAGE, "spark-driver:latest") + val javaKubernetesConf = KubernetesConf( + javaSparkConf, + KubernetesDriverSpecificConf( + Some(JavaMainAppResource("")), + APP_NAME, + PY_MAIN_CLASS, + APP_ARGS), + RESOURCE_NAME_PREFIX, + APP_ID, + DRIVER_LABELS, + DRIVER_ANNOTATIONS, + Map.empty, + Map.empty, + DRIVER_ENVS, + Seq.empty[String]) + val pythonKubernetesConf = KubernetesConf( + pythonSparkConf, + KubernetesDriverSpecificConf( + Some(PythonMainAppResource("")), + APP_NAME, + PY_MAIN_CLASS, + APP_ARGS), + RESOURCE_NAME_PREFIX, + APP_ID, + DRIVER_LABELS, + DRIVER_ANNOTATIONS, + Map.empty, + Map.empty, + DRIVER_ENVS, + Seq.empty[String]) + val javaFeatureStep = new BasicDriverFeatureStep(javaKubernetesConf) + val pythonFeatureStep = new BasicDriverFeatureStep(pythonKubernetesConf) + val basePod = SparkPod.initialPod() + val configuredJavaPod = javaFeatureStep.configurePod(basePod) + val configuredPythonPod = pythonFeatureStep.configurePod(basePod) + } + test("Additional system properties resolve jars and set cluster-mode confs.") { val allJars = Seq("local:///opt/spark/jar1.jar", "hdfs:///opt/spark/jar2.jar") val allFiles = Seq("https://localhost:9000/file1.txt", "local:///opt/spark/file2.txt") @@ -130,7 +177,7 @@ class BasicDriverFeatureStepSuite extends SparkFunSuite { val kubernetesConf = KubernetesConf( sparkConf, KubernetesDriverSpecificConf( - None, + Some(JavaMainAppResource("")), APP_NAME, MAIN_CLASS, APP_ARGS), @@ -140,7 +187,8 @@ class BasicDriverFeatureStepSuite extends SparkFunSuite { DRIVER_ANNOTATIONS, Map.empty, Map.empty, - Map.empty) + DRIVER_ENVS, + allFiles) val step = new BasicDriverFeatureStep(kubernetesConf) val additionalProperties = step.getAdditionalPodSystemProperties() val expectedSparkConf = Map( diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStepSuite.scala index 9182134b3337c..f06030aa55c0c 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStepSuite.scala @@ -88,7 +88,8 @@ class BasicExecutorFeatureStepSuite ANNOTATIONS, Map.empty, Map.empty, - Map.empty)) + Map.empty, + Seq.empty[String])) val executor = step.configurePod(SparkPod.initialPod()) // The executor pod name and default labels. @@ -126,7 +127,8 @@ class BasicExecutorFeatureStepSuite ANNOTATIONS, Map.empty, Map.empty, - Map.empty)) + Map.empty, + Seq.empty[String])) assert(step.configurePod(SparkPod.initialPod()).pod.getSpec.getHostname.length === 63) } @@ -145,7 +147,8 @@ class BasicExecutorFeatureStepSuite ANNOTATIONS, Map.empty, Map.empty, - Map("qux" -> "quux"))) + Map("qux" -> "quux"), + Seq.empty[String])) val executor = step.configurePod(SparkPod.initialPod()) checkEnv(executor, diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverKubernetesCredentialsFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverKubernetesCredentialsFeatureStepSuite.scala index f81894f8055f1..7cea83591f3e8 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverKubernetesCredentialsFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverKubernetesCredentialsFeatureStepSuite.scala @@ -60,7 +60,8 @@ class DriverKubernetesCredentialsFeatureStepSuite extends SparkFunSuite with Bef Map.empty, Map.empty, Map.empty, - Map.empty) + Map.empty, + Seq.empty[String]) val kubernetesCredentialsStep = new DriverKubernetesCredentialsFeatureStep(kubernetesConf) assert(kubernetesCredentialsStep.configurePod(BASE_DRIVER_POD) === BASE_DRIVER_POD) assert(kubernetesCredentialsStep.getAdditionalPodSystemProperties().isEmpty) @@ -90,7 +91,8 @@ class DriverKubernetesCredentialsFeatureStepSuite extends SparkFunSuite with Bef Map.empty, Map.empty, Map.empty, - Map.empty) + Map.empty, + Seq.empty[String]) val kubernetesCredentialsStep = new DriverKubernetesCredentialsFeatureStep(kubernetesConf) assert(kubernetesCredentialsStep.configurePod(BASE_DRIVER_POD) === BASE_DRIVER_POD) @@ -127,7 +129,8 @@ class DriverKubernetesCredentialsFeatureStepSuite extends SparkFunSuite with Bef Map.empty, Map.empty, Map.empty, - Map.empty) + Map.empty, + Seq.empty[String]) val kubernetesCredentialsStep = new DriverKubernetesCredentialsFeatureStep(kubernetesConf) val resolvedProperties = kubernetesCredentialsStep.getAdditionalPodSystemProperties() val expectedSparkConf = Map( diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverServiceFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverServiceFeatureStepSuite.scala index f265522a8823a..77d38bf19cd10 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverServiceFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverServiceFeatureStepSuite.scala @@ -66,7 +66,8 @@ class DriverServiceFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { Map.empty, Map.empty, Map.empty, - Map.empty)) + Map.empty, + Seq.empty[String])) assert(configurationStep.configurePod(SparkPod.initialPod()) === SparkPod.initialPod()) assert(configurationStep.getAdditionalKubernetesResources().size === 1) assert(configurationStep.getAdditionalKubernetesResources().head.isInstanceOf[Service]) @@ -96,7 +97,8 @@ class DriverServiceFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { Map.empty, Map.empty, Map.empty, - Map.empty)) + Map.empty, + Seq.empty[String])) val expectedServiceName = SHORT_RESOURCE_NAME_PREFIX + DriverServiceFeatureStep.DRIVER_SVC_POSTFIX val expectedHostName = s"$expectedServiceName.my-namespace.svc" @@ -116,7 +118,8 @@ class DriverServiceFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { Map.empty, Map.empty, Map.empty, - Map.empty)) + Map.empty, + Seq.empty[String])) val resolvedService = configurationStep .getAdditionalKubernetesResources() .head @@ -145,7 +148,8 @@ class DriverServiceFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { Map.empty, Map.empty, Map.empty, - Map.empty), + Map.empty, + Seq.empty[String]), clock) val driverService = configurationStep .getAdditionalKubernetesResources() @@ -171,7 +175,8 @@ class DriverServiceFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { Map.empty, Map.empty, Map.empty, - Map.empty), + Map.empty, + Seq.empty[String]), clock) fail("The driver bind address should not be allowed.") } catch { @@ -195,7 +200,8 @@ class DriverServiceFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { Map.empty, Map.empty, Map.empty, - Map.empty), + Map.empty, + Seq.empty[String]), clock) fail("The driver host address should not be allowed.") } catch { diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/EnvSecretsFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/EnvSecretsFeatureStepSuite.scala index 8b0b2d0739c76..af6b35eae484a 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/EnvSecretsFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/EnvSecretsFeatureStepSuite.scala @@ -44,7 +44,8 @@ class EnvSecretsFeatureStepSuite extends SparkFunSuite{ Map.empty, Map.empty, envVarsToKeys, - Map.empty) + Map.empty, + Seq.empty[String]) val step = new EnvSecretsFeatureStep(kubernetesConf) val driverContainerWithEnvSecrets = step.configurePod(baseDriverPod).container diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/LocalDirsFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/LocalDirsFeatureStepSuite.scala index 2542a02d37766..bd6ce4b42fc8e 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/LocalDirsFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/LocalDirsFeatureStepSuite.scala @@ -44,7 +44,8 @@ class LocalDirsFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { Map.empty, Map.empty, Map.empty, - Map.empty) + Map.empty, + Seq.empty[String]) } test("Resolve to default local dir if neither env nor configuration are set") { diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountSecretsFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountSecretsFeatureStepSuite.scala index 9155793774123..eff75b8a15daa 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountSecretsFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountSecretsFeatureStepSuite.scala @@ -42,7 +42,8 @@ class MountSecretsFeatureStepSuite extends SparkFunSuite { Map.empty, secretNamesToMountPaths, Map.empty, - Map.empty) + Map.empty, + Seq.empty[String]) val step = new MountSecretsFeatureStep(kubernetesConf) val driverPodWithSecretsMounted = step.configurePod(baseDriverPod).pod diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/bindings/JavaDriverFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/bindings/JavaDriverFeatureStepSuite.scala new file mode 100644 index 0000000000000..0f2bf2fa1d9b5 --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/bindings/JavaDriverFeatureStepSuite.scala @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.features.bindings + +import scala.collection.JavaConverters._ + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpecificConf, SparkPod} +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.deploy.k8s.submit.PythonMainAppResource + +class JavaDriverFeatureStepSuite extends SparkFunSuite { + + test("Java Step modifies container correctly") { + val baseDriverPod = SparkPod.initialPod() + val sparkConf = new SparkConf(false) + val kubernetesConf = KubernetesConf( + sparkConf, + KubernetesDriverSpecificConf( + Some(PythonMainAppResource("local:///main.jar")), + "test-class", + "java-runner", + Seq("5 7")), + appResourceNamePrefix = "", + appId = "", + roleLabels = Map.empty, + roleAnnotations = Map.empty, + roleSecretNamesToMountPaths = Map.empty, + roleSecretEnvNamesToKeyRefs = Map.empty, + roleEnvs = Map.empty, + sparkFiles = Seq.empty[String]) + + val step = new JavaDriverFeatureStep(kubernetesConf) + val driverPod = step.configurePod(baseDriverPod).pod + val driverContainerwithJavaStep = step.configurePod(baseDriverPod).container + assert(driverContainerwithJavaStep.getArgs.size === 7) + val args = driverContainerwithJavaStep + .getArgs.asScala + assert(args === List( + "driver", + "--properties-file", SPARK_CONF_PATH, + "--class", "test-class", + "spark-internal", "5 7")) + + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/bindings/PythonDriverFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/bindings/PythonDriverFeatureStepSuite.scala new file mode 100644 index 0000000000000..a1f9a5d9e264e --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/bindings/PythonDriverFeatureStepSuite.scala @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.features.bindings + +import scala.collection.JavaConverters._ + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpecificConf, SparkPod} +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.deploy.k8s.submit.PythonMainAppResource + +class PythonDriverFeatureStepSuite extends SparkFunSuite { + + test("Python Step modifies container correctly") { + val expectedMainResource = "/main.py" + val mainResource = "local:///main.py" + val pyFiles = Seq("local:///example2.py", "local:///example3.py") + val expectedPySparkFiles = + "/example2.py:/example3.py" + val baseDriverPod = SparkPod.initialPod() + val sparkConf = new SparkConf(false) + .set(KUBERNETES_PYSPARK_MAIN_APP_RESOURCE, mainResource) + .set(KUBERNETES_PYSPARK_PY_FILES, pyFiles.mkString(",")) + .set("spark.files", "local:///example.py") + .set(PYSPARK_MAJOR_PYTHON_VERSION, "2") + val kubernetesConf = KubernetesConf( + sparkConf, + KubernetesDriverSpecificConf( + Some(PythonMainAppResource("local:///main.py")), + "test-app", + "python-runner", + Seq("5 7")), + appResourceNamePrefix = "", + appId = "", + roleLabels = Map.empty, + roleAnnotations = Map.empty, + roleSecretNamesToMountPaths = Map.empty, + roleSecretEnvNamesToKeyRefs = Map.empty, + roleEnvs = Map.empty, + sparkFiles = Seq.empty[String]) + + val step = new PythonDriverFeatureStep(kubernetesConf) + val driverPod = step.configurePod(baseDriverPod).pod + val driverContainerwithPySpark = step.configurePod(baseDriverPod).container + assert(driverContainerwithPySpark.getEnv.size === 4) + val envs = driverContainerwithPySpark + .getEnv + .asScala + .map(env => (env.getName, env.getValue)) + .toMap + assert(envs(ENV_PYSPARK_PRIMARY) === expectedMainResource) + assert(envs(ENV_PYSPARK_FILES) === expectedPySparkFiles) + assert(envs(ENV_PYSPARK_ARGS) === "5 7") + assert(envs(ENV_PYSPARK_MAJOR_PYTHON_VERSION) === "2") + } + test("Python Step testing empty pyfiles") { + val mainResource = "local:///main.py" + val baseDriverPod = SparkPod.initialPod() + val sparkConf = new SparkConf(false) + .set(KUBERNETES_PYSPARK_MAIN_APP_RESOURCE, mainResource) + .set(PYSPARK_MAJOR_PYTHON_VERSION, "3") + val kubernetesConf = KubernetesConf( + sparkConf, + KubernetesDriverSpecificConf( + Some(PythonMainAppResource("local:///main.py")), + "test-class-py", + "python-runner", + Seq.empty[String]), + appResourceNamePrefix = "", + appId = "", + roleLabels = Map.empty, + roleAnnotations = Map.empty, + roleSecretNamesToMountPaths = Map.empty, + roleSecretEnvNamesToKeyRefs = Map.empty, + roleEnvs = Map.empty, + sparkFiles = Seq.empty[String]) + val step = new PythonDriverFeatureStep(kubernetesConf) + val driverContainerwithPySpark = step.configurePod(baseDriverPod).container + val args = driverContainerwithPySpark + .getArgs.asScala + assert(driverContainerwithPySpark.getArgs.size === 5) + assert(args === List( + "driver-py", + "--properties-file", SPARK_CONF_PATH, + "--class", "test-class-py")) + val envs = driverContainerwithPySpark + .getEnv + .asScala + .map(env => (env.getName, env.getValue)) + .toMap + assert(envs(ENV_PYSPARK_MAJOR_PYTHON_VERSION) === "3") + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala index 0775338098a13..a8a8218c621ea 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala @@ -143,7 +143,8 @@ class ClientSuite extends SparkFunSuite with BeforeAndAfter { Map.empty, Map.empty, Map.empty, - Map.empty) + Map.empty, + Seq.empty[String]) when(driverBuilder.buildFromFeatures(kubernetesConf)).thenReturn(BUILT_KUBERNETES_SPEC) when(kubernetesClient.pods()).thenReturn(podOperations) when(podOperations.withName(POD_NAME)).thenReturn(namedPods) diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala index cb724068ea4f3..4e8c300543430 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.deploy.k8s.submit import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpec, KubernetesDriverSpecificConf} import org.apache.spark.deploy.k8s.features.{BasicDriverFeatureStep, DriverKubernetesCredentialsFeatureStep, DriverServiceFeatureStep, EnvSecretsFeatureStep, KubernetesFeaturesTestUtils, LocalDirsFeatureStep, MountSecretsFeatureStep} +import org.apache.spark.deploy.k8s.features.bindings.{JavaDriverFeatureStep, PythonDriverFeatureStep} class KubernetesDriverBuilderSuite extends SparkFunSuite { @@ -27,6 +28,8 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite { private val SERVICE_STEP_TYPE = "service" private val LOCAL_DIRS_STEP_TYPE = "local-dirs" private val SECRETS_STEP_TYPE = "mount-secrets" + private val JAVA_STEP_TYPE = "java-bindings" + private val PYSPARK_STEP_TYPE = "pyspark-bindings" private val ENV_SECRETS_STEP_TYPE = "env-secrets" private val basicFeatureStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( @@ -44,6 +47,12 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite { private val secretsStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( SECRETS_STEP_TYPE, classOf[MountSecretsFeatureStep]) + private val javaStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( + JAVA_STEP_TYPE, classOf[JavaDriverFeatureStep]) + + private val pythonStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( + PYSPARK_STEP_TYPE, classOf[PythonDriverFeatureStep]) + private val envSecretsStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( ENV_SECRETS_STEP_TYPE, classOf[EnvSecretsFeatureStep]) @@ -54,13 +63,15 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite { _ => serviceStep, _ => secretsStep, _ => envSecretsStep, - _ => localDirsStep) + _ => localDirsStep, + _ => javaStep, + _ => pythonStep) test("Apply fundamental steps all the time.") { val conf = KubernetesConf( new SparkConf(false), KubernetesDriverSpecificConf( - None, + Some(JavaMainAppResource("example.jar")), "test-app", "main", Seq.empty), @@ -70,13 +81,15 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite { Map.empty, Map.empty, Map.empty, - Map.empty) + Map.empty, + Seq.empty[String]) validateStepTypesApplied( builderUnderTest.buildFromFeatures(conf), BASIC_STEP_TYPE, CREDENTIALS_STEP_TYPE, SERVICE_STEP_TYPE, - LOCAL_DIRS_STEP_TYPE) + LOCAL_DIRS_STEP_TYPE, + JAVA_STEP_TYPE) } test("Apply secrets step if secrets are present.") { @@ -93,7 +106,8 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite { Map.empty, Map("secret" -> "secretMountPath"), Map("EnvName" -> "SecretName:secretKey"), - Map.empty) + Map.empty, + Seq.empty[String]) validateStepTypesApplied( builderUnderTest.buildFromFeatures(conf), BASIC_STEP_TYPE, @@ -101,8 +115,58 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite { SERVICE_STEP_TYPE, LOCAL_DIRS_STEP_TYPE, SECRETS_STEP_TYPE, - ENV_SECRETS_STEP_TYPE - ) + ENV_SECRETS_STEP_TYPE, + JAVA_STEP_TYPE) + } + + test("Apply Java step if main resource is none.") { + val conf = KubernetesConf( + new SparkConf(false), + KubernetesDriverSpecificConf( + None, + "test-app", + "main", + Seq.empty), + "prefix", + "appId", + Map.empty, + Map.empty, + Map.empty, + Map.empty, + Map.empty, + Seq.empty[String]) + validateStepTypesApplied( + builderUnderTest.buildFromFeatures(conf), + BASIC_STEP_TYPE, + CREDENTIALS_STEP_TYPE, + SERVICE_STEP_TYPE, + LOCAL_DIRS_STEP_TYPE, + JAVA_STEP_TYPE) + } + + test("Apply Python step if main resource is python.") { + val conf = KubernetesConf( + new SparkConf(false), + KubernetesDriverSpecificConf( + Some(PythonMainAppResource("example.py")), + "test-app", + "main", + Seq.empty), + "prefix", + "appId", + Map.empty, + Map.empty, + Map.empty, + Map.empty, + Map.empty, + Seq.empty[String]) + validateStepTypesApplied( + builderUnderTest.buildFromFeatures(conf), + BASIC_STEP_TYPE, + CREDENTIALS_STEP_TYPE, + SERVICE_STEP_TYPE, + LOCAL_DIRS_STEP_TYPE, + PYSPARK_STEP_TYPE) } private def validateStepTypesApplied(resolvedSpec: KubernetesDriverSpec, stepTypes: String*) diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala index 753cd30a237f3..a6bc8bce32926 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala @@ -54,7 +54,8 @@ class KubernetesExecutorBuilderSuite extends SparkFunSuite { Map.empty, Map.empty, Map.empty, - Map.empty) + Map.empty, + Seq.empty[String]) validateStepTypesApplied( builderUnderTest.buildFromFeatures(conf), BASIC_STEP_TYPE, LOCAL_DIRS_STEP_TYPE) } @@ -70,7 +71,8 @@ class KubernetesExecutorBuilderSuite extends SparkFunSuite { Map.empty, Map("secret" -> "secretMountPath"), Map("secret-name" -> "secret-key"), - Map.empty) + Map.empty, + Seq.empty[String]) validateStepTypesApplied( builderUnderTest.buildFromFeatures(conf), BASIC_STEP_TYPE, diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/bindings/python/Dockerfile b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/bindings/python/Dockerfile new file mode 100644 index 0000000000000..72bb9620b45de --- /dev/null +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/bindings/python/Dockerfile @@ -0,0 +1,39 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +ARG base_img +FROM $base_img +WORKDIR / +RUN mkdir ${SPARK_HOME}/python +COPY python/lib ${SPARK_HOME}/python/lib +# TODO: Investigate running both pip and pip3 via virtualenvs +RUN apk add --no-cache python && \ + apk add --no-cache python3 && \ + python -m ensurepip && \ + python3 -m ensurepip && \ + # We remove ensurepip since it adds no functionality since pip is + # installed on the image and it just takes up 1.6MB on the image + rm -r /usr/lib/python*/ensurepip && \ + pip install --upgrade pip setuptools && \ + # You may install with python3 packages by using pip3.6 + # Removed the .cache to save space + rm -r /root/.cache + +ENV PYTHONPATH ${SPARK_HOME}/python/lib/pyspark.zip:${SPARK_HOME}/python/lib/py4j-*.zip + +WORKDIR /opt/spark/work-dir +ENTRYPOINT [ "/opt/entrypoint.sh" ] diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh index 3e166116aa3fd..acdb4b1f09e0a 100755 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh @@ -53,6 +53,28 @@ if [ -n "$SPARK_MOUNTED_FILES_DIR" ]; then cp -R "$SPARK_MOUNTED_FILES_DIR/." . fi +if [ -n "$PYSPARK_FILES" ]; then + PYTHONPATH="$PYTHONPATH:$PYSPARK_FILES" +fi + +PYSPARK_ARGS="" +if [ -n "$PYSPARK_APP_ARGS" ]; then + PYSPARK_ARGS="$PYSPARK_APP_ARGS" +fi + + +if [ "$PYSPARK_MAJOR_PYTHON_VERSION" == "2" ]; then + pyv="$(python -V 2>&1)" + export PYTHON_VERSION="${pyv:7}" + export PYSPARK_PYTHON="python" + export PYSPARK_DRIVER_PYTHON="python" +elif [ "$PYSPARK_MAJOR_PYTHON_VERSION" == "3" ]; then + pyv3="$(python3 -V 2>&1)" + export PYTHON_VERSION="${pyv3:7}" + export PYSPARK_PYTHON="python3" + export PYSPARK_DRIVER_PYTHON="python3" +fi + case "$SPARK_K8S_CMD" in driver) CMD=( @@ -62,6 +84,14 @@ case "$SPARK_K8S_CMD" in "$@" ) ;; + driver-py) + CMD=( + "$SPARK_HOME/bin/spark-submit" + --conf "spark.driver.bindAddress=$SPARK_DRIVER_BIND_ADDRESS" + --deploy-mode client + "$@" $PYSPARK_PRIMARY $PYSPARK_ARGS + ) + ;; executor) CMD=( From b070ded2843e88131c90cb9ef1b4f8d533f8361d Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Sat, 9 Jun 2018 01:27:51 +0700 Subject: [PATCH 0933/1161] [SPARK-17756][PYTHON][STREAMING] Workaround to avoid return type mismatch in PythonTransformFunction ## What changes were proposed in this pull request? This PR proposes to wrap the transformed rdd within `TransformFunction`. `PythonTransformFunction` looks requiring to return `JavaRDD` in `_jrdd`. https://github.com/apache/spark/blob/39e2bad6a866d27c3ca594d15e574a1da3ee84cc/python/pyspark/streaming/util.py#L67 https://github.com/apache/spark/blob/6ee28423ad1b2e6089b82af64a31d77d3552bb38/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala#L43 However, this could be `JavaPairRDD` by some APIs, for example, `zip` in PySpark's RDD API. `_jrdd` could be checked as below: ```python >>> rdd.zip(rdd)._jrdd.getClass().toString() u'class org.apache.spark.api.java.JavaPairRDD' ``` So, here, I wrapped it with `map` so that it ensures returning `JavaRDD`. ```python >>> rdd.zip(rdd).map(lambda x: x)._jrdd.getClass().toString() u'class org.apache.spark.api.java.JavaRDD' ``` I tried to elaborate some failure cases as below: ```python from pyspark.streaming import StreamingContext ssc = StreamingContext(spark.sparkContext, 10) ssc.queueStream([sc.range(10)]) \ .transform(lambda rdd: rdd.cartesian(rdd)) \ .pprint() ssc.start() ``` ```python from pyspark.streaming import StreamingContext ssc = StreamingContext(spark.sparkContext, 10) ssc.queueStream([sc.range(10)]).foreachRDD(lambda rdd: rdd.cartesian(rdd)) ssc.start() ``` ```python from pyspark.streaming import StreamingContext ssc = StreamingContext(spark.sparkContext, 10) ssc.queueStream([sc.range(10)]).foreachRDD(lambda rdd: rdd.zip(rdd)) ssc.start() ``` ```python from pyspark.streaming import StreamingContext ssc = StreamingContext(spark.sparkContext, 10) ssc.queueStream([sc.range(10)]).foreachRDD(lambda rdd: rdd.zip(rdd).union(rdd.zip(rdd))) ssc.start() ``` ```python from pyspark.streaming import StreamingContext ssc = StreamingContext(spark.sparkContext, 10) ssc.queueStream([sc.range(10)]).foreachRDD(lambda rdd: rdd.zip(rdd).coalesce(1)) ssc.start() ``` ## How was this patch tested? Unit tests were added in `python/pyspark/streaming/tests.py` and manually tested. Author: hyukjinkwon Closes #19498 from HyukjinKwon/SPARK-17756. --- python/pyspark/streaming/context.py | 2 +- python/pyspark/streaming/tests.py | 6 ++++++ python/pyspark/streaming/util.py | 11 ++++++++++- 3 files changed, 17 insertions(+), 2 deletions(-) diff --git a/python/pyspark/streaming/context.py b/python/pyspark/streaming/context.py index 17c34f8a1c54c..dd924ef89868e 100644 --- a/python/pyspark/streaming/context.py +++ b/python/pyspark/streaming/context.py @@ -338,7 +338,7 @@ def transform(self, dstreams, transformFunc): jdstreams = [d._jdstream for d in dstreams] # change the final serializer to sc.serializer func = TransformFunction(self._sc, - lambda t, *rdds: transformFunc(rdds).map(lambda x: x), + lambda t, *rdds: transformFunc(rdds), *[d._jrdd_deserializer for d in dstreams]) jfunc = self._jvm.TransformFunction(func) jdstream = self._jssc.transform(jdstreams, jfunc) diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index e4a428a0b27e7..373784f826677 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -779,6 +779,12 @@ def func(rdds): self.assertEqual([2, 3, 1], self._take(dstream, 3)) + def test_transform_pairrdd(self): + # This regression test case is for SPARK-17756. + dstream = self.ssc.queueStream( + [[1], [2], [3]]).transform(lambda rdd: rdd.cartesian(rdd)) + self.assertEqual([(1, 1), (2, 2), (3, 3)], self._take(dstream, 3)) + def test_get_active(self): self.assertEqual(StreamingContext.getActive(), None) diff --git a/python/pyspark/streaming/util.py b/python/pyspark/streaming/util.py index df184471993ff..b4b9f97feb7ca 100644 --- a/python/pyspark/streaming/util.py +++ b/python/pyspark/streaming/util.py @@ -20,6 +20,8 @@ import traceback import sys +from py4j.java_gateway import is_instance_of + from pyspark import SparkContext, RDD @@ -65,7 +67,14 @@ def call(self, milliseconds, jrdds): t = datetime.fromtimestamp(milliseconds / 1000.0) r = self.func(t, *rdds) if r: - return r._jrdd + # Here, we work around to ensure `_jrdd` is `JavaRDD` by wrapping it by `map`. + # org.apache.spark.streaming.api.python.PythonTransformFunction requires to return + # `JavaRDD`; however, this could be `JavaPairRDD` by some APIs, for example, `zip`. + # See SPARK-17756. + if is_instance_of(self.ctx._gateway, r._jrdd, "org.apache.spark.api.java.JavaRDD"): + return r._jrdd + else: + return r.map(lambda x: x)._jrdd except: self.failure = traceback.format_exc() From f433ef786770e48e3594ad158ce9908f98ef0d9a Mon Sep 17 00:00:00 2001 From: Sean Suchter Date: Fri, 8 Jun 2018 15:15:24 -0700 Subject: [PATCH 0934/1161] [SPARK-23010][K8S] Initial checkin of k8s integration tests. These tests were developed in the https://github.com/apache-spark-on-k8s/spark-integration repo by several contributors. This is a copy of the current state into the main apache spark repo. The only changes from the current spark-integration repo state are: * Move the files from the repo root into resource-managers/kubernetes/integration-tests * Add a reference to these tests in the root README.md * Fix a path reference in dev/dev-run-integration-tests.sh * Add a TODO in include/util.sh ## What changes were proposed in this pull request? Incorporation of Kubernetes integration tests. ## How was this patch tested? This code has its own unit tests, but the main purpose is to provide the integration tests. I tested this on my laptop by running dev/dev-run-integration-tests.sh --spark-tgz ~/spark-2.4.0-SNAPSHOT-bin--.tgz The spark-integration tests have already been running for months in AMPLab, here is an example: https://amplab.cs.berkeley.edu/jenkins/job/testing-k8s-scheduled-spark-integration-master/ Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Sean Suchter Author: Sean Suchter Closes #20697 from ssuchter/ssuchter-k8s-integration-tests. --- README.md | 2 + dev/tox.ini | 2 +- pom.xml | 1 + .../kubernetes/integration-tests/README.md | 52 ++++ .../dev/dev-run-integration-tests.sh | 93 ++++++ .../integration-tests/dev/spark-rbac.yaml | 52 ++++ .../kubernetes/integration-tests/pom.xml | 155 +++++++++ .../scripts/setup-integration-test-env.sh | 91 ++++++ .../src/test/resources/log4j.properties | 31 ++ .../k8s/integrationtest/KubernetesSuite.scala | 294 ++++++++++++++++++ .../KubernetesTestComponents.scala | 120 +++++++ .../k8s/integrationtest/ProcessUtils.scala | 46 +++ .../SparkReadinessWatcher.scala | 41 +++ .../deploy/k8s/integrationtest/Utils.scala | 30 ++ .../backend/IntegrationTestBackend.scala | 43 +++ .../backend/minikube/Minikube.scala | 84 +++++ .../minikube/MinikubeTestBackend.scala | 42 +++ .../deploy/k8s/integrationtest/config.scala | 38 +++ .../k8s/integrationtest/constants.scala | 22 ++ 19 files changed, 1238 insertions(+), 1 deletion(-) create mode 100644 resource-managers/kubernetes/integration-tests/README.md create mode 100755 resource-managers/kubernetes/integration-tests/dev/dev-run-integration-tests.sh create mode 100644 resource-managers/kubernetes/integration-tests/dev/spark-rbac.yaml create mode 100644 resource-managers/kubernetes/integration-tests/pom.xml create mode 100755 resource-managers/kubernetes/integration-tests/scripts/setup-integration-test-env.sh create mode 100644 resource-managers/kubernetes/integration-tests/src/test/resources/log4j.properties create mode 100644 resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala create mode 100644 resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesTestComponents.scala create mode 100644 resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/ProcessUtils.scala create mode 100644 resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/SparkReadinessWatcher.scala create mode 100644 resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/Utils.scala create mode 100644 resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/backend/IntegrationTestBackend.scala create mode 100644 resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/backend/minikube/Minikube.scala create mode 100644 resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/backend/minikube/MinikubeTestBackend.scala create mode 100644 resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/config.scala create mode 100644 resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/constants.scala diff --git a/README.md b/README.md index 1e521a7e7b178..531d330234062 100644 --- a/README.md +++ b/README.md @@ -81,6 +81,8 @@ can be run using: Please see the guidance on how to [run tests for a module, or individual tests](http://spark.apache.org/developer-tools.html#individual-tests). +There is also a Kubernetes integration test, see resource-managers/kubernetes/integration-tests/README.md + ## A Note About Hadoop Versions Spark uses the Hadoop core library to talk to HDFS and other Hadoop-supported diff --git a/dev/tox.ini b/dev/tox.ini index 583c1eaaa966b..28dad8f3b5c7c 100644 --- a/dev/tox.ini +++ b/dev/tox.ini @@ -16,4 +16,4 @@ [pycodestyle] ignore=E402,E731,E241,W503,E226,E722,E741,E305 max-line-length=100 -exclude=cloudpickle.py,heapq3.py,shared.py,python/docs/conf.py,work/*/*.py,python/.eggs/* +exclude=cloudpickle.py,heapq3.py,shared.py,python/docs/conf.py,work/*/*.py,python/.eggs/*,dist/* diff --git a/pom.xml b/pom.xml index 883c096ae1ae9..23bbd3b09734e 100644 --- a/pom.xml +++ b/pom.xml @@ -2705,6 +2705,7 @@ kubernetes resource-managers/kubernetes/core + resource-managers/kubernetes/integration-tests diff --git a/resource-managers/kubernetes/integration-tests/README.md b/resource-managers/kubernetes/integration-tests/README.md new file mode 100644 index 0000000000000..b3863e6b7d1af --- /dev/null +++ b/resource-managers/kubernetes/integration-tests/README.md @@ -0,0 +1,52 @@ +--- +layout: global +title: Spark on Kubernetes Integration Tests +--- + +# Running the Kubernetes Integration Tests + +Note that the integration test framework is currently being heavily revised and +is subject to change. Note that currently the integration tests only run with Java 8. + +The simplest way to run the integration tests is to install and run Minikube, then run the following: + + dev/dev-run-integration-tests.sh + +The minimum tested version of Minikube is 0.23.0. The kube-dns addon must be enabled. Minikube should +run with a minimum of 3 CPUs and 4G of memory: + + minikube start --cpus 3 --memory 4096 + +You can download Minikube [here](https://github.com/kubernetes/minikube/releases). + +# Integration test customization + +Configuration of the integration test runtime is done through passing different arguments to the test script. The main useful options are outlined below. + +## Re-using Docker Images + +By default, the test framework will build new Docker images on every test execution. A unique image tag is generated, +and it is written to file at `target/imageTag.txt`. To reuse the images built in a previous run, or to use a Docker image tag +that you have built by other means already, pass the tag to the test script: + + dev/dev-run-integration-tests.sh --image-tag + +where if you still want to use images that were built before by the test framework: + + dev/dev-run-integration-tests.sh --image-tag $(cat target/imageTag.txt) + +## Spark Distribution Under Test + +The Spark code to test is handed to the integration test system via a tarball. Here is the option that is used to specify the tarball: + +* `--spark-tgz ` - set `` to point to a tarball containing the Spark distribution to test. + +TODO: Don't require the packaging of the built Spark artifacts into this tarball, just read them out of the current tree. + +## Customizing the Namespace and Service Account + +* `--namespace ` - set `` to the namespace in which the tests should be run. +* `--service-account ` - set `` to the name of the Kubernetes service account to +use in the namespace specified by the `--namespace`. The service account is expected to have permissions to get, list, watch, +and create pods. For clusters with RBAC turned on, it's important that the right permissions are granted to the service account +in the namespace through an appropriate role and role binding. A reference RBAC configuration is provided in `dev/spark-rbac.yaml`. diff --git a/resource-managers/kubernetes/integration-tests/dev/dev-run-integration-tests.sh b/resource-managers/kubernetes/integration-tests/dev/dev-run-integration-tests.sh new file mode 100755 index 0000000000000..ea893fa39eede --- /dev/null +++ b/resource-managers/kubernetes/integration-tests/dev/dev-run-integration-tests.sh @@ -0,0 +1,93 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +TEST_ROOT_DIR=$(git rev-parse --show-toplevel)/resource-managers/kubernetes/integration-tests + +cd "${TEST_ROOT_DIR}" + +DEPLOY_MODE="minikube" +IMAGE_REPO="docker.io/kubespark" +SPARK_TGZ="N/A" +IMAGE_TAG="N/A" +SPARK_MASTER= +NAMESPACE= +SERVICE_ACCOUNT= + +# Parse arguments +while (( "$#" )); do + case $1 in + --image-repo) + IMAGE_REPO="$2" + shift + ;; + --image-tag) + IMAGE_TAG="$2" + shift + ;; + --deploy-mode) + DEPLOY_MODE="$2" + shift + ;; + --spark-tgz) + SPARK_TGZ="$2" + shift + ;; + --spark-master) + SPARK_MASTER="$2" + shift + ;; + --namespace) + NAMESPACE="$2" + shift + ;; + --service-account) + SERVICE_ACCOUNT="$2" + shift + ;; + *) + break + ;; + esac + shift +done + +cd $TEST_ROOT_DIR + +properties=( + -Dspark.kubernetes.test.sparkTgz=$SPARK_TGZ \ + -Dspark.kubernetes.test.imageTag=$IMAGE_TAG \ + -Dspark.kubernetes.test.imageRepo=$IMAGE_REPO \ + -Dspark.kubernetes.test.deployMode=$DEPLOY_MODE +) + +if [ -n $NAMESPACE ]; +then + properties=( ${properties[@]} -Dspark.kubernetes.test.namespace=$NAMESPACE ) +fi + +if [ -n $SERVICE_ACCOUNT ]; +then + properties=( ${properties[@]} -Dspark.kubernetes.test.serviceAccountName=$SERVICE_ACCOUNT ) +fi + +if [ -n $SPARK_MASTER ]; +then + properties=( ${properties[@]} -Dspark.kubernetes.test.master=$SPARK_MASTER ) +fi + +../../../build/mvn integration-test ${properties[@]} diff --git a/resource-managers/kubernetes/integration-tests/dev/spark-rbac.yaml b/resource-managers/kubernetes/integration-tests/dev/spark-rbac.yaml new file mode 100644 index 0000000000000..a4c242f2f2645 --- /dev/null +++ b/resource-managers/kubernetes/integration-tests/dev/spark-rbac.yaml @@ -0,0 +1,52 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +apiVersion: v1 +kind: Namespace +metadata: + name: spark +--- +apiVersion: v1 +kind: ServiceAccount +metadata: + name: spark-sa + namespace: spark +--- +apiVersion: rbac.authorization.k8s.io/v1beta1 +kind: ClusterRole +metadata: + name: spark-role +rules: +- apiGroups: + - "" + resources: + - "pods" + verbs: + - "*" +--- +apiVersion: rbac.authorization.k8s.io/v1beta1 +kind: ClusterRoleBinding +metadata: + name: spark-role-binding +subjects: +- kind: ServiceAccount + name: spark-sa + namespace: spark +roleRef: + kind: ClusterRole + name: spark-role + apiGroup: rbac.authorization.k8s.io \ No newline at end of file diff --git a/resource-managers/kubernetes/integration-tests/pom.xml b/resource-managers/kubernetes/integration-tests/pom.xml new file mode 100644 index 0000000000000..520bda89e034d --- /dev/null +++ b/resource-managers/kubernetes/integration-tests/pom.xml @@ -0,0 +1,155 @@ + + + + 4.0.0 + + org.apache.spark + spark-parent_2.11 + 2.4.0-SNAPSHOT + ../../../pom.xml + + + spark-kubernetes-integration-tests_2.11 + spark-kubernetes-integration-tests + + 1.3.0 + 1.4.0 + + 3.0.0 + 3.2.2 + 1.0 + kubernetes-integration-tests + ${project.build.directory}/spark-dist-unpacked + N/A + ${project.build.directory}/imageTag.txt + minikube + docker.io/kubespark + + + jar + Spark Project Kubernetes Integration Tests + + + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + + + io.fabric8 + kubernetes-client + ${kubernetes-client.version} + + + + + + + org.codehaus.mojo + exec-maven-plugin + ${exec-maven-plugin.version} + + + setup-integration-test-env + pre-integration-test + + exec + + + scripts/setup-integration-test-env.sh + + --unpacked-spark-tgz + ${spark.kubernetes.test.unpackSparkDir} + + --image-repo + ${spark.kubernetes.test.imageRepo} + + --image-tag + ${spark.kubernetes.test.imageTag} + + --image-tag-output-file + ${spark.kubernetes.test.imageTagFile} + + --deploy-mode + ${spark.kubernetes.test.deployMode} + + --spark-tgz + ${spark.kubernetes.test.sparkTgz} + + + + + + + + org.scalatest + scalatest-maven-plugin + ${scalatest-maven-plugin.version} + + ${project.build.directory}/surefire-reports + . + SparkTestSuite.txt + -ea -Xmx3g -XX:ReservedCodeCacheSize=512m ${extraScalaTestArgs} + + + file:src/test/resources/log4j.properties + true + ${spark.kubernetes.test.imageTagFile} + ${spark.kubernetes.test.unpackSparkDir} + ${spark.kubernetes.test.imageRepo} + ${spark.kubernetes.test.deployMode} + ${spark.kubernetes.test.master} + ${spark.kubernetes.test.namespace} + ${spark.kubernetes.test.serviceAccountName} + + ${test.exclude.tags} + + + + test + + test + + + + (?<!Suite) + + + + integration-test + integration-test + + test + + + + + + + + + diff --git a/resource-managers/kubernetes/integration-tests/scripts/setup-integration-test-env.sh b/resource-managers/kubernetes/integration-tests/scripts/setup-integration-test-env.sh new file mode 100755 index 0000000000000..ccfb8e767c529 --- /dev/null +++ b/resource-managers/kubernetes/integration-tests/scripts/setup-integration-test-env.sh @@ -0,0 +1,91 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +TEST_ROOT_DIR=$(git rev-parse --show-toplevel) +UNPACKED_SPARK_TGZ="$TEST_ROOT_DIR/target/spark-dist-unpacked" +IMAGE_TAG_OUTPUT_FILE="$TEST_ROOT_DIR/target/image-tag.txt" +DEPLOY_MODE="minikube" +IMAGE_REPO="docker.io/kubespark" +IMAGE_TAG="N/A" +SPARK_TGZ="N/A" + +# Parse arguments +while (( "$#" )); do + case $1 in + --unpacked-spark-tgz) + UNPACKED_SPARK_TGZ="$2" + shift + ;; + --image-repo) + IMAGE_REPO="$2" + shift + ;; + --image-tag) + IMAGE_TAG="$2" + shift + ;; + --image-tag-output-file) + IMAGE_TAG_OUTPUT_FILE="$2" + shift + ;; + --deploy-mode) + DEPLOY_MODE="$2" + shift + ;; + --spark-tgz) + SPARK_TGZ="$2" + shift + ;; + *) + break + ;; + esac + shift +done + +if [[ $SPARK_TGZ == "N/A" ]]; +then + echo "Must specify a Spark tarball to build Docker images against with --spark-tgz." && exit 1; +fi + +rm -rf $UNPACKED_SPARK_TGZ +mkdir -p $UNPACKED_SPARK_TGZ +tar -xzvf $SPARK_TGZ --strip-components=1 -C $UNPACKED_SPARK_TGZ; + +if [[ $IMAGE_TAG == "N/A" ]]; +then + IMAGE_TAG=$(uuidgen); + cd $UNPACKED_SPARK_TGZ + if [[ $DEPLOY_MODE == cloud ]] ; + then + $UNPACKED_SPARK_TGZ/bin/docker-image-tool.sh -r $IMAGE_REPO -t $IMAGE_TAG build + if [[ $IMAGE_REPO == gcr.io* ]] ; + then + gcloud docker -- push $IMAGE_REPO/spark:$IMAGE_TAG + else + $UNPACKED_SPARK_TGZ/bin/docker-image-tool.sh -r $IMAGE_REPO -t $IMAGE_TAG push + fi + else + # -m option for minikube. + $UNPACKED_SPARK_TGZ/bin/docker-image-tool.sh -m -r $IMAGE_REPO -t $IMAGE_TAG build + fi + cd - +fi + +rm -f $IMAGE_TAG_OUTPUT_FILE +echo -n $IMAGE_TAG > $IMAGE_TAG_OUTPUT_FILE diff --git a/resource-managers/kubernetes/integration-tests/src/test/resources/log4j.properties b/resource-managers/kubernetes/integration-tests/src/test/resources/log4j.properties new file mode 100644 index 0000000000000..866126bc3c1c2 --- /dev/null +++ b/resource-managers/kubernetes/integration-tests/src/test/resources/log4j.properties @@ -0,0 +1,31 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Set everything to be logged to the file target/integration-tests.log +log4j.rootCategory=INFO, file +log4j.appender.file=org.apache.log4j.FileAppender +log4j.appender.file.append=true +log4j.appender.file.file=target/integration-tests.log +log4j.appender.file.layout=org.apache.log4j.PatternLayout +log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n + +# Ignore messages below warning level from a few verbose libraries. +log4j.logger.com.sun.jersey=WARN +log4j.logger.org.apache.hadoop=WARN +log4j.logger.org.eclipse.jetty=WARN +log4j.logger.org.mortbay=WARN +log4j.logger.org.spark_project.jetty=WARN diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala new file mode 100644 index 0000000000000..65c513cf241a4 --- /dev/null +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala @@ -0,0 +1,294 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.integrationtest + +import java.io.File +import java.nio.file.{Path, Paths} +import java.util.UUID +import java.util.regex.Pattern + +import scala.collection.JavaConverters._ + +import com.google.common.io.PatternFilenameFilter +import io.fabric8.kubernetes.api.model.{Container, Pod} +import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} +import org.scalatest.concurrent.{Eventually, PatienceConfiguration} +import org.scalatest.time.{Minutes, Seconds, Span} + +import org.apache.spark.SparkFunSuite +import org.apache.spark.deploy.k8s.integrationtest.backend.{IntegrationTestBackend, IntegrationTestBackendFactory} +import org.apache.spark.deploy.k8s.integrationtest.config._ + +private[spark] class KubernetesSuite extends SparkFunSuite + with BeforeAndAfterAll with BeforeAndAfter { + + import KubernetesSuite._ + + private var testBackend: IntegrationTestBackend = _ + private var sparkHomeDir: Path = _ + private var kubernetesTestComponents: KubernetesTestComponents = _ + private var sparkAppConf: SparkAppConf = _ + private var image: String = _ + private var containerLocalSparkDistroExamplesJar: String = _ + private var appLocator: String = _ + private var driverPodName: String = _ + + override def beforeAll(): Unit = { + // The scalatest-maven-plugin gives system properties that are referenced but not set null + // values. We need to remove the null-value properties before initializing the test backend. + val nullValueProperties = System.getProperties.asScala + .filter(entry => entry._2.equals("null")) + .map(entry => entry._1.toString) + nullValueProperties.foreach { key => + System.clearProperty(key) + } + + val sparkDirProp = System.getProperty("spark.kubernetes.test.unpackSparkDir") + require(sparkDirProp != null, "Spark home directory must be provided in system properties.") + sparkHomeDir = Paths.get(sparkDirProp) + require(sparkHomeDir.toFile.isDirectory, + s"No directory found for spark home specified at $sparkHomeDir.") + val imageTag = getTestImageTag + val imageRepo = getTestImageRepo + image = s"$imageRepo/spark:$imageTag" + + val sparkDistroExamplesJarFile: File = sparkHomeDir.resolve(Paths.get("examples", "jars")) + .toFile + .listFiles(new PatternFilenameFilter(Pattern.compile("^spark-examples_.*\\.jar$")))(0) + containerLocalSparkDistroExamplesJar = s"local:///opt/spark/examples/jars/" + + s"${sparkDistroExamplesJarFile.getName}" + testBackend = IntegrationTestBackendFactory.getTestBackend + testBackend.initialize() + kubernetesTestComponents = new KubernetesTestComponents(testBackend.getKubernetesClient) + } + + override def afterAll(): Unit = { + testBackend.cleanUp() + } + + before { + appLocator = UUID.randomUUID().toString.replaceAll("-", "") + driverPodName = "spark-test-app-" + UUID.randomUUID().toString.replaceAll("-", "") + sparkAppConf = kubernetesTestComponents.newSparkAppConf() + .set("spark.kubernetes.container.image", image) + .set("spark.kubernetes.driver.pod.name", driverPodName) + .set("spark.kubernetes.driver.label.spark-app-locator", appLocator) + .set("spark.kubernetes.executor.label.spark-app-locator", appLocator) + if (!kubernetesTestComponents.hasUserSpecifiedNamespace) { + kubernetesTestComponents.createNamespace() + } + } + + after { + if (!kubernetesTestComponents.hasUserSpecifiedNamespace) { + kubernetesTestComponents.deleteNamespace() + } + deleteDriverPod() + } + + test("Run SparkPi with no resources") { + runSparkPiAndVerifyCompletion() + } + + test("Run SparkPi with a very long application name.") { + sparkAppConf.set("spark.app.name", "long" * 40) + runSparkPiAndVerifyCompletion() + } + + test("Run SparkPi with a master URL without a scheme.") { + val url = kubernetesTestComponents.kubernetesClient.getMasterUrl + val k8sMasterUrl = if (url.getPort < 0) { + s"k8s://${url.getHost}" + } else { + s"k8s://${url.getHost}:${url.getPort}" + } + sparkAppConf.set("spark.master", k8sMasterUrl) + runSparkPiAndVerifyCompletion() + } + + test("Run SparkPi with an argument.") { + runSparkPiAndVerifyCompletion(appArgs = Array("5")) + } + + test("Run SparkPi with custom labels, annotations, and environment variables.") { + sparkAppConf + .set("spark.kubernetes.driver.label.label1", "label1-value") + .set("spark.kubernetes.driver.label.label2", "label2-value") + .set("spark.kubernetes.driver.annotation.annotation1", "annotation1-value") + .set("spark.kubernetes.driver.annotation.annotation2", "annotation2-value") + .set("spark.kubernetes.driverEnv.ENV1", "VALUE1") + .set("spark.kubernetes.driverEnv.ENV2", "VALUE2") + .set("spark.kubernetes.executor.label.label1", "label1-value") + .set("spark.kubernetes.executor.label.label2", "label2-value") + .set("spark.kubernetes.executor.annotation.annotation1", "annotation1-value") + .set("spark.kubernetes.executor.annotation.annotation2", "annotation2-value") + .set("spark.executorEnv.ENV1", "VALUE1") + .set("spark.executorEnv.ENV2", "VALUE2") + + runSparkPiAndVerifyCompletion( + driverPodChecker = (driverPod: Pod) => { + doBasicDriverPodCheck(driverPod) + checkCustomSettings(driverPod) + }, + executorPodChecker = (executorPod: Pod) => { + doBasicExecutorPodCheck(executorPod) + checkCustomSettings(executorPod) + }) + } + + // TODO(ssuchter): Enable the below after debugging + // test("Run PageRank using remote data file") { + // sparkAppConf + // .set("spark.kubernetes.mountDependencies.filesDownloadDir", + // CONTAINER_LOCAL_FILE_DOWNLOAD_PATH) + // .set("spark.files", REMOTE_PAGE_RANK_DATA_FILE) + // runSparkPageRankAndVerifyCompletion( + // appArgs = Array(CONTAINER_LOCAL_DOWNLOADED_PAGE_RANK_DATA_FILE)) + // } + + private def runSparkPiAndVerifyCompletion( + appResource: String = containerLocalSparkDistroExamplesJar, + driverPodChecker: Pod => Unit = doBasicDriverPodCheck, + executorPodChecker: Pod => Unit = doBasicExecutorPodCheck, + appArgs: Array[String] = Array.empty[String], + appLocator: String = appLocator): Unit = { + runSparkApplicationAndVerifyCompletion( + appResource, + SPARK_PI_MAIN_CLASS, + Seq("Pi is roughly 3"), + appArgs, + driverPodChecker, + executorPodChecker, + appLocator) + } + + private def runSparkPageRankAndVerifyCompletion( + appResource: String = containerLocalSparkDistroExamplesJar, + driverPodChecker: Pod => Unit = doBasicDriverPodCheck, + executorPodChecker: Pod => Unit = doBasicExecutorPodCheck, + appArgs: Array[String], + appLocator: String = appLocator): Unit = { + runSparkApplicationAndVerifyCompletion( + appResource, + SPARK_PAGE_RANK_MAIN_CLASS, + Seq("1 has rank", "2 has rank", "3 has rank", "4 has rank"), + appArgs, + driverPodChecker, + executorPodChecker, + appLocator) + } + + private def runSparkApplicationAndVerifyCompletion( + appResource: String, + mainClass: String, + expectedLogOnCompletion: Seq[String], + appArgs: Array[String], + driverPodChecker: Pod => Unit, + executorPodChecker: Pod => Unit, + appLocator: String): Unit = { + val appArguments = SparkAppArguments( + mainAppResource = appResource, + mainClass = mainClass, + appArgs = appArgs) + SparkAppLauncher.launch(appArguments, sparkAppConf, TIMEOUT.value.toSeconds.toInt, sparkHomeDir) + + val driverPod = kubernetesTestComponents.kubernetesClient + .pods() + .withLabel("spark-app-locator", appLocator) + .withLabel("spark-role", "driver") + .list() + .getItems + .get(0) + driverPodChecker(driverPod) + + val executorPods = kubernetesTestComponents.kubernetesClient + .pods() + .withLabel("spark-app-locator", appLocator) + .withLabel("spark-role", "executor") + .list() + .getItems + executorPods.asScala.foreach { pod => + executorPodChecker(pod) + } + + Eventually.eventually(TIMEOUT, INTERVAL) { + expectedLogOnCompletion.foreach { e => + assert(kubernetesTestComponents.kubernetesClient + .pods() + .withName(driverPod.getMetadata.getName) + .getLog + .contains(e), "The application did not complete.") + } + } + } + + private def doBasicDriverPodCheck(driverPod: Pod): Unit = { + assert(driverPod.getMetadata.getName === driverPodName) + assert(driverPod.getSpec.getContainers.get(0).getImage === image) + assert(driverPod.getSpec.getContainers.get(0).getName === "spark-kubernetes-driver") + } + + private def doBasicExecutorPodCheck(executorPod: Pod): Unit = { + assert(executorPod.getSpec.getContainers.get(0).getImage === image) + assert(executorPod.getSpec.getContainers.get(0).getName === "executor") + } + + private def checkCustomSettings(pod: Pod): Unit = { + assert(pod.getMetadata.getLabels.get("label1") === "label1-value") + assert(pod.getMetadata.getLabels.get("label2") === "label2-value") + assert(pod.getMetadata.getAnnotations.get("annotation1") === "annotation1-value") + assert(pod.getMetadata.getAnnotations.get("annotation2") === "annotation2-value") + + val container = pod.getSpec.getContainers.get(0) + val envVars = container + .getEnv + .asScala + .map { env => + (env.getName, env.getValue) + } + .toMap + assert(envVars("ENV1") === "VALUE1") + assert(envVars("ENV2") === "VALUE2") + } + + private def deleteDriverPod(): Unit = { + kubernetesTestComponents.kubernetesClient.pods().withName(driverPodName).delete() + Eventually.eventually(TIMEOUT, INTERVAL) { + assert(kubernetesTestComponents.kubernetesClient + .pods() + .withName(driverPodName) + .get() == null) + } + } +} + +private[spark] object KubernetesSuite { + + val TIMEOUT = PatienceConfiguration.Timeout(Span(2, Minutes)) + val INTERVAL = PatienceConfiguration.Interval(Span(2, Seconds)) + val SPARK_PI_MAIN_CLASS: String = "org.apache.spark.examples.SparkPi" + val SPARK_PAGE_RANK_MAIN_CLASS: String = "org.apache.spark.examples.SparkPageRank" + + // val CONTAINER_LOCAL_FILE_DOWNLOAD_PATH = "/var/spark-data/spark-files" + + // val REMOTE_PAGE_RANK_DATA_FILE = + // "https://storage.googleapis.com/spark-k8s-integration-tests/files/pagerank_data.txt" + // val CONTAINER_LOCAL_DOWNLOADED_PAGE_RANK_DATA_FILE = + // s"$CONTAINER_LOCAL_FILE_DOWNLOAD_PATH/pagerank_data.txt" + + // case object ShuffleNotReadyException extends Exception +} diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesTestComponents.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesTestComponents.scala new file mode 100644 index 0000000000000..48727142dd052 --- /dev/null +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesTestComponents.scala @@ -0,0 +1,120 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.integrationtest + +import java.nio.file.{Path, Paths} +import java.util.UUID + +import scala.collection.JavaConverters._ +import scala.collection.mutable + +import io.fabric8.kubernetes.client.DefaultKubernetesClient +import org.scalatest.concurrent.Eventually + +import org.apache.spark.internal.Logging + +private[spark] class KubernetesTestComponents(defaultClient: DefaultKubernetesClient) { + + val namespaceOption = Option(System.getProperty("spark.kubernetes.test.namespace")) + val hasUserSpecifiedNamespace = namespaceOption.isDefined + val namespace = namespaceOption.getOrElse(UUID.randomUUID().toString.replaceAll("-", "")) + private val serviceAccountName = + Option(System.getProperty("spark.kubernetes.test.serviceAccountName")) + .getOrElse("default") + val kubernetesClient = defaultClient.inNamespace(namespace) + val clientConfig = kubernetesClient.getConfiguration + + def createNamespace(): Unit = { + defaultClient.namespaces.createNew() + .withNewMetadata() + .withName(namespace) + .endMetadata() + .done() + } + + def deleteNamespace(): Unit = { + defaultClient.namespaces.withName(namespace).delete() + Eventually.eventually(KubernetesSuite.TIMEOUT, KubernetesSuite.INTERVAL) { + val namespaceList = defaultClient + .namespaces() + .list() + .getItems + .asScala + require(!namespaceList.exists(_.getMetadata.getName == namespace)) + } + } + + def newSparkAppConf(): SparkAppConf = { + new SparkAppConf() + .set("spark.master", s"k8s://${kubernetesClient.getMasterUrl}") + .set("spark.kubernetes.namespace", namespace) + .set("spark.executor.memory", "500m") + .set("spark.executor.cores", "1") + .set("spark.executors.instances", "1") + .set("spark.app.name", "spark-test-app") + .set("spark.ui.enabled", "true") + .set("spark.testing", "false") + .set("spark.kubernetes.submission.waitAppCompletion", "false") + .set("spark.kubernetes.authenticate.driver.serviceAccountName", serviceAccountName) + } +} + +private[spark] class SparkAppConf { + + private val map = mutable.Map[String, String]() + + def set(key: String, value: String): SparkAppConf = { + map.put(key, value) + this + } + + def get(key: String): String = map.getOrElse(key, "") + + def setJars(jars: Seq[String]): Unit = set("spark.jars", jars.mkString(",")) + + override def toString: String = map.toString + + def toStringArray: Iterable[String] = map.toList.flatMap(t => List("--conf", s"${t._1}=${t._2}")) +} + +private[spark] case class SparkAppArguments( + mainAppResource: String, + mainClass: String, + appArgs: Array[String]) + +private[spark] object SparkAppLauncher extends Logging { + + def launch( + appArguments: SparkAppArguments, + appConf: SparkAppConf, + timeoutSecs: Int, + sparkHomeDir: Path): Unit = { + val sparkSubmitExecutable = sparkHomeDir.resolve(Paths.get("bin", "spark-submit")) + logInfo(s"Launching a spark app with arguments $appArguments and conf $appConf") + val appArgsArray = + if (appArguments.appArgs.length > 0) Array(appArguments.appArgs.mkString(" ")) + else Array[String]() + val commandLine = (Array(sparkSubmitExecutable.toFile.getAbsolutePath, + "--deploy-mode", "cluster", + "--class", appArguments.mainClass, + "--master", appConf.get("spark.master") + ) ++ appConf.toStringArray :+ + appArguments.mainAppResource) ++ + appArgsArray + ProcessUtils.executeProcess(commandLine, timeoutSecs) + } +} diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/ProcessUtils.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/ProcessUtils.scala new file mode 100644 index 0000000000000..d8f3a6cec05c3 --- /dev/null +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/ProcessUtils.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.integrationtest + +import java.util.concurrent.TimeUnit + +import scala.collection.mutable.ArrayBuffer +import scala.io.Source + +import org.apache.spark.internal.Logging + +object ProcessUtils extends Logging { + /** + * executeProcess is used to run a command and return the output if it + * completes within timeout seconds. + */ + def executeProcess(fullCommand: Array[String], timeout: Long): Seq[String] = { + val pb = new ProcessBuilder().command(fullCommand: _*) + pb.redirectErrorStream(true) + val proc = pb.start() + val outputLines = new ArrayBuffer[String] + Utils.tryWithResource(proc.getInputStream)( + Source.fromInputStream(_, "UTF-8").getLines().foreach { line => + logInfo(line) + outputLines += line + }) + assert(proc.waitFor(timeout, TimeUnit.SECONDS), + s"Timed out while executing ${fullCommand.mkString(" ")}") + assert(proc.exitValue == 0, s"Failed to execute ${fullCommand.mkString(" ")}") + outputLines + } +} diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/SparkReadinessWatcher.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/SparkReadinessWatcher.scala new file mode 100644 index 0000000000000..f1fd6dc19ce54 --- /dev/null +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/SparkReadinessWatcher.scala @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.integrationtest + +import java.util.concurrent.TimeUnit + +import com.google.common.util.concurrent.SettableFuture +import io.fabric8.kubernetes.api.model.HasMetadata +import io.fabric8.kubernetes.client.{KubernetesClientException, Watcher} +import io.fabric8.kubernetes.client.Watcher.Action +import io.fabric8.kubernetes.client.internal.readiness.Readiness + +private[spark] class SparkReadinessWatcher[T <: HasMetadata] extends Watcher[T] { + + private val signal = SettableFuture.create[Boolean] + + override def eventReceived(action: Action, resource: T): Unit = { + if ((action == Action.MODIFIED || action == Action.ADDED) && + Readiness.isReady(resource)) { + signal.set(true) + } + } + + override def onClose(cause: KubernetesClientException): Unit = {} + + def waitUntilReady(): Boolean = signal.get(60, TimeUnit.SECONDS) +} diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/Utils.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/Utils.scala new file mode 100644 index 0000000000000..663f8b6523ac8 --- /dev/null +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/Utils.scala @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.integrationtest + +import java.io.Closeable +import java.net.URI + +import org.apache.spark.internal.Logging + +object Utils extends Logging { + + def tryWithResource[R <: Closeable, T](createResource: => R)(f: R => T): T = { + val resource = createResource + try f.apply(resource) finally resource.close() + } +} diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/backend/IntegrationTestBackend.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/backend/IntegrationTestBackend.scala new file mode 100644 index 0000000000000..284712c6d250e --- /dev/null +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/backend/IntegrationTestBackend.scala @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.k8s.integrationtest.backend + +import io.fabric8.kubernetes.client.DefaultKubernetesClient + +import org.apache.spark.deploy.k8s.integrationtest.backend.minikube.MinikubeTestBackend + +private[spark] trait IntegrationTestBackend { + def initialize(): Unit + def getKubernetesClient: DefaultKubernetesClient + def cleanUp(): Unit = {} +} + +private[spark] object IntegrationTestBackendFactory { + val deployModeConfigKey = "spark.kubernetes.test.deployMode" + + def getTestBackend: IntegrationTestBackend = { + val deployMode = Option(System.getProperty(deployModeConfigKey)) + .getOrElse("minikube") + if (deployMode == "minikube") { + MinikubeTestBackend + } else { + throw new IllegalArgumentException( + "Invalid " + deployModeConfigKey + ": " + deployMode) + } + } +} diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/backend/minikube/Minikube.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/backend/minikube/Minikube.scala new file mode 100644 index 0000000000000..6494cbc18f33e --- /dev/null +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/backend/minikube/Minikube.scala @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.integrationtest.backend.minikube + +import java.io.File +import java.nio.file.Paths + +import io.fabric8.kubernetes.client.{ConfigBuilder, DefaultKubernetesClient} + +import org.apache.spark.deploy.k8s.integrationtest.ProcessUtils +import org.apache.spark.internal.Logging + +// TODO support windows +private[spark] object Minikube extends Logging { + + private val MINIKUBE_STARTUP_TIMEOUT_SECONDS = 60 + + def getMinikubeIp: String = { + val outputs = executeMinikube("ip") + .filter(_.matches("^\\d{1,3}\\.\\d{1,3}\\.\\d{1,3}\\.\\d{1,3}$")) + assert(outputs.size == 1, "Unexpected amount of output from minikube ip") + outputs.head + } + + def getMinikubeStatus: MinikubeStatus.Value = { + val statusString = executeMinikube("status") + .filter(line => line.contains("minikubeVM: ") || line.contains("minikube:")) + .head + .replaceFirst("minikubeVM: ", "") + .replaceFirst("minikube: ", "") + MinikubeStatus.unapply(statusString) + .getOrElse(throw new IllegalStateException(s"Unknown status $statusString")) + } + + def getKubernetesClient: DefaultKubernetesClient = { + val kubernetesMaster = s"https://${getMinikubeIp}:8443" + val userHome = System.getProperty("user.home") + val kubernetesConf = new ConfigBuilder() + .withApiVersion("v1") + .withMasterUrl(kubernetesMaster) + .withCaCertFile(Paths.get(userHome, ".minikube", "ca.crt").toFile.getAbsolutePath) + .withClientCertFile(Paths.get(userHome, ".minikube", "apiserver.crt").toFile.getAbsolutePath) + .withClientKeyFile(Paths.get(userHome, ".minikube", "apiserver.key").toFile.getAbsolutePath) + .build() + new DefaultKubernetesClient(kubernetesConf) + } + + private def executeMinikube(action: String, args: String*): Seq[String] = { + ProcessUtils.executeProcess( + Array("bash", "-c", s"minikube $action") ++ args, MINIKUBE_STARTUP_TIMEOUT_SECONDS) + } +} + +private[spark] object MinikubeStatus extends Enumeration { + + // The following states are listed according to + // https://github.com/docker/machine/blob/master/libmachine/state/state.go. + val STARTING = status("Starting") + val RUNNING = status("Running") + val PAUSED = status("Paused") + val STOPPING = status("Stopping") + val STOPPED = status("Stopped") + val ERROR = status("Error") + val TIMEOUT = status("Timeout") + val SAVED = status("Saved") + val NONE = status("") + + def status(value: String): Value = new Val(nextId, value) + def unapply(s: String): Option[Value] = values.find(s == _.toString) +} diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/backend/minikube/MinikubeTestBackend.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/backend/minikube/MinikubeTestBackend.scala new file mode 100644 index 0000000000000..cb9324179d70e --- /dev/null +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/backend/minikube/MinikubeTestBackend.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.integrationtest.backend.minikube + +import io.fabric8.kubernetes.client.DefaultKubernetesClient + +import org.apache.spark.deploy.k8s.integrationtest.backend.IntegrationTestBackend + +private[spark] object MinikubeTestBackend extends IntegrationTestBackend { + + private var defaultClient: DefaultKubernetesClient = _ + + override def initialize(): Unit = { + val minikubeStatus = Minikube.getMinikubeStatus + require(minikubeStatus == MinikubeStatus.RUNNING, + s"Minikube must be running to use the Minikube backend for integration tests." + + s" Current status is: $minikubeStatus.") + defaultClient = Minikube.getKubernetesClient + } + + override def cleanUp(): Unit = { + super.cleanUp() + } + + override def getKubernetesClient: DefaultKubernetesClient = { + defaultClient + } +} diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/config.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/config.scala new file mode 100644 index 0000000000000..a81ef455c6766 --- /dev/null +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/config.scala @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.integrationtest + +import java.io.File + +import com.google.common.base.Charsets +import com.google.common.io.Files + +package object config { + def getTestImageTag: String = { + val imageTagFileProp = System.getProperty("spark.kubernetes.test.imageTagFile") + require(imageTagFileProp != null, "Image tag file must be provided in system properties.") + val imageTagFile = new File(imageTagFileProp) + require(imageTagFile.isFile, s"No file found for image tag at ${imageTagFile.getAbsolutePath}.") + Files.toString(imageTagFile, Charsets.UTF_8).trim + } + + def getTestImageRepo: String = { + val imageRepo = System.getProperty("spark.kubernetes.test.imageRepo") + require(imageRepo != null, "Image repo must be provided in system properties.") + imageRepo + } +} diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/constants.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/constants.scala new file mode 100644 index 0000000000000..0807a68cd823c --- /dev/null +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/constants.scala @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.integrationtest + +package object constants { + val MINIKUBE_TEST_BACKEND = "minikube" + val GCE_TEST_BACKEND = "gce" +} From 36a3409134687d6a2894cd6a77554b8439cacec1 Mon Sep 17 00:00:00 2001 From: Thiruvasakan Paramasivan Date: Fri, 8 Jun 2018 17:17:43 -0700 Subject: [PATCH 0935/1161] [SPARK-24412][SQL] Adding docs about automagical type casting in `isin` and `isInCollection` APIs ## What changes were proposed in this pull request? Update documentation for `isInCollection` API to clealy explain the "auto-casting" of elements if their types are different. ## How was this patch tested? No-Op Author: Thiruvasakan Paramasivan Closes #21519 from trvskn/sql-doc-update. --- .../scala/org/apache/spark/sql/Column.scala | 24 +++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index b3e59f53ee3de..2dbb53e7c906b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -781,6 +781,14 @@ class Column(val expr: Expression) extends Logging { * A boolean expression that is evaluated to true if the value of this expression is contained * by the evaluated values of the arguments. * + * Note: Since the type of the elements in the list are inferred only during the run time, + * the elements will be "up-casted" to the most common type for comparison. + * For eg: + * 1) In the case of "Int vs String", the "Int" will be up-casted to "String" and the + * comparison will look like "String vs String". + * 2) In the case of "Float vs Double", the "Float" will be up-casted to "Double" and the + * comparison will look like "Double vs Double" + * * @group expr_ops * @since 1.5.0 */ @@ -791,6 +799,14 @@ class Column(val expr: Expression) extends Logging { * A boolean expression that is evaluated to true if the value of this expression is contained * by the provided collection. * + * Note: Since the type of the elements in the collection are inferred only during the run time, + * the elements will be "up-casted" to the most common type for comparison. + * For eg: + * 1) In the case of "Int vs String", the "Int" will be up-casted to "String" and the + * comparison will look like "String vs String". + * 2) In the case of "Float vs Double", the "Float" will be up-casted to "Double" and the + * comparison will look like "Double vs Double" + * * @group expr_ops * @since 2.4.0 */ @@ -800,6 +816,14 @@ class Column(val expr: Expression) extends Logging { * A boolean expression that is evaluated to true if the value of this expression is contained * by the provided collection. * + * Note: Since the type of the elements in the collection are inferred only during the run time, + * the elements will be "up-casted" to the most common type for comparison. + * For eg: + * 1) In the case of "Int vs String", the "Int" will be up-casted to "String" and the + * comparison will look like "String vs String". + * 2) In the case of "Float vs Double", the "Float" will be up-casted to "Double" and the + * comparison will look like "Double vs Double" + * * @group java_expr_ops * @since 2.4.0 */ From f07c5064a3967cdddf57c2469635ee50a26d864c Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Fri, 8 Jun 2018 18:51:56 -0700 Subject: [PATCH 0936/1161] [SPARK-24468][SQL] Handle negative scale when adjusting precision for decimal operations ## What changes were proposed in this pull request? In SPARK-22036 we introduced the possibility to allow precision loss in arithmetic operations (according to the SQL standard). The implementation was drawn from Hive's one, where Decimals with a negative scale are not allowed in the operations. The PR handles the case when the scale is negative, removing the assertion that it is not. ## How was this patch tested? added UTs Author: Marco Gaido Closes #21499 from mgaido91/SPARK-24468. --- .../apache/spark/sql/types/DecimalType.scala | 8 +- .../analysis/DecimalPrecisionSuite.scala | 9 + .../native/decimalArithmeticOperations.sql | 4 + .../decimalArithmeticOperations.sql.out | 164 +++++++++++------- 4 files changed, 117 insertions(+), 68 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala index ef3b67c0d48d0..dbf51c398fa47 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala @@ -161,13 +161,17 @@ object DecimalType extends AbstractDataType { * This method is used only when `spark.sql.decimalOperations.allowPrecisionLoss` is set to true. */ private[sql] def adjustPrecisionScale(precision: Int, scale: Int): DecimalType = { - // Assumptions: + // Assumption: assert(precision >= scale) - assert(scale >= 0) if (precision <= MAX_PRECISION) { // Adjustment only needed when we exceed max precision DecimalType(precision, scale) + } else if (scale < 0) { + // Decimal can have negative scale (SPARK-24468). In this case, we cannot allow a precision + // loss since we would cause a loss of digits in the integer part. + // In this case, we are likely to meet an overflow. + DecimalType(MAX_PRECISION, scale) } else { // Precision/scale exceed maximum precision. Result must be adjusted to MAX_PRECISION. val intDigits = precision - scale diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala index c86dc18dfa680..bd87ca6017e99 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala @@ -272,6 +272,15 @@ class DecimalPrecisionSuite extends AnalysisTest with BeforeAndAfter { } } + test("SPARK-24468: operations on decimals with negative scale") { + val a = AttributeReference("a", DecimalType(3, -10))() + val b = AttributeReference("b", DecimalType(1, -1))() + val c = AttributeReference("c", DecimalType(35, 1))() + checkType(Multiply(a, b), DecimalType(5, -11)) + checkType(Multiply(a, c), DecimalType(38, -9)) + checkType(Multiply(b, c), DecimalType(37, 0)) + } + /** strength reduction for integer/decimal comparisons */ def ruleTest(initial: Expression, transformed: Expression): Unit = { val testRelation = LocalRelation(AttributeReference("a", IntegerType)()) diff --git a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/decimalArithmeticOperations.sql b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/decimalArithmeticOperations.sql index 9be7fcdadfea8..28a0e20c0f495 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/decimalArithmeticOperations.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/decimalArithmeticOperations.sql @@ -40,12 +40,14 @@ select 10.3000 * 3.0; select 10.30000 * 30.0; select 10.300000000000000000 * 3.000000000000000000; select 10.300000000000000000 * 3.0000000000000000000; +select 2.35E10 * 1.0; -- arithmetic operations causing an overflow return NULL select (5e36 + 0.1) + 5e36; select (-4e36 - 0.1) - 7e36; select 12345678901234567890.0 * 12345678901234567890.0; select 1e35 / 0.1; +select 1.2345678901234567890E30 * 1.2345678901234567890E25; -- arithmetic operations causing a precision loss are truncated select 12345678912345678912345678912.1234567 + 9999999999999999999999999999999.12345; @@ -67,12 +69,14 @@ select 10.3000 * 3.0; select 10.30000 * 30.0; select 10.300000000000000000 * 3.000000000000000000; select 10.300000000000000000 * 3.0000000000000000000; +select 2.35E10 * 1.0; -- arithmetic operations causing an overflow return NULL select (5e36 + 0.1) + 5e36; select (-4e36 - 0.1) - 7e36; select 12345678901234567890.0 * 12345678901234567890.0; select 1e35 / 0.1; +select 1.2345678901234567890E30 * 1.2345678901234567890E25; -- arithmetic operations causing a precision loss return NULL select 12345678912345678912345678912.1234567 + 9999999999999999999999999999999.12345; diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/decimalArithmeticOperations.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/decimalArithmeticOperations.sql.out index 6bfdb84548d4d..cbf44548b3cce 100644 --- a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/decimalArithmeticOperations.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/decimalArithmeticOperations.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 36 +-- Number of queries: 40 -- !query 0 @@ -114,190 +114,222 @@ struct<(CAST(10.300000000000000000 AS DECIMAL(21,19)) * CAST(3.00000000000000000 -- !query 13 -select (5e36 + 0.1) + 5e36 +select 2.35E10 * 1.0 -- !query 13 schema -struct<(CAST((CAST(5E+36 AS DECIMAL(38,1)) + CAST(0.1 AS DECIMAL(38,1))) AS DECIMAL(38,1)) + CAST(5E+36 AS DECIMAL(38,1))):decimal(38,1)> +struct<(CAST(2.35E+10 AS DECIMAL(12,1)) * CAST(1.0 AS DECIMAL(12,1))):decimal(6,-7)> -- !query 13 output -NULL +23500000000 -- !query 14 -select (-4e36 - 0.1) - 7e36 +select (5e36 + 0.1) + 5e36 -- !query 14 schema -struct<(CAST((CAST(-4E+36 AS DECIMAL(38,1)) - CAST(0.1 AS DECIMAL(38,1))) AS DECIMAL(38,1)) - CAST(7E+36 AS DECIMAL(38,1))):decimal(38,1)> +struct<(CAST((CAST(5E+36 AS DECIMAL(38,1)) + CAST(0.1 AS DECIMAL(38,1))) AS DECIMAL(38,1)) + CAST(5E+36 AS DECIMAL(38,1))):decimal(38,1)> -- !query 14 output NULL -- !query 15 -select 12345678901234567890.0 * 12345678901234567890.0 +select (-4e36 - 0.1) - 7e36 -- !query 15 schema -struct<(12345678901234567890.0 * 12345678901234567890.0):decimal(38,2)> +struct<(CAST((CAST(-4E+36 AS DECIMAL(38,1)) - CAST(0.1 AS DECIMAL(38,1))) AS DECIMAL(38,1)) - CAST(7E+36 AS DECIMAL(38,1))):decimal(38,1)> -- !query 15 output NULL -- !query 16 -select 1e35 / 0.1 +select 12345678901234567890.0 * 12345678901234567890.0 -- !query 16 schema -struct<(CAST(1E+35 AS DECIMAL(37,1)) / CAST(0.1 AS DECIMAL(37,1))):decimal(38,6)> +struct<(12345678901234567890.0 * 12345678901234567890.0):decimal(38,2)> -- !query 16 output NULL -- !query 17 -select 12345678912345678912345678912.1234567 + 9999999999999999999999999999999.12345 +select 1e35 / 0.1 -- !query 17 schema -struct<(CAST(12345678912345678912345678912.1234567 AS DECIMAL(38,6)) + CAST(9999999999999999999999999999999.12345 AS DECIMAL(38,6))):decimal(38,6)> +struct<(CAST(1E+35 AS DECIMAL(37,1)) / CAST(0.1 AS DECIMAL(37,1))):decimal(38,6)> -- !query 17 output -10012345678912345678912345678911.246907 +NULL -- !query 18 -select 123456789123456789.1234567890 * 1.123456789123456789 +select 1.2345678901234567890E30 * 1.2345678901234567890E25 -- !query 18 schema -struct<(CAST(123456789123456789.1234567890 AS DECIMAL(36,18)) * CAST(1.123456789123456789 AS DECIMAL(36,18))):decimal(38,18)> +struct<(CAST(1.2345678901234567890E+30 AS DECIMAL(25,-6)) * CAST(1.2345678901234567890E+25 AS DECIMAL(25,-6))):decimal(38,-17)> -- !query 18 output -138698367904130467.654320988515622621 +NULL -- !query 19 -select 12345678912345.123456789123 / 0.000000012345678 +select 12345678912345678912345678912.1234567 + 9999999999999999999999999999999.12345 -- !query 19 schema -struct<(CAST(12345678912345.123456789123 AS DECIMAL(29,15)) / CAST(1.2345678E-8 AS DECIMAL(29,15))):decimal(38,9)> +struct<(CAST(12345678912345678912345678912.1234567 AS DECIMAL(38,6)) + CAST(9999999999999999999999999999999.12345 AS DECIMAL(38,6))):decimal(38,6)> -- !query 19 output -1000000073899961059796.725866332 +10012345678912345678912345678911.246907 -- !query 20 -set spark.sql.decimalOperations.allowPrecisionLoss=false +select 123456789123456789.1234567890 * 1.123456789123456789 -- !query 20 schema -struct +struct<(CAST(123456789123456789.1234567890 AS DECIMAL(36,18)) * CAST(1.123456789123456789 AS DECIMAL(36,18))):decimal(38,18)> -- !query 20 output -spark.sql.decimalOperations.allowPrecisionLoss false +138698367904130467.654320988515622621 -- !query 21 -select id, a+b, a-b, a*b, a/b from decimals_test order by id +select 12345678912345.123456789123 / 0.000000012345678 -- !query 21 schema -struct +struct<(CAST(12345678912345.123456789123 AS DECIMAL(29,15)) / CAST(1.2345678E-8 AS DECIMAL(29,15))):decimal(38,9)> -- !query 21 output -1 1099 -899 NULL 0.1001001001001001 -2 24690.246 0 NULL 1 -3 1234.2234567891011 -1233.9765432108989 NULL 0.000100037913541123 -4 123456789123456790.123456789123456789 123456789123456787.876543210876543211 NULL 109890109097814272.043109406191131436 +1000000073899961059796.725866332 -- !query 22 -select id, a*10, b/10 from decimals_test order by id +set spark.sql.decimalOperations.allowPrecisionLoss=false -- !query 22 schema -struct +struct -- !query 22 output -1 1000 99.9 -2 123451.23 1234.5123 -3 1.234567891011 123.41 -4 1234567891234567890 0.1123456789123456789 +spark.sql.decimalOperations.allowPrecisionLoss false -- !query 23 -select 10.3 * 3.0 +select id, a+b, a-b, a*b, a/b from decimals_test order by id -- !query 23 schema -struct<(CAST(10.3 AS DECIMAL(3,1)) * CAST(3.0 AS DECIMAL(3,1))):decimal(6,2)> +struct -- !query 23 output -30.9 +1 1099 -899 NULL 0.1001001001001001 +2 24690.246 0 NULL 1 +3 1234.2234567891011 -1233.9765432108989 NULL 0.000100037913541123 +4 123456789123456790.123456789123456789 123456789123456787.876543210876543211 NULL 109890109097814272.043109406191131436 -- !query 24 -select 10.3000 * 3.0 +select id, a*10, b/10 from decimals_test order by id -- !query 24 schema -struct<(CAST(10.3000 AS DECIMAL(6,4)) * CAST(3.0 AS DECIMAL(6,4))):decimal(9,5)> +struct -- !query 24 output -30.9 +1 1000 99.9 +2 123451.23 1234.5123 +3 1.234567891011 123.41 +4 1234567891234567890 0.1123456789123456789 -- !query 25 -select 10.30000 * 30.0 +select 10.3 * 3.0 -- !query 25 schema -struct<(CAST(10.30000 AS DECIMAL(7,5)) * CAST(30.0 AS DECIMAL(7,5))):decimal(11,6)> +struct<(CAST(10.3 AS DECIMAL(3,1)) * CAST(3.0 AS DECIMAL(3,1))):decimal(6,2)> -- !query 25 output -309 +30.9 -- !query 26 -select 10.300000000000000000 * 3.000000000000000000 +select 10.3000 * 3.0 -- !query 26 schema -struct<(CAST(10.300000000000000000 AS DECIMAL(20,18)) * CAST(3.000000000000000000 AS DECIMAL(20,18))):decimal(38,36)> +struct<(CAST(10.3000 AS DECIMAL(6,4)) * CAST(3.0 AS DECIMAL(6,4))):decimal(9,5)> -- !query 26 output 30.9 -- !query 27 -select 10.300000000000000000 * 3.0000000000000000000 +select 10.30000 * 30.0 -- !query 27 schema -struct<(CAST(10.300000000000000000 AS DECIMAL(21,19)) * CAST(3.0000000000000000000 AS DECIMAL(21,19))):decimal(38,37)> +struct<(CAST(10.30000 AS DECIMAL(7,5)) * CAST(30.0 AS DECIMAL(7,5))):decimal(11,6)> -- !query 27 output -NULL +309 -- !query 28 -select (5e36 + 0.1) + 5e36 +select 10.300000000000000000 * 3.000000000000000000 -- !query 28 schema -struct<(CAST((CAST(5E+36 AS DECIMAL(38,1)) + CAST(0.1 AS DECIMAL(38,1))) AS DECIMAL(38,1)) + CAST(5E+36 AS DECIMAL(38,1))):decimal(38,1)> +struct<(CAST(10.300000000000000000 AS DECIMAL(20,18)) * CAST(3.000000000000000000 AS DECIMAL(20,18))):decimal(38,36)> -- !query 28 output -NULL +30.9 -- !query 29 -select (-4e36 - 0.1) - 7e36 +select 10.300000000000000000 * 3.0000000000000000000 -- !query 29 schema -struct<(CAST((CAST(-4E+36 AS DECIMAL(38,1)) - CAST(0.1 AS DECIMAL(38,1))) AS DECIMAL(38,1)) - CAST(7E+36 AS DECIMAL(38,1))):decimal(38,1)> +struct<(CAST(10.300000000000000000 AS DECIMAL(21,19)) * CAST(3.0000000000000000000 AS DECIMAL(21,19))):decimal(38,37)> -- !query 29 output NULL -- !query 30 -select 12345678901234567890.0 * 12345678901234567890.0 +select 2.35E10 * 1.0 -- !query 30 schema -struct<(12345678901234567890.0 * 12345678901234567890.0):decimal(38,2)> +struct<(CAST(2.35E+10 AS DECIMAL(12,1)) * CAST(1.0 AS DECIMAL(12,1))):decimal(6,-7)> -- !query 30 output -NULL +23500000000 -- !query 31 -select 1e35 / 0.1 +select (5e36 + 0.1) + 5e36 -- !query 31 schema -struct<(CAST(1E+35 AS DECIMAL(37,1)) / CAST(0.1 AS DECIMAL(37,1))):decimal(38,3)> +struct<(CAST((CAST(5E+36 AS DECIMAL(38,1)) + CAST(0.1 AS DECIMAL(38,1))) AS DECIMAL(38,1)) + CAST(5E+36 AS DECIMAL(38,1))):decimal(38,1)> -- !query 31 output NULL -- !query 32 -select 12345678912345678912345678912.1234567 + 9999999999999999999999999999999.12345 +select (-4e36 - 0.1) - 7e36 -- !query 32 schema -struct<(CAST(12345678912345678912345678912.1234567 AS DECIMAL(38,7)) + CAST(9999999999999999999999999999999.12345 AS DECIMAL(38,7))):decimal(38,7)> +struct<(CAST((CAST(-4E+36 AS DECIMAL(38,1)) - CAST(0.1 AS DECIMAL(38,1))) AS DECIMAL(38,1)) - CAST(7E+36 AS DECIMAL(38,1))):decimal(38,1)> -- !query 32 output NULL -- !query 33 -select 123456789123456789.1234567890 * 1.123456789123456789 +select 12345678901234567890.0 * 12345678901234567890.0 -- !query 33 schema -struct<(CAST(123456789123456789.1234567890 AS DECIMAL(36,18)) * CAST(1.123456789123456789 AS DECIMAL(36,18))):decimal(38,28)> +struct<(12345678901234567890.0 * 12345678901234567890.0):decimal(38,2)> -- !query 33 output NULL -- !query 34 -select 12345678912345.123456789123 / 0.000000012345678 +select 1e35 / 0.1 -- !query 34 schema -struct<(CAST(12345678912345.123456789123 AS DECIMAL(29,15)) / CAST(1.2345678E-8 AS DECIMAL(29,15))):decimal(38,18)> +struct<(CAST(1E+35 AS DECIMAL(37,1)) / CAST(0.1 AS DECIMAL(37,1))):decimal(38,3)> -- !query 34 output NULL -- !query 35 -drop table decimals_test +select 1.2345678901234567890E30 * 1.2345678901234567890E25 -- !query 35 schema -struct<> +struct<(CAST(1.2345678901234567890E+30 AS DECIMAL(25,-6)) * CAST(1.2345678901234567890E+25 AS DECIMAL(25,-6))):decimal(38,-17)> -- !query 35 output +NULL + + +-- !query 36 +select 12345678912345678912345678912.1234567 + 9999999999999999999999999999999.12345 +-- !query 36 schema +struct<(CAST(12345678912345678912345678912.1234567 AS DECIMAL(38,7)) + CAST(9999999999999999999999999999999.12345 AS DECIMAL(38,7))):decimal(38,7)> +-- !query 36 output +NULL + + +-- !query 37 +select 123456789123456789.1234567890 * 1.123456789123456789 +-- !query 37 schema +struct<(CAST(123456789123456789.1234567890 AS DECIMAL(36,18)) * CAST(1.123456789123456789 AS DECIMAL(36,18))):decimal(38,28)> +-- !query 37 output +NULL + + +-- !query 38 +select 12345678912345.123456789123 / 0.000000012345678 +-- !query 38 schema +struct<(CAST(12345678912345.123456789123 AS DECIMAL(29,15)) / CAST(1.2345678E-8 AS DECIMAL(29,15))):decimal(38,18)> +-- !query 38 output +NULL + + +-- !query 39 +drop table decimals_test +-- !query 39 schema +struct<> +-- !query 39 output From 3e5b4ae63a468858ff8b9f7f3231cc877846a0af Mon Sep 17 00:00:00 2001 From: edorigatti Date: Mon, 11 Jun 2018 10:15:42 +0800 Subject: [PATCH 0937/1161] [SPARK-23754][PYTHON][FOLLOWUP] Move UDF stop iteration wrapping from driver to executor ## What changes were proposed in this pull request? SPARK-23754 was fixed in #21383 by changing the UDF code to wrap the user function, but this required a hack to save its argspec. This PR reverts this change and fixes the `StopIteration` bug in the worker ## How does this work? The root of the problem is that when an user-supplied function raises a `StopIteration`, pyspark might stop processing data, if this function is used in a for-loop. The solution is to catch `StopIteration`s exceptions and re-raise them as `RuntimeError`s, so that the execution fails and the error is reported to the user. This is done using the `fail_on_stopiteration` wrapper, in different ways depending on where the function is used: - In RDDs, the user function is wrapped in the driver, because this function is also called in the driver itself. - In SQL UDFs, the function is wrapped in the worker, since all processing happens there. Moreover, the worker needs the signature of the user function, which is lost when wrapping it, but passing this signature to the worker requires a not so nice hack. ## How was this patch tested? Same tests, plus tests for pandas UDFs Author: edorigatti Closes #21467 from e-dorigatti/fix_udf_hack. --- python/pyspark/sql/tests.py | 71 ++++++++++++++++++++++++++++--------- python/pyspark/sql/udf.py | 14 ++------ python/pyspark/tests.py | 37 +++++++++++-------- python/pyspark/util.py | 9 ++--- python/pyspark/worker.py | 18 ++++++---- 5 files changed, 92 insertions(+), 57 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 487eb19c3b98a..4a3941de8a6a6 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -900,22 +900,6 @@ def __call__(self, x): self.assertEqual(f, f_.func) self.assertEqual(return_type, f_.returnType) - def test_stopiteration_in_udf(self): - # test for SPARK-23754 - from pyspark.sql.functions import udf - from py4j.protocol import Py4JJavaError - - def foo(x): - raise StopIteration() - - with self.assertRaises(Py4JJavaError) as cm: - self.spark.range(0, 1000).withColumn('v', udf(foo)('id')).show() - - self.assertIn( - "Caught StopIteration thrown from user's code; failing the task", - cm.exception.java_exception.toString() - ) - def test_validate_column_types(self): from pyspark.sql.functions import udf, to_json from pyspark.sql.column import _to_java_column @@ -4144,6 +4128,61 @@ def foo(df): def foo(k, v, w): return k + def test_stopiteration_in_udf(self): + from pyspark.sql.functions import udf, pandas_udf, PandasUDFType + from py4j.protocol import Py4JJavaError + + def foo(x): + raise StopIteration() + + def foofoo(x, y): + raise StopIteration() + + exc_message = "Caught StopIteration thrown from user's code; failing the task" + df = self.spark.range(0, 100) + + # plain udf (test for SPARK-23754) + self.assertRaisesRegexp( + Py4JJavaError, + exc_message, + df.withColumn('v', udf(foo)('id')).collect + ) + + # pandas scalar udf + self.assertRaisesRegexp( + Py4JJavaError, + exc_message, + df.withColumn( + 'v', pandas_udf(foo, 'double', PandasUDFType.SCALAR)('id') + ).collect + ) + + # pandas grouped map + self.assertRaisesRegexp( + Py4JJavaError, + exc_message, + df.groupBy('id').apply( + pandas_udf(foo, df.schema, PandasUDFType.GROUPED_MAP) + ).collect + ) + + self.assertRaisesRegexp( + Py4JJavaError, + exc_message, + df.groupBy('id').apply( + pandas_udf(foofoo, df.schema, PandasUDFType.GROUPED_MAP) + ).collect + ) + + # pandas grouped agg + self.assertRaisesRegexp( + Py4JJavaError, + exc_message, + df.groupBy('id').agg( + pandas_udf(foo, 'double', PandasUDFType.GROUPED_AGG)('id') + ).collect + ) + @unittest.skipIf( not _have_pandas or not _have_pyarrow, diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index c8fb49d7c2b65..9dbe49b831cef 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -25,7 +25,7 @@ from pyspark.sql.column import Column, _to_java_column, _to_seq from pyspark.sql.types import StringType, DataType, StructType, _parse_datatype_string,\ to_arrow_type, to_arrow_schema -from pyspark.util import _get_argspec, fail_on_stopiteration +from pyspark.util import _get_argspec __all__ = ["UDFRegistration"] @@ -157,17 +157,7 @@ def _create_judf(self): spark = SparkSession.builder.getOrCreate() sc = spark.sparkContext - func = fail_on_stopiteration(self.func) - - # for pandas UDFs the worker needs to know if the function takes - # one or two arguments, but the signature is lost when wrapping with - # fail_on_stopiteration, so we store it here - if self.evalType in (PythonEvalType.SQL_SCALAR_PANDAS_UDF, - PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, - PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF): - func._argspec = _get_argspec(self.func) - - wrapped_func = _wrap_function(sc, func, self.returnType) + wrapped_func = _wrap_function(sc, self.func, self.returnType) jdt = spark._jsparkSession.parseDataType(self.returnType.json()) judf = sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonFunction( self._name, wrapped_func, jdt, self.evalType, self.deterministic) diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 30723b8e15b36..18b2f251dc9fd 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -1291,27 +1291,34 @@ def test_pipe_unicode(self): result = rdd.pipe('cat').collect() self.assertEqual(data, result) - def test_stopiteration_in_client_code(self): + def test_stopiteration_in_user_code(self): def stopit(*x): raise StopIteration() seq_rdd = self.sc.parallelize(range(10)) keyed_rdd = self.sc.parallelize((x % 2, x) for x in range(10)) - - self.assertRaises(Py4JJavaError, seq_rdd.map(stopit).collect) - self.assertRaises(Py4JJavaError, seq_rdd.filter(stopit).collect) - self.assertRaises(Py4JJavaError, seq_rdd.cartesian(seq_rdd).flatMap(stopit).collect) - self.assertRaises(Py4JJavaError, seq_rdd.foreach, stopit) - self.assertRaises(Py4JJavaError, keyed_rdd.reduceByKeyLocally, stopit) - self.assertRaises(Py4JJavaError, seq_rdd.reduce, stopit) - self.assertRaises(Py4JJavaError, seq_rdd.fold, 0, stopit) - - # the exception raised is non-deterministic - self.assertRaises((Py4JJavaError, RuntimeError), - seq_rdd.aggregate, 0, stopit, lambda *x: 1) - self.assertRaises((Py4JJavaError, RuntimeError), - seq_rdd.aggregate, 0, lambda *x: 1, stopit) + msg = "Caught StopIteration thrown from user's code; failing the task" + + self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.map(stopit).collect) + self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.filter(stopit).collect) + self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.foreach, stopit) + self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.reduce, stopit) + self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.fold, 0, stopit) + self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.foreach, stopit) + self.assertRaisesRegexp(Py4JJavaError, msg, + seq_rdd.cartesian(seq_rdd).flatMap(stopit).collect) + + # these methods call the user function both in the driver and in the executor + # the exception raised is different according to where the StopIteration happens + # RuntimeError is raised if in the driver + # Py4JJavaError is raised if in the executor (wraps the RuntimeError raised in the worker) + self.assertRaisesRegexp((Py4JJavaError, RuntimeError), msg, + keyed_rdd.reduceByKeyLocally, stopit) + self.assertRaisesRegexp((Py4JJavaError, RuntimeError), msg, + seq_rdd.aggregate, 0, stopit, lambda *x: 1) + self.assertRaisesRegexp((Py4JJavaError, RuntimeError), msg, + seq_rdd.aggregate, 0, lambda *x: 1, stopit) class ProfilerTests(PySparkTestCase): diff --git a/python/pyspark/util.py b/python/pyspark/util.py index e95a9b523393f..f015542c8799d 100644 --- a/python/pyspark/util.py +++ b/python/pyspark/util.py @@ -53,12 +53,7 @@ def _get_argspec(f): """ Get argspec of a function. Supports both Python 2 and Python 3. """ - - if hasattr(f, '_argspec'): - # only used for pandas UDF: they wrap the user function, losing its signature - # workers need this signature, so UDF saves it here - argspec = f._argspec - elif sys.version_info[0] < 3: + if sys.version_info[0] < 3: argspec = inspect.getargspec(f) else: # `getargspec` is deprecated since python3.0 (incompatible with function annotations). @@ -97,7 +92,7 @@ def majorMinorVersion(sparkVersion): def fail_on_stopiteration(f): """ Wraps the input function to fail on 'StopIteration' by raising a 'RuntimeError' - prevents silent loss of data when 'f' is used in a for loop + prevents silent loss of data when 'f' is used in a for loop in Spark code """ def wrapper(*args, **kwargs): try: diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index fbcb8af8bfb24..a30d6bf523a50 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -35,7 +35,7 @@ write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, \ BatchedSerializer, ArrowStreamPandasSerializer from pyspark.sql.types import to_arrow_type -from pyspark.util import _get_argspec +from pyspark.util import _get_argspec, fail_on_stopiteration from pyspark import shuffle pickleSer = PickleSerializer() @@ -92,10 +92,9 @@ def verify_result_length(*a): return lambda *a: (verify_result_length(*a), arrow_return_type) -def wrap_grouped_map_pandas_udf(f, return_type): +def wrap_grouped_map_pandas_udf(f, return_type, argspec): def wrapped(key_series, value_series): import pandas as pd - argspec = _get_argspec(f) if len(argspec.args) == 1: result = f(pd.concat(value_series, axis=1)) @@ -140,15 +139,20 @@ def read_single_udf(pickleSer, infile, eval_type): else: row_func = chain(row_func, f) + # make sure StopIteration's raised in the user code are not ignored + # when they are processed in a for loop, raise them as RuntimeError's instead + func = fail_on_stopiteration(row_func) + # the last returnType will be the return type of UDF if eval_type == PythonEvalType.SQL_SCALAR_PANDAS_UDF: - return arg_offsets, wrap_scalar_pandas_udf(row_func, return_type) + return arg_offsets, wrap_scalar_pandas_udf(func, return_type) elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF: - return arg_offsets, wrap_grouped_map_pandas_udf(row_func, return_type) + argspec = _get_argspec(row_func) # signature was lost when wrapping it + return arg_offsets, wrap_grouped_map_pandas_udf(func, return_type, argspec) elif eval_type == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF: - return arg_offsets, wrap_grouped_agg_pandas_udf(row_func, return_type) + return arg_offsets, wrap_grouped_agg_pandas_udf(func, return_type) elif eval_type == PythonEvalType.SQL_BATCHED_UDF: - return arg_offsets, wrap_udf(row_func, return_type) + return arg_offsets, wrap_udf(func, return_type) else: raise ValueError("Unknown eval type: {}".format(eval_type)) From a99d284c16cc4e00ce7c83ecdc3db6facd467552 Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Mon, 11 Jun 2018 12:15:14 -0700 Subject: [PATCH 0938/1161] [SPARK-19826][ML][PYTHON] add spark.ml Python API for PIC ## What changes were proposed in this pull request? add spark.ml Python API for PIC ## How was this patch tested? add doctest Author: Huaxin Gao Closes #21513 from huaxingao/spark--19826. --- python/pyspark/ml/clustering.py | 184 +++++++++++++++++++++++++++++++- 1 file changed, 179 insertions(+), 5 deletions(-) diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py index b3d5fb17f6b81..4aa1cf84b5824 100644 --- a/python/pyspark/ml/clustering.py +++ b/python/pyspark/ml/clustering.py @@ -19,14 +19,15 @@ from pyspark import since, keyword_only from pyspark.ml.util import * -from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaWrapper +from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaParams, JavaWrapper from pyspark.ml.param.shared import * from pyspark.ml.common import inherit_doc +from pyspark.sql import DataFrame __all__ = ['BisectingKMeans', 'BisectingKMeansModel', 'BisectingKMeansSummary', 'KMeans', 'KMeansModel', 'GaussianMixture', 'GaussianMixtureModel', 'GaussianMixtureSummary', - 'LDA', 'LDAModel', 'LocalLDAModel', 'DistributedLDAModel'] + 'LDA', 'LDAModel', 'LocalLDAModel', 'DistributedLDAModel', 'PowerIterationClustering'] class ClusteringSummary(JavaWrapper): @@ -836,7 +837,7 @@ class LDA(JavaEstimator, HasFeaturesCol, HasMaxIter, HasSeed, HasCheckpointInter Terminology: - - "term" = "word": an el + - "term" = "word": an element of the vocabulary - "token": instance of a term appearing in a document - "topic": multinomial distribution over terms representing some concept - "document": one piece of text, corresponding to one row in the input data @@ -938,7 +939,7 @@ def __init__(self, featuresCol="features", maxIter=20, seed=None, checkpointInte k=10, optimizer="online", learningOffset=1024.0, learningDecay=0.51,\ subsamplingRate=0.05, optimizeDocConcentration=True,\ docConcentration=None, topicConcentration=None,\ - topicDistributionCol="topicDistribution", keepLastCheckpoint=True): + topicDistributionCol="topicDistribution", keepLastCheckpoint=True) """ super(LDA, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.clustering.LDA", self.uid) @@ -967,7 +968,7 @@ def setParams(self, featuresCol="features", maxIter=20, seed=None, checkpointInt k=10, optimizer="online", learningOffset=1024.0, learningDecay=0.51,\ subsamplingRate=0.05, optimizeDocConcentration=True,\ docConcentration=None, topicConcentration=None,\ - topicDistributionCol="topicDistribution", keepLastCheckpoint=True): + topicDistributionCol="topicDistribution", keepLastCheckpoint=True) Sets params for LDA. """ @@ -1156,6 +1157,179 @@ def getKeepLastCheckpoint(self): return self.getOrDefault(self.keepLastCheckpoint) +@inherit_doc +class PowerIterationClustering(HasMaxIter, HasWeightCol, JavaParams, JavaMLReadable, + JavaMLWritable): + """ + .. note:: Experimental + + Power Iteration Clustering (PIC), a scalable graph clustering algorithm developed by + Lin and Cohen. From the abstract: + PIC finds a very low-dimensional embedding of a dataset using truncated power + iteration on a normalized pair-wise similarity matrix of the data. + + This class is not yet an Estimator/Transformer, use :py:func:`assignClusters` method + to run the PowerIterationClustering algorithm. + + .. seealso:: `Wikipedia on Spectral clustering \ + `_ + + >>> data = [(1, 0, 0.5), \ + (2, 0, 0.5), (2, 1, 0.7), \ + (3, 0, 0.5), (3, 1, 0.7), (3, 2, 0.9), \ + (4, 0, 0.5), (4, 1, 0.7), (4, 2, 0.9), (4, 3, 1.1), \ + (5, 0, 0.5), (5, 1, 0.7), (5, 2, 0.9), (5, 3, 1.1), (5, 4, 1.3)] + >>> df = spark.createDataFrame(data).toDF("src", "dst", "weight") + >>> pic = PowerIterationClustering(k=2, maxIter=40, weightCol="weight") + >>> assignments = pic.assignClusters(df) + >>> assignments.sort(assignments.id).show(truncate=False) + +---+-------+ + |id |cluster| + +---+-------+ + |0 |1 | + |1 |1 | + |2 |1 | + |3 |1 | + |4 |1 | + |5 |0 | + +---+-------+ + ... + >>> pic_path = temp_path + "/pic" + >>> pic.save(pic_path) + >>> pic2 = PowerIterationClustering.load(pic_path) + >>> pic2.getK() + 2 + >>> pic2.getMaxIter() + 40 + + .. versionadded:: 2.4.0 + """ + + k = Param(Params._dummy(), "k", + "The number of clusters to create. Must be > 1.", + typeConverter=TypeConverters.toInt) + initMode = Param(Params._dummy(), "initMode", + "The initialization algorithm. This can be either " + + "'random' to use a random vector as vertex properties, or 'degree' to use " + + "a normalized sum of similarities with other vertices. Supported options: " + + "'random' and 'degree'.", + typeConverter=TypeConverters.toString) + srcCol = Param(Params._dummy(), "srcCol", + "Name of the input column for source vertex IDs.", + typeConverter=TypeConverters.toString) + dstCol = Param(Params._dummy(), "dstCol", + "Name of the input column for destination vertex IDs.", + typeConverter=TypeConverters.toString) + + @keyword_only + def __init__(self, k=2, maxIter=20, initMode="random", srcCol="src", dstCol="dst", + weightCol=None): + """ + __init__(self, k=2, maxIter=20, initMode="random", srcCol="src", dstCol="dst",\ + weightCol=None) + """ + super(PowerIterationClustering, self).__init__() + self._java_obj = self._new_java_obj( + "org.apache.spark.ml.clustering.PowerIterationClustering", self.uid) + self._setDefault(k=2, maxIter=20, initMode="random", srcCol="src", dstCol="dst") + kwargs = self._input_kwargs + self.setParams(**kwargs) + + @keyword_only + @since("2.4.0") + def setParams(self, k=2, maxIter=20, initMode="random", srcCol="src", dstCol="dst", + weightCol=None): + """ + setParams(self, k=2, maxIter=20, initMode="random", srcCol="src", dstCol="dst",\ + weightCol=None) + Sets params for PowerIterationClustering. + """ + kwargs = self._input_kwargs + return self._set(**kwargs) + + @since("2.4.0") + def setK(self, value): + """ + Sets the value of :py:attr:`k`. + """ + return self._set(k=value) + + @since("2.4.0") + def getK(self): + """ + Gets the value of :py:attr:`k` or its default value. + """ + return self.getOrDefault(self.k) + + @since("2.4.0") + def setInitMode(self, value): + """ + Sets the value of :py:attr:`initMode`. + """ + return self._set(initMode=value) + + @since("2.4.0") + def getInitMode(self): + """ + Gets the value of :py:attr:`initMode` or its default value. + """ + return self.getOrDefault(self.initMode) + + @since("2.4.0") + def setSrcCol(self, value): + """ + Sets the value of :py:attr:`srcCol`. + """ + return self._set(srcCol=value) + + @since("2.4.0") + def getSrcCol(self): + """ + Gets the value of :py:attr:`srcCol` or its default value. + """ + return self.getOrDefault(self.srcCol) + + @since("2.4.0") + def setDstCol(self, value): + """ + Sets the value of :py:attr:`dstCol`. + """ + return self._set(dstCol=value) + + @since("2.4.0") + def getDstCol(self): + """ + Gets the value of :py:attr:`dstCol` or its default value. + """ + return self.getOrDefault(self.dstCol) + + @since("2.4.0") + def assignClusters(self, dataset): + """ + Run the PIC algorithm and returns a cluster assignment for each input vertex. + + :param dataset: + A dataset with columns src, dst, weight representing the affinity matrix, + which is the matrix A in the PIC paper. Suppose the src column value is i, + the dst column value is j, the weight column value is similarity s,,ij,, + which must be nonnegative. This is a symmetric matrix and hence + s,,ij,, = s,,ji,,. For any (i, j) with nonzero similarity, there should be + either (i, j, s,,ij,,) or (j, i, s,,ji,,) in the input. Rows with i = j are + ignored, because we assume s,,ij,, = 0.0. + + :return: + A dataset that contains columns of vertex id and the corresponding cluster for + the id. The schema of it will be: + - id: Long + - cluster: Int + + .. versionadded:: 2.4.0 + """ + self._transfer_params_to_java() + jdf = self._java_obj.assignClusters(dataset._jdf) + return DataFrame(jdf, dataset.sql_ctx) + + if __name__ == "__main__": import doctest import pyspark.ml.clustering From 9b6f24202f6f8d9d76bbe53f379743318acb19f9 Mon Sep 17 00:00:00 2001 From: Jonathan Kelly Date: Mon, 11 Jun 2018 16:41:15 -0500 Subject: [PATCH 0939/1161] [MINOR][CORE] Log committer class used by HadoopMapRedCommitProtocol ## What changes were proposed in this pull request? When HadoopMapRedCommitProtocol is used (e.g., when using saveAsTextFile() or saveAsHadoopFile() with RDDs), it's not easy to determine which output committer class was used, so this PR simply logs the class that was used, similarly to what is done in SQLHadoopMapReduceCommitProtocol. ## How was this patch tested? Built Spark then manually inspected logging when calling saveAsTextFile(): ```scala scala> sc.setLogLevel("INFO") scala> sc.textFile("README.md").saveAsTextFile("/tmp/out") ... 18/05/29 10:06:20 INFO HadoopMapRedCommitProtocol: Using output committer class org.apache.hadoop.mapred.FileOutputCommitter ``` Author: Jonathan Kelly Closes #21452 from ejono/master. --- .../apache/spark/internal/io/HadoopMapRedCommitProtocol.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/internal/io/HadoopMapRedCommitProtocol.scala b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapRedCommitProtocol.scala index ddbd624b380d4..af0aa41518766 100644 --- a/core/src/main/scala/org/apache/spark/internal/io/HadoopMapRedCommitProtocol.scala +++ b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapRedCommitProtocol.scala @@ -31,6 +31,8 @@ class HadoopMapRedCommitProtocol(jobId: String, path: String) override def setupCommitter(context: NewTaskAttemptContext): OutputCommitter = { val config = context.getConfiguration.asInstanceOf[JobConf] - config.getOutputCommitter + val committer = config.getOutputCommitter + logInfo(s"Using output committer class ${committer.getClass.getCanonicalName}") + committer } } From 2dc047a3189290411def92f6d7e9a4e01bdb2c30 Mon Sep 17 00:00:00 2001 From: Fokko Driesprong Date: Mon, 11 Jun 2018 17:12:33 -0500 Subject: [PATCH 0940/1161] [SPARK-24520] Double braces in documentations There are double braces in the markdown, which break the link. ## What changes were proposed in this pull request? (Please fill in changes proposed in this fix) ## How was this patch tested? (Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests) (If this patch involves UI changes, please attach a screenshot; otherwise, remove this) Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Fokko Driesprong Closes #21528 from Fokko/patch-1. --- docs/structured-streaming-programming-guide.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/structured-streaming-programming-guide.md b/docs/structured-streaming-programming-guide.md index 602a4c70848e7..0842e8dd88672 100644 --- a/docs/structured-streaming-programming-guide.md +++ b/docs/structured-streaming-programming-guide.md @@ -926,7 +926,7 @@ event time. For a specific window starting at time `T`, the engine will maintain data to update the state until `(max event time seen by the engine - late threshold > T)`. In other words, late data within the threshold will be aggregated, but data later than the threshold will start getting dropped -(see [later]((#semantic-guarantees-of-aggregation-with-watermarking)) +(see [later](#semantic-guarantees-of-aggregation-with-watermarking) in the section for the exact guarantees). Let's understand this with an example. We can easily define watermarking on the previous example using `withWatermark()` as shown below. From f5af86ea753c446df59a0a8c16c685224690d633 Mon Sep 17 00:00:00 2001 From: Xiaodong <11539188+XD-DENG@users.noreply.github.com> Date: Mon, 11 Jun 2018 17:13:11 -0500 Subject: [PATCH 0941/1161] [SPARK-24134][DOCS] A missing full-stop in doc "Tuning Spark". MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? In the document [Tuning Spark -> Determining Memory Consumption](https://spark.apache.org/docs/latest/tuning.html#determining-memory-consumption), a full stop was missing in the second paragraph. It's `...use SizeEstimator’s estimate method This is useful for experimenting...`, while there is supposed to be a full stop before `This`. Screenshot showing before change is attached below. screen shot 2018-05-01 at 5 22 32 pm ## How was this patch tested? This is a simple change in doc. Only one full stop was added in plain text. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Xiaodong <11539188+XD-DENG@users.noreply.github.com> Closes #21205 from XD-DENG/patch-1. --- docs/tuning.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/tuning.md b/docs/tuning.md index 912c39879be8f..1c3bd0e8758ff 100644 --- a/docs/tuning.md +++ b/docs/tuning.md @@ -132,7 +132,7 @@ The best way to size the amount of memory consumption a dataset will require is into cache, and look at the "Storage" page in the web UI. The page will tell you how much memory the RDD is occupying. -To estimate the memory consumption of a particular object, use `SizeEstimator`'s `estimate` method +To estimate the memory consumption of a particular object, use `SizeEstimator`'s `estimate` method. This is useful for experimenting with different data layouts to trim memory usage, as well as determining the amount of space a broadcast variable will occupy on each executor heap. From 048197749ef990e4def1fcbf488f3ded38d95cae Mon Sep 17 00:00:00 2001 From: liutang123 Date: Mon, 11 Jun 2018 17:48:07 -0700 Subject: [PATCH 0942/1161] [SPARK-22144][SQL] ExchangeCoordinator combine the partitions of an 0 sized pre-shuffle to 0 ## What changes were proposed in this pull request? when the length of pre-shuffle's partitions is 0, the length of post-shuffle's partitions should be 0 instead of spark.sql.shuffle.partitions. ## How was this patch tested? ExchangeCoordinator converted a pre-shuffle that partitions is 0 to a post-shuffle that partitions is 0 instead of one that partitions is spark.sql.shuffle.partitions. Author: liutang123 Closes #19364 from liutang123/SPARK-22144. --- .../spark/sql/execution/exchange/ExchangeCoordinator.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ExchangeCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ExchangeCoordinator.scala index 78f11ca8d8c78..051e610eb2705 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ExchangeCoordinator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ExchangeCoordinator.scala @@ -232,16 +232,16 @@ class ExchangeCoordinator( // number of post-shuffle partitions. val partitionStartIndices = if (mapOutputStatistics.length == 0) { - None + Array.empty[Int] } else { - Some(estimatePartitionStartIndices(mapOutputStatistics)) + estimatePartitionStartIndices(mapOutputStatistics) } var k = 0 while (k < numExchanges) { val exchange = exchanges(k) val rdd = - exchange.preparePostShuffleRDD(shuffleDependencies(k), partitionStartIndices) + exchange.preparePostShuffleRDD(shuffleDependencies(k), Some(partitionStartIndices)) newPostShuffleRDDs.put(exchange, rdd) k += 1 From dc22465f3e1ef5ad59306b1f591d6fd16d674eb7 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Tue, 12 Jun 2018 09:32:14 +0800 Subject: [PATCH 0943/1161] [SPARK-23732][DOCS] Fix source links in generated scaladoc. Apply the suggestion on the bug to fix source links. Tested with the 2.3.1 release docs. Author: Marcelo Vanzin Closes #21521 from vanzin/SPARK-23732. --- project/SparkBuild.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 4cb6495a33b61..adc2b6b5d273d 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -729,7 +729,8 @@ object Unidoc { scalacOptions in (ScalaUnidoc, unidoc) ++= Seq( "-groups", // Group similar methods together based on the @group annotation. - "-skip-packages", "org.apache.hadoop" + "-skip-packages", "org.apache.hadoop", + "-sourcepath", (baseDirectory in ThisBuild).value.getAbsolutePath ) ++ ( // Add links to sources when generating Scaladoc for a non-snapshot release if (!isSnapshot.value) { From 01452ea9c75ff027ceeb8314368c6bbedefdb2bf Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 11 Jun 2018 22:08:44 -0700 Subject: [PATCH 0944/1161] [SPARK-24502][SQL] flaky test: UnsafeRowSerializerSuite ## What changes were proposed in this pull request? `UnsafeRowSerializerSuite` calls `UnsafeProjection.create` which accesses `SQLConf.get`, while the current active SparkSession may already be stopped, and we may hit exception like this ``` sbt.ForkMain$ForkError: java.lang.IllegalStateException: LiveListenerBus is stopped. at org.apache.spark.scheduler.LiveListenerBus.addToQueue(LiveListenerBus.scala:97) at org.apache.spark.scheduler.LiveListenerBus.addToStatusQueue(LiveListenerBus.scala:80) at org.apache.spark.sql.internal.SharedState.(SharedState.scala:93) at org.apache.spark.sql.SparkSession$$anonfun$sharedState$1.apply(SparkSession.scala:120) at org.apache.spark.sql.SparkSession$$anonfun$sharedState$1.apply(SparkSession.scala:120) at scala.Option.getOrElse(Option.scala:121) at org.apache.spark.sql.SparkSession.sharedState$lzycompute(SparkSession.scala:120) at org.apache.spark.sql.SparkSession.sharedState(SparkSession.scala:119) at org.apache.spark.sql.internal.BaseSessionStateBuilder.build(BaseSessionStateBuilder.scala:286) at org.apache.spark.sql.test.TestSparkSession.sessionState$lzycompute(TestSQLContext.scala:42) at org.apache.spark.sql.test.TestSparkSession.sessionState(TestSQLContext.scala:41) at org.apache.spark.sql.SparkSession$$anonfun$1$$anonfun$apply$1.apply(SparkSession.scala:95) at org.apache.spark.sql.SparkSession$$anonfun$1$$anonfun$apply$1.apply(SparkSession.scala:95) at scala.Option.map(Option.scala:146) at org.apache.spark.sql.SparkSession$$anonfun$1.apply(SparkSession.scala:95) at org.apache.spark.sql.SparkSession$$anonfun$1.apply(SparkSession.scala:94) at org.apache.spark.sql.internal.SQLConf$.get(SQLConf.scala:126) at org.apache.spark.sql.catalyst.expressions.CodeGeneratorWithInterpretedFallback.createObject(CodeGeneratorWithInterpretedFallback.scala:54) at org.apache.spark.sql.catalyst.expressions.UnsafeProjection$.create(Projection.scala:157) at org.apache.spark.sql.catalyst.expressions.UnsafeProjection$.create(Projection.scala:150) at org.apache.spark.sql.execution.UnsafeRowSerializerSuite.org$apache$spark$sql$execution$UnsafeRowSerializerSuite$$unsafeRowConverter(UnsafeRowSerializerSuite.scala:54) at org.apache.spark.sql.execution.UnsafeRowSerializerSuite.org$apache$spark$sql$execution$UnsafeRowSerializerSuite$$toUnsafeRow(UnsafeRowSerializerSuite.scala:49) at org.apache.spark.sql.execution.UnsafeRowSerializerSuite$$anonfun$2.apply(UnsafeRowSerializerSuite.scala:63) at org.apache.spark.sql.execution.UnsafeRowSerializerSuite$$anonfun$2.apply(UnsafeRowSerializerSuite.scala:60) ... ``` ## How was this patch tested? N/A Author: Wenchen Fan Closes #21518 from cloud-fan/test. --- .../apache/spark/sql/LocalSparkSession.scala | 4 + .../execution/UnsafeRowSerializerSuite.scala | 80 +++++++------------ 2 files changed, 35 insertions(+), 49 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/LocalSparkSession.scala b/sql/core/src/test/scala/org/apache/spark/sql/LocalSparkSession.scala index d66a6902b0510..cbef1c7828319 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/LocalSparkSession.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/LocalSparkSession.scala @@ -30,11 +30,15 @@ trait LocalSparkSession extends BeforeAndAfterEach with BeforeAndAfterAll { self override def beforeAll() { super.beforeAll() InternalLoggerFactory.setDefaultFactory(Slf4JLoggerFactory.INSTANCE) + SparkSession.clearActiveSession() + SparkSession.clearDefaultSession() } override def afterEach() { try { resetSparkContext() + SparkSession.clearActiveSession() + SparkSession.clearDefaultSession() } finally { super.afterEach() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala index a3ae93810aa3c..d305ce3e698ae 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala @@ -21,15 +21,13 @@ import java.io.{ByteArrayInputStream, ByteArrayOutputStream, File} import java.util.Properties import org.apache.spark._ -import org.apache.spark.executor.TaskMetrics import org.apache.spark.memory.TaskMemoryManager import org.apache.spark.rdd.RDD -import org.apache.spark.sql.Row +import org.apache.spark.sql.{LocalSparkSession, Row, SparkSession} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, UnsafeRow} import org.apache.spark.sql.types._ import org.apache.spark.storage.ShuffleBlockId -import org.apache.spark.util.Utils import org.apache.spark.util.collection.ExternalSorter /** @@ -43,7 +41,7 @@ class ClosableByteArrayInputStream(buf: Array[Byte]) extends ByteArrayInputStrea } } -class UnsafeRowSerializerSuite extends SparkFunSuite with LocalSparkContext { +class UnsafeRowSerializerSuite extends SparkFunSuite with LocalSparkSession { private def toUnsafeRow(row: Row, schema: Array[DataType]): UnsafeRow = { val converter = unsafeRowConverter(schema) @@ -58,7 +56,7 @@ class UnsafeRowSerializerSuite extends SparkFunSuite with LocalSparkContext { } test("toUnsafeRow() test helper method") { - // This currently doesnt work because the generic getter throws an exception. + // This currently doesn't work because the generic getter throws an exception. val row = Row("Hello", 123) val unsafeRow = toUnsafeRow(row, Array(StringType, IntegerType)) assert(row.getString(0) === unsafeRow.getUTF8String(0).toString) @@ -97,59 +95,43 @@ class UnsafeRowSerializerSuite extends SparkFunSuite with LocalSparkContext { } test("SPARK-10466: external sorter spilling with unsafe row serializer") { - var sc: SparkContext = null - var outputFile: File = null - val oldEnv = SparkEnv.get // save the old SparkEnv, as it will be overwritten - Utils.tryWithSafeFinally { - val conf = new SparkConf() - .set("spark.shuffle.spill.initialMemoryThreshold", "1") - .set("spark.shuffle.sort.bypassMergeThreshold", "0") - .set("spark.testing.memory", "80000") - - sc = new SparkContext("local", "test", conf) - outputFile = File.createTempFile("test-unsafe-row-serializer-spill", "") - // prepare data - val converter = unsafeRowConverter(Array(IntegerType)) - val data = (1 to 10000).iterator.map { i => - (i, converter(Row(i))) - } - val taskMemoryManager = new TaskMemoryManager(sc.env.memoryManager, 0) - val taskContext = new TaskContextImpl(0, 0, 0, 0, 0, taskMemoryManager, new Properties, null) - - val sorter = new ExternalSorter[Int, UnsafeRow, UnsafeRow]( - taskContext, - partitioner = Some(new HashPartitioner(10)), - serializer = new UnsafeRowSerializer(numFields = 1)) - - // Ensure we spilled something and have to merge them later - assert(sorter.numSpills === 0) - sorter.insertAll(data) - assert(sorter.numSpills > 0) + val conf = new SparkConf() + .set("spark.shuffle.spill.initialMemoryThreshold", "1") + .set("spark.shuffle.sort.bypassMergeThreshold", "0") + .set("spark.testing.memory", "80000") + spark = SparkSession.builder().master("local").appName("test").config(conf).getOrCreate() + val outputFile = File.createTempFile("test-unsafe-row-serializer-spill", "") + outputFile.deleteOnExit() + // prepare data + val converter = unsafeRowConverter(Array(IntegerType)) + val data = (1 to 10000).iterator.map { i => + (i, converter(Row(i))) + } + val taskMemoryManager = new TaskMemoryManager(spark.sparkContext.env.memoryManager, 0) + val taskContext = new TaskContextImpl(0, 0, 0, 0, 0, taskMemoryManager, new Properties, null) - // Merging spilled files should not throw assertion error - sorter.writePartitionedFile(ShuffleBlockId(0, 0, 0), outputFile) - } { - // Clean up - if (sc != null) { - sc.stop() - } + val sorter = new ExternalSorter[Int, UnsafeRow, UnsafeRow]( + taskContext, + partitioner = Some(new HashPartitioner(10)), + serializer = new UnsafeRowSerializer(numFields = 1)) - // restore the spark env - SparkEnv.set(oldEnv) + // Ensure we spilled something and have to merge them later + assert(sorter.numSpills === 0) + sorter.insertAll(data) + assert(sorter.numSpills > 0) - if (outputFile != null) { - outputFile.delete() - } - } + // Merging spilled files should not throw assertion error + sorter.writePartitionedFile(ShuffleBlockId(0, 0, 0), outputFile) } test("SPARK-10403: unsafe row serializer with SortShuffleManager") { val conf = new SparkConf().set("spark.shuffle.manager", "sort") - sc = new SparkContext("local", "test", conf) + spark = SparkSession.builder().master("local").appName("test").config(conf).getOrCreate() val row = Row("Hello", 123) val unsafeRow = toUnsafeRow(row, Array(StringType, IntegerType)) - val rowsRDD = sc.parallelize(Seq((0, unsafeRow), (1, unsafeRow), (0, unsafeRow))) - .asInstanceOf[RDD[Product2[Int, InternalRow]]] + val rowsRDD = spark.sparkContext.parallelize( + Seq((0, unsafeRow), (1, unsafeRow), (0, unsafeRow)) + ).asInstanceOf[RDD[Product2[Int, InternalRow]]] val dependency = new ShuffleDependency[Int, InternalRow, InternalRow]( rowsRDD, From 1d7db65e968de1c601e7f8b1ec9bc783ef2dbd01 Mon Sep 17 00:00:00 2001 From: Tom Saleeba Date: Tue, 12 Jun 2018 09:22:52 -0500 Subject: [PATCH 0945/1161] docs: fix typo no => no[t] ## What changes were proposed in this pull request? Fixing a typo. ## How was this patch tested? Visual check of the docs. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Tom Saleeba Closes #21496 from tomsaleeba/patch-1. --- sql/core/src/main/scala/org/apache/spark/sql/Column.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 2dbb53e7c906b..4eee3de5f7d4e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -104,7 +104,7 @@ class TypedColumn[-T, U]( * * {{{ * df("columnName") // On a specific `df` DataFrame. - * col("columnName") // A generic column no yet associated with a DataFrame. + * col("columnName") // A generic column not yet associated with a DataFrame. * col("columnName.field") // Extracting a struct field * col("`a.column.with.dots`") // Escape `.` in column names. * $"columnName" // Scala short hand for a named column. From 5d6a53d9831cc1e2115560db5cebe0eea2565dcd Mon Sep 17 00:00:00 2001 From: Lee Dongjin Date: Tue, 12 Jun 2018 08:16:37 -0700 Subject: [PATCH 0946/1161] [SPARK-15064][ML] Locale support in StopWordsRemover ## What changes were proposed in this pull request? Add locale support for `StopWordsRemover`. ## How was this patch tested? [Scala|Python] unit tests. Author: Lee Dongjin Closes #21501 from dongjinleekr/feature/SPARK-15064. --- .../spark/ml/feature/StopWordsRemover.scala | 30 +++++++++-- .../ml/feature/StopWordsRemoverSuite.scala | 51 +++++++++++++++++++ python/pyspark/ml/feature.py | 30 +++++++++-- python/pyspark/ml/tests.py | 7 +++ 4 files changed, 109 insertions(+), 9 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala index 3fcd84c029e61..0f946dd2e015b 100755 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala @@ -17,9 +17,11 @@ package org.apache.spark.ml.feature +import java.util.Locale + import org.apache.spark.annotation.Since import org.apache.spark.ml.Transformer -import org.apache.spark.ml.param.{BooleanParam, ParamMap, StringArrayParam} +import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} import org.apache.spark.ml.util._ import org.apache.spark.sql.{DataFrame, Dataset} @@ -84,7 +86,27 @@ class StopWordsRemover @Since("1.5.0") (@Since("1.5.0") override val uid: String @Since("1.5.0") def getCaseSensitive: Boolean = $(caseSensitive) - setDefault(stopWords -> StopWordsRemover.loadDefaultStopWords("english"), caseSensitive -> false) + /** + * Locale of the input for case insensitive matching. Ignored when [[caseSensitive]] + * is true. + * Default: Locale.getDefault.toString + * @group param + */ + @Since("2.4.0") + val locale: Param[String] = new Param[String](this, "locale", + "Locale of the input for case insensitive matching. Ignored when caseSensitive is true.", + ParamValidators.inArray[String](Locale.getAvailableLocales.map(_.toString))) + + /** @group setParam */ + @Since("2.4.0") + def setLocale(value: String): this.type = set(locale, value) + + /** @group getParam */ + @Since("2.4.0") + def getLocale: String = $(locale) + + setDefault(stopWords -> StopWordsRemover.loadDefaultStopWords("english"), + caseSensitive -> false, locale -> Locale.getDefault.toString) @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { @@ -95,8 +117,8 @@ class StopWordsRemover @Since("1.5.0") (@Since("1.5.0") override val uid: String terms.filter(s => !stopWordsSet.contains(s)) } } else { - // TODO: support user locale (SPARK-15064) - val toLower = (s: String) => if (s != null) s.toLowerCase else s + val lc = new Locale($(locale)) + val toLower = (s: String) => if (s != null) s.toLowerCase(lc) else s val lowerStopWords = $(stopWords).map(toLower(_)).toSet udf { terms: Seq[String] => terms.filter(s => !lowerStopWords.contains(toLower(s))) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala index 21259a50916d2..20972d1f403b9 100755 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala @@ -65,6 +65,57 @@ class StopWordsRemoverSuite extends MLTest with DefaultReadWriteTest { testStopWordsRemover(remover, dataSet) } + test("StopWordsRemover with localed input (case insensitive)") { + val stopWords = Array("milk", "cookie") + val remover = new StopWordsRemover() + .setInputCol("raw") + .setOutputCol("filtered") + .setStopWords(stopWords) + .setCaseSensitive(false) + .setLocale("tr") // Turkish alphabet: has no Q, W, X but has dotted and dotless 'I's. + val dataSet = Seq( + // scalastyle:off + (Seq("mİlk", "and", "nuts"), Seq("and", "nuts")), + // scalastyle:on + (Seq("cookIe", "and", "nuts"), Seq("cookIe", "and", "nuts")), + (Seq(null), Seq(null)), + (Seq(), Seq()) + ).toDF("raw", "expected") + + testStopWordsRemover(remover, dataSet) + } + + test("StopWordsRemover with localed input (case sensitive)") { + val stopWords = Array("milk", "cookie") + val remover = new StopWordsRemover() + .setInputCol("raw") + .setOutputCol("filtered") + .setStopWords(stopWords) + .setCaseSensitive(true) + .setLocale("tr") // Turkish alphabet: has no Q, W, X but has dotted and dotless 'I's. + val dataSet = Seq( + // scalastyle:off + (Seq("mİlk", "and", "nuts"), Seq("mİlk", "and", "nuts")), + // scalastyle:on + (Seq("cookIe", "and", "nuts"), Seq("cookIe", "and", "nuts")), + (Seq(null), Seq(null)), + (Seq(), Seq()) + ).toDF("raw", "expected") + + testStopWordsRemover(remover, dataSet) + } + + test("StopWordsRemover with invalid locale") { + intercept[IllegalArgumentException] { + val stopWords = Array("test", "a", "an", "the") + new StopWordsRemover() + .setInputCol("raw") + .setOutputCol("filtered") + .setStopWords(stopWords) + .setLocale("rt") // invalid locale + } + } + test("StopWordsRemover case sensitive") { val remover = new StopWordsRemover() .setInputCol("raw") diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index cdda30cfab482..14800d4d9327a 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -2582,25 +2582,31 @@ class StopWordsRemover(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadabl typeConverter=TypeConverters.toListString) caseSensitive = Param(Params._dummy(), "caseSensitive", "whether to do a case sensitive " + "comparison over the stop words", typeConverter=TypeConverters.toBoolean) + locale = Param(Params._dummy(), "locale", "locale of the input. ignored when case sensitive " + + "is true", typeConverter=TypeConverters.toString) @keyword_only - def __init__(self, inputCol=None, outputCol=None, stopWords=None, caseSensitive=False): + def __init__(self, inputCol=None, outputCol=None, stopWords=None, caseSensitive=False, + locale=None): """ - __init__(self, inputCol=None, outputCol=None, stopWords=None, caseSensitive=false) + __init__(self, inputCol=None, outputCol=None, stopWords=None, caseSensitive=false, \ + locale=None) """ super(StopWordsRemover, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.StopWordsRemover", self.uid) self._setDefault(stopWords=StopWordsRemover.loadDefaultStopWords("english"), - caseSensitive=False) + caseSensitive=False, locale=self._java_obj.getLocale()) kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @since("1.6.0") - def setParams(self, inputCol=None, outputCol=None, stopWords=None, caseSensitive=False): + def setParams(self, inputCol=None, outputCol=None, stopWords=None, caseSensitive=False, + locale=None): """ - setParams(self, inputCol=None, outputCol=None, stopWords=None, caseSensitive=false) + setParams(self, inputCol=None, outputCol=None, stopWords=None, caseSensitive=false, \ + locale=None) Sets params for this StopWordRemover. """ kwargs = self._input_kwargs @@ -2634,6 +2640,20 @@ def getCaseSensitive(self): """ return self.getOrDefault(self.caseSensitive) + @since("2.4.0") + def setLocale(self, value): + """ + Sets the value of :py:attr:`locale`. + """ + return self._set(locale=value) + + @since("2.4.0") + def getLocale(self): + """ + Gets the value of :py:attr:`locale`. + """ + return self.getOrDefault(self.locale) + @staticmethod @since("2.0.0") def loadDefaultStopWords(language): diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 0dde0db9e3339..ebd36cbb5f7a7 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -681,6 +681,13 @@ def test_stopwordsremover(self): self.assertEqual(stopWordRemover.getStopWords(), stopwords) transformedDF = stopWordRemover.transform(dataset) self.assertEqual(transformedDF.head().output, []) + # with locale + stopwords = ["BELKİ"] + dataset = self.spark.createDataFrame([Row(input=["belki"])]) + stopWordRemover.setStopWords(stopwords).setLocale("tr") + self.assertEqual(stopWordRemover.getStopWords(), stopwords) + transformedDF = stopWordRemover.transform(dataset) + self.assertEqual(transformedDF.head().output, []) def test_count_vectorizer_with_binary(self): dataset = self.spark.createDataFrame([ From 2824f1436bb0371b7216730455f02456ef8479ce Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Tue, 12 Jun 2018 09:56:35 -0700 Subject: [PATCH 0947/1161] [SPARK-24531][TESTS] Remove version 2.2.0 from testing versions in HiveExternalCatalogVersionsSuite ## What changes were proposed in this pull request? Removing version 2.2.0 from testing versions in HiveExternalCatalogVersionsSuite as it is not present anymore in the mirrors and this is blocking all the open PRs. ## How was this patch tested? running UTs Author: Marco Gaido Closes #21540 from mgaido91/SPARK-24531. --- .../spark/sql/hive/HiveExternalCatalogVersionsSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala index ea86ab9772bc7..6f904c937348d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala @@ -195,7 +195,7 @@ class HiveExternalCatalogVersionsSuite extends SparkSubmitTestUtils { object PROCESS_TABLES extends QueryTest with SQLTestUtils { // Tests the latest version of every release line. - val testingVersions = Seq("2.0.2", "2.1.2", "2.2.0", "2.2.1", "2.3.0") + val testingVersions = Seq("2.0.2", "2.1.2", "2.2.1", "2.3.0") protected var spark: SparkSession = _ From 3af1d3e6d95719e15a997877d5ecd3bb40c08b9c Mon Sep 17 00:00:00 2001 From: Sanket Chintapalli Date: Tue, 12 Jun 2018 13:55:08 -0500 Subject: [PATCH 0948/1161] [SPARK-24416] Fix configuration specification for killBlacklisted executors ## What changes were proposed in this pull request? spark.blacklist.killBlacklistedExecutors is defined as (Experimental) If set to "true", allow Spark to automatically kill, and attempt to re-create, executors when they are blacklisted. Note that, when an entire node is added to the blacklist, all of the executors on that node will be killed. I presume the killing of blacklisted executors only happens after the stage completes successfully and all tasks have completed or on fetch failures (updateBlacklistForFetchFailure/updateBlacklistForSuccessfulTaskSet). It is confusing because the definition states that the executor will be attempted to be recreated as soon as it is blacklisted. This is not true while the stage is in progress and an executor is blacklisted, it will not attempt to cleanup until the stage finishes. Author: Sanket Chintapalli Closes #21475 from redsanket/SPARK-24416. --- docs/configuration.md | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/docs/configuration.md b/docs/configuration.md index 5588c372d3e42..6aa7878fe614d 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -1656,9 +1656,10 @@ Apart from these, the following properties are also available, and may be useful spark.blacklist.killBlacklistedExecutors false - (Experimental) If set to "true", allow Spark to automatically kill, and attempt to re-create, - executors when they are blacklisted. Note that, when an entire node is added to the blacklist, - all of the executors on that node will be killed. + (Experimental) If set to "true", allow Spark to automatically kill the executors + when they are blacklisted on fetch failure or blacklisted for the entire application, + as controlled by spark.blacklist.application.*. Note that, when an entire node is added + to the blacklist, all of the executors on that node will be killed. From f0ef1b311dd5399290ad6abe4ca491bdb13478f0 Mon Sep 17 00:00:00 2001 From: DylanGuedes Date: Tue, 12 Jun 2018 11:57:25 -0700 Subject: [PATCH 0949/1161] [SPARK-23931][SQL] Adds arrays_zip function to sparksql Signed-off-by: DylanGuedes ## What changes were proposed in this pull request? Addition of arrays_zip function to spark sql functions. ## How was this patch tested? (Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests) Unit tests that checks if the results are correct. Author: DylanGuedes Closes #21045 from DylanGuedes/SPARK-23931. --- python/pyspark/sql/functions.py | 17 ++ .../catalyst/analysis/FunctionRegistry.scala | 1 + .../expressions/collectionOperations.scala | 166 ++++++++++++++++++ .../CollectionExpressionsSuite.scala | 86 +++++++++ .../org/apache/spark/sql/functions.scala | 8 + .../spark/sql/DataFrameFunctionsSuite.scala | 47 +++++ 6 files changed, 325 insertions(+) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 1759195c6fcc0..0715297042520 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2394,6 +2394,23 @@ def array_repeat(col, count): return Column(sc._jvm.functions.array_repeat(_to_java_column(col), count)) +@since(2.4) +def arrays_zip(*cols): + """ + Collection function: Returns a merged array of structs in which the N-th struct contains all + N-th values of input arrays. + + :param cols: columns of arrays to be merged. + + >>> from pyspark.sql.functions import arrays_zip + >>> df = spark.createDataFrame([(([1, 2, 3], [2, 3, 4]))], ['vals1', 'vals2']) + >>> df.select(arrays_zip(df.vals1, df.vals2).alias('zipped')).collect() + [Row(zipped=[Row(vals1=1, vals2=2), Row(vals1=2, vals2=3), Row(vals1=3, vals2=4)])] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.arrays_zip(_to_seq(sc, cols, _to_java_column))) + + # ---------------------------- User Defined Function ---------------------------------- class PandasUDFType(object): diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 49fb35b083580..3c0b72873af54 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -423,6 +423,7 @@ object FunctionRegistry { expression[Size]("size"), expression[Slice]("slice"), expression[Size]("cardinality"), + expression[ArraysZip]("arrays_zip"), expression[SortArray]("sort_array"), expression[ArrayMin]("array_min"), expression[ArrayMax]("array_max"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 176995affe701..d76f3013f0c41 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -128,6 +128,172 @@ case class MapKeys(child: Expression) override def prettyName: String = "map_keys" } +@ExpressionDescription( + usage = """ + _FUNC_(a1, a2, ...) - Returns a merged array of structs in which the N-th struct contains all + N-th values of input arrays. + """, + examples = """ + Examples: + > SELECT _FUNC_(array(1, 2, 3), array(2, 3, 4)); + [[1, 2], [2, 3], [3, 4]] + > SELECT _FUNC_(array(1, 2), array(2, 3), array(3, 4)); + [[1, 2, 3], [2, 3, 4]] + """, + since = "2.4.0") +case class ArraysZip(children: Seq[Expression]) extends Expression with ExpectsInputTypes { + + override def inputTypes: Seq[AbstractDataType] = Seq.fill(children.length)(ArrayType) + + override def dataType: DataType = ArrayType(mountSchema) + + override def nullable: Boolean = children.exists(_.nullable) + + private lazy val arrayTypes = children.map(_.dataType.asInstanceOf[ArrayType]) + + private lazy val arrayElementTypes = arrayTypes.map(_.elementType) + + @transient private lazy val mountSchema: StructType = { + val fields = children.zip(arrayElementTypes).zipWithIndex.map { + case ((expr: NamedExpression, elementType), _) => + StructField(expr.name, elementType, nullable = true) + case ((_, elementType), idx) => + StructField(idx.toString, elementType, nullable = true) + } + StructType(fields) + } + + @transient lazy val numberOfArrays: Int = children.length + + @transient lazy val genericArrayData = classOf[GenericArrayData].getName + + def emptyInputGenCode(ev: ExprCode): ExprCode = { + ev.copy(code""" + |${CodeGenerator.javaType(dataType)} ${ev.value} = new $genericArrayData(new Object[0]); + |boolean ${ev.isNull} = false; + """.stripMargin) + } + + def nonEmptyInputGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val genericInternalRow = classOf[GenericInternalRow].getName + val arrVals = ctx.freshName("arrVals") + val biggestCardinality = ctx.freshName("biggestCardinality") + + val currentRow = ctx.freshName("currentRow") + val j = ctx.freshName("j") + val i = ctx.freshName("i") + val args = ctx.freshName("args") + + val evals = children.map(_.genCode(ctx)) + val getValuesAndCardinalities = evals.zipWithIndex.map { case (eval, index) => + s""" + |if ($biggestCardinality != -1) { + | ${eval.code} + | if (!${eval.isNull}) { + | $arrVals[$index] = ${eval.value}; + | $biggestCardinality = Math.max($biggestCardinality, ${eval.value}.numElements()); + | } else { + | $biggestCardinality = -1; + | } + |} + """.stripMargin + } + + val splittedGetValuesAndCardinalities = ctx.splitExpressions( + expressions = getValuesAndCardinalities, + funcName = "getValuesAndCardinalities", + returnType = "int", + makeSplitFunction = body => + s""" + |$body + |return $biggestCardinality; + """.stripMargin, + foldFunctions = _.map(funcCall => s"$biggestCardinality = $funcCall;").mkString("\n"), + arguments = + ("ArrayData[]", arrVals) :: + ("int", biggestCardinality) :: Nil) + + val getValueForType = arrayElementTypes.zipWithIndex.map { case (eleType, idx) => + val g = CodeGenerator.getValue(s"$arrVals[$idx]", eleType, i) + s""" + |if ($i < $arrVals[$idx].numElements() && !$arrVals[$idx].isNullAt($i)) { + | $currentRow[$idx] = $g; + |} else { + | $currentRow[$idx] = null; + |} + """.stripMargin + } + + val getValueForTypeSplitted = ctx.splitExpressions( + expressions = getValueForType, + funcName = "extractValue", + arguments = + ("int", i) :: + ("Object[]", currentRow) :: + ("ArrayData[]", arrVals) :: Nil) + + val initVariables = s""" + |ArrayData[] $arrVals = new ArrayData[$numberOfArrays]; + |int $biggestCardinality = 0; + |${CodeGenerator.javaType(dataType)} ${ev.value} = null; + """.stripMargin + + ev.copy(code""" + |$initVariables + |$splittedGetValuesAndCardinalities + |boolean ${ev.isNull} = $biggestCardinality == -1; + |if (!${ev.isNull}) { + | Object[] $args = new Object[$biggestCardinality]; + | for (int $i = 0; $i < $biggestCardinality; $i ++) { + | Object[] $currentRow = new Object[$numberOfArrays]; + | $getValueForTypeSplitted + | $args[$i] = new $genericInternalRow($currentRow); + | } + | ${ev.value} = new $genericArrayData($args); + |} + """.stripMargin) + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + if (numberOfArrays == 0) { + emptyInputGenCode(ev) + } else { + nonEmptyInputGenCode(ctx, ev) + } + } + + override def eval(input: InternalRow): Any = { + val inputArrays = children.map(_.eval(input).asInstanceOf[ArrayData]) + if (inputArrays.contains(null)) { + null + } else { + val biggestCardinality = if (inputArrays.isEmpty) { + 0 + } else { + inputArrays.map(_.numElements()).max + } + + val result = new Array[InternalRow](biggestCardinality) + val zippedArrs: Seq[(ArrayData, Int)] = inputArrays.zipWithIndex + + for (i <- 0 until biggestCardinality) { + val currentLayer: Seq[Object] = zippedArrs.map { case (arr, index) => + if (i < arr.numElements() && !arr.isNullAt(i)) { + arr.get(i, arrayElementTypes(index)) + } else { + null + } + } + + result(i) = InternalRow.apply(currentLayer: _*) + } + new GenericArrayData(result) + } + } + + override def prettyName: String = "arrays_zip" +} + /** * Returns an unordered array containing the values of the map. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index f8ad624ce0e3d..85e692bdc4ef1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types._ @@ -315,6 +316,91 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper Some(Literal.create(null, StringType))), null) } + test("ArraysZip") { + val literals = Seq( + Literal.create(Seq(9001, 9002, 9003, null), ArrayType(IntegerType)), + Literal.create(Seq(null, 1L, null, 4L, 11L), ArrayType(LongType)), + Literal.create(Seq(-1, -3, 900, null), ArrayType(IntegerType)), + Literal.create(Seq("a", null, "c"), ArrayType(StringType)), + Literal.create(Seq(null, false, true), ArrayType(BooleanType)), + Literal.create(Seq(1.1, null, 1.3, null), ArrayType(DoubleType)), + Literal.create(Seq(), ArrayType(NullType)), + Literal.create(Seq(null), ArrayType(NullType)), + Literal.create(Seq(192.toByte), ArrayType(ByteType)), + Literal.create( + Seq(Seq(1, 2, 3), null, Seq(4, 5), Seq(1, null, 3)), ArrayType(ArrayType(IntegerType))), + Literal.create(Seq(Array[Byte](1.toByte, 5.toByte)), ArrayType(BinaryType)) + ) + + checkEvaluation(ArraysZip(Seq(literals(0), literals(1))), + List(Row(9001, null), Row(9002, 1L), Row(9003, null), Row(null, 4L), Row(null, 11L))) + + checkEvaluation(ArraysZip(Seq(literals(0), literals(2))), + List(Row(9001, -1), Row(9002, -3), Row(9003, 900), Row(null, null))) + + checkEvaluation(ArraysZip(Seq(literals(0), literals(3))), + List(Row(9001, "a"), Row(9002, null), Row(9003, "c"), Row(null, null))) + + checkEvaluation(ArraysZip(Seq(literals(0), literals(4))), + List(Row(9001, null), Row(9002, false), Row(9003, true), Row(null, null))) + + checkEvaluation(ArraysZip(Seq(literals(0), literals(5))), + List(Row(9001, 1.1), Row(9002, null), Row(9003, 1.3), Row(null, null))) + + checkEvaluation(ArraysZip(Seq(literals(0), literals(6))), + List(Row(9001, null), Row(9002, null), Row(9003, null), Row(null, null))) + + checkEvaluation(ArraysZip(Seq(literals(0), literals(7))), + List(Row(9001, null), Row(9002, null), Row(9003, null), Row(null, null))) + + checkEvaluation(ArraysZip(Seq(literals(0), literals(1), literals(2), literals(3))), + List( + Row(9001, null, -1, "a"), + Row(9002, 1L, -3, null), + Row(9003, null, 900, "c"), + Row(null, 4L, null, null), + Row(null, 11L, null, null))) + + checkEvaluation(ArraysZip(Seq(literals(4), literals(5), literals(6), literals(7), literals(8))), + List( + Row(null, 1.1, null, null, 192.toByte), + Row(false, null, null, null, null), + Row(true, 1.3, null, null, null), + Row(null, null, null, null, null))) + + checkEvaluation(ArraysZip(Seq(literals(9), literals(0))), + List( + Row(List(1, 2, 3), 9001), + Row(null, 9002), + Row(List(4, 5), 9003), + Row(List(1, null, 3), null))) + + checkEvaluation(ArraysZip(Seq(literals(7), literals(10))), + List(Row(null, Array[Byte](1.toByte, 5.toByte)))) + + val longLiteral = + Literal.create((0 to 1000).toSeq, ArrayType(IntegerType)) + + checkEvaluation(ArraysZip(Seq(literals(0), longLiteral)), + List(Row(9001, 0), Row(9002, 1), Row(9003, 2)) ++ + (3 to 1000).map { Row(null, _) }.toList) + + val manyLiterals = (0 to 1000).map { _ => + Literal.create(Seq(1), ArrayType(IntegerType)) + }.toSeq + + val numbers = List( + Row(Seq(9001) ++ (0 to 1000).map { _ => 1 }.toSeq: _*), + Row(Seq(9002) ++ (0 to 1000).map { _ => null }.toSeq: _*), + Row(Seq(9003) ++ (0 to 1000).map { _ => null }.toSeq: _*), + Row(Seq(null) ++ (0 to 1000).map { _ => null }.toSeq: _*)) + checkEvaluation(ArraysZip(Seq(literals(0)) ++ manyLiterals), + List(numbers(0), numbers(1), numbers(2), numbers(3))) + + checkEvaluation(ArraysZip(Seq(literals(0), Literal.create(null, ArrayType(IntegerType)))), null) + checkEvaluation(ArraysZip(Seq()), List()) + } + test("Array Min") { checkEvaluation(ArrayMin(Literal.create(Seq(-11, 10, 2), ArrayType(IntegerType))), -11) checkEvaluation( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index a2aae9a708ff3..266a136fc2410 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3508,6 +3508,14 @@ object functions { */ def map_entries(e: Column): Column = withExpr { MapEntries(e.expr) } + /** + * Returns a merged array of structs in which the N-th struct contains all N-th values of input + * arrays. + * @group collection_funcs + * @since 2.4.0 + */ + def arrays_zip(e: Column*): Column = withExpr { ArraysZip(e.map(_.expr)) } + ////////////////////////////////////////////////////////////////////////////////////////////// // Mask functions ////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 59119bbbd8a2c..959a77a9ea345 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -479,6 +479,53 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { ) } + test("dataframe arrays_zip function") { + val df1 = Seq((Seq(9001, 9002, 9003), Seq(4, 5, 6))).toDF("val1", "val2") + val df2 = Seq((Seq("a", "b"), Seq(true, false), Seq(10, 11))).toDF("val1", "val2", "val3") + val df3 = Seq((Seq("a", "b"), Seq(4, 5, 6))).toDF("val1", "val2") + val df4 = Seq((Seq("a", "b", null), Seq(4L))).toDF("val1", "val2") + val df5 = Seq((Seq(-1), Seq(null), Seq(), Seq(null, null))).toDF("val1", "val2", "val3", "val4") + val df6 = Seq((Seq(192.toByte, 256.toByte), Seq(1.1), Seq(), Seq(null, null))) + .toDF("v1", "v2", "v3", "v4") + val df7 = Seq((Seq(Seq(1, 2, 3), Seq(4, 5)), Seq(1.1, 2.2))).toDF("v1", "v2") + val df8 = Seq((Seq(Array[Byte](1.toByte, 5.toByte)), Seq(null))).toDF("v1", "v2") + + val expectedValue1 = Row(Seq(Row(9001, 4), Row(9002, 5), Row(9003, 6))) + checkAnswer(df1.select(arrays_zip($"val1", $"val2")), expectedValue1) + checkAnswer(df1.selectExpr("arrays_zip(val1, val2)"), expectedValue1) + + val expectedValue2 = Row(Seq(Row("a", true, 10), Row("b", false, 11))) + checkAnswer(df2.select(arrays_zip($"val1", $"val2", $"val3")), expectedValue2) + checkAnswer(df2.selectExpr("arrays_zip(val1, val2, val3)"), expectedValue2) + + val expectedValue3 = Row(Seq(Row("a", 4), Row("b", 5), Row(null, 6))) + checkAnswer(df3.select(arrays_zip($"val1", $"val2")), expectedValue3) + checkAnswer(df3.selectExpr("arrays_zip(val1, val2)"), expectedValue3) + + val expectedValue4 = Row(Seq(Row("a", 4L), Row("b", null), Row(null, null))) + checkAnswer(df4.select(arrays_zip($"val1", $"val2")), expectedValue4) + checkAnswer(df4.selectExpr("arrays_zip(val1, val2)"), expectedValue4) + + val expectedValue5 = Row(Seq(Row(-1, null, null, null), Row(null, null, null, null))) + checkAnswer(df5.select(arrays_zip($"val1", $"val2", $"val3", $"val4")), expectedValue5) + checkAnswer(df5.selectExpr("arrays_zip(val1, val2, val3, val4)"), expectedValue5) + + val expectedValue6 = Row(Seq( + Row(192.toByte, 1.1, null, null), Row(256.toByte, null, null, null))) + checkAnswer(df6.select(arrays_zip($"v1", $"v2", $"v3", $"v4")), expectedValue6) + checkAnswer(df6.selectExpr("arrays_zip(v1, v2, v3, v4)"), expectedValue6) + + val expectedValue7 = Row(Seq( + Row(Seq(1, 2, 3), 1.1), Row(Seq(4, 5), 2.2))) + checkAnswer(df7.select(arrays_zip($"v1", $"v2")), expectedValue7) + checkAnswer(df7.selectExpr("arrays_zip(v1, v2)"), expectedValue7) + + val expectedValue8 = Row(Seq( + Row(Array[Byte](1.toByte, 5.toByte), null))) + checkAnswer(df8.select(arrays_zip($"v1", $"v2")), expectedValue8) + checkAnswer(df8.selectExpr("arrays_zip(v1, v2)"), expectedValue8) + } + test("map size function") { val df = Seq( (Map[Int, Int](1 -> 1, 2 -> 2), "x"), From cc88d7fad16e8b5cbf7b6b9bfe412908782b4a45 Mon Sep 17 00:00:00 2001 From: Fangshi Li Date: Tue, 12 Jun 2018 12:10:08 -0700 Subject: [PATCH 0950/1161] [SPARK-24216][SQL] Spark TypedAggregateExpression uses getSimpleName that is not safe in scala ## What changes were proposed in this pull request? When user create a aggregator object in scala and pass the aggregator to Spark Dataset's agg() method, Spark's will initialize TypedAggregateExpression with the nodeName field as aggregator.getClass.getSimpleName. However, getSimpleName is not safe in scala environment, depending on how user creates the aggregator object. For example, if the aggregator class full qualified name is "com.my.company.MyUtils$myAgg$2$", the getSimpleName will throw java.lang.InternalError "Malformed class name". This has been reported in scalatest https://github.com/scalatest/scalatest/pull/1044 and discussed in many scala upstream jiras such as SI-8110, SI-5425. To fix this issue, we follow the solution in https://github.com/scalatest/scalatest/pull/1044 to add safer version of getSimpleName as a util method, and TypedAggregateExpression will invoke this util method rather than getClass.getSimpleName. ## How was this patch tested? added unit test Author: Fangshi Li Closes #21276 from fangshil/SPARK-24216. --- .../org/apache/spark/util/AccumulatorV2.scala | 6 +- .../scala/org/apache/spark/util/Utils.scala | 59 ++++++++++++++++++- .../org/apache/spark/util/UtilsSuite.scala | 16 +++++ .../spark/ml/util/Instrumentation.scala | 5 +- .../aggregate/TypedAggregateExpression.scala | 5 +- .../v2/DataSourceV2StringFormat.scala | 4 +- 6 files changed, 89 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala b/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala index 3b469a69437b9..bf618b4afbce0 100644 --- a/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala +++ b/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala @@ -200,10 +200,12 @@ abstract class AccumulatorV2[IN, OUT] extends Serializable { } override def toString: String = { + // getClass.getSimpleName can cause Malformed class name error, + // call safer `Utils.getSimpleName` instead if (metadata == null) { - "Un-registered Accumulator: " + getClass.getSimpleName + "Un-registered Accumulator: " + Utils.getSimpleName(getClass) } else { - getClass.getSimpleName + s"(id: $id, name: $name, value: $value)" + Utils.getSimpleName(getClass) + s"(id: $id, name: $name, value: $value)" } } } diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index f9191a59c1655..7428db2158538 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -19,6 +19,7 @@ package org.apache.spark.util import java.io._ import java.lang.{Byte => JByte} +import java.lang.InternalError import java.lang.management.{LockInfo, ManagementFactory, MonitorInfo, ThreadInfo} import java.lang.reflect.InvocationTargetException import java.math.{MathContext, RoundingMode} @@ -1820,7 +1821,7 @@ private[spark] object Utils extends Logging { /** Return the class name of the given object, removing all dollar signs */ def getFormattedClassName(obj: AnyRef): String = { - obj.getClass.getSimpleName.replace("$", "") + getSimpleName(obj.getClass).replace("$", "") } /** @@ -2715,6 +2716,62 @@ private[spark] object Utils extends Logging { HashCodes.fromBytes(secretBytes).toString() } + /** + * Safer than Class obj's getSimpleName which may throw Malformed class name error in scala. + * This method mimicks scalatest's getSimpleNameOfAnObjectsClass. + */ + def getSimpleName(cls: Class[_]): String = { + try { + return cls.getSimpleName + } catch { + case err: InternalError => return stripDollars(stripPackages(cls.getName)) + } + } + + /** + * Remove the packages from full qualified class name + */ + private def stripPackages(fullyQualifiedName: String): String = { + fullyQualifiedName.split("\\.").takeRight(1)(0) + } + + /** + * Remove trailing dollar signs from qualified class name, + * and return the trailing part after the last dollar sign in the middle + */ + private def stripDollars(s: String): String = { + val lastDollarIndex = s.lastIndexOf('$') + if (lastDollarIndex < s.length - 1) { + // The last char is not a dollar sign + if (lastDollarIndex == -1 || !s.contains("$iw")) { + // The name does not have dollar sign or is not an intepreter + // generated class, so we should return the full string + s + } else { + // The class name is intepreter generated, + // return the part after the last dollar sign + // This is the same behavior as getClass.getSimpleName + s.substring(lastDollarIndex + 1) + } + } + else { + // The last char is a dollar sign + // Find last non-dollar char + val lastNonDollarChar = s.reverse.find(_ != '$') + lastNonDollarChar match { + case None => s + case Some(c) => + val lastNonDollarIndex = s.lastIndexOf(c) + if (lastNonDollarIndex == -1) { + s + } else { + // Strip the trailing dollar signs + // Invoke stripDollars again to get the simple name + stripDollars(s.substring(0, lastNonDollarIndex + 1)) + } + } + } + } } private[util] object CallerContext extends Logging { diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index 3b4273184f1e9..418d2f9b88500 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -1168,6 +1168,22 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { Utils.checkAndGetK8sMasterUrl("k8s://foo://host:port") } } + + object MalformedClassObject { + class MalformedClass + } + + test("Safe getSimpleName") { + // getSimpleName on class of MalformedClass will result in error: Malformed class name + // Utils.getSimpleName works + val err = intercept[java.lang.InternalError] { + classOf[MalformedClassObject.MalformedClass].getSimpleName + } + assert(err.getMessage === "Malformed class name") + + assert(Utils.getSimpleName(classOf[MalformedClassObject.MalformedClass]) === + "UtilsSuite$MalformedClassObject$MalformedClass") + } } private class SimpleExtension diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala b/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala index 3a1c166d46257..11f46eb9e4359 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala @@ -30,6 +30,7 @@ import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.param.Param import org.apache.spark.rdd.RDD import org.apache.spark.sql.Dataset +import org.apache.spark.util.Utils /** * A small wrapper that defines a training session for an estimator, and some methods to log @@ -47,7 +48,9 @@ private[spark] class Instrumentation[E <: Estimator[_]] private ( private val id = UUID.randomUUID() private val prefix = { - val className = estimator.getClass.getSimpleName + // estimator.getClass.getSimpleName can cause Malformed class name error, + // call safer `Utils.getSimpleName` instead + val className = Utils.getSimpleName(estimator.getClass) s"$className-${estimator.uid}-${dataset.hashCode()}-$id: " } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala index aab8cc50b9526..6d44890704f49 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateSafeProjection import org.apache.spark.sql.catalyst.expressions.objects.Invoke import org.apache.spark.sql.expressions.Aggregator import org.apache.spark.sql.types._ +import org.apache.spark.util.Utils object TypedAggregateExpression { def apply[BUF : Encoder, OUT : Encoder]( @@ -109,7 +110,9 @@ trait TypedAggregateExpression extends AggregateFunction { s"$nodeName($input)" } - override def nodeName: String = aggregator.getClass.getSimpleName.stripSuffix("$") + // aggregator.getClass.getSimpleName can cause Malformed class name error, + // call safer `Utils.getSimpleName` instead + override def nodeName: String = Utils.getSimpleName(aggregator.getClass).stripSuffix("$"); } // TODO: merge these 2 implementations once we refactor the `AggregateFunction` interface. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StringFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StringFormat.scala index 693e67dcd108e..97e6c6d702acb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StringFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StringFormat.scala @@ -53,7 +53,9 @@ trait DataSourceV2StringFormat { private def sourceName: String = source match { case registered: DataSourceRegister => registered.shortName() - case _ => source.getClass.getSimpleName.stripSuffix("$") + // source.getClass.getSimpleName can cause Malformed class name error, + // call safer `Utils.getSimpleName` instead + case _ => Utils.getSimpleName(source.getClass) } def metadataString: String = { From ada28f25955a9e8ddd182ad41b2a4ef278f3d809 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Tue, 12 Jun 2018 12:31:22 -0700 Subject: [PATCH 0951/1161] [SPARK-23933][SQL] Add map_from_arrays function ## What changes were proposed in this pull request? The PR adds the SQL function `map_from_arrays`. The behavior of the function is based on Presto's `map`. Since SparkSQL already had a `map` function, we prepared the different name for this behavior. This function returns returns a map from a pair of arrays for keys and values. ## How was this patch tested? Added UTs Author: Kazuaki Ishizaki Closes #21258 from kiszk/SPARK-23933. --- python/pyspark/sql/functions.py | 19 +++++ .../catalyst/analysis/FunctionRegistry.scala | 1 + .../expressions/complexTypeCreator.scala | 72 ++++++++++++++++++- .../expressions/ComplexTypeSuite.scala | 44 ++++++++++++ .../org/apache/spark/sql/functions.scala | 11 +++ .../spark/sql/DataFrameFunctionsSuite.scala | 30 ++++++++ 6 files changed, 176 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 0715297042520..1cdbb8a4c3e8b 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1819,6 +1819,25 @@ def create_map(*cols): return Column(jc) +@since(2.4) +def map_from_arrays(col1, col2): + """Creates a new map from two arrays. + + :param col1: name of column containing a set of keys. All elements should not be null + :param col2: name of column containing a set of values + + >>> df = spark.createDataFrame([([2, 5], ['a', 'b'])], ['k', 'v']) + >>> df.select(map_from_arrays(df.k, df.v).alias("map")).show() + +----------------+ + | map| + +----------------+ + |[2 -> a, 5 -> b]| + +----------------+ + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.map_from_arrays(_to_java_column(col1), _to_java_column(col2))) + + @since(1.4) def array(*cols): """Creates a new array column. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 3c0b72873af54..3700c63d817ea 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -417,6 +417,7 @@ object FunctionRegistry { expression[CreateMap]("map"), expression[CreateNamedStruct]("named_struct"), expression[ElementAt]("element_at"), + expression[MapFromArrays]("map_from_arrays"), expression[MapKeys]("map_keys"), expression[MapValues]("map_values"), expression[MapEntries]("map_entries"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index a9867aaeb0cfe..0a5f8a907b50a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ -import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, TypeUtils} +import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.array.ByteArrayMethods @@ -236,6 +236,76 @@ case class CreateMap(children: Seq[Expression]) extends Expression { override def prettyName: String = "map" } +/** + * Returns a catalyst Map containing the two arrays in children expressions as keys and values. + */ +@ExpressionDescription( + usage = """ + _FUNC_(keys, values) - Creates a map with a pair of the given key/value arrays. All elements + in keys should not be null""", + examples = """ + Examples: + > SELECT _FUNC_([1.0, 3.0], ['2', '4']); + {1.0:"2",3.0:"4"} + """, since = "2.4.0") +case class MapFromArrays(left: Expression, right: Expression) + extends BinaryExpression with ExpectsInputTypes { + + override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, ArrayType) + + override def dataType: DataType = { + MapType( + keyType = left.dataType.asInstanceOf[ArrayType].elementType, + valueType = right.dataType.asInstanceOf[ArrayType].elementType, + valueContainsNull = right.dataType.asInstanceOf[ArrayType].containsNull) + } + + override def nullSafeEval(keyArray: Any, valueArray: Any): Any = { + val keyArrayData = keyArray.asInstanceOf[ArrayData] + val valueArrayData = valueArray.asInstanceOf[ArrayData] + if (keyArrayData.numElements != valueArrayData.numElements) { + throw new RuntimeException("The given two arrays should have the same length") + } + val leftArrayType = left.dataType.asInstanceOf[ArrayType] + if (leftArrayType.containsNull) { + var i = 0 + while (i < keyArrayData.numElements) { + if (keyArrayData.isNullAt(i)) { + throw new RuntimeException("Cannot use null as map key!") + } + i += 1 + } + } + new ArrayBasedMapData(keyArrayData.copy(), valueArrayData.copy()) + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, (keyArrayData, valueArrayData) => { + val arrayBasedMapData = classOf[ArrayBasedMapData].getName + val leftArrayType = left.dataType.asInstanceOf[ArrayType] + val keyArrayElemNullCheck = if (!leftArrayType.containsNull) "" else { + val i = ctx.freshName("i") + s""" + |for (int $i = 0; $i < $keyArrayData.numElements(); $i++) { + | if ($keyArrayData.isNullAt($i)) { + | throw new RuntimeException("Cannot use null as map key!"); + | } + |} + """.stripMargin + } + s""" + |if ($keyArrayData.numElements() != $valueArrayData.numElements()) { + | throw new RuntimeException("The given two arrays should have the same length"); + |} + |$keyArrayElemNullCheck + |${ev.value} = new $arrayBasedMapData($keyArrayData.copy(), $valueArrayData.copy()); + """.stripMargin + }) + } + + override def prettyName: String = "map_from_arrays" +} + /** * An expression representing a not yet available attribute name. This expression is unevaluable * and as its name suggests it is a temporary place holder until we're able to determine the diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala index b4138ce366b3a..726193b411737 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala @@ -186,6 +186,50 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { } } + test("MapFromArrays") { + def createMap(keys: Seq[Any], values: Seq[Any]): Map[Any, Any] = { + // catalyst map is order-sensitive, so we create ListMap here to preserve the elements order. + scala.collection.immutable.ListMap(keys.zip(values): _*) + } + + val intSeq = Seq(5, 10, 15, 20, 25) + val longSeq = intSeq.map(_.toLong) + val strSeq = intSeq.map(_.toString) + val integerSeq = Seq[java.lang.Integer](5, 10, 15, 20, 25) + val intWithNullSeq = Seq[java.lang.Integer](5, 10, null, 20, 25) + val longWithNullSeq = intSeq.map(java.lang.Long.valueOf(_)) + + val intArray = Literal.create(intSeq, ArrayType(IntegerType, false)) + val longArray = Literal.create(longSeq, ArrayType(LongType, false)) + val strArray = Literal.create(strSeq, ArrayType(StringType, false)) + + val integerArray = Literal.create(integerSeq, ArrayType(IntegerType, true)) + val intWithNullArray = Literal.create(intWithNullSeq, ArrayType(IntegerType, true)) + val longWithNullArray = Literal.create(longWithNullSeq, ArrayType(LongType, true)) + + val nullArray = Literal.create(null, ArrayType(StringType, false)) + + checkEvaluation(MapFromArrays(intArray, longArray), createMap(intSeq, longSeq)) + checkEvaluation(MapFromArrays(intArray, strArray), createMap(intSeq, strSeq)) + checkEvaluation(MapFromArrays(integerArray, strArray), createMap(integerSeq, strSeq)) + + checkEvaluation( + MapFromArrays(strArray, intWithNullArray), createMap(strSeq, intWithNullSeq)) + checkEvaluation( + MapFromArrays(strArray, longWithNullArray), createMap(strSeq, longWithNullSeq)) + checkEvaluation( + MapFromArrays(strArray, longWithNullArray), createMap(strSeq, longWithNullSeq)) + checkEvaluation(MapFromArrays(nullArray, nullArray), null) + + intercept[RuntimeException] { + checkEvaluation(MapFromArrays(intWithNullArray, strArray), null) + } + intercept[RuntimeException] { + checkEvaluation( + MapFromArrays(intArray, Literal.create(Seq(1), ArrayType(IntegerType))), null) + } + } + test("CreateStruct") { val row = create_row(1, 2, 3) val c1 = 'a.int.at(0) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 266a136fc2410..87bd7b3b0f9c6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1070,6 +1070,17 @@ object functions { @scala.annotation.varargs def map(cols: Column*): Column = withExpr { CreateMap(cols.map(_.expr)) } + /** + * Creates a new map column. The array in the first column is used for keys. The array in the + * second column is used for values. All elements in the array for key should not be null. + * + * @group normal_funcs + * @since 2.4 + */ + def map_from_arrays(keys: Column, values: Column): Column = withExpr { + MapFromArrays(keys.expr, values.expr) + } + /** * Marks a DataFrame as small enough for use in broadcast joins. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 959a77a9ea345..4e5c1c56e2673 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -62,6 +62,36 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { assert(row.getMap[Int, String](0) === Map(2 -> "a")) } + test("map with arrays") { + val df1 = Seq((Seq(1, 2), Seq("a", "b"))).toDF("k", "v") + val expectedType = MapType(IntegerType, StringType, valueContainsNull = true) + val row = df1.select(map_from_arrays($"k", $"v")).first() + assert(row.schema(0).dataType === expectedType) + assert(row.getMap[Int, String](0) === Map(1 -> "a", 2 -> "b")) + checkAnswer(df1.select(map_from_arrays($"k", $"v")), Seq(Row(Map(1 -> "a", 2 -> "b")))) + + val df2 = Seq((Seq(1, 2), Seq(null, "b"))).toDF("k", "v") + checkAnswer(df2.select(map_from_arrays($"k", $"v")), Seq(Row(Map(1 -> null, 2 -> "b")))) + + val df3 = Seq((null, null)).toDF("k", "v") + checkAnswer(df3.select(map_from_arrays($"k", $"v")), Seq(Row(null))) + + val df4 = Seq((1, "a")).toDF("k", "v") + intercept[AnalysisException] { + df4.select(map_from_arrays($"k", $"v")) + } + + val df5 = Seq((Seq("a", null), Seq(1, 2))).toDF("k", "v") + intercept[RuntimeException] { + df5.select(map_from_arrays($"k", $"v")).collect + } + + val df6 = Seq((Seq(1, 2), Seq("a"))).toDF("k", "v") + intercept[RuntimeException] { + df6.select(map_from_arrays($"k", $"v")).collect + } + } + test("struct with column name") { val df = Seq((1, "str")).toDF("a", "b") val row = df.select(struct("a", "b")).first() From 0d3714d221460a2a1141134c3d451f18c4e0d46f Mon Sep 17 00:00:00 2001 From: Xingbo Jiang Date: Tue, 12 Jun 2018 15:57:43 -0700 Subject: [PATCH 0952/1161] [SPARK-23010][BUILD][FOLLOWUP] Fix java checkstyle failure of kubernetes-integration-tests ## What changes were proposed in this pull request? Fix java checkstyle failure of kubernetes-integration-tests ## How was this patch tested? Checked manually on my local environment. Author: Xingbo Jiang Closes #21545 from jiangxb1987/k8s-checkstyle. --- project/SparkBuild.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index adc2b6b5d273d..b606f9355e03b 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -57,11 +57,11 @@ object BuildCommons { val optionallyEnabledProjects@Seq(kubernetes, mesos, yarn, streamingFlumeSink, streamingFlume, streamingKafka, sparkGangliaLgpl, streamingKinesisAsl, - dockerIntegrationTests, hadoopCloud) = + dockerIntegrationTests, hadoopCloud, kubernetesIntegrationTests) = Seq("kubernetes", "mesos", "yarn", "streaming-flume-sink", "streaming-flume", "streaming-kafka-0-8", "ganglia-lgpl", "streaming-kinesis-asl", - "docker-integration-tests", "hadoop-cloud").map(ProjectRef(buildLocation, _)) + "docker-integration-tests", "hadoop-cloud", "kubernetes-integration-tests").map(ProjectRef(buildLocation, _)) val assemblyProjects@Seq(networkYarn, streamingFlumeAssembly, streamingKafkaAssembly, streamingKafka010Assembly, streamingKinesisAslAssembly) = Seq("network-yarn", "streaming-flume-assembly", "streaming-kafka-0-8-assembly", "streaming-kafka-0-10-assembly", "streaming-kinesis-asl-assembly") From f53818d35bdef5d20a2718b14a2fed4c468545c6 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Tue, 12 Jun 2018 16:42:44 -0700 Subject: [PATCH 0953/1161] [SPARK-24506][UI] Add UI filters to tabs added after binding ## What changes were proposed in this pull request? Currently, `spark.ui.filters` are not applied to the handlers added after binding the server. This means that every page which is added after starting the UI will not have the filters configured on it. This can allow unauthorized access to the pages. The PR adds the filters also to the handlers added after the UI starts. ## How was this patch tested? manual tests (without the patch, starting the thriftserver with `--conf spark.ui.filters=org.apache.hadoop.security.authentication.server.AuthenticationFilter --conf spark.org.apache.hadoop.security.authentication.server.AuthenticationFilter.params="type=simple"` you can access `http://localhost:4040/sqlserver`; with the patch, 401 is the response as for the other pages). Author: Marco Gaido Closes #21523 from mgaido91/SPARK-24506. --- .../org/apache/spark/deploy/history/HistoryServer.scala | 1 - core/src/main/scala/org/apache/spark/ui/JettyUtils.scala | 8 +++++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala index a9a4d5a4ec6a2..066275e8f8425 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala @@ -152,7 +152,6 @@ class HistoryServer( assert(serverInfo.isDefined, "HistoryServer must be bound before attaching SparkUIs") handlers.synchronized { ui.getHandlers.foreach(attachHandler) - addFilters(ui.getHandlers, conf) } } diff --git a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala index d6a025a6f12da..52a955111231a 100644 --- a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala @@ -263,7 +263,7 @@ private[spark] object JettyUtils extends Logging { filters.foreach { case filter : String => if (!filter.isEmpty) { - logInfo("Adding filter: " + filter) + logInfo(s"Adding filter $filter to ${handlers.map(_.getContextPath).mkString(", ")}.") val holder : FilterHolder = new FilterHolder() holder.setClassName(filter) // Get any parameters for each filter @@ -407,7 +407,7 @@ private[spark] object JettyUtils extends Logging { } pool.setMaxThreads(math.max(pool.getMaxThreads, minThreads)) - ServerInfo(server, httpPort, securePort, collection) + ServerInfo(server, httpPort, securePort, conf, collection) } catch { case e: Exception => server.stop() @@ -507,10 +507,12 @@ private[spark] case class ServerInfo( server: Server, boundPort: Int, securePort: Option[Int], + conf: SparkConf, private val rootHandler: ContextHandlerCollection) { - def addHandler(handler: ContextHandler): Unit = { + def addHandler(handler: ServletContextHandler): Unit = { handler.setVirtualHosts(JettyUtils.toVirtualHosts(JettyUtils.SPARK_CONNECTOR_NAME)) + JettyUtils.addFilters(Seq(handler), conf) rootHandler.addHandler(handler) if (!handler.isStarted()) { handler.start() From 9786ce66c52d41b1d58ddedb3a984f561fd09ff3 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Wed, 13 Jun 2018 09:10:52 +0800 Subject: [PATCH 0954/1161] [SPARK-22239][SQL][PYTHON] Enable grouped aggregate pandas UDFs as window functions with unbounded window frames ## What changes were proposed in this pull request? This PR enables using a grouped aggregate pandas UDFs as window functions. The semantics is the same as using SQL aggregation function as window functions. ``` >>> from pyspark.sql.functions import pandas_udf, PandasUDFType >>> from pyspark.sql import Window >>> df = spark.createDataFrame( ... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ... ("id", "v")) >>> pandas_udf("double", PandasUDFType.GROUPED_AGG) ... def mean_udf(v): ... return v.mean() >>> w = Window.partitionBy('id') >>> df.withColumn('mean_v', mean_udf(df['v']).over(w)).show() +---+----+------+ | id| v|mean_v| +---+----+------+ | 1| 1.0| 1.5| | 1| 2.0| 1.5| | 2| 3.0| 6.0| | 2| 5.0| 6.0| | 2|10.0| 6.0| +---+----+------+ ``` The scope of this PR is somewhat limited in terms of: (1) Only supports unbounded window, which acts essentially as group by. (2) Only supports aggregation functions, not "transform" like window functions (n -> n mapping) Both of these are left as future work. Especially, (1) needs careful thinking w.r.t. how to pass rolling window data to python efficiently. (2) is a bit easier but does require more changes therefore I think it's better to leave it as a separate PR. ## How was this patch tested? WindowPandasUDFTests Author: Li Jin Closes #21082 from icexelloss/SPARK-22239-window-udf. --- .../spark/api/python/PythonRunner.scala | 2 + python/pyspark/rdd.py | 1 + python/pyspark/sql/functions.py | 34 ++- python/pyspark/sql/tests.py | 238 ++++++++++++++++++ python/pyspark/worker.py | 20 +- .../sql/catalyst/analysis/Analyzer.scala | 11 +- .../sql/catalyst/analysis/CheckAnalysis.scala | 12 +- .../sql/catalyst/expressions/PythonUDF.scala | 6 +- .../expressions/windowExpressions.scala | 33 ++- .../sql/catalyst/optimizer/Optimizer.scala | 7 +- .../sql/catalyst/planning/patterns.scala | 42 +++- .../spark/sql/execution/SparkPlanner.scala | 1 + .../spark/sql/execution/SparkStrategies.scala | 20 +- .../execution/python/ExtractPythonUDFs.scala | 2 +- .../execution/python/WindowInPandasExec.scala | 173 +++++++++++++ 15 files changed, 580 insertions(+), 22 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala index 41eac10d9b267..ebabedf950e39 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala @@ -40,6 +40,7 @@ private[spark] object PythonEvalType { val SQL_SCALAR_PANDAS_UDF = 200 val SQL_GROUPED_MAP_PANDAS_UDF = 201 val SQL_GROUPED_AGG_PANDAS_UDF = 202 + val SQL_WINDOW_AGG_PANDAS_UDF = 203 def toString(pythonEvalType: Int): String = pythonEvalType match { case NON_UDF => "NON_UDF" @@ -47,6 +48,7 @@ private[spark] object PythonEvalType { case SQL_SCALAR_PANDAS_UDF => "SQL_SCALAR_PANDAS_UDF" case SQL_GROUPED_MAP_PANDAS_UDF => "SQL_GROUPED_MAP_PANDAS_UDF" case SQL_GROUPED_AGG_PANDAS_UDF => "SQL_GROUPED_AGG_PANDAS_UDF" + case SQL_WINDOW_AGG_PANDAS_UDF => "SQL_WINDOW_AGG_PANDAS_UDF" } } diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 14d9128502ab0..7e7e5822a6b20 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -74,6 +74,7 @@ class PythonEvalType(object): SQL_SCALAR_PANDAS_UDF = 200 SQL_GROUPED_MAP_PANDAS_UDF = 201 SQL_GROUPED_AGG_PANDAS_UDF = 202 + SQL_WINDOW_AGG_PANDAS_UDF = 203 def portable_hash(x): diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 1cdbb8a4c3e8b..a5e3384e802b8 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2616,10 +2616,12 @@ def pandas_udf(f=None, returnType=None, functionType=None): The returned scalar can be either a python primitive type, e.g., `int` or `float` or a numpy data type, e.g., `numpy.int64` or `numpy.float64`. - :class:`ArrayType`, :class:`MapType` and :class:`StructType` are currently not supported as - output types. + :class:`MapType` and :class:`StructType` are currently not supported as output types. - Group aggregate UDFs are used with :meth:`pyspark.sql.GroupedData.agg` + Group aggregate UDFs are used with :meth:`pyspark.sql.GroupedData.agg` and + :class:`pyspark.sql.Window` + + This example shows using grouped aggregated UDFs with groupby: >>> from pyspark.sql.functions import pandas_udf, PandasUDFType >>> df = spark.createDataFrame( @@ -2636,7 +2638,31 @@ def pandas_udf(f=None, returnType=None, functionType=None): | 2| 6.0| +---+-----------+ - .. seealso:: :meth:`pyspark.sql.GroupedData.agg` + This example shows using grouped aggregated UDFs as window functions. Note that only + unbounded window frame is supported at the moment: + + >>> from pyspark.sql.functions import pandas_udf, PandasUDFType + >>> from pyspark.sql import Window + >>> df = spark.createDataFrame( + ... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], + ... ("id", "v")) + >>> @pandas_udf("double", PandasUDFType.GROUPED_AGG) # doctest: +SKIP + ... def mean_udf(v): + ... return v.mean() + >>> w = Window.partitionBy('id') \\ + ... .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing) + >>> df.withColumn('mean_v', mean_udf(df['v']).over(w)).show() # doctest: +SKIP + +---+----+------+ + | id| v|mean_v| + +---+----+------+ + | 1| 1.0| 1.5| + | 1| 2.0| 1.5| + | 2| 3.0| 6.0| + | 2| 5.0| 6.0| + | 2|10.0| 6.0| + +---+----+------+ + + .. seealso:: :meth:`pyspark.sql.GroupedData.agg` and :class:`pyspark.sql.Window` .. note:: The user-defined functions are considered deterministic by default. Due to optimization, duplicate invocations may be eliminated or the function may even be invoked diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 4a3941de8a6a6..2d7a4f62d4ee8 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -5454,6 +5454,15 @@ def test_retain_group_columns(self): expected1 = df.groupby(df.id).agg(sum(df.v)) self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) + def test_array_type(self): + from pyspark.sql.functions import pandas_udf, PandasUDFType + + df = self.data + + array_udf = pandas_udf(lambda x: [1.0, 2.0], 'array', PandasUDFType.GROUPED_AGG) + result1 = df.groupby('id').agg(array_udf(df['v']).alias('v2')) + self.assertEquals(result1.first()['v2'], [1.0, 2.0]) + def test_invalid_args(self): from pyspark.sql.functions import mean @@ -5479,6 +5488,235 @@ def test_invalid_args(self): 'mixture.*aggregate function.*group aggregate pandas UDF'): df.groupby(df.id).agg(mean_udf(df.v), mean(df.v)).collect() + +@unittest.skipIf( + not _have_pandas or not _have_pyarrow, + _pandas_requirement_message or _pyarrow_requirement_message) +class WindowPandasUDFTests(ReusedSQLTestCase): + @property + def data(self): + from pyspark.sql.functions import array, explode, col, lit + return self.spark.range(10).toDF('id') \ + .withColumn("vs", array([lit(i * 1.0) + col('id') for i in range(20, 30)])) \ + .withColumn("v", explode(col('vs'))) \ + .drop('vs') \ + .withColumn('w', lit(1.0)) + + @property + def python_plus_one(self): + from pyspark.sql.functions import udf + return udf(lambda v: v + 1, 'double') + + @property + def pandas_scalar_time_two(self): + from pyspark.sql.functions import pandas_udf, PandasUDFType + return pandas_udf(lambda v: v * 2, 'double') + + @property + def pandas_agg_mean_udf(self): + from pyspark.sql.functions import pandas_udf, PandasUDFType + + @pandas_udf('double', PandasUDFType.GROUPED_AGG) + def avg(v): + return v.mean() + return avg + + @property + def pandas_agg_max_udf(self): + from pyspark.sql.functions import pandas_udf, PandasUDFType + + @pandas_udf('double', PandasUDFType.GROUPED_AGG) + def max(v): + return v.max() + return max + + @property + def pandas_agg_min_udf(self): + from pyspark.sql.functions import pandas_udf, PandasUDFType + + @pandas_udf('double', PandasUDFType.GROUPED_AGG) + def min(v): + return v.min() + return min + + @property + def unbounded_window(self): + return Window.partitionBy('id') \ + .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing) + + @property + def ordered_window(self): + return Window.partitionBy('id').orderBy('v') + + @property + def unpartitioned_window(self): + return Window.partitionBy() + + def test_simple(self): + from pyspark.sql.functions import pandas_udf, PandasUDFType, percent_rank, mean, max + + df = self.data + w = self.unbounded_window + + mean_udf = self.pandas_agg_mean_udf + + result1 = df.withColumn('mean_v', mean_udf(df['v']).over(w)) + expected1 = df.withColumn('mean_v', mean(df['v']).over(w)) + + result2 = df.select(mean_udf(df['v']).over(w)) + expected2 = df.select(mean(df['v']).over(w)) + + self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) + self.assertPandasEqual(expected2.toPandas(), result2.toPandas()) + + def test_multiple_udfs(self): + from pyspark.sql.functions import max, min, mean + + df = self.data + w = self.unbounded_window + + result1 = df.withColumn('mean_v', self.pandas_agg_mean_udf(df['v']).over(w)) \ + .withColumn('max_v', self.pandas_agg_max_udf(df['v']).over(w)) \ + .withColumn('min_w', self.pandas_agg_min_udf(df['w']).over(w)) + + expected1 = df.withColumn('mean_v', mean(df['v']).over(w)) \ + .withColumn('max_v', max(df['v']).over(w)) \ + .withColumn('min_w', min(df['w']).over(w)) + + self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) + + def test_replace_existing(self): + from pyspark.sql.functions import mean + + df = self.data + w = self.unbounded_window + + result1 = df.withColumn('v', self.pandas_agg_mean_udf(df['v']).over(w)) + expected1 = df.withColumn('v', mean(df['v']).over(w)) + + self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) + + def test_mixed_sql(self): + from pyspark.sql.functions import mean + + df = self.data + w = self.unbounded_window + mean_udf = self.pandas_agg_mean_udf + + result1 = df.withColumn('v', mean_udf(df['v'] * 2).over(w) + 1) + expected1 = df.withColumn('v', mean(df['v'] * 2).over(w) + 1) + + self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) + + def test_mixed_udf(self): + from pyspark.sql.functions import mean + + df = self.data + w = self.unbounded_window + + plus_one = self.python_plus_one + time_two = self.pandas_scalar_time_two + mean_udf = self.pandas_agg_mean_udf + + result1 = df.withColumn( + 'v2', + plus_one(mean_udf(plus_one(df['v'])).over(w))) + expected1 = df.withColumn( + 'v2', + plus_one(mean(plus_one(df['v'])).over(w))) + + result2 = df.withColumn( + 'v2', + time_two(mean_udf(time_two(df['v'])).over(w))) + expected2 = df.withColumn( + 'v2', + time_two(mean(time_two(df['v'])).over(w))) + + self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) + self.assertPandasEqual(expected2.toPandas(), result2.toPandas()) + + def test_without_partitionBy(self): + from pyspark.sql.functions import mean + + df = self.data + w = self.unpartitioned_window + mean_udf = self.pandas_agg_mean_udf + + result1 = df.withColumn('v2', mean_udf(df['v']).over(w)) + expected1 = df.withColumn('v2', mean(df['v']).over(w)) + + result2 = df.select(mean_udf(df['v']).over(w)) + expected2 = df.select(mean(df['v']).over(w)) + + self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) + self.assertPandasEqual(expected2.toPandas(), result2.toPandas()) + + def test_mixed_sql_and_udf(self): + from pyspark.sql.functions import max, min, rank, col + + df = self.data + w = self.unbounded_window + ow = self.ordered_window + max_udf = self.pandas_agg_max_udf + min_udf = self.pandas_agg_min_udf + + result1 = df.withColumn('v_diff', max_udf(df['v']).over(w) - min_udf(df['v']).over(w)) + expected1 = df.withColumn('v_diff', max(df['v']).over(w) - min(df['v']).over(w)) + + # Test mixing sql window function and window udf in the same expression + result2 = df.withColumn('v_diff', max_udf(df['v']).over(w) - min(df['v']).over(w)) + expected2 = expected1 + + # Test chaining sql aggregate function and udf + result3 = df.withColumn('max_v', max_udf(df['v']).over(w)) \ + .withColumn('min_v', min(df['v']).over(w)) \ + .withColumn('v_diff', col('max_v') - col('min_v')) \ + .drop('max_v', 'min_v') + expected3 = expected1 + + # Test mixing sql window function and udf + result4 = df.withColumn('max_v', max_udf(df['v']).over(w)) \ + .withColumn('rank', rank().over(ow)) + expected4 = df.withColumn('max_v', max(df['v']).over(w)) \ + .withColumn('rank', rank().over(ow)) + + self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) + self.assertPandasEqual(expected2.toPandas(), result2.toPandas()) + self.assertPandasEqual(expected3.toPandas(), result3.toPandas()) + self.assertPandasEqual(expected4.toPandas(), result4.toPandas()) + + def test_array_type(self): + from pyspark.sql.functions import pandas_udf, PandasUDFType + + df = self.data + w = self.unbounded_window + + array_udf = pandas_udf(lambda x: [1.0, 2.0], 'array', PandasUDFType.GROUPED_AGG) + result1 = df.withColumn('v2', array_udf(df['v']).over(w)) + self.assertEquals(result1.first()['v2'], [1.0, 2.0]) + + def test_invalid_args(self): + from pyspark.sql.functions import mean, pandas_udf, PandasUDFType + + df = self.data + w = self.unbounded_window + ow = self.ordered_window + mean_udf = self.pandas_agg_mean_udf + + with QuietTest(self.sc): + with self.assertRaisesRegexp( + AnalysisException, + '.*not supported within a window function'): + foo_udf = pandas_udf(lambda x: x, 'v double', PandasUDFType.GROUPED_MAP) + df.withColumn('v2', foo_udf(df['v']).over(w)) + + with QuietTest(self.sc): + with self.assertRaisesRegexp( + AnalysisException, + '.*Only unbounded window frame is supported.*'): + df.withColumn('mean_v', mean_udf(df['v']).over(ow)) + + if __name__ == "__main__": from pyspark.sql.tests import * if xmlrunner: diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index a30d6bf523a50..38fe2ef06eac5 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -128,6 +128,21 @@ def wrapped(*series): return lambda *a: (wrapped(*a), arrow_return_type) +def wrap_window_agg_pandas_udf(f, return_type): + # This is similar to grouped_agg_pandas_udf, the only difference + # is that window_agg_pandas_udf needs to repeat the return value + # to match window length, where grouped_agg_pandas_udf just returns + # the scalar value. + arrow_return_type = to_arrow_type(return_type) + + def wrapped(*series): + import pandas as pd + result = f(*series) + return pd.Series([result]).repeat(len(series[0])) + + return lambda *a: (wrapped(*a), arrow_return_type) + + def read_single_udf(pickleSer, infile, eval_type): num_arg = read_int(infile) arg_offsets = [read_int(infile) for i in range(num_arg)] @@ -151,6 +166,8 @@ def read_single_udf(pickleSer, infile, eval_type): return arg_offsets, wrap_grouped_map_pandas_udf(func, return_type, argspec) elif eval_type == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF: return arg_offsets, wrap_grouped_agg_pandas_udf(func, return_type) + elif eval_type == PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF: + return arg_offsets, wrap_window_agg_pandas_udf(func, return_type) elif eval_type == PythonEvalType.SQL_BATCHED_UDF: return arg_offsets, wrap_udf(func, return_type) else: @@ -195,7 +212,8 @@ def read_udfs(pickleSer, infile, eval_type): if eval_type in (PythonEvalType.SQL_SCALAR_PANDAS_UDF, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, - PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF): + PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF, + PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF): timezone = utf8_deserializer.loads(infile) ser = ArrowStreamPandasSerializer(timezone) else: diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index f9947d1fa6c78..6e3107f1c6f75 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1739,9 +1739,10 @@ class Analyzer( * 1. For a list of [[Expression]]s (a projectList or an aggregateExpressions), partitions * it two lists of [[Expression]]s, one for all [[WindowExpression]]s and another for * all regular expressions. - * 2. For all [[WindowExpression]]s, groups them based on their [[WindowSpecDefinition]]s. - * 3. For every distinct [[WindowSpecDefinition]], creates a [[Window]] operator and inserts - * it into the plan tree. + * 2. For all [[WindowExpression]]s, groups them based on their [[WindowSpecDefinition]]s + * and [[WindowFunctionType]]s. + * 3. For every distinct [[WindowSpecDefinition]] and [[WindowFunctionType]], creates a + * [[Window]] operator and inserts it into the plan tree. */ object ExtractWindowExpressions extends Rule[LogicalPlan] { private def hasWindowFunction(exprs: Seq[Expression]): Boolean = @@ -1901,7 +1902,7 @@ class Analyzer( s"Please file a bug report with this error message, stack trace, and the query.") } else { val spec = distinctWindowSpec.head - (spec.partitionSpec, spec.orderSpec) + (spec.partitionSpec, spec.orderSpec, WindowFunctionType.functionType(expr)) } }.toSeq @@ -1909,7 +1910,7 @@ class Analyzer( // setting this to the child of the next Window operator. val windowOps = groupedWindowExpressions.foldLeft(child) { - case (last, ((partitionSpec, orderSpec), windowExpressions)) => + case (last, ((partitionSpec, orderSpec, _), windowExpressions)) => Window(windowExpressions, partitionSpec, orderSpec, last) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 90bda2a72ad82..af256b98b34f3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.analysis +import org.apache.spark.api.python.PythonEvalType import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.SubExprUtils._ @@ -112,12 +113,19 @@ trait CheckAnalysis extends PredicateHelper { failAnalysis("An offset window function can only be evaluated in an ordered " + s"row-based window frame with a single offset: $w") + case _ @ WindowExpression(_: PythonUDF, + WindowSpecDefinition(_, _, frame: SpecifiedWindowFrame)) + if !frame.isUnbounded => + failAnalysis("Only unbounded window frame is supported with Pandas UDFs.") + case w @ WindowExpression(e, s) => // Only allow window functions with an aggregate expression or an offset window - // function. + // function or a Pandas window UDF. e match { case _: AggregateExpression | _: OffsetWindowFunction | _: AggregateWindowFunction => w + case f: PythonUDF if PythonUDF.isWindowPandasUDF(f) => + w case _ => failAnalysis(s"Expression '$e' not supported within a window function.") } @@ -154,7 +162,7 @@ trait CheckAnalysis extends PredicateHelper { case Aggregate(groupingExprs, aggregateExprs, child) => def isAggregateExpression(expr: Expression) = { - expr.isInstanceOf[AggregateExpression] || PythonUDF.isGroupAggPandasUDF(expr) + expr.isInstanceOf[AggregateExpression] || PythonUDF.isGroupedAggPandasUDF(expr) } def checkValidAggregateExpression(expr: Expression): Unit = expr match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala index efd664dde725a..6530b176968f2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala @@ -34,10 +34,14 @@ object PythonUDF { e.isInstanceOf[PythonUDF] && SCALAR_TYPES.contains(e.asInstanceOf[PythonUDF].evalType) } - def isGroupAggPandasUDF(e: Expression): Boolean = { + def isGroupedAggPandasUDF(e: Expression): Boolean = { e.isInstanceOf[PythonUDF] && e.asInstanceOf[PythonUDF].evalType == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF } + + // This is currently same as GroupedAggPandasUDF, but we might support new types in the future, + // e.g, N -> N transform. + def isWindowPandasUDF(e: Expression): Boolean = isGroupedAggPandasUDF(e) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala index 9fe2fb2b95e4d..f957aaa96e98c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala @@ -21,7 +21,7 @@ import java.util.Locale import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedException} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} -import org.apache.spark.sql.catalyst.expressions.aggregate.{DeclarativeAggregate, NoOp} +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateFunction, DeclarativeAggregate, NoOp} import org.apache.spark.sql.types._ /** @@ -297,6 +297,37 @@ trait WindowFunction extends Expression { def frame: WindowFrame = UnspecifiedFrame } +/** + * Case objects that describe whether a window function is a SQL window function or a Python + * user-defined window function. + */ +sealed trait WindowFunctionType + +object WindowFunctionType { + case object SQL extends WindowFunctionType + case object Python extends WindowFunctionType + + def functionType(windowExpression: NamedExpression): WindowFunctionType = { + val t = windowExpression.collectFirst { + case _: WindowFunction | _: AggregateFunction => SQL + case udf: PythonUDF if PythonUDF.isWindowPandasUDF(udf) => Python + } + + // Normally a window expression would either have a SQL window function, a SQL + // aggregate function or a python window UDF. However, sometimes the optimizer will replace + // the window function if the value of the window function can be predetermined. + // For example, for query: + // + // select count(NULL) over () from values 1.0, 2.0, 3.0 T(a) + // + // The window function will be replaced by expression literal(0) + // To handle this case, if a window expression doesn't have a regular window function, we + // consider its type to be SQL as literal(0) is also a SQL expression. + t.getOrElse(SQL) + } +} + + /** * An offset window function is a window function that returns the value of the input column offset * by a number of rows within the partition. For instance: an OffsetWindowfunction for value x with diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index bfa61116a6658..aa992def1ce6c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -621,12 +621,15 @@ object CollapseRepartition extends Rule[LogicalPlan] { /** * Collapse Adjacent Window Expression. * - If the partition specs and order specs are the same and the window expression are - * independent, collapse into the parent. + * independent and are of the same window function type, collapse into the parent. */ object CollapseWindow extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case w1 @ Window(we1, ps1, os1, w2 @ Window(we2, ps2, os2, grandChild)) - if ps1 == ps2 && os1 == os2 && w1.references.intersect(w2.windowOutputSet).isEmpty => + if ps1 == ps2 && os1 == os2 && w1.references.intersect(w2.windowOutputSet).isEmpty && + // This assumes Window contains the same type of window expressions. This is ensured + // by ExtractWindowFunctions. + WindowFunctionType.functionType(we1.head) == WindowFunctionType.functionType(we2.head) => w1.copy(windowExpressions = we2 ++ we1, child = grandChild) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 626f905707191..84be677e438a6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.planning import org.apache.spark.internal.Logging +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans._ @@ -215,7 +216,7 @@ object PhysicalAggregation { case agg: AggregateExpression if !equivalentAggregateExpressions.addExpr(agg) => agg case udf: PythonUDF - if PythonUDF.isGroupAggPandasUDF(udf) && + if PythonUDF.isGroupedAggPandasUDF(udf) && !equivalentAggregateExpressions.addExpr(udf) => udf } } @@ -245,7 +246,7 @@ object PhysicalAggregation { equivalentAggregateExpressions.getEquivalentExprs(ae).headOption .getOrElse(ae).asInstanceOf[AggregateExpression].resultAttribute // Similar to AggregateExpression - case ue: PythonUDF if PythonUDF.isGroupAggPandasUDF(ue) => + case ue: PythonUDF if PythonUDF.isGroupedAggPandasUDF(ue) => equivalentAggregateExpressions.getEquivalentExprs(ue).headOption .getOrElse(ue).asInstanceOf[PythonUDF].resultAttribute case expression => @@ -268,3 +269,40 @@ object PhysicalAggregation { case _ => None } } + +/** + * An extractor used when planning physical execution of a window. This extractor outputs + * the window function type of the logical window. + * + * The input logical window must contain same type of window functions, which is ensured by + * the rule ExtractWindowExpressions in the analyzer. + */ +object PhysicalWindow { + // windowFunctionType, windowExpression, partitionSpec, orderSpec, child + private type ReturnType = + (WindowFunctionType, Seq[NamedExpression], Seq[Expression], Seq[SortOrder], LogicalPlan) + + def unapply(a: Any): Option[ReturnType] = a match { + case expr @ logical.Window(windowExpressions, partitionSpec, orderSpec, child) => + + // The window expression should not be empty here, otherwise it's a bug. + if (windowExpressions.isEmpty) { + throw new AnalysisException(s"Window expression is empty in $expr") + } + + val windowFunctionType = windowExpressions.map(WindowFunctionType.functionType) + .reduceLeft { (t1: WindowFunctionType, t2: WindowFunctionType) => + if (t1 != t2) { + // We shouldn't have different window function type here, otherwise it's a bug. + throw new AnalysisException( + s"Found different window function type in $windowExpressions") + } else { + t1 + } + } + + Some((windowFunctionType, windowExpressions, partitionSpec, orderSpec, child)) + + case _ => None + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala index 74048871f8d42..75f5ec0e253df 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala @@ -41,6 +41,7 @@ class SparkPlanner( DataSourceStrategy(conf) :: SpecialLimits :: Aggregation :: + Window :: JoinSelection :: InMemoryScans :: BasicOperators :: Nil) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index be34387f6a874..d6951ad01fb0c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -327,7 +327,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case PhysicalAggregation( namedGroupingExpressions, aggregateExpressions, rewrittenResultExpressions, child) => - if (aggregateExpressions.exists(PythonUDF.isGroupAggPandasUDF)) { + if (aggregateExpressions.exists(PythonUDF.isGroupedAggPandasUDF)) { throw new AnalysisException( "Streaming aggregation doesn't support group aggregate pandas UDF") } @@ -428,6 +428,22 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } } + object Window extends Strategy { + def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case PhysicalWindow( + WindowFunctionType.SQL, windowExprs, partitionSpec, orderSpec, child) => + execution.window.WindowExec( + windowExprs, partitionSpec, orderSpec, planLater(child)) :: Nil + + case PhysicalWindow( + WindowFunctionType.Python, windowExprs, partitionSpec, orderSpec, child) => + execution.python.WindowInPandasExec( + windowExprs, partitionSpec, orderSpec, planLater(child)) :: Nil + + case _ => Nil + } + } + protected lazy val singleRowRdd = sparkContext.parallelize(Seq(InternalRow()), 1) object InMemoryScans extends Strategy { @@ -548,8 +564,6 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { execution.FilterExec(f.typedCondition(f.deserializer), planLater(f.child)) :: Nil case e @ logical.Expand(_, _, child) => execution.ExpandExec(e.projections, e.output, planLater(child)) :: Nil - case logical.Window(windowExprs, partitionSpec, orderSpec, child) => - execution.window.WindowExec(windowExprs, partitionSpec, orderSpec, planLater(child)) :: Nil case logical.Sample(lb, ub, withReplacement, seed, child) => execution.SampleExec(lb, ub, withReplacement, seed, planLater(child)) :: Nil case logical.LocalRelation(output, data, _) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala index 9d56f48249982..1e096100f7f43 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala @@ -39,7 +39,7 @@ object ExtractPythonUDFFromAggregate extends Rule[LogicalPlan] { */ private def belongAggregate(e: Expression, agg: Aggregate): Boolean = { e.isInstanceOf[AggregateExpression] || - PythonUDF.isGroupAggPandasUDF(e) || + PythonUDF.isGroupedAggPandasUDF(e) || agg.groupingExpressions.exists(_.semanticEquals(e)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala new file mode 100644 index 0000000000000..c76832a1a3829 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala @@ -0,0 +1,173 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.python + +import java.io.File + +import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.{SparkEnv, TaskContext} +import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.execution.{GroupedIterator, SparkPlan, UnaryExecNode} +import org.apache.spark.sql.types.{DataType, StructField, StructType} +import org.apache.spark.util.Utils + +case class WindowInPandasExec( + windowExpression: Seq[NamedExpression], + partitionSpec: Seq[Expression], + orderSpec: Seq[SortOrder], + child: SparkPlan) extends UnaryExecNode { + + override def output: Seq[Attribute] = + child.output ++ windowExpression.map(_.toAttribute) + + override def requiredChildDistribution: Seq[Distribution] = { + if (partitionSpec.isEmpty) { + // Only show warning when the number of bytes is larger than 100 MB? + logWarning("No Partition Defined for Window operation! Moving all data to a single " + + "partition, this can cause serious performance degradation.") + AllTuples :: Nil + } else { + ClusteredDistribution(partitionSpec) :: Nil + } + } + + override def requiredChildOrdering: Seq[Seq[SortOrder]] = + Seq(partitionSpec.map(SortOrder(_, Ascending)) ++ orderSpec) + + override def outputOrdering: Seq[SortOrder] = child.outputOrdering + + override def outputPartitioning: Partitioning = child.outputPartitioning + + private def collectFunctions(udf: PythonUDF): (ChainedPythonFunctions, Seq[Expression]) = { + udf.children match { + case Seq(u: PythonUDF) => + val (chained, children) = collectFunctions(u) + (ChainedPythonFunctions(chained.funcs ++ Seq(udf.func)), children) + case children => + // There should not be any other UDFs, or the children can't be evaluated directly. + assert(children.forall(_.find(_.isInstanceOf[PythonUDF]).isEmpty)) + (ChainedPythonFunctions(Seq(udf.func)), udf.children) + } + } + + /** + * Create the resulting projection. + * + * This method uses Code Generation. It can only be used on the executor side. + * + * @param expressions unbound ordered function expressions. + * @return the final resulting projection. + */ + private[this] def createResultProjection(expressions: Seq[Expression]): UnsafeProjection = { + val references = expressions.zipWithIndex.map { case (e, i) => + // Results of window expressions will be on the right side of child's output + BoundReference(child.output.size + i, e.dataType, e.nullable) + } + val unboundToRefMap = expressions.zip(references).toMap + val patchedWindowExpression = windowExpression.map(_.transform(unboundToRefMap)) + UnsafeProjection.create( + child.output ++ patchedWindowExpression, + child.output) + } + + protected override def doExecute(): RDD[InternalRow] = { + val inputRDD = child.execute() + + val bufferSize = inputRDD.conf.getInt("spark.buffer.size", 65536) + val reuseWorker = inputRDD.conf.getBoolean("spark.python.worker.reuse", defaultValue = true) + val sessionLocalTimeZone = conf.sessionLocalTimeZone + val pandasRespectSessionTimeZone = conf.pandasRespectSessionTimeZone + + // Extract window expressions and window functions + val expressions = windowExpression.flatMap(_.collect { case e: WindowExpression => e }) + + val udfExpressions = expressions.map(_.windowFunction.asInstanceOf[PythonUDF]) + + val (pyFuncs, inputs) = udfExpressions.map(collectFunctions).unzip + + // Filter child output attributes down to only those that are UDF inputs. + // Also eliminate duplicate UDF inputs. + val allInputs = new ArrayBuffer[Expression] + val dataTypes = new ArrayBuffer[DataType] + val argOffsets = inputs.map { input => + input.map { e => + if (allInputs.exists(_.semanticEquals(e))) { + allInputs.indexWhere(_.semanticEquals(e)) + } else { + allInputs += e + dataTypes += e.dataType + allInputs.length - 1 + } + }.toArray + }.toArray + + // Schema of input rows to the python runner + val windowInputSchema = StructType(dataTypes.zipWithIndex.map { case (dt, i) => + StructField(s"_$i", dt) + }) + + inputRDD.mapPartitionsInternal { iter => + val context = TaskContext.get() + + val grouped = if (partitionSpec.isEmpty) { + // Use an empty unsafe row as a place holder for the grouping key + Iterator((new UnsafeRow(), iter)) + } else { + GroupedIterator(iter, partitionSpec, child.output) + } + + // The queue used to buffer input rows so we can drain it to + // combine input with output from Python. + val queue = HybridRowQueue(context.taskMemoryManager(), + new File(Utils.getLocalDir(SparkEnv.get.conf)), child.output.length) + context.addTaskCompletionListener { _ => + queue.close() + } + + val inputProj = UnsafeProjection.create(allInputs, child.output) + val pythonInput = grouped.map { case (_, rows) => + rows.map { row => + queue.add(row.asInstanceOf[UnsafeRow]) + inputProj(row) + } + } + + val windowFunctionResult = new ArrowPythonRunner( + pyFuncs, bufferSize, reuseWorker, + PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF, + argOffsets, windowInputSchema, + sessionLocalTimeZone, pandasRespectSessionTimeZone) + .compute(pythonInput, context.partitionId(), context) + + val joined = new JoinedRow + val resultProj = createResultProjection(expressions) + + windowFunctionResult.flatMap(_.rowIterator.asScala).map { windowOutput => + val leftRow = queue.remove() + val joinedRow = joined(leftRow, windowOutput) + resultProj(joinedRow) + } + } + } +} From 3352d6fe9a1efb6dee18e40bdf584930b10d1d3e Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Wed, 13 Jun 2018 12:34:46 +0800 Subject: [PATCH 0955/1161] [SPARK-24466][SS] Fix TextSocketMicroBatchReader to be compatible with netcat again ## What changes were proposed in this pull request? TextSocketMicroBatchReader was no longer be compatible with netcat due to launching temporary reader for reading schema, and closing reader, and re-opening reader. While reliable socket server should be able to handle this without any issue, nc command normally can't handle multiple connections and simply exits when closing temporary reader. This patch fixes TextSocketMicroBatchReader to be compatible with netcat again, via deferring opening socket to the first call of planInputPartitions() instead of constructor. ## How was this patch tested? Added unit test which fails on current and succeeds with the patch. And also manually tested. Author: Jungtaek Lim Closes #21497 from HeartSaVioR/SPARK-24466. --- .../execution/streaming/sources/socket.scala | 7 +- .../sources/TextSocketStreamSuite.scala | 119 ++++++++++++------ 2 files changed, 85 insertions(+), 41 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala index 8240e06d4ab72..91e3b7179c34a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala @@ -22,6 +22,7 @@ import java.net.Socket import java.sql.Timestamp import java.text.SimpleDateFormat import java.util.{Calendar, List => JList, Locale, Optional} +import java.util.concurrent.atomic.AtomicBoolean import javax.annotation.concurrent.GuardedBy import scala.collection.JavaConverters._ @@ -76,7 +77,7 @@ class TextSocketMicroBatchReader(options: DataSourceOptions) extends MicroBatchR @GuardedBy("this") private var lastOffsetCommitted: LongOffset = LongOffset(-1L) - initialize() + private val initialized: AtomicBoolean = new AtomicBoolean(false) /** This method is only used for unit test */ private[sources] def getCurrentOffset(): LongOffset = synchronized { @@ -149,6 +150,10 @@ class TextSocketMicroBatchReader(options: DataSourceOptions) extends MicroBatchR // Internal buffer only holds the batches after lastOffsetCommitted val rawList = synchronized { + if (initialized.compareAndSet(false, true)) { + initialize() + } + val sliceStart = startOrdinal - lastOffsetCommitted.offset.toInt - 1 val sliceEnd = endOrdinal - lastOffsetCommitted.offset.toInt - 1 batches.slice(sliceStart, sliceEnd) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala index a15a980bb92fd..52e8386f6b1fa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala @@ -17,8 +17,7 @@ package org.apache.spark.sql.execution.streaming.sources -import java.io.IOException -import java.net.InetSocketAddress +import java.net.{InetSocketAddress, SocketException} import java.nio.ByteBuffer import java.nio.channels.ServerSocketChannel import java.sql.Timestamp @@ -33,9 +32,10 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.v2.{DataSourceOptions, MicroBatchReadSupport} import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset} -import org.apache.spark.sql.streaming.StreamTest +import org.apache.spark.sql.streaming.{StreamingQueryException, StreamTest} import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{StringType, StructField, StructType, TimestampType} @@ -101,7 +101,7 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before serverThread = new ServerThread() serverThread.start() - withSQLConf("spark.sql.streaming.unsupportedOperationCheck" -> "false") { + withSQLConf(SQLConf.UNSUPPORTED_OPERATION_CHECK_ENABLED.key -> "false") { val ref = spark import ref.implicits._ @@ -130,7 +130,7 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before serverThread = new ServerThread() serverThread.start() - withSQLConf("spark.sql.streaming.unsupportedOperationCheck" -> "false") { + withSQLConf(SQLConf.UNSUPPORTED_OPERATION_CHECK_ENABLED.key -> "false") { val socket = spark .readStream .format("socket") @@ -216,20 +216,11 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before "socket source does not support a user-specified schema")) } - test("no server up") { - val provider = new TextSocketSourceProvider - val parameters = Map("host" -> "localhost", "port" -> "0") - intercept[IOException] { - batchReader = provider.createMicroBatchReader( - Optional.empty(), "", new DataSourceOptions(parameters.asJava)) - } - } - test("input row metrics") { serverThread = new ServerThread() serverThread.start() - withSQLConf("spark.sql.streaming.unsupportedOperationCheck" -> "false") { + withSQLConf(SQLConf.UNSUPPORTED_OPERATION_CHECK_ENABLED.key -> "false") { val ref = spark import ref.implicits._ @@ -256,6 +247,66 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before } } + test("verify ServerThread only accepts the first connection") { + serverThread = new ServerThread() + serverThread.start() + + withSQLConf(SQLConf.UNSUPPORTED_OPERATION_CHECK_ENABLED.key -> "false") { + val ref = spark + import ref.implicits._ + + val socket = spark + .readStream + .format("socket") + .options(Map("host" -> "localhost", "port" -> serverThread.port.toString)) + .load() + .as[String] + + assert(socket.schema === StructType(StructField("value", StringType) :: Nil)) + + testStream(socket)( + StartStream(), + AddSocketData("hello"), + CheckAnswer("hello"), + AddSocketData("world"), + CheckLastBatch("world"), + CheckAnswer("hello", "world"), + StopStream + ) + + // we are trying to connect to the server once again which should fail + try { + val socket2 = spark + .readStream + .format("socket") + .options(Map("host" -> "localhost", "port" -> serverThread.port.toString)) + .load() + .as[String] + + testStream(socket2)( + StartStream(), + AddSocketData("hello"), + CheckAnswer("hello"), + AddSocketData("world"), + CheckLastBatch("world"), + CheckAnswer("hello", "world"), + StopStream + ) + + fail("StreamingQueryException is expected!") + } catch { + case e: StreamingQueryException if e.cause.isInstanceOf[SocketException] => // pass + } + } + } + + /** + * This class tries to mimic the behavior of netcat, so that we can ensure + * TextSocketStream supports netcat, which only accepts the first connection + * and exits the process when the first connection is closed. + * + * Please refer SPARK-24466 for more details. + */ private class ServerThread extends Thread with Logging { private val serverSocketChannel = ServerSocketChannel.open() serverSocketChannel.bind(new InetSocketAddress(0)) @@ -265,36 +316,24 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before override def run(): Unit = { try { + val clientSocketChannel = serverSocketChannel.accept() + + // Close server socket channel immediately to mimic the behavior that + // only first connection will be made and deny any further connections + // Note that the first client socket channel will be available + serverSocketChannel.close() + + clientSocketChannel.configureBlocking(false) + clientSocketChannel.socket().setTcpNoDelay(true) + while (true) { - val clientSocketChannel = serverSocketChannel.accept() - clientSocketChannel.configureBlocking(false) - clientSocketChannel.socket().setTcpNoDelay(true) - - // Check whether remote client is closed but still send data to this closed socket. - // This happens in DataStreamReader where a source will be created to get the schema. - var remoteIsClosed = false - var cnt = 0 - while (cnt < 3 && !remoteIsClosed) { - if (clientSocketChannel.read(ByteBuffer.allocate(1)) != -1) { - cnt += 1 - Thread.sleep(100) - } else { - remoteIsClosed = true - } - } - - if (remoteIsClosed) { - logInfo(s"remote client ${clientSocketChannel.socket()} is closed") - } else { - while (true) { - val line = messageQueue.take() + "\n" - clientSocketChannel.write(ByteBuffer.wrap(line.getBytes("UTF-8"))) - } - } + val line = messageQueue.take() + "\n" + clientSocketChannel.write(ByteBuffer.wrap(line.getBytes("UTF-8"))) } } catch { case e: InterruptedException => } finally { + // no harm to call close() again... serverSocketChannel.close() } } From 4c388bccf1bcac8f833fd9214096dd164c3ea065 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Wed, 13 Jun 2018 12:36:20 +0800 Subject: [PATCH 0956/1161] [SPARK-24485][SS] Measure and log elapsed time for filesystem operations in HDFSBackedStateStoreProvider ## What changes were proposed in this pull request? This patch measures and logs elapsed time for each operation which communicate with file system (mostly remote HDFS in production) in HDFSBackedStateStoreProvider to help investigating any latency issue. ## How was this patch tested? Manually tested. Author: Jungtaek Lim Closes #21506 from HeartSaVioR/SPARK-24485. --- .../scala/org/apache/spark/util/Utils.scala | 11 ++- .../state/HDFSBackedStateStoreProvider.scala | 83 +++++++++++-------- .../streaming/statefulOperators.scala | 9 +- 3 files changed, 62 insertions(+), 41 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 7428db2158538..c139db46b63a3 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -31,6 +31,7 @@ import java.nio.file.Files import java.security.SecureRandom import java.util.{Locale, Properties, Random, UUID} import java.util.concurrent._ +import java.util.concurrent.TimeUnit.NANOSECONDS import java.util.concurrent.atomic.AtomicBoolean import java.util.zip.GZIPInputStream @@ -434,7 +435,7 @@ private[spark] object Utils extends Logging { new URI("file:///" + rawFileName).getPath.substring(1) } - /** + /** * Download a file or directory to target directory. Supports fetching the file in a variety of * ways, including HTTP, Hadoop-compatible filesystems, and files on a standard filesystem, based * on the URL parameter. Fetching directories is only supported from Hadoop-compatible @@ -507,6 +508,14 @@ private[spark] object Utils extends Logging { targetFile } + /** Records the duration of running `body`. */ + def timeTakenMs[T](body: => T): (T, Long) = { + val startTime = System.nanoTime() + val result = body + val endTime = System.nanoTime() + (result, math.max(NANOSECONDS.toMillis(endTime - startTime), 0)) + } + /** * Download `in` to `tempFile`, then move it to `destFile`. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala index df722b953228b..118c82aa75e68 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala @@ -18,12 +18,10 @@ package org.apache.spark.sql.execution.streaming.state import java.io._ -import java.nio.channels.ClosedChannelException import java.util.Locale import scala.collection.JavaConverters._ import scala.collection.mutable -import scala.util.Random import scala.util.control.NonFatal import com.google.common.io.ByteStreams @@ -280,38 +278,49 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit if (loadedCurrentVersionMap.isDefined) { return loadedCurrentVersionMap.get } - val snapshotCurrentVersionMap = readSnapshotFile(version) - if (snapshotCurrentVersionMap.isDefined) { - synchronized { loadedMaps.put(version, snapshotCurrentVersionMap.get) } - return snapshotCurrentVersionMap.get - } - // Find the most recent map before this version that we can. - // [SPARK-22305] This must be done iteratively to avoid stack overflow. - var lastAvailableVersion = version - var lastAvailableMap: Option[MapType] = None - while (lastAvailableMap.isEmpty) { - lastAvailableVersion -= 1 + logWarning(s"The state for version $version doesn't exist in loadedMaps. " + + "Reading snapshot file and delta files if needed..." + + "Note that this is normal for the first batch of starting query.") - if (lastAvailableVersion <= 0) { - // Use an empty map for versions 0 or less. - lastAvailableMap = Some(new MapType) - } else { - lastAvailableMap = - synchronized { loadedMaps.get(lastAvailableVersion) } - .orElse(readSnapshotFile(lastAvailableVersion)) + val (result, elapsedMs) = Utils.timeTakenMs { + val snapshotCurrentVersionMap = readSnapshotFile(version) + if (snapshotCurrentVersionMap.isDefined) { + synchronized { loadedMaps.put(version, snapshotCurrentVersionMap.get) } + return snapshotCurrentVersionMap.get + } + + // Find the most recent map before this version that we can. + // [SPARK-22305] This must be done iteratively to avoid stack overflow. + var lastAvailableVersion = version + var lastAvailableMap: Option[MapType] = None + while (lastAvailableMap.isEmpty) { + lastAvailableVersion -= 1 + + if (lastAvailableVersion <= 0) { + // Use an empty map for versions 0 or less. + lastAvailableMap = Some(new MapType) + } else { + lastAvailableMap = + synchronized { loadedMaps.get(lastAvailableVersion) } + .orElse(readSnapshotFile(lastAvailableVersion)) + } + } + + // Load all the deltas from the version after the last available one up to the target version. + // The last available version is the one with a full snapshot, so it doesn't need deltas. + val resultMap = new MapType(lastAvailableMap.get) + for (deltaVersion <- lastAvailableVersion + 1 to version) { + updateFromDeltaFile(deltaVersion, resultMap) } - } - // Load all the deltas from the version after the last available one up to the target version. - // The last available version is the one with a full snapshot, so it doesn't need deltas. - val resultMap = new MapType(lastAvailableMap.get) - for (deltaVersion <- lastAvailableVersion + 1 to version) { - updateFromDeltaFile(deltaVersion, resultMap) + synchronized { loadedMaps.put(version, resultMap) } + resultMap } - synchronized { loadedMaps.put(version, resultMap) } - resultMap + logDebug(s"Loading state for $version takes $elapsedMs ms.") + + result } private def writeUpdateToDeltaFile( @@ -490,7 +499,9 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit /** Perform a snapshot of the store to allow delta files to be consolidated */ private def doSnapshot(): Unit = { try { - val files = fetchFiles() + val (files, e1) = Utils.timeTakenMs(fetchFiles()) + logDebug(s"fetchFiles() took $e1 ms.") + if (files.nonEmpty) { val lastVersion = files.last.version val deltaFilesForLastVersion = @@ -498,7 +509,8 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit synchronized { loadedMaps.get(lastVersion) } match { case Some(map) => if (deltaFilesForLastVersion.size > storeConf.minDeltasForSnapshot) { - writeSnapshotFile(lastVersion, map) + val (_, e2) = Utils.timeTakenMs(writeSnapshotFile(lastVersion, map)) + logDebug(s"writeSnapshotFile() took $e2 ms.") } case None => // The last map is not loaded, probably some other instance is in charge @@ -517,7 +529,9 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit */ private[state] def cleanup(): Unit = { try { - val files = fetchFiles() + val (files, e1) = Utils.timeTakenMs(fetchFiles()) + logDebug(s"fetchFiles() took $e1 ms.") + if (files.nonEmpty) { val earliestVersionToRetain = files.last.version - storeConf.minVersionsToRetain if (earliestVersionToRetain > 0) { @@ -527,9 +541,12 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit mapsToRemove.foreach(loadedMaps.remove) } val filesToDelete = files.filter(_.version < earliestFileToRetain.version) - filesToDelete.foreach { f => - fm.delete(f.path) + val (_, e2) = Utils.timeTakenMs { + filesToDelete.foreach { f => + fm.delete(f.path) + } } + logDebug(s"deleting files took $e2 ms.") logInfo(s"Deleted files older than ${earliestFileToRetain.version} for $this: " + filesToDelete.mkString(", ")) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index 1691a6320a526..6759fb42b4052 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -35,7 +35,7 @@ import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.sql.execution.streaming.state._ import org.apache.spark.sql.streaming.{OutputMode, StateOperatorProgress} import org.apache.spark.sql.types._ -import org.apache.spark.util.{CompletionIterator, NextIterator} +import org.apache.spark.util.{CompletionIterator, NextIterator, Utils} /** Used to identify the state store for a given operator. */ @@ -97,12 +97,7 @@ trait StateStoreWriter extends StatefulOperator { self: SparkPlan => } /** Records the duration of running `body` for the next query progress update. */ - protected def timeTakenMs(body: => Unit): Long = { - val startTime = System.nanoTime() - val result = body - val endTime = System.nanoTime() - math.max(NANOSECONDS.toMillis(endTime - startTime), 0) - } + protected def timeTakenMs(body: => Unit): Long = Utils.timeTakenMs(body)._2 /** * Set the SQL metrics related to the state store. From 7703b46d2843db99e28110c4c7ccf60934412504 Mon Sep 17 00:00:00 2001 From: Arun Mahadevan Date: Wed, 13 Jun 2018 20:43:16 +0800 Subject: [PATCH 0957/1161] [SPARK-24479][SS] Added config for registering streamingQueryListeners ## What changes were proposed in this pull request? Currently a "StreamingQueryListener" can only be registered programatically. We could have a new config "spark.sql.streamingQueryListeners" similar to "spark.sql.queryExecutionListeners" and "spark.extraListeners" for users to register custom streaming listeners. ## How was this patch tested? New unit test and running example programs. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Arun Mahadevan Closes #21504 from arunmahadevan/SPARK-24480. --- .../spark/sql/internal/StaticSQLConf.scala | 8 +++ .../sql/streaming/StreamingQueryManager.scala | 15 +++++ .../StreamingQueryListenersConfSuite.scala | 66 +++++++++++++++++++ 3 files changed, 89 insertions(+) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenersConfSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala index fe0ad39c29025..382ef28f49a7a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala @@ -96,6 +96,14 @@ object StaticSQLConf { .toSequence .createOptional + val STREAMING_QUERY_LISTENERS = buildStaticConf("spark.sql.streaming.streamingQueryListeners") + .doc("List of class names implementing StreamingQueryListener that will be automatically " + + "added to newly created sessions. The classes should have either a no-arg constructor, " + + "or a constructor that expects a SparkConf argument.") + .stringConf + .toSequence + .createOptional + val UI_RETAINED_EXECUTIONS = buildStaticConf("spark.sql.ui.retainedExecutions") .doc("Number of executions to retain in the Spark UI.") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala index 97da2b1325f58..25bb05212d66f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala @@ -24,6 +24,7 @@ import scala.collection.mutable import org.apache.hadoop.fs.Path +import org.apache.spark.SparkException import org.apache.spark.annotation.InterfaceStability import org.apache.spark.internal.Logging import org.apache.spark.sql.{AnalysisException, DataFrame, SparkSession} @@ -32,6 +33,7 @@ import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous.{ContinuousExecution, ContinuousTrigger} import org.apache.spark.sql.execution.streaming.state.StateStoreCoordinatorRef import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.StaticSQLConf.STREAMING_QUERY_LISTENERS import org.apache.spark.sql.sources.v2.StreamWriteSupport import org.apache.spark.util.{Clock, SystemClock, Utils} @@ -55,6 +57,19 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Lo @GuardedBy("awaitTerminationLock") private var lastTerminatedQuery: StreamingQuery = null + try { + sparkSession.sparkContext.conf.get(STREAMING_QUERY_LISTENERS).foreach { classNames => + Utils.loadExtensions(classOf[StreamingQueryListener], classNames, + sparkSession.sparkContext.conf).foreach(listener => { + addListener(listener) + logInfo(s"Registered listener ${listener.getClass.getName}") + }) + } + } catch { + case e: Exception => + throw new SparkException("Exception when registering StreamingQueryListener", e) + } + /** * Returns a list of active queries associated with this SQLContext * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenersConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenersConfSuite.scala new file mode 100644 index 0000000000000..1aaf8a9aa2d55 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenersConfSuite.scala @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming + +import scala.language.reflectiveCalls + +import org.scalatest.BeforeAndAfter + +import org.apache.spark.SparkConf +import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.streaming.StreamingQueryListener._ + + +class StreamingQueryListenersConfSuite extends StreamTest with BeforeAndAfter { + + import testImplicits._ + + + override protected def sparkConf: SparkConf = + super.sparkConf.set("spark.sql.streaming.streamingQueryListeners", + "org.apache.spark.sql.streaming.TestListener") + + test("test if the configured query lister is loaded") { + testStream(MemoryStream[Int].toDS)( + StartStream(), + StopStream + ) + + assert(TestListener.queryStartedEvent != null) + assert(TestListener.queryTerminatedEvent != null) + } + +} + +object TestListener { + @volatile var queryStartedEvent: QueryStartedEvent = null + @volatile var queryTerminatedEvent: QueryTerminatedEvent = null +} + +class TestListener(sparkConf: SparkConf) extends StreamingQueryListener { + + override def onQueryStarted(event: QueryStartedEvent): Unit = { + TestListener.queryStartedEvent = event + } + + override def onQueryProgress(event: QueryProgressEvent): Unit = {} + + override def onQueryTerminated(event: QueryTerminatedEvent): Unit = { + TestListener.queryTerminatedEvent = event + } +} From 299d297e250ca3d46616a97e4256aa9ad6a135e5 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Wed, 13 Jun 2018 07:09:48 -0700 Subject: [PATCH 0958/1161] [SPARK-24500][SQL] Make sure streams are materialized during Tree transforms. ## What changes were proposed in this pull request? If you construct catalyst trees using `scala.collection.immutable.Stream` you can run into situations where valid transformations do not seem to have any effect. There are two causes for this behavior: - `Stream` is evaluated lazily. Note that default implementation will generally only evaluate a function for the first element (this makes testing a bit tricky). - `TreeNode` and `QueryPlan` use side effects to detect if a tree has changed. Mapping over a stream is lazy and does not need to trigger this side effect. If this happens the node will invalidly assume that it did not change and return itself instead if the newly created node (this is for GC reasons). This PR fixes this issue by forcing materialization on streams in `TreeNode` and `QueryPlan`. ## How was this patch tested? Unit tests were added to `TreeNodeSuite` and `LogicalPlanSuite`. An integration test was added to the `PlannerSuite` Author: Herman van Hovell Closes #21539 from hvanhovell/SPARK-24500. --- .../spark/sql/catalyst/plans/QueryPlan.scala | 1 + .../spark/sql/catalyst/trees/TreeNode.scala | 122 ++++++++---------- .../sql/catalyst/plans/LogicalPlanSuite.scala | 20 ++- .../sql/catalyst/trees/TreeNodeSuite.scala | 25 +++- .../spark/sql/execution/PlannerSuite.scala | 11 +- 5 files changed, 109 insertions(+), 70 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 64cb8c726772f..e431c9523a9da 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -119,6 +119,7 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT case Some(value) => Some(recursiveTransform(value)) case m: Map[_, _] => m case d: DataType => d // Avoid unpacking Structs + case stream: Stream[_] => stream.map(recursiveTransform).force case seq: Traversable[_] => seq.map(recursiveTransform) case other: AnyRef => other case null => null diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index 9c7d47f99ee10..becfa8d982213 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -199,44 +199,33 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { var changed = false val remainingNewChildren = newChildren.toBuffer val remainingOldChildren = children.toBuffer + def mapTreeNode(node: TreeNode[_]): TreeNode[_] = { + val newChild = remainingNewChildren.remove(0) + val oldChild = remainingOldChildren.remove(0) + if (newChild fastEquals oldChild) { + oldChild + } else { + changed = true + newChild + } + } + def mapChild(child: Any): Any = child match { + case arg: TreeNode[_] if containsChild(arg) => mapTreeNode(arg) + case nonChild: AnyRef => nonChild + case null => null + } val newArgs = mapProductIterator { case s: StructType => s // Don't convert struct types to some other type of Seq[StructField] // Handle Seq[TreeNode] in TreeNode parameters. - case s: Seq[_] => s.map { - case arg: TreeNode[_] if containsChild(arg) => - val newChild = remainingNewChildren.remove(0) - val oldChild = remainingOldChildren.remove(0) - if (newChild fastEquals oldChild) { - oldChild - } else { - changed = true - newChild - } - case nonChild: AnyRef => nonChild - case null => null - } - case m: Map[_, _] => m.mapValues { - case arg: TreeNode[_] if containsChild(arg) => - val newChild = remainingNewChildren.remove(0) - val oldChild = remainingOldChildren.remove(0) - if (newChild fastEquals oldChild) { - oldChild - } else { - changed = true - newChild - } - case nonChild: AnyRef => nonChild - case null => null - }.view.force // `mapValues` is lazy and we need to force it to materialize - case arg: TreeNode[_] if containsChild(arg) => - val newChild = remainingNewChildren.remove(0) - val oldChild = remainingOldChildren.remove(0) - if (newChild fastEquals oldChild) { - oldChild - } else { - changed = true - newChild - } + case s: Stream[_] => + // Stream is lazy so we need to force materialization + s.map(mapChild).force + case s: Seq[_] => + s.map(mapChild) + case m: Map[_, _] => + // `mapValues` is lazy and we need to force it to materialize + m.mapValues(mapChild).view.force + case arg: TreeNode[_] if containsChild(arg) => mapTreeNode(arg) case nonChild: AnyRef => nonChild case null => null } @@ -301,6 +290,37 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { def mapChildren(f: BaseType => BaseType): BaseType = { if (children.nonEmpty) { var changed = false + def mapChild(child: Any): Any = child match { + case arg: TreeNode[_] if containsChild(arg) => + val newChild = f(arg.asInstanceOf[BaseType]) + if (!(newChild fastEquals arg)) { + changed = true + newChild + } else { + arg + } + case tuple@(arg1: TreeNode[_], arg2: TreeNode[_]) => + val newChild1 = if (containsChild(arg1)) { + f(arg1.asInstanceOf[BaseType]) + } else { + arg1.asInstanceOf[BaseType] + } + + val newChild2 = if (containsChild(arg2)) { + f(arg2.asInstanceOf[BaseType]) + } else { + arg2.asInstanceOf[BaseType] + } + + if (!(newChild1 fastEquals arg1) || !(newChild2 fastEquals arg2)) { + changed = true + (newChild1, newChild2) + } else { + tuple + } + case other => other + } + val newArgs = mapProductIterator { case arg: TreeNode[_] if containsChild(arg) => val newChild = f(arg.asInstanceOf[BaseType]) @@ -330,36 +350,8 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { case other => other }.view.force // `mapValues` is lazy and we need to force it to materialize case d: DataType => d // Avoid unpacking Structs - case args: Traversable[_] => args.map { - case arg: TreeNode[_] if containsChild(arg) => - val newChild = f(arg.asInstanceOf[BaseType]) - if (!(newChild fastEquals arg)) { - changed = true - newChild - } else { - arg - } - case tuple@(arg1: TreeNode[_], arg2: TreeNode[_]) => - val newChild1 = if (containsChild(arg1)) { - f(arg1.asInstanceOf[BaseType]) - } else { - arg1.asInstanceOf[BaseType] - } - - val newChild2 = if (containsChild(arg2)) { - f(arg2.asInstanceOf[BaseType]) - } else { - arg2.asInstanceOf[BaseType] - } - - if (!(newChild1 fastEquals arg1) || !(newChild2 fastEquals arg2)) { - changed = true - (newChild1, newChild2) - } else { - tuple - } - case other => other - } + case args: Stream[_] => args.map(mapChild).force // Force materialization on stream + case args: Traversable[_] => args.map(mapChild) case nonChild: AnyRef => nonChild case null => null } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala index 14041747fd20e..bf569cb869428 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.plans import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, Coalesce, Literal, NamedExpression} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.types.IntegerType @@ -101,4 +101,22 @@ class LogicalPlanSuite extends SparkFunSuite { assert(TestBinaryRelation(relation, incrementalRelation).isStreaming === true) assert(TestBinaryRelation(incrementalRelation, incrementalRelation).isStreaming) } + + test("transformExpressions works with a Stream") { + val id1 = NamedExpression.newExprId + val id2 = NamedExpression.newExprId + val plan = Project(Stream( + Alias(Literal(1), "a")(exprId = id1), + Alias(Literal(2), "b")(exprId = id2)), + OneRowRelation()) + val result = plan.transformExpressions { + case Literal(v: Int, IntegerType) if v != 1 => + Literal(v + 1, IntegerType) + } + val expected = Project(Stream( + Alias(Literal(1), "a")(exprId = id1), + Alias(Literal(3), "b")(exprId = id2)), + OneRowRelation()) + assert(result.sameResult(expected)) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala index 84d0ba7bef642..b7092f4c42d4c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala @@ -29,14 +29,14 @@ import org.json4s.jackson.JsonMethods._ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.{FunctionIdentifier, InternalRow, TableIdentifier} -import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat, CatalogTable, CatalogTableType, FunctionResource, JarResource} +import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.dsl.expressions.DslString import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.plans.{LeftOuter, NaturalJoin} import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Union} import org.apache.spark.sql.catalyst.plans.physical.{IdentityBroadcastMode, RoundRobinPartitioning, SinglePartition} -import org.apache.spark.sql.types.{BooleanType, DoubleType, FloatType, IntegerType, Metadata, NullType, StringType, StructField, StructType} +import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel case class Dummy(optKey: Option[Expression]) extends Expression with CodegenFallback { @@ -574,4 +574,25 @@ class TreeNodeSuite extends SparkFunSuite { val right = JsonMethods.parse(rightJson) assert(left == right) } + + test("transform works on stream of children") { + val before = Coalesce(Stream(Literal(1), Literal(2))) + // Note it is a bit tricky to exhibit the broken behavior. Basically we want to create the + // situation in which the TreeNode.mapChildren function's change detection is not triggered. A + // stream's first element is typically materialized, so in order to not trip the TreeNode change + // detection logic, we should not change the first element in the sequence. + val result = before.transform { + case Literal(v: Int, IntegerType) if v != 1 => + Literal(v + 1, IntegerType) + } + val expected = Coalesce(Stream(Literal(1), Literal(3))) + assert(result === expected) + } + + test("withNewChildren on stream of children") { + val before = Coalesce(Stream(Literal(1), Literal(2))) + val result = before.withNewChildren(Stream(Literal(1), Literal(3))) + val expected = Coalesce(Stream(Literal(1), Literal(3))) + assert(result === expected) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 98a50fbd52b4d..ed0ff1be476c7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -21,8 +21,8 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.{execution, Row} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.{Cross, FullOuter, Inner, LeftOuter, RightOuter} -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Range, Repartition, Sort} +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Range, Repartition, Sort, Union} import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.columnar.InMemoryRelation import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReusedExchangeExec, ReuseExchange, ShuffleExchangeExec} @@ -679,6 +679,13 @@ class PlannerSuite extends SharedSQLContext { } assert(rangeExecInZeroPartition.head.outputPartitioning == UnknownPartitioning(0)) } + + test("SPARK-24500: create union with stream of children") { + val df = Union(Stream( + Range(1, 1, 1, 1), + Range(1, 2, 1, 1))) + df.queryExecution.executedPlan.execute() + } } // Used for unit-testing EnsureRequirements From 1b46f41c55f5cd29956e17d7da95a95580cf273f Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Wed, 13 Jun 2018 13:13:01 -0700 Subject: [PATCH 0959/1161] [SPARK-24235][SS] Implement continuous shuffle writer for single reader partition. ## What changes were proposed in this pull request? https://docs.google.com/document/d/1IL4kJoKrZWeyIhklKUJqsW-yEN7V7aL05MmM65AYOfE/edit Implement continuous shuffle write RDD for a single reader partition. (I don't believe any implementation changes are actually required for multiple reader partitions, but this PR is already very large, so I want to exclude those for now to keep the size down.) ## How was this patch tested? new unit tests Author: Jose Torres Closes #21428 from jose-torres/writerTask. --- .../shuffle/ContinuousShuffleReadRDD.scala | 6 +- .../shuffle/ContinuousShuffleWriter.scala | 27 ++ ...scala => RPCContinuousShuffleReader.scala} | 24 +- .../shuffle/RPCContinuousShuffleWriter.scala | 60 +++++ ...ite.scala => ContinuousShuffleSuite.scala} | 231 ++++++++++++++---- 5 files changed, 281 insertions(+), 67 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleWriter.scala rename sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/{UnsafeRowReceiver.scala => RPCContinuousShuffleReader.scala} (86%) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/RPCContinuousShuffleWriter.scala rename sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/{ContinuousShuffleReadSuite.scala => ContinuousShuffleSuite.scala} (65%) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala index 801b28b751bee..cf6572d3de1f7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala @@ -34,8 +34,10 @@ case class ContinuousShuffleReadPartition( // Initialized only on the executor, and only once even as we call compute() multiple times. lazy val (reader: ContinuousShuffleReader, endpoint) = { val env = SparkEnv.get.rpcEnv - val receiver = new UnsafeRowReceiver(queueSize, numShuffleWriters, epochIntervalMs, env) - val endpoint = env.setupEndpoint(s"UnsafeRowReceiver-${UUID.randomUUID()}", receiver) + val receiver = new RPCContinuousShuffleReader( + queueSize, numShuffleWriters, epochIntervalMs, env) + val endpoint = env.setupEndpoint(s"RPCContinuousShuffleReader-${UUID.randomUUID()}", receiver) + TaskContext.get().addTaskCompletionListener { ctx => env.stop(endpoint) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleWriter.scala new file mode 100644 index 0000000000000..47b1f78b24505 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleWriter.scala @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.continuous.shuffle + +import org.apache.spark.sql.catalyst.expressions.UnsafeRow + +/** + * Trait for writing to a continuous processing shuffle. + */ +trait ContinuousShuffleWriter { + def write(epoch: Iterator[UnsafeRow]): Unit +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/RPCContinuousShuffleReader.scala similarity index 86% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/RPCContinuousShuffleReader.scala index d81f552d56626..834e84675c7d5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/RPCContinuousShuffleReader.scala @@ -20,26 +20,24 @@ package org.apache.spark.sql.execution.streaming.continuous.shuffle import java.util.concurrent._ import java.util.concurrent.atomic.AtomicBoolean -import scala.collection.mutable - import org.apache.spark.internal.Logging import org.apache.spark.rpc.{RpcCallContext, RpcEnv, ThreadSafeRpcEndpoint} import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.util.NextIterator /** - * Messages for the UnsafeRowReceiver endpoint. Either an incoming row or an epoch marker. + * Messages for the RPCContinuousShuffleReader endpoint. Either an incoming row or an epoch marker. * * Each message comes tagged with writerId, identifying which writer the message is coming * from. The receiver will only begin the next epoch once all writers have sent an epoch * marker ending the current epoch. */ -private[shuffle] sealed trait UnsafeRowReceiverMessage extends Serializable { +private[shuffle] sealed trait RPCContinuousShuffleMessage extends Serializable { def writerId: Int } private[shuffle] case class ReceiverRow(writerId: Int, row: UnsafeRow) - extends UnsafeRowReceiverMessage -private[shuffle] case class ReceiverEpochMarker(writerId: Int) extends UnsafeRowReceiverMessage + extends RPCContinuousShuffleMessage +private[shuffle] case class ReceiverEpochMarker(writerId: Int) extends RPCContinuousShuffleMessage /** * RPC endpoint for receiving rows into a continuous processing shuffle task. Continuous shuffle @@ -48,7 +46,7 @@ private[shuffle] case class ReceiverEpochMarker(writerId: Int) extends UnsafeRow * TODO: Support multiple source tasks. We need to output a single epoch marker once all * source tasks have sent one. */ -private[shuffle] class UnsafeRowReceiver( +private[shuffle] class RPCContinuousShuffleReader( queueSize: Int, numShuffleWriters: Int, epochIntervalMs: Long, @@ -57,7 +55,7 @@ private[shuffle] class UnsafeRowReceiver( // Note that this queue will be drained from the main task thread and populated in the RPC // response thread. private val queues = Array.fill(numShuffleWriters) { - new ArrayBlockingQueue[UnsafeRowReceiverMessage](queueSize) + new ArrayBlockingQueue[RPCContinuousShuffleMessage](queueSize) } // Exposed for testing to determine if the endpoint gets stopped on task end. @@ -68,7 +66,9 @@ private[shuffle] class UnsafeRowReceiver( } override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { - case r: UnsafeRowReceiverMessage => + case r: RPCContinuousShuffleMessage => + // Note that this will block a thread the shared RPC handler pool! + // The TCP based shuffle handler (SPARK-24541) will avoid this problem. queues(r.writerId).put(r) context.reply(()) } @@ -79,10 +79,10 @@ private[shuffle] class UnsafeRowReceiver( private val writerEpochMarkersReceived = Array.fill(numShuffleWriters)(false) private val executor = Executors.newFixedThreadPool(numShuffleWriters) - private val completion = new ExecutorCompletionService[UnsafeRowReceiverMessage](executor) + private val completion = new ExecutorCompletionService[RPCContinuousShuffleMessage](executor) - private def completionTask(writerId: Int) = new Callable[UnsafeRowReceiverMessage] { - override def call(): UnsafeRowReceiverMessage = queues(writerId).take() + private def completionTask(writerId: Int) = new Callable[RPCContinuousShuffleMessage] { + override def call(): RPCContinuousShuffleMessage = queues(writerId).take() } // Initialize by submitting tasks to read the first row from each writer. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/RPCContinuousShuffleWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/RPCContinuousShuffleWriter.scala new file mode 100644 index 0000000000000..1c6f3ddb395e6 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/RPCContinuousShuffleWriter.scala @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.continuous.shuffle + +import scala.concurrent.Future +import scala.concurrent.duration.Duration + +import org.apache.spark.Partitioner +import org.apache.spark.rpc.RpcEndpointRef +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.util.ThreadUtils + +/** + * A [[ContinuousShuffleWriter]] sending data to [[RPCContinuousShuffleReader]] instances. + * + * @param writerId The partition ID of this writer. + * @param outputPartitioner The partitioner on the reader side of the shuffle. + * @param endpoints The [[RPCContinuousShuffleReader]] endpoints to write to. Indexed by + * partition ID within outputPartitioner. + */ +class RPCContinuousShuffleWriter( + writerId: Int, + outputPartitioner: Partitioner, + endpoints: Array[RpcEndpointRef]) extends ContinuousShuffleWriter { + + if (outputPartitioner.numPartitions != 1) { + throw new IllegalArgumentException("multiple readers not yet supported") + } + + if (outputPartitioner.numPartitions != endpoints.length) { + throw new IllegalArgumentException(s"partitioner size ${outputPartitioner.numPartitions} did " + + s"not match endpoint count ${endpoints.length}") + } + + def write(epoch: Iterator[UnsafeRow]): Unit = { + while (epoch.hasNext) { + val row = epoch.next() + endpoints(outputPartitioner.getPartition(row)).askSync[Unit](ReceiverRow(writerId, row)) + } + + val futures = endpoints.map(_.ask[Unit](ReceiverEpochMarker(writerId))).toSeq + implicit val ec = ThreadUtils.sameThread + ThreadUtils.awaitResult(Future.sequence(futures), Duration.Inf) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleSuite.scala similarity index 65% rename from sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleSuite.scala index 2e4d607a403ca..a8e3611b585cf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleSuite.scala @@ -17,29 +17,14 @@ package org.apache.spark.sql.execution.streaming.continuous.shuffle -import org.apache.spark.{TaskContext, TaskContextImpl} +import org.apache.spark.{HashPartitioner, Partition, TaskContext, TaskContextImpl} import org.apache.spark.rpc.RpcEndpointRef -import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection} +import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.streaming.StreamTest import org.apache.spark.sql.types.{DataType, IntegerType, StringType} import org.apache.spark.unsafe.types.UTF8String -class ContinuousShuffleReadSuite extends StreamTest { - - private def unsafeRow(value: Int) = { - UnsafeProjection.create(Array(IntegerType : DataType))( - new GenericInternalRow(Array(value: Any))) - } - - private def unsafeRow(value: String) = { - UnsafeProjection.create(Array(StringType : DataType))( - new GenericInternalRow(Array(UTF8String.fromString(value): Any))) - } - - private def send(endpoint: RpcEndpointRef, messages: UnsafeRowReceiverMessage*) = { - messages.foreach(endpoint.askSync[Unit](_)) - } - +class ContinuousShuffleSuite extends StreamTest { // In this unit test, we emulate that we're in the task thread where // ContinuousShuffleReadRDD.compute() will be evaluated. This requires a task context // thread local to be set. @@ -58,39 +43,29 @@ class ContinuousShuffleReadSuite extends StreamTest { super.afterEach() } - test("receiver stopped with row last") { - val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) - val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint - send( - endpoint, - ReceiverEpochMarker(0), - ReceiverRow(0, unsafeRow(111)) - ) + private implicit def unsafeRow(value: Int) = { + UnsafeProjection.create(Array(IntegerType : DataType))( + new GenericInternalRow(Array(value: Any))) + } - ctx.markTaskCompleted(None) - val receiver = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].reader - eventually(timeout(streamingTimeout)) { - assert(receiver.asInstanceOf[UnsafeRowReceiver].stopped.get()) - } + private def unsafeRow(value: String) = { + UnsafeProjection.create(Array(StringType : DataType))( + new GenericInternalRow(Array(UTF8String.fromString(value): Any))) } - test("receiver stopped with marker last") { - val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) - val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint - send( - endpoint, - ReceiverRow(0, unsafeRow(111)), - ReceiverEpochMarker(0) - ) + private def send(endpoint: RpcEndpointRef, messages: RPCContinuousShuffleMessage*) = { + messages.foreach(endpoint.askSync[Unit](_)) + } - ctx.markTaskCompleted(None) - val receiver = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].reader - eventually(timeout(streamingTimeout)) { - assert(receiver.asInstanceOf[UnsafeRowReceiver].stopped.get()) - } + private def readRDDEndpoint(rdd: ContinuousShuffleReadRDD) = { + rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint } - test("one epoch") { + private def readEpoch(rdd: ContinuousShuffleReadRDD) = { + rdd.compute(rdd.partitions(0), ctx).toSeq.map(_.getInt(0)) + } + + test("reader - one epoch") { val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint send( @@ -105,7 +80,7 @@ class ContinuousShuffleReadSuite extends StreamTest { assert(iter.toSeq.map(_.getInt(0)) == Seq(111, 222, 333)) } - test("multiple epochs") { + test("reader - multiple epochs") { val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint send( @@ -124,7 +99,7 @@ class ContinuousShuffleReadSuite extends StreamTest { assert(secondEpoch.toSeq.map(_.getInt(0)) == Seq(222, 333)) } - test("empty epochs") { + test("reader - empty epochs") { val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint @@ -148,7 +123,7 @@ class ContinuousShuffleReadSuite extends StreamTest { assert(rdd.compute(rdd.partitions(0), ctx).isEmpty) } - test("multiple partitions") { + test("reader - multiple partitions") { val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 5) // Send all data before processing to ensure there's no crossover. for (p <- rdd.partitions) { @@ -169,7 +144,7 @@ class ContinuousShuffleReadSuite extends StreamTest { } } - test("blocks waiting for new rows") { + test("reader - blocks waiting for new rows") { val rdd = new ContinuousShuffleReadRDD( sparkContext, numPartitions = 1, epochIntervalMs = Long.MaxValue) val epoch = rdd.compute(rdd.partitions(0), ctx) @@ -195,7 +170,7 @@ class ContinuousShuffleReadSuite extends StreamTest { } } - test("multiple writers") { + test("reader - multiple writers") { val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1, numShuffleWriters = 3) val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint send( @@ -213,7 +188,7 @@ class ContinuousShuffleReadSuite extends StreamTest { Set("writer0-row0", "writer1-row0", "writer2-row0")) } - test("epoch only ends when all writers send markers") { + test("reader - epoch only ends when all writers send markers") { val rdd = new ContinuousShuffleReadRDD( sparkContext, numPartitions = 1, numShuffleWriters = 3, epochIntervalMs = Long.MaxValue) val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint @@ -233,6 +208,7 @@ class ContinuousShuffleReadSuite extends StreamTest { // After checking the right rows, block until we get an epoch marker indicating there's no next. // (Also fail the assertion if for some reason we get a row.) + val readEpochMarkerThread = new Thread { override def run(): Unit = { assert(!epoch.hasNext) @@ -251,10 +227,10 @@ class ContinuousShuffleReadSuite extends StreamTest { } // Join to pick up assertion failures. - readEpochMarkerThread.join() + readEpochMarkerThread.join(streamingTimeout.toMillis) } - test("writer epochs non aligned") { + test("reader - writer epochs non aligned") { val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1, numShuffleWriters = 3) val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint // We send multiple epochs for 0, then multiple for 1, then multiple for 2. The receiver should @@ -288,4 +264,153 @@ class ContinuousShuffleReadSuite extends StreamTest { val thirdEpoch = rdd.compute(rdd.partitions(0), ctx).map(_.getUTF8String(0).toString).toSet assert(thirdEpoch == Set("writer1-row1", "writer2-row0")) } + + test("one epoch") { + val reader = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) + val writer = new RPCContinuousShuffleWriter( + 0, new HashPartitioner(1), Array(readRDDEndpoint(reader))) + + writer.write(Iterator(1, 2, 3)) + + assert(readEpoch(reader) == Seq(1, 2, 3)) + } + + test("multiple epochs") { + val reader = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) + val writer = new RPCContinuousShuffleWriter( + 0, new HashPartitioner(1), Array(readRDDEndpoint(reader))) + + writer.write(Iterator(1, 2, 3)) + writer.write(Iterator(4, 5, 6)) + + assert(readEpoch(reader) == Seq(1, 2, 3)) + assert(readEpoch(reader) == Seq(4, 5, 6)) + } + + test("empty epochs") { + val reader = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) + val writer = new RPCContinuousShuffleWriter( + 0, new HashPartitioner(1), Array(readRDDEndpoint(reader))) + + writer.write(Iterator()) + writer.write(Iterator(1, 2)) + writer.write(Iterator()) + writer.write(Iterator()) + writer.write(Iterator(3, 4)) + writer.write(Iterator()) + + assert(readEpoch(reader) == Seq()) + assert(readEpoch(reader) == Seq(1, 2)) + assert(readEpoch(reader) == Seq()) + assert(readEpoch(reader) == Seq()) + assert(readEpoch(reader) == Seq(3, 4)) + assert(readEpoch(reader) == Seq()) + } + + test("blocks waiting for writer") { + val reader = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) + val writer = new RPCContinuousShuffleWriter( + 0, new HashPartitioner(1), Array(readRDDEndpoint(reader))) + + val readerEpoch = reader.compute(reader.partitions(0), ctx) + + val readRowThread = new Thread { + override def run(): Unit = { + assert(readerEpoch.toSeq.map(_.getInt(0)) == Seq(1)) + } + } + readRowThread.start() + + eventually(timeout(streamingTimeout)) { + assert(readRowThread.getState == Thread.State.TIMED_WAITING) + } + + // Once we write the epoch the thread should stop waiting and succeed. + writer.write(Iterator(1)) + readRowThread.join(streamingTimeout.toMillis) + } + + test("multiple writer partitions") { + val numWriterPartitions = 3 + + val reader = new ContinuousShuffleReadRDD( + sparkContext, numPartitions = 1, numShuffleWriters = numWriterPartitions) + val writers = (0 until 3).map { idx => + new RPCContinuousShuffleWriter(idx, new HashPartitioner(1), Array(readRDDEndpoint(reader))) + } + + writers(0).write(Iterator(1, 4, 7)) + writers(1).write(Iterator(2, 5)) + writers(2).write(Iterator(3, 6)) + + writers(0).write(Iterator(4, 7, 10)) + writers(1).write(Iterator(5, 8)) + writers(2).write(Iterator(6, 9)) + + // Since there are multiple asynchronous writers, the original row sequencing is not guaranteed. + // The epochs should be deterministically preserved, however. + assert(readEpoch(reader).toSet == Seq(1, 2, 3, 4, 5, 6, 7).toSet) + assert(readEpoch(reader).toSet == Seq(4, 5, 6, 7, 8, 9, 10).toSet) + } + + test("reader epoch only ends when all writer partitions write it") { + val numWriterPartitions = 3 + + val reader = new ContinuousShuffleReadRDD( + sparkContext, numPartitions = 1, numShuffleWriters = numWriterPartitions) + val writers = (0 until 3).map { idx => + new RPCContinuousShuffleWriter(idx, new HashPartitioner(1), Array(readRDDEndpoint(reader))) + } + + writers(1).write(Iterator()) + writers(2).write(Iterator()) + + val readerEpoch = reader.compute(reader.partitions(0), ctx) + + val readEpochMarkerThread = new Thread { + override def run(): Unit = { + assert(!readerEpoch.hasNext) + } + } + + readEpochMarkerThread.start() + eventually(timeout(streamingTimeout)) { + assert(readEpochMarkerThread.getState == Thread.State.TIMED_WAITING) + } + + writers(0).write(Iterator()) + readEpochMarkerThread.join(streamingTimeout.toMillis) + } + + test("receiver stopped with row last") { + val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) + val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint + send( + endpoint, + ReceiverEpochMarker(0), + ReceiverRow(0, unsafeRow(111)) + ) + + ctx.markTaskCompleted(None) + val receiver = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].reader + eventually(timeout(streamingTimeout)) { + assert(receiver.asInstanceOf[RPCContinuousShuffleReader].stopped.get()) + } + } + + test("receiver stopped with marker last") { + val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) + val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint + send( + endpoint, + ReceiverRow(0, unsafeRow(111)), + ReceiverEpochMarker(0) + ) + + ctx.markTaskCompleted(None) + val receiver = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].reader + eventually(timeout(streamingTimeout)) { + assert(receiver.asInstanceOf[RPCContinuousShuffleReader].stopped.get()) + } + } } From 3bf76918fb67fb3ee9aed254d4fb3b87a7e66117 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Wed, 13 Jun 2018 15:18:19 -0700 Subject: [PATCH 0960/1161] [SPARK-24531][TESTS] Replace 2.3.0 version with 2.3.1 ## What changes were proposed in this pull request? The PR updates the 2.3 version tested to the new release 2.3.1. ## How was this patch tested? existing UTs Author: Marco Gaido Closes #21543 from mgaido91/patch-1. --- .../spark/sql/hive/HiveExternalCatalogVersionsSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala index 6f904c937348d..514921875f1f9 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala @@ -195,7 +195,7 @@ class HiveExternalCatalogVersionsSuite extends SparkSubmitTestUtils { object PROCESS_TABLES extends QueryTest with SQLTestUtils { // Tests the latest version of every release line. - val testingVersions = Seq("2.0.2", "2.1.2", "2.2.1", "2.3.0") + val testingVersions = Seq("2.0.2", "2.1.2", "2.2.1", "2.3.1") protected var spark: SparkSession = _ From 534065efeb51ff0d308fa6cc9dea0715f8ce25ad Mon Sep 17 00:00:00 2001 From: Xingbo Jiang Date: Thu, 14 Jun 2018 14:20:48 +0800 Subject: [PATCH 0961/1161] [MINOR][CORE][TEST] Remove unnecessary sort in UnsafeInMemorySorterSuite ## What changes were proposed in this pull request? We don't require specific ordering of the input data, the sort action is not necessary and misleading. ## How was this patch tested? Existing test suite. Author: Xingbo Jiang Closes #21536 from jiangxb1987/sorterSuite. --- .../util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java | 1 - 1 file changed, 1 deletion(-) diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java index c145532328514..85ffdca436e14 100644 --- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java @@ -129,7 +129,6 @@ public int compare( final UnsafeSorterIterator iter = sorter.getSortedIterator(); int iterLength = 0; long prevPrefix = -1; - Arrays.sort(dataToSort); while (iter.hasNext()) { iter.loadNext(); final String str = From fdadc4be08dcf1a06383bbb05e53540da2092c63 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Thu, 14 Jun 2018 09:20:41 -0700 Subject: [PATCH 0962/1161] [SPARK-24495][SQL] EnsureRequirement returns wrong plan when reordering equal keys ## What changes were proposed in this pull request? `EnsureRequirement` in its `reorder` method currently assumes that the same key appears only once in the join condition. This of course might not be the case, and when it is not satisfied, it returns a wrong plan which produces a wrong result of the query. ## How was this patch tested? added UT Author: Marco Gaido Closes #21529 from mgaido91/SPARK-24495. --- .../execution/exchange/EnsureRequirements.scala | 14 ++++++++++++-- .../scala/org/apache/spark/sql/JoinSuite.scala | 11 +++++++++++ .../spark/sql/execution/PlannerSuite.scala | 17 +++++++++++++++++ 3 files changed, 40 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index e3d28388c5470..ad95879d86f42 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution.exchange +import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.catalyst.expressions._ @@ -227,9 +228,16 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { currentOrderOfKeys: Seq[Expression]): (Seq[Expression], Seq[Expression]) = { val leftKeysBuffer = ArrayBuffer[Expression]() val rightKeysBuffer = ArrayBuffer[Expression]() + val pickedIndexes = mutable.Set[Int]() + val keysAndIndexes = currentOrderOfKeys.zipWithIndex expectedOrderOfKeys.foreach(expression => { - val index = currentOrderOfKeys.indexWhere(e => e.semanticEquals(expression)) + val index = keysAndIndexes.find { case (e, idx) => + // As we may have the same key used many times, we need to filter out its occurrence we + // have already used. + e.semanticEquals(expression) && !pickedIndexes.contains(idx) + }.map(_._2).get + pickedIndexes += index leftKeysBuffer.append(leftKeys(index)) rightKeysBuffer.append(rightKeys(index)) }) @@ -270,7 +278,7 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { * partitioning of the join nodes' children. */ private def reorderJoinPredicates(plan: SparkPlan): SparkPlan = { - plan.transformUp { + plan match { case BroadcastHashJoinExec(leftKeys, rightKeys, joinType, buildSide, condition, left, right) => val (reorderedLeftKeys, reorderedRightKeys) = @@ -288,6 +296,8 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { val (reorderedLeftKeys, reorderedRightKeys) = reorderJoinKeys(leftKeys, rightKeys, left.outputPartitioning, right.outputPartitioning) SortMergeJoinExec(reorderedLeftKeys, reorderedRightKeys, joinType, condition, left, right) + + case other => other } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 8fa747465cb1a..44767dfc92497 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -882,4 +882,15 @@ class JoinSuite extends QueryTest with SharedSQLContext { checkAnswer(df, Row(3, 8, 7, 2) :: Row(3, 8, 4, 2) :: Nil) } } + + test("SPARK-24495: Join may return wrong result when having duplicated equal-join keys") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1", + SQLConf.CONSTRAINT_PROPAGATION_ENABLED.key -> "false", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + val df1 = spark.range(0, 100, 1, 2) + val df2 = spark.range(100).select($"id".as("b1"), (- $"id").as("b2")) + val res = df1.join(df2, $"id" === $"b1" && $"id" === $"b2").select($"b1", $"b2", $"id") + checkAnswer(res, Row(0, 0, 0)) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index ed0ff1be476c7..37d468739c613 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -680,6 +680,23 @@ class PlannerSuite extends SharedSQLContext { assert(rangeExecInZeroPartition.head.outputPartitioning == UnknownPartitioning(0)) } + test("SPARK-24495: EnsureRequirements can return wrong plan when reusing the same key in join") { + val plan1 = DummySparkPlan(outputOrdering = Seq(orderingA), + outputPartitioning = HashPartitioning(exprA :: exprA :: Nil, 5)) + val plan2 = DummySparkPlan(outputOrdering = Seq(orderingB), + outputPartitioning = HashPartitioning(exprB :: Nil, 5)) + val smjExec = SortMergeJoinExec( + exprA :: exprA :: Nil, exprB :: exprC :: Nil, Inner, None, plan1, plan2) + + val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(smjExec) + outputPlan match { + case SortMergeJoinExec(leftKeys, rightKeys, _, _, _, _) => + assert(leftKeys == Seq(exprA, exprA)) + assert(rightKeys == Seq(exprB, exprC)) + case _ => fail() + } + } + test("SPARK-24500: create union with stream of children") { val df = Union(Stream( Range(1, 1, 1, 1), From d3eed8fd6d65d95306abfb513a9e0fde05b703ac Mon Sep 17 00:00:00 2001 From: Li Jin Date: Thu, 14 Jun 2018 13:16:20 -0700 Subject: [PATCH 0963/1161] =?UTF-8?q?[SPARK-24563][PYTHON]=20Catch=20TypeE?= =?UTF-8?q?rror=20when=20testing=20existence=20of=20HiveConf=20when=20crea?= =?UTF-8?q?ting=20pysp=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …ark shell ## What changes were proposed in this pull request? This PR catches TypeError when testing existence of HiveConf when creating pyspark shell ## How was this patch tested? Manually tested. Here are the manual test cases: Build with hive: ``` (pyarrow-dev) Lis-MacBook-Pro:spark icexelloss$ bin/pyspark Python 3.6.5 | packaged by conda-forge | (default, Apr 6 2018, 13:44:09) [GCC 4.2.1 Compatible Apple LLVM 6.1.0 (clang-602.0.53)] on darwin Type "help", "copyright", "credits" or "license" for more information. 18/06/14 14:55:41 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable Setting default log level to "WARN". To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel). Welcome to ____ __ / __/__ ___ _____/ /__ _\ \/ _ \/ _ `/ __/ '_/ /__ / .__/\_,_/_/ /_/\_\ version 2.4.0-SNAPSHOT /_/ Using Python version 3.6.5 (default, Apr 6 2018 13:44:09) SparkSession available as 'spark'. >>> spark.conf.get('spark.sql.catalogImplementation') 'hive' ``` Build without hive: ``` (pyarrow-dev) Lis-MacBook-Pro:spark icexelloss$ bin/pyspark Python 3.6.5 | packaged by conda-forge | (default, Apr 6 2018, 13:44:09) [GCC 4.2.1 Compatible Apple LLVM 6.1.0 (clang-602.0.53)] on darwin Type "help", "copyright", "credits" or "license" for more information. 18/06/14 15:04:52 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable Setting default log level to "WARN". To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel). Welcome to ____ __ / __/__ ___ _____/ /__ _\ \/ _ \/ _ `/ __/ '_/ /__ / .__/\_,_/_/ /_/\_\ version 2.4.0-SNAPSHOT /_/ Using Python version 3.6.5 (default, Apr 6 2018 13:44:09) SparkSession available as 'spark'. >>> spark.conf.get('spark.sql.catalogImplementation') 'in-memory' ``` Failed to start shell: ``` (pyarrow-dev) Lis-MacBook-Pro:spark icexelloss$ bin/pyspark Python 3.6.5 | packaged by conda-forge | (default, Apr 6 2018, 13:44:09) [GCC 4.2.1 Compatible Apple LLVM 6.1.0 (clang-602.0.53)] on darwin Type "help", "copyright", "credits" or "license" for more information. 18/06/14 15:07:53 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable Setting default log level to "WARN". To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel). /Users/icexelloss/workspace/spark/python/pyspark/shell.py:45: UserWarning: Failed to initialize Spark session. warnings.warn("Failed to initialize Spark session.") Traceback (most recent call last): File "/Users/icexelloss/workspace/spark/python/pyspark/shell.py", line 41, in spark = SparkSession._create_shell_session() File "/Users/icexelloss/workspace/spark/python/pyspark/sql/session.py", line 581, in _create_shell_session return SparkSession.builder.getOrCreate() File "/Users/icexelloss/workspace/spark/python/pyspark/sql/session.py", line 168, in getOrCreate raise py4j.protocol.Py4JError("Fake Py4JError") py4j.protocol.Py4JError: Fake Py4JError (pyarrow-dev) Lis-MacBook-Pro:spark icexelloss$ ``` Author: Li Jin Closes #21569 from icexelloss/SPARK-24563-fix-pyspark-shell-without-hive. --- python/pyspark/sql/session.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index e880dd1ca6d1a..f1ad6b1212ed9 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -567,14 +567,7 @@ def _create_shell_session(): .getOrCreate() else: return SparkSession.builder.getOrCreate() - except py4j.protocol.Py4JError: - if conf.get('spark.sql.catalogImplementation', '').lower() == 'hive': - warnings.warn("Fall back to non-hive support because failing to access HiveConf, " - "please make sure you build spark with hive") - - try: - return SparkSession.builder.getOrCreate() - except TypeError: + except (py4j.protocol.Py4JError, TypeError): if conf.get('spark.sql.catalogImplementation', '').lower() == 'hive': warnings.warn("Fall back to non-hive support because failing to access HiveConf, " "please make sure you build spark with hive") From b8f27ae3b34134a01998b77db4b7935e7f82a4fe Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Thu, 14 Jun 2018 13:27:27 -0700 Subject: [PATCH 0964/1161] [SPARK-24543][SQL] Support any type as DDL string for from_json's schema ## What changes were proposed in this pull request? In the PR, I propose to support any DataType represented as DDL string for the from_json function. After the changes, it will be possible to specify `MapType` in SQL like: ```sql select from_json('{"a":1, "b":2}', 'map') ``` and in Scala (similar in other languages) ```scala val in = Seq("""{"a": {"b": 1}}""").toDS() val schema = "map>" val out = in.select(from_json($"value", schema, Map.empty[String, String])) ``` ## How was this patch tested? Added a couple sql tests and modified existing tests for Python and Scala. The former tests were modified because it is not imported for them in which format schema for `from_json` is provided. Author: Maxim Gekk Closes #21550 from MaxGekk/from_json-ddl-schema. --- python/pyspark/sql/functions.py | 3 +-- .../catalyst/expressions/jsonExpressions.scala | 5 ++--- .../org/apache/spark/sql/types/DataType.scala | 11 +++++++++++ .../scala/org/apache/spark/sql/functions.scala | 2 +- .../sql-tests/inputs/json-functions.sql | 4 ++++ .../sql-tests/results/json-functions.sql.out | 18 +++++++++++++++++- .../apache/spark/sql/JsonFunctionsSuite.scala | 4 ++-- 7 files changed, 38 insertions(+), 9 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index a5e3384e802b8..e6346691fb1d4 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2168,8 +2168,7 @@ def from_json(col, schema, options={}): [Row(json=Row(a=1))] >>> df.select(from_json(df.value, "a INT").alias("json")).collect() [Row(json=Row(a=1))] - >>> schema = MapType(StringType(), IntegerType()) - >>> df.select(from_json(df.value, schema).alias("json")).collect() + >>> df.select(from_json(df.value, "MAP").alias("json")).collect() [Row(json={u'a': 1})] >>> data = [(1, '''[{"a": 1}]''')] >>> schema = ArrayType(StructType([StructField("a", IntegerType())])) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index 04a4eb0ffc032..f6d74f5b74c8e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -28,7 +28,6 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.json._ -import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, BadRecordException, FailFastMode, GenericArrayData, MapData} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -747,8 +746,8 @@ case class StructsToJson( object JsonExprUtils { - def validateSchemaLiteral(exp: Expression): StructType = exp match { - case Literal(s, StringType) => CatalystSqlParser.parseTableSchema(s.toString) + def validateSchemaLiteral(exp: Expression): DataType = exp match { + case Literal(s, StringType) => DataType.fromDDL(s.toString) case e => throw new AnalysisException(s"Expected a string literal instead of $e") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala index 0bef11659fc9e..fd40741cfb5f1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.types import java.util.Locale +import scala.util.control.NonFatal + import org.json4s._ import org.json4s.JsonAST.JValue import org.json4s.JsonDSL._ @@ -26,6 +28,7 @@ import org.json4s.jackson.JsonMethods._ import org.apache.spark.annotation.InterfaceStability import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.Utils @@ -110,6 +113,14 @@ abstract class DataType extends AbstractDataType { @InterfaceStability.Stable object DataType { + def fromDDL(ddl: String): DataType = { + try { + CatalystSqlParser.parseDataType(ddl) + } catch { + case NonFatal(_) => CatalystSqlParser.parseTableSchema(ddl) + } + } + def fromJson(json: String): DataType = parseDataType(parse(json)) private val nonDecimalNameToType = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 87bd7b3b0f9c6..8551058ec58ce 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3369,7 +3369,7 @@ object functions { val dataType = try { DataType.fromJson(schema) } catch { - case NonFatal(_) => StructType.fromDDL(schema) + case NonFatal(_) => DataType.fromDDL(schema) } from_json(e, dataType, options) } diff --git a/sql/core/src/test/resources/sql-tests/inputs/json-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/json-functions.sql index fea069eac4d48..dc15d13cd1dd3 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/json-functions.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/json-functions.sql @@ -31,3 +31,7 @@ CREATE TEMPORARY VIEW jsonTable(jsonField, a) AS SELECT * FROM VALUES ('{"a": 1, SELECT json_tuple(jsonField, 'b', CAST(NULL AS STRING), a) FROM jsonTable; -- Clean up DROP VIEW IF EXISTS jsonTable; + +-- from_json - complex types +select from_json('{"a":1, "b":2}', 'map'); +select from_json('{"a":1, "b":"2"}', 'struct'); diff --git a/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out index 14a69128ffb41..2b3288dc5a137 100644 --- a/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 26 +-- Number of queries: 28 -- !query 0 @@ -258,3 +258,19 @@ DROP VIEW IF EXISTS jsonTable struct<> -- !query 25 output + + +-- !query 26 +select from_json('{"a":1, "b":2}', 'map') +-- !query 26 schema +struct> +-- !query 26 output +{"a":1,"b":2} + + +-- !query 27 +select from_json('{"a":1, "b":"2"}', 'struct') +-- !query 27 schema +struct> +-- !query 27 output +{"a":1,"b":"2"} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala index 055e1fc5640f3..7bf17cbcd9c97 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala @@ -354,8 +354,8 @@ class JsonFunctionsSuite extends QueryTest with SharedSQLContext { test("SPARK-24027: from_json - map>") { val in = Seq("""{"a": {"b": 1}}""").toDS() - val schema = MapType(StringType, MapType(StringType, IntegerType)) - val out = in.select(from_json($"value", schema)) + val schema = "map>" + val out = in.select(from_json($"value", schema, Map.empty[String, String])) checkAnswer(out, Row(Map("a" -> Map("b" -> 1)))) } From 18cb0c07988578156c869682d8a2c4151e8d35e5 Mon Sep 17 00:00:00 2001 From: Gabor Somogyi Date: Thu, 14 Jun 2018 14:54:46 -0700 Subject: [PATCH 0965/1161] [SPARK-24319][SPARK SUBMIT] Fix spark-submit execution where no main class is required. ## What changes were proposed in this pull request? With [PR 20925](https://github.com/apache/spark/pull/20925) now it's not possible to execute the following commands: * run-example * run-example --help * run-example --version * run-example --usage-error * run-example --status ... * run-example --kill ... In this PR the execution will be allowed for the mentioned commands. ## How was this patch tested? Existing unit tests extended + additional written. Author: Gabor Somogyi Closes #21450 from gaborgsomogyi/SPARK-24319. --- .../java/org/apache/spark/launcher/Main.java | 36 +++++++++---- .../launcher/SparkSubmitCommandBuilder.java | 33 ++++++------ .../SparkSubmitCommandBuilderSuite.java | 54 +++++++++++++++++-- 3 files changed, 90 insertions(+), 33 deletions(-) diff --git a/launcher/src/main/java/org/apache/spark/launcher/Main.java b/launcher/src/main/java/org/apache/spark/launcher/Main.java index 1e34bb8c73279..d967aa39a4827 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/Main.java +++ b/launcher/src/main/java/org/apache/spark/launcher/Main.java @@ -17,6 +17,7 @@ package org.apache.spark.launcher; +import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; @@ -54,10 +55,12 @@ public static void main(String[] argsArray) throws Exception { String className = args.remove(0); boolean printLaunchCommand = !isEmpty(System.getenv("SPARK_PRINT_LAUNCH_COMMAND")); - AbstractCommandBuilder builder; + Map env = new HashMap<>(); + List cmd; if (className.equals("org.apache.spark.deploy.SparkSubmit")) { try { - builder = new SparkSubmitCommandBuilder(args); + AbstractCommandBuilder builder = new SparkSubmitCommandBuilder(args); + cmd = buildCommand(builder, env, printLaunchCommand); } catch (IllegalArgumentException e) { printLaunchCommand = false; System.err.println("Error: " + e.getMessage()); @@ -76,17 +79,12 @@ public static void main(String[] argsArray) throws Exception { help.add(parser.className); } help.add(parser.USAGE_ERROR); - builder = new SparkSubmitCommandBuilder(help); + AbstractCommandBuilder builder = new SparkSubmitCommandBuilder(help); + cmd = buildCommand(builder, env, printLaunchCommand); } } else { - builder = new SparkClassCommandBuilder(className, args); - } - - Map env = new HashMap<>(); - List cmd = builder.buildCommand(env); - if (printLaunchCommand) { - System.err.println("Spark Command: " + join(" ", cmd)); - System.err.println("========================================"); + AbstractCommandBuilder builder = new SparkClassCommandBuilder(className, args); + cmd = buildCommand(builder, env, printLaunchCommand); } if (isWindows()) { @@ -101,6 +99,22 @@ public static void main(String[] argsArray) throws Exception { } } + /** + * Prepare spark commands with the appropriate command builder. + * If printLaunchCommand is set then the commands will be printed to the stderr. + */ + private static List buildCommand( + AbstractCommandBuilder builder, + Map env, + boolean printLaunchCommand) throws IOException, IllegalArgumentException { + List cmd = builder.buildCommand(env); + if (printLaunchCommand) { + System.err.println("Spark Command: " + join(" ", cmd)); + System.err.println("========================================"); + } + return cmd; + } + /** * Prepare a command line for execution from a Windows batch script. * diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java index 5cb6457bf5c21..cc65f78b45c30 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java +++ b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java @@ -90,7 +90,8 @@ class SparkSubmitCommandBuilder extends AbstractCommandBuilder { final List userArgs; private final List parsedArgs; - private final boolean requiresAppResource; + // Special command means no appResource and no mainClass required + private final boolean isSpecialCommand; private final boolean isExample; /** @@ -105,7 +106,7 @@ class SparkSubmitCommandBuilder extends AbstractCommandBuilder { * spark-submit argument list to be modified after creation. */ SparkSubmitCommandBuilder() { - this.requiresAppResource = true; + this.isSpecialCommand = false; this.isExample = false; this.parsedArgs = new ArrayList<>(); this.userArgs = new ArrayList<>(); @@ -138,25 +139,26 @@ class SparkSubmitCommandBuilder extends AbstractCommandBuilder { case RUN_EXAMPLE: isExample = true; + appResource = SparkLauncher.NO_RESOURCE; submitArgs = args.subList(1, args.size()); } this.isExample = isExample; OptionParser parser = new OptionParser(true); parser.parse(submitArgs); - this.requiresAppResource = parser.requiresAppResource; + this.isSpecialCommand = parser.isSpecialCommand; } else { this.isExample = isExample; - this.requiresAppResource = false; + this.isSpecialCommand = true; } } @Override public List buildCommand(Map env) throws IOException, IllegalArgumentException { - if (PYSPARK_SHELL.equals(appResource) && requiresAppResource) { + if (PYSPARK_SHELL.equals(appResource) && !isSpecialCommand) { return buildPySparkShellCommand(env); - } else if (SPARKR_SHELL.equals(appResource) && requiresAppResource) { + } else if (SPARKR_SHELL.equals(appResource) && !isSpecialCommand) { return buildSparkRCommand(env); } else { return buildSparkSubmitCommand(env); @@ -166,18 +168,18 @@ public List buildCommand(Map env) List buildSparkSubmitArgs() { List args = new ArrayList<>(); OptionParser parser = new OptionParser(false); - final boolean requiresAppResource; + final boolean isSpecialCommand; // If the user args array is not empty, we need to parse it to detect exactly what // the user is trying to run, so that checks below are correct. if (!userArgs.isEmpty()) { parser.parse(userArgs); - requiresAppResource = parser.requiresAppResource; + isSpecialCommand = parser.isSpecialCommand; } else { - requiresAppResource = this.requiresAppResource; + isSpecialCommand = this.isSpecialCommand; } - if (!allowsMixedArguments && requiresAppResource) { + if (!allowsMixedArguments && !isSpecialCommand) { checkArgument(appResource != null, "Missing application resource."); } @@ -229,7 +231,7 @@ List buildSparkSubmitArgs() { args.add(join(",", pyFiles)); } - if (isExample) { + if (isExample && !isSpecialCommand) { checkArgument(mainClass != null, "Missing example class name."); } @@ -421,7 +423,7 @@ private List findExamplesJars() { private class OptionParser extends SparkSubmitOptionParser { - boolean requiresAppResource = true; + boolean isSpecialCommand = false; private final boolean errorOnUnknownArgs; OptionParser(boolean errorOnUnknownArgs) { @@ -470,17 +472,14 @@ protected boolean handle(String opt, String value) { break; case KILL_SUBMISSION: case STATUS: - requiresAppResource = false; + isSpecialCommand = true; parsedArgs.add(opt); parsedArgs.add(value); break; case HELP: case USAGE_ERROR: - requiresAppResource = false; - parsedArgs.add(opt); - break; case VERSION: - requiresAppResource = false; + isSpecialCommand = true; parsedArgs.add(opt); break; default: diff --git a/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitCommandBuilderSuite.java b/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitCommandBuilderSuite.java index 2e050f8413074..b343094b2e7b8 100644 --- a/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitCommandBuilderSuite.java +++ b/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitCommandBuilderSuite.java @@ -18,6 +18,7 @@ package org.apache.spark.launcher; import java.io.File; +import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.HashMap; @@ -27,7 +28,10 @@ import org.junit.AfterClass; import org.junit.BeforeClass; +import org.junit.Rule; import org.junit.Test; +import org.junit.rules.ExpectedException; + import static org.junit.Assert.*; public class SparkSubmitCommandBuilderSuite extends BaseSuite { @@ -35,6 +39,9 @@ public class SparkSubmitCommandBuilderSuite extends BaseSuite { private static File dummyPropsFile; private static SparkSubmitOptionParser parser; + @Rule + public ExpectedException expectedException = ExpectedException.none(); + @BeforeClass public static void setUp() throws Exception { dummyPropsFile = File.createTempFile("spark", "properties"); @@ -74,8 +81,11 @@ public void testCliHelpAndNoArg() throws Exception { @Test public void testCliKillAndStatus() throws Exception { - testCLIOpts(parser.STATUS); - testCLIOpts(parser.KILL_SUBMISSION); + List params = Arrays.asList("driver-20160531171222-0000"); + testCLIOpts(null, parser.STATUS, params); + testCLIOpts(null, parser.KILL_SUBMISSION, params); + testCLIOpts(SparkSubmitCommandBuilder.RUN_EXAMPLE, parser.STATUS, params); + testCLIOpts(SparkSubmitCommandBuilder.RUN_EXAMPLE, parser.KILL_SUBMISSION, params); } @Test @@ -190,6 +200,33 @@ public void testSparkRShell() throws Exception { env.get("SPARKR_SUBMIT_ARGS")); } + @Test(expected = IllegalArgumentException.class) + public void testExamplesRunnerNoArg() throws Exception { + List sparkSubmitArgs = Arrays.asList(SparkSubmitCommandBuilder.RUN_EXAMPLE); + Map env = new HashMap<>(); + buildCommand(sparkSubmitArgs, env); + } + + @Test + public void testExamplesRunnerNoMainClass() throws Exception { + testCLIOpts(SparkSubmitCommandBuilder.RUN_EXAMPLE, parser.HELP, null); + testCLIOpts(SparkSubmitCommandBuilder.RUN_EXAMPLE, parser.USAGE_ERROR, null); + testCLIOpts(SparkSubmitCommandBuilder.RUN_EXAMPLE, parser.VERSION, null); + } + + @Test + public void testExamplesRunnerWithMasterNoMainClass() throws Exception { + expectedException.expect(IllegalArgumentException.class); + expectedException.expectMessage("Missing example class name."); + + List sparkSubmitArgs = Arrays.asList( + SparkSubmitCommandBuilder.RUN_EXAMPLE, + parser.MASTER + "=foo" + ); + Map env = new HashMap<>(); + buildCommand(sparkSubmitArgs, env); + } + @Test public void testExamplesRunner() throws Exception { List sparkSubmitArgs = Arrays.asList( @@ -344,10 +381,17 @@ private List buildCommand(List args, Map env) th return newCommandBuilder(args).buildCommand(env); } - private void testCLIOpts(String opt) throws Exception { - List helpArgs = Arrays.asList(opt, "driver-20160531171222-0000"); + private void testCLIOpts(String appResource, String opt, List params) throws Exception { + List args = new ArrayList<>(); + if (appResource != null) { + args.add(appResource); + } + args.add(opt); + if (params != null) { + args.addAll(params); + } Map env = new HashMap<>(); - List cmd = buildCommand(helpArgs, env); + List cmd = buildCommand(args, env); assertTrue(opt + " should be contained in the final cmd.", cmd.contains(opt)); } From 270a9a3cac25f3e799460320d0fc94ccd7ecfaea Mon Sep 17 00:00:00 2001 From: mcheah Date: Thu, 14 Jun 2018 15:56:21 -0700 Subject: [PATCH 0966/1161] [SPARK-24248][K8S] Use level triggering and state reconciliation in scheduling and lifecycle ## What changes were proposed in this pull request? Previously, the scheduler backend was maintaining state in many places, not only for reading state but also writing to it. For example, state had to be managed in both the watch and in the executor allocator runnable. Furthermore, one had to keep track of multiple hash tables. We can do better here by: 1. Consolidating the places where we manage state. Here, we take inspiration from traditional Kubernetes controllers. These controllers tend to follow a level-triggered mechanism. This means that the controller will continuously monitor the API server via watches and polling, and on periodic passes, the controller will reconcile the current state of the cluster with the desired state. We implement this by introducing the concept of a pod snapshot, which is a given state of the executors in the Kubernetes cluster. We operate periodically on snapshots. To prevent overloading the API server with polling requests to get the state of the cluster (particularly for executor allocation where we want to be checking frequently to get executors to launch without unbearably bad latency), we use watches to populate snapshots by applying observed events to a previous snapshot to get a new snapshot. Whenever we do poll the cluster, the polled state replaces any existing snapshot - this ensures eventual consistency and mirroring of the cluster, as is desired in a level triggered architecture. 2. Storing less specialized in-memory state in general. Previously we were creating hash tables to represent the state of executors. Instead, it's easier to represent state solely by the snapshots. ## How was this patch tested? Integration tests should test there's no regressions end to end. Unit tests to be updated, in particular focusing on different orderings of events, particularly accounting for when events come in unexpected ordering. Author: mcheah Closes #21366 from mccheah/event-queue-driven-scheduling. --- LICENSE | 1 + .../org/apache/spark/util/ThreadUtils.scala | 31 +- licenses/LICENSE-jmock.txt | 28 ++ pom.xml | 6 + resource-managers/kubernetes/core/pom.xml | 12 +- .../org/apache/spark/deploy/k8s/Config.scala | 19 +- .../cluster/k8s/ExecutorPodStates.scala | 37 ++ .../cluster/k8s/ExecutorPodsAllocator.scala | 149 ++++++ .../k8s/ExecutorPodsLifecycleManager.scala | 176 +++++++ .../ExecutorPodsPollingSnapshotSource.scala | 68 +++ .../cluster/k8s/ExecutorPodsSnapshot.scala | 74 +++ .../k8s/ExecutorPodsSnapshotsStore.scala | 32 ++ .../k8s/ExecutorPodsSnapshotsStoreImpl.scala | 113 +++++ .../k8s/ExecutorPodsWatchSnapshotSource.scala | 67 +++ .../k8s/KubernetesClusterManager.scala | 42 +- .../KubernetesClusterSchedulerBackend.scala | 417 +++------------- .../spark/deploy/k8s/Fabric8Aliases.scala | 30 ++ .../spark/deploy/k8s/submit/ClientSuite.scala | 9 +- ...erministicExecutorPodsSnapshotsStore.scala | 51 ++ .../k8s/ExecutorLifecycleTestUtils.scala | 123 +++++ .../k8s/ExecutorPodsAllocatorSuite.scala | 179 +++++++ .../ExecutorPodsLifecycleManagerSuite.scala | 126 +++++ ...ecutorPodsPollingSnapshotSourceSuite.scala | 85 ++++ .../k8s/ExecutorPodsSnapshotSuite.scala | 60 +++ .../k8s/ExecutorPodsSnapshotsStoreSuite.scala | 137 ++++++ ...ExecutorPodsWatchSnapshotSourceSuite.scala | 75 +++ ...bernetesClusterSchedulerBackendSuite.scala | 451 +++--------------- 27 files changed, 1842 insertions(+), 756 deletions(-) create mode 100644 licenses/LICENSE-jmock.txt create mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodStates.scala create mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocator.scala create mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsLifecycleManager.scala create mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsPollingSnapshotSource.scala create mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshot.scala create mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshotsStore.scala create mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshotsStoreImpl.scala create mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsWatchSnapshotSource.scala create mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/Fabric8Aliases.scala create mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/DeterministicExecutorPodsSnapshotsStore.scala create mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorLifecycleTestUtils.scala create mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocatorSuite.scala create mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsLifecycleManagerSuite.scala create mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsPollingSnapshotSourceSuite.scala create mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshotSuite.scala create mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshotsStoreSuite.scala create mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsWatchSnapshotSourceSuite.scala diff --git a/LICENSE b/LICENSE index 820f14dbdeed0..cc1f580207a75 100644 --- a/LICENSE +++ b/LICENSE @@ -237,6 +237,7 @@ The text of each license is also included at licenses/LICENSE-[project].txt. (BSD 3 Clause) netlib core (com.github.fommil.netlib:core:1.1.2 - https://github.com/fommil/netlib-java/core) (BSD 3 Clause) JPMML-Model (org.jpmml:pmml-model:1.2.7 - https://github.com/jpmml/jpmml-model) + (BSD 3 Clause) jmock (org.jmock:jmock-junit4:2.8.4 - http://jmock.org/) (BSD License) AntLR Parser Generator (antlr:antlr:2.7.7 - http://www.antlr.org/) (BSD License) ANTLR 4.5.2-1 (org.antlr:antlr4:4.5.2-1 - http://wwww.antlr.org/) (BSD licence) ANTLR ST4 4.0.4 (org.antlr:ST4:4.0.4 - http://www.stringtemplate.org) diff --git a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala index 165a15c73e7ca..0f08a2b0ad895 100644 --- a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala @@ -19,13 +19,12 @@ package org.apache.spark.util import java.util.concurrent._ +import com.google.common.util.concurrent.{MoreExecutors, ThreadFactoryBuilder} import scala.concurrent.{Awaitable, ExecutionContext, ExecutionContextExecutor} -import scala.concurrent.duration.Duration +import scala.concurrent.duration.{Duration, FiniteDuration} import scala.concurrent.forkjoin.{ForkJoinPool => SForkJoinPool, ForkJoinWorkerThread => SForkJoinWorkerThread} import scala.util.control.NonFatal -import com.google.common.util.concurrent.{MoreExecutors, ThreadFactoryBuilder} - import org.apache.spark.SparkException private[spark] object ThreadUtils { @@ -103,6 +102,22 @@ private[spark] object ThreadUtils { executor } + /** + * Wrapper over ScheduledThreadPoolExecutor. + */ + def newDaemonThreadPoolScheduledExecutor(threadNamePrefix: String, numThreads: Int) + : ScheduledExecutorService = { + val threadFactory = new ThreadFactoryBuilder() + .setDaemon(true) + .setNameFormat(s"$threadNamePrefix-%d") + .build() + val executor = new ScheduledThreadPoolExecutor(numThreads, threadFactory) + // By default, a cancelled task is not automatically removed from the work queue until its delay + // elapses. We have to enable it manually. + executor.setRemoveOnCancelPolicy(true) + executor + } + /** * Run a piece of code in a new thread and return the result. Exception in the new thread is * thrown in the caller thread with an adjusted stack trace that removes references to this @@ -229,4 +244,14 @@ private[spark] object ThreadUtils { } } // scalastyle:on awaitready + + def shutdown( + executor: ExecutorService, + gracePeriod: Duration = FiniteDuration(30, TimeUnit.SECONDS)): Unit = { + executor.shutdown() + executor.awaitTermination(gracePeriod.toMillis, TimeUnit.MILLISECONDS) + if (!executor.isShutdown) { + executor.shutdownNow() + } + } } diff --git a/licenses/LICENSE-jmock.txt b/licenses/LICENSE-jmock.txt new file mode 100644 index 0000000000000..ed7964fe3d9ef --- /dev/null +++ b/licenses/LICENSE-jmock.txt @@ -0,0 +1,28 @@ +Copyright (c) 2000-2017, jMock.org +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + +Redistributions of source code must retain the above copyright notice, +this list of conditions and the following disclaimer. Redistributions +in binary form must reproduce the above copyright notice, this list of +conditions and the following disclaimer in the documentation and/or +other materials provided with the distribution. + +Neither the name of jMock nor the names of its contributors may be +used to endorse or promote products derived from this software without +specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/pom.xml b/pom.xml index 23bbd3b09734e..4b4e6c13ea8fd 100644 --- a/pom.xml +++ b/pom.xml @@ -760,6 +760,12 @@ 1.10.19 test + + org.jmock + jmock-junit4 + test + 2.8.4 + org.scalacheck scalacheck_${scala.binary.version} diff --git a/resource-managers/kubernetes/core/pom.xml b/resource-managers/kubernetes/core/pom.xml index a62f271273465..a6dd47a6b7d95 100644 --- a/resource-managers/kubernetes/core/pom.xml +++ b/resource-managers/kubernetes/core/pom.xml @@ -77,6 +77,12 @@ + + com.squareup.okhttp3 + okhttp + 3.8.1 + + org.mockito mockito-core @@ -84,9 +90,9 @@ - com.squareup.okhttp3 - okhttp - 3.8.1 + org.jmock + jmock-junit4 + test diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala index 590deaa72e7ee..bf33179ae3dab 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala @@ -176,6 +176,24 @@ private[spark] object Config extends Logging { .checkValue(interval => interval > 0, s"Logging interval must be a positive time value.") .createWithDefaultString("1s") + val KUBERNETES_EXECUTOR_API_POLLING_INTERVAL = + ConfigBuilder("spark.kubernetes.executor.apiPollingInterval") + .doc("Interval between polls against the Kubernetes API server to inspect the " + + "state of executors.") + .timeConf(TimeUnit.MILLISECONDS) + .checkValue(interval => interval > 0, s"API server polling interval must be a" + + " positive time value.") + .createWithDefaultString("30s") + + val KUBERNETES_EXECUTOR_EVENT_PROCESSING_INTERVAL = + ConfigBuilder("spark.kubernetes.executor.eventProcessingInterval") + .doc("Interval between successive inspection of executor events sent from the" + + " Kubernetes API.") + .timeConf(TimeUnit.MILLISECONDS) + .checkValue(interval => interval > 0, s"Event processing interval must be a positive" + + " time value.") + .createWithDefaultString("1s") + val MEMORY_OVERHEAD_FACTOR = ConfigBuilder("spark.kubernetes.memoryOverheadFactor") .doc("This sets the Memory Overhead Factor that will allocate memory to non-JVM jobs " + @@ -193,7 +211,6 @@ private[spark] object Config extends Logging { "Ensure that major Python version is either Python2 or Python3") .createWithDefault("2") - val KUBERNETES_AUTH_SUBMISSION_CONF_PREFIX = "spark.kubernetes.authenticate.submission" diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodStates.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodStates.scala new file mode 100644 index 0000000000000..83daddf714489 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodStates.scala @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import io.fabric8.kubernetes.api.model.Pod + +sealed trait ExecutorPodState { + def pod: Pod +} + +case class PodRunning(pod: Pod) extends ExecutorPodState + +case class PodPending(pod: Pod) extends ExecutorPodState + +sealed trait FinalPodState extends ExecutorPodState + +case class PodSucceeded(pod: Pod) extends FinalPodState + +case class PodFailed(pod: Pod) extends FinalPodState + +case class PodDeleted(pod: Pod) extends FinalPodState + +case class PodUnknown(pod: Pod) extends ExecutorPodState diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocator.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocator.scala new file mode 100644 index 0000000000000..5a143ad3600fd --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocator.scala @@ -0,0 +1,149 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import java.util.concurrent.atomic.{AtomicInteger, AtomicLong} + +import io.fabric8.kubernetes.api.model.PodBuilder +import io.fabric8.kubernetes.client.KubernetesClient +import scala.collection.mutable + +import org.apache.spark.{SparkConf, SparkException} +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.deploy.k8s.KubernetesConf +import org.apache.spark.internal.Logging +import org.apache.spark.util.{Clock, Utils} + +private[spark] class ExecutorPodsAllocator( + conf: SparkConf, + executorBuilder: KubernetesExecutorBuilder, + kubernetesClient: KubernetesClient, + snapshotsStore: ExecutorPodsSnapshotsStore, + clock: Clock) extends Logging { + + private val EXECUTOR_ID_COUNTER = new AtomicLong(0L) + + private val totalExpectedExecutors = new AtomicInteger(0) + + private val podAllocationSize = conf.get(KUBERNETES_ALLOCATION_BATCH_SIZE) + + private val podAllocationDelay = conf.get(KUBERNETES_ALLOCATION_BATCH_DELAY) + + private val podCreationTimeout = math.max(podAllocationDelay * 5, 60000) + + private val kubernetesDriverPodName = conf + .get(KUBERNETES_DRIVER_POD_NAME) + .getOrElse(throw new SparkException("Must specify the driver pod name")) + + private val driverPod = kubernetesClient.pods() + .withName(kubernetesDriverPodName) + .get() + + // Executor IDs that have been requested from Kubernetes but have not been detected in any + // snapshot yet. Mapped to the timestamp when they were created. + private val newlyCreatedExecutors = mutable.Map.empty[Long, Long] + + def start(applicationId: String): Unit = { + snapshotsStore.addSubscriber(podAllocationDelay) { + onNewSnapshots(applicationId, _) + } + } + + def setTotalExpectedExecutors(total: Int): Unit = totalExpectedExecutors.set(total) + + private def onNewSnapshots(applicationId: String, snapshots: Seq[ExecutorPodsSnapshot]): Unit = { + newlyCreatedExecutors --= snapshots.flatMap(_.executorPods.keys) + // For all executors we've created against the API but have not seen in a snapshot + // yet - check the current time. If the current time has exceeded some threshold, + // assume that the pod was either never created (the API server never properly + // handled the creation request), or the API server created the pod but we missed + // both the creation and deletion events. In either case, delete the missing pod + // if possible, and mark such a pod to be rescheduled below. + newlyCreatedExecutors.foreach { case (execId, timeCreated) => + val currentTime = clock.getTimeMillis() + if (currentTime - timeCreated > podCreationTimeout) { + logWarning(s"Executor with id $execId was not detected in the Kubernetes" + + s" cluster after $podCreationTimeout milliseconds despite the fact that a" + + " previous allocation attempt tried to create it. The executor may have been" + + " deleted but the application missed the deletion event.") + Utils.tryLogNonFatalError { + kubernetesClient + .pods() + .withLabel(SPARK_EXECUTOR_ID_LABEL, execId.toString) + .delete() + } + newlyCreatedExecutors -= execId + } else { + logDebug(s"Executor with id $execId was not found in the Kubernetes cluster since it" + + s" was created ${currentTime - timeCreated} milliseconds ago.") + } + } + + if (snapshots.nonEmpty) { + // Only need to examine the cluster as of the latest snapshot, the "current" state, to see if + // we need to allocate more executors or not. + val latestSnapshot = snapshots.last + val currentRunningExecutors = latestSnapshot.executorPods.values.count { + case PodRunning(_) => true + case _ => false + } + val currentPendingExecutors = latestSnapshot.executorPods.values.count { + case PodPending(_) => true + case _ => false + } + val currentTotalExpectedExecutors = totalExpectedExecutors.get + logDebug(s"Currently have $currentRunningExecutors running executors and" + + s" $currentPendingExecutors pending executors. $newlyCreatedExecutors executors" + + s" have been requested but are pending appearance in the cluster.") + if (newlyCreatedExecutors.isEmpty + && currentPendingExecutors == 0 + && currentRunningExecutors < currentTotalExpectedExecutors) { + val numExecutorsToAllocate = math.min( + currentTotalExpectedExecutors - currentRunningExecutors, podAllocationSize) + logInfo(s"Going to request $numExecutorsToAllocate executors from Kubernetes.") + for ( _ <- 0 until numExecutorsToAllocate) { + val newExecutorId = EXECUTOR_ID_COUNTER.incrementAndGet() + val executorConf = KubernetesConf.createExecutorConf( + conf, + newExecutorId.toString, + applicationId, + driverPod) + val executorPod = executorBuilder.buildFromFeatures(executorConf) + val podWithAttachedContainer = new PodBuilder(executorPod.pod) + .editOrNewSpec() + .addToContainers(executorPod.container) + .endSpec() + .build() + kubernetesClient.pods().create(podWithAttachedContainer) + newlyCreatedExecutors(newExecutorId) = clock.getTimeMillis() + logDebug(s"Requested executor with id $newExecutorId from Kubernetes.") + } + } else if (currentRunningExecutors >= currentTotalExpectedExecutors) { + // TODO handle edge cases if we end up with more running executors than expected. + logDebug("Current number of running executors is equal to the number of requested" + + " executors. Not scaling up further.") + } else if (newlyCreatedExecutors.nonEmpty || currentPendingExecutors != 0) { + logDebug(s"Still waiting for ${newlyCreatedExecutors.size + currentPendingExecutors}" + + s" executors to begin running before requesting for more executors. # of executors in" + + s" pending status in the cluster: $currentPendingExecutors. # of executors that we have" + + s" created but we have not observed as being present in the cluster yet:" + + s" ${newlyCreatedExecutors.size}.") + } + } + } +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsLifecycleManager.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsLifecycleManager.scala new file mode 100644 index 0000000000000..b28d93990313e --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsLifecycleManager.scala @@ -0,0 +1,176 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import com.google.common.cache.Cache +import io.fabric8.kubernetes.api.model.Pod +import io.fabric8.kubernetes.client.KubernetesClient +import scala.collection.JavaConverters._ +import scala.collection.mutable + +import org.apache.spark.SparkConf +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.internal.Logging +import org.apache.spark.scheduler.ExecutorExited +import org.apache.spark.util.Utils + +private[spark] class ExecutorPodsLifecycleManager( + conf: SparkConf, + executorBuilder: KubernetesExecutorBuilder, + kubernetesClient: KubernetesClient, + snapshotsStore: ExecutorPodsSnapshotsStore, + // Use a best-effort to track which executors have been removed already. It's not generally + // job-breaking if we remove executors more than once but it's ideal if we make an attempt + // to avoid doing so. Expire cache entries so that this data structure doesn't grow beyond + // bounds. + removedExecutorsCache: Cache[java.lang.Long, java.lang.Long]) extends Logging { + + import ExecutorPodsLifecycleManager._ + + private val eventProcessingInterval = conf.get(KUBERNETES_EXECUTOR_EVENT_PROCESSING_INTERVAL) + + def start(schedulerBackend: KubernetesClusterSchedulerBackend): Unit = { + snapshotsStore.addSubscriber(eventProcessingInterval) { + onNewSnapshots(schedulerBackend, _) + } + } + + private def onNewSnapshots( + schedulerBackend: KubernetesClusterSchedulerBackend, + snapshots: Seq[ExecutorPodsSnapshot]): Unit = { + val execIdsRemovedInThisRound = mutable.HashSet.empty[Long] + snapshots.foreach { snapshot => + snapshot.executorPods.foreach { case (execId, state) => + state match { + case deleted@PodDeleted(_) => + logDebug(s"Snapshot reported deleted executor with id $execId," + + s" pod name ${state.pod.getMetadata.getName}") + removeExecutorFromSpark(schedulerBackend, deleted, execId) + execIdsRemovedInThisRound += execId + case failed@PodFailed(_) => + logDebug(s"Snapshot reported failed executor with id $execId," + + s" pod name ${state.pod.getMetadata.getName}") + onFinalNonDeletedState(failed, execId, schedulerBackend, execIdsRemovedInThisRound) + case succeeded@PodSucceeded(_) => + logDebug(s"Snapshot reported succeeded executor with id $execId," + + s" pod name ${state.pod.getMetadata.getName}. Note that succeeded executors are" + + s" unusual unless Spark specifically informed the executor to exit.") + onFinalNonDeletedState(succeeded, execId, schedulerBackend, execIdsRemovedInThisRound) + case _ => + } + } + } + + // Reconcile the case where Spark claims to know about an executor but the corresponding pod + // is missing from the cluster. This would occur if we miss a deletion event and the pod + // transitions immediately from running io absent. We only need to check against the latest + // snapshot for this, and we don't do this for executors in the deleted executors cache or + // that we just removed in this round. + if (snapshots.nonEmpty) { + val latestSnapshot = snapshots.last + (schedulerBackend.getExecutorIds().map(_.toLong).toSet + -- latestSnapshot.executorPods.keySet + -- execIdsRemovedInThisRound).foreach { missingExecutorId => + if (removedExecutorsCache.getIfPresent(missingExecutorId) == null) { + val exitReasonMessage = s"The executor with ID $missingExecutorId was not found in the" + + s" cluster but we didn't get a reason why. Marking the executor as failed. The" + + s" executor may have been deleted but the driver missed the deletion event." + logDebug(exitReasonMessage) + val exitReason = ExecutorExited( + UNKNOWN_EXIT_CODE, + exitCausedByApp = false, + exitReasonMessage) + schedulerBackend.doRemoveExecutor(missingExecutorId.toString, exitReason) + execIdsRemovedInThisRound += missingExecutorId + } + } + } + logDebug(s"Removed executors with ids ${execIdsRemovedInThisRound.mkString(",")}" + + s" from Spark that were either found to be deleted or non-existent in the cluster.") + } + + private def onFinalNonDeletedState( + podState: FinalPodState, + execId: Long, + schedulerBackend: KubernetesClusterSchedulerBackend, + execIdsRemovedInRound: mutable.Set[Long]): Unit = { + removeExecutorFromK8s(podState.pod) + removeExecutorFromSpark(schedulerBackend, podState, execId) + execIdsRemovedInRound += execId + } + + private def removeExecutorFromK8s(updatedPod: Pod): Unit = { + // If deletion failed on a previous try, we can try again if resync informs us the pod + // is still around. + // Delete as best attempt - duplicate deletes will throw an exception but the end state + // of getting rid of the pod is what matters. + Utils.tryLogNonFatalError { + kubernetesClient + .pods() + .withName(updatedPod.getMetadata.getName) + .delete() + } + } + + private def removeExecutorFromSpark( + schedulerBackend: KubernetesClusterSchedulerBackend, + podState: FinalPodState, + execId: Long): Unit = { + if (removedExecutorsCache.getIfPresent(execId) == null) { + removedExecutorsCache.put(execId, execId) + val exitReason = findExitReason(podState, execId) + schedulerBackend.doRemoveExecutor(execId.toString, exitReason) + } + } + + private def findExitReason(podState: FinalPodState, execId: Long): ExecutorExited = { + val exitCode = findExitCode(podState) + val (exitCausedByApp, exitMessage) = podState match { + case PodDeleted(_) => + (false, s"The executor with id $execId was deleted by a user or the framework.") + case _ => + val msg = exitReasonMessage(podState, execId, exitCode) + (true, msg) + } + ExecutorExited(exitCode, exitCausedByApp, exitMessage) + } + + private def exitReasonMessage(podState: FinalPodState, execId: Long, exitCode: Int) = { + val pod = podState.pod + s""" + |The executor with id $execId exited with exit code $exitCode. + |The API gave the following brief reason: ${pod.getStatus.getReason} + |The API gave the following message: ${pod.getStatus.getMessage} + |The API gave the following container statuses: + | + |${pod.getStatus.getContainerStatuses.asScala.map(_.toString).mkString("\n===\n")} + """.stripMargin + } + + private def findExitCode(podState: FinalPodState): Int = { + podState.pod.getStatus.getContainerStatuses.asScala.find { containerStatus => + containerStatus.getState.getTerminated != null + }.map { terminatedContainer => + terminatedContainer.getState.getTerminated.getExitCode.toInt + }.getOrElse(UNKNOWN_EXIT_CODE) + } +} + +private object ExecutorPodsLifecycleManager { + val UNKNOWN_EXIT_CODE = -1 +} + diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsPollingSnapshotSource.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsPollingSnapshotSource.scala new file mode 100644 index 0000000000000..e77e604d00e0f --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsPollingSnapshotSource.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import java.util.concurrent.{Future, ScheduledExecutorService, TimeUnit} + +import io.fabric8.kubernetes.client.KubernetesClient +import scala.collection.JavaConverters._ + +import org.apache.spark.SparkConf +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.internal.Logging +import org.apache.spark.util.ThreadUtils + +private[spark] class ExecutorPodsPollingSnapshotSource( + conf: SparkConf, + kubernetesClient: KubernetesClient, + snapshotsStore: ExecutorPodsSnapshotsStore, + pollingExecutor: ScheduledExecutorService) extends Logging { + + private val pollingInterval = conf.get(KUBERNETES_EXECUTOR_API_POLLING_INTERVAL) + + private var pollingFuture: Future[_] = _ + + def start(applicationId: String): Unit = { + require(pollingFuture == null, "Cannot start polling more than once.") + logDebug(s"Starting to check for executor pod state every $pollingInterval ms.") + pollingFuture = pollingExecutor.scheduleWithFixedDelay( + new PollRunnable(applicationId), pollingInterval, pollingInterval, TimeUnit.MILLISECONDS) + } + + def stop(): Unit = { + if (pollingFuture != null) { + pollingFuture.cancel(true) + pollingFuture = null + } + ThreadUtils.shutdown(pollingExecutor) + } + + private class PollRunnable(applicationId: String) extends Runnable { + override def run(): Unit = { + logDebug(s"Resynchronizing full executor pod state from Kubernetes.") + snapshotsStore.replaceSnapshot(kubernetesClient + .pods() + .withLabel(SPARK_APP_ID_LABEL, applicationId) + .withLabel(SPARK_ROLE_LABEL, SPARK_POD_EXECUTOR_ROLE) + .list() + .getItems + .asScala) + } + } + +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshot.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshot.scala new file mode 100644 index 0000000000000..26be918043412 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshot.scala @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import io.fabric8.kubernetes.api.model.Pod + +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.internal.Logging + +/** + * An immutable view of the current executor pods that are running in the cluster. + */ +private[spark] case class ExecutorPodsSnapshot(executorPods: Map[Long, ExecutorPodState]) { + + import ExecutorPodsSnapshot._ + + def withUpdate(updatedPod: Pod): ExecutorPodsSnapshot = { + val newExecutorPods = executorPods ++ toStatesByExecutorId(Seq(updatedPod)) + new ExecutorPodsSnapshot(newExecutorPods) + } +} + +object ExecutorPodsSnapshot extends Logging { + + def apply(executorPods: Seq[Pod]): ExecutorPodsSnapshot = { + ExecutorPodsSnapshot(toStatesByExecutorId(executorPods)) + } + + def apply(): ExecutorPodsSnapshot = ExecutorPodsSnapshot(Map.empty[Long, ExecutorPodState]) + + private def toStatesByExecutorId(executorPods: Seq[Pod]): Map[Long, ExecutorPodState] = { + executorPods.map { pod => + (pod.getMetadata.getLabels.get(SPARK_EXECUTOR_ID_LABEL).toLong, toState(pod)) + }.toMap + } + + private def toState(pod: Pod): ExecutorPodState = { + if (isDeleted(pod)) { + PodDeleted(pod) + } else { + val phase = pod.getStatus.getPhase.toLowerCase + phase match { + case "pending" => + PodPending(pod) + case "running" => + PodRunning(pod) + case "failed" => + PodFailed(pod) + case "succeeded" => + PodSucceeded(pod) + case _ => + logWarning(s"Received unknown phase $phase for executor pod with name" + + s" ${pod.getMetadata.getName} in namespace ${pod.getMetadata.getNamespace}") + PodUnknown(pod) + } + } + } + + private def isDeleted(pod: Pod): Boolean = pod.getMetadata.getDeletionTimestamp != null +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshotsStore.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshotsStore.scala new file mode 100644 index 0000000000000..dd264332cf9e8 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshotsStore.scala @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import io.fabric8.kubernetes.api.model.Pod + +private[spark] trait ExecutorPodsSnapshotsStore { + + def addSubscriber + (processBatchIntervalMillis: Long) + (onNewSnapshots: Seq[ExecutorPodsSnapshot] => Unit) + + def stop(): Unit + + def updatePod(updatedPod: Pod): Unit + + def replaceSnapshot(newSnapshot: Seq[Pod]): Unit +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshotsStoreImpl.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshotsStoreImpl.scala new file mode 100644 index 0000000000000..5583b4617eeb2 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshotsStoreImpl.scala @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import java.util.concurrent._ + +import io.fabric8.kubernetes.api.model.Pod +import javax.annotation.concurrent.GuardedBy +import scala.collection.JavaConverters._ +import scala.collection.mutable + +import org.apache.spark.util.{ThreadUtils, Utils} + +/** + * Controls the propagation of the Spark application's executor pods state to subscribers that + * react to that state. + *
    + * Roughly follows a producer-consumer model. Producers report states of executor pods, and these + * states are then published to consumers that can perform any actions in response to these states. + *
    + * Producers push updates in one of two ways. An incremental update sent by updatePod() represents + * a known new state of a single executor pod. A full sync sent by replaceSnapshot() indicates that + * the passed pods are all of the most up to date states of all executor pods for the application. + * The combination of the states of all executor pods for the application is collectively known as + * a snapshot. The store keeps track of the most up to date snapshot, and applies updates to that + * most recent snapshot - either by incrementally updating the snapshot with a single new pod state, + * or by replacing the snapshot entirely on a full sync. + *
    + * Consumers, or subscribers, register that they want to be informed about all snapshots of the + * executor pods. Every time the store replaces its most up to date snapshot from either an + * incremental update or a full sync, the most recent snapshot after the update is posted to the + * subscriber's buffer. Subscribers receive blocks of snapshots produced by the producers in + * time-windowed chunks. Each subscriber can choose to receive their snapshot chunks at different + * time intervals. + */ +private[spark] class ExecutorPodsSnapshotsStoreImpl(subscribersExecutor: ScheduledExecutorService) + extends ExecutorPodsSnapshotsStore { + + private val SNAPSHOT_LOCK = new Object() + + private val subscribers = mutable.Buffer.empty[SnapshotsSubscriber] + private val pollingTasks = mutable.Buffer.empty[Future[_]] + + @GuardedBy("SNAPSHOT_LOCK") + private var currentSnapshot = ExecutorPodsSnapshot() + + override def addSubscriber( + processBatchIntervalMillis: Long) + (onNewSnapshots: Seq[ExecutorPodsSnapshot] => Unit): Unit = { + val newSubscriber = SnapshotsSubscriber( + new LinkedBlockingQueue[ExecutorPodsSnapshot](), onNewSnapshots) + SNAPSHOT_LOCK.synchronized { + newSubscriber.snapshotsBuffer.add(currentSnapshot) + } + subscribers += newSubscriber + pollingTasks += subscribersExecutor.scheduleWithFixedDelay( + toRunnable(() => callSubscriber(newSubscriber)), + 0L, + processBatchIntervalMillis, + TimeUnit.MILLISECONDS) + } + + override def stop(): Unit = { + pollingTasks.foreach(_.cancel(true)) + ThreadUtils.shutdown(subscribersExecutor) + } + + override def updatePod(updatedPod: Pod): Unit = SNAPSHOT_LOCK.synchronized { + currentSnapshot = currentSnapshot.withUpdate(updatedPod) + addCurrentSnapshotToSubscribers() + } + + override def replaceSnapshot(newSnapshot: Seq[Pod]): Unit = SNAPSHOT_LOCK.synchronized { + currentSnapshot = ExecutorPodsSnapshot(newSnapshot) + addCurrentSnapshotToSubscribers() + } + + private def addCurrentSnapshotToSubscribers(): Unit = { + subscribers.foreach { subscriber => + subscriber.snapshotsBuffer.add(currentSnapshot) + } + } + + private def callSubscriber(subscriber: SnapshotsSubscriber): Unit = { + Utils.tryLogNonFatalError { + val currentSnapshots = mutable.Buffer.empty[ExecutorPodsSnapshot].asJava + subscriber.snapshotsBuffer.drainTo(currentSnapshots) + subscriber.onNewSnapshots(currentSnapshots.asScala) + } + } + + private def toRunnable[T](runnable: () => Unit): Runnable = new Runnable { + override def run(): Unit = runnable() + } + + private case class SnapshotsSubscriber( + snapshotsBuffer: BlockingQueue[ExecutorPodsSnapshot], + onNewSnapshots: Seq[ExecutorPodsSnapshot] => Unit) +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsWatchSnapshotSource.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsWatchSnapshotSource.scala new file mode 100644 index 0000000000000..a6749a644e00c --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsWatchSnapshotSource.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import java.io.Closeable + +import io.fabric8.kubernetes.api.model.Pod +import io.fabric8.kubernetes.client.{KubernetesClient, KubernetesClientException, Watcher} +import io.fabric8.kubernetes.client.Watcher.Action + +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.internal.Logging +import org.apache.spark.util.Utils + +private[spark] class ExecutorPodsWatchSnapshotSource( + snapshotsStore: ExecutorPodsSnapshotsStore, + kubernetesClient: KubernetesClient) extends Logging { + + private var watchConnection: Closeable = _ + + def start(applicationId: String): Unit = { + require(watchConnection == null, "Cannot start the watcher twice.") + logDebug(s"Starting watch for pods with labels $SPARK_APP_ID_LABEL=$applicationId," + + s" $SPARK_ROLE_LABEL=$SPARK_POD_EXECUTOR_ROLE.") + watchConnection = kubernetesClient.pods() + .withLabel(SPARK_APP_ID_LABEL, applicationId) + .withLabel(SPARK_ROLE_LABEL, SPARK_POD_EXECUTOR_ROLE) + .watch(new ExecutorPodsWatcher()) + } + + def stop(): Unit = { + if (watchConnection != null) { + Utils.tryLogNonFatalError { + watchConnection.close() + } + watchConnection = null + } + } + + private class ExecutorPodsWatcher extends Watcher[Pod] { + override def eventReceived(action: Action, pod: Pod): Unit = { + val podName = pod.getMetadata.getName + logDebug(s"Received executor pod update for pod named $podName, action $action") + snapshotsStore.updatePod(pod) + } + + override def onClose(e: KubernetesClientException): Unit = { + logWarning("Kubernetes client has been closed (this is expected if the application is" + + " shutting down.)", e) + } + } + +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala index 0ea80dfbc0d97..c6e931a38405f 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala @@ -17,7 +17,9 @@ package org.apache.spark.scheduler.cluster.k8s import java.io.File +import java.util.concurrent.TimeUnit +import com.google.common.cache.CacheBuilder import io.fabric8.kubernetes.client.Config import org.apache.spark.{SparkContext, SparkException} @@ -26,7 +28,7 @@ import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ import org.apache.spark.internal.Logging import org.apache.spark.scheduler.{ExternalClusterManager, SchedulerBackend, TaskScheduler, TaskSchedulerImpl} -import org.apache.spark.util.ThreadUtils +import org.apache.spark.util.{SystemClock, ThreadUtils} private[spark] class KubernetesClusterManager extends ExternalClusterManager with Logging { @@ -56,17 +58,45 @@ private[spark] class KubernetesClusterManager extends ExternalClusterManager wit Some(new File(Config.KUBERNETES_SERVICE_ACCOUNT_TOKEN_PATH)), Some(new File(Config.KUBERNETES_SERVICE_ACCOUNT_CA_CRT_PATH))) - val allocatorExecutor = ThreadUtils - .newDaemonSingleThreadScheduledExecutor("kubernetes-pod-allocator") val requestExecutorsService = ThreadUtils.newDaemonCachedThreadPool( "kubernetes-executor-requests") + + val subscribersExecutor = ThreadUtils + .newDaemonThreadPoolScheduledExecutor( + "kubernetes-executor-snapshots-subscribers", 2) + val snapshotsStore = new ExecutorPodsSnapshotsStoreImpl(subscribersExecutor) + val removedExecutorsCache = CacheBuilder.newBuilder() + .expireAfterWrite(3, TimeUnit.MINUTES) + .build[java.lang.Long, java.lang.Long]() + val executorPodsLifecycleEventHandler = new ExecutorPodsLifecycleManager( + sc.conf, + new KubernetesExecutorBuilder(), + kubernetesClient, + snapshotsStore, + removedExecutorsCache) + + val executorPodsAllocator = new ExecutorPodsAllocator( + sc.conf, new KubernetesExecutorBuilder(), kubernetesClient, snapshotsStore, new SystemClock()) + + val podsWatchEventSource = new ExecutorPodsWatchSnapshotSource( + snapshotsStore, + kubernetesClient) + + val eventsPollingExecutor = ThreadUtils.newDaemonSingleThreadScheduledExecutor( + "kubernetes-executor-pod-polling-sync") + val podsPollingEventSource = new ExecutorPodsPollingSnapshotSource( + sc.conf, kubernetesClient, snapshotsStore, eventsPollingExecutor) + new KubernetesClusterSchedulerBackend( scheduler.asInstanceOf[TaskSchedulerImpl], sc.env.rpcEnv, - new KubernetesExecutorBuilder, kubernetesClient, - allocatorExecutor, - requestExecutorsService) + requestExecutorsService, + snapshotsStore, + executorPodsAllocator, + executorPodsLifecycleEventHandler, + podsWatchEventSource, + podsPollingEventSource) } override def initialize(scheduler: TaskScheduler, backend: SchedulerBackend): Unit = { diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala index d86664c81071b..fa6dc2c479bbf 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala @@ -16,60 +16,32 @@ */ package org.apache.spark.scheduler.cluster.k8s -import java.io.Closeable -import java.net.InetAddress -import java.util.concurrent.{ConcurrentHashMap, ExecutorService, ScheduledExecutorService, TimeUnit} -import java.util.concurrent.atomic.{AtomicInteger, AtomicLong, AtomicReference} -import javax.annotation.concurrent.GuardedBy +import java.util.concurrent.ExecutorService -import io.fabric8.kubernetes.api.model._ -import io.fabric8.kubernetes.client.{KubernetesClient, KubernetesClientException, Watcher} -import io.fabric8.kubernetes.client.Watcher.Action -import scala.collection.JavaConverters._ -import scala.collection.mutable +import io.fabric8.kubernetes.client.KubernetesClient import scala.concurrent.{ExecutionContext, Future} -import org.apache.spark.SparkException -import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ -import org.apache.spark.deploy.k8s.KubernetesConf -import org.apache.spark.rpc.{RpcAddress, RpcEndpointAddress, RpcEnv} -import org.apache.spark.scheduler.{ExecutorExited, SlaveLost, TaskSchedulerImpl} +import org.apache.spark.rpc.{RpcAddress, RpcEnv} +import org.apache.spark.scheduler.{ExecutorLossReason, TaskSchedulerImpl} import org.apache.spark.scheduler.cluster.{CoarseGrainedSchedulerBackend, SchedulerBackendUtils} -import org.apache.spark.util.Utils +import org.apache.spark.util.{ThreadUtils, Utils} private[spark] class KubernetesClusterSchedulerBackend( scheduler: TaskSchedulerImpl, rpcEnv: RpcEnv, - executorBuilder: KubernetesExecutorBuilder, kubernetesClient: KubernetesClient, - allocatorExecutor: ScheduledExecutorService, - requestExecutorsService: ExecutorService) + requestExecutorsService: ExecutorService, + snapshotsStore: ExecutorPodsSnapshotsStore, + podAllocator: ExecutorPodsAllocator, + lifecycleEventHandler: ExecutorPodsLifecycleManager, + watchEvents: ExecutorPodsWatchSnapshotSource, + pollEvents: ExecutorPodsPollingSnapshotSource) extends CoarseGrainedSchedulerBackend(scheduler, rpcEnv) { - import KubernetesClusterSchedulerBackend._ - - private val EXECUTOR_ID_COUNTER = new AtomicLong(0L) - private val RUNNING_EXECUTOR_PODS_LOCK = new Object - @GuardedBy("RUNNING_EXECUTOR_PODS_LOCK") - private val runningExecutorsToPods = new mutable.HashMap[String, Pod] - private val executorPodsByIPs = new ConcurrentHashMap[String, Pod]() - private val podsWithKnownExitReasons = new ConcurrentHashMap[String, ExecutorExited]() - private val disconnectedPodsByExecutorIdPendingRemoval = new ConcurrentHashMap[String, Pod]() - - private val kubernetesNamespace = conf.get(KUBERNETES_NAMESPACE) - - private val kubernetesDriverPodName = conf - .get(KUBERNETES_DRIVER_POD_NAME) - .getOrElse(throw new SparkException("Must specify the driver pod name")) private implicit val requestExecutorContext = ExecutionContext.fromExecutorService( requestExecutorsService) - private val driverPod = kubernetesClient.pods() - .inNamespace(kubernetesNamespace) - .withName(kubernetesDriverPodName) - .get() - protected override val minRegisteredRatio = if (conf.getOption("spark.scheduler.minRegisteredResourcesRatio").isEmpty) { 0.8 @@ -77,372 +49,93 @@ private[spark] class KubernetesClusterSchedulerBackend( super.minRegisteredRatio } - private val executorWatchResource = new AtomicReference[Closeable] - private val totalExpectedExecutors = new AtomicInteger(0) - - private val driverUrl = RpcEndpointAddress( - conf.get("spark.driver.host"), - conf.getInt("spark.driver.port", DEFAULT_DRIVER_PORT), - CoarseGrainedSchedulerBackend.ENDPOINT_NAME).toString - private val initialExecutors = SchedulerBackendUtils.getInitialTargetExecutorNumber(conf) - private val podAllocationInterval = conf.get(KUBERNETES_ALLOCATION_BATCH_DELAY) - - private val podAllocationSize = conf.get(KUBERNETES_ALLOCATION_BATCH_SIZE) - - private val executorLostReasonCheckMaxAttempts = conf.get( - KUBERNETES_EXECUTOR_LOST_REASON_CHECK_MAX_ATTEMPTS) - - private val allocatorRunnable = new Runnable { - - // Maintains a map of executor id to count of checks performed to learn the loss reason - // for an executor. - private val executorReasonCheckAttemptCounts = new mutable.HashMap[String, Int] - - override def run(): Unit = { - handleDisconnectedExecutors() - - val executorsToAllocate = mutable.Map[String, Pod]() - val currentTotalRegisteredExecutors = totalRegisteredExecutors.get - val currentTotalExpectedExecutors = totalExpectedExecutors.get - val currentNodeToLocalTaskCount = getNodesWithLocalTaskCounts() - RUNNING_EXECUTOR_PODS_LOCK.synchronized { - if (currentTotalRegisteredExecutors < runningExecutorsToPods.size) { - logDebug("Waiting for pending executors before scaling") - } else if (currentTotalExpectedExecutors <= runningExecutorsToPods.size) { - logDebug("Maximum allowed executor limit reached. Not scaling up further.") - } else { - for (_ <- 0 until math.min( - currentTotalExpectedExecutors - runningExecutorsToPods.size, podAllocationSize)) { - val executorId = EXECUTOR_ID_COUNTER.incrementAndGet().toString - val executorConf = KubernetesConf.createExecutorConf( - conf, - executorId, - applicationId(), - driverPod) - val executorPod = executorBuilder.buildFromFeatures(executorConf) - val podWithAttachedContainer = new PodBuilder(executorPod.pod) - .editOrNewSpec() - .addToContainers(executorPod.container) - .endSpec() - .build() - - executorsToAllocate(executorId) = podWithAttachedContainer - logInfo( - s"Requesting a new executor, total executors is now ${runningExecutorsToPods.size}") - } - } - } - - val allocatedExecutors = executorsToAllocate.mapValues { pod => - Utils.tryLog { - kubernetesClient.pods().create(pod) - } - } - - RUNNING_EXECUTOR_PODS_LOCK.synchronized { - allocatedExecutors.map { - case (executorId, attemptedAllocatedExecutor) => - attemptedAllocatedExecutor.map { successfullyAllocatedExecutor => - runningExecutorsToPods.put(executorId, successfullyAllocatedExecutor) - } - } - } - } - - def handleDisconnectedExecutors(): Unit = { - // For each disconnected executor, synchronize with the loss reasons that may have been found - // by the executor pod watcher. If the loss reason was discovered by the watcher, - // inform the parent class with removeExecutor. - disconnectedPodsByExecutorIdPendingRemoval.asScala.foreach { - case (executorId, executorPod) => - val knownExitReason = Option(podsWithKnownExitReasons.remove( - executorPod.getMetadata.getName)) - knownExitReason.fold { - removeExecutorOrIncrementLossReasonCheckCount(executorId) - } { executorExited => - logWarning(s"Removing executor $executorId with loss reason " + executorExited.message) - removeExecutor(executorId, executorExited) - // We don't delete the pod running the executor that has an exit condition caused by - // the application from the Kubernetes API server. This allows users to debug later on - // through commands such as "kubectl logs " and - // "kubectl describe pod ". Note that exited containers have terminated and - // therefore won't take CPU and memory resources. - // Otherwise, the executor pod is marked to be deleted from the API server. - if (executorExited.exitCausedByApp) { - logInfo(s"Executor $executorId exited because of the application.") - deleteExecutorFromDataStructures(executorId) - } else { - logInfo(s"Executor $executorId failed because of a framework error.") - deleteExecutorFromClusterAndDataStructures(executorId) - } - } - } - } - - def removeExecutorOrIncrementLossReasonCheckCount(executorId: String): Unit = { - val reasonCheckCount = executorReasonCheckAttemptCounts.getOrElse(executorId, 0) - if (reasonCheckCount >= executorLostReasonCheckMaxAttempts) { - removeExecutor(executorId, SlaveLost("Executor lost for unknown reasons.")) - deleteExecutorFromClusterAndDataStructures(executorId) - } else { - executorReasonCheckAttemptCounts.put(executorId, reasonCheckCount + 1) - } - } - - def deleteExecutorFromClusterAndDataStructures(executorId: String): Unit = { - deleteExecutorFromDataStructures(executorId).foreach { pod => - kubernetesClient.pods().delete(pod) - } - } - - def deleteExecutorFromDataStructures(executorId: String): Option[Pod] = { - disconnectedPodsByExecutorIdPendingRemoval.remove(executorId) - executorReasonCheckAttemptCounts -= executorId - podsWithKnownExitReasons.remove(executorId) - RUNNING_EXECUTOR_PODS_LOCK.synchronized { - runningExecutorsToPods.remove(executorId).orElse { - logWarning(s"Unable to remove pod for unknown executor $executorId") - None - } - } - } - } - - override def sufficientResourcesRegistered(): Boolean = { - totalRegisteredExecutors.get() >= initialExecutors * minRegisteredRatio + // Allow removeExecutor to be accessible by ExecutorPodsLifecycleEventHandler + private[k8s] def doRemoveExecutor(executorId: String, reason: ExecutorLossReason): Unit = { + removeExecutor(executorId, reason) } override def start(): Unit = { super.start() - executorWatchResource.set( - kubernetesClient - .pods() - .withLabel(SPARK_APP_ID_LABEL, applicationId()) - .watch(new ExecutorPodsWatcher())) - - allocatorExecutor.scheduleWithFixedDelay( - allocatorRunnable, 0L, podAllocationInterval, TimeUnit.MILLISECONDS) - if (!Utils.isDynamicAllocationEnabled(conf)) { - doRequestTotalExecutors(initialExecutors) + podAllocator.setTotalExpectedExecutors(initialExecutors) } + lifecycleEventHandler.start(this) + podAllocator.start(applicationId()) + watchEvents.start(applicationId()) + pollEvents.start(applicationId()) } override def stop(): Unit = { - // stop allocation of new resources and caches. - allocatorExecutor.shutdown() - allocatorExecutor.awaitTermination(30, TimeUnit.SECONDS) - - // send stop message to executors so they shut down cleanly super.stop() - try { - val resource = executorWatchResource.getAndSet(null) - if (resource != null) { - resource.close() - } - } catch { - case e: Throwable => logWarning("Failed to close the executor pod watcher", e) + Utils.tryLogNonFatalError { + snapshotsStore.stop() } - // then delete the executor pods Utils.tryLogNonFatalError { - deleteExecutorPodsOnStop() - executorPodsByIPs.clear() + watchEvents.stop() } + Utils.tryLogNonFatalError { - logInfo("Closing kubernetes client") - kubernetesClient.close() + pollEvents.stop() } - } - /** - * @return A map of K8s cluster nodes to the number of tasks that could benefit from data - * locality if an executor launches on the cluster node. - */ - private def getNodesWithLocalTaskCounts() : Map[String, Int] = { - val nodeToLocalTaskCount = synchronized { - mutable.Map[String, Int]() ++ hostToLocalTaskCount + Utils.tryLogNonFatalError { + kubernetesClient.pods() + .withLabel(SPARK_APP_ID_LABEL, applicationId()) + .withLabel(SPARK_ROLE_LABEL, SPARK_POD_EXECUTOR_ROLE) + .delete() } - for (pod <- executorPodsByIPs.values().asScala) { - // Remove cluster nodes that are running our executors already. - // TODO: This prefers spreading out executors across nodes. In case users want - // consolidating executors on fewer nodes, introduce a flag. See the spark.deploy.spreadOut - // flag that Spark standalone has: https://spark.apache.org/docs/latest/spark-standalone.html - nodeToLocalTaskCount.remove(pod.getSpec.getNodeName).nonEmpty || - nodeToLocalTaskCount.remove(pod.getStatus.getHostIP).nonEmpty || - nodeToLocalTaskCount.remove( - InetAddress.getByName(pod.getStatus.getHostIP).getCanonicalHostName).nonEmpty + Utils.tryLogNonFatalError { + ThreadUtils.shutdown(requestExecutorsService) + } + + Utils.tryLogNonFatalError { + kubernetesClient.close() } - nodeToLocalTaskCount.toMap[String, Int] } override def doRequestTotalExecutors(requestedTotal: Int): Future[Boolean] = Future[Boolean] { - totalExpectedExecutors.set(requestedTotal) + // TODO when we support dynamic allocation, the pod allocator should be told to process the + // current snapshot in order to decrease/increase the number of executors accordingly. + podAllocator.setTotalExpectedExecutors(requestedTotal) true } - override def doKillExecutors(executorIds: Seq[String]): Future[Boolean] = Future[Boolean] { - val podsToDelete = RUNNING_EXECUTOR_PODS_LOCK.synchronized { - executorIds.flatMap { executorId => - runningExecutorsToPods.remove(executorId) match { - case Some(pod) => - disconnectedPodsByExecutorIdPendingRemoval.put(executorId, pod) - Some(pod) - - case None => - logWarning(s"Unable to remove pod for unknown executor $executorId") - None - } - } - } - - kubernetesClient.pods().delete(podsToDelete: _*) - true + override def sufficientResourcesRegistered(): Boolean = { + totalRegisteredExecutors.get() >= initialExecutors * minRegisteredRatio } - private def deleteExecutorPodsOnStop(): Unit = { - val executorPodsToDelete = RUNNING_EXECUTOR_PODS_LOCK.synchronized { - val runningExecutorPodsCopy = Seq(runningExecutorsToPods.values.toSeq: _*) - runningExecutorsToPods.clear() - runningExecutorPodsCopy - } - kubernetesClient.pods().delete(executorPodsToDelete: _*) + override def getExecutorIds(): Seq[String] = synchronized { + super.getExecutorIds() } - private class ExecutorPodsWatcher extends Watcher[Pod] { - - private val DEFAULT_CONTAINER_FAILURE_EXIT_STATUS = -1 - - override def eventReceived(action: Action, pod: Pod): Unit = { - val podName = pod.getMetadata.getName - val podIP = pod.getStatus.getPodIP - - action match { - case Action.MODIFIED if (pod.getStatus.getPhase == "Running" - && pod.getMetadata.getDeletionTimestamp == null) => - val clusterNodeName = pod.getSpec.getNodeName - logInfo(s"Executor pod $podName ready, launched at $clusterNodeName as IP $podIP.") - executorPodsByIPs.put(podIP, pod) - - case Action.DELETED | Action.ERROR => - val executorId = getExecutorId(pod) - logDebug(s"Executor pod $podName at IP $podIP was at $action.") - if (podIP != null) { - executorPodsByIPs.remove(podIP) - } - - val executorExitReason = if (action == Action.ERROR) { - logWarning(s"Received error event of executor pod $podName. Reason: " + - pod.getStatus.getReason) - executorExitReasonOnError(pod) - } else if (action == Action.DELETED) { - logWarning(s"Received delete event of executor pod $podName. Reason: " + - pod.getStatus.getReason) - executorExitReasonOnDelete(pod) - } else { - throw new IllegalStateException( - s"Unknown action that should only be DELETED or ERROR: $action") - } - podsWithKnownExitReasons.put(pod.getMetadata.getName, executorExitReason) - - if (!disconnectedPodsByExecutorIdPendingRemoval.containsKey(executorId)) { - log.warn(s"Executor with id $executorId was not marked as disconnected, but the " + - s"watch received an event of type $action for this executor. The executor may " + - "have failed to start in the first place and never registered with the driver.") - } - disconnectedPodsByExecutorIdPendingRemoval.put(executorId, pod) - - case _ => logDebug(s"Received event of executor pod $podName: " + action) - } - } - - override def onClose(cause: KubernetesClientException): Unit = { - logDebug("Executor pod watch closed.", cause) - } - - private def getExecutorExitStatus(pod: Pod): Int = { - val containerStatuses = pod.getStatus.getContainerStatuses - if (!containerStatuses.isEmpty) { - // we assume the first container represents the pod status. This assumption may not hold - // true in the future. Revisit this if side-car containers start running inside executor - // pods. - getExecutorExitStatus(containerStatuses.get(0)) - } else DEFAULT_CONTAINER_FAILURE_EXIT_STATUS - } - - private def getExecutorExitStatus(containerStatus: ContainerStatus): Int = { - Option(containerStatus.getState).map { containerState => - Option(containerState.getTerminated).map { containerStateTerminated => - containerStateTerminated.getExitCode.intValue() - }.getOrElse(UNKNOWN_EXIT_CODE) - }.getOrElse(UNKNOWN_EXIT_CODE) - } - - private def isPodAlreadyReleased(pod: Pod): Boolean = { - val executorId = pod.getMetadata.getLabels.get(SPARK_EXECUTOR_ID_LABEL) - RUNNING_EXECUTOR_PODS_LOCK.synchronized { - !runningExecutorsToPods.contains(executorId) - } - } - - private def executorExitReasonOnError(pod: Pod): ExecutorExited = { - val containerExitStatus = getExecutorExitStatus(pod) - // container was probably actively killed by the driver. - if (isPodAlreadyReleased(pod)) { - ExecutorExited(containerExitStatus, exitCausedByApp = false, - s"Container in pod ${pod.getMetadata.getName} exited from explicit termination " + - "request.") - } else { - val containerExitReason = s"Pod ${pod.getMetadata.getName}'s executor container " + - s"exited with exit status code $containerExitStatus." - ExecutorExited(containerExitStatus, exitCausedByApp = true, containerExitReason) - } - } - - private def executorExitReasonOnDelete(pod: Pod): ExecutorExited = { - val exitMessage = if (isPodAlreadyReleased(pod)) { - s"Container in pod ${pod.getMetadata.getName} exited from explicit termination request." - } else { - s"Pod ${pod.getMetadata.getName} deleted or lost." - } - ExecutorExited(getExecutorExitStatus(pod), exitCausedByApp = false, exitMessage) - } - - private def getExecutorId(pod: Pod): String = { - val executorId = pod.getMetadata.getLabels.get(SPARK_EXECUTOR_ID_LABEL) - require(executorId != null, "Unexpected pod metadata; expected all executor pods " + - s"to have label $SPARK_EXECUTOR_ID_LABEL.") - executorId - } + override def doKillExecutors(executorIds: Seq[String]): Future[Boolean] = Future[Boolean] { + kubernetesClient.pods() + .withLabel(SPARK_APP_ID_LABEL, applicationId()) + .withLabel(SPARK_ROLE_LABEL, SPARK_POD_EXECUTOR_ROLE) + .withLabelIn(SPARK_EXECUTOR_ID_LABEL, executorIds: _*) + .delete() + // Don't do anything else - let event handling from the Kubernetes API do the Spark changes } override def createDriverEndpoint(properties: Seq[(String, String)]): DriverEndpoint = { new KubernetesDriverEndpoint(rpcEnv, properties) } - private class KubernetesDriverEndpoint( - rpcEnv: RpcEnv, - sparkProperties: Seq[(String, String)]) + private class KubernetesDriverEndpoint(rpcEnv: RpcEnv, sparkProperties: Seq[(String, String)]) extends DriverEndpoint(rpcEnv, sparkProperties) { override def onDisconnected(rpcAddress: RpcAddress): Unit = { - addressToExecutorId.get(rpcAddress).foreach { executorId => - if (disableExecutor(executorId)) { - RUNNING_EXECUTOR_PODS_LOCK.synchronized { - runningExecutorsToPods.get(executorId).foreach { pod => - disconnectedPodsByExecutorIdPendingRemoval.put(executorId, pod) - } - } - } - } + // Don't do anything besides disabling the executor - allow the Kubernetes API events to + // drive the rest of the lifecycle decisions + // TODO what if we disconnect from a networking issue? Probably want to mark the executor + // to be deleted eventually. + addressToExecutorId.get(rpcAddress).foreach(disableExecutor) } } -} -private object KubernetesClusterSchedulerBackend { - private val UNKNOWN_EXIT_CODE = -1 } diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/Fabric8Aliases.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/Fabric8Aliases.scala new file mode 100644 index 0000000000000..527fc6b0d8f87 --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/Fabric8Aliases.scala @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s + +import io.fabric8.kubernetes.api.model.{DoneablePod, HasMetadata, Pod, PodList} +import io.fabric8.kubernetes.client.{Watch, Watcher} +import io.fabric8.kubernetes.client.dsl.{FilterWatchListDeletable, MixedOperation, NamespaceListVisitFromServerGetDeleteRecreateWaitApplicable, PodResource} + +object Fabric8Aliases { + type PODS = MixedOperation[Pod, PodList, DoneablePod, PodResource[Pod, DoneablePod]] + type LABELED_PODS = FilterWatchListDeletable[ + Pod, PodList, java.lang.Boolean, Watch, Watcher[Pod]] + type SINGLE_POD = PodResource[Pod, DoneablePod] + type RESOURCE_LIST = NamespaceListVisitFromServerGetDeleteRecreateWaitApplicable[ + HasMetadata, Boolean] +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala index a8a8218c621ea..d045d9ae89c07 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala @@ -27,6 +27,7 @@ import org.scalatest.mockito.MockitoSugar._ import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpec, KubernetesDriverSpecificConf, SparkPod} import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.deploy.k8s.Fabric8Aliases._ class ClientSuite extends SparkFunSuite with BeforeAndAfter { @@ -103,15 +104,11 @@ class ClientSuite extends SparkFunSuite with BeforeAndAfter { .build() } - private type ResourceList = NamespaceListVisitFromServerGetDeleteRecreateWaitApplicable[ - HasMetadata, Boolean] - private type Pods = MixedOperation[Pod, PodList, DoneablePod, PodResource[Pod, DoneablePod]] - @Mock private var kubernetesClient: KubernetesClient = _ @Mock - private var podOperations: Pods = _ + private var podOperations: PODS = _ @Mock private var namedPods: PodResource[Pod, DoneablePod] = _ @@ -123,7 +120,7 @@ class ClientSuite extends SparkFunSuite with BeforeAndAfter { private var driverBuilder: KubernetesDriverBuilder = _ @Mock - private var resourceList: ResourceList = _ + private var resourceList: RESOURCE_LIST = _ private var kubernetesConf: KubernetesConf[KubernetesDriverSpecificConf] = _ diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/DeterministicExecutorPodsSnapshotsStore.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/DeterministicExecutorPodsSnapshotsStore.scala new file mode 100644 index 0000000000000..f7721e6fd6388 --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/DeterministicExecutorPodsSnapshotsStore.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import io.fabric8.kubernetes.api.model.Pod +import scala.collection.mutable + +class DeterministicExecutorPodsSnapshotsStore extends ExecutorPodsSnapshotsStore { + + private val snapshotsBuffer = mutable.Buffer.empty[ExecutorPodsSnapshot] + private val subscribers = mutable.Buffer.empty[Seq[ExecutorPodsSnapshot] => Unit] + + private var currentSnapshot = ExecutorPodsSnapshot() + + override def addSubscriber + (processBatchIntervalMillis: Long) + (onNewSnapshots: Seq[ExecutorPodsSnapshot] => Unit): Unit = { + subscribers += onNewSnapshots + } + + override def stop(): Unit = {} + + def notifySubscribers(): Unit = { + subscribers.foreach(_(snapshotsBuffer)) + snapshotsBuffer.clear() + } + + override def updatePod(updatedPod: Pod): Unit = { + currentSnapshot = currentSnapshot.withUpdate(updatedPod) + snapshotsBuffer += currentSnapshot + } + + override def replaceSnapshot(newSnapshot: Seq[Pod]): Unit = { + currentSnapshot = ExecutorPodsSnapshot(newSnapshot) + snapshotsBuffer += currentSnapshot + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorLifecycleTestUtils.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorLifecycleTestUtils.scala new file mode 100644 index 0000000000000..c6b667ed85e8c --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorLifecycleTestUtils.scala @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import io.fabric8.kubernetes.api.model.{ContainerBuilder, Pod, PodBuilder} + +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.deploy.k8s.SparkPod + +object ExecutorLifecycleTestUtils { + + val TEST_SPARK_APP_ID = "spark-app-id" + + def failedExecutorWithoutDeletion(executorId: Long): Pod = { + new PodBuilder(podWithAttachedContainerForId(executorId)) + .editOrNewStatus() + .withPhase("failed") + .addNewContainerStatus() + .withName("spark-executor") + .withImage("k8s-spark") + .withNewState() + .withNewTerminated() + .withMessage("Failed") + .withExitCode(1) + .endTerminated() + .endState() + .endContainerStatus() + .addNewContainerStatus() + .withName("spark-executor-sidecar") + .withImage("k8s-spark-sidecar") + .withNewState() + .withNewTerminated() + .withMessage("Failed") + .withExitCode(1) + .endTerminated() + .endState() + .endContainerStatus() + .withMessage("Executor failed.") + .withReason("Executor failed because of a thrown error.") + .endStatus() + .build() + } + + def pendingExecutor(executorId: Long): Pod = { + new PodBuilder(podWithAttachedContainerForId(executorId)) + .editOrNewStatus() + .withPhase("pending") + .endStatus() + .build() + } + + def runningExecutor(executorId: Long): Pod = { + new PodBuilder(podWithAttachedContainerForId(executorId)) + .editOrNewStatus() + .withPhase("running") + .endStatus() + .build() + } + + def succeededExecutor(executorId: Long): Pod = { + new PodBuilder(podWithAttachedContainerForId(executorId)) + .editOrNewStatus() + .withPhase("succeeded") + .endStatus() + .build() + } + + def deletedExecutor(executorId: Long): Pod = { + new PodBuilder(podWithAttachedContainerForId(executorId)) + .editOrNewMetadata() + .withNewDeletionTimestamp("523012521") + .endMetadata() + .build() + } + + def unknownExecutor(executorId: Long): Pod = { + new PodBuilder(podWithAttachedContainerForId(executorId)) + .editOrNewStatus() + .withPhase("unknown") + .endStatus() + .build() + } + + def podWithAttachedContainerForId(executorId: Long): Pod = { + val sparkPod = executorPodWithId(executorId) + val podWithAttachedContainer = new PodBuilder(sparkPod.pod) + .editOrNewSpec() + .addToContainers(sparkPod.container) + .endSpec() + .build() + podWithAttachedContainer + } + + def executorPodWithId(executorId: Long): SparkPod = { + val pod = new PodBuilder() + .withNewMetadata() + .withName(s"spark-executor-$executorId") + .addToLabels(SPARK_APP_ID_LABEL, TEST_SPARK_APP_ID) + .addToLabels(SPARK_ROLE_LABEL, SPARK_POD_EXECUTOR_ROLE) + .addToLabels(SPARK_EXECUTOR_ID_LABEL, executorId.toString) + .endMetadata() + .build() + val container = new ContainerBuilder() + .withName("spark-executor") + .withImage("k8s-spark") + .build() + SparkPod(pod, container) + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocatorSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocatorSuite.scala new file mode 100644 index 0000000000000..0c19f5946b75f --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocatorSuite.scala @@ -0,0 +1,179 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import io.fabric8.kubernetes.api.model.{DoneablePod, Pod, PodBuilder} +import io.fabric8.kubernetes.client.KubernetesClient +import io.fabric8.kubernetes.client.dsl.PodResource +import org.mockito.{ArgumentMatcher, Matchers, Mock, MockitoAnnotations} +import org.mockito.Matchers.any +import org.mockito.Mockito.{never, times, verify, when} +import org.mockito.invocation.InvocationOnMock +import org.mockito.stubbing.Answer +import org.scalatest.BeforeAndAfter + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesExecutorSpecificConf, SparkPod} +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.deploy.k8s.Fabric8Aliases._ +import org.apache.spark.scheduler.cluster.k8s.ExecutorLifecycleTestUtils._ +import org.apache.spark.util.ManualClock + +class ExecutorPodsAllocatorSuite extends SparkFunSuite with BeforeAndAfter { + + private val driverPodName = "driver" + + private val driverPod = new PodBuilder() + .withNewMetadata() + .withName(driverPodName) + .addToLabels(SPARK_APP_ID_LABEL, TEST_SPARK_APP_ID) + .addToLabels(SPARK_ROLE_LABEL, SPARK_POD_DRIVER_ROLE) + .withUid("driver-pod-uid") + .endMetadata() + .build() + + private val conf = new SparkConf().set(KUBERNETES_DRIVER_POD_NAME, driverPodName) + + private val podAllocationSize = conf.get(KUBERNETES_ALLOCATION_BATCH_SIZE) + private val podAllocationDelay = conf.get(KUBERNETES_ALLOCATION_BATCH_DELAY) + private val podCreationTimeout = math.max(podAllocationDelay * 5, 60000L) + + private var waitForExecutorPodsClock: ManualClock = _ + + @Mock + private var kubernetesClient: KubernetesClient = _ + + @Mock + private var podOperations: PODS = _ + + @Mock + private var labeledPods: LABELED_PODS = _ + + @Mock + private var driverPodOperations: PodResource[Pod, DoneablePod] = _ + + @Mock + private var executorBuilder: KubernetesExecutorBuilder = _ + + private var snapshotsStore: DeterministicExecutorPodsSnapshotsStore = _ + + private var podsAllocatorUnderTest: ExecutorPodsAllocator = _ + + before { + MockitoAnnotations.initMocks(this) + when(kubernetesClient.pods()).thenReturn(podOperations) + when(podOperations.withName(driverPodName)).thenReturn(driverPodOperations) + when(driverPodOperations.get).thenReturn(driverPod) + when(executorBuilder.buildFromFeatures(kubernetesConfWithCorrectFields())) + .thenAnswer(executorPodAnswer()) + snapshotsStore = new DeterministicExecutorPodsSnapshotsStore() + waitForExecutorPodsClock = new ManualClock(0L) + podsAllocatorUnderTest = new ExecutorPodsAllocator( + conf, executorBuilder, kubernetesClient, snapshotsStore, waitForExecutorPodsClock) + podsAllocatorUnderTest.start(TEST_SPARK_APP_ID) + } + + test("Initially request executors in batches. Do not request another batch if the" + + " first has not finished.") { + podsAllocatorUnderTest.setTotalExpectedExecutors(podAllocationSize + 1) + snapshotsStore.replaceSnapshot(Seq.empty[Pod]) + snapshotsStore.notifySubscribers() + for (nextId <- 1 to podAllocationSize) { + verify(podOperations).create(podWithAttachedContainerForId(nextId)) + } + verify(podOperations, never()).create(podWithAttachedContainerForId(podAllocationSize + 1)) + } + + test("Request executors in batches. Allow another batch to be requested if" + + " all pending executors start running.") { + podsAllocatorUnderTest.setTotalExpectedExecutors(podAllocationSize + 1) + snapshotsStore.replaceSnapshot(Seq.empty[Pod]) + snapshotsStore.notifySubscribers() + for (execId <- 1 until podAllocationSize) { + snapshotsStore.updatePod(runningExecutor(execId)) + } + snapshotsStore.notifySubscribers() + verify(podOperations, never()).create(podWithAttachedContainerForId(podAllocationSize + 1)) + snapshotsStore.updatePod(runningExecutor(podAllocationSize)) + snapshotsStore.notifySubscribers() + verify(podOperations).create(podWithAttachedContainerForId(podAllocationSize + 1)) + snapshotsStore.updatePod(runningExecutor(podAllocationSize)) + snapshotsStore.notifySubscribers() + verify(podOperations, times(podAllocationSize + 1)).create(any(classOf[Pod])) + } + + test("When a current batch reaches error states immediately, re-request" + + " them on the next batch.") { + podsAllocatorUnderTest.setTotalExpectedExecutors(podAllocationSize) + snapshotsStore.replaceSnapshot(Seq.empty[Pod]) + snapshotsStore.notifySubscribers() + for (execId <- 1 until podAllocationSize) { + snapshotsStore.updatePod(runningExecutor(execId)) + } + val failedPod = failedExecutorWithoutDeletion(podAllocationSize) + snapshotsStore.updatePod(failedPod) + snapshotsStore.notifySubscribers() + verify(podOperations).create(podWithAttachedContainerForId(podAllocationSize + 1)) + } + + test("When an executor is requested but the API does not report it in a reasonable time, retry" + + " requesting that executor.") { + podsAllocatorUnderTest.setTotalExpectedExecutors(1) + snapshotsStore.replaceSnapshot(Seq.empty[Pod]) + snapshotsStore.notifySubscribers() + snapshotsStore.replaceSnapshot(Seq.empty[Pod]) + waitForExecutorPodsClock.setTime(podCreationTimeout + 1) + when(podOperations.withLabel(SPARK_EXECUTOR_ID_LABEL, "1")).thenReturn(labeledPods) + snapshotsStore.notifySubscribers() + verify(labeledPods).delete() + verify(podOperations).create(podWithAttachedContainerForId(2)) + } + + private def executorPodAnswer(): Answer[SparkPod] = { + new Answer[SparkPod] { + override def answer(invocation: InvocationOnMock): SparkPod = { + val k8sConf = invocation.getArgumentAt( + 0, classOf[KubernetesConf[KubernetesExecutorSpecificConf]]) + executorPodWithId(k8sConf.roleSpecificConf.executorId.toInt) + } + } + } + + private def kubernetesConfWithCorrectFields(): KubernetesConf[KubernetesExecutorSpecificConf] = + Matchers.argThat(new ArgumentMatcher[KubernetesConf[KubernetesExecutorSpecificConf]] { + override def matches(argument: scala.Any): Boolean = { + if (!argument.isInstanceOf[KubernetesConf[KubernetesExecutorSpecificConf]]) { + false + } else { + val k8sConf = argument.asInstanceOf[KubernetesConf[KubernetesExecutorSpecificConf]] + val executorSpecificConf = k8sConf.roleSpecificConf + val expectedK8sConf = KubernetesConf.createExecutorConf( + conf, + executorSpecificConf.executorId, + TEST_SPARK_APP_ID, + driverPod) + k8sConf.sparkConf.getAll.toMap == conf.getAll.toMap && + // Since KubernetesConf.createExecutorConf clones the SparkConf object, force + // deep equality comparison for the SparkConf object and use object equality + // comparison on all other fields. + k8sConf.copy(sparkConf = conf) == expectedK8sConf.copy(sparkConf = conf) + } + } + }) + +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsLifecycleManagerSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsLifecycleManagerSuite.scala new file mode 100644 index 0000000000000..562ace9f49d4d --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsLifecycleManagerSuite.scala @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import com.google.common.cache.CacheBuilder +import io.fabric8.kubernetes.api.model.{DoneablePod, Pod} +import io.fabric8.kubernetes.client.KubernetesClient +import io.fabric8.kubernetes.client.dsl.PodResource +import org.mockito.{Mock, MockitoAnnotations} +import org.mockito.Matchers.any +import org.mockito.Mockito.{mock, times, verify, when} +import org.mockito.invocation.InvocationOnMock +import org.mockito.stubbing.Answer +import org.scalatest.BeforeAndAfter +import scala.collection.JavaConverters._ +import scala.collection.mutable + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.k8s.Fabric8Aliases._ +import org.apache.spark.scheduler.ExecutorExited +import org.apache.spark.scheduler.cluster.k8s.ExecutorLifecycleTestUtils._ + +class ExecutorPodsLifecycleManagerSuite extends SparkFunSuite with BeforeAndAfter { + + private var namedExecutorPods: mutable.Map[String, PodResource[Pod, DoneablePod]] = _ + + @Mock + private var kubernetesClient: KubernetesClient = _ + + @Mock + private var podOperations: PODS = _ + + @Mock + private var executorBuilder: KubernetesExecutorBuilder = _ + + @Mock + private var schedulerBackend: KubernetesClusterSchedulerBackend = _ + + private var snapshotsStore: DeterministicExecutorPodsSnapshotsStore = _ + private var eventHandlerUnderTest: ExecutorPodsLifecycleManager = _ + + before { + MockitoAnnotations.initMocks(this) + val removedExecutorsCache = CacheBuilder.newBuilder().build[java.lang.Long, java.lang.Long] + snapshotsStore = new DeterministicExecutorPodsSnapshotsStore() + namedExecutorPods = mutable.Map.empty[String, PodResource[Pod, DoneablePod]] + when(schedulerBackend.getExecutorIds()).thenReturn(Seq.empty[String]) + when(kubernetesClient.pods()).thenReturn(podOperations) + when(podOperations.withName(any(classOf[String]))).thenAnswer(namedPodsAnswer()) + eventHandlerUnderTest = new ExecutorPodsLifecycleManager( + new SparkConf(), + executorBuilder, + kubernetesClient, + snapshotsStore, + removedExecutorsCache) + eventHandlerUnderTest.start(schedulerBackend) + } + + test("When an executor reaches error states immediately, remove from the scheduler backend.") { + val failedPod = failedExecutorWithoutDeletion(1) + snapshotsStore.updatePod(failedPod) + snapshotsStore.notifySubscribers() + val msg = exitReasonMessage(1, failedPod) + val expectedLossReason = ExecutorExited(1, exitCausedByApp = true, msg) + verify(schedulerBackend).doRemoveExecutor("1", expectedLossReason) + verify(namedExecutorPods(failedPod.getMetadata.getName)).delete() + } + + test("Don't remove executors twice from Spark but remove from K8s repeatedly.") { + val failedPod = failedExecutorWithoutDeletion(1) + snapshotsStore.updatePod(failedPod) + snapshotsStore.updatePod(failedPod) + snapshotsStore.notifySubscribers() + val msg = exitReasonMessage(1, failedPod) + val expectedLossReason = ExecutorExited(1, exitCausedByApp = true, msg) + verify(schedulerBackend, times(1)).doRemoveExecutor("1", expectedLossReason) + verify(namedExecutorPods(failedPod.getMetadata.getName), times(2)).delete() + } + + test("When the scheduler backend lists executor ids that aren't present in the cluster," + + " remove those executors from Spark.") { + when(schedulerBackend.getExecutorIds()).thenReturn(Seq("1")) + val msg = s"The executor with ID 1 was not found in the cluster but we didn't" + + s" get a reason why. Marking the executor as failed. The executor may have been" + + s" deleted but the driver missed the deletion event." + val expectedLossReason = ExecutorExited(-1, exitCausedByApp = false, msg) + snapshotsStore.replaceSnapshot(Seq.empty[Pod]) + snapshotsStore.notifySubscribers() + verify(schedulerBackend).doRemoveExecutor("1", expectedLossReason) + } + + private def exitReasonMessage(failedExecutorId: Int, failedPod: Pod): String = { + s""" + |The executor with id $failedExecutorId exited with exit code 1. + |The API gave the following brief reason: ${failedPod.getStatus.getReason} + |The API gave the following message: ${failedPod.getStatus.getMessage} + |The API gave the following container statuses: + | + |${failedPod.getStatus.getContainerStatuses.asScala.map(_.toString).mkString("\n===\n")} + """.stripMargin + } + + private def namedPodsAnswer(): Answer[PodResource[Pod, DoneablePod]] = { + new Answer[PodResource[Pod, DoneablePod]] { + override def answer(invocation: InvocationOnMock): PodResource[Pod, DoneablePod] = { + val podName = invocation.getArgumentAt(0, classOf[String]) + namedExecutorPods.getOrElseUpdate( + podName, mock(classOf[PodResource[Pod, DoneablePod]])) + } + } + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsPollingSnapshotSourceSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsPollingSnapshotSourceSuite.scala new file mode 100644 index 0000000000000..1b26d6af296a5 --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsPollingSnapshotSourceSuite.scala @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import java.util.concurrent.TimeUnit + +import io.fabric8.kubernetes.api.model.PodListBuilder +import io.fabric8.kubernetes.client.KubernetesClient +import org.jmock.lib.concurrent.DeterministicScheduler +import org.mockito.{Mock, MockitoAnnotations} +import org.mockito.Mockito.{verify, when} +import org.scalatest.BeforeAndAfter + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.deploy.k8s.Fabric8Aliases._ +import org.apache.spark.scheduler.cluster.k8s.ExecutorLifecycleTestUtils._ + +class ExecutorPodsPollingSnapshotSourceSuite extends SparkFunSuite with BeforeAndAfter { + + private val sparkConf = new SparkConf + + private val pollingInterval = sparkConf.get(KUBERNETES_EXECUTOR_API_POLLING_INTERVAL) + + @Mock + private var kubernetesClient: KubernetesClient = _ + + @Mock + private var podOperations: PODS = _ + + @Mock + private var appIdLabeledPods: LABELED_PODS = _ + + @Mock + private var executorRoleLabeledPods: LABELED_PODS = _ + + @Mock + private var eventQueue: ExecutorPodsSnapshotsStore = _ + + private var pollingExecutor: DeterministicScheduler = _ + private var pollingSourceUnderTest: ExecutorPodsPollingSnapshotSource = _ + + before { + MockitoAnnotations.initMocks(this) + pollingExecutor = new DeterministicScheduler() + pollingSourceUnderTest = new ExecutorPodsPollingSnapshotSource( + sparkConf, + kubernetesClient, + eventQueue, + pollingExecutor) + pollingSourceUnderTest.start(TEST_SPARK_APP_ID) + when(kubernetesClient.pods()).thenReturn(podOperations) + when(podOperations.withLabel(SPARK_APP_ID_LABEL, TEST_SPARK_APP_ID)) + .thenReturn(appIdLabeledPods) + when(appIdLabeledPods.withLabel(SPARK_ROLE_LABEL, SPARK_POD_EXECUTOR_ROLE)) + .thenReturn(executorRoleLabeledPods) + } + + test("Items returned by the API should be pushed to the event queue") { + when(executorRoleLabeledPods.list()) + .thenReturn(new PodListBuilder() + .addToItems( + runningExecutor(1), + runningExecutor(2)) + .build()) + pollingExecutor.tick(pollingInterval, TimeUnit.MILLISECONDS) + verify(eventQueue).replaceSnapshot(Seq(runningExecutor(1), runningExecutor(2))) + + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshotSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshotSuite.scala new file mode 100644 index 0000000000000..70e19c904eddb --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshotSuite.scala @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import org.apache.spark.SparkFunSuite +import org.apache.spark.scheduler.cluster.k8s.ExecutorLifecycleTestUtils._ + +class ExecutorPodsSnapshotSuite extends SparkFunSuite { + + test("States are interpreted correctly from pod metadata.") { + val pods = Seq( + pendingExecutor(0), + runningExecutor(1), + succeededExecutor(2), + failedExecutorWithoutDeletion(3), + deletedExecutor(4), + unknownExecutor(5)) + val snapshot = ExecutorPodsSnapshot(pods) + assert(snapshot.executorPods === + Map( + 0L -> PodPending(pods(0)), + 1L -> PodRunning(pods(1)), + 2L -> PodSucceeded(pods(2)), + 3L -> PodFailed(pods(3)), + 4L -> PodDeleted(pods(4)), + 5L -> PodUnknown(pods(5)))) + } + + test("Updates add new pods for non-matching ids and edit existing pods for matching ids") { + val originalPods = Seq( + pendingExecutor(0), + runningExecutor(1)) + val originalSnapshot = ExecutorPodsSnapshot(originalPods) + val snapshotWithUpdatedPod = originalSnapshot.withUpdate(succeededExecutor(1)) + assert(snapshotWithUpdatedPod.executorPods === + Map( + 0L -> PodPending(originalPods(0)), + 1L -> PodSucceeded(succeededExecutor(1)))) + val snapshotWithNewPod = snapshotWithUpdatedPod.withUpdate(pendingExecutor(2)) + assert(snapshotWithNewPod.executorPods === + Map( + 0L -> PodPending(originalPods(0)), + 1L -> PodSucceeded(succeededExecutor(1)), + 2L -> PodPending(pendingExecutor(2)))) + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshotsStoreSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshotsStoreSuite.scala new file mode 100644 index 0000000000000..cf54b3c4eb329 --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshotsStoreSuite.scala @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import java.util.concurrent.TimeUnit +import java.util.concurrent.atomic.AtomicReference + +import io.fabric8.kubernetes.api.model.{Pod, PodBuilder} +import org.jmock.lib.concurrent.DeterministicScheduler +import org.scalatest.BeforeAndAfter +import scala.collection.mutable + +import org.apache.spark.SparkFunSuite +import org.apache.spark.deploy.k8s.Constants._ + +class ExecutorPodsSnapshotsStoreSuite extends SparkFunSuite with BeforeAndAfter { + + private var eventBufferScheduler: DeterministicScheduler = _ + private var eventQueueUnderTest: ExecutorPodsSnapshotsStoreImpl = _ + + before { + eventBufferScheduler = new DeterministicScheduler() + eventQueueUnderTest = new ExecutorPodsSnapshotsStoreImpl(eventBufferScheduler) + } + + test("Subscribers get notified of events periodically.") { + val receivedSnapshots1 = mutable.Buffer.empty[ExecutorPodsSnapshot] + val receivedSnapshots2 = mutable.Buffer.empty[ExecutorPodsSnapshot] + eventQueueUnderTest.addSubscriber(1000) { + receivedSnapshots1 ++= _ + } + eventQueueUnderTest.addSubscriber(2000) { + receivedSnapshots2 ++= _ + } + + eventBufferScheduler.runUntilIdle() + assert(receivedSnapshots1 === Seq(ExecutorPodsSnapshot())) + assert(receivedSnapshots2 === Seq(ExecutorPodsSnapshot())) + + pushPodWithIndex(1) + // Force time to move forward so that the buffer is emitted, scheduling the + // processing task on the subscription executor... + eventBufferScheduler.tick(1000, TimeUnit.MILLISECONDS) + // ... then actually execute the subscribers. + + assert(receivedSnapshots1 === Seq( + ExecutorPodsSnapshot(), + ExecutorPodsSnapshot(Seq(podWithIndex(1))))) + assert(receivedSnapshots2 === Seq(ExecutorPodsSnapshot())) + + eventBufferScheduler.tick(1000, TimeUnit.MILLISECONDS) + + // Don't repeat snapshots + assert(receivedSnapshots1 === Seq( + ExecutorPodsSnapshot(), + ExecutorPodsSnapshot(Seq(podWithIndex(1))))) + assert(receivedSnapshots2 === Seq( + ExecutorPodsSnapshot(), + ExecutorPodsSnapshot(Seq(podWithIndex(1))))) + pushPodWithIndex(2) + pushPodWithIndex(3) + eventBufferScheduler.tick(1000, TimeUnit.MILLISECONDS) + + assert(receivedSnapshots1 === Seq( + ExecutorPodsSnapshot(), + ExecutorPodsSnapshot(Seq(podWithIndex(1))), + ExecutorPodsSnapshot(Seq(podWithIndex(1), podWithIndex(2))), + ExecutorPodsSnapshot(Seq(podWithIndex(1), podWithIndex(2), podWithIndex(3))))) + assert(receivedSnapshots2 === Seq( + ExecutorPodsSnapshot(), + ExecutorPodsSnapshot(Seq(podWithIndex(1))))) + + eventBufferScheduler.tick(1000, TimeUnit.MILLISECONDS) + assert(receivedSnapshots1 === Seq( + ExecutorPodsSnapshot(), + ExecutorPodsSnapshot(Seq(podWithIndex(1))), + ExecutorPodsSnapshot(Seq(podWithIndex(1), podWithIndex(2))), + ExecutorPodsSnapshot(Seq(podWithIndex(1), podWithIndex(2), podWithIndex(3))))) + assert(receivedSnapshots1 === receivedSnapshots2) + } + + test("Even without sending events, initially receive an empty buffer.") { + val receivedInitialSnapshot = new AtomicReference[Seq[ExecutorPodsSnapshot]](null) + eventQueueUnderTest.addSubscriber(1000) { + receivedInitialSnapshot.set + } + assert(receivedInitialSnapshot.get == null) + eventBufferScheduler.runUntilIdle() + assert(receivedInitialSnapshot.get === Seq(ExecutorPodsSnapshot())) + } + + test("Replacing the snapshot passes the new snapshot to subscribers.") { + val receivedSnapshots = mutable.Buffer.empty[ExecutorPodsSnapshot] + eventQueueUnderTest.addSubscriber(1000) { + receivedSnapshots ++= _ + } + eventQueueUnderTest.updatePod(podWithIndex(1)) + eventBufferScheduler.tick(1000, TimeUnit.MILLISECONDS) + assert(receivedSnapshots === Seq( + ExecutorPodsSnapshot(), + ExecutorPodsSnapshot(Seq(podWithIndex(1))))) + eventQueueUnderTest.replaceSnapshot(Seq(podWithIndex(2))) + eventBufferScheduler.tick(1000, TimeUnit.MILLISECONDS) + assert(receivedSnapshots === Seq( + ExecutorPodsSnapshot(), + ExecutorPodsSnapshot(Seq(podWithIndex(1))), + ExecutorPodsSnapshot(Seq(podWithIndex(2))))) + } + + private def pushPodWithIndex(index: Int): Unit = + eventQueueUnderTest.updatePod(podWithIndex(index)) + + private def podWithIndex(index: Int): Pod = + new PodBuilder() + .editOrNewMetadata() + .withName(s"pod-$index") + .addToLabels(SPARK_EXECUTOR_ID_LABEL, index.toString) + .endMetadata() + .editOrNewStatus() + .withPhase("running") + .endStatus() + .build() +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsWatchSnapshotSourceSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsWatchSnapshotSourceSuite.scala new file mode 100644 index 0000000000000..ac1968b4ff810 --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsWatchSnapshotSourceSuite.scala @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import io.fabric8.kubernetes.api.model.Pod +import io.fabric8.kubernetes.client.{KubernetesClient, Watch, Watcher} +import io.fabric8.kubernetes.client.Watcher.Action +import org.mockito.{ArgumentCaptor, Mock, MockitoAnnotations} +import org.mockito.Mockito.{verify, when} +import org.scalatest.BeforeAndAfter + +import org.apache.spark.SparkFunSuite +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.deploy.k8s.Fabric8Aliases._ +import org.apache.spark.scheduler.cluster.k8s.ExecutorLifecycleTestUtils._ + +class ExecutorPodsWatchSnapshotSourceSuite extends SparkFunSuite with BeforeAndAfter { + + @Mock + private var eventQueue: ExecutorPodsSnapshotsStore = _ + + @Mock + private var kubernetesClient: KubernetesClient = _ + + @Mock + private var podOperations: PODS = _ + + @Mock + private var appIdLabeledPods: LABELED_PODS = _ + + @Mock + private var executorRoleLabeledPods: LABELED_PODS = _ + + @Mock + private var watchConnection: Watch = _ + + private var watch: ArgumentCaptor[Watcher[Pod]] = _ + + private var watchSourceUnderTest: ExecutorPodsWatchSnapshotSource = _ + + before { + MockitoAnnotations.initMocks(this) + watch = ArgumentCaptor.forClass(classOf[Watcher[Pod]]) + when(kubernetesClient.pods()).thenReturn(podOperations) + when(podOperations.withLabel(SPARK_APP_ID_LABEL, TEST_SPARK_APP_ID)) + .thenReturn(appIdLabeledPods) + when(appIdLabeledPods.withLabel(SPARK_ROLE_LABEL, SPARK_POD_EXECUTOR_ROLE)) + .thenReturn(executorRoleLabeledPods) + when(executorRoleLabeledPods.watch(watch.capture())).thenReturn(watchConnection) + watchSourceUnderTest = new ExecutorPodsWatchSnapshotSource( + eventQueue, kubernetesClient) + watchSourceUnderTest.start(TEST_SPARK_APP_ID) + } + + test("Watch events should be pushed to the snapshots store as snapshot updates.") { + watch.getValue.eventReceived(Action.ADDED, runningExecutor(1)) + watch.getValue.eventReceived(Action.MODIFIED, runningExecutor(2)) + verify(eventQueue).updatePod(runningExecutor(1)) + verify(eventQueue).updatePod(runningExecutor(2)) + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackendSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackendSuite.scala index 96065e83f069c..52e7a12dbaf06 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackendSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackendSuite.scala @@ -16,85 +16,36 @@ */ package org.apache.spark.scheduler.cluster.k8s -import java.util.concurrent.{ExecutorService, ScheduledExecutorService, TimeUnit} - -import io.fabric8.kubernetes.api.model.{ContainerBuilder, DoneablePod, Pod, PodBuilder, PodList} -import io.fabric8.kubernetes.client.{KubernetesClient, Watch, Watcher} -import io.fabric8.kubernetes.client.Watcher.Action -import io.fabric8.kubernetes.client.dsl.{FilterWatchListDeletable, MixedOperation, NonNamespaceOperation, PodResource} -import org.hamcrest.{BaseMatcher, Description, Matcher} -import org.mockito.{AdditionalAnswers, ArgumentCaptor, Matchers, Mock, MockitoAnnotations} -import org.mockito.Matchers.{any, eq => mockitoEq} -import org.mockito.Mockito.{doNothing, never, times, verify, when} +import io.fabric8.kubernetes.client.KubernetesClient +import org.jmock.lib.concurrent.DeterministicScheduler +import org.mockito.{ArgumentCaptor, Mock, MockitoAnnotations} +import org.mockito.Matchers.{eq => mockitoEq} +import org.mockito.Mockito.{never, verify, when} import org.scalatest.BeforeAndAfter -import org.scalatest.mockito.MockitoSugar._ -import scala.collection.JavaConverters._ -import scala.concurrent.Future import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} -import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesExecutorSpecificConf, SparkPod} -import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ -import org.apache.spark.rpc._ -import org.apache.spark.scheduler.{ExecutorExited, LiveListenerBus, SlaveLost, TaskSchedulerImpl} -import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.{RegisterExecutor, RemoveExecutor} +import org.apache.spark.deploy.k8s.Fabric8Aliases._ +import org.apache.spark.rpc.{RpcEndpoint, RpcEndpointRef, RpcEnv} +import org.apache.spark.scheduler.{ExecutorKilled, TaskSchedulerImpl} +import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.RemoveExecutor import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend -import org.apache.spark.util.ThreadUtils +import org.apache.spark.scheduler.cluster.k8s.ExecutorLifecycleTestUtils.TEST_SPARK_APP_ID class KubernetesClusterSchedulerBackendSuite extends SparkFunSuite with BeforeAndAfter { - private val APP_ID = "test-spark-app" - private val DRIVER_POD_NAME = "spark-driver-pod" - private val NAMESPACE = "test-namespace" - private val SPARK_DRIVER_HOST = "localhost" - private val SPARK_DRIVER_PORT = 7077 - private val POD_ALLOCATION_INTERVAL = "1m" - private val FIRST_EXECUTOR_POD = new PodBuilder() - .withNewMetadata() - .withName("pod1") - .endMetadata() - .withNewSpec() - .withNodeName("node1") - .endSpec() - .withNewStatus() - .withHostIP("192.168.99.100") - .endStatus() - .build() - private val SECOND_EXECUTOR_POD = new PodBuilder() - .withNewMetadata() - .withName("pod2") - .endMetadata() - .withNewSpec() - .withNodeName("node2") - .endSpec() - .withNewStatus() - .withHostIP("192.168.99.101") - .endStatus() - .build() - - private type PODS = MixedOperation[Pod, PodList, DoneablePod, PodResource[Pod, DoneablePod]] - private type LABELED_PODS = FilterWatchListDeletable[ - Pod, PodList, java.lang.Boolean, Watch, Watcher[Pod]] - private type IN_NAMESPACE_PODS = NonNamespaceOperation[ - Pod, PodList, DoneablePod, PodResource[Pod, DoneablePod]] - - @Mock - private var sparkContext: SparkContext = _ - - @Mock - private var listenerBus: LiveListenerBus = _ - - @Mock - private var taskSchedulerImpl: TaskSchedulerImpl = _ + private val requestExecutorsService = new DeterministicScheduler() + private val sparkConf = new SparkConf(false) + .set("spark.executor.instances", "3") @Mock - private var allocatorExecutor: ScheduledExecutorService = _ + private var sc: SparkContext = _ @Mock - private var requestExecutorsService: ExecutorService = _ + private var rpcEnv: RpcEnv = _ @Mock - private var executorBuilder: KubernetesExecutorBuilder = _ + private var driverEndpointRef: RpcEndpointRef = _ @Mock private var kubernetesClient: KubernetesClient = _ @@ -103,347 +54,97 @@ class KubernetesClusterSchedulerBackendSuite extends SparkFunSuite with BeforeAn private var podOperations: PODS = _ @Mock - private var podsWithLabelOperations: LABELED_PODS = _ + private var labeledPods: LABELED_PODS = _ @Mock - private var podsInNamespace: IN_NAMESPACE_PODS = _ + private var taskScheduler: TaskSchedulerImpl = _ @Mock - private var podsWithDriverName: PodResource[Pod, DoneablePod] = _ + private var eventQueue: ExecutorPodsSnapshotsStore = _ @Mock - private var rpcEnv: RpcEnv = _ + private var podAllocator: ExecutorPodsAllocator = _ @Mock - private var driverEndpointRef: RpcEndpointRef = _ + private var lifecycleEventHandler: ExecutorPodsLifecycleManager = _ @Mock - private var executorPodsWatch: Watch = _ + private var watchEvents: ExecutorPodsWatchSnapshotSource = _ @Mock - private var successFuture: Future[Boolean] = _ + private var pollEvents: ExecutorPodsPollingSnapshotSource = _ - private var sparkConf: SparkConf = _ - private var executorPodsWatcherArgument: ArgumentCaptor[Watcher[Pod]] = _ - private var allocatorRunnable: ArgumentCaptor[Runnable] = _ - private var requestExecutorRunnable: ArgumentCaptor[Runnable] = _ private var driverEndpoint: ArgumentCaptor[RpcEndpoint] = _ - - private val driverPod = new PodBuilder() - .withNewMetadata() - .withName(DRIVER_POD_NAME) - .addToLabels(SPARK_APP_ID_LABEL, APP_ID) - .addToLabels(SPARK_ROLE_LABEL, SPARK_POD_DRIVER_ROLE) - .endMetadata() - .build() + private var schedulerBackendUnderTest: KubernetesClusterSchedulerBackend = _ before { MockitoAnnotations.initMocks(this) - sparkConf = new SparkConf() - .set(KUBERNETES_DRIVER_POD_NAME, DRIVER_POD_NAME) - .set(KUBERNETES_NAMESPACE, NAMESPACE) - .set("spark.driver.host", SPARK_DRIVER_HOST) - .set("spark.driver.port", SPARK_DRIVER_PORT.toString) - .set(KUBERNETES_ALLOCATION_BATCH_DELAY.key, POD_ALLOCATION_INTERVAL) - executorPodsWatcherArgument = ArgumentCaptor.forClass(classOf[Watcher[Pod]]) - allocatorRunnable = ArgumentCaptor.forClass(classOf[Runnable]) - requestExecutorRunnable = ArgumentCaptor.forClass(classOf[Runnable]) + when(taskScheduler.sc).thenReturn(sc) + when(sc.conf).thenReturn(sparkConf) driverEndpoint = ArgumentCaptor.forClass(classOf[RpcEndpoint]) - when(sparkContext.conf).thenReturn(sparkConf) - when(sparkContext.listenerBus).thenReturn(listenerBus) - when(taskSchedulerImpl.sc).thenReturn(sparkContext) - when(kubernetesClient.pods()).thenReturn(podOperations) - when(podOperations.withLabel(SPARK_APP_ID_LABEL, APP_ID)).thenReturn(podsWithLabelOperations) - when(podsWithLabelOperations.watch(executorPodsWatcherArgument.capture())) - .thenReturn(executorPodsWatch) - when(podOperations.inNamespace(NAMESPACE)).thenReturn(podsInNamespace) - when(podsInNamespace.withName(DRIVER_POD_NAME)).thenReturn(podsWithDriverName) - when(podsWithDriverName.get()).thenReturn(driverPod) - when(allocatorExecutor.scheduleWithFixedDelay( - allocatorRunnable.capture(), - mockitoEq(0L), - mockitoEq(TimeUnit.MINUTES.toMillis(1)), - mockitoEq(TimeUnit.MILLISECONDS))).thenReturn(null) - // Creating Futures in Scala backed by a Java executor service resolves to running - // ExecutorService#execute (as opposed to submit) - doNothing().when(requestExecutorsService).execute(requestExecutorRunnable.capture()) when(rpcEnv.setupEndpoint( mockitoEq(CoarseGrainedSchedulerBackend.ENDPOINT_NAME), driverEndpoint.capture())) .thenReturn(driverEndpointRef) - - // Used by the CoarseGrainedSchedulerBackend when making RPC calls. - when(driverEndpointRef.ask[Boolean] - (any(classOf[Any])) - (any())).thenReturn(successFuture) - when(successFuture.failed).thenReturn(Future[Throwable] { - // emulate behavior of the Future.failed method. - throw new NoSuchElementException() - }(ThreadUtils.sameThread)) - } - - test("Basic lifecycle expectations when starting and stopping the scheduler.") { - val scheduler = newSchedulerBackend() - scheduler.start() - assert(executorPodsWatcherArgument.getValue != null) - assert(allocatorRunnable.getValue != null) - scheduler.stop() - verify(executorPodsWatch).close() - } - - test("Static allocation should request executors upon first allocator run.") { - sparkConf - .set(KUBERNETES_ALLOCATION_BATCH_SIZE, 2) - .set(org.apache.spark.internal.config.EXECUTOR_INSTANCES, 2) - val scheduler = newSchedulerBackend() - scheduler.start() - requestExecutorRunnable.getValue.run() - val firstResolvedPod = expectPodCreationWithId(1, FIRST_EXECUTOR_POD) - val secondResolvedPod = expectPodCreationWithId(2, SECOND_EXECUTOR_POD) - when(podOperations.create(any(classOf[Pod]))).thenAnswer(AdditionalAnswers.returnsFirstArg()) - allocatorRunnable.getValue.run() - verify(podOperations).create(firstResolvedPod) - verify(podOperations).create(secondResolvedPod) - } - - test("Killing executors deletes the executor pods") { - sparkConf - .set(KUBERNETES_ALLOCATION_BATCH_SIZE, 2) - .set(org.apache.spark.internal.config.EXECUTOR_INSTANCES, 2) - val scheduler = newSchedulerBackend() - scheduler.start() - requestExecutorRunnable.getValue.run() - val firstResolvedPod = expectPodCreationWithId(1, FIRST_EXECUTOR_POD) - val secondResolvedPod = expectPodCreationWithId(2, SECOND_EXECUTOR_POD) - when(podOperations.create(any(classOf[Pod]))) - .thenAnswer(AdditionalAnswers.returnsFirstArg()) - allocatorRunnable.getValue.run() - scheduler.doKillExecutors(Seq("2")) - requestExecutorRunnable.getAllValues.asScala.last.run() - verify(podOperations).delete(secondResolvedPod) - verify(podOperations, never()).delete(firstResolvedPod) - } - - test("Executors should be requested in batches.") { - sparkConf - .set(KUBERNETES_ALLOCATION_BATCH_SIZE, 1) - .set(org.apache.spark.internal.config.EXECUTOR_INSTANCES, 2) - val scheduler = newSchedulerBackend() - scheduler.start() - requestExecutorRunnable.getValue.run() - when(podOperations.create(any(classOf[Pod]))) - .thenAnswer(AdditionalAnswers.returnsFirstArg()) - val firstResolvedPod = expectPodCreationWithId(1, FIRST_EXECUTOR_POD) - val secondResolvedPod = expectPodCreationWithId(2, SECOND_EXECUTOR_POD) - allocatorRunnable.getValue.run() - verify(podOperations).create(firstResolvedPod) - verify(podOperations, never()).create(secondResolvedPod) - val registerFirstExecutorMessage = RegisterExecutor( - "1", mock[RpcEndpointRef], "localhost", 1, Map.empty[String, String]) - when(taskSchedulerImpl.resourceOffers(any())).thenReturn(Seq.empty) - driverEndpoint.getValue.receiveAndReply(mock[RpcCallContext]) - .apply(registerFirstExecutorMessage) - allocatorRunnable.getValue.run() - verify(podOperations).create(secondResolvedPod) - } - - test("Scaled down executors should be cleaned up") { - sparkConf - .set(KUBERNETES_ALLOCATION_BATCH_SIZE, 1) - .set(org.apache.spark.internal.config.EXECUTOR_INSTANCES, 1) - val scheduler = newSchedulerBackend() - scheduler.start() - - // The scheduler backend spins up one executor pod. - requestExecutorRunnable.getValue.run() - when(podOperations.create(any(classOf[Pod]))) - .thenAnswer(AdditionalAnswers.returnsFirstArg()) - val resolvedPod = expectPodCreationWithId(1, FIRST_EXECUTOR_POD) - allocatorRunnable.getValue.run() - val executorEndpointRef = mock[RpcEndpointRef] - when(executorEndpointRef.address).thenReturn(RpcAddress("pod.example.com", 9000)) - val registerFirstExecutorMessage = RegisterExecutor( - "1", executorEndpointRef, "localhost:9000", 1, Map.empty[String, String]) - when(taskSchedulerImpl.resourceOffers(any())).thenReturn(Seq.empty) - driverEndpoint.getValue.receiveAndReply(mock[RpcCallContext]) - .apply(registerFirstExecutorMessage) - - // Request that there are 0 executors and trigger deletion from driver. - scheduler.doRequestTotalExecutors(0) - requestExecutorRunnable.getAllValues.asScala.last.run() - scheduler.doKillExecutors(Seq("1")) - requestExecutorRunnable.getAllValues.asScala.last.run() - verify(podOperations, times(1)).delete(resolvedPod) - driverEndpoint.getValue.onDisconnected(executorEndpointRef.address) - - val exitedPod = exitPod(resolvedPod, 0) - executorPodsWatcherArgument.getValue.eventReceived(Action.DELETED, exitedPod) - allocatorRunnable.getValue.run() - - // No more deletion attempts of the executors. - // This is graceful termination and should not be detected as a failure. - verify(podOperations, times(1)).delete(resolvedPod) - verify(driverEndpointRef, times(1)).send( - RemoveExecutor("1", ExecutorExited( - 0, - exitCausedByApp = false, - s"Container in pod ${exitedPod.getMetadata.getName} exited from" + - s" explicit termination request."))) - } - - test("Executors that fail should not be deleted.") { - sparkConf - .set(KUBERNETES_ALLOCATION_BATCH_SIZE, 1) - .set(org.apache.spark.internal.config.EXECUTOR_INSTANCES, 1) - - val scheduler = newSchedulerBackend() - scheduler.start() - val firstResolvedPod = expectPodCreationWithId(1, FIRST_EXECUTOR_POD) - when(podOperations.create(any(classOf[Pod]))).thenAnswer(AdditionalAnswers.returnsFirstArg()) - requestExecutorRunnable.getValue.run() - allocatorRunnable.getValue.run() - val executorEndpointRef = mock[RpcEndpointRef] - when(executorEndpointRef.address).thenReturn(RpcAddress("pod.example.com", 9000)) - val registerFirstExecutorMessage = RegisterExecutor( - "1", executorEndpointRef, "localhost:9000", 1, Map.empty[String, String]) - when(taskSchedulerImpl.resourceOffers(any())).thenReturn(Seq.empty) - driverEndpoint.getValue.receiveAndReply(mock[RpcCallContext]) - .apply(registerFirstExecutorMessage) - driverEndpoint.getValue.onDisconnected(executorEndpointRef.address) - executorPodsWatcherArgument.getValue.eventReceived( - Action.ERROR, exitPod(firstResolvedPod, 1)) - - // A replacement executor should be created but the error pod should persist. - val replacementPod = expectPodCreationWithId(2, SECOND_EXECUTOR_POD) - scheduler.doRequestTotalExecutors(1) - requestExecutorRunnable.getValue.run() - allocatorRunnable.getAllValues.asScala.last.run() - verify(podOperations, never()).delete(firstResolvedPod) - verify(driverEndpointRef).send( - RemoveExecutor("1", ExecutorExited( - 1, - exitCausedByApp = true, - s"Pod ${FIRST_EXECUTOR_POD.getMetadata.getName}'s executor container exited with" + - " exit status code 1."))) - } - - test("Executors disconnected due to unknown reasons are deleted and replaced.") { - sparkConf - .set(KUBERNETES_ALLOCATION_BATCH_SIZE, 1) - .set(org.apache.spark.internal.config.EXECUTOR_INSTANCES, 1) - val executorLostReasonCheckMaxAttempts = sparkConf.get( - KUBERNETES_EXECUTOR_LOST_REASON_CHECK_MAX_ATTEMPTS) - - val scheduler = newSchedulerBackend() - scheduler.start() - val firstResolvedPod = expectPodCreationWithId(1, FIRST_EXECUTOR_POD) - when(podOperations.create(any(classOf[Pod]))).thenAnswer(AdditionalAnswers.returnsFirstArg()) - requestExecutorRunnable.getValue.run() - allocatorRunnable.getValue.run() - val executorEndpointRef = mock[RpcEndpointRef] - when(executorEndpointRef.address).thenReturn(RpcAddress("pod.example.com", 9000)) - val registerFirstExecutorMessage = RegisterExecutor( - "1", executorEndpointRef, "localhost:9000", 1, Map.empty[String, String]) - when(taskSchedulerImpl.resourceOffers(any())).thenReturn(Seq.empty) - driverEndpoint.getValue.receiveAndReply(mock[RpcCallContext]) - .apply(registerFirstExecutorMessage) - - driverEndpoint.getValue.onDisconnected(executorEndpointRef.address) - 1 to executorLostReasonCheckMaxAttempts foreach { _ => - allocatorRunnable.getValue.run() - verify(podOperations, never()).delete(FIRST_EXECUTOR_POD) + when(kubernetesClient.pods()).thenReturn(podOperations) + schedulerBackendUnderTest = new KubernetesClusterSchedulerBackend( + taskScheduler, + rpcEnv, + kubernetesClient, + requestExecutorsService, + eventQueue, + podAllocator, + lifecycleEventHandler, + watchEvents, + pollEvents) { + override def applicationId(): String = TEST_SPARK_APP_ID } - - val recreatedResolvedPod = expectPodCreationWithId(2, SECOND_EXECUTOR_POD) - allocatorRunnable.getValue.run() - verify(podOperations).delete(firstResolvedPod) - verify(driverEndpointRef).send( - RemoveExecutor("1", SlaveLost("Executor lost for unknown reasons."))) } - test("Executors that fail to start on the Kubernetes API call rebuild in the next batch.") { - sparkConf - .set(KUBERNETES_ALLOCATION_BATCH_SIZE, 1) - .set(org.apache.spark.internal.config.EXECUTOR_INSTANCES, 1) - val scheduler = newSchedulerBackend() - scheduler.start() - val firstResolvedPod = expectPodCreationWithId(1, FIRST_EXECUTOR_POD) - when(podOperations.create(firstResolvedPod)) - .thenThrow(new RuntimeException("test")) - requestExecutorRunnable.getValue.run() - allocatorRunnable.getValue.run() - verify(podOperations, times(1)).create(firstResolvedPod) - val recreatedResolvedPod = expectPodCreationWithId(2, FIRST_EXECUTOR_POD) - allocatorRunnable.getValue.run() - verify(podOperations).create(recreatedResolvedPod) + test("Start all components") { + schedulerBackendUnderTest.start() + verify(podAllocator).setTotalExpectedExecutors(3) + verify(podAllocator).start(TEST_SPARK_APP_ID) + verify(lifecycleEventHandler).start(schedulerBackendUnderTest) + verify(watchEvents).start(TEST_SPARK_APP_ID) + verify(pollEvents).start(TEST_SPARK_APP_ID) } - test("Executors that are initially created but the watch notices them fail are rebuilt" + - " in the next batch.") { - sparkConf - .set(KUBERNETES_ALLOCATION_BATCH_SIZE, 1) - .set(org.apache.spark.internal.config.EXECUTOR_INSTANCES, 1) - val scheduler = newSchedulerBackend() - scheduler.start() - val firstResolvedPod = expectPodCreationWithId(1, FIRST_EXECUTOR_POD) - when(podOperations.create(FIRST_EXECUTOR_POD)).thenAnswer(AdditionalAnswers.returnsFirstArg()) - requestExecutorRunnable.getValue.run() - allocatorRunnable.getValue.run() - verify(podOperations, times(1)).create(firstResolvedPod) - executorPodsWatcherArgument.getValue.eventReceived(Action.ERROR, firstResolvedPod) - val recreatedResolvedPod = expectPodCreationWithId(2, FIRST_EXECUTOR_POD) - allocatorRunnable.getValue.run() - verify(podOperations).create(recreatedResolvedPod) + test("Stop all components") { + when(podOperations.withLabel(SPARK_APP_ID_LABEL, TEST_SPARK_APP_ID)).thenReturn(labeledPods) + when(labeledPods.withLabel(SPARK_ROLE_LABEL, SPARK_POD_EXECUTOR_ROLE)).thenReturn(labeledPods) + schedulerBackendUnderTest.stop() + verify(eventQueue).stop() + verify(watchEvents).stop() + verify(pollEvents).stop() + verify(labeledPods).delete() + verify(kubernetesClient).close() } - private def newSchedulerBackend(): KubernetesClusterSchedulerBackend = { - new KubernetesClusterSchedulerBackend( - taskSchedulerImpl, - rpcEnv, - executorBuilder, - kubernetesClient, - allocatorExecutor, - requestExecutorsService) { - - override def applicationId(): String = APP_ID - } + test("Remove executor") { + schedulerBackendUnderTest.start() + schedulerBackendUnderTest.doRemoveExecutor( + "1", ExecutorKilled) + verify(driverEndpointRef).send(RemoveExecutor("1", ExecutorKilled)) } - private def exitPod(basePod: Pod, exitCode: Int): Pod = { - new PodBuilder(basePod) - .editStatus() - .addNewContainerStatus() - .withNewState() - .withNewTerminated() - .withExitCode(exitCode) - .endTerminated() - .endState() - .endContainerStatus() - .endStatus() - .build() + test("Kill executors") { + schedulerBackendUnderTest.start() + when(podOperations.withLabel(SPARK_APP_ID_LABEL, TEST_SPARK_APP_ID)).thenReturn(labeledPods) + when(labeledPods.withLabel(SPARK_ROLE_LABEL, SPARK_POD_EXECUTOR_ROLE)).thenReturn(labeledPods) + when(labeledPods.withLabelIn(SPARK_EXECUTOR_ID_LABEL, "1", "2")).thenReturn(labeledPods) + schedulerBackendUnderTest.doKillExecutors(Seq("1", "2")) + verify(labeledPods, never()).delete() + requestExecutorsService.runNextPendingCommand() + verify(labeledPods).delete() } - private def expectPodCreationWithId(executorId: Int, expectedPod: Pod): Pod = { - val resolvedPod = new PodBuilder(expectedPod) - .editMetadata() - .addToLabels(SPARK_EXECUTOR_ID_LABEL, executorId.toString) - .endMetadata() - .build() - val resolvedContainer = new ContainerBuilder().build() - when(executorBuilder.buildFromFeatures(Matchers.argThat( - new BaseMatcher[KubernetesConf[KubernetesExecutorSpecificConf]] { - override def matches(argument: scala.Any) - : Boolean = { - argument.isInstanceOf[KubernetesConf[KubernetesExecutorSpecificConf]] && - argument.asInstanceOf[KubernetesConf[KubernetesExecutorSpecificConf]] - .roleSpecificConf.executorId == executorId.toString - } - - override def describeTo(description: Description): Unit = {} - }))).thenReturn(SparkPod(resolvedPod, resolvedContainer)) - new PodBuilder(resolvedPod) - .editSpec() - .addToContainers(resolvedContainer) - .endSpec() - .build() + test("Request total executors") { + schedulerBackendUnderTest.start() + schedulerBackendUnderTest.doRequestTotalExecutors(5) + verify(podAllocator).setTotalExpectedExecutors(3) + verify(podAllocator, never()).setTotalExpectedExecutors(5) + requestExecutorsService.runNextPendingCommand() + verify(podAllocator).setTotalExpectedExecutors(5) } + } From 22daeba59b3ffaccafc9ff4b521abc265d0e58dd Mon Sep 17 00:00:00 2001 From: Ryan Blue Date: Thu, 14 Jun 2018 20:59:42 -0700 Subject: [PATCH 0967/1161] [SPARK-24478][SQL] Move projection and filter push down to physical conversion ## What changes were proposed in this pull request? This removes the v2 optimizer rule for push-down and instead pushes filters and required columns when converting to a physical plan, as suggested by marmbrus. This makes the v2 relation cleaner because the output and filters do not change in the logical plan. A side-effect of this change is that the stats from the logical (optimized) plan no longer reflect pushed filters and projection. This is a temporary state, until the planner gathers stats from the physical plan instead. An alternative to this approach is https://github.com/rdblue/spark/commit/9d3a11e68bca6c5a56a2be47fb09395350362ac5. The first commit was proposed in #21262. This PR replaces #21262. ## How was this patch tested? Existing tests. Author: Ryan Blue Closes #21503 from rdblue/SPARK-24478-move-push-down-to-physical-conversion. --- .../v2/reader/SupportsReportStatistics.java | 5 ++ .../spark/sql/execution/SparkOptimizer.scala | 4 +- .../datasources/v2/DataSourceV2Relation.scala | 85 +++++-------------- .../datasources/v2/DataSourceV2Strategy.scala | 47 +++++++++- .../v2/PushDownOperatorsToDataSource.scala | 66 -------------- .../sql/sources/v2/DataSourceV2Suite.scala | 11 +-- 6 files changed, 79 insertions(+), 139 deletions(-) delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportStatistics.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportStatistics.java index 11bb13fd3b211..a79080a249ec8 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportStatistics.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportStatistics.java @@ -22,6 +22,11 @@ /** * A mix in interface for {@link DataSourceReader}. Data source readers can implement this * interface to report statistics to Spark. + * + * Statistics are reported to the optimizer before a projection or any filters are pushed to the + * DataSourceReader. Implementations that return more accurate statistics based on projection and + * filters will not improve query performance until the planner can push operators before getting + * stats. */ @InterfaceStability.Evolving public interface SupportsReportStatistics extends DataSourceReader { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala index 1c8e4050978dc..00ff4c8ac310b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala @@ -21,7 +21,6 @@ import org.apache.spark.sql.ExperimentalMethods import org.apache.spark.sql.catalyst.catalog.SessionCatalog import org.apache.spark.sql.catalyst.optimizer.Optimizer import org.apache.spark.sql.execution.datasources.PruneFileSourcePartitions -import org.apache.spark.sql.execution.datasources.v2.PushDownOperatorsToDataSource import org.apache.spark.sql.execution.python.ExtractPythonUDFFromAggregate class SparkOptimizer( @@ -32,8 +31,7 @@ class SparkOptimizer( override def batches: Seq[Batch] = (preOptimizationBatches ++ super.batches :+ Batch("Optimize Metadata Only Query", Once, OptimizeMetadataOnlyQuery(catalog)) :+ Batch("Extract Python UDF from Aggregate", Once, ExtractPythonUDFFromAggregate) :+ - Batch("Prune File Source Table Partitions", Once, PruneFileSourcePartitions) :+ - Batch("Push down operators to data source scan", Once, PushDownOperatorsToDataSource)) ++ + Batch("Prune File Source Table Partitions", Once, PruneFileSourcePartitions)) ++ postHocOptimizationBatches :+ Batch("User Provided Optimizers", fixedPoint, experimentalMethods.extraOptimizations: _*) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala index 90fb5a14c9fc9..e08af218513fd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala @@ -22,7 +22,6 @@ import scala.collection.JavaConverters._ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression} -import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} import org.apache.spark.sql.execution.datasources.DataSourceStrategy import org.apache.spark.sql.sources.{DataSourceRegister, Filter} @@ -32,69 +31,27 @@ import org.apache.spark.sql.types.StructType case class DataSourceV2Relation( source: DataSourceV2, + output: Seq[AttributeReference], options: Map[String, String], - projection: Seq[AttributeReference], - filters: Option[Seq[Expression]] = None, userSpecifiedSchema: Option[StructType] = None) extends LeafNode with MultiInstanceRelation with DataSourceV2StringFormat { import DataSourceV2Relation._ - override def simpleString: String = "RelationV2 " + metadataString - - override lazy val schema: StructType = reader.readSchema() - - override lazy val output: Seq[AttributeReference] = { - // use the projection attributes to avoid assigning new ids. fields that are not projected - // will be assigned new ids, which is okay because they are not projected. - val attrMap = projection.map(a => a.name -> a).toMap - schema.map(f => attrMap.getOrElse(f.name, - AttributeReference(f.name, f.dataType, f.nullable, f.metadata)())) - } - - private lazy val v2Options: DataSourceOptions = makeV2Options(options) + override def pushedFilters: Seq[Expression] = Seq.empty - // postScanFilters: filters that need to be evaluated after the scan. - // pushedFilters: filters that will be pushed down and evaluated in the underlying data sources. - // Note: postScanFilters and pushedFilters can overlap, e.g. the parquet row group filter. - lazy val ( - reader: DataSourceReader, - postScanFilters: Seq[Expression], - pushedFilters: Seq[Expression]) = { - val newReader = userSpecifiedSchema match { - case Some(s) => - source.asReadSupportWithSchema.createReader(s, v2Options) - case _ => - source.asReadSupport.createReader(v2Options) - } - - DataSourceV2Relation.pushRequiredColumns(newReader, projection.toStructType) - - val (postScanFilters, pushedFilters) = filters match { - case Some(filterSeq) => - DataSourceV2Relation.pushFilters(newReader, filterSeq) - case _ => - (Nil, Nil) - } - logInfo(s"Post-Scan Filters: ${postScanFilters.mkString(",")}") - logInfo(s"Pushed Filters: ${pushedFilters.mkString(", ")}") - - (newReader, postScanFilters, pushedFilters) - } - - override def doCanonicalize(): LogicalPlan = { - val c = super.doCanonicalize().asInstanceOf[DataSourceV2Relation] + override def simpleString: String = "RelationV2 " + metadataString - // override output with canonicalized output to avoid attempting to configure a reader - val canonicalOutput: Seq[AttributeReference] = this.output - .map(a => QueryPlan.normalizeExprId(a, projection)) + lazy val v2Options: DataSourceOptions = makeV2Options(options) - new DataSourceV2Relation(c.source, c.options, c.projection) { - override lazy val output: Seq[AttributeReference] = canonicalOutput - } + def newReader: DataSourceReader = userSpecifiedSchema match { + case Some(userSchema) => + source.asReadSupportWithSchema.createReader(userSchema, v2Options) + case None => + source.asReadSupport.createReader(v2Options) } - override def computeStats(): Statistics = reader match { + override def computeStats(): Statistics = newReader match { case r: SupportsReportStatistics => Statistics(sizeInBytes = r.getStatistics.sizeInBytes().orElse(conf.defaultSizeInBytes)) case _ => @@ -102,9 +59,7 @@ case class DataSourceV2Relation( } override def newInstance(): DataSourceV2Relation = { - // projection is used to maintain id assignment. - // if projection is not set, use output so the copy is not equal to the original - copy(projection = projection.map(_.newInstance())) + copy(output = output.map(_.newInstance())) } } @@ -206,21 +161,27 @@ object DataSourceV2Relation { def create( source: DataSourceV2, options: Map[String, String], - filters: Option[Seq[Expression]] = None, userSpecifiedSchema: Option[StructType] = None): DataSourceV2Relation = { - val projection = schema(source, makeV2Options(options), userSpecifiedSchema).toAttributes - DataSourceV2Relation(source, options, projection, filters, userSpecifiedSchema) + val output = schema(source, makeV2Options(options), userSpecifiedSchema).toAttributes + DataSourceV2Relation(source, output, options, userSpecifiedSchema) } - private def pushRequiredColumns(reader: DataSourceReader, struct: StructType): Unit = { + def pushRequiredColumns( + relation: DataSourceV2Relation, + reader: DataSourceReader, + struct: StructType): Seq[AttributeReference] = { reader match { case projectionSupport: SupportsPushDownRequiredColumns => projectionSupport.pruneColumns(struct) + // return the output columns from the relation that were projected + val attrMap = relation.output.map(a => a.name -> a).toMap + projectionSupport.readSchema().map(f => attrMap(f.name)) case _ => + relation.output } } - private def pushFilters( + def pushFilters( reader: DataSourceReader, filters: Seq[Expression]): (Seq[Expression], Seq[Expression]) = { reader match { @@ -248,7 +209,7 @@ object DataSourceV2Relation { // the data source cannot guarantee the rows returned can pass these filters. // As a result we must return it so Spark can plan an extra filter operator. val postScanFilters = - r.pushFilters(translatedFilterToExpr.keys.toArray).map(translatedFilterToExpr) + r.pushFilters(translatedFilterToExpr.keys.toArray).map(translatedFilterToExpr) // The filters which are marked as pushed to this data source val pushedFilters = r.pushedFilters().map(translatedFilterToExpr) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index 1b7c639f10f98..8bf858c38d76c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -17,15 +17,56 @@ package org.apache.spark.sql.execution.datasources.v2 -import org.apache.spark.sql.Strategy +import org.apache.spark.sql.{execution, Strategy} +import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, AttributeSet} +import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.streaming.continuous.{WriteToContinuousDataSource, WriteToContinuousDataSourceExec} object DataSourceV2Strategy extends Strategy { override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case r: DataSourceV2Relation => - DataSourceV2ScanExec(r.output, r.source, r.options, r.pushedFilters, r.reader) :: Nil + case PhysicalOperation(project, filters, relation: DataSourceV2Relation) => + val projectSet = AttributeSet(project.flatMap(_.references)) + val filterSet = AttributeSet(filters.flatMap(_.references)) + + val projection = if (filterSet.subsetOf(projectSet) && + AttributeSet(relation.output) == projectSet) { + // When the required projection contains all of the filter columns and column pruning alone + // can produce the required projection, push the required projection. + // A final projection may still be needed if the data source produces a different column + // order or if it cannot prune all of the nested columns. + relation.output + } else { + // When there are filter columns not already in the required projection or when the required + // projection is more complicated than column pruning, base column pruning on the set of + // all columns needed by both. + (projectSet ++ filterSet).toSeq + } + + val reader = relation.newReader + + val output = DataSourceV2Relation.pushRequiredColumns(relation, reader, + projection.asInstanceOf[Seq[AttributeReference]].toStructType) + + val (postScanFilters, pushedFilters) = DataSourceV2Relation.pushFilters(reader, filters) + + logInfo(s"Post-Scan Filters: ${postScanFilters.mkString(",")}") + logInfo(s"Pushed Filters: ${pushedFilters.mkString(", ")}") + + val scan = DataSourceV2ScanExec( + output, relation.source, relation.options, pushedFilters, reader) + + val filter = postScanFilters.reduceLeftOption(And) + val withFilter = filter.map(execution.FilterExec(_, scan)).getOrElse(scan) + + val withProjection = if (withFilter.output != project) { + execution.ProjectExec(project, withFilter) + } else { + withFilter + } + + withProjection :: Nil case r: StreamingDataSourceV2Relation => DataSourceV2ScanExec(r.output, r.source, r.options, r.pushedFilters, r.reader) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala deleted file mode 100644 index e894f8afd6762..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala +++ /dev/null @@ -1,66 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.datasources.v2 - -import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, AttributeSet} -import org.apache.spark.sql.catalyst.planning.PhysicalOperation -import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project} -import org.apache.spark.sql.catalyst.rules.Rule - -object PushDownOperatorsToDataSource extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = plan match { - // PhysicalOperation guarantees that filters are deterministic; no need to check - case PhysicalOperation(project, filters, relation: DataSourceV2Relation) => - assert(relation.filters.isEmpty, "data source v2 should do push down only once.") - - val projectAttrs = project.map(_.toAttribute) - val projectSet = AttributeSet(project.flatMap(_.references)) - val filterSet = AttributeSet(filters.flatMap(_.references)) - - val projection = if (filterSet.subsetOf(projectSet) && - AttributeSet(projectAttrs) == projectSet) { - // When the required projection contains all of the filter columns and column pruning alone - // can produce the required projection, push the required projection. - // A final projection may still be needed if the data source produces a different column - // order or if it cannot prune all of the nested columns. - projectAttrs - } else { - // When there are filter columns not already in the required projection or when the required - // projection is more complicated than column pruning, base column pruning on the set of - // all columns needed by both. - (projectSet ++ filterSet).toSeq - } - - val newRelation = relation.copy( - projection = projection.asInstanceOf[Seq[AttributeReference]], - filters = Some(filters)) - - // Add a Filter for any filters that need to be evaluated after scan. - val postScanFilterCond = newRelation.postScanFilters.reduceLeftOption(And) - val filtered = postScanFilterCond.map(Filter(_, newRelation)).getOrElse(newRelation) - - // Add a Project to ensure the output matches the required projection - if (newRelation.output != projectAttrs) { - Project(project, filtered) - } else { - filtered - } - - case other => other.mapChildren(apply) - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala index 505a3f3465c02..e96cd4500458d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala @@ -323,21 +323,22 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { } test("SPARK-23315: get output from canonicalized data source v2 related plans") { - def checkCanonicalizedOutput(df: DataFrame, numOutput: Int): Unit = { + def checkCanonicalizedOutput( + df: DataFrame, logicalNumOutput: Int, physicalNumOutput: Int): Unit = { val logical = df.queryExecution.optimizedPlan.collect { case d: DataSourceV2Relation => d }.head - assert(logical.canonicalized.output.length == numOutput) + assert(logical.canonicalized.output.length == logicalNumOutput) val physical = df.queryExecution.executedPlan.collect { case d: DataSourceV2ScanExec => d }.head - assert(physical.canonicalized.output.length == numOutput) + assert(physical.canonicalized.output.length == physicalNumOutput) } val df = spark.read.format(classOf[AdvancedDataSourceV2].getName).load() - checkCanonicalizedOutput(df, 2) - checkCanonicalizedOutput(df.select('i), 1) + checkCanonicalizedOutput(df, 2, 2) + checkCanonicalizedOutput(df.select('i), 2, 1) } } From 6567fc43aca75b41900cde976594e21c8b0ca98a Mon Sep 17 00:00:00 2001 From: Ruben Berenguel Montoro Date: Fri, 15 Jun 2018 16:59:00 +0800 Subject: [PATCH 0968/1161] [PYTHON] Fix typo in serializer exception ## What changes were proposed in this pull request? Fix typo in exception raised in Python serializer ## How was this patch tested? No code changes Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Ruben Berenguel Montoro Closes #21566 from rberenguel/fix_typo_pyspark_serializers. --- python/pyspark/serializers.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 15753f77bd903..4c16b5fc26f3d 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -33,8 +33,9 @@ [0, 2, 4, 6, 8, 10, 12, 14, 16, 18] >>> sc.stop() -PySpark serialize objects in batches; By default, the batch size is chosen based -on the size of objects, also configurable by SparkContext's C{batchSize} parameter: +PySpark serializes objects in batches; by default, the batch size is chosen based +on the size of objects and is also configurable by SparkContext's C{batchSize} +parameter: >>> sc = SparkContext('local', 'test', batchSize=2) >>> rdd = sc.parallelize(range(16), 4).map(lambda x: x) @@ -100,7 +101,7 @@ def load_stream(self, stream): def _load_stream_without_unbatching(self, stream): """ Return an iterator of deserialized batches (iterable) of objects from the input stream. - if the serializer does not operate on batches the default implementation returns an + If the serializer does not operate on batches the default implementation returns an iterator of single element lists. """ return map(lambda x: [x], self.load_stream(stream)) @@ -461,7 +462,7 @@ def dumps(self, obj): return obj -# Hook namedtuple, make it picklable +# Hack namedtuple, make it picklable __cls = {} @@ -525,15 +526,15 @@ def namedtuple(*args, **kwargs): cls = _old_namedtuple(*args, **kwargs) return _hack_namedtuple(cls) - # replace namedtuple with new one + # replace namedtuple with the new one collections.namedtuple.__globals__["_old_namedtuple_kwdefaults"] = _old_namedtuple_kwdefaults collections.namedtuple.__globals__["_old_namedtuple"] = _old_namedtuple collections.namedtuple.__globals__["_hack_namedtuple"] = _hack_namedtuple collections.namedtuple.__code__ = namedtuple.__code__ collections.namedtuple.__hijack = 1 - # hack the cls already generated by namedtuple - # those created in other module can be pickled as normal, + # hack the cls already generated by namedtuple. + # Those created in other modules can be pickled as normal, # so only hack those in __main__ module for n, o in sys.modules["__main__"].__dict__.items(): if (type(o) is type and o.__base__ is tuple @@ -627,7 +628,7 @@ def loads(self, obj): elif _type == b'P': return pickle.loads(obj[1:]) else: - raise ValueError("invalid sevialization type: %s" % _type) + raise ValueError("invalid serialization type: %s" % _type) class CompressedSerializer(FramedSerializer): From 495d8cf09ae7134aa6d2feb058612980e02955fa Mon Sep 17 00:00:00 2001 From: Jacek Laskowski Date: Fri, 15 Jun 2018 09:59:02 -0700 Subject: [PATCH 0969/1161] [SPARK-24490][WEBUI] Use WebUI.addStaticHandler in web UIs `WebUI` defines `addStaticHandler` that web UIs don't use (and simply introduce duplication). Let's clean them up and remove duplications. Local build and waiting for Jenkins Author: Jacek Laskowski Closes #21510 from jaceklaskowski/SPARK-24490-Use-WebUI.addStaticHandler. --- .../spark/deploy/history/HistoryServer.scala | 2 +- .../spark/deploy/master/ui/MasterWebUI.scala | 2 +- .../spark/deploy/worker/ui/WorkerWebUI.scala | 2 +- .../scala/org/apache/spark/ui/SparkUI.scala | 2 +- .../scala/org/apache/spark/ui/WebUI.scala | 52 ++++++++++--------- .../deploy/mesos/ui/MesosClusterUI.scala | 2 +- .../spark/streaming/ui/StreamingTab.scala | 2 +- 7 files changed, 33 insertions(+), 31 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala index 066275e8f8425..56f3f59504a7d 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala @@ -124,7 +124,7 @@ class HistoryServer( attachHandler(ApiRootResource.getServletHandler(this)) - attachHandler(createStaticHandler(SparkUI.STATIC_RESOURCE_DIR, "/static")) + addStaticHandler(SparkUI.STATIC_RESOURCE_DIR) val contextHandler = new ServletContextHandler contextHandler.setContextPath(HistoryServer.UI_PATH_PREFIX) diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala index 35b7ddd46e4db..e87b2240564bd 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala @@ -43,7 +43,7 @@ class MasterWebUI( val masterPage = new MasterPage(this) attachPage(new ApplicationPage(this)) attachPage(masterPage) - attachHandler(createStaticHandler(MasterWebUI.STATIC_RESOURCE_DIR, "/static")) + addStaticHandler(MasterWebUI.STATIC_RESOURCE_DIR) attachHandler(createRedirectHandler( "/app/kill", "/", masterPage.handleAppKillRequest, httpMethods = Set("POST"))) attachHandler(createRedirectHandler( diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala index db696b04384bd..ea67b7434a769 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala @@ -47,7 +47,7 @@ class WorkerWebUI( val logPage = new LogPage(this) attachPage(logPage) attachPage(new WorkerPage(this)) - attachHandler(createStaticHandler(WorkerWebUI.STATIC_RESOURCE_BASE, "/static")) + addStaticHandler(WorkerWebUI.STATIC_RESOURCE_BASE) attachHandler(createServletHandler("/log", (request: HttpServletRequest) => logPage.renderLog(request), worker.securityMgr, diff --git a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala index b44ac0ea1febc..d315ef66e0dc0 100644 --- a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala @@ -65,7 +65,7 @@ private[spark] class SparkUI private ( attachTab(new StorageTab(this, store)) attachTab(new EnvironmentTab(this, store)) attachTab(new ExecutorsTab(this)) - attachHandler(createStaticHandler(SparkUI.STATIC_RESOURCE_DIR, "/static")) + addStaticHandler(SparkUI.STATIC_RESOURCE_DIR) attachHandler(createRedirectHandler("/", "/jobs/", basePath = basePath)) attachHandler(ApiRootResource.getServletHandler(this)) diff --git a/core/src/main/scala/org/apache/spark/ui/WebUI.scala b/core/src/main/scala/org/apache/spark/ui/WebUI.scala index 8b75f5d8fe1a8..2e43f17e6a8e3 100644 --- a/core/src/main/scala/org/apache/spark/ui/WebUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/WebUI.scala @@ -60,23 +60,25 @@ private[spark] abstract class WebUI( def getHandlers: Seq[ServletContextHandler] = handlers def getSecurityManager: SecurityManager = securityManager - /** Attach a tab to this UI, along with all of its attached pages. */ - def attachTab(tab: WebUITab) { + /** Attaches a tab to this UI, along with all of its attached pages. */ + def attachTab(tab: WebUITab): Unit = { tab.pages.foreach(attachPage) tabs += tab } - def detachTab(tab: WebUITab) { + /** Detaches a tab from this UI, along with all of its attached pages. */ + def detachTab(tab: WebUITab): Unit = { tab.pages.foreach(detachPage) tabs -= tab } - def detachPage(page: WebUIPage) { + /** Detaches a page from this UI, along with all of its attached handlers. */ + def detachPage(page: WebUIPage): Unit = { pageToHandlers.remove(page).foreach(_.foreach(detachHandler)) } - /** Attach a page to this UI. */ - def attachPage(page: WebUIPage) { + /** Attaches a page to this UI. */ + def attachPage(page: WebUIPage): Unit = { val pagePath = "/" + page.prefix val renderHandler = createServletHandler(pagePath, (request: HttpServletRequest) => page.render(request), securityManager, conf, basePath) @@ -88,41 +90,41 @@ private[spark] abstract class WebUI( handlers += renderHandler } - /** Attach a handler to this UI. */ - def attachHandler(handler: ServletContextHandler) { + /** Attaches a handler to this UI. */ + def attachHandler(handler: ServletContextHandler): Unit = { handlers += handler serverInfo.foreach(_.addHandler(handler)) } - /** Detach a handler from this UI. */ - def detachHandler(handler: ServletContextHandler) { + /** Detaches a handler from this UI. */ + def detachHandler(handler: ServletContextHandler): Unit = { handlers -= handler serverInfo.foreach(_.removeHandler(handler)) } /** - * Add a handler for static content. + * Detaches the content handler at `path` URI. * - * @param resourceBase Root of where to find resources to serve. - * @param path Path in UI where to mount the resources. + * @param path Path in UI to unmount. */ - def addStaticHandler(resourceBase: String, path: String): Unit = { - attachHandler(JettyUtils.createStaticHandler(resourceBase, path)) + def detachHandler(path: String): Unit = { + handlers.find(_.getContextPath() == path).foreach(detachHandler) } /** - * Remove a static content handler. + * Adds a handler for static content. * - * @param path Path in UI to unmount. + * @param resourceBase Root of where to find resources to serve. + * @param path Path in UI where to mount the resources. */ - def removeStaticHandler(path: String): Unit = { - handlers.find(_.getContextPath() == path).foreach(detachHandler) + def addStaticHandler(resourceBase: String, path: String = "/static"): Unit = { + attachHandler(JettyUtils.createStaticHandler(resourceBase, path)) } - /** Initialize all components of the server. */ + /** A hook to initialize components of the UI */ def initialize(): Unit - /** Bind to the HTTP server behind this web interface. */ + /** Binds to the HTTP server behind this web interface. */ def bind(): Unit = { assert(serverInfo.isEmpty, s"Attempted to bind $className more than once!") try { @@ -136,17 +138,17 @@ private[spark] abstract class WebUI( } } - /** Return the url of web interface. Only valid after bind(). */ + /** @return The url of web interface. Only valid after [[bind]]. */ def webUrl: String = s"http://$publicHostName:$boundPort" - /** Return the actual port to which this server is bound. Only valid after bind(). */ + /** @return The actual port to which this server is bound. Only valid after [[bind]]. */ def boundPort: Int = serverInfo.map(_.boundPort).getOrElse(-1) - /** Stop the server behind this web interface. Only valid after bind(). */ + /** Stops the server behind this web interface. Only valid after [[bind]]. */ def stop(): Unit = { assert(serverInfo.isDefined, s"Attempted to stop $className before binding to a server!") - serverInfo.get.stop() + serverInfo.foreach(_.stop()) } } diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterUI.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterUI.scala index 604978967d6db..15bbe60d6c8fb 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterUI.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterUI.scala @@ -40,7 +40,7 @@ private[spark] class MesosClusterUI( override def initialize() { attachPage(new MesosClusterPage(this)) attachPage(new DriverPage(this)) - attachHandler(createStaticHandler(MesosClusterUI.STATIC_RESOURCE_DIR, "/static")) + addStaticHandler(MesosClusterUI.STATIC_RESOURCE_DIR) } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingTab.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingTab.scala index 9d1b82a6341b1..25e71258b9369 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingTab.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingTab.scala @@ -49,7 +49,7 @@ private[spark] class StreamingTab(val ssc: StreamingContext) def detach() { getSparkUI(ssc).detachTab(this) - getSparkUI(ssc).removeStaticHandler("/static/streaming") + getSparkUI(ssc).detachHandler("/static/streaming") } } From b5ccf0d3957a444db93893c0ce4417bfbbb11822 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Fri, 15 Jun 2018 12:56:39 -0700 Subject: [PATCH 0970/1161] [SPARK-24396][SS][PYSPARK] Add Structured Streaming ForeachWriter for python ## What changes were proposed in this pull request? This PR adds `foreach` for streaming queries in Python. Users will be able to specify their processing logic in two different ways. - As a function that takes a row as input. - As an object that has methods `open`, `process`, and `close` methods. See the python docs in this PR for more details. ## How was this patch tested? Added java and python unit tests Author: Tathagata Das Closes #21477 from tdas/SPARK-24396. --- python/pyspark/sql/streaming.py | 162 +++++++++++ python/pyspark/sql/tests.py | 257 ++++++++++++++++++ python/pyspark/tests.py | 4 +- .../org/apache/spark/sql/ForeachWriter.scala | 62 ++++- .../python/PythonForeachWriter.scala | 161 +++++++++++ .../sources/ForeachWriterProvider.scala | 52 +++- .../sql/streaming/DataStreamWriter.scala | 48 +--- .../python/PythonForeachWriterSuite.scala | 137 ++++++++++ 8 files changed, 811 insertions(+), 72 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonForeachWriter.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonForeachWriterSuite.scala diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index fae50b3d5d532..4984593bab491 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -854,6 +854,168 @@ def trigger(self, processingTime=None, once=None, continuous=None): self._jwrite = self._jwrite.trigger(jTrigger) return self + @since(2.4) + def foreach(self, f): + """ + Sets the output of the streaming query to be processed using the provided writer ``f``. + This is often used to write the output of a streaming query to arbitrary storage systems. + The processing logic can be specified in two ways. + + #. A **function** that takes a row as input. + This is a simple way to express your processing logic. Note that this does + not allow you to deduplicate generated data when failures cause reprocessing of + some input data. That would require you to specify the processing logic in the next + way. + + #. An **object** with a ``process`` method and optional ``open`` and ``close`` methods. + The object can have the following methods. + + * ``open(partition_id, epoch_id)``: *Optional* method that initializes the processing + (for example, open a connection, start a transaction, etc). Additionally, you can + use the `partition_id` and `epoch_id` to deduplicate regenerated data + (discussed later). + + * ``process(row)``: *Non-optional* method that processes each :class:`Row`. + + * ``close(error)``: *Optional* method that finalizes and cleans up (for example, + close connection, commit transaction, etc.) after all rows have been processed. + + The object will be used by Spark in the following way. + + * A single copy of this object is responsible of all the data generated by a + single task in a query. In other words, one instance is responsible for + processing one partition of the data generated in a distributed manner. + + * This object must be serializable because each task will get a fresh + serialized-deserialized copy of the provided object. Hence, it is strongly + recommended that any initialization for writing data (e.g. opening a + connection or starting a transaction) is done after the `open(...)` + method has been called, which signifies that the task is ready to generate data. + + * The lifecycle of the methods are as follows. + + For each partition with ``partition_id``: + + ... For each batch/epoch of streaming data with ``epoch_id``: + + ....... Method ``open(partitionId, epochId)`` is called. + + ....... If ``open(...)`` returns true, for each row in the partition and + batch/epoch, method ``process(row)`` is called. + + ....... Method ``close(errorOrNull)`` is called with error (if any) seen while + processing rows. + + Important points to note: + + * The `partitionId` and `epochId` can be used to deduplicate generated data when + failures cause reprocessing of some input data. This depends on the execution + mode of the query. If the streaming query is being executed in the micro-batch + mode, then every partition represented by a unique tuple (partition_id, epoch_id) + is guaranteed to have the same data. Hence, (partition_id, epoch_id) can be used + to deduplicate and/or transactionally commit data and achieve exactly-once + guarantees. However, if the streaming query is being executed in the continuous + mode, then this guarantee does not hold and therefore should not be used for + deduplication. + + * The ``close()`` method (if exists) will be called if `open()` method exists and + returns successfully (irrespective of the return value), except if the Python + crashes in the middle. + + .. note:: Evolving. + + >>> # Print every row using a function + >>> def print_row(row): + ... print(row) + ... + >>> writer = sdf.writeStream.foreach(print_row) + >>> # Print every row using a object with process() method + >>> class RowPrinter: + ... def open(self, partition_id, epoch_id): + ... print("Opened %d, %d" % (partition_id, epoch_id)) + ... return True + ... def process(self, row): + ... print(row) + ... def close(self, error): + ... print("Closed with error: %s" % str(error)) + ... + >>> writer = sdf.writeStream.foreach(RowPrinter()) + """ + + from pyspark.rdd import _wrap_function + from pyspark.serializers import PickleSerializer, AutoBatchedSerializer + from pyspark.taskcontext import TaskContext + + if callable(f): + # The provided object is a callable function that is supposed to be called on each row. + # Construct a function that takes an iterator and calls the provided function on each + # row. + def func_without_process(_, iterator): + for x in iterator: + f(x) + return iter([]) + + func = func_without_process + + else: + # The provided object is not a callable function. Then it is expected to have a + # 'process(row)' method, and optional 'open(partition_id, epoch_id)' and + # 'close(error)' methods. + + if not hasattr(f, 'process'): + raise Exception("Provided object does not have a 'process' method") + + if not callable(getattr(f, 'process')): + raise Exception("Attribute 'process' in provided object is not callable") + + def doesMethodExist(method_name): + exists = hasattr(f, method_name) + if exists and not callable(getattr(f, method_name)): + raise Exception( + "Attribute '%s' in provided object is not callable" % method_name) + return exists + + open_exists = doesMethodExist('open') + close_exists = doesMethodExist('close') + + def func_with_open_process_close(partition_id, iterator): + epoch_id = TaskContext.get().getLocalProperty('streaming.sql.batchId') + if epoch_id: + epoch_id = int(epoch_id) + else: + raise Exception("Could not get batch id from TaskContext") + + # Check if the data should be processed + should_process = True + if open_exists: + should_process = f.open(partition_id, epoch_id) + + error = None + + try: + if should_process: + for x in iterator: + f.process(x) + except Exception as ex: + error = ex + finally: + if close_exists: + f.close(error) + if error: + raise error + + return iter([]) + + func = func_with_open_process_close + + serializer = AutoBatchedSerializer(PickleSerializer()) + wrapped_func = _wrap_function(self._spark._sc, func, serializer, serializer) + jForeachWriter = \ + self._spark._sc._jvm.org.apache.spark.sql.execution.python.PythonForeachWriter( + wrapped_func, self._df._jdf.schema()) + self._jwrite.foreach(jForeachWriter) + return self + @ignore_unicode_prefix @since(2.0) def start(self, path=None, format=None, outputMode=None, partitionBy=None, queryName=None, diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 2d7a4f62d4ee8..4e5fafa77e109 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -1869,6 +1869,263 @@ def test_query_manager_await_termination(self): q.stop() shutil.rmtree(tmpPath) + class ForeachWriterTester: + + def __init__(self, spark): + self.spark = spark + + def write_open_event(self, partitionId, epochId): + self._write_event( + self.open_events_dir, + {'partition': partitionId, 'epoch': epochId}) + + def write_process_event(self, row): + self._write_event(self.process_events_dir, {'value': 'text'}) + + def write_close_event(self, error): + self._write_event(self.close_events_dir, {'error': str(error)}) + + def write_input_file(self): + self._write_event(self.input_dir, "text") + + def open_events(self): + return self._read_events(self.open_events_dir, 'partition INT, epoch INT') + + def process_events(self): + return self._read_events(self.process_events_dir, 'value STRING') + + def close_events(self): + return self._read_events(self.close_events_dir, 'error STRING') + + def run_streaming_query_on_writer(self, writer, num_files): + self._reset() + try: + sdf = self.spark.readStream.format('text').load(self.input_dir) + sq = sdf.writeStream.foreach(writer).start() + for i in range(num_files): + self.write_input_file() + sq.processAllAvailable() + finally: + self.stop_all() + + def assert_invalid_writer(self, writer, msg=None): + self._reset() + try: + sdf = self.spark.readStream.format('text').load(self.input_dir) + sq = sdf.writeStream.foreach(writer).start() + self.write_input_file() + sq.processAllAvailable() + self.fail("invalid writer %s did not fail the query" % str(writer)) # not expected + except Exception as e: + if msg: + assert(msg in str(e), "%s not in %s" % (msg, str(e))) + + finally: + self.stop_all() + + def stop_all(self): + for q in self.spark._wrapped.streams.active: + q.stop() + + def _reset(self): + self.input_dir = tempfile.mkdtemp() + self.open_events_dir = tempfile.mkdtemp() + self.process_events_dir = tempfile.mkdtemp() + self.close_events_dir = tempfile.mkdtemp() + + def _read_events(self, dir, json): + rows = self.spark.read.schema(json).json(dir).collect() + dicts = [row.asDict() for row in rows] + return dicts + + def _write_event(self, dir, event): + import uuid + with open(os.path.join(dir, str(uuid.uuid4())), 'w') as f: + f.write("%s\n" % str(event)) + + def __getstate__(self): + return (self.open_events_dir, self.process_events_dir, self.close_events_dir) + + def __setstate__(self, state): + self.open_events_dir, self.process_events_dir, self.close_events_dir = state + + def test_streaming_foreach_with_simple_function(self): + tester = self.ForeachWriterTester(self.spark) + + def foreach_func(row): + tester.write_process_event(row) + + tester.run_streaming_query_on_writer(foreach_func, 2) + self.assertEqual(len(tester.process_events()), 2) + + def test_streaming_foreach_with_basic_open_process_close(self): + tester = self.ForeachWriterTester(self.spark) + + class ForeachWriter: + def open(self, partitionId, epochId): + tester.write_open_event(partitionId, epochId) + return True + + def process(self, row): + tester.write_process_event(row) + + def close(self, error): + tester.write_close_event(error) + + tester.run_streaming_query_on_writer(ForeachWriter(), 2) + + open_events = tester.open_events() + self.assertEqual(len(open_events), 2) + self.assertSetEqual(set([e['epoch'] for e in open_events]), {0, 1}) + + self.assertEqual(len(tester.process_events()), 2) + + close_events = tester.close_events() + self.assertEqual(len(close_events), 2) + self.assertSetEqual(set([e['error'] for e in close_events]), {'None'}) + + def test_streaming_foreach_with_open_returning_false(self): + tester = self.ForeachWriterTester(self.spark) + + class ForeachWriter: + def open(self, partition_id, epoch_id): + tester.write_open_event(partition_id, epoch_id) + return False + + def process(self, row): + tester.write_process_event(row) + + def close(self, error): + tester.write_close_event(error) + + tester.run_streaming_query_on_writer(ForeachWriter(), 2) + + self.assertEqual(len(tester.open_events()), 2) + + self.assertEqual(len(tester.process_events()), 0) # no row was processed + + close_events = tester.close_events() + self.assertEqual(len(close_events), 2) + self.assertSetEqual(set([e['error'] for e in close_events]), {'None'}) + + def test_streaming_foreach_without_open_method(self): + tester = self.ForeachWriterTester(self.spark) + + class ForeachWriter: + def process(self, row): + tester.write_process_event(row) + + def close(self, error): + tester.write_close_event(error) + + tester.run_streaming_query_on_writer(ForeachWriter(), 2) + self.assertEqual(len(tester.open_events()), 0) # no open events + self.assertEqual(len(tester.process_events()), 2) + self.assertEqual(len(tester.close_events()), 2) + + def test_streaming_foreach_without_close_method(self): + tester = self.ForeachWriterTester(self.spark) + + class ForeachWriter: + def open(self, partition_id, epoch_id): + tester.write_open_event(partition_id, epoch_id) + return True + + def process(self, row): + tester.write_process_event(row) + + tester.run_streaming_query_on_writer(ForeachWriter(), 2) + self.assertEqual(len(tester.open_events()), 2) # no open events + self.assertEqual(len(tester.process_events()), 2) + self.assertEqual(len(tester.close_events()), 0) + + def test_streaming_foreach_without_open_and_close_methods(self): + tester = self.ForeachWriterTester(self.spark) + + class ForeachWriter: + def process(self, row): + tester.write_process_event(row) + + tester.run_streaming_query_on_writer(ForeachWriter(), 2) + self.assertEqual(len(tester.open_events()), 0) # no open events + self.assertEqual(len(tester.process_events()), 2) + self.assertEqual(len(tester.close_events()), 0) + + def test_streaming_foreach_with_process_throwing_error(self): + from pyspark.sql.utils import StreamingQueryException + + tester = self.ForeachWriterTester(self.spark) + + class ForeachWriter: + def process(self, row): + raise Exception("test error") + + def close(self, error): + tester.write_close_event(error) + + try: + tester.run_streaming_query_on_writer(ForeachWriter(), 1) + self.fail("bad writer did not fail the query") # this is not expected + except StreamingQueryException as e: + # TODO: Verify whether original error message is inside the exception + pass + + self.assertEqual(len(tester.process_events()), 0) # no row was processed + close_events = tester.close_events() + self.assertEqual(len(close_events), 1) + # TODO: Verify whether original error message is inside the exception + + def test_streaming_foreach_with_invalid_writers(self): + + tester = self.ForeachWriterTester(self.spark) + + def func_with_iterator_input(iter): + for x in iter: + print(x) + + tester.assert_invalid_writer(func_with_iterator_input) + + class WriterWithoutProcess: + def open(self, partition): + pass + + tester.assert_invalid_writer(WriterWithoutProcess(), "does not have a 'process'") + + class WriterWithNonCallableProcess(): + process = True + + tester.assert_invalid_writer(WriterWithNonCallableProcess(), + "'process' in provided object is not callable") + + class WriterWithNoParamProcess(): + def process(self): + pass + + tester.assert_invalid_writer(WriterWithNoParamProcess()) + + # Abstract class for tests below + class WithProcess(): + def process(self, row): + pass + + class WriterWithNonCallableOpen(WithProcess): + open = True + + tester.assert_invalid_writer(WriterWithNonCallableOpen(), + "'open' in provided object is not callable") + + class WriterWithNoParamOpen(WithProcess): + def open(self): + pass + + tester.assert_invalid_writer(WriterWithNoParamOpen()) + + class WriterWithNonCallableClose(WithProcess): + close = True + + tester.assert_invalid_writer(WriterWithNonCallableClose(), + "'close' in provided object is not callable") + def test_help_command(self): # Regression test for SPARK-5464 rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}']) diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 18b2f251dc9fd..a4c5fb1db8b37 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -581,9 +581,9 @@ def test_get_local_property(self): self.sc.setLocalProperty(key, value) try: rdd = self.sc.parallelize(range(1), 1) - prop1 = rdd.map(lambda x: TaskContext.get().getLocalProperty(key)).collect()[0] + prop1 = rdd.map(lambda _: TaskContext.get().getLocalProperty(key)).collect()[0] self.assertEqual(prop1, value) - prop2 = rdd.map(lambda x: TaskContext.get().getLocalProperty("otherkey")).collect()[0] + prop2 = rdd.map(lambda _: TaskContext.get().getLocalProperty("otherkey")).collect()[0] self.assertTrue(prop2 is None) finally: self.sc.setLocalProperty(key, None) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ForeachWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/ForeachWriter.scala index 86e02e98c01f3..b21c50af18433 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/ForeachWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/ForeachWriter.scala @@ -20,10 +20,48 @@ package org.apache.spark.sql import org.apache.spark.annotation.InterfaceStability /** - * A class to consume data generated by a `StreamingQuery`. Typically this is used to send the - * generated data to external systems. Each partition will use a new deserialized instance, so you - * usually should do all the initialization (e.g. opening a connection or initiating a transaction) - * in the `open` method. + * The abstract class for writing custom logic to process data generated by a query. + * This is often used to write the output of a streaming query to arbitrary storage systems. + * Any implementation of this base class will be used by Spark in the following way. + * + *
      + *
    • A single instance of this class is responsible of all the data generated by a single task + * in a query. In other words, one instance is responsible for processing one partition of the + * data generated in a distributed manner. + * + *
    • Any implementation of this class must be serializable because each task will get a fresh + * serialized-deserialized copy of the provided object. Hence, it is strongly recommended that + * any initialization for writing data (e.g. opening a connection or starting a transaction) + * is done after the `open(...)` method has been called, which signifies that the task is + * ready to generate data. + * + *
    • The lifecycle of the methods are as follows. + * + *
      + *   For each partition with `partitionId`:
      + *       For each batch/epoch of streaming data (if its streaming query) with `epochId`:
      + *           Method `open(partitionId, epochId)` is called.
      + *           If `open` returns true:
      + *                For each row in the partition and batch/epoch, method `process(row)` is called.
      + *           Method `close(errorOrNull)` is called with error (if any) seen while processing rows.
      + *   
      + * + *
    + * + * Important points to note: + *
      + *
    • The `partitionId` and `epochId` can be used to deduplicate generated data when failures + * cause reprocessing of some input data. This depends on the execution mode of the query. If + * the streaming query is being executed in the micro-batch mode, then every partition + * represented by a unique tuple (partitionId, epochId) is guaranteed to have the same data. + * Hence, (partitionId, epochId) can be used to deduplicate and/or transactionally commit data + * and achieve exactly-once guarantees. However, if the streaming query is being executed in the + * continuous mode, then this guarantee does not hold and therefore should not be used for + * deduplication. + * + *
    • The `close()` method will be called if `open()` method returns successfully (irrespective + * of the return value), except if the JVM crashes in the middle. + *
    * * Scala example: * {{{ @@ -63,6 +101,7 @@ import org.apache.spark.annotation.InterfaceStability * } * }); * }}} + * * @since 2.0.0 */ @InterfaceStability.Evolving @@ -71,23 +110,18 @@ abstract class ForeachWriter[T] extends Serializable { // TODO: Move this to org.apache.spark.sql.util or consolidate this with batch API. /** - * Called when starting to process one partition of new data in the executor. The `version` is - * for data deduplication when there are failures. When recovering from a failure, some data may - * be generated multiple times but they will always have the same version. - * - * If this method finds using the `partitionId` and `version` that this partition has already been - * processed, it can return `false` to skip the further data processing. However, `close` still - * will be called for cleaning up resources. + * Called when starting to process one partition of new data in the executor. See the class + * docs for more information on how to use the `partitionId` and `epochId`. * * @param partitionId the partition id. - * @param version a unique id for data deduplication. + * @param epochId a unique id for data deduplication. * @return `true` if the corresponding partition and version id should be processed. `false` * indicates the partition should be skipped. */ - def open(partitionId: Long, version: Long): Boolean + def open(partitionId: Long, epochId: Long): Boolean /** - * Called to process the data in the executor side. This method will be called only when `open` + * Called to process the data in the executor side. This method will be called only if `open` * returns `true`. */ def process(value: T): Unit diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonForeachWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonForeachWriter.scala new file mode 100644 index 0000000000000..a58773122922f --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonForeachWriter.scala @@ -0,0 +1,161 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.python + +import java.io.File +import java.util.concurrent.TimeUnit +import java.util.concurrent.locks.ReentrantLock + +import org.apache.spark.{SparkEnv, TaskContext} +import org.apache.spark.api.python._ +import org.apache.spark.internal.Logging +import org.apache.spark.memory.TaskMemoryManager +import org.apache.spark.sql.ForeachWriter +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.types.StructType +import org.apache.spark.util.{NextIterator, Utils} + +class PythonForeachWriter(func: PythonFunction, schema: StructType) + extends ForeachWriter[UnsafeRow] { + + private lazy val context = TaskContext.get() + private lazy val buffer = new PythonForeachWriter.UnsafeRowBuffer( + context.taskMemoryManager, new File(Utils.getLocalDir(SparkEnv.get.conf)), schema.fields.length) + private lazy val inputRowIterator = buffer.iterator + + private lazy val inputByteIterator = { + EvaluatePython.registerPicklers() + val objIterator = inputRowIterator.map { row => EvaluatePython.toJava(row, schema) } + new SerDeUtil.AutoBatchedPickler(objIterator) + } + + private lazy val pythonRunner = { + val conf = SparkEnv.get.conf + val bufferSize = conf.getInt("spark.buffer.size", 65536) + val reuseWorker = conf.getBoolean("spark.python.worker.reuse", true) + PythonRunner(func, bufferSize, reuseWorker) + } + + private lazy val outputIterator = + pythonRunner.compute(inputByteIterator, context.partitionId(), context) + + override def open(partitionId: Long, version: Long): Boolean = { + outputIterator // initialize everything + TaskContext.get.addTaskCompletionListener { _ => buffer.close() } + true + } + + override def process(value: UnsafeRow): Unit = { + buffer.add(value) + } + + override def close(errorOrNull: Throwable): Unit = { + buffer.allRowsAdded() + if (outputIterator.hasNext) outputIterator.next() // to throw python exception if there was one + } +} + +object PythonForeachWriter { + + /** + * A buffer that is designed for the sole purpose of buffering UnsafeRows in PythonForeachWriter. + * It is designed to be used with only 1 writer thread (i.e. JVM task thread) and only 1 reader + * thread (i.e. PythonRunner writing thread that reads from the buffer and writes to the Python + * worker stdin). Adds to the buffer are non-blocking, and reads through the buffer's iterator + * are blocking, that is, it blocks until new data is available or all data has been added. + * + * Internally, it uses a [[HybridRowQueue]] to buffer the rows in a practically unlimited queue + * across memory and local disk. However, HybridRowQueue is designed to be used only with + * EvalPythonExec where the reader is always behind the the writer, that is, the reader does not + * try to read n+1 rows if the writer has only written n rows at any point of time. This + * assumption is not true for PythonForeachWriter where rows may be added at a different rate as + * they are consumed by the python worker. Hence, to maintain the invariant of the reader being + * behind the writer while using HybridRowQueue, the buffer does the following + * - Keeps a count of the rows in the HybridRowQueue + * - Blocks the buffer's consuming iterator when the count is 0 so that the reader does not + * try to read more rows than what has been written. + * + * The implementation of the blocking iterator (ReentrantLock, Condition, etc.) has been borrowed + * from that of ArrayBlockingQueue. + */ + class UnsafeRowBuffer(taskMemoryManager: TaskMemoryManager, tempDir: File, numFields: Int) + extends Logging { + private val queue = HybridRowQueue(taskMemoryManager, tempDir, numFields) + private val lock = new ReentrantLock() + private val unblockRemove = lock.newCondition() + + // All of these are guarded by `lock` + private var count = 0L + private var allAdded = false + private var exception: Throwable = null + + val iterator = new NextIterator[UnsafeRow] { + override protected def getNext(): UnsafeRow = { + val row = remove() + if (row == null) finished = true + row + } + override protected def close(): Unit = { } + } + + def add(row: UnsafeRow): Unit = withLock { + assert(queue.add(row), s"Failed to add row to HybridRowQueue while sending data to Python" + + s"[count = $count, allAdded = $allAdded, exception = $exception]") + count += 1 + unblockRemove.signal() + logTrace(s"Added $row, $count left") + } + + private def remove(): UnsafeRow = withLock { + while (count == 0 && !allAdded && exception == null) { + unblockRemove.await(100, TimeUnit.MILLISECONDS) + } + + // If there was any error in the adding thread, then rethrow it in the removing thread + if (exception != null) throw exception + + if (count > 0) { + val row = queue.remove() + assert(row != null, "HybridRowQueue.remove() returned null " + + s"[count = $count, allAdded = $allAdded, exception = $exception]") + count -= 1 + logTrace(s"Removed $row, $count left") + row + } else { + null + } + } + + def allRowsAdded(): Unit = withLock { + allAdded = true + unblockRemove.signal() + } + + def close(): Unit = { queue.close() } + + private def withLock[T](f: => T): T = { + lock.lockInterruptibly() + try { f } catch { + case e: Throwable => + if (exception == null) exception = e + throw e + } finally { lock.unlock() } + } + } +} + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterProvider.scala index df5d69d57e36f..f677f25f116a2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterProvider.scala @@ -20,6 +20,8 @@ package org.apache.spark.sql.execution.streaming.sources import org.apache.spark.sql.{Encoder, ForeachWriter, SparkSession} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder} +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.execution.python.PythonForeachWriter import org.apache.spark.sql.sources.v2.{DataSourceOptions, StreamWriteSupport} import org.apache.spark.sql.sources.v2.writer.{DataWriter, DataWriterFactory, SupportsWriteInternalRow, WriterCommitMessage} import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter @@ -31,9 +33,14 @@ import org.apache.spark.sql.types.StructType * [[ForeachWriter]]. * * @param writer The [[ForeachWriter]] to process all data. + * @param converter An object to convert internal rows to target type T. Either it can be + * a [[ExpressionEncoder]] or a direct converter function. * @tparam T The expected type of the sink. */ -case class ForeachWriterProvider[T: Encoder](writer: ForeachWriter[T]) extends StreamWriteSupport { +case class ForeachWriterProvider[T]( + writer: ForeachWriter[T], + converter: Either[ExpressionEncoder[T], InternalRow => T]) extends StreamWriteSupport { + override def createStreamWriter( queryId: String, schema: StructType, @@ -44,10 +51,16 @@ case class ForeachWriterProvider[T: Encoder](writer: ForeachWriter[T]) extends S override def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {} override def createInternalRowWriterFactory(): DataWriterFactory[InternalRow] = { - val encoder = encoderFor[T].resolveAndBind( - schema.toAttributes, - SparkSession.getActiveSession.get.sessionState.analyzer) - ForeachWriterFactory(writer, encoder) + val rowConverter: InternalRow => T = converter match { + case Left(enc) => + val boundEnc = enc.resolveAndBind( + schema.toAttributes, + SparkSession.getActiveSession.get.sessionState.analyzer) + boundEnc.fromRow + case Right(func) => + func + } + ForeachWriterFactory(writer, rowConverter) } override def toString: String = "ForeachSink" @@ -55,29 +68,44 @@ case class ForeachWriterProvider[T: Encoder](writer: ForeachWriter[T]) extends S } } -case class ForeachWriterFactory[T: Encoder]( +object ForeachWriterProvider { + def apply[T]( + writer: ForeachWriter[T], + encoder: ExpressionEncoder[T]): ForeachWriterProvider[_] = { + writer match { + case pythonWriter: PythonForeachWriter => + new ForeachWriterProvider[UnsafeRow]( + pythonWriter, Right((x: InternalRow) => x.asInstanceOf[UnsafeRow])) + case _ => + new ForeachWriterProvider[T](writer, Left(encoder)) + } + } +} + +case class ForeachWriterFactory[T]( writer: ForeachWriter[T], - encoder: ExpressionEncoder[T]) + rowConverter: InternalRow => T) extends DataWriterFactory[InternalRow] { override def createDataWriter( partitionId: Int, attemptNumber: Int, epochId: Long): ForeachDataWriter[T] = { - new ForeachDataWriter(writer, encoder, partitionId, epochId) + new ForeachDataWriter(writer, rowConverter, partitionId, epochId) } } /** * A [[DataWriter]] which writes data in this partition to a [[ForeachWriter]]. + * * @param writer The [[ForeachWriter]] to process all data. - * @param encoder An encoder which can convert [[InternalRow]] to the required type [[T]] + * @param rowConverter A function which can convert [[InternalRow]] to the required type [[T]] * @param partitionId * @param epochId * @tparam T The type expected by the writer. */ -class ForeachDataWriter[T : Encoder]( +class ForeachDataWriter[T]( writer: ForeachWriter[T], - encoder: ExpressionEncoder[T], + rowConverter: InternalRow => T, partitionId: Int, epochId: Long) extends DataWriter[InternalRow] { @@ -89,7 +117,7 @@ class ForeachDataWriter[T : Encoder]( if (!opened) return try { - writer.process(encoder.fromRow(record)) + writer.process(rowConverter(record)) } catch { case t: Throwable => writer.close(t) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala index effc1471e8e12..e035c9cdc379e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala @@ -269,7 +269,7 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { query } else if (source == "foreach") { assertNotPartitioned("foreach") - val sink = new ForeachWriterProvider[T](foreachWriter)(ds.exprEnc) + val sink = ForeachWriterProvider[T](foreachWriter, ds.exprEnc) df.sparkSession.sessionState.streamingQueryManager.startQuery( extraOptions.get("queryName"), extraOptions.get("checkpointLocation"), @@ -307,49 +307,9 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { } /** - * Starts the execution of the streaming query, which will continually send results to the given - * `ForeachWriter` as new data arrives. The `ForeachWriter` can be used to send the data - * generated by the `DataFrame`/`Dataset` to an external system. - * - * Scala example: - * {{{ - * datasetOfString.writeStream.foreach(new ForeachWriter[String] { - * - * def open(partitionId: Long, version: Long): Boolean = { - * // open connection - * } - * - * def process(record: String) = { - * // write string to connection - * } - * - * def close(errorOrNull: Throwable): Unit = { - * // close the connection - * } - * }).start() - * }}} - * - * Java example: - * {{{ - * datasetOfString.writeStream().foreach(new ForeachWriter() { - * - * @Override - * public boolean open(long partitionId, long version) { - * // open connection - * } - * - * @Override - * public void process(String value) { - * // write string to connection - * } - * - * @Override - * public void close(Throwable errorOrNull) { - * // close the connection - * } - * }).start(); - * }}} - * + * Sets the output of the streaming query to be processed using the provided writer object. + * object. See [[org.apache.spark.sql.ForeachWriter]] for more details on the lifecycle and + * semantics. * @since 2.0.0 */ def foreach(writer: ForeachWriter[T]): DataStreamWriter[T] = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonForeachWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonForeachWriterSuite.scala new file mode 100644 index 0000000000000..07e6034770127 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonForeachWriterSuite.scala @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.python + +import scala.collection.mutable.ArrayBuffer + +import org.scalatest.concurrent.Eventually +import org.scalatest.time.SpanSugar._ + +import org.apache.spark._ +import org.apache.spark.memory.{TaskMemoryManager, TestMemoryManager} +import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection} +import org.apache.spark.sql.execution.python.PythonForeachWriter.UnsafeRowBuffer +import org.apache.spark.sql.types.{DataType, IntegerType} +import org.apache.spark.util.Utils + +class PythonForeachWriterSuite extends SparkFunSuite with Eventually { + + testWithBuffer("UnsafeRowBuffer: iterator blocks when no data is available") { b => + b.assertIteratorBlocked() + + b.add(Seq(1)) + b.assertOutput(Seq(1)) + b.assertIteratorBlocked() + + b.add(2 to 100) + b.assertOutput(1 to 100) + b.assertIteratorBlocked() + } + + testWithBuffer("UnsafeRowBuffer: iterator unblocks when all data added") { b => + b.assertIteratorBlocked() + b.add(Seq(1)) + b.assertIteratorBlocked() + + b.allAdded() + b.assertThreadTerminated() + b.assertOutput(Seq(1)) + } + + testWithBuffer( + "UnsafeRowBuffer: handles more data than memory", + memBytes = 5, + sleepPerRowReadMs = 1) { b => + + b.assertIteratorBlocked() + b.add(1 to 2000) + b.assertOutput(1 to 2000) + } + + def testWithBuffer( + name: String, + memBytes: Long = 4 << 10, + sleepPerRowReadMs: Int = 0 + )(f: BufferTester => Unit): Unit = { + + test(name) { + var tester: BufferTester = null + try { + tester = new BufferTester(memBytes, sleepPerRowReadMs) + f(tester) + } finally { + if (tester == null) tester.close() + } + } + } + + + class BufferTester(memBytes: Long, sleepPerRowReadMs: Int) { + private val buffer = { + val mem = new TestMemoryManager(new SparkConf()) + mem.limit(memBytes) + val taskM = new TaskMemoryManager(mem, 0) + new UnsafeRowBuffer(taskM, Utils.createTempDir(), 1) + } + private val iterator = buffer.iterator + private val outputBuffer = new ArrayBuffer[Int] + private val testTimeout = timeout(20.seconds) + private val intProj = UnsafeProjection.create(Array[DataType](IntegerType)) + private val thread = new Thread() { + override def run(): Unit = { + while (iterator.hasNext) { + outputBuffer.synchronized { + outputBuffer += iterator.next().getInt(0) + } + Thread.sleep(sleepPerRowReadMs) + } + } + } + thread.start() + + def add(ints: Seq[Int]): Unit = { + ints.foreach { i => buffer.add(intProj.apply(new GenericInternalRow(Array[Any](i)))) } + } + + def allAdded(): Unit = { buffer.allRowsAdded() } + + def assertOutput(expectedOutput: Seq[Int]): Unit = { + eventually(testTimeout) { + val output = outputBuffer.synchronized { outputBuffer.toArray }.toSeq + assert(output == expectedOutput) + } + } + + def assertIteratorBlocked(): Unit = { + import Thread.State._ + eventually(testTimeout) { + assert(thread.isAlive) + assert(thread.getState == TIMED_WAITING || thread.getState == WAITING) + } + } + + def assertThreadTerminated(): Unit = { + eventually(testTimeout) { assert(!thread.isAlive) } + } + + def close(): Unit = { + thread.interrupt() + thread.join() + } + } +} From 90da7dc241f8eec2348c0434312c97c116330bc4 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Fri, 15 Jun 2018 13:47:48 -0700 Subject: [PATCH 0971/1161] [SPARK-24452][SQL][CORE] Avoid possible overflow in int add or multiple ## What changes were proposed in this pull request? This PR fixes possible overflow in int add or multiply. In particular, their overflows in multiply are detected by [Spotbugs](https://spotbugs.github.io/) The following assignments may cause overflow in right hand side. As a result, the result may be negative. ``` long = int * int long = int + int ``` To avoid this problem, this PR performs cast from int to long in right hand side. ## How was this patch tested? Existing UTs. Author: Kazuaki Ishizaki Closes #21481 from kiszk/SPARK-24452. --- .../spark/unsafe/map/BytesToBytesMap.java | 2 +- .../spark/deploy/worker/DriverRunner.scala | 2 +- .../apache/spark/rdd/AsyncRDDActions.scala | 2 +- .../apache/spark/storage/BlockManager.scala | 2 +- .../catalyst/expressions/UnsafeArrayData.java | 14 +-- .../VariableLengthRowBasedKeyValueBatch.java | 2 +- .../vectorized/OffHeapColumnVector.java | 106 +++++++++--------- .../vectorized/OnHeapColumnVector.java | 10 +- .../sources/RateStreamMicroBatchReader.scala | 2 +- .../spark/sql/hive/client/HiveShim.scala | 2 +- .../util/FileBasedWriteAheadLog.scala | 2 +- 11 files changed, 73 insertions(+), 73 deletions(-) diff --git a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java index 5f0045507aaab..9a767dd739b91 100644 --- a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java +++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java @@ -703,7 +703,7 @@ public boolean append(Object kbase, long koff, int klen, Object vbase, long voff // must be stored in the same memory page. // (8 byte key length) (key) (value) (8 byte pointer to next value) int uaoSize = UnsafeAlignedOffset.getUaoSize(); - final long recordLength = (2 * uaoSize) + klen + vlen + 8; + final long recordLength = (2L * uaoSize) + klen + vlen + 8; if (currentPage == null || currentPage.size() - pageCursor < recordLength) { if (!acquireNewPage(recordLength + uaoSize)) { return false; diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala index 58a181128eb4d..a6d13d12fc28d 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala @@ -225,7 +225,7 @@ private[deploy] class DriverRunner( // check if attempting another run keepTrying = supervise && exitCode != 0 && !killed if (keepTrying) { - if (clock.getTimeMillis() - processStart > successfulRunDuration * 1000) { + if (clock.getTimeMillis() - processStart > successfulRunDuration * 1000L) { waitSeconds = 1 } logInfo(s"Command exited with status $exitCode, re-launching after $waitSeconds s.") diff --git a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala index 13db4985b0b80..ba9dae4ad48ec 100644 --- a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala @@ -95,7 +95,7 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi // the left side of max is >=1 whenever partsScanned >= 2 numPartsToTry = Math.max(1, (1.5 * num * partsScanned / results.size).toInt - partsScanned) - numPartsToTry = Math.min(numPartsToTry, partsScanned * 4) + numPartsToTry = Math.min(numPartsToTry, partsScanned * 4L) } } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index e0276a4dc4224..df1a4bef616b2 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -291,7 +291,7 @@ private[spark] class BlockManager( case e: Exception if i < MAX_ATTEMPTS => logError(s"Failed to connect to external shuffle server, will retry ${MAX_ATTEMPTS - i}" + s" more times after waiting $SLEEP_TIME_SECS seconds...", e) - Thread.sleep(SLEEP_TIME_SECS * 1000) + Thread.sleep(SLEEP_TIME_SECS * 1000L) case NonFatal(e) => throw new SparkException("Unable to register with external shuffle server due to : " + e.getMessage, e) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java index d5d934bc91cab..4dd2b7365652a 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java @@ -83,7 +83,7 @@ public static long calculateSizeOfUnderlyingByteArray(long numFields, int elemen private long elementOffset; private long getElementOffset(int ordinal, int elementSize) { - return elementOffset + ordinal * elementSize; + return elementOffset + ordinal * (long)elementSize; } public Object getBaseObject() { return baseObject; } @@ -414,7 +414,7 @@ public byte[] toByteArray() { public short[] toShortArray() { short[] values = new short[numElements]; Platform.copyMemory( - baseObject, elementOffset, values, Platform.SHORT_ARRAY_OFFSET, numElements * 2); + baseObject, elementOffset, values, Platform.SHORT_ARRAY_OFFSET, numElements * 2L); return values; } @@ -422,7 +422,7 @@ public short[] toShortArray() { public int[] toIntArray() { int[] values = new int[numElements]; Platform.copyMemory( - baseObject, elementOffset, values, Platform.INT_ARRAY_OFFSET, numElements * 4); + baseObject, elementOffset, values, Platform.INT_ARRAY_OFFSET, numElements * 4L); return values; } @@ -430,7 +430,7 @@ public int[] toIntArray() { public long[] toLongArray() { long[] values = new long[numElements]; Platform.copyMemory( - baseObject, elementOffset, values, Platform.LONG_ARRAY_OFFSET, numElements * 8); + baseObject, elementOffset, values, Platform.LONG_ARRAY_OFFSET, numElements * 8L); return values; } @@ -438,7 +438,7 @@ public long[] toLongArray() { public float[] toFloatArray() { float[] values = new float[numElements]; Platform.copyMemory( - baseObject, elementOffset, values, Platform.FLOAT_ARRAY_OFFSET, numElements * 4); + baseObject, elementOffset, values, Platform.FLOAT_ARRAY_OFFSET, numElements * 4L); return values; } @@ -446,14 +446,14 @@ public float[] toFloatArray() { public double[] toDoubleArray() { double[] values = new double[numElements]; Platform.copyMemory( - baseObject, elementOffset, values, Platform.DOUBLE_ARRAY_OFFSET, numElements * 8); + baseObject, elementOffset, values, Platform.DOUBLE_ARRAY_OFFSET, numElements * 8L); return values; } private static UnsafeArrayData fromPrimitiveArray( Object arr, int offset, int length, int elementSize) { final long headerInBytes = calculateHeaderPortionInBytes(length); - final long valueRegionInBytes = elementSize * length; + final long valueRegionInBytes = (long)elementSize * length; final long totalSizeInLongs = (headerInBytes + valueRegionInBytes + 7) / 8; if (totalSizeInLongs > Integer.MAX_VALUE / 8) { throw new UnsupportedOperationException("Cannot convert this array to unsafe format as " + diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/VariableLengthRowBasedKeyValueBatch.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/VariableLengthRowBasedKeyValueBatch.java index 905e6820ce6e2..c823de4810f2b 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/VariableLengthRowBasedKeyValueBatch.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/VariableLengthRowBasedKeyValueBatch.java @@ -41,7 +41,7 @@ public final class VariableLengthRowBasedKeyValueBatch extends RowBasedKeyValueB @Override public UnsafeRow appendRow(Object kbase, long koff, int klen, Object vbase, long voff, int vlen) { - final long recordLength = 8 + klen + vlen + 8; + final long recordLength = 8L + klen + vlen + 8; // if run out of max supported rows or page size, return null if (numRows >= capacity || page == null || page.size() - pageCursor < recordLength) { return null; diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java index 4733f36174f42..6fdadde628551 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java @@ -216,12 +216,12 @@ protected UTF8String getBytesAsUTF8String(int rowId, int count) { @Override public void putShort(int rowId, short value) { - Platform.putShort(null, data + 2 * rowId, value); + Platform.putShort(null, data + 2L * rowId, value); } @Override public void putShorts(int rowId, int count, short value) { - long offset = data + 2 * rowId; + long offset = data + 2L * rowId; for (int i = 0; i < count; ++i, offset += 2) { Platform.putShort(null, offset, value); } @@ -229,20 +229,20 @@ public void putShorts(int rowId, int count, short value) { @Override public void putShorts(int rowId, int count, short[] src, int srcIndex) { - Platform.copyMemory(src, Platform.SHORT_ARRAY_OFFSET + srcIndex * 2, - null, data + 2 * rowId, count * 2); + Platform.copyMemory(src, Platform.SHORT_ARRAY_OFFSET + srcIndex * 2L, + null, data + 2L * rowId, count * 2L); } @Override public void putShorts(int rowId, int count, byte[] src, int srcIndex) { Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, - null, data + rowId * 2, count * 2); + null, data + rowId * 2L, count * 2L); } @Override public short getShort(int rowId) { if (dictionary == null) { - return Platform.getShort(null, data + 2 * rowId); + return Platform.getShort(null, data + 2L * rowId); } else { return (short) dictionary.decodeToInt(dictionaryIds.getDictId(rowId)); } @@ -252,7 +252,7 @@ public short getShort(int rowId) { public short[] getShorts(int rowId, int count) { assert(dictionary == null); short[] array = new short[count]; - Platform.copyMemory(null, data + rowId * 2, array, Platform.SHORT_ARRAY_OFFSET, count * 2); + Platform.copyMemory(null, data + rowId * 2L, array, Platform.SHORT_ARRAY_OFFSET, count * 2L); return array; } @@ -262,12 +262,12 @@ public short[] getShorts(int rowId, int count) { @Override public void putInt(int rowId, int value) { - Platform.putInt(null, data + 4 * rowId, value); + Platform.putInt(null, data + 4L * rowId, value); } @Override public void putInts(int rowId, int count, int value) { - long offset = data + 4 * rowId; + long offset = data + 4L * rowId; for (int i = 0; i < count; ++i, offset += 4) { Platform.putInt(null, offset, value); } @@ -275,24 +275,24 @@ public void putInts(int rowId, int count, int value) { @Override public void putInts(int rowId, int count, int[] src, int srcIndex) { - Platform.copyMemory(src, Platform.INT_ARRAY_OFFSET + srcIndex * 4, - null, data + 4 * rowId, count * 4); + Platform.copyMemory(src, Platform.INT_ARRAY_OFFSET + srcIndex * 4L, + null, data + 4L * rowId, count * 4L); } @Override public void putInts(int rowId, int count, byte[] src, int srcIndex) { Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, - null, data + rowId * 4, count * 4); + null, data + rowId * 4L, count * 4L); } @Override public void putIntsLittleEndian(int rowId, int count, byte[] src, int srcIndex) { if (!bigEndianPlatform) { Platform.copyMemory(src, srcIndex + Platform.BYTE_ARRAY_OFFSET, - null, data + 4 * rowId, count * 4); + null, data + 4L * rowId, count * 4L); } else { int srcOffset = srcIndex + Platform.BYTE_ARRAY_OFFSET; - long offset = data + 4 * rowId; + long offset = data + 4L * rowId; for (int i = 0; i < count; ++i, offset += 4, srcOffset += 4) { Platform.putInt(null, offset, java.lang.Integer.reverseBytes(Platform.getInt(src, srcOffset))); @@ -303,7 +303,7 @@ public void putIntsLittleEndian(int rowId, int count, byte[] src, int srcIndex) @Override public int getInt(int rowId) { if (dictionary == null) { - return Platform.getInt(null, data + 4 * rowId); + return Platform.getInt(null, data + 4L * rowId); } else { return dictionary.decodeToInt(dictionaryIds.getDictId(rowId)); } @@ -313,7 +313,7 @@ public int getInt(int rowId) { public int[] getInts(int rowId, int count) { assert(dictionary == null); int[] array = new int[count]; - Platform.copyMemory(null, data + rowId * 4, array, Platform.INT_ARRAY_OFFSET, count * 4); + Platform.copyMemory(null, data + rowId * 4L, array, Platform.INT_ARRAY_OFFSET, count * 4L); return array; } @@ -325,7 +325,7 @@ public int[] getInts(int rowId, int count) { public int getDictId(int rowId) { assert(dictionary == null) : "A ColumnVector dictionary should not have a dictionary for itself."; - return Platform.getInt(null, data + 4 * rowId); + return Platform.getInt(null, data + 4L * rowId); } // @@ -334,12 +334,12 @@ public int getDictId(int rowId) { @Override public void putLong(int rowId, long value) { - Platform.putLong(null, data + 8 * rowId, value); + Platform.putLong(null, data + 8L * rowId, value); } @Override public void putLongs(int rowId, int count, long value) { - long offset = data + 8 * rowId; + long offset = data + 8L * rowId; for (int i = 0; i < count; ++i, offset += 8) { Platform.putLong(null, offset, value); } @@ -347,24 +347,24 @@ public void putLongs(int rowId, int count, long value) { @Override public void putLongs(int rowId, int count, long[] src, int srcIndex) { - Platform.copyMemory(src, Platform.LONG_ARRAY_OFFSET + srcIndex * 8, - null, data + 8 * rowId, count * 8); + Platform.copyMemory(src, Platform.LONG_ARRAY_OFFSET + srcIndex * 8L, + null, data + 8L * rowId, count * 8L); } @Override public void putLongs(int rowId, int count, byte[] src, int srcIndex) { Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, - null, data + rowId * 8, count * 8); + null, data + rowId * 8L, count * 8L); } @Override public void putLongsLittleEndian(int rowId, int count, byte[] src, int srcIndex) { if (!bigEndianPlatform) { Platform.copyMemory(src, srcIndex + Platform.BYTE_ARRAY_OFFSET, - null, data + 8 * rowId, count * 8); + null, data + 8L * rowId, count * 8L); } else { int srcOffset = srcIndex + Platform.BYTE_ARRAY_OFFSET; - long offset = data + 8 * rowId; + long offset = data + 8L * rowId; for (int i = 0; i < count; ++i, offset += 8, srcOffset += 8) { Platform.putLong(null, offset, java.lang.Long.reverseBytes(Platform.getLong(src, srcOffset))); @@ -375,7 +375,7 @@ public void putLongsLittleEndian(int rowId, int count, byte[] src, int srcIndex) @Override public long getLong(int rowId) { if (dictionary == null) { - return Platform.getLong(null, data + 8 * rowId); + return Platform.getLong(null, data + 8L * rowId); } else { return dictionary.decodeToLong(dictionaryIds.getDictId(rowId)); } @@ -385,7 +385,7 @@ public long getLong(int rowId) { public long[] getLongs(int rowId, int count) { assert(dictionary == null); long[] array = new long[count]; - Platform.copyMemory(null, data + rowId * 8, array, Platform.LONG_ARRAY_OFFSET, count * 8); + Platform.copyMemory(null, data + rowId * 8L, array, Platform.LONG_ARRAY_OFFSET, count * 8L); return array; } @@ -395,12 +395,12 @@ public long[] getLongs(int rowId, int count) { @Override public void putFloat(int rowId, float value) { - Platform.putFloat(null, data + rowId * 4, value); + Platform.putFloat(null, data + rowId * 4L, value); } @Override public void putFloats(int rowId, int count, float value) { - long offset = data + 4 * rowId; + long offset = data + 4L * rowId; for (int i = 0; i < count; ++i, offset += 4) { Platform.putFloat(null, offset, value); } @@ -408,18 +408,18 @@ public void putFloats(int rowId, int count, float value) { @Override public void putFloats(int rowId, int count, float[] src, int srcIndex) { - Platform.copyMemory(src, Platform.FLOAT_ARRAY_OFFSET + srcIndex * 4, - null, data + 4 * rowId, count * 4); + Platform.copyMemory(src, Platform.FLOAT_ARRAY_OFFSET + srcIndex * 4L, + null, data + 4L * rowId, count * 4L); } @Override public void putFloats(int rowId, int count, byte[] src, int srcIndex) { if (!bigEndianPlatform) { Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, - null, data + rowId * 4, count * 4); + null, data + rowId * 4L, count * 4L); } else { ByteBuffer bb = ByteBuffer.wrap(src).order(ByteOrder.LITTLE_ENDIAN); - long offset = data + 4 * rowId; + long offset = data + 4L * rowId; for (int i = 0; i < count; ++i, offset += 4) { Platform.putFloat(null, offset, bb.getFloat(srcIndex + (4 * i))); } @@ -429,7 +429,7 @@ public void putFloats(int rowId, int count, byte[] src, int srcIndex) { @Override public float getFloat(int rowId) { if (dictionary == null) { - return Platform.getFloat(null, data + rowId * 4); + return Platform.getFloat(null, data + rowId * 4L); } else { return dictionary.decodeToFloat(dictionaryIds.getDictId(rowId)); } @@ -439,7 +439,7 @@ public float getFloat(int rowId) { public float[] getFloats(int rowId, int count) { assert(dictionary == null); float[] array = new float[count]; - Platform.copyMemory(null, data + rowId * 4, array, Platform.FLOAT_ARRAY_OFFSET, count * 4); + Platform.copyMemory(null, data + rowId * 4L, array, Platform.FLOAT_ARRAY_OFFSET, count * 4L); return array; } @@ -450,12 +450,12 @@ public float[] getFloats(int rowId, int count) { @Override public void putDouble(int rowId, double value) { - Platform.putDouble(null, data + rowId * 8, value); + Platform.putDouble(null, data + rowId * 8L, value); } @Override public void putDoubles(int rowId, int count, double value) { - long offset = data + 8 * rowId; + long offset = data + 8L * rowId; for (int i = 0; i < count; ++i, offset += 8) { Platform.putDouble(null, offset, value); } @@ -463,18 +463,18 @@ public void putDoubles(int rowId, int count, double value) { @Override public void putDoubles(int rowId, int count, double[] src, int srcIndex) { - Platform.copyMemory(src, Platform.DOUBLE_ARRAY_OFFSET + srcIndex * 8, - null, data + 8 * rowId, count * 8); + Platform.copyMemory(src, Platform.DOUBLE_ARRAY_OFFSET + srcIndex * 8L, + null, data + 8L * rowId, count * 8L); } @Override public void putDoubles(int rowId, int count, byte[] src, int srcIndex) { if (!bigEndianPlatform) { Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, - null, data + rowId * 8, count * 8); + null, data + rowId * 8L, count * 8L); } else { ByteBuffer bb = ByteBuffer.wrap(src).order(ByteOrder.LITTLE_ENDIAN); - long offset = data + 8 * rowId; + long offset = data + 8L * rowId; for (int i = 0; i < count; ++i, offset += 8) { Platform.putDouble(null, offset, bb.getDouble(srcIndex + (8 * i))); } @@ -484,7 +484,7 @@ public void putDoubles(int rowId, int count, byte[] src, int srcIndex) { @Override public double getDouble(int rowId) { if (dictionary == null) { - return Platform.getDouble(null, data + rowId * 8); + return Platform.getDouble(null, data + rowId * 8L); } else { return dictionary.decodeToDouble(dictionaryIds.getDictId(rowId)); } @@ -494,7 +494,7 @@ public double getDouble(int rowId) { public double[] getDoubles(int rowId, int count) { assert(dictionary == null); double[] array = new double[count]; - Platform.copyMemory(null, data + rowId * 8, array, Platform.DOUBLE_ARRAY_OFFSET, count * 8); + Platform.copyMemory(null, data + rowId * 8L, array, Platform.DOUBLE_ARRAY_OFFSET, count * 8L); return array; } @@ -504,26 +504,26 @@ public double[] getDoubles(int rowId, int count) { @Override public void putArray(int rowId, int offset, int length) { assert(offset >= 0 && offset + length <= childColumns[0].capacity); - Platform.putInt(null, lengthData + 4 * rowId, length); - Platform.putInt(null, offsetData + 4 * rowId, offset); + Platform.putInt(null, lengthData + 4L * rowId, length); + Platform.putInt(null, offsetData + 4L * rowId, offset); } @Override public int getArrayLength(int rowId) { - return Platform.getInt(null, lengthData + 4 * rowId); + return Platform.getInt(null, lengthData + 4L * rowId); } @Override public int getArrayOffset(int rowId) { - return Platform.getInt(null, offsetData + 4 * rowId); + return Platform.getInt(null, offsetData + 4L * rowId); } // APIs dealing with ByteArrays @Override public int putByteArray(int rowId, byte[] value, int offset, int length) { int result = arrayData().appendBytes(length, value, offset); - Platform.putInt(null, lengthData + 4 * rowId, length); - Platform.putInt(null, offsetData + 4 * rowId, result); + Platform.putInt(null, lengthData + 4L * rowId, length); + Platform.putInt(null, offsetData + 4L * rowId, result); return result; } @@ -533,19 +533,19 @@ protected void reserveInternal(int newCapacity) { int oldCapacity = (nulls == 0L) ? 0 : capacity; if (isArray() || type instanceof MapType) { this.lengthData = - Platform.reallocateMemory(lengthData, oldCapacity * 4, newCapacity * 4); + Platform.reallocateMemory(lengthData, oldCapacity * 4L, newCapacity * 4L); this.offsetData = - Platform.reallocateMemory(offsetData, oldCapacity * 4, newCapacity * 4); + Platform.reallocateMemory(offsetData, oldCapacity * 4L, newCapacity * 4L); } else if (type instanceof ByteType || type instanceof BooleanType) { this.data = Platform.reallocateMemory(data, oldCapacity, newCapacity); } else if (type instanceof ShortType) { - this.data = Platform.reallocateMemory(data, oldCapacity * 2, newCapacity * 2); + this.data = Platform.reallocateMemory(data, oldCapacity * 2L, newCapacity * 2L); } else if (type instanceof IntegerType || type instanceof FloatType || type instanceof DateType || DecimalType.is32BitDecimalType(type)) { - this.data = Platform.reallocateMemory(data, oldCapacity * 4, newCapacity * 4); + this.data = Platform.reallocateMemory(data, oldCapacity * 4L, newCapacity * 4L); } else if (type instanceof LongType || type instanceof DoubleType || DecimalType.is64BitDecimalType(type) || type instanceof TimestampType) { - this.data = Platform.reallocateMemory(data, oldCapacity * 8, newCapacity * 8); + this.data = Platform.reallocateMemory(data, oldCapacity * 8L, newCapacity * 8L); } else if (childColumns != null) { // Nothing to store. } else { diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java index 23dcc104e67c4..577eab6ed14c8 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java @@ -231,7 +231,7 @@ public void putShorts(int rowId, int count, short[] src, int srcIndex) { @Override public void putShorts(int rowId, int count, byte[] src, int srcIndex) { Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, shortData, - Platform.SHORT_ARRAY_OFFSET + rowId * 2, count * 2); + Platform.SHORT_ARRAY_OFFSET + rowId * 2L, count * 2L); } @Override @@ -276,7 +276,7 @@ public void putInts(int rowId, int count, int[] src, int srcIndex) { @Override public void putInts(int rowId, int count, byte[] src, int srcIndex) { Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, intData, - Platform.INT_ARRAY_OFFSET + rowId * 4, count * 4); + Platform.INT_ARRAY_OFFSET + rowId * 4L, count * 4L); } @Override @@ -342,7 +342,7 @@ public void putLongs(int rowId, int count, long[] src, int srcIndex) { @Override public void putLongs(int rowId, int count, byte[] src, int srcIndex) { Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, longData, - Platform.LONG_ARRAY_OFFSET + rowId * 8, count * 8); + Platform.LONG_ARRAY_OFFSET + rowId * 8L, count * 8L); } @Override @@ -394,7 +394,7 @@ public void putFloats(int rowId, int count, float[] src, int srcIndex) { public void putFloats(int rowId, int count, byte[] src, int srcIndex) { if (!bigEndianPlatform) { Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, floatData, - Platform.DOUBLE_ARRAY_OFFSET + rowId * 4, count * 4); + Platform.DOUBLE_ARRAY_OFFSET + rowId * 4L, count * 4L); } else { ByteBuffer bb = ByteBuffer.wrap(src).order(ByteOrder.LITTLE_ENDIAN); for (int i = 0; i < count; ++i) { @@ -443,7 +443,7 @@ public void putDoubles(int rowId, int count, double[] src, int srcIndex) { public void putDoubles(int rowId, int count, byte[] src, int srcIndex) { if (!bigEndianPlatform) { Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, doubleData, - Platform.DOUBLE_ARRAY_OFFSET + rowId * 8, count * 8); + Platform.DOUBLE_ARRAY_OFFSET + rowId * 8L, count * 8L); } else { ByteBuffer bb = ByteBuffer.wrap(src).order(ByteOrder.LITTLE_ENDIAN); for (int i = 0; i < count; ++i) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala index fbff8db987110..b393c48baee8d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala @@ -202,7 +202,7 @@ class RateStreamMicroBatchInputPartitionReader( rangeEnd: Long, localStartTimeMs: Long, relativeMsPerValue: Double) extends InputPartitionReader[Row] { - private var count = 0 + private var count: Long = 0 override def next(): Boolean = { rangeStart + partitionId + numPartitions * count < rangeEnd diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala index 130e258e78ca2..8620f3f6d99fb 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala @@ -342,7 +342,7 @@ private[client] class Shim_v0_12 extends Shim with Logging { } override def getMetastoreClientConnectRetryDelayMillis(conf: HiveConf): Long = { - conf.getIntVar(HiveConf.ConfVars.METASTORE_CLIENT_CONNECT_RETRY_DELAY) * 1000 + conf.getIntVar(HiveConf.ConfVars.METASTORE_CLIENT_CONNECT_RETRY_DELAY) * 1000L } override def loadPartition( diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala index ab7c8558321c8..2e8599026ea1d 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala @@ -222,7 +222,7 @@ private[streaming] class FileBasedWriteAheadLog( pastLogs += LogInfo(currentLogWriterStartTime, currentLogWriterStopTime, _) } currentLogWriterStartTime = currentTime - currentLogWriterStopTime = currentTime + (rollingIntervalSecs * 1000) + currentLogWriterStopTime = currentTime + (rollingIntervalSecs * 1000L) val newLogPath = new Path(logDirectory, timeToLogFile(currentLogWriterStartTime, currentLogWriterStopTime)) currentLogPath = Some(newLogPath.toString) From e4fee395ecd93ad4579d9afbf0861f82a303e563 Mon Sep 17 00:00:00 2001 From: Mukul Murthy Date: Fri, 15 Jun 2018 13:56:48 -0700 Subject: [PATCH 0972/1161] [SPARK-24525][SS] Provide an option to limit number of rows in a MemorySink ## What changes were proposed in this pull request? Provide an option to limit number of rows in a MemorySink. Currently, MemorySink and MemorySinkV2 have unbounded size, meaning that if they're used on big data, they can OOM the stream. This change adds a maxMemorySinkRows option to limit how many rows MemorySink and MemorySinkV2 can hold. By default, they are still unbounded. ## How was this patch tested? Added new unit tests. Author: Mukul Murthy Closes #21559 from mukulmurthy/SPARK-24525. --- .../sql/execution/streaming/memory.scala | 70 ++++++++++++++-- .../streaming/sources/memoryV2.scala | 44 +++++++--- .../sql/streaming/DataStreamWriter.scala | 4 +- .../execution/streaming/MemorySinkSuite.scala | 62 +++++++++++++- .../streaming/MemorySinkV2Suite.scala | 80 ++++++++++++++++++- .../spark/sql/streaming/StreamTest.scala | 4 +- 6 files changed, 239 insertions(+), 25 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala index b137f98045c5a..7fa13c4aa2c01 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -33,6 +33,7 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeRow} import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ +import org.apache.spark.sql.sources.v2.DataSourceOptions import org.apache.spark.sql.sources.v2.reader.{InputPartition, InputPartitionReader, SupportsScanUnsafeRow} import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset => OffsetV2} import org.apache.spark.sql.streaming.OutputMode @@ -221,19 +222,60 @@ class MemoryStreamInputPartition(records: Array[UnsafeRow]) } /** A common trait for MemorySinks with methods used for testing */ -trait MemorySinkBase extends BaseStreamingSink { +trait MemorySinkBase extends BaseStreamingSink with Logging { def allData: Seq[Row] def latestBatchData: Seq[Row] def dataSinceBatch(sinceBatchId: Long): Seq[Row] def latestBatchId: Option[Long] + + /** + * Truncates the given rows to return at most maxRows rows. + * @param rows The data that may need to be truncated. + * @param batchLimit Number of rows to keep in this batch; the rest will be truncated + * @param sinkLimit Total number of rows kept in this sink, for logging purposes. + * @param batchId The ID of the batch that sent these rows, for logging purposes. + * @return Truncated rows. + */ + protected def truncateRowsIfNeeded( + rows: Array[Row], + batchLimit: Int, + sinkLimit: Int, + batchId: Long): Array[Row] = { + if (rows.length > batchLimit && batchLimit >= 0) { + logWarning(s"Truncating batch $batchId to $batchLimit rows because of sink limit $sinkLimit") + rows.take(batchLimit) + } else { + rows + } + } +} + +/** + * Companion object to MemorySinkBase. + */ +object MemorySinkBase { + val MAX_MEMORY_SINK_ROWS = "maxRows" + val MAX_MEMORY_SINK_ROWS_DEFAULT = -1 + + /** + * Gets the max number of rows a MemorySink should store. This number is based on the memory + * sink row limit option if it is set. If not, we use a large value so that data truncates + * rather than causing out of memory errors. + * @param options Options for writing from which we get the max rows option + * @return The maximum number of rows a memorySink should store. + */ + def getMemorySinkCapacity(options: DataSourceOptions): Int = { + val maxRows = options.getInt(MAX_MEMORY_SINK_ROWS, MAX_MEMORY_SINK_ROWS_DEFAULT) + if (maxRows >= 0) maxRows else Int.MaxValue - 10 + } } /** * A sink that stores the results in memory. This [[Sink]] is primarily intended for use in unit * tests and does not provide durability. */ -class MemorySink(val schema: StructType, outputMode: OutputMode) extends Sink - with MemorySinkBase with Logging { +class MemorySink(val schema: StructType, outputMode: OutputMode, options: DataSourceOptions) + extends Sink with MemorySinkBase with Logging { private case class AddedData(batchId: Long, data: Array[Row]) @@ -241,6 +283,12 @@ class MemorySink(val schema: StructType, outputMode: OutputMode) extends Sink @GuardedBy("this") private val batches = new ArrayBuffer[AddedData]() + /** The number of rows in this MemorySink. */ + private var numRows = 0 + + /** The capacity in rows of this sink. */ + val sinkCapacity: Int = MemorySinkBase.getMemorySinkCapacity(options) + /** Returns all rows that are stored in this [[Sink]]. */ def allData: Seq[Row] = synchronized { batches.flatMap(_.data) @@ -273,14 +321,23 @@ class MemorySink(val schema: StructType, outputMode: OutputMode) extends Sink logDebug(s"Committing batch $batchId to $this") outputMode match { case Append | Update => - val rows = AddedData(batchId, data.collect()) - synchronized { batches += rows } + var rowsToAdd = data.collect() + synchronized { + rowsToAdd = + truncateRowsIfNeeded(rowsToAdd, sinkCapacity - numRows, sinkCapacity, batchId) + val rows = AddedData(batchId, rowsToAdd) + batches += rows + numRows += rowsToAdd.length + } case Complete => - val rows = AddedData(batchId, data.collect()) + var rowsToAdd = data.collect() synchronized { + rowsToAdd = truncateRowsIfNeeded(rowsToAdd, sinkCapacity, sinkCapacity, batchId) + val rows = AddedData(batchId, rowsToAdd) batches.clear() batches += rows + numRows = rowsToAdd.length } case _ => @@ -294,6 +351,7 @@ class MemorySink(val schema: StructType, outputMode: OutputMode) extends Sink def clear(): Unit = synchronized { batches.clear() + numRows = 0 } override def toString(): String = "MemorySink" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala index 468313bfe8c3c..47b482007822d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala @@ -46,7 +46,7 @@ class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with MemorySinkB schema: StructType, mode: OutputMode, options: DataSourceOptions): StreamWriter = { - new MemoryStreamWriter(this, mode) + new MemoryStreamWriter(this, mode, options) } private case class AddedData(batchId: Long, data: Array[Row]) @@ -55,6 +55,9 @@ class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with MemorySinkB @GuardedBy("this") private val batches = new ArrayBuffer[AddedData]() + /** The number of rows in this MemorySink. */ + private var numRows = 0 + /** Returns all rows that are stored in this [[Sink]]. */ def allData: Seq[Row] = synchronized { batches.flatMap(_.data) @@ -81,7 +84,11 @@ class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with MemorySinkB }.mkString("\n") } - def write(batchId: Long, outputMode: OutputMode, newRows: Array[Row]): Unit = { + def write( + batchId: Long, + outputMode: OutputMode, + newRows: Array[Row], + sinkCapacity: Int): Unit = { val notCommitted = synchronized { latestBatchId.isEmpty || batchId > latestBatchId.get } @@ -89,14 +96,21 @@ class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with MemorySinkB logDebug(s"Committing batch $batchId to $this") outputMode match { case Append | Update => - val rows = AddedData(batchId, newRows) - synchronized { batches += rows } + synchronized { + val rowsToAdd = + truncateRowsIfNeeded(newRows, sinkCapacity - numRows, sinkCapacity, batchId) + val rows = AddedData(batchId, rowsToAdd) + batches += rows + numRows += rowsToAdd.length + } case Complete => - val rows = AddedData(batchId, newRows) synchronized { + val rowsToAdd = truncateRowsIfNeeded(newRows, sinkCapacity, sinkCapacity, batchId) + val rows = AddedData(batchId, rowsToAdd) batches.clear() batches += rows + numRows = rowsToAdd.length } case _ => @@ -110,6 +124,7 @@ class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with MemorySinkB def clear(): Unit = synchronized { batches.clear() + numRows = 0 } override def toString(): String = "MemorySinkV2" @@ -117,16 +132,22 @@ class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with MemorySinkB case class MemoryWriterCommitMessage(partition: Int, data: Seq[Row]) extends WriterCommitMessage {} -class MemoryWriter(sink: MemorySinkV2, batchId: Long, outputMode: OutputMode) +class MemoryWriter( + sink: MemorySinkV2, + batchId: Long, + outputMode: OutputMode, + options: DataSourceOptions) extends DataSourceWriter with Logging { + val sinkCapacity: Int = MemorySinkBase.getMemorySinkCapacity(options) + override def createWriterFactory: MemoryWriterFactory = MemoryWriterFactory(outputMode) def commit(messages: Array[WriterCommitMessage]): Unit = { val newRows = messages.flatMap { case message: MemoryWriterCommitMessage => message.data } - sink.write(batchId, outputMode, newRows) + sink.write(batchId, outputMode, newRows, sinkCapacity) } override def abort(messages: Array[WriterCommitMessage]): Unit = { @@ -134,16 +155,21 @@ class MemoryWriter(sink: MemorySinkV2, batchId: Long, outputMode: OutputMode) } } -class MemoryStreamWriter(val sink: MemorySinkV2, outputMode: OutputMode) +class MemoryStreamWriter( + val sink: MemorySinkV2, + outputMode: OutputMode, + options: DataSourceOptions) extends StreamWriter { + val sinkCapacity: Int = MemorySinkBase.getMemorySinkCapacity(options) + override def createWriterFactory: MemoryWriterFactory = MemoryWriterFactory(outputMode) override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = { val newRows = messages.flatMap { case message: MemoryWriterCommitMessage => message.data } - sink.write(epochId, outputMode, newRows) + sink.write(epochId, outputMode, newRows, sinkCapacity) } override def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala index e035c9cdc379e..43e80e4e54239 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger import org.apache.spark.sql.execution.streaming.sources.{ForeachWriterProvider, MemoryPlanV2, MemorySinkV2} -import org.apache.spark.sql.sources.v2.StreamWriteSupport +import org.apache.spark.sql.sources.v2.{DataSourceOptions, StreamWriteSupport} /** * Interface used to write a streaming `Dataset` to external storage systems (e.g. file systems, @@ -249,7 +249,7 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { val r = Dataset.ofRows(df.sparkSession, new MemoryPlanV2(s, df.schema.toAttributes)) (s, r) case _ => - val s = new MemorySink(df.schema, outputMode) + val s = new MemorySink(df.schema, outputMode, new DataSourceOptions(extraOptions.asJava)) val r = Dataset.ofRows(df.sparkSession, new MemoryPlan(s)) (s, r) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala index 3bc36ce55d902..b2fd6ba27ebb8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala @@ -17,11 +17,13 @@ package org.apache.spark.sql.execution.streaming +import scala.collection.JavaConverters._ import scala.language.implicitConversions import org.scalatest.BeforeAndAfter import org.apache.spark.sql._ +import org.apache.spark.sql.sources.v2.DataSourceOptions import org.apache.spark.sql.streaming.{OutputMode, StreamTest} import org.apache.spark.sql.types.{IntegerType, StructField, StructType} import org.apache.spark.util.Utils @@ -36,7 +38,7 @@ class MemorySinkSuite extends StreamTest with BeforeAndAfter { test("directly add data in Append output mode") { implicit val schema = new StructType().add(new StructField("value", IntegerType)) - val sink = new MemorySink(schema, OutputMode.Append) + val sink = new MemorySink(schema, OutputMode.Append, DataSourceOptions.empty()) // Before adding data, check output assert(sink.latestBatchId === None) @@ -68,9 +70,35 @@ class MemorySinkSuite extends StreamTest with BeforeAndAfter { checkAnswer(sink.allData, 1 to 9) } + test("directly add data in Append output mode with row limit") { + implicit val schema = new StructType().add(new StructField("value", IntegerType)) + + var optionsMap = new scala.collection.mutable.HashMap[String, String] + optionsMap.put(MemorySinkBase.MAX_MEMORY_SINK_ROWS, 5.toString()) + var options = new DataSourceOptions(optionsMap.toMap.asJava) + val sink = new MemorySink(schema, OutputMode.Append, options) + + // Before adding data, check output + assert(sink.latestBatchId === None) + checkAnswer(sink.latestBatchData, Seq.empty) + checkAnswer(sink.allData, Seq.empty) + + // Add batch 0 and check outputs + sink.addBatch(0, 1 to 3) + assert(sink.latestBatchId === Some(0)) + checkAnswer(sink.latestBatchData, 1 to 3) + checkAnswer(sink.allData, 1 to 3) + + // Add batch 1 and check outputs + sink.addBatch(1, 4 to 6) + assert(sink.latestBatchId === Some(1)) + checkAnswer(sink.latestBatchData, 4 to 5) + checkAnswer(sink.allData, 1 to 5) // new data should not go over the limit + } + test("directly add data in Update output mode") { implicit val schema = new StructType().add(new StructField("value", IntegerType)) - val sink = new MemorySink(schema, OutputMode.Update) + val sink = new MemorySink(schema, OutputMode.Update, DataSourceOptions.empty()) // Before adding data, check output assert(sink.latestBatchId === None) @@ -104,7 +132,7 @@ class MemorySinkSuite extends StreamTest with BeforeAndAfter { test("directly add data in Complete output mode") { implicit val schema = new StructType().add(new StructField("value", IntegerType)) - val sink = new MemorySink(schema, OutputMode.Complete) + val sink = new MemorySink(schema, OutputMode.Complete, DataSourceOptions.empty()) // Before adding data, check output assert(sink.latestBatchId === None) @@ -136,6 +164,32 @@ class MemorySinkSuite extends StreamTest with BeforeAndAfter { checkAnswer(sink.allData, 7 to 9) } + test("directly add data in Complete output mode with row limit") { + implicit val schema = new StructType().add(new StructField("value", IntegerType)) + + var optionsMap = new scala.collection.mutable.HashMap[String, String] + optionsMap.put(MemorySinkBase.MAX_MEMORY_SINK_ROWS, 5.toString()) + var options = new DataSourceOptions(optionsMap.toMap.asJava) + val sink = new MemorySink(schema, OutputMode.Complete, options) + + // Before adding data, check output + assert(sink.latestBatchId === None) + checkAnswer(sink.latestBatchData, Seq.empty) + checkAnswer(sink.allData, Seq.empty) + + // Add batch 0 and check outputs + sink.addBatch(0, 1 to 3) + assert(sink.latestBatchId === Some(0)) + checkAnswer(sink.latestBatchData, 1 to 3) + checkAnswer(sink.allData, 1 to 3) + + // Add batch 1 and check outputs + sink.addBatch(1, 4 to 10) + assert(sink.latestBatchId === Some(1)) + checkAnswer(sink.latestBatchData, 4 to 8) + checkAnswer(sink.allData, 4 to 8) // new data should replace old data + } + test("registering as a table in Append output mode") { val input = MemoryStream[Int] @@ -211,7 +265,7 @@ class MemorySinkSuite extends StreamTest with BeforeAndAfter { test("MemoryPlan statistics") { implicit val schema = new StructType().add(new StructField("value", IntegerType)) - val sink = new MemorySink(schema, OutputMode.Append) + val sink = new MemorySink(schema, OutputMode.Append, DataSourceOptions.empty()) val plan = new MemoryPlan(sink) // Before adding data, check output diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala index 9be22d94b5654..e539510e15755 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala @@ -17,11 +17,16 @@ package org.apache.spark.sql.execution.streaming +import scala.collection.JavaConverters._ + import org.scalatest.BeforeAndAfter import org.apache.spark.sql.Row import org.apache.spark.sql.execution.streaming.sources._ +import org.apache.spark.sql.sources.v2.DataSourceOptions import org.apache.spark.sql.streaming.{OutputMode, StreamTest} +import org.apache.spark.sql.types.IntegerType +import org.apache.spark.sql.types.StructType class MemorySinkV2Suite extends StreamTest with BeforeAndAfter { test("data writer") { @@ -40,7 +45,7 @@ class MemorySinkV2Suite extends StreamTest with BeforeAndAfter { test("continuous writer") { val sink = new MemorySinkV2 - val writer = new MemoryStreamWriter(sink, OutputMode.Append()) + val writer = new MemoryStreamWriter(sink, OutputMode.Append(), DataSourceOptions.empty()) writer.commit(0, Array( MemoryWriterCommitMessage(0, Seq(Row(1), Row(2))), @@ -62,7 +67,7 @@ class MemorySinkV2Suite extends StreamTest with BeforeAndAfter { test("microbatch writer") { val sink = new MemorySinkV2 - new MemoryWriter(sink, 0, OutputMode.Append()).commit( + new MemoryWriter(sink, 0, OutputMode.Append(), DataSourceOptions.empty()).commit( Array( MemoryWriterCommitMessage(0, Seq(Row(1), Row(2))), MemoryWriterCommitMessage(1, Seq(Row(3), Row(4))), @@ -70,7 +75,7 @@ class MemorySinkV2Suite extends StreamTest with BeforeAndAfter { )) assert(sink.latestBatchId.contains(0)) assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4, 6, 7)) - new MemoryWriter(sink, 19, OutputMode.Append()).commit( + new MemoryWriter(sink, 19, OutputMode.Append(), DataSourceOptions.empty()).commit( Array( MemoryWriterCommitMessage(3, Seq(Row(11), Row(22))), MemoryWriterCommitMessage(0, Seq(Row(33))) @@ -80,4 +85,73 @@ class MemorySinkV2Suite extends StreamTest with BeforeAndAfter { assert(sink.allData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4, 6, 7, 11, 22, 33)) } + + test("continuous writer with row limit") { + val sink = new MemorySinkV2 + val optionsMap = new scala.collection.mutable.HashMap[String, String] + optionsMap.put(MemorySinkBase.MAX_MEMORY_SINK_ROWS, 7.toString()) + val options = new DataSourceOptions(optionsMap.toMap.asJava) + val appendWriter = new MemoryStreamWriter(sink, OutputMode.Append(), options) + appendWriter.commit(0, Array( + MemoryWriterCommitMessage(0, Seq(Row(1), Row(2))), + MemoryWriterCommitMessage(1, Seq(Row(3), Row(4))), + MemoryWriterCommitMessage(2, Seq(Row(6), Row(7))))) + assert(sink.latestBatchId.contains(0)) + assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4, 6, 7)) + appendWriter.commit(19, Array( + MemoryWriterCommitMessage(3, Seq(Row(11), Row(22))), + MemoryWriterCommitMessage(0, Seq(Row(33))))) + assert(sink.latestBatchId.contains(19)) + assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(11)) + + assert(sink.allData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4, 6, 7, 11)) + + val completeWriter = new MemoryStreamWriter(sink, OutputMode.Complete(), options) + completeWriter.commit(20, Array( + MemoryWriterCommitMessage(4, Seq(Row(11), Row(22))), + MemoryWriterCommitMessage(5, Seq(Row(33))))) + assert(sink.latestBatchId.contains(20)) + assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(11, 22, 33)) + completeWriter.commit(21, Array( + MemoryWriterCommitMessage(0, Seq(Row(1), Row(2), Row(3))), + MemoryWriterCommitMessage(1, Seq(Row(4), Row(5), Row(6))), + MemoryWriterCommitMessage(2, Seq(Row(7), Row(8), Row(9))))) + assert(sink.latestBatchId.contains(21)) + assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4, 5, 6, 7)) + + assert(sink.allData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4, 5, 6, 7)) + } + + test("microbatch writer with row limit") { + val sink = new MemorySinkV2 + val optionsMap = new scala.collection.mutable.HashMap[String, String] + optionsMap.put(MemorySinkBase.MAX_MEMORY_SINK_ROWS, 5.toString()) + val options = new DataSourceOptions(optionsMap.toMap.asJava) + + new MemoryWriter(sink, 25, OutputMode.Append(), options).commit(Array( + MemoryWriterCommitMessage(0, Seq(Row(1), Row(2))), + MemoryWriterCommitMessage(1, Seq(Row(3), Row(4))))) + assert(sink.latestBatchId.contains(25)) + assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4)) + assert(sink.allData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4)) + new MemoryWriter(sink, 26, OutputMode.Append(), options).commit(Array( + MemoryWriterCommitMessage(2, Seq(Row(5), Row(6))), + MemoryWriterCommitMessage(3, Seq(Row(7), Row(8))))) + assert(sink.latestBatchId.contains(26)) + assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(5)) + assert(sink.allData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4, 5)) + + new MemoryWriter(sink, 27, OutputMode.Complete(), options).commit(Array( + MemoryWriterCommitMessage(4, Seq(Row(9), Row(10))), + MemoryWriterCommitMessage(5, Seq(Row(11), Row(12))))) + assert(sink.latestBatchId.contains(27)) + assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(9, 10, 11, 12)) + assert(sink.allData.map(_.getInt(0)).sorted == Seq(9, 10, 11, 12)) + new MemoryWriter(sink, 28, OutputMode.Complete(), options).commit(Array( + MemoryWriterCommitMessage(4, Seq(Row(13), Row(14), Row(15))), + MemoryWriterCommitMessage(5, Seq(Row(16), Row(17), Row(18))))) + assert(sink.latestBatchId.contains(28)) + assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(13, 14, 15, 16, 17)) + assert(sink.allData.map(_.getInt(0)).sorted == Seq(13, 14, 15, 16, 17)) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index 4c3fd58cb2e45..e41b4534ed51d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -45,6 +45,7 @@ import org.apache.spark.sql.execution.streaming.continuous.{ContinuousExecution, import org.apache.spark.sql.execution.streaming.sources.MemorySinkV2 import org.apache.spark.sql.execution.streaming.state.StateStore import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.sources.v2.DataSourceOptions import org.apache.spark.sql.streaming.StreamingQueryListener._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.util.{Clock, SystemClock, Utils} @@ -337,7 +338,8 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be var currentStream: StreamExecution = null var lastStream: StreamExecution = null val awaiting = new mutable.HashMap[Int, Offset]() // source index -> offset to wait for - val sink = if (useV2Sink) new MemorySinkV2 else new MemorySink(stream.schema, outputMode) + val sink = if (useV2Sink) new MemorySinkV2 + else new MemorySink(stream.schema, outputMode, DataSourceOptions.empty()) val resetConfValues = mutable.Map[String, Option[String]]() val defaultCheckpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath From c7c0b086a0b18424725433ade840d5121ac2b86e Mon Sep 17 00:00:00 2001 From: James Yu Date: Fri, 15 Jun 2018 21:04:04 -0700 Subject: [PATCH 0973/1161] add one supported type missing from the javadoc ## What changes were proposed in this pull request? The supported java.math.BigInteger type is not mentioned in the javadoc of Encoders.bean() ## How was this patch tested? only Javadoc fix Please review http://spark.apache.org/contributing.html before opening a pull request. Author: James Yu Closes #21544 from yuj/master. --- sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala index 0b95a8821b05a..b47ec0b72c638 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala @@ -132,7 +132,7 @@ object Encoders { * - primitive types: boolean, int, double, etc. * - boxed types: Boolean, Integer, Double, etc. * - String - * - java.math.BigDecimal + * - java.math.BigDecimal, java.math.BigInteger * - time related: java.sql.Date, java.sql.Timestamp * - collection types: only array and java.util.List currently, map support is in progress * - nested java bean. From b0a935255951280b49c39968f6234163e2f0e379 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Mon, 18 Jun 2018 15:32:34 +0800 Subject: [PATCH 0974/1161] [SPARK-24573][INFRA] Runs SBT checkstyle after the build to work around a side-effect ## What changes were proposed in this pull request? Seems checkstyle affects the build in the PR builder in Jenkins. I can't reproduce in my local and seems it can only be reproduced in the PR builder. I was checking the places it goes through and this is just a speculation that checkstyle's compilation in SBT has a side effect to the assembly build. This PR proposes to run the SBT checkstyle after the build. ## How was this patch tested? Jenkins tests. Author: hyukjinkwon Closes #21579 from HyukjinKwon/investigate-javastyle. --- dev/run-tests.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/dev/run-tests.py b/dev/run-tests.py index 5e8c8590b5c34..cd4590864b7d7 100755 --- a/dev/run-tests.py +++ b/dev/run-tests.py @@ -357,7 +357,7 @@ def build_spark_unidoc_sbt(hadoop_version): exec_sbt(profiles_and_goals) -def build_spark_assembly_sbt(hadoop_version): +def build_spark_assembly_sbt(hadoop_version, checkstyle=False): # Enable all of the profiles for the build: build_profiles = get_hadoop_profiles(hadoop_version) + modules.root.build_profile_flags sbt_goals = ["assembly/package"] @@ -366,6 +366,9 @@ def build_spark_assembly_sbt(hadoop_version): " ".join(profiles_and_goals)) exec_sbt(profiles_and_goals) + if checkstyle: + run_java_style_checks() + # Note that we skip Unidoc build only if Hadoop 2.6 is explicitly set in this SBT build. # Due to a different dependency resolution in SBT & Unidoc by an unknown reason, the # documentation build fails on a specific machine & environment in Jenkins but it was unable @@ -570,11 +573,13 @@ def main(): or f.endswith("scalastyle-config.xml") for f in changed_files): run_scala_style_checks() + should_run_java_style_checks = False if not changed_files or any(f.endswith(".java") or f.endswith("checkstyle.xml") or f.endswith("checkstyle-suppressions.xml") for f in changed_files): - run_java_style_checks() + # Run SBT Checkstyle after the build to prevent a side-effect to the build. + should_run_java_style_checks = True if not changed_files or any(f.endswith("lint-python") or f.endswith("tox.ini") or f.endswith(".py") @@ -603,7 +608,7 @@ def main(): detect_binary_inop_with_mima(hadoop_version) # Since we did not build assembly/package before running dev/mima, we need to # do it here because the tests still rely on it; see SPARK-13294 for details. - build_spark_assembly_sbt(hadoop_version) + build_spark_assembly_sbt(hadoop_version, should_run_java_style_checks) # run the test suites run_scala_tests(build_tool, hadoop_version, test_modules, excluded_tags) From e219e692ef70c161f37a48bfdec2a94b29260004 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Tue, 19 Jun 2018 00:24:54 +0800 Subject: [PATCH 0975/1161] [SPARK-23772][SQL] Provide an option to ignore column of all null values or empty array during JSON schema inference ## What changes were proposed in this pull request? This pr added a new JSON option `dropFieldIfAllNull ` to ignore column of all null values or empty array/struct during JSON schema inference. ## How was this patch tested? Added tests in `JsonSuite`. Author: Takeshi Yamamuro Author: Xiangrui Meng Closes #20929 from maropu/SPARK-23772. --- python/pyspark/sql/readwriter.py | 5 +- .../spark/sql/catalyst/json/JSONOptions.scala | 3 ++ .../apache/spark/sql/DataFrameReader.scala | 2 + .../datasources/json/JsonInferSchema.scala | 40 +++++++-------- .../sql/streaming/DataStreamReader.scala | 2 + .../datasources/json/JsonSuite.scala | 49 +++++++++++++++++++ 6 files changed, 80 insertions(+), 21 deletions(-) diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index a0e20d39c20da..3efe2adb6e2a4 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -177,7 +177,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, allowNumericLeadingZero=None, allowBackslashEscapingAnyCharacter=None, mode=None, columnNameOfCorruptRecord=None, dateFormat=None, timestampFormat=None, multiLine=None, allowUnquotedControlChars=None, lineSep=None, samplingRatio=None, - encoding=None): + dropFieldIfAllNull=None, encoding=None): """ Loads JSON files and returns the results as a :class:`DataFrame`. @@ -246,6 +246,9 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, set, it covers all ``\\r``, ``\\r\\n`` and ``\\n``. :param samplingRatio: defines fraction of input JSON objects used for schema inferring. If None is set, it uses the default value, ``1.0``. + :param dropFieldIfAllNull: whether to ignore column of all null values or empty + array/struct during schema inference. If None is set, it + uses the default value, ``false``. >>> df1 = spark.read.json('python/test_support/sql/people.json') >>> df1.dtypes diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala index 2ff12acb2946f..c081772116f84 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala @@ -73,6 +73,9 @@ private[sql] class JSONOptions( val columnNameOfCorruptRecord = parameters.getOrElse("columnNameOfCorruptRecord", defaultColumnNameOfCorruptRecord) + // Whether to ignore column of all null values or empty array/struct during schema inference + val dropFieldIfAllNull = parameters.get("dropFieldIfAllNull").map(_.toBoolean).getOrElse(false) + val timeZone: TimeZone = DateTimeUtils.getTimeZone( parameters.getOrElse(DateTimeUtils.TIMEZONE_OPTION, defaultTimeZoneId)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index de6be5f76e15a..ec9352a7fa055 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -381,6 +381,8 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * that should be used for parsing. *
  • `samplingRatio` (default is 1.0): defines fraction of input JSON objects used * for schema inferring.
  • + *
  • `dropFieldIfAllNull` (default `false`): whether to ignore column of all null values or + * empty array/struct during schema inference.
  • * * * @since 2.0.0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala index e7eed95a560a3..f6edc7bfb3750 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala @@ -75,7 +75,7 @@ private[sql] object JsonInferSchema { // active SparkSession and `SQLConf.get` may point to the wrong configs. val rootType = mergedTypesFromPartitions.toLocalIterator.fold(StructType(Nil))(typeMerger) - canonicalizeType(rootType) match { + canonicalizeType(rootType, configOptions) match { case Some(st: StructType) => st case _ => // canonicalizeType erases all empty structs, including the only one we want to keep @@ -181,33 +181,33 @@ private[sql] object JsonInferSchema { } /** - * Convert NullType to StringType and remove StructTypes with no fields + * Recursively canonicalizes inferred types, e.g., removes StructTypes with no fields, + * drops NullTypes or converts them to StringType based on provided options. */ - private def canonicalizeType(tpe: DataType): Option[DataType] = tpe match { - case at @ ArrayType(elementType, _) => - for { - canonicalType <- canonicalizeType(elementType) - } yield { - at.copy(canonicalType) - } + private def canonicalizeType(tpe: DataType, options: JSONOptions): Option[DataType] = tpe match { + case at: ArrayType => + canonicalizeType(at.elementType, options) + .map(t => at.copy(elementType = t)) case StructType(fields) => - val canonicalFields: Array[StructField] = for { - field <- fields - if field.name.length > 0 - canonicalType <- canonicalizeType(field.dataType) - } yield { - field.copy(dataType = canonicalType) + val canonicalFields = fields.filter(_.name.nonEmpty).flatMap { f => + canonicalizeType(f.dataType, options) + .map(t => f.copy(dataType = t)) } - - if (canonicalFields.length > 0) { - Some(StructType(canonicalFields)) + // SPARK-8093: empty structs should be deleted + if (canonicalFields.isEmpty) { + None } else { - // per SPARK-8093: empty structs should be deleted + Some(StructType(canonicalFields)) + } + + case NullType => + if (options.dropFieldIfAllNull) { None + } else { + Some(StringType) } - case NullType => Some(StringType) case other => Some(other) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index ae93965bc50ed..ef8dc3a325a33 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -270,6 +270,8 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo * per file *
  • `lineSep` (default covers all `\r`, `\r\n` and `\n`): defines the line separator * that should be used for parsing.
  • + *
  • `dropFieldIfAllNull` (default `false`): whether to ignore column of all null values or + * empty array/struct during schema inference.
  • * * * @since 2.0.0 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 4b3921c61a000..a8a4a524a97f9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -2427,4 +2427,53 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { spark.read.option("mode", "PERMISSIVE").option("encoding", "UTF-8").json(Seq(badJson).toDS()), Row(badJson)) } + + test("SPARK-23772 ignore column of all null values or empty array during schema inference") { + withTempPath { tempDir => + val path = tempDir.getAbsolutePath + + // primitive types + Seq( + """{"a":null, "b":1, "c":3.0}""", + """{"a":null, "b":null, "c":"string"}""", + """{"a":null, "b":null, "c":null}""") + .toDS().write.text(path) + var df = spark.read.format("json") + .option("dropFieldIfAllNull", true) + .load(path) + var expectedSchema = new StructType() + .add("b", LongType).add("c", StringType) + assert(df.schema === expectedSchema) + checkAnswer(df, Row(1, "3.0") :: Row(null, "string") :: Row(null, null) :: Nil) + + // arrays + Seq( + """{"a":[2, 1], "b":[null, null], "c":null, "d":[[], [null]], "e":[[], null, [[]]]}""", + """{"a":[null], "b":[null], "c":[], "d":[null, []], "e":null}""", + """{"a":null, "b":null, "c":[], "d":null, "e":[null, [], null]}""") + .toDS().write.mode("overwrite").text(path) + df = spark.read.format("json") + .option("dropFieldIfAllNull", true) + .load(path) + expectedSchema = new StructType() + .add("a", ArrayType(LongType)) + assert(df.schema === expectedSchema) + checkAnswer(df, Row(Array(2, 1)) :: Row(Array(null)) :: Row(null) :: Nil) + + // structs + Seq( + """{"a":{"a1": 1, "a2":"string"}, "b":{}}""", + """{"a":{"a1": 2, "a2":null}, "b":{"b1":[null]}}""", + """{"a":null, "b":null}""") + .toDS().write.mode("overwrite").text(path) + df = spark.read.format("json") + .option("dropFieldIfAllNull", true) + .load(path) + expectedSchema = new StructType() + .add("a", StructType(StructField("a1", LongType) :: StructField("a2", StringType) + :: Nil)) + assert(df.schema === expectedSchema) + checkAnswer(df, Row(Row(1, "string")) :: Row(Row(2, null)) :: Row(null) :: Nil) + } + } } From bce177552564a4862bc979d39790cf553a477d74 Mon Sep 17 00:00:00 2001 From: trystanleftwich Date: Tue, 19 Jun 2018 00:34:24 +0800 Subject: [PATCH 0976/1161] [SPARK-24526][BUILD][TEST-MAVEN] Spaces in the build dir causes failures in the build/mvn script ## What changes were proposed in this pull request? Fix the call to ${MVN_BIN} to be wrapped in quotes so it will handle having spaces in the path. ## How was this patch tested? Ran the following to confirm using the build/mvn tool with a space in the build dir now works without error ``` mkdir /tmp/test\ spaces cd /tmp/test\ spaces git clone https://github.com/apache/spark.git cd spark # Remove all mvn references in PATH so the script will download mvn to the local dir ./build/mvn -DskipTests clean package ``` Please review http://spark.apache.org/contributing.html before opening a pull request. Author: trystanleftwich Closes #21534 from trystanleftwich/SPARK-24526. --- build/mvn | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/build/mvn b/build/mvn index efa4f9364ea52..1405983982d4c 100755 --- a/build/mvn +++ b/build/mvn @@ -154,4 +154,4 @@ export MAVEN_OPTS=${MAVEN_OPTS:-"$_COMPILE_JVM_OPTS"} echo "Using \`mvn\` from path: $MVN_BIN" 1>&2 # Last, call the `mvn` command as usual -${MVN_BIN} -DzincPort=${ZINC_PORT} "$@" +"${MVN_BIN}" -DzincPort=${ZINC_PORT} "$@" From 8f225e055c2031ca85d61721ab712170ab4e50c1 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 18 Jun 2018 11:01:17 -0700 Subject: [PATCH 0977/1161] [SPARK-24548][SQL] Fix incorrect schema of Dataset with tuple encoders ## What changes were proposed in this pull request? When creating tuple expression encoders, we should give the serializer expressions of tuple items correct names, so we can have correct output schema when we use such tuple encoders. ## How was this patch tested? Added test. Author: Liang-Chi Hsieh Closes #21576 from viirya/SPARK-24548. --- .../catalyst/encoders/ExpressionEncoder.scala | 3 ++- .../org/apache/spark/sql/JavaDatasetSuite.java | 18 ++++++++++++++++++ .../org/apache/spark/sql/DatasetSuite.scala | 13 +++++++++++++ 3 files changed, 33 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index efc2882f0a3d3..cbea3c017a265 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -128,7 +128,7 @@ object ExpressionEncoder { case b: BoundReference if b == originalInputObject => newInputObject }) - if (enc.flat) { + val serializerExpr = if (enc.flat) { newSerializer.head } else { // For non-flat encoder, the input object is not top level anymore after being combined to @@ -146,6 +146,7 @@ object ExpressionEncoder { Invoke(Literal.fromObject(None), "equals", BooleanType, newInputObject :: Nil)) If(nullCheck, Literal.create(null, struct.dataType), struct) } + Alias(serializerExpr, s"_${index + 1}")() } val childrenDeserializers = encoders.zipWithIndex.map { case (enc, index) => diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index c132cab1b38cf..2c695fc58fd8c 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -34,6 +34,7 @@ import org.junit.*; import org.junit.rules.ExpectedException; +import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.*; import org.apache.spark.sql.*; @@ -336,6 +337,23 @@ public void testTupleEncoder() { Assert.assertEquals(data5, ds5.collectAsList()); } + @Test + public void testTupleEncoderSchema() { + Encoder>> encoder = + Encoders.tuple(Encoders.STRING(), Encoders.tuple(Encoders.STRING(), Encoders.STRING())); + List>> data = Arrays.asList(tuple2("1", tuple2("a", "b")), + tuple2("2", tuple2("c", "d"))); + Dataset ds1 = spark.createDataset(data, encoder).toDF("value1", "value2"); + + JavaPairRDD> pairRDD = jsc.parallelizePairs(data); + Dataset ds2 = spark.createDataset(JavaPairRDD.toRDD(pairRDD), encoder) + .toDF("value1", "value2"); + + Assert.assertEquals(ds1.schema(), ds2.schema()); + Assert.assertEquals(ds1.select(expr("value2._1")).collectAsList(), + ds2.select(expr("value2._1")).collectAsList()); + } + @Test public void testNestedTupleEncoder() { // test ((int, string), string) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index d477d78dc14e3..093cee91d2f49 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -1466,6 +1466,19 @@ class DatasetSuite extends QueryTest with SharedSQLContext { val ds = Seq[(Option[Int], Option[Int])]((Some(1), None)).toDS() intercept[NullPointerException](ds.as[(Int, Int)].collect()) } + + test("SPARK-24548: Dataset with tuple encoders should have correct schema") { + val encoder = Encoders.tuple(newStringEncoder, + Encoders.tuple(newStringEncoder, newStringEncoder)) + + val data = Seq(("a", ("1", "2")), ("b", ("3", "4"))) + val rdd = sparkContext.parallelize(data) + + val ds1 = spark.createDataset(rdd) + val ds2 = spark.createDataset(rdd)(encoder) + assert(ds1.schema == ds2.schema) + checkDataset(ds1.select("_2._2"), ds2.select("_2._2").collect(): _*) + } } case class TestDataUnion(x: Int, y: Int, z: Int) From 1737d45e08a5f1fb78515b14321721d7197b443a Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 18 Jun 2018 20:15:01 -0700 Subject: [PATCH 0978/1161] [SPARK-24478][SQL][FOLLOWUP] Move projection and filter push down to physical conversion ## What changes were proposed in this pull request? This is a followup of https://github.com/apache/spark/pull/21503, to completely move operator pushdown to the planner rule. The code are mostly from https://github.com/apache/spark/pull/21319 ## How was this patch tested? existing tests Author: Wenchen Fan Closes #21574 from cloud-fan/followup. --- .../v2/reader/SupportsReportStatistics.java | 7 +- .../datasources/v2/DataSourceV2Relation.scala | 109 ++++----------- .../datasources/v2/DataSourceV2Strategy.scala | 124 +++++++++++++----- 3 files changed, 123 insertions(+), 117 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportStatistics.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportStatistics.java index a79080a249ec8..926396414816c 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportStatistics.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportStatistics.java @@ -23,10 +23,9 @@ * A mix in interface for {@link DataSourceReader}. Data source readers can implement this * interface to report statistics to Spark. * - * Statistics are reported to the optimizer before a projection or any filters are pushed to the - * DataSourceReader. Implementations that return more accurate statistics based on projection and - * filters will not improve query performance until the planner can push operators before getting - * stats. + * Statistics are reported to the optimizer before any operator is pushed to the DataSourceReader. + * Implementations that return more accurate statistics based on pushed operators will not improve + * query performance until the planner can push operators before getting stats. */ @InterfaceStability.Evolving public interface SupportsReportStatistics extends DataSourceReader { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala index e08af218513fd..7613eb210c659 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala @@ -23,17 +23,24 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression} import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} -import org.apache.spark.sql.execution.datasources.DataSourceStrategy -import org.apache.spark.sql.sources.{DataSourceRegister, Filter} +import org.apache.spark.sql.sources.DataSourceRegister import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, ReadSupport, ReadSupportWithSchema} -import org.apache.spark.sql.sources.v2.reader.{DataSourceReader, SupportsPushDownCatalystFilters, SupportsPushDownFilters, SupportsPushDownRequiredColumns, SupportsReportStatistics} +import org.apache.spark.sql.sources.v2.reader.{DataSourceReader, SupportsReportStatistics} import org.apache.spark.sql.types.StructType +/** + * A logical plan representing a data source v2 scan. + * + * @param source An instance of a [[DataSourceV2]] implementation. + * @param options The options for this scan. Used to create fresh [[DataSourceReader]]. + * @param userSpecifiedSchema The user-specified schema for this scan. Used to create fresh + * [[DataSourceReader]]. + */ case class DataSourceV2Relation( source: DataSourceV2, output: Seq[AttributeReference], options: Map[String, String], - userSpecifiedSchema: Option[StructType] = None) + userSpecifiedSchema: Option[StructType]) extends LeafNode with MultiInstanceRelation with DataSourceV2StringFormat { import DataSourceV2Relation._ @@ -42,14 +49,7 @@ case class DataSourceV2Relation( override def simpleString: String = "RelationV2 " + metadataString - lazy val v2Options: DataSourceOptions = makeV2Options(options) - - def newReader: DataSourceReader = userSpecifiedSchema match { - case Some(userSchema) => - source.asReadSupportWithSchema.createReader(userSchema, v2Options) - case None => - source.asReadSupport.createReader(v2Options) - } + def newReader(): DataSourceReader = source.createReader(options, userSpecifiedSchema) override def computeStats(): Statistics = newReader match { case r: SupportsReportStatistics => @@ -139,83 +139,26 @@ object DataSourceV2Relation { source.getClass.getSimpleName } } - } - - private def makeV2Options(options: Map[String, String]): DataSourceOptions = { - new DataSourceOptions(options.asJava) - } - private def schema( - source: DataSourceV2, - v2Options: DataSourceOptions, - userSchema: Option[StructType]): StructType = { - val reader = userSchema match { - case Some(s) => - source.asReadSupportWithSchema.createReader(s, v2Options) - case _ => - source.asReadSupport.createReader(v2Options) + def createReader( + options: Map[String, String], + userSpecifiedSchema: Option[StructType]): DataSourceReader = { + val v2Options = new DataSourceOptions(options.asJava) + userSpecifiedSchema match { + case Some(s) => + asReadSupportWithSchema.createReader(s, v2Options) + case _ => + asReadSupport.createReader(v2Options) + } } - reader.readSchema() } def create( source: DataSourceV2, options: Map[String, String], - userSpecifiedSchema: Option[StructType] = None): DataSourceV2Relation = { - val output = schema(source, makeV2Options(options), userSpecifiedSchema).toAttributes - DataSourceV2Relation(source, output, options, userSpecifiedSchema) - } - - def pushRequiredColumns( - relation: DataSourceV2Relation, - reader: DataSourceReader, - struct: StructType): Seq[AttributeReference] = { - reader match { - case projectionSupport: SupportsPushDownRequiredColumns => - projectionSupport.pruneColumns(struct) - // return the output columns from the relation that were projected - val attrMap = relation.output.map(a => a.name -> a).toMap - projectionSupport.readSchema().map(f => attrMap(f.name)) - case _ => - relation.output - } - } - - def pushFilters( - reader: DataSourceReader, - filters: Seq[Expression]): (Seq[Expression], Seq[Expression]) = { - reader match { - case r: SupportsPushDownCatalystFilters => - val postScanFilters = r.pushCatalystFilters(filters.toArray) - val pushedFilters = r.pushedCatalystFilters() - (postScanFilters, pushedFilters) - - case r: SupportsPushDownFilters => - // A map from translated data source filters to original catalyst filter expressions. - val translatedFilterToExpr = scala.collection.mutable.HashMap.empty[Filter, Expression] - // Catalyst filter expression that can't be translated to data source filters. - val untranslatableExprs = scala.collection.mutable.ArrayBuffer.empty[Expression] - - for (filterExpr <- filters) { - val translated = DataSourceStrategy.translateFilter(filterExpr) - if (translated.isDefined) { - translatedFilterToExpr(translated.get) = filterExpr - } else { - untranslatableExprs += filterExpr - } - } - - // Data source filters that need to be evaluated again after scanning. which means - // the data source cannot guarantee the rows returned can pass these filters. - // As a result we must return it so Spark can plan an extra filter operator. - val postScanFilters = - r.pushFilters(translatedFilterToExpr.keys.toArray).map(translatedFilterToExpr) - // The filters which are marked as pushed to this data source - val pushedFilters = r.pushedFilters().map(translatedFilterToExpr) - - (untranslatableExprs ++ postScanFilters, pushedFilters) - - case _ => (filters, Nil) - } + userSpecifiedSchema: Option[StructType]): DataSourceV2Relation = { + val reader = source.createReader(options, userSpecifiedSchema) + DataSourceV2Relation( + source, reader.readSchema().toAttributes, options, userSpecifiedSchema) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index 8bf858c38d76c..182aa2906cf1e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -17,51 +17,115 @@ package org.apache.spark.sql.execution.datasources.v2 -import org.apache.spark.sql.{execution, Strategy} -import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, AttributeSet} +import scala.collection.mutable + +import org.apache.spark.sql.{sources, Strategy} +import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, AttributeSet, Expression} import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.{FilterExec, ProjectExec, SparkPlan} +import org.apache.spark.sql.execution.datasources.DataSourceStrategy import org.apache.spark.sql.execution.streaming.continuous.{WriteToContinuousDataSource, WriteToContinuousDataSourceExec} +import org.apache.spark.sql.sources.v2.reader.{DataSourceReader, SupportsPushDownCatalystFilters, SupportsPushDownFilters, SupportsPushDownRequiredColumns} object DataSourceV2Strategy extends Strategy { - override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case PhysicalOperation(project, filters, relation: DataSourceV2Relation) => - val projectSet = AttributeSet(project.flatMap(_.references)) - val filterSet = AttributeSet(filters.flatMap(_.references)) - - val projection = if (filterSet.subsetOf(projectSet) && - AttributeSet(relation.output) == projectSet) { - // When the required projection contains all of the filter columns and column pruning alone - // can produce the required projection, push the required projection. - // A final projection may still be needed if the data source produces a different column - // order or if it cannot prune all of the nested columns. - relation.output - } else { - // When there are filter columns not already in the required projection or when the required - // projection is more complicated than column pruning, base column pruning on the set of - // all columns needed by both. - (projectSet ++ filterSet).toSeq - } - val reader = relation.newReader + /** + * Pushes down filters to the data source reader + * + * @return pushed filter and post-scan filters. + */ + private def pushFilters( + reader: DataSourceReader, + filters: Seq[Expression]): (Seq[Expression], Seq[Expression]) = { + reader match { + case r: SupportsPushDownCatalystFilters => + val postScanFilters = r.pushCatalystFilters(filters.toArray) + val pushedFilters = r.pushedCatalystFilters() + (pushedFilters, postScanFilters) + + case r: SupportsPushDownFilters => + // A map from translated data source filters to original catalyst filter expressions. + val translatedFilterToExpr = mutable.HashMap.empty[sources.Filter, Expression] + // Catalyst filter expression that can't be translated to data source filters. + val untranslatableExprs = mutable.ArrayBuffer.empty[Expression] + + for (filterExpr <- filters) { + val translated = DataSourceStrategy.translateFilter(filterExpr) + if (translated.isDefined) { + translatedFilterToExpr(translated.get) = filterExpr + } else { + untranslatableExprs += filterExpr + } + } + + // Data source filters that need to be evaluated again after scanning. which means + // the data source cannot guarantee the rows returned can pass these filters. + // As a result we must return it so Spark can plan an extra filter operator. + val postScanFilters = r.pushFilters(translatedFilterToExpr.keys.toArray) + .map(translatedFilterToExpr) + // The filters which are marked as pushed to this data source + val pushedFilters = r.pushedFilters().map(translatedFilterToExpr) + (pushedFilters, untranslatableExprs ++ postScanFilters) + + case _ => (Nil, filters) + } + } - val output = DataSourceV2Relation.pushRequiredColumns(relation, reader, - projection.asInstanceOf[Seq[AttributeReference]].toStructType) + /** + * Applies column pruning to the data source, w.r.t. the references of the given expressions. + * + * @return new output attributes after column pruning. + */ + // TODO: nested column pruning. + private def pruneColumns( + reader: DataSourceReader, + relation: DataSourceV2Relation, + exprs: Seq[Expression]): Seq[AttributeReference] = { + reader match { + case r: SupportsPushDownRequiredColumns => + val requiredColumns = AttributeSet(exprs.flatMap(_.references)) + val neededOutput = relation.output.filter(requiredColumns.contains) + if (neededOutput != relation.output) { + r.pruneColumns(neededOutput.toStructType) + val nameToAttr = relation.output.map(_.name).zip(relation.output).toMap + r.readSchema().toAttributes.map { + // We have to keep the attribute id during transformation. + a => a.withExprId(nameToAttr(a.name).exprId) + } + } else { + relation.output + } + + case _ => relation.output + } + } - val (postScanFilters, pushedFilters) = DataSourceV2Relation.pushFilters(reader, filters) - logInfo(s"Post-Scan Filters: ${postScanFilters.mkString(",")}") - logInfo(s"Pushed Filters: ${pushedFilters.mkString(", ")}") + override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case PhysicalOperation(project, filters, relation: DataSourceV2Relation) => + val reader = relation.newReader() + // `pushedFilters` will be pushed down and evaluated in the underlying data sources. + // `postScanFilters` need to be evaluated after the scan. + // `postScanFilters` and `pushedFilters` can overlap, e.g. the parquet row group filter. + val (pushedFilters, postScanFilters) = pushFilters(reader, filters) + val output = pruneColumns(reader, relation, project ++ postScanFilters) + logInfo( + s""" + |Pushing operators to ${relation.source.getClass} + |Pushed Filters: ${pushedFilters.mkString(", ")} + |Post-Scan Filters: ${postScanFilters.mkString(",")} + |Output: ${output.mkString(", ")} + """.stripMargin) val scan = DataSourceV2ScanExec( output, relation.source, relation.options, pushedFilters, reader) - val filter = postScanFilters.reduceLeftOption(And) - val withFilter = filter.map(execution.FilterExec(_, scan)).getOrElse(scan) + val filterCondition = postScanFilters.reduceLeftOption(And) + val withFilter = filterCondition.map(FilterExec(_, scan)).getOrElse(scan) val withProjection = if (withFilter.output != project) { - execution.ProjectExec(project, withFilter) + ProjectExec(project, withFilter) } else { withFilter } From 9a75c18290fff7d116cf88a44f9120bf67d8bd27 Mon Sep 17 00:00:00 2001 From: Xiao Li Date: Mon, 18 Jun 2018 20:17:04 -0700 Subject: [PATCH 0979/1161] [SPARK-24542][SQL] UDF series UDFXPathXXXX allow users to pass carefully crafted XML to access arbitrary files ## What changes were proposed in this pull request? UDF series UDFXPathXXXX allow users to pass carefully crafted XML to access arbitrary files. Spark does not have built-in access control. When users use the external access control library, users might bypass them and access the file contents. This PR basically patches the Hive fix to Apache Spark. https://issues.apache.org/jira/browse/HIVE-18879 ## How was this patch tested? A unit test case Author: Xiao Li Closes #21549 from gatorsmile/xpathSecurity. --- .../expressions/xml/UDFXPathUtil.java | 28 ++++++++++++++++++- .../expressions/xml/UDFXPathUtilSuite.scala | 21 ++++++++++++++ .../xml/XPathExpressionSuite.scala | 5 ++-- 3 files changed, 51 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/xml/UDFXPathUtil.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/xml/UDFXPathUtil.java index d224332d8a6c9..023ec139652c5 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/xml/UDFXPathUtil.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/xml/UDFXPathUtil.java @@ -21,6 +21,9 @@ import java.io.Reader; import javax.xml.namespace.QName; +import javax.xml.parsers.DocumentBuilder; +import javax.xml.parsers.DocumentBuilderFactory; +import javax.xml.parsers.ParserConfigurationException; import javax.xml.xpath.XPath; import javax.xml.xpath.XPathConstants; import javax.xml.xpath.XPathExpression; @@ -37,9 +40,15 @@ * This is based on Hive's UDFXPathUtil implementation. */ public class UDFXPathUtil { + public static final String SAX_FEATURE_PREFIX = "http://xml.org/sax/features/"; + public static final String EXTERNAL_GENERAL_ENTITIES_FEATURE = "external-general-entities"; + public static final String EXTERNAL_PARAMETER_ENTITIES_FEATURE = "external-parameter-entities"; + private DocumentBuilderFactory dbf = DocumentBuilderFactory.newInstance(); + private DocumentBuilder builder = null; private XPath xpath = XPathFactory.newInstance().newXPath(); private ReusableStringReader reader = new ReusableStringReader(); private InputSource inputSource = new InputSource(reader); + private XPathExpression expression = null; private String oldPath = null; @@ -65,14 +74,31 @@ public Object eval(String xml, String path, QName qname) throws XPathExpressionE return null; } + if (builder == null){ + try { + initializeDocumentBuilderFactory(); + builder = dbf.newDocumentBuilder(); + } catch (ParserConfigurationException e) { + throw new RuntimeException( + "Error instantiating DocumentBuilder, cannot build xml parser", e); + } + } + reader.set(xml); try { - return expression.evaluate(inputSource, qname); + return expression.evaluate(builder.parse(inputSource), qname); } catch (XPathExpressionException e) { throw new RuntimeException("Invalid XML document: " + e.getMessage() + "\n" + xml, e); + } catch (Exception e) { + throw new RuntimeException("Error loading expression '" + oldPath + "'", e); } } + private void initializeDocumentBuilderFactory() throws ParserConfigurationException { + dbf.setFeature(SAX_FEATURE_PREFIX + EXTERNAL_GENERAL_ENTITIES_FEATURE, false); + dbf.setFeature(SAX_FEATURE_PREFIX + EXTERNAL_PARAMETER_ENTITIES_FEATURE, false); + } + public Boolean evalBoolean(String xml, String path) throws XPathExpressionException { return (Boolean) eval(xml, path, XPathConstants.BOOLEAN); } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/xml/UDFXPathUtilSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/xml/UDFXPathUtilSuite.scala index c4cde7091154b..0fec15bc42c17 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/xml/UDFXPathUtilSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/xml/UDFXPathUtilSuite.scala @@ -77,6 +77,27 @@ class UDFXPathUtilSuite extends SparkFunSuite { assert(ret == "foo") } + test("embedFailure") { + import org.apache.commons.io.FileUtils + import java.io.File + val secretValue = String.valueOf(Math.random) + val tempFile = File.createTempFile("verifyembed", ".tmp") + tempFile.deleteOnExit() + val fname = tempFile.getAbsolutePath + + FileUtils.writeStringToFile(tempFile, secretValue) + + val xml = + s""" + | + |]> + |&embed; + """.stripMargin + val evaled = new UDFXPathUtil().evalString(xml, "/foo") + assert(evaled.isEmpty) + } + test("number eval") { var ret = util.evalNumber("truefalseb3c1-77", "a/c[2]") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/xml/XPathExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/xml/XPathExpressionSuite.scala index bfa18a0919e45..c6f6d3abb860c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/xml/XPathExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/xml/XPathExpressionSuite.scala @@ -40,8 +40,9 @@ class XPathExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { // Test error message for invalid XML document val e1 = intercept[RuntimeException] { testExpr("/a>", "a", null.asInstanceOf[T]) } - assert(e1.getCause.getMessage.contains("Invalid XML document") && - e1.getCause.getMessage.contains("/a>")) + assert(e1.getCause.getCause.getMessage.contains( + "XML document structures must start and end within the same entity.")) + assert(e1.getMessage.contains("/a>")) // Test error message for invalid xpath val e2 = intercept[RuntimeException] { testExpr("", "!#$", null.asInstanceOf[T]) } From a78a9046413255756653f70165520efd486fb493 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Tue, 19 Jun 2018 10:42:08 -0700 Subject: [PATCH 0980/1161] [SPARK-24521][SQL][TEST] Fix ineffective test in CachedTableSuite ## What changes were proposed in this pull request? test("withColumn doesn't invalidate cached dataframe") in CachedTableSuite doesn't not work because: The UDF is executed and test count incremented when "df.cache()" is called and the subsequent "df.collect()" has no effect on the test result. This PR fixed this test and add another test for caching UDF. ## How was this patch tested? Add new tests. Author: Li Jin Closes #21531 from icexelloss/fix-cache-test. --- .../apache/spark/sql/CachedTableSuite.scala | 19 ---------- .../apache/spark/sql/DatasetCacheSuite.scala | 38 ++++++++++++++++++- 2 files changed, 37 insertions(+), 20 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index 81b7e18773f81..6982c22f4771d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -83,25 +83,6 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext }.sum } - test("withColumn doesn't invalidate cached dataframe") { - var evalCount = 0 - val myUDF = udf((x: String) => { evalCount += 1; "result" }) - val df = Seq(("test", 1)).toDF("s", "i").select(myUDF($"s")) - df.cache() - - df.collect() - assert(evalCount === 1) - - df.collect() - assert(evalCount === 1) - - val df2 = df.withColumn("newColumn", lit(1)) - df2.collect() - - // We should not reevaluate the cached dataframe - assert(evalCount === 1) - } - test("cache temp table") { withTempView("tempTable") { testData.select('key).createOrReplaceTempView("tempTable") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala index e0561ee2797a5..82a93f74dd76c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala @@ -17,12 +17,15 @@ package org.apache.spark.sql +import org.scalatest.concurrent.TimeLimits +import org.scalatest.time.SpanSugar._ + import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.storage.StorageLevel -class DatasetCacheSuite extends QueryTest with SharedSQLContext { +class DatasetCacheSuite extends QueryTest with SharedSQLContext with TimeLimits { import testImplicits._ test("get storage level") { @@ -96,4 +99,37 @@ class DatasetCacheSuite extends QueryTest with SharedSQLContext { agged.unpersist() assert(agged.storageLevel == StorageLevel.NONE, "The Dataset agged should not be cached.") } + + test("persist and then withColumn") { + val df = Seq(("test", 1)).toDF("s", "i") + val df2 = df.withColumn("newColumn", lit(1)) + + df.cache() + assertCached(df) + assertCached(df2) + + df.count() + assertCached(df2) + + df.unpersist() + assert(df.storageLevel == StorageLevel.NONE) + } + + test("cache UDF result correctly") { + val expensiveUDF = udf({x: Int => Thread.sleep(10000); x}) + val df = spark.range(0, 10).toDF("a").withColumn("b", expensiveUDF($"a")) + val df2 = df.agg(sum(df("b"))) + + df.cache() + df.count() + assertCached(df2) + + // udf has been evaluated during caching, and thus should not be re-evaluated here + failAfter(5 seconds) { + df2.collect() + } + + df.unpersist() + assert(df.storageLevel == StorageLevel.NONE) + } } From 9dbe53eb6bb5916d28000f2c0d646cf23094ac11 Mon Sep 17 00:00:00 2001 From: yucai Date: Tue, 19 Jun 2018 10:52:51 -0700 Subject: [PATCH 0981/1161] [SPARK-24556][SQL] Always rewrite output partitioning in ReusedExchangeExec and InMemoryTableScanExec ## What changes were proposed in this pull request? Currently, ReusedExchange and InMemoryTableScanExec only rewrite output partitioning if child's partitioning is HashPartitioning and do nothing for other partitioning, e.g., RangePartitioning. We should always rewrite it, otherwise, unnecessary shuffle could be introduced like https://issues.apache.org/jira/browse/SPARK-24556. ## How was this patch tested? Add new tests. Author: yucai Closes #21564 from yucai/SPARK-24556. --- .../columnar/InMemoryTableScanExec.scala | 6 +- .../sql/execution/exchange/Exchange.scala | 4 +- .../spark/sql/execution/PlannerSuite.scala | 64 ++++++++++++++++++- 3 files changed, 67 insertions(+), 7 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala index 0b4dd76c7d860..997cf92449c68 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.QueryPlan -import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning} +import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.execution.{ColumnarBatchScan, LeafExecNode, SparkPlan, WholeStageCodegenExec} import org.apache.spark.sql.execution.vectorized._ import org.apache.spark.sql.types._ @@ -169,8 +169,8 @@ case class InMemoryTableScanExec( // But the cached version could alias output, so we need to replace output. override def outputPartitioning: Partitioning = { relation.cachedPlan.outputPartitioning match { - case h: HashPartitioning => updateAttribute(h).asInstanceOf[HashPartitioning] - case _ => relation.cachedPlan.outputPartitioning + case e: Expression => updateAttribute(e).asInstanceOf[Partitioning] + case other => other } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala index 09f79a2de0ba0..1a5b7599bb7d9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala @@ -24,7 +24,7 @@ import org.apache.spark.broadcast import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, Expression, SortOrder} -import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning} +import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.{LeafExecNode, SparkPlan, UnaryExecNode} import org.apache.spark.sql.internal.SQLConf @@ -70,7 +70,7 @@ case class ReusedExchangeExec(override val output: Seq[Attribute], child: Exchan } override def outputPartitioning: Partitioning = child.outputPartitioning match { - case h: HashPartitioning => h.copy(expressions = h.expressions.map(updateAttr)) + case e: Expression => updateAttr(e).asInstanceOf[Partitioning] case other => other } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 37d468739c613..d254345e8fa54 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -18,13 +18,13 @@ package org.apache.spark.sql.execution import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{execution, Row} +import org.apache.spark.sql.{execution, DataFrame, Row} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Range, Repartition, Sort, Union} import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.execution.columnar.InMemoryRelation +import org.apache.spark.sql.execution.columnar.{InMemoryRelation, InMemoryTableScanExec} import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReusedExchangeExec, ReuseExchange, ShuffleExchangeExec} import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, SortMergeJoinExec} import org.apache.spark.sql.functions._ @@ -703,6 +703,66 @@ class PlannerSuite extends SharedSQLContext { Range(1, 2, 1, 1))) df.queryExecution.executedPlan.execute() } + + test("SPARK-24556: always rewrite output partitioning in ReusedExchangeExec " + + "and InMemoryTableScanExec") { + def checkOutputPartitioningRewrite( + plans: Seq[SparkPlan], + expectedPartitioningClass: Class[_]): Unit = { + assert(plans.size == 1) + val plan = plans.head + val partitioning = plan.outputPartitioning + assert(partitioning.getClass == expectedPartitioningClass) + val partitionedAttrs = partitioning.asInstanceOf[Expression].references + assert(partitionedAttrs.subsetOf(plan.outputSet)) + } + + def checkReusedExchangeOutputPartitioningRewrite( + df: DataFrame, + expectedPartitioningClass: Class[_]): Unit = { + val reusedExchange = df.queryExecution.executedPlan.collect { + case r: ReusedExchangeExec => r + } + checkOutputPartitioningRewrite(reusedExchange, expectedPartitioningClass) + } + + def checkInMemoryTableScanOutputPartitioningRewrite( + df: DataFrame, + expectedPartitioningClass: Class[_]): Unit = { + val inMemoryScan = df.queryExecution.executedPlan.collect { + case m: InMemoryTableScanExec => m + } + checkOutputPartitioningRewrite(inMemoryScan, expectedPartitioningClass) + } + + // ReusedExchange is HashPartitioning + val df1 = Seq(1 -> "a").toDF("i", "j").repartition($"i") + val df2 = Seq(1 -> "a").toDF("i", "j").repartition($"i") + checkReusedExchangeOutputPartitioningRewrite(df1.union(df2), classOf[HashPartitioning]) + + // ReusedExchange is RangePartitioning + val df3 = Seq(1 -> "a").toDF("i", "j").orderBy($"i") + val df4 = Seq(1 -> "a").toDF("i", "j").orderBy($"i") + checkReusedExchangeOutputPartitioningRewrite(df3.union(df4), classOf[RangePartitioning]) + + // InMemoryTableScan is HashPartitioning + Seq(1 -> "a").toDF("i", "j").repartition($"i").persist() + checkInMemoryTableScanOutputPartitioningRewrite( + Seq(1 -> "a").toDF("i", "j").repartition($"i"), classOf[HashPartitioning]) + + // InMemoryTableScan is RangePartitioning + spark.range(1, 100, 1, 10).toDF().persist() + checkInMemoryTableScanOutputPartitioningRewrite( + spark.range(1, 100, 1, 10).toDF(), classOf[RangePartitioning]) + + // InMemoryTableScan is PartitioningCollection + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + Seq(1 -> "a").toDF("i", "j").join(Seq(1 -> "a").toDF("m", "n"), $"i" === $"m").persist() + checkInMemoryTableScanOutputPartitioningRewrite( + Seq(1 -> "a").toDF("i", "j").join(Seq(1 -> "a").toDF("m", "n"), $"i" === $"m"), + classOf[PartitioningCollection]) + } + } } // Used for unit-testing EnsureRequirements From 13092d733791b19cd7994084178306e0c449f2ed Mon Sep 17 00:00:00 2001 From: rimolive Date: Tue, 19 Jun 2018 13:25:00 -0700 Subject: [PATCH 0982/1161] [SPARK-24534][K8S] Bypass non spark-on-k8s commands ## What changes were proposed in this pull request? This PR changes the entrypoint.sh to provide an option to run non spark-on-k8s commands (init, driver, executor) in order to let the user keep with the normal workflow without hacking the image to bypass the entrypoint ## How was this patch tested? This patch was built manually in my local machine and I ran some tests with a combination of ```docker run``` commands. Author: rimolive Closes #21572 from rimolive/rimolive-spark-24534. --- .../src/main/dockerfiles/spark/entrypoint.sh | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh index acdb4b1f09e0a..2f4e115e84ecd 100755 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh @@ -37,11 +37,17 @@ if [ -z "$uidentry" ] ; then fi SPARK_K8S_CMD="$1" -if [ -z "$SPARK_K8S_CMD" ]; then - echo "No command to execute has been provided." 1>&2 - exit 1 -fi -shift 1 +case "$SPARK_K8S_CMD" in + driver | driver-py | executor) + shift 1 + ;; + "") + ;; + *) + echo "Non-spark-on-k8s command provided, proceeding in pass-through mode..." + exec /sbin/tini -s -- "$@" + ;; +esac SPARK_CLASSPATH="$SPARK_CLASSPATH:${SPARK_HOME}/jars/*" env | grep SPARK_JAVA_OPT_ | sort -t_ -k4 -n | sed 's/[^=]*=\(.*\)/\1/g' > /tmp/java_opts.txt @@ -92,7 +98,6 @@ case "$SPARK_K8S_CMD" in "$@" $PYSPARK_PRIMARY $PYSPARK_ARGS ) ;; - executor) CMD=( ${JAVA_HOME}/bin/java From 2cb976355c615eee4ebd0a86f3911fa9284fccf6 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 19 Jun 2018 13:56:51 -0700 Subject: [PATCH 0983/1161] [SPARK-24565][SS] Add API for in Structured Streaming for exposing output rows of each microbatch as a DataFrame ## What changes were proposed in this pull request? Currently, the micro-batches in the MicroBatchExecution is not exposed to the user through any public API. This was because we did not want to expose the micro-batches, so that all the APIs we expose, we can eventually support them in the Continuous engine. But now that we have better sense of buiding a ContinuousExecution, I am considering adding APIs which will run only the MicroBatchExecution. I have quite a few use cases where exposing the microbatch output as a dataframe is useful. - Pass the output rows of each batch to a library that is designed only the batch jobs (example, uses many ML libraries need to collect() while learning). - Reuse batch data sources for output whose streaming version does not exists (e.g. redshift data source). - Writer the output rows to multiple places by writing twice for each batch. This is not the most elegant thing to do for multiple-output streaming queries but is likely to be better than running two streaming queries processing the same data twice. The proposal is to add a method `foreachBatch(f: Dataset[T] => Unit)` to Scala/Java/Python `DataStreamWriter`. ## How was this patch tested? New unit tests. Author: Tathagata Das Closes #21571 from tdas/foreachBatch. --- python/pyspark/java_gateway.py | 25 ++- python/pyspark/sql/streaming.py | 33 +++- python/pyspark/sql/tests.py | 36 +++++ python/pyspark/sql/utils.py | 23 +++ python/pyspark/streaming/context.py | 18 +-- .../streaming/sources/ForeachBatchSink.scala | 58 +++++++ .../sql/streaming/DataStreamWriter.scala | 63 +++++++- .../sources/ForeachBatchSinkSuite.scala | 148 ++++++++++++++++++ 8 files changed, 383 insertions(+), 21 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSink.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSinkSuite.scala diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py index 0afbe9dc6aa3e..fa2d5e8db716a 100644 --- a/python/pyspark/java_gateway.py +++ b/python/pyspark/java_gateway.py @@ -31,7 +31,7 @@ if sys.version >= '3': xrange = range -from py4j.java_gateway import java_import, JavaGateway, GatewayParameters +from py4j.java_gateway import java_import, JavaGateway, JavaObject, GatewayParameters from pyspark.find_spark_home import _find_spark_home from pyspark.serializers import read_int, write_with_length, UTF8Deserializer @@ -145,3 +145,26 @@ def do_server_auth(conn, auth_secret): if reply != "ok": conn.close() raise Exception("Unexpected reply from iterator server.") + + +def ensure_callback_server_started(gw): + """ + Start callback server if not already started. The callback server is needed if the Java + driver process needs to callback into the Python driver process to execute Python code. + """ + + # getattr will fallback to JVM, so we cannot test by hasattr() + if "_callback_server" not in gw.__dict__ or gw._callback_server is None: + gw.callback_server_parameters.eager_load = True + gw.callback_server_parameters.daemonize = True + gw.callback_server_parameters.daemonize_connections = True + gw.callback_server_parameters.port = 0 + gw.start_callback_server(gw.callback_server_parameters) + cbport = gw._callback_server.server_socket.getsockname()[1] + gw._callback_server.port = cbport + # gateway with real port + gw._python_proxy_port = gw._callback_server.port + # get the GatewayServer object in JVM by ID + jgws = JavaObject("GATEWAY_SERVER", gw._gateway_client) + # update the port of CallbackClient with real port + jgws.resetCallbackClient(jgws.getCallbackClient().getAddress(), gw._python_proxy_port) diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index 4984593bab491..8c1fd4af674d7 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -24,12 +24,14 @@ else: intlike = (int, long) +from py4j.java_gateway import java_import + from pyspark import since, keyword_only from pyspark.rdd import ignore_unicode_prefix from pyspark.sql.column import _to_seq from pyspark.sql.readwriter import OptionUtils, to_str from pyspark.sql.types import * -from pyspark.sql.utils import StreamingQueryException +from pyspark.sql.utils import ForeachBatchFunction, StreamingQueryException __all__ = ["StreamingQuery", "StreamingQueryManager", "DataStreamReader", "DataStreamWriter"] @@ -1016,6 +1018,35 @@ def func_with_open_process_close(partition_id, iterator): self._jwrite.foreach(jForeachWriter) return self + @since(2.4) + def foreachBatch(self, func): + """ + Sets the output of the streaming query to be processed using the provided + function. This is supported only the in the micro-batch execution modes (that is, when the + trigger is not continuous). In every micro-batch, the provided function will be called in + every micro-batch with (i) the output rows as a DataFrame and (ii) the batch identifier. + The batchId can be used deduplicate and transactionally write the output + (that is, the provided Dataset) to external systems. The output DataFrame is guaranteed + to exactly same for the same batchId (assuming all operations are deterministic in the + query). + + .. note:: Evolving. + + >>> def func(batch_df, batch_id): + ... batch_df.collect() + ... + >>> writer = sdf.writeStream.foreach(func) + """ + + from pyspark.java_gateway import ensure_callback_server_started + gw = self._spark._sc._gateway + java_import(gw.jvm, "org.apache.spark.sql.execution.streaming.sources.*") + + wrapped_func = ForeachBatchFunction(self._spark, func) + gw.jvm.PythonForeachBatchHelper.callForeachBatch(self._jwrite, wrapped_func) + ensure_callback_server_started(gw) + return self + @ignore_unicode_prefix @since(2.0) def start(self, path=None, format=None, outputMode=None, partitionBy=None, queryName=None, diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 4e5fafa77e109..94ab867f0bd9b 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -2126,6 +2126,42 @@ class WriterWithNonCallableClose(WithProcess): tester.assert_invalid_writer(WriterWithNonCallableClose(), "'close' in provided object is not callable") + def test_streaming_foreachBatch(self): + q = None + collected = dict() + + def collectBatch(batch_df, batch_id): + collected[batch_id] = batch_df.collect() + + try: + df = self.spark.readStream.format('text').load('python/test_support/sql/streaming') + q = df.writeStream.foreachBatch(collectBatch).start() + q.processAllAvailable() + self.assertTrue(0 in collected) + self.assertTrue(len(collected[0]), 2) + finally: + if q: + q.stop() + + def test_streaming_foreachBatch_propagates_python_errors(self): + from pyspark.sql.utils import StreamingQueryException + + q = None + + def collectBatch(df, id): + raise Exception("this should fail the query") + + try: + df = self.spark.readStream.format('text').load('python/test_support/sql/streaming') + q = df.writeStream.foreachBatch(collectBatch).start() + q.processAllAvailable() + self.fail("Expected a failure") + except StreamingQueryException as e: + self.assertTrue("this should fail" in str(e)) + finally: + if q: + q.stop() + def test_help_command(self): # Regression test for SPARK-5464 rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}']) diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py index 45363f089a73d..bb9ce02c4b60f 100644 --- a/python/pyspark/sql/utils.py +++ b/python/pyspark/sql/utils.py @@ -150,3 +150,26 @@ def require_minimum_pyarrow_version(): if LooseVersion(pyarrow.__version__) < LooseVersion(minimum_pyarrow_version): raise ImportError("PyArrow >= %s must be installed; however, " "your version was %s." % (minimum_pyarrow_version, pyarrow.__version__)) + + +class ForeachBatchFunction(object): + """ + This is the Python implementation of Java interface 'ForeachBatchFunction'. This wraps + the user-defined 'foreachBatch' function such that it can be called from the JVM when + the query is active. + """ + + def __init__(self, sql_ctx, func): + self.sql_ctx = sql_ctx + self.func = func + + def call(self, jdf, batch_id): + from pyspark.sql.dataframe import DataFrame + try: + self.func(DataFrame(jdf, self.sql_ctx), batch_id) + except Exception as e: + self.error = e + raise e + + class Java: + implements = ['org.apache.spark.sql.execution.streaming.sources.PythonForeachBatchFunction'] diff --git a/python/pyspark/streaming/context.py b/python/pyspark/streaming/context.py index dd924ef89868e..a4515828d180c 100644 --- a/python/pyspark/streaming/context.py +++ b/python/pyspark/streaming/context.py @@ -79,22 +79,8 @@ def _ensure_initialized(cls): java_import(gw.jvm, "org.apache.spark.streaming.api.java.*") java_import(gw.jvm, "org.apache.spark.streaming.api.python.*") - # start callback server - # getattr will fallback to JVM, so we cannot test by hasattr() - if "_callback_server" not in gw.__dict__ or gw._callback_server is None: - gw.callback_server_parameters.eager_load = True - gw.callback_server_parameters.daemonize = True - gw.callback_server_parameters.daemonize_connections = True - gw.callback_server_parameters.port = 0 - gw.start_callback_server(gw.callback_server_parameters) - cbport = gw._callback_server.server_socket.getsockname()[1] - gw._callback_server.port = cbport - # gateway with real port - gw._python_proxy_port = gw._callback_server.port - # get the GatewayServer object in JVM by ID - jgws = JavaObject("GATEWAY_SERVER", gw._gateway_client) - # update the port of CallbackClient with real port - jgws.resetCallbackClient(jgws.getCallbackClient().getAddress(), gw._python_proxy_port) + from pyspark.java_gateway import ensure_callback_server_started + ensure_callback_server_started(gw) # register serializer for TransformFunction # it happens before creating SparkContext when loading from checkpointing diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSink.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSink.scala new file mode 100644 index 0000000000000..03c567c58d46a --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSink.scala @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.sources + +import org.apache.spark.api.python.PythonException +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.execution.streaming.Sink +import org.apache.spark.sql.streaming.DataStreamWriter + +class ForeachBatchSink[T](batchWriter: (Dataset[T], Long) => Unit, encoder: ExpressionEncoder[T]) + extends Sink { + + override def addBatch(batchId: Long, data: DataFrame): Unit = { + val resolvedEncoder = encoder.resolveAndBind( + data.logicalPlan.output, + data.sparkSession.sessionState.analyzer) + val rdd = data.queryExecution.toRdd.map[T](resolvedEncoder.fromRow)(encoder.clsTag) + val ds = data.sparkSession.createDataset(rdd)(encoder) + batchWriter(ds, batchId) + } + + override def toString(): String = "ForeachBatchSink" +} + + +/** + * Interface that is meant to be extended by Python classes via Py4J. + * Py4J allows Python classes to implement Java interfaces so that the JVM can call back + * Python objects. In this case, this allows the user-defined Python `foreachBatch` function + * to be called from JVM when the query is active. + * */ +trait PythonForeachBatchFunction { + /** Call the Python implementation of this function */ + def call(batchDF: DataFrame, batchId: Long): Unit +} + +object PythonForeachBatchHelper { + def callForeachBatch(dsw: DataStreamWriter[Row], pythonFunc: PythonForeachBatchFunction): Unit = { + dsw.foreachBatch(pythonFunc.call _) + } +} + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala index 43e80e4e54239..926c0b69a03fd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala @@ -21,14 +21,15 @@ import java.util.Locale import scala.collection.JavaConverters._ -import org.apache.spark.annotation.InterfaceStability -import org.apache.spark.sql.{AnalysisException, Dataset, ForeachWriter} +import org.apache.spark.annotation.{InterfaceStability, Since} +import org.apache.spark.api.java.function.VoidFunction2 +import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.streaming.InternalOutputModes import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger -import org.apache.spark.sql.execution.streaming.sources.{ForeachWriterProvider, MemoryPlanV2, MemorySinkV2} +import org.apache.spark.sql.execution.streaming.sources._ import org.apache.spark.sql.sources.v2.{DataSourceOptions, StreamWriteSupport} /** @@ -279,6 +280,21 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { outputMode, useTempCheckpointLocation = true, trigger = trigger) + } else if (source == "foreachBatch") { + assertNotPartitioned("foreachBatch") + if (trigger.isInstanceOf[ContinuousTrigger]) { + throw new AnalysisException("'foreachBatch' is not supported with continuous trigger") + } + val sink = new ForeachBatchSink[T](foreachBatchWriter, ds.exprEnc) + df.sparkSession.sessionState.streamingQueryManager.startQuery( + extraOptions.get("queryName"), + extraOptions.get("checkpointLocation"), + df, + extraOptions.toMap, + sink, + outputMode, + useTempCheckpointLocation = true, + trigger = trigger) } else { val ds = DataSource.lookupDataSource(source, df.sparkSession.sessionState.conf) val disabledSources = df.sparkSession.sqlContext.conf.disabledV2StreamingWriters.split(",") @@ -322,6 +338,45 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { this } + /** + * :: Experimental :: + * + * (Scala-specific) Sets the output of the streaming query to be processed using the provided + * function. This is supported only the in the micro-batch execution modes (that is, when the + * trigger is not continuous). In every micro-batch, the provided function will be called in + * every micro-batch with (i) the output rows as a Dataset and (ii) the batch identifier. + * The batchId can be used deduplicate and transactionally write the output + * (that is, the provided Dataset) to external systems. The output Dataset is guaranteed + * to exactly same for the same batchId (assuming all operations are deterministic in the query). + * + * @since 2.4.0 + */ + @InterfaceStability.Evolving + def foreachBatch(function: (Dataset[T], Long) => Unit): DataStreamWriter[T] = { + this.source = "foreachBatch" + if (function == null) throw new IllegalArgumentException("foreachBatch function cannot be null") + this.foreachBatchWriter = function + this + } + + /** + * :: Experimental :: + * + * (Java-specific) Sets the output of the streaming query to be processed using the provided + * function. This is supported only the in the micro-batch execution modes (that is, when the + * trigger is not continuous). In every micro-batch, the provided function will be called in + * every micro-batch with (i) the output rows as a Dataset and (ii) the batch identifier. + * The batchId can be used deduplicate and transactionally write the output + * (that is, the provided Dataset) to external systems. The output Dataset is guaranteed + * to exactly same for the same batchId (assuming all operations are deterministic in the query). + * + * @since 2.4.0 + */ + @InterfaceStability.Evolving + def foreachBatch(function: VoidFunction2[Dataset[T], Long]): DataStreamWriter[T] = { + foreachBatch((batchDs: Dataset[T], batchId: Long) => function.call(batchDs, batchId)) + } + private def normalizedParCols: Option[Seq[String]] = partitioningColumns.map { cols => cols.map(normalize(_, "Partition")) } @@ -358,5 +413,7 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { private var foreachWriter: ForeachWriter[T] = null + private var foreachBatchWriter: (Dataset[T], Long) => Unit = null + private var partitioningColumns: Option[Seq[String]] = None } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSinkSuite.scala new file mode 100644 index 0000000000000..a4233e15e4ffd --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSinkSuite.scala @@ -0,0 +1,148 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.sources + +import scala.collection.mutable + +import org.apache.spark.sql._ +import org.apache.spark.sql.execution.streaming.MemoryStream +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.streaming._ + +case class KV(key: Int, value: Long) + +class ForeachBatchSinkSuite extends StreamTest { + import testImplicits._ + + test("foreachBatch with non-stateful query") { + val mem = MemoryStream[Int] + val ds = mem.toDS.map(_ + 1) + + val tester = new ForeachBatchTester[Int](mem) + val writer = (ds: Dataset[Int], batchId: Long) => tester.record(batchId, ds.map(_ + 1)) + + import tester._ + testWriter(ds, writer)( + check(in = 1, 2, 3)(out = 3, 4, 5), // out = in + 2 (i.e. 1 in query, 1 in writer) + check(in = 5, 6, 7)(out = 7, 8, 9)) + } + + test("foreachBatch with stateful query in update mode") { + val mem = MemoryStream[Int] + val ds = mem.toDF() + .select($"value" % 2 as "key") + .groupBy("key") + .agg(count("*") as "value") + .toDF.as[KV] + + val tester = new ForeachBatchTester[KV](mem) + val writer = (batchDS: Dataset[KV], batchId: Long) => tester.record(batchId, batchDS) + + import tester._ + testWriter(ds, writer, outputMode = OutputMode.Update)( + check(in = 0)(out = (0, 1L)), + check(in = 1)(out = (1, 1L)), + check(in = 2, 3)(out = (0, 2L), (1, 2L))) + } + + test("foreachBatch with stateful query in complete mode") { + val mem = MemoryStream[Int] + val ds = mem.toDF() + .select($"value" % 2 as "key") + .groupBy("key") + .agg(count("*") as "value") + .toDF.as[KV] + + val tester = new ForeachBatchTester[KV](mem) + val writer = (batchDS: Dataset[KV], batchId: Long) => tester.record(batchId, batchDS) + + import tester._ + testWriter(ds, writer, outputMode = OutputMode.Complete)( + check(in = 0)(out = (0, 1L)), + check(in = 1)(out = (0, 1L), (1, 1L)), + check(in = 2)(out = (0, 2L), (1, 1L))) + } + + test("foreachBatchSink does not affect metric generation") { + val mem = MemoryStream[Int] + val ds = mem.toDS.map(_ + 1) + + val tester = new ForeachBatchTester[Int](mem) + val writer = (ds: Dataset[Int], batchId: Long) => tester.record(batchId, ds.map(_ + 1)) + + import tester._ + testWriter(ds, writer)( + check(in = 1, 2, 3)(out = 3, 4, 5), + checkMetrics) + } + + test("throws errors in invalid situations") { + val ds = MemoryStream[Int].toDS + val ex1 = intercept[IllegalArgumentException] { + ds.writeStream.foreachBatch(null.asInstanceOf[(Dataset[Int], Long) => Unit]).start() + } + assert(ex1.getMessage.contains("foreachBatch function cannot be null")) + val ex2 = intercept[AnalysisException] { + ds.writeStream.foreachBatch((_, _) => {}).trigger(Trigger.Continuous("1 second")).start() + } + assert(ex2.getMessage.contains("'foreachBatch' is not supported with continuous trigger")) + val ex3 = intercept[AnalysisException] { + ds.writeStream.foreachBatch((_, _) => {}).partitionBy("value").start() + } + assert(ex3.getMessage.contains("'foreachBatch' does not support partitioning")) + } + + // ============== Helper classes and methods ================= + + private class ForeachBatchTester[T: Encoder](memoryStream: MemoryStream[Int]) { + trait Test + private case class Check(in: Seq[Int], out: Seq[T]) extends Test + private case object CheckMetrics extends Test + + private val recordedOutput = new mutable.HashMap[Long, Seq[T]] + + def testWriter( + ds: Dataset[T], + outputBatchWriter: (Dataset[T], Long) => Unit, + outputMode: OutputMode = OutputMode.Append())(tests: Test*): Unit = { + try { + var expectedBatchId = -1 + val query = ds.writeStream.outputMode(outputMode).foreachBatch(outputBatchWriter).start() + + tests.foreach { + case Check(in, out) => + expectedBatchId += 1 + memoryStream.addData(in) + query.processAllAvailable() + assert(recordedOutput.contains(expectedBatchId)) + val ds: Dataset[T] = spark.createDataset[T](recordedOutput(expectedBatchId)) + checkDataset[T](ds, out: _*) + case CheckMetrics => + assert(query.recentProgress.exists(_.numInputRows > 0)) + } + } finally { + sqlContext.streams.active.foreach(_.stop()) + } + } + + def check(in: Int*)(out: T*): Test = Check(in, out) + def checkMetrics: Test = CheckMetrics + def record(batchId: Long, ds: Dataset[T]): Unit = recordedOutput.put(batchId, ds.collect()) + implicit def conv(x: (Int, Long)): KV = KV(x._1, x._2) + } +} From bc0498d5820ded2b428277e396502e74ef0ce36d Mon Sep 17 00:00:00 2001 From: Maryann Xue Date: Tue, 19 Jun 2018 15:27:20 -0700 Subject: [PATCH 0984/1161] [SPARK-24583][SQL] Wrong schema type in InsertIntoDataSourceCommand ## What changes were proposed in this pull request? Change insert input schema type: "insertRelationType" -> "insertRelationType.asNullable", in order to avoid nullable being overridden. ## How was this patch tested? Added one test in InsertSuite. Author: Maryann Xue Closes #21585 from maryannxue/spark-24583. --- .../InsertIntoDataSourceCommand.scala | 5 +- .../spark/sql/sources/InsertSuite.scala | 51 ++++++++++++++++++- 2 files changed, 52 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSourceCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSourceCommand.scala index a813829d50cb1..80d7608a22891 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSourceCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSourceCommand.scala @@ -38,9 +38,8 @@ case class InsertIntoDataSourceCommand( override def run(sparkSession: SparkSession): Seq[Row] = { val relation = logicalRelation.relation.asInstanceOf[InsertableRelation] val data = Dataset.ofRows(sparkSession, query) - // Apply the schema of the existing table to the new data. - val df = sparkSession.internalCreateDataFrame(data.queryExecution.toRdd, logicalRelation.schema) - relation.insert(df, overwrite) + // Data has been casted to the target relation's schema by the PreprocessTableInsertion rule. + relation.insert(data, overwrite) // Re-cache all cached plans(including this relation itself, if it's cached) that refer to this // data source relation. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala index fef01c860db6e..438d5d8176b8b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala @@ -20,12 +20,36 @@ package org.apache.spark.sql.sources import java.io.File import org.apache.spark.SparkException -import org.apache.spark.sql.{AnalysisException, Row} +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable, CatalogTableType} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.PartitionOverwriteMode import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types._ import org.apache.spark.util.Utils +class SimpleInsertSource extends SchemaRelationProvider { + override def createRelation( + sqlContext: SQLContext, + parameters: Map[String, String], + schema: StructType): BaseRelation = { + SimpleInsert(schema)(sqlContext.sparkSession) + } +} + +case class SimpleInsert(userSpecifiedSchema: StructType)(@transient val sparkSession: SparkSession) + extends BaseRelation with InsertableRelation { + + override def sqlContext: SQLContext = sparkSession.sqlContext + + override def schema: StructType = userSpecifiedSchema + + override def insert(input: DataFrame, overwrite: Boolean): Unit = { + input.collect + } +} + class InsertSuite extends DataSourceTest with SharedSQLContext { import testImplicits._ @@ -520,4 +544,29 @@ class InsertSuite extends DataSourceTest with SharedSQLContext { } } } + + test("SPARK-24583 Wrong schema type in InsertIntoDataSourceCommand") { + withTable("test_table") { + val schema = new StructType() + .add("i", LongType, false) + .add("s", StringType, false) + val newTable = CatalogTable( + identifier = TableIdentifier("test_table", None), + tableType = CatalogTableType.EXTERNAL, + storage = CatalogStorageFormat( + locationUri = None, + inputFormat = None, + outputFormat = None, + serde = None, + compressed = false, + properties = Map.empty), + schema = schema, + provider = Some(classOf[SimpleInsertSource].getName)) + + spark.sessionState.catalog.createTable(newTable, false) + + sql("INSERT INTO TABLE test_table SELECT 1, 'a'") + sql("INSERT INTO TABLE test_table SELECT 2, null") + } + } } From bc111463a766a5619966a282fbe0fec991088ceb Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Tue, 19 Jun 2018 22:29:00 -0700 Subject: [PATCH 0985/1161] [SPARK-23778][CORE] Avoid unneeded shuffle when union gets an empty RDD ## What changes were proposed in this pull request? When a `union` is invoked on several RDDs of which one is an empty RDD, the result of the operation is a `UnionRDD`. This causes an unneeded extra-shuffle when all the other RDDs have the same partitioning. The PR ignores incoming empty RDDs in the union method. ## How was this patch tested? added UT Author: Marco Gaido Closes #21333 from mgaido91/SPARK-23778. --- .../main/scala/org/apache/spark/SparkContext.scala | 9 +++++---- .../test/scala/org/apache/spark/rdd/RDDSuite.scala | 14 +++++++++++++- 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 5e8595603cc90..74bfb5d6d2ea3 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -1306,11 +1306,12 @@ class SparkContext(config: SparkConf) extends Logging { /** Build the union of a list of RDDs. */ def union[T: ClassTag](rdds: Seq[RDD[T]]): RDD[T] = withScope { - val partitioners = rdds.flatMap(_.partitioner).toSet - if (rdds.forall(_.partitioner.isDefined) && partitioners.size == 1) { - new PartitionerAwareUnionRDD(this, rdds) + val nonEmptyRdds = rdds.filter(!_.partitions.isEmpty) + val partitioners = nonEmptyRdds.flatMap(_.partitioner).toSet + if (nonEmptyRdds.forall(_.partitioner.isDefined) && partitioners.size == 1) { + new PartitionerAwareUnionRDD(this, nonEmptyRdds) } else { - new UnionRDD(this, rdds) + new UnionRDD(this, nonEmptyRdds) } } diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index 191c61250ce21..5148ce05bd918 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -154,6 +154,16 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext { } } + test("SPARK-23778: empty RDD in union should not produce a UnionRDD") { + val rddWithPartitioner = sc.parallelize(Seq(1 -> true)).partitionBy(new HashPartitioner(1)) + val emptyRDD = sc.emptyRDD[(Int, Boolean)] + val unionRDD = sc.union(emptyRDD, rddWithPartitioner) + assert(unionRDD.isInstanceOf[PartitionerAwareUnionRDD[_]]) + val unionAllEmptyRDD = sc.union(emptyRDD, emptyRDD) + assert(unionAllEmptyRDD.isInstanceOf[UnionRDD[_]]) + assert(unionAllEmptyRDD.collect().isEmpty) + } + test("partitioner aware union") { def makeRDDWithPartitioner(seq: Seq[Int]): RDD[Int] = { sc.makeRDD(seq, 1) @@ -1047,7 +1057,9 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext { private class CyclicalDependencyRDD[T: ClassTag] extends RDD[T](sc, Nil) { private val mutableDependencies: ArrayBuffer[Dependency[_]] = ArrayBuffer.empty override def compute(p: Partition, c: TaskContext): Iterator[T] = Iterator.empty - override def getPartitions: Array[Partition] = Array.empty + override def getPartitions: Array[Partition] = Array(new Partition { + override def index: Int = 0 + }) override def getDependencies: Seq[Dependency[_]] = mutableDependencies def addDependency(dep: Dependency[_]) { mutableDependencies += dep From c8ef9232cf8b8ef262404b105cea83c1f393d8c3 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Wed, 20 Jun 2018 18:38:42 +0200 Subject: [PATCH 0986/1161] [MINOR][SQL] Remove invalid comment from SparkStrategies ## What changes were proposed in this pull request? This patch is removing invalid comment from SparkStrategies, given that TODO-like comment is no longer preferred one as the comment: https://github.com/apache/spark/pull/21388#issuecomment-396856235 Removing invalid comment will prevent contributors to spend their times which is not going to be merged. ## How was this patch tested? N/A Author: Jungtaek Lim Closes #21595 from HeartSaVioR/MINOR-remove-invalid-comment-on-spark-strategies. --- .../scala/org/apache/spark/sql/execution/SparkStrategies.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index d6951ad01fb0c..07a6fcae83b70 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -494,7 +494,6 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } } - // Can we automate these 'pass through' operations? object BasicOperators extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case d: DataWritingCommand => DataWritingCommandExec(d, planLater(d.query)) :: Nil From c5a0d1132a5608f2110781763f4c2229c6cd7175 Mon Sep 17 00:00:00 2001 From: aokolnychyi Date: Wed, 20 Jun 2018 18:57:13 +0200 Subject: [PATCH 0987/1161] [SPARK-24575][SQL] Prohibit window expressions inside WHERE and HAVING clauses ## What changes were proposed in this pull request? As discussed [before](https://github.com/apache/spark/pull/19193#issuecomment-393726964), this PR prohibits window expressions inside WHERE and HAVING clauses. ## How was this patch tested? This PR comes with a dedicated unit test. Author: aokolnychyi Closes #21580 from aokolnychyi/spark-24575. --- .../sql/catalyst/analysis/Analyzer.scala | 3 ++ .../sql/DataFrameWindowFunctionsSuite.scala | 42 +++++++++++++++++-- 2 files changed, 41 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 6e3107f1c6f75..e187133d03b17 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1923,6 +1923,9 @@ class Analyzer( // "Aggregate with Having clause" will be triggered. def apply(plan: LogicalPlan): LogicalPlan = plan transformDown { + case Filter(condition, _) if hasWindowFunction(condition) => + failAnalysis("It is not allowed to use window functions inside WHERE and HAVING clauses") + // Aggregate with Having clause. This rule works with an unresolved Aggregate because // a resolved Aggregate will not have Window Functions. case f @ Filter(condition, a @ Aggregate(groupingExprs, aggregateExprs, child)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala index 3ea398aad7375..97a843978f0bd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala @@ -17,9 +17,7 @@ package org.apache.spark.sql -import java.sql.{Date, Timestamp} - -import scala.collection.mutable +import org.scalatest.Matchers.the import org.apache.spark.TestUtils.{assertNotSpilled, assertSpilled} import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction, Window} @@ -27,7 +25,6 @@ import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.CalendarInterval /** * Window function testing for DataFrame API. @@ -624,4 +621,41 @@ class DataFrameWindowFunctionsSuite extends QueryTest with SharedSQLContext { } } } + + test("SPARK-24575: Window functions inside WHERE and HAVING clauses") { + def checkAnalysisError(df: => DataFrame): Unit = { + val thrownException = the [AnalysisException] thrownBy { + df.queryExecution.analyzed + } + assert(thrownException.message.contains("window functions inside WHERE and HAVING clauses")) + } + + checkAnalysisError(testData2.select('a).where(rank().over(Window.orderBy('b)) === 1)) + checkAnalysisError(testData2.where('b === 2 && rank().over(Window.orderBy('b)) === 1)) + checkAnalysisError( + testData2.groupBy('a) + .agg(avg('b).as("avgb")) + .where('a > 'avgb && rank().over(Window.orderBy('a)) === 1)) + checkAnalysisError( + testData2.groupBy('a) + .agg(max('b).as("maxb"), sum('b).as("sumb")) + .where(rank().over(Window.orderBy('a)) === 1)) + checkAnalysisError( + testData2.groupBy('a) + .agg(max('b).as("maxb"), sum('b).as("sumb")) + .where('sumb === 5 && rank().over(Window.orderBy('a)) === 1)) + + checkAnalysisError(sql("SELECT a FROM testData2 WHERE RANK() OVER(ORDER BY b) = 1")) + checkAnalysisError(sql("SELECT * FROM testData2 WHERE b = 2 AND RANK() OVER(ORDER BY b) = 1")) + checkAnalysisError( + sql("SELECT * FROM testData2 GROUP BY a HAVING a > AVG(b) AND RANK() OVER(ORDER BY a) = 1")) + checkAnalysisError( + sql("SELECT a, MAX(b), SUM(b) FROM testData2 GROUP BY a HAVING RANK() OVER(ORDER BY a) = 1")) + checkAnalysisError( + sql( + s"""SELECT a, MAX(b) + |FROM testData2 + |GROUP BY a + |HAVING SUM(b) = 5 AND RANK() OVER(ORDER BY a) = 1""".stripMargin)) + } } From 3f4bda7289f1bfbbe8b9bc4b516007f569c44d2e Mon Sep 17 00:00:00 2001 From: Wenbo Zhao Date: Wed, 20 Jun 2018 14:26:04 -0700 Subject: [PATCH 0988/1161] [SPARK-24578][CORE] Cap sub-region's size of returned nio buffer ## What changes were proposed in this pull request? This PR tries to fix the performance regression introduced by SPARK-21517. In our production job, we performed many parallel computations, with high possibility, some task could be scheduled to a host-2 where it needs to read the cache block data from host-1. Often, this big transfer makes the cluster suffer time out issue (it will retry 3 times, each with 120s timeout, and then do recompute to put the cache block into the local MemoryStore). The root cause is that we don't do `consolidateIfNeeded` anymore as we are using ``` Unpooled.wrappedBuffer(chunks.length, getChunks(): _*) ``` in ChunkedByteBuffer. If we have many small chunks, it could cause the `buf.notBuffer(...)` have very bad performance in the case that we have to call `copyByteBuf(...)` many times. ## How was this patch tested? Existing unit tests and also test in production Author: Wenbo Zhao Closes #21593 from WenboZhao/spark-24578. --- .../network/protocol/MessageWithHeader.java | 25 ++++--------------- 1 file changed, 5 insertions(+), 20 deletions(-) diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java index a5337656cbd84..e7b66a6f33a82 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java @@ -137,30 +137,15 @@ protected void deallocate() { } private int copyByteBuf(ByteBuf buf, WritableByteChannel target) throws IOException { - ByteBuffer buffer = buf.nioBuffer(); - int written = (buffer.remaining() <= NIO_BUFFER_LIMIT) ? - target.write(buffer) : writeNioBuffer(target, buffer); + // SPARK-24578: cap the sub-region's size of returned nio buffer to improve the performance + // for the case that the passed-in buffer has too many components. + int length = Math.min(buf.readableBytes(), NIO_BUFFER_LIMIT); + ByteBuffer buffer = buf.nioBuffer(buf.readerIndex(), length); + int written = target.write(buffer); buf.skipBytes(written); return written; } - private int writeNioBuffer( - WritableByteChannel writeCh, - ByteBuffer buf) throws IOException { - int originalLimit = buf.limit(); - int ret = 0; - - try { - int ioSize = Math.min(buf.remaining(), NIO_BUFFER_LIMIT); - buf.limit(buf.position() + ioSize); - ret = writeCh.write(buf); - } finally { - buf.limit(originalLimit); - } - - return ret; - } - @Override public MessageWithHeader touch(Object o) { super.touch(o); From 15747cfd3246385ffb23e19e28d2e4effa710bf6 Mon Sep 17 00:00:00 2001 From: Ray Burgemeestre Date: Wed, 20 Jun 2018 17:09:37 -0700 Subject: [PATCH 0989/1161] [SPARK-24547][K8S] Allow for building spark on k8s docker images without cache and don't forget to push spark-py container. ## What changes were proposed in this pull request? https://issues.apache.org/jira/browse/SPARK-24547 TL;DR from JIRA issue: - First time I generated images for 2.4.0 Docker was using it's cache, so actually when running jobs, old jars where still in the Docker image. This produces errors like this in the executors: `java.io.InvalidClassException: org.apache.spark.storage.BlockManagerId; local class incompatible: stream classdesc serialVersionUID = 6155820641931972169, local class serialVersionUID = -3720498261147521051` - The second problem was that the spark container is pushed, but the spark-py container wasn't yet. This was just forgotten in the initial PR. - A third problem I also ran into because I had an older docker was https://github.com/apache/spark/pull/21551 so I have not included a fix for that in this ticket. ## How was this patch tested? I've tested it on my own Spark on k8s deployment. Author: Ray Burgemeestre Closes #21555 from rayburgemeestre/SPARK-24547. --- bin/docker-image-tool.sh | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/bin/docker-image-tool.sh b/bin/docker-image-tool.sh index a871ab5d448c3..a3f1bcffaea57 100755 --- a/bin/docker-image-tool.sh +++ b/bin/docker-image-tool.sh @@ -70,17 +70,18 @@ function build { local BASEDOCKERFILE=${BASEDOCKERFILE:-"$IMG_PATH/spark/Dockerfile"} local PYDOCKERFILE=${PYDOCKERFILE:-"$IMG_PATH/spark/bindings/python/Dockerfile"} - docker build "${BUILD_ARGS[@]}" \ + docker build $NOCACHEARG "${BUILD_ARGS[@]}" \ -t $(image_ref spark) \ -f "$BASEDOCKERFILE" . - docker build "${BINDING_BUILD_ARGS[@]}" \ + docker build $NOCACHEARG "${BINDING_BUILD_ARGS[@]}" \ -t $(image_ref spark-py) \ -f "$PYDOCKERFILE" . } function push { docker push "$(image_ref spark)" + docker push "$(image_ref spark-py)" } function usage { @@ -99,6 +100,7 @@ Options: -r repo Repository address. -t tag Tag to apply to the built image, or to identify the image to be pushed. -m Use minikube's Docker daemon. + -n Build docker image with --no-cache Using minikube when building images will do so directly into minikube's Docker daemon. There is no need to push the images into minikube in that case, they'll be automatically @@ -127,7 +129,8 @@ REPO= TAG= BASEDOCKERFILE= PYDOCKERFILE= -while getopts f:mr:t: option +NOCACHEARG= +while getopts f:mr:t:n option do case "${option}" in @@ -135,6 +138,7 @@ do p) PYDOCKERFILE=${OPTARG};; r) REPO=${OPTARG};; t) TAG=${OPTARG};; + n) NOCACHEARG="--no-cache";; m) if ! which minikube 1>/dev/null; then error "Cannot find minikube." From 9de11d3f901bc206a33b9da3e7499bcd43e0142a Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Thu, 21 Jun 2018 12:24:53 +0900 Subject: [PATCH 0990/1161] [SPARK-23912][SQL] add array_distinct ## What changes were proposed in this pull request? Add array_distinct to remove duplicate value from the array. ## How was this patch tested? Add unit tests Author: Huaxin Gao Closes #21050 from huaxingao/spark-23912. --- python/pyspark/sql/functions.py | 14 + .../catalyst/analysis/FunctionRegistry.scala | 1 + .../expressions/collectionOperations.scala | 279 ++++++++++++++++++ .../CollectionExpressionsSuite.scala | 45 +++ .../org/apache/spark/sql/functions.scala | 7 + .../spark/sql/DataFrameFunctionsSuite.scala | 22 ++ 6 files changed, 368 insertions(+) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index e6346691fb1d4..11b179fe26bfc 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1999,6 +1999,20 @@ def array_remove(col, element): return Column(sc._jvm.functions.array_remove(_to_java_column(col), element)) +@since(2.4) +def array_distinct(col): + """ + Collection function: removes duplicate values from the array. + :param col: name of column or expression + + >>> df = spark.createDataFrame([([1, 2, 3, 2],), ([4, 5, 5, 4],)], ['data']) + >>> df.select(array_distinct(df.data)).collect() + [Row(array_distinct(data)=[1, 2, 3]), Row(array_distinct(data)=[4, 5])] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.array_distinct(_to_java_column(col))) + + @since(1.4) def explode(col): """Returns a new row for each element in the given array or map. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 3700c63d817ea..4b09b9a7e75df 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -433,6 +433,7 @@ object FunctionRegistry { expression[Flatten]("flatten"), expression[ArrayRepeat]("array_repeat"), expression[ArrayRemove]("array_remove"), + expression[ArrayDistinct]("array_distinct"), CreateStruct.registryEntry, // mask functions diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index d76f3013f0c41..7c064a130ff35 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.array.ByteArrayMethods import org.apache.spark.unsafe.types.{ByteArray, UTF8String} +import org.apache.spark.util.collection.OpenHashSet /** * Base trait for [[BinaryExpression]]s with two arrays of the same element type and implicit @@ -2355,3 +2356,281 @@ case class ArrayRemove(left: Expression, right: Expression) override def prettyName: String = "array_remove" } + +/** + * Removes duplicate values from the array. + */ +@ExpressionDescription( + usage = "_FUNC_(array) - Removes duplicate values from the array.", + examples = """ + Examples: + > SELECT _FUNC_(array(1, 2, 3, null, 3)); + [1,2,3,null] + """, since = "2.4.0") +case class ArrayDistinct(child: Expression) + extends UnaryExpression with ExpectsInputTypes { + + override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType) + + override def dataType: DataType = child.dataType + + @transient lazy val elementType: DataType = dataType.asInstanceOf[ArrayType].elementType + + @transient private lazy val ordering: Ordering[Any] = + TypeUtils.getInterpretedOrdering(elementType) + + override def checkInputDataTypes(): TypeCheckResult = { + super.checkInputDataTypes() match { + case f: TypeCheckResult.TypeCheckFailure => f + case TypeCheckResult.TypeCheckSuccess => + TypeUtils.checkForOrderingExpr(elementType, s"function $prettyName") + } + } + + @transient private lazy val elementTypeSupportEquals = elementType match { + case BinaryType => false + case _: AtomicType => true + case _ => false + } + + override def nullSafeEval(array: Any): Any = { + val data = array.asInstanceOf[ArrayData].toArray[AnyRef](elementType) + if (elementTypeSupportEquals) { + new GenericArrayData(data.distinct.asInstanceOf[Array[Any]]) + } else { + var foundNullElement = false + var pos = 0 + for (i <- 0 until data.length) { + if (data(i) == null) { + if (!foundNullElement) { + foundNullElement = true + pos = pos + 1 + } + } else { + var j = 0 + var done = false + while (j <= i && !done) { + if (data(j) != null && ordering.equiv(data(j), data(i))) { + done = true + } + j = j + 1 + } + if (i == j - 1) { + pos = pos + 1 + } + } + } + new GenericArrayData(data.slice(0, pos)) + } + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, (array) => { + val i = ctx.freshName("i") + val j = ctx.freshName("j") + val sizeOfDistinctArray = ctx.freshName("sizeOfDistinctArray") + val getValue1 = CodeGenerator.getValue(array, elementType, i) + val getValue2 = CodeGenerator.getValue(array, elementType, j) + val foundNullElement = ctx.freshName("foundNullElement") + val openHashSet = classOf[OpenHashSet[_]].getName + val hs = ctx.freshName("hs") + val classTag = s"scala.reflect.ClassTag$$.MODULE$$.Object()" + if (elementTypeSupportEquals) { + s""" + |int $sizeOfDistinctArray = 0; + |boolean $foundNullElement = false; + |$openHashSet $hs = new $openHashSet($classTag); + |for (int $i = 0; $i < $array.numElements(); $i ++) { + | if ($array.isNullAt($i)) { + | $foundNullElement = true; + | } else { + | $hs.add($getValue1); + | } + |} + |$sizeOfDistinctArray = $hs.size() + ($foundNullElement ? 1 : 0); + |${genCodeForResult(ctx, ev, array, sizeOfDistinctArray)} + """.stripMargin + } else { + s""" + |int $sizeOfDistinctArray = 0; + |boolean $foundNullElement = false; + |for (int $i = 0; $i < $array.numElements(); $i ++) { + | if ($array.isNullAt($i)) { + | if (!($foundNullElement)) { + | $sizeOfDistinctArray = $sizeOfDistinctArray + 1; + | $foundNullElement = true; + | } + | } else { + | int $j; + | for ($j = 0; $j < $i; $j ++) { + | if (!$array.isNullAt($j) && ${ctx.genEqual(elementType, getValue1, getValue2)}) { + | break; + | } + | } + | if ($i == $j) { + | $sizeOfDistinctArray = $sizeOfDistinctArray + 1; + | } + | } + |} + | + |${genCodeForResult(ctx, ev, array, sizeOfDistinctArray)} + """.stripMargin + } + }) + } + + private def setNull( + isPrimitive: Boolean, + foundNullElement: String, + distinctArray: String, + pos: String): String = { + val setNullValue = + if (!isPrimitive) { + s"$distinctArray[$pos] = null"; + } else { + s"$distinctArray.setNullAt($pos)"; + } + + s""" + |if (!($foundNullElement)) { + | $setNullValue; + | $pos = $pos + 1; + | $foundNullElement = true; + |} + """.stripMargin + } + + private def setNotNullValue(isPrimitive: Boolean, + distinctArray: String, + pos: String, + getValue1: String, + primitiveValueTypeName: String): String = { + if (!isPrimitive) { + s"$distinctArray[$pos] = $getValue1"; + } else { + s"$distinctArray.set$primitiveValueTypeName($pos, $getValue1)"; + } + } + + private def setValueForFastEval( + isPrimitive: Boolean, + hs: String, + distinctArray: String, + pos: String, + getValue1: String, + primitiveValueTypeName: String): String = { + val setValue = setNotNullValue(isPrimitive, + distinctArray, pos, getValue1, primitiveValueTypeName) + s""" + |if (!($hs.contains($getValue1))) { + | $hs.add($getValue1); + | $setValue; + | $pos = $pos + 1; + |} + """.stripMargin + } + + private def setValueForBruteForceEval( + isPrimitive: Boolean, + i: String, + j: String, + inputArray: String, + distinctArray: String, + pos: String, + getValue1: String, + isEqual: String, + primitiveValueTypeName: String): String = { + val setValue = setNotNullValue(isPrimitive, + distinctArray, pos, getValue1, primitiveValueTypeName) + s""" + |int $j; + |for ($j = 0; $j < $i; $j ++) { + | if (!$inputArray.isNullAt($j) && $isEqual) { + | break; + | } + |} + |if ($i == $j) { + | $setValue; + | $pos = $pos + 1; + |} + """.stripMargin + } + + def genCodeForResult( + ctx: CodegenContext, + ev: ExprCode, + inputArray: String, + size: String): String = { + val distinctArray = ctx.freshName("distinctArray") + val i = ctx.freshName("i") + val j = ctx.freshName("j") + val pos = ctx.freshName("pos") + val getValue1 = CodeGenerator.getValue(inputArray, elementType, i) + val getValue2 = CodeGenerator.getValue(inputArray, elementType, j) + val isEqual = ctx.genEqual(elementType, getValue1, getValue2) + val foundNullElement = ctx.freshName("foundNullElement") + val hs = ctx.freshName("hs") + val openHashSet = classOf[OpenHashSet[_]].getName + if (!CodeGenerator.isPrimitiveType(elementType)) { + val arrayClass = classOf[GenericArrayData].getName + val classTag = s"scala.reflect.ClassTag$$.MODULE$$.Object()" + val setNullForNonPrimitive = + setNull(false, foundNullElement, distinctArray, pos) + if (elementTypeSupportEquals) { + val setValueForFast = setValueForFastEval(false, hs, distinctArray, pos, getValue1, "") + s""" + |int $pos = 0; + |Object[] $distinctArray = new Object[$size]; + |boolean $foundNullElement = false; + |$openHashSet $hs = new $openHashSet($classTag); + |for (int $i = 0; $i < $inputArray.numElements(); $i ++) { + | if ($inputArray.isNullAt($i)) { + | $setNullForNonPrimitive; + | } else { + | $setValueForFast; + | } + |} + |${ev.value} = new $arrayClass($distinctArray); + """.stripMargin + } else { + val setValueForBruteForce = setValueForBruteForceEval( + false, i, j, inputArray, distinctArray, pos, getValue1, isEqual, "") + s""" + |int $pos = 0; + |Object[] $distinctArray = new Object[$size]; + |boolean $foundNullElement = false; + |for (int $i = 0; $i < $inputArray.numElements(); $i ++) { + | if ($inputArray.isNullAt($i)) { + | $setNullForNonPrimitive; + | } else { + | $setValueForBruteForce; + | } + |} + |${ev.value} = new $arrayClass($distinctArray); + """.stripMargin + } + } else { + val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType) + val setNullForPrimitive = setNull(true, foundNullElement, distinctArray, pos) + val classTag = s"scala.reflect.ClassTag$$.MODULE$$.$primitiveValueTypeName()" + val setValueForFast = + setValueForFastEval(true, hs, distinctArray, pos, getValue1, primitiveValueTypeName) + s""" + |${ctx.createUnsafeArray(distinctArray, size, elementType, s" $prettyName failed.")} + |int $pos = 0; + |boolean $foundNullElement = false; + |$openHashSet $hs = new $openHashSet($classTag); + |for (int $i = 0; $i < $inputArray.numElements(); $i ++) { + | if ($inputArray.isNullAt($i)) { + | $setNullForPrimitive; + | } else { + | $setValueForFast; + | } + |} + |${ev.value} = $distinctArray; + """.stripMargin + } + } + + override def prettyName: String = "array_distinct" +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 85e692bdc4ef1..f377f9c8cd533 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -766,4 +766,49 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(ArrayRemove(c1, dataToRemove2), Seq[Seq[Int]](Seq[Int](5, 6), Seq[Int](2, 1))) checkEvaluation(ArrayRemove(c2, dataToRemove2), Seq[Seq[Int]](null, Seq[Int](2, 1))) } + + test("Array Distinct") { + val a0 = Literal.create(Seq(2, 1, 2, 3, 4, 4, 5), ArrayType(IntegerType)) + val a1 = Literal.create(Seq.empty[Integer], ArrayType(IntegerType)) + val a2 = Literal.create(Seq("b", "a", "a", "c", "b"), ArrayType(StringType)) + val a3 = Literal.create(Seq("b", null, "a", null, "a", null), ArrayType(StringType)) + val a4 = Literal.create(Seq(null, null, null), ArrayType(NullType)) + val a5 = Literal.create(Seq(true, false, false, true), ArrayType(BooleanType)) + val a6 = Literal.create(Seq(1.123, 0.1234, 1.121, 1.123, 1.1230, 1.121, 0.1234), + ArrayType(DoubleType)) + val a7 = Literal.create(Seq(1.123f, 0.1234f, 1.121f, 1.123f, 1.1230f, 1.121f, 0.1234f), + ArrayType(FloatType)) + + checkEvaluation(new ArrayDistinct(a0), Seq(2, 1, 3, 4, 5)) + checkEvaluation(new ArrayDistinct(a1), Seq.empty[Integer]) + checkEvaluation(new ArrayDistinct(a2), Seq("b", "a", "c")) + checkEvaluation(new ArrayDistinct(a3), Seq("b", null, "a")) + checkEvaluation(new ArrayDistinct(a4), Seq(null)) + checkEvaluation(new ArrayDistinct(a5), Seq(true, false)) + checkEvaluation(new ArrayDistinct(a6), Seq(1.123, 0.1234, 1.121)) + checkEvaluation(new ArrayDistinct(a7), Seq(1.123f, 0.1234f, 1.121f)) + + // complex data types + val b0 = Literal.create(Seq[Array[Byte]](Array[Byte](5, 6), Array[Byte](1, 2), + Array[Byte](1, 2), Array[Byte](5, 6)), ArrayType(BinaryType)) + val b1 = Literal.create(Seq[Array[Byte]](Array[Byte](2, 1), null), + ArrayType(BinaryType)) + val b2 = Literal.create(Seq[Array[Byte]](Array[Byte](5, 6), null, Array[Byte](1, 2), + null, Array[Byte](5, 6), null), ArrayType(BinaryType)) + + checkEvaluation(ArrayDistinct(b0), Seq[Array[Byte]](Array[Byte](5, 6), Array[Byte](1, 2))) + checkEvaluation(ArrayDistinct(b1), Seq[Array[Byte]](Array[Byte](2, 1), null)) + checkEvaluation(ArrayDistinct(b2), Seq[Array[Byte]](Array[Byte](5, 6), null, + Array[Byte](1, 2))) + + val c0 = Literal.create(Seq[Seq[Int]](Seq[Int](1, 2), Seq[Int](3, 4), Seq[Int](1, 2), + Seq[Int](3, 4), Seq[Int](1, 2)), ArrayType(ArrayType(IntegerType))) + val c1 = Literal.create(Seq[Seq[Int]](Seq[Int](5, 6), Seq[Int](2, 1)), + ArrayType(ArrayType(IntegerType))) + val c2 = Literal.create(Seq[Seq[Int]](null, Seq[Int](2, 1), null, null, Seq[Int](2, 1), null), + ArrayType(ArrayType(IntegerType))) + checkEvaluation(ArrayDistinct(c0), Seq[Seq[Int]](Seq[Int](1, 2), Seq[Int](3, 4))) + checkEvaluation(ArrayDistinct(c1), Seq[Seq[Int]](Seq[Int](5, 6), Seq[Int](2, 1))) + checkEvaluation(ArrayDistinct(c2), Seq[Seq[Int]](null, Seq[Int](2, 1))) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 8551058ec58ce..965dbb69c8efb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3189,6 +3189,13 @@ object functions { ArrayRemove(column.expr, Literal(element)) } + /** + * Removes duplicate values from the array. + * @group collection_funcs + * @since 2.4.0 + */ + def array_distinct(e: Column): Column = withExpr { ArrayDistinct(e.expr) } + /** * Creates a new row for each element in the given array or map column. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 4e5c1c56e2673..3dc696bd01eeb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -1216,6 +1216,28 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { assert(e.message.contains("argument 1 requires array type, however, '`_1`' is of string type")) } + test("array_distinct functions") { + val df = Seq( + (Array[Int](2, 1, 3, 4, 3, 5), Array("b", "c", "a", "c", "b", "", "")), + (Array.empty[Int], Array.empty[String]), + (null, null) + ).toDF("a", "b") + checkAnswer( + df.select(array_distinct($"a"), array_distinct($"b")), + Seq( + Row(Seq(2, 1, 3, 4, 5), Seq("b", "c", "a", "")), + Row(Seq.empty[Int], Seq.empty[String]), + Row(null, null)) + ) + checkAnswer( + df.selectExpr("array_distinct(a)", "array_distinct(b)"), + Seq( + Row(Seq(2, 1, 3, 4, 5), Seq("b", "c", "a", "")), + Row(Seq.empty[Int], Seq.empty[String]), + Row(null, null)) + ) + } + private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = { import DataFrameFunctionsSuite.CodegenFallbackExpr for ((codegenFallback, wholeStage) <- Seq((true, false), (false, false), (false, true))) { From 54fcaafb094e299f21c18370fddb4a727c88d875 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Wed, 20 Jun 2018 23:38:37 -0700 Subject: [PATCH 0991/1161] [SPARK-24571][SQL] Support Char literals ## What changes were proposed in this pull request? In the PR, I propose to automatically convert a `Literal` with `Char` type to a `Literal` of `String` type. Currently, the following code: ```scala val df = Seq("Amsterdam", "San Francisco", "London").toDF("city") df.where($"city".contains('o')).show(false) ``` fails with the exception: ``` Unsupported literal type class java.lang.Character o java.lang.RuntimeException: Unsupported literal type class java.lang.Character o at org.apache.spark.sql.catalyst.expressions.Literal$.apply(literals.scala:78) ``` The PR fixes this issue by converting `char` to `string` of length `1`. I believe it makes sense to does not differentiate `char` and `string(1)` in _a unified, multi-language data platform_ like Spark which supports languages like Python/R. Author: Maxim Gekk Author: Maxim Gekk Closes #21578 from MaxGekk/support-char-literals. --- .../spark/sql/catalyst/CatalystTypeConverters.scala | 1 + .../apache/spark/sql/catalyst/expressions/literals.scala | 1 + .../spark/sql/catalyst/CatalystTypeConvertersSuite.scala | 8 ++++++++ .../sql/catalyst/expressions/LiteralExpressionSuite.scala | 7 +++++++ .../test/scala/org/apache/spark/sql/DatasetSuite.scala | 8 ++++++++ 5 files changed, 25 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala index 9e9105a157abe..93df73ab1eaf6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala @@ -286,6 +286,7 @@ object CatalystTypeConverters { override def toCatalystImpl(scalaValue: Any): UTF8String = scalaValue match { case str: String => UTF8String.fromString(str) case utf8: UTF8String => utf8 + case chr: Char => UTF8String.fromString(chr.toString) case other => throw new IllegalArgumentException( s"The value (${other.toString}) of the type (${other.getClass.getCanonicalName}) " + s"cannot be converted to the string type") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index 246025b82d59e..0cc2a332f2c30 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -57,6 +57,7 @@ object Literal { case b: Byte => Literal(b, ByteType) case s: Short => Literal(s, ShortType) case s: String => Literal(UTF8String.fromString(s), StringType) + case c: Char => Literal(UTF8String.fromString(c.toString), StringType) case b: Boolean => Literal(b, BooleanType) case d: BigDecimal => Literal(Decimal(d), DecimalType.fromBigDecimal(d)) case d: JavaBigDecimal => diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala index f99af9b84d959..89452ee05cff3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.UnsafeArrayData import org.apache.spark.sql.catalyst.util.GenericArrayData import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String class CatalystTypeConvertersSuite extends SparkFunSuite { @@ -139,4 +140,11 @@ class CatalystTypeConvertersSuite extends SparkFunSuite { assert(exception.getMessage.contains("The value (0.1) of the type " + "(java.lang.Double) cannot be converted to the string type")) } + + test("SPARK-24571: convert Char to String") { + val chr: Char = 'X' + val converter = CatalystTypeConverters.createToCatalystConverter(StringType) + val expected = UTF8String.fromString("X") + assert(converter(chr) === expected) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala index a9e0eb0e377a6..86f80fe66d28b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala @@ -219,4 +219,11 @@ class LiteralExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { checkUnsupportedTypeInLiteral(Map("key1" -> 1, "key2" -> 2)) checkUnsupportedTypeInLiteral(("mike", 29, 1.0)) } + + test("SPARK-24571: char literals") { + checkEvaluation(Literal('X'), "X") + checkEvaluation(Literal.create('0'), "0") + checkEvaluation(Literal('\u0000'), "\u0000") + checkEvaluation(Literal.create('\n'), "\n") + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 093cee91d2f49..2d20c50584c03 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -1479,6 +1479,14 @@ class DatasetSuite extends QueryTest with SharedSQLContext { assert(ds1.schema == ds2.schema) checkDataset(ds1.select("_2._2"), ds2.select("_2._2").collect(): _*) } + + test("SPARK-24571: filtering of string values by char literal") { + val df = Seq("Amsterdam", "San Francisco", "X").toDF("city") + checkAnswer(df.where('city === 'X'), Seq(Row("X"))) + checkAnswer( + df.where($"city".contains(new java.lang.Character('A'))), + Seq(Row("Amsterdam"))) + } } case class TestDataUnion(x: Int, y: Int, z: Int) From 7236e759c9856970116bf4dd20813dbf14440462 Mon Sep 17 00:00:00 2001 From: Chongguang LIU Date: Thu, 21 Jun 2018 14:58:57 +0800 Subject: [PATCH 0992/1161] [SPARK-24574][SQL] array_contains, array_position, array_remove and element_at functions deal with Column type ## What changes were proposed in this pull request? For the function ```def array_contains(column: Column, value: Any): Column ``` , if we pass the `value` parameter as a Column type, it will yield a runtime exception. This PR proposes a pattern matching to detect if `value` is of type Column. If yes, it will use the .expr of the column, otherwise it will work as it used to. Same thing for ```array_position, array_remove and element_at``` functions ## How was this patch tested? Unit test modified to cover this code change. Ping ueshin Author: Chongguang LIU Closes #21581 from chongguang/SPARK-24574. --- .../org/apache/spark/sql/functions.scala | 8 +-- .../spark/sql/DataFrameFunctionsSuite.scala | 69 +++++++++++++++---- 2 files changed, 58 insertions(+), 19 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 965dbb69c8efb..c296a1bf9d69d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3093,7 +3093,7 @@ object functions { * @since 1.5.0 */ def array_contains(column: Column, value: Any): Column = withExpr { - ArrayContains(column.expr, Literal(value)) + ArrayContains(column.expr, lit(value).expr) } /** @@ -3157,7 +3157,7 @@ object functions { * @since 2.4.0 */ def array_position(column: Column, value: Any): Column = withExpr { - ArrayPosition(column.expr, Literal(value)) + ArrayPosition(column.expr, lit(value).expr) } /** @@ -3168,7 +3168,7 @@ object functions { * @since 2.4.0 */ def element_at(column: Column, value: Any): Column = withExpr { - ElementAt(column.expr, Literal(value)) + ElementAt(column.expr, lit(value).expr) } /** @@ -3186,7 +3186,7 @@ object functions { * @since 2.4.0 */ def array_remove(column: Column, element: Any): Column = withExpr { - ArrayRemove(column.expr, Literal(element)) + ArrayRemove(column.expr, lit(element).expr) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 3dc696bd01eeb..fcdd33f544311 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -635,9 +635,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { test("array contains function") { val df = Seq( - (Seq[Int](1, 2), "x"), - (Seq[Int](), "x") - ).toDF("a", "b") + (Seq[Int](1, 2), "x", 1), + (Seq[Int](), "x", 1) + ).toDF("a", "b", "c") // Simple test cases checkAnswer( @@ -648,6 +648,14 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { df.selectExpr("array_contains(a, 1)"), Seq(Row(true), Row(false)) ) + checkAnswer( + df.select(array_contains(df("a"), df("c"))), + Seq(Row(true), Row(false)) + ) + checkAnswer( + df.selectExpr("array_contains(a, c)"), + Seq(Row(true), Row(false)) + ) // In hive, this errors because null has no type information intercept[AnalysisException] { @@ -862,9 +870,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { test("array position function") { val df = Seq( - (Seq[Int](1, 2), "x"), - (Seq[Int](), "x") - ).toDF("a", "b") + (Seq[Int](1, 2), "x", 1), + (Seq[Int](), "x", 1) + ).toDF("a", "b", "c") checkAnswer( df.select(array_position(df("a"), 1)), @@ -874,7 +882,14 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { df.selectExpr("array_position(a, 1)"), Seq(Row(1L), Row(0L)) ) - + checkAnswer( + df.selectExpr("array_position(a, c)"), + Seq(Row(1L), Row(0L)) + ) + checkAnswer( + df.select(array_position(df("a"), df("c"))), + Seq(Row(1L), Row(0L)) + ) checkAnswer( df.select(array_position(df("a"), null)), Seq(Row(null), Row(null)) @@ -901,10 +916,10 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { test("element_at function") { val df = Seq( - (Seq[String]("1", "2", "3")), - (Seq[String](null, "")), - (Seq[String]()) - ).toDF("a") + (Seq[String]("1", "2", "3"), 1), + (Seq[String](null, ""), -1), + (Seq[String](), 2) + ).toDF("a", "b") intercept[Exception] { checkAnswer( @@ -922,6 +937,14 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { df.select(element_at(df("a"), 4)), Seq(Row(null), Row(null), Row(null)) ) + checkAnswer( + df.select(element_at(df("a"), df("b"))), + Seq(Row("1"), Row(""), Row(null)) + ) + checkAnswer( + df.selectExpr("element_at(a, b)"), + Seq(Row("1"), Row(""), Row(null)) + ) checkAnswer( df.select(element_at(df("a"), 1)), @@ -1189,10 +1212,10 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { test("array remove") { val df = Seq( - (Array[Int](2, 1, 2, 3), Array("a", "b", "c", "a"), Array("", "")), - (Array.empty[Int], Array.empty[String], Array.empty[String]), - (null, null, null) - ).toDF("a", "b", "c") + (Array[Int](2, 1, 2, 3), Array("a", "b", "c", "a"), Array("", ""), 2), + (Array.empty[Int], Array.empty[String], Array.empty[String], 2), + (null, null, null, 2) + ).toDF("a", "b", "c", "d") checkAnswer( df.select(array_remove($"a", 2), array_remove($"b", "a"), array_remove($"c", "")), Seq( @@ -1201,6 +1224,22 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Row(null, null, null)) ) + checkAnswer( + df.select(array_remove($"a", $"d")), + Seq( + Row(Seq(1, 3)), + Row(Seq.empty[Int]), + Row(null)) + ) + + checkAnswer( + df.selectExpr("array_remove(a, d)"), + Seq( + Row(Seq(1, 3)), + Row(Seq.empty[Int]), + Row(null)) + ) + checkAnswer( df.selectExpr("array_remove(a, 2)", "array_remove(b, \"a\")", "array_remove(c, \"\")"), From c0cad596b84feeb7840c1b7a51c1ef275940a5ed Mon Sep 17 00:00:00 2001 From: Rekha Joshi Date: Thu, 21 Jun 2018 16:41:43 +0800 Subject: [PATCH 0993/1161] [SPARK-24614][PYSPARK] Fix for SyntaxWarning on tests.py ## What changes were proposed in this pull request? Fix for SyntaxWarning on tests.py ## How was this patch tested? ./dev/run-tests Author: Rekha Joshi Closes #21604 from rekhajoshm/SPARK-24614. --- python/pyspark/sql/tests.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 94ab867f0bd9b..3c8a8fcf6e946 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -1918,7 +1918,7 @@ def assert_invalid_writer(self, writer, msg=None): self.fail("invalid writer %s did not fail the query" % str(writer)) # not expected except Exception as e: if msg: - assert(msg in str(e), "%s not in %s" % (msg, str(e))) + assert msg in str(e), "%s not in %s" % (msg, str(e)) finally: self.stop_all() From b56e9c613fb345472da3db1a567ee129621f6bf3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Cattilapiros=E2=80=9D?= Date: Thu, 21 Jun 2018 09:17:18 -0500 Subject: [PATCH 0994/1161] [SPARK-16630][YARN] Blacklist a node if executors won't launch on it MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? This change extends YARN resource allocation handling with blacklisting functionality. This handles cases when node is messed up or misconfigured such that a container won't launch on it. Before this change backlisting only focused on task execution but this change introduces YarnAllocatorBlacklistTracker which tracks allocation failures per host (when enabled via "spark.yarn.blacklist.executor.launch.blacklisting.enabled"). ## How was this patch tested? ### With unit tests Including a new suite: YarnAllocatorBlacklistTrackerSuite. #### Manually It was tested on a cluster by deleting the Spark jars on one of the node. #### Behaviour before these changes Starting Spark as: ``` spark2-shell --master yarn --deploy-mode client --num-executors 4 --conf spark.executor.memory=4g --conf "spark.yarn.max.executor.failures=6" ``` Log is: ``` 18/04/12 06:49:36 INFO yarn.ApplicationMaster: Final app status: FAILED, exitCode: 11, (reason: Max number of executor failures (6) reached) 18/04/12 06:49:39 INFO yarn.ApplicationMaster: Unregistering ApplicationMaster with FAILED (diag message: Max number of executor failures (6) reached) 18/04/12 06:49:39 INFO impl.AMRMClientImpl: Waiting for application to be successfully unregistered. 18/04/12 06:49:39 INFO yarn.ApplicationMaster: Deleting staging directory hdfs://apiros-1.gce.test.com:8020/user/systest/.sparkStaging/application_1523459048274_0016 18/04/12 06:49:39 INFO util.ShutdownHookManager: Shutdown hook called ``` #### Behaviour after these changes Starting Spark as: ``` spark2-shell --master yarn --deploy-mode client --num-executors 4 --conf spark.executor.memory=4g --conf "spark.yarn.max.executor.failures=6" --conf "spark.yarn.blacklist.executor.launch.blacklisting.enabled=true" ``` And the log is: ``` 18/04/13 05:37:43 INFO yarn.YarnAllocator: Will request 1 executor container(s), each with 1 core(s) and 4505 MB memory (including 409 MB of overhead) 18/04/13 05:37:43 INFO yarn.YarnAllocator: Submitted 1 unlocalized container requests. 18/04/13 05:37:43 INFO yarn.YarnAllocator: Launching container container_1523459048274_0025_01_000008 on host apiros-4.gce.test.com for executor with ID 6 18/04/13 05:37:43 INFO yarn.YarnAllocator: Received 1 containers from YARN, launching executors on 1 of them. 18/04/13 05:37:43 INFO yarn.YarnAllocator: Completed container container_1523459048274_0025_01_000007 on host: apiros-4.gce.test.com (state: COMPLETE, exit status: 1) 18/04/13 05:37:43 INFO yarn.YarnAllocatorBlacklistTracker: blacklisting host as YARN allocation failed: apiros-4.gce.test.com 18/04/13 05:37:43 INFO yarn.YarnAllocatorBlacklistTracker: adding nodes to YARN application master's blacklist: List(apiros-4.gce.test.com) 18/04/13 05:37:43 WARN yarn.YarnAllocator: Container marked as failed: container_1523459048274_0025_01_000007 on host: apiros-4.gce.test.com. Exit status: 1. Diagnostics: Exception from container-launch. Container id: container_1523459048274_0025_01_000007 Exit code: 1 Stack trace: ExitCodeException exitCode=1: at org.apache.hadoop.util.Shell.runCommand(Shell.java:604) at org.apache.hadoop.util.Shell.run(Shell.java:507) at org.apache.hadoop.util.Shell$ShellCommandExecutor.execute(Shell.java:789) at org.apache.hadoop.yarn.server.nodemanager.DefaultContainerExecutor.launchContainer(DefaultContainerExecutor.java:213) at org.apache.hadoop.yarn.server.nodemanager.containermanager.launcher.ContainerLaunch.call(ContainerLaunch.java:302) at org.apache.hadoop.yarn.server.nodemanager.containermanager.launcher.ContainerLaunch.call(ContainerLaunch.java:82) at java.util.concurrent.FutureTask.run(FutureTask.java:266) at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149) at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624) at java.lang.Thread.run(Thread.java:748) ``` Where the most important part is: ``` 18/04/13 05:37:43 INFO yarn.YarnAllocatorBlacklistTracker: blacklisting host as YARN allocation failed: apiros-4.gce.test.com 18/04/13 05:37:43 INFO yarn.YarnAllocatorBlacklistTracker: adding nodes to YARN application master's blacklist: List(apiros-4.gce.test.com) ``` And execution was continued (no shutdown called). ### Testing the backlisting of the whole cluster Starting Spark with YARN blacklisting enabled then removing a the Spark core jar one by one from all the cluster nodes. Then executing a simple spark job which fails checking the yarn log the expected exit status is contained: ``` 18/06/15 01:07:10 INFO yarn.ApplicationMaster: Final app status: FAILED, exitCode: 11, (reason: Due to executor failures all available nodes are blacklisted) 18/06/15 01:07:13 INFO util.ShutdownHookManager: Shutdown hook called ``` Author: “attilapiros” Closes #21068 from attilapiros/SPARK-16630. --- .../spark/scheduler/BlacklistTracker.scala | 2 +- .../CoarseGrainedSchedulerBackend.scala | 3 +- .../apache/spark/HeartbeatReceiverSuite.scala | 6 +- docs/running-on-yarn.md | 10 + ...osCoarseGrainedSchedulerBackendSuite.scala | 1 + .../spark/deploy/yarn/ApplicationMaster.scala | 4 + .../spark/deploy/yarn/YarnAllocator.scala | 60 ++---- .../yarn/YarnAllocatorBlacklistTracker.scala | 187 ++++++++++++++++++ .../org/apache/spark/deploy/yarn/config.scala | 6 + .../deploy/yarn/FailureTrackerSuite.scala | 100 ++++++++++ .../YarnAllocatorBlacklistTrackerSuite.scala | 140 +++++++++++++ .../deploy/yarn/YarnAllocatorSuite.scala | 16 +- 12 files changed, 479 insertions(+), 56 deletions(-) create mode 100644 resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocatorBlacklistTracker.scala create mode 100644 resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/FailureTrackerSuite.scala create mode 100644 resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorBlacklistTrackerSuite.scala diff --git a/core/src/main/scala/org/apache/spark/scheduler/BlacklistTracker.scala b/core/src/main/scala/org/apache/spark/scheduler/BlacklistTracker.scala index 30cf75d43ee09..980fbbe516b91 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/BlacklistTracker.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/BlacklistTracker.scala @@ -371,7 +371,7 @@ private[scheduler] class BlacklistTracker ( } -private[scheduler] object BlacklistTracker extends Logging { +private[spark] object BlacklistTracker extends Logging { private val DEFAULT_TIMEOUT = "1h" diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index d8794e8e551aa..9b90e309d2e04 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -170,8 +170,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp if (executorDataMap.contains(executorId)) { executorRef.send(RegisterExecutorFailed("Duplicate executor ID: " + executorId)) context.reply(true) - } else if (scheduler.nodeBlacklist != null && - scheduler.nodeBlacklist.contains(hostname)) { + } else if (scheduler.nodeBlacklist.contains(hostname)) { // If the cluster manager gives us an executor on a blacklisted node (because it // already started allocating those resources before we informed it of our blacklist, // or if it ignored our blacklist), then we reject that executor immediately. diff --git a/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala b/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala index 88916488c0def..b705556e54b14 100644 --- a/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala +++ b/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala @@ -19,7 +19,6 @@ package org.apache.spark import java.util.concurrent.{ExecutorService, TimeUnit} -import scala.collection.Map import scala.collection.mutable import scala.concurrent.Future import scala.concurrent.duration._ @@ -73,6 +72,7 @@ class HeartbeatReceiverSuite sc = spy(new SparkContext(conf)) scheduler = mock(classOf[TaskSchedulerImpl]) when(sc.taskScheduler).thenReturn(scheduler) + when(scheduler.nodeBlacklist).thenReturn(Predef.Set[String]()) when(scheduler.sc).thenReturn(sc) heartbeatReceiverClock = new ManualClock heartbeatReceiver = new HeartbeatReceiver(sc, heartbeatReceiverClock) @@ -241,7 +241,7 @@ class HeartbeatReceiverSuite } === Some(true)) } - private def getTrackedExecutors: Map[String, Long] = { + private def getTrackedExecutors: collection.Map[String, Long] = { // We may receive undesired SparkListenerExecutorAdded from LocalSchedulerBackend, // so exclude it from the map. See SPARK-10800. heartbeatReceiver.invokePrivate(_executorLastSeen()). @@ -272,7 +272,7 @@ private class FakeSchedulerBackend( protected override def doRequestTotalExecutors(requestedTotal: Int): Future[Boolean] = { clusterManagerEndpoint.ask[Boolean]( - RequestExecutors(requestedTotal, localityAwareTasks, hostToLocalTaskCount, Set.empty[String])) + RequestExecutors(requestedTotal, localityAwareTasks, hostToLocalTaskCount, Set.empty)) } protected override def doKillExecutors(executorIds: Seq[String]): Future[Boolean] = { diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index 4dbcbeafbbd9d..575da7205b529 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -411,6 +411,16 @@ To use a custom metrics.properties for the application master and executors, upd name matches both the include and the exclude pattern, this file will be excluded eventually. + + spark.yarn.blacklist.executor.launch.blacklisting.enabled + false + + Flag to enable blacklisting of nodes having YARN resource allocation problems. + The error limit for blacklisting can be configured by + spark.blacklist.application.maxFailedExecutorsPerNode. + + + # Important notes diff --git a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala index f4bd1ee9da6f7..b790c7cd27794 100644 --- a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala +++ b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala @@ -789,6 +789,7 @@ class MesosCoarseGrainedSchedulerBackendSuite extends SparkFunSuite when(driver.start()).thenReturn(Protos.Status.DRIVER_RUNNING) taskScheduler = mock[TaskSchedulerImpl] + when(taskScheduler.nodeBlacklist).thenReturn(Set[String]()) when(taskScheduler.sc).thenReturn(sc) externalShuffleClient = mock[MesosExternalShuffleClient] diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 3d6ee50b070a3..ecc576910db9e 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -515,6 +515,10 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends finish(FinalApplicationStatus.FAILED, ApplicationMaster.EXIT_MAX_EXECUTOR_FAILURES, s"Max number of executor failures ($maxNumExecutorFailures) reached") + } else if (allocator.isAllNodeBlacklisted) { + finish(FinalApplicationStatus.FAILED, + ApplicationMaster.EXIT_MAX_EXECUTOR_FAILURES, + "Due to executor failures all available nodes are blacklisted") } else { logDebug("Sending progress") allocator.allocateResources() diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala index ebee3d431744d..fae054e0eea00 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala @@ -24,7 +24,7 @@ import java.util.regex.Pattern import scala.collection.JavaConverters._ import scala.collection.mutable -import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Queue} +import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} import scala.util.control.NonFatal import org.apache.hadoop.yarn.api.records._ @@ -66,7 +66,8 @@ private[yarn] class YarnAllocator( appAttemptId: ApplicationAttemptId, securityMgr: SecurityManager, localResources: Map[String, LocalResource], - resolver: SparkRackResolver) + resolver: SparkRackResolver, + clock: Clock = new SystemClock) extends Logging { import YarnAllocator._ @@ -102,18 +103,14 @@ private[yarn] class YarnAllocator( private var executorIdCounter: Int = driverRef.askSync[Int](RetrieveLastAllocatedExecutorId) - // Queue to store the timestamp of failed executors - private val failedExecutorsTimeStamps = new Queue[Long]() + private[spark] val failureTracker = new FailureTracker(sparkConf, clock) - private var clock: Clock = new SystemClock - - private val executorFailuresValidityInterval = - sparkConf.get(EXECUTOR_ATTEMPT_FAILURE_VALIDITY_INTERVAL_MS).getOrElse(-1L) + private val allocatorBlacklistTracker = + new YarnAllocatorBlacklistTracker(sparkConf, amClient, failureTracker) @volatile private var targetNumExecutors = SchedulerBackendUtils.getInitialTargetExecutorNumber(sparkConf) - private var currentNodeBlacklist = Set.empty[String] // Executor loss reason requests that are pending - maps from executor ID for inquiry to a // list of requesters that should be responded to once we find out why the given executor @@ -149,7 +146,6 @@ private[yarn] class YarnAllocator( private val labelExpression = sparkConf.get(EXECUTOR_NODE_LABEL_EXPRESSION) - // A map to store preferred hostname and possible task numbers running on it. private var hostToLocalTaskCounts: Map[String, Int] = Map.empty @@ -160,26 +156,11 @@ private[yarn] class YarnAllocator( private[yarn] val containerPlacementStrategy = new LocalityPreferredContainerPlacementStrategy(sparkConf, conf, resource, resolver) - /** - * Use a different clock for YarnAllocator. This is mainly used for testing. - */ - def setClock(newClock: Clock): Unit = { - clock = newClock - } - def getNumExecutorsRunning: Int = runningExecutors.size() - def getNumExecutorsFailed: Int = synchronized { - val endTime = clock.getTimeMillis() + def getNumExecutorsFailed: Int = failureTracker.numFailedExecutors - while (executorFailuresValidityInterval > 0 - && failedExecutorsTimeStamps.nonEmpty - && failedExecutorsTimeStamps.head < endTime - executorFailuresValidityInterval) { - failedExecutorsTimeStamps.dequeue() - } - - failedExecutorsTimeStamps.size - } + def isAllNodeBlacklisted: Boolean = allocatorBlacklistTracker.isAllNodeBlacklisted /** * A sequence of pending container requests that have not yet been fulfilled. @@ -204,9 +185,8 @@ private[yarn] class YarnAllocator( * @param localityAwareTasks number of locality aware tasks to be used as container placement hint * @param hostToLocalTaskCount a map of preferred hostname to possible task counts to be used as * container placement hint. - * @param nodeBlacklist a set of blacklisted nodes, which is passed in to avoid allocating new - * containers on them. It will be used to update the application master's - * blacklist. + * @param nodeBlacklist blacklisted nodes, which is passed in to avoid allocating new containers + * on them. It will be used to update the application master's blacklist. * @return Whether the new requested total is different than the old value. */ def requestTotalExecutorsWithPreferredLocalities( @@ -220,19 +200,7 @@ private[yarn] class YarnAllocator( if (requestedTotal != targetNumExecutors) { logInfo(s"Driver requested a total number of $requestedTotal executor(s).") targetNumExecutors = requestedTotal - - // Update blacklist infomation to YARN ResouceManager for this application, - // in order to avoid allocating new Containers on the problematic nodes. - val blacklistAdditions = nodeBlacklist -- currentNodeBlacklist - val blacklistRemovals = currentNodeBlacklist -- nodeBlacklist - if (blacklistAdditions.nonEmpty) { - logInfo(s"adding nodes to YARN application master's blacklist: $blacklistAdditions") - } - if (blacklistRemovals.nonEmpty) { - logInfo(s"removing nodes from YARN application master's blacklist: $blacklistRemovals") - } - amClient.updateBlacklist(blacklistAdditions.toList.asJava, blacklistRemovals.toList.asJava) - currentNodeBlacklist = nodeBlacklist + allocatorBlacklistTracker.setSchedulerBlacklistedNodes(nodeBlacklist) true } else { false @@ -268,6 +236,7 @@ private[yarn] class YarnAllocator( val allocateResponse = amClient.allocate(progressIndicator) val allocatedContainers = allocateResponse.getAllocatedContainers() + allocatorBlacklistTracker.setNumClusterNodes(allocateResponse.getNumClusterNodes) if (allocatedContainers.size > 0) { logDebug(("Allocated containers: %d. Current executor count: %d. " + @@ -602,8 +571,9 @@ private[yarn] class YarnAllocator( completedContainer.getDiagnostics, PMEM_EXCEEDED_PATTERN)) case _ => - // Enqueue the timestamp of failed executor - failedExecutorsTimeStamps.enqueue(clock.getTimeMillis()) + // all the failures which not covered above, like: + // disk failure, kill by app master or resource manager, ... + allocatorBlacklistTracker.handleResourceAllocationFailure(hostOpt) (true, "Container marked as failed: " + containerId + onHostStr + ". Exit status: " + completedContainer.getExitStatus + ". Diagnostics: " + completedContainer.getDiagnostics) diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocatorBlacklistTracker.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocatorBlacklistTracker.scala new file mode 100644 index 0000000000000..1b48a0ee7ad32 --- /dev/null +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocatorBlacklistTracker.scala @@ -0,0 +1,187 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.yarn + +import scala.collection.JavaConverters._ +import scala.collection.mutable +import scala.collection.mutable.HashMap + +import org.apache.hadoop.yarn.client.api.AMRMClient +import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest + +import org.apache.spark.SparkConf +import org.apache.spark.deploy.yarn.config._ +import org.apache.spark.internal.Logging +import org.apache.spark.internal.config._ +import org.apache.spark.scheduler.BlacklistTracker +import org.apache.spark.util.{Clock, SystemClock, Utils} + +/** + * YarnAllocatorBlacklistTracker is responsible for tracking the blacklisted nodes + * and synchronizing the node list to YARN. + * + * Blacklisted nodes are coming from two different sources: + * + *
      + *
    • from the scheduler as task level blacklisted nodes + *
    • from this class (tracked here) as YARN resource allocation problems + *
    + * + * The reason to realize this logic here (and not in the driver) is to avoid possible delays + * between synchronizing the blacklisted nodes with YARN and resource allocations. + */ +private[spark] class YarnAllocatorBlacklistTracker( + sparkConf: SparkConf, + amClient: AMRMClient[ContainerRequest], + failureTracker: FailureTracker) + extends Logging { + + private val blacklistTimeoutMillis = BlacklistTracker.getBlacklistTimeout(sparkConf) + + private val launchBlacklistEnabled = sparkConf.get(YARN_EXECUTOR_LAUNCH_BLACKLIST_ENABLED) + + private val maxFailuresPerHost = sparkConf.get(MAX_FAILED_EXEC_PER_NODE) + + private val allocatorBlacklist = new HashMap[String, Long]() + + private var currentBlacklistedYarnNodes = Set.empty[String] + + private var schedulerBlacklist = Set.empty[String] + + private var numClusterNodes = Int.MaxValue + + def setNumClusterNodes(numClusterNodes: Int): Unit = { + this.numClusterNodes = numClusterNodes + } + + def handleResourceAllocationFailure(hostOpt: Option[String]): Unit = { + hostOpt match { + case Some(hostname) if launchBlacklistEnabled => + // failures on an already blacklisted nodes are not even tracked. + // otherwise, such failures could shutdown the application + // as resource requests are asynchronous + // and a late failure response could exceed MAX_EXECUTOR_FAILURES + if (!schedulerBlacklist.contains(hostname) && + !allocatorBlacklist.contains(hostname)) { + failureTracker.registerFailureOnHost(hostname) + updateAllocationBlacklistedNodes(hostname) + } + case _ => + failureTracker.registerExecutorFailure() + } + } + + private def updateAllocationBlacklistedNodes(hostname: String): Unit = { + val failuresOnHost = failureTracker.numFailuresOnHost(hostname) + if (failuresOnHost > maxFailuresPerHost) { + logInfo(s"blacklisting $hostname as YARN allocation failed $failuresOnHost times") + allocatorBlacklist.put( + hostname, + failureTracker.clock.getTimeMillis() + blacklistTimeoutMillis) + refreshBlacklistedNodes() + } + } + + def setSchedulerBlacklistedNodes(schedulerBlacklistedNodesWithExpiry: Set[String]): Unit = { + this.schedulerBlacklist = schedulerBlacklistedNodesWithExpiry + refreshBlacklistedNodes() + } + + def isAllNodeBlacklisted: Boolean = currentBlacklistedYarnNodes.size >= numClusterNodes + + private def refreshBlacklistedNodes(): Unit = { + removeExpiredYarnBlacklistedNodes() + val allBlacklistedNodes = schedulerBlacklist ++ allocatorBlacklist.keySet + synchronizeBlacklistedNodeWithYarn(allBlacklistedNodes) + } + + private def synchronizeBlacklistedNodeWithYarn(nodesToBlacklist: Set[String]): Unit = { + // Update blacklist information to YARN ResourceManager for this application, + // in order to avoid allocating new Containers on the problematic nodes. + val additions = (nodesToBlacklist -- currentBlacklistedYarnNodes).toList.sorted + val removals = (currentBlacklistedYarnNodes -- nodesToBlacklist).toList.sorted + if (additions.nonEmpty) { + logInfo(s"adding nodes to YARN application master's blacklist: $additions") + } + if (removals.nonEmpty) { + logInfo(s"removing nodes from YARN application master's blacklist: $removals") + } + amClient.updateBlacklist(additions.asJava, removals.asJava) + currentBlacklistedYarnNodes = nodesToBlacklist + } + + private def removeExpiredYarnBlacklistedNodes(): Unit = { + val now = failureTracker.clock.getTimeMillis() + allocatorBlacklist.retain { (_, expiryTime) => expiryTime > now } + } +} + +/** + * FailureTracker is responsible for tracking executor failures both for each host separately + * and for all hosts altogether. + */ +private[spark] class FailureTracker( + sparkConf: SparkConf, + val clock: Clock = new SystemClock) extends Logging { + + private val executorFailuresValidityInterval = + sparkConf.get(config.EXECUTOR_ATTEMPT_FAILURE_VALIDITY_INTERVAL_MS).getOrElse(-1L) + + // Queue to store the timestamp of failed executors for each host + private val failedExecutorsTimeStampsPerHost = mutable.Map[String, mutable.Queue[Long]]() + + private val failedExecutorsTimeStamps = new mutable.Queue[Long]() + + private def updateAndCountFailures(failedExecutorsWithTimeStamps: mutable.Queue[Long]): Int = { + val endTime = clock.getTimeMillis() + while (executorFailuresValidityInterval > 0 && + failedExecutorsWithTimeStamps.nonEmpty && + failedExecutorsWithTimeStamps.head < endTime - executorFailuresValidityInterval) { + failedExecutorsWithTimeStamps.dequeue() + } + failedExecutorsWithTimeStamps.size + } + + def numFailedExecutors: Int = synchronized { + updateAndCountFailures(failedExecutorsTimeStamps) + } + + def registerFailureOnHost(hostname: String): Unit = synchronized { + val timeMillis = clock.getTimeMillis() + failedExecutorsTimeStamps.enqueue(timeMillis) + val failedExecutorsOnHost = + failedExecutorsTimeStampsPerHost.getOrElse(hostname, { + val failureOnHost = mutable.Queue[Long]() + failedExecutorsTimeStampsPerHost.put(hostname, failureOnHost) + failureOnHost + }) + failedExecutorsOnHost.enqueue(timeMillis) + } + + def registerExecutorFailure(): Unit = synchronized { + val timeMillis = clock.getTimeMillis() + failedExecutorsTimeStamps.enqueue(timeMillis) + } + + def numFailuresOnHost(hostname: String): Int = { + failedExecutorsTimeStampsPerHost.get(hostname).map { failedExecutorsOnHost => + updateAndCountFailures(failedExecutorsOnHost) + }.getOrElse(0) + } + +} + diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala index 1a99b3bd57672..129084a86597a 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala @@ -328,4 +328,10 @@ package object config { CACHED_FILES_TYPES, CACHED_CONF_ARCHIVE) + /* YARN allocator-level blacklisting related config entries. */ + private[spark] val YARN_EXECUTOR_LAUNCH_BLACKLIST_ENABLED = + ConfigBuilder("spark.yarn.blacklist.executor.launch.blacklisting.enabled") + .booleanConf + .createWithDefault(false) + } diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/FailureTrackerSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/FailureTrackerSuite.scala new file mode 100644 index 0000000000000..4f77b9c99dd25 --- /dev/null +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/FailureTrackerSuite.scala @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.yarn + +import org.scalatest.Matchers + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.util.ManualClock + +class FailureTrackerSuite extends SparkFunSuite with Matchers { + + override def beforeAll(): Unit = { + super.beforeAll() + } + + test("failures expire if validity interval is set") { + val sparkConf = new SparkConf() + sparkConf.set(config.EXECUTOR_ATTEMPT_FAILURE_VALIDITY_INTERVAL_MS, 100L) + + val clock = new ManualClock() + val failureTracker = new FailureTracker(sparkConf, clock) + + clock.setTime(0) + failureTracker.registerFailureOnHost("host1") + failureTracker.numFailuresOnHost("host1") should be (1) + failureTracker.numFailedExecutors should be (1) + + clock.setTime(10) + failureTracker.registerFailureOnHost("host2") + failureTracker.numFailuresOnHost("host2") should be (1) + failureTracker.numFailedExecutors should be (2) + + clock.setTime(20) + failureTracker.registerFailureOnHost("host1") + failureTracker.numFailuresOnHost("host1") should be (2) + failureTracker.numFailedExecutors should be (3) + + clock.setTime(30) + failureTracker.registerFailureOnHost("host2") + failureTracker.numFailuresOnHost("host2") should be (2) + failureTracker.numFailedExecutors should be (4) + + clock.setTime(101) + failureTracker.numFailuresOnHost("host1") should be (1) + failureTracker.numFailedExecutors should be (3) + + clock.setTime(231) + failureTracker.numFailuresOnHost("host1") should be (0) + failureTracker.numFailuresOnHost("host2") should be (0) + failureTracker.numFailedExecutors should be (0) + } + + + test("failures never expire if validity interval is not set (-1)") { + val sparkConf = new SparkConf() + + val clock = new ManualClock() + val failureTracker = new FailureTracker(sparkConf, clock) + + clock.setTime(0) + failureTracker.registerFailureOnHost("host1") + failureTracker.numFailuresOnHost("host1") should be (1) + failureTracker.numFailedExecutors should be (1) + + clock.setTime(10) + failureTracker.registerFailureOnHost("host2") + failureTracker.numFailuresOnHost("host2") should be (1) + failureTracker.numFailedExecutors should be (2) + + clock.setTime(20) + failureTracker.registerFailureOnHost("host1") + failureTracker.numFailuresOnHost("host1") should be (2) + failureTracker.numFailedExecutors should be (3) + + clock.setTime(30) + failureTracker.registerFailureOnHost("host2") + failureTracker.numFailuresOnHost("host2") should be (2) + failureTracker.numFailedExecutors should be (4) + + clock.setTime(1000) + failureTracker.numFailuresOnHost("host1") should be (2) + failureTracker.numFailuresOnHost("host2") should be (2) + failureTracker.numFailedExecutors should be (4) + } + +} diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorBlacklistTrackerSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorBlacklistTrackerSuite.scala new file mode 100644 index 0000000000000..aeac68e6ed330 --- /dev/null +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorBlacklistTrackerSuite.scala @@ -0,0 +1,140 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.yarn + +import java.util.Arrays +import java.util.Collections + +import org.apache.hadoop.yarn.client.api.AMRMClient +import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest +import org.mockito.Mockito._ +import org.scalatest.{BeforeAndAfterEach, Matchers} + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.yarn.config.YARN_EXECUTOR_LAUNCH_BLACKLIST_ENABLED +import org.apache.spark.internal.config.{BLACKLIST_TIMEOUT_CONF, MAX_FAILED_EXEC_PER_NODE} +import org.apache.spark.util.ManualClock + +class YarnAllocatorBlacklistTrackerSuite extends SparkFunSuite with Matchers + with BeforeAndAfterEach { + + val BLACKLIST_TIMEOUT = 100L + val MAX_FAILED_EXEC_PER_NODE_VALUE = 2 + + var amClientMock: AMRMClient[ContainerRequest] = _ + var yarnBlacklistTracker: YarnAllocatorBlacklistTracker = _ + var failureTracker: FailureTracker = _ + var clock: ManualClock = _ + + override def beforeEach(): Unit = { + val sparkConf = new SparkConf() + sparkConf.set(BLACKLIST_TIMEOUT_CONF, BLACKLIST_TIMEOUT) + sparkConf.set(YARN_EXECUTOR_LAUNCH_BLACKLIST_ENABLED, true) + sparkConf.set(MAX_FAILED_EXEC_PER_NODE, MAX_FAILED_EXEC_PER_NODE_VALUE) + clock = new ManualClock() + + amClientMock = mock(classOf[AMRMClient[ContainerRequest]]) + failureTracker = new FailureTracker(sparkConf, clock) + yarnBlacklistTracker = + new YarnAllocatorBlacklistTracker(sparkConf, amClientMock, failureTracker) + yarnBlacklistTracker.setNumClusterNodes(4) + super.beforeEach() + } + + test("expiring its own blacklisted nodes") { + (1 to MAX_FAILED_EXEC_PER_NODE_VALUE).foreach { + _ => { + yarnBlacklistTracker.handleResourceAllocationFailure(Some("host")) + // host should not be blacklisted at these failures as MAX_FAILED_EXEC_PER_NODE is 2 + verify(amClientMock, never()) + .updateBlacklist(Arrays.asList("host"), Collections.emptyList()) + } + } + + yarnBlacklistTracker.handleResourceAllocationFailure(Some("host")) + // the third failure on the host triggers the blacklisting + verify(amClientMock).updateBlacklist(Arrays.asList("host"), Collections.emptyList()) + + clock.advance(BLACKLIST_TIMEOUT) + + // trigger synchronisation of blacklisted nodes with YARN + yarnBlacklistTracker.setSchedulerBlacklistedNodes(Set()) + verify(amClientMock).updateBlacklist(Collections.emptyList(), Arrays.asList("host")) + } + + test("not handling the expiry of scheduler blacklisted nodes") { + yarnBlacklistTracker.setSchedulerBlacklistedNodes(Set("host1", "host2")) + verify(amClientMock) + .updateBlacklist(Arrays.asList("host1", "host2"), Collections.emptyList()) + + // advance timer more then host1, host2 expiry time + clock.advance(200L) + + // expired blacklisted nodes (simulating a resource request) + yarnBlacklistTracker.setSchedulerBlacklistedNodes(Set("host1", "host2")) + // no change is communicated to YARN regarding the blacklisting + verify(amClientMock).updateBlacklist(Collections.emptyList(), Collections.emptyList()) + } + + test("combining scheduler and allocation blacklist") { + (1 to MAX_FAILED_EXEC_PER_NODE_VALUE).foreach { + _ => { + yarnBlacklistTracker.handleResourceAllocationFailure(Some("host1")) + // host1 should not be blacklisted at these failures as MAX_FAILED_EXEC_PER_NODE is 2 + verify(amClientMock, never()) + .updateBlacklist(Arrays.asList("host1"), Collections.emptyList()) + } + } + + // as this is the third failure on host1 the node will be blacklisted + yarnBlacklistTracker.handleResourceAllocationFailure(Some("host1")) + verify(amClientMock) + .updateBlacklist(Arrays.asList("host1"), Collections.emptyList()) + + yarnBlacklistTracker.setSchedulerBlacklistedNodes(Set("host2", "host3")) + verify(amClientMock) + .updateBlacklist(Arrays.asList("host2", "host3"), Collections.emptyList()) + + clock.advance(10L) + + yarnBlacklistTracker.setSchedulerBlacklistedNodes(Set("host3", "host4")) + verify(amClientMock) + .updateBlacklist(Arrays.asList("host4"), Arrays.asList("host2")) + } + + test("blacklist all available nodes") { + yarnBlacklistTracker.setSchedulerBlacklistedNodes(Set("host1", "host2", "host3")) + verify(amClientMock) + .updateBlacklist(Arrays.asList("host1", "host2", "host3"), Collections.emptyList()) + + clock.advance(60L) + (1 to MAX_FAILED_EXEC_PER_NODE_VALUE).foreach { + _ => { + yarnBlacklistTracker.handleResourceAllocationFailure(Some("host4")) + // host4 should not be blacklisted at these failures as MAX_FAILED_EXEC_PER_NODE is 2 + verify(amClientMock, never()) + .updateBlacklist(Arrays.asList("host4"), Collections.emptyList()) + } + } + + // the third failure on the host triggers the blacklisting + yarnBlacklistTracker.handleResourceAllocationFailure(Some("host4")) + + verify(amClientMock).updateBlacklist(Arrays.asList("host4"), Collections.emptyList()) + assert(yarnBlacklistTracker.isAllNodeBlacklisted === true) + } +} diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala index 525abb6f2b350..3f783baed110d 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala @@ -59,6 +59,8 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter var rmClient: AMRMClient[ContainerRequest] = _ + var clock: ManualClock = _ + var containerNum = 0 override def beforeEach() { @@ -66,6 +68,7 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter rmClient = AMRMClient.createAMRMClient() rmClient.init(conf) rmClient.start() + clock = new ManualClock() } override def afterEach() { @@ -101,7 +104,8 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter appAttemptId, new SecurityManager(sparkConf), Map(), - new MockResolver()) + new MockResolver(), + clock) } def createContainer(host: String): Container = { @@ -332,10 +336,14 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter handler.requestTotalExecutorsWithPreferredLocalities(1, 0, Map(), Set("hostA")) verify(mockAmClient).updateBlacklist(Seq("hostA").asJava, Seq[String]().asJava) - handler.requestTotalExecutorsWithPreferredLocalities(2, 0, Map(), Set("hostA", "hostB")) + val blacklistedNodes = Set( + "hostA", + "hostB" + ) + handler.requestTotalExecutorsWithPreferredLocalities(2, 0, Map(), blacklistedNodes) verify(mockAmClient).updateBlacklist(Seq("hostB").asJava, Seq[String]().asJava) - handler.requestTotalExecutorsWithPreferredLocalities(3, 0, Map(), Set()) + handler.requestTotalExecutorsWithPreferredLocalities(3, 0, Map(), Set.empty) verify(mockAmClient).updateBlacklist(Seq[String]().asJava, Seq("hostA", "hostB").asJava) } @@ -353,8 +361,6 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter test("window based failure executor counting") { sparkConf.set("spark.yarn.executor.failuresValidityInterval", "100s") val handler = createAllocator(4) - val clock = new ManualClock(0L) - handler.setClock(clock) handler.updateResourceRequests() handler.getNumExecutorsRunning should be (0) From c8e909cd498b67b121fa920ceee7631c652dac38 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Thu, 21 Jun 2018 13:25:15 -0500 Subject: [PATCH 0995/1161] [SPARK-24589][CORE] Correctly identify tasks in output commit coordinator. When an output stage is retried, it's possible that tasks from the previous attempt are still running. In that case, there would be a new task for the same partition in the new attempt, and the coordinator would allow both tasks to commit their output since it did not keep track of stage attempts. The change adds more information to the stage state tracked by the coordinator, so that only one task is allowed to commit the output in the above case. The stage state in the coordinator is also maintained across stage retries, so that a stray speculative task from a previous stage attempt is not allowed to commit. This also removes some code added in SPARK-18113 that allowed for duplicate commit requests; with the RPC code used in Spark 2, that situation cannot happen, so there is no need to handle it. Author: Marcelo Vanzin Closes #21577 from vanzin/SPARK-24552. --- .../spark/mapred/SparkHadoopMapRedUtil.scala | 8 +- .../apache/spark/scheduler/DAGScheduler.scala | 23 ++-- .../scheduler/OutputCommitCoordinator.scala | 128 ++++++++++-------- .../OutputCommitCoordinatorSuite.scala | 116 +++++++++++----- .../datasources/v2/WriteToDataSourceV2.scala | 9 +- 5 files changed, 179 insertions(+), 105 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala b/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala index 764735dc4eae7..db8aff94ea1e1 100644 --- a/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala +++ b/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala @@ -69,9 +69,9 @@ object SparkHadoopMapRedUtil extends Logging { if (shouldCoordinateWithDriver) { val outputCommitCoordinator = SparkEnv.get.outputCommitCoordinator - val taskAttemptNumber = TaskContext.get().attemptNumber() - val stageId = TaskContext.get().stageId() - val canCommit = outputCommitCoordinator.canCommit(stageId, splitId, taskAttemptNumber) + val ctx = TaskContext.get() + val canCommit = outputCommitCoordinator.canCommit(ctx.stageId(), ctx.stageAttemptNumber(), + splitId, ctx.attemptNumber()) if (canCommit) { performCommit() @@ -81,7 +81,7 @@ object SparkHadoopMapRedUtil extends Logging { logInfo(message) // We need to abort the task so that the driver can reschedule new attempts, if necessary committer.abortTask(mrTaskContext) - throw new CommitDeniedException(message, stageId, splitId, taskAttemptNumber) + throw new CommitDeniedException(message, ctx.stageId(), splitId, ctx.attemptNumber()) } } else { // Speculation is disabled or a user has chosen to manually bypass the commit coordination diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 041eade82d3ca..f74425d73b392 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -1171,6 +1171,7 @@ class DAGScheduler( outputCommitCoordinator.taskCompleted( stageId, + task.stageAttemptId, task.partitionId, event.taskInfo.attemptNumber, // this is a task attempt number event.reason) @@ -1330,23 +1331,24 @@ class DAGScheduler( s" ${task.stageAttemptId} and there is a more recent attempt for that stage " + s"(attempt ${failedStage.latestInfo.attemptNumber}) running") } else { + failedStage.fetchFailedAttemptIds.add(task.stageAttemptId) + val shouldAbortStage = + failedStage.fetchFailedAttemptIds.size >= maxConsecutiveStageAttempts || + disallowStageRetryForTest + // It is likely that we receive multiple FetchFailed for a single stage (because we have // multiple tasks running concurrently on different executors). In that case, it is // possible the fetch failure has already been handled by the scheduler. if (runningStages.contains(failedStage)) { logInfo(s"Marking $failedStage (${failedStage.name}) as failed " + s"due to a fetch failure from $mapStage (${mapStage.name})") - markStageAsFinished(failedStage, Some(failureMessage)) + markStageAsFinished(failedStage, errorMessage = Some(failureMessage), + willRetry = !shouldAbortStage) } else { logDebug(s"Received fetch failure from $task, but its from $failedStage which is no " + s"longer running") } - failedStage.fetchFailedAttemptIds.add(task.stageAttemptId) - val shouldAbortStage = - failedStage.fetchFailedAttemptIds.size >= maxConsecutiveStageAttempts || - disallowStageRetryForTest - if (shouldAbortStage) { val abortMessage = if (disallowStageRetryForTest) { "Fetch failure will not retry stage due to testing config" @@ -1545,7 +1547,10 @@ class DAGScheduler( /** * Marks a stage as finished and removes it from the list of running stages. */ - private def markStageAsFinished(stage: Stage, errorMessage: Option[String] = None): Unit = { + private def markStageAsFinished( + stage: Stage, + errorMessage: Option[String] = None, + willRetry: Boolean = false): Unit = { val serviceTime = stage.latestInfo.submissionTime match { case Some(t) => "%.03f".format((clock.getTimeMillis() - t) / 1000.0) case _ => "Unknown" @@ -1564,7 +1569,9 @@ class DAGScheduler( logInfo(s"$stage (${stage.name}) failed in $serviceTime s due to ${errorMessage.get}") } - outputCommitCoordinator.stageEnd(stage.id) + if (!willRetry) { + outputCommitCoordinator.stageEnd(stage.id) + } listenerBus.post(SparkListenerStageCompleted(stage.latestInfo)) runningStages -= stage } diff --git a/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala b/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala index 83d87b548a430..b382d623806e2 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala @@ -27,7 +27,11 @@ import org.apache.spark.util.{RpcUtils, ThreadUtils} private sealed trait OutputCommitCoordinationMessage extends Serializable private case object StopCoordinator extends OutputCommitCoordinationMessage -private case class AskPermissionToCommitOutput(stage: Int, partition: Int, attemptNumber: Int) +private case class AskPermissionToCommitOutput( + stage: Int, + stageAttempt: Int, + partition: Int, + attemptNumber: Int) /** * Authority that decides whether tasks can commit output to HDFS. Uses a "first committer wins" @@ -45,13 +49,15 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean) // Initialized by SparkEnv var coordinatorRef: Option[RpcEndpointRef] = None - private type StageId = Int - private type PartitionId = Int - private type TaskAttemptNumber = Int - private val NO_AUTHORIZED_COMMITTER: TaskAttemptNumber = -1 + // Class used to identify a committer. The task ID for a committer is implicitly defined by + // the partition being processed, but the coordinator needs to keep track of both the stage + // attempt and the task attempt, because in some situations the same task may be running + // concurrently in two different attempts of the same stage. + private case class TaskIdentifier(stageAttempt: Int, taskAttempt: Int) + private case class StageState(numPartitions: Int) { - val authorizedCommitters = Array.fill[TaskAttemptNumber](numPartitions)(NO_AUTHORIZED_COMMITTER) - val failures = mutable.Map[PartitionId, mutable.Set[TaskAttemptNumber]]() + val authorizedCommitters = Array.fill[TaskIdentifier](numPartitions)(null) + val failures = mutable.Map[Int, mutable.Set[TaskIdentifier]]() } /** @@ -64,7 +70,7 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean) * * Access to this map should be guarded by synchronizing on the OutputCommitCoordinator instance. */ - private val stageStates = mutable.Map[StageId, StageState]() + private val stageStates = mutable.Map[Int, StageState]() /** * Returns whether the OutputCommitCoordinator's internal data structures are all empty. @@ -87,10 +93,11 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean) * @return true if this task is authorized to commit, false otherwise */ def canCommit( - stage: StageId, - partition: PartitionId, - attemptNumber: TaskAttemptNumber): Boolean = { - val msg = AskPermissionToCommitOutput(stage, partition, attemptNumber) + stage: Int, + stageAttempt: Int, + partition: Int, + attemptNumber: Int): Boolean = { + val msg = AskPermissionToCommitOutput(stage, stageAttempt, partition, attemptNumber) coordinatorRef match { case Some(endpointRef) => ThreadUtils.awaitResult(endpointRef.ask[Boolean](msg), @@ -103,26 +110,35 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean) } /** - * Called by the DAGScheduler when a stage starts. + * Called by the DAGScheduler when a stage starts. Initializes the stage's state if it hasn't + * yet been initialized. * * @param stage the stage id. * @param maxPartitionId the maximum partition id that could appear in this stage's tasks (i.e. * the maximum possible value of `context.partitionId`). */ - private[scheduler] def stageStart(stage: StageId, maxPartitionId: Int): Unit = synchronized { - stageStates(stage) = new StageState(maxPartitionId + 1) + private[scheduler] def stageStart(stage: Int, maxPartitionId: Int): Unit = synchronized { + stageStates.get(stage) match { + case Some(state) => + require(state.authorizedCommitters.length == maxPartitionId + 1) + logInfo(s"Reusing state from previous attempt of stage $stage.") + + case _ => + stageStates(stage) = new StageState(maxPartitionId + 1) + } } // Called by DAGScheduler - private[scheduler] def stageEnd(stage: StageId): Unit = synchronized { + private[scheduler] def stageEnd(stage: Int): Unit = synchronized { stageStates.remove(stage) } // Called by DAGScheduler private[scheduler] def taskCompleted( - stage: StageId, - partition: PartitionId, - attemptNumber: TaskAttemptNumber, + stage: Int, + stageAttempt: Int, + partition: Int, + attemptNumber: Int, reason: TaskEndReason): Unit = synchronized { val stageState = stageStates.getOrElse(stage, { logDebug(s"Ignoring task completion for completed stage") @@ -131,16 +147,17 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean) reason match { case Success => // The task output has been committed successfully - case denied: TaskCommitDenied => - logInfo(s"Task was denied committing, stage: $stage, partition: $partition, " + - s"attempt: $attemptNumber") - case otherReason => + case _: TaskCommitDenied => + logInfo(s"Task was denied committing, stage: $stage.$stageAttempt, " + + s"partition: $partition, attempt: $attemptNumber") + case _ => // Mark the attempt as failed to blacklist from future commit protocol - stageState.failures.getOrElseUpdate(partition, mutable.Set()) += attemptNumber - if (stageState.authorizedCommitters(partition) == attemptNumber) { + val taskId = TaskIdentifier(stageAttempt, attemptNumber) + stageState.failures.getOrElseUpdate(partition, mutable.Set()) += taskId + if (stageState.authorizedCommitters(partition) == taskId) { logDebug(s"Authorized committer (attemptNumber=$attemptNumber, stage=$stage, " + s"partition=$partition) failed; clearing lock") - stageState.authorizedCommitters(partition) = NO_AUTHORIZED_COMMITTER + stageState.authorizedCommitters(partition) = null } } } @@ -155,47 +172,41 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean) // Marked private[scheduler] instead of private so this can be mocked in tests private[scheduler] def handleAskPermissionToCommit( - stage: StageId, - partition: PartitionId, - attemptNumber: TaskAttemptNumber): Boolean = synchronized { + stage: Int, + stageAttempt: Int, + partition: Int, + attemptNumber: Int): Boolean = synchronized { stageStates.get(stage) match { - case Some(state) if attemptFailed(state, partition, attemptNumber) => - logInfo(s"Denying attemptNumber=$attemptNumber to commit for stage=$stage," + - s" partition=$partition as task attempt $attemptNumber has already failed.") + case Some(state) if attemptFailed(state, stageAttempt, partition, attemptNumber) => + logInfo(s"Commit denied for stage=$stage.$stageAttempt, partition=$partition: " + + s"task attempt $attemptNumber already marked as failed.") false case Some(state) => - state.authorizedCommitters(partition) match { - case NO_AUTHORIZED_COMMITTER => - logDebug(s"Authorizing attemptNumber=$attemptNumber to commit for stage=$stage, " + - s"partition=$partition") - state.authorizedCommitters(partition) = attemptNumber - true - case existingCommitter => - // Coordinator should be idempotent when receiving AskPermissionToCommit. - if (existingCommitter == attemptNumber) { - logWarning(s"Authorizing duplicate request to commit for " + - s"attemptNumber=$attemptNumber to commit for stage=$stage," + - s" partition=$partition; existingCommitter = $existingCommitter." + - s" This can indicate dropped network traffic.") - true - } else { - logDebug(s"Denying attemptNumber=$attemptNumber to commit for stage=$stage, " + - s"partition=$partition; existingCommitter = $existingCommitter") - false - } + val existing = state.authorizedCommitters(partition) + if (existing == null) { + logDebug(s"Commit allowed for stage=$stage.$stageAttempt, partition=$partition, " + + s"task attempt $attemptNumber") + state.authorizedCommitters(partition) = TaskIdentifier(stageAttempt, attemptNumber) + true + } else { + logDebug(s"Commit denied for stage=$stage.$stageAttempt, partition=$partition: " + + s"already committed by $existing") + false } case None => - logDebug(s"Stage $stage has completed, so not allowing" + - s" attempt number $attemptNumber of partition $partition to commit") + logDebug(s"Commit denied for stage=$stage.$stageAttempt, partition=$partition: " + + "stage already marked as completed.") false } } private def attemptFailed( stageState: StageState, - partition: PartitionId, - attempt: TaskAttemptNumber): Boolean = synchronized { - stageState.failures.get(partition).exists(_.contains(attempt)) + stageAttempt: Int, + partition: Int, + attempt: Int): Boolean = synchronized { + val failInfo = TaskIdentifier(stageAttempt, attempt) + stageState.failures.get(partition).exists(_.contains(failInfo)) } } @@ -215,9 +226,10 @@ private[spark] object OutputCommitCoordinator { } override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { - case AskPermissionToCommitOutput(stage, partition, attemptNumber) => + case AskPermissionToCommitOutput(stage, stageAttempt, partition, attemptNumber) => context.reply( - outputCommitCoordinator.handleAskPermissionToCommit(stage, partition, attemptNumber)) + outputCommitCoordinator.handleAskPermissionToCommit(stage, stageAttempt, partition, + attemptNumber)) } } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala index 03b1903902491..158c9eb75f2b6 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala @@ -35,6 +35,7 @@ import org.scalatest.BeforeAndAfter import org.apache.spark._ import org.apache.spark.internal.io.{FileCommitProtocol, HadoopMapRedCommitProtocol, SparkHadoopWriterUtils} import org.apache.spark.rdd.{FakeOutputCommitter, RDD} +import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.util.{ThreadUtils, Utils} /** @@ -153,7 +154,7 @@ class OutputCommitCoordinatorSuite extends SparkFunSuite with BeforeAndAfter { test("Job should not complete if all commits are denied") { // Create a mock OutputCommitCoordinator that denies all attempts to commit doReturn(false).when(outputCommitCoordinator).handleAskPermissionToCommit( - Matchers.any(), Matchers.any(), Matchers.any()) + Matchers.any(), Matchers.any(), Matchers.any(), Matchers.any()) val rdd: RDD[Int] = sc.parallelize(Seq(1), 1) def resultHandler(x: Int, y: Unit): Unit = {} val futureAction: SimpleFutureAction[Unit] = sc.submitJob[Int, Unit, Unit](rdd, @@ -169,45 +170,106 @@ class OutputCommitCoordinatorSuite extends SparkFunSuite with BeforeAndAfter { test("Only authorized committer failures can clear the authorized committer lock (SPARK-6614)") { val stage: Int = 1 + val stageAttempt: Int = 1 val partition: Int = 2 val authorizedCommitter: Int = 3 val nonAuthorizedCommitter: Int = 100 outputCommitCoordinator.stageStart(stage, maxPartitionId = 2) - assert(outputCommitCoordinator.canCommit(stage, partition, authorizedCommitter)) - assert(!outputCommitCoordinator.canCommit(stage, partition, nonAuthorizedCommitter)) + assert(outputCommitCoordinator.canCommit(stage, stageAttempt, partition, authorizedCommitter)) + assert(!outputCommitCoordinator.canCommit(stage, stageAttempt, partition, + nonAuthorizedCommitter)) // The non-authorized committer fails - outputCommitCoordinator.taskCompleted( - stage, partition, attemptNumber = nonAuthorizedCommitter, reason = TaskKilled("test")) + outputCommitCoordinator.taskCompleted(stage, stageAttempt, partition, + attemptNumber = nonAuthorizedCommitter, reason = TaskKilled("test")) // New tasks should still not be able to commit because the authorized committer has not failed - assert( - !outputCommitCoordinator.canCommit(stage, partition, nonAuthorizedCommitter + 1)) + assert(!outputCommitCoordinator.canCommit(stage, stageAttempt, partition, + nonAuthorizedCommitter + 1)) // The authorized committer now fails, clearing the lock - outputCommitCoordinator.taskCompleted( - stage, partition, attemptNumber = authorizedCommitter, reason = TaskKilled("test")) + outputCommitCoordinator.taskCompleted(stage, stageAttempt, partition, + attemptNumber = authorizedCommitter, reason = TaskKilled("test")) // A new task should now be allowed to become the authorized committer - assert( - outputCommitCoordinator.canCommit(stage, partition, nonAuthorizedCommitter + 2)) + assert(outputCommitCoordinator.canCommit(stage, stageAttempt, partition, + nonAuthorizedCommitter + 2)) // There can only be one authorized committer - assert( - !outputCommitCoordinator.canCommit(stage, partition, nonAuthorizedCommitter + 3)) - } - - test("Duplicate calls to canCommit from the authorized committer gets idempotent responses.") { - val rdd = sc.parallelize(Seq(1), 1) - sc.runJob(rdd, OutputCommitFunctions(tempDir.getAbsolutePath).callCanCommitMultipleTimes _, - 0 until rdd.partitions.size) + assert(!outputCommitCoordinator.canCommit(stage, stageAttempt, partition, + nonAuthorizedCommitter + 3)) } test("SPARK-19631: Do not allow failed attempts to be authorized for committing") { val stage: Int = 1 + val stageAttempt: Int = 1 val partition: Int = 1 val failedAttempt: Int = 0 outputCommitCoordinator.stageStart(stage, maxPartitionId = 1) - outputCommitCoordinator.taskCompleted(stage, partition, attemptNumber = failedAttempt, + outputCommitCoordinator.taskCompleted(stage, stageAttempt, partition, + attemptNumber = failedAttempt, reason = ExecutorLostFailure("0", exitCausedByApp = true, None)) - assert(!outputCommitCoordinator.canCommit(stage, partition, failedAttempt)) - assert(outputCommitCoordinator.canCommit(stage, partition, failedAttempt + 1)) + assert(!outputCommitCoordinator.canCommit(stage, stageAttempt, partition, failedAttempt)) + assert(outputCommitCoordinator.canCommit(stage, stageAttempt, partition, failedAttempt + 1)) + } + + test("SPARK-24589: Differentiate tasks from different stage attempts") { + var stage = 1 + val taskAttempt = 1 + val partition = 1 + + outputCommitCoordinator.stageStart(stage, maxPartitionId = 1) + assert(outputCommitCoordinator.canCommit(stage, 1, partition, taskAttempt)) + assert(!outputCommitCoordinator.canCommit(stage, 2, partition, taskAttempt)) + + // Fail the task in the first attempt, the task in the second attempt should succeed. + stage += 1 + outputCommitCoordinator.stageStart(stage, maxPartitionId = 1) + outputCommitCoordinator.taskCompleted(stage, 1, partition, taskAttempt, + ExecutorLostFailure("0", exitCausedByApp = true, None)) + assert(!outputCommitCoordinator.canCommit(stage, 1, partition, taskAttempt)) + assert(outputCommitCoordinator.canCommit(stage, 2, partition, taskAttempt)) + + // Commit the 1st attempt, fail the 2nd attempt, make sure 3rd attempt cannot commit, + // then fail the 1st attempt and make sure the 4th one can commit again. + stage += 1 + outputCommitCoordinator.stageStart(stage, maxPartitionId = 1) + assert(outputCommitCoordinator.canCommit(stage, 1, partition, taskAttempt)) + outputCommitCoordinator.taskCompleted(stage, 2, partition, taskAttempt, + ExecutorLostFailure("0", exitCausedByApp = true, None)) + assert(!outputCommitCoordinator.canCommit(stage, 3, partition, taskAttempt)) + outputCommitCoordinator.taskCompleted(stage, 1, partition, taskAttempt, + ExecutorLostFailure("0", exitCausedByApp = true, None)) + assert(outputCommitCoordinator.canCommit(stage, 4, partition, taskAttempt)) + } + + test("SPARK-24589: Make sure stage state is cleaned up") { + // Normal application without stage failures. + sc.parallelize(1 to 100, 100) + .map { i => (i % 10, i) } + .reduceByKey(_ + _) + .collect() + + assert(sc.dagScheduler.outputCommitCoordinator.isEmpty) + + // Force failures in a few tasks so that a stage is retried. Collect the ID of the failing + // stage so that we can check the state of the output committer. + val retriedStage = sc.parallelize(1 to 100, 10) + .map { i => (i % 10, i) } + .reduceByKey { case (_, _) => + val ctx = TaskContext.get() + if (ctx.stageAttemptNumber() == 0) { + throw new FetchFailedException(SparkEnv.get.blockManager.blockManagerId, 1, 1, 1, + new Exception("Failure for test.")) + } else { + ctx.stageId() + } + } + .collect() + .map { case (k, v) => v } + .toSet + + assert(retriedStage.size === 1) + assert(sc.dagScheduler.outputCommitCoordinator.isEmpty) + verify(sc.env.outputCommitCoordinator, times(2)) + .stageStart(Matchers.eq(retriedStage.head), Matchers.any()) + verify(sc.env.outputCommitCoordinator).stageEnd(Matchers.eq(retriedStage.head)) } } @@ -243,16 +305,6 @@ private case class OutputCommitFunctions(tempDirPath: String) { if (ctx.attemptNumber == 0) failingOutputCommitter else successfulOutputCommitter) } - // Receiver should be idempotent for AskPermissionToCommitOutput - def callCanCommitMultipleTimes(iter: Iterator[Int]): Unit = { - val ctx = TaskContext.get() - val canCommit1 = SparkEnv.get.outputCommitCoordinator - .canCommit(ctx.stageId(), ctx.partitionId(), ctx.attemptNumber()) - val canCommit2 = SparkEnv.get.outputCommitCoordinator - .canCommit(ctx.stageId(), ctx.partitionId(), ctx.attemptNumber()) - assert(canCommit1 && canCommit2) - } - private def runCommitWithProvidedCommitter( ctx: TaskContext, iter: Iterator[Int], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala index ea4bda327f36f..11ed7131e7e3d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala @@ -109,6 +109,7 @@ object DataWritingSparkTask extends Logging { iter: Iterator[InternalRow], useCommitCoordinator: Boolean): WriterCommitMessage = { val stageId = context.stageId() + val stageAttempt = context.stageAttemptNumber() val partId = context.partitionId() val attemptId = context.attemptNumber() val epochId = Option(context.getLocalProperty(MicroBatchExecution.BATCH_ID_KEY)).getOrElse("0") @@ -122,12 +123,14 @@ object DataWritingSparkTask extends Logging { val msg = if (useCommitCoordinator) { val coordinator = SparkEnv.get.outputCommitCoordinator - val commitAuthorized = coordinator.canCommit(context.stageId(), partId, attemptId) + val commitAuthorized = coordinator.canCommit(stageId, stageAttempt, partId, attemptId) if (commitAuthorized) { - logInfo(s"Writer for stage $stageId, task $partId.$attemptId is authorized to commit.") + logInfo(s"Writer for stage $stageId / $stageAttempt, " + + s"task $partId.$attemptId is authorized to commit.") dataWriter.commit() } else { - val message = s"Stage $stageId, task $partId.$attemptId: driver did not authorize commit" + val message = s"Stage $stageId / $stageAttempt, " + + s"task $partId.$attemptId: driver did not authorize commit" logInfo(message) // throwing CommitDeniedException will trigger the catch block for abort throw new CommitDeniedException(message, stageId, partId, attemptId) From b9a6f7499a7f4d95efaf6a753c480f3642e22187 Mon Sep 17 00:00:00 2001 From: Maryann Xue Date: Thu, 21 Jun 2018 11:45:30 -0700 Subject: [PATCH 0996/1161] [SPARK-24613][SQL] Cache with UDF could not be matched with subsequent dependent caches ## What changes were proposed in this pull request? Wrap the logical plan with a `AnalysisBarrier` for execution plan compilation in CacheManager, in order to avoid the plan being analyzed again. ## How was this patch tested? Add one test in `DatasetCacheSuite` Author: Maryann Xue Closes #21602 from maryannxue/cache-mismatch. --- .../spark/sql/execution/CacheManager.scala | 6 +++--- .../org/apache/spark/sql/DatasetCacheSuite.scala | 16 ++++++++++++++++ 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala index 93bf91e56f1bd..2db7c02e86014 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala @@ -26,7 +26,7 @@ import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.spark.internal.Logging import org.apache.spark.sql.{Dataset, SparkSession} import org.apache.spark.sql.catalyst.expressions.SubqueryExpression -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, ResolvedHint} +import org.apache.spark.sql.catalyst.plans.logical.{AnalysisBarrier, LogicalPlan, ResolvedHint} import org.apache.spark.sql.execution.columnar.InMemoryRelation import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation} import org.apache.spark.storage.StorageLevel @@ -97,7 +97,7 @@ class CacheManager extends Logging { val inMemoryRelation = InMemoryRelation( sparkSession.sessionState.conf.useCompression, sparkSession.sessionState.conf.columnBatchSize, storageLevel, - sparkSession.sessionState.executePlan(planToCache).executedPlan, + sparkSession.sessionState.executePlan(AnalysisBarrier(planToCache)).executedPlan, tableName, planToCache) cachedData.add(CachedData(planToCache, inMemoryRelation)) @@ -142,7 +142,7 @@ class CacheManager extends Logging { // Remove the cache entry before we create a new one, so that we can have a different // physical plan. it.remove() - val plan = spark.sessionState.executePlan(cd.plan).executedPlan + val plan = spark.sessionState.executePlan(AnalysisBarrier(cd.plan)).executedPlan val newCache = InMemoryRelation( cacheBuilder = cd.cachedRepresentation .cacheBuilder.copy(cachedPlan = plan)(_cachedColumnBuffers = null), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala index 82a93f74dd76c..c4f056334cd1a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql import org.scalatest.concurrent.TimeLimits import org.scalatest.time.SpanSugar._ +import org.apache.spark.sql.execution.columnar.{InMemoryRelation, InMemoryTableScanExec} import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.storage.StorageLevel @@ -132,4 +133,19 @@ class DatasetCacheSuite extends QueryTest with SharedSQLContext with TimeLimits df.unpersist() assert(df.storageLevel == StorageLevel.NONE) } + + test("SPARK-24613 Cache with UDF could not be matched with subsequent dependent caches") { + val udf1 = udf({x: Int => x + 1}) + val df = spark.range(0, 10).toDF("a").withColumn("b", udf1($"a")) + val df2 = df.agg(sum(df("b"))) + + df.cache() + df.count() + df2.cache() + + val plan = df2.queryExecution.withCachedData + assert(plan.isInstanceOf[InMemoryRelation]) + val internalPlan = plan.asInstanceOf[InMemoryRelation].cacheBuilder.cachedPlan + assert(internalPlan.find(_.isInstanceOf[InMemoryTableScanExec]).isDefined) + } } From dc8a6befa5dad861a731b4d7865f3ccf37482ae0 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 21 Jun 2018 15:38:46 -0700 Subject: [PATCH 0997/1161] [SPARK-24588][SS] streaming join should require HashClusteredPartitioning from children ## What changes were proposed in this pull request? In https://github.com/apache/spark/pull/19080 we simplified the distribution/partitioning framework, and make all the join-like operators require `HashClusteredDistribution` from children. Unfortunately streaming join operator was missed. This can cause wrong result. Think about ``` val input1 = MemoryStream[Int] val input2 = MemoryStream[Int] val df1 = input1.toDF.select('value as 'a, 'value * 2 as 'b) val df2 = input2.toDF.select('value as 'a, 'value * 2 as 'b).repartition('b) val joined = df1.join(df2, Seq("a", "b")).select('a) ``` The physical plan is ``` *(3) Project [a#5] +- StreamingSymmetricHashJoin [a#5, b#6], [a#10, b#11], Inner, condition = [ leftOnly = null, rightOnly = null, both = null, full = null ], state info [ checkpoint = , runId = 54e31fce-f055-4686-b75d-fcd2b076f8d8, opId = 0, ver = 0, numPartitions = 5], 0, state cleanup [ left = null, right = null ] :- Exchange hashpartitioning(a#5, b#6, 5) : +- *(1) Project [value#1 AS a#5, (value#1 * 2) AS b#6] : +- StreamingRelation MemoryStream[value#1], [value#1] +- Exchange hashpartitioning(b#11, 5) +- *(2) Project [value#3 AS a#10, (value#3 * 2) AS b#11] +- StreamingRelation MemoryStream[value#3], [value#3] ``` The left table is hash partitioned by `a, b`, while the right table is hash partitioned by `b`. This means, we may have a matching record that is in different partitions, which should be in the output but not. ## How was this patch tested? N/A Author: Wenchen Fan Closes #21587 from cloud-fan/join. --- .../plans/physical/partitioning.scala | 53 +++-- .../sql/catalyst/DistributionSuite.scala | 201 +++++++++++++++--- .../v2/DataSourcePartitioning.scala | 4 +- .../StreamingSymmetricHashJoinExec.scala | 4 +- .../sql/streaming/StreamingJoinSuite.scala | 14 ++ 5 files changed, 217 insertions(+), 59 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 4d9a9925fe3ff..cc1a5e835d9cd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -99,16 +99,19 @@ case class ClusteredDistribution( * This is a strictly stronger guarantee than [[ClusteredDistribution]]. Given a tuple and the * number of partitions, this distribution strictly requires which partition the tuple should be in. */ -case class HashClusteredDistribution(expressions: Seq[Expression]) extends Distribution { +case class HashClusteredDistribution( + expressions: Seq[Expression], + requiredNumPartitions: Option[Int] = None) extends Distribution { require( expressions != Nil, - "The expressions for hash of a HashPartitionedDistribution should not be Nil. " + + "The expressions for hash of a HashClusteredDistribution should not be Nil. " + "An AllTuples should be used to represent a distribution that only has " + "a single partition.") - override def requiredNumPartitions: Option[Int] = None - override def createPartitioning(numPartitions: Int): Partitioning = { + assert(requiredNumPartitions.isEmpty || requiredNumPartitions.get == numPartitions, + s"This HashClusteredDistribution requires ${requiredNumPartitions.get} partitions, but " + + s"the actual number of partitions is $numPartitions.") HashPartitioning(expressions, numPartitions) } } @@ -163,11 +166,22 @@ trait Partitioning { * i.e. the current dataset does not need to be re-partitioned for the `required` * Distribution (it is possible that tuples within a partition need to be reorganized). * + * A [[Partitioning]] can never satisfy a [[Distribution]] if its `numPartitions` does't match + * [[Distribution.requiredNumPartitions]]. + */ + final def satisfies(required: Distribution): Boolean = { + required.requiredNumPartitions.forall(_ == numPartitions) && satisfies0(required) + } + + /** + * The actual method that defines whether this [[Partitioning]] can satisfy the given + * [[Distribution]], after the `numPartitions` check. + * * By default a [[Partitioning]] can satisfy [[UnspecifiedDistribution]], and [[AllTuples]] if - * the [[Partitioning]] only have one partition. Implementations can overwrite this method with - * special logic. + * the [[Partitioning]] only have one partition. Implementations can also overwrite this method + * with special logic. */ - def satisfies(required: Distribution): Boolean = required match { + protected def satisfies0(required: Distribution): Boolean = required match { case UnspecifiedDistribution => true case AllTuples => numPartitions == 1 case _ => false @@ -186,9 +200,8 @@ case class RoundRobinPartitioning(numPartitions: Int) extends Partitioning case object SinglePartition extends Partitioning { val numPartitions = 1 - override def satisfies(required: Distribution): Boolean = required match { + override def satisfies0(required: Distribution): Boolean = required match { case _: BroadcastDistribution => false - case ClusteredDistribution(_, Some(requiredNumPartitions)) => requiredNumPartitions == 1 case _ => true } } @@ -205,16 +218,15 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) override def nullable: Boolean = false override def dataType: DataType = IntegerType - override def satisfies(required: Distribution): Boolean = { - super.satisfies(required) || { + override def satisfies0(required: Distribution): Boolean = { + super.satisfies0(required) || { required match { case h: HashClusteredDistribution => expressions.length == h.expressions.length && expressions.zip(h.expressions).forall { case (l, r) => l.semanticEquals(r) } - case ClusteredDistribution(requiredClustering, requiredNumPartitions) => - expressions.forall(x => requiredClustering.exists(_.semanticEquals(x))) && - (requiredNumPartitions.isEmpty || requiredNumPartitions.get == numPartitions) + case ClusteredDistribution(requiredClustering, _) => + expressions.forall(x => requiredClustering.exists(_.semanticEquals(x))) case _ => false } } @@ -246,15 +258,14 @@ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int) override def nullable: Boolean = false override def dataType: DataType = IntegerType - override def satisfies(required: Distribution): Boolean = { - super.satisfies(required) || { + override def satisfies0(required: Distribution): Boolean = { + super.satisfies0(required) || { required match { case OrderedDistribution(requiredOrdering) => val minSize = Seq(requiredOrdering.size, ordering.size).min requiredOrdering.take(minSize) == ordering.take(minSize) - case ClusteredDistribution(requiredClustering, requiredNumPartitions) => - ordering.map(_.child).forall(x => requiredClustering.exists(_.semanticEquals(x))) && - (requiredNumPartitions.isEmpty || requiredNumPartitions.get == numPartitions) + case ClusteredDistribution(requiredClustering, _) => + ordering.map(_.child).forall(x => requiredClustering.exists(_.semanticEquals(x))) case _ => false } } @@ -295,7 +306,7 @@ case class PartitioningCollection(partitionings: Seq[Partitioning]) * Returns true if any `partitioning` of this collection satisfies the given * [[Distribution]]. */ - override def satisfies(required: Distribution): Boolean = + override def satisfies0(required: Distribution): Boolean = partitionings.exists(_.satisfies(required)) override def toString: String = { @@ -310,7 +321,7 @@ case class PartitioningCollection(partitionings: Seq[Partitioning]) case class BroadcastPartitioning(mode: BroadcastMode) extends Partitioning { override val numPartitions: Int = 1 - override def satisfies(required: Distribution): Boolean = required match { + override def satisfies0(required: Distribution): Boolean = required match { case BroadcastDistribution(m) if m == mode => true case _ => false } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala index b47b8adfe5d55..39228102682b9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala @@ -41,34 +41,127 @@ class DistributionSuite extends SparkFunSuite { } } - test("HashPartitioning (with nullSafe = true) is the output partitioning") { - // Cases which do not need an exchange between two data properties. + test("UnspecifiedDistribution and AllTuples") { + // except `BroadcastPartitioning`, all other partitioning can satisfy UnspecifiedDistribution checkSatisfied( - HashPartitioning(Seq('a, 'b, 'c), 10), + UnknownPartitioning(-1), UnspecifiedDistribution, true) checkSatisfied( - HashPartitioning(Seq('a, 'b, 'c), 10), - ClusteredDistribution(Seq('a, 'b, 'c)), + RoundRobinPartitioning(10), + UnspecifiedDistribution, true) checkSatisfied( - HashPartitioning(Seq('b, 'c), 10), - ClusteredDistribution(Seq('a, 'b, 'c)), + SinglePartition, + UnspecifiedDistribution, + true) + + checkSatisfied( + HashPartitioning(Seq('a), 10), + UnspecifiedDistribution, + true) + + checkSatisfied( + RangePartitioning(Seq('a.asc), 10), + UnspecifiedDistribution, + true) + + checkSatisfied( + BroadcastPartitioning(IdentityBroadcastMode), + UnspecifiedDistribution, + false) + + // except `BroadcastPartitioning`, all other partitioning can satisfy AllTuples if they have + // only one partition. + checkSatisfied( + UnknownPartitioning(1), + AllTuples, + true) + + checkSatisfied( + UnknownPartitioning(10), + AllTuples, + false) + + checkSatisfied( + RoundRobinPartitioning(1), + AllTuples, + true) + + checkSatisfied( + RoundRobinPartitioning(10), + AllTuples, + false) + + checkSatisfied( + SinglePartition, + AllTuples, + true) + + checkSatisfied( + HashPartitioning(Seq('a), 1), + AllTuples, true) + checkSatisfied( + HashPartitioning(Seq('a), 10), + AllTuples, + false) + + checkSatisfied( + RangePartitioning(Seq('a.asc), 1), + AllTuples, + true) + + checkSatisfied( + RangePartitioning(Seq('a.asc), 10), + AllTuples, + false) + + checkSatisfied( + BroadcastPartitioning(IdentityBroadcastMode), + AllTuples, + false) + } + + test("SinglePartition is the output partitioning") { + // SinglePartition can satisfy all the distributions except `BroadcastDistribution` checkSatisfied( SinglePartition, ClusteredDistribution(Seq('a, 'b, 'c)), true) + checkSatisfied( + SinglePartition, + HashClusteredDistribution(Seq('a, 'b, 'c)), + true) + checkSatisfied( SinglePartition, OrderedDistribution(Seq('a.asc, 'b.asc, 'c.asc)), true) - // Cases which need an exchange between two data properties. + checkSatisfied( + SinglePartition, + BroadcastDistribution(IdentityBroadcastMode), + false) + } + + test("HashPartitioning is the output partitioning") { + // HashPartitioning can satisfy ClusteredDistribution iff its hash expressions are a subset of + // the required clustering expressions. + checkSatisfied( + HashPartitioning(Seq('a, 'b, 'c), 10), + ClusteredDistribution(Seq('a, 'b, 'c)), + true) + + checkSatisfied( + HashPartitioning(Seq('b, 'c), 10), + ClusteredDistribution(Seq('a, 'b, 'c)), + true) + checkSatisfied( HashPartitioning(Seq('a, 'b, 'c), 10), ClusteredDistribution(Seq('b, 'c)), @@ -79,37 +172,43 @@ class DistributionSuite extends SparkFunSuite { ClusteredDistribution(Seq('d, 'e)), false) + // HashPartitioning can satisfy HashClusteredDistribution iff its hash expressions are exactly + // same with the required hash clustering expressions. checkSatisfied( HashPartitioning(Seq('a, 'b, 'c), 10), - AllTuples, + HashClusteredDistribution(Seq('a, 'b, 'c)), + true) + + checkSatisfied( + HashPartitioning(Seq('c, 'b, 'a), 10), + HashClusteredDistribution(Seq('a, 'b, 'c)), false) + checkSatisfied( + HashPartitioning(Seq('a, 'b), 10), + HashClusteredDistribution(Seq('a, 'b, 'c)), + false) + + // HashPartitioning cannot satisfy OrderedDistribution checkSatisfied( HashPartitioning(Seq('a, 'b, 'c), 10), OrderedDistribution(Seq('a.asc, 'b.asc, 'c.asc)), false) checkSatisfied( - HashPartitioning(Seq('b, 'c), 10), + HashPartitioning(Seq('a, 'b, 'c), 1), OrderedDistribution(Seq('a.asc, 'b.asc, 'c.asc)), - false) + false) // TODO: this can be relaxed. - // TODO: We should check functional dependencies - /* checkSatisfied( - ClusteredDistribution(Seq('b)), - ClusteredDistribution(Seq('b + 1)), - true) - */ + HashPartitioning(Seq('b, 'c), 10), + OrderedDistribution(Seq('a.asc, 'b.asc, 'c.asc)), + false) } test("RangePartitioning is the output partitioning") { - // Cases which do not need an exchange between two data properties. - checkSatisfied( - RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), - UnspecifiedDistribution, - true) - + // RangePartitioning can satisfy OrderedDistribution iff its ordering is a prefix + // of the required ordering, or the required ordering is a prefix of its ordering. checkSatisfied( RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), OrderedDistribution(Seq('a.asc, 'b.asc, 'c.asc)), @@ -125,6 +224,27 @@ class DistributionSuite extends SparkFunSuite { OrderedDistribution(Seq('a.asc, 'b.asc, 'c.asc, 'd.desc)), true) + // TODO: We can have an optimization to first sort the dataset + // by a.asc and then sort b, and c in a partition. This optimization + // should tradeoff the benefit of a less number of Exchange operators + // and the parallelism. + checkSatisfied( + RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), + OrderedDistribution(Seq('a.asc, 'b.desc, 'c.asc)), + false) + + checkSatisfied( + RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), + OrderedDistribution(Seq('b.asc, 'a.asc)), + false) + + checkSatisfied( + RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), + OrderedDistribution(Seq('a.asc, 'b.asc, 'd.desc)), + false) + + // RangePartitioning can satisfy ClusteredDistribution iff its ordering expressions are a subset + // of the required clustering expressions. checkSatisfied( RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), ClusteredDistribution(Seq('a, 'b, 'c)), @@ -140,34 +260,47 @@ class DistributionSuite extends SparkFunSuite { ClusteredDistribution(Seq('b, 'c, 'a, 'd)), true) - // Cases which need an exchange between two data properties. - // TODO: We can have an optimization to first sort the dataset - // by a.asc and then sort b, and c in a partition. This optimization - // should tradeoff the benefit of a less number of Exchange operators - // and the parallelism. checkSatisfied( RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), - OrderedDistribution(Seq('a.asc, 'b.desc, 'c.asc)), + ClusteredDistribution(Seq('a, 'b)), false) checkSatisfied( RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), - OrderedDistribution(Seq('b.asc, 'a.asc)), + ClusteredDistribution(Seq('c, 'd)), false) + // RangePartitioning cannot satisfy HashClusteredDistribution checkSatisfied( RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), - ClusteredDistribution(Seq('a, 'b)), + HashClusteredDistribution(Seq('a, 'b, 'c)), false) + } + test("Partitioning.numPartitions must match Distribution.requiredNumPartitions to satisfy it") { checkSatisfied( - RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), - ClusteredDistribution(Seq('c, 'd)), + SinglePartition, + ClusteredDistribution(Seq('a, 'b, 'c), Some(10)), + false) + + checkSatisfied( + SinglePartition, + HashClusteredDistribution(Seq('a, 'b, 'c), Some(10)), + false) + + checkSatisfied( + HashPartitioning(Seq('a, 'b, 'c), 10), + ClusteredDistribution(Seq('a, 'b, 'c), Some(5)), + false) + + checkSatisfied( + HashPartitioning(Seq('a, 'b, 'c), 10), + HashClusteredDistribution(Seq('a, 'b, 'c), Some(5)), false) checkSatisfied( RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), - AllTuples, + ClusteredDistribution(Seq('a, 'b, 'c), Some(5)), false) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourcePartitioning.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourcePartitioning.scala index 017a6737161a6..33079d5912506 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourcePartitioning.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourcePartitioning.scala @@ -30,8 +30,8 @@ class DataSourcePartitioning( override val numPartitions: Int = partitioning.numPartitions() - override def satisfies(required: physical.Distribution): Boolean = { - super.satisfies(required) || { + override def satisfies0(required: physical.Distribution): Boolean = { + super.satisfies0(required) || { required match { case d: physical.ClusteredDistribution if isCandidate(d.clustering) => val attrs = d.clustering.map(_.asInstanceOf[Attribute]) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala index afa664eb76525..50cf971e4ec3c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala @@ -167,8 +167,8 @@ case class StreamingSymmetricHashJoinExec( val nullRight = new GenericInternalRow(right.output.map(_.withNullability(true)).length) override def requiredChildDistribution: Seq[Distribution] = - ClusteredDistribution(leftKeys, stateInfo.map(_.numPartitions)) :: - ClusteredDistribution(rightKeys, stateInfo.map(_.numPartitions)) :: Nil + HashClusteredDistribution(leftKeys, stateInfo.map(_.numPartitions)) :: + HashClusteredDistribution(rightKeys, stateInfo.map(_.numPartitions)) :: Nil override def output: Seq[Attribute] = joinType match { case _: InnerLike => left.output ++ right.output diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala index 1f62357e6d09e..c5cc8df4356a8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala @@ -404,6 +404,20 @@ class StreamingInnerJoinSuite extends StreamTest with StateStoreMetricsTest with AddData(input3, 5, 10), CheckNewAnswer((5, 10, 5, 15, 5, 25))) } + + test("streaming join should require HashClusteredDistribution from children") { + val input1 = MemoryStream[Int] + val input2 = MemoryStream[Int] + + val df1 = input1.toDF.select('value as 'a, 'value * 2 as 'b) + val df2 = input2.toDF.select('value as 'a, 'value * 2 as 'b).repartition('b) + val joined = df1.join(df2, Seq("a", "b")).select('a) + + testStream(joined)( + AddData(input1, 1.to(1000): _*), + AddData(input2, 1.to(1000): _*), + CheckAnswer(1.to(1000): _*)) + } } From 92c2f00bd275a90b6912fb8c8cf542002923629c Mon Sep 17 00:00:00 2001 From: Marek Novotny Date: Fri, 22 Jun 2018 16:18:22 +0900 Subject: [PATCH 0998/1161] [SPARK-23934][SQL] Adding map_from_entries function ## What changes were proposed in this pull request? The PR adds the `map_from_entries` function that returns a map created from the given array of entries. ## How was this patch tested? New tests added into: - `CollectionExpressionSuite` - `DataFrameFunctionSuite` ## CodeGen Examples ### Primitive-type Keys and Values ``` val idf = Seq( Seq((1, 10), (2, 20), (3, 10)), Seq((1, 10), null, (2, 20)) ).toDF("a") idf.filter('a.isNotNull).select(map_from_entries('a)).debugCodegen ``` Result: ``` /* 042 */ boolean project_isNull_0 = false; /* 043 */ MapData project_value_0 = null; /* 044 */ /* 045 */ for (int project_idx_2 = 0; !project_isNull_0 && project_idx_2 < inputadapter_value_0.numElements(); project_idx_2++) { /* 046 */ project_isNull_0 |= inputadapter_value_0.isNullAt(project_idx_2); /* 047 */ } /* 048 */ if (!project_isNull_0) { /* 049 */ final int project_numEntries_0 = inputadapter_value_0.numElements(); /* 050 */ /* 051 */ final long project_keySectionSize_0 = UnsafeArrayData.calculateSizeOfUnderlyingByteArray(project_numEntries_0, 4); /* 052 */ final long project_valueSectionSize_0 = UnsafeArrayData.calculateSizeOfUnderlyingByteArray(project_numEntries_0, 4); /* 053 */ final long project_byteArraySize_0 = 8 + project_keySectionSize_0 + project_valueSectionSize_0; /* 054 */ if (project_byteArraySize_0 > 2147483632) { /* 055 */ final Object[] project_keys_0 = new Object[project_numEntries_0]; /* 056 */ final Object[] project_values_0 = new Object[project_numEntries_0]; /* 057 */ /* 058 */ for (int project_idx_1 = 0; project_idx_1 < project_numEntries_0; project_idx_1++) { /* 059 */ InternalRow project_entry_1 = inputadapter_value_0.getStruct(project_idx_1, 2); /* 060 */ /* 061 */ project_keys_0[project_idx_1] = project_entry_1.getInt(0); /* 062 */ project_values_0[project_idx_1] = project_entry_1.getInt(1); /* 063 */ } /* 064 */ /* 065 */ project_value_0 = org.apache.spark.sql.catalyst.util.ArrayBasedMapData.apply(project_keys_0, project_values_0); /* 066 */ /* 067 */ } else { /* 068 */ final byte[] project_byteArray_0 = new byte[(int)project_byteArraySize_0]; /* 069 */ UnsafeMapData project_unsafeMapData_0 = new UnsafeMapData(); /* 070 */ Platform.putLong(project_byteArray_0, 16, project_keySectionSize_0); /* 071 */ Platform.putLong(project_byteArray_0, 24, project_numEntries_0); /* 072 */ Platform.putLong(project_byteArray_0, 24 + project_keySectionSize_0, project_numEntries_0); /* 073 */ project_unsafeMapData_0.pointTo(project_byteArray_0, 16, (int)project_byteArraySize_0); /* 074 */ ArrayData project_keyArrayData_0 = project_unsafeMapData_0.keyArray(); /* 075 */ ArrayData project_valueArrayData_0 = project_unsafeMapData_0.valueArray(); /* 076 */ /* 077 */ for (int project_idx_0 = 0; project_idx_0 < project_numEntries_0; project_idx_0++) { /* 078 */ InternalRow project_entry_0 = inputadapter_value_0.getStruct(project_idx_0, 2); /* 079 */ /* 080 */ project_keyArrayData_0.setInt(project_idx_0, project_entry_0.getInt(0)); /* 081 */ project_valueArrayData_0.setInt(project_idx_0, project_entry_0.getInt(1)); /* 082 */ } /* 083 */ /* 084 */ project_value_0 = project_unsafeMapData_0; /* 085 */ } /* 086 */ /* 087 */ } ``` ### Non-primitive-type Keys and Values ``` val sdf = Seq( Seq(("a", null), ("b", "bb"), ("c", "aa")), Seq(("a", "aa"), null, (null, "bb")) ).toDF("a") sdf.filter('a.isNotNull).select(map_from_entries('a)).debugCodegen ``` Result: ``` /* 042 */ boolean project_isNull_0 = false; /* 043 */ MapData project_value_0 = null; /* 044 */ /* 045 */ for (int project_idx_1 = 0; !project_isNull_0 && project_idx_1 < inputadapter_value_0.numElements(); project_idx_1++) { /* 046 */ project_isNull_0 |= inputadapter_value_0.isNullAt(project_idx_1); /* 047 */ } /* 048 */ if (!project_isNull_0) { /* 049 */ final int project_numEntries_0 = inputadapter_value_0.numElements(); /* 050 */ /* 051 */ final Object[] project_keys_0 = new Object[project_numEntries_0]; /* 052 */ final Object[] project_values_0 = new Object[project_numEntries_0]; /* 053 */ /* 054 */ for (int project_idx_0 = 0; project_idx_0 < project_numEntries_0; project_idx_0++) { /* 055 */ InternalRow project_entry_0 = inputadapter_value_0.getStruct(project_idx_0, 2); /* 056 */ /* 057 */ if (project_entry_0.isNullAt(0)) { /* 058 */ throw new RuntimeException("The first field from a struct (key) can't be null."); /* 059 */ } /* 060 */ /* 061 */ project_keys_0[project_idx_0] = project_entry_0.getUTF8String(0); /* 062 */ project_values_0[project_idx_0] = project_entry_0.getUTF8String(1); /* 063 */ } /* 064 */ /* 065 */ project_value_0 = org.apache.spark.sql.catalyst.util.ArrayBasedMapData.apply(project_keys_0, project_values_0); /* 066 */ /* 067 */ } ``` Author: Marek Novotny Closes #21282 from mn-mikke/feature/array-api-map_from_entries-to-master. --- python/pyspark/sql/functions.py | 20 ++ .../catalyst/analysis/FunctionRegistry.scala | 1 + .../expressions/codegen/CodeGenerator.scala | 30 +++ .../expressions/collectionOperations.scala | 235 ++++++++++++++++-- .../CollectionExpressionsSuite.scala | 51 ++++ .../org/apache/spark/sql/functions.scala | 7 + .../spark/sql/DataFrameFunctionsSuite.scala | 50 ++++ 7 files changed, 378 insertions(+), 16 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 11b179fe26bfc..5f5d73307e9c8 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2412,6 +2412,26 @@ def map_entries(col): return Column(sc._jvm.functions.map_entries(_to_java_column(col))) +@since(2.4) +def map_from_entries(col): + """ + Collection function: Returns a map created from the given array of entries. + + :param col: name of column or expression + + >>> from pyspark.sql.functions import map_from_entries + >>> df = spark.sql("SELECT array(struct(1, 'a'), struct(2, 'b')) as data") + >>> df.select(map_from_entries("data").alias("map")).show() + +----------------+ + | map| + +----------------+ + |[1 -> a, 2 -> b]| + +----------------+ + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.map_from_entries(_to_java_column(col))) + + @ignore_unicode_prefix @since(2.4) def array_repeat(col, count): diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 4b09b9a7e75df..8abc616c1a3f7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -421,6 +421,7 @@ object FunctionRegistry { expression[MapKeys]("map_keys"), expression[MapValues]("map_values"), expression[MapEntries]("map_entries"), + expression[MapFromEntries]("map_from_entries"), expression[Size]("size"), expression[Slice]("slice"), expression[Size]("cardinality"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 66315e5906253..4cc0968911cb5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -819,6 +819,36 @@ class CodegenContext { } } + /** + * Generates code to do null safe execution when accessing properties of complex + * ArrayData elements. + * + * @param nullElements used to decide whether the ArrayData might contain null or not. + * @param isNull a variable indicating whether the result will be evaluated to null or not. + * @param arrayData a variable name representing the ArrayData. + * @param execute the code that should be executed only if the ArrayData doesn't contain + * any null. + */ + def nullArrayElementsSaveExec( + nullElements: Boolean, + isNull: String, + arrayData: String)( + execute: String): String = { + val i = freshName("idx") + if (nullElements) { + s""" + |for (int $i = 0; !$isNull && $i < $arrayData.numElements(); $i++) { + | $isNull |= $arrayData.isNullAt($i); + |} + |if (!$isNull) { + | $execute + |} + """.stripMargin + } else { + execute + } + } + /** * Splits the generated code of expressions into multiple functions, because function has * 64kb code size limit in JVM. If the class to which the function would be inlined would grow diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 7c064a130ff35..3afabe14606e4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion} import org.apache.spark.sql.catalyst.expressions.ArraySortLike.NullOrder import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ -import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData, TypeUtils} +import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.Platform @@ -475,6 +475,223 @@ case class MapEntries(child: Expression) extends UnaryExpression with ExpectsInp override def prettyName: String = "map_entries" } +/** + * Returns a map created from the given array of entries. + */ +@ExpressionDescription( + usage = "_FUNC_(arrayOfEntries) - Returns a map created from the given array of entries.", + examples = """ + Examples: + > SELECT _FUNC_(array(struct(1, 'a'), struct(2, 'b'))); + {1:"a",2:"b"} + """, + since = "2.4.0") +case class MapFromEntries(child: Expression) extends UnaryExpression { + + @transient + private lazy val dataTypeDetails: Option[(MapType, Boolean, Boolean)] = child.dataType match { + case ArrayType( + StructType(Array( + StructField(_, keyType, keyNullable, _), + StructField(_, valueType, valueNullable, _))), + containsNull) => Some((MapType(keyType, valueType, valueNullable), keyNullable, containsNull)) + case _ => None + } + + private def nullEntries: Boolean = dataTypeDetails.get._3 + + override def nullable: Boolean = child.nullable || nullEntries + + override def dataType: MapType = dataTypeDetails.get._1 + + override def checkInputDataTypes(): TypeCheckResult = dataTypeDetails match { + case Some(_) => TypeCheckResult.TypeCheckSuccess + case None => TypeCheckResult.TypeCheckFailure(s"'${child.sql}' is of " + + s"${child.dataType.simpleString} type. $prettyName accepts only arrays of pair structs.") + } + + override protected def nullSafeEval(input: Any): Any = { + val arrayData = input.asInstanceOf[ArrayData] + val numEntries = arrayData.numElements() + var i = 0 + if(nullEntries) { + while (i < numEntries) { + if (arrayData.isNullAt(i)) return null + i += 1 + } + } + val keyArray = new Array[AnyRef](numEntries) + val valueArray = new Array[AnyRef](numEntries) + i = 0 + while (i < numEntries) { + val entry = arrayData.getStruct(i, 2) + val key = entry.get(0, dataType.keyType) + if (key == null) { + throw new RuntimeException("The first field from a struct (key) can't be null.") + } + keyArray.update(i, key) + val value = entry.get(1, dataType.valueType) + valueArray.update(i, value) + i += 1 + } + ArrayBasedMapData(keyArray, valueArray) + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, c => { + val numEntries = ctx.freshName("numEntries") + val isKeyPrimitive = CodeGenerator.isPrimitiveType(dataType.keyType) + val isValuePrimitive = CodeGenerator.isPrimitiveType(dataType.valueType) + val code = if (isKeyPrimitive && isValuePrimitive) { + genCodeForPrimitiveElements(ctx, c, ev.value, numEntries) + } else { + genCodeForAnyElements(ctx, c, ev.value, numEntries) + } + ctx.nullArrayElementsSaveExec(nullEntries, ev.isNull, c) { + s""" + |final int $numEntries = $c.numElements(); + |$code + """.stripMargin + } + }) + } + + private def genCodeForAssignmentLoop( + ctx: CodegenContext, + childVariable: String, + mapData: String, + numEntries: String, + keyAssignment: (String, String) => String, + valueAssignment: (String, String) => String): String = { + val entry = ctx.freshName("entry") + val i = ctx.freshName("idx") + + val nullKeyCheck = if (dataTypeDetails.get._2) { + s""" + |if ($entry.isNullAt(0)) { + | throw new RuntimeException("The first field from a struct (key) can't be null."); + |} + """.stripMargin + } else { + "" + } + + s""" + |for (int $i = 0; $i < $numEntries; $i++) { + | InternalRow $entry = $childVariable.getStruct($i, 2); + | $nullKeyCheck + | ${keyAssignment(CodeGenerator.getValue(entry, dataType.keyType, "0"), i)} + | ${valueAssignment(entry, i)} + |} + """.stripMargin + } + + private def genCodeForPrimitiveElements( + ctx: CodegenContext, + childVariable: String, + mapData: String, + numEntries: String): String = { + val byteArraySize = ctx.freshName("byteArraySize") + val keySectionSize = ctx.freshName("keySectionSize") + val valueSectionSize = ctx.freshName("valueSectionSize") + val data = ctx.freshName("byteArray") + val unsafeMapData = ctx.freshName("unsafeMapData") + val keyArrayData = ctx.freshName("keyArrayData") + val valueArrayData = ctx.freshName("valueArrayData") + + val baseOffset = Platform.BYTE_ARRAY_OFFSET + val keySize = dataType.keyType.defaultSize + val valueSize = dataType.valueType.defaultSize + val kByteSize = s"UnsafeArrayData.calculateSizeOfUnderlyingByteArray($numEntries, $keySize)" + val vByteSize = s"UnsafeArrayData.calculateSizeOfUnderlyingByteArray($numEntries, $valueSize)" + val keyTypeName = CodeGenerator.primitiveTypeName(dataType.keyType) + val valueTypeName = CodeGenerator.primitiveTypeName(dataType.valueType) + + val keyAssignment = (key: String, idx: String) => s"$keyArrayData.set$keyTypeName($idx, $key);" + val valueAssignment = (entry: String, idx: String) => { + val value = CodeGenerator.getValue(entry, dataType.valueType, "1") + val valueNullUnsafeAssignment = s"$valueArrayData.set$valueTypeName($idx, $value);" + if (dataType.valueContainsNull) { + s""" + |if ($entry.isNullAt(1)) { + | $valueArrayData.setNullAt($idx); + |} else { + | $valueNullUnsafeAssignment + |} + """.stripMargin + } else { + valueNullUnsafeAssignment + } + } + val assignmentLoop = genCodeForAssignmentLoop( + ctx, + childVariable, + mapData, + numEntries, + keyAssignment, + valueAssignment + ) + + s""" + |final long $keySectionSize = $kByteSize; + |final long $valueSectionSize = $vByteSize; + |final long $byteArraySize = 8 + $keySectionSize + $valueSectionSize; + |if ($byteArraySize > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { + | ${genCodeForAnyElements(ctx, childVariable, mapData, numEntries)} + |} else { + | final byte[] $data = new byte[(int)$byteArraySize]; + | UnsafeMapData $unsafeMapData = new UnsafeMapData(); + | Platform.putLong($data, $baseOffset, $keySectionSize); + | Platform.putLong($data, ${baseOffset + 8}, $numEntries); + | Platform.putLong($data, ${baseOffset + 8} + $keySectionSize, $numEntries); + | $unsafeMapData.pointTo($data, $baseOffset, (int)$byteArraySize); + | ArrayData $keyArrayData = $unsafeMapData.keyArray(); + | ArrayData $valueArrayData = $unsafeMapData.valueArray(); + | $assignmentLoop + | $mapData = $unsafeMapData; + |} + """.stripMargin + } + + private def genCodeForAnyElements( + ctx: CodegenContext, + childVariable: String, + mapData: String, + numEntries: String): String = { + val keys = ctx.freshName("keys") + val values = ctx.freshName("values") + val mapDataClass = classOf[ArrayBasedMapData].getName() + + val isValuePrimitive = CodeGenerator.isPrimitiveType(dataType.valueType) + val valueAssignment = (entry: String, idx: String) => { + val value = CodeGenerator.getValue(entry, dataType.valueType, "1") + if (dataType.valueContainsNull && isValuePrimitive) { + s"$values[$idx] = $entry.isNullAt(1) ? null : (Object)$value;" + } else { + s"$values[$idx] = $value;" + } + } + val keyAssignment = (key: String, idx: String) => s"$keys[$idx] = $key;" + val assignmentLoop = genCodeForAssignmentLoop( + ctx, + childVariable, + mapData, + numEntries, + keyAssignment, + valueAssignment) + + s""" + |final Object[] $keys = new Object[$numEntries]; + |final Object[] $values = new Object[$numEntries]; + |$assignmentLoop + |$mapData = $mapDataClass.apply($keys, $values); + """.stripMargin + } + + override def prettyName: String = "map_from_entries" +} + + /** * Common base class for [[SortArray]] and [[ArraySort]]. */ @@ -1990,24 +2207,10 @@ case class Flatten(child: Expression) extends UnaryExpression { } else { genCodeForFlattenOfNonPrimitiveElements(ctx, c, ev.value) } - if (childDataType.containsNull) nullElementsProtection(ev, c, code) else code + ctx.nullArrayElementsSaveExec(childDataType.containsNull, ev.isNull, c)(code) }) } - private def nullElementsProtection( - ev: ExprCode, - childVariableName: String, - coreLogic: String): String = { - s""" - |for (int z = 0; !${ev.isNull} && z < $childVariableName.numElements(); z++) { - | ${ev.isNull} |= $childVariableName.isNullAt(z); - |} - |if (!${ev.isNull}) { - | $coreLogic - |} - """.stripMargin - } - private def genCodeForNumberOfElements( ctx: CodegenContext, childVariableName: String) : (String, String) = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index f377f9c8cd533..5b8cf5128fe21 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -80,6 +80,57 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(MapEntries(ms2), null) } + test("MapFromEntries") { + def arrayType(keyType: DataType, valueType: DataType) : DataType = { + ArrayType( + StructType(Seq( + StructField("a", keyType), + StructField("b", valueType))), + true) + } + def r(values: Any*): InternalRow = create_row(values: _*) + + // Primitive-type keys and values + val aiType = arrayType(IntegerType, IntegerType) + val ai0 = Literal.create(Seq(r(1, 10), r(2, 20), r(3, 20)), aiType) + val ai1 = Literal.create(Seq(r(1, null), r(2, 20), r(3, null)), aiType) + val ai2 = Literal.create(Seq.empty, aiType) + val ai3 = Literal.create(null, aiType) + val ai4 = Literal.create(Seq(r(1, 10), r(1, 20)), aiType) + val ai5 = Literal.create(Seq(r(1, 10), r(null, 20)), aiType) + val ai6 = Literal.create(Seq(null, r(2, 20), null), aiType) + + checkEvaluation(MapFromEntries(ai0), Map(1 -> 10, 2 -> 20, 3 -> 20)) + checkEvaluation(MapFromEntries(ai1), Map(1 -> null, 2 -> 20, 3 -> null)) + checkEvaluation(MapFromEntries(ai2), Map.empty) + checkEvaluation(MapFromEntries(ai3), null) + checkEvaluation(MapKeys(MapFromEntries(ai4)), Seq(1, 1)) + checkExceptionInExpression[RuntimeException]( + MapFromEntries(ai5), + "The first field from a struct (key) can't be null.") + checkEvaluation(MapFromEntries(ai6), null) + + // Non-primitive-type keys and values + val asType = arrayType(StringType, StringType) + val as0 = Literal.create(Seq(r("a", "aa"), r("b", "bb"), r("c", "bb")), asType) + val as1 = Literal.create(Seq(r("a", null), r("b", "bb"), r("c", null)), asType) + val as2 = Literal.create(Seq.empty, asType) + val as3 = Literal.create(null, asType) + val as4 = Literal.create(Seq(r("a", "aa"), r("a", "bb")), asType) + val as5 = Literal.create(Seq(r("a", "aa"), r(null, "bb")), asType) + val as6 = Literal.create(Seq(null, r("b", "bb"), null), asType) + + checkEvaluation(MapFromEntries(as0), Map("a" -> "aa", "b" -> "bb", "c" -> "bb")) + checkEvaluation(MapFromEntries(as1), Map("a" -> null, "b" -> "bb", "c" -> null)) + checkEvaluation(MapFromEntries(as2), Map.empty) + checkEvaluation(MapFromEntries(as3), null) + checkEvaluation(MapKeys(MapFromEntries(as4)), Seq("a", "a")) + checkExceptionInExpression[RuntimeException]( + MapFromEntries(as5), + "The first field from a struct (key) can't be null.") + checkEvaluation(MapFromEntries(as6), null) + } + test("Sort Array") { val a0 = Literal.create(Seq(2, 1, 3), ArrayType(IntegerType)) val a1 = Literal.create(Seq[Integer](), ArrayType(IntegerType)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index c296a1bf9d69d..f792a6fba1d8f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3526,6 +3526,13 @@ object functions { */ def map_entries(e: Column): Column = withExpr { MapEntries(e.expr) } + /** + * Returns a map created from the given array of entries. + * @group collection_funcs + * @since 2.4.0 + */ + def map_from_entries(e: Column): Column = withExpr { MapFromEntries(e.expr) } + /** * Returns a merged array of structs in which the N-th struct contains all N-th values of input * arrays. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index fcdd33f544311..25fdbab745128 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -633,6 +633,56 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { checkAnswer(sdf.filter(dummyFilter('m)).select(map_entries('m)), sExpected) } + test("map_from_entries function") { + def dummyFilter(c: Column): Column = c.isNull || c.isNotNull + val oneRowDF = Seq(3215).toDF("i") + + // Test cases with primitive-type keys and values + val idf = Seq( + Seq((1, 10), (2, 20), (3, 10)), + Seq((1, 10), null, (2, 20)), + Seq.empty, + null + ).toDF("a") + val iExpected = Seq( + Row(Map(1 -> 10, 2 -> 20, 3 -> 10)), + Row(null), + Row(Map.empty), + Row(null)) + + checkAnswer(idf.select(map_from_entries('a)), iExpected) + checkAnswer(idf.selectExpr("map_from_entries(a)"), iExpected) + checkAnswer(idf.filter(dummyFilter('a)).select(map_from_entries('a)), iExpected) + checkAnswer( + oneRowDF.selectExpr("map_from_entries(array(struct(1, null), struct(2, null)))"), + Seq(Row(Map(1 -> null, 2 -> null))) + ) + checkAnswer( + oneRowDF.filter(dummyFilter('i)) + .selectExpr("map_from_entries(array(struct(1, null), struct(2, null)))"), + Seq(Row(Map(1 -> null, 2 -> null))) + ) + + // Test cases with non-primitive-type keys and values + val sdf = Seq( + Seq(("a", "aa"), ("b", "bb"), ("c", "aa")), + Seq(("a", "aa"), null, ("b", "bb")), + Seq(("a", null), ("b", null)), + Seq.empty, + null + ).toDF("a") + val sExpected = Seq( + Row(Map("a" -> "aa", "b" -> "bb", "c" -> "aa")), + Row(null), + Row(Map("a" -> null, "b" -> null)), + Row(Map.empty), + Row(null)) + + checkAnswer(sdf.select(map_from_entries('a)), sExpected) + checkAnswer(sdf.selectExpr("map_from_entries(a)"), sExpected) + checkAnswer(sdf.filter(dummyFilter('a)).select(map_from_entries('a)), sExpected) + } + test("array contains function") { val df = Seq( (Seq[Int](1, 2), "x", 1), From 39dfaf2fd167cafc84ec9cc637c114ed54a331e3 Mon Sep 17 00:00:00 2001 From: Hieu Huynh <“Hieu.huynh@oath.com”> Date: Fri, 22 Jun 2018 09:16:14 -0500 Subject: [PATCH 0999/1161] [SPARK-24519] Make the threshold for highly compressed map status configurable MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit **Problem** MapStatus uses hardcoded value of 2000 partitions to determine if it should use highly compressed map status. We should make it configurable to allow users to more easily tune their jobs with respect to this without having for them to modify their code to change the number of partitions. Note we can leave this as an internal/undocumented config for now until we have more advise for the users on how to set this config. Some of my reasoning: The config gives you a way to easily change something without the user having to change code, redeploy jar, and then run again. You can simply change the config and rerun. It also allows for easier experimentation. Changing the # of partitions has other side affects, whether good or bad is situation dependent. It can be worse are you could be increasing # of output files when you don't want to be, affects the # of tasks needs and thus executors to run in parallel, etc. There have been various talks about this number at spark summits where people have told customers to increase it to be 2001 partitions. Note if you just do a search for spark 2000 partitions you will fine various things all talking about this number. This shows that people are modifying their code to take this into account so it seems to me having this configurable would be better. Once we have more advice for users we could expose this and document information on it. **What changes were proposed in this pull request?** I make the hardcoded value mentioned above to be configurable under the name _SHUFFLE_MIN_NUM_PARTS_TO_HIGHLY_COMPRESS_, which has default value to be 2000. Users can set it to the value they want by setting the property name _spark.shuffle.minNumPartitionsToHighlyCompress_ **How was this patch tested?** I wrote a unit test to make sure that the default value is 2000, and _IllegalArgumentException_ will be thrown if user set it to a non-positive value. The unit test also checks that highly compressed map status is correctly used when the number of partition is greater than _SHUFFLE_MIN_NUM_PARTS_TO_HIGHLY_COMPRESS_. Author: Hieu Huynh <“Hieu.huynh@oath.com”> Closes #21527 from hthuynh2/spark_branch_1. --- .../spark/internal/config/package.scala | 7 +++++ .../apache/spark/scheduler/MapStatus.scala | 4 ++- .../spark/scheduler/MapStatusSuite.scala | 28 +++++++++++++++++++ 3 files changed, 38 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index a54b091a64d50..38a043c85ae33 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -552,4 +552,11 @@ package object config { .timeConf(TimeUnit.SECONDS) .createWithDefaultString("1h") + private[spark] val SHUFFLE_MIN_NUM_PARTS_TO_HIGHLY_COMPRESS = + ConfigBuilder("spark.shuffle.minNumPartitionsToHighlyCompress") + .internal() + .doc("Number of partitions to determine if MapStatus should use HighlyCompressedMapStatus") + .intConf + .checkValue(v => v > 0, "The value should be a positive integer.") + .createWithDefault(2000) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala index 2ec2f2031aa45..659694dd189ad 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala @@ -50,7 +50,9 @@ private[spark] sealed trait MapStatus { private[spark] object MapStatus { def apply(loc: BlockManagerId, uncompressedSizes: Array[Long]): MapStatus = { - if (uncompressedSizes.length > 2000) { + if (uncompressedSizes.length > Option(SparkEnv.get) + .map(_.conf.get(config.SHUFFLE_MIN_NUM_PARTS_TO_HIGHLY_COMPRESS)) + .getOrElse(config.SHUFFLE_MIN_NUM_PARTS_TO_HIGHLY_COMPRESS.defaultValue.get)) { HighlyCompressedMapStatus(loc, uncompressedSizes) } else { new CompressedMapStatus(loc, uncompressedSizes) diff --git a/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala index 2155a0f2b6c21..354e6386fa60e 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala @@ -188,4 +188,32 @@ class MapStatusSuite extends SparkFunSuite { assert(count === 3000) } } + + test("SPARK-24519: HighlyCompressedMapStatus has configurable threshold") { + val conf = new SparkConf() + val env = mock(classOf[SparkEnv]) + doReturn(conf).when(env).conf + SparkEnv.set(env) + val sizes = Array.fill[Long](500)(150L) + // Test default value + val status = MapStatus(null, sizes) + assert(status.isInstanceOf[CompressedMapStatus]) + // Test Non-positive values + for (s <- -1 to 0) { + assertThrows[IllegalArgumentException] { + conf.set(config.SHUFFLE_MIN_NUM_PARTS_TO_HIGHLY_COMPRESS, s) + val status = MapStatus(null, sizes) + } + } + // Test positive values + Seq(1, 100, 499, 500, 501).foreach { s => + conf.set(config.SHUFFLE_MIN_NUM_PARTS_TO_HIGHLY_COMPRESS, s) + val status = MapStatus(null, sizes) + if(sizes.length > s) { + assert(status.isInstanceOf[HighlyCompressedMapStatus]) + } else { + assert(status.isInstanceOf[CompressedMapStatus]) + } + } + } } From 33e77fa89b5805ecb1066fc534723527f70d37c7 Mon Sep 17 00:00:00 2001 From: jerryshao Date: Fri, 22 Jun 2018 10:14:12 -0700 Subject: [PATCH 1000/1161] [SPARK-24518][CORE] Using Hadoop credential provider API to store password ## What changes were proposed in this pull request? In our distribution, because we don't do such fine-grained access control of config file, also configuration file is world readable shared between different components, so password may leak to different users. Hadoop credential provider API support storing password in a secure way, in which Spark could read it in a secure way, so here propose to add support of using credential provider API to get password. ## How was this patch tested? Adding tests and verified locally. Author: jerryshao Closes #21548 from jerryshao/SPARK-24518. --- .../scala/org/apache/spark/SSLOptions.scala | 11 ++- .../org/apache/spark/SecurityManager.scala | 9 ++- .../org/apache/spark/SSLOptionsSuite.scala | 75 +++++++++++++++++-- docs/security.md | 23 +++++- 4 files changed, 107 insertions(+), 11 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SSLOptions.scala b/core/src/main/scala/org/apache/spark/SSLOptions.scala index 04c38f12acc78..1632e0c69eef5 100644 --- a/core/src/main/scala/org/apache/spark/SSLOptions.scala +++ b/core/src/main/scala/org/apache/spark/SSLOptions.scala @@ -21,6 +21,7 @@ import java.io.File import java.security.NoSuchAlgorithmException import javax.net.ssl.SSLContext +import org.apache.hadoop.conf.Configuration import org.eclipse.jetty.util.ssl.SslContextFactory import org.apache.spark.internal.Logging @@ -163,11 +164,16 @@ private[spark] object SSLOptions extends Logging { * missing in SparkConf, the corresponding setting is used from the default configuration. * * @param conf Spark configuration object where the settings are collected from + * @param hadoopConf Hadoop configuration to get settings * @param ns the namespace name * @param defaults the default configuration * @return [[org.apache.spark.SSLOptions]] object */ - def parse(conf: SparkConf, ns: String, defaults: Option[SSLOptions] = None): SSLOptions = { + def parse( + conf: SparkConf, + hadoopConf: Configuration, + ns: String, + defaults: Option[SSLOptions] = None): SSLOptions = { val enabled = conf.getBoolean(s"$ns.enabled", defaultValue = defaults.exists(_.enabled)) val port = conf.getWithSubstitution(s"$ns.port").map(_.toInt) @@ -179,9 +185,11 @@ private[spark] object SSLOptions extends Logging { .orElse(defaults.flatMap(_.keyStore)) val keyStorePassword = conf.getWithSubstitution(s"$ns.keyStorePassword") + .orElse(Option(hadoopConf.getPassword(s"$ns.keyStorePassword")).map(new String(_))) .orElse(defaults.flatMap(_.keyStorePassword)) val keyPassword = conf.getWithSubstitution(s"$ns.keyPassword") + .orElse(Option(hadoopConf.getPassword(s"$ns.keyPassword")).map(new String(_))) .orElse(defaults.flatMap(_.keyPassword)) val keyStoreType = conf.getWithSubstitution(s"$ns.keyStoreType") @@ -194,6 +202,7 @@ private[spark] object SSLOptions extends Logging { .orElse(defaults.flatMap(_.trustStore)) val trustStorePassword = conf.getWithSubstitution(s"$ns.trustStorePassword") + .orElse(Option(hadoopConf.getPassword(s"$ns.trustStorePassword")).map(new String(_))) .orElse(defaults.flatMap(_.trustStorePassword)) val trustStoreType = conf.getWithSubstitution(s"$ns.trustStoreType") diff --git a/core/src/main/scala/org/apache/spark/SecurityManager.scala b/core/src/main/scala/org/apache/spark/SecurityManager.scala index b87476322573d..3cfafeb951105 100644 --- a/core/src/main/scala/org/apache/spark/SecurityManager.scala +++ b/core/src/main/scala/org/apache/spark/SecurityManager.scala @@ -19,11 +19,11 @@ package org.apache.spark import java.net.{Authenticator, PasswordAuthentication} import java.nio.charset.StandardCharsets.UTF_8 -import javax.net.ssl._ import org.apache.hadoop.io.Text import org.apache.hadoop.security.{Credentials, UserGroupInformation} +import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ import org.apache.spark.launcher.SparkLauncher @@ -111,11 +111,14 @@ private[spark] class SecurityManager( ) } + private val hadoopConf = SparkHadoopUtil.get.newConfiguration(sparkConf) // the default SSL configuration - it will be used by all communication layers unless overwritten - private val defaultSSLOptions = SSLOptions.parse(sparkConf, "spark.ssl", defaults = None) + private val defaultSSLOptions = + SSLOptions.parse(sparkConf, hadoopConf, "spark.ssl", defaults = None) def getSSLOptions(module: String): SSLOptions = { - val opts = SSLOptions.parse(sparkConf, s"spark.ssl.$module", Some(defaultSSLOptions)) + val opts = + SSLOptions.parse(sparkConf, hadoopConf, s"spark.ssl.$module", Some(defaultSSLOptions)) logDebug(s"Created SSL options for $module: $opts") opts } diff --git a/core/src/test/scala/org/apache/spark/SSLOptionsSuite.scala b/core/src/test/scala/org/apache/spark/SSLOptionsSuite.scala index 8eabc2b3cb958..5dbfc5c10a6f8 100644 --- a/core/src/test/scala/org/apache/spark/SSLOptionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/SSLOptionsSuite.scala @@ -18,8 +18,11 @@ package org.apache.spark import java.io.File +import java.util.UUID import javax.net.ssl.SSLContext +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.security.alias.{CredentialProvider, CredentialProviderFactory} import org.scalatest.BeforeAndAfterAll import org.apache.spark.util.SparkConfWithEnv @@ -40,6 +43,7 @@ class SSLOptionsSuite extends SparkFunSuite with BeforeAndAfterAll { .toSet val conf = new SparkConf + val hadoopConf = new Configuration() conf.set("spark.ssl.enabled", "true") conf.set("spark.ssl.keyStore", keyStorePath) conf.set("spark.ssl.keyStorePassword", "password") @@ -49,7 +53,7 @@ class SSLOptionsSuite extends SparkFunSuite with BeforeAndAfterAll { conf.set("spark.ssl.enabledAlgorithms", algorithms.mkString(",")) conf.set("spark.ssl.protocol", "TLSv1.2") - val opts = SSLOptions.parse(conf, "spark.ssl") + val opts = SSLOptions.parse(conf, hadoopConf, "spark.ssl") assert(opts.enabled === true) assert(opts.trustStore.isDefined === true) @@ -70,6 +74,7 @@ class SSLOptionsSuite extends SparkFunSuite with BeforeAndAfterAll { val trustStorePath = new File(this.getClass.getResource("/truststore").toURI).getAbsolutePath val conf = new SparkConf + val hadoopConf = new Configuration() conf.set("spark.ssl.enabled", "true") conf.set("spark.ssl.keyStore", keyStorePath) conf.set("spark.ssl.keyStorePassword", "password") @@ -80,8 +85,8 @@ class SSLOptionsSuite extends SparkFunSuite with BeforeAndAfterAll { "TLS_RSA_WITH_AES_128_CBC_SHA, TLS_RSA_WITH_AES_256_CBC_SHA") conf.set("spark.ssl.protocol", "SSLv3") - val defaultOpts = SSLOptions.parse(conf, "spark.ssl", defaults = None) - val opts = SSLOptions.parse(conf, "spark.ssl.ui", defaults = Some(defaultOpts)) + val defaultOpts = SSLOptions.parse(conf, hadoopConf, "spark.ssl", defaults = None) + val opts = SSLOptions.parse(conf, hadoopConf, "spark.ssl.ui", defaults = Some(defaultOpts)) assert(opts.enabled === true) assert(opts.trustStore.isDefined === true) @@ -103,6 +108,7 @@ class SSLOptionsSuite extends SparkFunSuite with BeforeAndAfterAll { val trustStorePath = new File(this.getClass.getResource("/truststore").toURI).getAbsolutePath val conf = new SparkConf + val hadoopConf = new Configuration() conf.set("spark.ssl.enabled", "true") conf.set("spark.ssl.ui.enabled", "false") conf.set("spark.ssl.ui.port", "4242") @@ -117,8 +123,8 @@ class SSLOptionsSuite extends SparkFunSuite with BeforeAndAfterAll { conf.set("spark.ssl.ui.enabledAlgorithms", "ABC, DEF") conf.set("spark.ssl.protocol", "SSLv3") - val defaultOpts = SSLOptions.parse(conf, "spark.ssl", defaults = None) - val opts = SSLOptions.parse(conf, "spark.ssl.ui", defaults = Some(defaultOpts)) + val defaultOpts = SSLOptions.parse(conf, hadoopConf, "spark.ssl", defaults = None) + val opts = SSLOptions.parse(conf, hadoopConf, "spark.ssl.ui", defaults = Some(defaultOpts)) assert(opts.enabled === false) assert(opts.port === Some(4242)) @@ -139,14 +145,71 @@ class SSLOptionsSuite extends SparkFunSuite with BeforeAndAfterAll { val conf = new SparkConfWithEnv(Map( "ENV1" -> "val1", "ENV2" -> "val2")) + val hadoopConf = new Configuration() conf.set("spark.ssl.enabled", "true") conf.set("spark.ssl.keyStore", "${env:ENV1}") conf.set("spark.ssl.trustStore", "${env:ENV2}") - val opts = SSLOptions.parse(conf, "spark.ssl", defaults = None) + val opts = SSLOptions.parse(conf, hadoopConf, "spark.ssl", defaults = None) assert(opts.keyStore === Some(new File("val1"))) assert(opts.trustStore === Some(new File("val2"))) } + test("get password from Hadoop credential provider") { + val keyStorePath = new File(this.getClass.getResource("/keystore").toURI).getAbsolutePath + val trustStorePath = new File(this.getClass.getResource("/truststore").toURI).getAbsolutePath + + val conf = new SparkConf + val hadoopConf = new Configuration() + val tmpPath = s"localjceks://file${sys.props("java.io.tmpdir")}/test-" + + s"${UUID.randomUUID().toString}.jceks" + val provider = createCredentialProvider(tmpPath, hadoopConf) + + conf.set("spark.ssl.enabled", "true") + conf.set("spark.ssl.keyStore", keyStorePath) + storePassword(provider, "spark.ssl.keyStorePassword", "password") + storePassword(provider, "spark.ssl.keyPassword", "password") + conf.set("spark.ssl.trustStore", trustStorePath) + storePassword(provider, "spark.ssl.trustStorePassword", "password") + conf.set("spark.ssl.enabledAlgorithms", + "TLS_RSA_WITH_AES_128_CBC_SHA, TLS_RSA_WITH_AES_256_CBC_SHA") + conf.set("spark.ssl.protocol", "SSLv3") + + val defaultOpts = SSLOptions.parse(conf, hadoopConf, "spark.ssl", defaults = None) + val opts = SSLOptions.parse(conf, hadoopConf, "spark.ssl.ui", defaults = Some(defaultOpts)) + + assert(opts.enabled === true) + assert(opts.trustStore.isDefined === true) + assert(opts.trustStore.get.getName === "truststore") + assert(opts.trustStore.get.getAbsolutePath === trustStorePath) + assert(opts.keyStore.isDefined === true) + assert(opts.keyStore.get.getName === "keystore") + assert(opts.keyStore.get.getAbsolutePath === keyStorePath) + assert(opts.trustStorePassword === Some("password")) + assert(opts.keyStorePassword === Some("password")) + assert(opts.keyPassword === Some("password")) + assert(opts.protocol === Some("SSLv3")) + assert(opts.enabledAlgorithms === + Set("TLS_RSA_WITH_AES_128_CBC_SHA", "TLS_RSA_WITH_AES_256_CBC_SHA")) + } + + private def createCredentialProvider(tmpPath: String, conf: Configuration): CredentialProvider = { + conf.set(CredentialProviderFactory.CREDENTIAL_PROVIDER_PATH, tmpPath) + + val provider = CredentialProviderFactory.getProviders(conf).get(0) + if (provider == null) { + throw new IllegalStateException(s"Fail to get credential provider with path $tmpPath") + } + + provider + } + + private def storePassword( + provider: CredentialProvider, + passwordKey: String, + password: String): Unit = { + provider.createCredentialEntry(passwordKey, password.toCharArray) + provider.flush() + } } diff --git a/docs/security.md b/docs/security.md index 8c0c66fb5a285..6ef3a808e0471 100644 --- a/docs/security.md +++ b/docs/security.md @@ -177,7 +177,7 @@ ACLs can be configured for either users or groups. Configuration entries accept lists as input, meaning multiple users or groups can be given the desired privileges. This can be used if you run on a shared cluster and have a set of administrators or developers who need to monitor applications they may not have started themselves. A wildcard (`*`) added to specific ACL -means that all users will have the respective pivilege. By default, only the user submitting the +means that all users will have the respective privilege. By default, only the user submitting the application is added to the ACLs. Group membership is established by using a configurable group mapping provider. The mapper is @@ -446,6 +446,27 @@ replaced with one of the above namespaces. +Spark also supports retrieving `${ns}.keyPassword`, `${ns}.keyStorePassword` and `${ns}.trustStorePassword` from +[Hadoop Credential Providers](https://hadoop.apache.org/docs/current/hadoop-project-dist/hadoop-common/CredentialProviderAPI.html). +User could store password into credential file and make it accessible by different components, like: + +``` +hadoop credential create spark.ssl.keyPassword -value password \ + -provider jceks://hdfs@nn1.example.com:9001/user/backup/ssl.jceks +``` + +To configure the location of the credential provider, set the `hadoop.security.credential.provider.path` +config option in the Hadoop configuration used by Spark, like: + +``` + + hadoop.security.credential.provider.path + jceks://hdfs@nn1.example.com:9001/user/backup/ssl.jceks + +``` + +Or via SparkConf "spark.hadoop.hadoop.security.credential.provider.path=jceks://hdfs@nn1.example.com:9001/user/backup/ssl.jceks". + ## Preparing the key stores Key stores can be generated by `keytool` program. The reference documentation for this tool for From 4e7d8678a3d9b12797d07f5497e0ed9e471428dd Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Fri, 22 Jun 2018 12:38:34 -0500 Subject: [PATCH 1001/1161] [SPARK-24372][BUILD] Add scripts to help with preparing releases. The "do-release.sh" script asks questions about the RC being prepared, trying to find out as much as possible automatically, and then executes the existing scripts with proper arguments to prepare the release. This script was used to prepare the 2.3.1 release candidates, so was tested in that context. The docker version runs that same script inside a docker image especially crafted for building Spark releases. That image is based on the work by Felix C. linked in the bug. At this point is has been only midly tested. I also added a template for the vote e-mail, with placeholders for things that need to be replaced, although there is no automation around that for the moment. It shouldn't be hard to hook up certain things like version and tags to this, or to figure out certain things like the repo URL from the output of the release scripts. Author: Marcelo Vanzin Closes #21515 from vanzin/SPARK-24372. --- dev/.rat-excludes | 1 + dev/create-release/do-release-docker.sh | 143 +++++++++++++++ dev/create-release/do-release.sh | 81 +++++++++ dev/create-release/release-build.sh | 190 ++++++++++++-------- dev/create-release/release-tag.sh | 26 ++- dev/create-release/release-util.sh | 228 ++++++++++++++++++++++++ dev/create-release/spark-rm/Dockerfile | 87 +++++++++ dev/create-release/vote.tmpl | 65 +++++++ 8 files changed, 736 insertions(+), 85 deletions(-) create mode 100755 dev/create-release/do-release-docker.sh create mode 100755 dev/create-release/do-release.sh create mode 100644 dev/create-release/release-util.sh create mode 100644 dev/create-release/spark-rm/Dockerfile create mode 100644 dev/create-release/vote.tmpl diff --git a/dev/.rat-excludes b/dev/.rat-excludes index 9552d001a079c..23b24212b4d29 100644 --- a/dev/.rat-excludes +++ b/dev/.rat-excludes @@ -106,3 +106,4 @@ spark-warehouse structured-streaming/* kafka-source-initial-offset-version-2.1.0.bin kafka-source-initial-offset-future-version.bin +vote.tmpl diff --git a/dev/create-release/do-release-docker.sh b/dev/create-release/do-release-docker.sh new file mode 100755 index 0000000000000..fa7b73cdb40ec --- /dev/null +++ b/dev/create-release/do-release-docker.sh @@ -0,0 +1,143 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# +# Creates a Spark release candidate. The script will update versions, tag the branch, +# build Spark binary packages and documentation, and upload maven artifacts to a staging +# repository. There is also a dry run mode where only local builds are performed, and +# nothing is uploaded to the ASF repos. +# +# Run with "-h" for options. +# + +set -e +SELF=$(cd $(dirname $0) && pwd) +. "$SELF/release-util.sh" + +function usage { + local NAME=$(basename $0) + cat < "$GPG_KEY_FILE" + +run_silent "Building spark-rm image with tag $IMGTAG..." "docker-build.log" \ + docker build -t "spark-rm:$IMGTAG" --build-arg UID=$UID "$SELF/spark-rm" + +# Write the release information to a file with environment variables to be used when running the +# image. +ENVFILE="$WORKDIR/env.list" +fcreate_secure "$ENVFILE" + +function cleanup { + rm -f "$ENVFILE" + rm -f "$GPG_KEY_FILE" +} + +trap cleanup EXIT + +cat > $ENVFILE <> $ENVFILE + JAVA_VOL="--volume $JAVA:/opt/spark-java" +fi + +echo "Building $RELEASE_TAG; output will be at $WORKDIR/output" +docker run -ti \ + --env-file "$ENVFILE" \ + --volume "$WORKDIR:/opt/spark-rm" \ + $JAVA_VOL \ + "spark-rm:$IMGTAG" diff --git a/dev/create-release/do-release.sh b/dev/create-release/do-release.sh new file mode 100755 index 0000000000000..f1d4f3ab5ddec --- /dev/null +++ b/dev/create-release/do-release.sh @@ -0,0 +1,81 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +SELF=$(cd $(dirname $0) && pwd) +. "$SELF/release-util.sh" + +while getopts "bn" opt; do + case $opt in + b) GIT_BRANCH=$OPTARG ;; + n) DRY_RUN=1 ;; + ?) error "Invalid option: $OPTARG" ;; + esac +done + +if [ "$RUNNING_IN_DOCKER" = "1" ]; then + # Inside docker, need to import the GPG key stored in the current directory. + echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --import "$SELF/gpg.key" + + # We may need to adjust the path since JAVA_HOME may be overridden by the driver script. + if [ -n "$JAVA_HOME" ]; then + export PATH="$JAVA_HOME/bin:$PATH" + else + # JAVA_HOME for the openjdk package. + export JAVA_HOME=/usr + fi +else + # Outside docker, need to ask for information about the release. + get_release_info +fi + +function should_build { + local WHAT=$1 + [ -z "$RELEASE_STEP" ] || [ "$WHAT" = "$RELEASE_STEP" ] +} + +if should_build "tag" && [ $SKIP_TAG = 0 ]; then + run_silent "Creating release tag $RELEASE_TAG..." "tag.log" \ + "$SELF/release-tag.sh" + echo "It may take some time for the tag to be synchronized to github." + echo "Press enter when you've verified that the new tag ($RELEASE_TAG) is available." + read +else + echo "Skipping tag creation for $RELEASE_TAG." +fi + +if should_build "build"; then + run_silent "Building Spark..." "build.log" \ + "$SELF/release-build.sh" package +else + echo "Skipping build step." +fi + +if should_build "docs"; then + run_silent "Building documentation..." "docs.log" \ + "$SELF/release-build.sh" docs +else + echo "Skipping docs step." +fi + +if should_build "publish"; then + run_silent "Publishing release" "publish.log" \ + "$SELF/release-build.sh" publish-release +else + echo "Skipping publish step." +fi diff --git a/dev/create-release/release-build.sh b/dev/create-release/release-build.sh index 5faa3d3260a56..24a62a8f4c7d3 100755 --- a/dev/create-release/release-build.sh +++ b/dev/create-release/release-build.sh @@ -17,6 +17,9 @@ # limitations under the License. # +SELF=$(cd $(dirname $0) && pwd) +. "$SELF/release-util.sh" + function exit_with_usage { cat << EOF usage: release-build.sh @@ -87,49 +90,56 @@ NEXUS_ROOT=https://repository.apache.org/service/local/staging NEXUS_PROFILE=d63f592e7eac0 # Profile for Spark staging uploads BASE_DIR=$(pwd) -MVN="build/mvn --force" - -# Hive-specific profiles for some builds -HIVE_PROFILES="-Phive -Phive-thriftserver" -# Profiles for publishing snapshots and release to Maven Central -PUBLISH_PROFILES="-Pmesos -Pyarn -Pkubernetes -Pflume $HIVE_PROFILES -Pspark-ganglia-lgpl -Pkinesis-asl" -# Profiles for building binary releases -BASE_RELEASE_PROFILES="-Pmesos -Pyarn -Pkubernetes -Pflume -Psparkr" -# Scala 2.11 only profiles for some builds -SCALA_2_11_PROFILES="-Pkafka-0-8" -# Scala 2.12 only profiles for some builds -SCALA_2_12_PROFILES="-Pscala-2.12" +init_java +init_maven_sbt rm -rf spark -git clone https://git-wip-us.apache.org/repos/asf/spark.git +git clone "$ASF_REPO" cd spark git checkout $GIT_REF git_hash=`git rev-parse --short HEAD` echo "Checked out Spark git hash $git_hash" if [ -z "$SPARK_VERSION" ]; then - SPARK_VERSION=$($MVN help:evaluate -Dexpression=project.version \ - | grep -v INFO | grep -v WARNING | grep -v Download) + # Run $MVN in a separate command so that 'set -e' does the right thing. + TMP=$(mktemp) + $MVN help:evaluate -Dexpression=project.version > $TMP + SPARK_VERSION=$(cat $TMP | grep -v INFO | grep -v WARNING | grep -v Download) + rm $TMP fi -# Verify we have the right java version set -if [ -z "$JAVA_HOME" ]; then - echo "Please set JAVA_HOME." - exit 1 +# Depending on the version being built, certain extra profiles need to be activated, and +# different versions of Scala are supported. +BASE_PROFILES="-Pmesos -Pyarn" +PUBLISH_SCALA_2_10=0 +SCALA_2_10_PROFILES="-Pscala-2.10" +SCALA_2_11_PROFILES= +SCALA_2_12_PROFILES="-Pscala-2.12" + +if [[ $SPARK_VERSION > "2.3" ]]; then + BASE_PROFILES="$BASE_PROFILES -Pkubernetes -Pflume" + SCALA_2_11_PROFILES="-Pkafka-0-8" +else + PUBLISH_SCALA_2_10=1 fi -java_version=$("${JAVA_HOME}"/bin/javac -version 2>&1 | cut -d " " -f 2) +# Hive-specific profiles for some builds +HIVE_PROFILES="-Phive -Phive-thriftserver" +# Profiles for publishing snapshots and release to Maven Central +PUBLISH_PROFILES="$BASE_PROFILES $HIVE_PROFILES -Pspark-ganglia-lgpl -Pkinesis-asl" +# Profiles for building binary releases +BASE_RELEASE_PROFILES="$BASE_PROFILES -Psparkr" if [[ ! $SPARK_VERSION < "2.2." ]]; then - if [[ $java_version < "1.8." ]]; then - echo "Java version $java_version is less than required 1.8 for 2.2+" + if [[ $JAVA_VERSION < "1.8." ]]; then + echo "Java version $JAVA_VERSION is less than required 1.8 for 2.2+" echo "Please set JAVA_HOME correctly." exit 1 fi else - if [[ $java_version > "1.7." ]]; then + if ! [[ $JAVA_VERSION =~ 1\.7\..* ]]; then if [ -z "$JAVA_7_HOME" ]; then - echo "Java version $java_version is higher than required 1.7 for pre-2.2" + echo "Java version $JAVA_VERSION is higher than required 1.7 for pre-2.2" echo "Please set JAVA_HOME correctly." exit 1 else @@ -174,8 +184,9 @@ if [[ "$1" == "package" ]]; then FLAGS=$2 ZINC_PORT=$3 BUILD_PACKAGE=$4 - cp -r spark spark-$SPARK_VERSION-bin-$NAME + echo "Building binary dist $NAME" + cp -r spark spark-$SPARK_VERSION-bin-$NAME cd spark-$SPARK_VERSION-bin-$NAME # TODO There should probably be a flag to make-distribution to allow 2.12 support @@ -244,31 +255,39 @@ if [[ "$1" == "package" ]]; then spark-$SPARK_VERSION-bin-$NAME.tgz.sha512 } - # TODO: Check exit codes of children here: - # http://stackoverflow.com/questions/1570262/shell-get-exit-code-of-background-process - # We increment the Zinc port each time to avoid OOM's and other craziness if multiple builds # share the same Zinc server. - make_binary_release "hadoop2.6" "-Phadoop-2.6 $HIVE_PROFILES $SCALA_2_11_PROFILES $BASE_RELEASE_PROFILES" "3035" "withr" & - make_binary_release "hadoop2.7" "-Phadoop-2.7 $HIVE_PROFILES $SCALA_2_11_PROFILES $BASE_RELEASE_PROFILES" "3036" "withpip" & - make_binary_release "without-hadoop" "-Phadoop-provided $SCALA_2_11_PROFILES $BASE_RELEASE_PROFILES" "3038" & - wait - rm -rf spark-$SPARK_VERSION-bin-*/ + if ! make_binary_release "hadoop2.6" "$MVN_EXTRA_OPTS -B -Phadoop-2.6 $HIVE_PROFILES $SCALA_2_11_PROFILES $BASE_RELEASE_PROFILES" "3035" "withr"; then + error "Failed to build hadoop2.6 package. Check logs for details." + fi + if ! is_dry_run; then + if ! make_binary_release "hadoop2.7" "$MVN_EXTRA_OPTS -B -Phadoop-2.7 $HIVE_PROFILES $SCALA_2_11_PROFILES $BASE_RELEASE_PROFILES" "3036" "withpip"; then + error "Failed to build hadoop2.7 package. Check logs for details." + fi + if ! make_binary_release "without-hadoop" "$MVN_EXTRA_OPTS -B -Phadoop-provided $SCALA_2_11_PROFILES $BASE_RELEASE_PROFILES" "3037"; then + error "Failed to build without-hadoop package. Check logs for details." + fi + fi - svn co --depth=empty $RELEASE_STAGING_LOCATION svn-spark - rm -rf "svn-spark/${DEST_DIR_NAME}-bin" - mkdir -p "svn-spark/${DEST_DIR_NAME}-bin" + rm -rf spark-$SPARK_VERSION-bin-*/ - echo "Copying release tarballs" - cp spark-* "svn-spark/${DEST_DIR_NAME}-bin/" - cp pyspark-* "svn-spark/${DEST_DIR_NAME}-bin/" - cp SparkR_* "svn-spark/${DEST_DIR_NAME}-bin/" - svn add "svn-spark/${DEST_DIR_NAME}-bin" + if ! is_dry_run; then + svn co --depth=empty $RELEASE_STAGING_LOCATION svn-spark + rm -rf "svn-spark/${DEST_DIR_NAME}-bin" + mkdir -p "svn-spark/${DEST_DIR_NAME}-bin" + + echo "Copying release tarballs" + cp spark-* "svn-spark/${DEST_DIR_NAME}-bin/" + cp pyspark-* "svn-spark/${DEST_DIR_NAME}-bin/" + cp SparkR_* "svn-spark/${DEST_DIR_NAME}-bin/" + svn add "svn-spark/${DEST_DIR_NAME}-bin" + + cd svn-spark + svn ci --username $ASF_USERNAME --password "$ASF_PASSWORD" -m"Apache Spark $SPARK_PACKAGE_VERSION" + cd .. + rm -rf svn-spark + fi - cd svn-spark - svn ci --username $ASF_USERNAME --password "$ASF_PASSWORD" -m"Apache Spark $SPARK_PACKAGE_VERSION" - cd .. - rm -rf svn-spark exit 0 fi @@ -282,18 +301,22 @@ if [[ "$1" == "docs" ]]; then cd .. cd .. - svn co --depth=empty $RELEASE_STAGING_LOCATION svn-spark - rm -rf "svn-spark/${DEST_DIR_NAME}-docs" - mkdir -p "svn-spark/${DEST_DIR_NAME}-docs" + if ! is_dry_run; then + svn co --depth=empty $RELEASE_STAGING_LOCATION svn-spark + rm -rf "svn-spark/${DEST_DIR_NAME}-docs" + mkdir -p "svn-spark/${DEST_DIR_NAME}-docs" - echo "Copying release documentation" - cp -R "spark/docs/_site" "svn-spark/${DEST_DIR_NAME}-docs/" - svn add "svn-spark/${DEST_DIR_NAME}-docs" + echo "Copying release documentation" + cp -R "spark/docs/_site" "svn-spark/${DEST_DIR_NAME}-docs/" + svn add "svn-spark/${DEST_DIR_NAME}-docs" - cd svn-spark - svn ci --username $ASF_USERNAME --password "$ASF_PASSWORD" -m"Apache Spark $SPARK_PACKAGE_VERSION docs" - cd .. - rm -rf svn-spark + cd svn-spark + svn ci --username $ASF_USERNAME --password "$ASF_PASSWORD" -m"Apache Spark $SPARK_PACKAGE_VERSION docs" + cd .. + rm -rf svn-spark + fi + + mv "spark/docs/_site" docs/ exit 0 fi @@ -341,13 +364,15 @@ if [[ "$1" == "publish-release" ]]; then # Using Nexus API documented here: # https://support.sonatype.com/entries/39720203-Uploading-to-a-Staging-Repository-via-REST-API - echo "Creating Nexus staging repository" - repo_request="Apache Spark $SPARK_VERSION (commit $git_hash)" - out=$(curl -X POST -d "$repo_request" -u $ASF_USERNAME:$ASF_PASSWORD \ - -H "Content-Type:application/xml" -v \ - $NEXUS_ROOT/profiles/$NEXUS_PROFILE/start) - staged_repo_id=$(echo $out | sed -e "s/.*\(orgapachespark-[0-9]\{4\}\).*/\1/") - echo "Created Nexus staging repository: $staged_repo_id" + if ! is_dry_run; then + echo "Creating Nexus staging repository" + repo_request="Apache Spark $SPARK_VERSION (commit $git_hash)" + out=$(curl -X POST -d "$repo_request" -u $ASF_USERNAME:$ASF_PASSWORD \ + -H "Content-Type:application/xml" -v \ + $NEXUS_ROOT/profiles/$NEXUS_PROFILE/start) + staged_repo_id=$(echo $out | sed -e "s/.*\(orgapachespark-[0-9]\{4\}\).*/\1/") + echo "Created Nexus staging repository: $staged_repo_id" + fi tmp_repo=$(mktemp -d spark-repo-XXXXX) @@ -356,6 +381,12 @@ if [[ "$1" == "publish-release" ]]; then $MVN -DzincPort=$ZINC_PORT -Dmaven.repo.local=$tmp_repo -DskipTests $SCALA_2_11_PROFILES $PUBLISH_PROFILES clean install + if ! is_dry_run && [[ $PUBLISH_SCALA_2_10 = 1 ]]; then + ./dev/change-scala-version.sh 2.10 + $MVN -DzincPort=$((ZINC_PORT + 1)) -Dmaven.repo.local=$tmp_repo -Dscala-2.10 \ + -DskipTests $PUBLISH_PROFILES $SCALA_2_10_PROFILES clean install + fi + #./dev/change-scala-version.sh 2.12 #$MVN -DzincPort=$ZINC_PORT -Dmaven.repo.local=$tmp_repo \ # -DskipTests $SCALA_2_12_PROFILES §$PUBLISH_PROFILES clean install @@ -386,23 +417,26 @@ if [[ "$1" == "publish-release" ]]; then sha1sum $file | cut -f1 -d' ' > $file.sha1 done - nexus_upload=$NEXUS_ROOT/deployByRepositoryId/$staged_repo_id - echo "Uplading files to $nexus_upload" - for file in $(find . -type f) - do - # strip leading ./ - file_short=$(echo $file | sed -e "s/\.\///") - dest_url="$nexus_upload/org/apache/spark/$file_short" - echo " Uploading $file_short" - curl -u $ASF_USERNAME:$ASF_PASSWORD --upload-file $file_short $dest_url - done + if ! is_dry_run; then + nexus_upload=$NEXUS_ROOT/deployByRepositoryId/$staged_repo_id + echo "Uplading files to $nexus_upload" + for file in $(find . -type f) + do + # strip leading ./ + file_short=$(echo $file | sed -e "s/\.\///") + dest_url="$nexus_upload/org/apache/spark/$file_short" + echo " Uploading $file_short" + curl -u $ASF_USERNAME:$ASF_PASSWORD --upload-file $file_short $dest_url + done + + echo "Closing nexus staging repository" + repo_request="$staged_repo_idApache Spark $SPARK_VERSION (commit $git_hash)" + out=$(curl -X POST -d "$repo_request" -u $ASF_USERNAME:$ASF_PASSWORD \ + -H "Content-Type:application/xml" -v \ + $NEXUS_ROOT/profiles/$NEXUS_PROFILE/finish) + echo "Closed Nexus staging repository: $staged_repo_id" + fi - echo "Closing nexus staging repository" - repo_request="$staged_repo_idApache Spark $SPARK_VERSION (commit $git_hash)" - out=$(curl -X POST -d "$repo_request" -u $ASF_USERNAME:$ASF_PASSWORD \ - -H "Content-Type:application/xml" -v \ - $NEXUS_ROOT/profiles/$NEXUS_PROFILE/finish) - echo "Closed Nexus staging repository: $staged_repo_id" popd rm -rf $tmp_repo cd .. diff --git a/dev/create-release/release-tag.sh b/dev/create-release/release-tag.sh index a05716a5f66bb..628bc0504c9c8 100755 --- a/dev/create-release/release-tag.sh +++ b/dev/create-release/release-tag.sh @@ -17,6 +17,9 @@ # limitations under the License. # +SELF=$(cd $(dirname $0) && pwd) +. "$SELF/release-util.sh" + function exit_with_usage { cat << EOF usage: tag-release.sh @@ -36,6 +39,7 @@ EOF } set -e +set -o pipefail if [[ $@ == *"help"* ]]; then exit_with_usage @@ -54,8 +58,10 @@ for env in ASF_USERNAME ASF_PASSWORD RELEASE_VERSION RELEASE_TAG NEXT_VERSION GI fi done +init_java +init_maven_sbt + ASF_SPARK_REPO="git-wip-us.apache.org/repos/asf/spark.git" -MVN="build/mvn --force" rm -rf spark git clone "https://$ASF_USERNAME:$ASF_PASSWORD@$ASF_SPARK_REPO" -b $GIT_BRANCH @@ -94,9 +100,15 @@ sed -i".tmp7" 's/SPARK_VERSION_SHORT:.*$/SPARK_VERSION_SHORT: '"$R_NEXT_VERSION" git commit -a -m "Preparing development version $NEXT_VERSION" -# Push changes -git push origin $RELEASE_TAG -git push origin HEAD:$GIT_BRANCH - -cd .. -rm -rf spark +if ! is_dry_run; then + # Push changes + git push origin $RELEASE_TAG + git push origin HEAD:$GIT_BRANCH + + cd .. + rm -rf spark +else + cd .. + mv spark spark.tag + echo "Clone with version changes and tag available as spark.tag in the output directory." +fi diff --git a/dev/create-release/release-util.sh b/dev/create-release/release-util.sh new file mode 100644 index 0000000000000..7426b0d6ca08d --- /dev/null +++ b/dev/create-release/release-util.sh @@ -0,0 +1,228 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +DRY_RUN=${DRY_RUN:-0} +GPG="gpg --no-tty --batch" +ASF_REPO="https://git-wip-us.apache.org/repos/asf/spark.git" +ASF_REPO_WEBUI="https://git-wip-us.apache.org/repos/asf?p=spark.git" + +function error { + echo "$*" + exit 1 +} + +function read_config { + local PROMPT="$1" + local DEFAULT="$2" + local REPLY= + + read -p "$PROMPT [$DEFAULT]: " REPLY + local RETVAL="${REPLY:-$DEFAULT}" + if [ -z "$RETVAL" ]; then + error "$PROMPT is must be provided." + fi + echo "$RETVAL" +} + +function parse_version { + grep -e '.*' | \ + head -n 2 | tail -n 1 | cut -d'>' -f2 | cut -d '<' -f1 +} + +function run_silent { + local BANNER="$1" + local LOG_FILE="$2" + shift 2 + + echo "========================" + echo "= $BANNER" + echo "Command: $@" + echo "Log file: $LOG_FILE" + + "$@" 1>"$LOG_FILE" 2>&1 + + local EC=$? + if [ $EC != 0 ]; then + echo "Command FAILED. Check full logs for details." + tail "$LOG_FILE" + exit $EC + fi +} + +function fcreate_secure { + local FPATH="$1" + rm -f "$FPATH" + touch "$FPATH" + chmod 600 "$FPATH" +} + +function check_for_tag { + curl -s --head --fail "$ASF_REPO_WEBUI;a=commit;h=$1" >/dev/null +} + +function get_release_info { + if [ -z "$GIT_BRANCH" ]; then + # If no branch is specified, found out the latest branch from the repo. + GIT_BRANCH=$(git ls-remote --heads "$ASF_REPO" | + grep -v refs/heads/master | + awk '{print $2}' | + sort -r | + head -n 1 | + cut -d/ -f3) + fi + + export GIT_BRANCH=$(read_config "Branch" "$GIT_BRANCH") + + # Find the current version for the branch. + local VERSION=$(curl -s "$ASF_REPO_WEBUI;a=blob_plain;f=pom.xml;hb=refs/heads/$GIT_BRANCH" | + parse_version) + echo "Current branch version is $VERSION." + + if [[ ! $VERSION =~ .*-SNAPSHOT ]]; then + error "Not a SNAPSHOT version: $VERSION" + fi + + NEXT_VERSION="$VERSION" + RELEASE_VERSION="${VERSION/-SNAPSHOT/}" + SHORT_VERSION=$(echo "$VERSION" | cut -d . -f 1-2) + local REV=$(echo "$VERSION" | cut -d . -f 3) + + # Find out what rc is being prepared. + # - If the current version is "x.y.0", then this is rc1 of the "x.y.0" release. + # - If not, need to check whether the previous version has been already released or not. + # - If it has, then we're building rc1 of the current version. + # - If it has not, we're building the next RC of the previous version. + local RC_COUNT + if [ $REV != 0 ]; then + local PREV_REL_REV=$((REV - 1)) + local PREV_REL_TAG="v${SHORT_VERSION}.${PREV_REL_REV}" + if check_for_tag "$PREV_REL_TAG"; then + RC_COUNT=1 + REV=$((REV + 1)) + NEXT_VERSION="${SHORT_VERSION}.${REV}-SNAPSHOT" + else + RELEASE_VERSION="${SHORT_VERSION}.${PREV_REL_REV}" + RC_COUNT=$(git ls-remote --tags "$ASF_REPO" "v${RELEASE_VERSION}-rc*" | wc -l) + RC_COUNT=$((RC_COUNT + 1)) + fi + else + REV=$((REV + 1)) + NEXT_VERSION="${SHORT_VERSION}.${REV}-SNAPSHOT" + RC_COUNT=1 + fi + + export NEXT_VERSION + export RELEASE_VERSION=$(read_config "Release" "$RELEASE_VERSION") + + RC_COUNT=$(read_config "RC #" "$RC_COUNT") + + # Check if the RC already exists, and if re-creating the RC, skip tag creation. + RELEASE_TAG="v${RELEASE_VERSION}-rc${RC_COUNT}" + SKIP_TAG=0 + if check_for_tag "$RELEASE_TAG"; then + read -p "$RELEASE_TAG already exists. Continue anyway [y/n]? " ANSWER + if [ "$ANSWER" != "y" ]; then + error "Exiting." + fi + SKIP_TAG=1 + fi + + + export RELEASE_TAG + + GIT_REF="$RELEASE_TAG" + if is_dry_run; then + echo "This is a dry run. Please confirm the ref that will be built for testing." + GIT_REF=$(read_config "Ref" "$GIT_REF") + fi + export GIT_REF + export SPARK_PACKAGE_VERSION="$RELEASE_TAG" + + # Gather some user information. + export ASF_USERNAME=$(read_config "ASF user" "$LOGNAME") + + GIT_NAME=$(git config user.name || echo "") + export GIT_NAME=$(read_config "Full name" "$GIT_NAME") + + export GIT_EMAIL="$ASF_USERNAME@apache.org" + export GPG_KEY=$(read_config "GPG key" "$GIT_EMAIL") + + cat <&1 | cut -d " " -f 2) + export JAVA_VERSION +} + +# Initializes MVN_EXTRA_OPTS and SBT_OPTS depending on the JAVA_VERSION in use. Requires init_java. +function init_maven_sbt { + MVN="build/mvn -B" + MVN_EXTRA_OPTS= + SBT_OPTS= + if [[ $JAVA_VERSION < "1.8." ]]; then + # Needed for maven central when using Java 7. + SBT_OPTS="-Dhttps.protocols=TLSv1.1,TLSv1.2" + MVN_EXTRA_OPTS="-Dhttps.protocols=TLSv1.1,TLSv1.2" + MVN="$MVN $MVN_EXTRA_OPTS" + fi + export MVN MVN_EXTRA_OPTS SBT_OPTS +} diff --git a/dev/create-release/spark-rm/Dockerfile b/dev/create-release/spark-rm/Dockerfile new file mode 100644 index 0000000000000..07ce320177f5a --- /dev/null +++ b/dev/create-release/spark-rm/Dockerfile @@ -0,0 +1,87 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Image for building Spark releases. Based on Ubuntu 16.04. +# +# Includes: +# * Java 8 +# * Ivy +# * Python/PyPandoc (2.7.12/3.5.2) +# * R-base/R-base-dev (3.3.2+) +# * Ruby 2.3 build utilities + +FROM ubuntu:16.04 + +# These arguments are just for reuse and not really meant to be customized. +ARG APT_INSTALL="apt-get install --no-install-recommends -y" + +ARG BASE_PIP_PKGS="setuptools wheel virtualenv" +ARG PIP_PKGS="pyopenssl pypandoc numpy pygments sphinx" + +# Install extra needed repos and refresh. +# - CRAN repo +# - Ruby repo (for doc generation) +# +# This is all in a single "RUN" command so that if anything changes, "apt update" is run to fetch +# the most current package versions (instead of potentially using old versions cached by docker). +RUN echo 'deb http://cran.cnr.Berkeley.edu/bin/linux/ubuntu xenial/' >> /etc/apt/sources.list && \ + gpg --keyserver keyserver.ubuntu.com --recv-key E084DAB9 && \ + gpg -a --export E084DAB9 | apt-key add - && \ + apt-get clean && \ + rm -rf /var/lib/apt/lists/* && \ + apt-get clean && \ + apt-get update && \ + $APT_INSTALL software-properties-common && \ + apt-add-repository -y ppa:brightbox/ruby-ng && \ + apt-get update && \ + # Install openjdk 8. + $APT_INSTALL openjdk-8-jdk && \ + update-alternatives --set java /usr/lib/jvm/java-8-openjdk-amd64/jre/bin/java && \ + # Install build / source control tools + $APT_INSTALL curl wget git maven ivy subversion make gcc lsof libffi-dev \ + pandoc pandoc-citeproc libssl-dev libcurl4-openssl-dev libxml2-dev && \ + ln -s -T /usr/share/java/ivy.jar /usr/share/ant/lib/ivy.jar && \ + curl -sL https://deb.nodesource.com/setup_4.x | bash && \ + $APT_INSTALL nodejs && \ + # Install needed python packages. Use pip for installing packages (for consistency). + $APT_INSTALL libpython2.7-dev libpython3-dev python-pip python3-pip && \ + pip install $BASE_PIP_PKGS && \ + pip install $PIP_PKGS && \ + cd && \ + virtualenv -p python3 p35 && \ + . p35/bin/activate && \ + pip install $BASE_PIP_PKGS && \ + pip install $PIP_PKGS && \ + # Install R packages and dependencies used when building. + # R depends on pandoc*, libssl (which are installed above). + $APT_INSTALL r-base r-base-dev && \ + $APT_INSTALL texlive-latex-base texlive texlive-fonts-extra texinfo qpdf && \ + Rscript -e "install.packages(c('curl', 'xml2', 'httr', 'devtools', 'testthat', 'knitr', 'rmarkdown', 'roxygen2', 'e1071', 'survival'), repos='http://cran.us.r-project.org/')" && \ + Rscript -e "devtools::install_github('jimhester/lintr')" && \ + # Install tools needed to build the documentation. + $APT_INSTALL ruby2.3 ruby2.3-dev && \ + gem install jekyll --no-rdoc --no-ri && \ + gem install jekyll-redirect-from && \ + gem install pygments.rb + +WORKDIR /opt/spark-rm/output + +ARG UID +RUN useradd -m -s /bin/bash -p spark-rm -u $UID spark-rm +USER spark-rm:spark-rm + +ENTRYPOINT [ "/opt/spark-rm/do-release.sh" ] diff --git a/dev/create-release/vote.tmpl b/dev/create-release/vote.tmpl new file mode 100644 index 0000000000000..2ce953c2f7ec4 --- /dev/null +++ b/dev/create-release/vote.tmpl @@ -0,0 +1,65 @@ +Please vote on releasing the following candidate as Apache Spark version {version}. + +The vote is open until {deadline} and passes if a majority +1 PMC votes are cast, with +a minimum of 3 +1 votes. + +[ ] +1 Release this package as Apache Spark {version} +[ ] -1 Do not release this package because ... + +To learn more about Apache Spark, please see http://spark.apache.org/ + +The tag to be voted on is {tag} (commit {tag_commit}): +https://github.com/apache/spark/tree/{tag} + +The release files, including signatures, digests, etc. can be found at: +https://dist.apache.org/repos/dist/dev/spark/{tag}-bin/ + +Signatures used for Spark RCs can be found in this file: +https://dist.apache.org/repos/dist/dev/spark/KEYS + +The staging repository for this release can be found at: +https://repository.apache.org/content/repositories/orgapachespark-{repo_id}/ + +The documentation corresponding to this release can be found at: +https://dist.apache.org/repos/dist/dev/spark/{tag}-docs/ + +The list of bug fixes going into {version} can be found at the following URL: +https://issues.apache.org/jira/projects/SPARK/versions/{jira_version_id} + +FAQ + +========================= +How can I help test this release? +========================= + +If you are a Spark user, you can help us test this release by taking +an existing Spark workload and running on this release candidate, then +reporting any regressions. + +If you're working in PySpark you can set up a virtual env and install +the current RC and see if anything important breaks, in the Java/Scala +you can add the staging repository to your projects resolvers and test +with the RC (make sure to clean up the artifact cache before/after so +you don't end up building with a out of date RC going forward). + +=========================================== +What should happen to JIRA tickets still targeting {version}? +=========================================== + +The current list of open tickets targeted at {version} can be found at: +{open_issues_link} + +Committers should look at those and triage. Extremely important bug +fixes, documentation, and API tweaks that impact compatibility should +be worked on immediately. Everything else please retarget to an +appropriate release. + +================== +But my bug isn't fixed? +================== + +In order to make timely releases, we will typically not hold the +release unless the bug in question is a regression from the previous +release. That being said, if there is something which is a regression +that has not been correctly targeted please ping me or a committer to +help target the issue. \ No newline at end of file From c7e2742f9bce2fcb7c717df80761939272beff54 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Sat, 23 Jun 2018 17:40:20 -0700 Subject: [PATCH 1002/1161] [SPARK-24190][SQL] Allow saving of JSON files in UTF-16 and UTF-32 ## What changes were proposed in this pull request? Currently, restrictions in JSONOptions for `encoding` and `lineSep` are the same for read and for write. For example, a requirement for `lineSep` in the code: ``` df.write.option("encoding", "UTF-32BE").json(file) ``` doesn't allow to skip `lineSep` and use its default value `\n` because it throws the exception: ``` equirement failed: The lineSep option must be specified for the UTF-32BE encoding java.lang.IllegalArgumentException: requirement failed: The lineSep option must be specified for the UTF-32BE encoding ``` In the PR, I propose to separate JSONOptions in read and write, and make JSONOptions in write less restrictive. ## How was this patch tested? Added new test for blacklisted encodings in read. And the `lineSep` option was removed in write for some tests. Author: Maxim Gekk Author: Maxim Gekk Closes #21247 from MaxGekk/json-options-in-write. --- .../spark/sql/catalyst/json/JSONOptions.scala | 71 +++++++++++++------ .../datasources/json/JsonFileFormat.scala | 13 ++-- .../datasources/json/JsonSuite.scala | 56 ++++++++++----- 3 files changed, 95 insertions(+), 45 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala index c081772116f84..47eeb70e00427 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala @@ -97,32 +97,16 @@ private[sql] class JSONOptions( sep } + protected def checkedEncoding(enc: String): String = enc + /** * Standard encoding (charset) name. For example UTF-8, UTF-16LE and UTF-32BE. - * If the encoding is not specified (None), it will be detected automatically - * when the multiLine option is set to `true`. + * If the encoding is not specified (None) in read, it will be detected automatically + * when the multiLine option is set to `true`. If encoding is not specified in write, + * UTF-8 is used by default. */ val encoding: Option[String] = parameters.get("encoding") - .orElse(parameters.get("charset")).map { enc => - // The following encodings are not supported in per-line mode (multiline is false) - // because they cause some problems in reading files with BOM which is supposed to - // present in the files with such encodings. After splitting input files by lines, - // only the first lines will have the BOM which leads to impossibility for reading - // the rest lines. Besides of that, the lineSep option must have the BOM in such - // encodings which can never present between lines. - val blacklist = Seq(Charset.forName("UTF-16"), Charset.forName("UTF-32")) - val isBlacklisted = blacklist.contains(Charset.forName(enc)) - require(multiLine || !isBlacklisted, - s"""The $enc encoding in the blacklist is not allowed when multiLine is disabled. - |Blacklist: ${blacklist.mkString(", ")}""".stripMargin) - - val isLineSepRequired = - multiLine || Charset.forName(enc) == StandardCharsets.UTF_8 || lineSeparator.nonEmpty - - require(isLineSepRequired, s"The lineSep option must be specified for the $enc encoding") - - enc - } + .orElse(parameters.get("charset")).map(checkedEncoding) val lineSeparatorInRead: Option[Array[Byte]] = lineSeparator.map { lineSep => lineSep.getBytes(encoding.getOrElse("UTF-8")) @@ -141,3 +125,46 @@ private[sql] class JSONOptions( factory.configure(JsonParser.Feature.ALLOW_UNQUOTED_CONTROL_CHARS, allowUnquotedControlChars) } } + +private[sql] class JSONOptionsInRead( + @transient override val parameters: CaseInsensitiveMap[String], + defaultTimeZoneId: String, + defaultColumnNameOfCorruptRecord: String) + extends JSONOptions(parameters, defaultTimeZoneId, defaultColumnNameOfCorruptRecord) { + + def this( + parameters: Map[String, String], + defaultTimeZoneId: String, + defaultColumnNameOfCorruptRecord: String = "") = { + this( + CaseInsensitiveMap(parameters), + defaultTimeZoneId, + defaultColumnNameOfCorruptRecord) + } + + protected override def checkedEncoding(enc: String): String = { + val isBlacklisted = JSONOptionsInRead.blacklist.contains(Charset.forName(enc)) + require(multiLine || !isBlacklisted, + s"""The ${enc} encoding must not be included in the blacklist when multiLine is disabled: + |Blacklist: ${JSONOptionsInRead.blacklist.mkString(", ")}""".stripMargin) + + val isLineSepRequired = + multiLine || Charset.forName(enc) == StandardCharsets.UTF_8 || lineSeparator.nonEmpty + require(isLineSepRequired, s"The lineSep option must be specified for the $enc encoding") + + enc + } +} + +private[sql] object JSONOptionsInRead { + // The following encodings are not supported in per-line mode (multiline is false) + // because they cause some problems in reading files with BOM which is supposed to + // present in the files with such encodings. After splitting input files by lines, + // only the first lines will have the BOM which leads to impossibility for reading + // the rest lines. Besides of that, the lineSep option must have the BOM in such + // encodings which can never present between lines. + val blacklist = Seq( + Charset.forName("UTF-16"), + Charset.forName("UTF-32") + ) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala index 3b04510d29695..e9a0b383b5f49 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala @@ -26,7 +26,7 @@ import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} import org.apache.spark.internal.Logging import org.apache.spark.sql.{AnalysisException, SparkSession} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.json.{JacksonGenerator, JacksonParser, JSONOptions} +import org.apache.spark.sql.catalyst.json.{JacksonGenerator, JacksonParser, JSONOptions, JSONOptionsInRead} import org.apache.spark.sql.catalyst.util.CompressionCodecs import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.sources._ @@ -40,7 +40,7 @@ class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister { sparkSession: SparkSession, options: Map[String, String], path: Path): Boolean = { - val parsedOptions = new JSONOptions( + val parsedOptions = new JSONOptionsInRead( options, sparkSession.sessionState.conf.sessionLocalTimeZone, sparkSession.sessionState.conf.columnNameOfCorruptRecord) @@ -52,7 +52,7 @@ class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister { sparkSession: SparkSession, options: Map[String, String], files: Seq[FileStatus]): Option[StructType] = { - val parsedOptions = new JSONOptions( + val parsedOptions = new JSONOptionsInRead( options, sparkSession.sessionState.conf.sessionLocalTimeZone, sparkSession.sessionState.conf.columnNameOfCorruptRecord) @@ -99,7 +99,7 @@ class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister { val broadcastedHadoopConf = sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) - val parsedOptions = new JSONOptions( + val parsedOptions = new JSONOptionsInRead( options, sparkSession.sessionState.conf.sessionLocalTimeZone, sparkSession.sessionState.conf.columnNameOfCorruptRecord) @@ -158,6 +158,11 @@ private[json] class JsonOutputWriter( case None => StandardCharsets.UTF_8 } + if (JSONOptionsInRead.blacklist.contains(encoding)) { + logWarning(s"The JSON file ($path) was written in the encoding ${encoding.displayName()}" + + " which can be read back by Spark only if multiLine is enabled.") + } + private val writer = CodecStreams.createOutputStreamWriter( context, new Path(path), encoding) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index a8a4a524a97f9..897424daca0cb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -17,9 +17,9 @@ package org.apache.spark.sql.execution.datasources.json -import java.io.{File, FileOutputStream, StringWriter} -import java.nio.charset.{StandardCharsets, UnsupportedCharsetException} -import java.nio.file.{Files, Paths, StandardOpenOption} +import java.io._ +import java.nio.charset.{Charset, StandardCharsets, UnsupportedCharsetException} +import java.nio.file.Files import java.sql.{Date, Timestamp} import java.util.Locale @@ -2262,7 +2262,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { withTempPath { path => val df = spark.createDataset(Seq(("Dog", 42))) df.write - .options(Map("encoding" -> encoding, "lineSep" -> "\n")) + .options(Map("encoding" -> encoding)) .json(path.getCanonicalPath) checkEncoding( @@ -2286,16 +2286,22 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { test("SPARK-23723: wrong output encoding") { val encoding = "UTF-128" - val exception = intercept[UnsupportedCharsetException] { + val exception = intercept[SparkException] { withTempPath { path => val df = spark.createDataset(Seq((0))) df.write - .options(Map("encoding" -> encoding, "lineSep" -> "\n")) + .options(Map("encoding" -> encoding)) .json(path.getCanonicalPath) } } - assert(exception.getMessage == encoding) + val baos = new ByteArrayOutputStream() + val ps = new PrintStream(baos, true, "UTF-8") + exception.printStackTrace(ps) + ps.flush() + + assert(baos.toString.contains( + "java.nio.charset.UnsupportedCharsetException: UTF-128")) } test("SPARK-23723: read back json in UTF-16LE") { @@ -2316,18 +2322,17 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { test("SPARK-23723: write json in UTF-16/32 with multiline off") { Seq("UTF-16", "UTF-32").foreach { encoding => withTempPath { path => - val ds = spark.createDataset(Seq( - ("a", 1), ("b", 2), ("c", 3)) - ).repartition(2) - val e = intercept[IllegalArgumentException] { - ds.write - .option("encoding", encoding) - .option("multiline", "false") - .format("json").mode("overwrite") - .save(path.getCanonicalPath) - }.getMessage - assert(e.contains( - s"$encoding encoding in the blacklist is not allowed when multiLine is disabled")) + val ds = spark.createDataset(Seq(("a", 1))).repartition(1) + ds.write + .option("encoding", encoding) + .option("multiline", false) + .json(path.getCanonicalPath) + val jsonFiles = path.listFiles().filter(_.getName.endsWith("json")) + jsonFiles.foreach { jsonFile => + val readback = Files.readAllBytes(jsonFile.toPath) + val expected = ("""{"_1":"a","_2":1}""" + "\n").getBytes(Charset.forName(encoding)) + assert(readback === expected) + } } } } @@ -2476,4 +2481,17 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { checkAnswer(df, Row(Row(1, "string")) :: Row(Row(2, null)) :: Row(null) :: Nil) } } + + test("SPARK-24190: restrictions for JSONOptions in read") { + for (encoding <- Set("UTF-16", "UTF-32")) { + val exception = intercept[IllegalArgumentException] { + spark.read + .option("encoding", encoding) + .option("multiLine", false) + .json(testFile("test-data/utf16LE.json")) + .count() + } + assert(exception.getMessage.contains("encoding must not be included in the blacklist")) + } + } } From 98f363b77488792009f5b97bf831ede280f232e2 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Sat, 23 Jun 2018 17:51:18 -0700 Subject: [PATCH 1003/1161] [SPARK-24206][SQL] Improve FilterPushdownBenchmark benchmark code ## What changes were proposed in this pull request? This pr added benchmark code `FilterPushdownBenchmark` for string pushdown and updated performance results on the AWS `r3.xlarge`. ## How was this patch tested? N/A Author: Takeshi Yamamuro Closes #21288 from maropu/UpdateParquetBenchmark. --- .../spark/sql/FilterPushdownBenchmark.scala | 243 ---------- .../benchmark/FilterPushdownBenchmark.scala | 442 ++++++++++++++++++ 2 files changed, 442 insertions(+), 243 deletions(-) delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/FilterPushdownBenchmark.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/FilterPushdownBenchmark.scala diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FilterPushdownBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/FilterPushdownBenchmark.scala deleted file mode 100644 index c6dd7dadc9d93..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/FilterPushdownBenchmark.scala +++ /dev/null @@ -1,243 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql - -import java.io.File - -import scala.util.{Random, Try} - -import org.apache.spark.SparkConf -import org.apache.spark.sql.functions.monotonically_increasing_id -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.util.{Benchmark, Utils} - - -/** - * Benchmark to measure read performance with Filter pushdown. - */ -object FilterPushdownBenchmark { - val conf = new SparkConf() - conf.set("orc.compression", "snappy") - conf.set("spark.sql.parquet.compression.codec", "snappy") - - private val spark = SparkSession.builder() - .master("local[1]") - .appName("FilterPushdownBenchmark") - .config(conf) - .getOrCreate() - - def withTempPath(f: File => Unit): Unit = { - val path = Utils.createTempDir() - path.delete() - try f(path) finally Utils.deleteRecursively(path) - } - - def withTempTable(tableNames: String*)(f: => Unit): Unit = { - try f finally tableNames.foreach(spark.catalog.dropTempView) - } - - def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = { - val (keys, values) = pairs.unzip - val currentValues = keys.map(key => Try(spark.conf.get(key)).toOption) - (keys, values).zipped.foreach(spark.conf.set) - try f finally { - keys.zip(currentValues).foreach { - case (key, Some(value)) => spark.conf.set(key, value) - case (key, None) => spark.conf.unset(key) - } - } - } - - private def prepareTable(dir: File, numRows: Int, width: Int): Unit = { - import spark.implicits._ - val selectExpr = (1 to width).map(i => s"CAST(value AS STRING) c$i") - val df = spark.range(numRows).map(_ => Random.nextLong).selectExpr(selectExpr: _*) - .withColumn("id", monotonically_increasing_id()) - - val dirORC = dir.getCanonicalPath + "/orc" - val dirParquet = dir.getCanonicalPath + "/parquet" - - df.write.mode("overwrite").orc(dirORC) - df.write.mode("overwrite").parquet(dirParquet) - - spark.read.orc(dirORC).createOrReplaceTempView("orcTable") - spark.read.parquet(dirParquet).createOrReplaceTempView("parquetTable") - } - - def filterPushDownBenchmark( - values: Int, - title: String, - whereExpr: String, - selectExpr: String = "*"): Unit = { - val benchmark = new Benchmark(title, values, minNumIters = 5) - - Seq(false, true).foreach { pushDownEnabled => - val name = s"Parquet Vectorized ${if (pushDownEnabled) s"(Pushdown)" else ""}" - benchmark.addCase(name) { _ => - withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> s"$pushDownEnabled") { - spark.sql(s"SELECT $selectExpr FROM parquetTable WHERE $whereExpr").collect() - } - } - } - - Seq(false, true).foreach { pushDownEnabled => - val name = s"Native ORC Vectorized ${if (pushDownEnabled) s"(Pushdown)" else ""}" - benchmark.addCase(name) { _ => - withSQLConf(SQLConf.ORC_FILTER_PUSHDOWN_ENABLED.key -> s"$pushDownEnabled") { - spark.sql(s"SELECT $selectExpr FROM orcTable WHERE $whereExpr").collect() - } - } - } - - /* - Java HotSpot(TM) 64-Bit Server VM 1.8.0_152-b16 on Mac OS X 10.13.2 - Intel(R) Core(TM) i7-4770HQ CPU @ 2.20GHz - - Select 0 row (id IS NULL): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ----------------------------------------------------------------------------------------------- - Parquet Vectorized 7882 / 7957 2.0 501.1 1.0X - Parquet Vectorized (Pushdown) 55 / 60 285.2 3.5 142.9X - Native ORC Vectorized 5592 / 5627 2.8 355.5 1.4X - Native ORC Vectorized (Pushdown) 66 / 70 237.2 4.2 118.9X - - Select 0 row (7864320 < id < 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ----------------------------------------------------------------------------------------------- - Parquet Vectorized 7884 / 7909 2.0 501.2 1.0X - Parquet Vectorized (Pushdown) 739 / 752 21.3 47.0 10.7X - Native ORC Vectorized 5614 / 5646 2.8 356.9 1.4X - Native ORC Vectorized (Pushdown) 81 / 83 195.2 5.1 97.8X - - Select 1 row (id = 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ----------------------------------------------------------------------------------------------- - Parquet Vectorized 7905 / 8027 2.0 502.6 1.0X - Parquet Vectorized (Pushdown) 740 / 766 21.2 47.1 10.7X - Native ORC Vectorized 5684 / 5738 2.8 361.4 1.4X - Native ORC Vectorized (Pushdown) 78 / 81 202.4 4.9 101.7X - - Select 1 row (id <=> 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ----------------------------------------------------------------------------------------------- - Parquet Vectorized 7928 / 7993 2.0 504.1 1.0X - Parquet Vectorized (Pushdown) 747 / 772 21.0 47.5 10.6X - Native ORC Vectorized 5728 / 5753 2.7 364.2 1.4X - Native ORC Vectorized (Pushdown) 76 / 78 207.9 4.8 104.8X - - Select 1 row (7864320 <= id <= 7864320):Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ----------------------------------------------------------------------------------------------- - Parquet Vectorized 7939 / 8021 2.0 504.8 1.0X - Parquet Vectorized (Pushdown) 746 / 770 21.1 47.4 10.6X - Native ORC Vectorized 5690 / 5734 2.8 361.7 1.4X - Native ORC Vectorized (Pushdown) 76 / 79 206.7 4.8 104.3X - - Select 1 row (7864319 < id < 7864321): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ----------------------------------------------------------------------------------------------- - Parquet Vectorized 7972 / 8019 2.0 506.9 1.0X - Parquet Vectorized (Pushdown) 742 / 764 21.2 47.2 10.7X - Native ORC Vectorized 5704 / 5743 2.8 362.6 1.4X - Native ORC Vectorized (Pushdown) 76 / 78 207.9 4.8 105.4X - - Select 10% rows (id < 1572864): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ----------------------------------------------------------------------------------------------- - Parquet Vectorized 8733 / 8808 1.8 555.2 1.0X - Parquet Vectorized (Pushdown) 2213 / 2267 7.1 140.7 3.9X - Native ORC Vectorized 6420 / 6463 2.4 408.2 1.4X - Native ORC Vectorized (Pushdown) 1313 / 1331 12.0 83.5 6.7X - - Select 50% rows (id < 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ----------------------------------------------------------------------------------------------- - Parquet Vectorized 11518 / 11591 1.4 732.3 1.0X - Parquet Vectorized (Pushdown) 7962 / 7991 2.0 506.2 1.4X - Native ORC Vectorized 8927 / 8985 1.8 567.6 1.3X - Native ORC Vectorized (Pushdown) 6102 / 6160 2.6 387.9 1.9X - - Select 90% rows (id < 14155776): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ----------------------------------------------------------------------------------------------- - Parquet Vectorized 14255 / 14389 1.1 906.3 1.0X - Parquet Vectorized (Pushdown) 13564 / 13594 1.2 862.4 1.1X - Native ORC Vectorized 11442 / 11608 1.4 727.5 1.2X - Native ORC Vectorized (Pushdown) 10991 / 11029 1.4 698.8 1.3X - - Select all rows (id IS NOT NULL): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ----------------------------------------------------------------------------------------------- - Parquet Vectorized 14917 / 14938 1.1 948.4 1.0X - Parquet Vectorized (Pushdown) 14910 / 14964 1.1 948.0 1.0X - Native ORC Vectorized 11986 / 12069 1.3 762.0 1.2X - Native ORC Vectorized (Pushdown) 12037 / 12123 1.3 765.3 1.2X - - Select all rows (id > -1): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ----------------------------------------------------------------------------------------------- - Parquet Vectorized 14951 / 14976 1.1 950.6 1.0X - Parquet Vectorized (Pushdown) 14934 / 15016 1.1 949.5 1.0X - Native ORC Vectorized 12000 / 12156 1.3 763.0 1.2X - Native ORC Vectorized (Pushdown) 12079 / 12113 1.3 767.9 1.2X - - Select all rows (id != -1): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ----------------------------------------------------------------------------------------------- - Parquet Vectorized 14930 / 14972 1.1 949.3 1.0X - Parquet Vectorized (Pushdown) 15015 / 15047 1.0 954.6 1.0X - Native ORC Vectorized 12090 / 12259 1.3 768.7 1.2X - Native ORC Vectorized (Pushdown) 12021 / 12096 1.3 764.2 1.2X - */ - benchmark.run() - } - - def main(args: Array[String]): Unit = { - val numRows = 1024 * 1024 * 15 - val width = 5 - val mid = numRows / 2 - - withTempPath { dir => - withTempTable("orcTable", "patquetTable") { - prepareTable(dir, numRows, width) - - Seq("id IS NULL", s"$mid < id AND id < $mid").foreach { whereExpr => - val title = s"Select 0 row ($whereExpr)".replace("id AND id", "id") - filterPushDownBenchmark(numRows, title, whereExpr) - } - - Seq( - s"id = $mid", - s"id <=> $mid", - s"$mid <= id AND id <= $mid", - s"${mid - 1} < id AND id < ${mid + 1}" - ).foreach { whereExpr => - val title = s"Select 1 row ($whereExpr)".replace("id AND id", "id") - filterPushDownBenchmark(numRows, title, whereExpr) - } - - val selectExpr = (1 to width).map(i => s"MAX(c$i)").mkString("", ",", ", MAX(id)") - - Seq(10, 50, 90).foreach { percent => - filterPushDownBenchmark( - numRows, - s"Select $percent% rows (id < ${numRows * percent / 100})", - s"id < ${numRows * percent / 100}", - selectExpr - ) - } - - Seq("id IS NOT NULL", "id > -1", "id != -1").foreach { whereExpr => - filterPushDownBenchmark( - numRows, - s"Select all rows ($whereExpr)", - whereExpr, - selectExpr) - } - } - } - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/FilterPushdownBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/FilterPushdownBenchmark.scala new file mode 100644 index 0000000000000..6d7c7de9a856e --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/FilterPushdownBenchmark.scala @@ -0,0 +1,442 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.benchmark + +import java.io.File + +import scala.util.{Random, Try} + +import org.apache.spark.SparkConf +import org.apache.spark.sql.{DataFrame, SparkSession} +import org.apache.spark.sql.functions.monotonically_increasing_id +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.util.{Benchmark, Utils} + + +/** + * Benchmark to measure read performance with Filter pushdown. + * To run this: + * spark-submit --class + */ +object FilterPushdownBenchmark { + val conf = new SparkConf() + .setAppName("FilterPushdownBenchmark") + // Since `spark.master` always exists, overrides this value + .set("spark.master", "local[1]") + .setIfMissing("spark.driver.memory", "3g") + .setIfMissing("spark.executor.memory", "3g") + .setIfMissing("spark.ui.enabled", "false") + .setIfMissing("orc.compression", "snappy") + .setIfMissing("spark.sql.parquet.compression.codec", "snappy") + + private val spark = SparkSession.builder().config(conf).getOrCreate() + + def withTempPath(f: File => Unit): Unit = { + val path = Utils.createTempDir() + path.delete() + try f(path) finally Utils.deleteRecursively(path) + } + + def withTempTable(tableNames: String*)(f: => Unit): Unit = { + try f finally tableNames.foreach(spark.catalog.dropTempView) + } + + def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = { + val (keys, values) = pairs.unzip + val currentValues = keys.map(key => Try(spark.conf.get(key)).toOption) + (keys, values).zipped.foreach(spark.conf.set) + try f finally { + keys.zip(currentValues).foreach { + case (key, Some(value)) => spark.conf.set(key, value) + case (key, None) => spark.conf.unset(key) + } + } + } + + private def prepareTable( + dir: File, numRows: Int, width: Int, useStringForValue: Boolean): Unit = { + import spark.implicits._ + val selectExpr = (1 to width).map(i => s"CAST(value AS STRING) c$i") + val valueCol = if (useStringForValue) { + monotonically_increasing_id().cast("string") + } else { + monotonically_increasing_id() + } + val df = spark.range(numRows).map(_ => Random.nextLong).selectExpr(selectExpr: _*) + .withColumn("value", valueCol) + .sort("value") + + saveAsOrcTable(df, dir.getCanonicalPath + "/orc") + saveAsParquetTable(df, dir.getCanonicalPath + "/parquet") + } + + private def prepareStringDictTable( + dir: File, numRows: Int, numDistinctValues: Int, width: Int): Unit = { + val selectExpr = (0 to width).map { + case 0 => s"CAST(id % $numDistinctValues AS STRING) AS value" + case i => s"CAST(rand() AS STRING) c$i" + } + val df = spark.range(numRows).selectExpr(selectExpr: _*).sort("value") + + saveAsOrcTable(df, dir.getCanonicalPath + "/orc") + saveAsParquetTable(df, dir.getCanonicalPath + "/parquet") + } + + private def saveAsOrcTable(df: DataFrame, dir: String): Unit = { + // To always turn on dictionary encoding, we set 1.0 at the threshold (the default is 0.8) + df.write.mode("overwrite").option("orc.dictionary.key.threshold", 1.0).orc(dir) + spark.read.orc(dir).createOrReplaceTempView("orcTable") + } + + private def saveAsParquetTable(df: DataFrame, dir: String): Unit = { + df.write.mode("overwrite").parquet(dir) + spark.read.parquet(dir).createOrReplaceTempView("parquetTable") + } + + def filterPushDownBenchmark( + values: Int, + title: String, + whereExpr: String, + selectExpr: String = "*"): Unit = { + val benchmark = new Benchmark(title, values, minNumIters = 5) + + Seq(false, true).foreach { pushDownEnabled => + val name = s"Parquet Vectorized ${if (pushDownEnabled) s"(Pushdown)" else ""}" + benchmark.addCase(name) { _ => + withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> s"$pushDownEnabled") { + spark.sql(s"SELECT $selectExpr FROM parquetTable WHERE $whereExpr").collect() + } + } + } + + Seq(false, true).foreach { pushDownEnabled => + val name = s"Native ORC Vectorized ${if (pushDownEnabled) s"(Pushdown)" else ""}" + benchmark.addCase(name) { _ => + withSQLConf(SQLConf.ORC_FILTER_PUSHDOWN_ENABLED.key -> s"$pushDownEnabled") { + spark.sql(s"SELECT $selectExpr FROM orcTable WHERE $whereExpr").collect() + } + } + } + + /* + OpenJDK 64-Bit Server VM 1.8.0_171-b10 on Linux 4.14.33-51.37.amzn1.x86_64 + Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz + Select 0 string row (value IS NULL): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Parquet Vectorized 9201 / 9300 1.7 585.0 1.0X + Parquet Vectorized (Pushdown) 89 / 105 176.3 5.7 103.1X + Native ORC Vectorized 8886 / 8898 1.8 564.9 1.0X + Native ORC Vectorized (Pushdown) 110 / 128 143.4 7.0 83.9X + + + Select 0 string row + ('7864320' < value < '7864320'): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Parquet Vectorized 9336 / 9357 1.7 593.6 1.0X + Parquet Vectorized (Pushdown) 927 / 937 17.0 58.9 10.1X + Native ORC Vectorized 9026 / 9041 1.7 573.9 1.0X + Native ORC Vectorized (Pushdown) 257 / 272 61.1 16.4 36.3X + + + Select 1 string row (value = '7864320'): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Parquet Vectorized 9209 / 9223 1.7 585.5 1.0X + Parquet Vectorized (Pushdown) 908 / 925 17.3 57.7 10.1X + Native ORC Vectorized 8878 / 8904 1.8 564.4 1.0X + Native ORC Vectorized (Pushdown) 248 / 261 63.4 15.8 37.1X + + + Select 1 string row + (value <=> '7864320'): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Parquet Vectorized 9194 / 9216 1.7 584.5 1.0X + Parquet Vectorized (Pushdown) 899 / 908 17.5 57.2 10.2X + Native ORC Vectorized 8934 / 8962 1.8 568.0 1.0X + Native ORC Vectorized (Pushdown) 249 / 254 63.3 15.8 37.0X + + + Select 1 string row + ('7864320' <= value <= '7864320'): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Parquet Vectorized 9332 / 9351 1.7 593.3 1.0X + Parquet Vectorized (Pushdown) 915 / 934 17.2 58.2 10.2X + Native ORC Vectorized 9049 / 9057 1.7 575.3 1.0X + Native ORC Vectorized (Pushdown) 248 / 258 63.5 15.8 37.7X + + + Select all string rows + (value IS NOT NULL): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Parquet Vectorized 20478 / 20497 0.8 1301.9 1.0X + Parquet Vectorized (Pushdown) 20461 / 20550 0.8 1300.9 1.0X + Native ORC Vectorized 27464 / 27482 0.6 1746.1 0.7X + Native ORC Vectorized (Pushdown) 27454 / 27488 0.6 1745.5 0.7X + + + Select 0 int row (value IS NULL): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Parquet Vectorized 8489 / 8519 1.9 539.7 1.0X + Parquet Vectorized (Pushdown) 64 / 69 246.1 4.1 132.8X + Native ORC Vectorized 8064 / 8099 2.0 512.7 1.1X + Native ORC Vectorized (Pushdown) 88 / 94 178.6 5.6 96.4X + + + Select 0 int row + (7864320 < value < 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Parquet Vectorized 8494 / 8514 1.9 540.0 1.0X + Parquet Vectorized (Pushdown) 835 / 840 18.8 53.1 10.2X + Native ORC Vectorized 8090 / 8106 1.9 514.4 1.0X + Native ORC Vectorized (Pushdown) 249 / 257 63.2 15.8 34.1X + + + Select 1 int row (value = 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Parquet Vectorized 8552 / 8560 1.8 543.7 1.0X + Parquet Vectorized (Pushdown) 837 / 841 18.8 53.2 10.2X + Native ORC Vectorized 8178 / 8188 1.9 519.9 1.0X + Native ORC Vectorized (Pushdown) 249 / 258 63.2 15.8 34.4X + + + Select 1 int row (value <=> 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Parquet Vectorized 8562 / 8580 1.8 544.3 1.0X + Parquet Vectorized (Pushdown) 833 / 836 18.9 53.0 10.3X + Native ORC Vectorized 8164 / 8185 1.9 519.0 1.0X + Native ORC Vectorized (Pushdown) 245 / 254 64.3 15.6 35.0X + + + Select 1 int row + (7864320 <= value <= 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Parquet Vectorized 8540 / 8555 1.8 542.9 1.0X + Parquet Vectorized (Pushdown) 837 / 839 18.8 53.2 10.2X + Native ORC Vectorized 8182 / 8231 1.9 520.2 1.0X + Native ORC Vectorized (Pushdown) 250 / 259 62.9 15.9 34.1X + + + Select 1 int row + (7864319 < value < 7864321): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Parquet Vectorized 8535 / 8555 1.8 542.6 1.0X + Parquet Vectorized (Pushdown) 835 / 841 18.8 53.1 10.2X + Native ORC Vectorized 8159 / 8179 1.9 518.8 1.0X + Native ORC Vectorized (Pushdown) 244 / 250 64.5 15.5 35.0X + + + Select 10% int rows (value < 1572864): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Parquet Vectorized 9609 / 9634 1.6 610.9 1.0X + Parquet Vectorized (Pushdown) 2663 / 2672 5.9 169.3 3.6X + Native ORC Vectorized 9824 / 9850 1.6 624.6 1.0X + Native ORC Vectorized (Pushdown) 2717 / 2722 5.8 172.7 3.5X + + + Select 50% int rows (value < 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Parquet Vectorized 13592 / 13613 1.2 864.2 1.0X + Parquet Vectorized (Pushdown) 9720 / 9738 1.6 618.0 1.4X + Native ORC Vectorized 16366 / 16397 1.0 1040.5 0.8X + Native ORC Vectorized (Pushdown) 12437 / 12459 1.3 790.7 1.1X + + + Select 90% int rows (value < 14155776): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Parquet Vectorized 17580 / 17617 0.9 1117.7 1.0X + Parquet Vectorized (Pushdown) 16803 / 16827 0.9 1068.3 1.0X + Native ORC Vectorized 24169 / 24187 0.7 1536.6 0.7X + Native ORC Vectorized (Pushdown) 22147 / 22341 0.7 1408.1 0.8X + + + Select all int rows (value IS NOT NULL): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Parquet Vectorized 18461 / 18491 0.9 1173.7 1.0X + Parquet Vectorized (Pushdown) 18466 / 18530 0.9 1174.1 1.0X + Native ORC Vectorized 24231 / 24270 0.6 1540.6 0.8X + Native ORC Vectorized (Pushdown) 24207 / 24304 0.6 1539.0 0.8X + + + Select all int rows (value > -1): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Parquet Vectorized 18414 / 18453 0.9 1170.7 1.0X + Parquet Vectorized (Pushdown) 18435 / 18464 0.9 1172.1 1.0X + Native ORC Vectorized 24430 / 24454 0.6 1553.2 0.8X + Native ORC Vectorized (Pushdown) 24410 / 24465 0.6 1552.0 0.8X + + + Select all int rows (value != -1): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Parquet Vectorized 18446 / 18457 0.9 1172.8 1.0X + Parquet Vectorized (Pushdown) 18428 / 18440 0.9 1171.6 1.0X + Native ORC Vectorized 24414 / 24450 0.6 1552.2 0.8X + Native ORC Vectorized (Pushdown) 24385 / 24472 0.6 1550.4 0.8X + + + Select 0 distinct string row + (value IS NULL): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Parquet Vectorized 8322 / 8352 1.9 529.1 1.0X + Parquet Vectorized (Pushdown) 53 / 57 296.3 3.4 156.7X + Native ORC Vectorized 7903 / 7953 2.0 502.4 1.1X + Native ORC Vectorized (Pushdown) 80 / 82 197.2 5.1 104.3X + + + Select 0 distinct string row + ('100' < value < '100'): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Parquet Vectorized 8712 / 8743 1.8 553.9 1.0X + Parquet Vectorized (Pushdown) 995 / 1030 15.8 63.3 8.8X + Native ORC Vectorized 8345 / 8362 1.9 530.6 1.0X + Native ORC Vectorized (Pushdown) 84 / 87 187.6 5.3 103.9X + + + Select 1 distinct string row + (value = '100'): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Parquet Vectorized 8574 / 8610 1.8 545.1 1.0X + Parquet Vectorized (Pushdown) 1127 / 1135 14.0 71.6 7.6X + Native ORC Vectorized 8163 / 8181 1.9 519.0 1.1X + Native ORC Vectorized (Pushdown) 426 / 433 36.9 27.1 20.1X + + + Select 1 distinct string row + (value <=> '100'): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Parquet Vectorized 8549 / 8568 1.8 543.5 1.0X + Parquet Vectorized (Pushdown) 1124 / 1131 14.0 71.4 7.6X + Native ORC Vectorized 8163 / 8210 1.9 519.0 1.0X + Native ORC Vectorized (Pushdown) 426 / 436 36.9 27.1 20.1X + + + Select 1 distinct string row + ('100' <= value <= '100'): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Parquet Vectorized 8889 / 8896 1.8 565.2 1.0X + Parquet Vectorized (Pushdown) 1161 / 1168 13.6 73.8 7.7X + Native ORC Vectorized 8519 / 8554 1.8 541.6 1.0X + Native ORC Vectorized (Pushdown) 430 / 437 36.6 27.3 20.7X + + + Select all distinct string rows + (value IS NOT NULL): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Parquet Vectorized 20433 / 20533 0.8 1299.1 1.0X + Parquet Vectorized (Pushdown) 20433 / 20456 0.8 1299.1 1.0X + Native ORC Vectorized 25435 / 25513 0.6 1617.1 0.8X + Native ORC Vectorized (Pushdown) 25435 / 25507 0.6 1617.1 0.8X + */ + + benchmark.run() + } + + private def runIntBenchmark(numRows: Int, width: Int, mid: Int): Unit = { + Seq("value IS NULL", s"$mid < value AND value < $mid").foreach { whereExpr => + val title = s"Select 0 int row ($whereExpr)".replace("value AND value", "value") + filterPushDownBenchmark(numRows, title, whereExpr) + } + + Seq( + s"value = $mid", + s"value <=> $mid", + s"$mid <= value AND value <= $mid", + s"${mid - 1} < value AND value < ${mid + 1}" + ).foreach { whereExpr => + val title = s"Select 1 int row ($whereExpr)".replace("value AND value", "value") + filterPushDownBenchmark(numRows, title, whereExpr) + } + + val selectExpr = (1 to width).map(i => s"MAX(c$i)").mkString("", ",", ", MAX(value)") + + Seq(10, 50, 90).foreach { percent => + filterPushDownBenchmark( + numRows, + s"Select $percent% int rows (value < ${numRows * percent / 100})", + s"value < ${numRows * percent / 100}", + selectExpr + ) + } + + Seq("value IS NOT NULL", "value > -1", "value != -1").foreach { whereExpr => + filterPushDownBenchmark( + numRows, + s"Select all int rows ($whereExpr)", + whereExpr, + selectExpr) + } + } + + private def runStringBenchmark( + numRows: Int, width: Int, searchValue: Int, colType: String): Unit = { + Seq("value IS NULL", s"'$searchValue' < value AND value < '$searchValue'") + .foreach { whereExpr => + val title = s"Select 0 $colType row ($whereExpr)".replace("value AND value", "value") + filterPushDownBenchmark(numRows, title, whereExpr) + } + + Seq( + s"value = '$searchValue'", + s"value <=> '$searchValue'", + s"'$searchValue' <= value AND value <= '$searchValue'" + ).foreach { whereExpr => + val title = s"Select 1 $colType row ($whereExpr)".replace("value AND value", "value") + filterPushDownBenchmark(numRows, title, whereExpr) + } + + val selectExpr = (1 to width).map(i => s"MAX(c$i)").mkString("", ",", ", MAX(value)") + + Seq("value IS NOT NULL").foreach { whereExpr => + filterPushDownBenchmark( + numRows, + s"Select all $colType rows ($whereExpr)", + whereExpr, + selectExpr) + } + } + + def main(args: Array[String]): Unit = { + val numRows = 1024 * 1024 * 15 + val width = 5 + + // Pushdown for many distinct value case + withTempPath { dir => + val mid = numRows / 2 + + withTempTable("orcTable", "patquetTable") { + Seq(true, false).foreach { useStringForValue => + prepareTable(dir, numRows, width, useStringForValue) + if (useStringForValue) { + runStringBenchmark(numRows, width, mid, "string") + } else { + runIntBenchmark(numRows, width, mid) + } + } + } + } + + // Pushdown for few distinct value case (use dictionary encoding) + withTempPath { dir => + val numDistinctValues = 200 + val mid = numDistinctValues / 2 + + withTempTable("orcTable", "patquetTable") { + prepareStringDictTable(dir, numRows, numDistinctValues, width) + runStringBenchmark(numRows, width, mid, "distinct string") + } + } + } +} From a5849ad9a3e5d41b5938faa7c592bcc6aec36044 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Sun, 24 Jun 2018 09:28:46 +0800 Subject: [PATCH 1004/1161] [SPARK-24324][PYTHON] Pandas Grouped Map UDF should assign result columns by name ## What changes were proposed in this pull request? Currently, a `pandas_udf` of type `PandasUDFType.GROUPED_MAP` will assign the resulting columns based on index of the return pandas.DataFrame. If a new DataFrame is returned and constructed using a dict, then the order of the columns could be arbitrary and be different than the defined schema for the UDF. If the schema types still match, then no error will be raised and the user will see column names and column data mixed up. This change will first try to assign columns using the return type field names. If a KeyError occurs, then the column index is checked if it is string based. If so, then the error is raised as it is most likely a naming mistake, else it will fallback to assign columns by position and raise a TypeError if the field types do not match. ## How was this patch tested? Added a test that returns a new DataFrame with column order different than the schema. Author: Bryan Cutler Closes #21427 from BryanCutler/arrow-grouped-map-mixesup-cols-SPARK-24324. --- docs/sql-programming-guide.md | 12 +- python/pyspark/sql/functions.py | 7 +- python/pyspark/sql/tests.py | 104 +++++++++++++++++- python/pyspark/worker.py | 55 ++++++--- .../apache/spark/sql/internal/SQLConf.scala | 13 +++ .../sql/execution/arrow/ArrowUtils.scala | 16 +++ .../python/AggregateInPandasExec.scala | 15 ++- .../python/ArrowEvalPythonExec.scala | 15 ++- .../execution/python/ArrowPythonRunner.scala | 15 ++- .../python/FlatMapGroupsInPandasExec.scala | 15 ++- .../execution/python/WindowInPandasExec.scala | 14 ++- 11 files changed, 226 insertions(+), 55 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 4d8a738507bd1..d2db067989434 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1752,14 +1752,10 @@ To use `groupBy().apply()`, the user needs to define the following: * A Python function that defines the computation for each group. * A `StructType` object or a string that defines the schema of the output `DataFrame`. -The output schema will be applied to the columns of the returned `pandas.DataFrame` in order by position, -not by name. This means that the columns in the `pandas.DataFrame` must be indexed so that their -position matches the corresponding field in the schema. - -Note that when creating a new `pandas.DataFrame` using a dictionary, the actual position of the column -can differ from the order that it was placed in the dictionary. It is recommended in this case to -explicitly define the column order using the `columns` keyword, e.g. -`pandas.DataFrame({'id': ids, 'a': data}, columns=['id', 'a'])`, or alternatively use an `OrderedDict`. +The column labels of the returned `pandas.DataFrame` must either match the field names in the +defined output schema if specified as strings, or match the field data types by position if not +strings, e.g. integer indices. See [pandas.DataFrame](https://pandas.pydata.org/pandas-docs/stable/generated/pandas.DataFrame.html#pandas.DataFrame) +on how to label columns when constructing a `pandas.DataFrame`. Note that all data for a group will be loaded into memory before the function is applied. This can lead to out of memory exceptons, especially if the group sizes are skewed. The configuration for diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 5f5d73307e9c8..9652d3e79b875 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2584,9 +2584,10 @@ def pandas_udf(f=None, returnType=None, functionType=None): A grouped map UDF defines transformation: A `pandas.DataFrame` -> A `pandas.DataFrame` The returnType should be a :class:`StructType` describing the schema of the returned - `pandas.DataFrame`. - The length of the returned `pandas.DataFrame` can be arbitrary and the columns must be - indexed so that their position matches the corresponding field in the schema. + `pandas.DataFrame`. The column labels of the returned `pandas.DataFrame` must either match + the field names in the defined returnType schema if specified as strings, or match the + field data types by position if not strings, e.g. integer indices. + The length of the returned `pandas.DataFrame` can be arbitrary. Grouped map UDFs are used with :meth:`pyspark.sql.GroupedData.apply`. diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 3c8a8fcf6e946..35a0636e5cfc0 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -4742,7 +4742,6 @@ def test_vectorized_udf_chained(self): def test_vectorized_udf_wrong_return_type(self): from pyspark.sql.functions import pandas_udf, col - df = self.spark.range(10) with QuietTest(self.sc): with self.assertRaisesRegexp( NotImplementedError, @@ -5327,6 +5326,109 @@ def foo3(key, pdf): expected4 = udf3.func((), pdf) self.assertPandasEqual(expected4, result4) + def test_column_order(self): + from collections import OrderedDict + import pandas as pd + from pyspark.sql.functions import pandas_udf, PandasUDFType + + # Helper function to set column names from a list + def rename_pdf(pdf, names): + pdf.rename(columns={old: new for old, new in + zip(pd_result.columns, names)}, inplace=True) + + df = self.data + grouped_df = df.groupby('id') + grouped_pdf = df.toPandas().groupby('id') + + # Function returns a pdf with required column names, but order could be arbitrary using dict + def change_col_order(pdf): + # Constructing a DataFrame from a dict should result in the same order, + # but use from_items to ensure the pdf column order is different than schema + return pd.DataFrame.from_items([ + ('id', pdf.id), + ('u', pdf.v * 2), + ('v', pdf.v)]) + + ordered_udf = pandas_udf( + change_col_order, + 'id long, v int, u int', + PandasUDFType.GROUPED_MAP + ) + + # The UDF result should assign columns by name from the pdf + result = grouped_df.apply(ordered_udf).sort('id', 'v')\ + .select('id', 'u', 'v').toPandas() + pd_result = grouped_pdf.apply(change_col_order) + expected = pd_result.sort_values(['id', 'v']).reset_index(drop=True) + self.assertPandasEqual(expected, result) + + # Function returns a pdf with positional columns, indexed by range + def range_col_order(pdf): + # Create a DataFrame with positional columns, fix types to long + return pd.DataFrame(list(zip(pdf.id, pdf.v * 3, pdf.v)), dtype='int64') + + range_udf = pandas_udf( + range_col_order, + 'id long, u long, v long', + PandasUDFType.GROUPED_MAP + ) + + # The UDF result uses positional columns from the pdf + result = grouped_df.apply(range_udf).sort('id', 'v') \ + .select('id', 'u', 'v').toPandas() + pd_result = grouped_pdf.apply(range_col_order) + rename_pdf(pd_result, ['id', 'u', 'v']) + expected = pd_result.sort_values(['id', 'v']).reset_index(drop=True) + self.assertPandasEqual(expected, result) + + # Function returns a pdf with columns indexed with integers + def int_index(pdf): + return pd.DataFrame(OrderedDict([(0, pdf.id), (1, pdf.v * 4), (2, pdf.v)])) + + int_index_udf = pandas_udf( + int_index, + 'id long, u int, v int', + PandasUDFType.GROUPED_MAP + ) + + # The UDF result should assign columns by position of integer index + result = grouped_df.apply(int_index_udf).sort('id', 'v') \ + .select('id', 'u', 'v').toPandas() + pd_result = grouped_pdf.apply(int_index) + rename_pdf(pd_result, ['id', 'u', 'v']) + expected = pd_result.sort_values(['id', 'v']).reset_index(drop=True) + self.assertPandasEqual(expected, result) + + @pandas_udf('id long, v int', PandasUDFType.GROUPED_MAP) + def column_name_typo(pdf): + return pd.DataFrame({'iid': pdf.id, 'v': pdf.v}) + + @pandas_udf('id long, v int', PandasUDFType.GROUPED_MAP) + def invalid_positional_types(pdf): + return pd.DataFrame([(u'a', 1.2)]) + + with QuietTest(self.sc): + with self.assertRaisesRegexp(Exception, "KeyError: 'id'"): + grouped_df.apply(column_name_typo).collect() + with self.assertRaisesRegexp(Exception, "No cast implemented"): + grouped_df.apply(invalid_positional_types).collect() + + def test_positional_assignment_conf(self): + import pandas as pd + from pyspark.sql.functions import pandas_udf, PandasUDFType + + with self.sql_conf({"spark.sql.execution.pandas.groupedMap.assignColumnsByPosition": True}): + + @pandas_udf("a string, b float", PandasUDFType.GROUPED_MAP) + def foo(_): + return pd.DataFrame([('hi', 1)], columns=['x', 'y']) + + df = self.data + result = df.groupBy('id').apply(foo).select('a', 'b').collect() + for r in result: + self.assertEqual(r.a, 'hi') + self.assertEqual(r.b, 1) + @unittest.skipIf( not _have_pandas or not _have_pyarrow, diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 38fe2ef06eac5..eaaae2b14e107 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -38,6 +38,9 @@ from pyspark.util import _get_argspec, fail_on_stopiteration from pyspark import shuffle +if sys.version >= '3': + basestring = str + pickleSer = PickleSerializer() utf8_deserializer = UTF8Deserializer() @@ -92,7 +95,10 @@ def verify_result_length(*a): return lambda *a: (verify_result_length(*a), arrow_return_type) -def wrap_grouped_map_pandas_udf(f, return_type, argspec): +def wrap_grouped_map_pandas_udf(f, return_type, argspec, runner_conf): + assign_cols_by_pos = runner_conf.get( + "spark.sql.execution.pandas.groupedMap.assignColumnsByPosition", False) + def wrapped(key_series, value_series): import pandas as pd @@ -110,9 +116,13 @@ def wrapped(key_series, value_series): "Number of columns of the returned pandas.DataFrame " "doesn't match specified schema. " "Expected: {} Actual: {}".format(len(return_type), len(result.columns))) - arrow_return_types = (to_arrow_type(field.dataType) for field in return_type) - return [(result[result.columns[i]], arrow_type) - for i, arrow_type in enumerate(arrow_return_types)] + + # Assign result columns by schema name if user labeled with strings, else use position + if not assign_cols_by_pos and any(isinstance(name, basestring) for name in result.columns): + return [(result[field.name], to_arrow_type(field.dataType)) for field in return_type] + else: + return [(result[result.columns[i]], to_arrow_type(field.dataType)) + for i, field in enumerate(return_type)] return wrapped @@ -143,7 +153,7 @@ def wrapped(*series): return lambda *a: (wrapped(*a), arrow_return_type) -def read_single_udf(pickleSer, infile, eval_type): +def read_single_udf(pickleSer, infile, eval_type, runner_conf): num_arg = read_int(infile) arg_offsets = [read_int(infile) for i in range(num_arg)] row_func = None @@ -163,7 +173,7 @@ def read_single_udf(pickleSer, infile, eval_type): return arg_offsets, wrap_scalar_pandas_udf(func, return_type) elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF: argspec = _get_argspec(row_func) # signature was lost when wrapping it - return arg_offsets, wrap_grouped_map_pandas_udf(func, return_type, argspec) + return arg_offsets, wrap_grouped_map_pandas_udf(func, return_type, argspec, runner_conf) elif eval_type == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF: return arg_offsets, wrap_grouped_agg_pandas_udf(func, return_type) elif eval_type == PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF: @@ -175,6 +185,26 @@ def read_single_udf(pickleSer, infile, eval_type): def read_udfs(pickleSer, infile, eval_type): + runner_conf = {} + + if eval_type in (PythonEvalType.SQL_SCALAR_PANDAS_UDF, + PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, + PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF, + PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF): + + # Load conf used for pandas_udf evaluation + num_conf = read_int(infile) + for i in range(num_conf): + k = utf8_deserializer.loads(infile) + v = utf8_deserializer.loads(infile) + runner_conf[k] = v + + # NOTE: if timezone is set here, that implies respectSessionTimeZone is True + timezone = runner_conf.get("spark.sql.session.timeZone", None) + ser = ArrowStreamPandasSerializer(timezone) + else: + ser = BatchedSerializer(PickleSerializer(), 100) + num_udfs = read_int(infile) udfs = {} call_udf = [] @@ -189,7 +219,7 @@ def read_udfs(pickleSer, infile, eval_type): # See FlatMapGroupsInPandasExec for how arg_offsets are used to # distinguish between grouping attributes and data attributes - arg_offsets, udf = read_single_udf(pickleSer, infile, eval_type) + arg_offsets, udf = read_single_udf(pickleSer, infile, eval_type, runner_conf) udfs['f'] = udf split_offset = arg_offsets[0] + 1 arg0 = ["a[%d]" % o for o in arg_offsets[1: split_offset]] @@ -201,7 +231,7 @@ def read_udfs(pickleSer, infile, eval_type): # In the special case of a single UDF this will return a single result rather # than a tuple of results; this is the format that the JVM side expects. for i in range(num_udfs): - arg_offsets, udf = read_single_udf(pickleSer, infile, eval_type) + arg_offsets, udf = read_single_udf(pickleSer, infile, eval_type, runner_conf) udfs['f%d' % i] = udf args = ["a[%d]" % o for o in arg_offsets] call_udf.append("f%d(%s)" % (i, ", ".join(args))) @@ -210,15 +240,6 @@ def read_udfs(pickleSer, infile, eval_type): mapper = eval(mapper_str, udfs) func = lambda _, it: map(mapper, it) - if eval_type in (PythonEvalType.SQL_SCALAR_PANDAS_UDF, - PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, - PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF, - PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF): - timezone = utf8_deserializer.loads(infile) - ser = ArrowStreamPandasSerializer(timezone) - else: - ser = BatchedSerializer(PickleSerializer(), 100) - # profiling is not supported for UDF return func, None, ser, ser diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 8d2320d8a6ed7..d5fb524a1396f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1161,6 +1161,16 @@ object SQLConf { .booleanConf .createWithDefault(true) + val PANDAS_GROUPED_MAP_ASSIGN_COLUMNS_BY_POSITION = + buildConf("spark.sql.execution.pandas.groupedMap.assignColumnsByPosition") + .internal() + .doc("When true, a grouped map Pandas UDF will assign columns from the returned " + + "Pandas DataFrame based on position, regardless of column label type. When false, " + + "columns will be looked up by name if labeled with a string and fallback to use " + + "position if not.") + .booleanConf + .createWithDefault(false) + val REPLACE_EXCEPT_WITH_FILTER = buildConf("spark.sql.optimizer.replaceExceptWithFilter") .internal() .doc("When true, the apply function of the rule verifies whether the right node of the" + @@ -1647,6 +1657,9 @@ class SQLConf extends Serializable with Logging { def pandasRespectSessionTimeZone: Boolean = getConf(PANDAS_RESPECT_SESSION_LOCAL_TIMEZONE) + def pandasGroupedMapAssignColumnssByPosition: Boolean = + getConf(SQLConf.PANDAS_GROUPED_MAP_ASSIGN_COLUMNS_BY_POSITION) + def replaceExceptWithFilter: Boolean = getConf(REPLACE_EXCEPT_WITH_FILTER) def decimalOperationsAllowPrecisionLoss: Boolean = getConf(DECIMAL_OPERATIONS_ALLOW_PREC_LOSS) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowUtils.scala index 6ad11bda84bf6..93c8127681b3e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowUtils.scala @@ -23,6 +23,7 @@ import org.apache.arrow.memory.RootAllocator import org.apache.arrow.vector.types.{DateUnit, FloatingPointPrecision, TimeUnit} import org.apache.arrow.vector.types.pojo.{ArrowType, Field, FieldType, Schema} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ object ArrowUtils { @@ -120,4 +121,19 @@ object ArrowUtils { StructField(field.getName, dt, field.isNullable) }) } + + /** Return Map with conf settings to be used in ArrowPythonRunner */ + def getPythonRunnerConfMap(conf: SQLConf): Map[String, String] = { + val timeZoneConf = if (conf.pandasRespectSessionTimeZone) { + Seq(SQLConf.SESSION_LOCAL_TIMEZONE.key -> conf.sessionLocalTimeZone) + } else { + Nil + } + val pandasColsByPosition = if (conf.pandasGroupedMapAssignColumnssByPosition) { + Seq(SQLConf.PANDAS_GROUPED_MAP_ASSIGN_COLUMNS_BY_POSITION.key -> "true") + } else { + Nil + } + Map(timeZoneConf ++ pandasColsByPosition: _*) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala index 8e01e8e56a5bd..d00f6f042d6e0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning} import org.apache.spark.sql.execution.{GroupedIterator, SparkPlan, UnaryExecNode} +import org.apache.spark.sql.execution.arrow.ArrowUtils import org.apache.spark.sql.types.{DataType, StructField, StructType} import org.apache.spark.util.Utils @@ -81,7 +82,7 @@ case class AggregateInPandasExec( val bufferSize = inputRDD.conf.getInt("spark.buffer.size", 65536) val reuseWorker = inputRDD.conf.getBoolean("spark.python.worker.reuse", defaultValue = true) val sessionLocalTimeZone = conf.sessionLocalTimeZone - val pandasRespectSessionTimeZone = conf.pandasRespectSessionTimeZone + val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf) val (pyFuncs, inputs) = udfExpressions.map(collectFunctions).unzip @@ -135,10 +136,14 @@ case class AggregateInPandasExec( } val columnarBatchIter = new ArrowPythonRunner( - pyFuncs, bufferSize, reuseWorker, - PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF, argOffsets, aggInputSchema, - sessionLocalTimeZone, pandasRespectSessionTimeZone) - .compute(projectedRowIter, context.partitionId(), context) + pyFuncs, + bufferSize, + reuseWorker, + PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF, + argOffsets, + aggInputSchema, + sessionLocalTimeZone, + pythonRunnerConf).compute(projectedRowIter, context.partitionId(), context) val joinedAttributes = groupingExpressions.map(_.toAttribute) ++ udfExpressions.map(_.resultAttribute) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala index c4de214679ae4..0bc21c0986e69 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala @@ -24,6 +24,7 @@ import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.arrow.ArrowUtils import org.apache.spark.sql.types.StructType /** @@ -63,7 +64,7 @@ case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chi private val batchSize = conf.arrowMaxRecordsPerBatch private val sessionLocalTimeZone = conf.sessionLocalTimeZone - private val pandasRespectSessionTimeZone = conf.pandasRespectSessionTimeZone + private val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf) protected override def evaluate( funcs: Seq[ChainedPythonFunctions], @@ -80,10 +81,14 @@ case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chi val batchIter = if (batchSize > 0) new BatchIterator(iter, batchSize) else Iterator(iter) val columnarBatchIter = new ArrowPythonRunner( - funcs, bufferSize, reuseWorker, - PythonEvalType.SQL_SCALAR_PANDAS_UDF, argOffsets, schema, - sessionLocalTimeZone, pandasRespectSessionTimeZone) - .compute(batchIter, context.partitionId(), context) + funcs, + bufferSize, + reuseWorker, + PythonEvalType.SQL_SCALAR_PANDAS_UDF, + argOffsets, + schema, + sessionLocalTimeZone, + pythonRunnerConf).compute(batchIter, context.partitionId(), context) new Iterator[InternalRow] { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala index 01e19bddbfb66..ca665652f204d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala @@ -45,7 +45,7 @@ class ArrowPythonRunner( argOffsets: Array[Array[Int]], schema: StructType, timeZoneId: String, - respectTimeZone: Boolean) + conf: Map[String, String]) extends BasePythonRunner[Iterator[InternalRow], ColumnarBatch]( funcs, bufferSize, reuseWorker, evalType, argOffsets) { @@ -58,12 +58,15 @@ class ArrowPythonRunner( new WriterThread(env, worker, inputIterator, partitionIndex, context) { protected override def writeCommand(dataOut: DataOutputStream): Unit = { - PythonUDFRunner.writeUDFs(dataOut, funcs, argOffsets) - if (respectTimeZone) { - PythonRDD.writeUTF(timeZoneId, dataOut) - } else { - dataOut.writeInt(SpecialLengths.NULL) + + // Write config for the worker as a number of key -> value pairs of strings + dataOut.writeInt(conf.size) + for ((k, v) <- conf) { + PythonRDD.writeUTF(k, dataOut) + PythonRDD.writeUTF(v, dataOut) } + + PythonUDFRunner.writeUDFs(dataOut, funcs, argOffsets) } protected override def writeIteratorToStream(dataOut: DataOutputStream): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala index 513e174c7733e..f5a563baf52df 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning} import org.apache.spark.sql.execution.{GroupedIterator, SparkPlan, UnaryExecNode} +import org.apache.spark.sql.execution.arrow.ArrowUtils import org.apache.spark.sql.types.StructType /** @@ -77,7 +78,7 @@ case class FlatMapGroupsInPandasExec( val reuseWorker = inputRDD.conf.getBoolean("spark.python.worker.reuse", defaultValue = true) val chainedFunc = Seq(ChainedPythonFunctions(Seq(pandasFunction))) val sessionLocalTimeZone = conf.sessionLocalTimeZone - val pandasRespectSessionTimeZone = conf.pandasRespectSessionTimeZone + val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf) // Deduplicate the grouping attributes. // If a grouping attribute also appears in data attributes, then we don't need to send the @@ -139,10 +140,14 @@ case class FlatMapGroupsInPandasExec( val context = TaskContext.get() val columnarBatchIter = new ArrowPythonRunner( - chainedFunc, bufferSize, reuseWorker, - PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, argOffsets, dedupSchema, - sessionLocalTimeZone, pandasRespectSessionTimeZone) - .compute(grouped, context.partitionId(), context) + chainedFunc, + bufferSize, + reuseWorker, + PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, + argOffsets, + dedupSchema, + sessionLocalTimeZone, + pythonRunnerConf).compute(grouped, context.partitionId(), context) columnarBatchIter.flatMap(_.rowIterator.asScala).map(UnsafeProjection.create(output, output)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala index c76832a1a3829..628029b13a6c3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.{GroupedIterator, SparkPlan, UnaryExecNode} +import org.apache.spark.sql.execution.arrow.ArrowUtils import org.apache.spark.sql.types.{DataType, StructField, StructType} import org.apache.spark.util.Utils @@ -97,7 +98,7 @@ case class WindowInPandasExec( val bufferSize = inputRDD.conf.getInt("spark.buffer.size", 65536) val reuseWorker = inputRDD.conf.getBoolean("spark.python.worker.reuse", defaultValue = true) val sessionLocalTimeZone = conf.sessionLocalTimeZone - val pandasRespectSessionTimeZone = conf.pandasRespectSessionTimeZone + val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf) // Extract window expressions and window functions val expressions = windowExpression.flatMap(_.collect { case e: WindowExpression => e }) @@ -154,11 +155,14 @@ case class WindowInPandasExec( } val windowFunctionResult = new ArrowPythonRunner( - pyFuncs, bufferSize, reuseWorker, + pyFuncs, + bufferSize, + reuseWorker, PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF, - argOffsets, windowInputSchema, - sessionLocalTimeZone, pandasRespectSessionTimeZone) - .compute(pythonInput, context.partitionId(), context) + argOffsets, + windowInputSchema, + sessionLocalTimeZone, + pythonRunnerConf).compute(pythonInput, context.partitionId(), context) val joined = new JoinedRow val resultProj = createResultProjection(expressions) From f596ebe4d3170590b6fce34c179e51ee80c965d3 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Sun, 24 Jun 2018 23:14:42 -0700 Subject: [PATCH 1005/1161] [SPARK-24327][SQL] Verify and normalize a partition column name based on the JDBC resolved schema ## What changes were proposed in this pull request? This pr modified JDBC datasource code to verify and normalize a partition column based on the JDBC resolved schema before building `JDBCRelation`. Closes #20370 ## How was this patch tested? Added tests in `JDBCSuite`. Author: Takeshi Yamamuro Closes #21379 from maropu/SPARK-24327. --- .../scala/org/apache/spark/util/Utils.scala | 2 +- .../datasources/jdbc/JDBCRelation.scala | 76 +++++++++++++++---- .../jdbc/JdbcRelationProvider.scala | 6 +- .../org/apache/spark/sql/jdbc/JDBCSuite.scala | 51 ++++++++++++- 4 files changed, 118 insertions(+), 17 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index c139db46b63a3..a6fd3637663e8 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -100,7 +100,7 @@ private[spark] object Utils extends Logging { */ val DEFAULT_MAX_TO_STRING_FIELDS = 25 - private def maxNumToStringFields = { + private[spark] def maxNumToStringFields = { if (SparkEnv.get != null) { SparkEnv.get.conf.getInt("spark.debug.maxToStringFields", DEFAULT_MAX_TO_STRING_FIELDS) } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala index b23e5a7722004..b84543ccd7869 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala @@ -22,10 +22,12 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.Partition import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, Row, SaveMode, SparkSession, SQLContext} +import org.apache.spark.sql.{AnalysisException, DataFrame, Row, SaveMode, SparkSession, SQLContext} +import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.jdbc.JdbcDialects import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.StructType +import org.apache.spark.util.Utils /** * Instructions on how to partition the table among workers. @@ -48,10 +50,17 @@ private[sql] object JDBCRelation extends Logging { * Null value predicate is added to the first partition where clause to include * the rows with null value for the partitions column. * + * @param schema resolved schema of a JDBC table * @param partitioning partition information to generate the where clause for each partition + * @param resolver function used to determine if two identifiers are equal + * @param jdbcOptions JDBC options that contains url * @return an array of partitions with where clause for each partition */ - def columnPartition(partitioning: JDBCPartitioningInfo): Array[Partition] = { + def columnPartition( + schema: StructType, + partitioning: JDBCPartitioningInfo, + resolver: Resolver, + jdbcOptions: JDBCOptions): Array[Partition] = { if (partitioning == null || partitioning.numPartitions <= 1 || partitioning.lowerBound == partitioning.upperBound) { return Array[Partition](JDBCPartition(null, 0)) @@ -78,7 +87,10 @@ private[sql] object JDBCRelation extends Logging { // Overflow and silliness can happen if you subtract then divide. // Here we get a little roundoff, but that's (hopefully) OK. val stride: Long = upperBound / numPartitions - lowerBound / numPartitions - val column = partitioning.column + + val column = verifyAndGetNormalizedColumnName( + schema, partitioning.column, resolver, jdbcOptions) + var i: Int = 0 var currentValue: Long = lowerBound val ans = new ArrayBuffer[Partition]() @@ -99,10 +111,57 @@ private[sql] object JDBCRelation extends Logging { } ans.toArray } + + // Verify column name based on the JDBC resolved schema + private def verifyAndGetNormalizedColumnName( + schema: StructType, + columnName: String, + resolver: Resolver, + jdbcOptions: JDBCOptions): String = { + val dialect = JdbcDialects.get(jdbcOptions.url) + schema.map(_.name).find { fieldName => + resolver(fieldName, columnName) || + resolver(dialect.quoteIdentifier(fieldName), columnName) + }.map(dialect.quoteIdentifier).getOrElse { + throw new AnalysisException(s"User-defined partition column $columnName not " + + s"found in the JDBC relation: ${schema.simpleString(Utils.maxNumToStringFields)}") + } + } + + /** + * Takes a (schema, table) specification and returns the table's Catalyst schema. + * If `customSchema` defined in the JDBC options, replaces the schema's dataType with the + * custom schema's type. + * + * @param resolver function used to determine if two identifiers are equal + * @param jdbcOptions JDBC options that contains url, table and other information. + * @return resolved Catalyst schema of a JDBC table + */ + def getSchema(resolver: Resolver, jdbcOptions: JDBCOptions): StructType = { + val tableSchema = JDBCRDD.resolveTable(jdbcOptions) + jdbcOptions.customSchema match { + case Some(customSchema) => JdbcUtils.getCustomSchema( + tableSchema, customSchema, resolver) + case None => tableSchema + } + } + + /** + * Resolves a Catalyst schema of a JDBC table and returns [[JDBCRelation]] with the schema. + */ + def apply( + parts: Array[Partition], + jdbcOptions: JDBCOptions)( + sparkSession: SparkSession): JDBCRelation = { + val schema = JDBCRelation.getSchema(sparkSession.sessionState.conf.resolver, jdbcOptions) + JDBCRelation(schema, parts, jdbcOptions)(sparkSession) + } } private[sql] case class JDBCRelation( - parts: Array[Partition], jdbcOptions: JDBCOptions)(@transient val sparkSession: SparkSession) + override val schema: StructType, + parts: Array[Partition], + jdbcOptions: JDBCOptions)(@transient val sparkSession: SparkSession) extends BaseRelation with PrunedFilteredScan with InsertableRelation { @@ -111,15 +170,6 @@ private[sql] case class JDBCRelation( override val needConversion: Boolean = false - override val schema: StructType = { - val tableSchema = JDBCRDD.resolveTable(jdbcOptions) - jdbcOptions.customSchema match { - case Some(customSchema) => JdbcUtils.getCustomSchema( - tableSchema, customSchema, sparkSession.sessionState.conf.resolver) - case None => tableSchema - } - } - // Check if JDBCRDD.compileFilter can accept input filters override def unhandledFilters(filters: Array[Filter]): Array[Filter] = { filters.filter(JDBCRDD.compileFilter(_, JdbcDialects.get(jdbcOptions.url)).isEmpty) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala index f8c5677ea0f2a..2b488bb7121dc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala @@ -48,8 +48,10 @@ class JdbcRelationProvider extends CreatableRelationProvider JDBCPartitioningInfo( partitionColumn.get, lowerBound.get, upperBound.get, numPartitions.get) } - val parts = JDBCRelation.columnPartition(partitionInfo) - JDBCRelation(parts, jdbcOptions)(sqlContext.sparkSession) + val resolver = sqlContext.conf.resolver + val schema = JDBCRelation.getSchema(resolver, jdbcOptions) + val parts = JDBCRelation.columnPartition(schema, partitionInfo, resolver, jdbcOptions) + JDBCRelation(schema, parts, jdbcOptions)(sqlContext.sparkSession) } override def createRelation( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index bc2aca65e803f..6ea61f02a8206 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -31,8 +31,9 @@ import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.execution.DataSourceScanExec import org.apache.spark.sql.execution.command.ExplainCommand import org.apache.spark.sql.execution.datasources.LogicalRelation -import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JDBCRDD, JDBCRelation, JdbcUtils} +import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JDBCPartition, JDBCRDD, JDBCRelation, JdbcUtils} import org.apache.spark.sql.execution.metric.InputOutputMetricsHelper +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ @@ -238,6 +239,11 @@ class JDBCSuite extends SparkFunSuite |OPTIONS (url '$url', dbtable 'TEST."mixedCaseCols"', user 'testUser', password 'testPass') """.stripMargin.replaceAll("\n", " ")) + conn.prepareStatement("CREATE TABLE test.partition (THEID INTEGER, `THE ID` INTEGER) " + + "AS SELECT 1, 1") + .executeUpdate() + conn.commit() + // Untested: IDENTITY, OTHER, UUID, ARRAY, and GEOMETRY types. } @@ -1206,4 +1212,47 @@ class JDBCSuite extends SparkFunSuite }.getMessage assert(errMsg.contains("Statement was canceled or the session timed out")) } + + test("SPARK-24327 verify and normalize a partition column based on a JDBC resolved schema") { + def testJdbcParitionColumn(partColName: String, expectedColumnName: String): Unit = { + val df = spark.read.format("jdbc") + .option("url", urlWithUserAndPass) + .option("dbtable", "TEST.PARTITION") + .option("partitionColumn", partColName) + .option("lowerBound", 1) + .option("upperBound", 4) + .option("numPartitions", 3) + .load() + + val quotedPrtColName = testH2Dialect.quoteIdentifier(expectedColumnName) + df.logicalPlan match { + case LogicalRelation(JDBCRelation(_, parts, _), _, _, _) => + val whereClauses = parts.map(_.asInstanceOf[JDBCPartition].whereClause).toSet + assert(whereClauses === Set( + s"$quotedPrtColName < 2 or $quotedPrtColName is null", + s"$quotedPrtColName >= 2 AND $quotedPrtColName < 3", + s"$quotedPrtColName >= 3")) + } + } + + testJdbcParitionColumn("THEID", "THEID") + testJdbcParitionColumn("\"THEID\"", "THEID") + withSQLConf("spark.sql.caseSensitive" -> "false") { + testJdbcParitionColumn("ThEiD", "THEID") + } + testJdbcParitionColumn("THE ID", "THE ID") + + def testIncorrectJdbcPartitionColumn(partColName: String): Unit = { + val errMsg = intercept[AnalysisException] { + testJdbcParitionColumn(partColName, "THEID") + }.getMessage + assert(errMsg.contains(s"User-defined partition column $partColName not found " + + "in the JDBC relation:")) + } + + testIncorrectJdbcPartitionColumn("NoExistingColumn") + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + testIncorrectJdbcPartitionColumn(testH2Dialect.quoteIdentifier("ThEiD")) + } + } } From 6e0596e2639117d7a0a58b644b0600086f45c7f9 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Sun, 24 Jun 2018 23:56:47 -0700 Subject: [PATCH 1006/1161] [SPARK-23931][SQL][FOLLOW-UP] Make `arrays_zip` in function.scala `@scala.annotation.varargs`. ## What changes were proposed in this pull request? This is a follow-up pr of #21045 which added `arrays_zip`. The `arrays_zip` in functions.scala should've been `scala.annotation.varargs`. This pr makes it `scala.annotation.varargs`. ## How was this patch tested? Existing tests. Author: Takuya UESHIN Closes #21630 from ueshin/issues/SPARK-23931/fup1. --- sql/core/src/main/scala/org/apache/spark/sql/functions.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index f792a6fba1d8f..40c40e7083d1c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3539,6 +3539,7 @@ object functions { * @group collection_funcs * @since 2.4.0 */ + @scala.annotation.varargs def arrays_zip(e: Column*): Column = withExpr { ArraysZip(e.map(_.expr)) } ////////////////////////////////////////////////////////////////////////////////////////////// From 8ab8ef7733b42e94f687a5520332814ac9caeda8 Mon Sep 17 00:00:00 2001 From: Jim Kleckner Date: Mon, 25 Jun 2018 16:23:23 +0800 Subject: [PATCH 1007/1161] Fix minor typo in docs/cloud-integration.md ## What changes were proposed in this pull request? Minor typo in docs/cloud-integration.md ## How was this patch tested? This is trivial enough that it should not affect tests. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Jim Kleckner Closes #21629 from jkleckner/fix-doc-typo. --- docs/cloud-integration.md | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/docs/cloud-integration.md b/docs/cloud-integration.md index ac1c336988930..18e8fe77bbdbe 100644 --- a/docs/cloud-integration.md +++ b/docs/cloud-integration.md @@ -70,7 +70,7 @@ be safely used as the direct destination of work with the normal rename-based co ### Installation With the relevant libraries on the classpath and Spark configured with valid credentials, -objects can be can be read or written by using their URLs as the path to data. +objects can be read or written by using their URLs as the path to data. For example `sparkContext.textFile("s3a://landsat-pds/scene_list.gz")` will create an RDD of the file `scene_list.gz` stored in S3, using the s3a connector. @@ -184,7 +184,8 @@ is no need for a workflow of write-then-rename to ensure that files aren't picke while they are still being written. Applications can write straight to the monitored directory. 1. Streams should only be checkpointed to a store implementing a fast and -atomic `rename()` operation Otherwise the checkpointing may be slow and potentially unreliable. +atomic `rename()` operation. +Otherwise the checkpointing may be slow and potentially unreliable. ## Further Reading From bac50aa37168a7612702a4503750a78ed5d59c78 Mon Sep 17 00:00:00 2001 From: Maryann Xue Date: Mon, 25 Jun 2018 07:17:30 -0700 Subject: [PATCH 1008/1161] [SPARK-24596][SQL] Non-cascading Cache Invalidation ## What changes were proposed in this pull request? 1. Add parameter 'cascade' in CacheManager.uncacheQuery(). Under 'cascade=false' mode, only invalidate the current cache, and for other dependent caches, rebuild execution plan and reuse cached buffer. 2. Pass true/false from callers in different uncache scenarios: - Drop tables and regular (persistent) views: regular mode - Drop temporary views: non-cascading mode - Modify table contents (INSERT/UPDATE/MERGE/DELETE): regular mode - Call `DataSet.unpersist()`: non-cascading mode - Call `Catalog.uncacheTable()`: follow the same convention as drop tables/view, which is, use non-cascading mode for temporary views and regular mode for the rest Note that a regular (persistent) view is a database object just like a table, so after dropping a regular view (whether cached or not cached), any query referring to that view should no long be valid. Hence if a cached persistent view is dropped, we need to invalidate the all dependent caches so that exceptions will be thrown for any later reference. On the other hand, a temporary view is in fact equivalent to an unnamed DataSet, and dropping a temporary view should have no impact on queries referencing that view. Thus we should do non-cascading uncaching for temporary views, which also guarantees a consistent uncaching behavior between temporary views and unnamed DataSets. ## How was this patch tested? New tests in CachedTableSuite and DatasetCacheSuite. Author: Maryann Xue Closes #21594 from maryannxue/noncascading-cache. --- docs/sql-programming-guide.md | 1 + .../scala/org/apache/spark/sql/Dataset.scala | 4 +- .../spark/sql/execution/CacheManager.scala | 50 ++++++++++--- .../execution/columnar/InMemoryRelation.scala | 10 +++ .../spark/sql/execution/command/ddl.scala | 8 ++- .../spark/sql/execution/command/tables.scala | 2 +- .../spark/sql/internal/CatalogImpl.scala | 12 ++-- .../apache/spark/sql/CachedTableSuite.scala | 66 ++++++++++++++++- .../apache/spark/sql/DatasetCacheSuite.scala | 70 +++++++++++++++++-- 9 files changed, 197 insertions(+), 26 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index d2db067989434..196b814420be1 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1827,6 +1827,7 @@ working with timestamps in `pandas_udf`s to get the best performance, see - Since Spark 2.4, Spark compares a DATE type with a TIMESTAMP type after promotes both sides to TIMESTAMP. To set `false` to `spark.sql.hive.compareDateTimestampInTimestamp` restores the previous behavior. This option will be removed in Spark 3.0. - Since Spark 2.4, creating a managed table with nonempty location is not allowed. An exception is thrown when attempting to create a managed table with nonempty location. To set `true` to `spark.sql.allowCreatingManagedTableUsingNonemptyLocation` restores the previous behavior. This option will be removed in Spark 3.0. - Since Spark 2.4, the type coercion rules can automatically promote the argument types of the variadic SQL functions (e.g., IN/COALESCE) to the widest common type, no matter how the input arguments order. In prior Spark versions, the promotion could fail in some specific orders (e.g., TimestampType, IntegerType and StringType) and throw an exception. + - Since Spark 2.4, Spark has enabled non-cascading SQL cache invalidation in addition to the traditional cache invalidation mechanism. The non-cascading cache invalidation mechanism allows users to remove a cache without impacting its dependent caches. This new cache invalidation mechanism is used in scenarios where the data of the cache to be removed is still valid, e.g., calling unpersist() on a Dataset, or dropping a temporary view. This allows users to free up memory and keep the desired caches valid at the same time. - In version 2.3 and earlier, `to_utc_timestamp` and `from_utc_timestamp` respect the timezone in the input timestamp string, which breaks the assumption that the input timestamp is in a specific timezone. Therefore, these 2 functions can return unexpected results. In version 2.4 and later, this problem has been fixed. `to_utc_timestamp` and `from_utc_timestamp` will return null if the input timestamp string contains timezone. As an example, `from_utc_timestamp('2000-10-10 00:00:00', 'GMT+1')` will return `2000-10-10 01:00:00` in both Spark 2.3 and 2.4. However, `from_utc_timestamp('2000-10-10 00:00:00+00:00', 'GMT+1')`, assuming a local timezone of GMT+8, will return `2000-10-10 09:00:00` in Spark 2.3 but `null` in 2.4. For people who don't care about this problem and want to retain the previous behaivor to keep their query unchanged, you can set `spark.sql.function.rejectTimezoneInString` to false. This option will be removed in Spark 3.0 and should only be used as a temporary workaround. - In version 2.3 and earlier, Spark converts Parquet Hive tables by default but ignores table properties like `TBLPROPERTIES (parquet.compression 'NONE')`. This happens for ORC Hive table properties like `TBLPROPERTIES (orc.compress 'NONE')` in case of `spark.sql.hive.convertMetastoreOrc=true`, too. Since Spark 2.4, Spark respects Parquet/ORC specific table properties while converting Parquet/ORC Hive tables. As an example, `CREATE TABLE t(id int) STORED AS PARQUET TBLPROPERTIES (parquet.compression 'NONE')` would generate Snappy parquet files during insertion in Spark 2.3, and in Spark 2.4, the result would be uncompressed parquet files. - Since Spark 2.0, Spark converts Parquet Hive tables by default for better performance. Since Spark 2.4, Spark converts ORC Hive tables by default, too. It means Spark uses its own ORC support by default instead of Hive SerDe. As an example, `CREATE TABLE t(id int) STORED AS ORC` would be handled with Hive SerDe in Spark 2.3, and in Spark 2.4, it would be converted into Spark's ORC data source table and ORC vectorization would be applied. To set `false` to `spark.sql.hive.convertMetastoreOrc` restores the previous behavior. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index f5526104690d2..57f1e173211af 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -2964,6 +2964,7 @@ class Dataset[T] private[sql]( /** * Mark the Dataset as non-persistent, and remove all blocks for it from memory and disk. + * This will not un-persist any cached data that is built upon this Dataset. * * @param blocking Whether to block until all blocks are deleted. * @@ -2971,12 +2972,13 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ def unpersist(blocking: Boolean): this.type = { - sparkSession.sharedState.cacheManager.uncacheQuery(this, blocking) + sparkSession.sharedState.cacheManager.uncacheQuery(this, cascade = false, blocking) this } /** * Mark the Dataset as non-persistent, and remove all blocks for it from memory and disk. + * This will not un-persist any cached data that is built upon this Dataset. * * @group basic * @since 1.6.0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala index 2db7c02e86014..39d9a95ca4710 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala @@ -105,24 +105,50 @@ class CacheManager extends Logging { } /** - * Un-cache all the cache entries that refer to the given plan. + * Un-cache the given plan or all the cache entries that refer to the given plan. + * @param query The [[Dataset]] to be un-cached. + * @param cascade If true, un-cache all the cache entries that refer to the given + * [[Dataset]]; otherwise un-cache the given [[Dataset]] only. + * @param blocking Whether to block until all blocks are deleted. */ - def uncacheQuery(query: Dataset[_], blocking: Boolean = true): Unit = writeLock { - uncacheQuery(query.sparkSession, query.logicalPlan, blocking) + def uncacheQuery( + query: Dataset[_], + cascade: Boolean, + blocking: Boolean = true): Unit = writeLock { + uncacheQuery(query.sparkSession, query.logicalPlan, cascade, blocking) } /** - * Un-cache all the cache entries that refer to the given plan. + * Un-cache the given plan or all the cache entries that refer to the given plan. + * @param spark The Spark session. + * @param plan The plan to be un-cached. + * @param cascade If true, un-cache all the cache entries that refer to the given + * plan; otherwise un-cache the given plan only. + * @param blocking Whether to block until all blocks are deleted. */ - def uncacheQuery(spark: SparkSession, plan: LogicalPlan, blocking: Boolean): Unit = writeLock { + def uncacheQuery( + spark: SparkSession, + plan: LogicalPlan, + cascade: Boolean, + blocking: Boolean): Unit = writeLock { + val shouldRemove: LogicalPlan => Boolean = + if (cascade) { + _.find(_.sameResult(plan)).isDefined + } else { + _.sameResult(plan) + } val it = cachedData.iterator() while (it.hasNext) { val cd = it.next() - if (cd.plan.find(_.sameResult(plan)).isDefined) { + if (shouldRemove(cd.plan)) { cd.cachedRepresentation.cacheBuilder.clearCache(blocking) it.remove() } } + // Re-compile dependent cached queries after removing the cached query. + if (!cascade) { + recacheByCondition(spark, _.find(_.sameResult(plan)).isDefined, clearCache = false) + } } /** @@ -132,20 +158,24 @@ class CacheManager extends Logging { recacheByCondition(spark, _.find(_.sameResult(plan)).isDefined) } - private def recacheByCondition(spark: SparkSession, condition: LogicalPlan => Boolean): Unit = { + private def recacheByCondition( + spark: SparkSession, + condition: LogicalPlan => Boolean, + clearCache: Boolean = true): Unit = { val it = cachedData.iterator() val needToRecache = scala.collection.mutable.ArrayBuffer.empty[CachedData] while (it.hasNext) { val cd = it.next() if (condition(cd.plan)) { - cd.cachedRepresentation.cacheBuilder.clearCache() + if (clearCache) { + cd.cachedRepresentation.cacheBuilder.clearCache() + } // Remove the cache entry before we create a new one, so that we can have a different // physical plan. it.remove() val plan = spark.sessionState.executePlan(AnalysisBarrier(cd.plan)).executedPlan val newCache = InMemoryRelation( - cacheBuilder = cd.cachedRepresentation - .cacheBuilder.copy(cachedPlan = plan)(_cachedColumnBuffers = null), + cacheBuilder = cd.cachedRepresentation.cacheBuilder.withCachedPlan(plan), logicalPlan = cd.plan) needToRecache += cd.copy(cachedRepresentation = newCache) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala index da35a4734e65a..7c8faec53a828 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala @@ -74,6 +74,16 @@ case class CachedRDDBuilder( } } + def withCachedPlan(cachedPlan: SparkPlan): CachedRDDBuilder = { + new CachedRDDBuilder( + useCompression, + batchSize, + storageLevel, + cachedPlan = cachedPlan, + tableName + )(_cachedColumnBuffers) + } + private def buildBuffers(): RDD[CachedBatch] = { val output = cachedPlan.output val cached = cachedPlan.execute().mapPartitionsInternal { rowIterator => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala index bf4d96fa18d0d..04bf8c6dd917f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala @@ -189,8 +189,9 @@ case class DropTableCommand( override def run(sparkSession: SparkSession): Seq[Row] = { val catalog = sparkSession.sessionState.catalog + val isTempView = catalog.isTemporaryTable(tableName) - if (!catalog.isTemporaryTable(tableName) && catalog.tableExists(tableName)) { + if (!isTempView && catalog.tableExists(tableName)) { // If the command DROP VIEW is to drop a table or DROP TABLE is to drop a view // issue an exception. catalog.getTableMetadata(tableName).tableType match { @@ -204,9 +205,10 @@ case class DropTableCommand( } } - if (catalog.isTemporaryTable(tableName) || catalog.tableExists(tableName)) { + if (isTempView || catalog.tableExists(tableName)) { try { - sparkSession.sharedState.cacheManager.uncacheQuery(sparkSession.table(tableName)) + sparkSession.sharedState.cacheManager.uncacheQuery( + sparkSession.table(tableName), cascade = !isTempView) } catch { case NonFatal(e) => log.warn(e.toString, e) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala index 44749190c79eb..ec3961f84bd8d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala @@ -493,7 +493,7 @@ case class TruncateTableCommand( spark.sessionState.refreshTable(tableName.unquotedString) // Also try to drop the contents of the table from the columnar cache try { - spark.sharedState.cacheManager.uncacheQuery(spark.table(table.identifier)) + spark.sharedState.cacheManager.uncacheQuery(spark.table(table.identifier), cascade = true) } catch { case NonFatal(e) => log.warn(s"Exception when attempting to uncache table $tableIdentWithDB", e) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala index 6ae307bce10c8..4698e8ab13ce3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala @@ -364,7 +364,8 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { */ override def dropTempView(viewName: String): Boolean = { sparkSession.sessionState.catalog.getTempView(viewName).exists { viewDef => - sparkSession.sharedState.cacheManager.uncacheQuery(sparkSession, viewDef, blocking = true) + sparkSession.sharedState.cacheManager.uncacheQuery( + sparkSession, viewDef, cascade = false, blocking = true) sessionCatalog.dropTempView(viewName) } } @@ -379,7 +380,8 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { */ override def dropGlobalTempView(viewName: String): Boolean = { sparkSession.sessionState.catalog.getGlobalTempView(viewName).exists { viewDef => - sparkSession.sharedState.cacheManager.uncacheQuery(sparkSession, viewDef, blocking = true) + sparkSession.sharedState.cacheManager.uncacheQuery( + sparkSession, viewDef, cascade = false, blocking = true) sessionCatalog.dropGlobalTempView(viewName) } } @@ -438,7 +440,9 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { * @since 2.0.0 */ override def uncacheTable(tableName: String): Unit = { - sparkSession.sharedState.cacheManager.uncacheQuery(sparkSession.table(tableName)) + val tableIdent = sparkSession.sessionState.sqlParser.parseTableIdentifier(tableName) + val cascade = !sessionCatalog.isTemporaryTable(tableIdent) + sparkSession.sharedState.cacheManager.uncacheQuery(sparkSession.table(tableName), cascade) } /** @@ -490,7 +494,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { // cached version and make the new version cached lazily. if (isCached(table)) { // Uncache the logicalPlan. - sparkSession.sharedState.cacheManager.uncacheQuery(table, blocking = true) + sparkSession.sharedState.cacheManager.uncacheQuery(table, cascade = true, blocking = true) // Cache it again. sparkSession.sharedState.cacheManager.cacheQuery(table, Some(tableIdent.table)) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index 6982c22f4771d..60c73df88896b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -28,7 +28,6 @@ import org.apache.spark.sql.catalyst.expressions.SubqueryExpression import org.apache.spark.sql.execution.{RDDScanExec, SparkPlan} import org.apache.spark.sql.execution.columnar._ import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec -import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} import org.apache.spark.storage.{RDDBlockId, StorageLevel} @@ -801,4 +800,69 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext } assert(cachedData.collect === Seq(1001)) } + + test("SPARK-24596 Non-cascading Cache Invalidation - uncache temporary view") { + withTempView("t1", "t2") { + sql("CACHE TABLE t1 AS SELECT * FROM testData WHERE key > 1") + sql("CACHE TABLE t2 as SELECT * FROM t1 WHERE value > 1") + + assert(spark.catalog.isCached("t1")) + assert(spark.catalog.isCached("t2")) + sql("UNCACHE TABLE t1") + assert(!spark.catalog.isCached("t1")) + assert(spark.catalog.isCached("t2")) + } + } + + test("SPARK-24596 Non-cascading Cache Invalidation - drop temporary view") { + withTempView("t1", "t2") { + sql("CACHE TABLE t1 AS SELECT * FROM testData WHERE key > 1") + sql("CACHE TABLE t2 as SELECT * FROM t1 WHERE value > 1") + + assert(spark.catalog.isCached("t1")) + assert(spark.catalog.isCached("t2")) + sql("DROP VIEW t1") + assert(spark.catalog.isCached("t2")) + } + } + + test("SPARK-24596 Non-cascading Cache Invalidation - drop persistent view") { + withTable("t") { + spark.range(1, 10).toDF("key").withColumn("value", 'key * 2) + .write.format("json").saveAsTable("t") + withView("t1") { + withTempView("t2") { + sql("CREATE VIEW t1 AS SELECT * FROM t WHERE key > 1") + + sql("CACHE TABLE t1") + sql("CACHE TABLE t2 AS SELECT * FROM t1 WHERE value > 1") + + assert(spark.catalog.isCached("t1")) + assert(spark.catalog.isCached("t2")) + sql("DROP VIEW t1") + assert(!spark.catalog.isCached("t2")) + } + } + } + } + + test("SPARK-24596 Non-cascading Cache Invalidation - uncache table") { + withTable("t") { + spark.range(1, 10).toDF("key").withColumn("value", 'key * 2) + .write.format("json").saveAsTable("t") + withTempView("t1", "t2") { + sql("CACHE TABLE t") + sql("CACHE TABLE t1 AS SELECT * FROM t WHERE key > 1") + sql("CACHE TABLE t2 AS SELECT * FROM t1 WHERE value > 1") + + assert(spark.catalog.isCached("t")) + assert(spark.catalog.isCached("t1")) + assert(spark.catalog.isCached("t2")) + sql("UNCACHE TABLE t") + assert(!spark.catalog.isCached("t")) + assert(!spark.catalog.isCached("t1")) + assert(!spark.catalog.isCached("t2")) + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala index c4f056334cd1a..5c6a021d5b767 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala @@ -29,6 +29,16 @@ import org.apache.spark.storage.StorageLevel class DatasetCacheSuite extends QueryTest with SharedSQLContext with TimeLimits { import testImplicits._ + /** + * Asserts that a cached [[Dataset]] will be built using the given number of other cached results. + */ + private def assertCacheDependency(df: DataFrame, numOfCachesDependedUpon: Int = 1): Unit = { + val plan = df.queryExecution.withCachedData + assert(plan.isInstanceOf[InMemoryRelation]) + val internalPlan = plan.asInstanceOf[InMemoryRelation].cacheBuilder.cachedPlan + assert(internalPlan.find(_.isInstanceOf[InMemoryTableScanExec]).size == numOfCachesDependedUpon) + } + test("get storage level") { val ds1 = Seq("1", "2").toDS().as("a") val ds2 = Seq(2, 3).toDS().as("b") @@ -117,7 +127,7 @@ class DatasetCacheSuite extends QueryTest with SharedSQLContext with TimeLimits } test("cache UDF result correctly") { - val expensiveUDF = udf({x: Int => Thread.sleep(10000); x}) + val expensiveUDF = udf({x: Int => Thread.sleep(5000); x}) val df = spark.range(0, 10).toDF("a").withColumn("b", expensiveUDF($"a")) val df2 = df.agg(sum(df("b"))) @@ -126,7 +136,7 @@ class DatasetCacheSuite extends QueryTest with SharedSQLContext with TimeLimits assertCached(df2) // udf has been evaluated during caching, and thus should not be re-evaluated here - failAfter(5 seconds) { + failAfter(3 seconds) { df2.collect() } @@ -143,9 +153,57 @@ class DatasetCacheSuite extends QueryTest with SharedSQLContext with TimeLimits df.count() df2.cache() - val plan = df2.queryExecution.withCachedData - assert(plan.isInstanceOf[InMemoryRelation]) - val internalPlan = plan.asInstanceOf[InMemoryRelation].cacheBuilder.cachedPlan - assert(internalPlan.find(_.isInstanceOf[InMemoryTableScanExec]).isDefined) + assertCacheDependency(df2) + } + + test("SPARK-24596 Non-cascading Cache Invalidation") { + val df = Seq(("a", 1), ("b", 2)).toDF("s", "i") + val df2 = df.filter('i > 1) + val df3 = df.filter('i < 2) + + df2.cache() + df.cache() + df.count() + df3.cache() + + df.unpersist() + + // df un-cached; df2 and df3's cache plan re-compiled + assert(df.storageLevel == StorageLevel.NONE) + assertCacheDependency(df2, 0) + assertCacheDependency(df3, 0) + } + + test("SPARK-24596 Non-cascading Cache Invalidation - verify cached data reuse") { + val expensiveUDF = udf({ x: Int => Thread.sleep(5000); x }) + val df = spark.range(0, 5).toDF("a") + val df1 = df.withColumn("b", expensiveUDF($"a")) + val df2 = df1.groupBy('a).agg(sum('b)) + val df3 = df.agg(sum('a)) + + df1.cache() + df2.cache() + df2.collect() + df3.cache() + + assertCacheDependency(df2) + + df1.unpersist(blocking = true) + + // df1 un-cached; df2's cache plan re-compiled + assert(df1.storageLevel == StorageLevel.NONE) + assertCacheDependency(df1.groupBy('a).agg(sum('b)), 0) + + val df4 = df1.groupBy('a).agg(sum('b)).agg(sum("sum(b)")) + assertCached(df4) + // reuse loaded cache + failAfter(3 seconds) { + checkDataset(df4, Row(10)) + } + + val df5 = df.agg(sum('a)).filter($"sum(a)" > 1) + assertCached(df5) + // first time use, load cache + checkDataset(df5, Row(10)) } } From 594ac4f7b816488091202918c409487058e6d8ac Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Mon, 25 Jun 2018 23:44:20 +0800 Subject: [PATCH 1009/1161] [SPARK-24633][SQL] Fix codegen when split is required for arrays_zip ## What changes were proposed in this pull request? In function array_zip, when split is required by the high number of arguments, a codegen error can happen. The PR fixes codegen for cases when splitting the code is required. ## How was this patch tested? added UT Author: Marco Gaido Closes #21621 from mgaido91/SPARK-24633. --- .../catalyst/expressions/collectionOperations.scala | 4 ++-- .../apache/spark/sql/DataFrameFunctionsSuite.scala | 11 +++++++++++ 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 3afabe14606e4..b6137b07555f4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -200,7 +200,7 @@ case class ArraysZip(children: Seq[Expression]) extends Expression with ExpectsI """.stripMargin } - val splittedGetValuesAndCardinalities = ctx.splitExpressions( + val splittedGetValuesAndCardinalities = ctx.splitExpressionsWithCurrentInputs( expressions = getValuesAndCardinalities, funcName = "getValuesAndCardinalities", returnType = "int", @@ -210,7 +210,7 @@ case class ArraysZip(children: Seq[Expression]) extends Expression with ExpectsI |return $biggestCardinality; """.stripMargin, foldFunctions = _.map(funcCall => s"$biggestCardinality = $funcCall;").mkString("\n"), - arguments = + extraArguments = ("ArrayData[]", arrVals) :: ("int", biggestCardinality) :: Nil) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 25fdbab745128..47fe67d8daea3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -556,6 +556,17 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { checkAnswer(df8.selectExpr("arrays_zip(v1, v2)"), expectedValue8) } + test("SPARK-24633: arrays_zip splits input processing correctly") { + Seq("true", "false").foreach { wholestageCodegenEnabled => + withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> wholestageCodegenEnabled) { + val df = spark.range(1) + val exprs = (0 to 5).map(x => array($"id" + lit(x))) + checkAnswer(df.select(arrays_zip(exprs: _*)), + Row(Seq(Row(0, 1, 2, 3, 4, 5)))) + } + } + } + test("map size function") { val df = Seq( (Map[Int, Int](1 -> 1, 2 -> 2), "x"), From 5264164a67df498b73facae207eda12ee133be7d Mon Sep 17 00:00:00 2001 From: Stacy Kerkela Date: Mon, 25 Jun 2018 23:41:39 +0200 Subject: [PATCH 1010/1161] [SPARK-24648][SQL] SqlMetrics should be threadsafe Use LongAdder to make SQLMetrics thread safe. ## What changes were proposed in this pull request? Replace += with LongAdder.add() for concurrent counting ## How was this patch tested? Unit tests with local threads Author: Stacy Kerkela Closes #21634 from dbkerkela/sqlmetrics-concurrency-stacy. --- .../sql/execution/metric/SQLMetrics.scala | 33 ++++++++++------- .../execution/metric/SQLMetricsSuite.scala | 36 ++++++++++++++++++- 2 files changed, 55 insertions(+), 14 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala index 77b907870d678..b4f0ae1eb1a18 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.metric import java.text.NumberFormat import java.util.Locale +import java.util.concurrent.atomic.LongAdder import org.apache.spark.SparkContext import org.apache.spark.scheduler.AccumulableInfo @@ -32,40 +33,45 @@ import org.apache.spark.util.{AccumulatorContext, AccumulatorV2, Utils} * on the driver side must be explicitly posted using [[SQLMetrics.postDriverMetricUpdates()]]. */ class SQLMetric(val metricType: String, initValue: Long = 0L) extends AccumulatorV2[Long, Long] { + // This is a workaround for SPARK-11013. // We may use -1 as initial value of the accumulator, if the accumulator is valid, we will // update it at the end of task and the value will be at least 0. Then we can filter out the -1 // values before calculate max, min, etc. - private[this] var _value = initValue - private var _zeroValue = initValue + private[this] val _value = new LongAdder + private val _zeroValue = initValue + _value.add(initValue) override def copy(): SQLMetric = { - val newAcc = new SQLMetric(metricType, _value) - newAcc._zeroValue = initValue + val newAcc = new SQLMetric(metricType, initValue) + newAcc.add(_value.sum()) newAcc } - override def reset(): Unit = _value = _zeroValue + override def reset(): Unit = this.set(_zeroValue) override def merge(other: AccumulatorV2[Long, Long]): Unit = other match { - case o: SQLMetric => _value += o.value + case o: SQLMetric => _value.add(o.value) case _ => throw new UnsupportedOperationException( s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}") } - override def isZero(): Boolean = _value == _zeroValue + override def isZero(): Boolean = _value.sum() == _zeroValue - override def add(v: Long): Unit = _value += v + override def add(v: Long): Unit = _value.add(v) // We can set a double value to `SQLMetric` which stores only long value, if it is // average metrics. def set(v: Double): Unit = SQLMetrics.setDoubleForAverageMetrics(this, v) - def set(v: Long): Unit = _value = v + def set(v: Long): Unit = { + _value.reset() + _value.add(v) + } - def +=(v: Long): Unit = _value += v + def +=(v: Long): Unit = _value.add(v) - override def value: Long = _value + override def value: Long = _value.sum() // Provide special identifier as metadata so we can tell that this is a `SQLMetric` later override def toInfo(update: Option[Any], value: Option[Any]): AccumulableInfo = { @@ -153,7 +159,7 @@ object SQLMetrics { Seq.fill(3)(0L) } else { val sorted = validValues.sorted - Seq(sorted(0), sorted(validValues.length / 2), sorted(validValues.length - 1)) + Seq(sorted.head, sorted(validValues.length / 2), sorted(validValues.length - 1)) } metric.map(v => numberFormat.format(v.toDouble / baseForAvgMetric)) } @@ -173,7 +179,8 @@ object SQLMetrics { Seq.fill(4)(0L) } else { val sorted = validValues.sorted - Seq(sorted.sum, sorted(0), sorted(validValues.length / 2), sorted(validValues.length - 1)) + Seq(sorted.sum, sorted.head, sorted(validValues.length / 2), + sorted(validValues.length - 1)) } metric.map(strFormat) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index a3a3f3851e21c..8263c9c81c49e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -19,12 +19,12 @@ package org.apache.spark.sql.execution.metric import java.io.File +import scala.concurrent.{ExecutionContext, ExecutionContextExecutor, Future} import scala.util.Random import org.apache.spark.SparkFunSuite import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.plans.logical.LocalRelation -import org.apache.spark.sql.execution.ui.SQLAppStatusStore import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext @@ -504,4 +504,38 @@ class SQLMetricsSuite extends SparkFunSuite with SQLMetricsTestUtils with Shared test("writing data out metrics with dynamic partition: parquet") { testMetricsDynamicPartition("parquet", "parquet", "t1") } + + test("writing metrics from single thread") { + val nAdds = 10 + val acc = new SQLMetric("test", -10) + assert(acc.isZero()) + acc.set(0) + for (i <- 1 to nAdds) acc.add(1) + assert(!acc.isZero()) + assert(nAdds === acc.value) + acc.reset() + assert(acc.isZero()) + } + + test("writing metrics from multiple threads") { + implicit val ec: ExecutionContextExecutor = ExecutionContext.global + val nFutures = 1000 + val nAdds = 100 + val acc = new SQLMetric("test", -10) + assert(acc.isZero() === true) + acc.set(0) + val l = for ( i <- 1 to nFutures ) yield { + Future { + for (j <- 1 to nAdds) acc.add(1) + i + } + } + for (futures <- Future.sequence(l)) { + assert(nFutures === futures.length) + assert(!acc.isZero()) + assert(nFutures * nAdds === acc.value) + acc.reset() + assert(acc.isZero()) + } + } } From baa01c8ca9e8ea456f986fbb223c61ad541b52b0 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Mon, 25 Jun 2018 15:12:33 -0700 Subject: [PATCH 1011/1161] [INFRA] Close stale PR. Closes #21614 From 6d16b9885d6ad01e1cc56d5241b7ebad99487a0c Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Mon, 25 Jun 2018 16:54:57 -0700 Subject: [PATCH 1012/1161] [SPARK-24552][CORE][SQL] Use task ID instead of attempt number for writes. This passes the unique task attempt id instead of attempt number to v2 data sources because attempt number is reused when stages are retried. When attempt numbers are reused, sources that track data by partition id and attempt number may incorrectly clean up data because the same attempt number can be both committed and aborted. For v1 / Hadoop writes, generate a unique ID based on available attempt numbers to avoid a similar problem. Closes #21558 Author: Marcelo Vanzin Author: Ryan Blue Closes #21606 from vanzin/SPARK-24552.2. --- .../spark/internal/io/SparkHadoopWriter.scala | 6 +++- .../sql/kafka010/KafkaStreamWriter.scala | 2 +- .../sources/v2/writer/DataSourceWriter.java | 8 +++--- .../sql/sources/v2/writer/DataWriter.java | 8 +++--- .../sources/v2/writer/DataWriterFactory.java | 11 +++----- .../datasources/v2/WriteToDataSourceV2.scala | 28 ++++++++++--------- .../continuous/ContinuousWriteRDD.scala | 2 +- .../sources/ForeachWriterProvider.scala | 2 +- .../sources/PackedRowWriterFactory.scala | 2 +- .../streaming/sources/memoryV2.scala | 2 +- .../sources/v2/SimpleWritableDataSource.scala | 8 +++--- 11 files changed, 41 insertions(+), 38 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopWriter.scala b/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopWriter.scala index abf39213fa0d2..9ebd0aa301592 100644 --- a/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopWriter.scala +++ b/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopWriter.scala @@ -76,13 +76,17 @@ object SparkHadoopWriter extends Logging { // Try to write all RDD partitions as a Hadoop OutputFormat. try { val ret = sparkContext.runJob(rdd, (context: TaskContext, iter: Iterator[(K, V)]) => { + // SPARK-24552: Generate a unique "attempt ID" based on the stage and task attempt numbers. + // Assumes that there won't be more than Short.MaxValue attempts, at least not concurrently. + val attemptId = (context.stageAttemptNumber << 16) | context.attemptNumber + executeTask( context = context, config = config, jobTrackerId = jobTrackerId, commitJobId = commitJobId, sparkPartitionId = context.partitionId, - sparkAttemptNumber = context.attemptNumber, + sparkAttemptNumber = attemptId, committer = committer, iterator = iter) }) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamWriter.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamWriter.scala index ae5b5c52d514e..32923dc9f5a6b 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamWriter.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamWriter.scala @@ -67,7 +67,7 @@ case class KafkaStreamWriterFactory( override def createDataWriter( partitionId: Int, - attemptNumber: Int, + taskId: Long, epochId: Long): DataWriter[InternalRow] = { new KafkaStreamDataWriter(topic, producerParams, schema.toAttributes) } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java index 0030a9f05dba7..7eedc85a5d6f3 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java @@ -64,8 +64,8 @@ public interface DataSourceWriter { DataWriterFactory createWriterFactory(); /** - * Returns whether Spark should use the commit coordinator to ensure that at most one attempt for - * each task commits. + * Returns whether Spark should use the commit coordinator to ensure that at most one task for + * each partition commits. * * @return true if commit coordinator should be used, false otherwise. */ @@ -90,9 +90,9 @@ default void onDataWriterCommit(WriterCommitMessage message) {} * is undefined and @{@link #abort(WriterCommitMessage[])} may not be able to deal with it. * * Note that speculative execution may cause multiple tasks to run for a partition. By default, - * Spark uses the commit coordinator to allow at most one attempt to commit. Implementations can + * Spark uses the commit coordinator to allow at most one task to commit. Implementations can * disable this behavior by overriding {@link #useCommitCoordinator()}. If disabled, multiple - * attempts may have committed successfully and one successful commit message per task will be + * tasks may have committed successfully and one successful commit message per task will be * passed to this commit method. The remaining commit messages are ignored by Spark. */ void commit(WriterCommitMessage[] messages); diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java index 39bf458298862..1626c0013e4e7 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java @@ -22,7 +22,7 @@ import org.apache.spark.annotation.InterfaceStability; /** - * A data writer returned by {@link DataWriterFactory#createDataWriter(int, int, long)} and is + * A data writer returned by {@link DataWriterFactory#createDataWriter(int, long, long)} and is * responsible for writing data for an input RDD partition. * * One Spark task has one exclusive data writer, so there is no thread-safe concern. @@ -39,14 +39,14 @@ * {@link DataSourceWriter#commit(WriterCommitMessage[])} with commit messages from other data * writers. If this data writer fails(one record fails to write or {@link #commit()} fails), an * exception will be sent to the driver side, and Spark may retry this writing task a few times. - * In each retry, {@link DataWriterFactory#createDataWriter(int, int, long)} will receive a - * different `attemptNumber`. Spark will call {@link DataSourceWriter#abort(WriterCommitMessage[])} + * In each retry, {@link DataWriterFactory#createDataWriter(int, long, long)} will receive a + * different `taskId`. Spark will call {@link DataSourceWriter#abort(WriterCommitMessage[])} * when the configured number of retries is exhausted. * * Besides the retry mechanism, Spark may launch speculative tasks if the existing writing task * takes too long to finish. Different from retried tasks, which are launched one by one after the * previous one fails, speculative tasks are running simultaneously. It's possible that one input - * RDD partition has multiple data writers with different `attemptNumber` running at the same time, + * RDD partition has multiple data writers with different `taskId` running at the same time, * and data sources should guarantee that these data writers don't conflict and can work together. * Implementations can coordinate with driver during {@link #commit()} to make sure only one of * these data writers can commit successfully. Or implementations can allow all of them to commit diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java index 7527bcc0c4027..0932ff8f8f8a7 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java @@ -42,15 +42,12 @@ public interface DataWriterFactory extends Serializable { * Usually Spark processes many RDD partitions at the same time, * implementations should use the partition id to distinguish writers for * different partitions. - * @param attemptNumber Spark may launch multiple tasks with the same task id. For example, a task - * failed, Spark launches a new task wth the same task id but different - * attempt number. Or a task is too slow, Spark launches new tasks wth the - * same task id but different attempt number, which means there are multiple - * tasks with the same task id running at the same time. Implementations can - * use this attempt number to distinguish writers of different task attempts. + * @param taskId A unique identifier for a task that is performing the write of the partition + * data. Spark may run multiple tasks for the same partition (due to speculation + * or task failures, for example). * @param epochId A monotonically increasing id for streaming queries that are split in to * discrete periods of execution. For non-streaming queries, * this ID will always be 0. */ - DataWriter createDataWriter(int partitionId, int attemptNumber, long epochId); + DataWriter createDataWriter(int partitionId, long taskId, long epochId); } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala index 11ed7131e7e3d..b1148c0f62f7c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala @@ -29,10 +29,8 @@ import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder} import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.sql.execution.streaming.{MicroBatchExecution, StreamExecution} -import org.apache.spark.sql.execution.streaming.continuous.{CommitPartitionEpoch, ContinuousExecution, EpochCoordinatorRef, SetWriterPartitions} +import org.apache.spark.sql.execution.streaming.MicroBatchExecution import org.apache.spark.sql.sources.v2.writer._ -import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils @@ -111,9 +109,10 @@ object DataWritingSparkTask extends Logging { val stageId = context.stageId() val stageAttempt = context.stageAttemptNumber() val partId = context.partitionId() + val taskId = context.taskAttemptId() val attemptId = context.attemptNumber() val epochId = Option(context.getLocalProperty(MicroBatchExecution.BATCH_ID_KEY)).getOrElse("0") - val dataWriter = writeTask.createDataWriter(partId, attemptId, epochId.toLong) + val dataWriter = writeTask.createDataWriter(partId, taskId, epochId.toLong) // write the data and commit this writer. Utils.tryWithSafeFinallyAndFailureCallbacks(block = { @@ -125,12 +124,12 @@ object DataWritingSparkTask extends Logging { val coordinator = SparkEnv.get.outputCommitCoordinator val commitAuthorized = coordinator.canCommit(stageId, stageAttempt, partId, attemptId) if (commitAuthorized) { - logInfo(s"Writer for stage $stageId / $stageAttempt, " + - s"task $partId.$attemptId is authorized to commit.") + logInfo(s"Commit authorized for partition $partId (task $taskId, attempt $attemptId" + + s"stage $stageId.$stageAttempt)") dataWriter.commit() } else { - val message = s"Stage $stageId / $stageAttempt, " + - s"task $partId.$attemptId: driver did not authorize commit" + val message = s"Commit denied for partition $partId (task $taskId, attempt $attemptId" + + s"stage $stageId.$stageAttempt)" logInfo(message) // throwing CommitDeniedException will trigger the catch block for abort throw new CommitDeniedException(message, stageId, partId, attemptId) @@ -141,15 +140,18 @@ object DataWritingSparkTask extends Logging { dataWriter.commit() } - logInfo(s"Writer for stage $stageId, task $partId.$attemptId committed.") + logInfo(s"Committed partition $partId (task $taskId, attempt $attemptId" + + s"stage $stageId.$stageAttempt)") msg })(catchBlock = { // If there is an error, abort this writer - logError(s"Writer for stage $stageId, task $partId.$attemptId is aborting.") + logError(s"Aborting commit for partition $partId (task $taskId, attempt $attemptId" + + s"stage $stageId.$stageAttempt)") dataWriter.abort() - logError(s"Writer for stage $stageId, task $partId.$attemptId aborted.") + logError(s"Aborted commit for partition $partId (task $taskId, attempt $attemptId" + + s"stage $stageId.$stageAttempt)") }) } } @@ -160,10 +162,10 @@ class InternalRowDataWriterFactory( override def createDataWriter( partitionId: Int, - attemptNumber: Int, + taskId: Long, epochId: Long): DataWriter[InternalRow] = { new InternalRowDataWriter( - rowWriterFactory.createDataWriter(partitionId, attemptNumber, epochId), + rowWriterFactory.createDataWriter(partitionId, taskId, epochId), RowEncoder.apply(schema).resolveAndBind()) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala index ef5f0da1e7cc2..76f3f5baa8d56 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala @@ -56,7 +56,7 @@ class ContinuousWriteRDD(var prev: RDD[InternalRow], writeTask: DataWriterFactor val dataIterator = prev.compute(split, context) dataWriter = writeTask.createDataWriter( context.partitionId(), - context.attemptNumber(), + context.taskAttemptId(), EpochTracker.getCurrentEpoch.get) while (dataIterator.hasNext) { dataWriter.write(dataIterator.next()) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterProvider.scala index f677f25f116a2..bc9b6d93ce7d9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterProvider.scala @@ -88,7 +88,7 @@ case class ForeachWriterFactory[T]( extends DataWriterFactory[InternalRow] { override def createDataWriter( partitionId: Int, - attemptNumber: Int, + taskId: Long, epochId: Long): ForeachDataWriter[T] = { new ForeachDataWriter(writer, rowConverter, partitionId, epochId) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PackedRowWriterFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PackedRowWriterFactory.scala index e07355aa37dba..b501d90c81f06 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PackedRowWriterFactory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PackedRowWriterFactory.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql.sources.v2.writer.{DataSourceWriter, DataWriter, Dat case object PackedRowWriterFactory extends DataWriterFactory[Row] { override def createDataWriter( partitionId: Int, - attemptNumber: Int, + taskId: Long, epochId: Long): DataWriter[Row] = { new PackedRowDataWriter() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala index 47b482007822d..29f8cca476722 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala @@ -180,7 +180,7 @@ class MemoryStreamWriter( case class MemoryWriterFactory(outputMode: OutputMode) extends DataWriterFactory[Row] { override def createDataWriter( partitionId: Int, - attemptNumber: Int, + taskId: Long, epochId: Long): DataWriter[Row] = { new MemoryDataWriter(partitionId, outputMode) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala index 694bb3b95b0f0..1334cf71ae988 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala @@ -209,10 +209,10 @@ class SimpleCSVDataWriterFactory(path: String, jobId: String, conf: Serializable override def createDataWriter( partitionId: Int, - attemptNumber: Int, + taskId: Long, epochId: Long): DataWriter[Row] = { val jobPath = new Path(new Path(path, "_temporary"), jobId) - val filePath = new Path(jobPath, s"$jobId-$partitionId-$attemptNumber") + val filePath = new Path(jobPath, s"$jobId-$partitionId-$taskId") val fs = filePath.getFileSystem(conf.value) new SimpleCSVDataWriter(fs, filePath) } @@ -245,10 +245,10 @@ class InternalRowCSVDataWriterFactory(path: String, jobId: String, conf: Seriali override def createDataWriter( partitionId: Int, - attemptNumber: Int, + taskId: Long, epochId: Long): DataWriter[InternalRow] = { val jobPath = new Path(new Path(path, "_temporary"), jobId) - val filePath = new Path(jobPath, s"$jobId-$partitionId-$attemptNumber") + val filePath = new Path(jobPath, s"$jobId-$partitionId-$taskId") val fs = filePath.getFileSystem(conf.value) new InternalRowCSVDataWriter(fs, filePath) } From d48803bf64dc0fccd6f560738b4682f0c05e767a Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Mon, 25 Jun 2018 17:08:23 -0700 Subject: [PATCH 1013/1161] [SPARK-24324][PYTHON][FOLLOWUP] Grouped Map positional conf should have deprecation note ## What changes were proposed in this pull request? Followup to the discussion of the added conf in SPARK-24324 which allows assignment by column position only. This conf is to preserve old behavior and will be removed in future releases, so it should have a note to indicate that. ## How was this patch tested? NA Author: Bryan Cutler Closes #21637 from BryanCutler/arrow-groupedMap-conf-deprecate-followup-SPARK-24324. --- .../src/main/scala/org/apache/spark/sql/internal/SQLConf.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index d5fb524a1396f..e768416f257c9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1167,7 +1167,7 @@ object SQLConf { .doc("When true, a grouped map Pandas UDF will assign columns from the returned " + "Pandas DataFrame based on position, regardless of column label type. When false, " + "columns will be looked up by name if labeled with a string and fallback to use " + - "position if not.") + "position if not. This configuration will be deprecated in future releases.") .booleanConf .createWithDefault(false) From 4c059ebc6008b4e78cbebc87a421cb87d1b800ed Mon Sep 17 00:00:00 2001 From: Bruce Robbins Date: Tue, 26 Jun 2018 09:48:15 +0800 Subject: [PATCH 1014/1161] [SPARK-23776][DOC] Update instructions for running PySpark after building with SBT ## What changes were proposed in this pull request? This update tells the reader how to build Spark with SBT such that pyspark-sql tests will succeed. If you follow the current instructions for building Spark with SBT, pyspark/sql/udf.py fails with:
    AnalysisException: u'Can not load class test.org.apache.spark.sql.JavaStringLength, please make sure it is on the classpath;'
    
    ## How was this patch tested? I ran the doc build command (SKIP_API=1 jekyll build) and eyeballed the result. Author: Bruce Robbins Closes #21628 from bersprockets/SPARK-23776_doc. --- docs/building-spark.md | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/docs/building-spark.md b/docs/building-spark.md index 0236bb05849ad..c3bcd90ccc78f 100644 --- a/docs/building-spark.md +++ b/docs/building-spark.md @@ -215,19 +215,23 @@ If you are building Spark for use in a Python environment and you wish to pip in Alternatively, you can also run make-distribution with the --pip option. -## PySpark Tests with Maven +## PySpark Tests with Maven or SBT If you are building PySpark and wish to run the PySpark tests you will need to build Spark with Hive support. ./build/mvn -DskipTests clean package -Phive ./python/run-tests +If you are building PySpark with SBT and wish to run the PySpark tests, you will need to build Spark with Hive support and also build the test components: + + ./build/sbt -Phive clean package + ./build/sbt test:compile + ./python/run-tests + The run-tests script also can be limited to a specific Python version or a specific module ./python/run-tests --python-executables=python --modules=pyspark-sql -**Note:** You can also run Python tests with an sbt build, provided you build Spark with Hive support. - ## Running R Tests To run the SparkR tests you will need to install the [knitr](https://cran.r-project.org/package=knitr), [rmarkdown](https://cran.r-project.org/package=rmarkdown), [testthat](https://cran.r-project.org/package=testthat), [e1071](https://cran.r-project.org/package=e1071) and [survival](https://cran.r-project.org/package=survival) packages first: From c7967c6049327a03b63ea7a3b0001a97d31e309d Mon Sep 17 00:00:00 2001 From: DB Tsai Date: Tue, 26 Jun 2018 09:48:52 +0800 Subject: [PATCH 1015/1161] [SPARK-24418][BUILD] Upgrade Scala to 2.11.12 and 2.12.6 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? Scala is upgraded to `2.11.12` and `2.12.6`. We used `loadFIles()` in `ILoop` as a hook to initialize the Spark before REPL sees any files in Scala `2.11.8`. However, it was a hack, and it was not intended to be a public API, so it was removed in Scala `2.11.12`. From the discussion in Scala community, https://github.com/scala/bug/issues/10913 , we can use `initializeSynchronous` to initialize Spark instead. This PR implements the Spark initialization there. However, in Scala `2.11.12`'s `ILoop.scala`, in function `def startup()`, the first thing it calls is `printWelcome()`. As a result, Scala will call `printWelcome()` and `splash` before calling `initializeSynchronous`. Thus, the Spark shell will allow users to type commends first, and then show the Spark UI URL. It's working, but it will change the Spark Shell interface as the following. ```scala ➜ apache-spark git:(scala-2.11.12) ✗ ./bin/spark-shell Setting default log level to "WARN". To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel). Welcome to ____ __ / __/__ ___ _____/ /__ _\ \/ _ \/ _ `/ __/ '_/ /___/ .__/\_,_/_/ /_/\_\ version 2.4.0-SNAPSHOT /_/ Using Scala version 2.11.12 (Java HotSpot(TM) 64-Bit Server VM, Java 1.8.0_161) Type in expressions to have them evaluated. Type :help for more information. scala> Spark context Web UI available at http://192.168.1.169:4040 Spark context available as 'sc' (master = local[*], app id = local-1528180279528). Spark session available as 'spark'. scala> ``` It seems there is no easy way to inject the Spark initialization code in the proper place as Scala doesn't provide a hook. Maybe som-snytt can comment on this. The following command is used to update the dep files. ```scala ./dev/test-dependencies.sh --replace-manifest ``` ## How was this patch tested? Existing tests Author: DB Tsai Closes #21495 from dbtsai/scala-2.11.12. --- LICENSE | 12 +++++----- dev/deps/spark-deps-hadoop-2.6 | 10 ++++---- dev/deps/spark-deps-hadoop-2.7 | 10 ++++---- dev/deps/spark-deps-hadoop-3.1 | 10 ++++---- pom.xml | 8 +++---- .../org/apache/spark/repl/SparkILoop.scala | 24 +++++++------------ .../spark/repl/SparkILoopInterpreter.scala | 18 ++++++++++++-- 7 files changed, 50 insertions(+), 42 deletions(-) diff --git a/LICENSE b/LICENSE index cc1f580207a75..6f5d9452e800d 100644 --- a/LICENSE +++ b/LICENSE @@ -243,18 +243,18 @@ The text of each license is also included at licenses/LICENSE-[project].txt. (BSD licence) ANTLR ST4 4.0.4 (org.antlr:ST4:4.0.4 - http://www.stringtemplate.org) (BSD licence) ANTLR StringTemplate (org.antlr:stringtemplate:3.2.1 - http://www.stringtemplate.org) (BSD License) Javolution (javolution:javolution:5.5.1 - http://javolution.org) - (BSD) JLine (jline:jline:0.9.94 - http://jline.sourceforge.net) + (BSD) JLine (jline:jline:2.14.3 - https://github.com/jline/jline2) (BSD) ParaNamer Core (com.thoughtworks.paranamer:paranamer:2.3 - http://paranamer.codehaus.org/paranamer) (BSD) ParaNamer Core (com.thoughtworks.paranamer:paranamer:2.6 - http://paranamer.codehaus.org/paranamer) (BSD 3 Clause) Scala (http://www.scala-lang.org/download/#License) (Interpreter classes (all .scala files in repl/src/main/scala except for Main.Scala, SparkHelper.scala and ExecutorClassLoader.scala), and for SerializableMapWrapper in JavaUtils.scala) - (BSD-like) Scala Actors library (org.scala-lang:scala-actors:2.11.8 - http://www.scala-lang.org/) - (BSD-like) Scala Compiler (org.scala-lang:scala-compiler:2.11.8 - http://www.scala-lang.org/) - (BSD-like) Scala Compiler (org.scala-lang:scala-reflect:2.11.8 - http://www.scala-lang.org/) - (BSD-like) Scala Library (org.scala-lang:scala-library:2.11.8 - http://www.scala-lang.org/) - (BSD-like) Scalap (org.scala-lang:scalap:2.11.8 - http://www.scala-lang.org/) + (BSD-like) Scala Actors library (org.scala-lang:scala-actors:2.11.12 - http://www.scala-lang.org/) + (BSD-like) Scala Compiler (org.scala-lang:scala-compiler:2.11.12 - http://www.scala-lang.org/) + (BSD-like) Scala Compiler (org.scala-lang:scala-reflect:2.11.12 - http://www.scala-lang.org/) + (BSD-like) Scala Library (org.scala-lang:scala-library:2.11.12 - http://www.scala-lang.org/) + (BSD-like) Scalap (org.scala-lang:scalap:2.11.12 - http://www.scala-lang.org/) (BSD-style) scalacheck (org.scalacheck:scalacheck_2.11:1.10.0 - http://www.scalacheck.org) (BSD-style) spire (org.spire-math:spire_2.11:0.7.1 - http://spire-math.org) (BSD-style) spire-macros (org.spire-math:spire-macros_2.11:0.7.1 - http://spire-math.org) diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6 index 723180a14febb..96e9c27210d05 100644 --- a/dev/deps/spark-deps-hadoop-2.6 +++ b/dev/deps/spark-deps-hadoop-2.6 @@ -122,7 +122,7 @@ jersey-server-2.22.2.jar jets3t-0.9.4.jar jetty-6.1.26.jar jetty-util-6.1.26.jar -jline-2.12.1.jar +jline-2.14.3.jar joda-time-2.9.3.jar jodd-core-3.5.2.jar jpam-1.1.jar @@ -172,10 +172,10 @@ parquet-jackson-1.10.0.jar protobuf-java-2.5.0.jar py4j-0.10.7.jar pyrolite-4.13.jar -scala-compiler-2.11.8.jar -scala-library-2.11.8.jar -scala-parser-combinators_2.11-1.0.4.jar -scala-reflect-2.11.8.jar +scala-compiler-2.11.12.jar +scala-library-2.11.12.jar +scala-parser-combinators_2.11-1.1.0.jar +scala-reflect-2.11.12.jar scala-xml_2.11-1.0.5.jar shapeless_2.11-2.3.2.jar slf4j-api-1.7.16.jar diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index ea08a001a1c9b..4a6ee027ec355 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -122,7 +122,7 @@ jersey-server-2.22.2.jar jets3t-0.9.4.jar jetty-6.1.26.jar jetty-util-6.1.26.jar -jline-2.12.1.jar +jline-2.14.3.jar joda-time-2.9.3.jar jodd-core-3.5.2.jar jpam-1.1.jar @@ -173,10 +173,10 @@ parquet-jackson-1.10.0.jar protobuf-java-2.5.0.jar py4j-0.10.7.jar pyrolite-4.13.jar -scala-compiler-2.11.8.jar -scala-library-2.11.8.jar -scala-parser-combinators_2.11-1.0.4.jar -scala-reflect-2.11.8.jar +scala-compiler-2.11.12.jar +scala-library-2.11.12.jar +scala-parser-combinators_2.11-1.1.0.jar +scala-reflect-2.11.12.jar scala-xml_2.11-1.0.5.jar shapeless_2.11-2.3.2.jar slf4j-api-1.7.16.jar diff --git a/dev/deps/spark-deps-hadoop-3.1 b/dev/deps/spark-deps-hadoop-3.1 index da874026d7d10..e0b560c8ec71f 100644 --- a/dev/deps/spark-deps-hadoop-3.1 +++ b/dev/deps/spark-deps-hadoop-3.1 @@ -122,7 +122,7 @@ jersey-server-2.22.2.jar jets3t-0.9.4.jar jetty-webapp-9.3.20.v20170531.jar jetty-xml-9.3.20.v20170531.jar -jline-2.12.1.jar +jline-2.14.3.jar joda-time-2.9.3.jar jodd-core-3.5.2.jar jpam-1.1.jar @@ -192,10 +192,10 @@ protobuf-java-2.5.0.jar py4j-0.10.7.jar pyrolite-4.13.jar re2j-1.1.jar -scala-compiler-2.11.8.jar -scala-library-2.11.8.jar -scala-parser-combinators_2.11-1.0.4.jar -scala-reflect-2.11.8.jar +scala-compiler-2.11.12.jar +scala-library-2.11.12.jar +scala-parser-combinators_2.11-1.1.0.jar +scala-reflect-2.11.12.jar scala-xml_2.11-1.0.5.jar shapeless_2.11-2.3.2.jar slf4j-api-1.7.16.jar diff --git a/pom.xml b/pom.xml index 4b4e6c13ea8fd..90e64ff71d229 100644 --- a/pom.xml +++ b/pom.xml @@ -155,7 +155,7 @@ 3.4.1 3.2.2 - 2.11.8 + 2.11.12 2.11 1.9.13 2.6.7 @@ -740,13 +740,13 @@ org.scala-lang.modules scala-parser-combinators_${scala.binary.version} - 1.0.4 + 1.1.0 jline jline - 2.12.1 + 2.14.3 org.scalatest @@ -2755,7 +2755,7 @@ scala-2.12 - 2.12.4 + 2.12.6 2.12 diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala index e69441a475e9a..a44051b351e19 100644 --- a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala +++ b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala @@ -36,7 +36,7 @@ class SparkILoop(in0: Option[BufferedReader], out: JPrintWriter) def this() = this(None, new JPrintWriter(Console.out, true)) override def createInterpreter(): Unit = { - intp = new SparkILoopInterpreter(settings, out) + intp = new SparkILoopInterpreter(settings, out, initializeSpark) } val initializationCommands: Seq[String] = Seq( @@ -73,11 +73,15 @@ class SparkILoop(in0: Option[BufferedReader], out: JPrintWriter) "import org.apache.spark.sql.functions._" ) - def initializeSpark() { - intp.beQuietDuring { - savingReplayStack { // remove the commands from session history. - initializationCommands.foreach(processLine) + def initializeSpark(): Unit = { + if (!intp.reporter.hasErrors) { + // `savingReplayStack` removes the commands from session history. + savingReplayStack { + initializationCommands.foreach(intp quietRun _) } + } else { + throw new RuntimeException(s"Scala $versionString interpreter encountered " + + "errors during initialization") } } @@ -101,16 +105,6 @@ class SparkILoop(in0: Option[BufferedReader], out: JPrintWriter) /** Available commands */ override def commands: List[LoopCommand] = standardCommands - /** - * We override `loadFiles` because we need to initialize Spark *before* the REPL - * sees any files, so that the Spark context is visible in those files. This is a bit of a - * hack, but there isn't another hook available to us at this point. - */ - override def loadFiles(settings: Settings): Unit = { - initializeSpark() - super.loadFiles(settings) - } - override def resetCommand(line: String): Unit = { super.resetCommand(line) initializeSpark() diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoopInterpreter.scala b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoopInterpreter.scala index e736607a9a6b9..4e63816402a10 100644 --- a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoopInterpreter.scala +++ b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoopInterpreter.scala @@ -21,8 +21,22 @@ import scala.collection.mutable import scala.tools.nsc.Settings import scala.tools.nsc.interpreter._ -class SparkILoopInterpreter(settings: Settings, out: JPrintWriter) extends IMain(settings, out) { - self => +class SparkILoopInterpreter(settings: Settings, out: JPrintWriter, initializeSpark: () => Unit) + extends IMain(settings, out) { self => + + /** + * We override `initializeSynchronous` to initialize Spark *after* `intp` is properly initialized + * and *before* the REPL sees any files in the private `loadInitFiles` functions, so that + * the Spark context is visible in those files. + * + * This is a bit of a hack, but there isn't another hook available to us at this point. + * + * See the discussion in Scala community https://github.com/scala/bug/issues/10913 for detail. + */ + override def initializeSynchronous(): Unit = { + super.initializeSynchronous() + initializeSpark() + } override lazy val memberHandlers = new { val intp: self.type = self From e07aee2165af4d301ae12005a6d9ffb030bc2650 Mon Sep 17 00:00:00 2001 From: Marek Novotny Date: Tue, 26 Jun 2018 09:51:55 +0800 Subject: [PATCH 1016/1161] [SPARK-24636][SQL] Type coercion of arrays for array_join function ## What changes were proposed in this pull request? Presto's implementation accepts arbitrary arrays of primitive types as an input: ``` presto> SELECT array_join(ARRAY [1, 2, 3], ', '); _col0 --------- 1, 2, 3 (1 row) ``` This PR proposes to implement a type coercion rule for ```array_join``` function that converts arrays of primitive as well as non-primitive types to arrays of string. ## How was this patch tested? New test cases add into: - sql-tests/inputs/typeCoercion/native/arrayJoin.sql - DataFrameFunctionsSuite.scala Author: Marek Novotny Closes #21620 from mn-mikke/SPARK-24636. --- .../sql/catalyst/analysis/TypeCoercion.scala | 8 ++ .../expressions/collectionOperations.scala | 1 + .../inputs/typeCoercion/native/arrayJoin.sql | 11 +++ .../typeCoercion/native/arrayJoin.sql.out | 90 +++++++++++++++++++ .../spark/sql/DataFrameFunctionsSuite.scala | 17 ++++ 5 files changed, 127 insertions(+) create mode 100644 sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/arrayJoin.sql create mode 100644 sql/core/src/test/resources/sql-tests/results/typeCoercion/native/arrayJoin.sql.out diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index b2817b0538a7f..637923928a7da 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -536,6 +536,14 @@ object TypeCoercion { case None => c } + case aj @ ArrayJoin(arr, d, nr) if !ArrayType(StringType).acceptsType(arr.dataType) && + ArrayType.acceptsType(arr.dataType) => + val containsNull = arr.dataType.asInstanceOf[ArrayType].containsNull + ImplicitTypeCasts.implicitCast(arr, ArrayType(StringType, containsNull)) match { + case Some(castedArr) => ArrayJoin(castedArr, d, nr) + case None => aj + } + case m @ CreateMap(children) if m.keys.length == m.values.length && (!haveSameType(m.keys) || !haveSameType(m.values)) => val newKeys = if (haveSameType(m.keys)) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index b6137b07555f4..58612f65c1a53 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -1621,6 +1621,7 @@ case class ArrayJoin( override def dataType: DataType = StringType + override def prettyName: String = "array_join" } /** diff --git a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/arrayJoin.sql b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/arrayJoin.sql new file mode 100644 index 0000000000000..99729c007b104 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/arrayJoin.sql @@ -0,0 +1,11 @@ +SELECT array_join(array(true, false), ', '); +SELECT array_join(array(2Y, 1Y), ', '); +SELECT array_join(array(2S, 1S), ', '); +SELECT array_join(array(2, 1), ', '); +SELECT array_join(array(2L, 1L), ', '); +SELECT array_join(array(9223372036854775809, 9223372036854775808), ', '); +SELECT array_join(array(2.0D, 1.0D), ', '); +SELECT array_join(array(float(2.0), float(1.0)), ', '); +SELECT array_join(array(date '2016-03-14', date '2016-03-13'), ', '); +SELECT array_join(array(timestamp '2016-11-15 20:54:00.000', timestamp '2016-11-12 20:54:00.000'), ', '); +SELECT array_join(array('a', 'b'), ', '); diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/arrayJoin.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/arrayJoin.sql.out new file mode 100644 index 0000000000000..b23a62dacef7c --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/arrayJoin.sql.out @@ -0,0 +1,90 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 11 + + +-- !query 0 +SELECT array_join(array(true, false), ', ') +-- !query 0 schema +struct +-- !query 0 output +true, false + + +-- !query 1 +SELECT array_join(array(2Y, 1Y), ', ') +-- !query 1 schema +struct +-- !query 1 output +2, 1 + + +-- !query 2 +SELECT array_join(array(2S, 1S), ', ') +-- !query 2 schema +struct +-- !query 2 output +2, 1 + + +-- !query 3 +SELECT array_join(array(2, 1), ', ') +-- !query 3 schema +struct +-- !query 3 output +2, 1 + + +-- !query 4 +SELECT array_join(array(2L, 1L), ', ') +-- !query 4 schema +struct +-- !query 4 output +2, 1 + + +-- !query 5 +SELECT array_join(array(9223372036854775809, 9223372036854775808), ', ') +-- !query 5 schema +struct +-- !query 5 output +9223372036854775809, 9223372036854775808 + + +-- !query 6 +SELECT array_join(array(2.0D, 1.0D), ', ') +-- !query 6 schema +struct +-- !query 6 output +2.0, 1.0 + + +-- !query 7 +SELECT array_join(array(float(2.0), float(1.0)), ', ') +-- !query 7 schema +struct +-- !query 7 output +2.0, 1.0 + + +-- !query 8 +SELECT array_join(array(date '2016-03-14', date '2016-03-13'), ', ') +-- !query 8 schema +struct +-- !query 8 output +2016-03-14, 2016-03-13 + + +-- !query 9 +SELECT array_join(array(timestamp '2016-11-15 20:54:00.000', timestamp '2016-11-12 20:54:00.000'), ', ') +-- !query 9 schema +struct +-- !query 9 output +2016-11-15 20:54:00, 2016-11-12 20:54:00 + + +-- !query 10 +SELECT array_join(array('a', 'b'), ', ') +-- !query 10 schema +struct +-- !query 10 output +a, b diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 47fe67d8daea3..5d6a6c0832c96 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -805,6 +805,23 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { checkAnswer( df.selectExpr("array_join(x, delimiter, 'NULL')"), Seq(Row("a,b"), Row("a,NULL,b"), Row(""))) + + val idf = Seq(Seq(1, 2, 3)).toDF("x") + + checkAnswer( + idf.select(array_join(idf("x"), ", ")), + Seq(Row("1, 2, 3")) + ) + checkAnswer( + idf.selectExpr("array_join(x, ', ')"), + Seq(Row("1, 2, 3")) + ) + intercept[AnalysisException] { + idf.selectExpr("array_join(x, 1)") + } + intercept[AnalysisException] { + idf.selectExpr("array_join(x, ', ', 1)") + } } test("array_min function") { From dcaa49ff1edd7fcf0f000c6f93ae0e30bd5b6464 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Tue, 26 Jun 2018 14:33:04 -0700 Subject: [PATCH 1017/1161] [SPARK-24658][SQL] Remove workaround for ANTLR bug ## What changes were proposed in this pull request? Issue antlr/antlr4#781 has already been fixed, so the workaround of extracting the pattern into a separate rule is no longer needed. The presto already removed it: https://github.com/prestodb/presto/pull/10744. ## How was this patch tested? Existing tests Author: Yuming Wang Closes #21641 from wangyum/ANTLR-780. --- .../org/apache/spark/sql/catalyst/parser/SqlBase.g4 | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 3fe00eefde7d8..dc95751bf905c 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -539,18 +539,11 @@ expression booleanExpression : NOT booleanExpression #logicalNot | EXISTS '(' query ')' #exists - | predicated #booleanDefault + | valueExpression predicate? #predicated | left=booleanExpression operator=AND right=booleanExpression #logicalBinary | left=booleanExpression operator=OR right=booleanExpression #logicalBinary ; -// workaround for: -// https://github.com/antlr/antlr4/issues/780 -// https://github.com/antlr/antlr4/issues/781 -predicated - : valueExpression predicate? - ; - predicate : NOT? kind=BETWEEN lower=valueExpression AND upper=valueExpression | NOT? kind=IN '(' expression (',' expression)* ')' From 02f8781fa2649cf1d3a5cb932e1c8408790974ff Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Tue, 26 Jun 2018 15:17:00 -0700 Subject: [PATCH 1018/1161] [SPARK-24423][SQL] Add a new option for JDBC sources ## What changes were proposed in this pull request? Here is the description in the JIRA - Currently, our JDBC connector provides the option `dbtable` for users to specify the to-be-loaded JDBC source table. ```SQL val jdbcDf = spark.read .format("jdbc") .option("dbtable", "dbName.tableName") .options(jdbcCredentials: Map) .load() ``` Normally, users do not fetch the whole JDBC table due to the poor performance/throughput of JDBC. Thus, they normally just fetch a small set of tables. For advanced users, they can pass a subquery as the option. ```SQL val query = """ (select * from tableName limit 10) as tmp """ val jdbcDf = spark.read .format("jdbc") .option("dbtable", query) .options(jdbcCredentials: Map) .load() ``` However, this is straightforward to end users. We should simply allow users to specify the query by a new option `query`. We will handle the complexity for them. ```SQL val query = """select * from tableName limit 10""" val jdbcDf = spark.read .format("jdbc") .option("query", query) .options(jdbcCredentials: Map) .load() ``` ## How was this patch tested? Added tests in JDBCSuite and JDBCWriterSuite. Also tested against MySQL, Postgress, Oracle, DB2 (using docker infrastructure) to make sure there are no syntax issues. Author: Dilip Biswal Closes #21590 from dilipbiswal/SPARK-24423. --- docs/sql-programming-guide.md | 30 +++++- .../datasources/jdbc/JDBCOptions.scala | 66 ++++++++++++- .../execution/datasources/jdbc/JDBCRDD.scala | 4 +- .../datasources/jdbc/JDBCRelation.scala | 4 +- .../jdbc/JdbcRelationProvider.scala | 5 +- .../datasources/jdbc/JdbcUtils.scala | 10 +- .../org/apache/spark/sql/jdbc/JDBCSuite.scala | 94 ++++++++++++++++++- .../spark/sql/jdbc/JDBCWriteSuite.scala | 14 ++- 8 files changed, 204 insertions(+), 23 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 196b814420be1..7c4ef41cc8907 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1302,9 +1302,33 @@ the following case-insensitive options: dbtable - The JDBC table that should be read. Note that anything that is valid in a FROM clause of - a SQL query can be used. For example, instead of a full table you could also use a - subquery in parentheses. + The JDBC table that should be read from or written into. Note that when using it in the read + path anything that is valid in a FROM clause of a SQL query can be used. + For example, instead of a full table you could also use a subquery in parentheses. It is not + allowed to specify `dbtable` and `query` options at the same time. + + + + query + + A query that will be used to read data into Spark. The specified query will be parenthesized and used + as a subquery in the FROM clause. Spark will also assign an alias to the subquery clause. + As an example, spark will issue a query of the following form to the JDBC Source.

    + SELECT <columns> FROM (<user_specified_query>) spark_gen_alias

    + Below are couple of restrictions while using this option.
    +
      +
    1. It is not allowed to specify `dbtable` and `query` options at the same time.
    2. +
    3. It is not allowed to spcify `query` and `partitionColumn` options at the same time. When specifying + `partitionColumn` option is required, the subquery can be specified using `dbtable` option instead and + partition columns can be qualified using the subquery alias provided as part of `dbtable`.
      + Example:
      + + spark.read.format("jdbc")
      +    .option("dbtable", "(select c1, c2 from t1) as subq")
      +    .option("partitionColumn", "subq.c1"
      +    .load() +
    4. +
    diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala index a73a97c06fe5a..eea966d30948b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.types.StructType * Options for the JDBC data source. */ class JDBCOptions( - @transient private val parameters: CaseInsensitiveMap[String]) + @transient val parameters: CaseInsensitiveMap[String]) extends Serializable { import JDBCOptions._ @@ -65,11 +65,31 @@ class JDBCOptions( // Required parameters // ------------------------------------------------------------ require(parameters.isDefinedAt(JDBC_URL), s"Option '$JDBC_URL' is required.") - require(parameters.isDefinedAt(JDBC_TABLE_NAME), s"Option '$JDBC_TABLE_NAME' is required.") // a JDBC URL val url = parameters(JDBC_URL) - // name of table - val table = parameters(JDBC_TABLE_NAME) + // table name or a table subquery. + val tableOrQuery = (parameters.get(JDBC_TABLE_NAME), parameters.get(JDBC_QUERY_STRING)) match { + case (Some(name), Some(subquery)) => + throw new IllegalArgumentException( + s"Both '$JDBC_TABLE_NAME' and '$JDBC_QUERY_STRING' can not be specified at the same time." + ) + case (None, None) => + throw new IllegalArgumentException( + s"Option '$JDBC_TABLE_NAME' or '$JDBC_QUERY_STRING' is required." + ) + case (Some(name), None) => + if (name.isEmpty) { + throw new IllegalArgumentException(s"Option '$JDBC_TABLE_NAME' can not be empty.") + } else { + name.trim + } + case (None, Some(subquery)) => + if (subquery.isEmpty) { + throw new IllegalArgumentException(s"Option `$JDBC_QUERY_STRING` can not be empty.") + } else { + s"(${subquery}) __SPARK_GEN_JDBC_SUBQUERY_NAME_${curId.getAndIncrement()}" + } + } // ------------------------------------------------------------ // Optional parameters @@ -109,6 +129,20 @@ class JDBCOptions( s"When reading JDBC data sources, users need to specify all or none for the following " + s"options: '$JDBC_PARTITION_COLUMN', '$JDBC_LOWER_BOUND', '$JDBC_UPPER_BOUND', " + s"and '$JDBC_NUM_PARTITIONS'") + + require(!(parameters.get(JDBC_QUERY_STRING).isDefined && partitionColumn.isDefined), + s""" + |Options '$JDBC_QUERY_STRING' and '$JDBC_PARTITION_COLUMN' can not be specified together. + |Please define the query using `$JDBC_TABLE_NAME` option instead and make sure to qualify + |the partition columns using the supplied subquery alias to resolve any ambiguity. + |Example : + |spark.read.format("jdbc") + | .option("dbtable", "(select c1, c2 from t1) as subq") + | .option("partitionColumn", "subq.c1" + | .load() + """.stripMargin + ) + val fetchSize = { val size = parameters.getOrElse(JDBC_BATCH_FETCH_SIZE, "0").toInt require(size >= 0, @@ -149,7 +183,30 @@ class JDBCOptions( val sessionInitStatement = parameters.get(JDBC_SESSION_INIT_STATEMENT) } +class JdbcOptionsInWrite( + @transient override val parameters: CaseInsensitiveMap[String]) + extends JDBCOptions(parameters) { + + import JDBCOptions._ + + def this(parameters: Map[String, String]) = this(CaseInsensitiveMap(parameters)) + + def this(url: String, table: String, parameters: Map[String, String]) = { + this(CaseInsensitiveMap(parameters ++ Map( + JDBCOptions.JDBC_URL -> url, + JDBCOptions.JDBC_TABLE_NAME -> table))) + } + + require( + parameters.get(JDBC_TABLE_NAME).isDefined, + s"Option '$JDBC_TABLE_NAME' is required. " + + s"Option '$JDBC_QUERY_STRING' is not applicable while writing.") + + val table = parameters(JDBC_TABLE_NAME) +} + object JDBCOptions { + private val curId = new java.util.concurrent.atomic.AtomicLong(0L) private val jdbcOptionNames = collection.mutable.Set[String]() private def newOption(name: String): String = { @@ -159,6 +216,7 @@ object JDBCOptions { val JDBC_URL = newOption("url") val JDBC_TABLE_NAME = newOption("dbtable") + val JDBC_QUERY_STRING = newOption("query") val JDBC_DRIVER_CLASS = newOption("driver") val JDBC_PARTITION_COLUMN = newOption("partitionColumn") val JDBC_LOWER_BOUND = newOption("lowerBound") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index 0bab3689e5d0e..1b3b17c75e756 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -51,7 +51,7 @@ object JDBCRDD extends Logging { */ def resolveTable(options: JDBCOptions): StructType = { val url = options.url - val table = options.table + val table = options.tableOrQuery val dialect = JdbcDialects.get(url) val conn: Connection = JdbcUtils.createConnectionFactory(options)() try { @@ -296,7 +296,7 @@ private[jdbc] class JDBCRDD( val myWhereClause = getWhereClause(part) - val sqlText = s"SELECT $columnList FROM ${options.table} $myWhereClause" + val sqlText = s"SELECT $columnList FROM ${options.tableOrQuery} $myWhereClause" stmt = conn.prepareStatement(sqlText, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY) stmt.setFetchSize(options.fetchSize) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala index b84543ccd7869..97e2d255cb7be 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala @@ -189,12 +189,12 @@ private[sql] case class JDBCRelation( override def insert(data: DataFrame, overwrite: Boolean): Unit = { data.write .mode(if (overwrite) SaveMode.Overwrite else SaveMode.Append) - .jdbc(jdbcOptions.url, jdbcOptions.table, jdbcOptions.asProperties) + .jdbc(jdbcOptions.url, jdbcOptions.tableOrQuery, jdbcOptions.asProperties) } override def toString: String = { val partitioningInfo = if (parts.nonEmpty) s" [numPartitions=${parts.length}]" else "" // credentials should not be included in the plan output, table information is sufficient. - s"JDBCRelation(${jdbcOptions.table})" + partitioningInfo + s"JDBCRelation(${jdbcOptions.tableOrQuery})" + partitioningInfo } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala index 2b488bb7121dc..782d626c1573c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala @@ -59,7 +59,7 @@ class JdbcRelationProvider extends CreatableRelationProvider mode: SaveMode, parameters: Map[String, String], df: DataFrame): BaseRelation = { - val options = new JDBCOptions(parameters) + val options = new JdbcOptionsInWrite(parameters) val isCaseSensitive = sqlContext.conf.caseSensitiveAnalysis val conn = JdbcUtils.createConnectionFactory(options)() @@ -86,7 +86,8 @@ class JdbcRelationProvider extends CreatableRelationProvider case SaveMode.ErrorIfExists => throw new AnalysisException( - s"Table or view '${options.table}' already exists. SaveMode: ErrorIfExists.") + s"Table or view '${options.table}' already exists. " + + s"SaveMode: ErrorIfExists.") case SaveMode.Ignore => // With `SaveMode.Ignore` mode, if table already exists, the save operation is expected diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index 433443007cfd8..b81737eda475b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -67,7 +67,7 @@ object JdbcUtils extends Logging { /** * Returns true if the table already exists in the JDBC database. */ - def tableExists(conn: Connection, options: JDBCOptions): Boolean = { + def tableExists(conn: Connection, options: JdbcOptionsInWrite): Boolean = { val dialect = JdbcDialects.get(options.url) // Somewhat hacky, but there isn't a good way to identify whether a table exists for all @@ -100,7 +100,7 @@ object JdbcUtils extends Logging { /** * Truncates a table from the JDBC database without side effects. */ - def truncateTable(conn: Connection, options: JDBCOptions): Unit = { + def truncateTable(conn: Connection, options: JdbcOptionsInWrite): Unit = { val dialect = JdbcDialects.get(options.url) val statement = conn.createStatement try { @@ -255,7 +255,7 @@ object JdbcUtils extends Logging { val dialect = JdbcDialects.get(options.url) try { - val statement = conn.prepareStatement(dialect.getSchemaQuery(options.table)) + val statement = conn.prepareStatement(dialect.getSchemaQuery(options.tableOrQuery)) try { statement.setQueryTimeout(options.queryTimeout) Some(getSchema(statement.executeQuery(), dialect)) @@ -809,7 +809,7 @@ object JdbcUtils extends Logging { df: DataFrame, tableSchema: Option[StructType], isCaseSensitive: Boolean, - options: JDBCOptions): Unit = { + options: JdbcOptionsInWrite): Unit = { val url = options.url val table = options.table val dialect = JdbcDialects.get(url) @@ -838,7 +838,7 @@ object JdbcUtils extends Logging { def createTable( conn: Connection, df: DataFrame, - options: JDBCOptions): Unit = { + options: JdbcOptionsInWrite): Unit = { val strSchema = schemaString( df, options.url, options.createTableColumnTypes) val table = options.table diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 6ea61f02a8206..0389273d6cdfa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -25,7 +25,7 @@ import org.h2.jdbc.JdbcSQLException import org.scalatest.{BeforeAndAfter, PrivateMethodTester} import org.apache.spark.{SparkException, SparkFunSuite} -import org.apache.spark.sql.{AnalysisException, DataFrame, Row} +import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row} import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.execution.DataSourceScanExec @@ -39,7 +39,7 @@ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.util.Utils -class JDBCSuite extends SparkFunSuite +class JDBCSuite extends QueryTest with BeforeAndAfter with PrivateMethodTester with SharedSQLContext { import testImplicits._ @@ -1099,7 +1099,7 @@ class JDBCSuite extends SparkFunSuite test("SPARK-19318: Connection properties keys should be case-sensitive.") { def testJdbcOptions(options: JDBCOptions): Unit = { // Spark JDBC data source options are case-insensitive - assert(options.table == "t1") + assert(options.tableOrQuery == "t1") // When we convert it to properties, it should be case-sensitive. assert(options.asProperties.size == 3) assert(options.asProperties.get("customkey") == null) @@ -1255,4 +1255,92 @@ class JDBCSuite extends SparkFunSuite testIncorrectJdbcPartitionColumn(testH2Dialect.quoteIdentifier("ThEiD")) } } + + test("query JDBC option - negative tests") { + val query = "SELECT * FROM test.people WHERE theid = 1" + // load path + val e1 = intercept[RuntimeException] { + val df = spark.read.format("jdbc") + .option("Url", urlWithUserAndPass) + .option("query", query) + .option("dbtable", "test.people") + .load() + }.getMessage + assert(e1.contains("Both 'dbtable' and 'query' can not be specified at the same time.")) + + // jdbc api path + val properties = new Properties() + properties.setProperty(JDBCOptions.JDBC_QUERY_STRING, query) + val e2 = intercept[RuntimeException] { + spark.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", properties).collect() + }.getMessage + assert(e2.contains("Both 'dbtable' and 'query' can not be specified at the same time.")) + + val e3 = intercept[RuntimeException] { + sql( + s""" + |CREATE OR REPLACE TEMPORARY VIEW queryOption + |USING org.apache.spark.sql.jdbc + |OPTIONS (url '$url', query '$query', dbtable 'TEST.PEOPLE', + | user 'testUser', password 'testPass') + """.stripMargin.replaceAll("\n", " ")) + }.getMessage + assert(e3.contains("Both 'dbtable' and 'query' can not be specified at the same time.")) + + val e4 = intercept[RuntimeException] { + val df = spark.read.format("jdbc") + .option("Url", urlWithUserAndPass) + .option("query", "") + .load() + }.getMessage + assert(e4.contains("Option `query` can not be empty.")) + + // Option query and partitioncolumn are not allowed together. + val expectedErrorMsg = + s""" + |Options 'query' and 'partitionColumn' can not be specified together. + |Please define the query using `dbtable` option instead and make sure to qualify + |the partition columns using the supplied subquery alias to resolve any ambiguity. + |Example : + |spark.read.format("jdbc") + | .option("dbtable", "(select c1, c2 from t1) as subq") + | .option("partitionColumn", "subq.c1" + | .load() + """.stripMargin + val e5 = intercept[RuntimeException] { + sql( + s""" + |CREATE OR REPLACE TEMPORARY VIEW queryOption + |USING org.apache.spark.sql.jdbc + |OPTIONS (url '$url', query '$query', user 'testUser', password 'testPass', + | partitionColumn 'THEID', lowerBound '1', upperBound '4', numPartitions '3') + """.stripMargin.replaceAll("\n", " ")) + }.getMessage + assert(e5.contains(expectedErrorMsg)) + } + + test("query JDBC option") { + val query = "SELECT name, theid FROM test.people WHERE theid = 1" + // query option to pass on the query string. + val df = spark.read.format("jdbc") + .option("Url", urlWithUserAndPass) + .option("query", query) + .load() + checkAnswer( + df, + Row("fred", 1) :: Nil) + + // query option in the create table path. + sql( + s""" + |CREATE OR REPLACE TEMPORARY VIEW queryOption + |USING org.apache.spark.sql.jdbc + |OPTIONS (url '$url', query '$query', user 'testUser', password 'testPass') + """.stripMargin.replaceAll("\n", " ")) + + checkAnswer( + sql("select name, theid from queryOption"), + Row("fred", 1) :: Nil) + + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala index 1c2c92d1f0737..b751ec2de4825 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala @@ -293,13 +293,23 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter { test("save errors if dbtable is not specified") { val df = spark.createDataFrame(sparkContext.parallelize(arr2x2), schema2) - val e = intercept[RuntimeException] { + val e1 = intercept[RuntimeException] { + df.write.format("jdbc") + .option("url", url1) + .options(properties.asScala) + .save() + }.getMessage + assert(e1.contains("Option 'dbtable' or 'query' is required")) + + val e2 = intercept[RuntimeException] { df.write.format("jdbc") .option("url", url1) .options(properties.asScala) + .option("query", "select * from TEST.SAVETEST") .save() }.getMessage - assert(e.contains("Option 'dbtable' is required")) + val msg = "Option 'dbtable' is required. Option 'query' is not applicable while writing." + assert(e2.contains(msg)) } test("save errors if wrong user/password combination") { From 16f2c3ea46a330bff7fae33f2521eb36a6280f04 Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Tue, 26 Jun 2018 15:56:58 -0700 Subject: [PATCH 1019/1161] [SPARK-6237][NETWORK] Network-layer changes to allow stream upload. These changes allow an RPCHandler to receive an upload as a stream of data, without having to buffer the entire message in the FrameDecoder. The primary use case is for replicating large blocks. By itself, this change is adding dead-code that is not being used -- it is a step towards SPARK-24296. Added unit tests for handling streaming data, including successfully sending data, and failures in reading the stream with concurrent requests. Summary of changes: * Introduce a new UploadStream RPC which is sent to push a large payload as a stream (in contrast, the pre-existing StreamRequest and StreamResponse RPCs are used for pull-based streaming). * Generalize RpcHandler.receive() to support requests which contain streams. * Generalize StreamInterceptor to handle both request and response messages (previously it only handled responses). * Introduce StdChannelListener to abstract away common logging logic in ChannelFuture listeners. Author: Imran Rashid Closes #21346 from squito/upload_stream. --- .../network/client/StreamCallbackWithID.java | 22 ++ .../network/client/StreamInterceptor.java | 26 +- .../spark/network/client/TransportClient.java | 175 ++++++----- .../spark/network/crypto/AuthRpcHandler.java | 9 + .../spark/network/protocol/Message.java | 3 +- .../network/protocol/MessageDecoder.java | 3 + .../network/protocol/StreamResponse.java | 2 +- .../spark/network/protocol/UploadStream.java | 107 +++++++ .../spark/network/sasl/SaslRpcHandler.java | 9 + .../spark/network/server/RpcHandler.java | 34 ++- .../server/TransportRequestHandler.java | 95 +++++- .../spark/network/RpcIntegrationSuite.java | 289 ++++++++++++++++-- .../org/apache/spark/network/StreamSuite.java | 94 ++---- .../spark/network/StreamTestHelper.java | 104 +++++++ project/MimaExcludes.scala | 3 + 15 files changed, 799 insertions(+), 176 deletions(-) create mode 100644 common/network-common/src/main/java/org/apache/spark/network/client/StreamCallbackWithID.java create mode 100644 common/network-common/src/main/java/org/apache/spark/network/protocol/UploadStream.java create mode 100644 common/network-common/src/test/java/org/apache/spark/network/StreamTestHelper.java diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/StreamCallbackWithID.java b/common/network-common/src/main/java/org/apache/spark/network/client/StreamCallbackWithID.java new file mode 100644 index 0000000000000..bd173b653e33e --- /dev/null +++ b/common/network-common/src/main/java/org/apache/spark/network/client/StreamCallbackWithID.java @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.client; + +public interface StreamCallbackWithID extends StreamCallback { + String getID(); +} diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/StreamInterceptor.java b/common/network-common/src/main/java/org/apache/spark/network/client/StreamInterceptor.java index b0e85bae7c309..f3eb744ff7345 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/client/StreamInterceptor.java +++ b/common/network-common/src/main/java/org/apache/spark/network/client/StreamInterceptor.java @@ -22,22 +22,24 @@ import io.netty.buffer.ByteBuf; +import org.apache.spark.network.protocol.Message; +import org.apache.spark.network.server.MessageHandler; import org.apache.spark.network.util.TransportFrameDecoder; /** * An interceptor that is registered with the frame decoder to feed stream data to a * callback. */ -class StreamInterceptor implements TransportFrameDecoder.Interceptor { +public class StreamInterceptor implements TransportFrameDecoder.Interceptor { - private final TransportResponseHandler handler; + private final MessageHandler handler; private final String streamId; private final long byteCount; private final StreamCallback callback; private long bytesRead; - StreamInterceptor( - TransportResponseHandler handler, + public StreamInterceptor( + MessageHandler handler, String streamId, long byteCount, StreamCallback callback) { @@ -50,16 +52,24 @@ class StreamInterceptor implements TransportFrameDecoder.Interceptor { @Override public void exceptionCaught(Throwable cause) throws Exception { - handler.deactivateStream(); + deactivateStream(); callback.onFailure(streamId, cause); } @Override public void channelInactive() throws Exception { - handler.deactivateStream(); + deactivateStream(); callback.onFailure(streamId, new ClosedChannelException()); } + private void deactivateStream() { + if (handler instanceof TransportResponseHandler) { + // we only have to do this for TransportResponseHandler as it exposes numOutstandingFetches + // (there is no extra cleanup that needs to happen) + ((TransportResponseHandler) handler).deactivateStream(); + } + } + @Override public boolean handle(ByteBuf buf) throws Exception { int toRead = (int) Math.min(buf.readableBytes(), byteCount - bytesRead); @@ -72,10 +82,10 @@ public boolean handle(ByteBuf buf) throws Exception { RuntimeException re = new IllegalStateException(String.format( "Read too many bytes? Expected %d, but read %d.", byteCount, bytesRead)); callback.onFailure(streamId, re); - handler.deactivateStream(); + deactivateStream(); throw re; } else if (bytesRead == byteCount) { - handler.deactivateStream(); + deactivateStream(); callback.onComplete(streamId); } diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java index 8f354ad78bbaa..325225dc0ea2c 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java +++ b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java @@ -32,15 +32,15 @@ import com.google.common.base.Throwables; import com.google.common.util.concurrent.SettableFuture; import io.netty.channel.Channel; +import io.netty.util.concurrent.Future; +import io.netty.util.concurrent.GenericFutureListener; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.apache.spark.network.buffer.ManagedBuffer; import org.apache.spark.network.buffer.NioManagedBuffer; -import org.apache.spark.network.protocol.ChunkFetchRequest; -import org.apache.spark.network.protocol.OneWayMessage; -import org.apache.spark.network.protocol.RpcRequest; -import org.apache.spark.network.protocol.StreamChunkId; -import org.apache.spark.network.protocol.StreamRequest; +import org.apache.spark.network.protocol.*; + import static org.apache.spark.network.util.NettyUtils.getRemoteAddress; /** @@ -133,34 +133,21 @@ public void fetchChunk( long streamId, int chunkIndex, ChunkReceivedCallback callback) { - long startTime = System.currentTimeMillis(); if (logger.isDebugEnabled()) { logger.debug("Sending fetch chunk request {} to {}", chunkIndex, getRemoteAddress(channel)); } StreamChunkId streamChunkId = new StreamChunkId(streamId, chunkIndex); - handler.addFetchRequest(streamChunkId, callback); - - channel.writeAndFlush(new ChunkFetchRequest(streamChunkId)).addListener(future -> { - if (future.isSuccess()) { - long timeTaken = System.currentTimeMillis() - startTime; - if (logger.isTraceEnabled()) { - logger.trace("Sending request {} to {} took {} ms", streamChunkId, - getRemoteAddress(channel), timeTaken); - } - } else { - String errorMsg = String.format("Failed to send request %s to %s: %s", streamChunkId, - getRemoteAddress(channel), future.cause()); - logger.error(errorMsg, future.cause()); + StdChannelListener listener = new StdChannelListener(streamChunkId) { + @Override + void handleFailure(String errorMsg, Throwable cause) { handler.removeFetchRequest(streamChunkId); - channel.close(); - try { - callback.onFailure(chunkIndex, new IOException(errorMsg, future.cause())); - } catch (Exception e) { - logger.error("Uncaught exception in RPC response callback handler!", e); - } + callback.onFailure(chunkIndex, new IOException(errorMsg, cause)); } - }); + }; + handler.addFetchRequest(streamChunkId, callback); + + channel.writeAndFlush(new ChunkFetchRequest(streamChunkId)).addListener(listener); } /** @@ -170,7 +157,12 @@ public void fetchChunk( * @param callback Object to call with the stream data. */ public void stream(String streamId, StreamCallback callback) { - long startTime = System.currentTimeMillis(); + StdChannelListener listener = new StdChannelListener(streamId) { + @Override + void handleFailure(String errorMsg, Throwable cause) throws Exception { + callback.onFailure(streamId, new IOException(errorMsg, cause)); + } + }; if (logger.isDebugEnabled()) { logger.debug("Sending stream request for {} to {}", streamId, getRemoteAddress(channel)); } @@ -180,25 +172,7 @@ public void stream(String streamId, StreamCallback callback) { // when responses arrive. synchronized (this) { handler.addStreamCallback(streamId, callback); - channel.writeAndFlush(new StreamRequest(streamId)).addListener(future -> { - if (future.isSuccess()) { - long timeTaken = System.currentTimeMillis() - startTime; - if (logger.isTraceEnabled()) { - logger.trace("Sending request for {} to {} took {} ms", streamId, - getRemoteAddress(channel), timeTaken); - } - } else { - String errorMsg = String.format("Failed to send request for %s to %s: %s", streamId, - getRemoteAddress(channel), future.cause()); - logger.error(errorMsg, future.cause()); - channel.close(); - try { - callback.onFailure(streamId, new IOException(errorMsg, future.cause())); - } catch (Exception e) { - logger.error("Uncaught exception in RPC response callback handler!", e); - } - } - }); + channel.writeAndFlush(new StreamRequest(streamId)).addListener(listener); } } @@ -211,35 +185,44 @@ public void stream(String streamId, StreamCallback callback) { * @return The RPC's id. */ public long sendRpc(ByteBuffer message, RpcResponseCallback callback) { - long startTime = System.currentTimeMillis(); if (logger.isTraceEnabled()) { logger.trace("Sending RPC to {}", getRemoteAddress(channel)); } - long requestId = Math.abs(UUID.randomUUID().getLeastSignificantBits()); + long requestId = requestId(); handler.addRpcRequest(requestId, callback); + RpcChannelListener listener = new RpcChannelListener(requestId, callback); channel.writeAndFlush(new RpcRequest(requestId, new NioManagedBuffer(message))) - .addListener(future -> { - if (future.isSuccess()) { - long timeTaken = System.currentTimeMillis() - startTime; - if (logger.isTraceEnabled()) { - logger.trace("Sending request {} to {} took {} ms", requestId, - getRemoteAddress(channel), timeTaken); - } - } else { - String errorMsg = String.format("Failed to send RPC %s to %s: %s", requestId, - getRemoteAddress(channel), future.cause()); - logger.error(errorMsg, future.cause()); - handler.removeRpcRequest(requestId); - channel.close(); - try { - callback.onFailure(new IOException(errorMsg, future.cause())); - } catch (Exception e) { - logger.error("Uncaught exception in RPC response callback handler!", e); - } - } - }); + .addListener(listener); + + return requestId; + } + + /** + * Send data to the remote end as a stream. This differs from stream() in that this is a request + * to *send* data to the remote end, not to receive it from the remote. + * + * @param meta meta data associated with the stream, which will be read completely on the + * receiving end before the stream itself. + * @param data this will be streamed to the remote end to allow for transferring large amounts + * of data without reading into memory. + * @param callback handles the reply -- onSuccess will only be called when both message and data + * are received successfully. + */ + public long uploadStream( + ManagedBuffer meta, + ManagedBuffer data, + RpcResponseCallback callback) { + if (logger.isTraceEnabled()) { + logger.trace("Sending RPC to {}", getRemoteAddress(channel)); + } + + long requestId = requestId(); + handler.addRpcRequest(requestId, callback); + + RpcChannelListener listener = new RpcChannelListener(requestId, callback); + channel.writeAndFlush(new UploadStream(requestId, meta, data)).addListener(listener); return requestId; } @@ -319,4 +302,60 @@ public String toString() { .add("isActive", isActive()) .toString(); } + + private static long requestId() { + return Math.abs(UUID.randomUUID().getLeastSignificantBits()); + } + + private class StdChannelListener + implements GenericFutureListener> { + final long startTime; + final Object requestId; + + StdChannelListener(Object requestId) { + this.startTime = System.currentTimeMillis(); + this.requestId = requestId; + } + + @Override + public void operationComplete(Future future) throws Exception { + if (future.isSuccess()) { + if (logger.isTraceEnabled()) { + long timeTaken = System.currentTimeMillis() - startTime; + logger.trace("Sending request {} to {} took {} ms", requestId, + getRemoteAddress(channel), timeTaken); + } + } else { + String errorMsg = String.format("Failed to send RPC %s to %s: %s", requestId, + getRemoteAddress(channel), future.cause()); + logger.error(errorMsg, future.cause()); + channel.close(); + try { + handleFailure(errorMsg, future.cause()); + } catch (Exception e) { + logger.error("Uncaught exception in RPC response callback handler!", e); + } + } + } + + void handleFailure(String errorMsg, Throwable cause) throws Exception {} + } + + private class RpcChannelListener extends StdChannelListener { + final long rpcRequestId; + final RpcResponseCallback callback; + + RpcChannelListener(long rpcRequestId, RpcResponseCallback callback) { + super("RPC " + rpcRequestId); + this.rpcRequestId = rpcRequestId; + this.callback = callback; + } + + @Override + void handleFailure(String errorMsg, Throwable cause) { + handler.removeRpcRequest(rpcRequestId); + callback.onFailure(new IOException(errorMsg, cause)); + } + } + } diff --git a/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthRpcHandler.java b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthRpcHandler.java index 8a6e3858081bf..fb44dbbb0953b 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthRpcHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthRpcHandler.java @@ -29,6 +29,7 @@ import org.slf4j.LoggerFactory; import org.apache.spark.network.client.RpcResponseCallback; +import org.apache.spark.network.client.StreamCallbackWithID; import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.sasl.SecretKeyHolder; import org.apache.spark.network.sasl.SaslRpcHandler; @@ -149,6 +150,14 @@ public void receive(TransportClient client, ByteBuffer message) { delegate.receive(client, message); } + @Override + public StreamCallbackWithID receiveStream( + TransportClient client, + ByteBuffer message, + RpcResponseCallback callback) { + return delegate.receiveStream(client, message, callback); + } + @Override public StreamManager getStreamManager() { return delegate.getStreamManager(); diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/Message.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/Message.java index 434935a8ef2ad..0ccd70c03aba8 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/protocol/Message.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/Message.java @@ -37,7 +37,7 @@ enum Type implements Encodable { ChunkFetchRequest(0), ChunkFetchSuccess(1), ChunkFetchFailure(2), RpcRequest(3), RpcResponse(4), RpcFailure(5), StreamRequest(6), StreamResponse(7), StreamFailure(8), - OneWayMessage(9), User(-1); + OneWayMessage(9), UploadStream(10), User(-1); private final byte id; @@ -65,6 +65,7 @@ public static Type decode(ByteBuf buf) { case 7: return StreamResponse; case 8: return StreamFailure; case 9: return OneWayMessage; + case 10: return UploadStream; case -1: throw new IllegalArgumentException("User type messages cannot be decoded."); default: throw new IllegalArgumentException("Unknown message type: " + id); } diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java index 39a7495828a8a..bf80aed0afe10 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java @@ -80,6 +80,9 @@ private Message decode(Message.Type msgType, ByteBuf in) { case StreamFailure: return StreamFailure.decode(in); + case UploadStream: + return UploadStream.decode(in); + default: throw new IllegalArgumentException("Unexpected message type: " + msgType); } diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamResponse.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamResponse.java index 87e212f3e157b..50b811604b84b 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamResponse.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamResponse.java @@ -67,7 +67,7 @@ public static StreamResponse decode(ByteBuf buf) { @Override public int hashCode() { - return Objects.hashCode(byteCount, streamId, body()); + return Objects.hashCode(byteCount, streamId); } @Override diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/UploadStream.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/UploadStream.java new file mode 100644 index 0000000000000..fa1d26e76b852 --- /dev/null +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/UploadStream.java @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol; + +import java.io.IOException; +import java.nio.ByteBuffer; + +import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; + +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.buffer.NettyManagedBuffer; + +/** + * An RPC with data that is sent outside of the frame, so it can be read as a stream. + */ +public final class UploadStream extends AbstractMessage implements RequestMessage { + /** Used to link an RPC request with its response. */ + public final long requestId; + public final ManagedBuffer meta; + public final long bodyByteCount; + + public UploadStream(long requestId, ManagedBuffer meta, ManagedBuffer body) { + super(body, false); // body is *not* included in the frame + this.requestId = requestId; + this.meta = meta; + bodyByteCount = body.size(); + } + + // this version is called when decoding the bytes on the receiving end. The body is handled + // separately. + private UploadStream(long requestId, ManagedBuffer meta, long bodyByteCount) { + super(null, false); + this.requestId = requestId; + this.meta = meta; + this.bodyByteCount = bodyByteCount; + } + + @Override + public Type type() { return Type.UploadStream; } + + @Override + public int encodedLength() { + // the requestId, meta size, meta and bodyByteCount (body is not included) + return 8 + 4 + ((int) meta.size()) + 8; + } + + @Override + public void encode(ByteBuf buf) { + buf.writeLong(requestId); + try { + ByteBuffer metaBuf = meta.nioByteBuffer(); + buf.writeInt(metaBuf.remaining()); + buf.writeBytes(metaBuf); + } catch (IOException io) { + throw new RuntimeException(io); + } + buf.writeLong(bodyByteCount); + } + + public static UploadStream decode(ByteBuf buf) { + long requestId = buf.readLong(); + int metaSize = buf.readInt(); + ManagedBuffer meta = new NettyManagedBuffer(buf.readRetainedSlice(metaSize)); + long bodyByteCount = buf.readLong(); + // This is called by the frame decoder, so the data is still null. We need a StreamInterceptor + // to read the data. + return new UploadStream(requestId, meta, bodyByteCount); + } + + @Override + public int hashCode() { + return Long.hashCode(requestId); + } + + @Override + public boolean equals(Object other) { + if (other instanceof UploadStream) { + UploadStream o = (UploadStream) other; + return requestId == o.requestId && super.equals(o); + } + return false; + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("requestId", requestId) + .add("body", body()) + .toString(); + } +} diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java index 0231428318add..355a3def8cc22 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java @@ -28,6 +28,7 @@ import org.slf4j.LoggerFactory; import org.apache.spark.network.client.RpcResponseCallback; +import org.apache.spark.network.client.StreamCallbackWithID; import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.server.RpcHandler; import org.apache.spark.network.server.StreamManager; @@ -132,6 +133,14 @@ public void receive(TransportClient client, ByteBuffer message) { delegate.receive(client, message); } + @Override + public StreamCallbackWithID receiveStream( + TransportClient client, + ByteBuffer message, + RpcResponseCallback callback) { + return delegate.receiveStream(client, message, callback); + } + @Override public StreamManager getStreamManager() { return delegate.getStreamManager(); diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/RpcHandler.java b/common/network-common/src/main/java/org/apache/spark/network/server/RpcHandler.java index 8f7554e2e07d5..38569baf82bce 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/RpcHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/RpcHandler.java @@ -23,6 +23,7 @@ import org.slf4j.LoggerFactory; import org.apache.spark.network.client.RpcResponseCallback; +import org.apache.spark.network.client.StreamCallbackWithID; import org.apache.spark.network.client.TransportClient; /** @@ -36,7 +37,8 @@ public abstract class RpcHandler { * Receive a single RPC message. Any exception thrown while in this method will be sent back to * the client in string form as a standard RPC failure. * - * This method will not be called in parallel for a single TransportClient (i.e., channel). + * Neither this method nor #receiveStream will be called in parallel for a single + * TransportClient (i.e., channel). * * @param client A channel client which enables the handler to make requests back to the sender * of this RPC. This will always be the exact same object for a particular channel. @@ -49,6 +51,36 @@ public abstract void receive( ByteBuffer message, RpcResponseCallback callback); + /** + * Receive a single RPC message which includes data that is to be received as a stream. Any + * exception thrown while in this method will be sent back to the client in string form as a + * standard RPC failure. + * + * Neither this method nor #receive will be called in parallel for a single TransportClient + * (i.e., channel). + * + * An error while reading data from the stream + * ({@link org.apache.spark.network.client.StreamCallback#onData(String, ByteBuffer)}) + * will fail the entire channel. A failure in "post-processing" the stream in + * {@link org.apache.spark.network.client.StreamCallback#onComplete(String)} will result in an + * rpcFailure, but the channel will remain active. + * + * @param client A channel client which enables the handler to make requests back to the sender + * of this RPC. This will always be the exact same object for a particular channel. + * @param messageHeader The serialized bytes of the header portion of the RPC. This is in meant + * to be relatively small, and will be buffered entirely in memory, to + * facilitate how the streaming portion should be received. + * @param callback Callback which should be invoked exactly once upon success or failure of the + * RPC. + * @return a StreamCallback for handling the accompanying streaming data + */ + public StreamCallbackWithID receiveStream( + TransportClient client, + ByteBuffer messageHeader, + RpcResponseCallback callback) { + throw new UnsupportedOperationException(); + } + /** * Returns the StreamManager which contains the state about which streams are currently being * fetched by a TransportClient. diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java b/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java index e94453578e6b0..e1d7b2dbff60f 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java @@ -17,6 +17,7 @@ package org.apache.spark.network.server; +import java.io.IOException; import java.net.SocketAddress; import java.nio.ByteBuffer; @@ -28,20 +29,10 @@ import org.apache.spark.network.buffer.ManagedBuffer; import org.apache.spark.network.buffer.NioManagedBuffer; -import org.apache.spark.network.client.RpcResponseCallback; -import org.apache.spark.network.client.TransportClient; -import org.apache.spark.network.protocol.ChunkFetchRequest; -import org.apache.spark.network.protocol.ChunkFetchFailure; -import org.apache.spark.network.protocol.ChunkFetchSuccess; -import org.apache.spark.network.protocol.Encodable; -import org.apache.spark.network.protocol.OneWayMessage; -import org.apache.spark.network.protocol.RequestMessage; -import org.apache.spark.network.protocol.RpcFailure; -import org.apache.spark.network.protocol.RpcRequest; -import org.apache.spark.network.protocol.RpcResponse; -import org.apache.spark.network.protocol.StreamFailure; -import org.apache.spark.network.protocol.StreamRequest; -import org.apache.spark.network.protocol.StreamResponse; +import org.apache.spark.network.client.*; +import org.apache.spark.network.protocol.*; +import org.apache.spark.network.util.TransportFrameDecoder; + import static org.apache.spark.network.util.NettyUtils.getRemoteAddress; /** @@ -52,6 +43,7 @@ * The messages should have been processed by the pipeline setup by {@link TransportServer}. */ public class TransportRequestHandler extends MessageHandler { + private static final Logger logger = LoggerFactory.getLogger(TransportRequestHandler.class); /** The Netty channel that this handler is associated with. */ @@ -113,6 +105,8 @@ public void handle(RequestMessage request) { processOneWayMessage((OneWayMessage) request); } else if (request instanceof StreamRequest) { processStreamRequest((StreamRequest) request); + } else if (request instanceof UploadStream) { + processStreamUpload((UploadStream) request); } else { throw new IllegalArgumentException("Unknown request type: " + request); } @@ -203,6 +197,79 @@ public void onFailure(Throwable e) { } } + /** + * Handle a request from the client to upload a stream of data. + */ + private void processStreamUpload(final UploadStream req) { + assert (req.body() == null); + try { + RpcResponseCallback callback = new RpcResponseCallback() { + @Override + public void onSuccess(ByteBuffer response) { + respond(new RpcResponse(req.requestId, new NioManagedBuffer(response))); + } + + @Override + public void onFailure(Throwable e) { + respond(new RpcFailure(req.requestId, Throwables.getStackTraceAsString(e))); + } + }; + TransportFrameDecoder frameDecoder = (TransportFrameDecoder) + channel.pipeline().get(TransportFrameDecoder.HANDLER_NAME); + ByteBuffer meta = req.meta.nioByteBuffer(); + StreamCallbackWithID streamHandler = rpcHandler.receiveStream(reverseClient, meta, callback); + if (streamHandler == null) { + throw new NullPointerException("rpcHandler returned a null streamHandler"); + } + StreamCallbackWithID wrappedCallback = new StreamCallbackWithID() { + @Override + public void onData(String streamId, ByteBuffer buf) throws IOException { + streamHandler.onData(streamId, buf); + } + + @Override + public void onComplete(String streamId) throws IOException { + try { + streamHandler.onComplete(streamId); + callback.onSuccess(ByteBuffer.allocate(0)); + } catch (Exception ex) { + IOException ioExc = new IOException("Failure post-processing complete stream;" + + " failing this rpc and leaving channel active"); + callback.onFailure(ioExc); + streamHandler.onFailure(streamId, ioExc); + } + } + + @Override + public void onFailure(String streamId, Throwable cause) throws IOException { + callback.onFailure(new IOException("Destination failed while reading stream", cause)); + streamHandler.onFailure(streamId, cause); + } + + @Override + public String getID() { + return streamHandler.getID(); + } + }; + if (req.bodyByteCount > 0) { + StreamInterceptor interceptor = new StreamInterceptor(this, wrappedCallback.getID(), + req.bodyByteCount, wrappedCallback); + frameDecoder.setInterceptor(interceptor); + } else { + wrappedCallback.onComplete(wrappedCallback.getID()); + } + } catch (Exception e) { + logger.error("Error while invoking RpcHandler#receive() on RPC id " + req.requestId, e); + respond(new RpcFailure(req.requestId, Throwables.getStackTraceAsString(e))); + // We choose to totally fail the channel, rather than trying to recover as we do in other + // cases. We don't know how many bytes of the stream the client has already sent for the + // stream, it's not worth trying to recover. + channel.pipeline().fireExceptionCaught(e); + } finally { + req.meta.release(); + } + } + private void processOneWayMessage(OneWayMessage req) { try { rpcHandler.receive(reverseClient, req.body().nioByteBuffer()); diff --git a/common/network-common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java b/common/network-common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java index 8ff737b129641..1f4d75c7e2ec5 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java @@ -17,43 +17,46 @@ package org.apache.spark.network; +import java.io.*; import java.nio.ByteBuffer; -import java.util.ArrayList; -import java.util.Collections; -import java.util.HashSet; -import java.util.Iterator; -import java.util.List; -import java.util.Set; +import java.util.*; +import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.Semaphore; import java.util.concurrent.TimeUnit; import com.google.common.collect.Sets; +import com.google.common.io.Files; +import org.apache.commons.lang3.tuple.ImmutablePair; +import org.apache.commons.lang3.tuple.Pair; import org.junit.AfterClass; import org.junit.BeforeClass; import org.junit.Test; import static org.junit.Assert.*; -import org.apache.spark.network.client.RpcResponseCallback; -import org.apache.spark.network.client.TransportClient; -import org.apache.spark.network.client.TransportClientFactory; -import org.apache.spark.network.server.OneForOneStreamManager; -import org.apache.spark.network.server.RpcHandler; -import org.apache.spark.network.server.StreamManager; -import org.apache.spark.network.server.TransportServer; +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.buffer.NioManagedBuffer; +import org.apache.spark.network.client.*; +import org.apache.spark.network.server.*; import org.apache.spark.network.util.JavaUtils; import org.apache.spark.network.util.MapConfigProvider; import org.apache.spark.network.util.TransportConf; public class RpcIntegrationSuite { + static TransportConf conf; static TransportServer server; static TransportClientFactory clientFactory; static RpcHandler rpcHandler; static List oneWayMsgs; + static StreamTestHelper testData; + + static ConcurrentHashMap streamCallbacks = + new ConcurrentHashMap<>(); @BeforeClass public static void setUp() throws Exception { - TransportConf conf = new TransportConf("shuffle", MapConfigProvider.EMPTY); + conf = new TransportConf("shuffle", MapConfigProvider.EMPTY); + testData = new StreamTestHelper(); rpcHandler = new RpcHandler() { @Override public void receive( @@ -71,6 +74,14 @@ public void receive( } } + @Override + public StreamCallbackWithID receiveStream( + TransportClient client, + ByteBuffer messageHeader, + RpcResponseCallback callback) { + return receiveStreamHelper(JavaUtils.bytesToString(messageHeader)); + } + @Override public void receive(TransportClient client, ByteBuffer message) { oneWayMsgs.add(JavaUtils.bytesToString(message)); @@ -85,10 +96,71 @@ public void receive(TransportClient client, ByteBuffer message) { oneWayMsgs = new ArrayList<>(); } + private static StreamCallbackWithID receiveStreamHelper(String msg) { + try { + if (msg.startsWith("fail/")) { + String[] parts = msg.split("/"); + switch (parts[1]) { + case "exception-ondata": + return new StreamCallbackWithID() { + @Override + public void onData(String streamId, ByteBuffer buf) throws IOException { + throw new IOException("failed to read stream data!"); + } + + @Override + public void onComplete(String streamId) throws IOException { + } + + @Override + public void onFailure(String streamId, Throwable cause) throws IOException { + } + + @Override + public String getID() { + return msg; + } + }; + case "exception-oncomplete": + return new StreamCallbackWithID() { + @Override + public void onData(String streamId, ByteBuffer buf) throws IOException { + } + + @Override + public void onComplete(String streamId) throws IOException { + throw new IOException("exception in onComplete"); + } + + @Override + public void onFailure(String streamId, Throwable cause) throws IOException { + } + + @Override + public String getID() { + return msg; + } + }; + case "null": + return null; + default: + throw new IllegalArgumentException("unexpected msg: " + msg); + } + } else { + VerifyingStreamCallback streamCallback = new VerifyingStreamCallback(msg); + streamCallbacks.put(msg, streamCallback); + return streamCallback; + } + } catch (IOException e) { + throw new RuntimeException(e); + } + } + @AfterClass public static void tearDown() { server.close(); clientFactory.close(); + testData.cleanup(); } static class RpcResult { @@ -130,6 +202,59 @@ public void onFailure(Throwable e) { return res; } + private RpcResult sendRpcWithStream(String... streams) throws Exception { + TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); + final Semaphore sem = new Semaphore(0); + RpcResult res = new RpcResult(); + res.successMessages = Collections.synchronizedSet(new HashSet()); + res.errorMessages = Collections.synchronizedSet(new HashSet()); + + for (String stream : streams) { + int idx = stream.lastIndexOf('/'); + ManagedBuffer meta = new NioManagedBuffer(JavaUtils.stringToBytes(stream)); + String streamName = (idx == -1) ? stream : stream.substring(idx + 1); + ManagedBuffer data = testData.openStream(conf, streamName); + client.uploadStream(meta, data, new RpcStreamCallback(stream, res, sem)); + } + + if (!sem.tryAcquire(streams.length, 5, TimeUnit.SECONDS)) { + fail("Timeout getting response from the server"); + } + streamCallbacks.values().forEach(streamCallback -> { + try { + streamCallback.verify(); + } catch (IOException e) { + throw new RuntimeException(e); + } + }); + client.close(); + return res; + } + + private static class RpcStreamCallback implements RpcResponseCallback { + final String streamId; + final RpcResult res; + final Semaphore sem; + + RpcStreamCallback(String streamId, RpcResult res, Semaphore sem) { + this.streamId = streamId; + this.res = res; + this.sem = sem; + } + + @Override + public void onSuccess(ByteBuffer message) { + res.successMessages.add(streamId); + sem.release(); + } + + @Override + public void onFailure(Throwable e) { + res.errorMessages.add(e.getMessage()); + sem.release(); + } + } + @Test public void singleRPC() throws Exception { RpcResult res = sendRPC("hello/Aaron"); @@ -193,10 +318,83 @@ public void sendOneWayMessage() throws Exception { } } + @Test + public void sendRpcWithStreamOneAtATime() throws Exception { + for (String stream : StreamTestHelper.STREAMS) { + RpcResult res = sendRpcWithStream(stream); + assertTrue("there were error messages!" + res.errorMessages, res.errorMessages.isEmpty()); + assertEquals(Sets.newHashSet(stream), res.successMessages); + } + } + + @Test + public void sendRpcWithStreamConcurrently() throws Exception { + String[] streams = new String[10]; + for (int i = 0; i < 10; i++) { + streams[i] = StreamTestHelper.STREAMS[i % StreamTestHelper.STREAMS.length]; + } + RpcResult res = sendRpcWithStream(streams); + assertEquals(Sets.newHashSet(StreamTestHelper.STREAMS), res.successMessages); + assertTrue(res.errorMessages.isEmpty()); + } + + @Test + public void sendRpcWithStreamFailures() throws Exception { + // when there is a failure reading stream data, we don't try to keep the channel usable, + // just send back a decent error msg. + RpcResult exceptionInCallbackResult = + sendRpcWithStream("fail/exception-ondata/smallBuffer", "smallBuffer"); + assertErrorAndClosed(exceptionInCallbackResult, "Destination failed while reading stream"); + + RpcResult nullStreamHandler = + sendRpcWithStream("fail/null/smallBuffer", "smallBuffer"); + assertErrorAndClosed(exceptionInCallbackResult, "Destination failed while reading stream"); + + // OTOH, if there is a failure during onComplete, the channel should still be fine + RpcResult exceptionInOnComplete = + sendRpcWithStream("fail/exception-oncomplete/smallBuffer", "smallBuffer"); + assertErrorsContain(exceptionInOnComplete.errorMessages, + Sets.newHashSet("Failure post-processing")); + assertEquals(Sets.newHashSet("smallBuffer"), exceptionInOnComplete.successMessages); + } + private void assertErrorsContain(Set errors, Set contains) { - assertEquals(contains.size(), errors.size()); + assertEquals("Expected " + contains.size() + " errors, got " + errors.size() + "errors: " + + errors, contains.size(), errors.size()); + + Pair, Set> r = checkErrorsContain(errors, contains); + assertTrue("Could not find error containing " + r.getRight() + "; errors: " + errors, + r.getRight().isEmpty()); + + assertTrue(r.getLeft().isEmpty()); + } + + private void assertErrorAndClosed(RpcResult result, String expectedError) { + assertTrue("unexpected success: " + result.successMessages, result.successMessages.isEmpty()); + // we expect 1 additional error, which contains *either* "closed" or "Connection reset" + Set errors = result.errorMessages; + assertEquals("Expected 2 errors, got " + errors.size() + "errors: " + + errors, 2, errors.size()); + + Set containsAndClosed = Sets.newHashSet(expectedError); + containsAndClosed.add("closed"); + containsAndClosed.add("Connection reset"); + + Pair, Set> r = checkErrorsContain(errors, containsAndClosed); + Set errorsNotFound = r.getRight(); + assertEquals(1, errorsNotFound.size()); + String err = errorsNotFound.iterator().next(); + assertTrue(err.equals("closed") || err.equals("Connection reset")); + + assertTrue(r.getLeft().isEmpty()); + } + + private Pair, Set> checkErrorsContain( + Set errors, + Set contains) { Set remainingErrors = Sets.newHashSet(errors); + Set notFound = Sets.newHashSet(); for (String contain : contains) { Iterator it = remainingErrors.iterator(); boolean foundMatch = false; @@ -207,9 +405,66 @@ private void assertErrorsContain(Set errors, Set contains) { break; } } - assertTrue("Could not find error containing " + contain + "; errors: " + errors, foundMatch); + if (!foundMatch) { + notFound.add(contain); + } + } + return new ImmutablePair<>(remainingErrors, notFound); + } + + private static class VerifyingStreamCallback implements StreamCallbackWithID { + final String streamId; + final StreamSuite.TestCallback helper; + final OutputStream out; + final File outFile; + + VerifyingStreamCallback(String streamId) throws IOException { + if (streamId.equals("file")) { + outFile = File.createTempFile("data", ".tmp", testData.tempDir); + out = new FileOutputStream(outFile); + } else { + out = new ByteArrayOutputStream(); + outFile = null; + } + this.streamId = streamId; + helper = new StreamSuite.TestCallback(out); + } + + void verify() throws IOException { + if (streamId.equals("file")) { + assertTrue("File stream did not match.", Files.equal(testData.testFile, outFile)); + } else { + byte[] result = ((ByteArrayOutputStream)out).toByteArray(); + ByteBuffer srcBuffer = testData.srcBuffer(streamId); + ByteBuffer base; + synchronized (srcBuffer) { + base = srcBuffer.duplicate(); + } + byte[] expected = new byte[base.remaining()]; + base.get(expected); + assertEquals(expected.length, result.length); + assertTrue("buffers don't match", Arrays.equals(expected, result)); + } + } + + @Override + public void onData(String streamId, ByteBuffer buf) throws IOException { + helper.onData(streamId, buf); } - assertTrue(remainingErrors.isEmpty()); + @Override + public void onComplete(String streamId) throws IOException { + helper.onComplete(streamId); + } + + @Override + public void onFailure(String streamId, Throwable cause) throws IOException { + helper.onFailure(streamId, cause); + } + + @Override + public String getID() { + return streamId; + } } } diff --git a/common/network-common/src/test/java/org/apache/spark/network/StreamSuite.java b/common/network-common/src/test/java/org/apache/spark/network/StreamSuite.java index f253a07e64be1..f3050cb79cdfd 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/StreamSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/StreamSuite.java @@ -26,7 +26,6 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.List; -import java.util.Random; import java.util.concurrent.Executors; import java.util.concurrent.ExecutorService; import java.util.concurrent.TimeUnit; @@ -37,9 +36,7 @@ import org.junit.Test; import static org.junit.Assert.*; -import org.apache.spark.network.buffer.FileSegmentManagedBuffer; import org.apache.spark.network.buffer.ManagedBuffer; -import org.apache.spark.network.buffer.NioManagedBuffer; import org.apache.spark.network.client.RpcResponseCallback; import org.apache.spark.network.client.StreamCallback; import org.apache.spark.network.client.TransportClient; @@ -51,16 +48,11 @@ import org.apache.spark.network.util.TransportConf; public class StreamSuite { - private static final String[] STREAMS = { "largeBuffer", "smallBuffer", "emptyBuffer", "file" }; + private static final String[] STREAMS = StreamTestHelper.STREAMS; + private static StreamTestHelper testData; private static TransportServer server; private static TransportClientFactory clientFactory; - private static File testFile; - private static File tempDir; - - private static ByteBuffer emptyBuffer; - private static ByteBuffer smallBuffer; - private static ByteBuffer largeBuffer; private static ByteBuffer createBuffer(int bufSize) { ByteBuffer buf = ByteBuffer.allocate(bufSize); @@ -73,23 +65,7 @@ private static ByteBuffer createBuffer(int bufSize) { @BeforeClass public static void setUp() throws Exception { - tempDir = Files.createTempDir(); - emptyBuffer = createBuffer(0); - smallBuffer = createBuffer(100); - largeBuffer = createBuffer(100000); - - testFile = File.createTempFile("stream-test-file", "txt", tempDir); - FileOutputStream fp = new FileOutputStream(testFile); - try { - Random rnd = new Random(); - for (int i = 0; i < 512; i++) { - byte[] fileContent = new byte[1024]; - rnd.nextBytes(fileContent); - fp.write(fileContent); - } - } finally { - fp.close(); - } + testData = new StreamTestHelper(); final TransportConf conf = new TransportConf("shuffle", MapConfigProvider.EMPTY); final StreamManager streamManager = new StreamManager() { @@ -100,18 +76,7 @@ public ManagedBuffer getChunk(long streamId, int chunkIndex) { @Override public ManagedBuffer openStream(String streamId) { - switch (streamId) { - case "largeBuffer": - return new NioManagedBuffer(largeBuffer); - case "smallBuffer": - return new NioManagedBuffer(smallBuffer); - case "emptyBuffer": - return new NioManagedBuffer(emptyBuffer); - case "file": - return new FileSegmentManagedBuffer(conf, testFile, 0, testFile.length()); - default: - throw new IllegalArgumentException("Invalid stream: " + streamId); - } + return testData.openStream(conf, streamId); } }; RpcHandler handler = new RpcHandler() { @@ -137,12 +102,7 @@ public StreamManager getStreamManager() { public static void tearDown() { server.close(); clientFactory.close(); - if (tempDir != null) { - for (File f : tempDir.listFiles()) { - f.delete(); - } - tempDir.delete(); - } + testData.cleanup(); } @Test @@ -234,21 +194,21 @@ public void run() { case "largeBuffer": baos = new ByteArrayOutputStream(); out = baos; - srcBuffer = largeBuffer; + srcBuffer = testData.largeBuffer; break; case "smallBuffer": baos = new ByteArrayOutputStream(); out = baos; - srcBuffer = smallBuffer; + srcBuffer = testData.smallBuffer; break; case "file": - outFile = File.createTempFile("data", ".tmp", tempDir); + outFile = File.createTempFile("data", ".tmp", testData.tempDir); out = new FileOutputStream(outFile); break; case "emptyBuffer": baos = new ByteArrayOutputStream(); out = baos; - srcBuffer = emptyBuffer; + srcBuffer = testData.emptyBuffer; break; default: throw new IllegalArgumentException(streamId); @@ -256,10 +216,10 @@ public void run() { TestCallback callback = new TestCallback(out); client.stream(streamId, callback); - waitForCompletion(callback); + callback.waitForCompletion(timeoutMs); if (srcBuffer == null) { - assertTrue("File stream did not match.", Files.equal(testFile, outFile)); + assertTrue("File stream did not match.", Files.equal(testData.testFile, outFile)); } else { ByteBuffer base; synchronized (srcBuffer) { @@ -292,23 +252,9 @@ public void check() throws Throwable { throw error; } } - - private void waitForCompletion(TestCallback callback) throws Exception { - long now = System.currentTimeMillis(); - long deadline = now + timeoutMs; - synchronized (callback) { - while (!callback.completed && now < deadline) { - callback.wait(deadline - now); - now = System.currentTimeMillis(); - } - } - assertTrue("Timed out waiting for stream.", callback.completed); - assertNull(callback.error); - } - } - private static class TestCallback implements StreamCallback { + static class TestCallback implements StreamCallback { private final OutputStream out; public volatile boolean completed; @@ -344,6 +290,22 @@ public void onFailure(String streamId, Throwable cause) { } } + void waitForCompletion(long timeoutMs) { + long now = System.currentTimeMillis(); + long deadline = now + timeoutMs; + synchronized (this) { + while (!completed && now < deadline) { + try { + wait(deadline - now); + } catch (InterruptedException ie) { + throw new RuntimeException(ie); + } + now = System.currentTimeMillis(); + } + } + assertTrue("Timed out waiting for stream.", completed); + assertNull(error); + } } } diff --git a/common/network-common/src/test/java/org/apache/spark/network/StreamTestHelper.java b/common/network-common/src/test/java/org/apache/spark/network/StreamTestHelper.java new file mode 100644 index 0000000000000..0f5c82c9e9b1f --- /dev/null +++ b/common/network-common/src/test/java/org/apache/spark/network/StreamTestHelper.java @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network; + +import java.io.File; +import java.io.FileOutputStream; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.Random; + +import com.google.common.io.Files; + +import org.apache.spark.network.buffer.FileSegmentManagedBuffer; +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.buffer.NioManagedBuffer; +import org.apache.spark.network.util.JavaUtils; +import org.apache.spark.network.util.TransportConf; + +class StreamTestHelper { + static final String[] STREAMS = { "largeBuffer", "smallBuffer", "emptyBuffer", "file" }; + + final File testFile; + final File tempDir; + + final ByteBuffer emptyBuffer; + final ByteBuffer smallBuffer; + final ByteBuffer largeBuffer; + + private static ByteBuffer createBuffer(int bufSize) { + ByteBuffer buf = ByteBuffer.allocate(bufSize); + for (int i = 0; i < bufSize; i ++) { + buf.put((byte) i); + } + buf.flip(); + return buf; + } + + StreamTestHelper() throws Exception { + tempDir = Files.createTempDir(); + emptyBuffer = createBuffer(0); + smallBuffer = createBuffer(100); + largeBuffer = createBuffer(100000); + + testFile = File.createTempFile("stream-test-file", "txt", tempDir); + FileOutputStream fp = new FileOutputStream(testFile); + try { + Random rnd = new Random(); + for (int i = 0; i < 512; i++) { + byte[] fileContent = new byte[1024]; + rnd.nextBytes(fileContent); + fp.write(fileContent); + } + } finally { + fp.close(); + } + } + + public ByteBuffer srcBuffer(String name) { + switch (name) { + case "largeBuffer": + return largeBuffer; + case "smallBuffer": + return smallBuffer; + case "emptyBuffer": + return emptyBuffer; + default: + throw new IllegalArgumentException("Invalid stream: " + name); + } + } + + public ManagedBuffer openStream(TransportConf conf, String streamId) { + switch (streamId) { + case "file": + return new FileSegmentManagedBuffer(conf, testFile, 0, testFile.length()); + default: + return new NioManagedBuffer(srcBuffer(streamId)); + } + } + + void cleanup() { + if (tempDir != null) { + try { + JavaUtils.deleteRecursively(tempDir); + } catch (IOException io) { + throw new RuntimeException(io); + } + } + } +} diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 4f6d5ff898681..eeb097ef153ad 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -36,6 +36,9 @@ object MimaExcludes { // Exclude rules for 2.4.x lazy val v24excludes = v23excludes ++ Seq( + // [SPARK-6237][NETWORK] Network-layer changes to allow stream upload + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.network.netty.NettyBlockRpcServer.receive"), + // [SPARK-20087][CORE] Attach accumulators / metrics to 'TaskKilled' end reason ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.TaskKilled.apply"), ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.TaskKilled.copy"), From 1b9368f7d4c1d5c0df49204f48515d3b4ffe3e13 Mon Sep 17 00:00:00 2001 From: Kris Mok Date: Wed, 27 Jun 2018 10:27:40 +0800 Subject: [PATCH 1020/1161] [SPARK-24659][SQL] GenericArrayData.equals should respect element type differences ## What changes were proposed in this pull request? Fix `GenericArrayData.equals`, so that it respects the actual types of the elements. e.g. an instance that represents an `array` and another instance that represents an `array` should be considered incompatible, and thus should return false for `equals`. `GenericArrayData` doesn't keep any schema information by itself, and rather relies on the Java objects referenced by its `array` field's elements to keep track of their own object types. So, the most straightforward way to respect their types is to call `equals` on the elements, instead of using Scala's `==` operator, which can have semantics that are not always desirable: ``` new java.lang.Integer(123) == new java.lang.Long(123L) // true in Scala new java.lang.Integer(123).equals(new java.lang.Long(123L)) // false in Scala ``` ## How was this patch tested? Added unit test in `ComplexDataSuite` Author: Kris Mok Closes #21643 from rednaxelafx/fix-genericarraydata-equals. --- .../sql/catalyst/util/GenericArrayData.scala | 2 +- .../sql/catalyst/util/ComplexDataSuite.scala | 36 +++++++++++++++++++ 2 files changed, 37 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala index 9e39ed9c3a778..83ad08d8e1758 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala @@ -122,7 +122,7 @@ class GenericArrayData(val array: Array[Any]) extends ArrayData { if (!o2.isInstanceOf[Double] || ! java.lang.Double.isNaN(o2.asInstanceOf[Double])) { return false } - case _ => if (o1 != o2) { + case _ => if (!o1.equals(o2)) { return false } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ComplexDataSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ComplexDataSuite.scala index 9d285916bcf42..229e32479082c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ComplexDataSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ComplexDataSuite.scala @@ -104,4 +104,40 @@ class ComplexDataSuite extends SparkFunSuite { // The copied data should not be changed externally. assert(copied.getStruct(0, 1).getUTF8String(0).toString == "a") } + + test("SPARK-24659: GenericArrayData.equals should respect element type differences") { + import scala.reflect.ClassTag + + // Expected positive cases + def arraysShouldEqual[T: ClassTag](element: T*): Unit = { + val array1 = new GenericArrayData(Array[T](element: _*)) + val array2 = new GenericArrayData(Array[T](element: _*)) + assert(array1.equals(array2)) + } + arraysShouldEqual(true, false) // Boolean + arraysShouldEqual(0.toByte, 123.toByte, Byte.MinValue, Byte.MaxValue) // Byte + arraysShouldEqual(0.toShort, 123.toShort, Short.MinValue, Short.MaxValue) // Short + arraysShouldEqual(0, 123, -65536, Int.MinValue, Int.MaxValue) // Int + arraysShouldEqual(0L, 123L, -65536L, Long.MinValue, Long.MaxValue) // Long + arraysShouldEqual(0.0F, 123.0F, Float.MinValue, Float.MaxValue, Float.MinPositiveValue, + Float.PositiveInfinity, Float.NegativeInfinity, Float.NaN) // Float + arraysShouldEqual(0.0, 123.0, Double.MinValue, Double.MaxValue, Double.MinPositiveValue, + Double.PositiveInfinity, Double.NegativeInfinity, Double.NaN) // Double + arraysShouldEqual(Array[Byte](123.toByte), Array[Byte](), null) // SQL Binary + arraysShouldEqual(UTF8String.fromString("foo"), null) // SQL String + + // Expected negative cases + // Spark SQL considers cases like array vs array to be incompatible, + // so an underlying implementation of array type should return false in such cases. + def arraysShouldNotEqual[T: ClassTag, U: ClassTag](element1: T, element2: U): Unit = { + val array1 = new GenericArrayData(Array[T](element1)) + val array2 = new GenericArrayData(Array[U](element2)) + assert(!array1.equals(array2)) + } + arraysShouldNotEqual(true, 1) // Boolean <-> Int + arraysShouldNotEqual(123.toByte, 123) // Byte <-> Int + arraysShouldNotEqual(123.toByte, 123L) // Byte <-> Long + arraysShouldNotEqual(123.toShort, 123) // Short <-> Int + arraysShouldNotEqual(123, 123L) // Int <-> Long + } } From d08f53dc61f662f5291f71bcbe1a7b9f531a34d2 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Wed, 27 Jun 2018 10:36:51 +0800 Subject: [PATCH 1021/1161] [SPARK-24605][SQL] size(null) returns null instead of -1 ## What changes were proposed in this pull request? In PR, I propose new behavior of `size(null)` under the config flag `spark.sql.legacy.sizeOfNull`. If the former one is disabled, the `size()` function returns `null` for `null` input. By default the `spark.sql.legacy.sizeOfNull` is enabled to keep backward compatibility with previous versions. In that case, `size(null)` returns `-1`. ## How was this patch tested? Modified existing tests for the `size()` function to check new behavior (`null`) and old one (`-1`). Author: Maxim Gekk Closes #21598 from MaxGekk/legacy-size-of-null. --- .../expressions/collectionOperations.scala | 38 ++++++++++--- .../apache/spark/sql/internal/SQLConf.scala | 8 +++ .../CollectionExpressionsSuite.scala | 30 +++++++---- .../org/apache/spark/sql/functions.scala | 2 +- .../spark/sql/DataFrameFunctionsSuite.scala | 54 +++++++++++-------- 5 files changed, 93 insertions(+), 39 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 58612f65c1a53..abd6c88d3d985 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -67,37 +67,61 @@ trait BinaryArrayExpressionWithImplicitCast extends BinaryExpression /** - * Given an array or map, returns its size. Returns -1 if null. + * Given an array or map, returns total number of elements in it. */ @ExpressionDescription( - usage = "_FUNC_(expr) - Returns the size of an array or a map. Returns -1 if null.", + usage = """ + _FUNC_(expr) - Returns the size of an array or a map. + The function returns -1 if its input is null and spark.sql.legacy.sizeOfNull is set to true. + If spark.sql.legacy.sizeOfNull is set to false, the function returns null for null input. + By default, the spark.sql.legacy.sizeOfNull parameter is set to true. + """, examples = """ Examples: > SELECT _FUNC_(array('b', 'd', 'c', 'a')); 4 + > SELECT _FUNC_(map('a', 1, 'b', 2)); + 2 + > SELECT _FUNC_(NULL); + -1 """) -case class Size(child: Expression) extends UnaryExpression with ExpectsInputTypes { +case class Size( + child: Expression, + legacySizeOfNull: Boolean) + extends UnaryExpression with ExpectsInputTypes { + + def this(child: Expression) = + this( + child, + legacySizeOfNull = SQLConf.get.getConf(SQLConf.LEGACY_SIZE_OF_NULL)) + override def dataType: DataType = IntegerType override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(ArrayType, MapType)) - override def nullable: Boolean = false + override def nullable: Boolean = if (legacySizeOfNull) false else super.nullable override def eval(input: InternalRow): Any = { val value = child.eval(input) if (value == null) { - -1 + if (legacySizeOfNull) -1 else null } else child.dataType match { case _: ArrayType => value.asInstanceOf[ArrayData].numElements() case _: MapType => value.asInstanceOf[MapData].numElements() + case other => throw new UnsupportedOperationException( + s"The size function doesn't support the operand type ${other.getClass.getCanonicalName}") } } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val childGen = child.genCode(ctx) - ev.copy(code = code""" + if (legacySizeOfNull) { + val childGen = child.genCode(ctx) + ev.copy(code = code""" boolean ${ev.isNull} = false; ${childGen.code} ${CodeGenerator.javaType(dataType)} ${ev.value} = ${childGen.isNull} ? -1 : (${childGen.value}).numElements();""", isNull = FalseLiteral) + } else { + defineCodeGen(ctx, ev, c => s"($c).numElements()") + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index e768416f257c9..239c8266351ae 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1324,6 +1324,12 @@ object SQLConf { "Other column values can be ignored during parsing even if they are malformed.") .booleanConf .createWithDefault(true) + + val LEGACY_SIZE_OF_NULL = buildConf("spark.sql.legacy.sizeOfNull") + .doc("If it is set to true, size of null returns -1. This behavior was inherited from Hive. " + + "The size function returns null for null input if the flag is disabled.") + .booleanConf + .createWithDefault(true) } /** @@ -1686,6 +1692,8 @@ class SQLConf extends Serializable with Logging { def csvColumnPruning: Boolean = getConf(SQLConf.CSV_PARSER_COLUMN_PRUNING) + def legacySizeOfNull: Boolean = getConf(SQLConf.LEGACY_SIZE_OF_NULL) + /** ********************** SQLConf functionality methods ************ */ /** Set Spark SQL configuration properties. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 5b8cf5128fe21..caea4fb25ff7e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -24,25 +24,37 @@ import org.apache.spark.sql.types._ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { - test("Array and Map Size") { + def testSize(legacySizeOfNull: Boolean, sizeOfNull: Any): Unit = { val a0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType)) val a1 = Literal.create(Seq[Integer](), ArrayType(IntegerType)) val a2 = Literal.create(Seq(1, 2), ArrayType(IntegerType)) - checkEvaluation(Size(a0), 3) - checkEvaluation(Size(a1), 0) - checkEvaluation(Size(a2), 2) + checkEvaluation(Size(a0, legacySizeOfNull), 3) + checkEvaluation(Size(a1, legacySizeOfNull), 0) + checkEvaluation(Size(a2, legacySizeOfNull), 2) val m0 = Literal.create(Map("a" -> "a", "b" -> "b"), MapType(StringType, StringType)) val m1 = Literal.create(Map[String, String](), MapType(StringType, StringType)) val m2 = Literal.create(Map("a" -> "a"), MapType(StringType, StringType)) - checkEvaluation(Size(m0), 2) - checkEvaluation(Size(m1), 0) - checkEvaluation(Size(m2), 1) + checkEvaluation(Size(m0, legacySizeOfNull), 2) + checkEvaluation(Size(m1, legacySizeOfNull), 0) + checkEvaluation(Size(m2, legacySizeOfNull), 1) + + checkEvaluation( + Size(Literal.create(null, MapType(StringType, StringType)), legacySizeOfNull), + expected = sizeOfNull) + checkEvaluation( + Size(Literal.create(null, ArrayType(StringType)), legacySizeOfNull), + expected = sizeOfNull) + } + + test("Array and Map Size - legacy") { + testSize(legacySizeOfNull = true, sizeOfNull = -1) + } - checkEvaluation(Size(Literal.create(null, MapType(StringType, StringType))), -1) - checkEvaluation(Size(Literal.create(null, ArrayType(StringType))), -1) + test("Array and Map Size") { + testSize(legacySizeOfNull = false, sizeOfNull = null) } test("MapKeys/MapValues") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 40c40e7083d1c..ef99ce3ad69d9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3431,7 +3431,7 @@ object functions { * @group collection_funcs * @since 1.5.0 */ - def size(e: Column): Column = withExpr { Size(e.expr) } + def size(e: Column): Column = withExpr { new Size(e.expr) } /** * Sorts the input array for the given column in ascending order, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 5d6a6c0832c96..b109898b5bfb3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -487,26 +487,29 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { }.getMessage().contains("only supports array input")) } - test("array size function") { + def testSizeOfArray(sizeOfNull: Any): Unit = { val df = Seq( (Seq[Int](1, 2), "x"), (Seq[Int](), "y"), (Seq[Int](1, 2, 3), "z"), (null, "empty") ).toDF("a", "b") - checkAnswer( - df.select(size($"a")), - Seq(Row(2), Row(0), Row(3), Row(-1)) - ) - checkAnswer( - df.selectExpr("size(a)"), - Seq(Row(2), Row(0), Row(3), Row(-1)) - ) - checkAnswer( - df.selectExpr("cardinality(a)"), - Seq(Row(2L), Row(0L), Row(3L), Row(-1L)) - ) + checkAnswer(df.select(size($"a")), Seq(Row(2), Row(0), Row(3), Row(sizeOfNull))) + checkAnswer(df.selectExpr("size(a)"), Seq(Row(2), Row(0), Row(3), Row(sizeOfNull))) + checkAnswer(df.selectExpr("cardinality(a)"), Seq(Row(2L), Row(0L), Row(3L), Row(sizeOfNull))) + } + + test("array size function - legacy") { + withSQLConf(SQLConf.LEGACY_SIZE_OF_NULL.key -> "true") { + testSizeOfArray(sizeOfNull = -1) + } + } + + test("array size function") { + withSQLConf(SQLConf.LEGACY_SIZE_OF_NULL.key -> "false") { + testSizeOfArray(sizeOfNull = null) + } } test("dataframe arrays_zip function") { @@ -567,21 +570,28 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { } } - test("map size function") { + def testSizeOfMap(sizeOfNull: Any): Unit = { val df = Seq( (Map[Int, Int](1 -> 1, 2 -> 2), "x"), (Map[Int, Int](), "y"), (Map[Int, Int](1 -> 1, 2 -> 2, 3 -> 3), "z"), (null, "empty") ).toDF("a", "b") - checkAnswer( - df.select(size($"a")), - Seq(Row(2), Row(0), Row(3), Row(-1)) - ) - checkAnswer( - df.selectExpr("size(a)"), - Seq(Row(2), Row(0), Row(3), Row(-1)) - ) + + checkAnswer(df.select(size($"a")), Seq(Row(2), Row(0), Row(3), Row(sizeOfNull))) + checkAnswer(df.selectExpr("size(a)"), Seq(Row(2), Row(0), Row(3), Row(sizeOfNull))) + } + + test("map size function - legacy") { + withSQLConf(SQLConf.LEGACY_SIZE_OF_NULL.key -> "true") { + testSizeOfMap(sizeOfNull = -1: Int) + } + } + + test("map size function") { + withSQLConf(SQLConf.LEGACY_SIZE_OF_NULL.key -> "false") { + testSizeOfMap(sizeOfNull = null) + } } test("map_keys/map_values function") { From 2669b4de3b336dde84b698c20dbc73b30abf79d4 Mon Sep 17 00:00:00 2001 From: "Vayda, Oleksandr: IT (PRG)" Date: Wed, 27 Jun 2018 11:52:31 +0900 Subject: [PATCH 1022/1161] [SPARK-23927][SQL] Add "sequence" expression MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? The PR adds the SQL function ```sequence```. https://issues.apache.org/jira/browse/SPARK-23927 The behavior of the function is based on Presto's one. Ref: https://prestodb.io/docs/current/functions/array.html - ```sequence(start, stop) → array``` Generate a sequence of integers from ```start``` to ```stop```, incrementing by ```1``` if ```start``` is less than or equal to ```stop```, otherwise ```-1```. - ```sequence(start, stop, step) → array``` Generate a sequence of integers from ```start``` to ```stop```, incrementing by ```step```. - ```sequence(start_date, stop_date) → array``` Generate a sequence of dates from ```start_date``` to ```stop_date```, incrementing by ```interval 1 day``` if ```start_date``` is less than or equal to ```stop_date```, otherwise ```- interval 1 day```. - ```sequence(start_date, stop_date, step_interval) → array``` Generate a sequence of dates from ```start_date``` to ```stop_date```, incrementing by ```step_interval```. The type of ```step_interval``` is ```CalendarInterval```. - ```sequence(start_timestemp, stop_timestemp) → array``` Generate a sequence of timestamps from ```start_timestamps``` to ```stop_timestamps```, incrementing by ```interval 1 day``` if ```start_date``` is less than or equal to ```stop_date```, otherwise ```- interval 1 day```. - ```sequence(start_timestamp, stop_timestamp, step_interval) → array``` Generate a sequence of timestamps from ```start_timestamps``` to ```stop_timestamps```, incrementing by ```step_interval```. The type of ```step_interval``` is ```CalendarInterval```. ## How was this patch tested? Added unit tests. Author: Vayda, Oleksandr: IT (PRG) Closes #21155 from wajda/feature/array-api-sequence. --- .../catalyst/analysis/FunctionRegistry.scala | 1 + .../sql/catalyst/analysis/TypeCoercion.scala | 7 + .../expressions/collectionOperations.scala | 402 +++++++++++++++++- .../CollectionExpressionsSuite.scala | 292 +++++++++++++ .../org/apache/spark/sql/functions.scala | 21 + .../spark/sql/DataFrameFunctionsSuite.scala | 56 +++ 6 files changed, 777 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 8abc616c1a3f7..a574d8a84d4fb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -432,6 +432,7 @@ object FunctionRegistry { expression[Reverse]("reverse"), expression[Concat]("concat"), expression[Flatten]("flatten"), + expression[Sequence]("sequence"), expression[ArrayRepeat]("array_repeat"), expression[ArrayRemove]("array_remove"), expression[ArrayDistinct]("array_distinct"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 637923928a7da..3ebab430ffbcd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -544,6 +544,13 @@ object TypeCoercion { case None => aj } + case s @ Sequence(_, _, _, timeZoneId) if !haveSameType(s.coercibleChildren) => + val types = s.coercibleChildren.map(_.dataType) + findWiderCommonType(types) match { + case Some(widerDataType) => s.castChildrenTo(widerDataType) + case None => s + } + case m @ CreateMap(children) if m.keys.length == m.values.length && (!haveSameType(m.keys) || !haveSameType(m.values)) => val newKeys = if (haveSameType(m.keys)) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index abd6c88d3d985..0395e1ef9a7ad 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -16,9 +16,10 @@ */ package org.apache.spark.sql.catalyst.expressions -import java.util.Comparator +import java.util.{Comparator, TimeZone} import scala.collection.mutable +import scala.reflect.ClassTag import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion} @@ -26,11 +27,13 @@ import org.apache.spark.sql.catalyst.expressions.ArraySortLike.NullOrder import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util._ -import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.catalyst.util.DateTimeUtils._ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.array.ByteArrayMethods +import org.apache.spark.unsafe.array.ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH import org.apache.spark.unsafe.types.{ByteArray, UTF8String} +import org.apache.spark.unsafe.types.CalendarInterval import org.apache.spark.util.collection.OpenHashSet /** @@ -2313,6 +2316,401 @@ case class Flatten(child: Expression) extends UnaryExpression { override def prettyName: String = "flatten" } +@ExpressionDescription( + usage = """ + _FUNC_(start, stop, step) - Generates an array of elements from start to stop (inclusive), + incrementing by step. The type of the returned elements is the same as the type of argument + expressions. + + Supported types are: byte, short, integer, long, date, timestamp. + + The start and stop expressions must resolve to the same type. + If start and stop expressions resolve to the 'date' or 'timestamp' type + then the step expression must resolve to the 'interval' type, otherwise to the same type + as the start and stop expressions. + """, + arguments = """ + Arguments: + * start - an expression. The start of the range. + * stop - an expression. The end the range (inclusive). + * step - an optional expression. The step of the range. + By default step is 1 if start is less than or equal to stop, otherwise -1. + For the temporal sequences it's 1 day and -1 day respectively. + If start is greater than stop then the step must be negative, and vice versa. + """, + examples = """ + Examples: + > SELECT _FUNC_(1, 5); + [1, 2, 3, 4, 5] + > SELECT _FUNC_(5, 1); + [5, 4, 3, 2, 1] + > SELECT _FUNC_(to_date('2018-01-01'), to_date('2018-03-01'), interval 1 month); + [2018-01-01, 2018-02-01, 2018-03-01] + """, + since = "2.4.0" +) +case class Sequence( + start: Expression, + stop: Expression, + stepOpt: Option[Expression], + timeZoneId: Option[String] = None) + extends Expression + with TimeZoneAwareExpression { + + import Sequence._ + + def this(start: Expression, stop: Expression) = + this(start, stop, None, None) + + def this(start: Expression, stop: Expression, step: Expression) = + this(start, stop, Some(step), None) + + override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = + copy(timeZoneId = Some(timeZoneId)) + + override def children: Seq[Expression] = Seq(start, stop) ++ stepOpt + + override def foldable: Boolean = children.forall(_.foldable) + + override def nullable: Boolean = children.exists(_.nullable) + + override lazy val dataType: ArrayType = ArrayType(start.dataType, containsNull = false) + + override def checkInputDataTypes(): TypeCheckResult = { + val startType = start.dataType + def stepType = stepOpt.get.dataType + val typesCorrect = + startType.sameType(stop.dataType) && + (startType match { + case TimestampType | DateType => + stepOpt.isEmpty || CalendarIntervalType.acceptsType(stepType) + case _: IntegralType => + stepOpt.isEmpty || stepType.sameType(startType) + case _ => false + }) + + if (typesCorrect) { + TypeCheckResult.TypeCheckSuccess + } else { + TypeCheckResult.TypeCheckFailure( + s"$prettyName only supports integral, timestamp or date types") + } + } + + def coercibleChildren: Seq[Expression] = children.filter(_.dataType != CalendarIntervalType) + + def castChildrenTo(widerType: DataType): Expression = Sequence( + Cast(start, widerType), + Cast(stop, widerType), + stepOpt.map(step => if (step.dataType != CalendarIntervalType) Cast(step, widerType) else step), + timeZoneId) + + private lazy val impl: SequenceImpl = dataType.elementType match { + case iType: IntegralType => + type T = iType.InternalType + val ct = ClassTag[T](iType.tag.mirror.runtimeClass(iType.tag.tpe)) + new IntegralSequenceImpl(iType)(ct, iType.integral) + + case TimestampType => + new TemporalSequenceImpl[Long](LongType, 1, identity, timeZone) + + case DateType => + new TemporalSequenceImpl[Int](IntegerType, MICROS_PER_DAY, _.toInt, timeZone) + } + + override def eval(input: InternalRow): Any = { + val startVal = start.eval(input) + if (startVal == null) return null + val stopVal = stop.eval(input) + if (stopVal == null) return null + val stepVal = stepOpt.map(_.eval(input)).getOrElse(impl.defaultStep(startVal, stopVal)) + if (stepVal == null) return null + + ArrayData.toArrayData(impl.eval(startVal, stopVal, stepVal)) + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val startGen = start.genCode(ctx) + val stopGen = stop.genCode(ctx) + val stepGen = stepOpt.map(_.genCode(ctx)).getOrElse( + impl.defaultStep.genCode(ctx, startGen, stopGen)) + + val resultType = CodeGenerator.javaType(dataType) + val resultCode = { + val arr = ctx.freshName("arr") + val arrElemType = CodeGenerator.javaType(dataType.elementType) + s""" + |final $arrElemType[] $arr = null; + |${impl.genCode(ctx, startGen.value, stopGen.value, stepGen.value, arr, arrElemType)} + |${ev.value} = UnsafeArrayData.fromPrimitiveArray($arr); + """.stripMargin + } + + if (nullable) { + val nullSafeEval = + startGen.code + ctx.nullSafeExec(start.nullable, startGen.isNull) { + stopGen.code + ctx.nullSafeExec(stop.nullable, stopGen.isNull) { + stepGen.code + ctx.nullSafeExec(stepOpt.exists(_.nullable), stepGen.isNull) { + s""" + |${ev.isNull} = false; + |$resultCode + """.stripMargin + } + } + } + ev.copy(code = + code""" + |boolean ${ev.isNull} = true; + |$resultType ${ev.value} = null; + |$nullSafeEval + """.stripMargin) + + } else { + ev.copy(code = + code""" + |${startGen.code} + |${stopGen.code} + |${stepGen.code} + |$resultType ${ev.value} = null; + |$resultCode + """.stripMargin, + isNull = FalseLiteral) + } + } +} + +object Sequence { + + private type LessThanOrEqualFn = (Any, Any) => Boolean + + private class DefaultStep(lteq: LessThanOrEqualFn, stepType: DataType, one: Any) { + private val negativeOne = UnaryMinus(Literal(one)).eval() + + def apply(start: Any, stop: Any): Any = { + if (lteq(start, stop)) one else negativeOne + } + + def genCode(ctx: CodegenContext, startGen: ExprCode, stopGen: ExprCode): ExprCode = { + val Seq(oneVal, negativeOneVal) = Seq(one, negativeOne).map(Literal(_).genCode(ctx).value) + ExprCode.forNonNullValue(JavaCode.expression( + s"${startGen.value} <= ${stopGen.value} ? $oneVal : $negativeOneVal", + stepType)) + } + } + + private trait SequenceImpl { + def eval(start: Any, stop: Any, step: Any): Any + + def genCode( + ctx: CodegenContext, + start: String, + stop: String, + step: String, + arr: String, + elemType: String): String + + val defaultStep: DefaultStep + } + + private class IntegralSequenceImpl[T: ClassTag] + (elemType: IntegralType)(implicit num: Integral[T]) extends SequenceImpl { + + override val defaultStep: DefaultStep = new DefaultStep( + (elemType.ordering.lteq _).asInstanceOf[LessThanOrEqualFn], + elemType, + num.one) + + override def eval(input1: Any, input2: Any, input3: Any): Array[T] = { + import num._ + + val start = input1.asInstanceOf[T] + val stop = input2.asInstanceOf[T] + val step = input3.asInstanceOf[T] + + var i: Int = getSequenceLength(start, stop, step) + val arr = new Array[T](i) + while (i > 0) { + i -= 1 + arr(i) = start + step * num.fromInt(i) + } + arr + } + + override def genCode( + ctx: CodegenContext, + start: String, + stop: String, + step: String, + arr: String, + elemType: String): String = { + val i = ctx.freshName("i") + s""" + |${genSequenceLengthCode(ctx, start, stop, step, i)} + |$arr = new $elemType[$i]; + |while ($i > 0) { + | $i--; + | $arr[$i] = ($elemType) ($start + $step * $i); + |} + """.stripMargin + } + } + + private class TemporalSequenceImpl[T: ClassTag] + (dt: IntegralType, scale: Long, fromLong: Long => T, timeZone: TimeZone) + (implicit num: Integral[T]) extends SequenceImpl { + + override val defaultStep: DefaultStep = new DefaultStep( + (dt.ordering.lteq _).asInstanceOf[LessThanOrEqualFn], + CalendarIntervalType, + new CalendarInterval(0, MICROS_PER_DAY)) + + private val backedSequenceImpl = new IntegralSequenceImpl[T](dt) + private val microsPerMonth = 28 * CalendarInterval.MICROS_PER_DAY + + override def eval(input1: Any, input2: Any, input3: Any): Array[T] = { + val start = input1.asInstanceOf[T] + val stop = input2.asInstanceOf[T] + val step = input3.asInstanceOf[CalendarInterval] + val stepMonths = step.months + val stepMicros = step.microseconds + + if (stepMonths == 0) { + backedSequenceImpl.eval(start, stop, fromLong(stepMicros / scale)) + + } else { + // To estimate the resulted array length we need to make assumptions + // about a month length in microseconds + val intervalStepInMicros = stepMicros + stepMonths * microsPerMonth + val startMicros: Long = num.toLong(start) * scale + val stopMicros: Long = num.toLong(stop) * scale + val maxEstimatedArrayLength = + getSequenceLength(startMicros, stopMicros, intervalStepInMicros) + + val stepSign = if (stopMicros > startMicros) +1 else -1 + val exclusiveItem = stopMicros + stepSign + val arr = new Array[T](maxEstimatedArrayLength) + var t = startMicros + var i = 0 + + while (t < exclusiveItem ^ stepSign < 0) { + arr(i) = fromLong(t / scale) + t = timestampAddInterval(t, stepMonths, stepMicros, timeZone) + i += 1 + } + + // truncate array to the correct length + if (arr.length == i) arr else arr.slice(0, i) + } + } + + override def genCode( + ctx: CodegenContext, + start: String, + stop: String, + step: String, + arr: String, + elemType: String): String = { + val stepMonths = ctx.freshName("stepMonths") + val stepMicros = ctx.freshName("stepMicros") + val stepScaled = ctx.freshName("stepScaled") + val intervalInMicros = ctx.freshName("intervalInMicros") + val startMicros = ctx.freshName("startMicros") + val stopMicros = ctx.freshName("stopMicros") + val arrLength = ctx.freshName("arrLength") + val stepSign = ctx.freshName("stepSign") + val exclusiveItem = ctx.freshName("exclusiveItem") + val t = ctx.freshName("t") + val i = ctx.freshName("i") + val genTimeZone = ctx.addReferenceObj("timeZone", timeZone, classOf[TimeZone].getName) + + val sequenceLengthCode = + s""" + |final long $intervalInMicros = $stepMicros + $stepMonths * ${microsPerMonth}L; + |${genSequenceLengthCode(ctx, startMicros, stopMicros, intervalInMicros, arrLength)} + """.stripMargin + + val timestampAddIntervalCode = + s""" + |$t = org.apache.spark.sql.catalyst.util.DateTimeUtils.timestampAddInterval( + | $t, $stepMonths, $stepMicros, $genTimeZone); + """.stripMargin + + s""" + |final int $stepMonths = $step.months; + |final long $stepMicros = $step.microseconds; + | + |if ($stepMonths == 0) { + | final $elemType $stepScaled = ($elemType) ($stepMicros / ${scale}L); + | ${backedSequenceImpl.genCode(ctx, start, stop, stepScaled, arr, elemType)}; + | + |} else { + | final long $startMicros = $start * ${scale}L; + | final long $stopMicros = $stop * ${scale}L; + | + | $sequenceLengthCode + | + | final int $stepSign = $stopMicros > $startMicros ? +1 : -1; + | final long $exclusiveItem = $stopMicros + $stepSign; + | + | $arr = new $elemType[$arrLength]; + | long $t = $startMicros; + | int $i = 0; + | + | while ($t < $exclusiveItem ^ $stepSign < 0) { + | $arr[$i] = ($elemType) ($t / ${scale}L); + | $timestampAddIntervalCode + | $i += 1; + | } + | + | if ($arr.length > $i) { + | $arr = java.util.Arrays.copyOf($arr, $i); + | } + |} + """.stripMargin + } + } + + private def getSequenceLength[U](start: U, stop: U, step: U)(implicit num: Integral[U]): Int = { + import num._ + require( + (step > num.zero && start <= stop) + || (step < num.zero && start >= stop) + || (step == num.zero && start == stop), + s"Illegal sequence boundaries: $start to $stop by $step") + + val len = if (start == stop) 1L else 1L + (stop.toLong - start.toLong) / step.toLong + + require( + len <= MAX_ROUNDED_ARRAY_LENGTH, + s"Too long sequence: $len. Should be <= $MAX_ROUNDED_ARRAY_LENGTH") + + len.toInt + } + + private def genSequenceLengthCode( + ctx: CodegenContext, + start: String, + stop: String, + step: String, + len: String): String = { + val longLen = ctx.freshName("longLen") + s""" + |if (!(($step > 0 && $start <= $stop) || + | ($step < 0 && $start >= $stop) || + | ($step == 0 && $start == $stop))) { + | throw new IllegalArgumentException( + | "Illegal sequence boundaries: " + $start + " to " + $stop + " by " + $step); + |} + |long $longLen = $stop == $start ? 1L : 1L + ((long) $stop - $start) / $step; + |if ($longLen > $MAX_ROUNDED_ARRAY_LENGTH) { + | throw new IllegalArgumentException( + | "Too long sequence: " + $longLen + ". Should be <= $MAX_ROUNDED_ARRAY_LENGTH"); + |} + |int $len = (int) $longLen; + """.stripMargin + } +} + /** * Returns the array containing the given input value (left) count (right) times. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index caea4fb25ff7e..d7744eb4c7dc7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -17,10 +17,16 @@ package org.apache.spark.sql.catalyst.expressions +import java.sql.{Date, Timestamp} +import java.util.TimeZone + import org.apache.spark.SparkFunSuite import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.util.DateTimeTestUtils import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.array.ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH +import org.apache.spark.unsafe.types.CalendarInterval class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -484,6 +490,292 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper ArrayMax(Literal.create(Seq(1.123, 0.1234, 1.121), ArrayType(DoubleType))), 1.123) } + test("Sequence of numbers") { + // test null handling + + checkEvaluation(new Sequence(Literal(null, LongType), Literal(1L)), null) + checkEvaluation(new Sequence(Literal(1L), Literal(null, LongType)), null) + checkEvaluation(new Sequence(Literal(null, LongType), Literal(1L), Literal(1L)), null) + checkEvaluation(new Sequence(Literal(1L), Literal(null, LongType), Literal(1L)), null) + checkEvaluation(new Sequence(Literal(1L), Literal(1L), Literal(null, LongType)), null) + + // test sequence boundaries checking + + checkExceptionInExpression[IllegalArgumentException]( + new Sequence(Literal(Int.MinValue), Literal(Int.MaxValue), Literal(1)), + EmptyRow, s"Too long sequence: 4294967296. Should be <= $MAX_ROUNDED_ARRAY_LENGTH") + + checkExceptionInExpression[IllegalArgumentException]( + new Sequence(Literal(1), Literal(2), Literal(0)), EmptyRow, "boundaries: 1 to 2 by 0") + checkExceptionInExpression[IllegalArgumentException]( + new Sequence(Literal(2), Literal(1), Literal(0)), EmptyRow, "boundaries: 2 to 1 by 0") + checkExceptionInExpression[IllegalArgumentException]( + new Sequence(Literal(2), Literal(1), Literal(1)), EmptyRow, "boundaries: 2 to 1 by 1") + checkExceptionInExpression[IllegalArgumentException]( + new Sequence(Literal(1), Literal(2), Literal(-1)), EmptyRow, "boundaries: 1 to 2 by -1") + + // test sequence with one element (zero step or equal start and stop) + + checkEvaluation(new Sequence(Literal(1), Literal(1), Literal(-1)), Seq(1)) + checkEvaluation(new Sequence(Literal(1), Literal(1), Literal(0)), Seq(1)) + checkEvaluation(new Sequence(Literal(1), Literal(1), Literal(1)), Seq(1)) + checkEvaluation(new Sequence(Literal(1), Literal(2), Literal(2)), Seq(1)) + checkEvaluation(new Sequence(Literal(1), Literal(0), Literal(-2)), Seq(1)) + + // test sequence of different integral types (ascending and descending) + + checkEvaluation(new Sequence(Literal(1L), Literal(3L), Literal(1L)), Seq(1L, 2L, 3L)) + checkEvaluation(new Sequence(Literal(-3), Literal(3), Literal(3)), Seq(-3, 0, 3)) + checkEvaluation( + new Sequence(Literal(3.toShort), Literal(-3.toShort), Literal(-3.toShort)), + Seq(3.toShort, 0.toShort, -3.toShort)) + checkEvaluation( + new Sequence(Literal(-1.toByte), Literal(-3.toByte), Literal(-1.toByte)), + Seq(-1.toByte, -2.toByte, -3.toByte)) + } + + test("Sequence of timestamps") { + checkEvaluation(new Sequence( + Literal(Timestamp.valueOf("2018-01-01 00:00:00")), + Literal(Timestamp.valueOf("2018-01-02 00:00:00")), + Literal(CalendarInterval.fromString("interval 12 hours"))), + Seq( + Timestamp.valueOf("2018-01-01 00:00:00"), + Timestamp.valueOf("2018-01-01 12:00:00"), + Timestamp.valueOf("2018-01-02 00:00:00"))) + + checkEvaluation(new Sequence( + Literal(Timestamp.valueOf("2018-01-01 00:00:00")), + Literal(Timestamp.valueOf("2018-01-02 00:00:01")), + Literal(CalendarInterval.fromString("interval 12 hours"))), + Seq( + Timestamp.valueOf("2018-01-01 00:00:00"), + Timestamp.valueOf("2018-01-01 12:00:00"), + Timestamp.valueOf("2018-01-02 00:00:00"))) + + checkEvaluation(new Sequence( + Literal(Timestamp.valueOf("2018-01-02 00:00:00")), + Literal(Timestamp.valueOf("2018-01-01 00:00:00")), + Literal(CalendarInterval.fromString("interval 12 hours").negate())), + Seq( + Timestamp.valueOf("2018-01-02 00:00:00"), + Timestamp.valueOf("2018-01-01 12:00:00"), + Timestamp.valueOf("2018-01-01 00:00:00"))) + + checkEvaluation(new Sequence( + Literal(Timestamp.valueOf("2018-01-02 00:00:00")), + Literal(Timestamp.valueOf("2017-12-31 23:59:59")), + Literal(CalendarInterval.fromString("interval 12 hours").negate())), + Seq( + Timestamp.valueOf("2018-01-02 00:00:00"), + Timestamp.valueOf("2018-01-01 12:00:00"), + Timestamp.valueOf("2018-01-01 00:00:00"))) + + checkEvaluation(new Sequence( + Literal(Timestamp.valueOf("2018-01-01 00:00:00")), + Literal(Timestamp.valueOf("2018-03-01 00:00:00")), + Literal(CalendarInterval.fromString("interval 1 month"))), + Seq( + Timestamp.valueOf("2018-01-01 00:00:00"), + Timestamp.valueOf("2018-02-01 00:00:00"), + Timestamp.valueOf("2018-03-01 00:00:00"))) + + checkEvaluation(new Sequence( + Literal(Timestamp.valueOf("2018-03-01 00:00:00")), + Literal(Timestamp.valueOf("2018-01-01 00:00:00")), + Literal(CalendarInterval.fromString("interval 1 month").negate())), + Seq( + Timestamp.valueOf("2018-03-01 00:00:00"), + Timestamp.valueOf("2018-02-01 00:00:00"), + Timestamp.valueOf("2018-01-01 00:00:00"))) + + checkEvaluation(new Sequence( + Literal(Timestamp.valueOf("2018-03-03 00:00:00")), + Literal(Timestamp.valueOf("2018-01-01 00:00:00")), + Literal(CalendarInterval.fromString("interval 1 month 1 day").negate())), + Seq( + Timestamp.valueOf("2018-03-03 00:00:00"), + Timestamp.valueOf("2018-02-02 00:00:00"), + Timestamp.valueOf("2018-01-01 00:00:00"))) + + checkEvaluation(new Sequence( + Literal(Timestamp.valueOf("2018-01-31 00:00:00")), + Literal(Timestamp.valueOf("2018-04-30 00:00:00")), + Literal(CalendarInterval.fromString("interval 1 month"))), + Seq( + Timestamp.valueOf("2018-01-31 00:00:00"), + Timestamp.valueOf("2018-02-28 00:00:00"), + Timestamp.valueOf("2018-03-31 00:00:00"), + Timestamp.valueOf("2018-04-30 00:00:00"))) + + checkEvaluation(new Sequence( + Literal(Timestamp.valueOf("2018-01-01 00:00:00")), + Literal(Timestamp.valueOf("2018-03-01 00:00:00")), + Literal(CalendarInterval.fromString("interval 1 month 1 second"))), + Seq( + Timestamp.valueOf("2018-01-01 00:00:00"), + Timestamp.valueOf("2018-02-01 00:00:01"))) + + checkEvaluation(new Sequence( + Literal(Timestamp.valueOf("2018-01-01 00:00:00")), + Literal(Timestamp.valueOf("2018-03-01 00:04:06")), + Literal(CalendarInterval.fromString("interval 1 month 2 minutes 3 seconds"))), + Seq( + Timestamp.valueOf("2018-01-01 00:00:00"), + Timestamp.valueOf("2018-02-01 00:02:03"), + Timestamp.valueOf("2018-03-01 00:04:06"))) + + checkEvaluation(new Sequence( + Literal(Timestamp.valueOf("2018-01-01 00:00:00")), + Literal(Timestamp.valueOf("2023-01-01 00:00:00")), + Literal(CalendarInterval.fromYearMonthString("1-5"))), + Seq( + Timestamp.valueOf("2018-01-01 00:00:00.000"), + Timestamp.valueOf("2019-06-01 00:00:00.000"), + Timestamp.valueOf("2020-11-01 00:00:00.000"), + Timestamp.valueOf("2022-04-01 00:00:00.000"))) + + checkEvaluation(new Sequence( + Literal(Timestamp.valueOf("2022-04-01 00:00:00")), + Literal(Timestamp.valueOf("2017-01-01 00:00:00")), + Literal(CalendarInterval.fromYearMonthString("1-5").negate())), + Seq( + Timestamp.valueOf("2022-04-01 00:00:00.000"), + Timestamp.valueOf("2020-11-01 00:00:00.000"), + Timestamp.valueOf("2019-06-01 00:00:00.000"), + Timestamp.valueOf("2018-01-01 00:00:00.000"))) + } + + test("Sequence on DST boundaries") { + val timeZone = TimeZone.getTimeZone("Europe/Prague") + val dstOffset = timeZone.getDSTSavings + + def noDST(t: Timestamp): Timestamp = new Timestamp(t.getTime - dstOffset) + + DateTimeTestUtils.withDefaultTimeZone(timeZone) { + // Spring time change + checkEvaluation(new Sequence( + Literal(Timestamp.valueOf("2018-03-25 01:30:00")), + Literal(Timestamp.valueOf("2018-03-25 03:30:00")), + Literal(CalendarInterval.fromString("interval 30 minutes"))), + Seq( + Timestamp.valueOf("2018-03-25 01:30:00"), + Timestamp.valueOf("2018-03-25 03:00:00"), + Timestamp.valueOf("2018-03-25 03:30:00"))) + + // Autumn time change + checkEvaluation(new Sequence( + Literal(Timestamp.valueOf("2018-10-28 01:30:00")), + Literal(Timestamp.valueOf("2018-10-28 03:30:00")), + Literal(CalendarInterval.fromString("interval 30 minutes"))), + Seq( + Timestamp.valueOf("2018-10-28 01:30:00"), + noDST(Timestamp.valueOf("2018-10-28 02:00:00")), + noDST(Timestamp.valueOf("2018-10-28 02:30:00")), + Timestamp.valueOf("2018-10-28 02:00:00"), + Timestamp.valueOf("2018-10-28 02:30:00"), + Timestamp.valueOf("2018-10-28 03:00:00"), + Timestamp.valueOf("2018-10-28 03:30:00"))) + } + } + + test("Sequence of dates") { + DateTimeTestUtils.withDefaultTimeZone(TimeZone.getTimeZone("UTC")) { + checkEvaluation(new Sequence( + Literal(Date.valueOf("2018-01-01")), + Literal(Date.valueOf("2018-01-05")), + Literal(CalendarInterval.fromString("interval 2 days"))), + Seq( + Date.valueOf("2018-01-01"), + Date.valueOf("2018-01-03"), + Date.valueOf("2018-01-05"))) + + checkEvaluation(new Sequence( + Literal(Date.valueOf("2018-01-01")), + Literal(Date.valueOf("2018-03-01")), + Literal(CalendarInterval.fromString("interval 1 month"))), + Seq( + Date.valueOf("2018-01-01"), + Date.valueOf("2018-02-01"), + Date.valueOf("2018-03-01"))) + + checkEvaluation(new Sequence( + Literal(Date.valueOf("2018-01-31")), + Literal(Date.valueOf("2018-04-30")), + Literal(CalendarInterval.fromString("interval 1 month"))), + Seq( + Date.valueOf("2018-01-31"), + Date.valueOf("2018-02-28"), + Date.valueOf("2018-03-31"), + Date.valueOf("2018-04-30"))) + + checkEvaluation(new Sequence( + Literal(Date.valueOf("2018-01-01")), + Literal(Date.valueOf("2023-01-01")), + Literal(CalendarInterval.fromYearMonthString("1-5"))), + Seq( + Date.valueOf("2018-01-01"), + Date.valueOf("2019-06-01"), + Date.valueOf("2020-11-01"), + Date.valueOf("2022-04-01"))) + + checkExceptionInExpression[IllegalArgumentException]( + new Sequence( + Literal(Date.valueOf("1970-01-02")), + Literal(Date.valueOf("1970-01-01")), + Literal(CalendarInterval.fromString("interval 1 day"))), + EmptyRow, "sequence boundaries: 1 to 0 by 1") + + checkExceptionInExpression[IllegalArgumentException]( + new Sequence( + Literal(Date.valueOf("1970-01-01")), + Literal(Date.valueOf("1970-02-01")), + Literal(CalendarInterval.fromString("interval 1 month").negate())), + EmptyRow, + s"sequence boundaries: 0 to 2678400000000 by -${28 * CalendarInterval.MICROS_PER_DAY}") + } + } + + test("Sequence with default step") { + // +/- 1 for integral type + checkEvaluation(new Sequence(Literal(1), Literal(3)), Seq(1, 2, 3)) + checkEvaluation(new Sequence(Literal(3), Literal(1)), Seq(3, 2, 1)) + + // +/- 1 day for timestamps + checkEvaluation(new Sequence( + Literal(Timestamp.valueOf("2018-01-01 00:00:00")), + Literal(Timestamp.valueOf("2018-01-03 00:00:00"))), + Seq( + Timestamp.valueOf("2018-01-01 00:00:00"), + Timestamp.valueOf("2018-01-02 00:00:00"), + Timestamp.valueOf("2018-01-03 00:00:00"))) + + checkEvaluation(new Sequence( + Literal(Timestamp.valueOf("2018-01-03 00:00:00")), + Literal(Timestamp.valueOf("2018-01-01 00:00:00"))), + Seq( + Timestamp.valueOf("2018-01-03 00:00:00"), + Timestamp.valueOf("2018-01-02 00:00:00"), + Timestamp.valueOf("2018-01-01 00:00:00"))) + + // +/- 1 day for dates + checkEvaluation(new Sequence( + Literal(Date.valueOf("2018-01-01")), + Literal(Date.valueOf("2018-01-03"))), + Seq( + Date.valueOf("2018-01-01"), + Date.valueOf("2018-01-02"), + Date.valueOf("2018-01-03"))) + + checkEvaluation(new Sequence( + Literal(Date.valueOf("2018-01-03")), + Literal(Date.valueOf("2018-01-01"))), + Seq( + Date.valueOf("2018-01-03"), + Date.valueOf("2018-01-02"), + Date.valueOf("2018-01-01"))) + } + test("Reverse") { // Primitive-type elements val ai0 = Literal.create(Seq(2, 1, 4, 3), ArrayType(IntegerType)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index ef99ce3ad69d9..0b4f526799578 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3485,6 +3485,27 @@ object functions { */ def flatten(e: Column): Column = withExpr { Flatten(e.expr) } + /** + * Generate a sequence of integers from start to stop, incrementing by step. + * + * @group collection_funcs + * @since 2.4.0 + */ + def sequence(start: Column, stop: Column, step: Column): Column = withExpr { + new Sequence(start.expr, stop.expr, step.expr) + } + + /** + * Generate a sequence of integers from start to stop, + * incrementing by 1 if start is less than or equal to stop, otherwise -1. + * + * @group collection_funcs + * @since 2.4.0 + */ + def sequence(start: Column, stop: Column): Column = withExpr { + new Sequence(start.expr, stop.expr) + } + /** * Creates an array containing the left argument repeated the number of times given by the * right argument. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index b109898b5bfb3..4c28e2f1cd909 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -18,12 +18,15 @@ package org.apache.spark.sql import java.nio.charset.StandardCharsets +import java.sql.{Date, Timestamp} +import java.util.TimeZone import scala.util.Random import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.util.DateTimeTestUtils import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext @@ -862,6 +865,59 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { checkAnswer(df.selectExpr("array_max(a)"), answer) } + test("sequence") { + checkAnswer(Seq((-2, 2)).toDF().select(sequence('_1, '_2)), Seq(Row(Array(-2, -1, 0, 1, 2)))) + checkAnswer(Seq((7, 2, -2)).toDF().select(sequence('_1, '_2, '_3)), Seq(Row(Array(7, 5, 3)))) + + checkAnswer( + spark.sql("select sequence(" + + " cast('2018-01-01 00:00:00' as timestamp)" + + ", cast('2018-01-02 00:00:00' as timestamp)" + + ", interval 12 hours)"), + Seq(Row(Array( + Timestamp.valueOf("2018-01-01 00:00:00"), + Timestamp.valueOf("2018-01-01 12:00:00"), + Timestamp.valueOf("2018-01-02 00:00:00"))))) + + DateTimeTestUtils.withDefaultTimeZone(TimeZone.getTimeZone("UTC")) { + checkAnswer( + spark.sql("select sequence(" + + " cast('2018-01-01' as date)" + + ", cast('2018-03-01' as date)" + + ", interval 1 month)"), + Seq(Row(Array( + Date.valueOf("2018-01-01"), + Date.valueOf("2018-02-01"), + Date.valueOf("2018-03-01"))))) + } + + // test type coercion + checkAnswer( + Seq((1.toByte, 3L, 1)).toDF().select(sequence('_1, '_2, '_3)), + Seq(Row(Array(1L, 2L, 3L)))) + + checkAnswer( + spark.sql("select sequence(" + + " cast('2018-01-01' as date)" + + ", cast('2018-01-02 00:00:00' as timestamp)" + + ", interval 12 hours)"), + Seq(Row(Array( + Timestamp.valueOf("2018-01-01 00:00:00"), + Timestamp.valueOf("2018-01-01 12:00:00"), + Timestamp.valueOf("2018-01-02 00:00:00"))))) + + // test invalid data types + intercept[AnalysisException] { + Seq((true, false)).toDF().selectExpr("sequence(_1, _2)") + } + intercept[AnalysisException] { + Seq((true, false, 42)).toDF().selectExpr("sequence(_1, _2, _3)") + } + intercept[AnalysisException] { + Seq((1, 2, 0.5)).toDF().selectExpr("sequence(_1, _2, _3)") + } + } + test("reverse function") { val dummyFilter = (c: Column) => c.isNull || c.isNotNull // switch codegen on From 9a76f23c6a1756053c30f58baea2966d1b023981 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Wed, 27 Jun 2018 11:52:48 +0800 Subject: [PATCH 1023/1161] [SPARK-23927][SQL][FOLLOW-UP] Fix a build failure. ## What changes were proposed in this pull request? This pr is a follow-up pr of #21155. The #21155 removed unnecessary import at that time, but the import became necessary in another pr. ## How was this patch tested? Existing tests. Author: Takuya UESHIN Closes #21646 from ueshin/issues/SPARK-23927/fup1. --- .../spark/sql/catalyst/expressions/collectionOperations.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 0395e1ef9a7ad..8b278f067749e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.catalyst.util.DateTimeUtils._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.array.ByteArrayMethods From a1a64e3583cfa451b4d0d2361c1da2972a5e4444 Mon Sep 17 00:00:00 2001 From: Yuexin Zhang Date: Wed, 27 Jun 2018 16:05:36 +0800 Subject: [PATCH 1024/1161] [SPARK-21335][DOC] doc changes for disallowed un-aliased subquery use case ## What changes were proposed in this pull request? Document a change for un-aliased subquery use case, to address the last question in PR #18559: https://github.com/apache/spark/pull/18559#issuecomment-316884858 (Please fill in changes proposed in this fix) ## How was this patch tested? it does not affect tests. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Yuexin Zhang Closes #21647 from cnZach/doc_change_for_SPARK-20690_SPARK-21335. --- docs/sql-programming-guide.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 7c4ef41cc8907..cd7329b621122 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -2017,6 +2017,7 @@ working with timestamps in `pandas_udf`s to get the best performance, see - Literal values used in SQL operations are converted to DECIMAL with the exact precision and scale needed by them. - The configuration `spark.sql.decimalOperations.allowPrecisionLoss` has been introduced. It defaults to `true`, which means the new behavior described here; if set to `false`, Spark uses previous rules, ie. it doesn't adjust the needed scale to represent the values and it returns NULL if an exact representation of the value is not possible. - In PySpark, `df.replace` does not allow to omit `value` when `to_replace` is not a dictionary. Previously, `value` could be omitted in the other cases and had `None` by default, which is counterintuitive and error-prone. + - Un-aliased subquery's semantic has not been well defined with confusing behaviors. Since Spark 2.3, we invalidate such confusing cases, for example: `SELECT v.i from (SELECT i FROM v)`, Spark will throw an analysis exception in this case because users should not be able to use the qualifier inside a subquery. See [SPARK-20690](https://issues.apache.org/jira/browse/SPARK-20690) and [SPARK-21335](https://issues.apache.org/jira/browse/SPARK-21335) for more details. ## Upgrading From Spark SQL 2.1 to 2.2 From 6a0b77a55d53e74ac0a0892556c3a7a933474948 Mon Sep 17 00:00:00 2001 From: Yuanjian Li Date: Wed, 27 Jun 2018 10:43:06 -0700 Subject: [PATCH 1025/1161] [SPARK-24215][PYSPARK][FOLLOW UP] Implement eager evaluation for DataFrame APIs in PySpark ## What changes were proposed in this pull request? Address comments in #21370 and add more test. ## How was this patch tested? Enhance test in pyspark/sql/test.py and DataFrameSuite Author: Yuanjian Li Closes #21553 from xuanyuanking/SPARK-24215-follow. --- docs/configuration.md | 27 --------- python/pyspark/sql/dataframe.py | 3 +- python/pyspark/sql/tests.py | 46 ++++++++++++++- .../apache/spark/sql/internal/SQLConf.scala | 23 ++++++++ .../scala/org/apache/spark/sql/Dataset.scala | 11 ++-- .../org/apache/spark/sql/DataFrameSuite.scala | 59 +++++++++++++++++++ 6 files changed, 131 insertions(+), 38 deletions(-) diff --git a/docs/configuration.md b/docs/configuration.md index 6aa7878fe614d..0c7c4472be643 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -456,33 +456,6 @@ Apart from these, the following properties are also available, and may be useful from JVM to Python worker for every task. - - spark.sql.repl.eagerEval.enabled - false - - Enable eager evaluation or not. If true and the REPL you are using supports eager evaluation, - Dataset will be ran automatically. The HTML table which generated by _repl_html_ - called by notebooks like Jupyter will feedback the queries user have defined. For plain Python - REPL, the output will be shown like dataframe.show() - (see SPARK-24215 for more details). - - - - spark.sql.repl.eagerEval.maxNumRows - 20 - - Default number of rows in eager evaluation output HTML table generated by _repr_html_ or plain text, - this only take effect when spark.sql.repl.eagerEval.enabled is set to true. - - - - spark.sql.repl.eagerEval.truncate - 20 - - Default number of truncate in eager evaluation output HTML table generated by _repr_html_ or - plain text, this only take effect when spark.sql.repl.eagerEval.enabled set to true. - - spark.files diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 1e6a1acebb5ca..cb3fe448b6fc7 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -393,9 +393,8 @@ def _repr_html_(self): self._support_repr_html = True if self._eager_eval: max_num_rows = max(self._max_num_rows, 0) - vertical = False sock_info = self._jdf.getRowsToPython( - max_num_rows, self._truncate, vertical) + max_num_rows, self._truncate) rows = list(_load_from_socket(sock_info, BatchedSerializer(PickleSerializer()))) head = rows[0] row_data = rows[1:] diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 35a0636e5cfc0..8d738069adb3d 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3351,11 +3351,41 @@ def test_checking_csv_header(self): finally: shutil.rmtree(path) - def test_repr_html(self): + def test_repr_behaviors(self): import re pattern = re.compile(r'^ *\|', re.MULTILINE) df = self.spark.createDataFrame([(1, "1"), (22222, "22222")], ("key", "value")) - self.assertEquals(None, df._repr_html_()) + + # test when eager evaluation is enabled and _repr_html_ will not be called + with self.sql_conf({"spark.sql.repl.eagerEval.enabled": True}): + expected1 = """+-----+-----+ + || key|value| + |+-----+-----+ + || 1| 1| + ||22222|22222| + |+-----+-----+ + |""" + self.assertEquals(re.sub(pattern, '', expected1), df.__repr__()) + with self.sql_conf({"spark.sql.repl.eagerEval.truncate": 3}): + expected2 = """+---+-----+ + ||key|value| + |+---+-----+ + || 1| 1| + ||222| 222| + |+---+-----+ + |""" + self.assertEquals(re.sub(pattern, '', expected2), df.__repr__()) + with self.sql_conf({"spark.sql.repl.eagerEval.maxNumRows": 1}): + expected3 = """+---+-----+ + ||key|value| + |+---+-----+ + || 1| 1| + |+---+-----+ + |only showing top 1 row + |""" + self.assertEquals(re.sub(pattern, '', expected3), df.__repr__()) + + # test when eager evaluation is enabled and _repr_html_ will be called with self.sql_conf({"spark.sql.repl.eagerEval.enabled": True}): expected1 = """ | @@ -3381,6 +3411,18 @@ def test_repr_html(self): |""" self.assertEquals(re.sub(pattern, '', expected3), df._repr_html_()) + # test when eager evaluation is disabled and _repr_html_ will be called + with self.sql_conf({"spark.sql.repl.eagerEval.enabled": False}): + expected = "DataFrame[key: bigint, value: string]" + self.assertEquals(None, df._repr_html_()) + self.assertEquals(expected, df.__repr__()) + with self.sql_conf({"spark.sql.repl.eagerEval.truncate": 3}): + self.assertEquals(None, df._repr_html_()) + self.assertEquals(expected, df.__repr__()) + with self.sql_conf({"spark.sql.repl.eagerEval.maxNumRows": 1}): + self.assertEquals(None, df._repr_html_()) + self.assertEquals(expected, df.__repr__()) + class HiveSparkSubmitTests(SparkSubmitTests): diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 239c8266351ae..e1752ff997b69 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1330,6 +1330,29 @@ object SQLConf { "The size function returns null for null input if the flag is disabled.") .booleanConf .createWithDefault(true) + + val REPL_EAGER_EVAL_ENABLED = buildConf("spark.sql.repl.eagerEval.enabled") + .doc("Enables eager evaluation or not. When true, the top K rows of Dataset will be " + + "displayed if and only if the REPL supports the eager evaluation. Currently, the " + + "eager evaluation is only supported in PySpark. For the notebooks like Jupyter, " + + "the HTML table (generated by _repr_html_) will be returned. For plain Python REPL, " + + "the returned outputs are formatted like dataframe.show().") + .booleanConf + .createWithDefault(false) + + val REPL_EAGER_EVAL_MAX_NUM_ROWS = buildConf("spark.sql.repl.eagerEval.maxNumRows") + .doc("The max number of rows that are returned by eager evaluation. This only takes " + + "effect when spark.sql.repl.eagerEval.enabled is set to true. The valid range of this " + + "config is from 0 to (Int.MaxValue - 1), so the invalid config like negative and " + + "greater than (Int.MaxValue - 1) will be normalized to 0 and (Int.MaxValue - 1).") + .intConf + .createWithDefault(20) + + val REPL_EAGER_EVAL_TRUNCATE = buildConf("spark.sql.repl.eagerEval.truncate") + .doc("The max number of characters for each cell that is returned by eager evaluation. " + + "This only takes effect when spark.sql.repl.eagerEval.enabled is set to true.") + .intConf + .createWithDefault(20) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 57f1e173211af..2ec236fc75efc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -236,12 +236,10 @@ class Dataset[T] private[sql]( * @param numRows Number of rows to return * @param truncate If set to more than 0, truncates strings to `truncate` characters and * all cells will be aligned right. - * @param vertical If set to true, the rows to return do not need truncate. */ private[sql] def getRows( numRows: Int, - truncate: Int, - vertical: Boolean): Seq[Seq[String]] = { + truncate: Int): Seq[Seq[String]] = { val newDf = toDF() val castCols = newDf.logicalPlan.output.map { col => // Since binary types in top-level schema fields have a specific format to print, @@ -289,7 +287,7 @@ class Dataset[T] private[sql]( vertical: Boolean = false): String = { val numRows = _numRows.max(0).min(Int.MaxValue - 1) // Get rows represented by Seq[Seq[String]], we may get one more line if it has more data. - val tmpRows = getRows(numRows, truncate, vertical) + val tmpRows = getRows(numRows, truncate) val hasMoreData = tmpRows.length - 1 > numRows val rows = tmpRows.take(numRows + 1) @@ -3226,11 +3224,10 @@ class Dataset[T] private[sql]( private[sql] def getRowsToPython( _numRows: Int, - truncate: Int, - vertical: Boolean): Array[Any] = { + truncate: Int): Array[Any] = { EvaluatePython.registerPicklers() val numRows = _numRows.max(0).min(Int.MaxValue - 1) - val rows = getRows(numRows, truncate, vertical).map(_.toArray).toArray + val rows = getRows(numRows, truncate).map(_.toArray).toArray val toJava: (Any) => Any = EvaluatePython.toJava(_, ArrayType(ArrayType(StringType))) val iter: Iterator[Array[Byte]] = new SerDeUtil.AutoBatchedPickler( rows.iterator.map(toJava)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 1cc8cb3874c9b..ea00d22bff001 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -1044,6 +1044,65 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { testData.select($"*").show(1000) } + test("getRows: truncate = [0, 20]") { + val longString = Array.fill(21)("1").mkString + val df = sparkContext.parallelize(Seq("1", longString)).toDF() + val expectedAnswerForFalse = Seq( + Seq("value"), + Seq("1"), + Seq("111111111111111111111")) + assert(df.getRows(10, 0) === expectedAnswerForFalse) + val expectedAnswerForTrue = Seq( + Seq("value"), + Seq("1"), + Seq("11111111111111111...")) + assert(df.getRows(10, 20) === expectedAnswerForTrue) + } + + test("getRows: truncate = [3, 17]") { + val longString = Array.fill(21)("1").mkString + val df = sparkContext.parallelize(Seq("1", longString)).toDF() + val expectedAnswerForFalse = Seq( + Seq("value"), + Seq("1"), + Seq("111")) + assert(df.getRows(10, 3) === expectedAnswerForFalse) + val expectedAnswerForTrue = Seq( + Seq("value"), + Seq("1"), + Seq("11111111111111...")) + assert(df.getRows(10, 17) === expectedAnswerForTrue) + } + + test("getRows: numRows = 0") { + val expectedAnswer = Seq(Seq("key", "value"), Seq("1", "1")) + assert(testData.select($"*").getRows(0, 20) === expectedAnswer) + } + + test("getRows: array") { + val df = Seq( + (Array(1, 2, 3), Array(1, 2, 3)), + (Array(2, 3, 4), Array(2, 3, 4)) + ).toDF() + val expectedAnswer = Seq( + Seq("_1", "_2"), + Seq("[1, 2, 3]", "[1, 2, 3]"), + Seq("[2, 3, 4]", "[2, 3, 4]")) + assert(df.getRows(10, 20) === expectedAnswer) + } + + test("getRows: binary") { + val df = Seq( + ("12".getBytes(StandardCharsets.UTF_8), "ABC.".getBytes(StandardCharsets.UTF_8)), + ("34".getBytes(StandardCharsets.UTF_8), "12346".getBytes(StandardCharsets.UTF_8)) + ).toDF() + val expectedAnswer = Seq( + Seq("_1", "_2"), + Seq("[31 32]", "[41 42 43 2E]"), + Seq("[33 34]", "[31 32 33 34 36]")) + assert(df.getRows(10, 20) === expectedAnswer) + } + test("showString: truncate = [0, 20]") { val longString = Array.fill(21)("1").mkString val df = sparkContext.parallelize(Seq("1", longString)).toDF() From 78ecb6d457970b136a2e0e0e27d170c84ea28eac Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Wed, 27 Jun 2018 10:57:29 -0700 Subject: [PATCH 1026/1161] [SPARK-24446][YARN] Properly quote library path for YARN. Because the way YARN executes commands via bash -c, everything needs to be quoted so that the whole command is fully contained inside a bash string and is interpreted correctly when the string is read by bash. This is a bit different than the quoting done when executing things as if typing in a bash shell. Tweaked unit tests to exercise the bad behavior, which would cause existing tests to time out without the fix. Also tested on a real cluster, verifying the shell script created by YARN to run the container. Author: Marcelo Vanzin Closes #21476 from vanzin/SPARK-24446. --- .../org/apache/spark/deploy/yarn/Client.scala | 22 +++++++++++++++++-- .../spark/deploy/yarn/ExecutorRunnable.scala | 11 +++++----- .../deploy/yarn/BaseYarnClusterSuite.scala | 9 ++++++++ 3 files changed, 34 insertions(+), 8 deletions(-) diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 7225ff03dc34e..793d012218490 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -899,7 +899,8 @@ private[spark] class Client( val libraryPaths = Seq(sparkConf.get(DRIVER_LIBRARY_PATH), sys.props.get("spark.driver.libraryPath")).flatten if (libraryPaths.nonEmpty) { - prefixEnv = Some(getClusterPath(sparkConf, Utils.libraryPathEnvPrefix(libraryPaths))) + prefixEnv = Some(createLibraryPathPrefix(libraryPaths.mkString(File.pathSeparator), + sparkConf)) } if (sparkConf.get(AM_JAVA_OPTIONS).isDefined) { logWarning(s"${AM_JAVA_OPTIONS.key} will not take effect in cluster mode") @@ -921,7 +922,7 @@ private[spark] class Client( .map(YarnSparkHadoopUtil.escapeForShell) } sparkConf.get(AM_LIBRARY_PATH).foreach { paths => - prefixEnv = Some(getClusterPath(sparkConf, Utils.libraryPathEnvPrefix(Seq(paths)))) + prefixEnv = Some(createLibraryPathPrefix(paths, sparkConf)) } } @@ -1485,6 +1486,23 @@ private object Client extends Logging { YarnAppReport(report.getYarnApplicationState(), report.getFinalApplicationStatus(), diagsOpt) } + /** + * Create a properly quoted and escaped library path string to be added as a prefix to the command + * executed by YARN. This is different from normal quoting / escaping due to YARN executing the + * command through "bash -c". + */ + def createLibraryPathPrefix(libpath: String, conf: SparkConf): String = { + val cmdPrefix = if (Utils.isWindows) { + Utils.libraryPathEnvPrefix(Seq(libpath)) + } else { + val envName = Utils.libraryPathEnvName + // For quotes, escape both the quote and the escape character when encoding in the command + // string. + val quoted = libpath.replace("\"", "\\\\\\\"") + envName + "=\\\"" + quoted + File.pathSeparator + "$" + envName + "\\\"" + } + getClusterPath(conf, cmdPrefix) + } } private[spark] class YarnClusterApplication extends SparkApplication { diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala index a2a18cdff65af..49a0b93aa5c40 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala @@ -131,10 +131,6 @@ private[yarn] class ExecutorRunnable( // Extra options for the JVM val javaOpts = ListBuffer[String]() - // Set the environment variable through a command prefix - // to append to the existing value of the variable - var prefixEnv: Option[String] = None - // Set the JVM memory val executorMemoryString = executorMemory + "m" javaOpts += "-Xmx" + executorMemoryString @@ -144,8 +140,11 @@ private[yarn] class ExecutorRunnable( val subsOpt = Utils.substituteAppNExecIds(opts, appId, executorId) javaOpts ++= Utils.splitCommandString(subsOpt).map(YarnSparkHadoopUtil.escapeForShell) } - sparkConf.get(EXECUTOR_LIBRARY_PATH).foreach { p => - prefixEnv = Some(Client.getClusterPath(sparkConf, Utils.libraryPathEnvPrefix(Seq(p)))) + + // Set the library path through a command prefix to append to the existing value of the + // env variable. + val prefixEnv = sparkConf.get(EXECUTOR_LIBRARY_PATH).map { libPath => + Client.createLibraryPathPrefix(libPath, sparkConf) } javaOpts += "-Djava.io.tmpdir=" + diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala index ac67f2196e0a0..b0abcc9149d08 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala @@ -36,6 +36,7 @@ import org.scalatest.concurrent.Eventually._ import org.apache.spark._ import org.apache.spark.deploy.yarn.config._ import org.apache.spark.internal.Logging +import org.apache.spark.internal.config._ import org.apache.spark.launcher._ import org.apache.spark.util.Utils @@ -216,6 +217,14 @@ abstract class BaseYarnClusterSuite props.setProperty("spark.driver.extraJavaOptions", "-Dfoo=\"one two three\"") props.setProperty("spark.executor.extraJavaOptions", "-Dfoo=\"one two three\"") + // SPARK-24446: make sure special characters in the library path do not break containers. + if (!Utils.isWindows) { + val libPath = """/tmp/does not exist:$PWD/tmp:/tmp/quote":/tmp/ampersand&""" + props.setProperty(AM_LIBRARY_PATH.key, libPath) + props.setProperty(DRIVER_LIBRARY_PATH.key, libPath) + props.setProperty(EXECUTOR_LIBRARY_PATH.key, libPath) + } + yarnCluster.getConfig().asScala.foreach { e => props.setProperty("spark.hadoop." + e.getKey(), e.getValue()) } From c04cb2d1b72b1edaddf684755f5a9d6aaf00e03b Mon Sep 17 00:00:00 2001 From: debugger87 Date: Wed, 27 Jun 2018 11:34:28 -0700 Subject: [PATCH 1027/1161] [SPARK-21687][SQL] Spark SQL should set createTime for Hive partition ## What changes were proposed in this pull request? Set createTime for every hive partition created in Spark SQL, which could be used to manage data lifecycle in Hive warehouse. We found that almost every partition modified by spark sql has not been set createTime. ``` mysql> select * from partitions where create_time=0 limit 1\G; *************************** 1. row *************************** PART_ID: 1028584 CREATE_TIME: 0 LAST_ACCESS_TIME: 1502203611 PART_NAME: date=20170130 SD_ID: 1543605 TBL_ID: 211605 LINK_TARGET_ID: NULL 1 row in set (0.27 sec) ``` ## How was this patch tested? N/A Author: debugger87 Author: Chaozhong Yang Closes #18900 from debugger87/fix/set-create-time-for-hive-partition. --- .../spark/sql/catalyst/catalog/interface.scala | 6 ++++++ .../sql/catalyst/catalog/SessionCatalogSuite.scala | 6 ++++-- .../results/describe-part-after-analyze.sql.out | 14 ++++++++++++++ .../resources/sql-tests/results/describe.sql.out | 4 ++++ .../sql-tests/results/show-tables.sql.out | 2 ++ .../spark/sql/hive/client/HiveClientImpl.scala | 4 ++++ 6 files changed, 34 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala index f3e67dc4e975c..c6105c5526049 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala @@ -93,12 +93,16 @@ object CatalogStorageFormat { * @param spec partition spec values indexed by column name * @param storage storage format of the partition * @param parameters some parameters for the partition + * @param createTime creation time of the partition, in milliseconds + * @param lastAccessTime last access time, in milliseconds * @param stats optional statistics (number of rows, total size, etc.) */ case class CatalogTablePartition( spec: CatalogTypes.TablePartitionSpec, storage: CatalogStorageFormat, parameters: Map[String, String] = Map.empty, + createTime: Long = System.currentTimeMillis, + lastAccessTime: Long = -1, stats: Option[CatalogStatistics] = None) { def toLinkedHashMap: mutable.LinkedHashMap[String, String] = { @@ -109,6 +113,8 @@ case class CatalogTablePartition( if (parameters.nonEmpty) { map.put("Partition Parameters", s"{${parameters.map(p => p._1 + "=" + p._2).mkString(", ")}}") } + map.put("Created Time", new Date(createTime).toString) + map.put("Last Access", new Date(lastAccessTime).toString) stats.foreach(s => map.put("Partition Statistics", s.simpleString)) map } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala index 6abab0073cca3..6a7375ee186fa 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala @@ -1114,11 +1114,13 @@ abstract class SessionCatalogSuite extends AnalysisTest { // And for hive serde table, hive metastore will set some values(e.g.transient_lastDdlTime) // in table's parameters and storage's properties, here we also ignore them. val actualPartsNormalize = actualParts.map(p => - p.copy(parameters = Map.empty, storage = p.storage.copy( + p.copy(parameters = Map.empty, createTime = -1, lastAccessTime = -1, + storage = p.storage.copy( properties = Map.empty, locationUri = None, serde = None))).toSet val expectedPartsNormalize = expectedParts.map(p => - p.copy(parameters = Map.empty, storage = p.storage.copy( + p.copy(parameters = Map.empty, createTime = -1, lastAccessTime = -1, + storage = p.storage.copy( properties = Map.empty, locationUri = None, serde = None))).toSet actualPartsNormalize == expectedPartsNormalize diff --git a/sql/core/src/test/resources/sql-tests/results/describe-part-after-analyze.sql.out b/sql/core/src/test/resources/sql-tests/results/describe-part-after-analyze.sql.out index 58ed201e2a60f..8ba69c698b551 100644 --- a/sql/core/src/test/resources/sql-tests/results/describe-part-after-analyze.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/describe-part-after-analyze.sql.out @@ -57,6 +57,8 @@ Database default Table t Partition Values [ds=2017-08-01, hr=10] Location [not included in comparison]sql/core/spark-warehouse/t/ds=2017-08-01/hr=10 +Created Time [not included in comparison] +Last Access [not included in comparison] # Storage Information Location [not included in comparison]sql/core/spark-warehouse/t @@ -89,6 +91,8 @@ Database default Table t Partition Values [ds=2017-08-01, hr=10] Location [not included in comparison]sql/core/spark-warehouse/t/ds=2017-08-01/hr=10 +Created Time [not included in comparison] +Last Access [not included in comparison] Partition Statistics 1121 bytes, 3 rows # Storage Information @@ -122,6 +126,8 @@ Database default Table t Partition Values [ds=2017-08-01, hr=10] Location [not included in comparison]sql/core/spark-warehouse/t/ds=2017-08-01/hr=10 +Created Time [not included in comparison] +Last Access [not included in comparison] Partition Statistics 1121 bytes, 3 rows # Storage Information @@ -147,6 +153,8 @@ Database default Table t Partition Values [ds=2017-08-01, hr=11] Location [not included in comparison]sql/core/spark-warehouse/t/ds=2017-08-01/hr=11 +Created Time [not included in comparison] +Last Access [not included in comparison] Partition Statistics 1098 bytes, 4 rows # Storage Information @@ -180,6 +188,8 @@ Database default Table t Partition Values [ds=2017-08-01, hr=10] Location [not included in comparison]sql/core/spark-warehouse/t/ds=2017-08-01/hr=10 +Created Time [not included in comparison] +Last Access [not included in comparison] Partition Statistics 1121 bytes, 3 rows # Storage Information @@ -205,6 +215,8 @@ Database default Table t Partition Values [ds=2017-08-01, hr=11] Location [not included in comparison]sql/core/spark-warehouse/t/ds=2017-08-01/hr=11 +Created Time [not included in comparison] +Last Access [not included in comparison] Partition Statistics 1098 bytes, 4 rows # Storage Information @@ -230,6 +242,8 @@ Database default Table t Partition Values [ds=2017-09-01, hr=5] Location [not included in comparison]sql/core/spark-warehouse/t/ds=2017-09-01/hr=5 +Created Time [not included in comparison] +Last Access [not included in comparison] Partition Statistics 1144 bytes, 2 rows # Storage Information diff --git a/sql/core/src/test/resources/sql-tests/results/describe.sql.out b/sql/core/src/test/resources/sql-tests/results/describe.sql.out index 8c908b7625056..79390cb424444 100644 --- a/sql/core/src/test/resources/sql-tests/results/describe.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/describe.sql.out @@ -282,6 +282,8 @@ Table t Partition Values [c=Us, d=1] Location [not included in comparison]sql/core/spark-warehouse/t/c=Us/d=1 Storage Properties [a=1, b=2] +Created Time [not included in comparison] +Last Access [not included in comparison] # Storage Information Num Buckets 2 @@ -311,6 +313,8 @@ Table t Partition Values [c=Us, d=1] Location [not included in comparison]sql/core/spark-warehouse/t/c=Us/d=1 Storage Properties [a=1, b=2] +Created Time [not included in comparison] +Last Access [not included in comparison] # Storage Information Num Buckets 2 diff --git a/sql/core/src/test/resources/sql-tests/results/show-tables.sql.out b/sql/core/src/test/resources/sql-tests/results/show-tables.sql.out index 975bb06124744..abeb7e18f031e 100644 --- a/sql/core/src/test/resources/sql-tests/results/show-tables.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/show-tables.sql.out @@ -178,6 +178,8 @@ struct -- !query 14 output showdb show_t1 false Partition Values: [c=Us, d=1] Location [not included in comparison]sql/core/spark-warehouse/showdb.db/show_t1/c=Us/d=1 +Created Time [not included in comparison] +Last Access [not included in comparison] -- !query 15 diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala index da9fe2d3088b4..1df46d7431a21 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala @@ -995,6 +995,8 @@ private[hive] object HiveClientImpl { tpart.setTableName(ht.getTableName) tpart.setValues(partValues.asJava) tpart.setSd(storageDesc) + tpart.setCreateTime((p.createTime / 1000).toInt) + tpart.setLastAccessTime((p.lastAccessTime / 1000).toInt) tpart.setParameters(mutable.Map(p.parameters.toSeq: _*).asJava) new HivePartition(ht, tpart) } @@ -1019,6 +1021,8 @@ private[hive] object HiveClientImpl { compressed = apiPartition.getSd.isCompressed, properties = Option(apiPartition.getSd.getSerdeInfo.getParameters) .map(_.asScala.toMap).orNull), + createTime = apiPartition.getCreateTime.toLong * 1000, + lastAccessTime = apiPartition.getLastAccessTime.toLong * 1000, parameters = properties, stats = readHiveStats(properties)) } From 776befbfd5b3c317a713d4fa3882cda6264db9ba Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Wed, 27 Jun 2018 14:26:08 -0700 Subject: [PATCH 1028/1161] [SPARK-24660][SHS] Show correct error pages when downloading logs ## What changes were proposed in this pull request? SHS is showing bad errors when trying to download logs is not successful. This may happen because the requested application doesn't exist or the user doesn't have permissions for it, for instance. The PR fixes the response when errors occur, so that they are displayed properly. ## How was this patch tested? manual tests **Before the patch:** 1. Unauthorized user ![screen shot 2018-06-26 at 3 53 33 pm](https://user-images.githubusercontent.com/8821783/41918118-f8b37e70-795b-11e8-91e8-d0250239f09d.png) 2. Non-existing application ![screen shot 2018-06-26 at 3 25 19 pm](https://user-images.githubusercontent.com/8821783/41918082-e3034c72-795b-11e8-970e-cee4a1eae77f.png) **After the patch** 1. Unauthorized user ![screen shot 2018-06-26 at 3 41 29 pm](https://user-images.githubusercontent.com/8821783/41918155-0d950476-795c-11e8-8d26-7b7ce73e6fe1.png) 2. Non-existing application ![screen shot 2018-06-26 at 3 40 37 pm](https://user-images.githubusercontent.com/8821783/41918175-1a14bb88-795c-11e8-91ab-eadf29190a02.png) Author: Marco Gaido Closes #21644 from mgaido91/SPARK-24660. --- .../spark/status/api/v1/ApiRootResource.scala | 30 ++++--------------- .../status/api/v1/JacksonMessageWriter.scala | 5 +--- .../api/v1/OneApplicationResource.scala | 7 ++--- .../scala/org/apache/spark/ui/UIUtils.scala | 5 ++++ 4 files changed, 13 insertions(+), 34 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala index d121068718b8a..84c2ad48f1f27 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala @@ -28,7 +28,7 @@ import org.glassfish.jersey.server.ServerProperties import org.glassfish.jersey.servlet.ServletContainer import org.apache.spark.SecurityManager -import org.apache.spark.ui.SparkUI +import org.apache.spark.ui.{SparkUI, UIUtils} /** * Main entry point for serving spark application metrics as json, using JAX-RS. @@ -148,38 +148,18 @@ private[v1] trait BaseAppResource extends ApiRequestContext { } private[v1] class ForbiddenException(msg: String) extends WebApplicationException( - Response.status(Response.Status.FORBIDDEN).entity(msg).build()) + UIUtils.buildErrorResponse(Response.Status.FORBIDDEN, msg)) private[v1] class NotFoundException(msg: String) extends WebApplicationException( - new NoSuchElementException(msg), - Response - .status(Response.Status.NOT_FOUND) - .entity(ErrorWrapper(msg)) - .build() -) + UIUtils.buildErrorResponse(Response.Status.NOT_FOUND, msg)) private[v1] class ServiceUnavailable(msg: String) extends WebApplicationException( - new ServiceUnavailableException(msg), - Response - .status(Response.Status.SERVICE_UNAVAILABLE) - .entity(ErrorWrapper(msg)) - .build() -) + UIUtils.buildErrorResponse(Response.Status.SERVICE_UNAVAILABLE, msg)) private[v1] class BadParameterException(msg: String) extends WebApplicationException( - new IllegalArgumentException(msg), - Response - .status(Response.Status.BAD_REQUEST) - .entity(ErrorWrapper(msg)) - .build() -) { + UIUtils.buildErrorResponse(Response.Status.BAD_REQUEST, msg)) { def this(param: String, exp: String, actual: String) = { this(raw"""Bad value for parameter "$param". Expected a $exp, got "$actual"""") } } -/** - * Signal to JacksonMessageWriter to not convert the message into json (which would result in an - * extra set of quotes). - */ -private[v1] case class ErrorWrapper(s: String) diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/JacksonMessageWriter.scala b/core/src/main/scala/org/apache/spark/status/api/v1/JacksonMessageWriter.scala index 76af33c1a18db..4560d300cb0c8 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/JacksonMessageWriter.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/JacksonMessageWriter.scala @@ -68,10 +68,7 @@ private[v1] class JacksonMessageWriter extends MessageBodyWriter[Object]{ mediaType: MediaType, multivaluedMap: MultivaluedMap[String, AnyRef], outputStream: OutputStream): Unit = { - t match { - case ErrorWrapper(err) => outputStream.write(err.getBytes(StandardCharsets.UTF_8)) - case _ => mapper.writeValue(outputStream, t) - } + mapper.writeValue(outputStream, t) } override def getSize( diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/OneApplicationResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/OneApplicationResource.scala index 974697890dd03..32100c5704538 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/OneApplicationResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/OneApplicationResource.scala @@ -140,11 +140,8 @@ private[v1] class AbstractApplicationResource extends BaseAppResource { .header("Content-Type", MediaType.APPLICATION_OCTET_STREAM) .build() } catch { - case NonFatal(e) => - Response.serverError() - .entity(s"Event logs are not available for app: $appId.") - .status(Response.Status.SERVICE_UNAVAILABLE) - .build() + case NonFatal(_) => + throw new ServiceUnavailable(s"Event logs are not available for app: $appId.") } } diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala index 5d015b0531ef6..732b7528f499e 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala @@ -21,6 +21,7 @@ import java.net.URLDecoder import java.text.SimpleDateFormat import java.util.{Date, Locale, TimeZone} import javax.servlet.http.HttpServletRequest +import javax.ws.rs.core.{MediaType, Response} import scala.util.control.NonFatal import scala.xml._ @@ -566,4 +567,8 @@ private[spark] object UIUtils extends Logging { NEWLINE_AND_SINGLE_QUOTE_REGEX.replaceAllIn(requestParameter, "")) } } + + def buildErrorResponse(status: Response.Status, msg: String): Response = { + Response.status(status).entity(msg).`type`(MediaType.TEXT_PLAIN).build() + } } From 221d03acca19bdf7a2624a29c180c99f098205d8 Mon Sep 17 00:00:00 2001 From: Sanket Chintapalli Date: Wed, 27 Jun 2018 14:37:19 -0700 Subject: [PATCH 1029/1161] [SPARK-24533] Typesafe rebranded to lightbend. Changing the build downloads path Typesafe has rebranded to lightbend. Just changing the downloads path to avoid redirection Tested by running build/mvn -DskipTests package Author: Sanket Chintapalli Closes #21636 from redsanket/SPARK-24533. --- build/mvn | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/build/mvn b/build/mvn index 1405983982d4c..ae4276dbc7e32 100755 --- a/build/mvn +++ b/build/mvn @@ -93,7 +93,7 @@ install_mvn() { install_zinc() { local zinc_path="zinc-0.3.15/bin/zinc" [ ! -f "${_DIR}/${zinc_path}" ] && ZINC_INSTALL_FLAG=1 - local TYPESAFE_MIRROR=${TYPESAFE_MIRROR:-https://downloads.typesafe.com} + local TYPESAFE_MIRROR=${TYPESAFE_MIRROR:-https://downloads.lightbend.com} install_app \ "${TYPESAFE_MIRROR}/zinc/0.3.15" \ @@ -109,7 +109,7 @@ install_scala() { # determine the Scala version used in Spark local scala_version=`grep "scala.version" "${_DIR}/../pom.xml" | head -n1 | awk -F '[<>]' '{print $3}'` local scala_bin="${_DIR}/scala-${scala_version}/bin/scala" - local TYPESAFE_MIRROR=${TYPESAFE_MIRROR:-https://downloads.typesafe.com} + local TYPESAFE_MIRROR=${TYPESAFE_MIRROR:-https://downloads.lightbend.com} install_app \ "${TYPESAFE_MIRROR}/scala/${scala_version}" \ From 893ea224cc738766be207c87f4b913fe8fea4c94 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Wed, 27 Jun 2018 15:25:51 -0700 Subject: [PATCH 1030/1161] [SPARK-24204][SQL] Verify a schema in Json/Orc/ParquetFileFormat ## What changes were proposed in this pull request? This pr added code to verify a schema in Json/Orc/ParquetFileFormat along with CSVFileFormat. ## How was this patch tested? Added verification tests in `FileBasedDataSourceSuite` and `HiveOrcSourceSuite`. Author: Takeshi Yamamuro Closes #21389 from maropu/SPARK-24204. --- .../datasources/DataSourceUtils.scala | 106 +++++++++ .../datasources/csv/CSVFileFormat.scala | 4 +- .../execution/datasources/csv/CSVUtils.scala | 19 -- .../datasources/json/JsonFileFormat.scala | 4 + .../datasources/orc/OrcFileFormat.scala | 4 + .../parquet/ParquetFileFormat.scala | 3 + .../spark/sql/FileBasedDataSourceSuite.scala | 213 +++++++++++++++++- .../execution/datasources/csv/CSVSuite.scala | 33 --- .../spark/sql/hive/orc/OrcFileFormat.scala | 4 + .../sql/hive/orc/HiveOrcSourceSuite.scala | 49 +++- 10 files changed, 383 insertions(+), 56 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala new file mode 100644 index 0000000000000..c5347218c4b40 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat +import org.apache.spark.sql.execution.datasources.json.JsonFileFormat +import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat +import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat +import org.apache.spark.sql.types._ + + +object DataSourceUtils { + + /** + * Verify if the schema is supported in datasource in write path. + */ + def verifyWriteSchema(format: FileFormat, schema: StructType): Unit = { + verifySchema(format, schema, isReadPath = false) + } + + /** + * Verify if the schema is supported in datasource in read path. + */ + def verifyReadSchema(format: FileFormat, schema: StructType): Unit = { + verifySchema(format, schema, isReadPath = true) + } + + /** + * Verify if the schema is supported in datasource. This verification should be done + * in a driver side, e.g., `prepareWrite`, `buildReader`, and `buildReaderWithPartitionValues` + * in `FileFormat`. + * + * Unsupported data types of csv, json, orc, and parquet are as follows; + * csv -> R/W: Interval, Null, Array, Map, Struct + * json -> W: Interval + * orc -> W: Interval, Null + * parquet -> R/W: Interval, Null + */ + private def verifySchema(format: FileFormat, schema: StructType, isReadPath: Boolean): Unit = { + def throwUnsupportedException(dataType: DataType): Unit = { + throw new UnsupportedOperationException( + s"$format data source does not support ${dataType.simpleString} data type.") + } + + def verifyType(dataType: DataType): Unit = dataType match { + case BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType | + StringType | BinaryType | DateType | TimestampType | _: DecimalType => + + // All the unsupported types for CSV + case _: NullType | _: CalendarIntervalType | _: StructType | _: ArrayType | _: MapType + if format.isInstanceOf[CSVFileFormat] => + throwUnsupportedException(dataType) + + case st: StructType => st.foreach { f => verifyType(f.dataType) } + + case ArrayType(elementType, _) => verifyType(elementType) + + case MapType(keyType, valueType, _) => + verifyType(keyType) + verifyType(valueType) + + case udt: UserDefinedType[_] => verifyType(udt.sqlType) + + // Interval type not supported in all the write path + case _: CalendarIntervalType if !isReadPath => + throwUnsupportedException(dataType) + + // JSON and ORC don't support an Interval type, but we pass it in read pass + // for back-compatibility. + case _: CalendarIntervalType if format.isInstanceOf[JsonFileFormat] || + format.isInstanceOf[OrcFileFormat] => + + // Interval type not supported in the other read path + case _: CalendarIntervalType => + throwUnsupportedException(dataType) + + // For JSON & ORC backward-compatibility + case _: NullType if format.isInstanceOf[JsonFileFormat] || + (isReadPath && format.isInstanceOf[OrcFileFormat]) => + + // Null type not supported in the other path + case _: NullType => + throwUnsupportedException(dataType) + + // We keep this default case for safeguards + case _ => throwUnsupportedException(dataType) + } + + schema.foreach(field => verifyType(field.dataType)) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala index b90275de9f40a..fa366ccce6b61 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala @@ -66,7 +66,7 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { job: Job, options: Map[String, String], dataSchema: StructType): OutputWriterFactory = { - CSVUtils.verifySchema(dataSchema) + DataSourceUtils.verifyWriteSchema(this, dataSchema) val conf = job.getConfiguration val csvOptions = new CSVOptions( options, @@ -98,7 +98,7 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { filters: Seq[Filter], options: Map[String, String], hadoopConf: Configuration): (PartitionedFile) => Iterator[InternalRow] = { - CSVUtils.verifySchema(dataSchema) + DataSourceUtils.verifyReadSchema(this, dataSchema) val broadcastedHadoopConf = sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala index 1012e774118e2..7ce65fa89b02d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala @@ -117,25 +117,6 @@ object CSVUtils { } } - /** - * Verify if the schema is supported in CSV datasource. - */ - def verifySchema(schema: StructType): Unit = { - def verifyType(dataType: DataType): Unit = dataType match { - case ByteType | ShortType | IntegerType | LongType | FloatType | - DoubleType | BooleanType | _: DecimalType | TimestampType | - DateType | StringType => - - case udt: UserDefinedType[_] => verifyType(udt.sqlType) - - case _ => - throw new UnsupportedOperationException( - s"CSV data source does not support ${dataType.simpleString} data type.") - } - - schema.foreach(field => verifyType(field.dataType)) - } - /** * Sample CSV dataset as configured by `samplingRatio`. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala index e9a0b383b5f49..383bff1375a93 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala @@ -65,6 +65,8 @@ class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister { job: Job, options: Map[String, String], dataSchema: StructType): OutputWriterFactory = { + DataSourceUtils.verifyWriteSchema(this, dataSchema) + val conf = job.getConfiguration val parsedOptions = new JSONOptions( options, @@ -96,6 +98,8 @@ class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister { filters: Seq[Filter], options: Map[String, String], hadoopConf: Configuration): PartitionedFile => Iterator[InternalRow] = { + DataSourceUtils.verifyReadSchema(this, dataSchema) + val broadcastedHadoopConf = sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala index 1de2ca2914c44..df488a748e3e5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala @@ -89,6 +89,8 @@ class OrcFileFormat job: Job, options: Map[String, String], dataSchema: StructType): OutputWriterFactory = { + DataSourceUtils.verifyWriteSchema(this, dataSchema) + val orcOptions = new OrcOptions(options, sparkSession.sessionState.conf) val conf = job.getConfiguration @@ -141,6 +143,8 @@ class OrcFileFormat filters: Seq[Filter], options: Map[String, String], hadoopConf: Configuration): (PartitionedFile) => Iterator[InternalRow] = { + DataSourceUtils.verifyReadSchema(this, dataSchema) + if (sparkSession.sessionState.conf.orcFilterPushDown) { OrcFilters.createFilter(dataSchema, filters).foreach { f => OrcInputFormat.setSearchArgument(hadoopConf, f, dataSchema.fieldNames) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala index 60fc9ec7e1f82..9602a08911dea 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala @@ -78,6 +78,7 @@ class ParquetFileFormat job: Job, options: Map[String, String], dataSchema: StructType): OutputWriterFactory = { + DataSourceUtils.verifyWriteSchema(this, dataSchema) val parquetOptions = new ParquetOptions(options, sparkSession.sessionState.conf) @@ -302,6 +303,8 @@ class ParquetFileFormat filters: Seq[Filter], options: Map[String, String], hadoopConf: Configuration): (PartitionedFile) => Iterator[InternalRow] = { + DataSourceUtils.verifyReadSchema(this, dataSchema) + hadoopConf.set(ParquetInputFormat.READ_SUPPORT_CLASS, classOf[ParquetReadSupport].getName) hadoopConf.set( ParquetReadSupport.SPARK_ROW_REQUESTED_SCHEMA, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala index 06303099f5310..86f9647b4ac4c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala @@ -17,12 +17,14 @@ package org.apache.spark.sql -import java.io.FileNotFoundException +import java.io.{File, FileNotFoundException} +import java.util.Locale import org.apache.hadoop.fs.Path import org.scalatest.BeforeAndAfterAll import org.apache.spark.SparkException +import org.apache.spark.sql.TestingUDT.{IntervalData, IntervalUDT, NullData, NullUDT} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext @@ -202,4 +204,213 @@ class FileBasedDataSourceSuite extends QueryTest with SharedSQLContext with Befo } } } + + // Unsupported data types of csv, json, orc, and parquet are as follows; + // csv -> R/W: Interval, Null, Array, Map, Struct + // json -> W: Interval + // orc -> W: Interval, Null + // parquet -> R/W: Interval, Null + test("SPARK-24204 error handling for unsupported Array/Map/Struct types - csv") { + withTempDir { dir => + val csvDir = new File(dir, "csv").getCanonicalPath + var msg = intercept[UnsupportedOperationException] { + Seq((1, "Tesla")).toDF("a", "b").selectExpr("struct(a, b)").write.csv(csvDir) + }.getMessage + assert(msg.contains("CSV data source does not support struct data type")) + + msg = intercept[UnsupportedOperationException] { + val schema = StructType.fromDDL("a struct") + spark.range(1).write.mode("overwrite").csv(csvDir) + spark.read.schema(schema).csv(csvDir).collect() + }.getMessage + assert(msg.contains("CSV data source does not support struct data type")) + + msg = intercept[UnsupportedOperationException] { + Seq((1, Map("Tesla" -> 3))).toDF("id", "cars").write.mode("overwrite").csv(csvDir) + }.getMessage + assert(msg.contains("CSV data source does not support map data type")) + + msg = intercept[UnsupportedOperationException] { + val schema = StructType.fromDDL("a map") + spark.range(1).write.mode("overwrite").csv(csvDir) + spark.read.schema(schema).csv(csvDir).collect() + }.getMessage + assert(msg.contains("CSV data source does not support map data type")) + + msg = intercept[UnsupportedOperationException] { + Seq((1, Array("Tesla", "Chevy", "Ford"))).toDF("id", "brands") + .write.mode("overwrite").csv(csvDir) + }.getMessage + assert(msg.contains("CSV data source does not support array data type")) + + msg = intercept[UnsupportedOperationException] { + val schema = StructType.fromDDL("a array") + spark.range(1).write.mode("overwrite").csv(csvDir) + spark.read.schema(schema).csv(csvDir).collect() + }.getMessage + assert(msg.contains("CSV data source does not support array data type")) + + msg = intercept[UnsupportedOperationException] { + Seq((1, new UDT.MyDenseVector(Array(0.25, 2.25, 4.25)))).toDF("id", "vectors") + .write.mode("overwrite").csv(csvDir) + }.getMessage + assert(msg.contains("CSV data source does not support array data type")) + + msg = intercept[UnsupportedOperationException] { + val schema = StructType(StructField("a", new UDT.MyDenseVectorUDT(), true) :: Nil) + spark.range(1).write.mode("overwrite").csv(csvDir) + spark.read.schema(schema).csv(csvDir).collect() + }.getMessage + assert(msg.contains("CSV data source does not support array data type.")) + } + } + + test("SPARK-24204 error handling for unsupported Interval data types - csv, json, parquet, orc") { + withTempDir { dir => + val tempDir = new File(dir, "files").getCanonicalPath + + // write path + Seq("csv", "json", "parquet", "orc").foreach { format => + var msg = intercept[AnalysisException] { + sql("select interval 1 days").write.format(format).mode("overwrite").save(tempDir) + }.getMessage + assert(msg.contains("Cannot save interval data type into external storage.")) + + msg = intercept[UnsupportedOperationException] { + spark.udf.register("testType", () => new IntervalData()) + sql("select testType()").write.format(format).mode("overwrite").save(tempDir) + }.getMessage + assert(msg.toLowerCase(Locale.ROOT) + .contains(s"$format data source does not support calendarinterval data type.")) + } + + // read path + Seq("parquet", "csv").foreach { format => + var msg = intercept[UnsupportedOperationException] { + val schema = StructType(StructField("a", CalendarIntervalType, true) :: Nil) + spark.range(1).write.format(format).mode("overwrite").save(tempDir) + spark.read.schema(schema).format(format).load(tempDir).collect() + }.getMessage + assert(msg.toLowerCase(Locale.ROOT) + .contains(s"$format data source does not support calendarinterval data type.")) + + msg = intercept[UnsupportedOperationException] { + val schema = StructType(StructField("a", new IntervalUDT(), true) :: Nil) + spark.range(1).write.format(format).mode("overwrite").save(tempDir) + spark.read.schema(schema).format(format).load(tempDir).collect() + }.getMessage + assert(msg.toLowerCase(Locale.ROOT) + .contains(s"$format data source does not support calendarinterval data type.")) + } + + // We expect the types below should be passed for backward-compatibility + Seq("orc", "json").foreach { format => + // Interval type + var schema = StructType(StructField("a", CalendarIntervalType, true) :: Nil) + spark.range(1).write.format(format).mode("overwrite").save(tempDir) + spark.read.schema(schema).format(format).load(tempDir).collect() + + // UDT having interval data + schema = StructType(StructField("a", new IntervalUDT(), true) :: Nil) + spark.range(1).write.format(format).mode("overwrite").save(tempDir) + spark.read.schema(schema).format(format).load(tempDir).collect() + } + } + } + + test("SPARK-24204 error handling for unsupported Null data types - csv, parquet, orc") { + withTempDir { dir => + val tempDir = new File(dir, "files").getCanonicalPath + + Seq("orc").foreach { format => + // write path + var msg = intercept[UnsupportedOperationException] { + sql("select null").write.format(format).mode("overwrite").save(tempDir) + }.getMessage + assert(msg.toLowerCase(Locale.ROOT) + .contains(s"$format data source does not support null data type.")) + + msg = intercept[UnsupportedOperationException] { + spark.udf.register("testType", () => new NullData()) + sql("select testType()").write.format(format).mode("overwrite").save(tempDir) + }.getMessage + assert(msg.toLowerCase(Locale.ROOT) + .contains(s"$format data source does not support null data type.")) + + // read path + // We expect the types below should be passed for backward-compatibility + + // Null type + var schema = StructType(StructField("a", NullType, true) :: Nil) + spark.range(1).write.format(format).mode("overwrite").save(tempDir) + spark.read.schema(schema).format(format).load(tempDir).collect() + + // UDT having null data + schema = StructType(StructField("a", new NullUDT(), true) :: Nil) + spark.range(1).write.format(format).mode("overwrite").save(tempDir) + spark.read.schema(schema).format(format).load(tempDir).collect() + } + + Seq("parquet", "csv").foreach { format => + // write path + var msg = intercept[UnsupportedOperationException] { + sql("select null").write.format(format).mode("overwrite").save(tempDir) + }.getMessage + assert(msg.toLowerCase(Locale.ROOT) + .contains(s"$format data source does not support null data type.")) + + msg = intercept[UnsupportedOperationException] { + spark.udf.register("testType", () => new NullData()) + sql("select testType()").write.format(format).mode("overwrite").save(tempDir) + }.getMessage + assert(msg.toLowerCase(Locale.ROOT) + .contains(s"$format data source does not support null data type.")) + + // read path + msg = intercept[UnsupportedOperationException] { + val schema = StructType(StructField("a", NullType, true) :: Nil) + spark.range(1).write.format(format).mode("overwrite").save(tempDir) + spark.read.schema(schema).format(format).load(tempDir).collect() + }.getMessage + assert(msg.toLowerCase(Locale.ROOT) + .contains(s"$format data source does not support null data type.")) + + msg = intercept[UnsupportedOperationException] { + val schema = StructType(StructField("a", new NullUDT(), true) :: Nil) + spark.range(1).write.format(format).mode("overwrite").save(tempDir) + spark.read.schema(schema).format(format).load(tempDir).collect() + }.getMessage + assert(msg.toLowerCase(Locale.ROOT) + .contains(s"$format data source does not support null data type.")) + } + } + } +} + +object TestingUDT { + + @SQLUserDefinedType(udt = classOf[IntervalUDT]) + class IntervalData extends Serializable + + class IntervalUDT extends UserDefinedType[IntervalData] { + + override def sqlType: DataType = CalendarIntervalType + override def serialize(obj: IntervalData): Any = + throw new NotImplementedError("Not implemented") + override def deserialize(datum: Any): IntervalData = + throw new NotImplementedError("Not implemented") + override def userClass: Class[IntervalData] = classOf[IntervalData] + } + + @SQLUserDefinedType(udt = classOf[NullUDT]) + private[sql] class NullData extends Serializable + + private[sql] class NullUDT extends UserDefinedType[NullData] { + + override def sqlType: DataType = NullType + override def serialize(obj: NullData): Any = throw new NotImplementedError("Not implemented") + override def deserialize(datum: Any): NullData = + throw new NotImplementedError("Not implemented") + override def userClass: Class[NullData] = classOf[NullData] + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index d2f166c7d1877..365239d040ef2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -740,39 +740,6 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils with Te assert(numbers.count() == 8) } - test("error handling for unsupported data types.") { - withTempDir { dir => - val csvDir = new File(dir, "csv").getCanonicalPath - var msg = intercept[UnsupportedOperationException] { - Seq((1, "Tesla")).toDF("a", "b").selectExpr("struct(a, b)").write.csv(csvDir) - }.getMessage - assert(msg.contains("CSV data source does not support struct data type")) - - msg = intercept[UnsupportedOperationException] { - Seq((1, Map("Tesla" -> 3))).toDF("id", "cars").write.csv(csvDir) - }.getMessage - assert(msg.contains("CSV data source does not support map data type")) - - msg = intercept[UnsupportedOperationException] { - Seq((1, Array("Tesla", "Chevy", "Ford"))).toDF("id", "brands").write.csv(csvDir) - }.getMessage - assert(msg.contains("CSV data source does not support array data type")) - - msg = intercept[UnsupportedOperationException] { - Seq((1, new UDT.MyDenseVector(Array(0.25, 2.25, 4.25)))).toDF("id", "vectors") - .write.csv(csvDir) - }.getMessage - assert(msg.contains("CSV data source does not support array data type")) - - msg = intercept[UnsupportedOperationException] { - val schema = StructType(StructField("a", new UDT.MyDenseVectorUDT(), true) :: Nil) - spark.range(1).write.csv(csvDir) - spark.read.schema(schema).csv(csvDir).collect() - }.getMessage - assert(msg.contains("CSV data source does not support array data type.")) - } - } - test("SPARK-15585 turn off quotations") { val cars = spark.read .format("csv") diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala index 237ed9bc05988..dd2144c5fcea8 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala @@ -72,6 +72,8 @@ class OrcFileFormat extends FileFormat with DataSourceRegister with Serializable job: Job, options: Map[String, String], dataSchema: StructType): OutputWriterFactory = { + DataSourceUtils.verifyWriteSchema(this, dataSchema) + val orcOptions = new OrcOptions(options, sparkSession.sessionState.conf) val configuration = job.getConfiguration @@ -121,6 +123,8 @@ class OrcFileFormat extends FileFormat with DataSourceRegister with Serializable filters: Seq[Filter], options: Map[String, String], hadoopConf: Configuration): (PartitionedFile) => Iterator[InternalRow] = { + DataSourceUtils.verifyReadSchema(this, dataSchema) + if (sparkSession.sessionState.conf.orcFilterPushDown) { // Sets pushed predicates OrcFilters.createFilter(requiredSchema, filters.toArray).foreach { f => diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcSourceSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcSourceSuite.scala index d556a030e2186..69009e1b520c2 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcSourceSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcSourceSuite.scala @@ -19,11 +19,13 @@ package org.apache.spark.sql.hive.orc import java.io.File -import org.apache.spark.sql.Row +import org.apache.spark.sql.{AnalysisException, Row} +import org.apache.spark.sql.TestingUDT.{IntervalData, IntervalUDT} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.execution.datasources.orc.OrcSuite import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.internal.HiveSerDe +import org.apache.spark.sql.types._ import org.apache.spark.util.Utils class HiveOrcSourceSuite extends OrcSuite with TestHiveSingleton { @@ -133,4 +135,49 @@ class HiveOrcSourceSuite extends OrcSuite with TestHiveSingleton { Utils.deleteRecursively(location) } } + + test("SPARK-24204 error handling for unsupported data types") { + withTempDir { dir => + val orcDir = new File(dir, "orc").getCanonicalPath + + // write path + var msg = intercept[AnalysisException] { + sql("select interval 1 days").write.mode("overwrite").orc(orcDir) + }.getMessage + assert(msg.contains("Cannot save interval data type into external storage.")) + + msg = intercept[UnsupportedOperationException] { + sql("select null").write.mode("overwrite").orc(orcDir) + }.getMessage + assert(msg.contains("ORC data source does not support null data type.")) + + msg = intercept[UnsupportedOperationException] { + spark.udf.register("testType", () => new IntervalData()) + sql("select testType()").write.mode("overwrite").orc(orcDir) + }.getMessage + assert(msg.contains("ORC data source does not support calendarinterval data type.")) + + // read path + msg = intercept[UnsupportedOperationException] { + val schema = StructType(StructField("a", CalendarIntervalType, true) :: Nil) + spark.range(1).write.mode("overwrite").orc(orcDir) + spark.read.schema(schema).orc(orcDir).collect() + }.getMessage + assert(msg.contains("ORC data source does not support calendarinterval data type.")) + + msg = intercept[UnsupportedOperationException] { + val schema = StructType(StructField("a", NullType, true) :: Nil) + spark.range(1).write.mode("overwrite").orc(orcDir) + spark.read.schema(schema).orc(orcDir).collect() + }.getMessage + assert(msg.contains("ORC data source does not support null data type.")) + + msg = intercept[UnsupportedOperationException] { + val schema = StructType(StructField("a", new IntervalUDT(), true) :: Nil) + spark.range(1).write.mode("overwrite").orc(orcDir) + spark.read.schema(schema).orc(orcDir).collect() + }.getMessage + assert(msg.contains("ORC data source does not support calendarinterval data type.")) + } + } } From c5aa54d54b301555bad1ff0653df11293f0033ed Mon Sep 17 00:00:00 2001 From: "Kallman, Steven" Date: Wed, 27 Jun 2018 15:36:59 -0700 Subject: [PATCH 1031/1161] [SPARK-24553][WEB-UI] http 302 fixes for href redirect ## What changes were proposed in this pull request? Updated URL/href links to include a '/' before '?id' to make links consistent and avoid http 302 redirect errors within UI port 4040 tabs. ## How was this patch tested? Built a runnable distribution and executed jobs. Validated that http 302 redirects are no longer encountered when clicking on links within UI port 4040 tabs. Author: Steven Kallman Author: Kallman, Steven Closes #21600 from SJKallman/{Spark-24553}{WEB-UI}-redirect-href-fixes. --- .../src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala | 2 +- core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala | 2 +- core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala | 2 +- .../org/apache/spark/sql/execution/ui/AllExecutionsPage.scala | 4 ++-- .../org/apache/spark/sql/execution/ui/ExecutionPage.scala | 2 +- .../spark/sql/hive/thriftserver/ui/ThriftServerPage.scala | 4 ++-- .../sql/hive/thriftserver/ui/ThriftServerSessionPage.scala | 2 +- .../main/scala/org/apache/spark/streaming/ui/BatchPage.scala | 2 +- 8 files changed, 10 insertions(+), 10 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala index 178d2c8d1a10a..90e9a7a3630cf 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala @@ -464,7 +464,7 @@ private[ui] class JobDataSource( val jobDescription = UIUtils.makeDescription(lastStageDescription, basePath, plainText = false) - val detailUrl = "%s/jobs/job?id=%s".format(basePath, jobData.jobId) + val detailUrl = "%s/jobs/job/?id=%s".format(basePath, jobData.jobId) new JobTableRowData( jobData, diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index d4e6a7bc3effa..55eb989962668 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -282,7 +282,7 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We val _taskTable = new TaskPagedTable( stageData, UIUtils.prependBaseUri(request, parent.basePath) + - s"/stages/stage?id=${stageId}&attempt=${stageAttemptId}", + s"/stages/stage/?id=${stageId}&attempt=${stageAttemptId}", currentTime, pageSize = taskPageSize, sortColumn = taskSortColumn, diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala index 56e4d6838a99a..d01acdae59c9f 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala @@ -370,7 +370,7 @@ private[ui] class StagePagedTable( Seq.empty } - val nameLinkUri = s"$basePathUri/stages/stage?id=${s.stageId}&attempt=${s.attemptId}" + val nameLinkUri = s"$basePathUri/stages/stage/?id=${s.stageId}&attempt=${s.attemptId}" val nameLink = {s.name} val cachedRddInfos = store.rddList().filter { rdd => s.rddIds.contains(rdd.id) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala index bf46bc4cf904d..a7a24ac3641b5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala @@ -214,11 +214,11 @@ private[ui] abstract class ExecutionTable( } private def jobURL(request: HttpServletRequest, jobId: Long): String = - "%s/jobs/job?id=%s".format(UIUtils.prependBaseUri(request, parent.basePath), jobId) + "%s/jobs/job/?id=%s".format(UIUtils.prependBaseUri(request, parent.basePath), jobId) private def executionURL(request: HttpServletRequest, executionID: Long): String = s"${UIUtils.prependBaseUri( - request, parent.basePath)}/${parent.prefix}/execution?id=$executionID" + request, parent.basePath)}/${parent.prefix}/execution/?id=$executionID" } private[ui] class RunningExecutionTable( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala index 282f7b4bb5a58..877176b030f8b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala @@ -122,7 +122,7 @@ class ExecutionPage(parent: SQLTab) extends WebUIPage("execution") with Logging } private def jobURL(request: HttpServletRequest, jobId: Long): String = - "%s/jobs/job?id=%s".format(UIUtils.prependBaseUri(request, parent.basePath), jobId) + "%s/jobs/job/?id=%s".format(UIUtils.prependBaseUri(request, parent.basePath), jobId) private def physicalPlanDescription(physicalPlanDescription: String): Seq[Node] = {
    diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala index 0950b30126773..771104ceb8842 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala @@ -76,7 +76,7 @@ private[ui] class ThriftServerPage(parent: ThriftServerTab) extends WebUIPage("" def generateDataRow(info: ExecutionInfo): Seq[Node] = { val jobLink = info.jobId.map { id: String => - [{id}] @@ -147,7 +147,7 @@ private[ui] class ThriftServerPage(parent: ThriftServerTab) extends WebUIPage("" val headerRow = Seq("User", "IP", "Session ID", "Start Time", "Finish Time", "Duration", "Total Execute") def generateDataRow(session: SessionInfo): Seq[Node] = { - val sessionLink = "%s/%s/session?id=%s".format( + val sessionLink = "%s/%s/session/?id=%s".format( UIUtils.prependBaseUri(request, parent.basePath), parent.prefix, session.sessionId)
    diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala index c884aa0ecbdf8..163eb43aabc72 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala @@ -86,7 +86,7 @@ private[ui] class ThriftServerSessionPage(parent: ThriftServerTab) def generateDataRow(info: ExecutionInfo): Seq[Node] = { val jobLink = info.jobId.map { id: String => - [{id}] diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala index ca9da6139649a..884d21d0afdd3 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala @@ -109,7 +109,7 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { flatMap(info => info.failureReason).headOption.getOrElse("") val formattedDuration = duration.map(d => SparkUIUtils.formatDuration(d)).getOrElse("-") val detailUrl = s"${SparkUIUtils.prependBaseUri( - request, parent.basePath)}/jobs/job?id=${sparkJob.jobId}" + request, parent.basePath)}/jobs/job/?id=${sparkJob.jobId}" // In the first row, output op id and its information needs to be shown. In other rows, these // cells will be taken up due to "rowspan". From bd32b509a1728366494cba13f8f6612b7bd46ec0 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Thu, 28 Jun 2018 09:19:25 +0800 Subject: [PATCH 1032/1161] [SPARK-24645][SQL] Skip parsing when csvColumnPruning enabled and partitions scanned only ## What changes were proposed in this pull request? In the master, when `csvColumnPruning`(implemented in [this commit](https://github.com/apache/spark/commit/64fad0b519cf35b8c0a0dec18dd3df9488a5ed25#diff-d19881aceddcaa5c60620fdcda99b4c4)) enabled and partitions scanned only, it throws an exception below; ``` scala> val dir = "/tmp/spark-csv/csv" scala> spark.range(10).selectExpr("id % 2 AS p", "id").write.mode("overwrite").partitionBy("p").csv(dir) scala> spark.read.csv(dir).selectExpr("sum(p)").collect() 18/06/25 13:12:51 ERROR Executor: Exception in task 0.0 in stage 2.0 (TID 5) java.lang.NullPointerException at org.apache.spark.sql.execution.datasources.csv.UnivocityParser.org$apache$spark$sql$execution$datasources$csv$UnivocityParser$$convert(UnivocityParser.scala:197) at org.apache.spark.sql.execution.datasources.csv.UnivocityParser.parse(UnivocityParser.scala:190) at org.apache.spark.sql.execution.datasources.csv.UnivocityParser$$anonfun$5.apply(UnivocityParser.scala:309) at org.apache.spark.sql.execution.datasources.csv.UnivocityParser$$anonfun$5.apply(UnivocityParser.scala:309) at org.apache.spark.sql.execution.datasources.FailureSafeParser.parse(FailureSafeParser.scala:61) ... ``` This pr modified code to skip CSV parsing in the case. ## How was this patch tested? Added tests in `CSVSuite`. Author: Takeshi Yamamuro Closes #21631 from maropu/SPARK-24645. --- .../execution/datasources/csv/UnivocityParser.scala | 10 +++++++++- .../spark/sql/execution/datasources/csv/CSVSuite.scala | 10 ++++++++++ 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala index 5f7d5696b71a6..aa545e1a0c00a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala @@ -183,11 +183,19 @@ class UnivocityParser( } } + private val doParse = if (schema.nonEmpty) { + (input: String) => convert(tokenizer.parseLine(input)) + } else { + // If `columnPruning` enabled and partition attributes scanned only, + // `schema` gets empty. + (_: String) => InternalRow.empty + } + /** * Parses a single CSV string and turns it into either one resulting row or no row (if the * the record is malformed). */ - def parse(input: String): InternalRow = convert(tokenizer.parseLine(input)) + def parse(input: String): InternalRow = doParse(input) private def convert(tokens: Array[String]): InternalRow = { if (tokens.length != schema.length) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index 365239d040ef2..84b91f6309fe8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -1569,4 +1569,14 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils with Te assert(testAppender2.events.asScala .exists(msg => msg.getRenderedMessage.contains("CSV header does not conform to the schema"))) } + + test("SPARK-24645 skip parsing when columnPruning enabled and partitions scanned only") { + withSQLConf(SQLConf.CSV_PARSER_COLUMN_PRUNING.key -> "true") { + withTempPath { path => + val dir = path.getAbsolutePath + spark.range(10).selectExpr("id % 2 AS p", "id").write.partitionBy("p").csv(dir) + checkAnswer(spark.read.csv(dir).selectExpr("sum(p)"), Row(5)) + } + } + } } From 1c9acc2438f9a97134ae5213a12112b2361fbb78 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Thu, 28 Jun 2018 09:21:10 +0800 Subject: [PATCH 1033/1161] [SPARK-24206][SQL][FOLLOW-UP] Update DataSourceReadBenchmark benchmark results ## What changes were proposed in this pull request? This pr corrected the default configuration (`spark.master=local[1]`) for benchmarks. Also, this updated performance results on the AWS `r3.xlarge`. ## How was this patch tested? N/A Author: Takeshi Yamamuro Closes #21625 from maropu/FixDataSourceReadBenchmark. --- .../benchmark/DataSourceReadBenchmark.scala | 296 +++++++++--------- 1 file changed, 152 insertions(+), 144 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceReadBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceReadBenchmark.scala index fc6d8abc03c09..8711f5a8fa1ce 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceReadBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceReadBenchmark.scala @@ -39,9 +39,11 @@ import org.apache.spark.util.{Benchmark, Utils} object DataSourceReadBenchmark { val conf = new SparkConf() .setAppName("DataSourceReadBenchmark") - .setIfMissing("spark.master", "local[1]") + // Since `spark.master` always exists, overrides this value + .set("spark.master", "local[1]") .setIfMissing("spark.driver.memory", "3g") .setIfMissing("spark.executor.memory", "3g") + .setIfMissing("spark.ui.enabled", "false") val spark = SparkSession.builder.config(conf).getOrCreate() @@ -154,73 +156,73 @@ object DataSourceReadBenchmark { } } - /* - Intel(R) Xeon(R) CPU E5-2686 v4 @ 2.30GHz + OpenJDK 64-Bit Server VM 1.8.0_171-b10 on Linux 4.14.33-51.37.amzn1.x86_64 + Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz SQL Single TINYINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative -------------------------------------------------------------------------------------------- - SQL CSV 15231 / 15267 1.0 968.3 1.0X - SQL Json 8476 / 8498 1.9 538.9 1.8X - SQL Parquet Vectorized 121 / 127 130.0 7.7 125.9X - SQL Parquet MR 1515 / 1543 10.4 96.3 10.1X - SQL ORC Vectorized 164 / 171 95.9 10.4 92.9X - SQL ORC Vectorized with copy 228 / 234 69.0 14.5 66.8X - SQL ORC MR 1297 / 1309 12.1 82.5 11.7X + SQL CSV 22964 / 23096 0.7 1460.0 1.0X + SQL Json 8469 / 8593 1.9 538.4 2.7X + SQL Parquet Vectorized 164 / 177 95.8 10.4 139.9X + SQL Parquet MR 1687 / 1706 9.3 107.2 13.6X + SQL ORC Vectorized 191 / 197 82.3 12.2 120.2X + SQL ORC Vectorized with copy 215 / 219 73.2 13.7 106.9X + SQL ORC MR 1392 / 1412 11.3 88.5 16.5X SQL Single SMALLINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative -------------------------------------------------------------------------------------------- - SQL CSV 16344 / 16374 1.0 1039.1 1.0X - SQL Json 8634 / 8648 1.8 548.9 1.9X - SQL Parquet Vectorized 172 / 177 91.5 10.9 95.1X - SQL Parquet MR 1744 / 1746 9.0 110.9 9.4X - SQL ORC Vectorized 189 / 194 83.1 12.0 86.4X - SQL ORC Vectorized with copy 244 / 250 64.5 15.5 67.0X - SQL ORC MR 1341 / 1386 11.7 85.3 12.2X + SQL CSV 24090 / 24097 0.7 1531.6 1.0X + SQL Json 8791 / 8813 1.8 558.9 2.7X + SQL Parquet Vectorized 204 / 212 77.0 13.0 117.9X + SQL Parquet MR 1813 / 1850 8.7 115.3 13.3X + SQL ORC Vectorized 226 / 230 69.7 14.4 106.7X + SQL ORC Vectorized with copy 295 / 298 53.3 18.8 81.6X + SQL ORC MR 1526 / 1549 10.3 97.1 15.8X SQL Single INT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative -------------------------------------------------------------------------------------------- - SQL CSV 17874 / 17875 0.9 1136.4 1.0X - SQL Json 9190 / 9204 1.7 584.3 1.9X - SQL Parquet Vectorized 141 / 160 111.2 9.0 126.4X - SQL Parquet MR 1930 / 2049 8.2 122.7 9.3X - SQL ORC Vectorized 259 / 264 60.7 16.5 69.0X - SQL ORC Vectorized with copy 265 / 272 59.4 16.8 67.5X - SQL ORC MR 1528 / 1569 10.3 97.2 11.7X + SQL CSV 25637 / 25791 0.6 1629.9 1.0X + SQL Json 9532 / 9570 1.7 606.0 2.7X + SQL Parquet Vectorized 181 / 191 86.8 11.5 141.5X + SQL Parquet MR 2210 / 2227 7.1 140.5 11.6X + SQL ORC Vectorized 309 / 317 50.9 19.6 83.0X + SQL ORC Vectorized with copy 316 / 322 49.8 20.1 81.2X + SQL ORC MR 1650 / 1680 9.5 104.9 15.5X SQL Single BIGINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative -------------------------------------------------------------------------------------------- - SQL CSV 22812 / 22839 0.7 1450.4 1.0X - SQL Json 12026 / 12054 1.3 764.6 1.9X - SQL Parquet Vectorized 222 / 227 70.8 14.1 102.6X - SQL Parquet MR 2199 / 2204 7.2 139.8 10.4X - SQL ORC Vectorized 331 / 335 47.6 21.0 69.0X - SQL ORC Vectorized with copy 338 / 343 46.6 21.5 67.6X - SQL ORC MR 1618 / 1622 9.7 102.9 14.1X + SQL CSV 31617 / 31764 0.5 2010.1 1.0X + SQL Json 12440 / 12451 1.3 790.9 2.5X + SQL Parquet Vectorized 284 / 315 55.4 18.0 111.4X + SQL Parquet MR 2382 / 2390 6.6 151.5 13.3X + SQL ORC Vectorized 398 / 403 39.5 25.3 79.5X + SQL ORC Vectorized with copy 410 / 413 38.3 26.1 77.1X + SQL ORC MR 1783 / 1813 8.8 113.4 17.7X SQL Single FLOAT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative -------------------------------------------------------------------------------------------- - SQL CSV 18703 / 18740 0.8 1189.1 1.0X - SQL Json 11779 / 11869 1.3 748.9 1.6X - SQL Parquet Vectorized 143 / 145 110.1 9.1 130.9X - SQL Parquet MR 1954 / 1963 8.0 124.2 9.6X - SQL ORC Vectorized 347 / 355 45.3 22.1 53.8X - SQL ORC Vectorized with copy 356 / 359 44.1 22.7 52.5X - SQL ORC MR 1570 / 1598 10.0 99.8 11.9X + SQL CSV 26679 / 26742 0.6 1696.2 1.0X + SQL Json 12490 / 12541 1.3 794.1 2.1X + SQL Parquet Vectorized 174 / 183 90.4 11.1 153.3X + SQL Parquet MR 2201 / 2223 7.1 140.0 12.1X + SQL ORC Vectorized 415 / 429 37.9 26.4 64.3X + SQL ORC Vectorized with copy 422 / 428 37.2 26.9 63.2X + SQL ORC MR 1767 / 1773 8.9 112.3 15.1X SQL Single DOUBLE Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative -------------------------------------------------------------------------------------------- - SQL CSV 23832 / 23838 0.7 1515.2 1.0X - SQL Json 16204 / 16226 1.0 1030.2 1.5X - SQL Parquet Vectorized 242 / 306 65.1 15.4 98.6X - SQL Parquet MR 2462 / 2482 6.4 156.5 9.7X - SQL ORC Vectorized 419 / 451 37.6 26.6 56.9X - SQL ORC Vectorized with copy 426 / 447 36.9 27.1 55.9X - SQL ORC MR 1885 / 1931 8.3 119.8 12.6X + SQL CSV 34223 / 34324 0.5 2175.8 1.0X + SQL Json 17784 / 17785 0.9 1130.7 1.9X + SQL Parquet Vectorized 277 / 283 56.7 17.6 123.4X + SQL Parquet MR 2356 / 2386 6.7 149.8 14.5X + SQL ORC Vectorized 533 / 536 29.5 33.9 64.2X + SQL ORC Vectorized with copy 541 / 546 29.1 34.4 63.3X + SQL ORC MR 2166 / 2177 7.3 137.7 15.8X */ sqlBenchmark.run() @@ -294,41 +296,42 @@ object DataSourceReadBenchmark { } /* - Intel(R) Xeon(R) CPU E5-2686 v4 @ 2.30GHz + OpenJDK 64-Bit Server VM 1.8.0_171-b10 on Linux 4.14.33-51.37.amzn1.x86_64 + Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Single TINYINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative -------------------------------------------------------------------------------------------- - ParquetReader Vectorized 187 / 201 84.2 11.9 1.0X - ParquetReader Vectorized -> Row 101 / 103 156.4 6.4 1.9X + ParquetReader Vectorized 198 / 202 79.4 12.6 1.0X + ParquetReader Vectorized -> Row 119 / 121 132.3 7.6 1.7X Single SMALLINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative -------------------------------------------------------------------------------------------- - ParquetReader Vectorized 272 / 288 57.8 17.3 1.0X - ParquetReader Vectorized -> Row 213 / 219 73.7 13.6 1.3X + ParquetReader Vectorized 282 / 287 55.8 17.9 1.0X + ParquetReader Vectorized -> Row 246 / 247 64.0 15.6 1.1X Single INT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative -------------------------------------------------------------------------------------------- - ParquetReader Vectorized 252 / 288 62.5 16.0 1.0X - ParquetReader Vectorized -> Row 232 / 246 67.7 14.8 1.1X + ParquetReader Vectorized 258 / 262 60.9 16.4 1.0X + ParquetReader Vectorized -> Row 259 / 260 60.8 16.5 1.0X Single BIGINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative -------------------------------------------------------------------------------------------- - ParquetReader Vectorized 415 / 454 37.9 26.4 1.0X - ParquetReader Vectorized -> Row 407 / 432 38.6 25.9 1.0X + ParquetReader Vectorized 361 / 369 43.6 23.0 1.0X + ParquetReader Vectorized -> Row 361 / 371 43.6 22.9 1.0X Single FLOAT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative -------------------------------------------------------------------------------------------- - ParquetReader Vectorized 251 / 302 62.7 16.0 1.0X - ParquetReader Vectorized -> Row 220 / 234 71.5 14.0 1.1X + ParquetReader Vectorized 253 / 261 62.2 16.1 1.0X + ParquetReader Vectorized -> Row 254 / 256 61.9 16.2 1.0X Single DOUBLE Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative -------------------------------------------------------------------------------------------- - ParquetReader Vectorized 432 / 436 36.4 27.5 1.0X - ParquetReader Vectorized -> Row 414 / 422 38.0 26.4 1.0X + ParquetReader Vectorized 357 / 364 44.0 22.7 1.0X + ParquetReader Vectorized -> Row 358 / 366 44.0 22.7 1.0X */ parquetReaderBenchmark.run() } @@ -382,16 +385,17 @@ object DataSourceReadBenchmark { } /* - Intel(R) Xeon(R) CPU E5-2686 v4 @ 2.30GHz + OpenJDK 64-Bit Server VM 1.8.0_171-b10 on Linux 4.14.33-51.37.amzn1.x86_64 + Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Int and String Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative -------------------------------------------------------------------------------------------- - SQL CSV 19172 / 19173 0.5 1828.4 1.0X - SQL Json 12799 / 12873 0.8 1220.6 1.5X - SQL Parquet Vectorized 2558 / 2564 4.1 244.0 7.5X - SQL Parquet MR 4514 / 4583 2.3 430.4 4.2X - SQL ORC Vectorized 2561 / 2697 4.1 244.3 7.5X - SQL ORC Vectorized with copy 3076 / 3110 3.4 293.4 6.2X - SQL ORC MR 4197 / 4283 2.5 400.2 4.6X + SQL CSV 27145 / 27158 0.4 2588.7 1.0X + SQL Json 12969 / 13337 0.8 1236.8 2.1X + SQL Parquet Vectorized 2419 / 2448 4.3 230.7 11.2X + SQL Parquet MR 4631 / 4633 2.3 441.7 5.9X + SQL ORC Vectorized 2412 / 2465 4.3 230.0 11.3X + SQL ORC Vectorized with copy 2633 / 2675 4.0 251.1 10.3X + SQL ORC MR 4280 / 4350 2.4 408.2 6.3X */ benchmark.run() } @@ -445,16 +449,17 @@ object DataSourceReadBenchmark { } /* - Intel(R) Xeon(R) CPU E5-2686 v4 @ 2.30GHz + OpenJDK 64-Bit Server VM 1.8.0_171-b10 on Linux 4.14.33-51.37.amzn1.x86_64 + Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Repeated String: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative -------------------------------------------------------------------------------------------- - SQL CSV 10889 / 10924 1.0 1038.5 1.0X - SQL Json 7903 / 7931 1.3 753.7 1.4X - SQL Parquet Vectorized 777 / 799 13.5 74.1 14.0X - SQL Parquet MR 1682 / 1708 6.2 160.4 6.5X - SQL ORC Vectorized 532 / 534 19.7 50.7 20.5X - SQL ORC Vectorized with copy 742 / 743 14.1 70.7 14.7X - SQL ORC MR 1996 / 2002 5.3 190.4 5.5X + SQL CSV 17345 / 17424 0.6 1654.1 1.0X + SQL Json 8639 / 8664 1.2 823.9 2.0X + SQL Parquet Vectorized 839 / 854 12.5 80.0 20.7X + SQL Parquet MR 1771 / 1775 5.9 168.9 9.8X + SQL ORC Vectorized 550 / 569 19.1 52.4 31.6X + SQL ORC Vectorized with copy 785 / 849 13.4 74.9 22.1X + SQL ORC MR 2168 / 2202 4.8 206.7 8.0X */ benchmark.run() } @@ -574,30 +579,31 @@ object DataSourceReadBenchmark { } /* - Intel(R) Xeon(R) CPU E5-2686 v4 @ 2.30GHz + OpenJDK 64-Bit Server VM 1.8.0_171-b10 on Linux 4.14.33-51.37.amzn1.x86_64 + Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Partitioned Table: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative -------------------------------------------------------------------------------------------- - Data column - CSV 25428 / 25454 0.6 1616.7 1.0X - Data column - Json 12689 / 12774 1.2 806.7 2.0X - Data column - Parquet Vectorized 222 / 231 70.7 14.1 114.3X - Data column - Parquet MR 3355 / 3397 4.7 213.3 7.6X - Data column - ORC Vectorized 332 / 338 47.4 21.1 76.6X - Data column - ORC Vectorized with copy 338 / 341 46.5 21.5 75.2X - Data column - ORC MR 2329 / 2356 6.8 148.0 10.9X - Partition column - CSV 17465 / 17502 0.9 1110.4 1.5X - Partition column - Json 10865 / 10876 1.4 690.8 2.3X - Partition column - Parquet Vectorized 48 / 52 325.4 3.1 526.1X - Partition column - Parquet MR 1695 / 1696 9.3 107.8 15.0X - Partition column - ORC Vectorized 49 / 54 319.9 3.1 517.2X - Partition column - ORC Vectorized with copy 49 / 52 324.1 3.1 524.0X - Partition column - ORC MR 1548 / 1549 10.2 98.4 16.4X - Both columns - CSV 25568 / 25595 0.6 1625.6 1.0X - Both columns - Json 13658 / 13673 1.2 868.4 1.9X - Both columns - Parquet Vectorized 270 / 296 58.3 17.1 94.3X - Both columns - Parquet MR 3501 / 3521 4.5 222.6 7.3X - Both columns - ORC Vectorized 377 / 380 41.7 24.0 67.4X - Both column - ORC Vectorized with copy 447 / 448 35.2 28.4 56.9X - Both columns - ORC MR 2440 / 2446 6.4 155.2 10.4X + Data column - CSV 32613 / 32841 0.5 2073.4 1.0X + Data column - Json 13343 / 13469 1.2 848.3 2.4X + Data column - Parquet Vectorized 302 / 318 52.1 19.2 108.0X + Data column - Parquet MR 2908 / 2924 5.4 184.9 11.2X + Data column - ORC Vectorized 412 / 425 38.1 26.2 79.1X + Data column - ORC Vectorized with copy 442 / 446 35.6 28.1 73.8X + Data column - ORC MR 2390 / 2396 6.6 152.0 13.6X + Partition column - CSV 9626 / 9683 1.6 612.0 3.4X + Partition column - Json 10909 / 10923 1.4 693.6 3.0X + Partition column - Parquet Vectorized 69 / 76 228.4 4.4 473.6X + Partition column - Parquet MR 1898 / 1933 8.3 120.7 17.2X + Partition column - ORC Vectorized 67 / 74 236.0 4.2 489.4X + Partition column - ORC Vectorized with copy 65 / 72 241.9 4.1 501.6X + Partition column - ORC MR 1743 / 1749 9.0 110.8 18.7X + Both columns - CSV 35523 / 35552 0.4 2258.5 0.9X + Both columns - Json 13676 / 13681 1.2 869.5 2.4X + Both columns - Parquet Vectorized 317 / 326 49.5 20.2 102.7X + Both columns - Parquet MR 3333 / 3336 4.7 211.9 9.8X + Both columns - ORC Vectorized 441 / 446 35.6 28.1 73.9X + Both column - ORC Vectorized with copy 517 / 524 30.4 32.9 63.1X + Both columns - ORC MR 2574 / 2577 6.1 163.6 12.7X */ benchmark.run() } @@ -684,41 +690,42 @@ object DataSourceReadBenchmark { } /* - Intel(R) Xeon(R) CPU E5-2686 v4 @ 2.30GHz + OpenJDK 64-Bit Server VM 1.8.0_171-b10 on Linux 4.14.33-51.37.amzn1.x86_64 + Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz String with Nulls Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative -------------------------------------------------------------------------------------------- - SQL CSV 13518 / 13529 0.8 1289.2 1.0X - SQL Json 10895 / 10926 1.0 1039.0 1.2X - SQL Parquet Vectorized 1539 / 1581 6.8 146.8 8.8X - SQL Parquet MR 3746 / 3811 2.8 357.3 3.6X - ParquetReader Vectorized 1070 / 1112 9.8 102.0 12.6X - SQL ORC Vectorized 1389 / 1408 7.6 132.4 9.7X - SQL ORC Vectorized with copy 1736 / 1750 6.0 165.6 7.8X - SQL ORC MR 3799 / 3892 2.8 362.3 3.6X + SQL CSV 14875 / 14920 0.7 1418.6 1.0X + SQL Json 10974 / 10992 1.0 1046.5 1.4X + SQL Parquet Vectorized 1711 / 1750 6.1 163.2 8.7X + SQL Parquet MR 3838 / 3884 2.7 366.0 3.9X + ParquetReader Vectorized 1155 / 1168 9.1 110.2 12.9X + SQL ORC Vectorized 1341 / 1380 7.8 127.9 11.1X + SQL ORC Vectorized with copy 1659 / 1716 6.3 158.2 9.0X + SQL ORC MR 3594 / 3634 2.9 342.7 4.1X String with Nulls Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative -------------------------------------------------------------------------------------------- - SQL CSV 10854 / 10892 1.0 1035.2 1.0X - SQL Json 8129 / 8138 1.3 775.3 1.3X - SQL Parquet Vectorized 1053 / 1104 10.0 100.4 10.3X - SQL Parquet MR 2840 / 2854 3.7 270.8 3.8X - ParquetReader Vectorized 978 / 1008 10.7 93.2 11.1X - SQL ORC Vectorized 1312 / 1387 8.0 125.1 8.3X - SQL ORC Vectorized with copy 1764 / 1772 5.9 168.2 6.2X - SQL ORC MR 3435 / 3445 3.1 327.6 3.2X + SQL CSV 17219 / 17264 0.6 1642.1 1.0X + SQL Json 8843 / 8864 1.2 843.3 1.9X + SQL Parquet Vectorized 1169 / 1178 9.0 111.4 14.7X + SQL Parquet MR 2676 / 2697 3.9 255.2 6.4X + ParquetReader Vectorized 1068 / 1071 9.8 101.8 16.1X + SQL ORC Vectorized 1319 / 1319 7.9 125.8 13.1X + SQL ORC Vectorized with copy 1638 / 1639 6.4 156.2 10.5X + SQL ORC MR 3230 / 3257 3.2 308.1 5.3X String with Nulls Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative -------------------------------------------------------------------------------------------- - SQL CSV 8043 / 8048 1.3 767.1 1.0X - SQL Json 4911 / 4923 2.1 468.4 1.6X - SQL Parquet Vectorized 206 / 209 51.0 19.6 39.1X - SQL Parquet MR 1528 / 1537 6.9 145.8 5.3X - ParquetReader Vectorized 216 / 219 48.6 20.6 37.2X - SQL ORC Vectorized 462 / 466 22.7 44.1 17.4X - SQL ORC Vectorized with copy 568 / 572 18.5 54.2 14.2X - SQL ORC MR 1647 / 1649 6.4 157.1 4.9X + SQL CSV 13976 / 14053 0.8 1332.8 1.0X + SQL Json 5166 / 5176 2.0 492.6 2.7X + SQL Parquet Vectorized 274 / 282 38.2 26.2 50.9X + SQL Parquet MR 1553 / 1555 6.8 148.1 9.0X + ParquetReader Vectorized 241 / 246 43.5 23.0 57.9X + SQL ORC Vectorized 476 / 479 22.0 45.4 29.3X + SQL ORC Vectorized with copy 584 / 588 17.9 55.7 23.9X + SQL ORC MR 1720 / 1734 6.1 164.1 8.1X */ benchmark.run() } @@ -773,38 +780,39 @@ object DataSourceReadBenchmark { } /* - Intel(R) Xeon(R) CPU E5-2686 v4 @ 2.30GHz + OpenJDK 64-Bit Server VM 1.8.0_171-b10 on Linux 4.14.33-51.37.amzn1.x86_64 + Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz Single Column Scan from 10 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative -------------------------------------------------------------------------------------------- - SQL CSV 3663 / 3665 0.3 3493.2 1.0X - SQL Json 3122 / 3160 0.3 2977.5 1.2X - SQL Parquet Vectorized 40 / 42 26.2 38.2 91.5X - SQL Parquet MR 189 / 192 5.5 180.2 19.4X - SQL ORC Vectorized 48 / 51 21.6 46.2 75.6X - SQL ORC Vectorized with copy 49 / 52 21.4 46.7 74.9X - SQL ORC MR 280 / 289 3.7 267.1 13.1X + SQL CSV 3478 / 3481 0.3 3316.4 1.0X + SQL Json 2646 / 2654 0.4 2523.6 1.3X + SQL Parquet Vectorized 67 / 72 15.8 63.5 52.2X + SQL Parquet MR 207 / 214 5.1 197.6 16.8X + SQL ORC Vectorized 69 / 76 15.2 66.0 50.3X + SQL ORC Vectorized with copy 70 / 76 15.0 66.5 49.9X + SQL ORC MR 299 / 303 3.5 285.1 11.6X Single Column Scan from 50 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative -------------------------------------------------------------------------------------------- - SQL CSV 11420 / 11505 0.1 10891.1 1.0X - SQL Json 11905 / 12120 0.1 11353.6 1.0X - SQL Parquet Vectorized 50 / 54 20.9 47.8 227.7X - SQL Parquet MR 195 / 199 5.4 185.8 58.6X - SQL ORC Vectorized 61 / 65 17.3 57.8 188.3X - SQL ORC Vectorized with copy 62 / 65 17.0 58.8 185.2X - SQL ORC MR 847 / 865 1.2 807.4 13.5X + SQL CSV 9214 / 9236 0.1 8786.7 1.0X + SQL Json 9943 / 9978 0.1 9482.7 0.9X + SQL Parquet Vectorized 77 / 86 13.6 73.3 119.8X + SQL Parquet MR 229 / 235 4.6 218.6 40.2X + SQL ORC Vectorized 84 / 96 12.5 80.0 109.9X + SQL ORC Vectorized with copy 83 / 91 12.6 79.4 110.7X + SQL ORC MR 843 / 854 1.2 804.0 10.9X - Single Column Scan from 100 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + Single Column Scan from 100 columns Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative -------------------------------------------------------------------------------------------- - SQL CSV 21278 / 21404 0.0 20292.4 1.0X - SQL Json 22455 / 22625 0.0 21414.7 0.9X - SQL Parquet Vectorized 73 / 75 14.4 69.3 292.8X - SQL Parquet MR 220 / 226 4.8 209.7 96.8X - SQL ORC Vectorized 82 / 86 12.8 78.2 259.4X - SQL ORC Vectorized with copy 82 / 90 12.7 78.7 258.0X - SQL ORC MR 1568 / 1582 0.7 1495.4 13.6X + SQL CSV 16503 / 16622 0.1 15738.9 1.0X + SQL Json 19109 / 19184 0.1 18224.2 0.9X + SQL Parquet Vectorized 99 / 108 10.6 94.3 166.8X + SQL Parquet MR 253 / 264 4.1 241.6 65.1X + SQL ORC Vectorized 107 / 114 9.8 101.6 154.8X + SQL ORC Vectorized with copy 107 / 118 9.8 102.1 154.1X + SQL ORC MR 1526 / 1529 0.7 1455.3 10.8X */ benchmark.run() } From 6a97e8eb31da76fe5af912a6304c07b63735062f Mon Sep 17 00:00:00 2001 From: Fokko Driesprong Date: Thu, 28 Jun 2018 09:59:00 +0800 Subject: [PATCH 1034/1161] [SPARK-24603][SQL] Fix findTightestCommonType reference in comments findTightestCommonTypeOfTwo has been renamed to findTightestCommonType ## What changes were proposed in this pull request? (Please fill in changes proposed in this fix) ## How was this patch tested? (Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests) (If this patch involves UI changes, please attach a screenshot; otherwise, remove this) Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Fokko Driesprong Closes #21597 from Fokko/fd-typo. --- .../sql/execution/datasources/json/JsonInferSchema.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala index f6edc7bfb3750..8e1b430f4eb33 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala @@ -334,8 +334,8 @@ private[sql] object JsonInferSchema { ArrayType(compatibleType(elementType1, elementType2), containsNull1 || containsNull2) // The case that given `DecimalType` is capable of given `IntegralType` is handled in - // `findTightestCommonTypeOfTwo`. Both cases below will be executed only when - // the given `DecimalType` is not capable of the given `IntegralType`. + // `findTightestCommonType`. Both cases below will be executed only when the given + // `DecimalType` is not capable of the given `IntegralType`. case (t1: IntegralType, t2: DecimalType) => compatibleType(DecimalType.forType(t1), t2) case (t1: DecimalType, t2: IntegralType) => From 5b0596648854c0c733b7c607661b78af7df18b89 Mon Sep 17 00:00:00 2001 From: Xingbo Jiang Date: Thu, 28 Jun 2018 14:19:50 +0800 Subject: [PATCH 1035/1161] [SPARK-24564][TEST] Add test suite for RecordBinaryComparator ## What changes were proposed in this pull request? Add a new test suite to test RecordBinaryComparator. ## How was this patch tested? New test suite. Author: Xingbo Jiang Closes #21570 from jiangxb1987/rbc-test. --- .../spark/memory/TestMemoryConsumer.java | 10 + .../sort/RecordBinaryComparatorSuite.java | 256 ++++++++++++++++++ 2 files changed, 266 insertions(+) create mode 100644 sql/core/src/test/java/test/org/apache/spark/sql/execution/sort/RecordBinaryComparatorSuite.java diff --git a/core/src/test/java/org/apache/spark/memory/TestMemoryConsumer.java b/core/src/test/java/org/apache/spark/memory/TestMemoryConsumer.java index db91329c94cb6..0bbaea6b834b8 100644 --- a/core/src/test/java/org/apache/spark/memory/TestMemoryConsumer.java +++ b/core/src/test/java/org/apache/spark/memory/TestMemoryConsumer.java @@ -17,6 +17,10 @@ package org.apache.spark.memory; +import com.google.common.annotations.VisibleForTesting; + +import org.apache.spark.unsafe.memory.MemoryBlock; + import java.io.IOException; public class TestMemoryConsumer extends MemoryConsumer { @@ -43,6 +47,12 @@ void free(long size) { used -= size; taskMemoryManager.releaseExecutionMemory(size, this); } + + @VisibleForTesting + public void freePage(MemoryBlock page) { + used -= page.size(); + taskMemoryManager.freePage(page, this); + } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/execution/sort/RecordBinaryComparatorSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/execution/sort/RecordBinaryComparatorSuite.java new file mode 100644 index 0000000000000..a19ddbdbadba2 --- /dev/null +++ b/sql/core/src/test/java/test/org/apache/spark/sql/execution/sort/RecordBinaryComparatorSuite.java @@ -0,0 +1,256 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test.org.apache.spark.sql.execution.sort; + +import org.apache.spark.SparkConf; +import org.apache.spark.memory.TaskMemoryManager; +import org.apache.spark.memory.TestMemoryConsumer; +import org.apache.spark.memory.TestMemoryManager; +import org.apache.spark.sql.catalyst.expressions.UnsafeArrayData; +import org.apache.spark.sql.catalyst.expressions.UnsafeRow; +import org.apache.spark.sql.execution.RecordBinaryComparator; +import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.UnsafeAlignedOffset; +import org.apache.spark.unsafe.array.LongArray; +import org.apache.spark.unsafe.memory.MemoryBlock; +import org.apache.spark.unsafe.types.UTF8String; +import org.apache.spark.util.collection.unsafe.sort.*; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +/** + * Test the RecordBinaryComparator, which compares two UnsafeRows by their binary form. + */ +public class RecordBinaryComparatorSuite { + + private final TaskMemoryManager memoryManager = new TaskMemoryManager( + new TestMemoryManager(new SparkConf().set("spark.memory.offHeap.enabled", "false")), 0); + private final TestMemoryConsumer consumer = new TestMemoryConsumer(memoryManager); + + private final int uaoSize = UnsafeAlignedOffset.getUaoSize(); + + private MemoryBlock dataPage; + private long pageCursor; + + private LongArray array; + private int pos; + + @Before + public void beforeEach() { + // Only compare between two input rows. + array = consumer.allocateArray(2); + pos = 0; + + dataPage = memoryManager.allocatePage(4096, consumer); + pageCursor = dataPage.getBaseOffset(); + } + + @After + public void afterEach() { + consumer.freePage(dataPage); + dataPage = null; + pageCursor = 0; + + consumer.freeArray(array); + array = null; + pos = 0; + } + + private void insertRow(UnsafeRow row) { + Object recordBase = row.getBaseObject(); + long recordOffset = row.getBaseOffset(); + int recordLength = row.getSizeInBytes(); + + Object baseObject = dataPage.getBaseObject(); + assert(pageCursor + recordLength <= dataPage.getBaseOffset() + dataPage.size()); + long recordAddress = memoryManager.encodePageNumberAndOffset(dataPage, pageCursor); + UnsafeAlignedOffset.putSize(baseObject, pageCursor, recordLength); + pageCursor += uaoSize; + Platform.copyMemory(recordBase, recordOffset, baseObject, pageCursor, recordLength); + pageCursor += recordLength; + + assert(pos < 2); + array.set(pos, recordAddress); + pos++; + } + + private int compare(int index1, int index2) { + Object baseObject = dataPage.getBaseObject(); + + long recordAddress1 = array.get(index1); + long baseOffset1 = memoryManager.getOffsetInPage(recordAddress1) + uaoSize; + int recordLength1 = UnsafeAlignedOffset.getSize(baseObject, baseOffset1 - uaoSize); + + long recordAddress2 = array.get(index2); + long baseOffset2 = memoryManager.getOffsetInPage(recordAddress2) + uaoSize; + int recordLength2 = UnsafeAlignedOffset.getSize(baseObject, baseOffset2 - uaoSize); + + return binaryComparator.compare(baseObject, baseOffset1, recordLength1, baseObject, + baseOffset2, recordLength2); + } + + private final RecordComparator binaryComparator = new RecordBinaryComparator(); + + // Compute the most compact size for UnsafeRow's backing data. + private int computeSizeInBytes(int originalSize) { + // All the UnsafeRows in this suite contains less than 64 columns, so the bitSetSize shall + // always be 8. + return 8 + (originalSize + 7) / 8 * 8; + } + + // Compute the relative offset of variable-length values. + private long relativeOffset(int numFields) { + // All the UnsafeRows in this suite contains less than 64 columns, so the bitSetSize shall + // always be 8. + return 8 + numFields * 8L; + } + + @Test + public void testBinaryComparatorForSingleColumnRow() throws Exception { + int numFields = 1; + + UnsafeRow row1 = new UnsafeRow(numFields); + byte[] data1 = new byte[100]; + row1.pointTo(data1, computeSizeInBytes(numFields * 8)); + row1.setInt(0, 11); + + UnsafeRow row2 = new UnsafeRow(numFields); + byte[] data2 = new byte[100]; + row2.pointTo(data2, computeSizeInBytes(numFields * 8)); + row2.setInt(0, 42); + + insertRow(row1); + insertRow(row2); + + assert(compare(0, 0) == 0); + assert(compare(0, 1) < 0); + } + + @Test + public void testBinaryComparatorForMultipleColumnRow() throws Exception { + int numFields = 5; + + UnsafeRow row1 = new UnsafeRow(numFields); + byte[] data1 = new byte[100]; + row1.pointTo(data1, computeSizeInBytes(numFields * 8)); + for (int i = 0; i < numFields; i++) { + row1.setDouble(i, i * 3.14); + } + + UnsafeRow row2 = new UnsafeRow(numFields); + byte[] data2 = new byte[100]; + row2.pointTo(data2, computeSizeInBytes(numFields * 8)); + for (int i = 0; i < numFields; i++) { + row2.setDouble(i, 198.7 / (i + 1)); + } + + insertRow(row1); + insertRow(row2); + + assert(compare(0, 0) == 0); + assert(compare(0, 1) < 0); + } + + @Test + public void testBinaryComparatorForArrayColumn() throws Exception { + int numFields = 1; + + UnsafeRow row1 = new UnsafeRow(numFields); + byte[] data1 = new byte[100]; + UnsafeArrayData arrayData1 = UnsafeArrayData.fromPrimitiveArray(new int[]{11, 42, -1}); + row1.pointTo(data1, computeSizeInBytes(numFields * 8 + arrayData1.getSizeInBytes())); + row1.setLong(0, (relativeOffset(numFields) << 32) | (long) arrayData1.getSizeInBytes()); + Platform.copyMemory(arrayData1.getBaseObject(), arrayData1.getBaseOffset(), data1, + row1.getBaseOffset() + relativeOffset(numFields), arrayData1.getSizeInBytes()); + + UnsafeRow row2 = new UnsafeRow(numFields); + byte[] data2 = new byte[100]; + UnsafeArrayData arrayData2 = UnsafeArrayData.fromPrimitiveArray(new int[]{22}); + row2.pointTo(data2, computeSizeInBytes(numFields * 8 + arrayData2.getSizeInBytes())); + row2.setLong(0, (relativeOffset(numFields) << 32) | (long) arrayData2.getSizeInBytes()); + Platform.copyMemory(arrayData2.getBaseObject(), arrayData2.getBaseOffset(), data2, + row2.getBaseOffset() + relativeOffset(numFields), arrayData2.getSizeInBytes()); + + insertRow(row1); + insertRow(row2); + + assert(compare(0, 0) == 0); + assert(compare(0, 1) > 0); + } + + @Test + public void testBinaryComparatorForMixedColumns() throws Exception { + int numFields = 4; + + UnsafeRow row1 = new UnsafeRow(numFields); + byte[] data1 = new byte[100]; + UTF8String str1 = UTF8String.fromString("Milk tea"); + row1.pointTo(data1, computeSizeInBytes(numFields * 8 + str1.numBytes())); + row1.setInt(0, 11); + row1.setDouble(1, 3.14); + row1.setInt(2, -1); + row1.setLong(3, (relativeOffset(numFields) << 32) | (long) str1.numBytes()); + Platform.copyMemory(str1.getBaseObject(), str1.getBaseOffset(), data1, + row1.getBaseOffset() + relativeOffset(numFields), str1.numBytes()); + + UnsafeRow row2 = new UnsafeRow(numFields); + byte[] data2 = new byte[100]; + UTF8String str2 = UTF8String.fromString("Java"); + row2.pointTo(data2, computeSizeInBytes(numFields * 8 + str2.numBytes())); + row2.setInt(0, 11); + row2.setDouble(1, 3.14); + row2.setInt(2, -1); + row2.setLong(3, (relativeOffset(numFields) << 32) | (long) str2.numBytes()); + Platform.copyMemory(str2.getBaseObject(), str2.getBaseOffset(), data2, + row2.getBaseOffset() + relativeOffset(numFields), str2.numBytes()); + + insertRow(row1); + insertRow(row2); + + assert(compare(0, 0) == 0); + assert(compare(0, 1) > 0); + } + + @Test + public void testBinaryComparatorForNullColumns() throws Exception { + int numFields = 3; + + UnsafeRow row1 = new UnsafeRow(numFields); + byte[] data1 = new byte[100]; + row1.pointTo(data1, computeSizeInBytes(numFields * 8)); + for (int i = 0; i < numFields; i++) { + row1.setNullAt(i); + } + + UnsafeRow row2 = new UnsafeRow(numFields); + byte[] data2 = new byte[100]; + row2.pointTo(data2, computeSizeInBytes(numFields * 8)); + for (int i = 0; i < numFields - 1; i++) { + row2.setNullAt(i); + } + row2.setDouble(numFields - 1, 3.14); + + insertRow(row1); + insertRow(row2); + + assert(compare(0, 0) == 0); + assert(compare(0, 1) > 0); + } +} From 524827f0626281847582ec3056982db7eb83f8b1 Mon Sep 17 00:00:00 2001 From: bravo-zhang Date: Thu, 28 Jun 2018 12:40:39 -0700 Subject: [PATCH 1036/1161] [SPARK-14712][ML] LogisticRegressionModel.toString should summarize model ## What changes were proposed in this pull request? [SPARK-14712](https://issues.apache.org/jira/browse/SPARK-14712) spark.mllib LogisticRegressionModel overrides toString to print a little model info. We should do the same in spark.ml and override repr in pyspark. ## How was this patch tested? LogisticRegressionSuite.scala Python doctest in pyspark.ml.classification.py Author: bravo-zhang Closes #18826 from bravo-zhang/spark-14712. --- .../apache/spark/ml/classification/LogisticRegression.scala | 5 +++++ .../spark/ml/classification/LogisticRegressionSuite.scala | 6 ++++++ python/pyspark/ml/classification.py | 5 +++++ python/pyspark/mllib/classification.py | 3 +++ 4 files changed, 19 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 06ca37bc75146..92e342ed4a464 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -1202,6 +1202,11 @@ class LogisticRegressionModel private[spark] ( */ @Since("1.6.0") override def write: MLWriter = new LogisticRegressionModel.LogisticRegressionModelWriter(this) + + override def toString: String = { + s"LogisticRegressionModel: " + + s"uid = ${super.toString}, numClasses = $numClasses, numFeatures = $numFeatures" + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index 36b7e51f93d01..75c2aeb146786 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -2751,6 +2751,12 @@ class LogisticRegressionSuite extends MLTest with DefaultReadWriteTest { assert(model.getFamily === family) } } + + test("toString") { + val model = new LogisticRegressionModel("logReg", Vectors.dense(0.1, 0.2, 0.3), 0.0) + val expected = "LogisticRegressionModel: uid = logReg, numClasses = 2, numFeatures = 3" + assert(model.toString === expected) + } } object LogisticRegressionSuite { diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 1754c48937a62..d5963f4f7042c 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -239,6 +239,8 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti True >>> blorModel.intercept == model2.intercept True + >>> model2 + LogisticRegressionModel: uid = ..., numClasses = 2, numFeatures = 2 .. versionadded:: 1.3.0 """ @@ -562,6 +564,9 @@ def evaluate(self, dataset): java_blr_summary = self._call_java("evaluate", dataset) return BinaryLogisticRegressionSummary(java_blr_summary) + def __repr__(self): + return self._call_java("toString") + class LogisticRegressionSummary(JavaWrapper): """ diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py index bb281981fd56b..e00ed95ef0701 100644 --- a/python/pyspark/mllib/classification.py +++ b/python/pyspark/mllib/classification.py @@ -258,6 +258,9 @@ def load(cls, sc, path): model.setThreshold(threshold) return model + def __repr__(self): + return self._call_java("toString") + class LogisticRegressionWithSGD(object): """ From a95a4af76459016b0d52df90adab68a49904da99 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Thu, 28 Jun 2018 13:20:08 -0700 Subject: [PATCH 1037/1161] [SPARK-23120][PYSPARK][ML] Add basic PMML export support to PySpark ## What changes were proposed in this pull request? Adds basic PMML export support for Spark ML stages to PySpark as was previously done in Scala. Includes LinearRegressionModel as the first stage to implement. ## How was this patch tested? Doctest, the main testing work for this is on the Scala side. (TODO holden add the unittest once I finish locally). Author: Holden Karau Closes #21172 from holdenk/SPARK-23120-add-pmml-export-support-to-pyspark. --- python/pyspark/ml/regression.py | 3 ++- python/pyspark/ml/tests.py | 17 ++++++++++++ python/pyspark/ml/util.py | 46 +++++++++++++++++++++++++++++++++ 3 files changed, 65 insertions(+), 1 deletion(-) diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index dba0e57b01a0b..83f0edb397271 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -95,6 +95,7 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction True >>> model.numFeatures 1 + >>> model.write().format("pmml").save(model_path + "_2") .. versionadded:: 1.4.0 """ @@ -161,7 +162,7 @@ def getEpsilon(self): return self.getOrDefault(self.epsilon) -class LinearRegressionModel(JavaModel, JavaPredictionModel, JavaMLWritable, JavaMLReadable): +class LinearRegressionModel(JavaModel, JavaPredictionModel, GeneralJavaMLWritable, JavaMLReadable): """ Model fitted by :class:`LinearRegression`. diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index ebd36cbb5f7a7..bc782138292bf 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -1362,6 +1362,23 @@ def test_linear_regression(self): except OSError: pass + def test_linear_regression_pmml_basic(self): + # Most of the validation is done in the Scala side, here we just check + # that we output text rather than parquet (e.g. that the format flag + # was respected). + df = self.spark.createDataFrame([(1.0, 2.0, Vectors.dense(1.0)), + (0.0, 2.0, Vectors.sparse(1, [], []))], + ["label", "weight", "features"]) + lr = LinearRegression(maxIter=1) + model = lr.fit(df) + path = tempfile.mkdtemp() + lr_path = path + "/lr-pmml" + model.write().format("pmml").save(lr_path) + pmml_text_list = self.sc.textFile(lr_path).collect() + pmml_text = "\n".join(pmml_text_list) + self.assertIn("Apache Spark", pmml_text) + self.assertIn("PMML", pmml_text) + def test_logistic_regression(self): lr = LogisticRegression(maxIter=1) path = tempfile.mkdtemp() diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py index 9fa85664939b8..080cd299f4fde 100644 --- a/python/pyspark/ml/util.py +++ b/python/pyspark/ml/util.py @@ -148,6 +148,23 @@ def overwrite(self): return self +@inherit_doc +class GeneralMLWriter(MLWriter): + """ + Utility class that can save ML instances in different formats. + + .. versionadded:: 2.4.0 + """ + + def format(self, source): + """ + Specifies the format of ML export (e.g. "pmml", "internal", or the fully qualified class + name for export). + """ + self.source = source + return self + + @inherit_doc class JavaMLWriter(MLWriter): """ @@ -192,6 +209,24 @@ def session(self, sparkSession): return self +@inherit_doc +class GeneralJavaMLWriter(JavaMLWriter): + """ + (Private) Specialization of :py:class:`GeneralMLWriter` for :py:class:`JavaParams` types + """ + + def __init__(self, instance): + super(GeneralJavaMLWriter, self).__init__(instance) + + def format(self, source): + """ + Specifies the format of ML export (e.g. "pmml", "internal", or the fully qualified class + name for export). + """ + self._jwrite.format(source) + return self + + @inherit_doc class MLWritable(object): """ @@ -220,6 +255,17 @@ def write(self): return JavaMLWriter(self) +@inherit_doc +class GeneralJavaMLWritable(JavaMLWritable): + """ + (Private) Mixin for ML instances that provide :py:class:`GeneralJavaMLWriter`. + """ + + def write(self): + """Returns an GeneralMLWriter instance for this ML instance.""" + return GeneralJavaMLWriter(self) + + @inherit_doc class MLReader(BaseReadWrite): """ From e1d3f80103f6df2eb8a962607dd5427df4b355dd Mon Sep 17 00:00:00 2001 From: Jacek Laskowski Date: Thu, 28 Jun 2018 13:22:52 -0700 Subject: [PATCH 1038/1161] [SPARK-24408][SQL][DOC] Move abs function to math_funcs group ## What changes were proposed in this pull request? A few math functions (`abs` , `bitwiseNOT`, `isnan`, `nanvl`) are not in **math_funcs** group. They should really be. ## How was this patch tested? Awaiting Jenkins Author: Jacek Laskowski Closes #21448 from jaceklaskowski/SPARK-24408-math-funcs-doc. --- .../scala/org/apache/spark/sql/functions.scala | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 0b4f526799578..acca9572cb14c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1031,14 +1031,6 @@ object functions { // Non-aggregate functions ////////////////////////////////////////////////////////////////////////////////////////////// - /** - * Computes the absolute value. - * - * @group normal_funcs - * @since 1.3.0 - */ - def abs(e: Column): Column = withExpr { Abs(e.expr) } - /** * Creates a new array column. The input columns must all have the same data type. * @@ -1336,7 +1328,7 @@ object functions { } /** - * Computes bitwise NOT. + * Computes bitwise NOT (~) of a number. * * @group normal_funcs * @since 1.4.0 @@ -1364,6 +1356,14 @@ object functions { // Math Functions ////////////////////////////////////////////////////////////////////////////////////////////// + /** + * Computes the absolute value of a numeric value. + * + * @group math_funcs + * @since 1.3.0 + */ + def abs(e: Column): Column = withExpr { Abs(e.expr) } + /** * @return inverse cosine of `e` in radians, as if computed by `java.lang.Math.acos` * From 2224861f2f93830d736b625c9a4cb72c918512b2 Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Thu, 28 Jun 2018 14:07:28 -0700 Subject: [PATCH 1039/1161] [SPARK-24439][ML][PYTHON] Add distanceMeasure to BisectingKMeans in PySpark ## What changes were proposed in this pull request? add distanceMeasure to BisectingKMeans in Python. ## How was this patch tested? added doctest and also manually tested it. Author: Huaxin Gao Closes #21557 from huaxingao/spark-24439. --- python/pyspark/ml/clustering.py | 35 +++++++++++++------ .../ml/param/_shared_params_code_gen.py | 4 ++- python/pyspark/ml/param/shared.py | 24 +++++++++++++ 3 files changed, 51 insertions(+), 12 deletions(-) diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py index 4aa1cf84b5824..6d77baf7349e4 100644 --- a/python/pyspark/ml/clustering.py +++ b/python/pyspark/ml/clustering.py @@ -349,8 +349,8 @@ def summary(self): @inherit_doc -class KMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasTol, HasSeed, - JavaMLWritable, JavaMLReadable): +class KMeans(JavaEstimator, HasDistanceMeasure, HasFeaturesCol, HasPredictionCol, HasMaxIter, + HasTol, HasSeed, JavaMLWritable, JavaMLReadable): """ K-means clustering with a k-means++ like initialization mode (the k-means|| algorithm by Bahmani et al). @@ -406,9 +406,6 @@ class KMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasTol typeConverter=TypeConverters.toString) initSteps = Param(Params._dummy(), "initSteps", "The number of steps for k-means|| " + "initialization mode. Must be > 0.", typeConverter=TypeConverters.toInt) - distanceMeasure = Param(Params._dummy(), "distanceMeasure", "The distance measure. " + - "Supported options: 'euclidean' and 'cosine'.", - typeConverter=TypeConverters.toString) @keyword_only def __init__(self, featuresCol="features", predictionCol="prediction", k=2, @@ -544,8 +541,8 @@ def summary(self): @inherit_doc -class BisectingKMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasSeed, - JavaMLWritable, JavaMLReadable): +class BisectingKMeans(JavaEstimator, HasDistanceMeasure, HasFeaturesCol, HasPredictionCol, + HasMaxIter, HasSeed, JavaMLWritable, JavaMLReadable): """ A bisecting k-means algorithm based on the paper "A comparison of document clustering techniques" by Steinbach, Karypis, and Kumar, with modification to fit Spark. @@ -585,6 +582,8 @@ class BisectingKMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIte >>> bkm2 = BisectingKMeans.load(bkm_path) >>> bkm2.getK() 2 + >>> bkm2.getDistanceMeasure() + 'euclidean' >>> model_path = temp_path + "/bkm_model" >>> model.save(model_path) >>> model2 = BisectingKMeansModel.load(model_path) @@ -607,10 +606,10 @@ class BisectingKMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIte @keyword_only def __init__(self, featuresCol="features", predictionCol="prediction", maxIter=20, - seed=None, k=4, minDivisibleClusterSize=1.0): + seed=None, k=4, minDivisibleClusterSize=1.0, distanceMeasure="euclidean"): """ __init__(self, featuresCol="features", predictionCol="prediction", maxIter=20, \ - seed=None, k=4, minDivisibleClusterSize=1.0) + seed=None, k=4, minDivisibleClusterSize=1.0, distanceMeasure="euclidean") """ super(BisectingKMeans, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.clustering.BisectingKMeans", @@ -622,10 +621,10 @@ def __init__(self, featuresCol="features", predictionCol="prediction", maxIter=2 @keyword_only @since("2.0.0") def setParams(self, featuresCol="features", predictionCol="prediction", maxIter=20, - seed=None, k=4, minDivisibleClusterSize=1.0): + seed=None, k=4, minDivisibleClusterSize=1.0, distanceMeasure="euclidean"): """ setParams(self, featuresCol="features", predictionCol="prediction", maxIter=20, \ - seed=None, k=4, minDivisibleClusterSize=1.0) + seed=None, k=4, minDivisibleClusterSize=1.0, distanceMeasure="euclidean") Sets params for BisectingKMeans. """ kwargs = self._input_kwargs @@ -659,6 +658,20 @@ def getMinDivisibleClusterSize(self): """ return self.getOrDefault(self.minDivisibleClusterSize) + @since("2.4.0") + def setDistanceMeasure(self, value): + """ + Sets the value of :py:attr:`distanceMeasure`. + """ + return self._set(distanceMeasure=value) + + @since("2.4.0") + def getDistanceMeasure(self): + """ + Gets the value of `distanceMeasure` or its default value. + """ + return self.getOrDefault(self.distanceMeasure) + def _create_model(self, java_model): return BisectingKMeansModel(java_model) diff --git a/python/pyspark/ml/param/_shared_params_code_gen.py b/python/pyspark/ml/param/_shared_params_code_gen.py index 6e9e0a34cdfde..e45ba840b412b 100644 --- a/python/pyspark/ml/param/_shared_params_code_gen.py +++ b/python/pyspark/ml/param/_shared_params_code_gen.py @@ -162,7 +162,9 @@ def get$Name(self): "fitting. If set to true, then all sub-models will be available. Warning: For large " + "models, collecting all sub-models can cause OOMs on the Spark driver.", "False", "TypeConverters.toBoolean"), - ("loss", "the loss function to be optimized.", None, "TypeConverters.toString")] + ("loss", "the loss function to be optimized.", None, "TypeConverters.toString"), + ("distanceMeasure", "the distance measure. Supported options: 'euclidean' and 'cosine'.", + "'euclidean'", "TypeConverters.toString")] code = [] for name, doc, defaultValueStr, typeConverter in shared: diff --git a/python/pyspark/ml/param/shared.py b/python/pyspark/ml/param/shared.py index 08408ee8fbfcc..618f5bf0a8103 100644 --- a/python/pyspark/ml/param/shared.py +++ b/python/pyspark/ml/param/shared.py @@ -790,3 +790,27 @@ def getCacheNodeIds(self): """ return self.getOrDefault(self.cacheNodeIds) + +class HasDistanceMeasure(Params): + """ + Mixin for param distanceMeasure: the distance measure. Supported options: 'euclidean' and 'cosine'. + """ + + distanceMeasure = Param(Params._dummy(), "distanceMeasure", "the distance measure. Supported options: 'euclidean' and 'cosine'.", typeConverter=TypeConverters.toString) + + def __init__(self): + super(HasDistanceMeasure, self).__init__() + self._setDefault(distanceMeasure='euclidean') + + def setDistanceMeasure(self, value): + """ + Sets the value of :py:attr:`distanceMeasure`. + """ + return self._set(distanceMeasure=value) + + def getDistanceMeasure(self): + """ + Gets the value of distanceMeasure or its default value. + """ + return self.getOrDefault(self.distanceMeasure) + From f6e6899a8b8af99cd06e84cae7c69e0fc35bc60a Mon Sep 17 00:00:00 2001 From: Jose Torres Date: Thu, 28 Jun 2018 16:25:40 -0700 Subject: [PATCH 1040/1161] [SPARK-24386][SS] coalesce(1) aggregates in continuous processing ## What changes were proposed in this pull request? Provide a continuous processing implementation of coalesce(1), as well as allowing aggregates on top of it. The changes in ContinuousQueuedDataReader and such are to use split.index (the ID of the partition within the RDD currently being compute()d) rather than context.partitionId() (the partition ID of the scheduled task within the Spark job - that is, the post coalesce writer). In the absence of a narrow dependency, these values were previously always the same, so there was no need to distinguish. ## How was this patch tested? new unit test Author: Jose Torres Closes #21560 from jose-torres/coalesce. --- .../UnsupportedOperationChecker.scala | 11 ++ .../datasources/v2/DataSourceV2Strategy.scala | 16 ++- .../continuous/ContinuousCoalesceExec.scala | 51 +++++++ .../continuous/ContinuousCoalesceRDD.scala | 136 ++++++++++++++++++ .../continuous/ContinuousDataSourceRDD.scala | 7 +- .../continuous/ContinuousExecution.scala | 4 + .../ContinuousQueuedDataReader.scala | 6 +- .../shuffle/ContinuousShuffleReadRDD.scala | 10 +- .../shuffle/RPCContinuousShuffleReader.scala | 4 +- .../sources/ContinuousMemoryStream.scala | 11 +- .../ContinuousAggregationSuite.scala | 63 +++++++- .../ContinuousQueuedDataReaderSuite.scala | 2 +- .../shuffle/ContinuousShuffleSuite.scala | 7 +- 13 files changed, 310 insertions(+), 18 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousCoalesceExec.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousCoalesceRDD.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala index 2bed41672fe33..5ced1ca200daa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala @@ -349,6 +349,17 @@ object UnsupportedOperationChecker { _: DeserializeToObject | _: SerializeFromObject | _: SubqueryAlias | _: TypedFilter) => case node if node.nodeName == "StreamingRelationV2" => + case Repartition(1, false, _) => + case node: Aggregate => + val aboveSinglePartitionCoalesce = node.find { + case Repartition(1, false, _) => true + case _ => false + }.isDefined + + if (!aboveSinglePartitionCoalesce) { + throwError(s"In continuous processing mode, coalesce(1) must be called before " + + s"aggregate operation ${node.nodeName}.") + } case node => throwError(s"Continuous processing does not support ${node.nodeName} operations.") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index 182aa2906cf1e..2a7f1de2c7c19 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -22,11 +22,12 @@ import scala.collection.mutable import org.apache.spark.sql.{sources, Strategy} import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, AttributeSet, Expression} import org.apache.spark.sql.catalyst.planning.PhysicalOperation -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Repartition} import org.apache.spark.sql.execution.{FilterExec, ProjectExec, SparkPlan} import org.apache.spark.sql.execution.datasources.DataSourceStrategy -import org.apache.spark.sql.execution.streaming.continuous.{WriteToContinuousDataSource, WriteToContinuousDataSourceExec} +import org.apache.spark.sql.execution.streaming.continuous.{ContinuousCoalesceExec, WriteToContinuousDataSource, WriteToContinuousDataSourceExec} import org.apache.spark.sql.sources.v2.reader.{DataSourceReader, SupportsPushDownCatalystFilters, SupportsPushDownFilters, SupportsPushDownRequiredColumns} +import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousReader object DataSourceV2Strategy extends Strategy { @@ -141,6 +142,17 @@ object DataSourceV2Strategy extends Strategy { case WriteToContinuousDataSource(writer, query) => WriteToContinuousDataSourceExec(writer, planLater(query)) :: Nil + case Repartition(1, false, child) => + val isContinuous = child.collectFirst { + case StreamingDataSourceV2Relation(_, _, _, r: ContinuousReader) => r + }.isDefined + + if (isContinuous) { + ContinuousCoalesceExec(1, planLater(child)) :: Nil + } else { + Nil + } + case _ => Nil } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousCoalesceExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousCoalesceExec.scala new file mode 100644 index 0000000000000..5f60343bacfaa --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousCoalesceExec.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.continuous + +import java.util.UUID + +import org.apache.spark.{HashPartitioner, SparkEnv} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeRow} +import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, SinglePartition} +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.streaming.continuous.shuffle.{ContinuousShuffleReadPartition, ContinuousShuffleReadRDD} + +/** + * Physical plan for coalescing a continuous processing plan. + * + * Currently, only coalesces to a single partition are supported. `numPartitions` must be 1. + */ +case class ContinuousCoalesceExec(numPartitions: Int, child: SparkPlan) extends SparkPlan { + override def output: Seq[Attribute] = child.output + + override def children: Seq[SparkPlan] = child :: Nil + + override def outputPartitioning: Partitioning = SinglePartition + + override def doExecute(): RDD[InternalRow] = { + assert(numPartitions == 1) + new ContinuousCoalesceRDD( + sparkContext, + numPartitions, + conf.continuousStreamingExecutorQueueSize, + sparkContext.getLocalProperty(ContinuousExecution.EPOCH_INTERVAL_KEY).toLong, + child.execute()) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousCoalesceRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousCoalesceRDD.scala new file mode 100644 index 0000000000000..ba85b355f974f --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousCoalesceRDD.scala @@ -0,0 +1,136 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.continuous + +import java.util.UUID + +import org.apache.spark._ +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.execution.streaming.continuous.shuffle._ +import org.apache.spark.util.ThreadUtils + +case class ContinuousCoalesceRDDPartition( + index: Int, + endpointName: String, + queueSize: Int, + numShuffleWriters: Int, + epochIntervalMs: Long) + extends Partition { + // Initialized only on the executor, and only once even as we call compute() multiple times. + lazy val (reader: ContinuousShuffleReader, endpoint) = { + val env = SparkEnv.get.rpcEnv + val receiver = new RPCContinuousShuffleReader( + queueSize, numShuffleWriters, epochIntervalMs, env) + val endpoint = env.setupEndpoint(endpointName, receiver) + + TaskContext.get().addTaskCompletionListener { ctx => + env.stop(endpoint) + } + (receiver, endpoint) + } + // This flag will be flipped on the executors to indicate that the threads processing + // partitions of the write-side RDD have been started. These will run indefinitely + // asynchronously as epochs of the coalesce RDD complete on the read side. + private[continuous] var writersInitialized: Boolean = false +} + +/** + * RDD for continuous coalescing. Asynchronously writes all partitions of `prev` into a local + * continuous shuffle, and then reads them in the task thread using `reader`. + */ +class ContinuousCoalesceRDD( + context: SparkContext, + numPartitions: Int, + readerQueueSize: Int, + epochIntervalMs: Long, + prev: RDD[InternalRow]) + extends RDD[InternalRow](context, Nil) { + + // When we support more than 1 target partition, we'll need to figure out how to pass in the + // required partitioner. + private val outputPartitioner = new HashPartitioner(1) + + private val readerEndpointNames = (0 until numPartitions).map { i => + s"ContinuousCoalesceRDD-part$i-${UUID.randomUUID()}" + } + + override def getPartitions: Array[Partition] = { + (0 until numPartitions).map { partIndex => + ContinuousCoalesceRDDPartition( + partIndex, + readerEndpointNames(partIndex), + readerQueueSize, + prev.getNumPartitions, + epochIntervalMs) + }.toArray + } + + private lazy val threadPool = ThreadUtils.newDaemonFixedThreadPool( + prev.getNumPartitions, + this.name) + + override def compute(split: Partition, context: TaskContext): Iterator[InternalRow] = { + val part = split.asInstanceOf[ContinuousCoalesceRDDPartition] + + if (!part.writersInitialized) { + val rpcEnv = SparkEnv.get.rpcEnv + + // trigger lazy initialization + part.endpoint + val endpointRefs = readerEndpointNames.map { endpointName => + rpcEnv.setupEndpointRef(rpcEnv.address, endpointName) + } + + val runnables = prev.partitions.map { prevSplit => + new Runnable() { + override def run(): Unit = { + TaskContext.setTaskContext(context) + + val writer: ContinuousShuffleWriter = new RPCContinuousShuffleWriter( + prevSplit.index, outputPartitioner, endpointRefs.toArray) + + EpochTracker.initializeCurrentEpoch( + context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong) + while (!context.isInterrupted() && !context.isCompleted()) { + writer.write(prev.compute(prevSplit, context).asInstanceOf[Iterator[UnsafeRow]]) + // Note that current epoch is a non-inheritable thread local, so each writer thread + // can properly increment its own epoch without affecting the main task thread. + EpochTracker.incrementCurrentEpoch() + } + } + } + } + + context.addTaskCompletionListener { ctx => + threadPool.shutdownNow() + } + + part.writersInitialized = true + + runnables.foreach(threadPool.execute) + } + + part.reader.read() + } + + override def clearDependencies(): Unit = { + throw new IllegalStateException("Continuous RDDs cannot be checkpointed") + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala index a7ccce10b0cee..73868d5967e90 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala @@ -51,11 +51,11 @@ class ContinuousDataSourceRDD( sc: SparkContext, dataQueueSize: Int, epochPollIntervalMs: Long, - @transient private val readerFactories: Seq[InputPartition[UnsafeRow]]) + private val readerInputPartitions: Seq[InputPartition[UnsafeRow]]) extends RDD[UnsafeRow](sc, Nil) { override protected def getPartitions: Array[Partition] = { - readerFactories.zipWithIndex.map { + readerInputPartitions.zipWithIndex.map { case (inputPartition, index) => new ContinuousDataSourceRDDPartition(index, inputPartition) }.toArray } @@ -74,8 +74,7 @@ class ContinuousDataSourceRDD( val partition = split.asInstanceOf[ContinuousDataSourceRDDPartition] if (partition.queueReader == null) { partition.queueReader = - new ContinuousQueuedDataReader( - partition.inputPartition, context, dataQueueSize, epochPollIntervalMs) + new ContinuousQueuedDataReader(partition, context, dataQueueSize, epochPollIntervalMs) } partition.queueReader diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index e3d0cea608b2a..a0bb8292d7766 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -216,6 +216,9 @@ class ContinuousExecution( currentEpochCoordinatorId = epochCoordinatorId sparkSessionForQuery.sparkContext.setLocalProperty( ContinuousExecution.EPOCH_COORDINATOR_ID_KEY, epochCoordinatorId) + sparkSessionForQuery.sparkContext.setLocalProperty( + ContinuousExecution.EPOCH_INTERVAL_KEY, + trigger.asInstanceOf[ContinuousTrigger].intervalMs.toString) // Use the parent Spark session for the endpoint since it's where this query ID is registered. val epochEndpoint = @@ -382,4 +385,5 @@ class ContinuousExecution( object ContinuousExecution { val START_EPOCH_KEY = "__continuous_start_epoch" val EPOCH_COORDINATOR_ID_KEY = "__epoch_coordinator_id" + val EPOCH_INTERVAL_KEY = "__continuous_epoch_interval" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala index f38577b6a9f16..8c74b8244d096 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala @@ -37,11 +37,11 @@ import org.apache.spark.util.ThreadUtils * offsets across epochs. Each compute() should call the next() method here until null is returned. */ class ContinuousQueuedDataReader( - partition: InputPartition[UnsafeRow], + partition: ContinuousDataSourceRDDPartition, context: TaskContext, dataQueueSize: Int, epochPollIntervalMs: Long) extends Closeable { - private val reader = partition.createPartitionReader() + private val reader = partition.inputPartition.createPartitionReader() // Important sequencing - we must get our starting point before the provider threads start running private var currentOffset: PartitionOffset = @@ -113,7 +113,7 @@ class ContinuousQueuedDataReader( currentEntry match { case EpochMarker => epochCoordEndpoint.send(ReportPartitionOffset( - context.partitionId(), EpochTracker.getCurrentEpoch.get, currentOffset)) + partition.index, EpochTracker.getCurrentEpoch.get, currentOffset)) null case ContinuousRow(row, offset) => currentOffset = offset diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala index cf6572d3de1f7..518223f3cd008 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala @@ -21,12 +21,14 @@ import java.util.UUID import org.apache.spark.{Partition, SparkContext, SparkEnv, TaskContext} import org.apache.spark.rdd.RDD +import org.apache.spark.rpc.RpcAddress import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.NextIterator case class ContinuousShuffleReadPartition( index: Int, + endpointName: String, queueSize: Int, numShuffleWriters: Int, epochIntervalMs: Long) @@ -36,7 +38,7 @@ case class ContinuousShuffleReadPartition( val env = SparkEnv.get.rpcEnv val receiver = new RPCContinuousShuffleReader( queueSize, numShuffleWriters, epochIntervalMs, env) - val endpoint = env.setupEndpoint(s"RPCContinuousShuffleReader-${UUID.randomUUID()}", receiver) + val endpoint = env.setupEndpoint(endpointName, receiver) TaskContext.get().addTaskCompletionListener { ctx => env.stop(endpoint) @@ -61,12 +63,14 @@ class ContinuousShuffleReadRDD( numPartitions: Int, queueSize: Int = 1024, numShuffleWriters: Int = 1, - epochIntervalMs: Long = 1000) + epochIntervalMs: Long = 1000, + val endpointNames: Seq[String] = Seq(s"RPCContinuousShuffleReader-${UUID.randomUUID()}")) extends RDD[UnsafeRow](sc, Nil) { override protected def getPartitions: Array[Partition] = { (0 until numPartitions).map { partIndex => - ContinuousShuffleReadPartition(partIndex, queueSize, numShuffleWriters, epochIntervalMs) + ContinuousShuffleReadPartition( + partIndex, endpointNames(partIndex), queueSize, numShuffleWriters, epochIntervalMs) }.toArray } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/RPCContinuousShuffleReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/RPCContinuousShuffleReader.scala index 834e84675c7d5..502ae0d4822e8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/RPCContinuousShuffleReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/RPCContinuousShuffleReader.scala @@ -46,7 +46,7 @@ private[shuffle] case class ReceiverEpochMarker(writerId: Int) extends RPCContin * TODO: Support multiple source tasks. We need to output a single epoch marker once all * source tasks have sent one. */ -private[shuffle] class RPCContinuousShuffleReader( +private[continuous] class RPCContinuousShuffleReader( queueSize: Int, numShuffleWriters: Int, epochIntervalMs: Long, @@ -107,7 +107,7 @@ private[shuffle] class RPCContinuousShuffleReader( } logWarning( s"Completion service failed to make progress after $epochIntervalMs ms. Waiting " + - s"for writers $writerIdsUncommitted to send epoch markers.") + s"for writers ${writerIdsUncommitted.mkString(",")} to send epoch markers.") // The completion service guarantees this future will be available immediately. case future => future.get() match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala index d1c3498450096..0bf90b8063326 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala @@ -23,12 +23,13 @@ import java.util.concurrent.atomic.AtomicInteger import javax.annotation.concurrent.GuardedBy import scala.collection.JavaConverters._ +import scala.collection.SortedMap import scala.collection.mutable.ListBuffer import org.json4s.NoTypeHints import org.json4s.jackson.Serialization -import org.apache.spark.SparkEnv +import org.apache.spark.{SparkEnv, TaskContext} import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, ThreadSafeRpcEndpoint} import org.apache.spark.sql.{Encoder, Row, SQLContext} import org.apache.spark.sql.execution.streaming._ @@ -184,6 +185,14 @@ class ContinuousMemoryStreamInputPartitionReader( private var currentOffset = startOffset private var current: Option[Row] = None + // Defense-in-depth against failing to propagate the task context. Since it's not inheritable, + // we have to do a bit of error prone work to get it into every thread used by continuous + // processing. We hope that some unit test will end up instantiating a continuous memory stream + // in such cases. + if (TaskContext.get() == null) { + throw new IllegalStateException("Task context was not set!") + } + override def next(): Boolean = { current = getRecord while (current.isEmpty) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousAggregationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousAggregationSuite.scala index b7ef637f5270e..0223812600961 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousAggregationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousAggregationSuite.scala @@ -31,7 +31,8 @@ class ContinuousAggregationSuite extends ContinuousSuiteBase { testStream(input.toDF().agg(max('value)), OutputMode.Complete)() } - assert(ex.getMessage.contains("Continuous processing does not support Aggregate operations")) + assert(ex.getMessage.contains( + "In continuous processing mode, coalesce(1) must be called before aggregate operation")) } test("basic") { @@ -50,6 +51,66 @@ class ContinuousAggregationSuite extends ContinuousSuiteBase { } } + test("multiple partitions with coalesce") { + val input = ContinuousMemoryStream[Int] + + val df = input.toDF().coalesce(1).agg(max('value)) + + testStream(df, OutputMode.Complete)( + AddData(input, 0, 1, 2), + CheckAnswer(2), + StopStream, + AddData(input, 3, 4, 5), + StartStream(), + CheckAnswer(5), + AddData(input, -1, -2, -3), + CheckAnswer(5)) + } + + test("multiple partitions with coalesce - multiple transformations") { + val input = ContinuousMemoryStream[Int] + + // We use a barrier to make sure predicates both before and after coalesce work + val df = input.toDF() + .select('value as 'copy, 'value) + .where('copy =!= 1) + .planWithBarrier + .coalesce(1) + .where('copy =!= 2) + .agg(max('value)) + + testStream(df, OutputMode.Complete)( + AddData(input, 0, 1, 2), + CheckAnswer(0), + StopStream, + AddData(input, 3, 4, 5), + StartStream(), + CheckAnswer(5), + AddData(input, -1, -2, -3), + CheckAnswer(5)) + } + + test("multiple partitions with multiple coalesce") { + val input = ContinuousMemoryStream[Int] + + val df = input.toDF() + .coalesce(1) + .planWithBarrier + .coalesce(1) + .select('value as 'copy, 'value) + .agg(max('value)) + + testStream(df, OutputMode.Complete)( + AddData(input, 0, 1, 2), + CheckAnswer(2), + StopStream, + AddData(input, 3, 4, 5), + StartStream(), + CheckAnswer(5), + AddData(input, -1, -2, -3), + CheckAnswer(5)) + } + test("repeated restart") { withSQLConf(("spark.sql.streaming.unsupportedOperationCheck", "false")) { val input = ContinuousMemoryStream.singlePartition[Int] diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala index e663fa8312da4..0e7e6febb53df 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala @@ -92,7 +92,7 @@ class ContinuousQueuedDataReaderSuite extends StreamTest with MockitoSugar { } } val reader = new ContinuousQueuedDataReader( - factory, + new ContinuousDataSourceRDDPartition(0, factory), mockContext, dataQueueSize = sqlContext.conf.continuousStreamingExecutorQueueSize, epochPollIntervalMs = sqlContext.conf.continuousStreamingExecutorPollIntervalMs) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleSuite.scala index a8e3611b585cf..f84f3d49707bf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.streaming.continuous.shuffle +import java.util.UUID + import org.apache.spark.{HashPartitioner, Partition, TaskContext, TaskContextImpl} import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection, UnsafeRow} @@ -124,7 +126,10 @@ class ContinuousShuffleSuite extends StreamTest { } test("reader - multiple partitions") { - val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 5) + val rdd = new ContinuousShuffleReadRDD( + sparkContext, + numPartitions = 5, + endpointNames = Seq.fill(5)(s"endpt-${UUID.randomUUID()}")) // Send all data before processing to ensure there's no crossover. for (p <- rdd.partitions) { val part = p.asInstanceOf[ContinuousShuffleReadPartition] From f71e8da5efde96aacc89e59c6e27b71fffcbc25f Mon Sep 17 00:00:00 2001 From: xueyu <278006819@qq.com> Date: Fri, 29 Jun 2018 10:44:17 -0700 Subject: [PATCH 1041/1161] [SPARK-24566][CORE] Fix spark.storage.blockManagerSlaveTimeoutMs default config This PR use spark.network.timeout in place of spark.storage.blockManagerSlaveTimeoutMs when it is not configured, as configuration doc said manual test Author: xueyu <278006819@qq.com> Closes #21575 from xueyumusic/slaveTimeOutConfig. --- core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala | 5 ++--- .../cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala | 2 +- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala index ff960b396dbf1..bcbc8df0d5865 100644 --- a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala +++ b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala @@ -74,10 +74,9 @@ private[spark] class HeartbeatReceiver(sc: SparkContext, clock: Clock) // "spark.network.timeout" uses "seconds", while `spark.storage.blockManagerSlaveTimeoutMs` uses // "milliseconds" - private val slaveTimeoutMs = - sc.conf.getTimeAsMs("spark.storage.blockManagerSlaveTimeoutMs", "120s") private val executorTimeoutMs = - sc.conf.getTimeAsSeconds("spark.network.timeout", s"${slaveTimeoutMs}ms") * 1000 + sc.conf.getTimeAsMs("spark.storage.blockManagerSlaveTimeoutMs", + s"${sc.conf.getTimeAsSeconds("spark.network.timeout", "120s")}s") // "spark.network.timeoutInterval" uses "seconds", while // "spark.storage.blockManagerTimeoutIntervalMs" uses "milliseconds" diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala index d35bea4aca311..1ce2f816dffb2 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala @@ -634,7 +634,7 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( slave.hostname, externalShufflePort, sc.conf.getTimeAsMs("spark.storage.blockManagerSlaveTimeoutMs", - s"${sc.conf.getTimeAsSeconds("spark.network.timeout", "120s") * 1000L}ms"), + s"${sc.conf.getTimeAsSeconds("spark.network.timeout", "120s")}s"), sc.conf.getTimeAsMs("spark.executor.heartbeatInterval", "10s")) slave.shuffleRegistered = true } From 03545ce6de08bd0ad685c5f59b73bc22dfc40887 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Sat, 30 Jun 2018 13:58:50 +0800 Subject: [PATCH 1042/1161] [SPARK-24638][SQL] StringStartsWith support push down ## What changes were proposed in this pull request? `StringStartsWith` support push down. About 50% savings in compute time. ## How was this patch tested? unit tests, manual tests and performance test: ```scala cat < SPARK-24638.scala def benchmark(func: () => Unit): Long = { val start = System.currentTimeMillis() for(i <- 0 until 100) { func() } val end = System.currentTimeMillis() end - start } val path = "/tmp/spark/parquet/string/" spark.range(10000000).selectExpr("concat(id, 'str', id) as id").coalesce(1).write.mode("overwrite").option("parquet.block.size", 1048576).parquet(path) val df = spark.read.parquet(path) spark.sql("set spark.sql.parquet.filterPushdown.string.startsWith=true") val pushdownEnable = benchmark(() => df.where("id like '999998%'").count()) spark.sql("set spark.sql.parquet.filterPushdown.string.startsWith=false") val pushdownDisable = benchmark(() => df.where("id like '999998%'").count()) val improvements = pushdownDisable - pushdownEnable println(s"improvements: $improvements") EOF bin/spark-shell -i SPARK-24638.scala ``` result: ```scala Loading SPARK-24638.scala... benchmark: (func: () => Unit)Long path: String = /tmp/spark/parquet/string/ df: org.apache.spark.sql.DataFrame = [id: string] res1: org.apache.spark.sql.DataFrame = [key: string, value: string] pushdownEnable: Long = 11608 res2: org.apache.spark.sql.DataFrame = [key: string, value: string] pushdownDisable: Long = 31981 improvements: Long = 20373 ``` Author: Yuming Wang Closes #21623 from wangyum/SPARK-24638. --- .../apache/spark/sql/internal/SQLConf.scala | 11 +++ .../parquet/ParquetFileFormat.scala | 4 +- .../datasources/parquet/ParquetFilters.scala | 35 +++++++- .../parquet/ParquetFilterSuite.scala | 84 ++++++++++++++++++- 4 files changed, 130 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index e1752ff997b69..da1c34cdc78f2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -378,6 +378,14 @@ object SQLConf { .booleanConf .createWithDefault(true) + val PARQUET_FILTER_PUSHDOWN_STRING_STARTSWITH_ENABLED = + buildConf("spark.sql.parquet.filterPushdown.string.startsWith") + .doc("If true, enables Parquet filter push-down optimization for string startsWith function. " + + "This configuration only has an effect when 'spark.sql.parquet.filterPushdown' is enabled.") + .internal() + .booleanConf + .createWithDefault(true) + val PARQUET_WRITE_LEGACY_FORMAT = buildConf("spark.sql.parquet.writeLegacyFormat") .doc("Whether to be compatible with the legacy Parquet format adopted by Spark 1.4 and prior " + "versions, when converting Parquet schema to Spark SQL schema and vice versa.") @@ -1459,6 +1467,9 @@ class SQLConf extends Serializable with Logging { def parquetFilterPushDownDate: Boolean = getConf(PARQUET_FILTER_PUSHDOWN_DATE_ENABLED) + def parquetFilterPushDownStringStartWith: Boolean = + getConf(PARQUET_FILTER_PUSHDOWN_STRING_STARTSWITH_ENABLED) + def orcFilterPushDown: Boolean = getConf(ORC_FILTER_PUSHDOWN_ENABLED) def verifyPartitionPath: Boolean = getConf(HIVE_VERIFY_PARTITION_PATH) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala index 9602a08911dea..93de1faef527a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala @@ -348,6 +348,7 @@ class ParquetFileFormat // Whole stage codegen (PhysicalRDD) is able to deal with batches directly val returningBatch = supportBatch(sparkSession, resultSchema) val pushDownDate = sqlConf.parquetFilterPushDownDate + val pushDownStringStartWith = sqlConf.parquetFilterPushDownStringStartWith (file: PartitionedFile) => { assert(file.partitionValues.numFields == partitionSchema.size) @@ -358,7 +359,8 @@ class ParquetFileFormat // Collects all converted Parquet filter predicates. Notice that not all predicates can be // converted (`ParquetFilters.createFilter` returns an `Option`). That's why a `flatMap` // is used here. - .flatMap(new ParquetFilters(pushDownDate).createFilter(requiredSchema, _)) + .flatMap(new ParquetFilters(pushDownDate, pushDownStringStartWith) + .createFilter(requiredSchema, _)) .reduceOption(FilterApi.and) } else { None diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala index 310626197a763..21c9e2e4f82b4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala @@ -22,16 +22,18 @@ import java.sql.Date import org.apache.parquet.filter2.predicate._ import org.apache.parquet.filter2.predicate.FilterApi._ import org.apache.parquet.io.api.Binary +import org.apache.parquet.schema.PrimitiveComparator import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.catalyst.util.DateTimeUtils.SQLDate import org.apache.spark.sql.sources import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String /** * Some utility function to convert Spark data source filters to Parquet filters. */ -private[parquet] class ParquetFilters(pushDownDate: Boolean) { +private[parquet] class ParquetFilters(pushDownDate: Boolean, pushDownStartWith: Boolean) { private def dateToDays(date: Date): SQLDate = { DateTimeUtils.fromJavaDate(date) @@ -270,6 +272,37 @@ private[parquet] class ParquetFilters(pushDownDate: Boolean) { case sources.Not(pred) => createFilter(schema, pred).map(FilterApi.not) + case sources.StringStartsWith(name, prefix) if pushDownStartWith && canMakeFilterOn(name) => + Option(prefix).map { v => + FilterApi.userDefined(binaryColumn(name), + new UserDefinedPredicate[Binary] with Serializable { + private val strToBinary = Binary.fromReusedByteArray(v.getBytes) + private val size = strToBinary.length + + override def canDrop(statistics: Statistics[Binary]): Boolean = { + val comparator = PrimitiveComparator.UNSIGNED_LEXICOGRAPHICAL_BINARY_COMPARATOR + val max = statistics.getMax + val min = statistics.getMin + comparator.compare(max.slice(0, math.min(size, max.length)), strToBinary) < 0 || + comparator.compare(min.slice(0, math.min(size, min.length)), strToBinary) > 0 + } + + override def inverseCanDrop(statistics: Statistics[Binary]): Boolean = { + val comparator = PrimitiveComparator.UNSIGNED_LEXICOGRAPHICAL_BINARY_COMPARATOR + val max = statistics.getMax + val min = statistics.getMin + comparator.compare(max.slice(0, math.min(size, max.length)), strToBinary) == 0 && + comparator.compare(min.slice(0, math.min(size, min.length)), strToBinary) == 0 + } + + override def keep(value: Binary): Boolean = { + UTF8String.fromBytes(value.getBytes).startsWith( + UTF8String.fromBytes(strToBinary.getBytes)) + } + } + ) + } + case _ => None } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index 90da7eb8c4fb5..d9ae5858e5ed0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -55,7 +55,8 @@ import org.apache.spark.util.{AccumulatorContext, AccumulatorV2} */ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContext { - private lazy val parquetFilters = new ParquetFilters(conf.parquetFilterPushDownDate) + private lazy val parquetFilters = + new ParquetFilters(conf.parquetFilterPushDownDate, conf.parquetFilterPushDownStringStartWith) override def beforeEach(): Unit = { super.beforeEach() @@ -82,6 +83,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex withSQLConf( SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> "true", SQLConf.PARQUET_FILTER_PUSHDOWN_DATE_ENABLED.key -> "true", + SQLConf.PARQUET_FILTER_PUSHDOWN_STRING_STARTSWITH_ENABLED.key -> "true", SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { val query = df .select(output.map(e => Column(e)): _*) @@ -140,6 +142,31 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex checkBinaryFilterPredicate(predicate, filterClass, Seq(Row(expected)))(df) } + // This function tests that exactly go through the `canDrop` and `inverseCanDrop`. + private def testStringStartsWith(dataFrame: DataFrame, filter: String): Unit = { + withTempPath { dir => + val path = dir.getCanonicalPath + dataFrame.write.option("parquet.block.size", 512).parquet(path) + Seq(true, false).foreach { pushDown => + withSQLConf( + SQLConf.PARQUET_FILTER_PUSHDOWN_STRING_STARTSWITH_ENABLED.key -> pushDown.toString) { + val accu = new NumRowGroupsAcc + sparkContext.register(accu) + + val df = spark.read.parquet(path).filter(filter) + df.foreachPartition((it: Iterator[Row]) => it.foreach(v => accu.add(0))) + if (pushDown) { + assert(accu.value == 0) + } else { + assert(accu.value > 0) + } + + AccumulatorContext.remove(accu.id) + } + } + } + } + test("filter pushdown - boolean") { withParquetDataFrame((true :: false :: Nil).map(b => Tuple1.apply(Option(b)))) { implicit df => checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) @@ -574,7 +601,6 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex val df = spark.read.parquet(path).filter("a < 100") df.foreachPartition((it: Iterator[Row]) => it.foreach(v => accu.add(0))) - df.collect if (enablePushDown) { assert(accu.value == 0) @@ -660,6 +686,60 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex assert(df.where("col > 0").count() === 2) } } + + test("filter pushdown - StringStartsWith") { + withParquetDataFrame((1 to 4).map(i => Tuple1(i + "str" + i))) { implicit df => + checkFilterPredicate( + '_1.startsWith("").asInstanceOf[Predicate], + classOf[UserDefinedByInstance[_, _]], + Seq("1str1", "2str2", "3str3", "4str4").map(Row(_))) + + Seq("2", "2s", "2st", "2str", "2str2").foreach { prefix => + checkFilterPredicate( + '_1.startsWith(prefix).asInstanceOf[Predicate], + classOf[UserDefinedByInstance[_, _]], + "2str2") + } + + Seq("2S", "null", "2str22").foreach { prefix => + checkFilterPredicate( + '_1.startsWith(prefix).asInstanceOf[Predicate], + classOf[UserDefinedByInstance[_, _]], + Seq.empty[Row]) + } + + checkFilterPredicate( + !'_1.startsWith("").asInstanceOf[Predicate], + classOf[UserDefinedByInstance[_, _]], + Seq().map(Row(_))) + + Seq("2", "2s", "2st", "2str", "2str2").foreach { prefix => + checkFilterPredicate( + !'_1.startsWith(prefix).asInstanceOf[Predicate], + classOf[UserDefinedByInstance[_, _]], + Seq("1str1", "3str3", "4str4").map(Row(_))) + } + + Seq("2S", "null", "2str22").foreach { prefix => + checkFilterPredicate( + !'_1.startsWith(prefix).asInstanceOf[Predicate], + classOf[UserDefinedByInstance[_, _]], + Seq("1str1", "2str2", "3str3", "4str4").map(Row(_))) + } + + assertResult(None) { + parquetFilters.createFilter( + df.schema, + sources.StringStartsWith("_1", null)) + } + } + + import testImplicits._ + // Test canDrop() has taken effect + testStringStartsWith(spark.range(1024).map(_.toString).toDF(), "value like 'a%'") + // Test inverseCanDrop() has taken effect + testStringStartsWith(spark.range(1024).map(c => "100").toDF(), "value not like '10%'") + } } class NumRowGroupsAcc extends AccumulatorV2[Integer, Integer] { From 797971ed42cab41cbc3d039c0af4b26199bff783 Mon Sep 17 00:00:00 2001 From: maryannxue Date: Fri, 29 Jun 2018 23:46:12 -0700 Subject: [PATCH 1043/1161] [SPARK-24696][SQL] ColumnPruning rule fails to remove extra Project ## What changes were proposed in this pull request? The ColumnPruning rule tries adding an extra Project if an input node produces fields more than needed, but as a post-processing step, it needs to remove the lower Project in the form of "Project - Filter - Project" otherwise it would conflict with PushPredicatesThroughProject and would thus cause a infinite optimization loop. The current post-processing method is defined as: ``` private def removeProjectBeforeFilter(plan: LogicalPlan): LogicalPlan = plan transform { case p1 Project(_, f Filter(_, p2 Project(_, child))) if p2.outputSet.subsetOf(child.outputSet) => p1.copy(child = f.copy(child = child)) } ``` This method works well when there is only one Filter but would not if there's two or more Filters. In this case, there is a deterministic filter and a non-deterministic filter so they stay as separate filter nodes and cannot be combined together. An simplified illustration of the optimization process that forms the infinite loop is shown below (F1 stands for the 1st filter, F2 for the 2nd filter, P for project, S for scan of relation, PredicatePushDown as abbrev. of PushPredicatesThroughProject): ``` F1 - F2 - P - S PredicatePushDown => F1 - P - F2 - S ColumnPruning => F1 - P - F2 - P - S => F1 - P - F2 - S (Project removed) PredicatePushDown => P - F1 - F2 - S ColumnPruning => P - F1 - P - F2 - S => P - F1 - P - F2 - P - S => P - F1 - F2 - P - S (only one Project removed) RemoveRedundantProject => F1 - F2 - P - S (goes back to the loop start) ``` So the problem is the ColumnPruning rule adds a Project under a Filter (and fails to remove it in the end), and that new Project triggers PushPredicateThroughProject. Once the filters have been push through the Project, a new Project will be added by the ColumnPruning rule and this goes on and on. The fix should be when adding Projects, the rule applies top-down, but later when removing extra Projects, the process should go bottom-up to ensure all extra Projects can be matched. ## How was this patch tested? Added a optimization rule test in ColumnPruningSuite; and a end-to-end test in SQLQuerySuite. Author: maryannxue Closes #21674 from maryannxue/spark-24696. --- .../spark/sql/catalyst/dsl/package.scala | 1 + .../sql/catalyst/optimizer/Optimizer.scala | 5 +++-- .../optimizer/ColumnPruningSuite.scala | 9 +++++++- .../org/apache/spark/sql/SQLQuerySuite.scala | 21 +++++++++++++++++++ 4 files changed, 33 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index efb2eba655e15..8cf69c6f3c922 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -149,6 +149,7 @@ package object dsl { } } + def rand(e: Long): Expression = Rand(Literal.create(e, LongType)) def sum(e: Expression): Expression = Sum(e).toAggregateExpression() def sumDistinct(e: Expression): Expression = Sum(e).toAggregateExpression(isDistinct = true) def count(e: Expression): Expression = Count(e).toAggregateExpression() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index aa992def1ce6c..2cc27d82f7d20 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -526,9 +526,10 @@ object ColumnPruning extends Rule[LogicalPlan] { /** * The Project before Filter is not necessary but conflict with PushPredicatesThroughProject, - * so remove it. + * so remove it. Since the Projects have been added top-down, we need to remove in bottom-up + * order, otherwise lower Projects can be missed. */ - private def removeProjectBeforeFilter(plan: LogicalPlan): LogicalPlan = plan transform { + private def removeProjectBeforeFilter(plan: LogicalPlan): LogicalPlan = plan transformUp { case p1 @ Project(_, f @ Filter(_, p2 @ Project(_, child))) if p2.outputSet.subsetOf(child.outputSet) => p1.copy(child = f.copy(child = child)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala index 3f41f4b144096..8b05ba32e6eef 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.catalyst.optimizer import scala.reflect.runtime.universe.TypeTag -import org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder @@ -370,5 +369,13 @@ class ColumnPruningSuite extends PlanTest { comparePlans(optimized2, expected2.analyze) } + test("SPARK-24696 ColumnPruning rule fails to remove extra Project") { + val input = LocalRelation('key.int, 'value.string) + val query = input.select('key).where(rand(0L) > 0.5).where('key < 10).analyze + val optimized = Optimize.execute(query) + val expected = input.where(rand(0L) > 0.5).where('key < 10).select('key).analyze + comparePlans(optimized, expected) + } + // todo: add more tests for column pruning } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 640affc10ee58..dfb9c137b74f0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -2792,4 +2792,25 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } } } + + test("SPARK-24696 ColumnPruning rule fails to remove extra Project") { + withTable("fact_stats", "dim_stats") { + val factData = Seq((1, 1, 99, 1), (2, 2, 99, 2), (3, 1, 99, 3), (4, 2, 99, 4)) + val storeData = Seq((1, "BW", "DE"), (2, "AZ", "US")) + spark.udf.register("filterND", udf((value: Int) => value > 2).asNondeterministic) + factData.toDF("date_id", "store_id", "product_id", "units_sold") + .write.mode("overwrite").partitionBy("store_id").format("parquet").saveAsTable("fact_stats") + storeData.toDF("store_id", "state_province", "country") + .write.mode("overwrite").format("parquet").saveAsTable("dim_stats") + val df = sql( + """ + |SELECT f.date_id, f.product_id, f.store_id FROM + |(SELECT date_id, product_id, store_id + | FROM fact_stats WHERE filterND(date_id)) AS f + |JOIN dim_stats s + |ON f.store_id = s.store_id WHERE s.country = 'DE' + """.stripMargin) + checkAnswer(df, Seq(Row(3, 99, 1))) + } + } } From d54d8b86301581142293341af25fd78b3278a2e8 Mon Sep 17 00:00:00 2001 From: Xiao Li Date: Fri, 29 Jun 2018 23:51:13 -0700 Subject: [PATCH 1044/1161] simplify rand in dsl/package.scala --- .../main/scala/org/apache/spark/sql/catalyst/dsl/package.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 8cf69c6f3c922..89e8c998f740d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -149,7 +149,7 @@ package object dsl { } } - def rand(e: Long): Expression = Rand(Literal.create(e, LongType)) + def rand(e: Long): Expression = Rand(e) def sum(e: Expression): Expression = Sum(e).toAggregateExpression() def sumDistinct(e: Expression): Expression = Sum(e).toAggregateExpression(isDistinct = true) def count(e: Expression): Expression = Count(e).toAggregateExpression() From f825847c82042a9eee7bd5cfab106310d279fc32 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Sat, 30 Jun 2018 19:27:16 -0500 Subject: [PATCH 1045/1161] [SPARK-24654][BUILD] Update, fix LICENSE and NOTICE, and specialize for source vs binary Whew, lots of work to track down again all the license requirements, but this ought to be a pretty good pass. Below, find a writeup on how I approached it for future reference. - LICENSE and NOTICE and licenses/ now reflect the *source* release - LICENSE-binary and NOTICE-binary and licenses-binary now reflect the binary release - Recreated all the license info from scratch - Added notes about how this was constructed for next time - License-oriented info was moved from NOTICE to LICENSE, esp. for Cat B deps - Some seemingly superfluous or stale license info was removed, especially for test-scope deps - Updated release script to put binary-oriented versions in binary releases ---- # Principles ASF projects distribute source and binary code under the Apache License 2.0. However these project distributions frequently include copies of source or binary code from third parties, under possibly other license terms. This triggers conditions of those licenses, which essentially amount to including license information in a LICENSE and/or NOTICE file, and including copies of license texts (here, in a directory called `license/`). See http://www.apache.org/dev/licensing-howto.html and https://www.apache.org/legal/resolved.html#required-third-party-notices # In Spark Spark produces source releases, and also binary releases of that code. Spark source code may contain source from third parties, possibly modified. This is true in Scala, Java, Python and R, and in the UI's JavaScript and CSS files. These must be handled appropriately per above in a LICENSE and NOTICE file created for the source release. Separately, the binary releases may contain binary code from third parties. This is very much true for Scala and Java, as Spark produces an 'assembly' binary release which includes all transitive binary dependencies of this part of Spark. With perhaps the exception of py4j, this doesn't occur in the same way for Python or R because of the way these ecosystems work. (Note that the JS and CSS for the UI will be in both 'source' and 'binary' releases.) These must also be handled in a separate LICENSE and NOTICE file for the binary release. # Binary Release License ## Transitive Maven Dependencies We'll first tackle the binary release, and that almost entirely means assessing the transitive dependencies of the Scala/Java backbone of Spark. Run `project-info-reports:dependencies` with essentially all profiles: a set that would bring in all different possible transitive dependencies. However, don't activate any of the '-lgpl' profiles as these would bring in LGPL-licensed dependencies that are explicitly excluded from Spark binary releases. ``` mvn -Phadoop-2.7 -Pyarn -Phive -Pmesos -Pkubernetes -Pflume -Pkinesis-asl -Pdocker-integration-tests -Phive-thriftserver -Pkafka-0-8 -Ddependency.locations.enabled=false project-info-reports:dependencies ``` Open `assembly/target/site/dependencies.html`. Find "Project Transitive Dependencies", and find "compile" and "runtime" (if exists). This is a list of all the dependencies that Spark is going to ship in its binary "assembly" distro and therefore whose licenses need to be appropriately considered in LICENSE and NOTICE. Copy this table into a spreadsheet for easy management. Next job is to fill in some blanks, as a few projects will not have clearly declared their licenses in a POM. Sort by license. This is a good time to verify all the dependencies are at least Cat A/B licenses, and not Cat X! http://www.apache.org/legal/resolved.html ### Apache License 2 The Apache License 2 variants are typically easiest to deal with as they will not require you to modify LICENSE, nor add to license/. It's still good form to list the ALv2 dependencies in LICENSE for completeness, but optional. They may require you to propagate bits from NOTICE. It's tedious to track down all the NOTICE files and evaluate what if anything needs to be copied to NOTICE. Fortunately, this can be made easier as the assembly module can be temporarily modified to produce a NOTICE file that concatenates all NOTICE files bundled with transitive dependencies. First change the packaging of `assembly/spark-assembly_2.11/pom.xml` to `jar`. Next add this stanza somewhere in the body of the same POM file: ``` org.apache.maven.plugins maven-shade-plugin false *:* package shade ``` Finally execute `mvn ... package` with all of the same `-P` profile flags as above. In the JAR file at `assembly/target/spark-assembly_2.11....jar` you'll find a file `META-INF/NOTICE` that concatenates all NOTICE files bundled with transitive dependencies. This should be the starting point for the binary release's NOTICE file. Some elements in the file are from Spark itself, like: ``` Spark Project Assembly Copyright 2018 The Apache Software Foundation Spark Project Core Copyright 2018 The Apache Software Foundation ``` These can be removed. Remove elements of the combined NOTICE file that aren't relevant to Spark. It's actually rare that we are sure that some element is completely irrelevant to Spark, because each transitive dependency includes all its transitive dependencies. So there may be nothing that can be done here. Of course, some projects may not publish NOTICE in their Maven artifacts. Ideally, search for the NOTICE file of projects that don't seem to have produced any text in NOTICE, but, there is some argument that projects that don't produce a NOTICE in their Maven artifacts don't entail an obligation on projects that depend solely on their Maven artifacts. ### Other Licenses Next are "Cat A" permissively licensed (BSD 2-Clause, BSD 3-Clause, MIT) components. List the components grouped by their license type in LICENSE. Then add the text of the license to licenses/. For example if you list "foo bar" as a BSD-licensed dependency, add its license text as licenses/LICENSE-foo-bar.txt. Public domain and similar works are treated like permissively licensed dependencies. And the same goes for all Cat B licenses too, like CDDL. However these additional require at least a URL pointer to the project's page. Use the artifact hyperlink in your spreadsheet if possible; if non-existent or doesn't resolve, do your best to determine a URL for the project's source. ### Shaded third-party dependencies Some third party dependencies actually copy in other dependencies rather than depend on them as Maven artifacts. This means they don't show up in the process above. These can be quite hard to track down, but are rare. A key example is reflectasm, embedded in kryo. ### Examples module The above _almost_ considers everything bundled in a Spark binary release. The main assembly won't include examples. The same must be done for dependencies marked as 'compile' for the examples module. See `examples/target/site/dependencies.html`. At the time of this writing however this just adds one dependency: `scopt`. ### provided scope Above we considered just compile and runtime scope dependencies, which makes sense as they are the ones that are packaged. However, for complicated reasons (shading), a few components that Spark does bundle are not marked as compile dependencies in the assembly. Therefore it's also necessary to consider 'provided' dependencies from `assembly/target/site/dependencies.html` actually! Right now that's just Jetty and JPMML artifacts. ## Python, R Don't forget that Py4J is also distributed in the binary release, actually. There should be no other R, Python code in the binary release. That's it. ## Sense checking Compare the contents of `jars/`, `examples/jars/` and `python/lib` from a recent binary release to see if anything appears there that doesn't seem to have been covered above. These additional components will have to be handled manually, but should be few or none of this type. # Source Release License While there are relatively fewer third-party source artifacts included as source code, there is no automated way to detect it, really. It requires some degree of manual auditing. Most third party source comes from included JS and CSS files. At the time of this writing, some places to look or consider: `build/sbt-launch-lib.bash`, `python/lib`, third party source in `python/pyspark` like `heapq3.py`, `docs/js/vendor`, and `core/src/main/resources/org/apache/spark/ui/static`. The principles are the same as above. Remember some JS files copy in other JS files! Look out for Modernizr. # One More Thing: JS and CSS in Binary Release Now that you've got a handle on source licenses, recall that all the JS and CSS source code will *also* be part of the binary release. Copy that info from source to binary license files accordingly. Author: Sean Owen Closes #21640 from srowen/SPARK-24654. --- LICENSE | 158 +-- LICENSE-binary | 520 ++++++++ NOTICE | 661 ---------- NOTICE-binary | 1170 +++++++++++++++++ dev/.rat-excludes | 4 + dev/make-distribution.sh | 7 +- .../LICENSE-AnchorJS.txt | 0 licenses-binary/LICENSE-CC0.txt | 121 ++ .../LICENSE-antlr.txt | 0 licenses-binary/LICENSE-arpack.txt | 8 + licenses-binary/LICENSE-automaton.txt | 24 + licenses-binary/LICENSE-bootstrap.txt | 13 + .../LICENSE-bouncycastle-bcprov.txt | 7 + licenses-binary/LICENSE-cloudpickle.txt | 28 + licenses-binary/LICENSE-d3.min.js.txt | 26 + .../LICENSE-dagre-d3.txt | 4 +- licenses-binary/LICENSE-datatables.txt | 7 + {licenses => licenses-binary}/LICENSE-f2j.txt | 0 licenses-binary/LICENSE-graphlib-dot.txt | 19 + licenses-binary/LICENSE-heapq.txt | 280 ++++ licenses-binary/LICENSE-janino.txt | 31 + licenses-binary/LICENSE-javassist.html | 373 ++++++ .../LICENSE-javolution.txt | 0 .../LICENSE-jline.txt | 0 .../LICENSE-jodd.txt | 12 +- .../LICENSE-join.txt | 0 licenses-binary/LICENSE-jquery.txt | 20 + licenses-binary/LICENSE-json-formatter.txt | 6 + licenses-binary/LICENSE-jtransforms.html | 388 ++++++ .../LICENSE-kryo.txt | 0 licenses-binary/LICENSE-leveldbjni.txt | 27 + licenses-binary/LICENSE-machinist.txt | 19 + .../LICENSE-matchMedia-polyfill.txt | 1 + .../LICENSE-minlog.txt | 0 licenses-binary/LICENSE-modernizr.txt | 21 + .../LICENSE-netlib.txt | 0 .../LICENSE-paranamer.txt | 0 .../LICENSE-pmml-model.txt | 0 .../LICENSE-protobuf.txt | 0 licenses-binary/LICENSE-py4j.txt | 27 + .../LICENSE-pyrolite.txt | 0 .../LICENSE-reflectasm.txt | 0 licenses-binary/LICENSE-respond.txt | 22 + licenses-binary/LICENSE-sbt-launch-lib.txt | 26 + .../LICENSE-scala.txt | 0 licenses-binary/LICENSE-scopt.txt | 9 + .../LICENSE-slf4j.txt | 0 licenses-binary/LICENSE-sorttable.js.txt | 16 + .../LICENSE-spire.txt | 0 licenses-binary/LICENSE-vis.txt | 22 + .../LICENSE-xmlenc.txt | 0 .../LICENSE-zstd-jni.txt | 0 .../LICENSE-zstd.txt | 0 licenses/LICENSE-CC0.txt | 121 ++ licenses/LICENSE-SnapTree.txt | 35 - licenses/LICENSE-bootstrap.txt | 13 + licenses/LICENSE-boto.txt | 20 - licenses/LICENSE-datatables.txt | 7 + licenses/LICENSE-graphlib-dot.txt | 2 +- licenses/LICENSE-jbcrypt.txt | 17 - .../{LICENSE-jmock.txt => LICENSE-join.txt} | 22 +- licenses/LICENSE-jquery.txt | 23 +- licenses/LICENSE-json-formatter.txt | 6 + licenses/LICENSE-matchMedia-polyfill.txt | 1 + licenses/LICENSE-postgresql.txt | 24 - licenses/LICENSE-respond.txt | 22 + licenses/LICENSE-scalacheck.txt | 32 - licenses/LICENSE-vis.txt | 22 + 68 files changed, 3526 insertions(+), 918 deletions(-) create mode 100644 LICENSE-binary create mode 100644 NOTICE-binary rename licenses/LICENSE-scopt.txt => licenses-binary/LICENSE-AnchorJS.txt (100%) create mode 100644 licenses-binary/LICENSE-CC0.txt rename {licenses => licenses-binary}/LICENSE-antlr.txt (100%) create mode 100644 licenses-binary/LICENSE-arpack.txt create mode 100644 licenses-binary/LICENSE-automaton.txt create mode 100644 licenses-binary/LICENSE-bootstrap.txt create mode 100644 licenses-binary/LICENSE-bouncycastle-bcprov.txt create mode 100644 licenses-binary/LICENSE-cloudpickle.txt create mode 100644 licenses-binary/LICENSE-d3.min.js.txt rename licenses/LICENSE-Mockito.txt => licenses-binary/LICENSE-dagre-d3.txt (94%) create mode 100644 licenses-binary/LICENSE-datatables.txt rename {licenses => licenses-binary}/LICENSE-f2j.txt (100%) create mode 100644 licenses-binary/LICENSE-graphlib-dot.txt create mode 100644 licenses-binary/LICENSE-heapq.txt create mode 100644 licenses-binary/LICENSE-janino.txt create mode 100644 licenses-binary/LICENSE-javassist.html rename {licenses => licenses-binary}/LICENSE-javolution.txt (100%) rename {licenses => licenses-binary}/LICENSE-jline.txt (100%) rename licenses/LICENSE-junit-interface.txt => licenses-binary/LICENSE-jodd.txt (69%) rename licenses/LICENSE-DPark.txt => licenses-binary/LICENSE-join.txt (100%) create mode 100644 licenses-binary/LICENSE-jquery.txt create mode 100644 licenses-binary/LICENSE-json-formatter.txt create mode 100644 licenses-binary/LICENSE-jtransforms.html rename {licenses => licenses-binary}/LICENSE-kryo.txt (100%) create mode 100644 licenses-binary/LICENSE-leveldbjni.txt create mode 100644 licenses-binary/LICENSE-machinist.txt create mode 100644 licenses-binary/LICENSE-matchMedia-polyfill.txt rename {licenses => licenses-binary}/LICENSE-minlog.txt (100%) create mode 100644 licenses-binary/LICENSE-modernizr.txt rename {licenses => licenses-binary}/LICENSE-netlib.txt (100%) rename {licenses => licenses-binary}/LICENSE-paranamer.txt (100%) rename licenses/LICENSE-jpmml-model.txt => licenses-binary/LICENSE-pmml-model.txt (100%) rename {licenses => licenses-binary}/LICENSE-protobuf.txt (100%) create mode 100644 licenses-binary/LICENSE-py4j.txt rename {licenses => licenses-binary}/LICENSE-pyrolite.txt (100%) rename {licenses => licenses-binary}/LICENSE-reflectasm.txt (100%) create mode 100644 licenses-binary/LICENSE-respond.txt create mode 100644 licenses-binary/LICENSE-sbt-launch-lib.txt rename {licenses => licenses-binary}/LICENSE-scala.txt (100%) create mode 100644 licenses-binary/LICENSE-scopt.txt rename {licenses => licenses-binary}/LICENSE-slf4j.txt (100%) create mode 100644 licenses-binary/LICENSE-sorttable.js.txt rename {licenses => licenses-binary}/LICENSE-spire.txt (100%) create mode 100644 licenses-binary/LICENSE-vis.txt rename {licenses => licenses-binary}/LICENSE-xmlenc.txt (100%) rename {licenses => licenses-binary}/LICENSE-zstd-jni.txt (100%) rename {licenses => licenses-binary}/LICENSE-zstd.txt (100%) create mode 100644 licenses/LICENSE-CC0.txt delete mode 100644 licenses/LICENSE-SnapTree.txt create mode 100644 licenses/LICENSE-bootstrap.txt delete mode 100644 licenses/LICENSE-boto.txt create mode 100644 licenses/LICENSE-datatables.txt delete mode 100644 licenses/LICENSE-jbcrypt.txt rename licenses/{LICENSE-jmock.txt => LICENSE-join.txt} (60%) create mode 100644 licenses/LICENSE-json-formatter.txt create mode 100644 licenses/LICENSE-matchMedia-polyfill.txt delete mode 100644 licenses/LICENSE-postgresql.txt create mode 100644 licenses/LICENSE-respond.txt delete mode 100644 licenses/LICENSE-scalacheck.txt create mode 100644 licenses/LICENSE-vis.txt diff --git a/LICENSE b/LICENSE index 6f5d9452e800d..b771bd552b762 100644 --- a/LICENSE +++ b/LICENSE @@ -201,103 +201,61 @@ limitations under the License. -======================================================================= -Apache Spark Subcomponents: - -The Apache Spark project contains subcomponents with separate copyright -notices and license terms. Your use of the source code for the these -subcomponents is subject to the terms and conditions of the following -licenses. - - -======================================================================== -For heapq (pyspark/heapq3.py): -======================================================================== - -See license/LICENSE-heapq.txt - -======================================================================== -For SnapTree: -======================================================================== - -See license/LICENSE-SnapTree.txt - -======================================================================== -For jbcrypt: -======================================================================== - -See license/LICENSE-jbcrypt.txt - -======================================================================== -BSD-style licenses -======================================================================== - -The following components are provided under a BSD-style license. See project link for details. -The text of each license is also included at licenses/LICENSE-[project].txt. - - (BSD 3 Clause) netlib core (com.github.fommil.netlib:core:1.1.2 - https://github.com/fommil/netlib-java/core) - (BSD 3 Clause) JPMML-Model (org.jpmml:pmml-model:1.2.7 - https://github.com/jpmml/jpmml-model) - (BSD 3 Clause) jmock (org.jmock:jmock-junit4:2.8.4 - http://jmock.org/) - (BSD License) AntLR Parser Generator (antlr:antlr:2.7.7 - http://www.antlr.org/) - (BSD License) ANTLR 4.5.2-1 (org.antlr:antlr4:4.5.2-1 - http://wwww.antlr.org/) - (BSD licence) ANTLR ST4 4.0.4 (org.antlr:ST4:4.0.4 - http://www.stringtemplate.org) - (BSD licence) ANTLR StringTemplate (org.antlr:stringtemplate:3.2.1 - http://www.stringtemplate.org) - (BSD License) Javolution (javolution:javolution:5.5.1 - http://javolution.org) - (BSD) JLine (jline:jline:2.14.3 - https://github.com/jline/jline2) - (BSD) ParaNamer Core (com.thoughtworks.paranamer:paranamer:2.3 - http://paranamer.codehaus.org/paranamer) - (BSD) ParaNamer Core (com.thoughtworks.paranamer:paranamer:2.6 - http://paranamer.codehaus.org/paranamer) - (BSD 3 Clause) Scala (http://www.scala-lang.org/download/#License) - (Interpreter classes (all .scala files in repl/src/main/scala - except for Main.Scala, SparkHelper.scala and ExecutorClassLoader.scala), - and for SerializableMapWrapper in JavaUtils.scala) - (BSD-like) Scala Actors library (org.scala-lang:scala-actors:2.11.12 - http://www.scala-lang.org/) - (BSD-like) Scala Compiler (org.scala-lang:scala-compiler:2.11.12 - http://www.scala-lang.org/) - (BSD-like) Scala Compiler (org.scala-lang:scala-reflect:2.11.12 - http://www.scala-lang.org/) - (BSD-like) Scala Library (org.scala-lang:scala-library:2.11.12 - http://www.scala-lang.org/) - (BSD-like) Scalap (org.scala-lang:scalap:2.11.12 - http://www.scala-lang.org/) - (BSD-style) scalacheck (org.scalacheck:scalacheck_2.11:1.10.0 - http://www.scalacheck.org) - (BSD-style) spire (org.spire-math:spire_2.11:0.7.1 - http://spire-math.org) - (BSD-style) spire-macros (org.spire-math:spire-macros_2.11:0.7.1 - http://spire-math.org) - (New BSD License) Kryo (com.esotericsoftware:kryo:3.0.3 - https://github.com/EsotericSoftware/kryo) - (New BSD License) MinLog (com.esotericsoftware:minlog:1.3.0 - https://github.com/EsotericSoftware/minlog) - (New BSD license) Protocol Buffer Java API (com.google.protobuf:protobuf-java:2.5.0 - http://code.google.com/p/protobuf) - (New BSD license) Protocol Buffer Java API (org.spark-project.protobuf:protobuf-java:2.4.1-shaded - http://code.google.com/p/protobuf) - (The BSD License) Fortran to Java ARPACK (net.sourceforge.f2j:arpack_combined_all:0.1 - http://f2j.sourceforge.net) - (The BSD License) xmlenc Library (xmlenc:xmlenc:0.52 - http://xmlenc.sourceforge.net) - (The New BSD License) Py4J (net.sf.py4j:py4j:0.10.7 - http://py4j.sourceforge.net/) - (Two-clause BSD-style license) JUnit-Interface (com.novocode:junit-interface:0.10 - http://github.com/szeiger/junit-interface/) - (BSD licence) sbt and sbt-launch-lib.bash - (BSD 3 Clause) d3.min.js (https://github.com/mbostock/d3/blob/master/LICENSE) - (BSD 3 Clause) DPark (https://github.com/douban/dpark/blob/master/LICENSE) - (BSD 3 Clause) CloudPickle (https://github.com/cloudpipe/cloudpickle/blob/master/LICENSE) - (BSD 2 Clause) Zstd-jni (https://github.com/luben/zstd-jni/blob/master/LICENSE) - (BSD license) Zstd (https://github.com/facebook/zstd/blob/v1.3.1/LICENSE) - -======================================================================== -MIT licenses -======================================================================== - -The following components are provided under the MIT License. See project link for details. -The text of each license is also included at licenses/LICENSE-[project].txt. - - (MIT License) JCL 1.1.1 implemented over SLF4J (org.slf4j:jcl-over-slf4j:1.7.5 - http://www.slf4j.org) - (MIT License) JUL to SLF4J bridge (org.slf4j:jul-to-slf4j:1.7.5 - http://www.slf4j.org) - (MIT License) SLF4J API Module (org.slf4j:slf4j-api:1.7.5 - http://www.slf4j.org) - (MIT License) SLF4J LOG4J-12 Binding (org.slf4j:slf4j-log4j12:1.7.5 - http://www.slf4j.org) - (MIT License) pyrolite (org.spark-project:pyrolite:2.0.1 - http://pythonhosted.org/Pyro4/) - (MIT License) scopt (com.github.scopt:scopt_2.11:3.2.0 - https://github.com/scopt/scopt) - (The MIT License) Mockito (org.mockito:mockito-core:1.9.5 - http://www.mockito.org) - (MIT License) jquery (https://jquery.org/license/) - (MIT License) AnchorJS (https://github.com/bryanbraun/anchorjs) - (MIT License) graphlib-dot (https://github.com/cpettitt/graphlib-dot) - (MIT License) dagre-d3 (https://github.com/cpettitt/dagre-d3) - (MIT License) sorttable (https://github.com/stuartlangridge/sorttable) - (MIT License) boto (https://github.com/boto/boto/blob/develop/LICENSE) - (MIT License) datatables (http://datatables.net/license) - (MIT License) mustache (https://github.com/mustache/mustache/blob/master/LICENSE) - (MIT License) cookies (http://code.google.com/p/cookies/wiki/License) - (MIT License) blockUI (http://jquery.malsup.com/block/) - (MIT License) RowsGroup (http://datatables.net/license/mit) - (MIT License) jsonFormatter (http://www.jqueryscript.net/other/jQuery-Plugin-For-Pretty-JSON-Formatting-jsonFormatter.html) - (MIT License) modernizr (https://github.com/Modernizr/Modernizr/blob/master/LICENSE) - (MIT License) machinist (https://github.com/typelevel/machinist) +------------------------------------------------------------------------------------ +This product bundles various third-party components under other open source licenses. +This section summarizes those components and their licenses. See licenses/ +for text of these licenses. + + +Apache Software Foundation License 2.0 +-------------------------------------- + +common/network-common/src/main/java/org/apache/spark/network/util/LimitedInputStream.java +core/src/main/java/org/apache/spark/util/collection/TimSort.java +core/src/main/resources/org/apache/spark/ui/static/bootstrap* +core/src/main/resources/org/apache/spark/ui/static/jsonFormatter* +core/src/main/resources/org/apache/spark/ui/static/vis* +docs/js/vendor/bootstrap.js + + +Python Software Foundation License +---------------------------------- + +pyspark/heapq3.py + + +BSD 3-Clause +------------ + +python/lib/py4j-*-src.zip +python/pyspark/cloudpickle.py +python/pyspark/join.py +core/src/main/resources/org/apache/spark/ui/static/d3.min.js + +The CSS style for the navigation sidebar of the documentation was originally +submitted by Óscar Nájera for the scikit-learn project. The scikit-learn project +is distributed under the 3-Clause BSD license. + + +MIT License +----------- + +core/src/main/resources/org/apache/spark/ui/static/dagre-d3.min.js +core/src/main/resources/org/apache/spark/ui/static/*dataTables* +core/src/main/resources/org/apache/spark/ui/static/graphlib-dot.min.js +ore/src/main/resources/org/apache/spark/ui/static/jquery* +core/src/main/resources/org/apache/spark/ui/static/sorttable.js +docs/js/vendor/anchor.min.js +docs/js/vendor/jquery* +docs/js/vendor/modernizer* + + +Creative Commons CC0 1.0 Universal Public Domain Dedication +----------------------------------------------------------- +(see LICENSE-CC0.txt) + +data/mllib/images/kittens/29.5.a_b_EGDP022204.jpg +data/mllib/images/kittens/54893.jpg +data/mllib/images/kittens/DP153539.jpg +data/mllib/images/kittens/DP802813.jpg +data/mllib/images/multi-channel/chr30.4.184.jpg \ No newline at end of file diff --git a/LICENSE-binary b/LICENSE-binary new file mode 100644 index 0000000000000..c033dd8ad2e6a --- /dev/null +++ b/LICENSE-binary @@ -0,0 +1,520 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + +------------------------------------------------------------------------------------ +This project bundles some components that are also licensed under the Apache +License Version 2.0: + +commons-beanutils:commons-beanutils +org.apache.zookeeper:zookeeper +oro:oro +commons-configuration:commons-configuration +commons-digester:commons-digester +com.chuusai:shapeless_2.11 +com.googlecode.javaewah:JavaEWAH +com.twitter:chill-java +com.twitter:chill_2.11 +com.univocity:univocity-parsers +javax.jdo:jdo-api +joda-time:joda-time +net.sf.opencsv:opencsv +org.apache.derby:derby +org.objenesis:objenesis +org.roaringbitmap:RoaringBitmap +org.scalanlp:breeze-macros_2.11 +org.scalanlp:breeze_2.11 +org.typelevel:macro-compat_2.11 +org.yaml:snakeyaml +org.apache.xbean:xbean-asm5-shaded +com.squareup.okhttp3:logging-interceptor +com.squareup.okhttp3:okhttp +com.squareup.okio:okio +net.java.dev.jets3t:jets3t +org.apache.spark:spark-catalyst_2.11 +org.apache.spark:spark-kvstore_2.11 +org.apache.spark:spark-launcher_2.11 +org.apache.spark:spark-mllib-local_2.11 +org.apache.spark:spark-network-common_2.11 +org.apache.spark:spark-network-shuffle_2.11 +org.apache.spark:spark-sketch_2.11 +org.apache.spark:spark-tags_2.11 +org.apache.spark:spark-unsafe_2.11 +commons-httpclient:commons-httpclient +com.vlkan:flatbuffers +com.ning:compress-lzf +io.airlift:aircompressor +io.dropwizard.metrics:metrics-core +io.dropwizard.metrics:metrics-ganglia +io.dropwizard.metrics:metrics-graphite +io.dropwizard.metrics:metrics-json +io.dropwizard.metrics:metrics-jvm +org.iq80.snappy:snappy +com.clearspring.analytics:stream +com.jamesmurty.utils:java-xmlbuilder +commons-codec:commons-codec +commons-collections:commons-collections +io.fabric8:kubernetes-client +io.fabric8:kubernetes-model +io.netty:netty +io.netty:netty-all +net.hydromatic:eigenbase-properties +net.sf.supercsv:super-csv +org.apache.arrow:arrow-format +org.apache.arrow:arrow-memory +org.apache.arrow:arrow-vector +org.apache.calcite:calcite-avatica +org.apache.calcite:calcite-core +org.apache.calcite:calcite-linq4j +org.apache.commons:commons-crypto +org.apache.commons:commons-lang3 +org.apache.hadoop:hadoop-annotations +org.apache.hadoop:hadoop-auth +org.apache.hadoop:hadoop-client +org.apache.hadoop:hadoop-common +org.apache.hadoop:hadoop-hdfs +org.apache.hadoop:hadoop-mapreduce-client-app +org.apache.hadoop:hadoop-mapreduce-client-common +org.apache.hadoop:hadoop-mapreduce-client-core +org.apache.hadoop:hadoop-mapreduce-client-jobclient +org.apache.hadoop:hadoop-mapreduce-client-shuffle +org.apache.hadoop:hadoop-yarn-api +org.apache.hadoop:hadoop-yarn-client +org.apache.hadoop:hadoop-yarn-common +org.apache.hadoop:hadoop-yarn-server-common +org.apache.hadoop:hadoop-yarn-server-web-proxy +org.apache.httpcomponents:httpclient +org.apache.httpcomponents:httpcore +org.apache.orc:orc-core +org.apache.orc:orc-mapreduce +org.mortbay.jetty:jetty +org.mortbay.jetty:jetty-util +com.jolbox:bonecp +org.json4s:json4s-ast_2.11 +org.json4s:json4s-core_2.11 +org.json4s:json4s-jackson_2.11 +org.json4s:json4s-scalap_2.11 +com.carrotsearch:hppc +com.fasterxml.jackson.core:jackson-annotations +com.fasterxml.jackson.core:jackson-core +com.fasterxml.jackson.core:jackson-databind +com.fasterxml.jackson.dataformat:jackson-dataformat-yaml +com.fasterxml.jackson.module:jackson-module-jaxb-annotations +com.fasterxml.jackson.module:jackson-module-paranamer +com.fasterxml.jackson.module:jackson-module-scala_2.11 +com.github.mifmif:generex +com.google.code.findbugs:jsr305 +com.google.code.gson:gson +com.google.inject:guice +com.google.inject.extensions:guice-servlet +com.twitter:parquet-hadoop-bundle +commons-beanutils:commons-beanutils-core +commons-cli:commons-cli +commons-dbcp:commons-dbcp +commons-io:commons-io +commons-lang:commons-lang +commons-logging:commons-logging +commons-net:commons-net +commons-pool:commons-pool +io.fabric8:zjsonpatch +javax.inject:javax.inject +javax.validation:validation-api +log4j:apache-log4j-extras +log4j:log4j +net.sf.jpam:jpam +org.apache.avro:avro +org.apache.avro:avro-ipc +org.apache.avro:avro-mapred +org.apache.commons:commons-compress +org.apache.commons:commons-math3 +org.apache.curator:curator-client +org.apache.curator:curator-framework +org.apache.curator:curator-recipes +org.apache.directory.api:api-asn1-api +org.apache.directory.api:api-util +org.apache.directory.server:apacheds-i18n +org.apache.directory.server:apacheds-kerberos-codec +org.apache.htrace:htrace-core +org.apache.ivy:ivy +org.apache.mesos:mesos +org.apache.parquet:parquet-column +org.apache.parquet:parquet-common +org.apache.parquet:parquet-encoding +org.apache.parquet:parquet-format +org.apache.parquet:parquet-hadoop +org.apache.parquet:parquet-jackson +org.apache.thrift:libfb303 +org.apache.thrift:libthrift +org.codehaus.jackson:jackson-core-asl +org.codehaus.jackson:jackson-mapper-asl +org.datanucleus:datanucleus-api-jdo +org.datanucleus:datanucleus-core +org.datanucleus:datanucleus-rdbms +org.lz4:lz4-java +org.spark-project.hive:hive-beeline +org.spark-project.hive:hive-cli +org.spark-project.hive:hive-exec +org.spark-project.hive:hive-jdbc +org.spark-project.hive:hive-metastore +org.xerial.snappy:snappy-java +stax:stax-api +xerces:xercesImpl +org.codehaus.jackson:jackson-jaxrs +org.codehaus.jackson:jackson-xc +org.eclipse.jetty:jetty-client +org.eclipse.jetty:jetty-continuation +org.eclipse.jetty:jetty-http +org.eclipse.jetty:jetty-io +org.eclipse.jetty:jetty-jndi +org.eclipse.jetty:jetty-plus +org.eclipse.jetty:jetty-proxy +org.eclipse.jetty:jetty-security +org.eclipse.jetty:jetty-server +org.eclipse.jetty:jetty-servlet +org.eclipse.jetty:jetty-servlets +org.eclipse.jetty:jetty-util +org.eclipse.jetty:jetty-webapp +org.eclipse.jetty:jetty-xml + +core/src/main/java/org/apache/spark/util/collection/TimSort.java +core/src/main/resources/org/apache/spark/ui/static/bootstrap* +core/src/main/resources/org/apache/spark/ui/static/jsonFormatter* +core/src/main/resources/org/apache/spark/ui/static/vis* +docs/js/vendor/bootstrap.js + + +------------------------------------------------------------------------------------ +This product bundles various third-party components under other open source licenses. +This section summarizes those components and their licenses. See licenses-binary/ +for text of these licenses. + + +BSD 2-Clause +------------ + +com.github.luben:zstd-jni +javolution:javolution +com.esotericsoftware:kryo-shaded +com.esotericsoftware:minlog +com.esotericsoftware:reflectasm +com.google.protobuf:protobuf-java +org.codehaus.janino:commons-compiler +org.codehaus.janino:janino +jline:jline +org.jodd:jodd-core + + +BSD 3-Clause +------------ + +dk.brics.automaton:automaton +org.antlr:antlr-runtime +org.antlr:ST4 +org.antlr:stringtemplate +org.antlr:antlr4-runtime +antlr:antlr +com.github.fommil.netlib:core +com.thoughtworks.paranamer:paranamer +org.scala-lang:scala-compiler +org.scala-lang:scala-library +org.scala-lang:scala-reflect +org.scala-lang.modules:scala-parser-combinators_2.11 +org.scala-lang.modules:scala-xml_2.11 +org.fusesource.leveldbjni:leveldbjni-all +net.sourceforge.f2j:arpack_combined_all +xmlenc:xmlenc +net.sf.py4j:py4j +org.jpmml:pmml-model +org.jpmml:pmml-schema + +python/lib/py4j-*-src.zip +python/pyspark/cloudpickle.py +python/pyspark/join.py +core/src/main/resources/org/apache/spark/ui/static/d3.min.js + +The CSS style for the navigation sidebar of the documentation was originally +submitted by Óscar Nájera for the scikit-learn project. The scikit-learn project +is distributed under the 3-Clause BSD license. + + +MIT License +----------- + +org.spire-math:spire-macros_2.11 +org.spire-math:spire_2.11 +org.typelevel:machinist_2.11 +net.razorvine:pyrolite +org.slf4j:jcl-over-slf4j +org.slf4j:jul-to-slf4j +org.slf4j:slf4j-api +org.slf4j:slf4j-log4j12 +com.github.scopt:scopt_2.11 +org.bouncycastle:bcprov-jdk15on + +core/src/main/resources/org/apache/spark/ui/static/dagre-d3.min.js +core/src/main/resources/org/apache/spark/ui/static/*dataTables* +core/src/main/resources/org/apache/spark/ui/static/graphlib-dot.min.js +ore/src/main/resources/org/apache/spark/ui/static/jquery* +core/src/main/resources/org/apache/spark/ui/static/sorttable.js +docs/js/vendor/anchor.min.js +docs/js/vendor/jquery* +docs/js/vendor/modernizer* + + +Common Development and Distribution License (CDDL) 1.0 +------------------------------------------------------ + +javax.activation:activation http://www.oracle.com/technetwork/java/javase/tech/index-jsp-138795.html +javax.xml.stream:stax-api https://jcp.org/en/jsr/detail?id=173 + + +Common Development and Distribution License (CDDL) 1.1 +------------------------------------------------------ + +javax.annotation:javax.annotation-api https://jcp.org/en/jsr/detail?id=250 +javax.servlet:javax.servlet-api https://javaee.github.io/servlet-spec/ +javax.transaction:jta http://www.oracle.com/technetwork/java/index.html +javax.ws.rs:javax.ws.rs-api https://github.com/jax-rs +javax.xml.bind:jaxb-api https://github.com/javaee/jaxb-v2 +org.glassfish.hk2:hk2-api https://github.com/javaee/glassfish +org.glassfish.hk2:hk2-locator (same) +org.glassfish.hk2:hk2-utils +org.glassfish.hk2:osgi-resource-locator +org.glassfish.hk2.external:aopalliance-repackaged +org.glassfish.hk2.external:javax.inject +org.glassfish.jersey.bundles.repackaged:jersey-guava +org.glassfish.jersey.containers:jersey-container-servlet +org.glassfish.jersey.containers:jersey-container-servlet-core +org.glassfish.jersey.core:jersey-client +org.glassfish.jersey.core:jersey-common +org.glassfish.jersey.core:jersey-server +org.glassfish.jersey.media:jersey-media-jaxb + + +Mozilla Public License (MPL) 1.1 +-------------------------------- + +com.github.rwl:jtransforms https://sourceforge.net/projects/jtransforms/ + + +Python Software Foundation License +---------------------------------- + +pyspark/heapq3.py + + +Public Domain +------------- + +aopalliance:aopalliance +net.iharder:base64 +org.tukaani:xz + + +Creative Commons CC0 1.0 Universal Public Domain Dedication +----------------------------------------------------------- +(see LICENSE-CC0.txt) + +data/mllib/images/kittens/29.5.a_b_EGDP022204.jpg +data/mllib/images/kittens/54893.jpg +data/mllib/images/kittens/DP153539.jpg +data/mllib/images/kittens/DP802813.jpg +data/mllib/images/multi-channel/chr30.4.184.jpg diff --git a/NOTICE b/NOTICE index 6ec240efbf12e..9246cc54caa3a 100644 --- a/NOTICE +++ b/NOTICE @@ -4,664 +4,3 @@ Copyright 2014 and onwards The Apache Software Foundation. This product includes software developed at The Apache Software Foundation (http://www.apache.org/). - -======================================================================== -Common Development and Distribution License 1.0 -======================================================================== - -The following components are provided under the Common Development and Distribution License 1.0. See project link for details. - - (CDDL 1.0) Glassfish Jasper (org.mortbay.jetty:jsp-2.1:6.1.14 - http://jetty.mortbay.org/project/modules/jsp-2.1) - (CDDL 1.0) JAX-RS (https://jax-rs-spec.java.net/) - (CDDL 1.0) Servlet Specification 2.5 API (org.mortbay.jetty:servlet-api-2.5:6.1.14 - http://jetty.mortbay.org/project/modules/servlet-api-2.5) - (CDDL 1.0) (GPL2 w/ CPE) javax.annotation API (https://glassfish.java.net/nonav/public/CDDL+GPL.html) - (COMMON DEVELOPMENT AND DISTRIBUTION LICENSE (CDDL) Version 1.0) (GNU General Public Library) Streaming API for XML (javax.xml.stream:stax-api:1.0-2 - no url defined) - (Common Development and Distribution License (CDDL) v1.0) JavaBeans Activation Framework (JAF) (javax.activation:activation:1.1 - http://java.sun.com/products/javabeans/jaf/index.jsp) - -======================================================================== -Common Development and Distribution License 1.1 -======================================================================== - -The following components are provided under the Common Development and Distribution License 1.1. See project link for details. - - (CDDL 1.1) (GPL2 w/ CPE) org.glassfish.hk2 (https://hk2.java.net) - (CDDL 1.1) (GPL2 w/ CPE) JAXB API bundle for GlassFish V3 (javax.xml.bind:jaxb-api:2.2.2 - https://jaxb.dev.java.net/) - (CDDL 1.1) (GPL2 w/ CPE) JAXB RI (com.sun.xml.bind:jaxb-impl:2.2.3-1 - http://jaxb.java.net/) - (CDDL 1.1) (GPL2 w/ CPE) Jersey 2 (https://jersey.java.net) - -======================================================================== -Common Public License 1.0 -======================================================================== - -The following components are provided under the Common Public 1.0 License. See project link for details. - - (Common Public License Version 1.0) JUnit (junit:junit-dep:4.10 - http://junit.org) - (Common Public License Version 1.0) JUnit (junit:junit:3.8.1 - http://junit.org) - (Common Public License Version 1.0) JUnit (junit:junit:4.8.2 - http://junit.org) - -======================================================================== -Eclipse Public License 1.0 -======================================================================== - -The following components are provided under the Eclipse Public License 1.0. See project link for details. - - (Eclipse Public License v1.0) Eclipse JDT Core (org.eclipse.jdt:core:3.1.1 - http://www.eclipse.org/jdt/) - -======================================================================== -Mozilla Public License 1.0 -======================================================================== - -The following components are provided under the Mozilla Public License 1.0. See project link for details. - - (GPL) (LGPL) (MPL) JTransforms (com.github.rwl:jtransforms:2.4.0 - http://sourceforge.net/projects/jtransforms/) - (Mozilla Public License Version 1.1) jamon-runtime (org.jamon:jamon-runtime:2.3.1 - http://www.jamon.org/jamon-runtime/) - - - -======================================================================== -NOTICE files -======================================================================== - -The following NOTICEs are pertain to software distributed with this project. - - -// ------------------------------------------------------------------ -// NOTICE file corresponding to the section 4d of The Apache License, -// Version 2.0, in this case for -// ------------------------------------------------------------------ - -Apache Avro -Copyright 2009-2013 The Apache Software Foundation - -This product includes software developed at -The Apache Software Foundation (http://www.apache.org/). - -Apache Commons Codec -Copyright 2002-2009 The Apache Software Foundation - -This product includes software developed by -The Apache Software Foundation (http://www.apache.org/). - --------------------------------------------------------------------------------- -src/test/org/apache/commons/codec/language/DoubleMetaphoneTest.java contains -test data from http://aspell.sourceforge.net/test/batch0.tab. - -Copyright (C) 2002 Kevin Atkinson (kevina@gnu.org). Verbatim copying -and distribution of this entire article is permitted in any medium, -provided this notice is preserved. --------------------------------------------------------------------------------- - -Apache HttpComponents HttpClient -Copyright 1999-2011 The Apache Software Foundation - -This project contains annotations derived from JCIP-ANNOTATIONS -Copyright (c) 2005 Brian Goetz and Tim Peierls. See http://www.jcip.net - -Apache HttpComponents HttpCore -Copyright 2005-2011 The Apache Software Foundation - -Curator Recipes -Copyright 2011-2014 The Apache Software Foundation - -Curator Framework -Copyright 2011-2014 The Apache Software Foundation - -Curator Client -Copyright 2011-2014 The Apache Software Foundation - -Apache Geronimo -Copyright 2003-2008 The Apache Software Foundation - -Activation 1.1 -Copyright 2003-2007 The Apache Software Foundation - -Apache Commons Lang -Copyright 2001-2014 The Apache Software Foundation - -This product includes software from the Spring Framework, -under the Apache License 2.0 (see: StringUtils.containsWhitespace()) - -Apache log4j -Copyright 2007 The Apache Software Foundation - -# Compress LZF - -This library contains efficient implementation of LZF compression format, -as well as additional helper classes that build on JDK-provided gzip (deflat) -codec. - -## Licensing - -Library is licensed under Apache License 2.0, as per accompanying LICENSE file. - -## Credit - -Library has been written by Tatu Saloranta (tatu.saloranta@iki.fi). -It was started at Ning, inc., as an official Open Source process used by -platform backend, but after initial versions has been developed outside of -Ning by supporting community. - -Other contributors include: - -* Jon Hartlaub (first versions of streaming reader/writer; unit tests) -* Cedrik Lime: parallel LZF implementation - -Various community members have contributed bug reports, and suggested minor -fixes; these can be found from file "VERSION.txt" in SCM. - -Objenesis -Copyright 2006-2009 Joe Walnes, Henri Tremblay, Leonardo Mesquita - -Apache Commons Net -Copyright 2001-2010 The Apache Software Foundation - - The Netty Project - ================= - -Please visit the Netty web site for more information: - - * http://netty.io/ - -Copyright 2011 The Netty Project - -The Netty Project licenses this file to you under the Apache License, -version 2.0 (the "License"); you may not use this file except in compliance -with the License. You may obtain a copy of the License at: - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -License for the specific language governing permissions and limitations -under the License. - -Also, please refer to each LICENSE..txt file, which is located in -the 'license' directory of the distribution file, for the license terms of the -components that this product depends on. - -------------------------------------------------------------------------------- -This product contains the extensions to Java Collections Framework which has -been derived from the works by JSR-166 EG, Doug Lea, and Jason T. Greene: - - * LICENSE: - * license/LICENSE.jsr166y.txt (Public Domain) - * HOMEPAGE: - * http://gee.cs.oswego.edu/cgi-bin/viewcvs.cgi/jsr166/ - * http://viewvc.jboss.org/cgi-bin/viewvc.cgi/jbosscache/experimental/jsr166/ - -This product contains a modified version of Robert Harder's Public Domain -Base64 Encoder and Decoder, which can be obtained at: - - * LICENSE: - * license/LICENSE.base64.txt (Public Domain) - * HOMEPAGE: - * http://iharder.sourceforge.net/current/java/base64/ - -This product contains a modified version of 'JZlib', a re-implementation of -zlib in pure Java, which can be obtained at: - - * LICENSE: - * license/LICENSE.jzlib.txt (BSD Style License) - * HOMEPAGE: - * http://www.jcraft.com/jzlib/ - -This product optionally depends on 'Protocol Buffers', Google's data -interchange format, which can be obtained at: - - * LICENSE: - * license/LICENSE.protobuf.txt (New BSD License) - * HOMEPAGE: - * http://code.google.com/p/protobuf/ - -This product optionally depends on 'SLF4J', a simple logging facade for Java, -which can be obtained at: - - * LICENSE: - * license/LICENSE.slf4j.txt (MIT License) - * HOMEPAGE: - * http://www.slf4j.org/ - -This product optionally depends on 'Apache Commons Logging', a logging -framework, which can be obtained at: - - * LICENSE: - * license/LICENSE.commons-logging.txt (Apache License 2.0) - * HOMEPAGE: - * http://commons.apache.org/logging/ - -This product optionally depends on 'Apache Log4J', a logging framework, -which can be obtained at: - - * LICENSE: - * license/LICENSE.log4j.txt (Apache License 2.0) - * HOMEPAGE: - * http://logging.apache.org/log4j/ - -This product optionally depends on 'JBoss Logging', a logging framework, -which can be obtained at: - - * LICENSE: - * license/LICENSE.jboss-logging.txt (GNU LGPL 2.1) - * HOMEPAGE: - * http://anonsvn.jboss.org/repos/common/common-logging-spi/ - -This product optionally depends on 'Apache Felix', an open source OSGi -framework implementation, which can be obtained at: - - * LICENSE: - * license/LICENSE.felix.txt (Apache License 2.0) - * HOMEPAGE: - * http://felix.apache.org/ - -This product optionally depends on 'Webbit', a Java event based -WebSocket and HTTP server: - - * LICENSE: - * license/LICENSE.webbit.txt (BSD License) - * HOMEPAGE: - * https://github.com/joewalnes/webbit - -# Jackson JSON processor - -Jackson is a high-performance, Free/Open Source JSON processing library. -It was originally written by Tatu Saloranta (tatu.saloranta@iki.fi), and has -been in development since 2007. -It is currently developed by a community of developers, as well as supported -commercially by FasterXML.com. - -Jackson core and extension components may be licensed under different licenses. -To find the details that apply to this artifact see the accompanying LICENSE file. -For more information, including possible other licensing options, contact -FasterXML.com (http://fasterxml.com). - -## Credits - -A list of contributors may be found from CREDITS file, which is included -in some artifacts (usually source distributions); but is always available -from the source code management (SCM) system project uses. - -Jackson core and extension components may licensed under different licenses. -To find the details that apply to this artifact see the accompanying LICENSE file. -For more information, including possible other licensing options, contact -FasterXML.com (http://fasterxml.com). - -mesos -Copyright 2014 The Apache Software Foundation - -Apache Thrift -Copyright 2006-2010 The Apache Software Foundation. - - Apache Ant - Copyright 1999-2013 The Apache Software Foundation - - The task is based on code Copyright (c) 2002, Landmark - Graphics Corp that has been kindly donated to the Apache Software - Foundation. - -Apache Commons IO -Copyright 2002-2012 The Apache Software Foundation - -Apache Commons Math -Copyright 2001-2013 The Apache Software Foundation - -=============================================================================== - -The inverse error function implementation in the Erf class is based on CUDA -code developed by Mike Giles, Oxford-Man Institute of Quantitative Finance, -and published in GPU Computing Gems, volume 2, 2010. -=============================================================================== - -The BracketFinder (package org.apache.commons.math3.optimization.univariate) -and PowellOptimizer (package org.apache.commons.math3.optimization.general) -classes are based on the Python code in module "optimize.py" (version 0.5) -developed by Travis E. Oliphant for the SciPy library (http://www.scipy.org/) -Copyright © 2003-2009 SciPy Developers. -=============================================================================== - -The LinearConstraint, LinearObjectiveFunction, LinearOptimizer, -RelationShip, SimplexSolver and SimplexTableau classes in package -org.apache.commons.math3.optimization.linear include software developed by -Benjamin McCann (http://www.benmccann.com) and distributed with -the following copyright: Copyright 2009 Google Inc. -=============================================================================== - -This product includes software developed by the -University of Chicago, as Operator of Argonne National -Laboratory. -The LevenbergMarquardtOptimizer class in package -org.apache.commons.math3.optimization.general includes software -translated from the lmder, lmpar and qrsolv Fortran routines -from the Minpack package -Minpack Copyright Notice (1999) University of Chicago. All rights reserved -=============================================================================== - -The GraggBulirschStoerIntegrator class in package -org.apache.commons.math3.ode.nonstiff includes software translated -from the odex Fortran routine developed by E. Hairer and G. Wanner. -Original source copyright: -Copyright (c) 2004, Ernst Hairer -=============================================================================== - -The EigenDecompositionImpl class in package -org.apache.commons.math3.linear includes software translated -from some LAPACK Fortran routines. Original source copyright: -Copyright (c) 1992-2008 The University of Tennessee. All rights reserved. -=============================================================================== - -The MersenneTwister class in package org.apache.commons.math3.random -includes software translated from the 2002-01-26 version of -the Mersenne-Twister generator written in C by Makoto Matsumoto and Takuji -Nishimura. Original source copyright: -Copyright (C) 1997 - 2002, Makoto Matsumoto and Takuji Nishimura, -All rights reserved -=============================================================================== - -The LocalizedFormatsTest class in the unit tests is an adapted version of -the OrekitMessagesTest class from the orekit library distributed under the -terms of the Apache 2 licence. Original source copyright: -Copyright 2010 CS Systèmes d'Information -=============================================================================== - -The HermiteInterpolator class and its corresponding test have been imported from -the orekit library distributed under the terms of the Apache 2 licence. Original -source copyright: -Copyright 2010-2012 CS Systèmes d'Information -=============================================================================== - -The creation of the package "o.a.c.m.analysis.integration.gauss" was inspired -by an original code donated by Sébastien Brisard. -=============================================================================== - -The complete text of licenses and disclaimers associated with the the original -sources enumerated above at the time of code translation are in the LICENSE.txt -file. - -This product currently only contains code developed by authors -of specific components, as identified by the source code files; -if such notes are missing files have been created by -Tatu Saloranta. - -For additional credits (generally to people who reported problems) -see CREDITS file. - -Apache Commons Lang -Copyright 2001-2011 The Apache Software Foundation - -Apache Commons Compress -Copyright 2002-2012 The Apache Software Foundation - -Apache Commons CLI -Copyright 2001-2009 The Apache Software Foundation - -Google Guice - Extensions - Servlet -Copyright 2006-2011 Google, Inc. - -Google Guice - Core Library -Copyright 2006-2011 Google, Inc. - -Apache Jakarta HttpClient -Copyright 1999-2007 The Apache Software Foundation - -Apache Hive -Copyright 2008-2013 The Apache Software Foundation - -This product includes software developed by The Apache Software -Foundation (http://www.apache.org/). - -This product includes software developed by The JDBM Project -(http://jdbm.sourceforge.net/). - -This product includes/uses ANTLR (http://www.antlr.org/), -Copyright (c) 2003-2011, Terrence Parr. - -This product includes/uses StringTemplate (http://www.stringtemplate.org/), -Copyright (c) 2011, Terrence Parr. - -This product includes/uses ASM (http://asm.ow2.org/), -Copyright (c) 2000-2007 INRIA, France Telecom. - -This product includes/uses JLine (http://jline.sourceforge.net/), -Copyright (c) 2002-2006, Marc Prud'hommeaux . - -This product includes/uses SQLLine (http://sqlline.sourceforge.net), -Copyright (c) 2002, 2003, 2004, 2005 Marc Prud'hommeaux . - -This product includes/uses SLF4J (http://www.slf4j.org/), -Copyright (c) 2004-2010 QOS.ch - -This product includes/uses Bootstrap (http://twitter.github.com/bootstrap/), -Copyright (c) 2012 Twitter, Inc. - -This product includes/uses Glyphicons (http://glyphicons.com/), -Copyright (c) 2010 - 2012 Jan Kovarík - -This product includes DataNucleus (http://www.datanucleus.org/) -Copyright 2008-2008 DataNucleus - -This product includes Guava (http://code.google.com/p/guava-libraries/) -Copyright (C) 2006 Google Inc. - -This product includes JavaEWAH (http://code.google.com/p/javaewah/) -Copyright (C) 2011 Google Inc. - -Apache Commons Pool -Copyright 1999-2009 The Apache Software Foundation - -This product includes/uses Kubernetes & OpenShift 3 Java Client (https://github.com/fabric8io/kubernetes-client) -Copyright (C) 2015 Red Hat, Inc. - -This product includes/uses OkHttp (https://github.com/square/okhttp) -Copyright (C) 2012 The Android Open Source Project - -========================================================================= -== NOTICE file corresponding to section 4(d) of the Apache License, == -== Version 2.0, in this case for the DataNucleus distribution. == -========================================================================= - -=================================================================== -This product includes software developed by many individuals, -including the following: -=================================================================== -Erik Bengtson -Andy Jefferson - -=================================================================== -This product has included contributions from some individuals, -including the following: -=================================================================== - -=================================================================== -This product has included contributions from some individuals, -including the following: -=================================================================== -Joerg von Frantzius -Thomas Marti -Barry Haddow -Marco Schulze -Ralph Ullrich -David Ezzio -Brendan de Beer -David Eaves -Martin Taal -Tony Lai -Roland Szabo -Marcus Mennemeier -Xuan Baldauf -Eric Sultan - -=================================================================== -This product also includes software developed by the TJDO project -(http://tjdo.sourceforge.net/). -=================================================================== - -=================================================================== -This product includes software developed by many individuals, -including the following: -=================================================================== -Andy Jefferson -Erik Bengtson -Joerg von Frantzius -Marco Schulze - -=================================================================== -This product has included contributions from some individuals, -including the following: -=================================================================== -Barry Haddow -Ralph Ullrich -David Ezzio -Brendan de Beer -David Eaves -Martin Taal -Tony Lai -Roland Szabo -Anton Troshin (Timesten) - -=================================================================== -This product also includes software developed by the Apache Commons project -(http://commons.apache.org/). -=================================================================== - -Apache Java Data Objects (JDO) -Copyright 2005-2006 The Apache Software Foundation - -========================================================================= -== NOTICE file corresponding to section 4(d) of the Apache License, == -== Version 2.0, in this case for the Apache Derby distribution. == -========================================================================= - -Apache Derby -Copyright 2004-2008 The Apache Software Foundation - -Portions of Derby were originally developed by -International Business Machines Corporation and are -licensed to the Apache Software Foundation under the -"Software Grant and Corporate Contribution License Agreement", -informally known as the "Derby CLA". -The following copyright notice(s) were affixed to portions of the code -with which this file is now or was at one time distributed -and are placed here unaltered. - -(C) Copyright 1997,2004 International Business Machines Corporation. All rights reserved. - -(C) Copyright IBM Corp. 2003. - -The portion of the functionTests under 'nist' was originally -developed by the National Institute of Standards and Technology (NIST), -an agency of the United States Department of Commerce, and adapted by -International Business Machines Corporation in accordance with the NIST -Software Acknowledgment and Redistribution document at -http://www.itl.nist.gov/div897/ctg/sql_form.htm - -Apache Commons Collections -Copyright 2001-2008 The Apache Software Foundation - -Apache Commons Configuration -Copyright 2001-2008 The Apache Software Foundation - -Apache Jakarta Commons Digester -Copyright 2001-2006 The Apache Software Foundation - -Apache Commons BeanUtils -Copyright 2000-2008 The Apache Software Foundation - -Apache Avro Mapred API -Copyright 2009-2013 The Apache Software Foundation - -Apache Avro IPC -Copyright 2009-2013 The Apache Software Foundation - - -Vis.js -Copyright 2010-2015 Almende B.V. - -Vis.js is dual licensed under both - - * The Apache 2.0 License - http://www.apache.org/licenses/LICENSE-2.0 - - and - - * The MIT License - http://opensource.org/licenses/MIT - -Vis.js may be distributed under either license. - - -Vis.js uses and redistributes the following third-party libraries: - -- component-emitter - https://github.com/component/emitter - The MIT License - -- hammer.js - http://hammerjs.github.io/ - The MIT License - -- moment.js - http://momentjs.com/ - The MIT License - -- keycharm - https://github.com/AlexDM0/keycharm - The MIT License - -=============================================================================== - -The CSS style for the navigation sidebar of the documentation was originally -submitted by Óscar Nájera for the scikit-learn project. The scikit-learn project -is distributed under the 3-Clause BSD license. -=============================================================================== - -For CSV functionality: - -/* - * Copyright 2014 Databricks - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/* - * Copyright 2015 Ayasdi Inc - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - - -=============================================================================== -For dev/sparktestsupport/toposort.py: - -Copyright 2014 True Blade Systems, Inc. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. diff --git a/NOTICE-binary b/NOTICE-binary new file mode 100644 index 0000000000000..d56f99bdb55a6 --- /dev/null +++ b/NOTICE-binary @@ -0,0 +1,1170 @@ +Apache Spark +Copyright 2014 and onwards The Apache Software Foundation. + +This product includes software developed at +The Apache Software Foundation (http://www.apache.org/). + + +// ------------------------------------------------------------------ +// NOTICE file corresponding to the section 4d of The Apache License, +// Version 2.0, in this case for +// ------------------------------------------------------------------ + +Hive Beeline +Copyright 2016 The Apache Software Foundation + +This product includes software developed at +The Apache Software Foundation (http://www.apache.org/). + +Apache Avro +Copyright 2009-2014 The Apache Software Foundation + +This product currently only contains code developed by authors +of specific components, as identified by the source code files; +if such notes are missing files have been created by +Tatu Saloranta. + +For additional credits (generally to people who reported problems) +see CREDITS file. + +Apache Commons Compress +Copyright 2002-2012 The Apache Software Foundation + +This product includes software developed by +The Apache Software Foundation (http://www.apache.org/). + +Apache Avro Mapred API +Copyright 2009-2014 The Apache Software Foundation + +Apache Avro IPC +Copyright 2009-2014 The Apache Software Foundation + +Objenesis +Copyright 2006-2013 Joe Walnes, Henri Tremblay, Leonardo Mesquita + +Apache XBean :: ASM 5 shaded (repackaged) +Copyright 2005-2015 The Apache Software Foundation + +-------------------------------------- + +This product includes software developed at +OW2 Consortium (http://asm.ow2.org/) + +This product includes software developed by The Apache Software +Foundation (http://www.apache.org/). + +The binary distribution of this product bundles binaries of +org.iq80.leveldb:leveldb-api (https://github.com/dain/leveldb), which has the +following notices: +* Copyright 2011 Dain Sundstrom +* Copyright 2011 FuseSource Corp. http://fusesource.com + +The binary distribution of this product bundles binaries of +org.fusesource.hawtjni:hawtjni-runtime (https://github.com/fusesource/hawtjni), +which has the following notices: +* This product includes software developed by FuseSource Corp. + http://fusesource.com +* This product includes software developed at + Progress Software Corporation and/or its subsidiaries or affiliates. +* This product includes software developed by IBM Corporation and others. + +The binary distribution of this product bundles binaries of +Gson 2.2.4, +which has the following notices: + + The Netty Project + ================= + +Please visit the Netty web site for more information: + + * http://netty.io/ + +Copyright 2014 The Netty Project + +The Netty Project licenses this file to you under the Apache License, +version 2.0 (the "License"); you may not use this file except in compliance +with the License. You may obtain a copy of the License at: + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +License for the specific language governing permissions and limitations +under the License. + +Also, please refer to each LICENSE..txt file, which is located in +the 'license' directory of the distribution file, for the license terms of the +components that this product depends on. + +------------------------------------------------------------------------------- +This product contains the extensions to Java Collections Framework which has +been derived from the works by JSR-166 EG, Doug Lea, and Jason T. Greene: + + * LICENSE: + * license/LICENSE.jsr166y.txt (Public Domain) + * HOMEPAGE: + * http://gee.cs.oswego.edu/cgi-bin/viewcvs.cgi/jsr166/ + * http://viewvc.jboss.org/cgi-bin/viewvc.cgi/jbosscache/experimental/jsr166/ + +This product contains a modified version of Robert Harder's Public Domain +Base64 Encoder and Decoder, which can be obtained at: + + * LICENSE: + * license/LICENSE.base64.txt (Public Domain) + * HOMEPAGE: + * http://iharder.sourceforge.net/current/java/base64/ + +This product contains a modified portion of 'Webbit', an event based +WebSocket and HTTP server, which can be obtained at: + + * LICENSE: + * license/LICENSE.webbit.txt (BSD License) + * HOMEPAGE: + * https://github.com/joewalnes/webbit + +This product contains a modified portion of 'SLF4J', a simple logging +facade for Java, which can be obtained at: + + * LICENSE: + * license/LICENSE.slf4j.txt (MIT License) + * HOMEPAGE: + * http://www.slf4j.org/ + +This product contains a modified portion of 'ArrayDeque', written by Josh +Bloch of Google, Inc: + + * LICENSE: + * license/LICENSE.deque.txt (Public Domain) + +This product contains a modified portion of 'Apache Harmony', an open source +Java SE, which can be obtained at: + + * LICENSE: + * license/LICENSE.harmony.txt (Apache License 2.0) + * HOMEPAGE: + * http://archive.apache.org/dist/harmony/ + +This product contains a modified version of Roland Kuhn's ASL2 +AbstractNodeQueue, which is based on Dmitriy Vyukov's non-intrusive MPSC queue. +It can be obtained at: + + * LICENSE: + * license/LICENSE.abstractnodequeue.txt (Public Domain) + * HOMEPAGE: + * https://github.com/akka/akka/blob/wip-2.2.3-for-scala-2.11/akka-actor/src/main/java/akka/dispatch/AbstractNodeQueue.java + +This product contains a modified portion of 'jbzip2', a Java bzip2 compression +and decompression library written by Matthew J. Francis. It can be obtained at: + + * LICENSE: + * license/LICENSE.jbzip2.txt (MIT License) + * HOMEPAGE: + * https://code.google.com/p/jbzip2/ + +This product contains a modified portion of 'libdivsufsort', a C API library to construct +the suffix array and the Burrows-Wheeler transformed string for any input string of +a constant-size alphabet written by Yuta Mori. It can be obtained at: + + * LICENSE: + * license/LICENSE.libdivsufsort.txt (MIT License) + * HOMEPAGE: + * https://code.google.com/p/libdivsufsort/ + +This product contains a modified portion of Nitsan Wakart's 'JCTools', Java Concurrency Tools for the JVM, + which can be obtained at: + + * LICENSE: + * license/LICENSE.jctools.txt (ASL2 License) + * HOMEPAGE: + * https://github.com/JCTools/JCTools + +This product optionally depends on 'JZlib', a re-implementation of zlib in +pure Java, which can be obtained at: + + * LICENSE: + * license/LICENSE.jzlib.txt (BSD style License) + * HOMEPAGE: + * http://www.jcraft.com/jzlib/ + +This product optionally depends on 'Compress-LZF', a Java library for encoding and +decoding data in LZF format, written by Tatu Saloranta. It can be obtained at: + + * LICENSE: + * license/LICENSE.compress-lzf.txt (Apache License 2.0) + * HOMEPAGE: + * https://github.com/ning/compress + +This product optionally depends on 'lz4', a LZ4 Java compression +and decompression library written by Adrien Grand. It can be obtained at: + + * LICENSE: + * license/LICENSE.lz4.txt (Apache License 2.0) + * HOMEPAGE: + * https://github.com/jpountz/lz4-java + +This product optionally depends on 'lzma-java', a LZMA Java compression +and decompression library, which can be obtained at: + + * LICENSE: + * license/LICENSE.lzma-java.txt (Apache License 2.0) + * HOMEPAGE: + * https://github.com/jponge/lzma-java + +This product contains a modified portion of 'jfastlz', a Java port of FastLZ compression +and decompression library written by William Kinney. It can be obtained at: + + * LICENSE: + * license/LICENSE.jfastlz.txt (MIT License) + * HOMEPAGE: + * https://code.google.com/p/jfastlz/ + +This product contains a modified portion of and optionally depends on 'Protocol Buffers', Google's data +interchange format, which can be obtained at: + + * LICENSE: + * license/LICENSE.protobuf.txt (New BSD License) + * HOMEPAGE: + * http://code.google.com/p/protobuf/ + +This product optionally depends on 'Bouncy Castle Crypto APIs' to generate +a temporary self-signed X.509 certificate when the JVM does not provide the +equivalent functionality. It can be obtained at: + + * LICENSE: + * license/LICENSE.bouncycastle.txt (MIT License) + * HOMEPAGE: + * http://www.bouncycastle.org/ + +This product optionally depends on 'Snappy', a compression library produced +by Google Inc, which can be obtained at: + + * LICENSE: + * license/LICENSE.snappy.txt (New BSD License) + * HOMEPAGE: + * http://code.google.com/p/snappy/ + +This product optionally depends on 'JBoss Marshalling', an alternative Java +serialization API, which can be obtained at: + + * LICENSE: + * license/LICENSE.jboss-marshalling.txt (GNU LGPL 2.1) + * HOMEPAGE: + * http://www.jboss.org/jbossmarshalling + +This product optionally depends on 'Caliper', Google's micro- +benchmarking framework, which can be obtained at: + + * LICENSE: + * license/LICENSE.caliper.txt (Apache License 2.0) + * HOMEPAGE: + * http://code.google.com/p/caliper/ + +This product optionally depends on 'Apache Commons Logging', a logging +framework, which can be obtained at: + + * LICENSE: + * license/LICENSE.commons-logging.txt (Apache License 2.0) + * HOMEPAGE: + * http://commons.apache.org/logging/ + +This product optionally depends on 'Apache Log4J', a logging framework, which +can be obtained at: + + * LICENSE: + * license/LICENSE.log4j.txt (Apache License 2.0) + * HOMEPAGE: + * http://logging.apache.org/log4j/ + +This product optionally depends on 'Aalto XML', an ultra-high performance +non-blocking XML processor, which can be obtained at: + + * LICENSE: + * license/LICENSE.aalto-xml.txt (Apache License 2.0) + * HOMEPAGE: + * http://wiki.fasterxml.com/AaltoHome + +This product contains a modified version of 'HPACK', a Java implementation of +the HTTP/2 HPACK algorithm written by Twitter. It can be obtained at: + + * LICENSE: + * license/LICENSE.hpack.txt (Apache License 2.0) + * HOMEPAGE: + * https://github.com/twitter/hpack + +This product contains a modified portion of 'Apache Commons Lang', a Java library +provides utilities for the java.lang API, which can be obtained at: + + * LICENSE: + * license/LICENSE.commons-lang.txt (Apache License 2.0) + * HOMEPAGE: + * https://commons.apache.org/proper/commons-lang/ + +The binary distribution of this product bundles binaries of +Commons Codec 1.4, +which has the following notices: + * src/test/org/apache/commons/codec/language/DoubleMetaphoneTest.javacontains test data from http://aspell.net/test/orig/batch0.tab.Copyright (C) 2002 Kevin Atkinson (kevina@gnu.org) + =============================================================================== + The content of package org.apache.commons.codec.language.bm has been translated + from the original php source code available at http://stevemorse.org/phoneticinfo.htm + with permission from the original authors. + Original source copyright:Copyright (c) 2008 Alexander Beider & Stephen P. Morse. + +The binary distribution of this product bundles binaries of +Commons Lang 2.6, +which has the following notices: + * This product includes software from the Spring Framework,under the Apache License 2.0 (see: StringUtils.containsWhitespace()) + +The binary distribution of this product bundles binaries of +Apache Log4j 1.2.17, +which has the following notices: + * ResolverUtil.java + Copyright 2005-2006 Tim Fennell + Dumbster SMTP test server + Copyright 2004 Jason Paul Kitchen + TypeUtil.java + Copyright 2002-2012 Ramnivas Laddad, Juergen Hoeller, Chris Beams + +The binary distribution of this product bundles binaries of +Jetty 6.1.26, +which has the following notices: + * ============================================================== + Jetty Web Container + Copyright 1995-2016 Mort Bay Consulting Pty Ltd. + ============================================================== + + The Jetty Web Container is Copyright Mort Bay Consulting Pty Ltd + unless otherwise noted. + + Jetty is dual licensed under both + + * The Apache 2.0 License + http://www.apache.org/licenses/LICENSE-2.0.html + + and + + * The Eclipse Public 1.0 License + http://www.eclipse.org/legal/epl-v10.html + + Jetty may be distributed under either license. + + ------ + Eclipse + + The following artifacts are EPL. + * org.eclipse.jetty.orbit:org.eclipse.jdt.core + + The following artifacts are EPL and ASL2. + * org.eclipse.jetty.orbit:javax.security.auth.message + + The following artifacts are EPL and CDDL 1.0. + * org.eclipse.jetty.orbit:javax.mail.glassfish + + ------ + Oracle + + The following artifacts are CDDL + GPLv2 with classpath exception. + https://glassfish.dev.java.net/nonav/public/CDDL+GPL.html + + * javax.servlet:javax.servlet-api + * javax.annotation:javax.annotation-api + * javax.transaction:javax.transaction-api + * javax.websocket:javax.websocket-api + + ------ + Oracle OpenJDK + + If ALPN is used to negotiate HTTP/2 connections, then the following + artifacts may be included in the distribution or downloaded when ALPN + module is selected. + + * java.sun.security.ssl + + These artifacts replace/modify OpenJDK classes. The modififications + are hosted at github and both modified and original are under GPL v2 with + classpath exceptions. + http://openjdk.java.net/legal/gplv2+ce.html + + ------ + OW2 + + The following artifacts are licensed by the OW2 Foundation according to the + terms of http://asm.ow2.org/license.html + + org.ow2.asm:asm-commons + org.ow2.asm:asm + + ------ + Apache + + The following artifacts are ASL2 licensed. + + org.apache.taglibs:taglibs-standard-spec + org.apache.taglibs:taglibs-standard-impl + + ------ + MortBay + + The following artifacts are ASL2 licensed. Based on selected classes from + following Apache Tomcat jars, all ASL2 licensed. + + org.mortbay.jasper:apache-jsp + org.apache.tomcat:tomcat-jasper + org.apache.tomcat:tomcat-juli + org.apache.tomcat:tomcat-jsp-api + org.apache.tomcat:tomcat-el-api + org.apache.tomcat:tomcat-jasper-el + org.apache.tomcat:tomcat-api + org.apache.tomcat:tomcat-util-scan + org.apache.tomcat:tomcat-util + + org.mortbay.jasper:apache-el + org.apache.tomcat:tomcat-jasper-el + org.apache.tomcat:tomcat-el-api + + ------ + Mortbay + + The following artifacts are CDDL + GPLv2 with classpath exception. + + https://glassfish.dev.java.net/nonav/public/CDDL+GPL.html + + org.eclipse.jetty.toolchain:jetty-schemas + + ------ + Assorted + + The UnixCrypt.java code implements the one way cryptography used by + Unix systems for simple password protection. Copyright 1996 Aki Yoshida, + modified April 2001 by Iris Van den Broeke, Daniel Deville. + Permission to use, copy, modify and distribute UnixCrypt + for non-commercial or commercial purposes and without fee is + granted provided that the copyright notice appears in all copies./ + +The binary distribution of this product bundles binaries of +Snappy for Java 1.0.4.1, +which has the following notices: + * This product includes software developed by Google + Snappy: http://code.google.com/p/snappy/ (New BSD License) + + This product includes software developed by Apache + PureJavaCrc32C from apache-hadoop-common http://hadoop.apache.org/ + (Apache 2.0 license) + + This library containd statically linked libstdc++. This inclusion is allowed by + "GCC RUntime Library Exception" + http://gcc.gnu.org/onlinedocs/libstdc++/manual/license.html + + == Contributors == + * Tatu Saloranta + * Providing benchmark suite + * Alec Wysoker + * Performance and memory usage improvement + +The binary distribution of this product bundles binaries of +Xerces2 Java Parser 2.9.1, +which has the following notices: + * ========================================================================= + == NOTICE file corresponding to section 4(d) of the Apache License, == + == Version 2.0, in this case for the Apache Xerces Java distribution. == + ========================================================================= + + Apache Xerces Java + Copyright 1999-2007 The Apache Software Foundation + + This product includes software developed at + The Apache Software Foundation (http://www.apache.org/). + + Portions of this software were originally based on the following: + - software copyright (c) 1999, IBM Corporation., http://www.ibm.com. + - software copyright (c) 1999, Sun Microsystems., http://www.sun.com. + - voluntary contributions made by Paul Eng on behalf of the + Apache Software Foundation that were originally developed at iClick, Inc., + software copyright (c) 1999. + +Apache Commons Collections +Copyright 2001-2015 The Apache Software Foundation + +Apache Commons Configuration +Copyright 2001-2008 The Apache Software Foundation + +Apache Jakarta Commons Digester +Copyright 2001-2006 The Apache Software Foundation + +Apache Commons BeanUtils +Copyright 2000-2008 The Apache Software Foundation + +ApacheDS Protocol Kerberos Codec +Copyright 2003-2013 The Apache Software Foundation + +ApacheDS I18n +Copyright 2003-2013 The Apache Software Foundation + +Apache Directory API ASN.1 API +Copyright 2003-2013 The Apache Software Foundation + +Apache Directory LDAP API Utilities +Copyright 2003-2013 The Apache Software Foundation + +Curator Client +Copyright 2011-2015 The Apache Software Foundation + +htrace-core +Copyright 2015 The Apache Software Foundation + + ========================================================================= + == NOTICE file corresponding to section 4(d) of the Apache License, == + == Version 2.0, in this case for the Apache Xerces Java distribution. == + ========================================================================= + + Portions of this software were originally based on the following: + - software copyright (c) 1999, IBM Corporation., http://www.ibm.com. + - software copyright (c) 1999, Sun Microsystems., http://www.sun.com. + - voluntary contributions made by Paul Eng on behalf of the + Apache Software Foundation that were originally developed at iClick, Inc., + software copyright (c) 1999. + +# Jackson JSON processor + +Jackson is a high-performance, Free/Open Source JSON processing library. +It was originally written by Tatu Saloranta (tatu.saloranta@iki.fi), and has +been in development since 2007. +It is currently developed by a community of developers, as well as supported +commercially by FasterXML.com. + +## Licensing + +Jackson core and extension components may licensed under different licenses. +To find the details that apply to this artifact see the accompanying LICENSE file. +For more information, including possible other licensing options, contact +FasterXML.com (http://fasterxml.com). + +## Credits + +A list of contributors may be found from CREDITS file, which is included +in some artifacts (usually source distributions); but is always available +from the source code management (SCM) system project uses. + +Apache HttpCore +Copyright 2005-2017 The Apache Software Foundation + +Curator Recipes +Copyright 2011-2015 The Apache Software Foundation + +Curator Framework +Copyright 2011-2015 The Apache Software Foundation + +Apache Commons Lang +Copyright 2001-2016 The Apache Software Foundation + +This product includes software from the Spring Framework, +under the Apache License 2.0 (see: StringUtils.containsWhitespace()) + +Apache Commons Math +Copyright 2001-2015 The Apache Software Foundation + +This product includes software developed for Orekit by +CS Systèmes d'Information (http://www.c-s.fr/) +Copyright 2010-2012 CS Systèmes d'Information + +Apache log4j +Copyright 2007 The Apache Software Foundation + +# Compress LZF + +This library contains efficient implementation of LZF compression format, +as well as additional helper classes that build on JDK-provided gzip (deflat) +codec. + +Library is licensed under Apache License 2.0, as per accompanying LICENSE file. + +## Credit + +Library has been written by Tatu Saloranta (tatu.saloranta@iki.fi). +It was started at Ning, inc., as an official Open Source process used by +platform backend, but after initial versions has been developed outside of +Ning by supporting community. + +Other contributors include: + +* Jon Hartlaub (first versions of streaming reader/writer; unit tests) +* Cedrik Lime: parallel LZF implementation + +Various community members have contributed bug reports, and suggested minor +fixes; these can be found from file "VERSION.txt" in SCM. + +Apache Commons Net +Copyright 2001-2012 The Apache Software Foundation + +Copyright 2011 The Netty Project + +http://www.apache.org/licenses/LICENSE-2.0 + +This product contains a modified version of 'JZlib', a re-implementation of +zlib in pure Java, which can be obtained at: + + * LICENSE: + * license/LICENSE.jzlib.txt (BSD Style License) + * HOMEPAGE: + * http://www.jcraft.com/jzlib/ + +This product contains a modified version of 'Webbit', a Java event based +WebSocket and HTTP server: + +This product optionally depends on 'Protocol Buffers', Google's data +interchange format, which can be obtained at: + +This product optionally depends on 'SLF4J', a simple logging facade for Java, +which can be obtained at: + +This product optionally depends on 'Apache Log4J', a logging framework, +which can be obtained at: + +This product optionally depends on 'JBoss Logging', a logging framework, +which can be obtained at: + + * LICENSE: + * license/LICENSE.jboss-logging.txt (GNU LGPL 2.1) + * HOMEPAGE: + * http://anonsvn.jboss.org/repos/common/common-logging-spi/ + +This product optionally depends on 'Apache Felix', an open source OSGi +framework implementation, which can be obtained at: + + * LICENSE: + * license/LICENSE.felix.txt (Apache License 2.0) + * HOMEPAGE: + * http://felix.apache.org/ + +Jackson core and extension components may be licensed under different licenses. +To find the details that apply to this artifact see the accompanying LICENSE file. +For more information, including possible other licensing options, contact +FasterXML.com (http://fasterxml.com). + +Apache Ivy (TM) +Copyright 2007-2014 The Apache Software Foundation + +Portions of Ivy were originally developed at +Jayasoft SARL (http://www.jayasoft.fr/) +and are licensed to the Apache Software Foundation under the +"Software Grant License Agreement" + +SSH and SFTP support is provided by the JCraft JSch package, +which is open source software, available under +the terms of a BSD style license. +The original software and related information is available +at http://www.jcraft.com/jsch/. + + +ORC Core +Copyright 2013-2018 The Apache Software Foundation + +Apache Commons Lang +Copyright 2001-2011 The Apache Software Foundation + +ORC MapReduce +Copyright 2013-2018 The Apache Software Foundation + +Apache Parquet Format +Copyright 2017 The Apache Software Foundation + +Arrow Vectors +Copyright 2017 The Apache Software Foundation + +Arrow Format +Copyright 2017 The Apache Software Foundation + +Arrow Memory +Copyright 2017 The Apache Software Foundation + +Apache Commons CLI +Copyright 2001-2009 The Apache Software Foundation + +Google Guice - Extensions - Servlet +Copyright 2006-2011 Google, Inc. + +Apache Commons IO +Copyright 2002-2012 The Apache Software Foundation + +Google Guice - Core Library +Copyright 2006-2011 Google, Inc. + +mesos +Copyright 2017 The Apache Software Foundation + +Apache Parquet Hadoop Bundle (Incubating) +Copyright 2015 The Apache Software Foundation + +Hive Query Language +Copyright 2016 The Apache Software Foundation + +Apache Extras Companion for log4j 1.2. +Copyright 2007 The Apache Software Foundation + +Hive Metastore +Copyright 2016 The Apache Software Foundation + +Apache Commons Logging +Copyright 2003-2013 The Apache Software Foundation + +========================================================================= +== NOTICE file corresponding to section 4(d) of the Apache License, == +== Version 2.0, in this case for the DataNucleus distribution. == +========================================================================= + +=================================================================== +This product includes software developed by many individuals, +including the following: +=================================================================== +Erik Bengtson +Andy Jefferson + +=================================================================== +This product has included contributions from some individuals, +including the following: +=================================================================== + +=================================================================== +This product includes software developed by many individuals, +including the following: +=================================================================== +Andy Jefferson +Erik Bengtson +Joerg von Frantzius +Marco Schulze + +=================================================================== +This product has included contributions from some individuals, +including the following: +=================================================================== +Barry Haddow +Ralph Ullrich +David Ezzio +Brendan de Beer +David Eaves +Martin Taal +Tony Lai +Roland Szabo +Anton Troshin (Timesten) + +=================================================================== +This product also includes software developed by the TJDO project +(http://tjdo.sourceforge.net/). +=================================================================== + +=================================================================== +This product also includes software developed by the Apache Commons project +(http://commons.apache.org/). +=================================================================== + +Apache Commons Pool +Copyright 1999-2009 The Apache Software Foundation + +Apache Commons DBCP +Copyright 2001-2010 The Apache Software Foundation + +Apache Java Data Objects (JDO) +Copyright 2005-2006 The Apache Software Foundation + +Apache Jakarta HttpClient +Copyright 1999-2007 The Apache Software Foundation + +Calcite Avatica +Copyright 2012-2015 The Apache Software Foundation + +Calcite Core +Copyright 2012-2015 The Apache Software Foundation + +Calcite Linq4j +Copyright 2012-2015 The Apache Software Foundation + +Apache HttpClient +Copyright 1999-2017 The Apache Software Foundation + +Apache Commons Codec +Copyright 2002-2014 The Apache Software Foundation + +src/test/org/apache/commons/codec/language/DoubleMetaphoneTest.java +contains test data from http://aspell.net/test/orig/batch0.tab. +Copyright (C) 2002 Kevin Atkinson (kevina@gnu.org) + +=============================================================================== + +The content of package org.apache.commons.codec.language.bm has been translated +from the original php source code available at http://stevemorse.org/phoneticinfo.htm +with permission from the original authors. +Original source copyright: +Copyright (c) 2008 Alexander Beider & Stephen P. Morse. + +============================================================================= += NOTICE file corresponding to section 4d of the Apache License Version 2.0 = +============================================================================= +This product includes software developed by +Joda.org (http://www.joda.org/). + +=================================================================== +This product has included contributions from some individuals, +including the following: +=================================================================== +Joerg von Frantzius +Thomas Marti +Barry Haddow +Marco Schulze +Ralph Ullrich +David Ezzio +Brendan de Beer +David Eaves +Martin Taal +Tony Lai +Roland Szabo +Marcus Mennemeier +Xuan Baldauf +Eric Sultan + +Apache Thrift +Copyright 2006-2010 The Apache Software Foundation. + +========================================================================= +== NOTICE file corresponding to section 4(d) of the Apache License, +== Version 2.0, in this case for the Apache Derby distribution. +== +== DO NOT EDIT THIS FILE DIRECTLY. IT IS GENERATED +== BY THE buildnotice TARGET IN THE TOP LEVEL build.xml FILE. +== +========================================================================= + +Apache Derby +Copyright 2004-2015 The Apache Software Foundation + +========================================================================= + +Portions of Derby were originally developed by +International Business Machines Corporation and are +licensed to the Apache Software Foundation under the +"Software Grant and Corporate Contribution License Agreement", +informally known as the "Derby CLA". +The following copyright notice(s) were affixed to portions of the code +with which this file is now or was at one time distributed +and are placed here unaltered. + +(C) Copyright 1997,2004 International Business Machines Corporation. All rights reserved. + +(C) Copyright IBM Corp. 2003. + +The portion of the functionTests under 'nist' was originally +developed by the National Institute of Standards and Technology (NIST), +an agency of the United States Department of Commerce, and adapted by +International Business Machines Corporation in accordance with the NIST +Software Acknowledgment and Redistribution document at +http://www.itl.nist.gov/div897/ctg/sql_form.htm + +The JDBC apis for small devices and JDBC3 (under java/stubs/jsr169 and +java/stubs/jdbc3) were produced by trimming sources supplied by the +Apache Harmony project. In addition, the Harmony SerialBlob and +SerialClob implementations are used. The following notice covers the Harmony sources: + +Portions of Harmony were originally developed by +Intel Corporation and are licensed to the Apache Software +Foundation under the "Software Grant and Corporate Contribution +License Agreement", informally known as the "Intel Harmony CLA". + +The Derby build relies on source files supplied by the Apache Felix +project. The following notice covers the Felix files: + + Apache Felix Main + Copyright 2008 The Apache Software Foundation + + I. Included Software + + This product includes software developed at + The Apache Software Foundation (http://www.apache.org/). + Licensed under the Apache License 2.0. + + This product includes software developed at + The OSGi Alliance (http://www.osgi.org/). + Copyright (c) OSGi Alliance (2000, 2007). + Licensed under the Apache License 2.0. + + This product includes software from http://kxml.sourceforge.net. + Copyright (c) 2002,2003, Stefan Haustein, Oberhausen, Rhld., Germany. + Licensed under BSD License. + + II. Used Software + + This product uses software developed at + The OSGi Alliance (http://www.osgi.org/). + Copyright (c) OSGi Alliance (2000, 2007). + Licensed under the Apache License 2.0. + + III. License Summary + - Apache License 2.0 + - BSD License + +The Derby build relies on jar files supplied by the Apache Lucene +project. The following notice covers the Lucene files: + +Apache Lucene +Copyright 2013 The Apache Software Foundation + +Includes software from other Apache Software Foundation projects, +including, but not limited to: + - Apache Ant + - Apache Jakarta Regexp + - Apache Commons + - Apache Xerces + +ICU4J, (under analysis/icu) is licensed under an MIT styles license +and Copyright (c) 1995-2008 International Business Machines Corporation and others + +Some data files (under analysis/icu/src/data) are derived from Unicode data such +as the Unicode Character Database. See http://unicode.org/copyright.html for more +details. + +Brics Automaton (under core/src/java/org/apache/lucene/util/automaton) is +BSD-licensed, created by Anders Møller. See http://www.brics.dk/automaton/ + +The levenshtein automata tables (under core/src/java/org/apache/lucene/util/automaton) were +automatically generated with the moman/finenight FSA library, created by +Jean-Philippe Barrette-LaPierre. This library is available under an MIT license, +see http://sites.google.com/site/rrettesite/moman and +http://bitbucket.org/jpbarrette/moman/overview/ + +The class org.apache.lucene.util.WeakIdentityMap was derived from +the Apache CXF project and is Apache License 2.0. + +The Google Code Prettify is Apache License 2.0. +See http://code.google.com/p/google-code-prettify/ + +JUnit (junit-4.10) is licensed under the Common Public License v. 1.0 +See http://junit.sourceforge.net/cpl-v10.html + +This product includes code (JaspellTernarySearchTrie) from Java Spelling Checkin +g Package (jaspell): http://jaspell.sourceforge.net/ +License: The BSD License (http://www.opensource.org/licenses/bsd-license.php) + +The snowball stemmers in + analysis/common/src/java/net/sf/snowball +were developed by Martin Porter and Richard Boulton. +The snowball stopword lists in + analysis/common/src/resources/org/apache/lucene/analysis/snowball +were developed by Martin Porter and Richard Boulton. +The full snowball package is available from + http://snowball.tartarus.org/ + +The KStem stemmer in + analysis/common/src/org/apache/lucene/analysis/en +was developed by Bob Krovetz and Sergio Guzman-Lara (CIIR-UMass Amherst) +under the BSD-license. + +The Arabic,Persian,Romanian,Bulgarian, and Hindi analyzers (common) come with a default +stopword list that is BSD-licensed created by Jacques Savoy. These files reside in: +analysis/common/src/resources/org/apache/lucene/analysis/ar/stopwords.txt, +analysis/common/src/resources/org/apache/lucene/analysis/fa/stopwords.txt, +analysis/common/src/resources/org/apache/lucene/analysis/ro/stopwords.txt, +analysis/common/src/resources/org/apache/lucene/analysis/bg/stopwords.txt, +analysis/common/src/resources/org/apache/lucene/analysis/hi/stopwords.txt +See http://members.unine.ch/jacques.savoy/clef/index.html. + +The German,Spanish,Finnish,French,Hungarian,Italian,Portuguese,Russian and Swedish light stemmers +(common) are based on BSD-licensed reference implementations created by Jacques Savoy and +Ljiljana Dolamic. These files reside in: +analysis/common/src/java/org/apache/lucene/analysis/de/GermanLightStemmer.java +analysis/common/src/java/org/apache/lucene/analysis/de/GermanMinimalStemmer.java +analysis/common/src/java/org/apache/lucene/analysis/es/SpanishLightStemmer.java +analysis/common/src/java/org/apache/lucene/analysis/fi/FinnishLightStemmer.java +analysis/common/src/java/org/apache/lucene/analysis/fr/FrenchLightStemmer.java +analysis/common/src/java/org/apache/lucene/analysis/fr/FrenchMinimalStemmer.java +analysis/common/src/java/org/apache/lucene/analysis/hu/HungarianLightStemmer.java +analysis/common/src/java/org/apache/lucene/analysis/it/ItalianLightStemmer.java +analysis/common/src/java/org/apache/lucene/analysis/pt/PortugueseLightStemmer.java +analysis/common/src/java/org/apache/lucene/analysis/ru/RussianLightStemmer.java +analysis/common/src/java/org/apache/lucene/analysis/sv/SwedishLightStemmer.java + +The Stempel analyzer (stempel) includes BSD-licensed software developed +by the Egothor project http://egothor.sf.net/, created by Leo Galambos, Martin Kvapil, +and Edmond Nolan. + +The Polish analyzer (stempel) comes with a default +stopword list that is BSD-licensed created by the Carrot2 project. The file resides +in stempel/src/resources/org/apache/lucene/analysis/pl/stopwords.txt. +See http://project.carrot2.org/license.html. + +The SmartChineseAnalyzer source code (smartcn) was +provided by Xiaoping Gao and copyright 2009 by www.imdict.net. + +WordBreakTestUnicode_*.java (under modules/analysis/common/src/test/) +is derived from Unicode data such as the Unicode Character Database. +See http://unicode.org/copyright.html for more details. + +The Morfologik analyzer (morfologik) includes BSD-licensed software +developed by Dawid Weiss and Marcin Miłkowski (http://morfologik.blogspot.com/). + +Morfologik uses data from Polish ispell/myspell dictionary +(http://www.sjp.pl/slownik/en/) licenced on the terms of (inter alia) +LGPL and Creative Commons ShareAlike. + +Morfologic includes data from BSD-licensed dictionary of Polish (SGJP) +(http://sgjp.pl/morfeusz/) + +Servlet-api.jar and javax.servlet-*.jar are under the CDDL license, the original +source code for this can be found at http://www.eclipse.org/jetty/downloads.php + +=========================================================================== +Kuromoji Japanese Morphological Analyzer - Apache Lucene Integration +=========================================================================== + +This software includes a binary and/or source version of data from + + mecab-ipadic-2.7.0-20070801 + +which can be obtained from + + http://atilika.com/releases/mecab-ipadic/mecab-ipadic-2.7.0-20070801.tar.gz + +or + + http://jaist.dl.sourceforge.net/project/mecab/mecab-ipadic/2.7.0-20070801/mecab-ipadic-2.7.0-20070801.tar.gz + +=========================================================================== +mecab-ipadic-2.7.0-20070801 Notice +=========================================================================== + +Nara Institute of Science and Technology (NAIST), +the copyright holders, disclaims all warranties with regard to this +software, including all implied warranties of merchantability and +fitness, in no event shall NAIST be liable for +any special, indirect or consequential damages or any damages +whatsoever resulting from loss of use, data or profits, whether in an +action of contract, negligence or other tortuous action, arising out +of or in connection with the use or performance of this software. + +A large portion of the dictionary entries +originate from ICOT Free Software. The following conditions for ICOT +Free Software applies to the current dictionary as well. + +Each User may also freely distribute the Program, whether in its +original form or modified, to any third party or parties, PROVIDED +that the provisions of Section 3 ("NO WARRANTY") will ALWAYS appear +on, or be attached to, the Program, which is distributed substantially +in the same form as set out herein and that such intended +distribution, if actually made, will neither violate or otherwise +contravene any of the laws and regulations of the countries having +jurisdiction over the User or the intended distribution itself. + +NO WARRANTY + +The program was produced on an experimental basis in the course of the +research and development conducted during the project and is provided +to users as so produced on an experimental basis. Accordingly, the +program is provided without any warranty whatsoever, whether express, +implied, statutory or otherwise. The term "warranty" used herein +includes, but is not limited to, any warranty of the quality, +performance, merchantability and fitness for a particular purpose of +the program and the nonexistence of any infringement or violation of +any right of any third party. + +Each user of the program will agree and understand, and be deemed to +have agreed and understood, that there is no warranty whatsoever for +the program and, accordingly, the entire risk arising from or +otherwise connected with the program is assumed by the user. + +Therefore, neither ICOT, the copyright holder, or any other +organization that participated in or was otherwise related to the +development of the program and their respective officials, directors, +officers and other employees shall be held liable for any and all +damages, including, without limitation, general, special, incidental +and consequential damages, arising out of or otherwise in connection +with the use or inability to use the program or any product, material +or result produced or otherwise obtained by using the program, +regardless of whether they have been advised of, or otherwise had +knowledge of, the possibility of such damages at any time during the +project or thereafter. Each user will be deemed to have agreed to the +foregoing by his or her commencement of use of the program. The term +"use" as used herein includes, but is not limited to, the use, +modification, copying and distribution of the program and the +production of secondary products from the program. + +In the case where the program, whether in its original form or +modified, was distributed or delivered to or received by a user from +any person, organization or entity other than ICOT, unless it makes or +grants independently of ICOT any specific warranty to the user in +writing, such person, organization or entity, will also be exempted +from and not be held liable to the user for any such damages as noted +above as far as the program is concerned. + +The Derby build relies on a jar file supplied by the JSON Simple +project, hosted at https://code.google.com/p/json-simple/. +The JSON simple jar file is licensed under the Apache 2.0 License. + +Hive CLI +Copyright 2016 The Apache Software Foundation + +Hive JDBC +Copyright 2016 The Apache Software Foundation + + +Chill is a set of Scala extensions for Kryo. +Copyright 2012 Twitter, Inc. + +Third Party Dependencies: + +Kryo 2.17 +BSD 3-Clause License +http://code.google.com/p/kryo + +Commons-Codec 1.7 +Apache Public License 2.0 +http://hadoop.apache.org + + + +Breeze is distributed under an Apache License V2.0 (See LICENSE) + +=============================================================================== + +Proximal algorithms outlined in Proximal.scala (package breeze.optimize.proximal) +are based on https://github.com/cvxgrp/proximal (see LICENSE for details) and distributed with +Copyright (c) 2014 by Debasish Das (Verizon), all rights reserved. + +=============================================================================== + +QuadraticMinimizer class in package breeze.optimize.proximal is distributed with Copyright (c) +2014, Debasish Das (Verizon), all rights reserved. + +=============================================================================== + +NonlinearMinimizer class in package breeze.optimize.proximal is distributed with Copyright (c) +2015, Debasish Das (Verizon), all rights reserved. + + + ========================================================================= + == NOTICE file corresponding to section 4(d) of the Apache License, == + == Version 2.0, in this case for the distribution of jets3t. == + ========================================================================= + + This product includes software developed by: + + The Apache Software Foundation (http://www.apache.org/). + + The ExoLab Project (http://www.exolab.org/) + + Sun Microsystems (http://www.sun.com/) + + Codehaus (http://castor.codehaus.org) + + Tatu Saloranta (http://wiki.fasterxml.com/TatuSaloranta) + + + +stream-lib +Copyright 2016 AddThis + +This product includes software developed by AddThis. + +This product also includes code adapted from: + +Apache Solr (http://lucene.apache.org/solr/) +Copyright 2014 The Apache Software Foundation + +Apache Mahout (http://mahout.apache.org/) +Copyright 2014 The Apache Software Foundation \ No newline at end of file diff --git a/dev/.rat-excludes b/dev/.rat-excludes index 23b24212b4d29..466135e72233a 100644 --- a/dev/.rat-excludes +++ b/dev/.rat-excludes @@ -11,6 +11,10 @@ cache .rat-excludes .*md derby.log +licenses/* +licenses-binary/* +LICENSE +NOTICE TAGS RELEASE control diff --git a/dev/make-distribution.sh b/dev/make-distribution.sh index 84233c64caa9c..ad99ce55806af 100755 --- a/dev/make-distribution.sh +++ b/dev/make-distribution.sh @@ -211,9 +211,10 @@ mkdir -p "$DISTDIR/examples/src/main" cp -r "$SPARK_HOME/examples/src/main" "$DISTDIR/examples/src/" # Copy license and ASF files -cp "$SPARK_HOME/LICENSE" "$DISTDIR" -cp -r "$SPARK_HOME/licenses" "$DISTDIR" -cp "$SPARK_HOME/NOTICE" "$DISTDIR" +cp "$SPARK_HOME/LICENSE-binary" "$DISTDIR/LICENSE" +mkdir -p "$DISTDIR/licenses" +cp -r "$SPARK_HOME/licenses-binary" "$DISTDIR/licenses" +cp "$SPARK_HOME/NOTICE-binary" "$DISTDIR/NOTICE" if [ -e "$SPARK_HOME/CHANGES.txt" ]; then cp "$SPARK_HOME/CHANGES.txt" "$DISTDIR" diff --git a/licenses/LICENSE-scopt.txt b/licenses-binary/LICENSE-AnchorJS.txt similarity index 100% rename from licenses/LICENSE-scopt.txt rename to licenses-binary/LICENSE-AnchorJS.txt diff --git a/licenses-binary/LICENSE-CC0.txt b/licenses-binary/LICENSE-CC0.txt new file mode 100644 index 0000000000000..1625c17936079 --- /dev/null +++ b/licenses-binary/LICENSE-CC0.txt @@ -0,0 +1,121 @@ +Creative Commons Legal Code + +CC0 1.0 Universal + + CREATIVE COMMONS CORPORATION IS NOT A LAW FIRM AND DOES NOT PROVIDE + LEGAL SERVICES. DISTRIBUTION OF THIS DOCUMENT DOES NOT CREATE AN + ATTORNEY-CLIENT RELATIONSHIP. CREATIVE COMMONS PROVIDES THIS + INFORMATION ON AN "AS-IS" BASIS. CREATIVE COMMONS MAKES NO WARRANTIES + REGARDING THE USE OF THIS DOCUMENT OR THE INFORMATION OR WORKS + PROVIDED HEREUNDER, AND DISCLAIMS LIABILITY FOR DAMAGES RESULTING FROM + THE USE OF THIS DOCUMENT OR THE INFORMATION OR WORKS PROVIDED + HEREUNDER. + +Statement of Purpose + +The laws of most jurisdictions throughout the world automatically confer +exclusive Copyright and Related Rights (defined below) upon the creator +and subsequent owner(s) (each and all, an "owner") of an original work of +authorship and/or a database (each, a "Work"). + +Certain owners wish to permanently relinquish those rights to a Work for +the purpose of contributing to a commons of creative, cultural and +scientific works ("Commons") that the public can reliably and without fear +of later claims of infringement build upon, modify, incorporate in other +works, reuse and redistribute as freely as possible in any form whatsoever +and for any purposes, including without limitation commercial purposes. +These owners may contribute to the Commons to promote the ideal of a free +culture and the further production of creative, cultural and scientific +works, or to gain reputation or greater distribution for their Work in +part through the use and efforts of others. + +For these and/or other purposes and motivations, and without any +expectation of additional consideration or compensation, the person +associating CC0 with a Work (the "Affirmer"), to the extent that he or she +is an owner of Copyright and Related Rights in the Work, voluntarily +elects to apply CC0 to the Work and publicly distribute the Work under its +terms, with knowledge of his or her Copyright and Related Rights in the +Work and the meaning and intended legal effect of CC0 on those rights. + +1. Copyright and Related Rights. A Work made available under CC0 may be +protected by copyright and related or neighboring rights ("Copyright and +Related Rights"). Copyright and Related Rights include, but are not +limited to, the following: + + i. the right to reproduce, adapt, distribute, perform, display, + communicate, and translate a Work; + ii. moral rights retained by the original author(s) and/or performer(s); +iii. publicity and privacy rights pertaining to a person's image or + likeness depicted in a Work; + iv. rights protecting against unfair competition in regards to a Work, + subject to the limitations in paragraph 4(a), below; + v. rights protecting the extraction, dissemination, use and reuse of data + in a Work; + vi. database rights (such as those arising under Directive 96/9/EC of the + European Parliament and of the Council of 11 March 1996 on the legal + protection of databases, and under any national implementation + thereof, including any amended or successor version of such + directive); and +vii. other similar, equivalent or corresponding rights throughout the + world based on applicable law or treaty, and any national + implementations thereof. + +2. Waiver. To the greatest extent permitted by, but not in contravention +of, applicable law, Affirmer hereby overtly, fully, permanently, +irrevocably and unconditionally waives, abandons, and surrenders all of +Affirmer's Copyright and Related Rights and associated claims and causes +of action, whether now known or unknown (including existing as well as +future claims and causes of action), in the Work (i) in all territories +worldwide, (ii) for the maximum duration provided by applicable law or +treaty (including future time extensions), (iii) in any current or future +medium and for any number of copies, and (iv) for any purpose whatsoever, +including without limitation commercial, advertising or promotional +purposes (the "Waiver"). Affirmer makes the Waiver for the benefit of each +member of the public at large and to the detriment of Affirmer's heirs and +successors, fully intending that such Waiver shall not be subject to +revocation, rescission, cancellation, termination, or any other legal or +equitable action to disrupt the quiet enjoyment of the Work by the public +as contemplated by Affirmer's express Statement of Purpose. + +3. Public License Fallback. Should any part of the Waiver for any reason +be judged legally invalid or ineffective under applicable law, then the +Waiver shall be preserved to the maximum extent permitted taking into +account Affirmer's express Statement of Purpose. In addition, to the +extent the Waiver is so judged Affirmer hereby grants to each affected +person a royalty-free, non transferable, non sublicensable, non exclusive, +irrevocable and unconditional license to exercise Affirmer's Copyright and +Related Rights in the Work (i) in all territories worldwide, (ii) for the +maximum duration provided by applicable law or treaty (including future +time extensions), (iii) in any current or future medium and for any number +of copies, and (iv) for any purpose whatsoever, including without +limitation commercial, advertising or promotional purposes (the +"License"). The License shall be deemed effective as of the date CC0 was +applied by Affirmer to the Work. Should any part of the License for any +reason be judged legally invalid or ineffective under applicable law, such +partial invalidity or ineffectiveness shall not invalidate the remainder +of the License, and in such case Affirmer hereby affirms that he or she +will not (i) exercise any of his or her remaining Copyright and Related +Rights in the Work or (ii) assert any associated claims and causes of +action with respect to the Work, in either case contrary to Affirmer's +express Statement of Purpose. + +4. Limitations and Disclaimers. + + a. No trademark or patent rights held by Affirmer are waived, abandoned, + surrendered, licensed or otherwise affected by this document. + b. Affirmer offers the Work as-is and makes no representations or + warranties of any kind concerning the Work, express, implied, + statutory or otherwise, including without limitation warranties of + title, merchantability, fitness for a particular purpose, non + infringement, or the absence of latent or other defects, accuracy, or + the present or absence of errors, whether or not discoverable, all to + the greatest extent permissible under applicable law. + c. Affirmer disclaims responsibility for clearing rights of other persons + that may apply to the Work or any use thereof, including without + limitation any person's Copyright and Related Rights in the Work. + Further, Affirmer disclaims responsibility for obtaining any necessary + consents, permissions or other rights required for any use of the + Work. + d. Affirmer understands and acknowledges that Creative Commons is not a + party to this document and has no duty or obligation with respect to + this CC0 or use of the Work. \ No newline at end of file diff --git a/licenses/LICENSE-antlr.txt b/licenses-binary/LICENSE-antlr.txt similarity index 100% rename from licenses/LICENSE-antlr.txt rename to licenses-binary/LICENSE-antlr.txt diff --git a/licenses-binary/LICENSE-arpack.txt b/licenses-binary/LICENSE-arpack.txt new file mode 100644 index 0000000000000..a3ad80087bb63 --- /dev/null +++ b/licenses-binary/LICENSE-arpack.txt @@ -0,0 +1,8 @@ +Copyright © 2018 The University of Tennessee. All rights reserved. + +Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: +· Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. +· Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer listed in this license in the documentation and/or other materials provided with the distribution. +· Neither the name of the copyright holders nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. + +This software is provided by the copyright holders and contributors "as is" and any express or implied warranties, including, but not limited to, the implied warranties of merchantability and fitness for a particular purpose are disclaimed. in no event shall the copyright owner or contributors be liable for any direct, indirect, incidental, special, exemplary, or consequential damages (including, but not limited to, procurement of substitute goods or services; loss of use, data, or profits; or business interruption) however caused and on any theory of liability, whether in contract, strict liability, or tort (including negligence or otherwise) arising in any way out of the use of this software, even if advised of the possibility of such damage. \ No newline at end of file diff --git a/licenses-binary/LICENSE-automaton.txt b/licenses-binary/LICENSE-automaton.txt new file mode 100644 index 0000000000000..2fc6e8c3432f0 --- /dev/null +++ b/licenses-binary/LICENSE-automaton.txt @@ -0,0 +1,24 @@ +Copyright (c) 2001-2017 Anders Moeller +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions +are met: +1. Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. +3. The name of the author may not be used to endorse or promote products + derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR +IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES +OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. +IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, +INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT +NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF +THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. \ No newline at end of file diff --git a/licenses-binary/LICENSE-bootstrap.txt b/licenses-binary/LICENSE-bootstrap.txt new file mode 100644 index 0000000000000..6c711832fbc85 --- /dev/null +++ b/licenses-binary/LICENSE-bootstrap.txt @@ -0,0 +1,13 @@ +Copyright 2013 Twitter, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. diff --git a/licenses-binary/LICENSE-bouncycastle-bcprov.txt b/licenses-binary/LICENSE-bouncycastle-bcprov.txt new file mode 100644 index 0000000000000..c445a93a06dd4 --- /dev/null +++ b/licenses-binary/LICENSE-bouncycastle-bcprov.txt @@ -0,0 +1,7 @@ +Copyright (c) 2000 - 2018 The Legion of the Bouncy Castle Inc. (https://www.bouncycastle.org) + +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. \ No newline at end of file diff --git a/licenses-binary/LICENSE-cloudpickle.txt b/licenses-binary/LICENSE-cloudpickle.txt new file mode 100644 index 0000000000000..b1e20fa1eda88 --- /dev/null +++ b/licenses-binary/LICENSE-cloudpickle.txt @@ -0,0 +1,28 @@ +Copyright (c) 2012, Regents of the University of California. +Copyright (c) 2009 `PiCloud, Inc. `_. +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions +are met: + * Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + * Neither the name of the University of California, Berkeley nor the + names of its contributors may be used to endorse or promote + products derived from this software without specific prior written + permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED +TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. \ No newline at end of file diff --git a/licenses-binary/LICENSE-d3.min.js.txt b/licenses-binary/LICENSE-d3.min.js.txt new file mode 100644 index 0000000000000..c71e3f254c068 --- /dev/null +++ b/licenses-binary/LICENSE-d3.min.js.txt @@ -0,0 +1,26 @@ +Copyright (c) 2010-2015, Michael Bostock +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +* The name Michael Bostock may not be used to endorse or promote products + derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL MICHAEL BOSTOCK BE LIABLE FOR ANY DIRECT, +INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, +BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, +EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. \ No newline at end of file diff --git a/licenses/LICENSE-Mockito.txt b/licenses-binary/LICENSE-dagre-d3.txt similarity index 94% rename from licenses/LICENSE-Mockito.txt rename to licenses-binary/LICENSE-dagre-d3.txt index e0840a446caf5..4864fe05e9803 100644 --- a/licenses/LICENSE-Mockito.txt +++ b/licenses-binary/LICENSE-dagre-d3.txt @@ -1,6 +1,4 @@ -The MIT License - -Copyright (c) 2007 Mockito contributors +Copyright (c) 2013 Chris Pettitt Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal diff --git a/licenses-binary/LICENSE-datatables.txt b/licenses-binary/LICENSE-datatables.txt new file mode 100644 index 0000000000000..bb7708b5b5a49 --- /dev/null +++ b/licenses-binary/LICENSE-datatables.txt @@ -0,0 +1,7 @@ +Copyright (C) 2008-2018, SpryMedia Ltd. + +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. \ No newline at end of file diff --git a/licenses/LICENSE-f2j.txt b/licenses-binary/LICENSE-f2j.txt similarity index 100% rename from licenses/LICENSE-f2j.txt rename to licenses-binary/LICENSE-f2j.txt diff --git a/licenses-binary/LICENSE-graphlib-dot.txt b/licenses-binary/LICENSE-graphlib-dot.txt new file mode 100644 index 0000000000000..4864fe05e9803 --- /dev/null +++ b/licenses-binary/LICENSE-graphlib-dot.txt @@ -0,0 +1,19 @@ +Copyright (c) 2013 Chris Pettitt + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. \ No newline at end of file diff --git a/licenses-binary/LICENSE-heapq.txt b/licenses-binary/LICENSE-heapq.txt new file mode 100644 index 0000000000000..0c4c4b954bea4 --- /dev/null +++ b/licenses-binary/LICENSE-heapq.txt @@ -0,0 +1,280 @@ + +# A. HISTORY OF THE SOFTWARE +# ========================== +# +# Python was created in the early 1990s by Guido van Rossum at Stichting +# Mathematisch Centrum (CWI, see http://www.cwi.nl) in the Netherlands +# as a successor of a language called ABC. Guido remains Python's +# principal author, although it includes many contributions from others. +# +# In 1995, Guido continued his work on Python at the Corporation for +# National Research Initiatives (CNRI, see http://www.cnri.reston.va.us) +# in Reston, Virginia where he released several versions of the +# software. +# +# In May 2000, Guido and the Python core development team moved to +# BeOpen.com to form the BeOpen PythonLabs team. In October of the same +# year, the PythonLabs team moved to Digital Creations (now Zope +# Corporation, see http://www.zope.com). In 2001, the Python Software +# Foundation (PSF, see http://www.python.org/psf/) was formed, a +# non-profit organization created specifically to own Python-related +# Intellectual Property. Zope Corporation is a sponsoring member of +# the PSF. +# +# All Python releases are Open Source (see http://www.opensource.org for +# the Open Source Definition). Historically, most, but not all, Python +# releases have also been GPL-compatible; the table below summarizes +# the various releases. +# +# Release Derived Year Owner GPL- +# from compatible? (1) +# +# 0.9.0 thru 1.2 1991-1995 CWI yes +# 1.3 thru 1.5.2 1.2 1995-1999 CNRI yes +# 1.6 1.5.2 2000 CNRI no +# 2.0 1.6 2000 BeOpen.com no +# 1.6.1 1.6 2001 CNRI yes (2) +# 2.1 2.0+1.6.1 2001 PSF no +# 2.0.1 2.0+1.6.1 2001 PSF yes +# 2.1.1 2.1+2.0.1 2001 PSF yes +# 2.2 2.1.1 2001 PSF yes +# 2.1.2 2.1.1 2002 PSF yes +# 2.1.3 2.1.2 2002 PSF yes +# 2.2.1 2.2 2002 PSF yes +# 2.2.2 2.2.1 2002 PSF yes +# 2.2.3 2.2.2 2003 PSF yes +# 2.3 2.2.2 2002-2003 PSF yes +# 2.3.1 2.3 2002-2003 PSF yes +# 2.3.2 2.3.1 2002-2003 PSF yes +# 2.3.3 2.3.2 2002-2003 PSF yes +# 2.3.4 2.3.3 2004 PSF yes +# 2.3.5 2.3.4 2005 PSF yes +# 2.4 2.3 2004 PSF yes +# 2.4.1 2.4 2005 PSF yes +# 2.4.2 2.4.1 2005 PSF yes +# 2.4.3 2.4.2 2006 PSF yes +# 2.4.4 2.4.3 2006 PSF yes +# 2.5 2.4 2006 PSF yes +# 2.5.1 2.5 2007 PSF yes +# 2.5.2 2.5.1 2008 PSF yes +# 2.5.3 2.5.2 2008 PSF yes +# 2.6 2.5 2008 PSF yes +# 2.6.1 2.6 2008 PSF yes +# 2.6.2 2.6.1 2009 PSF yes +# 2.6.3 2.6.2 2009 PSF yes +# 2.6.4 2.6.3 2009 PSF yes +# 2.6.5 2.6.4 2010 PSF yes +# 2.7 2.6 2010 PSF yes +# +# Footnotes: +# +# (1) GPL-compatible doesn't mean that we're distributing Python under +# the GPL. All Python licenses, unlike the GPL, let you distribute +# a modified version without making your changes open source. The +# GPL-compatible licenses make it possible to combine Python with +# other software that is released under the GPL; the others don't. +# +# (2) According to Richard Stallman, 1.6.1 is not GPL-compatible, +# because its license has a choice of law clause. According to +# CNRI, however, Stallman's lawyer has told CNRI's lawyer that 1.6.1 +# is "not incompatible" with the GPL. +# +# Thanks to the many outside volunteers who have worked under Guido's +# direction to make these releases possible. +# +# +# B. TERMS AND CONDITIONS FOR ACCESSING OR OTHERWISE USING PYTHON +# =============================================================== +# +# PYTHON SOFTWARE FOUNDATION LICENSE VERSION 2 +# -------------------------------------------- +# +# 1. This LICENSE AGREEMENT is between the Python Software Foundation +# ("PSF"), and the Individual or Organization ("Licensee") accessing and +# otherwise using this software ("Python") in source or binary form and +# its associated documentation. +# +# 2. Subject to the terms and conditions of this License Agreement, PSF hereby +# grants Licensee a nonexclusive, royalty-free, world-wide license to reproduce, +# analyze, test, perform and/or display publicly, prepare derivative works, +# distribute, and otherwise use Python alone or in any derivative version, +# provided, however, that PSF's License Agreement and PSF's notice of copyright, +# i.e., "Copyright (c) 2001, 2002, 2003, 2004, 2005, 2006, 2007, 2008, 2009, 2010, +# 2011, 2012, 2013 Python Software Foundation; All Rights Reserved" are retained +# in Python alone or in any derivative version prepared by Licensee. +# +# 3. In the event Licensee prepares a derivative work that is based on +# or incorporates Python or any part thereof, and wants to make +# the derivative work available to others as provided herein, then +# Licensee hereby agrees to include in any such work a brief summary of +# the changes made to Python. +# +# 4. PSF is making Python available to Licensee on an "AS IS" +# basis. PSF MAKES NO REPRESENTATIONS OR WARRANTIES, EXPRESS OR +# IMPLIED. BY WAY OF EXAMPLE, BUT NOT LIMITATION, PSF MAKES NO AND +# DISCLAIMS ANY REPRESENTATION OR WARRANTY OF MERCHANTABILITY OR FITNESS +# FOR ANY PARTICULAR PURPOSE OR THAT THE USE OF PYTHON WILL NOT +# INFRINGE ANY THIRD PARTY RIGHTS. +# +# 5. PSF SHALL NOT BE LIABLE TO LICENSEE OR ANY OTHER USERS OF PYTHON +# FOR ANY INCIDENTAL, SPECIAL, OR CONSEQUENTIAL DAMAGES OR LOSS AS +# A RESULT OF MODIFYING, DISTRIBUTING, OR OTHERWISE USING PYTHON, +# OR ANY DERIVATIVE THEREOF, EVEN IF ADVISED OF THE POSSIBILITY THEREOF. +# +# 6. This License Agreement will automatically terminate upon a material +# breach of its terms and conditions. +# +# 7. Nothing in this License Agreement shall be deemed to create any +# relationship of agency, partnership, or joint venture between PSF and +# Licensee. This License Agreement does not grant permission to use PSF +# trademarks or trade name in a trademark sense to endorse or promote +# products or services of Licensee, or any third party. +# +# 8. By copying, installing or otherwise using Python, Licensee +# agrees to be bound by the terms and conditions of this License +# Agreement. +# +# +# BEOPEN.COM LICENSE AGREEMENT FOR PYTHON 2.0 +# ------------------------------------------- +# +# BEOPEN PYTHON OPEN SOURCE LICENSE AGREEMENT VERSION 1 +# +# 1. This LICENSE AGREEMENT is between BeOpen.com ("BeOpen"), having an +# office at 160 Saratoga Avenue, Santa Clara, CA 95051, and the +# Individual or Organization ("Licensee") accessing and otherwise using +# this software in source or binary form and its associated +# documentation ("the Software"). +# +# 2. Subject to the terms and conditions of this BeOpen Python License +# Agreement, BeOpen hereby grants Licensee a non-exclusive, +# royalty-free, world-wide license to reproduce, analyze, test, perform +# and/or display publicly, prepare derivative works, distribute, and +# otherwise use the Software alone or in any derivative version, +# provided, however, that the BeOpen Python License is retained in the +# Software, alone or in any derivative version prepared by Licensee. +# +# 3. BeOpen is making the Software available to Licensee on an "AS IS" +# basis. BEOPEN MAKES NO REPRESENTATIONS OR WARRANTIES, EXPRESS OR +# IMPLIED. BY WAY OF EXAMPLE, BUT NOT LIMITATION, BEOPEN MAKES NO AND +# DISCLAIMS ANY REPRESENTATION OR WARRANTY OF MERCHANTABILITY OR FITNESS +# FOR ANY PARTICULAR PURPOSE OR THAT THE USE OF THE SOFTWARE WILL NOT +# INFRINGE ANY THIRD PARTY RIGHTS. +# +# 4. BEOPEN SHALL NOT BE LIABLE TO LICENSEE OR ANY OTHER USERS OF THE +# SOFTWARE FOR ANY INCIDENTAL, SPECIAL, OR CONSEQUENTIAL DAMAGES OR LOSS +# AS A RESULT OF USING, MODIFYING OR DISTRIBUTING THE SOFTWARE, OR ANY +# DERIVATIVE THEREOF, EVEN IF ADVISED OF THE POSSIBILITY THEREOF. +# +# 5. This License Agreement will automatically terminate upon a material +# breach of its terms and conditions. +# +# 6. This License Agreement shall be governed by and interpreted in all +# respects by the law of the State of California, excluding conflict of +# law provisions. Nothing in this License Agreement shall be deemed to +# create any relationship of agency, partnership, or joint venture +# between BeOpen and Licensee. This License Agreement does not grant +# permission to use BeOpen trademarks or trade names in a trademark +# sense to endorse or promote products or services of Licensee, or any +# third party. As an exception, the "BeOpen Python" logos available at +# http://www.pythonlabs.com/logos.html may be used according to the +# permissions granted on that web page. +# +# 7. By copying, installing or otherwise using the software, Licensee +# agrees to be bound by the terms and conditions of this License +# Agreement. +# +# +# CNRI LICENSE AGREEMENT FOR PYTHON 1.6.1 +# --------------------------------------- +# +# 1. This LICENSE AGREEMENT is between the Corporation for National +# Research Initiatives, having an office at 1895 Preston White Drive, +# Reston, VA 20191 ("CNRI"), and the Individual or Organization +# ("Licensee") accessing and otherwise using Python 1.6.1 software in +# source or binary form and its associated documentation. +# +# 2. Subject to the terms and conditions of this License Agreement, CNRI +# hereby grants Licensee a nonexclusive, royalty-free, world-wide +# license to reproduce, analyze, test, perform and/or display publicly, +# prepare derivative works, distribute, and otherwise use Python 1.6.1 +# alone or in any derivative version, provided, however, that CNRI's +# License Agreement and CNRI's notice of copyright, i.e., "Copyright (c) +# 1995-2001 Corporation for National Research Initiatives; All Rights +# Reserved" are retained in Python 1.6.1 alone or in any derivative +# version prepared by Licensee. Alternately, in lieu of CNRI's License +# Agreement, Licensee may substitute the following text (omitting the +# quotes): "Python 1.6.1 is made available subject to the terms and +# conditions in CNRI's License Agreement. This Agreement together with +# Python 1.6.1 may be located on the Internet using the following +# unique, persistent identifier (known as a handle): 1895.22/1013. This +# Agreement may also be obtained from a proxy server on the Internet +# using the following URL: http://hdl.handle.net/1895.22/1013". +# +# 3. In the event Licensee prepares a derivative work that is based on +# or incorporates Python 1.6.1 or any part thereof, and wants to make +# the derivative work available to others as provided herein, then +# Licensee hereby agrees to include in any such work a brief summary of +# the changes made to Python 1.6.1. +# +# 4. CNRI is making Python 1.6.1 available to Licensee on an "AS IS" +# basis. CNRI MAKES NO REPRESENTATIONS OR WARRANTIES, EXPRESS OR +# IMPLIED. BY WAY OF EXAMPLE, BUT NOT LIMITATION, CNRI MAKES NO AND +# DISCLAIMS ANY REPRESENTATION OR WARRANTY OF MERCHANTABILITY OR FITNESS +# FOR ANY PARTICULAR PURPOSE OR THAT THE USE OF PYTHON 1.6.1 WILL NOT +# INFRINGE ANY THIRD PARTY RIGHTS. +# +# 5. CNRI SHALL NOT BE LIABLE TO LICENSEE OR ANY OTHER USERS OF PYTHON +# 1.6.1 FOR ANY INCIDENTAL, SPECIAL, OR CONSEQUENTIAL DAMAGES OR LOSS AS +# A RESULT OF MODIFYING, DISTRIBUTING, OR OTHERWISE USING PYTHON 1.6.1, +# OR ANY DERIVATIVE THEREOF, EVEN IF ADVISED OF THE POSSIBILITY THEREOF. +# +# 6. This License Agreement will automatically terminate upon a material +# breach of its terms and conditions. +# +# 7. This License Agreement shall be governed by the federal +# intellectual property law of the United States, including without +# limitation the federal copyright law, and, to the extent such +# U.S. federal law does not apply, by the law of the Commonwealth of +# Virginia, excluding Virginia's conflict of law provisions. +# Notwithstanding the foregoing, with regard to derivative works based +# on Python 1.6.1 that incorporate non-separable material that was +# previously distributed under the GNU General Public License (GPL), the +# law of the Commonwealth of Virginia shall govern this License +# Agreement only as to issues arising under or with respect to +# Paragraphs 4, 5, and 7 of this License Agreement. Nothing in this +# License Agreement shall be deemed to create any relationship of +# agency, partnership, or joint venture between CNRI and Licensee. This +# License Agreement does not grant permission to use CNRI trademarks or +# trade name in a trademark sense to endorse or promote products or +# services of Licensee, or any third party. +# +# 8. By clicking on the "ACCEPT" button where indicated, or by copying, +# installing or otherwise using Python 1.6.1, Licensee agrees to be +# bound by the terms and conditions of this License Agreement. +# +# ACCEPT +# +# +# CWI LICENSE AGREEMENT FOR PYTHON 0.9.0 THROUGH 1.2 +# -------------------------------------------------- +# +# Copyright (c) 1991 - 1995, Stichting Mathematisch Centrum Amsterdam, +# The Netherlands. All rights reserved. +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose and without fee is hereby granted, +# provided that the above copyright notice appear in all copies and that +# both that copyright notice and this permission notice appear in +# supporting documentation, and that the name of Stichting Mathematisch +# Centrum or CWI not be used in advertising or publicity pertaining to +# distribution of the software without specific, written prior +# permission. +# +# STICHTING MATHEMATISCH CENTRUM DISCLAIMS ALL WARRANTIES WITH REGARD TO +# THIS SOFTWARE, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND +# FITNESS, IN NO EVENT SHALL STICHTING MATHEMATISCH CENTRUM BE LIABLE +# FOR ANY SPECIAL, INDIRECT OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. \ No newline at end of file diff --git a/licenses-binary/LICENSE-janino.txt b/licenses-binary/LICENSE-janino.txt new file mode 100644 index 0000000000000..d1e1f237c4641 --- /dev/null +++ b/licenses-binary/LICENSE-janino.txt @@ -0,0 +1,31 @@ +Janino - An embedded Java[TM] compiler + +Copyright (c) 2001-2016, Arno Unkrig +Copyright (c) 2015-2016 TIBCO Software Inc. +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions +are met: + + 1. Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + 2. Redistributions in binary form must reproduce the above + copyright notice, this list of conditions and the following + disclaimer in the documentation and/or other materials + provided with the distribution. + 3. Neither the name of JANINO nor the names of its contributors + may be used to endorse or promote products derived from this + software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER +IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR +OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN +IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. \ No newline at end of file diff --git a/licenses-binary/LICENSE-javassist.html b/licenses-binary/LICENSE-javassist.html new file mode 100644 index 0000000000000..5abd563a0c4d9 --- /dev/null +++ b/licenses-binary/LICENSE-javassist.html @@ -0,0 +1,373 @@ + + + Javassist License + + + + +
    MOZILLA PUBLIC LICENSE
    Version + 1.1 +

    +


    +
    +

    1. Definitions. +

      1.0.1. "Commercial Use" means distribution or otherwise making the + Covered Code available to a third party. +

      1.1. ''Contributor'' means each entity that creates or contributes + to the creation of Modifications. +

      1.2. ''Contributor Version'' means the combination of the Original + Code, prior Modifications used by a Contributor, and the Modifications made by + that particular Contributor. +

      1.3. ''Covered Code'' means the Original Code or Modifications or + the combination of the Original Code and Modifications, in each case including + portions thereof. +

      1.4. ''Electronic Distribution Mechanism'' means a mechanism + generally accepted in the software development community for the electronic + transfer of data. +

      1.5. ''Executable'' means Covered Code in any form other than Source + Code. +

      1.6. ''Initial Developer'' means the individual or entity identified + as the Initial Developer in the Source Code notice required by Exhibit + A. +

      1.7. ''Larger Work'' means a work which combines Covered Code or + portions thereof with code not governed by the terms of this License. +

      1.8. ''License'' means this document. +

      1.8.1. "Licensable" means having the right to grant, to the maximum + extent possible, whether at the time of the initial grant or subsequently + acquired, any and all of the rights conveyed herein. +

      1.9. ''Modifications'' means any addition to or deletion from the + substance or structure of either the Original Code or any previous + Modifications. When Covered Code is released as a series of files, a + Modification is: +

        A. Any addition to or deletion from the contents of a file + containing Original Code or previous Modifications. +

        B. Any new file that contains any part of the Original Code or + previous Modifications.
         

      1.10. ''Original Code'' +means Source Code of computer software code which is described in the Source +Code notice required by Exhibit A as Original Code, and which, at the +time of its release under this License is not already Covered Code governed by +this License. +

      1.10.1. "Patent Claims" means any patent claim(s), now owned or + hereafter acquired, including without limitation,  method, process, and + apparatus claims, in any patent Licensable by grantor. +

      1.11. ''Source Code'' means the preferred form of the Covered Code + for making modifications to it, including all modules it contains, plus any + associated interface definition files, scripts used to control compilation and + installation of an Executable, or source code differential comparisons against + either the Original Code or another well known, available Covered Code of the + Contributor's choice. The Source Code can be in a compressed or archival form, + provided the appropriate decompression or de-archiving software is widely + available for no charge. +

      1.12. "You'' (or "Your")  means an individual or a legal entity + exercising rights under, and complying with all of the terms of, this License + or a future version of this License issued under Section 6.1. For legal + entities, "You'' includes any entity which controls, is controlled by, or is + under common control with You. For purposes of this definition, "control'' + means (a) the power, direct or indirect, to cause the direction or management + of such entity, whether by contract or otherwise, or (b) ownership of more + than fifty percent (50%) of the outstanding shares or beneficial ownership of + such entity.

    2. Source Code License. +
      2.1. The Initial Developer Grant.
      The Initial Developer hereby + grants You a world-wide, royalty-free, non-exclusive license, subject to third + party intellectual property claims: +
        (a)  under intellectual property rights (other than + patent or trademark) Licensable by Initial Developer to use, reproduce, + modify, display, perform, sublicense and distribute the Original Code (or + portions thereof) with or without Modifications, and/or as part of a Larger + Work; and +

        (b) under Patents Claims infringed by the making, using or selling + of Original Code, to make, have made, use, practice, sell, and offer for + sale, and/or otherwise dispose of the Original Code (or portions thereof). +

          +
          (c) the licenses granted in this Section 2.1(a) and (b) + are effective on the date Initial Developer first distributes Original Code + under the terms of this License. +

          (d) Notwithstanding Section 2.1(b) above, no patent license is + granted: 1) for code that You delete from the Original Code; 2) separate + from the Original Code;  or 3) for infringements caused by: i) the + modification of the Original Code or ii) the combination of the Original + Code with other software or devices.
           

        2.2. Contributor + Grant.
        Subject to third party intellectual property claims, each + Contributor hereby grants You a world-wide, royalty-free, non-exclusive + license +

          (a)  under intellectual property rights (other + than patent or trademark) Licensable by Contributor, to use, reproduce, + modify, display, perform, sublicense and distribute the Modifications + created by such Contributor (or portions thereof) either on an unmodified + basis, with other Modifications, as Covered Code and/or as part of a Larger + Work; and +

          (b) under Patent Claims infringed by the making, using, or selling + of  Modifications made by that Contributor either alone and/or in combination with its Contributor Version (or portions of such + combination), to make, use, sell, offer for sale, have made, and/or + otherwise dispose of: 1) Modifications made by that Contributor (or portions + thereof); and 2) the combination of  Modifications made by that + Contributor with its Contributor Version (or portions of such + combination). +

          (c) the licenses granted in Sections 2.2(a) and 2.2(b) are + effective on the date Contributor first makes Commercial Use of the Covered + Code. +

          (d)    Notwithstanding Section 2.2(b) above, no + patent license is granted: 1) for any code that Contributor has deleted from + the Contributor Version; 2)  separate from the Contributor + Version;  3)  for infringements caused by: i) third party + modifications of Contributor Version or ii)  the combination of + Modifications made by that Contributor with other software  (except as + part of the Contributor Version) or other devices; or 4) under Patent Claims + infringed by Covered Code in the absence of Modifications made by that + Contributor.

      +


      3. Distribution Obligations. +

        3.1. Application of License.
        The Modifications which You create + or to which You contribute are governed by the terms of this License, + including without limitation Section 2.2. The Source Code version of + Covered Code may be distributed only under the terms of this License or a + future version of this License released under Section 6.1, and You must + include a copy of this License with every copy of the Source Code You + distribute. You may not offer or impose any terms on any Source Code version + that alters or restricts the applicable version of this License or the + recipients' rights hereunder. However, You may include an additional document + offering the additional rights described in Section 3.5. +

        3.2. Availability of Source Code.
        Any Modification which You + create or to which You contribute must be made available in Source Code form + under the terms of this License either on the same media as an Executable + version or via an accepted Electronic Distribution Mechanism to anyone to whom + you made an Executable version available; and if made available via Electronic + Distribution Mechanism, must remain available for at least twelve (12) months + after the date it initially became available, or at least six (6) months after + a subsequent version of that particular Modification has been made available + to such recipients. You are responsible for ensuring that the Source Code + version remains available even if the Electronic Distribution Mechanism is + maintained by a third party. +

        3.3. Description of Modifications.
        You must cause all Covered + Code to which You contribute to contain a file documenting the changes You + made to create that Covered Code and the date of any change. You must include + a prominent statement that the Modification is derived, directly or + indirectly, from Original Code provided by the Initial Developer and including + the name of the Initial Developer in (a) the Source Code, and (b) in any + notice in an Executable version or related documentation in which You describe + the origin or ownership of the Covered Code. +

        3.4. Intellectual Property Matters +

          (a) Third Party Claims.
          If Contributor has knowledge that a + license under a third party's intellectual property rights is required to + exercise the rights granted by such Contributor under Sections 2.1 or 2.2, + Contributor must include a text file with the Source Code distribution + titled "LEGAL'' which describes the claim and the party making the claim in + sufficient detail that a recipient will know whom to contact. If Contributor + obtains such knowledge after the Modification is made available as described + in Section 3.2, Contributor shall promptly modify the LEGAL file in all + copies Contributor makes available thereafter and shall take other steps + (such as notifying appropriate mailing lists or newsgroups) reasonably + calculated to inform those who received the Covered Code that new knowledge + has been obtained. +

          (b) Contributor APIs.
          If Contributor's Modifications include + an application programming interface and Contributor has knowledge of patent + licenses which are reasonably necessary to implement that API, Contributor + must also include this information in the LEGAL file. +
           

                  +(c)    Representations. +
          Contributor represents that, except as disclosed pursuant to Section + 3.4(a) above, Contributor believes that Contributor's Modifications are + Contributor's original creation(s) and/or Contributor has sufficient rights + to grant the rights conveyed by this License.
        +


        3.5. Required Notices.
        You must duplicate the notice in + Exhibit A in each file of the Source Code.  If it is not possible + to put such notice in a particular Source Code file due to its structure, then + You must include such notice in a location (such as a relevant directory) + where a user would be likely to look for such a notice.  If You created + one or more Modification(s) You may add your name as a Contributor to the + notice described in Exhibit A.  You must also duplicate this + License in any documentation for the Source Code where You describe + recipients' rights or ownership rights relating to Covered Code.  You may + choose to offer, and to charge a fee for, warranty, support, indemnity or + liability obligations to one or more recipients of Covered Code. However, You + may do so only on Your own behalf, and not on behalf of the Initial Developer + or any Contributor. You must make it absolutely clear than any such warranty, + support, indemnity or liability obligation is offered by You alone, and You + hereby agree to indemnify the Initial Developer and every Contributor for any + liability incurred by the Initial Developer or such Contributor as a result of + warranty, support, indemnity or liability terms You offer. +

        3.6. Distribution of Executable Versions.
        You may distribute + Covered Code in Executable form only if the requirements of Section + 3.1-3.5 have been met for that Covered Code, and if You include a + notice stating that the Source Code version of the Covered Code is available + under the terms of this License, including a description of how and where You + have fulfilled the obligations of Section 3.2. The notice must be + conspicuously included in any notice in an Executable version, related + documentation or collateral in which You describe recipients' rights relating + to the Covered Code. You may distribute the Executable version of Covered Code + or ownership rights under a license of Your choice, which may contain terms + different from this License, provided that You are in compliance with the + terms of this License and that the license for the Executable version does not + attempt to limit or alter the recipient's rights in the Source Code version + from the rights set forth in this License. If You distribute the Executable + version under a different license You must make it absolutely clear that any + terms which differ from this License are offered by You alone, not by the + Initial Developer or any Contributor. You hereby agree to indemnify the + Initial Developer and every Contributor for any liability incurred by the + Initial Developer or such Contributor as a result of any such terms You offer. + +

        3.7. Larger Works.
        You may create a Larger Work by combining + Covered Code with other code not governed by the terms of this License and + distribute the Larger Work as a single product. In such a case, You must make + sure the requirements of this License are fulfilled for the Covered + Code.

      4. Inability to Comply Due to Statute or Regulation. +
        If it is impossible for You to comply with any of the terms of this + License with respect to some or all of the Covered Code due to statute, + judicial order, or regulation then You must: (a) comply with the terms of this + License to the maximum extent possible; and (b) describe the limitations and + the code they affect. Such description must be included in the LEGAL file + described in Section 3.4 and must be included with all distributions of + the Source Code. Except to the extent prohibited by statute or regulation, + such description must be sufficiently detailed for a recipient of ordinary + skill to be able to understand it.
      5. Application of this License. +
        This License applies to code to which the Initial Developer has attached + the notice in Exhibit A and to related Covered Code.
      6. Versions + of the License. +
        6.1. New Versions.
        Netscape Communications Corporation + (''Netscape'') may publish revised and/or new versions of the License from + time to time. Each version will be given a distinguishing version number. +

        6.2. Effect of New Versions.
        Once Covered Code has been + published under a particular version of the License, You may always continue + to use it under the terms of that version. You may also choose to use such + Covered Code under the terms of any subsequent version of the License + published by Netscape. No one other than Netscape has the right to modify the + terms applicable to Covered Code created under this License. +

        6.3. Derivative Works.
        If You create or use a modified version + of this License (which you may only do in order to apply it to code which is + not already Covered Code governed by this License), You must (a) rename Your + license so that the phrases ''Mozilla'', ''MOZILLAPL'', ''MOZPL'', + ''Netscape'', "MPL", ''NPL'' or any confusingly similar phrase do not appear + in your license (except to note that your license differs from this License) + and (b) otherwise make it clear that Your version of the license contains + terms which differ from the Mozilla Public License and Netscape Public + License. (Filling in the name of the Initial Developer, Original Code or + Contributor in the notice described in Exhibit A shall not of + themselves be deemed to be modifications of this License.)

      7. + DISCLAIMER OF WARRANTY. +
        COVERED CODE IS PROVIDED UNDER THIS LICENSE ON AN "AS IS'' BASIS, WITHOUT + WARRANTY OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, WITHOUT + LIMITATION, WARRANTIES THAT THE COVERED CODE IS FREE OF DEFECTS, MERCHANTABLE, + FIT FOR A PARTICULAR PURPOSE OR NON-INFRINGING. THE ENTIRE RISK AS TO THE + QUALITY AND PERFORMANCE OF THE COVERED CODE IS WITH YOU. SHOULD ANY COVERED + CODE PROVE DEFECTIVE IN ANY RESPECT, YOU (NOT THE INITIAL DEVELOPER OR ANY + OTHER CONTRIBUTOR) ASSUME THE COST OF ANY NECESSARY SERVICING, REPAIR OR + CORRECTION. THIS DISCLAIMER OF WARRANTY CONSTITUTES AN ESSENTIAL PART OF THIS + LICENSE. NO USE OF ANY COVERED CODE IS AUTHORIZED HEREUNDER EXCEPT UNDER THIS + DISCLAIMER.
      8. TERMINATION. +
        8.1.  This License and the rights granted hereunder will + terminate automatically if You fail to comply with terms herein and fail to + cure such breach within 30 days of becoming aware of the breach. All + sublicenses to the Covered Code which are properly granted shall survive any + termination of this License. Provisions which, by their nature, must remain in + effect beyond the termination of this License shall survive. +

        8.2.  If You initiate litigation by asserting a patent + infringement claim (excluding declatory judgment actions) against Initial + Developer or a Contributor (the Initial Developer or Contributor against whom + You file such action is referred to as "Participant")  alleging that: +

        (a)  such Participant's Contributor Version directly or + indirectly infringes any patent, then any and all rights granted by such + Participant to You under Sections 2.1 and/or 2.2 of this License shall, upon + 60 days notice from Participant terminate prospectively, unless if within 60 + days after receipt of notice You either: (i)  agree in writing to pay + Participant a mutually agreeable reasonable royalty for Your past and future + use of Modifications made by such Participant, or (ii) withdraw Your + litigation claim with respect to the Contributor Version against such + Participant.  If within 60 days of notice, a reasonable royalty and + payment arrangement are not mutually agreed upon in writing by the parties or + the litigation claim is not withdrawn, the rights granted by Participant to + You under Sections 2.1 and/or 2.2 automatically terminate at the expiration of + the 60 day notice period specified above. +

        (b)  any software, hardware, or device, other than such + Participant's Contributor Version, directly or indirectly infringes any + patent, then any rights granted to You by such Participant under Sections + 2.1(b) and 2.2(b) are revoked effective as of the date You first made, used, + sold, distributed, or had made, Modifications made by that Participant. +

        8.3.  If You assert a patent infringement claim against + Participant alleging that such Participant's Contributor Version directly or + indirectly infringes any patent where such claim is resolved (such as by + license or settlement) prior to the initiation of patent infringement + litigation, then the reasonable value of the licenses granted by such + Participant under Sections 2.1 or 2.2 shall be taken into account in + determining the amount or value of any payment or license. +

        8.4.  In the event of termination under Sections 8.1 or 8.2 + above,  all end user license agreements (excluding distributors and + resellers) which have been validly granted by You or any distributor hereunder + prior to termination shall survive termination.

      9. LIMITATION OF + LIABILITY. +
        UNDER NO CIRCUMSTANCES AND UNDER NO LEGAL THEORY, WHETHER TORT (INCLUDING + NEGLIGENCE), CONTRACT, OR OTHERWISE, SHALL YOU, THE INITIAL DEVELOPER, ANY + OTHER CONTRIBUTOR, OR ANY DISTRIBUTOR OF COVERED CODE, OR ANY SUPPLIER OF ANY + OF SUCH PARTIES, BE LIABLE TO ANY PERSON FOR ANY INDIRECT, SPECIAL, + INCIDENTAL, OR CONSEQUENTIAL DAMAGES OF ANY CHARACTER INCLUDING, WITHOUT + LIMITATION, DAMAGES FOR LOSS OF GOODWILL, WORK STOPPAGE, COMPUTER FAILURE OR + MALFUNCTION, OR ANY AND ALL OTHER COMMERCIAL DAMAGES OR LOSSES, EVEN IF SUCH + PARTY SHALL HAVE BEEN INFORMED OF THE POSSIBILITY OF SUCH DAMAGES. THIS + LIMITATION OF LIABILITY SHALL NOT APPLY TO LIABILITY FOR DEATH OR PERSONAL + INJURY RESULTING FROM SUCH PARTY'S NEGLIGENCE TO THE EXTENT APPLICABLE LAW + PROHIBITS SUCH LIMITATION. SOME JURISDICTIONS DO NOT ALLOW THE EXCLUSION OR + LIMITATION OF INCIDENTAL OR CONSEQUENTIAL DAMAGES, SO THIS EXCLUSION AND + LIMITATION MAY NOT APPLY TO YOU.
      10. U.S. GOVERNMENT END USERS. +
        The Covered Code is a ''commercial item,'' as that term is defined in 48 + C.F.R. 2.101 (Oct. 1995), consisting of ''commercial computer software'' and + ''commercial computer software documentation,'' as such terms are used in 48 + C.F.R. 12.212 (Sept. 1995). Consistent with 48 C.F.R. 12.212 and 48 C.F.R. + 227.7202-1 through 227.7202-4 (June 1995), all U.S. Government End Users + acquire Covered Code with only those rights set forth herein.
      11. + MISCELLANEOUS. +
        This License represents the complete agreement concerning subject matter + hereof. If any provision of this License is held to be unenforceable, such + provision shall be reformed only to the extent necessary to make it + enforceable. This License shall be governed by California law provisions + (except to the extent applicable law, if any, provides otherwise), excluding + its conflict-of-law provisions. With respect to disputes in which at least one + party is a citizen of, or an entity chartered or registered to do business in + the United States of America, any litigation relating to this License shall be + subject to the jurisdiction of the Federal Courts of the Northern District of + California, with venue lying in Santa Clara County, California, with the + losing party responsible for costs, including without limitation, court costs + and reasonable attorneys' fees and expenses. The application of the United + Nations Convention on Contracts for the International Sale of Goods is + expressly excluded. Any law or regulation which provides that the language of + a contract shall be construed against the drafter shall not apply to this + License.
      12. RESPONSIBILITY FOR CLAIMS. +
        As between Initial Developer and the Contributors, each party is + responsible for claims and damages arising, directly or indirectly, out of its + utilization of rights under this License and You agree to work with Initial + Developer and Contributors to distribute such responsibility on an equitable + basis. Nothing herein is intended or shall be deemed to constitute any + admission of liability.
      13. MULTIPLE-LICENSED CODE. +
        Initial Developer may designate portions of the Covered Code as + "Multiple-Licensed".  "Multiple-Licensed" means that the Initial + Developer permits you to utilize portions of the Covered Code under Your + choice of the MPL or the alternative licenses, if any, specified by the + Initial Developer in the file described in Exhibit A.
      +


      EXHIBIT A -Mozilla Public License. +

        The contents of this file are subject to the Mozilla Public License + Version 1.1 (the "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at +
        http://www.mozilla.org/MPL/ +

        Software distributed under the License is distributed on an "AS IS" basis, + WITHOUT WARRANTY OF
        ANY KIND, either express or implied. See the License + for the specific language governing rights and
        limitations under the + License. +

        The Original Code is Javassist. +

        The Initial Developer of the Original Code is Shigeru Chiba. + Portions created by the Initial Developer are
          + Copyright (C) 1999- Shigeru Chiba. All Rights Reserved. +

        Contributor(s): __Bill Burke, Jason T. Greene______________. + +

        Alternatively, the contents of this software may be used under the + terms of the GNU Lesser General Public License Version 2.1 or later + (the "LGPL"), or the Apache License Version 2.0 (the "AL"), + in which case the provisions of the LGPL or the AL are applicable + instead of those above. If you wish to allow use of your version of + this software only under the terms of either the LGPL or the AL, and not to allow others to + use your version of this software under the terms of the MPL, indicate + your decision by deleting the provisions above and replace them with + the notice and other provisions required by the LGPL or the AL. If you do not + delete the provisions above, a recipient may use your version of this + software under the terms of any one of the MPL, the LGPL or the AL. + +

      + + \ No newline at end of file diff --git a/licenses/LICENSE-javolution.txt b/licenses-binary/LICENSE-javolution.txt similarity index 100% rename from licenses/LICENSE-javolution.txt rename to licenses-binary/LICENSE-javolution.txt diff --git a/licenses/LICENSE-jline.txt b/licenses-binary/LICENSE-jline.txt similarity index 100% rename from licenses/LICENSE-jline.txt rename to licenses-binary/LICENSE-jline.txt diff --git a/licenses/LICENSE-junit-interface.txt b/licenses-binary/LICENSE-jodd.txt similarity index 69% rename from licenses/LICENSE-junit-interface.txt rename to licenses-binary/LICENSE-jodd.txt index e835350c4e2a4..cc6b458adb386 100644 --- a/licenses/LICENSE-junit-interface.txt +++ b/licenses-binary/LICENSE-jodd.txt @@ -1,15 +1,15 @@ -Copyright (c) 2009-2012, Stefan Zeiger +Copyright (c) 2003-present, Jodd Team (https://jodd.org) All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - * Redistributions of source code must retain the above copyright notice, - this list of conditions and the following disclaimer. +1. Redistributions of source code must retain the above copyright notice, +this list of conditions and the following disclaimer. - * Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in the +documentation and/or other materials provided with the distribution. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE diff --git a/licenses/LICENSE-DPark.txt b/licenses-binary/LICENSE-join.txt similarity index 100% rename from licenses/LICENSE-DPark.txt rename to licenses-binary/LICENSE-join.txt diff --git a/licenses-binary/LICENSE-jquery.txt b/licenses-binary/LICENSE-jquery.txt new file mode 100644 index 0000000000000..45930542204fb --- /dev/null +++ b/licenses-binary/LICENSE-jquery.txt @@ -0,0 +1,20 @@ +Copyright JS Foundation and other contributors, https://js.foundation/ + +Permission is hereby granted, free of charge, to any person obtaining +a copy of this software and associated documentation files (the +"Software"), to deal in the Software without restriction, including +without limitation the rights to use, copy, modify, merge, publish, +distribute, sublicense, and/or sell copies of the Software, and to +permit persons to whom the Software is furnished to do so, subject to +the following conditions: + +The above copyright notice and this permission notice shall be +included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE +LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION +WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. \ No newline at end of file diff --git a/licenses-binary/LICENSE-json-formatter.txt b/licenses-binary/LICENSE-json-formatter.txt new file mode 100644 index 0000000000000..5193348fce126 --- /dev/null +++ b/licenses-binary/LICENSE-json-formatter.txt @@ -0,0 +1,6 @@ +Copyright 2014 Mohsen Azimi + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. \ No newline at end of file diff --git a/licenses-binary/LICENSE-jtransforms.html b/licenses-binary/LICENSE-jtransforms.html new file mode 100644 index 0000000000000..351c17412357b --- /dev/null +++ b/licenses-binary/LICENSE-jtransforms.html @@ -0,0 +1,388 @@ + + +Mozilla Public License version 1.1 + + + + +

      Mozilla Public License Version 1.1

      +

      1. Definitions.

      +
      +
      1.0.1. "Commercial Use" +
      means distribution or otherwise making the Covered Code available to a third party. +
      1.1. "Contributor" +
      means each entity that creates or contributes to the creation of Modifications. +
      1.2. "Contributor Version" +
      means the combination of the Original Code, prior Modifications used by a Contributor, + and the Modifications made by that particular Contributor. +
      1.3. "Covered Code" +
      means the Original Code or Modifications or the combination of the Original Code and + Modifications, in each case including portions thereof. +
      1.4. "Electronic Distribution Mechanism" +
      means a mechanism generally accepted in the software development community for the + electronic transfer of data. +
      1.5. "Executable" +
      means Covered Code in any form other than Source Code. +
      1.6. "Initial Developer" +
      means the individual or entity identified as the Initial Developer in the Source Code + notice required by Exhibit A. +
      1.7. "Larger Work" +
      means a work which combines Covered Code or portions thereof with code not governed + by the terms of this License. +
      1.8. "License" +
      means this document. +
      1.8.1. "Licensable" +
      means having the right to grant, to the maximum extent possible, whether at the + time of the initial grant or subsequently acquired, any and all of the rights + conveyed herein. +
      1.9. "Modifications" +
      +

      means any addition to or deletion from the substance or structure of either the + Original Code or any previous Modifications. When Covered Code is released as a + series of files, a Modification is: +

        +
      1. Any addition to or deletion from the contents of a file + containing Original Code or previous Modifications. +
      2. Any new file that contains any part of the Original Code or + previous Modifications. +
      +
      1.10. "Original Code" +
      means Source Code of computer software code which is described in the Source Code + notice required by Exhibit A as Original Code, and which, + at the time of its release under this License is not already Covered Code governed + by this License. +
      1.10.1. "Patent Claims" +
      means any patent claim(s), now owned or hereafter acquired, including without + limitation, method, process, and apparatus claims, in any patent Licensable by + grantor. +
      1.11. "Source Code" +
      means the preferred form of the Covered Code for making modifications to it, + including all modules it contains, plus any associated interface definition files, + scripts used to control compilation and installation of an Executable, or source + code differential comparisons against either the Original Code or another well known, + available Covered Code of the Contributor's choice. The Source Code can be in a + compressed or archival form, provided the appropriate decompression or de-archiving + software is widely available for no charge. +
      1.12. "You" (or "Your") +
      means an individual or a legal entity exercising rights under, and complying with + all of the terms of, this License or a future version of this License issued under + Section 6.1. For legal entities, "You" includes any entity + which controls, is controlled by, or is under common control with You. For purposes of + this definition, "control" means (a) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or otherwise, or (b) + ownership of more than fifty percent (50%) of the outstanding shares or beneficial + ownership of such entity. +
      +

      2. Source Code License.

      +

      2.1. The Initial Developer Grant.

      +

      The Initial Developer hereby grants You a world-wide, royalty-free, non-exclusive + license, subject to third party intellectual property claims: +

        +
      1. under intellectual property rights (other than patent or + trademark) Licensable by Initial Developer to use, reproduce, modify, display, perform, + sublicense and distribute the Original Code (or portions thereof) with or without + Modifications, and/or as part of a Larger Work; and +
      2. under Patents Claims infringed by the making, using or selling + of Original Code, to make, have made, use, practice, sell, and offer for sale, and/or + otherwise dispose of the Original Code (or portions thereof). +
      3. the licenses granted in this Section 2.1 + (a) and (b) are effective on + the date Initial Developer first distributes Original Code under the terms of this + License. +
      4. Notwithstanding Section 2.1 (b) + above, no patent license is granted: 1) for code that You delete from the Original Code; + 2) separate from the Original Code; or 3) for infringements caused by: i) the + modification of the Original Code or ii) the combination of the Original Code with other + software or devices. +
      +

      2.2. Contributor Grant.

      +

      Subject to third party intellectual property claims, each Contributor hereby grants You + a world-wide, royalty-free, non-exclusive license +

        +
      1. under intellectual property rights (other than patent or trademark) + Licensable by Contributor, to use, reproduce, modify, display, perform, sublicense and + distribute the Modifications created by such Contributor (or portions thereof) either on + an unmodified basis, with other Modifications, as Covered Code and/or as part of a Larger + Work; and +
      2. under Patent Claims infringed by the making, using, or selling of + Modifications made by that Contributor either alone and/or in combination with its + Contributor Version (or portions of such combination), to make, use, sell, offer for + sale, have made, and/or otherwise dispose of: 1) Modifications made by that Contributor + (or portions thereof); and 2) the combination of Modifications made by that Contributor + with its Contributor Version (or portions of such combination). +
      3. the licenses granted in Sections 2.2 + (a) and 2.2 (b) are effective + on the date Contributor first makes Commercial Use of the Covered Code. +
      4. Notwithstanding Section 2.2 (b) + above, no patent license is granted: 1) for any code that Contributor has deleted from + the Contributor Version; 2) separate from the Contributor Version; 3) for infringements + caused by: i) third party modifications of Contributor Version or ii) the combination of + Modifications made by that Contributor with other software (except as part of the + Contributor Version) or other devices; or 4) under Patent Claims infringed by Covered Code + in the absence of Modifications made by that Contributor. +
      +

      3. Distribution Obligations.

      +

      3.1. Application of License.

      +

      The Modifications which You create or to which You contribute are governed by the terms + of this License, including without limitation Section 2.2. The + Source Code version of Covered Code may be distributed only under the terms of this License + or a future version of this License released under Section 6.1, + and You must include a copy of this License with every copy of the Source Code You + distribute. You may not offer or impose any terms on any Source Code version that alters or + restricts the applicable version of this License or the recipients' rights hereunder. + However, You may include an additional document offering the additional rights described in + Section 3.5. +

      3.2. Availability of Source Code.

      +

      Any Modification which You create or to which You contribute must be made available in + Source Code form under the terms of this License either on the same media as an Executable + version or via an accepted Electronic Distribution Mechanism to anyone to whom you made an + Executable version available; and if made available via Electronic Distribution Mechanism, + must remain available for at least twelve (12) months after the date it initially became + available, or at least six (6) months after a subsequent version of that particular + Modification has been made available to such recipients. You are responsible for ensuring + that the Source Code version remains available even if the Electronic Distribution + Mechanism is maintained by a third party. +

      3.3. Description of Modifications.

      +

      You must cause all Covered Code to which You contribute to contain a file documenting the + changes You made to create that Covered Code and the date of any change. You must include a + prominent statement that the Modification is derived, directly or indirectly, from Original + Code provided by the Initial Developer and including the name of the Initial Developer in + (a) the Source Code, and (b) in any notice in an Executable version or related documentation + in which You describe the origin or ownership of the Covered Code. +

      3.4. Intellectual Property Matters

      +

      (a) Third Party Claims

      +

      If Contributor has knowledge that a license under a third party's intellectual property + rights is required to exercise the rights granted by such Contributor under Sections + 2.1 or 2.2, Contributor must include a + text file with the Source Code distribution titled "LEGAL" which describes the claim and the + party making the claim in sufficient detail that a recipient will know whom to contact. If + Contributor obtains such knowledge after the Modification is made available as described in + Section 3.2, Contributor shall promptly modify the LEGAL file in + all copies Contributor makes available thereafter and shall take other steps (such as + notifying appropriate mailing lists or newsgroups) reasonably calculated to inform those who + received the Covered Code that new knowledge has been obtained. +

      (b) Contributor APIs

      +

      If Contributor's Modifications include an application programming interface and Contributor + has knowledge of patent licenses which are reasonably necessary to implement that + API, Contributor must also include this information in the + legal file. +

      (c) Representations.

      +

      Contributor represents that, except as disclosed pursuant to Section 3.4 + (a) above, Contributor believes that Contributor's Modifications + are Contributor's original creation(s) and/or Contributor has sufficient rights to grant the + rights conveyed by this License. +

      3.5. Required Notices.

      +

      You must duplicate the notice in Exhibit A in each file of the + Source Code. If it is not possible to put such notice in a particular Source Code file due to + its structure, then You must include such notice in a location (such as a relevant directory) + where a user would be likely to look for such a notice. If You created one or more + Modification(s) You may add your name as a Contributor to the notice described in + Exhibit A. You must also duplicate this License in any documentation + for the Source Code where You describe recipients' rights or ownership rights relating to + Covered Code. You may choose to offer, and to charge a fee for, warranty, support, indemnity + or liability obligations to one or more recipients of Covered Code. However, You may do so + only on Your own behalf, and not on behalf of the Initial Developer or any Contributor. You + must make it absolutely clear than any such warranty, support, indemnity or liability + obligation is offered by You alone, and You hereby agree to indemnify the Initial Developer + and every Contributor for any liability incurred by the Initial Developer or such Contributor + as a result of warranty, support, indemnity or liability terms You offer. +

      3.6. Distribution of Executable Versions.

      +

      You may distribute Covered Code in Executable form only if the requirements of Sections + 3.1, 3.2, + 3.3, 3.4 and + 3.5 have been met for that Covered Code, and if You include a + notice stating that the Source Code version of the Covered Code is available under the terms + of this License, including a description of how and where You have fulfilled the obligations + of Section 3.2. The notice must be conspicuously included in any + notice in an Executable version, related documentation or collateral in which You describe + recipients' rights relating to the Covered Code. You may distribute the Executable version of + Covered Code or ownership rights under a license of Your choice, which may contain terms + different from this License, provided that You are in compliance with the terms of this + License and that the license for the Executable version does not attempt to limit or alter the + recipient's rights in the Source Code version from the rights set forth in this License. If + You distribute the Executable version under a different license You must make it absolutely + clear that any terms which differ from this License are offered by You alone, not by the + Initial Developer or any Contributor. You hereby agree to indemnify the Initial Developer and + every Contributor for any liability incurred by the Initial Developer or such Contributor as + a result of any such terms You offer. +

      3.7. Larger Works.

      +

      You may create a Larger Work by combining Covered Code with other code not governed by the + terms of this License and distribute the Larger Work as a single product. In such a case, + You must make sure the requirements of this License are fulfilled for the Covered Code. +

      4. Inability to Comply Due to Statute or Regulation.

      +

      If it is impossible for You to comply with any of the terms of this License with respect to + some or all of the Covered Code due to statute, judicial order, or regulation then You must: + (a) comply with the terms of this License to the maximum extent possible; and (b) describe + the limitations and the code they affect. Such description must be included in the + legal file described in Section + 3.4 and must be included with all distributions of the Source Code. + Except to the extent prohibited by statute or regulation, such description must be + sufficiently detailed for a recipient of ordinary skill to be able to understand it. +

      5. Application of this License.

      +

      This License applies to code to which the Initial Developer has attached the notice in + Exhibit A and to related Covered Code. +

      6. Versions of the License.

      +

      6.1. New Versions

      +

      Netscape Communications Corporation ("Netscape") may publish revised and/or new versions + of the License from time to time. Each version will be given a distinguishing version number. +

      6.2. Effect of New Versions

      +

      Once Covered Code has been published under a particular version of the License, You may + always continue to use it under the terms of that version. You may also choose to use such + Covered Code under the terms of any subsequent version of the License published by Netscape. + No one other than Netscape has the right to modify the terms applicable to Covered Code + created under this License. +

      6.3. Derivative Works

      +

      If You create or use a modified version of this License (which you may only do in order to + apply it to code which is not already Covered Code governed by this License), You must (a) + rename Your license so that the phrases "Mozilla", "MOZILLAPL", "MOZPL", "Netscape", "MPL", + "NPL" or any confusingly similar phrase do not appear in your license (except to note that + your license differs from this License) and (b) otherwise make it clear that Your version of + the license contains terms which differ from the Mozilla Public License and Netscape Public + License. (Filling in the name of the Initial Developer, Original Code or Contributor in the + notice described in Exhibit A shall not of themselves be deemed to + be modifications of this License.) +

      7. Disclaimer of warranty

      +

      Covered code is provided under this license on an "as is" + basis, without warranty of any kind, either expressed or implied, including, without + limitation, warranties that the covered code is free of defects, merchantable, fit for a + particular purpose or non-infringing. The entire risk as to the quality and performance of + the covered code is with you. Should any covered code prove defective in any respect, you + (not the initial developer or any other contributor) assume the cost of any necessary + servicing, repair or correction. This disclaimer of warranty constitutes an essential part + of this license. No use of any covered code is authorized hereunder except under this + disclaimer. +

      8. Termination

      +

      8.1. This License and the rights granted hereunder will terminate + automatically if You fail to comply with terms herein and fail to cure such breach + within 30 days of becoming aware of the breach. All sublicenses to the Covered Code which + are properly granted shall survive any termination of this License. Provisions which, by + their nature, must remain in effect beyond the termination of this License shall survive. +

      8.2. If You initiate litigation by asserting a patent infringement + claim (excluding declatory judgment actions) against Initial Developer or a Contributor + (the Initial Developer or Contributor against whom You file such action is referred to + as "Participant") alleging that: +

        +
      1. such Participant's Contributor Version directly or indirectly + infringes any patent, then any and all rights granted by such Participant to You under + Sections 2.1 and/or 2.2 of this + License shall, upon 60 days notice from Participant terminate prospectively, unless if + within 60 days after receipt of notice You either: (i) agree in writing to pay + Participant a mutually agreeable reasonable royalty for Your past and future use of + Modifications made by such Participant, or (ii) withdraw Your litigation claim with + respect to the Contributor Version against such Participant. If within 60 days of + notice, a reasonable royalty and payment arrangement are not mutually agreed upon in + writing by the parties or the litigation claim is not withdrawn, the rights granted by + Participant to You under Sections 2.1 and/or + 2.2 automatically terminate at the expiration of the 60 day + notice period specified above. +
      2. any software, hardware, or device, other than such Participant's + Contributor Version, directly or indirectly infringes any patent, then any rights + granted to You by such Participant under Sections 2.1(b) + and 2.2(b) are revoked effective as of the date You first + made, used, sold, distributed, or had made, Modifications made by that Participant. +
      +

      8.3. If You assert a patent infringement claim against Participant + alleging that such Participant's Contributor Version directly or indirectly infringes + any patent where such claim is resolved (such as by license or settlement) prior to the + initiation of patent infringement litigation, then the reasonable value of the licenses + granted by such Participant under Sections 2.1 or + 2.2 shall be taken into account in determining the amount or + value of any payment or license. +

      8.4. In the event of termination under Sections + 8.1 or 8.2 above, all end user + license agreements (excluding distributors and resellers) which have been validly + granted by You or any distributor hereunder prior to termination shall survive + termination. +

      9. Limitation of liability

      +

      Under no circumstances and under no legal theory, whether + tort (including negligence), contract, or otherwise, shall you, the initial developer, + any other contributor, or any distributor of covered code, or any supplier of any of + such parties, be liable to any person for any indirect, special, incidental, or + consequential damages of any character including, without limitation, damages for loss + of goodwill, work stoppage, computer failure or malfunction, or any and all other + commercial damages or losses, even if such party shall have been informed of the + possibility of such damages. This limitation of liability shall not apply to liability + for death or personal injury resulting from such party's negligence to the extent + applicable law prohibits such limitation. Some jurisdictions do not allow the exclusion + or limitation of incidental or consequential damages, so this exclusion and limitation + may not apply to you. +

      10. U.S. government end users

      +

      The Covered Code is a "commercial item," as that term is defined in 48 + C.F.R. 2.101 (Oct. 1995), consisting of + "commercial computer software" and "commercial computer software documentation," as such + terms are used in 48 C.F.R. 12.212 (Sept. + 1995). Consistent with 48 C.F.R. 12.212 and 48 C.F.R. + 227.7202-1 through 227.7202-4 (June 1995), all U.S. Government End Users + acquire Covered Code with only those rights set forth herein. +

      11. Miscellaneous

      +

      This License represents the complete agreement concerning subject matter hereof. If + any provision of this License is held to be unenforceable, such provision shall be + reformed only to the extent necessary to make it enforceable. This License shall be + governed by California law provisions (except to the extent applicable law, if any, + provides otherwise), excluding its conflict-of-law provisions. With respect to + disputes in which at least one party is a citizen of, or an entity chartered or + registered to do business in the United States of America, any litigation relating to + this License shall be subject to the jurisdiction of the Federal Courts of the + Northern District of California, with venue lying in Santa Clara County, California, + with the losing party responsible for costs, including without limitation, court + costs and reasonable attorneys' fees and expenses. The application of the United + Nations Convention on Contracts for the International Sale of Goods is expressly + excluded. Any law or regulation which provides that the language of a contract + shall be construed against the drafter shall not apply to this License. +

      12. Responsibility for claims

      +

      As between Initial Developer and the Contributors, each party is responsible for + claims and damages arising, directly or indirectly, out of its utilization of rights + under this License and You agree to work with Initial Developer and Contributors to + distribute such responsibility on an equitable basis. Nothing herein is intended or + shall be deemed to constitute any admission of liability. +

      13. Multiple-licensed code

      +

      Initial Developer may designate portions of the Covered Code as + "Multiple-Licensed". "Multiple-Licensed" means that the Initial Developer permits + you to utilize portions of the Covered Code under Your choice of the MPL + or the alternative licenses, if any, specified by the Initial Developer in the file + described in Exhibit A. +

      Exhibit A - Mozilla Public License.

      +
      "The contents of this file are subject to the Mozilla Public License
      +Version 1.1 (the "License"); you may not use this file except in
      +compliance with the License. You may obtain a copy of the License at
      +http://www.mozilla.org/MPL/
      +
      +Software distributed under the License is distributed on an "AS IS"
      +basis, WITHOUT WARRANTY OF ANY KIND, either express or implied. See the
      +License for the specific language governing rights and limitations
      +under the License.
      +
      +The Original Code is JTransforms.
      +
      +The Initial Developer of the Original Code is
      +Piotr Wendykier, Emory University.
      +Portions created by the Initial Developer are Copyright (C) 2007-2009
      +the Initial Developer. All Rights Reserved.
      +
      +Alternatively, the contents of this file may be used under the terms of
      +either the GNU General Public License Version 2 or later (the "GPL"), or
      +the GNU Lesser General Public License Version 2.1 or later (the "LGPL"),
      +in which case the provisions of the GPL or the LGPL are applicable instead
      +of those above. If you wish to allow use of your version of this file only
      +under the terms of either the GPL or the LGPL, and not to allow others to
      +use your version of this file under the terms of the MPL, indicate your
      +decision by deleting the provisions above and replace them with the notice
      +and other provisions required by the GPL or the LGPL. If you do not delete
      +the provisions above, a recipient may use your version of this file under
      +the terms of any one of the MPL, the GPL or the LGPL.
      +

      NOTE: The text of this Exhibit A may differ slightly from the text of + the notices in the Source Code files of the Original Code. You should + use the text of this Exhibit A rather than the text found in the + Original Code Source Code for Your Modifications. + +

      \ No newline at end of file diff --git a/licenses/LICENSE-kryo.txt b/licenses-binary/LICENSE-kryo.txt similarity index 100% rename from licenses/LICENSE-kryo.txt rename to licenses-binary/LICENSE-kryo.txt diff --git a/licenses-binary/LICENSE-leveldbjni.txt b/licenses-binary/LICENSE-leveldbjni.txt new file mode 100644 index 0000000000000..b4dabb9174c6d --- /dev/null +++ b/licenses-binary/LICENSE-leveldbjni.txt @@ -0,0 +1,27 @@ +Copyright (c) 2011 FuseSource Corp. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + * Neither the name of FuseSource Corp. nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. \ No newline at end of file diff --git a/licenses-binary/LICENSE-machinist.txt b/licenses-binary/LICENSE-machinist.txt new file mode 100644 index 0000000000000..68cc3a3e3a9c4 --- /dev/null +++ b/licenses-binary/LICENSE-machinist.txt @@ -0,0 +1,19 @@ +Copyright (c) 2011-2014 Erik Osheim, Tom Switzer + +Permission is hereby granted, free of charge, to any person obtaining a copy of +this software and associated documentation files (the "Software"), to deal in +the Software without restriction, including without limitation the rights to +use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies +of the Software, and to permit persons to whom the Software is furnished to do +so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/licenses-binary/LICENSE-matchMedia-polyfill.txt b/licenses-binary/LICENSE-matchMedia-polyfill.txt new file mode 100644 index 0000000000000..2fd0bc2b37448 --- /dev/null +++ b/licenses-binary/LICENSE-matchMedia-polyfill.txt @@ -0,0 +1 @@ +matchMedia() polyfill - Test a CSS media type/query in JS. Authors & copyright (c) 2012: Scott Jehl, Paul Irish, Nicholas Zakas. Dual MIT/BSD license \ No newline at end of file diff --git a/licenses/LICENSE-minlog.txt b/licenses-binary/LICENSE-minlog.txt similarity index 100% rename from licenses/LICENSE-minlog.txt rename to licenses-binary/LICENSE-minlog.txt diff --git a/licenses-binary/LICENSE-modernizr.txt b/licenses-binary/LICENSE-modernizr.txt new file mode 100644 index 0000000000000..2bf24b9b9f848 --- /dev/null +++ b/licenses-binary/LICENSE-modernizr.txt @@ -0,0 +1,21 @@ +The MIT License (MIT) + +Copyright (c) + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. \ No newline at end of file diff --git a/licenses/LICENSE-netlib.txt b/licenses-binary/LICENSE-netlib.txt similarity index 100% rename from licenses/LICENSE-netlib.txt rename to licenses-binary/LICENSE-netlib.txt diff --git a/licenses/LICENSE-paranamer.txt b/licenses-binary/LICENSE-paranamer.txt similarity index 100% rename from licenses/LICENSE-paranamer.txt rename to licenses-binary/LICENSE-paranamer.txt diff --git a/licenses/LICENSE-jpmml-model.txt b/licenses-binary/LICENSE-pmml-model.txt similarity index 100% rename from licenses/LICENSE-jpmml-model.txt rename to licenses-binary/LICENSE-pmml-model.txt diff --git a/licenses/LICENSE-protobuf.txt b/licenses-binary/LICENSE-protobuf.txt similarity index 100% rename from licenses/LICENSE-protobuf.txt rename to licenses-binary/LICENSE-protobuf.txt diff --git a/licenses-binary/LICENSE-py4j.txt b/licenses-binary/LICENSE-py4j.txt new file mode 100644 index 0000000000000..70af3e69ed67a --- /dev/null +++ b/licenses-binary/LICENSE-py4j.txt @@ -0,0 +1,27 @@ +Copyright (c) 2009-2011, Barthelemy Dagenais All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +- Redistributions of source code must retain the above copyright notice, this +list of conditions and the following disclaimer. + +- Redistributions in binary form must reproduce the above copyright notice, +this list of conditions and the following disclaimer in the documentation +and/or other materials provided with the distribution. + +- The name of the author may not be used to endorse or promote products +derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +POSSIBILITY OF SUCH DAMAGE. + diff --git a/licenses/LICENSE-pyrolite.txt b/licenses-binary/LICENSE-pyrolite.txt similarity index 100% rename from licenses/LICENSE-pyrolite.txt rename to licenses-binary/LICENSE-pyrolite.txt diff --git a/licenses/LICENSE-reflectasm.txt b/licenses-binary/LICENSE-reflectasm.txt similarity index 100% rename from licenses/LICENSE-reflectasm.txt rename to licenses-binary/LICENSE-reflectasm.txt diff --git a/licenses-binary/LICENSE-respond.txt b/licenses-binary/LICENSE-respond.txt new file mode 100644 index 0000000000000..dea4ff9e5b2ea --- /dev/null +++ b/licenses-binary/LICENSE-respond.txt @@ -0,0 +1,22 @@ +Copyright (c) 2012 Scott Jehl + +Permission is hereby granted, free of charge, to any person +obtaining a copy of this software and associated documentation +files (the "Software"), to deal in the Software without +restriction, including without limitation the rights to use, +copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following +conditions: + +The above copyright notice and this permission notice shall be +included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES +OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT +HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, +WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR +OTHER DEALINGS IN THE SOFTWARE. \ No newline at end of file diff --git a/licenses-binary/LICENSE-sbt-launch-lib.txt b/licenses-binary/LICENSE-sbt-launch-lib.txt new file mode 100644 index 0000000000000..3b9156baaab78 --- /dev/null +++ b/licenses-binary/LICENSE-sbt-launch-lib.txt @@ -0,0 +1,26 @@ +// Generated from http://www.opensource.org/licenses/bsd-license.php +Copyright (c) 2011, Paul Phillips. +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + + * Redistributions of source code must retain the above copyright notice, + this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + * Neither the name of the author nor the names of its contributors may be + used to endorse or promote products derived from this software without + specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, +EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. \ No newline at end of file diff --git a/licenses/LICENSE-scala.txt b/licenses-binary/LICENSE-scala.txt similarity index 100% rename from licenses/LICENSE-scala.txt rename to licenses-binary/LICENSE-scala.txt diff --git a/licenses-binary/LICENSE-scopt.txt b/licenses-binary/LICENSE-scopt.txt new file mode 100644 index 0000000000000..e92e9b592fba0 --- /dev/null +++ b/licenses-binary/LICENSE-scopt.txt @@ -0,0 +1,9 @@ +This project is licensed under the MIT license. + +Copyright (c) scopt contributors + +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. \ No newline at end of file diff --git a/licenses/LICENSE-slf4j.txt b/licenses-binary/LICENSE-slf4j.txt similarity index 100% rename from licenses/LICENSE-slf4j.txt rename to licenses-binary/LICENSE-slf4j.txt diff --git a/licenses-binary/LICENSE-sorttable.js.txt b/licenses-binary/LICENSE-sorttable.js.txt new file mode 100644 index 0000000000000..b31a5b206bf40 --- /dev/null +++ b/licenses-binary/LICENSE-sorttable.js.txt @@ -0,0 +1,16 @@ +Copyright (c) 1997-2007 Stuart Langridge + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. diff --git a/licenses/LICENSE-spire.txt b/licenses-binary/LICENSE-spire.txt similarity index 100% rename from licenses/LICENSE-spire.txt rename to licenses-binary/LICENSE-spire.txt diff --git a/licenses-binary/LICENSE-vis.txt b/licenses-binary/LICENSE-vis.txt new file mode 100644 index 0000000000000..18b7323059a41 --- /dev/null +++ b/licenses-binary/LICENSE-vis.txt @@ -0,0 +1,22 @@ +vis.js +https://github.com/almende/vis + +A dynamic, browser-based visualization library. + +@version 4.16.1 +@date 2016-04-18 + +@license +Copyright (C) 2011-2016 Almende B.V, http://almende.com + +Vis.js is dual licensed under both + +* The Apache 2.0 License + http://www.apache.org/licenses/LICENSE-2.0 + +and + +* The MIT License + http://opensource.org/licenses/MIT + +Vis.js may be distributed under either license. \ No newline at end of file diff --git a/licenses/LICENSE-xmlenc.txt b/licenses-binary/LICENSE-xmlenc.txt similarity index 100% rename from licenses/LICENSE-xmlenc.txt rename to licenses-binary/LICENSE-xmlenc.txt diff --git a/licenses/LICENSE-zstd-jni.txt b/licenses-binary/LICENSE-zstd-jni.txt similarity index 100% rename from licenses/LICENSE-zstd-jni.txt rename to licenses-binary/LICENSE-zstd-jni.txt diff --git a/licenses/LICENSE-zstd.txt b/licenses-binary/LICENSE-zstd.txt similarity index 100% rename from licenses/LICENSE-zstd.txt rename to licenses-binary/LICENSE-zstd.txt diff --git a/licenses/LICENSE-CC0.txt b/licenses/LICENSE-CC0.txt new file mode 100644 index 0000000000000..1625c17936079 --- /dev/null +++ b/licenses/LICENSE-CC0.txt @@ -0,0 +1,121 @@ +Creative Commons Legal Code + +CC0 1.0 Universal + + CREATIVE COMMONS CORPORATION IS NOT A LAW FIRM AND DOES NOT PROVIDE + LEGAL SERVICES. DISTRIBUTION OF THIS DOCUMENT DOES NOT CREATE AN + ATTORNEY-CLIENT RELATIONSHIP. CREATIVE COMMONS PROVIDES THIS + INFORMATION ON AN "AS-IS" BASIS. CREATIVE COMMONS MAKES NO WARRANTIES + REGARDING THE USE OF THIS DOCUMENT OR THE INFORMATION OR WORKS + PROVIDED HEREUNDER, AND DISCLAIMS LIABILITY FOR DAMAGES RESULTING FROM + THE USE OF THIS DOCUMENT OR THE INFORMATION OR WORKS PROVIDED + HEREUNDER. + +Statement of Purpose + +The laws of most jurisdictions throughout the world automatically confer +exclusive Copyright and Related Rights (defined below) upon the creator +and subsequent owner(s) (each and all, an "owner") of an original work of +authorship and/or a database (each, a "Work"). + +Certain owners wish to permanently relinquish those rights to a Work for +the purpose of contributing to a commons of creative, cultural and +scientific works ("Commons") that the public can reliably and without fear +of later claims of infringement build upon, modify, incorporate in other +works, reuse and redistribute as freely as possible in any form whatsoever +and for any purposes, including without limitation commercial purposes. +These owners may contribute to the Commons to promote the ideal of a free +culture and the further production of creative, cultural and scientific +works, or to gain reputation or greater distribution for their Work in +part through the use and efforts of others. + +For these and/or other purposes and motivations, and without any +expectation of additional consideration or compensation, the person +associating CC0 with a Work (the "Affirmer"), to the extent that he or she +is an owner of Copyright and Related Rights in the Work, voluntarily +elects to apply CC0 to the Work and publicly distribute the Work under its +terms, with knowledge of his or her Copyright and Related Rights in the +Work and the meaning and intended legal effect of CC0 on those rights. + +1. Copyright and Related Rights. A Work made available under CC0 may be +protected by copyright and related or neighboring rights ("Copyright and +Related Rights"). Copyright and Related Rights include, but are not +limited to, the following: + + i. the right to reproduce, adapt, distribute, perform, display, + communicate, and translate a Work; + ii. moral rights retained by the original author(s) and/or performer(s); +iii. publicity and privacy rights pertaining to a person's image or + likeness depicted in a Work; + iv. rights protecting against unfair competition in regards to a Work, + subject to the limitations in paragraph 4(a), below; + v. rights protecting the extraction, dissemination, use and reuse of data + in a Work; + vi. database rights (such as those arising under Directive 96/9/EC of the + European Parliament and of the Council of 11 March 1996 on the legal + protection of databases, and under any national implementation + thereof, including any amended or successor version of such + directive); and +vii. other similar, equivalent or corresponding rights throughout the + world based on applicable law or treaty, and any national + implementations thereof. + +2. Waiver. To the greatest extent permitted by, but not in contravention +of, applicable law, Affirmer hereby overtly, fully, permanently, +irrevocably and unconditionally waives, abandons, and surrenders all of +Affirmer's Copyright and Related Rights and associated claims and causes +of action, whether now known or unknown (including existing as well as +future claims and causes of action), in the Work (i) in all territories +worldwide, (ii) for the maximum duration provided by applicable law or +treaty (including future time extensions), (iii) in any current or future +medium and for any number of copies, and (iv) for any purpose whatsoever, +including without limitation commercial, advertising or promotional +purposes (the "Waiver"). Affirmer makes the Waiver for the benefit of each +member of the public at large and to the detriment of Affirmer's heirs and +successors, fully intending that such Waiver shall not be subject to +revocation, rescission, cancellation, termination, or any other legal or +equitable action to disrupt the quiet enjoyment of the Work by the public +as contemplated by Affirmer's express Statement of Purpose. + +3. Public License Fallback. Should any part of the Waiver for any reason +be judged legally invalid or ineffective under applicable law, then the +Waiver shall be preserved to the maximum extent permitted taking into +account Affirmer's express Statement of Purpose. In addition, to the +extent the Waiver is so judged Affirmer hereby grants to each affected +person a royalty-free, non transferable, non sublicensable, non exclusive, +irrevocable and unconditional license to exercise Affirmer's Copyright and +Related Rights in the Work (i) in all territories worldwide, (ii) for the +maximum duration provided by applicable law or treaty (including future +time extensions), (iii) in any current or future medium and for any number +of copies, and (iv) for any purpose whatsoever, including without +limitation commercial, advertising or promotional purposes (the +"License"). The License shall be deemed effective as of the date CC0 was +applied by Affirmer to the Work. Should any part of the License for any +reason be judged legally invalid or ineffective under applicable law, such +partial invalidity or ineffectiveness shall not invalidate the remainder +of the License, and in such case Affirmer hereby affirms that he or she +will not (i) exercise any of his or her remaining Copyright and Related +Rights in the Work or (ii) assert any associated claims and causes of +action with respect to the Work, in either case contrary to Affirmer's +express Statement of Purpose. + +4. Limitations and Disclaimers. + + a. No trademark or patent rights held by Affirmer are waived, abandoned, + surrendered, licensed or otherwise affected by this document. + b. Affirmer offers the Work as-is and makes no representations or + warranties of any kind concerning the Work, express, implied, + statutory or otherwise, including without limitation warranties of + title, merchantability, fitness for a particular purpose, non + infringement, or the absence of latent or other defects, accuracy, or + the present or absence of errors, whether or not discoverable, all to + the greatest extent permissible under applicable law. + c. Affirmer disclaims responsibility for clearing rights of other persons + that may apply to the Work or any use thereof, including without + limitation any person's Copyright and Related Rights in the Work. + Further, Affirmer disclaims responsibility for obtaining any necessary + consents, permissions or other rights required for any use of the + Work. + d. Affirmer understands and acknowledges that Creative Commons is not a + party to this document and has no duty or obligation with respect to + this CC0 or use of the Work. \ No newline at end of file diff --git a/licenses/LICENSE-SnapTree.txt b/licenses/LICENSE-SnapTree.txt deleted file mode 100644 index a538825d89ec5..0000000000000 --- a/licenses/LICENSE-SnapTree.txt +++ /dev/null @@ -1,35 +0,0 @@ -SNAPTREE LICENSE - -Copyright (c) 2009-2012 Stanford University, unless otherwise specified. -All rights reserved. - -This software was developed by the Pervasive Parallelism Laboratory of -Stanford University, California, USA. - -Permission to use, copy, modify, and distribute this software in source -or binary form for any purpose with or without fee is hereby granted, -provided that the following conditions are met: - - 1. Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - 2. Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - 3. Neither the name of Stanford University nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - -THIS SOFTWARE IS PROVIDED BY THE REGENTS AND CONTRIBUTORS ``AS IS'' AND -ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE -ARE DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE -FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT -LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY -OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF -SUCH DAMAGE. diff --git a/licenses/LICENSE-bootstrap.txt b/licenses/LICENSE-bootstrap.txt new file mode 100644 index 0000000000000..6c711832fbc85 --- /dev/null +++ b/licenses/LICENSE-bootstrap.txt @@ -0,0 +1,13 @@ +Copyright 2013 Twitter, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. diff --git a/licenses/LICENSE-boto.txt b/licenses/LICENSE-boto.txt deleted file mode 100644 index 7bba0cd9e10a4..0000000000000 --- a/licenses/LICENSE-boto.txt +++ /dev/null @@ -1,20 +0,0 @@ -Copyright (c) 2006-2008 Mitch Garnaat http://garnaat.org/ - -Permission is hereby granted, free of charge, to any person obtaining a -copy of this software and associated documentation files (the -"Software"), to deal in the Software without restriction, including -without limitation the rights to use, copy, modify, merge, publish, dis- -tribute, sublicense, and/or sell copies of the Software, and to permit -persons to whom the Software is furnished to do so, subject to the fol- -lowing conditions: - -The above copyright notice and this permission notice shall be included -in all copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS -OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABIL- -ITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT -SHALL THE AUTHOR BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, -WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS -IN THE SOFTWARE. \ No newline at end of file diff --git a/licenses/LICENSE-datatables.txt b/licenses/LICENSE-datatables.txt new file mode 100644 index 0000000000000..bb7708b5b5a49 --- /dev/null +++ b/licenses/LICENSE-datatables.txt @@ -0,0 +1,7 @@ +Copyright (C) 2008-2018, SpryMedia Ltd. + +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. \ No newline at end of file diff --git a/licenses/LICENSE-graphlib-dot.txt b/licenses/LICENSE-graphlib-dot.txt index c9e18cd562423..4864fe05e9803 100644 --- a/licenses/LICENSE-graphlib-dot.txt +++ b/licenses/LICENSE-graphlib-dot.txt @@ -1,4 +1,4 @@ -Copyright (c) 2012-2013 Chris Pettitt +Copyright (c) 2013 Chris Pettitt Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal diff --git a/licenses/LICENSE-jbcrypt.txt b/licenses/LICENSE-jbcrypt.txt deleted file mode 100644 index d332534c06356..0000000000000 --- a/licenses/LICENSE-jbcrypt.txt +++ /dev/null @@ -1,17 +0,0 @@ -jBCrypt is subject to the following license: - -/* - * Copyright (c) 2006 Damien Miller - * - * Permission to use, copy, modify, and distribute this software for any - * purpose with or without fee is hereby granted, provided that the above - * copyright notice and this permission notice appear in all copies. - * - * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES - * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF - * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR - * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES - * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN - * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF - * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. - */ diff --git a/licenses/LICENSE-jmock.txt b/licenses/LICENSE-join.txt similarity index 60% rename from licenses/LICENSE-jmock.txt rename to licenses/LICENSE-join.txt index ed7964fe3d9ef..1d916090e4ea0 100644 --- a/licenses/LICENSE-jmock.txt +++ b/licenses/LICENSE-join.txt @@ -1,19 +1,21 @@ -Copyright (c) 2000-2017, jMock.org +Copyright (c) 2011, Douban Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: -Redistributions of source code must retain the above copyright notice, -this list of conditions and the following disclaimer. Redistributions -in binary form must reproduce the above copyright notice, this list of -conditions and the following disclaimer in the documentation and/or -other materials provided with the distribution. + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. -Neither the name of jMock nor the names of its contributors may be -used to endorse or promote products derived from this software without -specific prior written permission. + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + + * Neither the name of the Douban Inc. nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -25,4 +27,4 @@ LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. \ No newline at end of file diff --git a/licenses/LICENSE-jquery.txt b/licenses/LICENSE-jquery.txt index e1dd696d3b6cc..45930542204fb 100644 --- a/licenses/LICENSE-jquery.txt +++ b/licenses/LICENSE-jquery.txt @@ -1,9 +1,20 @@ -The MIT License (MIT) +Copyright JS Foundation and other contributors, https://js.foundation/ -Copyright (c) +Permission is hereby granted, free of charge, to any person obtaining +a copy of this software and associated documentation files (the +"Software"), to deal in the Software without restriction, including +without limitation the rights to use, copy, modify, merge, publish, +distribute, sublicense, and/or sell copies of the Software, and to +permit persons to whom the Software is furnished to do so, subject to +the following conditions: -Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: +The above copyright notice and this permission notice shall be +included in all copies or substantial portions of the Software. -The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. \ No newline at end of file +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE +LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION +WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. \ No newline at end of file diff --git a/licenses/LICENSE-json-formatter.txt b/licenses/LICENSE-json-formatter.txt new file mode 100644 index 0000000000000..5193348fce126 --- /dev/null +++ b/licenses/LICENSE-json-formatter.txt @@ -0,0 +1,6 @@ +Copyright 2014 Mohsen Azimi + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. \ No newline at end of file diff --git a/licenses/LICENSE-matchMedia-polyfill.txt b/licenses/LICENSE-matchMedia-polyfill.txt new file mode 100644 index 0000000000000..2fd0bc2b37448 --- /dev/null +++ b/licenses/LICENSE-matchMedia-polyfill.txt @@ -0,0 +1 @@ +matchMedia() polyfill - Test a CSS media type/query in JS. Authors & copyright (c) 2012: Scott Jehl, Paul Irish, Nicholas Zakas. Dual MIT/BSD license \ No newline at end of file diff --git a/licenses/LICENSE-postgresql.txt b/licenses/LICENSE-postgresql.txt deleted file mode 100644 index 515bf9af4d432..0000000000000 --- a/licenses/LICENSE-postgresql.txt +++ /dev/null @@ -1,24 +0,0 @@ -PostgreSQL Database Management System -(formerly known as Postgres, then as Postgres95) - -Portions Copyright (c) 1996-2010, PostgreSQL Global Development Group - -Portions Copyright (c) 1994, The Regents of the University of California - -Permission to use, copy, modify, and distribute this software and its -documentation for any purpose, without fee, and without a written agreement -is hereby granted, provided that the above copyright notice and this -paragraph and the following two paragraphs appear in all copies. - -IN NO EVENT SHALL THE UNIVERSITY OF CALIFORNIA BE LIABLE TO ANY PARTY FOR -DIRECT, INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, INCLUDING -LOST PROFITS, ARISING OUT OF THE USE OF THIS SOFTWARE AND ITS -DOCUMENTATION, EVEN IF THE UNIVERSITY OF CALIFORNIA HAS BEEN ADVISED OF THE -POSSIBILITY OF SUCH DAMAGE. - -THE UNIVERSITY OF CALIFORNIA SPECIFICALLY DISCLAIMS ANY WARRANTIES, -INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY -AND FITNESS FOR A PARTICULAR PURPOSE. THE SOFTWARE PROVIDED HEREUNDER IS -ON AN "AS IS" BASIS, AND THE UNIVERSITY OF CALIFORNIA HAS NO OBLIGATIONS TO -PROVIDE MAINTENANCE, SUPPORT, UPDATES, ENHANCEMENTS, OR MODIFICATIONS. - diff --git a/licenses/LICENSE-respond.txt b/licenses/LICENSE-respond.txt new file mode 100644 index 0000000000000..dea4ff9e5b2ea --- /dev/null +++ b/licenses/LICENSE-respond.txt @@ -0,0 +1,22 @@ +Copyright (c) 2012 Scott Jehl + +Permission is hereby granted, free of charge, to any person +obtaining a copy of this software and associated documentation +files (the "Software"), to deal in the Software without +restriction, including without limitation the rights to use, +copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following +conditions: + +The above copyright notice and this permission notice shall be +included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES +OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT +HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, +WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR +OTHER DEALINGS IN THE SOFTWARE. \ No newline at end of file diff --git a/licenses/LICENSE-scalacheck.txt b/licenses/LICENSE-scalacheck.txt deleted file mode 100644 index cb8f97842f4c4..0000000000000 --- a/licenses/LICENSE-scalacheck.txt +++ /dev/null @@ -1,32 +0,0 @@ -ScalaCheck LICENSE - -Copyright (c) 2007-2015, Rickard Nilsson -All rights reserved. - -Permission to use, copy, modify, and distribute this software in source -or binary form for any purpose with or without fee is hereby granted, -provided that the following conditions are met: - - 1. Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - 2. Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - 3. Neither the name of the author nor the names of its contributors - may be used to endorse or promote products derived from this - software without specific prior written permission. - - -THIS SOFTWARE IS PROVIDED BY THE REGENTS AND CONTRIBUTORS ``AS IS'' AND -ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE -ARE DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE -FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT -LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY -OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF -SUCH DAMAGE. \ No newline at end of file diff --git a/licenses/LICENSE-vis.txt b/licenses/LICENSE-vis.txt new file mode 100644 index 0000000000000..18b7323059a41 --- /dev/null +++ b/licenses/LICENSE-vis.txt @@ -0,0 +1,22 @@ +vis.js +https://github.com/almende/vis + +A dynamic, browser-based visualization library. + +@version 4.16.1 +@date 2016-04-18 + +@license +Copyright (C) 2011-2016 Almende B.V, http://almende.com + +Vis.js is dual licensed under both + +* The Apache 2.0 License + http://www.apache.org/licenses/LICENSE-2.0 + +and + +* The MIT License + http://opensource.org/licenses/MIT + +Vis.js may be distributed under either license. \ No newline at end of file From 8f91c697e251423b826cd6ac4ddd9e2dac15b96e Mon Sep 17 00:00:00 2001 From: Yuanjian Li Date: Mon, 2 Jul 2018 14:35:37 +0800 Subject: [PATCH 1046/1161] [SPARK-24665][PYSPARK] Use SQLConf in PySpark to manage all sql configs ## What changes were proposed in this pull request? Use SQLConf for PySpark to manage all sql configs, drop all the hard code in config usage. ## How was this patch tested? Existing UT. Author: Yuanjian Li Closes #21648 from xuanyuanking/SPARK-24665. --- python/pyspark/sql/context.py | 5 +++ python/pyspark/sql/dataframe.py | 42 +++++-------------- .../apache/spark/sql/internal/SQLConf.scala | 6 +++ 3 files changed, 21 insertions(+), 32 deletions(-) diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index e9ec7ba866761..9c094dd9a9033 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -93,6 +93,11 @@ def _ssql_ctx(self): """ return self._jsqlContext + @property + def _conf(self): + """Accessor for the JVM SQL-specific configurations""" + return self.sparkSession._jsparkSession.sessionState().conf() + @classmethod @since(1.6) def getOrCreate(cls, sc): diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index cb3fe448b6fc7..c40aea9bcef0a 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -354,32 +354,12 @@ def show(self, n=20, truncate=True, vertical=False): else: print(self._jdf.showString(n, int(truncate), vertical)) - @property - def _eager_eval(self): - """Returns true if the eager evaluation enabled. - """ - return self.sql_ctx.getConf( - "spark.sql.repl.eagerEval.enabled", "false").lower() == "true" - - @property - def _max_num_rows(self): - """Returns the max row number for eager evaluation. - """ - return int(self.sql_ctx.getConf( - "spark.sql.repl.eagerEval.maxNumRows", "20")) - - @property - def _truncate(self): - """Returns the truncate length for eager evaluation. - """ - return int(self.sql_ctx.getConf( - "spark.sql.repl.eagerEval.truncate", "20")) - def __repr__(self): - if not self._support_repr_html and self._eager_eval: + if not self._support_repr_html and self.sql_ctx._conf.isReplEagerEvalEnabled(): vertical = False return self._jdf.showString( - self._max_num_rows, self._truncate, vertical) + self.sql_ctx._conf.replEagerEvalMaxNumRows(), + self.sql_ctx._conf.replEagerEvalTruncate(), vertical) else: return "DataFrame[%s]" % (", ".join("%s: %s" % c for c in self.dtypes)) @@ -391,10 +371,10 @@ def _repr_html_(self): import cgi if not self._support_repr_html: self._support_repr_html = True - if self._eager_eval: - max_num_rows = max(self._max_num_rows, 0) + if self.sql_ctx._conf.isReplEagerEvalEnabled(): + max_num_rows = max(self.sql_ctx._conf.replEagerEvalMaxNumRows(), 0) sock_info = self._jdf.getRowsToPython( - max_num_rows, self._truncate) + max_num_rows, self.sql_ctx._conf.replEagerEvalTruncate()) rows = list(_load_from_socket(sock_info, BatchedSerializer(PickleSerializer()))) head = rows[0] row_data = rows[1:] @@ -2049,13 +2029,12 @@ def toPandas(self): import pandas as pd - if self.sql_ctx.getConf("spark.sql.execution.pandas.respectSessionTimeZone").lower() \ - == "true": - timezone = self.sql_ctx.getConf("spark.sql.session.timeZone") + if self.sql_ctx._conf.pandasRespectSessionTimeZone(): + timezone = self.sql_ctx._conf.sessionLocalTimeZone() else: timezone = None - if self.sql_ctx.getConf("spark.sql.execution.arrow.enabled", "false").lower() == "true": + if self.sql_ctx._conf.arrowEnabled(): use_arrow = True try: from pyspark.sql.types import to_arrow_schema @@ -2065,8 +2044,7 @@ def toPandas(self): to_arrow_schema(self.schema) except Exception as e: - if self.sql_ctx.getConf("spark.sql.execution.arrow.fallback.enabled", "true") \ - .lower() == "true": + if self.sql_ctx._conf.arrowFallbackEnabled(): msg = ( "toPandas attempted Arrow optimization because " "'spark.sql.execution.arrow.enabled' is set to true; however, " diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index da1c34cdc78f2..e2c48e2d8a14c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1728,6 +1728,12 @@ class SQLConf extends Serializable with Logging { def legacySizeOfNull: Boolean = getConf(SQLConf.LEGACY_SIZE_OF_NULL) + def isReplEagerEvalEnabled: Boolean = getConf(SQLConf.REPL_EAGER_EVAL_ENABLED) + + def replEagerEvalMaxNumRows: Int = getConf(SQLConf.REPL_EAGER_EVAL_MAX_NUM_ROWS) + + def replEagerEvalTruncate: Int = getConf(SQLConf.REPL_EAGER_EVAL_TRUNCATE) + /** ********************** SQLConf functionality methods ************ */ /** Set Spark SQL configuration properties. */ From 8008f9cb82e7c228b94eade2e7cb484d6d17e6a4 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 2 Jul 2018 22:09:47 +0800 Subject: [PATCH 1047/1161] [SPARK-24715][BUILD] Override jline version as 2.14.3 in SBT ## What changes were proposed in this pull request? During SPARK-24418 (Upgrade Scala to 2.11.12 and 2.12.6), we upgrade `jline` version together. So, `mvn` works correctly. However, `sbt` brings old jline library and is hitting `NoSuchMethodError` in `master` branch, see https://github.com/apache/spark/pull/21495#issuecomment-401560826. This overrides jline version in SBT to make sbt build work. ## How was this patch tested? Manually test. Author: Liang-Chi Hsieh Closes #21692 from viirya/SPARK-24715. --- project/SparkBuild.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index b606f9355e03b..f887e4570c85d 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -464,7 +464,8 @@ object DockerIntegrationTests { */ object DependencyOverrides { lazy val settings = Seq( - dependencyOverrides += "com.google.guava" % "guava" % "14.0.1") + dependencyOverrides += "com.google.guava" % "guava" % "14.0.1", + dependencyOverrides += "jline" % "jline" % "2.14.3") } /** From f599cde69506a5aedeeec449cba9a8b5ab128282 Mon Sep 17 00:00:00 2001 From: Rekha Joshi Date: Mon, 2 Jul 2018 22:39:00 +0800 Subject: [PATCH 1048/1161] [SPARK-24507][DOCUMENTATION] Update streaming guide ## What changes were proposed in this pull request? Updated streaming guide for direct stream and link to integration guide. ## How was this patch tested? jekyll build Author: Rekha Joshi Closes #21683 from rekhajoshm/SPARK-24507. --- docs/streaming-programming-guide.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index c30959263cdfa..118b05355c74d 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -2176,6 +2176,8 @@ the input data stream (using `inputStream.repartition()`). This distributes the received batches of data across the specified number of machines in the cluster before further processing. +For direct stream, please refer to [Spark Streaming + Kafka Integration Guide](streaming-kafka-integration.html) + ### Level of Parallelism in Data Processing {:.no_toc} Cluster resources can be under-utilized if the number of parallel tasks used in any stage of the From 42815548c7ef498439b9ba47134a6f3e1b519c83 Mon Sep 17 00:00:00 2001 From: mcheah Date: Mon, 2 Jul 2018 10:24:04 -0700 Subject: [PATCH 1049/1161] [SPARK-24683][K8S] Fix k8s no resource ## What changes were proposed in this pull request? Make SparkSubmit pass in the main class even if `SparkLauncher.NO_RESOURCE` is the primary resource. ## How was this patch tested? New integration test written to capture this case. Author: mcheah Closes #21660 from mccheah/fix-k8s-no-resource. --- .../scala/org/apache/spark/deploy/SparkSubmit.scala | 2 ++ .../deploy/k8s/submit/KubernetesDriverBuilder.scala | 3 ++- .../deploy/k8s/integrationtest/KubernetesSuite.scala | 10 ++++++++-- 3 files changed, 12 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index e83d82f847c61..2da778a29779d 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -702,6 +702,8 @@ private[spark] class SparkSubmit extends Logging { childArgs ++= Array("--primary-java-resource", args.primaryResource) childArgs ++= Array("--main-class", args.mainClass) } + } else { + childArgs ++= Array("--main-class", args.mainClass) } if (args.childArgs != null) { args.childArgs.foreach { arg => diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala index 5762d8245f778..0dd1c37661707 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala @@ -64,7 +64,8 @@ private[spark] class KubernetesDriverBuilder( case JavaMainAppResource(_) => provideJavaStep(kubernetesConf) case PythonMainAppResource(_) => - providePythonStep(kubernetesConf)}.getOrElse(provideJavaStep(kubernetesConf)) + providePythonStep(kubernetesConf)} + .getOrElse(provideJavaStep(kubernetesConf)) val allFeatures: Seq[KubernetesFeatureConfigStep] = (baseFeatures :+ bindingsStep) ++ diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala index 65c513cf241a4..6e334c83fbde8 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala @@ -21,17 +21,17 @@ import java.nio.file.{Path, Paths} import java.util.UUID import java.util.regex.Pattern -import scala.collection.JavaConverters._ - import com.google.common.io.PatternFilenameFilter import io.fabric8.kubernetes.api.model.{Container, Pod} import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} import org.scalatest.concurrent.{Eventually, PatienceConfiguration} import org.scalatest.time.{Minutes, Seconds, Span} +import scala.collection.JavaConverters._ import org.apache.spark.SparkFunSuite import org.apache.spark.deploy.k8s.integrationtest.backend.{IntegrationTestBackend, IntegrationTestBackendFactory} import org.apache.spark.deploy.k8s.integrationtest.config._ +import org.apache.spark.launcher.SparkLauncher private[spark] class KubernetesSuite extends SparkFunSuite with BeforeAndAfterAll with BeforeAndAfter { @@ -109,6 +109,12 @@ private[spark] class KubernetesSuite extends SparkFunSuite runSparkPiAndVerifyCompletion() } + test("Use SparkLauncher.NO_RESOURCE") { + sparkAppConf.setJars(Seq(containerLocalSparkDistroExamplesJar)) + runSparkPiAndVerifyCompletion( + appResource = SparkLauncher.NO_RESOURCE) + } + test("Run SparkPi with a master URL without a scheme.") { val url = kubernetesTestComponents.kubernetesClient.getMasterUrl val k8sMasterUrl = if (url.getPort < 0) { From 85fe1297e35bcff9cf86bd53fee615e140ee5bfb Mon Sep 17 00:00:00 2001 From: Stavros Kontopoulos Date: Mon, 2 Jul 2018 13:08:16 -0700 Subject: [PATCH 1050/1161] [SPARK-24428][K8S] Fix unused code ## What changes were proposed in this pull request? Remove code that is misleading and is a leftover from a previous implementation. ## How was this patch tested? Manually. Author: Stavros Kontopoulos Closes #21462 from skonto/fix-k8s-docs. --- .../org/apache/spark/deploy/k8s/Constants.scala | 6 ------ .../cluster/k8s/KubernetesClusterManager.scala | 2 -- .../docker/src/main/dockerfiles/spark/entrypoint.sh | 12 +++++------- 3 files changed, 5 insertions(+), 15 deletions(-) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Constants.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Constants.scala index 69bd03d1eda6f..5ecdd3a04d77b 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Constants.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Constants.scala @@ -25,9 +25,6 @@ private[spark] object Constants { val SPARK_POD_DRIVER_ROLE = "driver" val SPARK_POD_EXECUTOR_ROLE = "executor" - // Annotations - val SPARK_APP_NAME_ANNOTATION = "spark-app-name" - // Credentials secrets val DRIVER_CREDENTIALS_SECRETS_BASE_DIR = "/mnt/secrets/spark-kubernetes-credentials" @@ -50,17 +47,14 @@ private[spark] object Constants { val DEFAULT_BLOCKMANAGER_PORT = 7079 val DRIVER_PORT_NAME = "driver-rpc-port" val BLOCK_MANAGER_PORT_NAME = "blockmanager" - val EXECUTOR_PORT_NAME = "executor" // Environment Variables - val ENV_EXECUTOR_PORT = "SPARK_EXECUTOR_PORT" val ENV_DRIVER_URL = "SPARK_DRIVER_URL" val ENV_EXECUTOR_CORES = "SPARK_EXECUTOR_CORES" val ENV_EXECUTOR_MEMORY = "SPARK_EXECUTOR_MEMORY" val ENV_APPLICATION_ID = "SPARK_APPLICATION_ID" val ENV_EXECUTOR_ID = "SPARK_EXECUTOR_ID" val ENV_EXECUTOR_POD_IP = "SPARK_EXECUTOR_POD_IP" - val ENV_MOUNTED_CLASSPATH = "SPARK_MOUNTED_CLASSPATH" val ENV_JAVA_OPT_PREFIX = "SPARK_JAVA_OPT_" val ENV_CLASSPATH = "SPARK_CLASSPATH" val ENV_DRIVER_BIND_ADDRESS = "SPARK_DRIVER_BIND_ADDRESS" diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala index c6e931a38405f..de2a52bc7a0b8 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala @@ -48,8 +48,6 @@ private[spark] class KubernetesClusterManager extends ExternalClusterManager wit sc: SparkContext, masterURL: String, scheduler: TaskScheduler): SchedulerBackend = { - val executorSecretNamesToMountPaths = KubernetesUtils.parsePrefixedKeyValuePairs( - sc.conf, KUBERNETES_EXECUTOR_SECRETS_PREFIX) val kubernetesClient = SparkKubernetesClientFactory.createKubernetesClient( KUBERNETES_MASTER_INTERNAL_URL, Some(sc.conf.get(KUBERNETES_NAMESPACE)), diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh index 2f4e115e84ecd..8bdb0f7a10795 100755 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh @@ -51,12 +51,10 @@ esac SPARK_CLASSPATH="$SPARK_CLASSPATH:${SPARK_HOME}/jars/*" env | grep SPARK_JAVA_OPT_ | sort -t_ -k4 -n | sed 's/[^=]*=\(.*\)/\1/g' > /tmp/java_opts.txt -readarray -t SPARK_JAVA_OPTS < /tmp/java_opts.txt -if [ -n "$SPARK_MOUNTED_CLASSPATH" ]; then - SPARK_CLASSPATH="$SPARK_CLASSPATH:$SPARK_MOUNTED_CLASSPATH" -fi -if [ -n "$SPARK_MOUNTED_FILES_DIR" ]; then - cp -R "$SPARK_MOUNTED_FILES_DIR/." . +readarray -t SPARK_EXECUTOR_JAVA_OPTS < /tmp/java_opts.txt + +if [ -n "$SPARK_EXTRA_CLASSPATH" ]; then + SPARK_CLASSPATH="$SPARK_CLASSPATH:$SPARK_EXTRA_CLASSPATH" fi if [ -n "$PYSPARK_FILES" ]; then @@ -101,7 +99,7 @@ case "$SPARK_K8S_CMD" in executor) CMD=( ${JAVA_HOME}/bin/java - "${SPARK_JAVA_OPTS[@]}" + "${SPARK_EXECUTOR_JAVA_OPTS[@]}" -Xms$SPARK_EXECUTOR_MEMORY -Xmx$SPARK_EXECUTOR_MEMORY -cp "$SPARK_CLASSPATH" From a7c8f0c8cb144a026ea21e8780107e363ceacb8d Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Tue, 3 Jul 2018 12:20:03 +0800 Subject: [PATCH 1051/1161] [SPARK-24385][SQL] Resolve self-join condition ambiguity for EqualNullSafe ## What changes were proposed in this pull request? In Dataset.join we have a small hack for resolving ambiguity in the column name for self-joins. The current code supports only `EqualTo`. The PR extends the fix to `EqualNullSafe`. Credit for this PR should be given to daniel-shields. ## How was this patch tested? added UT Author: Marco Gaido Closes #21605 from mgaido91/SPARK-24385_2. --- .../src/main/scala/org/apache/spark/sql/Dataset.scala | 5 +++++ .../scala/org/apache/spark/sql/DataFrameJoinSuite.scala | 8 ++++++++ 2 files changed, 13 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 2ec236fc75efc..c97246f30220d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -1016,6 +1016,11 @@ class Dataset[T] private[sql]( catalyst.expressions.EqualTo( withPlan(plan.left).resolve(a.name), withPlan(plan.right).resolve(b.name)) + case catalyst.expressions.EqualNullSafe(a: AttributeReference, b: AttributeReference) + if a.sameRef(b) => + catalyst.expressions.EqualNullSafe( + withPlan(plan.left).resolve(a.name), + withPlan(plan.right).resolve(b.name)) }} withPlan { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala index 0d9eeabb397a1..10d9a11d2ee79 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala @@ -287,4 +287,12 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext { dfOne.join(dfTwo, $"a" === $"b", "left").queryExecution.optimizedPlan } } + + test("SPARK-24385: Resolve ambiguity in self-joins with EqualNullSafe") { + withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "false") { + val df = spark.range(2) + // this throws an exception before the fix + df.join(df, df("id") <=> df("id")).queryExecution.optimizedPlan + } + } } From 5585c5765f13519a447587ca778d52ce6a36a484 Mon Sep 17 00:00:00 2001 From: DB Tsai Date: Tue, 3 Jul 2018 10:13:48 -0700 Subject: [PATCH 1052/1161] [SPARK-24420][BUILD] Upgrade ASM to 6.1 to support JDK9+ ## What changes were proposed in this pull request? Upgrade ASM to 6.1 to support JDK9+ ## How was this patch tested? Existing tests. Author: DB Tsai Closes #21459 from dbtsai/asm. --- core/pom.xml | 2 +- .../main/scala/org/apache/spark/util/ClosureCleaner.scala | 4 ++-- dev/deps/spark-deps-hadoop-2.6 | 2 +- dev/deps/spark-deps-hadoop-2.7 | 2 +- dev/deps/spark-deps-hadoop-3.1 | 2 +- graphx/pom.xml | 2 +- .../org/apache/spark/graphx/util/BytecodeUtils.scala | 4 ++-- pom.xml | 8 ++++---- repl/pom.xml | 4 ++-- .../scala/org/apache/spark/repl/ExecutorClassLoader.scala | 4 ++-- sql/core/pom.xml | 2 +- 11 files changed, 18 insertions(+), 18 deletions(-) diff --git a/core/pom.xml b/core/pom.xml index 220522d3a8296..d0b869e6ef92c 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -56,7 +56,7 @@ org.apache.xbean - xbean-asm5-shaded + xbean-asm6-shaded org.apache.hadoop diff --git a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala index ad0c0639521f6..073d71c63b0c7 100644 --- a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala +++ b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala @@ -22,8 +22,8 @@ import java.io.{ByteArrayInputStream, ByteArrayOutputStream} import scala.collection.mutable.{Map, Set, Stack} import scala.language.existentials -import org.apache.xbean.asm5.{ClassReader, ClassVisitor, MethodVisitor, Type} -import org.apache.xbean.asm5.Opcodes._ +import org.apache.xbean.asm6.{ClassReader, ClassVisitor, MethodVisitor, Type} +import org.apache.xbean.asm6.Opcodes._ import org.apache.spark.{SparkEnv, SparkException} import org.apache.spark.internal.Logging diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6 index 96e9c27210d05..f50a0aac0aefc 100644 --- a/dev/deps/spark-deps-hadoop-2.6 +++ b/dev/deps/spark-deps-hadoop-2.6 @@ -192,7 +192,7 @@ stringtemplate-3.2.1.jar super-csv-2.2.0.jar univocity-parsers-2.6.3.jar validation-api-1.1.0.Final.jar -xbean-asm5-shaded-4.4.jar +xbean-asm6-shaded-4.8.jar xercesImpl-2.9.1.jar xmlenc-0.52.jar xz-1.0.jar diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index 4a6ee027ec355..774f9dc39ce4d 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -193,7 +193,7 @@ stringtemplate-3.2.1.jar super-csv-2.2.0.jar univocity-parsers-2.6.3.jar validation-api-1.1.0.Final.jar -xbean-asm5-shaded-4.4.jar +xbean-asm6-shaded-4.8.jar xercesImpl-2.9.1.jar xmlenc-0.52.jar xz-1.0.jar diff --git a/dev/deps/spark-deps-hadoop-3.1 b/dev/deps/spark-deps-hadoop-3.1 index e0b560c8ec71f..19c05ad1e991f 100644 --- a/dev/deps/spark-deps-hadoop-3.1 +++ b/dev/deps/spark-deps-hadoop-3.1 @@ -214,7 +214,7 @@ token-provider-1.0.1.jar univocity-parsers-2.6.3.jar validation-api-1.1.0.Final.jar woodstox-core-5.0.3.jar -xbean-asm5-shaded-4.4.jar +xbean-asm6-shaded-4.8.jar xz-1.0.jar zjsonpatch-0.3.0.jar zookeeper-3.4.9.jar diff --git a/graphx/pom.xml b/graphx/pom.xml index fbe77fcb958d5..0f5dc548600b2 100644 --- a/graphx/pom.xml +++ b/graphx/pom.xml @@ -53,7 +53,7 @@ org.apache.xbean - xbean-asm5-shaded + xbean-asm6-shaded com.google.guava diff --git a/graphx/src/main/scala/org/apache/spark/graphx/util/BytecodeUtils.scala b/graphx/src/main/scala/org/apache/spark/graphx/util/BytecodeUtils.scala index d76e84ed8c9ed..a559685b1633c 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/util/BytecodeUtils.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/util/BytecodeUtils.scala @@ -22,8 +22,8 @@ import java.io.{ByteArrayInputStream, ByteArrayOutputStream} import scala.collection.mutable.HashSet import scala.language.existentials -import org.apache.xbean.asm5.{ClassReader, ClassVisitor, MethodVisitor} -import org.apache.xbean.asm5.Opcodes._ +import org.apache.xbean.asm6.{ClassReader, ClassVisitor, MethodVisitor} +import org.apache.xbean.asm6.Opcodes._ import org.apache.spark.util.Utils diff --git a/pom.xml b/pom.xml index 90e64ff71d229..ca30f9f12b098 100644 --- a/pom.xml +++ b/pom.xml @@ -313,13 +313,13 @@ chill-java ${chill.version} - org.apache.xbean - xbean-asm5-shaded - 4.4 + xbean-asm6-shaded + 4.8 @@ -166,7 +166,7 @@ - + scala-2.12 diff --git a/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala b/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala index 4dc399827ffed..42298b06a2c86 100644 --- a/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala +++ b/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala @@ -22,8 +22,8 @@ import java.net.{URI, URL, URLEncoder} import java.nio.channels.Channels import org.apache.hadoop.fs.{FileSystem, Path} -import org.apache.xbean.asm5._ -import org.apache.xbean.asm5.Opcodes._ +import org.apache.xbean.asm6._ +import org.apache.xbean.asm6.Opcodes._ import org.apache.spark.{SparkConf, SparkEnv} import org.apache.spark.deploy.SparkHadoopUtil diff --git a/sql/core/pom.xml b/sql/core/pom.xml index f270c70fbfcf0..18ae314309d7b 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -118,7 +118,7 @@ org.apache.xbean - xbean-asm5-shaded + xbean-asm6-shaded org.scalacheck From 776f299fc8146b400e97185b1577b0fc8f06e14b Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Wed, 4 Jul 2018 09:38:18 +0800 Subject: [PATCH 1053/1161] [SPARK-24709][SQL] schema_of_json() - schema inference from an example ## What changes were proposed in this pull request? In the PR, I propose to add new function - *schema_of_json()* which infers schema of JSON string literal. The result of the function is a string containing a schema in DDL format. One of the use cases is using of *schema_of_json()* in the combination with *from_json()*. Currently, _from_json()_ requires a schema as a mandatory argument. The *schema_of_json()* function will allow to point out an JSON string as an example which has the same schema as the first argument of _from_json()_. For instance: ```sql select from_json(json_column, schema_of_json('{"c1": [0], "c2": [{"c3":0}]}')) from json_table; ``` ## How was this patch tested? Added new test to `JsonFunctionsSuite`, `JsonExpressionsSuite` and SQL tests to `json-functions.sql` Author: Maxim Gekk Closes #21686 from MaxGekk/infer_schema_json. --- python/pyspark/sql/functions.py | 27 ++++++++++ .../catalyst/analysis/FunctionRegistry.scala | 1 + .../expressions/jsonExpressions.scala | 52 ++++++++++++++++--- .../sql/catalyst}/json/JsonInferSchema.scala | 5 +- .../expressions/JsonExpressionsSuite.scala | 7 +++ .../datasources/json/JsonDataSource.scala | 2 +- .../org/apache/spark/sql/functions.scala | 42 +++++++++++++++ .../sql-tests/inputs/json-functions.sql | 4 ++ .../sql-tests/results/json-functions.sql.out | 20 ++++++- .../apache/spark/sql/JsonFunctionsSuite.scala | 17 +++++- .../datasources/json/JsonSuite.scala | 4 +- 11 files changed, 163 insertions(+), 18 deletions(-) rename sql/{core/src/main/scala/org/apache/spark/sql/execution/datasources => catalyst/src/main/scala/org/apache/spark/sql/catalyst}/json/JsonInferSchema.scala (98%) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 9652d3e79b875..4d371976364d3 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2189,11 +2189,16 @@ def from_json(col, schema, options={}): >>> df = spark.createDataFrame(data, ("key", "value")) >>> df.select(from_json(df.value, schema).alias("json")).collect() [Row(json=[Row(a=1)])] + >>> schema = schema_of_json(lit('''{"a": 0}''')) + >>> df.select(from_json(df.value, schema).alias("json")).collect() + [Row(json=Row(a=1))] """ sc = SparkContext._active_spark_context if isinstance(schema, DataType): schema = schema.json() + elif isinstance(schema, Column): + schema = _to_java_column(schema) jc = sc._jvm.functions.from_json(_to_java_column(col), schema, options) return Column(jc) @@ -2235,6 +2240,28 @@ def to_json(col, options={}): return Column(jc) +@ignore_unicode_prefix +@since(2.4) +def schema_of_json(col): + """ + Parses a column containing a JSON string and infers its schema in DDL format. + + :param col: string column in json format + + >>> from pyspark.sql.types import * + >>> data = [(1, '{"a": 1}')] + >>> df = spark.createDataFrame(data, ("key", "value")) + >>> df.select(schema_of_json(df.value).alias("json")).collect() + [Row(json=u'struct')] + >>> df.select(schema_of_json(lit('{"a": 0}')).alias("json")).collect() + [Row(json=u'struct')] + """ + + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.schema_of_json(_to_java_column(col)) + return Column(jc) + + @since(1.5) def size(col): """ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index a574d8a84d4fb..80a0af672bf74 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -505,6 +505,7 @@ object FunctionRegistry { // json expression[StructsToJson]("to_json"), expression[JsonToStructs]("from_json"), + expression[SchemaOfJson]("schema_of_json"), // cast expression[Cast]("cast"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index f6d74f5b74c8e..8cd86053a01c7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.expressions -import java.io.{ByteArrayInputStream, ByteArrayOutputStream, CharArrayWriter, InputStreamReader, StringWriter} +import java.io._ import scala.util.parsing.combinator.RegexParsers @@ -28,7 +28,8 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.json._ -import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, BadRecordException, FailFastMode, GenericArrayData, MapData} +import org.apache.spark.sql.catalyst.json.JsonInferSchema.inferField +import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -525,17 +526,19 @@ case class JsonToStructs( override def nullable: Boolean = true // Used in `FunctionRegistry` - def this(child: Expression, schema: Expression) = + def this(child: Expression, schema: Expression, options: Map[String, String]) = this( - schema = JsonExprUtils.validateSchemaLiteral(schema), - options = Map.empty[String, String], + schema = JsonExprUtils.evalSchemaExpr(schema), + options = options, child = child, timeZoneId = None, forceNullableSchema = SQLConf.get.getConf(SQLConf.FROM_JSON_FORCE_NULLABLE_SCHEMA)) + def this(child: Expression, schema: Expression) = this(child, schema, Map.empty[String, String]) + def this(child: Expression, schema: Expression, options: Expression) = this( - schema = JsonExprUtils.validateSchemaLiteral(schema), + schema = JsonExprUtils.evalSchemaExpr(schema), options = JsonExprUtils.convertToMapData(options), child = child, timeZoneId = None, @@ -744,11 +747,44 @@ case class StructsToJson( override def inputTypes: Seq[AbstractDataType] = TypeCollection(ArrayType, StructType) :: Nil } +/** + * A function infers schema of JSON string. + */ +@ExpressionDescription( + usage = "_FUNC_(json[, options]) - Returns schema in the DDL format of JSON string.", + examples = """ + Examples: + > SELECT _FUNC_('[{"col":0}]'); + array> + """, + since = "2.4.0") +case class SchemaOfJson(child: Expression) + extends UnaryExpression with String2StringExpression with CodegenFallback { + + private val jsonOptions = new JSONOptions(Map.empty, "UTC") + private val jsonFactory = new JsonFactory() + jsonOptions.setJacksonOptions(jsonFactory) + + override def convert(v: UTF8String): UTF8String = { + val dt = Utils.tryWithResource(CreateJacksonParser.utf8String(jsonFactory, v)) { parser => + parser.nextToken() + inferField(parser, jsonOptions) + } + + UTF8String.fromString(dt.catalogString) + } +} + object JsonExprUtils { - def validateSchemaLiteral(exp: Expression): DataType = exp match { + def evalSchemaExpr(exp: Expression): DataType = exp match { case Literal(s, StringType) => DataType.fromDDL(s.toString) - case e => throw new AnalysisException(s"Expected a string literal instead of $e") + case e @ SchemaOfJson(_: Literal) => + val ddlSchema = e.eval().asInstanceOf[UTF8String] + DataType.fromDDL(ddlSchema.toString) + case e => throw new AnalysisException( + "Schema should be specified in DDL format as a string literal" + + s" or output of the schema_of_json function instead of ${e.sql}") } def convertToMapData(exp: Expression): Map[String, String] = exp match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala similarity index 98% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala index 8e1b430f4eb33..491ca005877f8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.execution.datasources.json +package org.apache.spark.sql.catalyst.json import java.util.Comparator @@ -25,7 +25,6 @@ import org.apache.spark.SparkException import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.analysis.TypeCoercion import org.apache.spark.sql.catalyst.json.JacksonUtils.nextUntil -import org.apache.spark.sql.catalyst.json.JSONOptions import org.apache.spark.sql.catalyst.util.{DropMalformedMode, FailFastMode, ParseMode, PermissiveMode} import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -103,7 +102,7 @@ private[sql] object JsonInferSchema { /** * Infer the type of a json document from the parser's token stream */ - private def inferField(parser: JsonParser, configOptions: JSONOptions): DataType = { + def inferField(parser: JsonParser, configOptions: JSONOptions): DataType = { import com.fasterxml.jackson.core.JsonToken._ parser.getCurrentToken match { case null | VALUE_NULL => NullType diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala index 00e97637eee7e..52203b9e337ba 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala @@ -706,4 +706,11 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with assert(schemaToCompare == schema) } } + + test("SPARK-24709: infer schema of json strings") { + checkEvaluation(SchemaOfJson(Literal.create("""{"col":0}""")), "struct") + checkEvaluation( + SchemaOfJson(Literal.create("""{"col0":["a"], "col1": {"col2": "b"}}""")), + "struct,col1:struct>") + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala index 3b6df45e949e8..2fee2128ba1f9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala @@ -33,7 +33,7 @@ import org.apache.spark.input.{PortableDataStream, StreamInputFormat} import org.apache.spark.rdd.{BinaryFileRDD, RDD} import org.apache.spark.sql.{Dataset, Encoders, SparkSession} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JSONOptions} +import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JsonInferSchema, JSONOptions} import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.text.TextFileFormat diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index acca9572cb14c..614f65f0faaba 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3381,6 +3381,48 @@ object functions { from_json(e, dataType, options) } + /** + * (Scala-specific) Parses a column containing a JSON string into a `MapType` with `StringType` + * as keys type, `StructType` or `ArrayType` of `StructType`s with the specified schema. + * Returns `null`, in the case of an unparseable string. + * + * @param e a string column containing JSON data. + * @param schema the schema to use when parsing the json string + * + * @group collection_funcs + * @since 2.4.0 + */ + def from_json(e: Column, schema: Column): Column = { + from_json(e, schema, Map.empty[String, String].asJava) + } + + /** + * (Java-specific) Parses a column containing a JSON string into a `MapType` with `StringType` + * as keys type, `StructType` or `ArrayType` of `StructType`s with the specified schema. + * Returns `null`, in the case of an unparseable string. + * + * @param e a string column containing JSON data. + * @param schema the schema to use when parsing the json string + * @param options options to control how the json is parsed. accepts the same options and the + * json data source. + * + * @group collection_funcs + * @since 2.4.0 + */ + def from_json(e: Column, schema: Column, options: java.util.Map[String, String]): Column = { + withExpr(new JsonToStructs(e.expr, schema.expr, options.asScala.toMap)) + } + + /** + * Parses a column containing a JSON string and infers its schema. + * + * @param e a string column containing JSON data. + * + * @group collection_funcs + * @since 2.4.0 + */ + def schema_of_json(e: Column): Column = withExpr(new SchemaOfJson(e.expr)) + /** * (Scala-specific) Converts a column containing a `StructType`, `ArrayType` of `StructType`s, * a `MapType` or `ArrayType` of `MapType`s into a JSON string with the specified schema. diff --git a/sql/core/src/test/resources/sql-tests/inputs/json-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/json-functions.sql index dc15d13cd1dd3..79fdd5895e691 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/json-functions.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/json-functions.sql @@ -35,3 +35,7 @@ DROP VIEW IF EXISTS jsonTable; -- from_json - complex types select from_json('{"a":1, "b":2}', 'map'); select from_json('{"a":1, "b":"2"}', 'struct'); + +-- infer schema of json literal +select schema_of_json('{"c1":0, "c2":[1]}'); +select from_json('{"c1":[1, 2, 3]}', schema_of_json('{"c1":[0]}')); diff --git a/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out index 2b3288dc5a137..3d49323751a10 100644 --- a/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 28 +-- Number of queries: 30 -- !query 0 @@ -183,7 +183,7 @@ select from_json('{"a":1}', 1) struct<> -- !query 17 output org.apache.spark.sql.AnalysisException -Expected a string literal instead of 1;; line 1 pos 7 +Schema should be specified in DDL format as a string literal or output of the schema_of_json function instead of 1;; line 1 pos 7 -- !query 18 @@ -274,3 +274,19 @@ select from_json('{"a":1, "b":"2"}', 'struct') struct> -- !query 27 output {"a":1,"b":"2"} + + +-- !query 28 +select schema_of_json('{"c1":0, "c2":[1]}') +-- !query 28 schema +struct +-- !query 28 output +struct> + + +-- !query 29 +select from_json('{"c1":[1, 2, 3]}', schema_of_json('{"c1":[0]}')) +-- !query 29 schema +struct>> +-- !query 29 output +{"c1":[1,2,3]} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala index 7bf17cbcd9c97..d3b2701f2558e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql -import org.apache.spark.sql.functions.{from_json, lit, map, struct, to_json} +import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ @@ -311,7 +311,7 @@ class JsonFunctionsSuite extends QueryTest with SharedSQLContext { val errMsg1 = intercept[AnalysisException] { df3.selectExpr("from_json(value, 1)") } - assert(errMsg1.getMessage.startsWith("Expected a string literal instead of")) + assert(errMsg1.getMessage.startsWith("Schema should be specified in DDL format as a string")) val errMsg2 = intercept[AnalysisException] { df3.selectExpr("""from_json(value, 'time InvalidType')""") } @@ -392,4 +392,17 @@ class JsonFunctionsSuite extends QueryTest with SharedSQLContext { checkAnswer(Seq("""{"{"f": 1}": "a"}""").toDS().select(from_json($"value", schema)), Row(null)) } + + test("SPARK-24709: infers schemas of json strings and pass them to from_json") { + val in = Seq("""{"a": [1, 2, 3]}""").toDS() + val out = in.select(from_json('value, schema_of_json(lit("""{"a": [1]}"""))) as "parsed") + val expected = StructType(StructField( + "parsed", + StructType(StructField( + "a", + ArrayType(LongType, true), true) :: Nil), + true) :: Nil) + + assert(out.schema == expected) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 897424daca0cb..eab15b35c97d3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -31,11 +31,11 @@ import org.apache.hadoop.io.compress.GzipCodec import org.apache.spark.{SparkException, TestUtils} import org.apache.spark.rdd.RDD import org.apache.spark.sql.{functions => F, _} -import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JSONOptions} +import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JsonInferSchema, JSONOptions} +import org.apache.spark.sql.catalyst.json.JsonInferSchema.compatibleType import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.ExternalRDD import org.apache.spark.sql.execution.datasources.DataSource -import org.apache.spark.sql.execution.datasources.json.JsonInferSchema.compatibleType import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ From b42fda8ab3b5f82b33b96fce3f584c50f2ed5a3a Mon Sep 17 00:00:00 2001 From: cclauss Date: Wed, 4 Jul 2018 09:40:58 +0800 Subject: [PATCH 1054/1161] [SPARK-23698] Remove raw_input() from Python 2 Signed-off-by: cclauss ## What changes were proposed in this pull request? Humans will be able to enter text in Python 3 prompts which they can not do today. The Python builtin __raw_input()__ was removed in Python 3 in favor of __input()__. This PR does the same thing in Python 2. ## How was this patch tested? (Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests) (If this patch involves UI changes, please attach a screenshot; otherwise, remove this) flake8 testing Please review http://spark.apache.org/contributing.html before opening a pull request. Author: cclauss Closes #21702 from cclauss/python-fix-raw_input. --- dev/create-release/releaseutils.py | 5 ++++- dev/merge_spark_pr.py | 21 ++++++++++++--------- 2 files changed, 16 insertions(+), 10 deletions(-) diff --git a/dev/create-release/releaseutils.py b/dev/create-release/releaseutils.py index 32f6cbb29f0be..ab812e1bb7c04 100755 --- a/dev/create-release/releaseutils.py +++ b/dev/create-release/releaseutils.py @@ -49,13 +49,16 @@ print("Install using 'sudo pip install unidecode'") sys.exit(-1) +if sys.version < '3': + input = raw_input + # Contributors list file name contributors_file_name = "contributors.txt" # Prompt the user to answer yes or no until they do so def yesOrNoPrompt(msg): - response = raw_input("%s [y/n]: " % msg) + response = input("%s [y/n]: " % msg) while response != "y" and response != "n": return yesOrNoPrompt(msg) return response == "y" diff --git a/dev/merge_spark_pr.py b/dev/merge_spark_pr.py index 7f46a1c8f6a7c..79c7c021fe74a 100755 --- a/dev/merge_spark_pr.py +++ b/dev/merge_spark_pr.py @@ -39,6 +39,9 @@ except ImportError: JIRA_IMPORTED = False +if sys.version < '3': + input = raw_input + # Location of your Spark git development area SPARK_HOME = os.environ.get("SPARK_HOME", os.getcwd()) # Remote name which points to the Gihub site @@ -95,7 +98,7 @@ def run_cmd(cmd): def continue_maybe(prompt): - result = raw_input("\n%s (y/n): " % prompt) + result = input("\n%s (y/n): " % prompt) if result.lower() != "y": fail("Okay, exiting") @@ -134,7 +137,7 @@ def merge_pr(pr_num, target_ref, title, body, pr_repo_desc): '--pretty=format:%an <%ae>']).split("\n") distinct_authors = sorted(set(commit_authors), key=lambda x: commit_authors.count(x), reverse=True) - primary_author = raw_input( + primary_author = input( "Enter primary author in the format of \"name \" [%s]: " % distinct_authors[0]) if primary_author == "": @@ -184,7 +187,7 @@ def merge_pr(pr_num, target_ref, title, body, pr_repo_desc): def cherry_pick(pr_num, merge_hash, default_branch): - pick_ref = raw_input("Enter a branch name [%s]: " % default_branch) + pick_ref = input("Enter a branch name [%s]: " % default_branch) if pick_ref == "": pick_ref = default_branch @@ -231,7 +234,7 @@ def resolve_jira_issue(merge_branches, comment, default_jira_id=""): asf_jira = jira.client.JIRA({'server': JIRA_API_BASE}, basic_auth=(JIRA_USERNAME, JIRA_PASSWORD)) - jira_id = raw_input("Enter a JIRA id [%s]: " % default_jira_id) + jira_id = input("Enter a JIRA id [%s]: " % default_jira_id) if jira_id == "": jira_id = default_jira_id @@ -276,7 +279,7 @@ def resolve_jira_issue(merge_branches, comment, default_jira_id=""): default_fix_versions = filter(lambda x: x != v, default_fix_versions) default_fix_versions = ",".join(default_fix_versions) - fix_versions = raw_input("Enter comma-separated fix version(s) [%s]: " % default_fix_versions) + fix_versions = input("Enter comma-separated fix version(s) [%s]: " % default_fix_versions) if fix_versions == "": fix_versions = default_fix_versions fix_versions = fix_versions.replace(" ", "").split(",") @@ -315,7 +318,7 @@ def choose_jira_assignee(issue, asf_jira): if author in commentors: annotations.append("Commentor") print("[%d] %s (%s)" % (idx, author.displayName, ",".join(annotations))) - raw_assignee = raw_input( + raw_assignee = input( "Enter number of user, or userid, to assign to (blank to leave unassigned):") if raw_assignee == "": return None @@ -428,7 +431,7 @@ def main(): # Assumes branch names can be sorted lexicographically latest_branch = sorted(branch_names, reverse=True)[0] - pr_num = raw_input("Which pull request would you like to merge? (e.g. 34): ") + pr_num = input("Which pull request would you like to merge? (e.g. 34): ") pr = get_json("%s/pulls/%s" % (GITHUB_API_BASE, pr_num)) pr_events = get_json("%s/issues/%s/events" % (GITHUB_API_BASE, pr_num)) @@ -440,7 +443,7 @@ def main(): print("I've re-written the title as follows to match the standard format:") print("Original: %s" % pr["title"]) print("Modified: %s" % modified_title) - result = raw_input("Would you like to use the modified title? (y/n): ") + result = input("Would you like to use the modified title? (y/n): ") if result.lower() == "y": title = modified_title print("Using modified title:") @@ -491,7 +494,7 @@ def main(): merge_hash = merge_pr(pr_num, target_ref, title, body, pr_repo_desc) pick_prompt = "Would you like to pick %s into another branch?" % merge_hash - while raw_input("\n%s (y/n): " % pick_prompt).lower() == "y": + while input("\n%s (y/n): " % pick_prompt).lower() == "y": merged_refs = merged_refs + [cherry_pick(pr_num, merge_hash, latest_branch)] if JIRA_IMPORTED: From 5bf95f2a37e624eb6fb0ef6fbd2a40a129d5a470 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Wed, 4 Jul 2018 09:53:04 +0800 Subject: [PATCH 1055/1161] [BUILD] Close stale PRs Closes #20932 Closes #17843 Closes #13477 Closes #14291 Closes #20919 Closes #17907 Closes #18766 Closes #20809 Closes #8849 Closes #21076 Closes #21507 Closes #21336 Closes #21681 Closes #21691 Author: Sean Owen Closes #21708 from srowen/CloseStalePRs. From 7c08eb6d61d55ce45229f3302e6d463e7669183d Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Wed, 4 Jul 2018 12:21:26 +0800 Subject: [PATCH 1056/1161] [SPARK-24732][SQL] Type coercion between MapTypes. ## What changes were proposed in this pull request? Currently we don't allow type coercion between maps. We can support type coercion between MapTypes where both the key types and the value types are compatible. ## How was this patch tested? Added tests. Author: Takuya UESHIN Closes #21703 from ueshin/issues/SPARK-24732/maptypecoercion. --- .../sql/catalyst/analysis/TypeCoercion.scala | 12 +++++ .../catalyst/analysis/TypeCoercionSuite.scala | 45 ++++++++++++++++++- 2 files changed, 56 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 3ebab430ffbcd..cf90e6e555fc8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -179,6 +179,12 @@ object TypeCoercion { .orElse((t1, t2) match { case (ArrayType(et1, containsNull1), ArrayType(et2, containsNull2)) => findWiderTypeForTwo(et1, et2).map(ArrayType(_, containsNull1 || containsNull2)) + case (MapType(kt1, vt1, valueContainsNull1), MapType(kt2, vt2, valueContainsNull2)) => + findWiderTypeForTwo(kt1, kt2).flatMap { kt => + findWiderTypeForTwo(vt1, vt2).map { vt => + MapType(kt, vt, valueContainsNull1 || valueContainsNull2) + } + } case _ => None }) } @@ -220,6 +226,12 @@ object TypeCoercion { case (ArrayType(et1, containsNull1), ArrayType(et2, containsNull2)) => findWiderTypeWithoutStringPromotionForTwo(et1, et2) .map(ArrayType(_, containsNull1 || containsNull2)) + case (MapType(kt1, vt1, valueContainsNull1), MapType(kt2, vt2, valueContainsNull2)) => + findWiderTypeWithoutStringPromotionForTwo(kt1, kt2).flatMap { kt => + findWiderTypeWithoutStringPromotionForTwo(vt1, vt2).map { vt => + MapType(kt, vt, valueContainsNull1 || valueContainsNull2) + } + } case _ => None }) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala index 0acd3b490447d..4e5ca1b8cdd36 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala @@ -54,8 +54,9 @@ class TypeCoercionSuite extends AnalysisTest { // | NullType | ByteType | ShortType | IntegerType | LongType | DoubleType | FloatType | Dec(10, 2) | BinaryType | BooleanType | StringType | DateType | TimestampType | ArrayType | MapType | StructType | NullType | CalendarIntervalType | DecimalType(38, 18) | DoubleType | IntegerType | // | CalendarIntervalType | X | X | X | X | X | X | X | X | X | X | X | X | X | X | X | X | CalendarIntervalType | X | X | X | // +----------------------+----------+-----------+-------------+----------+------------+-----------+------------+------------+-------------+------------+----------+---------------+------------+----------+-------------+----------+----------------------+---------------------+-------------+--------------+ - // Note: MapType*, StructType* are castable only when the internal child types also match; otherwise, not castable. + // Note: StructType* is castable only when the internal child types also match; otherwise, not castable. // Note: ArrayType* is castable when the element type is castable according to the table. + // Note: MapType* is castable when both the key type and the value type are castable according to the table. // scalastyle:on line.size.limit private def shouldCast(from: DataType, to: AbstractDataType, expected: DataType): Unit = { @@ -487,12 +488,38 @@ class TypeCoercionSuite extends AnalysisTest { ArrayType(ArrayType(IntegerType), containsNull = false), ArrayType(ArrayType(LongType), containsNull = false), Some(ArrayType(ArrayType(LongType), containsNull = false))) + widenTestWithStringPromotion( + ArrayType(MapType(IntegerType, FloatType), containsNull = false), + ArrayType(MapType(LongType, DoubleType), containsNull = false), + Some(ArrayType(MapType(LongType, DoubleType), containsNull = false))) + + // MapType + widenTestWithStringPromotion( + MapType(ShortType, TimestampType, valueContainsNull = true), + MapType(DoubleType, StringType, valueContainsNull = false), + Some(MapType(DoubleType, StringType, valueContainsNull = true))) + widenTestWithStringPromotion( + MapType(IntegerType, ArrayType(TimestampType), valueContainsNull = false), + MapType(LongType, ArrayType(StringType), valueContainsNull = true), + Some(MapType(LongType, ArrayType(StringType), valueContainsNull = true))) + widenTestWithStringPromotion( + MapType(IntegerType, MapType(ShortType, TimestampType), valueContainsNull = false), + MapType(LongType, MapType(DoubleType, StringType), valueContainsNull = false), + Some(MapType(LongType, MapType(DoubleType, StringType), valueContainsNull = false))) // Without string promotion widenTestWithoutStringPromotion(IntegerType, StringType, None) widenTestWithoutStringPromotion(StringType, TimestampType, None) widenTestWithoutStringPromotion(ArrayType(LongType), ArrayType(StringType), None) widenTestWithoutStringPromotion(ArrayType(StringType), ArrayType(TimestampType), None) + widenTestWithoutStringPromotion( + MapType(LongType, IntegerType), MapType(StringType, IntegerType), None) + widenTestWithoutStringPromotion( + MapType(IntegerType, LongType), MapType(IntegerType, StringType), None) + widenTestWithoutStringPromotion( + MapType(StringType, IntegerType), MapType(TimestampType, IntegerType), None) + widenTestWithoutStringPromotion( + MapType(IntegerType, StringType), MapType(IntegerType, TimestampType), None) // String promotion widenTestWithStringPromotion(IntegerType, StringType, Some(StringType)) @@ -501,6 +528,22 @@ class TypeCoercionSuite extends AnalysisTest { ArrayType(LongType), ArrayType(StringType), Some(ArrayType(StringType))) widenTestWithStringPromotion( ArrayType(StringType), ArrayType(TimestampType), Some(ArrayType(StringType))) + widenTestWithStringPromotion( + MapType(LongType, IntegerType), + MapType(StringType, IntegerType), + Some(MapType(StringType, IntegerType))) + widenTestWithStringPromotion( + MapType(IntegerType, LongType), + MapType(IntegerType, StringType), + Some(MapType(IntegerType, StringType))) + widenTestWithStringPromotion( + MapType(StringType, IntegerType), + MapType(TimestampType, IntegerType), + Some(MapType(StringType, IntegerType))) + widenTestWithStringPromotion( + MapType(IntegerType, StringType), + MapType(IntegerType, TimestampType), + Some(MapType(IntegerType, StringType))) } private def ruleTest(rule: Rule[LogicalPlan], initial: Expression, transformed: Expression) { From 772060d0940a97d89807befd682a70ae82e83ef4 Mon Sep 17 00:00:00 2001 From: Stan Zhai Date: Wed, 4 Jul 2018 10:12:36 +0200 Subject: [PATCH 1057/1161] [SPARK-24704][WEBUI] Fix the order of stages in the DAG graph ## What changes were proposed in this pull request? Before: ![wx20180630-155537](https://user-images.githubusercontent.com/1438757/42123357-2c2e2d84-7c83-11e8-8abd-1c2860f38783.png) After: ![wx20180630-155604](https://user-images.githubusercontent.com/1438757/42123359-32fae990-7c83-11e8-8a7b-cdcee94f9123.png) ## How was this patch tested? Manual tests. Author: Stan Zhai Closes #21680 from stanzhai/fix-dag-graph. --- .../src/main/scala/org/apache/spark/status/AppStatusStore.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala b/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala index 688f25a9fdea1..e237281c552b1 100644 --- a/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala +++ b/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala @@ -471,7 +471,7 @@ private[spark] class AppStatusStore( def operationGraphForJob(jobId: Int): Seq[RDDOperationGraph] = { val job = store.read(classOf[JobDataWrapper], jobId) - val stages = job.info.stageIds + val stages = job.info.stageIds.sorted stages.map { id => val g = store.read(classOf[RDDOperationGraphWrapper], id).toRDDOperationGraph() From b2deef64f604ddd9502a31105ed47cb63470ec85 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Wed, 4 Jul 2018 20:04:18 +0800 Subject: [PATCH 1058/1161] [SPARK-24727][SQL] Add a static config to control cache size for generated classes ## What changes were proposed in this pull request? Since SPARK-24250 has been resolved, executors correctly references user-defined configurations. So, this pr added a static config to control cache size for generated classes in `CodeGenerator`. ## How was this patch tested? Added tests in `ExecutorSideSQLConfSuite`. Author: Takeshi Yamamuro Closes #21705 from maropu/SPARK-24727. --- .../expressions/codegen/CodeGenerator.scala | 2 +- .../apache/spark/sql/internal/SQLConf.scala | 2 ++ .../spark/sql/internal/StaticSQLConf.scala | 8 +++++ .../internal/ExecutorSideSQLConfSuite.scala | 31 +++++++++++++++---- 4 files changed, 36 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 4cc0968911cb5..838c045d5bcce 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -1415,7 +1415,7 @@ object CodeGenerator extends Logging { * weak keys/values and thus does not respond to memory pressure. */ private val cache = CacheBuilder.newBuilder() - .maximumSize(100) + .maximumSize(SQLConf.get.codegenCacheMaxEntries) .build( new CacheLoader[CodeAndComment, (GeneratedClass, Int)]() { override def load(code: CodeAndComment): (GeneratedClass, Int) = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index e2c48e2d8a14c..50965c1abc68c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1508,6 +1508,8 @@ class SQLConf extends Serializable with Logging { def tableRelationCacheSize: Int = getConf(StaticSQLConf.FILESOURCE_TABLE_RELATION_CACHE_SIZE) + def codegenCacheMaxEntries: Int = getConf(StaticSQLConf.CODEGEN_CACHE_MAX_ENTRIES) + def exchangeReuseEnabled: Boolean = getConf(EXCHANGE_REUSE_ENABLED) def caseSensitiveAnalysis: Boolean = getConf(SQLConf.CASE_SENSITIVE) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala index 382ef28f49a7a..384b1917a1f79 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala @@ -66,6 +66,14 @@ object StaticSQLConf { .checkValue(cacheSize => cacheSize >= 0, "The maximum size of the cache must not be negative") .createWithDefault(1000) + val CODEGEN_CACHE_MAX_ENTRIES = buildStaticConf("spark.sql.codegen.cache.maxEntries") + .internal() + .doc("When nonzero, enable caching of generated classes for operators and expressions. " + + "All jobs share the cache that can use up to the specified number for generated classes.") + .intConf + .checkValue(maxEntries => maxEntries >= 0, "The maximum must not be negative") + .createWithDefault(100) + // When enabling the debug, Spark SQL internal table properties are not filtered out; however, // some related DDL commands (e.g., ANALYZE TABLE and CREATE TABLE LIKE) might not work properly. val DEBUG_MODE = buildStaticConf("spark.sql.debug") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/ExecutorSideSQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/ExecutorSideSQLConfSuite.scala index 3dd0712e02448..855fe4f4523f2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/internal/ExecutorSideSQLConfSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/ExecutorSideSQLConfSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.internal import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.{AnalysisException, SparkSession} import org.apache.spark.sql.test.SQLTestUtils class ExecutorSideSQLConfSuite extends SparkFunSuite with SQLTestUtils { @@ -40,16 +40,24 @@ class ExecutorSideSQLConfSuite extends SparkFunSuite with SQLTestUtils { spark = null } + override def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = { + pairs.foreach { case (k, v) => + SQLConf.get.setConfString(k, v) + } + try f finally { + pairs.foreach { case (k, _) => + SQLConf.get.unsetConf(k) + } + } + } + test("ReadOnlySQLConf is correctly created at the executor side") { - SQLConf.get.setConfString("spark.sql.x", "a") - try { - val checks = spark.range(10).mapPartitions { it => + withSQLConf("spark.sql.x" -> "a") { + val checks = spark.range(10).mapPartitions { _ => val conf = SQLConf.get Iterator(conf.isInstanceOf[ReadOnlySQLConf] && conf.getConfString("spark.sql.x") == "a") }.collect() assert(checks.forall(_ == true)) - } finally { - SQLConf.get.unsetConf("spark.sql.x") } } @@ -63,4 +71,15 @@ class ExecutorSideSQLConfSuite extends SparkFunSuite with SQLTestUtils { } } } + + test("SPARK-24727 CODEGEN_CACHE_MAX_ENTRIES is correctly referenced at the executor side") { + withSQLConf(StaticSQLConf.CODEGEN_CACHE_MAX_ENTRIES.key -> "300") { + val checks = spark.range(10).mapPartitions { _ => + val conf = SQLConf.get + Iterator(conf.isInstanceOf[ReadOnlySQLConf] && + conf.getConfString(StaticSQLConf.CODEGEN_CACHE_MAX_ENTRIES.key) == "300") + }.collect() + assert(checks.forall(_ == true)) + } + } } From 021145f36432b386cce30450c888a85393d5169f Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Wed, 4 Jul 2018 20:15:40 +0800 Subject: [PATCH 1059/1161] [SPARK-24716][SQL] Refactor ParquetFilters ## What changes were proposed in this pull request? Replace DataFrame schema to Parquet file schema when create `ParquetFilters`. Thus we can easily implement `Decimal` and `Timestamp` push down. some thing like this: ```scala // DecimalType: 32BitDecimalType case ParquetSchemaType(DECIMAL, INT32, decimal) if pushDownDecimal => (n: String, v: Any) => FilterApi.eq( intColumn(n), Option(v).map(_.asInstanceOf[JBigDecimal].unscaledValue().intValue() .asInstanceOf[Integer]).orNull) // DecimalType: 64BitDecimalType case ParquetSchemaType(DECIMAL, INT64, decimal) if pushDownDecimal => (n: String, v: Any) => FilterApi.eq( longColumn(n), Option(v).map(_.asInstanceOf[JBigDecimal].unscaledValue().longValue() .asInstanceOf[java.lang.Long]).orNull) // DecimalType: LegacyParquetFormat 32BitDecimalType & 64BitDecimalType case ParquetSchemaType(DECIMAL, FIXED_LEN_BYTE_ARRAY, decimal) if pushDownDecimal && decimal.getPrecision <= Decimal.MAX_LONG_DIGITS => (n: String, v: Any) => FilterApi.eq( binaryColumn(n), Option(v).map(d => decimalToBinaryUsingUnscaledLong(decimal.getPrecision, d.asInstanceOf[JBigDecimal])).orNull) // DecimalType: ByteArrayDecimalType case ParquetSchemaType(DECIMAL, FIXED_LEN_BYTE_ARRAY, decimal) if pushDownDecimal && decimal.getPrecision > Decimal.MAX_LONG_DIGITS => (n: String, v: Any) => FilterApi.eq( binaryColumn(n), Option(v).map(d => decimalToBinaryUsingUnscaledBytes(decimal.getPrecision, d.asInstanceOf[JBigDecimal])).orNull) ``` ```scala // INT96 doesn't support pushdown case ParquetSchemaType(TIMESTAMP_MICROS, INT64, null) => (n: String, v: Any) => FilterApi.eq( longColumn(n), Option(v).map(t => DateTimeUtils.fromJavaTimestamp(t.asInstanceOf[Timestamp]) .asInstanceOf[java.lang.Long]).orNull) case ParquetSchemaType(TIMESTAMP_MILLIS, INT64, null) => (n: String, v: Any) => FilterApi.eq( longColumn(n), Option(v).map(_.asInstanceOf[Timestamp].getTime.asInstanceOf[java.lang.Long]).orNull) ``` ## How was this patch tested? unit tests Author: Yuming Wang Closes #21696 from wangyum/SPARK-24716. --- .../parquet/ParquetFileFormat.scala | 34 ++-- .../datasources/parquet/ParquetFilters.scala | 173 ++++++++++-------- .../apache/spark/sql/sources/filters.scala | 2 +- .../parquet/ParquetFilterSuite.scala | 13 +- 4 files changed, 121 insertions(+), 101 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala index 93de1faef527a..52a18abb55241 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala @@ -353,25 +353,13 @@ class ParquetFileFormat (file: PartitionedFile) => { assert(file.partitionValues.numFields == partitionSchema.size) - // Try to push down filters when filter push-down is enabled. - val pushed = if (enableParquetFilterPushDown) { - filters - // Collects all converted Parquet filter predicates. Notice that not all predicates can be - // converted (`ParquetFilters.createFilter` returns an `Option`). That's why a `flatMap` - // is used here. - .flatMap(new ParquetFilters(pushDownDate, pushDownStringStartWith) - .createFilter(requiredSchema, _)) - .reduceOption(FilterApi.and) - } else { - None - } - val fileSplit = new FileSplit(new Path(new URI(file.filePath)), file.start, file.length, Array.empty) + val filePath = fileSplit.getPath val split = new org.apache.parquet.hadoop.ParquetInputSplit( - fileSplit.getPath, + filePath, fileSplit.getStart, fileSplit.getStart + fileSplit.getLength, fileSplit.getLength, @@ -379,12 +367,28 @@ class ParquetFileFormat null) val sharedConf = broadcastedHadoopConf.value.value + + // Try to push down filters when filter push-down is enabled. + val pushed = if (enableParquetFilterPushDown) { + val parquetSchema = ParquetFileReader.readFooter(sharedConf, filePath, SKIP_ROW_GROUPS) + .getFileMetaData.getSchema + filters + // Collects all converted Parquet filter predicates. Notice that not all predicates can be + // converted (`ParquetFilters.createFilter` returns an `Option`). That's why a `flatMap` + // is used here. + .flatMap(new ParquetFilters(pushDownDate, pushDownStringStartWith) + .createFilter(parquetSchema, _)) + .reduceOption(FilterApi.and) + } else { + None + } + // PARQUET_INT96_TIMESTAMP_CONVERSION says to apply timezone conversions to int96 timestamps' // *only* if the file was created by something other than "parquet-mr", so check the actual // writer here for this file. We have to do this per-file, as each file in the table may // have different writers. def isCreatedByParquetMr(): Boolean = { - val footer = ParquetFileReader.readFooter(sharedConf, fileSplit.getPath, SKIP_ROW_GROUPS) + val footer = ParquetFileReader.readFooter(sharedConf, filePath, SKIP_ROW_GROUPS) footer.getFileMetaData().getCreatedBy().startsWith("parquet-mr") } val convertTz = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala index 21c9e2e4f82b4..4827f706e6016 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala @@ -19,15 +19,19 @@ package org.apache.spark.sql.execution.datasources.parquet import java.sql.Date +import scala.collection.JavaConverters.asScalaBufferConverter + import org.apache.parquet.filter2.predicate._ import org.apache.parquet.filter2.predicate.FilterApi._ import org.apache.parquet.io.api.Binary -import org.apache.parquet.schema.PrimitiveComparator +import org.apache.parquet.schema.{DecimalMetadata, MessageType, OriginalType, PrimitiveComparator, PrimitiveType} +import org.apache.parquet.schema.OriginalType._ +import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName +import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName._ import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.catalyst.util.DateTimeUtils.SQLDate import org.apache.spark.sql.sources -import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String /** @@ -35,171 +39,180 @@ import org.apache.spark.unsafe.types.UTF8String */ private[parquet] class ParquetFilters(pushDownDate: Boolean, pushDownStartWith: Boolean) { + private case class ParquetSchemaType( + originalType: OriginalType, + primitiveTypeName: PrimitiveTypeName, + decimalMetadata: DecimalMetadata) + + private val ParquetBooleanType = ParquetSchemaType(null, BOOLEAN, null) + private val ParquetIntegerType = ParquetSchemaType(null, INT32, null) + private val ParquetLongType = ParquetSchemaType(null, INT64, null) + private val ParquetFloatType = ParquetSchemaType(null, FLOAT, null) + private val ParquetDoubleType = ParquetSchemaType(null, DOUBLE, null) + private val ParquetStringType = ParquetSchemaType(UTF8, BINARY, null) + private val ParquetBinaryType = ParquetSchemaType(null, BINARY, null) + private val ParquetDateType = ParquetSchemaType(DATE, INT32, null) + private def dateToDays(date: Date): SQLDate = { DateTimeUtils.fromJavaDate(date) } - private val makeEq: PartialFunction[DataType, (String, Any) => FilterPredicate] = { - case BooleanType => + private val makeEq: PartialFunction[ParquetSchemaType, (String, Any) => FilterPredicate] = { + case ParquetBooleanType => (n: String, v: Any) => FilterApi.eq(booleanColumn(n), v.asInstanceOf[java.lang.Boolean]) - case IntegerType => + case ParquetIntegerType => (n: String, v: Any) => FilterApi.eq(intColumn(n), v.asInstanceOf[Integer]) - case LongType => + case ParquetLongType => (n: String, v: Any) => FilterApi.eq(longColumn(n), v.asInstanceOf[java.lang.Long]) - case FloatType => + case ParquetFloatType => (n: String, v: Any) => FilterApi.eq(floatColumn(n), v.asInstanceOf[java.lang.Float]) - case DoubleType => + case ParquetDoubleType => (n: String, v: Any) => FilterApi.eq(doubleColumn(n), v.asInstanceOf[java.lang.Double]) // Binary.fromString and Binary.fromByteArray don't accept null values - case StringType => + case ParquetStringType => (n: String, v: Any) => FilterApi.eq( binaryColumn(n), Option(v).map(s => Binary.fromString(s.asInstanceOf[String])).orNull) - case BinaryType => + case ParquetBinaryType => (n: String, v: Any) => FilterApi.eq( binaryColumn(n), Option(v).map(b => Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])).orNull) - case DateType if pushDownDate => + case ParquetDateType if pushDownDate => (n: String, v: Any) => FilterApi.eq( intColumn(n), Option(v).map(date => dateToDays(date.asInstanceOf[Date]).asInstanceOf[Integer]).orNull) } - private val makeNotEq: PartialFunction[DataType, (String, Any) => FilterPredicate] = { - case BooleanType => + private val makeNotEq: PartialFunction[ParquetSchemaType, (String, Any) => FilterPredicate] = { + case ParquetBooleanType => (n: String, v: Any) => FilterApi.notEq(booleanColumn(n), v.asInstanceOf[java.lang.Boolean]) - case IntegerType => + case ParquetIntegerType => (n: String, v: Any) => FilterApi.notEq(intColumn(n), v.asInstanceOf[Integer]) - case LongType => + case ParquetLongType => (n: String, v: Any) => FilterApi.notEq(longColumn(n), v.asInstanceOf[java.lang.Long]) - case FloatType => + case ParquetFloatType => (n: String, v: Any) => FilterApi.notEq(floatColumn(n), v.asInstanceOf[java.lang.Float]) - case DoubleType => + case ParquetDoubleType => (n: String, v: Any) => FilterApi.notEq(doubleColumn(n), v.asInstanceOf[java.lang.Double]) - case StringType => + case ParquetStringType => (n: String, v: Any) => FilterApi.notEq( binaryColumn(n), Option(v).map(s => Binary.fromString(s.asInstanceOf[String])).orNull) - case BinaryType => + case ParquetBinaryType => (n: String, v: Any) => FilterApi.notEq( binaryColumn(n), Option(v).map(b => Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])).orNull) - case DateType if pushDownDate => + case ParquetDateType if pushDownDate => (n: String, v: Any) => FilterApi.notEq( intColumn(n), Option(v).map(date => dateToDays(date.asInstanceOf[Date]).asInstanceOf[Integer]).orNull) } - private val makeLt: PartialFunction[DataType, (String, Any) => FilterPredicate] = { - case IntegerType => + private val makeLt: PartialFunction[ParquetSchemaType, (String, Any) => FilterPredicate] = { + case ParquetIntegerType => (n: String, v: Any) => FilterApi.lt(intColumn(n), v.asInstanceOf[Integer]) - case LongType => + case ParquetLongType => (n: String, v: Any) => FilterApi.lt(longColumn(n), v.asInstanceOf[java.lang.Long]) - case FloatType => + case ParquetFloatType => (n: String, v: Any) => FilterApi.lt(floatColumn(n), v.asInstanceOf[java.lang.Float]) - case DoubleType => + case ParquetDoubleType => (n: String, v: Any) => FilterApi.lt(doubleColumn(n), v.asInstanceOf[java.lang.Double]) - case StringType => + case ParquetStringType => (n: String, v: Any) => - FilterApi.lt(binaryColumn(n), - Binary.fromString(v.asInstanceOf[String])) - case BinaryType => + FilterApi.lt(binaryColumn(n), Binary.fromString(v.asInstanceOf[String])) + case ParquetBinaryType => (n: String, v: Any) => FilterApi.lt(binaryColumn(n), Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])) - case DateType if pushDownDate => - (n: String, v: Any) => FilterApi.lt( - intColumn(n), - Option(v).map(date => dateToDays(date.asInstanceOf[Date]).asInstanceOf[Integer]).orNull) + case ParquetDateType if pushDownDate => + (n: String, v: Any) => + FilterApi.lt(intColumn(n), dateToDays(v.asInstanceOf[Date]).asInstanceOf[Integer]) } - private val makeLtEq: PartialFunction[DataType, (String, Any) => FilterPredicate] = { - case IntegerType => - (n: String, v: Any) => FilterApi.ltEq(intColumn(n), v.asInstanceOf[java.lang.Integer]) - case LongType => + private val makeLtEq: PartialFunction[ParquetSchemaType, (String, Any) => FilterPredicate] = { + case ParquetIntegerType => + (n: String, v: Any) => FilterApi.ltEq(intColumn(n), v.asInstanceOf[Integer]) + case ParquetLongType => (n: String, v: Any) => FilterApi.ltEq(longColumn(n), v.asInstanceOf[java.lang.Long]) - case FloatType => + case ParquetFloatType => (n: String, v: Any) => FilterApi.ltEq(floatColumn(n), v.asInstanceOf[java.lang.Float]) - case DoubleType => + case ParquetDoubleType => (n: String, v: Any) => FilterApi.ltEq(doubleColumn(n), v.asInstanceOf[java.lang.Double]) - case StringType => + case ParquetStringType => (n: String, v: Any) => - FilterApi.ltEq(binaryColumn(n), - Binary.fromString(v.asInstanceOf[String])) - case BinaryType => + FilterApi.ltEq(binaryColumn(n), Binary.fromString(v.asInstanceOf[String])) + case ParquetBinaryType => (n: String, v: Any) => FilterApi.ltEq(binaryColumn(n), Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])) - case DateType if pushDownDate => - (n: String, v: Any) => FilterApi.ltEq( - intColumn(n), - Option(v).map(date => dateToDays(date.asInstanceOf[Date]).asInstanceOf[Integer]).orNull) + case ParquetDateType if pushDownDate => + (n: String, v: Any) => + FilterApi.ltEq(intColumn(n), dateToDays(v.asInstanceOf[Date]).asInstanceOf[Integer]) } - private val makeGt: PartialFunction[DataType, (String, Any) => FilterPredicate] = { - case IntegerType => - (n: String, v: Any) => FilterApi.gt(intColumn(n), v.asInstanceOf[java.lang.Integer]) - case LongType => + private val makeGt: PartialFunction[ParquetSchemaType, (String, Any) => FilterPredicate] = { + case ParquetIntegerType => + (n: String, v: Any) => FilterApi.gt(intColumn(n), v.asInstanceOf[Integer]) + case ParquetLongType => (n: String, v: Any) => FilterApi.gt(longColumn(n), v.asInstanceOf[java.lang.Long]) - case FloatType => + case ParquetFloatType => (n: String, v: Any) => FilterApi.gt(floatColumn(n), v.asInstanceOf[java.lang.Float]) - case DoubleType => + case ParquetDoubleType => (n: String, v: Any) => FilterApi.gt(doubleColumn(n), v.asInstanceOf[java.lang.Double]) - case StringType => + case ParquetStringType => (n: String, v: Any) => - FilterApi.gt(binaryColumn(n), - Binary.fromString(v.asInstanceOf[String])) - case BinaryType => + FilterApi.gt(binaryColumn(n), Binary.fromString(v.asInstanceOf[String])) + case ParquetBinaryType => (n: String, v: Any) => FilterApi.gt(binaryColumn(n), Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])) - case DateType if pushDownDate => - (n: String, v: Any) => FilterApi.gt( - intColumn(n), - Option(v).map(date => dateToDays(date.asInstanceOf[Date]).asInstanceOf[Integer]).orNull) + case ParquetDateType if pushDownDate => + (n: String, v: Any) => + FilterApi.gt(intColumn(n), dateToDays(v.asInstanceOf[Date]).asInstanceOf[Integer]) } - private val makeGtEq: PartialFunction[DataType, (String, Any) => FilterPredicate] = { - case IntegerType => - (n: String, v: Any) => FilterApi.gtEq(intColumn(n), v.asInstanceOf[java.lang.Integer]) - case LongType => + private val makeGtEq: PartialFunction[ParquetSchemaType, (String, Any) => FilterPredicate] = { + case ParquetIntegerType => + (n: String, v: Any) => FilterApi.gtEq(intColumn(n), v.asInstanceOf[Integer]) + case ParquetLongType => (n: String, v: Any) => FilterApi.gtEq(longColumn(n), v.asInstanceOf[java.lang.Long]) - case FloatType => + case ParquetFloatType => (n: String, v: Any) => FilterApi.gtEq(floatColumn(n), v.asInstanceOf[java.lang.Float]) - case DoubleType => + case ParquetDoubleType => (n: String, v: Any) => FilterApi.gtEq(doubleColumn(n), v.asInstanceOf[java.lang.Double]) - case StringType => + case ParquetStringType => (n: String, v: Any) => - FilterApi.gtEq(binaryColumn(n), - Binary.fromString(v.asInstanceOf[String])) - case BinaryType => + FilterApi.gtEq(binaryColumn(n), Binary.fromString(v.asInstanceOf[String])) + case ParquetBinaryType => (n: String, v: Any) => FilterApi.gtEq(binaryColumn(n), Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])) - case DateType if pushDownDate => - (n: String, v: Any) => FilterApi.gtEq( - intColumn(n), - Option(v).map(date => dateToDays(date.asInstanceOf[Date]).asInstanceOf[Integer]).orNull) + case ParquetDateType if pushDownDate => + (n: String, v: Any) => + FilterApi.gtEq(intColumn(n), dateToDays(v.asInstanceOf[Date]).asInstanceOf[Integer]) } /** * Returns a map from name of the column to the data type, if predicate push down applies. */ - private def getFieldMap(dataType: DataType): Map[String, DataType] = dataType match { - case StructType(fields) => + private def getFieldMap(dataType: MessageType): Map[String, ParquetSchemaType] = dataType match { + case m: MessageType => // Here we don't flatten the fields in the nested schema but just look up through // root fields. Currently, accessing to nested fields does not push down filters // and it does not support to create filters for them. - fields.map(f => f.name -> f.dataType).toMap - case _ => Map.empty[String, DataType] + m.getFields.asScala.filter(_.isPrimitive).map(_.asPrimitiveType()).map { f => + f.getName -> ParquetSchemaType( + f.getOriginalType, f.getPrimitiveTypeName, f.getDecimalMetadata) + }.toMap + case _ => Map.empty[String, ParquetSchemaType] } /** * Converts data sources filters to Parquet filter predicates. */ - def createFilter(schema: StructType, predicate: sources.Filter): Option[FilterPredicate] = { + def createFilter(schema: MessageType, predicate: sources.Filter): Option[FilterPredicate] = { val nameToType = getFieldMap(schema) // Parquet does not allow dots in the column name because dots are used as a column path diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala index 2499e9b604f3e..bdd8c4da6bd30 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala @@ -199,7 +199,7 @@ case class StringStartsWith(attribute: String, value: String) extends Filter { /** * A filter that evaluates to `true` iff the attribute evaluates to - * a string that starts with `value`. + * a string that ends with `value`. * * @since 1.3.1 */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index d9ae5858e5ed0..8b96c841c8c6e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -103,7 +103,8 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex assert(selectedFilters.nonEmpty, "No filter is pushed down") selectedFilters.foreach { pred => - val maybeFilter = parquetFilters.createFilter(df.schema, pred) + val maybeFilter = parquetFilters.createFilter( + new SparkToParquetSchemaConverter(conf).convert(df.schema), pred) assert(maybeFilter.isDefined, s"Couldn't generate filter predicate for $pred") // Doesn't bother checking type parameters here (e.g. `Eq[Integer]`) maybeFilter.exists(_.getClass === filterClass) @@ -542,12 +543,14 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex StructField("c", DoubleType, nullable = true) )) + val parquetSchema = new SparkToParquetSchemaConverter(conf).convert(schema) + assertResult(Some(and( lt(intColumn("a"), 10: Integer), gt(doubleColumn("c"), 1.5: java.lang.Double))) ) { parquetFilters.createFilter( - schema, + parquetSchema, sources.And( sources.LessThan("a", 10), sources.GreaterThan("c", 1.5D))) @@ -555,7 +558,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex assertResult(None) { parquetFilters.createFilter( - schema, + parquetSchema, sources.And( sources.LessThan("a", 10), sources.StringContains("b", "prefix"))) @@ -563,7 +566,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex assertResult(None) { parquetFilters.createFilter( - schema, + parquetSchema, sources.Not( sources.And( sources.GreaterThan("a", 1), @@ -729,7 +732,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex assertResult(None) { parquetFilters.createFilter( - df.schema, + new SparkToParquetSchemaConverter(conf).convert(df.schema), sources.StringStartsWith("_1", null)) } } From 1a2655a9e75627b584787f9e4c6cdaa92e61fa3f Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 4 Jul 2018 20:42:08 +0800 Subject: [PATCH 1060/1161] [SPARK-24635][SQL] Remove Blocks class from JavaCode class hierarchy ## What changes were proposed in this pull request? The `Blocks` class in `JavaCode` class hierarchy is not necessary. Its function can be taken by `CodeBlock`. We should remove it to make simpler class hierarchy. ## How was this patch tested? Existing tests. Author: Liang-Chi Hsieh Closes #21619 from viirya/SPARK-24635. --- .../expressions/codegen/javaCode.scala | 40 +++++++------------ 1 file changed, 14 insertions(+), 26 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala index 250ce48d059e0..44f63e21e93bb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala @@ -119,6 +119,7 @@ object JavaCode { * A trait representing a block of java code. */ trait Block extends JavaCode { + import Block._ // The expressions to be evaluated inside this block. def exprValues: Set[ExprValue] @@ -148,14 +149,17 @@ trait Block extends JavaCode { } // Concatenates this block with other block. - def + (other: Block): Block + def + (other: Block): Block = other match { + case EmptyBlock => this + case _ => code"$this\n$other" + } } object Block { val CODE_BLOCK_BUFFER_LENGTH: Int = 512 - implicit def blocksToBlock(blocks: Seq[Block]): Block = Blocks(blocks) + implicit def blocksToBlock(blocks: Seq[Block]): Block = blocks.reduceLeft(_ + _) implicit class BlockHelper(val sc: StringContext) extends AnyVal { def code(args: Any*): Block = { @@ -190,18 +194,17 @@ object Block { while (strings.hasNext) { val input = inputs.next input match { - case _: ExprValue | _: Block => + case _: ExprValue | _: CodeBlock => codeParts += buf.toString buf.clear blockInputs += input.asInstanceOf[JavaCode] + case EmptyBlock => case _ => buf.append(input) } buf.append(strings.next) } - if (buf.nonEmpty) { - codeParts += buf.toString - } + codeParts += buf.toString (codeParts.toSeq, blockInputs.toSeq) } @@ -209,7 +212,11 @@ object Block { /** * A block of java code. Including a sequence of code parts and some inputs to this block. - * The actual java code is generated by embedding the inputs into the code parts. + * The actual java code is generated by embedding the inputs into the code parts. Here we keep + * inputs of `JavaCode` instead of simply folding them as a string of code, because we need to + * track expressions (`ExprValue`) in this code block. We need to be able to manipulate the + * expressions later without changing the behavior of this code block in some applications, e.g., + * method splitting. */ case class CodeBlock(codeParts: Seq[String], blockInputs: Seq[JavaCode]) extends Block { override lazy val exprValues: Set[ExprValue] = { @@ -230,30 +237,11 @@ case class CodeBlock(codeParts: Seq[String], blockInputs: Seq[JavaCode]) extends } buf.toString } - - override def + (other: Block): Block = other match { - case c: CodeBlock => Blocks(Seq(this, c)) - case b: Blocks => Blocks(Seq(this) ++ b.blocks) - case EmptyBlock => this - } -} - -case class Blocks(blocks: Seq[Block]) extends Block { - override lazy val exprValues: Set[ExprValue] = blocks.flatMap(_.exprValues).toSet - override lazy val code: String = blocks.map(_.toString).mkString("\n") - - override def + (other: Block): Block = other match { - case c: CodeBlock => Blocks(blocks :+ c) - case b: Blocks => Blocks(blocks ++ b.blocks) - case EmptyBlock => this - } } object EmptyBlock extends Block with Serializable { override val code: String = "" override val exprValues: Set[ExprValue] = Set.empty - - override def + (other: Block): Block = other } /** From ca8243f30fc6939ee099a9534e3b811d5c64d2cf Mon Sep 17 00:00:00 2001 From: Shahid Date: Wed, 4 Jul 2018 09:56:24 -0500 Subject: [PATCH 1061/1161] [MINOR][ML] Minor correction in the powerIterationSuite ## What changes were proposed in this pull request? Currently the power iteration clustering test in spark ml, maps the results to the labels 0 and 1 for assertion. Since the clustering outputs need not be the same as the mapped labels, it may cause failure in the test case. Even if it correctly maps, theoretically we cannot guarantee which set belongs to which cluster label. KMeans can assign label 0 to either of the set. PowerIterationClusteringSuite in the MLLib checks the clustering results without mapping to the particular cluster label, as shown below. `` val predictions = Array.fill(2)(mutable.Set.empty[Long]) model.assignments.collect().foreach { a => predictions(a.cluster) += a.id } assert(predictions.toSet == Set((0 until n1).toSet, (n1 until n).toSet)) `` ## How was this patch tested? Existing tests Author: Shahid Closes #21689 from shahidki31/picTestSuiteMinorCorrection. --- .../PowerIterationClusteringSuite.scala | 30 ++++++++++++------- 1 file changed, 20 insertions(+), 10 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/PowerIterationClusteringSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/PowerIterationClusteringSuite.scala index b7072728d48f0..55b460f1a4524 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/PowerIterationClusteringSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/PowerIterationClusteringSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.ml.clustering +import scala.collection.mutable + import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.util.MLlibTestSparkContext @@ -76,12 +78,15 @@ class PowerIterationClusteringSuite extends SparkFunSuite .setMaxIter(40) .setWeightCol("weight") .assignClusters(data) - val localAssignments = assignments - .select('id, 'cluster) - .as[(Long, Int)].collect().toSet - val expectedResult = (0 until n1).map(x => (x, 1)).toSet ++ - (n1 until n).map(x => (x, 0)).toSet - assert(localAssignments === expectedResult) + .select("id", "cluster") + .as[(Long, Int)] + .collect() + + val predictions = Array.fill(2)(mutable.Set.empty[Long]) + assignments.foreach { + case (id, cluster) => predictions(cluster) += id + } + assert(predictions.toSet === Set((0 until n1).toSet, (n1 until n).toSet)) val assignments2 = new PowerIterationClustering() .setK(2) @@ -89,10 +94,15 @@ class PowerIterationClusteringSuite extends SparkFunSuite .setInitMode("degree") .setWeightCol("weight") .assignClusters(data) - val localAssignments2 = assignments2 - .select('id, 'cluster) - .as[(Long, Int)].collect().toSet - assert(localAssignments2 === expectedResult) + .select("id", "cluster") + .as[(Long, Int)] + .collect() + + val predictions2 = Array.fill(2)(mutable.Set.empty[Long]) + assignments2.foreach { + case (id, cluster) => predictions2(cluster) += id + } + assert(predictions2.toSet === Set((0 until n1).toSet, (n1 until n).toSet)) } test("supported input types") { From bf764a33bef617aa9bae535a5ea73d6a3e278d42 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 4 Jul 2018 18:36:09 -0700 Subject: [PATCH 1062/1161] [SPARK-22384][SQL][FOLLOWUP] Refine partition pruning when attribute is wrapped in Cast ## What changes were proposed in this pull request? As mentioned in https://github.com/apache/spark/pull/21586 , `Cast.mayTruncate` is not 100% safe, string to boolean is allowed. Since changing `Cast.mayTruncate` also changes the behavior of Dataset, here I propose to add a new `Cast.canSafeCast` for partition pruning. ## How was this patch tested? new test cases Author: Wenchen Fan Closes #21712 from cloud-fan/safeCast. --- .../spark/sql/catalyst/expressions/Cast.scala | 20 +++++++++++++++++++ .../spark/sql/hive/client/HiveShim.scala | 5 +++-- .../sql/hive/client/HiveClientSuite.scala | 20 +++++++++++++++++-- 3 files changed, 41 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 699ea53b5df0f..7971ae602bd37 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -134,6 +134,26 @@ object Cast { toPrecedence > 0 && fromPrecedence > toPrecedence } + /** + * Returns true iff we can safely cast the `from` type to `to` type without any truncating or + * precision lose, e.g. int -> long, date -> timestamp. + */ + def canSafeCast(from: AtomicType, to: AtomicType): Boolean = (from, to) match { + case _ if from == to => true + case (from: NumericType, to: DecimalType) if to.isWiderThan(from) => true + case (from: DecimalType, to: NumericType) if from.isTighterThan(to) => true + case (from, to) if legalNumericPrecedence(from, to) => true + case (DateType, TimestampType) => true + case (_, StringType) => true + case _ => false + } + + private def legalNumericPrecedence(from: DataType, to: DataType): Boolean = { + val fromPrecedence = TypeCoercion.numericPrecedence.indexOf(from) + val toPrecedence = TypeCoercion.numericPrecedence.indexOf(to) + fromPrecedence >= 0 && fromPrecedence < toPrecedence + } + def forceNullable(from: DataType, to: DataType): Boolean = (from, to) match { case (NullType, _) => true case (_, _) if from == to => false diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala index 8620f3f6d99fb..933384ed43e98 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala @@ -45,7 +45,7 @@ import org.apache.spark.sql.catalyst.analysis.NoSuchPermanentFunctionException import org.apache.spark.sql.catalyst.catalog.{CatalogFunction, CatalogTablePartition, CatalogUtils, FunctionResource, FunctionResourceType} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{IntegralType, StringType} +import org.apache.spark.sql.types.{AtomicType, IntegralType, StringType} import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils @@ -660,7 +660,8 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { def unapply(expr: Expression): Option[Attribute] = { expr match { case attr: Attribute => Some(attr) - case Cast(child, dt, _) if !Cast.mayTruncate(child.dataType, dt) => unapply(child) + case Cast(child @ AtomicType(), dt: AtomicType, _) + if Cast.canSafeCast(child.dataType.asInstanceOf[AtomicType], dt) => unapply(child) case _ => None } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala index 55275f6b37945..fa9f753795f65 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala @@ -24,7 +24,7 @@ import org.scalatest.BeforeAndAfterAll import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.types.LongType +import org.apache.spark.sql.types.{BooleanType, IntegerType, LongType} // TODO: Refactor this to `HivePartitionFilteringSuite` class HiveClientSuite(version: String) @@ -122,6 +122,22 @@ class HiveClientSuite(version: String) "aa" :: Nil) } + test("getPartitionsByFilter: cast(chunk as int)=1 (not a valid partition predicate)") { + testMetastorePartitionFiltering( + attr("chunk").cast(IntegerType) === 1, + 20170101 to 20170103, + 0 to 23, + "aa" :: "ab" :: "ba" :: "bb" :: Nil) + } + + test("getPartitionsByFilter: cast(chunk as boolean)=true (not a valid partition predicate)") { + testMetastorePartitionFiltering( + attr("chunk").cast(BooleanType) === true, + 20170101 to 20170103, + 0 to 23, + "aa" :: "ab" :: "ba" :: "bb" :: Nil) + } + test("getPartitionsByFilter: 20170101=ds") { testMetastorePartitionFiltering( Literal(20170101) === attr("ds"), @@ -138,7 +154,7 @@ class HiveClientSuite(version: String) "aa" :: "ab" :: "ba" :: "bb" :: Nil) } - test("getPartitionsByFilter: chunk in cast(ds as long)=20170101L") { + test("getPartitionsByFilter: cast(ds as long)=20170101L and h=10") { testMetastorePartitionFiltering( attr("ds").cast(LongType) === 20170101L && attr("h") === 10, 20170101 to 20170101, From 489a5294d106130beda1509e3cbbaf707a3d703d Mon Sep 17 00:00:00 2001 From: Xiao Li Date: Thu, 5 Jul 2018 09:56:48 +0800 Subject: [PATCH 1063/1161] [SPARK-17213][SPARK-17213][FOLLOW-UP] Improve the test of ## What changes were proposed in this pull request? This is a minor improvement for the test of SPARK-17213 ## How was this patch tested? N/A Author: Xiao Li Closes #21716 from gatorsmile/testMaster23. --- .../parquet/ParquetFilterSuite.scala | 28 +++++++++++-------- 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index 8b96c841c8c6e..f2c0bda256239 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -618,21 +618,25 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex } test("SPARK-17213: Broken Parquet filter push-down for string columns") { - withTempPath { dir => - import testImplicits._ + Seq(true, false).foreach { vectorizedEnabled => + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> vectorizedEnabled.toString) { + withTempPath { dir => + import testImplicits._ - val path = dir.getCanonicalPath - // scalastyle:off nonascii - Seq("a", "é").toDF("name").write.parquet(path) - // scalastyle:on nonascii + val path = dir.getCanonicalPath + // scalastyle:off nonascii + Seq("a", "é").toDF("name").write.parquet(path) + // scalastyle:on nonascii - assert(spark.read.parquet(path).where("name > 'a'").count() == 1) - assert(spark.read.parquet(path).where("name >= 'a'").count() == 2) + assert(spark.read.parquet(path).where("name > 'a'").count() == 1) + assert(spark.read.parquet(path).where("name >= 'a'").count() == 2) - // scalastyle:off nonascii - assert(spark.read.parquet(path).where("name < 'é'").count() == 1) - assert(spark.read.parquet(path).where("name <= 'é'").count() == 2) - // scalastyle:on nonascii + // scalastyle:off nonascii + assert(spark.read.parquet(path).where("name < 'é'").count() == 1) + assert(spark.read.parquet(path).where("name <= 'é'").count() == 2) + // scalastyle:on nonascii + } + } } } From f997be0c3136f85762b841469e7dfcde7e699ced Mon Sep 17 00:00:00 2001 From: mcteo Date: Thu, 5 Jul 2018 10:05:41 +0800 Subject: [PATCH 1064/1161] [SPARK-24698][PYTHON] Fixed typo in pyspark.ml's Identifiable class. ## What changes were proposed in this pull request? Fixed a small typo in the code that caused 20 random characters to be added to the UID, rather than 12. Author: mcteo Closes #21675 from mcteo/SPARK-24698-fix. --- python/pyspark/ml/util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py index 080cd299f4fde..e846834761e49 100644 --- a/python/pyspark/ml/util.py +++ b/python/pyspark/ml/util.py @@ -63,7 +63,7 @@ def _randomUID(cls): Generate a unique unicode id for the object. The default implementation concatenates the class name, "_", and 12 random hex chars. """ - return unicode(cls.__name__ + "_" + uuid.uuid4().hex[12:]) + return unicode(cls.__name__ + "_" + uuid.uuid4().hex[-12:]) @inherit_doc From 4be9f0c028cebb0d2975e93a6ebc56337cd2c585 Mon Sep 17 00:00:00 2001 From: Antonio Murgia Date: Thu, 5 Jul 2018 16:10:34 +0800 Subject: [PATCH 1065/1161] [SPARK-24673][SQL] scala sql function from_utc_timestamp second argument could be Column instead of String ## What changes were proposed in this pull request? Add an overloaded version to `from_utc_timestamp` and `to_utc_timestamp` having second argument as a `Column` instead of `String`. ## How was this patch tested? Unit testing, especially adding two tests to org.apache.spark.sql.DateFunctionsSuite.scala Author: Antonio Murgia Author: Antonio Murgia Closes #21693 from tmnd1991/feature/SPARK-24673. --- .../org/apache/spark/sql/functions.scala | 22 +++++++++++ .../apache/spark/sql/DateFunctionsSuite.scala | 38 ++++++++++++++++++- 2 files changed, 58 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 614f65f0faaba..f2627e69939cd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -2934,6 +2934,17 @@ object functions { FromUTCTimestamp(ts.expr, Literal(tz)) } + /** + * Given a timestamp like '2017-07-14 02:40:00.0', interprets it as a time in UTC, and renders + * that time as a timestamp in the given time zone. For example, 'GMT+1' would yield + * '2017-07-14 03:40:00.0'. + * @group datetime_funcs + * @since 2.4.0 + */ + def from_utc_timestamp(ts: Column, tz: Column): Column = withExpr { + FromUTCTimestamp(ts.expr, tz.expr) + } + /** * Given a timestamp like '2017-07-14 02:40:00.0', interprets it as a time in the given time * zone, and renders that time as a timestamp in UTC. For example, 'GMT+1' would yield @@ -2945,6 +2956,17 @@ object functions { ToUTCTimestamp(ts.expr, Literal(tz)) } + /** + * Given a timestamp like '2017-07-14 02:40:00.0', interprets it as a time in the given time + * zone, and renders that time as a timestamp in UTC. For example, 'GMT+1' would yield + * '2017-07-14 01:40:00.0'. + * @group datetime_funcs + * @since 2.4.0 + */ + def to_utc_timestamp(ts: Column, tz: Column): Column = withExpr { + ToUTCTimestamp(ts.expr, tz.expr) + } + /** * Bucketize rows into one or more time windows given a timestamp specifying column. Window * starts are inclusive but the window ends are exclusive, e.g. 12:05 will be in the window diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala index 237412aa692e5..3af80b36ec42c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala @@ -663,7 +663,7 @@ class DateFunctionsSuite extends QueryTest with SharedSQLContext { checkAnswer(df.selectExpr("datediff(a, d)"), Seq(Row(1), Row(1))) } - test("from_utc_timestamp") { + test("from_utc_timestamp with literal zone") { val df = Seq( (Timestamp.valueOf("2015-07-24 00:00:00"), "2015-07-24 00:00:00"), (Timestamp.valueOf("2015-07-25 00:00:00"), "2015-07-25 00:00:00") @@ -680,7 +680,24 @@ class DateFunctionsSuite extends QueryTest with SharedSQLContext { Row(Timestamp.valueOf("2015-07-24 17:00:00")))) } - test("to_utc_timestamp") { + test("from_utc_timestamp with column zone") { + val df = Seq( + (Timestamp.valueOf("2015-07-24 00:00:00"), "2015-07-24 00:00:00", "CET"), + (Timestamp.valueOf("2015-07-25 00:00:00"), "2015-07-25 00:00:00", "PST") + ).toDF("a", "b", "c") + checkAnswer( + df.select(from_utc_timestamp(col("a"), col("c"))), + Seq( + Row(Timestamp.valueOf("2015-07-24 02:00:00")), + Row(Timestamp.valueOf("2015-07-24 17:00:00")))) + checkAnswer( + df.select(from_utc_timestamp(col("b"), col("c"))), + Seq( + Row(Timestamp.valueOf("2015-07-24 02:00:00")), + Row(Timestamp.valueOf("2015-07-24 17:00:00")))) + } + + test("to_utc_timestamp with literal zone") { val df = Seq( (Timestamp.valueOf("2015-07-24 00:00:00"), "2015-07-24 00:00:00"), (Timestamp.valueOf("2015-07-25 00:00:00"), "2015-07-25 00:00:00") @@ -697,6 +714,23 @@ class DateFunctionsSuite extends QueryTest with SharedSQLContext { Row(Timestamp.valueOf("2015-07-25 07:00:00")))) } + test("to_utc_timestamp with column zone") { + val df = Seq( + (Timestamp.valueOf("2015-07-24 00:00:00"), "2015-07-24 00:00:00", "PST"), + (Timestamp.valueOf("2015-07-25 00:00:00"), "2015-07-25 00:00:00", "CET") + ).toDF("a", "b", "c") + checkAnswer( + df.select(to_utc_timestamp(col("a"), col("c"))), + Seq( + Row(Timestamp.valueOf("2015-07-24 07:00:00")), + Row(Timestamp.valueOf("2015-07-24 22:00:00")))) + checkAnswer( + df.select(to_utc_timestamp(col("b"), col("c"))), + Seq( + Row(Timestamp.valueOf("2015-07-24 07:00:00")), + Row(Timestamp.valueOf("2015-07-24 22:00:00")))) + } + test("SPARK-23715: to/from_utc_timestamp can retain the previous behavior") { withSQLConf(SQLConf.REJECT_TIMEZONE_IN_STRING.key -> "false") { checkAnswer( From 32cfd3e75a5ca65696fedfa4d49681e6fc3e698d Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 5 Jul 2018 20:48:55 +0800 Subject: [PATCH 1066/1161] [SPARK-24361][SQL] Polish code block manipulation API ## What changes were proposed in this pull request? Current code block manipulation API is immature and hacky. We need a formal API to manipulate code blocks. The basic idea is making `JavaCode` as `TreeNode`. So we can use familiar `transform` API to manipulate code blocks and expressions in code blocks. For example, we can replace `SimpleExprValue` in a code block like this: ```scala code.transformExprValues { case SimpleExprValue("1 + 1", _) => aliasedParam } ``` The example use case is splitting code to methods. For example, we have an `ExprCode` containing generated code. But it is too long and we need to split it as method. Because statement-based expressions can't be directly passed into. We need to transform them as variables first: ```scala def getExprValues(block: Block): Set[ExprValue] = block match { case c: CodeBlock => c.blockInputs.collect { case e: ExprValue => e }.toSet case _ => Set.empty } def currentCodegenInputs(ctx: CodegenContext): Set[ExprValue] = { // Collects current variables in ctx.currentVars and ctx.INPUT_ROW. // It looks roughly like... ctx.currentVars.flatMap { v => getExprValues(v.code) ++ Set(v.value, v.isNull) }.toSet + ctx.INPUT_ROW } // A code block of an expression contains too long code, making it as method if (eval.code.length > 1024) { val setIsNull = if (!eval.isNull.isInstanceOf[LiteralValue]) { ... } else { "" } // Pick up variables and statements necessary to pass in. val currentVars = currentCodegenInputs(ctx) val varsPassIn = getExprValues(eval.code).intersect(currentVars) val aliasedExprs = HashMap.empty[SimpleExprValue, VariableValue] // Replace statement-based expressions which can't be directly passed in the method. val newCode = eval.code.transform { case block => block.transformExprValues { case s: SimpleExprValue(_, javaType) if varsPassIn.contains(s) => if (aliasedExprs.contains(s)) { aliasedExprs(s) } else { val aliasedVariable = JavaCode.variable(ctx.freshName("aliasedVar"), javaType) aliasedExprs += s -> aliasedVariable varsPassIn += aliasedVariable aliasedVariable } } } val params = varsPassIn.filter(!_.isInstanceOf[SimpleExprValue])).map { variable => s"${variable.javaType.getName} ${variable.variableName}" }.mkString(", ") val funcName = ctx.freshName("nodeName") val javaType = CodeGenerator.javaType(dataType) val newValue = JavaCode.variable(ctx.freshName("value"), dataType) val funcFullName = ctx.addNewFunction(funcName, s""" |private $javaType $funcName($params) { | $newCode | $setIsNull | return ${eval.value}; |} """.stripMargin)) eval.value = newValue val args = varsPassIn.filter(!_.isInstanceOf[SimpleExprValue])).map { variable => s"${variable.variableName}" } // Create a code block to assign statements to aliased variables. val createVariables = aliasedExprs.foldLeft(EmptyBlock) { (block, (statement, variable)) => block + code"${statement.javaType.getName} $variable = $statement;" } eval.code = createVariables + code"$javaType $newValue = $funcFullName($args);" } ``` ## How was this patch tested? Added unite tests. Author: Liang-Chi Hsieh Closes #21405 from viirya/codeblock-api. --- .../expressions/codegen/javaCode.scala | 48 +++++++++--- .../expressions/codegen/CodeBlockSuite.scala | 75 +++++++++++++++++-- 2 files changed, 104 insertions(+), 19 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala index 44f63e21e93bb..2f8c853e836ba 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala @@ -22,6 +22,7 @@ import java.lang.{Boolean => JBool} import scala.collection.mutable.ArrayBuffer import scala.language.{existentials, implicitConversions} +import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.types.{BooleanType, DataType} /** @@ -118,12 +119,9 @@ object JavaCode { /** * A trait representing a block of java code. */ -trait Block extends JavaCode { +trait Block extends TreeNode[Block] with JavaCode { import Block._ - // The expressions to be evaluated inside this block. - def exprValues: Set[ExprValue] - // Returns java code string for this code block. override def toString: String = _marginChar match { case Some(c) => code.stripMargin(c).trim @@ -148,11 +146,41 @@ trait Block extends JavaCode { this } + /** + * Apply a map function to each java expression codes present in this java code, and return a new + * java code based on the mapped java expression codes. + */ + def transformExprValues(f: PartialFunction[ExprValue, ExprValue]): this.type = { + var changed = false + + @inline def transform(e: ExprValue): ExprValue = { + val newE = f lift e + if (!newE.isDefined || newE.get.equals(e)) { + e + } else { + changed = true + newE.get + } + } + + def doTransform(arg: Any): AnyRef = arg match { + case e: ExprValue => transform(e) + case Some(value) => Some(doTransform(value)) + case seq: Traversable[_] => seq.map(doTransform) + case other: AnyRef => other + } + + val newArgs = mapProductIterator(doTransform) + if (changed) makeCopy(newArgs).asInstanceOf[this.type] else this + } + // Concatenates this block with other block. def + (other: Block): Block = other match { case EmptyBlock => this case _ => code"$this\n$other" } + + override def verboseString: String = toString } object Block { @@ -219,12 +247,8 @@ object Block { * method splitting. */ case class CodeBlock(codeParts: Seq[String], blockInputs: Seq[JavaCode]) extends Block { - override lazy val exprValues: Set[ExprValue] = { - blockInputs.flatMap { - case b: Block => b.exprValues - case e: ExprValue => Set(e) - }.toSet - } + override def children: Seq[Block] = + blockInputs.filter(_.isInstanceOf[Block]).asInstanceOf[Seq[Block]] override lazy val code: String = { val strings = codeParts.iterator @@ -239,9 +263,9 @@ case class CodeBlock(codeParts: Seq[String], blockInputs: Seq[JavaCode]) extends } } -object EmptyBlock extends Block with Serializable { +case object EmptyBlock extends Block with Serializable { override val code: String = "" - override val exprValues: Set[ExprValue] = Set.empty + override def children: Seq[Block] = Seq.empty } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeBlockSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeBlockSuite.scala index d2c6420eadb20..55569b6f2933e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeBlockSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeBlockSuite.scala @@ -65,7 +65,9 @@ class CodeBlockSuite extends SparkFunSuite { |boolean $isNull = false; |int $value = -1; """.stripMargin - val exprValues = code.exprValues + val exprValues = code.asInstanceOf[CodeBlock].blockInputs.collect { + case e: ExprValue => e + }.toSet assert(exprValues.size == 2) assert(exprValues === Set(value, isNull)) } @@ -94,7 +96,9 @@ class CodeBlockSuite extends SparkFunSuite { assert(code.toString == expected) - val exprValues = code.exprValues + val exprValues = code.children.flatMap(_.asInstanceOf[CodeBlock].blockInputs.collect { + case e: ExprValue => e + }).toSet assert(exprValues.size == 5) assert(exprValues === Set(isNull1, value1, isNull2, value2, literal)) } @@ -107,7 +111,7 @@ class CodeBlockSuite extends SparkFunSuite { assert(e.getMessage().contains(s"Can not interpolate ${obj.getClass.getName}")) } - test("replace expr values in code block") { + test("transform expr in code block") { val expr = JavaCode.expression("1 + 1", IntegerType) val isNull = JavaCode.isNullVariable("expr1_isNull") val exprInFunc = JavaCode.variable("expr1", IntegerType) @@ -120,11 +124,11 @@ class CodeBlockSuite extends SparkFunSuite { |}""".stripMargin val aliasedParam = JavaCode.variable("aliased", expr.javaType) - val aliasedInputs = code.asInstanceOf[CodeBlock].blockInputs.map { - case _: SimpleExprValue => aliasedParam - case other => other + + // We want to replace all occurrences of `expr` with the variable `aliasedParam`. + val aliasedCode = code.transformExprValues { + case SimpleExprValue("1 + 1", java.lang.Integer.TYPE) => aliasedParam } - val aliasedCode = CodeBlock(code.asInstanceOf[CodeBlock].codeParts, aliasedInputs).stripMargin val expected = code""" |callFunc(int $aliasedParam) { @@ -133,4 +137,61 @@ class CodeBlockSuite extends SparkFunSuite { |}""".stripMargin assert(aliasedCode.toString == expected.toString) } + + test ("transform expr in nested blocks") { + val expr = JavaCode.expression("1 + 1", IntegerType) + val isNull = JavaCode.isNullVariable("expr1_isNull") + val exprInFunc = JavaCode.variable("expr1", IntegerType) + + val funcs = Seq("callFunc1", "callFunc2", "callFunc3") + val subBlocks = funcs.map { funcName => + code""" + |$funcName(int $expr) { + | boolean $isNull = false; + | int $exprInFunc = $expr + 1; + |}""".stripMargin + } + + val aliasedParam = JavaCode.variable("aliased", expr.javaType) + + val block = code"${subBlocks(0)}\n${subBlocks(1)}\n${subBlocks(2)}" + val transformedBlock = block.transform { + case b: Block => b.transformExprValues { + case SimpleExprValue("1 + 1", java.lang.Integer.TYPE) => aliasedParam + } + }.asInstanceOf[CodeBlock] + + val expected1 = + code""" + |callFunc1(int aliased) { + | boolean expr1_isNull = false; + | int expr1 = aliased + 1; + |}""".stripMargin + + val expected2 = + code""" + |callFunc2(int aliased) { + | boolean expr1_isNull = false; + | int expr1 = aliased + 1; + |}""".stripMargin + + val expected3 = + code""" + |callFunc3(int aliased) { + | boolean expr1_isNull = false; + | int expr1 = aliased + 1; + |}""".stripMargin + + val exprValues = transformedBlock.children.flatMap { block => + block.asInstanceOf[CodeBlock].blockInputs.collect { + case e: ExprValue => e + } + }.toSet + + assert(transformedBlock.children(0).toString == expected1.toString) + assert(transformedBlock.children(1).toString == expected2.toString) + assert(transformedBlock.children(2).toString == expected3.toString) + assert(transformedBlock.toString == (expected1 + expected2 + expected3).toString) + assert(exprValues === Set(isNull, exprInFunc, aliasedParam)) + } } From e58dadb77ed6cac3e1b2a037a6449e5a6e7f2cec Mon Sep 17 00:00:00 2001 From: Michael Mior Date: Thu, 5 Jul 2018 08:32:20 -0500 Subject: [PATCH 1067/1161] [SPARK-23820][CORE] Enable use of long form of callsite in logs This adds an option to event logging to include the long form of the callsite instead of the short form. Author: Michael Mior Closes #21433 from michaelmior/long-callsite. --- .../org/apache/spark/internal/config/package.scala | 3 +++ .../main/scala/org/apache/spark/storage/RDDInfo.scala | 10 +++++++++- 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 38a043c85ae33..bda9795a0b925 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -72,6 +72,9 @@ package object config { private[spark] val EVENT_LOG_OVERWRITE = ConfigBuilder("spark.eventLog.overwrite").booleanConf.createWithDefault(false) + private[spark] val EVENT_LOG_CALLSITE_FORM = + ConfigBuilder("spark.eventLog.callsite").stringConf.createWithDefault("short") + private[spark] val EXECUTOR_CLASS_PATH = ConfigBuilder(SparkLauncher.EXECUTOR_EXTRA_CLASSPATH).stringConf.createOptional diff --git a/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala b/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala index e5abbf745cc41..9ccc8f9cc585b 100644 --- a/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala +++ b/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala @@ -17,7 +17,9 @@ package org.apache.spark.storage +import org.apache.spark.SparkEnv import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.internal.config._ import org.apache.spark.rdd.{RDD, RDDOperationScope} import org.apache.spark.util.Utils @@ -53,10 +55,16 @@ class RDDInfo( } private[spark] object RDDInfo { + private val callsiteForm = SparkEnv.get.conf.get(EVENT_LOG_CALLSITE_FORM) + def fromRdd(rdd: RDD[_]): RDDInfo = { val rddName = Option(rdd.name).getOrElse(Utils.getFormattedClassName(rdd)) val parentIds = rdd.dependencies.map(_.rdd.id) + val callSite = callsiteForm match { + case "short" => rdd.creationSite.shortForm + case "long" => rdd.creationSite.longForm + } new RDDInfo(rdd.id, rddName, rdd.partitions.length, - rdd.getStorageLevel, parentIds, rdd.creationSite.shortForm, rdd.scope) + rdd.getStorageLevel, parentIds, callSite, rdd.scope) } } From 7bd6d5412072643f2320fd389f323cfc51368c81 Mon Sep 17 00:00:00 2001 From: Stavros Kontopoulos Date: Thu, 5 Jul 2018 08:38:26 -0500 Subject: [PATCH 1068/1161] [SPARK-24711][K8S] Fix tags for integration tests ## What changes were proposed in this pull request? - disables maven surfire plugin to allow tags function properly, doc here: http://www.scalatest.org/user_guide/using_the_scalatest_maven_plugin ## How was this patch tested? Manually by adding tags. Author: Stavros Kontopoulos Closes #21697 from skonto/fix-tags. --- pom.xml | 2 +- .../dev/dev-run-integration-tests.sh | 20 +++++++++++++++++++ .../kubernetes/integration-tests/pom.xml | 11 ++++++++++ 3 files changed, 32 insertions(+), 1 deletion(-) diff --git a/pom.xml b/pom.xml index ca30f9f12b098..cd567e227f331 100644 --- a/pom.xml +++ b/pom.xml @@ -2122,7 +2122,7 @@ org.apache.maven.plugins maven-surefire-plugin - 2.20.1 + 2.22.0 diff --git a/resource-managers/kubernetes/integration-tests/dev/dev-run-integration-tests.sh b/resource-managers/kubernetes/integration-tests/dev/dev-run-integration-tests.sh index ea893fa39eede..3acd0f5cd3349 100755 --- a/resource-managers/kubernetes/integration-tests/dev/dev-run-integration-tests.sh +++ b/resource-managers/kubernetes/integration-tests/dev/dev-run-integration-tests.sh @@ -27,6 +27,8 @@ IMAGE_TAG="N/A" SPARK_MASTER= NAMESPACE= SERVICE_ACCOUNT= +INCLUDE_TAGS= +EXCLUDE_TAGS= # Parse arguments while (( "$#" )); do @@ -59,6 +61,14 @@ while (( "$#" )); do SERVICE_ACCOUNT="$2" shift ;; + --include-tags) + INCLUDE_TAGS="$2" + shift + ;; + --exclude-tags) + EXCLUDE_TAGS="$2" + shift + ;; *) break ;; @@ -90,4 +100,14 @@ then properties=( ${properties[@]} -Dspark.kubernetes.test.master=$SPARK_MASTER ) fi +if [ -n $EXCLUDE_TAGS ]; +then + properties=( ${properties[@]} -Dtest.exclude.tags=$EXCLUDE_TAGS ) +fi + +if [ -n $INCLUDE_TAGS ]; +then + properties=( ${properties[@]} -Dtest.include.tags=$INCLUDE_TAGS ) +fi + ../../../build/mvn integration-test ${properties[@]} diff --git a/resource-managers/kubernetes/integration-tests/pom.xml b/resource-managers/kubernetes/integration-tests/pom.xml index 520bda89e034d..6a2fff891098b 100644 --- a/resource-managers/kubernetes/integration-tests/pom.xml +++ b/resource-managers/kubernetes/integration-tests/pom.xml @@ -40,6 +40,7 @@ minikube docker.io/kubespark + jar Spark Project Kubernetes Integration Tests @@ -102,6 +103,15 @@ + + + org.apache.maven.plugins + maven-surefire-plugin + + true + + + @@ -126,6 +136,7 @@ ${spark.kubernetes.test.serviceAccountName} ${test.exclude.tags} + ${test.include.tags} From ac78bcce00ff8ec8e5b7335c2807aa0cd0f5406a Mon Sep 17 00:00:00 2001 From: cluo <0512lc@163.com> Date: Thu, 5 Jul 2018 09:06:25 -0500 Subject: [PATCH 1069/1161] [SPARK-24743][EXAMPLES] Update the JavaDirectKafkaWordCount example to support the new API of kafka ## What changes were proposed in this pull request? Add some required configs for Kafka consumer in JavaDirectKafkaWordCount class. ## How was this patch tested? Manual tests on Local mode. Author: cluo <0512lc@163.com> Closes #21717 from cluo512/SPARK-24743-update-JavaDirectKafkaWordCount. --- .../streaming/JavaDirectKafkaWordCount.java | 24 ++++++++++++------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaDirectKafkaWordCount.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaDirectKafkaWordCount.java index b6b163fa8b2cd..748bf58f30350 100644 --- a/examples/src/main/java/org/apache/spark/examples/streaming/JavaDirectKafkaWordCount.java +++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaDirectKafkaWordCount.java @@ -26,7 +26,9 @@ import scala.Tuple2; +import org.apache.kafka.clients.consumer.ConsumerConfig; import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.common.serialization.StringDeserializer; import org.apache.spark.SparkConf; import org.apache.spark.streaming.api.java.*; @@ -37,30 +39,33 @@ /** * Consumes messages from one or more topics in Kafka and does wordcount. - * Usage: JavaDirectKafkaWordCount + * Usage: JavaDirectKafkaWordCount * is a list of one or more Kafka brokers + * is a consumer group name to consume from topics * is a list of one or more kafka topics to consume from * * Example: * $ bin/run-example streaming.JavaDirectKafkaWordCount broker1-host:port,broker2-host:port \ - * topic1,topic2 + * consumer-group topic1,topic2 */ public final class JavaDirectKafkaWordCount { private static final Pattern SPACE = Pattern.compile(" "); public static void main(String[] args) throws Exception { - if (args.length < 2) { - System.err.println("Usage: JavaDirectKafkaWordCount \n" + - " is a list of one or more Kafka brokers\n" + - " is a list of one or more kafka topics to consume from\n\n"); + if (args.length < 3) { + System.err.println("Usage: JavaDirectKafkaWordCount \n" + + " is a list of one or more Kafka brokers\n" + + " is a consumer group name to consume from topics\n" + + " is a list of one or more kafka topics to consume from\n\n"); System.exit(1); } StreamingExamples.setStreamingLogLevels(); String brokers = args[0]; - String topics = args[1]; + String groupId = args[1]; + String topics = args[2]; // Create context with a 2 seconds batch interval SparkConf sparkConf = new SparkConf().setAppName("JavaDirectKafkaWordCount"); @@ -68,7 +73,10 @@ public static void main(String[] args) throws Exception { Set topicsSet = new HashSet<>(Arrays.asList(topics.split(","))); Map kafkaParams = new HashMap<>(); - kafkaParams.put("metadata.broker.list", brokers); + kafkaParams.put(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, brokers); + kafkaParams.put(ConsumerConfig.GROUP_ID_CONFIG, groupId); + kafkaParams.put(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, StringDeserializer.class); + kafkaParams.put(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, StringDeserializer.class); // Create direct kafka stream with brokers and topics JavaInputDStream> messages = KafkaUtils.createDirectStream( From 33952cfa8182c1e925083e18c63c6152dcc3c8b4 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Thu, 5 Jul 2018 09:25:19 -0700 Subject: [PATCH 1070/1161] [SPARK-24675][SQL] Rename table: validate existence of new location MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? If table is renamed to a existing new location, data won't show up. ``` scala> Seq("hello").toDF("a").write.format("parquet").saveAsTable("t") scala> sql("select * from t").show() +-----+ | a| +-----+ |hello| +-----+ scala> sql("alter table t rename to test") res2: org.apache.spark.sql.DataFrame = [] scala> sql("select * from test").show() +---+ | a| +---+ +---+ ``` The file layout is like ``` $ tree test test ├── gabage └── t ├── _SUCCESS └── part-00000-856b0f10-08f1-42d6-9eb3-7719261f3d5e-c000.snappy.parquet ``` In Hive, if the new location exists, the renaming will fail even the location is empty. We should have the same validation in Catalog, in case of unexpected bugs. ## How was this patch tested? New unit test. Author: Gengliang Wang Closes #21655 from gengliangwang/validate_rename_table. --- docs/sql-programming-guide.md | 1 + .../sql/catalyst/catalog/SessionCatalog.scala | 20 +++++++++++++++++++ .../sql/execution/command/DDLSuite.scala | 18 +++++++++++++++++ 3 files changed, 39 insertions(+) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index cd7329b621122..ad23dae7c6b7c 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1850,6 +1850,7 @@ working with timestamps in `pandas_udf`s to get the best performance, see - Since Spark 2.4, writing a dataframe with an empty or nested empty schema using any file formats (parquet, orc, json, text, csv etc.) is not allowed. An exception is thrown when attempting to write dataframes with empty schema. - Since Spark 2.4, Spark compares a DATE type with a TIMESTAMP type after promotes both sides to TIMESTAMP. To set `false` to `spark.sql.hive.compareDateTimestampInTimestamp` restores the previous behavior. This option will be removed in Spark 3.0. - Since Spark 2.4, creating a managed table with nonempty location is not allowed. An exception is thrown when attempting to create a managed table with nonempty location. To set `true` to `spark.sql.allowCreatingManagedTableUsingNonemptyLocation` restores the previous behavior. This option will be removed in Spark 3.0. + - Since Spark 2.4, renaming a managed table to existing location is not allowed. An exception is thrown when attempting to rename a managed table to existing location. - Since Spark 2.4, the type coercion rules can automatically promote the argument types of the variadic SQL functions (e.g., IN/COALESCE) to the widest common type, no matter how the input arguments order. In prior Spark versions, the promotion could fail in some specific orders (e.g., TimestampType, IntegerType and StringType) and throw an exception. - Since Spark 2.4, Spark has enabled non-cascading SQL cache invalidation in addition to the traditional cache invalidation mechanism. The non-cascading cache invalidation mechanism allows users to remove a cache without impacting its dependent caches. This new cache invalidation mechanism is used in scenarios where the data of the cache to be removed is still valid, e.g., calling unpersist() on a Dataset, or dropping a temporary view. This allows users to free up memory and keep the desired caches valid at the same time. - In version 2.3 and earlier, `to_utc_timestamp` and `from_utc_timestamp` respect the timezone in the input timestamp string, which breaks the assumption that the input timestamp is in a specific timezone. Therefore, these 2 functions can return unexpected results. In version 2.4 and later, this problem has been fixed. `to_utc_timestamp` and `from_utc_timestamp` will return null if the input timestamp string contains timezone. As an example, `from_utc_timestamp('2000-10-10 00:00:00', 'GMT+1')` will return `2000-10-10 01:00:00` in both Spark 2.3 and 2.4. However, `from_utc_timestamp('2000-10-10 00:00:00+00:00', 'GMT+1')`, assuming a local timezone of GMT+8, will return `2000-10-10 09:00:00` in Spark 2.3 but `null` in 2.4. For people who don't care about this problem and want to retain the previous behaivor to keep their query unchanged, you can set `spark.sql.function.rejectTimezoneInString` to false. This option will be removed in Spark 3.0 and should only be used as a temporary workaround. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index c390337c03ff5..c26a34528c162 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -619,6 +619,7 @@ class SessionCatalog( requireTableExists(TableIdentifier(oldTableName, Some(db))) requireTableNotExists(TableIdentifier(newTableName, Some(db))) validateName(newTableName) + validateNewLocationOfRename(oldName, newName) externalCatalog.renameTable(db, oldTableName, newTableName) } else { if (newName.database.isDefined) { @@ -1366,4 +1367,23 @@ class SessionCatalog( // copy over temporary views tempViews.foreach(kv => target.tempViews.put(kv._1, kv._2)) } + + /** + * Validate the new locatoin before renaming a managed table, which should be non-existent. + */ + private def validateNewLocationOfRename( + oldName: TableIdentifier, + newName: TableIdentifier): Unit = { + val oldTable = getTableMetadata(oldName) + if (oldTable.tableType == CatalogTableType.MANAGED) { + val databaseLocation = + externalCatalog.getDatabase(oldName.database.getOrElse(currentDb)).locationUri + val newTableLocation = new Path(new Path(databaseLocation), formatTableName(newName.table)) + val fs = newTableLocation.getFileSystem(hadoopConf) + if (fs.exists(newTableLocation)) { + throw new AnalysisException(s"Can not rename the managed table('$oldName')" + + s". The associated location('$newTableLocation') already exists.") + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index 3998ceca38b30..270ed7f80197c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -441,6 +441,24 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } } + test("rename a managed table with existing empty directory") { + val tableLoc = new File(spark.sessionState.catalog.defaultTablePath(TableIdentifier("tab2"))) + try { + withTable("tab1") { + sql(s"CREATE TABLE tab1 USING $dataSource AS SELECT 1, 'a'") + tableLoc.mkdir() + val ex = intercept[AnalysisException] { + sql("ALTER TABLE tab1 RENAME TO tab2") + }.getMessage + val expectedMsg = "Can not rename the managed table('`tab1`'). The associated location" + assert(ex.contains(expectedMsg)) + } + } finally { + waitForTasksToFinish() + Utils.deleteRecursively(tableLoc) + } + } + private def checkSchemaInCreatedDataSourceTable( path: File, userSpecifiedSchema: Option[String], From e71e93aaaa0d26301e10d3dc65f4db298424e99a Mon Sep 17 00:00:00 2001 From: Stavros Kontopoulos Date: Thu, 5 Jul 2018 16:35:16 -0500 Subject: [PATCH 1071/1161] [SPARK-24694][K8S] Pass all app args to integration tests ## What changes were proposed in this pull request? - Allows to pass more than one app args to tests. ## How was this patch tested? Manually tested it with a spark test that requires more than on app args. Author: Stavros Kontopoulos Closes #21672 from skonto/fix_itsets-args. --- .../k8s/integrationtest/KubernetesTestComponents.scala | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesTestComponents.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesTestComponents.scala index 48727142dd052..b2471e51116cb 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesTestComponents.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesTestComponents.scala @@ -105,16 +105,13 @@ private[spark] object SparkAppLauncher extends Logging { sparkHomeDir: Path): Unit = { val sparkSubmitExecutable = sparkHomeDir.resolve(Paths.get("bin", "spark-submit")) logInfo(s"Launching a spark app with arguments $appArguments and conf $appConf") - val appArgsArray = - if (appArguments.appArgs.length > 0) Array(appArguments.appArgs.mkString(" ")) - else Array[String]() val commandLine = (Array(sparkSubmitExecutable.toFile.getAbsolutePath, "--deploy-mode", "cluster", "--class", appArguments.mainClass, "--master", appConf.get("spark.master") ) ++ appConf.toStringArray :+ appArguments.mainAppResource) ++ - appArgsArray + appArguments.appArgs ProcessUtils.executeProcess(commandLine, timeoutSecs) } } From 01fcba2c685be0603a404392685e9d52fb4cb82a Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Fri, 6 Jul 2018 11:10:50 +0800 Subject: [PATCH 1072/1161] [SPARK-24737][SQL] Type coercion between StructTypes. ## What changes were proposed in this pull request? We can support type coercion between `StructType`s where all the internal types are compatible. ## How was this patch tested? Added tests. Author: Takuya UESHIN Closes #21713 from ueshin/issues/SPARK-24737/structtypecoercion. --- .../sql/catalyst/analysis/TypeCoercion.scala | 69 ++++++-------- .../catalyst/analysis/TypeCoercionSuite.scala | 93 +++++++++++++++++-- 2 files changed, 114 insertions(+), 48 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index cf90e6e555fc8..b6ca30c7398f2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -102,25 +102,7 @@ object TypeCoercion { case (_: TimestampType, _: DateType) | (_: DateType, _: TimestampType) => Some(TimestampType) - case (t1 @ StructType(fields1), t2 @ StructType(fields2)) if t1.sameType(t2) => - Some(StructType(fields1.zip(fields2).map { case (f1, f2) => - // Since `t1.sameType(t2)` is true, two StructTypes have the same DataType - // except `name` (in case of `spark.sql.caseSensitive=false`) and `nullable`. - // - Different names: use f1.name - // - Different nullabilities: `nullable` is true iff one of them is nullable. - val dataType = findTightestCommonType(f1.dataType, f2.dataType).get - StructField(f1.name, dataType, nullable = f1.nullable || f2.nullable) - })) - - case (a1 @ ArrayType(et1, hasNull1), a2 @ ArrayType(et2, hasNull2)) if a1.sameType(a2) => - findTightestCommonType(et1, et2).map(ArrayType(_, hasNull1 || hasNull2)) - - case (m1 @ MapType(kt1, vt1, hasNull1), m2 @ MapType(kt2, vt2, hasNull2)) if m1.sameType(m2) => - val keyType = findTightestCommonType(kt1, kt2) - val valueType = findTightestCommonType(vt1, vt2) - Some(MapType(keyType.get, valueType.get, hasNull1 || hasNull2)) - - case _ => None + case (t1, t2) => findTypeForComplex(t1, t2, findTightestCommonType) } /** Promotes all the way to StringType. */ @@ -166,6 +148,30 @@ object TypeCoercion { case (l, r) => None } + private def findTypeForComplex( + t1: DataType, + t2: DataType, + findTypeFunc: (DataType, DataType) => Option[DataType]): Option[DataType] = (t1, t2) match { + case (ArrayType(et1, containsNull1), ArrayType(et2, containsNull2)) => + findTypeFunc(et1, et2).map(ArrayType(_, containsNull1 || containsNull2)) + case (MapType(kt1, vt1, valueContainsNull1), MapType(kt2, vt2, valueContainsNull2)) => + findTypeFunc(kt1, kt2).flatMap { kt => + findTypeFunc(vt1, vt2).map { vt => + MapType(kt, vt, valueContainsNull1 || valueContainsNull2) + } + } + case (StructType(fields1), StructType(fields2)) if fields1.length == fields2.length => + val resolver = SQLConf.get.resolver + fields1.zip(fields2).foldLeft(Option(new StructType())) { + case (Some(struct), (field1, field2)) if resolver(field1.name, field2.name) => + findTypeFunc(field1.dataType, field2.dataType).map { + dt => struct.add(field1.name, dt, field1.nullable || field2.nullable) + } + case _ => None + } + case _ => None + } + /** * Case 2 type widening (see the classdoc comment above for TypeCoercion). * @@ -176,17 +182,7 @@ object TypeCoercion { findTightestCommonType(t1, t2) .orElse(findWiderTypeForDecimal(t1, t2)) .orElse(stringPromotion(t1, t2)) - .orElse((t1, t2) match { - case (ArrayType(et1, containsNull1), ArrayType(et2, containsNull2)) => - findWiderTypeForTwo(et1, et2).map(ArrayType(_, containsNull1 || containsNull2)) - case (MapType(kt1, vt1, valueContainsNull1), MapType(kt2, vt2, valueContainsNull2)) => - findWiderTypeForTwo(kt1, kt2).flatMap { kt => - findWiderTypeForTwo(vt1, vt2).map { vt => - MapType(kt, vt, valueContainsNull1 || valueContainsNull2) - } - } - case _ => None - }) + .orElse(findTypeForComplex(t1, t2, findWiderTypeForTwo)) } /** @@ -222,18 +218,7 @@ object TypeCoercion { t2: DataType): Option[DataType] = { findTightestCommonType(t1, t2) .orElse(findWiderTypeForDecimal(t1, t2)) - .orElse((t1, t2) match { - case (ArrayType(et1, containsNull1), ArrayType(et2, containsNull2)) => - findWiderTypeWithoutStringPromotionForTwo(et1, et2) - .map(ArrayType(_, containsNull1 || containsNull2)) - case (MapType(kt1, vt1, valueContainsNull1), MapType(kt2, vt2, valueContainsNull2)) => - findWiderTypeWithoutStringPromotionForTwo(kt1, kt2).flatMap { kt => - findWiderTypeWithoutStringPromotionForTwo(vt1, vt2).map { vt => - MapType(kt, vt, valueContainsNull1 || valueContainsNull2) - } - } - case _ => None - }) + .orElse(findTypeForComplex(t1, t2, findWiderTypeWithoutStringPromotionForTwo)) } def findWiderTypeWithoutStringPromotion(types: Seq[DataType]): Option[DataType] = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala index 4e5ca1b8cdd36..8cc5a23779a2a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala @@ -54,7 +54,7 @@ class TypeCoercionSuite extends AnalysisTest { // | NullType | ByteType | ShortType | IntegerType | LongType | DoubleType | FloatType | Dec(10, 2) | BinaryType | BooleanType | StringType | DateType | TimestampType | ArrayType | MapType | StructType | NullType | CalendarIntervalType | DecimalType(38, 18) | DoubleType | IntegerType | // | CalendarIntervalType | X | X | X | X | X | X | X | X | X | X | X | X | X | X | X | X | CalendarIntervalType | X | X | X | // +----------------------+----------+-----------+-------------+----------+------------+-----------+------------+------------+-------------+------------+----------+---------------+------------+----------+-------------+----------+----------------------+---------------------+-------------+--------------+ - // Note: StructType* is castable only when the internal child types also match; otherwise, not castable. + // Note: StructType* is castable when all the internal child types are castable according to the table. // Note: ArrayType* is castable when the element type is castable according to the table. // Note: MapType* is castable when both the key type and the value type are castable according to the table. // scalastyle:on line.size.limit @@ -397,7 +397,7 @@ class TypeCoercionSuite extends AnalysisTest { widenTest( StructType(Seq(StructField("a", IntegerType, nullable = false))), StructType(Seq(StructField("a", DoubleType, nullable = false))), - None) + Some(StructType(Seq(StructField("a", DoubleType, nullable = false))))) widenTest( StructType(Seq(StructField("a", IntegerType, nullable = false))), @@ -454,15 +454,18 @@ class TypeCoercionSuite extends AnalysisTest { def widenTestWithStringPromotion( t1: DataType, t2: DataType, - expected: Option[DataType]): Unit = { - checkWidenType(TypeCoercion.findWiderTypeForTwo, t1, t2, expected) + expected: Option[DataType], + isSymmetric: Boolean = true): Unit = { + checkWidenType(TypeCoercion.findWiderTypeForTwo, t1, t2, expected, isSymmetric) } def widenTestWithoutStringPromotion( t1: DataType, t2: DataType, - expected: Option[DataType]): Unit = { - checkWidenType(TypeCoercion.findWiderTypeWithoutStringPromotionForTwo, t1, t2, expected) + expected: Option[DataType], + isSymmetric: Boolean = true): Unit = { + checkWidenType( + TypeCoercion.findWiderTypeWithoutStringPromotionForTwo, t1, t2, expected, isSymmetric) } // Decimal @@ -492,6 +495,10 @@ class TypeCoercionSuite extends AnalysisTest { ArrayType(MapType(IntegerType, FloatType), containsNull = false), ArrayType(MapType(LongType, DoubleType), containsNull = false), Some(ArrayType(MapType(LongType, DoubleType), containsNull = false))) + widenTestWithStringPromotion( + ArrayType(new StructType().add("num", ShortType), containsNull = false), + ArrayType(new StructType().add("num", LongType), containsNull = false), + Some(ArrayType(new StructType().add("num", LongType), containsNull = false))) // MapType widenTestWithStringPromotion( @@ -506,6 +513,64 @@ class TypeCoercionSuite extends AnalysisTest { MapType(IntegerType, MapType(ShortType, TimestampType), valueContainsNull = false), MapType(LongType, MapType(DoubleType, StringType), valueContainsNull = false), Some(MapType(LongType, MapType(DoubleType, StringType), valueContainsNull = false))) + widenTestWithStringPromotion( + MapType(IntegerType, new StructType().add("num", ShortType), valueContainsNull = false), + MapType(LongType, new StructType().add("num", LongType), valueContainsNull = false), + Some(MapType(LongType, new StructType().add("num", LongType), valueContainsNull = false))) + + // StructType + widenTestWithStringPromotion( + new StructType() + .add("num", ShortType, nullable = true).add("ts", StringType, nullable = false), + new StructType() + .add("num", DoubleType, nullable = false).add("ts", TimestampType, nullable = true), + Some(new StructType() + .add("num", DoubleType, nullable = true).add("ts", StringType, nullable = true))) + widenTestWithStringPromotion( + new StructType() + .add("arr", ArrayType(ShortType, containsNull = false), nullable = false), + new StructType() + .add("arr", ArrayType(DoubleType, containsNull = true), nullable = false), + Some(new StructType() + .add("arr", ArrayType(DoubleType, containsNull = true), nullable = false))) + widenTestWithStringPromotion( + new StructType() + .add("map", MapType(ShortType, TimestampType, valueContainsNull = true), nullable = false), + new StructType() + .add("map", MapType(DoubleType, StringType, valueContainsNull = false), nullable = false), + Some(new StructType() + .add("map", MapType(DoubleType, StringType, valueContainsNull = true), nullable = false))) + + widenTestWithStringPromotion( + new StructType().add("num", IntegerType), + new StructType().add("num", LongType).add("str", StringType), + None) + widenTestWithoutStringPromotion( + new StructType().add("num", IntegerType), + new StructType().add("num", LongType).add("str", StringType), + None) + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + widenTestWithStringPromotion( + new StructType().add("a", IntegerType), + new StructType().add("A", LongType), + None) + widenTestWithoutStringPromotion( + new StructType().add("a", IntegerType), + new StructType().add("A", LongType), + None) + } + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { + widenTestWithStringPromotion( + new StructType().add("a", IntegerType), + new StructType().add("A", LongType), + Some(new StructType().add("a", LongType)), + isSymmetric = false) + widenTestWithoutStringPromotion( + new StructType().add("a", IntegerType), + new StructType().add("A", LongType), + Some(new StructType().add("a", LongType)), + isSymmetric = false) + } // Without string promotion widenTestWithoutStringPromotion(IntegerType, StringType, None) @@ -520,6 +585,14 @@ class TypeCoercionSuite extends AnalysisTest { MapType(StringType, IntegerType), MapType(TimestampType, IntegerType), None) widenTestWithoutStringPromotion( MapType(IntegerType, StringType), MapType(IntegerType, TimestampType), None) + widenTestWithoutStringPromotion( + new StructType().add("a", IntegerType), + new StructType().add("a", StringType), + None) + widenTestWithoutStringPromotion( + new StructType().add("a", StringType), + new StructType().add("a", IntegerType), + None) // String promotion widenTestWithStringPromotion(IntegerType, StringType, Some(StringType)) @@ -544,6 +617,14 @@ class TypeCoercionSuite extends AnalysisTest { MapType(IntegerType, StringType), MapType(IntegerType, TimestampType), Some(MapType(IntegerType, StringType))) + widenTestWithStringPromotion( + new StructType().add("a", IntegerType), + new StructType().add("a", StringType), + Some(new StructType().add("a", StringType))) + widenTestWithStringPromotion( + new StructType().add("a", StringType), + new StructType().add("a", IntegerType), + Some(new StructType().add("a", StringType))) } private def ruleTest(rule: Rule[LogicalPlan], initial: Expression, transformed: Expression) { From bf67f70c48881ee99751f7d51fbcbda1e593d90a Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Fri, 6 Jul 2018 11:13:57 +0800 Subject: [PATCH 1073/1161] [SPARK-24692][TESTS] Improvement FilterPushdownBenchmark ## What changes were proposed in this pull request? Refer to the [`WideSchemaBenchmark`](https://github.com/apache/spark/blob/v2.3.1/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/WideSchemaBenchmark.scala) update `FilterPushdownBenchmark`: 1. Write the result to `benchmarks/FilterPushdownBenchmark-results.txt` for easy maintenance. 2. Add more benchmark case: `StringStartsWith`, `Decimal`, `InSet -> InFilters` and `tinyint`. ## How was this patch tested? manual tests Author: Yuming Wang Closes #21677 from wangyum/SPARK-24692. --- .../FilterPushdownBenchmark-results.txt | 580 ++++++++++++++++++ .../benchmark/FilterPushdownBenchmark.scala | 405 +++++------- 2 files changed, 748 insertions(+), 237 deletions(-) create mode 100644 sql/core/benchmarks/FilterPushdownBenchmark-results.txt diff --git a/sql/core/benchmarks/FilterPushdownBenchmark-results.txt b/sql/core/benchmarks/FilterPushdownBenchmark-results.txt new file mode 100644 index 0000000000000..29fe4345d69da --- /dev/null +++ b/sql/core/benchmarks/FilterPushdownBenchmark-results.txt @@ -0,0 +1,580 @@ +================================================================================================ +Pushdown for many distinct value case +================================================================================================ + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 0 string row (value IS NULL): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 8970 / 9122 1.8 570.3 1.0X +Parquet Vectorized (Pushdown) 471 / 491 33.4 30.0 19.0X +Native ORC Vectorized 7661 / 7853 2.1 487.0 1.2X +Native ORC Vectorized (Pushdown) 1134 / 1161 13.9 72.1 7.9X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 0 string row ('7864320' < value < '7864320'): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 9246 / 9297 1.7 587.8 1.0X +Parquet Vectorized (Pushdown) 480 / 488 32.8 30.5 19.3X +Native ORC Vectorized 7838 / 7850 2.0 498.3 1.2X +Native ORC Vectorized (Pushdown) 1054 / 1118 14.9 67.0 8.8X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 1 string row (value = '7864320'): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 8989 / 9100 1.7 571.5 1.0X +Parquet Vectorized (Pushdown) 448 / 467 35.1 28.5 20.1X +Native ORC Vectorized 7680 / 7768 2.0 488.3 1.2X +Native ORC Vectorized (Pushdown) 1067 / 1118 14.7 67.8 8.4X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 1 string row (value <=> '7864320'): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 9115 / 9266 1.7 579.5 1.0X +Parquet Vectorized (Pushdown) 466 / 492 33.7 29.7 19.5X +Native ORC Vectorized 7800 / 7914 2.0 495.9 1.2X +Native ORC Vectorized (Pushdown) 1075 / 1102 14.6 68.4 8.5X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 1 string row ('7864320' <= value <= '7864320'): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 9099 / 9237 1.7 578.5 1.0X +Parquet Vectorized (Pushdown) 462 / 475 34.1 29.3 19.7X +Native ORC Vectorized 7847 / 7925 2.0 498.9 1.2X +Native ORC Vectorized (Pushdown) 1078 / 1114 14.6 68.5 8.4X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select all string rows (value IS NOT NULL): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 19303 / 19547 0.8 1227.3 1.0X +Parquet Vectorized (Pushdown) 19924 / 20089 0.8 1266.7 1.0X +Native ORC Vectorized 18725 / 19079 0.8 1190.5 1.0X +Native ORC Vectorized (Pushdown) 19310 / 19492 0.8 1227.7 1.0X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 0 int row (value IS NULL): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 8117 / 8323 1.9 516.1 1.0X +Parquet Vectorized (Pushdown) 484 / 494 32.5 30.8 16.8X +Native ORC Vectorized 6811 / 7036 2.3 433.0 1.2X +Native ORC Vectorized (Pushdown) 1061 / 1082 14.8 67.5 7.6X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 0 int row (7864320 < value < 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 8105 / 8140 1.9 515.3 1.0X +Parquet Vectorized (Pushdown) 478 / 505 32.9 30.4 17.0X +Native ORC Vectorized 6914 / 7211 2.3 439.6 1.2X +Native ORC Vectorized (Pushdown) 1044 / 1064 15.1 66.4 7.8X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 1 int row (value = 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 7983 / 8116 2.0 507.6 1.0X +Parquet Vectorized (Pushdown) 464 / 487 33.9 29.5 17.2X +Native ORC Vectorized 6703 / 6774 2.3 426.1 1.2X +Native ORC Vectorized (Pushdown) 1017 / 1058 15.5 64.6 7.9X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 1 int row (value <=> 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 7942 / 7983 2.0 504.9 1.0X +Parquet Vectorized (Pushdown) 468 / 479 33.6 29.7 17.0X +Native ORC Vectorized 6677 / 6779 2.4 424.5 1.2X +Native ORC Vectorized (Pushdown) 1021 / 1068 15.4 64.9 7.8X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 1 int row (7864320 <= value <= 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 7909 / 7958 2.0 502.8 1.0X +Parquet Vectorized (Pushdown) 485 / 494 32.4 30.8 16.3X +Native ORC Vectorized 6751 / 6846 2.3 429.2 1.2X +Native ORC Vectorized (Pushdown) 1043 / 1077 15.1 66.3 7.6X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 1 int row (7864319 < value < 7864321): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 8010 / 8033 2.0 509.2 1.0X +Parquet Vectorized (Pushdown) 472 / 489 33.3 30.0 17.0X +Native ORC Vectorized 6655 / 6808 2.4 423.1 1.2X +Native ORC Vectorized (Pushdown) 1015 / 1067 15.5 64.5 7.9X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 10% int rows (value < 1572864): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 8983 / 9035 1.8 571.1 1.0X +Parquet Vectorized (Pushdown) 2204 / 2231 7.1 140.1 4.1X +Native ORC Vectorized 7864 / 8011 2.0 500.0 1.1X +Native ORC Vectorized (Pushdown) 2674 / 2789 5.9 170.0 3.4X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 50% int rows (value < 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 12723 / 12903 1.2 808.9 1.0X +Parquet Vectorized (Pushdown) 9112 / 9282 1.7 579.3 1.4X +Native ORC Vectorized 12090 / 12230 1.3 768.7 1.1X +Native ORC Vectorized (Pushdown) 9242 / 9372 1.7 587.6 1.4X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 90% int rows (value < 14155776): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 16453 / 16678 1.0 1046.1 1.0X +Parquet Vectorized (Pushdown) 15997 / 16262 1.0 1017.0 1.0X +Native ORC Vectorized 16652 / 17070 0.9 1058.7 1.0X +Native ORC Vectorized (Pushdown) 15843 / 16112 1.0 1007.2 1.0X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select all int rows (value IS NOT NULL): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 17098 / 17254 0.9 1087.1 1.0X +Parquet Vectorized (Pushdown) 17302 / 17529 0.9 1100.1 1.0X +Native ORC Vectorized 16790 / 17098 0.9 1067.5 1.0X +Native ORC Vectorized (Pushdown) 17329 / 17914 0.9 1101.7 1.0X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select all int rows (value > -1): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 17088 / 17392 0.9 1086.4 1.0X +Parquet Vectorized (Pushdown) 17609 / 17863 0.9 1119.5 1.0X +Native ORC Vectorized 18334 / 69831 0.9 1165.7 0.9X +Native ORC Vectorized (Pushdown) 17465 / 17629 0.9 1110.4 1.0X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select all int rows (value != -1): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 16903 / 17233 0.9 1074.6 1.0X +Parquet Vectorized (Pushdown) 16945 / 17032 0.9 1077.3 1.0X +Native ORC Vectorized 16377 / 16762 1.0 1041.2 1.0X +Native ORC Vectorized (Pushdown) 16950 / 17212 0.9 1077.7 1.0X + + +================================================================================================ +Pushdown for few distinct value case (use dictionary encoding) +================================================================================================ + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 0 distinct string row (value IS NULL): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 7245 / 7322 2.2 460.7 1.0X +Parquet Vectorized (Pushdown) 378 / 389 41.6 24.0 19.2X +Native ORC Vectorized 6720 / 6778 2.3 427.2 1.1X +Native ORC Vectorized (Pushdown) 1009 / 1032 15.6 64.2 7.2X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 0 distinct string row ('100' < value < '100'): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 7627 / 7795 2.1 484.9 1.0X +Parquet Vectorized (Pushdown) 384 / 406 41.0 24.4 19.9X +Native ORC Vectorized 6724 / 7824 2.3 427.5 1.1X +Native ORC Vectorized (Pushdown) 968 / 986 16.3 61.5 7.9X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 1 distinct string row (value = '100'): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 7157 / 7534 2.2 455.0 1.0X +Parquet Vectorized (Pushdown) 542 / 565 29.0 34.5 13.2X +Native ORC Vectorized 6716 / 7214 2.3 427.0 1.1X +Native ORC Vectorized (Pushdown) 1212 / 1288 13.0 77.0 5.9X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 1 distinct string row (value <=> '100'): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 7368 / 7552 2.1 468.4 1.0X +Parquet Vectorized (Pushdown) 544 / 556 28.9 34.6 13.5X +Native ORC Vectorized 6740 / 6867 2.3 428.5 1.1X +Native ORC Vectorized (Pushdown) 1230 / 1426 12.8 78.2 6.0X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 1 distinct string row ('100' <= value <= '100'): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 7427 / 7734 2.1 472.2 1.0X +Parquet Vectorized (Pushdown) 556 / 568 28.3 35.4 13.3X +Native ORC Vectorized 6847 / 7059 2.3 435.3 1.1X +Native ORC Vectorized (Pushdown) 1226 / 1230 12.8 77.9 6.1X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select all distinct string rows (value IS NOT NULL): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 16998 / 17311 0.9 1080.7 1.0X +Parquet Vectorized (Pushdown) 16977 / 17250 0.9 1079.4 1.0X +Native ORC Vectorized 18447 / 19852 0.9 1172.8 0.9X +Native ORC Vectorized (Pushdown) 16614 / 17102 0.9 1056.3 1.0X + + +================================================================================================ +Pushdown benchmark for StringStartsWith +================================================================================================ + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +StringStartsWith filter: (value like '10%'): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 9705 / 10814 1.6 617.0 1.0X +Parquet Vectorized (Pushdown) 3086 / 3574 5.1 196.2 3.1X +Native ORC Vectorized 10094 / 10695 1.6 641.8 1.0X +Native ORC Vectorized (Pushdown) 9611 / 9999 1.6 611.0 1.0X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +StringStartsWith filter: (value like '1000%'): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 8016 / 8183 2.0 509.7 1.0X +Parquet Vectorized (Pushdown) 444 / 457 35.4 28.2 18.0X +Native ORC Vectorized 6970 / 7169 2.3 443.2 1.2X +Native ORC Vectorized (Pushdown) 7447 / 7503 2.1 473.5 1.1X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +StringStartsWith filter: (value like '786432%'): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 7908 / 8046 2.0 502.8 1.0X +Parquet Vectorized (Pushdown) 408 / 429 38.6 25.9 19.4X +Native ORC Vectorized 7021 / 7100 2.2 446.4 1.1X +Native ORC Vectorized (Pushdown) 7310 / 7490 2.2 464.8 1.1X + + +================================================================================================ +Pushdown benchmark for decimal +================================================================================================ + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 1 decimal(9, 2) row (value = 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 3785 / 3867 4.2 240.6 1.0X +Parquet Vectorized (Pushdown) 3820 / 3928 4.1 242.9 1.0X +Native ORC Vectorized 3981 / 4049 4.0 253.1 1.0X +Native ORC Vectorized (Pushdown) 702 / 735 22.4 44.6 5.4X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 10% decimal(9, 2) rows (value < 1572864): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 4694 / 4813 3.4 298.4 1.0X +Parquet Vectorized (Pushdown) 4839 / 4907 3.3 307.6 1.0X +Native ORC Vectorized 4943 / 5032 3.2 314.2 0.9X +Native ORC Vectorized (Pushdown) 2043 / 2085 7.7 129.9 2.3X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 50% decimal(9, 2) rows (value < 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 8321 / 8472 1.9 529.0 1.0X +Parquet Vectorized (Pushdown) 8125 / 8471 1.9 516.6 1.0X +Native ORC Vectorized 8524 / 8616 1.8 541.9 1.0X +Native ORC Vectorized (Pushdown) 7961 / 8383 2.0 506.1 1.0X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 90% decimal(9, 2) rows (value < 14155776): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 9587 / 10112 1.6 609.5 1.0X +Parquet Vectorized (Pushdown) 9726 / 10370 1.6 618.3 1.0X +Native ORC Vectorized 10119 / 11147 1.6 643.4 0.9X +Native ORC Vectorized (Pushdown) 9366 / 9497 1.7 595.5 1.0X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 1 decimal(18, 2) row (value = 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 4060 / 4093 3.9 258.1 1.0X +Parquet Vectorized (Pushdown) 4037 / 4125 3.9 256.6 1.0X +Native ORC Vectorized 4756 / 4811 3.3 302.4 0.9X +Native ORC Vectorized (Pushdown) 824 / 889 19.1 52.4 4.9X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 10% decimal(18, 2) rows (value < 1572864): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 5157 / 5271 3.0 327.9 1.0X +Parquet Vectorized (Pushdown) 5051 / 5141 3.1 321.1 1.0X +Native ORC Vectorized 5723 / 6146 2.7 363.9 0.9X +Native ORC Vectorized (Pushdown) 2198 / 2317 7.2 139.8 2.3X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 50% decimal(18, 2) rows (value < 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 8608 / 8647 1.8 547.3 1.0X +Parquet Vectorized (Pushdown) 8471 / 8584 1.9 538.6 1.0X +Native ORC Vectorized 9249 / 10048 1.7 588.0 0.9X +Native ORC Vectorized (Pushdown) 7645 / 8091 2.1 486.1 1.1X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 90% decimal(18, 2) rows (value < 14155776): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 11658 / 11888 1.3 741.2 1.0X +Parquet Vectorized (Pushdown) 11812 / 12098 1.3 751.0 1.0X +Native ORC Vectorized 12943 / 13312 1.2 822.9 0.9X +Native ORC Vectorized (Pushdown) 13139 / 13465 1.2 835.4 0.9X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 1 decimal(38, 2) row (value = 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 5491 / 5716 2.9 349.1 1.0X +Parquet Vectorized (Pushdown) 5515 / 5615 2.9 350.6 1.0X +Native ORC Vectorized 4582 / 4654 3.4 291.3 1.2X +Native ORC Vectorized (Pushdown) 815 / 861 19.3 51.8 6.7X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 10% decimal(38, 2) rows (value < 1572864): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 6432 / 6527 2.4 409.0 1.0X +Parquet Vectorized (Pushdown) 6513 / 6607 2.4 414.1 1.0X +Native ORC Vectorized 5618 / 6085 2.8 357.2 1.1X +Native ORC Vectorized (Pushdown) 2403 / 2443 6.5 152.8 2.7X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 50% decimal(38, 2) rows (value < 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 11041 / 11467 1.4 701.9 1.0X +Parquet Vectorized (Pushdown) 10909 / 11484 1.4 693.5 1.0X +Native ORC Vectorized 9860 / 10436 1.6 626.9 1.1X +Native ORC Vectorized (Pushdown) 7908 / 8069 2.0 502.8 1.4X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 90% decimal(38, 2) rows (value < 14155776): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 14816 / 16877 1.1 942.0 1.0X +Parquet Vectorized (Pushdown) 15383 / 15740 1.0 978.0 1.0X +Native ORC Vectorized 14408 / 14771 1.1 916.0 1.0X +Native ORC Vectorized (Pushdown) 13968 / 14805 1.1 888.1 1.1X + + +================================================================================================ +Pushdown benchmark for InSet -> InFilters +================================================================================================ + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +InSet -> InFilters (values count: 5, distribution: 10): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 7477 / 7587 2.1 475.4 1.0X +Parquet Vectorized (Pushdown) 7862 / 8346 2.0 499.9 1.0X +Native ORC Vectorized 6447 / 7021 2.4 409.9 1.2X +Native ORC Vectorized (Pushdown) 983 / 1003 16.0 62.5 7.6X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +InSet -> InFilters (values count: 5, distribution: 50): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 7107 / 7290 2.2 451.9 1.0X +Parquet Vectorized (Pushdown) 7196 / 7258 2.2 457.5 1.0X +Native ORC Vectorized 6102 / 6222 2.6 388.0 1.2X +Native ORC Vectorized (Pushdown) 926 / 958 17.0 58.9 7.7X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +InSet -> InFilters (values count: 5, distribution: 90): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 7374 / 7692 2.1 468.8 1.0X +Parquet Vectorized (Pushdown) 7771 / 7848 2.0 494.1 0.9X +Native ORC Vectorized 6184 / 6356 2.5 393.2 1.2X +Native ORC Vectorized (Pushdown) 920 / 963 17.1 58.5 8.0X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +InSet -> InFilters (values count: 10, distribution: 10): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 7073 / 7326 2.2 449.7 1.0X +Parquet Vectorized (Pushdown) 7304 / 7647 2.2 464.4 1.0X +Native ORC Vectorized 6222 / 6579 2.5 395.6 1.1X +Native ORC Vectorized (Pushdown) 958 / 994 16.4 60.9 7.4X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +InSet -> InFilters (values count: 10, distribution: 50): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 7121 / 7501 2.2 452.7 1.0X +Parquet Vectorized (Pushdown) 7751 / 8334 2.0 492.8 0.9X +Native ORC Vectorized 6225 / 6680 2.5 395.8 1.1X +Native ORC Vectorized (Pushdown) 998 / 1020 15.8 63.5 7.1X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +InSet -> InFilters (values count: 10, distribution: 90): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 7157 / 7399 2.2 455.1 1.0X +Parquet Vectorized (Pushdown) 7806 / 7911 2.0 496.3 0.9X +Native ORC Vectorized 6548 / 6720 2.4 416.3 1.1X +Native ORC Vectorized (Pushdown) 1016 / 1050 15.5 64.6 7.0X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +InSet -> InFilters (values count: 50, distribution: 10): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 7662 / 7805 2.1 487.1 1.0X +Parquet Vectorized (Pushdown) 7590 / 7861 2.1 482.5 1.0X +Native ORC Vectorized 6840 / 8073 2.3 434.9 1.1X +Native ORC Vectorized (Pushdown) 1041 / 1075 15.1 66.2 7.4X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +InSet -> InFilters (values count: 50, distribution: 50): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 8230 / 9266 1.9 523.2 1.0X +Parquet Vectorized (Pushdown) 7735 / 7960 2.0 491.8 1.1X +Native ORC Vectorized 6945 / 7109 2.3 441.6 1.2X +Native ORC Vectorized (Pushdown) 1123 / 1144 14.0 71.4 7.3X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +InSet -> InFilters (values count: 50, distribution: 90): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 7656 / 8058 2.1 486.7 1.0X +Parquet Vectorized (Pushdown) 7860 / 8247 2.0 499.7 1.0X +Native ORC Vectorized 6684 / 7003 2.4 424.9 1.1X +Native ORC Vectorized (Pushdown) 1085 / 1172 14.5 69.0 7.1X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +InSet -> InFilters (values count: 100, distribution: 10): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 7594 / 8128 2.1 482.8 1.0X +Parquet Vectorized (Pushdown) 7845 / 7923 2.0 498.8 1.0X +Native ORC Vectorized 5859 / 6421 2.7 372.5 1.3X +Native ORC Vectorized (Pushdown) 1037 / 1054 15.2 66.0 7.3X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +InSet -> InFilters (values count: 100, distribution: 50): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 6762 / 6775 2.3 429.9 1.0X +Parquet Vectorized (Pushdown) 6911 / 6970 2.3 439.4 1.0X +Native ORC Vectorized 5884 / 5960 2.7 374.1 1.1X +Native ORC Vectorized (Pushdown) 1028 / 1052 15.3 65.4 6.6X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +InSet -> InFilters (values count: 100, distribution: 90): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 6718 / 6767 2.3 427.1 1.0X +Parquet Vectorized (Pushdown) 6812 / 6909 2.3 433.1 1.0X +Native ORC Vectorized 5842 / 5883 2.7 371.4 1.1X +Native ORC Vectorized (Pushdown) 1040 / 1058 15.1 66.1 6.5X + + +================================================================================================ +Pushdown benchmark for tinyint +================================================================================================ + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 1 tinyint row (value = CAST(63 AS tinyint)): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 3726 / 3775 4.2 236.9 1.0X +Parquet Vectorized (Pushdown) 3741 / 3789 4.2 237.9 1.0X +Native ORC Vectorized 2793 / 2909 5.6 177.6 1.3X +Native ORC Vectorized (Pushdown) 530 / 561 29.7 33.7 7.0X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 10% tinyint rows (value < CAST(12 AS tinyint)): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 4385 / 4406 3.6 278.8 1.0X +Parquet Vectorized (Pushdown) 4398 / 4454 3.6 279.6 1.0X +Native ORC Vectorized 3420 / 3501 4.6 217.4 1.3X +Native ORC Vectorized (Pushdown) 1395 / 1432 11.3 88.7 3.1X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 50% tinyint rows (value < CAST(63 AS tinyint)): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 7307 / 7394 2.2 464.6 1.0X +Parquet Vectorized (Pushdown) 7411 / 7461 2.1 471.2 1.0X +Native ORC Vectorized 6501 / 7814 2.4 413.4 1.1X +Native ORC Vectorized (Pushdown) 7341 / 8637 2.1 466.7 1.0X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 90% tinyint rows (value < CAST(114 AS tinyint)): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 11886 / 13122 1.3 755.7 1.0X +Parquet Vectorized (Pushdown) 12557 / 14173 1.3 798.4 0.9X +Native ORC Vectorized 10758 / 11971 1.5 684.0 1.1X +Native ORC Vectorized (Pushdown) 10564 / 10713 1.5 671.6 1.1X + + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/FilterPushdownBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/FilterPushdownBenchmark.scala index 6d7c7de9a856e..fc716dec9f337 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/FilterPushdownBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/FilterPushdownBenchmark.scala @@ -17,25 +17,30 @@ package org.apache.spark.sql.execution.benchmark -import java.io.File +import java.io.{File, FileOutputStream, OutputStream} import scala.util.{Random, Try} +import org.scalatest.{BeforeAndAfterEachTestData, Suite, TestData} + import org.apache.spark.SparkConf +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.{DataFrame, SparkSession} import org.apache.spark.sql.functions.monotonically_increasing_id import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.{ByteType, Decimal, DecimalType} import org.apache.spark.util.{Benchmark, Utils} - /** * Benchmark to measure read performance with Filter pushdown. * To run this: - * spark-submit --class + * build/sbt "sql/test-only *FilterPushdownBenchmark" + * + * Results will be written to "benchmarks/FilterPushdownBenchmark-results.txt". */ -object FilterPushdownBenchmark { - val conf = new SparkConf() - .setAppName("FilterPushdownBenchmark") +class FilterPushdownBenchmark extends SparkFunSuite with BenchmarkBeforeAndAfterEachTest { + private val conf = new SparkConf() + .setAppName(this.getClass.getSimpleName) // Since `spark.master` always exists, overrides this value .set("spark.master", "local[1]") .setIfMissing("spark.driver.memory", "3g") @@ -44,8 +49,40 @@ object FilterPushdownBenchmark { .setIfMissing("orc.compression", "snappy") .setIfMissing("spark.sql.parquet.compression.codec", "snappy") + private val numRows = 1024 * 1024 * 15 + private val width = 5 + private val mid = numRows / 2 + private val blockSize = 1048576 + private val spark = SparkSession.builder().config(conf).getOrCreate() + private var out: OutputStream = _ + + override def beforeAll() { + super.beforeAll() + out = new FileOutputStream(new File("benchmarks/FilterPushdownBenchmark-results.txt")) + } + + override def beforeEach(td: TestData) { + super.beforeEach(td) + val separator = "=" * 96 + val testHeader = (separator + '\n' + td.name + '\n' + separator + '\n' + '\n').getBytes + out.write(testHeader) + } + + override def afterEach(td: TestData) { + out.write('\n') + super.afterEach(td) + } + + override def afterAll() { + try { + out.close() + } finally { + super.afterAll() + } + } + def withTempPath(f: File => Unit): Unit = { val path = Utils.createTempDir() path.delete() @@ -81,8 +118,7 @@ object FilterPushdownBenchmark { .withColumn("value", valueCol) .sort("value") - saveAsOrcTable(df, dir.getCanonicalPath + "/orc") - saveAsParquetTable(df, dir.getCanonicalPath + "/parquet") + saveAsTable(df, dir) } private def prepareStringDictTable( @@ -93,19 +129,22 @@ object FilterPushdownBenchmark { } val df = spark.range(numRows).selectExpr(selectExpr: _*).sort("value") - saveAsOrcTable(df, dir.getCanonicalPath + "/orc") - saveAsParquetTable(df, dir.getCanonicalPath + "/parquet") + saveAsTable(df, dir) } - private def saveAsOrcTable(df: DataFrame, dir: String): Unit = { - // To always turn on dictionary encoding, we set 1.0 at the threshold (the default is 0.8) - df.write.mode("overwrite").option("orc.dictionary.key.threshold", 1.0).orc(dir) - spark.read.orc(dir).createOrReplaceTempView("orcTable") - } + private def saveAsTable(df: DataFrame, dir: File): Unit = { + val orcPath = dir.getCanonicalPath + "/orc" + val parquetPath = dir.getCanonicalPath + "/parquet" - private def saveAsParquetTable(df: DataFrame, dir: String): Unit = { - df.write.mode("overwrite").parquet(dir) - spark.read.parquet(dir).createOrReplaceTempView("parquetTable") + // To always turn on dictionary encoding, we set 1.0 at the threshold (the default is 0.8) + df.write.mode("overwrite") + .option("orc.dictionary.key.threshold", 1.0) + .option("orc.stripe.size", blockSize).orc(orcPath) + spark.read.orc(orcPath).createOrReplaceTempView("orcTable") + + df.write.mode("overwrite") + .option("parquet.block.size", blockSize).parquet(parquetPath) + spark.read.parquet(parquetPath).createOrReplaceTempView("parquetTable") } def filterPushDownBenchmark( @@ -113,7 +152,7 @@ object FilterPushdownBenchmark { title: String, whereExpr: String, selectExpr: String = "*"): Unit = { - val benchmark = new Benchmark(title, values, minNumIters = 5) + val benchmark = new Benchmark(title, values, minNumIters = 5, output = Some(out)) Seq(false, true).foreach { pushDownEnabled => val name = s"Parquet Vectorized ${if (pushDownEnabled) s"(Pushdown)" else ""}" @@ -133,214 +172,6 @@ object FilterPushdownBenchmark { } } - /* - OpenJDK 64-Bit Server VM 1.8.0_171-b10 on Linux 4.14.33-51.37.amzn1.x86_64 - Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz - Select 0 string row (value IS NULL): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - Parquet Vectorized 9201 / 9300 1.7 585.0 1.0X - Parquet Vectorized (Pushdown) 89 / 105 176.3 5.7 103.1X - Native ORC Vectorized 8886 / 8898 1.8 564.9 1.0X - Native ORC Vectorized (Pushdown) 110 / 128 143.4 7.0 83.9X - - - Select 0 string row - ('7864320' < value < '7864320'): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - Parquet Vectorized 9336 / 9357 1.7 593.6 1.0X - Parquet Vectorized (Pushdown) 927 / 937 17.0 58.9 10.1X - Native ORC Vectorized 9026 / 9041 1.7 573.9 1.0X - Native ORC Vectorized (Pushdown) 257 / 272 61.1 16.4 36.3X - - - Select 1 string row (value = '7864320'): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - Parquet Vectorized 9209 / 9223 1.7 585.5 1.0X - Parquet Vectorized (Pushdown) 908 / 925 17.3 57.7 10.1X - Native ORC Vectorized 8878 / 8904 1.8 564.4 1.0X - Native ORC Vectorized (Pushdown) 248 / 261 63.4 15.8 37.1X - - - Select 1 string row - (value <=> '7864320'): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - Parquet Vectorized 9194 / 9216 1.7 584.5 1.0X - Parquet Vectorized (Pushdown) 899 / 908 17.5 57.2 10.2X - Native ORC Vectorized 8934 / 8962 1.8 568.0 1.0X - Native ORC Vectorized (Pushdown) 249 / 254 63.3 15.8 37.0X - - - Select 1 string row - ('7864320' <= value <= '7864320'): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - Parquet Vectorized 9332 / 9351 1.7 593.3 1.0X - Parquet Vectorized (Pushdown) 915 / 934 17.2 58.2 10.2X - Native ORC Vectorized 9049 / 9057 1.7 575.3 1.0X - Native ORC Vectorized (Pushdown) 248 / 258 63.5 15.8 37.7X - - - Select all string rows - (value IS NOT NULL): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - Parquet Vectorized 20478 / 20497 0.8 1301.9 1.0X - Parquet Vectorized (Pushdown) 20461 / 20550 0.8 1300.9 1.0X - Native ORC Vectorized 27464 / 27482 0.6 1746.1 0.7X - Native ORC Vectorized (Pushdown) 27454 / 27488 0.6 1745.5 0.7X - - - Select 0 int row (value IS NULL): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - Parquet Vectorized 8489 / 8519 1.9 539.7 1.0X - Parquet Vectorized (Pushdown) 64 / 69 246.1 4.1 132.8X - Native ORC Vectorized 8064 / 8099 2.0 512.7 1.1X - Native ORC Vectorized (Pushdown) 88 / 94 178.6 5.6 96.4X - - - Select 0 int row - (7864320 < value < 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - Parquet Vectorized 8494 / 8514 1.9 540.0 1.0X - Parquet Vectorized (Pushdown) 835 / 840 18.8 53.1 10.2X - Native ORC Vectorized 8090 / 8106 1.9 514.4 1.0X - Native ORC Vectorized (Pushdown) 249 / 257 63.2 15.8 34.1X - - - Select 1 int row (value = 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - Parquet Vectorized 8552 / 8560 1.8 543.7 1.0X - Parquet Vectorized (Pushdown) 837 / 841 18.8 53.2 10.2X - Native ORC Vectorized 8178 / 8188 1.9 519.9 1.0X - Native ORC Vectorized (Pushdown) 249 / 258 63.2 15.8 34.4X - - - Select 1 int row (value <=> 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - Parquet Vectorized 8562 / 8580 1.8 544.3 1.0X - Parquet Vectorized (Pushdown) 833 / 836 18.9 53.0 10.3X - Native ORC Vectorized 8164 / 8185 1.9 519.0 1.0X - Native ORC Vectorized (Pushdown) 245 / 254 64.3 15.6 35.0X - - - Select 1 int row - (7864320 <= value <= 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - Parquet Vectorized 8540 / 8555 1.8 542.9 1.0X - Parquet Vectorized (Pushdown) 837 / 839 18.8 53.2 10.2X - Native ORC Vectorized 8182 / 8231 1.9 520.2 1.0X - Native ORC Vectorized (Pushdown) 250 / 259 62.9 15.9 34.1X - - - Select 1 int row - (7864319 < value < 7864321): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - Parquet Vectorized 8535 / 8555 1.8 542.6 1.0X - Parquet Vectorized (Pushdown) 835 / 841 18.8 53.1 10.2X - Native ORC Vectorized 8159 / 8179 1.9 518.8 1.0X - Native ORC Vectorized (Pushdown) 244 / 250 64.5 15.5 35.0X - - - Select 10% int rows (value < 1572864): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - Parquet Vectorized 9609 / 9634 1.6 610.9 1.0X - Parquet Vectorized (Pushdown) 2663 / 2672 5.9 169.3 3.6X - Native ORC Vectorized 9824 / 9850 1.6 624.6 1.0X - Native ORC Vectorized (Pushdown) 2717 / 2722 5.8 172.7 3.5X - - - Select 50% int rows (value < 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - Parquet Vectorized 13592 / 13613 1.2 864.2 1.0X - Parquet Vectorized (Pushdown) 9720 / 9738 1.6 618.0 1.4X - Native ORC Vectorized 16366 / 16397 1.0 1040.5 0.8X - Native ORC Vectorized (Pushdown) 12437 / 12459 1.3 790.7 1.1X - - - Select 90% int rows (value < 14155776): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - Parquet Vectorized 17580 / 17617 0.9 1117.7 1.0X - Parquet Vectorized (Pushdown) 16803 / 16827 0.9 1068.3 1.0X - Native ORC Vectorized 24169 / 24187 0.7 1536.6 0.7X - Native ORC Vectorized (Pushdown) 22147 / 22341 0.7 1408.1 0.8X - - - Select all int rows (value IS NOT NULL): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - Parquet Vectorized 18461 / 18491 0.9 1173.7 1.0X - Parquet Vectorized (Pushdown) 18466 / 18530 0.9 1174.1 1.0X - Native ORC Vectorized 24231 / 24270 0.6 1540.6 0.8X - Native ORC Vectorized (Pushdown) 24207 / 24304 0.6 1539.0 0.8X - - - Select all int rows (value > -1): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - Parquet Vectorized 18414 / 18453 0.9 1170.7 1.0X - Parquet Vectorized (Pushdown) 18435 / 18464 0.9 1172.1 1.0X - Native ORC Vectorized 24430 / 24454 0.6 1553.2 0.8X - Native ORC Vectorized (Pushdown) 24410 / 24465 0.6 1552.0 0.8X - - - Select all int rows (value != -1): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - Parquet Vectorized 18446 / 18457 0.9 1172.8 1.0X - Parquet Vectorized (Pushdown) 18428 / 18440 0.9 1171.6 1.0X - Native ORC Vectorized 24414 / 24450 0.6 1552.2 0.8X - Native ORC Vectorized (Pushdown) 24385 / 24472 0.6 1550.4 0.8X - - - Select 0 distinct string row - (value IS NULL): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - Parquet Vectorized 8322 / 8352 1.9 529.1 1.0X - Parquet Vectorized (Pushdown) 53 / 57 296.3 3.4 156.7X - Native ORC Vectorized 7903 / 7953 2.0 502.4 1.1X - Native ORC Vectorized (Pushdown) 80 / 82 197.2 5.1 104.3X - - - Select 0 distinct string row - ('100' < value < '100'): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - Parquet Vectorized 8712 / 8743 1.8 553.9 1.0X - Parquet Vectorized (Pushdown) 995 / 1030 15.8 63.3 8.8X - Native ORC Vectorized 8345 / 8362 1.9 530.6 1.0X - Native ORC Vectorized (Pushdown) 84 / 87 187.6 5.3 103.9X - - - Select 1 distinct string row - (value = '100'): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - Parquet Vectorized 8574 / 8610 1.8 545.1 1.0X - Parquet Vectorized (Pushdown) 1127 / 1135 14.0 71.6 7.6X - Native ORC Vectorized 8163 / 8181 1.9 519.0 1.1X - Native ORC Vectorized (Pushdown) 426 / 433 36.9 27.1 20.1X - - - Select 1 distinct string row - (value <=> '100'): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - Parquet Vectorized 8549 / 8568 1.8 543.5 1.0X - Parquet Vectorized (Pushdown) 1124 / 1131 14.0 71.4 7.6X - Native ORC Vectorized 8163 / 8210 1.9 519.0 1.0X - Native ORC Vectorized (Pushdown) 426 / 436 36.9 27.1 20.1X - - - Select 1 distinct string row - ('100' <= value <= '100'): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - Parquet Vectorized 8889 / 8896 1.8 565.2 1.0X - Parquet Vectorized (Pushdown) 1161 / 1168 13.6 73.8 7.7X - Native ORC Vectorized 8519 / 8554 1.8 541.6 1.0X - Native ORC Vectorized (Pushdown) 430 / 437 36.6 27.3 20.7X - - - Select all distinct string rows - (value IS NOT NULL): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------------ - Parquet Vectorized 20433 / 20533 0.8 1299.1 1.0X - Parquet Vectorized (Pushdown) 20433 / 20456 0.8 1299.1 1.0X - Native ORC Vectorized 25435 / 25513 0.6 1617.1 0.8X - Native ORC Vectorized (Pushdown) 25435 / 25507 0.6 1617.1 0.8X - */ - benchmark.run() } @@ -408,14 +239,8 @@ object FilterPushdownBenchmark { } } - def main(args: Array[String]): Unit = { - val numRows = 1024 * 1024 * 15 - val width = 5 - - // Pushdown for many distinct value case + ignore("Pushdown for many distinct value case") { withTempPath { dir => - val mid = numRows / 2 - withTempTable("orcTable", "patquetTable") { Seq(true, false).foreach { useStringForValue => prepareTable(dir, numRows, width, useStringForValue) @@ -427,16 +252,122 @@ object FilterPushdownBenchmark { } } } + } - // Pushdown for few distinct value case (use dictionary encoding) + ignore("Pushdown for few distinct value case (use dictionary encoding)") { withTempPath { dir => val numDistinctValues = 200 - val mid = numDistinctValues / 2 withTempTable("orcTable", "patquetTable") { prepareStringDictTable(dir, numRows, numDistinctValues, width) - runStringBenchmark(numRows, width, mid, "distinct string") + runStringBenchmark(numRows, width, numDistinctValues / 2, "distinct string") } } } + + ignore("Pushdown benchmark for StringStartsWith") { + withTempPath { dir => + withTempTable("orcTable", "patquetTable") { + prepareTable(dir, numRows, width, true) + Seq( + "value like '10%'", + "value like '1000%'", + s"value like '${mid.toString.substring(0, mid.toString.length - 1)}%'" + ).foreach { whereExpr => + val title = s"StringStartsWith filter: ($whereExpr)" + filterPushDownBenchmark(numRows, title, whereExpr) + } + } + } + } + + ignore(s"Pushdown benchmark for ${DecimalType.simpleString}") { + withTempPath { dir => + Seq( + s"decimal(${Decimal.MAX_INT_DIGITS}, 2)", + s"decimal(${Decimal.MAX_LONG_DIGITS}, 2)", + s"decimal(${DecimalType.MAX_PRECISION}, 2)" + ).foreach { dt => + val columns = (1 to width).map(i => s"CAST(id AS string) c$i") + val df = spark.range(numRows).selectExpr(columns: _*) + .withColumn("value", monotonically_increasing_id().cast(dt)) + withTempTable("orcTable", "patquetTable") { + saveAsTable(df, dir) + + Seq(s"value = $mid").foreach { whereExpr => + val title = s"Select 1 $dt row ($whereExpr)".replace("value AND value", "value") + filterPushDownBenchmark(numRows, title, whereExpr) + } + + val selectExpr = (1 to width).map(i => s"MAX(c$i)").mkString("", ",", ", MAX(value)") + Seq(10, 50, 90).foreach { percent => + filterPushDownBenchmark( + numRows, + s"Select $percent% $dt rows (value < ${numRows * percent / 100})", + s"value < ${numRows * percent / 100}", + selectExpr + ) + } + } + } + } + } + + ignore("Pushdown benchmark for InSet -> InFilters") { + withTempPath { dir => + withTempTable("orcTable", "patquetTable") { + prepareTable(dir, numRows, width, false) + Seq(5, 10, 50, 100).foreach { count => + Seq(10, 50, 90).foreach { distribution => + val filter = + Range(0, count).map(r => scala.util.Random.nextInt(numRows * distribution / 100)) + val whereExpr = s"value in(${filter.mkString(",")})" + val title = s"InSet -> InFilters (values count: $count, distribution: $distribution)" + filterPushDownBenchmark(numRows, title, whereExpr) + } + } + } + } + } + + ignore(s"Pushdown benchmark for ${ByteType.simpleString}") { + withTempPath { dir => + val columns = (1 to width).map(i => s"CAST(id AS string) c$i") + val df = spark.range(numRows).selectExpr(columns: _*) + .withColumn("value", (monotonically_increasing_id() % Byte.MaxValue).cast(ByteType)) + .orderBy("value") + withTempTable("orcTable", "patquetTable") { + saveAsTable(df, dir) + + Seq(s"value = CAST(${Byte.MaxValue / 2} AS ${ByteType.simpleString})") + .foreach { whereExpr => + val title = s"Select 1 ${ByteType.simpleString} row ($whereExpr)" + .replace("value AND value", "value") + filterPushDownBenchmark(numRows, title, whereExpr) + } + + val selectExpr = (1 to width).map(i => s"MAX(c$i)").mkString("", ",", ", MAX(value)") + Seq(10, 50, 90).foreach { percent => + filterPushDownBenchmark( + numRows, + s"Select $percent% ${ByteType.simpleString} rows " + + s"(value < CAST(${Byte.MaxValue * percent / 100} AS ${ByteType.simpleString}))", + s"value < CAST(${Byte.MaxValue * percent / 100} AS ${ByteType.simpleString})", + selectExpr + ) + } + } + } + } +} + +trait BenchmarkBeforeAndAfterEachTest extends BeforeAndAfterEachTestData { this: Suite => + + override def beforeEach(td: TestData) { + super.beforeEach(td) + } + + override def afterEach(td: TestData) { + super.afterEach(td) + } } From 141953f4c44dbad1c2a7059e92bec5fe770af932 Mon Sep 17 00:00:00 2001 From: Felix Cheung Date: Fri, 6 Jul 2018 00:08:03 -0700 Subject: [PATCH 1074/1161] [SPARK-24535][SPARKR] fix tests on java check error ## What changes were proposed in this pull request? change to skip tests if - couldn't determine java version fix problem on windows ## How was this patch tested? unit test, manual, win-builder Author: Felix Cheung Closes #21666 from felixcheung/rjavaskip. --- R/pkg/R/client.R | 24 +++++++++++++++--------- R/pkg/R/sparkR.R | 2 +- R/pkg/inst/tests/testthat/test_basic.R | 8 ++++++++ 3 files changed, 24 insertions(+), 10 deletions(-) diff --git a/R/pkg/R/client.R b/R/pkg/R/client.R index 4c87f64e7f0e1..660f0864403e0 100644 --- a/R/pkg/R/client.R +++ b/R/pkg/R/client.R @@ -71,15 +71,20 @@ checkJavaVersion <- function() { # If java is missing from PATH, we get an error in Unix and a warning in Windows javaVersionOut <- tryCatch( - launchScript(javaBin, "-version", wait = TRUE, stdout = TRUE, stderr = TRUE), - error = function(e) { - stop("Java version check failed. Please make sure Java is installed", - " and set JAVA_HOME to point to the installation directory.", e) - }, - warning = function(w) { - stop("Java version check failed. Please make sure Java is installed", - " and set JAVA_HOME to point to the installation directory.", w) - }) + if (is_windows()) { + # See SPARK-24535 + system2(javaBin, "-version", wait = TRUE, stdout = TRUE, stderr = TRUE) + } else { + launchScript(javaBin, "-version", wait = TRUE, stdout = TRUE, stderr = TRUE) + }, + error = function(e) { + stop("Java version check failed. Please make sure Java is installed", + " and set JAVA_HOME to point to the installation directory.", e) + }, + warning = function(w) { + stop("Java version check failed. Please make sure Java is installed", + " and set JAVA_HOME to point to the installation directory.", w) + }) javaVersionFilter <- Filter( function(x) { grepl(" version", x) @@ -93,6 +98,7 @@ checkJavaVersion <- function() { stop(paste("Java version", sparkJavaVersion, "is required for this package; found version:", javaVersionStr)) } + return(javaVersionNum) } launchBackend <- function(args, sparkHome, jars, sparkSubmitOpts, packages) { diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R index f7c1663d32c96..d3a9cbae7d808 100644 --- a/R/pkg/R/sparkR.R +++ b/R/pkg/R/sparkR.R @@ -167,7 +167,7 @@ sparkR.sparkContext <- function( submitOps <- getClientModeSparkSubmitOpts( Sys.getenv("SPARKR_SUBMIT_ARGS", "sparkr-shell"), sparkEnvirMap) - checkJavaVersion() + invisible(checkJavaVersion()) launchBackend( args = path, sparkHome = sparkHome, diff --git a/R/pkg/inst/tests/testthat/test_basic.R b/R/pkg/inst/tests/testthat/test_basic.R index 823d26f12feee..243f5f0298284 100644 --- a/R/pkg/inst/tests/testthat/test_basic.R +++ b/R/pkg/inst/tests/testthat/test_basic.R @@ -18,6 +18,10 @@ context("basic tests for CRAN") test_that("create DataFrame from list or data.frame", { + tryCatch( checkJavaVersion(), + error = function(e) { skip("error on Java check") }, + warning = function(e) { skip("warning on Java check") } ) + sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE, sparkConfig = sparkRTestConfig) @@ -50,6 +54,10 @@ test_that("create DataFrame from list or data.frame", { }) test_that("spark.glm and predict", { + tryCatch( checkJavaVersion(), + error = function(e) { skip("error on Java check") }, + warning = function(e) { skip("warning on Java check") } ) + sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE, sparkConfig = sparkRTestConfig) From a381bce7285ec30f58f28f523dfcfe0c13221bbf Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Fri, 6 Jul 2018 18:28:54 +0800 Subject: [PATCH 1075/1161] [SPARK-24673][SQL][PYTHON][FOLLOWUP] Support Column arguments in timezone of from_utc_timestamp/to_utc_timestamp ## What changes were proposed in this pull request? This pr supported column arguments in timezone of `from_utc_timestamp/to_utc_timestamp` (follow-up of #21693). ## How was this patch tested? Added tests. Author: Takeshi Yamamuro Closes #21723 from maropu/SPARK-24673-FOLLOWUP. --- python/pyspark/sql/functions.py | 26 +++++++++++++++++++++++--- 1 file changed, 23 insertions(+), 3 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 4d371976364d3..55e7d575b4681 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1285,11 +1285,21 @@ def from_utc_timestamp(timestamp, tz): that time as a timestamp in the given time zone. For example, 'GMT+1' would yield '2017-07-14 03:40:00.0'. - >>> df = spark.createDataFrame([('1997-02-28 10:30:00',)], ['t']) - >>> df.select(from_utc_timestamp(df.t, "PST").alias('local_time')).collect() + :param timestamp: the column that contains timestamps + :param tz: a string that has the ID of timezone, e.g. "GMT", "America/Los_Angeles", etc + + .. versionchanged:: 2.4 + `tz` can take a :class:`Column` containing timezone ID strings. + + >>> df = spark.createDataFrame([('1997-02-28 10:30:00', 'JST')], ['ts', 'tz']) + >>> df.select(from_utc_timestamp(df.ts, "PST").alias('local_time')).collect() [Row(local_time=datetime.datetime(1997, 2, 28, 2, 30))] + >>> df.select(from_utc_timestamp(df.ts, df.tz).alias('local_time')).collect() + [Row(local_time=datetime.datetime(1997, 2, 28, 19, 30))] """ sc = SparkContext._active_spark_context + if isinstance(tz, Column): + tz = _to_java_column(tz) return Column(sc._jvm.functions.from_utc_timestamp(_to_java_column(timestamp), tz)) @@ -1300,11 +1310,21 @@ def to_utc_timestamp(timestamp, tz): zone, and renders that time as a timestamp in UTC. For example, 'GMT+1' would yield '2017-07-14 01:40:00.0'. - >>> df = spark.createDataFrame([('1997-02-28 10:30:00',)], ['ts']) + :param timestamp: the column that contains timestamps + :param tz: a string that has the ID of timezone, e.g. "GMT", "America/Los_Angeles", etc + + .. versionchanged:: 2.4 + `tz` can take a :class:`Column` containing timezone ID strings. + + >>> df = spark.createDataFrame([('1997-02-28 10:30:00', 'JST')], ['ts', 'tz']) >>> df.select(to_utc_timestamp(df.ts, "PST").alias('utc_time')).collect() [Row(utc_time=datetime.datetime(1997, 2, 28, 18, 30))] + >>> df.select(to_utc_timestamp(df.ts, df.tz).alias('utc_time')).collect() + [Row(utc_time=datetime.datetime(1997, 2, 28, 1, 30))] """ sc = SparkContext._active_spark_context + if isinstance(tz, Column): + tz = _to_java_column(tz) return Column(sc._jvm.functions.to_utc_timestamp(_to_java_column(timestamp), tz)) From 4de0425df8d2545718a0583bc26592108aebc5ac Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sat, 7 Jul 2018 10:54:14 +0800 Subject: [PATCH 1076/1161] [SPARK-24569][SQL] Aggregator with output type Option should produce consistent schema ## What changes were proposed in this pull request? SQL `Aggregator` with output type `Option[Boolean]` creates column of type `StructType`. It's not in consistency with a Dataset of similar java class. This changes the way `definedByConstructorParams` checks given type. For `Option[_]`, it goes to check its type argument. ## How was this patch tested? Added test. Author: Liang-Chi Hsieh Closes #21611 from viirya/SPARK-24569. --- .../spark/sql/catalyst/ScalaReflection.scala | 7 ++- .../spark/sql/DatasetAggregatorSuite.scala | 60 +++++++++++++++++++ .../org/apache/spark/sql/DatasetSuite.scala | 11 ++++ 3 files changed, 77 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index f9acc208b715e..4543bba8f6ed4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -798,7 +798,12 @@ object ScalaReflection extends ScalaReflection { * Whether the fields of the given type is defined entirely by its constructor parameters. */ def definedByConstructorParams(tpe: Type): Boolean = cleanUpReflectionObjects { - tpe.dealias <:< localTypeOf[Product] || tpe.dealias <:< localTypeOf[DefinedByConstructorParams] + tpe.dealias match { + // `Option` is a `Product`, but we don't wanna treat `Option[Int]` as a struct type. + case t if t <:< localTypeOf[Option[_]] => definedByConstructorParams(t.typeArgs.head) + case _ => tpe.dealias <:< localTypeOf[Product] || + tpe.dealias <:< localTypeOf[DefinedByConstructorParams] + } } private val javaKeywords = Set("abstract", "assert", "boolean", "break", "byte", "case", "catch", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala index 0e7eaa9e88d57..538ea3c66c40e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala @@ -148,6 +148,41 @@ object VeryComplexResultAgg extends Aggregator[Row, String, ComplexAggData] { } +case class OptionBooleanData(name: String, isGood: Option[Boolean]) + +case class OptionBooleanAggregator(colName: String) + extends Aggregator[Row, Option[Boolean], Option[Boolean]] { + + override def zero: Option[Boolean] = None + + override def reduce(buffer: Option[Boolean], row: Row): Option[Boolean] = { + val index = row.fieldIndex(colName) + val value = if (row.isNullAt(index)) { + Option.empty[Boolean] + } else { + Some(row.getBoolean(index)) + } + merge(buffer, value) + } + + override def merge(b1: Option[Boolean], b2: Option[Boolean]): Option[Boolean] = { + if ((b1.isDefined && b1.get) || (b2.isDefined && b2.get)) { + Some(true) + } else if (b1.isDefined) { + b1 + } else { + b2 + } + } + + override def finish(reduction: Option[Boolean]): Option[Boolean] = reduction + + override def bufferEncoder: Encoder[Option[Boolean]] = OptionalBoolEncoder + override def outputEncoder: Encoder[Option[Boolean]] = OptionalBoolEncoder + + def OptionalBoolEncoder: Encoder[Option[Boolean]] = ExpressionEncoder() +} + class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { import testImplicits._ @@ -333,4 +368,29 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { df.groupBy($"i").agg(VeryComplexResultAgg.toColumn), Row(1, Row(Row(1, "a"), Row(1, "a"))) :: Row(2, Row(Row(2, "bc"), Row(2, "bc"))) :: Nil) } + + test("SPARK-24569: Aggregator with output type Option[Boolean] creates column of type Row") { + val df = Seq( + OptionBooleanData("bob", Some(true)), + OptionBooleanData("bob", Some(false)), + OptionBooleanData("bob", None)).toDF() + val group = df + .groupBy("name") + .agg(OptionBooleanAggregator("isGood").toColumn.alias("isGood")) + assert(df.schema == group.schema) + checkAnswer(group, Row("bob", true) :: Nil) + checkDataset(group.as[OptionBooleanData], OptionBooleanData("bob", Some(true))) + } + + test("SPARK-24569: groupByKey with Aggregator of output type Option[Boolean]") { + val df = Seq( + OptionBooleanData("bob", Some(true)), + OptionBooleanData("bob", Some(false)), + OptionBooleanData("bob", None)).toDF() + val grouped = df.groupByKey((r: Row) => r.getString(0)) + .agg(OptionBooleanAggregator("isGood").toColumn).toDF("name", "isGood") + + assert(grouped.schema == df.schema) + checkDataset(grouped.as[OptionBooleanData], OptionBooleanData("bob", Some(true))) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 2d20c50584c03..ce8db99d4e2f1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -1467,6 +1467,17 @@ class DatasetSuite extends QueryTest with SharedSQLContext { intercept[NullPointerException](ds.as[(Int, Int)].collect()) } + test("SPARK-24569: Option of primitive types are mistakenly mapped to struct type") { + withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "true") { + val a = Seq(Some(1)).toDS + val b = Seq(Some(1.2)).toDS + val expected = Seq((Some(1), Some(1.2))).toDS + val joined = a.joinWith(b, lit(true)) + assert(joined.schema == expected.schema) + checkDataset(joined, expected.collect: _*) + } + } + test("SPARK-24548: Dataset with tuple encoders should have correct schema") { val encoder = Encoders.tuple(newStringEncoder, Encoders.tuple(newStringEncoder, newStringEncoder)) From fc43690d36e7a17e45826a69ab86935fb0ee2be4 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sat, 7 Jul 2018 11:34:30 +0800 Subject: [PATCH 1077/1161] [SPARK-24749][SQL] Use sameType to compare Array's element type in ArrayContains ## What changes were proposed in this pull request? We should use `DataType.sameType` to compare element type in `ArrayContains`, otherwise nullability affects comparison result. ## How was this patch tested? Added test. Author: Liang-Chi Hsieh Closes #21724 from viirya/SPARK-24749. --- .../sql/catalyst/expressions/collectionOperations.scala | 2 +- .../catalyst/expressions/CollectionExpressionsSuite.scala | 7 +++++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 8b278f067749e..fcac3a58e6a95 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -1085,7 +1085,7 @@ case class ArrayContains(left: Expression, right: Expression) if (right.dataType == NullType) { TypeCheckResult.TypeCheckFailure("Null typed values cannot be used as arguments") } else if (!left.dataType.isInstanceOf[ArrayType] - || left.dataType.asInstanceOf[ArrayType].elementType != right.dataType) { + || !left.dataType.asInstanceOf[ArrayType].elementType.sameType(right.dataType)) { TypeCheckResult.TypeCheckFailure( "Arguments must be an array followed by a value of same type as the array members") } else { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index d7744eb4c7dc7..496ee1d496a36 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -213,6 +213,8 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper val a1 = Literal.create(Seq[String](null, ""), ArrayType(StringType)) val a2 = Literal.create(Seq(null), ArrayType(LongType)) val a3 = Literal.create(null, ArrayType(StringType)) + val a4 = Literal.create(Seq(create_row(1)), ArrayType(StructType(Seq( + StructField("a", IntegerType, true))))) checkEvaluation(ArrayContains(a0, Literal(1)), true) checkEvaluation(ArrayContains(a0, Literal(0)), false) @@ -228,6 +230,11 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(ArrayContains(a3, Literal("")), null) checkEvaluation(ArrayContains(a3, Literal.create(null, StringType)), null) + checkEvaluation(ArrayContains(a4, Literal.create(create_row(1), StructType(Seq( + StructField("a", IntegerType, false))))), true) + checkEvaluation(ArrayContains(a4, Literal.create(create_row(0), StructType(Seq( + StructField("a", IntegerType, false))))), false) + // binary val b0 = Literal.create(Seq[Array[Byte]](Array[Byte](5, 6), Array[Byte](1, 2)), ArrayType(BinaryType)) From 74f6a92fcea9196d62c2d531c11ec7efd580b760 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Sat, 7 Jul 2018 11:37:41 +0800 Subject: [PATCH 1078/1161] [SPARK-24739][PYTHON] Make PySpark compatible with Python 3.7 ## What changes were proposed in this pull request? This PR proposes to make PySpark compatible with Python 3.7. There are rather radical change in semantic of `StopIteration` within a generator. It now throws it as a `RuntimeError`. To make it compatible, we should fix it: ```python try: next(...) except StopIteration return ``` See [release note](https://docs.python.org/3/whatsnew/3.7.html#porting-to-python-3-7) and [PEP 479](https://www.python.org/dev/peps/pep-0479/). ## How was this patch tested? Manually tested: ``` $ ./run-tests --python-executables=python3.7 Running PySpark tests. Output is in /.../spark/python/unit-tests.log Will test against the following Python executables: ['python3.7'] Will test the following Python modules: ['pyspark-core', 'pyspark-ml', 'pyspark-mllib', 'pyspark-sql', 'pyspark-streaming'] Starting test(python3.7): pyspark.mllib.tests Starting test(python3.7): pyspark.sql.tests Starting test(python3.7): pyspark.streaming.tests Starting test(python3.7): pyspark.tests Finished test(python3.7): pyspark.streaming.tests (130s) Starting test(python3.7): pyspark.accumulators Finished test(python3.7): pyspark.accumulators (8s) Starting test(python3.7): pyspark.broadcast Finished test(python3.7): pyspark.broadcast (9s) Starting test(python3.7): pyspark.conf Finished test(python3.7): pyspark.conf (6s) Starting test(python3.7): pyspark.context Finished test(python3.7): pyspark.context (27s) Starting test(python3.7): pyspark.ml.classification Finished test(python3.7): pyspark.tests (200s) ... 3 tests were skipped Starting test(python3.7): pyspark.ml.clustering Finished test(python3.7): pyspark.mllib.tests (244s) Starting test(python3.7): pyspark.ml.evaluation Finished test(python3.7): pyspark.ml.classification (63s) Starting test(python3.7): pyspark.ml.feature Finished test(python3.7): pyspark.ml.clustering (48s) Starting test(python3.7): pyspark.ml.fpm Finished test(python3.7): pyspark.ml.fpm (0s) Starting test(python3.7): pyspark.ml.image Finished test(python3.7): pyspark.ml.evaluation (23s) Starting test(python3.7): pyspark.ml.linalg.__init__ Finished test(python3.7): pyspark.ml.linalg.__init__ (0s) Starting test(python3.7): pyspark.ml.recommendation Finished test(python3.7): pyspark.ml.image (20s) Starting test(python3.7): pyspark.ml.regression Finished test(python3.7): pyspark.ml.regression (58s) Starting test(python3.7): pyspark.ml.stat Finished test(python3.7): pyspark.ml.feature (90s) Starting test(python3.7): pyspark.ml.tests Finished test(python3.7): pyspark.ml.recommendation (82s) Starting test(python3.7): pyspark.ml.tuning Finished test(python3.7): pyspark.ml.stat (27s) Starting test(python3.7): pyspark.mllib.classification Finished test(python3.7): pyspark.sql.tests (362s) ... 102 tests were skipped Starting test(python3.7): pyspark.mllib.clustering Finished test(python3.7): pyspark.ml.tuning (29s) Starting test(python3.7): pyspark.mllib.evaluation Finished test(python3.7): pyspark.mllib.classification (39s) Starting test(python3.7): pyspark.mllib.feature Finished test(python3.7): pyspark.mllib.evaluation (30s) Starting test(python3.7): pyspark.mllib.fpm Finished test(python3.7): pyspark.mllib.feature (44s) Starting test(python3.7): pyspark.mllib.linalg.__init__ Finished test(python3.7): pyspark.mllib.linalg.__init__ (0s) Starting test(python3.7): pyspark.mllib.linalg.distributed Finished test(python3.7): pyspark.mllib.clustering (78s) Starting test(python3.7): pyspark.mllib.random Finished test(python3.7): pyspark.mllib.fpm (33s) Starting test(python3.7): pyspark.mllib.recommendation Finished test(python3.7): pyspark.mllib.random (12s) Starting test(python3.7): pyspark.mllib.regression Finished test(python3.7): pyspark.mllib.linalg.distributed (45s) Starting test(python3.7): pyspark.mllib.stat.KernelDensity Finished test(python3.7): pyspark.mllib.stat.KernelDensity (0s) Starting test(python3.7): pyspark.mllib.stat._statistics Finished test(python3.7): pyspark.mllib.recommendation (41s) Starting test(python3.7): pyspark.mllib.tree Finished test(python3.7): pyspark.mllib.regression (44s) Starting test(python3.7): pyspark.mllib.util Finished test(python3.7): pyspark.mllib.stat._statistics (20s) Starting test(python3.7): pyspark.profiler Finished test(python3.7): pyspark.mllib.tree (26s) Starting test(python3.7): pyspark.rdd Finished test(python3.7): pyspark.profiler (11s) Starting test(python3.7): pyspark.serializers Finished test(python3.7): pyspark.mllib.util (24s) Starting test(python3.7): pyspark.shuffle Finished test(python3.7): pyspark.shuffle (0s) Starting test(python3.7): pyspark.sql.catalog Finished test(python3.7): pyspark.serializers (15s) Starting test(python3.7): pyspark.sql.column Finished test(python3.7): pyspark.rdd (27s) Starting test(python3.7): pyspark.sql.conf Finished test(python3.7): pyspark.sql.catalog (24s) Starting test(python3.7): pyspark.sql.context Finished test(python3.7): pyspark.sql.conf (8s) Starting test(python3.7): pyspark.sql.dataframe Finished test(python3.7): pyspark.sql.column (29s) Starting test(python3.7): pyspark.sql.functions Finished test(python3.7): pyspark.sql.context (26s) Starting test(python3.7): pyspark.sql.group Finished test(python3.7): pyspark.sql.dataframe (51s) Starting test(python3.7): pyspark.sql.readwriter Finished test(python3.7): pyspark.ml.tests (266s) Starting test(python3.7): pyspark.sql.session Finished test(python3.7): pyspark.sql.group (36s) Starting test(python3.7): pyspark.sql.streaming Finished test(python3.7): pyspark.sql.functions (57s) Starting test(python3.7): pyspark.sql.types Finished test(python3.7): pyspark.sql.session (25s) Starting test(python3.7): pyspark.sql.udf Finished test(python3.7): pyspark.sql.types (10s) Starting test(python3.7): pyspark.sql.window Finished test(python3.7): pyspark.sql.readwriter (31s) Starting test(python3.7): pyspark.streaming.util Finished test(python3.7): pyspark.sql.streaming (22s) Starting test(python3.7): pyspark.util Finished test(python3.7): pyspark.util (0s) Finished test(python3.7): pyspark.streaming.util (0s) Finished test(python3.7): pyspark.sql.udf (16s) Finished test(python3.7): pyspark.sql.window (12s) ``` In my local (I have two Macs but both have the same issues), I currently faced some issues for now to install both extra dependencies PyArrow and Pandas same as Jenkins's, against Python 3.7. Author: hyukjinkwon Closes #21714 from HyukjinKwon/SPARK-24739. --- python/pyspark/rdd.py | 5 ++++- python/setup.py | 1 + 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 7e7e5822a6b20..951851804b1d8 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -1370,7 +1370,10 @@ def takeUpToNumLeft(iterator): iterator = iter(iterator) taken = 0 while taken < left: - yield next(iterator) + try: + yield next(iterator) + except StopIteration: + return taken += 1 p = range(partsScanned, min(partsScanned + numPartsToTry, totalParts)) diff --git a/python/setup.py b/python/setup.py index d309e0564530a..45eb74eb87ce7 100644 --- a/python/setup.py +++ b/python/setup.py @@ -219,6 +219,7 @@ def _supports_symlinks(): 'Programming Language :: Python :: 3.4', 'Programming Language :: Python :: 3.5', 'Programming Language :: Python :: 3.6', + 'Programming Language :: Python :: 3.7', 'Programming Language :: Python :: Implementation :: CPython', 'Programming Language :: Python :: Implementation :: PyPy'] ) From 044b33b2ed2d423d798f2a632fab110c46f41567 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Sat, 7 Jul 2018 11:39:29 +0800 Subject: [PATCH 1079/1161] [SPARK-24740][PYTHON][ML] Make PySpark's tests compatible with NumPy 1.14+ ## What changes were proposed in this pull request? This PR proposes to make PySpark's tests compatible with NumPy 0.14+ NumPy 0.14.x introduced rather radical changes about its string representation. For example, the tests below are failed: ``` ********************************************************************** File "/.../spark/python/pyspark/ml/linalg/__init__.py", line 895, in __main__.DenseMatrix.__str__ Failed example: print(dm) Expected: DenseMatrix([[ 0., 2.], [ 1., 3.]]) Got: DenseMatrix([[0., 2.], [1., 3.]]) ********************************************************************** File "/.../spark/python/pyspark/ml/linalg/__init__.py", line 899, in __main__.DenseMatrix.__str__ Failed example: print(dm) Expected: DenseMatrix([[ 0., 1.], [ 2., 3.]]) Got: DenseMatrix([[0., 1.], [2., 3.]]) ********************************************************************** File "/.../spark/python/pyspark/ml/linalg/__init__.py", line 939, in __main__.DenseMatrix.toArray Failed example: m.toArray() Expected: array([[ 0., 2.], [ 1., 3.]]) Got: array([[0., 2.], [1., 3.]]) ********************************************************************** File "/.../spark/python/pyspark/ml/linalg/__init__.py", line 324, in __main__.DenseVector.dot Failed example: dense.dot(np.reshape([1., 2., 3., 4.], (2, 2), order='F')) Expected: array([ 5., 11.]) Got: array([ 5., 11.]) ********************************************************************** File "/.../spark/python/pyspark/ml/linalg/__init__.py", line 567, in __main__.SparseVector.dot Failed example: a.dot(np.array([[1, 1], [2, 2], [3, 3], [4, 4]])) Expected: array([ 22., 22.]) Got: array([22., 22.]) ``` See [release note](https://docs.scipy.org/doc/numpy-1.14.0/release.html#compatibility-notes). ## How was this patch tested? Manually tested: ``` $ ./run-tests --python-executables=python3.6,python2.7 --modules=pyspark-ml,pyspark-mllib Running PySpark tests. Output is in /.../spark/python/unit-tests.log Will test against the following Python executables: ['python3.6', 'python2.7'] Will test the following Python modules: ['pyspark-ml', 'pyspark-mllib'] Starting test(python2.7): pyspark.mllib.tests Starting test(python2.7): pyspark.ml.classification Starting test(python3.6): pyspark.mllib.tests Starting test(python2.7): pyspark.ml.clustering Finished test(python2.7): pyspark.ml.clustering (54s) Starting test(python2.7): pyspark.ml.evaluation Finished test(python2.7): pyspark.ml.classification (74s) Starting test(python2.7): pyspark.ml.feature Finished test(python2.7): pyspark.ml.evaluation (27s) Starting test(python2.7): pyspark.ml.fpm Finished test(python2.7): pyspark.ml.fpm (0s) Starting test(python2.7): pyspark.ml.image Finished test(python2.7): pyspark.ml.image (17s) Starting test(python2.7): pyspark.ml.linalg.__init__ Finished test(python2.7): pyspark.ml.linalg.__init__ (1s) Starting test(python2.7): pyspark.ml.recommendation Finished test(python2.7): pyspark.ml.feature (76s) Starting test(python2.7): pyspark.ml.regression Finished test(python2.7): pyspark.ml.recommendation (69s) Starting test(python2.7): pyspark.ml.stat Finished test(python2.7): pyspark.ml.regression (45s) Starting test(python2.7): pyspark.ml.tests Finished test(python2.7): pyspark.ml.stat (28s) Starting test(python2.7): pyspark.ml.tuning Finished test(python2.7): pyspark.ml.tuning (20s) Starting test(python2.7): pyspark.mllib.classification Finished test(python2.7): pyspark.mllib.classification (31s) Starting test(python2.7): pyspark.mllib.clustering Finished test(python2.7): pyspark.mllib.tests (260s) Starting test(python2.7): pyspark.mllib.evaluation Finished test(python3.6): pyspark.mllib.tests (266s) Starting test(python2.7): pyspark.mllib.feature Finished test(python2.7): pyspark.mllib.evaluation (21s) Starting test(python2.7): pyspark.mllib.fpm Finished test(python2.7): pyspark.mllib.feature (38s) Starting test(python2.7): pyspark.mllib.linalg.__init__ Finished test(python2.7): pyspark.mllib.linalg.__init__ (1s) Starting test(python2.7): pyspark.mllib.linalg.distributed Finished test(python2.7): pyspark.mllib.fpm (34s) Starting test(python2.7): pyspark.mllib.random Finished test(python2.7): pyspark.mllib.clustering (64s) Starting test(python2.7): pyspark.mllib.recommendation Finished test(python2.7): pyspark.mllib.random (15s) Starting test(python2.7): pyspark.mllib.regression Finished test(python2.7): pyspark.mllib.linalg.distributed (47s) Starting test(python2.7): pyspark.mllib.stat.KernelDensity Finished test(python2.7): pyspark.mllib.stat.KernelDensity (0s) Starting test(python2.7): pyspark.mllib.stat._statistics Finished test(python2.7): pyspark.mllib.recommendation (40s) Starting test(python2.7): pyspark.mllib.tree Finished test(python2.7): pyspark.mllib.regression (38s) Starting test(python2.7): pyspark.mllib.util Finished test(python2.7): pyspark.mllib.stat._statistics (19s) Starting test(python3.6): pyspark.ml.classification Finished test(python2.7): pyspark.mllib.tree (26s) Starting test(python3.6): pyspark.ml.clustering Finished test(python2.7): pyspark.mllib.util (27s) Starting test(python3.6): pyspark.ml.evaluation Finished test(python3.6): pyspark.ml.evaluation (30s) Starting test(python3.6): pyspark.ml.feature Finished test(python2.7): pyspark.ml.tests (234s) Starting test(python3.6): pyspark.ml.fpm Finished test(python3.6): pyspark.ml.fpm (1s) Starting test(python3.6): pyspark.ml.image Finished test(python3.6): pyspark.ml.clustering (55s) Starting test(python3.6): pyspark.ml.linalg.__init__ Finished test(python3.6): pyspark.ml.linalg.__init__ (0s) Starting test(python3.6): pyspark.ml.recommendation Finished test(python3.6): pyspark.ml.classification (71s) Starting test(python3.6): pyspark.ml.regression Finished test(python3.6): pyspark.ml.image (18s) Starting test(python3.6): pyspark.ml.stat Finished test(python3.6): pyspark.ml.stat (37s) Starting test(python3.6): pyspark.ml.tests Finished test(python3.6): pyspark.ml.regression (59s) Starting test(python3.6): pyspark.ml.tuning Finished test(python3.6): pyspark.ml.feature (93s) Starting test(python3.6): pyspark.mllib.classification Finished test(python3.6): pyspark.ml.recommendation (83s) Starting test(python3.6): pyspark.mllib.clustering Finished test(python3.6): pyspark.ml.tuning (29s) Starting test(python3.6): pyspark.mllib.evaluation Finished test(python3.6): pyspark.mllib.evaluation (26s) Starting test(python3.6): pyspark.mllib.feature Finished test(python3.6): pyspark.mllib.classification (43s) Starting test(python3.6): pyspark.mllib.fpm Finished test(python3.6): pyspark.mllib.clustering (81s) Starting test(python3.6): pyspark.mllib.linalg.__init__ Finished test(python3.6): pyspark.mllib.linalg.__init__ (2s) Starting test(python3.6): pyspark.mllib.linalg.distributed Finished test(python3.6): pyspark.mllib.fpm (48s) Starting test(python3.6): pyspark.mllib.random Finished test(python3.6): pyspark.mllib.feature (54s) Starting test(python3.6): pyspark.mllib.recommendation Finished test(python3.6): pyspark.mllib.random (18s) Starting test(python3.6): pyspark.mllib.regression Finished test(python3.6): pyspark.mllib.linalg.distributed (55s) Starting test(python3.6): pyspark.mllib.stat.KernelDensity Finished test(python3.6): pyspark.mllib.stat.KernelDensity (1s) Starting test(python3.6): pyspark.mllib.stat._statistics Finished test(python3.6): pyspark.mllib.recommendation (51s) Starting test(python3.6): pyspark.mllib.tree Finished test(python3.6): pyspark.mllib.regression (45s) Starting test(python3.6): pyspark.mllib.util Finished test(python3.6): pyspark.mllib.stat._statistics (21s) Finished test(python3.6): pyspark.mllib.tree (27s) Finished test(python3.6): pyspark.mllib.util (27s) Finished test(python3.6): pyspark.ml.tests (264s) ``` Author: hyukjinkwon Closes #21715 from HyukjinKwon/SPARK-24740. --- python/pyspark/ml/clustering.py | 6 ++++++ python/pyspark/ml/linalg/__init__.py | 5 +++++ python/pyspark/ml/stat.py | 6 ++++++ python/pyspark/mllib/clustering.py | 6 ++++++ python/pyspark/mllib/evaluation.py | 6 ++++++ python/pyspark/mllib/linalg/__init__.py | 6 ++++++ python/pyspark/mllib/linalg/distributed.py | 6 ++++++ python/pyspark/mllib/stat/_statistics.py | 6 ++++++ 8 files changed, 47 insertions(+) diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py index 6d77baf7349e4..2f0660040dc7c 100644 --- a/python/pyspark/ml/clustering.py +++ b/python/pyspark/ml/clustering.py @@ -1345,8 +1345,14 @@ def assignClusters(self, dataset): if __name__ == "__main__": import doctest + import numpy import pyspark.ml.clustering from pyspark.sql import SparkSession + try: + # Numpy 1.14+ changed it's string format. + numpy.set_printoptions(legacy='1.13') + except TypeError: + pass globs = pyspark.ml.clustering.__dict__.copy() # The small batch size here ensures that we see multiple batches, # even in these small test examples: diff --git a/python/pyspark/ml/linalg/__init__.py b/python/pyspark/ml/linalg/__init__.py index 6a611a2b5b59d..2548fd0f50b33 100644 --- a/python/pyspark/ml/linalg/__init__.py +++ b/python/pyspark/ml/linalg/__init__.py @@ -1156,6 +1156,11 @@ def sparse(numRows, numCols, colPtrs, rowIndices, values): def _test(): import doctest + try: + # Numpy 1.14+ changed it's string format. + np.set_printoptions(legacy='1.13') + except TypeError: + pass (failure_count, test_count) = doctest.testmod(optionflags=doctest.ELLIPSIS) if failure_count: sys.exit(-1) diff --git a/python/pyspark/ml/stat.py b/python/pyspark/ml/stat.py index a06ab31a7a56a..370154fc6d62a 100644 --- a/python/pyspark/ml/stat.py +++ b/python/pyspark/ml/stat.py @@ -388,8 +388,14 @@ def summary(self, featuresCol, weightCol=None): if __name__ == "__main__": import doctest + import numpy import pyspark.ml.stat from pyspark.sql import SparkSession + try: + # Numpy 1.14+ changed it's string format. + numpy.set_printoptions(legacy='1.13') + except TypeError: + pass globs = pyspark.ml.stat.__dict__.copy() # The small batch size here ensures that we see multiple batches, diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py index 0cbabab13a896..b09469b9f5c2d 100644 --- a/python/pyspark/mllib/clustering.py +++ b/python/pyspark/mllib/clustering.py @@ -1042,7 +1042,13 @@ def train(cls, rdd, k=10, maxIterations=20, docConcentration=-1.0, def _test(): import doctest + import numpy import pyspark.mllib.clustering + try: + # Numpy 1.14+ changed it's string format. + numpy.set_printoptions(legacy='1.13') + except TypeError: + pass globs = pyspark.mllib.clustering.__dict__.copy() globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2) (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) diff --git a/python/pyspark/mllib/evaluation.py b/python/pyspark/mllib/evaluation.py index 36cb03369b8c0..6c65da58e4e2b 100644 --- a/python/pyspark/mllib/evaluation.py +++ b/python/pyspark/mllib/evaluation.py @@ -532,8 +532,14 @@ def accuracy(self): def _test(): import doctest + import numpy from pyspark.sql import SparkSession import pyspark.mllib.evaluation + try: + # Numpy 1.14+ changed it's string format. + numpy.set_printoptions(legacy='1.13') + except TypeError: + pass globs = pyspark.mllib.evaluation.__dict__.copy() spark = SparkSession.builder\ .master("local[4]")\ diff --git a/python/pyspark/mllib/linalg/__init__.py b/python/pyspark/mllib/linalg/__init__.py index 60d96d8d5ceb8..4afd6666400b0 100644 --- a/python/pyspark/mllib/linalg/__init__.py +++ b/python/pyspark/mllib/linalg/__init__.py @@ -1368,6 +1368,12 @@ def R(self): def _test(): import doctest + import numpy + try: + # Numpy 1.14+ changed it's string format. + numpy.set_printoptions(legacy='1.13') + except TypeError: + pass (failure_count, test_count) = doctest.testmod(optionflags=doctest.ELLIPSIS) if failure_count: sys.exit(-1) diff --git a/python/pyspark/mllib/linalg/distributed.py b/python/pyspark/mllib/linalg/distributed.py index bba88542167ad..7e8b15056cabe 100644 --- a/python/pyspark/mllib/linalg/distributed.py +++ b/python/pyspark/mllib/linalg/distributed.py @@ -1364,9 +1364,15 @@ def toCoordinateMatrix(self): def _test(): import doctest + import numpy from pyspark.sql import SparkSession from pyspark.mllib.linalg import Matrices import pyspark.mllib.linalg.distributed + try: + # Numpy 1.14+ changed it's string format. + numpy.set_printoptions(legacy='1.13') + except TypeError: + pass globs = pyspark.mllib.linalg.distributed.__dict__.copy() spark = SparkSession.builder\ .master("local[2]")\ diff --git a/python/pyspark/mllib/stat/_statistics.py b/python/pyspark/mllib/stat/_statistics.py index 3c75b132ecad2..937bb154c2356 100644 --- a/python/pyspark/mllib/stat/_statistics.py +++ b/python/pyspark/mllib/stat/_statistics.py @@ -303,7 +303,13 @@ def kolmogorovSmirnovTest(data, distName="norm", *params): def _test(): import doctest + import numpy from pyspark.sql import SparkSession + try: + # Numpy 1.14+ changed it's string format. + numpy.set_printoptions(legacy='1.13') + except TypeError: + pass globs = globals().copy() spark = SparkSession.builder\ .master("local[4]")\ From 79c66894296840cc4a5bf6c8718ecfd2b08bcca8 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Sat, 7 Jul 2018 22:16:48 +0200 Subject: [PATCH 1080/1161] [SPARK-24757][SQL] Improving the error message for broadcast timeouts ## What changes were proposed in this pull request? In the PR, I propose to provide a tip to user how to resolve the issue of timeout expiration for broadcast joins. In particular, they can increase the timeout via **spark.sql.broadcastTimeout** or disable the broadcast at all by setting **spark.sql.autoBroadcastJoinThreshold** to `-1`. ## How was this patch tested? It tested manually from `spark-shell`: ``` scala> spark.conf.set("spark.sql.broadcastTimeout", 1) scala> val df = spark.range(100).join(spark.range(15).as[Long].map { x => Thread.sleep(5000) x }).where("id = value") scala> df.count() ``` ``` org.apache.spark.SparkException: Could not execute broadcast in 1 secs. You can increase the timeout for broadcasts via spark.sql.broadcastTimeout or disable broadcast join by setting spark.sql.autoBroadcastJoinThreshold to -1 at org.apache.spark.sql.execution.exchange.BroadcastExchangeExec.doExecuteBroadcast(BroadcastExchangeExec.scala:150) ``` Author: Maxim Gekk Closes #21727 from MaxGekk/broadcast-timeout-error. --- .../execution/exchange/BroadcastExchangeExec.scala | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala index c55f9b8f1a7fc..a80673c705f1a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.exchange +import java.util.concurrent.TimeoutException + import scala.concurrent.{ExecutionContext, Future} import scala.concurrent.duration._ import scala.util.control.NonFatal @@ -140,7 +142,16 @@ case class BroadcastExchangeExec( } override protected[sql] def doExecuteBroadcast[T](): broadcast.Broadcast[T] = { - ThreadUtils.awaitResult(relationFuture, timeout).asInstanceOf[broadcast.Broadcast[T]] + try { + ThreadUtils.awaitResult(relationFuture, timeout).asInstanceOf[broadcast.Broadcast[T]] + } catch { + case ex: TimeoutException => + logError(s"Could not execute broadcast in ${timeout.toSeconds} secs.", ex) + throw new SparkException(s"Could not execute broadcast in ${timeout.toSeconds} secs. " + + s"You can increase the timeout for broadcasts via ${SQLConf.BROADCAST_TIMEOUT.key} or " + + s"disable broadcast join by setting ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key} to -1", + ex) + } } } From e2c7e09f742a7e522efd74fe8e14c2620afdb522 Mon Sep 17 00:00:00 2001 From: jerryshao Date: Mon, 9 Jul 2018 10:21:40 +0800 Subject: [PATCH 1081/1161] [SPARK-24646][CORE] Minor change to spark.yarn.dist.forceDownloadSchemes to support wildcard '*' ## What changes were proposed in this pull request? In the case of getting tokens via customized `ServiceCredentialProvider`, it is required that `ServiceCredentialProvider` be available in local spark-submit process classpath. In this case, all the configured remote sources should be forced to download to local. For the ease of using this configuration, here propose to add wildcard '*' support to `spark.yarn.dist.forceDownloadSchemes`, also clarify the usage of this configuration. ## How was this patch tested? New UT added. Author: jerryshao Closes #21633 from jerryshao/SPARK-21917-followup. --- .../org/apache/spark/deploy/SparkSubmit.scala | 5 ++-- .../spark/internal/config/package.scala | 5 ++-- .../spark/deploy/SparkSubmitSuite.scala | 29 ++++++++++++------- docs/running-on-yarn.md | 5 ++-- 4 files changed, 28 insertions(+), 16 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 2da778a29779d..e7310ee886103 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -385,7 +385,7 @@ private[spark] class SparkSubmit extends Logging { val forceDownloadSchemes = sparkConf.get(FORCE_DOWNLOAD_SCHEMES) def shouldDownload(scheme: String): Boolean = { - forceDownloadSchemes.contains(scheme) || + forceDownloadSchemes.contains("*") || forceDownloadSchemes.contains(scheme) || Try { FileSystem.getFileSystemClass(scheme, hadoopConf) }.isFailure } @@ -578,7 +578,8 @@ private[spark] class SparkSubmit extends Logging { } // Add the main application jar and any added jars to classpath in case YARN client // requires these jars. - // This assumes both primaryResource and user jars are local jars, otherwise it will not be + // This assumes both primaryResource and user jars are local jars, or already downloaded + // to local by configuring "spark.yarn.dist.forceDownloadSchemes", otherwise it will not be // added to the classpath of YARN client. if (isYarnCluster) { if (isUserJar(args.primaryResource)) { diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index bda9795a0b925..ba892bf7f60d6 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -486,10 +486,11 @@ package object config { private[spark] val FORCE_DOWNLOAD_SCHEMES = ConfigBuilder("spark.yarn.dist.forceDownloadSchemes") - .doc("Comma-separated list of schemes for which files will be downloaded to the " + + .doc("Comma-separated list of schemes for which resources will be downloaded to the " + "local disk prior to being added to YARN's distributed cache. For use in cases " + "where the YARN service does not support schemes that are supported by Spark, like http, " + - "https and ftp.") + "https and ftp, or jars required to be in the local YARN client's classpath. Wildcard " + + "'*' is denoted to download resources for all the schemes.") .stringConf .toSequence .createWithDefault(Nil) diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index 545c8d0423dc3..f829fecc30840 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -995,20 +995,24 @@ class SparkSubmitSuite } test("download remote resource if it is not supported by yarn service") { - testRemoteResources(enableHttpFs = false, blacklistHttpFs = false) + testRemoteResources(enableHttpFs = false) } test("avoid downloading remote resource if it is supported by yarn service") { - testRemoteResources(enableHttpFs = true, blacklistHttpFs = false) + testRemoteResources(enableHttpFs = true) } test("force download from blacklisted schemes") { - testRemoteResources(enableHttpFs = true, blacklistHttpFs = true) + testRemoteResources(enableHttpFs = true, blacklistSchemes = Seq("http")) + } + + test("force download for all the schemes") { + testRemoteResources(enableHttpFs = true, blacklistSchemes = Seq("*")) } private def testRemoteResources( enableHttpFs: Boolean, - blacklistHttpFs: Boolean): Unit = { + blacklistSchemes: Seq[String] = Nil): Unit = { val hadoopConf = new Configuration() updateConfWithFakeS3Fs(hadoopConf) if (enableHttpFs) { @@ -1025,8 +1029,8 @@ class SparkSubmitSuite val tmpHttpJar = TestUtils.createJarWithFiles(Map("test.resource" -> "USER"), tmpDir) val tmpHttpJarPath = s"http://${new File(tmpHttpJar.toURI).getAbsolutePath}" - val forceDownloadArgs = if (blacklistHttpFs) { - Seq("--conf", "spark.yarn.dist.forceDownloadSchemes=http") + val forceDownloadArgs = if (blacklistSchemes.nonEmpty) { + Seq("--conf", s"spark.yarn.dist.forceDownloadSchemes=${blacklistSchemes.mkString(",")}") } else { Nil } @@ -1044,14 +1048,19 @@ class SparkSubmitSuite val jars = conf.get("spark.yarn.dist.jars").split(",").toSet - // The URI of remote S3 resource should still be remote. - assert(jars.contains(tmpS3JarPath)) + def isSchemeBlacklisted(scheme: String) = { + blacklistSchemes.contains("*") || blacklistSchemes.contains(scheme) + } + + if (!isSchemeBlacklisted("s3")) { + assert(jars.contains(tmpS3JarPath)) + } - if (enableHttpFs && !blacklistHttpFs) { + if (enableHttpFs && blacklistSchemes.isEmpty) { // If Http FS is supported by yarn service, the URI of remote http resource should // still be remote. assert(jars.contains(tmpHttpJarPath)) - } else { + } else if (!enableHttpFs || isSchemeBlacklisted("http")) { // If Http FS is not supported by yarn service, or http scheme is configured to be force // downloading, the URI of remote http resource should be changed to a local one. val jarName = new File(tmpHttpJar.toURI).getName diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index 575da7205b529..0b265b0cb1b31 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -218,9 +218,10 @@ To use a custom metrics.properties for the application master and executors, upd
      From 034913b62b579ae003431231c0272513de8f496c Mon Sep 17 00:00:00 2001 From: Bruce Robbins Date: Mon, 9 Jul 2018 21:21:38 +0900 Subject: [PATCH 1082/1161] [SPARK-23936][SQL] Implement map_concat ## What changes were proposed in this pull request? Implement map_concat high order function. This implementation does not pick a winner when the specified maps have overlapping keys. Therefore, this implementation preserves existing duplicate keys in the maps and potentially introduces new duplicates (After discussion with ueshin, we settled on option 1 from [here](https://issues.apache.org/jira/browse/SPARK-23936?focusedCommentId=16464245&page=com.atlassian.jira.plugin.system.issuetabpanels%3Acomment-tabpanel#comment-16464245)). ## How was this patch tested? New tests Manual tests Run all sbt SQL tests Run all pyspark sql tests Author: Bruce Robbins Closes #21073 from bersprockets/SPARK-23936. --- python/pyspark/sql/functions.py | 22 ++ .../sql/catalyst/CatalystTypeConverters.scala | 6 + .../catalyst/analysis/FunctionRegistry.scala | 1 + .../sql/catalyst/analysis/TypeCoercion.scala | 8 + .../expressions/collectionOperations.scala | 231 ++++++++++++++++++ .../CollectionExpressionsSuite.scala | 126 ++++++++++ .../org/apache/spark/sql/functions.scala | 8 + .../inputs/typeCoercion/native/mapconcat.sql | 94 +++++++ .../typeCoercion/native/mapconcat.sql.out | 143 +++++++++++ .../spark/sql/DataFrameFunctionsSuite.scala | 78 ++++++ 10 files changed, 717 insertions(+) create mode 100644 sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/mapconcat.sql create mode 100644 sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapconcat.sql.out diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 55e7d575b4681..9f61e29f9cd42 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2510,6 +2510,28 @@ def arrays_zip(*cols): return Column(sc._jvm.functions.arrays_zip(_to_seq(sc, cols, _to_java_column))) +@since(2.4) +def map_concat(*cols): + """Returns the union of all the given maps. + + :param cols: list of column names (string) or list of :class:`Column` expressions + + >>> from pyspark.sql.functions import map_concat + >>> df = spark.sql("SELECT map(1, 'a', 2, 'b') as map1, map(3, 'c', 1, 'd') as map2") + >>> df.select(map_concat("map1", "map2").alias("map3")).show(truncate=False) + +--------------------------------+ + |map3 | + +--------------------------------+ + |[1 -> a, 2 -> b, 3 -> c, 1 -> d]| + +--------------------------------+ + """ + sc = SparkContext._active_spark_context + if len(cols) == 1 and isinstance(cols[0], (list, set)): + cols = cols[0] + jc = sc._jvm.functions.map_concat(_to_seq(sc, cols, _to_java_column)) + return Column(jc) + + # ---------------------------- User Defined Function ---------------------------------- class PandasUDFType(object): diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala index 93df73ab1eaf6..6f5fbdd79e668 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala @@ -431,6 +431,12 @@ object CatalystTypeConverters { map, (key: Any) => convertToCatalyst(key), (value: Any) => convertToCatalyst(value)) + case (keys: Array[_], values: Array[_]) => + // case for mapdata with duplicate keys + new ArrayBasedMapData( + new GenericArrayData(keys.map(convertToCatalyst)), + new GenericArrayData(values.map(convertToCatalyst)) + ) case other => other } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 80a0af672bf74..e7517e8c676e3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -422,6 +422,7 @@ object FunctionRegistry { expression[MapValues]("map_values"), expression[MapEntries]("map_entries"), expression[MapFromEntries]("map_from_entries"), + expression[MapConcat]("map_concat"), expression[Size]("size"), expression[Slice]("slice"), expression[Size]("cardinality"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index b6ca30c7398f2..72908c1f433ee 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -548,6 +548,14 @@ object TypeCoercion { case None => s } + case m @ MapConcat(children) if children.forall(c => MapType.acceptsType(c.dataType)) && + !haveSameType(children) => + val types = children.map(_.dataType) + findWiderCommonType(types) match { + case Some(finalDataType) => MapConcat(children.map(Cast(_, finalDataType))) + case None => m + } + case m @ CreateMap(children) if m.keys.length == m.values.length && (!haveSameType(m.keys) || !haveSameType(m.values)) => val newKeys = if (haveSameType(m.keys)) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index fcac3a58e6a95..879603b66b314 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -503,6 +503,237 @@ case class MapEntries(child: Expression) extends UnaryExpression with ExpectsInp override def prettyName: String = "map_entries" } +/** + * Returns the union of all the given maps. + */ +@ExpressionDescription( + usage = "_FUNC_(map, ...) - Returns the union of all the given maps", + examples = """ + Examples: + > SELECT _FUNC_(map(1, 'a', 2, 'b'), map(2, 'c', 3, 'd')); + [[1 -> "a"], [2 -> "b"], [2 -> "c"], [3 -> "d"]] + """, since = "2.4.0") +case class MapConcat(children: Seq[Expression]) extends Expression { + + override def checkInputDataTypes(): TypeCheckResult = { + var funcName = s"function $prettyName" + if (children.exists(!_.dataType.isInstanceOf[MapType])) { + TypeCheckResult.TypeCheckFailure( + s"input to $funcName should all be of type map, but it's " + + children.map(_.dataType.simpleString).mkString("[", ", ", "]")) + } else { + TypeUtils.checkForSameTypeInputExpr(children.map(_.dataType), funcName) + } + } + + override def dataType: MapType = { + val dt = children.map(_.dataType.asInstanceOf[MapType]).headOption + .getOrElse(MapType(StringType, StringType)) + val valueContainsNull = children.map(_.dataType.asInstanceOf[MapType]) + .exists(_.valueContainsNull) + if (dt.valueContainsNull != valueContainsNull) { + dt.copy(valueContainsNull = valueContainsNull) + } else { + dt + } + } + + override def nullable: Boolean = children.exists(_.nullable) + + override def eval(input: InternalRow): Any = { + val maps = children.map(_.eval(input)) + if (maps.contains(null)) { + return null + } + val keyArrayDatas = maps.map(_.asInstanceOf[MapData].keyArray()) + val valueArrayDatas = maps.map(_.asInstanceOf[MapData].valueArray()) + + val numElements = keyArrayDatas.foldLeft(0L)((sum, ad) => sum + ad.numElements()) + if (numElements > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { + throw new RuntimeException(s"Unsuccessful attempt to concat maps with $numElements " + + s"elements due to exceeding the map size limit " + + s"${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}.") + } + val finalKeyArray = new Array[AnyRef](numElements.toInt) + val finalValueArray = new Array[AnyRef](numElements.toInt) + var position = 0 + for (i <- keyArrayDatas.indices) { + val keyArray = keyArrayDatas(i).toObjectArray(dataType.keyType) + val valueArray = valueArrayDatas(i).toObjectArray(dataType.valueType) + Array.copy(keyArray, 0, finalKeyArray, position, keyArray.length) + Array.copy(valueArray, 0, finalValueArray, position, valueArray.length) + position += keyArray.length + } + + new ArrayBasedMapData(new GenericArrayData(finalKeyArray), + new GenericArrayData(finalValueArray)) + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val mapCodes = children.map(_.genCode(ctx)) + val keyType = dataType.keyType + val valueType = dataType.valueType + val argsName = ctx.freshName("args") + val hasNullName = ctx.freshName("hasNull") + val mapDataClass = classOf[MapData].getName + val arrayBasedMapDataClass = classOf[ArrayBasedMapData].getName + val arrayDataClass = classOf[ArrayData].getName + + val init = + s""" + |$mapDataClass[] $argsName = new $mapDataClass[${mapCodes.size}]; + |boolean ${ev.isNull}, $hasNullName = false; + |$mapDataClass ${ev.value} = null; + """.stripMargin + + val assignments = mapCodes.zipWithIndex.map { case (m, i) => + s""" + |if (!$hasNullName) { + | ${m.code} + | $argsName[$i] = ${m.value}; + | if (${m.isNull}) { + | $hasNullName = true; + | } + |} + """.stripMargin + } + + val codes = ctx.splitExpressionsWithCurrentInputs( + expressions = assignments, + funcName = "getMapConcatInputs", + extraArguments = (s"$mapDataClass[]", argsName) :: ("boolean", hasNullName) :: Nil, + returnType = "boolean", + makeSplitFunction = body => + s""" + |$body + |return $hasNullName; + """.stripMargin, + foldFunctions = _.map(funcCall => s"$hasNullName = $funcCall;").mkString("\n") + ) + + val idxName = ctx.freshName("idx") + val numElementsName = ctx.freshName("numElems") + val finKeysName = ctx.freshName("finalKeys") + val finValsName = ctx.freshName("finalValues") + + val keyConcatenator = if (CodeGenerator.isPrimitiveType(keyType)) { + genCodeForPrimitiveArrays(ctx, keyType, false) + } else { + genCodeForNonPrimitiveArrays(ctx, keyType) + } + + val valueConcatenator = if (CodeGenerator.isPrimitiveType(valueType)) { + genCodeForPrimitiveArrays(ctx, valueType, dataType.valueContainsNull) + } else { + genCodeForNonPrimitiveArrays(ctx, valueType) + } + + val keyArgsName = ctx.freshName("keyArgs") + val valArgsName = ctx.freshName("valArgs") + + val mapMerge = + s""" + |${ev.isNull} = $hasNullName; + |if (!${ev.isNull}) { + | $arrayDataClass[] $keyArgsName = new $arrayDataClass[${mapCodes.size}]; + | $arrayDataClass[] $valArgsName = new $arrayDataClass[${mapCodes.size}]; + | long $numElementsName = 0; + | for (int $idxName = 0; $idxName < $argsName.length; $idxName++) { + | $keyArgsName[$idxName] = $argsName[$idxName].keyArray(); + | $valArgsName[$idxName] = $argsName[$idxName].valueArray(); + | $numElementsName += $argsName[$idxName].numElements(); + | } + | if ($numElementsName > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { + | throw new RuntimeException("Unsuccessful attempt to concat maps with " + + | $numElementsName + " elements due to exceeding the map size limit " + + | "${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}."); + | } + | $arrayDataClass $finKeysName = $keyConcatenator.concat($keyArgsName, + | (int) $numElementsName); + | $arrayDataClass $finValsName = $valueConcatenator.concat($valArgsName, + | (int) $numElementsName); + | ${ev.value} = new $arrayBasedMapDataClass($finKeysName, $finValsName); + |} + """.stripMargin + + ev.copy( + code = code""" + |$init + |$codes + |$mapMerge + """.stripMargin) + } + + private def genCodeForPrimitiveArrays( + ctx: CodegenContext, + elementType: DataType, + checkForNull: Boolean): String = { + val counter = ctx.freshName("counter") + val arrayData = ctx.freshName("arrayData") + val argsName = ctx.freshName("args") + val numElemName = ctx.freshName("numElements") + val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType) + + val setterCode1 = + s""" + |$arrayData.set$primitiveValueTypeName( + | $counter, + | ${CodeGenerator.getValue(s"$argsName[y]", elementType, "z")} + |);""".stripMargin + + val setterCode = if (checkForNull) { + s""" + |if ($argsName[y].isNullAt(z)) { + | $arrayData.setNullAt($counter); + |} else { + | $setterCode1 + |}""".stripMargin + } else { + setterCode1 + } + + s""" + |new Object() { + | public ArrayData concat(${classOf[ArrayData].getName}[] $argsName, int $numElemName) { + | ${ctx.createUnsafeArray(arrayData, numElemName, elementType, s" $prettyName failed.")} + | int $counter = 0; + | for (int y = 0; y < ${children.length}; y++) { + | for (int z = 0; z < $argsName[y].numElements(); z++) { + | $setterCode + | $counter++; + | } + | } + | return $arrayData; + | } + |}""".stripMargin.stripPrefix("\n") + } + + private def genCodeForNonPrimitiveArrays(ctx: CodegenContext, elementType: DataType): String = { + val genericArrayClass = classOf[GenericArrayData].getName + val arrayData = ctx.freshName("arrayObjects") + val counter = ctx.freshName("counter") + val argsName = ctx.freshName("args") + val numElemName = ctx.freshName("numElements") + + s""" + |new Object() { + | public ArrayData concat(${classOf[ArrayData].getName}[] $argsName, int $numElemName) {; + | Object[] $arrayData = new Object[$numElemName]; + | int $counter = 0; + | for (int y = 0; y < ${children.length}; y++) { + | for (int z = 0; z < $argsName[y].numElements(); z++) { + | $arrayData[$counter] = ${CodeGenerator.getValue(s"$argsName[y]", elementType, "z")}; + | $counter++; + | } + | } + | return new $genericArrayClass($arrayData); + | } + |}""".stripMargin.stripPrefix("\n") + } + + override def prettyName: String = "map_concat" +} + /** * Returns a map created from the given array of entries. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 496ee1d496a36..173c98af323b1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -98,6 +98,132 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(MapEntries(ms2), null) } + test("Map Concat") { + val m0 = Literal.create(Map("a" -> "1", "b" -> "2"), MapType(StringType, StringType, + valueContainsNull = false)) + val m1 = Literal.create(Map("c" -> "3", "a" -> "4"), MapType(StringType, StringType, + valueContainsNull = false)) + val m2 = Literal.create(Map("d" -> "4", "e" -> "5"), MapType(StringType, StringType)) + val m3 = Literal.create(Map("a" -> "1", "b" -> "2"), MapType(StringType, StringType)) + val m4 = Literal.create(Map("a" -> null, "c" -> "3"), MapType(StringType, StringType)) + val m5 = Literal.create(Map("a" -> 1, "b" -> 2), MapType(StringType, IntegerType)) + val m6 = Literal.create(Map("a" -> null, "c" -> 3), MapType(StringType, IntegerType)) + val m7 = Literal.create(Map(List(1, 2) -> 1, List(3, 4) -> 2), + MapType(ArrayType(IntegerType), IntegerType)) + val m8 = Literal.create(Map(List(5, 6) -> 3, List(1, 2) -> 4), + MapType(ArrayType(IntegerType), IntegerType)) + val m9 = Literal.create(Map(Map(1 -> 2, 3 -> 4) -> 1, Map(5 -> 6, 7 -> 8) -> 2), + MapType(MapType(IntegerType, IntegerType), IntegerType)) + val m10 = Literal.create(Map(Map(9 -> 10, 11 -> 12) -> 3, Map(1 -> 2, 3 -> 4) -> 4), + MapType(MapType(IntegerType, IntegerType), IntegerType)) + val m11 = Literal.create(Map(1 -> "1", 2 -> "2"), MapType(IntegerType, StringType, + valueContainsNull = false)) + val m12 = Literal.create(Map(3 -> "3", 4 -> "4"), MapType(IntegerType, StringType, + valueContainsNull = false)) + val mNull = Literal.create(null, MapType(StringType, StringType)) + + // overlapping maps + checkEvaluation(MapConcat(Seq(m0, m1)), + ( + Array("a", "b", "c", "a"), // keys + Array("1", "2", "3", "4") // values + ) + ) + + // maps with no overlap + checkEvaluation(MapConcat(Seq(m0, m2)), + Map("a" -> "1", "b" -> "2", "d" -> "4", "e" -> "5")) + + // 3 maps + checkEvaluation(MapConcat(Seq(m0, m1, m2)), + ( + Array("a", "b", "c", "a", "d", "e"), // keys + Array("1", "2", "3", "4", "4", "5") // values + ) + ) + + // null reference values + checkEvaluation(MapConcat(Seq(m3, m4)), + ( + Array("a", "b", "a", "c"), // keys + Array("1", "2", null, "3") // values + ) + ) + + // null primitive values + checkEvaluation(MapConcat(Seq(m5, m6)), + ( + Array("a", "b", "a", "c"), // keys + Array(1, 2, null, 3) // values + ) + ) + + // keys that are primitive + checkEvaluation(MapConcat(Seq(m11, m12)), + ( + Array(1, 2, 3, 4), // keys + Array("1", "2", "3", "4") // values + ) + ) + + // keys that are arrays, with overlap + checkEvaluation(MapConcat(Seq(m7, m8)), + ( + Array(List(1, 2), List(3, 4), List(5, 6), List(1, 2)), // keys + Array(1, 2, 3, 4) // values + ) + ) + + // keys that are maps, with overlap + checkEvaluation(MapConcat(Seq(m9, m10)), + ( + Array(Map(1 -> 2, 3 -> 4), Map(5 -> 6, 7 -> 8), Map(9 -> 10, 11 -> 12), + Map(1 -> 2, 3 -> 4)), // keys + Array(1, 2, 3, 4) // values + ) + ) + + // null map + checkEvaluation(MapConcat(Seq(m0, mNull)), null) + checkEvaluation(MapConcat(Seq(mNull, m0)), null) + checkEvaluation(MapConcat(Seq(mNull, mNull)), null) + checkEvaluation(MapConcat(Seq(mNull)), null) + + // single map + checkEvaluation(MapConcat(Seq(m0)), Map("a" -> "1", "b" -> "2")) + + // no map + checkEvaluation(MapConcat(Seq.empty), Map.empty) + + // force split expressions for input in generated code + val expectedKeys = Array.fill(65)(Seq("a", "b")).flatten ++ Array("d", "e") + val expectedValues = Array.fill(65)(Seq("1", "2")).flatten ++ Array("4", "5") + checkEvaluation(MapConcat( + Seq( + m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, + m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, + m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m2 + )), + (expectedKeys, expectedValues)) + + // argument checking + assert(MapConcat(Seq(m0, m1)).checkInputDataTypes().isSuccess) + assert(MapConcat(Seq(m5, m6)).checkInputDataTypes().isSuccess) + assert(MapConcat(Seq(m0, m5)).checkInputDataTypes().isFailure) + assert(MapConcat(Seq(m0, Literal(12))).checkInputDataTypes().isFailure) + assert(MapConcat(Seq(m0, m1)).dataType.keyType == StringType) + assert(MapConcat(Seq(m0, m1)).dataType.valueType == StringType) + assert(!MapConcat(Seq(m0, m1)).dataType.valueContainsNull) + assert(MapConcat(Seq(m5, m6)).dataType.keyType == StringType) + assert(MapConcat(Seq(m5, m6)).dataType.valueType == IntegerType) + assert(MapConcat(Seq.empty).dataType.keyType == StringType) + assert(MapConcat(Seq.empty).dataType.valueType == StringType) + assert(MapConcat(Seq(m5, m6)).dataType.valueContainsNull) + assert(MapConcat(Seq(m6, m5)).dataType.valueContainsNull) + assert(!MapConcat(Seq(m1, m2)).nullable) + assert(MapConcat(Seq(m1, mNull)).nullable) + } + test("MapFromEntries") { def arrayType(keyType: DataType, valueType: DataType) : DataType = { ArrayType( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index f2627e69939cd..89dbba10a6bf1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3627,6 +3627,14 @@ object functions { @scala.annotation.varargs def arrays_zip(e: Column*): Column = withExpr { ArraysZip(e.map(_.expr)) } + /** + * Returns the union of all the given maps. + * @group collection_funcs + * @since 2.4.0 + */ + @scala.annotation.varargs + def map_concat(cols: Column*): Column = withExpr { MapConcat(cols.map(_.expr)) } + ////////////////////////////////////////////////////////////////////////////////////////////// // Mask functions ////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/mapconcat.sql b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/mapconcat.sql new file mode 100644 index 0000000000000..fc26397b881b5 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/mapconcat.sql @@ -0,0 +1,94 @@ +CREATE TEMPORARY VIEW various_maps AS SELECT * FROM VALUES ( + map(true, false), map(false, true), + map(1Y, 2Y), map(3Y, 4Y), + map(1S, 2S), map(3S, 4S), + map(4, 6), map(7, 8), + map(6L, 7L), map(8L, 9L), + map(9223372036854775809, 9223372036854775808), map(9223372036854775808, 9223372036854775809), + map(1.0D, 2.0D), map(3.0D, 4.0D), + map(float(1.0D), float(2.0D)), map(float(3.0D), float(4.0D)), + map(date '2016-03-14', date '2016-03-13'), map(date '2016-03-12', date '2016-03-11'), + map(timestamp '2016-11-15 20:54:00.000', timestamp '2016-11-12 20:54:00.000'), + map(timestamp '2016-11-11 20:54:00.000', timestamp '2016-11-09 20:54:00.000'), + map('a', 'b'), map('c', 'd'), + map(array('a', 'b'), array('c', 'd')), map(array('e'), array('f')), + map(struct('a', 1), struct('b', 2)), map(struct('c', 3), struct('d', 4)), + map(map('a', 1), map('b', 2)), map(map('c', 3), map('d', 4)), + map('a', 1), map('c', 2), + map(1, 'a'), map(2, 'c') +) AS various_maps ( + boolean_map1, boolean_map2, + tinyint_map1, tinyint_map2, + smallint_map1, smallint_map2, + int_map1, int_map2, + bigint_map1, bigint_map2, + decimal_map1, decimal_map2, + double_map1, double_map2, + float_map1, float_map2, + date_map1, date_map2, + timestamp_map1, + timestamp_map2, + string_map1, string_map2, + array_map1, array_map2, + struct_map1, struct_map2, + map_map1, map_map2, + string_int_map1, string_int_map2, + int_string_map1, int_string_map2 +); + +-- Concatenate maps of the same type +SELECT + map_concat(boolean_map1, boolean_map2) boolean_map, + map_concat(tinyint_map1, tinyint_map2) tinyint_map, + map_concat(smallint_map1, smallint_map2) smallint_map, + map_concat(int_map1, int_map2) int_map, + map_concat(bigint_map1, bigint_map2) bigint_map, + map_concat(decimal_map1, decimal_map2) decimal_map, + map_concat(float_map1, float_map2) float_map, + map_concat(double_map1, double_map2) double_map, + map_concat(date_map1, date_map2) date_map, + map_concat(timestamp_map1, timestamp_map2) timestamp_map, + map_concat(string_map1, string_map2) string_map, + map_concat(array_map1, array_map2) array_map, + map_concat(struct_map1, struct_map2) struct_map, + map_concat(map_map1, map_map2) map_map, + map_concat(string_int_map1, string_int_map2) string_int_map, + map_concat(int_string_map1, int_string_map2) int_string_map +FROM various_maps; + +-- Concatenate maps of different types +SELECT + map_concat(tinyint_map1, smallint_map2) ts_map, + map_concat(smallint_map1, int_map2) si_map, + map_concat(int_map1, bigint_map2) ib_map, + map_concat(decimal_map1, float_map2) df_map, + map_concat(string_map1, date_map2) std_map, + map_concat(timestamp_map1, string_map2) tst_map, + map_concat(string_map1, int_map2) sti_map, + map_concat(int_string_map1, tinyint_map2) istt_map +FROM various_maps; + +-- Concatenate map of incompatible types 1 +SELECT + map_concat(tinyint_map1, map_map2) tm_map +FROM various_maps; + +-- Concatenate map of incompatible types 2 +SELECT + map_concat(boolean_map1, int_map2) bi_map +FROM various_maps; + +-- Concatenate map of incompatible types 3 +SELECT + map_concat(int_map1, struct_map2) is_map +FROM various_maps; + +-- Concatenate map of incompatible types 4 +SELECT + map_concat(map_map1, array_map2) ma_map +FROM various_maps; + +-- Concatenate map of incompatible types 5 +SELECT + map_concat(map_map1, struct_map2) ms_map +FROM various_maps; diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapconcat.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapconcat.sql.out new file mode 100644 index 0000000000000..d352b7284ae87 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapconcat.sql.out @@ -0,0 +1,143 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 8 + + +-- !query 0 +CREATE TEMPORARY VIEW various_maps AS SELECT * FROM VALUES ( + map(true, false), map(false, true), + map(1Y, 2Y), map(3Y, 4Y), + map(1S, 2S), map(3S, 4S), + map(4, 6), map(7, 8), + map(6L, 7L), map(8L, 9L), + map(9223372036854775809, 9223372036854775808), map(9223372036854775808, 9223372036854775809), + map(1.0D, 2.0D), map(3.0D, 4.0D), + map(float(1.0D), float(2.0D)), map(float(3.0D), float(4.0D)), + map(date '2016-03-14', date '2016-03-13'), map(date '2016-03-12', date '2016-03-11'), + map(timestamp '2016-11-15 20:54:00.000', timestamp '2016-11-12 20:54:00.000'), + map(timestamp '2016-11-11 20:54:00.000', timestamp '2016-11-09 20:54:00.000'), + map('a', 'b'), map('c', 'd'), + map(array('a', 'b'), array('c', 'd')), map(array('e'), array('f')), + map(struct('a', 1), struct('b', 2)), map(struct('c', 3), struct('d', 4)), + map(map('a', 1), map('b', 2)), map(map('c', 3), map('d', 4)), + map('a', 1), map('c', 2), + map(1, 'a'), map(2, 'c') +) AS various_maps ( + boolean_map1, boolean_map2, + tinyint_map1, tinyint_map2, + smallint_map1, smallint_map2, + int_map1, int_map2, + bigint_map1, bigint_map2, + decimal_map1, decimal_map2, + double_map1, double_map2, + float_map1, float_map2, + date_map1, date_map2, + timestamp_map1, + timestamp_map2, + string_map1, string_map2, + array_map1, array_map2, + struct_map1, struct_map2, + map_map1, map_map2, + string_int_map1, string_int_map2, + int_string_map1, int_string_map2 +) +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +SELECT + map_concat(boolean_map1, boolean_map2) boolean_map, + map_concat(tinyint_map1, tinyint_map2) tinyint_map, + map_concat(smallint_map1, smallint_map2) smallint_map, + map_concat(int_map1, int_map2) int_map, + map_concat(bigint_map1, bigint_map2) bigint_map, + map_concat(decimal_map1, decimal_map2) decimal_map, + map_concat(float_map1, float_map2) float_map, + map_concat(double_map1, double_map2) double_map, + map_concat(date_map1, date_map2) date_map, + map_concat(timestamp_map1, timestamp_map2) timestamp_map, + map_concat(string_map1, string_map2) string_map, + map_concat(array_map1, array_map2) array_map, + map_concat(struct_map1, struct_map2) struct_map, + map_concat(map_map1, map_map2) map_map, + map_concat(string_int_map1, string_int_map2) string_int_map, + map_concat(int_string_map1, int_string_map2) int_string_map +FROM various_maps +-- !query 1 schema +struct,tinyint_map:map,smallint_map:map,int_map:map,bigint_map:map,decimal_map:map,float_map:map,double_map:map,date_map:map,timestamp_map:map,string_map:map,array_map:map,array>,struct_map:map,struct>,map_map:map,map>,string_int_map:map,int_string_map:map> +-- !query 1 output +{false:true,true:false} {1:2,3:4} {1:2,3:4} {4:6,7:8} {6:7,8:9} {9223372036854775808:9223372036854775809,9223372036854775809:9223372036854775808} {1.0:2.0,3.0:4.0} {1.0:2.0,3.0:4.0} {2016-03-12:2016-03-11,2016-03-14:2016-03-13} {2016-11-11 20:54:00.0:2016-11-09 20:54:00.0,2016-11-15 20:54:00.0:2016-11-12 20:54:00.0} {"a":"b","c":"d"} {["a","b"]:["c","d"],["e"]:["f"]} {{"col1":"a","col2":1}:{"col1":"b","col2":2},{"col1":"c","col2":3}:{"col1":"d","col2":4}} {{"a":1}:{"b":2},{"c":3}:{"d":4}} {"a":1,"c":2} {1:"a",2:"c"} + + +-- !query 2 +SELECT + map_concat(tinyint_map1, smallint_map2) ts_map, + map_concat(smallint_map1, int_map2) si_map, + map_concat(int_map1, bigint_map2) ib_map, + map_concat(decimal_map1, float_map2) df_map, + map_concat(string_map1, date_map2) std_map, + map_concat(timestamp_map1, string_map2) tst_map, + map_concat(string_map1, int_map2) sti_map, + map_concat(int_string_map1, tinyint_map2) istt_map +FROM various_maps +-- !query 2 schema +struct,si_map:map,ib_map:map,df_map:map,std_map:map,tst_map:map,sti_map:map,istt_map:map> +-- !query 2 output +{1:2,3:4} {1:2,7:8} {4:6,8:9} {3.0:4.0,9.223372036854776E18:9.223372036854776E18} {"2016-03-12":"2016-03-11","a":"b"} {"2016-11-15 20:54:00":"2016-11-12 20:54:00","c":"d"} {"7":"8","a":"b"} {1:"a",3:"4"} + + +-- !query 3 +SELECT + map_concat(tinyint_map1, map_map2) tm_map +FROM various_maps +-- !query 3 schema +struct<> +-- !query 3 output +org.apache.spark.sql.AnalysisException +cannot resolve 'map_concat(various_maps.`tinyint_map1`, various_maps.`map_map2`)' due to data type mismatch: input to function map_concat should all be the same type, but it's [map, map,map>]; line 2 pos 4 + + +-- !query 4 +SELECT + map_concat(boolean_map1, int_map2) bi_map +FROM various_maps +-- !query 4 schema +struct<> +-- !query 4 output +org.apache.spark.sql.AnalysisException +cannot resolve 'map_concat(various_maps.`boolean_map1`, various_maps.`int_map2`)' due to data type mismatch: input to function map_concat should all be the same type, but it's [map, map]; line 2 pos 4 + + +-- !query 5 +SELECT + map_concat(int_map1, struct_map2) is_map +FROM various_maps +-- !query 5 schema +struct<> +-- !query 5 output +org.apache.spark.sql.AnalysisException +cannot resolve 'map_concat(various_maps.`int_map1`, various_maps.`struct_map2`)' due to data type mismatch: input to function map_concat should all be the same type, but it's [map, map,struct>]; line 2 pos 4 + + +-- !query 6 +SELECT + map_concat(map_map1, array_map2) ma_map +FROM various_maps +-- !query 6 schema +struct<> +-- !query 6 output +org.apache.spark.sql.AnalysisException +cannot resolve 'map_concat(various_maps.`map_map1`, various_maps.`array_map2`)' due to data type mismatch: input to function map_concat should all be the same type, but it's [map,map>, map,array>]; line 2 pos 4 + + +-- !query 7 +SELECT + map_concat(map_map1, struct_map2) ms_map +FROM various_maps +-- !query 7 schema +struct<> +-- !query 7 output +org.apache.spark.sql.AnalysisException +cannot resolve 'map_concat(various_maps.`map_map1`, various_maps.`struct_map2`)' due to data type mismatch: input to function map_concat should all be the same type, but it's [map,map>, map,struct>]; line 2 pos 4 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 4c28e2f1cd909..d60ed7a5ef0d9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -657,6 +657,84 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { checkAnswer(sdf.filter(dummyFilter('m)).select(map_entries('m)), sExpected) } + test("map_concat function") { + val df1 = Seq( + (Map[Int, Int](1 -> 100, 2 -> 200), Map[Int, Int](3 -> 300, 4 -> 400)), + (Map[Int, Int](1 -> 100, 2 -> 200), Map[Int, Int](3 -> 300, 1 -> 400)), + (null, Map[Int, Int](3 -> 300, 4 -> 400)) + ).toDF("map1", "map2") + + val expected1a = Seq( + Row(Map(1 -> 100, 2 -> 200, 3 -> 300, 4 -> 400)), + Row(Map(1 -> 400, 2 -> 200, 3 -> 300)), + Row(null) + ) + + checkAnswer(df1.selectExpr("map_concat(map1, map2)"), expected1a) + checkAnswer(df1.select(map_concat('map1, 'map2)), expected1a) + + val expected1b = Seq( + Row(Map(1 -> 100, 2 -> 200)), + Row(Map(1 -> 100, 2 -> 200)), + Row(null) + ) + + checkAnswer(df1.selectExpr("map_concat(map1)"), expected1b) + checkAnswer(df1.select(map_concat('map1)), expected1b) + + val df2 = Seq( + ( + Map[Array[Int], Int](Array(1) -> 100, Array(2) -> 200), + Map[String, Int]("3" -> 300, "4" -> 400) + ) + ).toDF("map1", "map2") + + val expected2 = Seq(Row(Map())) + + checkAnswer(df2.selectExpr("map_concat()"), expected2) + checkAnswer(df2.select(map_concat()), expected2) + + val df3 = { + val schema = StructType( + StructField("map1", MapType(StringType, IntegerType, true), false) :: + StructField("map2", MapType(StringType, IntegerType, false), false) :: Nil + ) + val data = Seq( + Row(Map[String, Any]("a" -> 1, "b" -> null), Map[String, Any]("c" -> 3, "d" -> 4)), + Row(Map[String, Any]("a" -> 1, "b" -> 2), Map[String, Any]("c" -> 3, "d" -> 4)) + ) + spark.createDataFrame(spark.sparkContext.parallelize(data), schema) + } + + val expected3 = Seq( + Row(Map[String, Any]("a" -> 1, "b" -> null, "c" -> 3, "d" -> 4)), + Row(Map[String, Any]("a" -> 1, "b" -> 2, "c" -> 3, "d" -> 4)) + ) + + checkAnswer(df3.selectExpr("map_concat(map1, map2)"), expected3) + checkAnswer(df3.select(map_concat('map1, 'map2)), expected3) + + val expectedMessage1 = "input to function map_concat should all be the same type" + + assert(intercept[AnalysisException] { + df2.selectExpr("map_concat(map1, map2)").collect() + }.getMessage().contains(expectedMessage1)) + + assert(intercept[AnalysisException] { + df2.select(map_concat('map1, 'map2)).collect() + }.getMessage().contains(expectedMessage1)) + + val expectedMessage2 = "input to function map_concat should all be of type map" + + assert(intercept[AnalysisException] { + df2.selectExpr("map_concat(map1, 12)").collect() + }.getMessage().contains(expectedMessage2)) + + assert(intercept[AnalysisException] { + df2.select(map_concat('map1, lit(12))).collect() + }.getMessage().contains(expectedMessage2)) + } + test("map_from_entries function") { def dummyFilter(c: Column): Column = c.isNull || c.isNotNull val oneRowDF = Seq(3215).toDF("i") From 1bd3d61f4191767a94b71b42f4d00706b703e84f Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Mon, 9 Jul 2018 22:59:05 +0800 Subject: [PATCH 1083/1161] [SPARK-24268][SQL] Use datatype.simpleString in error messages ## What changes were proposed in this pull request? SPARK-22893 tried to unify error messages about dataTypes. Unfortunately, still many places were missing the `simpleString` method in other to have the same representation everywhere. The PR unified the messages using alway the simpleString representation of the dataTypes in the messages. ## How was this patch tested? existing/modified UTs Author: Marco Gaido Closes #21321 from mgaido91/SPARK-24268. --- .../spark/sql/kafka010/KafkaWriteTask.scala | 6 +++--- .../apache/spark/sql/kafka010/KafkaWriter.scala | 6 +++--- .../sql/kafka010/KafkaContinuousSinkSuite.scala | 4 ++-- .../spark/sql/kafka010/KafkaSinkSuite.scala | 4 ++-- .../scala/org/apache/spark/ml/feature/DCT.scala | 3 ++- .../apache/spark/ml/feature/FeatureHasher.scala | 5 +++-- .../org/apache/spark/ml/feature/HashingTF.scala | 2 +- .../apache/spark/ml/feature/Interaction.scala | 3 ++- .../org/apache/spark/ml/feature/NGram.scala | 2 +- .../apache/spark/ml/feature/OneHotEncoder.scala | 3 ++- .../org/apache/spark/ml/feature/RFormula.scala | 2 +- .../spark/ml/feature/StopWordsRemover.scala | 4 ++-- .../org/apache/spark/ml/feature/Tokenizer.scala | 3 ++- .../spark/ml/feature/VectorAssembler.scala | 2 +- .../scala/org/apache/spark/ml/fpm/FPGrowth.scala | 2 +- .../org/apache/spark/ml/util/SchemaUtils.scala | 11 +++++++---- .../BinaryClassificationEvaluatorSuite.scala | 4 ++-- .../apache/spark/ml/feature/RFormulaSuite.scala | 2 +- .../spark/ml/feature/VectorAssemblerSuite.scala | 6 +++--- .../spark/ml/recommendation/ALSSuite.scala | 2 +- .../regression/AFTSurvivalRegressionSuite.scala | 2 +- .../apache/spark/ml/util/MLTestingUtils.scala | 6 +++--- .../expressions/complexTypeCreator.scala | 4 ++-- .../catalyst/expressions/jsonExpressions.scala | 2 +- .../catalyst/expressions/stringExpressions.scala | 5 +++-- .../sql/catalyst/json/JacksonGenerator.scala | 4 ++-- .../spark/sql/catalyst/json/JacksonParser.scala | 6 ++++-- .../sql/catalyst/json/JsonInferSchema.scala | 6 ++++-- .../spark/sql/catalyst/util/TypeUtils.scala | 5 +++-- .../spark/sql/types/AbstractDataType.scala | 9 +++++---- .../org/apache/spark/sql/types/ArrayType.scala | 5 +++-- .../org/apache/spark/sql/types/DecimalType.scala | 3 ++- .../org/apache/spark/sql/types/ObjectType.scala | 3 ++- .../org/apache/spark/sql/types/StructType.scala | 5 +++-- .../catalyst/analysis/AnalysisErrorSuite.scala | 2 +- .../analysis/ExpressionTypeCheckingSuite.scala | 16 ++++++++-------- .../catalyst/parser/ExpressionParserSuite.scala | 2 +- .../apache/spark/sql/types/DataTypeSuite.scala | 2 +- .../parquet/VectorizedColumnReader.java | 2 +- .../spark/sql/RelationalGroupedDataset.scala | 2 +- .../spark/sql/execution/arrow/ArrowUtils.scala | 3 ++- .../execution/datasources/orc/OrcFilters.scala | 2 +- .../parquet/ParquetSchemaConverter.scala | 2 +- .../spark/sql/execution/stat/StatFunctions.scala | 2 +- .../sql-tests/results/json-functions.sql.out | 4 ++-- .../resources/sql-tests/results/literals.sql.out | 6 +++--- .../datasources/parquet/ParquetSchemaSuite.scala | 4 ++-- .../sql/hive/execution/HiveTableScanExec.scala | 6 +++--- 48 files changed, 108 insertions(+), 88 deletions(-) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala index d90630a8adc93..59a84706d4f55 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala @@ -110,7 +110,7 @@ private[kafka010] abstract class KafkaRowWriter( case t => throw new IllegalStateException(s"${KafkaWriter.TOPIC_ATTRIBUTE_NAME} " + s"attribute unsupported type $t. ${KafkaWriter.TOPIC_ATTRIBUTE_NAME} " + - "must be a StringType") + s"must be a ${StringType.simpleString}") } val keyExpression = inputSchema.find(_.name == KafkaWriter.KEY_ATTRIBUTE_NAME) .getOrElse(Literal(null, BinaryType)) @@ -118,7 +118,7 @@ private[kafka010] abstract class KafkaRowWriter( case StringType | BinaryType => // good case t => throw new IllegalStateException(s"${KafkaWriter.KEY_ATTRIBUTE_NAME} " + - s"attribute unsupported type $t") + s"attribute unsupported type ${t.simpleString}") } val valueExpression = inputSchema .find(_.name == KafkaWriter.VALUE_ATTRIBUTE_NAME).getOrElse( @@ -129,7 +129,7 @@ private[kafka010] abstract class KafkaRowWriter( case StringType | BinaryType => // good case t => throw new IllegalStateException(s"${KafkaWriter.VALUE_ATTRIBUTE_NAME} " + - s"attribute unsupported type $t") + s"attribute unsupported type ${t.simpleString}") } UnsafeProjection.create( Seq(topicExpression, Cast(keyExpression, BinaryType), diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala index 15cd44812cb0c..3ec26e9edd353 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala @@ -57,7 +57,7 @@ private[kafka010] object KafkaWriter extends Logging { ).dataType match { case StringType => // good case _ => - throw new AnalysisException(s"Topic type must be a String") + throw new AnalysisException(s"Topic type must be a ${StringType.simpleString}") } schema.find(_.name == KEY_ATTRIBUTE_NAME).getOrElse( Literal(null, StringType) @@ -65,7 +65,7 @@ private[kafka010] object KafkaWriter extends Logging { case StringType | BinaryType => // good case _ => throw new AnalysisException(s"$KEY_ATTRIBUTE_NAME attribute type " + - s"must be a String or BinaryType") + s"must be a ${StringType.simpleString} or ${BinaryType.simpleString}") } schema.find(_.name == VALUE_ATTRIBUTE_NAME).getOrElse( throw new AnalysisException(s"Required attribute '$VALUE_ATTRIBUTE_NAME' not found") @@ -73,7 +73,7 @@ private[kafka010] object KafkaWriter extends Logging { case StringType | BinaryType => // good case _ => throw new AnalysisException(s"$VALUE_ATTRIBUTE_NAME attribute type " + - s"must be a String or BinaryType") + s"must be a ${StringType.simpleString} or ${BinaryType.simpleString}") } } diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala index ddfc0c1a4be2d..0e1492ac27449 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala @@ -314,7 +314,7 @@ class KafkaContinuousSinkSuite extends KafkaContinuousTest { writer.stop() } assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( - "value attribute type must be a string or binarytype")) + "value attribute type must be a string or binary")) try { /* key field wrong type */ @@ -330,7 +330,7 @@ class KafkaContinuousSinkSuite extends KafkaContinuousTest { writer.stop() } assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( - "key attribute type must be a string or binarytype")) + "key attribute type must be a string or binary")) } test("streaming - write to non-existing topic") { diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala index 7079ac6453ffc..70ffd7dee89d7 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala @@ -303,7 +303,7 @@ class KafkaSinkSuite extends StreamTest with SharedSQLContext { writer.stop() } assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( - "value attribute type must be a string or binarytype")) + "value attribute type must be a string or binary")) try { ex = intercept[StreamingQueryException] { @@ -318,7 +318,7 @@ class KafkaSinkSuite extends StreamTest with SharedSQLContext { writer.stop() } assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( - "key attribute type must be a string or binarytype")) + "key attribute type must be a string or binary")) } test("streaming - write to non-existing topic") { diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala index 682787a830113..1eac1d1613d2a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala @@ -69,7 +69,8 @@ class DCT @Since("1.5.0") (@Since("1.5.0") override val uid: String) } override protected def validateInputType(inputType: DataType): Unit = { - require(inputType.isInstanceOf[VectorUDT], s"Input type must be VectorUDT but got $inputType.") + require(inputType.isInstanceOf[VectorUDT], + s"Input type must be ${(new VectorUDT).simpleString} but got ${inputType.simpleString}.") } override protected def outputDataType: DataType = new VectorUDT diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/FeatureHasher.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/FeatureHasher.scala index d67e4819b161a..405ea467cb02a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/FeatureHasher.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/FeatureHasher.scala @@ -208,8 +208,9 @@ class FeatureHasher(@Since("2.3.0") override val uid: String) extends Transforme require(dataType.isInstanceOf[NumericType] || dataType.isInstanceOf[StringType] || dataType.isInstanceOf[BooleanType], - s"FeatureHasher requires columns to be of NumericType, BooleanType or StringType. " + - s"Column $fieldName was $dataType") + s"FeatureHasher requires columns to be of ${NumericType.simpleString}, " + + s"${BooleanType.simpleString} or ${StringType.simpleString}. " + + s"Column $fieldName was ${dataType.simpleString}") } val attrGroup = new AttributeGroup($(outputCol), $(numFeatures)) SchemaUtils.appendColumn(schema, attrGroup.toStructField()) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala index db432b6fefaff..403b0a813aedd 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala @@ -104,7 +104,7 @@ class HashingTF @Since("1.4.0") (@Since("1.4.0") override val uid: String) override def transformSchema(schema: StructType): StructType = { val inputType = schema($(inputCol)).dataType require(inputType.isInstanceOf[ArrayType], - s"The input column must be ArrayType, but got $inputType.") + s"The input column must be ${ArrayType.simpleString}, but got ${inputType.simpleString}.") val attrGroup = new AttributeGroup($(outputCol), $(numFeatures)) SchemaUtils.appendColumn(schema, attrGroup.toStructField()) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala index 4ff1d0ef356f3..5e01ec30bb2eb 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala @@ -261,7 +261,8 @@ private[ml] class FeatureEncoder(numFeatures: Array[Int]) extends Serializable { */ def foreachNonzeroOutput(value: Any, f: (Int, Double) => Unit): Unit = value match { case d: Double => - assert(numFeatures.length == 1, "DoubleType columns should only contain one feature.") + assert(numFeatures.length == 1, + s"${DoubleType.simpleString} columns should only contain one feature.") val numOutputCols = numFeatures.head if (numOutputCols > 1) { assert( diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala index c8760f9dc178f..6445360f7fd90 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala @@ -65,7 +65,7 @@ class NGram @Since("1.5.0") (@Since("1.5.0") override val uid: String) override protected def validateInputType(inputType: DataType): Unit = { require(inputType.sameType(ArrayType(StringType)), - s"Input type must be ArrayType(StringType) but got $inputType.") + s"Input type must be ${ArrayType(StringType).simpleString} but got $inputType.") } override protected def outputDataType: DataType = new ArrayType(StringType, false) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala index 5ab6c2dde667a..24045f0448c81 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala @@ -85,7 +85,8 @@ class OneHotEncoder @Since("1.4.0") (@Since("1.4.0") override val uid: String) e val inputFields = schema.fields require(schema(inputColName).dataType.isInstanceOf[NumericType], - s"Input column must be of type NumericType but got ${schema(inputColName).dataType}") + s"Input column must be of type ${NumericType.simpleString} but got " + + schema(inputColName).dataType.simpleString) require(!inputFields.exists(_.name == outputColName), s"Output column $outputColName already exists.") diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala index 55e595eee6ffb..346e1823f00b8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala @@ -394,7 +394,7 @@ class RFormulaModel private[feature]( require(!columnNames.contains($(featuresCol)), "Features column already exists.") require( !columnNames.contains($(labelCol)) || schema($(labelCol)).dataType.isInstanceOf[NumericType], - "Label column already exists and is not of type NumericType.") + s"Label column already exists and is not of type ${NumericType.simpleString}.") } @Since("2.0.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala index 0f946dd2e015b..ead75d5b8def3 100755 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala @@ -131,8 +131,8 @@ class StopWordsRemover @Since("1.5.0") (@Since("1.5.0") override val uid: String @Since("1.5.0") override def transformSchema(schema: StructType): StructType = { val inputType = schema($(inputCol)).dataType - require(inputType.sameType(ArrayType(StringType)), - s"Input type must be ArrayType(StringType) but got $inputType.") + require(inputType.sameType(ArrayType(StringType)), "Input type must be " + + s"${ArrayType(StringType).simpleString} but got ${inputType.simpleString}.") SchemaUtils.appendColumn(schema, $(outputCol), inputType, schema($(inputCol)).nullable) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala index cfaf6c0e610b3..5132f63af1796 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala @@ -40,7 +40,8 @@ class Tokenizer @Since("1.4.0") (@Since("1.4.0") override val uid: String) } override protected def validateInputType(inputType: DataType): Unit = { - require(inputType == StringType, s"Input type must be string type but got $inputType.") + require(inputType == StringType, + s"Input type must be ${StringType.simpleString} type but got ${inputType.simpleString}.") } override protected def outputDataType: DataType = new ArrayType(StringType, true) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala index 4061154b39c14..ed3b36ee5ab2f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala @@ -162,7 +162,7 @@ class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String) schema(name).dataType match { case _: NumericType | BooleanType => None case t if t.isInstanceOf[VectorUDT] => None - case other => Some(s"Data type $other of column $name is not supported.") + case other => Some(s"Data type ${other.simpleString} of column $name is not supported.") } } if (incorrectColumns.nonEmpty) { diff --git a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala index d7fbe28ae7a64..51b88b3117f4e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala @@ -106,7 +106,7 @@ private[fpm] trait FPGrowthParams extends Params with HasPredictionCol { protected def validateAndTransformSchema(schema: StructType): StructType = { val inputType = schema($(itemsCol)).dataType require(inputType.isInstanceOf[ArrayType], - s"The input column must be ArrayType, but got $inputType.") + s"The input column must be ${ArrayType.simpleString}, but got ${inputType.simpleString}.") SchemaUtils.appendColumn(schema, $(predictionCol), schema($(itemsCol)).dataType) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala index d9a3f85ef9a24..b500582074398 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala @@ -41,7 +41,8 @@ private[spark] object SchemaUtils { val actualDataType = schema(colName).dataType val message = if (msg != null && msg.trim.length > 0) " " + msg else "" require(actualDataType.equals(dataType), - s"Column $colName must be of type $dataType but was actually $actualDataType.$message") + s"Column $colName must be of type ${dataType.simpleString} but was actually " + + s"${actualDataType.simpleString}.$message") } /** @@ -58,7 +59,8 @@ private[spark] object SchemaUtils { val message = if (msg != null && msg.trim.length > 0) " " + msg else "" require(dataTypes.exists(actualDataType.equals), s"Column $colName must be of type equal to one of the following types: " + - s"${dataTypes.mkString("[", ", ", "]")} but was actually of type $actualDataType.$message") + s"${dataTypes.map(_.simpleString).mkString("[", ", ", "]")} but was actually of type " + + s"${actualDataType.simpleString}.$message") } /** @@ -71,8 +73,9 @@ private[spark] object SchemaUtils { msg: String = ""): Unit = { val actualDataType = schema(colName).dataType val message = if (msg != null && msg.trim.length > 0) " " + msg else "" - require(actualDataType.isInstanceOf[NumericType], s"Column $colName must be of type " + - s"NumericType but was actually of type $actualDataType.$message") + require(actualDataType.isInstanceOf[NumericType], + s"Column $colName must be of type ${NumericType.simpleString} but was actually of type " + + s"${actualDataType.simpleString}.$message") } /** diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala index ede284712b1c0..2b0909acf69c3 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala @@ -67,8 +67,8 @@ class BinaryClassificationEvaluatorSuite evaluator.evaluate(stringDF) } assert(thrown.getMessage.replace("\n", "") contains "Column rawPrediction must be of type " + - "equal to one of the following types: [DoubleType, ") - assert(thrown.getMessage.replace("\n", "") contains "but was actually of type StringType.") + "equal to one of the following types: [double, ") + assert(thrown.getMessage.replace("\n", "") contains "but was actually of type string.") } test("should support all NumericType labels and not support other types") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala index a250331efeb1d..0de6528c4cf22 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala @@ -105,7 +105,7 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { testTransformerByInterceptingException[(Int, Boolean)]( original, model, - "Label column already exists and is not of type NumericType.", + "Label column already exists and is not of type numeric.", "x") } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala index 91fb24a268b8c..ed15a1d88a269 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala @@ -99,9 +99,9 @@ class VectorAssemblerSuite assembler.transform(df) } assert(thrown.getMessage contains - "Data type StringType of column a is not supported.\n" + - "Data type StringType of column b is not supported.\n" + - "Data type StringType of column c is not supported.") + "Data type string of column a is not supported.\n" + + "Data type string of column b is not supported.\n" + + "Data type string of column c is not supported.") } test("ML attributes") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala index e3dfe2faf5698..65bee4edc4965 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala @@ -612,7 +612,7 @@ class ALSSuite extends MLTest with DefaultReadWriteTest with Logging { estimator.fit(strDF) } assert(thrown.getMessage.contains( - s"$column must be of type NumericType but was actually of type StringType")) + s"$column must be of type numeric but was actually of type string")) } private class NumericTypeWithEncoder[A](val numericType: NumericType) diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala index 4e4ff71c9de90..6cc73e040e82c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala @@ -385,7 +385,7 @@ class AFTSurvivalRegressionSuite extends MLTest with DefaultReadWriteTest { aft.fit(dfWithStringCensors) } assert(thrown.getMessage.contains( - "Column censor must be of type NumericType but was actually of type StringType")) + "Column censor must be of type numeric but was actually of type string")) } test("numerical stability of standardization") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala index 5e72b4d864c1d..91a8b14625a86 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala @@ -74,7 +74,7 @@ object MLTestingUtils extends SparkFunSuite { estimator.fit(dfWithStringLabels) } assert(thrown.getMessage.contains( - "Column label must be of type NumericType but was actually of type StringType")) + "Column label must be of type numeric but was actually of type string")) estimator match { case weighted: Estimator[M] with HasWeightCol => @@ -86,7 +86,7 @@ object MLTestingUtils extends SparkFunSuite { weighted.fit(dfWithStringWeights) } assert(thrown.getMessage.contains( - "Column weight must be of type NumericType but was actually of type StringType")) + "Column weight must be of type numeric but was actually of type string")) case _ => } } @@ -104,7 +104,7 @@ object MLTestingUtils extends SparkFunSuite { evaluator.evaluate(dfWithStringLabels) } assert(thrown.getMessage.contains( - "Column label must be of type NumericType but was actually of type StringType")) + "Column label must be of type numeric but was actually of type string")) } def genClassifDFWithNumericLabelCol( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 0a5f8a907b50a..cf0e3765de80f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -385,8 +385,8 @@ trait CreateNamedStructLike extends Expression { val invalidNames = nameExprs.filterNot(e => e.foldable && e.dataType == StringType) if (invalidNames.nonEmpty) { TypeCheckResult.TypeCheckFailure( - "Only foldable StringType expressions are allowed to appear at odd position, got:" + - s" ${invalidNames.mkString(",")}") + s"Only foldable ${StringType.simpleString} expressions are allowed to appear at odd" + + s" position, got: ${invalidNames.mkString(",")}") } else if (!names.contains(null)) { TypeCheckResult.TypeCheckSuccess } else { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index 8cd86053a01c7..1bcf11d7ee737 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -796,7 +796,7 @@ object JsonExprUtils { } case m: CreateMap => throw new AnalysisException( - s"A type of keys and values in map() must be string, but got ${m.dataType}") + s"A type of keys and values in map() must be string, but got ${m.dataType.simpleString}") case _ => throw new AnalysisException("Must use a map() function for options") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index bedad7da334ae..70dd4df9df511 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -222,11 +222,12 @@ case class Elt(children: Seq[Expression]) extends Expression { val (indexType, inputTypes) = (indexExpr.dataType, inputExprs.map(_.dataType)) if (indexType != IntegerType) { return TypeCheckResult.TypeCheckFailure(s"first input to function $prettyName should " + - s"have IntegerType, but it's $indexType") + s"have ${IntegerType.simpleString}, but it's ${indexType.simpleString}") } if (inputTypes.exists(tpe => !Seq(StringType, BinaryType).contains(tpe))) { return TypeCheckResult.TypeCheckFailure( - s"input to function $prettyName should have StringType or BinaryType, but it's " + + s"input to function $prettyName should have ${StringType.simpleString} or " + + s"${BinaryType.simpleString}, but it's " + inputTypes.map(_.simpleString).mkString("[", ", ", "]")) } TypeUtils.checkForSameTypeInputExpr(inputTypes, s"function $prettyName") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala index 9c413de752a8c..00086abbefd08 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala @@ -45,8 +45,8 @@ private[sql] class JacksonGenerator( // `JackGenerator` can only be initialized with a `StructType` or a `MapType`. require(dataType.isInstanceOf[StructType] || dataType.isInstanceOf[MapType], - "JacksonGenerator only supports to be initialized with a StructType " + - s"or MapType but got ${dataType.simpleString}") + s"JacksonGenerator only supports to be initialized with a ${StructType.simpleString} " + + s"or ${MapType.simpleString} but got ${dataType.simpleString}") // `ValueWriter`s for all fields of the schema private lazy val rootFieldWriters: Array[ValueWriter] = dataType match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala index c3a4ca8f64bf6..aa1691bb40d93 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala @@ -143,7 +143,8 @@ class JacksonParser( case "NaN" => Float.NaN case "Infinity" => Float.PositiveInfinity case "-Infinity" => Float.NegativeInfinity - case other => throw new RuntimeException(s"Cannot parse $other as FloatType.") + case other => throw new RuntimeException( + s"Cannot parse $other as ${FloatType.simpleString}.") } } @@ -158,7 +159,8 @@ class JacksonParser( case "NaN" => Double.NaN case "Infinity" => Double.PositiveInfinity case "-Infinity" => Double.NegativeInfinity - case other => throw new RuntimeException(s"Cannot parse $other as DoubleType.") + case other => + throw new RuntimeException(s"Cannot parse $other as ${DoubleType.simpleString}.") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala index 491ca005877f8..5f70e062d46c8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala @@ -294,8 +294,10 @@ private[sql] object JsonInferSchema { // Both fields1 and fields2 should be sorted by name, since inferField performs sorting. // Therefore, we can take advantage of the fact that we're merging sorted lists and skip // building a hash map or performing additional sorting. - assert(isSorted(fields1), s"StructType's fields were not sorted: ${fields1.toSeq}") - assert(isSorted(fields2), s"StructType's fields were not sorted: ${fields2.toSeq}") + assert(isSorted(fields1), + s"${StructType.simpleString}'s fields were not sorted: ${fields1.toSeq}") + assert(isSorted(fields2), + s"${StructType.simpleString}'s fields were not sorted: ${fields2.toSeq}") val newFields = new java.util.ArrayList[StructField]() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala index 1dcda49a3af6a..a9aaf617f7837 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala @@ -29,7 +29,7 @@ object TypeUtils { if (dt.isInstanceOf[NumericType] || dt == NullType) { TypeCheckResult.TypeCheckSuccess } else { - TypeCheckResult.TypeCheckFailure(s"$caller requires numeric types, not $dt") + TypeCheckResult.TypeCheckFailure(s"$caller requires numeric types, not ${dt.simpleString}") } } @@ -37,7 +37,8 @@ object TypeUtils { if (RowOrdering.isOrderable(dt)) { TypeCheckResult.TypeCheckSuccess } else { - TypeCheckResult.TypeCheckFailure(s"$caller does not support ordering on type $dt") + TypeCheckResult.TypeCheckFailure( + s"$caller does not support ordering on type ${dt.simpleString}") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala index 3041f44b116ea..c43cc748655e8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala @@ -145,7 +145,7 @@ abstract class NumericType extends AtomicType { } -private[sql] object NumericType extends AbstractDataType { +private[spark] object NumericType extends AbstractDataType { /** * Enables matching against NumericType for expressions: * {{{ @@ -155,11 +155,12 @@ private[sql] object NumericType extends AbstractDataType { */ def unapply(e: Expression): Boolean = e.dataType.isInstanceOf[NumericType] - override private[sql] def defaultConcreteType: DataType = DoubleType + override private[spark] def defaultConcreteType: DataType = DoubleType - override private[sql] def simpleString: String = "numeric" + override private[spark] def simpleString: String = "numeric" - override private[sql] def acceptsType(other: DataType): Boolean = other.isInstanceOf[NumericType] + override private[spark] def acceptsType(other: DataType): Boolean = + other.isInstanceOf[NumericType] } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala index 38c40482fa4d9..8f118624f6d2f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala @@ -42,7 +42,7 @@ object ArrayType extends AbstractDataType { other.isInstanceOf[ArrayType] } - override private[sql] def simpleString: String = "array" + override private[spark] def simpleString: String = "array" } /** @@ -103,7 +103,8 @@ case class ArrayType(elementType: DataType, containsNull: Boolean) extends DataT case a : ArrayType => a.interpretedOrdering.asInstanceOf[Ordering[Any]] case s: StructType => s.interpretedOrdering.asInstanceOf[Ordering[Any]] case other => - throw new IllegalArgumentException(s"Type $other does not support ordered operations") + throw new IllegalArgumentException( + s"Type ${other.simpleString} does not support ordered operations") } def compare(x: ArrayData, y: ArrayData): Int = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala index dbf51c398fa47..f780ffd46a876 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala @@ -48,7 +48,8 @@ case class DecimalType(precision: Int, scale: Int) extends FractionalType { } if (precision > DecimalType.MAX_PRECISION) { - throw new AnalysisException(s"DecimalType can only support precision up to 38") + throw new AnalysisException( + s"${DecimalType.simpleString} can only support precision up to ${DecimalType.MAX_PRECISION}") } // default constructor for Java diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ObjectType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ObjectType.scala index 2d49fe076786a..203e85e1c99bd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ObjectType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ObjectType.scala @@ -24,7 +24,8 @@ import org.apache.spark.annotation.InterfaceStability @InterfaceStability.Evolving object ObjectType extends AbstractDataType { override private[sql] def defaultConcreteType: DataType = - throw new UnsupportedOperationException("null literals can't be casted to ObjectType") + throw new UnsupportedOperationException( + s"null literals can't be casted to ${ObjectType.simpleString}") override private[sql] def acceptsType(other: DataType): Boolean = other match { case ObjectType(_) => true diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index 362676b252126..0e69ef8ba73e8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -426,7 +426,7 @@ object StructType extends AbstractDataType { private[sql] def fromString(raw: String): StructType = { Try(DataType.fromJson(raw)).getOrElse(LegacyTypeStringParser.parse(raw)) match { case t: StructType => t - case _ => throw new RuntimeException(s"Failed parsing StructType: $raw") + case _ => throw new RuntimeException(s"Failed parsing ${StructType.simpleString}: $raw") } } @@ -528,7 +528,8 @@ object StructType extends AbstractDataType { leftType case _ => - throw new SparkException(s"Failed to merge incompatible data types $left and $right") + throw new SparkException(s"Failed to merge incompatible data types ${left.simpleString} " + + s"and ${right.simpleString}") } private[sql] def fieldsMap(fields: Array[StructField]): Map[String, StructField] = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index 5d2f8e735e3d4..5e503be416a1f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -514,7 +514,7 @@ class AnalysisErrorSuite extends AnalysisTest { right, joinType = Cross, condition = Some('b === 'd)) - assertAnalysisError(plan2, "EqualTo does not support ordering on type MapType" :: Nil) + assertAnalysisError(plan2, "EqualTo does not support ordering on type map" :: Nil) } test("PredicateSubQuery is used outside of a filter") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala index 36714bd631b0e..8eec14842c7e7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala @@ -109,17 +109,17 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { assertErrorForDifferingTypes(GreaterThan('intField, 'booleanField)) assertErrorForDifferingTypes(GreaterThanOrEqual('intField, 'booleanField)) - assertError(EqualTo('mapField, 'mapField), "EqualTo does not support ordering on type MapType") + assertError(EqualTo('mapField, 'mapField), "EqualTo does not support ordering on type map") assertError(EqualNullSafe('mapField, 'mapField), - "EqualNullSafe does not support ordering on type MapType") + "EqualNullSafe does not support ordering on type map") assertError(LessThan('mapField, 'mapField), - "LessThan does not support ordering on type MapType") + "LessThan does not support ordering on type map") assertError(LessThanOrEqual('mapField, 'mapField), - "LessThanOrEqual does not support ordering on type MapType") + "LessThanOrEqual does not support ordering on type map") assertError(GreaterThan('mapField, 'mapField), - "GreaterThan does not support ordering on type MapType") + "GreaterThan does not support ordering on type map") assertError(GreaterThanOrEqual('mapField, 'mapField), - "GreaterThanOrEqual does not support ordering on type MapType") + "GreaterThanOrEqual does not support ordering on type map") assertError(If('intField, 'stringField, 'stringField), "type of predicate expression in If should be boolean") @@ -169,10 +169,10 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { CreateNamedStruct(Seq("a", "b", 2.0)), "even number of arguments") assertError( CreateNamedStruct(Seq(1, "a", "b", 2.0)), - "Only foldable StringType expressions are allowed to appear at odd position") + "Only foldable string expressions are allowed to appear at odd position") assertError( CreateNamedStruct(Seq('a.string.at(0), "a", "b", 2.0)), - "Only foldable StringType expressions are allowed to appear at odd position") + "Only foldable string expressions are allowed to appear at odd position") assertError( CreateNamedStruct(Seq(Literal.create(null, StringType), "a")), "Field name should not be null") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala index cb8a1fecb80a7..b4d422d8506fc 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala @@ -469,7 +469,7 @@ class ExpressionParserSuite extends PlanTest { Literal(BigDecimal("90912830918230182310293801923652346786").underlying())) assertEqual("123.0E-28BD", Literal(BigDecimal("123.0E-28").underlying())) assertEqual("123.08BD", Literal(BigDecimal("123.08").underlying())) - intercept("1.20E-38BD", "DecimalType can only support precision up to 38") + intercept("1.20E-38BD", "decimal can only support precision up to 38") } test("strings") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala index 5a86f4055dce7..fccd057e577d4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala @@ -154,7 +154,7 @@ class DataTypeSuite extends SparkFunSuite { left.merge(right) }.getMessage assert(message.equals("Failed to merge fields 'b' and 'b'. " + - "Failed to merge incompatible data types FloatType and LongType")) + "Failed to merge incompatible data types float and bigint")) } test("existsRecursively") { diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java index d5969b55eef96..060e2ec068053 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java @@ -244,7 +244,7 @@ private SchemaColumnConvertNotSupportedException constructConvertNotSupportedExc return new SchemaColumnConvertNotSupportedException( Arrays.toString(descriptor.getPath()), descriptor.getType().toString(), - column.dataType().toString()); + column.dataType().simpleString()); } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index c6449cd5a16b0..b068493f2dd17 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -452,7 +452,7 @@ class RelationalGroupedDataset protected[sql]( require(expr.evalType == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, "Must pass a grouped map udf") require(expr.dataType.isInstanceOf[StructType], - "The returnType of the udf must be a StructType") + s"The returnType of the udf must be a ${StructType.simpleString}") val groupingNamedExpressions = groupingExprs.map { case ne: NamedExpression => ne diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowUtils.scala index 93c8127681b3e..1274abffaa116 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowUtils.scala @@ -47,7 +47,8 @@ object ArrowUtils { case DateType => new ArrowType.Date(DateUnit.DAY) case TimestampType => if (timeZoneId == null) { - throw new UnsupportedOperationException("TimestampType must supply timeZoneId parameter") + throw new UnsupportedOperationException( + s"${TimestampType.simpleString} must supply timeZoneId parameter") } else { new ArrowType.Timestamp(TimeUnit.MICROSECOND, timeZoneId) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala index 4f44ae4fa1d71..c90328f7ad43f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala @@ -98,7 +98,7 @@ private[orc] object OrcFilters { case DateType => PredicateLeaf.Type.DATE case TimestampType => PredicateLeaf.Type.TIMESTAMP case _: DecimalType => PredicateLeaf.Type.DECIMAL - case _ => throw new UnsupportedOperationException(s"DataType: $dataType") + case _ => throw new UnsupportedOperationException(s"DataType: ${dataType.simpleString}") } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala index c61be077d309f..18decad3f62f0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala @@ -555,7 +555,7 @@ class SparkToParquetSchemaConverter( convertField(field.copy(dataType = udt.sqlType)) case _ => - throw new AnalysisException(s"Unsupported data type $field.dataType") + throw new AnalysisException(s"Unsupported data type ${field.dataType.simpleString}") } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala index 685d5841ab551..f772a3336d6af 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala @@ -157,7 +157,7 @@ object StatFunctions extends Logging { cols.map(name => (name, df.schema.fields.find(_.name == name))).foreach { case (name, data) => require(data.nonEmpty, s"Couldn't find column with name $name") require(data.get.dataType.isInstanceOf[NumericType], s"Currently $functionName calculation " + - s"for columns with dataType ${data.get.dataType} not supported.") + s"for columns with dataType ${data.get.dataType.simpleString} not supported.") } val columns = cols.map(n => Column(Cast(Column(n).expr, DoubleType))) df.select(columns: _*).queryExecution.toRdd.treeAggregate(new CovarianceCounter)( diff --git a/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out index 3d49323751a10..827931d74138d 100644 --- a/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out @@ -120,7 +120,7 @@ select to_json(named_struct('a', 1, 'b', 2), map('mode', 1)) struct<> -- !query 11 output org.apache.spark.sql.AnalysisException -A type of keys and values in map() must be string, but got MapType(StringType,IntegerType,false);; line 1 pos 7 +A type of keys and values in map() must be string, but got map;; line 1 pos 7 -- !query 12 @@ -216,7 +216,7 @@ select from_json('{"a":1}', 'a INT', map('mode', 1)) struct<> -- !query 20 output org.apache.spark.sql.AnalysisException -A type of keys and values in map() must be string, but got MapType(StringType,IntegerType,false);; line 1 pos 7 +A type of keys and values in map() must be string, but got map;; line 1 pos 7 -- !query 21 diff --git a/sql/core/src/test/resources/sql-tests/results/literals.sql.out b/sql/core/src/test/resources/sql-tests/results/literals.sql.out index b8c91dc8b59a4..7f301614523b2 100644 --- a/sql/core/src/test/resources/sql-tests/results/literals.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/literals.sql.out @@ -147,7 +147,7 @@ struct<> -- !query 15 output org.apache.spark.sql.catalyst.parser.ParseException -DecimalType can only support precision up to 38 +decimal can only support precision up to 38 == SQL == select 1234567890123456789012345678901234567890 @@ -159,7 +159,7 @@ struct<> -- !query 16 output org.apache.spark.sql.catalyst.parser.ParseException -DecimalType can only support precision up to 38 +decimal can only support precision up to 38 == SQL == select 1234567890123456789012345678901234567890.0 @@ -379,7 +379,7 @@ struct<> -- !query 39 output org.apache.spark.sql.catalyst.parser.ParseException -DecimalType can only support precision up to 38(line 1, pos 7) +decimal can only support precision up to 38(line 1, pos 7) == SQL == select 1.20E-38BD diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala index 9d3dfae348beb..368e52cfbda9c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala @@ -430,9 +430,9 @@ class ParquetSchemaSuite extends ParquetSchemaTest { val col = spark.read.parquet(file).schema.fields.filter(_.name.equals("a")) assert(col.length == 1) if (col(0).dataType == StringType) { - assert(errMsg.contains("Column: [a], Expected: IntegerType, Found: BINARY")) + assert(errMsg.contains("Column: [a], Expected: int, Found: BINARY")) } else { - assert(errMsg.endsWith("Column: [a], Expected: StringType, Found: INT32")) + assert(errMsg.endsWith("Column: [a], Expected: string, Found: INT32")) } } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala index 7dcaf170f9693..40be4e8c1f5be 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala @@ -78,9 +78,9 @@ case class HiveTableScanExec( // Bind all partition key attribute references in the partition pruning predicate for later // evaluation. private lazy val boundPruningPred = partitionPruningPred.reduceLeftOption(And).map { pred => - require( - pred.dataType == BooleanType, - s"Data type of predicate $pred must be BooleanType rather than ${pred.dataType}.") + require(pred.dataType == BooleanType, + s"Data type of predicate $pred must be ${BooleanType.simpleString} rather than " + + s"${pred.dataType.simpleString}.") BindReferences.bindReference(pred, relation.partitionCols) } From aec966b05e8df9d459dae88d091de1923e50e2dc Mon Sep 17 00:00:00 2001 From: Xiao Li Date: Mon, 9 Jul 2018 14:24:23 -0700 Subject: [PATCH 1084/1161] Revert "[SPARK-24268][SQL] Use datatype.simpleString in error messages" This reverts commit 1bd3d61f4191767a94b71b42f4d00706b703e84f. --- .../spark/sql/kafka010/KafkaWriteTask.scala | 6 +++--- .../apache/spark/sql/kafka010/KafkaWriter.scala | 6 +++--- .../sql/kafka010/KafkaContinuousSinkSuite.scala | 4 ++-- .../spark/sql/kafka010/KafkaSinkSuite.scala | 4 ++-- .../scala/org/apache/spark/ml/feature/DCT.scala | 3 +-- .../apache/spark/ml/feature/FeatureHasher.scala | 5 ++--- .../org/apache/spark/ml/feature/HashingTF.scala | 2 +- .../apache/spark/ml/feature/Interaction.scala | 3 +-- .../org/apache/spark/ml/feature/NGram.scala | 2 +- .../apache/spark/ml/feature/OneHotEncoder.scala | 3 +-- .../org/apache/spark/ml/feature/RFormula.scala | 2 +- .../spark/ml/feature/StopWordsRemover.scala | 4 ++-- .../org/apache/spark/ml/feature/Tokenizer.scala | 3 +-- .../spark/ml/feature/VectorAssembler.scala | 2 +- .../scala/org/apache/spark/ml/fpm/FPGrowth.scala | 2 +- .../org/apache/spark/ml/util/SchemaUtils.scala | 11 ++++------- .../BinaryClassificationEvaluatorSuite.scala | 4 ++-- .../apache/spark/ml/feature/RFormulaSuite.scala | 2 +- .../spark/ml/feature/VectorAssemblerSuite.scala | 6 +++--- .../spark/ml/recommendation/ALSSuite.scala | 2 +- .../regression/AFTSurvivalRegressionSuite.scala | 2 +- .../apache/spark/ml/util/MLTestingUtils.scala | 6 +++--- .../expressions/complexTypeCreator.scala | 4 ++-- .../catalyst/expressions/jsonExpressions.scala | 2 +- .../catalyst/expressions/stringExpressions.scala | 5 ++--- .../sql/catalyst/json/JacksonGenerator.scala | 4 ++-- .../spark/sql/catalyst/json/JacksonParser.scala | 6 ++---- .../sql/catalyst/json/JsonInferSchema.scala | 6 ++---- .../spark/sql/catalyst/util/TypeUtils.scala | 5 ++--- .../spark/sql/types/AbstractDataType.scala | 9 ++++----- .../org/apache/spark/sql/types/ArrayType.scala | 5 ++--- .../org/apache/spark/sql/types/DecimalType.scala | 3 +-- .../org/apache/spark/sql/types/ObjectType.scala | 3 +-- .../org/apache/spark/sql/types/StructType.scala | 5 ++--- .../catalyst/analysis/AnalysisErrorSuite.scala | 2 +- .../analysis/ExpressionTypeCheckingSuite.scala | 16 ++++++++-------- .../catalyst/parser/ExpressionParserSuite.scala | 2 +- .../apache/spark/sql/types/DataTypeSuite.scala | 2 +- .../parquet/VectorizedColumnReader.java | 2 +- .../spark/sql/RelationalGroupedDataset.scala | 2 +- .../spark/sql/execution/arrow/ArrowUtils.scala | 3 +-- .../execution/datasources/orc/OrcFilters.scala | 2 +- .../parquet/ParquetSchemaConverter.scala | 2 +- .../spark/sql/execution/stat/StatFunctions.scala | 2 +- .../sql-tests/results/json-functions.sql.out | 4 ++-- .../resources/sql-tests/results/literals.sql.out | 6 +++--- .../datasources/parquet/ParquetSchemaSuite.scala | 4 ++-- .../sql/hive/execution/HiveTableScanExec.scala | 6 +++--- 48 files changed, 88 insertions(+), 108 deletions(-) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala index 59a84706d4f55..d90630a8adc93 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala @@ -110,7 +110,7 @@ private[kafka010] abstract class KafkaRowWriter( case t => throw new IllegalStateException(s"${KafkaWriter.TOPIC_ATTRIBUTE_NAME} " + s"attribute unsupported type $t. ${KafkaWriter.TOPIC_ATTRIBUTE_NAME} " + - s"must be a ${StringType.simpleString}") + "must be a StringType") } val keyExpression = inputSchema.find(_.name == KafkaWriter.KEY_ATTRIBUTE_NAME) .getOrElse(Literal(null, BinaryType)) @@ -118,7 +118,7 @@ private[kafka010] abstract class KafkaRowWriter( case StringType | BinaryType => // good case t => throw new IllegalStateException(s"${KafkaWriter.KEY_ATTRIBUTE_NAME} " + - s"attribute unsupported type ${t.simpleString}") + s"attribute unsupported type $t") } val valueExpression = inputSchema .find(_.name == KafkaWriter.VALUE_ATTRIBUTE_NAME).getOrElse( @@ -129,7 +129,7 @@ private[kafka010] abstract class KafkaRowWriter( case StringType | BinaryType => // good case t => throw new IllegalStateException(s"${KafkaWriter.VALUE_ATTRIBUTE_NAME} " + - s"attribute unsupported type ${t.simpleString}") + s"attribute unsupported type $t") } UnsafeProjection.create( Seq(topicExpression, Cast(keyExpression, BinaryType), diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala index 3ec26e9edd353..15cd44812cb0c 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala @@ -57,7 +57,7 @@ private[kafka010] object KafkaWriter extends Logging { ).dataType match { case StringType => // good case _ => - throw new AnalysisException(s"Topic type must be a ${StringType.simpleString}") + throw new AnalysisException(s"Topic type must be a String") } schema.find(_.name == KEY_ATTRIBUTE_NAME).getOrElse( Literal(null, StringType) @@ -65,7 +65,7 @@ private[kafka010] object KafkaWriter extends Logging { case StringType | BinaryType => // good case _ => throw new AnalysisException(s"$KEY_ATTRIBUTE_NAME attribute type " + - s"must be a ${StringType.simpleString} or ${BinaryType.simpleString}") + s"must be a String or BinaryType") } schema.find(_.name == VALUE_ATTRIBUTE_NAME).getOrElse( throw new AnalysisException(s"Required attribute '$VALUE_ATTRIBUTE_NAME' not found") @@ -73,7 +73,7 @@ private[kafka010] object KafkaWriter extends Logging { case StringType | BinaryType => // good case _ => throw new AnalysisException(s"$VALUE_ATTRIBUTE_NAME attribute type " + - s"must be a ${StringType.simpleString} or ${BinaryType.simpleString}") + s"must be a String or BinaryType") } } diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala index 0e1492ac27449..ddfc0c1a4be2d 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala @@ -314,7 +314,7 @@ class KafkaContinuousSinkSuite extends KafkaContinuousTest { writer.stop() } assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( - "value attribute type must be a string or binary")) + "value attribute type must be a string or binarytype")) try { /* key field wrong type */ @@ -330,7 +330,7 @@ class KafkaContinuousSinkSuite extends KafkaContinuousTest { writer.stop() } assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( - "key attribute type must be a string or binary")) + "key attribute type must be a string or binarytype")) } test("streaming - write to non-existing topic") { diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala index 70ffd7dee89d7..7079ac6453ffc 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala @@ -303,7 +303,7 @@ class KafkaSinkSuite extends StreamTest with SharedSQLContext { writer.stop() } assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( - "value attribute type must be a string or binary")) + "value attribute type must be a string or binarytype")) try { ex = intercept[StreamingQueryException] { @@ -318,7 +318,7 @@ class KafkaSinkSuite extends StreamTest with SharedSQLContext { writer.stop() } assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( - "key attribute type must be a string or binary")) + "key attribute type must be a string or binarytype")) } test("streaming - write to non-existing topic") { diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala index 1eac1d1613d2a..682787a830113 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala @@ -69,8 +69,7 @@ class DCT @Since("1.5.0") (@Since("1.5.0") override val uid: String) } override protected def validateInputType(inputType: DataType): Unit = { - require(inputType.isInstanceOf[VectorUDT], - s"Input type must be ${(new VectorUDT).simpleString} but got ${inputType.simpleString}.") + require(inputType.isInstanceOf[VectorUDT], s"Input type must be VectorUDT but got $inputType.") } override protected def outputDataType: DataType = new VectorUDT diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/FeatureHasher.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/FeatureHasher.scala index 405ea467cb02a..d67e4819b161a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/FeatureHasher.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/FeatureHasher.scala @@ -208,9 +208,8 @@ class FeatureHasher(@Since("2.3.0") override val uid: String) extends Transforme require(dataType.isInstanceOf[NumericType] || dataType.isInstanceOf[StringType] || dataType.isInstanceOf[BooleanType], - s"FeatureHasher requires columns to be of ${NumericType.simpleString}, " + - s"${BooleanType.simpleString} or ${StringType.simpleString}. " + - s"Column $fieldName was ${dataType.simpleString}") + s"FeatureHasher requires columns to be of NumericType, BooleanType or StringType. " + + s"Column $fieldName was $dataType") } val attrGroup = new AttributeGroup($(outputCol), $(numFeatures)) SchemaUtils.appendColumn(schema, attrGroup.toStructField()) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala index 403b0a813aedd..db432b6fefaff 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala @@ -104,7 +104,7 @@ class HashingTF @Since("1.4.0") (@Since("1.4.0") override val uid: String) override def transformSchema(schema: StructType): StructType = { val inputType = schema($(inputCol)).dataType require(inputType.isInstanceOf[ArrayType], - s"The input column must be ${ArrayType.simpleString}, but got ${inputType.simpleString}.") + s"The input column must be ArrayType, but got $inputType.") val attrGroup = new AttributeGroup($(outputCol), $(numFeatures)) SchemaUtils.appendColumn(schema, attrGroup.toStructField()) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala index 5e01ec30bb2eb..4ff1d0ef356f3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala @@ -261,8 +261,7 @@ private[ml] class FeatureEncoder(numFeatures: Array[Int]) extends Serializable { */ def foreachNonzeroOutput(value: Any, f: (Int, Double) => Unit): Unit = value match { case d: Double => - assert(numFeatures.length == 1, - s"${DoubleType.simpleString} columns should only contain one feature.") + assert(numFeatures.length == 1, "DoubleType columns should only contain one feature.") val numOutputCols = numFeatures.head if (numOutputCols > 1) { assert( diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala index 6445360f7fd90..c8760f9dc178f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala @@ -65,7 +65,7 @@ class NGram @Since("1.5.0") (@Since("1.5.0") override val uid: String) override protected def validateInputType(inputType: DataType): Unit = { require(inputType.sameType(ArrayType(StringType)), - s"Input type must be ${ArrayType(StringType).simpleString} but got $inputType.") + s"Input type must be ArrayType(StringType) but got $inputType.") } override protected def outputDataType: DataType = new ArrayType(StringType, false) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala index 24045f0448c81..5ab6c2dde667a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala @@ -85,8 +85,7 @@ class OneHotEncoder @Since("1.4.0") (@Since("1.4.0") override val uid: String) e val inputFields = schema.fields require(schema(inputColName).dataType.isInstanceOf[NumericType], - s"Input column must be of type ${NumericType.simpleString} but got " + - schema(inputColName).dataType.simpleString) + s"Input column must be of type NumericType but got ${schema(inputColName).dataType}") require(!inputFields.exists(_.name == outputColName), s"Output column $outputColName already exists.") diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala index 346e1823f00b8..55e595eee6ffb 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala @@ -394,7 +394,7 @@ class RFormulaModel private[feature]( require(!columnNames.contains($(featuresCol)), "Features column already exists.") require( !columnNames.contains($(labelCol)) || schema($(labelCol)).dataType.isInstanceOf[NumericType], - s"Label column already exists and is not of type ${NumericType.simpleString}.") + "Label column already exists and is not of type NumericType.") } @Since("2.0.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala index ead75d5b8def3..0f946dd2e015b 100755 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala @@ -131,8 +131,8 @@ class StopWordsRemover @Since("1.5.0") (@Since("1.5.0") override val uid: String @Since("1.5.0") override def transformSchema(schema: StructType): StructType = { val inputType = schema($(inputCol)).dataType - require(inputType.sameType(ArrayType(StringType)), "Input type must be " + - s"${ArrayType(StringType).simpleString} but got ${inputType.simpleString}.") + require(inputType.sameType(ArrayType(StringType)), + s"Input type must be ArrayType(StringType) but got $inputType.") SchemaUtils.appendColumn(schema, $(outputCol), inputType, schema($(inputCol)).nullable) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala index 5132f63af1796..cfaf6c0e610b3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala @@ -40,8 +40,7 @@ class Tokenizer @Since("1.4.0") (@Since("1.4.0") override val uid: String) } override protected def validateInputType(inputType: DataType): Unit = { - require(inputType == StringType, - s"Input type must be ${StringType.simpleString} type but got ${inputType.simpleString}.") + require(inputType == StringType, s"Input type must be string type but got $inputType.") } override protected def outputDataType: DataType = new ArrayType(StringType, true) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala index ed3b36ee5ab2f..4061154b39c14 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala @@ -162,7 +162,7 @@ class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String) schema(name).dataType match { case _: NumericType | BooleanType => None case t if t.isInstanceOf[VectorUDT] => None - case other => Some(s"Data type ${other.simpleString} of column $name is not supported.") + case other => Some(s"Data type $other of column $name is not supported.") } } if (incorrectColumns.nonEmpty) { diff --git a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala index 51b88b3117f4e..d7fbe28ae7a64 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala @@ -106,7 +106,7 @@ private[fpm] trait FPGrowthParams extends Params with HasPredictionCol { protected def validateAndTransformSchema(schema: StructType): StructType = { val inputType = schema($(itemsCol)).dataType require(inputType.isInstanceOf[ArrayType], - s"The input column must be ${ArrayType.simpleString}, but got ${inputType.simpleString}.") + s"The input column must be ArrayType, but got $inputType.") SchemaUtils.appendColumn(schema, $(predictionCol), schema($(itemsCol)).dataType) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala index b500582074398..d9a3f85ef9a24 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala @@ -41,8 +41,7 @@ private[spark] object SchemaUtils { val actualDataType = schema(colName).dataType val message = if (msg != null && msg.trim.length > 0) " " + msg else "" require(actualDataType.equals(dataType), - s"Column $colName must be of type ${dataType.simpleString} but was actually " + - s"${actualDataType.simpleString}.$message") + s"Column $colName must be of type $dataType but was actually $actualDataType.$message") } /** @@ -59,8 +58,7 @@ private[spark] object SchemaUtils { val message = if (msg != null && msg.trim.length > 0) " " + msg else "" require(dataTypes.exists(actualDataType.equals), s"Column $colName must be of type equal to one of the following types: " + - s"${dataTypes.map(_.simpleString).mkString("[", ", ", "]")} but was actually of type " + - s"${actualDataType.simpleString}.$message") + s"${dataTypes.mkString("[", ", ", "]")} but was actually of type $actualDataType.$message") } /** @@ -73,9 +71,8 @@ private[spark] object SchemaUtils { msg: String = ""): Unit = { val actualDataType = schema(colName).dataType val message = if (msg != null && msg.trim.length > 0) " " + msg else "" - require(actualDataType.isInstanceOf[NumericType], - s"Column $colName must be of type ${NumericType.simpleString} but was actually of type " + - s"${actualDataType.simpleString}.$message") + require(actualDataType.isInstanceOf[NumericType], s"Column $colName must be of type " + + s"NumericType but was actually of type $actualDataType.$message") } /** diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala index 2b0909acf69c3..ede284712b1c0 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala @@ -67,8 +67,8 @@ class BinaryClassificationEvaluatorSuite evaluator.evaluate(stringDF) } assert(thrown.getMessage.replace("\n", "") contains "Column rawPrediction must be of type " + - "equal to one of the following types: [double, ") - assert(thrown.getMessage.replace("\n", "") contains "but was actually of type string.") + "equal to one of the following types: [DoubleType, ") + assert(thrown.getMessage.replace("\n", "") contains "but was actually of type StringType.") } test("should support all NumericType labels and not support other types") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala index 0de6528c4cf22..a250331efeb1d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala @@ -105,7 +105,7 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest { testTransformerByInterceptingException[(Int, Boolean)]( original, model, - "Label column already exists and is not of type numeric.", + "Label column already exists and is not of type NumericType.", "x") } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala index ed15a1d88a269..91fb24a268b8c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala @@ -99,9 +99,9 @@ class VectorAssemblerSuite assembler.transform(df) } assert(thrown.getMessage contains - "Data type string of column a is not supported.\n" + - "Data type string of column b is not supported.\n" + - "Data type string of column c is not supported.") + "Data type StringType of column a is not supported.\n" + + "Data type StringType of column b is not supported.\n" + + "Data type StringType of column c is not supported.") } test("ML attributes") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala index 65bee4edc4965..e3dfe2faf5698 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala @@ -612,7 +612,7 @@ class ALSSuite extends MLTest with DefaultReadWriteTest with Logging { estimator.fit(strDF) } assert(thrown.getMessage.contains( - s"$column must be of type numeric but was actually of type string")) + s"$column must be of type NumericType but was actually of type StringType")) } private class NumericTypeWithEncoder[A](val numericType: NumericType) diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala index 6cc73e040e82c..4e4ff71c9de90 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala @@ -385,7 +385,7 @@ class AFTSurvivalRegressionSuite extends MLTest with DefaultReadWriteTest { aft.fit(dfWithStringCensors) } assert(thrown.getMessage.contains( - "Column censor must be of type numeric but was actually of type string")) + "Column censor must be of type NumericType but was actually of type StringType")) } test("numerical stability of standardization") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala index 91a8b14625a86..5e72b4d864c1d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala @@ -74,7 +74,7 @@ object MLTestingUtils extends SparkFunSuite { estimator.fit(dfWithStringLabels) } assert(thrown.getMessage.contains( - "Column label must be of type numeric but was actually of type string")) + "Column label must be of type NumericType but was actually of type StringType")) estimator match { case weighted: Estimator[M] with HasWeightCol => @@ -86,7 +86,7 @@ object MLTestingUtils extends SparkFunSuite { weighted.fit(dfWithStringWeights) } assert(thrown.getMessage.contains( - "Column weight must be of type numeric but was actually of type string")) + "Column weight must be of type NumericType but was actually of type StringType")) case _ => } } @@ -104,7 +104,7 @@ object MLTestingUtils extends SparkFunSuite { evaluator.evaluate(dfWithStringLabels) } assert(thrown.getMessage.contains( - "Column label must be of type numeric but was actually of type string")) + "Column label must be of type NumericType but was actually of type StringType")) } def genClassifDFWithNumericLabelCol( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index cf0e3765de80f..0a5f8a907b50a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -385,8 +385,8 @@ trait CreateNamedStructLike extends Expression { val invalidNames = nameExprs.filterNot(e => e.foldable && e.dataType == StringType) if (invalidNames.nonEmpty) { TypeCheckResult.TypeCheckFailure( - s"Only foldable ${StringType.simpleString} expressions are allowed to appear at odd" + - s" position, got: ${invalidNames.mkString(",")}") + "Only foldable StringType expressions are allowed to appear at odd position, got:" + + s" ${invalidNames.mkString(",")}") } else if (!names.contains(null)) { TypeCheckResult.TypeCheckSuccess } else { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index 1bcf11d7ee737..8cd86053a01c7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -796,7 +796,7 @@ object JsonExprUtils { } case m: CreateMap => throw new AnalysisException( - s"A type of keys and values in map() must be string, but got ${m.dataType.simpleString}") + s"A type of keys and values in map() must be string, but got ${m.dataType}") case _ => throw new AnalysisException("Must use a map() function for options") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 70dd4df9df511..bedad7da334ae 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -222,12 +222,11 @@ case class Elt(children: Seq[Expression]) extends Expression { val (indexType, inputTypes) = (indexExpr.dataType, inputExprs.map(_.dataType)) if (indexType != IntegerType) { return TypeCheckResult.TypeCheckFailure(s"first input to function $prettyName should " + - s"have ${IntegerType.simpleString}, but it's ${indexType.simpleString}") + s"have IntegerType, but it's $indexType") } if (inputTypes.exists(tpe => !Seq(StringType, BinaryType).contains(tpe))) { return TypeCheckResult.TypeCheckFailure( - s"input to function $prettyName should have ${StringType.simpleString} or " + - s"${BinaryType.simpleString}, but it's " + + s"input to function $prettyName should have StringType or BinaryType, but it's " + inputTypes.map(_.simpleString).mkString("[", ", ", "]")) } TypeUtils.checkForSameTypeInputExpr(inputTypes, s"function $prettyName") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala index 00086abbefd08..9c413de752a8c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala @@ -45,8 +45,8 @@ private[sql] class JacksonGenerator( // `JackGenerator` can only be initialized with a `StructType` or a `MapType`. require(dataType.isInstanceOf[StructType] || dataType.isInstanceOf[MapType], - s"JacksonGenerator only supports to be initialized with a ${StructType.simpleString} " + - s"or ${MapType.simpleString} but got ${dataType.simpleString}") + "JacksonGenerator only supports to be initialized with a StructType " + + s"or MapType but got ${dataType.simpleString}") // `ValueWriter`s for all fields of the schema private lazy val rootFieldWriters: Array[ValueWriter] = dataType match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala index aa1691bb40d93..c3a4ca8f64bf6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala @@ -143,8 +143,7 @@ class JacksonParser( case "NaN" => Float.NaN case "Infinity" => Float.PositiveInfinity case "-Infinity" => Float.NegativeInfinity - case other => throw new RuntimeException( - s"Cannot parse $other as ${FloatType.simpleString}.") + case other => throw new RuntimeException(s"Cannot parse $other as FloatType.") } } @@ -159,8 +158,7 @@ class JacksonParser( case "NaN" => Double.NaN case "Infinity" => Double.PositiveInfinity case "-Infinity" => Double.NegativeInfinity - case other => - throw new RuntimeException(s"Cannot parse $other as ${DoubleType.simpleString}.") + case other => throw new RuntimeException(s"Cannot parse $other as DoubleType.") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala index 5f70e062d46c8..491ca005877f8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala @@ -294,10 +294,8 @@ private[sql] object JsonInferSchema { // Both fields1 and fields2 should be sorted by name, since inferField performs sorting. // Therefore, we can take advantage of the fact that we're merging sorted lists and skip // building a hash map or performing additional sorting. - assert(isSorted(fields1), - s"${StructType.simpleString}'s fields were not sorted: ${fields1.toSeq}") - assert(isSorted(fields2), - s"${StructType.simpleString}'s fields were not sorted: ${fields2.toSeq}") + assert(isSorted(fields1), s"StructType's fields were not sorted: ${fields1.toSeq}") + assert(isSorted(fields2), s"StructType's fields were not sorted: ${fields2.toSeq}") val newFields = new java.util.ArrayList[StructField]() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala index a9aaf617f7837..1dcda49a3af6a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala @@ -29,7 +29,7 @@ object TypeUtils { if (dt.isInstanceOf[NumericType] || dt == NullType) { TypeCheckResult.TypeCheckSuccess } else { - TypeCheckResult.TypeCheckFailure(s"$caller requires numeric types, not ${dt.simpleString}") + TypeCheckResult.TypeCheckFailure(s"$caller requires numeric types, not $dt") } } @@ -37,8 +37,7 @@ object TypeUtils { if (RowOrdering.isOrderable(dt)) { TypeCheckResult.TypeCheckSuccess } else { - TypeCheckResult.TypeCheckFailure( - s"$caller does not support ordering on type ${dt.simpleString}") + TypeCheckResult.TypeCheckFailure(s"$caller does not support ordering on type $dt") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala index c43cc748655e8..3041f44b116ea 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala @@ -145,7 +145,7 @@ abstract class NumericType extends AtomicType { } -private[spark] object NumericType extends AbstractDataType { +private[sql] object NumericType extends AbstractDataType { /** * Enables matching against NumericType for expressions: * {{{ @@ -155,12 +155,11 @@ private[spark] object NumericType extends AbstractDataType { */ def unapply(e: Expression): Boolean = e.dataType.isInstanceOf[NumericType] - override private[spark] def defaultConcreteType: DataType = DoubleType + override private[sql] def defaultConcreteType: DataType = DoubleType - override private[spark] def simpleString: String = "numeric" + override private[sql] def simpleString: String = "numeric" - override private[spark] def acceptsType(other: DataType): Boolean = - other.isInstanceOf[NumericType] + override private[sql] def acceptsType(other: DataType): Boolean = other.isInstanceOf[NumericType] } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala index 8f118624f6d2f..38c40482fa4d9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala @@ -42,7 +42,7 @@ object ArrayType extends AbstractDataType { other.isInstanceOf[ArrayType] } - override private[spark] def simpleString: String = "array" + override private[sql] def simpleString: String = "array" } /** @@ -103,8 +103,7 @@ case class ArrayType(elementType: DataType, containsNull: Boolean) extends DataT case a : ArrayType => a.interpretedOrdering.asInstanceOf[Ordering[Any]] case s: StructType => s.interpretedOrdering.asInstanceOf[Ordering[Any]] case other => - throw new IllegalArgumentException( - s"Type ${other.simpleString} does not support ordered operations") + throw new IllegalArgumentException(s"Type $other does not support ordered operations") } def compare(x: ArrayData, y: ArrayData): Int = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala index f780ffd46a876..dbf51c398fa47 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala @@ -48,8 +48,7 @@ case class DecimalType(precision: Int, scale: Int) extends FractionalType { } if (precision > DecimalType.MAX_PRECISION) { - throw new AnalysisException( - s"${DecimalType.simpleString} can only support precision up to ${DecimalType.MAX_PRECISION}") + throw new AnalysisException(s"DecimalType can only support precision up to 38") } // default constructor for Java diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ObjectType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ObjectType.scala index 203e85e1c99bd..2d49fe076786a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ObjectType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ObjectType.scala @@ -24,8 +24,7 @@ import org.apache.spark.annotation.InterfaceStability @InterfaceStability.Evolving object ObjectType extends AbstractDataType { override private[sql] def defaultConcreteType: DataType = - throw new UnsupportedOperationException( - s"null literals can't be casted to ${ObjectType.simpleString}") + throw new UnsupportedOperationException("null literals can't be casted to ObjectType") override private[sql] def acceptsType(other: DataType): Boolean = other match { case ObjectType(_) => true diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index 0e69ef8ba73e8..362676b252126 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -426,7 +426,7 @@ object StructType extends AbstractDataType { private[sql] def fromString(raw: String): StructType = { Try(DataType.fromJson(raw)).getOrElse(LegacyTypeStringParser.parse(raw)) match { case t: StructType => t - case _ => throw new RuntimeException(s"Failed parsing ${StructType.simpleString}: $raw") + case _ => throw new RuntimeException(s"Failed parsing StructType: $raw") } } @@ -528,8 +528,7 @@ object StructType extends AbstractDataType { leftType case _ => - throw new SparkException(s"Failed to merge incompatible data types ${left.simpleString} " + - s"and ${right.simpleString}") + throw new SparkException(s"Failed to merge incompatible data types $left and $right") } private[sql] def fieldsMap(fields: Array[StructField]): Map[String, StructField] = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index 5e503be416a1f..5d2f8e735e3d4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -514,7 +514,7 @@ class AnalysisErrorSuite extends AnalysisTest { right, joinType = Cross, condition = Some('b === 'd)) - assertAnalysisError(plan2, "EqualTo does not support ordering on type map" :: Nil) + assertAnalysisError(plan2, "EqualTo does not support ordering on type MapType" :: Nil) } test("PredicateSubQuery is used outside of a filter") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala index 8eec14842c7e7..36714bd631b0e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala @@ -109,17 +109,17 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { assertErrorForDifferingTypes(GreaterThan('intField, 'booleanField)) assertErrorForDifferingTypes(GreaterThanOrEqual('intField, 'booleanField)) - assertError(EqualTo('mapField, 'mapField), "EqualTo does not support ordering on type map") + assertError(EqualTo('mapField, 'mapField), "EqualTo does not support ordering on type MapType") assertError(EqualNullSafe('mapField, 'mapField), - "EqualNullSafe does not support ordering on type map") + "EqualNullSafe does not support ordering on type MapType") assertError(LessThan('mapField, 'mapField), - "LessThan does not support ordering on type map") + "LessThan does not support ordering on type MapType") assertError(LessThanOrEqual('mapField, 'mapField), - "LessThanOrEqual does not support ordering on type map") + "LessThanOrEqual does not support ordering on type MapType") assertError(GreaterThan('mapField, 'mapField), - "GreaterThan does not support ordering on type map") + "GreaterThan does not support ordering on type MapType") assertError(GreaterThanOrEqual('mapField, 'mapField), - "GreaterThanOrEqual does not support ordering on type map") + "GreaterThanOrEqual does not support ordering on type MapType") assertError(If('intField, 'stringField, 'stringField), "type of predicate expression in If should be boolean") @@ -169,10 +169,10 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { CreateNamedStruct(Seq("a", "b", 2.0)), "even number of arguments") assertError( CreateNamedStruct(Seq(1, "a", "b", 2.0)), - "Only foldable string expressions are allowed to appear at odd position") + "Only foldable StringType expressions are allowed to appear at odd position") assertError( CreateNamedStruct(Seq('a.string.at(0), "a", "b", 2.0)), - "Only foldable string expressions are allowed to appear at odd position") + "Only foldable StringType expressions are allowed to appear at odd position") assertError( CreateNamedStruct(Seq(Literal.create(null, StringType), "a")), "Field name should not be null") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala index b4d422d8506fc..cb8a1fecb80a7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala @@ -469,7 +469,7 @@ class ExpressionParserSuite extends PlanTest { Literal(BigDecimal("90912830918230182310293801923652346786").underlying())) assertEqual("123.0E-28BD", Literal(BigDecimal("123.0E-28").underlying())) assertEqual("123.08BD", Literal(BigDecimal("123.08").underlying())) - intercept("1.20E-38BD", "decimal can only support precision up to 38") + intercept("1.20E-38BD", "DecimalType can only support precision up to 38") } test("strings") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala index fccd057e577d4..5a86f4055dce7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala @@ -154,7 +154,7 @@ class DataTypeSuite extends SparkFunSuite { left.merge(right) }.getMessage assert(message.equals("Failed to merge fields 'b' and 'b'. " + - "Failed to merge incompatible data types float and bigint")) + "Failed to merge incompatible data types FloatType and LongType")) } test("existsRecursively") { diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java index 060e2ec068053..d5969b55eef96 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java @@ -244,7 +244,7 @@ private SchemaColumnConvertNotSupportedException constructConvertNotSupportedExc return new SchemaColumnConvertNotSupportedException( Arrays.toString(descriptor.getPath()), descriptor.getType().toString(), - column.dataType().simpleString()); + column.dataType().toString()); } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index b068493f2dd17..c6449cd5a16b0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -452,7 +452,7 @@ class RelationalGroupedDataset protected[sql]( require(expr.evalType == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, "Must pass a grouped map udf") require(expr.dataType.isInstanceOf[StructType], - s"The returnType of the udf must be a ${StructType.simpleString}") + "The returnType of the udf must be a StructType") val groupingNamedExpressions = groupingExprs.map { case ne: NamedExpression => ne diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowUtils.scala index 1274abffaa116..93c8127681b3e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowUtils.scala @@ -47,8 +47,7 @@ object ArrowUtils { case DateType => new ArrowType.Date(DateUnit.DAY) case TimestampType => if (timeZoneId == null) { - throw new UnsupportedOperationException( - s"${TimestampType.simpleString} must supply timeZoneId parameter") + throw new UnsupportedOperationException("TimestampType must supply timeZoneId parameter") } else { new ArrowType.Timestamp(TimeUnit.MICROSECOND, timeZoneId) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala index c90328f7ad43f..4f44ae4fa1d71 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala @@ -98,7 +98,7 @@ private[orc] object OrcFilters { case DateType => PredicateLeaf.Type.DATE case TimestampType => PredicateLeaf.Type.TIMESTAMP case _: DecimalType => PredicateLeaf.Type.DECIMAL - case _ => throw new UnsupportedOperationException(s"DataType: ${dataType.simpleString}") + case _ => throw new UnsupportedOperationException(s"DataType: $dataType") } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala index 18decad3f62f0..c61be077d309f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala @@ -555,7 +555,7 @@ class SparkToParquetSchemaConverter( convertField(field.copy(dataType = udt.sqlType)) case _ => - throw new AnalysisException(s"Unsupported data type ${field.dataType.simpleString}") + throw new AnalysisException(s"Unsupported data type $field.dataType") } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala index f772a3336d6af..685d5841ab551 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala @@ -157,7 +157,7 @@ object StatFunctions extends Logging { cols.map(name => (name, df.schema.fields.find(_.name == name))).foreach { case (name, data) => require(data.nonEmpty, s"Couldn't find column with name $name") require(data.get.dataType.isInstanceOf[NumericType], s"Currently $functionName calculation " + - s"for columns with dataType ${data.get.dataType.simpleString} not supported.") + s"for columns with dataType ${data.get.dataType} not supported.") } val columns = cols.map(n => Column(Cast(Column(n).expr, DoubleType))) df.select(columns: _*).queryExecution.toRdd.treeAggregate(new CovarianceCounter)( diff --git a/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out index 827931d74138d..3d49323751a10 100644 --- a/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out @@ -120,7 +120,7 @@ select to_json(named_struct('a', 1, 'b', 2), map('mode', 1)) struct<> -- !query 11 output org.apache.spark.sql.AnalysisException -A type of keys and values in map() must be string, but got map;; line 1 pos 7 +A type of keys and values in map() must be string, but got MapType(StringType,IntegerType,false);; line 1 pos 7 -- !query 12 @@ -216,7 +216,7 @@ select from_json('{"a":1}', 'a INT', map('mode', 1)) struct<> -- !query 20 output org.apache.spark.sql.AnalysisException -A type of keys and values in map() must be string, but got map;; line 1 pos 7 +A type of keys and values in map() must be string, but got MapType(StringType,IntegerType,false);; line 1 pos 7 -- !query 21 diff --git a/sql/core/src/test/resources/sql-tests/results/literals.sql.out b/sql/core/src/test/resources/sql-tests/results/literals.sql.out index 7f301614523b2..b8c91dc8b59a4 100644 --- a/sql/core/src/test/resources/sql-tests/results/literals.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/literals.sql.out @@ -147,7 +147,7 @@ struct<> -- !query 15 output org.apache.spark.sql.catalyst.parser.ParseException -decimal can only support precision up to 38 +DecimalType can only support precision up to 38 == SQL == select 1234567890123456789012345678901234567890 @@ -159,7 +159,7 @@ struct<> -- !query 16 output org.apache.spark.sql.catalyst.parser.ParseException -decimal can only support precision up to 38 +DecimalType can only support precision up to 38 == SQL == select 1234567890123456789012345678901234567890.0 @@ -379,7 +379,7 @@ struct<> -- !query 39 output org.apache.spark.sql.catalyst.parser.ParseException -decimal can only support precision up to 38(line 1, pos 7) +DecimalType can only support precision up to 38(line 1, pos 7) == SQL == select 1.20E-38BD diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala index 368e52cfbda9c..9d3dfae348beb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala @@ -430,9 +430,9 @@ class ParquetSchemaSuite extends ParquetSchemaTest { val col = spark.read.parquet(file).schema.fields.filter(_.name.equals("a")) assert(col.length == 1) if (col(0).dataType == StringType) { - assert(errMsg.contains("Column: [a], Expected: int, Found: BINARY")) + assert(errMsg.contains("Column: [a], Expected: IntegerType, Found: BINARY")) } else { - assert(errMsg.endsWith("Column: [a], Expected: string, Found: INT32")) + assert(errMsg.endsWith("Column: [a], Expected: StringType, Found: INT32")) } } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala index 40be4e8c1f5be..7dcaf170f9693 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala @@ -78,9 +78,9 @@ case class HiveTableScanExec( // Bind all partition key attribute references in the partition pruning predicate for later // evaluation. private lazy val boundPruningPred = partitionPruningPred.reduceLeftOption(And).map { pred => - require(pred.dataType == BooleanType, - s"Data type of predicate $pred must be ${BooleanType.simpleString} rather than " + - s"${pred.dataType.simpleString}.") + require( + pred.dataType == BooleanType, + s"Data type of predicate $pred must be BooleanType rather than ${pred.dataType}.") BindReferences.bindReference(pred, relation.partitionCols) } From eb6e9880397dbac8b0b9ebc0796150b6924fc566 Mon Sep 17 00:00:00 2001 From: Xiao Li Date: Mon, 9 Jul 2018 14:53:14 -0700 Subject: [PATCH 1085/1161] [SPARK-24759][SQL] No reordering keys for broadcast hash join ## What changes were proposed in this pull request? As the implementation of the broadcast hash join is independent of the input hash partitioning, reordering keys is not necessary. Thus, we solve this issue by simply removing the broadcast hash join from the reordering rule in EnsureRequirements. ## How was this patch tested? N/A Author: Xiao Li Closes #21728 from gatorsmile/cleanER. --- .../spark/sql/execution/exchange/EnsureRequirements.scala | 7 ------- 1 file changed, 7 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index ad95879d86f42..d96ecbaa48029 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -279,13 +279,6 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { */ private def reorderJoinPredicates(plan: SparkPlan): SparkPlan = { plan match { - case BroadcastHashJoinExec(leftKeys, rightKeys, joinType, buildSide, condition, left, - right) => - val (reorderedLeftKeys, reorderedRightKeys) = - reorderJoinKeys(leftKeys, rightKeys, left.outputPartitioning, right.outputPartitioning) - BroadcastHashJoinExec(reorderedLeftKeys, reorderedRightKeys, joinType, buildSide, condition, - left, right) - case ShuffledHashJoinExec(leftKeys, rightKeys, joinType, buildSide, condition, left, right) => val (reorderedLeftKeys, reorderedRightKeys) = reorderJoinKeys(leftKeys, rightKeys, left.outputPartitioning, right.outputPartitioning) From 4984f1af7e48dab1ae08021a3b17c5ad6d47a87e Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Tue, 10 Jul 2018 13:54:04 +0800 Subject: [PATCH 1086/1161] [MINOR] Add Sphinx into dev/requirements.txt ## What changes were proposed in this pull request? Not a big deal but this PR adds `sphinx` into `dev/requirements.txt` since we found it needed - https://github.com/apache/spark-website/pull/122#discussion_r200896018 ## How was this patch tested? manually: ``` pip install -r requirements.txt ``` Author: hyukjinkwon Closes #21735 from HyukjinKwon/minor-dev. --- dev/requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/dev/requirements.txt b/dev/requirements.txt index 79782279f8fbd..fa833ab96b8e7 100644 --- a/dev/requirements.txt +++ b/dev/requirements.txt @@ -2,3 +2,4 @@ jira==1.0.3 PyGithub==1.26.0 Unidecode==0.04.19 pypandoc==1.3.3 +sphinx From a289009567c1566a1df4bcdfdf0111e82ae3d81d Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Tue, 10 Jul 2018 15:58:14 +0800 Subject: [PATCH 1087/1161] [SPARK-24706][SQL] ByteType and ShortType support pushdown to parquet ## What changes were proposed in this pull request? `ByteType` and `ShortType` support pushdown to parquet data source. [Benchmark result](https://issues.apache.org/jira/browse/SPARK-24706?focusedCommentId=16528878&page=com.atlassian.jira.plugin.system.issuetabpanels%3Acomment-tabpanel#comment-16528878). ## How was this patch tested? unit tests Author: Yuming Wang Closes #21682 from wangyum/SPARK-24706. --- .../FilterPushdownBenchmark-results.txt | 32 +++++------ .../datasources/parquet/ParquetFilters.scala | 34 +++++++---- .../parquet/ParquetFilterSuite.scala | 56 +++++++++++++++++++ 3 files changed, 94 insertions(+), 28 deletions(-) diff --git a/sql/core/benchmarks/FilterPushdownBenchmark-results.txt b/sql/core/benchmarks/FilterPushdownBenchmark-results.txt index 29fe4345d69da..110669b69a00d 100644 --- a/sql/core/benchmarks/FilterPushdownBenchmark-results.txt +++ b/sql/core/benchmarks/FilterPushdownBenchmark-results.txt @@ -542,39 +542,39 @@ Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz Select 1 tinyint row (value = CAST(63 AS tinyint)): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Parquet Vectorized 3726 / 3775 4.2 236.9 1.0X -Parquet Vectorized (Pushdown) 3741 / 3789 4.2 237.9 1.0X -Native ORC Vectorized 2793 / 2909 5.6 177.6 1.3X -Native ORC Vectorized (Pushdown) 530 / 561 29.7 33.7 7.0X +Parquet Vectorized 3461 / 3997 4.5 220.1 1.0X +Parquet Vectorized (Pushdown) 270 / 315 58.4 17.1 12.8X +Native ORC Vectorized 4107 / 5372 3.8 261.1 0.8X +Native ORC Vectorized (Pushdown) 778 / 1553 20.2 49.5 4.4X Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz Select 10% tinyint rows (value < CAST(12 AS tinyint)): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Parquet Vectorized 4385 / 4406 3.6 278.8 1.0X -Parquet Vectorized (Pushdown) 4398 / 4454 3.6 279.6 1.0X -Native ORC Vectorized 3420 / 3501 4.6 217.4 1.3X -Native ORC Vectorized (Pushdown) 1395 / 1432 11.3 88.7 3.1X +Parquet Vectorized 4771 / 6655 3.3 303.3 1.0X +Parquet Vectorized (Pushdown) 1322 / 1606 11.9 84.0 3.6X +Native ORC Vectorized 4437 / 4572 3.5 282.1 1.1X +Native ORC Vectorized (Pushdown) 1781 / 1976 8.8 113.2 2.7X Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz Select 50% tinyint rows (value < CAST(63 AS tinyint)): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Parquet Vectorized 7307 / 7394 2.2 464.6 1.0X -Parquet Vectorized (Pushdown) 7411 / 7461 2.1 471.2 1.0X -Native ORC Vectorized 6501 / 7814 2.4 413.4 1.1X -Native ORC Vectorized (Pushdown) 7341 / 8637 2.1 466.7 1.0X +Parquet Vectorized 7433 / 7752 2.1 472.6 1.0X +Parquet Vectorized (Pushdown) 5863 / 5913 2.7 372.8 1.3X +Native ORC Vectorized 7986 / 8084 2.0 507.7 0.9X +Native ORC Vectorized (Pushdown) 6522 / 6608 2.4 414.6 1.1X Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz Select 90% tinyint rows (value < CAST(114 AS tinyint)): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Parquet Vectorized 11886 / 13122 1.3 755.7 1.0X -Parquet Vectorized (Pushdown) 12557 / 14173 1.3 798.4 0.9X -Native ORC Vectorized 10758 / 11971 1.5 684.0 1.1X -Native ORC Vectorized (Pushdown) 10564 / 10713 1.5 671.6 1.1X +Parquet Vectorized 11190 / 11519 1.4 711.4 1.0X +Parquet Vectorized (Pushdown) 10861 / 11206 1.4 690.5 1.0X +Native ORC Vectorized 11622 / 12196 1.4 738.9 1.0X +Native ORC Vectorized (Pushdown) 11377 / 11654 1.4 723.3 1.0X diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala index 4827f706e6016..4c9b940db2b30 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala @@ -45,6 +45,8 @@ private[parquet] class ParquetFilters(pushDownDate: Boolean, pushDownStartWith: decimalMetadata: DecimalMetadata) private val ParquetBooleanType = ParquetSchemaType(null, BOOLEAN, null) + private val ParquetByteType = ParquetSchemaType(INT_8, INT32, null) + private val ParquetShortType = ParquetSchemaType(INT_16, INT32, null) private val ParquetIntegerType = ParquetSchemaType(null, INT32, null) private val ParquetLongType = ParquetSchemaType(null, INT64, null) private val ParquetFloatType = ParquetSchemaType(null, FLOAT, null) @@ -60,8 +62,10 @@ private[parquet] class ParquetFilters(pushDownDate: Boolean, pushDownStartWith: private val makeEq: PartialFunction[ParquetSchemaType, (String, Any) => FilterPredicate] = { case ParquetBooleanType => (n: String, v: Any) => FilterApi.eq(booleanColumn(n), v.asInstanceOf[java.lang.Boolean]) - case ParquetIntegerType => - (n: String, v: Any) => FilterApi.eq(intColumn(n), v.asInstanceOf[Integer]) + case ParquetByteType | ParquetShortType | ParquetIntegerType => + (n: String, v: Any) => FilterApi.eq( + intColumn(n), + Option(v).map(_.asInstanceOf[Number].intValue.asInstanceOf[Integer]).orNull) case ParquetLongType => (n: String, v: Any) => FilterApi.eq(longColumn(n), v.asInstanceOf[java.lang.Long]) case ParquetFloatType => @@ -87,8 +91,10 @@ private[parquet] class ParquetFilters(pushDownDate: Boolean, pushDownStartWith: private val makeNotEq: PartialFunction[ParquetSchemaType, (String, Any) => FilterPredicate] = { case ParquetBooleanType => (n: String, v: Any) => FilterApi.notEq(booleanColumn(n), v.asInstanceOf[java.lang.Boolean]) - case ParquetIntegerType => - (n: String, v: Any) => FilterApi.notEq(intColumn(n), v.asInstanceOf[Integer]) + case ParquetByteType | ParquetShortType | ParquetIntegerType => + (n: String, v: Any) => FilterApi.notEq( + intColumn(n), + Option(v).map(_.asInstanceOf[Number].intValue.asInstanceOf[Integer]).orNull) case ParquetLongType => (n: String, v: Any) => FilterApi.notEq(longColumn(n), v.asInstanceOf[java.lang.Long]) case ParquetFloatType => @@ -111,8 +117,9 @@ private[parquet] class ParquetFilters(pushDownDate: Boolean, pushDownStartWith: } private val makeLt: PartialFunction[ParquetSchemaType, (String, Any) => FilterPredicate] = { - case ParquetIntegerType => - (n: String, v: Any) => FilterApi.lt(intColumn(n), v.asInstanceOf[Integer]) + case ParquetByteType | ParquetShortType | ParquetIntegerType => + (n: String, v: Any) => + FilterApi.lt(intColumn(n), v.asInstanceOf[Number].intValue.asInstanceOf[Integer]) case ParquetLongType => (n: String, v: Any) => FilterApi.lt(longColumn(n), v.asInstanceOf[java.lang.Long]) case ParquetFloatType => @@ -132,8 +139,9 @@ private[parquet] class ParquetFilters(pushDownDate: Boolean, pushDownStartWith: } private val makeLtEq: PartialFunction[ParquetSchemaType, (String, Any) => FilterPredicate] = { - case ParquetIntegerType => - (n: String, v: Any) => FilterApi.ltEq(intColumn(n), v.asInstanceOf[Integer]) + case ParquetByteType | ParquetShortType | ParquetIntegerType => + (n: String, v: Any) => + FilterApi.ltEq(intColumn(n), v.asInstanceOf[Number].intValue.asInstanceOf[Integer]) case ParquetLongType => (n: String, v: Any) => FilterApi.ltEq(longColumn(n), v.asInstanceOf[java.lang.Long]) case ParquetFloatType => @@ -153,8 +161,9 @@ private[parquet] class ParquetFilters(pushDownDate: Boolean, pushDownStartWith: } private val makeGt: PartialFunction[ParquetSchemaType, (String, Any) => FilterPredicate] = { - case ParquetIntegerType => - (n: String, v: Any) => FilterApi.gt(intColumn(n), v.asInstanceOf[Integer]) + case ParquetByteType | ParquetShortType | ParquetIntegerType => + (n: String, v: Any) => + FilterApi.gt(intColumn(n), v.asInstanceOf[Number].intValue.asInstanceOf[Integer]) case ParquetLongType => (n: String, v: Any) => FilterApi.gt(longColumn(n), v.asInstanceOf[java.lang.Long]) case ParquetFloatType => @@ -174,8 +183,9 @@ private[parquet] class ParquetFilters(pushDownDate: Boolean, pushDownStartWith: } private val makeGtEq: PartialFunction[ParquetSchemaType, (String, Any) => FilterPredicate] = { - case ParquetIntegerType => - (n: String, v: Any) => FilterApi.gtEq(intColumn(n), v.asInstanceOf[Integer]) + case ParquetByteType | ParquetShortType | ParquetIntegerType => + (n: String, v: Any) => + FilterApi.gtEq(intColumn(n), v.asInstanceOf[Number].intValue.asInstanceOf[Integer]) case ParquetLongType => (n: String, v: Any) => FilterApi.gtEq(longColumn(n), v.asInstanceOf[java.lang.Long]) case ParquetFloatType => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index f2c0bda256239..067d2fea14fd7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -179,6 +179,62 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex } } + test("filter pushdown - tinyint") { + withParquetDataFrame((1 to 4).map(i => Tuple1(Option(i.toByte)))) { implicit df => + assert(df.schema.head.dataType === ByteType) + checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) + checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_))) + + checkFilterPredicate('_1 === 1.toByte, classOf[Eq[_]], 1) + checkFilterPredicate('_1 <=> 1.toByte, classOf[Eq[_]], 1) + checkFilterPredicate('_1 =!= 1.toByte, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) + + checkFilterPredicate('_1 < 2.toByte, classOf[Lt[_]], 1) + checkFilterPredicate('_1 > 3.toByte, classOf[Gt[_]], 4) + checkFilterPredicate('_1 <= 1.toByte, classOf[LtEq[_]], 1) + checkFilterPredicate('_1 >= 4.toByte, classOf[GtEq[_]], 4) + + checkFilterPredicate(Literal(1.toByte) === '_1, classOf[Eq[_]], 1) + checkFilterPredicate(Literal(1.toByte) <=> '_1, classOf[Eq[_]], 1) + checkFilterPredicate(Literal(2.toByte) > '_1, classOf[Lt[_]], 1) + checkFilterPredicate(Literal(3.toByte) < '_1, classOf[Gt[_]], 4) + checkFilterPredicate(Literal(1.toByte) >= '_1, classOf[LtEq[_]], 1) + checkFilterPredicate(Literal(4.toByte) <= '_1, classOf[GtEq[_]], 4) + + checkFilterPredicate(!('_1 < 4.toByte), classOf[GtEq[_]], 4) + checkFilterPredicate('_1 < 2.toByte || '_1 > 3.toByte, + classOf[Operators.Or], Seq(Row(1), Row(4))) + } + } + + test("filter pushdown - smallint") { + withParquetDataFrame((1 to 4).map(i => Tuple1(Option(i.toShort)))) { implicit df => + assert(df.schema.head.dataType === ShortType) + checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) + checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_))) + + checkFilterPredicate('_1 === 1.toShort, classOf[Eq[_]], 1) + checkFilterPredicate('_1 <=> 1.toShort, classOf[Eq[_]], 1) + checkFilterPredicate('_1 =!= 1.toShort, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) + + checkFilterPredicate('_1 < 2.toShort, classOf[Lt[_]], 1) + checkFilterPredicate('_1 > 3.toShort, classOf[Gt[_]], 4) + checkFilterPredicate('_1 <= 1.toShort, classOf[LtEq[_]], 1) + checkFilterPredicate('_1 >= 4.toShort, classOf[GtEq[_]], 4) + + checkFilterPredicate(Literal(1.toShort) === '_1, classOf[Eq[_]], 1) + checkFilterPredicate(Literal(1.toShort) <=> '_1, classOf[Eq[_]], 1) + checkFilterPredicate(Literal(2.toShort) > '_1, classOf[Lt[_]], 1) + checkFilterPredicate(Literal(3.toShort) < '_1, classOf[Gt[_]], 4) + checkFilterPredicate(Literal(1.toShort) >= '_1, classOf[LtEq[_]], 1) + checkFilterPredicate(Literal(4.toShort) <= '_1, classOf[GtEq[_]], 4) + + checkFilterPredicate(!('_1 < 4.toShort), classOf[GtEq[_]], 4) + checkFilterPredicate('_1 < 2.toShort || '_1 > 3.toShort, + classOf[Operators.Or], Seq(Row(1), Row(4))) + } + } + test("filter pushdown - integer") { withParquetDataFrame((1 to 4).map(i => Tuple1(Option(i)))) { implicit df => checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) From 6fe32869ccb17933e77a4dbe883e36d382fbeeec Mon Sep 17 00:00:00 2001 From: sharkdtu Date: Tue, 10 Jul 2018 20:18:34 +0800 Subject: [PATCH 1088/1161] [SPARK-24678][SPARK-STREAMING] Give priority in use of 'PROCESS_LOCAL' for spark-streaming ## What changes were proposed in this pull request? Currently, `BlockRDD.getPreferredLocations` only get hosts info of blocks, which results in subsequent schedule level is not better than 'NODE_LOCAL'. We can just make a small changes, the schedule level can be improved to 'PROCESS_LOCAL' ## How was this patch tested? manual test Author: sharkdtu Closes #21658 from sharkdtu/master. --- .../main/scala/org/apache/spark/rdd/BlockRDD.scala | 2 +- .../org/apache/spark/storage/BlockManager.scala | 7 +++++-- .../apache/spark/storage/BlockManagerSuite.scala | 13 +++++++++++++ 3 files changed, 19 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala b/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala index 4e036c2ed49b5..23cf19d55b4ae 100644 --- a/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala @@ -30,7 +30,7 @@ private[spark] class BlockRDD[T: ClassTag](sc: SparkContext, @transient val blockIds: Array[BlockId]) extends RDD[T](sc, Nil) { - @transient lazy val _locations = BlockManager.blockIdsToHosts(blockIds, SparkEnv.get) + @transient lazy val _locations = BlockManager.blockIdsToLocations(blockIds, SparkEnv.get) @volatile private var _isValid = true override def getPartitions: Array[Partition] = { diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index df1a4bef616b2..0e1c7d5fd3fa2 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -45,6 +45,7 @@ import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.network.shuffle.{ExternalShuffleClient, TempFileManager} import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo import org.apache.spark.rpc.RpcEnv +import org.apache.spark.scheduler.ExecutorCacheTaskLocation import org.apache.spark.serializer.{SerializerInstance, SerializerManager} import org.apache.spark.shuffle.ShuffleManager import org.apache.spark.storage.memory._ @@ -1554,7 +1555,7 @@ private[spark] class BlockManager( private[spark] object BlockManager { private val ID_GENERATOR = new IdGenerator - def blockIdsToHosts( + def blockIdsToLocations( blockIds: Array[BlockId], env: SparkEnv, blockManagerMaster: BlockManagerMaster = null): Map[BlockId, Seq[String]] = { @@ -1569,7 +1570,9 @@ private[spark] object BlockManager { val blockManagers = new HashMap[BlockId, Seq[String]] for (i <- 0 until blockIds.length) { - blockManagers(blockIds(i)) = blockLocations(i).map(_.host) + blockManagers(blockIds(i)) = blockLocations(i).map { loc => + ExecutorCacheTaskLocation(loc.host, loc.executorId).toString + } } blockManagers.toMap } diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index b19d8ebf72c61..08172f0b07b75 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -1422,6 +1422,19 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE assert(mockBlockTransferService.tempFileManager === store.remoteBlockTempFileManager) } + test("query locations of blockIds") { + val mockBlockManagerMaster = mock(classOf[BlockManagerMaster]) + val blockLocations = Seq(BlockManagerId("1", "host1", 100), BlockManagerId("2", "host2", 200)) + when(mockBlockManagerMaster.getLocations(mc.any[Array[BlockId]])) + .thenReturn(Array(blockLocations)) + val env = mock(classOf[SparkEnv]) + + val blockIds: Array[BlockId] = Array(StreamBlockId(1, 2)) + val locs = BlockManager.blockIdsToLocations(blockIds, env, mockBlockManagerMaster) + val expectedLocs = Seq("executor_host1_1", "executor_host2_2") + assert(locs(blockIds(0)) == expectedLocs) + } + class MockBlockTransferService(val maxFailures: Int) extends BlockTransferService { var numCalls = 0 var tempFileManager: TempFileManager = null From e0559f238009e02c40f65678fec691c07904e8c0 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 10 Jul 2018 23:07:10 +0800 Subject: [PATCH 1089/1161] [SPARK-21743][SQL][FOLLOWUP] free aggregate map when task ends ## What changes were proposed in this pull request? This is the first follow-up of https://github.com/apache/spark/pull/21573 , which was only merged to 2.3. This PR fixes the memory leak in another way: free the `UnsafeExternalMap` when the task ends. All the data buffers in Spark SQL are using `UnsafeExternalMap` and `UnsafeExternalSorter` under the hood, e.g. sort, aggregate, window, SMJ, etc. `UnsafeExternalSorter` registers a task completion listener to free the resource, we should apply the same thing to `UnsafeExternalMap`. TODO in the next PR: do not consume all the inputs when having limit in whole stage codegen. ## How was this patch tested? existing tests Author: Wenchen Fan Closes #21738 from cloud-fan/limit. --- .../UnsafeFixedWidthAggregationMap.java | 17 ++++++++++----- .../spark/sql/execution/SparkStrategies.scala | 7 +------ .../aggregate/HashAggregateExec.scala | 2 +- .../TungstenAggregationIterator.scala | 2 +- .../UnsafeFixedWidthAggregationMapSuite.scala | 21 ++++++++++++------- 5 files changed, 28 insertions(+), 21 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java index c7c4c7b3e7715..c8cf44b51df77 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java @@ -20,8 +20,8 @@ import java.io.IOException; import org.apache.spark.SparkEnv; +import org.apache.spark.TaskContext; import org.apache.spark.internal.config.package$; -import org.apache.spark.memory.TaskMemoryManager; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.expressions.UnsafeProjection; import org.apache.spark.sql.catalyst.expressions.UnsafeRow; @@ -82,7 +82,7 @@ public static boolean supportsAggregationBufferSchema(StructType schema) { * @param emptyAggregationBuffer the default value for new keys (a "zero" of the agg. function) * @param aggregationBufferSchema the schema of the aggregation buffer, used for row conversion. * @param groupingKeySchema the schema of the grouping key, used for row conversion. - * @param taskMemoryManager the memory manager used to allocate our Unsafe memory structures. + * @param taskContext the current task context. * @param initialCapacity the initial capacity of the map (a sizing hint to avoid re-hashing). * @param pageSizeBytes the data page size, in bytes; limits the maximum record size. */ @@ -90,19 +90,26 @@ public UnsafeFixedWidthAggregationMap( InternalRow emptyAggregationBuffer, StructType aggregationBufferSchema, StructType groupingKeySchema, - TaskMemoryManager taskMemoryManager, + TaskContext taskContext, int initialCapacity, long pageSizeBytes) { this.aggregationBufferSchema = aggregationBufferSchema; this.currentAggregationBuffer = new UnsafeRow(aggregationBufferSchema.length()); this.groupingKeyProjection = UnsafeProjection.create(groupingKeySchema); this.groupingKeySchema = groupingKeySchema; - this.map = - new BytesToBytesMap(taskMemoryManager, initialCapacity, pageSizeBytes, true); + this.map = new BytesToBytesMap( + taskContext.taskMemoryManager(), initialCapacity, pageSizeBytes, true); // Initialize the buffer for aggregation value final UnsafeProjection valueProjection = UnsafeProjection.create(aggregationBufferSchema); this.emptyAggregationBuffer = valueProjection.apply(emptyAggregationBuffer).getBytes(); + + // Register a cleanup task with TaskContext to ensure that memory is guaranteed to be freed at + // the end of the task. This is necessary to avoid memory leaks in when the downstream operator + // does not fully consume the aggregation map's output (e.g. aggregate followed by limit). + taskContext.addTaskCompletionListener(context -> { + free(); + }); } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 07a6fcae83b70..cfbcb9aad65c4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -73,12 +73,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { if limit < conf.topKSortFallbackThreshold => TakeOrderedAndProjectExec(limit, order, projectList, planLater(child)) :: Nil case Limit(IntegerLiteral(limit), child) => - // With whole stage codegen, Spark releases resources only when all the output data of the - // query plan are consumed. It's possible that `CollectLimitExec` only consumes a little - // data from child plan and finishes the query without releasing resources. Here we wrap - // the child plan with `LocalLimitExec`, to stop the processing of whole stage codegen and - // trigger the resource releasing work, after we consume `limit` rows. - CollectLimitExec(limit, LocalLimitExec(limit, planLater(child))) :: Nil + CollectLimitExec(limit, planLater(child)) :: Nil case other => planLater(other) :: Nil } case Limit(IntegerLiteral(limit), Sort(order, true, child)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 8c7b2c187cccd..2cac0cfce28de 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -328,7 +328,7 @@ case class HashAggregateExec( initialBuffer, bufferSchema, groupingKeySchema, - TaskContext.get().taskMemoryManager(), + TaskContext.get(), 1024 * 16, // initial capacity TaskContext.get().taskMemoryManager().pageSizeBytes ) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala index 9dc334c1ead3c..c1911235f8df3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala @@ -166,7 +166,7 @@ class TungstenAggregationIterator( initialAggregationBuffer, StructType.fromAttributes(aggregateFunctions.flatMap(_.aggBufferAttributes)), StructType.fromAttributes(groupingExpressions.map(_.toAttribute)), - TaskContext.get().taskMemoryManager(), + TaskContext.get(), 1024 * 16, // initial capacity TaskContext.get().taskMemoryManager().pageSizeBytes ) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala index 3e31d22e15c0e..5c15ecd42fa0c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala @@ -23,6 +23,7 @@ import scala.collection.mutable import scala.util.{Random, Try} import scala.util.control.NonFatal +import org.mockito.Mockito._ import org.scalatest.Matchers import org.apache.spark.{SparkConf, SparkFunSuite, TaskContext, TaskContextImpl} @@ -54,6 +55,8 @@ class UnsafeFixedWidthAggregationMapSuite private var memoryManager: TestMemoryManager = null private var taskMemoryManager: TaskMemoryManager = null + private var taskContext: TaskContext = null + def testWithMemoryLeakDetection(name: String)(f: => Unit) { def cleanup(): Unit = { if (taskMemoryManager != null) { @@ -67,6 +70,8 @@ class UnsafeFixedWidthAggregationMapSuite val conf = new SparkConf().set(MEMORY_OFFHEAP_ENABLED.key, "false") memoryManager = new TestMemoryManager(conf) taskMemoryManager = new TaskMemoryManager(memoryManager, 0) + taskContext = mock(classOf[TaskContext]) + when(taskContext.taskMemoryManager()).thenReturn(taskMemoryManager) TaskContext.setTaskContext(new TaskContextImpl( stageId = 0, @@ -111,7 +116,7 @@ class UnsafeFixedWidthAggregationMapSuite emptyAggregationBuffer, aggBufferSchema, groupKeySchema, - taskMemoryManager, + taskContext, 1024, // initial capacity, PAGE_SIZE_BYTES ) @@ -124,7 +129,7 @@ class UnsafeFixedWidthAggregationMapSuite emptyAggregationBuffer, aggBufferSchema, groupKeySchema, - taskMemoryManager, + taskContext, 1024, // initial capacity PAGE_SIZE_BYTES ) @@ -151,7 +156,7 @@ class UnsafeFixedWidthAggregationMapSuite emptyAggregationBuffer, aggBufferSchema, groupKeySchema, - taskMemoryManager, + taskContext, 128, // initial capacity PAGE_SIZE_BYTES ) @@ -176,7 +181,7 @@ class UnsafeFixedWidthAggregationMapSuite emptyAggregationBuffer, aggBufferSchema, groupKeySchema, - taskMemoryManager, + taskContext, 128, // initial capacity PAGE_SIZE_BYTES ) @@ -223,7 +228,7 @@ class UnsafeFixedWidthAggregationMapSuite emptyAggregationBuffer, aggBufferSchema, groupKeySchema, - taskMemoryManager, + taskContext, 128, // initial capacity PAGE_SIZE_BYTES ) @@ -263,7 +268,7 @@ class UnsafeFixedWidthAggregationMapSuite emptyAggregationBuffer, StructType(Nil), StructType(Nil), - taskMemoryManager, + taskContext, 128, // initial capacity PAGE_SIZE_BYTES ) @@ -307,7 +312,7 @@ class UnsafeFixedWidthAggregationMapSuite emptyAggregationBuffer, aggBufferSchema, groupKeySchema, - taskMemoryManager, + taskContext, 128, // initial capacity pageSize ) @@ -344,7 +349,7 @@ class UnsafeFixedWidthAggregationMapSuite emptyAggregationBuffer, aggBufferSchema, groupKeySchema, - taskMemoryManager, + taskContext, 128, // initial capacity pageSize ) From 32cb50835e7258625afff562939872be002232f2 Mon Sep 17 00:00:00 2001 From: Mukul Murthy Date: Tue, 10 Jul 2018 11:08:04 -0700 Subject: [PATCH 1090/1161] [SPARK-24662][SQL][SS] Support limit in structured streaming ## What changes were proposed in this pull request? Support the LIMIT operator in structured streaming. For streams in append or complete output mode, a stream with a LIMIT operator will return no more than the specified number of rows. LIMIT is still unsupported for the update output mode. This change reverts https://github.com/apache/spark/commit/e4fee395ecd93ad4579d9afbf0861f82a303e563 as part of it because it is a better and more complete implementation. ## How was this patch tested? New and existing unit tests. Author: Mukul Murthy Closes #21662 from mukulmurthy/SPARK-24662. --- .../UnsupportedOperationChecker.scala | 6 +- .../spark/sql/execution/SparkStrategies.scala | 26 ++++- .../streaming/IncrementalExecution.scala | 11 +- .../streaming/StreamingGlobalLimitExec.scala | 102 +++++++++++++++++ .../sql/execution/streaming/memory.scala | 70 +----------- .../streaming/sources/memoryV2.scala | 44 ++----- .../sql/streaming/DataStreamWriter.scala | 4 +- .../execution/streaming/MemorySinkSuite.scala | 62 +--------- .../streaming/MemorySinkV2Suite.scala | 80 +------------ .../spark/sql/streaming/StreamSuite.scala | 108 ++++++++++++++++++ .../spark/sql/streaming/StreamTest.scala | 4 +- 11 files changed, 272 insertions(+), 245 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingGlobalLimitExec.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala index 5ced1ca200daa..f68df5d29b545 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala @@ -315,8 +315,10 @@ object UnsupportedOperationChecker { case GroupingSets(_, _, child, _) if child.isStreaming => throwError("GroupingSets is not supported on streaming DataFrames/Datasets") - case GlobalLimit(_, _) | LocalLimit(_, _) if subPlan.children.forall(_.isStreaming) => - throwError("Limits are not supported on streaming DataFrames/Datasets") + case GlobalLimit(_, _) | LocalLimit(_, _) + if subPlan.children.forall(_.isStreaming) && outputMode == InternalOutputModes.Update => + throwError("Limits are not supported on streaming DataFrames/Datasets in Update " + + "output mode") case Sort(_, _, _) if !containsCompleteData(subPlan) => throwError("Sorting is not supported on streaming DataFrames/Datasets, unless it is on " + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index cfbcb9aad65c4..02e095b42a506 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.catalyst.streaming.InternalOutputModes import org.apache.spark.sql.execution.columnar.{InMemoryRelation, InMemoryTableScanExec} import org.apache.spark.sql.execution.command._ import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec @@ -34,7 +35,7 @@ import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight, BuildSide} import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.sources.MemoryPlanV2 import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.streaming.StreamingQuery +import org.apache.spark.sql.streaming.{OutputMode, StreamingQuery} import org.apache.spark.sql.types.StructType /** @@ -349,6 +350,29 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } } + /** + * Used to plan the streaming global limit operator for streams in append mode. + * We need to check for either a direct Limit or a Limit wrapped in a ReturnAnswer operator, + * following the example of the SpecialLimits Strategy above. + * Streams with limit in Append mode use the stateful StreamingGlobalLimitExec. + * Streams with limit in Complete mode use the stateless CollectLimitExec operator. + * Limit is unsupported for streams in Update mode. + */ + case class StreamingGlobalLimitStrategy(outputMode: OutputMode) extends Strategy { + override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case ReturnAnswer(rootPlan) => rootPlan match { + case Limit(IntegerLiteral(limit), child) + if plan.isStreaming && outputMode == InternalOutputModes.Append => + StreamingGlobalLimitExec(limit, LocalLimitExec(limit, planLater(child))) :: Nil + case _ => Nil + } + case Limit(IntegerLiteral(limit), child) + if plan.isStreaming && outputMode == InternalOutputModes.Append => + StreamingGlobalLimitExec(limit, LocalLimitExec(limit, planLater(child))) :: Nil + case _ => Nil + } + } + object StreamingJoinStrategy extends Strategy { override def apply(plan: LogicalPlan): Seq[SparkPlan] = { plan match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index c480b96626f84..6ae7f2869b0f3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -59,7 +59,8 @@ class IncrementalExecution( StatefulAggregationStrategy :: FlatMapGroupsWithStateStrategy :: StreamingRelationStrategy :: - StreamingDeduplicationStrategy :: Nil + StreamingDeduplicationStrategy :: + StreamingGlobalLimitStrategy(outputMode) :: Nil } private[sql] val numStateStores = offsetSeqMetadata.conf.get(SQLConf.SHUFFLE_PARTITIONS.key) @@ -134,8 +135,12 @@ class IncrementalExecution( stateWatermarkPredicates = StreamingSymmetricHashJoinHelper.getStateWatermarkPredicates( j.left.output, j.right.output, j.leftKeys, j.rightKeys, j.condition.full, - Some(offsetSeqMetadata.batchWatermarkMs)) - ) + Some(offsetSeqMetadata.batchWatermarkMs))) + + case l: StreamingGlobalLimitExec => + l.copy( + stateInfo = Some(nextStatefulOperationStateInfo), + outputMode = Some(outputMode)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingGlobalLimitExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingGlobalLimitExec.scala new file mode 100644 index 0000000000000..bf4af60c8cf03 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingGlobalLimitExec.scala @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.streaming + +import java.util.concurrent.TimeUnit.NANOSECONDS + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow +import org.apache.spark.sql.catalyst.expressions.UnsafeProjection +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, Distribution, Partitioning} +import org.apache.spark.sql.catalyst.streaming.InternalOutputModes +import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} +import org.apache.spark.sql.execution.streaming.state.StateStoreOps +import org.apache.spark.sql.streaming.OutputMode +import org.apache.spark.sql.types.{LongType, NullType, StructField, StructType} +import org.apache.spark.util.CompletionIterator + +/** + * A physical operator for executing a streaming limit, which makes sure no more than streamLimit + * rows are returned. This operator is meant for streams in Append mode only. + */ +case class StreamingGlobalLimitExec( + streamLimit: Long, + child: SparkPlan, + stateInfo: Option[StatefulOperatorStateInfo] = None, + outputMode: Option[OutputMode] = None) + extends UnaryExecNode with StateStoreWriter { + + private val keySchema = StructType(Array(StructField("key", NullType))) + private val valueSchema = StructType(Array(StructField("value", LongType))) + + override protected def doExecute(): RDD[InternalRow] = { + metrics // force lazy init at driver + + assert(outputMode.isDefined && outputMode.get == InternalOutputModes.Append, + "StreamingGlobalLimitExec is only valid for streams in Append output mode") + + child.execute().mapPartitionsWithStateStore( + getStateInfo, + keySchema, + valueSchema, + indexOrdinal = None, + sqlContext.sessionState, + Some(sqlContext.streams.stateStoreCoordinator)) { (store, iter) => + val key = UnsafeProjection.create(keySchema)(new GenericInternalRow(Array[Any](null))) + val numOutputRows = longMetric("numOutputRows") + val numUpdatedStateRows = longMetric("numUpdatedStateRows") + val allUpdatesTimeMs = longMetric("allUpdatesTimeMs") + val commitTimeMs = longMetric("commitTimeMs") + val updatesStartTimeNs = System.nanoTime + + val preBatchRowCount: Long = Option(store.get(key)).map(_.getLong(0)).getOrElse(0L) + var cumulativeRowCount = preBatchRowCount + + val result = iter.filter { r => + val x = cumulativeRowCount < streamLimit + if (x) { + cumulativeRowCount += 1 + } + x + } + + CompletionIterator[InternalRow, Iterator[InternalRow]](result, { + if (cumulativeRowCount > preBatchRowCount) { + numUpdatedStateRows += 1 + numOutputRows += cumulativeRowCount - preBatchRowCount + store.put(key, getValueRow(cumulativeRowCount)) + } + allUpdatesTimeMs += NANOSECONDS.toMillis(System.nanoTime - updatesStartTimeNs) + commitTimeMs += timeTakenMs { store.commit() } + setStoreMetrics(store) + }) + } + } + + override def output: Seq[Attribute] = child.output + + override def outputPartitioning: Partitioning = child.outputPartitioning + + override def requiredChildDistribution: Seq[Distribution] = AllTuples :: Nil + + private def getValueRow(value: Long): UnsafeRow = { + UnsafeProjection.create(valueSchema)(new GenericInternalRow(Array[Any](value))) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala index 7fa13c4aa2c01..b137f98045c5a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -33,7 +33,6 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeRow} import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ -import org.apache.spark.sql.sources.v2.DataSourceOptions import org.apache.spark.sql.sources.v2.reader.{InputPartition, InputPartitionReader, SupportsScanUnsafeRow} import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset => OffsetV2} import org.apache.spark.sql.streaming.OutputMode @@ -222,60 +221,19 @@ class MemoryStreamInputPartition(records: Array[UnsafeRow]) } /** A common trait for MemorySinks with methods used for testing */ -trait MemorySinkBase extends BaseStreamingSink with Logging { +trait MemorySinkBase extends BaseStreamingSink { def allData: Seq[Row] def latestBatchData: Seq[Row] def dataSinceBatch(sinceBatchId: Long): Seq[Row] def latestBatchId: Option[Long] - - /** - * Truncates the given rows to return at most maxRows rows. - * @param rows The data that may need to be truncated. - * @param batchLimit Number of rows to keep in this batch; the rest will be truncated - * @param sinkLimit Total number of rows kept in this sink, for logging purposes. - * @param batchId The ID of the batch that sent these rows, for logging purposes. - * @return Truncated rows. - */ - protected def truncateRowsIfNeeded( - rows: Array[Row], - batchLimit: Int, - sinkLimit: Int, - batchId: Long): Array[Row] = { - if (rows.length > batchLimit && batchLimit >= 0) { - logWarning(s"Truncating batch $batchId to $batchLimit rows because of sink limit $sinkLimit") - rows.take(batchLimit) - } else { - rows - } - } -} - -/** - * Companion object to MemorySinkBase. - */ -object MemorySinkBase { - val MAX_MEMORY_SINK_ROWS = "maxRows" - val MAX_MEMORY_SINK_ROWS_DEFAULT = -1 - - /** - * Gets the max number of rows a MemorySink should store. This number is based on the memory - * sink row limit option if it is set. If not, we use a large value so that data truncates - * rather than causing out of memory errors. - * @param options Options for writing from which we get the max rows option - * @return The maximum number of rows a memorySink should store. - */ - def getMemorySinkCapacity(options: DataSourceOptions): Int = { - val maxRows = options.getInt(MAX_MEMORY_SINK_ROWS, MAX_MEMORY_SINK_ROWS_DEFAULT) - if (maxRows >= 0) maxRows else Int.MaxValue - 10 - } } /** * A sink that stores the results in memory. This [[Sink]] is primarily intended for use in unit * tests and does not provide durability. */ -class MemorySink(val schema: StructType, outputMode: OutputMode, options: DataSourceOptions) - extends Sink with MemorySinkBase with Logging { +class MemorySink(val schema: StructType, outputMode: OutputMode) extends Sink + with MemorySinkBase with Logging { private case class AddedData(batchId: Long, data: Array[Row]) @@ -283,12 +241,6 @@ class MemorySink(val schema: StructType, outputMode: OutputMode, options: DataSo @GuardedBy("this") private val batches = new ArrayBuffer[AddedData]() - /** The number of rows in this MemorySink. */ - private var numRows = 0 - - /** The capacity in rows of this sink. */ - val sinkCapacity: Int = MemorySinkBase.getMemorySinkCapacity(options) - /** Returns all rows that are stored in this [[Sink]]. */ def allData: Seq[Row] = synchronized { batches.flatMap(_.data) @@ -321,23 +273,14 @@ class MemorySink(val schema: StructType, outputMode: OutputMode, options: DataSo logDebug(s"Committing batch $batchId to $this") outputMode match { case Append | Update => - var rowsToAdd = data.collect() - synchronized { - rowsToAdd = - truncateRowsIfNeeded(rowsToAdd, sinkCapacity - numRows, sinkCapacity, batchId) - val rows = AddedData(batchId, rowsToAdd) - batches += rows - numRows += rowsToAdd.length - } + val rows = AddedData(batchId, data.collect()) + synchronized { batches += rows } case Complete => - var rowsToAdd = data.collect() + val rows = AddedData(batchId, data.collect()) synchronized { - rowsToAdd = truncateRowsIfNeeded(rowsToAdd, sinkCapacity, sinkCapacity, batchId) - val rows = AddedData(batchId, rowsToAdd) batches.clear() batches += rows - numRows = rowsToAdd.length } case _ => @@ -351,7 +294,6 @@ class MemorySink(val schema: StructType, outputMode: OutputMode, options: DataSo def clear(): Unit = synchronized { batches.clear() - numRows = 0 } override def toString(): String = "MemorySink" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala index 29f8cca476722..f2a35a90af24a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala @@ -46,7 +46,7 @@ class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with MemorySinkB schema: StructType, mode: OutputMode, options: DataSourceOptions): StreamWriter = { - new MemoryStreamWriter(this, mode, options) + new MemoryStreamWriter(this, mode) } private case class AddedData(batchId: Long, data: Array[Row]) @@ -55,9 +55,6 @@ class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with MemorySinkB @GuardedBy("this") private val batches = new ArrayBuffer[AddedData]() - /** The number of rows in this MemorySink. */ - private var numRows = 0 - /** Returns all rows that are stored in this [[Sink]]. */ def allData: Seq[Row] = synchronized { batches.flatMap(_.data) @@ -84,11 +81,7 @@ class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with MemorySinkB }.mkString("\n") } - def write( - batchId: Long, - outputMode: OutputMode, - newRows: Array[Row], - sinkCapacity: Int): Unit = { + def write(batchId: Long, outputMode: OutputMode, newRows: Array[Row]): Unit = { val notCommitted = synchronized { latestBatchId.isEmpty || batchId > latestBatchId.get } @@ -96,21 +89,14 @@ class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with MemorySinkB logDebug(s"Committing batch $batchId to $this") outputMode match { case Append | Update => - synchronized { - val rowsToAdd = - truncateRowsIfNeeded(newRows, sinkCapacity - numRows, sinkCapacity, batchId) - val rows = AddedData(batchId, rowsToAdd) - batches += rows - numRows += rowsToAdd.length - } + val rows = AddedData(batchId, newRows) + synchronized { batches += rows } case Complete => + val rows = AddedData(batchId, newRows) synchronized { - val rowsToAdd = truncateRowsIfNeeded(newRows, sinkCapacity, sinkCapacity, batchId) - val rows = AddedData(batchId, rowsToAdd) batches.clear() batches += rows - numRows = rowsToAdd.length } case _ => @@ -124,7 +110,6 @@ class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with MemorySinkB def clear(): Unit = synchronized { batches.clear() - numRows = 0 } override def toString(): String = "MemorySinkV2" @@ -132,22 +117,16 @@ class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with MemorySinkB case class MemoryWriterCommitMessage(partition: Int, data: Seq[Row]) extends WriterCommitMessage {} -class MemoryWriter( - sink: MemorySinkV2, - batchId: Long, - outputMode: OutputMode, - options: DataSourceOptions) +class MemoryWriter(sink: MemorySinkV2, batchId: Long, outputMode: OutputMode) extends DataSourceWriter with Logging { - val sinkCapacity: Int = MemorySinkBase.getMemorySinkCapacity(options) - override def createWriterFactory: MemoryWriterFactory = MemoryWriterFactory(outputMode) def commit(messages: Array[WriterCommitMessage]): Unit = { val newRows = messages.flatMap { case message: MemoryWriterCommitMessage => message.data } - sink.write(batchId, outputMode, newRows, sinkCapacity) + sink.write(batchId, outputMode, newRows) } override def abort(messages: Array[WriterCommitMessage]): Unit = { @@ -155,21 +134,16 @@ class MemoryWriter( } } -class MemoryStreamWriter( - val sink: MemorySinkV2, - outputMode: OutputMode, - options: DataSourceOptions) +class MemoryStreamWriter(val sink: MemorySinkV2, outputMode: OutputMode) extends StreamWriter { - val sinkCapacity: Int = MemorySinkBase.getMemorySinkCapacity(options) - override def createWriterFactory: MemoryWriterFactory = MemoryWriterFactory(outputMode) override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = { val newRows = messages.flatMap { case message: MemoryWriterCommitMessage => message.data } - sink.write(epochId, outputMode, newRows, sinkCapacity) + sink.write(epochId, outputMode, newRows) } override def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala index 926c0b69a03fd..3b9a56ffdde4b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger import org.apache.spark.sql.execution.streaming.sources._ -import org.apache.spark.sql.sources.v2.{DataSourceOptions, StreamWriteSupport} +import org.apache.spark.sql.sources.v2.StreamWriteSupport /** * Interface used to write a streaming `Dataset` to external storage systems (e.g. file systems, @@ -250,7 +250,7 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { val r = Dataset.ofRows(df.sparkSession, new MemoryPlanV2(s, df.schema.toAttributes)) (s, r) case _ => - val s = new MemorySink(df.schema, outputMode, new DataSourceOptions(extraOptions.asJava)) + val s = new MemorySink(df.schema, outputMode) val r = Dataset.ofRows(df.sparkSession, new MemoryPlan(s)) (s, r) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala index b2fd6ba27ebb8..3bc36ce55d902 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala @@ -17,13 +17,11 @@ package org.apache.spark.sql.execution.streaming -import scala.collection.JavaConverters._ import scala.language.implicitConversions import org.scalatest.BeforeAndAfter import org.apache.spark.sql._ -import org.apache.spark.sql.sources.v2.DataSourceOptions import org.apache.spark.sql.streaming.{OutputMode, StreamTest} import org.apache.spark.sql.types.{IntegerType, StructField, StructType} import org.apache.spark.util.Utils @@ -38,7 +36,7 @@ class MemorySinkSuite extends StreamTest with BeforeAndAfter { test("directly add data in Append output mode") { implicit val schema = new StructType().add(new StructField("value", IntegerType)) - val sink = new MemorySink(schema, OutputMode.Append, DataSourceOptions.empty()) + val sink = new MemorySink(schema, OutputMode.Append) // Before adding data, check output assert(sink.latestBatchId === None) @@ -70,35 +68,9 @@ class MemorySinkSuite extends StreamTest with BeforeAndAfter { checkAnswer(sink.allData, 1 to 9) } - test("directly add data in Append output mode with row limit") { - implicit val schema = new StructType().add(new StructField("value", IntegerType)) - - var optionsMap = new scala.collection.mutable.HashMap[String, String] - optionsMap.put(MemorySinkBase.MAX_MEMORY_SINK_ROWS, 5.toString()) - var options = new DataSourceOptions(optionsMap.toMap.asJava) - val sink = new MemorySink(schema, OutputMode.Append, options) - - // Before adding data, check output - assert(sink.latestBatchId === None) - checkAnswer(sink.latestBatchData, Seq.empty) - checkAnswer(sink.allData, Seq.empty) - - // Add batch 0 and check outputs - sink.addBatch(0, 1 to 3) - assert(sink.latestBatchId === Some(0)) - checkAnswer(sink.latestBatchData, 1 to 3) - checkAnswer(sink.allData, 1 to 3) - - // Add batch 1 and check outputs - sink.addBatch(1, 4 to 6) - assert(sink.latestBatchId === Some(1)) - checkAnswer(sink.latestBatchData, 4 to 5) - checkAnswer(sink.allData, 1 to 5) // new data should not go over the limit - } - test("directly add data in Update output mode") { implicit val schema = new StructType().add(new StructField("value", IntegerType)) - val sink = new MemorySink(schema, OutputMode.Update, DataSourceOptions.empty()) + val sink = new MemorySink(schema, OutputMode.Update) // Before adding data, check output assert(sink.latestBatchId === None) @@ -132,7 +104,7 @@ class MemorySinkSuite extends StreamTest with BeforeAndAfter { test("directly add data in Complete output mode") { implicit val schema = new StructType().add(new StructField("value", IntegerType)) - val sink = new MemorySink(schema, OutputMode.Complete, DataSourceOptions.empty()) + val sink = new MemorySink(schema, OutputMode.Complete) // Before adding data, check output assert(sink.latestBatchId === None) @@ -164,32 +136,6 @@ class MemorySinkSuite extends StreamTest with BeforeAndAfter { checkAnswer(sink.allData, 7 to 9) } - test("directly add data in Complete output mode with row limit") { - implicit val schema = new StructType().add(new StructField("value", IntegerType)) - - var optionsMap = new scala.collection.mutable.HashMap[String, String] - optionsMap.put(MemorySinkBase.MAX_MEMORY_SINK_ROWS, 5.toString()) - var options = new DataSourceOptions(optionsMap.toMap.asJava) - val sink = new MemorySink(schema, OutputMode.Complete, options) - - // Before adding data, check output - assert(sink.latestBatchId === None) - checkAnswer(sink.latestBatchData, Seq.empty) - checkAnswer(sink.allData, Seq.empty) - - // Add batch 0 and check outputs - sink.addBatch(0, 1 to 3) - assert(sink.latestBatchId === Some(0)) - checkAnswer(sink.latestBatchData, 1 to 3) - checkAnswer(sink.allData, 1 to 3) - - // Add batch 1 and check outputs - sink.addBatch(1, 4 to 10) - assert(sink.latestBatchId === Some(1)) - checkAnswer(sink.latestBatchData, 4 to 8) - checkAnswer(sink.allData, 4 to 8) // new data should replace old data - } - test("registering as a table in Append output mode") { val input = MemoryStream[Int] @@ -265,7 +211,7 @@ class MemorySinkSuite extends StreamTest with BeforeAndAfter { test("MemoryPlan statistics") { implicit val schema = new StructType().add(new StructField("value", IntegerType)) - val sink = new MemorySink(schema, OutputMode.Append, DataSourceOptions.empty()) + val sink = new MemorySink(schema, OutputMode.Append) val plan = new MemoryPlan(sink) // Before adding data, check output diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala index e539510e15755..9be22d94b5654 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala @@ -17,16 +17,11 @@ package org.apache.spark.sql.execution.streaming -import scala.collection.JavaConverters._ - import org.scalatest.BeforeAndAfter import org.apache.spark.sql.Row import org.apache.spark.sql.execution.streaming.sources._ -import org.apache.spark.sql.sources.v2.DataSourceOptions import org.apache.spark.sql.streaming.{OutputMode, StreamTest} -import org.apache.spark.sql.types.IntegerType -import org.apache.spark.sql.types.StructType class MemorySinkV2Suite extends StreamTest with BeforeAndAfter { test("data writer") { @@ -45,7 +40,7 @@ class MemorySinkV2Suite extends StreamTest with BeforeAndAfter { test("continuous writer") { val sink = new MemorySinkV2 - val writer = new MemoryStreamWriter(sink, OutputMode.Append(), DataSourceOptions.empty()) + val writer = new MemoryStreamWriter(sink, OutputMode.Append()) writer.commit(0, Array( MemoryWriterCommitMessage(0, Seq(Row(1), Row(2))), @@ -67,7 +62,7 @@ class MemorySinkV2Suite extends StreamTest with BeforeAndAfter { test("microbatch writer") { val sink = new MemorySinkV2 - new MemoryWriter(sink, 0, OutputMode.Append(), DataSourceOptions.empty()).commit( + new MemoryWriter(sink, 0, OutputMode.Append()).commit( Array( MemoryWriterCommitMessage(0, Seq(Row(1), Row(2))), MemoryWriterCommitMessage(1, Seq(Row(3), Row(4))), @@ -75,7 +70,7 @@ class MemorySinkV2Suite extends StreamTest with BeforeAndAfter { )) assert(sink.latestBatchId.contains(0)) assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4, 6, 7)) - new MemoryWriter(sink, 19, OutputMode.Append(), DataSourceOptions.empty()).commit( + new MemoryWriter(sink, 19, OutputMode.Append()).commit( Array( MemoryWriterCommitMessage(3, Seq(Row(11), Row(22))), MemoryWriterCommitMessage(0, Seq(Row(33))) @@ -85,73 +80,4 @@ class MemorySinkV2Suite extends StreamTest with BeforeAndAfter { assert(sink.allData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4, 6, 7, 11, 22, 33)) } - - test("continuous writer with row limit") { - val sink = new MemorySinkV2 - val optionsMap = new scala.collection.mutable.HashMap[String, String] - optionsMap.put(MemorySinkBase.MAX_MEMORY_SINK_ROWS, 7.toString()) - val options = new DataSourceOptions(optionsMap.toMap.asJava) - val appendWriter = new MemoryStreamWriter(sink, OutputMode.Append(), options) - appendWriter.commit(0, Array( - MemoryWriterCommitMessage(0, Seq(Row(1), Row(2))), - MemoryWriterCommitMessage(1, Seq(Row(3), Row(4))), - MemoryWriterCommitMessage(2, Seq(Row(6), Row(7))))) - assert(sink.latestBatchId.contains(0)) - assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4, 6, 7)) - appendWriter.commit(19, Array( - MemoryWriterCommitMessage(3, Seq(Row(11), Row(22))), - MemoryWriterCommitMessage(0, Seq(Row(33))))) - assert(sink.latestBatchId.contains(19)) - assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(11)) - - assert(sink.allData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4, 6, 7, 11)) - - val completeWriter = new MemoryStreamWriter(sink, OutputMode.Complete(), options) - completeWriter.commit(20, Array( - MemoryWriterCommitMessage(4, Seq(Row(11), Row(22))), - MemoryWriterCommitMessage(5, Seq(Row(33))))) - assert(sink.latestBatchId.contains(20)) - assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(11, 22, 33)) - completeWriter.commit(21, Array( - MemoryWriterCommitMessage(0, Seq(Row(1), Row(2), Row(3))), - MemoryWriterCommitMessage(1, Seq(Row(4), Row(5), Row(6))), - MemoryWriterCommitMessage(2, Seq(Row(7), Row(8), Row(9))))) - assert(sink.latestBatchId.contains(21)) - assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4, 5, 6, 7)) - - assert(sink.allData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4, 5, 6, 7)) - } - - test("microbatch writer with row limit") { - val sink = new MemorySinkV2 - val optionsMap = new scala.collection.mutable.HashMap[String, String] - optionsMap.put(MemorySinkBase.MAX_MEMORY_SINK_ROWS, 5.toString()) - val options = new DataSourceOptions(optionsMap.toMap.asJava) - - new MemoryWriter(sink, 25, OutputMode.Append(), options).commit(Array( - MemoryWriterCommitMessage(0, Seq(Row(1), Row(2))), - MemoryWriterCommitMessage(1, Seq(Row(3), Row(4))))) - assert(sink.latestBatchId.contains(25)) - assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4)) - assert(sink.allData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4)) - new MemoryWriter(sink, 26, OutputMode.Append(), options).commit(Array( - MemoryWriterCommitMessage(2, Seq(Row(5), Row(6))), - MemoryWriterCommitMessage(3, Seq(Row(7), Row(8))))) - assert(sink.latestBatchId.contains(26)) - assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(5)) - assert(sink.allData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4, 5)) - - new MemoryWriter(sink, 27, OutputMode.Complete(), options).commit(Array( - MemoryWriterCommitMessage(4, Seq(Row(9), Row(10))), - MemoryWriterCommitMessage(5, Seq(Row(11), Row(12))))) - assert(sink.latestBatchId.contains(27)) - assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(9, 10, 11, 12)) - assert(sink.allData.map(_.getInt(0)).sorted == Seq(9, 10, 11, 12)) - new MemoryWriter(sink, 28, OutputMode.Complete(), options).commit(Array( - MemoryWriterCommitMessage(4, Seq(Row(13), Row(14), Row(15))), - MemoryWriterCommitMessage(5, Seq(Row(16), Row(17), Row(18))))) - assert(sink.latestBatchId.contains(28)) - assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(13, 14, 15, 16, 17)) - assert(sink.allData.map(_.getInt(0)).sorted == Seq(13, 14, 15, 16, 17)) - } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala index c1ec1eba69fb2..ca38f04136c7d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala @@ -805,6 +805,114 @@ class StreamSuite extends StreamTest { } } + test("streaming limit without state") { + val inputData1 = MemoryStream[Int] + testStream(inputData1.toDF().limit(0))( + AddData(inputData1, 1 to 8: _*), + CheckAnswer()) + + val inputData2 = MemoryStream[Int] + testStream(inputData2.toDF().limit(4))( + AddData(inputData2, 1 to 8: _*), + CheckAnswer(1 to 4: _*)) + } + + test("streaming limit with state") { + val inputData = MemoryStream[Int] + testStream(inputData.toDF().limit(4))( + AddData(inputData, 1 to 2: _*), + CheckAnswer(1 to 2: _*), + AddData(inputData, 3 to 6: _*), + CheckAnswer(1 to 4: _*), + AddData(inputData, 7 to 9: _*), + CheckAnswer(1 to 4: _*)) + } + + test("streaming limit with other operators") { + val inputData = MemoryStream[Int] + testStream(inputData.toDF().where("value % 2 = 1").limit(4))( + AddData(inputData, 1 to 5: _*), + CheckAnswer(1, 3, 5), + AddData(inputData, 6 to 9: _*), + CheckAnswer(1, 3, 5, 7), + AddData(inputData, 10 to 12: _*), + CheckAnswer(1, 3, 5, 7)) + } + + test("streaming limit with multiple limits") { + val inputData1 = MemoryStream[Int] + testStream(inputData1.toDF().limit(4).limit(2))( + AddData(inputData1, 1), + CheckAnswer(1), + AddData(inputData1, 2 to 8: _*), + CheckAnswer(1, 2)) + + val inputData2 = MemoryStream[Int] + testStream(inputData2.toDF().limit(4).limit(100).limit(3))( + AddData(inputData2, 1, 2), + CheckAnswer(1, 2), + AddData(inputData2, 3 to 8: _*), + CheckAnswer(1 to 3: _*)) + } + + test("streaming limit in complete mode") { + val inputData = MemoryStream[Int] + val limited = inputData.toDF().limit(5).groupBy("value").count() + testStream(limited, OutputMode.Complete())( + AddData(inputData, 1 to 3: _*), + CheckAnswer(Row(1, 1), Row(2, 1), Row(3, 1)), + AddData(inputData, 1 to 9: _*), + CheckAnswer(Row(1, 2), Row(2, 2), Row(3, 2), Row(4, 1), Row(5, 1))) + } + + test("streaming limits in complete mode") { + val inputData = MemoryStream[Int] + val limited = inputData.toDF().limit(4).groupBy("value").count().orderBy("value").limit(3) + testStream(limited, OutputMode.Complete())( + AddData(inputData, 1 to 9: _*), + CheckAnswer(Row(1, 1), Row(2, 1), Row(3, 1)), + AddData(inputData, 2 to 6: _*), + CheckAnswer(Row(1, 1), Row(2, 2), Row(3, 2))) + } + + test("streaming limit in update mode") { + val inputData = MemoryStream[Int] + val e = intercept[AnalysisException] { + testStream(inputData.toDF().limit(5), OutputMode.Update())( + AddData(inputData, 1 to 3: _*) + ) + } + assert(e.getMessage.contains( + "Limits are not supported on streaming DataFrames/Datasets in Update output mode")) + } + + test("streaming limit in multiple partitions") { + val inputData = MemoryStream[Int] + testStream(inputData.toDF().repartition(2).limit(7))( + AddData(inputData, 1 to 10: _*), + CheckAnswerRowsByFunc( + rows => assert(rows.size == 7 && rows.forall(r => r.getInt(0) <= 10)), + false), + AddData(inputData, 11 to 20: _*), + CheckAnswerRowsByFunc( + rows => assert(rows.size == 7 && rows.forall(r => r.getInt(0) <= 10)), + false)) + } + + test("streaming limit in multiple partitions by column") { + val inputData = MemoryStream[(Int, Int)] + val df = inputData.toDF().repartition(2, $"_2").limit(7) + testStream(df)( + AddData(inputData, (1, 0), (2, 0), (3, 1), (4, 1)), + CheckAnswerRowsByFunc( + rows => assert(rows.size == 4 && rows.forall(r => r.getInt(0) <= 4)), + false), + AddData(inputData, (5, 0), (6, 0), (7, 1), (8, 1)), + CheckAnswerRowsByFunc( + rows => assert(rows.size == 7 && rows.forall(r => r.getInt(0) <= 8)), + false)) + } + for (e <- Seq( new InterruptedException, new InterruptedIOException, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index e41b4534ed51d..4c3fd58cb2e45 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -45,7 +45,6 @@ import org.apache.spark.sql.execution.streaming.continuous.{ContinuousExecution, import org.apache.spark.sql.execution.streaming.sources.MemorySinkV2 import org.apache.spark.sql.execution.streaming.state.StateStore import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.sources.v2.DataSourceOptions import org.apache.spark.sql.streaming.StreamingQueryListener._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.util.{Clock, SystemClock, Utils} @@ -338,8 +337,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be var currentStream: StreamExecution = null var lastStream: StreamExecution = null val awaiting = new mutable.HashMap[Int, Offset]() // source index -> offset to wait for - val sink = if (useV2Sink) new MemorySinkV2 - else new MemorySink(stream.schema, outputMode, DataSourceOptions.empty()) + val sink = if (useV2Sink) new MemorySinkV2 else new MemorySink(stream.schema, outputMode) val resetConfValues = mutable.Map[String, Option[String]]() val defaultCheckpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath From 6078b891da8fe7fc36579699473168ae7443284c Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 10 Jul 2018 18:03:40 -0700 Subject: [PATCH 1091/1161] [SPARK-24730][SS] Add policy to choose max as global watermark when streaming query has multiple watermarks ## What changes were proposed in this pull request? Currently, when a streaming query has multiple watermark, the policy is to choose the min of them as the global watermark. This is safe to do as the global watermark moves with the slowest stream, and is therefore is safe as it does not unexpectedly drop some data as late, etc. While this is indeed the safe thing to do, in some cases, you may want the watermark to advance with the fastest stream, that is, take the max of multiple watermarks. This PR is to add that configuration. It makes the following changes. - Adds a configuration to specify max as the policy. - Saves the configuration in OffsetSeqMetadata because changing it in the middle can lead to unpredictable results. - For old checkpoints without the configuration, it assumes the default policy as min (irrespective of the policy set at the session where the query is being restarted). This is to ensure that existing queries are affected in any way. TODO - [ ] Add a test for recovery from existing checkpoints. ## How was this patch tested? New unit test Author: Tathagata Das Closes #21701 from tdas/SPARK-24730. --- .../apache/spark/sql/internal/SQLConf.scala | 15 ++ .../streaming/MicroBatchExecution.scala | 4 +- .../sql/execution/streaming/OffsetSeq.scala | 37 ++++- .../streaming/WatermarkTracker.scala | 90 ++++++++++-- .../commits/0 | 2 + .../commits/1 | 2 + .../metadata | 1 + .../offsets/0 | 4 + .../offsets/1 | 4 + .../streaming/EventTimeWatermarkSuite.scala | 136 +++++++++++++++++- 10 files changed, 276 insertions(+), 19 deletions(-) create mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-for-multi-watermark-policy/commits/0 create mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-for-multi-watermark-policy/commits/1 create mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-for-multi-watermark-policy/metadata create mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-for-multi-watermark-policy/offsets/0 create mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-for-multi-watermark-policy/offsets/1 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 50965c1abc68c..ae56cc97581a5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -875,6 +875,21 @@ object SQLConf { .stringConf .createWithDefault("org.apache.spark.sql.execution.streaming.ManifestFileCommitProtocol") + val STREAMING_MULTIPLE_WATERMARK_POLICY = + buildConf("spark.sql.streaming.multipleWatermarkPolicy") + .doc("Policy to calculate the global watermark value when there are multiple watermark " + + "operators in a streaming query. The default value is 'min' which chooses " + + "the minimum watermark reported across multiple operators. Other alternative value is" + + "'max' which chooses the maximum across multiple operators." + + "Note: This configuration cannot be changed between query restarts from the same " + + "checkpoint location.") + .stringConf + .checkValue( + str => Set("min", "max").contains(str.toLowerCase), + "Invalid value for 'spark.sql.streaming.multipleWatermarkPolicy'. " + + "Valid values are 'min' and 'max'") + .createWithDefault("min") // must be same as MultipleWatermarkPolicy.DEFAULT_POLICY_NAME + val OBJECT_AGG_SORT_BASED_FALLBACK_THRESHOLD = buildConf("spark.sql.objectHashAggregate.sortBased.fallbackThreshold") .internal() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index 17ffa2a517312..16651dd060d73 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -61,7 +61,7 @@ class MicroBatchExecution( case _ => throw new IllegalStateException(s"Unknown type of trigger: $trigger") } - private val watermarkTracker = new WatermarkTracker() + private var watermarkTracker: WatermarkTracker = _ override lazy val logicalPlan: LogicalPlan = { assert(queryExecutionThread eq Thread.currentThread, @@ -257,6 +257,7 @@ class MicroBatchExecution( OffsetSeqMetadata.setSessionConf(metadata, sparkSessionToRunBatches.conf) offsetSeqMetadata = OffsetSeqMetadata( metadata.batchWatermarkMs, metadata.batchTimestampMs, sparkSessionToRunBatches.conf) + watermarkTracker = WatermarkTracker(sparkSessionToRunBatches.conf) watermarkTracker.setWatermark(metadata.batchWatermarkMs) } @@ -295,6 +296,7 @@ class MicroBatchExecution( case None => // We are starting this stream for the first time. logInfo(s"Starting new streaming query.") currentBatchId = 0 + watermarkTracker = WatermarkTracker(sparkSessionToRunBatches.conf) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala index 787174481ff08..1ae3f36c152cf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala @@ -22,7 +22,7 @@ import org.json4s.jackson.Serialization import org.apache.spark.internal.Logging import org.apache.spark.sql.RuntimeConfig -import org.apache.spark.sql.internal.SQLConf.{SHUFFLE_PARTITIONS, STATE_STORE_PROVIDER_CLASS} +import org.apache.spark.sql.internal.SQLConf._ /** * An ordered collection of offsets, used to track the progress of processing data from one or more @@ -86,7 +86,22 @@ case class OffsetSeqMetadata( object OffsetSeqMetadata extends Logging { private implicit val format = Serialization.formats(NoTypeHints) - private val relevantSQLConfs = Seq(SHUFFLE_PARTITIONS, STATE_STORE_PROVIDER_CLASS) + private val relevantSQLConfs = Seq( + SHUFFLE_PARTITIONS, STATE_STORE_PROVIDER_CLASS, STREAMING_MULTIPLE_WATERMARK_POLICY) + + /** + * Default values of relevant configurations that are used for backward compatibility. + * As new configurations are added to the metadata, existing checkpoints may not have those + * confs. The values in this list ensures that the confs without recovered values are + * set to a default value that ensure the same behavior of the streaming query as it was before + * the restart. + * + * Note, that this is optional; set values here if you *have* to override existing session conf + * with a specific default value for ensuring same behavior of the query as before. + */ + private val relevantSQLConfDefaultValues = Map[String, String]( + STREAMING_MULTIPLE_WATERMARK_POLICY.key -> MultipleWatermarkPolicy.DEFAULT_POLICY_NAME + ) def apply(json: String): OffsetSeqMetadata = Serialization.read[OffsetSeqMetadata](json) @@ -115,8 +130,22 @@ object OffsetSeqMetadata extends Logging { case None => // For backward compatibility, if a config was not recorded in the offset log, - // then log it, and let the existing conf value in SparkSession prevail. - logWarning (s"Conf '$confKey' was not found in the offset log, using existing value") + // then either inject a default value (if specified in `relevantSQLConfDefaultValues`) or + // let the existing conf value in SparkSession prevail. + relevantSQLConfDefaultValues.get(confKey) match { + + case Some(defaultValue) => + sessionConf.set(confKey, defaultValue) + logWarning(s"Conf '$confKey' was not found in the offset log, " + + s"using default value '$defaultValue'") + + case None => + val valueStr = sessionConf.getOption(confKey).map { v => + s" Using existing session conf value '$v'." + }.getOrElse { " No value set in session conf." } + logWarning(s"Conf '$confKey' was not found in the offset log. $valueStr") + + } } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/WatermarkTracker.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/WatermarkTracker.scala index 80865669558dd..7b30db44a2090 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/WatermarkTracker.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/WatermarkTracker.scala @@ -20,15 +20,68 @@ package org.apache.spark.sql.execution.streaming import scala.collection.mutable import org.apache.spark.internal.Logging +import org.apache.spark.sql.RuntimeConfig import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.internal.SQLConf -class WatermarkTracker extends Logging { +/** + * Policy to define how to choose a new global watermark value if there are + * multiple watermark operators in a streaming query. + */ +sealed trait MultipleWatermarkPolicy { + def chooseGlobalWatermark(operatorWatermarks: Seq[Long]): Long +} + +object MultipleWatermarkPolicy { + val DEFAULT_POLICY_NAME = "min" + + def apply(policyName: String): MultipleWatermarkPolicy = { + policyName.toLowerCase match { + case DEFAULT_POLICY_NAME => MinWatermark + case "max" => MaxWatermark + case _ => + throw new IllegalArgumentException(s"Could not recognize watermark policy '$policyName'") + } + } +} + +/** + * Policy to choose the *min* of the operator watermark values as the global watermark value. + * Note that this is the safe (hence default) policy as the global watermark will advance + * only if all the individual operator watermarks have advanced. In other words, in a + * streaming query with multiple input streams and watermarks defined on all of them, + * the global watermark will advance as slowly as the slowest input. So if there is watermark + * based state cleanup or late-data dropping, then this policy is the most conservative one. + */ +case object MinWatermark extends MultipleWatermarkPolicy { + def chooseGlobalWatermark(operatorWatermarks: Seq[Long]): Long = { + assert(operatorWatermarks.nonEmpty) + operatorWatermarks.min + } +} + +/** + * Policy to choose the *min* of the operator watermark values as the global watermark value. So the + * global watermark will advance if any of the individual operator watermarks has advanced. + * In other words, in a streaming query with multiple input streams and watermarks defined on all + * of them, the global watermark will advance as fast as the fastest input. So if there is watermark + * based state cleanup or late-data dropping, then this policy is the most aggressive one and + * may lead to unexpected behavior if the data of the slow stream is delayed. + */ +case object MaxWatermark extends MultipleWatermarkPolicy { + def chooseGlobalWatermark(operatorWatermarks: Seq[Long]): Long = { + assert(operatorWatermarks.nonEmpty) + operatorWatermarks.max + } +} + +/** Tracks the watermark value of a streaming query based on a given `policy` */ +case class WatermarkTracker(policy: MultipleWatermarkPolicy) extends Logging { private val operatorToWatermarkMap = mutable.HashMap[Int, Long]() - private var watermarkMs: Long = 0 - private var updated = false + private var globalWatermarkMs: Long = 0 def setWatermark(newWatermarkMs: Long): Unit = synchronized { - watermarkMs = newWatermarkMs + globalWatermarkMs = newWatermarkMs } def updateWatermark(executedPlan: SparkPlan): Unit = synchronized { @@ -37,7 +90,6 @@ class WatermarkTracker extends Logging { } if (watermarkOperators.isEmpty) return - watermarkOperators.zipWithIndex.foreach { case (e, index) if e.eventTimeStats.value.count > 0 => logDebug(s"Observed event time stats $index: ${e.eventTimeStats.value}") @@ -58,16 +110,28 @@ class WatermarkTracker extends Logging { // This is the safest option, because only the global watermark is fault-tolerant. Making // it the minimum of all individual watermarks guarantees it will never advance past where // any individual watermark operator would be if it were in a plan by itself. - val newWatermarkMs = operatorToWatermarkMap.minBy(_._2)._2 - if (newWatermarkMs > watermarkMs) { - logInfo(s"Updating eventTime watermark to: $newWatermarkMs ms") - watermarkMs = newWatermarkMs - updated = true + val chosenGlobalWatermark = policy.chooseGlobalWatermark(operatorToWatermarkMap.values.toSeq) + if (chosenGlobalWatermark > globalWatermarkMs) { + logInfo(s"Updating event-time watermark from $globalWatermarkMs to $chosenGlobalWatermark ms") + globalWatermarkMs = chosenGlobalWatermark } else { - logDebug(s"Event time didn't move: $newWatermarkMs < $watermarkMs") - updated = false + logDebug(s"Event time watermark didn't move: $chosenGlobalWatermark < $globalWatermarkMs") } } - def currentWatermark: Long = synchronized { watermarkMs } + def currentWatermark: Long = synchronized { globalWatermarkMs } +} + +object WatermarkTracker { + def apply(conf: RuntimeConfig): WatermarkTracker = { + // If the session has been explicitly configured to use non-default policy then use it, + // otherwise use the default `min` policy as thats the safe thing to do. + // When recovering from a checkpoint location, it is expected that the `conf` will already + // be configured with the value present in the checkpoint. If there is no policy explicitly + // saved in the checkpoint (e.g., old checkpoints), then the default `min` policy is enforced + // through defaults specified in OffsetSeqMetadata.setSessionConf(). + val policyName = conf.get( + SQLConf.STREAMING_MULTIPLE_WATERMARK_POLICY, MultipleWatermarkPolicy.DEFAULT_POLICY_NAME) + new WatermarkTracker(MultipleWatermarkPolicy(policyName)) + } } diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-for-multi-watermark-policy/commits/0 b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-for-multi-watermark-policy/commits/0 new file mode 100644 index 0000000000000..83321cd95eb0c --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-for-multi-watermark-policy/commits/0 @@ -0,0 +1,2 @@ +v1 +{} \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-for-multi-watermark-policy/commits/1 b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-for-multi-watermark-policy/commits/1 new file mode 100644 index 0000000000000..83321cd95eb0c --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-for-multi-watermark-policy/commits/1 @@ -0,0 +1,2 @@ +v1 +{} \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-for-multi-watermark-policy/metadata b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-for-multi-watermark-policy/metadata new file mode 100644 index 0000000000000..d6be7fbffa9b7 --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-for-multi-watermark-policy/metadata @@ -0,0 +1 @@ +{"id":"549eeb1a-d762-420c-bb44-3fd6d73a5268"} \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-for-multi-watermark-policy/offsets/0 b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-for-multi-watermark-policy/offsets/0 new file mode 100644 index 0000000000000..43db49d052894 --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-for-multi-watermark-policy/offsets/0 @@ -0,0 +1,4 @@ +v1 +{"batchWatermarkMs":0,"batchTimestampMs":1531172902041,"conf":{"spark.sql.shuffle.partitions":"10","spark.sql.streaming.stateStore.providerClass":"org.apache.spark.sql.execution.streaming.state.HDFSBackedStateStoreProvider"}} +0 +0 \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-for-multi-watermark-policy/offsets/1 b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-for-multi-watermark-policy/offsets/1 new file mode 100644 index 0000000000000..8cc898e81017f --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-for-multi-watermark-policy/offsets/1 @@ -0,0 +1,4 @@ +v1 +{"batchWatermarkMs":10000,"batchTimestampMs":1531172902217,"conf":{"spark.sql.shuffle.partitions":"10","spark.sql.streaming.stateStore.providerClass":"org.apache.spark.sql.execution.streaming.state.HDFSBackedStateStoreProvider"}} +1 +0 \ No newline at end of file diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala index 7e8fde1ff8e56..58ed9790ea123 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala @@ -18,18 +18,22 @@ package org.apache.spark.sql.streaming import java.{util => ju} +import java.io.File import java.text.SimpleDateFormat import java.util.{Calendar, Date} +import org.apache.commons.io.FileUtils import org.scalatest.{BeforeAndAfter, Matchers} import org.apache.spark.internal.Logging -import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.{AnalysisException, Dataset} import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.functions.{count, window} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.OutputMode._ +import org.apache.spark.util.Utils class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matchers with Logging { @@ -484,6 +488,136 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche testWithFlag(false) } + test("MultipleWatermarkPolicy: max") { + val input1 = MemoryStream[Int] + val input2 = MemoryStream[Int] + + withSQLConf(SQLConf.STREAMING_MULTIPLE_WATERMARK_POLICY.key -> "max") { + testStream(dfWithMultipleWatermarks(input1, input2))( + MultiAddData(input1, 20)(input2, 30), + CheckLastBatch(20, 30), + checkWatermark(input1, 15), // max(20 - 10, 30 - 15) = 15 + StopStream, + StartStream(), + checkWatermark(input1, 15), // watermark recovered correctly + MultiAddData(input1, 120)(input2, 130), + CheckLastBatch(120, 130), + checkWatermark(input1, 115), // max(120 - 10, 130 - 15) = 115, policy recovered correctly + AddData(input1, 150), + CheckLastBatch(150), + checkWatermark(input1, 140) // should advance even if one of the input has data + ) + } + } + + test("MultipleWatermarkPolicy: min") { + val input1 = MemoryStream[Int] + val input2 = MemoryStream[Int] + + withSQLConf(SQLConf.STREAMING_MULTIPLE_WATERMARK_POLICY.key -> "min") { + testStream(dfWithMultipleWatermarks(input1, input2))( + MultiAddData(input1, 20)(input2, 30), + CheckLastBatch(20, 30), + checkWatermark(input1, 10), // min(20 - 10, 30 - 15) = 10 + StopStream, + StartStream(), + checkWatermark(input1, 10), // watermark recovered correctly + MultiAddData(input1, 120)(input2, 130), + CheckLastBatch(120, 130), + checkWatermark(input2, 110), // min(120 - 10, 130 - 15) = 110, policy recovered correctly + AddData(input2, 150), + CheckLastBatch(150), + checkWatermark(input2, 110) // does not advance when only one of the input has data + ) + } + } + + test("MultipleWatermarkPolicy: recovery from checkpoints ignores session conf") { + val input1 = MemoryStream[Int] + val input2 = MemoryStream[Int] + + val checkpointDir = Utils.createTempDir().getCanonicalFile + withSQLConf(SQLConf.STREAMING_MULTIPLE_WATERMARK_POLICY.key -> "max") { + testStream(dfWithMultipleWatermarks(input1, input2))( + StartStream(checkpointLocation = checkpointDir.getAbsolutePath), + MultiAddData(input1, 20)(input2, 30), + CheckLastBatch(20, 30), + checkWatermark(input1, 15) // max(20 - 10, 30 - 15) = 15 + ) + } + + withSQLConf(SQLConf.STREAMING_MULTIPLE_WATERMARK_POLICY.key -> "min") { + testStream(dfWithMultipleWatermarks(input1, input2))( + StartStream(checkpointLocation = checkpointDir.getAbsolutePath), + checkWatermark(input1, 15), // watermark recovered correctly + MultiAddData(input1, 120)(input2, 130), + CheckLastBatch(120, 130), + checkWatermark(input1, 115), // max(120 - 10, 130 - 15) = 115, policy recovered correctly + AddData(input1, 150), + CheckLastBatch(150), + checkWatermark(input1, 140) // should advance even if one of the input has data + ) + } + } + + test("MultipleWatermarkPolicy: recovery from Spark ver 2.3.1 checkpoints ensures min policy") { + val input1 = MemoryStream[Int] + val input2 = MemoryStream[Int] + + val resourceUri = this.getClass.getResource( + "/structured-streaming/checkpoint-version-2.3.1-for-multi-watermark-policy/").toURI + + val checkpointDir = Utils.createTempDir().getCanonicalFile + // Copy the checkpoint to a temp dir to prevent changes to the original. + // Not doing this will lead to the test passing on the first run, but fail subsequent runs. + FileUtils.copyDirectory(new File(resourceUri), checkpointDir) + + input1.addData(20) + input2.addData(30) + input1.addData(10) + + withSQLConf(SQLConf.STREAMING_MULTIPLE_WATERMARK_POLICY.key -> "max") { + testStream(dfWithMultipleWatermarks(input1, input2))( + StartStream(checkpointLocation = checkpointDir.getAbsolutePath), + Execute { _.processAllAvailable() }, + MultiAddData(input1, 120)(input2, 130), + CheckLastBatch(120, 130), + checkWatermark(input2, 110), // should calculate 'min' even if session conf has 'max' policy + AddData(input2, 150), + CheckLastBatch(150), + checkWatermark(input2, 110) + ) + } + } + + test("MultipleWatermarkPolicy: fail on incorrect conf values") { + val invalidValues = Seq("", "random") + invalidValues.foreach { value => + val e = intercept[IllegalArgumentException] { + spark.conf.set(SQLConf.STREAMING_MULTIPLE_WATERMARK_POLICY.key, value) + } + assert(e.getMessage.toLowerCase.contains("valid values are 'min' and 'max'")) + } + } + + private def dfWithMultipleWatermarks( + input1: MemoryStream[Int], + input2: MemoryStream[Int]): Dataset[_] = { + val df1 = input1.toDF + .withColumn("eventTime", $"value".cast("timestamp")) + .withWatermark("eventTime", "10 seconds") + val df2 = input2.toDF + .withColumn("eventTime", $"value".cast("timestamp")) + .withWatermark("eventTime", "15 seconds") + df1.union(df2).select($"eventTime".cast("int")) + } + + private def checkWatermark(input: MemoryStream[Int], watermark: Long) = Execute { q => + input.addData(1) + q.processAllAvailable() + assert(q.lastProgress.eventTime.get("watermark") == formatTimestamp(watermark)) + } + private def assertNumStateRows(numTotalRows: Long): AssertOnQuery = AssertOnQuery { q => q.processAllAvailable() val progressWithData = q.recentProgress.lastOption.get From 1f94bf492c3bce3b61f7fec6132b50e06dea94a8 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Wed, 11 Jul 2018 10:10:07 +0800 Subject: [PATCH 1092/1161] [SPARK-24530][PYTHON] Add a control to force Python version in Sphinx via environment variable, SPHINXPYTHON ## What changes were proposed in this pull request? This PR proposes to add `SPHINXPYTHON` environment variable to control the Python version used by Sphinx. The motivation of this environment variable is, it seems not properly rendering some signatures in the Python documentation when Python 2 is used by Sphinx. See the JIRA's case. It should be encouraged to use Python 3, but looks we will probably live with this problem for a long while in any event. For the default case of `make html`, it keeps previous behaviour and use `SPHINXBUILD` as it was. If `SPHINXPYTHON` is set, then it forces Sphinx to use the specific Python version. ``` $ SPHINXPYTHON=python3 make html python3 -msphinx -b html -d _build/doctrees . _build/html Running Sphinx v1.7.5 ... ``` 1. if `SPHINXPYTHON` is set, use Python. If `SPHINXBUILD` is set, use sphinx-build. 2. If both are set, `SPHINXBUILD` has a higher priority over `SPHINXPYTHON` 3. By default, `SPHINXBUILD` is used as 'sphinx-build'. Probably, we can somehow work around this via explicitly setting `SPHINXBUILD` but `sphinx-build` can't be easily distinguished since it (at least in my environment and up to my knowledge) doesn't replace `sphinx-build` when newer Sphinx is installed in different Python version. It confuses and doesn't warn for its Python version. ## How was this patch tested? Manually tested: **`python` (Python 2.7) in the path with Sphinx:** ``` $ make html sphinx-build -b html -d _build/doctrees . _build/html Running Sphinx v1.7.5 ... ``` **`python` (Python 2.7) in the path without Sphinx:** ``` $ make html Makefile:8: *** The 'sphinx-build' command was not found. Make sure you have Sphinx installed, then set the SPHINXBUILD environment variable to point to the full path of the 'sphinx-build' executable. Alternatively you can add the directory with the executable to your PATH. If you don't have Sphinx installed, grab it from http://sphinx-doc.org/. Stop. ``` **`SPHINXPYTHON` set `python` (Python 2.7) with Sphinx:** ``` $ SPHINXPYTHON=python make html Makefile:35: *** Note that Python 3 is required to generate PySpark documentation correctly for now. Current Python executable was less than Python 3. See SPARK-24530. To force Sphinx to use a specific Python executable, please set SPHINXPYTHON to point to the Python 3 executable.. Stop. ``` **`SPHINXPYTHON` set `python` (Python 2.7) without Sphinx:** ``` $ SPHINXPYTHON=python make html Makefile:35: *** Note that Python 3 is required to generate PySpark documentation correctly for now. Current Python executable was less than Python 3. See SPARK-24530. To force Sphinx to use a specific Python executable, please set SPHINXPYTHON to point to the Python 3 executable.. Stop. ``` **`SPHINXPYTHON` set `python3` with Sphinx:** ``` $ SPHINXPYTHON=python3 make html python3 -msphinx -b html -d _build/doctrees . _build/html Running Sphinx v1.7.5 ... ``` **`SPHINXPYTHON` set `python3` without Sphinx:** ``` $ SPHINXPYTHON=python3 make html Makefile:39: *** Python executable 'python3' did not have Sphinx installed. Make sure you have Sphinx installed, then set the SPHINXPYTHON environment variable to point to the Python executable having Sphinx installed. If you don't have Sphinx installed, grab it from http://sphinx-doc.org/. Stop. ``` **`SPHINXBUILD` set:** ``` $ SPHINXBUILD=sphinx-build make html sphinx-build -b html -d _build/doctrees . _build/html Running Sphinx v1.7.5 ... ``` **Both `SPHINXPYTHON` and `SPHINXBUILD` are set:** ``` $ SPHINXBUILD=sphinx-build SPHINXPYTHON=python make html sphinx-build -b html -d _build/doctrees . _build/html Running Sphinx v1.7.5 ... ``` Author: hyukjinkwon Closes #21659 from HyukjinKwon/SPARK-24530. --- python/docs/Makefile | 37 +++++++++++++++++++++++++++++++------ 1 file changed, 31 insertions(+), 6 deletions(-) diff --git a/python/docs/Makefile b/python/docs/Makefile index b8e079483c90c..1ed1f33af2326 100644 --- a/python/docs/Makefile +++ b/python/docs/Makefile @@ -1,19 +1,44 @@ # Makefile for Sphinx documentation # +ifndef SPHINXBUILD +ifndef SPHINXPYTHON +SPHINXBUILD = sphinx-build +endif +endif + +ifdef SPHINXBUILD +# User-friendly check for sphinx-build. +ifeq ($(shell which $(SPHINXBUILD) >/dev/null 2>&1; echo $$?), 1) +$(error The '$(SPHINXBUILD)' command was not found. Make sure you have Sphinx installed, then set the SPHINXBUILD environment variable to point to the full path of the '$(SPHINXBUILD)' executable. Alternatively you can add the directory with the executable to your PATH. If you don't have Sphinx installed, grab it from http://sphinx-doc.org/) +endif +else +# Note that there is an issue with Python version and Sphinx in PySpark documentation generation. +# Please remove this check below when this issue is fixed. See SPARK-24530 for more details. +PYTHON_VERSION_CHECK = $(shell $(SPHINXPYTHON) -c 'import sys; print(sys.version_info < (3, 0, 0))') +ifeq ($(PYTHON_VERSION_CHECK), True) +$(error Note that Python 3 is required to generate PySpark documentation correctly for now. Current Python executable was less than Python 3. See SPARK-24530. To force Sphinx to use a specific Python executable, please set SPHINXPYTHON to point to the Python 3 executable.) +endif +# Check if Sphinx is installed. +ifeq ($(shell $(SPHINXPYTHON) -c 'import sphinx' >/dev/null 2>&1; echo $$?), 1) +$(error Python executable '$(SPHINXPYTHON)' did not have Sphinx installed. Make sure you have Sphinx installed, then set the SPHINXPYTHON environment variable to point to the Python executable having Sphinx installed. If you don't have Sphinx installed, grab it from http://sphinx-doc.org/) +endif +# Use 'SPHINXPYTHON -msphinx' instead of 'sphinx-build'. See https://github.com/sphinx-doc/sphinx/pull/3523 for more details. +SPHINXBUILD = $(SPHINXPYTHON) -msphinx +endif + # You can set these variables from the command line. SPHINXOPTS ?= -SPHINXBUILD ?= sphinx-build PAPER ?= BUILDDIR ?= _build +# You can set SPHINXBUILD to specify Sphinx build executable or SPHINXPYTHON to specify the Python executable used in Sphinx. +# They follow: +# 1. if SPHINXPYTHON is set, use Python. If SPHINXBUILD is set, use sphinx-build. +# 2. If both are set, SPHINXBUILD has a higher priority over SPHINXPYTHON +# 3. By default, SPHINXBUILD is used as 'sphinx-build'. export PYTHONPATH=$(realpath ..):$(realpath ../lib/py4j-0.10.7-src.zip) -# User-friendly check for sphinx-build -ifeq ($(shell which $(SPHINXBUILD) >/dev/null 2>&1; echo $$?), 1) -$(error The '$(SPHINXBUILD)' command was not found. Make sure you have Sphinx installed, then set the SPHINXBUILD environment variable to point to the full path of the '$(SPHINXBUILD)' executable. Alternatively you can add the directory with the executable to your PATH. If you don't have Sphinx installed, grab it from http://sphinx-doc.org/) -endif - # Internal variables. PAPEROPT_a4 = -D latex_paper_size=a4 PAPEROPT_letter = -D latex_paper_size=letter From 74a8d6308bfa6e7ed4c64e1175c77eb3114baed5 Mon Sep 17 00:00:00 2001 From: Marek Novotny Date: Wed, 11 Jul 2018 12:21:03 +0800 Subject: [PATCH 1093/1161] [SPARK-24165][SQL] Fixing conditional expressions to handle nullability of nested types ## What changes were proposed in this pull request? This PR is proposing a fix for the output data type of ```If``` and ```CaseWhen``` expression. Upon till now, the implementation of exprassions has ignored nullability of nested types from different execution branches and returned the type of the first branch. This could lead to an unwanted ```NullPointerException``` from other expressions depending on a ```If```/```CaseWhen``` expression. Example: ``` val rows = new util.ArrayList[Row]() rows.add(Row(true, ("a", 1))) rows.add(Row(false, (null, 2))) val schema = StructType(Seq( StructField("cond", BooleanType, false), StructField("s", StructType(Seq( StructField("val1", StringType, true), StructField("val2", IntegerType, false) )), false) )) val df = spark.createDataFrame(rows, schema) df .select(when('cond, struct(lit("x").as("val1"), lit(10).as("val2"))).otherwise('s) as "res") .select('res.getField("val1")) .show() ``` Exception: ``` Exception in thread "main" java.lang.NullPointerException at org.apache.spark.sql.catalyst.expressions.codegen.UnsafeWriter.write(UnsafeWriter.java:109) at org.apache.spark.sql.catalyst.expressions.GeneratedClass$SpecificUnsafeProjection.apply(Unknown Source) at org.apache.spark.sql.execution.LocalTableScanExec$$anonfun$unsafeRows$1.apply(LocalTableScanExec.scala:44) at org.apache.spark.sql.execution.LocalTableScanExec$$anonfun$unsafeRows$1.apply(LocalTableScanExec.scala:44) ... ``` Output schema: ``` root |-- res.val1: string (nullable = false) ``` ## How was this patch tested? New test cases added into - DataFrameSuite.scala - conditionalExpressions.scala Author: Marek Novotny Closes #21687 from mn-mikke/SPARK-24165. --- .../sql/catalyst/analysis/TypeCoercion.scala | 22 ++++-- .../sql/catalyst/expressions/Expression.scala | 37 +++++++++- .../expressions/conditionalExpressions.scala | 28 ++++---- .../ConditionalExpressionSuite.scala | 70 +++++++++++++++++++ .../org/apache/spark/sql/DataFrameSuite.scala | 58 +++++++++++++++ 5 files changed, 195 insertions(+), 20 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 72908c1f433ee..e8331c90ea0f6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -172,6 +172,18 @@ object TypeCoercion { case _ => None } + /** + * The method finds a common type for data types that differ only in nullable, containsNull + * and valueContainsNull flags. If the input types are too different, None is returned. + */ + def findCommonTypeDifferentOnlyInNullFlags(t1: DataType, t2: DataType): Option[DataType] = { + if (t1 == t2) { + Some(t1) + } else { + findTypeForComplex(t1, t2, findCommonTypeDifferentOnlyInNullFlags) + } + } + /** * Case 2 type widening (see the classdoc comment above for TypeCoercion). * @@ -660,8 +672,8 @@ object TypeCoercion { object CaseWhenCoercion extends TypeCoercionRule { override protected def coerceTypes( plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { - case c: CaseWhen if c.childrenResolved && !c.valueTypesEqual => - val maybeCommonType = findWiderCommonType(c.valueTypes) + case c: CaseWhen if c.childrenResolved && !c.areInputTypesForMergingEqual => + val maybeCommonType = findWiderCommonType(c.inputTypesForMerging) maybeCommonType.map { commonType => var changed = false val newBranches = c.branches.map { case (condition, value) => @@ -693,10 +705,10 @@ object TypeCoercion { plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { case e if !e.childrenResolved => e // Find tightest common type for If, if the true value and false value have different types. - case i @ If(pred, left, right) if left.dataType != right.dataType => + case i @ If(pred, left, right) if !i.areInputTypesForMergingEqual => findWiderTypeForTwo(left.dataType, right.dataType).map { widestType => - val newLeft = if (left.dataType == widestType) left else Cast(left, widestType) - val newRight = if (right.dataType == widestType) right else Cast(right, widestType) + val newLeft = if (left.dataType.sameType(widestType)) left else Cast(left, widestType) + val newRight = if (right.dataType.sameType(widestType)) right else Cast(right, widestType) If(pred, newLeft, newRight) }.getOrElse(i) // If there is no applicable conversion, leave expression unchanged. case If(Literal(null, NullType), left, right) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 9b9fa41a47d0f..44c5556ff9ccf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import java.util.Locale import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.trees.TreeNode @@ -695,6 +695,41 @@ abstract class TernaryExpression extends Expression { } } +/** + * A trait resolving nullable, containsNull, valueContainsNull flags of the output date type. + * This logic is usually utilized by expressions combining data from multiple child expressions + * of non-primitive types (e.g. [[CaseWhen]]). + */ +trait ComplexTypeMergingExpression extends Expression { + + /** + * A collection of data types used for resolution the output type of the expression. By default, + * data types of all child expressions. The collection must not be empty. + */ + @transient + lazy val inputTypesForMerging: Seq[DataType] = children.map(_.dataType) + + /** + * A method determining whether the input types are equal ignoring nullable, containsNull and + * valueContainsNull flags and thus convenient for resolution of the final data type. + */ + def areInputTypesForMergingEqual: Boolean = { + inputTypesForMerging.length <= 1 || inputTypesForMerging.sliding(2, 1).forall { + case Seq(dt1, dt2) => dt1.sameType(dt2) + } + } + + override def dataType: DataType = { + require( + inputTypesForMerging.nonEmpty, + "The collection of input data types must not be empty.") + require( + areInputTypesForMergingEqual, + "All input types must be the same except nullable, containsNull, valueContainsNull flags.") + inputTypesForMerging.reduceLeft(TypeCoercion.findCommonTypeDifferentOnlyInNullFlags(_, _).get) + } +} + /** * Common base trait for user-defined functions, including UDF/UDAF/UDTF of different languages * and Hive function wrappers. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index 77ac6c088022e..e6377b7d87b53 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.types._ @@ -33,7 +33,12 @@ import org.apache.spark.sql.types._ """) // scalastyle:on line.size.limit case class If(predicate: Expression, trueValue: Expression, falseValue: Expression) - extends Expression { + extends ComplexTypeMergingExpression { + + @transient + override lazy val inputTypesForMerging: Seq[DataType] = { + Seq(trueValue.dataType, falseValue.dataType) + } override def children: Seq[Expression] = predicate :: trueValue :: falseValue :: Nil override def nullable: Boolean = trueValue.nullable || falseValue.nullable @@ -43,7 +48,7 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi TypeCheckResult.TypeCheckFailure( "type of predicate expression in If should be boolean, " + s"not ${predicate.dataType.simpleString}") - } else if (!trueValue.dataType.sameType(falseValue.dataType)) { + } else if (!areInputTypesForMergingEqual) { TypeCheckResult.TypeCheckFailure(s"differing types in '$sql' " + s"(${trueValue.dataType.simpleString} and ${falseValue.dataType.simpleString}).") } else { @@ -51,8 +56,6 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi } } - override def dataType: DataType = trueValue.dataType - override def eval(input: InternalRow): Any = { if (java.lang.Boolean.TRUE.equals(predicate.eval(input))) { trueValue.eval(input) @@ -118,27 +121,24 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi case class CaseWhen( branches: Seq[(Expression, Expression)], elseValue: Option[Expression] = None) - extends Expression with Serializable { + extends ComplexTypeMergingExpression with Serializable { override def children: Seq[Expression] = branches.flatMap(b => b._1 :: b._2 :: Nil) ++ elseValue // both then and else expressions should be considered. - def valueTypes: Seq[DataType] = branches.map(_._2.dataType) ++ elseValue.map(_.dataType) - - def valueTypesEqual: Boolean = valueTypes.size <= 1 || valueTypes.sliding(2, 1).forall { - case Seq(dt1, dt2) => dt1.sameType(dt2) + @transient + override lazy val inputTypesForMerging: Seq[DataType] = { + branches.map(_._2.dataType) ++ elseValue.map(_.dataType) } - override def dataType: DataType = branches.head._2.dataType - override def nullable: Boolean = { // Result is nullable if any of the branch is nullable, or if the else value is nullable branches.exists(_._2.nullable) || elseValue.map(_.nullable).getOrElse(true) } override def checkInputDataTypes(): TypeCheckResult = { - // Make sure all branch conditions are boolean types. - if (valueTypesEqual) { + if (areInputTypesForMergingEqual) { + // Make sure all branch conditions are boolean types. if (branches.forall(_._1.dataType == BooleanType)) { TypeCheckResult.TypeCheckSuccess } else { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala index a099119732e25..e068c32500cfc 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala @@ -113,6 +113,76 @@ class ConditionalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper assert(CaseWhen(Seq((c2, c4_notNull), (c3, c5))).nullable === true) } + test("if/case when - null flags of non-primitive types") { + val arrayWithNulls = Literal.create(Seq("a", null, "b"), ArrayType(StringType, true)) + val arrayWithoutNulls = Literal.create(Seq("c", "d"), ArrayType(StringType, false)) + val structWithNulls = Literal.create( + create_row(null, null), + StructType(Seq(StructField("a", IntegerType, true), StructField("b", StringType, true)))) + val structWithoutNulls = Literal.create( + create_row(1, "a"), + StructType(Seq(StructField("a", IntegerType, false), StructField("b", StringType, false)))) + val mapWithNulls = Literal.create(Map(1 -> null), MapType(IntegerType, StringType, true)) + val mapWithoutNulls = Literal.create(Map(1 -> "a"), MapType(IntegerType, StringType, false)) + + val arrayIf1 = If(Literal.FalseLiteral, arrayWithNulls, arrayWithoutNulls) + val arrayIf2 = If(Literal.FalseLiteral, arrayWithoutNulls, arrayWithNulls) + val arrayIf3 = If(Literal.TrueLiteral, arrayWithNulls, arrayWithoutNulls) + val arrayIf4 = If(Literal.TrueLiteral, arrayWithoutNulls, arrayWithNulls) + val structIf1 = If(Literal.FalseLiteral, structWithNulls, structWithoutNulls) + val structIf2 = If(Literal.FalseLiteral, structWithoutNulls, structWithNulls) + val structIf3 = If(Literal.TrueLiteral, structWithNulls, structWithoutNulls) + val structIf4 = If(Literal.TrueLiteral, structWithoutNulls, structWithNulls) + val mapIf1 = If(Literal.FalseLiteral, mapWithNulls, mapWithoutNulls) + val mapIf2 = If(Literal.FalseLiteral, mapWithoutNulls, mapWithNulls) + val mapIf3 = If(Literal.TrueLiteral, mapWithNulls, mapWithoutNulls) + val mapIf4 = If(Literal.TrueLiteral, mapWithoutNulls, mapWithNulls) + + val arrayCaseWhen1 = CaseWhen(Seq((Literal.FalseLiteral, arrayWithNulls)), arrayWithoutNulls) + val arrayCaseWhen2 = CaseWhen(Seq((Literal.FalseLiteral, arrayWithoutNulls)), arrayWithNulls) + val arrayCaseWhen3 = CaseWhen(Seq((Literal.TrueLiteral, arrayWithNulls)), arrayWithoutNulls) + val arrayCaseWhen4 = CaseWhen(Seq((Literal.TrueLiteral, arrayWithoutNulls)), arrayWithNulls) + val structCaseWhen1 = CaseWhen(Seq((Literal.FalseLiteral, structWithNulls)), structWithoutNulls) + val structCaseWhen2 = CaseWhen(Seq((Literal.FalseLiteral, structWithoutNulls)), structWithNulls) + val structCaseWhen3 = CaseWhen(Seq((Literal.TrueLiteral, structWithNulls)), structWithoutNulls) + val structCaseWhen4 = CaseWhen(Seq((Literal.TrueLiteral, structWithoutNulls)), structWithNulls) + val mapCaseWhen1 = CaseWhen(Seq((Literal.FalseLiteral, mapWithNulls)), mapWithoutNulls) + val mapCaseWhen2 = CaseWhen(Seq((Literal.FalseLiteral, mapWithoutNulls)), mapWithNulls) + val mapCaseWhen3 = CaseWhen(Seq((Literal.TrueLiteral, mapWithNulls)), mapWithoutNulls) + val mapCaseWhen4 = CaseWhen(Seq((Literal.TrueLiteral, mapWithoutNulls)), mapWithNulls) + + def checkResult(expectedType: DataType, expectedValue: Any, result: Expression): Unit = { + assert(expectedType == result.dataType) + checkEvaluation(result, expectedValue) + } + + checkResult(arrayWithNulls.dataType, arrayWithoutNulls.value, arrayIf1) + checkResult(arrayWithNulls.dataType, arrayWithNulls.value, arrayIf2) + checkResult(arrayWithNulls.dataType, arrayWithNulls.value, arrayIf3) + checkResult(arrayWithNulls.dataType, arrayWithoutNulls.value, arrayIf4) + checkResult(structWithNulls.dataType, structWithoutNulls.value, structIf1) + checkResult(structWithNulls.dataType, structWithNulls.value, structIf2) + checkResult(structWithNulls.dataType, structWithNulls.value, structIf3) + checkResult(structWithNulls.dataType, structWithoutNulls.value, structIf4) + checkResult(mapWithNulls.dataType, mapWithoutNulls.value, mapIf1) + checkResult(mapWithNulls.dataType, mapWithNulls.value, mapIf2) + checkResult(mapWithNulls.dataType, mapWithNulls.value, mapIf3) + checkResult(mapWithNulls.dataType, mapWithoutNulls.value, mapIf4) + + checkResult(arrayWithNulls.dataType, arrayWithoutNulls.value, arrayCaseWhen1) + checkResult(arrayWithNulls.dataType, arrayWithNulls.value, arrayCaseWhen2) + checkResult(arrayWithNulls.dataType, arrayWithNulls.value, arrayCaseWhen3) + checkResult(arrayWithNulls.dataType, arrayWithoutNulls.value, arrayCaseWhen4) + checkResult(structWithNulls.dataType, structWithoutNulls.value, structCaseWhen1) + checkResult(structWithNulls.dataType, structWithNulls.value, structCaseWhen2) + checkResult(structWithNulls.dataType, structWithNulls.value, structCaseWhen3) + checkResult(structWithNulls.dataType, structWithoutNulls.value, structCaseWhen4) + checkResult(mapWithNulls.dataType, mapWithoutNulls.value, mapCaseWhen1) + checkResult(mapWithNulls.dataType, mapWithNulls.value, mapCaseWhen2) + checkResult(mapWithNulls.dataType, mapWithNulls.value, mapCaseWhen3) + checkResult(mapWithNulls.dataType, mapWithoutNulls.value, mapCaseWhen4) + } + test("case key when") { val row = create_row(null, 1, 2, "a", "b", "c") val c1 = 'a.int.at(0) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index ea00d22bff001..9d7645d232d08 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -2320,6 +2320,64 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { assert(df.queryExecution.executedPlan.isInstanceOf[WholeStageCodegenExec]) } + test("SPARK-24165: CaseWhen/If - nullability of nested types") { + val rows = new java.util.ArrayList[Row]() + rows.add(Row(true, ("x", 1), Seq("x", "y"), Map(0 -> "x"))) + rows.add(Row(false, (null, 2), Seq(null, "z"), Map(0 -> null))) + val schema = StructType(Seq( + StructField("cond", BooleanType, true), + StructField("s", StructType(Seq( + StructField("val1", StringType, true), + StructField("val2", IntegerType, false) + )), false), + StructField("a", ArrayType(StringType, true)), + StructField("m", MapType(IntegerType, StringType, true)) + )) + + val sourceDF = spark.createDataFrame(rows, schema) + + val structWhenDF = sourceDF + .select(when('cond, struct(lit("a").as("val1"), lit(10).as("val2"))).otherwise('s) as "res") + .select('res.getField("val1")) + val arrayWhenDF = sourceDF + .select(when('cond, array(lit("a"), lit("b"))).otherwise('a) as "res") + .select('res.getItem(0)) + val mapWhenDF = sourceDF + .select(when('cond, map(lit(0), lit("a"))).otherwise('m) as "res") + .select('res.getItem(0)) + + val structIfDF = sourceDF + .select(expr("if(cond, struct('a' as val1, 10 as val2), s)") as "res") + .select('res.getField("val1")) + val arrayIfDF = sourceDF + .select(expr("if(cond, array('a', 'b'), a)") as "res") + .select('res.getItem(0)) + val mapIfDF = sourceDF + .select(expr("if(cond, map(0, 'a'), m)") as "res") + .select('res.getItem(0)) + + def checkResult(df: DataFrame, codegenExpected: Boolean): Unit = { + assert(df.queryExecution.executedPlan.isInstanceOf[WholeStageCodegenExec] == codegenExpected) + checkAnswer(df, Seq(Row("a"), Row(null))) + } + + // without codegen + checkResult(structWhenDF, false) + checkResult(arrayWhenDF, false) + checkResult(mapWhenDF, false) + checkResult(structIfDF, false) + checkResult(arrayIfDF, false) + checkResult(mapIfDF, false) + + // with codegen + checkResult(structWhenDF.filter('cond.isNotNull), true) + checkResult(arrayWhenDF.filter('cond.isNotNull), true) + checkResult(mapWhenDF.filter('cond.isNotNull), true) + checkResult(structIfDF.filter('cond.isNotNull), true) + checkResult(arrayIfDF.filter('cond.isNotNull), true) + checkResult(mapIfDF.filter('cond.isNotNull), true) + } + test("Uuid expressions should produce same results at retries in the same DataFrame") { val df = spark.range(1).select($"id", new Column(Uuid())) checkAnswer(df, df.collect()) From 5ff1b9ba1983d5601add62aef64a3e87d07050eb Mon Sep 17 00:00:00 2001 From: Andrew Korzhuev Date: Tue, 10 Jul 2018 22:53:44 -0700 Subject: [PATCH 1094/1161] [SPARK-23529][K8S] Support mounting volumes This PR continues #21095 and intersects with #21238. I've added volume mounts as a separate step and added PersistantVolumeClaim support. There is a fundamental problem with how we pass the options through spark conf to fabric8. For each volume type and all possible volume options we would have to implement some custom code to map config values to fabric8 calls. This will result in big body of code we would have to support and means that Spark will always be somehow out of sync with k8s. I think there needs to be a discussion on how to proceed correctly (eg use PodPreset instead) ---- Due to the complications of provisioning and managing actual resources this PR addresses only volume mounting of already present resources. ---- - [x] emptyDir support - [x] Testing - [x] Documentation - [x] KubernetesVolumeUtils tests Author: Andrew Korzhuev Author: madanadit Closes #21260 from andrusha/k8s-vol. --- docs/running-on-kubernetes.md | 48 ++++++ .../org/apache/spark/deploy/k8s/Config.scala | 12 ++ .../spark/deploy/k8s/KubernetesConf.scala | 11 ++ .../spark/deploy/k8s/KubernetesUtils.scala | 2 - .../deploy/k8s/KubernetesVolumeSpec.scala | 38 +++++ .../deploy/k8s/KubernetesVolumeUtils.scala | 110 +++++++++++++ .../k8s/features/BasicDriverFeatureStep.scala | 5 +- .../features/BasicExecutorFeatureStep.scala | 5 +- .../features/MountVolumesFeatureStep.scala | 79 ++++++++++ .../k8s/submit/KubernetesDriverBuilder.scala | 31 ++-- .../k8s/KubernetesExecutorBuilder.scala | 38 ++--- .../k8s/KubernetesVolumeUtilsSuite.scala | 106 +++++++++++++ .../BasicDriverFeatureStepSuite.scala | 23 +-- .../BasicExecutorFeatureStepSuite.scala | 3 + ...ubernetesCredentialsFeatureStepSuite.scala | 3 + .../DriverServiceFeatureStepSuite.scala | 6 + .../features/EnvSecretsFeatureStepSuite.scala | 1 + .../features/LocalDirsFeatureStepSuite.scala | 3 +- .../MountSecretsFeatureStepSuite.scala | 1 + .../MountVolumesFeatureStepSuite.scala | 144 ++++++++++++++++++ .../bindings/JavaDriverFeatureStepSuite.scala | 1 + .../PythonDriverFeatureStepSuite.scala | 2 + .../spark/deploy/k8s/submit/ClientSuite.scala | 1 + .../submit/KubernetesDriverBuilderSuite.scala | 45 +++++- .../k8s/KubernetesExecutorBuilderSuite.scala | 38 ++++- 25 files changed, 705 insertions(+), 51 deletions(-) create mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesVolumeSpec.scala create mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesVolumeUtils.scala create mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStep.scala create mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesVolumeUtilsSuite.scala create mode 100644 resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStepSuite.scala diff --git a/docs/running-on-kubernetes.md b/docs/running-on-kubernetes.md index 408e446ea4822..7149616e534aa 100644 --- a/docs/running-on-kubernetes.md +++ b/docs/running-on-kubernetes.md @@ -629,6 +629,54 @@ specific to Spark on Kubernetes. Add as an environment variable to the executor container with name EnvName (case sensitive), the value referenced by key key in the data of the referenced Kubernetes Secret. For example, spark.kubernetes.executor.secrets.ENV_VAR=spark-secret:key. + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala index bf33179ae3dab..f9a77e71ad618 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala @@ -220,11 +220,23 @@ private[spark] object Config extends Logging { val KUBERNETES_DRIVER_ANNOTATION_PREFIX = "spark.kubernetes.driver.annotation." val KUBERNETES_DRIVER_SECRETS_PREFIX = "spark.kubernetes.driver.secrets." val KUBERNETES_DRIVER_SECRET_KEY_REF_PREFIX = "spark.kubernetes.driver.secretKeyRef." + val KUBERNETES_DRIVER_VOLUMES_PREFIX = "spark.kubernetes.driver.volumes." val KUBERNETES_EXECUTOR_LABEL_PREFIX = "spark.kubernetes.executor.label." val KUBERNETES_EXECUTOR_ANNOTATION_PREFIX = "spark.kubernetes.executor.annotation." val KUBERNETES_EXECUTOR_SECRETS_PREFIX = "spark.kubernetes.executor.secrets." val KUBERNETES_EXECUTOR_SECRET_KEY_REF_PREFIX = "spark.kubernetes.executor.secretKeyRef." + val KUBERNETES_EXECUTOR_VOLUMES_PREFIX = "spark.kubernetes.executor.volumes." + + val KUBERNETES_VOLUMES_HOSTPATH_TYPE = "hostPath" + val KUBERNETES_VOLUMES_PVC_TYPE = "persistentVolumeClaim" + val KUBERNETES_VOLUMES_EMPTYDIR_TYPE = "emptyDir" + val KUBERNETES_VOLUMES_MOUNT_PATH_KEY = "mount.path" + val KUBERNETES_VOLUMES_MOUNT_READONLY_KEY = "mount.readOnly" + val KUBERNETES_VOLUMES_OPTIONS_PATH_KEY = "options.path" + val KUBERNETES_VOLUMES_OPTIONS_CLAIM_NAME_KEY = "options.claimName" + val KUBERNETES_VOLUMES_OPTIONS_MEDIUM_KEY = "options.medium" + val KUBERNETES_VOLUMES_OPTIONS_SIZE_LIMIT_KEY = "options.sizeLimit" val KUBERNETES_DRIVER_ENV_PREFIX = "spark.kubernetes.driverEnv." } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesConf.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesConf.scala index b0ccaa36b01ed..51d205fdb68d1 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesConf.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesConf.scala @@ -59,6 +59,7 @@ private[spark] case class KubernetesConf[T <: KubernetesRoleSpecificConf]( roleSecretNamesToMountPaths: Map[String, String], roleSecretEnvNamesToKeyRefs: Map[String, String], roleEnvs: Map[String, String], + roleVolumes: Iterable[KubernetesVolumeSpec[_ <: KubernetesVolumeSpecificConf]], sparkFiles: Seq[String]) { def namespace(): String = sparkConf.get(KUBERNETES_NAMESPACE) @@ -155,6 +156,12 @@ private[spark] object KubernetesConf { sparkConf, KUBERNETES_DRIVER_SECRET_KEY_REF_PREFIX) val driverEnvs = KubernetesUtils.parsePrefixedKeyValuePairs( sparkConf, KUBERNETES_DRIVER_ENV_PREFIX) + val driverVolumes = KubernetesVolumeUtils.parseVolumesWithPrefix( + sparkConf, KUBERNETES_DRIVER_VOLUMES_PREFIX).map(_.get) + // Also parse executor volumes in order to verify configuration + // before the driver pod is created + KubernetesVolumeUtils.parseVolumesWithPrefix( + sparkConf, KUBERNETES_EXECUTOR_VOLUMES_PREFIX).map(_.get) val sparkFiles = sparkConf .getOption("spark.files") @@ -171,6 +178,7 @@ private[spark] object KubernetesConf { driverSecretNamesToMountPaths, driverSecretEnvNamesToKeyRefs, driverEnvs, + driverVolumes, sparkFiles) } @@ -203,6 +211,8 @@ private[spark] object KubernetesConf { val executorEnvSecrets = KubernetesUtils.parsePrefixedKeyValuePairs( sparkConf, KUBERNETES_EXECUTOR_SECRET_KEY_REF_PREFIX) val executorEnv = sparkConf.getExecutorEnv.toMap + val executorVolumes = KubernetesVolumeUtils.parseVolumesWithPrefix( + sparkConf, KUBERNETES_EXECUTOR_VOLUMES_PREFIX).map(_.get) KubernetesConf( sparkConf.clone(), @@ -214,6 +224,7 @@ private[spark] object KubernetesConf { executorMountSecrets, executorEnvSecrets, executorEnv, + executorVolumes, Seq.empty[String]) } } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesUtils.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesUtils.scala index 593fb531a004d..66fff267545dc 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesUtils.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesUtils.scala @@ -16,8 +16,6 @@ */ package org.apache.spark.deploy.k8s -import io.fabric8.kubernetes.api.model.LocalObjectReference - import org.apache.spark.SparkConf import org.apache.spark.util.Utils diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesVolumeSpec.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesVolumeSpec.scala new file mode 100644 index 0000000000000..b1762d1efe2ea --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesVolumeSpec.scala @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s + +private[spark] sealed trait KubernetesVolumeSpecificConf + +private[spark] case class KubernetesHostPathVolumeConf( + hostPath: String) + extends KubernetesVolumeSpecificConf + +private[spark] case class KubernetesPVCVolumeConf( + claimName: String) + extends KubernetesVolumeSpecificConf + +private[spark] case class KubernetesEmptyDirVolumeConf( + medium: Option[String], + sizeLimit: Option[String]) + extends KubernetesVolumeSpecificConf + +private[spark] case class KubernetesVolumeSpec[T <: KubernetesVolumeSpecificConf]( + volumeName: String, + mountPath: String, + mountReadOnly: Boolean, + volumeConf: T) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesVolumeUtils.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesVolumeUtils.scala new file mode 100644 index 0000000000000..713df5fffc3a2 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesVolumeUtils.scala @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s + +import java.util.NoSuchElementException + +import scala.util.{Failure, Success, Try} + +import org.apache.spark.SparkConf +import org.apache.spark.deploy.k8s.Config._ + +private[spark] object KubernetesVolumeUtils { + /** + * Extract Spark volume configuration properties with a given name prefix. + * + * @param sparkConf Spark configuration + * @param prefix the given property name prefix + * @return a Map storing with volume name as key and spec as value + */ + def parseVolumesWithPrefix( + sparkConf: SparkConf, + prefix: String): Iterable[Try[KubernetesVolumeSpec[_ <: KubernetesVolumeSpecificConf]]] = { + val properties = sparkConf.getAllWithPrefix(prefix).toMap + + getVolumeTypesAndNames(properties).map { case (volumeType, volumeName) => + val pathKey = s"$volumeType.$volumeName.$KUBERNETES_VOLUMES_MOUNT_PATH_KEY" + val readOnlyKey = s"$volumeType.$volumeName.$KUBERNETES_VOLUMES_MOUNT_READONLY_KEY" + + for { + path <- properties.getTry(pathKey) + volumeConf <- parseVolumeSpecificConf(properties, volumeType, volumeName) + } yield KubernetesVolumeSpec( + volumeName = volumeName, + mountPath = path, + mountReadOnly = properties.get(readOnlyKey).exists(_.toBoolean), + volumeConf = volumeConf + ) + } + } + + /** + * Get unique pairs of volumeType and volumeName, + * assuming options are formatted in this way: + * `volumeType`.`volumeName`.`property` = `value` + * @param properties flat mapping of property names to values + * @return Set[(volumeType, volumeName)] + */ + private def getVolumeTypesAndNames( + properties: Map[String, String] + ): Set[(String, String)] = { + properties.keys.flatMap { k => + k.split('.').toList match { + case tpe :: name :: _ => Some((tpe, name)) + case _ => None + } + }.toSet + } + + private def parseVolumeSpecificConf( + options: Map[String, String], + volumeType: String, + volumeName: String): Try[KubernetesVolumeSpecificConf] = { + volumeType match { + case KUBERNETES_VOLUMES_HOSTPATH_TYPE => + val pathKey = s"$volumeType.$volumeName.$KUBERNETES_VOLUMES_OPTIONS_PATH_KEY" + for { + path <- options.getTry(pathKey) + } yield KubernetesHostPathVolumeConf(path) + + case KUBERNETES_VOLUMES_PVC_TYPE => + val claimNameKey = s"$volumeType.$volumeName.$KUBERNETES_VOLUMES_OPTIONS_CLAIM_NAME_KEY" + for { + claimName <- options.getTry(claimNameKey) + } yield KubernetesPVCVolumeConf(claimName) + + case KUBERNETES_VOLUMES_EMPTYDIR_TYPE => + val mediumKey = s"$volumeType.$volumeName.$KUBERNETES_VOLUMES_OPTIONS_MEDIUM_KEY" + val sizeLimitKey = s"$volumeType.$volumeName.$KUBERNETES_VOLUMES_OPTIONS_SIZE_LIMIT_KEY" + Success(KubernetesEmptyDirVolumeConf(options.get(mediumKey), options.get(sizeLimitKey))) + + case _ => + Failure(new RuntimeException(s"Kubernetes Volume type `$volumeType` is not supported")) + } + } + + /** + * Convenience wrapper to accumulate key lookup errors + */ + implicit private class MapOps[A, B](m: Map[A, B]) { + def getTry(key: A): Try[B] = { + m + .get(key) + .fold[Try[B]](Failure(new NoSuchElementException(key.toString)))(Success(_)) + } + } +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStep.scala index 143dc8a12304e..7e67b51de6e04 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStep.scala @@ -19,10 +19,10 @@ package org.apache.spark.deploy.k8s.features import scala.collection.JavaConverters._ import scala.collection.mutable -import io.fabric8.kubernetes.api.model.{ContainerBuilder, EnvVarBuilder, EnvVarSourceBuilder, HasMetadata, PodBuilder, QuantityBuilder} +import io.fabric8.kubernetes.api.model._ import org.apache.spark.SparkException -import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpecificConf, KubernetesUtils, SparkPod} +import org.apache.spark.deploy.k8s._ import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ import org.apache.spark.deploy.k8s.submit._ @@ -103,6 +103,7 @@ private[spark] class BasicDriverFeatureStep( .addToImagePullSecrets(conf.imagePullSecrets(): _*) .endSpec() .build() + SparkPod(driverPod, driverContainer) } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala index 91c54a9776982..abaeff0313a79 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala @@ -18,10 +18,10 @@ package org.apache.spark.deploy.k8s.features import scala.collection.JavaConverters._ -import io.fabric8.kubernetes.api.model.{ContainerBuilder, ContainerPortBuilder, EnvVar, EnvVarBuilder, EnvVarSourceBuilder, HasMetadata, PodBuilder, QuantityBuilder} +import io.fabric8.kubernetes.api.model._ import org.apache.spark.SparkException -import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesExecutorSpecificConf, SparkPod} +import org.apache.spark.deploy.k8s._ import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ import org.apache.spark.internal.config.{EXECUTOR_CLASS_PATH, EXECUTOR_JAVA_OPTIONS, EXECUTOR_MEMORY, EXECUTOR_MEMORY_OVERHEAD} @@ -173,6 +173,7 @@ private[spark] class BasicExecutorFeatureStep( .addToImagePullSecrets(kubernetesConf.imagePullSecrets(): _*) .endSpec() .build() + SparkPod(executorPod, containerWithLimitCores) } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStep.scala new file mode 100644 index 0000000000000..bb0e2b3128efd --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStep.scala @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.features + +import io.fabric8.kubernetes.api.model._ + +import org.apache.spark.deploy.k8s._ + +private[spark] class MountVolumesFeatureStep( + kubernetesConf: KubernetesConf[_ <: KubernetesRoleSpecificConf]) + extends KubernetesFeatureConfigStep { + + override def configurePod(pod: SparkPod): SparkPod = { + val (volumeMounts, volumes) = constructVolumes(kubernetesConf.roleVolumes).unzip + + val podWithVolumes = new PodBuilder(pod.pod) + .editSpec() + .addToVolumes(volumes.toSeq: _*) + .endSpec() + .build() + + val containerWithVolumeMounts = new ContainerBuilder(pod.container) + .addToVolumeMounts(volumeMounts.toSeq: _*) + .build() + + SparkPod(podWithVolumes, containerWithVolumeMounts) + } + + override def getAdditionalPodSystemProperties(): Map[String, String] = Map.empty + + override def getAdditionalKubernetesResources(): Seq[HasMetadata] = Seq.empty + + private def constructVolumes( + volumeSpecs: Iterable[KubernetesVolumeSpec[_ <: KubernetesVolumeSpecificConf]] + ): Iterable[(VolumeMount, Volume)] = { + volumeSpecs.map { spec => + val volumeMount = new VolumeMountBuilder() + .withMountPath(spec.mountPath) + .withReadOnly(spec.mountReadOnly) + .withName(spec.volumeName) + .build() + + val volumeBuilder = spec.volumeConf match { + case KubernetesHostPathVolumeConf(hostPath) => + new VolumeBuilder() + .withHostPath(new HostPathVolumeSource(hostPath)) + + case KubernetesPVCVolumeConf(claimName) => + new VolumeBuilder() + .withPersistentVolumeClaim( + new PersistentVolumeClaimVolumeSource(claimName, spec.mountReadOnly)) + + case KubernetesEmptyDirVolumeConf(medium, sizeLimit) => + new VolumeBuilder() + .withEmptyDir( + new EmptyDirVolumeSource(medium.getOrElse(""), + new Quantity(sizeLimit.orNull))) + } + + val volume = volumeBuilder.withName(spec.volumeName).build() + + (volumeMount, volume) + } + } +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala index 0dd1c37661707..7208e3d377593 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala @@ -17,7 +17,7 @@ package org.apache.spark.deploy.k8s.submit import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpec, KubernetesDriverSpecificConf, KubernetesRoleSpecificConf} -import org.apache.spark.deploy.k8s.features.{BasicDriverFeatureStep, DriverKubernetesCredentialsFeatureStep, DriverServiceFeatureStep, EnvSecretsFeatureStep, KubernetesFeatureConfigStep, LocalDirsFeatureStep, MountSecretsFeatureStep} +import org.apache.spark.deploy.k8s.features._ import org.apache.spark.deploy.k8s.features.bindings.{JavaDriverFeatureStep, PythonDriverFeatureStep} private[spark] class KubernetesDriverBuilder( @@ -33,10 +33,13 @@ private[spark] class KubernetesDriverBuilder( new MountSecretsFeatureStep(_), provideEnvSecretsStep: (KubernetesConf[_ <: KubernetesRoleSpecificConf] => EnvSecretsFeatureStep) = - new EnvSecretsFeatureStep(_), - provideLocalDirsStep: (KubernetesConf[_ <: KubernetesRoleSpecificConf] - => LocalDirsFeatureStep) = + new EnvSecretsFeatureStep(_), + provideLocalDirsStep: (KubernetesConf[_ <: KubernetesRoleSpecificConf]) + => LocalDirsFeatureStep = new LocalDirsFeatureStep(_), + provideVolumesStep: (KubernetesConf[_ <: KubernetesRoleSpecificConf] + => MountVolumesFeatureStep) = + new MountVolumesFeatureStep(_), provideJavaStep: ( KubernetesConf[KubernetesDriverSpecificConf] => JavaDriverFeatureStep) = @@ -54,11 +57,15 @@ private[spark] class KubernetesDriverBuilder( provideServiceStep(kubernetesConf), provideLocalDirsStep(kubernetesConf)) - val maybeRoleSecretNamesStep = if (kubernetesConf.roleSecretNamesToMountPaths.nonEmpty) { - Some(provideSecretsStep(kubernetesConf)) } else None - - val maybeProvideSecretsStep = if (kubernetesConf.roleSecretEnvNamesToKeyRefs.nonEmpty) { - Some(provideEnvSecretsStep(kubernetesConf)) } else None + val secretFeature = if (kubernetesConf.roleSecretNamesToMountPaths.nonEmpty) { + Seq(provideSecretsStep(kubernetesConf)) + } else Nil + val envSecretFeature = if (kubernetesConf.roleSecretEnvNamesToKeyRefs.nonEmpty) { + Seq(provideEnvSecretsStep(kubernetesConf)) + } else Nil + val volumesFeature = if (kubernetesConf.roleVolumes.nonEmpty) { + Seq(provideVolumesStep(kubernetesConf)) + } else Nil val bindingsStep = kubernetesConf.roleSpecificConf.mainAppResource.map { case JavaMainAppResource(_) => @@ -67,10 +74,8 @@ private[spark] class KubernetesDriverBuilder( providePythonStep(kubernetesConf)} .getOrElse(provideJavaStep(kubernetesConf)) - val allFeatures: Seq[KubernetesFeatureConfigStep] = - (baseFeatures :+ bindingsStep) ++ - maybeRoleSecretNamesStep.toSeq ++ - maybeProvideSecretsStep.toSeq + val allFeatures = (baseFeatures :+ bindingsStep) ++ + secretFeature ++ envSecretFeature ++ volumesFeature var spec = KubernetesDriverSpec.initialSpec(kubernetesConf.sparkConf.getAll.toMap) for (feature <- allFeatures) { diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala index 769a0a5a63047..364b6fb367722 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala @@ -17,37 +17,41 @@ package org.apache.spark.scheduler.cluster.k8s import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesExecutorSpecificConf, KubernetesRoleSpecificConf, SparkPod} -import org.apache.spark.deploy.k8s.features.{BasicExecutorFeatureStep, EnvSecretsFeatureStep, KubernetesFeatureConfigStep, LocalDirsFeatureStep, MountSecretsFeatureStep} +import org.apache.spark.deploy.k8s.features._ +import org.apache.spark.deploy.k8s.features.{BasicExecutorFeatureStep, EnvSecretsFeatureStep, LocalDirsFeatureStep, MountSecretsFeatureStep} private[spark] class KubernetesExecutorBuilder( - provideBasicStep: (KubernetesConf[KubernetesExecutorSpecificConf]) => BasicExecutorFeatureStep = + provideBasicStep: (KubernetesConf [KubernetesExecutorSpecificConf]) + => BasicExecutorFeatureStep = new BasicExecutorFeatureStep(_), - provideSecretsStep: - (KubernetesConf[_ <: KubernetesRoleSpecificConf]) => MountSecretsFeatureStep = + provideSecretsStep: (KubernetesConf[_ <: KubernetesRoleSpecificConf]) + => MountSecretsFeatureStep = new MountSecretsFeatureStep(_), provideEnvSecretsStep: (KubernetesConf[_ <: KubernetesRoleSpecificConf] => EnvSecretsFeatureStep) = new EnvSecretsFeatureStep(_), provideLocalDirsStep: (KubernetesConf[_ <: KubernetesRoleSpecificConf]) => LocalDirsFeatureStep = - new LocalDirsFeatureStep(_)) { + new LocalDirsFeatureStep(_), + provideVolumesStep: (KubernetesConf[_ <: KubernetesRoleSpecificConf] + => MountVolumesFeatureStep) = + new MountVolumesFeatureStep(_)) { def buildFromFeatures( kubernetesConf: KubernetesConf[KubernetesExecutorSpecificConf]): SparkPod = { - val baseFeatures = Seq( - provideBasicStep(kubernetesConf), - provideLocalDirsStep(kubernetesConf)) - val maybeRoleSecretNamesStep = if (kubernetesConf.roleSecretNamesToMountPaths.nonEmpty) { - Some(provideSecretsStep(kubernetesConf)) } else None + val baseFeatures = Seq(provideBasicStep(kubernetesConf), provideLocalDirsStep(kubernetesConf)) + val secretFeature = if (kubernetesConf.roleSecretNamesToMountPaths.nonEmpty) { + Seq(provideSecretsStep(kubernetesConf)) + } else Nil + val secretEnvFeature = if (kubernetesConf.roleSecretEnvNamesToKeyRefs.nonEmpty) { + Seq(provideEnvSecretsStep(kubernetesConf)) + } else Nil + val volumesFeature = if (kubernetesConf.roleVolumes.nonEmpty) { + Seq(provideVolumesStep(kubernetesConf)) + } else Nil - val maybeProvideSecretsStep = if (kubernetesConf.roleSecretEnvNamesToKeyRefs.nonEmpty) { - Some(provideEnvSecretsStep(kubernetesConf)) } else None - - val allFeatures: Seq[KubernetesFeatureConfigStep] = - baseFeatures ++ - maybeRoleSecretNamesStep.toSeq ++ - maybeProvideSecretsStep.toSeq + val allFeatures = baseFeatures ++ secretFeature ++ secretEnvFeature ++ volumesFeature var executorPod = SparkPod.initialPod() for (feature <- allFeatures) { diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesVolumeUtilsSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesVolumeUtilsSuite.scala new file mode 100644 index 0000000000000..d795d159773a8 --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesVolumeUtilsSuite.scala @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s + +import org.apache.spark.{SparkConf, SparkFunSuite} + +class KubernetesVolumeUtilsSuite extends SparkFunSuite { + test("Parses hostPath volumes correctly") { + val sparkConf = new SparkConf(false) + sparkConf.set("test.hostPath.volumeName.mount.path", "/path") + sparkConf.set("test.hostPath.volumeName.mount.readOnly", "true") + sparkConf.set("test.hostPath.volumeName.options.path", "/hostPath") + + val volumeSpec = KubernetesVolumeUtils.parseVolumesWithPrefix(sparkConf, "test.").head.get + assert(volumeSpec.volumeName === "volumeName") + assert(volumeSpec.mountPath === "/path") + assert(volumeSpec.mountReadOnly === true) + assert(volumeSpec.volumeConf.asInstanceOf[KubernetesHostPathVolumeConf] === + KubernetesHostPathVolumeConf("/hostPath")) + } + + test("Parses persistentVolumeClaim volumes correctly") { + val sparkConf = new SparkConf(false) + sparkConf.set("test.persistentVolumeClaim.volumeName.mount.path", "/path") + sparkConf.set("test.persistentVolumeClaim.volumeName.mount.readOnly", "true") + sparkConf.set("test.persistentVolumeClaim.volumeName.options.claimName", "claimeName") + + val volumeSpec = KubernetesVolumeUtils.parseVolumesWithPrefix(sparkConf, "test.").head.get + assert(volumeSpec.volumeName === "volumeName") + assert(volumeSpec.mountPath === "/path") + assert(volumeSpec.mountReadOnly === true) + assert(volumeSpec.volumeConf.asInstanceOf[KubernetesPVCVolumeConf] === + KubernetesPVCVolumeConf("claimeName")) + } + + test("Parses emptyDir volumes correctly") { + val sparkConf = new SparkConf(false) + sparkConf.set("test.emptyDir.volumeName.mount.path", "/path") + sparkConf.set("test.emptyDir.volumeName.mount.readOnly", "true") + sparkConf.set("test.emptyDir.volumeName.options.medium", "medium") + sparkConf.set("test.emptyDir.volumeName.options.sizeLimit", "5G") + + val volumeSpec = KubernetesVolumeUtils.parseVolumesWithPrefix(sparkConf, "test.").head.get + assert(volumeSpec.volumeName === "volumeName") + assert(volumeSpec.mountPath === "/path") + assert(volumeSpec.mountReadOnly === true) + assert(volumeSpec.volumeConf.asInstanceOf[KubernetesEmptyDirVolumeConf] === + KubernetesEmptyDirVolumeConf(Some("medium"), Some("5G"))) + } + + test("Parses emptyDir volume options can be optional") { + val sparkConf = new SparkConf(false) + sparkConf.set("test.emptyDir.volumeName.mount.path", "/path") + sparkConf.set("test.emptyDir.volumeName.mount.readOnly", "true") + + val volumeSpec = KubernetesVolumeUtils.parseVolumesWithPrefix(sparkConf, "test.").head.get + assert(volumeSpec.volumeName === "volumeName") + assert(volumeSpec.mountPath === "/path") + assert(volumeSpec.mountReadOnly === true) + assert(volumeSpec.volumeConf.asInstanceOf[KubernetesEmptyDirVolumeConf] === + KubernetesEmptyDirVolumeConf(None, None)) + } + + test("Defaults optional readOnly to false") { + val sparkConf = new SparkConf(false) + sparkConf.set("test.hostPath.volumeName.mount.path", "/path") + sparkConf.set("test.hostPath.volumeName.options.path", "/hostPath") + + val volumeSpec = KubernetesVolumeUtils.parseVolumesWithPrefix(sparkConf, "test.").head.get + assert(volumeSpec.mountReadOnly === false) + } + + test("Gracefully fails on missing mount key") { + val sparkConf = new SparkConf(false) + sparkConf.set("test.emptyDir.volumeName.mnt.path", "/path") + + val volumeSpec = KubernetesVolumeUtils.parseVolumesWithPrefix(sparkConf, "test.").head + assert(volumeSpec.isFailure === true) + assert(volumeSpec.failed.get.getMessage === "emptyDir.volumeName.mount.path") + } + + test("Gracefully fails on missing option key") { + val sparkConf = new SparkConf(false) + sparkConf.set("test.hostPath.volumeName.mount.path", "/path") + sparkConf.set("test.hostPath.volumeName.mount.readOnly", "true") + sparkConf.set("test.hostPath.volumeName.options.pth", "/hostPath") + + val volumeSpec = KubernetesVolumeUtils.parseVolumesWithPrefix(sparkConf, "test.").head + assert(volumeSpec.isFailure === true) + assert(volumeSpec.failed.get.getMessage === "hostPath.volumeName.options.path") + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStepSuite.scala index 04b909db9d9f3..165f46a07df2f 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStepSuite.scala @@ -50,6 +50,12 @@ class BasicDriverFeatureStepSuite extends SparkFunSuite { TEST_IMAGE_PULL_SECRETS.map { secret => new LocalObjectReferenceBuilder().withName(secret).build() } + private val emptyDriverSpecificConf = KubernetesDriverSpecificConf( + None, + APP_NAME, + MAIN_CLASS, + APP_ARGS) + test("Check the pod respects all configurations from the user.") { val sparkConf = new SparkConf() @@ -62,11 +68,7 @@ class BasicDriverFeatureStepSuite extends SparkFunSuite { .set(IMAGE_PULL_SECRETS, TEST_IMAGE_PULL_SECRETS.mkString(",")) val kubernetesConf = KubernetesConf( sparkConf, - KubernetesDriverSpecificConf( - Some(JavaMainAppResource("")), - APP_NAME, - MAIN_CLASS, - APP_ARGS), + emptyDriverSpecificConf, RESOURCE_NAME_PREFIX, APP_ID, DRIVER_LABELS, @@ -74,6 +76,7 @@ class BasicDriverFeatureStepSuite extends SparkFunSuite { Map.empty, Map.empty, DRIVER_ENVS, + Nil, Seq.empty[String]) val featureStep = new BasicDriverFeatureStep(kubernetesConf) @@ -143,6 +146,7 @@ class BasicDriverFeatureStepSuite extends SparkFunSuite { Map.empty, Map.empty, DRIVER_ENVS, + Nil, Seq.empty[String]) val pythonKubernetesConf = KubernetesConf( pythonSparkConf, @@ -158,6 +162,7 @@ class BasicDriverFeatureStepSuite extends SparkFunSuite { Map.empty, Map.empty, DRIVER_ENVS, + Nil, Seq.empty[String]) val javaFeatureStep = new BasicDriverFeatureStep(javaKubernetesConf) val pythonFeatureStep = new BasicDriverFeatureStep(pythonKubernetesConf) @@ -176,11 +181,7 @@ class BasicDriverFeatureStepSuite extends SparkFunSuite { .set(CONTAINER_IMAGE, "spark-driver:latest") val kubernetesConf = KubernetesConf( sparkConf, - KubernetesDriverSpecificConf( - Some(JavaMainAppResource("")), - APP_NAME, - MAIN_CLASS, - APP_ARGS), + emptyDriverSpecificConf, RESOURCE_NAME_PREFIX, APP_ID, DRIVER_LABELS, @@ -188,7 +189,9 @@ class BasicDriverFeatureStepSuite extends SparkFunSuite { Map.empty, Map.empty, DRIVER_ENVS, + Nil, allFiles) + val step = new BasicDriverFeatureStep(kubernetesConf) val additionalProperties = step.getAdditionalPodSystemProperties() val expectedSparkConf = Map( diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStepSuite.scala index f06030aa55c0c..a44fa1f2ffc63 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStepSuite.scala @@ -89,6 +89,7 @@ class BasicExecutorFeatureStepSuite Map.empty, Map.empty, Map.empty, + Nil, Seq.empty[String])) val executor = step.configurePod(SparkPod.initialPod()) @@ -128,6 +129,7 @@ class BasicExecutorFeatureStepSuite Map.empty, Map.empty, Map.empty, + Nil, Seq.empty[String])) assert(step.configurePod(SparkPod.initialPod()).pod.getSpec.getHostname.length === 63) } @@ -148,6 +150,7 @@ class BasicExecutorFeatureStepSuite Map.empty, Map.empty, Map("qux" -> "quux"), + Nil, Seq.empty[String])) val executor = step.configurePod(SparkPod.initialPod()) diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverKubernetesCredentialsFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverKubernetesCredentialsFeatureStepSuite.scala index 7cea83591f3e8..7e916b3854404 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverKubernetesCredentialsFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverKubernetesCredentialsFeatureStepSuite.scala @@ -61,6 +61,7 @@ class DriverKubernetesCredentialsFeatureStepSuite extends SparkFunSuite with Bef Map.empty, Map.empty, Map.empty, + Nil, Seq.empty[String]) val kubernetesCredentialsStep = new DriverKubernetesCredentialsFeatureStep(kubernetesConf) assert(kubernetesCredentialsStep.configurePod(BASE_DRIVER_POD) === BASE_DRIVER_POD) @@ -92,6 +93,7 @@ class DriverKubernetesCredentialsFeatureStepSuite extends SparkFunSuite with Bef Map.empty, Map.empty, Map.empty, + Nil, Seq.empty[String]) val kubernetesCredentialsStep = new DriverKubernetesCredentialsFeatureStep(kubernetesConf) @@ -130,6 +132,7 @@ class DriverKubernetesCredentialsFeatureStepSuite extends SparkFunSuite with Bef Map.empty, Map.empty, Map.empty, + Nil, Seq.empty[String]) val kubernetesCredentialsStep = new DriverKubernetesCredentialsFeatureStep(kubernetesConf) val resolvedProperties = kubernetesCredentialsStep.getAdditionalPodSystemProperties() diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverServiceFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverServiceFeatureStepSuite.scala index 77d38bf19cd10..8b91e93eecd8c 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverServiceFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverServiceFeatureStepSuite.scala @@ -67,6 +67,7 @@ class DriverServiceFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { Map.empty, Map.empty, Map.empty, + Nil, Seq.empty[String])) assert(configurationStep.configurePod(SparkPod.initialPod()) === SparkPod.initialPod()) assert(configurationStep.getAdditionalKubernetesResources().size === 1) @@ -98,6 +99,7 @@ class DriverServiceFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { Map.empty, Map.empty, Map.empty, + Nil, Seq.empty[String])) val expectedServiceName = SHORT_RESOURCE_NAME_PREFIX + DriverServiceFeatureStep.DRIVER_SVC_POSTFIX @@ -119,6 +121,7 @@ class DriverServiceFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { Map.empty, Map.empty, Map.empty, + Nil, Seq.empty[String])) val resolvedService = configurationStep .getAdditionalKubernetesResources() @@ -149,6 +152,7 @@ class DriverServiceFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { Map.empty, Map.empty, Map.empty, + Nil, Seq.empty[String]), clock) val driverService = configurationStep @@ -176,6 +180,7 @@ class DriverServiceFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { Map.empty, Map.empty, Map.empty, + Nil, Seq.empty[String]), clock) fail("The driver bind address should not be allowed.") @@ -201,6 +206,7 @@ class DriverServiceFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { Map.empty, Map.empty, Map.empty, + Nil, Seq.empty[String]), clock) fail("The driver host address should not be allowed.") diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/EnvSecretsFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/EnvSecretsFeatureStepSuite.scala index af6b35eae484a..1c8d84b76c56b 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/EnvSecretsFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/EnvSecretsFeatureStepSuite.scala @@ -45,6 +45,7 @@ class EnvSecretsFeatureStepSuite extends SparkFunSuite{ Map.empty, envVarsToKeys, Map.empty, + Nil, Seq.empty[String]) val step = new EnvSecretsFeatureStep(kubernetesConf) diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/LocalDirsFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/LocalDirsFeatureStepSuite.scala index bd6ce4b42fc8e..a339827b819a9 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/LocalDirsFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/LocalDirsFeatureStepSuite.scala @@ -21,7 +21,7 @@ import org.mockito.Mockito import org.scalatest.BeforeAndAfter import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpecificConf, KubernetesExecutorSpecificConf, KubernetesRoleSpecificConf, SparkPod} +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpecificConf, KubernetesRoleSpecificConf, SparkPod} class LocalDirsFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { private val defaultLocalDir = "/var/data/default-local-dir" @@ -45,6 +45,7 @@ class LocalDirsFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { Map.empty, Map.empty, Map.empty, + Nil, Seq.empty[String]) } diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountSecretsFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountSecretsFeatureStepSuite.scala index eff75b8a15daa..2b49b72dfa569 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountSecretsFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountSecretsFeatureStepSuite.scala @@ -43,6 +43,7 @@ class MountSecretsFeatureStepSuite extends SparkFunSuite { secretNamesToMountPaths, Map.empty, Map.empty, + Nil, Seq.empty[String]) val step = new MountSecretsFeatureStep(kubernetesConf) diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStepSuite.scala new file mode 100644 index 0000000000000..d309aa94ec115 --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStepSuite.scala @@ -0,0 +1,144 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.features + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.k8s._ + +class MountVolumesFeatureStepSuite extends SparkFunSuite { + private val sparkConf = new SparkConf(false) + private val emptyKubernetesConf = KubernetesConf( + sparkConf = sparkConf, + roleSpecificConf = KubernetesDriverSpecificConf( + None, + "app-name", + "main", + Seq.empty), + appResourceNamePrefix = "resource", + appId = "app-id", + roleLabels = Map.empty, + roleAnnotations = Map.empty, + roleSecretNamesToMountPaths = Map.empty, + roleSecretEnvNamesToKeyRefs = Map.empty, + roleEnvs = Map.empty, + roleVolumes = Nil, + sparkFiles = Nil) + + test("Mounts hostPath volumes") { + val volumeConf = KubernetesVolumeSpec( + "testVolume", + "/tmp", + false, + KubernetesHostPathVolumeConf("/hostPath/tmp") + ) + val kubernetesConf = emptyKubernetesConf.copy(roleVolumes = volumeConf :: Nil) + val step = new MountVolumesFeatureStep(kubernetesConf) + val configuredPod = step.configurePod(SparkPod.initialPod()) + + assert(configuredPod.pod.getSpec.getVolumes.size() === 1) + assert(configuredPod.pod.getSpec.getVolumes.get(0).getHostPath.getPath === "/hostPath/tmp") + assert(configuredPod.container.getVolumeMounts.size() === 1) + assert(configuredPod.container.getVolumeMounts.get(0).getMountPath === "/tmp") + assert(configuredPod.container.getVolumeMounts.get(0).getName === "testVolume") + assert(configuredPod.container.getVolumeMounts.get(0).getReadOnly === false) + } + + test("Mounts pesistentVolumeClaims") { + val volumeConf = KubernetesVolumeSpec( + "testVolume", + "/tmp", + true, + KubernetesPVCVolumeConf("pvcClaim") + ) + val kubernetesConf = emptyKubernetesConf.copy(roleVolumes = volumeConf :: Nil) + val step = new MountVolumesFeatureStep(kubernetesConf) + val configuredPod = step.configurePod(SparkPod.initialPod()) + + assert(configuredPod.pod.getSpec.getVolumes.size() === 1) + val pvcClaim = configuredPod.pod.getSpec.getVolumes.get(0).getPersistentVolumeClaim + assert(pvcClaim.getClaimName === "pvcClaim") + assert(configuredPod.container.getVolumeMounts.size() === 1) + assert(configuredPod.container.getVolumeMounts.get(0).getMountPath === "/tmp") + assert(configuredPod.container.getVolumeMounts.get(0).getName === "testVolume") + assert(configuredPod.container.getVolumeMounts.get(0).getReadOnly === true) + + } + + test("Mounts emptyDir") { + val volumeConf = KubernetesVolumeSpec( + "testVolume", + "/tmp", + false, + KubernetesEmptyDirVolumeConf(Some("Memory"), Some("6G")) + ) + val kubernetesConf = emptyKubernetesConf.copy(roleVolumes = volumeConf :: Nil) + val step = new MountVolumesFeatureStep(kubernetesConf) + val configuredPod = step.configurePod(SparkPod.initialPod()) + + assert(configuredPod.pod.getSpec.getVolumes.size() === 1) + val emptyDir = configuredPod.pod.getSpec.getVolumes.get(0).getEmptyDir + assert(emptyDir.getMedium === "Memory") + assert(emptyDir.getSizeLimit.getAmount === "6G") + assert(configuredPod.container.getVolumeMounts.size() === 1) + assert(configuredPod.container.getVolumeMounts.get(0).getMountPath === "/tmp") + assert(configuredPod.container.getVolumeMounts.get(0).getName === "testVolume") + assert(configuredPod.container.getVolumeMounts.get(0).getReadOnly === false) + } + + test("Mounts emptyDir with no options") { + val volumeConf = KubernetesVolumeSpec( + "testVolume", + "/tmp", + false, + KubernetesEmptyDirVolumeConf(None, None) + ) + val kubernetesConf = emptyKubernetesConf.copy(roleVolumes = volumeConf :: Nil) + val step = new MountVolumesFeatureStep(kubernetesConf) + val configuredPod = step.configurePod(SparkPod.initialPod()) + + assert(configuredPod.pod.getSpec.getVolumes.size() === 1) + val emptyDir = configuredPod.pod.getSpec.getVolumes.get(0).getEmptyDir + assert(emptyDir.getMedium === "") + assert(emptyDir.getSizeLimit.getAmount === null) + assert(configuredPod.container.getVolumeMounts.size() === 1) + assert(configuredPod.container.getVolumeMounts.get(0).getMountPath === "/tmp") + assert(configuredPod.container.getVolumeMounts.get(0).getName === "testVolume") + assert(configuredPod.container.getVolumeMounts.get(0).getReadOnly === false) + } + + test("Mounts multiple volumes") { + val hpVolumeConf = KubernetesVolumeSpec( + "hpVolume", + "/tmp", + false, + KubernetesHostPathVolumeConf("/hostPath/tmp") + ) + val pvcVolumeConf = KubernetesVolumeSpec( + "checkpointVolume", + "/checkpoints", + true, + KubernetesPVCVolumeConf("pvcClaim") + ) + val volumesConf = hpVolumeConf :: pvcVolumeConf :: Nil + val kubernetesConf = emptyKubernetesConf.copy(roleVolumes = volumesConf) + val step = new MountVolumesFeatureStep(kubernetesConf) + val configuredPod = step.configurePod(SparkPod.initialPod()) + + assert(configuredPod.pod.getSpec.getVolumes.size() === 2) + assert(configuredPod.container.getVolumeMounts.size() === 2) + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/bindings/JavaDriverFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/bindings/JavaDriverFeatureStepSuite.scala index 0f2bf2fa1d9b5..18874afe6e53a 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/bindings/JavaDriverFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/bindings/JavaDriverFeatureStepSuite.scala @@ -42,6 +42,7 @@ class JavaDriverFeatureStepSuite extends SparkFunSuite { roleSecretNamesToMountPaths = Map.empty, roleSecretEnvNamesToKeyRefs = Map.empty, roleEnvs = Map.empty, + roleVolumes = Nil, sparkFiles = Seq.empty[String]) val step = new JavaDriverFeatureStep(kubernetesConf) diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/bindings/PythonDriverFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/bindings/PythonDriverFeatureStepSuite.scala index a1f9a5d9e264e..a5dac6869327d 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/bindings/PythonDriverFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/bindings/PythonDriverFeatureStepSuite.scala @@ -52,6 +52,7 @@ class PythonDriverFeatureStepSuite extends SparkFunSuite { roleSecretNamesToMountPaths = Map.empty, roleSecretEnvNamesToKeyRefs = Map.empty, roleEnvs = Map.empty, + roleVolumes = Nil, sparkFiles = Seq.empty[String]) val step = new PythonDriverFeatureStep(kubernetesConf) @@ -88,6 +89,7 @@ class PythonDriverFeatureStepSuite extends SparkFunSuite { roleSecretNamesToMountPaths = Map.empty, roleSecretEnvNamesToKeyRefs = Map.empty, roleEnvs = Map.empty, + roleVolumes = Nil, sparkFiles = Seq.empty[String]) val step = new PythonDriverFeatureStep(kubernetesConf) val driverContainerwithPySpark = step.configurePod(baseDriverPod).container diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala index d045d9ae89c07..4d8e79189ff32 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala @@ -141,6 +141,7 @@ class ClientSuite extends SparkFunSuite with BeforeAndAfter { Map.empty, Map.empty, Map.empty, + Nil, Seq.empty[String]) when(driverBuilder.buildFromFeatures(kubernetesConf)).thenReturn(BUILT_KUBERNETES_SPEC) when(kubernetesClient.pods()).thenReturn(podOperations) diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala index 4e8c300543430..046e578b94629 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala @@ -17,7 +17,8 @@ package org.apache.spark.deploy.k8s.submit import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpec, KubernetesDriverSpecificConf} +import org.apache.spark.deploy.k8s._ +import org.apache.spark.deploy.k8s.features._ import org.apache.spark.deploy.k8s.features.{BasicDriverFeatureStep, DriverKubernetesCredentialsFeatureStep, DriverServiceFeatureStep, EnvSecretsFeatureStep, KubernetesFeaturesTestUtils, LocalDirsFeatureStep, MountSecretsFeatureStep} import org.apache.spark.deploy.k8s.features.bindings.{JavaDriverFeatureStep, PythonDriverFeatureStep} @@ -31,6 +32,7 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite { private val JAVA_STEP_TYPE = "java-bindings" private val PYSPARK_STEP_TYPE = "pyspark-bindings" private val ENV_SECRETS_STEP_TYPE = "env-secrets" + private val MOUNT_VOLUMES_STEP_TYPE = "mount-volumes" private val basicFeatureStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( BASIC_STEP_TYPE, classOf[BasicDriverFeatureStep]) @@ -56,6 +58,9 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite { private val envSecretsStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( ENV_SECRETS_STEP_TYPE, classOf[EnvSecretsFeatureStep]) + private val mountVolumesStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( + MOUNT_VOLUMES_STEP_TYPE, classOf[MountVolumesFeatureStep]) + private val builderUnderTest: KubernetesDriverBuilder = new KubernetesDriverBuilder( _ => basicFeatureStep, @@ -64,6 +69,7 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite { _ => secretsStep, _ => envSecretsStep, _ => localDirsStep, + _ => mountVolumesStep, _ => javaStep, _ => pythonStep) @@ -82,6 +88,7 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite { Map.empty, Map.empty, Map.empty, + Nil, Seq.empty[String]) validateStepTypesApplied( builderUnderTest.buildFromFeatures(conf), @@ -107,6 +114,7 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite { Map("secret" -> "secretMountPath"), Map("EnvName" -> "SecretName:secretKey"), Map.empty, + Nil, Seq.empty[String]) validateStepTypesApplied( builderUnderTest.buildFromFeatures(conf), @@ -134,6 +142,7 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite { Map.empty, Map.empty, Map.empty, + Nil, Seq.empty[String]) validateStepTypesApplied( builderUnderTest.buildFromFeatures(conf), @@ -159,6 +168,7 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite { Map.empty, Map.empty, Map.empty, + Nil, Seq.empty[String]) validateStepTypesApplied( builderUnderTest.buildFromFeatures(conf), @@ -169,6 +179,39 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite { PYSPARK_STEP_TYPE) } + test("Apply volumes step if mounts are present.") { + val volumeSpec = KubernetesVolumeSpec( + "volume", + "/tmp", + false, + KubernetesHostPathVolumeConf("/path")) + val conf = KubernetesConf( + new SparkConf(false), + KubernetesDriverSpecificConf( + None, + "test-app", + "main", + Seq.empty), + "prefix", + "appId", + Map.empty, + Map.empty, + Map.empty, + Map.empty, + Map.empty, + volumeSpec :: Nil, + Seq.empty[String]) + validateStepTypesApplied( + builderUnderTest.buildFromFeatures(conf), + BASIC_STEP_TYPE, + CREDENTIALS_STEP_TYPE, + SERVICE_STEP_TYPE, + LOCAL_DIRS_STEP_TYPE, + MOUNT_VOLUMES_STEP_TYPE, + JAVA_STEP_TYPE) + } + + private def validateStepTypesApplied(resolvedSpec: KubernetesDriverSpec, stepTypes: String*) : Unit = { assert(resolvedSpec.systemProperties.size === stepTypes.size) diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala index a6bc8bce32926..d0b4127065eb7 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala @@ -19,14 +19,15 @@ package org.apache.spark.scheduler.cluster.k8s import io.fabric8.kubernetes.api.model.PodBuilder import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesExecutorSpecificConf, SparkPod} -import org.apache.spark.deploy.k8s.features.{BasicExecutorFeatureStep, EnvSecretsFeatureStep, KubernetesFeaturesTestUtils, LocalDirsFeatureStep, MountSecretsFeatureStep} +import org.apache.spark.deploy.k8s._ +import org.apache.spark.deploy.k8s.features._ class KubernetesExecutorBuilderSuite extends SparkFunSuite { private val BASIC_STEP_TYPE = "basic" private val SECRETS_STEP_TYPE = "mount-secrets" private val ENV_SECRETS_STEP_TYPE = "env-secrets" private val LOCAL_DIRS_STEP_TYPE = "local-dirs" + private val MOUNT_VOLUMES_STEP_TYPE = "mount-volumes" private val basicFeatureStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( BASIC_STEP_TYPE, classOf[BasicExecutorFeatureStep]) @@ -36,12 +37,15 @@ class KubernetesExecutorBuilderSuite extends SparkFunSuite { ENV_SECRETS_STEP_TYPE, classOf[EnvSecretsFeatureStep]) private val localDirsStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( LOCAL_DIRS_STEP_TYPE, classOf[LocalDirsFeatureStep]) + private val mountVolumesStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( + MOUNT_VOLUMES_STEP_TYPE, classOf[MountVolumesFeatureStep]) private val builderUnderTest = new KubernetesExecutorBuilder( _ => basicFeatureStep, _ => mountSecretsStep, _ => envSecretsStep, - _ => localDirsStep) + _ => localDirsStep, + _ => mountVolumesStep) test("Basic steps are consistently applied.") { val conf = KubernetesConf( @@ -55,6 +59,7 @@ class KubernetesExecutorBuilderSuite extends SparkFunSuite { Map.empty, Map.empty, Map.empty, + Nil, Seq.empty[String]) validateStepTypesApplied( builderUnderTest.buildFromFeatures(conf), BASIC_STEP_TYPE, LOCAL_DIRS_STEP_TYPE) @@ -72,6 +77,7 @@ class KubernetesExecutorBuilderSuite extends SparkFunSuite { Map("secret" -> "secretMountPath"), Map("secret-name" -> "secret-key"), Map.empty, + Nil, Seq.empty[String]) validateStepTypesApplied( builderUnderTest.buildFromFeatures(conf), @@ -81,6 +87,32 @@ class KubernetesExecutorBuilderSuite extends SparkFunSuite { ENV_SECRETS_STEP_TYPE) } + test("Apply volumes step if mounts are present.") { + val volumeSpec = KubernetesVolumeSpec( + "volume", + "/tmp", + false, + KubernetesHostPathVolumeConf("/checkpoint")) + val conf = KubernetesConf( + new SparkConf(false), + KubernetesExecutorSpecificConf( + "executor-id", new PodBuilder().build()), + "prefix", + "appId", + Map.empty, + Map.empty, + Map.empty, + Map.empty, + Map.empty, + volumeSpec :: Nil, + Seq.empty[String]) + validateStepTypesApplied( + builderUnderTest.buildFromFeatures(conf), + BASIC_STEP_TYPE, + LOCAL_DIRS_STEP_TYPE, + MOUNT_VOLUMES_STEP_TYPE) + } + private def validateStepTypesApplied(resolvedPod: SparkPod, stepTypes: String*): Unit = { assert(resolvedPod.pod.getMetadata.getLabels.size === stepTypes.size) stepTypes.foreach { stepType => From 006e798e477b6871ad3ba4417d354d23f45e4013 Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Tue, 10 Jul 2018 23:18:07 -0700 Subject: [PATCH 1095/1161] [SPARK-23461][R] vignettes should include model predictions for some ML models ## What changes were proposed in this pull request? Add model predictions for Linear Support Vector Machine (SVM) Classifier, Logistic Regression, GBT, RF and DecisionTree in vignettes. ## How was this patch tested? Manually ran the test and checked the result. Author: Huaxin Gao Closes #21678 from huaxingao/spark-23461. --- R/pkg/vignettes/sparkr-vignettes.Rmd | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/R/pkg/vignettes/sparkr-vignettes.Rmd b/R/pkg/vignettes/sparkr-vignettes.Rmd index d4713de7806a1..68a18ab57b28d 100644 --- a/R/pkg/vignettes/sparkr-vignettes.Rmd +++ b/R/pkg/vignettes/sparkr-vignettes.Rmd @@ -590,6 +590,7 @@ summary(model) Predict values on training data ```{r} prediction <- predict(model, training) +head(select(prediction, "Class", "Sex", "Age", "Freq", "Survived", "prediction")) ``` #### Logistic Regression @@ -613,6 +614,7 @@ summary(model) Predict values on training data ```{r} fitted <- predict(model, training) +head(select(fitted, "Class", "Sex", "Age", "Freq", "Survived", "prediction")) ``` Multinomial logistic regression against three classes @@ -807,6 +809,7 @@ df <- createDataFrame(t) dtModel <- spark.decisionTree(df, Survived ~ ., type = "classification", maxDepth = 2) summary(dtModel) predictions <- predict(dtModel, df) +head(select(predictions, "Class", "Sex", "Age", "Freq", "Survived", "prediction")) ``` #### Gradient-Boosted Trees @@ -822,6 +825,7 @@ df <- createDataFrame(t) gbtModel <- spark.gbt(df, Survived ~ ., type = "classification", maxDepth = 2, maxIter = 2) summary(gbtModel) predictions <- predict(gbtModel, df) +head(select(predictions, "Class", "Sex", "Age", "Freq", "Survived", "prediction")) ``` #### Random Forest @@ -837,6 +841,7 @@ df <- createDataFrame(t) rfModel <- spark.randomForest(df, Survived ~ ., type = "classification", maxDepth = 2, numTrees = 2) summary(rfModel) predictions <- predict(rfModel, df) +head(select(predictions, "Class", "Sex", "Age", "Freq", "Survived", "prediction")) ``` #### Bisecting k-Means From 592cc84583d74c78e4cdf34a3b82692c8de8f4a9 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Wed, 11 Jul 2018 23:43:06 +0800 Subject: [PATCH 1096/1161] [SPARK-24562][TESTS] Support different configs for same test in SQLQueryTestSuite ## What changes were proposed in this pull request? The PR proposes to add support for running the same SQL test input files against different configs leading to the same result. ## How was this patch tested? Involved UTs Author: Marco Gaido Closes #21568 from mgaido91/SPARK-24562. --- .../sql-tests/inputs/join-empty-relation.sql | 5 ++ .../sql-tests/inputs/natural-join.sql | 5 ++ .../resources/sql-tests/inputs/outer-join.sql | 5 ++ .../exists-joins-and-set-ops.sql | 4 ++ .../inputs/subquery/in-subquery/in-joins.sql | 4 ++ .../subquery/in-subquery/not-in-joins.sql | 4 ++ .../apache/spark/sql/SQLQueryTestSuite.scala | 53 ++++++++++++++++--- 7 files changed, 74 insertions(+), 6 deletions(-) diff --git a/sql/core/src/test/resources/sql-tests/inputs/join-empty-relation.sql b/sql/core/src/test/resources/sql-tests/inputs/join-empty-relation.sql index 8afa3270f4de4..2e6a5f362a8fa 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/join-empty-relation.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/join-empty-relation.sql @@ -1,3 +1,8 @@ +-- List of configuration the test suite is run against: +--SET spark.sql.autoBroadcastJoinThreshold=10485760 +--SET spark.sql.autoBroadcastJoinThreshold=-1,spark.sql.join.preferSortMergeJoin=true +--SET spark.sql.autoBroadcastJoinThreshold=-1,spark.sql.join.preferSortMergeJoin=false + CREATE TEMPORARY VIEW t1 AS SELECT * FROM VALUES (1) AS GROUPING(a); CREATE TEMPORARY VIEW t2 AS SELECT * FROM VALUES (1) AS GROUPING(a); diff --git a/sql/core/src/test/resources/sql-tests/inputs/natural-join.sql b/sql/core/src/test/resources/sql-tests/inputs/natural-join.sql index 71a50157b766c..e0abeda3eb44f 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/natural-join.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/natural-join.sql @@ -1,3 +1,8 @@ +-- List of configuration the test suite is run against: +--SET spark.sql.autoBroadcastJoinThreshold=10485760 +--SET spark.sql.autoBroadcastJoinThreshold=-1,spark.sql.join.preferSortMergeJoin=true +--SET spark.sql.autoBroadcastJoinThreshold=-1,spark.sql.join.preferSortMergeJoin=false + create temporary view nt1 as select * from values ("one", 1), ("two", 2), diff --git a/sql/core/src/test/resources/sql-tests/inputs/outer-join.sql b/sql/core/src/test/resources/sql-tests/inputs/outer-join.sql index cdc6c81e10047..ce09c21568f13 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/outer-join.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/outer-join.sql @@ -1,3 +1,8 @@ +-- List of configuration the test suite is run against: +--SET spark.sql.autoBroadcastJoinThreshold=10485760 +--SET spark.sql.autoBroadcastJoinThreshold=-1,spark.sql.join.preferSortMergeJoin=true +--SET spark.sql.autoBroadcastJoinThreshold=-1,spark.sql.join.preferSortMergeJoin=false + -- SPARK-17099: Incorrect result when HAVING clause is added to group by query CREATE OR REPLACE TEMPORARY VIEW t1 AS SELECT * FROM VALUES (-234), (145), (367), (975), (298) diff --git a/sql/core/src/test/resources/sql-tests/inputs/subquery/exists-subquery/exists-joins-and-set-ops.sql b/sql/core/src/test/resources/sql-tests/inputs/subquery/exists-subquery/exists-joins-and-set-ops.sql index cc4ed64affec7..cefc3fe6272ab 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/subquery/exists-subquery/exists-joins-and-set-ops.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/subquery/exists-subquery/exists-joins-and-set-ops.sql @@ -1,5 +1,9 @@ -- Tests EXISTS subquery support. Tests Exists subquery -- used in Joins (Both when joins occurs in outer and suquery blocks) +-- List of configuration the test suite is run against: +--SET spark.sql.autoBroadcastJoinThreshold=10485760 +--SET spark.sql.autoBroadcastJoinThreshold=-1,spark.sql.join.preferSortMergeJoin=true +--SET spark.sql.autoBroadcastJoinThreshold=-1,spark.sql.join.preferSortMergeJoin=false CREATE TEMPORARY VIEW EMP AS SELECT * FROM VALUES (100, "emp 1", date "2005-01-01", 100.00D, 10), diff --git a/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-joins.sql b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-joins.sql index 880175fd7add0..22f3eafd6a02d 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-joins.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-joins.sql @@ -1,5 +1,9 @@ -- A test suite for IN JOINS in parent side, subquery, and both predicate subquery -- It includes correlated cases. +-- List of configuration the test suite is run against: +--SET spark.sql.autoBroadcastJoinThreshold=10485760 +--SET spark.sql.autoBroadcastJoinThreshold=-1,spark.sql.join.preferSortMergeJoin=true +--SET spark.sql.autoBroadcastJoinThreshold=-1,spark.sql.join.preferSortMergeJoin=false create temporary view t1 as select * from values ("val1a", 6S, 8, 10L, float(15.0), 20D, 20E2, timestamp '2014-04-04 01:00:00.000', date '2014-04-04'), diff --git a/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/not-in-joins.sql b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/not-in-joins.sql index e09b91f18de0a..4f8ca8bfb27c1 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/not-in-joins.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/not-in-joins.sql @@ -1,5 +1,9 @@ -- A test suite for not-in-joins in parent side, subquery, and both predicate subquery -- It includes correlated cases. +-- List of configuration the test suite is run against: +--SET spark.sql.autoBroadcastJoinThreshold=10485760 +--SET spark.sql.autoBroadcastJoinThreshold=-1,spark.sql.join.preferSortMergeJoin=true +--SET spark.sql.autoBroadcastJoinThreshold=-1,spark.sql.join.preferSortMergeJoin=false create temporary view t1 as select * from values ("val1a", 6S, 8, 10L, float(15.0), 20D, 20E2, timestamp '2014-04-04 01:00:00.000', date '2014-04-04'), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala index beac9699585d5..826408c7161e9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala @@ -54,6 +54,7 @@ import org.apache.spark.sql.types.StructType * The format for input files is simple: * 1. A list of SQL queries separated by semicolon. * 2. Lines starting with -- are treated as comments and ignored. + * 3. Lines starting with --SET are used to run the file with the following set of configs. * * For example: * {{{ @@ -138,18 +139,58 @@ class SQLQueryTestSuite extends QueryTest with SharedSQLContext { private def runTest(testCase: TestCase): Unit = { val input = fileToString(new File(testCase.inputFile)) + val (comments, code) = input.split("\n").partition(_.startsWith("--")) + val configSets = { + val configLines = comments.filter(_.startsWith("--SET")).map(_.substring(5)) + val configs = configLines.map(_.split(",").map { confAndValue => + val (conf, value) = confAndValue.span(_ != '=') + conf.trim -> value.substring(1).trim + }) + // When we are regenerating the golden files we don't need to run all the configs as they + // all need to return the same result + if (regenerateGoldenFiles && configs.nonEmpty) { + configs.take(1) + } else { + configs + } + } // List of SQL queries to run - val queries: Seq[String] = { - val cleaned = input.split("\n").filterNot(_.startsWith("--")).mkString("\n") - // note: this is not a robust way to split queries using semicolon, but works for now. - cleaned.split("(?<=[^\\\\]);").map(_.trim).filter(_ != "").toSeq + // note: this is not a robust way to split queries using semicolon, but works for now. + val queries = code.mkString("\n").split("(?<=[^\\\\]);").map(_.trim).filter(_ != "").toSeq + + if (configSets.isEmpty) { + runQueries(queries, testCase.resultFile, None) + } else { + configSets.foreach { configSet => + try { + runQueries(queries, testCase.resultFile, Some(configSet)) + } catch { + case e: Throwable => + val configs = configSet.map { + case (k, v) => s"$k=$v" + } + logError(s"Error using configs: ${configs.mkString(",")}") + throw e + } + } } + } + private def runQueries( + queries: Seq[String], + resultFileName: String, + configSet: Option[Seq[(String, String)]]): Unit = { // Create a local SparkSession to have stronger isolation between different test cases. // This does not isolate catalog changes. val localSparkSession = spark.newSession() loadTestData(localSparkSession) + if (configSet.isDefined) { + // Execute the list of set operation in order to add the desired configs + val setOperations = configSet.get.map { case (key, value) => s"set $key=$value" } + logInfo(s"Setting configs: ${setOperations.mkString(", ")}") + setOperations.foreach(localSparkSession.sql) + } // Run the SQL queries preparing them for comparison. val outputs: Seq[QueryOutput] = queries.map { sql => val (schema, output) = getNormalizedResult(localSparkSession, sql) @@ -167,7 +208,7 @@ class SQLQueryTestSuite extends QueryTest with SharedSQLContext { s"-- Number of queries: ${outputs.size}\n\n\n" + outputs.zipWithIndex.map{case (qr, i) => qr.toString(i)}.mkString("\n\n\n") + "\n" } - val resultFile = new File(testCase.resultFile) + val resultFile = new File(resultFileName) val parent = resultFile.getParentFile if (!parent.exists()) { assert(parent.mkdirs(), "Could not create directory: " + parent) @@ -177,7 +218,7 @@ class SQLQueryTestSuite extends QueryTest with SharedSQLContext { // Read back the golden file. val expectedOutputs: Seq[QueryOutput] = { - val goldenOutput = fileToString(new File(testCase.resultFile)) + val goldenOutput = fileToString(new File(resultFileName)) val segments = goldenOutput.split("-- !query.+\n") // each query has 3 segments, plus the header From ebf4bfb966389342bfd9bdb8e3b612828c18730c Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Wed, 11 Jul 2018 09:29:19 -0700 Subject: [PATCH 1097/1161] [SPARK-24208][SQL] Fix attribute deduplication for FlatMapGroupsInPandas ## What changes were proposed in this pull request? A self-join on a dataset which contains a `FlatMapGroupsInPandas` fails because of duplicate attributes. This happens because we are not dealing with this specific case in our `dedupAttr` rules. The PR fix the issue by adding the management of the specific case ## How was this patch tested? added UT + manual tests Author: Marco Gaido Author: Marco Gaido Closes #21737 from mgaido91/SPARK-24208. --- python/pyspark/sql/tests.py | 16 ++++++++++++++++ .../spark/sql/catalyst/analysis/Analyzer.scala | 4 ++++ .../apache/spark/sql/GroupedDatasetSuite.scala | 12 ++++++++++++ 3 files changed, 32 insertions(+) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 8d738069adb3d..4404dbe40590a 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -5925,6 +5925,22 @@ def test_invalid_args(self): 'mixture.*aggregate function.*group aggregate pandas UDF'): df.groupby(df.id).agg(mean_udf(df.v), mean(df.v)).collect() + def test_self_join_with_pandas(self): + import pyspark.sql.functions as F + + @F.pandas_udf('key long, col string', F.PandasUDFType.GROUPED_MAP) + def dummy_pandas_udf(df): + return df[['key', 'col']] + + df = self.spark.createDataFrame([Row(key=1, col='A'), Row(key=1, col='B'), + Row(key=2, col='C')]) + dfWithPandas = df.groupBy('key').apply(dummy_pandas_udf) + + # this was throwing an AnalysisException before SPARK-24208 + res = dfWithPandas.alias('temp0').join(dfWithPandas.alias('temp1'), + F.col('temp0.key') == F.col('temp1.key')) + self.assertEquals(res.count(), 5) + @unittest.skipIf( not _have_pandas or not _have_pyarrow, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index e187133d03b17..c078efdfc0000 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -738,6 +738,10 @@ class Analyzer( if findAliases(aggregateExpressions).intersect(conflictingAttributes).nonEmpty => (oldVersion, oldVersion.copy(aggregateExpressions = newAliases(aggregateExpressions))) + case oldVersion @ FlatMapGroupsInPandas(_, _, output, _) + if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty => + (oldVersion, oldVersion.copy(output = output.map(_.newInstance()))) + case oldVersion: Generate if oldVersion.producedAttributes.intersect(conflictingAttributes).nonEmpty => val newOutput = oldVersion.generatorOutput.map(_.newInstance()) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/GroupedDatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/GroupedDatasetSuite.scala index 147c0b61f5017..bd54ea415ca88 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/GroupedDatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/GroupedDatasetSuite.scala @@ -93,4 +93,16 @@ class GroupedDatasetSuite extends QueryTest with SharedSQLContext { } datasetWithUDF.unpersist(true) } + + test("SPARK-24208: analysis fails on self-join with FlatMapGroupsInPandas") { + val df = datasetWithUDF.groupBy("s").flatMapGroupsInPandas(PythonUDF( + "pyUDF", + null, + StructType(Seq(StructField("s", LongType))), + Seq.empty, + PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, + true)) + val df1 = df.alias("temp0").join(df.alias("temp1"), $"temp0.s" === $"temp1.s") + df1.queryExecution.assertAnalyzed() + } } From 290c30a53fc2f46001846ab8abafcc69b853ba98 Mon Sep 17 00:00:00 2001 From: Rekha Joshi Date: Wed, 11 Jul 2018 13:48:28 -0500 Subject: [PATCH 1098/1161] [SPARK-24470][CORE] RestSubmissionClient to be robust against 404 & non json responses ## What changes were proposed in this pull request? Added check for 404, to avoid json parsing on not found response and to avoid returning malformed or bad request when it was a not found http response. Not sure if I need to add an additional check on non json response [if(connection.getHeaderField("Content-Type").contains("text/html")) then exception] as non-json is a subset of malformed json and covered in flow. ## How was this patch tested? ./dev/run-tests Author: Rekha Joshi Closes #21684 from rekhajoshm/SPARK-24470. --- .../deploy/rest/RestSubmissionClient.scala | 60 ++++++++++++------- 1 file changed, 37 insertions(+), 23 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala index 742a95841a138..31a8e3e60c067 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala @@ -233,30 +233,44 @@ private[spark] class RestSubmissionClient(master: String) extends Logging { private[rest] def readResponse(connection: HttpURLConnection): SubmitRestProtocolResponse = { import scala.concurrent.ExecutionContext.Implicits.global val responseFuture = Future { - val dataStream = - if (connection.getResponseCode == HttpServletResponse.SC_OK) { - connection.getInputStream - } else { - connection.getErrorStream + val responseCode = connection.getResponseCode + + if (responseCode != HttpServletResponse.SC_OK) { + val errString = Some(Source.fromInputStream(connection.getErrorStream()) + .getLines().mkString("\n")) + if (responseCode == HttpServletResponse.SC_INTERNAL_SERVER_ERROR && + !connection.getContentType().contains("application/json")) { + throw new SubmitRestProtocolException(s"Server responded with exception:\n${errString}") + } + logError(s"Server responded with error:\n${errString}") + val error = new ErrorResponse + if (responseCode == RestSubmissionServer.SC_UNKNOWN_PROTOCOL_VERSION) { + error.highestProtocolVersion = RestSubmissionServer.PROTOCOL_VERSION + } + error.message = errString.get + error + } else { + val dataStream = connection.getInputStream + + // If the server threw an exception while writing a response, it will not have a body + if (dataStream == null) { + throw new SubmitRestProtocolException("Server returned empty body") + } + val responseJson = Source.fromInputStream(dataStream).mkString + logDebug(s"Response from the server:\n$responseJson") + val response = SubmitRestProtocolMessage.fromJson(responseJson) + response.validate() + response match { + // If the response is an error, log the message + case error: ErrorResponse => + logError(s"Server responded with error:\n${error.message}") + error + // Otherwise, simply return the response + case response: SubmitRestProtocolResponse => response + case unexpected => + throw new SubmitRestProtocolException( + s"Message received from server was not a response:\n${unexpected.toJson}") } - // If the server threw an exception while writing a response, it will not have a body - if (dataStream == null) { - throw new SubmitRestProtocolException("Server returned empty body") - } - val responseJson = Source.fromInputStream(dataStream).mkString - logDebug(s"Response from the server:\n$responseJson") - val response = SubmitRestProtocolMessage.fromJson(responseJson) - response.validate() - response match { - // If the response is an error, log the message - case error: ErrorResponse => - logError(s"Server responded with error:\n${error.message}") - error - // Otherwise, simply return the response - case response: SubmitRestProtocolResponse => response - case unexpected => - throw new SubmitRestProtocolException( - s"Message received from server was not a response:\n${unexpected.toJson}") } } From 59c3c233f4366809b6b4db39b3d32c194c98d5ab Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Wed, 11 Jul 2018 13:56:09 -0500 Subject: [PATCH 1099/1161] [SPARK-23254][ML] Add user guide entry and example for DataFrame multivariate summary ## What changes were proposed in this pull request? Add user guide and scala/java/python examples for `ml.stat.Summarizer` ## How was this patch tested? Doc generated snapshot: ![image](https://user-images.githubusercontent.com/19235986/38987108-45646044-4401-11e8-9ba8-ae94ba96cbf9.png) ![image](https://user-images.githubusercontent.com/19235986/38987096-36dcc73c-4401-11e8-87f9-5b91e7f9e27b.png) ![image](https://user-images.githubusercontent.com/19235986/38987088-2d1c1eaa-4401-11e8-80b5-8c40d529a120.png) ![image](https://user-images.githubusercontent.com/19235986/38987077-22ce8be0-4401-11e8-8199-c3a4d8d23201.png) Author: WeichenXu Closes #20446 from WeichenXu123/summ_guide. --- docs/ml-statistics.md | 28 ++++++++ .../examples/ml/JavaSummarizerExample.java | 71 +++++++++++++++++++ .../src/main/python/ml/summarizer_example.py | 59 +++++++++++++++ .../spark/examples/ml/SummarizerExample.scala | 61 ++++++++++++++++ 4 files changed, 219 insertions(+) create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaSummarizerExample.java create mode 100644 examples/src/main/python/ml/summarizer_example.py create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/SummarizerExample.scala diff --git a/docs/ml-statistics.md b/docs/ml-statistics.md index abfb3cab1e566..6c82b3bb94b24 100644 --- a/docs/ml-statistics.md +++ b/docs/ml-statistics.md @@ -89,4 +89,32 @@ Refer to the [`ChiSquareTest` Python docs](api/python/index.html#pyspark.ml.stat {% include_example python/ml/chi_square_test_example.py %} + + +## Summarizer + +We provide vector column summary statistics for `Dataframe` through `Summarizer`. +Available metrics are the column-wise max, min, mean, variance, and number of nonzeros, as well as the total count. + +
      +
      +The following example demonstrates using [`Summarizer`](api/scala/index.html#org.apache.spark.ml.stat.Summarizer$) +to compute the mean and variance for a vector column of the input dataframe, with and without a weight column. + +{% include_example scala/org/apache/spark/examples/ml/SummarizerExample.scala %} +
      + +
      +The following example demonstrates using [`Summarizer`](api/java/org/apache/spark/ml/stat/Summarizer.html) +to compute the mean and variance for a vector column of the input dataframe, with and without a weight column. + +{% include_example java/org/apache/spark/examples/ml/JavaSummarizerExample.java %} +
      + +
      +Refer to the [`Summarizer` Python docs](api/python/index.html#pyspark.ml.stat.Summarizer$) for details on the API. + +{% include_example python/ml/summarizer_example.py %} +
      +
      \ No newline at end of file diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaSummarizerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaSummarizerExample.java new file mode 100644 index 0000000000000..e9b84365d86ed --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaSummarizerExample.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.sql.*; + +// $example on$ +import java.util.Arrays; +import java.util.List; + +import org.apache.spark.ml.linalg.Vector; +import org.apache.spark.ml.linalg.Vectors; +import org.apache.spark.ml.linalg.VectorUDT; +import org.apache.spark.ml.stat.Summarizer; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +// $example off$ + +public class JavaSummarizerExample { + public static void main(String[] args) { + SparkSession spark = SparkSession + .builder() + .appName("JavaSummarizerExample") + .getOrCreate(); + + // $example on$ + List data = Arrays.asList( + RowFactory.create(Vectors.dense(2.0, 3.0, 5.0), 1.0), + RowFactory.create(Vectors.dense(4.0, 6.0, 7.0), 2.0) + ); + + StructType schema = new StructType(new StructField[]{ + new StructField("features", new VectorUDT(), false, Metadata.empty()), + new StructField("weight", DataTypes.DoubleType, false, Metadata.empty()) + }); + + Dataset df = spark.createDataFrame(data, schema); + + Row result1 = df.select(Summarizer.metrics("mean", "variance") + .summary(new Column("features"), new Column("weight")).as("summary")) + .select("summary.mean", "summary.variance").first(); + System.out.println("with weight: mean = " + result1.getAs(0).toString() + + ", variance = " + result1.getAs(1).toString()); + + Row result2 = df.select( + Summarizer.mean(new Column("features")), + Summarizer.variance(new Column("features")) + ).first(); + System.out.println("without weight: mean = " + result2.getAs(0).toString() + + ", variance = " + result2.getAs(1).toString()); + // $example off$ + spark.stop(); + } +} diff --git a/examples/src/main/python/ml/summarizer_example.py b/examples/src/main/python/ml/summarizer_example.py new file mode 100644 index 0000000000000..8835f189a1ad4 --- /dev/null +++ b/examples/src/main/python/ml/summarizer_example.py @@ -0,0 +1,59 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +An example for summarizer. +Run with: + bin/spark-submit examples/src/main/python/ml/summarizer_example.py +""" +from __future__ import print_function + +from pyspark.sql import SparkSession +# $example on$ +from pyspark.ml.stat import Summarizer +from pyspark.sql import Row +from pyspark.ml.linalg import Vectors +# $example off$ + +if __name__ == "__main__": + spark = SparkSession \ + .builder \ + .appName("SummarizerExample") \ + .getOrCreate() + sc = spark.sparkContext + + # $example on$ + df = sc.parallelize([Row(weight=1.0, features=Vectors.dense(1.0, 1.0, 1.0)), + Row(weight=0.0, features=Vectors.dense(1.0, 2.0, 3.0))]).toDF() + + # create summarizer for multiple metrics "mean" and "count" + summarizer = Summarizer.metrics("mean", "count") + + # compute statistics for multiple metrics with weight + df.select(summarizer.summary(df.features, df.weight)).show(truncate=False) + + # compute statistics for multiple metrics without weight + df.select(summarizer.summary(df.features)).show(truncate=False) + + # compute statistics for single metric "mean" with weight + df.select(Summarizer.mean(df.features, df.weight)).show(truncate=False) + + # compute statistics for single metric "mean" without weight + df.select(Summarizer.mean(df.features)).show(truncate=False) + # $example off$ + + spark.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/SummarizerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/SummarizerExample.scala new file mode 100644 index 0000000000000..2f54d1d81bc48 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/SummarizerExample.scala @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.linalg.{Vector, Vectors} +import org.apache.spark.ml.stat.Summarizer +// $example off$ +import org.apache.spark.sql.SparkSession + +object SummarizerExample { + def main(args: Array[String]): Unit = { + val spark = SparkSession + .builder + .appName("SummarizerExample") + .getOrCreate() + + import spark.implicits._ + import Summarizer._ + + // $example on$ + val data = Seq( + (Vectors.dense(2.0, 3.0, 5.0), 1.0), + (Vectors.dense(4.0, 6.0, 7.0), 2.0) + ) + + val df = data.toDF("features", "weight") + + val (meanVal, varianceVal) = df.select(metrics("mean", "variance") + .summary($"features", $"weight").as("summary")) + .select("summary.mean", "summary.variance") + .as[(Vector, Vector)].first() + + println(s"with weight: mean = ${meanVal}, variance = ${varianceVal}") + + val (meanVal2, varianceVal2) = df.select(mean($"features"), variance($"features")) + .as[(Vector, Vector)].first() + + println(s"without weight: mean = ${meanVal2}, sum = ${varianceVal2}") + // $example off$ + + spark.stop() + } +} +// scalastyle:on println From ff7f6ef75c80633480802d537e66432e3bea4785 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Wed, 11 Jul 2018 12:44:42 -0700 Subject: [PATCH 1100/1161] [SPARK-24697][SS] Fix the reported start offsets in streaming query progress ## What changes were proposed in this pull request? In ProgressReporter for streams, we use the `committedOffsets` as the startOffset and `availableOffsets` as the end offset when reporting the status of a trigger in `finishTrigger`. This is a bad pattern that has existed since the beginning of ProgressReporter and it is bad because its super hard to reason about when `availableOffsets` and `committedOffsets` are updated, and when they are recorded. Case in point, this bug silently existed in ContinuousExecution, since before MicroBatchExecution was refactored. The correct fix it to record the offsets explicitly. This PR adds a simple method which is explicitly called from MicroBatch/ContinuousExecition before updating the `committedOffsets`. ## How was this patch tested? Added new tests Author: Tathagata Das Closes #21744 from tdas/SPARK-24697. --- .../streaming/MicroBatchExecution.scala | 3 +++ .../streaming/ProgressReporter.scala | 21 +++++++++++++++---- .../continuous/ContinuousExecution.scala | 3 +++ .../sql/streaming/StreamingQuerySuite.scala | 6 ++++-- 4 files changed, 27 insertions(+), 6 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index 16651dd060d73..45c43f549d24f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -184,6 +184,9 @@ class MicroBatchExecution( isCurrentBatchConstructed = constructNextBatch(noDataBatchesEnabled) } + // Record the trigger offset range for progress reporting *before* processing the batch + recordTriggerOffsets(from = committedOffsets, to = availableOffsets) + // Remember whether the current batch has data or not. This will be required later // for bookkeeping after running the batch, when `isNewDataAvailable` will have changed // to false as the batch would have already processed the available data. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala index 16ad3ef9a3d4a..47f4b52e6e34c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala @@ -56,8 +56,6 @@ trait ProgressReporter extends Logging { protected def logicalPlan: LogicalPlan protected def lastExecution: QueryExecution protected def newData: Map[BaseStreamingSource, LogicalPlan] - protected def availableOffsets: StreamProgress - protected def committedOffsets: StreamProgress protected def sources: Seq[BaseStreamingSource] protected def sink: BaseStreamingSink protected def offsetSeqMetadata: OffsetSeqMetadata @@ -68,8 +66,11 @@ trait ProgressReporter extends Logging { // Local timestamps and counters. private var currentTriggerStartTimestamp = -1L private var currentTriggerEndTimestamp = -1L + private var currentTriggerStartOffsets: Map[BaseStreamingSource, String] = _ + private var currentTriggerEndOffsets: Map[BaseStreamingSource, String] = _ // TODO: Restore this from the checkpoint when possible. private var lastTriggerStartTimestamp = -1L + private val currentDurationsMs = new mutable.HashMap[String, Long]() /** Flag that signals whether any error with input metrics have already been logged */ @@ -114,9 +115,20 @@ trait ProgressReporter extends Logging { lastTriggerStartTimestamp = currentTriggerStartTimestamp currentTriggerStartTimestamp = triggerClock.getTimeMillis() currentStatus = currentStatus.copy(isTriggerActive = true) + currentTriggerStartOffsets = null + currentTriggerEndOffsets = null currentDurationsMs.clear() } + /** + * Record the offsets range this trigger will process. Call this before updating + * `committedOffsets` in `StreamExecution` to make sure that the correct range is recorded. + */ + protected def recordTriggerOffsets(from: StreamProgress, to: StreamProgress): Unit = { + currentTriggerStartOffsets = from.mapValues(_.json) + currentTriggerEndOffsets = to.mapValues(_.json) + } + private def updateProgress(newProgress: StreamingQueryProgress): Unit = { progressBuffer.synchronized { progressBuffer += newProgress @@ -130,6 +142,7 @@ trait ProgressReporter extends Logging { /** Finalizes the query progress and adds it to list of recent status updates. */ protected def finishTrigger(hasNewData: Boolean): Unit = { + assert(currentTriggerStartOffsets != null && currentTriggerEndOffsets != null) currentTriggerEndTimestamp = triggerClock.getTimeMillis() val executionStats = extractExecutionStats(hasNewData) @@ -147,8 +160,8 @@ trait ProgressReporter extends Logging { val numRecords = executionStats.inputRows.getOrElse(source, 0L) new SourceProgress( description = source.toString, - startOffset = committedOffsets.get(source).map(_.json).orNull, - endOffset = availableOffsets.get(source).map(_.json).orNull, + startOffset = currentTriggerStartOffsets.get(source).orNull, + endOffset = currentTriggerEndOffsets.get(source).orNull, numInputRows = numRecords, inputRowsPerSecond = numRecords / inputTimeSec, processedRowsPerSecond = numRecords / processingTimeSec diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index a0bb8292d7766..e991dbc81696d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -309,7 +309,10 @@ class ContinuousExecution( def commit(epoch: Long): Unit = { assert(continuousSources.length == 1, "only one continuous source supported currently") assert(offsetLog.get(epoch).isDefined, s"offset for epoch $epoch not reported before commit") + synchronized { + // Record offsets before updating `committedOffsets` + recordTriggerOffsets(from = committedOffsets, to = availableOffsets) if (queryExecutionThread.isAlive) { commitLog.add(epoch) val offset = diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala index dcf6cb5d609ee..936a076d647b6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala @@ -335,8 +335,8 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi assert(progress.sources.length === 1) assert(progress.sources(0).description contains "MemoryStream") - assert(progress.sources(0).startOffset === "0") - assert(progress.sources(0).endOffset !== null) + assert(progress.sources(0).startOffset === null) // no prior offset + assert(progress.sources(0).endOffset === "0") assert(progress.sources(0).processedRowsPerSecond === 4.0) // 2 rows processed in 500 ms assert(progress.stateOperators.length === 1) @@ -362,6 +362,8 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi assert(query.lastProgress.batchId === 1) assert(query.lastProgress.inputRowsPerSecond === 2.0) assert(query.lastProgress.sources(0).inputRowsPerSecond === 2.0) + assert(query.lastProgress.sources(0).startOffset === "0") + assert(query.lastProgress.sources(0).endOffset === "1") true }, From e008ad175256a3192fdcbd2c4793044d52f46d57 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Wed, 11 Jul 2018 17:30:43 -0700 Subject: [PATCH 1101/1161] [SPARK-24782][SQL] Simplify conf retrieval in SQL expressions ## What changes were proposed in this pull request? The PR simplifies the retrieval of config in `size`, as we can access them from tasks too thanks to SPARK-24250. ## How was this patch tested? existing UTs Author: Marco Gaido Closes #21736 from mgaido91/SPARK-24605_followup. --- .../expressions/collectionOperations.scala | 10 +-- .../expressions/jsonExpressions.scala | 16 ++--- .../spark/sql/catalyst/plans/QueryPlan.scala | 2 - .../CollectionExpressionsSuite.scala | 27 ++++---- .../expressions/JsonExpressionsSuite.scala | 65 +++++++++---------- .../org/apache/spark/sql/functions.scala | 4 +- 6 files changed, 57 insertions(+), 67 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 879603b66b314..e217d37184511 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -89,15 +89,9 @@ trait BinaryArrayExpressionWithImplicitCast extends BinaryExpression > SELECT _FUNC_(NULL); -1 """) -case class Size( - child: Expression, - legacySizeOfNull: Boolean) - extends UnaryExpression with ExpectsInputTypes { +case class Size(child: Expression) extends UnaryExpression with ExpectsInputTypes { - def this(child: Expression) = - this( - child, - legacySizeOfNull = SQLConf.get.getConf(SQLConf.LEGACY_SIZE_OF_NULL)) + val legacySizeOfNull = SQLConf.get.legacySizeOfNull override def dataType: DataType = IntegerType override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(ArrayType, MapType)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index 8cd86053a01c7..63943b1a4351a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -514,10 +514,11 @@ case class JsonToStructs( schema: DataType, options: Map[String, String], child: Expression, - timeZoneId: Option[String], - forceNullableSchema: Boolean) + timeZoneId: Option[String] = None) extends UnaryExpression with TimeZoneAwareExpression with CodegenFallback with ExpectsInputTypes { + val forceNullableSchema = SQLConf.get.getConf(SQLConf.FROM_JSON_FORCE_NULLABLE_SCHEMA) + // The JSON input data might be missing certain fields. We force the nullability // of the user-provided schema to avoid data corruptions. In particular, the parquet-mr encoder // can generate incorrect files if values are missing in columns declared as non-nullable. @@ -531,8 +532,7 @@ case class JsonToStructs( schema = JsonExprUtils.evalSchemaExpr(schema), options = options, child = child, - timeZoneId = None, - forceNullableSchema = SQLConf.get.getConf(SQLConf.FROM_JSON_FORCE_NULLABLE_SCHEMA)) + timeZoneId = None) def this(child: Expression, schema: Expression) = this(child, schema, Map.empty[String, String]) @@ -541,13 +541,7 @@ case class JsonToStructs( schema = JsonExprUtils.evalSchemaExpr(schema), options = JsonExprUtils.convertToMapData(options), child = child, - timeZoneId = None, - forceNullableSchema = SQLConf.get.getConf(SQLConf.FROM_JSON_FORCE_NULLABLE_SCHEMA)) - - // Used in `org.apache.spark.sql.functions` - def this(schema: DataType, options: Map[String, String], child: Expression) = - this(schema, options, child, timeZoneId = None, - forceNullableSchema = SQLConf.get.getConf(SQLConf.FROM_JSON_FORCE_NULLABLE_SCHEMA)) + timeZoneId = None) override def checkInputDataTypes(): TypeCheckResult = nullableSchema match { case _: StructType | ArrayType(_: StructType, _) | _: MapType => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index e431c9523a9da..4b4722b0b2117 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -27,8 +27,6 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT /** * The active config object within the current scope. - * Note that if you want to refer config values during execution, you have to capture them - * in Driver and use the captured values in Executors. * See [[SQLConf.get]] for more information. */ def conf: SQLConf = SQLConf.get diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 173c98af323b1..a838a2eedc03c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -24,43 +24,48 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.DateTimeTestUtils +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.array.ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH import org.apache.spark.unsafe.types.CalendarInterval class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { - def testSize(legacySizeOfNull: Boolean, sizeOfNull: Any): Unit = { + def testSize(sizeOfNull: Any): Unit = { val a0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType)) val a1 = Literal.create(Seq[Integer](), ArrayType(IntegerType)) val a2 = Literal.create(Seq(1, 2), ArrayType(IntegerType)) - checkEvaluation(Size(a0, legacySizeOfNull), 3) - checkEvaluation(Size(a1, legacySizeOfNull), 0) - checkEvaluation(Size(a2, legacySizeOfNull), 2) + checkEvaluation(Size(a0), 3) + checkEvaluation(Size(a1), 0) + checkEvaluation(Size(a2), 2) val m0 = Literal.create(Map("a" -> "a", "b" -> "b"), MapType(StringType, StringType)) val m1 = Literal.create(Map[String, String](), MapType(StringType, StringType)) val m2 = Literal.create(Map("a" -> "a"), MapType(StringType, StringType)) - checkEvaluation(Size(m0, legacySizeOfNull), 2) - checkEvaluation(Size(m1, legacySizeOfNull), 0) - checkEvaluation(Size(m2, legacySizeOfNull), 1) + checkEvaluation(Size(m0), 2) + checkEvaluation(Size(m1), 0) + checkEvaluation(Size(m2), 1) checkEvaluation( - Size(Literal.create(null, MapType(StringType, StringType)), legacySizeOfNull), + Size(Literal.create(null, MapType(StringType, StringType))), expected = sizeOfNull) checkEvaluation( - Size(Literal.create(null, ArrayType(StringType)), legacySizeOfNull), + Size(Literal.create(null, ArrayType(StringType))), expected = sizeOfNull) } test("Array and Map Size - legacy") { - testSize(legacySizeOfNull = true, sizeOfNull = -1) + withSQLConf(SQLConf.LEGACY_SIZE_OF_NULL.key -> "true") { + testSize(sizeOfNull = -1) + } } test("Array and Map Size") { - testSize(legacySizeOfNull = false, sizeOfNull = null) + withSQLConf(SQLConf.LEGACY_SIZE_OF_NULL.key -> "false") { + testSize(sizeOfNull = null) + } } test("MapKeys/MapValues") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala index 52203b9e337ba..04f1c8ce0b83d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala @@ -392,7 +392,7 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with val jsonData = """{"a": 1}""" val schema = StructType(StructField("a", IntegerType) :: Nil) checkEvaluation( - JsonToStructs(schema, Map.empty, Literal(jsonData), gmtId, true), + JsonToStructs(schema, Map.empty, Literal(jsonData), gmtId), InternalRow(1) ) } @@ -401,13 +401,13 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with val jsonData = """{"a" 1}""" val schema = StructType(StructField("a", IntegerType) :: Nil) checkEvaluation( - JsonToStructs(schema, Map.empty, Literal(jsonData), gmtId, true), + JsonToStructs(schema, Map.empty, Literal(jsonData), gmtId), null ) // Other modes should still return `null`. checkEvaluation( - JsonToStructs(schema, Map("mode" -> PermissiveMode.name), Literal(jsonData), gmtId, true), + JsonToStructs(schema, Map("mode" -> PermissiveMode.name), Literal(jsonData), gmtId), null ) } @@ -416,62 +416,62 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with val input = """[{"a": 1}, {"a": 2}]""" val schema = ArrayType(StructType(StructField("a", IntegerType) :: Nil)) val output = InternalRow(1) :: InternalRow(2) :: Nil - checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId, true), output) + checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId), output) } test("from_json - input=object, schema=array, output=array of single row") { val input = """{"a": 1}""" val schema = ArrayType(StructType(StructField("a", IntegerType) :: Nil)) val output = InternalRow(1) :: Nil - checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId, true), output) + checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId), output) } test("from_json - input=empty array, schema=array, output=empty array") { val input = "[ ]" val schema = ArrayType(StructType(StructField("a", IntegerType) :: Nil)) val output = Nil - checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId, true), output) + checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId), output) } test("from_json - input=empty object, schema=array, output=array of single row with null") { val input = "{ }" val schema = ArrayType(StructType(StructField("a", IntegerType) :: Nil)) val output = InternalRow(null) :: Nil - checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId, true), output) + checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId), output) } test("from_json - input=array of single object, schema=struct, output=single row") { val input = """[{"a": 1}]""" val schema = StructType(StructField("a", IntegerType) :: Nil) val output = InternalRow(1) - checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId, true), output) + checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId), output) } test("from_json - input=array, schema=struct, output=null") { val input = """[{"a": 1}, {"a": 2}]""" val schema = StructType(StructField("a", IntegerType) :: Nil) val output = null - checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId, true), output) + checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId), output) } test("from_json - input=empty array, schema=struct, output=null") { val input = """[]""" val schema = StructType(StructField("a", IntegerType) :: Nil) val output = null - checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId, true), output) + checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId), output) } test("from_json - input=empty object, schema=struct, output=single row with null") { val input = """{ }""" val schema = StructType(StructField("a", IntegerType) :: Nil) val output = InternalRow(null) - checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId, true), output) + checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId), output) } test("from_json null input column") { val schema = StructType(StructField("a", IntegerType) :: Nil) checkEvaluation( - JsonToStructs(schema, Map.empty, Literal.create(null, StringType), gmtId, true), + JsonToStructs(schema, Map.empty, Literal.create(null, StringType), gmtId), null ) } @@ -479,7 +479,7 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with test("SPARK-20549: from_json bad UTF-8") { val schema = StructType(StructField("a", IntegerType) :: Nil) checkEvaluation( - JsonToStructs(schema, Map.empty, Literal(badJson), gmtId, true), + JsonToStructs(schema, Map.empty, Literal(badJson), gmtId), null) } @@ -491,14 +491,14 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with c.set(2016, 0, 1, 0, 0, 0) c.set(Calendar.MILLISECOND, 123) checkEvaluation( - JsonToStructs(schema, Map.empty, Literal(jsonData1), gmtId, true), + JsonToStructs(schema, Map.empty, Literal(jsonData1), gmtId), InternalRow(c.getTimeInMillis * 1000L) ) // The result doesn't change because the json string includes timezone string ("Z" here), // which means the string represents the timestamp string in the timezone regardless of // the timeZoneId parameter. checkEvaluation( - JsonToStructs(schema, Map.empty, Literal(jsonData1), Option("PST"), true), + JsonToStructs(schema, Map.empty, Literal(jsonData1), Option("PST")), InternalRow(c.getTimeInMillis * 1000L) ) @@ -512,8 +512,7 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with schema, Map("timestampFormat" -> "yyyy-MM-dd'T'HH:mm:ss"), Literal(jsonData2), - Option(tz.getID), - true), + Option(tz.getID)), InternalRow(c.getTimeInMillis * 1000L) ) checkEvaluation( @@ -522,8 +521,7 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with Map("timestampFormat" -> "yyyy-MM-dd'T'HH:mm:ss", DateTimeUtils.TIMEZONE_OPTION -> tz.getID), Literal(jsonData2), - gmtId, - true), + gmtId), InternalRow(c.getTimeInMillis * 1000L) ) } @@ -532,7 +530,7 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with test("SPARK-19543: from_json empty input column") { val schema = StructType(StructField("a", IntegerType) :: Nil) checkEvaluation( - JsonToStructs(schema, Map.empty, Literal.create(" ", StringType), gmtId, true), + JsonToStructs(schema, Map.empty, Literal.create(" ", StringType), gmtId), null ) } @@ -687,23 +685,24 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with test("from_json missing fields") { for (forceJsonNullableSchema <- Seq(false, true)) { - val input = - """{ + withSQLConf(SQLConf.FROM_JSON_FORCE_NULLABLE_SCHEMA.key -> forceJsonNullableSchema.toString) { + val input = + """{ | "a": 1, | "c": "foo" |} |""".stripMargin - val jsonSchema = new StructType() - .add("a", LongType, nullable = false) - .add("b", StringType, nullable = false) - .add("c", StringType, nullable = false) - val output = InternalRow(1L, null, UTF8String.fromString("foo")) - val expr = JsonToStructs( - jsonSchema, Map.empty, Literal.create(input, StringType), gmtId, forceJsonNullableSchema) - checkEvaluation(expr, output) - val schema = expr.dataType - val schemaToCompare = if (forceJsonNullableSchema) jsonSchema.asNullable else jsonSchema - assert(schemaToCompare == schema) + val jsonSchema = new StructType() + .add("a", LongType, nullable = false) + .add("b", StringType, nullable = false) + .add("c", StringType, nullable = false) + val output = InternalRow(1L, null, UTF8String.fromString("foo")) + val expr = JsonToStructs(jsonSchema, Map.empty, Literal.create(input, StringType), gmtId) + checkEvaluation(expr, output) + val schema = expr.dataType + val schemaToCompare = if (forceJsonNullableSchema) jsonSchema.asNullable else jsonSchema + assert(schemaToCompare == schema) + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 89dbba10a6bf1..6b956ddb48561 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3304,7 +3304,7 @@ object functions { * @since 2.2.0 */ def from_json(e: Column, schema: DataType, options: Map[String, String]): Column = withExpr { - new JsonToStructs(schema, options, e.expr) + JsonToStructs(schema, options, e.expr) } /** @@ -3495,7 +3495,7 @@ object functions { * @group collection_funcs * @since 1.5.0 */ - def size(e: Column): Column = withExpr { new Size(e.expr) } + def size(e: Column): Column = withExpr { Size(e.expr) } /** * Sorts the input array for the given column in ascending order, From 3ab48f985c7f96bc9143caad99bf3df7cc984583 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Wed, 11 Jul 2018 17:38:43 -0700 Subject: [PATCH 1102/1161] [SPARK-24761][SQL] Adding of isModifiable() to RuntimeConfig ## What changes were proposed in this pull request? In the PR, I propose to extend `RuntimeConfig` by new method `isModifiable()` which returns `true` if a config parameter can be modified at runtime (for current session state). For static SQL and core parameters, the method returns `false`. ## How was this patch tested? Added new test to `RuntimeConfigSuite` for checking Spark core and SQL parameters. Author: Maxim Gekk Closes #21730 from MaxGekk/is-modifiable. --- python/pyspark/sql/conf.py | 8 ++++++++ .../org/apache/spark/sql/internal/SQLConf.scala | 4 ++++ .../scala/org/apache/spark/sql/RuntimeConfig.scala | 11 +++++++++++ .../org/apache/spark/sql/RuntimeConfigSuite.scala | 14 ++++++++++++++ 4 files changed, 37 insertions(+) diff --git a/python/pyspark/sql/conf.py b/python/pyspark/sql/conf.py index db49040e17b63..f80bf598c2211 100644 --- a/python/pyspark/sql/conf.py +++ b/python/pyspark/sql/conf.py @@ -63,6 +63,14 @@ def _checkType(self, obj, identifier): raise TypeError("expected %s '%s' to be a string (was '%s')" % (identifier, obj, type(obj).__name__)) + @ignore_unicode_prefix + @since(2.4) + def isModifiable(self, key): + """Indicates whether the configuration property with the given key + is modifiable in the current session. + """ + return self._jconf.isModifiable(key) + def _test(): import os diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index ae56cc97581a5..14dd5281fbcb1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1907,4 +1907,8 @@ class SQLConf extends Serializable with Logging { } cloned } + + def isModifiable(key: String): Boolean = { + sqlConfEntries.containsKey(key) && !staticConfKeys.contains(key) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RuntimeConfig.scala b/sql/core/src/main/scala/org/apache/spark/sql/RuntimeConfig.scala index b352e332bc7e0..3c39579149fff 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RuntimeConfig.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RuntimeConfig.scala @@ -132,6 +132,17 @@ class RuntimeConfig private[sql](sqlConf: SQLConf = new SQLConf) { sqlConf.unsetConf(key) } + /** + * Indicates whether the configuration property with the given key + * is modifiable in the current session. + * + * @return `true` if the configuration property is modifiable. For static SQL, Spark Core, + * invalid (not existing) and other non-modifiable configuration properties, + * the returned value is `false`. + * @since 2.4.0 + */ + def isModifiable(key: String): Boolean = sqlConf.isModifiable(key) + /** * Returns whether a particular key is set. */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/RuntimeConfigSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/RuntimeConfigSuite.scala index cfe2e9f2dbc44..cdcea09ad9758 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/RuntimeConfigSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/RuntimeConfigSuite.scala @@ -54,4 +54,18 @@ class RuntimeConfigSuite extends SparkFunSuite { conf.get("k1") } } + + test("SPARK-24761: is a config parameter modifiable") { + val conf = newConf() + + // SQL configs + assert(!conf.isModifiable("spark.sql.sources.schemaStringLengthThreshold")) + assert(conf.isModifiable("spark.sql.streaming.checkpointLocation")) + // Core configs + assert(!conf.isModifiable("spark.task.cpus")) + assert(!conf.isModifiable("spark.executor.cores")) + // Invalid config parameters + assert(!conf.isModifiable("")) + assert(!conf.isModifiable("invalid config parameter")) + } } From 5ad4735bdad558fe564a0391e207c62743647ab1 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Thu, 12 Jul 2018 09:52:23 +0800 Subject: [PATCH 1103/1161] [SPARK-24529][BUILD][TEST-MAVEN] Add spotbugs into maven build process ## What changes were proposed in this pull request? This PR enables a Java bytecode check tool [spotbugs](https://spotbugs.github.io/) to avoid possible integer overflow at multiplication. When an violation is detected, the build process is stopped. Due to the tool limitation, some other checks will be enabled. In this PR, [these patterns](http://spotbugs-in-kengo-toda.readthedocs.io/en/lqc-list-detectors/detectors.html#findpuzzlers) in `FindPuzzlers` can be detected. This check is enabled at `compile` phase. Thus, `mvn compile` or `mvn package` launches this check. ## How was this patch tested? Existing UTs Author: Kazuaki Ishizaki Closes #21542 from kiszk/SPARK-24529. --- .../util/collection/ExternalSorter.scala | 4 ++-- .../apache/spark/ml/image/HadoopUtils.scala | 8 ++++--- pom.xml | 22 +++++++++++++++++++ resource-managers/kubernetes/core/pom.xml | 6 +++++ .../kubernetes/integration-tests/pom.xml | 5 +++++ .../expressions/collectionOperations.scala | 2 +- .../expressions/conditionalExpressions.scala | 4 ++-- .../sql/catalyst/parser/AstBuilder.scala | 2 +- 8 files changed, 44 insertions(+), 9 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index 176f84fa2a0d2..b159200d79222 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -368,8 +368,8 @@ private[spark] class ExternalSorter[K, V, C]( val bufferedIters = iterators.filter(_.hasNext).map(_.buffered) type Iter = BufferedIterator[Product2[K, C]] val heap = new mutable.PriorityQueue[Iter]()(new Ordering[Iter] { - // Use the reverse of comparator.compare because PriorityQueue dequeues the max - override def compare(x: Iter, y: Iter): Int = -comparator.compare(x.head._1, y.head._1) + // Use the reverse order because PriorityQueue dequeues the max + override def compare(x: Iter, y: Iter): Int = comparator.compare(y.head._1, x.head._1) }) heap.enqueue(bufferedIters: _*) // Will contain only the iterators with hasNext = true new Iterator[Product2[K, C]] { diff --git a/mllib/src/main/scala/org/apache/spark/ml/image/HadoopUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/image/HadoopUtils.scala index 8c975a2fba8ca..f1579ec5844a4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/image/HadoopUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/image/HadoopUtils.scala @@ -42,9 +42,11 @@ private object RecursiveFlag { val old = Option(hadoopConf.get(flagName)) hadoopConf.set(flagName, value.toString) try f finally { - old match { - case Some(v) => hadoopConf.set(flagName, v) - case None => hadoopConf.unset(flagName) + // avoid false positive of DLS_DEAD_LOCAL_STORE_IN_RETURN by SpotBugs + if (old.isDefined) { + hadoopConf.set(flagName, old.get) + } else { + hadoopConf.unset(flagName) } } } diff --git a/pom.xml b/pom.xml index cd567e227f331..6dee6fce3ffc4 100644 --- a/pom.xml +++ b/pom.xml @@ -2606,6 +2606,28 @@ + + com.github.spotbugs + spotbugs-maven-plugin + 3.1.3 + + ${basedir}/target/scala-${scala.binary.version}/classes + ${basedir}/target/scala-${scala.binary.version}/test-classes + Max + Low + true + FindPuzzlers + false + + + + + check + + compile + + + diff --git a/resource-managers/kubernetes/core/pom.xml b/resource-managers/kubernetes/core/pom.xml index a6dd47a6b7d95..920f0f6ebf2c8 100644 --- a/resource-managers/kubernetes/core/pom.xml +++ b/resource-managers/kubernetes/core/pom.xml @@ -47,6 +47,12 @@ test + + org.apache.spark + spark-tags_${scala.binary.version} + test-jar + + io.fabric8 kubernetes-client diff --git a/resource-managers/kubernetes/integration-tests/pom.xml b/resource-managers/kubernetes/integration-tests/pom.xml index 6a2fff891098b..29334cc6d891d 100644 --- a/resource-managers/kubernetes/integration-tests/pom.xml +++ b/resource-managers/kubernetes/integration-tests/pom.xml @@ -63,6 +63,11 @@ kubernetes-client ${kubernetes-client.version} + + org.apache.spark + spark-tags_${scala.binary.version} + test-jar + diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index e217d37184511..b8f2aa3e624ce 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -993,7 +993,7 @@ trait ArraySortLike extends ExpectsInputTypes { } else if (o2 == null) { nullOrder } else { - -ordering.compare(o1, o2) + ordering.compare(o2, o1) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index e6377b7d87b53..30ce9e4743da9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -294,7 +294,7 @@ object CaseWhen { case cond :: value :: Nil => Some((cond, value)) case value :: Nil => None }.toArray.toSeq // force materialization to make the seq serializable - val elseValue = if (branches.size % 2 == 1) Some(branches.last) else None + val elseValue = if (branches.size % 2 != 0) Some(branches.last) else None CaseWhen(cases, elseValue) } } @@ -309,7 +309,7 @@ object CaseKeyWhen { case Seq(cond, value) => Some((EqualTo(key, cond), value)) case Seq(value) => None }.toArray.toSeq // force materialization to make the seq serializable - val elseValue = if (branches.size % 2 == 1) Some(branches.last) else None + val elseValue = if (branches.size % 2 != 0) Some(branches.last) else None CaseWhen(cases, elseValue) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 383ebde3229d6..f398b479dc273 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -1507,7 +1507,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging case "TIMESTAMP" => Literal(Timestamp.valueOf(value)) case "X" => - val padding = if (value.length % 2 == 1) "0" else "" + val padding = if (value.length % 2 != 0) "0" else "" Literal(DatatypeConverter.parseHexBinary(padding + value)) case other => throw new ParseException(s"Literals of type '$other' are currently not supported.", ctx) From 301bff70637983426d76b106b7c659c1f28ed7bf Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Thu, 12 Jul 2018 17:42:29 +0900 Subject: [PATCH 1104/1161] [SPARK-23914][SQL] Add array_union function ## What changes were proposed in this pull request? The PR adds the SQL function `array_union`. The behavior of the function is based on Presto's one. This function returns returns an array of the elements in the union of array1 and array2. Note: The order of elements in the result is not defined. ## How was this patch tested? Added UTs Author: Kazuaki Ishizaki Closes #21061 from kiszk/SPARK-23914. --- python/pyspark/sql/functions.py | 19 ++ .../catalyst/expressions/UnsafeArrayData.java | 19 +- .../catalyst/analysis/FunctionRegistry.scala | 1 + .../expressions/collectionOperations.scala | 319 ++++++++++++++++++ .../CollectionExpressionsSuite.scala | 81 +++++ .../org/apache/spark/sql/functions.scala | 11 + .../spark/sql/DataFrameFunctionsSuite.scala | 52 +++ 7 files changed, 499 insertions(+), 3 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 9f61e29f9cd42..5ef73987a66a6 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2033,6 +2033,25 @@ def array_distinct(col): return Column(sc._jvm.functions.array_distinct(_to_java_column(col))) +@ignore_unicode_prefix +@since(2.4) +def array_union(col1, col2): + """ + Collection function: returns an array of the elements in the union of col1 and col2, + without duplicates. + + :param col1: name of column containing array + :param col2: name of column containing array + + >>> from pyspark.sql import Row + >>> df = spark.createDataFrame([Row(c1=["b", "a", "c"], c2=["c", "d", "a", "f"])]) + >>> df.select(array_union(df.c1, df.c2)).collect() + [Row(array_union(c1, c2)=[u'b', u'a', u'c', u'd', u'f'])] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.array_union(_to_java_column(col1), _to_java_column(col2))) + + @since(1.4) def explode(col): """Returns a new row for each element in the given array or map. diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java index 4dd2b7365652a..cf2a5ed2e27f9 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java @@ -450,7 +450,7 @@ public double[] toDoubleArray() { return values; } - private static UnsafeArrayData fromPrimitiveArray( + public static UnsafeArrayData fromPrimitiveArray( Object arr, int offset, int length, int elementSize) { final long headerInBytes = calculateHeaderPortionInBytes(length); final long valueRegionInBytes = (long)elementSize * length; @@ -463,14 +463,27 @@ private static UnsafeArrayData fromPrimitiveArray( final long[] data = new long[(int)totalSizeInLongs]; Platform.putLong(data, Platform.LONG_ARRAY_OFFSET, length); - Platform.copyMemory(arr, offset, data, - Platform.LONG_ARRAY_OFFSET + headerInBytes, valueRegionInBytes); + if (arr != null) { + Platform.copyMemory(arr, offset, data, + Platform.LONG_ARRAY_OFFSET + headerInBytes, valueRegionInBytes); + } UnsafeArrayData result = new UnsafeArrayData(); result.pointTo(data, Platform.LONG_ARRAY_OFFSET, (int)totalSizeInLongs * 8); return result; } + public static UnsafeArrayData forPrimitiveArray(int offset, int length, int elementSize) { + return fromPrimitiveArray(null, offset, length, elementSize); + } + + public static boolean shouldUseGenericArrayData(int elementSize, int length) { + final long headerInBytes = calculateHeaderPortionInBytes(length); + final long valueRegionInBytes = (long)elementSize * length; + final long totalSizeInLongs = (headerInBytes + valueRegionInBytes + 7) / 8; + return totalSizeInLongs > Integer.MAX_VALUE / 8; + } + public static UnsafeArrayData fromPrimitiveArray(boolean[] arr) { return fromPrimitiveArray(arr, Platform.BOOLEAN_ARRAY_OFFSET, arr.length, 1); } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index e7517e8c676e3..1d9e470c9ba83 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -414,6 +414,7 @@ object FunctionRegistry { expression[ArrayJoin]("array_join"), expression[ArrayPosition]("array_position"), expression[ArraySort]("array_sort"), + expression[ArrayUnion]("array_union"), expression[CreateMap]("map"), expression[CreateNamedStruct]("named_struct"), expression[ElementAt]("element_at"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index b8f2aa3e624ce..0f4f4f1601b4a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -3486,3 +3486,322 @@ case class ArrayDistinct(child: Expression) override def prettyName: String = "array_distinct" } + +/** + * Will become common base class for [[ArrayUnion]], ArrayIntersect, and ArrayExcept. + */ +abstract class ArraySetLike extends BinaryArrayExpressionWithImplicitCast { + override def dataType: DataType = { + val dataTypes = children.map(_.dataType.asInstanceOf[ArrayType]) + ArrayType(elementType, dataTypes.exists(_.containsNull)) + } + + override def checkInputDataTypes(): TypeCheckResult = { + val typeCheckResult = super.checkInputDataTypes() + if (typeCheckResult.isSuccess) { + TypeUtils.checkForOrderingExpr(dataType.asInstanceOf[ArrayType].elementType, + s"function $prettyName") + } else { + typeCheckResult + } + } + + @transient protected lazy val ordering: Ordering[Any] = + TypeUtils.getInterpretedOrdering(elementType) + + @transient protected lazy val elementTypeSupportEquals = elementType match { + case BinaryType => false + case _: AtomicType => true + case _ => false + } +} + +object ArraySetLike { + def throwUnionLengthOverflowException(length: Int): Unit = { + throw new RuntimeException(s"Unsuccessful try to union arrays with $length " + + s"elements due to exceeding the array size limit " + + s"${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}.") + } +} + + +/** + * Returns an array of the elements in the union of x and y, without duplicates + */ +@ExpressionDescription( + usage = """ + _FUNC_(array1, array2) - Returns an array of the elements in the union of array1 and array2, + without duplicates. + """, + examples = """ + Examples: + > SELECT _FUNC_(array(1, 2, 3), array(1, 3, 5)); + array(1, 2, 3, 5) + """, + since = "2.4.0") +case class ArrayUnion(left: Expression, right: Expression) extends ArraySetLike { + var hsInt: OpenHashSet[Int] = _ + var hsLong: OpenHashSet[Long] = _ + + def assignInt(array: ArrayData, idx: Int, resultArray: ArrayData, pos: Int): Boolean = { + val elem = array.getInt(idx) + if (!hsInt.contains(elem)) { + if (resultArray != null) { + resultArray.setInt(pos, elem) + } + hsInt.add(elem) + true + } else { + false + } + } + + def assignLong(array: ArrayData, idx: Int, resultArray: ArrayData, pos: Int): Boolean = { + val elem = array.getLong(idx) + if (!hsLong.contains(elem)) { + if (resultArray != null) { + resultArray.setLong(pos, elem) + } + hsLong.add(elem) + true + } else { + false + } + } + + def evalIntLongPrimitiveType( + array1: ArrayData, + array2: ArrayData, + resultArray: ArrayData, + isLongType: Boolean): Int = { + // store elements into resultArray + var nullElementSize = 0 + var pos = 0 + Seq(array1, array2).foreach { array => + var i = 0 + while (i < array.numElements()) { + val size = if (!isLongType) hsInt.size else hsLong.size + if (size + nullElementSize > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { + ArraySetLike.throwUnionLengthOverflowException(size) + } + if (array.isNullAt(i)) { + if (nullElementSize == 0) { + if (resultArray != null) { + resultArray.setNullAt(pos) + } + pos += 1 + nullElementSize = 1 + } + } else { + val assigned = if (!isLongType) { + assignInt(array, i, resultArray, pos) + } else { + assignLong(array, i, resultArray, pos) + } + if (assigned) { + pos += 1 + } + } + i += 1 + } + } + pos + } + + override def nullSafeEval(input1: Any, input2: Any): Any = { + val array1 = input1.asInstanceOf[ArrayData] + val array2 = input2.asInstanceOf[ArrayData] + + if (elementTypeSupportEquals) { + elementType match { + case IntegerType => + // avoid boxing of primitive int array elements + // calculate result array size + hsInt = new OpenHashSet[Int] + val elements = evalIntLongPrimitiveType(array1, array2, null, false) + hsInt = new OpenHashSet[Int] + val resultArray = if (UnsafeArrayData.shouldUseGenericArrayData( + IntegerType.defaultSize, elements)) { + new GenericArrayData(new Array[Any](elements)) + } else { + UnsafeArrayData.forPrimitiveArray( + Platform.INT_ARRAY_OFFSET, elements, IntegerType.defaultSize) + } + evalIntLongPrimitiveType(array1, array2, resultArray, false) + resultArray + case LongType => + // avoid boxing of primitive long array elements + // calculate result array size + hsLong = new OpenHashSet[Long] + val elements = evalIntLongPrimitiveType(array1, array2, null, true) + hsLong = new OpenHashSet[Long] + val resultArray = if (UnsafeArrayData.shouldUseGenericArrayData( + LongType.defaultSize, elements)) { + new GenericArrayData(new Array[Any](elements)) + } else { + UnsafeArrayData.forPrimitiveArray( + Platform.LONG_ARRAY_OFFSET, elements, LongType.defaultSize) + } + evalIntLongPrimitiveType(array1, array2, resultArray, true) + resultArray + case _ => + val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any] + val hs = new OpenHashSet[Any] + var foundNullElement = false + Seq(array1, array2).foreach { array => + var i = 0 + while (i < array.numElements()) { + if (array.isNullAt(i)) { + if (!foundNullElement) { + arrayBuffer += null + foundNullElement = true + } + } else { + val elem = array.get(i, elementType) + if (!hs.contains(elem)) { + if (arrayBuffer.size > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { + ArraySetLike.throwUnionLengthOverflowException(arrayBuffer.size) + } + arrayBuffer += elem + hs.add(elem) + } + } + i += 1 + } + } + new GenericArrayData(arrayBuffer) + } + } else { + ArrayUnion.unionOrdering(array1, array2, elementType, ordering) + } + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val i = ctx.freshName("i") + val pos = ctx.freshName("pos") + val value = ctx.freshName("value") + val size = ctx.freshName("size") + val (postFix, openHashElementType, getter, setter, javaTypeName, castOp, arrayBuilder) = + if (elementTypeSupportEquals) { + elementType match { + case ByteType | ShortType | IntegerType | LongType => + val ptName = CodeGenerator.primitiveTypeName(elementType) + val unsafeArray = ctx.freshName("unsafeArray") + (if (elementType == LongType) s"$$mcJ$$sp" else s"$$mcI$$sp", + if (elementType == LongType) "Long" else "Int", + s"get$ptName($i)", s"set$ptName($pos, $value)", CodeGenerator.javaType(elementType), + if (elementType == LongType) "(long)" else "(int)", + s""" + |${ctx.createUnsafeArray(unsafeArray, size, elementType, s" $prettyName failed.")} + |${ev.value} = $unsafeArray; + """.stripMargin) + case _ => + val genericArrayData = classOf[GenericArrayData].getName + val et = ctx.addReferenceObj("elementType", elementType) + ("", "Object", + s"get($i, $et)", s"update($pos, $value)", "Object", "", + s"${ev.value} = new $genericArrayData(new Object[$size]);") + } + } else { + ("", "", "", "", "", "", "") + } + + nullSafeCodeGen(ctx, ev, (array1, array2) => { + if (openHashElementType != "") { + // Here, we ensure elementTypeSupportEquals is true + val foundNullElement = ctx.freshName("foundNullElement") + val openHashSet = classOf[OpenHashSet[_]].getName + val classTag = s"scala.reflect.ClassTag$$.MODULE$$.$openHashElementType()" + val hs = ctx.freshName("hs") + val arrayData = classOf[ArrayData].getName + val arrays = ctx.freshName("arrays") + val array = ctx.freshName("array") + val arrayDataIdx = ctx.freshName("arrayDataIdx") + s""" + |$openHashSet $hs = new $openHashSet$postFix($classTag); + |boolean $foundNullElement = false; + |$arrayData[] $arrays = new $arrayData[]{$array1, $array2}; + |for (int $arrayDataIdx = 0; $arrayDataIdx < 2; $arrayDataIdx++) { + | $arrayData $array = $arrays[$arrayDataIdx]; + | for (int $i = 0; $i < $array.numElements(); $i++) { + | if ($array.isNullAt($i)) { + | $foundNullElement = true; + | } else { + | $hs.add$postFix($array.$getter); + | } + | } + |} + |int $size = $hs.size() + ($foundNullElement ? 1 : 0); + |$arrayBuilder + |$hs = new $openHashSet$postFix($classTag); + |$foundNullElement = false; + |int $pos = 0; + |for (int $arrayDataIdx = 0; $arrayDataIdx < 2; $arrayDataIdx++) { + | $arrayData $array = $arrays[$arrayDataIdx]; + | for (int $i = 0; $i < $array.numElements(); $i++) { + | if ($array.isNullAt($i)) { + | if (!$foundNullElement) { + | ${ev.value}.setNullAt($pos++); + | $foundNullElement = true; + | } + | } else { + | $javaTypeName $value = $array.$getter; + | if (!$hs.contains($castOp $value)) { + | $hs.add$postFix($value); + | ${ev.value}.$setter; + | $pos++; + | } + | } + | } + |} + """.stripMargin + } else { + val arrayUnion = classOf[ArrayUnion].getName + val et = ctx.addReferenceObj("elementTypeUnion", elementType) + val order = ctx.addReferenceObj("orderingUnion", ordering) + val method = "unionOrdering" + s"${ev.value} = $arrayUnion$$.MODULE$$.$method($array1, $array2, $et, $order);" + } + }) + } + + override def prettyName: String = "array_union" +} + +object ArrayUnion { + def unionOrdering( + array1: ArrayData, + array2: ArrayData, + elementType: DataType, + ordering: Ordering[Any]): ArrayData = { + val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any] + var alreadyIncludeNull = false + Seq(array1, array2).foreach(_.foreach(elementType, (_, elem) => { + var found = false + if (elem == null) { + if (alreadyIncludeNull) { + found = true + } else { + alreadyIncludeNull = true + } + } else { + // check elem is already stored in arrayBuffer or not? + var j = 0 + while (!found && j < arrayBuffer.size) { + val va = arrayBuffer(j) + if (va != null && ordering.equiv(va, elem)) { + found = true + } + j = j + 1 + } + } + if (!found) { + if (arrayBuffer.length > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { + ArraySetLike.throwUnionLengthOverflowException(arrayBuffer.length) + } + arrayBuffer += elem + } + })) + new GenericArrayData(arrayBuffer) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index a838a2eedc03c..85d6a1befed6b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -1304,4 +1304,85 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(ArrayDistinct(c1), Seq[Seq[Int]](Seq[Int](5, 6), Seq[Int](2, 1))) checkEvaluation(ArrayDistinct(c2), Seq[Seq[Int]](null, Seq[Int](2, 1))) } + + test("Array Union") { + val a00 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType, containsNull = false)) + val a01 = Literal.create(Seq(4, 2), ArrayType(IntegerType, containsNull = false)) + val a02 = Literal.create(Seq(1, 2, null, 4, 5), ArrayType(IntegerType, containsNull = true)) + val a03 = Literal.create(Seq(-5, 4, -3, 2, 4), ArrayType(IntegerType, containsNull = false)) + val a04 = Literal.create(Seq.empty[Int], ArrayType(IntegerType, containsNull = false)) + val a05 = Literal.create(Seq[Byte](1, 2, 3), ArrayType(ByteType, containsNull = false)) + val a06 = Literal.create(Seq[Byte](4, 2), ArrayType(ByteType, containsNull = false)) + val a07 = Literal.create(Seq[Short](1, 2, 3), ArrayType(ShortType, containsNull = false)) + val a08 = Literal.create(Seq[Short](4, 2), ArrayType(ShortType, containsNull = false)) + + val a10 = Literal.create(Seq(1L, 2L, 3L), ArrayType(LongType, containsNull = false)) + val a11 = Literal.create(Seq(4L, 2L), ArrayType(LongType, containsNull = false)) + val a12 = Literal.create(Seq(1L, 2L, null, 4L, 5L), ArrayType(LongType, containsNull = true)) + val a13 = Literal.create(Seq(-5L, 4L, -3L, 2L, -1L), ArrayType(LongType, containsNull = false)) + val a14 = Literal.create(Seq.empty[Long], ArrayType(LongType, containsNull = false)) + + val a20 = Literal.create(Seq("b", "a", "c"), ArrayType(StringType, containsNull = false)) + val a21 = Literal.create(Seq("c", "d", "a", "f"), ArrayType(StringType, containsNull = false)) + val a22 = Literal.create(Seq("b", null, "a", "g"), ArrayType(StringType, containsNull = true)) + + val a30 = Literal.create(Seq(null, null), ArrayType(IntegerType)) + val a31 = Literal.create(null, ArrayType(StringType)) + + checkEvaluation(ArrayUnion(a00, a01), Seq(1, 2, 3, 4)) + checkEvaluation(ArrayUnion(a02, a03), Seq(1, 2, null, 4, 5, -5, -3)) + checkEvaluation(ArrayUnion(a03, a02), Seq(-5, 4, -3, 2, 1, null, 5)) + checkEvaluation(ArrayUnion(a02, a04), Seq(1, 2, null, 4, 5)) + checkEvaluation(ArrayUnion(a05, a06), Seq[Byte](1, 2, 3, 4)) + checkEvaluation(ArrayUnion(a07, a08), Seq[Short](1, 2, 3, 4)) + + checkEvaluation(ArrayUnion(a10, a11), Seq(1L, 2L, 3L, 4L)) + checkEvaluation(ArrayUnion(a12, a13), Seq(1L, 2L, null, 4L, 5L, -5L, -3L, -1L)) + checkEvaluation(ArrayUnion(a13, a12), Seq(-5L, 4L, -3L, 2L, -1L, 1L, null, 5L)) + checkEvaluation(ArrayUnion(a12, a14), Seq(1L, 2L, null, 4L, 5L)) + + checkEvaluation(ArrayUnion(a20, a21), Seq("b", "a", "c", "d", "f")) + checkEvaluation(ArrayUnion(a20, a22), Seq("b", "a", "c", null, "g")) + + checkEvaluation(ArrayUnion(a30, a30), Seq(null)) + checkEvaluation(ArrayUnion(a20, a31), null) + checkEvaluation(ArrayUnion(a31, a20), null) + + val b0 = Literal.create(Seq[Array[Byte]](Array[Byte](5, 6), Array[Byte](1, 2)), + ArrayType(BinaryType)) + val b1 = Literal.create(Seq[Array[Byte]](Array[Byte](2, 1), Array[Byte](4, 3)), + ArrayType(BinaryType)) + val b2 = Literal.create(Seq[Array[Byte]](Array[Byte](1, 2), Array[Byte](4, 3)), + ArrayType(BinaryType)) + val b3 = Literal.create(Seq[Array[Byte]]( + Array[Byte](1, 2), Array[Byte](4, 3), Array[Byte](1, 2)), ArrayType(BinaryType)) + val b4 = Literal.create(Seq[Array[Byte]](Array[Byte](1, 2), null), ArrayType(BinaryType)) + val b5 = Literal.create(Seq[Array[Byte]](null, Array[Byte](1, 2)), ArrayType(BinaryType)) + val b6 = Literal.create(Seq.empty, ArrayType(BinaryType)) + val arrayWithBinaryNull = Literal.create(Seq(null), ArrayType(BinaryType)) + + checkEvaluation(ArrayUnion(b0, b1), + Seq(Array[Byte](5, 6), Array[Byte](1, 2), Array[Byte](2, 1), Array[Byte](4, 3))) + checkEvaluation(ArrayUnion(b0, b2), + Seq(Array[Byte](5, 6), Array[Byte](1, 2), Array[Byte](4, 3))) + checkEvaluation(ArrayUnion(b2, b4), Seq(Array[Byte](1, 2), Array[Byte](4, 3), null)) + checkEvaluation(ArrayUnion(b3, b0), + Seq(Array[Byte](1, 2), Array[Byte](4, 3), Array[Byte](5, 6))) + checkEvaluation(ArrayUnion(b4, b0), Seq(Array[Byte](1, 2), null, Array[Byte](5, 6))) + checkEvaluation(ArrayUnion(b4, b5), Seq(Array[Byte](1, 2), null)) + checkEvaluation(ArrayUnion(b6, b4), Seq(Array[Byte](1, 2), null)) + checkEvaluation(ArrayUnion(b4, arrayWithBinaryNull), Seq(Array[Byte](1, 2), null)) + + val aa0 = Literal.create(Seq[Seq[Int]](Seq[Int](1, 2), Seq[Int](3, 4)), + ArrayType(ArrayType(IntegerType))) + val aa1 = Literal.create(Seq[Seq[Int]](Seq[Int](5, 6), Seq[Int](2, 1)), + ArrayType(ArrayType(IntegerType))) + checkEvaluation(ArrayUnion(aa0, aa1), + Seq[Seq[Int]](Seq[Int](1, 2), Seq[Int](3, 4), Seq[Int](5, 6), Seq[Int](2, 1))) + + assert(ArrayUnion(a00, a01).dataType.asInstanceOf[ArrayType].containsNull === false) + assert(ArrayUnion(a00, a02).dataType.asInstanceOf[ArrayType].containsNull === true) + assert(ArrayUnion(a20, a21).dataType.asInstanceOf[ArrayType].containsNull === false) + assert(ArrayUnion(a20, a22).dataType.asInstanceOf[ArrayType].containsNull === true) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 6b956ddb48561..b98ab11e56feb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3204,6 +3204,7 @@ object functions { /** * Remove all elements that equal to element from the given array. + * * @group collection_funcs * @since 2.4.0 */ @@ -3218,6 +3219,16 @@ object functions { */ def array_distinct(e: Column): Column = withExpr { ArrayDistinct(e.expr) } + /** + * Returns an array of the elements in the union of the given two arrays, without duplicates. + * + * @group collection_funcs + * @since 2.4.0 + */ + def array_union(col1: Column, col2: Column): Column = withExpr { + ArrayUnion(col1.expr, col2.expr) + } + /** * Creates a new row for each element in the given array or map column. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index d60ed7a5ef0d9..d4615714cff03 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -1198,6 +1198,58 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { "argument 1 requires (array or map) type, however, '`_1`' is of string type")) } + test("array_union functions") { + val df1 = Seq((Array(1, 2, 3), Array(4, 2))).toDF("a", "b") + val ans1 = Row(Seq(1, 2, 3, 4)) + checkAnswer(df1.select(array_union($"a", $"b")), ans1) + checkAnswer(df1.selectExpr("array_union(a, b)"), ans1) + + val df2 = Seq((Array[Integer](1, 2, null, 4, 5), Array(-5, 4, -3, 2, -1))).toDF("a", "b") + val ans2 = Row(Seq(1, 2, null, 4, 5, -5, -3, -1)) + checkAnswer(df2.select(array_union($"a", $"b")), ans2) + checkAnswer(df2.selectExpr("array_union(a, b)"), ans2) + + val df3 = Seq((Array(1L, 2L, 3L), Array(4L, 2L))).toDF("a", "b") + val ans3 = Row(Seq(1L, 2L, 3L, 4L)) + checkAnswer(df3.select(array_union($"a", $"b")), ans3) + checkAnswer(df3.selectExpr("array_union(a, b)"), ans3) + + val df4 = Seq((Array[java.lang.Long](1L, 2L, null, 4L, 5L), Array(-5L, 4L, -3L, 2L, -1L))) + .toDF("a", "b") + val ans4 = Row(Seq(1L, 2L, null, 4L, 5L, -5L, -3L, -1L)) + checkAnswer(df4.select(array_union($"a", $"b")), ans4) + checkAnswer(df4.selectExpr("array_union(a, b)"), ans4) + + val df5 = Seq((Array("b", "a", "c"), Array("b", null, "a", "g"))).toDF("a", "b") + val ans5 = Row(Seq("b", "a", "c", null, "g")) + checkAnswer(df5.select(array_union($"a", $"b")), ans5) + checkAnswer(df5.selectExpr("array_union(a, b)"), ans5) + + val df6 = Seq((null, Array("a"))).toDF("a", "b") + intercept[AnalysisException] { + df6.select(array_union($"a", $"b")) + } + intercept[AnalysisException] { + df6.selectExpr("array_union(a, b)") + } + + val df7 = Seq((null, null)).toDF("a", "b") + intercept[AnalysisException] { + df7.select(array_union($"a", $"b")) + } + intercept[AnalysisException] { + df7.selectExpr("array_union(a, b)") + } + + val df8 = Seq((Array(Array(1)), Array("a"))).toDF("a", "b") + intercept[AnalysisException] { + df8.select(array_union($"a", $"b")) + } + intercept[AnalysisException] { + df8.selectExpr("array_union(a, b)") + } + } + test("concat function - arrays") { val nseqi : Seq[Int] = null val nseqs : Seq[String] = null From e6c6f90a55241905c420afbc803dd3bd6961d66b Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Fri, 13 Jul 2018 00:26:49 +0800 Subject: [PATCH 1105/1161] [SPARK-24691][SQL] Dispatch the type support check in FileFormat implementation ## What changes were proposed in this pull request? With https://github.com/apache/spark/pull/21389, data source schema is validated on driver side before launching read/write tasks. However, 1. Putting all the validations together in `DataSourceUtils` is tricky and hard to maintain. On second thought after review, I find that the `OrcFileFormat` in hive package is not matched, so that its validation wrong. 2. `DataSourceUtils.verifyWriteSchema` and `DataSourceUtils.verifyReadSchema` is not supposed to be called in every file format. We can move them to some upper entry. So, I propose we can add a new method `validateDataType` in FileFormat. File format implementation can override the method to specify its supported/non-supported data types. Although we should focus on data source V2 API, `FileFormat` should remain workable for some time. Adding this new method should be helpful. ## How was this patch tested? Unit test Author: Gengliang Wang Closes #21667 from gengliangwang/refactorSchemaValidate. --- .../execution/datasources/DataSource.scala | 1 + .../datasources/DataSourceUtils.scala | 68 ++-------- .../execution/datasources/FileFormat.scala | 9 +- .../datasources/FileFormatWriter.scala | 4 +- .../datasources/csv/CSVFileFormat.scala | 11 +- .../datasources/json/JsonFileFormat.scala | 23 +++- .../datasources/orc/OrcFileFormat.scala | 21 +++- .../parquet/ParquetFileFormat.scala | 19 ++- .../datasources/text/TextFileFormat.scala | 10 +- .../spark/sql/FileBasedDataSourceSuite.scala | 119 ++++++++++++------ .../sql/execution/command/DDLSuite.scala | 2 +- .../datasources/FileSourceStrategySuite.scala | 2 +- .../spark/sql/hive/orc/OrcFileFormat.scala | 21 +++- .../sql/hive/orc/HiveOrcSourceSuite.scala | 19 +-- 14 files changed, 189 insertions(+), 140 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index f16d824201e77..0c3d9a4895fe2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -396,6 +396,7 @@ case class DataSource( hs.partitionSchema.map(_.name), "in the partition schema", equality) + DataSourceUtils.verifyReadSchema(hs.fileFormat, hs.dataSchema) case _ => SchemaUtils.checkColumnNameDuplication( relation.schema.map(_.name), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala index c5347218c4b40..82e99190ecf14 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala @@ -17,10 +17,7 @@ package org.apache.spark.sql.execution.datasources -import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat -import org.apache.spark.sql.execution.datasources.json.JsonFileFormat -import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat -import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.types._ @@ -42,65 +39,14 @@ object DataSourceUtils { /** * Verify if the schema is supported in datasource. This verification should be done - * in a driver side, e.g., `prepareWrite`, `buildReader`, and `buildReaderWithPartitionValues` - * in `FileFormat`. - * - * Unsupported data types of csv, json, orc, and parquet are as follows; - * csv -> R/W: Interval, Null, Array, Map, Struct - * json -> W: Interval - * orc -> W: Interval, Null - * parquet -> R/W: Interval, Null + * in a driver side. */ private def verifySchema(format: FileFormat, schema: StructType, isReadPath: Boolean): Unit = { - def throwUnsupportedException(dataType: DataType): Unit = { - throw new UnsupportedOperationException( - s"$format data source does not support ${dataType.simpleString} data type.") + schema.foreach { field => + if (!format.supportDataType(field.dataType, isReadPath)) { + throw new AnalysisException( + s"$format data source does not support ${field.dataType.simpleString} data type.") + } } - - def verifyType(dataType: DataType): Unit = dataType match { - case BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType | - StringType | BinaryType | DateType | TimestampType | _: DecimalType => - - // All the unsupported types for CSV - case _: NullType | _: CalendarIntervalType | _: StructType | _: ArrayType | _: MapType - if format.isInstanceOf[CSVFileFormat] => - throwUnsupportedException(dataType) - - case st: StructType => st.foreach { f => verifyType(f.dataType) } - - case ArrayType(elementType, _) => verifyType(elementType) - - case MapType(keyType, valueType, _) => - verifyType(keyType) - verifyType(valueType) - - case udt: UserDefinedType[_] => verifyType(udt.sqlType) - - // Interval type not supported in all the write path - case _: CalendarIntervalType if !isReadPath => - throwUnsupportedException(dataType) - - // JSON and ORC don't support an Interval type, but we pass it in read pass - // for back-compatibility. - case _: CalendarIntervalType if format.isInstanceOf[JsonFileFormat] || - format.isInstanceOf[OrcFileFormat] => - - // Interval type not supported in the other read path - case _: CalendarIntervalType => - throwUnsupportedException(dataType) - - // For JSON & ORC backward-compatibility - case _: NullType if format.isInstanceOf[JsonFileFormat] || - (isReadPath && format.isInstanceOf[OrcFileFormat]) => - - // Null type not supported in the other path - case _: NullType => - throwUnsupportedException(dataType) - - // We keep this default case for safeguards - case _ => throwUnsupportedException(dataType) - } - - schema.foreach(field => verifyType(field.dataType)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala index 023e127888290..2c162e23644ef 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.Filter -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{DataType, StructType} /** @@ -57,7 +57,7 @@ trait FileFormat { dataSchema: StructType): OutputWriterFactory /** - * Returns whether this format support returning columnar batch or not. + * Returns whether this format supports returning columnar batch or not. * * TODO: we should just have different traits for the different formats. */ @@ -152,6 +152,11 @@ trait FileFormat { } } + /** + * Returns whether this format supports the given [[DataType]] in read/write path. + * By default all data types are supported. + */ + def supportDataType(dataType: DataType, isReadPath: Boolean): Boolean = true } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala index 52da8356ab835..7c6ab4bc922fe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala @@ -96,9 +96,11 @@ object FileFormatWriter extends Logging { val caseInsensitiveOptions = CaseInsensitiveMap(options) + val dataSchema = dataColumns.toStructType + DataSourceUtils.verifyWriteSchema(fileFormat, dataSchema) // Note: prepareWrite has side effect. It sets "job". val outputWriterFactory = - fileFormat.prepareWrite(sparkSession, job, caseInsensitiveOptions, dataColumns.toStructType) + fileFormat.prepareWrite(sparkSession, job, caseInsensitiveOptions, dataSchema) val description = new WriteJobDescription( uuid = UUID.randomUUID().toString, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala index fa366ccce6b61..aeb40e5a4131d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala @@ -66,7 +66,6 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { job: Job, options: Map[String, String], dataSchema: StructType): OutputWriterFactory = { - DataSourceUtils.verifyWriteSchema(this, dataSchema) val conf = job.getConfiguration val csvOptions = new CSVOptions( options, @@ -98,7 +97,6 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { filters: Seq[Filter], options: Map[String, String], hadoopConf: Configuration): (PartitionedFile) => Iterator[InternalRow] = { - DataSourceUtils.verifyReadSchema(this, dataSchema) val broadcastedHadoopConf = sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) @@ -153,6 +151,15 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { override def hashCode(): Int = getClass.hashCode() override def equals(other: Any): Boolean = other.isInstanceOf[CSVFileFormat] + + override def supportDataType(dataType: DataType, isReadPath: Boolean): Boolean = dataType match { + case _: AtomicType => true + + case udt: UserDefinedType[_] => supportDataType(udt.sqlType, isReadPath) + + case _ => false + } + } private[csv] class CsvOutputWriter( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala index 383bff1375a93..a9241afba537b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.json.{JacksonGenerator, JacksonParser, JSON import org.apache.spark.sql.catalyst.util.CompressionCodecs import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.sources._ -import org.apache.spark.sql.types.{StringType, StructType} +import org.apache.spark.sql.types._ import org.apache.spark.util.SerializableConfiguration class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister { @@ -65,8 +65,6 @@ class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister { job: Job, options: Map[String, String], dataSchema: StructType): OutputWriterFactory = { - DataSourceUtils.verifyWriteSchema(this, dataSchema) - val conf = job.getConfiguration val parsedOptions = new JSONOptions( options, @@ -98,8 +96,6 @@ class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister { filters: Seq[Filter], options: Map[String, String], hadoopConf: Configuration): PartitionedFile => Iterator[InternalRow] = { - DataSourceUtils.verifyReadSchema(this, dataSchema) - val broadcastedHadoopConf = sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) @@ -148,6 +144,23 @@ class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister { override def hashCode(): Int = getClass.hashCode() override def equals(other: Any): Boolean = other.isInstanceOf[JsonFileFormat] + + override def supportDataType(dataType: DataType, isReadPath: Boolean): Boolean = dataType match { + case _: AtomicType => true + + case st: StructType => st.forall { f => supportDataType(f.dataType, isReadPath) } + + case ArrayType(elementType, _) => supportDataType(elementType, isReadPath) + + case MapType(keyType, valueType, _) => + supportDataType(keyType, isReadPath) && supportDataType(valueType, isReadPath) + + case udt: UserDefinedType[_] => supportDataType(udt.sqlType, isReadPath) + + case _: NullType => true + + case _ => false + } } private[json] class JsonOutputWriter( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala index df488a748e3e5..3a8c0add8c2f9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala @@ -89,8 +89,6 @@ class OrcFileFormat job: Job, options: Map[String, String], dataSchema: StructType): OutputWriterFactory = { - DataSourceUtils.verifyWriteSchema(this, dataSchema) - val orcOptions = new OrcOptions(options, sparkSession.sessionState.conf) val conf = job.getConfiguration @@ -143,8 +141,6 @@ class OrcFileFormat filters: Seq[Filter], options: Map[String, String], hadoopConf: Configuration): (PartitionedFile) => Iterator[InternalRow] = { - DataSourceUtils.verifyReadSchema(this, dataSchema) - if (sparkSession.sessionState.conf.orcFilterPushDown) { OrcFilters.createFilter(dataSchema, filters).foreach { f => OrcInputFormat.setSearchArgument(hadoopConf, f, dataSchema.fieldNames) @@ -228,4 +224,21 @@ class OrcFileFormat } } } + + override def supportDataType(dataType: DataType, isReadPath: Boolean): Boolean = dataType match { + case _: AtomicType => true + + case st: StructType => st.forall { f => supportDataType(f.dataType, isReadPath) } + + case ArrayType(elementType, _) => supportDataType(elementType, isReadPath) + + case MapType(keyType, valueType, _) => + supportDataType(keyType, isReadPath) && supportDataType(valueType, isReadPath) + + case udt: UserDefinedType[_] => supportDataType(udt.sqlType, isReadPath) + + case _: NullType => isReadPath + + case _ => false + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala index 52a18abb55241..b86b97ec7b103 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala @@ -78,8 +78,6 @@ class ParquetFileFormat job: Job, options: Map[String, String], dataSchema: StructType): OutputWriterFactory = { - DataSourceUtils.verifyWriteSchema(this, dataSchema) - val parquetOptions = new ParquetOptions(options, sparkSession.sessionState.conf) val conf = ContextUtil.getConfiguration(job) @@ -303,8 +301,6 @@ class ParquetFileFormat filters: Seq[Filter], options: Map[String, String], hadoopConf: Configuration): (PartitionedFile) => Iterator[InternalRow] = { - DataSourceUtils.verifyReadSchema(this, dataSchema) - hadoopConf.set(ParquetInputFormat.READ_SUPPORT_CLASS, classOf[ParquetReadSupport].getName) hadoopConf.set( ParquetReadSupport.SPARK_ROW_REQUESTED_SCHEMA, @@ -454,6 +450,21 @@ class ParquetFileFormat } } } + + override def supportDataType(dataType: DataType, isReadPath: Boolean): Boolean = dataType match { + case _: AtomicType => true + + case st: StructType => st.forall { f => supportDataType(f.dataType, isReadPath) } + + case ArrayType(elementType, _) => supportDataType(elementType, isReadPath) + + case MapType(keyType, valueType, _) => + supportDataType(keyType, isReadPath) && supportDataType(valueType, isReadPath) + + case udt: UserDefinedType[_] => supportDataType(udt.sqlType, isReadPath) + + case _ => false + } } object ParquetFileFormat extends Logging { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala index e93908da43535..8661a5395ac44 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter import org.apache.spark.sql.catalyst.util.CompressionCodecs import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.sources._ -import org.apache.spark.sql.types.{StringType, StructType} +import org.apache.spark.sql.types.{DataType, StringType, StructType} import org.apache.spark.util.SerializableConfiguration /** @@ -47,11 +47,6 @@ class TextFileFormat extends TextBasedFileFormat with DataSourceRegister { throw new AnalysisException( s"Text data source supports only a single column, and you have ${schema.size} columns.") } - val tpe = schema(0).dataType - if (tpe != StringType) { - throw new AnalysisException( - s"Text data source supports only a string column, but you have ${tpe.simpleString}.") - } } override def isSplitable( @@ -141,6 +136,9 @@ class TextFileFormat extends TextBasedFileFormat with DataSourceRegister { } } } + + override def supportDataType(dataType: DataType, isReadPath: Boolean): Boolean = + dataType == StringType } class TextOutputWriter( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala index 86f9647b4ac4c..a7ce952b70ac1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala @@ -205,63 +205,121 @@ class FileBasedDataSourceSuite extends QueryTest with SharedSQLContext with Befo } } + // Text file format only supports string type + test("SPARK-24691 error handling for unsupported types - text") { + withTempDir { dir => + // write path + val textDir = new File(dir, "text").getCanonicalPath + var msg = intercept[AnalysisException] { + Seq(1).toDF.write.text(textDir) + }.getMessage + assert(msg.contains("Text data source does not support int data type")) + + msg = intercept[AnalysisException] { + Seq(1.2).toDF.write.text(textDir) + }.getMessage + assert(msg.contains("Text data source does not support double data type")) + + msg = intercept[AnalysisException] { + Seq(true).toDF.write.text(textDir) + }.getMessage + assert(msg.contains("Text data source does not support boolean data type")) + + msg = intercept[AnalysisException] { + Seq(1).toDF("a").selectExpr("struct(a)").write.text(textDir) + }.getMessage + assert(msg.contains("Text data source does not support struct data type")) + + msg = intercept[AnalysisException] { + Seq((Map("Tesla" -> 3))).toDF("cars").write.mode("overwrite").text(textDir) + }.getMessage + assert(msg.contains("Text data source does not support map data type")) + + msg = intercept[AnalysisException] { + Seq((Array("Tesla", "Chevy", "Ford"))).toDF("brands") + .write.mode("overwrite").text(textDir) + }.getMessage + assert(msg.contains("Text data source does not support array data type")) + + // read path + Seq("aaa").toDF.write.mode("overwrite").text(textDir) + msg = intercept[AnalysisException] { + val schema = StructType(StructField("a", IntegerType, true) :: Nil) + spark.read.schema(schema).text(textDir).collect() + }.getMessage + assert(msg.contains("Text data source does not support int data type")) + + msg = intercept[AnalysisException] { + val schema = StructType(StructField("a", DoubleType, true) :: Nil) + spark.read.schema(schema).text(textDir).collect() + }.getMessage + assert(msg.contains("Text data source does not support double data type")) + + msg = intercept[AnalysisException] { + val schema = StructType(StructField("a", BooleanType, true) :: Nil) + spark.read.schema(schema).text(textDir).collect() + }.getMessage + assert(msg.contains("Text data source does not support boolean data type")) + } + } + // Unsupported data types of csv, json, orc, and parquet are as follows; - // csv -> R/W: Interval, Null, Array, Map, Struct - // json -> W: Interval - // orc -> W: Interval, Null + // csv -> R/W: Null, Array, Map, Struct + // json -> R/W: Interval + // orc -> R/W: Interval, W: Null // parquet -> R/W: Interval, Null test("SPARK-24204 error handling for unsupported Array/Map/Struct types - csv") { withTempDir { dir => val csvDir = new File(dir, "csv").getCanonicalPath - var msg = intercept[UnsupportedOperationException] { + var msg = intercept[AnalysisException] { Seq((1, "Tesla")).toDF("a", "b").selectExpr("struct(a, b)").write.csv(csvDir) }.getMessage assert(msg.contains("CSV data source does not support struct data type")) - msg = intercept[UnsupportedOperationException] { + msg = intercept[AnalysisException] { val schema = StructType.fromDDL("a struct") spark.range(1).write.mode("overwrite").csv(csvDir) spark.read.schema(schema).csv(csvDir).collect() }.getMessage assert(msg.contains("CSV data source does not support struct data type")) - msg = intercept[UnsupportedOperationException] { + msg = intercept[AnalysisException] { Seq((1, Map("Tesla" -> 3))).toDF("id", "cars").write.mode("overwrite").csv(csvDir) }.getMessage assert(msg.contains("CSV data source does not support map data type")) - msg = intercept[UnsupportedOperationException] { + msg = intercept[AnalysisException] { val schema = StructType.fromDDL("a map") spark.range(1).write.mode("overwrite").csv(csvDir) spark.read.schema(schema).csv(csvDir).collect() }.getMessage assert(msg.contains("CSV data source does not support map data type")) - msg = intercept[UnsupportedOperationException] { + msg = intercept[AnalysisException] { Seq((1, Array("Tesla", "Chevy", "Ford"))).toDF("id", "brands") .write.mode("overwrite").csv(csvDir) }.getMessage assert(msg.contains("CSV data source does not support array data type")) - msg = intercept[UnsupportedOperationException] { + msg = intercept[AnalysisException] { val schema = StructType.fromDDL("a array") spark.range(1).write.mode("overwrite").csv(csvDir) spark.read.schema(schema).csv(csvDir).collect() }.getMessage assert(msg.contains("CSV data source does not support array data type")) - msg = intercept[UnsupportedOperationException] { + msg = intercept[AnalysisException] { Seq((1, new UDT.MyDenseVector(Array(0.25, 2.25, 4.25)))).toDF("id", "vectors") .write.mode("overwrite").csv(csvDir) }.getMessage - assert(msg.contains("CSV data source does not support array data type")) + assert(msg.contains("CSV data source does not support mydensevector data type")) - msg = intercept[UnsupportedOperationException] { + msg = intercept[AnalysisException] { val schema = StructType(StructField("a", new UDT.MyDenseVectorUDT(), true) :: Nil) spark.range(1).write.mode("overwrite").csv(csvDir) spark.read.schema(schema).csv(csvDir).collect() }.getMessage - assert(msg.contains("CSV data source does not support array data type.")) + assert(msg.contains("CSV data source does not support mydensevector data type.")) } } @@ -276,17 +334,17 @@ class FileBasedDataSourceSuite extends QueryTest with SharedSQLContext with Befo }.getMessage assert(msg.contains("Cannot save interval data type into external storage.")) - msg = intercept[UnsupportedOperationException] { + msg = intercept[AnalysisException] { spark.udf.register("testType", () => new IntervalData()) sql("select testType()").write.format(format).mode("overwrite").save(tempDir) }.getMessage assert(msg.toLowerCase(Locale.ROOT) - .contains(s"$format data source does not support calendarinterval data type.")) + .contains(s"$format data source does not support interval data type.")) } // read path Seq("parquet", "csv").foreach { format => - var msg = intercept[UnsupportedOperationException] { + var msg = intercept[AnalysisException] { val schema = StructType(StructField("a", CalendarIntervalType, true) :: Nil) spark.range(1).write.format(format).mode("overwrite").save(tempDir) spark.read.schema(schema).format(format).load(tempDir).collect() @@ -294,26 +352,13 @@ class FileBasedDataSourceSuite extends QueryTest with SharedSQLContext with Befo assert(msg.toLowerCase(Locale.ROOT) .contains(s"$format data source does not support calendarinterval data type.")) - msg = intercept[UnsupportedOperationException] { + msg = intercept[AnalysisException] { val schema = StructType(StructField("a", new IntervalUDT(), true) :: Nil) spark.range(1).write.format(format).mode("overwrite").save(tempDir) spark.read.schema(schema).format(format).load(tempDir).collect() }.getMessage assert(msg.toLowerCase(Locale.ROOT) - .contains(s"$format data source does not support calendarinterval data type.")) - } - - // We expect the types below should be passed for backward-compatibility - Seq("orc", "json").foreach { format => - // Interval type - var schema = StructType(StructField("a", CalendarIntervalType, true) :: Nil) - spark.range(1).write.format(format).mode("overwrite").save(tempDir) - spark.read.schema(schema).format(format).load(tempDir).collect() - - // UDT having interval data - schema = StructType(StructField("a", new IntervalUDT(), true) :: Nil) - spark.range(1).write.format(format).mode("overwrite").save(tempDir) - spark.read.schema(schema).format(format).load(tempDir).collect() + .contains(s"$format data source does not support interval data type.")) } } } @@ -324,13 +369,13 @@ class FileBasedDataSourceSuite extends QueryTest with SharedSQLContext with Befo Seq("orc").foreach { format => // write path - var msg = intercept[UnsupportedOperationException] { + var msg = intercept[AnalysisException] { sql("select null").write.format(format).mode("overwrite").save(tempDir) }.getMessage assert(msg.toLowerCase(Locale.ROOT) .contains(s"$format data source does not support null data type.")) - msg = intercept[UnsupportedOperationException] { + msg = intercept[AnalysisException] { spark.udf.register("testType", () => new NullData()) sql("select testType()").write.format(format).mode("overwrite").save(tempDir) }.getMessage @@ -353,13 +398,13 @@ class FileBasedDataSourceSuite extends QueryTest with SharedSQLContext with Befo Seq("parquet", "csv").foreach { format => // write path - var msg = intercept[UnsupportedOperationException] { + var msg = intercept[AnalysisException] { sql("select null").write.format(format).mode("overwrite").save(tempDir) }.getMessage assert(msg.toLowerCase(Locale.ROOT) .contains(s"$format data source does not support null data type.")) - msg = intercept[UnsupportedOperationException] { + msg = intercept[AnalysisException] { spark.udf.register("testType", () => new NullData()) sql("select testType()").write.format(format).mode("overwrite").save(tempDir) }.getMessage @@ -367,7 +412,7 @@ class FileBasedDataSourceSuite extends QueryTest with SharedSQLContext with Befo .contains(s"$format data source does not support null data type.")) // read path - msg = intercept[UnsupportedOperationException] { + msg = intercept[AnalysisException] { val schema = StructType(StructField("a", NullType, true) :: Nil) spark.range(1).write.format(format).mode("overwrite").save(tempDir) spark.read.schema(schema).format(format).load(tempDir).collect() @@ -375,7 +420,7 @@ class FileBasedDataSourceSuite extends QueryTest with SharedSQLContext with Befo assert(msg.toLowerCase(Locale.ROOT) .contains(s"$format data source does not support null data type.")) - msg = intercept[UnsupportedOperationException] { + msg = intercept[AnalysisException] { val schema = StructType(StructField("a", new NullUDT(), true) :: Nil) spark.range(1).write.format(format).mode("overwrite").save(tempDir) spark.read.schema(schema).format(format).load(tempDir).collect() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index 270ed7f80197c..ca95aad3976e6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -2513,7 +2513,7 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { test("alter datasource table add columns - text format not supported") { withTable("t1") { - sql("CREATE TABLE t1 (c1 int) USING text") + sql("CREATE TABLE t1 (c1 string) USING text") val e = intercept[AnalysisException] { sql("ALTER TABLE t1 ADD COLUMNS (c2 int)") }.getMessage diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala index 8764f0c42cf9f..bceaf1a9ec061 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala @@ -25,7 +25,7 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{BlockLocation, FileStatus, Path, RawLocalFileSystem} import org.apache.hadoop.mapreduce.Job -import org.apache.spark.{SparkConf, SparkException} +import org.apache.spark.SparkException import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.catalog.BucketSpec diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala index dd2144c5fcea8..20090696ec3fc 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala @@ -42,7 +42,7 @@ import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.orc.OrcOptions import org.apache.spark.sql.hive.{HiveInspectors, HiveShim} import org.apache.spark.sql.sources.{Filter, _} -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types._ import org.apache.spark.util.SerializableConfiguration /** @@ -72,7 +72,6 @@ class OrcFileFormat extends FileFormat with DataSourceRegister with Serializable job: Job, options: Map[String, String], dataSchema: StructType): OutputWriterFactory = { - DataSourceUtils.verifyWriteSchema(this, dataSchema) val orcOptions = new OrcOptions(options, sparkSession.sessionState.conf) @@ -123,7 +122,6 @@ class OrcFileFormat extends FileFormat with DataSourceRegister with Serializable filters: Seq[Filter], options: Map[String, String], hadoopConf: Configuration): (PartitionedFile) => Iterator[InternalRow] = { - DataSourceUtils.verifyReadSchema(this, dataSchema) if (sparkSession.sessionState.conf.orcFilterPushDown) { // Sets pushed predicates @@ -178,6 +176,23 @@ class OrcFileFormat extends FileFormat with DataSourceRegister with Serializable } } } + + override def supportDataType(dataType: DataType, isReadPath: Boolean): Boolean = dataType match { + case _: AtomicType => true + + case st: StructType => st.forall { f => supportDataType(f.dataType, isReadPath) } + + case ArrayType(elementType, _) => supportDataType(elementType, isReadPath) + + case MapType(keyType, valueType, _) => + supportDataType(keyType, isReadPath) && supportDataType(valueType, isReadPath) + + case udt: UserDefinedType[_] => supportDataType(udt.sqlType, isReadPath) + + case _: NullType => isReadPath + + case _ => false + } } private[orc] class OrcSerializer(dataSchema: StructType, conf: Configuration) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcSourceSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcSourceSuite.scala index 69009e1b520c2..fb4957ed943a7 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcSourceSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcSourceSuite.scala @@ -146,38 +146,31 @@ class HiveOrcSourceSuite extends OrcSuite with TestHiveSingleton { }.getMessage assert(msg.contains("Cannot save interval data type into external storage.")) - msg = intercept[UnsupportedOperationException] { + msg = intercept[AnalysisException] { sql("select null").write.mode("overwrite").orc(orcDir) }.getMessage assert(msg.contains("ORC data source does not support null data type.")) - msg = intercept[UnsupportedOperationException] { + msg = intercept[AnalysisException] { spark.udf.register("testType", () => new IntervalData()) sql("select testType()").write.mode("overwrite").orc(orcDir) }.getMessage - assert(msg.contains("ORC data source does not support calendarinterval data type.")) + assert(msg.contains("ORC data source does not support interval data type.")) // read path - msg = intercept[UnsupportedOperationException] { + msg = intercept[AnalysisException] { val schema = StructType(StructField("a", CalendarIntervalType, true) :: Nil) spark.range(1).write.mode("overwrite").orc(orcDir) spark.read.schema(schema).orc(orcDir).collect() }.getMessage assert(msg.contains("ORC data source does not support calendarinterval data type.")) - msg = intercept[UnsupportedOperationException] { - val schema = StructType(StructField("a", NullType, true) :: Nil) - spark.range(1).write.mode("overwrite").orc(orcDir) - spark.read.schema(schema).orc(orcDir).collect() - }.getMessage - assert(msg.contains("ORC data source does not support null data type.")) - - msg = intercept[UnsupportedOperationException] { + msg = intercept[AnalysisException] { val schema = StructType(StructField("a", new IntervalUDT(), true) :: Nil) spark.range(1).write.mode("overwrite").orc(orcDir) spark.read.schema(schema).orc(orcDir).collect() }.getMessage - assert(msg.contains("ORC data source does not support calendarinterval data type.")) + assert(msg.contains("ORC data source does not support interval data type.")) } } } From 9fa4a1ed38713e2d18a3320d3fc56f9f6db07b06 Mon Sep 17 00:00:00 2001 From: Yash Sharma Date: Thu, 12 Jul 2018 10:04:47 -0700 Subject: [PATCH 1106/1161] =?UTF-8?q?[SPARK-20168][STREAMING=20KINESIS]=20?= =?UTF-8?q?Setting=20the=20timestamp=20directly=20would=20cause=20exceptio?= =?UTF-8?q?n=20on=20=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Setting the timestamp directly would cause exception on reading stream, it can be set directly only if the mode is not AT_TIMESTAMP ## What changes were proposed in this pull request? The last patch in the kinesis streaming receiver sets the timestamp for the mode AT_TIMESTAMP, but this mode can only be set via the `baseClientLibConfiguration.withTimestampAtInitialPositionInStream() ` and can't be set directly using `.withInitialPositionInStream()` This patch fixes the issue. ## How was this patch tested? Kinesis Receiver doesn't expose the internal state outside, so couldn't find the right way to test this change. Seeking for tips from other contributors here. Author: Yash Sharma Closes #21541 from yashs360/ysharma/fix_kinesis_bug. --- .../org/apache/spark/streaming/kinesis/KinesisReceiver.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala index fa0de6298a5f1..69c52365b1bf8 100644 --- a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala +++ b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala @@ -160,7 +160,6 @@ private[kinesis] class KinesisReceiver[T]( cloudWatchCreds.map(_.provider).getOrElse(kinesisProvider), workerId) .withKinesisEndpoint(endpointUrl) - .withInitialPositionInStream(initialPosition.getPosition) .withTaskBackoffTimeMillis(500) .withRegionName(regionName) @@ -169,7 +168,8 @@ private[kinesis] class KinesisReceiver[T]( initialPosition match { case ts: AtTimestamp => baseClientLibConfiguration.withTimestampAtInitialPositionInStream(ts.getTimestamp) - case _ => baseClientLibConfiguration + case _ => + baseClientLibConfiguration.withInitialPositionInStream(initialPosition.getPosition) } } From 1055c94cdf072bfce5e36bb6552fe9b148bb9d17 Mon Sep 17 00:00:00 2001 From: Dhruve Ashar Date: Thu, 12 Jul 2018 15:36:02 -0500 Subject: [PATCH 1107/1161] [SPARK-24610] fix reading small files via wholeTextFiles ## What changes were proposed in this pull request? The `WholeTextFileInputFormat` determines the `maxSplitSize` for the file/s being read using the `wholeTextFiles` method. While this works well for large files, for smaller files where the maxSplitSize is smaller than the defaults being used with configs like hive-site.xml or explicitly passed in the form of `mapreduce.input.fileinputformat.split.minsize.per.node` or `mapreduce.input.fileinputformat.split.minsize.per.rack` , it just throws up an exception. ```java java.io.IOException: Minimum split size pernode 123456 cannot be larger than maximum split size 9962 at org.apache.hadoop.mapreduce.lib.input.CombineFileInputFormat.getSplits(CombineFileInputFormat.java:200) at org.apache.spark.rdd.WholeTextFileRDD.getPartitions(WholeTextFileRDD.scala:50) at org.apache.spark.rdd.RDD$$anonfun$partitions$2.apply(RDD.scala:252) at org.apache.spark.rdd.RDD$$anonfun$partitions$2.apply(RDD.scala:250) at scala.Option.getOrElse(Option.scala:121) at org.apache.spark.rdd.RDD.partitions(RDD.scala:250) at org.apache.spark.rdd.MapPartitionsRDD.getPartitions(MapPartitionsRDD.scala:35) at org.apache.spark.rdd.RDD$$anonfun$partitions$2.apply(RDD.scala:252) at org.apache.spark.rdd.RDD$$anonfun$partitions$2.apply(RDD.scala:250) at scala.Option.getOrElse(Option.scala:121) at org.apache.spark.rdd.RDD.partitions(RDD.scala:250) at org.apache.spark.SparkContext.runJob(SparkContext.scala:2096) at org.apache.spark.rdd.RDD.count(RDD.scala:1158) ... 48 elided ` This change checks the maxSplitSize against the minSplitSizePerNode and minSplitSizePerRack and set them if `maxSplitSize < minSplitSizePerNode/Rack` ## How was this patch tested? Test manually setting the conf while launching the job and added unit test. Author: Dhruve Ashar Closes #21601 from dhruve/bug/SPARK-24610. --- .../input/WholeTextFileInputFormat.scala | 13 +++ .../input/WholeTextFileInputFormatSuite.scala | 96 +++++++++++++++++++ 2 files changed, 109 insertions(+) create mode 100644 core/src/test/scala/org/apache/spark/input/WholeTextFileInputFormatSuite.scala diff --git a/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala b/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala index f47cd38d712c3..04c5c4b90e8a1 100644 --- a/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala +++ b/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala @@ -53,6 +53,19 @@ private[spark] class WholeTextFileInputFormat val totalLen = files.map(file => if (file.isDirectory) 0L else file.getLen).sum val maxSplitSize = Math.ceil(totalLen * 1.0 / (if (minPartitions == 0) 1 else minPartitions)).toLong + + // For small files we need to ensure the min split size per node & rack <= maxSplitSize + val config = context.getConfiguration + val minSplitSizePerNode = config.getLong(CombineFileInputFormat.SPLIT_MINSIZE_PERNODE, 0L) + val minSplitSizePerRack = config.getLong(CombineFileInputFormat.SPLIT_MINSIZE_PERRACK, 0L) + + if (maxSplitSize < minSplitSizePerNode) { + super.setMinSplitSizeNode(maxSplitSize) + } + + if (maxSplitSize < minSplitSizePerRack) { + super.setMinSplitSizeRack(maxSplitSize) + } super.setMaxSplitSize(maxSplitSize) } } diff --git a/core/src/test/scala/org/apache/spark/input/WholeTextFileInputFormatSuite.scala b/core/src/test/scala/org/apache/spark/input/WholeTextFileInputFormatSuite.scala new file mode 100644 index 0000000000000..817dc082b7d38 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/input/WholeTextFileInputFormatSuite.scala @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.input + +import java.io.{DataOutputStream, File, FileOutputStream} + +import scala.collection.immutable.IndexedSeq + +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark.internal.Logging +import org.apache.spark.util.Utils + +/** + * Tests the correctness of + * [[org.apache.spark.input.WholeTextFileInputFormat WholeTextFileInputFormat]]. A temporary + * directory containing files is created as fake input which is deleted in the end. + */ +class WholeTextFileInputFormatSuite extends SparkFunSuite with BeforeAndAfterAll with Logging { + private var sc: SparkContext = _ + + override def beforeAll() { + super.beforeAll() + val conf = new SparkConf() + sc = new SparkContext("local", "test", conf) + } + + override def afterAll() { + try { + sc.stop() + } finally { + super.afterAll() + } + } + + private def createNativeFile(inputDir: File, fileName: String, contents: Array[Byte], + compress: Boolean) = { + val path = s"${inputDir.toString}/$fileName" + val out = new DataOutputStream(new FileOutputStream(path)) + out.write(contents, 0, contents.length) + out.close() + } + + test("for small files minimum split size per node and per rack should be less than or equal to " + + "maximum split size.") { + var dir : File = null; + try { + dir = Utils.createTempDir() + logInfo(s"Local disk address is ${dir.toString}.") + + // Set the minsize per node and rack to be larger than the size of the input file. + sc.hadoopConfiguration.setLong( + "mapreduce.input.fileinputformat.split.minsize.per.node", 123456) + sc.hadoopConfiguration.setLong( + "mapreduce.input.fileinputformat.split.minsize.per.rack", 123456) + + WholeTextFileInputFormatSuite.files.foreach { case (filename, contents) => + createNativeFile(dir, filename, contents, false) + } + // ensure spark job runs successfully without exceptions from the CombineFileInputFormat + assert(sc.wholeTextFiles(dir.toString).count == 3) + } finally { + Utils.deleteRecursively(dir) + } + } +} + +/** + * Files to be tested are defined here. + */ +object WholeTextFileInputFormatSuite { + private val testWords: IndexedSeq[Byte] = "Spark is easy to use.\n".map(_.toByte) + + private val fileNames = Array("part-00000", "part-00001", "part-00002") + private val fileLengths = Array(10, 100, 1000) + + private val files = fileLengths.zip(fileNames).map { case (upperBound, filename) => + filename -> Stream.continually(testWords.toList.toStream).flatten.take(upperBound).toArray + }.toMap +} From 395860a986987886df6d60fd9b26afd818b2cb39 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Thu, 12 Jul 2018 13:55:25 -0700 Subject: [PATCH 1108/1161] [SPARK-24768][SQL] Have a built-in AVRO data source implementation ## What changes were proposed in this pull request? Apache Avro (https://avro.apache.org) is a popular data serialization format. It is widely used in the Spark and Hadoop ecosystem, especially for Kafka-based data pipelines. Using the external package https://github.com/databricks/spark-avro, Spark SQL can read and write the avro data. Making spark-Avro built-in can provide a better experience for first-time users of Spark SQL and structured streaming. We expect the built-in Avro data source can further improve the adoption of structured streaming. The proposal is to inline code from spark-avro package (https://github.com/databricks/spark-avro). The target release is Spark 2.4. [Built-in AVRO Data Source In Spark 2.4.pdf](https://github.com/apache/spark/files/2181511/Built-in.AVRO.Data.Source.In.Spark.2.4.pdf) ## How was this patch tested? Unit test Author: Gengliang Wang Closes #21742 from gengliangwang/export_avro. --- dev/run-tests.py | 2 +- dev/sparktestsupport/modules.py | 10 + external/avro/pom.xml | 73 ++ ...pache.spark.sql.sources.DataSourceRegister | 1 + .../spark/sql/avro/AvroFileFormat.scala | 289 +++++++ .../spark/sql/avro/AvroOutputWriter.scala | 164 ++++ .../sql/avro/AvroOutputWriterFactory.scala | 38 + .../spark/sql/avro/SchemaConverters.scala | 406 +++++++++ .../org/apache/spark/sql/avro/package.scala | 39 + .../avro/src/test/resources/episodes.avro | Bin 0 -> 597 bytes .../avro/src/test/resources/log4j.properties | 49 ++ .../test-random-partitioned/part-r-00000.avro | Bin 0 -> 1768 bytes .../test-random-partitioned/part-r-00001.avro | Bin 0 -> 2313 bytes .../test-random-partitioned/part-r-00002.avro | Bin 0 -> 1621 bytes .../test-random-partitioned/part-r-00003.avro | Bin 0 -> 2117 bytes .../test-random-partitioned/part-r-00004.avro | Bin 0 -> 3282 bytes .../test-random-partitioned/part-r-00005.avro | Bin 0 -> 1550 bytes .../test-random-partitioned/part-r-00006.avro | Bin 0 -> 1729 bytes .../test-random-partitioned/part-r-00007.avro | Bin 0 -> 1897 bytes .../test-random-partitioned/part-r-00008.avro | Bin 0 -> 3420 bytes .../test-random-partitioned/part-r-00009.avro | Bin 0 -> 1796 bytes .../test-random-partitioned/part-r-00010.avro | Bin 0 -> 3872 bytes external/avro/src/test/resources/test.avro | Bin 0 -> 1365 bytes external/avro/src/test/resources/test.avsc | 53 ++ external/avro/src/test/resources/test.json | 42 + .../org/apache/spark/sql/avro/AvroSuite.scala | 812 ++++++++++++++++++ .../avro/SerializableConfigurationSuite.scala | 50 ++ .../org/apache/spark/sql/avro/TestUtils.scala | 156 ++++ pom.xml | 1 + project/SparkBuild.scala | 12 +- 30 files changed, 2191 insertions(+), 6 deletions(-) create mode 100644 external/avro/pom.xml create mode 100644 external/avro/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister create mode 100755 external/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala create mode 100644 external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOutputWriter.scala create mode 100644 external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOutputWriterFactory.scala create mode 100644 external/avro/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala create mode 100755 external/avro/src/main/scala/org/apache/spark/sql/avro/package.scala create mode 100644 external/avro/src/test/resources/episodes.avro create mode 100644 external/avro/src/test/resources/log4j.properties create mode 100755 external/avro/src/test/resources/test-random-partitioned/part-r-00000.avro create mode 100755 external/avro/src/test/resources/test-random-partitioned/part-r-00001.avro create mode 100755 external/avro/src/test/resources/test-random-partitioned/part-r-00002.avro create mode 100755 external/avro/src/test/resources/test-random-partitioned/part-r-00003.avro create mode 100755 external/avro/src/test/resources/test-random-partitioned/part-r-00004.avro create mode 100755 external/avro/src/test/resources/test-random-partitioned/part-r-00005.avro create mode 100755 external/avro/src/test/resources/test-random-partitioned/part-r-00006.avro create mode 100755 external/avro/src/test/resources/test-random-partitioned/part-r-00007.avro create mode 100755 external/avro/src/test/resources/test-random-partitioned/part-r-00008.avro create mode 100755 external/avro/src/test/resources/test-random-partitioned/part-r-00009.avro create mode 100755 external/avro/src/test/resources/test-random-partitioned/part-r-00010.avro create mode 100644 external/avro/src/test/resources/test.avro create mode 100644 external/avro/src/test/resources/test.avsc create mode 100644 external/avro/src/test/resources/test.json create mode 100644 external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala create mode 100755 external/avro/src/test/scala/org/apache/spark/sql/avro/SerializableConfigurationSuite.scala create mode 100755 external/avro/src/test/scala/org/apache/spark/sql/avro/TestUtils.scala diff --git a/dev/run-tests.py b/dev/run-tests.py index cd4590864b7d7..d9d3789ac1255 100755 --- a/dev/run-tests.py +++ b/dev/run-tests.py @@ -110,7 +110,7 @@ def determine_modules_to_test(changed_modules): ['graphx', 'examples'] >>> x = [x.name for x in determine_modules_to_test([modules.sql])] >>> x # doctest: +NORMALIZE_WHITESPACE - ['sql', 'hive', 'mllib', 'sql-kafka-0-10', 'examples', 'hive-thriftserver', + ['sql', 'avro', 'hive', 'mllib', 'sql-kafka-0-10', 'examples', 'hive-thriftserver', 'pyspark-sql', 'repl', 'sparkr', 'pyspark-mllib', 'pyspark-ml'] """ modules_to_test = set() diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index dfea762db98c6..2aa355504bf29 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -170,6 +170,16 @@ def __hash__(self): ] ) +avro = Module( + name="avro", + dependencies=[sql], + source_file_regexes=[ + "external/avro", + ], + sbt_test_goals=[ + "avro/test", + ] +) sql_kafka = Module( name="sql-kafka-0-10", diff --git a/external/avro/pom.xml b/external/avro/pom.xml new file mode 100644 index 0000000000000..42e865bc38824 --- /dev/null +++ b/external/avro/pom.xml @@ -0,0 +1,73 @@ + + + + + 4.0.0 + + org.apache.spark + spark-parent_2.11 + 2.4.0-SNAPSHOT + ../../pom.xml + + + spark-sql-avro_2.11 + + avro + + jar + Spark Avro + http://spark.apache.org/ + + + + org.apache.spark + spark-sql_${scala.binary.version} + ${project.version} + provided + + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + + + org.apache.spark + spark-catalyst_${scala.binary.version} + ${project.version} + test-jar + test + + + org.apache.spark + spark-sql_${scala.binary.version} + ${project.version} + test-jar + test + + + org.apache.spark + spark-tags_${scala.binary.version} + + + + target/scala-${scala.binary.version}/classes + target/scala-${scala.binary.version}/test-classes + + diff --git a/external/avro/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/external/avro/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister new file mode 100644 index 0000000000000..95835f0d4ca49 --- /dev/null +++ b/external/avro/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -0,0 +1 @@ +org.apache.spark.sql.avro.AvroFileFormat diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala new file mode 100755 index 0000000000000..46e5a189c5eb3 --- /dev/null +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala @@ -0,0 +1,289 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.avro + +import java.io._ +import java.net.URI +import java.util.zip.Deflater + +import scala.util.control.NonFatal + +import com.esotericsoftware.kryo.{Kryo, KryoSerializable} +import com.esotericsoftware.kryo.io.{Input, Output} +import org.apache.avro.{Schema, SchemaBuilder} +import org.apache.avro.file.{DataFileConstants, DataFileReader} +import org.apache.avro.generic.{GenericDatumReader, GenericRecord} +import org.apache.avro.mapred.{AvroOutputFormat, FsInput} +import org.apache.avro.mapreduce.AvroJob +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileStatus, Path} +import org.apache.hadoop.mapreduce.Job +import org.slf4j.LoggerFactory + +import org.apache.spark.TaskContext +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.encoders.RowEncoder +import org.apache.spark.sql.catalyst.expressions.GenericRow +import org.apache.spark.sql.execution.datasources.{FileFormat, OutputWriterFactory, PartitionedFile} +import org.apache.spark.sql.sources.{DataSourceRegister, Filter} +import org.apache.spark.sql.types.StructType + +private[avro] class AvroFileFormat extends FileFormat with DataSourceRegister { + private val log = LoggerFactory.getLogger(getClass) + + override def equals(other: Any): Boolean = other match { + case _: AvroFileFormat => true + case _ => false + } + + // Dummy hashCode() to appease ScalaStyle. + override def hashCode(): Int = super.hashCode() + + override def inferSchema( + spark: SparkSession, + options: Map[String, String], + files: Seq[FileStatus]): Option[StructType] = { + val conf = spark.sparkContext.hadoopConfiguration + + // Schema evolution is not supported yet. Here we only pick a single random sample file to + // figure out the schema of the whole dataset. + val sampleFile = + if (conf.getBoolean(AvroFileFormat.IgnoreFilesWithoutExtensionProperty, true)) { + files.find(_.getPath.getName.endsWith(".avro")).getOrElse { + throw new FileNotFoundException( + "No Avro files found. Hadoop option \"avro.mapred.ignore.inputs.without.extension\" " + + " is set to true. Do all input files have \".avro\" extension?" + ) + } + } else { + files.headOption.getOrElse { + throw new FileNotFoundException("No Avro files found.") + } + } + + // User can specify an optional avro json schema. + val avroSchema = options.get(AvroFileFormat.AvroSchema) + .map(new Schema.Parser().parse) + .getOrElse { + val in = new FsInput(sampleFile.getPath, conf) + try { + val reader = DataFileReader.openReader(in, new GenericDatumReader[GenericRecord]()) + try { + reader.getSchema + } finally { + reader.close() + } + } finally { + in.close() + } + } + + SchemaConverters.toSqlType(avroSchema).dataType match { + case t: StructType => Some(t) + case _ => throw new RuntimeException( + s"""Avro schema cannot be converted to a Spark SQL StructType: + | + |${avroSchema.toString(true)} + |""".stripMargin) + } + } + + override def shortName(): String = "avro" + + override def isSplitable( + sparkSession: SparkSession, + options: Map[String, String], + path: Path): Boolean = true + + override def prepareWrite( + spark: SparkSession, + job: Job, + options: Map[String, String], + dataSchema: StructType): OutputWriterFactory = { + val recordName = options.getOrElse("recordName", "topLevelRecord") + val recordNamespace = options.getOrElse("recordNamespace", "") + val build = SchemaBuilder.record(recordName).namespace(recordNamespace) + val outputAvroSchema = SchemaConverters.convertStructToAvro(dataSchema, build, recordNamespace) + + AvroJob.setOutputKeySchema(job, outputAvroSchema) + val AVRO_COMPRESSION_CODEC = "spark.sql.avro.compression.codec" + val AVRO_DEFLATE_LEVEL = "spark.sql.avro.deflate.level" + val COMPRESS_KEY = "mapred.output.compress" + + spark.conf.get(AVRO_COMPRESSION_CODEC, "snappy") match { + case "uncompressed" => + log.info("writing uncompressed Avro records") + job.getConfiguration.setBoolean(COMPRESS_KEY, false) + + case "snappy" => + log.info("compressing Avro output using Snappy") + job.getConfiguration.setBoolean(COMPRESS_KEY, true) + job.getConfiguration.set(AvroJob.CONF_OUTPUT_CODEC, DataFileConstants.SNAPPY_CODEC) + + case "deflate" => + val deflateLevel = spark.conf.get( + AVRO_DEFLATE_LEVEL, Deflater.DEFAULT_COMPRESSION.toString).toInt + log.info(s"compressing Avro output using deflate (level=$deflateLevel)") + job.getConfiguration.setBoolean(COMPRESS_KEY, true) + job.getConfiguration.set(AvroJob.CONF_OUTPUT_CODEC, DataFileConstants.DEFLATE_CODEC) + job.getConfiguration.setInt(AvroOutputFormat.DEFLATE_LEVEL_KEY, deflateLevel) + + case unknown: String => + log.error(s"unsupported compression codec $unknown") + } + + new AvroOutputWriterFactory(dataSchema, recordName, recordNamespace) + } + + override def buildReader( + spark: SparkSession, + dataSchema: StructType, + partitionSchema: StructType, + requiredSchema: StructType, + filters: Seq[Filter], + options: Map[String, String], + hadoopConf: Configuration): (PartitionedFile) => Iterator[InternalRow] = { + + val broadcastedConf = + spark.sparkContext.broadcast(new AvroFileFormat.SerializableConfiguration(hadoopConf)) + + (file: PartitionedFile) => { + val log = LoggerFactory.getLogger(classOf[AvroFileFormat]) + val conf = broadcastedConf.value.value + val userProvidedSchema = options.get(AvroFileFormat.AvroSchema).map(new Schema.Parser().parse) + + // TODO Removes this check once `FileFormat` gets a general file filtering interface method. + // Doing input file filtering is improper because we may generate empty tasks that process no + // input files but stress the scheduler. We should probably add a more general input file + // filtering mechanism for `FileFormat` data sources. See SPARK-16317. + if ( + conf.getBoolean(AvroFileFormat.IgnoreFilesWithoutExtensionProperty, true) && + !file.filePath.endsWith(".avro") + ) { + Iterator.empty + } else { + val reader = { + val in = new FsInput(new Path(new URI(file.filePath)), conf) + try { + val datumReader = userProvidedSchema match { + case Some(userSchema) => new GenericDatumReader[GenericRecord](userSchema) + case _ => new GenericDatumReader[GenericRecord]() + } + DataFileReader.openReader(in, datumReader) + } catch { + case NonFatal(e) => + log.error("Exception while opening DataFileReader", e) + in.close() + throw e + } + } + + // Ensure that the reader is closed even if the task fails or doesn't consume the entire + // iterator of records. + Option(TaskContext.get()).foreach { taskContext => + taskContext.addTaskCompletionListener { _ => + reader.close() + } + } + + reader.sync(file.start) + val stop = file.start + file.length + + val rowConverter = SchemaConverters.createConverterToSQL( + userProvidedSchema.getOrElse(reader.getSchema), requiredSchema) + + new Iterator[InternalRow] { + // Used to convert `Row`s containing data columns into `InternalRow`s. + private val encoderForDataColumns = RowEncoder(requiredSchema) + + private[this] var completed = false + + override def hasNext: Boolean = { + if (completed) { + false + } else { + val r = reader.hasNext && !reader.pastSync(stop) + if (!r) { + reader.close() + completed = true + } + r + } + } + + override def next(): InternalRow = { + if (reader.pastSync(stop)) { + throw new NoSuchElementException("next on empty iterator") + } + val record = reader.next() + val safeDataRow = rowConverter(record).asInstanceOf[GenericRow] + + // The safeDataRow is reused, we must do a copy + encoderForDataColumns.toRow(safeDataRow) + } + } + } + } + } +} + +private[avro] object AvroFileFormat { + val IgnoreFilesWithoutExtensionProperty = "avro.mapred.ignore.inputs.without.extension" + + val AvroSchema = "avroSchema" + + class SerializableConfiguration(@transient var value: Configuration) + extends Serializable with KryoSerializable { + @transient private[avro] lazy val log = LoggerFactory.getLogger(getClass) + + private def writeObject(out: ObjectOutputStream): Unit = tryOrIOException { + out.defaultWriteObject() + value.write(out) + } + + private def readObject(in: ObjectInputStream): Unit = tryOrIOException { + value = new Configuration(false) + value.readFields(in) + } + + private def tryOrIOException[T](block: => T): T = { + try { + block + } catch { + case e: IOException => + log.error("Exception encountered", e) + throw e + case NonFatal(e) => + log.error("Exception encountered", e) + throw new IOException(e) + } + } + + def write(kryo: Kryo, out: Output): Unit = { + val dos = new DataOutputStream(out) + value.write(dos) + dos.flush() + } + + def read(kryo: Kryo, in: Input): Unit = { + value = new Configuration(false) + value.readFields(new DataInputStream(in)) + } + } +} diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOutputWriter.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOutputWriter.scala new file mode 100644 index 0000000000000..830bf3c0570bf --- /dev/null +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOutputWriter.scala @@ -0,0 +1,164 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.avro + +import java.io.{IOException, OutputStream} +import java.nio.ByteBuffer +import java.sql.{Date, Timestamp} +import java.util.HashMap + +import scala.collection.immutable.Map + +import org.apache.avro.{Schema, SchemaBuilder} +import org.apache.avro.generic.GenericData.Record +import org.apache.avro.generic.GenericRecord +import org.apache.avro.mapred.AvroKey +import org.apache.avro.mapreduce.AvroKeyOutputFormat +import org.apache.hadoop.fs.Path +import org.apache.hadoop.io.NullWritable +import org.apache.hadoop.mapreduce.{RecordWriter, TaskAttemptContext} + +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} +import org.apache.spark.sql.execution.datasources.OutputWriter +import org.apache.spark.sql.types._ + +// NOTE: This class is instantiated and used on executor side only, no need to be serializable. +private[avro] class AvroOutputWriter( + path: String, + context: TaskAttemptContext, + schema: StructType, + recordName: String, + recordNamespace: String) extends OutputWriter { + + private lazy val converter = createConverterToAvro(schema, recordName, recordNamespace) + // copy of the old conversion logic after api change in SPARK-19085 + private lazy val internalRowConverter = + CatalystTypeConverters.createToScalaConverter(schema).asInstanceOf[InternalRow => Row] + + /** + * Overrides the couple of methods responsible for generating the output streams / files so + * that the data can be correctly partitioned + */ + private val recordWriter: RecordWriter[AvroKey[GenericRecord], NullWritable] = + new AvroKeyOutputFormat[GenericRecord]() { + + override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { + new Path(path) + } + + @throws(classOf[IOException]) + override def getAvroFileOutputStream(c: TaskAttemptContext): OutputStream = { + val path = getDefaultWorkFile(context, ".avro") + path.getFileSystem(context.getConfiguration).create(path) + } + + }.getRecordWriter(context) + + override def write(internalRow: InternalRow): Unit = { + val row = internalRowConverter(internalRow) + val key = new AvroKey(converter(row).asInstanceOf[GenericRecord]) + recordWriter.write(key, NullWritable.get()) + } + + override def close(): Unit = recordWriter.close(context) + + /** + * This function constructs converter function for a given sparkSQL datatype. This is used in + * writing Avro records out to disk + */ + private def createConverterToAvro( + dataType: DataType, + structName: String, + recordNamespace: String): (Any) => Any = { + dataType match { + case BinaryType => (item: Any) => item match { + case null => null + case bytes: Array[Byte] => ByteBuffer.wrap(bytes) + } + case ByteType | ShortType | IntegerType | LongType | + FloatType | DoubleType | StringType | BooleanType => identity + case _: DecimalType => (item: Any) => if (item == null) null else item.toString + case TimestampType => (item: Any) => + if (item == null) null else item.asInstanceOf[Timestamp].getTime + case DateType => (item: Any) => + if (item == null) null else item.asInstanceOf[Date].getTime + case ArrayType(elementType, _) => + val elementConverter = createConverterToAvro( + elementType, + structName, + SchemaConverters.getNewRecordNamespace(elementType, recordNamespace, structName)) + (item: Any) => { + if (item == null) { + null + } else { + val sourceArray = item.asInstanceOf[Seq[Any]] + val sourceArraySize = sourceArray.size + val targetArray = new Array[Any](sourceArraySize) + var idx = 0 + while (idx < sourceArraySize) { + targetArray(idx) = elementConverter(sourceArray(idx)) + idx += 1 + } + targetArray + } + } + case MapType(StringType, valueType, _) => + val valueConverter = createConverterToAvro( + valueType, + structName, + SchemaConverters.getNewRecordNamespace(valueType, recordNamespace, structName)) + (item: Any) => { + if (item == null) { + null + } else { + val javaMap = new HashMap[String, Any]() + item.asInstanceOf[Map[String, Any]].foreach { case (key, value) => + javaMap.put(key, valueConverter(value)) + } + javaMap + } + } + case structType: StructType => + val builder = SchemaBuilder.record(structName).namespace(recordNamespace) + val schema: Schema = SchemaConverters.convertStructToAvro( + structType, builder, recordNamespace) + val fieldConverters = structType.fields.map(field => + createConverterToAvro( + field.dataType, + field.name, + SchemaConverters.getNewRecordNamespace(field.dataType, recordNamespace, field.name))) + (item: Any) => { + if (item == null) { + null + } else { + val record = new Record(schema) + val convertersIterator = fieldConverters.iterator + val fieldNamesIterator = dataType.asInstanceOf[StructType].fieldNames.iterator + val rowIterator = item.asInstanceOf[Row].toSeq.iterator + + while (convertersIterator.hasNext) { + val converter = convertersIterator.next() + record.put(fieldNamesIterator.next(), converter(rowIterator.next())) + } + record + } + } + } + } +} diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOutputWriterFactory.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOutputWriterFactory.scala new file mode 100644 index 0000000000000..5b2ce7d7d8e0f --- /dev/null +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOutputWriterFactory.scala @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.avro + +import org.apache.hadoop.mapreduce.TaskAttemptContext + +import org.apache.spark.sql.execution.datasources.{OutputWriter, OutputWriterFactory} +import org.apache.spark.sql.types.StructType + +private[avro] class AvroOutputWriterFactory( + schema: StructType, + recordName: String, + recordNamespace: String) extends OutputWriterFactory { + + override def getFileExtension(context: TaskAttemptContext): String = ".avro" + + override def newInstance( + path: String, + dataSchema: StructType, + context: TaskAttemptContext): OutputWriter = { + new AvroOutputWriter(path, context, schema, recordName, recordNamespace) + } +} diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala new file mode 100644 index 0000000000000..01f8c74982535 --- /dev/null +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala @@ -0,0 +1,406 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.avro + +import java.nio.ByteBuffer +import java.sql.{Date, Timestamp} + +import scala.collection.JavaConverters._ + +import org.apache.avro.{Schema, SchemaBuilder} +import org.apache.avro.Schema.Type._ +import org.apache.avro.SchemaBuilder._ +import org.apache.avro.generic.{GenericData, GenericRecord} +import org.apache.avro.generic.GenericFixed + +import org.apache.spark.sql.catalyst.expressions.GenericRow +import org.apache.spark.sql.types._ + +/** + * This object contains method that are used to convert sparkSQL schemas to avro schemas and vice + * versa. + */ +object SchemaConverters { + + class IncompatibleSchemaException(msg: String, ex: Throwable = null) extends Exception(msg, ex) + + case class SchemaType(dataType: DataType, nullable: Boolean) + + /** + * This function takes an avro schema and returns a sql schema. + */ + def toSqlType(avroSchema: Schema): SchemaType = { + avroSchema.getType match { + case INT => SchemaType(IntegerType, nullable = false) + case STRING => SchemaType(StringType, nullable = false) + case BOOLEAN => SchemaType(BooleanType, nullable = false) + case BYTES => SchemaType(BinaryType, nullable = false) + case DOUBLE => SchemaType(DoubleType, nullable = false) + case FLOAT => SchemaType(FloatType, nullable = false) + case LONG => SchemaType(LongType, nullable = false) + case FIXED => SchemaType(BinaryType, nullable = false) + case ENUM => SchemaType(StringType, nullable = false) + + case RECORD => + val fields = avroSchema.getFields.asScala.map { f => + val schemaType = toSqlType(f.schema()) + StructField(f.name, schemaType.dataType, schemaType.nullable) + } + + SchemaType(StructType(fields), nullable = false) + + case ARRAY => + val schemaType = toSqlType(avroSchema.getElementType) + SchemaType( + ArrayType(schemaType.dataType, containsNull = schemaType.nullable), + nullable = false) + + case MAP => + val schemaType = toSqlType(avroSchema.getValueType) + SchemaType( + MapType(StringType, schemaType.dataType, valueContainsNull = schemaType.nullable), + nullable = false) + + case UNION => + if (avroSchema.getTypes.asScala.exists(_.getType == NULL)) { + // In case of a union with null, eliminate it and make a recursive call + val remainingUnionTypes = avroSchema.getTypes.asScala.filterNot(_.getType == NULL) + if (remainingUnionTypes.size == 1) { + toSqlType(remainingUnionTypes.head).copy(nullable = true) + } else { + toSqlType(Schema.createUnion(remainingUnionTypes.asJava)).copy(nullable = true) + } + } else avroSchema.getTypes.asScala.map(_.getType) match { + case Seq(t1) => + toSqlType(avroSchema.getTypes.get(0)) + case Seq(t1, t2) if Set(t1, t2) == Set(INT, LONG) => + SchemaType(LongType, nullable = false) + case Seq(t1, t2) if Set(t1, t2) == Set(FLOAT, DOUBLE) => + SchemaType(DoubleType, nullable = false) + case _ => + // Convert complex unions to struct types where field names are member0, member1, etc. + // This is consistent with the behavior when converting between Avro and Parquet. + val fields = avroSchema.getTypes.asScala.zipWithIndex.map { + case (s, i) => + val schemaType = toSqlType(s) + // All fields are nullable because only one of them is set at a time + StructField(s"member$i", schemaType.dataType, nullable = true) + } + + SchemaType(StructType(fields), nullable = false) + } + + case other => throw new IncompatibleSchemaException(s"Unsupported type $other") + } + } + + /** + * This function converts sparkSQL StructType into avro schema. This method uses two other + * converter methods in order to do the conversion. + */ + def convertStructToAvro[T]( + structType: StructType, + schemaBuilder: RecordBuilder[T], + recordNamespace: String): T = { + val fieldsAssembler: FieldAssembler[T] = schemaBuilder.fields() + structType.fields.foreach { field => + val newField = fieldsAssembler.name(field.name).`type`() + + if (field.nullable) { + convertFieldTypeToAvro(field.dataType, newField.nullable(), field.name, recordNamespace) + .noDefault + } else { + convertFieldTypeToAvro(field.dataType, newField, field.name, recordNamespace) + .noDefault + } + } + fieldsAssembler.endRecord() + } + + /** + * Returns a converter function to convert row in avro format to GenericRow of catalyst. + * + * @param sourceAvroSchema Source schema before conversion inferred from avro file by passed in + * by user. + * @param targetSqlType Target catalyst sql type after the conversion. + * @return returns a converter function to convert row in avro format to GenericRow of catalyst. + */ + private[avro] def createConverterToSQL( + sourceAvroSchema: Schema, + targetSqlType: DataType): AnyRef => AnyRef = { + + def createConverter(avroSchema: Schema, + sqlType: DataType, path: List[String]): AnyRef => AnyRef = { + val avroType = avroSchema.getType + (sqlType, avroType) match { + // Avro strings are in Utf8, so we have to call toString on them + case (StringType, STRING) | (StringType, ENUM) => + (item: AnyRef) => item.toString + // Byte arrays are reused by avro, so we have to make a copy of them. + case (IntegerType, INT) | (BooleanType, BOOLEAN) | (DoubleType, DOUBLE) | + (FloatType, FLOAT) | (LongType, LONG) => + identity + case (TimestampType, LONG) => + (item: AnyRef) => new Timestamp(item.asInstanceOf[Long]) + case (DateType, LONG) => + (item: AnyRef) => new Date(item.asInstanceOf[Long]) + case (BinaryType, FIXED) => + (item: AnyRef) => item.asInstanceOf[GenericFixed].bytes().clone() + case (BinaryType, BYTES) => + (item: AnyRef) => + val byteBuffer = item.asInstanceOf[ByteBuffer] + val bytes = new Array[Byte](byteBuffer.remaining) + byteBuffer.get(bytes) + bytes + case (struct: StructType, RECORD) => + val length = struct.fields.length + val converters = new Array[AnyRef => AnyRef](length) + val avroFieldIndexes = new Array[Int](length) + var i = 0 + while (i < length) { + val sqlField = struct.fields(i) + val avroField = avroSchema.getField(sqlField.name) + if (avroField != null) { + val converter = (item: AnyRef) => { + if (item == null) { + item + } else { + createConverter(avroField.schema, sqlField.dataType, path :+ sqlField.name)(item) + } + } + converters(i) = converter + avroFieldIndexes(i) = avroField.pos() + } else if (!sqlField.nullable) { + throw new IncompatibleSchemaException( + s"Cannot find non-nullable field ${sqlField.name} at path ${path.mkString(".")} " + + "in Avro schema\n" + + s"Source Avro schema: $sourceAvroSchema.\n" + + s"Target Catalyst type: $targetSqlType") + } + i += 1 + } + + (item: AnyRef) => + val record = item.asInstanceOf[GenericRecord] + val result = new Array[Any](length) + var i = 0 + while (i < converters.length) { + if (converters(i) != null) { + val converter = converters(i) + result(i) = converter(record.get(avroFieldIndexes(i))) + } + i += 1 + } + new GenericRow(result) + case (arrayType: ArrayType, ARRAY) => + val elementConverter = createConverter(avroSchema.getElementType, arrayType.elementType, + path) + val allowsNull = arrayType.containsNull + (item: AnyRef) => + item.asInstanceOf[java.lang.Iterable[AnyRef]].asScala.map { element => + if (element == null && !allowsNull) { + throw new RuntimeException(s"Array value at path ${path.mkString(".")} is not " + + "allowed to be null") + } else { + elementConverter(element) + } + } + case (mapType: MapType, MAP) if mapType.keyType == StringType => + val valueConverter = createConverter(avroSchema.getValueType, mapType.valueType, path) + val allowsNull = mapType.valueContainsNull + (item: AnyRef) => + item.asInstanceOf[java.util.Map[AnyRef, AnyRef]].asScala.map { case (k, v) => + if (v == null && !allowsNull) { + throw new RuntimeException(s"Map value at path ${path.mkString(".")} is not " + + "allowed to be null") + } else { + (k.toString, valueConverter(v)) + } + }.toMap + case (sqlType, UNION) => + if (avroSchema.getTypes.asScala.exists(_.getType == NULL)) { + val remainingUnionTypes = avroSchema.getTypes.asScala.filterNot(_.getType == NULL) + if (remainingUnionTypes.size == 1) { + createConverter(remainingUnionTypes.head, sqlType, path) + } else { + createConverter(Schema.createUnion(remainingUnionTypes.asJava), sqlType, path) + } + } else avroSchema.getTypes.asScala.map(_.getType) match { + case Seq(t1) => createConverter(avroSchema.getTypes.get(0), sqlType, path) + case Seq(a, b) if Set(a, b) == Set(INT, LONG) && sqlType == LongType => + (item: AnyRef) => + item match { + case l: java.lang.Long => l + case i: java.lang.Integer => new java.lang.Long(i.longValue()) + } + case Seq(a, b) if Set(a, b) == Set(FLOAT, DOUBLE) && sqlType == DoubleType => + (item: AnyRef) => + item match { + case d: java.lang.Double => d + case f: java.lang.Float => new java.lang.Double(f.doubleValue()) + } + case other => + sqlType match { + case t: StructType if t.fields.length == avroSchema.getTypes.size => + val fieldConverters = t.fields.zip(avroSchema.getTypes.asScala).map { + case (field, schema) => + createConverter(schema, field.dataType, path :+ field.name) + } + (item: AnyRef) => + val i = GenericData.get().resolveUnion(avroSchema, item) + val converted = new Array[Any](fieldConverters.length) + converted(i) = fieldConverters(i)(item) + new GenericRow(converted) + case _ => throw new IncompatibleSchemaException( + s"Cannot convert Avro schema to catalyst type because schema at path " + + s"${path.mkString(".")} is not compatible " + + s"(avroType = $other, sqlType = $sqlType). \n" + + s"Source Avro schema: $sourceAvroSchema.\n" + + s"Target Catalyst type: $targetSqlType") + } + } + case (left, right) => + throw new IncompatibleSchemaException( + s"Cannot convert Avro schema to catalyst type because schema at path " + + s"${path.mkString(".")} is not compatible (avroType = $right, sqlType = $left). \n" + + s"Source Avro schema: $sourceAvroSchema.\n" + + s"Target Catalyst type: $targetSqlType") + } + } + createConverter(sourceAvroSchema, targetSqlType, List.empty[String]) + } + + /** + * This function is used to convert some sparkSQL type to avro type. Note that this function won't + * be used to construct fields of avro record (convertFieldTypeToAvro is used for that). + */ + private def convertTypeToAvro[T]( + dataType: DataType, + schemaBuilder: BaseTypeBuilder[T], + structName: String, + recordNamespace: String): T = { + dataType match { + case ByteType => schemaBuilder.intType() + case ShortType => schemaBuilder.intType() + case IntegerType => schemaBuilder.intType() + case LongType => schemaBuilder.longType() + case FloatType => schemaBuilder.floatType() + case DoubleType => schemaBuilder.doubleType() + case _: DecimalType => schemaBuilder.stringType() + case StringType => schemaBuilder.stringType() + case BinaryType => schemaBuilder.bytesType() + case BooleanType => schemaBuilder.booleanType() + case TimestampType => schemaBuilder.longType() + case DateType => schemaBuilder.longType() + + case ArrayType(elementType, _) => + val builder = getSchemaBuilder(dataType.asInstanceOf[ArrayType].containsNull) + val elementSchema = convertTypeToAvro(elementType, builder, structName, recordNamespace) + schemaBuilder.array().items(elementSchema) + + case MapType(StringType, valueType, _) => + val builder = getSchemaBuilder(dataType.asInstanceOf[MapType].valueContainsNull) + val valueSchema = convertTypeToAvro(valueType, builder, structName, recordNamespace) + schemaBuilder.map().values(valueSchema) + + case structType: StructType => + convertStructToAvro( + structType, + schemaBuilder.record(structName).namespace(recordNamespace), + recordNamespace) + + case other => throw new IncompatibleSchemaException(s"Unexpected type $dataType.") + } + } + + /** + * This function is used to construct fields of the avro record, where schema of the field is + * specified by avro representation of dataType. Since builders for record fields are different + * from those for everything else, we have to use a separate method. + */ + private def convertFieldTypeToAvro[T]( + dataType: DataType, + newFieldBuilder: BaseFieldTypeBuilder[T], + structName: String, + recordNamespace: String): FieldDefault[T, _] = { + dataType match { + case ByteType => newFieldBuilder.intType() + case ShortType => newFieldBuilder.intType() + case IntegerType => newFieldBuilder.intType() + case LongType => newFieldBuilder.longType() + case FloatType => newFieldBuilder.floatType() + case DoubleType => newFieldBuilder.doubleType() + case _: DecimalType => newFieldBuilder.stringType() + case StringType => newFieldBuilder.stringType() + case BinaryType => newFieldBuilder.bytesType() + case BooleanType => newFieldBuilder.booleanType() + case TimestampType => newFieldBuilder.longType() + case DateType => newFieldBuilder.longType() + + case ArrayType(elementType, _) => + val builder = getSchemaBuilder(dataType.asInstanceOf[ArrayType].containsNull) + val elementSchema = convertTypeToAvro( + elementType, + builder, + structName, + getNewRecordNamespace(elementType, recordNamespace, structName)) + newFieldBuilder.array().items(elementSchema) + + case MapType(StringType, valueType, _) => + val builder = getSchemaBuilder(dataType.asInstanceOf[MapType].valueContainsNull) + val valueSchema = convertTypeToAvro( + valueType, + builder, + structName, + getNewRecordNamespace(valueType, recordNamespace, structName)) + newFieldBuilder.map().values(valueSchema) + + case structType: StructType => + convertStructToAvro( + structType, + newFieldBuilder.record(structName).namespace(s"$recordNamespace.$structName"), + s"$recordNamespace.$structName") + + case other => throw new IncompatibleSchemaException(s"Unexpected type $dataType.") + } + } + + /** + * Returns a new namespace depending on the data type of the element. + * If the data type is a StructType it returns the current namespace concatenated + * with the element name, otherwise it returns the current namespace as it is. + */ + private[avro] def getNewRecordNamespace( + elementDataType: DataType, + currentRecordNamespace: String, + elementName: String): String = { + + elementDataType match { + case StructType(_) => s"$currentRecordNamespace.$elementName" + case _ => currentRecordNamespace + } + } + + private def getSchemaBuilder(isNullable: Boolean): BaseTypeBuilder[Schema] = { + if (isNullable) { + SchemaBuilder.builder().nullable() + } else { + SchemaBuilder.builder() + } + } +} diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/package.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/package.scala new file mode 100755 index 0000000000000..b3c8a669cf820 --- /dev/null +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/package.scala @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +package object avro { + /** + * Adds a method, `avro`, to DataFrameWriter that allows you to write avro files using + * the DataFileWriter + */ + implicit class AvroDataFrameWriter[T](writer: DataFrameWriter[T]) { + def avro: String => Unit = writer.format("avro").save + } + + /** + * Adds a method, `avro`, to DataFrameReader that allows you to read avro files using + * the DataFileReader + */ + implicit class AvroDataFrameReader(reader: DataFrameReader) { + def avro: String => DataFrame = reader.format("avro").load + + @scala.annotation.varargs + def avro(sources: String*): DataFrame = reader.format("avro").load(sources: _*) + } +} diff --git a/external/avro/src/test/resources/episodes.avro b/external/avro/src/test/resources/episodes.avro new file mode 100644 index 0000000000000000000000000000000000000000..58a028ce19e6a1e964465dba8f9fbf5c84b3c3e5 GIT binary patch literal 597 zcmZ`$$w~u35Ov8x#Dj=L5s_jL1qmVRM7@a%A}%2+f)b^isW@$Vx`*!0$Pn`jp8Nt& zeudv=PZl@uSoPMfKD&P$pU7gYWL|p#h4`N7Iwpz8*>)6pQu$8K5g4X3MNCVd^l+mi z^wPBcH1bcl6SZw%Sr`PO_y|f#X$L9-g z;dAr)w);_-ea$!*mc7p@CSd|NlpVELhMh<;4y8h|knQ6Gw{;CytVP*k1x_$Y;bL~} zP%34!WeX0_W;dkQhBBN}WGK8R1;wpeZEAH#z@;EmCg2I|28{bqD#NLaM@Nh~YM*GtY%NloTUNlnX1EJ+mu3l%44q~<0r;;U9FsVqoUvQjEaP0lY$ zQPNS$OUwoFOHzwV;vuSlf@ztlIVr_TR?*ck`QnnI%)E4<6jVQ)pOT*p)b5*_m4 zNi8l`fJiCirzs?7Bo-wmm!uXIE7j^CjLOU{$VrXQO)P*L2X_`kuq-jBG! zQZ16!g>a~NE!5`q4)_ z4$liO^cG152#C1yMoOBpsIolT{kr(v`xw3Q^aHBW4<*kTe|R?Ed|m!p0};Vgja0tv z^0NQSvV@mTJX+?d_R`?juLpDYZ?l{K!|lO@#j3M9R!GPRCFkb&$j^M&lzz*=zqXM# zsNrZ`5s?Tm34X+aF~ge-=lK+ z2L&D(4&@8abtV@!zo-l}lk%Cm$yRMq@$>t&Ket}CzUHQT%Vx^AYX|R6cVK-rRd8j` zOTR5yPp$nu6^}@m>DI|rUwOUhz@FC}U7Lg=b6d_{+`1-vk5BKc(#2Doxw0d#H~rkS z>B_}r%bU$JUshk+u`iG#hrO-m)Zq|*ZMK;OR@s{;vvnSzoJ1dQtC(q4u@wlk?es!+mtYw`gb3HR! zlFhy~`RA@k4N1A}alNE7Cs<_Xv%EKZvz-q_I6ZH2t+4v}K}xeex=j9SE@xNu_tbCuxjBt5tMljYE-QGu oc=J`GqRT&@xU6vJWtjc`@x7({=kF?=`0vk0?&b4-%A@BG085M0ApigX literal 0 HcmV?d00001 diff --git a/external/avro/src/test/resources/test-random-partitioned/part-r-00001.avro b/external/avro/src/test/resources/test-random-partitioned/part-r-00001.avro new file mode 100755 index 0000000000000000000000000000000000000000..1ca623a07dcf36e222f502144212cf90a03bec5a GIT binary patch literal 2313 zcma)6c~}$o78OK9KqLrC6%`^;sR3jW5o8G>$QDr67$6ME1QN1fl3_6fWhtway&}sC zqGey|M%V-`OGT*CqEP|G9l;{{fM{5x3D7)RzxRFpYv!JN&$;I}b7ouv$xGDGSj^bQ*4{gQGzd zio_TezFaQT{8pG*Vu>u`D0EuTKY3#7NVvEox-5!(%_UOk01HQ;LxB`<#mYrkR4+GH z@`$7ekYFU4l^k*r7cW|Ro02gm>6Ga08m&C+V$bm3Nr=amBn($dfHa8uwZJmEY{4BO zi5~znk{U>-h@%8|cSG48aTj4nkD!iwh;M8iP%f@$Tk-8-XHOeI+;3LN-L#S&Qr(v8 zR1Vs!C4_U?3{(GFDznj7FR>IAswtO}q?ux;ocxU`<5|P@s;3j-Fq_a4s?zNs2m&^cy=$fyhTFzEv6sTZLfgWDR|cL3 zG48Ci+FwrC{Ujyx8e!}X0qtQ{9fVHofs%mruJk~Cy=($!pO$L~M^6y-3peruZ^V-I zw)e(IuGNj1%!0zWRJDzEaxEYD^Rq#l&4iY>^;;+FxEgx+-<}*%nKE~^`pgH+{qJ7a z`JDgqOP2H2H36CF{kvH-h4L^*imYQQpGa=kgX>RzY~PTvTV}g?)Nqukp{k&H zUSZ68`1Bv+_HDxF>0jmsXUC(Ouh)Ef;~BkB^dYTBKVu3@4|!WxMo92X$J& zd#0*t56R;WHs&Fmd(herUdpB(8xyD&u5BM*j%kxpT7@Df45&r-lEf zUmcIIag=_re-@+o2X0{fKslPOj;FlFoi^{B#?X0T*?VXlXc{Xn2b9{GXJTJ@koN6q1l*b<2dWF z&b(bO@}jc5qht1&+atDh$F~$YLf_?CD|_rnA4F+Gm958047*C9sP2sZ>sj&Plf)R~ zf~uQ0vd;Z}5_#rI*uBS1cfVaN%!8d9bO2DJSDG-MzrF9$pE=)=pY1FN;JKo=ur)JW zl6I!s#sW##8r#CEYcjp|wolmy`Ka4_Q~|g88%8bc159;moWCVxs05qXURB{>A6(6C zwqf%t-}W7GkP50S+g&nn*H+<{0NTZ)^r@K~SYsT&$*3GNh}h=lqkI;R@VY;$*x_8d zn#;#4zS{}W5-%htpjt|@RO@AZ?v-OrJPAKwe|FKiWU|XTR15DFKe@m4eghWvhWr}K z3-*Z;I5#4Y@!*CLwBS_((o#BRzjsd9(`q#1WRZ@dLRL~CK@M{df7zo(*LRB8ogH^A z|M}DW^=MLcOjlq28#O7Kx_+_UC!e!IOX@pWIp5CL&W?^pfAD4**!%O<`k{G(A)>U3 zv7yT~AF1PHq}>>zPPiFc-lUlC>2Yy7_}K|RcVWH%L-{dMTW>usaw3&(F>|>t`b7no l=Dn%ucfNM=JvwJ!aUC~mGJJtO{<_X0i7T92koxBj^iS{0*{lEn literal 0 HcmV?d00001 diff --git a/external/avro/src/test/resources/test-random-partitioned/part-r-00002.avro b/external/avro/src/test/resources/test-random-partitioned/part-r-00002.avro new file mode 100755 index 0000000000000000000000000000000000000000..a12e9459e7461cc6a7d6bcdda9f6a5b76a6a840a GIT binary patch literal 1621 zcmeZI%3@>@Nh~YM*GtY%NloTUNlnX1EJ+mu3l%44q~<0r;;U9FsVqoUvQjEaP0lY$ zQPNS$OUwoFOHzwV;vuSlf@ztlIVr_TR?*ck`QnnI%)E4<6jVQ)pOT*p)b5*_m4 zNi8l`fJiCirzs?7Bo-wmm!uXIE7j^CjLOU{$VrXQO)P*L2X_`kuq-jBG! zQZ16!g>a~NE!5`h%K!X7&ekGnQ^-`@$(Y$#!+Dt3>Pb zy#Dsb4-0n0@I4gL6aT{hgPliQ``{nW^tAJ5=Jnggzpsg6bDbix@=|7Ms$$m!B_&S< zPO)1V4zm|1MpSO(cG*(!`x^JXt@F0}?a?nje5>BNcy1!2z{?pm8AmZ?*^C7Oak59|jSH8WVHPuDp(v{xm z^ZV_0J<48xdUa$};qTe}b+;CuU0D?S-QroGw)GtUzDq0%y5=m|^lU;k=i;`DNhOOI zCS6>2<)P2nmUga&%}QnWF3GVjXnq*coIY#DS3~u6W^O-4J&Mhyb>@87xl*vF>4A>6 z;gVRh$=N^GJX*vmZ2Ux_diSfEXFAUp6>j94`^|9Ef?D5`hc=|zT-x$1Jxkr9Yr$7KlU}z^&QOkJ+VyjO zfBwI%Zm(+(&P=-#{-00zX;e~XcUo8Z)5tkeh5?INo<5lp<29$Rhh@?NPN~W^ug=Ri z&Ig2YaNZ6$mLOg@`H;bKai&|#7fd=Wc*5scmQ6{b_t{uUZAIxvF^he*AJ0`y$uzP$ z=<%v0if^LzQOObmDe0&=T8zw}FNlbA+$(7Qtnur`VT<=NeO4EY3zF1S{PZL;p1QCL zNqRe6Y%I{OUhZ?^>c;fqAUnV3-KLRVmm@slvSPT+@yvSI%xANEBpHJ)26C(hAypM@Nh~YM*GtY%NloTUNlnX1EJ+mu3l%44q~<0r;;U9FsVqoUvQjEaP0lY$ zQPNS$OUwoFOHzwV;vuSlf@ztlIVr_TR?*ck`QnnI%)E4<6jVQ)pOT*p)b5*_m4 zNi8l`fJiCirzs?7Bo-wmm!uXIE7j^CjLOU{$VrXQO)P*L2X_`kuq-jBG! zQZ16!g>a~NE!5`` ztqlL{$Gh|rBb_3`Hg4j)#i_LHx}Pqu(f8Z--*<;{bZ}@bGgx+Ufr1e0Q4zMzM@~H} z6mCS^(2USvJorFW!e>re+rw15eU;CD-t%3Z8MVJSe2!>)fV+qDM9ZWVJ|}&{*t0KL zCAyVd2??9w+J8nhb!)Rs#{J@T77Fq%3+7cH3q2|5bf5dwy5}lst&U2uO7FXL!Dsbc9OlY=GA|Eh8uoJ?^bGE~;LD?+gmPv z^Afklj;Y@Id;ULr<~YHlz?MffGATuW&##9+m91?xq*|EfN2 ze{a6*YvJ38eZ2>kO`IQ|9mag^_?eguPlG0|I%;(+=Eu4v#fAWky|n zWA^>n*&M^Iy4{N_!!+C`{nYC_I4eN;>~5R%vyTci^@Hz)NNtHd(6srK!%VZ20n6>h zuYK4#^OcLa(}7p4+b-@pvi5_=1c?O+Q-63geYn=)rlKMfw6JsfDM9TIIy0oU&g3=X z2=&RG`(w)7SF!=ux;YPvyuQsKP;f@)<2Loll3^!;&UCY1(Yn4N@Zz_O&WTQGVn&X7 zA+~%|PrC>PFdS&Hbd3~AE)YAtv_HFN=f&Sl-l3%#E*Uys%brXup17=a-n1)zmy>p% zO}b{MyU*_X%l`bD-#3qHOLt$6|5|L$cj#oiUrg<LZo|eU=EAeREYq`*S9BVc`tu@)c;q=kZ`FDQ4YZt!mU-#|3 zaes~B+8E8#y}d6V?A!eHslC}5@qF8Y&sVz7*BY&z9wz_q&+p0XXA>nl7AU4N%(D7? z>TC4rTyqC`+dDshe0sZf?WO9pZ?jjYufMnXd6vk^k2#4ypFeu}{=FPS6tB-aU;fo1 zk#85@*&1w^D#FlynbE83;(`t!D z$8FCF_A1Wx=FYTwC-Qt}vryy1vaQ+zT&fQKUl_$@zBhKO?+88eU5Qz2lF>fh^+8wN zpGu^y5mt>BG)up{rPXHs*FCqK44%l9cln92dO)UCiOTjZJ)!# zdt!>(tQo0WMYgR9ZkQlt@VrUm^NnnSM=Km+7Yc7YbE*7dW!^&9DcXq(l2%uDp z++Dd_D#<1NvYnRlG>N1*rgu+^Zf#T$*G;^8`iP*;W`__NW@Y!BCRw+OY>B5|e)!Ye sU-|#a()9MT&QI2Uj%avsbZwiRZ2jBLE&b=_+WdL>P4nM}|LBDe08rC;u>b%7 literal 0 HcmV?d00001 diff --git a/external/avro/src/test/resources/test-random-partitioned/part-r-00004.avro b/external/avro/src/test/resources/test-random-partitioned/part-r-00004.avro new file mode 100755 index 0000000000000000000000000000000000000000..af56dfc8083dc409b2afb5a1970897b148dd8ef8 GIT binary patch literal 3282 zcma)8c{o)2A8tX}B}>+H2s5agJ#raa$QF{dF~(q+1+zqqEh9@&mTXhmqAa%<`;wH9 ztR>7KSrXYol9aCP`nmT$zvsDso%4A=pZEQ|-*evcJjcx0V=n^_jOc?s0mr%^;2bUp zR}>77;M;_7aCZb6mdW7{;QhQ1fEwU~fMb1J09gP7LvPf01P%|~npy{4kqDFv4p6iA z|ErI~`yerH>#c46PVIt)uhUHsFwA-%g}~v&wpwCXS24Id%m)U?BYbcGN%ntINVFFU z0Y$^ScI0+!ZGl7>ihx*O4^93W{b6M5sJcM-2Tbiur3R1bk;AMBLC*PqOdSL)CEg$Mj>{s+=SQb?Z$4N zE7BL?qOdc&bI476W6{>z+!@0mudKgah24kU?8*N(iH)%>3HjX;2n+%JZ-H%e+kzzl zx$yvSerRVbY9mU3r8z`b&vG3z(1Tc5ZQSS@LQKp|bvJHxj4x_l`cayw_; z!M8W<)c;QHW{{wp`1+2&N0;O%3&2oF7;YnJ_JGCa2Lagt{%!;2;{)@9VqKxmSS)Hi z<~u?=hBhbGVI&@b-bn3VgS!vo?6>}f?NEWOlH^F*lTeZF%RO0TDkohJExBeo$22Q0 zl4KH*eUUf34%w1}4$lz?haGH_o&eM#q3=D)W_bp{!)w1NtG|L6k1>d|KpY}YgmSY% z^mM%r#fyPYD2ApVbV1N zM=(9`jMMem)3t{*L@o~M6}#3{G-dAXFDFRZ%{E#j@h3%_yu%0jokNEh>0vjs|p4kG1LJS-Z#&G6`F>(?lxsb z3hzsbiv$yjm++3_Zz$uWws32Cd?2p}&@fir_)WjZn=*K?iG1o-K8JA^RV8Slk#)## z+!L!o_)6!8ey@){y7F-}8XGYFBG*hMAy~0ms6UwbviE_bP0!70f5r|C`7bYN22H<& zu6>p~*w4x~kPa1TH){`=OuTrv)^|eitr2U0>*|1iNA2g3YV@h;*^ZzuZ|f*WlVZz; z^)I356Y39Ry;nX+6%fPYN|5hM1M2= zJ*Cn?Iq{3uVG%la)_we$VIf@3KHj|%s+y*v?FZudvlRm~;vZd%Nh_TbO8g zFiYNo`VeVS)2)}zMvWI4TAM7tr0%~Kqa>J``}*YTwA4-q&Vh+Rmj#f3bQXh=treiw zKag&hHee9mzU08b1UjDst^3tl&ENE)?so?GjH|FffWZ3--%i;RSF0_khrx{=F*DgD zpyq*=9o=;?2Nh{JXOjFIl{vk$Ybd^_WrdCtxNDQC*epl*K%C9NDZk*RFG7y*K`j535LX#%e4KEyk9 zKT6Q8dLMa7*Uk-LJaZ}%j2SuF1Y3>sdl=@TZ1=MC976F?Loq<1(|EL_<{3#bWrgn< z-`%=tPDY2yS+nwLc$u?@ z?Y=SzMt)N@D^9{W^vsITYmh74B=x8RS6!Fg$zgU!i!5igB$B9e%0V&aFC4!RHSn|pwHbP#9j2Hc76&kHGeL++bQ)1rU2nY^tno!u{8!g9NS z(ieF?F^eOy=Xj`ziEYUQTg;IlbX?gP%06q$s+)SrnD@olZj*G0NP$+058z9>JS`NILDeS0{+s z5OHrUQd+3Rfl3Q`>vrWlR!}rHlLnv+R!u18j2^Aloyg2?^CY6yTNy?U;j~c6G&&gaK7Ufe^U|}^4>z%L-A*$Qq z&tn$RW;|R?aTl8Ne;(aa#*tAJ(sFFSdd{Jd=(92AX_1}2{x`IjD#$7iDFkFUxIF%8 zb|qNPNm8ktn|SOrg1KMeh6CJGUfr!Jm_kMukItUvvvR5;#JUCuH8awjK{jHQ0xax; zRGT)RrxME7ucR1f)UpYKl+yJ!}mD*ZV;hGATFX9RGOBA}=Kf0%>;+zwY5^ z07=<1*~?XV@wb8!FBB@985DHyJ>mft?8$fx`*1?H;u1|u+CVtrONnrCH^<4v#+YL- z9(XURX_vQhD(f`0S2Iz$_$~{*%d4czmq>cHg=Ja^`x_YZ(diFFp!zA!dxtY6`SbIP zc}G;(OnKooQVycgqRaxQRTd$(Sz2|y8t=0ahX z(@d9=qd390A50A{l+m4w3Y-faJtLT}UujXp5}5~hW-`0@$F#%xI-*7g^zowr6WCN2 zxTkTep%qQ^I#qQAnb8*DAZETytBf1v`85Ci&h;)u-=QNs5}5Jn_EvwZy=DOmHmU_T z)>R^C~s-Zv7@T0(NO9YVvvz%4PQg$;fPVlv}byI z;qo~C>W3n_dMA1R)c}b|Db-&6koZOCmAH{G$Rk@1irJ@FRz8HL$nr}YrPAUIO@Nh~YM*GtY%NloTUNlnX1EJ+mu3l%44q~<0r;;U9FsVqoUvQjEaP0lY$ zQPNS$OUwoFOHzwV;vuSlf@ztlIVr_TR?*ck`QnnI%)E4<6jVQ)pOT*p)b5*_m4 zNi8l`fJiCirzs?7Bo-wmm!uXIE7j^CjLOU{$VrXQO)P*L2X_`kuq-jBG! zQZ16!g>a~NE!5`9vH?HLE}M8j+mM~2qt zX0sn#oOgD%;j>w{al?d+L}f=Fj$izL7&jYRTU!J1W+3WTSk$V~G;6_Bj(K5@!k%27 zi!-l?m^C%+JYgzm+IlEa`PiZAug%~0et%V{G27QqcKU4Z!;fpFGNpD2Wu?c=Q`;^7 zZ(eSSf&1(4aWV6=1eIet*74cf=X_P|VVbQTb8w2=kv%sfWlbEJ8RVDNO@4GJ#CTo9 zwLo1LcXe0Q3K@siU&|i#6?JPj6))6Tv3w?zZScu0Itu(LZc{(q)@n~ojWPE(dL!X- zCG}4D5qD08xVe!J*S@vex%Fq+>(|?*ENWGwB5G{s*)d7+g`QdMYvU^1Gq3*EysTrc zPrD8t`{#e3|K`8m-`QXHm55}v`L-FDOyPBxUi?x{TyYtC+S2pm2J^T&^O`*+zpuveu95jy_fFv5UCHL39j$%>vqVyL zCIwu+acT$G5=S1@6^vQJNA9gZx>MfW{?uM>asL?a{aYrUygcJrT=(>8Gao#=>;8JO ztL>ZZbCeQYZU;W~kel46crI!6qb_5HH18bUl0+$v2Z8x&;yoK$rtb;S(byUGZ;#Cb-g$A~!u#Lqyg4T?Gymf*vjEL0R{SL?whX)ee>nMct@rY} N^K;AJ|7S)|MF7-8R#E@} literal 0 HcmV?d00001 diff --git a/external/avro/src/test/resources/test-random-partitioned/part-r-00006.avro b/external/avro/src/test/resources/test-random-partitioned/part-r-00006.avro new file mode 100755 index 0000000000000000000000000000000000000000..c326fc434bf181b897372b8ab750a03320145617 GIT binary patch literal 1729 zcmeZI%3@>@Nh~YM*GtY%NloTUNlnX1EJ+mu3l%44q~<0r;;U9FsVqoUvQjEaP0lY$ zQPNS$OUwoFOHzwV;vuSlf@ztlIVr_TR?*ck`QnnI%)E4<6jVQ)pOT*p)b5*_m4 zNi8l`fJiCirzs?7Bo-wmm!uXIE7j^CjLOU{$VrXQO)P*L2X_`kuq-jBG! zQZ16!g>a~NE!5`=-%k3&)dh9=mr>?d2A)g-gYZ?w7wiGk50y8vm7z8cPISTrNj;3TzQOqqIcu znv!?Nq*MAL%vq|Qf)}5!tUvm&+^_$-_{@CI8SU@y{{4N^nd` zmSerb>Q9Bc70y@mHFRFMd(3T$l8bd-ny}d006C$x?Lw;gGB3?0iE0U4xtHg7I>vTt zfQd}@UYGNVXMBG}?OtbJuqbfPyY1iYYabttE?*Z@{Wb7UnZ+rco1SOvdRFIcebrVz zZ~umi)fZMI{YsXc!@&4^SL=db=f7U|*DHB^!=Uu_+_Hlvheg+&X`SnyI5WC!P1+&7 zG!~BOkKJqo^~_r=9XDBLB+kn%U-4F^XwI*9=elM;^SXFh<=}0WHLF~gL{00OI&s<( z%Z)44TvkrK_DHGtLACMJC#t?K!Ru!6B`WHr^oKGuO9o7>Rb6J7KDY5&?@3{!Kbwpu zGfd$L6WrPMI463`wBOSTkIuRkvs^;pT6chI_Ob~NTa{kvOj~(L!+0i(kmz*J6Gz{0 zPF~YJSumt>k>lh}f6qxdrjN|p1bFOL+K7fSKQ(N%xcYq7qg&0<3^#UU#PfB`sjxDM z;CA!ee!l;C`uz33j@$p$tgHU}TKhil^ta-NJ(HQX9yN*8<`w!Cx_$4*!+$rk+uygY z`*J+{x_(-CW?SKdguJ@{2j8Z-WSjSo3%aTtlTa_o^(SKZXLO`ivT-3Y|Y58ndZ=RUUWq4-t5c=uUbB|wL}zkT9Jmp(#Q#&!pE{9^2`x#2|7Y*{ z`{1wo>vvD9&Gq(|{`)y2hMU%h<)pAkLb0|2== B!hHY$ literal 0 HcmV?d00001 diff --git a/external/avro/src/test/resources/test-random-partitioned/part-r-00007.avro b/external/avro/src/test/resources/test-random-partitioned/part-r-00007.avro new file mode 100755 index 0000000000000000000000000000000000000000..279f36c317eb8763be6cd3202a88d852550d381e GIT binary patch literal 1897 zcmeZI%3@>@Nh~YM*GtY%NloTUNlnX1EJ+mu3l%44q~<0r;;U9FsVqoUvQjEaP0lY$ zQPNS$OUwoFOHzwV;vuSlf@ztlIVr_TR?*ck`QnnI%)E4<6jVQ)pOT*p)b5*_m4 zNi8l`fJiCirzs?7Bo-wmm!uXIE7j^CjLOU{$VrXQO)P*L2X_`kuq-jBG! zQZ16!g>a~NE!5`p)McTZ*Uy6Uy$j-bjwwT zC4VZ5B+o4$)x3ziS=T z%0fzRdy1_m*D1!@iy?>OCM{d_Ln5h(Gcr~)?-Y|&l*cxgrb7Zo+$k+VY=-C9^hGyM zFg;T7B;cjN+GVp+kVq}^)JF+3&0 zEF~(+Uvg4GO4l}d<(9JtO?=JsYjxwRzPUF)xS5xEz9({W64$1safVFi#X1I5-@oD>yCf!!48l1 zIj1K07SG+js-Z^0X7PQS=GiaR(qxV90YLNAAaN%zu`ce5G07_M-b5on|{uzbrH6y(lMS(DC^zZ-fGaFl%SX zhP3t5r#=bMd=^mXymHy8o^wL$d?u^x7MXSW>aDMn)Zcvl{_f@J&;0*ue_cEJWe+pY Rj1d0to!27ue*Q<#pa29S0zLo$ literal 0 HcmV?d00001 diff --git a/external/avro/src/test/resources/test-random-partitioned/part-r-00008.avro b/external/avro/src/test/resources/test-random-partitioned/part-r-00008.avro new file mode 100755 index 0000000000000000000000000000000000000000..8d70f5d1274d4f68f9e6cef0fdb20c8857b81e52 GIT binary patch literal 3420 zcma)8c{r4N8!lTSvXc;#o$SjbyCHiW%!IL=F*6vA8O@BHD6&OE23eAwX)MW>wFMPf zQixDSwjx57sC>5LbI!TG>-zqB-{*dw`+n}<``+($y%$_O4%2bLd~jGfI2M6~vm%hL zXcz&>vjgGb?nqBq4r?HQ;O~V5XaI0XI2MNh$O14h&+VFk#1ou%r?x;>6cUZV12mw4 zfA#SM917#M)!O6l*9a_pi*A90VYd2cBpxrk+Y-jQ3d7xDI2fFO#Nh$Kvj3Grd3vFd z&Ym!@eYt&GyPyvY?Ty@84?tlEfZzjEICh)r_y2JJm*k%D;DfT>7!(%cyxVccc%#w# z#-M<`a*yVCz|1ad%c3(Hi*frWZ`-;oV7s$**%A3S*AXG7G`XyXJxZ}12WOGxL^s|zBMoh>stdH4{q6y=x*fp z&$^HAZQ8H@o!XrsIq%?GJN_QsU`JU142^=}x0411SnYfedfUL?ZNP9in7=dD)!7A$ zMQ_D?UufUZ&cqgsA|O4tQ~TH8!Na)tZ+&6=RKRW->==9~enaSj7=xHT?IHYe$!MA* z4A;d4Q_PC)x^YLWo<$R1sD!iWM`hE|u?qsZ+{gf+)s$uU{8i}3U)sf4iI!;9B>k$Q zC@@!pn5`ZCBOaXCEt5WV0lV z_m>(*S>b(ACz9~|4EdA?fiO`+M-lEiO=$WbU!M=>iLwtKLy30!WfVl8X`p)M4F@PB z5X5c{8qAOy?JSH&lIA)Y&j)36*zzydnVe9EBoClELNt5so0}23g9gF$pEXR%R7;+N zMR;ub8<3Qnv+l8#mi5;?(w558a!l?dgUivku6;_luTXcIw(FA|&b0`{!HCOyUp%om z`M$3*b!__bNOf&^6%wW$H84pFr6-GwR~o!ru$eWmcU6C+g?EEy~}kUm7g?Ieqfetnut%TU#6TCv$hm*xg*pl-pW- z&F|&z4GHy;SOa`xUTRjt7|tHkmv+ImBJ zjc7^NeqS=w#__gb0nt{o-2P4UNue1gaQZuyvi@?AG1MV4>%`DGSzqAS$R#>8fm4Q3 z^oj^}RIemvkz~ym)An%9+p5?swxi${SU?U{PcPsVF8_GBe#3WO$8U++Mx8meF$yOJ zkLI4jzMus0gUKQ(tiYyb>N>Oh@0G}VwLd;j&Hd7hAL)#-5;u^Y5shX?AC>N&^DSxc z$iBR)$!P$&#&r{-8>|z;Wg1^nqYbUrXj&WZcG9Uelwu>q_(0{Ceo}~!35!2Cp*G_= z(vcqBh?5WE(mJesuZFjqk_20;rFi?Z@ZlC+i6`oYn^iExltIdpcWS9h3q7Vi^S)%o zln)>HhxW5fEmlEuF)b!SamvF_Lbm)5$KFZKt2~{Wn08Gsi#gru7_Yl5i-Gd@RExb8hAW zI6U3my&m?#;P)L{9cGQ}-Z*>%8R7)iOr@iXW0#uTnEd+fuTW6O+wdZW@g_`$eXHU{ zRr!Zf9TY{AXa{|cr@m207J@I6_1khiiRwpa9gd3odl{*t8knnq5~#0Sbkf#Kh|M2uR`e@QLj^u*D2FmY0g^KRV`VPT<9wjXkTXkqjkN<>*?CrM%l zoNf)}wX{Qw+u}G9flXPL$!Uu05}tRnj7ZbA_H96nhe!VMooTK*`fM1-8$ zr4(ux_Y@S1M#ZGskWv{gQs@m@Z97aY zj|F(;Jr9;RyC9U3^Nek}VM%pqARM3S@+V2Aqj2C7`*f2aF1MYzSR?GLGEX9V2A*XV z2OC#94`J-I8cAlCm<%d6N>IyYBW|p={i&*tuhy#J(X%=qenCj&D~IzdR5IArHt)XE zfW7{wVZELk%KYuDrEzEZgd(J9t-_2PCYYq?o?= zJ~ha2kF?m^Tia4ip)2O4nSa#$fiHf@g15pWr*4xu^d=Be)YT{V7Ib~Zg?2r^8BxKe zW}6ULUktoD%H74s<^gg`mo{>5hER_xc|q!eKh_$BiMiJ00>721xx!{4IgY$v3>MN7lsf|&_OA>L^zU`pYW4K&QvNhlGwwUJ1b+8Se7JGooWR` zze;p|NIt?TACv5*ToB&b5yMr%dw($VkBLz5gUMD(8O!Y?_@ka{!ZVYEQkxdg`4$e* zX`3)M1fgjSf~-R^URMYc6ix`8YvcK$sKTT(cBoQ_Z|Rmj4Png{+!C9e6+ud<=p-Pj z<)Y`WmEPh{U`3jlIEM?5hVJ}1W_qWA#CH0^KvO-;`ygGSowP%G&>A{UhYUHDu~ z#)h~ZIeTqXA%;S)nb|zu?7yg?dL=YZjk%&1%+%&w`DeR|s-{L^V=KDl4%aOtDxwa2&!3I80O~1fV4P{#T7`uW9tukJ<)!mMVKOEPqz% zx1}WX?$wQ|!rvQjPyPI{a6zbPEk$BJXu9!Uz-A=zF>r)WlPudm%s=k(`F2`oJ;^F; zCL17RB4r!5VNB6D8Ykz_uKNPh!ZRnwhFBb{r6PNXAYi1h=l-zgV)8mpjYf{OTSzMWC@EdPR{12J8cZteE9v6^JDE&pZw@Nh~YM*GtY%NloTUNlnX1EJ+mu3l%44q~<0r;;U9FsVqoUvQjEaP0lY$ zQPNS$OUwoFOHzwV;vuSlf@ztlIVr_TR?*ck`QnnI%)E4<6jVQ)pOT*p)b5*_m4 zNi8l`fJiCirzs?7Bo-wmm!uXIE7j^CjLOU{$VrXQO)P*L2X_`kuq-jBG! zQZ16!g>a~NE!5`22aQTmZy8)zx%yUgv*mfR9Hz#OKZgr5jX2zt|(rq?gT}Z zscnrc+y`eQEL2+Yy=1xN{;zjtW~*8(zc9;Sm+}JH%!YOW?ea*4tr15=4oo~^^-axX zIuB!i;L2&2UKDocZ0t!t*%9nmIJ@iP#T5cJ+x0rn6$Nu`2(*6ed{Ss(SV zF5}$oRomS9v{vkL*TJ7(m=cZ7XkFJxW!HF}mtfLx)8p$GhGTbb9LR|Dmli7#++3+q zJdZuD;08zbS2iW5S9jN{>YH9zGpS%x#tj#~O&{_eDrKL%b@bepgM8m+74Zrr9O!B2 z<-4xTuN1jZJfQfJt9ZuD1dD{nEMd01o7+uNQ+*=&-WQsS>aEJ~uDKH$P?~;g&azCq z45!x~i?^H$V1F*Zul9Ghw)pvfJ7%4;{c+^s-sR_V{nqg*#})&n1-&XgeZ6<0OG5Tm zmi+$y)8Wy!l*cQ4&KT5v-kbk!-ksKytWyv9Z@2rAQXTbnwn!+0TgCf}CZ{WJDO@j= z{r65&Cwe-|R*U;Te(ZegEnZjm^UR!ie_m`oF2A$v$Mg2{cAr1&X?_#X`YhAx>$~gE z&)d}2{0$U5nz<@5V{iX4X~pZ+%8ego=M@zG{kpjR+GT5Z{lB}mo}BRE+smh)k4IZy zY-~Gu)xK`$cDWr+iCr--rtyn?x0cU*B;lj~r}pVu;m=#e_xnG&AKgB`{@bdZg-Opo SU)xgoulMfk?f=-(^9}$|N9DQz literal 0 HcmV?d00001 diff --git a/external/avro/src/test/resources/test-random-partitioned/part-r-00010.avro b/external/avro/src/test/resources/test-random-partitioned/part-r-00010.avro new file mode 100755 index 0000000000000000000000000000000000000000..aedc7f7e0e61ce84828b3046992b525ec029f473 GIT binary patch literal 3872 zcma)8c{r5)8n*8oMByWhEJKLwq-2Ybi0oq;GqzbQGxnwWvLr^Ih^$4FT^ReGrNP*> z$QnX+V;hoV={U~$u5(@IulIfK=eh6a{=M)0Ue|l;o(nk@2iVI4B?Un_Kp`v+P#7GH zhMqlwJRwd{1UQ|=AAt6Cg920l9#9C%!vSy+fCM8BYc$jo4LY7W0AbEhxPvD^#oGU0 zeNVK9Gt%*(^_zcEJD?y3bVDc@dC-SLJv{}ETLLJU0K^IG0fwNV9-e@}i~p8#M!3PD zAOzU$MDE1aG3W(`W1t7?0nSJ?An+6wf;y!7{6F0PCHdQU;AvS5(iw#W9d|%T3>+n|7!0fsy!1nZ(lZYNi z?&Pc|`0q_8_1{x_6eQ3Qez4;|qZ?>@5delegFO$EW(_bo`XDq8f&a7t_V57vf>1Ef zJroLl5c3nE6GKN62e305ia1Q|zXneq=AQ4t7j{Ag9G8K%fn-TnuFE)AmF!%o7m8_>3Kau5L3y-I=(ED>mm!M$yCdDzskoqW+!#uGPI)j_u_-%sfAL&B z$#FSA>^&A!lC-c@_iJ#U)Hyk5-PlyDn*#YXiK|qv3mT(}e@$4<1PZ0J>2_2#h>>0n z3#^*Ix0K_D+#E*mgW7nt@;o3yECAjB>tZ_HXFLWFGPGc-~9?A&DRI` z)&!}ICJCRb2L;<~6$CECFm=b5Y_ZP_lDfICZgHvg0y0Jkcg7(4)4;k-Y`)s;>_-uQ z(=K5qvviavX(`OQReZ92T|&+-Au2#cy{lb(0dIkl>8=bENNNA|_R_h&C`5u8`d9LJ z^FW0lVrXquTH0#m>(Z)GthVw;zXWminHwTg5ecy(uiz$Cn~9pkj&MPCVtS!TwB?Aa zYtv|Fk)RP)N}sJJQ>xgjm^+xEa+bN}E%qhYC^IHEs<=H_AL3rhhM0esu9d_kbH-it z%}D7mRL~or6i!v^hns4XBQ3z1#j)!nHL;(1FZ|kW*zFTnU8)O$z+D$yL0mF}J=SY# z8yO^ZLLr+%c6$5k9Aa?_9=!efWo&O*UO%Ep{nIGq#oY}4w&6;{v_4hYP@u?@0n7zn zNm{oU?#DgvJ6%P+7%vsoHqp8Cwa53gs$1RGaeduD#VGN9c%F>eQVFLZ&OEnK)RzOl z6=%Ci_z|v|=-LELvB?qX9u^nDMNU20VSU85`UCgG@saDI!h_szXfFA#lrFfrzUWf^EW6rTQqoh? zoM}|GSyAANRvr7G?3fu(K_@%hQ1pjb&j?>&+3YKm5WPZ}m2RqzQnIvfEpO8hFxpzr z#Q?YBRj?d}1JaVLo=N+a#;iRyw>0;(z2>s4?;ddtjI5rpZA9r6#Is5>M@+1kVWdas z^1s~9V_g}jCEOD^Q;DrmB6g;bvCUQHX5?iDwF}tueBb%)n{m~QbqVpPrtJ(nr4is zG&EX%DA=lqUj&YQn~vm`Ode#)_Tk34jdZaF4h+a7N_FcAl{GwxDyPsJip@-O^k}%m zwhH$?XT&ZLoD~iy#@(fK1NsQk+^mo;nJnrF@n$O;E1t(KZ{~R$ZMa}|>4K6S-g>oo z<(e=Fl~boPb}$!-w=HE~h46CWw3XqOuG zF(;?gVH^?aEC{PWS4dK{7}!2ntmAMt@n*f0K~NHD&(FK;eJA7L!jpgWFNLyxpAC#- z*I!tuV(_>x;sN_a!_-;1;Q(@`(N9Ve6)}wB`_@@i>5;?OGK<&Fid!P{T_3lXQWcY# z;d%a`bN8C%xGdIE0PDmw4Be~rFyLxE%Xdi2XTCBj%%8A|4gOZ^i_9s==8yS^n#cwyofuuR4o0My#MX6xa_@nuKc?#R{b}*Q$`?aXO&jh#yoge4L{(TunPb9DL!Ujg7ZofW?}f=js1`+B z+ohgJRC4v?|1=tuLiTHi($g|I=rM6l+Rgw!(Ngp}1I?8A{C%%dvOu_&fK-JKOr>>L zeB-+Z@_0Apcqgo?pMFtyN@*%ZX;{-HE2i$5&^dO%Bw;1}UA%>b_%a4!;3mhj6?47N zYF_HcrfS8wh-mwWSQW{K8r0;D*%2|=r@I_qW-kk%5;B+2Uaa<(?Mabt2y3)+M#!I?v~QX3q^)h${tKE`!kRQY{*prL6bQ^jL`OyJD|du+3pv zpB+8&n_tdt#I|eYPy2$OTiKqh7)fqYq+<5qsEuVyPSDA$|JP|e0rNa zhF#IchM+ECYaGc!jet@O;;!jjr;2Fw9#>~bYP$DDs(R6aj$-q?#Puw(kv>z=WvI?w z|5)Bpu<(YCM}_f3pC~07{?$gGN#z-4ZWak|o7#iE0&!wVh=oz*%$-k1pi>&3Tsk%PPJfP^Q zls5EEbpp}5odC|@%M9yx$)V%Yl(IlY8Y(q?Etk;c@V=bR_GymBpDhbl_5tHvFrK2b z5=sZCp3>~)pwhyZu6~*tXEptKDUjM}FOqJ1C9#+R|HjAuvn@b&UcB32qe*!!?p#@Z zrlxUEqBL#Vj6Pv5!Bqk&RDJ)K1&igwqGaU(e1OGOXE~J}hDoZoBy@U~lFB5;^m>S} zw}PHu3~r4)lPdvc3{%1~r86|$;H+4el~FlAQ95NqcjpJDaH;UQ7W1he*X}jSUZA+CxZ;v6Onj`D2T)KPZylIPP3Mczr+zYMi@o!q zSh5Un+YD~}bl!_J?=HV&78257oWSt8iIsyf_FAcwC_A=TZU&Q^IfzEpeI5zbs-rx^ zuqH$?6f&i)caigHc?f0i{kaz!A3=<5-@ltXaXC_0bTq$#Gs~O*0A7>%}^qFfSY}w-7m)2%x0sPz~40*`PCb*-hn~*N}r}xF%#^p0) z=1D7hDrd3~;jfF9U!&v``FE^zg%t)&-f{f3C1Zq@J6p7oz*7AnG;_F$+&Jb!3)fR@ zO-SNjoTO?fF50rGAdcqepG$K)C3@Yxw?0ODOG+$9N8NGYE?bMgx--4MS3kD>K6t-U zT1}m33fh^PsozKo-kQeTu0MRS3H)>U-ZwQ*Cu?cR@RI@om-1YR0Yml>hx-{ZCxu7Nh_G literal 0 HcmV?d00001 diff --git a/external/avro/src/test/resources/test.avro b/external/avro/src/test/resources/test.avro new file mode 100644 index 0000000000000000000000000000000000000000..6425e2107e3043aa9cfa201e274994def048f4c5 GIT binary patch literal 1365 zcma)5F>ljA6n0E-VF?tWA~B%r2?0YCiG-;z)NLaOGz7 jeHj?Q?U~=SzKdX{aKW zikXR#fsq;UW8gQiGB6R*bpC|b&G z=}#yppBslbolPlT!3p(665u9|30HPXW$G4D0EUc4fy67@hsS=ICM@0oSIO6QAbg&lu;;;S)Af|h3X4MJva~d ze<@4h^J>~GW+HYAkE;$%3){w}S<=Q8F$D`Gx{-)?PV)ib~Oc!Gk!KfiIx(a zjHv^VGwz8d< zwfP{qISw^Wj_!Qi#3W)ws!7|%!+arZ1)P*Yl7!4$5xSlb5sbM`qy^;>0JD^GHMPfq z)n>dIY?!9v!kmxi#-FD*-hE9L8z1j9zCZf;{+IRp;=MutF@n`n(?hRdLvLeft{3yBowhYWHAU_ z0gxgX-F?_fibx!wNybSGTboT;z|z^n9PHiYC>AM_8IXx5vQ(2=R?RRB%U)Z*HKK4h zF=7(+`fK)b-P+sRJE~cnb0sPl*)boO_szDl@3%YVQO$hyMLjoH7Zw(3ruC_|$wG<( z7UcC{Z766;1>$5Egi17}Nl5*)g|;Swf@)Q*#E?hTf=TEO5yUe|Gu|?c+aqYvS9Ay^ lXtUQ{F4TbRx(ToRS;UvGeyA^vy3WXTMnfeIL^`MM<1a>lvsC~9 literal 0 HcmV?d00001 diff --git a/external/avro/src/test/resources/test.avsc b/external/avro/src/test/resources/test.avsc new file mode 100644 index 0000000000000..d7119a01f6aa0 --- /dev/null +++ b/external/avro/src/test/resources/test.avsc @@ -0,0 +1,53 @@ +{ + "type" : "record", + "name" : "test_schema", + "fields" : [{ + "name" : "string", + "type" : "string", + "doc" : "Meaningless string of characters" + }, { + "name" : "simple_map", + "type" : {"type": "map", "values": "int"} + }, { + "name" : "complex_map", + "type" : {"type": "map", "values": {"type": "map", "values": "string"}} + }, { + "name" : "union_string_null", + "type" : ["null", "string"] + }, { + "name" : "union_int_long_null", + "type" : ["int", "long", "null"] + }, { + "name" : "union_float_double", + "type" : ["float", "double"] + }, { + "name": "fixed3", + "type": {"type": "fixed", "size": 3, "name": "fixed3"} + }, { + "name": "fixed2", + "type": {"type": "fixed", "size": 2, "name": "fixed2"} + }, { + "name": "enum", + "type": { "type": "enum", + "name": "Suit", + "symbols" : ["SPADES", "HEARTS", "DIAMONDS", "CLUBS"] + } + }, { + "name": "record", + "type": { + "type": "record", + "name": "record", + "aliases": ["RecordAlias"], + "fields" : [{ + "name": "value_field", + "type": "string" + }] + } + }, { + "name": "array_of_boolean", + "type": {"type": "array", "items": "boolean"} + }, { + "name": "bytes", + "type": "bytes" + }] +} diff --git a/external/avro/src/test/resources/test.json b/external/avro/src/test/resources/test.json new file mode 100644 index 0000000000000..780189a92b378 --- /dev/null +++ b/external/avro/src/test/resources/test.json @@ -0,0 +1,42 @@ +{ + "string": "OMG SPARK IS AWESOME", + "simple_map": {"abc": 1, "bcd": 7}, + "complex_map": {"key": {"a": "b", "c": "d"}}, + "union_string_null": {"string": "abc"}, + "union_int_long_null": {"int": 1}, + "union_float_double": {"float": 3.1415926535}, + "fixed3":"\u0002\u0003\u0004", + "fixed2":"\u0011\u0012", + "enum": "SPADES", + "record": {"value_field": "Two things are infinite: the universe and human stupidity; and I'm not sure about universe."}, + "array_of_boolean": [true, false, false], + "bytes": "\u0041\u0042\u0043" +} +{ + "string": "Terran is IMBA!", + "simple_map": {"mmm": 0, "qqq": 66}, + "complex_map": {"key": {"1": "2", "3": "4"}}, + "union_string_null": {"string": "123"}, + "union_int_long_null": {"long": 66}, + "union_float_double": {"double": 6.6666666666666}, + "fixed3":"\u0007\u0007\u0007", + "fixed2":"\u0001\u0002", + "enum": "CLUBS", + "record": {"value_field": "Life did not intend to make us perfect. Whoever is perfect belongs in a museum."}, + "array_of_boolean": [], + "bytes": "" +} +{ + "string": "The cake is a LIE!", + "simple_map": {}, + "complex_map": {"key": {}}, + "union_string_null": {"null": null}, + "union_int_long_null": {"null": null}, + "union_float_double": {"double": 0}, + "fixed3":"\u0011\u0022\u0009", + "fixed2":"\u0010\u0090", + "enum": "DIAMONDS", + "record": {"value_field": "TEST_STR123"}, + "array_of_boolean": [false], + "bytes": "\u0053" +} diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala new file mode 100644 index 0000000000000..c6c1e4051a4b3 --- /dev/null +++ b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala @@ -0,0 +1,812 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.avro + +import java.io._ +import java.nio.file.Files +import java.sql.{Date, Timestamp} +import java.util.{TimeZone, UUID} + +import scala.collection.JavaConverters._ + +import org.apache.avro.Schema +import org.apache.avro.Schema.{Field, Type} +import org.apache.avro.file.DataFileWriter +import org.apache.avro.generic.{GenericData, GenericDatumWriter, GenericRecord} +import org.apache.avro.generic.GenericData.{EnumSymbol, Fixed} +import org.apache.commons.io.FileUtils + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql._ +import org.apache.spark.sql.avro.SchemaConverters.IncompatibleSchemaException +import org.apache.spark.sql.types._ + +class AvroSuite extends SparkFunSuite { + val episodesFile = "src/test/resources/episodes.avro" + val testFile = "src/test/resources/test.avro" + + private var spark: SparkSession = _ + + override protected def beforeAll(): Unit = { + super.beforeAll() + spark = SparkSession.builder() + .master("local[2]") + .appName("AvroSuite") + .config("spark.sql.files.maxPartitionBytes", 1024) + .getOrCreate() + } + + override protected def afterAll(): Unit = { + try { + spark.sparkContext.stop() + } finally { + super.afterAll() + } + } + + test("reading from multiple paths") { + val df = spark.read.avro(episodesFile, episodesFile) + assert(df.count == 16) + } + + test("reading and writing partitioned data") { + val df = spark.read.avro(episodesFile) + val fields = List("title", "air_date", "doctor") + for (field <- fields) { + TestUtils.withTempDir { dir => + val outputDir = s"$dir/${UUID.randomUUID}" + df.write.partitionBy(field).avro(outputDir) + val input = spark.read.avro(outputDir) + // makes sure that no fields got dropped. + // We convert Rows to Seqs in order to work around SPARK-10325 + assert(input.select(field).collect().map(_.toSeq).toSet === + df.select(field).collect().map(_.toSeq).toSet) + } + } + } + + test("request no fields") { + val df = spark.read.avro(episodesFile) + df.registerTempTable("avro_table") + assert(spark.sql("select count(*) from avro_table").collect().head === Row(8)) + } + + test("convert formats") { + TestUtils.withTempDir { dir => + val df = spark.read.avro(episodesFile) + df.write.parquet(dir.getCanonicalPath) + assert(spark.read.parquet(dir.getCanonicalPath).count() === df.count) + } + } + + test("rearrange internal schema") { + TestUtils.withTempDir { dir => + val df = spark.read.avro(episodesFile) + df.select("doctor", "title").write.avro(dir.getCanonicalPath) + } + } + + test("test NULL avro type") { + TestUtils.withTempDir { dir => + val fields = Seq(new Field("null", Schema.create(Type.NULL), "doc", null)).asJava + val schema = Schema.createRecord("name", "docs", "namespace", false) + schema.setFields(fields) + val datumWriter = new GenericDatumWriter[GenericRecord](schema) + val dataFileWriter = new DataFileWriter[GenericRecord](datumWriter) + dataFileWriter.create(schema, new File(s"$dir.avro")) + val avroRec = new GenericData.Record(schema) + avroRec.put("null", null) + dataFileWriter.append(avroRec) + dataFileWriter.flush() + dataFileWriter.close() + + intercept[IncompatibleSchemaException] { + spark.read.avro(s"$dir.avro") + } + } + } + + test("union(int, long) is read as long") { + TestUtils.withTempDir { dir => + val avroSchema: Schema = { + val union = + Schema.createUnion(List(Schema.create(Type.INT), Schema.create(Type.LONG)).asJava) + val fields = Seq(new Field("field1", union, "doc", null)).asJava + val schema = Schema.createRecord("name", "docs", "namespace", false) + schema.setFields(fields) + schema + } + + val datumWriter = new GenericDatumWriter[GenericRecord](avroSchema) + val dataFileWriter = new DataFileWriter[GenericRecord](datumWriter) + dataFileWriter.create(avroSchema, new File(s"$dir.avro")) + val rec1 = new GenericData.Record(avroSchema) + rec1.put("field1", 1.toLong) + dataFileWriter.append(rec1) + val rec2 = new GenericData.Record(avroSchema) + rec2.put("field1", 2) + dataFileWriter.append(rec2) + dataFileWriter.flush() + dataFileWriter.close() + val df = spark.read.avro(s"$dir.avro") + assert(df.schema.fields === Seq(StructField("field1", LongType, nullable = true))) + assert(df.collect().toSet == Set(Row(1L), Row(2L))) + } + } + + test("union(float, double) is read as double") { + TestUtils.withTempDir { dir => + val avroSchema: Schema = { + val union = + Schema.createUnion(List(Schema.create(Type.FLOAT), Schema.create(Type.DOUBLE)).asJava) + val fields = Seq(new Field("field1", union, "doc", null)).asJava + val schema = Schema.createRecord("name", "docs", "namespace", false) + schema.setFields(fields) + schema + } + + val datumWriter = new GenericDatumWriter[GenericRecord](avroSchema) + val dataFileWriter = new DataFileWriter[GenericRecord](datumWriter) + dataFileWriter.create(avroSchema, new File(s"$dir.avro")) + val rec1 = new GenericData.Record(avroSchema) + rec1.put("field1", 1.toFloat) + dataFileWriter.append(rec1) + val rec2 = new GenericData.Record(avroSchema) + rec2.put("field1", 2.toDouble) + dataFileWriter.append(rec2) + dataFileWriter.flush() + dataFileWriter.close() + val df = spark.read.avro(s"$dir.avro") + assert(df.schema.fields === Seq(StructField("field1", DoubleType, nullable = true))) + assert(df.collect().toSet == Set(Row(1.toDouble), Row(2.toDouble))) + } + } + + test("union(float, double, null) is read as nullable double") { + TestUtils.withTempDir { dir => + val avroSchema: Schema = { + val union = Schema.createUnion( + List(Schema.create(Type.FLOAT), + Schema.create(Type.DOUBLE), + Schema.create(Type.NULL) + ).asJava + ) + val fields = Seq(new Field("field1", union, "doc", null)).asJava + val schema = Schema.createRecord("name", "docs", "namespace", false) + schema.setFields(fields) + schema + } + + val datumWriter = new GenericDatumWriter[GenericRecord](avroSchema) + val dataFileWriter = new DataFileWriter[GenericRecord](datumWriter) + dataFileWriter.create(avroSchema, new File(s"$dir.avro")) + val rec1 = new GenericData.Record(avroSchema) + rec1.put("field1", 1.toFloat) + dataFileWriter.append(rec1) + val rec2 = new GenericData.Record(avroSchema) + rec2.put("field1", null) + dataFileWriter.append(rec2) + dataFileWriter.flush() + dataFileWriter.close() + val df = spark.read.avro(s"$dir.avro") + assert(df.schema.fields === Seq(StructField("field1", DoubleType, nullable = true))) + assert(df.collect().toSet == Set(Row(1.toDouble), Row(null))) + } + } + + test("Union of a single type") { + TestUtils.withTempDir { dir => + val UnionOfOne = Schema.createUnion(List(Schema.create(Type.INT)).asJava) + val fields = Seq(new Field("field1", UnionOfOne, "doc", null)).asJava + val schema = Schema.createRecord("name", "docs", "namespace", false) + schema.setFields(fields) + + val datumWriter = new GenericDatumWriter[GenericRecord](schema) + val dataFileWriter = new DataFileWriter[GenericRecord](datumWriter) + dataFileWriter.create(schema, new File(s"$dir.avro")) + val avroRec = new GenericData.Record(schema) + + avroRec.put("field1", 8) + + dataFileWriter.append(avroRec) + dataFileWriter.flush() + dataFileWriter.close() + + val df = spark.read.avro(s"$dir.avro") + assert(df.first() == Row(8)) + } + } + + test("Complex Union Type") { + TestUtils.withTempDir { dir => + val fixedSchema = Schema.createFixed("fixed_name", "doc", "namespace", 4) + val enumSchema = Schema.createEnum("enum_name", "doc", "namespace", List("e1", "e2").asJava) + val complexUnionType = Schema.createUnion( + List(Schema.create(Type.INT), Schema.create(Type.STRING), fixedSchema, enumSchema).asJava) + val fields = Seq( + new Field("field1", complexUnionType, "doc", null), + new Field("field2", complexUnionType, "doc", null), + new Field("field3", complexUnionType, "doc", null), + new Field("field4", complexUnionType, "doc", null) + ).asJava + val schema = Schema.createRecord("name", "docs", "namespace", false) + schema.setFields(fields) + val datumWriter = new GenericDatumWriter[GenericRecord](schema) + val dataFileWriter = new DataFileWriter[GenericRecord](datumWriter) + dataFileWriter.create(schema, new File(s"$dir.avro")) + val avroRec = new GenericData.Record(schema) + val field1 = 1234 + val field2 = "Hope that was not load bearing" + val field3 = Array[Byte](1, 2, 3, 4) + val field4 = "e2" + avroRec.put("field1", field1) + avroRec.put("field2", field2) + avroRec.put("field3", new Fixed(fixedSchema, field3)) + avroRec.put("field4", new EnumSymbol(enumSchema, field4)) + dataFileWriter.append(avroRec) + dataFileWriter.flush() + dataFileWriter.close() + + val df = spark.sqlContext.read.avro(s"$dir.avro") + assertResult(field1)(df.selectExpr("field1.member0").first().get(0)) + assertResult(field2)(df.selectExpr("field2.member1").first().get(0)) + assertResult(field3)(df.selectExpr("field3.member2").first().get(0)) + assertResult(field4)(df.selectExpr("field4.member3").first().get(0)) + } + } + + test("Lots of nulls") { + TestUtils.withTempDir { dir => + val schema = StructType(Seq( + StructField("binary", BinaryType, true), + StructField("timestamp", TimestampType, true), + StructField("array", ArrayType(ShortType), true), + StructField("map", MapType(StringType, StringType), true), + StructField("struct", StructType(Seq(StructField("int", IntegerType, true)))))) + val rdd = spark.sparkContext.parallelize(Seq[Row]( + Row(null, new Timestamp(1), Array[Short](1, 2, 3), null, null), + Row(null, null, null, null, null), + Row(null, null, null, null, null), + Row(null, null, null, null, null))) + val df = spark.createDataFrame(rdd, schema) + df.write.avro(dir.toString) + assert(spark.read.avro(dir.toString).count == rdd.count) + } + } + + test("Struct field type") { + TestUtils.withTempDir { dir => + val schema = StructType(Seq( + StructField("float", FloatType, true), + StructField("short", ShortType, true), + StructField("byte", ByteType, true), + StructField("boolean", BooleanType, true) + )) + val rdd = spark.sparkContext.parallelize(Seq( + Row(1f, 1.toShort, 1.toByte, true), + Row(2f, 2.toShort, 2.toByte, true), + Row(3f, 3.toShort, 3.toByte, true) + )) + val df = spark.createDataFrame(rdd, schema) + df.write.avro(dir.toString) + assert(spark.read.avro(dir.toString).count == rdd.count) + } + } + + test("Date field type") { + TestUtils.withTempDir { dir => + val schema = StructType(Seq( + StructField("float", FloatType, true), + StructField("date", DateType, true) + )) + TimeZone.setDefault(TimeZone.getTimeZone("UTC")) + val rdd = spark.sparkContext.parallelize(Seq( + Row(1f, null), + Row(2f, new Date(1451948400000L)), + Row(3f, new Date(1460066400500L)) + )) + val df = spark.createDataFrame(rdd, schema) + df.write.avro(dir.toString) + assert(spark.read.avro(dir.toString).count == rdd.count) + assert(spark.read.avro(dir.toString).select("date").collect().map(_(0)).toSet == + Array(null, 1451865600000L, 1459987200000L).toSet) + } + } + + test("Array data types") { + TestUtils.withTempDir { dir => + val testSchema = StructType(Seq( + StructField("byte_array", ArrayType(ByteType), true), + StructField("short_array", ArrayType(ShortType), true), + StructField("float_array", ArrayType(FloatType), true), + StructField("bool_array", ArrayType(BooleanType), true), + StructField("long_array", ArrayType(LongType), true), + StructField("double_array", ArrayType(DoubleType), true), + StructField("decimal_array", ArrayType(DecimalType(10, 0)), true), + StructField("bin_array", ArrayType(BinaryType), true), + StructField("timestamp_array", ArrayType(TimestampType), true), + StructField("array_array", ArrayType(ArrayType(StringType), true), true), + StructField("struct_array", ArrayType( + StructType(Seq(StructField("name", StringType, true))))))) + + val arrayOfByte = new Array[Byte](4) + for (i <- arrayOfByte.indices) { + arrayOfByte(i) = i.toByte + } + + val rdd = spark.sparkContext.parallelize(Seq( + Row(arrayOfByte, Array[Short](1, 2, 3, 4), Array[Float](1f, 2f, 3f, 4f), + Array[Boolean](true, false, true, false), Array[Long](1L, 2L), Array[Double](1.0, 2.0), + Array[BigDecimal](BigDecimal.valueOf(3)), Array[Array[Byte]](arrayOfByte, arrayOfByte), + Array[Timestamp](new Timestamp(0)), + Array[Array[String]](Array[String]("CSH, tearing down the walls that divide us", "-jd")), + Array[Row](Row("Bobby G. can't swim"))))) + val df = spark.createDataFrame(rdd, testSchema) + df.write.avro(dir.toString) + assert(spark.read.avro(dir.toString).count == rdd.count) + } + } + + test("write with compression") { + TestUtils.withTempDir { dir => + val AVRO_COMPRESSION_CODEC = "spark.sql.avro.compression.codec" + val AVRO_DEFLATE_LEVEL = "spark.sql.avro.deflate.level" + val uncompressDir = s"$dir/uncompress" + val deflateDir = s"$dir/deflate" + val snappyDir = s"$dir/snappy" + val fakeDir = s"$dir/fake" + + val df = spark.read.avro(testFile) + spark.conf.set(AVRO_COMPRESSION_CODEC, "uncompressed") + df.write.avro(uncompressDir) + spark.conf.set(AVRO_COMPRESSION_CODEC, "deflate") + spark.conf.set(AVRO_DEFLATE_LEVEL, "9") + df.write.avro(deflateDir) + spark.conf.set(AVRO_COMPRESSION_CODEC, "snappy") + df.write.avro(snappyDir) + + val uncompressSize = FileUtils.sizeOfDirectory(new File(uncompressDir)) + val deflateSize = FileUtils.sizeOfDirectory(new File(deflateDir)) + val snappySize = FileUtils.sizeOfDirectory(new File(snappyDir)) + + assert(uncompressSize > deflateSize) + assert(snappySize > deflateSize) + } + } + + test("dsl test") { + val results = spark.read.avro(episodesFile).select("title").collect() + assert(results.length === 8) + } + + test("support of various data types") { + // This test uses data from test.avro. You can see the data and the schema of this file in + // test.json and test.avsc + val all = spark.read.avro(testFile).collect() + assert(all.length == 3) + + val str = spark.read.avro(testFile).select("string").collect() + assert(str.map(_(0)).toSet.contains("Terran is IMBA!")) + + val simple_map = spark.read.avro(testFile).select("simple_map").collect() + assert(simple_map(0)(0).getClass.toString.contains("Map")) + assert(simple_map.map(_(0).asInstanceOf[Map[String, Some[Int]]].size).toSet == Set(2, 0)) + + val union0 = spark.read.avro(testFile).select("union_string_null").collect() + assert(union0.map(_(0)).toSet == Set("abc", "123", null)) + + val union1 = spark.read.avro(testFile).select("union_int_long_null").collect() + assert(union1.map(_(0)).toSet == Set(66, 1, null)) + + val union2 = spark.read.avro(testFile).select("union_float_double").collect() + assert( + union2 + .map(x => new java.lang.Double(x(0).toString)) + .exists(p => Math.abs(p - Math.PI) < 0.001)) + + val fixed = spark.read.avro(testFile).select("fixed3").collect() + assert(fixed.map(_(0).asInstanceOf[Array[Byte]]).exists(p => p(1) == 3)) + + val enum = spark.read.avro(testFile).select("enum").collect() + assert(enum.map(_(0)).toSet == Set("SPADES", "CLUBS", "DIAMONDS")) + + val record = spark.read.avro(testFile).select("record").collect() + assert(record(0)(0).getClass.toString.contains("Row")) + assert(record.map(_(0).asInstanceOf[Row](0)).contains("TEST_STR123")) + + val array_of_boolean = spark.read.avro(testFile).select("array_of_boolean").collect() + assert(array_of_boolean.map(_(0).asInstanceOf[Seq[Boolean]].size).toSet == Set(3, 1, 0)) + + val bytes = spark.read.avro(testFile).select("bytes").collect() + assert(bytes.map(_(0).asInstanceOf[Array[Byte]].length).toSet == Set(3, 1, 0)) + } + + test("sql test") { + spark.sql( + s""" + |CREATE TEMPORARY TABLE avroTable + |USING avro + |OPTIONS (path "$episodesFile") + """.stripMargin.replaceAll("\n", " ")) + + assert(spark.sql("SELECT * FROM avroTable").collect().length === 8) + } + + test("conversion to avro and back") { + // Note that test.avro includes a variety of types, some of which are nullable. We expect to + // get the same values back. + TestUtils.withTempDir { dir => + val avroDir = s"$dir/avro" + spark.read.avro(testFile).write.avro(avroDir) + TestUtils.checkReloadMatchesSaved(spark, testFile, avroDir) + } + } + + test("conversion to avro and back with namespace") { + // Note that test.avro includes a variety of types, some of which are nullable. We expect to + // get the same values back. + TestUtils.withTempDir { tempDir => + val name = "AvroTest" + val namespace = "com.databricks.spark.avro" + val parameters = Map("recordName" -> name, "recordNamespace" -> namespace) + + val avroDir = tempDir + "/namedAvro" + spark.read.avro(testFile).write.options(parameters).avro(avroDir) + TestUtils.checkReloadMatchesSaved(spark, testFile, avroDir) + + // Look at raw file and make sure has namespace info + val rawSaved = spark.sparkContext.textFile(avroDir) + val schema = rawSaved.collect().mkString("") + assert(schema.contains(name)) + assert(schema.contains(namespace)) + } + } + + test("converting some specific sparkSQL types to avro") { + TestUtils.withTempDir { tempDir => + val testSchema = StructType(Seq( + StructField("Name", StringType, false), + StructField("Length", IntegerType, true), + StructField("Time", TimestampType, false), + StructField("Decimal", DecimalType(10, 2), true), + StructField("Binary", BinaryType, false))) + + val arrayOfByte = new Array[Byte](4) + for (i <- arrayOfByte.indices) { + arrayOfByte(i) = i.toByte + } + val cityRDD = spark.sparkContext.parallelize(Seq( + Row("San Francisco", 12, new Timestamp(666), null, arrayOfByte), + Row("Palo Alto", null, new Timestamp(777), null, arrayOfByte), + Row("Munich", 8, new Timestamp(42), Decimal(3.14), arrayOfByte))) + val cityDataFrame = spark.createDataFrame(cityRDD, testSchema) + + val avroDir = tempDir + "/avro" + cityDataFrame.write.avro(avroDir) + assert(spark.read.avro(avroDir).collect().length == 3) + + // TimesStamps are converted to longs + val times = spark.read.avro(avroDir).select("Time").collect() + assert(times.map(_(0)).toSet == Set(666, 777, 42)) + + // DecimalType should be converted to string + val decimals = spark.read.avro(avroDir).select("Decimal").collect() + assert(decimals.map(_(0)).contains("3.14")) + + // There should be a null entry + val length = spark.read.avro(avroDir).select("Length").collect() + assert(length.map(_(0)).contains(null)) + + val binary = spark.read.avro(avroDir).select("Binary").collect() + for (i <- arrayOfByte.indices) { + assert(binary(1)(0).asInstanceOf[Array[Byte]](i) == arrayOfByte(i)) + } + } + } + + test("correctly read long as date/timestamp type") { + TestUtils.withTempDir { tempDir => + val sparkSession = spark + import sparkSession.implicits._ + + val currentTime = new Timestamp(System.currentTimeMillis()) + val currentDate = new Date(System.currentTimeMillis()) + val schema = StructType(Seq( + StructField("_1", DateType, false), StructField("_2", TimestampType, false))) + val writeDs = Seq((currentDate, currentTime)).toDS + + val avroDir = tempDir + "/avro" + writeDs.write.avro(avroDir) + assert(spark.read.avro(avroDir).collect().length == 1) + + val readDs = spark.read.schema(schema).avro(avroDir).as[(Date, Timestamp)] + + assert(readDs.collect().sameElements(writeDs.collect())) + } + } + + test("support of globbed paths") { + val e1 = spark.read.avro("*/test/resources/episodes.avro").collect() + assert(e1.length == 8) + + val e2 = spark.read.avro("src/*/*/episodes.avro").collect() + assert(e2.length == 8) + } + + test("does not coerce null date/timestamp value to 0 epoch.") { + TestUtils.withTempDir { tempDir => + val sparkSession = spark + import sparkSession.implicits._ + + val nullTime: Timestamp = null + val nullDate: Date = null + val schema = StructType(Seq( + StructField("_1", DateType, nullable = true), + StructField("_2", TimestampType, nullable = true)) + ) + val writeDs = Seq((nullDate, nullTime)).toDS + + val avroDir = tempDir + "/avro" + writeDs.write.avro(avroDir) + val readValues = spark.read.schema(schema).avro(avroDir).as[(Date, Timestamp)].collect + + assert(readValues.size == 1) + assert(readValues.head == ((nullDate, nullTime))) + } + } + + test("support user provided avro schema") { + val avroSchema = + """ + |{ + | "type" : "record", + | "name" : "test_schema", + | "fields" : [{ + | "name" : "string", + | "type" : "string", + | "doc" : "Meaningless string of characters" + | }] + |} + """.stripMargin + val result = spark.read.option(AvroFileFormat.AvroSchema, avroSchema).avro(testFile).collect() + val expected = spark.read.avro(testFile).select("string").collect() + assert(result.sameElements(expected)) + } + + test("support user provided avro schema with defaults for missing fields") { + val avroSchema = + """ + |{ + | "type" : "record", + | "name" : "test_schema", + | "fields" : [{ + | "name" : "missingField", + | "type" : "string", + | "default" : "foo" + | }] + |} + """.stripMargin + val result = spark.read.option(AvroFileFormat.AvroSchema, avroSchema) + .avro(testFile).select("missingField").first + assert(result === Row("foo")) + } + + test("reading from invalid path throws exception") { + + // Directory given has no avro files + intercept[AnalysisException] { + TestUtils.withTempDir(dir => spark.read.avro(dir.getCanonicalPath)) + } + + intercept[AnalysisException] { + spark.read.avro("very/invalid/path/123.avro") + } + + // In case of globbed path that can't be matched to anything, another exception is thrown (and + // exception message is helpful) + intercept[AnalysisException] { + spark.read.avro("*/*/*/*/*/*/*/something.avro") + } + + intercept[FileNotFoundException] { + TestUtils.withTempDir { dir => + FileUtils.touch(new File(dir, "test")) + spark.read.avro(dir.toString) + } + } + + } + + test("SQL test insert overwrite") { + TestUtils.withTempDir { tempDir => + val tempEmptyDir = s"$tempDir/sqlOverwrite" + // Create a temp directory for table that will be overwritten + new File(tempEmptyDir).mkdirs() + spark.sql( + s""" + |CREATE TEMPORARY TABLE episodes + |USING avro + |OPTIONS (path "$episodesFile") + """.stripMargin.replaceAll("\n", " ")) + spark.sql( + s""" + |CREATE TEMPORARY TABLE episodesEmpty + |(name string, air_date string, doctor int) + |USING avro + |OPTIONS (path "$tempEmptyDir") + """.stripMargin.replaceAll("\n", " ")) + + assert(spark.sql("SELECT * FROM episodes").collect().length === 8) + assert(spark.sql("SELECT * FROM episodesEmpty").collect().isEmpty) + + spark.sql( + s""" + |INSERT OVERWRITE TABLE episodesEmpty + |SELECT * FROM episodes + """.stripMargin.replaceAll("\n", " ")) + assert(spark.sql("SELECT * FROM episodesEmpty").collect().length == 8) + } + } + + test("test save and load") { + // Test if load works as expected + TestUtils.withTempDir { tempDir => + val df = spark.read.avro(episodesFile) + assert(df.count == 8) + + val tempSaveDir = s"$tempDir/save/" + + df.write.avro(tempSaveDir) + val newDf = spark.read.avro(tempSaveDir) + assert(newDf.count == 8) + } + } + + test("test load with non-Avro file") { + // Test if load works as expected + TestUtils.withTempDir { tempDir => + val df = spark.read.avro(episodesFile) + assert(df.count == 8) + + val tempSaveDir = s"$tempDir/save/" + df.write.avro(tempSaveDir) + + Files.createFile(new File(tempSaveDir, "non-avro").toPath) + + val newDf = spark + .read + .option(AvroFileFormat.IgnoreFilesWithoutExtensionProperty, "true") + .avro(tempSaveDir) + + assert(newDf.count == 8) + } + } + + test("read avro with user defined schema: read partial columns") { + val partialColumns = StructType(Seq( + StructField("string", StringType, false), + StructField("simple_map", MapType(StringType, IntegerType), false), + StructField("complex_map", MapType(StringType, MapType(StringType, StringType)), false), + StructField("union_string_null", StringType, true), + StructField("union_int_long_null", LongType, true), + StructField("fixed3", BinaryType, true), + StructField("fixed2", BinaryType, true), + StructField("enum", StringType, false), + StructField("record", StructType(Seq(StructField("value_field", StringType, false))), false), + StructField("array_of_boolean", ArrayType(BooleanType), false), + StructField("bytes", BinaryType, true))) + val withSchema = spark.read.schema(partialColumns).avro(testFile).collect() + val withOutSchema = spark + .read + .avro(testFile) + .select("string", "simple_map", "complex_map", "union_string_null", "union_int_long_null", + "fixed3", "fixed2", "enum", "record", "array_of_boolean", "bytes") + .collect() + assert(withSchema.sameElements(withOutSchema)) + } + + test("read avro with user defined schema: read non-exist columns") { + val schema = + StructType( + Seq( + StructField("non_exist_string", StringType, true), + StructField( + "record", + StructType(Seq( + StructField("non_exist_field", StringType, false), + StructField("non_exist_field2", StringType, false))), + false))) + val withEmptyColumn = spark.read.schema(schema).avro(testFile).collect() + + assert(withEmptyColumn.forall(_ == Row(null: String, Row(null: String, null: String)))) + } + + test("read avro file partitioned") { + TestUtils.withTempDir { dir => + val sparkSession = spark + import sparkSession.implicits._ + val df = (0 to 1024 * 3).toDS.map(i => s"record${i}").toDF("records") + val outputDir = s"$dir/${UUID.randomUUID}" + df.write.avro(outputDir) + val input = spark.read.avro(outputDir) + assert(input.collect.toSet.size === 1024 * 3 + 1) + assert(input.rdd.partitions.size > 2) + } + } + + case class NestedBottom(id: Int, data: String) + + case class NestedMiddle(id: Int, data: NestedBottom) + + case class NestedTop(id: Int, data: NestedMiddle) + + test("saving avro that has nested records with the same name") { + TestUtils.withTempDir { tempDir => + // Save avro file on output folder path + val writeDf = spark.createDataFrame(List(NestedTop(1, NestedMiddle(2, NestedBottom(3, "1"))))) + val outputFolder = s"$tempDir/duplicate_names/" + writeDf.write.avro(outputFolder) + // Read avro file saved on the last step + val readDf = spark.read.avro(outputFolder) + // Check if the written DataFrame is equals than read DataFrame + assert(readDf.collect().sameElements(writeDf.collect())) + } + } + + case class NestedMiddleArray(id: Int, data: Array[NestedBottom]) + + case class NestedTopArray(id: Int, data: NestedMiddleArray) + + test("saving avro that has nested records with the same name inside an array") { + TestUtils.withTempDir { tempDir => + // Save avro file on output folder path + val writeDf = spark.createDataFrame( + List(NestedTopArray(1, NestedMiddleArray(2, Array( + NestedBottom(3, "1"), NestedBottom(4, "2") + )))) + ) + val outputFolder = s"$tempDir/duplicate_names_array/" + writeDf.write.avro(outputFolder) + // Read avro file saved on the last step + val readDf = spark.read.avro(outputFolder) + // Check if the written DataFrame is equals than read DataFrame + assert(readDf.collect().sameElements(writeDf.collect())) + } + } + + case class NestedMiddleMap(id: Int, data: Map[String, NestedBottom]) + + case class NestedTopMap(id: Int, data: NestedMiddleMap) + + test("saving avro that has nested records with the same name inside a map") { + TestUtils.withTempDir { tempDir => + // Save avro file on output folder path + val writeDf = spark.createDataFrame( + List(NestedTopMap(1, NestedMiddleMap(2, Map( + "1" -> NestedBottom(3, "1"), "2" -> NestedBottom(4, "2") + )))) + ) + val outputFolder = s"$tempDir/duplicate_names_map/" + writeDf.write.avro(outputFolder) + // Read avro file saved on the last step + val readDf = spark.read.avro(outputFolder) + // Check if the written DataFrame is equals than read DataFrame + assert(readDf.collect().sameElements(writeDf.collect())) + } + } +} diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/SerializableConfigurationSuite.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/SerializableConfigurationSuite.scala new file mode 100755 index 0000000000000..a0f88515ed9d4 --- /dev/null +++ b/external/avro/src/test/scala/org/apache/spark/sql/avro/SerializableConfigurationSuite.scala @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.avro + +import org.apache.hadoop.conf.Configuration + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.serializer.{JavaSerializer, KryoSerializer, SerializerInstance} + +class SerializableConfigurationSuite extends SparkFunSuite { + + private def testSerialization(serializer: SerializerInstance): Unit = { + import AvroFileFormat.SerializableConfiguration + val conf = new SerializableConfiguration(new Configuration()) + + val serialized = serializer.serialize(conf) + + serializer.deserialize[Any](serialized) match { + case c: SerializableConfiguration => + assert(c.log != null, "log was null") + assert(c.value != null, "value was null") + case other => fail( + s"Expecting ${classOf[SerializableConfiguration]}, but got ${other.getClass}.") + } + } + + test("serialization with JavaSerializer") { + testSerialization(new JavaSerializer(new SparkConf()).newInstance()) + } + + test("serialization with KryoSerializer") { + testSerialization(new KryoSerializer(new SparkConf()).newInstance()) + } + +} diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/TestUtils.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/TestUtils.scala new file mode 100755 index 0000000000000..4ae9b14d9ad0d --- /dev/null +++ b/external/avro/src/test/scala/org/apache/spark/sql/avro/TestUtils.scala @@ -0,0 +1,156 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.avro + +import java.io.{File, IOException} +import java.nio.ByteBuffer + +import scala.collection.immutable.HashSet +import scala.collection.mutable.ArrayBuffer +import scala.util.Random + +import com.google.common.io.Files +import java.util + +import org.apache.spark.sql.SparkSession + +private[avro] object TestUtils { + + /** + * This function checks that all records in a file match the original + * record. + */ + def checkReloadMatchesSaved(spark: SparkSession, testFile: String, avroDir: String): Unit = { + + def convertToString(elem: Any): String = { + elem match { + case null => "NULL" // HashSets can't have null in them, so we use a string instead + case arrayBuf: ArrayBuffer[_] => + arrayBuf.asInstanceOf[ArrayBuffer[Any]].toArray.deep.mkString(" ") + case arrayByte: Array[Byte] => arrayByte.deep.mkString(" ") + case other => other.toString + } + } + + val originalEntries = spark.read.avro(testFile).collect() + val newEntries = spark.read.avro(avroDir).collect() + + assert(originalEntries.length == newEntries.length) + + val origEntrySet = Array.fill(originalEntries(0).size)(new HashSet[Any]()) + for (origEntry <- originalEntries) { + var idx = 0 + for (origElement <- origEntry.toSeq) { + origEntrySet(idx) += convertToString(origElement) + idx += 1 + } + } + + for (newEntry <- newEntries) { + var idx = 0 + for (newElement <- newEntry.toSeq) { + assert(origEntrySet(idx).contains(convertToString(newElement))) + idx += 1 + } + } + } + + def withTempDir(f: File => Unit): Unit = { + val dir = Files.createTempDir() + dir.delete() + try f(dir) finally deleteRecursively(dir) + } + + /** + * This function deletes a file or a directory with everything that's in it. This function is + * copied from Spark with minor modifications made to it. See original source at: + * github.com/apache/spark/blob/master/core/src/main/scala/org/apache/spark/util/Utils.scala + */ + + def deleteRecursively(file: File) { + def listFilesSafely(file: File): Seq[File] = { + if (file.exists()) { + val files = file.listFiles() + if (files == null) { + throw new IOException("Failed to list files for dir: " + file) + } + files + } else { + List() + } + } + + if (file != null) { + try { + if (file.isDirectory) { + var savedIOException: IOException = null + for (child <- listFilesSafely(file)) { + try { + deleteRecursively(child) + } catch { + // In case of multiple exceptions, only last one will be thrown + case ioe: IOException => savedIOException = ioe + } + } + if (savedIOException != null) { + throw savedIOException + } + } + } finally { + if (!file.delete()) { + // Delete can also fail if the file simply did not exist + if (file.exists()) { + throw new IOException("Failed to delete: " + file.getAbsolutePath) + } + } + } + } + } + + /** + * This function generates a random map(string, int) of a given size. + */ + private[avro] def generateRandomMap(rand: Random, size: Int): java.util.Map[String, Int] = { + val jMap = new util.HashMap[String, Int]() + for (i <- 0 until size) { + jMap.put(rand.nextString(5), i) + } + jMap + } + + /** + * This function generates a random array of booleans of a given size. + */ + private[avro] def generateRandomArray(rand: Random, size: Int): util.ArrayList[Boolean] = { + val vec = new util.ArrayList[Boolean]() + for (i <- 0 until size) { + vec.add(rand.nextBoolean()) + } + vec + } + + /** + * This function generates a random ByteBuffer of a given size. + */ + private[avro] def generateRandomByteBuffer(rand: Random, size: Int): ByteBuffer = { + val bb = ByteBuffer.allocate(size) + val arrayOfBytes = new Array[Byte](size) + rand.nextBytes(arrayOfBytes) + bb.put(arrayOfBytes) + } +} diff --git a/pom.xml b/pom.xml index 6dee6fce3ffc4..039292337eaa0 100644 --- a/pom.xml +++ b/pom.xml @@ -104,6 +104,7 @@ external/kafka-0-10 external/kafka-0-10-assembly external/kafka-0-10-sql + external/avro diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index f887e4570c85d..247b6fee394bc 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -40,8 +40,8 @@ object BuildCommons { private val buildLocation = file(".").getAbsoluteFile.getParentFile - val sqlProjects@Seq(catalyst, sql, hive, hiveThriftServer, sqlKafka010) = Seq( - "catalyst", "sql", "hive", "hive-thriftserver", "sql-kafka-0-10" + val sqlProjects@Seq(catalyst, sql, hive, hiveThriftServer, sqlKafka010, avro) = Seq( + "catalyst", "sql", "hive", "hive-thriftserver", "sql-kafka-0-10", "avro" ).map(ProjectRef(buildLocation, _)) val streamingProjects@Seq(streaming, streamingKafka010) = @@ -326,7 +326,7 @@ object SparkBuild extends PomBuild { val mimaProjects = allProjects.filterNot { x => Seq( spark, hive, hiveThriftServer, catalyst, repl, networkCommon, networkShuffle, networkYarn, - unsafe, tags, sqlKafka010, kvstore + unsafe, tags, sqlKafka010, kvstore, avro ).contains(x) } @@ -688,9 +688,11 @@ object Unidoc { publish := {}, unidocProjectFilter in(ScalaUnidoc, unidoc) := - inAnyProject -- inProjects(OldDeps.project, repl, examples, tools, streamingFlumeSink, kubernetes, yarn, tags, streamingKafka010, sqlKafka010), + inAnyProject -- inProjects(OldDeps.project, repl, examples, tools, streamingFlumeSink, kubernetes, + yarn, tags, streamingKafka010, sqlKafka010, avro), unidocProjectFilter in(JavaUnidoc, unidoc) := - inAnyProject -- inProjects(OldDeps.project, repl, examples, tools, streamingFlumeSink, kubernetes, yarn, tags, streamingKafka010, sqlKafka010), + inAnyProject -- inProjects(OldDeps.project, repl, examples, tools, streamingFlumeSink, kubernetes, + yarn, tags, streamingKafka010, sqlKafka010, avro), unidocAllClasspaths in (ScalaUnidoc, unidoc) := { ignoreClasspaths((unidocAllClasspaths in (ScalaUnidoc, unidoc)).value) From 07704c971cbc92bff15e15f8c42fab9afaab3ef7 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Thu, 12 Jul 2018 14:08:49 -0700 Subject: [PATCH 1109/1161] [SPARK-23007][SQL][TEST] Add read schema suite for file-based data sources ## What changes were proposed in this pull request? The reader schema is said to be evolved (or projected) when it changed after the data is written. The followings are already supported in file-based data sources. Note that partition columns are not maintained in files. In this PR, `column` means `non-partition column`. 1. Add a column 2. Hide a column 3. Change a column position 4. Change a column type (upcast) This issue aims to guarantee users a backward-compatible read-schema test coverage on file-based data sources and to prevent future regressions by *adding read schema tests explicitly*. Here, we consider safe changes without data loss. For example, data type change should be from small types to larger types like `int`-to-`long`, not vice versa. As of today, in the master branch, file-based data sources have the following coverage. File Format | Coverage | Note ----------- | ---------- | ------------------------------------------------ TEXT | N/A | Schema consists of a single string column. CSV | 1, 2, 4 | JSON | 1, 2, 3, 4 | ORC | 1, 2, 3, 4 | Native vectorized ORC reader has the widest coverage among ORC formats. PARQUET | 1, 2, 3 | ## How was this patch tested? Pass the Jenkins with newly added test suites. Author: Dongjoon Hyun Closes #20208 from dongjoon-hyun/SPARK-SCHEMA-EVOLUTION. --- .../datasources/ReadSchemaSuite.scala | 181 +++++++ .../datasources/ReadSchemaTest.scala | 493 ++++++++++++++++++ 2 files changed, 674 insertions(+) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/ReadSchemaSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/ReadSchemaTest.scala diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/ReadSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/ReadSchemaSuite.scala new file mode 100644 index 0000000000000..23c58e175fe5e --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/ReadSchemaSuite.scala @@ -0,0 +1,181 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import org.apache.spark.sql.internal.SQLConf + +/** + * Read schema suites have the following hierarchy and aims to guarantee users + * a backward-compatible read-schema change coverage on file-based data sources, and + * to prevent future regressions. + * + * ReadSchemaSuite + * -> CSVReadSchemaSuite + * -> HeaderCSVReadSchemaSuite + * + * -> JsonReadSchemaSuite + * + * -> OrcReadSchemaSuite + * -> VectorizedOrcReadSchemaSuite + * + * -> ParquetReadSchemaSuite + * -> VectorizedParquetReadSchemaSuite + * -> MergedParquetReadSchemaSuite + */ + +/** + * All file-based data sources supports column addition and removal at the end. + */ +abstract class ReadSchemaSuite + extends AddColumnTest + with HideColumnAtTheEndTest { + + var originalConf: Boolean = _ +} + +class CSVReadSchemaSuite + extends ReadSchemaSuite + with IntegralTypeTest + with ToDoubleTypeTest + with ToDecimalTypeTest + with ToStringTypeTest { + + override val format: String = "csv" +} + +class HeaderCSVReadSchemaSuite + extends ReadSchemaSuite + with IntegralTypeTest + with ToDoubleTypeTest + with ToDecimalTypeTest + with ToStringTypeTest { + + override val format: String = "csv" + + override val options = Map("header" -> "true") +} + +class JsonReadSchemaSuite + extends ReadSchemaSuite + with HideColumnInTheMiddleTest + with ChangePositionTest + with IntegralTypeTest + with ToDoubleTypeTest + with ToDecimalTypeTest + with ToStringTypeTest { + + override val format: String = "json" +} + +class OrcReadSchemaSuite + extends ReadSchemaSuite + with HideColumnInTheMiddleTest + with ChangePositionTest { + + override val format: String = "orc" + + override def beforeAll() { + super.beforeAll() + originalConf = spark.conf.get(SQLConf.ORC_VECTORIZED_READER_ENABLED) + spark.conf.set(SQLConf.ORC_VECTORIZED_READER_ENABLED.key, "false") + } + + override def afterAll() { + spark.conf.set(SQLConf.ORC_VECTORIZED_READER_ENABLED.key, originalConf) + super.afterAll() + } +} + +class VectorizedOrcReadSchemaSuite + extends ReadSchemaSuite + with HideColumnInTheMiddleTest + with ChangePositionTest + with BooleanTypeTest + with IntegralTypeTest + with ToDoubleTypeTest { + + override val format: String = "orc" + + override def beforeAll() { + super.beforeAll() + originalConf = spark.conf.get(SQLConf.ORC_VECTORIZED_READER_ENABLED) + spark.conf.set(SQLConf.ORC_VECTORIZED_READER_ENABLED.key, "true") + } + + override def afterAll() { + spark.conf.set(SQLConf.ORC_VECTORIZED_READER_ENABLED.key, originalConf) + super.afterAll() + } +} + +class ParquetReadSchemaSuite + extends ReadSchemaSuite + with HideColumnInTheMiddleTest + with ChangePositionTest { + + override val format: String = "parquet" + + override def beforeAll() { + super.beforeAll() + originalConf = spark.conf.get(SQLConf.PARQUET_VECTORIZED_READER_ENABLED) + spark.conf.set(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key, "false") + } + + override def afterAll() { + spark.conf.set(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key, originalConf) + super.afterAll() + } +} + +class VectorizedParquetReadSchemaSuite + extends ReadSchemaSuite + with HideColumnInTheMiddleTest + with ChangePositionTest { + + override val format: String = "parquet" + + override def beforeAll() { + super.beforeAll() + originalConf = spark.conf.get(SQLConf.PARQUET_VECTORIZED_READER_ENABLED) + spark.conf.set(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key, "true") + } + + override def afterAll() { + spark.conf.set(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key, originalConf) + super.afterAll() + } +} + +class MergedParquetReadSchemaSuite + extends ReadSchemaSuite + with HideColumnInTheMiddleTest + with ChangePositionTest { + + override val format: String = "parquet" + + override def beforeAll() { + super.beforeAll() + originalConf = spark.conf.get(SQLConf.PARQUET_SCHEMA_MERGING_ENABLED) + spark.conf.set(SQLConf.PARQUET_SCHEMA_MERGING_ENABLED.key, "true") + } + + override def afterAll() { + spark.conf.set(SQLConf.PARQUET_SCHEMA_MERGING_ENABLED.key, originalConf) + super.afterAll() + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/ReadSchemaTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/ReadSchemaTest.scala new file mode 100644 index 0000000000000..2a5457e00b4ef --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/ReadSchemaTest.scala @@ -0,0 +1,493 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import java.io.File + +import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} + +/** + * The reader schema is said to be evolved (or projected) when it changed after the data is + * written by writers. The followings are supported in file-based data sources. + * Note that partition columns are not maintained in files. Here, `column` means non-partition + * column. + * + * 1. Add a column + * 2. Hide a column + * 3. Change a column position + * 4. Change a column type (Upcast) + * + * Here, we consider safe changes without data loss. For example, data type changes should be + * from small types to larger types like `int`-to-`long`, not vice versa. + * + * So far, file-based data sources have the following coverages. + * + * | File Format | Coverage | Note | + * | ------------ | ------------ | ------------------------------------------------------ | + * | TEXT | N/A | Schema consists of a single string column. | + * | CSV | 1, 2, 4 | | + * | JSON | 1, 2, 3, 4 | | + * | ORC | 1, 2, 3, 4 | Native vectorized ORC reader has the widest coverage. | + * | PARQUET | 1, 2, 3 | | + * + * This aims to provide an explicit test coverage for reader schema change on file-based data + * sources. Since a file format has its own coverage, we need a test suite for each file-based + * data source with corresponding supported test case traits. + * + * The following is a hierarchy of test traits. + * + * ReadSchemaTest + * -> AddColumnTest + * -> HideColumnTest + * -> ChangePositionTest + * -> BooleanTypeTest + * -> IntegralTypeTest + * -> ToDoubleTypeTest + * -> ToDecimalTypeTest + */ + +trait ReadSchemaTest extends QueryTest with SQLTestUtils with SharedSQLContext { + val format: String + val options: Map[String, String] = Map.empty[String, String] +} + +/** + * Add column (Case 1). + * This test suite assumes that the missing column should be `null`. + */ +trait AddColumnTest extends ReadSchemaTest { + import testImplicits._ + + test("append column at the end") { + withTempPath { dir => + val path = dir.getCanonicalPath + + val df1 = Seq("a", "b").toDF("col1") + val df2 = df1.withColumn("col2", lit("x")) + val df3 = df2.withColumn("col3", lit("y")) + + val dir1 = s"$path${File.separator}part=one" + val dir2 = s"$path${File.separator}part=two" + val dir3 = s"$path${File.separator}part=three" + + df1.write.format(format).options(options).save(dir1) + df2.write.format(format).options(options).save(dir2) + df3.write.format(format).options(options).save(dir3) + + val df = spark.read + .schema(df3.schema) + .format(format) + .options(options) + .load(path) + + checkAnswer(df, Seq( + Row("a", null, null, "one"), + Row("b", null, null, "one"), + Row("a", "x", null, "two"), + Row("b", "x", null, "two"), + Row("a", "x", "y", "three"), + Row("b", "x", "y", "three"))) + } + } +} + +/** + * Hide column (Case 2-1). + */ +trait HideColumnAtTheEndTest extends ReadSchemaTest { + import testImplicits._ + + test("hide column at the end") { + withTempPath { dir => + val path = dir.getCanonicalPath + + val df1 = Seq(("1", "a"), ("2", "b")).toDF("col1", "col2") + val df2 = df1.withColumn("col3", lit("y")) + + val dir1 = s"$path${File.separator}part=two" + val dir2 = s"$path${File.separator}part=three" + + df1.write.format(format).options(options).save(dir1) + df2.write.format(format).options(options).save(dir2) + + val df = spark.read + .schema(df1.schema) + .format(format) + .options(options) + .load(path) + + checkAnswer(df, Seq( + Row("1", "a", "two"), + Row("2", "b", "two"), + Row("1", "a", "three"), + Row("2", "b", "three"))) + + val df3 = spark.read + .schema("col1 string") + .format(format) + .options(options) + .load(path) + + checkAnswer(df3, Seq( + Row("1", "two"), + Row("2", "two"), + Row("1", "three"), + Row("2", "three"))) + } + } +} + +/** + * Hide column in the middle (Case 2-2). + */ +trait HideColumnInTheMiddleTest extends ReadSchemaTest { + import testImplicits._ + + test("hide column in the middle") { + withTempPath { dir => + val path = dir.getCanonicalPath + + val df1 = Seq(("1", "a"), ("2", "b")).toDF("col1", "col2") + val df2 = df1.withColumn("col3", lit("y")) + + val dir1 = s"$path${File.separator}part=two" + val dir2 = s"$path${File.separator}part=three" + + df1.write.format(format).options(options).save(dir1) + df2.write.format(format).options(options).save(dir2) + + val df = spark.read + .schema("col2 string") + .format(format) + .options(options) + .load(path) + + checkAnswer(df, Seq( + Row("a", "two"), + Row("b", "two"), + Row("a", "three"), + Row("b", "three"))) + } + } +} + +/** + * Change column positions (Case 3). + * This suite assumes that all data set have the same number of columns. + */ +trait ChangePositionTest extends ReadSchemaTest { + import testImplicits._ + + test("change column position") { + withTempPath { dir => + val path = dir.getCanonicalPath + + val df1 = Seq(("1", "a"), ("2", "b"), ("3", "c")).toDF("col1", "col2") + val df2 = Seq(("d", "4"), ("e", "5"), ("f", "6")).toDF("col2", "col1") + val unionDF = df1.unionByName(df2) + + val dir1 = s"$path${File.separator}part=one" + val dir2 = s"$path${File.separator}part=two" + + df1.write.format(format).options(options).save(dir1) + df2.write.format(format).options(options).save(dir2) + + val df = spark.read + .schema(unionDF.schema) + .format(format) + .options(options) + .load(path) + .select("col1", "col2") + + checkAnswer(df, unionDF) + } + } +} + +/** + * Change a column type (Case 4). + * This suite assumes that a user gives a wider schema intentionally. + */ +trait BooleanTypeTest extends ReadSchemaTest { + import testImplicits._ + + test("change column type from boolean to byte/short/int/long") { + withTempPath { dir => + val path = dir.getCanonicalPath + + val values = (1 to 10).map(_ % 2) + val booleanDF = (1 to 10).map(_ % 2 == 1).toDF("col1") + val byteDF = values.map(_.toByte).toDF("col1") + val shortDF = values.map(_.toShort).toDF("col1") + val intDF = values.toDF("col1") + val longDF = values.map(_.toLong).toDF("col1") + + booleanDF.write.mode("overwrite").format(format).options(options).save(path) + + Seq( + ("col1 byte", byteDF), + ("col1 short", shortDF), + ("col1 int", intDF), + ("col1 long", longDF)).foreach { case (schema, answerDF) => + checkAnswer(spark.read.schema(schema).format(format).options(options).load(path), answerDF) + } + } + } +} + +/** + * Change a column type (Case 4). + * This suite assumes that a user gives a wider schema intentionally. + */ +trait ToStringTypeTest extends ReadSchemaTest { + import testImplicits._ + + test("read as string") { + withTempPath { dir => + val path = dir.getCanonicalPath + + val byteDF = (Byte.MaxValue - 2 to Byte.MaxValue).map(_.toByte).toDF("col1") + val shortDF = (Short.MaxValue - 2 to Short.MaxValue).map(_.toShort).toDF("col1") + val intDF = (Int.MaxValue - 2 to Int.MaxValue).toDF("col1") + val longDF = (Long.MaxValue - 2 to Long.MaxValue).toDF("col1") + val unionDF = byteDF.union(shortDF).union(intDF).union(longDF) + .selectExpr("cast(col1 AS STRING) col1") + + val byteDir = s"$path${File.separator}part=byte" + val shortDir = s"$path${File.separator}part=short" + val intDir = s"$path${File.separator}part=int" + val longDir = s"$path${File.separator}part=long" + + byteDF.write.format(format).options(options).save(byteDir) + shortDF.write.format(format).options(options).save(shortDir) + intDF.write.format(format).options(options).save(intDir) + longDF.write.format(format).options(options).save(longDir) + + val df = spark.read + .schema("col1 string") + .format(format) + .options(options) + .load(path) + .select("col1") + + checkAnswer(df, unionDF) + } + } +} + +/** + * Change a column type (Case 4). + * This suite assumes that a user gives a wider schema intentionally. + */ +trait IntegralTypeTest extends ReadSchemaTest { + + import testImplicits._ + + private lazy val values = 1 to 10 + private lazy val byteDF = values.map(_.toByte).toDF("col1") + private lazy val shortDF = values.map(_.toShort).toDF("col1") + private lazy val intDF = values.toDF("col1") + private lazy val longDF = values.map(_.toLong).toDF("col1") + + test("change column type from byte to short/int/long") { + withTempPath { dir => + val path = dir.getCanonicalPath + + byteDF.write.format(format).options(options).save(path) + + Seq( + ("col1 short", shortDF), + ("col1 int", intDF), + ("col1 long", longDF)).foreach { case (schema, answerDF) => + checkAnswer(spark.read.schema(schema).format(format).options(options).load(path), answerDF) + } + } + } + + test("change column type from short to int/long") { + withTempPath { dir => + val path = dir.getCanonicalPath + + shortDF.write.format(format).options(options).save(path) + + Seq(("col1 int", intDF), ("col1 long", longDF)).foreach { case (schema, answerDF) => + checkAnswer(spark.read.schema(schema).format(format).options(options).load(path), answerDF) + } + } + } + + test("change column type from int to long") { + withTempPath { dir => + val path = dir.getCanonicalPath + + intDF.write.format(format).options(options).save(path) + + Seq(("col1 long", longDF)).foreach { case (schema, answerDF) => + checkAnswer(spark.read.schema(schema).format(format).options(options).load(path), answerDF) + } + } + } + + test("read byte, int, short, long together") { + withTempPath { dir => + val path = dir.getCanonicalPath + + val byteDF = (Byte.MaxValue - 2 to Byte.MaxValue).map(_.toByte).toDF("col1") + val shortDF = (Short.MaxValue - 2 to Short.MaxValue).map(_.toShort).toDF("col1") + val intDF = (Int.MaxValue - 2 to Int.MaxValue).toDF("col1") + val longDF = (Long.MaxValue - 2 to Long.MaxValue).toDF("col1") + val unionDF = byteDF.union(shortDF).union(intDF).union(longDF) + + val byteDir = s"$path${File.separator}part=byte" + val shortDir = s"$path${File.separator}part=short" + val intDir = s"$path${File.separator}part=int" + val longDir = s"$path${File.separator}part=long" + + byteDF.write.format(format).options(options).save(byteDir) + shortDF.write.format(format).options(options).save(shortDir) + intDF.write.format(format).options(options).save(intDir) + longDF.write.format(format).options(options).save(longDir) + + val df = spark.read + .schema(unionDF.schema) + .format(format) + .options(options) + .load(path) + .select("col1") + + checkAnswer(df, unionDF) + } + } +} + +/** + * Change a column type (Case 4). + * This suite assumes that a user gives a wider schema intentionally. + */ +trait ToDoubleTypeTest extends ReadSchemaTest { + import testImplicits._ + + private lazy val values = 1 to 10 + private lazy val floatDF = values.map(_.toFloat).toDF("col1") + private lazy val doubleDF = values.map(_.toDouble).toDF("col1") + private lazy val unionDF = floatDF.union(doubleDF) + + test("change column type from float to double") { + withTempPath { dir => + val path = dir.getCanonicalPath + + floatDF.write.format(format).options(options).save(path) + + val df = spark.read.schema("col1 double").format(format).options(options).load(path) + + checkAnswer(df, doubleDF) + } + } + + test("read float and double together") { + withTempPath { dir => + val path = dir.getCanonicalPath + + val floatDir = s"$path${File.separator}part=float" + val doubleDir = s"$path${File.separator}part=double" + + floatDF.write.format(format).options(options).save(floatDir) + doubleDF.write.format(format).options(options).save(doubleDir) + + val df = spark.read + .schema(unionDF.schema) + .format(format) + .options(options) + .load(path) + .select("col1") + + checkAnswer(df, unionDF) + } + } +} + +/** + * Change a column type (Case 4). + * This suite assumes that a user gives a wider schema intentionally. + */ +trait ToDecimalTypeTest extends ReadSchemaTest { + import testImplicits._ + + private lazy val values = 1 to 10 + private lazy val floatDF = values.map(_.toFloat).toDF("col1") + private lazy val doubleDF = values.map(_.toDouble).toDF("col1") + private lazy val decimalDF = values.map(BigDecimal(_)).toDF("col1") + private lazy val unionDF = floatDF.union(doubleDF).union(decimalDF) + + test("change column type from float to decimal") { + withTempPath { dir => + val path = dir.getCanonicalPath + + floatDF.write.format(format).options(options).save(path) + + val df = spark.read + .schema("col1 decimal(38,18)") + .format(format) + .options(options) + .load(path) + + checkAnswer(df, decimalDF) + } + } + + test("change column type from double to decimal") { + withTempPath { dir => + val path = dir.getCanonicalPath + + doubleDF.write.format(format).options(options).save(path) + + val df = spark.read + .schema("col1 decimal(38,18)") + .format(format) + .options(options) + .load(path) + + checkAnswer(df, decimalDF) + } + } + + test("read float, double, decimal together") { + withTempPath { dir => + val path = dir.getCanonicalPath + + val floatDir = s"$path${File.separator}part=float" + val doubleDir = s"$path${File.separator}part=double" + val decimalDir = s"$path${File.separator}part=decimal" + + floatDF.write.format(format).options(options).save(floatDir) + doubleDF.write.format(format).options(options).save(doubleDir) + decimalDF.write.format(format).options(options).save(decimalDir) + + val df = spark.read + .schema(unionDF.schema) + .format(format) + .options(options) + .load(path) + .select("col1") + + checkAnswer(df, unionDF) + } + } +} From 11384893b6ad09c0c8bc6a350bb9540d0d704bb4 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Thu, 12 Jul 2018 15:13:26 -0700 Subject: [PATCH 1110/1161] [SPARK-24208][SQL][FOLLOWUP] Move test cases to proper locations ## What changes were proposed in this pull request? The PR is a followup to move the test cases introduced by the original PR in their proper location. ## How was this patch tested? moved UTs Author: Marco Gaido Closes #21751 from mgaido91/SPARK-24208_followup. --- python/pyspark/sql/tests.py | 32 +++++++++---------- .../sql/catalyst/analysis/AnalysisSuite.scala | 18 +++++++++++ .../spark/sql/GroupedDatasetSuite.scala | 12 ------- 3 files changed, 34 insertions(+), 28 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 4404dbe40590a..565654e7f03bb 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -5471,6 +5471,22 @@ def foo(_): self.assertEqual(r.a, 'hi') self.assertEqual(r.b, 1) + def test_self_join_with_pandas(self): + import pyspark.sql.functions as F + + @F.pandas_udf('key long, col string', F.PandasUDFType.GROUPED_MAP) + def dummy_pandas_udf(df): + return df[['key', 'col']] + + df = self.spark.createDataFrame([Row(key=1, col='A'), Row(key=1, col='B'), + Row(key=2, col='C')]) + df_with_pandas = df.groupBy('key').apply(dummy_pandas_udf) + + # this was throwing an AnalysisException before SPARK-24208 + res = df_with_pandas.alias('temp0').join(df_with_pandas.alias('temp1'), + F.col('temp0.key') == F.col('temp1.key')) + self.assertEquals(res.count(), 5) + @unittest.skipIf( not _have_pandas or not _have_pyarrow, @@ -5925,22 +5941,6 @@ def test_invalid_args(self): 'mixture.*aggregate function.*group aggregate pandas UDF'): df.groupby(df.id).agg(mean_udf(df.v), mean(df.v)).collect() - def test_self_join_with_pandas(self): - import pyspark.sql.functions as F - - @F.pandas_udf('key long, col string', F.PandasUDFType.GROUPED_MAP) - def dummy_pandas_udf(df): - return df[['key', 'col']] - - df = self.spark.createDataFrame([Row(key=1, col='A'), Row(key=1, col='B'), - Row(key=2, col='C')]) - dfWithPandas = df.groupBy('key').apply(dummy_pandas_udf) - - # this was throwing an AnalysisException before SPARK-24208 - res = dfWithPandas.alias('temp0').join(dfWithPandas.alias('temp1'), - F.col('temp0.key') == F.col('temp1.key')) - self.assertEquals(res.count(), 5) - @unittest.skipIf( not _have_pandas or not _have_pyarrow, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index cd8579584eada..bbcdf6c1b8481 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -21,6 +21,7 @@ import java.util.TimeZone import org.scalatest.Matchers +import org.apache.spark.api.python.PythonEvalType import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ @@ -557,4 +558,21 @@ class AnalysisSuite extends AnalysisTest with Matchers { SubqueryAlias("tbl", testRelation))) assertAnalysisError(barrier, Seq("cannot resolve '`tbl.b`'")) } + + test("SPARK-24208: analysis fails on self-join with FlatMapGroupsInPandas") { + val pythonUdf = PythonUDF("pyUDF", null, + StructType(Seq(StructField("a", LongType))), + Seq.empty, + PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, + true) + val output = pythonUdf.dataType.asInstanceOf[StructType].toAttributes + val project = Project(Seq(UnresolvedAttribute("a")), testRelation) + val flatMapGroupsInPandas = FlatMapGroupsInPandas( + Seq(UnresolvedAttribute("a")), pythonUdf, output, project) + val left = SubqueryAlias("temp0", flatMapGroupsInPandas) + val right = SubqueryAlias("temp1", flatMapGroupsInPandas) + val join = Join(left, right, Inner, None) + assertAnalysisSuccess( + Project(Seq(UnresolvedAttribute("temp0.a"), UnresolvedAttribute("temp1.a")), join)) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/GroupedDatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/GroupedDatasetSuite.scala index bd54ea415ca88..147c0b61f5017 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/GroupedDatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/GroupedDatasetSuite.scala @@ -93,16 +93,4 @@ class GroupedDatasetSuite extends QueryTest with SharedSQLContext { } datasetWithUDF.unpersist(true) } - - test("SPARK-24208: analysis fails on self-join with FlatMapGroupsInPandas") { - val df = datasetWithUDF.groupBy("s").flatMapGroupsInPandas(PythonUDF( - "pyUDF", - null, - StructType(Seq(StructField("s", LongType))), - Seq.empty, - PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, - true)) - val df1 = df.alias("temp0").join(df.alias("temp1"), $"temp0.s" === $"temp1.s") - df1.queryExecution.assertAnalyzed() - } } From 75725057b3ffdb0891844b10bd707bb0830f92ca Mon Sep 17 00:00:00 2001 From: maryannxue Date: Thu, 12 Jul 2018 16:54:03 -0700 Subject: [PATCH 1111/1161] [SPARK-24790][SQL] Allow complex aggregate expressions in Pivot ## What changes were proposed in this pull request? Relax the check to allow complex aggregate expressions, like `ceil(sum(col1))` or `sum(col1) + 1`, which roughly means any aggregate expression that could appear in an Aggregate plan except pandas UDF (due to the fact that it is not supported in pivot yet). ## How was this patch tested? Added 2 tests in pivot.sql Author: maryannxue Closes #21753 from maryannxue/pivot-relax-syntax. --- .../sql/catalyst/analysis/Analyzer.scala | 24 ++++++------- .../test/resources/sql-tests/inputs/pivot.sql | 18 ++++++++++ .../resources/sql-tests/results/pivot.sql.out | 34 +++++++++++++++++-- 3 files changed, 62 insertions(+), 14 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index c078efdfc0000..9749893dbdb0c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -509,12 +509,7 @@ class Analyzer( || !p.pivotColumn.resolved => p case Pivot(groupByExprsOpt, pivotColumn, pivotValues, aggregates, child) => // Check all aggregate expressions. - aggregates.foreach { e => - if (!isAggregateExpression(e)) { - throw new AnalysisException( - s"Aggregate expression required for pivot, found '$e'") - } - } + aggregates.foreach(checkValidAggregateExpression) // Group-by expressions coming from SQL are implicit and need to be deduced. val groupByExprs = groupByExprsOpt.getOrElse( (child.outputSet -- aggregates.flatMap(_.references) -- pivotColumn.references).toSeq) @@ -586,12 +581,17 @@ class Analyzer( } } - private def isAggregateExpression(expr: Expression): Boolean = { - expr match { - case Alias(e, _) => isAggregateExpression(e) - case AggregateExpression(_, _, _, _) => true - case _ => false - } + // Support any aggregate expression that can appear in an Aggregate plan except Pandas UDF. + // TODO: Support Pandas UDF. + private def checkValidAggregateExpression(expr: Expression): Unit = expr match { + case _: AggregateExpression => // OK and leave the argument check to CheckAnalysis. + case expr: PythonUDF if PythonUDF.isGroupedAggPandasUDF(expr) => + failAnalysis("Pandas UDF aggregate expressions are currently not supported in pivot.") + case e: Attribute => + failAnalysis( + s"Aggregate expression required for pivot, but '${e.sql}' " + + s"did not appear in any aggregate function.") + case e => e.children.foreach(checkValidAggregateExpression) } } diff --git a/sql/core/src/test/resources/sql-tests/inputs/pivot.sql b/sql/core/src/test/resources/sql-tests/inputs/pivot.sql index 01dea6c81c11b..b3d53adfbebe7 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/pivot.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/pivot.sql @@ -111,3 +111,21 @@ PIVOT ( sum(earnings) FOR year IN (2012, 2013) ); + +-- pivot with complex aggregate expressions +SELECT * FROM ( + SELECT year, course, earnings FROM courseSales +) +PIVOT ( + ceil(sum(earnings)), avg(earnings) + 1 as a1 + FOR course IN ('dotNET', 'Java') +); + +-- pivot with invalid arguments in aggregate expressions +SELECT * FROM ( + SELECT year, course, earnings FROM courseSales +) +PIVOT ( + sum(avg(earnings)) + FOR course IN ('dotNET', 'Java') +); diff --git a/sql/core/src/test/resources/sql-tests/results/pivot.sql.out b/sql/core/src/test/resources/sql-tests/results/pivot.sql.out index 85e3488990e20..922d8b9f9152c 100644 --- a/sql/core/src/test/resources/sql-tests/results/pivot.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/pivot.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 13 +-- Number of queries: 15 -- !query 0 @@ -176,7 +176,7 @@ PIVOT ( struct<> -- !query 11 output org.apache.spark.sql.AnalysisException -Aggregate expression required for pivot, found 'abs(earnings#x)'; +Aggregate expression required for pivot, but 'coursesales.`earnings`' did not appear in any aggregate function.; -- !query 12 @@ -192,3 +192,33 @@ struct<> -- !query 12 output org.apache.spark.sql.AnalysisException cannot resolve '`year`' given input columns: [__auto_generated_subquery_name.course, __auto_generated_subquery_name.earnings]; line 4 pos 0 + + +-- !query 13 +SELECT * FROM ( + SELECT year, course, earnings FROM courseSales +) +PIVOT ( + ceil(sum(earnings)), avg(earnings) + 1 as a1 + FOR course IN ('dotNET', 'Java') +) +-- !query 13 schema +struct +-- !query 13 output +2012 15000 7501.0 20000 20001.0 +2013 48000 48001.0 30000 30001.0 + + +-- !query 14 +SELECT * FROM ( + SELECT year, course, earnings FROM courseSales +) +PIVOT ( + sum(avg(earnings)) + FOR course IN ('dotNET', 'Java') +) +-- !query 14 schema +struct<> +-- !query 14 output +org.apache.spark.sql.AnalysisException +It is not allowed to use an aggregate function in the argument of another aggregate function. Please use the inner aggregate function in a sub-query.; From e0f4f206b737c62da307c1b1b8e6d2eae832696e Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Fri, 13 Jul 2018 10:40:58 +0800 Subject: [PATCH 1112/1161] [SPARK-24537][R] Add array_remove / array_zip / map_from_arrays / array_distinct ## What changes were proposed in this pull request? Add array_remove / array_zip / map_from_arrays / array_distinct functions in SparkR. ## How was this patch tested? Add tests in test_sparkSQL.R Author: Huaxin Gao Closes #21645 from huaxingao/spark-24537. --- R/pkg/NAMESPACE | 4 ++ R/pkg/R/functions.R | 70 +++++++++++++++++++++++++-- R/pkg/R/generics.R | 16 ++++++ R/pkg/tests/fulltests/test_sparkSQL.R | 21 ++++++++ 4 files changed, 107 insertions(+), 4 deletions(-) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 9696f6987ad78..adfd3871f3426 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -201,13 +201,16 @@ exportMethods("%<=>%", "approxCountDistinct", "approxQuantile", "array_contains", + "array_distinct", "array_join", "array_max", "array_min", "array_position", + "array_remove", "array_repeat", "array_sort", "arrays_overlap", + "arrays_zip", "asc", "ascii", "asin", @@ -306,6 +309,7 @@ exportMethods("%<=>%", "lpad", "ltrim", "map_entries", + "map_from_arrays", "map_keys", "map_values", "max", diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 3bff633fbc1ff..2929a00330c62 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -194,10 +194,12 @@ NULL #' \itemize{ #' \item \code{array_contains}: a value to be checked if contained in the column. #' \item \code{array_position}: a value to locate in the given array. +#' \item \code{array_remove}: a value to remove in the given array. #' } #' @param ... additional argument(s). In \code{to_json} and \code{from_json}, this contains #' additional named properties to control how it is converted, accepts the same -#' options as the JSON data source. +#' options as the JSON data source. In \code{arrays_zip}, this contains additional +#' Columns of arrays to be merged. #' @name column_collection_functions #' @rdname column_collection_functions #' @family collection functions @@ -207,9 +209,9 @@ NULL #' df <- createDataFrame(cbind(model = rownames(mtcars), mtcars)) #' tmp <- mutate(df, v1 = create_array(df$mpg, df$cyl, df$hp)) #' head(select(tmp, array_contains(tmp$v1, 21), size(tmp$v1))) -#' head(select(tmp, array_max(tmp$v1), array_min(tmp$v1))) +#' head(select(tmp, array_max(tmp$v1), array_min(tmp$v1), array_distinct(tmp$v1))) #' head(select(tmp, array_position(tmp$v1, 21), array_repeat(df$mpg, 3), array_sort(tmp$v1))) -#' head(select(tmp, flatten(tmp$v1), reverse(tmp$v1))) +#' head(select(tmp, flatten(tmp$v1), reverse(tmp$v1), array_remove(tmp$v1, 21))) #' tmp2 <- mutate(tmp, v2 = explode(tmp$v1)) #' head(tmp2) #' head(select(tmp, posexplode(tmp$v1))) @@ -221,6 +223,7 @@ NULL #' head(select(tmp3, element_at(tmp3$v3, "Valiant"))) #' tmp4 <- mutate(df, v4 = create_array(df$mpg, df$cyl), v5 = create_array(df$cyl, df$hp)) #' head(select(tmp4, concat(tmp4$v4, tmp4$v5), arrays_overlap(tmp4$v4, tmp4$v5))) +#' head(select(tmp4, arrays_zip(tmp4$v4, tmp4$v5), map_from_arrays(tmp4$v4, tmp4$v5))) #' head(select(tmp, concat(df$mpg, df$cyl, df$hp))) #' tmp5 <- mutate(df, v6 = create_array(df$model, df$model)) #' head(select(tmp5, array_join(tmp5$v6, "#"), array_join(tmp5$v6, "#", "NULL")))} @@ -1978,7 +1981,7 @@ setMethod("levenshtein", signature(y = "Column"), }) #' @details -#' \code{months_between}: Returns number of months between dates \code{y} and \code{x}. +#' \code{months_between}: Returns number of months between dates \code{y} and \code{x}. #' If \code{y} is later than \code{x}, then the result is positive. If \code{y} and \code{x} #' are on the same day of month, or both are the last day of month, time of day will be ignored. #' Otherwise, the difference is calculated based on 31 days per month, and rounded to 8 digits. @@ -3008,6 +3011,19 @@ setMethod("array_contains", column(jc) }) +#' @details +#' \code{array_distinct}: Removes duplicate values from the array. +#' +#' @rdname column_collection_functions +#' @aliases array_distinct array_distinct,Column-method +#' @note array_distinct since 2.4.0 +setMethod("array_distinct", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "array_distinct", x@jc) + column(jc) + }) + #' @details #' \code{array_join}: Concatenates the elements of column using the delimiter. #' Null values are replaced with nullReplacement if set, otherwise they are ignored. @@ -3071,6 +3087,19 @@ setMethod("array_position", column(jc) }) +#' @details +#' \code{array_remove}: Removes all elements that equal to element from the given array. +#' +#' @rdname column_collection_functions +#' @aliases array_remove array_remove,Column-method +#' @note array_remove since 2.4.0 +setMethod("array_remove", + signature(x = "Column", value = "ANY"), + function(x, value) { + jc <- callJStatic("org.apache.spark.sql.functions", "array_remove", x@jc, value) + column(jc) + }) + #' @details #' \code{array_repeat}: Creates an array containing \code{x} repeated the number of times #' given by \code{count}. @@ -3120,6 +3149,24 @@ setMethod("arrays_overlap", column(jc) }) +#' @details +#' \code{arrays_zip}: Returns a merged array of structs in which the N-th struct contains all N-th +#' values of input arrays. +#' +#' @rdname column_collection_functions +#' @aliases arrays_zip arrays_zip,Column-method +#' @note arrays_zip since 2.4.0 +setMethod("arrays_zip", + signature(x = "Column"), + function(x, ...) { + jcols <- lapply(list(x, ...), function(arg) { + stopifnot(class(arg) == "Column") + arg@jc + }) + jc <- callJStatic("org.apache.spark.sql.functions", "arrays_zip", jcols) + column(jc) + }) + #' @details #' \code{flatten}: Creates a single array from an array of arrays. #' If a structure of nested arrays is deeper than two levels, only one level of nesting is removed. @@ -3147,6 +3194,21 @@ setMethod("map_entries", column(jc) }) +#' @details +#' \code{map_from_arrays}: Creates a new map column. The array in the first column is used for +#' keys. The array in the second column is used for values. All elements in the array for key +#' should not be null. +#' +#' @rdname column_collection_functions +#' @aliases map_from_arrays map_from_arrays,Column-method +#' @note map_from_arrays since 2.4.0 +setMethod("map_from_arrays", + signature(x = "Column", y = "Column"), + function(x, y) { + jc <- callJStatic("org.apache.spark.sql.functions", "map_from_arrays", x@jc, y@jc) + column(jc) + }) + #' @details #' \code{map_keys}: Returns an unordered array containing the keys of the map. #' diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 9321bbaf96ff8..4a7210bf1b902 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -757,6 +757,10 @@ setGeneric("approxCountDistinct", function(x, ...) { standardGeneric("approxCoun #' @name NULL setGeneric("array_contains", function(x, value) { standardGeneric("array_contains") }) +#' @rdname column_collection_functions +#' @name NULL +setGeneric("array_distinct", function(x) { standardGeneric("array_distinct") }) + #' @rdname column_collection_functions #' @name NULL setGeneric("array_join", function(x, delimiter, ...) { standardGeneric("array_join") }) @@ -773,6 +777,10 @@ setGeneric("array_min", function(x) { standardGeneric("array_min") }) #' @name NULL setGeneric("array_position", function(x, value) { standardGeneric("array_position") }) +#' @rdname column_collection_functions +#' @name NULL +setGeneric("array_remove", function(x, value) { standardGeneric("array_remove") }) + #' @rdname column_collection_functions #' @name NULL setGeneric("array_repeat", function(x, count) { standardGeneric("array_repeat") }) @@ -785,6 +793,10 @@ setGeneric("array_sort", function(x) { standardGeneric("array_sort") }) #' @name NULL setGeneric("arrays_overlap", function(x, y) { standardGeneric("arrays_overlap") }) +#' @rdname column_collection_functions +#' @name NULL +setGeneric("arrays_zip", function(x, ...) { standardGeneric("arrays_zip") }) + #' @rdname column_string_functions #' @name NULL setGeneric("ascii", function(x) { standardGeneric("ascii") }) @@ -1050,6 +1062,10 @@ setGeneric("ltrim", function(x, trimString) { standardGeneric("ltrim") }) #' @name NULL setGeneric("map_entries", function(x) { standardGeneric("map_entries") }) +#' @rdname column_collection_functions +#' @name NULL +setGeneric("map_from_arrays", function(x, y) { standardGeneric("map_from_arrays") }) + #' @rdname column_collection_functions #' @name NULL setGeneric("map_keys", function(x) { standardGeneric("map_keys") }) diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index 36e0f78bb0599..adcbbff823a2d 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -1503,6 +1503,27 @@ test_that("column functions", { result <- collect(select(df2, reverse(df2[[1]])))[[1]] expect_equal(result, "cba") + # Test array_distinct() and array_remove() + df <- createDataFrame(list(list(list(1L, 2L, 3L, 1L, 2L)), list(list(6L, 5L, 5L, 4L, 6L)))) + result <- collect(select(df, array_distinct(df[[1]])))[[1]] + expect_equal(result, list(list(1L, 2L, 3L), list(6L, 5L, 4L))) + + result <- collect(select(df, array_remove(df[[1]], 2L)))[[1]] + expect_equal(result, list(list(1L, 3L, 1L), list(6L, 5L, 5L, 4L, 6L))) + + # Test arrays_zip() + df <- createDataFrame(list(list(list(1L, 2L), list(3L, 4L))), schema = c("c1", "c2")) + result <- collect(select(df, arrays_zip(df[[1]], df[[2]])))[[1]] + expected_entries <- list(listToStruct(list(c1 = 1L, c2 = 3L)), + listToStruct(list(c1 = 2L, c2 = 4L))) + expect_equal(result, list(expected_entries)) + + # Test map_from_arrays() + df <- createDataFrame(list(list(list("x", "y"), list(1, 2))), schema = c("k", "v")) + result <- collect(select(df, map_from_arrays(df$k, df$v)))[[1]] + expected_entries <- list(as.environment(list(x = 1, y = 2))) + expect_equal(result, expected_entries) + # Test array_repeat() df <- createDataFrame(list(list("a", 3L), list("b", 2L))) result <- collect(select(df, array_repeat(df[[1]], df[[2]])))[[1]] From 0ce11d0e3a7c8c48d9f7305d2dd39c7b281b2a53 Mon Sep 17 00:00:00 2001 From: Kevin Yu Date: Thu, 12 Jul 2018 22:20:06 -0700 Subject: [PATCH 1113/1161] [SPARK-23486] cache the function name from the external catalog for lookupFunctions ## What changes were proposed in this pull request? This PR will cache the function name from external catalog, it is used by lookupFunctions in the analyzer, and it is cached for each query plan. The original problem is reported in the [ spark-19737](https://issues.apache.org/jira/browse/SPARK-19737) ## How was this patch tested? create new test file LookupFunctionsSuite and add test case in SessionCatalogSuite Author: Kevin Yu Closes #20795 from kevinyu98/spark-23486. --- .../sql/catalyst/analysis/Analyzer.scala | 45 +++++++- .../sql/catalyst/catalog/SessionCatalog.scala | 16 +++ .../analysis/LookupFunctionsSuite.scala | 104 ++++++++++++++++++ .../catalog/SessionCatalogSuite.scala | 36 ++++++ .../spark/sql/hive/HiveSessionCatalog.scala | 4 + 5 files changed, 199 insertions(+), 6 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/LookupFunctionsSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 9749893dbdb0c..960ee27aec7be 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -17,6 +17,9 @@ package org.apache.spark.sql.catalyst.analysis +import java.util.Locale + +import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import scala.util.Random @@ -1208,16 +1211,46 @@ class Analyzer( * only performs simple existence check according to the function identifier to quickly identify * undefined functions without triggering relation resolution, which may incur potentially * expensive partition/schema discovery process in some cases. - * + * In order to avoid duplicate external functions lookup, the external function identifier will + * store in the local hash set externalFunctionNameSet. * @see [[ResolveFunctions]] * @see https://issues.apache.org/jira/browse/SPARK-19737 */ object LookupFunctions extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = plan.transformAllExpressions { - case f: UnresolvedFunction if !catalog.functionExists(f.name) => - withPosition(f) { - throw new NoSuchFunctionException(f.name.database.getOrElse("default"), f.name.funcName) - } + override def apply(plan: LogicalPlan): LogicalPlan = { + val externalFunctionNameSet = new mutable.HashSet[FunctionIdentifier]() + plan.transformAllExpressions { + case f: UnresolvedFunction + if externalFunctionNameSet.contains(normalizeFuncName(f.name)) => f + case f: UnresolvedFunction if catalog.isRegisteredFunction(f.name) => f + case f: UnresolvedFunction if catalog.isPersistentFunction(f.name) => + externalFunctionNameSet.add(normalizeFuncName(f.name)) + f + case f: UnresolvedFunction => + withPosition(f) { + throw new NoSuchFunctionException(f.name.database.getOrElse(catalog.getCurrentDatabase), + f.name.funcName) + } + } + } + + def normalizeFuncName(name: FunctionIdentifier): FunctionIdentifier = { + val funcName = if (conf.caseSensitiveAnalysis) { + name.funcName + } else { + name.funcName.toLowerCase(Locale.ROOT) + } + + val databaseName = name.database match { + case Some(a) => formatDatabaseName(a) + case None => catalog.getCurrentDatabase + } + + FunctionIdentifier(funcName, Some(databaseName)) + } + + protected def formatDatabaseName(name: String): String = { + if (conf.caseSensitiveAnalysis) name else name.toLowerCase(Locale.ROOT) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index c26a34528c162..b09b81eabf60d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -1193,6 +1193,22 @@ class SessionCatalog( !hiveFunctions.contains(name.funcName.toLowerCase(Locale.ROOT)) } + /** + * Return whether this function has been registered in the function registry of the current + * session. If not existed, return false. + */ + def isRegisteredFunction(name: FunctionIdentifier): Boolean = { + functionRegistry.functionExists(name) + } + + /** + * Returns whether it is a persistent function. If not existed, returns false. + */ + def isPersistentFunction(name: FunctionIdentifier): Boolean = { + val db = formatDatabaseName(name.database.getOrElse(getCurrentDatabase)) + databaseExists(db) && externalCatalog.functionExists(db, name.funcName) + } + protected def failFunctionLookup(name: FunctionIdentifier): Nothing = { throw new NoSuchFunctionException( db = name.database.getOrElse(getCurrentDatabase), func = name.funcName) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/LookupFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/LookupFunctionsSuite.scala new file mode 100644 index 0000000000000..cea0f2a9cbc97 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/LookupFunctionsSuite.scala @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import java.net.URI + +import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} +import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, InMemoryCatalog, SessionCatalog} +import org.apache.spark.sql.catalyst.expressions.Alias +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.internal.SQLConf + +class LookupFunctionsSuite extends PlanTest { + + test("SPARK-23486: the functionExists for the Persistent function check") { + val externalCatalog = new CustomInMemoryCatalog + val conf = new SQLConf() + val catalog = new SessionCatalog(externalCatalog, FunctionRegistry.builtin, conf) + val analyzer = { + catalog.createDatabase( + CatalogDatabase("default", "", new URI("loc"), Map.empty), + ignoreIfExists = false) + new Analyzer(catalog, conf) + } + + def table(ref: String): LogicalPlan = UnresolvedRelation(TableIdentifier(ref)) + val unresolvedPersistentFunc = UnresolvedFunction("func", Seq.empty, false) + val unresolvedRegisteredFunc = UnresolvedFunction("max", Seq.empty, false) + val plan = Project( + Seq(Alias(unresolvedPersistentFunc, "call1")(), Alias(unresolvedPersistentFunc, "call2")(), + Alias(unresolvedPersistentFunc, "call3")(), Alias(unresolvedRegisteredFunc, "call4")(), + Alias(unresolvedRegisteredFunc, "call5")()), + table("TaBlE")) + analyzer.LookupFunctions.apply(plan) + + assert(externalCatalog.getFunctionExistsCalledTimes == 1) + assert(analyzer.LookupFunctions.normalizeFuncName + (unresolvedPersistentFunc.name).database == Some("default")) + } + + test("SPARK-23486: the functionExists for the Registered function check") { + val externalCatalog = new InMemoryCatalog + val conf = new SQLConf() + val customerFunctionReg = new CustomerFunctionRegistry + val catalog = new SessionCatalog(externalCatalog, customerFunctionReg, conf) + val analyzer = { + catalog.createDatabase( + CatalogDatabase("default", "", new URI("loc"), Map.empty), + ignoreIfExists = false) + new Analyzer(catalog, conf) + } + + def table(ref: String): LogicalPlan = UnresolvedRelation(TableIdentifier(ref)) + val unresolvedRegisteredFunc = UnresolvedFunction("max", Seq.empty, false) + val plan = Project( + Seq(Alias(unresolvedRegisteredFunc, "call1")(), Alias(unresolvedRegisteredFunc, "call2")()), + table("TaBlE")) + analyzer.LookupFunctions.apply(plan) + + assert(customerFunctionReg.getIsRegisteredFunctionCalledTimes == 2) + assert(analyzer.LookupFunctions.normalizeFuncName + (unresolvedRegisteredFunc.name).database == Some("default")) + } +} + +class CustomerFunctionRegistry extends SimpleFunctionRegistry { + + private var isRegisteredFunctionCalledTimes: Int = 0; + + override def functionExists(funcN: FunctionIdentifier): Boolean = synchronized { + isRegisteredFunctionCalledTimes = isRegisteredFunctionCalledTimes + 1 + true + } + + def getIsRegisteredFunctionCalledTimes: Int = isRegisteredFunctionCalledTimes +} + +class CustomInMemoryCatalog extends InMemoryCatalog { + + private var functionExistsCalledTimes: Int = 0 + + override def functionExists(db: String, funcName: String): Boolean = synchronized { + functionExistsCalledTimes = functionExistsCalledTimes + 1 + true + } + + def getFunctionExistsCalledTimes: Int = functionExistsCalledTimes +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala index 6a7375ee186fa..50496a0410528 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala @@ -1217,6 +1217,42 @@ abstract class SessionCatalogSuite extends AnalysisTest { } } + test("isRegisteredFunction") { + withBasicCatalog { catalog => + // Returns false when the function does not register + assert(!catalog.isRegisteredFunction(FunctionIdentifier("temp1"))) + + // Returns true when the function does register + val tempFunc1 = (e: Seq[Expression]) => e.head + catalog.registerFunction(newFunc("iff", None), overrideIfExists = false, + functionBuilder = Some(tempFunc1) ) + assert(catalog.isRegisteredFunction(FunctionIdentifier("iff"))) + + // Returns false when using the createFunction + catalog.createFunction(newFunc("sum", Some("db2")), ignoreIfExists = false) + assert(!catalog.isRegisteredFunction(FunctionIdentifier("sum"))) + assert(!catalog.isRegisteredFunction(FunctionIdentifier("sum", Some("db2")))) + } + } + + test("isPersistentFunction") { + withBasicCatalog { catalog => + // Returns false when the function does not register + assert(!catalog.isPersistentFunction(FunctionIdentifier("temp2"))) + + // Returns false when the function does register + val tempFunc2 = (e: Seq[Expression]) => e.head + catalog.registerFunction(newFunc("iff", None), overrideIfExists = false, + functionBuilder = Some(tempFunc2)) + assert(!catalog.isPersistentFunction(FunctionIdentifier("iff"))) + + // Return true when using the createFunction + catalog.createFunction(newFunc("sum", Some("db2")), ignoreIfExists = false) + assert(catalog.isPersistentFunction(FunctionIdentifier("sum", Some("db2")))) + assert(!catalog.isPersistentFunction(FunctionIdentifier("db2.sum"))) + } + } + test("drop function") { withBasicCatalog { catalog => assert(catalog.externalCatalog.listFunctions("db2", "*").toSet == Set("func1")) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala index 94ddeae1bf547..de41bb418181d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala @@ -175,6 +175,10 @@ private[sql] class HiveSessionCatalog( super.functionExists(name) || hiveFunctions.contains(name.funcName) } + override def isPersistentFunction(name: FunctionIdentifier): Boolean = { + super.isPersistentFunction(name) || hiveFunctions.contains(name.funcName) + } + /** List of functions we pass over to Hive. Note that over time this list should go to 0. */ // We have a list of Hive built-in functions that we do not support. So, we will check // Hive's function registry and lazily load needed functions into our own function registry. From 0f24c6f8abc11c4525c21ac5cd25991bfed36dc4 Mon Sep 17 00:00:00 2001 From: Yuanbo Liu Date: Fri, 13 Jul 2018 07:37:24 -0600 Subject: [PATCH 1114/1161] =?UTF-8?q?[SPARK-24713]=20AppMatser=20of=20spar?= =?UTF-8?q?k=20streaming=20kafka=20OOM=20if=20there=20are=20hund=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit We have hundreds of kafka topics need to be consumed in one application. The application master will throw OOM exception after hanging for nearly half of an hour. OOM happens in the env with a lot of topics, and it's not convenient to set up such kind of env in the unit test. So I didn't change/add test case. Author: Yuanbo Liu Author: yuanbo Closes #21690 from yuanboliu/master. --- .../spark/streaming/kafka010/DirectKafkaInputDStream.scala | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala index c3221481556f5..0246006acf0bd 100644 --- a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala +++ b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala @@ -166,6 +166,8 @@ private[spark] class DirectKafkaInputDStream[K, V]( * which would throw off consumer position. Fix position if this happens. */ private def paranoidPoll(c: Consumer[K, V]): Unit = { + // don't actually want to consume any messages, so pause all partitions + c.pause(c.assignment()) val msgs = c.poll(0) if (!msgs.isEmpty) { // position should be minimum offset per topicpartition @@ -204,8 +206,6 @@ private[spark] class DirectKafkaInputDStream[K, V]( // position for new partitions determined by auto.offset.reset if no commit currentOffsets = currentOffsets ++ newPartitions.map(tp => tp -> c.position(tp)).toMap - // don't want to consume messages, so pause - c.pause(newPartitions.asJava) // find latest available offsets c.seekToEnd(currentOffsets.keySet.asJava) parts.map(tp => tp -> c.position(tp)).toMap @@ -262,9 +262,6 @@ private[spark] class DirectKafkaInputDStream[K, V]( tp -> c.position(tp) }.toMap } - - // don't actually want to consume any messages, so pause all partitions - c.pause(currentOffsets.keySet.asJava) } override def stop(): Unit = this.synchronized { From dfd7ac9887f89b9b51b7b143ab54d01f11cfcdb5 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 13 Jul 2018 08:25:00 -0700 Subject: [PATCH 1115/1161] [SPARK-24781][SQL] Using a reference from Dataset in Filter/Sort might not work ## What changes were proposed in this pull request? When we use a reference from Dataset in filter or sort, which was not used in the prior select, an AnalysisException occurs, e.g., ```scala val df = Seq(("test1", 0), ("test2", 1)).toDF("name", "id") df.select(df("name")).filter(df("id") === 0).show() ``` ```scala org.apache.spark.sql.AnalysisException: Resolved attribute(s) id#6 missing from name#5 in operator !Filter (id#6 = 0).;; !Filter (id#6 = 0) +- AnalysisBarrier +- Project [name#5] +- Project [_1#2 AS name#5, _2#3 AS id#6] +- LocalRelation [_1#2, _2#3] ``` This change updates the rule `ResolveMissingReferences` so `Filter` and `Sort` with non-empty `missingInputs` will also be transformed. ## How was this patch tested? Added tests. Author: Liang-Chi Hsieh Closes #21745 from viirya/SPARK-24781. --- .../sql/catalyst/analysis/Analyzer.scala | 30 ++++++++++++++----- .../org/apache/spark/sql/DataFrameSuite.scala | 25 ++++++++++++++++ 2 files changed, 48 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 960ee27aec7be..36f14ccdc6989 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1132,7 +1132,8 @@ class Analyzer( case sa @ Sort(_, _, AnalysisBarrier(child: Aggregate)) => sa case sa @ Sort(_, _, child: Aggregate) => sa - case s @ Sort(order, _, child) if !s.resolved && child.resolved => + case s @ Sort(order, _, child) + if (!s.resolved || s.missingInput.nonEmpty) && child.resolved => val (newOrder, newChild) = resolveExprsAndAddMissingAttrs(order, child) val ordering = newOrder.map(_.asInstanceOf[SortOrder]) if (child.output == newChild.output) { @@ -1143,7 +1144,7 @@ class Analyzer( Project(child.output, newSort) } - case f @ Filter(cond, child) if !f.resolved && child.resolved => + case f @ Filter(cond, child) if (!f.resolved || f.missingInput.nonEmpty) && child.resolved => val (newCond, newChild) = resolveExprsAndAddMissingAttrs(Seq(cond), child) if (child.output == newChild.output) { f.copy(condition = newCond.head) @@ -1154,10 +1155,17 @@ class Analyzer( } } + /** + * This method tries to resolve expressions and find missing attributes recursively. Specially, + * when the expressions used in `Sort` or `Filter` contain unresolved attributes or resolved + * attributes which are missed from child output. This method tries to find the missing + * attributes out and add into the projection. + */ private def resolveExprsAndAddMissingAttrs( exprs: Seq[Expression], plan: LogicalPlan): (Seq[Expression], LogicalPlan) = { - if (exprs.forall(_.resolved)) { - // All given expressions are resolved, no need to continue anymore. + // Missing attributes can be unresolved attributes or resolved attributes which are not in + // the output attributes of the plan. + if (exprs.forall(e => e.resolved && e.references.subsetOf(plan.outputSet))) { (exprs, plan) } else { plan match { @@ -1168,15 +1176,19 @@ class Analyzer( (newExprs, AnalysisBarrier(newChild)) case p: Project => + // Resolving expressions against current plan. val maybeResolvedExprs = exprs.map(resolveExpression(_, p)) + // Recursively resolving expressions on the child of current plan. val (newExprs, newChild) = resolveExprsAndAddMissingAttrs(maybeResolvedExprs, p.child) - val missingAttrs = AttributeSet(newExprs) -- AttributeSet(maybeResolvedExprs) + // If some attributes used by expressions are resolvable only on the rewritten child + // plan, we need to add them into original projection. + val missingAttrs = (AttributeSet(newExprs) -- p.outputSet).intersect(newChild.outputSet) (newExprs, Project(p.projectList ++ missingAttrs, newChild)) case a @ Aggregate(groupExprs, aggExprs, child) => val maybeResolvedExprs = exprs.map(resolveExpression(_, a)) val (newExprs, newChild) = resolveExprsAndAddMissingAttrs(maybeResolvedExprs, child) - val missingAttrs = AttributeSet(newExprs) -- AttributeSet(maybeResolvedExprs) + val missingAttrs = (AttributeSet(newExprs) -- a.outputSet).intersect(newChild.outputSet) if (missingAttrs.forall(attr => groupExprs.exists(_.semanticEquals(attr)))) { // All the missing attributes are grouping expressions, valid case. (newExprs, a.copy(aggregateExpressions = aggExprs ++ missingAttrs, child = newChild)) @@ -1526,7 +1538,11 @@ class Analyzer( // Try resolving the ordering as though it is in the aggregate clause. try { - val unresolvedSortOrders = sortOrder.filter(s => !s.resolved || containsAggregate(s)) + // If a sort order is unresolved, containing references not in aggregate, or containing + // `AggregateExpression`, we need to push down it to the underlying aggregate operator. + val unresolvedSortOrders = sortOrder.filter { s => + !s.resolved || !s.references.subsetOf(aggregate.outputSet) || containsAggregate(s) + } val aliasedOrdering = unresolvedSortOrders.map(o => Alias(o.child, "aggOrder")()) val aggregatedOrdering = aggregate.copy(aggregateExpressions = aliasedOrdering) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 9d7645d232d08..5babdf6f33b99 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -2387,4 +2387,29 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { val mapWithBinaryKey = map(lit(Array[Byte](1.toByte)), lit(1)) checkAnswer(spark.range(1).select(mapWithBinaryKey.getItem(Array[Byte](1.toByte))), Row(1)) } + + test("SPARK-24781: Using a reference from Dataset in Filter/Sort") { + val df = Seq(("test1", 0), ("test2", 1)).toDF("name", "id") + val filter1 = df.select(df("name")).filter(df("id") === 0) + val filter2 = df.select(col("name")).filter(col("id") === 0) + checkAnswer(filter1, filter2.collect()) + + val sort1 = df.select(df("name")).orderBy(df("id")) + val sort2 = df.select(col("name")).orderBy(col("id")) + checkAnswer(sort1, sort2.collect()) + } + + test("SPARK-24781: Using a reference not in aggregation in Filter/Sort") { + withSQLConf(SQLConf.DATAFRAME_RETAIN_GROUP_COLUMNS.key -> "false") { + val df = Seq(("test1", 0), ("test2", 1)).toDF("name", "id") + + val aggPlusSort1 = df.groupBy(df("name")).agg(count(df("name"))).orderBy(df("name")) + val aggPlusSort2 = df.groupBy(col("name")).agg(count(col("name"))).orderBy(col("name")) + checkAnswer(aggPlusSort1, aggPlusSort2.collect()) + + val aggPlusFilter1 = df.groupBy(df("name")).agg(count(df("name"))).filter(df("name") === 0) + val aggPlusFilter2 = df.groupBy(col("name")).agg(count(col("name"))).filter(col("name") === 0) + checkAnswer(aggPlusFilter1, aggPlusFilter2.collect()) + } + } } From c1b62e420a43aa7da36733ccdbec057d87ac1b43 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Fri, 13 Jul 2018 08:55:46 -0700 Subject: [PATCH 1116/1161] [SPARK-24776][SQL] Avro unit test: use SQLTestUtils and replace deprecated methods ## What changes were proposed in this pull request? Improve Avro unit test: 1. use QueryTest/SharedSQLContext/SQLTestUtils, instead of the duplicated test utils. 2. replace deprecated methods ## How was this patch tested? Unit test Author: Gengliang Wang Closes #21760 from gengliangwang/improve_avro_test. --- .../org/apache/spark/sql/avro/AvroSuite.scala | 114 ++++++------- .../org/apache/spark/sql/avro/TestUtils.scala | 156 ------------------ 2 files changed, 53 insertions(+), 217 deletions(-) delete mode 100755 external/avro/src/test/scala/org/apache/spark/sql/avro/TestUtils.scala diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala index c6c1e4051a4b3..108b347ca0f56 100644 --- a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala +++ b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala @@ -31,32 +31,24 @@ import org.apache.avro.generic.{GenericData, GenericDatumWriter, GenericRecord} import org.apache.avro.generic.GenericData.{EnumSymbol, Fixed} import org.apache.commons.io.FileUtils -import org.apache.spark.SparkFunSuite import org.apache.spark.sql._ import org.apache.spark.sql.avro.SchemaConverters.IncompatibleSchemaException +import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} import org.apache.spark.sql.types._ -class AvroSuite extends SparkFunSuite { +class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { val episodesFile = "src/test/resources/episodes.avro" val testFile = "src/test/resources/test.avro" - private var spark: SparkSession = _ - override protected def beforeAll(): Unit = { super.beforeAll() - spark = SparkSession.builder() - .master("local[2]") - .appName("AvroSuite") - .config("spark.sql.files.maxPartitionBytes", 1024) - .getOrCreate() - } - - override protected def afterAll(): Unit = { - try { - spark.sparkContext.stop() - } finally { - super.afterAll() - } + spark.conf.set("spark.sql.files.maxPartitionBytes", 1024) + } + + def checkReloadMatchesSaved(originalFile: String, newFile: String): Unit = { + val originalEntries = spark.read.avro(testFile).collect() + val newEntries = spark.read.avro(newFile) + checkAnswer(newEntries, originalEntries) } test("reading from multiple paths") { @@ -68,7 +60,7 @@ class AvroSuite extends SparkFunSuite { val df = spark.read.avro(episodesFile) val fields = List("title", "air_date", "doctor") for (field <- fields) { - TestUtils.withTempDir { dir => + withTempPath { dir => val outputDir = s"$dir/${UUID.randomUUID}" df.write.partitionBy(field).avro(outputDir) val input = spark.read.avro(outputDir) @@ -82,12 +74,12 @@ class AvroSuite extends SparkFunSuite { test("request no fields") { val df = spark.read.avro(episodesFile) - df.registerTempTable("avro_table") + df.createOrReplaceTempView("avro_table") assert(spark.sql("select count(*) from avro_table").collect().head === Row(8)) } test("convert formats") { - TestUtils.withTempDir { dir => + withTempPath { dir => val df = spark.read.avro(episodesFile) df.write.parquet(dir.getCanonicalPath) assert(spark.read.parquet(dir.getCanonicalPath).count() === df.count) @@ -95,15 +87,16 @@ class AvroSuite extends SparkFunSuite { } test("rearrange internal schema") { - TestUtils.withTempDir { dir => + withTempPath { dir => val df = spark.read.avro(episodesFile) df.select("doctor", "title").write.avro(dir.getCanonicalPath) } } test("test NULL avro type") { - TestUtils.withTempDir { dir => - val fields = Seq(new Field("null", Schema.create(Type.NULL), "doc", null)).asJava + withTempPath { dir => + val fields = + Seq(new Field("null", Schema.create(Type.NULL), "doc", null.asInstanceOf[Any])).asJava val schema = Schema.createRecord("name", "docs", "namespace", false) schema.setFields(fields) val datumWriter = new GenericDatumWriter[GenericRecord](schema) @@ -122,11 +115,11 @@ class AvroSuite extends SparkFunSuite { } test("union(int, long) is read as long") { - TestUtils.withTempDir { dir => + withTempPath { dir => val avroSchema: Schema = { val union = Schema.createUnion(List(Schema.create(Type.INT), Schema.create(Type.LONG)).asJava) - val fields = Seq(new Field("field1", union, "doc", null)).asJava + val fields = Seq(new Field("field1", union, "doc", null.asInstanceOf[Any])).asJava val schema = Schema.createRecord("name", "docs", "namespace", false) schema.setFields(fields) schema @@ -150,11 +143,11 @@ class AvroSuite extends SparkFunSuite { } test("union(float, double) is read as double") { - TestUtils.withTempDir { dir => + withTempPath { dir => val avroSchema: Schema = { val union = Schema.createUnion(List(Schema.create(Type.FLOAT), Schema.create(Type.DOUBLE)).asJava) - val fields = Seq(new Field("field1", union, "doc", null)).asJava + val fields = Seq(new Field("field1", union, "doc", null.asInstanceOf[Any])).asJava val schema = Schema.createRecord("name", "docs", "namespace", false) schema.setFields(fields) schema @@ -178,7 +171,7 @@ class AvroSuite extends SparkFunSuite { } test("union(float, double, null) is read as nullable double") { - TestUtils.withTempDir { dir => + withTempPath { dir => val avroSchema: Schema = { val union = Schema.createUnion( List(Schema.create(Type.FLOAT), @@ -186,7 +179,7 @@ class AvroSuite extends SparkFunSuite { Schema.create(Type.NULL) ).asJava ) - val fields = Seq(new Field("field1", union, "doc", null)).asJava + val fields = Seq(new Field("field1", union, "doc", null.asInstanceOf[Any])).asJava val schema = Schema.createRecord("name", "docs", "namespace", false) schema.setFields(fields) schema @@ -210,9 +203,9 @@ class AvroSuite extends SparkFunSuite { } test("Union of a single type") { - TestUtils.withTempDir { dir => + withTempPath { dir => val UnionOfOne = Schema.createUnion(List(Schema.create(Type.INT)).asJava) - val fields = Seq(new Field("field1", UnionOfOne, "doc", null)).asJava + val fields = Seq(new Field("field1", UnionOfOne, "doc", null.asInstanceOf[Any])).asJava val schema = Schema.createRecord("name", "docs", "namespace", false) schema.setFields(fields) @@ -233,16 +226,16 @@ class AvroSuite extends SparkFunSuite { } test("Complex Union Type") { - TestUtils.withTempDir { dir => + withTempPath { dir => val fixedSchema = Schema.createFixed("fixed_name", "doc", "namespace", 4) val enumSchema = Schema.createEnum("enum_name", "doc", "namespace", List("e1", "e2").asJava) val complexUnionType = Schema.createUnion( List(Schema.create(Type.INT), Schema.create(Type.STRING), fixedSchema, enumSchema).asJava) val fields = Seq( - new Field("field1", complexUnionType, "doc", null), - new Field("field2", complexUnionType, "doc", null), - new Field("field3", complexUnionType, "doc", null), - new Field("field4", complexUnionType, "doc", null) + new Field("field1", complexUnionType, "doc", null.asInstanceOf[Any]), + new Field("field2", complexUnionType, "doc", null.asInstanceOf[Any]), + new Field("field3", complexUnionType, "doc", null.asInstanceOf[Any]), + new Field("field4", complexUnionType, "doc", null.asInstanceOf[Any]) ).asJava val schema = Schema.createRecord("name", "docs", "namespace", false) schema.setFields(fields) @@ -271,7 +264,7 @@ class AvroSuite extends SparkFunSuite { } test("Lots of nulls") { - TestUtils.withTempDir { dir => + withTempPath { dir => val schema = StructType(Seq( StructField("binary", BinaryType, true), StructField("timestamp", TimestampType, true), @@ -290,7 +283,7 @@ class AvroSuite extends SparkFunSuite { } test("Struct field type") { - TestUtils.withTempDir { dir => + withTempPath { dir => val schema = StructType(Seq( StructField("float", FloatType, true), StructField("short", ShortType, true), @@ -309,7 +302,7 @@ class AvroSuite extends SparkFunSuite { } test("Date field type") { - TestUtils.withTempDir { dir => + withTempPath { dir => val schema = StructType(Seq( StructField("float", FloatType, true), StructField("date", DateType, true) @@ -329,7 +322,7 @@ class AvroSuite extends SparkFunSuite { } test("Array data types") { - TestUtils.withTempDir { dir => + withTempPath { dir => val testSchema = StructType(Seq( StructField("byte_array", ArrayType(ByteType), true), StructField("short_array", ArrayType(ShortType), true), @@ -363,13 +356,12 @@ class AvroSuite extends SparkFunSuite { } test("write with compression") { - TestUtils.withTempDir { dir => + withTempPath { dir => val AVRO_COMPRESSION_CODEC = "spark.sql.avro.compression.codec" val AVRO_DEFLATE_LEVEL = "spark.sql.avro.deflate.level" val uncompressDir = s"$dir/uncompress" val deflateDir = s"$dir/deflate" val snappyDir = s"$dir/snappy" - val fakeDir = s"$dir/fake" val df = spark.read.avro(testFile) spark.conf.set(AVRO_COMPRESSION_CODEC, "uncompressed") @@ -439,7 +431,7 @@ class AvroSuite extends SparkFunSuite { test("sql test") { spark.sql( s""" - |CREATE TEMPORARY TABLE avroTable + |CREATE TEMPORARY VIEW avroTable |USING avro |OPTIONS (path "$episodesFile") """.stripMargin.replaceAll("\n", " ")) @@ -450,24 +442,24 @@ class AvroSuite extends SparkFunSuite { test("conversion to avro and back") { // Note that test.avro includes a variety of types, some of which are nullable. We expect to // get the same values back. - TestUtils.withTempDir { dir => + withTempPath { dir => val avroDir = s"$dir/avro" spark.read.avro(testFile).write.avro(avroDir) - TestUtils.checkReloadMatchesSaved(spark, testFile, avroDir) + checkReloadMatchesSaved(testFile, avroDir) } } test("conversion to avro and back with namespace") { // Note that test.avro includes a variety of types, some of which are nullable. We expect to // get the same values back. - TestUtils.withTempDir { tempDir => + withTempPath { tempDir => val name = "AvroTest" val namespace = "com.databricks.spark.avro" val parameters = Map("recordName" -> name, "recordNamespace" -> namespace) val avroDir = tempDir + "/namedAvro" spark.read.avro(testFile).write.options(parameters).avro(avroDir) - TestUtils.checkReloadMatchesSaved(spark, testFile, avroDir) + checkReloadMatchesSaved(testFile, avroDir) // Look at raw file and make sure has namespace info val rawSaved = spark.sparkContext.textFile(avroDir) @@ -478,7 +470,7 @@ class AvroSuite extends SparkFunSuite { } test("converting some specific sparkSQL types to avro") { - TestUtils.withTempDir { tempDir => + withTempPath { tempDir => val testSchema = StructType(Seq( StructField("Name", StringType, false), StructField("Length", IntegerType, true), @@ -520,7 +512,7 @@ class AvroSuite extends SparkFunSuite { } test("correctly read long as date/timestamp type") { - TestUtils.withTempDir { tempDir => + withTempPath { tempDir => val sparkSession = spark import sparkSession.implicits._ @@ -549,7 +541,7 @@ class AvroSuite extends SparkFunSuite { } test("does not coerce null date/timestamp value to 0 epoch.") { - TestUtils.withTempDir { tempDir => + withTempPath { tempDir => val sparkSession = spark import sparkSession.implicits._ @@ -610,7 +602,7 @@ class AvroSuite extends SparkFunSuite { // Directory given has no avro files intercept[AnalysisException] { - TestUtils.withTempDir(dir => spark.read.avro(dir.getCanonicalPath)) + withTempPath(dir => spark.read.avro(dir.getCanonicalPath)) } intercept[AnalysisException] { @@ -624,7 +616,7 @@ class AvroSuite extends SparkFunSuite { } intercept[FileNotFoundException] { - TestUtils.withTempDir { dir => + withTempPath { dir => FileUtils.touch(new File(dir, "test")) spark.read.avro(dir.toString) } @@ -633,19 +625,19 @@ class AvroSuite extends SparkFunSuite { } test("SQL test insert overwrite") { - TestUtils.withTempDir { tempDir => + withTempPath { tempDir => val tempEmptyDir = s"$tempDir/sqlOverwrite" // Create a temp directory for table that will be overwritten new File(tempEmptyDir).mkdirs() spark.sql( s""" - |CREATE TEMPORARY TABLE episodes + |CREATE TEMPORARY VIEW episodes |USING avro |OPTIONS (path "$episodesFile") """.stripMargin.replaceAll("\n", " ")) spark.sql( s""" - |CREATE TEMPORARY TABLE episodesEmpty + |CREATE TEMPORARY VIEW episodesEmpty |(name string, air_date string, doctor int) |USING avro |OPTIONS (path "$tempEmptyDir") @@ -665,7 +657,7 @@ class AvroSuite extends SparkFunSuite { test("test save and load") { // Test if load works as expected - TestUtils.withTempDir { tempDir => + withTempPath { tempDir => val df = spark.read.avro(episodesFile) assert(df.count == 8) @@ -679,7 +671,7 @@ class AvroSuite extends SparkFunSuite { test("test load with non-Avro file") { // Test if load works as expected - TestUtils.withTempDir { tempDir => + withTempPath { tempDir => val df = spark.read.avro(episodesFile) assert(df.count == 8) @@ -737,7 +729,7 @@ class AvroSuite extends SparkFunSuite { } test("read avro file partitioned") { - TestUtils.withTempDir { dir => + withTempPath { dir => val sparkSession = spark import sparkSession.implicits._ val df = (0 to 1024 * 3).toDS.map(i => s"record${i}").toDF("records") @@ -756,7 +748,7 @@ class AvroSuite extends SparkFunSuite { case class NestedTop(id: Int, data: NestedMiddle) test("saving avro that has nested records with the same name") { - TestUtils.withTempDir { tempDir => + withTempPath { tempDir => // Save avro file on output folder path val writeDf = spark.createDataFrame(List(NestedTop(1, NestedMiddle(2, NestedBottom(3, "1"))))) val outputFolder = s"$tempDir/duplicate_names/" @@ -773,7 +765,7 @@ class AvroSuite extends SparkFunSuite { case class NestedTopArray(id: Int, data: NestedMiddleArray) test("saving avro that has nested records with the same name inside an array") { - TestUtils.withTempDir { tempDir => + withTempPath { tempDir => // Save avro file on output folder path val writeDf = spark.createDataFrame( List(NestedTopArray(1, NestedMiddleArray(2, Array( @@ -794,7 +786,7 @@ class AvroSuite extends SparkFunSuite { case class NestedTopMap(id: Int, data: NestedMiddleMap) test("saving avro that has nested records with the same name inside a map") { - TestUtils.withTempDir { tempDir => + withTempPath { tempDir => // Save avro file on output folder path val writeDf = spark.createDataFrame( List(NestedTopMap(1, NestedMiddleMap(2, Map( diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/TestUtils.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/TestUtils.scala deleted file mode 100755 index 4ae9b14d9ad0d..0000000000000 --- a/external/avro/src/test/scala/org/apache/spark/sql/avro/TestUtils.scala +++ /dev/null @@ -1,156 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.avro - -import java.io.{File, IOException} -import java.nio.ByteBuffer - -import scala.collection.immutable.HashSet -import scala.collection.mutable.ArrayBuffer -import scala.util.Random - -import com.google.common.io.Files -import java.util - -import org.apache.spark.sql.SparkSession - -private[avro] object TestUtils { - - /** - * This function checks that all records in a file match the original - * record. - */ - def checkReloadMatchesSaved(spark: SparkSession, testFile: String, avroDir: String): Unit = { - - def convertToString(elem: Any): String = { - elem match { - case null => "NULL" // HashSets can't have null in them, so we use a string instead - case arrayBuf: ArrayBuffer[_] => - arrayBuf.asInstanceOf[ArrayBuffer[Any]].toArray.deep.mkString(" ") - case arrayByte: Array[Byte] => arrayByte.deep.mkString(" ") - case other => other.toString - } - } - - val originalEntries = spark.read.avro(testFile).collect() - val newEntries = spark.read.avro(avroDir).collect() - - assert(originalEntries.length == newEntries.length) - - val origEntrySet = Array.fill(originalEntries(0).size)(new HashSet[Any]()) - for (origEntry <- originalEntries) { - var idx = 0 - for (origElement <- origEntry.toSeq) { - origEntrySet(idx) += convertToString(origElement) - idx += 1 - } - } - - for (newEntry <- newEntries) { - var idx = 0 - for (newElement <- newEntry.toSeq) { - assert(origEntrySet(idx).contains(convertToString(newElement))) - idx += 1 - } - } - } - - def withTempDir(f: File => Unit): Unit = { - val dir = Files.createTempDir() - dir.delete() - try f(dir) finally deleteRecursively(dir) - } - - /** - * This function deletes a file or a directory with everything that's in it. This function is - * copied from Spark with minor modifications made to it. See original source at: - * github.com/apache/spark/blob/master/core/src/main/scala/org/apache/spark/util/Utils.scala - */ - - def deleteRecursively(file: File) { - def listFilesSafely(file: File): Seq[File] = { - if (file.exists()) { - val files = file.listFiles() - if (files == null) { - throw new IOException("Failed to list files for dir: " + file) - } - files - } else { - List() - } - } - - if (file != null) { - try { - if (file.isDirectory) { - var savedIOException: IOException = null - for (child <- listFilesSafely(file)) { - try { - deleteRecursively(child) - } catch { - // In case of multiple exceptions, only last one will be thrown - case ioe: IOException => savedIOException = ioe - } - } - if (savedIOException != null) { - throw savedIOException - } - } - } finally { - if (!file.delete()) { - // Delete can also fail if the file simply did not exist - if (file.exists()) { - throw new IOException("Failed to delete: " + file.getAbsolutePath) - } - } - } - } - } - - /** - * This function generates a random map(string, int) of a given size. - */ - private[avro] def generateRandomMap(rand: Random, size: Int): java.util.Map[String, Int] = { - val jMap = new util.HashMap[String, Int]() - for (i <- 0 until size) { - jMap.put(rand.nextString(5), i) - } - jMap - } - - /** - * This function generates a random array of booleans of a given size. - */ - private[avro] def generateRandomArray(rand: Random, size: Int): util.ArrayList[Boolean] = { - val vec = new util.ArrayList[Boolean]() - for (i <- 0 until size) { - vec.add(rand.nextBoolean()) - } - vec - } - - /** - * This function generates a random ByteBuffer of a given size. - */ - private[avro] def generateRandomByteBuffer(rand: Random, size: Int): ByteBuffer = { - val bb = ByteBuffer.allocate(size) - val arrayOfBytes = new Array[Byte](size) - rand.nextBytes(arrayOfBytes) - bb.put(arrayOfBytes) - } -} From 3bcb1b481423aedf1ac531ad582c7cb8685f1e3c Mon Sep 17 00:00:00 2001 From: Xiao Li Date: Fri, 13 Jul 2018 10:06:26 -0700 Subject: [PATCH 1117/1161] Revert "[SPARK-24776][SQL] Avro unit test: use SQLTestUtils and replace deprecated methods" This reverts commit c1b62e420a43aa7da36733ccdbec057d87ac1b43. --- .../org/apache/spark/sql/avro/AvroSuite.scala | 114 +++++++------ .../org/apache/spark/sql/avro/TestUtils.scala | 156 ++++++++++++++++++ 2 files changed, 217 insertions(+), 53 deletions(-) create mode 100755 external/avro/src/test/scala/org/apache/spark/sql/avro/TestUtils.scala diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala index 108b347ca0f56..c6c1e4051a4b3 100644 --- a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala +++ b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala @@ -31,24 +31,32 @@ import org.apache.avro.generic.{GenericData, GenericDatumWriter, GenericRecord} import org.apache.avro.generic.GenericData.{EnumSymbol, Fixed} import org.apache.commons.io.FileUtils +import org.apache.spark.SparkFunSuite import org.apache.spark.sql._ import org.apache.spark.sql.avro.SchemaConverters.IncompatibleSchemaException -import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} import org.apache.spark.sql.types._ -class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { +class AvroSuite extends SparkFunSuite { val episodesFile = "src/test/resources/episodes.avro" val testFile = "src/test/resources/test.avro" + private var spark: SparkSession = _ + override protected def beforeAll(): Unit = { super.beforeAll() - spark.conf.set("spark.sql.files.maxPartitionBytes", 1024) - } - - def checkReloadMatchesSaved(originalFile: String, newFile: String): Unit = { - val originalEntries = spark.read.avro(testFile).collect() - val newEntries = spark.read.avro(newFile) - checkAnswer(newEntries, originalEntries) + spark = SparkSession.builder() + .master("local[2]") + .appName("AvroSuite") + .config("spark.sql.files.maxPartitionBytes", 1024) + .getOrCreate() + } + + override protected def afterAll(): Unit = { + try { + spark.sparkContext.stop() + } finally { + super.afterAll() + } } test("reading from multiple paths") { @@ -60,7 +68,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { val df = spark.read.avro(episodesFile) val fields = List("title", "air_date", "doctor") for (field <- fields) { - withTempPath { dir => + TestUtils.withTempDir { dir => val outputDir = s"$dir/${UUID.randomUUID}" df.write.partitionBy(field).avro(outputDir) val input = spark.read.avro(outputDir) @@ -74,12 +82,12 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { test("request no fields") { val df = spark.read.avro(episodesFile) - df.createOrReplaceTempView("avro_table") + df.registerTempTable("avro_table") assert(spark.sql("select count(*) from avro_table").collect().head === Row(8)) } test("convert formats") { - withTempPath { dir => + TestUtils.withTempDir { dir => val df = spark.read.avro(episodesFile) df.write.parquet(dir.getCanonicalPath) assert(spark.read.parquet(dir.getCanonicalPath).count() === df.count) @@ -87,16 +95,15 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("rearrange internal schema") { - withTempPath { dir => + TestUtils.withTempDir { dir => val df = spark.read.avro(episodesFile) df.select("doctor", "title").write.avro(dir.getCanonicalPath) } } test("test NULL avro type") { - withTempPath { dir => - val fields = - Seq(new Field("null", Schema.create(Type.NULL), "doc", null.asInstanceOf[Any])).asJava + TestUtils.withTempDir { dir => + val fields = Seq(new Field("null", Schema.create(Type.NULL), "doc", null)).asJava val schema = Schema.createRecord("name", "docs", "namespace", false) schema.setFields(fields) val datumWriter = new GenericDatumWriter[GenericRecord](schema) @@ -115,11 +122,11 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("union(int, long) is read as long") { - withTempPath { dir => + TestUtils.withTempDir { dir => val avroSchema: Schema = { val union = Schema.createUnion(List(Schema.create(Type.INT), Schema.create(Type.LONG)).asJava) - val fields = Seq(new Field("field1", union, "doc", null.asInstanceOf[Any])).asJava + val fields = Seq(new Field("field1", union, "doc", null)).asJava val schema = Schema.createRecord("name", "docs", "namespace", false) schema.setFields(fields) schema @@ -143,11 +150,11 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("union(float, double) is read as double") { - withTempPath { dir => + TestUtils.withTempDir { dir => val avroSchema: Schema = { val union = Schema.createUnion(List(Schema.create(Type.FLOAT), Schema.create(Type.DOUBLE)).asJava) - val fields = Seq(new Field("field1", union, "doc", null.asInstanceOf[Any])).asJava + val fields = Seq(new Field("field1", union, "doc", null)).asJava val schema = Schema.createRecord("name", "docs", "namespace", false) schema.setFields(fields) schema @@ -171,7 +178,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("union(float, double, null) is read as nullable double") { - withTempPath { dir => + TestUtils.withTempDir { dir => val avroSchema: Schema = { val union = Schema.createUnion( List(Schema.create(Type.FLOAT), @@ -179,7 +186,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { Schema.create(Type.NULL) ).asJava ) - val fields = Seq(new Field("field1", union, "doc", null.asInstanceOf[Any])).asJava + val fields = Seq(new Field("field1", union, "doc", null)).asJava val schema = Schema.createRecord("name", "docs", "namespace", false) schema.setFields(fields) schema @@ -203,9 +210,9 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("Union of a single type") { - withTempPath { dir => + TestUtils.withTempDir { dir => val UnionOfOne = Schema.createUnion(List(Schema.create(Type.INT)).asJava) - val fields = Seq(new Field("field1", UnionOfOne, "doc", null.asInstanceOf[Any])).asJava + val fields = Seq(new Field("field1", UnionOfOne, "doc", null)).asJava val schema = Schema.createRecord("name", "docs", "namespace", false) schema.setFields(fields) @@ -226,16 +233,16 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("Complex Union Type") { - withTempPath { dir => + TestUtils.withTempDir { dir => val fixedSchema = Schema.createFixed("fixed_name", "doc", "namespace", 4) val enumSchema = Schema.createEnum("enum_name", "doc", "namespace", List("e1", "e2").asJava) val complexUnionType = Schema.createUnion( List(Schema.create(Type.INT), Schema.create(Type.STRING), fixedSchema, enumSchema).asJava) val fields = Seq( - new Field("field1", complexUnionType, "doc", null.asInstanceOf[Any]), - new Field("field2", complexUnionType, "doc", null.asInstanceOf[Any]), - new Field("field3", complexUnionType, "doc", null.asInstanceOf[Any]), - new Field("field4", complexUnionType, "doc", null.asInstanceOf[Any]) + new Field("field1", complexUnionType, "doc", null), + new Field("field2", complexUnionType, "doc", null), + new Field("field3", complexUnionType, "doc", null), + new Field("field4", complexUnionType, "doc", null) ).asJava val schema = Schema.createRecord("name", "docs", "namespace", false) schema.setFields(fields) @@ -264,7 +271,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("Lots of nulls") { - withTempPath { dir => + TestUtils.withTempDir { dir => val schema = StructType(Seq( StructField("binary", BinaryType, true), StructField("timestamp", TimestampType, true), @@ -283,7 +290,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("Struct field type") { - withTempPath { dir => + TestUtils.withTempDir { dir => val schema = StructType(Seq( StructField("float", FloatType, true), StructField("short", ShortType, true), @@ -302,7 +309,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("Date field type") { - withTempPath { dir => + TestUtils.withTempDir { dir => val schema = StructType(Seq( StructField("float", FloatType, true), StructField("date", DateType, true) @@ -322,7 +329,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("Array data types") { - withTempPath { dir => + TestUtils.withTempDir { dir => val testSchema = StructType(Seq( StructField("byte_array", ArrayType(ByteType), true), StructField("short_array", ArrayType(ShortType), true), @@ -356,12 +363,13 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("write with compression") { - withTempPath { dir => + TestUtils.withTempDir { dir => val AVRO_COMPRESSION_CODEC = "spark.sql.avro.compression.codec" val AVRO_DEFLATE_LEVEL = "spark.sql.avro.deflate.level" val uncompressDir = s"$dir/uncompress" val deflateDir = s"$dir/deflate" val snappyDir = s"$dir/snappy" + val fakeDir = s"$dir/fake" val df = spark.read.avro(testFile) spark.conf.set(AVRO_COMPRESSION_CODEC, "uncompressed") @@ -431,7 +439,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { test("sql test") { spark.sql( s""" - |CREATE TEMPORARY VIEW avroTable + |CREATE TEMPORARY TABLE avroTable |USING avro |OPTIONS (path "$episodesFile") """.stripMargin.replaceAll("\n", " ")) @@ -442,24 +450,24 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { test("conversion to avro and back") { // Note that test.avro includes a variety of types, some of which are nullable. We expect to // get the same values back. - withTempPath { dir => + TestUtils.withTempDir { dir => val avroDir = s"$dir/avro" spark.read.avro(testFile).write.avro(avroDir) - checkReloadMatchesSaved(testFile, avroDir) + TestUtils.checkReloadMatchesSaved(spark, testFile, avroDir) } } test("conversion to avro and back with namespace") { // Note that test.avro includes a variety of types, some of which are nullable. We expect to // get the same values back. - withTempPath { tempDir => + TestUtils.withTempDir { tempDir => val name = "AvroTest" val namespace = "com.databricks.spark.avro" val parameters = Map("recordName" -> name, "recordNamespace" -> namespace) val avroDir = tempDir + "/namedAvro" spark.read.avro(testFile).write.options(parameters).avro(avroDir) - checkReloadMatchesSaved(testFile, avroDir) + TestUtils.checkReloadMatchesSaved(spark, testFile, avroDir) // Look at raw file and make sure has namespace info val rawSaved = spark.sparkContext.textFile(avroDir) @@ -470,7 +478,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("converting some specific sparkSQL types to avro") { - withTempPath { tempDir => + TestUtils.withTempDir { tempDir => val testSchema = StructType(Seq( StructField("Name", StringType, false), StructField("Length", IntegerType, true), @@ -512,7 +520,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("correctly read long as date/timestamp type") { - withTempPath { tempDir => + TestUtils.withTempDir { tempDir => val sparkSession = spark import sparkSession.implicits._ @@ -541,7 +549,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("does not coerce null date/timestamp value to 0 epoch.") { - withTempPath { tempDir => + TestUtils.withTempDir { tempDir => val sparkSession = spark import sparkSession.implicits._ @@ -602,7 +610,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { // Directory given has no avro files intercept[AnalysisException] { - withTempPath(dir => spark.read.avro(dir.getCanonicalPath)) + TestUtils.withTempDir(dir => spark.read.avro(dir.getCanonicalPath)) } intercept[AnalysisException] { @@ -616,7 +624,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } intercept[FileNotFoundException] { - withTempPath { dir => + TestUtils.withTempDir { dir => FileUtils.touch(new File(dir, "test")) spark.read.avro(dir.toString) } @@ -625,19 +633,19 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("SQL test insert overwrite") { - withTempPath { tempDir => + TestUtils.withTempDir { tempDir => val tempEmptyDir = s"$tempDir/sqlOverwrite" // Create a temp directory for table that will be overwritten new File(tempEmptyDir).mkdirs() spark.sql( s""" - |CREATE TEMPORARY VIEW episodes + |CREATE TEMPORARY TABLE episodes |USING avro |OPTIONS (path "$episodesFile") """.stripMargin.replaceAll("\n", " ")) spark.sql( s""" - |CREATE TEMPORARY VIEW episodesEmpty + |CREATE TEMPORARY TABLE episodesEmpty |(name string, air_date string, doctor int) |USING avro |OPTIONS (path "$tempEmptyDir") @@ -657,7 +665,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { test("test save and load") { // Test if load works as expected - withTempPath { tempDir => + TestUtils.withTempDir { tempDir => val df = spark.read.avro(episodesFile) assert(df.count == 8) @@ -671,7 +679,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { test("test load with non-Avro file") { // Test if load works as expected - withTempPath { tempDir => + TestUtils.withTempDir { tempDir => val df = spark.read.avro(episodesFile) assert(df.count == 8) @@ -729,7 +737,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("read avro file partitioned") { - withTempPath { dir => + TestUtils.withTempDir { dir => val sparkSession = spark import sparkSession.implicits._ val df = (0 to 1024 * 3).toDS.map(i => s"record${i}").toDF("records") @@ -748,7 +756,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { case class NestedTop(id: Int, data: NestedMiddle) test("saving avro that has nested records with the same name") { - withTempPath { tempDir => + TestUtils.withTempDir { tempDir => // Save avro file on output folder path val writeDf = spark.createDataFrame(List(NestedTop(1, NestedMiddle(2, NestedBottom(3, "1"))))) val outputFolder = s"$tempDir/duplicate_names/" @@ -765,7 +773,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { case class NestedTopArray(id: Int, data: NestedMiddleArray) test("saving avro that has nested records with the same name inside an array") { - withTempPath { tempDir => + TestUtils.withTempDir { tempDir => // Save avro file on output folder path val writeDf = spark.createDataFrame( List(NestedTopArray(1, NestedMiddleArray(2, Array( @@ -786,7 +794,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { case class NestedTopMap(id: Int, data: NestedMiddleMap) test("saving avro that has nested records with the same name inside a map") { - withTempPath { tempDir => + TestUtils.withTempDir { tempDir => // Save avro file on output folder path val writeDf = spark.createDataFrame( List(NestedTopMap(1, NestedMiddleMap(2, Map( diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/TestUtils.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/TestUtils.scala new file mode 100755 index 0000000000000..4ae9b14d9ad0d --- /dev/null +++ b/external/avro/src/test/scala/org/apache/spark/sql/avro/TestUtils.scala @@ -0,0 +1,156 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.avro + +import java.io.{File, IOException} +import java.nio.ByteBuffer + +import scala.collection.immutable.HashSet +import scala.collection.mutable.ArrayBuffer +import scala.util.Random + +import com.google.common.io.Files +import java.util + +import org.apache.spark.sql.SparkSession + +private[avro] object TestUtils { + + /** + * This function checks that all records in a file match the original + * record. + */ + def checkReloadMatchesSaved(spark: SparkSession, testFile: String, avroDir: String): Unit = { + + def convertToString(elem: Any): String = { + elem match { + case null => "NULL" // HashSets can't have null in them, so we use a string instead + case arrayBuf: ArrayBuffer[_] => + arrayBuf.asInstanceOf[ArrayBuffer[Any]].toArray.deep.mkString(" ") + case arrayByte: Array[Byte] => arrayByte.deep.mkString(" ") + case other => other.toString + } + } + + val originalEntries = spark.read.avro(testFile).collect() + val newEntries = spark.read.avro(avroDir).collect() + + assert(originalEntries.length == newEntries.length) + + val origEntrySet = Array.fill(originalEntries(0).size)(new HashSet[Any]()) + for (origEntry <- originalEntries) { + var idx = 0 + for (origElement <- origEntry.toSeq) { + origEntrySet(idx) += convertToString(origElement) + idx += 1 + } + } + + for (newEntry <- newEntries) { + var idx = 0 + for (newElement <- newEntry.toSeq) { + assert(origEntrySet(idx).contains(convertToString(newElement))) + idx += 1 + } + } + } + + def withTempDir(f: File => Unit): Unit = { + val dir = Files.createTempDir() + dir.delete() + try f(dir) finally deleteRecursively(dir) + } + + /** + * This function deletes a file or a directory with everything that's in it. This function is + * copied from Spark with minor modifications made to it. See original source at: + * github.com/apache/spark/blob/master/core/src/main/scala/org/apache/spark/util/Utils.scala + */ + + def deleteRecursively(file: File) { + def listFilesSafely(file: File): Seq[File] = { + if (file.exists()) { + val files = file.listFiles() + if (files == null) { + throw new IOException("Failed to list files for dir: " + file) + } + files + } else { + List() + } + } + + if (file != null) { + try { + if (file.isDirectory) { + var savedIOException: IOException = null + for (child <- listFilesSafely(file)) { + try { + deleteRecursively(child) + } catch { + // In case of multiple exceptions, only last one will be thrown + case ioe: IOException => savedIOException = ioe + } + } + if (savedIOException != null) { + throw savedIOException + } + } + } finally { + if (!file.delete()) { + // Delete can also fail if the file simply did not exist + if (file.exists()) { + throw new IOException("Failed to delete: " + file.getAbsolutePath) + } + } + } + } + } + + /** + * This function generates a random map(string, int) of a given size. + */ + private[avro] def generateRandomMap(rand: Random, size: Int): java.util.Map[String, Int] = { + val jMap = new util.HashMap[String, Int]() + for (i <- 0 until size) { + jMap.put(rand.nextString(5), i) + } + jMap + } + + /** + * This function generates a random array of booleans of a given size. + */ + private[avro] def generateRandomArray(rand: Random, size: Int): util.ArrayList[Boolean] = { + val vec = new util.ArrayList[Boolean]() + for (i <- 0 until size) { + vec.add(rand.nextBoolean()) + } + vec + } + + /** + * This function generates a random ByteBuffer of a given size. + */ + private[avro] def generateRandomByteBuffer(rand: Random, size: Int): ByteBuffer = { + val bb = ByteBuffer.allocate(size) + val arrayOfBytes = new Array[Byte](size) + rand.nextBytes(arrayOfBytes) + bb.put(arrayOfBytes) + } +} From 3b6005b8a276e646c0785d924f139a48238a7c87 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Fri, 13 Jul 2018 11:23:42 -0700 Subject: [PATCH 1118/1161] [SPARK-23528][ML] Add numIter to ClusteringSummary ## What changes were proposed in this pull request? Added the number of iterations in `ClusteringSummary`. This is an helpful information in evaluating how to eventually modify the parameters in order to get a better model. ## How was this patch tested? modified existing UTs Author: Marco Gaido Closes #20701 from mgaido91/SPARK-23528. --- .../org/apache/spark/ml/clustering/BisectingKMeans.scala | 6 ++++-- .../apache/spark/ml/clustering/ClusteringSummary.scala | 6 ++++-- .../org/apache/spark/ml/clustering/GaussianMixture.scala | 8 +++++--- .../scala/org/apache/spark/ml/clustering/KMeans.scala | 6 ++++-- .../scala/org/apache/spark/mllib/clustering/KMeans.scala | 2 +- .../org/apache/spark/mllib/clustering/KMeansModel.scala | 9 +++++++-- .../spark/ml/clustering/BisectingKMeansSuite.scala | 1 + .../spark/ml/clustering/GaussianMixtureSuite.scala | 1 + .../org/apache/spark/ml/clustering/KMeansSuite.scala | 1 + project/MimaExcludes.scala | 5 +++++ 10 files changed, 33 insertions(+), 12 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala index 9c9614509c64f..de564471c2b4d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala @@ -274,7 +274,7 @@ class BisectingKMeans @Since("2.0.0") ( val parentModel = bkm.run(rdd) val model = copyValues(new BisectingKMeansModel(uid, parentModel).setParent(this)) val summary = new BisectingKMeansSummary( - model.transform(dataset), $(predictionCol), $(featuresCol), $(k)) + model.transform(dataset), $(predictionCol), $(featuresCol), $(k), $(maxIter)) model.setSummary(Some(summary)) instr.logNamedValue("clusterSizes", summary.clusterSizes) instr.logSuccess(model) @@ -304,6 +304,7 @@ object BisectingKMeans extends DefaultParamsReadable[BisectingKMeans] { * @param predictionCol Name for column of predicted clusters in `predictions`. * @param featuresCol Name for column of features in `predictions`. * @param k Number of clusters. + * @param numIter Number of iterations. */ @Since("2.1.0") @Experimental @@ -311,4 +312,5 @@ class BisectingKMeansSummary private[clustering] ( predictions: DataFrame, predictionCol: String, featuresCol: String, - k: Int) extends ClusteringSummary(predictions, predictionCol, featuresCol, k) + k: Int, + numIter: Int) extends ClusteringSummary(predictions, predictionCol, featuresCol, k, numIter) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/ClusteringSummary.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/ClusteringSummary.scala index 44e832b058b62..7da4c43a1abf3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/ClusteringSummary.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/ClusteringSummary.scala @@ -17,7 +17,7 @@ package org.apache.spark.ml.clustering -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.sql.{DataFrame, Row} /** @@ -28,13 +28,15 @@ import org.apache.spark.sql.{DataFrame, Row} * @param predictionCol Name for column of predicted clusters in `predictions`. * @param featuresCol Name for column of features in `predictions`. * @param k Number of clusters. + * @param numIter Number of iterations. */ @Experimental class ClusteringSummary private[clustering] ( @transient val predictions: DataFrame, val predictionCol: String, val featuresCol: String, - val k: Int) extends Serializable { + val k: Int, + @Since("2.4.0") val numIter: Int) extends Serializable { /** * Cluster centers of the transformed data. diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala index 64ecc1ebda589..dae64ba9a515d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala @@ -423,7 +423,7 @@ class GaussianMixture @Since("2.0.0") ( val model = copyValues(new GaussianMixtureModel(uid, weights, gaussianDists)).setParent(this) val summary = new GaussianMixtureSummary(model.transform(dataset), - $(predictionCol), $(probabilityCol), $(featuresCol), $(k), logLikelihood) + $(predictionCol), $(probabilityCol), $(featuresCol), $(k), logLikelihood, iter) model.setSummary(Some(summary)) instr.logNamedValue("logLikelihood", logLikelihood) instr.logNamedValue("clusterSizes", summary.clusterSizes) @@ -687,6 +687,7 @@ private class ExpectationAggregator( * @param featuresCol Name for column of features in `predictions`. * @param k Number of clusters. * @param logLikelihood Total log-likelihood for this model on the given data. + * @param numIter Number of iterations. */ @Since("2.0.0") @Experimental @@ -696,8 +697,9 @@ class GaussianMixtureSummary private[clustering] ( @Since("2.0.0") val probabilityCol: String, featuresCol: String, k: Int, - @Since("2.2.0") val logLikelihood: Double) - extends ClusteringSummary(predictions, predictionCol, featuresCol, k) { + @Since("2.2.0") val logLikelihood: Double, + numIter: Int) + extends ClusteringSummary(predictions, predictionCol, featuresCol, k, numIter) { /** * Probability of each cluster. diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index 1704412741d49..f40037a8d9aa9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -356,7 +356,7 @@ class KMeans @Since("1.5.0") ( val parentModel = algo.run(instances, Option(instr)) val model = copyValues(new KMeansModel(uid, parentModel).setParent(this)) val summary = new KMeansSummary( - model.transform(dataset), $(predictionCol), $(featuresCol), $(k)) + model.transform(dataset), $(predictionCol), $(featuresCol), $(k), parentModel.numIter) model.setSummary(Some(summary)) instr.logNamedValue("clusterSizes", summary.clusterSizes) @@ -388,6 +388,7 @@ object KMeans extends DefaultParamsReadable[KMeans] { * @param predictionCol Name for column of predicted clusters in `predictions`. * @param featuresCol Name for column of features in `predictions`. * @param k Number of clusters. + * @param numIter Number of iterations. */ @Since("2.0.0") @Experimental @@ -395,4 +396,5 @@ class KMeansSummary private[clustering] ( predictions: DataFrame, predictionCol: String, featuresCol: String, - k: Int) extends ClusteringSummary(predictions, predictionCol, featuresCol, k) + k: Int, + numIter: Int) extends ClusteringSummary(predictions, predictionCol, featuresCol, k, numIter) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala index b5b1be3490497..37ae8b1a6171a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala @@ -348,7 +348,7 @@ class KMeans private ( logInfo(s"The cost is $cost.") - new KMeansModel(centers.map(_.vector), distanceMeasure) + new KMeansModel(centers.map(_.vector), distanceMeasure, iteration) } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala index a78c21e838e44..e3a88b42fbf73 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala @@ -36,8 +36,9 @@ import org.apache.spark.sql.{Row, SparkSession} * A clustering model for K-means. Each point belongs to the cluster with the closest center. */ @Since("0.8.0") -class KMeansModel @Since("2.4.0") (@Since("1.0.0") val clusterCenters: Array[Vector], - @Since("2.4.0") val distanceMeasure: String) +class KMeansModel (@Since("1.0.0") val clusterCenters: Array[Vector], + @Since("2.4.0") val distanceMeasure: String, + private[spark] val numIter: Int) extends Saveable with Serializable with PMMLExportable { private val distanceMeasureInstance: DistanceMeasure = @@ -46,6 +47,10 @@ class KMeansModel @Since("2.4.0") (@Since("1.0.0") val clusterCenters: Array[Vec private val clusterCentersWithNorm = if (clusterCenters == null) null else clusterCenters.map(new VectorWithNorm(_)) + @Since("2.4.0") + private[spark] def this(clusterCenters: Array[Vector], distanceMeasure: String) = + this(clusterCenters: Array[Vector], distanceMeasure, -1) + @Since("1.1.0") def this(clusterCenters: Array[Vector]) = this(clusterCenters: Array[Vector], DistanceMeasure.EUCLIDEAN) diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala index 81842afbddbbb..1b7780e171e77 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala @@ -133,6 +133,7 @@ class BisectingKMeansSuite extends MLTest with DefaultReadWriteTest { assert(clusterSizes.length === k) assert(clusterSizes.sum === numRows) assert(clusterSizes.forall(_ >= 0)) + assert(summary.numIter == 20) model.setSummary(None) assert(!model.hasSummary) diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala index 0b91f502f615b..13bed9dbe3e89 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala @@ -145,6 +145,7 @@ class GaussianMixtureSuite extends MLTest with DefaultReadWriteTest { assert(clusterSizes.length === k) assert(clusterSizes.sum === numRows) assert(clusterSizes.forall(_ >= 0)) + assert(summary.numIter == 2) model.setSummary(None) assert(!model.hasSummary) diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala index 2569e7a432ca4..829c90fe34e94 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala @@ -135,6 +135,7 @@ class KMeansSuite extends MLTest with DefaultReadWriteTest with PMMLReadWriteTes assert(clusterSizes.length === k) assert(clusterSizes.sum === numRows) assert(clusterSizes.forall(_ >= 0)) + assert(summary.numIter == 1) model.setSummary(None) assert(!model.hasSummary) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index eeb097ef153ad..8f96bb0f33849 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -36,6 +36,11 @@ object MimaExcludes { // Exclude rules for 2.4.x lazy val v24excludes = v23excludes ++ Seq( + // [SPARK-23528] Add numIter to ClusteringSummary + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.clustering.ClusteringSummary.this"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.clustering.KMeansSummary.this"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.clustering.BisectingKMeansSummary.this"), + // [SPARK-6237][NETWORK] Network-layer changes to allow stream upload ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.network.netty.NettyBlockRpcServer.receive"), From a75571b46f813005a6d4b076ec39081ffab11844 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Fri, 13 Jul 2018 14:07:52 -0700 Subject: [PATCH 1119/1161] [SPARK-23831][SQL] Add org.apache.derby to IsolatedClientLoader ## What changes were proposed in this pull request? Add `org.apache.derby` to `IsolatedClientLoader`, otherwise it may throw an exception: ```scala ... [info] Cause: java.sql.SQLException: Failed to start database 'metastore_db' with class loader org.apache.spark.sql.hive.client.IsolatedClientLoader$$anon$12439ab23, see the next exception for details. [info] at org.apache.derby.impl.jdbc.SQLExceptionFactory.getSQLException(Unknown Source) [info] at org.apache.derby.impl.jdbc.SQLExceptionFactory.getSQLException(Unknown Source) [info] at org.apache.derby.impl.jdbc.Util.seeNextException(Unknown Source) [info] at org.apache.derby.impl.jdbc.EmbedConnection.bootDatabase(Unknown Source) [info] at org.apache.derby.impl.jdbc.EmbedConnection.(Unknown Source) [info] at org.apache.derby.jdbc.InternalDriver$1.run(Unknown Source) ... ``` ## How was this patch tested? unit tests and manual tests Author: Yuming Wang Closes #20944 from wangyum/SPARK-23831. --- .../apache/spark/sql/hive/client/IsolatedClientLoader.scala | 1 + .../apache/spark/sql/hive/HiveExternalCatalogSuite.scala | 6 ++++++ 2 files changed, 7 insertions(+) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala index 2f34f69b5cf48..6a90c44a2633d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala @@ -182,6 +182,7 @@ private[hive] class IsolatedClientLoader( name.startsWith("org.slf4j") || name.startsWith("org.apache.log4j") || // log4j1.x name.startsWith("org.apache.logging.log4j") || // log4j2 + name.startsWith("org.apache.derby.") || name.startsWith("org.apache.spark.") || (sharesHadoopClasses && isHadoopClass) || name.startsWith("scala.") || diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala index 0a522b6a11c80..1de258f060943 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala @@ -113,4 +113,10 @@ class HiveExternalCatalogSuite extends ExternalCatalogSuite { catalog.createDatabase(newDb("dbWithNullDesc").copy(description = null), ignoreIfExists = false) assert(catalog.getDatabase("dbWithNullDesc").description == "") } + + test("SPARK-23831: Add org.apache.derby to IsolatedClientLoader") { + val client1 = HiveUtils.newClientForMetadata(new SparkConf, new Configuration) + val client2 = HiveUtils.newClientForMetadata(new SparkConf, new Configuration) + assert(!client1.equals(client2)) + } } From f1a99ad5825daf1b4cc275146ba8460cbcdf9701 Mon Sep 17 00:00:00 2001 From: Ilan Filonenko Date: Fri, 13 Jul 2018 17:19:28 -0700 Subject: [PATCH 1120/1161] [SPARK-23984][K8S][TEST] Added Integration Tests for PySpark on Kubernetes ## What changes were proposed in this pull request? I added integration tests for PySpark ( + checking JVM options + RemoteFileTest) which wasn't properly merged in the initial integration test PR. ## How was this patch tested? I tested this with integration tests using: `dev/dev-run-integration-tests.sh --spark-tgz spark-2.4.0-SNAPSHOT-bin-2.7.3.tgz` Author: Ilan Filonenko Closes #21583 from ifilonenko/master. --- .../k8s/integrationtest/KubernetesSuite.scala | 171 +++++++++++++++--- .../KubernetesTestComponents.scala | 28 ++- 2 files changed, 167 insertions(+), 32 deletions(-) diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala index 6e334c83fbde8..774c3936b877c 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala @@ -22,7 +22,7 @@ import java.util.UUID import java.util.regex.Pattern import com.google.common.io.PatternFilenameFilter -import io.fabric8.kubernetes.api.model.{Container, Pod} +import io.fabric8.kubernetes.api.model.Pod import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} import org.scalatest.concurrent.{Eventually, PatienceConfiguration} import org.scalatest.time.{Minutes, Seconds, Span} @@ -43,6 +43,7 @@ private[spark] class KubernetesSuite extends SparkFunSuite private var kubernetesTestComponents: KubernetesTestComponents = _ private var sparkAppConf: SparkAppConf = _ private var image: String = _ + private var pyImage: String = _ private var containerLocalSparkDistroExamplesJar: String = _ private var appLocator: String = _ private var driverPodName: String = _ @@ -65,6 +66,7 @@ private[spark] class KubernetesSuite extends SparkFunSuite val imageTag = getTestImageTag val imageRepo = getTestImageRepo image = s"$imageRepo/spark:$imageTag" + pyImage = s"$imageRepo/spark-py:$imageTag" val sparkDistroExamplesJarFile: File = sparkHomeDir.resolve(Paths.get("examples", "jars")) .toFile @@ -156,22 +158,77 @@ private[spark] class KubernetesSuite extends SparkFunSuite }) } - // TODO(ssuchter): Enable the below after debugging - // test("Run PageRank using remote data file") { - // sparkAppConf - // .set("spark.kubernetes.mountDependencies.filesDownloadDir", - // CONTAINER_LOCAL_FILE_DOWNLOAD_PATH) - // .set("spark.files", REMOTE_PAGE_RANK_DATA_FILE) - // runSparkPageRankAndVerifyCompletion( - // appArgs = Array(CONTAINER_LOCAL_DOWNLOADED_PAGE_RANK_DATA_FILE)) - // } + test("Run extraJVMOptions check on driver") { + sparkAppConf + .set("spark.driver.extraJavaOptions", "-Dspark.test.foo=spark.test.bar") + runSparkJVMCheckAndVerifyCompletion( + expectedJVMValue = Seq("(spark.test.foo,spark.test.bar)")) + } + + test("Run SparkRemoteFileTest using a remote data file") { + sparkAppConf + .set("spark.files", REMOTE_PAGE_RANK_DATA_FILE) + runSparkRemoteCheckAndVerifyCompletion( + appArgs = Array(REMOTE_PAGE_RANK_FILE_NAME)) + } + + test("Run PySpark on simple pi.py example") { + sparkAppConf + .set("spark.kubernetes.container.image", s"${getTestImageRepo}/spark-py:${getTestImageTag}") + runSparkApplicationAndVerifyCompletion( + appResource = PYSPARK_PI, + mainClass = "", + expectedLogOnCompletion = Seq("Pi is roughly 3"), + appArgs = Array("5"), + driverPodChecker = doBasicDriverPyPodCheck, + executorPodChecker = doBasicExecutorPyPodCheck, + appLocator = appLocator, + isJVM = false) + } + + test("Run PySpark with Python2 to test a pyfiles example") { + sparkAppConf + .set("spark.kubernetes.container.image", s"${getTestImageRepo}/spark-py:${getTestImageTag}") + .set("spark.kubernetes.pyspark.pythonversion", "2") + runSparkApplicationAndVerifyCompletion( + appResource = PYSPARK_FILES, + mainClass = "", + expectedLogOnCompletion = Seq( + "Python runtime version check is: True", + "Python environment version check is: True"), + appArgs = Array("python"), + driverPodChecker = doBasicDriverPyPodCheck, + executorPodChecker = doBasicExecutorPyPodCheck, + appLocator = appLocator, + isJVM = false, + pyFiles = Some(PYSPARK_CONTAINER_TESTS)) + } + + test("Run PySpark with Python3 to test a pyfiles example") { + sparkAppConf + .set("spark.kubernetes.container.image", s"${getTestImageRepo}/spark-py:${getTestImageTag}") + .set("spark.kubernetes.pyspark.pythonversion", "3") + runSparkApplicationAndVerifyCompletion( + appResource = PYSPARK_FILES, + mainClass = "", + expectedLogOnCompletion = Seq( + "Python runtime version check is: True", + "Python environment version check is: True"), + appArgs = Array("python3"), + driverPodChecker = doBasicDriverPyPodCheck, + executorPodChecker = doBasicExecutorPyPodCheck, + appLocator = appLocator, + isJVM = false, + pyFiles = Some(PYSPARK_CONTAINER_TESTS)) + } private def runSparkPiAndVerifyCompletion( appResource: String = containerLocalSparkDistroExamplesJar, driverPodChecker: Pod => Unit = doBasicDriverPodCheck, executorPodChecker: Pod => Unit = doBasicExecutorPodCheck, appArgs: Array[String] = Array.empty[String], - appLocator: String = appLocator): Unit = { + appLocator: String = appLocator, + isJVM: Boolean = true ): Unit = { runSparkApplicationAndVerifyCompletion( appResource, SPARK_PI_MAIN_CLASS, @@ -179,10 +236,11 @@ private[spark] class KubernetesSuite extends SparkFunSuite appArgs, driverPodChecker, executorPodChecker, - appLocator) + appLocator, + isJVM) } - private def runSparkPageRankAndVerifyCompletion( + private def runSparkRemoteCheckAndVerifyCompletion( appResource: String = containerLocalSparkDistroExamplesJar, driverPodChecker: Pod => Unit = doBasicDriverPodCheck, executorPodChecker: Pod => Unit = doBasicExecutorPodCheck, @@ -190,12 +248,50 @@ private[spark] class KubernetesSuite extends SparkFunSuite appLocator: String = appLocator): Unit = { runSparkApplicationAndVerifyCompletion( appResource, - SPARK_PAGE_RANK_MAIN_CLASS, - Seq("1 has rank", "2 has rank", "3 has rank", "4 has rank"), + SPARK_REMOTE_MAIN_CLASS, + Seq(s"Mounting of ${appArgs.head} was true"), appArgs, driverPodChecker, executorPodChecker, - appLocator) + appLocator, + true) + } + + private def runSparkJVMCheckAndVerifyCompletion( + appResource: String = containerLocalSparkDistroExamplesJar, + mainClass: String = SPARK_DRIVER_MAIN_CLASS, + driverPodChecker: Pod => Unit = doBasicDriverPodCheck, + appArgs: Array[String] = Array("5"), + expectedJVMValue: Seq[String]): Unit = { + val appArguments = SparkAppArguments( + mainAppResource = appResource, + mainClass = mainClass, + appArgs = appArgs) + SparkAppLauncher.launch( + appArguments, + sparkAppConf, + TIMEOUT.value.toSeconds.toInt, + sparkHomeDir, + true) + + val driverPod = kubernetesTestComponents.kubernetesClient + .pods() + .withLabel("spark-app-locator", appLocator) + .withLabel("spark-role", "driver") + .list() + .getItems + .get(0) + doBasicDriverPodCheck(driverPod) + + Eventually.eventually(TIMEOUT, INTERVAL) { + expectedJVMValue.foreach { e => + assert(kubernetesTestComponents.kubernetesClient + .pods() + .withName(driverPod.getMetadata.getName) + .getLog + .contains(e), "The application did not complete.") + } + } } private def runSparkApplicationAndVerifyCompletion( @@ -205,12 +301,20 @@ private[spark] class KubernetesSuite extends SparkFunSuite appArgs: Array[String], driverPodChecker: Pod => Unit, executorPodChecker: Pod => Unit, - appLocator: String): Unit = { + appLocator: String, + isJVM: Boolean, + pyFiles: Option[String] = None): Unit = { val appArguments = SparkAppArguments( mainAppResource = appResource, mainClass = mainClass, appArgs = appArgs) - SparkAppLauncher.launch(appArguments, sparkAppConf, TIMEOUT.value.toSeconds.toInt, sparkHomeDir) + SparkAppLauncher.launch( + appArguments, + sparkAppConf, + TIMEOUT.value.toSeconds.toInt, + sparkHomeDir, + isJVM, + pyFiles) val driverPod = kubernetesTestComponents.kubernetesClient .pods() @@ -248,11 +352,22 @@ private[spark] class KubernetesSuite extends SparkFunSuite assert(driverPod.getSpec.getContainers.get(0).getName === "spark-kubernetes-driver") } + private def doBasicDriverPyPodCheck(driverPod: Pod): Unit = { + assert(driverPod.getMetadata.getName === driverPodName) + assert(driverPod.getSpec.getContainers.get(0).getImage === pyImage) + assert(driverPod.getSpec.getContainers.get(0).getName === "spark-kubernetes-driver") + } + private def doBasicExecutorPodCheck(executorPod: Pod): Unit = { assert(executorPod.getSpec.getContainers.get(0).getImage === image) assert(executorPod.getSpec.getContainers.get(0).getName === "executor") } + private def doBasicExecutorPyPodCheck(executorPod: Pod): Unit = { + assert(executorPod.getSpec.getContainers.get(0).getImage === pyImage) + assert(executorPod.getSpec.getContainers.get(0).getName === "executor") + } + private def checkCustomSettings(pod: Pod): Unit = { assert(pod.getMetadata.getLabels.get("label1") === "label1-value") assert(pod.getMetadata.getLabels.get("label2") === "label2-value") @@ -287,14 +402,22 @@ private[spark] object KubernetesSuite { val TIMEOUT = PatienceConfiguration.Timeout(Span(2, Minutes)) val INTERVAL = PatienceConfiguration.Interval(Span(2, Seconds)) val SPARK_PI_MAIN_CLASS: String = "org.apache.spark.examples.SparkPi" + val SPARK_REMOTE_MAIN_CLASS: String = "org.apache.spark.examples.SparkRemoteFileTest" + val SPARK_DRIVER_MAIN_CLASS: String = "org.apache.spark.examples.DriverSubmissionTest" val SPARK_PAGE_RANK_MAIN_CLASS: String = "org.apache.spark.examples.SparkPageRank" + val CONTAINER_LOCAL_PYSPARK: String = "local:///opt/spark/examples/src/main/python/" + val PYSPARK_PI: String = CONTAINER_LOCAL_PYSPARK + "pi.py" + val PYSPARK_FILES: String = CONTAINER_LOCAL_PYSPARK + "pyfiles.py" + val PYSPARK_CONTAINER_TESTS: String = CONTAINER_LOCAL_PYSPARK + "py_container_checks.py" - // val CONTAINER_LOCAL_FILE_DOWNLOAD_PATH = "/var/spark-data/spark-files" + val TEST_SECRET_NAME_PREFIX = "test-secret-" + val TEST_SECRET_KEY = "test-key" + val TEST_SECRET_VALUE = "test-data" + val TEST_SECRET_MOUNT_PATH = "/etc/secrets" - // val REMOTE_PAGE_RANK_DATA_FILE = - // "https://storage.googleapis.com/spark-k8s-integration-tests/files/pagerank_data.txt" - // val CONTAINER_LOCAL_DOWNLOADED_PAGE_RANK_DATA_FILE = - // s"$CONTAINER_LOCAL_FILE_DOWNLOAD_PATH/pagerank_data.txt" + val REMOTE_PAGE_RANK_DATA_FILE = + "https://storage.googleapis.com/spark-k8s-integration-tests/files/pagerank_data.txt" + val REMOTE_PAGE_RANK_FILE_NAME = "pagerank_data.txt" - // case object ShuffleNotReadyException extends Exception + case object ShuffleNotReadyException extends Exception } diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesTestComponents.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesTestComponents.scala index b2471e51116cb..a9b49a8e5a610 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesTestComponents.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesTestComponents.scala @@ -97,21 +97,33 @@ private[spark] case class SparkAppArguments( appArgs: Array[String]) private[spark] object SparkAppLauncher extends Logging { - def launch( appArguments: SparkAppArguments, appConf: SparkAppConf, timeoutSecs: Int, - sparkHomeDir: Path): Unit = { + sparkHomeDir: Path, + isJVM: Boolean, + pyFiles: Option[String] = None): Unit = { val sparkSubmitExecutable = sparkHomeDir.resolve(Paths.get("bin", "spark-submit")) logInfo(s"Launching a spark app with arguments $appArguments and conf $appConf") - val commandLine = (Array(sparkSubmitExecutable.toFile.getAbsolutePath, + val preCommandLine = if (isJVM) { + mutable.ArrayBuffer(sparkSubmitExecutable.toFile.getAbsolutePath, "--deploy-mode", "cluster", "--class", appArguments.mainClass, - "--master", appConf.get("spark.master") - ) ++ appConf.toStringArray :+ - appArguments.mainAppResource) ++ - appArguments.appArgs - ProcessUtils.executeProcess(commandLine, timeoutSecs) + "--master", appConf.get("spark.master")) + } else { + mutable.ArrayBuffer(sparkSubmitExecutable.toFile.getAbsolutePath, + "--deploy-mode", "cluster", + "--master", appConf.get("spark.master")) + } + val commandLine = + pyFiles.map(s => preCommandLine ++ Array("--py-files", s)).getOrElse(preCommandLine) ++ + appConf.toStringArray :+ appArguments.mainAppResource + + if (appArguments.appArgs.nonEmpty) { + commandLine += appArguments.appArgs.mkString(" ") + } + logInfo(s"Launching a spark app with command line: ${commandLine.mkString(" ")}") + ProcessUtils.executeProcess(commandLine.toArray, timeoutSecs) } } From e1de34113e057707dfc5ff54a8109b3ec7c16dfb Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Sat, 14 Jul 2018 17:50:54 +0800 Subject: [PATCH 1121/1161] [SPARK-17091][SQL] Add rule to convert IN predicate to equivalent Parquet filter ## What changes were proposed in this pull request? The original pr is: https://github.com/apache/spark/pull/18424 Add a new optimizer rule to convert an IN predicate to an equivalent Parquet filter and add `spark.sql.parquet.pushdown.inFilterThreshold` to control limit thresholds. Different data types have different limit thresholds, this is a copy of data for reference: Type | limit threshold -- | -- string | 370 int | 210 long | 285 double | 270 float | 220 decimal | Won't provide better performance before [SPARK-24549](https://issues.apache.org/jira/browse/SPARK-24549) ## How was this patch tested? unit tests and manual tests Author: Yuming Wang Closes #21603 from wangyum/SPARK-17091. --- .../apache/spark/sql/internal/SQLConf.scala | 15 +++ .../FilterPushdownBenchmark-results.txt | 96 +++++++++---------- .../parquet/ParquetFileFormat.scala | 15 ++- .../datasources/parquet/ParquetFilters.scala | 20 +++- .../parquet/ParquetFilterSuite.scala | 66 ++++++++++++- 5 files changed, 153 insertions(+), 59 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 14dd5281fbcb1..699e9394f5be5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -386,6 +386,18 @@ object SQLConf { .booleanConf .createWithDefault(true) + val PARQUET_FILTER_PUSHDOWN_INFILTERTHRESHOLD = + buildConf("spark.sql.parquet.pushdown.inFilterThreshold") + .doc("The maximum number of values to filter push-down optimization for IN predicate. " + + "Large threshold won't necessarily provide much better performance. " + + "The experiment argued that 300 is the limit threshold. " + + "By setting this value to 0 this feature can be disabled. " + + "This configuration only has an effect when 'spark.sql.parquet.filterPushdown' is enabled.") + .internal() + .intConf + .checkValue(threshold => threshold >= 0, "The threshold must not be negative.") + .createWithDefault(10) + val PARQUET_WRITE_LEGACY_FORMAT = buildConf("spark.sql.parquet.writeLegacyFormat") .doc("Whether to be compatible with the legacy Parquet format adopted by Spark 1.4 and prior " + "versions, when converting Parquet schema to Spark SQL schema and vice versa.") @@ -1485,6 +1497,9 @@ class SQLConf extends Serializable with Logging { def parquetFilterPushDownStringStartWith: Boolean = getConf(PARQUET_FILTER_PUSHDOWN_STRING_STARTSWITH_ENABLED) + def parquetFilterPushDownInFilterThreshold: Int = + getConf(PARQUET_FILTER_PUSHDOWN_INFILTERTHRESHOLD) + def orcFilterPushDown: Boolean = getConf(ORC_FILTER_PUSHDOWN_ENABLED) def verifyPartitionPath: Boolean = getConf(HIVE_VERIFY_PARTITION_PATH) diff --git a/sql/core/benchmarks/FilterPushdownBenchmark-results.txt b/sql/core/benchmarks/FilterPushdownBenchmark-results.txt index 110669b69a00d..c44908b3b5406 100644 --- a/sql/core/benchmarks/FilterPushdownBenchmark-results.txt +++ b/sql/core/benchmarks/FilterPushdownBenchmark-results.txt @@ -417,120 +417,120 @@ Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz InSet -> InFilters (values count: 5, distribution: 10): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Parquet Vectorized 7477 / 7587 2.1 475.4 1.0X -Parquet Vectorized (Pushdown) 7862 / 8346 2.0 499.9 1.0X -Native ORC Vectorized 6447 / 7021 2.4 409.9 1.2X -Native ORC Vectorized (Pushdown) 983 / 1003 16.0 62.5 7.6X +Parquet Vectorized 7993 / 8104 2.0 508.2 1.0X +Parquet Vectorized (Pushdown) 507 / 532 31.0 32.2 15.8X +Native ORC Vectorized 6922 / 7163 2.3 440.1 1.2X +Native ORC Vectorized (Pushdown) 1017 / 1058 15.5 64.6 7.9X Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz InSet -> InFilters (values count: 5, distribution: 50): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Parquet Vectorized 7107 / 7290 2.2 451.9 1.0X -Parquet Vectorized (Pushdown) 7196 / 7258 2.2 457.5 1.0X -Native ORC Vectorized 6102 / 6222 2.6 388.0 1.2X -Native ORC Vectorized (Pushdown) 926 / 958 17.0 58.9 7.7X +Parquet Vectorized 7855 / 7963 2.0 499.4 1.0X +Parquet Vectorized (Pushdown) 503 / 516 31.3 32.0 15.6X +Native ORC Vectorized 6825 / 6954 2.3 433.9 1.2X +Native ORC Vectorized (Pushdown) 1019 / 1044 15.4 64.8 7.7X Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz InSet -> InFilters (values count: 5, distribution: 90): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Parquet Vectorized 7374 / 7692 2.1 468.8 1.0X -Parquet Vectorized (Pushdown) 7771 / 7848 2.0 494.1 0.9X -Native ORC Vectorized 6184 / 6356 2.5 393.2 1.2X -Native ORC Vectorized (Pushdown) 920 / 963 17.1 58.5 8.0X +Parquet Vectorized 7858 / 7928 2.0 499.6 1.0X +Parquet Vectorized (Pushdown) 490 / 519 32.1 31.1 16.0X +Native ORC Vectorized 7079 / 7966 2.2 450.1 1.1X +Native ORC Vectorized (Pushdown) 1276 / 1673 12.3 81.1 6.2X Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz InSet -> InFilters (values count: 10, distribution: 10): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Parquet Vectorized 7073 / 7326 2.2 449.7 1.0X -Parquet Vectorized (Pushdown) 7304 / 7647 2.2 464.4 1.0X -Native ORC Vectorized 6222 / 6579 2.5 395.6 1.1X -Native ORC Vectorized (Pushdown) 958 / 994 16.4 60.9 7.4X +Parquet Vectorized 8007 / 11155 2.0 509.0 1.0X +Parquet Vectorized (Pushdown) 519 / 540 30.3 33.0 15.4X +Native ORC Vectorized 6848 / 7072 2.3 435.4 1.2X +Native ORC Vectorized (Pushdown) 1026 / 1050 15.3 65.2 7.8X Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz InSet -> InFilters (values count: 10, distribution: 50): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Parquet Vectorized 7121 / 7501 2.2 452.7 1.0X -Parquet Vectorized (Pushdown) 7751 / 8334 2.0 492.8 0.9X -Native ORC Vectorized 6225 / 6680 2.5 395.8 1.1X -Native ORC Vectorized (Pushdown) 998 / 1020 15.8 63.5 7.1X +Parquet Vectorized 7876 / 7956 2.0 500.7 1.0X +Parquet Vectorized (Pushdown) 521 / 535 30.2 33.1 15.1X +Native ORC Vectorized 7051 / 7368 2.2 448.3 1.1X +Native ORC Vectorized (Pushdown) 1014 / 1035 15.5 64.5 7.8X Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz InSet -> InFilters (values count: 10, distribution: 90): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Parquet Vectorized 7157 / 7399 2.2 455.1 1.0X -Parquet Vectorized (Pushdown) 7806 / 7911 2.0 496.3 0.9X -Native ORC Vectorized 6548 / 6720 2.4 416.3 1.1X -Native ORC Vectorized (Pushdown) 1016 / 1050 15.5 64.6 7.0X +Parquet Vectorized 7897 / 8229 2.0 502.1 1.0X +Parquet Vectorized (Pushdown) 513 / 530 30.7 32.6 15.4X +Native ORC Vectorized 6730 / 6990 2.3 427.9 1.2X +Native ORC Vectorized (Pushdown) 1003 / 1036 15.7 63.8 7.9X Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz InSet -> InFilters (values count: 50, distribution: 10): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Parquet Vectorized 7662 / 7805 2.1 487.1 1.0X -Parquet Vectorized (Pushdown) 7590 / 7861 2.1 482.5 1.0X -Native ORC Vectorized 6840 / 8073 2.3 434.9 1.1X -Native ORC Vectorized (Pushdown) 1041 / 1075 15.1 66.2 7.4X +Parquet Vectorized 7967 / 8175 2.0 506.5 1.0X +Parquet Vectorized (Pushdown) 8155 / 8434 1.9 518.5 1.0X +Native ORC Vectorized 7002 / 7107 2.2 445.2 1.1X +Native ORC Vectorized (Pushdown) 1092 / 1139 14.4 69.4 7.3X Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz InSet -> InFilters (values count: 50, distribution: 50): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Parquet Vectorized 8230 / 9266 1.9 523.2 1.0X -Parquet Vectorized (Pushdown) 7735 / 7960 2.0 491.8 1.1X -Native ORC Vectorized 6945 / 7109 2.3 441.6 1.2X -Native ORC Vectorized (Pushdown) 1123 / 1144 14.0 71.4 7.3X +Parquet Vectorized 8032 / 8122 2.0 510.7 1.0X +Parquet Vectorized (Pushdown) 8141 / 8908 1.9 517.6 1.0X +Native ORC Vectorized 7140 / 7387 2.2 454.0 1.1X +Native ORC Vectorized (Pushdown) 1156 / 1220 13.6 73.5 6.9X Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz InSet -> InFilters (values count: 50, distribution: 90): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Parquet Vectorized 7656 / 8058 2.1 486.7 1.0X -Parquet Vectorized (Pushdown) 7860 / 8247 2.0 499.7 1.0X -Native ORC Vectorized 6684 / 7003 2.4 424.9 1.1X -Native ORC Vectorized (Pushdown) 1085 / 1172 14.5 69.0 7.1X +Parquet Vectorized 8088 / 8350 1.9 514.2 1.0X +Parquet Vectorized (Pushdown) 8629 / 8702 1.8 548.6 0.9X +Native ORC Vectorized 7480 / 7886 2.1 475.6 1.1X +Native ORC Vectorized (Pushdown) 1106 / 1145 14.2 70.3 7.3X Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz InSet -> InFilters (values count: 100, distribution: 10): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Parquet Vectorized 7594 / 8128 2.1 482.8 1.0X -Parquet Vectorized (Pushdown) 7845 / 7923 2.0 498.8 1.0X -Native ORC Vectorized 5859 / 6421 2.7 372.5 1.3X -Native ORC Vectorized (Pushdown) 1037 / 1054 15.2 66.0 7.3X +Parquet Vectorized 8028 / 8165 2.0 510.4 1.0X +Parquet Vectorized (Pushdown) 8349 / 8674 1.9 530.8 1.0X +Native ORC Vectorized 7107 / 7354 2.2 451.8 1.1X +Native ORC Vectorized (Pushdown) 1175 / 1207 13.4 74.7 6.8X Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz InSet -> InFilters (values count: 100, distribution: 50): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Parquet Vectorized 6762 / 6775 2.3 429.9 1.0X -Parquet Vectorized (Pushdown) 6911 / 6970 2.3 439.4 1.0X -Native ORC Vectorized 5884 / 5960 2.7 374.1 1.1X -Native ORC Vectorized (Pushdown) 1028 / 1052 15.3 65.4 6.6X +Parquet Vectorized 8041 / 8195 2.0 511.2 1.0X +Parquet Vectorized (Pushdown) 8466 / 8604 1.9 538.2 0.9X +Native ORC Vectorized 7116 / 7286 2.2 452.4 1.1X +Native ORC Vectorized (Pushdown) 1197 / 1214 13.1 76.1 6.7X Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz InSet -> InFilters (values count: 100, distribution: 90): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Parquet Vectorized 6718 / 6767 2.3 427.1 1.0X -Parquet Vectorized (Pushdown) 6812 / 6909 2.3 433.1 1.0X -Native ORC Vectorized 5842 / 5883 2.7 371.4 1.1X -Native ORC Vectorized (Pushdown) 1040 / 1058 15.1 66.1 6.5X +Parquet Vectorized 7998 / 8311 2.0 508.5 1.0X +Parquet Vectorized (Pushdown) 9366 / 11257 1.7 595.5 0.9X +Native ORC Vectorized 7856 / 9273 2.0 499.5 1.0X +Native ORC Vectorized (Pushdown) 1350 / 1747 11.7 85.8 5.9X ================================================================================================ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala index b86b97ec7b103..efddf8d68eb8f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala @@ -334,17 +334,15 @@ class ParquetFileFormat val enableVectorizedReader: Boolean = sqlConf.parquetVectorizedReaderEnabled && resultSchema.forall(_.dataType.isInstanceOf[AtomicType]) - val enableRecordFilter: Boolean = - sparkSession.sessionState.conf.parquetRecordFilterEnabled - val timestampConversion: Boolean = - sparkSession.sessionState.conf.isParquetINT96TimestampConversion + val enableRecordFilter: Boolean = sqlConf.parquetRecordFilterEnabled + val timestampConversion: Boolean = sqlConf.isParquetINT96TimestampConversion val capacity = sqlConf.parquetVectorizedReaderBatchSize - val enableParquetFilterPushDown: Boolean = - sparkSession.sessionState.conf.parquetFilterPushDown + val enableParquetFilterPushDown: Boolean = sqlConf.parquetFilterPushDown // Whole stage codegen (PhysicalRDD) is able to deal with batches directly val returningBatch = supportBatch(sparkSession, resultSchema) val pushDownDate = sqlConf.parquetFilterPushDownDate val pushDownStringStartWith = sqlConf.parquetFilterPushDownStringStartWith + val pushDownInFilterThreshold = sqlConf.parquetFilterPushDownInFilterThreshold (file: PartitionedFile) => { assert(file.partitionValues.numFields == partitionSchema.size) @@ -368,12 +366,13 @@ class ParquetFileFormat val pushed = if (enableParquetFilterPushDown) { val parquetSchema = ParquetFileReader.readFooter(sharedConf, filePath, SKIP_ROW_GROUPS) .getFileMetaData.getSchema + val parquetFilters = new ParquetFilters(pushDownDate, + pushDownStringStartWith, pushDownInFilterThreshold) filters // Collects all converted Parquet filter predicates. Notice that not all predicates can be // converted (`ParquetFilters.createFilter` returns an `Option`). That's why a `flatMap` // is used here. - .flatMap(new ParquetFilters(pushDownDate, pushDownStringStartWith) - .createFilter(parquetSchema, _)) + .flatMap(parquetFilters.createFilter(parquetSchema, _)) .reduceOption(FilterApi.and) } else { None diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala index 4c9b940db2b30..e590c153c4151 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala @@ -37,7 +37,10 @@ import org.apache.spark.unsafe.types.UTF8String /** * Some utility function to convert Spark data source filters to Parquet filters. */ -private[parquet] class ParquetFilters(pushDownDate: Boolean, pushDownStartWith: Boolean) { +private[parquet] class ParquetFilters( + pushDownDate: Boolean, + pushDownStartWith: Boolean, + pushDownInFilterThreshold: Int) { private case class ParquetSchemaType( originalType: OriginalType, @@ -232,6 +235,15 @@ private[parquet] class ParquetFilters(pushDownDate: Boolean, pushDownStartWith: // See SPARK-20364. def canMakeFilterOn(name: String): Boolean = nameToType.contains(name) && !name.contains(".") + // All DataTypes that support `makeEq` can provide better performance. + def shouldConvertInPredicate(name: String): Boolean = nameToType(name) match { + case ParquetBooleanType | ParquetByteType | ParquetShortType | ParquetIntegerType + | ParquetLongType | ParquetFloatType | ParquetDoubleType | ParquetStringType + | ParquetBinaryType => true + case ParquetDateType if pushDownDate => true + case _ => false + } + // NOTE: // // For any comparison operator `cmp`, both `a cmp NULL` and `NULL cmp a` evaluate to `NULL`, @@ -295,6 +307,12 @@ private[parquet] class ParquetFilters(pushDownDate: Boolean, pushDownStartWith: case sources.Not(pred) => createFilter(schema, pred).map(FilterApi.not) + case sources.In(name, values) if canMakeFilterOn(name) && shouldConvertInPredicate(name) + && values.distinct.length <= pushDownInFilterThreshold => + values.distinct.flatMap { v => + makeEq.lift(nameToType(name)).map(_(name, v)) + }.reduceLeftOption(FilterApi.or) + case sources.StringStartsWith(name, prefix) if pushDownStartWith && canMakeFilterOn(name) => Option(prefix).map { v => FilterApi.userDefined(binaryColumn(name), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index 067d2fea14fd7..00c191f755520 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.datasources.parquet import java.nio.charset.StandardCharsets import java.sql.Date -import org.apache.parquet.filter2.predicate.{FilterPredicate, Operators} +import org.apache.parquet.filter2.predicate.{FilterApi, FilterPredicate, Operators} import org.apache.parquet.filter2.predicate.FilterApi._ import org.apache.parquet.filter2.predicate.Operators.{Column => _, _} @@ -56,7 +56,8 @@ import org.apache.spark.util.{AccumulatorContext, AccumulatorV2} class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContext { private lazy val parquetFilters = - new ParquetFilters(conf.parquetFilterPushDownDate, conf.parquetFilterPushDownStringStartWith) + new ParquetFilters(conf.parquetFilterPushDownDate, conf.parquetFilterPushDownStringStartWith, + conf.parquetFilterPushDownInFilterThreshold) override def beforeEach(): Unit = { super.beforeEach() @@ -803,6 +804,67 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex // Test inverseCanDrop() has taken effect testStringStartsWith(spark.range(1024).map(c => "100").toDF(), "value not like '10%'") } + + test("SPARK-17091: Convert IN predicate to Parquet filter push-down") { + val schema = StructType(Seq( + StructField("a", IntegerType, nullable = false) + )) + + val parquetSchema = new SparkToParquetSchemaConverter(conf).convert(schema) + + assertResult(Some(FilterApi.eq(intColumn("a"), null: Integer))) { + parquetFilters.createFilter(parquetSchema, sources.In("a", Array(null))) + } + + assertResult(Some(FilterApi.eq(intColumn("a"), 10: Integer))) { + parquetFilters.createFilter(parquetSchema, sources.In("a", Array(10))) + } + + // Remove duplicates + assertResult(Some(FilterApi.eq(intColumn("a"), 10: Integer))) { + parquetFilters.createFilter(parquetSchema, sources.In("a", Array(10, 10))) + } + + assertResult(Some(or(or( + FilterApi.eq(intColumn("a"), 10: Integer), + FilterApi.eq(intColumn("a"), 20: Integer)), + FilterApi.eq(intColumn("a"), 30: Integer))) + ) { + parquetFilters.createFilter(parquetSchema, sources.In("a", Array(10, 20, 30))) + } + + assert(parquetFilters.createFilter(parquetSchema, sources.In("a", + Range(0, conf.parquetFilterPushDownInFilterThreshold).toArray)).isDefined) + assert(parquetFilters.createFilter(parquetSchema, sources.In("a", + Range(0, conf.parquetFilterPushDownInFilterThreshold + 1).toArray)).isEmpty) + + import testImplicits._ + withTempPath { path => + val data = 0 to 1024 + data.toDF("a").selectExpr("if (a = 1024, null, a) AS a") // convert 1024 to null + .coalesce(1).write.option("parquet.block.size", 512) + .parquet(path.getAbsolutePath) + val df = spark.read.parquet(path.getAbsolutePath) + Seq(true, false).foreach { pushEnabled => + withSQLConf( + SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> pushEnabled.toString) { + Seq(1, 5, 10, 11).foreach { count => + val filter = s"a in(${Range(0, count).mkString(",")})" + assert(df.where(filter).count() === count) + val actual = stripSparkFilter(df.where(filter)).collect().length + if (pushEnabled && count <= conf.parquetFilterPushDownInFilterThreshold) { + assert(actual > 1 && actual < data.length) + } else { + assert(actual === data.length) + } + } + assert(df.where("a in(null)").count() === 0) + assert(df.where("a = null").count() === 0) + assert(df.where("a is null").count() === 1) + } + } + } + } } class NumRowGroupsAcc extends AccumulatorV2[Integer, Integer] { From 8aceb961c3b8e462c6002dbe03be61b4fe194f47 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Sat, 14 Jul 2018 15:59:17 -0500 Subject: [PATCH 1122/1161] [SPARK-24754][ML] Minhash integer overflow ## What changes were proposed in this pull request? Use longs in calculating min hash to avoid bias due to int overflow. ## How was this patch tested? Existing tests. Author: Sean Owen Closes #21750 from srowen/SPARK-24754. --- .../main/scala/org/apache/spark/ml/feature/MinHashLSH.scala | 2 +- python/pyspark/ml/feature.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala index a67a3b0abbc1f..a043033e96724 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala @@ -66,7 +66,7 @@ class MinHashLSHModel private[ml]( val elemsList = elems.toSparse.indices.toList val hashValues = randCoefficients.map { case (a, b) => elemsList.map { elem: Int => - ((1 + elem) * a + b) % MinHashLSH.HASH_PRIME + ((1L + elem) * a + b) % MinHashLSH.HASH_PRIME }.min.toDouble } // TODO: Output vectors of dimension numHashFunctions in SPARK-18450 diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 14800d4d9327a..ddba7389145e3 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -1294,14 +1294,14 @@ class MinHashLSH(JavaEstimator, LSHParams, HasInputCol, HasOutputCol, HasSeed, >>> mh = MinHashLSH(inputCol="features", outputCol="hashes", seed=12345) >>> model = mh.fit(df) >>> model.transform(df).head() - Row(id=0, features=SparseVector(6, {0: 1.0, 1: 1.0, 2: 1.0}), hashes=[DenseVector([-1638925... + Row(id=0, features=SparseVector(6, {0: 1.0, 1: 1.0, 2: 1.0}), hashes=[DenseVector([6179668... >>> data2 = [(3, Vectors.sparse(6, [1, 3, 5], [1.0, 1.0, 1.0]),), ... (4, Vectors.sparse(6, [2, 3, 5], [1.0, 1.0, 1.0]),), ... (5, Vectors.sparse(6, [1, 2, 4], [1.0, 1.0, 1.0]),)] >>> df2 = spark.createDataFrame(data2, ["id", "features"]) >>> key = Vectors.sparse(6, [1, 2], [1.0, 1.0]) >>> model.approxNearestNeighbors(df2, key, 1).collect() - [Row(id=5, features=SparseVector(6, {1: 1.0, 2: 1.0, 4: 1.0}), hashes=[DenseVector([-163892... + [Row(id=5, features=SparseVector(6, {1: 1.0, 2: 1.0, 4: 1.0}), hashes=[DenseVector([6179668... >>> model.approxSimilarityJoin(df, df2, 0.6, distCol="JaccardDistance").select( ... col("datasetA.id").alias("idA"), ... col("datasetB.id").alias("idB"), @@ -1309,8 +1309,8 @@ class MinHashLSH(JavaEstimator, LSHParams, HasInputCol, HasOutputCol, HasSeed, +---+---+---------------+ |idA|idB|JaccardDistance| +---+---+---------------+ - | 1| 4| 0.5| | 0| 5| 0.5| + | 1| 4| 0.5| +---+---+---------------+ ... >>> mhPath = temp_path + "/mh" From 43e4e851b642bbee535d22e1b9e72ec6b99f6ed4 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Sun, 15 Jul 2018 11:13:49 +0800 Subject: [PATCH 1123/1161] [SPARK-24718][SQL] Timestamp support pushdown to parquet data source ## What changes were proposed in this pull request? `Timestamp` support pushdown to parquet data source. Only `TIMESTAMP_MICROS` and `TIMESTAMP_MILLIS` support push down. ## How was this patch tested? unit tests and benchmark tests Author: Yuming Wang Closes #21741 from wangyum/SPARK-24718. --- .../apache/spark/sql/internal/SQLConf.scala | 11 ++ .../FilterPushdownBenchmark-results.txt | 124 ++++++++++++++++++ .../parquet/ParquetFileFormat.scala | 3 +- .../datasources/parquet/ParquetFilters.scala | 59 ++++++++- .../benchmark/FilterPushdownBenchmark.scala | 37 +++++- .../parquet/ParquetFilterSuite.scala | 74 ++++++++++- 6 files changed, 301 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 699e9394f5be5..07d33fa7d52ae 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -378,6 +378,15 @@ object SQLConf { .booleanConf .createWithDefault(true) + val PARQUET_FILTER_PUSHDOWN_TIMESTAMP_ENABLED = + buildConf("spark.sql.parquet.filterPushdown.timestamp") + .doc("If true, enables Parquet filter push-down optimization for Timestamp. " + + "This configuration only has an effect when 'spark.sql.parquet.filterPushdown' is " + + "enabled and Timestamp stored as TIMESTAMP_MICROS or TIMESTAMP_MILLIS type.") + .internal() + .booleanConf + .createWithDefault(true) + val PARQUET_FILTER_PUSHDOWN_STRING_STARTSWITH_ENABLED = buildConf("spark.sql.parquet.filterPushdown.string.startsWith") .doc("If true, enables Parquet filter push-down optimization for string startsWith function. " + @@ -1494,6 +1503,8 @@ class SQLConf extends Serializable with Logging { def parquetFilterPushDownDate: Boolean = getConf(PARQUET_FILTER_PUSHDOWN_DATE_ENABLED) + def parquetFilterPushDownTimestamp: Boolean = getConf(PARQUET_FILTER_PUSHDOWN_TIMESTAMP_ENABLED) + def parquetFilterPushDownStringStartWith: Boolean = getConf(PARQUET_FILTER_PUSHDOWN_STRING_STARTSWITH_ENABLED) diff --git a/sql/core/benchmarks/FilterPushdownBenchmark-results.txt b/sql/core/benchmarks/FilterPushdownBenchmark-results.txt index c44908b3b5406..4f38cc4cee96d 100644 --- a/sql/core/benchmarks/FilterPushdownBenchmark-results.txt +++ b/sql/core/benchmarks/FilterPushdownBenchmark-results.txt @@ -578,3 +578,127 @@ Native ORC Vectorized 11622 / 12196 1.4 7 Native ORC Vectorized (Pushdown) 11377 / 11654 1.4 723.3 1.0X +================================================================================================ +Pushdown benchmark for Timestamp +================================================================================================ + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 1 timestamp stored as INT96 row (value = CAST(7864320 AS timestamp)): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 4784 / 4956 3.3 304.2 1.0X +Parquet Vectorized (Pushdown) 4838 / 4917 3.3 307.6 1.0X +Native ORC Vectorized 3923 / 4173 4.0 249.4 1.2X +Native ORC Vectorized (Pushdown) 894 / 943 17.6 56.8 5.4X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 10% timestamp stored as INT96 rows (value < CAST(1572864 AS timestamp)): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 5686 / 5901 2.8 361.5 1.0X +Parquet Vectorized (Pushdown) 5555 / 5895 2.8 353.2 1.0X +Native ORC Vectorized 4844 / 4957 3.2 308.0 1.2X +Native ORC Vectorized (Pushdown) 2141 / 2230 7.3 136.1 2.7X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 50% timestamp stored as INT96 rows (value < CAST(7864320 AS timestamp)): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 9100 / 9421 1.7 578.6 1.0X +Parquet Vectorized (Pushdown) 9122 / 9496 1.7 580.0 1.0X +Native ORC Vectorized 8365 / 8874 1.9 531.9 1.1X +Native ORC Vectorized (Pushdown) 7128 / 7376 2.2 453.2 1.3X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 90% timestamp stored as INT96 rows (value < CAST(14155776 AS timestamp)): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 12764 / 13120 1.2 811.5 1.0X +Parquet Vectorized (Pushdown) 12656 / 13003 1.2 804.7 1.0X +Native ORC Vectorized 13096 / 13233 1.2 832.6 1.0X +Native ORC Vectorized (Pushdown) 12710 / 15611 1.2 808.1 1.0X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 1 timestamp stored as TIMESTAMP_MICROS row (value = CAST(7864320 AS timestamp)): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 4381 / 4796 3.6 278.5 1.0X +Parquet Vectorized (Pushdown) 122 / 137 129.3 7.7 36.0X +Native ORC Vectorized 3913 / 3988 4.0 248.8 1.1X +Native ORC Vectorized (Pushdown) 905 / 945 17.4 57.6 4.8X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 10% timestamp stored as TIMESTAMP_MICROS rows (value < CAST(1572864 AS timestamp)): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 5145 / 5184 3.1 327.1 1.0X +Parquet Vectorized (Pushdown) 1426 / 1519 11.0 90.7 3.6X +Native ORC Vectorized 4827 / 4901 3.3 306.9 1.1X +Native ORC Vectorized (Pushdown) 2133 / 2210 7.4 135.6 2.4X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 50% timestamp stored as TIMESTAMP_MICROS rows (value < CAST(7864320 AS timestamp)): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 9234 / 9516 1.7 587.1 1.0X +Parquet Vectorized (Pushdown) 6752 / 7046 2.3 429.3 1.4X +Native ORC Vectorized 8418 / 8998 1.9 535.2 1.1X +Native ORC Vectorized (Pushdown) 7199 / 7314 2.2 457.7 1.3X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 90% timestamp stored as TIMESTAMP_MICROS rows (value < CAST(14155776 AS timestamp)): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 12414 / 12458 1.3 789.2 1.0X +Parquet Vectorized (Pushdown) 12094 / 12249 1.3 768.9 1.0X +Native ORC Vectorized 12198 / 13755 1.3 775.5 1.0X +Native ORC Vectorized (Pushdown) 12205 / 12431 1.3 776.0 1.0X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 1 timestamp stored as TIMESTAMP_MILLIS row (value = CAST(7864320 AS timestamp)): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 4369 / 4515 3.6 277.8 1.0X +Parquet Vectorized (Pushdown) 116 / 125 136.2 7.3 37.8X +Native ORC Vectorized 3965 / 4703 4.0 252.1 1.1X +Native ORC Vectorized (Pushdown) 892 / 1162 17.6 56.7 4.9X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 10% timestamp stored as TIMESTAMP_MILLIS rows (value < CAST(1572864 AS timestamp)): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 5211 / 5409 3.0 331.3 1.0X +Parquet Vectorized (Pushdown) 1427 / 1438 11.0 90.7 3.7X +Native ORC Vectorized 4719 / 4883 3.3 300.1 1.1X +Native ORC Vectorized (Pushdown) 2191 / 2228 7.2 139.3 2.4X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 50% timestamp stored as TIMESTAMP_MILLIS rows (value < CAST(7864320 AS timestamp)): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 8716 / 8953 1.8 554.2 1.0X +Parquet Vectorized (Pushdown) 6632 / 6968 2.4 421.7 1.3X +Native ORC Vectorized 8376 / 9118 1.9 532.5 1.0X +Native ORC Vectorized (Pushdown) 7218 / 7609 2.2 458.9 1.2X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz + +Select 90% timestamp stored as TIMESTAMP_MILLIS rows (value < CAST(14155776 AS timestamp)): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +Parquet Vectorized 12264 / 12452 1.3 779.7 1.0X +Parquet Vectorized (Pushdown) 11766 / 11927 1.3 748.0 1.0X +Native ORC Vectorized 12101 / 12301 1.3 769.3 1.0X +Native ORC Vectorized (Pushdown) 11983 / 12651 1.3 761.9 1.0X + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala index efddf8d68eb8f..3ec33b2f4b540 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala @@ -341,6 +341,7 @@ class ParquetFileFormat // Whole stage codegen (PhysicalRDD) is able to deal with batches directly val returningBatch = supportBatch(sparkSession, resultSchema) val pushDownDate = sqlConf.parquetFilterPushDownDate + val pushDownTimestamp = sqlConf.parquetFilterPushDownTimestamp val pushDownStringStartWith = sqlConf.parquetFilterPushDownStringStartWith val pushDownInFilterThreshold = sqlConf.parquetFilterPushDownInFilterThreshold @@ -366,7 +367,7 @@ class ParquetFileFormat val pushed = if (enableParquetFilterPushDown) { val parquetSchema = ParquetFileReader.readFooter(sharedConf, filePath, SKIP_ROW_GROUPS) .getFileMetaData.getSchema - val parquetFilters = new ParquetFilters(pushDownDate, + val parquetFilters = new ParquetFilters(pushDownDate, pushDownTimestamp, pushDownStringStartWith, pushDownInFilterThreshold) filters // Collects all converted Parquet filter predicates. Notice that not all predicates can be diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala index e590c153c4151..0c146f2f6f915 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala @@ -17,14 +17,15 @@ package org.apache.spark.sql.execution.datasources.parquet -import java.sql.Date +import java.lang.{Long => JLong} +import java.sql.{Date, Timestamp} import scala.collection.JavaConverters.asScalaBufferConverter import org.apache.parquet.filter2.predicate._ import org.apache.parquet.filter2.predicate.FilterApi._ import org.apache.parquet.io.api.Binary -import org.apache.parquet.schema.{DecimalMetadata, MessageType, OriginalType, PrimitiveComparator, PrimitiveType} +import org.apache.parquet.schema.{DecimalMetadata, MessageType, OriginalType, PrimitiveComparator} import org.apache.parquet.schema.OriginalType._ import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName._ @@ -39,6 +40,7 @@ import org.apache.spark.unsafe.types.UTF8String */ private[parquet] class ParquetFilters( pushDownDate: Boolean, + pushDownTimestamp: Boolean, pushDownStartWith: Boolean, pushDownInFilterThreshold: Int) { @@ -57,6 +59,8 @@ private[parquet] class ParquetFilters( private val ParquetStringType = ParquetSchemaType(UTF8, BINARY, null) private val ParquetBinaryType = ParquetSchemaType(null, BINARY, null) private val ParquetDateType = ParquetSchemaType(DATE, INT32, null) + private val ParquetTimestampMicrosType = ParquetSchemaType(TIMESTAMP_MICROS, INT64, null) + private val ParquetTimestampMillisType = ParquetSchemaType(TIMESTAMP_MILLIS, INT64, null) private def dateToDays(date: Date): SQLDate = { DateTimeUtils.fromJavaDate(date) @@ -89,6 +93,15 @@ private[parquet] class ParquetFilters( (n: String, v: Any) => FilterApi.eq( intColumn(n), Option(v).map(date => dateToDays(date.asInstanceOf[Date]).asInstanceOf[Integer]).orNull) + case ParquetTimestampMicrosType if pushDownTimestamp => + (n: String, v: Any) => FilterApi.eq( + longColumn(n), + Option(v).map(t => DateTimeUtils.fromJavaTimestamp(t.asInstanceOf[Timestamp]) + .asInstanceOf[JLong]).orNull) + case ParquetTimestampMillisType if pushDownTimestamp => + (n: String, v: Any) => FilterApi.eq( + longColumn(n), + Option(v).map(_.asInstanceOf[Timestamp].getTime.asInstanceOf[JLong]).orNull) } private val makeNotEq: PartialFunction[ParquetSchemaType, (String, Any) => FilterPredicate] = { @@ -117,6 +130,15 @@ private[parquet] class ParquetFilters( (n: String, v: Any) => FilterApi.notEq( intColumn(n), Option(v).map(date => dateToDays(date.asInstanceOf[Date]).asInstanceOf[Integer]).orNull) + case ParquetTimestampMicrosType if pushDownTimestamp => + (n: String, v: Any) => FilterApi.notEq( + longColumn(n), + Option(v).map(t => DateTimeUtils.fromJavaTimestamp(t.asInstanceOf[Timestamp]) + .asInstanceOf[JLong]).orNull) + case ParquetTimestampMillisType if pushDownTimestamp => + (n: String, v: Any) => FilterApi.notEq( + longColumn(n), + Option(v).map(_.asInstanceOf[Timestamp].getTime.asInstanceOf[JLong]).orNull) } private val makeLt: PartialFunction[ParquetSchemaType, (String, Any) => FilterPredicate] = { @@ -139,6 +161,14 @@ private[parquet] class ParquetFilters( case ParquetDateType if pushDownDate => (n: String, v: Any) => FilterApi.lt(intColumn(n), dateToDays(v.asInstanceOf[Date]).asInstanceOf[Integer]) + case ParquetTimestampMicrosType if pushDownTimestamp => + (n: String, v: Any) => FilterApi.lt( + longColumn(n), + DateTimeUtils.fromJavaTimestamp(v.asInstanceOf[Timestamp]).asInstanceOf[JLong]) + case ParquetTimestampMillisType if pushDownTimestamp => + (n: String, v: Any) => FilterApi.lt( + longColumn(n), + v.asInstanceOf[Timestamp].getTime.asInstanceOf[JLong]) } private val makeLtEq: PartialFunction[ParquetSchemaType, (String, Any) => FilterPredicate] = { @@ -161,6 +191,14 @@ private[parquet] class ParquetFilters( case ParquetDateType if pushDownDate => (n: String, v: Any) => FilterApi.ltEq(intColumn(n), dateToDays(v.asInstanceOf[Date]).asInstanceOf[Integer]) + case ParquetTimestampMicrosType if pushDownTimestamp => + (n: String, v: Any) => FilterApi.ltEq( + longColumn(n), + DateTimeUtils.fromJavaTimestamp(v.asInstanceOf[Timestamp]).asInstanceOf[JLong]) + case ParquetTimestampMillisType if pushDownTimestamp => + (n: String, v: Any) => FilterApi.ltEq( + longColumn(n), + v.asInstanceOf[Timestamp].getTime.asInstanceOf[JLong]) } private val makeGt: PartialFunction[ParquetSchemaType, (String, Any) => FilterPredicate] = { @@ -183,6 +221,14 @@ private[parquet] class ParquetFilters( case ParquetDateType if pushDownDate => (n: String, v: Any) => FilterApi.gt(intColumn(n), dateToDays(v.asInstanceOf[Date]).asInstanceOf[Integer]) + case ParquetTimestampMicrosType if pushDownTimestamp => + (n: String, v: Any) => FilterApi.gt( + longColumn(n), + DateTimeUtils.fromJavaTimestamp(v.asInstanceOf[Timestamp]).asInstanceOf[JLong]) + case ParquetTimestampMillisType if pushDownTimestamp => + (n: String, v: Any) => FilterApi.gt( + longColumn(n), + v.asInstanceOf[Timestamp].getTime.asInstanceOf[JLong]) } private val makeGtEq: PartialFunction[ParquetSchemaType, (String, Any) => FilterPredicate] = { @@ -205,6 +251,14 @@ private[parquet] class ParquetFilters( case ParquetDateType if pushDownDate => (n: String, v: Any) => FilterApi.gtEq(intColumn(n), dateToDays(v.asInstanceOf[Date]).asInstanceOf[Integer]) + case ParquetTimestampMicrosType if pushDownTimestamp => + (n: String, v: Any) => FilterApi.gtEq( + longColumn(n), + DateTimeUtils.fromJavaTimestamp(v.asInstanceOf[Timestamp]).asInstanceOf[JLong]) + case ParquetTimestampMillisType if pushDownTimestamp => + (n: String, v: Any) => FilterApi.gtEq( + longColumn(n), + v.asInstanceOf[Timestamp].getTime.asInstanceOf[JLong]) } /** @@ -241,6 +295,7 @@ private[parquet] class ParquetFilters( | ParquetLongType | ParquetFloatType | ParquetDoubleType | ParquetStringType | ParquetBinaryType => true case ParquetDateType if pushDownDate => true + case ParquetTimestampMicrosType | ParquetTimestampMillisType if pushDownTimestamp => true case _ => false } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/FilterPushdownBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/FilterPushdownBenchmark.scala index fc716dec9f337..567a8ebf9d102 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/FilterPushdownBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/FilterPushdownBenchmark.scala @@ -28,7 +28,8 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.{DataFrame, SparkSession} import org.apache.spark.sql.functions.monotonically_increasing_id import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{ByteType, Decimal, DecimalType} +import org.apache.spark.sql.internal.SQLConf.ParquetOutputTimestampType +import org.apache.spark.sql.types.{ByteType, Decimal, DecimalType, TimestampType} import org.apache.spark.util.{Benchmark, Utils} /** @@ -359,6 +360,40 @@ class FilterPushdownBenchmark extends SparkFunSuite with BenchmarkBeforeAndAfter } } } + + ignore(s"Pushdown benchmark for Timestamp") { + withTempPath { dir => + withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_TIMESTAMP_ENABLED.key -> true.toString) { + ParquetOutputTimestampType.values.toSeq.map(_.toString).foreach { fileType => + withSQLConf(SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key -> fileType) { + val columns = (1 to width).map(i => s"CAST(id AS string) c$i") + val df = spark.range(numRows).selectExpr(columns: _*) + .withColumn("value", monotonically_increasing_id().cast(TimestampType)) + withTempTable("orcTable", "patquetTable") { + saveAsTable(df, dir) + + Seq(s"value = CAST($mid AS timestamp)").foreach { whereExpr => + val title = s"Select 1 timestamp stored as $fileType row ($whereExpr)" + .replace("value AND value", "value") + filterPushDownBenchmark(numRows, title, whereExpr) + } + + val selectExpr = (1 to width).map(i => s"MAX(c$i)").mkString("", ",", ", MAX(value)") + Seq(10, 50, 90).foreach { percent => + filterPushDownBenchmark( + numRows, + s"Select $percent% timestamp stored as $fileType rows " + + s"(value < CAST(${numRows * percent / 100} AS timestamp))", + s"value < CAST(${numRows * percent / 100} as timestamp)", + selectExpr + ) + } + } + } + } + } + } + } } trait BenchmarkBeforeAndAfterEachTest extends BeforeAndAfterEachTestData { this: Suite => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index 00c191f755520..924f136503656 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.datasources.parquet import java.nio.charset.StandardCharsets -import java.sql.Date +import java.sql.{Date, Timestamp} import org.apache.parquet.filter2.predicate.{FilterApi, FilterPredicate, Operators} import org.apache.parquet.filter2.predicate.FilterApi._ @@ -31,6 +31,7 @@ import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, HadoopFsRelation, LogicalRelation} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.SQLConf.ParquetOutputTimestampType import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.util.{AccumulatorContext, AccumulatorV2} @@ -56,8 +57,8 @@ import org.apache.spark.util.{AccumulatorContext, AccumulatorV2} class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContext { private lazy val parquetFilters = - new ParquetFilters(conf.parquetFilterPushDownDate, conf.parquetFilterPushDownStringStartWith, - conf.parquetFilterPushDownInFilterThreshold) + new ParquetFilters(conf.parquetFilterPushDownDate, conf.parquetFilterPushDownTimestamp, + conf.parquetFilterPushDownStringStartWith, conf.parquetFilterPushDownInFilterThreshold) override def beforeEach(): Unit = { super.beforeEach() @@ -84,6 +85,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex withSQLConf( SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> "true", SQLConf.PARQUET_FILTER_PUSHDOWN_DATE_ENABLED.key -> "true", + SQLConf.PARQUET_FILTER_PUSHDOWN_TIMESTAMP_ENABLED.key -> "true", SQLConf.PARQUET_FILTER_PUSHDOWN_STRING_STARTSWITH_ENABLED.key -> "true", SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { val query = df @@ -144,6 +146,39 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex checkBinaryFilterPredicate(predicate, filterClass, Seq(Row(expected)))(df) } + private def testTimestampPushdown(data: Seq[Timestamp]): Unit = { + assert(data.size === 4) + val ts1 = data.head + val ts2 = data(1) + val ts3 = data(2) + val ts4 = data(3) + + withParquetDataFrame(data.map(i => Tuple1(i))) { implicit df => + checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) + checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], data.map(i => Row.apply(i))) + + checkFilterPredicate('_1 === ts1, classOf[Eq[_]], ts1) + checkFilterPredicate('_1 <=> ts1, classOf[Eq[_]], ts1) + checkFilterPredicate('_1 =!= ts1, classOf[NotEq[_]], + Seq(ts2, ts3, ts4).map(i => Row.apply(i))) + + checkFilterPredicate('_1 < ts2, classOf[Lt[_]], ts1) + checkFilterPredicate('_1 > ts1, classOf[Gt[_]], Seq(ts2, ts3, ts4).map(i => Row.apply(i))) + checkFilterPredicate('_1 <= ts1, classOf[LtEq[_]], ts1) + checkFilterPredicate('_1 >= ts4, classOf[GtEq[_]], ts4) + + checkFilterPredicate(Literal(ts1) === '_1, classOf[Eq[_]], ts1) + checkFilterPredicate(Literal(ts1) <=> '_1, classOf[Eq[_]], ts1) + checkFilterPredicate(Literal(ts2) > '_1, classOf[Lt[_]], ts1) + checkFilterPredicate(Literal(ts3) < '_1, classOf[Gt[_]], ts4) + checkFilterPredicate(Literal(ts1) >= '_1, classOf[LtEq[_]], ts1) + checkFilterPredicate(Literal(ts4) <= '_1, classOf[GtEq[_]], ts4) + + checkFilterPredicate(!('_1 < ts4), classOf[GtEq[_]], ts4) + checkFilterPredicate('_1 < ts2 || '_1 > ts3, classOf[Operators.Or], Seq(Row(ts1), Row(ts4))) + } + } + // This function tests that exactly go through the `canDrop` and `inverseCanDrop`. private def testStringStartsWith(dataFrame: DataFrame, filter: String): Unit = { withTempPath { dir => @@ -444,6 +479,39 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex } } + test("filter pushdown - timestamp") { + // spark.sql.parquet.outputTimestampType = TIMESTAMP_MILLIS + val millisData = Seq(Timestamp.valueOf("2018-06-14 08:28:53.123"), + Timestamp.valueOf("2018-06-15 08:28:53.123"), + Timestamp.valueOf("2018-06-16 08:28:53.123"), + Timestamp.valueOf("2018-06-17 08:28:53.123")) + withSQLConf(SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key -> + ParquetOutputTimestampType.TIMESTAMP_MILLIS.toString) { + testTimestampPushdown(millisData) + } + + // spark.sql.parquet.outputTimestampType = TIMESTAMP_MICROS + val microsData = Seq(Timestamp.valueOf("2018-06-14 08:28:53.123456"), + Timestamp.valueOf("2018-06-15 08:28:53.123456"), + Timestamp.valueOf("2018-06-16 08:28:53.123456"), + Timestamp.valueOf("2018-06-17 08:28:53.123456")) + withSQLConf(SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key -> + ParquetOutputTimestampType.TIMESTAMP_MICROS.toString) { + testTimestampPushdown(microsData) + } + + // spark.sql.parquet.outputTimestampType = INT96 doesn't support pushdown + withSQLConf(SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key -> + ParquetOutputTimestampType.INT96.toString) { + withParquetDataFrame(millisData.map(i => Tuple1(i))) { implicit df => + assertResult(None) { + parquetFilters.createFilter( + new SparkToParquetSchemaConverter(conf).convert(df.schema), sources.IsNull("_1")) + } + } + } + } + test("SPARK-6554: don't push down predicates which reference partition columns") { import testImplicits._ From 3e7dc82960fd3339eee16d83df66761ae6e3fe3d Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Sat, 14 Jul 2018 21:36:56 -0700 Subject: [PATCH 1124/1161] [SPARK-24776][SQL] Avro unit test: deduplicate code and replace deprecated methods ## What changes were proposed in this pull request? Improve Avro unit test: 1. use QueryTest/SharedSQLContext/SQLTestUtils, instead of the duplicated test utils. 2. replace deprecated methods This is a follow up PR for #21760, the PR passes pull request tests but failed in: https://amplab.cs.berkeley.edu/jenkins/view/Spark%20QA%20Compile/job/spark-master-compile-maven-hadoop-2.6/7842/ This PR is to fix it. ## How was this patch tested? Unit test. Compile with different commands: ``` ./build/mvn --force -DzincPort=3643 -DskipTests -Phadoop-2.6 -Phive-thriftserver -Pkinesis-asl -Pspark-ganglia-lgpl -Pmesos -Pyarn compile test-compile ./build/mvn --force -DzincPort=3643 -DskipTests -Phadoop-2.7 -Phive-thriftserver -Pkinesis-asl -Pspark-ganglia-lgpl -Pmesos -Pyarn compile test-compile ./build/mvn --force -DzincPort=3643 -DskipTests -Phadoop-3.1 -Phive-thriftserver -Pkinesis-asl -Pspark-ganglia-lgpl -Pmesos -Pyarn compile test-compile ``` Author: Gengliang Wang Closes #21768 from gengliangwang/improve_avro_test. --- .../org/apache/spark/sql/avro/AvroSuite.scala | 98 +++++------ .../org/apache/spark/sql/avro/TestUtils.scala | 156 ------------------ 2 files changed, 45 insertions(+), 209 deletions(-) delete mode 100755 external/avro/src/test/scala/org/apache/spark/sql/avro/TestUtils.scala diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala index c6c1e4051a4b3..4f94d827e3127 100644 --- a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala +++ b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala @@ -31,32 +31,24 @@ import org.apache.avro.generic.{GenericData, GenericDatumWriter, GenericRecord} import org.apache.avro.generic.GenericData.{EnumSymbol, Fixed} import org.apache.commons.io.FileUtils -import org.apache.spark.SparkFunSuite import org.apache.spark.sql._ import org.apache.spark.sql.avro.SchemaConverters.IncompatibleSchemaException +import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} import org.apache.spark.sql.types._ -class AvroSuite extends SparkFunSuite { +class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { val episodesFile = "src/test/resources/episodes.avro" val testFile = "src/test/resources/test.avro" - private var spark: SparkSession = _ - override protected def beforeAll(): Unit = { super.beforeAll() - spark = SparkSession.builder() - .master("local[2]") - .appName("AvroSuite") - .config("spark.sql.files.maxPartitionBytes", 1024) - .getOrCreate() - } - - override protected def afterAll(): Unit = { - try { - spark.sparkContext.stop() - } finally { - super.afterAll() - } + spark.conf.set("spark.sql.files.maxPartitionBytes", 1024) + } + + def checkReloadMatchesSaved(originalFile: String, newFile: String): Unit = { + val originalEntries = spark.read.avro(testFile).collect() + val newEntries = spark.read.avro(newFile) + checkAnswer(newEntries, originalEntries) } test("reading from multiple paths") { @@ -68,7 +60,7 @@ class AvroSuite extends SparkFunSuite { val df = spark.read.avro(episodesFile) val fields = List("title", "air_date", "doctor") for (field <- fields) { - TestUtils.withTempDir { dir => + withTempPath { dir => val outputDir = s"$dir/${UUID.randomUUID}" df.write.partitionBy(field).avro(outputDir) val input = spark.read.avro(outputDir) @@ -82,12 +74,12 @@ class AvroSuite extends SparkFunSuite { test("request no fields") { val df = spark.read.avro(episodesFile) - df.registerTempTable("avro_table") + df.createOrReplaceTempView("avro_table") assert(spark.sql("select count(*) from avro_table").collect().head === Row(8)) } test("convert formats") { - TestUtils.withTempDir { dir => + withTempPath { dir => val df = spark.read.avro(episodesFile) df.write.parquet(dir.getCanonicalPath) assert(spark.read.parquet(dir.getCanonicalPath).count() === df.count) @@ -95,15 +87,16 @@ class AvroSuite extends SparkFunSuite { } test("rearrange internal schema") { - TestUtils.withTempDir { dir => + withTempPath { dir => val df = spark.read.avro(episodesFile) df.select("doctor", "title").write.avro(dir.getCanonicalPath) } } test("test NULL avro type") { - TestUtils.withTempDir { dir => - val fields = Seq(new Field("null", Schema.create(Type.NULL), "doc", null)).asJava + withTempPath { dir => + val fields = + Seq(new Field("null", Schema.create(Type.NULL), "doc", null)).asJava val schema = Schema.createRecord("name", "docs", "namespace", false) schema.setFields(fields) val datumWriter = new GenericDatumWriter[GenericRecord](schema) @@ -122,7 +115,7 @@ class AvroSuite extends SparkFunSuite { } test("union(int, long) is read as long") { - TestUtils.withTempDir { dir => + withTempPath { dir => val avroSchema: Schema = { val union = Schema.createUnion(List(Schema.create(Type.INT), Schema.create(Type.LONG)).asJava) @@ -150,7 +143,7 @@ class AvroSuite extends SparkFunSuite { } test("union(float, double) is read as double") { - TestUtils.withTempDir { dir => + withTempPath { dir => val avroSchema: Schema = { val union = Schema.createUnion(List(Schema.create(Type.FLOAT), Schema.create(Type.DOUBLE)).asJava) @@ -178,7 +171,7 @@ class AvroSuite extends SparkFunSuite { } test("union(float, double, null) is read as nullable double") { - TestUtils.withTempDir { dir => + withTempPath { dir => val avroSchema: Schema = { val union = Schema.createUnion( List(Schema.create(Type.FLOAT), @@ -210,7 +203,7 @@ class AvroSuite extends SparkFunSuite { } test("Union of a single type") { - TestUtils.withTempDir { dir => + withTempPath { dir => val UnionOfOne = Schema.createUnion(List(Schema.create(Type.INT)).asJava) val fields = Seq(new Field("field1", UnionOfOne, "doc", null)).asJava val schema = Schema.createRecord("name", "docs", "namespace", false) @@ -233,7 +226,7 @@ class AvroSuite extends SparkFunSuite { } test("Complex Union Type") { - TestUtils.withTempDir { dir => + withTempPath { dir => val fixedSchema = Schema.createFixed("fixed_name", "doc", "namespace", 4) val enumSchema = Schema.createEnum("enum_name", "doc", "namespace", List("e1", "e2").asJava) val complexUnionType = Schema.createUnion( @@ -271,7 +264,7 @@ class AvroSuite extends SparkFunSuite { } test("Lots of nulls") { - TestUtils.withTempDir { dir => + withTempPath { dir => val schema = StructType(Seq( StructField("binary", BinaryType, true), StructField("timestamp", TimestampType, true), @@ -290,7 +283,7 @@ class AvroSuite extends SparkFunSuite { } test("Struct field type") { - TestUtils.withTempDir { dir => + withTempPath { dir => val schema = StructType(Seq( StructField("float", FloatType, true), StructField("short", ShortType, true), @@ -309,7 +302,7 @@ class AvroSuite extends SparkFunSuite { } test("Date field type") { - TestUtils.withTempDir { dir => + withTempPath { dir => val schema = StructType(Seq( StructField("float", FloatType, true), StructField("date", DateType, true) @@ -329,7 +322,7 @@ class AvroSuite extends SparkFunSuite { } test("Array data types") { - TestUtils.withTempDir { dir => + withTempPath { dir => val testSchema = StructType(Seq( StructField("byte_array", ArrayType(ByteType), true), StructField("short_array", ArrayType(ShortType), true), @@ -363,13 +356,12 @@ class AvroSuite extends SparkFunSuite { } test("write with compression") { - TestUtils.withTempDir { dir => + withTempPath { dir => val AVRO_COMPRESSION_CODEC = "spark.sql.avro.compression.codec" val AVRO_DEFLATE_LEVEL = "spark.sql.avro.deflate.level" val uncompressDir = s"$dir/uncompress" val deflateDir = s"$dir/deflate" val snappyDir = s"$dir/snappy" - val fakeDir = s"$dir/fake" val df = spark.read.avro(testFile) spark.conf.set(AVRO_COMPRESSION_CODEC, "uncompressed") @@ -439,7 +431,7 @@ class AvroSuite extends SparkFunSuite { test("sql test") { spark.sql( s""" - |CREATE TEMPORARY TABLE avroTable + |CREATE TEMPORARY VIEW avroTable |USING avro |OPTIONS (path "$episodesFile") """.stripMargin.replaceAll("\n", " ")) @@ -450,24 +442,24 @@ class AvroSuite extends SparkFunSuite { test("conversion to avro and back") { // Note that test.avro includes a variety of types, some of which are nullable. We expect to // get the same values back. - TestUtils.withTempDir { dir => + withTempPath { dir => val avroDir = s"$dir/avro" spark.read.avro(testFile).write.avro(avroDir) - TestUtils.checkReloadMatchesSaved(spark, testFile, avroDir) + checkReloadMatchesSaved(testFile, avroDir) } } test("conversion to avro and back with namespace") { // Note that test.avro includes a variety of types, some of which are nullable. We expect to // get the same values back. - TestUtils.withTempDir { tempDir => + withTempPath { tempDir => val name = "AvroTest" val namespace = "com.databricks.spark.avro" val parameters = Map("recordName" -> name, "recordNamespace" -> namespace) val avroDir = tempDir + "/namedAvro" spark.read.avro(testFile).write.options(parameters).avro(avroDir) - TestUtils.checkReloadMatchesSaved(spark, testFile, avroDir) + checkReloadMatchesSaved(testFile, avroDir) // Look at raw file and make sure has namespace info val rawSaved = spark.sparkContext.textFile(avroDir) @@ -478,7 +470,7 @@ class AvroSuite extends SparkFunSuite { } test("converting some specific sparkSQL types to avro") { - TestUtils.withTempDir { tempDir => + withTempPath { tempDir => val testSchema = StructType(Seq( StructField("Name", StringType, false), StructField("Length", IntegerType, true), @@ -520,7 +512,7 @@ class AvroSuite extends SparkFunSuite { } test("correctly read long as date/timestamp type") { - TestUtils.withTempDir { tempDir => + withTempPath { tempDir => val sparkSession = spark import sparkSession.implicits._ @@ -549,7 +541,7 @@ class AvroSuite extends SparkFunSuite { } test("does not coerce null date/timestamp value to 0 epoch.") { - TestUtils.withTempDir { tempDir => + withTempPath { tempDir => val sparkSession = spark import sparkSession.implicits._ @@ -610,7 +602,7 @@ class AvroSuite extends SparkFunSuite { // Directory given has no avro files intercept[AnalysisException] { - TestUtils.withTempDir(dir => spark.read.avro(dir.getCanonicalPath)) + withTempPath(dir => spark.read.avro(dir.getCanonicalPath)) } intercept[AnalysisException] { @@ -624,7 +616,7 @@ class AvroSuite extends SparkFunSuite { } intercept[FileNotFoundException] { - TestUtils.withTempDir { dir => + withTempPath { dir => FileUtils.touch(new File(dir, "test")) spark.read.avro(dir.toString) } @@ -633,19 +625,19 @@ class AvroSuite extends SparkFunSuite { } test("SQL test insert overwrite") { - TestUtils.withTempDir { tempDir => + withTempPath { tempDir => val tempEmptyDir = s"$tempDir/sqlOverwrite" // Create a temp directory for table that will be overwritten new File(tempEmptyDir).mkdirs() spark.sql( s""" - |CREATE TEMPORARY TABLE episodes + |CREATE TEMPORARY VIEW episodes |USING avro |OPTIONS (path "$episodesFile") """.stripMargin.replaceAll("\n", " ")) spark.sql( s""" - |CREATE TEMPORARY TABLE episodesEmpty + |CREATE TEMPORARY VIEW episodesEmpty |(name string, air_date string, doctor int) |USING avro |OPTIONS (path "$tempEmptyDir") @@ -665,7 +657,7 @@ class AvroSuite extends SparkFunSuite { test("test save and load") { // Test if load works as expected - TestUtils.withTempDir { tempDir => + withTempPath { tempDir => val df = spark.read.avro(episodesFile) assert(df.count == 8) @@ -679,7 +671,7 @@ class AvroSuite extends SparkFunSuite { test("test load with non-Avro file") { // Test if load works as expected - TestUtils.withTempDir { tempDir => + withTempPath { tempDir => val df = spark.read.avro(episodesFile) assert(df.count == 8) @@ -737,7 +729,7 @@ class AvroSuite extends SparkFunSuite { } test("read avro file partitioned") { - TestUtils.withTempDir { dir => + withTempPath { dir => val sparkSession = spark import sparkSession.implicits._ val df = (0 to 1024 * 3).toDS.map(i => s"record${i}").toDF("records") @@ -756,7 +748,7 @@ class AvroSuite extends SparkFunSuite { case class NestedTop(id: Int, data: NestedMiddle) test("saving avro that has nested records with the same name") { - TestUtils.withTempDir { tempDir => + withTempPath { tempDir => // Save avro file on output folder path val writeDf = spark.createDataFrame(List(NestedTop(1, NestedMiddle(2, NestedBottom(3, "1"))))) val outputFolder = s"$tempDir/duplicate_names/" @@ -773,7 +765,7 @@ class AvroSuite extends SparkFunSuite { case class NestedTopArray(id: Int, data: NestedMiddleArray) test("saving avro that has nested records with the same name inside an array") { - TestUtils.withTempDir { tempDir => + withTempPath { tempDir => // Save avro file on output folder path val writeDf = spark.createDataFrame( List(NestedTopArray(1, NestedMiddleArray(2, Array( @@ -794,7 +786,7 @@ class AvroSuite extends SparkFunSuite { case class NestedTopMap(id: Int, data: NestedMiddleMap) test("saving avro that has nested records with the same name inside a map") { - TestUtils.withTempDir { tempDir => + withTempPath { tempDir => // Save avro file on output folder path val writeDf = spark.createDataFrame( List(NestedTopMap(1, NestedMiddleMap(2, Map( diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/TestUtils.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/TestUtils.scala deleted file mode 100755 index 4ae9b14d9ad0d..0000000000000 --- a/external/avro/src/test/scala/org/apache/spark/sql/avro/TestUtils.scala +++ /dev/null @@ -1,156 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.avro - -import java.io.{File, IOException} -import java.nio.ByteBuffer - -import scala.collection.immutable.HashSet -import scala.collection.mutable.ArrayBuffer -import scala.util.Random - -import com.google.common.io.Files -import java.util - -import org.apache.spark.sql.SparkSession - -private[avro] object TestUtils { - - /** - * This function checks that all records in a file match the original - * record. - */ - def checkReloadMatchesSaved(spark: SparkSession, testFile: String, avroDir: String): Unit = { - - def convertToString(elem: Any): String = { - elem match { - case null => "NULL" // HashSets can't have null in them, so we use a string instead - case arrayBuf: ArrayBuffer[_] => - arrayBuf.asInstanceOf[ArrayBuffer[Any]].toArray.deep.mkString(" ") - case arrayByte: Array[Byte] => arrayByte.deep.mkString(" ") - case other => other.toString - } - } - - val originalEntries = spark.read.avro(testFile).collect() - val newEntries = spark.read.avro(avroDir).collect() - - assert(originalEntries.length == newEntries.length) - - val origEntrySet = Array.fill(originalEntries(0).size)(new HashSet[Any]()) - for (origEntry <- originalEntries) { - var idx = 0 - for (origElement <- origEntry.toSeq) { - origEntrySet(idx) += convertToString(origElement) - idx += 1 - } - } - - for (newEntry <- newEntries) { - var idx = 0 - for (newElement <- newEntry.toSeq) { - assert(origEntrySet(idx).contains(convertToString(newElement))) - idx += 1 - } - } - } - - def withTempDir(f: File => Unit): Unit = { - val dir = Files.createTempDir() - dir.delete() - try f(dir) finally deleteRecursively(dir) - } - - /** - * This function deletes a file or a directory with everything that's in it. This function is - * copied from Spark with minor modifications made to it. See original source at: - * github.com/apache/spark/blob/master/core/src/main/scala/org/apache/spark/util/Utils.scala - */ - - def deleteRecursively(file: File) { - def listFilesSafely(file: File): Seq[File] = { - if (file.exists()) { - val files = file.listFiles() - if (files == null) { - throw new IOException("Failed to list files for dir: " + file) - } - files - } else { - List() - } - } - - if (file != null) { - try { - if (file.isDirectory) { - var savedIOException: IOException = null - for (child <- listFilesSafely(file)) { - try { - deleteRecursively(child) - } catch { - // In case of multiple exceptions, only last one will be thrown - case ioe: IOException => savedIOException = ioe - } - } - if (savedIOException != null) { - throw savedIOException - } - } - } finally { - if (!file.delete()) { - // Delete can also fail if the file simply did not exist - if (file.exists()) { - throw new IOException("Failed to delete: " + file.getAbsolutePath) - } - } - } - } - } - - /** - * This function generates a random map(string, int) of a given size. - */ - private[avro] def generateRandomMap(rand: Random, size: Int): java.util.Map[String, Int] = { - val jMap = new util.HashMap[String, Int]() - for (i <- 0 until size) { - jMap.put(rand.nextString(5), i) - } - jMap - } - - /** - * This function generates a random array of booleans of a given size. - */ - private[avro] def generateRandomArray(rand: Random, size: Int): util.ArrayList[Boolean] = { - val vec = new util.ArrayList[Boolean]() - for (i <- 0 until size) { - vec.add(rand.nextBoolean()) - } - vec - } - - /** - * This function generates a random ByteBuffer of a given size. - */ - private[avro] def generateRandomByteBuffer(rand: Random, size: Int): ByteBuffer = { - val bb = ByteBuffer.allocate(size) - val arrayOfBytes = new Array[Byte](size) - rand.nextBytes(arrayOfBytes) - bb.put(arrayOfBytes) - } -} From 69993217fc4f5e5e41a297702389e86fe534dc2f Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Sat, 14 Jul 2018 22:07:49 -0700 Subject: [PATCH 1125/1161] [SPARK-24807][CORE] Adding files/jars twice: output a warning and add a note ## What changes were proposed in this pull request? In the PR, I propose to output an warning if the `addFile()` or `addJar()` methods are callled more than once for the same path. Currently, overwriting of already added files is not supported. New comments and warning are reflected the existing behaviour. Author: Maxim Gekk Closes #21771 from MaxGekk/warning-on-adding-file. --- R/pkg/R/context.R | 2 ++ .../main/scala/org/apache/spark/SparkContext.scala | 12 ++++++++++++ .../org/apache/spark/api/java/JavaSparkContext.scala | 6 ++++++ python/pyspark/context.py | 4 ++++ 4 files changed, 24 insertions(+) diff --git a/R/pkg/R/context.R b/R/pkg/R/context.R index 8ec727dd042bc..3e996a5ba26fc 100644 --- a/R/pkg/R/context.R +++ b/R/pkg/R/context.R @@ -305,6 +305,8 @@ setCheckpointDirSC <- function(sc, dirName) { #' Currently directories are only supported for Hadoop-supported filesystems. #' Refer Hadoop-supported filesystems at \url{https://wiki.apache.org/hadoop/HCFS}. #' +#' Note: A path can be added only once. Subsequent additions of the same path are ignored. +#' #' @rdname spark.addFile #' @param path The path of the file to be added #' @param recursive Whether to add files recursively from the path. Default is FALSE. diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 74bfb5d6d2ea3..531384ab57305 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -1496,6 +1496,8 @@ class SparkContext(config: SparkConf) extends Logging { * @param path can be either a local file, a file in HDFS (or other Hadoop-supported * filesystems), or an HTTP, HTTPS or FTP URI. To access the file in Spark jobs, * use `SparkFiles.get(fileName)` to find its download location. + * + * @note A path can be added only once. Subsequent additions of the same path are ignored. */ def addFile(path: String): Unit = { addFile(path, false) @@ -1516,6 +1518,8 @@ class SparkContext(config: SparkConf) extends Logging { * use `SparkFiles.get(fileName)` to find its download location. * @param recursive if true, a directory can be given in `path`. Currently directories are * only supported for Hadoop-supported filesystems. + * + * @note A path can be added only once. Subsequent additions of the same path are ignored. */ def addFile(path: String, recursive: Boolean): Unit = { val uri = new Path(path).toUri @@ -1555,6 +1559,9 @@ class SparkContext(config: SparkConf) extends Logging { Utils.fetchFile(uri.toString, new File(SparkFiles.getRootDirectory()), conf, env.securityManager, hadoopConfiguration, timestamp, useCache = false) postEnvironmentUpdate() + } else { + logWarning(s"The path $path has been added already. Overwriting of added paths " + + "is not supported in the current version.") } } @@ -1803,6 +1810,8 @@ class SparkContext(config: SparkConf) extends Logging { * * @param path can be either a local file, a file in HDFS (or other Hadoop-supported filesystems), * an HTTP, HTTPS or FTP URI, or local:/path for a file on every worker node. + * + * @note A path can be added only once. Subsequent additions of the same path are ignored. */ def addJar(path: String) { def addJarFile(file: File): String = { @@ -1849,6 +1858,9 @@ class SparkContext(config: SparkConf) extends Logging { if (addedJars.putIfAbsent(key, timestamp).isEmpty) { logInfo(s"Added JAR $path at $key with timestamp $timestamp") postEnvironmentUpdate() + } else { + logWarning(s"The jar $path has been added already. Overwriting of added jars " + + "is not supported in the current version.") } } } diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala index f1936bf587282..09c83849e26b2 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala @@ -668,6 +668,8 @@ class JavaSparkContext(val sc: SparkContext) * The `path` passed can be either a local file, a file in HDFS (or other Hadoop-supported * filesystems), or an HTTP, HTTPS or FTP URI. To access the file in Spark jobs, * use `SparkFiles.get(fileName)` to find its download location. + * + * @note A path can be added only once. Subsequent additions of the same path are ignored. */ def addFile(path: String) { sc.addFile(path) @@ -681,6 +683,8 @@ class JavaSparkContext(val sc: SparkContext) * * A directory can be given if the recursive option is set to true. Currently directories are only * supported for Hadoop-supported filesystems. + * + * @note A path can be added only once. Subsequent additions of the same path are ignored. */ def addFile(path: String, recursive: Boolean): Unit = { sc.addFile(path, recursive) @@ -690,6 +694,8 @@ class JavaSparkContext(val sc: SparkContext) * Adds a JAR dependency for all tasks to be executed on this SparkContext in the future. * The `path` passed can be either a local file, a file in HDFS (or other Hadoop-supported * filesystems), or an HTTP, HTTPS or FTP URI. + * + * @note A path can be added only once. Subsequent additions of the same path are ignored. */ def addJar(path: String) { sc.addJar(path) diff --git a/python/pyspark/context.py b/python/pyspark/context.py index ede3b6af0a8cf..2cb3117184334 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -847,6 +847,8 @@ def addFile(self, path, recursive=False): A directory can be given if the recursive option is set to True. Currently directories are only supported for Hadoop-supported filesystems. + .. note:: A path can be added only once. Subsequent additions of the same path are ignored. + >>> from pyspark import SparkFiles >>> path = os.path.join(tempdir, "test.txt") >>> with open(path, "w") as testFile: @@ -867,6 +869,8 @@ def addPyFile(self, path): SparkContext in the future. The C{path} passed can be either a local file, a file in HDFS (or other Hadoop-supported filesystems), or an HTTP, HTTPS or FTP URI. + + .. note:: A path can be added only once. Subsequent additions of the same path are ignored. """ self.addFile(path) (dirname, filename) = os.path.split(path) # dirname may be directory or HDFS/S3 prefix From 96030876383822645a5b35698ee407a8d4eb76af Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Sun, 15 Jul 2018 22:06:33 +0800 Subject: [PATCH 1126/1161] [SPARK-24800][SQL] Refactor Avro Serializer and Deserializer ## What changes were proposed in this pull request? Currently the Avro Deserializer converts input Avro format data to `Row`, and then convert the `Row` to `InternalRow`. While the Avro Serializer converts `InternalRow` to `Row`, and then output Avro format data. This PR allows direct conversion between `InternalRow` and Avro format data. ## How was this patch tested? Unit test Author: Gengliang Wang Closes #21762 from gengliangwang/avro_io. --- .../spark/sql/avro/AvroDeserializer.scala | 348 ++++++++++++++++++ .../spark/sql/avro/AvroFileFormat.scala | 24 +- .../spark/sql/avro/AvroOutputWriter.scala | 109 +----- .../sql/avro/AvroOutputWriterFactory.scala | 5 +- .../spark/sql/avro/AvroSerializer.scala | 180 +++++++++ .../spark/sql/avro/SchemaConverters.scala | 333 ++--------------- .../spark/sql/avro/SerializableSchema.scala | 69 ++++ .../org/apache/spark/sql/avro/AvroSuite.scala | 1 - .../sql/avro/SerializableSchemaSuite.scala | 56 +++ 9 files changed, 704 insertions(+), 421 deletions(-) create mode 100644 external/avro/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala create mode 100644 external/avro/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala create mode 100644 external/avro/src/main/scala/org/apache/spark/sql/avro/SerializableSchema.scala create mode 100644 external/avro/src/test/scala/org/apache/spark/sql/avro/SerializableSchemaSuite.scala diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala new file mode 100644 index 0000000000000..b31149a2c74c2 --- /dev/null +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala @@ -0,0 +1,348 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.avro + +import java.nio.ByteBuffer + +import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer + +import org.apache.avro.{Schema, SchemaBuilder} +import org.apache.avro.Schema.Type._ +import org.apache.avro.generic._ +import org.apache.avro.util.Utf8 + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{SpecificInternalRow, UnsafeArrayData} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, DateTimeUtils, GenericArrayData} +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + +/** + * A deserializer to deserialize data in avro format to data in catalyst format. + */ +class AvroDeserializer(rootAvroType: Schema, rootCatalystType: DataType) { + private val converter: Any => Any = rootCatalystType match { + // A shortcut for empty schema. + case st: StructType if st.isEmpty => + (data: Any) => InternalRow.empty + + case st: StructType => + val resultRow = new SpecificInternalRow(st.map(_.dataType)) + val fieldUpdater = new RowUpdater(resultRow) + val writer = getRecordWriter(rootAvroType, st, Nil) + (data: Any) => { + val record = data.asInstanceOf[GenericRecord] + writer(fieldUpdater, record) + resultRow + } + + case _ => + val tmpRow = new SpecificInternalRow(Seq(rootCatalystType)) + val fieldUpdater = new RowUpdater(tmpRow) + val writer = newWriter(rootAvroType, rootCatalystType, Nil) + (data: Any) => { + writer(fieldUpdater, 0, data) + tmpRow.get(0, rootCatalystType) + } + } + + def deserialize(data: Any): Any = converter(data) + + /** + * Creates a writer to write avro values to Catalyst values at the given ordinal with the given + * updater. + */ + private def newWriter( + avroType: Schema, + catalystType: DataType, + path: List[String]): (CatalystDataUpdater, Int, Any) => Unit = + (avroType.getType, catalystType) match { + case (NULL, NullType) => (updater, ordinal, _) => + updater.setNullAt(ordinal) + + // TODO: we can avoid boxing if future version of avro provide primitive accessors. + case (BOOLEAN, BooleanType) => (updater, ordinal, value) => + updater.setBoolean(ordinal, value.asInstanceOf[Boolean]) + + case (INT, IntegerType) => (updater, ordinal, value) => + updater.setInt(ordinal, value.asInstanceOf[Int]) + + case (LONG, LongType) => (updater, ordinal, value) => + updater.setLong(ordinal, value.asInstanceOf[Long]) + + case (LONG, TimestampType) => (updater, ordinal, value) => + updater.setLong(ordinal, value.asInstanceOf[Long] * 1000) + + case (LONG, DateType) => (updater, ordinal, value) => + updater.setInt(ordinal, (value.asInstanceOf[Long] / DateTimeUtils.MILLIS_PER_DAY).toInt) + + case (FLOAT, FloatType) => (updater, ordinal, value) => + updater.setFloat(ordinal, value.asInstanceOf[Float]) + + case (DOUBLE, DoubleType) => (updater, ordinal, value) => + updater.setDouble(ordinal, value.asInstanceOf[Double]) + + case (STRING, StringType) => (updater, ordinal, value) => + val str = value match { + case s: String => UTF8String.fromString(s) + case s: Utf8 => + val bytes = new Array[Byte](s.getByteLength) + System.arraycopy(s.getBytes, 0, bytes, 0, s.getByteLength) + UTF8String.fromBytes(bytes) + } + updater.set(ordinal, str) + + case (ENUM, StringType) => (updater, ordinal, value) => + updater.set(ordinal, UTF8String.fromString(value.toString)) + + case (FIXED, BinaryType) => (updater, ordinal, value) => + updater.set(ordinal, value.asInstanceOf[GenericFixed].bytes().clone()) + + case (BYTES, BinaryType) => (updater, ordinal, value) => + val bytes = value match { + case b: ByteBuffer => + val bytes = new Array[Byte](b.remaining) + b.get(bytes) + bytes + case b: Array[Byte] => b + case other => throw new RuntimeException(s"$other is not a valid avro binary.") + + } + updater.set(ordinal, bytes) + + case (RECORD, st: StructType) => + val writeRecord = getRecordWriter(avroType, st, path) + (updater, ordinal, value) => + val row = new SpecificInternalRow(st) + writeRecord(new RowUpdater(row), value.asInstanceOf[GenericRecord]) + updater.set(ordinal, row) + + case (ARRAY, ArrayType(elementType, containsNull)) => + val elementWriter = newWriter(avroType.getElementType, elementType, path) + (updater, ordinal, value) => + val array = value.asInstanceOf[GenericData.Array[Any]] + val len = array.size() + val result = createArrayData(elementType, len) + val elementUpdater = new ArrayDataUpdater(result) + + var i = 0 + while (i < len) { + val element = array.get(i) + if (element == null) { + if (!containsNull) { + throw new RuntimeException(s"Array value at path ${path.mkString(".")} is not " + + "allowed to be null") + } else { + elementUpdater.setNullAt(i) + } + } else { + elementWriter(elementUpdater, i, element) + } + i += 1 + } + + updater.set(ordinal, result) + + case (MAP, MapType(keyType, valueType, valueContainsNull)) if keyType == StringType => + val keyWriter = newWriter(SchemaBuilder.builder().stringType(), StringType, path) + val valueWriter = newWriter(avroType.getValueType, valueType, path) + (updater, ordinal, value) => + val map = value.asInstanceOf[java.util.Map[AnyRef, AnyRef]] + val keyArray = createArrayData(keyType, map.size()) + val keyUpdater = new ArrayDataUpdater(keyArray) + val valueArray = createArrayData(valueType, map.size()) + val valueUpdater = new ArrayDataUpdater(valueArray) + val iter = map.entrySet().iterator() + var i = 0 + while (iter.hasNext) { + val entry = iter.next() + assert(entry.getKey != null) + keyWriter(keyUpdater, i, entry.getKey) + if (entry.getValue == null) { + if (!valueContainsNull) { + throw new RuntimeException(s"Map value at path ${path.mkString(".")} is not " + + "allowed to be null") + } else { + valueUpdater.setNullAt(i) + } + } else { + valueWriter(valueUpdater, i, entry.getValue) + } + i += 1 + } + + updater.set(ordinal, new ArrayBasedMapData(keyArray, valueArray)) + + case (UNION, _) => + val allTypes = avroType.getTypes.asScala + val nonNullTypes = allTypes.filter(_.getType != NULL) + if (nonNullTypes.nonEmpty) { + if (nonNullTypes.length == 1) { + newWriter(nonNullTypes.head, catalystType, path) + } else { + nonNullTypes.map(_.getType) match { + case Seq(a, b) if Set(a, b) == Set(INT, LONG) && catalystType == LongType => + (updater, ordinal, value) => value match { + case null => updater.setNullAt(ordinal) + case l: java.lang.Long => updater.setLong(ordinal, l) + case i: java.lang.Integer => updater.setLong(ordinal, i.longValue()) + } + + case Seq(a, b) if Set(a, b) == Set(FLOAT, DOUBLE) && catalystType == DoubleType => + (updater, ordinal, value) => value match { + case null => updater.setNullAt(ordinal) + case d: java.lang.Double => updater.setDouble(ordinal, d) + case f: java.lang.Float => updater.setDouble(ordinal, f.doubleValue()) + } + + case _ => + catalystType match { + case st: StructType if st.length == nonNullTypes.size => + val fieldWriters = nonNullTypes.zip(st.fields).map { + case (schema, field) => newWriter(schema, field.dataType, path :+ field.name) + }.toArray + (updater, ordinal, value) => { + val row = new SpecificInternalRow(st) + val fieldUpdater = new RowUpdater(row) + val i = GenericData.get().resolveUnion(avroType, value) + fieldWriters(i)(fieldUpdater, i, value) + updater.set(ordinal, row) + } + + case _ => + throw new IncompatibleSchemaException( + s"Cannot convert Avro to catalyst because schema at path " + + s"${path.mkString(".")} is not compatible " + + s"(avroType = $avroType, sqlType = $catalystType).\n" + + s"Source Avro schema: $rootAvroType.\n" + + s"Target Catalyst type: $rootCatalystType") + } + } + } + } else { + (updater, ordinal, value) => updater.setNullAt(ordinal) + } + + case _ => + throw new IncompatibleSchemaException( + s"Cannot convert Avro to catalyst because schema at path ${path.mkString(".")} " + + s"is not compatible (avroType = $avroType, sqlType = $catalystType).\n" + + s"Source Avro schema: $rootAvroType.\n" + + s"Target Catalyst type: $rootCatalystType") + } + + private def getRecordWriter( + avroType: Schema, + sqlType: StructType, + path: List[String]): (CatalystDataUpdater, GenericRecord) => Unit = { + val validFieldIndexes = ArrayBuffer.empty[Int] + val fieldWriters = ArrayBuffer.empty[(CatalystDataUpdater, Any) => Unit] + + val length = sqlType.length + var i = 0 + while (i < length) { + val sqlField = sqlType.fields(i) + val avroField = avroType.getField(sqlField.name) + if (avroField != null) { + validFieldIndexes += avroField.pos() + + val baseWriter = newWriter(avroField.schema(), sqlField.dataType, path :+ sqlField.name) + val ordinal = i + val fieldWriter = (fieldUpdater: CatalystDataUpdater, value: Any) => { + if (value == null) { + fieldUpdater.setNullAt(ordinal) + } else { + baseWriter(fieldUpdater, ordinal, value) + } + } + fieldWriters += fieldWriter + } else if (!sqlField.nullable) { + throw new IncompatibleSchemaException( + s""" + |Cannot find non-nullable field ${path.mkString(".")}.${sqlField.name} in Avro schema. + |Source Avro schema: $rootAvroType. + |Target Catalyst type: $rootCatalystType. + """.stripMargin) + } + i += 1 + } + + (fieldUpdater, record) => { + var i = 0 + while (i < validFieldIndexes.length) { + fieldWriters(i)(fieldUpdater, record.get(validFieldIndexes(i))) + i += 1 + } + } + } + + private def createArrayData(elementType: DataType, length: Int): ArrayData = elementType match { + case BooleanType => UnsafeArrayData.fromPrimitiveArray(new Array[Boolean](length)) + case ByteType => UnsafeArrayData.fromPrimitiveArray(new Array[Byte](length)) + case ShortType => UnsafeArrayData.fromPrimitiveArray(new Array[Short](length)) + case IntegerType => UnsafeArrayData.fromPrimitiveArray(new Array[Int](length)) + case LongType => UnsafeArrayData.fromPrimitiveArray(new Array[Long](length)) + case FloatType => UnsafeArrayData.fromPrimitiveArray(new Array[Float](length)) + case DoubleType => UnsafeArrayData.fromPrimitiveArray(new Array[Double](length)) + case _ => new GenericArrayData(new Array[Any](length)) + } + + /** + * A base interface for updating values inside catalyst data structure like `InternalRow` and + * `ArrayData`. + */ + sealed trait CatalystDataUpdater { + def set(ordinal: Int, value: Any): Unit + + def setNullAt(ordinal: Int): Unit = set(ordinal, null) + def setBoolean(ordinal: Int, value: Boolean): Unit = set(ordinal, value) + def setByte(ordinal: Int, value: Byte): Unit = set(ordinal, value) + def setShort(ordinal: Int, value: Short): Unit = set(ordinal, value) + def setInt(ordinal: Int, value: Int): Unit = set(ordinal, value) + def setLong(ordinal: Int, value: Long): Unit = set(ordinal, value) + def setDouble(ordinal: Int, value: Double): Unit = set(ordinal, value) + def setFloat(ordinal: Int, value: Float): Unit = set(ordinal, value) + } + + final class RowUpdater(row: InternalRow) extends CatalystDataUpdater { + override def set(ordinal: Int, value: Any): Unit = row.update(ordinal, value) + + override def setNullAt(ordinal: Int): Unit = row.setNullAt(ordinal) + override def setBoolean(ordinal: Int, value: Boolean): Unit = row.setBoolean(ordinal, value) + override def setByte(ordinal: Int, value: Byte): Unit = row.setByte(ordinal, value) + override def setShort(ordinal: Int, value: Short): Unit = row.setShort(ordinal, value) + override def setInt(ordinal: Int, value: Int): Unit = row.setInt(ordinal, value) + override def setLong(ordinal: Int, value: Long): Unit = row.setLong(ordinal, value) + override def setDouble(ordinal: Int, value: Double): Unit = row.setDouble(ordinal, value) + override def setFloat(ordinal: Int, value: Float): Unit = row.setFloat(ordinal, value) + } + + final class ArrayDataUpdater(array: ArrayData) extends CatalystDataUpdater { + override def set(ordinal: Int, value: Any): Unit = array.update(ordinal, value) + + override def setNullAt(ordinal: Int): Unit = array.setNullAt(ordinal) + override def setBoolean(ordinal: Int, value: Boolean): Unit = array.setBoolean(ordinal, value) + override def setByte(ordinal: Int, value: Byte): Unit = array.setByte(ordinal, value) + override def setShort(ordinal: Int, value: Short): Unit = array.setShort(ordinal, value) + override def setInt(ordinal: Int, value: Int): Unit = array.setInt(ordinal, value) + override def setLong(ordinal: Int, value: Long): Unit = array.setLong(ordinal, value) + override def setDouble(ordinal: Int, value: Double): Unit = array.setDouble(ordinal, value) + override def setFloat(ordinal: Int, value: Float): Unit = array.setFloat(ordinal, value) + } +} diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala index 46e5a189c5eb3..fb93033bb15d4 100755 --- a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala @@ -25,7 +25,7 @@ import scala.util.control.NonFatal import com.esotericsoftware.kryo.{Kryo, KryoSerializable} import com.esotericsoftware.kryo.io.{Input, Output} -import org.apache.avro.{Schema, SchemaBuilder} +import org.apache.avro.Schema import org.apache.avro.file.{DataFileConstants, DataFileReader} import org.apache.avro.generic.{GenericDatumReader, GenericRecord} import org.apache.avro.mapred.{AvroOutputFormat, FsInput} @@ -38,8 +38,6 @@ import org.slf4j.LoggerFactory import org.apache.spark.TaskContext import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.encoders.RowEncoder -import org.apache.spark.sql.catalyst.expressions.GenericRow import org.apache.spark.sql.execution.datasources.{FileFormat, OutputWriterFactory, PartitionedFile} import org.apache.spark.sql.sources.{DataSourceRegister, Filter} import org.apache.spark.sql.types.StructType @@ -118,8 +116,8 @@ private[avro] class AvroFileFormat extends FileFormat with DataSourceRegister { dataSchema: StructType): OutputWriterFactory = { val recordName = options.getOrElse("recordName", "topLevelRecord") val recordNamespace = options.getOrElse("recordNamespace", "") - val build = SchemaBuilder.record(recordName).namespace(recordNamespace) - val outputAvroSchema = SchemaConverters.convertStructToAvro(dataSchema, build, recordNamespace) + val outputAvroSchema = SchemaConverters.toAvroType( + dataSchema, nullable = false, recordName, recordNamespace) AvroJob.setOutputKeySchema(job, outputAvroSchema) val AVRO_COMPRESSION_CODEC = "spark.sql.avro.compression.codec" @@ -148,7 +146,7 @@ private[avro] class AvroFileFormat extends FileFormat with DataSourceRegister { log.error(s"unsupported compression codec $unknown") } - new AvroOutputWriterFactory(dataSchema, recordName, recordNamespace) + new AvroOutputWriterFactory(dataSchema, new SerializableSchema(outputAvroSchema)) } override def buildReader( @@ -205,13 +203,10 @@ private[avro] class AvroFileFormat extends FileFormat with DataSourceRegister { reader.sync(file.start) val stop = file.start + file.length - val rowConverter = SchemaConverters.createConverterToSQL( - userProvidedSchema.getOrElse(reader.getSchema), requiredSchema) + val deserializer = + new AvroDeserializer(userProvidedSchema.getOrElse(reader.getSchema), requiredSchema) new Iterator[InternalRow] { - // Used to convert `Row`s containing data columns into `InternalRow`s. - private val encoderForDataColumns = RowEncoder(requiredSchema) - private[this] var completed = false override def hasNext: Boolean = { @@ -228,14 +223,11 @@ private[avro] class AvroFileFormat extends FileFormat with DataSourceRegister { } override def next(): InternalRow = { - if (reader.pastSync(stop)) { + if (!hasNext) { throw new NoSuchElementException("next on empty iterator") } val record = reader.next() - val safeDataRow = rowConverter(record).asInstanceOf[GenericRow] - - // The safeDataRow is reused, we must do a copy - encoderForDataColumns.toRow(safeDataRow) + deserializer.deserialize(record).asInstanceOf[InternalRow] } } } diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOutputWriter.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOutputWriter.scala index 830bf3c0570bf..06507115f5ed8 100644 --- a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOutputWriter.scala +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOutputWriter.scala @@ -18,14 +18,8 @@ package org.apache.spark.sql.avro import java.io.{IOException, OutputStream} -import java.nio.ByteBuffer -import java.sql.{Date, Timestamp} -import java.util.HashMap -import scala.collection.immutable.Map - -import org.apache.avro.{Schema, SchemaBuilder} -import org.apache.avro.generic.GenericData.Record +import org.apache.avro.Schema import org.apache.avro.generic.GenericRecord import org.apache.avro.mapred.AvroKey import org.apache.avro.mapreduce.AvroKeyOutputFormat @@ -33,8 +27,7 @@ import org.apache.hadoop.fs.Path import org.apache.hadoop.io.NullWritable import org.apache.hadoop.mapreduce.{RecordWriter, TaskAttemptContext} -import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.datasources.OutputWriter import org.apache.spark.sql.types._ @@ -43,13 +36,10 @@ private[avro] class AvroOutputWriter( path: String, context: TaskAttemptContext, schema: StructType, - recordName: String, - recordNamespace: String) extends OutputWriter { + avroSchema: Schema) extends OutputWriter { - private lazy val converter = createConverterToAvro(schema, recordName, recordNamespace) - // copy of the old conversion logic after api change in SPARK-19085 - private lazy val internalRowConverter = - CatalystTypeConverters.createToScalaConverter(schema).asInstanceOf[InternalRow => Row] + // The input rows will never be null. + private lazy val serializer = new AvroSerializer(schema, avroSchema, nullable = false) /** * Overrides the couple of methods responsible for generating the output streams / files so @@ -70,95 +60,10 @@ private[avro] class AvroOutputWriter( }.getRecordWriter(context) - override def write(internalRow: InternalRow): Unit = { - val row = internalRowConverter(internalRow) - val key = new AvroKey(converter(row).asInstanceOf[GenericRecord]) + override def write(row: InternalRow): Unit = { + val key = new AvroKey(serializer.serialize(row).asInstanceOf[GenericRecord]) recordWriter.write(key, NullWritable.get()) } override def close(): Unit = recordWriter.close(context) - - /** - * This function constructs converter function for a given sparkSQL datatype. This is used in - * writing Avro records out to disk - */ - private def createConverterToAvro( - dataType: DataType, - structName: String, - recordNamespace: String): (Any) => Any = { - dataType match { - case BinaryType => (item: Any) => item match { - case null => null - case bytes: Array[Byte] => ByteBuffer.wrap(bytes) - } - case ByteType | ShortType | IntegerType | LongType | - FloatType | DoubleType | StringType | BooleanType => identity - case _: DecimalType => (item: Any) => if (item == null) null else item.toString - case TimestampType => (item: Any) => - if (item == null) null else item.asInstanceOf[Timestamp].getTime - case DateType => (item: Any) => - if (item == null) null else item.asInstanceOf[Date].getTime - case ArrayType(elementType, _) => - val elementConverter = createConverterToAvro( - elementType, - structName, - SchemaConverters.getNewRecordNamespace(elementType, recordNamespace, structName)) - (item: Any) => { - if (item == null) { - null - } else { - val sourceArray = item.asInstanceOf[Seq[Any]] - val sourceArraySize = sourceArray.size - val targetArray = new Array[Any](sourceArraySize) - var idx = 0 - while (idx < sourceArraySize) { - targetArray(idx) = elementConverter(sourceArray(idx)) - idx += 1 - } - targetArray - } - } - case MapType(StringType, valueType, _) => - val valueConverter = createConverterToAvro( - valueType, - structName, - SchemaConverters.getNewRecordNamespace(valueType, recordNamespace, structName)) - (item: Any) => { - if (item == null) { - null - } else { - val javaMap = new HashMap[String, Any]() - item.asInstanceOf[Map[String, Any]].foreach { case (key, value) => - javaMap.put(key, valueConverter(value)) - } - javaMap - } - } - case structType: StructType => - val builder = SchemaBuilder.record(structName).namespace(recordNamespace) - val schema: Schema = SchemaConverters.convertStructToAvro( - structType, builder, recordNamespace) - val fieldConverters = structType.fields.map(field => - createConverterToAvro( - field.dataType, - field.name, - SchemaConverters.getNewRecordNamespace(field.dataType, recordNamespace, field.name))) - (item: Any) => { - if (item == null) { - null - } else { - val record = new Record(schema) - val convertersIterator = fieldConverters.iterator - val fieldNamesIterator = dataType.asInstanceOf[StructType].fieldNames.iterator - val rowIterator = item.asInstanceOf[Row].toSeq.iterator - - while (convertersIterator.hasNext) { - val converter = convertersIterator.next() - record.put(fieldNamesIterator.next(), converter(rowIterator.next())) - } - record - } - } - } - } } diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOutputWriterFactory.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOutputWriterFactory.scala index 5b2ce7d7d8e0f..18a6d93951408 100644 --- a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOutputWriterFactory.scala +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOutputWriterFactory.scala @@ -24,8 +24,7 @@ import org.apache.spark.sql.types.StructType private[avro] class AvroOutputWriterFactory( schema: StructType, - recordName: String, - recordNamespace: String) extends OutputWriterFactory { + avroSchema: SerializableSchema) extends OutputWriterFactory { override def getFileExtension(context: TaskAttemptContext): String = ".avro" @@ -33,6 +32,6 @@ private[avro] class AvroOutputWriterFactory( path: String, dataSchema: StructType, context: TaskAttemptContext): OutputWriter = { - new AvroOutputWriter(path, context, schema, recordName, recordNamespace) + new AvroOutputWriter(path, context, schema, avroSchema.value) } } diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala new file mode 100644 index 0000000000000..2b4c5813a535b --- /dev/null +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala @@ -0,0 +1,180 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.avro + +import java.nio.ByteBuffer + +import scala.collection.JavaConverters._ + +import org.apache.avro.Schema +import org.apache.avro.Schema.Type.NULL +import org.apache.avro.generic.GenericData.Record +import org.apache.avro.util.Utf8 + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{SpecializedGetters, SpecificInternalRow} +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.types._ + +/** + * A serializer to serialize data in catalyst format to data in avro format. + */ +class AvroSerializer(rootCatalystType: DataType, rootAvroType: Schema, nullable: Boolean) { + + def serialize(catalystData: Any): Any = { + converter.apply(catalystData) + } + + private val converter: Any => Any = { + val actualAvroType = resolveNullableType(rootAvroType, nullable) + val baseConverter = rootCatalystType match { + case st: StructType => + newStructConverter(st, actualAvroType).asInstanceOf[Any => Any] + case _ => + val tmpRow = new SpecificInternalRow(Seq(rootCatalystType)) + val converter = newConverter(rootCatalystType, actualAvroType) + (data: Any) => + tmpRow.update(0, data) + converter.apply(tmpRow, 0) + } + if (nullable) { + (data: Any) => + if (data == null) { + null + } else { + baseConverter.apply(data) + } + } else { + baseConverter + } + } + + private type Converter = (SpecializedGetters, Int) => Any + + private def newConverter(catalystType: DataType, avroType: Schema): Converter = { + catalystType match { + case NullType => + (getter, ordinal) => null + case BooleanType => + (getter, ordinal) => getter.getBoolean(ordinal) + case ByteType => + (getter, ordinal) => getter.getByte(ordinal).toInt + case ShortType => + (getter, ordinal) => getter.getShort(ordinal).toInt + case IntegerType => + (getter, ordinal) => getter.getInt(ordinal) + case LongType => + (getter, ordinal) => getter.getLong(ordinal) + case FloatType => + (getter, ordinal) => getter.getFloat(ordinal) + case DoubleType => + (getter, ordinal) => getter.getDouble(ordinal) + case d: DecimalType => + (getter, ordinal) => getter.getDecimal(ordinal, d.precision, d.scale).toString + case StringType => + (getter, ordinal) => new Utf8(getter.getUTF8String(ordinal).getBytes) + case BinaryType => + (getter, ordinal) => ByteBuffer.wrap(getter.getBinary(ordinal)) + case DateType => + (getter, ordinal) => getter.getInt(ordinal) * DateTimeUtils.MILLIS_PER_DAY + case TimestampType => + (getter, ordinal) => getter.getLong(ordinal) / 1000 + + case ArrayType(et, containsNull) => + val elementConverter = newConverter( + et, resolveNullableType(avroType.getElementType, containsNull)) + (getter, ordinal) => { + val arrayData = getter.getArray(ordinal) + val result = new java.util.ArrayList[Any] + var i = 0 + while (i < arrayData.numElements()) { + if (arrayData.isNullAt(i)) { + result.add(null) + } else { + result.add(elementConverter(arrayData, i)) + } + i += 1 + } + result + } + + case st: StructType => + val structConverter = newStructConverter(st, avroType) + val numFields = st.length + (getter, ordinal) => structConverter(getter.getStruct(ordinal, numFields)) + + case MapType(kt, vt, valueContainsNull) if kt == StringType => + val valueConverter = newConverter( + vt, resolveNullableType(avroType.getValueType, valueContainsNull)) + (getter, ordinal) => + val mapData = getter.getMap(ordinal) + val result = new java.util.HashMap[String, Any](mapData.numElements()) + val keyArray = mapData.keyArray() + val valueArray = mapData.valueArray() + var i = 0 + while (i < mapData.numElements()) { + val key = keyArray.getUTF8String(i).toString + if (valueArray.isNullAt(i)) { + result.put(key, null) + } else { + result.put(key, valueConverter(valueArray, i)) + } + i += 1 + } + result + + case other => + throw new IncompatibleSchemaException(s"Unexpected type: $other") + } + } + + private def newStructConverter( + catalystStruct: StructType, avroStruct: Schema): InternalRow => Record = { + val avroFields = avroStruct.getFields + assert(avroFields.size() == catalystStruct.length) + val fieldConverters = catalystStruct.zip(avroFields.asScala).map { + case (f1, f2) => newConverter(f1.dataType, resolveNullableType(f2.schema(), f1.nullable)) + } + val numFields = catalystStruct.length + (row: InternalRow) => + val result = new Record(avroStruct) + var i = 0 + while (i < numFields) { + if (row.isNullAt(i)) { + result.put(i, null) + } else { + result.put(i, fieldConverters(i).apply(row, i)) + } + i += 1 + } + result + } + + private def resolveNullableType(avroType: Schema, nullable: Boolean): Schema = { + if (nullable) { + // avro uses union to represent nullable type. + val fields = avroType.getTypes.asScala + assert(fields.length == 2) + val actualType = fields.filter(_.getType != NULL) + assert(actualType.length == 1) + actualType.head + } else { + avroType + } + } +} diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala index 01f8c74982535..87fae63aeff2b 100644 --- a/external/avro/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala @@ -17,18 +17,11 @@ package org.apache.spark.sql.avro -import java.nio.ByteBuffer -import java.sql.{Date, Timestamp} - import scala.collection.JavaConverters._ import org.apache.avro.{Schema, SchemaBuilder} import org.apache.avro.Schema.Type._ -import org.apache.avro.SchemaBuilder._ -import org.apache.avro.generic.{GenericData, GenericRecord} -import org.apache.avro.generic.GenericFixed -import org.apache.spark.sql.catalyst.expressions.GenericRow import org.apache.spark.sql.types._ /** @@ -36,9 +29,6 @@ import org.apache.spark.sql.types._ * versa. */ object SchemaConverters { - - class IncompatibleSchemaException(msg: String, ex: Throwable = null) extends Exception(msg, ex) - case class SchemaType(dataType: DataType, nullable: Boolean) /** @@ -109,298 +99,43 @@ object SchemaConverters { } } - /** - * This function converts sparkSQL StructType into avro schema. This method uses two other - * converter methods in order to do the conversion. - */ - def convertStructToAvro[T]( - structType: StructType, - schemaBuilder: RecordBuilder[T], - recordNamespace: String): T = { - val fieldsAssembler: FieldAssembler[T] = schemaBuilder.fields() - structType.fields.foreach { field => - val newField = fieldsAssembler.name(field.name).`type`() - - if (field.nullable) { - convertFieldTypeToAvro(field.dataType, newField.nullable(), field.name, recordNamespace) - .noDefault - } else { - convertFieldTypeToAvro(field.dataType, newField, field.name, recordNamespace) - .noDefault - } - } - fieldsAssembler.endRecord() - } - - /** - * Returns a converter function to convert row in avro format to GenericRow of catalyst. - * - * @param sourceAvroSchema Source schema before conversion inferred from avro file by passed in - * by user. - * @param targetSqlType Target catalyst sql type after the conversion. - * @return returns a converter function to convert row in avro format to GenericRow of catalyst. - */ - private[avro] def createConverterToSQL( - sourceAvroSchema: Schema, - targetSqlType: DataType): AnyRef => AnyRef = { - - def createConverter(avroSchema: Schema, - sqlType: DataType, path: List[String]): AnyRef => AnyRef = { - val avroType = avroSchema.getType - (sqlType, avroType) match { - // Avro strings are in Utf8, so we have to call toString on them - case (StringType, STRING) | (StringType, ENUM) => - (item: AnyRef) => item.toString - // Byte arrays are reused by avro, so we have to make a copy of them. - case (IntegerType, INT) | (BooleanType, BOOLEAN) | (DoubleType, DOUBLE) | - (FloatType, FLOAT) | (LongType, LONG) => - identity - case (TimestampType, LONG) => - (item: AnyRef) => new Timestamp(item.asInstanceOf[Long]) - case (DateType, LONG) => - (item: AnyRef) => new Date(item.asInstanceOf[Long]) - case (BinaryType, FIXED) => - (item: AnyRef) => item.asInstanceOf[GenericFixed].bytes().clone() - case (BinaryType, BYTES) => - (item: AnyRef) => - val byteBuffer = item.asInstanceOf[ByteBuffer] - val bytes = new Array[Byte](byteBuffer.remaining) - byteBuffer.get(bytes) - bytes - case (struct: StructType, RECORD) => - val length = struct.fields.length - val converters = new Array[AnyRef => AnyRef](length) - val avroFieldIndexes = new Array[Int](length) - var i = 0 - while (i < length) { - val sqlField = struct.fields(i) - val avroField = avroSchema.getField(sqlField.name) - if (avroField != null) { - val converter = (item: AnyRef) => { - if (item == null) { - item - } else { - createConverter(avroField.schema, sqlField.dataType, path :+ sqlField.name)(item) - } - } - converters(i) = converter - avroFieldIndexes(i) = avroField.pos() - } else if (!sqlField.nullable) { - throw new IncompatibleSchemaException( - s"Cannot find non-nullable field ${sqlField.name} at path ${path.mkString(".")} " + - "in Avro schema\n" + - s"Source Avro schema: $sourceAvroSchema.\n" + - s"Target Catalyst type: $targetSqlType") - } - i += 1 - } - - (item: AnyRef) => - val record = item.asInstanceOf[GenericRecord] - val result = new Array[Any](length) - var i = 0 - while (i < converters.length) { - if (converters(i) != null) { - val converter = converters(i) - result(i) = converter(record.get(avroFieldIndexes(i))) - } - i += 1 - } - new GenericRow(result) - case (arrayType: ArrayType, ARRAY) => - val elementConverter = createConverter(avroSchema.getElementType, arrayType.elementType, - path) - val allowsNull = arrayType.containsNull - (item: AnyRef) => - item.asInstanceOf[java.lang.Iterable[AnyRef]].asScala.map { element => - if (element == null && !allowsNull) { - throw new RuntimeException(s"Array value at path ${path.mkString(".")} is not " + - "allowed to be null") - } else { - elementConverter(element) - } - } - case (mapType: MapType, MAP) if mapType.keyType == StringType => - val valueConverter = createConverter(avroSchema.getValueType, mapType.valueType, path) - val allowsNull = mapType.valueContainsNull - (item: AnyRef) => - item.asInstanceOf[java.util.Map[AnyRef, AnyRef]].asScala.map { case (k, v) => - if (v == null && !allowsNull) { - throw new RuntimeException(s"Map value at path ${path.mkString(".")} is not " + - "allowed to be null") - } else { - (k.toString, valueConverter(v)) - } - }.toMap - case (sqlType, UNION) => - if (avroSchema.getTypes.asScala.exists(_.getType == NULL)) { - val remainingUnionTypes = avroSchema.getTypes.asScala.filterNot(_.getType == NULL) - if (remainingUnionTypes.size == 1) { - createConverter(remainingUnionTypes.head, sqlType, path) - } else { - createConverter(Schema.createUnion(remainingUnionTypes.asJava), sqlType, path) - } - } else avroSchema.getTypes.asScala.map(_.getType) match { - case Seq(t1) => createConverter(avroSchema.getTypes.get(0), sqlType, path) - case Seq(a, b) if Set(a, b) == Set(INT, LONG) && sqlType == LongType => - (item: AnyRef) => - item match { - case l: java.lang.Long => l - case i: java.lang.Integer => new java.lang.Long(i.longValue()) - } - case Seq(a, b) if Set(a, b) == Set(FLOAT, DOUBLE) && sqlType == DoubleType => - (item: AnyRef) => - item match { - case d: java.lang.Double => d - case f: java.lang.Float => new java.lang.Double(f.doubleValue()) - } - case other => - sqlType match { - case t: StructType if t.fields.length == avroSchema.getTypes.size => - val fieldConverters = t.fields.zip(avroSchema.getTypes.asScala).map { - case (field, schema) => - createConverter(schema, field.dataType, path :+ field.name) - } - (item: AnyRef) => - val i = GenericData.get().resolveUnion(avroSchema, item) - val converted = new Array[Any](fieldConverters.length) - converted(i) = fieldConverters(i)(item) - new GenericRow(converted) - case _ => throw new IncompatibleSchemaException( - s"Cannot convert Avro schema to catalyst type because schema at path " + - s"${path.mkString(".")} is not compatible " + - s"(avroType = $other, sqlType = $sqlType). \n" + - s"Source Avro schema: $sourceAvroSchema.\n" + - s"Target Catalyst type: $targetSqlType") - } - } - case (left, right) => - throw new IncompatibleSchemaException( - s"Cannot convert Avro schema to catalyst type because schema at path " + - s"${path.mkString(".")} is not compatible (avroType = $right, sqlType = $left). \n" + - s"Source Avro schema: $sourceAvroSchema.\n" + - s"Target Catalyst type: $targetSqlType") - } - } - createConverter(sourceAvroSchema, targetSqlType, List.empty[String]) - } - - /** - * This function is used to convert some sparkSQL type to avro type. Note that this function won't - * be used to construct fields of avro record (convertFieldTypeToAvro is used for that). - */ - private def convertTypeToAvro[T]( - dataType: DataType, - schemaBuilder: BaseTypeBuilder[T], - structName: String, - recordNamespace: String): T = { - dataType match { - case ByteType => schemaBuilder.intType() - case ShortType => schemaBuilder.intType() - case IntegerType => schemaBuilder.intType() - case LongType => schemaBuilder.longType() - case FloatType => schemaBuilder.floatType() - case DoubleType => schemaBuilder.doubleType() - case _: DecimalType => schemaBuilder.stringType() - case StringType => schemaBuilder.stringType() - case BinaryType => schemaBuilder.bytesType() - case BooleanType => schemaBuilder.booleanType() - case TimestampType => schemaBuilder.longType() - case DateType => schemaBuilder.longType() - - case ArrayType(elementType, _) => - val builder = getSchemaBuilder(dataType.asInstanceOf[ArrayType].containsNull) - val elementSchema = convertTypeToAvro(elementType, builder, structName, recordNamespace) - schemaBuilder.array().items(elementSchema) - - case MapType(StringType, valueType, _) => - val builder = getSchemaBuilder(dataType.asInstanceOf[MapType].valueContainsNull) - val valueSchema = convertTypeToAvro(valueType, builder, structName, recordNamespace) - schemaBuilder.map().values(valueSchema) - - case structType: StructType => - convertStructToAvro( - structType, - schemaBuilder.record(structName).namespace(recordNamespace), - recordNamespace) - - case other => throw new IncompatibleSchemaException(s"Unexpected type $dataType.") - } - } - - /** - * This function is used to construct fields of the avro record, where schema of the field is - * specified by avro representation of dataType. Since builders for record fields are different - * from those for everything else, we have to use a separate method. - */ - private def convertFieldTypeToAvro[T]( - dataType: DataType, - newFieldBuilder: BaseFieldTypeBuilder[T], - structName: String, - recordNamespace: String): FieldDefault[T, _] = { - dataType match { - case ByteType => newFieldBuilder.intType() - case ShortType => newFieldBuilder.intType() - case IntegerType => newFieldBuilder.intType() - case LongType => newFieldBuilder.longType() - case FloatType => newFieldBuilder.floatType() - case DoubleType => newFieldBuilder.doubleType() - case _: DecimalType => newFieldBuilder.stringType() - case StringType => newFieldBuilder.stringType() - case BinaryType => newFieldBuilder.bytesType() - case BooleanType => newFieldBuilder.booleanType() - case TimestampType => newFieldBuilder.longType() - case DateType => newFieldBuilder.longType() - - case ArrayType(elementType, _) => - val builder = getSchemaBuilder(dataType.asInstanceOf[ArrayType].containsNull) - val elementSchema = convertTypeToAvro( - elementType, - builder, - structName, - getNewRecordNamespace(elementType, recordNamespace, structName)) - newFieldBuilder.array().items(elementSchema) - - case MapType(StringType, valueType, _) => - val builder = getSchemaBuilder(dataType.asInstanceOf[MapType].valueContainsNull) - val valueSchema = convertTypeToAvro( - valueType, - builder, - structName, - getNewRecordNamespace(valueType, recordNamespace, structName)) - newFieldBuilder.map().values(valueSchema) - - case structType: StructType => - convertStructToAvro( - structType, - newFieldBuilder.record(structName).namespace(s"$recordNamespace.$structName"), - s"$recordNamespace.$structName") - - case other => throw new IncompatibleSchemaException(s"Unexpected type $dataType.") - } - } - - /** - * Returns a new namespace depending on the data type of the element. - * If the data type is a StructType it returns the current namespace concatenated - * with the element name, otherwise it returns the current namespace as it is. - */ - private[avro] def getNewRecordNamespace( - elementDataType: DataType, - currentRecordNamespace: String, - elementName: String): String = { - - elementDataType match { - case StructType(_) => s"$currentRecordNamespace.$elementName" - case _ => currentRecordNamespace - } - } - - private def getSchemaBuilder(isNullable: Boolean): BaseTypeBuilder[Schema] = { - if (isNullable) { + def toAvroType( + catalystType: DataType, + nullable: Boolean = false, + recordName: String = "topLevelRecord", + prevNameSpace: String = ""): Schema = { + val builder = if (nullable) { SchemaBuilder.builder().nullable() } else { SchemaBuilder.builder() } + catalystType match { + case BooleanType => builder.booleanType() + case ByteType | ShortType | IntegerType => builder.intType() + case LongType => builder.longType() + case DateType => builder.longType() + case TimestampType => builder.longType() + case FloatType => builder.floatType() + case DoubleType => builder.doubleType() + case _: DecimalType | StringType => builder.stringType() + case BinaryType => builder.bytesType() + case ArrayType(et, containsNull) => + builder.array().items(toAvroType(et, containsNull, recordName, prevNameSpace)) + case MapType(StringType, vt, valueContainsNull) => + builder.map().values(toAvroType(vt, valueContainsNull, recordName, prevNameSpace)) + case st: StructType => + val nameSpace = s"$prevNameSpace.$recordName" + val fieldsAssembler = builder.record(recordName).namespace(nameSpace).fields() + st.foreach { f => + val fieldAvroType = toAvroType(f.dataType, f.nullable, f.name, nameSpace) + fieldsAssembler.name(f.name).`type`(fieldAvroType).noDefault() + } + fieldsAssembler.endRecord() + + // This should never happen. + case other => throw new IncompatibleSchemaException(s"Unexpected type $other.") + } } } + +class IncompatibleSchemaException(msg: String, ex: Throwable = null) extends Exception(msg, ex) diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/SerializableSchema.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/SerializableSchema.scala new file mode 100644 index 0000000000000..ec0ddc778c8f6 --- /dev/null +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/SerializableSchema.scala @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.avro + +import java.io._ + +import scala.util.control.NonFatal + +import com.esotericsoftware.kryo.{Kryo, KryoSerializable} +import com.esotericsoftware.kryo.io.{Input, Output} +import org.apache.avro.Schema +import org.slf4j.LoggerFactory + +class SerializableSchema(@transient var value: Schema) + extends Serializable with KryoSerializable { + + @transient private[avro] lazy val log = LoggerFactory.getLogger(getClass) + + private def writeObject(out: ObjectOutputStream): Unit = tryOrIOException { + out.defaultWriteObject() + out.writeUTF(value.toString()) + out.flush() + } + + private def readObject(in: ObjectInputStream): Unit = tryOrIOException { + val json = in.readUTF() + value = new Schema.Parser().parse(json) + } + + private def tryOrIOException[T](block: => T): T = { + try { + block + } catch { + case e: IOException => + log.error("Exception encountered", e) + throw e + case NonFatal(e) => + log.error("Exception encountered", e) + throw new IOException(e) + } + } + + def write(kryo: Kryo, out: Output): Unit = { + val dos = new DataOutputStream(out) + dos.writeUTF(value.toString()) + dos.flush() + } + + def read(kryo: Kryo, in: Input): Unit = { + val dis = new DataInputStream(in) + val json = dis.readUTF() + value = new Schema.Parser().parse(json) + } +} diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala index 4f94d827e3127..6ed66563b987a 100644 --- a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala +++ b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala @@ -32,7 +32,6 @@ import org.apache.avro.generic.GenericData.{EnumSymbol, Fixed} import org.apache.commons.io.FileUtils import org.apache.spark.sql._ -import org.apache.spark.sql.avro.SchemaConverters.IncompatibleSchemaException import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} import org.apache.spark.sql.types._ diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/SerializableSchemaSuite.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/SerializableSchemaSuite.scala new file mode 100644 index 0000000000000..510bcbdd31929 --- /dev/null +++ b/external/avro/src/test/scala/org/apache/spark/sql/avro/SerializableSchemaSuite.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.avro + +import org.apache.avro.Schema + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.serializer.{JavaSerializer, KryoSerializer, SerializerInstance} + +class SerializableSchemaSuite extends SparkFunSuite { + + private def testSerialization(serializer: SerializerInstance): Unit = { + val avroTypeJson = + s""" + |{ + | "type": "string", + | "name": "my_string" + |} + """.stripMargin + val avroSchema = new Schema.Parser().parse(avroTypeJson) + val serializableSchema = new SerializableSchema(avroSchema) + val serialized = serializer.serialize(serializableSchema) + + serializer.deserialize[Any](serialized) match { + case c: SerializableSchema => + assert(c.log != null, "log was null") + assert(c.value != null, "value was null") + assert(c.value == avroSchema) + case other => fail( + s"Expecting ${classOf[SerializableSchema]}, but got ${other.getClass}.") + } + } + + test("serialization with JavaSerializer") { + testSerialization(new JavaSerializer(new SparkConf()).newInstance()) + } + + test("serialization with KryoSerializer") { + testSerialization(new KryoSerializer(new SparkConf()).newInstance()) + } +} From 5d62a985dca9280f884e13e29fc7166ef13c459f Mon Sep 17 00:00:00 2001 From: "Zoltan C. Toth" Date: Sun, 15 Jul 2018 17:08:26 -0500 Subject: [PATCH 1127/1161] Doc fix: The Imputer is an Estimator Fixing the doc as the imputer is not a `Transformer` but an `Estimator`. https://github.com/apache/spark/blob/master/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala#L96-L97 ## What changes were proposed in this pull request? Simple documentation fix ## How was this patch tested? manual testing Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Zoltan C. Toth Closes #21755 from zoltanctoth/doc-imputer-is-estimator. --- docs/ml-features.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/ml-features.md b/docs/ml-features.md index 7aed2341584fc..ad6e718b37f1b 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -1429,7 +1429,7 @@ for more details on the API. ## Imputer -The `Imputer` transformer completes missing values in a dataset, either using the mean or the +The `Imputer` estimator completes missing values in a dataset, either using the mean or the median of the columns in which the missing values are located. The input columns should be of `DoubleType` or `FloatType`. Currently `Imputer` does not support categorical features and possibly creates incorrect values for columns containing categorical features. Imputer can impute custom values From bbc2ffc8ab27192384def9847c36b873efd87234 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Mon, 16 Jul 2018 09:29:51 +0800 Subject: [PATCH 1128/1161] [SPARK-24813][TESTS][HIVE][HOTFIX] HiveExternalCatalogVersionsSuite still flaky; fall back to Apache archive ## What changes were proposed in this pull request? Try only unique ASF mirrors to download Spark release; fall back to Apache archive if no mirrors available or release is not mirrored ## How was this patch tested? Existing HiveExternalCatalogVersionsSuite Author: Sean Owen Closes #21776 from srowen/SPARK-24813. --- .../HiveExternalCatalogVersionsSuite.scala | 24 ++++++++++++------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala index 514921875f1f9..f8212684d5335 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala @@ -56,14 +56,21 @@ class HiveExternalCatalogVersionsSuite extends SparkSubmitTestUtils { } private def tryDownloadSpark(version: String, path: String): Unit = { - // Try mirrors a few times until one succeeds - for (i <- 0 until 3) { - // we don't retry on a failure to get mirror url. If we can't get a mirror url, - // the test fails (getStringFromUrl will throw an exception) - val preferredMirror = - getStringFromUrl("https://www.apache.org/dyn/closer.lua?preferred=true") + // Try a few mirrors first; fall back to Apache archive + val mirrors = + (0 until 2).flatMap { _ => + try { + Some(getStringFromUrl("https://www.apache.org/dyn/closer.lua?preferred=true")) + } catch { + // If we can't get a mirror URL, skip it. No retry. + case _: Exception => None + } + } + val sites = mirrors.distinct :+ "https://archive.apache.org/dist" + logInfo(s"Trying to download Spark $version from $sites") + for (site <- sites) { val filename = s"spark-$version-bin-hadoop2.7.tgz" - val url = s"$preferredMirror/spark/spark-$version/$filename" + val url = s"$site/spark/spark-$version/$filename" logInfo(s"Downloading Spark $version from $url") try { getFileFromUrl(url, path, filename) @@ -83,7 +90,8 @@ class HiveExternalCatalogVersionsSuite extends SparkSubmitTestUtils { Seq("rm", "-rf", targetDir).! } } catch { - case ex: Exception => logWarning(s"Failed to download Spark $version from $url", ex) + case ex: Exception => + logWarning(s"Failed to download Spark $version from $url: ${ex.getMessage}") } } fail(s"Unable to download Spark $version") From bcf7121ed2283d88424863ac1d35393870eaae6b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=83=91=E7=91=9E=E5=B3=B0?= Date: Sun, 15 Jul 2018 20:14:17 -0700 Subject: [PATCH 1129/1161] [TRIVIAL][ML] GMM unpersist RDD after training MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? unpersist `instances` after training ## How was this patch tested? existing tests Author: 郑瑞峰 Closes #21562 from zhengruifeng/gmm_unpersist. --- .../scala/org/apache/spark/ml/clustering/GaussianMixture.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala index dae64ba9a515d..f0707b380c673 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala @@ -341,7 +341,7 @@ class GaussianMixture @Since("2.0.0") ( val sc = dataset.sparkSession.sparkContext val numClusters = $(k) - val instances: RDD[Vector] = dataset + val instances = dataset .select(DatasetUtils.columnToVector(dataset, getFeaturesCol)).rdd.map { case Row(features: Vector) => features }.cache() @@ -416,6 +416,7 @@ class GaussianMixture @Since("2.0.0") ( iter += 1 } + instances.unpersist(false) val gaussianDists = gaussians.map { case (mean, covVec) => val cov = GaussianMixture.unpackUpperTriangularMatrix(numFeatures, covVec.values) new MultivariateGaussian(mean, cov) From d463533ded89a05e9f77e590fd3de2ffa212d68b Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Sun, 15 Jul 2018 20:22:09 -0700 Subject: [PATCH 1130/1161] [SPARK-24676][SQL] Project required data from CSV parsed data when column pruning disabled ## What changes were proposed in this pull request? This pr modified code to project required data from CSV parsed data when column pruning disabled. In the current master, an exception below happens if `spark.sql.csv.parser.columnPruning.enabled` is false. This is because required formats and CSV parsed formats are different from each other; ``` ./bin/spark-shell --conf spark.sql.csv.parser.columnPruning.enabled=false scala> val dir = "/tmp/spark-csv/csv" scala> spark.range(10).selectExpr("id % 2 AS p", "id").write.mode("overwrite").partitionBy("p").csv(dir) scala> spark.read.csv(dir).selectExpr("sum(p)").collect() 18/06/25 13:48:46 ERROR Executor: Exception in task 2.0 in stage 2.0 (TID 7) java.lang.ClassCastException: org.apache.spark.unsafe.types.UTF8String cannot be cast to java.lang.Integer at scala.runtime.BoxesRunTime.unboxToInt(BoxesRunTime.java:101) at org.apache.spark.sql.catalyst.expressions.BaseGenericInternalRow$class.getInt(rows.scala:41) ... ``` ## How was this patch tested? Added tests in `CSVSuite`. Author: Takeshi Yamamuro Closes #21657 from maropu/SPARK-24676. --- .../datasources/csv/UnivocityParser.scala | 54 ++++++++++++++----- .../execution/datasources/csv/CSVSuite.scala | 29 ++++++++++ 2 files changed, 70 insertions(+), 13 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala index aa545e1a0c00a..79143cce4a380 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala @@ -33,29 +33,49 @@ import org.apache.spark.sql.execution.datasources.FailureSafeParser import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String + +/** + * Constructs a parser for a given schema that translates CSV data to an [[InternalRow]]. + * + * @param dataSchema The CSV data schema that is specified by the user, or inferred from underlying + * data files. + * @param requiredSchema The schema of the data that should be output for each row. This should be a + * subset of the columns in dataSchema. + * @param options Configuration options for a CSV parser. + */ class UnivocityParser( dataSchema: StructType, requiredSchema: StructType, val options: CSVOptions) extends Logging { require(requiredSchema.toSet.subsetOf(dataSchema.toSet), - "requiredSchema should be the subset of schema.") + s"requiredSchema (${requiredSchema.catalogString}) should be the subset of " + + s"dataSchema (${dataSchema.catalogString}).") def this(schema: StructType, options: CSVOptions) = this(schema, schema, options) // A `ValueConverter` is responsible for converting the given value to a desired type. private type ValueConverter = String => Any + // This index is used to reorder parsed tokens + private val tokenIndexArr = + requiredSchema.map(f => java.lang.Integer.valueOf(dataSchema.indexOf(f))).toArray + + // When column pruning is enabled, the parser only parses the required columns based on + // their positions in the data schema. + private val parsedSchema = if (options.columnPruning) requiredSchema else dataSchema + val tokenizer = { val parserSetting = options.asParserSettings - if (options.columnPruning && requiredSchema.length < dataSchema.length) { - val tokenIndexArr = requiredSchema.map(f => java.lang.Integer.valueOf(dataSchema.indexOf(f))) + // When to-be-parsed schema is shorter than the to-be-read data schema, we let Univocity CSV + // parser select a sequence of fields for reading by their positions. + // if (options.columnPruning && requiredSchema.length < dataSchema.length) { + if (parsedSchema.length < dataSchema.length) { parserSetting.selectIndexes(tokenIndexArr: _*) } new CsvParser(parserSetting) } - private val schema = if (options.columnPruning) requiredSchema else dataSchema - private val row = new GenericInternalRow(schema.length) + private val row = new GenericInternalRow(requiredSchema.length) // Retrieve the raw record string. private def getCurrentInput: UTF8String = { @@ -82,7 +102,7 @@ class UnivocityParser( // // output row - ["A", 2] private val valueConverters: Array[ValueConverter] = { - schema.map(f => makeConverter(f.name, f.dataType, f.nullable, options)).toArray + requiredSchema.map(f => makeConverter(f.name, f.dataType, f.nullable, options)).toArray } /** @@ -183,7 +203,7 @@ class UnivocityParser( } } - private val doParse = if (schema.nonEmpty) { + private val doParse = if (requiredSchema.nonEmpty) { (input: String) => convert(tokenizer.parseLine(input)) } else { // If `columnPruning` enabled and partition attributes scanned only, @@ -197,15 +217,21 @@ class UnivocityParser( */ def parse(input: String): InternalRow = doParse(input) + private val getToken = if (options.columnPruning) { + (tokens: Array[String], index: Int) => tokens(index) + } else { + (tokens: Array[String], index: Int) => tokens(tokenIndexArr(index)) + } + private def convert(tokens: Array[String]): InternalRow = { - if (tokens.length != schema.length) { + if (tokens.length != parsedSchema.length) { // If the number of tokens doesn't match the schema, we should treat it as a malformed record. // However, we still have chance to parse some of the tokens, by adding extra null tokens in // the tail if the number is smaller, or by dropping extra tokens if the number is larger. - val checkedTokens = if (schema.length > tokens.length) { - tokens ++ new Array[String](schema.length - tokens.length) + val checkedTokens = if (parsedSchema.length > tokens.length) { + tokens ++ new Array[String](parsedSchema.length - tokens.length) } else { - tokens.take(schema.length) + tokens.take(parsedSchema.length) } def getPartialResult(): Option[InternalRow] = { try { @@ -222,9 +248,11 @@ class UnivocityParser( new RuntimeException("Malformed CSV record")) } else { try { + // When the length of the returned tokens is identical to the length of the parsed schema, + // we just need to convert the tokens that correspond to the required columns. var i = 0 - while (i < schema.length) { - row(i) = valueConverters(i).apply(tokens(i)) + while (i < requiredSchema.length) { + row(i) = valueConverters(i).apply(getToken(tokens, i)) i += 1 } row diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index 84b91f6309fe8..ae8110fdf1709 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -1579,4 +1579,33 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils with Te } } } + + test("SPARK-24676 project required data from parsed data when columnPruning disabled") { + withSQLConf(SQLConf.CSV_PARSER_COLUMN_PRUNING.key -> "false") { + withTempPath { path => + val dir = path.getAbsolutePath + spark.range(10).selectExpr("id % 2 AS p", "id AS c0", "id AS c1").write.partitionBy("p") + .option("header", "true").csv(dir) + val df1 = spark.read.option("header", true).csv(dir).selectExpr("sum(p)", "count(c0)") + checkAnswer(df1, Row(5, 10)) + + // empty required column case + val df2 = spark.read.option("header", true).csv(dir).selectExpr("sum(p)") + checkAnswer(df2, Row(5)) + } + + // the case where tokens length != parsedSchema length + withTempPath { path => + val dir = path.getAbsolutePath + Seq("1,2").toDF().write.text(dir) + // more tokens + val df1 = spark.read.schema("c0 int").format("csv").option("mode", "permissive").load(dir) + checkAnswer(df1, Row(1)) + // less tokens + val df2 = spark.read.schema("c0 int, c1 int, c2 int").format("csv") + .option("mode", "permissive").load(dir) + checkAnswer(df2, Row(1, 2, null)) + } + } + } } From 9f929458fb0a8a106f3b5a6ed3ee2cd3faa85770 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Sun, 15 Jul 2018 23:01:36 -0700 Subject: [PATCH 1131/1161] [SPARK-24810][SQL] Fix paths to test files in AvroSuite ## What changes were proposed in this pull request? In the PR, I propose to move `testFile()` to the common trait `SQLTestUtilsBase` and wrap test files in `AvroSuite` by the method `testFile()` which returns full paths to test files in the resource folder. Author: Maxim Gekk Closes #21773 from MaxGekk/test-file. --- .../org/apache/spark/sql/avro/AvroSuite.scala | 79 ++++++++++--------- .../execution/datasources/csv/CSVSuite.scala | 4 - .../datasources/json/JsonSuite.scala | 4 - .../datasources/text/WholeTextFileSuite.scala | 11 +-- .../apache/spark/sql/test/SQLTestUtils.scala | 7 ++ 5 files changed, 53 insertions(+), 52 deletions(-) diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala index 6ed66563b987a..9c6526b29dca3 100644 --- a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala +++ b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala @@ -36,8 +36,8 @@ import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} import org.apache.spark.sql.types._ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { - val episodesFile = "src/test/resources/episodes.avro" - val testFile = "src/test/resources/test.avro" + val episodesAvro = testFile("episodes.avro") + val testAvro = testFile("test.avro") override protected def beforeAll(): Unit = { super.beforeAll() @@ -45,18 +45,18 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } def checkReloadMatchesSaved(originalFile: String, newFile: String): Unit = { - val originalEntries = spark.read.avro(testFile).collect() + val originalEntries = spark.read.avro(testAvro).collect() val newEntries = spark.read.avro(newFile) checkAnswer(newEntries, originalEntries) } test("reading from multiple paths") { - val df = spark.read.avro(episodesFile, episodesFile) + val df = spark.read.avro(episodesAvro, episodesAvro) assert(df.count == 16) } test("reading and writing partitioned data") { - val df = spark.read.avro(episodesFile) + val df = spark.read.avro(episodesAvro) val fields = List("title", "air_date", "doctor") for (field <- fields) { withTempPath { dir => @@ -72,14 +72,14 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("request no fields") { - val df = spark.read.avro(episodesFile) + val df = spark.read.avro(episodesAvro) df.createOrReplaceTempView("avro_table") assert(spark.sql("select count(*) from avro_table").collect().head === Row(8)) } test("convert formats") { withTempPath { dir => - val df = spark.read.avro(episodesFile) + val df = spark.read.avro(episodesAvro) df.write.parquet(dir.getCanonicalPath) assert(spark.read.parquet(dir.getCanonicalPath).count() === df.count) } @@ -87,7 +87,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { test("rearrange internal schema") { withTempPath { dir => - val df = spark.read.avro(episodesFile) + val df = spark.read.avro(episodesAvro) df.select("doctor", "title").write.avro(dir.getCanonicalPath) } } @@ -362,7 +362,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { val deflateDir = s"$dir/deflate" val snappyDir = s"$dir/snappy" - val df = spark.read.avro(testFile) + val df = spark.read.avro(testAvro) spark.conf.set(AVRO_COMPRESSION_CODEC, "uncompressed") df.write.avro(uncompressDir) spark.conf.set(AVRO_COMPRESSION_CODEC, "deflate") @@ -381,49 +381,49 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("dsl test") { - val results = spark.read.avro(episodesFile).select("title").collect() + val results = spark.read.avro(episodesAvro).select("title").collect() assert(results.length === 8) } test("support of various data types") { // This test uses data from test.avro. You can see the data and the schema of this file in // test.json and test.avsc - val all = spark.read.avro(testFile).collect() + val all = spark.read.avro(testAvro).collect() assert(all.length == 3) - val str = spark.read.avro(testFile).select("string").collect() + val str = spark.read.avro(testAvro).select("string").collect() assert(str.map(_(0)).toSet.contains("Terran is IMBA!")) - val simple_map = spark.read.avro(testFile).select("simple_map").collect() + val simple_map = spark.read.avro(testAvro).select("simple_map").collect() assert(simple_map(0)(0).getClass.toString.contains("Map")) assert(simple_map.map(_(0).asInstanceOf[Map[String, Some[Int]]].size).toSet == Set(2, 0)) - val union0 = spark.read.avro(testFile).select("union_string_null").collect() + val union0 = spark.read.avro(testAvro).select("union_string_null").collect() assert(union0.map(_(0)).toSet == Set("abc", "123", null)) - val union1 = spark.read.avro(testFile).select("union_int_long_null").collect() + val union1 = spark.read.avro(testAvro).select("union_int_long_null").collect() assert(union1.map(_(0)).toSet == Set(66, 1, null)) - val union2 = spark.read.avro(testFile).select("union_float_double").collect() + val union2 = spark.read.avro(testAvro).select("union_float_double").collect() assert( union2 .map(x => new java.lang.Double(x(0).toString)) .exists(p => Math.abs(p - Math.PI) < 0.001)) - val fixed = spark.read.avro(testFile).select("fixed3").collect() + val fixed = spark.read.avro(testAvro).select("fixed3").collect() assert(fixed.map(_(0).asInstanceOf[Array[Byte]]).exists(p => p(1) == 3)) - val enum = spark.read.avro(testFile).select("enum").collect() + val enum = spark.read.avro(testAvro).select("enum").collect() assert(enum.map(_(0)).toSet == Set("SPADES", "CLUBS", "DIAMONDS")) - val record = spark.read.avro(testFile).select("record").collect() + val record = spark.read.avro(testAvro).select("record").collect() assert(record(0)(0).getClass.toString.contains("Row")) assert(record.map(_(0).asInstanceOf[Row](0)).contains("TEST_STR123")) - val array_of_boolean = spark.read.avro(testFile).select("array_of_boolean").collect() + val array_of_boolean = spark.read.avro(testAvro).select("array_of_boolean").collect() assert(array_of_boolean.map(_(0).asInstanceOf[Seq[Boolean]].size).toSet == Set(3, 1, 0)) - val bytes = spark.read.avro(testFile).select("bytes").collect() + val bytes = spark.read.avro(testAvro).select("bytes").collect() assert(bytes.map(_(0).asInstanceOf[Array[Byte]].length).toSet == Set(3, 1, 0)) } @@ -432,7 +432,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { s""" |CREATE TEMPORARY VIEW avroTable |USING avro - |OPTIONS (path "$episodesFile") + |OPTIONS (path "${episodesAvro}") """.stripMargin.replaceAll("\n", " ")) assert(spark.sql("SELECT * FROM avroTable").collect().length === 8) @@ -443,8 +443,8 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { // get the same values back. withTempPath { dir => val avroDir = s"$dir/avro" - spark.read.avro(testFile).write.avro(avroDir) - checkReloadMatchesSaved(testFile, avroDir) + spark.read.avro(testAvro).write.avro(avroDir) + checkReloadMatchesSaved(testAvro, avroDir) } } @@ -457,8 +457,8 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { val parameters = Map("recordName" -> name, "recordNamespace" -> namespace) val avroDir = tempDir + "/namedAvro" - spark.read.avro(testFile).write.options(parameters).avro(avroDir) - checkReloadMatchesSaved(testFile, avroDir) + spark.read.avro(testAvro).write.options(parameters).avro(avroDir) + checkReloadMatchesSaved(testAvro, avroDir) // Look at raw file and make sure has namespace info val rawSaved = spark.sparkContext.textFile(avroDir) @@ -532,10 +532,11 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("support of globbed paths") { - val e1 = spark.read.avro("*/test/resources/episodes.avro").collect() + val resourceDir = testFile(".") + val e1 = spark.read.avro(resourceDir + "../*/episodes.avro").collect() assert(e1.length == 8) - val e2 = spark.read.avro("src/*/*/episodes.avro").collect() + val e2 = spark.read.avro(resourceDir + "../../*/*/episodes.avro").collect() assert(e2.length == 8) } @@ -574,8 +575,12 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { | }] |} """.stripMargin - val result = spark.read.option(AvroFileFormat.AvroSchema, avroSchema).avro(testFile).collect() - val expected = spark.read.avro(testFile).select("string").collect() + val result = spark + .read + .option(AvroFileFormat.AvroSchema, avroSchema) + .avro(testAvro) + .collect() + val expected = spark.read.avro(testAvro).select("string").collect() assert(result.sameElements(expected)) } @@ -593,7 +598,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { |} """.stripMargin val result = spark.read.option(AvroFileFormat.AvroSchema, avroSchema) - .avro(testFile).select("missingField").first + .avro(testAvro).select("missingField").first assert(result === Row("foo")) } @@ -632,7 +637,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { s""" |CREATE TEMPORARY VIEW episodes |USING avro - |OPTIONS (path "$episodesFile") + |OPTIONS (path "${episodesAvro}") """.stripMargin.replaceAll("\n", " ")) spark.sql( s""" @@ -657,7 +662,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { test("test save and load") { // Test if load works as expected withTempPath { tempDir => - val df = spark.read.avro(episodesFile) + val df = spark.read.avro(episodesAvro) assert(df.count == 8) val tempSaveDir = s"$tempDir/save/" @@ -671,7 +676,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { test("test load with non-Avro file") { // Test if load works as expected withTempPath { tempDir => - val df = spark.read.avro(episodesFile) + val df = spark.read.avro(episodesAvro) assert(df.count == 8) val tempSaveDir = s"$tempDir/save/" @@ -701,10 +706,10 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { StructField("record", StructType(Seq(StructField("value_field", StringType, false))), false), StructField("array_of_boolean", ArrayType(BooleanType), false), StructField("bytes", BinaryType, true))) - val withSchema = spark.read.schema(partialColumns).avro(testFile).collect() + val withSchema = spark.read.schema(partialColumns).avro(testAvro).collect() val withOutSchema = spark .read - .avro(testFile) + .avro(testAvro) .select("string", "simple_map", "complex_map", "union_string_null", "union_int_long_null", "fixed3", "fixed2", "enum", "record", "array_of_boolean", "bytes") .collect() @@ -722,7 +727,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { StructField("non_exist_field", StringType, false), StructField("non_exist_field2", StringType, false))), false))) - val withEmptyColumn = spark.read.schema(schema).avro(testFile).collect() + val withEmptyColumn = spark.read.schema(schema).avro(testAvro).collect() assert(withEmptyColumn.forall(_ == Row(null: String, Row(null: String, null: String)))) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index ae8110fdf1709..63cc5985040c9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -60,10 +60,6 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils with Te private val unescapedQuotesFile = "test-data/unescaped-quotes.csv" private val valueMalformedFile = "test-data/value-malformed.csv" - private def testFile(fileName: String): String = { - Thread.currentThread().getContextClassLoader.getResource(fileName).toString - } - /** Verifies data and schema. */ private def verifyCars( df: DataFrame, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index eab15b35c97d3..655f40ad549e6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -48,10 +48,6 @@ class TestFileFilter extends PathFilter { class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { import testImplicits._ - def testFile(fileName: String): String = { - Thread.currentThread().getContextClassLoader.getResource(fileName).toString - } - test("Type promotion") { def checkTypePromotion(expected: Any, actual: Any) { assert(expected.getClass == actual.getClass, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/WholeTextFileSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/WholeTextFileSuite.scala index fff0f82f9bc2b..a302d67b5cbf7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/WholeTextFileSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/WholeTextFileSuite.scala @@ -21,10 +21,10 @@ import java.io.File import org.apache.spark.sql.{QueryTest, Row} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} import org.apache.spark.sql.types.{StringType, StructType} -class WholeTextFileSuite extends QueryTest with SharedSQLContext { +class WholeTextFileSuite extends QueryTest with SharedSQLContext with SQLTestUtils { // Hadoop's FileSystem caching does not use the Configuration as part of its cache key, which // can cause Filesystem.get(Configuration) to return a cached instance created with a different @@ -35,13 +35,10 @@ class WholeTextFileSuite extends QueryTest with SharedSQLContext { protected override def sparkConf = super.sparkConf.set("spark.hadoop.fs.file.impl.disable.cache", "true") - private def testFile: String = { - Thread.currentThread().getContextClassLoader.getResource("test-data/text-suite.txt").toString - } - test("reading text file with option wholetext=true") { val df = spark.read.option("wholetext", "true") - .format("text").load(testFile) + .format("text") + .load(testFile("test-data/text-suite.txt")) // schema assert(df.schema == new StructType().add("value", StringType)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index bc4a120f7042f..e562be83822e9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -391,6 +391,13 @@ private[sql] trait SQLTestUtilsBase val fs = hadoopPath.getFileSystem(spark.sessionState.newHadoopConf()) fs.makeQualified(hadoopPath).toUri } + + /** + * Returns full path to the given file in the resouce folder + */ + protected def testFile(fileName: String): String = { + Thread.currentThread().getContextClassLoader.getResource(fileName).toString + } } private[sql] object SQLTestUtils { From 2603ae30be78c6cb24a67c26fb781fae8763f229 Mon Sep 17 00:00:00 2001 From: sandeep-katta Date: Mon, 16 Jul 2018 14:52:49 +0800 Subject: [PATCH 1132/1161] [SPARK-24558][CORE] wrong Idle Timeout value is used in case of the cacheBlock. It is corrected as per the configuration. ## What changes were proposed in this pull request? IdleTimeout info used to print in the logs is taken based on the cacheBlock. If it is cacheBlock then cachedExecutorIdleTimeoutS is considered else executorIdleTimeoutS ## How was this patch tested? Manual Test spark-sql> cache table sample; 2018-05-15 14:44:02 INFO DAGScheduler:54 - Submitting 3 missing tasks from ShuffleMapStage 0 (MapPartitionsRDD[8] at processCmd at CliDriver.java:376) (first 15 tasks are for partitions Vector(0, 1, 2)) 2018-05-15 14:44:02 INFO YarnScheduler:54 - Adding task set 0.0 with 3 tasks 2018-05-15 14:44:03 INFO ExecutorAllocationManager:54 - Requesting 1 new executor because tasks are backlogged (new desired total will be 1) ... ... 2018-05-15 14:46:10 INFO YarnClientSchedulerBackend:54 - Actual list of executor(s) to be killed is 1 2018-05-15 14:46:10 INFO **ExecutorAllocationManager:54 - Removing executor 1 because it has been idle for 120 seconds (new desired total will be 0)** 2018-05-15 14:46:11 INFO YarnSchedulerBackend$YarnDriverEndpoint:54 - Disabling executor 1. 2018-05-15 14:46:11 INFO DAGScheduler:54 - Executor lost: 1 (epoch 1) Author: sandeep-katta Closes #21565 from sandeep-katta/loginfoBug. --- .../org/apache/spark/ExecutorAllocationManager.scala | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala index aa363eeffffb8..17b88631bcb4c 100644 --- a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala +++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala @@ -488,9 +488,15 @@ private[spark] class ExecutorAllocationManager( newExecutorTotal = numExistingExecutors if (testing || executorsRemoved.nonEmpty) { executorsRemoved.foreach { removedExecutorId => + // If it is a cached block, it uses cachedExecutorIdleTimeoutS for timeout + val idleTimeout = if (blockManagerMaster.hasCachedBlocks(removedExecutorId)) { + cachedExecutorIdleTimeoutS + } else { + executorIdleTimeoutS + } newExecutorTotal -= 1 logInfo(s"Removing executor $removedExecutorId because it has been idle for " + - s"$executorIdleTimeoutS seconds (new desired total will be $newExecutorTotal)") + s"$idleTimeout seconds (new desired total will be $newExecutorTotal)") executorsPendingToRemove.add(removedExecutorId) } executorsRemoved From 9549a2814951f9ba969955d78ac4bd2240f85989 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Mon, 16 Jul 2018 15:44:51 +0800 Subject: [PATCH 1133/1161] [SPARK-24549][SQL] Support Decimal type push down to the parquet data sources ## What changes were proposed in this pull request? Support Decimal type push down to the parquet data sources. The Decimal comparator used is: [`BINARY_AS_SIGNED_INTEGER_COMPARATOR`](https://github.com/apache/parquet-mr/blob/c6764c4a0848abf1d581e22df8b33e28ee9f2ced/parquet-column/src/main/java/org/apache/parquet/schema/PrimitiveComparator.java#L224-L292). ## How was this patch tested? unit tests and manual tests. **manual tests**: ```scala spark.range(10000000).selectExpr("id", "cast(id as decimal(9)) as d1", "cast(id as decimal(9, 2)) as d2", "cast(id as decimal(18)) as d3", "cast(id as decimal(18, 4)) as d4", "cast(id as decimal(38)) as d5", "cast(id as decimal(38, 18)) as d6").coalesce(1).write.option("parquet.block.size", 1048576).parquet("/tmp/spark/parquet/decimal") val df = spark.read.parquet("/tmp/spark/parquet/decimal/") spark.sql("set spark.sql.parquet.filterPushdown.decimal=true") // Only read about 1 MB data df.filter("d2 = 10000").show // Only read about 1 MB data df.filter("d4 = 10000").show spark.sql("set spark.sql.parquet.filterPushdown.decimal=false") // Read 174.3 MB data df.filter("d2 = 10000").show // Read 174.3 MB data df.filter("d4 = 10000").show ``` Author: Yuming Wang Closes #21556 from wangyum/SPARK-24549. --- .../apache/spark/sql/internal/SQLConf.scala | 10 + .../FilterPushdownBenchmark-results.txt | 96 ++++---- .../parquet/ParquetFileFormat.scala | 3 +- .../datasources/parquet/ParquetFilters.scala | 225 +++++++++++++----- .../benchmark/FilterPushdownBenchmark.scala | 8 +- .../parquet/ParquetFilterSuite.scala | 90 ++++++- 6 files changed, 324 insertions(+), 108 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 07d33fa7d52ae..41fe0c3b60d9e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -387,6 +387,14 @@ object SQLConf { .booleanConf .createWithDefault(true) + val PARQUET_FILTER_PUSHDOWN_DECIMAL_ENABLED = + buildConf("spark.sql.parquet.filterPushdown.decimal") + .doc("If true, enables Parquet filter push-down optimization for Decimal. " + + "This configuration only has an effect when 'spark.sql.parquet.filterPushdown' is enabled.") + .internal() + .booleanConf + .createWithDefault(true) + val PARQUET_FILTER_PUSHDOWN_STRING_STARTSWITH_ENABLED = buildConf("spark.sql.parquet.filterPushdown.string.startsWith") .doc("If true, enables Parquet filter push-down optimization for string startsWith function. " + @@ -1505,6 +1513,8 @@ class SQLConf extends Serializable with Logging { def parquetFilterPushDownTimestamp: Boolean = getConf(PARQUET_FILTER_PUSHDOWN_TIMESTAMP_ENABLED) + def parquetFilterPushDownDecimal: Boolean = getConf(PARQUET_FILTER_PUSHDOWN_DECIMAL_ENABLED) + def parquetFilterPushDownStringStartWith: Boolean = getConf(PARQUET_FILTER_PUSHDOWN_STRING_STARTSWITH_ENABLED) diff --git a/sql/core/benchmarks/FilterPushdownBenchmark-results.txt b/sql/core/benchmarks/FilterPushdownBenchmark-results.txt index 4f38cc4cee96d..2215ed91e2018 100644 --- a/sql/core/benchmarks/FilterPushdownBenchmark-results.txt +++ b/sql/core/benchmarks/FilterPushdownBenchmark-results.txt @@ -292,120 +292,120 @@ Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz Select 1 decimal(9, 2) row (value = 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Parquet Vectorized 3785 / 3867 4.2 240.6 1.0X -Parquet Vectorized (Pushdown) 3820 / 3928 4.1 242.9 1.0X -Native ORC Vectorized 3981 / 4049 4.0 253.1 1.0X -Native ORC Vectorized (Pushdown) 702 / 735 22.4 44.6 5.4X +Parquet Vectorized 4546 / 4743 3.5 289.0 1.0X +Parquet Vectorized (Pushdown) 161 / 175 98.0 10.2 28.3X +Native ORC Vectorized 5721 / 5842 2.7 363.7 0.8X +Native ORC Vectorized (Pushdown) 1019 / 1070 15.4 64.8 4.5X Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz Select 10% decimal(9, 2) rows (value < 1572864): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Parquet Vectorized 4694 / 4813 3.4 298.4 1.0X -Parquet Vectorized (Pushdown) 4839 / 4907 3.3 307.6 1.0X -Native ORC Vectorized 4943 / 5032 3.2 314.2 0.9X -Native ORC Vectorized (Pushdown) 2043 / 2085 7.7 129.9 2.3X +Parquet Vectorized 6340 / 7236 2.5 403.1 1.0X +Parquet Vectorized (Pushdown) 3052 / 3164 5.2 194.1 2.1X +Native ORC Vectorized 8370 / 9214 1.9 532.1 0.8X +Native ORC Vectorized (Pushdown) 4137 / 4242 3.8 263.0 1.5X Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz Select 50% decimal(9, 2) rows (value < 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Parquet Vectorized 8321 / 8472 1.9 529.0 1.0X -Parquet Vectorized (Pushdown) 8125 / 8471 1.9 516.6 1.0X -Native ORC Vectorized 8524 / 8616 1.8 541.9 1.0X -Native ORC Vectorized (Pushdown) 7961 / 8383 2.0 506.1 1.0X +Parquet Vectorized 12976 / 13249 1.2 825.0 1.0X +Parquet Vectorized (Pushdown) 12655 / 13570 1.2 804.6 1.0X +Native ORC Vectorized 15562 / 15950 1.0 989.4 0.8X +Native ORC Vectorized (Pushdown) 15042 / 15668 1.0 956.3 0.9X Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz Select 90% decimal(9, 2) rows (value < 14155776): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Parquet Vectorized 9587 / 10112 1.6 609.5 1.0X -Parquet Vectorized (Pushdown) 9726 / 10370 1.6 618.3 1.0X -Native ORC Vectorized 10119 / 11147 1.6 643.4 0.9X -Native ORC Vectorized (Pushdown) 9366 / 9497 1.7 595.5 1.0X +Parquet Vectorized 14303 / 14616 1.1 909.3 1.0X +Parquet Vectorized (Pushdown) 14380 / 14649 1.1 914.3 1.0X +Native ORC Vectorized 16964 / 17358 0.9 1078.5 0.8X +Native ORC Vectorized (Pushdown) 17255 / 17874 0.9 1097.0 0.8X Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz Select 1 decimal(18, 2) row (value = 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Parquet Vectorized 4060 / 4093 3.9 258.1 1.0X -Parquet Vectorized (Pushdown) 4037 / 4125 3.9 256.6 1.0X -Native ORC Vectorized 4756 / 4811 3.3 302.4 0.9X -Native ORC Vectorized (Pushdown) 824 / 889 19.1 52.4 4.9X +Parquet Vectorized 4701 / 6416 3.3 298.9 1.0X +Parquet Vectorized (Pushdown) 128 / 164 122.8 8.1 36.7X +Native ORC Vectorized 5698 / 7904 2.8 362.3 0.8X +Native ORC Vectorized (Pushdown) 913 / 942 17.2 58.0 5.2X Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz Select 10% decimal(18, 2) rows (value < 1572864): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Parquet Vectorized 5157 / 5271 3.0 327.9 1.0X -Parquet Vectorized (Pushdown) 5051 / 5141 3.1 321.1 1.0X -Native ORC Vectorized 5723 / 6146 2.7 363.9 0.9X -Native ORC Vectorized (Pushdown) 2198 / 2317 7.2 139.8 2.3X +Parquet Vectorized 5376 / 5461 2.9 341.8 1.0X +Parquet Vectorized (Pushdown) 1479 / 1543 10.6 94.0 3.6X +Native ORC Vectorized 6640 / 6748 2.4 422.2 0.8X +Native ORC Vectorized (Pushdown) 2438 / 2479 6.5 155.0 2.2X Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz Select 50% decimal(18, 2) rows (value < 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Parquet Vectorized 8608 / 8647 1.8 547.3 1.0X -Parquet Vectorized (Pushdown) 8471 / 8584 1.9 538.6 1.0X -Native ORC Vectorized 9249 / 10048 1.7 588.0 0.9X -Native ORC Vectorized (Pushdown) 7645 / 8091 2.1 486.1 1.1X +Parquet Vectorized 9224 / 9356 1.7 586.5 1.0X +Parquet Vectorized (Pushdown) 7172 / 7415 2.2 456.0 1.3X +Native ORC Vectorized 11017 / 11408 1.4 700.4 0.8X +Native ORC Vectorized (Pushdown) 8771 / 10218 1.8 557.7 1.1X Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz Select 90% decimal(18, 2) rows (value < 14155776): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Parquet Vectorized 11658 / 11888 1.3 741.2 1.0X -Parquet Vectorized (Pushdown) 11812 / 12098 1.3 751.0 1.0X -Native ORC Vectorized 12943 / 13312 1.2 822.9 0.9X -Native ORC Vectorized (Pushdown) 13139 / 13465 1.2 835.4 0.9X +Parquet Vectorized 13933 / 15990 1.1 885.8 1.0X +Parquet Vectorized (Pushdown) 12683 / 12942 1.2 806.4 1.1X +Native ORC Vectorized 16344 / 20196 1.0 1039.1 0.9X +Native ORC Vectorized (Pushdown) 15162 / 16627 1.0 964.0 0.9X Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz Select 1 decimal(38, 2) row (value = 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Parquet Vectorized 5491 / 5716 2.9 349.1 1.0X -Parquet Vectorized (Pushdown) 5515 / 5615 2.9 350.6 1.0X -Native ORC Vectorized 4582 / 4654 3.4 291.3 1.2X -Native ORC Vectorized (Pushdown) 815 / 861 19.3 51.8 6.7X +Parquet Vectorized 7102 / 8282 2.2 451.5 1.0X +Parquet Vectorized (Pushdown) 124 / 150 126.4 7.9 57.1X +Native ORC Vectorized 5811 / 6883 2.7 369.5 1.2X +Native ORC Vectorized (Pushdown) 1121 / 1502 14.0 71.3 6.3X Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz Select 10% decimal(38, 2) rows (value < 1572864): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Parquet Vectorized 6432 / 6527 2.4 409.0 1.0X -Parquet Vectorized (Pushdown) 6513 / 6607 2.4 414.1 1.0X -Native ORC Vectorized 5618 / 6085 2.8 357.2 1.1X -Native ORC Vectorized (Pushdown) 2403 / 2443 6.5 152.8 2.7X +Parquet Vectorized 6894 / 7562 2.3 438.3 1.0X +Parquet Vectorized (Pushdown) 1863 / 1980 8.4 118.4 3.7X +Native ORC Vectorized 6812 / 6848 2.3 433.1 1.0X +Native ORC Vectorized (Pushdown) 2511 / 2598 6.3 159.7 2.7X Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz Select 50% decimal(38, 2) rows (value < 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Parquet Vectorized 11041 / 11467 1.4 701.9 1.0X -Parquet Vectorized (Pushdown) 10909 / 11484 1.4 693.5 1.0X -Native ORC Vectorized 9860 / 10436 1.6 626.9 1.1X -Native ORC Vectorized (Pushdown) 7908 / 8069 2.0 502.8 1.4X +Parquet Vectorized 11732 / 12183 1.3 745.9 1.0X +Parquet Vectorized (Pushdown) 8912 / 9945 1.8 566.6 1.3X +Native ORC Vectorized 11499 / 12387 1.4 731.1 1.0X +Native ORC Vectorized (Pushdown) 9328 / 9382 1.7 593.1 1.3X Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6 Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz Select 90% decimal(38, 2) rows (value < 14155776): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ -Parquet Vectorized 14816 / 16877 1.1 942.0 1.0X -Parquet Vectorized (Pushdown) 15383 / 15740 1.0 978.0 1.0X -Native ORC Vectorized 14408 / 14771 1.1 916.0 1.0X -Native ORC Vectorized (Pushdown) 13968 / 14805 1.1 888.1 1.1X +Parquet Vectorized 16272 / 16328 1.0 1034.6 1.0X +Parquet Vectorized (Pushdown) 15714 / 18100 1.0 999.1 1.0X +Native ORC Vectorized 16539 / 18897 1.0 1051.5 1.0X +Native ORC Vectorized (Pushdown) 16328 / 17306 1.0 1038.1 1.0X ================================================================================================ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala index 3ec33b2f4b540..295960b1c2d30 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala @@ -342,6 +342,7 @@ class ParquetFileFormat val returningBatch = supportBatch(sparkSession, resultSchema) val pushDownDate = sqlConf.parquetFilterPushDownDate val pushDownTimestamp = sqlConf.parquetFilterPushDownTimestamp + val pushDownDecimal = sqlConf.parquetFilterPushDownDecimal val pushDownStringStartWith = sqlConf.parquetFilterPushDownStringStartWith val pushDownInFilterThreshold = sqlConf.parquetFilterPushDownInFilterThreshold @@ -367,7 +368,7 @@ class ParquetFileFormat val pushed = if (enableParquetFilterPushDown) { val parquetSchema = ParquetFileReader.readFooter(sharedConf, filePath, SKIP_ROW_GROUPS) .getFileMetaData.getSchema - val parquetFilters = new ParquetFilters(pushDownDate, pushDownTimestamp, + val parquetFilters = new ParquetFilters(pushDownDate, pushDownTimestamp, pushDownDecimal, pushDownStringStartWith, pushDownInFilterThreshold) filters // Collects all converted Parquet filter predicates. Notice that not all predicates can be diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala index 0c146f2f6f915..58b4a769fcb62 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala @@ -17,7 +17,8 @@ package org.apache.spark.sql.execution.datasources.parquet -import java.lang.{Long => JLong} +import java.lang.{Boolean => JBoolean, Double => JDouble, Float => JFloat, Long => JLong} +import java.math.{BigDecimal => JBigDecimal} import java.sql.{Date, Timestamp} import scala.collection.JavaConverters.asScalaBufferConverter @@ -41,44 +42,65 @@ import org.apache.spark.unsafe.types.UTF8String private[parquet] class ParquetFilters( pushDownDate: Boolean, pushDownTimestamp: Boolean, + pushDownDecimal: Boolean, pushDownStartWith: Boolean, pushDownInFilterThreshold: Int) { private case class ParquetSchemaType( originalType: OriginalType, primitiveTypeName: PrimitiveTypeName, + length: Int, decimalMetadata: DecimalMetadata) - private val ParquetBooleanType = ParquetSchemaType(null, BOOLEAN, null) - private val ParquetByteType = ParquetSchemaType(INT_8, INT32, null) - private val ParquetShortType = ParquetSchemaType(INT_16, INT32, null) - private val ParquetIntegerType = ParquetSchemaType(null, INT32, null) - private val ParquetLongType = ParquetSchemaType(null, INT64, null) - private val ParquetFloatType = ParquetSchemaType(null, FLOAT, null) - private val ParquetDoubleType = ParquetSchemaType(null, DOUBLE, null) - private val ParquetStringType = ParquetSchemaType(UTF8, BINARY, null) - private val ParquetBinaryType = ParquetSchemaType(null, BINARY, null) - private val ParquetDateType = ParquetSchemaType(DATE, INT32, null) - private val ParquetTimestampMicrosType = ParquetSchemaType(TIMESTAMP_MICROS, INT64, null) - private val ParquetTimestampMillisType = ParquetSchemaType(TIMESTAMP_MILLIS, INT64, null) + private val ParquetBooleanType = ParquetSchemaType(null, BOOLEAN, 0, null) + private val ParquetByteType = ParquetSchemaType(INT_8, INT32, 0, null) + private val ParquetShortType = ParquetSchemaType(INT_16, INT32, 0, null) + private val ParquetIntegerType = ParquetSchemaType(null, INT32, 0, null) + private val ParquetLongType = ParquetSchemaType(null, INT64, 0, null) + private val ParquetFloatType = ParquetSchemaType(null, FLOAT, 0, null) + private val ParquetDoubleType = ParquetSchemaType(null, DOUBLE, 0, null) + private val ParquetStringType = ParquetSchemaType(UTF8, BINARY, 0, null) + private val ParquetBinaryType = ParquetSchemaType(null, BINARY, 0, null) + private val ParquetDateType = ParquetSchemaType(DATE, INT32, 0, null) + private val ParquetTimestampMicrosType = ParquetSchemaType(TIMESTAMP_MICROS, INT64, 0, null) + private val ParquetTimestampMillisType = ParquetSchemaType(TIMESTAMP_MILLIS, INT64, 0, null) private def dateToDays(date: Date): SQLDate = { DateTimeUtils.fromJavaDate(date) } + private def decimalToInt32(decimal: JBigDecimal): Integer = decimal.unscaledValue().intValue() + + private def decimalToInt64(decimal: JBigDecimal): JLong = decimal.unscaledValue().longValue() + + private def decimalToByteArray(decimal: JBigDecimal, numBytes: Int): Binary = { + val decimalBuffer = new Array[Byte](numBytes) + val bytes = decimal.unscaledValue().toByteArray + + val fixedLengthBytes = if (bytes.length == numBytes) { + bytes + } else { + val signByte = if (bytes.head < 0) -1: Byte else 0: Byte + java.util.Arrays.fill(decimalBuffer, 0, numBytes - bytes.length, signByte) + System.arraycopy(bytes, 0, decimalBuffer, numBytes - bytes.length, bytes.length) + decimalBuffer + } + Binary.fromConstantByteArray(fixedLengthBytes, 0, numBytes) + } + private val makeEq: PartialFunction[ParquetSchemaType, (String, Any) => FilterPredicate] = { case ParquetBooleanType => - (n: String, v: Any) => FilterApi.eq(booleanColumn(n), v.asInstanceOf[java.lang.Boolean]) + (n: String, v: Any) => FilterApi.eq(booleanColumn(n), v.asInstanceOf[JBoolean]) case ParquetByteType | ParquetShortType | ParquetIntegerType => (n: String, v: Any) => FilterApi.eq( intColumn(n), Option(v).map(_.asInstanceOf[Number].intValue.asInstanceOf[Integer]).orNull) case ParquetLongType => - (n: String, v: Any) => FilterApi.eq(longColumn(n), v.asInstanceOf[java.lang.Long]) + (n: String, v: Any) => FilterApi.eq(longColumn(n), v.asInstanceOf[JLong]) case ParquetFloatType => - (n: String, v: Any) => FilterApi.eq(floatColumn(n), v.asInstanceOf[java.lang.Float]) + (n: String, v: Any) => FilterApi.eq(floatColumn(n), v.asInstanceOf[JFloat]) case ParquetDoubleType => - (n: String, v: Any) => FilterApi.eq(doubleColumn(n), v.asInstanceOf[java.lang.Double]) + (n: String, v: Any) => FilterApi.eq(doubleColumn(n), v.asInstanceOf[JDouble]) // Binary.fromString and Binary.fromByteArray don't accept null values case ParquetStringType => @@ -102,21 +124,34 @@ private[parquet] class ParquetFilters( (n: String, v: Any) => FilterApi.eq( longColumn(n), Option(v).map(_.asInstanceOf[Timestamp].getTime.asInstanceOf[JLong]).orNull) + + case ParquetSchemaType(DECIMAL, INT32, _, _) if pushDownDecimal => + (n: String, v: Any) => FilterApi.eq( + intColumn(n), + Option(v).map(d => decimalToInt32(d.asInstanceOf[JBigDecimal])).orNull) + case ParquetSchemaType(DECIMAL, INT64, _, _) if pushDownDecimal => + (n: String, v: Any) => FilterApi.eq( + longColumn(n), + Option(v).map(d => decimalToInt64(d.asInstanceOf[JBigDecimal])).orNull) + case ParquetSchemaType(DECIMAL, FIXED_LEN_BYTE_ARRAY, length, _) if pushDownDecimal => + (n: String, v: Any) => FilterApi.eq( + binaryColumn(n), + Option(v).map(d => decimalToByteArray(d.asInstanceOf[JBigDecimal], length)).orNull) } private val makeNotEq: PartialFunction[ParquetSchemaType, (String, Any) => FilterPredicate] = { case ParquetBooleanType => - (n: String, v: Any) => FilterApi.notEq(booleanColumn(n), v.asInstanceOf[java.lang.Boolean]) + (n: String, v: Any) => FilterApi.notEq(booleanColumn(n), v.asInstanceOf[JBoolean]) case ParquetByteType | ParquetShortType | ParquetIntegerType => (n: String, v: Any) => FilterApi.notEq( intColumn(n), Option(v).map(_.asInstanceOf[Number].intValue.asInstanceOf[Integer]).orNull) case ParquetLongType => - (n: String, v: Any) => FilterApi.notEq(longColumn(n), v.asInstanceOf[java.lang.Long]) + (n: String, v: Any) => FilterApi.notEq(longColumn(n), v.asInstanceOf[JLong]) case ParquetFloatType => - (n: String, v: Any) => FilterApi.notEq(floatColumn(n), v.asInstanceOf[java.lang.Float]) + (n: String, v: Any) => FilterApi.notEq(floatColumn(n), v.asInstanceOf[JFloat]) case ParquetDoubleType => - (n: String, v: Any) => FilterApi.notEq(doubleColumn(n), v.asInstanceOf[java.lang.Double]) + (n: String, v: Any) => FilterApi.notEq(doubleColumn(n), v.asInstanceOf[JDouble]) case ParquetStringType => (n: String, v: Any) => FilterApi.notEq( @@ -139,6 +174,19 @@ private[parquet] class ParquetFilters( (n: String, v: Any) => FilterApi.notEq( longColumn(n), Option(v).map(_.asInstanceOf[Timestamp].getTime.asInstanceOf[JLong]).orNull) + + case ParquetSchemaType(DECIMAL, INT32, _, _) if pushDownDecimal => + (n: String, v: Any) => FilterApi.notEq( + intColumn(n), + Option(v).map(d => decimalToInt32(d.asInstanceOf[JBigDecimal])).orNull) + case ParquetSchemaType(DECIMAL, INT64, _, _) if pushDownDecimal => + (n: String, v: Any) => FilterApi.notEq( + longColumn(n), + Option(v).map(d => decimalToInt64(d.asInstanceOf[JBigDecimal])).orNull) + case ParquetSchemaType(DECIMAL, FIXED_LEN_BYTE_ARRAY, length, _) if pushDownDecimal => + (n: String, v: Any) => FilterApi.notEq( + binaryColumn(n), + Option(v).map(d => decimalToByteArray(d.asInstanceOf[JBigDecimal], length)).orNull) } private val makeLt: PartialFunction[ParquetSchemaType, (String, Any) => FilterPredicate] = { @@ -146,11 +194,11 @@ private[parquet] class ParquetFilters( (n: String, v: Any) => FilterApi.lt(intColumn(n), v.asInstanceOf[Number].intValue.asInstanceOf[Integer]) case ParquetLongType => - (n: String, v: Any) => FilterApi.lt(longColumn(n), v.asInstanceOf[java.lang.Long]) + (n: String, v: Any) => FilterApi.lt(longColumn(n), v.asInstanceOf[JLong]) case ParquetFloatType => - (n: String, v: Any) => FilterApi.lt(floatColumn(n), v.asInstanceOf[java.lang.Float]) + (n: String, v: Any) => FilterApi.lt(floatColumn(n), v.asInstanceOf[JFloat]) case ParquetDoubleType => - (n: String, v: Any) => FilterApi.lt(doubleColumn(n), v.asInstanceOf[java.lang.Double]) + (n: String, v: Any) => FilterApi.lt(doubleColumn(n), v.asInstanceOf[JDouble]) case ParquetStringType => (n: String, v: Any) => @@ -169,6 +217,16 @@ private[parquet] class ParquetFilters( (n: String, v: Any) => FilterApi.lt( longColumn(n), v.asInstanceOf[Timestamp].getTime.asInstanceOf[JLong]) + + case ParquetSchemaType(DECIMAL, INT32, _, _) if pushDownDecimal => + (n: String, v: Any) => + FilterApi.lt(intColumn(n), decimalToInt32(v.asInstanceOf[JBigDecimal])) + case ParquetSchemaType(DECIMAL, INT64, _, _) if pushDownDecimal => + (n: String, v: Any) => + FilterApi.lt(longColumn(n), decimalToInt64(v.asInstanceOf[JBigDecimal])) + case ParquetSchemaType(DECIMAL, FIXED_LEN_BYTE_ARRAY, length, _) if pushDownDecimal => + (n: String, v: Any) => + FilterApi.lt(binaryColumn(n), decimalToByteArray(v.asInstanceOf[JBigDecimal], length)) } private val makeLtEq: PartialFunction[ParquetSchemaType, (String, Any) => FilterPredicate] = { @@ -176,11 +234,11 @@ private[parquet] class ParquetFilters( (n: String, v: Any) => FilterApi.ltEq(intColumn(n), v.asInstanceOf[Number].intValue.asInstanceOf[Integer]) case ParquetLongType => - (n: String, v: Any) => FilterApi.ltEq(longColumn(n), v.asInstanceOf[java.lang.Long]) + (n: String, v: Any) => FilterApi.ltEq(longColumn(n), v.asInstanceOf[JLong]) case ParquetFloatType => - (n: String, v: Any) => FilterApi.ltEq(floatColumn(n), v.asInstanceOf[java.lang.Float]) + (n: String, v: Any) => FilterApi.ltEq(floatColumn(n), v.asInstanceOf[JFloat]) case ParquetDoubleType => - (n: String, v: Any) => FilterApi.ltEq(doubleColumn(n), v.asInstanceOf[java.lang.Double]) + (n: String, v: Any) => FilterApi.ltEq(doubleColumn(n), v.asInstanceOf[JDouble]) case ParquetStringType => (n: String, v: Any) => @@ -199,6 +257,16 @@ private[parquet] class ParquetFilters( (n: String, v: Any) => FilterApi.ltEq( longColumn(n), v.asInstanceOf[Timestamp].getTime.asInstanceOf[JLong]) + + case ParquetSchemaType(DECIMAL, INT32, _, _) if pushDownDecimal => + (n: String, v: Any) => + FilterApi.ltEq(intColumn(n), decimalToInt32(v.asInstanceOf[JBigDecimal])) + case ParquetSchemaType(DECIMAL, INT64, _, _) if pushDownDecimal => + (n: String, v: Any) => + FilterApi.ltEq(longColumn(n), decimalToInt64(v.asInstanceOf[JBigDecimal])) + case ParquetSchemaType(DECIMAL, FIXED_LEN_BYTE_ARRAY, length, _) if pushDownDecimal => + (n: String, v: Any) => + FilterApi.ltEq(binaryColumn(n), decimalToByteArray(v.asInstanceOf[JBigDecimal], length)) } private val makeGt: PartialFunction[ParquetSchemaType, (String, Any) => FilterPredicate] = { @@ -206,11 +274,11 @@ private[parquet] class ParquetFilters( (n: String, v: Any) => FilterApi.gt(intColumn(n), v.asInstanceOf[Number].intValue.asInstanceOf[Integer]) case ParquetLongType => - (n: String, v: Any) => FilterApi.gt(longColumn(n), v.asInstanceOf[java.lang.Long]) + (n: String, v: Any) => FilterApi.gt(longColumn(n), v.asInstanceOf[JLong]) case ParquetFloatType => - (n: String, v: Any) => FilterApi.gt(floatColumn(n), v.asInstanceOf[java.lang.Float]) + (n: String, v: Any) => FilterApi.gt(floatColumn(n), v.asInstanceOf[JFloat]) case ParquetDoubleType => - (n: String, v: Any) => FilterApi.gt(doubleColumn(n), v.asInstanceOf[java.lang.Double]) + (n: String, v: Any) => FilterApi.gt(doubleColumn(n), v.asInstanceOf[JDouble]) case ParquetStringType => (n: String, v: Any) => @@ -229,6 +297,16 @@ private[parquet] class ParquetFilters( (n: String, v: Any) => FilterApi.gt( longColumn(n), v.asInstanceOf[Timestamp].getTime.asInstanceOf[JLong]) + + case ParquetSchemaType(DECIMAL, INT32, _, _) if pushDownDecimal => + (n: String, v: Any) => + FilterApi.gt(intColumn(n), decimalToInt32(v.asInstanceOf[JBigDecimal])) + case ParquetSchemaType(DECIMAL, INT64, _, _) if pushDownDecimal => + (n: String, v: Any) => + FilterApi.gt(longColumn(n), decimalToInt64(v.asInstanceOf[JBigDecimal])) + case ParquetSchemaType(DECIMAL, FIXED_LEN_BYTE_ARRAY, length, _) if pushDownDecimal => + (n: String, v: Any) => + FilterApi.gt(binaryColumn(n), decimalToByteArray(v.asInstanceOf[JBigDecimal], length)) } private val makeGtEq: PartialFunction[ParquetSchemaType, (String, Any) => FilterPredicate] = { @@ -236,11 +314,11 @@ private[parquet] class ParquetFilters( (n: String, v: Any) => FilterApi.gtEq(intColumn(n), v.asInstanceOf[Number].intValue.asInstanceOf[Integer]) case ParquetLongType => - (n: String, v: Any) => FilterApi.gtEq(longColumn(n), v.asInstanceOf[java.lang.Long]) + (n: String, v: Any) => FilterApi.gtEq(longColumn(n), v.asInstanceOf[JLong]) case ParquetFloatType => - (n: String, v: Any) => FilterApi.gtEq(floatColumn(n), v.asInstanceOf[java.lang.Float]) + (n: String, v: Any) => FilterApi.gtEq(floatColumn(n), v.asInstanceOf[JFloat]) case ParquetDoubleType => - (n: String, v: Any) => FilterApi.gtEq(doubleColumn(n), v.asInstanceOf[java.lang.Double]) + (n: String, v: Any) => FilterApi.gtEq(doubleColumn(n), v.asInstanceOf[JDouble]) case ParquetStringType => (n: String, v: Any) => @@ -259,6 +337,16 @@ private[parquet] class ParquetFilters( (n: String, v: Any) => FilterApi.gtEq( longColumn(n), v.asInstanceOf[Timestamp].getTime.asInstanceOf[JLong]) + + case ParquetSchemaType(DECIMAL, INT32, _, _) if pushDownDecimal => + (n: String, v: Any) => + FilterApi.gtEq(intColumn(n), decimalToInt32(v.asInstanceOf[JBigDecimal])) + case ParquetSchemaType(DECIMAL, INT64, _, _) if pushDownDecimal => + (n: String, v: Any) => + FilterApi.gtEq(longColumn(n), decimalToInt64(v.asInstanceOf[JBigDecimal])) + case ParquetSchemaType(DECIMAL, FIXED_LEN_BYTE_ARRAY, length, _) if pushDownDecimal => + (n: String, v: Any) => + FilterApi.gtEq(binaryColumn(n), decimalToByteArray(v.asInstanceOf[JBigDecimal], length)) } /** @@ -271,7 +359,7 @@ private[parquet] class ParquetFilters( // and it does not support to create filters for them. m.getFields.asScala.filter(_.isPrimitive).map(_.asPrimitiveType()).map { f => f.getName -> ParquetSchemaType( - f.getOriginalType, f.getPrimitiveTypeName, f.getDecimalMetadata) + f.getOriginalType, f.getPrimitiveTypeName, f.getTypeLength, f.getDecimalMetadata) }.toMap case _ => Map.empty[String, ParquetSchemaType] } @@ -282,21 +370,45 @@ private[parquet] class ParquetFilters( def createFilter(schema: MessageType, predicate: sources.Filter): Option[FilterPredicate] = { val nameToType = getFieldMap(schema) + // Decimal type must make sure that filter value's scale matched the file. + // If doesn't matched, which would cause data corruption. + def isDecimalMatched(value: Any, decimalMeta: DecimalMetadata): Boolean = value match { + case decimal: JBigDecimal => + decimal.scale == decimalMeta.getScale + case _ => false + } + + // Parquet's type in the given file should be matched to the value's type + // in the pushed filter in order to push down the filter to Parquet. + def valueCanMakeFilterOn(name: String, value: Any): Boolean = { + value == null || (nameToType(name) match { + case ParquetBooleanType => value.isInstanceOf[JBoolean] + case ParquetByteType | ParquetShortType | ParquetIntegerType => value.isInstanceOf[Number] + case ParquetLongType => value.isInstanceOf[JLong] + case ParquetFloatType => value.isInstanceOf[JFloat] + case ParquetDoubleType => value.isInstanceOf[JDouble] + case ParquetStringType => value.isInstanceOf[String] + case ParquetBinaryType => value.isInstanceOf[Array[Byte]] + case ParquetDateType => value.isInstanceOf[Date] + case ParquetTimestampMicrosType | ParquetTimestampMillisType => + value.isInstanceOf[Timestamp] + case ParquetSchemaType(DECIMAL, INT32, _, decimalMeta) => + isDecimalMatched(value, decimalMeta) + case ParquetSchemaType(DECIMAL, INT64, _, decimalMeta) => + isDecimalMatched(value, decimalMeta) + case ParquetSchemaType(DECIMAL, FIXED_LEN_BYTE_ARRAY, _, decimalMeta) => + isDecimalMatched(value, decimalMeta) + case _ => false + }) + } + // Parquet does not allow dots in the column name because dots are used as a column path // delimiter. Since Parquet 1.8.2 (PARQUET-389), Parquet accepts the filter predicates // with missing columns. The incorrect results could be got from Parquet when we push down // filters for the column having dots in the names. Thus, we do not push down such filters. // See SPARK-20364. - def canMakeFilterOn(name: String): Boolean = nameToType.contains(name) && !name.contains(".") - - // All DataTypes that support `makeEq` can provide better performance. - def shouldConvertInPredicate(name: String): Boolean = nameToType(name) match { - case ParquetBooleanType | ParquetByteType | ParquetShortType | ParquetIntegerType - | ParquetLongType | ParquetFloatType | ParquetDoubleType | ParquetStringType - | ParquetBinaryType => true - case ParquetDateType if pushDownDate => true - case ParquetTimestampMicrosType | ParquetTimestampMillisType if pushDownTimestamp => true - case _ => false + def canMakeFilterOn(name: String, value: Any): Boolean = { + nameToType.contains(name) && !name.contains(".") && valueCanMakeFilterOn(name, value) } // NOTE: @@ -315,29 +427,29 @@ private[parquet] class ParquetFilters( // Probably I missed something and obviously this should be changed. predicate match { - case sources.IsNull(name) if canMakeFilterOn(name) => + case sources.IsNull(name) if canMakeFilterOn(name, null) => makeEq.lift(nameToType(name)).map(_(name, null)) - case sources.IsNotNull(name) if canMakeFilterOn(name) => + case sources.IsNotNull(name) if canMakeFilterOn(name, null) => makeNotEq.lift(nameToType(name)).map(_(name, null)) - case sources.EqualTo(name, value) if canMakeFilterOn(name) => + case sources.EqualTo(name, value) if canMakeFilterOn(name, value) => makeEq.lift(nameToType(name)).map(_(name, value)) - case sources.Not(sources.EqualTo(name, value)) if canMakeFilterOn(name) => + case sources.Not(sources.EqualTo(name, value)) if canMakeFilterOn(name, value) => makeNotEq.lift(nameToType(name)).map(_(name, value)) - case sources.EqualNullSafe(name, value) if canMakeFilterOn(name) => + case sources.EqualNullSafe(name, value) if canMakeFilterOn(name, value) => makeEq.lift(nameToType(name)).map(_(name, value)) - case sources.Not(sources.EqualNullSafe(name, value)) if canMakeFilterOn(name) => + case sources.Not(sources.EqualNullSafe(name, value)) if canMakeFilterOn(name, value) => makeNotEq.lift(nameToType(name)).map(_(name, value)) - case sources.LessThan(name, value) if canMakeFilterOn(name) => + case sources.LessThan(name, value) if canMakeFilterOn(name, value) => makeLt.lift(nameToType(name)).map(_(name, value)) - case sources.LessThanOrEqual(name, value) if canMakeFilterOn(name) => + case sources.LessThanOrEqual(name, value) if canMakeFilterOn(name, value) => makeLtEq.lift(nameToType(name)).map(_(name, value)) - case sources.GreaterThan(name, value) if canMakeFilterOn(name) => + case sources.GreaterThan(name, value) if canMakeFilterOn(name, value) => makeGt.lift(nameToType(name)).map(_(name, value)) - case sources.GreaterThanOrEqual(name, value) if canMakeFilterOn(name) => + case sources.GreaterThanOrEqual(name, value) if canMakeFilterOn(name, value) => makeGtEq.lift(nameToType(name)).map(_(name, value)) case sources.And(lhs, rhs) => @@ -362,13 +474,14 @@ private[parquet] class ParquetFilters( case sources.Not(pred) => createFilter(schema, pred).map(FilterApi.not) - case sources.In(name, values) if canMakeFilterOn(name) && shouldConvertInPredicate(name) + case sources.In(name, values) if canMakeFilterOn(name, values.head) && values.distinct.length <= pushDownInFilterThreshold => values.distinct.flatMap { v => makeEq.lift(nameToType(name)).map(_(name, v)) }.reduceLeftOption(FilterApi.or) - case sources.StringStartsWith(name, prefix) if pushDownStartWith && canMakeFilterOn(name) => + case sources.StringStartsWith(name, prefix) + if pushDownStartWith && canMakeFilterOn(name, prefix) => Option(prefix).map { v => FilterApi.userDefined(binaryColumn(name), new UserDefinedPredicate[Binary] with Serializable { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/FilterPushdownBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/FilterPushdownBenchmark.scala index 567a8ebf9d102..bdb60b44750c7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/FilterPushdownBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/FilterPushdownBenchmark.scala @@ -290,8 +290,12 @@ class FilterPushdownBenchmark extends SparkFunSuite with BenchmarkBeforeAndAfter s"decimal(${DecimalType.MAX_PRECISION}, 2)" ).foreach { dt => val columns = (1 to width).map(i => s"CAST(id AS string) c$i") - val df = spark.range(numRows).selectExpr(columns: _*) - .withColumn("value", monotonically_increasing_id().cast(dt)) + val valueCol = if (dt.equalsIgnoreCase(s"decimal(${Decimal.MAX_INT_DIGITS}, 2)")) { + monotonically_increasing_id() % 9999999 + } else { + monotonically_increasing_id() + } + val df = spark.range(numRows).selectExpr(columns: _*).withColumn("value", valueCol.cast(dt)) withTempTable("orcTable", "patquetTable") { saveAsTable(df, dir) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index 924f136503656..be4f498c921ab 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution.datasources.parquet +import java.math.{BigDecimal => JBigDecimal} import java.nio.charset.StandardCharsets import java.sql.{Date, Timestamp} @@ -58,7 +59,8 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex private lazy val parquetFilters = new ParquetFilters(conf.parquetFilterPushDownDate, conf.parquetFilterPushDownTimestamp, - conf.parquetFilterPushDownStringStartWith, conf.parquetFilterPushDownInFilterThreshold) + conf.parquetFilterPushDownDecimal, conf.parquetFilterPushDownStringStartWith, + conf.parquetFilterPushDownInFilterThreshold) override def beforeEach(): Unit = { super.beforeEach() @@ -86,6 +88,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> "true", SQLConf.PARQUET_FILTER_PUSHDOWN_DATE_ENABLED.key -> "true", SQLConf.PARQUET_FILTER_PUSHDOWN_TIMESTAMP_ENABLED.key -> "true", + SQLConf.PARQUET_FILTER_PUSHDOWN_DECIMAL_ENABLED.key -> "true", SQLConf.PARQUET_FILTER_PUSHDOWN_STRING_STARTSWITH_ENABLED.key -> "true", SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { val query = df @@ -179,6 +182,13 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex } } + private def testDecimalPushDown(data: DataFrame)(f: DataFrame => Unit): Unit = { + withTempPath { file => + data.write.parquet(file.getCanonicalPath) + readParquetFile(file.toString)(f) + } + } + // This function tests that exactly go through the `canDrop` and `inverseCanDrop`. private def testStringStartsWith(dataFrame: DataFrame, filter: String): Unit = { withTempPath { dir => @@ -512,6 +522,84 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex } } + test("filter pushdown - decimal") { + Seq(true, false).foreach { legacyFormat => + withSQLConf(SQLConf.PARQUET_WRITE_LEGACY_FORMAT.key -> legacyFormat.toString) { + Seq( + s"a decimal(${Decimal.MAX_INT_DIGITS}, 2)", // 32BitDecimalType + s"a decimal(${Decimal.MAX_LONG_DIGITS}, 2)", // 64BitDecimalType + "a decimal(38, 18)" // ByteArrayDecimalType + ).foreach { schemaDDL => + val schema = StructType.fromDDL(schemaDDL) + val rdd = + spark.sparkContext.parallelize((1 to 4).map(i => Row(new java.math.BigDecimal(i)))) + val dataFrame = spark.createDataFrame(rdd, schema) + testDecimalPushDown(dataFrame) { implicit df => + assert(df.schema === schema) + checkFilterPredicate('a.isNull, classOf[Eq[_]], Seq.empty[Row]) + checkFilterPredicate('a.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_))) + + checkFilterPredicate('a === 1, classOf[Eq[_]], 1) + checkFilterPredicate('a <=> 1, classOf[Eq[_]], 1) + checkFilterPredicate('a =!= 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) + + checkFilterPredicate('a < 2, classOf[Lt[_]], 1) + checkFilterPredicate('a > 3, classOf[Gt[_]], 4) + checkFilterPredicate('a <= 1, classOf[LtEq[_]], 1) + checkFilterPredicate('a >= 4, classOf[GtEq[_]], 4) + + checkFilterPredicate(Literal(1) === 'a, classOf[Eq[_]], 1) + checkFilterPredicate(Literal(1) <=> 'a, classOf[Eq[_]], 1) + checkFilterPredicate(Literal(2) > 'a, classOf[Lt[_]], 1) + checkFilterPredicate(Literal(3) < 'a, classOf[Gt[_]], 4) + checkFilterPredicate(Literal(1) >= 'a, classOf[LtEq[_]], 1) + checkFilterPredicate(Literal(4) <= 'a, classOf[GtEq[_]], 4) + + checkFilterPredicate(!('a < 4), classOf[GtEq[_]], 4) + checkFilterPredicate('a < 2 || 'a > 3, classOf[Operators.Or], Seq(Row(1), Row(4))) + } + } + } + } + } + + test("Ensure that filter value matched the parquet file schema") { + val scale = 2 + val schema = StructType(Seq( + StructField("cint", IntegerType), + StructField("cdecimal1", DecimalType(Decimal.MAX_INT_DIGITS, scale)), + StructField("cdecimal2", DecimalType(Decimal.MAX_LONG_DIGITS, scale)), + StructField("cdecimal3", DecimalType(DecimalType.MAX_PRECISION, scale)) + )) + + val parquetSchema = new SparkToParquetSchemaConverter(conf).convert(schema) + + val decimal = new JBigDecimal(10).setScale(scale) + val decimal1 = new JBigDecimal(10).setScale(scale + 1) + assert(decimal.scale() === scale) + assert(decimal1.scale() === scale + 1) + + assertResult(Some(lt(intColumn("cdecimal1"), 1000: Integer))) { + parquetFilters.createFilter(parquetSchema, sources.LessThan("cdecimal1", decimal)) + } + assertResult(None) { + parquetFilters.createFilter(parquetSchema, sources.LessThan("cdecimal1", decimal1)) + } + + assertResult(Some(lt(longColumn("cdecimal2"), 1000L: java.lang.Long))) { + parquetFilters.createFilter(parquetSchema, sources.LessThan("cdecimal2", decimal)) + } + assertResult(None) { + parquetFilters.createFilter(parquetSchema, sources.LessThan("cdecimal2", decimal1)) + } + + assert(parquetFilters.createFilter( + parquetSchema, sources.LessThan("cdecimal3", decimal)).isDefined) + assertResult(None) { + parquetFilters.createFilter(parquetSchema, sources.LessThan("cdecimal3", decimal1)) + } + } + test("SPARK-6554: don't push down predicates which reference partition columns") { import testImplicits._ From cf9704534903b5bbd9bd4834728c92953e45293e Mon Sep 17 00:00:00 2001 From: Shahid Date: Mon, 16 Jul 2018 09:50:43 -0500 Subject: [PATCH 1134/1161] [SPARK-18230][MLLIB] Throw a better exception, if the user or product doesn't exist When invoking MatrixFactorizationModel.recommendProducts(Int, Int) with a non-existing user, a java.util.NoSuchElementException is thrown: > java.util.NoSuchElementException: next on empty iterator at scala.collection.Iterator$$anon$2.next(Iterator.scala:39) at scala.collection.Iterator$$anon$2.next(Iterator.scala:37) at scala.collection.IndexedSeqLike$Elements.next(IndexedSeqLike.scala:63) at scala.collection.IterableLike$class.head(IterableLike.scala:107) at scala.collection.mutable.WrappedArray.scala$collection$IndexedSeqOptimized$$super$head(WrappedArray.scala:35) at scala.collection.IndexedSeqOptimized$class.head(IndexedSeqOptimized.scala:126) at scala.collection.mutable.WrappedArray.head(WrappedArray.scala:35) at org.apache.spark.mllib.recommendation.MatrixFactorizationModel.recommendProducts(MatrixFactorizationModel.scala:169) ## What changes were proposed in this pull request? Throw a better exception, like "user-id/product-id doesn't found in the model", for a non-existent user/product ## How was this patch tested? Added UT Author: Shahid Closes #21740 from shahidki31/checkInvalidUserProduct. --- .../MatrixFactorizationModel.scala | 23 ++++++++++++++----- .../MatrixFactorizationModelSuite.scala | 21 +++++++++++++++++ 2 files changed, 38 insertions(+), 6 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala index ac709ad72f0c0..7b49d4d0812f9 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala @@ -78,8 +78,13 @@ class MatrixFactorizationModel @Since("0.8.0") ( /** Predict the rating of one user for one product. */ @Since("0.8.0") def predict(user: Int, product: Int): Double = { - val userVector = userFeatures.lookup(user).head - val productVector = productFeatures.lookup(product).head + val userFeatureSeq = userFeatures.lookup(user) + require(userFeatureSeq.nonEmpty, s"userId: $user not found in the model") + val productFeatureSeq = productFeatures.lookup(product) + require(productFeatureSeq.nonEmpty, s"productId: $product not found in the model") + + val userVector = userFeatureSeq.head + val productVector = productFeatureSeq.head blas.ddot(rank, userVector, 1, productVector, 1) } @@ -164,9 +169,12 @@ class MatrixFactorizationModel @Since("0.8.0") ( * recommended the product is. */ @Since("1.1.0") - def recommendProducts(user: Int, num: Int): Array[Rating] = - MatrixFactorizationModel.recommend(userFeatures.lookup(user).head, productFeatures, num) + def recommendProducts(user: Int, num: Int): Array[Rating] = { + val userFeatureSeq = userFeatures.lookup(user) + require(userFeatureSeq.nonEmpty, s"userId: $user not found in the model") + MatrixFactorizationModel.recommend(userFeatureSeq.head, productFeatures, num) .map(t => Rating(user, t._1, t._2)) + } /** * Recommends users to a product. That is, this returns users who are most likely to be @@ -181,9 +189,12 @@ class MatrixFactorizationModel @Since("0.8.0") ( * recommended the user is. */ @Since("1.1.0") - def recommendUsers(product: Int, num: Int): Array[Rating] = - MatrixFactorizationModel.recommend(productFeatures.lookup(product).head, userFeatures, num) + def recommendUsers(product: Int, num: Int): Array[Rating] = { + val productFeatureSeq = productFeatures.lookup(product) + require(productFeatureSeq.nonEmpty, s"productId: $product not found in the model") + MatrixFactorizationModel.recommend(productFeatureSeq.head, userFeatures, num) .map(t => Rating(t._1, product, t._2)) + } protected override val formatVersion: String = "1.0" diff --git a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModelSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModelSuite.scala index 2c8ed057a516a..5ed9d077afe78 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModelSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModelSuite.scala @@ -72,6 +72,27 @@ class MatrixFactorizationModelSuite extends SparkFunSuite with MLlibTestSparkCon } } + test("invalid user and product") { + val model = new MatrixFactorizationModel(rank, userFeatures, prodFeatures) + + intercept[IllegalArgumentException] { + // invalid user + model.predict(5, 2) + } + intercept[IllegalArgumentException] { + // invalid product + model.predict(0, 5) + } + intercept[IllegalArgumentException] { + // invalid user + model.recommendProducts(5, 2) + } + intercept[IllegalArgumentException] { + // invalid product + model.recommendUsers(5, 2) + } + } + test("batch predict API recommendProductsForUsers") { val model = new MatrixFactorizationModel(rank, userFeatures, prodFeatures) val topK = 10 From b045315e5d87b7ea3588436053aaa4d5a7bd103f Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Mon, 16 Jul 2018 23:16:25 +0800 Subject: [PATCH 1135/1161] [SPARK-24734][SQL] Fix type coercions and nullabilities of nested data types of some functions. ## What changes were proposed in this pull request? We have some functions which need to aware the nullabilities of all children, such as `CreateArray`, `CreateMap`, `Concat`, and so on. Currently we add casts to fix the nullabilities, but the casts might be removed during the optimization phase. After the discussion, we decided to not add extra casts for just fixing the nullabilities of the nested types, but handle them by functions themselves. ## How was this patch tested? Modified and added some tests. Author: Takuya UESHIN Closes #21704 from ueshin/issues/SPARK-24734/concat_containsnull. --- .../sql/catalyst/analysis/TypeCoercion.scala | 113 ++++++++++-------- .../sql/catalyst/expressions/Expression.scala | 12 +- .../sql/catalyst/expressions/arithmetic.scala | 14 +-- .../expressions/collectionOperations.scala | 22 ++-- .../expressions/complexTypeCreator.scala | 15 ++- .../expressions/conditionalExpressions.scala | 4 +- .../sql/catalyst/expressions/literals.scala | 2 +- .../expressions/nullExpressions.scala | 6 +- .../spark/sql/catalyst/util/TypeUtils.scala | 16 +-- .../catalyst/analysis/TypeCoercionSuite.scala | 43 ++++--- .../ArithmeticExpressionSuite.scala | 12 ++ .../CollectionExpressionsSuite.scala | 60 ++++++++-- .../expressions/ComplexTypeSuite.scala | 19 +++ .../expressions/NullExpressionsSuite.scala | 7 ++ .../spark/sql/DataFrameFunctionsSuite.scala | 8 ++ 15 files changed, 211 insertions(+), 142 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index e8331c90ea0f6..316aebdeaffa1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -184,6 +184,17 @@ object TypeCoercion { } } + def findCommonTypeDifferentOnlyInNullFlags(types: Seq[DataType]): Option[DataType] = { + if (types.isEmpty) { + None + } else { + types.tail.foldLeft[Option[DataType]](Some(types.head)) { + case (Some(t1), t2) => findCommonTypeDifferentOnlyInNullFlags(t1, t2) + case _ => None + } + } + } + /** * Case 2 type widening (see the classdoc comment above for TypeCoercion). * @@ -259,8 +270,25 @@ object TypeCoercion { } } - private def haveSameType(exprs: Seq[Expression]): Boolean = - exprs.map(_.dataType).distinct.length == 1 + /** + * Check whether the given types are equal ignoring nullable, containsNull and valueContainsNull. + */ + def haveSameType(types: Seq[DataType]): Boolean = { + if (types.size <= 1) { + true + } else { + val head = types.head + types.tail.forall(_.sameType(head)) + } + } + + private def castIfNotSameType(expr: Expression, dt: DataType): Expression = { + if (!expr.dataType.sameType(dt)) { + Cast(expr, dt) + } else { + expr + } + } /** * Widens numeric types and converts strings to numbers when appropriate. @@ -525,23 +553,24 @@ object TypeCoercion { * This ensure that the types for various functions are as expected. */ object FunctionArgumentConversion extends TypeCoercionRule { + override protected def coerceTypes( plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e - case a @ CreateArray(children) if !haveSameType(children) => + case a @ CreateArray(children) if !haveSameType(children.map(_.dataType)) => val types = children.map(_.dataType) findWiderCommonType(types) match { - case Some(finalDataType) => CreateArray(children.map(Cast(_, finalDataType))) + case Some(finalDataType) => CreateArray(children.map(castIfNotSameType(_, finalDataType))) case None => a } case c @ Concat(children) if children.forall(c => ArrayType.acceptsType(c.dataType)) && - !haveSameType(children) => + !haveSameType(c.inputTypesForMerging) => val types = children.map(_.dataType) findWiderCommonType(types) match { - case Some(finalDataType) => Concat(children.map(Cast(_, finalDataType))) + case Some(finalDataType) => Concat(children.map(castIfNotSameType(_, finalDataType))) case None => c } @@ -553,7 +582,8 @@ object TypeCoercion { case None => aj } - case s @ Sequence(_, _, _, timeZoneId) if !haveSameType(s.coercibleChildren) => + case s @ Sequence(_, _, _, timeZoneId) + if !haveSameType(s.coercibleChildren.map(_.dataType)) => val types = s.coercibleChildren.map(_.dataType) findWiderCommonType(types) match { case Some(widerDataType) => s.castChildrenTo(widerDataType) @@ -561,33 +591,25 @@ object TypeCoercion { } case m @ MapConcat(children) if children.forall(c => MapType.acceptsType(c.dataType)) && - !haveSameType(children) => + !haveSameType(m.inputTypesForMerging) => val types = children.map(_.dataType) findWiderCommonType(types) match { - case Some(finalDataType) => MapConcat(children.map(Cast(_, finalDataType))) + case Some(finalDataType) => MapConcat(children.map(castIfNotSameType(_, finalDataType))) case None => m } case m @ CreateMap(children) if m.keys.length == m.values.length && - (!haveSameType(m.keys) || !haveSameType(m.values)) => - val newKeys = if (haveSameType(m.keys)) { - m.keys - } else { - val types = m.keys.map(_.dataType) - findWiderCommonType(types) match { - case Some(finalDataType) => m.keys.map(Cast(_, finalDataType)) - case None => m.keys - } + (!haveSameType(m.keys.map(_.dataType)) || !haveSameType(m.values.map(_.dataType))) => + val keyTypes = m.keys.map(_.dataType) + val newKeys = findWiderCommonType(keyTypes) match { + case Some(finalDataType) => m.keys.map(castIfNotSameType(_, finalDataType)) + case None => m.keys } - val newValues = if (haveSameType(m.values)) { - m.values - } else { - val types = m.values.map(_.dataType) - findWiderCommonType(types) match { - case Some(finalDataType) => m.values.map(Cast(_, finalDataType)) - case None => m.values - } + val valueTypes = m.values.map(_.dataType) + val newValues = findWiderCommonType(valueTypes) match { + case Some(finalDataType) => m.values.map(castIfNotSameType(_, finalDataType)) + case None => m.values } CreateMap(newKeys.zip(newValues).flatMap { case (k, v) => Seq(k, v) }) @@ -610,27 +632,27 @@ object TypeCoercion { // Coalesce should return the first non-null value, which could be any column // from the list. So we need to make sure the return type is deterministic and // compatible with every child column. - case c @ Coalesce(es) if !haveSameType(es) => + case c @ Coalesce(es) if !haveSameType(c.inputTypesForMerging) => val types = es.map(_.dataType) findWiderCommonType(types) match { - case Some(finalDataType) => Coalesce(es.map(Cast(_, finalDataType))) + case Some(finalDataType) => Coalesce(es.map(castIfNotSameType(_, finalDataType))) case None => c } // When finding wider type for `Greatest` and `Least`, we should handle decimal types even if // we need to truncate, but we should not promote one side to string if the other side is // string.g - case g @ Greatest(children) if !haveSameType(children) => + case g @ Greatest(children) if !haveSameType(g.inputTypesForMerging) => val types = children.map(_.dataType) findWiderTypeWithoutStringPromotion(types) match { - case Some(finalDataType) => Greatest(children.map(Cast(_, finalDataType))) + case Some(finalDataType) => Greatest(children.map(castIfNotSameType(_, finalDataType))) case None => g } - case l @ Least(children) if !haveSameType(children) => + case l @ Least(children) if !haveSameType(l.inputTypesForMerging) => val types = children.map(_.dataType) findWiderTypeWithoutStringPromotion(types) match { - case Some(finalDataType) => Least(children.map(Cast(_, finalDataType))) + case Some(finalDataType) => Least(children.map(castIfNotSameType(_, finalDataType))) case None => l } @@ -672,27 +694,14 @@ object TypeCoercion { object CaseWhenCoercion extends TypeCoercionRule { override protected def coerceTypes( plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { - case c: CaseWhen if c.childrenResolved && !c.areInputTypesForMergingEqual => + case c: CaseWhen if c.childrenResolved && !haveSameType(c.inputTypesForMerging) => val maybeCommonType = findWiderCommonType(c.inputTypesForMerging) maybeCommonType.map { commonType => - var changed = false val newBranches = c.branches.map { case (condition, value) => - if (value.dataType.sameType(commonType)) { - (condition, value) - } else { - changed = true - (condition, Cast(value, commonType)) - } - } - val newElseValue = c.elseValue.map { value => - if (value.dataType.sameType(commonType)) { - value - } else { - changed = true - Cast(value, commonType) - } + (condition, castIfNotSameType(value, commonType)) } - if (changed) CaseWhen(newBranches, newElseValue) else c + val newElseValue = c.elseValue.map(castIfNotSameType(_, commonType)) + CaseWhen(newBranches, newElseValue) }.getOrElse(c) } } @@ -705,10 +714,10 @@ object TypeCoercion { plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { case e if !e.childrenResolved => e // Find tightest common type for If, if the true value and false value have different types. - case i @ If(pred, left, right) if !i.areInputTypesForMergingEqual => + case i @ If(pred, left, right) if !haveSameType(i.inputTypesForMerging) => findWiderTypeForTwo(left.dataType, right.dataType).map { widestType => - val newLeft = if (left.dataType.sameType(widestType)) left else Cast(left, widestType) - val newRight = if (right.dataType.sameType(widestType)) right else Cast(right, widestType) + val newLeft = castIfNotSameType(left, widestType) + val newRight = castIfNotSameType(right, widestType) If(pred, newLeft, newRight) }.getOrElse(i) // If there is no applicable conversion, leave expression unchanged. case If(Literal(null, NullType), left, right) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 44c5556ff9ccf..f7d1b105964d5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -709,22 +709,12 @@ trait ComplexTypeMergingExpression extends Expression { @transient lazy val inputTypesForMerging: Seq[DataType] = children.map(_.dataType) - /** - * A method determining whether the input types are equal ignoring nullable, containsNull and - * valueContainsNull flags and thus convenient for resolution of the final data type. - */ - def areInputTypesForMergingEqual: Boolean = { - inputTypesForMerging.length <= 1 || inputTypesForMerging.sliding(2, 1).forall { - case Seq(dt1, dt2) => dt1.sameType(dt2) - } - } - override def dataType: DataType = { require( inputTypesForMerging.nonEmpty, "The collection of input data types must not be empty.") require( - areInputTypesForMergingEqual, + TypeCoercion.haveSameType(inputTypesForMerging), "All input types must be the same except nullable, containsNull, valueContainsNull flags.") inputTypesForMerging.reduceLeft(TypeCoercion.findCommonTypeDifferentOnlyInNullFlags(_, _).get) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index fe91e520169b4..55940410cc4d4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.TypeUtils @@ -514,7 +514,7 @@ case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic { > SELECT _FUNC_(10, 9, 2, 4, 3); 2 """) -case class Least(children: Seq[Expression]) extends Expression { +case class Least(children: Seq[Expression]) extends ComplexTypeMergingExpression { override def nullable: Boolean = children.forall(_.nullable) override def foldable: Boolean = children.forall(_.foldable) @@ -525,7 +525,7 @@ case class Least(children: Seq[Expression]) extends Expression { if (children.length <= 1) { TypeCheckResult.TypeCheckFailure( s"input to function $prettyName requires at least two arguments") - } else if (children.map(_.dataType).distinct.count(_ != NullType) > 1) { + } else if (!TypeCoercion.haveSameType(inputTypesForMerging)) { TypeCheckResult.TypeCheckFailure( s"The expressions should all have the same type," + s" got LEAST(${children.map(_.dataType.simpleString).mkString(", ")}).") @@ -534,8 +534,6 @@ case class Least(children: Seq[Expression]) extends Expression { } } - override def dataType: DataType = children.head.dataType - override def eval(input: InternalRow): Any = { children.foldLeft[Any](null)((r, c) => { val evalc = c.eval(input) @@ -589,7 +587,7 @@ case class Least(children: Seq[Expression]) extends Expression { > SELECT _FUNC_(10, 9, 2, 4, 3); 10 """) -case class Greatest(children: Seq[Expression]) extends Expression { +case class Greatest(children: Seq[Expression]) extends ComplexTypeMergingExpression { override def nullable: Boolean = children.forall(_.nullable) override def foldable: Boolean = children.forall(_.foldable) @@ -600,7 +598,7 @@ case class Greatest(children: Seq[Expression]) extends Expression { if (children.length <= 1) { TypeCheckResult.TypeCheckFailure( s"input to function $prettyName requires at least two arguments") - } else if (children.map(_.dataType).distinct.count(_ != NullType) > 1) { + } else if (!TypeCoercion.haveSameType(inputTypesForMerging)) { TypeCheckResult.TypeCheckFailure( s"The expressions should all have the same type," + s" got GREATEST(${children.map(_.dataType.simpleString).mkString(", ")}).") @@ -609,8 +607,6 @@ case class Greatest(children: Seq[Expression]) extends Expression { } } - override def dataType: DataType = children.head.dataType - override def eval(input: InternalRow): Any = { children.foldLeft[Any](null)((r, c) => { val evalc = c.eval(input) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 0f4f4f1601b4a..972bc6e57892c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -507,7 +507,7 @@ case class MapEntries(child: Expression) extends UnaryExpression with ExpectsInp > SELECT _FUNC_(map(1, 'a', 2, 'b'), map(2, 'c', 3, 'd')); [[1 -> "a"], [2 -> "b"], [2 -> "c"], [3 -> "d"]] """, since = "2.4.0") -case class MapConcat(children: Seq[Expression]) extends Expression { +case class MapConcat(children: Seq[Expression]) extends ComplexTypeMergingExpression { override def checkInputDataTypes(): TypeCheckResult = { var funcName = s"function $prettyName" @@ -521,14 +521,10 @@ case class MapConcat(children: Seq[Expression]) extends Expression { } override def dataType: MapType = { - val dt = children.map(_.dataType.asInstanceOf[MapType]).headOption - .getOrElse(MapType(StringType, StringType)) - val valueContainsNull = children.map(_.dataType.asInstanceOf[MapType]) - .exists(_.valueContainsNull) - if (dt.valueContainsNull != valueContainsNull) { - dt.copy(valueContainsNull = valueContainsNull) + if (children.isEmpty) { + MapType(StringType, StringType) } else { - dt + super.dataType.asInstanceOf[MapType] } } @@ -2211,7 +2207,7 @@ case class ElementAt(left: Expression, right: Expression) extends GetMapValueUti > SELECT _FUNC_(array(1, 2, 3), array(4, 5), array(6)); | [1,2,3,4,5,6] """) -case class Concat(children: Seq[Expression]) extends Expression { +case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpression { private val MAX_ARRAY_LENGTH: Int = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH @@ -2232,7 +2228,13 @@ case class Concat(children: Seq[Expression]) extends Expression { } } - override def dataType: DataType = children.map(_.dataType).headOption.getOrElse(StringType) + override def dataType: DataType = { + if (children.isEmpty) { + StringType + } else { + super.dataType + } + } lazy val javaType: String = CodeGenerator.javaType(dataType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 0a5f8a907b50a..a43de028360b1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -18,8 +18,8 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion} import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util._ @@ -48,7 +48,8 @@ case class CreateArray(children: Seq[Expression]) extends Expression { override def dataType: ArrayType = { ArrayType( - children.headOption.map(_.dataType).getOrElse(StringType), + TypeCoercion.findCommonTypeDifferentOnlyInNullFlags(children.map(_.dataType)) + .getOrElse(StringType), containsNull = children.exists(_.nullable)) } @@ -179,11 +180,11 @@ case class CreateMap(children: Seq[Expression]) extends Expression { if (children.size % 2 != 0) { TypeCheckResult.TypeCheckFailure( s"$prettyName expects a positive even number of arguments.") - } else if (keys.map(_.dataType).distinct.length > 1) { + } else if (!TypeCoercion.haveSameType(keys.map(_.dataType))) { TypeCheckResult.TypeCheckFailure( "The given keys of function map should all be the same type, but they are " + keys.map(_.dataType.simpleString).mkString("[", ", ", "]")) - } else if (values.map(_.dataType).distinct.length > 1) { + } else if (!TypeCoercion.haveSameType(values.map(_.dataType))) { TypeCheckResult.TypeCheckFailure( "The given values of function map should all be the same type, but they are " + values.map(_.dataType.simpleString).mkString("[", ", ", "]")) @@ -194,8 +195,10 @@ case class CreateMap(children: Seq[Expression]) extends Expression { override def dataType: DataType = { MapType( - keyType = keys.headOption.map(_.dataType).getOrElse(StringType), - valueType = values.headOption.map(_.dataType).getOrElse(StringType), + keyType = TypeCoercion.findCommonTypeDifferentOnlyInNullFlags(keys.map(_.dataType)) + .getOrElse(StringType), + valueType = TypeCoercion.findCommonTypeDifferentOnlyInNullFlags(values.map(_.dataType)) + .getOrElse(StringType), valueContainsNull = values.exists(_.nullable)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index 30ce9e4743da9..3b597e8b5263b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -48,7 +48,7 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi TypeCheckResult.TypeCheckFailure( "type of predicate expression in If should be boolean, " + s"not ${predicate.dataType.simpleString}") - } else if (!areInputTypesForMergingEqual) { + } else if (!TypeCoercion.haveSameType(inputTypesForMerging)) { TypeCheckResult.TypeCheckFailure(s"differing types in '$sql' " + s"(${trueValue.dataType.simpleString} and ${falseValue.dataType.simpleString}).") } else { @@ -137,7 +137,7 @@ case class CaseWhen( } override def checkInputDataTypes(): TypeCheckResult = { - if (areInputTypesForMergingEqual) { + if (TypeCoercion.haveSameType(inputTypesForMerging)) { // Make sure all branch conditions are boolean types. if (branches.forall(_._1.dataType == BooleanType)) { TypeCheckResult.TypeCheckSuccess diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index 0cc2a332f2c30..0efd1224f1bca 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -186,7 +186,7 @@ object Literal { case map: MapType => create(Map(), map) case struct: StructType => create(InternalRow.fromSeq(struct.fields.map(f => default(f.dataType).value)), struct) - case udt: UserDefinedType[_] => default(udt.sqlType) + case udt: UserDefinedType[_] => Literal(default(udt.sqlType).value, udt) case other => throw new RuntimeException(s"no default for type $dataType") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala index 2eeed3bbb2d91..b683d2a7e9ef3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.TypeUtils @@ -44,7 +44,7 @@ import org.apache.spark.sql.types._ 1 """) // scalastyle:on line.size.limit -case class Coalesce(children: Seq[Expression]) extends Expression { +case class Coalesce(children: Seq[Expression]) extends ComplexTypeMergingExpression { /** Coalesce is nullable if all of its children are nullable, or if it has no children. */ override def nullable: Boolean = children.forall(_.nullable) @@ -61,8 +61,6 @@ case class Coalesce(children: Seq[Expression]) extends Expression { } } - override def dataType: DataType = children.head.dataType - override def eval(input: InternalRow): Any = { var result: Any = null val childIterator = children.iterator diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala index 1dcda49a3af6a..b795abea95a74 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.util -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion} import org.apache.spark.sql.catalyst.expressions.RowOrdering import org.apache.spark.sql.types._ @@ -42,18 +42,12 @@ object TypeUtils { } def checkForSameTypeInputExpr(types: Seq[DataType], caller: String): TypeCheckResult = { - if (types.size <= 1) { + if (TypeCoercion.haveSameType(types)) { TypeCheckResult.TypeCheckSuccess } else { - val firstType = types.head - types.foreach { t => - if (!t.sameType(firstType)) { - return TypeCheckResult.TypeCheckFailure( - s"input to $caller should all be the same type, but it's " + - types.map(_.simpleString).mkString("[", ", ", "]")) - } - } - TypeCheckResult.TypeCheckSuccess + return TypeCheckResult.TypeCheckFailure( + s"input to $caller should all be the same type, but it's " + + types.map(_.simpleString).mkString("[", ", ", "]")) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala index 8cc5a23779a2a..4161f09c63190 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala @@ -687,46 +687,43 @@ class TypeCoercionSuite extends AnalysisTest { ruleTest(rule, Coalesce(Seq(doubleLit, intLit, floatLit)), - Coalesce(Seq(Cast(doubleLit, DoubleType), - Cast(intLit, DoubleType), Cast(floatLit, DoubleType)))) + Coalesce(Seq(doubleLit, Cast(intLit, DoubleType), Cast(floatLit, DoubleType)))) ruleTest(rule, Coalesce(Seq(longLit, intLit, decimalLit)), Coalesce(Seq(Cast(longLit, DecimalType(22, 0)), - Cast(intLit, DecimalType(22, 0)), Cast(decimalLit, DecimalType(22, 0))))) + Cast(intLit, DecimalType(22, 0)), decimalLit))) ruleTest(rule, Coalesce(Seq(nullLit, intLit)), - Coalesce(Seq(Cast(nullLit, IntegerType), Cast(intLit, IntegerType)))) + Coalesce(Seq(Cast(nullLit, IntegerType), intLit))) ruleTest(rule, Coalesce(Seq(timestampLit, stringLit)), - Coalesce(Seq(Cast(timestampLit, StringType), Cast(stringLit, StringType)))) + Coalesce(Seq(Cast(timestampLit, StringType), stringLit))) ruleTest(rule, Coalesce(Seq(nullLit, floatNullLit, intLit)), - Coalesce(Seq(Cast(nullLit, FloatType), Cast(floatNullLit, FloatType), - Cast(intLit, FloatType)))) + Coalesce(Seq(Cast(nullLit, FloatType), floatNullLit, Cast(intLit, FloatType)))) ruleTest(rule, Coalesce(Seq(nullLit, intLit, decimalLit, doubleLit)), Coalesce(Seq(Cast(nullLit, DoubleType), Cast(intLit, DoubleType), - Cast(decimalLit, DoubleType), Cast(doubleLit, DoubleType)))) + Cast(decimalLit, DoubleType), doubleLit))) ruleTest(rule, Coalesce(Seq(nullLit, floatNullLit, doubleLit, stringLit)), Coalesce(Seq(Cast(nullLit, StringType), Cast(floatNullLit, StringType), - Cast(doubleLit, StringType), Cast(stringLit, StringType)))) + Cast(doubleLit, StringType), stringLit))) ruleTest(rule, Coalesce(Seq(timestampLit, intLit, stringLit)), - Coalesce(Seq(Cast(timestampLit, StringType), Cast(intLit, StringType), - Cast(stringLit, StringType)))) + Coalesce(Seq(Cast(timestampLit, StringType), Cast(intLit, StringType), stringLit))) ruleTest(rule, Coalesce(Seq(tsArrayLit, intArrayLit, strArrayLit)), Coalesce(Seq(Cast(tsArrayLit, ArrayType(StringType)), - Cast(intArrayLit, ArrayType(StringType)), Cast(strArrayLit, ArrayType(StringType))))) + Cast(intArrayLit, ArrayType(StringType)), strArrayLit))) } test("CreateArray casts") { @@ -735,7 +732,7 @@ class TypeCoercionSuite extends AnalysisTest { :: Literal(1) :: Literal.create(1.0, FloatType) :: Nil), - CreateArray(Cast(Literal(1.0), DoubleType) + CreateArray(Literal(1.0) :: Cast(Literal(1), DoubleType) :: Cast(Literal.create(1.0, FloatType), DoubleType) :: Nil)) @@ -747,7 +744,7 @@ class TypeCoercionSuite extends AnalysisTest { :: Nil), CreateArray(Cast(Literal(1.0), StringType) :: Cast(Literal(1), StringType) - :: Cast(Literal("a"), StringType) + :: Literal("a") :: Nil)) ruleTest(TypeCoercion.FunctionArgumentConversion, @@ -765,7 +762,7 @@ class TypeCoercionSuite extends AnalysisTest { :: Nil), CreateArray(Literal.create(null, DecimalType(5, 3)).cast(DecimalType(38, 38)) :: Literal.create(null, DecimalType(22, 10)).cast(DecimalType(38, 38)) - :: Literal.create(null, DecimalType(38, 38)).cast(DecimalType(38, 38)) + :: Literal.create(null, DecimalType(38, 38)) :: Nil)) } @@ -779,7 +776,7 @@ class TypeCoercionSuite extends AnalysisTest { :: Nil), CreateMap(Cast(Literal(1), FloatType) :: Literal("a") - :: Cast(Literal.create(2.0, FloatType), FloatType) + :: Literal.create(2.0, FloatType) :: Literal("b") :: Nil)) ruleTest(TypeCoercion.FunctionArgumentConversion, @@ -801,7 +798,7 @@ class TypeCoercionSuite extends AnalysisTest { :: Literal(3.0) :: Nil), CreateMap(Literal(1) - :: Cast(Literal("a"), StringType) + :: Literal("a") :: Literal(2) :: Cast(Literal(3.0), StringType) :: Nil)) @@ -814,7 +811,7 @@ class TypeCoercionSuite extends AnalysisTest { CreateMap(Literal(1) :: Literal.create(null, DecimalType(38, 0)).cast(DecimalType(38, 38)) :: Literal(2) - :: Literal.create(null, DecimalType(38, 38)).cast(DecimalType(38, 38)) + :: Literal.create(null, DecimalType(38, 38)) :: Nil)) // type coercion for both map keys and values ruleTest(TypeCoercion.FunctionArgumentConversion, @@ -824,8 +821,8 @@ class TypeCoercionSuite extends AnalysisTest { :: Literal(3.0) :: Nil), CreateMap(Cast(Literal(1), DoubleType) - :: Cast(Literal("a"), StringType) - :: Cast(Literal(2.0), DoubleType) + :: Literal("a") + :: Literal(2.0) :: Cast(Literal(3.0), StringType) :: Nil)) } @@ -837,7 +834,7 @@ class TypeCoercionSuite extends AnalysisTest { :: Literal(1) :: Literal.create(1.0, FloatType) :: Nil), - operator(Cast(Literal(1.0), DoubleType) + operator(Literal(1.0) :: Cast(Literal(1), DoubleType) :: Cast(Literal.create(1.0, FloatType), DoubleType) :: Nil)) @@ -848,14 +845,14 @@ class TypeCoercionSuite extends AnalysisTest { :: Nil), operator(Cast(Literal(1L), DecimalType(22, 0)) :: Cast(Literal(1), DecimalType(22, 0)) - :: Cast(Literal(new java.math.BigDecimal("1000000000000000000000")), DecimalType(22, 0)) + :: Literal(new java.math.BigDecimal("1000000000000000000000")) :: Nil)) ruleTest(TypeCoercion.FunctionArgumentConversion, operator(Literal(1.0) :: Literal.create(null, DecimalType(10, 5)) :: Literal(1) :: Nil), - operator(Literal(1.0).cast(DoubleType) + operator(Literal(1.0) :: Literal.create(null, DecimalType(10, 5)).cast(DoubleType) :: Literal(1).cast(DoubleType) :: Nil)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala index 6edb4348f8309..021217606dc03 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala @@ -282,6 +282,12 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper DataTypeTestUtils.ordered.foreach { dt => checkConsistencyBetweenInterpretedAndCodegen(Least, dt, 2) } + + val least = Least(Seq( + Literal.create(Seq(1, 2), ArrayType(IntegerType, containsNull = false)), + Literal.create(Seq(1, 3, null), ArrayType(IntegerType, containsNull = true)))) + assert(least.dataType === ArrayType(IntegerType, containsNull = true)) + checkEvaluation(least, Seq(1, 2)) } test("function greatest") { @@ -334,6 +340,12 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper DataTypeTestUtils.ordered.foreach { dt => checkConsistencyBetweenInterpretedAndCodegen(Greatest, dt, 2) } + + val greatest = Greatest(Seq( + Literal.create(Seq(1, 2), ArrayType(IntegerType, containsNull = false)), + Literal.create(Seq(1, 3, null), ArrayType(IntegerType, containsNull = true)))) + assert(greatest.dataType === ArrayType(IntegerType, containsNull = true)) + checkEvaluation(greatest, Seq(1, 3, null)) } test("SPARK-22499: Least and greatest should not generate codes beyond 64KB") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 85d6a1befed6b..f1e3bd091565d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -227,6 +227,27 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper assert(MapConcat(Seq(m6, m5)).dataType.valueContainsNull) assert(!MapConcat(Seq(m1, m2)).nullable) assert(MapConcat(Seq(m1, mNull)).nullable) + + val mapConcat = MapConcat(Seq( + Literal.create(Map(Seq(1, 2) -> Seq("a", "b")), + MapType( + ArrayType(IntegerType, containsNull = false), + ArrayType(StringType, containsNull = false), + valueContainsNull = false)), + Literal.create(Map(Seq(3, 4, null) -> Seq("c", "d", null), Seq(6) -> null), + MapType( + ArrayType(IntegerType, containsNull = true), + ArrayType(StringType, containsNull = true), + valueContainsNull = true)))) + assert(mapConcat.dataType === + MapType( + ArrayType(IntegerType, containsNull = true), + ArrayType(StringType, containsNull = true), + valueContainsNull = true)) + checkEvaluation(mapConcat, Map( + Seq(1, 2) -> Seq("a", "b"), + Seq(3, 4, null) -> Seq("c", "d", null), + Seq(6) -> null)) } test("MapFromEntries") { @@ -1050,11 +1071,11 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper test("Concat") { // Primitive-type elements - val ai0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType)) - val ai1 = Literal.create(Seq.empty[Integer], ArrayType(IntegerType)) - val ai2 = Literal.create(Seq(4, null, 5), ArrayType(IntegerType)) - val ai3 = Literal.create(Seq(null, null), ArrayType(IntegerType)) - val ai4 = Literal.create(null, ArrayType(IntegerType)) + val ai0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType, containsNull = false)) + val ai1 = Literal.create(Seq.empty[Integer], ArrayType(IntegerType, containsNull = false)) + val ai2 = Literal.create(Seq(4, null, 5), ArrayType(IntegerType, containsNull = true)) + val ai3 = Literal.create(Seq(null, null), ArrayType(IntegerType, containsNull = true)) + val ai4 = Literal.create(null, ArrayType(IntegerType, containsNull = false)) checkEvaluation(Concat(Seq(ai0)), Seq(1, 2, 3)) checkEvaluation(Concat(Seq(ai0, ai1)), Seq(1, 2, 3)) @@ -1067,14 +1088,18 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(Concat(Seq(ai4, ai0)), null) // Non-primitive-type elements - val as0 = Literal.create(Seq("a", "b", "c"), ArrayType(StringType)) - val as1 = Literal.create(Seq.empty[String], ArrayType(StringType)) - val as2 = Literal.create(Seq("d", null, "e"), ArrayType(StringType)) - val as3 = Literal.create(Seq(null, null), ArrayType(StringType)) - val as4 = Literal.create(null, ArrayType(StringType)) - - val aa0 = Literal.create(Seq(Seq("a", "b"), Seq("c")), ArrayType(ArrayType(StringType))) - val aa1 = Literal.create(Seq(Seq("d"), Seq("e", "f")), ArrayType(ArrayType(StringType))) + val as0 = Literal.create(Seq("a", "b", "c"), ArrayType(StringType, containsNull = false)) + val as1 = Literal.create(Seq.empty[String], ArrayType(StringType, containsNull = false)) + val as2 = Literal.create(Seq("d", null, "e"), ArrayType(StringType, containsNull = true)) + val as3 = Literal.create(Seq(null, null), ArrayType(StringType, containsNull = true)) + val as4 = Literal.create(null, ArrayType(StringType, containsNull = false)) + + val aa0 = Literal.create(Seq(Seq("a", "b"), Seq("c")), + ArrayType(ArrayType(StringType, containsNull = false), containsNull = false)) + val aa1 = Literal.create(Seq(Seq("d"), Seq("e", "f")), + ArrayType(ArrayType(StringType, containsNull = false), containsNull = false)) + val aa2 = Literal.create(Seq(Seq("g", null), null), + ArrayType(ArrayType(StringType, containsNull = true), containsNull = true)) checkEvaluation(Concat(Seq(as0)), Seq("a", "b", "c")) checkEvaluation(Concat(Seq(as0, as1)), Seq("a", "b", "c")) @@ -1087,6 +1112,15 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(Concat(Seq(as4, as0)), null) checkEvaluation(Concat(Seq(aa0, aa1)), Seq(Seq("a", "b"), Seq("c"), Seq("d"), Seq("e", "f"))) + + assert(Concat(Seq(ai0, ai1)).dataType.asInstanceOf[ArrayType].containsNull === false) + assert(Concat(Seq(ai0, ai2)).dataType.asInstanceOf[ArrayType].containsNull === true) + assert(Concat(Seq(as0, as1)).dataType.asInstanceOf[ArrayType].containsNull === false) + assert(Concat(Seq(as0, as2)).dataType.asInstanceOf[ArrayType].containsNull === true) + assert(Concat(Seq(aa0, aa1)).dataType === + ArrayType(ArrayType(StringType, containsNull = false), containsNull = false)) + assert(Concat(Seq(aa0, aa2)).dataType === + ArrayType(ArrayType(StringType, containsNull = true), containsNull = true)) } test("Flatten") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala index 726193b411737..77aaf55480ec2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala @@ -144,6 +144,13 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(CreateArray(byteWithNull), byteSeq :+ null, EmptyRow) checkEvaluation(CreateArray(strWithNull), strSeq :+ null, EmptyRow) checkEvaluation(CreateArray(Literal.create(null, IntegerType) :: Nil), null :: Nil) + + val array = CreateArray(Seq( + Literal.create(intSeq, ArrayType(IntegerType, containsNull = false)), + Literal.create(intSeq :+ null, ArrayType(IntegerType, containsNull = true)))) + assert(array.dataType === + ArrayType(ArrayType(IntegerType, containsNull = true), containsNull = false)) + checkEvaluation(array, Seq(intSeq, intSeq :+ null)) } test("CreateMap") { @@ -184,6 +191,18 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { CreateMap(interlace(strWithNull, intSeq.map(Literal(_)))), null, null) } + + val map = CreateMap(Seq( + Literal.create(intSeq, ArrayType(IntegerType, containsNull = false)), + Literal.create(strSeq, ArrayType(StringType, containsNull = false)), + Literal.create(intSeq :+ null, ArrayType(IntegerType, containsNull = true)), + Literal.create(strSeq :+ null, ArrayType(StringType, containsNull = true)))) + assert(map.dataType === + MapType( + ArrayType(IntegerType, containsNull = true), + ArrayType(StringType, containsNull = true), + valueContainsNull = false)) + checkEvaluation(map, createMap(Seq(intSeq, intSeq :+ null), Seq(strSeq, strSeq :+ null))) } test("MapFromArrays") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala index 424c3a4696077..6e07f7a59b730 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala @@ -86,6 +86,13 @@ class NullExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Coalesce(Seq(nullLit, lit, lit)), value) checkEvaluation(Coalesce(Seq(nullLit, nullLit, lit)), value) } + + val coalesce = Coalesce(Seq( + Literal.create(null, ArrayType(IntegerType, containsNull = false)), + Literal.create(Seq(1, 2, 3), ArrayType(IntegerType, containsNull = false)), + Literal.create(Seq(1, 2, 3, null), ArrayType(IntegerType, containsNull = true)))) + assert(coalesce.dataType === ArrayType(IntegerType, containsNull = true)) + checkEvaluation(coalesce, Seq(1, 2, 3)) } test("SPARK-16602 Nvl should support numeric-string cases") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index d4615714cff03..3f6f4556e2d6b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -1622,6 +1622,14 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { assert(errMsg.contains(s"input to function $name requires at least two arguments")) } } + + test("SPARK-24734: Fix containsNull of Concat for array type") { + val df = Seq((Seq(1), Seq[Integer](null), Seq("a", "b"))).toDF("k1", "k2", "v") + val ex = intercept[RuntimeException] { + df.select(map_from_arrays(concat($"k1", $"k2"), $"v")).show() + } + assert(ex.getMessage.contains("Cannot use null as map key")) + } } object DataFrameFunctionsSuite { From b0c95a1d698888df752bd62e49838a98268f6847 Mon Sep 17 00:00:00 2001 From: Marek Novotny Date: Mon, 16 Jul 2018 14:28:35 -0700 Subject: [PATCH 1136/1161] [SPARK-23901][SQL] Removing masking functions The PR reverts #21246. Author: Marek Novotny Closes #21786 from mn-mikke/SPARK-23901. --- .../expressions/MaskExpressionsUtils.java | 80 --- .../catalyst/analysis/FunctionRegistry.scala | 8 - .../expressions/maskExpressions.scala | 569 ------------------ .../expressions/MaskExpressionsSuite.scala | 236 -------- .../org/apache/spark/sql/functions.scala | 119 ---- .../spark/sql/DataFrameFunctionsSuite.scala | 107 ---- 6 files changed, 1119 deletions(-) delete mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/MaskExpressionsUtils.java delete mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/maskExpressions.scala delete mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MaskExpressionsSuite.scala diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/MaskExpressionsUtils.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/MaskExpressionsUtils.java deleted file mode 100644 index 05879902a4ed9..0000000000000 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/MaskExpressionsUtils.java +++ /dev/null @@ -1,80 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.expressions; - -/** - * Contains all the Utils methods used in the masking expressions. - */ -public class MaskExpressionsUtils { - static final int UNMASKED_VAL = -1; - - /** - * Returns the masking character for {@param c} or {@param c} is it should not be masked. - * @param c the character to transform - * @param maskedUpperChar the character to use instead of a uppercase letter - * @param maskedLowerChar the character to use instead of a lowercase letter - * @param maskedDigitChar the character to use instead of a digit - * @param maskedOtherChar the character to use instead of a any other character - * @return masking character for {@param c} - */ - public static int transformChar( - final int c, - int maskedUpperChar, - int maskedLowerChar, - int maskedDigitChar, - int maskedOtherChar) { - switch(Character.getType(c)) { - case Character.UPPERCASE_LETTER: - if(maskedUpperChar != UNMASKED_VAL) { - return maskedUpperChar; - } - break; - - case Character.LOWERCASE_LETTER: - if(maskedLowerChar != UNMASKED_VAL) { - return maskedLowerChar; - } - break; - - case Character.DECIMAL_DIGIT_NUMBER: - if(maskedDigitChar != UNMASKED_VAL) { - return maskedDigitChar; - } - break; - - default: - if(maskedOtherChar != UNMASKED_VAL) { - return maskedOtherChar; - } - break; - } - - return c; - } - - /** - * Returns the replacement char to use according to the {@param rep} specified by the user and - * the {@param def} default. - */ - public static int getReplacementChar(String rep, int def) { - if (rep != null && rep.length() > 0) { - return rep.codePointAt(0); - } - return def; - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 1d9e470c9ba83..d696ce9a766d5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -440,14 +440,6 @@ object FunctionRegistry { expression[ArrayDistinct]("array_distinct"), CreateStruct.registryEntry, - // mask functions - expression[Mask]("mask"), - expression[MaskFirstN]("mask_first_n"), - expression[MaskLastN]("mask_last_n"), - expression[MaskShowFirstN]("mask_show_first_n"), - expression[MaskShowLastN]("mask_show_last_n"), - expression[MaskHash]("mask_hash"), - // misc functions expression[AssertTrue]("assert_true"), expression[Crc32]("crc32"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/maskExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/maskExpressions.scala deleted file mode 100644 index 276a57266a6e0..0000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/maskExpressions.scala +++ /dev/null @@ -1,569 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.expressions - -import org.apache.commons.codec.digest.DigestUtils - -import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.expressions.MaskExpressionsUtils._ -import org.apache.spark.sql.catalyst.expressions.MaskLike._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} -import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String - - -trait MaskLike { - def upper: String - def lower: String - def digit: String - - protected lazy val upperReplacement: Int = getReplacementChar(upper, defaultMaskedUppercase) - protected lazy val lowerReplacement: Int = getReplacementChar(lower, defaultMaskedLowercase) - protected lazy val digitReplacement: Int = getReplacementChar(digit, defaultMaskedDigit) - - protected val maskUtilsClassName: String = classOf[MaskExpressionsUtils].getName - - def inputStringLengthCode(inputString: String, length: String): String = { - s"${CodeGenerator.JAVA_INT} $length = $inputString.codePointCount(0, $inputString.length());" - } - - def appendMaskedToStringBuilderCode( - ctx: CodegenContext, - sb: String, - inputString: String, - offset: String, - numChars: String): String = { - val i = ctx.freshName("i") - val codePoint = ctx.freshName("codePoint") - s""" - |for (${CodeGenerator.JAVA_INT} $i = 0; $i < $numChars; $i++) { - | ${CodeGenerator.JAVA_INT} $codePoint = $inputString.codePointAt($offset); - | $sb.appendCodePoint($maskUtilsClassName.transformChar($codePoint, - | $upperReplacement, $lowerReplacement, - | $digitReplacement, $defaultMaskedOther)); - | $offset += Character.charCount($codePoint); - |} - """.stripMargin - } - - def appendUnchangedToStringBuilderCode( - ctx: CodegenContext, - sb: String, - inputString: String, - offset: String, - numChars: String): String = { - val i = ctx.freshName("i") - val codePoint = ctx.freshName("codePoint") - s""" - |for (${CodeGenerator.JAVA_INT} $i = 0; $i < $numChars; $i++) { - | ${CodeGenerator.JAVA_INT} $codePoint = $inputString.codePointAt($offset); - | $sb.appendCodePoint($codePoint); - | $offset += Character.charCount($codePoint); - |} - """.stripMargin - } - - def appendMaskedToStringBuilder( - sb: java.lang.StringBuilder, - inputString: String, - startOffset: Int, - numChars: Int): Int = { - var offset = startOffset - (1 to numChars) foreach { _ => - val codePoint = inputString.codePointAt(offset) - sb.appendCodePoint(transformChar( - codePoint, - upperReplacement, - lowerReplacement, - digitReplacement, - defaultMaskedOther)) - offset += Character.charCount(codePoint) - } - offset - } - - def appendUnchangedToStringBuilder( - sb: java.lang.StringBuilder, - inputString: String, - startOffset: Int, - numChars: Int): Int = { - var offset = startOffset - (1 to numChars) foreach { _ => - val codePoint = inputString.codePointAt(offset) - sb.appendCodePoint(codePoint) - offset += Character.charCount(codePoint) - } - offset - } -} - -trait MaskLikeWithN extends MaskLike { - def n: Int - protected lazy val charCount: Int = if (n < 0) 0 else n -} - -/** - * Utils for mask operations. - */ -object MaskLike { - val defaultCharCount = 4 - val defaultMaskedUppercase: Int = 'X' - val defaultMaskedLowercase: Int = 'x' - val defaultMaskedDigit: Int = 'n' - val defaultMaskedOther: Int = MaskExpressionsUtils.UNMASKED_VAL - - def extractCharCount(e: Expression): Int = e match { - case Literal(i, IntegerType | NullType) => - if (i == null) defaultCharCount else i.asInstanceOf[Int] - case Literal(_, dt) => throw new AnalysisException("Expected literal expression of type " + - s"${IntegerType.simpleString}, but got literal of ${dt.simpleString}") - case other => throw new AnalysisException(s"Expected literal expression, but got ${other.sql}") - } - - def extractReplacement(e: Expression): String = e match { - case Literal(s, StringType | NullType) => if (s == null) null else s.toString - case Literal(_, dt) => throw new AnalysisException("Expected literal expression of type " + - s"${StringType.simpleString}, but got literal of ${dt.simpleString}") - case other => throw new AnalysisException(s"Expected literal expression, but got ${other.sql}") - } -} - -/** - * Masks the input string. Additional parameters can be set to change the masking chars for - * uppercase letters, lowercase letters and digits. - */ -// scalastyle:off line.size.limit -@ExpressionDescription( - usage = "_FUNC_(str[, upper[, lower[, digit]]]) - Masks str. By default, upper case letters are converted to \"X\", lower case letters are converted to \"x\" and numbers are converted to \"n\". You can override the characters used in the mask by supplying additional arguments: the second argument controls the mask character for upper case letters, the third argument for lower case letters and the fourth argument for numbers.", - examples = """ - Examples: - > SELECT _FUNC_("abcd-EFGH-8765-4321", "U", "l", "#"); - llll-UUUU-####-#### - """) -// scalastyle:on line.size.limit -case class Mask(child: Expression, upper: String, lower: String, digit: String) - extends UnaryExpression with ExpectsInputTypes with MaskLike { - - def this(child: Expression) = this(child, null.asInstanceOf[String], null, null) - - def this(child: Expression, upper: Expression) = - this(child, extractReplacement(upper), null, null) - - def this(child: Expression, upper: Expression, lower: Expression) = - this(child, extractReplacement(upper), extractReplacement(lower), null) - - def this(child: Expression, upper: Expression, lower: Expression, digit: Expression) = - this(child, extractReplacement(upper), extractReplacement(lower), extractReplacement(digit)) - - override def nullSafeEval(input: Any): Any = { - val str = input.asInstanceOf[UTF8String].toString - val length = str.codePointCount(0, str.length()) - val sb = new java.lang.StringBuilder(length) - appendMaskedToStringBuilder(sb, str, 0, length) - UTF8String.fromString(sb.toString) - } - - override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - nullSafeCodeGen(ctx, ev, (input: String) => { - val sb = ctx.freshName("sb") - val length = ctx.freshName("length") - val offset = ctx.freshName("offset") - val inputString = ctx.freshName("inputString") - s""" - |String $inputString = $input.toString(); - |${inputStringLengthCode(inputString, length)} - |StringBuilder $sb = new StringBuilder($length); - |${CodeGenerator.JAVA_INT} $offset = 0; - |${appendMaskedToStringBuilderCode(ctx, sb, inputString, offset, length)} - |${ev.value} = UTF8String.fromString($sb.toString()); - """.stripMargin - }) - } - - override def dataType: DataType = StringType - - override def inputTypes: Seq[AbstractDataType] = Seq(StringType) -} - -/** - * Masks the first N chars of the input string. N defaults to 4. Additional parameters can be set - * to change the masking chars for uppercase letters, lowercase letters and digits. - */ -// scalastyle:off line.size.limit -@ExpressionDescription( - usage = "_FUNC_(str[, n[, upper[, lower[, digit]]]]) - Masks the first n values of str. By default, n is 4, upper case letters are converted to \"X\", lower case letters are converted to \"x\" and numbers are converted to \"n\". You can override the characters used in the mask by supplying additional arguments: the second argument controls the mask character for upper case letters, the third argument for lower case letters and the fourth argument for numbers.", - examples = """ - Examples: - > SELECT _FUNC_("1234-5678-8765-4321", 4); - nnnn-5678-8765-4321 - """) -// scalastyle:on line.size.limit -case class MaskFirstN( - child: Expression, - n: Int, - upper: String, - lower: String, - digit: String) - extends UnaryExpression with ExpectsInputTypes with MaskLikeWithN { - - def this(child: Expression) = - this(child, defaultCharCount, null, null, null) - - def this(child: Expression, n: Expression) = - this(child, extractCharCount(n), null, null, null) - - def this(child: Expression, n: Expression, upper: Expression) = - this(child, extractCharCount(n), extractReplacement(upper), null, null) - - def this(child: Expression, n: Expression, upper: Expression, lower: Expression) = - this(child, extractCharCount(n), extractReplacement(upper), extractReplacement(lower), null) - - def this( - child: Expression, - n: Expression, - upper: Expression, - lower: Expression, - digit: Expression) = - this(child, - extractCharCount(n), - extractReplacement(upper), - extractReplacement(lower), - extractReplacement(digit)) - - override def nullSafeEval(input: Any): Any = { - val str = input.asInstanceOf[UTF8String].toString - val length = str.codePointCount(0, str.length()) - val endOfMask = if (charCount > length) length else charCount - val sb = new java.lang.StringBuilder(length) - val offset = appendMaskedToStringBuilder(sb, str, 0, endOfMask) - appendUnchangedToStringBuilder(sb, str, offset, length - endOfMask) - UTF8String.fromString(sb.toString) - } - - override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - nullSafeCodeGen(ctx, ev, (input: String) => { - val sb = ctx.freshName("sb") - val length = ctx.freshName("length") - val offset = ctx.freshName("offset") - val inputString = ctx.freshName("inputString") - val endOfMask = ctx.freshName("endOfMask") - s""" - |String $inputString = $input.toString(); - |${inputStringLengthCode(inputString, length)} - |${CodeGenerator.JAVA_INT} $endOfMask = $charCount > $length ? $length : $charCount; - |${CodeGenerator.JAVA_INT} $offset = 0; - |StringBuilder $sb = new StringBuilder($length); - |${appendMaskedToStringBuilderCode(ctx, sb, inputString, offset, endOfMask)} - |${appendUnchangedToStringBuilderCode( - ctx, sb, inputString, offset, s"$length - $endOfMask")} - |${ev.value} = UTF8String.fromString($sb.toString()); - |""".stripMargin - }) - } - - override def dataType: DataType = StringType - - override def inputTypes: Seq[AbstractDataType] = Seq(StringType) - - override def prettyName: String = "mask_first_n" -} - -/** - * Masks the last N chars of the input string. N defaults to 4. Additional parameters can be set - * to change the masking chars for uppercase letters, lowercase letters and digits. - */ -// scalastyle:off line.size.limit -@ExpressionDescription( - usage = "_FUNC_(str[, n[, upper[, lower[, digit]]]]) - Masks the last n values of str. By default, n is 4, upper case letters are converted to \"X\", lower case letters are converted to \"x\" and numbers are converted to \"n\". You can override the characters used in the mask by supplying additional arguments: the second argument controls the mask character for upper case letters, the third argument for lower case letters and the fourth argument for numbers.", - examples = """ - Examples: - > SELECT _FUNC_("1234-5678-8765-4321", 4); - 1234-5678-8765-nnnn - """, since = "2.4.0") -// scalastyle:on line.size.limit -case class MaskLastN( - child: Expression, - n: Int, - upper: String, - lower: String, - digit: String) - extends UnaryExpression with ExpectsInputTypes with MaskLikeWithN { - - def this(child: Expression) = - this(child, defaultCharCount, null, null, null) - - def this(child: Expression, n: Expression) = - this(child, extractCharCount(n), null, null, null) - - def this(child: Expression, n: Expression, upper: Expression) = - this(child, extractCharCount(n), extractReplacement(upper), null, null) - - def this(child: Expression, n: Expression, upper: Expression, lower: Expression) = - this(child, extractCharCount(n), extractReplacement(upper), extractReplacement(lower), null) - - def this( - child: Expression, - n: Expression, - upper: Expression, - lower: Expression, - digit: Expression) = - this(child, - extractCharCount(n), - extractReplacement(upper), - extractReplacement(lower), - extractReplacement(digit)) - - override def nullSafeEval(input: Any): Any = { - val str = input.asInstanceOf[UTF8String].toString - val length = str.codePointCount(0, str.length()) - val startOfMask = if (charCount >= length) 0 else length - charCount - val sb = new java.lang.StringBuilder(length) - val offset = appendUnchangedToStringBuilder(sb, str, 0, startOfMask) - appendMaskedToStringBuilder(sb, str, offset, length - startOfMask) - UTF8String.fromString(sb.toString) - } - - override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - nullSafeCodeGen(ctx, ev, (input: String) => { - val sb = ctx.freshName("sb") - val length = ctx.freshName("length") - val offset = ctx.freshName("offset") - val inputString = ctx.freshName("inputString") - val startOfMask = ctx.freshName("startOfMask") - s""" - |String $inputString = $input.toString(); - |${inputStringLengthCode(inputString, length)} - |${CodeGenerator.JAVA_INT} $startOfMask = $charCount >= $length ? - | 0 : $length - $charCount; - |${CodeGenerator.JAVA_INT} $offset = 0; - |StringBuilder $sb = new StringBuilder($length); - |${appendUnchangedToStringBuilderCode(ctx, sb, inputString, offset, startOfMask)} - |${appendMaskedToStringBuilderCode( - ctx, sb, inputString, offset, s"$length - $startOfMask")} - |${ev.value} = UTF8String.fromString($sb.toString()); - |""".stripMargin - }) - } - - override def dataType: DataType = StringType - - override def inputTypes: Seq[AbstractDataType] = Seq(StringType) - - override def prettyName: String = "mask_last_n" -} - -/** - * Masks all but the first N chars of the input string. N defaults to 4. Additional parameters can - * be set to change the masking chars for uppercase letters, lowercase letters and digits. - */ -// scalastyle:off line.size.limit -@ExpressionDescription( - usage = "_FUNC_(str[, n[, upper[, lower[, digit]]]]) - Masks all but the first n values of str. By default, n is 4, upper case letters are converted to \"X\", lower case letters are converted to \"x\" and numbers are converted to \"n\". You can override the characters used in the mask by supplying additional arguments: the second argument controls the mask character for upper case letters, the third argument for lower case letters and the fourth argument for numbers.", - examples = """ - Examples: - > SELECT _FUNC_("1234-5678-8765-4321", 4); - 1234-nnnn-nnnn-nnnn - """, since = "2.4.0") -// scalastyle:on line.size.limit -case class MaskShowFirstN( - child: Expression, - n: Int, - upper: String, - lower: String, - digit: String) - extends UnaryExpression with ExpectsInputTypes with MaskLikeWithN { - - def this(child: Expression) = - this(child, defaultCharCount, null, null, null) - - def this(child: Expression, n: Expression) = - this(child, extractCharCount(n), null, null, null) - - def this(child: Expression, n: Expression, upper: Expression) = - this(child, extractCharCount(n), extractReplacement(upper), null, null) - - def this(child: Expression, n: Expression, upper: Expression, lower: Expression) = - this(child, extractCharCount(n), extractReplacement(upper), extractReplacement(lower), null) - - def this( - child: Expression, - n: Expression, - upper: Expression, - lower: Expression, - digit: Expression) = - this(child, - extractCharCount(n), - extractReplacement(upper), - extractReplacement(lower), - extractReplacement(digit)) - - override def nullSafeEval(input: Any): Any = { - val str = input.asInstanceOf[UTF8String].toString - val length = str.codePointCount(0, str.length()) - val startOfMask = if (charCount > length) length else charCount - val sb = new java.lang.StringBuilder(length) - val offset = appendUnchangedToStringBuilder(sb, str, 0, startOfMask) - appendMaskedToStringBuilder(sb, str, offset, length - startOfMask) - UTF8String.fromString(sb.toString) - } - - override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - nullSafeCodeGen(ctx, ev, (input: String) => { - val sb = ctx.freshName("sb") - val length = ctx.freshName("length") - val offset = ctx.freshName("offset") - val inputString = ctx.freshName("inputString") - val startOfMask = ctx.freshName("startOfMask") - s""" - |String $inputString = $input.toString(); - |${inputStringLengthCode(inputString, length)} - |${CodeGenerator.JAVA_INT} $startOfMask = $charCount > $length ? $length : $charCount; - |${CodeGenerator.JAVA_INT} $offset = 0; - |StringBuilder $sb = new StringBuilder($length); - |${appendUnchangedToStringBuilderCode(ctx, sb, inputString, offset, startOfMask)} - |${appendMaskedToStringBuilderCode( - ctx, sb, inputString, offset, s"$length - $startOfMask")} - |${ev.value} = UTF8String.fromString($sb.toString()); - |""".stripMargin - }) - } - - override def dataType: DataType = StringType - - override def inputTypes: Seq[AbstractDataType] = Seq(StringType) - - override def prettyName: String = "mask_show_first_n" -} - -/** - * Masks all but the last N chars of the input string. N defaults to 4. Additional parameters can - * be set to change the masking chars for uppercase letters, lowercase letters and digits. - */ -// scalastyle:off line.size.limit -@ExpressionDescription( - usage = "_FUNC_(str[, n[, upper[, lower[, digit]]]]) - Masks all but the last n values of str. By default, n is 4, upper case letters are converted to \"X\", lower case letters are converted to \"x\" and numbers are converted to \"n\". You can override the characters used in the mask by supplying additional arguments: the second argument controls the mask character for upper case letters, the third argument for lower case letters and the fourth argument for numbers.", - examples = """ - Examples: - > SELECT _FUNC_("1234-5678-8765-4321", 4); - nnnn-nnnn-nnnn-4321 - """, since = "2.4.0") -// scalastyle:on line.size.limit -case class MaskShowLastN( - child: Expression, - n: Int, - upper: String, - lower: String, - digit: String) - extends UnaryExpression with ExpectsInputTypes with MaskLikeWithN { - - def this(child: Expression) = - this(child, defaultCharCount, null, null, null) - - def this(child: Expression, n: Expression) = - this(child, extractCharCount(n), null, null, null) - - def this(child: Expression, n: Expression, upper: Expression) = - this(child, extractCharCount(n), extractReplacement(upper), null, null) - - def this(child: Expression, n: Expression, upper: Expression, lower: Expression) = - this(child, extractCharCount(n), extractReplacement(upper), extractReplacement(lower), null) - - def this( - child: Expression, - n: Expression, - upper: Expression, - lower: Expression, - digit: Expression) = - this(child, - extractCharCount(n), - extractReplacement(upper), - extractReplacement(lower), - extractReplacement(digit)) - - override def nullSafeEval(input: Any): Any = { - val str = input.asInstanceOf[UTF8String].toString - val length = str.codePointCount(0, str.length()) - val endOfMask = if (charCount >= length) 0 else length - charCount - val sb = new java.lang.StringBuilder(length) - val offset = appendMaskedToStringBuilder(sb, str, 0, endOfMask) - appendUnchangedToStringBuilder(sb, str, offset, length - endOfMask) - UTF8String.fromString(sb.toString) - } - - override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - nullSafeCodeGen(ctx, ev, (input: String) => { - val sb = ctx.freshName("sb") - val length = ctx.freshName("length") - val offset = ctx.freshName("offset") - val inputString = ctx.freshName("inputString") - val endOfMask = ctx.freshName("endOfMask") - s""" - |String $inputString = $input.toString(); - |${inputStringLengthCode(inputString, length)} - |${CodeGenerator.JAVA_INT} $endOfMask = $charCount >= $length ? 0 : $length - $charCount; - |${CodeGenerator.JAVA_INT} $offset = 0; - |StringBuilder $sb = new StringBuilder($length); - |${appendMaskedToStringBuilderCode(ctx, sb, inputString, offset, endOfMask)} - |${appendUnchangedToStringBuilderCode( - ctx, sb, inputString, offset, s"$length - $endOfMask")} - |${ev.value} = UTF8String.fromString($sb.toString()); - |""".stripMargin - }) - } - - override def dataType: DataType = StringType - - override def inputTypes: Seq[AbstractDataType] = Seq(StringType) - - override def prettyName: String = "mask_show_last_n" -} - -/** - * Returns a hashed value based on str. - */ -// scalastyle:off line.size.limit -@ExpressionDescription( - usage = "_FUNC_(str) - Returns a hashed value based on str. The hash is consistent and can be used to join masked values together across tables.", - examples = """ - Examples: - > SELECT _FUNC_("abcd-EFGH-8765-4321"); - 60c713f5ec6912229d2060df1c322776 - """) -// scalastyle:on line.size.limit -case class MaskHash(child: Expression) - extends UnaryExpression with ExpectsInputTypes { - - override def nullSafeEval(input: Any): Any = { - UTF8String.fromString(DigestUtils.md5Hex(input.asInstanceOf[UTF8String].toString)) - } - - override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - nullSafeCodeGen(ctx, ev, (input: String) => { - val digestUtilsClass = classOf[DigestUtils].getName.stripSuffix("$") - s""" - |${ev.value} = UTF8String.fromString($digestUtilsClass.md5Hex($input.toString())); - |""".stripMargin - }) - } - - override def dataType: DataType = StringType - - override def inputTypes: Seq[AbstractDataType] = Seq(StringType) - - override def prettyName: String = "mask_hash" -} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MaskExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MaskExpressionsSuite.scala deleted file mode 100644 index 4d69dc32ace82..0000000000000 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MaskExpressionsSuite.scala +++ /dev/null @@ -1,236 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.expressions - -import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.types.{IntegerType, StringType} - -class MaskExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { - - test("mask") { - checkEvaluation(Mask(Literal("abcd-EFGH-8765-4321"), "U", "l", "#"), "llll-UUUU-####-####") - checkEvaluation( - new Mask(Literal("abcd-EFGH-8765-4321"), Literal("U"), Literal("l"), Literal("#")), - "llll-UUUU-####-####") - checkEvaluation(new Mask(Literal("abcd-EFGH-8765-4321"), Literal("U"), Literal("l")), - "llll-UUUU-nnnn-nnnn") - checkEvaluation(new Mask(Literal("abcd-EFGH-8765-4321"), Literal("U")), "xxxx-UUUU-nnnn-nnnn") - checkEvaluation(new Mask(Literal("abcd-EFGH-8765-4321")), "xxxx-XXXX-nnnn-nnnn") - checkEvaluation(new Mask(Literal(null, StringType)), null) - checkEvaluation(Mask(Literal("abcd-EFGH-8765-4321"), null, "l", "#"), "llll-XXXX-####-####") - checkEvaluation(new Mask( - Literal("abcd-EFGH-8765-4321"), - Literal(null, StringType), - Literal(null, StringType), - Literal(null, StringType)), "xxxx-XXXX-nnnn-nnnn") - checkEvaluation(new Mask(Literal("abcd-EFGH-8765-4321"), Literal("Upper")), - "xxxx-UUUU-nnnn-nnnn") - checkEvaluation(new Mask(Literal("")), "") - checkEvaluation(new Mask(Literal("abcd-EFGH-8765-4321"), Literal("")), "xxxx-XXXX-nnnn-nnnn") - checkEvaluation(Mask(Literal("abcd-EFGH-8765-4321"), "", "", ""), "xxxx-XXXX-nnnn-nnnn") - // scalastyle:off nonascii - checkEvaluation(Mask(Literal("Ul9U"), "\u2200", null, null), "\u2200xn\u2200") - checkEvaluation(new Mask(Literal("Hello World, こんにちは, 𠀋"), Literal("あ"), Literal("𡈽")), - "あ𡈽𡈽𡈽𡈽 あ𡈽𡈽𡈽𡈽, こんにちは, 𠀋") - // scalastyle:on nonascii - intercept[AnalysisException] { - checkEvaluation(new Mask(Literal(""), Literal(1)), "") - } - } - - test("mask_first_n") { - checkEvaluation(MaskFirstN(Literal("aB3d-EFGH-8765"), 6, "U", "l", "#"), - "lU#l-UFGH-8765") - checkEvaluation(new MaskFirstN( - Literal("abcd-EFGH-8765-4321"), Literal(6), Literal("U"), Literal("l"), Literal("#")), - "llll-UFGH-8765-4321") - checkEvaluation( - new MaskFirstN(Literal("abcd-EFGH-8765-4321"), Literal(6), Literal("U"), Literal("l")), - "llll-UFGH-8765-4321") - checkEvaluation(new MaskFirstN(Literal("abcd-EFGH-8765-4321"), Literal(6), Literal("U")), - "xxxx-UFGH-8765-4321") - checkEvaluation(new MaskFirstN(Literal("abcd-EFGH-8765-4321"), Literal(6)), - "xxxx-XFGH-8765-4321") - intercept[AnalysisException] { - checkEvaluation(new MaskFirstN(Literal("abcd-EFGH-8765-4321"), Literal("U")), "") - } - checkEvaluation(new MaskFirstN(Literal("abcd-EFGH-8765-4321")), "xxxx-EFGH-8765-4321") - checkEvaluation(new MaskFirstN(Literal(null, StringType)), null) - checkEvaluation(MaskFirstN(Literal("abcd-EFGH-8765-4321"), 4, "U", "l", null), - "llll-EFGH-8765-4321") - checkEvaluation(new MaskFirstN( - Literal("abcd-EFGH-8765-4321"), - Literal(null, IntegerType), - Literal(null, StringType), - Literal(null, StringType), - Literal(null, StringType)), "xxxx-EFGH-8765-4321") - checkEvaluation(new MaskFirstN(Literal("abcd-EFGH-8765-4321"), Literal(6), Literal("Upper")), - "xxxx-UFGH-8765-4321") - checkEvaluation(new MaskFirstN(Literal("")), "") - checkEvaluation(new MaskFirstN(Literal("abcd-EFGH-8765-4321"), Literal(4), Literal("")), - "xxxx-EFGH-8765-4321") - checkEvaluation(MaskFirstN(Literal("abcd-EFGH-8765-4321"), 1000, "", "", ""), - "xxxx-XXXX-nnnn-nnnn") - checkEvaluation(MaskFirstN(Literal("abcd-EFGH-8765-4321"), -1, "", "", ""), - "abcd-EFGH-8765-4321") - // scalastyle:off nonascii - checkEvaluation(MaskFirstN(Literal("Ul9U"), 2, "\u2200", null, null), "\u2200x9U") - checkEvaluation(new MaskFirstN(Literal("あ, 𠀋, Hello World"), Literal(10)), - "あ, 𠀋, Xxxxo World") - // scalastyle:on nonascii - } - - test("mask_last_n") { - checkEvaluation(MaskLastN(Literal("abcd-EFGH-aB3d"), 6, "U", "l", "#"), - "abcd-EFGU-lU#l") - checkEvaluation(new MaskLastN( - Literal("abcd-EFGH-8765"), Literal(6), Literal("U"), Literal("l"), Literal("#")), - "abcd-EFGU-####") - checkEvaluation( - new MaskLastN(Literal("abcd-EFGH-8765"), Literal(6), Literal("U"), Literal("l")), - "abcd-EFGU-nnnn") - checkEvaluation( - new MaskLastN(Literal("abcd-EFGH-8765"), Literal(6), Literal("U")), - "abcd-EFGU-nnnn") - checkEvaluation( - new MaskLastN(Literal("abcd-EFGH-8765"), Literal(6)), - "abcd-EFGX-nnnn") - intercept[AnalysisException] { - checkEvaluation(new MaskLastN(Literal("abcd-EFGH-8765"), Literal("U")), "") - } - checkEvaluation(new MaskLastN(Literal("abcd-EFGH-8765-4321")), "abcd-EFGH-8765-nnnn") - checkEvaluation(new MaskLastN(Literal(null, StringType)), null) - checkEvaluation(MaskLastN(Literal("abcd-EFGH-8765-4321"), 4, "U", "l", null), - "abcd-EFGH-8765-nnnn") - checkEvaluation(new MaskLastN( - Literal("abcd-EFGH-8765-4321"), - Literal(null, IntegerType), - Literal(null, StringType), - Literal(null, StringType), - Literal(null, StringType)), "abcd-EFGH-8765-nnnn") - checkEvaluation(new MaskLastN(Literal("abcd-EFGH-8765-4321"), Literal(12), Literal("Upper")), - "abcd-EFUU-nnnn-nnnn") - checkEvaluation(new MaskLastN(Literal("")), "") - checkEvaluation(new MaskLastN(Literal("abcd-EFGH-8765-4321"), Literal(16), Literal("")), - "abcx-XXXX-nnnn-nnnn") - checkEvaluation(MaskLastN(Literal("abcd-EFGH-8765-4321"), 1000, "", "", ""), - "xxxx-XXXX-nnnn-nnnn") - checkEvaluation(MaskLastN(Literal("abcd-EFGH-8765-4321"), -1, "", "", ""), - "abcd-EFGH-8765-4321") - // scalastyle:off nonascii - checkEvaluation(MaskLastN(Literal("Ul9U"), 2, "\u2200", null, null), "Uln\u2200") - checkEvaluation(new MaskLastN(Literal("あ, 𠀋, Hello World あ 𠀋"), Literal(10)), - "あ, 𠀋, Hello Xxxxx あ 𠀋") - // scalastyle:on nonascii - } - - test("mask_show_first_n") { - checkEvaluation(MaskShowFirstN(Literal("abcd-EFGH-8765-aB3d"), 6, "U", "l", "#"), - "abcd-EUUU-####-lU#l") - checkEvaluation(new MaskShowFirstN( - Literal("abcd-EFGH-8765-4321"), Literal(6), Literal("U"), Literal("l"), Literal("#")), - "abcd-EUUU-####-####") - checkEvaluation( - new MaskShowFirstN(Literal("abcd-EFGH-8765-4321"), Literal(6), Literal("U"), Literal("l")), - "abcd-EUUU-nnnn-nnnn") - checkEvaluation(new MaskShowFirstN(Literal("abcd-EFGH-8765-4321"), Literal(6), Literal("U")), - "abcd-EUUU-nnnn-nnnn") - checkEvaluation(new MaskShowFirstN(Literal("abcd-EFGH-8765-4321"), Literal(6)), - "abcd-EXXX-nnnn-nnnn") - intercept[AnalysisException] { - checkEvaluation(new MaskShowFirstN(Literal("abcd-EFGH-8765-4321"), Literal("U")), "") - } - checkEvaluation(new MaskShowFirstN(Literal("abcd-EFGH-8765-4321")), "abcd-XXXX-nnnn-nnnn") - checkEvaluation(new MaskShowFirstN(Literal(null, StringType)), null) - checkEvaluation(MaskShowFirstN(Literal("abcd-EFGH-8765-4321"), 4, "U", "l", null), - "abcd-UUUU-nnnn-nnnn") - checkEvaluation(new MaskShowFirstN( - Literal("abcd-EFGH-8765-4321"), - Literal(null, IntegerType), - Literal(null, StringType), - Literal(null, StringType), - Literal(null, StringType)), "abcd-XXXX-nnnn-nnnn") - checkEvaluation( - new MaskShowFirstN(Literal("abcd-EFGH-8765-4321"), Literal(6), Literal("Upper")), - "abcd-EUUU-nnnn-nnnn") - checkEvaluation(new MaskShowFirstN(Literal("")), "") - checkEvaluation(new MaskShowFirstN(Literal("abcd-EFGH-8765-4321"), Literal(4), Literal("")), - "abcd-XXXX-nnnn-nnnn") - checkEvaluation(MaskShowFirstN(Literal("abcd-EFGH-8765-4321"), 1000, "", "", ""), - "abcd-EFGH-8765-4321") - checkEvaluation(MaskShowFirstN(Literal("abcd-EFGH-8765-4321"), -1, "", "", ""), - "xxxx-XXXX-nnnn-nnnn") - // scalastyle:off nonascii - checkEvaluation(MaskShowFirstN(Literal("Ul9U"), 2, "\u2200", null, null), "Uln\u2200") - checkEvaluation(new MaskShowFirstN(Literal("あ, 𠀋, Hello World"), Literal(10)), - "あ, 𠀋, Hellx Xxxxx") - // scalastyle:on nonascii - } - - test("mask_show_last_n") { - checkEvaluation(MaskShowLastN(Literal("aB3d-EFGH-8765"), 6, "U", "l", "#"), - "lU#l-UUUH-8765") - checkEvaluation(new MaskShowLastN( - Literal("abcd-EFGH-8765-4321"), Literal(6), Literal("U"), Literal("l"), Literal("#")), - "llll-UUUU-###5-4321") - checkEvaluation( - new MaskShowLastN(Literal("abcd-EFGH-8765-4321"), Literal(6), Literal("U"), Literal("l")), - "llll-UUUU-nnn5-4321") - checkEvaluation(new MaskShowLastN(Literal("abcd-EFGH-8765-4321"), Literal(6), Literal("U")), - "xxxx-UUUU-nnn5-4321") - checkEvaluation(new MaskShowLastN(Literal("abcd-EFGH-8765-4321"), Literal(6)), - "xxxx-XXXX-nnn5-4321") - intercept[AnalysisException] { - checkEvaluation(new MaskShowLastN(Literal("abcd-EFGH-8765-4321"), Literal("U")), "") - } - checkEvaluation(new MaskShowLastN(Literal("abcd-EFGH-8765-4321")), "xxxx-XXXX-nnnn-4321") - checkEvaluation(new MaskShowLastN(Literal(null, StringType)), null) - checkEvaluation(MaskShowLastN(Literal("abcd-EFGH-8765-4321"), 4, "U", "l", null), - "llll-UUUU-nnnn-4321") - checkEvaluation(new MaskShowLastN( - Literal("abcd-EFGH-8765-4321"), - Literal(null, IntegerType), - Literal(null, StringType), - Literal(null, StringType), - Literal(null, StringType)), "xxxx-XXXX-nnnn-4321") - checkEvaluation(new MaskShowLastN(Literal("abcd-EFGH-8765-4321"), Literal(6), Literal("Upper")), - "xxxx-UUUU-nnn5-4321") - checkEvaluation(new MaskShowLastN(Literal("")), "") - checkEvaluation(new MaskShowLastN(Literal("abcd-EFGH-8765-4321"), Literal(4), Literal("")), - "xxxx-XXXX-nnnn-4321") - checkEvaluation(MaskShowLastN(Literal("abcd-EFGH-8765-4321"), 1000, "", "", ""), - "abcd-EFGH-8765-4321") - checkEvaluation(MaskShowLastN(Literal("abcd-EFGH-8765-4321"), -1, "", "", ""), - "xxxx-XXXX-nnnn-nnnn") - // scalastyle:off nonascii - checkEvaluation(MaskShowLastN(Literal("Ul9U"), 2, "\u2200", null, null), "\u2200x9U") - checkEvaluation(new MaskShowLastN(Literal("あ, 𠀋, Hello World"), Literal(10)), - "あ, 𠀋, Xello World") - // scalastyle:on nonascii - } - - test("mask_hash") { - checkEvaluation(MaskHash(Literal("abcd-EFGH-8765-4321")), "60c713f5ec6912229d2060df1c322776") - checkEvaluation(MaskHash(Literal("")), "d41d8cd98f00b204e9800998ecf8427e") - checkEvaluation(MaskHash(Literal(null, StringType)), null) - // scalastyle:off nonascii - checkEvaluation(MaskHash(Literal("\u2200x9U")), "f1243ef123d516b1f32a3a75309e5711") - // scalastyle:on nonascii - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index b98ab11e56feb..de1d422856ba9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3646,125 +3646,6 @@ object functions { @scala.annotation.varargs def map_concat(cols: Column*): Column = withExpr { MapConcat(cols.map(_.expr)) } - ////////////////////////////////////////////////////////////////////////////////////////////// - // Mask functions - ////////////////////////////////////////////////////////////////////////////////////////////// - /** - * Returns a string which is the masked representation of the input. - * @group mask_funcs - * @since 2.4.0 - */ - def mask(e: Column): Column = withExpr { new Mask(e.expr) } - - /** - * Returns a string which is the masked representation of the input, using `upper`, `lower` and - * `digit` as replacement characters. - * @group mask_funcs - * @since 2.4.0 - */ - def mask(e: Column, upper: String, lower: String, digit: String): Column = withExpr { - Mask(e.expr, upper, lower, digit) - } - - /** - * Returns a string with the first `n` characters masked. - * @group mask_funcs - * @since 2.4.0 - */ - def mask_first_n(e: Column, n: Int): Column = withExpr { new MaskFirstN(e.expr, Literal(n)) } - - /** - * Returns a string with the first `n` characters masked, using `upper`, `lower` and `digit` as - * replacement characters. - * @group mask_funcs - * @since 2.4.0 - */ - def mask_first_n( - e: Column, - n: Int, - upper: String, - lower: String, - digit: String): Column = withExpr { - MaskFirstN(e.expr, n, upper, lower, digit) - } - - /** - * Returns a string with the last `n` characters masked. - * @group mask_funcs - * @since 2.4.0 - */ - def mask_last_n(e: Column, n: Int): Column = withExpr { new MaskLastN(e.expr, Literal(n)) } - - /** - * Returns a string with the last `n` characters masked, using `upper`, `lower` and `digit` as - * replacement characters. - * @group mask_funcs - * @since 2.4.0 - */ - def mask_last_n( - e: Column, - n: Int, - upper: String, - lower: String, - digit: String): Column = withExpr { - MaskLastN(e.expr, n, upper, lower, digit) - } - - /** - * Returns a string with all but the first `n` characters masked. - * @group mask_funcs - * @since 2.4.0 - */ - def mask_show_first_n(e: Column, n: Int): Column = withExpr { - new MaskShowFirstN(e.expr, Literal(n)) - } - - /** - * Returns a string with all but the first `n` characters masked, using `upper`, `lower` and - * `digit` as replacement characters. - * @group mask_funcs - * @since 2.4.0 - */ - def mask_show_first_n( - e: Column, - n: Int, - upper: String, - lower: String, - digit: String): Column = withExpr { - MaskShowFirstN(e.expr, n, upper, lower, digit) - } - - /** - * Returns a string with all but the last `n` characters masked. - * @group mask_funcs - * @since 2.4.0 - */ - def mask_show_last_n(e: Column, n: Int): Column = withExpr { - new MaskShowLastN(e.expr, Literal(n)) - } - - /** - * Returns a string with all but the last `n` characters masked, using `upper`, `lower` and - * `digit` as replacement characters. - * @group mask_funcs - * @since 2.4.0 - */ - def mask_show_last_n( - e: Column, - n: Int, - upper: String, - lower: String, - digit: String): Column = withExpr { - MaskShowLastN(e.expr, n, upper, lower, digit) - } - - /** - * Returns a hashed value based on the input column. - * @group mask_funcs - * @since 2.4.0 - */ - def mask_hash(e: Column): Column = withExpr { MaskHash(e.expr) } - // scalastyle:off line.size.limit // scalastyle:off parameter.number diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 3f6f4556e2d6b..a39aef18d27e1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -309,113 +309,6 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { ) } - test("mask functions") { - val df = Seq("TestString-123", "", null).toDF("a") - checkAnswer(df.select(mask($"a")), Seq(Row("XxxxXxxxxx-nnn"), Row(""), Row(null))) - checkAnswer(df.select(mask_first_n($"a", 4)), Seq(Row("XxxxString-123"), Row(""), Row(null))) - checkAnswer(df.select(mask_last_n($"a", 4)), Seq(Row("TestString-nnn"), Row(""), Row(null))) - checkAnswer(df.select(mask_show_first_n($"a", 4)), - Seq(Row("TestXxxxxx-nnn"), Row(""), Row(null))) - checkAnswer(df.select(mask_show_last_n($"a", 4)), - Seq(Row("XxxxXxxxxx-123"), Row(""), Row(null))) - checkAnswer(df.select(mask_hash($"a")), - Seq(Row("dd78d68ad1b23bde126812482dd70ac6"), - Row("d41d8cd98f00b204e9800998ecf8427e"), - Row(null))) - - checkAnswer(df.select(mask($"a", "U", "l", "#")), - Seq(Row("UlllUlllll-###"), Row(""), Row(null))) - checkAnswer(df.select(mask_first_n($"a", 4, "U", "l", "#")), - Seq(Row("UlllString-123"), Row(""), Row(null))) - checkAnswer(df.select(mask_last_n($"a", 4, "U", "l", "#")), - Seq(Row("TestString-###"), Row(""), Row(null))) - checkAnswer(df.select(mask_show_first_n($"a", 4, "U", "l", "#")), - Seq(Row("TestUlllll-###"), Row(""), Row(null))) - checkAnswer(df.select(mask_show_last_n($"a", 4, "U", "l", "#")), - Seq(Row("UlllUlllll-123"), Row(""), Row(null))) - - checkAnswer( - df.selectExpr("mask(a)", "mask(a, 'U')", "mask(a, 'U', 'l')", "mask(a, 'U', 'l', '#')"), - Seq(Row("XxxxXxxxxx-nnn", "UxxxUxxxxx-nnn", "UlllUlllll-nnn", "UlllUlllll-###"), - Row("", "", "", ""), - Row(null, null, null, null))) - checkAnswer(sql("select mask(null)"), Row(null)) - checkAnswer(sql("select mask('AAaa11', null, null, null)"), Row("XXxxnn")) - intercept[AnalysisException] { - checkAnswer(df.selectExpr("mask(a, a)"), Seq(Row("XxxxXxxxxx-nnn"), Row(""), Row(null))) - } - - checkAnswer( - df.selectExpr( - "mask_first_n(a)", - "mask_first_n(a, 6)", - "mask_first_n(a, 6, 'U')", - "mask_first_n(a, 6, 'U', 'l')", - "mask_first_n(a, 6, 'U', 'l', '#')"), - Seq(Row("XxxxString-123", "XxxxXxring-123", "UxxxUxring-123", "UlllUlring-123", - "UlllUlring-123"), - Row("", "", "", "", ""), - Row(null, null, null, null, null))) - checkAnswer(sql("select mask_first_n(null)"), Row(null)) - checkAnswer(sql("select mask_first_n('A1aA1a', null, null, null, null)"), Row("XnxX1a")) - intercept[AnalysisException] { - checkAnswer(spark.range(1).selectExpr("mask_first_n('A1aA1a', id)"), Row("XnxX1a")) - } - - checkAnswer( - df.selectExpr( - "mask_last_n(a)", - "mask_last_n(a, 6)", - "mask_last_n(a, 6, 'U')", - "mask_last_n(a, 6, 'U', 'l')", - "mask_last_n(a, 6, 'U', 'l', '#')"), - Seq(Row("TestString-nnn", "TestStrixx-nnn", "TestStrixx-nnn", "TestStrill-nnn", - "TestStrill-###"), - Row("", "", "", "", ""), - Row(null, null, null, null, null))) - checkAnswer(sql("select mask_last_n(null)"), Row(null)) - checkAnswer(sql("select mask_last_n('A1aA1a', null, null, null, null)"), Row("A1xXnx")) - intercept[AnalysisException] { - checkAnswer(spark.range(1).selectExpr("mask_last_n('A1aA1a', id)"), Row("A1xXnx")) - } - - checkAnswer( - df.selectExpr( - "mask_show_first_n(a)", - "mask_show_first_n(a, 6)", - "mask_show_first_n(a, 6, 'U')", - "mask_show_first_n(a, 6, 'U', 'l')", - "mask_show_first_n(a, 6, 'U', 'l', '#')"), - Seq(Row("TestXxxxxx-nnn", "TestStxxxx-nnn", "TestStxxxx-nnn", "TestStllll-nnn", - "TestStllll-###"), - Row("", "", "", "", ""), - Row(null, null, null, null, null))) - checkAnswer(sql("select mask_show_first_n(null)"), Row(null)) - checkAnswer(sql("select mask_show_first_n('A1aA1a', null, null, null, null)"), Row("A1aAnx")) - intercept[AnalysisException] { - checkAnswer(spark.range(1).selectExpr("mask_show_first_n('A1aA1a', id)"), Row("A1aAnx")) - } - - checkAnswer( - df.selectExpr( - "mask_show_last_n(a)", - "mask_show_last_n(a, 6)", - "mask_show_last_n(a, 6, 'U')", - "mask_show_last_n(a, 6, 'U', 'l')", - "mask_show_last_n(a, 6, 'U', 'l', '#')"), - Seq(Row("XxxxXxxxxx-123", "XxxxXxxxng-123", "UxxxUxxxng-123", "UlllUlllng-123", - "UlllUlllng-123"), - Row("", "", "", "", ""), - Row(null, null, null, null, null))) - checkAnswer(sql("select mask_show_last_n(null)"), Row(null)) - checkAnswer(sql("select mask_show_last_n('A1aA1a', null, null, null, null)"), Row("XnaA1a")) - intercept[AnalysisException] { - checkAnswer(spark.range(1).selectExpr("mask_show_last_n('A1aA1a', id)"), Row("XnaA1a")) - } - - checkAnswer(sql("select mask_hash(null)"), Row(null)) - } - test("sort_array/array_sort functions") { val df = Seq( (Array[Int](2, 1, 3), Array("b", "c", "a")), From ba437fc5c73b95ee4c59327abf3161c58f64cb12 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Mon, 16 Jul 2018 14:35:44 -0700 Subject: [PATCH 1137/1161] [SPARK-24805][SQL] Do not ignore avro files without extensions by default ## What changes were proposed in this pull request? In the PR, I propose to change default behaviour of AVRO datasource which currently ignores files without `.avro` extension in read by default. This PR sets the default value for `avro.mapred.ignore.inputs.without.extension` to `false` in the case if the parameter is not set by an user. ## How was this patch tested? Added a test file without extension in AVRO format, and new test for reading the file with and wihout specified schema. Author: Maxim Gekk Author: Maxim Gekk Closes #21769 from MaxGekk/avro-without-extension. --- .../spark/sql/avro/AvroFileFormat.scala | 14 +++--- .../org/apache/spark/sql/avro/AvroSuite.scala | 45 ++++++++++++++++--- 2 files changed, 47 insertions(+), 12 deletions(-) diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala index fb93033bb15d4..9eb206457809c 100755 --- a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala @@ -62,7 +62,7 @@ private[avro] class AvroFileFormat extends FileFormat with DataSourceRegister { // Schema evolution is not supported yet. Here we only pick a single random sample file to // figure out the schema of the whole dataset. val sampleFile = - if (conf.getBoolean(AvroFileFormat.IgnoreFilesWithoutExtensionProperty, true)) { + if (AvroFileFormat.ignoreFilesWithoutExtensions(conf)) { files.find(_.getPath.getName.endsWith(".avro")).getOrElse { throw new FileNotFoundException( "No Avro files found. Hadoop option \"avro.mapred.ignore.inputs.without.extension\" " + @@ -170,10 +170,7 @@ private[avro] class AvroFileFormat extends FileFormat with DataSourceRegister { // Doing input file filtering is improper because we may generate empty tasks that process no // input files but stress the scheduler. We should probably add a more general input file // filtering mechanism for `FileFormat` data sources. See SPARK-16317. - if ( - conf.getBoolean(AvroFileFormat.IgnoreFilesWithoutExtensionProperty, true) && - !file.filePath.endsWith(".avro") - ) { + if (AvroFileFormat.ignoreFilesWithoutExtensions(conf) && !file.filePath.endsWith(".avro")) { Iterator.empty } else { val reader = { @@ -278,4 +275,11 @@ private[avro] object AvroFileFormat { value.readFields(new DataInputStream(in)) } } + + def ignoreFilesWithoutExtensions(conf: Configuration): Boolean = { + // Files without .avro extensions are not ignored by default + val defaultValue = false + + conf.getBoolean(AvroFileFormat.IgnoreFilesWithoutExtensionProperty, defaultValue) + } } diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala index 9c6526b29dca3..446b42124ceca 100644 --- a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala +++ b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala @@ -18,7 +18,8 @@ package org.apache.spark.sql.avro import java.io._ -import java.nio.file.Files +import java.net.URL +import java.nio.file.{Files, Path, Paths} import java.sql.{Date, Timestamp} import java.util.{TimeZone, UUID} @@ -622,7 +623,12 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { intercept[FileNotFoundException] { withTempPath { dir => FileUtils.touch(new File(dir, "test")) - spark.read.avro(dir.toString) + val hadoopConf = spark.sqlContext.sparkContext.hadoopConfiguration + try { + hadoopConf.set(AvroFileFormat.IgnoreFilesWithoutExtensionProperty, "true") + spark.read.avro(dir.toString) + } finally { + hadoopConf.unset(AvroFileFormat.IgnoreFilesWithoutExtensionProperty) } } } @@ -684,12 +690,18 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { Files.createFile(new File(tempSaveDir, "non-avro").toPath) - val newDf = spark - .read - .option(AvroFileFormat.IgnoreFilesWithoutExtensionProperty, "true") - .avro(tempSaveDir) + val hadoopConf = spark.sqlContext.sparkContext.hadoopConfiguration + val count = try { + hadoopConf.set(AvroFileFormat.IgnoreFilesWithoutExtensionProperty, "true") + val newDf = spark + .read + .avro(tempSaveDir) + newDf.count() + } finally { + hadoopConf.unset(AvroFileFormat.IgnoreFilesWithoutExtensionProperty) + } - assert(newDf.count == 8) + assert(count == 8) } } @@ -805,4 +817,23 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { assert(readDf.collect().sameElements(writeDf.collect())) } } + + test("SPARK-24805: do not ignore files without .avro extension by default") { + withTempDir { dir => + Files.copy( + Paths.get(new URL(episodesAvro).toURI), + Paths.get(dir.getCanonicalPath, "episodes")) + + val fileWithoutExtension = s"${dir.getCanonicalPath}/episodes" + val df1 = spark.read.avro(fileWithoutExtension) + assert(df1.count == 8) + + val schema = new StructType() + .add("title", StringType) + .add("air_date", StringType) + .add("doctor", IntegerType) + val df2 = spark.read.schema(schema).avro(fileWithoutExtension) + assert(df2.count == 8) + } + } } From 0f0d1865f581a9158d73505471953656b173beba Mon Sep 17 00:00:00 2001 From: DB Tsai Date: Mon, 16 Jul 2018 15:33:39 -0700 Subject: [PATCH 1138/1161] [SPARK-24402][SQL] Optimize `In` expression when only one element in the collection or collection is empty ## What changes were proposed in this pull request? Two new rules in the logical plan optimizers are added. 1. When there is only one element in the **`Collection`**, the physical plan will be optimized to **`EqualTo`**, so predicate pushdown can be used. ```scala profileDF.filter( $"profileID".isInCollection(Set(6))).explain(true) """ |== Physical Plan == |*(1) Project [profileID#0] |+- *(1) Filter (isnotnull(profileID#0) && (profileID#0 = 6)) | +- *(1) FileScan parquet [profileID#0] Batched: true, Format: Parquet, | PartitionFilters: [], | PushedFilters: [IsNotNull(profileID), EqualTo(profileID,6)], | ReadSchema: struct """.stripMargin ``` 2. When the **`Collection`** is empty, and the input is nullable, the logical plan will be simplified to ```scala profileDF.filter( $"profileID".isInCollection(Set())).explain(true) """ |== Optimized Logical Plan == |Filter if (isnull(profileID#0)) null else false |+- Relation[profileID#0] parquet """.stripMargin ``` TODO: 1. For multiple conditions with numbers less than certain thresholds, we should still allow predicate pushdown. 2. Optimize the **`In`** using **`tableswitch`** or **`lookupswitch`** when the numbers of the categories are low, and they are **`Int`**, **`Long`**. 3. The default immutable hash trees set is slow for query, and we should do benchmark for using different set implementation for faster query. 4. **`filter(if (condition) null else false)`** can be optimized to false. ## How was this patch tested? Couple new tests are added. Author: DB Tsai Closes #21442 from dbtsai/optimize-in. --- .../sql/catalyst/optimizer/expressions.scala | 13 +++++--- .../catalyst/optimizer/OptimizeInSuite.scala | 32 +++++++++++++++++++ 2 files changed, 41 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 1d363b8146e3f..f78a0ff95f382 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -218,15 +218,20 @@ object ReorderAssociativeOperator extends Rule[LogicalPlan] { object OptimizeIn extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case q: LogicalPlan => q transformExpressionsDown { - case In(v, list) if list.isEmpty && !v.nullable => FalseLiteral + case In(v, list) if list.isEmpty => + // When v is not nullable, the following expression will be optimized + // to FalseLiteral which is tested in OptimizeInSuite.scala + If(IsNotNull(v), FalseLiteral, Literal(null, BooleanType)) case expr @ In(v, list) if expr.inSetConvertible => val newList = ExpressionSet(list).toSeq - if (newList.size > SQLConf.get.optimizerInSetConversionThreshold) { + if (newList.length == 1 && !newList.isInstanceOf[ListQuery]) { + EqualTo(v, newList.head) + } else if (newList.length > SQLConf.get.optimizerInSetConversionThreshold) { val hSet = newList.map(e => e.eval(EmptyRow)) InSet(v, HashSet() ++ hSet) - } else if (newList.size < list.size) { + } else if (newList.length < list.length) { expr.copy(list = newList) - } else { // newList.length == list.length + } else { // newList.length == list.length && newList.length > 1 expr } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala index 478118ed709f7..86522a6a54ed5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala @@ -176,6 +176,21 @@ class OptimizeInSuite extends PlanTest { } } + test("OptimizedIn test: one element in list gets transformed to EqualTo.") { + val originalQuery = + testRelation + .where(In(UnresolvedAttribute("a"), Seq(Literal(1)))) + .analyze + + val optimized = Optimize.execute(originalQuery) + val correctAnswer = + testRelation + .where(EqualTo(UnresolvedAttribute("a"), Literal(1))) + .analyze + + comparePlans(optimized, correctAnswer) + } + test("OptimizedIn test: In empty list gets transformed to FalseLiteral " + "when value is not nullable") { val originalQuery = @@ -191,4 +206,21 @@ class OptimizeInSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + + test("OptimizedIn test: In empty list gets transformed to `If` expression " + + "when value is nullable") { + val originalQuery = + testRelation + .where(In(UnresolvedAttribute("a"), Nil)) + .analyze + + val optimized = Optimize.execute(originalQuery) + val correctAnswer = + testRelation + .where(If(IsNotNull(UnresolvedAttribute("a")), + Literal(false), Literal.create(null, BooleanType))) + .analyze + + comparePlans(optimized, correctAnswer) + } } From d57a267b79f4015508c3686c34a0f438bad41ea1 Mon Sep 17 00:00:00 2001 From: Feng Liu Date: Tue, 17 Jul 2018 09:13:35 +0800 Subject: [PATCH 1139/1161] [SPARK-23259][SQL] Clean up legacy code around hive external catalog and HiveClientImpl ## What changes were proposed in this pull request? Three legacy statements are removed by this patch: - in HiveExternalCatalog: The withClient wrapper is not necessary for the private method getRawTable. - in HiveClientImpl: There are some redundant code in both the tableExists and getTableOption method. This PR takes over https://github.com/apache/spark/pull/20425 ## How was this patch tested? Existing tests Closes #20425 Author: hyukjinkwon Closes #21780 from HyukjinKwon/SPARK-23259. --- .../org/apache/spark/sql/hive/HiveExternalCatalog.scala | 2 +- .../org/apache/spark/sql/hive/client/HiveClientImpl.scala | 8 ++++++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index 011a3ba553cb2..44480cecf0039 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -114,7 +114,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat * should interpret these special data source properties and restore the original table metadata * before returning it. */ - private[hive] def getRawTable(db: String, table: String): CatalogTable = withClient { + private[hive] def getRawTable(db: String, table: String): CatalogTable = { client.getTable(db, table) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala index 1df46d7431a21..db8fd5a43d842 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala @@ -353,15 +353,19 @@ private[hive] class HiveClientImpl( client.getDatabasesByPattern(pattern).asScala } + private def getRawTableOption(dbName: String, tableName: String): Option[HiveTable] = { + Option(client.getTable(dbName, tableName, false /* do not throw exception */)) + } + override def tableExists(dbName: String, tableName: String): Boolean = withHiveState { - Option(client.getTable(dbName, tableName, false /* do not throw exception */)).nonEmpty + getRawTableOption(dbName, tableName).nonEmpty } override def getTableOption( dbName: String, tableName: String): Option[CatalogTable] = withHiveState { logDebug(s"Looking up $dbName.$tableName") - Option(client.getTable(dbName, tableName, false)).map { h => + getRawTableOption(dbName, tableName).map { h => // Note: Hive separates partition columns and the schema, but for us the // partition columns are part of the schema val cols = h.getCols.asScala.map(fromHiveColumn) From f876d3fa800ae04ec33f27295354669bb1db911e Mon Sep 17 00:00:00 2001 From: Miklos C Date: Tue, 17 Jul 2018 09:22:16 +0800 Subject: [PATCH 1140/1161] [SPARK-20220][DOCS] Documentation Add thrift scheduling pool config to scheduling docs ## What changes were proposed in this pull request? The thrift scheduling pool configuration was removed from a previous release. Adding this back to the job scheduling configuration docs. This PR takes over #17536 and handle some comments here. ## How was this patch tested? Manually. Closes #17536 Author: hyukjinkwon Closes #21778 from HyukjinKwon/SPARK-20220. --- docs/job-scheduling.md | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/docs/job-scheduling.md b/docs/job-scheduling.md index da90342406c84..2316f175676ee 100644 --- a/docs/job-scheduling.md +++ b/docs/job-scheduling.md @@ -264,3 +264,11 @@ within it for the various settings. For example: A full example is also available in `conf/fairscheduler.xml.template`. Note that any pools not configured in the XML file will simply get default values for all settings (scheduling mode FIFO, weight 1, and minShare 0). + +## Scheduling using JDBC Connections +To set a [Fair Scheduler](job-scheduling.html#fair-scheduler-pools) pool for a JDBC client session, +users can set the `spark.sql.thriftserver.scheduler.pool` variable: + +{% highlight SQL %} +SET spark.sql.thriftserver.scheduler.pool=accounting; +{% endhighlight %} From 0ca16f6e143768f0c96b5310c1f81b3b51dcbbc8 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Tue, 17 Jul 2018 11:30:53 +0800 Subject: [PATCH 1141/1161] Revert "[SPARK-24402][SQL] Optimize `In` expression when only one element in the collection or collection is empty" This reverts commit 0f0d1865f581a9158d73505471953656b173beba. --- .../sql/catalyst/optimizer/expressions.scala | 13 +++----- .../catalyst/optimizer/OptimizeInSuite.scala | 32 ------------------- 2 files changed, 4 insertions(+), 41 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index f78a0ff95f382..1d363b8146e3f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -218,20 +218,15 @@ object ReorderAssociativeOperator extends Rule[LogicalPlan] { object OptimizeIn extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case q: LogicalPlan => q transformExpressionsDown { - case In(v, list) if list.isEmpty => - // When v is not nullable, the following expression will be optimized - // to FalseLiteral which is tested in OptimizeInSuite.scala - If(IsNotNull(v), FalseLiteral, Literal(null, BooleanType)) + case In(v, list) if list.isEmpty && !v.nullable => FalseLiteral case expr @ In(v, list) if expr.inSetConvertible => val newList = ExpressionSet(list).toSeq - if (newList.length == 1 && !newList.isInstanceOf[ListQuery]) { - EqualTo(v, newList.head) - } else if (newList.length > SQLConf.get.optimizerInSetConversionThreshold) { + if (newList.size > SQLConf.get.optimizerInSetConversionThreshold) { val hSet = newList.map(e => e.eval(EmptyRow)) InSet(v, HashSet() ++ hSet) - } else if (newList.length < list.length) { + } else if (newList.size < list.size) { expr.copy(list = newList) - } else { // newList.length == list.length && newList.length > 1 + } else { // newList.length == list.length expr } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala index 86522a6a54ed5..478118ed709f7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala @@ -176,21 +176,6 @@ class OptimizeInSuite extends PlanTest { } } - test("OptimizedIn test: one element in list gets transformed to EqualTo.") { - val originalQuery = - testRelation - .where(In(UnresolvedAttribute("a"), Seq(Literal(1)))) - .analyze - - val optimized = Optimize.execute(originalQuery) - val correctAnswer = - testRelation - .where(EqualTo(UnresolvedAttribute("a"), Literal(1))) - .analyze - - comparePlans(optimized, correctAnswer) - } - test("OptimizedIn test: In empty list gets transformed to FalseLiteral " + "when value is not nullable") { val originalQuery = @@ -206,21 +191,4 @@ class OptimizeInSuite extends PlanTest { comparePlans(optimized, correctAnswer) } - - test("OptimizedIn test: In empty list gets transformed to `If` expression " + - "when value is nullable") { - val originalQuery = - testRelation - .where(In(UnresolvedAttribute("a"), Nil)) - .analyze - - val optimized = Optimize.execute(originalQuery) - val correctAnswer = - testRelation - .where(If(IsNotNull(UnresolvedAttribute("a")), - Literal(false), Literal.create(null, BooleanType))) - .analyze - - comparePlans(optimized, correctAnswer) - } } From 4cf1bec4dc574c541d03ea2f49db4de8b76ef6d2 Mon Sep 17 00:00:00 2001 From: Marek Novotny Date: Tue, 17 Jul 2018 23:07:18 +0800 Subject: [PATCH 1142/1161] [SPARK-24305][SQL][FOLLOWUP] Avoid serialization of private fields in collection expressions. ## What changes were proposed in this pull request? The PR tries to avoid serialization of private fields of already added collection functions and follows up on comments in [SPARK-23922](https://github.com/apache/spark/pull/21028) and [SPARK-23935](https://github.com/apache/spark/pull/21236) ## How was this patch tested? Run tests from: - CollectionExpressionSuite.scala - DataFrameFunctionsSuite.scala Author: Marek Novotny Closes #21352 from mn-mikke/SPARK-24305. --- .../expressions/collectionOperations.scala | 132 +++++++++--------- 1 file changed, 64 insertions(+), 68 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 972bc6e57892c..d60f4c36fa214 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -168,27 +168,22 @@ case class ArraysZip(children: Seq[Expression]) extends Expression with ExpectsI override def inputTypes: Seq[AbstractDataType] = Seq.fill(children.length)(ArrayType) - override def dataType: DataType = ArrayType(mountSchema) - - override def nullable: Boolean = children.exists(_.nullable) - - private lazy val arrayTypes = children.map(_.dataType.asInstanceOf[ArrayType]) - - private lazy val arrayElementTypes = arrayTypes.map(_.elementType) - - @transient private lazy val mountSchema: StructType = { + @transient override lazy val dataType: DataType = { val fields = children.zip(arrayElementTypes).zipWithIndex.map { case ((expr: NamedExpression, elementType), _) => StructField(expr.name, elementType, nullable = true) case ((_, elementType), idx) => StructField(idx.toString, elementType, nullable = true) } - StructType(fields) + ArrayType(StructType(fields), containsNull = false) } - @transient lazy val numberOfArrays: Int = children.length + override def nullable: Boolean = children.exists(_.nullable) + + @transient private lazy val arrayElementTypes = + children.map(_.dataType.asInstanceOf[ArrayType].elementType) - @transient lazy val genericArrayData = classOf[GenericArrayData].getName + private def genericArrayData = classOf[GenericArrayData].getName def emptyInputGenCode(ev: ExprCode): ExprCode = { ev.copy(code""" @@ -256,7 +251,7 @@ case class ArraysZip(children: Seq[Expression]) extends Expression with ExpectsI ("ArrayData[]", arrVals) :: Nil) val initVariables = s""" - |ArrayData[] $arrVals = new ArrayData[$numberOfArrays]; + |ArrayData[] $arrVals = new ArrayData[${children.length}]; |int $biggestCardinality = 0; |${CodeGenerator.javaType(dataType)} ${ev.value} = null; """.stripMargin @@ -268,7 +263,7 @@ case class ArraysZip(children: Seq[Expression]) extends Expression with ExpectsI |if (!${ev.isNull}) { | Object[] $args = new Object[$biggestCardinality]; | for (int $i = 0; $i < $biggestCardinality; $i ++) { - | Object[] $currentRow = new Object[$numberOfArrays]; + | Object[] $currentRow = new Object[${children.length}]; | $getValueForTypeSplitted | $args[$i] = new $genericInternalRow($currentRow); | } @@ -278,7 +273,7 @@ case class ArraysZip(children: Seq[Expression]) extends Expression with ExpectsI } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - if (numberOfArrays == 0) { + if (children.length == 0) { emptyInputGenCode(ev) } else { nonEmptyInputGenCode(ctx, ev) @@ -360,7 +355,7 @@ case class MapEntries(child: Expression) extends UnaryExpression with ExpectsInp override def inputTypes: Seq[AbstractDataType] = Seq(MapType) - lazy val childDataType: MapType = child.dataType.asInstanceOf[MapType] + @transient private lazy val childDataType: MapType = child.dataType.asInstanceOf[MapType] override def dataType: DataType = { ArrayType( @@ -520,7 +515,7 @@ case class MapConcat(children: Seq[Expression]) extends ComplexTypeMergingExpres } } - override def dataType: MapType = { + @transient override lazy val dataType: MapType = { if (children.isEmpty) { MapType(StringType, StringType) } else { @@ -747,11 +742,11 @@ case class MapFromEntries(child: Expression) extends UnaryExpression { case _ => None } - private def nullEntries: Boolean = dataTypeDetails.get._3 + @transient private lazy val nullEntries: Boolean = dataTypeDetails.get._3 override def nullable: Boolean = child.nullable || nullEntries - override def dataType: MapType = dataTypeDetails.get._1 + @transient override lazy val dataType: MapType = dataTypeDetails.get._1 override def checkInputDataTypes(): TypeCheckResult = dataTypeDetails match { case Some(_) => TypeCheckResult.TypeCheckSuccess @@ -949,8 +944,7 @@ trait ArraySortLike extends ExpectsInputTypes { protected def nullOrder: NullOrder - @transient - private lazy val lt: Comparator[Any] = { + @transient private lazy val lt: Comparator[Any] = { val ordering = arrayExpression.dataType match { case _ @ ArrayType(n: AtomicType, _) => n.ordering.asInstanceOf[Ordering[Any]] case _ @ ArrayType(a: ArrayType, _) => a.interpretedOrdering.asInstanceOf[Ordering[Any]] @@ -972,8 +966,7 @@ trait ArraySortLike extends ExpectsInputTypes { } } - @transient - private lazy val gt: Comparator[Any] = { + @transient private lazy val gt: Comparator[Any] = { val ordering = arrayExpression.dataType match { case _ @ ArrayType(n: AtomicType, _) => n.ordering.asInstanceOf[Ordering[Any]] case _ @ ArrayType(a: ArrayType, _) => a.interpretedOrdering.asInstanceOf[Ordering[Any]] @@ -995,7 +988,9 @@ trait ArraySortLike extends ExpectsInputTypes { } } - def elementType: DataType = arrayExpression.dataType.asInstanceOf[ArrayType].elementType + @transient lazy val elementType: DataType = + arrayExpression.dataType.asInstanceOf[ArrayType].elementType + def containsNull: Boolean = arrayExpression.dataType.asInstanceOf[ArrayType].containsNull def sortEval(array: Any, ascending: Boolean): Any = { @@ -1211,7 +1206,7 @@ case class Reverse(child: Expression) extends UnaryExpression with ImplicitCastI override def dataType: DataType = child.dataType - lazy val elementType: DataType = dataType.asInstanceOf[ArrayType].elementType + @transient private lazy val elementType: DataType = dataType.asInstanceOf[ArrayType].elementType override def nullSafeEval(input: Any): Any = input match { case a: ArrayData => new GenericArrayData(a.toObjectArray(elementType).reverse) @@ -1601,9 +1596,9 @@ case class Slice(x: Expression, start: Expression, length: Expression) override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, IntegerType, IntegerType) - override def children: Seq[Expression] = Seq(x, start, length) + @transient override lazy val children: Seq[Expression] = Seq(x, start, length) // called from eval - lazy val elementType: DataType = x.dataType.asInstanceOf[ArrayType].elementType + @transient private lazy val elementType: DataType = x.dataType.asInstanceOf[ArrayType].elementType override def nullSafeEval(xVal: Any, startVal: Any, lengthVal: Any): Any = { val startInt = startVal.asInstanceOf[Int] @@ -1889,7 +1884,7 @@ case class ArrayMin(child: Expression) extends UnaryExpression with ImplicitCast override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType) - private lazy val ordering = TypeUtils.getInterpretedOrdering(dataType) + @transient private lazy val ordering = TypeUtils.getInterpretedOrdering(dataType) override def checkInputDataTypes(): TypeCheckResult = { val typeCheckResult = super.checkInputDataTypes() @@ -1930,7 +1925,7 @@ case class ArrayMin(child: Expression) extends UnaryExpression with ImplicitCast min } - override def dataType: DataType = child.dataType match { + @transient override lazy val dataType: DataType = child.dataType match { case ArrayType(dt, _) => dt case _ => throw new IllegalStateException(s"$prettyName accepts only arrays.") } @@ -1954,7 +1949,7 @@ case class ArrayMax(child: Expression) extends UnaryExpression with ImplicitCast override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType) - private lazy val ordering = TypeUtils.getInterpretedOrdering(dataType) + @transient private lazy val ordering = TypeUtils.getInterpretedOrdering(dataType) override def checkInputDataTypes(): TypeCheckResult = { val typeCheckResult = super.checkInputDataTypes() @@ -1995,7 +1990,7 @@ case class ArrayMax(child: Expression) extends UnaryExpression with ImplicitCast max } - override def dataType: DataType = child.dataType match { + @transient override lazy val dataType: DataType = child.dataType match { case ArrayType(dt, _) => dt case _ => throw new IllegalStateException(s"$prettyName accepts only arrays.") } @@ -2097,10 +2092,13 @@ case class ArrayPosition(left: Expression, right: Expression) since = "2.4.0") case class ElementAt(left: Expression, right: Expression) extends GetMapValueUtil { - @transient private lazy val ordering: Ordering[Any] = - TypeUtils.getInterpretedOrdering(left.dataType.asInstanceOf[MapType].keyType) + @transient private lazy val mapKeyType = left.dataType.asInstanceOf[MapType].keyType + + @transient private lazy val arrayContainsNull = left.dataType.asInstanceOf[ArrayType].containsNull + + @transient private lazy val ordering: Ordering[Any] = TypeUtils.getInterpretedOrdering(mapKeyType) - override def dataType: DataType = left.dataType match { + @transient override lazy val dataType: DataType = left.dataType match { case ArrayType(elementType, _) => elementType case MapType(_, valueType, _) => valueType } @@ -2109,7 +2107,7 @@ case class ElementAt(left: Expression, right: Expression) extends GetMapValueUti Seq(TypeCollection(ArrayType, MapType), left.dataType match { case _: ArrayType => IntegerType - case _: MapType => left.dataType.asInstanceOf[MapType].keyType + case _: MapType => mapKeyType case _ => AnyDataType // no match for a wrong 'left' expression type } ) @@ -2119,8 +2117,7 @@ case class ElementAt(left: Expression, right: Expression) extends GetMapValueUti super.checkInputDataTypes() match { case f: TypeCheckResult.TypeCheckFailure => f case TypeCheckResult.TypeCheckSuccess if left.dataType.isInstanceOf[MapType] => - TypeUtils.checkForOrderingExpr( - left.dataType.asInstanceOf[MapType].keyType, s"function $prettyName") + TypeUtils.checkForOrderingExpr(mapKeyType, s"function $prettyName") case TypeCheckResult.TypeCheckSuccess => TypeCheckResult.TypeCheckSuccess } } @@ -2142,14 +2139,14 @@ case class ElementAt(left: Expression, right: Expression) extends GetMapValueUti } else { array.numElements() + index } - if (left.dataType.asInstanceOf[ArrayType].containsNull && array.isNullAt(idx)) { + if (arrayContainsNull && array.isNullAt(idx)) { null } else { array.get(idx, dataType) } } case _: MapType => - getValueEval(value, ordinal, left.dataType.asInstanceOf[MapType].keyType, ordering) + getValueEval(value, ordinal, mapKeyType, ordering) } } @@ -2158,7 +2155,7 @@ case class ElementAt(left: Expression, right: Expression) extends GetMapValueUti case _: ArrayType => nullSafeCodeGen(ctx, ev, (eval1, eval2) => { val index = ctx.freshName("elementAtIndex") - val nullCheck = if (left.dataType.asInstanceOf[ArrayType].containsNull) { + val nullCheck = if (arrayContainsNull) { s""" |if ($eval1.isNullAt($index)) { | ${ev.isNull} = true; @@ -2209,9 +2206,7 @@ case class ElementAt(left: Expression, right: Expression) extends GetMapValueUti """) case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpression { - private val MAX_ARRAY_LENGTH: Int = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH - - val allowedTypes = Seq(StringType, BinaryType, ArrayType) + private def allowedTypes: Seq[AbstractDataType] = Seq(StringType, BinaryType, ArrayType) override def checkInputDataTypes(): TypeCheckResult = { if (children.isEmpty) { @@ -2228,7 +2223,7 @@ case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpressio } } - override def dataType: DataType = { + @transient override lazy val dataType: DataType = { if (children.isEmpty) { StringType } else { @@ -2236,7 +2231,7 @@ case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpressio } } - lazy val javaType: String = CodeGenerator.javaType(dataType) + private def javaType: String = CodeGenerator.javaType(dataType) override def nullable: Boolean = children.exists(_.nullable) @@ -2256,9 +2251,10 @@ case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpressio } else { val arrayData = inputs.map(_.asInstanceOf[ArrayData]) val numberOfElements = arrayData.foldLeft(0L)((sum, ad) => sum + ad.numElements()) - if (numberOfElements > MAX_ARRAY_LENGTH) { + if (numberOfElements > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { throw new RuntimeException(s"Unsuccessful try to concat arrays with $numberOfElements" + - s" elements due to exceeding the array size limit $MAX_ARRAY_LENGTH.") + " elements due to exceeding the array size limit " + + ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH + ".") } val finalData = new Array[AnyRef](numberOfElements.toInt) var position = 0 @@ -2316,9 +2312,10 @@ case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpressio |for (int z = 0; z < ${children.length}; z++) { | $numElements += args[z].numElements(); |} - |if ($numElements > $MAX_ARRAY_LENGTH) { + |if ($numElements > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { | throw new RuntimeException("Unsuccessful try to concat arrays with " + $numElements + - | " elements due to exceeding the array size limit $MAX_ARRAY_LENGTH."); + | " elements due to exceeding the array size limit" + + | " ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}."); |} """.stripMargin @@ -2413,15 +2410,13 @@ case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpressio since = "2.4.0") case class Flatten(child: Expression) extends UnaryExpression { - private val MAX_ARRAY_LENGTH = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH - - private lazy val childDataType: ArrayType = child.dataType.asInstanceOf[ArrayType] + private def childDataType: ArrayType = child.dataType.asInstanceOf[ArrayType] override def nullable: Boolean = child.nullable || childDataType.containsNull - override def dataType: DataType = childDataType.elementType + @transient override lazy val dataType: DataType = childDataType.elementType - lazy val elementType: DataType = dataType.asInstanceOf[ArrayType].elementType + @transient private lazy val elementType: DataType = dataType.asInstanceOf[ArrayType].elementType override def checkInputDataTypes(): TypeCheckResult = child.dataType match { case ArrayType(_: ArrayType, _) => @@ -2441,9 +2436,10 @@ case class Flatten(child: Expression) extends UnaryExpression { } else { val arrayData = elements.map(_.asInstanceOf[ArrayData]) val numberOfElements = arrayData.foldLeft(0L)((sum, e) => sum + e.numElements()) - if (numberOfElements > MAX_ARRAY_LENGTH) { + if (numberOfElements > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { throw new RuntimeException("Unsuccessful try to flatten an array of arrays with " + - s"$numberOfElements elements due to exceeding the array size limit $MAX_ARRAY_LENGTH.") + s"$numberOfElements elements due to exceeding the array size limit " + + ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH + ".") } val flattenedData = new Array(numberOfElements.toInt) var position = 0 @@ -2476,9 +2472,10 @@ case class Flatten(child: Expression) extends UnaryExpression { |for (int z = 0; z < $childVariableName.numElements(); z++) { | $variableName += $childVariableName.getArray(z).numElements(); |} - |if ($variableName > $MAX_ARRAY_LENGTH) { + |if ($variableName > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { | throw new RuntimeException("Unsuccessful try to flatten an array of arrays with " + - | $variableName + " elements due to exceeding the array size limit $MAX_ARRAY_LENGTH."); + | $variableName + " elements due to exceeding the array size limit" + + | " ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}."); |} """.stripMargin (code, variableName) @@ -2602,7 +2599,7 @@ case class Sequence( override def nullable: Boolean = children.exists(_.nullable) - override lazy val dataType: ArrayType = ArrayType(start.dataType, containsNull = false) + override def dataType: ArrayType = ArrayType(start.dataType, containsNull = false) override def checkInputDataTypes(): TypeCheckResult = { val startType = start.dataType @@ -2633,7 +2630,7 @@ case class Sequence( stepOpt.map(step => if (step.dataType != CalendarIntervalType) Cast(step, widerType) else step), timeZoneId) - private lazy val impl: SequenceImpl = dataType.elementType match { + @transient private lazy val impl: SequenceImpl = dataType.elementType match { case iType: IntegralType => type T = iType.InternalType val ct = ClassTag[T](iType.tag.mirror.runtimeClass(iType.tag.tpe)) @@ -2953,8 +2950,6 @@ object Sequence { case class ArrayRepeat(left: Expression, right: Expression) extends BinaryExpression with ExpectsInputTypes { - private val MAX_ARRAY_LENGTH = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH - override def dataType: ArrayType = ArrayType(left.dataType, left.nullable) override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType, IntegerType) @@ -2966,9 +2961,9 @@ case class ArrayRepeat(left: Expression, right: Expression) if (count == null) { null } else { - if (count.asInstanceOf[Int] > MAX_ARRAY_LENGTH) { + if (count.asInstanceOf[Int] > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { throw new RuntimeException(s"Unsuccessful try to create array with $count elements " + - s"due to exceeding the array size limit $MAX_ARRAY_LENGTH."); + s"due to exceeding the array size limit ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}."); } val element = left.eval(input) new GenericArrayData(Array.fill(count.asInstanceOf[Int])(element)) @@ -3027,9 +3022,10 @@ case class ArrayRepeat(left: Expression, right: Expression) |if ($count > 0) { | $numElements = $count; |} - |if ($numElements > $MAX_ARRAY_LENGTH) { + |if ($numElements > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { | throw new RuntimeException("Unsuccessful try to create array with " + $numElements + - | " elements due to exceeding the array size limit $MAX_ARRAY_LENGTH."); + | " elements due to exceeding the array size limit" + + | " ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}."); |} """.stripMargin @@ -3111,7 +3107,7 @@ case class ArrayRemove(left: Expression, right: Expression) Seq(ArrayType, elementType) } - lazy val elementType: DataType = left.dataType.asInstanceOf[ArrayType].elementType + private def elementType: DataType = left.dataType.asInstanceOf[ArrayType].elementType @transient private lazy val ordering: Ordering[Any] = TypeUtils.getInterpretedOrdering(right.dataType) @@ -3228,7 +3224,7 @@ case class ArrayDistinct(child: Expression) override def dataType: DataType = child.dataType - @transient lazy val elementType: DataType = dataType.asInstanceOf[ArrayType].elementType + @transient private lazy val elementType: DataType = dataType.asInstanceOf[ArrayType].elementType @transient private lazy val ordering: Ordering[Any] = TypeUtils.getInterpretedOrdering(elementType) From 5215344deaa5533e593c62aba3fcdfa1a2901801 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Tue, 17 Jul 2018 11:23:34 -0500 Subject: [PATCH 1143/1161] [SPARK-24813][BUILD][FOLLOW-UP][HOTFIX] HiveExternalCatalogVersionsSuite still flaky; fall back to Apache archive ## What changes were proposed in this pull request? Test HiveExternalCatalogVersionsSuite vs only current Spark releases ## How was this patch tested? `HiveExternalCatalogVersionsSuite` Author: Sean Owen Closes #21793 from srowen/SPARK-24813.3. --- .../spark/sql/hive/HiveExternalCatalogVersionsSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala index f8212684d5335..5103aa8a207db 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala @@ -203,7 +203,7 @@ class HiveExternalCatalogVersionsSuite extends SparkSubmitTestUtils { object PROCESS_TABLES extends QueryTest with SQLTestUtils { // Tests the latest version of every release line. - val testingVersions = Seq("2.0.2", "2.1.2", "2.2.1", "2.3.1") + val testingVersions = Seq("2.1.3", "2.2.2", "2.3.1") protected var spark: SparkSession = _ From 7688ce88b2ea514054200845ae860fbccc25a927 Mon Sep 17 00:00:00 2001 From: HanShuliang Date: Tue, 17 Jul 2018 11:25:23 -0500 Subject: [PATCH 1144/1161] [SPARK-21590][SS] Window start time should support negative values ## What changes were proposed in this pull request? Remove the non-negative checks of window start time to make window support negative start time, and add a check to guarantee the absolute value of start time is less than slide duration. ## How was this patch tested? New unit tests. Author: HanShuliang Closes #18903 from KevinZwx/dev. --- .../sql/catalyst/expressions/TimeWindow.scala | 9 ++-- .../analysis/AnalysisErrorSuite.scala | 25 +++++++---- .../expressions/TimeWindowSuite.scala | 13 ++++++ .../sql/DataFrameTimeWindowingSuite.scala | 45 +++++++++++++++++++ 4 files changed, 77 insertions(+), 15 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala index 84e38a8b2711e..8e48856d4607c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala @@ -80,16 +80,13 @@ case class TimeWindow( if (slideDuration <= 0) { return TypeCheckFailure(s"The slide duration ($slideDuration) must be greater than 0.") } - if (startTime < 0) { - return TypeCheckFailure(s"The start time ($startTime) must be greater than or equal to 0.") - } if (slideDuration > windowDuration) { return TypeCheckFailure(s"The slide duration ($slideDuration) must be less than or equal" + s" to the windowDuration ($windowDuration).") } - if (startTime >= slideDuration) { - return TypeCheckFailure(s"The start time ($startTime) must be less than the " + - s"slideDuration ($slideDuration).") + if (startTime.abs >= slideDuration) { + return TypeCheckFailure(s"The absolute value of start time ($startTime) must be less " + + s"than the slideDuration ($slideDuration).") } } dataTypeCheck diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index 5d2f8e735e3d4..0ce94d39e994a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -334,14 +334,28 @@ class AnalysisErrorSuite extends AnalysisTest { "start time greater than slide duration in time window", testRelation.select( TimeWindow(Literal("2016-01-01 01:01:01"), "1 second", "1 second", "1 minute").as("window")), - "The start time " :: " must be less than the slideDuration " :: Nil + "The absolute value of start time " :: " must be less than the slideDuration " :: Nil ) errorTest( "start time equal to slide duration in time window", testRelation.select( TimeWindow(Literal("2016-01-01 01:01:01"), "1 second", "1 second", "1 second").as("window")), - "The start time " :: " must be less than the slideDuration " :: Nil + "The absolute value of start time " :: " must be less than the slideDuration " :: Nil + ) + + errorTest( + "SPARK-21590: absolute value of start time greater than slide duration in time window", + testRelation.select( + TimeWindow(Literal("2016-01-01 01:01:01"), "1 second", "1 second", "-1 minute").as("window")), + "The absolute value of start time " :: " must be less than the slideDuration " :: Nil + ) + + errorTest( + "SPARK-21590: absolute value of start time equal to slide duration in time window", + testRelation.select( + TimeWindow(Literal("2016-01-01 01:01:01"), "1 second", "1 second", "-1 second").as("window")), + "The absolute value of start time " :: " must be less than the slideDuration " :: Nil ) errorTest( @@ -372,13 +386,6 @@ class AnalysisErrorSuite extends AnalysisTest { "The slide duration" :: " must be greater than 0." :: Nil ) - errorTest( - "negative start time in time window", - testRelation.select( - TimeWindow(Literal("2016-01-01 01:01:01"), "1 second", "1 second", "-5 second").as("window")), - "The start time" :: "must be greater than or equal to 0." :: Nil - ) - errorTest( "generator nested in expressions", listRelation.select(Explode('list) + 1), diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TimeWindowSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TimeWindowSuite.scala index 351d4d0c2eac9..d46135c02bc01 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TimeWindowSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TimeWindowSuite.scala @@ -77,6 +77,19 @@ class TimeWindowSuite extends SparkFunSuite with ExpressionEvalHelper with Priva } } + test("SPARK-21590: Start time works with negative values and return microseconds") { + val validDuration = "10 minutes" + for ((text, seconds) <- Seq( + ("-10 seconds", -10000000), // -1e7 + ("-1 minute", -60000000), + ("-1 hour", -3600000000L))) { // -6e7 + assert(TimeWindow(Literal(10L), validDuration, validDuration, "interval " + text).startTime + === seconds) + assert(TimeWindow(Literal(10L), validDuration, validDuration, text).startTime + === seconds) + } + } + private val parseExpression = PrivateMethod[Long]('parseExpression) test("parse sql expression for duration in microseconds - string") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala index 6fe356877c268..2953425b1db49 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala @@ -43,6 +43,22 @@ class DataFrameTimeWindowingSuite extends QueryTest with SharedSQLContext with B ) } + test("SPARK-21590: tumbling window using negative start time") { + val df = Seq( + ("2016-03-27 19:39:30", 1, "a"), + ("2016-03-27 19:39:25", 2, "a")).toDF("time", "value", "id") + + checkAnswer( + df.groupBy(window($"time", "10 seconds", "10 seconds", "-5 seconds")) + .agg(count("*").as("counts")) + .orderBy($"window.start".asc) + .select($"window.start".cast("string"), $"window.end".cast("string"), $"counts"), + Seq( + Row("2016-03-27 19:39:25", "2016-03-27 19:39:35", 2) + ) + ) + } + test("tumbling window groupBy statement") { val df = Seq( ("2016-03-27 19:39:34", 1, "a"), @@ -72,6 +88,20 @@ class DataFrameTimeWindowingSuite extends QueryTest with SharedSQLContext with B Seq(Row(1), Row(1), Row(1))) } + test("SPARK-21590: tumbling window groupBy statement with negative startTime") { + val df = Seq( + ("2016-03-27 19:39:34", 1, "a"), + ("2016-03-27 19:39:56", 2, "a"), + ("2016-03-27 19:39:27", 4, "b")).toDF("time", "value", "id") + + checkAnswer( + df.groupBy(window($"time", "10 seconds", "10 seconds", "-5 seconds"), $"id") + .agg(count("*").as("counts")) + .orderBy($"window.start".asc) + .select("counts"), + Seq(Row(1), Row(1), Row(1))) + } + test("tumbling window with multi-column projection") { val df = Seq( ("2016-03-27 19:39:34", 1, "a"), @@ -309,4 +339,19 @@ class DataFrameTimeWindowingSuite extends QueryTest with SharedSQLContext with B ) } } + + test("SPARK-21590: time window in SQL with three expressions including negative start time") { + withTempTable { table => + checkAnswer( + spark.sql( + s"""select window(time, "10 seconds", 10000000, "-5 seconds"), value from $table""") + .select($"window.start".cast(StringType), $"window.end".cast(StringType), $"value"), + Seq( + Row("2016-03-27 19:39:25", "2016-03-27 19:39:35", 1), + Row("2016-03-27 19:39:25", "2016-03-27 19:39:35", 4), + Row("2016-03-27 19:39:55", "2016-03-27 19:40:05", 2) + ) + ) + } + } } From 912634b004c2302533a8a8501b4ecb803d17e335 Mon Sep 17 00:00:00 2001 From: Bago Amirbekian Date: Tue, 17 Jul 2018 13:11:52 -0700 Subject: [PATCH 1145/1161] [SPARK-24747][ML] Make Instrumentation class more flexible ## What changes were proposed in this pull request? This PR updates the Instrumentation class to make it more flexible and a little bit easier to use. When these APIs are merged, I'll followup with a PR to update the training code to use these new APIs so we can remove the old APIs. These changes are all to private APIs so this PR doesn't make any user facing changes. ## How was this patch tested? Existing tests. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Bago Amirbekian Closes #21719 from MrBago/new-instrumentation-apis. --- .../classification/LogisticRegression.scala | 8 +- .../spark/ml/tree/impl/RandomForest.scala | 2 +- .../spark/ml/tuning/ValidatorParams.scala | 2 +- .../spark/ml/util/Instrumentation.scala | 128 ++++++++++++------ .../spark/mllib/clustering/KMeans.scala | 4 +- 5 files changed, 93 insertions(+), 51 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 92e342ed4a464..25fb9c8aab0bc 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -35,6 +35,7 @@ import org.apache.spark.ml.optim.loss.{L2Regularization, RDDLossFunction} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ +import org.apache.spark.ml.util.Instrumentation.instrumented import org.apache.spark.mllib.evaluation.{BinaryClassificationMetrics, MulticlassMetrics} import org.apache.spark.mllib.linalg.VectorImplicits._ import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer @@ -490,7 +491,7 @@ class LogisticRegression @Since("1.2.0") ( protected[spark] def train( dataset: Dataset[_], - handlePersistence: Boolean): LogisticRegressionModel = { + handlePersistence: Boolean): LogisticRegressionModel = instrumented { instr => val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol)) val instances: RDD[Instance] = dataset.select(col($(labelCol)), w, col($(featuresCol))).rdd.map { @@ -500,7 +501,8 @@ class LogisticRegression @Since("1.2.0") ( if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK) - val instr = Instrumentation.create(this, dataset) + instr.logPipelineStage(this) + instr.logDataset(dataset) instr.logParams(regParam, elasticNetParam, standardization, threshold, maxIter, tol, fitIntercept) @@ -905,8 +907,6 @@ class LogisticRegression @Since("1.2.0") ( objectiveHistory) } model.setSummary(Some(logRegSummary)) - instr.logSuccess(model) - model } @Since("1.4.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index 905870178e549..bb3f3a015c715 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -91,7 +91,7 @@ private[spark] object RandomForest extends Logging { numTrees: Int, featureSubsetStrategy: String, seed: Long, - instr: Option[Instrumentation[_]], + instr: Option[Instrumentation], prune: Boolean = true, // exposed for testing only, real trees are always pruned parentUID: Option[String] = None): Array[DecisionTreeModel] = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala index 363304ef10147..135828815504a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala @@ -80,7 +80,7 @@ private[ml] trait ValidatorParams extends HasSeed with Params { /** * Instrumentation logging for tuning params including the inner estimator and evaluator info. */ - protected def logTuningParams(instrumentation: Instrumentation[_]): Unit = { + protected def logTuningParams(instrumentation: Instrumentation): Unit = { instrumentation.logNamedValue("estimator", $(estimator).getClass.getCanonicalName) instrumentation.logNamedValue("evaluator", $(evaluator).getClass.getCanonicalName) instrumentation.logNamedValue("estimatorParamMapsLength", $(estimatorParamMaps).length) diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala b/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala index 11f46eb9e4359..2e43a9ef49ee1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala @@ -19,15 +19,16 @@ package org.apache.spark.ml.util import java.util.UUID -import scala.reflect.ClassTag +import scala.util.{Failure, Success, Try} +import scala.util.control.NonFatal import org.json4s._ import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ import org.apache.spark.internal.Logging -import org.apache.spark.ml.{Estimator, Model} -import org.apache.spark.ml.param.Param +import org.apache.spark.ml.{Estimator, Model, PipelineStage} +import org.apache.spark.ml.param.{Param, Params} import org.apache.spark.rdd.RDD import org.apache.spark.sql.Dataset import org.apache.spark.util.Utils @@ -35,29 +36,44 @@ import org.apache.spark.util.Utils /** * A small wrapper that defines a training session for an estimator, and some methods to log * useful information during this session. - * - * A new instance is expected to be created within fit(). - * - * @param estimator the estimator that is being fit - * @param dataset the training dataset - * @tparam E the type of the estimator */ -private[spark] class Instrumentation[E <: Estimator[_]] private ( - val estimator: E, - val dataset: RDD[_]) extends Logging { +private[spark] class Instrumentation extends Logging { private val id = UUID.randomUUID() - private val prefix = { + private val shortId = id.toString.take(8) + private val prefix = s"[$shortId] " + + // TODO: remove stage + var stage: Params = _ + // TODO: update spark.ml to use new Instrumentation APIs and remove this constructor + private def this(estimator: Estimator[_], dataset: RDD[_]) = { + this() + logPipelineStage(estimator) + logDataset(dataset) + } + + /** + * Log some info about the pipeline stage being fit. + */ + def logPipelineStage(stage: PipelineStage): Unit = { + this.stage = stage // estimator.getClass.getSimpleName can cause Malformed class name error, // call safer `Utils.getSimpleName` instead - val className = Utils.getSimpleName(estimator.getClass) - s"$className-${estimator.uid}-${dataset.hashCode()}-$id: " + val className = Utils.getSimpleName(stage.getClass) + logInfo(s"Stage class: $className") + logInfo(s"Stage uid: ${stage.uid}") } - init() + /** + * Log some data about the dataset being fit. + */ + def logDataset(dataset: Dataset[_]): Unit = logDataset(dataset.rdd) - private def init(): Unit = { - log(s"training: numPartitions=${dataset.partitions.length}" + + /** + * Log some data about the dataset being fit. + */ + def logDataset(dataset: RDD[_]): Unit = { + logInfo(s"training: numPartitions=${dataset.partitions.length}" + s" storageLevel=${dataset.getStorageLevel}") } @@ -89,23 +105,25 @@ private[spark] class Instrumentation[E <: Estimator[_]] private ( super.logInfo(prefix + msg) } - /** - * Alias for logInfo, see above. - */ - def log(msg: String): Unit = logInfo(msg) - /** * Logs the value of the given parameters for the estimator being used in this session. */ - def logParams(params: Param[_]*): Unit = { + def logParams(hasParams: Params, params: Param[_]*): Unit = { val pairs: Seq[(String, JValue)] = for { p <- params - value <- estimator.get(p) + value <- hasParams.get(p) } yield { val cast = p.asInstanceOf[Param[Any]] p.name -> parse(cast.jsonEncode(value)) } - log(compact(render(map2jvalue(pairs.toMap)))) + logInfo(compact(render(map2jvalue(pairs.toMap)))) + } + + // TODO: remove this + def logParams(params: Param[_]*): Unit = { + require(stage != null, "`logStageParams` must be called before `logParams` (or an instance of" + + " Params must be provided explicitly).") + logParams(stage, params: _*) } def logNumFeatures(num: Long): Unit = { @@ -124,35 +142,48 @@ private[spark] class Instrumentation[E <: Estimator[_]] private ( * Logs the value with customized name field. */ def logNamedValue(name: String, value: String): Unit = { - log(compact(render(name -> value))) + logInfo(compact(render(name -> value))) } def logNamedValue(name: String, value: Long): Unit = { - log(compact(render(name -> value))) + logInfo(compact(render(name -> value))) } def logNamedValue(name: String, value: Double): Unit = { - log(compact(render(name -> value))) + logInfo(compact(render(name -> value))) } def logNamedValue(name: String, value: Array[String]): Unit = { - log(compact(render(name -> compact(render(value.toSeq))))) + logInfo(compact(render(name -> compact(render(value.toSeq))))) } def logNamedValue(name: String, value: Array[Long]): Unit = { - log(compact(render(name -> compact(render(value.toSeq))))) + logInfo(compact(render(name -> compact(render(value.toSeq))))) } def logNamedValue(name: String, value: Array[Double]): Unit = { - log(compact(render(name -> compact(render(value.toSeq))))) + logInfo(compact(render(name -> compact(render(value.toSeq))))) } + // TODO: Remove this (possibly replace with logModel?) /** * Logs the successful completion of the training session. */ def logSuccess(model: Model[_]): Unit = { - log(s"training finished") + logInfo(s"training finished") + } + + def logSuccess(): Unit = { + logInfo("training finished") + } + + /** + * Logs an exception raised during a training session. + */ + def logFailure(e: Throwable): Unit = { + val msg = e.getStackTrace.mkString("\n") + super.logError(msg) } } @@ -169,22 +200,33 @@ private[spark] object Instrumentation { val varianceOfLabels = "varianceOfLabels" } + // TODO: Remove these /** * Creates an instrumentation object for a training session. */ - def create[E <: Estimator[_]]( - estimator: E, dataset: Dataset[_]): Instrumentation[E] = { - create[E](estimator, dataset.rdd) + def create(estimator: Estimator[_], dataset: Dataset[_]): Instrumentation = { + create(estimator, dataset.rdd) } /** * Creates an instrumentation object for a training session. */ - def create[E <: Estimator[_]]( - estimator: E, dataset: RDD[_]): Instrumentation[E] = { - new Instrumentation[E](estimator, dataset) + def create(estimator: Estimator[_], dataset: RDD[_]): Instrumentation = { + new Instrumentation(estimator, dataset) + } + // end remove + + def instrumented[T](body: (Instrumentation => T)): T = { + val instr = new Instrumentation() + Try(body(instr)) match { + case Failure(NonFatal(e)) => + instr.logFailure(e) + throw e + case Success(result) => + instr.logSuccess() + result + } } - } /** @@ -193,7 +235,7 @@ private[spark] object Instrumentation { * will log via it, otherwise will log via common logger. */ private[spark] class OptionalInstrumentation private( - val instrumentation: Option[Instrumentation[_ <: Estimator[_]]], + val instrumentation: Option[Instrumentation], val className: String) extends Logging { protected override def logName: String = className @@ -225,9 +267,9 @@ private[spark] object OptionalInstrumentation { /** * Creates an `OptionalInstrumentation` object from an existing `Instrumentation` object. */ - def create[E <: Estimator[_]](instr: Instrumentation[E]): OptionalInstrumentation = { + def create(instr: Instrumentation): OptionalInstrumentation = { new OptionalInstrumentation(Some(instr), - instr.estimator.getClass.getName.stripSuffix("$")) + instr.stage.getClass.getName.stripSuffix("$")) } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala index 37ae8b1a6171a..4f554f420b903 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala @@ -235,7 +235,7 @@ class KMeans private ( private[spark] def run( data: RDD[Vector], - instr: Option[Instrumentation[NewKMeans]]): KMeansModel = { + instr: Option[Instrumentation]): KMeansModel = { if (data.getStorageLevel == StorageLevel.NONE) { logWarning("The input data is not directly cached, which may hurt performance if its" @@ -264,7 +264,7 @@ class KMeans private ( */ private def runAlgorithm( data: RDD[VectorWithNorm], - instr: Option[Instrumentation[NewKMeans]]): KMeansModel = { + instr: Option[Instrumentation]): KMeansModel = { val sc = data.sparkContext From 2a4dd6f06cfd2f58fda9786c88809e6de695444e Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Tue, 17 Jul 2018 14:15:30 -0700 Subject: [PATCH 1146/1161] [SPARK-24681][SQL] Verify nested column names in Hive metastore ## What changes were proposed in this pull request? This pr added code to check if nested column names do not include ',', ':', and ';' because Hive metastore can't handle these characters in nested column names; ref: https://github.com/apache/hive/blob/release-1.2.1/serde/src/java/org/apache/hadoop/hive/serde2/typeinfo/TypeInfoUtils.java#L239 ## How was this patch tested? Added tests in `HiveDDLSuite`. Author: Takeshi Yamamuro Closes #21711 from maropu/SPARK-24681. --- .../spark/sql/hive/HiveExternalCatalog.scala | 34 +++++++++++++++---- .../sql/hive/execution/HiveDDLSuite.scala | 19 +++++++++++ 2 files changed, 46 insertions(+), 7 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index 44480cecf0039..7f28fc40b4469 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -138,17 +138,37 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat } /** - * Checks the validity of data column names. Hive metastore disallows the table to use comma in - * data column names. Partition columns do not have such a restriction. Views do not have such - * a restriction. + * Checks the validity of data column names. Hive metastore disallows the table to use some + * special characters (',', ':', and ';') in data column names, including nested column names. + * Partition columns do not have such a restriction. Views do not have such a restriction. */ private def verifyDataSchema( tableName: TableIdentifier, tableType: CatalogTableType, dataSchema: StructType): Unit = { if (tableType != VIEW) { - dataSchema.map(_.name).foreach { colName => - if (colName.contains(",")) { - throw new AnalysisException("Cannot create a table having a column whose name contains " + - s"commas in Hive metastore. Table: $tableName; Column: $colName") + val invalidChars = Seq(",", ":", ";") + def verifyNestedColumnNames(schema: StructType): Unit = schema.foreach { f => + f.dataType match { + case st: StructType => verifyNestedColumnNames(st) + case _ if invalidChars.exists(f.name.contains) => + val invalidCharsString = invalidChars.map(c => s"'$c'").mkString(", ") + val errMsg = "Cannot create a table having a nested column whose name contains " + + s"invalid characters ($invalidCharsString) in Hive metastore. Table: $tableName; " + + s"Column: ${f.name}" + throw new AnalysisException(errMsg) + case _ => + } + } + + dataSchema.foreach { f => + f.dataType match { + // Checks top-level column names + case _ if f.name.contains(",") => + throw new AnalysisException("Cannot create a table having a column whose name " + + s"contains commas in Hive metastore. Table: $tableName; Column: ${f.name}") + // Checks nested column names + case st: StructType => + verifyNestedColumnNames(st) + case _ => } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala index 0341c3b378918..31fd4c5a1f996 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -33,6 +33,7 @@ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.{NoSuchPartitionException, TableAlreadyExistsException} import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.execution.command.{DDLSuite, DDLUtils} +import org.apache.spark.sql.functions._ import org.apache.spark.sql.hive.HiveExternalCatalog import org.apache.spark.sql.hive.HiveUtils.{CONVERT_METASTORE_ORC, CONVERT_METASTORE_PARQUET} import org.apache.spark.sql.hive.orc.OrcFileOperator @@ -2248,4 +2249,22 @@ class HiveDDLSuite checkAnswer(spark.table("t4"), Row(0, 0)) } } + + test("SPARK-24681 checks if nested column names do not include ',', ':', and ';'") { + val expectedMsg = "Cannot create a table having a nested column whose name contains invalid " + + "characters (',', ':', ';') in Hive metastore." + + Seq("nested,column", "nested:column", "nested;column").foreach { nestedColumnName => + withTable("t") { + val e = intercept[AnalysisException] { + spark.range(1) + .select(struct(lit(0).as(nestedColumnName)).as("toplevel")) + .write + .format("hive") + .saveAsTable("t") + }.getMessage + assert(e.contains(expectedMsg)) + } + } + } } From 681845fd62bbcbbf1d9309b7d8a252198d96c738 Mon Sep 17 00:00:00 2001 From: DB Tsai Date: Tue, 17 Jul 2018 17:33:52 -0700 Subject: [PATCH 1147/1161] [SPARK-24402][SQL] Optimize `In` expression when only one element in the collection or collection is empty ## What changes were proposed in this pull request? Two new rules in the logical plan optimizers are added. 1. When there is only one element in the **`Collection`**, the physical plan will be optimized to **`EqualTo`**, so predicate pushdown can be used. ```scala profileDF.filter( $"profileID".isInCollection(Set(6))).explain(true) """ |== Physical Plan == |*(1) Project [profileID#0] |+- *(1) Filter (isnotnull(profileID#0) && (profileID#0 = 6)) | +- *(1) FileScan parquet [profileID#0] Batched: true, Format: Parquet, | PartitionFilters: [], | PushedFilters: [IsNotNull(profileID), EqualTo(profileID,6)], | ReadSchema: struct """.stripMargin ``` 2. When the **`Collection`** is empty, and the input is nullable, the logical plan will be simplified to ```scala profileDF.filter( $"profileID".isInCollection(Set())).explain(true) """ |== Optimized Logical Plan == |Filter if (isnull(profileID#0)) null else false |+- Relation[profileID#0] parquet """.stripMargin ``` TODO: 1. For multiple conditions with numbers less than certain thresholds, we should still allow predicate pushdown. 2. Optimize the **`In`** using **`tableswitch`** or **`lookupswitch`** when the numbers of the categories are low, and they are **`Int`**, **`Long`**. 3. The default immutable hash trees set is slow for query, and we should do benchmark for using different set implementation for faster query. 4. **`filter(if (condition) null else false)`** can be optimized to false. ## How was this patch tested? Couple new tests are added. Author: DB Tsai Closes #21797 from dbtsai/optimize-in. --- .../sql/catalyst/optimizer/expressions.scala | 17 +++++++--- .../catalyst/optimizer/OptimizeInSuite.scala | 32 +++++++++++++++++++ 2 files changed, 45 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 1d363b8146e3f..cf17f59599968 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -218,15 +218,24 @@ object ReorderAssociativeOperator extends Rule[LogicalPlan] { object OptimizeIn extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case q: LogicalPlan => q transformExpressionsDown { - case In(v, list) if list.isEmpty && !v.nullable => FalseLiteral + case In(v, list) if list.isEmpty => + // When v is not nullable, the following expression will be optimized + // to FalseLiteral which is tested in OptimizeInSuite.scala + If(IsNotNull(v), FalseLiteral, Literal(null, BooleanType)) case expr @ In(v, list) if expr.inSetConvertible => val newList = ExpressionSet(list).toSeq - if (newList.size > SQLConf.get.optimizerInSetConversionThreshold) { + if (newList.length == 1 + // TODO: `EqualTo` for structural types are not working. Until SPARK-24443 is addressed, + // TODO: we exclude them in this rule. + && !v.isInstanceOf[CreateNamedStructLike] + && !newList.head.isInstanceOf[CreateNamedStructLike]) { + EqualTo(v, newList.head) + } else if (newList.length > SQLConf.get.optimizerInSetConversionThreshold) { val hSet = newList.map(e => e.eval(EmptyRow)) InSet(v, HashSet() ++ hSet) - } else if (newList.size < list.size) { + } else if (newList.length < list.length) { expr.copy(list = newList) - } else { // newList.length == list.length + } else { // newList.length == list.length && newList.length > 1 expr } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala index 478118ed709f7..86522a6a54ed5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala @@ -176,6 +176,21 @@ class OptimizeInSuite extends PlanTest { } } + test("OptimizedIn test: one element in list gets transformed to EqualTo.") { + val originalQuery = + testRelation + .where(In(UnresolvedAttribute("a"), Seq(Literal(1)))) + .analyze + + val optimized = Optimize.execute(originalQuery) + val correctAnswer = + testRelation + .where(EqualTo(UnresolvedAttribute("a"), Literal(1))) + .analyze + + comparePlans(optimized, correctAnswer) + } + test("OptimizedIn test: In empty list gets transformed to FalseLiteral " + "when value is not nullable") { val originalQuery = @@ -191,4 +206,21 @@ class OptimizeInSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + + test("OptimizedIn test: In empty list gets transformed to `If` expression " + + "when value is nullable") { + val originalQuery = + testRelation + .where(In(UnresolvedAttribute("a"), Nil)) + .analyze + + val optimized = Optimize.execute(originalQuery) + val correctAnswer = + testRelation + .where(If(IsNotNull(UnresolvedAttribute("a")), + Literal(false), Literal.create(null, BooleanType))) + .analyze + + comparePlans(optimized, correctAnswer) + } } From fc2e18963efdf4b50258f85c8779122742876910 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Wed, 18 Jul 2018 10:00:13 +0800 Subject: [PATCH 1148/1161] [SPARK-24529][BUILD][TEST-MAVEN][FOLLOW-UP] Set spotbugs-maven-plugin's fork to true ## What changes were proposed in this pull request? Set `spotbugs-maven-plugin`'s fork to `true`, otherwise will throw exception when make distribution: ``` ./dev/make-distribution.sh --name SPARK-24529 --tgz -Phadoop-2.7 -Phive -Phive-thriftserver -Pyarn -Phadoop-provided ``` exception: ```java ... [INFO] Reactor Summary: [INFO] [INFO] Spark Project Parent POM ........................... SUCCESS [ 8.753 s] [INFO] Spark Project Tags ................................. SUCCESS [ 9.334 s] [INFO] Spark Project Sketch ............................... SUCCESS [ 12.029 s] [INFO] Spark Project Local DB ............................. SUCCESS [ 13.641 s] [INFO] Spark Project Networking ........................... FAILURE [10:10 min] [INFO] Spark Project Shuffle Streaming Service ............ SKIPPED [INFO] Spark Project Unsafe ............................... SUCCESS [ 16.415 s] [INFO] Spark Project Launcher ............................. SKIPPED [INFO] Spark Project Core ................................. SKIPPED [INFO] Spark Project ML Local Library ..................... SKIPPED [INFO] Spark Project GraphX ............................... SKIPPED [INFO] Spark Project Streaming ............................ SKIPPED [INFO] Spark Project Catalyst ............................. SKIPPED [INFO] Spark Project SQL .................................. SKIPPED [INFO] Spark Project ML Library ........................... SKIPPED [INFO] Spark Project Tools ................................ SUCCESS [ 8.750 s] [INFO] Spark Project Hive ................................. SKIPPED [INFO] Spark Project REPL ................................. SKIPPED [INFO] Spark Project YARN Shuffle Service ................. SKIPPED [INFO] Spark Project YARN ................................. SKIPPED [INFO] Spark Project Hive Thrift Server ................... SKIPPED [INFO] Spark Project Assembly ............................. SKIPPED [INFO] Spark Integration for Kafka 0.10 ................... SKIPPED [INFO] Kafka 0.10 Source for Structured Streaming ......... SKIPPED [INFO] Spark Project Examples ............................. SKIPPED [INFO] Spark Integration for Kafka 0.10 Assembly .......... SKIPPED [INFO] Spark Avro ......................................... SKIPPED [INFO] ------------------------------------------------------------------------ [INFO] BUILD FAILURE [INFO] ------------------------------------------------------------------------ [INFO] Total time: 10:29 min (Wall Clock) [INFO] Finished at: 2018-07-16T21:39:46+08:00 [INFO] Final Memory: 61M/885M [INFO] ------------------------------------------------------------------------ Timeout: sub-process interrupted [ERROR] Failed to execute goal com.github.spotbugs:spotbugs-maven-plugin:3.1.3:spotbugs (spotbugs) on project spark-network-common_2.11: Execution spotbugs of goal com.github.spotbugs:spotbugs-maven-plugin:3.1.3:spotbugs failed: Timeout: killed the sub-process -> [Help 1] [ERROR] [ERROR] To see the full stack trace of the errors, re-run Maven with the -e switch. [ERROR] Re-run Maven using the -X switch to enable full debug logging. [ERROR] [ERROR] For more information about the errors and possible solutions, please read the following articles: [ERROR] [Help 1] http://cwiki.apache.org/confluence/display/MAVEN/PluginExecutionException [ERROR] [ERROR] After correcting the problems, you can resume the build with the command [ERROR] mvn -rf :spark-network-common_2.11 org.apache.tools.ant.ExitException: Permission ("java.lang.RuntimePermission" "exitVM") was not granted. at org.apache.tools.ant.types.Permissions$MySM.checkExit(Permissions.java:194) at java.lang.Runtime.exit(Runtime.java:107) at java.lang.System.exit(System.java:971) at org.codehaus.plexus.classworlds.launcher.Launcher.main(Launcher.java:358) Exception in thread "main" org.apache.tools.ant.ExitException: Permission ("java.lang.RuntimePermission" "exitVM") was not granted. at org.apache.tools.ant.types.Permissions$MySM.checkExit(Permissions.java:194) at java.lang.Runtime.exit(Runtime.java:107) at java.lang.System.exit(System.java:971) at org.codehaus.plexus.classworlds.launcher.Launcher.main(Launcher.java:364) Timeout: sub-process interrupted ``` ## How was this patch tested? manual tests Author: Yuming Wang Closes #21785 from wangyum/SPARK-24529. --- pom.xml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pom.xml b/pom.xml index 039292337eaa0..1892bbee4e97b 100644 --- a/pom.xml +++ b/pom.xml @@ -2618,7 +2618,7 @@ Low true FindPuzzlers - false + true From 3b59d326c77bec96e5fb856d827139e0389394ba Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Tue, 17 Jul 2018 23:52:17 -0700 Subject: [PATCH 1149/1161] [SPARK-24576][BUILD] Upgrade Apache ORC to 1.5.2 ## What changes were proposed in this pull request? This issue aims to upgrade Apache ORC library from 1.4.4 to 1.5.2 in order to bring the following benefits into Apache Spark. - [ORC-91](https://issues.apache.org/jira/browse/ORC-91) Support for variable length blocks in HDFS (The current space wasted in ORC to padding is known to be 5%.) - [ORC-344](https://issues.apache.org/jira/browse/ORC-344) Support for using Decimal64ColumnVector In addition to that, Apache Hive 3.1 and 3.2 will use ORC 1.5.1 ([HIVE-19669](https://issues.apache.org/jira/browse/HIVE-19465)) and 1.5.2 ([HIVE-19792](https://issues.apache.org/jira/browse/HIVE-19792)) respectively. This will improve the compatibility between Apache Spark and Apache Hive by sharing the common library. ## How was this patch tested? Pass the Jenkins with all existing tests. Author: Dongjoon Hyun Closes #21582 from dongjoon-hyun/SPARK-24576. --- dev/deps/spark-deps-hadoop-2.6 | 7 +++-- dev/deps/spark-deps-hadoop-2.7 | 7 +++-- dev/deps/spark-deps-hadoop-3.1 | 7 +++-- pom.xml | 2 +- sql/core/pom.xml | 28 +++++++++++++++++++ .../datasources/orc/OrcFileFormat.scala | 15 +++++++++- .../datasources/orc/OrcSerializer.scala | 2 +- 7 files changed, 56 insertions(+), 12 deletions(-) diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6 index f50a0aac0aefc..ff6d5c30c1eb4 100644 --- a/dev/deps/spark-deps-hadoop-2.6 +++ b/dev/deps/spark-deps-hadoop-2.6 @@ -2,7 +2,7 @@ JavaEWAH-0.3.2.jar RoaringBitmap-0.5.11.jar ST4-4.0.4.jar activation-1.1.1.jar -aircompressor-0.8.jar +aircompressor-0.10.jar antlr-2.7.7.jar antlr-runtime-3.4.jar antlr4-runtime-4.7.jar @@ -157,8 +157,9 @@ objenesis-2.1.jar okhttp-3.8.1.jar okio-1.13.0.jar opencsv-2.3.jar -orc-core-1.4.4-nohive.jar -orc-mapreduce-1.4.4-nohive.jar +orc-core-1.5.2-nohive.jar +orc-mapreduce-1.5.2-nohive.jar +orc-shims-1.5.2.jar oro-2.0.8.jar osgi-resource-locator-1.0.1.jar paranamer-2.8.jar diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index 774f9dc39ce4d..72a94f8953c6c 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -2,7 +2,7 @@ JavaEWAH-0.3.2.jar RoaringBitmap-0.5.11.jar ST4-4.0.4.jar activation-1.1.1.jar -aircompressor-0.8.jar +aircompressor-0.10.jar antlr-2.7.7.jar antlr-runtime-3.4.jar antlr4-runtime-4.7.jar @@ -158,8 +158,9 @@ objenesis-2.1.jar okhttp-3.8.1.jar okio-1.13.0.jar opencsv-2.3.jar -orc-core-1.4.4-nohive.jar -orc-mapreduce-1.4.4-nohive.jar +orc-core-1.5.2-nohive.jar +orc-mapreduce-1.5.2-nohive.jar +orc-shims-1.5.2.jar oro-2.0.8.jar osgi-resource-locator-1.0.1.jar paranamer-2.8.jar diff --git a/dev/deps/spark-deps-hadoop-3.1 b/dev/deps/spark-deps-hadoop-3.1 index 19c05ad1e991f..3409dc4613324 100644 --- a/dev/deps/spark-deps-hadoop-3.1 +++ b/dev/deps/spark-deps-hadoop-3.1 @@ -4,7 +4,7 @@ RoaringBitmap-0.5.11.jar ST4-4.0.4.jar accessors-smart-1.2.jar activation-1.1.1.jar -aircompressor-0.8.jar +aircompressor-0.10.jar antlr-2.7.7.jar antlr-runtime-3.4.jar antlr4-runtime-4.7.jar @@ -176,8 +176,9 @@ okhttp-2.7.5.jar okhttp-3.8.1.jar okio-1.13.0.jar opencsv-2.3.jar -orc-core-1.4.4-nohive.jar -orc-mapreduce-1.4.4-nohive.jar +orc-core-1.5.2-nohive.jar +orc-mapreduce-1.5.2-nohive.jar +orc-shims-1.5.2.jar oro-2.0.8.jar osgi-resource-locator-1.0.1.jar paranamer-2.8.jar diff --git a/pom.xml b/pom.xml index 1892bbee4e97b..649221d526086 100644 --- a/pom.xml +++ b/pom.xml @@ -131,7 +131,7 @@ 1.2.1 10.12.1.1 1.10.0 - 1.4.4 + 1.5.2 nohive 1.6.0 9.3.20.v20170531 diff --git a/sql/core/pom.xml b/sql/core/pom.xml index 18ae314309d7b..8873b00e7117a 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -90,11 +90,39 @@ org.apache.orc orc-core ${orc.classifier} + + + org.apache.hadoop + hadoop-hdfs + + + + org.apache.hive + hive-storage-api + + org.apache.orc orc-mapreduce ${orc.classifier} + + + org.apache.hadoop + hadoop-hdfs + + + + org.apache.hive + hive-storage-api + + org.apache.parquet diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala index 3a8c0add8c2f9..df1cebed5bd0a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala @@ -59,6 +59,19 @@ private[sql] object OrcFileFormat { def checkFieldNames(names: Seq[String]): Unit = { names.foreach(checkFieldName) } + + def getQuotedSchemaString(dataType: DataType): String = dataType match { + case _: AtomicType => dataType.catalogString + case StructType(fields) => + fields.map(f => s"`${f.name}`:${getQuotedSchemaString(f.dataType)}") + .mkString("struct<", ",", ">") + case ArrayType(elementType, _) => + s"array<${getQuotedSchemaString(elementType)}>" + case MapType(keyType, valueType, _) => + s"map<${getQuotedSchemaString(keyType)},${getQuotedSchemaString(valueType)}>" + case _ => // UDT and others + dataType.catalogString + } } /** @@ -93,7 +106,7 @@ class OrcFileFormat val conf = job.getConfiguration - conf.set(MAPRED_OUTPUT_SCHEMA.getAttribute, dataSchema.catalogString) + conf.set(MAPRED_OUTPUT_SCHEMA.getAttribute, OrcFileFormat.getQuotedSchemaString(dataSchema)) conf.set(COMPRESS.getAttribute, orcOptions.compressionCodec) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcSerializer.scala index 899af0750cadf..90d1268028096 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcSerializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcSerializer.scala @@ -223,6 +223,6 @@ class OrcSerializer(dataSchema: StructType) { * Return a Orc value object for the given Spark schema. */ private def createOrcValue(dataType: DataType) = { - OrcStruct.createValue(TypeDescription.fromString(dataType.catalogString)) + OrcStruct.createValue(TypeDescription.fromString(OrcFileFormat.getQuotedSchemaString(dataType))) } } From 34cb3b54e9b7d4c5739cd446a9a66ec3fe59516b Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Wed, 18 Jul 2018 19:17:18 +0800 Subject: [PATCH 1150/1161] [SPARK-24386][SPARK-24768][BUILD][FOLLOWUP] Fix lint-java and Scala 2.12 build. ## What changes were proposed in this pull request? This pr fixes lint-java and Scala 2.12 build. lint-java: ``` [ERROR] src/test/resources/log4j.properties:[0] (misc) NewlineAtEndOfFile: File does not end with a newline. ``` Scala 2.12 build: ``` [error] /.../sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousCoalesceRDD.scala:121: overloaded method value addTaskCompletionListener with alternatives: [error] (f: org.apache.spark.TaskContext => Unit)org.apache.spark.TaskContext [error] (listener: org.apache.spark.util.TaskCompletionListener)org.apache.spark.TaskContext [error] cannot be applied to (org.apache.spark.TaskContext => java.util.List[Runnable]) [error] context.addTaskCompletionListener { ctx => [error] ^ ``` ## How was this patch tested? Manually executed lint-java and Scala 2.12 build in my local environment. Author: Takuya UESHIN Closes #21801 from ueshin/issues/SPARK-24386_24768/fix_build. --- external/avro/src/test/resources/log4j.properties | 2 +- .../execution/streaming/continuous/ContinuousCoalesceRDD.scala | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/external/avro/src/test/resources/log4j.properties b/external/avro/src/test/resources/log4j.properties index c18a724007c27..f80a5291bc078 100644 --- a/external/avro/src/test/resources/log4j.properties +++ b/external/avro/src/test/resources/log4j.properties @@ -46,4 +46,4 @@ log4j.additivity.org.apache.hadoop.hive.metastore.RetryingHMSHandler=false log4j.logger.org.apache.hadoop.hive.metastore.RetryingHMSHandler=OFF log4j.additivity.hive.ql.metadata.Hive=false -log4j.logger.hive.ql.metadata.Hive=OFF \ No newline at end of file +log4j.logger.hive.ql.metadata.Hive=OFF diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousCoalesceRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousCoalesceRDD.scala index ba85b355f974f..10d8fc553fede 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousCoalesceRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousCoalesceRDD.scala @@ -120,6 +120,7 @@ class ContinuousCoalesceRDD( context.addTaskCompletionListener { ctx => threadPool.shutdownNow() + () } part.writersInitialized = true From 2694dd2bf084410ff346d21aaf74025b587d46a8 Mon Sep 17 00:00:00 2001 From: Nihar Sheth Date: Wed, 18 Jul 2018 09:14:36 -0500 Subject: [PATCH 1151/1161] [MINOR][CORE] Add test cases for RDD.cartesian ## What changes were proposed in this pull request? While looking through the codebase, it appeared that the scala code for RDD.cartesian does not have any tests for correctness. This adds a couple basic tests to verify cartesian yields correct values. While the implementation for RDD.cartesian is pretty simple, it always helps to have a few tests! ## How was this patch tested? The new test cases pass, and the scala style tests from running dev/run-tests all pass. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Nihar Sheth Closes #21765 from NiharS/cartesianTests. --- .../scala/org/apache/spark/rdd/RDDSuite.scala | 24 ++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index 5148ce05bd918..b143a468a1baf 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -443,7 +443,7 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext { map{x => List(x)}.toList, "Tried coalescing 9 partitions to 20 but didn't get 9 back") } - test("coalesced RDDs with partial locality") { + test("coalesced RDDs with partial locality") { // Make an RDD that has some locality preferences and some without. This can happen // with UnionRDD val data = sc.makeRDD((1 to 9).map(i => { @@ -846,6 +846,28 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext { assert(partitions(1) === Seq((1, 3), (3, 8), (3, 8))) } + test("cartesian on empty RDD") { + val a = sc.emptyRDD[Int] + val b = sc.parallelize(1 to 3) + val cartesian_result = Array.empty[(Int, Int)] + assert(a.cartesian(a).collect().toList === cartesian_result) + assert(a.cartesian(b).collect().toList === cartesian_result) + assert(b.cartesian(a).collect().toList === cartesian_result) + } + + test("cartesian on non-empty RDDs") { + val a = sc.parallelize(1 to 3) + val b = sc.parallelize(2 to 4) + val c = sc.parallelize(1 to 1) + val a_cartesian_b = + Array((1, 2), (1, 3), (1, 4), (2, 2), (2, 3), (2, 4), (3, 2), (3, 3), (3, 4)) + val a_cartesian_c = Array((1, 1), (2, 1), (3, 1)) + val c_cartesian_a = Array((1, 1), (1, 2), (1, 3)) + assert(a.cartesian[Int](b).collect().toList.sorted === a_cartesian_b) + assert(a.cartesian[Int](c).collect().toList.sorted === a_cartesian_c) + assert(c.cartesian[Int](a).collect().toList.sorted === c_cartesian_a) + } + test("intersection") { val all = sc.parallelize(1 to 10) val evens = sc.parallelize(2 to 10 by 2) From 002300dd41ccc2d1c38351cb09c430f0ded6ab85 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9F=A9=E7=94=B0=E7=94=B000222924?= Date: Wed, 18 Jul 2018 09:40:36 -0500 Subject: [PATCH 1152/1161] [SPARK-24804] There are duplicate words in the test title in the DatasetSuite MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? In DatasetSuite.scala, in the 1299 line, test("SPARK-19896: cannot have circular references in in case class") , there are duplicate words "in in". We can get rid of one. ## How was this patch tested? (Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests) (If this patch involves UI changes, please attach a screenshot; otherwise, remove this) Please review http://spark.apache.org/contributing.html before opening a pull request. Author: 韩田田00222924 Closes #21767 from httfighter/inin. --- sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index ce8db99d4e2f1..cf24eba128012 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -1296,7 +1296,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { new java.sql.Timestamp(100000)) } - test("SPARK-19896: cannot have circular references in in case class") { + test("SPARK-19896: cannot have circular references in case class") { val errMsg1 = intercept[UnsupportedOperationException] { Seq(CircularReferenceClassA(null)).toDS } From ebe9e28488a30b40f85f79c69be69185a1c4e4f5 Mon Sep 17 00:00:00 2001 From: Huangweizhe Date: Wed, 18 Jul 2018 09:45:56 -0500 Subject: [PATCH 1153/1161] [SPARK-24628][DOC] Typos of the example code in docs/mllib-data-types.md ## What changes were proposed in this pull request? The example wants to create a dense matrix ((1.0, 2.0), (3.0, 4.0), (5.0, 6.0)), but the list is given as [1, 2, 3, 4, 5, 6]. Now it is changed as [1, 3, 5, 2, 4, 6]. And the example wants to create an RDD of coordinate entries like: entries = sc.parallelize([(0, 0, 1.2), (1, 0, 2.1), (2, 1, 3.7)]). However, it is done with the MatrixEntry class like: entries = sc.parallelize([MatrixEntry(0, 0, 1.2), MatrixEntry(1, 0, 2.1), MatrixEntry(6, 1, 3.7)]), where the third MatrixEntry has a different row index. Now it is changed as MatrixEntry(2, 1, 3.7). ## How was this patch tested? This is trivial enough that it should not affect tests. Author: Weizhe Huang Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Huangweizhe Closes #21612 from huangweizhe123/my_change. --- docs/mllib-data-types.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/mllib-data-types.md b/docs/mllib-data-types.md index 5066bb29387dc..eca101132d2e5 100644 --- a/docs/mllib-data-types.md +++ b/docs/mllib-data-types.md @@ -317,7 +317,7 @@ Refer to the [`Matrix` Python docs](api/python/pyspark.mllib.html#pyspark.mllib. from pyspark.mllib.linalg import Matrix, Matrices # Create a dense matrix ((1.0, 2.0), (3.0, 4.0), (5.0, 6.0)) -dm2 = Matrices.dense(3, 2, [1, 2, 3, 4, 5, 6]) +dm2 = Matrices.dense(3, 2, [1, 3, 5, 2, 4, 6]) # Create a sparse matrix ((9.0, 0.0), (0.0, 8.0), (0.0, 6.0)) sm = Matrices.sparse(3, 2, [0, 1, 3], [0, 2, 1], [9, 6, 8]) @@ -624,7 +624,7 @@ from pyspark.mllib.linalg.distributed import CoordinateMatrix, MatrixEntry # Create an RDD of coordinate entries. # - This can be done explicitly with the MatrixEntry class: -entries = sc.parallelize([MatrixEntry(0, 0, 1.2), MatrixEntry(1, 0, 2.1), MatrixEntry(6, 1, 3.7)]) +entries = sc.parallelize([MatrixEntry(0, 0, 1.2), MatrixEntry(1, 0, 2.1), MatrixEntry(2, 1, 3.7)]) # - or using (long, long, float) tuples: entries = sc.parallelize([(0, 0, 1.2), (1, 0, 2.1), (2, 1, 3.7)]) From fc0c8c97173e641b50eb7cea80c63262d5ba4180 Mon Sep 17 00:00:00 2001 From: mcheah Date: Wed, 18 Jul 2018 10:01:39 -0700 Subject: [PATCH 1154/1161] [SPARK-24825][K8S][TEST] Kubernetes integration tests build the whole reactor ## What changes were proposed in this pull request? Make the integration test script build all modules. In order to not run all the non-Kubernetes integration tests in the build, support specifying tags and tag all integration tests specifically with "k8s". Supply the k8s tag in the dev/dev-run-integration-tests.sh script. ## How was this patch tested? The build system will test this. Author: mcheah Closes #21800 from mccheah/k8s-integration-tests-maven-fix. --- pom.xml | 3 +++ .../dev/dev-run-integration-tests.sh | 21 ++++++---------- .../k8s/integrationtest/KubernetesSuite.scala | 25 ++++++++++--------- 3 files changed, 23 insertions(+), 26 deletions(-) diff --git a/pom.xml b/pom.xml index 649221d526086..81a53eee14f29 100644 --- a/pom.xml +++ b/pom.xml @@ -194,6 +194,7 @@ ${java.home} + org.spark_project @@ -2162,6 +2163,7 @@ false ${test.exclude.tags} + ${test.include.tags} @@ -2209,6 +2211,7 @@ __not_used__ ${test.exclude.tags} + ${test.include.tags} diff --git a/resource-managers/kubernetes/integration-tests/dev/dev-run-integration-tests.sh b/resource-managers/kubernetes/integration-tests/dev/dev-run-integration-tests.sh index 3acd0f5cd3349..b28b8b82ca016 100755 --- a/resource-managers/kubernetes/integration-tests/dev/dev-run-integration-tests.sh +++ b/resource-managers/kubernetes/integration-tests/dev/dev-run-integration-tests.sh @@ -16,9 +16,8 @@ # See the License for the specific language governing permissions and # limitations under the License. # -TEST_ROOT_DIR=$(git rev-parse --show-toplevel)/resource-managers/kubernetes/integration-tests - -cd "${TEST_ROOT_DIR}" +set -xo errexit +TEST_ROOT_DIR=$(git rev-parse --show-toplevel) DEPLOY_MODE="minikube" IMAGE_REPO="docker.io/kubespark" @@ -27,7 +26,7 @@ IMAGE_TAG="N/A" SPARK_MASTER= NAMESPACE= SERVICE_ACCOUNT= -INCLUDE_TAGS= +INCLUDE_TAGS="k8s" EXCLUDE_TAGS= # Parse arguments @@ -62,7 +61,7 @@ while (( "$#" )); do shift ;; --include-tags) - INCLUDE_TAGS="$2" + INCLUDE_TAGS="k8s,$2" shift ;; --exclude-tags) @@ -76,13 +75,12 @@ while (( "$#" )); do shift done -cd $TEST_ROOT_DIR - properties=( -Dspark.kubernetes.test.sparkTgz=$SPARK_TGZ \ -Dspark.kubernetes.test.imageTag=$IMAGE_TAG \ -Dspark.kubernetes.test.imageRepo=$IMAGE_REPO \ - -Dspark.kubernetes.test.deployMode=$DEPLOY_MODE + -Dspark.kubernetes.test.deployMode=$DEPLOY_MODE \ + -Dtest.include.tags=$INCLUDE_TAGS ) if [ -n $NAMESPACE ]; @@ -105,9 +103,4 @@ then properties=( ${properties[@]} -Dtest.exclude.tags=$EXCLUDE_TAGS ) fi -if [ -n $INCLUDE_TAGS ]; -then - properties=( ${properties[@]} -Dtest.include.tags=$INCLUDE_TAGS ) -fi - -../../../build/mvn integration-test ${properties[@]} +$TEST_ROOT_DIR/build/mvn integration-test -f $TEST_ROOT_DIR/pom.xml -pl resource-managers/kubernetes/integration-tests -am -Pkubernetes -Phadoop-2.7 ${properties[@]} diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala index 774c3936b877c..daabfaaac8c7e 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala @@ -23,7 +23,7 @@ import java.util.regex.Pattern import com.google.common.io.PatternFilenameFilter import io.fabric8.kubernetes.api.model.Pod -import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} +import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, Tag} import org.scalatest.concurrent.{Eventually, PatienceConfiguration} import org.scalatest.time.{Minutes, Seconds, Span} import scala.collection.JavaConverters._ @@ -47,6 +47,7 @@ private[spark] class KubernetesSuite extends SparkFunSuite private var containerLocalSparkDistroExamplesJar: String = _ private var appLocator: String = _ private var driverPodName: String = _ + private val k8sTestTag = Tag("k8s") override def beforeAll(): Unit = { // The scalatest-maven-plugin gives system properties that are referenced but not set null @@ -102,22 +103,22 @@ private[spark] class KubernetesSuite extends SparkFunSuite deleteDriverPod() } - test("Run SparkPi with no resources") { + test("Run SparkPi with no resources", k8sTestTag) { runSparkPiAndVerifyCompletion() } - test("Run SparkPi with a very long application name.") { + test("Run SparkPi with a very long application name.", k8sTestTag) { sparkAppConf.set("spark.app.name", "long" * 40) runSparkPiAndVerifyCompletion() } - test("Use SparkLauncher.NO_RESOURCE") { + test("Use SparkLauncher.NO_RESOURCE", k8sTestTag) { sparkAppConf.setJars(Seq(containerLocalSparkDistroExamplesJar)) runSparkPiAndVerifyCompletion( appResource = SparkLauncher.NO_RESOURCE) } - test("Run SparkPi with a master URL without a scheme.") { + test("Run SparkPi with a master URL without a scheme.", k8sTestTag) { val url = kubernetesTestComponents.kubernetesClient.getMasterUrl val k8sMasterUrl = if (url.getPort < 0) { s"k8s://${url.getHost}" @@ -128,11 +129,11 @@ private[spark] class KubernetesSuite extends SparkFunSuite runSparkPiAndVerifyCompletion() } - test("Run SparkPi with an argument.") { + test("Run SparkPi with an argument.", k8sTestTag) { runSparkPiAndVerifyCompletion(appArgs = Array("5")) } - test("Run SparkPi with custom labels, annotations, and environment variables.") { + test("Run SparkPi with custom labels, annotations, and environment variables.", k8sTestTag) { sparkAppConf .set("spark.kubernetes.driver.label.label1", "label1-value") .set("spark.kubernetes.driver.label.label2", "label2-value") @@ -158,21 +159,21 @@ private[spark] class KubernetesSuite extends SparkFunSuite }) } - test("Run extraJVMOptions check on driver") { + test("Run extraJVMOptions check on driver", k8sTestTag) { sparkAppConf .set("spark.driver.extraJavaOptions", "-Dspark.test.foo=spark.test.bar") runSparkJVMCheckAndVerifyCompletion( expectedJVMValue = Seq("(spark.test.foo,spark.test.bar)")) } - test("Run SparkRemoteFileTest using a remote data file") { + test("Run SparkRemoteFileTest using a remote data file", k8sTestTag) { sparkAppConf .set("spark.files", REMOTE_PAGE_RANK_DATA_FILE) runSparkRemoteCheckAndVerifyCompletion( appArgs = Array(REMOTE_PAGE_RANK_FILE_NAME)) } - test("Run PySpark on simple pi.py example") { + test("Run PySpark on simple pi.py example", k8sTestTag) { sparkAppConf .set("spark.kubernetes.container.image", s"${getTestImageRepo}/spark-py:${getTestImageTag}") runSparkApplicationAndVerifyCompletion( @@ -186,7 +187,7 @@ private[spark] class KubernetesSuite extends SparkFunSuite isJVM = false) } - test("Run PySpark with Python2 to test a pyfiles example") { + test("Run PySpark with Python2 to test a pyfiles example", k8sTestTag) { sparkAppConf .set("spark.kubernetes.container.image", s"${getTestImageRepo}/spark-py:${getTestImageTag}") .set("spark.kubernetes.pyspark.pythonversion", "2") @@ -204,7 +205,7 @@ private[spark] class KubernetesSuite extends SparkFunSuite pyFiles = Some(PYSPARK_CONTAINER_TESTS)) } - test("Run PySpark with Python3 to test a pyfiles example") { + test("Run PySpark with Python3 to test a pyfiles example", k8sTestTag) { sparkAppConf .set("spark.kubernetes.container.image", s"${getTestImageRepo}/spark-py:${getTestImageTag}") .set("spark.kubernetes.pyspark.pythonversion", "3") From c8bee932cb644627c4049b5a07dd8028968572d9 Mon Sep 17 00:00:00 2001 From: sychen Date: Wed, 18 Jul 2018 13:24:41 -0500 Subject: [PATCH 1155/1161] [SPARK-24677][CORE] Avoid NoSuchElementException from MedianHeap ## What changes were proposed in this pull request? When speculation is enabled, TaskSetManager#markPartitionCompleted should write successful task duration to MedianHeap, not just increase tasksSuccessful. Otherwise when TaskSetManager#checkSpeculatableTasks,tasksSuccessful non-zero, but MedianHeap is empty. Then throw an exception successfulTaskDurations.median java.util.NoSuchElementException: MedianHeap is empty. Finally led to stopping SparkContext. ## How was this patch tested? TaskSetManagerSuite.scala unit test:[SPARK-24677] MedianHeap should not be empty when speculation is enabled Author: sychen Closes #21656 from cxzl25/fix_MedianHeap_empty. --- .../spark/scheduler/TaskSchedulerImpl.scala | 7 ++- .../spark/scheduler/TaskSetManager.scala | 7 ++- .../spark/scheduler/TaskSetManagerSuite.scala | 49 +++++++++++++++++++ 3 files changed, 59 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index 598b62f85a1fa..56c0bf6c09351 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -697,9 +697,12 @@ private[spark] class TaskSchedulerImpl( * do not also submit those same tasks. That also means that a task completion from an earlier * attempt can lead to the entire stage getting marked as successful. */ - private[scheduler] def markPartitionCompletedInAllTaskSets(stageId: Int, partitionId: Int) = { + private[scheduler] def markPartitionCompletedInAllTaskSets( + stageId: Int, + partitionId: Int, + taskInfo: TaskInfo) = { taskSetsByStageIdAndAttempt.getOrElse(stageId, Map()).values.foreach { tsm => - tsm.markPartitionCompleted(partitionId) + tsm.markPartitionCompleted(partitionId, taskInfo) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index a18c66596852a..6071605ad7f9d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -758,7 +758,7 @@ private[spark] class TaskSetManager( } // There may be multiple tasksets for this stage -- we let all of them know that the partition // was completed. This may result in some of the tasksets getting completed. - sched.markPartitionCompletedInAllTaskSets(stageId, tasks(index).partitionId) + sched.markPartitionCompletedInAllTaskSets(stageId, tasks(index).partitionId, info) // This method is called by "TaskSchedulerImpl.handleSuccessfulTask" which holds the // "TaskSchedulerImpl" lock until exiting. To avoid the SPARK-7655 issue, we should not // "deserialize" the value when holding a lock to avoid blocking other threads. So we call @@ -769,9 +769,12 @@ private[spark] class TaskSetManager( maybeFinishTaskSet() } - private[scheduler] def markPartitionCompleted(partitionId: Int): Unit = { + private[scheduler] def markPartitionCompleted(partitionId: Int, taskInfo: TaskInfo): Unit = { partitionToIndex.get(partitionId).foreach { index => if (!successful(index)) { + if (speculationEnabled && !isZombie) { + successfulTaskDurations.insert(taskInfo.duration) + } tasksSuccessful += 1 successful(index) = true if (tasksSuccessful == numTasks) { diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index ca6a7e5db3b17..ae571e5a3583a 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -1365,6 +1365,55 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg assert(taskOption4.get.addedJars === addedJarsMidTaskSet) } + test("[SPARK-24677] Avoid NoSuchElementException from MedianHeap") { + val conf = new SparkConf().set("spark.speculation", "true") + sc = new SparkContext("local", "test", conf) + // Set the speculation multiplier to be 0 so speculative tasks are launched immediately + sc.conf.set("spark.speculation.multiplier", "0.0") + sc.conf.set("spark.speculation.quantile", "0.1") + sc.conf.set("spark.speculation", "true") + + sched = new FakeTaskScheduler(sc) + sched.initialize(new FakeSchedulerBackend()) + + val dagScheduler = new FakeDAGScheduler(sc, sched) + sched.setDAGScheduler(dagScheduler) + + val taskSet1 = FakeTask.createTaskSet(10) + val accumUpdatesByTask: Array[Seq[AccumulatorV2[_, _]]] = taskSet1.tasks.map { task => + task.metrics.internalAccums + } + + sched.submitTasks(taskSet1) + sched.resourceOffers( + (0 until 10).map { idx => WorkerOffer(s"exec-$idx", s"host-$idx", 1) }) + + val taskSetManager1 = sched.taskSetManagerForAttempt(0, 0).get + + // fail fetch + taskSetManager1.handleFailedTask( + taskSetManager1.taskAttempts.head.head.taskId, TaskState.FAILED, + FetchFailed(null, 0, 0, 0, "fetch failed")) + + assert(taskSetManager1.isZombie) + assert(taskSetManager1.runningTasks === 9) + + val taskSet2 = FakeTask.createTaskSet(10, stageAttemptId = 1) + sched.submitTasks(taskSet2) + sched.resourceOffers( + (11 until 20).map { idx => WorkerOffer(s"exec-$idx", s"host-$idx", 1) }) + + // Complete the 2 tasks and leave 8 task in running + for (id <- Set(0, 1)) { + taskSetManager1.handleSuccessfulTask(id, createTaskResult(id, accumUpdatesByTask(id))) + assert(sched.endedTasks(id) === Success) + } + + val taskSetManager2 = sched.taskSetManagerForAttempt(0, 1).get + assert(!taskSetManager2.successfulTaskDurations.isEmpty()) + taskSetManager2.checkSpeculatableTasks(0) + } + private def createTaskResult( id: Int, accumUpdates: Seq[AccumulatorV2[_, _]] = Seq.empty): DirectTaskResult[Int] = { From 1272b2034d4eed4bfe60a49e1065871b3a3f96e0 Mon Sep 17 00:00:00 2001 From: pgandhi Date: Wed, 18 Jul 2018 14:07:03 -0500 Subject: [PATCH 1156/1161] =?UTF-8?q?[SPARK-22151]=20PYTHONPATH=20not=20pi?= =?UTF-8?q?cked=20up=20from=20the=20spark.yarn.appMaste=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …rEnv properly Running in yarn cluster mode and trying to set pythonpath via spark.yarn.appMasterEnv.PYTHONPATH doesn't work. the yarn Client code looks at the env variables: val pythonPathStr = (sys.env.get("PYTHONPATH") ++ pythonPath) But when you set spark.yarn.appMasterEnv it puts it into the local env. So the python path set in spark.yarn.appMasterEnv isn't properly set. You can work around if you are running in cluster mode by setting it on the client like: PYTHONPATH=./addon/python/ spark-submit ## What changes were proposed in this pull request? In Client.scala, PYTHONPATH was being overridden, so changed code to append values to PYTHONPATH instead of overriding them. ## How was this patch tested? Added log statements to ApplicationMaster.scala to check for environment variable PYTHONPATH, ran a spark job in cluster mode before the change and verified the issue. Performed the same test after the change and verified the fix. Author: pgandhi Closes #21468 from pgandhi999/SPARK-22151. --- .../main/scala/org/apache/spark/deploy/yarn/Client.scala | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 793d012218490..ed9879c06968d 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -811,10 +811,12 @@ private[spark] class Client( // Finally, update the Spark config to propagate PYTHONPATH to the AM and executors. if (pythonPath.nonEmpty) { - val pythonPathStr = (sys.env.get("PYTHONPATH") ++ pythonPath) + val pythonPathList = (sys.env.get("PYTHONPATH") ++ pythonPath) + env("PYTHONPATH") = (env.get("PYTHONPATH") ++ pythonPathList) .mkString(ApplicationConstants.CLASS_PATH_SEPARATOR) - env("PYTHONPATH") = pythonPathStr - sparkConf.setExecutorEnv("PYTHONPATH", pythonPathStr) + val pythonPathExecutorEnv = (sparkConf.getExecutorEnv.toMap.get("PYTHONPATH") ++ + pythonPathList).mkString(ApplicationConstants.CLASS_PATH_SEPARATOR) + sparkConf.setExecutorEnv("PYTHONPATH", pythonPathExecutorEnv) } if (isClusterMode) { From cd203e0dfc0758a2a90297e8c74c22a1212db846 Mon Sep 17 00:00:00 2001 From: maryannxue Date: Wed, 18 Jul 2018 13:33:26 -0700 Subject: [PATCH 1157/1161] [SPARK-24163][SPARK-24164][SQL] Support column list as the pivot column in Pivot ## What changes were proposed in this pull request? 1. Extend the Parser to enable parsing a column list as the pivot column. 2. Extend the Parser and the Pivot node to enable parsing complex expressions with aliases as the pivot value. 3. Add type check and constant check in Analyzer for Pivot node. ## How was this patch tested? Add tests in pivot.sql Author: maryannxue Closes #21720 from maryannxue/spark-24164. --- .../spark/sql/catalyst/parser/SqlBase.g4 | 11 +- .../sql/catalyst/analysis/Analyzer.scala | 47 +++- .../sql/catalyst/parser/AstBuilder.scala | 22 +- .../plans/logical/basicLogicalOperators.scala | 2 +- .../test/resources/sql-tests/inputs/pivot.sql | 92 +++++++ .../resources/sql-tests/results/pivot.sql.out | 230 +++++++++++++++--- 6 files changed, 348 insertions(+), 56 deletions(-) diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index dc95751bf905c..1b43874af6feb 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -414,7 +414,16 @@ groupingSet ; pivotClause - : PIVOT '(' aggregates=namedExpressionSeq FOR pivotColumn=identifier IN '(' pivotValues+=constant (',' pivotValues+=constant)* ')' ')' + : PIVOT '(' aggregates=namedExpressionSeq FOR pivotColumn IN '(' pivotValues+=pivotValue (',' pivotValues+=pivotValue)* ')' ')' + ; + +pivotColumn + : identifiers+=identifier + | '(' identifiers+=identifier (',' identifiers+=identifier)* ')' + ; + +pivotValue + : expression (AS? identifier)? ; lateralView diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 36f14ccdc6989..59c371eb1557b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -509,17 +509,39 @@ class Analyzer( def apply(plan: LogicalPlan): LogicalPlan = plan transform { case p: Pivot if !p.childrenResolved || !p.aggregates.forall(_.resolved) || (p.groupByExprsOpt.isDefined && !p.groupByExprsOpt.get.forall(_.resolved)) - || !p.pivotColumn.resolved => p + || !p.pivotColumn.resolved || !p.pivotValues.forall(_.resolved) => p case Pivot(groupByExprsOpt, pivotColumn, pivotValues, aggregates, child) => // Check all aggregate expressions. aggregates.foreach(checkValidAggregateExpression) + // Check all pivot values are literal and match pivot column data type. + val evalPivotValues = pivotValues.map { value => + val foldable = value match { + case Alias(v, _) => v.foldable + case _ => value.foldable + } + if (!foldable) { + throw new AnalysisException( + s"Literal expressions required for pivot values, found '$value'") + } + if (!Cast.canCast(value.dataType, pivotColumn.dataType)) { + throw new AnalysisException(s"Invalid pivot value '$value': " + + s"value data type ${value.dataType.simpleString} does not match " + + s"pivot column data type ${pivotColumn.dataType.catalogString}") + } + Cast(value, pivotColumn.dataType, Some(conf.sessionLocalTimeZone)).eval(EmptyRow) + } // Group-by expressions coming from SQL are implicit and need to be deduced. val groupByExprs = groupByExprsOpt.getOrElse( (child.outputSet -- aggregates.flatMap(_.references) -- pivotColumn.references).toSeq) val singleAgg = aggregates.size == 1 - def outputName(value: Literal, aggregate: Expression): String = { - val utf8Value = Cast(value, StringType, Some(conf.sessionLocalTimeZone)).eval(EmptyRow) - val stringValue: String = Option(utf8Value).map(_.toString).getOrElse("null") + def outputName(value: Expression, aggregate: Expression): String = { + val stringValue = value match { + case n: NamedExpression => n.name + case _ => + val utf8Value = + Cast(value, StringType, Some(conf.sessionLocalTimeZone)).eval(EmptyRow) + Option(utf8Value).map(_.toString).getOrElse("null") + } if (singleAgg) { stringValue } else { @@ -534,15 +556,10 @@ class Analyzer( // Since evaluating |pivotValues| if statements for each input row can get slow this is an // alternate plan that instead uses two steps of aggregation. val namedAggExps: Seq[NamedExpression] = aggregates.map(a => Alias(a, a.sql)()) - val namedPivotCol = pivotColumn match { - case n: NamedExpression => n - case _ => Alias(pivotColumn, "__pivot_col")() - } - val bigGroup = groupByExprs :+ namedPivotCol + val bigGroup = groupByExprs ++ pivotColumn.references val firstAgg = Aggregate(bigGroup, bigGroup ++ namedAggExps, child) - val castPivotValues = pivotValues.map(Cast(_, pivotColumn.dataType).eval(EmptyRow)) val pivotAggs = namedAggExps.map { a => - Alias(PivotFirst(namedPivotCol.toAttribute, a.toAttribute, castPivotValues) + Alias(PivotFirst(pivotColumn, a.toAttribute, evalPivotValues) .toAggregateExpression() , "__pivot_" + a.sql)() } @@ -557,8 +574,12 @@ class Analyzer( Project(groupByExprsAttr ++ pivotOutputs, secondAgg) } else { val pivotAggregates: Seq[NamedExpression] = pivotValues.flatMap { value => - def ifExpr(expr: Expression) = { - If(EqualNullSafe(pivotColumn, value), expr, Literal(null)) + def ifExpr(e: Expression) = { + If( + EqualNullSafe( + pivotColumn, + Cast(value, pivotColumn.dataType, Some(conf.sessionLocalTimeZone))), + e, Literal(null)) } aggregates.map { aggregate => val filteredAggregate = aggregate.transformDown { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index f398b479dc273..49f578a24aaeb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -630,11 +630,29 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging val aggregates = Option(ctx.aggregates).toSeq .flatMap(_.namedExpression.asScala) .map(typedVisit[Expression]) - val pivotColumn = UnresolvedAttribute.quoted(ctx.pivotColumn.getText) - val pivotValues = ctx.pivotValues.asScala.map(typedVisit[Expression]).map(Literal.apply) + val pivotColumn = if (ctx.pivotColumn.identifiers.size == 1) { + UnresolvedAttribute.quoted(ctx.pivotColumn.identifier.getText) + } else { + CreateStruct( + ctx.pivotColumn.identifiers.asScala.map( + identifier => UnresolvedAttribute.quoted(identifier.getText))) + } + val pivotValues = ctx.pivotValues.asScala.map(visitPivotValue) Pivot(None, pivotColumn, pivotValues, aggregates, query) } + /** + * Create a Pivot column value with or without an alias. + */ + override def visitPivotValue(ctx: PivotValueContext): Expression = withOrigin(ctx) { + val e = expression(ctx.expression) + if (ctx.identifier != null) { + Alias(e, ctx.identifier.getText)() + } else { + e + } + } + /** * Add a [[Generate]] (Lateral View) to a logical plan. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 3bf32ef7884e5..ea5a9b8ed5542 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -700,7 +700,7 @@ case class GroupingSets( case class Pivot( groupByExprsOpt: Option[Seq[NamedExpression]], pivotColumn: Expression, - pivotValues: Seq[Literal], + pivotValues: Seq[Expression], aggregates: Seq[Expression], child: LogicalPlan) extends UnaryNode { override lazy val resolved = false // Pivot will be replaced after being resolved. diff --git a/sql/core/src/test/resources/sql-tests/inputs/pivot.sql b/sql/core/src/test/resources/sql-tests/inputs/pivot.sql index b3d53adfbebe7..a6c8d4854ff38 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/pivot.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/pivot.sql @@ -11,6 +11,11 @@ create temporary view years as select * from values (2013, 2) as years(y, s); +create temporary view yearsWithArray as select * from values + (2012, array(1, 1)), + (2013, array(2, 2)) + as yearsWithArray(y, a); + -- pivot courses SELECT * FROM ( SELECT year, course, earnings FROM courseSales @@ -96,6 +101,15 @@ PIVOT ( FOR y IN (2012, 2013) ); +-- pivot with projection and value aliases +SELECT firstYear_s, secondYear_s, firstYear_a, secondYear_a, c FROM ( + SELECT year y, course c, earnings e FROM courseSales +) +PIVOT ( + sum(e) s, avg(e) a + FOR y IN (2012 as firstYear, 2013 secondYear) +); + -- pivot years with non-aggregate function SELECT * FROM courseSales PIVOT ( @@ -103,6 +117,15 @@ PIVOT ( FOR year IN (2012, 2013) ); +-- pivot with one of the expressions as non-aggregate function +SELECT * FROM ( + SELECT year, course, earnings FROM courseSales +) +PIVOT ( + sum(earnings), year + FOR course IN ('dotNET', 'Java') +); + -- pivot with unresolvable columns SELECT * FROM ( SELECT course, earnings FROM courseSales @@ -129,3 +152,72 @@ PIVOT ( sum(avg(earnings)) FOR course IN ('dotNET', 'Java') ); + +-- pivot on multiple pivot columns +SELECT * FROM ( + SELECT course, year, earnings, s + FROM courseSales + JOIN years ON year = y +) +PIVOT ( + sum(earnings) + FOR (course, year) IN (('dotNET', 2012), ('Java', 2013)) +); + +-- pivot on multiple pivot columns with aliased values +SELECT * FROM ( + SELECT course, year, earnings, s + FROM courseSales + JOIN years ON year = y +) +PIVOT ( + sum(earnings) + FOR (course, s) IN (('dotNET', 2) as c1, ('Java', 1) as c2) +); + +-- pivot on multiple pivot columns with values of wrong data types +SELECT * FROM ( + SELECT course, year, earnings, s + FROM courseSales + JOIN years ON year = y +) +PIVOT ( + sum(earnings) + FOR (course, year) IN ('dotNET', 'Java') +); + +-- pivot with unresolvable values +SELECT * FROM courseSales +PIVOT ( + sum(earnings) + FOR year IN (s, 2013) +); + +-- pivot with non-literal values +SELECT * FROM courseSales +PIVOT ( + sum(earnings) + FOR year IN (course, 2013) +); + +-- pivot on join query with columns of complex data types +SELECT * FROM ( + SELECT course, year, a + FROM courseSales + JOIN yearsWithArray ON year = y +) +PIVOT ( + min(a) + FOR course IN ('dotNET', 'Java') +); + +-- pivot on multiple pivot columns with agg columns of complex data types +SELECT * FROM ( + SELECT course, year, y, a + FROM courseSales + JOIN yearsWithArray ON year = y +) +PIVOT ( + max(a) + FOR (y, course) IN ((2012, 'dotNET'), (2013, 'Java')) +); diff --git a/sql/core/src/test/resources/sql-tests/results/pivot.sql.out b/sql/core/src/test/resources/sql-tests/results/pivot.sql.out index 922d8b9f9152c..6bb51b946f960 100644 --- a/sql/core/src/test/resources/sql-tests/results/pivot.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/pivot.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 15 +-- Number of queries: 25 -- !query 0 @@ -28,6 +28,17 @@ struct<> -- !query 2 +create temporary view yearsWithArray as select * from values + (2012, array(1, 1)), + (2013, array(2, 2)) + as yearsWithArray(y, a) +-- !query 2 schema +struct<> +-- !query 2 output + + + +-- !query 3 SELECT * FROM ( SELECT year, course, earnings FROM courseSales ) @@ -35,27 +46,27 @@ PIVOT ( sum(earnings) FOR course IN ('dotNET', 'Java') ) --- !query 2 schema +-- !query 3 schema struct --- !query 2 output +-- !query 3 output 2012 15000 20000 2013 48000 30000 --- !query 3 +-- !query 4 SELECT * FROM courseSales PIVOT ( sum(earnings) FOR year IN (2012, 2013) ) --- !query 3 schema +-- !query 4 schema struct --- !query 3 output +-- !query 4 output Java 20000 30000 dotNET 15000 48000 --- !query 4 +-- !query 5 SELECT * FROM ( SELECT year, course, earnings FROM courseSales ) @@ -63,14 +74,14 @@ PIVOT ( sum(earnings), avg(earnings) FOR course IN ('dotNET', 'Java') ) --- !query 4 schema +-- !query 5 schema struct --- !query 4 output +-- !query 5 output 2012 15000 7500.0 20000 20000.0 2013 48000 48000.0 30000 30000.0 --- !query 5 +-- !query 6 SELECT * FROM ( SELECT course, earnings FROM courseSales ) @@ -78,13 +89,13 @@ PIVOT ( sum(earnings) FOR course IN ('dotNET', 'Java') ) --- !query 5 schema +-- !query 6 schema struct --- !query 5 output +-- !query 6 output 63000 50000 --- !query 6 +-- !query 7 SELECT * FROM ( SELECT year, course, earnings FROM courseSales ) @@ -92,13 +103,13 @@ PIVOT ( sum(earnings), min(year) FOR course IN ('dotNET', 'Java') ) --- !query 6 schema +-- !query 7 schema struct --- !query 6 output +-- !query 7 output 63000 2012 50000 2012 --- !query 7 +-- !query 8 SELECT * FROM ( SELECT course, year, earnings, s FROM courseSales @@ -108,16 +119,16 @@ PIVOT ( sum(earnings) FOR s IN (1, 2) ) --- !query 7 schema +-- !query 8 schema struct --- !query 7 output +-- !query 8 output Java 2012 20000 NULL Java 2013 NULL 30000 dotNET 2012 15000 NULL dotNET 2013 NULL 48000 --- !query 8 +-- !query 9 SELECT * FROM ( SELECT course, year, earnings, s FROM courseSales @@ -127,14 +138,14 @@ PIVOT ( sum(earnings), min(s) FOR course IN ('dotNET', 'Java') ) --- !query 8 schema +-- !query 9 schema struct --- !query 8 output +-- !query 9 output 2012 15000 1 20000 1 2013 48000 2 30000 2 --- !query 9 +-- !query 10 SELECT * FROM ( SELECT course, year, earnings, s FROM courseSales @@ -144,14 +155,14 @@ PIVOT ( sum(earnings * s) FOR course IN ('dotNET', 'Java') ) --- !query 9 schema +-- !query 10 schema struct --- !query 9 output +-- !query 10 output 2012 15000 20000 2013 96000 60000 --- !query 10 +-- !query 11 SELECT 2012_s, 2013_s, 2012_a, 2013_a, c FROM ( SELECT year y, course c, earnings e FROM courseSales ) @@ -159,27 +170,57 @@ PIVOT ( sum(e) s, avg(e) a FOR y IN (2012, 2013) ) --- !query 10 schema +-- !query 11 schema struct<2012_s:bigint,2013_s:bigint,2012_a:double,2013_a:double,c:string> --- !query 10 output +-- !query 11 output 15000 48000 7500.0 48000.0 dotNET 20000 30000 20000.0 30000.0 Java --- !query 11 +-- !query 12 +SELECT firstYear_s, secondYear_s, firstYear_a, secondYear_a, c FROM ( + SELECT year y, course c, earnings e FROM courseSales +) +PIVOT ( + sum(e) s, avg(e) a + FOR y IN (2012 as firstYear, 2013 secondYear) +) +-- !query 12 schema +struct +-- !query 12 output +15000 48000 7500.0 48000.0 dotNET +20000 30000 20000.0 30000.0 Java + + +-- !query 13 SELECT * FROM courseSales PIVOT ( abs(earnings) FOR year IN (2012, 2013) ) --- !query 11 schema +-- !query 13 schema struct<> --- !query 11 output +-- !query 13 output org.apache.spark.sql.AnalysisException Aggregate expression required for pivot, but 'coursesales.`earnings`' did not appear in any aggregate function.; --- !query 12 +-- !query 14 +SELECT * FROM ( + SELECT year, course, earnings FROM courseSales +) +PIVOT ( + sum(earnings), year + FOR course IN ('dotNET', 'Java') +) +-- !query 14 schema +struct<> +-- !query 14 output +org.apache.spark.sql.AnalysisException +Aggregate expression required for pivot, but '__auto_generated_subquery_name.`year`' did not appear in any aggregate function.; + + +-- !query 15 SELECT * FROM ( SELECT course, earnings FROM courseSales ) @@ -187,14 +228,14 @@ PIVOT ( sum(earnings) FOR year IN (2012, 2013) ) --- !query 12 schema +-- !query 15 schema struct<> --- !query 12 output +-- !query 15 output org.apache.spark.sql.AnalysisException cannot resolve '`year`' given input columns: [__auto_generated_subquery_name.course, __auto_generated_subquery_name.earnings]; line 4 pos 0 --- !query 13 +-- !query 16 SELECT * FROM ( SELECT year, course, earnings FROM courseSales ) @@ -202,14 +243,14 @@ PIVOT ( ceil(sum(earnings)), avg(earnings) + 1 as a1 FOR course IN ('dotNET', 'Java') ) --- !query 13 schema +-- !query 16 schema struct --- !query 13 output +-- !query 16 output 2012 15000 7501.0 20000 20001.0 2013 48000 48001.0 30000 30001.0 --- !query 14 +-- !query 17 SELECT * FROM ( SELECT year, course, earnings FROM courseSales ) @@ -217,8 +258,119 @@ PIVOT ( sum(avg(earnings)) FOR course IN ('dotNET', 'Java') ) --- !query 14 schema +-- !query 17 schema struct<> --- !query 14 output +-- !query 17 output org.apache.spark.sql.AnalysisException It is not allowed to use an aggregate function in the argument of another aggregate function. Please use the inner aggregate function in a sub-query.; + + +-- !query 18 +SELECT * FROM ( + SELECT course, year, earnings, s + FROM courseSales + JOIN years ON year = y +) +PIVOT ( + sum(earnings) + FOR (course, year) IN (('dotNET', 2012), ('Java', 2013)) +) +-- !query 18 schema +struct +-- !query 18 output +1 15000 NULL +2 NULL 30000 + + +-- !query 19 +SELECT * FROM ( + SELECT course, year, earnings, s + FROM courseSales + JOIN years ON year = y +) +PIVOT ( + sum(earnings) + FOR (course, s) IN (('dotNET', 2) as c1, ('Java', 1) as c2) +) +-- !query 19 schema +struct +-- !query 19 output +2012 NULL 20000 +2013 48000 NULL + + +-- !query 20 +SELECT * FROM ( + SELECT course, year, earnings, s + FROM courseSales + JOIN years ON year = y +) +PIVOT ( + sum(earnings) + FOR (course, year) IN ('dotNET', 'Java') +) +-- !query 20 schema +struct<> +-- !query 20 output +org.apache.spark.sql.AnalysisException +Invalid pivot value 'dotNET': value data type string does not match pivot column data type struct; + + +-- !query 21 +SELECT * FROM courseSales +PIVOT ( + sum(earnings) + FOR year IN (s, 2013) +) +-- !query 21 schema +struct<> +-- !query 21 output +org.apache.spark.sql.AnalysisException +cannot resolve '`s`' given input columns: [coursesales.course, coursesales.year, coursesales.earnings]; line 4 pos 15 + + +-- !query 22 +SELECT * FROM courseSales +PIVOT ( + sum(earnings) + FOR year IN (course, 2013) +) +-- !query 22 schema +struct<> +-- !query 22 output +org.apache.spark.sql.AnalysisException +Literal expressions required for pivot values, found 'course#x'; + + +-- !query 23 +SELECT * FROM ( + SELECT course, year, a + FROM courseSales + JOIN yearsWithArray ON year = y +) +PIVOT ( + min(a) + FOR course IN ('dotNET', 'Java') +) +-- !query 23 schema +struct,Java:array> +-- !query 23 output +2012 [1,1] [1,1] +2013 [2,2] [2,2] + + +-- !query 24 +SELECT * FROM ( + SELECT course, year, y, a + FROM courseSales + JOIN yearsWithArray ON year = y +) +PIVOT ( + max(a) + FOR (y, course) IN ((2012, 'dotNET'), (2013, 'Java')) +) +-- !query 24 schema +struct,[2013, Java]:array> +-- !query 24 output +2012 [1,1] NULL +2013 NULL [2,2] From d404e54e644ec7c6873234b9fed72e7f1f41a8b1 Mon Sep 17 00:00:00 2001 From: Devaraj K Date: Wed, 18 Jul 2018 16:18:29 -0500 Subject: [PATCH 1158/1161] [SPARK-24129][K8S] Add option to pass --build-arg's to docker-image-tool.sh ## What changes were proposed in this pull request? Adding `-b arg` option to take `--build-arg` parameters to pass into the docker command ## How was this patch tested? I verified by passing proxy details which fails without this change and succeeds with the changes. Author: Devaraj K Closes #21202 from devaraj-kavali/SPARK-24129. --- bin/docker-image-tool.sh | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/bin/docker-image-tool.sh b/bin/docker-image-tool.sh index a3f1bcffaea57..f36fb43692cf4 100755 --- a/bin/docker-image-tool.sh +++ b/bin/docker-image-tool.sh @@ -49,6 +49,7 @@ function build { # Set image build arguments accordingly if this is a source repo and not a distribution archive. IMG_PATH=resource-managers/kubernetes/docker/src/main/dockerfiles BUILD_ARGS=( + ${BUILD_PARAMS} --build-arg img_path=$IMG_PATH --build-arg @@ -57,13 +58,14 @@ function build { else # Not passed as an argument to docker, but used to validate the Spark directory. IMG_PATH="kubernetes/dockerfiles" - BUILD_ARGS=() + BUILD_ARGS=(${BUILD_PARAMS}) fi if [ ! -d "$IMG_PATH" ]; then error "Cannot find docker image. This script must be run from a runnable distribution of Apache Spark." fi local BINDING_BUILD_ARGS=( + ${BUILD_PARAMS} --build-arg base_img=$(image_ref spark) ) @@ -101,6 +103,8 @@ Options: -t tag Tag to apply to the built image, or to identify the image to be pushed. -m Use minikube's Docker daemon. -n Build docker image with --no-cache + -b arg Build arg to build or push the image. For multiple build args, this option needs to + be used separately for each build arg. Using minikube when building images will do so directly into minikube's Docker daemon. There is no need to push the images into minikube in that case, they'll be automatically @@ -130,7 +134,8 @@ TAG= BASEDOCKERFILE= PYDOCKERFILE= NOCACHEARG= -while getopts f:mr:t:n option +BUILD_PARAMS= +while getopts f:mr:t:n:b: option do case "${option}" in @@ -139,6 +144,7 @@ do r) REPO=${OPTARG};; t) TAG=${OPTARG};; n) NOCACHEARG="--no-cache";; + b) BUILD_PARAMS=${BUILD_PARAMS}" --build-arg "${OPTARG};; m) if ! which minikube 1>/dev/null; then error "Cannot find minikube." From 753f11516226d0ddf542f95bd67d145a4e899451 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Wed, 18 Jul 2018 18:39:23 -0500 Subject: [PATCH 1159/1161] [SPARK-21261][DOCS][SQL] SQL Regex document fix ## What changes were proposed in this pull request? Fix regexes in spark-sql command examples. This takes over https://github.com/apache/spark/pull/18477 ## How was this patch tested? Existing tests. I verified the existing example doesn't work in spark-sql, but new ones does. Author: Sean Owen Closes #21808 from srowen/SPARK-21261. --- .../spark/sql/catalyst/expressions/regexpExpressions.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index 7b68bb771faf3..bf0c35fe61018 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -272,7 +272,7 @@ case class StringSplit(str: Expression, pattern: Expression) usage = "_FUNC_(str, regexp, rep) - Replaces all substrings of `str` that match `regexp` with `rep`.", examples = """ Examples: - > SELECT _FUNC_('100-200', '(\d+)', 'num'); + > SELECT _FUNC_('100-200', '(\\d+)', 'num'); num-num """) // scalastyle:on line.size.limit @@ -371,7 +371,7 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio usage = "_FUNC_(str, regexp[, idx]) - Extracts a group that matches `regexp`.", examples = """ Examples: - > SELECT _FUNC_('100-200', '(\d+)-(\d+)', 1); + > SELECT _FUNC_('100-200', '(\\d+)-(\\d+)', 1); 100 """) case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expression) From cd5d93c0e4ec4573126c6cdda3362814976d11eb Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Thu, 19 Jul 2018 09:16:16 +0800 Subject: [PATCH 1160/1161] [SPARK-24854][SQL] Gathering all Avro options into the AvroOptions class ## What changes were proposed in this pull request? In the PR, I propose to put all `Avro` options in new class `AvroOptions` in the same way as for other datasources `JSON` and `CSV`. ## How was this patch tested? It was tested by `AvroSuite` Author: Maxim Gekk Closes #21810 from MaxGekk/avro-options. --- .../spark/sql/avro/AvroFileFormat.scala | 13 +++-- .../apache/spark/sql/avro/AvroOptions.scala | 48 +++++++++++++++++++ .../org/apache/spark/sql/avro/AvroSuite.scala | 6 ++- 3 files changed, 58 insertions(+), 9 deletions(-) create mode 100644 external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOptions.scala diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala index 9eb206457809c..1d0f40e1ce92a 100755 --- a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala @@ -58,6 +58,7 @@ private[avro] class AvroFileFormat extends FileFormat with DataSourceRegister { options: Map[String, String], files: Seq[FileStatus]): Option[StructType] = { val conf = spark.sparkContext.hadoopConfiguration + val parsedOptions = new AvroOptions(options) // Schema evolution is not supported yet. Here we only pick a single random sample file to // figure out the schema of the whole dataset. @@ -76,7 +77,7 @@ private[avro] class AvroFileFormat extends FileFormat with DataSourceRegister { } // User can specify an optional avro json schema. - val avroSchema = options.get(AvroFileFormat.AvroSchema) + val avroSchema = parsedOptions.schema .map(new Schema.Parser().parse) .getOrElse { val in = new FsInput(sampleFile.getPath, conf) @@ -114,10 +115,9 @@ private[avro] class AvroFileFormat extends FileFormat with DataSourceRegister { job: Job, options: Map[String, String], dataSchema: StructType): OutputWriterFactory = { - val recordName = options.getOrElse("recordName", "topLevelRecord") - val recordNamespace = options.getOrElse("recordNamespace", "") + val parsedOptions = new AvroOptions(options) val outputAvroSchema = SchemaConverters.toAvroType( - dataSchema, nullable = false, recordName, recordNamespace) + dataSchema, nullable = false, parsedOptions.recordName, parsedOptions.recordNamespace) AvroJob.setOutputKeySchema(job, outputAvroSchema) val AVRO_COMPRESSION_CODEC = "spark.sql.avro.compression.codec" @@ -160,11 +160,12 @@ private[avro] class AvroFileFormat extends FileFormat with DataSourceRegister { val broadcastedConf = spark.sparkContext.broadcast(new AvroFileFormat.SerializableConfiguration(hadoopConf)) + val parsedOptions = new AvroOptions(options) (file: PartitionedFile) => { val log = LoggerFactory.getLogger(classOf[AvroFileFormat]) val conf = broadcastedConf.value.value - val userProvidedSchema = options.get(AvroFileFormat.AvroSchema).map(new Schema.Parser().parse) + val userProvidedSchema = parsedOptions.schema.map(new Schema.Parser().parse) // TODO Removes this check once `FileFormat` gets a general file filtering interface method. // Doing input file filtering is improper because we may generate empty tasks that process no @@ -235,8 +236,6 @@ private[avro] class AvroFileFormat extends FileFormat with DataSourceRegister { private[avro] object AvroFileFormat { val IgnoreFilesWithoutExtensionProperty = "avro.mapred.ignore.inputs.without.extension" - val AvroSchema = "avroSchema" - class SerializableConfiguration(@transient var value: Configuration) extends Serializable with KryoSerializable { @transient private[avro] lazy val log = LoggerFactory.getLogger(getClass) diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOptions.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOptions.scala new file mode 100644 index 0000000000000..8721eae3481da --- /dev/null +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOptions.scala @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.avro + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap + +/** + * Options for Avro Reader and Writer stored in case insensitive manner. + */ +class AvroOptions(@transient val parameters: CaseInsensitiveMap[String]) + extends Logging with Serializable { + + def this(parameters: Map[String, String]) = this(CaseInsensitiveMap(parameters)) + + /** + * Optional schema provided by an user in JSON format. + */ + val schema: Option[String] = parameters.get("avroSchema") + + /** + * Top level record name in write result, which is required in Avro spec. + * See https://avro.apache.org/docs/1.8.2/spec.html#schema_record . + * Default value is "topLevelRecord" + */ + val recordName: String = parameters.getOrElse("recordName", "topLevelRecord") + + /** + * Record namespace in write result. Default value is "". + * See Avro spec for details: https://avro.apache.org/docs/1.8.2/spec.html#schema_record . + */ + val recordNamespace: String = parameters.getOrElse("recordNamespace", "") +} diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala index 446b42124ceca..f7e9877b7744b 100644 --- a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala +++ b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala @@ -578,7 +578,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { """.stripMargin val result = spark .read - .option(AvroFileFormat.AvroSchema, avroSchema) + .option("avroSchema", avroSchema) .avro(testAvro) .collect() val expected = spark.read.avro(testAvro).select("string").collect() @@ -598,7 +598,9 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { | }] |} """.stripMargin - val result = spark.read.option(AvroFileFormat.AvroSchema, avroSchema) + val result = spark + .read + .option("avroSchema", avroSchema) .avro(testAvro).select("missingField").first assert(result === Row("foo")) } From 1a4fda88685bf62d7d8c639ec2d9762107cd447b Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Thu, 19 Jul 2018 10:48:10 +0800 Subject: [PATCH 1161/1161] [INFRA] Close stale PR Closes #17422 Closes #17619 Closes #18034 Closes #18229 Closes #18268 Closes #17973 Closes #18125 Closes #18918 Closes #19274 Closes #19456 Closes #19510 Closes #19420 Closes #20090 Closes #20177 Closes #20304 Closes #20319 Closes #20543 Closes #20437 Closes #21261 Closes #21726 Closes #14653 Closes #13143 Closes #17894 Closes #19758 Closes #12951 Closes #17092 Closes #21240 Closes #16910 Closes #12904 Closes #21731 Closes #21095 Added: Closes #19233 Closes #20100 Closes #21453 Closes #21455 Closes #18477 Added: Closes #21812 Closes #21787 Author: hyukjinkwon Closes #21781 from HyukjinKwon/closing-prs.
      keyvalue
      {session.userName} spark.yarn.dist.forceDownloadSchemes (none) - Comma-separated list of schemes for which files will be downloaded to the local disk prior to + Comma-separated list of schemes for which resources will be downloaded to the local disk prior to being added to YARN's distributed cache. For use in cases where the YARN service does not - support schemes that are supported by Spark, like http, https and ftp. + support schemes that are supported by Spark, like http, https and ftp, or jars required to be in the + local YARN client's classpath. Wildcard '*' is denoted to download resources for all the schemes.
      spark.kubernetes.driver.volumes.[VolumeType].[VolumeName].mount.path(none) + Add the Kubernetes Volume named VolumeName of the VolumeType type to the driver pod on the path specified in the value. For example, + spark.kubernetes.driver.volumes.persistentVolumeClaim.checkpointpvc.mount.path=/checkpoint. +
      spark.kubernetes.driver.volumes.[VolumeType].[VolumeName].mount.readOnly(none) + Specify if the mounted volume is read only or not. For example, + spark.kubernetes.driver.volumes.persistentVolumeClaim.checkpointpvc.mount.readOnly=false. +
      spark.kubernetes.driver.volumes.[VolumeType].[VolumeName].options.[OptionName](none) + Configure Kubernetes Volume options passed to the Kubernetes with OptionName as key having specified value, must conform with Kubernetes option format. For example, + spark.kubernetes.driver.volumes.persistentVolumeClaim.checkpointpvc.options.claimName=spark-pvc-claim. +
      spark.kubernetes.executor.volumes.[VolumeType].[VolumeName].mount.path(none) + Add the Kubernetes Volume named VolumeName of the VolumeType type to the executor pod on the path specified in the value. For example, + spark.kubernetes.executor.volumes.persistentVolumeClaim.checkpointpvc.mount.path=/checkpoint. +
      spark.kubernetes.executor.volumes.[VolumeType].[VolumeName].mount.readOnlyfalse + Specify if the mounted volume is read only or not. For example, + spark.kubernetes.executor.volumes.persistentVolumeClaim.checkpointpvc.mount.readOnly=false. +
      spark.kubernetes.executor.volumes.[VolumeType].[VolumeName].options.[OptionName](none) + Configure Kubernetes Volume options passed to the Kubernetes with OptionName as key having specified value. For example, + spark.kubernetes.executor.volumes.persistentVolumeClaim.checkpointpvc.options.claimName=spark-pvc-claim. +
      spark.kubernetes.memoryOverheadFactor